[Mlir-commits] [mlir] [mlir][vector] Add pattern to reorder shape_cast(arithmetic(a, b)) (PR #74817)
Andrzej Warzyński
llvmlistbot at llvm.org
Mon Dec 11 00:07:13 PST 2023
================
@@ -1446,6 +1446,108 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
}
};
+/// Reorders:
+/// shape_cast(arithmetic(a + b))
+/// as
+/// arithmetic(shape_cast(a), shape_cast(b)).
+///
+/// Ex:
+/// ```
+/// %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
+/// %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
+/// ```
+///
+/// gets converted to:
+///
+/// ```
+/// %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
+/// %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
+/// %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
+/// ```
+///
+/// While this pattern introduces an extra shape_cast Op (1 shape_cast is
+/// replaced with 2), this brings shape_cast closer to vector.xfer operations.
+/// With patterns like e.g. `FlattenContiguousRowMajorTransferWritePattern`,
+/// the addition shape_cast's are eventually folded away.
+///
+/// Here is another example where this pattern is helpful:
+/// ```
+/// %sc_arg0 = vector.shape_cast %arg0 : vector<8xi32> to vector<1x8xi32>
+/// %sc_arg1 = vector.shape_cast %arg1 : vector<8xi32> to vector<1x8xi32>
+/// %sc_arg2 = vector.shape_cast %arg2 : vector<8xi32> to vector<1x8xi32>
+/// %mul = arith.muli %sc_arg0, %sc_arg1 : vector<1x8xi32>
+/// %add = arith.addi %mul, %sc_arg2 : vector<1x8xi32>
+/// %res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32>
+/// ```
+///
+/// gets folded as:
+///
+///```
+/// %0 = arith.muli %arg0, %arg1 : vector<8xi32>
+/// %res = arith.addi %0, %arg2 : vector<8xi32>
+/// ```
+/// ATM this pattern is limited to `vector.shape_cast` ops that fold the unit
+/// dim, e.g.:
+/// ```
+/// vector.shape_cast %mul : vector<1x4xf32> to vector<4xf32>
+/// ```
+/// In addition, the input vector should be the result of an arithmetic
+/// operation, `AritOp`.
+template <typename ArithOp>
+struct ReorderArithAndShapeCast : public OpRewritePattern<vector::ShapeCastOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+ PatternRewriter &rewriter) const override {
+ if (!llvm::isa_and_present<ArithOp>(
+ shapeCastOp.getSource().getDefiningOp()))
+ return failure();
+
+ auto *arithOp = shapeCastOp.getSource().getDefiningOp();
+
+ auto sourceVectorType =
+ dyn_cast_or_null<VectorType>(shapeCastOp.getSource().getType());
+ auto resultVectorType =
+ dyn_cast_or_null<VectorType>(shapeCastOp.getResult().getType());
+ if (!sourceVectorType || !resultVectorType)
+ return failure();
+
+ // Either the leading or the trailing dims of the input should be
+ // non-scalable 1.
+ if (((sourceVectorType.getShape().back() != 1) ||
+ (sourceVectorType.getScalableDims().back())) &&
+ ((sourceVectorType.getShape().front() != 1) ||
+ (sourceVectorType.getScalableDims().front())))
+ return failure();
+
+ // Does this shape_cast fold the input vector?
+ if (resultVectorType.getRank() != (sourceVectorType.getRank() - 1))
+ return failure();
+
+ // Does this shape_cast fold the _unit_ dim?
+ if (llvm::any_of(resultVectorType.getShape(),
+ [](int64_t dim) { return (dim == 1); }))
+ return failure();
+
+ auto loc = shapeCastOp->getLoc();
+
+ // shape_cast(a)
+ auto *lhs = rewriter.create(loc, shapeCastOp->getName().getIdentifier(),
+ arithOp->getOperands()[0], resultVectorType,
+ shapeCastOp->getAttrs());
----------------
banach-space wrote:
Ah, I was playing around with some idea that didn't work 😢 .
https://github.com/llvm/llvm-project/pull/74817
More information about the Mlir-commits
mailing list