[Mlir-commits] [mlir] 25cc5a7 - [mlir][vector] Generalize vector.transpose lowering to n-D vectors

Hanhan Wang llvmlistbot at llvm.org
Mon May 8 10:48:39 PDT 2023


Author: Hanhan Wang
Date: 2023-05-08T10:48:26-07:00
New Revision: 25cc5a71b3a2f197fd3c31eeba1eeb1711b93de2

URL: https://github.com/llvm/llvm-project/commit/25cc5a71b3a2f197fd3c31eeba1eeb1711b93de2
DIFF: https://github.com/llvm/llvm-project/commit/25cc5a71b3a2f197fd3c31eeba1eeb1711b93de2.diff

LOG: [mlir][vector] Generalize vector.transpose lowering to n-D vectors

The existing vector.transpose lowering patterns only triggers if the
input vector is 2D. The revision extends the pattern to handle n-D
vectors which are effectively 2-D vectors (e.g., vector<1x4x1x8x1).

It refactors a common check about 2-D vectors from X86Vector
lowering to VectorUtils.h so it can be reused by both sides.

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D149908

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
    mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
    mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
    mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
    mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 66f8cbbf4bdb6..fc00769a4aaa8 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -37,6 +37,11 @@ namespace vector {
 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on
 /// the type of `source`.
 Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
+
+/// Returns two dims that are greater than one if the transposition is applied
+/// on a 2D slice. Otherwise, returns a failure.
+FailureOr<std::pair<int, int>> isTranspose2DSlice(vector::TransposeOp op);
+
 } // namespace vector
 
 /// Constructs a permutation map of invariant memref indices to vector

diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 1408c03f21456..42c1aa58c5e5e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -332,7 +332,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
       transp.push_back(attr.cast<IntegerAttr>().getInt());
 
     if (isShuffleLike(vectorTransformOptions.vectorTransposeLowering) &&
-        resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0)
+        succeeded(isTranspose2DSlice(op)))
       return rewriter.notifyMatchFailure(
           op, "Options specifies lowering to shuffle");
 
@@ -411,36 +411,37 @@ class TransposeOp2DToShuffleLowering
 
   LogicalResult matchAndRewrite(vector::TransposeOp op,
                                 PatternRewriter &rewriter) const override {
-    auto loc = op.getLoc();
+    if (!isShuffleLike(vectorTransformOptions.vectorTransposeLowering))
+      return rewriter.notifyMatchFailure(
+          op, "not using vector shuffle based lowering");
+
+    auto srcGtOneDims = isTranspose2DSlice(op);
+    if (failed(srcGtOneDims))
+      return rewriter.notifyMatchFailure(
+          op, "expected transposition on a 2D slice");
 
     VectorType srcType = op.getSourceVectorType();
-    if (srcType.getRank() != 2)
-      return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
+    int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
+    int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));
 
-    SmallVector<int64_t> transp;
-    for (auto attr : op.getTransp())
-      transp.push_back(attr.cast<IntegerAttr>().getInt());
-    if (transp[0] != 1 && transp[1] != 0)
-      return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation");
+    // Reshape the n-D input vector with only two dimensions greater than one
+    // to a 2-D vector.
+    Location loc = op.getLoc();
+    auto flattenedType = VectorType::get({n * m}, srcType.getElementType());
+    auto reshInputType = VectorType::get({m, n}, srcType.getElementType());
+    auto reshInput = rewriter.create<vector::ShapeCastOp>(loc, flattenedType,
+                                                          op.getVector());
 
     Value res;
-    int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
-    switch (vectorTransformOptions.vectorTransposeLowering) {
-    case VectorTransposeLowering::Shuffle1D: {
-      Value casted = rewriter.create<vector::ShapeCastOp>(
-          loc, VectorType::get({m * n}, srcType.getElementType()),
-          op.getVector());
-      res = transposeToShuffle1D(rewriter, casted, m, n);
-      break;
-    }
-    case VectorTransposeLowering::Shuffle16x16:
-      if (m != 16 || n != 16)
-        return failure();
-      res = transposeToShuffle16x16(rewriter, op.getVector(), m, n);
-      break;
-    case VectorTransposeLowering::EltWise:
-    case VectorTransposeLowering::Flat:
-      return failure();
+    if (vectorTransformOptions.vectorTransposeLowering ==
+            VectorTransposeLowering::Shuffle16x16 &&
+        m == 16 && n == 16) {
+      reshInput =
+          rewriter.create<vector::ShapeCastOp>(loc, reshInputType, reshInput);
+      res = transposeToShuffle16x16(rewriter, reshInput, m, n);
+    } else {
+      // Fallback to shuffle on 1D approach.
+      res = transposeToShuffle1D(rewriter, reshInput, m, n);
     }
 
     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(

diff  --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 15895470eca4c..e77a13a9c653a 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -43,6 +43,63 @@ Value mlir::vector::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
   llvm_unreachable("Expected MemRefType or TensorType");
 }
 
+/// Given the n-D transpose pattern 'transp', return true if 'dim0' and 'dim1'
+/// should be transposed with each other within the context of their 2D
+/// transposition slice.
+///
+/// Example 1: dim0 = 0, dim1 = 2, transp = [2, 1, 0]
+///   Return true: dim0 and dim1 are transposed within the context of their 2D
+///   transposition slice ([1, 0]).
+///
+/// Example 2: dim0 = 0, dim1 = 1, transp = [2, 1, 0]
+///   Return true: dim0 and dim1 are transposed within the context of their 2D
+///   transposition slice ([1, 0]). Paradoxically, note how dim1 (1) is *not*
+///   transposed within the full context of the transposition.
+///
+/// Example 3: dim0 = 0, dim1 = 1, transp = [2, 0, 1]
+///   Return false: dim0 and dim1 are *not* transposed within the context of
+///   their 2D transposition slice ([0, 1]). Paradoxically, note how dim0 (0)
+///   and dim1 (1) are transposed within the full context of the of the
+///   transposition.
+static bool areDimsTransposedIn2DSlice(int64_t dim0, int64_t dim1,
+                                       ArrayRef<int64_t> transp) {
+  // Perform a linear scan along the dimensions of the transposed pattern. If
+  // dim0 is found first, dim0 and dim1 are not transposed within the context of
+  // their 2D slice. Otherwise, 'dim1' is found first and they are transposed.
+  for (int64_t permDim : transp) {
+    if (permDim == dim0)
+      return false;
+    if (permDim == dim1)
+      return true;
+  }
+
+  llvm_unreachable("Ill-formed transpose pattern");
+}
+
+FailureOr<std::pair<int, int>>
+mlir::vector::isTranspose2DSlice(vector::TransposeOp op) {
+  VectorType srcType = op.getSourceVectorType();
+  SmallVector<int64_t> srcGtOneDims;
+  for (auto [index, size] : llvm::enumerate(srcType.getShape()))
+    if (size > 1)
+      srcGtOneDims.push_back(index);
+
+  if (srcGtOneDims.size() != 2)
+    return failure();
+
+  SmallVector<int64_t> transp;
+  for (auto attr : op.getTransp())
+    transp.push_back(attr.cast<IntegerAttr>().getInt());
+
+  // Check whether the two source vector dimensions that are greater than one
+  // must be transposed with each other so that we can apply one of the 2-D
+  // transpose pattens. Otherwise, these patterns are not applicable.
+  if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1], transp))
+    return failure();
+
+  return std::pair<int, int>(srcGtOneDims[0], srcGtOneDims[1]);
+}
+
 /// Constructs a permutation map from memref indices to vector dimension.
 ///
 /// The implementation uses the knowledge of the mapping of enclosing loop to

diff  --git a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
index daeb128fecb40..ec55aac6e06fd 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/Dialect/X86Vector/Transforms.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/Matchers.h"
@@ -187,39 +188,6 @@ void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib,
   vs[7] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<3, 1>());
 }
 
-/// Given the n-D transpose pattern 'transp', return true if 'dim0' and 'dim1'
-/// should be transposed with each other within the context of their 2D
-/// transposition slice.
-///
-/// Example 1: dim0 = 0, dim1 = 2, transp = [2, 1, 0]
-///   Return true: dim0 and dim1 are transposed within the context of their 2D
-///   transposition slice ([1, 0]).
-///
-/// Example 2: dim0 = 0, dim1 = 1, transp = [2, 1, 0]
-///   Return true: dim0 and dim1 are transposed within the context of their 2D
-///   transposition slice ([1, 0]). Paradoxically, note how dim1 (1) is *not*
-///   transposed within the full context of the transposition.
-///
-/// Example 3: dim0 = 0, dim1 = 1, transp = [2, 0, 1]
-///   Return false: dim0 and dim1 are *not* transposed within the context of
-///   their 2D transposition slice ([0, 1]). Paradoxically, note how dim0 (0)
-///   and dim1 (1) are transposed within the full context of the of the
-///   transposition.
-static bool areDimsTransposedIn2DSlice(int64_t dim0, int64_t dim1,
-                                       ArrayRef<int64_t> transp) {
-  // Perform a linear scan along the dimensions of the transposed pattern. If
-  // dim0 is found first, dim0 and dim1 are not transposed within the context of
-  // their 2D slice. Otherwise, 'dim1' is found first and they are transposed.
-  for (int64_t permDim : transp) {
-    if (permDim == dim0)
-      return false;
-    if (permDim == dim1)
-      return true;
-  }
-
-  llvm_unreachable("Ill-formed transpose pattern");
-}
-
 /// Rewrite AVX2-specific vector.transpose, for the supported cases and
 /// depending on the `TransposeLoweringOptions`. The lowering supports 2-D
 /// transpose cases and n-D cases that have been decomposed into 2-D
@@ -256,29 +224,16 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
     if (!srcType.getElementType().isF32())
       return rewriter.notifyMatchFailure(op, "Unsupported vector element type");
 
-    SmallVector<int64_t> srcGtOneDims;
-    for (auto [index, size] : llvm::enumerate(srcType.getShape()))
-      if (size > 1)
-        srcGtOneDims.push_back(index);
-
-    if (srcGtOneDims.size() != 2)
-      return rewriter.notifyMatchFailure(op, "Unsupported vector type");
-
-    SmallVector<int64_t, 4> transp;
-    for (auto attr : op.getTransp())
-      transp.push_back(attr.cast<IntegerAttr>().getInt());
-
-    // Check whether the two source vector dimensions that are greater than one
-    // must be transposed with each other so that we can apply one of the 2-D
-    // AVX2 transpose pattens. Otherwise, these patterns are not applicable.
-    if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1], transp))
+    auto srcGtOneDims = mlir::vector::isTranspose2DSlice(op);
+    if (failed(srcGtOneDims))
       return rewriter.notifyMatchFailure(
-          op, "Not applicable to this transpose permutation");
+          op, "expected transposition on a 2D slice");
 
     // Retrieve the sizes of the two dimensions greater than one to be
     // transposed.
     auto srcShape = srcType.getShape();
-    int64_t m = srcShape[srcGtOneDims[0]], n = srcShape[srcGtOneDims[1]];
+    int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
+    int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));
 
     auto applyRewrite = [&]() {
       ImplicitLocOpBuilder ib(loc, rewriter);

diff  --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index 44bba6fbed982..a668b49efc6e5 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -687,3 +687,82 @@ transform.sequence failures(propagate) {
     lowering_strategy = "shuffle_16x16"
       : (!pdl.operation) -> !pdl.operation
 }
+
+// -----
+
+// CHECK-LABEL: func @transpose021_shuffle16x16xf32
+func.func @transpose021_shuffle16x16xf32(%arg0: vector<1x16x16xf32>) -> vector<1x16x16xf32> {
+  // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+  %0 = vector.transpose %arg0, [0, 2, 1] : vector<1x16x16xf32> to vector<1x16x16xf32>
+  return %0 : vector<1x16x16xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+  transform.vector.lower_transpose %module_op
+    lowering_strategy = "shuffle_16x16"
+      : (!pdl.operation) -> !pdl.operation
+}

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index daa5b894d7b98..eef474ee0af06 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1920,6 +1920,7 @@ cc_library(
         ":LLVMCommonConversion",
         ":LLVMDialect",
         ":VectorDialect",
+        ":VectorUtils",
         ":X86VectorDialect",
         "//llvm:Core",
         "//llvm:Support",


        


More information about the Mlir-commits mailing list