[Mlir-commits] [mlir] [MLIR][Vector] Add unroll pattern for vector.create_mask (PR #169119)
Erick Ochoa Lopez
llvmlistbot at llvm.org
Mon Nov 24 06:00:34 PST 2025
================
@@ -1003,6 +1003,96 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
vector::UnrollVectorOptions options;
};
+/// This pattern unrolls `vector.create_mask` operations into smaller mask
+/// operations based on the target unroll shape. Each unrolled slice computes
+/// its local mask size in each dimension (d) as:
+/// min(max(originalMaskSize[d] - offset[d], 0), unrolledDimSize[d]).
+/// Example:
+/// Given a create_mask operation:
+/// %0 = vector.create_mask %c6, %c10 : vector<8x16xi1> // mask first 6x10
+/// elements
+///
+/// and a target unroll shape of <4x8>, the pattern produces:
+///
+/// %false = arith.constant dense<false> : vector<8x16xi1>
+///
+/// Slice [0,0]:
+/// mask size = min(max(6-0, 0), 4) x min(max(10-0, 0), 8) = 4x8
+/// %mask00 = vector.create_mask %c4, %c8 : vector<4x8xi1>
+/// %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1]
+/// : vector<4x8xi1> into vector<8x16xi1>
+/// Slice [0,8]:
+/// mask size = min(max(6-0, 0), 4) x min(max(10-8, 0), 8) = 4x2
+/// %mask01 = vector.create_mask %c4, %c2 : vector<4x8xi1>
+/// %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1]
+/// : vector<4x8xi1> into vector<8x16xi1>
+/// Slice [4,0]:
+/// mask size = min(max(6-4, 0), 4) x min(max(10-0, 0), 8) = 2x8
+/// %mask10 = vector.create_mask %c2, %c8 : vector<4x8xi1>
+/// %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1]
+/// : vector<4x8xi1> into vector<8x16xi1>
+/// Slice [4,8]:
+/// mask size = min(max(6-4, 0), 4) x min(max(10-8, 0), 8) = 2x2
+/// %mask11 = vector.create_mask %c2, %c2 : vector<4x8xi1>
+/// %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1]
+/// : vector<4x8xi1> into vector<8x16xi1>
+struct UnrollCreateMaskPattern : public OpRewritePattern<vector::CreateMaskOp> {
+ UnrollCreateMaskPattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::CreateMaskOp>(context, benefit),
+ options(options) {}
+
+ LogicalResult matchAndRewrite(vector::CreateMaskOp createMaskOp,
+ PatternRewriter &rewriter) const override {
+ auto targetShape = getTargetShape(options, createMaskOp);
----------------
amd-eochoalo wrote:
Can you spell out the type? https://llvm.org/docs/CodingStandards.html#use-auto-type-deduction-to-make-code-more-readable
https://github.com/llvm/llvm-project/pull/169119
More information about the Mlir-commits
mailing list