[Mlir-commits] [mlir] [mlir][vector] Drop trailing 1-dims from constant_mask (PR #187383)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Fri Mar 20 01:48:49 PDT 2026
================
@@ -522,21 +532,23 @@ class TransferReadDropUnitDimsPattern
Value maskOp = transferReadOp.getMask();
if (maskOp) {
LDBG() << " -> Processing mask operation";
- auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
- if (!createMaskOp) {
- LDBG()
- << " -> Unsupported mask op, only 'vector.create_mask' supported";
- return rewriter.notifyMatchFailure(
- transferReadOp, "unsupported mask op, only 'vector.create_mask' is "
- "currently supported");
- }
- FailureOr<Value> rankReducedCreateMask =
- createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
- if (failed(rankReducedCreateMask)) {
+ FailureOr<Value> rankReducedMask = failure();
+ if (auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>())
+ rankReducedMask =
+ maskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
+ else if (auto constantMaskOp =
+ maskOp.getDefiningOp<vector::ConstantMaskOp>())
+ rankReducedMask =
+ maskDropNonScalableUnitDims(rewriter, loc, constantMaskOp);
----------------
banach-space wrote:
Is this required? Why not:
```cpp
if (!isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp.getDefininingOp()))
return failure();
rankReducedMask =
maskDropNonScalableUnitDims(rewriter, loc, maskOp);
```
?
https://github.com/llvm/llvm-project/pull/187383
More information about the Mlir-commits
mailing list