[Mlir-commits] [mlir] [MLIR][Vector] Add unroll pattern for vector.create_mask (PR #169119)
Nishant Patel
llvmlistbot at llvm.org
Fri Nov 21 14:56:01 PST 2025
https://github.com/nbpatel created https://github.com/llvm/llvm-project/pull/169119
This PR adds unrolling for vector.create_mask op based on the targetShape. Each unrolled vector computes its local mask size in each dimension (d) as:
min(max(originalMaskSize[d] - offset[d], 0), unrolledMaskSize[d]).
>From 86425375b67f2e105af80e31d8e96b87fe22ad82 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 14 Nov 2025 21:09:50 +0000
Subject: [PATCH] Add unroll pattern for vector.create_mask
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 4 +-
.../Vector/Transforms/VectorUnroll.cpp | 94 ++++++++++++++++++-
.../Dialect/Vector/vector-unroll-options.mlir | 40 ++++++++
.../Dialect/Vector/TestVectorTransforms.cpp | 6 ++
4 files changed, 141 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 43172ff2082df..4f9252f046eab 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2607,7 +2607,9 @@ def Vector_ConstantMaskOp :
}
def Vector_CreateMaskOp :
- Vector_Op<"create_mask", [Pure]>,
+ Vector_Op<"create_mask", [Pure,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface>
+ ]>,
Arguments<(ins Variadic<Index>:$operands)>,
Results<(outs VectorOfAnyRankOf<[I1]>)> {
let summary = "creates a vector mask";
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index fbae0989bed26..ca2978c5d5a19 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1003,6 +1003,96 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
vector::UnrollVectorOptions options;
};
+/// This pattern unrolls `vector.create_mask` operations into smaller mask
+/// operations based on the target unroll shape. Each unrolled slice computes
+/// its local mask size in each dimension (d) as:
+/// min(max(originalMaskSize[d] - offset[d], 0), unrolledDimSize[d]).
+/// Example:
+/// Given a create_mask operation:
+/// %0 = vector.create_mask %c6, %c10 : vector<8x16xi1> // mask first 6x10
+/// elements
+///
+/// and a target unroll shape of <4x8>, the pattern produces:
+///
+/// %false = arith.constant dense<false> : vector<8x16xi1>
+///
+/// Slice [0,0]:
+/// mask size = min(max(6-0, 0), 4) x min(max(10-0, 0), 8) = 4x8
+/// %mask00 = vector.create_mask %c4, %c8 : vector<4x8xi1>
+/// %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1]
+/// : vector<4x8xi1> into vector<8x16xi1>
+/// Slice [0,8]:
+/// mask size = min(max(6-0, 0), 4) x min(max(10-8, 0), 8) = 4x2
+/// %mask01 = vector.create_mask %c4, %c2 : vector<4x8xi1>
+/// %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1]
+/// : vector<4x8xi1> into vector<8x16xi1>
+/// Slice [4,0]:
+/// mask size = min(max(6-4, 0), 4) x min(max(10-0, 0), 8) = 2x8
+/// %mask10 = vector.create_mask %c2, %c8 : vector<4x8xi1>
+/// %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1]
+/// : vector<4x8xi1> into vector<8x16xi1>
+/// Slice [4,8]:
+/// mask size = min(max(6-4, 0), 4) x min(max(10-8, 0), 8) = 2x2
+/// %mask11 = vector.create_mask %c2, %c2 : vector<4x8xi1>
+/// %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1]
+/// : vector<4x8xi1> into vector<8x16xi1>
+struct UnrollCreateMaskPattern : public OpRewritePattern<vector::CreateMaskOp> {
+ UnrollCreateMaskPattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::CreateMaskOp>(context, benefit),
+ options(options) {}
+
+ LogicalResult matchAndRewrite(vector::CreateMaskOp createMaskOp,
+ PatternRewriter &rewriter) const override {
+ auto targetShape = getTargetShape(options, createMaskOp);
+ if (!targetShape)
+ return failure();
+
+ VectorType resultType = createMaskOp.getVectorType();
+ SmallVector<int64_t> originalSize = *createMaskOp.getShapeForUnroll();
+ Location loc = createMaskOp.getLoc();
+
+ Value result = arith::ConstantOp::create(rewriter, loc, resultType,
+ rewriter.getZeroAttr(resultType));
+ auto targetVectorType = VectorType::get(*targetShape, rewriter.getI1Type());
+ SmallVector<int64_t> strides(targetShape->size(), 1);
+
+ // In each dimension (d), each unrolled vector computes its mask size as:
+ // min(max(originalMaskOperands[d] - offset[d], 0), unrolledDimSize[d]).
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(originalSize, *targetShape)) {
+ SmallVector<Value> unrolledOperands;
+
+ for (auto [i, originalMaskOperand] :
+ llvm::enumerate(createMaskOp.getOperands())) {
+ Value offsetVal =
+ arith::ConstantIndexOp::create(rewriter, loc, offsets[i]);
+ Value adjustedMaskSize = arith::SubIOp::create(
+ rewriter, loc, originalMaskOperand, offsetVal);
+ Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ Value unrolledDimSize =
+ arith::ConstantIndexOp::create(rewriter, loc, (*targetShape)[i]);
+ Value nonNegative =
+ arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
+ Value unrolledOperand =
+ arith::MinSIOp::create(rewriter, loc, nonNegative, unrolledDimSize);
+ unrolledOperands.push_back(unrolledOperand);
+ }
+
+ auto unrolledMask = vector::CreateMaskOp::create(
+ rewriter, loc, targetVectorType, unrolledOperands);
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, unrolledMask, result, offsets, strides);
+ }
+ rewriter.replaceOp(createMaskOp, result);
+ return success();
+ }
+
+private:
+ vector::UnrollVectorOptions options;
+};
+
} // namespace
void mlir::vector::populateVectorUnrollPatterns(
@@ -1013,8 +1103,8 @@ void mlir::vector::populateVectorUnrollPatterns(
UnrollReductionPattern, UnrollMultiReductionPattern,
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
- UnrollToElements, UnrollStepPattern>(patterns.getContext(),
- options, benefit);
+ UnrollToElements, UnrollStepPattern, UnrollCreateMaskPattern>(
+ patterns.getContext(), options, benefit);
}
void mlir::vector::populateVectorToElementsUnrollPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index e5a98b5c67f33..f36c77ee8799f 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -496,3 +496,43 @@ func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf3
// CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32>
// CHECK-NOT: arith.addf
// CHECK: return
+
+func.func @vector_create_mask(%size1: index, %size2: index) -> vector<16x16xi1> {
+ %0 = vector.create_mask %size1, %size2 : vector<16x16xi1>
+ return %0 : vector<16x16xi1>
+}
+
+// CHECK-LABEL: func @vector_create_mask
+// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<16x16xi1>
+// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<16x16xi1>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[C8:.*]] = arith.constant 8 : index
+// CHECK: %[[MAX0:.*]] = arith.maxsi %[[ARG0]], %[[C0]] : index
+// CHECK: %[[MIN0:.*]] = arith.minsi %[[MAX0]], %[[C8]] : index
+// CHECK: %[[MAX1:.*]] = arith.maxsi %[[ARG1]], %[[C0]] : index
+// CHECK: %[[MIN1:.*]] = arith.minsi %[[MAX1]], %[[C8]] : index
+// CHECK: %[[MASK00:.*]] = vector.create_mask %[[MIN0]], %[[MIN1]] : vector<8x8xi1>
+// CHECK: %[[INS00:.*]] = vector.insert_strided_slice %[[MASK00]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+// CHECK: %[[MAX0_2:.*]] = arith.maxsi %[[ARG0]], %[[C0]] : index
+// CHECK: %[[MIN0_2:.*]] = arith.minsi %[[MAX0_2]], %[[C8]] : index
+// CHECK: %[[SUB1:.*]] = arith.subi %[[ARG1]], %[[C8]] : index
+// CHECK: %[[MAX1_2:.*]] = arith.maxsi %[[SUB1]], %[[C0]] : index
+// CHECK: %[[MIN1_2:.*]] = arith.minsi %[[MAX1_2]], %[[C8]] : index
+// CHECK: %[[MASK01:.*]] = vector.create_mask %[[MIN0_2]], %[[MIN1_2]] : vector<8x8xi1>
+// CHECK: %[[INS01:.*]] = vector.insert_strided_slice %[[MASK01]], %[[INS00]] {offsets = [0, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+// CHECK: %[[SUB0:.*]] = arith.subi %[[ARG0]], %[[C8]] : index
+// CHECK: %[[MAX0_3:.*]] = arith.maxsi %[[SUB0]], %[[C0]] : index
+// CHECK: %[[MIN0_3:.*]] = arith.minsi %[[MAX0_3]], %[[C8]] : index
+// CHECK: %[[MAX1_3:.*]] = arith.maxsi %[[ARG1]], %[[C0]] : index
+// CHECK: %[[MIN1_3:.*]] = arith.minsi %[[MAX1_3]], %[[C8]] : index
+// CHECK: %[[MASK10:.*]] = vector.create_mask %[[MIN0_3]], %[[MIN1_3]] : vector<8x8xi1>
+// CHECK: %[[INS10:.*]] = vector.insert_strided_slice %[[MASK10]], %[[INS01]] {offsets = [8, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+// CHECK: %[[SUB0_2:.*]] = arith.subi %[[ARG0]], %[[C8]] : index
+// CHECK: %[[MAX0_4:.*]] = arith.maxsi %[[SUB0_2]], %[[C0]] : index
+// CHECK: %[[MIN0_4:.*]] = arith.minsi %[[MAX0_4]], %[[C8]] : index
+// CHECK: %[[SUB1_2:.*]] = arith.subi %[[ARG1]], %[[C8]] : index
+// CHECK: %[[MAX1_4:.*]] = arith.maxsi %[[SUB1_2]], %[[C0]] : index
+// CHECK: %[[MIN1_4:.*]] = arith.minsi %[[MAX1_4]], %[[C8]] : index
+// CHECK: %[[MASK11:.*]] = vector.create_mask %[[MIN0_4]], %[[MIN1_4]] : vector<8x8xi1>
+// CHECK: %[[INS11:.*]] = vector.insert_strided_slice %[[MASK11]], %[[INS10]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+// CHECK: return %[[INS11]] : vector<16x16xi1>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 79bfc9bbcda71..8e69a2ab37e5e 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -178,6 +178,12 @@ struct TestVectorUnrollingPatterns
.setFilterConstraint([](Operation *op) {
return success(isa<vector::StepOp>(op));
}));
+ populateVectorUnrollPatterns(patterns,
+ UnrollVectorOptions()
+ .setNativeShape(ArrayRef<int64_t>{8, 8})
+ .setFilterConstraint([](Operation *op) {
+ return success(isa<vector::CreateMaskOp>(op));
+ }));
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
More information about the Mlir-commits
mailing list