[Mlir-commits] [mlir] [mlir][vector] Add pattern to reorder shape_cast(arithmetic(a, b)) (PR #74817)

Diego Caballero llvmlistbot at llvm.org
Mon Dec 11 08:28:40 PST 2023


================
@@ -1446,6 +1446,117 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
   }
 };
 
+/// Reorders:
+///   shape_cast(arithmetic(a + b))
+/// as
+///   arithmetic(shape_cast(a), shape_cast(b)).
+///
+/// Ex:
+/// ```
+///  %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
+///  %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
----------------
dcaballe wrote:

To keep the logic simpler and more generic (to handle cases where the shape cast is above as well, for example), we can match the `arith.mulf`, check if it has unit dims and, if so, add shape casts around it:

```
 %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
 %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
 %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
 %cast0 = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
 %cast = vector.shape_cast %cast0 : vector<1x[4]xf32> to vector<[4]xf32>
```

Then add the ShapeCastOp canonicalization patterns to the populate as well.

https://github.com/llvm/llvm-project/pull/74817


More information about the Mlir-commits mailing list