[Mlir-commits] [mlir] [mlir][vector] Add pattern to reorder shape_cast(arithmetic(a, b)) (PR #74817)
Diego Caballero
llvmlistbot at llvm.org
Mon Dec 11 08:28:41 PST 2023
================
@@ -1446,6 +1446,117 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
}
};
+/// Reorders:
+/// shape_cast(arithmetic(a + b))
+/// as
+/// arithmetic(shape_cast(a), shape_cast(b)).
+///
+/// Ex:
+/// ```
+/// %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
+/// %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
+/// ```
+///
+/// gets converted to:
+///
+/// ```
+/// %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
+/// %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
+/// %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
+/// ```
+///
+/// While this pattern introduces an extra shape_cast Op (1 shape_cast is
+/// replaced with 2), this brings shape_cast closer to vector.xfer operations.
+/// With patterns like e.g. `FlattenContiguousRowMajorTransferWritePattern`,
+/// the addition shape_cast's are eventually folded away.
+///
+/// Here is another example where this pattern is helpful:
+/// ```
+/// %sc_arg0 = vector.shape_cast %arg0 : vector<8xi32> to vector<1x8xi32>
+/// %sc_arg1 = vector.shape_cast %arg1 : vector<8xi32> to vector<1x8xi32>
+/// %sc_arg2 = vector.shape_cast %arg2 : vector<8xi32> to vector<1x8xi32>
+/// %mul = arith.muli %sc_arg0, %sc_arg1 : vector<1x8xi32>
+/// %add = arith.addi %mul, %sc_arg2 : vector<1x8xi32>
+/// %res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32>
+/// ```
+///
+/// gets folded as:
+///
+///```
+/// %0 = arith.muli %arg0, %arg1 : vector<8xi32>
+/// %res = arith.addi %0, %arg2 : vector<8xi32>
+/// ```
+/// ATM this pattern is limited to `vector.shape_cast` ops that fold the unit
+/// dim, e.g.:
+/// ```
+/// vector.shape_cast %mul : vector<1x4xf32> to vector<4xf32>
+/// ```
+/// In addition, the input vector should be the result of an arithmetic
+/// operation, `ArithOp`.
+template <typename ArithOp>
+struct ReorderShapeCastWithUnitDimAndArith
+ : public OpRewritePattern<vector::ShapeCastOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+ PatternRewriter &rewriter) const override {
+ auto arithOp = shapeCastOp.getSource().getDefiningOp<ArithOp>();
+ if (!arithOp)
+ return failure();
+
+ // All arith ops are elementwise - filter out everything else.
+ if (!arithOp.template hasTrait<OpTrait::Elementwise>())
+ return failure();
+
+ // TODO: Add support for unary ops
+ if (arithOp->getOperands().size() != 2)
+ return failure();
+
+ auto sourceVectorType =
+ dyn_cast<VectorType>(shapeCastOp.getSource().getType());
+ auto resultVectorType =
+ dyn_cast<VectorType>(shapeCastOp.getResult().getType());
+ if (!sourceVectorType || !resultVectorType)
+ return failure();
+
+ // Either the leading or the trailing dims of the input should be
+ // non-scalable 1.
+ bool leadDimUnitFixed = ((sourceVectorType.getShape().back() != 1) ||
+ (sourceVectorType.getScalableDims().back()));
+ bool trailinDimUnitFixed = ((sourceVectorType.getShape().front() != 1) ||
+ (sourceVectorType.getScalableDims().front()));
+ if (!leadDimUnitFixed && !trailinDimUnitFixed)
+ return failure();
+
+ // Does this shape_cast fold the input vector?
+ if (resultVectorType.getRank() != (sourceVectorType.getRank() - 1))
+ return failure();
+
+ // Does this shape_cast fold the traling/leading _unit_ dim?
+ // TODO: Even when the trailing/leading unit dims are folded, there might
+ // still be some "inner" unit dims left.
+ if (llvm::any_of(resultVectorType.getShape(),
+ [](int64_t dim) { return (dim == 1); }))
+ return failure();
+
+ auto loc = shapeCastOp->getLoc();
+
+ // shape_cast(a)
+ auto lhs = rewriter.create<vector::ShapeCastOp>(loc, resultVectorType,
+ arithOp->getOperands()[0],
+ shapeCastOp->getAttrs());
+ // shape_cast(b)
+ auto rhs = rewriter.create<vector::ShapeCastOp>(loc, resultVectorType,
+ arithOp->getOperands()[1],
+ shapeCastOp->getAttrs());
+
+ // Replace shape_cast(a ArithOp b) with shape_cast(a) ArithOp shape_cast(b)
+ rewriter.replaceOpWithNewOp<ArithOp>(shapeCastOp, lhs, rhs);
----------------
dcaballe wrote:
To replace a new op in a generic way we can use the OperationState approach and extract all the op fields from the original operation.
https://github.com/llvm/llvm-project/pull/74817
More information about the Mlir-commits
mailing list