[Mlir-commits] [mlir] [mlir][vector] Add pattern to reorder shape_cast(arithmetic(a, b)) (PR #74817)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Mon Dec 11 03:38:44 PST 2023
================
@@ -1446,6 +1446,108 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
}
};
+/// Reorders:
+/// shape_cast(arithmetic(a + b))
+/// as
+/// arithmetic(shape_cast(a), shape_cast(b)).
+///
+/// Ex:
+/// ```
+/// %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
+/// %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
+/// ```
+///
+/// gets converted to:
+///
+/// ```
+/// %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
+/// %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
+/// %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
+/// ```
+///
+/// While this pattern introduces an extra shape_cast Op (1 shape_cast is
+/// replaced with 2), this brings shape_cast closer to vector.xfer operations.
+/// With patterns like e.g. `FlattenContiguousRowMajorTransferWritePattern`,
+/// the addition shape_cast's are eventually folded away.
+///
+/// Here is another example where this pattern is helpful:
+/// ```
+/// %sc_arg0 = vector.shape_cast %arg0 : vector<8xi32> to vector<1x8xi32>
+/// %sc_arg1 = vector.shape_cast %arg1 : vector<8xi32> to vector<1x8xi32>
+/// %sc_arg2 = vector.shape_cast %arg2 : vector<8xi32> to vector<1x8xi32>
+/// %mul = arith.muli %sc_arg0, %sc_arg1 : vector<1x8xi32>
+/// %add = arith.addi %mul, %sc_arg2 : vector<1x8xi32>
+/// %res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32>
+/// ```
+///
+/// gets folded as:
+///
+///```
+/// %0 = arith.muli %arg0, %arg1 : vector<8xi32>
+/// %res = arith.addi %0, %arg2 : vector<8xi32>
+/// ```
+/// ATM this pattern is limited to `vector.shape_cast` ops that fold the unit
+/// dim, e.g.:
+/// ```
+/// vector.shape_cast %mul : vector<1x4xf32> to vector<4xf32>
+/// ```
+/// In addition, the input vector should be the result of an arithmetic
+/// operation, `AritOp`.
+template <typename ArithOp>
+struct ReorderArithAndShapeCast : public OpRewritePattern<vector::ShapeCastOp> {
----------------
banach-space wrote:
Argh, I wanted to match `OpTraitRewritePattern<OpTrait::Elementwise>`, but that's not going to work (I am matching `vector.shape_cast` rather than e.g. `arith.addi`). Let me rename then.
https://github.com/llvm/llvm-project/pull/74817
More information about the Mlir-commits
mailing list