[Mlir-commits] [mlir] f5e8e98 - [mlir][VectorOps] Fold extract on constant_mask (#183780)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 2 07:54:55 PST 2026


Author: Lukas Sommer
Date: 2026-03-02T16:54:50+01:00
New Revision: f5e8e98a4ef4a4889523101b4d4039e13bccbdf4

URL: https://github.com/llvm/llvm-project/commit/f5e8e98a4ef4a4889523101b4d4039e13bccbdf4
DIFF: https://github.com/llvm/llvm-project/commit/f5e8e98a4ef4a4889523101b4d4039e13bccbdf4.diff

LOG: [mlir][VectorOps] Fold extract on constant_mask (#183780)

Fold `vector.extract(vector.constant_mask)` to `vector.constant_mask` if
possible.

If the static position is outside of the masked area, the pattern will
fold to a constant all-false value instead.

Dynamic positions are only supported if the mask covers the entire
vector in that dimension.

Assisted-by: Claude Code

---------

Signed-off-by: Lukas Sommer <lukas.sommer at amd.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b935ad77c1c14..5dc3984b0a037 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2298,6 +2298,70 @@ class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
   }
 };
 
+// Pattern to rewrite a ExtractOp(ConstantMask) -> ConstantMask.
+class ExtractOpFromConstantMask final : public OpRewritePattern<ExtractOp> {
+public:
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    auto constantMaskOp =
+        extractOp.getSource().getDefiningOp<vector::ConstantMaskOp>();
+    if (!constantMaskOp)
+      return failure();
+
+    Type resultType = extractOp.getResult().getType();
+    auto extractedMaskType = dyn_cast<VectorType>(resultType);
+
+    ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
+    ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
+
+    VectorType maskType = constantMaskOp.getVectorType();
+
+    // Check if any extracted position is outside the mask bounds.
+    for (size_t dimIdx = 0; dimIdx < extractOpPos.size(); dimIdx++) {
+      int64_t pos = extractOpPos[dimIdx];
+      if (pos == ShapedType::kDynamic) {
+        // If the dim is all-true, a dynamic index is fine — any position
+        // is within the masked region.
+        if (maskDimSizes[dimIdx] == maskType.getDimSize(dimIdx))
+          continue;
+        // Otherwise we don't know if the position is inside or outside of
+        // the masked area, so bail out.
+        return failure();
+      }
+
+      // If the position is statically outside of the masked area, the result
+      // will be all-false.
+      if (pos >= maskDimSizes[dimIdx]) {
+        if (extractedMaskType) {
+          rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+              extractOp, DenseElementsAttr::get(extractedMaskType, false));
+        } else {
+          rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+              extractOp, rewriter.getIntegerAttr(resultType, false));
+        }
+        return success();
+      }
+    }
+
+    // All positions are within the mask bounds.
+    if (extractedMaskType) {
+      // Vector result: the result is a constant_mask with the remaining
+      // dimensions.
+      rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
+          extractOp, extractedMaskType,
+          maskDimSizes.drop_front(extractOpPos.size()));
+    } else {
+      // Scalar result: all positions are within the masked region, so the
+      // result is true.
+      rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+          extractOp, rewriter.getIntegerAttr(resultType, true));
+    }
+    return success();
+  }
+};
+
 // Folds extract(shape_cast(..)) into shape_cast when the total element count
 // does not change.
 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
@@ -2405,9 +2469,8 @@ struct ExtractToShapeCast final : OpRewritePattern<vector::ExtractOp> {
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results
-      .add<ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractToShapeCast>(
-          context);
+  results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask,
+              ExtractOpFromConstantMask, ExtractToShapeCast>(context);
   results.add(foldExtractFromShapeCastToShapeCast);
   results.add(foldExtractFromFromElements);
 }

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 3980c179b5d0a..583aa2efa49c3 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -272,6 +272,98 @@ func.func @extract_from_non_constant_create_mask(%dim0: index) -> vector<[2]xi1>
 
 // -----
 
+// CHECK-LABEL: extract_from_constant_mask
+func.func @extract_from_constant_mask() -> vector<4xi1> {
+  %mask = vector.constant_mask [2, 3] : vector<4x4xi1>
+  // CHECK: %[[RES:.*]] = vector.constant_mask [3] : vector<4xi1>
+  // CHECK-NEXT: return %[[RES]]
+  %extract = vector.extract %mask[1] : vector<4xi1> from vector<4x4xi1>
+  return %extract : vector<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_constant_mask_all_false
+func.func @extract_from_constant_mask_all_false() -> vector<4xi1> {
+  %mask = vector.constant_mask [2, 3] : vector<4x4xi1>
+  // CHECK: %[[RES:.*]] = arith.constant dense<false> : vector<4xi1>
+  // CHECK-NEXT: return %[[RES]]
+  %extract = vector.extract %mask[3] : vector<4xi1> from vector<4x4xi1>
+  return %extract : vector<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_constant_mask_at_boundary
+func.func @extract_from_constant_mask_at_boundary() -> vector<4xi1> {
+  %mask = vector.constant_mask [2, 3] : vector<4x4xi1>
+  // CHECK: %[[RES:.*]] = arith.constant dense<false> : vector<4xi1>
+  // CHECK-NEXT: return %[[RES]]
+  %extract = vector.extract %mask[2] : vector<4xi1> from vector<4x4xi1>
+  return %extract : vector<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_constant_mask_multiple_indices
+func.func @extract_from_constant_mask_multiple_indices() -> vector<4xi1> {
+  %mask = vector.constant_mask [2, 3, 3] : vector<4x4x4xi1>
+  // CHECK: %[[RES:.*]] = vector.constant_mask [3] : vector<4xi1>
+  // CHECK-NEXT: return %[[RES]]
+  %extract = vector.extract %mask[1, 2] : vector<4xi1> from vector<4x4x4xi1>
+  return %extract : vector<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_constant_mask_dynamic_position_all_true
+//  CHECK-SAME: %[[INDEX:.*]]: index
+func.func @extract_from_constant_mask_dynamic_position_all_true(%index: index) -> vector<4xi1> {
+  // The mask covers the entire first dimension, so a dynamic index is fine.
+  %mask = vector.constant_mask [4, 3] : vector<4x4xi1>
+  // CHECK: %[[RES:.*]] = vector.constant_mask [3] : vector<4xi1>
+  // CHECK-NEXT: return %[[RES]]
+  %extract = vector.extract %mask[%index] : vector<4xi1> from vector<4x4xi1>
+  return %extract : vector<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_constant_mask_dynamic_position_not_all_true
+//  CHECK-SAME: %[[INDEX:.*]]: index
+func.func @extract_from_constant_mask_dynamic_position_not_all_true(%index: index) -> vector<4xi1> {
+  %mask = vector.constant_mask [2, 3] : vector<4x4xi1>
+  // CHECK: %[[MASK:.*]] = vector.constant_mask [2, 3] : vector<4x4xi1>
+  // CHECK-NEXT: %[[RES:.*]] = vector.extract %[[MASK]][%[[INDEX]]] : vector<4xi1> from vector<4x4xi1>
+  // CHECK-NEXT: return %[[RES]]
+  %extract = vector.extract %mask[%index] : vector<4xi1> from vector<4x4xi1>
+  return %extract : vector<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_scalar_from_constant_mask_within_bounds
+func.func @extract_scalar_from_constant_mask_within_bounds() -> i1 {
+  %mask = vector.constant_mask [2, 3] : vector<4x4xi1>
+  // CHECK: %[[RES:.*]] = arith.constant true
+  // CHECK-NEXT: return %[[RES]]
+  %extract = vector.extract %mask[0, 1] : i1 from vector<4x4xi1>
+  return %extract : i1
+}
+
+// -----
+
+// CHECK-LABEL: extract_scalar_from_constant_mask_outside_bounds
+func.func @extract_scalar_from_constant_mask_outside_bounds() -> i1 {
+  %mask = vector.constant_mask [2, 3] : vector<4x4xi1>
+  // CHECK: %[[RES:.*]] = arith.constant false
+  // CHECK-NEXT: return %[[RES]]
+  %extract = vector.extract %mask[0, 3] : i1 from vector<4x4xi1>
+  return %extract : i1
+}
+
+// -----
+
 // CHECK-LABEL: constant_mask_to_true_splat
 func.func @constant_mask_to_true_splat() -> vector<2x4xi1> {
   // CHECK: arith.constant dense<true>


        


More information about the Mlir-commits mailing list