[Mlir-commits] [mlir] [mlir][VectorOps] Fold extract on constant_mask (PR #183780)

Lukas Sommer llvmlistbot at llvm.org
Mon Mar 2 01:42:05 PST 2026


https://github.com/sommerlukas updated https://github.com/llvm/llvm-project/pull/183780

>From 6d250f3023e64a07c0e062947a7d285f280ee921 Mon Sep 17 00:00:00 2001
From: Lukas Sommer <lukas.sommer at amd.com>
Date: Fri, 27 Feb 2026 17:27:43 +0000
Subject: [PATCH 1/3] [mlir][VectorOps] Fold extract on constant_mask

Fold `vector.extract(vector.constant_mask)` to `vector.constant_mask` if
possible.

If the static position is outside of the masked area, the pattern will
fold to a constant all-false vector instead.

Dynamic positions are only supported if the mask covers the entire
vector in that dimension.

Assisted-by: Claude Code

Signed-off-by: Lukas Sommer <lukas.sommer at amd.com>
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 58 +++++++++++++++++-
 mlir/test/Dialect/Vector/canonicalize.mlir | 70 ++++++++++++++++++++++
 2 files changed, 125 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b935ad77c1c14..fa6e0c3ac7c76 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2298,6 +2298,59 @@ class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
   }
 };
 
+// Pattern to rewrite a ExtractOp(ConstantMask) -> ConstantMask.
+class ExtractOpFromConstantMask final : public OpRewritePattern<ExtractOp> {
+public:
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    auto constantMaskOp =
+        extractOp.getSource().getDefiningOp<vector::ConstantMaskOp>();
+    if (!constantMaskOp)
+      return failure();
+
+    auto extractedMaskType =
+        llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
+    if (!extractedMaskType)
+      return failure();
+
+    ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
+    ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
+
+    VectorType maskType = constantMaskOp.getVectorType();
+
+    // Check if any extracted position is outside the mask bounds.
+    for (size_t dimIdx = 0; dimIdx < extractOpPos.size(); dimIdx++) {
+      int64_t pos = extractOpPos[dimIdx];
+      if (pos == ShapedType::kDynamic) {
+        // If the dim is all-true, a dynamic index is fine — any position
+        // is within the masked region.
+        if (maskDimSizes[dimIdx] == maskType.getDimSize(dimIdx))
+          continue;
+        // Otherwise we don't know if the position is inside or outside of
+        // the masked area, so bail out.
+        return failure();
+      }
+
+      // If the position is statically outside of the masked area, the result
+      // will be all-false.
+      if (pos >= maskDimSizes[dimIdx]) {
+        rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+            extractOp, DenseElementsAttr::get(extractedMaskType, false));
+        return success();
+      }
+    }
+
+    // All positions are within the mask bounds. The result is a constant_mask
+    // with the remaining dimensions.
+    rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
+        extractOp, extractedMaskType,
+        maskDimSizes.drop_front(extractOpPos.size()));
+    return success();
+  }
+};
+
 // Folds extract(shape_cast(..)) into shape_cast when the total element count
 // does not change.
 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
@@ -2405,9 +2458,8 @@ struct ExtractToShapeCast final : OpRewritePattern<vector::ExtractOp> {
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results
-      .add<ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractToShapeCast>(
-          context);
+  results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask,
+              ExtractOpFromConstantMask, ExtractToShapeCast>(context);
   results.add(foldExtractFromShapeCastToShapeCast);
   results.add(foldExtractFromFromElements);
 }
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 3980c179b5d0a..1d38064fb383c 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -272,6 +272,76 @@ func.func @extract_from_non_constant_create_mask(%dim0: index) -> vector<[2]xi1>
 
 // -----
 
+// CHECK-LABEL: extract_from_constant_mask
+func.func @extract_from_constant_mask() -> vector<4xi1> {
+  %mask = vector.constant_mask [2, 3] : vector<4x4xi1>
+  // CHECK: %[[RES:.*]] = vector.constant_mask [3] : vector<4xi1>
+  // CHECK-NEXT: return %[[RES]]
+  %extract = vector.extract %mask[1] : vector<4xi1> from vector<4x4xi1>
+  return %extract : vector<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_constant_mask_all_false
+func.func @extract_from_constant_mask_all_false() -> vector<4xi1> {
+  %mask = vector.constant_mask [2, 3] : vector<4x4xi1>
+  // CHECK: %[[RES:.*]] = arith.constant dense<false> : vector<4xi1>
+  // CHECK-NEXT: return %[[RES]]
+  %extract = vector.extract %mask[3] : vector<4xi1> from vector<4x4xi1>
+  return %extract : vector<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_constant_mask_at_boundary
+func.func @extract_from_constant_mask_at_boundary() -> vector<4xi1> {
+  %mask = vector.constant_mask [2, 3] : vector<4x4xi1>
+  // CHECK: %[[RES:.*]] = arith.constant dense<false> : vector<4xi1>
+  // CHECK-NEXT: return %[[RES]]
+  %extract = vector.extract %mask[2] : vector<4xi1> from vector<4x4xi1>
+  return %extract : vector<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_constant_mask_multiple_indices
+func.func @extract_from_constant_mask_multiple_indices() -> vector<4xi1> {
+  %mask = vector.constant_mask [2, 3, 3] : vector<4x4x4xi1>
+  // CHECK: %[[RES:.*]] = vector.constant_mask [3] : vector<4xi1>
+  // CHECK-NEXT: return %[[RES]]
+  %extract = vector.extract %mask[1, 2] : vector<4xi1> from vector<4x4x4xi1>
+  return %extract : vector<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_constant_mask_dynamic_position_all_true
+//  CHECK-SAME: %[[INDEX:.*]]: index
+func.func @extract_from_constant_mask_dynamic_position_all_true(%index: index) -> vector<4xi1> {
+  // The mask covers the entire first dimension, so a dynamic index is fine.
+  %mask = vector.constant_mask [4, 3] : vector<4x4xi1>
+  // CHECK: %[[RES:.*]] = vector.constant_mask [3] : vector<4xi1>
+  // CHECK-NEXT: return %[[RES]]
+  %extract = vector.extract %mask[%index] : vector<4xi1> from vector<4x4xi1>
+  return %extract : vector<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_constant_mask_dynamic_position_not_all_true
+//  CHECK-SAME: %[[INDEX:.*]]: index
+func.func @extract_from_constant_mask_dynamic_position_not_all_true(%index: index) -> vector<4xi1> {
+  %mask = vector.constant_mask [2, 3] : vector<4x4xi1>
+  // CHECK: %[[MASK:.*]] = vector.constant_mask [2, 3] : vector<4x4xi1>
+  // CHECK-NEXT: %[[RES:.*]] = vector.extract %[[MASK]][%[[INDEX]]] : vector<4xi1> from vector<4x4xi1>
+  // CHECK-NEXT: return %[[RES]]
+  %extract = vector.extract %mask[%index] : vector<4xi1> from vector<4x4xi1>
+  return %extract : vector<4xi1>
+}
+
+// -----
+
 // CHECK-LABEL: constant_mask_to_true_splat
 func.func @constant_mask_to_true_splat() -> vector<2x4xi1> {
   // CHECK: arith.constant dense<true>

>From 694158840e1a33cef006128aef2f3dea67c90aee Mon Sep 17 00:00:00 2001
From: Lukas Sommer <lukas.sommer at amd.com>
Date: Fri, 27 Feb 2026 18:00:51 +0000
Subject: [PATCH 2/3] Remove namespace

Signed-off-by: Lukas Sommer <lukas.sommer at amd.com>
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index fa6e0c3ac7c76..0f81dd8226b96 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2311,7 +2311,7 @@ class ExtractOpFromConstantMask final : public OpRewritePattern<ExtractOp> {
       return failure();
 
     auto extractedMaskType =
-        llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
+        dyn_cast<VectorType>(extractOp.getResult().getType());
     if (!extractedMaskType)
       return failure();
 

>From 80caf5bc303f8de59782cb72b552ff4d0ba80b2e Mon Sep 17 00:00:00 2001
From: Lukas Sommer <lukas.sommer at amd.com>
Date: Mon, 2 Mar 2026 09:40:46 +0000
Subject: [PATCH 3/3] Add support for scalar case

Signed-off-by: Lukas Sommer <lukas.sommer at amd.com>
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 33 ++++++++++++++--------
 mlir/test/Dialect/Vector/canonicalize.mlir | 22 +++++++++++++++
 2 files changed, 44 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 0f81dd8226b96..5dc3984b0a037 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2310,10 +2310,8 @@ class ExtractOpFromConstantMask final : public OpRewritePattern<ExtractOp> {
     if (!constantMaskOp)
       return failure();
 
-    auto extractedMaskType =
-        dyn_cast<VectorType>(extractOp.getResult().getType());
-    if (!extractedMaskType)
-      return failure();
+    Type resultType = extractOp.getResult().getType();
+    auto extractedMaskType = dyn_cast<VectorType>(resultType);
 
     ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
     ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
@@ -2336,17 +2334,30 @@ class ExtractOpFromConstantMask final : public OpRewritePattern<ExtractOp> {
       // If the position is statically outside of the masked area, the result
       // will be all-false.
       if (pos >= maskDimSizes[dimIdx]) {
-        rewriter.replaceOpWithNewOp<arith::ConstantOp>(
-            extractOp, DenseElementsAttr::get(extractedMaskType, false));
+        if (extractedMaskType) {
+          rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+              extractOp, DenseElementsAttr::get(extractedMaskType, false));
+        } else {
+          rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+              extractOp, rewriter.getIntegerAttr(resultType, false));
+        }
         return success();
       }
     }
 
-    // All positions are within the mask bounds. The result is a constant_mask
-    // with the remaining dimensions.
-    rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
-        extractOp, extractedMaskType,
-        maskDimSizes.drop_front(extractOpPos.size()));
+    // All positions are within the mask bounds.
+    if (extractedMaskType) {
+      // Vector result: the result is a constant_mask with the remaining
+      // dimensions.
+      rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
+          extractOp, extractedMaskType,
+          maskDimSizes.drop_front(extractOpPos.size()));
+    } else {
+      // Scalar result: all positions are within the masked region, so the
+      // result is true.
+      rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+          extractOp, rewriter.getIntegerAttr(resultType, true));
+    }
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1d38064fb383c..583aa2efa49c3 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -342,6 +342,28 @@ func.func @extract_from_constant_mask_dynamic_position_not_all_true(%index: inde
 
 // -----
 
+// CHECK-LABEL: extract_scalar_from_constant_mask_within_bounds
+func.func @extract_scalar_from_constant_mask_within_bounds() -> i1 {
+  %mask = vector.constant_mask [2, 3] : vector<4x4xi1>
+  // CHECK: %[[RES:.*]] = arith.constant true
+  // CHECK-NEXT: return %[[RES]]
+  %extract = vector.extract %mask[0, 1] : i1 from vector<4x4xi1>
+  return %extract : i1
+}
+
+// -----
+
+// CHECK-LABEL: extract_scalar_from_constant_mask_outside_bounds
+func.func @extract_scalar_from_constant_mask_outside_bounds() -> i1 {
+  %mask = vector.constant_mask [2, 3] : vector<4x4xi1>
+  // CHECK: %[[RES:.*]] = arith.constant false
+  // CHECK-NEXT: return %[[RES]]
+  %extract = vector.extract %mask[0, 3] : i1 from vector<4x4xi1>
+  return %extract : i1
+}
+
+// -----
+
 // CHECK-LABEL: constant_mask_to_true_splat
 func.func @constant_mask_to_true_splat() -> vector<2x4xi1> {
   // CHECK: arith.constant dense<true>



More information about the Mlir-commits mailing list