[Mlir-commits] [mlir] [mlir][vector] Add pattern to drop unit dim from elementwise(a, b)) (PR #74817)

Andrzej WarzyƄski llvmlistbot at llvm.org
Wed Dec 13 07:41:41 PST 2023


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/74817

>From 0d002147e2292d688267299fe28953818a113027 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 6 Dec 2023 21:15:33 +0000
Subject: [PATCH 1/6] [mlir][vector] Add pattern to reorder
 shape_cast(arithmetic(a, b))

Adds a new pattern to reorder:

  shape_cast(arithmetic(a + b))

as

  arithmetic(shape_cast(a), shape_cast(b)).

Example:
```mlir
  %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
  %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
```

gets converted to:

```mlir
  %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
  %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
  %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
```

This helps to, eventually, fold away shape_casts entirely. One specific
case where this is beneficial is when vectorising 1D depthwise
convolutions, e.g. `linalg.depthwise_conv_1d_nwc_wc`, with the channel
dimension is flattened, i.e. `flatten_1d_depthwise_conv` is set:
```
transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv}
```

Output after vectorisation:
```mlir
  func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_memref(%arg0: memref<1x8x3xi8>, %arg1: memref<1x3xi8>, %arg2: memref<1x8x3xi8>) {
    %c0 = arith.constant 0 : index
    %c0_i8 = arith.constant 0 : i8
    %0 = vector.transfer_read %arg0[%c0, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref<1x8x3xi8>, vector<1x8x3xi8>
    %1 = vector.transfer_read %arg1[%c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<1x3xi8>, vector<1x3xi8>
    %2 = vector.transfer_read %arg2[%c0, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref<1x8x3xi8>, vector<1x8x3xi8>
    %3 = vector.extract %1[0] : vector<3xi8> from vector<1x3xi8>
    %4 = vector.shape_cast %0 : vector<1x8x3xi8> to vector<1x24xi8>
    %5 = vector.shape_cast %2 : vector<1x8x3xi8> to vector<1x24xi8>
    %6 = vector.broadcast %3 : vector<3xi8> to vector<1x8x3xi8>
    %7 = vector.shape_cast %6 : vector<1x8x3xi8> to vector<1x24xi8>
    %8 = arith.muli %4, %7 : vector<1x24xi8>
    %9 = arith.addi %8, %5 : vector<1x24xi8>
    %10 = vector.shape_cast %9 : vector<1x24xi8> to vector<1x8x3xi8>
    vector.transfer_write %10, %arg2[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x8x3xi8>, memref<1x8x3xi8>
    return
  }
```

Output after applying this patterns and other related patterns (e.g.
-test-vector-transfer-flatten-patterns -canonicalize)
```mlir
  func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_memref(%arg0: memref<1x8x3xi8>, %arg1: memref<1x3xi8>, %arg2: memref<1x8x3xi8>) {
    %c0_i8 = arith.constant 0 : i8
    %c0 = arith.constant 0 : index
    %collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2]] : memref<1x8x3xi8> into memref<24xi8>
    %0 = vector.transfer_read %collapse_shape[%c0], %c0_i8 {in_bounds = [true]} : memref<24xi8>, vector<24xi8>
    %collapse_shape_0 = memref.collapse_shape %arg1 [[0, 1]] : memref<1x3xi8> into memref<3xi8>
    %1 = vector.transfer_read %collapse_shape_0[%c0], %c0_i8 {in_bounds = [true]} : memref<3xi8>, vector<3xi8>
    %collapse_shape_1 = memref.collapse_shape %arg2 [[0, 1, 2]] : memref<1x8x3xi8> into memref<24xi8>
    %2 = vector.transfer_read %collapse_shape_1[%c0], %c0_i8 {in_bounds = [true]} : memref<24xi8>, vector<24xi8>
    %3 = vector.broadcast %1 : vector<3xi8> to vector<1x8x3xi8>
    %collapse_shape_2 = memref.collapse_shape %arg2 [[0, 1, 2]] : memref<1x8x3xi8> into memref<24xi8>
    %4 = vector.shape_cast %3 : vector<1x8x3xi8> to vector<24xi8>
    %5 = arith.muli %0, %4 : vector<24xi8>
    %6 = arith.addi %5, %2 : vector<24xi8>
    vector.transfer_write %6, %collapse_shape_2[%c0] {in_bounds = [true]} : vector<24xi8>, memref<24xi8>
    return
  }
```

Both shape_casts and the leading unit dims have been folded away.
---
 .../Vector/Transforms/VectorRewritePatterns.h |   4 +
 .../Transforms/VectorTransferOpTransforms.cpp |   1 +
 .../Vector/Transforms/VectorTransforms.cpp    | 111 ++++++++++++++++++
 .../Vector/vector-transfer-flatten.mlir       |  59 ++++++++++
 4 files changed, 175 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 08c08172d0531e..7102ed81ec57d4 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -294,6 +294,10 @@ void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns,
 void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns,
                                                 PatternBenefit benefit = 1);
 
+/// Collect a set of vector.shape_cast folding patterns.
+void populateReorderShapeCastPatterns(RewritePatternSet &patterns,
+                                      PatternBenefit benefit = 1);
+
 /// Collect a set of patterns to flatten n-D vector transfers on contiguous
 /// memref.
 ///
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index ed42e6508b4310..ac475566ccdb1e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -922,4 +922,5 @@ void mlir::vector::populateFlattenVectorTransferPatterns(
                FlattenContiguousRowMajorTransferWritePattern>(
       patterns.getContext(), benefit);
   populateShapeCastFoldingPatterns(patterns, benefit);
+  populateReorderShapeCastPatterns(patterns, benefit);
 }
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 6e7fab293d3a1c..b38442f7338144 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1446,6 +1446,108 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
   }
 };
 
+/// Reorders:
+///   shape_cast(arithmetic(a + b))
+/// as
+///   arithmetic(shape_cast(a), shape_cast(b)).
+///
+/// Ex:
+/// ```
+///  %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
+///  %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
+/// ```
+///
+/// gets converted to:
+///
+/// ```
+///  %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
+///  %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
+///  %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
+/// ```
+///
+/// While this pattern introduces an extra shape_cast Op (1 shape_cast is
+/// replaced with 2), this brings shape_cast closer to vector.xfer operations.
+/// With patterns like e.g. `FlattenContiguousRowMajorTransferWritePattern`,
+/// the addition shape_cast's are eventually folded away.
+///
+/// Here is another example where this pattern is helpful:
+/// ```
+///  %sc_arg0 = vector.shape_cast %arg0 : vector<8xi32> to vector<1x8xi32>
+///  %sc_arg1 = vector.shape_cast %arg1 : vector<8xi32> to vector<1x8xi32>
+///  %sc_arg2 = vector.shape_cast %arg2 : vector<8xi32> to vector<1x8xi32>
+///  %mul = arith.muli %sc_arg0, %sc_arg1 : vector<1x8xi32>
+///  %add = arith.addi %mul, %sc_arg2 : vector<1x8xi32>
+///  %res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32>
+/// ```
+///
+/// gets folded as:
+///
+///```
+///    %0 = arith.muli %arg0, %arg1 : vector<8xi32>
+///    %res = arith.addi %0, %arg2 : vector<8xi32>
+/// ```
+/// ATM this pattern is limited to `vector.shape_cast` ops that fold the unit
+/// dim, e.g.:
+/// ```
+///   vector.shape_cast %mul : vector<1x4xf32> to vector<4xf32>
+/// ```
+/// In addition, the input vector should be the result of an arithmetic
+/// operation, `AritOp`.
+template <typename ArithOp>
+struct ReorderArithAndShapeCast : public OpRewritePattern<vector::ShapeCastOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+                                PatternRewriter &rewriter) const override {
+    if (!llvm::isa_and_present<ArithOp>(
+            shapeCastOp.getSource().getDefiningOp()))
+      return failure();
+
+    auto *arithOp = shapeCastOp.getSource().getDefiningOp();
+
+    auto sourceVectorType =
+        dyn_cast_or_null<VectorType>(shapeCastOp.getSource().getType());
+    auto resultVectorType =
+        dyn_cast_or_null<VectorType>(shapeCastOp.getResult().getType());
+    if (!sourceVectorType || !resultVectorType)
+      return failure();
+
+    // Either the leading or the trailing dims of the input should be
+    // non-scalable 1.
+    if (((sourceVectorType.getShape().back() != 1) ||
+         (sourceVectorType.getScalableDims().back())) &&
+        ((sourceVectorType.getShape().front() != 1) ||
+         (sourceVectorType.getScalableDims().front())))
+      return failure();
+
+    // Does this shape_cast fold the input vector?
+    if (resultVectorType.getRank() != (sourceVectorType.getRank() - 1))
+      return failure();
+
+    // Does this shape_cast fold the _unit_ dim?
+    if (llvm::any_of(resultVectorType.getShape(),
+                     [](int64_t dim) { return (dim == 1); }))
+      return failure();
+
+    auto loc = shapeCastOp->getLoc();
+
+    // shape_cast(a)
+    auto *lhs = rewriter.create(loc, shapeCastOp->getName().getIdentifier(),
+                                arithOp->getOperands()[0], resultVectorType,
+                                shapeCastOp->getAttrs());
+    // shape_cast(b)
+    auto *rhs = rewriter.create(loc, shapeCastOp->getName().getIdentifier(),
+                                arithOp->getOperands()[1], resultVectorType,
+                                shapeCastOp->getAttrs());
+
+    // Replace shape_cast(a ArithOp b) with shape_cast(a) ArithOp shape_cast(b)
+    rewriter.replaceOpWithNewOp<ArithOp>(shapeCastOp, lhs->getResult(0),
+                                         rhs->getResult(0));
+
+    return success();
+  }
+};
+
 /// Pattern to eliminate redundant zero-constants added to reduction operands.
 /// It's enough for there to be one initial zero value, so we can eliminate the
 /// extra ones that feed into `vector.reduction <add>`. These get created by the
@@ -1514,6 +1616,15 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
   patterns.add<ShapeCastOpFolder>(patterns.getContext(), benefit);
 }
 
+void mlir::vector::populateReorderShapeCastPatterns(RewritePatternSet &patterns,
+                                                    PatternBenefit benefit) {
+  patterns.add<ReorderArithAndShapeCast<arith::AddIOp>,
+               ReorderArithAndShapeCast<arith::AddFOp>,
+               ReorderArithAndShapeCast<arith::MulIOp>,
+               ReorderArithAndShapeCast<arith::MulFOp>>(patterns.getContext(),
+                                                        benefit);
+}
+
 void mlir::vector::populateBubbleVectorBitCastOpPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<BubbleDownVectorBitCastForExtract,
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index ebec2274655e46..88a86755f9f7a6 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -254,3 +254,62 @@ func.func @transfer_read_flattenable_negative2(
 
 // CHECK-LABEL: func @transfer_read_flattenable_negative2
 //       CHECK:   vector.transfer_read {{.*}} vector<5x4x3x2xi8>
+
+// -----
+
+func.func @fold_unit_dim_add(%arg0 : vector<8x1xi32>,
+                             %arg1 : vector<1x8xi32>) -> vector<8xi32> {
+   %sc_arg0 = vector.shape_cast %arg0 : vector<8x1xi32> to vector<1x8xi32>
+   %add = arith.addi %sc_arg0, %arg1 : vector<1x8xi32>
+   %res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32>
+   return %res : vector<8xi32>
+}
+
+// CHECK-LABEL:   func.func @fold_unit_dim_add(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<8x1xi32>,
+// CHECK-SAME:      %[[VAL_1:.*]]: vector<1x8xi32>) -> vector<8xi32> {
+// CHECK:           %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1xi32> to vector<8xi32>
+// CHECK:           %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8xi32> to vector<8xi32>
+// CHECK:           %[[VAL_4:.*]] = arith.addi %[[VAL_2]], %[[VAL_3]] : vector<8xi32>
+// CHECK:           return %[[VAL_4]] : vector<8xi32>
+
+// -----
+
+func.func @fold_unit_dim_mulf(%arg0 : vector<8x[2]x1xf32>,
+                              %arg1 : vector<1x8x[2]xf32>) -> vector<8x[2]xf32> {
+   %sc_arg0 = vector.shape_cast %arg0 : vector<8x[2]x1xf32> to vector<1x8x[2]xf32>
+   %add = arith.mulf %sc_arg0, %arg1 : vector<1x8x[2]xf32>
+   %res = vector.shape_cast %add : vector<1x8x[2]xf32> to vector<8x[2]xf32>
+   return %res : vector<8x[2]xf32>
+}
+
+// CHECK-LABEL:   func.func @fold_unit_dim_mulf(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<8x[2]x1xf32>,
+// CHECK-SAME:      %[[VAL_1:.*]]: vector<1x8x[2]xf32>) -> vector<8x[2]xf32> {
+// CHECK:           %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x[2]x1xf32> to vector<8x[2]xf32>
+// CHECK:           %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x[2]xf32> to vector<8x[2]xf32>
+// CHECK:           %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[2]xf32>
+// CHECK:           return %[[VAL_4]] : vector<8x[2]xf32>
+
+// -----
+
+// All shape casts are folded away
+
+func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
+                                   %arg1 : vector<8xi32>,
+                                   %arg2 : vector<8xi32>) -> vector<8xi32> {
+   %sc_arg0 = vector.shape_cast %arg0 : vector<8xi32> to vector<1x8xi32>
+   %sc_arg1 = vector.shape_cast %arg1 : vector<8xi32> to vector<1x8xi32>
+   %sc_arg2 = vector.shape_cast %arg2 : vector<8xi32> to vector<1x8xi32>
+   %mul = arith.muli %sc_arg0, %sc_arg1 : vector<1x8xi32>
+   %add = arith.addi %mul, %sc_arg2 : vector<1x8xi32>
+   %res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32>
+   return %res : vector<8xi32>
+}
+
+// CHECK-LABEL:   func.func @fold_unit_dims_entirely(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<8xi32>, %[[VAL_1:.*]]: vector<8xi32>,
+// CHECK-SAME:      %[[VAL_2:.*]]: vector<8xi32>) -> vector<8xi32> {
+// CHECK:           %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32>
+// CHECK:           %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32>
+// CHECK:           return %[[VAL_4]] : vector<8xi32>

>From f54addbc2b07cd6a99e684304b89bae753bce1ae Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Mon, 11 Dec 2023 08:18:30 +0000
Subject: [PATCH 2/6] fixup! [mlir][vector] Add pattern to reorder
 shape_cast(arithmetic(a, b))

Address PR comments
---
 .../Vector/Transforms/VectorTransforms.cpp    | 40 +++++++++----------
 1 file changed, 20 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index b38442f7338144..03aaf85226fbcd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1492,39 +1492,40 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
 ///   vector.shape_cast %mul : vector<1x4xf32> to vector<4xf32>
 /// ```
 /// In addition, the input vector should be the result of an arithmetic
-/// operation, `AritOp`.
+/// operation, `ArithOp`.
 template <typename ArithOp>
 struct ReorderArithAndShapeCast : public OpRewritePattern<vector::ShapeCastOp> {
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
                                 PatternRewriter &rewriter) const override {
-    if (!llvm::isa_and_present<ArithOp>(
-            shapeCastOp.getSource().getDefiningOp()))
-      return failure();
-
     auto *arithOp = shapeCastOp.getSource().getDefiningOp();
+    if (!llvm::isa_and_present<ArithOp>(arithOp))
+      return failure();
 
     auto sourceVectorType =
-        dyn_cast_or_null<VectorType>(shapeCastOp.getSource().getType());
+        dyn_cast<VectorType>(shapeCastOp.getSource().getType());
     auto resultVectorType =
-        dyn_cast_or_null<VectorType>(shapeCastOp.getResult().getType());
+        dyn_cast<VectorType>(shapeCastOp.getResult().getType());
     if (!sourceVectorType || !resultVectorType)
       return failure();
 
     // Either the leading or the trailing dims of the input should be
     // non-scalable 1.
-    if (((sourceVectorType.getShape().back() != 1) ||
-         (sourceVectorType.getScalableDims().back())) &&
-        ((sourceVectorType.getShape().front() != 1) ||
-         (sourceVectorType.getScalableDims().front())))
+    bool leadDimUnitFixed = ((sourceVectorType.getShape().back() != 1) ||
+                             (sourceVectorType.getScalableDims().back()));
+    bool trailinDimUnitFixed = ((sourceVectorType.getShape().front() != 1) ||
+                                (sourceVectorType.getScalableDims().front()));
+    if (!leadDimUnitFixed && !trailinDimUnitFixed)
       return failure();
 
     // Does this shape_cast fold the input vector?
     if (resultVectorType.getRank() != (sourceVectorType.getRank() - 1))
       return failure();
 
-    // Does this shape_cast fold the _unit_ dim?
+    // Does this shape_cast fold the traling/leading _unit_ dim?
+    // TODO: Even when the trailing/leading unit dims are folded, there might
+    // still be some "inner" unit dims left.
     if (llvm::any_of(resultVectorType.getShape(),
                      [](int64_t dim) { return (dim == 1); }))
       return failure();
@@ -1532,17 +1533,16 @@ struct ReorderArithAndShapeCast : public OpRewritePattern<vector::ShapeCastOp> {
     auto loc = shapeCastOp->getLoc();
 
     // shape_cast(a)
-    auto *lhs = rewriter.create(loc, shapeCastOp->getName().getIdentifier(),
-                                arithOp->getOperands()[0], resultVectorType,
-                                shapeCastOp->getAttrs());
+    auto lhs = rewriter.create<vector::ShapeCastOp>(loc, resultVectorType,
+                                                    arithOp->getOperands()[0],
+                                                    shapeCastOp->getAttrs());
     // shape_cast(b)
-    auto *rhs = rewriter.create(loc, shapeCastOp->getName().getIdentifier(),
-                                arithOp->getOperands()[1], resultVectorType,
-                                shapeCastOp->getAttrs());
+    auto rhs = rewriter.create<vector::ShapeCastOp>(loc, resultVectorType,
+                                                    arithOp->getOperands()[1],
+                                                    shapeCastOp->getAttrs());
 
     // Replace shape_cast(a ArithOp b) with shape_cast(a) ArithOp shape_cast(b)
-    rewriter.replaceOpWithNewOp<ArithOp>(shapeCastOp, lhs->getResult(0),
-                                         rhs->getResult(0));
+    rewriter.replaceOpWithNewOp<ArithOp>(shapeCastOp, lhs, rhs);
 
     return success();
   }

>From 0ebb9c31b8cac0be465fae8a3cd1d4385d3b45b8 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 12 Dec 2023 19:53:30 +0000
Subject: [PATCH 3/6] fixup! [mlir][vector] Add pattern to reorder
 shape_cast(arithmetic(a, b))

Switch to matching an elementwise Op
---
 .../Vector/Transforms/VectorTransforms.cpp    | 127 +++++++-----------
 1 file changed, 50 insertions(+), 77 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 03aaf85226fbcd..dcac4f925c5c32 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1446,10 +1446,15 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
   }
 };
 
-/// Reorders:
-///   shape_cast(arithmetic(a + b))
-/// as
-///   arithmetic(shape_cast(a), shape_cast(b)).
+/// Replace:
+///   elementwise(a, b)
+/// with:
+///   sc_a = shape_cast(a)
+///   sc_b = shape_cast(b)
+///   res = elementwise(sc_a, sc_b)
+///   return shape_cast(res)
+/// for which `a` and `b` are vectors of rank > 2 and have unit leading and/or
+/// trailing dimension.
 ///
 /// Ex:
 /// ```
@@ -1464,85 +1469,57 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
 ///  %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
 ///  %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
 /// ```
-///
-/// While this pattern introduces an extra shape_cast Op (1 shape_cast is
-/// replaced with 2), this brings shape_cast closer to vector.xfer operations.
-/// With patterns like e.g. `FlattenContiguousRowMajorTransferWritePattern`,
-/// the addition shape_cast's are eventually folded away.
-///
-/// Here is another example where this pattern is helpful:
-/// ```
-///  %sc_arg0 = vector.shape_cast %arg0 : vector<8xi32> to vector<1x8xi32>
-///  %sc_arg1 = vector.shape_cast %arg1 : vector<8xi32> to vector<1x8xi32>
-///  %sc_arg2 = vector.shape_cast %arg2 : vector<8xi32> to vector<1x8xi32>
-///  %mul = arith.muli %sc_arg0, %sc_arg1 : vector<1x8xi32>
-///  %add = arith.addi %mul, %sc_arg2 : vector<1x8xi32>
-///  %res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32>
-/// ```
-///
-/// gets folded as:
-///
-///```
-///    %0 = arith.muli %arg0, %arg1 : vector<8xi32>
-///    %res = arith.addi %0, %arg2 : vector<8xi32>
-/// ```
-/// ATM this pattern is limited to `vector.shape_cast` ops that fold the unit
-/// dim, e.g.:
-/// ```
-///   vector.shape_cast %mul : vector<1x4xf32> to vector<4xf32>
-/// ```
-/// In addition, the input vector should be the result of an arithmetic
-/// operation, `ArithOp`.
-template <typename ArithOp>
-struct ReorderArithAndShapeCast : public OpRewritePattern<vector::ShapeCastOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+struct DropUnitDimFromElementwiseOps final
+    : public OpTraitRewritePattern<OpTrait::Elementwise> {
+  using OpTraitRewritePattern::OpTraitRewritePattern;
+  LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
-    auto *arithOp = shapeCastOp.getSource().getDefiningOp();
-    if (!llvm::isa_and_present<ArithOp>(arithOp))
-      return failure();
-
-    auto sourceVectorType =
-        dyn_cast<VectorType>(shapeCastOp.getSource().getType());
-    auto resultVectorType =
-        dyn_cast<VectorType>(shapeCastOp.getResult().getType());
-    if (!sourceVectorType || !resultVectorType)
+    if (op->getNumResults() != 1)
       return failure();
 
-    // Either the leading or the trailing dims of the input should be
-    // non-scalable 1.
-    bool leadDimUnitFixed = ((sourceVectorType.getShape().back() != 1) ||
-                             (sourceVectorType.getScalableDims().back()));
-    bool trailinDimUnitFixed = ((sourceVectorType.getShape().front() != 1) ||
-                                (sourceVectorType.getScalableDims().front()));
-    if (!leadDimUnitFixed && !trailinDimUnitFixed)
+    // Check the pre-condiitions. For `Elementwise` Ops all operands
+    // are guaranteed to have identical shapes and it suffices to only check the
+    // first one.
+    auto op1 = op->getOperands()[0];
+    auto sourceVectorType = dyn_cast<VectorType>(op1.getType());
+    if (!sourceVectorType)
       return failure();
 
-    // Does this shape_cast fold the input vector?
-    if (resultVectorType.getRank() != (sourceVectorType.getRank() - 1))
+    if (sourceVectorType.getRank() < 2)
       return failure();
 
-    // Does this shape_cast fold the traling/leading _unit_ dim?
-    // TODO: Even when the trailing/leading unit dims are folded, there might
-    // still be some "inner" unit dims left.
-    if (llvm::any_of(resultVectorType.getShape(),
-                     [](int64_t dim) { return (dim == 1); }))
+    bool trailingDimUnitFixed = ((sourceVectorType.getShape().back() == 1) &&
+                                 (!sourceVectorType.getScalableDims().back()));
+    bool leadDimUnitFixed = ((sourceVectorType.getShape().front() == 1) &&
+                             (!sourceVectorType.getScalableDims().front()));
+    if (!leadDimUnitFixed && !trailingDimUnitFixed)
       return failure();
 
-    auto loc = shapeCastOp->getLoc();
+    // Drop leading/trailing unit dim by applying vector.shape_cast to all
+    // operands
+    auto elTy = sourceVectorType.getElementType();
+    VectorType newVType =
+        leadDimUnitFixed
+            ? VectorType::get(sourceVectorType.getShape().drop_front(1), elTy,
+                              sourceVectorType.getScalableDims().drop_front(1))
+            : VectorType::get(sourceVectorType.getShape().drop_back(1), elTy,
+                              sourceVectorType.getScalableDims().drop_back(1));
+    SmallVector<Value> newOperands;
+    auto loc = op->getLoc();
+    for (auto operand : op->getOperands()) {
+      auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
+      newOperands.push_back(opSC);
+    }
 
-    // shape_cast(a)
-    auto lhs = rewriter.create<vector::ShapeCastOp>(loc, resultVectorType,
-                                                    arithOp->getOperands()[0],
-                                                    shapeCastOp->getAttrs());
-    // shape_cast(b)
-    auto rhs = rewriter.create<vector::ShapeCastOp>(loc, resultVectorType,
-                                                    arithOp->getOperands()[1],
-                                                    shapeCastOp->getAttrs());
+    // Create an updated elementwise Op without leading/trailing unit dim
+    Operation *elementwiseOp =
+        rewriter.create(op->getLoc(), op->getName().getIdentifier(),
+                        newOperands, newVType, op->getAttrs());
 
-    // Replace shape_cast(a ArithOp b) with shape_cast(a) ArithOp shape_cast(b)
-    rewriter.replaceOpWithNewOp<ArithOp>(shapeCastOp, lhs, rhs);
+    // Restore the leading/trailing unit dim by applying vector.shape_cast to
+    // the result
+    rewriter.replaceOpWithNewOp<ShapeCastOp>(op, sourceVectorType,
+                                             elementwiseOp->getResults()[0]);
 
     return success();
   }
@@ -1618,11 +1595,7 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
 
 void mlir::vector::populateReorderShapeCastPatterns(RewritePatternSet &patterns,
                                                     PatternBenefit benefit) {
-  patterns.add<ReorderArithAndShapeCast<arith::AddIOp>,
-               ReorderArithAndShapeCast<arith::AddFOp>,
-               ReorderArithAndShapeCast<arith::MulIOp>,
-               ReorderArithAndShapeCast<arith::MulFOp>>(patterns.getContext(),
-                                                        benefit);
+  patterns.add<DropUnitDimFromElementwiseOps>(patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateBubbleVectorBitCastOpPatterns(

>From ac5c0163cf354998d0977a925b8fa80338b4d13e Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 13 Dec 2023 08:58:26 +0000
Subject: [PATCH 4/6] fixup! fixup! [mlir][vector] Add pattern to reorder
 shape_cast(arithmetic(a, b))

Update comment, add ShapeCastOpFolder to the list of patterns
---
 mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp | 8 +++++++-
 1 file changed, 7 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index dcac4f925c5c32..2ec1183000beac 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1468,7 +1468,12 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
 ///  %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
 ///  %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
 ///  %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
+///  %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
+///  %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32>
 /// ```
+///
+/// Patterns for folding shape_casts should instantly eliminate `%cast_new` and
+/// `%cast`.
 struct DropUnitDimFromElementwiseOps final
     : public OpTraitRewritePattern<OpTrait::Elementwise> {
   using OpTraitRewritePattern::OpTraitRewritePattern;
@@ -1595,7 +1600,8 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
 
 void mlir::vector::populateReorderShapeCastPatterns(RewritePatternSet &patterns,
                                                     PatternBenefit benefit) {
-  patterns.add<DropUnitDimFromElementwiseOps>(patterns.getContext(), benefit);
+  patterns.add<DropUnitDimFromElementwiseOps, ShapeCastOpFolder>(
+      patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateBubbleVectorBitCastOpPatterns(

>From 5abdab6bc52c5ce4bbbffca6f2d8b7491711280b Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 13 Dec 2023 10:02:52 +0000
Subject: [PATCH 5/6] fixup! [mlir][vector] Add pattern to reorder
 shape_cast(arithmetic(a, b))

Rename + add comments
---
 .../Vector/Transforms/VectorRewritePatterns.h  | 11 ++++++++---
 .../Transforms/VectorTransferOpTransforms.cpp  |  2 +-
 .../Vector/Transforms/VectorTransforms.cpp     | 18 ++++++++++--------
 3 files changed, 19 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 7102ed81ec57d4..17173c01ab762a 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -294,9 +294,14 @@ void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns,
 void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns,
                                                 PatternBenefit benefit = 1);
 
-/// Collect a set of vector.shape_cast folding patterns.
-void populateReorderShapeCastPatterns(RewritePatternSet &patterns,
-                                      PatternBenefit benefit = 1);
+/// Collect a set of patterns that use vector.shape_cast to help fold unit dims.
+///
+/// These patterns use vector.shape_cast to remove unit dims from e.g.
+/// arithmetic operations on Vectors. The newly inserted shape_casts will either
+/// cancel each other out or will be folded away when combined with other
+/// patterns.
+void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns,
+                                              PatternBenefit benefit = 1);
 
 /// Collect a set of patterns to flatten n-D vector transfers on contiguous
 /// memref.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index ac475566ccdb1e..b761d1ed888973 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -922,5 +922,5 @@ void mlir::vector::populateFlattenVectorTransferPatterns(
                FlattenContiguousRowMajorTransferWritePattern>(
       patterns.getContext(), benefit);
   populateShapeCastFoldingPatterns(patterns, benefit);
-  populateReorderShapeCastPatterns(patterns, benefit);
+  populateDropUnitDimWithShapeCastPatterns(patterns, benefit);
 }
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 2ec1183000beac..6d86b867c82391 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1493,18 +1493,20 @@ struct DropUnitDimFromElementwiseOps final
     if (sourceVectorType.getRank() < 2)
       return failure();
 
-    bool trailingDimUnitFixed = ((sourceVectorType.getShape().back() == 1) &&
-                                 (!sourceVectorType.getScalableDims().back()));
-    bool leadDimUnitFixed = ((sourceVectorType.getShape().front() == 1) &&
-                             (!sourceVectorType.getScalableDims().front()));
-    if (!leadDimUnitFixed && !trailingDimUnitFixed)
+    bool hasTrailingDimUnitFixed =
+        ((sourceVectorType.getShape().back() == 1) &&
+         (!sourceVectorType.getScalableDims().back()));
+    bool hasLeadingDimUnitFixed =
+        ((sourceVectorType.getShape().front() == 1) &&
+         (!sourceVectorType.getScalableDims().front()));
+    if (!hasLeadingDimUnitFixed && !hasTrailingDimUnitFixed)
       return failure();
 
     // Drop leading/trailing unit dim by applying vector.shape_cast to all
     // operands
     auto elTy = sourceVectorType.getElementType();
     VectorType newVType =
-        leadDimUnitFixed
+        hasLeadingDimUnitFixed
             ? VectorType::get(sourceVectorType.getShape().drop_front(1), elTy,
                               sourceVectorType.getScalableDims().drop_front(1))
             : VectorType::get(sourceVectorType.getShape().drop_back(1), elTy,
@@ -1598,8 +1600,8 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
   patterns.add<ShapeCastOpFolder>(patterns.getContext(), benefit);
 }
 
-void mlir::vector::populateReorderShapeCastPatterns(RewritePatternSet &patterns,
-                                                    PatternBenefit benefit) {
+void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<DropUnitDimFromElementwiseOps, ShapeCastOpFolder>(
       patterns.getContext(), benefit);
 }

>From 1e6015586bf6aed45d875df84a58aee4be45eefe Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 13 Dec 2023 11:40:26 +0000
Subject: [PATCH 6/6] fixup! fixup! [mlir][vector] Add pattern to reorder
 shape_cast(arithmetic(a, b))

Add missing dependency on VectorDialect, refine implementation
---
 .../Vector/Transforms/VectorTransforms.cpp      | 17 +++++++----------
 .../Dialect/Vector/vector-transfer-flatten.mlir | 15 +++++++++++++++
 .../lib/Dialect/Vector/TestVectorTransforms.cpp |  1 +
 3 files changed, 23 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 6d86b867c82391..afb35a39f6f07c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1453,7 +1453,7 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
 ///   sc_b = shape_cast(b)
 ///   res = elementwise(sc_a, sc_b)
 ///   return shape_cast(res)
-/// for which `a` and `b` are vectors of rank > 2 and have unit leading and/or
+/// for which `a` and `b` are vectors of rank > 1 and have unit leading and/or
 /// trailing dimension.
 ///
 /// Ex:
@@ -1505,12 +1505,9 @@ struct DropUnitDimFromElementwiseOps final
     // Drop leading/trailing unit dim by applying vector.shape_cast to all
     // operands
     auto elTy = sourceVectorType.getElementType();
-    VectorType newVType =
-        hasLeadingDimUnitFixed
-            ? VectorType::get(sourceVectorType.getShape().drop_front(1), elTy,
-                              sourceVectorType.getScalableDims().drop_front(1))
-            : VectorType::get(sourceVectorType.getShape().drop_back(1), elTy,
-                              sourceVectorType.getScalableDims().drop_back(1));
+    int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank() - 1;
+    VectorType newVType = VectorType::Builder(sourceVectorType).dropDim(dim);
+
     SmallVector<Value> newOperands;
     auto loc = op->getLoc();
     for (auto operand : op->getOperands()) {
@@ -1520,13 +1517,13 @@ struct DropUnitDimFromElementwiseOps final
 
     // Create an updated elementwise Op without leading/trailing unit dim
     Operation *elementwiseOp =
-        rewriter.create(op->getLoc(), op->getName().getIdentifier(),
-                        newOperands, newVType, op->getAttrs());
+        rewriter.create(loc, op->getName().getIdentifier(), newOperands,
+                        newVType, op->getAttrs());
 
     // Restore the leading/trailing unit dim by applying vector.shape_cast to
     // the result
     rewriter.replaceOpWithNewOp<ShapeCastOp>(op, sourceVectorType,
-                                             elementwiseOp->getResults()[0]);
+                                             elementwiseOp->getResult(0));
 
     return success();
   }
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 88a86755f9f7a6..5ec1cb8b467258 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -257,6 +257,20 @@ func.func @transfer_read_flattenable_negative2(
 
 // -----
 
+func.func @fold_unit_dim_add_basic(%arg0 : vector<1x8xi32>) -> vector<1x8xi32> {
+   %add = arith.addi %arg0, %arg0 : vector<1x8xi32>
+   return %add : vector<1x8xi32>
+}
+// CHECK-LABEL:   func.func @fold_unit_dim_add_basic(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<1x8xi32>) -> vector<1x8xi32> {
+// CHECK:           %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x8xi32> to vector<8xi32>
+// CHECK:           %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x8xi32> to vector<8xi32>
+// CHECK:           %[[VAL_3:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : vector<8xi32>
+// CHECK:           %[[VAL_4:.*]] = vector.shape_cast %[[VAL_3]] : vector<8xi32> to vector<1x8xi32>
+// CHECK:           return %[[VAL_4]] : vector<1x8xi32>
+
+// -----
+
 func.func @fold_unit_dim_add(%arg0 : vector<8x1xi32>,
                              %arg1 : vector<1x8xi32>) -> vector<8xi32> {
    %sc_arg0 = vector.shape_cast %arg0 : vector<8x1xi32> to vector<1x8xi32>
@@ -313,3 +327,4 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
 // CHECK:           %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32>
 // CHECK:           %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32>
 // CHECK:           return %[[VAL_4]] : vector<8xi32>
+
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 86b8d5f9b0995a..21e8299b72398b 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -455,6 +455,7 @@ struct TestFlattenVectorTransferPatterns
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<memref::MemRefDialect>();
     registry.insert<affine::AffineDialect>();
+    registry.insert<vector::VectorDialect>();
   }
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());



More information about the Mlir-commits mailing list