[Mlir-commits] [mlir] [mlir][VectorOps] Fold extract on constant_mask (PR #183780)
Lukas Sommer
llvmlistbot at llvm.org
Mon Mar 2 01:42:05 PST 2026
https://github.com/sommerlukas updated https://github.com/llvm/llvm-project/pull/183780
>From 6d250f3023e64a07c0e062947a7d285f280ee921 Mon Sep 17 00:00:00 2001
From: Lukas Sommer <lukas.sommer at amd.com>
Date: Fri, 27 Feb 2026 17:27:43 +0000
Subject: [PATCH 1/3] [mlir][VectorOps] Fold extract on constant_mask
Fold `vector.extract(vector.constant_mask)` to `vector.constant_mask` if
possible.
If the static position is outside of the masked area, the pattern will
fold to a constant all-false vector instead.
Dynamic positions are only supported if the mask covers the entire
vector in that dimension.
Assisted-by: Claude Code
Signed-off-by: Lukas Sommer <lukas.sommer at amd.com>
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 58 +++++++++++++++++-
mlir/test/Dialect/Vector/canonicalize.mlir | 70 ++++++++++++++++++++++
2 files changed, 125 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b935ad77c1c14..fa6e0c3ac7c76 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2298,6 +2298,59 @@ class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
}
};
+// Pattern to rewrite a ExtractOp(ConstantMask) -> ConstantMask.
+class ExtractOpFromConstantMask final : public OpRewritePattern<ExtractOp> {
+public:
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ auto constantMaskOp =
+ extractOp.getSource().getDefiningOp<vector::ConstantMaskOp>();
+ if (!constantMaskOp)
+ return failure();
+
+ auto extractedMaskType =
+ llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
+ if (!extractedMaskType)
+ return failure();
+
+ ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
+ ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
+
+ VectorType maskType = constantMaskOp.getVectorType();
+
+ // Check if any extracted position is outside the mask bounds.
+ for (size_t dimIdx = 0; dimIdx < extractOpPos.size(); dimIdx++) {
+ int64_t pos = extractOpPos[dimIdx];
+ if (pos == ShapedType::kDynamic) {
+ // If the dim is all-true, a dynamic index is fine — any position
+ // is within the masked region.
+ if (maskDimSizes[dimIdx] == maskType.getDimSize(dimIdx))
+ continue;
+ // Otherwise we don't know if the position is inside or outside of
+ // the masked area, so bail out.
+ return failure();
+ }
+
+ // If the position is statically outside of the masked area, the result
+ // will be all-false.
+ if (pos >= maskDimSizes[dimIdx]) {
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+ extractOp, DenseElementsAttr::get(extractedMaskType, false));
+ return success();
+ }
+ }
+
+ // All positions are within the mask bounds. The result is a constant_mask
+ // with the remaining dimensions.
+ rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
+ extractOp, extractedMaskType,
+ maskDimSizes.drop_front(extractOpPos.size()));
+ return success();
+ }
+};
+
// Folds extract(shape_cast(..)) into shape_cast when the total element count
// does not change.
LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
@@ -2405,9 +2458,8 @@ struct ExtractToShapeCast final : OpRewritePattern<vector::ExtractOp> {
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results
- .add<ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractToShapeCast>(
- context);
+ results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask,
+ ExtractOpFromConstantMask, ExtractToShapeCast>(context);
results.add(foldExtractFromShapeCastToShapeCast);
results.add(foldExtractFromFromElements);
}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 3980c179b5d0a..1d38064fb383c 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -272,6 +272,76 @@ func.func @extract_from_non_constant_create_mask(%dim0: index) -> vector<[2]xi1>
// -----
+// CHECK-LABEL: extract_from_constant_mask
+func.func @extract_from_constant_mask() -> vector<4xi1> {
+ %mask = vector.constant_mask [2, 3] : vector<4x4xi1>
+ // CHECK: %[[RES:.*]] = vector.constant_mask [3] : vector<4xi1>
+ // CHECK-NEXT: return %[[RES]]
+ %extract = vector.extract %mask[1] : vector<4xi1> from vector<4x4xi1>
+ return %extract : vector<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_constant_mask_all_false
+func.func @extract_from_constant_mask_all_false() -> vector<4xi1> {
+ %mask = vector.constant_mask [2, 3] : vector<4x4xi1>
+ // CHECK: %[[RES:.*]] = arith.constant dense<false> : vector<4xi1>
+ // CHECK-NEXT: return %[[RES]]
+ %extract = vector.extract %mask[3] : vector<4xi1> from vector<4x4xi1>
+ return %extract : vector<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_constant_mask_at_boundary
+func.func @extract_from_constant_mask_at_boundary() -> vector<4xi1> {
+ %mask = vector.constant_mask [2, 3] : vector<4x4xi1>
+ // CHECK: %[[RES:.*]] = arith.constant dense<false> : vector<4xi1>
+ // CHECK-NEXT: return %[[RES]]
+ %extract = vector.extract %mask[2] : vector<4xi1> from vector<4x4xi1>
+ return %extract : vector<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_constant_mask_multiple_indices
+func.func @extract_from_constant_mask_multiple_indices() -> vector<4xi1> {
+ %mask = vector.constant_mask [2, 3, 3] : vector<4x4x4xi1>
+ // CHECK: %[[RES:.*]] = vector.constant_mask [3] : vector<4xi1>
+ // CHECK-NEXT: return %[[RES]]
+ %extract = vector.extract %mask[1, 2] : vector<4xi1> from vector<4x4x4xi1>
+ return %extract : vector<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_constant_mask_dynamic_position_all_true
+// CHECK-SAME: %[[INDEX:.*]]: index
+func.func @extract_from_constant_mask_dynamic_position_all_true(%index: index) -> vector<4xi1> {
+ // The mask covers the entire first dimension, so a dynamic index is fine.
+ %mask = vector.constant_mask [4, 3] : vector<4x4xi1>
+ // CHECK: %[[RES:.*]] = vector.constant_mask [3] : vector<4xi1>
+ // CHECK-NEXT: return %[[RES]]
+ %extract = vector.extract %mask[%index] : vector<4xi1> from vector<4x4xi1>
+ return %extract : vector<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_constant_mask_dynamic_position_not_all_true
+// CHECK-SAME: %[[INDEX:.*]]: index
+func.func @extract_from_constant_mask_dynamic_position_not_all_true(%index: index) -> vector<4xi1> {
+ %mask = vector.constant_mask [2, 3] : vector<4x4xi1>
+ // CHECK: %[[MASK:.*]] = vector.constant_mask [2, 3] : vector<4x4xi1>
+ // CHECK-NEXT: %[[RES:.*]] = vector.extract %[[MASK]][%[[INDEX]]] : vector<4xi1> from vector<4x4xi1>
+ // CHECK-NEXT: return %[[RES]]
+ %extract = vector.extract %mask[%index] : vector<4xi1> from vector<4x4xi1>
+ return %extract : vector<4xi1>
+}
+
+// -----
+
// CHECK-LABEL: constant_mask_to_true_splat
func.func @constant_mask_to_true_splat() -> vector<2x4xi1> {
// CHECK: arith.constant dense<true>
>From 694158840e1a33cef006128aef2f3dea67c90aee Mon Sep 17 00:00:00 2001
From: Lukas Sommer <lukas.sommer at amd.com>
Date: Fri, 27 Feb 2026 18:00:51 +0000
Subject: [PATCH 2/3] Remove namespace
Signed-off-by: Lukas Sommer <lukas.sommer at amd.com>
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index fa6e0c3ac7c76..0f81dd8226b96 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2311,7 +2311,7 @@ class ExtractOpFromConstantMask final : public OpRewritePattern<ExtractOp> {
return failure();
auto extractedMaskType =
- llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
+ dyn_cast<VectorType>(extractOp.getResult().getType());
if (!extractedMaskType)
return failure();
>From 80caf5bc303f8de59782cb72b552ff4d0ba80b2e Mon Sep 17 00:00:00 2001
From: Lukas Sommer <lukas.sommer at amd.com>
Date: Mon, 2 Mar 2026 09:40:46 +0000
Subject: [PATCH 3/3] Add support for scalar case
Signed-off-by: Lukas Sommer <lukas.sommer at amd.com>
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 33 ++++++++++++++--------
mlir/test/Dialect/Vector/canonicalize.mlir | 22 +++++++++++++++
2 files changed, 44 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 0f81dd8226b96..5dc3984b0a037 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2310,10 +2310,8 @@ class ExtractOpFromConstantMask final : public OpRewritePattern<ExtractOp> {
if (!constantMaskOp)
return failure();
- auto extractedMaskType =
- dyn_cast<VectorType>(extractOp.getResult().getType());
- if (!extractedMaskType)
- return failure();
+ Type resultType = extractOp.getResult().getType();
+ auto extractedMaskType = dyn_cast<VectorType>(resultType);
ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
@@ -2336,17 +2334,30 @@ class ExtractOpFromConstantMask final : public OpRewritePattern<ExtractOp> {
// If the position is statically outside of the masked area, the result
// will be all-false.
if (pos >= maskDimSizes[dimIdx]) {
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(
- extractOp, DenseElementsAttr::get(extractedMaskType, false));
+ if (extractedMaskType) {
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+ extractOp, DenseElementsAttr::get(extractedMaskType, false));
+ } else {
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+ extractOp, rewriter.getIntegerAttr(resultType, false));
+ }
return success();
}
}
- // All positions are within the mask bounds. The result is a constant_mask
- // with the remaining dimensions.
- rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
- extractOp, extractedMaskType,
- maskDimSizes.drop_front(extractOpPos.size()));
+ // All positions are within the mask bounds.
+ if (extractedMaskType) {
+ // Vector result: the result is a constant_mask with the remaining
+ // dimensions.
+ rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
+ extractOp, extractedMaskType,
+ maskDimSizes.drop_front(extractOpPos.size()));
+ } else {
+ // Scalar result: all positions are within the masked region, so the
+ // result is true.
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+ extractOp, rewriter.getIntegerAttr(resultType, true));
+ }
return success();
}
};
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1d38064fb383c..583aa2efa49c3 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -342,6 +342,28 @@ func.func @extract_from_constant_mask_dynamic_position_not_all_true(%index: inde
// -----
+// CHECK-LABEL: extract_scalar_from_constant_mask_within_bounds
+func.func @extract_scalar_from_constant_mask_within_bounds() -> i1 {
+ %mask = vector.constant_mask [2, 3] : vector<4x4xi1>
+ // CHECK: %[[RES:.*]] = arith.constant true
+ // CHECK-NEXT: return %[[RES]]
+ %extract = vector.extract %mask[0, 1] : i1 from vector<4x4xi1>
+ return %extract : i1
+}
+
+// -----
+
+// CHECK-LABEL: extract_scalar_from_constant_mask_outside_bounds
+func.func @extract_scalar_from_constant_mask_outside_bounds() -> i1 {
+ %mask = vector.constant_mask [2, 3] : vector<4x4xi1>
+ // CHECK: %[[RES:.*]] = arith.constant false
+ // CHECK-NEXT: return %[[RES]]
+ %extract = vector.extract %mask[0, 3] : i1 from vector<4x4xi1>
+ return %extract : i1
+}
+
+// -----
+
// CHECK-LABEL: constant_mask_to_true_splat
func.func @constant_mask_to_true_splat() -> vector<2x4xi1> {
// CHECK: arith.constant dense<true>
More information about the Mlir-commits
mailing list