[Mlir-commits] [mlir] [MLIR][Vector] Add unroll pattern for vector.create_mask (PR #169119)

Nishant Patel llvmlistbot at llvm.org
Sun Nov 30 23:12:17 PST 2025


================
@@ -1003,6 +1003,96 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
   vector::UnrollVectorOptions options;
 };
 
+/// This pattern unrolls `vector.create_mask` operations into smaller mask
+/// operations based on the target unroll shape. Each unrolled slice computes
+/// its local mask size in each dimension (d) as:
+/// min(max(originalMaskSize[d] - offset[d], 0), unrolledDimSize[d]).
+/// Example:
+///   Given a create_mask operation:
+///     %0 = vector.create_mask %c6, %c10 : vector<8x16xi1>  // mask first 6x10
+///     elements
+///
+///   and a target unroll shape of <4x8>, the pattern produces:
+///
+///     %false = arith.constant dense<false> : vector<8x16xi1>
+///
+///     Slice [0,0]:
+///     mask size = min(max(6-0, 0), 4) x min(max(10-0, 0), 8) = 4x8
+///     %mask00 = vector.create_mask %c4, %c8 : vector<4x8xi1>
+///     %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1]
+///       : vector<4x8xi1> into vector<8x16xi1>
+///     Slice [0,8]:
+///     mask size = min(max(6-0, 0), 4) x min(max(10-8, 0), 8) = 4x2
+///     %mask01 = vector.create_mask %c4, %c2 : vector<4x8xi1>
+///     %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1]
+///       : vector<4x8xi1> into vector<8x16xi1>
+///     Slice [4,0]:
+///     mask size = min(max(6-4, 0), 4) x min(max(10-0, 0), 8) = 2x8
+///     %mask10 = vector.create_mask %c2, %c8 : vector<4x8xi1>
+///     %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1]
+///       : vector<4x8xi1> into vector<8x16xi1>
+///     Slice [4,8]:
+///     mask size = min(max(6-4, 0), 4) x min(max(10-8, 0), 8) = 2x2
+///     %mask11 = vector.create_mask %c2, %c2 : vector<4x8xi1>
+///     %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1]
+///       : vector<4x8xi1> into vector<8x16xi1>
+struct UnrollCreateMaskPattern : public OpRewritePattern<vector::CreateMaskOp> {
+  UnrollCreateMaskPattern(MLIRContext *context,
+                          const vector::UnrollVectorOptions &options,
+                          PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::CreateMaskOp>(context, benefit),
+        options(options) {}
+
+  LogicalResult matchAndRewrite(vector::CreateMaskOp createMaskOp,
+                                PatternRewriter &rewriter) const override {
+    auto targetShape = getTargetShape(options, createMaskOp);
+    if (!targetShape)
+      return failure();
+
+    VectorType resultType = createMaskOp.getVectorType();
+    SmallVector<int64_t> originalSize = *createMaskOp.getShapeForUnroll();
+    Location loc = createMaskOp.getLoc();
+
+    Value result = arith::ConstantOp::create(rewriter, loc, resultType,
+                                             rewriter.getZeroAttr(resultType));
+    auto targetVectorType = VectorType::get(*targetShape, rewriter.getI1Type());
+    SmallVector<int64_t> strides(targetShape->size(), 1);
+
+    // In each dimension (d), each unrolled vector computes its mask size as:
+    // min(max(originalMaskOperands[d] - offset[d], 0), unrolledDimSize[d]).
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(originalSize, *targetShape)) {
+      SmallVector<Value> unrolledOperands;
+
+      for (auto [i, originalMaskOperand] :
+           llvm::enumerate(createMaskOp.getOperands())) {
+        Value offsetVal =
+            arith::ConstantIndexOp::create(rewriter, loc, offsets[i]);
+        Value adjustedMaskSize = arith::SubIOp::create(
+            rewriter, loc, originalMaskOperand, offsetVal);
+        Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+        Value unrolledDimSize =
+            arith::ConstantIndexOp::create(rewriter, loc, (*targetShape)[i]);
+        Value nonNegative =
+            arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
+        Value unrolledOperand =
+            arith::MinSIOp::create(rewriter, loc, nonNegative, unrolledDimSize);
+        unrolledOperands.push_back(unrolledOperand);
+      }
+
+      auto unrolledMask = vector::CreateMaskOp::create(
+          rewriter, loc, targetVectorType, unrolledOperands);
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+          loc, unrolledMask, result, offsets, strides);
----------------
nbpatel wrote:

Thanks for pointing that out..changed it

https://github.com/llvm/llvm-project/pull/169119


More information about the Mlir-commits mailing list