[Mlir-commits] [mlir] [mlir][vector] Add pattern to reorder shape_cast(arithmetic(a, b)) (PR #74817)

Andrzej WarzyƄski llvmlistbot at llvm.org
Mon Dec 11 05:56:23 PST 2023


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/74817

>From 0d002147e2292d688267299fe28953818a113027 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 6 Dec 2023 21:15:33 +0000
Subject: [PATCH 1/3] [mlir][vector] Add pattern to reorder
 shape_cast(arithmetic(a, b))

Adds a new pattern to reorder:

  shape_cast(arithmetic(a + b))

as

  arithmetic(shape_cast(a), shape_cast(b)).

Example:
```mlir
  %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
  %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
```

gets converted to:

```mlir
  %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
  %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
  %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
```

This helps to, eventually, fold away shape_casts entirely. One specific
case where this is beneficial is when vectorising 1D depthwise
convolutions, e.g. `linalg.depthwise_conv_1d_nwc_wc`, with the channel
dimension is flattened, i.e. `flatten_1d_depthwise_conv` is set:
```
transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv}
```

Output after vectorisation:
```mlir
  func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_memref(%arg0: memref<1x8x3xi8>, %arg1: memref<1x3xi8>, %arg2: memref<1x8x3xi8>) {
    %c0 = arith.constant 0 : index
    %c0_i8 = arith.constant 0 : i8
    %0 = vector.transfer_read %arg0[%c0, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref<1x8x3xi8>, vector<1x8x3xi8>
    %1 = vector.transfer_read %arg1[%c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<1x3xi8>, vector<1x3xi8>
    %2 = vector.transfer_read %arg2[%c0, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref<1x8x3xi8>, vector<1x8x3xi8>
    %3 = vector.extract %1[0] : vector<3xi8> from vector<1x3xi8>
    %4 = vector.shape_cast %0 : vector<1x8x3xi8> to vector<1x24xi8>
    %5 = vector.shape_cast %2 : vector<1x8x3xi8> to vector<1x24xi8>
    %6 = vector.broadcast %3 : vector<3xi8> to vector<1x8x3xi8>
    %7 = vector.shape_cast %6 : vector<1x8x3xi8> to vector<1x24xi8>
    %8 = arith.muli %4, %7 : vector<1x24xi8>
    %9 = arith.addi %8, %5 : vector<1x24xi8>
    %10 = vector.shape_cast %9 : vector<1x24xi8> to vector<1x8x3xi8>
    vector.transfer_write %10, %arg2[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x8x3xi8>, memref<1x8x3xi8>
    return
  }
```

Output after applying this patterns and other related patterns (e.g.
-test-vector-transfer-flatten-patterns -canonicalize)
```mlir
  func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_memref(%arg0: memref<1x8x3xi8>, %arg1: memref<1x3xi8>, %arg2: memref<1x8x3xi8>) {
    %c0_i8 = arith.constant 0 : i8
    %c0 = arith.constant 0 : index
    %collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2]] : memref<1x8x3xi8> into memref<24xi8>
    %0 = vector.transfer_read %collapse_shape[%c0], %c0_i8 {in_bounds = [true]} : memref<24xi8>, vector<24xi8>
    %collapse_shape_0 = memref.collapse_shape %arg1 [[0, 1]] : memref<1x3xi8> into memref<3xi8>
    %1 = vector.transfer_read %collapse_shape_0[%c0], %c0_i8 {in_bounds = [true]} : memref<3xi8>, vector<3xi8>
    %collapse_shape_1 = memref.collapse_shape %arg2 [[0, 1, 2]] : memref<1x8x3xi8> into memref<24xi8>
    %2 = vector.transfer_read %collapse_shape_1[%c0], %c0_i8 {in_bounds = [true]} : memref<24xi8>, vector<24xi8>
    %3 = vector.broadcast %1 : vector<3xi8> to vector<1x8x3xi8>
    %collapse_shape_2 = memref.collapse_shape %arg2 [[0, 1, 2]] : memref<1x8x3xi8> into memref<24xi8>
    %4 = vector.shape_cast %3 : vector<1x8x3xi8> to vector<24xi8>
    %5 = arith.muli %0, %4 : vector<24xi8>
    %6 = arith.addi %5, %2 : vector<24xi8>
    vector.transfer_write %6, %collapse_shape_2[%c0] {in_bounds = [true]} : vector<24xi8>, memref<24xi8>
    return
  }
```

Both shape_casts and the leading unit dims have been folded away.
---
 .../Vector/Transforms/VectorRewritePatterns.h |   4 +
 .../Transforms/VectorTransferOpTransforms.cpp |   1 +
 .../Vector/Transforms/VectorTransforms.cpp    | 111 ++++++++++++++++++
 .../Vector/vector-transfer-flatten.mlir       |  59 ++++++++++
 4 files changed, 175 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 08c08172d0531..7102ed81ec57d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -294,6 +294,10 @@ void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns,
 void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns,
                                                 PatternBenefit benefit = 1);
 
+/// Collect a set of vector.shape_cast folding patterns.
+void populateReorderShapeCastPatterns(RewritePatternSet &patterns,
+                                      PatternBenefit benefit = 1);
+
 /// Collect a set of patterns to flatten n-D vector transfers on contiguous
 /// memref.
 ///
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index ed42e6508b431..ac475566ccdb1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -922,4 +922,5 @@ void mlir::vector::populateFlattenVectorTransferPatterns(
                FlattenContiguousRowMajorTransferWritePattern>(
       patterns.getContext(), benefit);
   populateShapeCastFoldingPatterns(patterns, benefit);
+  populateReorderShapeCastPatterns(patterns, benefit);
 }
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 6e7fab293d3a1..b38442f733814 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1446,6 +1446,108 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
   }
 };
 
+/// Reorders:
+///   shape_cast(arithmetic(a + b))
+/// as
+///   arithmetic(shape_cast(a), shape_cast(b)).
+///
+/// Ex:
+/// ```
+///  %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
+///  %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
+/// ```
+///
+/// gets converted to:
+///
+/// ```
+///  %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
+///  %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
+///  %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
+/// ```
+///
+/// While this pattern introduces an extra shape_cast Op (1 shape_cast is
+/// replaced with 2), this brings shape_cast closer to vector.xfer operations.
+/// With patterns like e.g. `FlattenContiguousRowMajorTransferWritePattern`,
+/// the addition shape_cast's are eventually folded away.
+///
+/// Here is another example where this pattern is helpful:
+/// ```
+///  %sc_arg0 = vector.shape_cast %arg0 : vector<8xi32> to vector<1x8xi32>
+///  %sc_arg1 = vector.shape_cast %arg1 : vector<8xi32> to vector<1x8xi32>
+///  %sc_arg2 = vector.shape_cast %arg2 : vector<8xi32> to vector<1x8xi32>
+///  %mul = arith.muli %sc_arg0, %sc_arg1 : vector<1x8xi32>
+///  %add = arith.addi %mul, %sc_arg2 : vector<1x8xi32>
+///  %res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32>
+/// ```
+///
+/// gets folded as:
+///
+///```
+///    %0 = arith.muli %arg0, %arg1 : vector<8xi32>
+///    %res = arith.addi %0, %arg2 : vector<8xi32>
+/// ```
+/// ATM this pattern is limited to `vector.shape_cast` ops that fold the unit
+/// dim, e.g.:
+/// ```
+///   vector.shape_cast %mul : vector<1x4xf32> to vector<4xf32>
+/// ```
+/// In addition, the input vector should be the result of an arithmetic
+/// operation, `AritOp`.
+template <typename ArithOp>
+struct ReorderArithAndShapeCast : public OpRewritePattern<vector::ShapeCastOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+                                PatternRewriter &rewriter) const override {
+    if (!llvm::isa_and_present<ArithOp>(
+            shapeCastOp.getSource().getDefiningOp()))
+      return failure();
+
+    auto *arithOp = shapeCastOp.getSource().getDefiningOp();
+
+    auto sourceVectorType =
+        dyn_cast_or_null<VectorType>(shapeCastOp.getSource().getType());
+    auto resultVectorType =
+        dyn_cast_or_null<VectorType>(shapeCastOp.getResult().getType());
+    if (!sourceVectorType || !resultVectorType)
+      return failure();
+
+    // Either the leading or the trailing dims of the input should be
+    // non-scalable 1.
+    if (((sourceVectorType.getShape().back() != 1) ||
+         (sourceVectorType.getScalableDims().back())) &&
+        ((sourceVectorType.getShape().front() != 1) ||
+         (sourceVectorType.getScalableDims().front())))
+      return failure();
+
+    // Does this shape_cast fold the input vector?
+    if (resultVectorType.getRank() != (sourceVectorType.getRank() - 1))
+      return failure();
+
+    // Does this shape_cast fold the _unit_ dim?
+    if (llvm::any_of(resultVectorType.getShape(),
+                     [](int64_t dim) { return (dim == 1); }))
+      return failure();
+
+    auto loc = shapeCastOp->getLoc();
+
+    // shape_cast(a)
+    auto *lhs = rewriter.create(loc, shapeCastOp->getName().getIdentifier(),
+                                arithOp->getOperands()[0], resultVectorType,
+                                shapeCastOp->getAttrs());
+    // shape_cast(b)
+    auto *rhs = rewriter.create(loc, shapeCastOp->getName().getIdentifier(),
+                                arithOp->getOperands()[1], resultVectorType,
+                                shapeCastOp->getAttrs());
+
+    // Replace shape_cast(a ArithOp b) with shape_cast(a) ArithOp shape_cast(b)
+    rewriter.replaceOpWithNewOp<ArithOp>(shapeCastOp, lhs->getResult(0),
+                                         rhs->getResult(0));
+
+    return success();
+  }
+};
+
 /// Pattern to eliminate redundant zero-constants added to reduction operands.
 /// It's enough for there to be one initial zero value, so we can eliminate the
 /// extra ones that feed into `vector.reduction <add>`. These get created by the
@@ -1514,6 +1616,15 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
   patterns.add<ShapeCastOpFolder>(patterns.getContext(), benefit);
 }
 
+void mlir::vector::populateReorderShapeCastPatterns(RewritePatternSet &patterns,
+                                                    PatternBenefit benefit) {
+  patterns.add<ReorderArithAndShapeCast<arith::AddIOp>,
+               ReorderArithAndShapeCast<arith::AddFOp>,
+               ReorderArithAndShapeCast<arith::MulIOp>,
+               ReorderArithAndShapeCast<arith::MulFOp>>(patterns.getContext(),
+                                                        benefit);
+}
+
 void mlir::vector::populateBubbleVectorBitCastOpPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<BubbleDownVectorBitCastForExtract,
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index ebec2274655e4..88a86755f9f7a 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -254,3 +254,62 @@ func.func @transfer_read_flattenable_negative2(
 
 // CHECK-LABEL: func @transfer_read_flattenable_negative2
 //       CHECK:   vector.transfer_read {{.*}} vector<5x4x3x2xi8>
+
+// -----
+
+func.func @fold_unit_dim_add(%arg0 : vector<8x1xi32>,
+                             %arg1 : vector<1x8xi32>) -> vector<8xi32> {
+   %sc_arg0 = vector.shape_cast %arg0 : vector<8x1xi32> to vector<1x8xi32>
+   %add = arith.addi %sc_arg0, %arg1 : vector<1x8xi32>
+   %res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32>
+   return %res : vector<8xi32>
+}
+
+// CHECK-LABEL:   func.func @fold_unit_dim_add(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<8x1xi32>,
+// CHECK-SAME:      %[[VAL_1:.*]]: vector<1x8xi32>) -> vector<8xi32> {
+// CHECK:           %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1xi32> to vector<8xi32>
+// CHECK:           %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8xi32> to vector<8xi32>
+// CHECK:           %[[VAL_4:.*]] = arith.addi %[[VAL_2]], %[[VAL_3]] : vector<8xi32>
+// CHECK:           return %[[VAL_4]] : vector<8xi32>
+
+// -----
+
+func.func @fold_unit_dim_mulf(%arg0 : vector<8x[2]x1xf32>,
+                              %arg1 : vector<1x8x[2]xf32>) -> vector<8x[2]xf32> {
+   %sc_arg0 = vector.shape_cast %arg0 : vector<8x[2]x1xf32> to vector<1x8x[2]xf32>
+   %add = arith.mulf %sc_arg0, %arg1 : vector<1x8x[2]xf32>
+   %res = vector.shape_cast %add : vector<1x8x[2]xf32> to vector<8x[2]xf32>
+   return %res : vector<8x[2]xf32>
+}
+
+// CHECK-LABEL:   func.func @fold_unit_dim_mulf(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<8x[2]x1xf32>,
+// CHECK-SAME:      %[[VAL_1:.*]]: vector<1x8x[2]xf32>) -> vector<8x[2]xf32> {
+// CHECK:           %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x[2]x1xf32> to vector<8x[2]xf32>
+// CHECK:           %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x[2]xf32> to vector<8x[2]xf32>
+// CHECK:           %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[2]xf32>
+// CHECK:           return %[[VAL_4]] : vector<8x[2]xf32>
+
+// -----
+
+// All shape casts are folded away
+
+func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
+                                   %arg1 : vector<8xi32>,
+                                   %arg2 : vector<8xi32>) -> vector<8xi32> {
+   %sc_arg0 = vector.shape_cast %arg0 : vector<8xi32> to vector<1x8xi32>
+   %sc_arg1 = vector.shape_cast %arg1 : vector<8xi32> to vector<1x8xi32>
+   %sc_arg2 = vector.shape_cast %arg2 : vector<8xi32> to vector<1x8xi32>
+   %mul = arith.muli %sc_arg0, %sc_arg1 : vector<1x8xi32>
+   %add = arith.addi %mul, %sc_arg2 : vector<1x8xi32>
+   %res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32>
+   return %res : vector<8xi32>
+}
+
+// CHECK-LABEL:   func.func @fold_unit_dims_entirely(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<8xi32>, %[[VAL_1:.*]]: vector<8xi32>,
+// CHECK-SAME:      %[[VAL_2:.*]]: vector<8xi32>) -> vector<8xi32> {
+// CHECK:           %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32>
+// CHECK:           %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32>
+// CHECK:           return %[[VAL_4]] : vector<8xi32>

>From f54addbc2b07cd6a99e684304b89bae753bce1ae Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Mon, 11 Dec 2023 08:18:30 +0000
Subject: [PATCH 2/3] fixup! [mlir][vector] Add pattern to reorder
 shape_cast(arithmetic(a, b))

Address PR comments
---
 .../Vector/Transforms/VectorTransforms.cpp    | 40 +++++++++----------
 1 file changed, 20 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index b38442f733814..03aaf85226fbc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1492,39 +1492,40 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
 ///   vector.shape_cast %mul : vector<1x4xf32> to vector<4xf32>
 /// ```
 /// In addition, the input vector should be the result of an arithmetic
-/// operation, `AritOp`.
+/// operation, `ArithOp`.
 template <typename ArithOp>
 struct ReorderArithAndShapeCast : public OpRewritePattern<vector::ShapeCastOp> {
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
                                 PatternRewriter &rewriter) const override {
-    if (!llvm::isa_and_present<ArithOp>(
-            shapeCastOp.getSource().getDefiningOp()))
-      return failure();
-
     auto *arithOp = shapeCastOp.getSource().getDefiningOp();
+    if (!llvm::isa_and_present<ArithOp>(arithOp))
+      return failure();
 
     auto sourceVectorType =
-        dyn_cast_or_null<VectorType>(shapeCastOp.getSource().getType());
+        dyn_cast<VectorType>(shapeCastOp.getSource().getType());
     auto resultVectorType =
-        dyn_cast_or_null<VectorType>(shapeCastOp.getResult().getType());
+        dyn_cast<VectorType>(shapeCastOp.getResult().getType());
     if (!sourceVectorType || !resultVectorType)
       return failure();
 
     // Either the leading or the trailing dims of the input should be
     // non-scalable 1.
-    if (((sourceVectorType.getShape().back() != 1) ||
-         (sourceVectorType.getScalableDims().back())) &&
-        ((sourceVectorType.getShape().front() != 1) ||
-         (sourceVectorType.getScalableDims().front())))
+    bool leadDimUnitFixed = ((sourceVectorType.getShape().back() != 1) ||
+                             (sourceVectorType.getScalableDims().back()));
+    bool trailinDimUnitFixed = ((sourceVectorType.getShape().front() != 1) ||
+                                (sourceVectorType.getScalableDims().front()));
+    if (!leadDimUnitFixed && !trailinDimUnitFixed)
       return failure();
 
     // Does this shape_cast fold the input vector?
     if (resultVectorType.getRank() != (sourceVectorType.getRank() - 1))
       return failure();
 
-    // Does this shape_cast fold the _unit_ dim?
+    // Does this shape_cast fold the traling/leading _unit_ dim?
+    // TODO: Even when the trailing/leading unit dims are folded, there might
+    // still be some "inner" unit dims left.
     if (llvm::any_of(resultVectorType.getShape(),
                      [](int64_t dim) { return (dim == 1); }))
       return failure();
@@ -1532,17 +1533,16 @@ struct ReorderArithAndShapeCast : public OpRewritePattern<vector::ShapeCastOp> {
     auto loc = shapeCastOp->getLoc();
 
     // shape_cast(a)
-    auto *lhs = rewriter.create(loc, shapeCastOp->getName().getIdentifier(),
-                                arithOp->getOperands()[0], resultVectorType,
-                                shapeCastOp->getAttrs());
+    auto lhs = rewriter.create<vector::ShapeCastOp>(loc, resultVectorType,
+                                                    arithOp->getOperands()[0],
+                                                    shapeCastOp->getAttrs());
     // shape_cast(b)
-    auto *rhs = rewriter.create(loc, shapeCastOp->getName().getIdentifier(),
-                                arithOp->getOperands()[1], resultVectorType,
-                                shapeCastOp->getAttrs());
+    auto rhs = rewriter.create<vector::ShapeCastOp>(loc, resultVectorType,
+                                                    arithOp->getOperands()[1],
+                                                    shapeCastOp->getAttrs());
 
     // Replace shape_cast(a ArithOp b) with shape_cast(a) ArithOp shape_cast(b)
-    rewriter.replaceOpWithNewOp<ArithOp>(shapeCastOp, lhs->getResult(0),
-                                         rhs->getResult(0));
+    rewriter.replaceOpWithNewOp<ArithOp>(shapeCastOp, lhs, rhs);
 
     return success();
   }

>From b4f7096ca52178ad6bed20178c7b12f171a337b0 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Mon, 11 Dec 2023 11:34:49 +0000
Subject: [PATCH 3/3] fixup! [mlir][vector] Add pattern to reorder
 shape_cast(arithmetic(a, b))

Rename and restrict the pattern
---
 .../Vector/Transforms/VectorTransforms.cpp    | 25 +++++++++++++------
 1 file changed, 17 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 03aaf85226fbc..39a331ccb8b08 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1494,13 +1494,22 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
 /// In addition, the input vector should be the result of an arithmetic
 /// operation, `ArithOp`.
 template <typename ArithOp>
-struct ReorderArithAndShapeCast : public OpRewritePattern<vector::ShapeCastOp> {
+struct ReorderShapeCastWithUnitDimAndArith
+    : public OpRewritePattern<vector::ShapeCastOp> {
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
                                 PatternRewriter &rewriter) const override {
-    auto *arithOp = shapeCastOp.getSource().getDefiningOp();
-    if (!llvm::isa_and_present<ArithOp>(arithOp))
+    auto arithOp = shapeCastOp.getSource().getDefiningOp<ArithOp>();
+    if (!arithOp)
+      return failure();
+
+    // All arith ops are elementwise - filter out everything else.
+    if (!arithOp.template hasTrait<OpTrait::Elementwise>())
+      return failure();
+
+    // TODO: Add support for unary ops
+    if (arithOp->getOperands().size() != 2)
       return failure();
 
     auto sourceVectorType =
@@ -1618,11 +1627,11 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
 
 void mlir::vector::populateReorderShapeCastPatterns(RewritePatternSet &patterns,
                                                     PatternBenefit benefit) {
-  patterns.add<ReorderArithAndShapeCast<arith::AddIOp>,
-               ReorderArithAndShapeCast<arith::AddFOp>,
-               ReorderArithAndShapeCast<arith::MulIOp>,
-               ReorderArithAndShapeCast<arith::MulFOp>>(patterns.getContext(),
-                                                        benefit);
+  patterns.add<ReorderShapeCastWithUnitDimAndArith<arith::AddIOp>,
+               ReorderShapeCastWithUnitDimAndArith<arith::AddFOp>,
+               ReorderShapeCastWithUnitDimAndArith<arith::MulIOp>,
+               ReorderShapeCastWithUnitDimAndArith<arith::MulFOp>>(
+      patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateBubbleVectorBitCastOpPatterns(



More information about the Mlir-commits mailing list