[Mlir-commits] [mlir] [mlir][vector]Enable DropUnitDimFromTransposeOp (PR #93007)

Hugo Trachino llvmlistbot at llvm.org
Wed May 22 02:33:29 PDT 2024


https://github.com/nujaa updated https://github.com/llvm/llvm-project/pull/93007

>From 7d14d34c0411534f0286721ccbd6a70301daba8f Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Fri, 10 May 2024 19:53:15 +0800
Subject: [PATCH 1/2] [mlir][vector]Enable DropUnitDimFromTransposeOp

---
 .../Vector/Transforms/VectorTransforms.cpp    | 75 ++++++++++++++++++-
 .../Vector/vector-transfer-flatten.mlir       | 74 ++++++++++++++++++
 2 files changed, 147 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index f29eba90c3ceb..fe74a6446ceed 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1695,6 +1695,77 @@ struct DropUnitDimFromElementwiseOps final
   }
 };
 
+
+/// Removes unit dimensions from a transpose op. Generates a vector.shape_cast
+/// on the operand and result to match types.
+///
+/// Ex:
+/// ```
+///   %tr = vector.transpose %arg0, [3, 1, 2, 0]: vector<1x4x1x2xf32> to
+///   vector<2x4x1x1xf32>
+/// ```
+///
+/// gets converted to:
+///
+/// ```
+/// %sc0 = vector.shape_cast %arg0 : vector<1x4x1x2xf32> to vector<4x2xf32>
+/// %tr = vector.transpose %sc0, [1, 0] : vector<4x2xf32> to vector<2x4xf32>
+/// %sc1 = vector.shape_cast %tr : vector<2x4xf32> to vector<2x4x1x1xf32>
+/// ```
+struct DropUnitDimFromTransposeOp final
+    : public OpRewritePattern<vector::TransposeOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    auto sourceVectorType = transposeOp.getSourceVectorType();
+    if (sourceVectorType.getRank() < 2)
+      return failure();
+
+    VectorType newVType = sourceVectorType;
+    SmallVector<int64_t> newPerm =
+        llvm::to_vector(transposeOp.getPermutation());
+    unsigned removedDims = 0;
+    auto shape = sourceVectorType.getShape();
+    for (const auto &dim : llvm::enumerate(shape)) {
+      if (dim.value() == 1 &&
+          !sourceVectorType.getScalableDims()[dim.index()]) {
+        newVType =
+            VectorType::Builder(newVType).dropDim(dim.index() - removedDims);
+        for (unsigned permutationIdx = 0; permutationIdx < newPerm.size();
+             ++permutationIdx) {
+          // Erase from permutation map the dropped unary dimension.
+          if ((unsigned)newPerm[permutationIdx] == dim.index() - removedDims) {
+            newPerm.erase(newPerm.begin() + permutationIdx);
+            permutationIdx--;
+          }
+          // Decrement all dimensions of higher rank to keep permutation map
+          // in range of the new rank.
+          else if ((unsigned)newPerm[permutationIdx] > dim.index() - removedDims) {
+            newPerm[permutationIdx]--;
+          }
+        }
+        removedDims++;
+      }
+    }
+    if (!removedDims)
+      return failure();
+
+    auto loc = transposeOp->getLoc();
+    auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType,
+                                                     transposeOp.getVector());
+    // Create an updated Transpose Op without unit dim.
+    vector::TransposeOp newTransposeOp =
+        rewriter.create<vector::TransposeOp>(loc, opSC, newPerm);
+
+    // Restore the unit dim by applying vector.shape_cast to the result.
+    rewriter.replaceOpWithNewOp<ShapeCastOp>(
+        transposeOp, transposeOp.getResultVectorType(), newTransposeOp);
+
+    return failure();
+  }
+};
+
 /// Pattern to eliminate redundant zero-constants added to reduction operands.
 /// It's enough for there to be one initial zero value, so we can eliminate the
 /// extra ones that feed into `vector.reduction <add>`. These get created by the
@@ -1819,8 +1890,8 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
 
 void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<DropUnitDimFromElementwiseOps, ShapeCastOpFolder>(
-      patterns.getContext(), benefit);
+  patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimFromTransposeOp,
+               ShapeCastOpFolder>(patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateBubbleVectorBitCastOpPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 788ae9ac044ed..1c8fd3a40acb8 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -460,6 +460,80 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
 //   CHECK-128B-NOT:   memref.collapse_shape
 
 
+// -----
+
+func.func @fold_unit_dim_transpose(%arg0 : vector<1x4x1x2xf32>) -> vector<2x4x1x1xf32> {
+  %tr = vector.transpose %arg0, [3, 1, 2, 0]: vector<1x4x1x2xf32> to vector<2x4x1x1xf32>
+  return %tr : vector<2x4x1x1xf32>
+}
+// CHECK-LABEL:   func.func @fold_unit_dim_transpose(
+// CHECK-SAME:        %[[VAL_0:.*]]: vector<1x4x1x2xf32>) -> vector<2x4x1x1xf32> {
+// CHECK:             %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x4x1x2xf32> to vector<4x2xf32>
+// CHECK:             %[[VAL_2:.*]] = vector.transpose %[[VAL_1]], [1, 0] : vector<4x2xf32> to vector<2x4xf32>
+// CHECK:             %[[VAL_3:.*]] = vector.shape_cast %[[VAL_2]] : vector<2x4xf32> to vector<2x4x1x1xf32>
+// CHECK:             return %[[VAL_3]] : vector<2x4x1x1xf32>
+
+// -----
+
+func.func @fold_unit_dim_transpose_identity(%arg0 : vector<1x4x1x2xf32>) -> vector<1x4x2x1xf32> {
+  %tr = vector.transpose %arg0, [2, 1, 3, 0]: vector<1x4x1x2xf32> to vector<1x4x2x1xf32>
+  return %tr : vector<1x4x2x1xf32>
+}
+// CHECK-LABEL:   func.func @fold_unit_dim_transpose_identity(
+// CHECK-SAME:        %[[VAL_0:.*]]: vector<1x4x1x2xf32>) -> vector<1x4x2x1xf32> {
+// CHECK:             %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x4x1x2xf32> to vector<4x2xf32>
+// CHECK:             %[[VAL_2:.*]] = vector.shape_cast %[[VAL_1]] : vector<4x2xf32> to vector<1x4x2x1xf32>
+// CHECK:             return %[[VAL_2]] : vector<1x4x2x1xf32>
+// -----
+
+func.func @fold_unit_dim_transpose_to_1_dim(%arg0 : vector<1x4xf32>) -> vector<4x1xf32> {
+  %tr = vector.transpose %arg0, [1, 0]: vector<1x4xf32> to vector<4x1xf32>
+  return %tr : vector<4x1xf32>
+}
+
+// CHECK-LABEL:   func.func @fold_unit_dim_transpose_to_1_dim(
+// CHECK-SAME:        %[[VAL_0:.*]]: vector<1x4xf32>) -> vector<4x1xf32> {
+// CHECK:             %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x4xf32> to vector<4xf32>
+// CHECK:             %[[VAL_2:.*]] = vector.shape_cast %[[VAL_1]] : vector<4xf32> to vector<4x1xf32>
+// CHECK:             return %[[VAL_2]] : vector<4x1xf32>
+
+// -----
+
+func.func @fold_unit_dim_transpose_all_one_dim(%arg0 : vector<1x1x1xf32>) -> vector<1x1x1xf32> {
+  %tr = vector.transpose %arg0, [1, 2, 0]: vector<1x1x1xf32> to vector<1x1x1xf32>
+  return %tr : vector<1x1x1xf32>
+}
+
+// CHECK-LABEL:   func.func @fold_unit_dim_transpose_all_one_dim(
+// CHECK-SAME:        %[[VAL_0:.*]]: vector<1x1x1xf32>) -> vector<1x1x1xf32> {
+// CHECK:             return %[[VAL_0]] : vector<1x1x1xf32>
+
+// -----
+
+func.func @drop_unit_dim_full_example(%inputLHS : vector<[1]xf128>, %inputRHS : vector<[1]xf128>, %acc : vector<[1]x[1]xf128>) -> vector<[1]x[1]xf128> {
+  %lhsCast = vector.shape_cast %inputLHS : vector<[1]xf128> to vector<[1]x1xf128>
+  %lhsBcast = vector.broadcast %lhsCast : vector<[1]x1xf128> to vector<[1]x[1]x1xf128>
+  %lhsT = vector.transpose %lhsBcast, [1, 0, 2] : vector<[1]x[1]x1xf128> to vector<[1]x[1]x1xf128>
+  %rhsCast = vector.shape_cast %inputRHS : vector<[1]xf128> to vector<1x[1]xf128>
+  %rhsBcast = vector.broadcast %rhsCast : vector<1x[1]xf128> to vector<[1]x1x[1]xf128>
+  %rhs = vector.transpose %rhsBcast, [0, 2, 1] : vector<[1]x1x[1]xf128> to vector<[1]x[1]x1xf128>
+  %mul = arith.mulf %lhsT, %rhs : vector<[1]x[1]x1xf128>
+  %dropDim = vector.shape_cast %mul : vector<[1]x[1]x1xf128> to vector<[1]x[1]xf128>
+  %addAcc = arith.addf %acc, %dropDim : vector<[1]x[1]xf128>
+  return %addAcc : vector<[1]x[1]xf128>
+}
+
+// CHECK-LABEL: func.func @drop_unit_dim_full_example(
+// CHECK-SAME:    %[[LHS:[a-zA-Z_]+[0-9]]]: vector<[1]xf128>,
+// CHECK-SAME:    %[[RHS:[a-zA-Z_]+[0-9]]]: vector<[1]xf128>,
+// CHECK-SAME:    %[[ACC:[a-zA-Z_]+[0-9]]]: vector<[1]x[1]xf128>) -> vector<[1]x[1]xf128> {
+// CHECK:         %[[LHSBCAST:.*]] = vector.broadcast %[[LHS]] : vector<[1]xf128> to vector<[1]x[1]xf128>
+// CHECK:         %[[TRA:.*]] = vector.transpose %[[LHSBCAST]], [1, 0] : vector<[1]x[1]xf128> to vector<[1]x[1]xf128>
+// CHECK:         %[[RHSBCAST:.*]] = vector.broadcast %[[RHS]] : vector<[1]xf128> to vector<[1]x[1]xf128>
+// CHECK:         %[[MUL:.*]] = arith.mulf %[[TRA]], %[[RHSBCAST]] : vector<[1]x[1]xf128>
+// CHECK:         %[[RES:.*]] = arith.addf %[[ACC]], %[[MUL]] : vector<[1]x[1]xf128>
+// CHECK:         return %[[RES]] : vector<[1]x[1]xf128>
+
 // -----
 
 func.func @regression_non_contiguous_dim_read(%subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,

>From b79f4e103a11ef72c89671a3bb8b2a2b26cab0fb Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Wed, 22 May 2024 17:33:13 +0800
Subject: [PATCH 2/2] Formatting.

---
 mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index fe74a6446ceed..4225ea6cac9e2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1695,7 +1695,6 @@ struct DropUnitDimFromElementwiseOps final
   }
 };
 
-
 /// Removes unit dimensions from a transpose op. Generates a vector.shape_cast
 /// on the operand and result to match types.
 ///
@@ -1741,7 +1740,8 @@ struct DropUnitDimFromTransposeOp final
           }
           // Decrement all dimensions of higher rank to keep permutation map
           // in range of the new rank.
-          else if ((unsigned)newPerm[permutationIdx] > dim.index() - removedDims) {
+          else if ((unsigned)newPerm[permutationIdx] >
+                   dim.index() - removedDims) {
             newPerm[permutationIdx]--;
           }
         }



More information about the Mlir-commits mailing list