[Mlir-commits] [mlir] [mlir][VectorOps] Fold extract on constant_mask (PR #183780)
Lukas Sommer
llvmlistbot at llvm.org
Fri Feb 27 09:35:35 PST 2026
https://github.com/sommerlukas created https://github.com/llvm/llvm-project/pull/183780
Fold `vector.extract(vector.constant_mask)` to `vector.constant_mask` if possible.
If the static position is outside of the masked area, the pattern will fold to a constant all-false vector instead.
Dynamic positions are only supported if the mask covers the entire vector in that dimension.
Assisted-by: Claude Code
>From ed1d62ecb941321778f30b155c2110223d45b420 Mon Sep 17 00:00:00 2001
From: Lukas Sommer <lukas.sommer at amd.com>
Date: Fri, 27 Feb 2026 17:27:43 +0000
Subject: [PATCH] [mlir][VectorOps] Fold extract on constant_mask
Fold `vector.extract(vector.constant_mask)` to `vector.constant_mask` if
possible.
If the static position is outside of the masked area, the pattern will
fold to a constant all-false vector instead.
Dynamic positions are only supported if the mask covers the entire
vector in that dimension.
Assisted-by: Claude Code
Signed-off-by: Lukas Sommer <lukas.sommer at amd.com>
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 58 +++++++++++++++++-
mlir/test/Dialect/Vector/canonicalize.mlir | 70 ++++++++++++++++++++++
2 files changed, 125 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index aa6734049bbd7..644e4bec620ef 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2298,6 +2298,59 @@ class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
}
};
+// Pattern to rewrite a ExtractOp(ConstantMask) -> ConstantMask.
+class ExtractOpFromConstantMask final : public OpRewritePattern<ExtractOp> {
+public:
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ auto constantMaskOp =
+ extractOp.getSource().getDefiningOp<vector::ConstantMaskOp>();
+ if (!constantMaskOp)
+ return failure();
+
+ auto extractedMaskType =
+ llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
+ if (!extractedMaskType)
+ return failure();
+
+ ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
+ ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
+
+ VectorType maskType = constantMaskOp.getVectorType();
+
+ // Check if any extracted position is outside the mask bounds.
+ for (size_t dimIdx = 0; dimIdx < extractOpPos.size(); dimIdx++) {
+ int64_t pos = extractOpPos[dimIdx];
+ if (pos == ShapedType::kDynamic) {
+ // If the dim is all-true, a dynamic index is fine — any position
+ // is within the masked region.
+ if (maskDimSizes[dimIdx] == maskType.getDimSize(dimIdx))
+ continue;
+ // Otherwise we don't know if the position is inside or outside of
+ // the masked area, so bail out.
+ return failure();
+ }
+
+ // If the position is statically outside of the masked area, the result
+ // will be all-false.
+ if (pos >= maskDimSizes[dimIdx]) {
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+ extractOp, DenseElementsAttr::get(extractedMaskType, false));
+ return success();
+ }
+ }
+
+ // All positions are within the mask bounds. The result is a constant_mask
+ // with the remaining dimensions.
+ rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
+ extractOp, extractedMaskType,
+ maskDimSizes.drop_front(extractOpPos.size()));
+ return success();
+ }
+};
+
// Folds extract(shape_cast(..)) into shape_cast when the total element count
// does not change.
LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
@@ -2405,9 +2458,8 @@ struct ExtractToShapeCast final : OpRewritePattern<vector::ExtractOp> {
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results
- .add<ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractToShapeCast>(
- context);
+ results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask,
+ ExtractOpFromConstantMask, ExtractToShapeCast>(context);
results.add(foldExtractFromShapeCastToShapeCast);
results.add(foldExtractFromFromElements);
}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 82b2cb633d1c9..0e78f4f723641 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -272,6 +272,76 @@ func.func @extract_from_non_constant_create_mask(%dim0: index) -> vector<[2]xi1>
// -----
+// CHECK-LABEL: extract_from_constant_mask
+func.func @extract_from_constant_mask() -> vector<4xi1> {
+ %mask = vector.constant_mask [2, 3] : vector<4x4xi1>
+ // CHECK: %[[RES:.*]] = vector.constant_mask [3] : vector<4xi1>
+ // CHECK-NEXT: return %[[RES]]
+ %extract = vector.extract %mask[1] : vector<4xi1> from vector<4x4xi1>
+ return %extract : vector<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_constant_mask_all_false
+func.func @extract_from_constant_mask_all_false() -> vector<4xi1> {
+ %mask = vector.constant_mask [2, 3] : vector<4x4xi1>
+ // CHECK: %[[RES:.*]] = arith.constant dense<false> : vector<4xi1>
+ // CHECK-NEXT: return %[[RES]]
+ %extract = vector.extract %mask[3] : vector<4xi1> from vector<4x4xi1>
+ return %extract : vector<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_constant_mask_at_boundary
+func.func @extract_from_constant_mask_at_boundary() -> vector<4xi1> {
+ %mask = vector.constant_mask [2, 3] : vector<4x4xi1>
+ // CHECK: %[[RES:.*]] = arith.constant dense<false> : vector<4xi1>
+ // CHECK-NEXT: return %[[RES]]
+ %extract = vector.extract %mask[2] : vector<4xi1> from vector<4x4xi1>
+ return %extract : vector<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_constant_mask_multiple_indices
+func.func @extract_from_constant_mask_multiple_indices() -> vector<4xi1> {
+ %mask = vector.constant_mask [2, 3, 3] : vector<4x4x4xi1>
+ // CHECK: %[[RES:.*]] = vector.constant_mask [3] : vector<4xi1>
+ // CHECK-NEXT: return %[[RES]]
+ %extract = vector.extract %mask[1, 2] : vector<4xi1> from vector<4x4x4xi1>
+ return %extract : vector<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_constant_mask_dynamic_position_all_true
+// CHECK-SAME: %[[INDEX:.*]]: index
+func.func @extract_from_constant_mask_dynamic_position_all_true(%index: index) -> vector<4xi1> {
+ // The mask covers the entire first dimension, so a dynamic index is fine.
+ %mask = vector.constant_mask [4, 3] : vector<4x4xi1>
+ // CHECK: %[[RES:.*]] = vector.constant_mask [3] : vector<4xi1>
+ // CHECK-NEXT: return %[[RES]]
+ %extract = vector.extract %mask[%index] : vector<4xi1> from vector<4x4xi1>
+ return %extract : vector<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_constant_mask_dynamic_position_not_all_true
+// CHECK-SAME: %[[INDEX:.*]]: index
+func.func @extract_from_constant_mask_dynamic_position_not_all_true(%index: index) -> vector<4xi1> {
+ %mask = vector.constant_mask [2, 3] : vector<4x4xi1>
+ // CHECK: %[[MASK:.*]] = vector.constant_mask [2, 3] : vector<4x4xi1>
+ // CHECK-NEXT: %[[RES:.*]] = vector.extract %[[MASK]][%[[INDEX]]] : vector<4xi1> from vector<4x4xi1>
+ // CHECK-NEXT: return %[[RES]]
+ %extract = vector.extract %mask[%index] : vector<4xi1> from vector<4x4xi1>
+ return %extract : vector<4xi1>
+}
+
+// -----
+
// CHECK-LABEL: constant_mask_to_true_splat
func.func @constant_mask_to_true_splat() -> vector<2x4xi1> {
// CHECK: arith.constant dense<true>
More information about the Mlir-commits
mailing list