[Mlir-commits] [mlir] 1778d3b - [mlir] [vector] Add linearization pattern for vector.create_mask (#138214)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 14 15:54:01 PDT 2025


Author: Nishant Patel
Date: 2025-05-14T15:53:58-07:00
New Revision: 1778d3b8245b9a7787bbd0b00f60f879ed4689c9

URL: https://github.com/llvm/llvm-project/commit/1778d3b8245b9a7787bbd0b00f60f879ed4689c9
DIFF: https://github.com/llvm/llvm-project/commit/1778d3b8245b9a7787bbd0b00f60f879ed4689c9.diff

LOG: [mlir] [vector] Add linearization pattern for vector.create_mask (#138214)

This PR is a breakdown [3 / 4] of the PR #136193 
The PR adds linearization patterns for vector.create_mask

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
    mlir/test/Dialect/Vector/linearize.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index c7169c5297d9a..90e0479a515d5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -566,6 +566,64 @@ struct LinearizeVectorSplat final
   }
 };
 
+/// This pattern converts the CreateMaskOp to work on a linearized vector.
+/// It currently supports only 2D masks with a unit outer dimension.
+/// Following,
+///   vector.create_mask %arg0, %arg1 : vector<1x4xi1>
+/// is converted to:
+///   %zero = arith.constant 0 : index
+///   %cmpi = arith.cmpi sgt, %arg0, %zero : index
+///   %index = arith.index_cast %cmpi : i1 to index
+///   %mul = arith.andi %index, %arg1 : index
+///   %mask = vector.create_mask %mul : vector<4xi1>
+///   %shape_cast = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1>
+struct LinearizeVectorCreateMask final
+    : OpConversionPattern<vector::CreateMaskOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LinearizeVectorCreateMask(const TypeConverter &typeConverter,
+                            MLIRContext *context, PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = createMaskOp.getLoc();
+    VectorType srcTy = createMaskOp.getType();
+    auto srcShape = srcTy.getShape();
+    if (srcShape.size() != 2)
+      return rewriter.notifyMatchFailure(createMaskOp,
+                                         "only 2D mask is supported.");
+
+    if (srcShape[0] != 1)
+      return rewriter.notifyMatchFailure(
+          createMaskOp, "only unit outer dimension is supported.");
+
+    auto dstTy = getTypeConverter()->convertType(srcTy);
+    if (!dstTy)
+      return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
+
+    // Compare the first operand with 0. If it is greater than 0, the
+    // corresponding mask element is set to true, otherwise false.
+    // The result of the comparison is then multiplied with
+    // the second operand of create_mask to get the 1D mask.
+    auto firstOperand = adaptor.getOperands().front();
+    auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
+    auto isNonZero = rewriter.createOrFold<mlir::arith::CmpIOp>(
+        loc, mlir::arith::CmpIPredicate::sgt, firstOperand, zero);
+    auto isNonZeroIndex = rewriter.createOrFold<mlir::arith::IndexCastOp>(
+        loc, rewriter.getIndexType(), isNonZero);
+    auto secondOperand = adaptor.getOperands().back();
+    auto maskSize = rewriter.createOrFold<mlir::arith::AndIOp>(
+        loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand);
+
+    auto newMask =
+        rewriter.create<mlir::vector::CreateMaskOp>(loc, dstTy, maskSize);
+    rewriter.replaceOp(createMaskOp, newMask);
+    return success();
+  }
+};
+
 } // namespace
 
 /// Return true if the operation `op` does not support scalable vectors and
@@ -651,9 +709,10 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
 void mlir::vector::populateVectorLinearizeBasePatterns(
     const TypeConverter &typeConverter, const ConversionTarget &target,
     RewritePatternSet &patterns) {
-  patterns.add<LinearizeConstantLike, LinearizeVectorizable,
-               LinearizeVectorBitCast, LinearizeVectorSplat>(
-      typeConverter, patterns.getContext());
+  patterns
+      .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
+           LinearizeVectorSplat, LinearizeVectorCreateMask>(
+          typeConverter, patterns.getContext());
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(

diff  --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 3cdbef8db604b..40445d3781228 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -416,3 +416,28 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
   return %0 : vector<4x[2]xi32>
 }
 
+// -----
+
+// CHECK-LABEL: linearize_create_mask
+// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<1x16xi1>
+func.func @linearize_create_mask(%arg0 : index, %arg1 : index) -> vector<1x16xi1> {
+  
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[ARG0]], %[[C0]] : index
+  // CHECK: %[[INDEXCAST:.*]] = arith.index_cast %[[CMP]] : i1 to index
+  // CHECK: %[[MULI:.*]] = arith.andi %[[INDEXCAST]], %[[ARG1]] : index
+  // CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[MULI]] : vector<16xi1>
+  // CHECK: %[[CAST:.*]] = vector.shape_cast %[[MASK_1D]] : vector<16xi1> to vector<1x16xi1>
+  // CHECK: return %[[CAST]] : vector<1x16xi1>
+  %0 = vector.create_mask %arg0, %arg1 : vector<1x16xi1>
+  return %0 : vector<1x16xi1>
+}
+
+// -----
+// CHECK-LABEL: linearize_scalable_create_mask
+func.func @linearize_scalable_create_mask(%arg0 : index, %arg1 : index) -> vector<1x[16]xi1> {
+
+  // CHECK: %[[MASK_1D:.*]] = vector.create_mask {{%.*}} : vector<[16]xi1>
+  %0 = vector.create_mask %arg0, %arg1 : vector<1x[16]xi1>
+  return %0 : vector<1x[16]xi1>
+}

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index eda2594fbc7c7..54defd949c264 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -973,7 +973,7 @@ struct TestVectorLinearize final
     return "Linearizes ND vectors for N >= 2 into 1D vectors";
   }
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<vector::VectorDialect>();
+    registry.insert<vector::VectorDialect, arith::ArithDialect>();
   }
 
   void runOnOperation() override {


        


More information about the Mlir-commits mailing list