[Mlir-commits] [mlir] 25cc5a7 - [mlir][vector] Generalize vector.transpose lowering to n-D vectors
Hanhan Wang
llvmlistbot at llvm.org
Mon May 8 10:48:39 PDT 2023
Author: Hanhan Wang
Date: 2023-05-08T10:48:26-07:00
New Revision: 25cc5a71b3a2f197fd3c31eeba1eeb1711b93de2
URL: https://github.com/llvm/llvm-project/commit/25cc5a71b3a2f197fd3c31eeba1eeb1711b93de2
DIFF: https://github.com/llvm/llvm-project/commit/25cc5a71b3a2f197fd3c31eeba1eeb1711b93de2.diff
LOG: [mlir][vector] Generalize vector.transpose lowering to n-D vectors
The existing vector.transpose lowering patterns only triggers if the
input vector is 2D. The revision extends the pattern to handle n-D
vectors which are effectively 2-D vectors (e.g., vector<1x4x1x8x1).
It refactors a common check about 2-D vectors from X86Vector
lowering to VectorUtils.h so it can be reused by both sides.
Reviewed By: dcaballe
Differential Revision: https://reviews.llvm.org/D149908
Added:
Modified:
mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 66f8cbbf4bdb6..fc00769a4aaa8 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -37,6 +37,11 @@ namespace vector {
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
/// the type of `source`.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
+
+/// Returns two dims that are greater than one if the transposition is applied
+/// on a 2D slice. Otherwise, returns a failure.
+FailureOr<std::pair<int, int>> isTranspose2DSlice(vector::TransposeOp op);
+
} // namespace vector
/// Constructs a permutation map of invariant memref indices to vector
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 1408c03f21456..42c1aa58c5e5e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -332,7 +332,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
transp.push_back(attr.cast<IntegerAttr>().getInt());
if (isShuffleLike(vectorTransformOptions.vectorTransposeLowering) &&
- resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0)
+ succeeded(isTranspose2DSlice(op)))
return rewriter.notifyMatchFailure(
op, "Options specifies lowering to shuffle");
@@ -411,36 +411,37 @@ class TransposeOp2DToShuffleLowering
LogicalResult matchAndRewrite(vector::TransposeOp op,
PatternRewriter &rewriter) const override {
- auto loc = op.getLoc();
+ if (!isShuffleLike(vectorTransformOptions.vectorTransposeLowering))
+ return rewriter.notifyMatchFailure(
+ op, "not using vector shuffle based lowering");
+
+ auto srcGtOneDims = isTranspose2DSlice(op);
+ if (failed(srcGtOneDims))
+ return rewriter.notifyMatchFailure(
+ op, "expected transposition on a 2D slice");
VectorType srcType = op.getSourceVectorType();
- if (srcType.getRank() != 2)
- return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
+ int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
+ int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));
- SmallVector<int64_t> transp;
- for (auto attr : op.getTransp())
- transp.push_back(attr.cast<IntegerAttr>().getInt());
- if (transp[0] != 1 && transp[1] != 0)
- return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation");
+ // Reshape the n-D input vector with only two dimensions greater than one
+ // to a 2-D vector.
+ Location loc = op.getLoc();
+ auto flattenedType = VectorType::get({n * m}, srcType.getElementType());
+ auto reshInputType = VectorType::get({m, n}, srcType.getElementType());
+ auto reshInput = rewriter.create<vector::ShapeCastOp>(loc, flattenedType,
+ op.getVector());
Value res;
- int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
- switch (vectorTransformOptions.vectorTransposeLowering) {
- case VectorTransposeLowering::Shuffle1D: {
- Value casted = rewriter.create<vector::ShapeCastOp>(
- loc, VectorType::get({m * n}, srcType.getElementType()),
- op.getVector());
- res = transposeToShuffle1D(rewriter, casted, m, n);
- break;
- }
- case VectorTransposeLowering::Shuffle16x16:
- if (m != 16 || n != 16)
- return failure();
- res = transposeToShuffle16x16(rewriter, op.getVector(), m, n);
- break;
- case VectorTransposeLowering::EltWise:
- case VectorTransposeLowering::Flat:
- return failure();
+ if (vectorTransformOptions.vectorTransposeLowering ==
+ VectorTransposeLowering::Shuffle16x16 &&
+ m == 16 && n == 16) {
+ reshInput =
+ rewriter.create<vector::ShapeCastOp>(loc, reshInputType, reshInput);
+ res = transposeToShuffle16x16(rewriter, reshInput, m, n);
+ } else {
+ // Fallback to shuffle on 1D approach.
+ res = transposeToShuffle1D(rewriter, reshInput, m, n);
}
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 15895470eca4c..e77a13a9c653a 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -43,6 +43,63 @@ Value mlir::vector::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
llvm_unreachable("Expected MemRefType or TensorType");
}
+/// Given the n-D transpose pattern 'transp', return true if 'dim0' and 'dim1'
+/// should be transposed with each other within the context of their 2D
+/// transposition slice.
+///
+/// Example 1: dim0 = 0, dim1 = 2, transp = [2, 1, 0]
+/// Return true: dim0 and dim1 are transposed within the context of their 2D
+/// transposition slice ([1, 0]).
+///
+/// Example 2: dim0 = 0, dim1 = 1, transp = [2, 1, 0]
+/// Return true: dim0 and dim1 are transposed within the context of their 2D
+/// transposition slice ([1, 0]). Paradoxically, note how dim1 (1) is *not*
+/// transposed within the full context of the transposition.
+///
+/// Example 3: dim0 = 0, dim1 = 1, transp = [2, 0, 1]
+/// Return false: dim0 and dim1 are *not* transposed within the context of
+/// their 2D transposition slice ([0, 1]). Paradoxically, note how dim0 (0)
+/// and dim1 (1) are transposed within the full context of the of the
+/// transposition.
+static bool areDimsTransposedIn2DSlice(int64_t dim0, int64_t dim1,
+ ArrayRef<int64_t> transp) {
+ // Perform a linear scan along the dimensions of the transposed pattern. If
+ // dim0 is found first, dim0 and dim1 are not transposed within the context of
+ // their 2D slice. Otherwise, 'dim1' is found first and they are transposed.
+ for (int64_t permDim : transp) {
+ if (permDim == dim0)
+ return false;
+ if (permDim == dim1)
+ return true;
+ }
+
+ llvm_unreachable("Ill-formed transpose pattern");
+}
+
+FailureOr<std::pair<int, int>>
+mlir::vector::isTranspose2DSlice(vector::TransposeOp op) {
+ VectorType srcType = op.getSourceVectorType();
+ SmallVector<int64_t> srcGtOneDims;
+ for (auto [index, size] : llvm::enumerate(srcType.getShape()))
+ if (size > 1)
+ srcGtOneDims.push_back(index);
+
+ if (srcGtOneDims.size() != 2)
+ return failure();
+
+ SmallVector<int64_t> transp;
+ for (auto attr : op.getTransp())
+ transp.push_back(attr.cast<IntegerAttr>().getInt());
+
+ // Check whether the two source vector dimensions that are greater than one
+ // must be transposed with each other so that we can apply one of the 2-D
+ // transpose pattens. Otherwise, these patterns are not applicable.
+ if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1], transp))
+ return failure();
+
+ return std::pair<int, int>(srcGtOneDims[0], srcGtOneDims[1]);
+}
+
/// Constructs a permutation map from memref indices to vector dimension.
///
/// The implementation uses the knowledge of the mapping of enclosing loop to
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
index daeb128fecb40..ec55aac6e06fd 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Matchers.h"
@@ -187,39 +188,6 @@ void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib,
vs[7] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<3, 1>());
}
-/// Given the n-D transpose pattern 'transp', return true if 'dim0' and 'dim1'
-/// should be transposed with each other within the context of their 2D
-/// transposition slice.
-///
-/// Example 1: dim0 = 0, dim1 = 2, transp = [2, 1, 0]
-/// Return true: dim0 and dim1 are transposed within the context of their 2D
-/// transposition slice ([1, 0]).
-///
-/// Example 2: dim0 = 0, dim1 = 1, transp = [2, 1, 0]
-/// Return true: dim0 and dim1 are transposed within the context of their 2D
-/// transposition slice ([1, 0]). Paradoxically, note how dim1 (1) is *not*
-/// transposed within the full context of the transposition.
-///
-/// Example 3: dim0 = 0, dim1 = 1, transp = [2, 0, 1]
-/// Return false: dim0 and dim1 are *not* transposed within the context of
-/// their 2D transposition slice ([0, 1]). Paradoxically, note how dim0 (0)
-/// and dim1 (1) are transposed within the full context of the of the
-/// transposition.
-static bool areDimsTransposedIn2DSlice(int64_t dim0, int64_t dim1,
- ArrayRef<int64_t> transp) {
- // Perform a linear scan along the dimensions of the transposed pattern. If
- // dim0 is found first, dim0 and dim1 are not transposed within the context of
- // their 2D slice. Otherwise, 'dim1' is found first and they are transposed.
- for (int64_t permDim : transp) {
- if (permDim == dim0)
- return false;
- if (permDim == dim1)
- return true;
- }
-
- llvm_unreachable("Ill-formed transpose pattern");
-}
-
/// Rewrite AVX2-specific vector.transpose, for the supported cases and
/// depending on the `TransposeLoweringOptions`. The lowering supports 2-D
/// transpose cases and n-D cases that have been decomposed into 2-D
@@ -256,29 +224,16 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
if (!srcType.getElementType().isF32())
return rewriter.notifyMatchFailure(op, "Unsupported vector element type");
- SmallVector<int64_t> srcGtOneDims;
- for (auto [index, size] : llvm::enumerate(srcType.getShape()))
- if (size > 1)
- srcGtOneDims.push_back(index);
-
- if (srcGtOneDims.size() != 2)
- return rewriter.notifyMatchFailure(op, "Unsupported vector type");
-
- SmallVector<int64_t, 4> transp;
- for (auto attr : op.getTransp())
- transp.push_back(attr.cast<IntegerAttr>().getInt());
-
- // Check whether the two source vector dimensions that are greater than one
- // must be transposed with each other so that we can apply one of the 2-D
- // AVX2 transpose pattens. Otherwise, these patterns are not applicable.
- if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1], transp))
+ auto srcGtOneDims = mlir::vector::isTranspose2DSlice(op);
+ if (failed(srcGtOneDims))
return rewriter.notifyMatchFailure(
- op, "Not applicable to this transpose permutation");
+ op, "expected transposition on a 2D slice");
// Retrieve the sizes of the two dimensions greater than one to be
// transposed.
auto srcShape = srcType.getShape();
- int64_t m = srcShape[srcGtOneDims[0]], n = srcShape[srcGtOneDims[1]];
+ int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
+ int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));
auto applyRewrite = [&]() {
ImplicitLocOpBuilder ib(loc, rewriter);
diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index 44bba6fbed982..a668b49efc6e5 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -687,3 +687,82 @@ transform.sequence failures(propagate) {
lowering_strategy = "shuffle_16x16"
: (!pdl.operation) -> !pdl.operation
}
+
+// -----
+
+// CHECK-LABEL: func @transpose021_shuffle16x16xf32
+func.func @transpose021_shuffle16x16xf32(%arg0: vector<1x16x16xf32>) -> vector<1x16x16xf32> {
+ // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+ %0 = vector.transpose %arg0, [0, 2, 1] : vector<1x16x16xf32> to vector<1x16x16xf32>
+ return %0 : vector<1x16x16xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+ transform.vector.lower_transpose %module_op
+ lowering_strategy = "shuffle_16x16"
+ : (!pdl.operation) -> !pdl.operation
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index daa5b894d7b98..eef474ee0af06 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1920,6 +1920,7 @@ cc_library(
":LLVMCommonConversion",
":LLVMDialect",
":VectorDialect",
+ ":VectorUtils",
":X86VectorDialect",
"//llvm:Core",
"//llvm:Support",
More information about the Mlir-commits
mailing list