[Mlir-commits] [mlir] [mlir][vector] Add pattern to reorder shape_cast(arithmetic(a, b)) (PR #74817)

Diego Caballero llvmlistbot at llvm.org
Mon Dec 11 08:28:41 PST 2023


================
@@ -1446,6 +1446,117 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
   }
 };
 
+/// Reorders:
+///   shape_cast(arithmetic(a + b))
+/// as
+///   arithmetic(shape_cast(a), shape_cast(b)).
+///
+/// Ex:
+/// ```
+///  %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
+///  %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
+/// ```
+///
+/// gets converted to:
+///
+/// ```
+///  %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
+///  %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
+///  %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
+/// ```
+///
+/// While this pattern introduces an extra shape_cast Op (1 shape_cast is
+/// replaced with 2), this brings shape_cast closer to vector.xfer operations.
+/// With patterns like e.g. `FlattenContiguousRowMajorTransferWritePattern`,
+/// the addition shape_cast's are eventually folded away.
+///
+/// Here is another example where this pattern is helpful:
+/// ```
+///  %sc_arg0 = vector.shape_cast %arg0 : vector<8xi32> to vector<1x8xi32>
+///  %sc_arg1 = vector.shape_cast %arg1 : vector<8xi32> to vector<1x8xi32>
+///  %sc_arg2 = vector.shape_cast %arg2 : vector<8xi32> to vector<1x8xi32>
+///  %mul = arith.muli %sc_arg0, %sc_arg1 : vector<1x8xi32>
+///  %add = arith.addi %mul, %sc_arg2 : vector<1x8xi32>
+///  %res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32>
+/// ```
+///
+/// gets folded as:
+///
+///```
+///    %0 = arith.muli %arg0, %arg1 : vector<8xi32>
+///    %res = arith.addi %0, %arg2 : vector<8xi32>
+/// ```
+/// ATM this pattern is limited to `vector.shape_cast` ops that fold the unit
+/// dim, e.g.:
+/// ```
+///   vector.shape_cast %mul : vector<1x4xf32> to vector<4xf32>
+/// ```
+/// In addition, the input vector should be the result of an arithmetic
+/// operation, `ArithOp`.
+template <typename ArithOp>
+struct ReorderShapeCastWithUnitDimAndArith
+    : public OpRewritePattern<vector::ShapeCastOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+                                PatternRewriter &rewriter) const override {
+    auto arithOp = shapeCastOp.getSource().getDefiningOp<ArithOp>();
+    if (!arithOp)
+      return failure();
+
+    // All arith ops are elementwise - filter out everything else.
+    if (!arithOp.template hasTrait<OpTrait::Elementwise>())
+      return failure();
+
+    // TODO: Add support for unary ops
+    if (arithOp->getOperands().size() != 2)
+      return failure();
+
+    auto sourceVectorType =
+        dyn_cast<VectorType>(shapeCastOp.getSource().getType());
+    auto resultVectorType =
+        dyn_cast<VectorType>(shapeCastOp.getResult().getType());
+    if (!sourceVectorType || !resultVectorType)
+      return failure();
+
+    // Either the leading or the trailing dims of the input should be
+    // non-scalable 1.
+    bool leadDimUnitFixed = ((sourceVectorType.getShape().back() != 1) ||
+                             (sourceVectorType.getScalableDims().back()));
+    bool trailinDimUnitFixed = ((sourceVectorType.getShape().front() != 1) ||
+                                (sourceVectorType.getScalableDims().front()));
+    if (!leadDimUnitFixed && !trailinDimUnitFixed)
+      return failure();
+
+    // Does this shape_cast fold the input vector?
+    if (resultVectorType.getRank() != (sourceVectorType.getRank() - 1))
+      return failure();
+
+    // Does this shape_cast fold the traling/leading _unit_ dim?
+    // TODO: Even when the trailing/leading unit dims are folded, there might
+    // still be some "inner" unit dims left.
+    if (llvm::any_of(resultVectorType.getShape(),
+                     [](int64_t dim) { return (dim == 1); }))
+      return failure();
+
+    auto loc = shapeCastOp->getLoc();
+
+    // shape_cast(a)
+    auto lhs = rewriter.create<vector::ShapeCastOp>(loc, resultVectorType,
+                                                    arithOp->getOperands()[0],
+                                                    shapeCastOp->getAttrs());
+    // shape_cast(b)
+    auto rhs = rewriter.create<vector::ShapeCastOp>(loc, resultVectorType,
+                                                    arithOp->getOperands()[1],
+                                                    shapeCastOp->getAttrs());
+
+    // Replace shape_cast(a ArithOp b) with shape_cast(a) ArithOp shape_cast(b)
+    rewriter.replaceOpWithNewOp<ArithOp>(shapeCastOp, lhs, rhs);
----------------
dcaballe wrote:

To replace a new op in a generic way we can use the OperationState approach and extract all the op fields from the original operation.

https://github.com/llvm/llvm-project/pull/74817


More information about the Mlir-commits mailing list