[Mlir-commits] [mlir] [mlir][vector] Update `CombineContractBroadcastMask` (PR #140050)
Andrzej Warzyński
llvmlistbot at llvm.org
Tue May 20 10:55:16 PDT 2025
banach-space wrote:
Hey Diego - your comments are spot on and very much aligned with what we've been thinking.
> Would it possible to just reject these cases in the pattern when the broadcasts to be folded are actually "no-ops"? This is pointing at "we need to remove unnecessary unit dims before calling this pattern" kind of requirement...
In practice, quite a bit happens before we hit this pattern. My goal here was to minimally extend `CombineContractBroadcastMask` to unblock us, and then separately investigate a more principled solution.
For context, here’s the IR we're seeing - this is just the part that extracts arguments for the masked `vector.contract`:
```mlir
// MASK
%13 = vector.create_mask %c1, %c1, %c1, %c2, %dim, %c8 : vector<1x1x1x2x[4]x8xi1>
%mask = vector.extract %13[0, 0] : vector<1x2x[4]x8xi1> from vector<1x1x1x2x[4]x8xi1>
// LHS - %4 comes from an xfer Op
%rhs = vector.extract %4[0, 0] : vector<2x[4]xi32> from vector<1x1x2x[4]xi32>
// RHS - %2 comes from an xfer Op
%10 = vector.extract %2[0, 0, 0] : vector<2x[4]x8xi8> from vector<1x1x1x2x[4]x8xi8>
%11 = arith.extsi %10 : vector<2x[4]x8xi8> to vector<2x[4]x8xi32>
%rhs = vector.broadcast %11 : vector<2x[4]x8xi32> to vector<1x2x[4]x8xi32>
```
My thinking is that introducing `vector.shape_cast` creates something that is easy to correct for (with a different pattern):
```mlir
// MASK
%13 = vector.create_mask %c1, %c1, %c1, %c2, %dim, %c8 : vector<1x1x1x2x[4]x8xi1>
%mask = vector.extract %13[0, 0] : vector<1x2x[4]x8xi1> from vector<1x1x1x2x[4]x8xi1>
%mask_sc = vector.shape_cast %mask vector<1x2x[4]x8xi1> to vector<2x[4]x8xi1>
```
Indeed, the code above could be simplified as:
```mlir
// MASK
%13 = vector.create_mask %c2, %dim, %c8 : vector<1x1x1x2x[4]x8xi1>
```
…but I don’t think we want to be re-writing arbitrary `vector.create_mask` directly within `CombineContractBroadcastMask`, right?
I agree it would be better if the IR were cleaned up before reaching this point - I'll look into whether earlier patterns could be improved. So far, it looked a bit tricky, but there may be low-hanging fruit.
> It really looks like the pattern is trying to do an in-place simplification that should have happened before and adding some complexity to achieve that.
Totally hear you - and I should mention: GitHub makes the diff look more invasive than it is. This patch mainly just wraps `CombineContractBroadcastMask` in `MaskableOpRewritePattern` to support masks. So while the diff appears large, the actual change is pretty minimal 😅
Let me know what you think - it’s possible I’m missing a simpler approach here.🤔
https://github.com/llvm/llvm-project/pull/140050
More information about the Mlir-commits
mailing list