[Mlir-commits] [mlir] [mlir][vector] Add pattern to reorder shape_cast(arithmetic(a, b)) (PR #74817)
Diego Caballero
llvmlistbot at llvm.org
Mon Dec 11 08:28:40 PST 2023
================
@@ -1446,6 +1446,117 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
}
};
+/// Reorders:
+/// shape_cast(arithmetic(a + b))
+/// as
+/// arithmetic(shape_cast(a), shape_cast(b)).
+///
+/// Ex:
+/// ```
+/// %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
+/// %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
----------------
dcaballe wrote:
To keep the logic simpler and more generic (to handle cases where the shape cast is above as well, for example), we can match the `arith.mulf`, check if it has unit dims and, if so, add shape casts around it:
```
%B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
%A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
%mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
%cast0 = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
%cast = vector.shape_cast %cast0 : vector<1x[4]xf32> to vector<[4]xf32>
```
Then add the ShapeCastOp canonicalization patterns to the populate as well.
https://github.com/llvm/llvm-project/pull/74817
More information about the Mlir-commits
mailing list