[Mlir-commits] [mlir] [mlir][VectorOps] Fold extract on constant_mask (PR #183780)

Lukas Sommer llvmlistbot at llvm.org
Fri Feb 27 09:35:35 PST 2026


https://github.com/sommerlukas created https://github.com/llvm/llvm-project/pull/183780

Fold `vector.extract(vector.constant_mask)` to `vector.constant_mask` if possible.

If the static position is outside of the masked area, the pattern will fold to a constant all-false vector instead.

Dynamic positions are only supported if the mask covers the entire vector in that dimension.

Assisted-by: Claude Code

>From ed1d62ecb941321778f30b155c2110223d45b420 Mon Sep 17 00:00:00 2001
From: Lukas Sommer <lukas.sommer at amd.com>
Date: Fri, 27 Feb 2026 17:27:43 +0000
Subject: [PATCH] [mlir][VectorOps] Fold extract on constant_mask

Fold `vector.extract(vector.constant_mask)` to `vector.constant_mask` if
possible.

If the static position is outside of the masked area, the pattern will
fold to a constant all-false vector instead.

Dynamic positions are only supported if the mask covers the entire
vector in that dimension.

Assisted-by: Claude Code

Signed-off-by: Lukas Sommer <lukas.sommer at amd.com>
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 58 +++++++++++++++++-
 mlir/test/Dialect/Vector/canonicalize.mlir | 70 ++++++++++++++++++++++
 2 files changed, 125 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index aa6734049bbd7..644e4bec620ef 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2298,6 +2298,59 @@ class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
   }
 };
 
+// Pattern to rewrite a ExtractOp(ConstantMask) -> ConstantMask.
+class ExtractOpFromConstantMask final : public OpRewritePattern<ExtractOp> {
+public:
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    auto constantMaskOp =
+        extractOp.getSource().getDefiningOp<vector::ConstantMaskOp>();
+    if (!constantMaskOp)
+      return failure();
+
+    auto extractedMaskType =
+        llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
+    if (!extractedMaskType)
+      return failure();
+
+    ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
+    ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
+
+    VectorType maskType = constantMaskOp.getVectorType();
+
+    // Check if any extracted position is outside the mask bounds.
+    for (size_t dimIdx = 0; dimIdx < extractOpPos.size(); dimIdx++) {
+      int64_t pos = extractOpPos[dimIdx];
+      if (pos == ShapedType::kDynamic) {
+        // If the dim is all-true, a dynamic index is fine — any position
+        // is within the masked region.
+        if (maskDimSizes[dimIdx] == maskType.getDimSize(dimIdx))
+          continue;
+        // Otherwise we don't know if the position is inside or outside of
+        // the masked area, so bail out.
+        return failure();
+      }
+
+      // If the position is statically outside of the masked area, the result
+      // will be all-false.
+      if (pos >= maskDimSizes[dimIdx]) {
+        rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+            extractOp, DenseElementsAttr::get(extractedMaskType, false));
+        return success();
+      }
+    }
+
+    // All positions are within the mask bounds. The result is a constant_mask
+    // with the remaining dimensions.
+    rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
+        extractOp, extractedMaskType,
+        maskDimSizes.drop_front(extractOpPos.size()));
+    return success();
+  }
+};
+
 // Folds extract(shape_cast(..)) into shape_cast when the total element count
 // does not change.
 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
@@ -2405,9 +2458,8 @@ struct ExtractToShapeCast final : OpRewritePattern<vector::ExtractOp> {
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results
-      .add<ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractToShapeCast>(
-          context);
+  results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask,
+              ExtractOpFromConstantMask, ExtractToShapeCast>(context);
   results.add(foldExtractFromShapeCastToShapeCast);
   results.add(foldExtractFromFromElements);
 }
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 82b2cb633d1c9..0e78f4f723641 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -272,6 +272,76 @@ func.func @extract_from_non_constant_create_mask(%dim0: index) -> vector<[2]xi1>
 
 // -----
 
+// CHECK-LABEL: extract_from_constant_mask
+func.func @extract_from_constant_mask() -> vector<4xi1> {
+  %mask = vector.constant_mask [2, 3] : vector<4x4xi1>
+  // CHECK: %[[RES:.*]] = vector.constant_mask [3] : vector<4xi1>
+  // CHECK-NEXT: return %[[RES]]
+  %extract = vector.extract %mask[1] : vector<4xi1> from vector<4x4xi1>
+  return %extract : vector<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_constant_mask_all_false
+func.func @extract_from_constant_mask_all_false() -> vector<4xi1> {
+  %mask = vector.constant_mask [2, 3] : vector<4x4xi1>
+  // CHECK: %[[RES:.*]] = arith.constant dense<false> : vector<4xi1>
+  // CHECK-NEXT: return %[[RES]]
+  %extract = vector.extract %mask[3] : vector<4xi1> from vector<4x4xi1>
+  return %extract : vector<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_constant_mask_at_boundary
+func.func @extract_from_constant_mask_at_boundary() -> vector<4xi1> {
+  %mask = vector.constant_mask [2, 3] : vector<4x4xi1>
+  // CHECK: %[[RES:.*]] = arith.constant dense<false> : vector<4xi1>
+  // CHECK-NEXT: return %[[RES]]
+  %extract = vector.extract %mask[2] : vector<4xi1> from vector<4x4xi1>
+  return %extract : vector<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_constant_mask_multiple_indices
+func.func @extract_from_constant_mask_multiple_indices() -> vector<4xi1> {
+  %mask = vector.constant_mask [2, 3, 3] : vector<4x4x4xi1>
+  // CHECK: %[[RES:.*]] = vector.constant_mask [3] : vector<4xi1>
+  // CHECK-NEXT: return %[[RES]]
+  %extract = vector.extract %mask[1, 2] : vector<4xi1> from vector<4x4x4xi1>
+  return %extract : vector<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_constant_mask_dynamic_position_all_true
+//  CHECK-SAME: %[[INDEX:.*]]: index
+func.func @extract_from_constant_mask_dynamic_position_all_true(%index: index) -> vector<4xi1> {
+  // The mask covers the entire first dimension, so a dynamic index is fine.
+  %mask = vector.constant_mask [4, 3] : vector<4x4xi1>
+  // CHECK: %[[RES:.*]] = vector.constant_mask [3] : vector<4xi1>
+  // CHECK-NEXT: return %[[RES]]
+  %extract = vector.extract %mask[%index] : vector<4xi1> from vector<4x4xi1>
+  return %extract : vector<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_constant_mask_dynamic_position_not_all_true
+//  CHECK-SAME: %[[INDEX:.*]]: index
+func.func @extract_from_constant_mask_dynamic_position_not_all_true(%index: index) -> vector<4xi1> {
+  %mask = vector.constant_mask [2, 3] : vector<4x4xi1>
+  // CHECK: %[[MASK:.*]] = vector.constant_mask [2, 3] : vector<4x4xi1>
+  // CHECK-NEXT: %[[RES:.*]] = vector.extract %[[MASK]][%[[INDEX]]] : vector<4xi1> from vector<4x4xi1>
+  // CHECK-NEXT: return %[[RES]]
+  %extract = vector.extract %mask[%index] : vector<4xi1> from vector<4x4xi1>
+  return %extract : vector<4xi1>
+}
+
+// -----
+
 // CHECK-LABEL: constant_mask_to_true_splat
 func.func @constant_mask_to_true_splat() -> vector<2x4xi1> {
   // CHECK: arith.constant dense<true>



More information about the Mlir-commits mailing list