[Mlir-commits] [mlir] [mlir][Vector] Add fold transpose(shape_cast) -> shape_cast (PR #73951)

Benjamin Maxwell llvmlistbot at llvm.org
Thu Nov 30 08:01:14 PST 2023


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/73951

>From 5c8fd0de2be8e2d36d2a802c150d430ebffad467 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 30 Nov 2023 15:06:15 +0000
Subject: [PATCH 1/2] [mlir][Vector] Add fold transpose(shape_cast) ->
 shape_cast

This folds transpose(shape_cast) into a new shape_cast, when the
transpose just permutes a unit dim from the result of the shape_cast.

Example:

```
%0 = vector.shape_cast %vec : vector<[4]xf32> to vector<[4]x1xf32>
%1 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
```

Folds to:
```
vector.shape_cast %vec : vector<[4]xf32> to vector<1x[4]xf32>
```

This is an (alternate) fix for lowering matmuls to ArmSME.
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 45 +++++++++++++++++++++-
 mlir/test/Dialect/Vector/canonicalize.mlir | 12 ++++++
 2 files changed, 56 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index c462b23e1133fc9..0f372d86e8b3de9 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5548,12 +5548,55 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
   }
 };
 
+/// Folds transpose(shape_cast) into a new shape_cast, when the transpose just
+/// permutes a unit dim from the result of the shape_cast.
+class FoldTransposeShapeCast : public OpRewritePattern<TransposeOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TransposeOp transpOp,
+                                PatternRewriter &rewriter) const override {
+    Value transposeSrc = transpOp.getVector();
+    auto shapeCastOp = transposeSrc.getDefiningOp<vector::ShapeCastOp>();
+    if (!shapeCastOp)
+      return failure();
+
+    auto sourceType = transpOp.getSourceVectorType();
+    auto resultType = transpOp.getResultVectorType();
+
+    auto filterUnitDims = [](VectorType type) {
+      return llvm::make_filter_range(
+          llvm::zip_equal(type.getShape(), type.getScalableDims()),
+          [&](auto dim) {
+            auto [size, isScalble] = dim;
+            return size != 1 || isScalble;
+          });
+    };
+
+    auto sourceWithoutUnitDims = filterUnitDims(sourceType);
+    auto resultWithoutUnitDims = filterUnitDims(resultType);
+
+    // If this transpose just permutes a unit dim, then we can fold it into the
+    // shape_cast.
+    for (auto [srcDim, resDim] :
+         llvm::zip_equal(sourceWithoutUnitDims, resultWithoutUnitDims)) {
+      if (srcDim != resDim)
+        return failure();
+    }
+
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transpOp, resultType,
+                                                     shapeCastOp.getSource());
+
+    return success();
+  };
+};
+
 } // namespace
 
 void vector::TransposeOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
   results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
-              TransposeFolder, FoldTransposeSplat>(context);
+              TransposeFolder, FoldTransposeSplat, FoldTransposeShapeCast>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1021c73cc57d341..6bfb477ecf97285 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -67,6 +67,18 @@ func.func @create_mask_transpose_to_transposed_create_mask(
 
 // -----
 
+// CHECK-LABEL: transposed_unit_dim_shape_cast_to_shape_cast
+//  CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
+func.func @transposed_unit_dim_shape_cast_to_shape_cast(%vec: vector<[4]xf32>) -> vector<1x[4]xf32> {
+  //     CHECK: vector.shape_cast %[[VEC]] : vector<[4]xf32> to vector<1x[4]xf32>
+  // CHECK-NOT: vector.transpose
+  %0 = vector.shape_cast %vec : vector<[4]xf32> to vector<[4]x1xf32>
+  %1 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+  return %1 : vector<1x[4]xf32>
+}
+
+// -----
+
 // CHECK-LABEL: extract_from_create_mask
 //  CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
 func.func @extract_from_create_mask(%dim0: index, %dim1: index) -> vector<[4]x[4]xi1> {

>From 1e68c1e5e65a5bb251239c5dfe23e759da0ecd14 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 30 Nov 2023 16:00:01 +0000
Subject: [PATCH 2/2] Fix typo isScalble -> isScalable

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 0f372d86e8b3de9..4b68520589bdfe6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5567,8 +5567,8 @@ class FoldTransposeShapeCast : public OpRewritePattern<TransposeOp> {
       return llvm::make_filter_range(
           llvm::zip_equal(type.getShape(), type.getScalableDims()),
           [&](auto dim) {
-            auto [size, isScalble] = dim;
-            return size != 1 || isScalble;
+            auto [size, isScalable] = dim;
+            return size != 1 || isScalable;
           });
     };
 



More information about the Mlir-commits mailing list