[Mlir-commits] [mlir] [mlir][vector] Update `CombineContractBroadcastMask` (PR #140050)

Andrzej Warzyński llvmlistbot at llvm.org
Tue May 20 10:55:16 PDT 2025


banach-space wrote:

Hey Diego - your comments are spot on and very much aligned with what we've been thinking.

> Would it possible to just reject these cases in the pattern when the broadcasts to be folded are actually "no-ops"? This is pointing at "we need to remove unnecessary unit dims before calling this pattern" kind of requirement...

In practice, quite a bit happens before we hit this pattern. My goal here was to minimally extend `CombineContractBroadcastMask` to unblock us, and then separately investigate a more principled solution.

For context, here’s the IR we're seeing - this is just the part that extracts arguments for the masked `vector.contract`:

```mlir
  // MASK
  %13 = vector.create_mask %c1, %c1, %c1, %c2, %dim, %c8 : vector<1x1x1x2x[4]x8xi1>
  %mask = vector.extract %13[0, 0] : vector<1x2x[4]x8xi1> from vector<1x1x1x2x[4]x8xi1>
 
  // LHS - %4 comes from an xfer Op
  %rhs = vector.extract %4[0, 0] : vector<2x[4]xi32> from vector<1x1x2x[4]xi32>

  // RHS - %2 comes from an xfer Op
  %10 = vector.extract %2[0, 0, 0] : vector<2x[4]x8xi8> from vector<1x1x1x2x[4]x8xi8>
  %11 = arith.extsi %10 : vector<2x[4]x8xi8> to vector<2x[4]x8xi32>
  %rhs = vector.broadcast %11 : vector<2x[4]x8xi32> to vector<1x2x[4]x8xi32>
```

My thinking is that introducing `vector.shape_cast` creates something that is easy to correct for (with a different pattern):
```mlir
  // MASK
  %13 = vector.create_mask %c1, %c1, %c1, %c2, %dim, %c8 : vector<1x1x1x2x[4]x8xi1>
  %mask = vector.extract %13[0, 0] : vector<1x2x[4]x8xi1> from vector<1x1x1x2x[4]x8xi1>
  %mask_sc = vector.shape_cast %mask  vector<1x2x[4]x8xi1> to  vector<2x[4]x8xi1>
```

Indeed, the code above could be simplified as:
```mlir
  // MASK
  %13 = vector.create_mask %c2, %dim, %c8 : vector<1x1x1x2x[4]x8xi1>
```

…but I don’t think we want to be re-writing arbitrary `vector.create_mask` directly within `CombineContractBroadcastMask`, right?

I agree it would be better if the IR were cleaned up before reaching this point - I'll look into whether earlier patterns could be improved. So far, it looked a bit tricky, but there may be low-hanging fruit.

> It really looks like the pattern is trying to do an in-place simplification that should have happened before and adding some complexity to achieve that. 

Totally hear you - and I should mention: GitHub makes the diff look more invasive than it is. This patch mainly just wraps `CombineContractBroadcastMask` in `MaskableOpRewritePattern` to support masks. So while the diff appears large, the actual change is pretty minimal 😅

Let me know what you think - it’s possible I’m missing a simpler approach here.🤔 

https://github.com/llvm/llvm-project/pull/140050


More information about the Mlir-commits mailing list