[Mlir-commits] [mlir] 71ee84a - [MLIR][Vector] Add unroll pattern for vector.constant_mask (#171518)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Dec 11 13:16:59 PST 2025


Author: Nishant Patel
Date: 2025-12-11T13:16:55-08:00
New Revision: 71ee84acc4f7c93b9292af90ef5d79dd05687410

URL: https://github.com/llvm/llvm-project/commit/71ee84acc4f7c93b9292af90ef5d79dd05687410
DIFF: https://github.com/llvm/llvm-project/commit/71ee84acc4f7c93b9292af90ef5d79dd05687410.diff

LOG: [MLIR][Vector] Add unroll pattern for vector.constant_mask (#171518)

This PR adds unrolling for vector.constant_mask op based on the
targetShape. Each unrolled vector computes its local mask size in each
dimension (d) as:
min(max(originalMaskSize[d] - offset[d], 0), unrolledMaskSize[d]).

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
    mlir/test/Dialect/Vector/vector-unroll-options.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index d10bedef6040f..ddb04b6bbe40d 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2534,7 +2534,9 @@ def Vector_TypeCastOp :
 }
 
 def Vector_ConstantMaskOp :
-  Vector_Op<"constant_mask", [Pure]>,
+  Vector_Op<"constant_mask", [Pure, 
+   DeclareOpInterfaceMethods<VectorUnrollOpInterface>
+   ]>,
     Arguments<(ins DenseI64ArrayAttr:$mask_dim_sizes)>,
     Results<(outs VectorOfAnyRankOf<[I1]>)> {
   let summary = "creates a constant vector mask";

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 462bd8c3dc4a6..b62ce8a2ec398 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1094,6 +1094,93 @@ struct UnrollCreateMaskPattern : public OpRewritePattern<vector::CreateMaskOp> {
   vector::UnrollVectorOptions options;
 };
 
+/// This pattern unrolls `vector.constant_mask` operations into smaller mask
+/// operations based on the target unroll shape. Each unrolled slice computes
+/// whether its elements should be masked based on the original mask dimensions
+/// and the slice's offset position.
+///
+/// Example:
+///   Given a constant_mask operation:
+///     %0 = vector.constant_mask [6, 10] : vector<8x16xi1>
+///
+///   and a target unroll shape of <4x8>, the pattern produces:
+///
+///     %false = arith.constant dense<false> : vector<8x16xi1>
+///
+///     Slice [0,0]: elements [0:4, 0:8] - fully within [6, 10] bounds
+///     %mask00 = vector.constant_mask [4, 8] : vector<4x8xi1>
+///     %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1]
+///       : vector<4x8xi1> into vector<8x16xi1>
+///
+///     Slice [0,8]: elements [0:4, 8:16] - partially within bounds
+///     %mask01 = vector.constant_mask [4, 2] : vector<4x8xi1>
+///     %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1]
+///       : vector<4x8xi1> into vector<8x16xi1>
+///
+///     Slice [4,0]: elements [4:8, 0:8] - partially within bounds
+///     %mask10 = vector.constant_mask [2, 8] : vector<4x8xi1>
+///     %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1]
+///       : vector<4x8xi1> into vector<8x16xi1>
+///
+///     Slice [4,8]: elements [4:8, 8:16] - partially within bounds
+///     %mask11 = vector.constant_mask [2, 2] : vector<4x8xi1>
+///     %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1]
+///       : vector<4x8xi1> into vector<8x16xi1>
+struct UnrollConstantMaskPattern
+    : public OpRewritePattern<vector::ConstantMaskOp> {
+  UnrollConstantMaskPattern(MLIRContext *context,
+                            const vector::UnrollVectorOptions &options,
+                            PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::ConstantMaskOp>(context, benefit),
+        options(options) {}
+
+  LogicalResult matchAndRewrite(vector::ConstantMaskOp constantMaskOp,
+                                PatternRewriter &rewriter) const override {
+    std::optional<SmallVector<int64_t>> targetShape =
+        getTargetShape(options, constantMaskOp);
+    if (!targetShape)
+      return failure();
+
+    VectorType resultType = constantMaskOp.getVectorType();
+    SmallVector<int64_t> originalSize = *constantMaskOp.getShapeForUnroll();
+    Location loc = constantMaskOp.getLoc();
+
+    Value result = arith::ConstantOp::create(rewriter, loc, resultType,
+                                             rewriter.getZeroAttr(resultType));
+    VectorType targetVectorType =
+        VectorType::get(*targetShape, rewriter.getI1Type());
+    SmallVector<int64_t> strides(targetShape->size(), 1);
+
+    // In each dimension (d), each unrolled vector computes its mask size as:
+    // min(max(originalMaskDim[d] - offset[d], 0), unrolledDimSize[d]).
+    for (const SmallVector<int64_t> &offsets :
+         StaticTileOffsetRange(originalSize, *targetShape)) {
+      SmallVector<int64_t> unrolledMaskDims;
+
+      for (auto [i, originalMaskDim] :
+           llvm::enumerate(constantMaskOp.getMaskDimSizes())) {
+        // Calculate how many elements in this dimension should be masked
+        // for this particular slice
+        int64_t adjustedMaskSize =
+            std::max(originalMaskDim - offsets[i], static_cast<int64_t>(0));
+        int64_t unrolledMaskDim =
+            std::min(adjustedMaskSize, static_cast<int64_t>((*targetShape)[i]));
+        unrolledMaskDims.push_back(unrolledMaskDim);
+      }
+
+      auto unrolledMask = rewriter.createOrFold<vector::ConstantMaskOp>(
+          loc, targetVectorType, unrolledMaskDims);
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+          loc, unrolledMask, result, offsets, strides);
+    }
+    rewriter.replaceOp(constantMaskOp, result);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
 /// Checks whether extractShape is a contiguous slice of shape.
 /// For extractShape to be contiguous in shape:
 /// 1) All but the leading dimension of extractShape and shape must match
@@ -1294,8 +1381,8 @@ void mlir::vector::populateVectorUnrollPatterns(
                UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
                UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
                UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern,
-               UnrollCreateMaskPattern>(patterns.getContext(), options,
-                                        benefit);
+               UnrollCreateMaskPattern, UnrollConstantMaskPattern>(
+      patterns.getContext(), options, benefit);
 }
 
 void mlir::vector::populateVectorToElementsUnrollPatterns(

diff  --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 805e66f133c59..c2e7f6a9338b1 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -552,6 +552,23 @@ func.func @vector_create_mask_constant_dim_sizes() -> vector<16x16xi1> {
 // CHECK:   %[[S3:.*]] = vector.insert_strided_slice %[[CST_0]], %[[S2]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
 // CHECK:   return %[[S3]] : vector<16x16xi1>
 
+func.func @vector_constant_mask() -> vector<16x16xi1> {
+  %0 = vector.constant_mask [12, 10] : vector<16x16xi1>
+  return %0 : vector<16x16xi1>
+}
+
+// CHECK-LABEL: func @vector_constant_mask
+// CHECK-SAME: () -> vector<16x16xi1>
+//       CHECK:   %[[CST:.*]] = arith.constant dense<false> : vector<16x16xi1>
+//       CHECK:   %[[CST_TRUE:.*]] = arith.constant dense<true> : vector<8x8xi1>
+//       CHECK:   %[[INS00:.*]] = vector.insert_strided_slice %[[CST_TRUE]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+//       CHECK:   %[[MASK01:.*]] = vector.constant_mask [8, 2] : vector<8x8xi1>
+//       CHECK:   %[[INS01:.*]] = vector.insert_strided_slice %[[MASK01]], %[[INS00]] {offsets = [0, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+//       CHECK:   %[[MASK10:.*]] = vector.constant_mask [4, 8] : vector<8x8xi1>
+//       CHECK:   %[[INS10:.*]] = vector.insert_strided_slice %[[MASK10]], %[[INS01]] {offsets = [8, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+//       CHECK:   %[[MASK11:.*]] = vector.constant_mask [4, 2] : vector<8x8xi1>
+//       CHECK:   %[[INS11:.*]] = vector.insert_strided_slice %[[MASK11]], %[[INS10]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+//       CHECK:   return %[[INS11]] : vector<16x16xi1>
 
 func.func @shape_cast_1D(%v: vector<16xf32>) -> vector<2x2x4xf32> {
   %0 = vector.shape_cast %v : vector<16xf32> to vector<2x2x4xf32>

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index f834d0cdd42bd..2cbb5ab3067f2 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -179,11 +179,13 @@ struct TestVectorUnrollingPatterns
                                        return success(isa<vector::StepOp>(op));
                                      }));
     populateVectorUnrollPatterns(
-        patterns, UnrollVectorOptions()
-                      .setNativeShape(ArrayRef<int64_t>{8, 8})
-                      .setFilterConstraint([](Operation *op) {
-                        return success(isa<vector::CreateMaskOp>(op));
-                      }));
+        patterns,
+        UnrollVectorOptions()
+            .setNativeShape(ArrayRef<int64_t>{8, 8})
+            .setFilterConstraint([](Operation *op) {
+              return success(
+                  isa<vector::CreateMaskOp, vector::ConstantMaskOp>(op));
+            }));
     populateVectorUnrollPatterns(
         patterns,
         UnrollVectorOptions()


        


More information about the Mlir-commits mailing list