[Mlir-commits] [mlir] [MLIR][Vector] Add unroll pattern for vector.create_mask (PR #169119)

Nishant Patel llvmlistbot at llvm.org
Fri Nov 21 14:56:01 PST 2025


https://github.com/nbpatel created https://github.com/llvm/llvm-project/pull/169119

This PR adds unrolling for vector.create_mask op based on the targetShape. Each unrolled vector computes its local mask size in each dimension (d) as:
min(max(originalMaskSize[d] - offset[d], 0), unrolledMaskSize[d]).

>From 86425375b67f2e105af80e31d8e96b87fe22ad82 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 14 Nov 2025 21:09:50 +0000
Subject: [PATCH] Add unroll pattern for vector.create_mask

---
 .../mlir/Dialect/Vector/IR/VectorOps.td       |  4 +-
 .../Vector/Transforms/VectorUnroll.cpp        | 94 ++++++++++++++++++-
 .../Dialect/Vector/vector-unroll-options.mlir | 40 ++++++++
 .../Dialect/Vector/TestVectorTransforms.cpp   |  6 ++
 4 files changed, 141 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 43172ff2082df..4f9252f046eab 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2607,7 +2607,9 @@ def Vector_ConstantMaskOp :
 }
 
 def Vector_CreateMaskOp :
-  Vector_Op<"create_mask", [Pure]>,
+  Vector_Op<"create_mask", [Pure, 
+   DeclareOpInterfaceMethods<VectorUnrollOpInterface>
+   ]>,
     Arguments<(ins Variadic<Index>:$operands)>,
     Results<(outs VectorOfAnyRankOf<[I1]>)> {
   let summary = "creates a vector mask";
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index fbae0989bed26..ca2978c5d5a19 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1003,6 +1003,96 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
   vector::UnrollVectorOptions options;
 };
 
+/// This pattern unrolls `vector.create_mask` operations into smaller mask
+/// operations based on the target unroll shape. Each unrolled slice computes
+/// its local mask size in each dimension (d) as:
+/// min(max(originalMaskSize[d] - offset[d], 0), unrolledDimSize[d]).
+/// Example:
+///   Given a create_mask operation:
+///     %0 = vector.create_mask %c6, %c10 : vector<8x16xi1>  // mask first 6x10
+///     elements
+///
+///   and a target unroll shape of <4x8>, the pattern produces:
+///
+///     %false = arith.constant dense<false> : vector<8x16xi1>
+///
+///     Slice [0,0]:
+///     mask size = min(max(6-0, 0), 4) x min(max(10-0, 0), 8) = 4x8
+///     %mask00 = vector.create_mask %c4, %c8 : vector<4x8xi1>
+///     %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1]
+///       : vector<4x8xi1> into vector<8x16xi1>
+///     Slice [0,8]:
+///     mask size = min(max(6-0, 0), 4) x min(max(10-8, 0), 8) = 4x2
+///     %mask01 = vector.create_mask %c4, %c2 : vector<4x8xi1>
+///     %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1]
+///       : vector<4x8xi1> into vector<8x16xi1>
+///     Slice [4,0]:
+///     mask size = min(max(6-4, 0), 4) x min(max(10-0, 0), 8) = 2x8
+///     %mask10 = vector.create_mask %c2, %c8 : vector<4x8xi1>
+///     %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1]
+///       : vector<4x8xi1> into vector<8x16xi1>
+///     Slice [4,8]:
+///     mask size = min(max(6-4, 0), 4) x min(max(10-8, 0), 8) = 2x2
+///     %mask11 = vector.create_mask %c2, %c2 : vector<4x8xi1>
+///     %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1]
+///       : vector<4x8xi1> into vector<8x16xi1>
+struct UnrollCreateMaskPattern : public OpRewritePattern<vector::CreateMaskOp> {
+  UnrollCreateMaskPattern(MLIRContext *context,
+                          const vector::UnrollVectorOptions &options,
+                          PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::CreateMaskOp>(context, benefit),
+        options(options) {}
+
+  LogicalResult matchAndRewrite(vector::CreateMaskOp createMaskOp,
+                                PatternRewriter &rewriter) const override {
+    auto targetShape = getTargetShape(options, createMaskOp);
+    if (!targetShape)
+      return failure();
+
+    VectorType resultType = createMaskOp.getVectorType();
+    SmallVector<int64_t> originalSize = *createMaskOp.getShapeForUnroll();
+    Location loc = createMaskOp.getLoc();
+
+    Value result = arith::ConstantOp::create(rewriter, loc, resultType,
+                                             rewriter.getZeroAttr(resultType));
+    auto targetVectorType = VectorType::get(*targetShape, rewriter.getI1Type());
+    SmallVector<int64_t> strides(targetShape->size(), 1);
+
+    // In each dimension (d), each unrolled vector computes its mask size as:
+    // min(max(originalMaskOperands[d] - offset[d], 0), unrolledDimSize[d]).
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(originalSize, *targetShape)) {
+      SmallVector<Value> unrolledOperands;
+
+      for (auto [i, originalMaskOperand] :
+           llvm::enumerate(createMaskOp.getOperands())) {
+        Value offsetVal =
+            arith::ConstantIndexOp::create(rewriter, loc, offsets[i]);
+        Value adjustedMaskSize = arith::SubIOp::create(
+            rewriter, loc, originalMaskOperand, offsetVal);
+        Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+        Value unrolledDimSize =
+            arith::ConstantIndexOp::create(rewriter, loc, (*targetShape)[i]);
+        Value nonNegative =
+            arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
+        Value unrolledOperand =
+            arith::MinSIOp::create(rewriter, loc, nonNegative, unrolledDimSize);
+        unrolledOperands.push_back(unrolledOperand);
+      }
+
+      auto unrolledMask = vector::CreateMaskOp::create(
+          rewriter, loc, targetVectorType, unrolledOperands);
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+          loc, unrolledMask, result, offsets, strides);
+    }
+    rewriter.replaceOp(createMaskOp, result);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
 } // namespace
 
 void mlir::vector::populateVectorUnrollPatterns(
@@ -1013,8 +1103,8 @@ void mlir::vector::populateVectorUnrollPatterns(
                UnrollReductionPattern, UnrollMultiReductionPattern,
                UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
                UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
-               UnrollToElements, UnrollStepPattern>(patterns.getContext(),
-                                                    options, benefit);
+               UnrollToElements, UnrollStepPattern, UnrollCreateMaskPattern>(
+      patterns.getContext(), options, benefit);
 }
 
 void mlir::vector::populateVectorToElementsUnrollPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index e5a98b5c67f33..f36c77ee8799f 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -496,3 +496,43 @@ func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf3
 // CHECK-COUNT-4:   arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32>
 // CHECK-NOT: arith.addf
 // CHECK: return
+
+func.func @vector_create_mask(%size1: index, %size2: index) -> vector<16x16xi1> {
+  %0 = vector.create_mask %size1, %size2 : vector<16x16xi1>
+  return %0 : vector<16x16xi1>
+}
+
+// CHECK-LABEL: func @vector_create_mask
+// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<16x16xi1>
+//       CHECK:   %[[CST:.*]] = arith.constant dense<false> : vector<16x16xi1>
+//       CHECK:   %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:   %[[C8:.*]] = arith.constant 8 : index
+//       CHECK:   %[[MAX0:.*]] = arith.maxsi %[[ARG0]], %[[C0]] : index
+//       CHECK:   %[[MIN0:.*]] = arith.minsi %[[MAX0]], %[[C8]] : index
+//       CHECK:   %[[MAX1:.*]] = arith.maxsi %[[ARG1]], %[[C0]] : index
+//       CHECK:   %[[MIN1:.*]] = arith.minsi %[[MAX1]], %[[C8]] : index
+//       CHECK:   %[[MASK00:.*]] = vector.create_mask %[[MIN0]], %[[MIN1]] : vector<8x8xi1>
+//       CHECK:   %[[INS00:.*]] = vector.insert_strided_slice %[[MASK00]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+//       CHECK:   %[[MAX0_2:.*]] = arith.maxsi %[[ARG0]], %[[C0]] : index
+//       CHECK:   %[[MIN0_2:.*]] = arith.minsi %[[MAX0_2]], %[[C8]] : index
+//       CHECK:   %[[SUB1:.*]] = arith.subi %[[ARG1]], %[[C8]] : index
+//       CHECK:   %[[MAX1_2:.*]] = arith.maxsi %[[SUB1]], %[[C0]] : index
+//       CHECK:   %[[MIN1_2:.*]] = arith.minsi %[[MAX1_2]], %[[C8]] : index
+//       CHECK:   %[[MASK01:.*]] = vector.create_mask %[[MIN0_2]], %[[MIN1_2]] : vector<8x8xi1>
+//       CHECK:   %[[INS01:.*]] = vector.insert_strided_slice %[[MASK01]], %[[INS00]] {offsets = [0, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+//       CHECK:   %[[SUB0:.*]] = arith.subi %[[ARG0]], %[[C8]] : index
+//       CHECK:   %[[MAX0_3:.*]] = arith.maxsi %[[SUB0]], %[[C0]] : index
+//       CHECK:   %[[MIN0_3:.*]] = arith.minsi %[[MAX0_3]], %[[C8]] : index
+//       CHECK:   %[[MAX1_3:.*]] = arith.maxsi %[[ARG1]], %[[C0]] : index
+//       CHECK:   %[[MIN1_3:.*]] = arith.minsi %[[MAX1_3]], %[[C8]] : index
+//       CHECK:   %[[MASK10:.*]] = vector.create_mask %[[MIN0_3]], %[[MIN1_3]] : vector<8x8xi1>
+//       CHECK:   %[[INS10:.*]] = vector.insert_strided_slice %[[MASK10]], %[[INS01]] {offsets = [8, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+//       CHECK:   %[[SUB0_2:.*]] = arith.subi %[[ARG0]], %[[C8]] : index
+//       CHECK:   %[[MAX0_4:.*]] = arith.maxsi %[[SUB0_2]], %[[C0]] : index
+//       CHECK:   %[[MIN0_4:.*]] = arith.minsi %[[MAX0_4]], %[[C8]] : index
+//       CHECK:   %[[SUB1_2:.*]] = arith.subi %[[ARG1]], %[[C8]] : index
+//       CHECK:   %[[MAX1_4:.*]] = arith.maxsi %[[SUB1_2]], %[[C0]] : index
+//       CHECK:   %[[MIN1_4:.*]] = arith.minsi %[[MAX1_4]], %[[C8]] : index
+//       CHECK:   %[[MASK11:.*]] = vector.create_mask %[[MIN0_4]], %[[MIN1_4]] : vector<8x8xi1>
+//       CHECK:   %[[INS11:.*]] = vector.insert_strided_slice %[[MASK11]], %[[INS10]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+//       CHECK:   return %[[INS11]] : vector<16x16xi1>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 79bfc9bbcda71..8e69a2ab37e5e 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -178,6 +178,12 @@ struct TestVectorUnrollingPatterns
                                      .setFilterConstraint([](Operation *op) {
                                        return success(isa<vector::StepOp>(op));
                                      }));
+    populateVectorUnrollPatterns(patterns,
+                                UnrollVectorOptions()
+                                    .setNativeShape(ArrayRef<int64_t>{8, 8})
+                                    .setFilterConstraint([](Operation *op) {
+                                      return success(isa<vector::CreateMaskOp>(op));
+                                    }));
     populateVectorUnrollPatterns(
         patterns, UnrollVectorOptions()
                       .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})



More information about the Mlir-commits mailing list