[Mlir-commits] [mlir] [mlir][vector] Add pattern to reorder shape_cast(arithmetic(a, b)) (PR #74817)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 8 00:39:37 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Andrzej WarzyƄski (banach-space)

<details>
<summary>Changes</summary>

Adds a new pattern to reorder:

  shape_cast(arithmetic(a + b))

as

  arithmetic(shape_cast(a), shape_cast(b)).

Example:
```mlir
  %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
  %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
```

gets converted to:

```mlir
  %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
  %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
  %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
```

This helps to, eventually, fold away shape_casts entirely. One specific
case where this is beneficial is when vectorising 1D depthwise
convolutions, e.g. `linalg.depthwise_conv_1d_nwc_wc`, with the channel
dimension is flattened, i.e. `flatten_1d_depthwise_conv` is set:
```
transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv}
```

Output after vectorisation:
```mlir
  func.func @<!-- -->depthwise_conv1d_nwc_wc_1x8x3xi8_memref(%arg0: memref<1x8x3xi8>, %arg1: memref<1x3xi8>, %arg2: memref<1x8x3xi8>) {
    %c0 = arith.constant 0 : index
    %c0_i8 = arith.constant 0 : i8
    %0 = vector.transfer_read %arg0[%c0, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref<1x8x3xi8>, vector<1x8x3xi8>
    %1 = vector.transfer_read %arg1[%c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<1x3xi8>, vector<1x3xi8>
    %2 = vector.transfer_read %arg2[%c0, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref<1x8x3xi8>, vector<1x8x3xi8>
    %3 = vector.extract %1[0] : vector<3xi8> from vector<1x3xi8>
    %4 = vector.shape_cast %0 : vector<1x8x3xi8> to vector<1x24xi8>
    %5 = vector.shape_cast %2 : vector<1x8x3xi8> to vector<1x24xi8>
    %6 = vector.broadcast %3 : vector<3xi8> to vector<1x8x3xi8>
    %7 = vector.shape_cast %6 : vector<1x8x3xi8> to vector<1x24xi8>
    %8 = arith.muli %4, %7 : vector<1x24xi8>
    %9 = arith.addi %8, %5 : vector<1x24xi8>
    %10 = vector.shape_cast %9 : vector<1x24xi8> to vector<1x8x3xi8>
    vector.transfer_write %10, %arg2[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x8x3xi8>, memref<1x8x3xi8>
    return
  }
```

Output after applying this patterns and other related patterns (e.g.
-test-vector-transfer-flatten-patterns -canonicalize)
```mlir
  func.func @<!-- -->depthwise_conv1d_nwc_wc_1x8x3xi8_memref(%arg0: memref<1x8x3xi8>, %arg1: memref<1x3xi8>, %arg2: memref<1x8x3xi8>) {
    %c0_i8 = arith.constant 0 : i8
    %c0 = arith.constant 0 : index
    %collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2]] : memref<1x8x3xi8> into memref<24xi8>
    %0 = vector.transfer_read %collapse_shape[%c0], %c0_i8 {in_bounds = [true]} : memref<24xi8>, vector<24xi8>
    %collapse_shape_0 = memref.collapse_shape %arg1 [[0, 1]] : memref<1x3xi8> into memref<3xi8>
    %1 = vector.transfer_read %collapse_shape_0[%c0], %c0_i8 {in_bounds = [true]} : memref<3xi8>, vector<3xi8>
    %collapse_shape_1 = memref.collapse_shape %arg2 [[0, 1, 2]] : memref<1x8x3xi8> into memref<24xi8>
    %2 = vector.transfer_read %collapse_shape_1[%c0], %c0_i8 {in_bounds = [true]} : memref<24xi8>, vector<24xi8>
    %3 = vector.broadcast %1 : vector<3xi8> to vector<1x8x3xi8>
    %collapse_shape_2 = memref.collapse_shape %arg2 [[0, 1, 2]] : memref<1x8x3xi8> into memref<24xi8>
    %4 = vector.shape_cast %3 : vector<1x8x3xi8> to vector<24xi8>
    %5 = arith.muli %0, %4 : vector<24xi8>
    %6 = arith.addi %5, %2 : vector<24xi8>
    vector.transfer_write %6, %collapse_shape_2[%c0] {in_bounds = [true]} : vector<24xi8>, memref<24xi8>
    return
  }
```

Both shape_casts and the leading unit dims have been folded away.


---
Full diff: https://github.com/llvm/llvm-project/pull/74817.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+4) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+1) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+111) 
- (modified) mlir/test/Dialect/Vector/vector-transfer-flatten.mlir (+79) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 08c08172d0531e..b0e626cf39b0f8 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -294,6 +294,10 @@ void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns,
 void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns,
                                                 PatternBenefit benefit = 1);
 
+/// Collect a set of vector.shape_cast folding patterns.
+void populateArithOpDropUnitDimPatterns(RewritePatternSet &patterns,
+                                        PatternBenefit benefit = 1);
+
 /// Collect a set of patterns to flatten n-D vector transfers on contiguous
 /// memref.
 ///
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index ed42e6508b4310..ac475566ccdb1e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -922,4 +922,5 @@ void mlir::vector::populateFlattenVectorTransferPatterns(
                FlattenContiguousRowMajorTransferWritePattern>(
       patterns.getContext(), benefit);
   populateShapeCastFoldingPatterns(patterns, benefit);
+  populateReorderShapeCastPatterns(patterns, benefit);
 }
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 6e7fab293d3a1c..b38442f7338144 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1446,6 +1446,108 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
   }
 };
 
+/// Reorders:
+///   shape_cast(arithmetic(a + b))
+/// as
+///   arithmetic(shape_cast(a), shape_cast(b)).
+///
+/// Ex:
+/// ```
+///  %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
+///  %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
+/// ```
+///
+/// gets converted to:
+///
+/// ```
+///  %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
+///  %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
+///  %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
+/// ```
+///
+/// While this pattern introduces an extra shape_cast Op (1 shape_cast is
+/// replaced with 2), this brings shape_cast closer to vector.xfer operations.
+/// With patterns like e.g. `FlattenContiguousRowMajorTransferWritePattern`,
+/// the addition shape_cast's are eventually folded away.
+///
+/// Here is another example where this pattern is helpful:
+/// ```
+///  %sc_arg0 = vector.shape_cast %arg0 : vector<8xi32> to vector<1x8xi32>
+///  %sc_arg1 = vector.shape_cast %arg1 : vector<8xi32> to vector<1x8xi32>
+///  %sc_arg2 = vector.shape_cast %arg2 : vector<8xi32> to vector<1x8xi32>
+///  %mul = arith.muli %sc_arg0, %sc_arg1 : vector<1x8xi32>
+///  %add = arith.addi %mul, %sc_arg2 : vector<1x8xi32>
+///  %res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32>
+/// ```
+///
+/// gets folded as:
+///
+///```
+///    %0 = arith.muli %arg0, %arg1 : vector<8xi32>
+///    %res = arith.addi %0, %arg2 : vector<8xi32>
+/// ```
+/// ATM this pattern is limited to `vector.shape_cast` ops that fold the unit
+/// dim, e.g.:
+/// ```
+///   vector.shape_cast %mul : vector<1x4xf32> to vector<4xf32>
+/// ```
+/// In addition, the input vector should be the result of an arithmetic
+/// operation, `AritOp`.
+template <typename ArithOp>
+struct ReorderArithAndShapeCast : public OpRewritePattern<vector::ShapeCastOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+                                PatternRewriter &rewriter) const override {
+    if (!llvm::isa_and_present<ArithOp>(
+            shapeCastOp.getSource().getDefiningOp()))
+      return failure();
+
+    auto *arithOp = shapeCastOp.getSource().getDefiningOp();
+
+    auto sourceVectorType =
+        dyn_cast_or_null<VectorType>(shapeCastOp.getSource().getType());
+    auto resultVectorType =
+        dyn_cast_or_null<VectorType>(shapeCastOp.getResult().getType());
+    if (!sourceVectorType || !resultVectorType)
+      return failure();
+
+    // Either the leading or the trailing dims of the input should be
+    // non-scalable 1.
+    if (((sourceVectorType.getShape().back() != 1) ||
+         (sourceVectorType.getScalableDims().back())) &&
+        ((sourceVectorType.getShape().front() != 1) ||
+         (sourceVectorType.getScalableDims().front())))
+      return failure();
+
+    // Does this shape_cast fold the input vector?
+    if (resultVectorType.getRank() != (sourceVectorType.getRank() - 1))
+      return failure();
+
+    // Does this shape_cast fold the _unit_ dim?
+    if (llvm::any_of(resultVectorType.getShape(),
+                     [](int64_t dim) { return (dim == 1); }))
+      return failure();
+
+    auto loc = shapeCastOp->getLoc();
+
+    // shape_cast(a)
+    auto *lhs = rewriter.create(loc, shapeCastOp->getName().getIdentifier(),
+                                arithOp->getOperands()[0], resultVectorType,
+                                shapeCastOp->getAttrs());
+    // shape_cast(b)
+    auto *rhs = rewriter.create(loc, shapeCastOp->getName().getIdentifier(),
+                                arithOp->getOperands()[1], resultVectorType,
+                                shapeCastOp->getAttrs());
+
+    // Replace shape_cast(a ArithOp b) with shape_cast(a) ArithOp shape_cast(b)
+    rewriter.replaceOpWithNewOp<ArithOp>(shapeCastOp, lhs->getResult(0),
+                                         rhs->getResult(0));
+
+    return success();
+  }
+};
+
 /// Pattern to eliminate redundant zero-constants added to reduction operands.
 /// It's enough for there to be one initial zero value, so we can eliminate the
 /// extra ones that feed into `vector.reduction <add>`. These get created by the
@@ -1514,6 +1616,15 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
   patterns.add<ShapeCastOpFolder>(patterns.getContext(), benefit);
 }
 
+void mlir::vector::populateReorderShapeCastPatterns(RewritePatternSet &patterns,
+                                                    PatternBenefit benefit) {
+  patterns.add<ReorderArithAndShapeCast<arith::AddIOp>,
+               ReorderArithAndShapeCast<arith::AddFOp>,
+               ReorderArithAndShapeCast<arith::MulIOp>,
+               ReorderArithAndShapeCast<arith::MulFOp>>(patterns.getContext(),
+                                                        benefit);
+}
+
 void mlir::vector::populateBubbleVectorBitCastOpPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<BubbleDownVectorBitCastForExtract,
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index ebec2274655e46..232037241fdb7d 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -254,3 +254,82 @@ func.func @transfer_read_flattenable_negative2(
 
 // CHECK-LABEL: func @transfer_read_flattenable_negative2
 //       CHECK:   vector.transfer_read {{.*}} vector<5x4x3x2xi8>
+
+// -----
+
+func.func @fold_unit_dim_add(%arg0 : vector<8x1xi32>,
+                             %arg1 : vector<1x8xi32>) -> vector<8xi32> {
+   %sc_arg0 = vector.shape_cast %arg0 : vector<8x1xi32> to vector<1x8xi32>
+   %add = arith.addi %sc_arg0, %arg1 : vector<1x8xi32>
+   %res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32>
+   return %res : vector<8xi32>
+}
+
+// CHECK-LABEL:   func.func @fold_unit_dim_add(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<8x1xi32>,
+// CHECK-SAME:      %[[VAL_1:.*]]: vector<1x8xi32>) -> vector<8xi32> {
+// CHECK:           %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1xi32> to vector<8xi32>
+// CHECK:           %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8xi32> to vector<8xi32>
+// CHECK:           %[[VAL_4:.*]] = arith.addi %[[VAL_2]], %[[VAL_3]] : vector<8xi32>
+// CHECK:           return %[[VAL_4]] : vector<8xi32>
+
+// -----
+
+func.func @fold_unit_dim_mulf(%arg0 : vector<8x[2]x1xf32>,
+                              %arg1 : vector<1x8x[2]xf32>) -> vector<8x[2]xf32> {
+   %sc_arg0 = vector.shape_cast %arg0 : vector<8x[2]x1xf32> to vector<1x8x[2]xf32>
+   %add = arith.mulf %sc_arg0, %arg1 : vector<1x8x[2]xf32>
+   %res = vector.shape_cast %add : vector<1x8x[2]xf32> to vector<8x[2]xf32>
+   return %res : vector<8x[2]xf32>
+}
+
+// CHECK-LABEL:   func.func @fold_unit_dim_mulf(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<8x[2]x1xf32>,
+// CHECK-SAME:      %[[VAL_1:.*]]: vector<1x8x[2]xf32>) -> vector<8x[2]xf32> {
+// CHECK:           %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x[2]x1xf32> to vector<8x[2]xf32>
+// CHECK:           %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x[2]xf32> to vector<8x[2]xf32>
+// CHECK:           %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[2]xf32>
+// CHECK:           return %[[VAL_4]] : vector<8x[2]xf32>
+
+// -----
+
+// All shape casts are folded away
+
+func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
+                                   %arg1 : vector<8xi32>,
+                                   %arg2 : vector<8xi32>) -> vector<8xi32> {
+   %sc_arg0 = vector.shape_cast %arg0 : vector<8xi32> to vector<1x8xi32>
+   %sc_arg1 = vector.shape_cast %arg1 : vector<8xi32> to vector<1x8xi32>
+   %sc_arg2 = vector.shape_cast %arg2 : vector<8xi32> to vector<1x8xi32>
+   %mul = arith.muli %sc_arg0, %sc_arg1 : vector<1x8xi32>
+   %add = arith.addi %mul, %sc_arg2 : vector<1x8xi32>
+   %res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32>
+   return %res : vector<8xi32>
+}
+
+// CHECK-LABEL:   func.func @fold_unit_dims_entirely(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<8xi32>,
+// CHECK-SAME:      %[[VAL_1:.*]]: vector<8xi32>,
+// CHECK-SAME:      %[[VAL_2:.*]]: vector<8xi32>) -> vector<8xi32> {
+// CHECK:           %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32>
+// CHECK:           %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32>
+// CHECK:           return %[[VAL_4]] : vector<8xi32>
+
+// -----
+
+func.func @dont_fold_unit_scalable(%arg0 : vector<8x[1]xi32>,
+                             %arg1 : vector<[1]x8xi32>) -> vector<8xi32> {
+   %sc_arg0 = vector.shape_cast %arg0 : vector<8x[1]xi32> to vector<[1]x8xi32>
+   %add = arith.addi %sc_arg0, %arg1 : vector<[1]x8xi32>
+   %res = vector.shape_cast %add : vector<[1]x8xi32> to vector<8xi32>
+   return %res : vector<8xi32>
+}
+
+// CHECK-LABEL:   func.func @fold_unit_dim_add(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<8x1xi32>,
+// CHECK-SAME:      %[[VAL_1:.*]]: vector<1x8xi32>) -> vector<8xi32> {
+// CHECK:           %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1xi32> to vector<8xi32>
+// CHECK:           %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8xi32> to vector<8xi32>
+// CHECK:           %[[VAL_4:.*]] = arith.addi %[[VAL_2]], %[[VAL_3]] : vector<8xi32>
+// CHECK:           return %[[VAL_4]] : vector<8xi32>
+

``````````

</details>


https://github.com/llvm/llvm-project/pull/74817


More information about the Mlir-commits mailing list