[Mlir-commits] [mlir] [mlir][vector] Add pattern to reorder shape_cast(arithmetic(a, b)) (PR #74817)

Cullen Rhodes llvmlistbot at llvm.org
Fri Dec 8 09:35:17 PST 2023


================
@@ -1446,6 +1446,108 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
   }
 };
 
+/// Reorders:
+///   shape_cast(arithmetic(a + b))
+/// as
+///   arithmetic(shape_cast(a), shape_cast(b)).
+///
+/// Ex:
+/// ```
+///  %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
+///  %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
+/// ```
+///
+/// gets converted to:
+///
+/// ```
+///  %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
+///  %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
+///  %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
+/// ```
+///
+/// While this pattern introduces an extra shape_cast Op (1 shape_cast is
+/// replaced with 2), this brings shape_cast closer to vector.xfer operations.
+/// With patterns like e.g. `FlattenContiguousRowMajorTransferWritePattern`,
+/// the addition shape_cast's are eventually folded away.
+///
+/// Here is another example where this pattern is helpful:
+/// ```
+///  %sc_arg0 = vector.shape_cast %arg0 : vector<8xi32> to vector<1x8xi32>
+///  %sc_arg1 = vector.shape_cast %arg1 : vector<8xi32> to vector<1x8xi32>
+///  %sc_arg2 = vector.shape_cast %arg2 : vector<8xi32> to vector<1x8xi32>
+///  %mul = arith.muli %sc_arg0, %sc_arg1 : vector<1x8xi32>
+///  %add = arith.addi %mul, %sc_arg2 : vector<1x8xi32>
+///  %res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32>
+/// ```
+///
+/// gets folded as:
+///
+///```
+///    %0 = arith.muli %arg0, %arg1 : vector<8xi32>
+///    %res = arith.addi %0, %arg2 : vector<8xi32>
+/// ```
+/// ATM this pattern is limited to `vector.shape_cast` ops that fold the unit
+/// dim, e.g.:
+/// ```
+///   vector.shape_cast %mul : vector<1x4xf32> to vector<4xf32>
+/// ```
+/// In addition, the input vector should be the result of an arithmetic
+/// operation, `AritOp`.
+template <typename ArithOp>
+struct ReorderArithAndShapeCast : public OpRewritePattern<vector::ShapeCastOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+                                PatternRewriter &rewriter) const override {
+    if (!llvm::isa_and_present<ArithOp>(
+            shapeCastOp.getSource().getDefiningOp()))
+      return failure();
+
+    auto *arithOp = shapeCastOp.getSource().getDefiningOp();
+
+    auto sourceVectorType =
+        dyn_cast_or_null<VectorType>(shapeCastOp.getSource().getType());
+    auto resultVectorType =
+        dyn_cast_or_null<VectorType>(shapeCastOp.getResult().getType());
+    if (!sourceVectorType || !resultVectorType)
+      return failure();
+
+    // Either the leading or the trailing dims of the input should be
+    // non-scalable 1.
+    if (((sourceVectorType.getShape().back() != 1) ||
+         (sourceVectorType.getScalableDims().back())) &&
+        ((sourceVectorType.getShape().front() != 1) ||
+         (sourceVectorType.getScalableDims().front())))
----------------
c-rhodes wrote:

parens surrounding `!= 1` check can be dropped
```suggestion
    if ((sourceVectorType.getShape().back() != 1 ||
         sourceVectorType.getScalableDims().back()) &&
        (sourceVectorType.getShape().front() != 1 ||
         sourceVectorType.getScalableDims().front()))
```

https://github.com/llvm/llvm-project/pull/74817


More information about the Mlir-commits mailing list