[Mlir-commits] [mlir] [mlir][vector] Drop trailing 1-dims from constant_mask (PR #187383)

Andrzej WarzyƄski llvmlistbot at llvm.org
Fri Mar 20 01:48:49 PDT 2026


================
@@ -522,21 +532,23 @@ class TransferReadDropUnitDimsPattern
     Value maskOp = transferReadOp.getMask();
     if (maskOp) {
       LDBG() << "  -> Processing mask operation";
-      auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
-      if (!createMaskOp) {
-        LDBG()
-            << "  -> Unsupported mask op, only 'vector.create_mask' supported";
-        return rewriter.notifyMatchFailure(
-            transferReadOp, "unsupported mask op, only 'vector.create_mask' is "
-                            "currently supported");
-      }
-      FailureOr<Value> rankReducedCreateMask =
-          createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
-      if (failed(rankReducedCreateMask)) {
+      FailureOr<Value> rankReducedMask = failure();
+      if (auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>())
+        rankReducedMask =
+            maskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
+      else if (auto constantMaskOp =
+                   maskOp.getDefiningOp<vector::ConstantMaskOp>())
+        rankReducedMask =
+            maskDropNonScalableUnitDims(rewriter, loc, constantMaskOp);
----------------
banach-space wrote:

Is this required? Why not:
```cpp
if (!isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp.getDefininingOp()))
   return failure();

rankReducedMask =
            maskDropNonScalableUnitDims(rewriter, loc, maskOp);
```

?

https://github.com/llvm/llvm-project/pull/187383


More information about the Mlir-commits mailing list