[Mlir-commits] [mlir] 34ff857 - [mlir][X86Vector] Add specialized vector.transpose lowering patterns for AVX2
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Nov 10 23:39:16 PST 2021
Author: Nicolas Vasilache
Date: 2021-11-11T07:33:31Z
New Revision: 34ff8573505e04c75e84a0e515af462f223f2795
URL: https://github.com/llvm/llvm-project/commit/34ff8573505e04c75e84a0e515af462f223f2795
DIFF: https://github.com/llvm/llvm-project/commit/34ff8573505e04c75e84a0e515af462f223f2795.diff
LOG: [mlir][X86Vector] Add specialized vector.transpose lowering patterns for AVX2
This revision adds an implementation of 2-D vector.transpose for 4x8 and 8x8 for
AVX2 and surfaces it to the Linalg level of control.
Reviewed By: dcaballe
Differential Revision: https://reviews.llvm.org/D113347
Added:
mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/X86Vector/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
mlir/test/Dialect/Vector/vector-contract-transforms.mlir
mlir/test/Dialect/Vector/vector-mem-transforms.mlir
mlir/test/Dialect/Vector/vector-transforms.mlir
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
mlir/tools/mlir-opt/mlir-opt.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
Removed:
mlir/test/Dialect/Vector/vector-flat-transforms.mlir
mlir/test/Dialect/Vector/vector-transpose-to-shuffle.mlir
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index e05dde330797b..6b816f9359c55 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/Bufferize.h"
@@ -993,6 +994,12 @@ struct LinalgVectorLoweringOptions {
transposeLowering = val;
return *this;
}
+ /// Enable AVX2-specific lowerings.
+ bool avx2Lowering = false;
+ LinalgVectorLoweringOptions &enableAVX2Lowering(bool val = true) {
+ avx2Lowering = val;
+ return *this;
+ }
/// Configure the post staged-patterns late vector.transfer to scf
/// conversion.
@@ -1009,6 +1016,13 @@ struct LinalgVectorLoweringOptions {
vectorTransformOptions = options;
return *this;
}
+ /// Configure specialized vector lowerings.
+ x86vector::avx2::LoweringOptions avx2LoweringOptions;
+ LinalgVectorLoweringOptions &
+ setAVX2LoweringOptions(x86vector::avx2::LoweringOptions options) {
+ avx2LoweringOptions = options;
+ return *this;
+ }
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index 5e55a03e9b6c6..ebdc2f67e72d1 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -9,13 +9,126 @@
#ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMS_H
#define MLIR_DIALECT_X86VECTOR_TRANSFORMS_H
+#include "mlir/IR/Value.h"
+
namespace mlir {
+class ImplicitLocOpBuilder;
class LLVMConversionTarget;
class LLVMTypeConverter;
class RewritePatternSet;
using OwningRewritePatternList = RewritePatternSet;
+namespace x86vector {
+
+/// Helper class to factor out the creation and extraction of masks from nibs.
+struct MaskHelper {
+ /// b01 captures the lower 2 bits, b67 captures the higher 2 bits.
+ /// Meant to be used with instructions such as mm256ShufflePs.
+ template <unsigned b67, unsigned b45, unsigned b23, unsigned b01>
+ static char shuffle() {
+ static_assert(b01 <= 0x03, "overflow");
+ static_assert(b23 <= 0x03, "overflow");
+ static_assert(b45 <= 0x03, "overflow");
+ static_assert(b67 <= 0x03, "overflow");
+ return (b67 << 6) + (b45 << 4) + (b23 << 2) + b01;
+ }
+ /// b01 captures the lower 2 bits, b67 captures the higher 2 bits.
+ static void extractShuffle(char mask, char &b01, char &b23, char &b45,
+ char &b67) {
+ b67 = (mask & (0x03 << 6)) >> 6;
+ b45 = (mask & (0x03 << 4)) >> 4;
+ b23 = (mask & (0x03 << 2)) >> 2;
+ b01 = mask & 0x03;
+ }
+ /// b03 captures the lower 4 bits, b47 captures the higher 4 bits.
+ /// Meant to be used with instructions such as mm256Permute2f128Ps.
+ template <unsigned b47, unsigned b03>
+ static char permute() {
+ static_assert(b03 <= 0x0f, "overflow");
+ static_assert(b47 <= 0x0f, "overflow");
+ return (b47 << 4) + b03;
+ }
+ /// b03 captures the lower 4 bits, b47 captures the higher 4 bits.
+ static void extractPermute(char mask, char &b03, char &b47) {
+ b47 = (mask & (0x0f << 4)) >> 4;
+ b03 = mask & 0x0f;
+ }
+};
+
+//===----------------------------------------------------------------------===//
+/// Helpers extracted from:
+/// - clang/lib/Headers/avxintrin.h
+/// - clang/test/CodeGen/X86/avx-builtins.c
+/// - clang/test/CodeGen/X86/avx2-builtins.c
+/// - clang/test/CodeGen/X86/avx-shuffle-builtins.c
+/// as well as the Intel Intrinsics Guide
+/// (https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html)
+/// make it easier to just implement known good lowerings.
+/// All intrinsics correspond 1-1 to the Intel definition.
+//===----------------------------------------------------------------------===//
+
+namespace avx2 {
+
+/// Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13].
+Value mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2);
+
+/// Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13].
+Value mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2);
+
+/// a a b b a a b b
+/// Take an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4):
+/// 0:127 | 128:255
+/// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4
+Value mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1, Value v2, char mask);
+
+// imm[0:1] out of imm[0:3] is:
+// 0 1 2 3
+// a[0:127] or a[128:255] or b[0:127] or b[128:255] |
+// a[0:127] or a[128:255] or b[0:127] or b[128:255]
+// 0 1 2 3
+// imm[0:1] out of imm[4:7].
+Value mm256Permute2f128Ps(ImplicitLocOpBuilder &b, Value v1, Value v2,
+ char mask);
+
+/// 4x8xf32-specific AVX2 transpose lowering.
+void transpose4x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef<Value> vs);
+
+/// 8x8xf32-specific AVX2 transpose lowering.
+void transpose8x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef<Value> vs);
+
+/// Structure to control the behavior of specialized AVX2 transpose lowering.
+struct TransposeLoweringOptions {
+ bool lower4x8xf32_ = false;
+ TransposeLoweringOptions &lower4x8xf32(bool lower = true) {
+ lower4x8xf32_ = lower;
+ return *this;
+ }
+ bool lower8x8xf32_ = false;
+ TransposeLoweringOptions &lower8x8xf32(bool lower = true) {
+ lower8x8xf32_ = lower;
+ return *this;
+ }
+};
+
+/// Options for controlling specialized AVX2 lowerings.
+struct LoweringOptions {
+ /// Configure specialized vector lowerings.
+ TransposeLoweringOptions transposeOptions;
+ LoweringOptions &setTransposeOptions(TransposeLoweringOptions options) {
+ transposeOptions = options;
+ return *this;
+ }
+};
+
+/// Insert specialized transpose lowering patterns.
+void populateSpecializedTransposeLoweringPatterns(
+ RewritePatternSet &patterns, LoweringOptions options = LoweringOptions(),
+ int benefit = 10);
+
+} // namespace avx2
+} // namespace x86vector
+
/// Collect a set of patterns to lower X86Vector ops to ops that map to LLVM
/// intrinsics.
void populateX86VectorLegalizeForLLVMExportPatterns(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index d75b5742b38d3..db91472c5a096 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -51,5 +51,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRTransforms
MLIRTransformUtils
MLIRVector
+ MLIRX86VectorTransforms
MLIRVectorToSCF
)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
index b9d8d719898cb..4a97b926655fa 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
@@ -334,6 +334,9 @@ struct LinalgStrategyLowerVectorsPass
if (options.transposeLowering) {
vector::populateVectorTransposeLoweringPatterns(
patterns, options.vectorTransformOptions);
+ if (options.avx2Lowering)
+ x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
+ patterns, options.avx2LoweringOptions, /*benefit=*/10);
}
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
new file mode 100644
index 0000000000000..783939583ae33
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
@@ -0,0 +1,208 @@
+//===- AVXTranspose.cpp - Lower Vector transpose to AVX -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements vector.transpose rewrites as AVX patterns for particular
+// sizes of interest.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+using namespace mlir::x86vector::avx2;
+
+Value mlir::x86vector::avx2::mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1,
+ Value v2) {
+ return b.create<vector::ShuffleOp>(
+ v1, v2, ArrayRef<int64_t>{0, 8, 1, 9, 4, 12, 5, 13});
+}
+
+Value mlir::x86vector::avx2::mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1,
+ Value v2) {
+ return b.create<vector::ShuffleOp>(
+ v1, v2, ArrayRef<int64_t>{2, 10, 3, 11, 6, 14, 7, 15});
+}
+/// a a b b a a b b
+/// Takes an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4):
+/// 0:127 | 128:255
+/// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4
+Value mlir::x86vector::avx2::mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1,
+ Value v2, char mask) {
+ char b01, b23, b45, b67;
+ MaskHelper::extractShuffle(mask, b01, b23, b45, b67);
+ SmallVector<int64_t> shuffleMask{b01, b23, b45 + 8, b67 + 8,
+ b01 + 4, b23 + 4, b45 + 8 + 4, b67 + 8 + 4};
+ return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
+}
+
+// imm[0:1] out of imm[0:3] is:
+// 0 1 2 3
+// a[0:127] or a[128:255] or b[0:127] or b[128:255] |
+// a[0:127] or a[128:255] or b[0:127] or b[128:255]
+// 0 1 2 3
+// imm[0:1] out of imm[4:7].
+Value mlir::x86vector::avx2::mm256Permute2f128Ps(ImplicitLocOpBuilder &b,
+ Value v1, Value v2,
+ char mask) {
+ SmallVector<int64_t> shuffleMask;
+ auto appendToMask = [&](char control) {
+ if (control == 0)
+ llvm::append_range(shuffleMask, ArrayRef<int64_t>{0, 1, 2, 3});
+ else if (control == 1)
+ llvm::append_range(shuffleMask, ArrayRef<int64_t>{4, 5, 6, 7});
+ else if (control == 2)
+ llvm::append_range(shuffleMask, ArrayRef<int64_t>{8, 9, 10, 11});
+ else if (control == 3)
+ llvm::append_range(shuffleMask, ArrayRef<int64_t>{12, 13, 14, 15});
+ else
+ llvm_unreachable("control > 3 : overflow");
+ };
+ char b03, b47;
+ MaskHelper::extractPermute(mask, b03, b47);
+ appendToMask(b03);
+ appendToMask(b47);
+ return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
+}
+
+/// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model.
+void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib,
+ MutableArrayRef<Value> vs) {
+ auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
+#ifndef NDEBUG
+ assert(vs.size() == 4 && "expects 4 vectors");
+ assert(llvm::all_of(ValueRange{vs}.getTypes(),
+ [&](Type t) { return t == vt; }) &&
+ "expects all types to be vector<8xf32>");
+#endif
+
+ Value T0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
+ Value T1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
+ Value T2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
+ Value T3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
+ Value S0 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<1, 0, 1, 0>());
+ Value S1 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<3, 2, 3, 2>());
+ Value S2 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<1, 0, 1, 0>());
+ Value S3 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<3, 2, 3, 2>());
+ vs[0] = mm256Permute2f128Ps(ib, S0, S1, MaskHelper::permute<2, 0>());
+ vs[1] = mm256Permute2f128Ps(ib, S2, S3, MaskHelper::permute<2, 0>());
+ vs[2] = mm256Permute2f128Ps(ib, S0, S1, MaskHelper::permute<3, 1>());
+ vs[3] = mm256Permute2f128Ps(ib, S2, S3, MaskHelper::permute<3, 1>());
+}
+
+/// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model.
+void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib,
+ MutableArrayRef<Value> vs) {
+ auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
+ (void)vt;
+ assert(vs.size() == 8 && "expects 8 vectors");
+ assert(llvm::all_of(ValueRange{vs}.getTypes(),
+ [&](Type t) { return t == vt; }) &&
+ "expects all types to be vector<8xf32>");
+
+ Value T0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
+ Value T1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
+ Value T2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
+ Value T3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
+ Value T4 = mm256UnpackLoPs(ib, vs[4], vs[5]);
+ Value T5 = mm256UnpackHiPs(ib, vs[4], vs[5]);
+ Value T6 = mm256UnpackLoPs(ib, vs[6], vs[7]);
+ Value T7 = mm256UnpackHiPs(ib, vs[6], vs[7]);
+ Value S0 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<1, 0, 1, 0>());
+ Value S1 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<3, 2, 3, 2>());
+ Value S2 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<1, 0, 1, 0>());
+ Value S3 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<3, 2, 3, 2>());
+ Value S4 = mm256ShufflePs(ib, T4, T6, MaskHelper::shuffle<1, 0, 1, 0>());
+ Value S5 = mm256ShufflePs(ib, T4, T6, MaskHelper::shuffle<3, 2, 3, 2>());
+ Value S6 = mm256ShufflePs(ib, T5, T7, MaskHelper::shuffle<1, 0, 1, 0>());
+ Value S7 = mm256ShufflePs(ib, T5, T7, MaskHelper::shuffle<3, 2, 3, 2>());
+ vs[0] = mm256Permute2f128Ps(ib, S0, S4, MaskHelper::permute<2, 0>());
+ vs[1] = mm256Permute2f128Ps(ib, S1, S5, MaskHelper::permute<2, 0>());
+ vs[2] = mm256Permute2f128Ps(ib, S2, S6, MaskHelper::permute<2, 0>());
+ vs[3] = mm256Permute2f128Ps(ib, S3, S7, MaskHelper::permute<2, 0>());
+ vs[4] = mm256Permute2f128Ps(ib, S0, S4, MaskHelper::permute<3, 1>());
+ vs[5] = mm256Permute2f128Ps(ib, S1, S5, MaskHelper::permute<3, 1>());
+ vs[6] = mm256Permute2f128Ps(ib, S2, S6, MaskHelper::permute<3, 1>());
+ vs[7] = mm256Permute2f128Ps(ib, S3, S7, MaskHelper::permute<3, 1>());
+}
+
+/// Rewrite avx2-specific 2-D vector.transpose, for the supported cases and
+/// depending on the `TransposeLoweringOptions`.
+class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
+public:
+ using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
+
+ TransposeOpLowering(LoweringOptions loweringOptions, MLIRContext *context,
+ int benefit)
+ : OpRewritePattern<vector::TransposeOp>(context, benefit),
+ loweringOptions(loweringOptions) {}
+
+ LogicalResult matchAndRewrite(vector::TransposeOp op,
+ PatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+
+ VectorType srcType = op.getVectorType();
+ if (srcType.getRank() != 2)
+ return rewriter.notifyMatchFailure(op, "Not a 2-D transpose");
+
+ SmallVector<int64_t, 4> transp;
+ for (auto attr : op.transp())
+ transp.push_back(attr.cast<IntegerAttr>().getInt());
+ if (transp[0] != 1 && transp[1] != 0)
+ return rewriter.notifyMatchFailure(op, "Not a 2-D transpose permutation");
+
+ int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
+
+ auto applyRewrite = [&]() {
+ ImplicitLocOpBuilder ib(loc, rewriter);
+ SmallVector<Value> vs;
+ for (int64_t i = 0; i < m; ++i)
+ vs.push_back(ib.create<vector::ExtractOp>(op.vector(), i));
+ if (m == 4)
+ transpose4x8xf32(ib, vs);
+ if (m == 8)
+ transpose8x8xf32(ib, vs);
+ auto flattenedType =
+ VectorType::get({n * m}, op.getVectorType().getElementType());
+ auto transposedType =
+ VectorType::get({n, m}, op.getVectorType().getElementType());
+ Value res = ib.create<arith::ConstantOp>(
+ op.getVectorType(), ib.getZeroAttr(op.getVectorType()));
+ // The transposed form is still 4x8 and needs to be reinterpreted as 8x4
+ // via shape_casts.
+ for (int64_t i = 0; i < m; ++i)
+ res = ib.create<vector::InsertOp>(vs[i], res, i);
+ if (m == 4) {
+ res = ib.create<vector::ShapeCastOp>(flattenedType, res);
+ res = ib.create<vector::ShapeCastOp>(transposedType, res);
+ }
+
+ rewriter.replaceOp(op, res);
+ return success();
+ };
+
+ if (loweringOptions.transposeOptions.lower4x8xf32_ && m == 4 && n == 8)
+ return applyRewrite();
+ if (loweringOptions.transposeOptions.lower8x8xf32_ && m == 8 && n == 8)
+ return applyRewrite();
+ return failure();
+ }
+
+private:
+ LoweringOptions loweringOptions;
+};
+
+void mlir::x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
+ RewritePatternSet &patterns, LoweringOptions options, int benefit) {
+ patterns.add<TransposeOpLowering>(options, patterns.getContext(), benefit);
+}
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
index 679379a0a51b0..8781d1f5f1242 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRX86VectorTransforms
+ AVXTranspose.cpp
LegalizeForLLVMExport.cpp
DEPENDS
@@ -10,4 +11,5 @@ add_mlir_dialect_library(MLIRX86VectorTransforms
MLIRIR
MLIRLLVMCommonConversion
MLIRLLVMIR
+ MLIRVector
)
diff --git a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
index 50e5ce8aaf9cf..93cdf05c77692 100644
--- a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-outerproduct=1 | FileCheck %s
+// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-outerproduct=1 | FileCheck %s
#matvec_accesses = [
affine_map<(i, j) -> (i, j)>,
diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index 3c32008fed81c..01e6a23698cc1 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -1,7 +1,7 @@
-// RUN: mlir-opt %s -test-vector-contraction-conversion | FileCheck %s
-// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX
-// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT
-// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-filter-outerproduct=1 | FileCheck %s --check-prefix=FILTEROUTERPRODUCT
+// RUN: mlir-opt %s -test-vector-contraction-lowering | FileCheck %s
+// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX
+// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT
+// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-filter-outerproduct=1 | FileCheck %s --check-prefix=FILTEROUTERPRODUCT
#dotp_accesses = [
affine_map<(i) -> (i)>,
@@ -149,8 +149,7 @@ func @extract_contract3(%arg0: vector<3xf32>,
// CHECK-SAME: %[[B:.*1]]: vector<2x2xf32>,
// CHECK-SAME: %[[C:.*2]]: vector<2x2xf32>
// CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
-// ... bunch of extract insert to transpose B into Bt
-// CHECK: %[[Bt:.*]] = vector.insert %{{.*}}, %{{.*}} [1, 1] : f32 into vector<2x2xf32>
+// CHECK: %[[Bt:.*]] = vector.transpose %arg1, [1, 0] : vector<2x2xf32> to vector<2x2xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x2xf32>
// CHECK: %[[T2:.*]] = vector.extract %[[Bt]][0] : vector<2x2xf32>
// CHECK: %[[T9:.*]] = arith.mulf %[[T0]], %[[T2]] : vector<2xf32>
@@ -399,28 +398,6 @@ func @axpy_int_add(%arg0: vector<16xi32>, %arg1: i32, %arg2: vector<16xi32>) ->
return %0: vector<16xi32>
}
-// CHECK-LABEL: func @transpose23
-// CHECK-SAME: %[[A:.*]]: vector<2x3xf32>
-// CHECK: %[[Z:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
-// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2x3xf32>
-// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[Z]] [0, 0] : f32 into vector<3x2xf32>
-// CHECK: %[[T2:.*]] = vector.extract %[[A]][1, 0] : vector<2x3xf32>
-// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<3x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 1] : vector<2x3xf32>
-// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [1, 0] : f32 into vector<3x2xf32>
-// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : vector<2x3xf32>
-// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 1] : f32 into vector<3x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[A]][0, 2] : vector<2x3xf32>
-// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [2, 0] : f32 into vector<3x2xf32>
-// CHECK: %[[T10:.*]] = vector.extract %[[A]][1, 2] : vector<2x3xf32>
-// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [2, 1] : f32 into vector<3x2xf32>
-// CHECK: return %[[T11]] : vector<3x2xf32>
-
-func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> {
- %0 = vector.transpose %arg0, [1, 0] : vector<2x3xf32> to vector<3x2xf32>
- return %0 : vector<3x2xf32>
-}
-
// CHECK-LABEL: func @nop_shape_cast
// CHECK-SAME: %[[A:.*]]: vector<16xf32>
// CHECK: return %[[A]] : vector<16xf32>
diff --git a/mlir/test/Dialect/Vector/vector-flat-transforms.mlir b/mlir/test/Dialect/Vector/vector-flat-transforms.mlir
deleted file mode 100644
index 8d51d323a1a71..0000000000000
--- a/mlir/test/Dialect/Vector/vector-flat-transforms.mlir
+++ /dev/null
@@ -1,65 +0,0 @@
-// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-flat-transpose=1 | FileCheck %s
-
-// Tests for lowering 2-D vector.transpose into vector.flat_transpose.
-//
-// TODO: having ShapeCastOp2DDownCastRewritePattern and
-// ShapeCastOp2DUpCastRewritePattern too early in the greedy rewriting
-// patterns misses opportunities to fold shape casts!
-
-// No shape cast folding expected.
-//
-// CHECK-LABEL: func @transpose44_44(
-// CHECK-SAME: %[[A:.*]]: vector<4x4xf32>
-// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x4xf32>
-// CHECK: %[[T8:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
-// CHECK: %[[T9:.*]] = vector.extract_strided_slice %[[T8]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xf32> to vector<4xf32>
-//
-func @transpose44_44(%arg0: vector<4x4xf32>) -> vector<4x4xf32> {
- %0 = vector.transpose %arg0, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
- return %0 : vector<4x4xf32>
-}
-
-// Folds preceding shape cast as expected,
-// no following shape cast folding expected.
-//
-// FIXME: PR49590 - shape_cast not stable.
-//
-// CHECK-LABEL: func @transpose16_44(
-// CHECK-SAME: %[[A:.*]]: vector<16xf32>
-// HECK: %[[T0:.*]] = vector.flat_transpose %[[A]] {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
-// HECK: %[[T1:.*]] = vector.extract_strided_slice %[[T0]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xf32> to vector<4xf32>
-//
-func @transpose16_44(%arg0: vector<16xf32>) -> vector<4x4xf32> {
- %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32>
- %1 = vector.transpose %0, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
- return %1 : vector<4x4xf32>
-}
-
-// No preceding shape cast folding expected,
-// but FAILS to fold following cast.
-//
-// CHECK-LABEL: func @transpose44_16(
-// CHECK-SAME: %[[A:.*]]: vector<4x4xf32>
-// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x4xf32>
-// CHECK: %[[T8:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
-func @transpose44_16(%arg0: vector<4x4xf32>) -> vector<16xf32> {
- %0 = vector.transpose %arg0, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
- %1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32>
- return %1 : vector<16xf32>
-}
-
-// Folds preceding shape cast as expected,
-// but FAILS to fold following cast.
-//
-// FIXME: PR49590 - shape_cast not stable.
-//
-// CHECK-LABEL: func @transpose16_16(
-// CHECK-SAME: %[[A:.*]]: vector<16xf32>
-// HECK: %[[T0:.*]] = vector.flat_transpose %[[A]] {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
-//
-func @transpose16_16(%arg0: vector<16xf32>) -> vector<16xf32> {
- %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32>
- %1 = vector.transpose %0, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
- %2 = vector.shape_cast %1 : vector<4x4xf32> to vector<16xf32>
- return %2 : vector<16xf32>
-}
diff --git a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
index ba943873daec2..cfa63339a99af 100644
--- a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s
+// RUN: mlir-opt %s -test-vector-to-vector-lowering | FileCheck %s
// CHECK-LABEL: func @maskedload0(
// CHECK-SAME: %[[A0:.*]]: memref<?xf32>,
diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index f10f84bcd9891..50ed05a4b956e 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-vector-to-vector-conversion="unroll" | FileCheck %s
+// RUN: mlir-opt %s -test-vector-to-vector-lowering="unroll" | FileCheck %s
// CHECK-DAG: #[[MAP1:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
new file mode 100644
index 0000000000000..2bb66d8f8f757
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -0,0 +1,101 @@
+// RUN: mlir-opt %s -test-vector-transpose-lowering=eltwise=1 | FileCheck %s --check-prefix=ELTWISE
+// RUN: mlir-opt %s -test-vector-transpose-lowering=shuffle=1 | FileCheck %s --check-prefix=SHUFFLE
+// RUN: mlir-opt %s -test-vector-transpose-lowering=flat=1 | FileCheck %s --check-prefix=FLAT
+// RUN: mlir-opt %s -test-vector-transpose-lowering=avx2=1 | FileCheck %s --check-prefix=AVX2
+
+// ELTWISE-LABEL: func @transpose23
+// ELTWISE-SAME: %[[A:.*]]: vector<2x3xf32>
+// ELTWISE: %[[Z:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
+// ELTWISE: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2x3xf32>
+// ELTWISE: %[[T1:.*]] = vector.insert %[[T0]], %[[Z]] [0, 0] : f32 into vector<3x2xf32>
+// ELTWISE: %[[T2:.*]] = vector.extract %[[A]][1, 0] : vector<2x3xf32>
+// ELTWISE: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<3x2xf32>
+// ELTWISE: %[[T4:.*]] = vector.extract %[[A]][0, 1] : vector<2x3xf32>
+// ELTWISE: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [1, 0] : f32 into vector<3x2xf32>
+// ELTWISE: %[[T6:.*]] = vector.extract %[[A]][1, 1] : vector<2x3xf32>
+// ELTWISE: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 1] : f32 into vector<3x2xf32>
+// ELTWISE: %[[T8:.*]] = vector.extract %[[A]][0, 2] : vector<2x3xf32>
+// ELTWISE: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [2, 0] : f32 into vector<3x2xf32>
+// ELTWISE: %[[T10:.*]] = vector.extract %[[A]][1, 2] : vector<2x3xf32>
+// ELTWISE: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [2, 1] : f32 into vector<3x2xf32>
+// ELTWISE: return %[[T11]] : vector<3x2xf32>
+func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> {
+ %0 = vector.transpose %arg0, [1, 0] : vector<2x3xf32> to vector<3x2xf32>
+ return %0 : vector<3x2xf32>
+}
+
+// SHUFFLE-LABEL: func @transpose
+// FLAT-LABEL: func @transpose(
+func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> {
+ // SHUFFLE: vector.shape_cast %{{.*}} : vector<2x4xf32> to vector<8xf32>
+ // 0 4
+ // 0 1 2 3 1 5
+ // 4 5 6 7 -> 2 6
+ // 3 7
+ // SHUFFLE-NEXT: vector.shuffle %{{.*}} [0, 4, 1, 5, 2, 6, 3, 7] : vector<8xf32>, vector<8xf32>
+ // SHUFFLE-NEXT: vector.shape_cast %{{.*}} : vector<8xf32> to vector<4x2xf32>
+
+ // FLAT: vector.shape_cast {{.*}} : vector<2x4xf32> to vector<8xf32>
+ // FLAT: vector.flat_transpose %{{.*}} {columns = 2 : i32, rows = 4 : i32} : vector<8xf32> -> vector<8xf32>
+ // FLAT: vector.shape_cast {{.*}} : vector<8xf32> to vector<4x2xf32>
+ %0 = vector.transpose %arg0, [1, 0] : vector<2x4xf32> to vector<4x2xf32>
+ return %0 : vector<4x2xf32>
+}
+
+// AVX2-LABEL: func @transpose4x8
+func @transpose4x8xf32(%arg0: vector<4x8xf32>) -> vector<8x4xf32> {
+ // AVX2: vector.extract {{.*}}[0]
+ // AVX2-NEXT: vector.extract {{.*}}[1]
+ // AVX2-NEXT: vector.extract {{.*}}[2]
+ // AVX2-NEXT: vector.extract {{.*}}[3]
+ // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.insert {{.*}}[0]
+ // AVX2-NEXT: vector.insert {{.*}}[1]
+ // AVX2-NEXT: vector.insert {{.*}}[2]
+ // AVX2-NEXT: vector.insert {{.*}}[3]
+ // AVX2-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<32xf32>
+ // AVX2-NEXT: vector.shape_cast {{.*}} vector<32xf32> to vector<8x4xf32>
+ %0 = vector.transpose %arg0, [1, 0] : vector<4x8xf32> to vector<8x4xf32>
+ return %0 : vector<8x4xf32>
+}
+
+// AVX2-LABEL: func @transpose8x8
+func @transpose8x8xf32(%arg0: vector<8x8xf32>) -> vector<8x8xf32> {
+ // AVX2: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+ // AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+ %0 = vector.transpose %arg0, [1, 0] : vector<8x8xf32> to vector<8x8xf32>
+ return %0 : vector<8x8xf32>
+}
diff --git a/mlir/test/Dialect/Vector/vector-transpose-to-shuffle.mlir b/mlir/test/Dialect/Vector/vector-transpose-to-shuffle.mlir
deleted file mode 100644
index 1b65579b5c813..0000000000000
--- a/mlir/test/Dialect/Vector/vector-transpose-to-shuffle.mlir
+++ /dev/null
@@ -1,14 +0,0 @@
-// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-shuffle-transpose=1 | FileCheck %s
-
-// CHECK-LABEL: func @transpose
-func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> {
- // CHECK: vector.shape_cast %{{.*}} : vector<2x4xf32> to vector<8xf32>
- // 0 4
- // 0 1 2 3 1 5
- // 4 5 6 7 -> 2 6
- // 3 7
- // CHECK: vector.shuffle %{{.*}} [0, 4, 1, 5, 2, 6, 3, 7] : vector<8xf32>, vector<8xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<8xf32> to vector<4x2xf32>
- %0 = vector.transpose %arg0, [1, 0] : vector<2x4xf32> to vector<4x2xf32>
- return %0 : vector<4x2xf32>
-}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 12d57489af60b..ccd6bc5fe31a5 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -1,4 +1,4 @@
-//===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===//
+//===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -11,27 +11,31 @@
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
+using namespace mlir::linalg;
using namespace mlir::vector;
namespace {
-struct TestVectorToVectorConversion
- : public PassWrapper<TestVectorToVectorConversion, FunctionPass> {
- TestVectorToVectorConversion() = default;
- TestVectorToVectorConversion(const TestVectorToVectorConversion &pass) {}
+struct TestVectorToVectorLowering
+ : public PassWrapper<TestVectorToVectorLowering, FunctionPass> {
+ TestVectorToVectorLowering() = default;
+ TestVectorToVectorLowering(const TestVectorToVectorLowering &pass) {}
StringRef getArgument() const final {
- return "test-vector-to-vector-conversion";
+ return "test-vector-to-vector-lowering";
}
StringRef getDescription() const final {
- return "Test conversion patterns between ops in the vector dialect";
+ return "Test lowering patterns between ops in the vector dialect";
}
void getDependentDialects(DialectRegistry ®istry) const override {
@@ -95,31 +99,22 @@ struct TestVectorToVectorConversion
}
};
-struct TestVectorContractionConversion
- : public PassWrapper<TestVectorContractionConversion, FunctionPass> {
+struct TestVectorContractionLowering
+ : public PassWrapper<TestVectorContractionLowering, FunctionPass> {
StringRef getArgument() const final {
- return "test-vector-contraction-conversion";
+ return "test-vector-contraction-lowering";
}
StringRef getDescription() const final {
- return "Test conversion patterns that lower contract ops in the vector "
+ return "Test lowering patterns that lower contract ops in the vector "
"dialect";
}
- TestVectorContractionConversion() = default;
- TestVectorContractionConversion(const TestVectorContractionConversion &pass) {
- }
+ TestVectorContractionLowering() = default;
+ TestVectorContractionLowering(const TestVectorContractionLowering &pass) {}
Option<bool> lowerToFlatMatrix{
*this, "vector-lower-matrix-intrinsics",
llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
llvm::cl::init(false)};
- Option<bool> lowerToFlatTranspose{
- *this, "vector-flat-transpose",
- llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
- llvm::cl::init(false)};
- Option<bool> lowerToShuffleTranspose{
- *this, "vector-shuffle-transpose",
- llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"),
- llvm::cl::init(false)};
Option<bool> lowerToOuterProduct{
*this, "vector-outerproduct",
llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
@@ -165,31 +160,91 @@ struct TestVectorContractionConversion
contractLowering = VectorContractLowering::Matmul;
VectorMultiReductionLowering vectorMultiReductionLowering =
VectorMultiReductionLowering::InnerParallel;
- VectorTransposeLowering transposeLowering =
- VectorTransposeLowering::EltWise;
- if (lowerToFlatTranspose)
- transposeLowering = VectorTransposeLowering::Flat;
- if (lowerToShuffleTranspose)
- transposeLowering = VectorTransposeLowering::Shuffle;
- VectorTransformsOptions options{
- contractLowering, vectorMultiReductionLowering, transposeLowering};
+ VectorTransformsOptions options{contractLowering,
+ vectorMultiReductionLowering,
+ VectorTransposeLowering()};
populateVectorBroadcastLoweringPatterns(patterns);
populateVectorContractLoweringPatterns(patterns, options);
populateVectorMaskOpLoweringPatterns(patterns);
- if (!lowerToShuffleTranspose)
- populateVectorShapeCastLoweringPatterns(patterns);
- populateVectorTransposeLoweringPatterns(patterns, options);
+ populateVectorShapeCastLoweringPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
+struct TestVectorTransposeLowering
+ : public PassWrapper<TestVectorTransposeLowering, FunctionPass> {
+ StringRef getArgument() const final {
+ return "test-vector-transpose-lowering";
+ }
+ StringRef getDescription() const final {
+ return "Test lowering patterns that lower contract ops in the vector "
+ "dialect";
+ }
+ TestVectorTransposeLowering() = default;
+ TestVectorTransposeLowering(const TestVectorTransposeLowering &pass) {}
+
+ Option<bool> lowerToEltwise{
+ *this, "eltwise",
+ llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"),
+ llvm::cl::init(false)};
+ Option<bool> lowerToFlatTranspose{
+ *this, "flat",
+ llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
+ llvm::cl::init(false)};
+ Option<bool> lowerToShuffleTranspose{
+ *this, "shuffle",
+ llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"),
+ llvm::cl::init(false)};
+ Option<bool> lowerToAvx2{
+ *this, "avx2",
+ llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"),
+ llvm::cl::init(false)};
+
+ void runOnFunction() override {
+ RewritePatternSet patterns(&getContext());
+
+ // Test on one pattern in isolation.
+ // Explicitly disable shape_cast lowering.
+ LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions()
+ .enableVectorTransposeLowering()
+ .enableShapeCastLowering(false);
+ if (lowerToEltwise) {
+ options = options.setVectorTransformsOptions(
+ VectorTransformsOptions().setVectorTransposeLowering(
+ VectorTransposeLowering::EltWise));
+ }
+ if (lowerToFlatTranspose) {
+ options = options.setVectorTransformsOptions(
+ VectorTransformsOptions().setVectorTransposeLowering(
+ VectorTransposeLowering::Flat));
+ }
+ if (lowerToShuffleTranspose) {
+ options = options.setVectorTransformsOptions(
+ VectorTransformsOptions().setVectorTransposeLowering(
+ VectorTransposeLowering::Shuffle));
+ }
+ if (lowerToAvx2) {
+ options = options.enableAVX2Lowering().setAVX2LoweringOptions(
+ x86vector::avx2::LoweringOptions().setTransposeOptions(
+ x86vector::avx2::TransposeLoweringOptions()
+ .lower4x8xf32()
+ .lower8x8xf32()));
+ }
+
+ OpPassManager dynamicPM("builtin.func");
+ dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options));
+ if (failed(runPipeline(dynamicPM, getFunction())))
+ return signalPassFailure();
+ }
+};
+
struct TestVectorUnrollingPatterns
: public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> {
StringRef getArgument() const final {
return "test-vector-unrolling-patterns";
}
StringRef getDescription() const final {
- return "Test conversion patterns to unroll contract ops in the vector "
+ return "Test lowering patterns to unroll contract ops in the vector "
"dialect";
}
TestVectorUnrollingPatterns() = default;
@@ -248,7 +303,7 @@ struct TestVectorDistributePatterns
return "test-vector-distribute-patterns";
}
StringRef getDescription() const final {
- return "Test conversion patterns to distribute vector ops in the vector "
+ return "Test lowering patterns to distribute vector ops in the vector "
"dialect";
}
TestVectorDistributePatterns() = default;
@@ -302,7 +357,7 @@ struct TestVectorToLoopPatterns
: public PassWrapper<TestVectorToLoopPatterns, FunctionPass> {
StringRef getArgument() const final { return "test-vector-to-forloop"; }
StringRef getDescription() const final {
- return "Test conversion patterns to break up a vector op into a for loop";
+ return "Test lowering patterns to break up a vector op into a for loop";
}
TestVectorToLoopPatterns() = default;
TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {}
@@ -365,7 +420,7 @@ struct TestVectorTransferUnrollingPatterns
return "test-vector-transfer-unrolling-patterns";
}
StringRef getDescription() const final {
- return "Test conversion patterns to unroll transfer ops in the vector "
+ return "Test lowering patterns to unroll transfer ops in the vector "
"dialect";
}
void runOnFunction() override {
@@ -391,7 +446,7 @@ struct TestVectorTransferFullPartialSplitPatterns
return "test-vector-transfer-full-partial-split";
}
StringRef getDescription() const final {
- return "Test conversion patterns to split "
+ return "Test lowering patterns to split "
"transfer ops via scf.if + linalg ops";
}
TestVectorTransferFullPartialSplitPatterns() = default;
@@ -439,7 +494,7 @@ struct TestVectorTransferLoweringPatterns
return "test-vector-transfer-lowering-patterns";
}
StringRef getDescription() const final {
- return "Test conversion patterns to lower transfer ops to other vector ops";
+ return "Test lowering patterns to lower transfer ops to other vector ops";
}
void runOnFunction() override {
RewritePatternSet patterns(&getContext());
@@ -462,7 +517,7 @@ struct TestVectorMultiReductionLoweringPatterns
return "test-vector-multi-reduction-lowering-patterns";
}
StringRef getDescription() const final {
- return "Test conversion patterns to lower vector.multi_reduction to other "
+ return "Test lowering patterns to lower vector.multi_reduction to other "
"vector ops";
}
Option<bool> useOuterReductions{
@@ -495,7 +550,7 @@ struct TestVectorTransferCollapseInnerMostContiguousDims
}
StringRef getDescription() const final {
- return "Test conversion patterns that reducedes the rank of the vector "
+ return "Test lowering patterns that reducedes the rank of the vector "
"transfer memory and vector operands.";
}
@@ -527,10 +582,12 @@ struct TestVectorReduceToContractPatternsPatterns
namespace mlir {
namespace test {
-void registerTestVectorConversions() {
- PassRegistration<TestVectorToVectorConversion>();
+void registerTestVectorLowerings() {
+ PassRegistration<TestVectorToVectorLowering>();
+
+ PassRegistration<TestVectorContractionLowering>();
- PassRegistration<TestVectorContractionConversion>();
+ PassRegistration<TestVectorTransposeLowering>();
PassRegistration<TestVectorUnrollingPatterns>();
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index dd24b4e507523..285b06d7aa831 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -107,7 +107,7 @@ void registerTestPreparationPassWithAllowedMemrefResults();
void registerTestRecursiveTypesPass();
void registerTestSCFUtilsPass();
void registerTestSliceAnalysisPass();
-void registerTestVectorConversions();
+void registerTestVectorLowerings();
} // namespace test
} // namespace mlir
@@ -197,7 +197,7 @@ void registerTestPasses() {
mlir::test::registerTestRecursiveTypesPass();
mlir::test::registerTestSCFUtilsPass();
mlir::test::registerTestSliceAnalysisPass();
- mlir::test::registerTestVectorConversions();
+ mlir::test::registerTestVectorLowerings();
}
#endif
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index bb3e66695c5d0..1c8a45c7e9b74 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1458,6 +1458,7 @@ cc_library(
":LLVMCommonConversion",
":LLVMDialect",
":StandardOps",
+ ":VectorOps",
":X86Vector",
"//llvm:Core",
"//llvm:Support",
@@ -6401,6 +6402,7 @@ cc_library(
":TransformUtils",
":VectorOps",
":VectorToSCF",
+ ":X86VectorTransforms",
"//llvm:Support",
],
)
diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index 31c0a7303f2fb..a4556d9499060 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -484,6 +484,7 @@ cc_library(
"//mlir:Affine",
"//mlir:Analysis",
"//mlir:LinalgOps",
+ "//mlir:LinalgTransforms",
"//mlir:MemRefDialect",
"//mlir:Pass",
"//mlir:SCFDialect",
More information about the Mlir-commits
mailing list