[Mlir-commits] [mlir] [MLIR][Vector] Replace vector.transpose with vector.shape_cast (PR #125966)

Hyunsung Lee llvmlistbot at llvm.org
Wed Feb 5 16:05:50 PST 2025


https://github.com/ita9naiwa created https://github.com/llvm/llvm-project/pull/125966

> Suppose the permutation width is defined as the last index in the
permutation array that is not equal to its index. This pattern is
applied to transpose operations where the input vector has a shape with
at most one non-unit dimension up to the permutation width. The pattern
replaces the transpose operation with a shape cast operation.
For example:
%0 = vector.transpose %1, [1, 0, 2] : vector<4x1x1xi32> to
vector<1x4x1xi32>
is replaced by
 %0 = vector.shape_cast %1 : vector<4x1x1xi32> to vector<1x4x1xi32>
given the permutation width is 2.

this work(#94912) is credited @pashu123
I added some tests and fixed tests but didn't know how to push on his branch and created new PR.


>From f65021aeb617be59c718481554f83b5baf78130c Mon Sep 17 00:00:00 2001
From: Prashant Kumar <pk5561 at gmail.com>
Date: Mon, 10 Jun 2024 01:27:21 +0530
Subject: [PATCH 1/2] [mlir][Vector] Replace vector.transpose with
 vector.shape_cast

Suppose the permutation width is defined as the last index in the
permutation array that is not equal to its index. This pattern is
applied to transpose operations where the input vector has a shape with
at most one non-unit dimension up to the permutation width. The pattern
replaces the transpose operation with a shape cast operation.
For example:
%0 = vector.transpose %1, [1, 0, 2] : vector<4x1x1xi32> to
vector<1x4x1xi32>
is replaced by
 %0 = vector.shape_cast %1 : vector<4x1x1xi32> to vector<1x4x1xi32>
given the permutation width is 2.
---
 .../Transforms/LowerVectorTranspose.cpp       | 58 ++++++++++++++++++-
 .../Vector/vector-transpose-lowering.mlir     |  7 +++
 2 files changed, 63 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 3c92b222e6bc80f..a29ba47b28cde15 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -450,6 +450,59 @@ class Transpose2DWithUnitDimToShapeCast
   }
 };
 
+// Suppose the permutation width is defined as the last index in the permutation
+// array that is not equal to its index. This pattern is applied to transpose
+// operations where the input vector has a shape with at most one non-unit
+// dimension up to the permutation width. The pattern replaces the transpose
+// operation with a shape cast operation.
+// For example:
+//  %0 = vector.transpose %1, [1, 0, 2] : vector<4x1x1xi32> to vector<1x4x1xi32>
+//  is replaced by
+//  %0 = vector.shape_cast %1 : vector<4x1x1xi32> to vector<1x4x1xi32>
+//  given the permutation width is 2.
+class TransposeWithUnitDimToShapeCast
+    : public OpRewritePattern<vector::TransposeOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  TransposeWithUnitDimToShapeCast(MLIRContext *context,
+                                  PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
+
+  LogicalResult matchAndRewrite(vector::TransposeOp op,
+                                PatternRewriter &rewriter) const override {
+    Value input = op.getVector();
+    VectorType inputType = op.getSourceVectorType();
+    if (inputType.isScalable())
+      return rewriter.notifyMatchFailure(
+          op, "This lowering does not support scalable vectors");
+    VectorType resType = op.getResultVectorType();
+
+    ArrayRef<int64_t> transp = op.getPermutation();
+
+    // Get the permutation width.
+    int64_t permWidth = 1;
+    for (auto &&[idx, val] : llvm::enumerate(transp)) {
+      if (static_cast<int64_t>(idx) != val)
+        permWidth = idx + 1;
+    }
+
+    // Check the no. of non unit dim in the input shape upto permutation width
+    // is not greater than one.
+    auto inputShape = inputType.getShape();
+
+    int64_t countNonUnitDims = 0;
+    for (int i = 0; i < permWidth; i++) {
+      if (inputShape[i] != 1)
+        countNonUnitDims++;
+      if (countNonUnitDims > 1)
+        return failure();
+    }
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
+    return success();
+  }
+};
+
 /// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
 /// If the strategy is Shuffle1D, it will be lowered to:
 ///   vector.shape_cast 2D -> 1D
@@ -522,8 +575,9 @@ class TransposeOp2DToShuffleLowering
 void mlir::vector::populateVectorTransposeLoweringPatterns(
     RewritePatternSet &patterns, VectorTransformsOptions options,
     PatternBenefit benefit) {
-  patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
-                                                  benefit);
+  patterns
+      .add<Transpose2DWithUnitDimToShapeCast, TransposeWithUnitDimToShapeCast>(
+          patterns.getContext(), benefit);
   patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
       options, patterns.getContext(), benefit);
 }
diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index 219a72df52a19c9..d50d8d0d67da196 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -386,6 +386,13 @@ func.func @transpose10_4x1xf32_scalable(%arg0: vector<4x[1]xf32>) -> vector<[1]x
   return %0 : vector<[1]x4xf32>
 }
 
+// CHECK-LABEL: func @transpose_nd
+func.func @transpose_nd(%arg0: vector<1x2x1x16xf32>) -> vector<1x1x2x16xf32> {
+  // CHECK-NEXT: vector.shape_cast %arg0 : vector<1x2x1x16xf32> to vector<1x1x2x16xf32>
+  %0 = vector.transpose %arg0, [0, 2, 1, 3] : vector<1x2x1x16xf32> to vector<1x1x2x16xf32>
+  return %0 : vector<1x1x2x16xf32>
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
     %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">

>From 2de663983ef3078c8174171c9ea6c7a033589851 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Thu, 6 Feb 2025 09:03:37 +0900
Subject: [PATCH 2/2] fix tests

---
 .../Vector/vector-transpose-lowering.mlir     | 49 +++++--------------
 1 file changed, 12 insertions(+), 37 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index d50d8d0d67da196..68e408488cf06f0 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -23,53 +23,21 @@ func.func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> {
 
 // CHECK-LABEL: func @transpose102_1x8x8xf32
 func.func @transpose102_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x1x8xf32> {
-  //      CHECK: vector.extract {{.*}}[0, 0] : vector<8xf32> from vector<1x8x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8xf32> into vector<8x1x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[0, 1] : vector<8xf32> from vector<1x8x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [1, 0] : vector<8xf32> into vector<8x1x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[0, 2] : vector<8xf32> from vector<1x8x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [2, 0] : vector<8xf32> into vector<8x1x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[0, 3] : vector<8xf32> from vector<1x8x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [3, 0] : vector<8xf32> into vector<8x1x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[0, 4] : vector<8xf32> from vector<1x8x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [4, 0] : vector<8xf32> into vector<8x1x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[0, 5] : vector<8xf32> from vector<1x8x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [5, 0] : vector<8xf32> into vector<8x1x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[0, 6] : vector<8xf32> from vector<1x8x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [6, 0] : vector<8xf32> into vector<8x1x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[0, 7] : vector<8xf32> from vector<1x8x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [7, 0] : vector<8xf32> into vector<8x1x8xf32>
+  // CHECK: %0 = vector.shape_cast %arg0 : vector<1x8x8xf32> to vector<8x1x8xf32>
   %0 = vector.transpose %arg0, [1, 0, 2] : vector<1x8x8xf32> to vector<8x1x8xf32>
   return %0 : vector<8x1x8xf32>
 }
 
 // CHECK-LABEL: func @transpose102_8x1x8xf32
 func.func @transpose102_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<1x8x8xf32> {
-  //      CHECK: vector.extract {{.*}}[0, 0] : vector<8xf32> from vector<8x1x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8xf32> into vector<1x8x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[1, 0] : vector<8xf32> from vector<8x1x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [0, 1] : vector<8xf32> into vector<1x8x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[2, 0] : vector<8xf32> from vector<8x1x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [0, 2] : vector<8xf32> into vector<1x8x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[3, 0] : vector<8xf32> from vector<8x1x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [0, 3] : vector<8xf32> into vector<1x8x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[4, 0] : vector<8xf32> from vector<8x1x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [0, 4] : vector<8xf32> into vector<1x8x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[5, 0] : vector<8xf32> from vector<8x1x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [0, 5] : vector<8xf32> into vector<1x8x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[6, 0] : vector<8xf32> from vector<8x1x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [0, 6] : vector<8xf32> into vector<1x8x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[7, 0] : vector<8xf32> from vector<8x1x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [0, 7] : vector<8xf32> into vector<1x8x8xf32>
+  // CHECK: %0 = vector.shape_cast %arg0 : vector<8x1x8xf32> to vector<1x8x8xf32>
   %0 = vector.transpose %arg0, [1, 0, 2] : vector<8x1x8xf32> to vector<1x8x8xf32>
   return %0 : vector<1x8x8xf32>
 }
 
 // CHECK-LABEL:   func @transpose1023_1x1x8x8xf32(
 func.func @transpose1023_1x1x8x8xf32(%arg0: vector<1x1x8x8xf32>) -> vector<1x1x8x8xf32> {
-  // Note the single 2-D extract/insert pair since 2 and 3 are not transposed!
-  //      CHECK: vector.extract {{.*}}[0, 0] : vector<8x8xf32> from vector<1x1x8x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8x8xf32> into vector<1x1x8x8xf32>
+  // CHECK: return %arg0 : vector<1x1x8x8xf32>
   %0 = vector.transpose %arg0, [1, 0, 2, 3] : vector<1x1x8x8xf32> to vector<1x1x8x8xf32>
   return %0 : vector<1x1x8x8xf32>
 }
@@ -386,13 +354,20 @@ func.func @transpose10_4x1xf32_scalable(%arg0: vector<4x[1]xf32>) -> vector<[1]x
   return %0 : vector<[1]x4xf32>
 }
 
-// CHECK-LABEL: func @transpose_nd
-func.func @transpose_nd(%arg0: vector<1x2x1x16xf32>) -> vector<1x1x2x16xf32> {
+// CHECK-LABEL: func @transpose_nd1
+func.func @transpose_nd1(%arg0: vector<1x2x1x16xf32>) -> vector<1x1x2x16xf32> {
   // CHECK-NEXT: vector.shape_cast %arg0 : vector<1x2x1x16xf32> to vector<1x1x2x16xf32>
   %0 = vector.transpose %arg0, [0, 2, 1, 3] : vector<1x2x1x16xf32> to vector<1x1x2x16xf32>
   return %0 : vector<1x1x2x16xf32>
 }
 
+// CHECK-LABEL: func @transpose_nd2
+func.func @transpose_nd2(%arg0: vector<1x1x2x16xf32>) -> vector<1x2x1x16xf32> {
+  // CHECK-NEXT: vector.shape_cast %arg0 : vector<1x1x2x16xf32> to vector<1x2x1x16xf32>
+  %0 = vector.transpose %arg0, [0, 2, 1, 3] : vector<1x1x2x16xf32> to vector<1x2x1x16xf32>
+  return %0 : vector<1x2x1x16xf32>
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
     %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">



More information about the Mlir-commits mailing list