[Mlir-commits] [mlir] [MLIR][Vector] Add unroll pattern for vector.constant_mask (PR #171518)
Nishant Patel
llvmlistbot at llvm.org
Thu Dec 11 10:51:57 PST 2025
https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/171518
>From 5008d19c3467f396fa8eb6cf016d18de1743f7a1 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 9 Dec 2025 03:30:46 +0000
Subject: [PATCH 1/3] Add unroll pattern for vector.constant_mask
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 4 +-
.../Vector/Transforms/VectorUnroll.cpp | 89 ++++++++++++++++++-
.../Dialect/Vector/vector-unroll-options.mlir | 17 ++++
.../Dialect/Vector/TestVectorTransforms.cpp | 12 +--
4 files changed, 114 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index d8ed46c2820fe..3d76f3d7fbc46 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2547,7 +2547,9 @@ def Vector_TypeCastOp :
}
def Vector_ConstantMaskOp :
- Vector_Op<"constant_mask", [Pure]>,
+ Vector_Op<"constant_mask", [Pure,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface>
+ ]>,
Arguments<(ins DenseI64ArrayAttr:$mask_dim_sizes)>,
Results<(outs VectorOfAnyRankOf<[I1]>)> {
let summary = "creates a constant vector mask";
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 462bd8c3dc4a6..81e7d76eefcfb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1094,6 +1094,91 @@ struct UnrollCreateMaskPattern : public OpRewritePattern<vector::CreateMaskOp> {
vector::UnrollVectorOptions options;
};
+/// This pattern unrolls `vector.constant_mask` operations into smaller mask
+/// operations based on the target unroll shape. Each unrolled slice computes
+/// whether its elements should be masked based on the original mask dimensions
+/// and the slice's offset position.
+///
+/// Example:
+/// Given a constant_mask operation:
+/// %0 = vector.constant_mask [6, 10] : vector<8x16xi1> // mask first 6x10
+/// elements
+///
+/// and a target unroll shape of <4x8>, the pattern produces:
+///
+/// %false = arith.constant dense<false> : vector<8x16xi1>
+///
+/// Slice [0,0]: elements [0:4, 0:8] - fully within [6, 10] bounds
+/// %mask00 = vector.constant_mask [4, 8] : vector<4x8xi1>
+/// %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1]
+/// : vector<4x8xi1> into vector<8x16xi1>
+///
+/// Slice [0,8]: elements [0:4, 8:16] - partially within bounds
+/// %mask01 = vector.constant_mask [4, 2] : vector<4x8xi1>
+/// %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1]
+/// : vector<4x8xi1> into vector<8x16xi1>
+///
+/// Slice [4,0]: elements [4:8, 0:8] - partially within bounds
+/// %mask10 = vector.constant_mask [2, 8] : vector<4x8xi1>
+/// %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1]
+/// : vector<4x8xi1> into vector<8x16xi1>
+///
+/// Slice [4,8]: elements [4:8, 8:16] - partially within bounds
+/// %mask11 = vector.constant_mask [2, 2] : vector<4x8xi1>
+/// %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1]
+/// : vector<4x8xi1> into vector<8x16xi1>
+struct UnrollConstantMaskPattern
+ : public OpRewritePattern<vector::ConstantMaskOp> {
+ UnrollConstantMaskPattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::ConstantMaskOp>(context, benefit),
+ options(options) {}
+
+ LogicalResult matchAndRewrite(vector::ConstantMaskOp constantMaskOp,
+ PatternRewriter &rewriter) const override {
+ auto targetShape = getTargetShape(options, constantMaskOp);
+ if (!targetShape)
+ return failure();
+
+ VectorType resultType = constantMaskOp.getVectorType();
+ SmallVector<int64_t> originalSize = *constantMaskOp.getShapeForUnroll();
+ Location loc = constantMaskOp.getLoc();
+
+ Value result = arith::ConstantOp::create(rewriter, loc, resultType,
+ rewriter.getZeroAttr(resultType));
+ VectorType targetVectorType =
+ VectorType::get(*targetShape, rewriter.getI1Type());
+ SmallVector<int64_t> strides(targetShape->size(), 1);
+
+ // In each dimension (d), each unrolled vector computes its mask size as:
+ // min(max(originalMaskDim[d] - offset[d], 0), unrolledDimSize[d]).
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(originalSize, *targetShape)) {
+ SmallVector<int64_t> unrolledMaskDims;
+
+ for (auto [i, originalMaskDim] :
+ llvm::enumerate(constantMaskOp.getMaskDimSizes())) {
+ // Calculate how many elements in this dimension should be masked
+ // for this particular slice
+ int64_t adjustedMaskSize = std::max(originalMaskDim - offsets[i], 0L);
+ int64_t unrolledMaskDim = std::min(adjustedMaskSize, (*targetShape)[i]);
+ unrolledMaskDims.push_back(unrolledMaskDim);
+ }
+
+ auto unrolledMask = rewriter.createOrFold<vector::ConstantMaskOp>(
+ loc, targetVectorType, unrolledMaskDims);
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, unrolledMask, result, offsets, strides);
+ }
+ rewriter.replaceOp(constantMaskOp, result);
+ return success();
+ }
+
+private:
+ vector::UnrollVectorOptions options;
+};
+
/// Checks whether extractShape is a contiguous slice of shape.
/// For extractShape to be contiguous in shape:
/// 1) All but the leading dimension of extractShape and shape must match
@@ -1294,8 +1379,8 @@ void mlir::vector::populateVectorUnrollPatterns(
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern,
- UnrollCreateMaskPattern>(patterns.getContext(), options,
- benefit);
+ UnrollCreateMaskPattern, UnrollConstantMaskPattern>(
+ patterns.getContext(), options, benefit);
}
void mlir::vector::populateVectorToElementsUnrollPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 805e66f133c59..c2e7f6a9338b1 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -552,6 +552,23 @@ func.func @vector_create_mask_constant_dim_sizes() -> vector<16x16xi1> {
// CHECK: %[[S3:.*]] = vector.insert_strided_slice %[[CST_0]], %[[S2]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
// CHECK: return %[[S3]] : vector<16x16xi1>
+func.func @vector_constant_mask() -> vector<16x16xi1> {
+ %0 = vector.constant_mask [12, 10] : vector<16x16xi1>
+ return %0 : vector<16x16xi1>
+}
+
+// CHECK-LABEL: func @vector_constant_mask
+// CHECK-SAME: () -> vector<16x16xi1>
+// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<16x16xi1>
+// CHECK: %[[CST_TRUE:.*]] = arith.constant dense<true> : vector<8x8xi1>
+// CHECK: %[[INS00:.*]] = vector.insert_strided_slice %[[CST_TRUE]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+// CHECK: %[[MASK01:.*]] = vector.constant_mask [8, 2] : vector<8x8xi1>
+// CHECK: %[[INS01:.*]] = vector.insert_strided_slice %[[MASK01]], %[[INS00]] {offsets = [0, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+// CHECK: %[[MASK10:.*]] = vector.constant_mask [4, 8] : vector<8x8xi1>
+// CHECK: %[[INS10:.*]] = vector.insert_strided_slice %[[MASK10]], %[[INS01]] {offsets = [8, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+// CHECK: %[[MASK11:.*]] = vector.constant_mask [4, 2] : vector<8x8xi1>
+// CHECK: %[[INS11:.*]] = vector.insert_strided_slice %[[MASK11]], %[[INS10]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+// CHECK: return %[[INS11]] : vector<16x16xi1>
func.func @shape_cast_1D(%v: vector<16xf32>) -> vector<2x2x4xf32> {
%0 = vector.shape_cast %v : vector<16xf32> to vector<2x2x4xf32>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index f834d0cdd42bd..2cbb5ab3067f2 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -179,11 +179,13 @@ struct TestVectorUnrollingPatterns
return success(isa<vector::StepOp>(op));
}));
populateVectorUnrollPatterns(
- patterns, UnrollVectorOptions()
- .setNativeShape(ArrayRef<int64_t>{8, 8})
- .setFilterConstraint([](Operation *op) {
- return success(isa<vector::CreateMaskOp>(op));
- }));
+ patterns,
+ UnrollVectorOptions()
+ .setNativeShape(ArrayRef<int64_t>{8, 8})
+ .setFilterConstraint([](Operation *op) {
+ return success(
+ isa<vector::CreateMaskOp, vector::ConstantMaskOp>(op));
+ }));
populateVectorUnrollPatterns(
patterns,
UnrollVectorOptions()
>From e159099766b113a9a2a0ddb79bcbfbded7c9dfd0 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 9 Dec 2025 22:58:52 +0000
Subject: [PATCH 2/3] Add cast
---
mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 81e7d76eefcfb..7357e2478f3df 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1161,8 +1161,10 @@ struct UnrollConstantMaskPattern
llvm::enumerate(constantMaskOp.getMaskDimSizes())) {
// Calculate how many elements in this dimension should be masked
// for this particular slice
- int64_t adjustedMaskSize = std::max(originalMaskDim - offsets[i], 0L);
- int64_t unrolledMaskDim = std::min(adjustedMaskSize, (*targetShape)[i]);
+ int64_t adjustedMaskSize =
+ std::max(originalMaskDim - offsets[i], static_cast<int64_t>(0));
+ int64_t unrolledMaskDim =
+ std::min(adjustedMaskSize, static_cast<int64_t>((*targetShape)[i]));
unrolledMaskDims.push_back(unrolledMaskDim);
}
>From a68aedf884f2e223e1aa38c01c463a19a4efab94 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 11 Dec 2025 17:29:40 +0000
Subject: [PATCH 3/3] Address feedback
---
mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 7357e2478f3df..b62ce8a2ec398 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1101,8 +1101,7 @@ struct UnrollCreateMaskPattern : public OpRewritePattern<vector::CreateMaskOp> {
///
/// Example:
/// Given a constant_mask operation:
-/// %0 = vector.constant_mask [6, 10] : vector<8x16xi1> // mask first 6x10
-/// elements
+/// %0 = vector.constant_mask [6, 10] : vector<8x16xi1>
///
/// and a target unroll shape of <4x8>, the pattern produces:
///
@@ -1137,7 +1136,8 @@ struct UnrollConstantMaskPattern
LogicalResult matchAndRewrite(vector::ConstantMaskOp constantMaskOp,
PatternRewriter &rewriter) const override {
- auto targetShape = getTargetShape(options, constantMaskOp);
+ std::optional<SmallVector<int64_t>> targetShape =
+ getTargetShape(options, constantMaskOp);
if (!targetShape)
return failure();
@@ -1153,7 +1153,7 @@ struct UnrollConstantMaskPattern
// In each dimension (d), each unrolled vector computes its mask size as:
// min(max(originalMaskDim[d] - offset[d], 0), unrolledDimSize[d]).
- for (SmallVector<int64_t> offsets :
+ for (const SmallVector<int64_t> &offsets :
StaticTileOffsetRange(originalSize, *targetShape)) {
SmallVector<int64_t> unrolledMaskDims;
More information about the Mlir-commits
mailing list