[Mlir-commits] [mlir] [mlir] [vector] Add linearization pattern for vector.create_mask (PR #138214)

Nishant Patel llvmlistbot at llvm.org
Thu May 1 16:08:09 PDT 2025


https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/138214

>From 3a83e2d5cfd5aae8c35fde4050886a96b61edd3f Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 1 May 2025 22:10:03 +0000
Subject: [PATCH 1/4] Add linearization pattern for vector.create_mask

---
 .../Vector/Transforms/VectorLinearize.cpp     | 65 ++++++++++++++++++-
 mlir/test/Dialect/Vector/linearize.mlir       | 33 ++++++++++
 .../Dialect/Vector/TestVectorTransforms.cpp   |  3 +-
 3 files changed, 97 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index b9cef003fa365..cdd937eed6569 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -445,6 +445,64 @@ struct LinearizeVectorSplat final
   }
 };
 
+/// This pattern converts the CreateMaskOp to work on a
+/// linearized vector. The pattern currently
+/// supports only 2D masks with a unit outer dimension.
+/// Following,
+///   vector.create_mask %dims : vector<1x4xi1>
+/// is converted to:
+///   %out_1d = vector.create_mask %dims : vector<4xi1>
+///   %out_nd = vector.shape_cast %out_1d : vector<4xi1> to vector<1x4xi1>
+struct LinearizeVectorCreateMask final
+    : OpConversionPattern<vector::CreateMaskOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LinearizeVectorCreateMask(const TypeConverter &typeConverter,
+                            MLIRContext *context, PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto srcTy = createMaskOp.getType();
+    auto srcShape = srcTy.getShape();
+    if (srcShape.size() != 2)
+      return rewriter.notifyMatchFailure(createMaskOp,
+                                         "only 2D mask is supported.");
+
+    if (srcShape[0] != 1)
+      return rewriter.notifyMatchFailure(
+          createMaskOp, "only unit outer dimension is supported.");
+
+    auto dstTy = getTypeConverter()->convertType(srcTy);
+    if (!dstTy)
+      return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
+
+    // Compare the first operand with 0. If it's less than or equal to 0,
+    // create a zero mask, else strip the first operand and create a mask
+    // using the second operand.
+    auto firstOperand = adaptor.getOperands().front();
+    auto zero =
+        rewriter.create<mlir::arith::ConstantIndexOp>(createMaskOp.getLoc(), 0);
+    auto isZeroOrNegative = rewriter.create<mlir::arith::CmpIOp>(
+        createMaskOp.getLoc(), mlir::arith::CmpIPredicate::sle, firstOperand,
+        zero);
+    auto isZeroOrNegativeSplat = rewriter.create<mlir::vector::SplatOp>(
+        createMaskOp.getLoc(), dstTy, isZeroOrNegative);
+
+    // Use a select operation to choose between the masks.
+    auto zeroMask = rewriter.create<mlir::arith::ConstantOp>(
+        createMaskOp.getLoc(), dstTy, rewriter.getZeroAttr(dstTy));
+    auto newMask = rewriter.create<mlir::vector::CreateMaskOp>(
+        createMaskOp.getLoc(), dstTy, adaptor.getOperands().back());
+    auto result = rewriter.create<mlir::arith::SelectOp>(
+        createMaskOp.getLoc(), isZeroOrNegativeSplat, zeroMask, newMask);
+
+    rewriter.replaceOp(createMaskOp, result.getResult());
+    return success();
+  }
+};
+
 } // namespace
 
 /// Return true if the operation `op` does not support scalable vectors and
@@ -530,9 +588,10 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
 void mlir::vector::populateVectorLinearizeBasePatterns(
     const TypeConverter &typeConverter, const ConversionTarget &target,
     RewritePatternSet &patterns) {
-  patterns.add<LinearizeConstantLike, LinearizeVectorizable,
-               LinearizeVectorBitCast, LinearizeVectorSplat>(
-      typeConverter, patterns.getContext());
+  patterns
+      .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
+           LinearizeVectorSplat, LinearizeVectorCreateMask>(
+          typeConverter, patterns.getContext());
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 20169c15eb2c1..2b802eed64595 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -447,3 +447,36 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
   %0 = vector.splat %arg0 : vector<4x[2]xi32>
   return %0 : vector<4x[2]xi32>
 }
+
+// -----
+// ALL-LABEL: test_create_mask
+func.func @test_create_mask() -> vector<1x16xi1> {
+  // DEFAULT: %[[C0:.*]] = arith.constant 0 : index
+  // BW-128: %[[C0:.*]] = arith.constant 0 : index
+  // DEFAULT: %[[C20:.*]] = arith.constant 20 : index
+  // BW-128: %[[C20:.*]] = arith.constant 20 : index
+  // DEFAULT: %[[C0_0:.*]] = arith.constant 0 : index
+  // BW-128: %[[C0_0:.*]] = arith.constant 0 : index
+  // DEFAULT: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index
+  // BW-128: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index
+  // DEFAULT: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<16xi1>
+  // BW-128: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<16xi1>
+  // DEFAULT: %[[CST:.*]] = arith.constant dense<false> : vector<16xi1>
+  // BW-128: %[[CST:.*]] = arith.constant dense<false> : vector<16xi1>
+  // DEFAULT: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<16xi1>
+  // BW-128: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<16xi1>
+  // DEFAULT: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<16xi1>, vector<16xi1>
+  // BW-128: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<16xi1>
+  // DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<16xi1> to vector<1x16xi1>
+  // BW-128: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<16xi1> to vector<1x16xi1>
+  // DEFAULT: return %[[CAST]] : vector<1x16xi1>
+  // BW-128: return %[[CAST]] : vector<1x16xi1>
+
+  // BW-0: %[[C0:.*]] = arith.constant 0 : index
+  // BW-0: %[[C20:.*]] = arith.constant 20 : index
+  // BW-0: %[[MASK:.*]] = vector.create_mask %[[C0]], %[[C20]] : vector<1x16xi1>
+  %c0 = arith.constant 0 : index
+  %c20 = arith.constant 20 : index
+  %0 = vector.create_mask %c0, %c20 : vector<1x16xi1>
+  return %0 : vector<1x16xi1>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index eda2594fbc7c7..2d5e90908d4d0 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -973,7 +973,8 @@ struct TestVectorLinearize final
     return "Linearizes ND vectors for N >= 2 into 1D vectors";
   }
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<vector::VectorDialect>();
+    registry.insert<vector::VectorDialect, memref::MemRefDialect,
+                    arith::ArithDialect>();
   }
 
   void runOnOperation() override {

>From 24b0739da64b109564abbe85bb68706c0ad6d101 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 1 May 2025 22:56:42 +0000
Subject: [PATCH 2/4] Use CHECKS

---
 mlir/test/Dialect/Vector/linearize.mlir | 38 ++++++++-----------------
 1 file changed, 12 insertions(+), 26 deletions(-)

diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index d5d4cfa4f9aa1..cc5ec1a5c036c 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -347,32 +347,18 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
 }
 
 // -----
-// ALL-LABEL: test_create_mask
-func.func @test_create_mask() -> vector<1x16xi1> {
-  // DEFAULT: %[[C0:.*]] = arith.constant 0 : index
-  // BW-128: %[[C0:.*]] = arith.constant 0 : index
-  // DEFAULT: %[[C20:.*]] = arith.constant 20 : index
-  // BW-128: %[[C20:.*]] = arith.constant 20 : index
-  // DEFAULT: %[[C0_0:.*]] = arith.constant 0 : index
-  // BW-128: %[[C0_0:.*]] = arith.constant 0 : index
-  // DEFAULT: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index
-  // BW-128: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index
-  // DEFAULT: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<16xi1>
-  // BW-128: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<16xi1>
-  // DEFAULT: %[[CST:.*]] = arith.constant dense<false> : vector<16xi1>
-  // BW-128: %[[CST:.*]] = arith.constant dense<false> : vector<16xi1>
-  // DEFAULT: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<16xi1>
-  // BW-128: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<16xi1>
-  // DEFAULT: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<16xi1>, vector<16xi1>
-  // BW-128: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<16xi1>
-  // DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<16xi1> to vector<1x16xi1>
-  // BW-128: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<16xi1> to vector<1x16xi1>
-  // DEFAULT: return %[[CAST]] : vector<1x16xi1>
-  // BW-128: return %[[CAST]] : vector<1x16xi1>
-
-  // BW-0: %[[C0:.*]] = arith.constant 0 : index
-  // BW-0: %[[C20:.*]] = arith.constant 20 : index
-  // BW-0: %[[MASK:.*]] = vector.create_mask %[[C0]], %[[C20]] : vector<1x16xi1>
+// ALL-LABEL: linearize_create_mask
+func.func @linearize_create_mask() -> vector<1x16xi1> {
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[C20:.*]] = arith.constant 20 : index
+  // CHECK: %[[C0_0:.*]] = arith.constant 0 : index
+  // CHECK: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index
+  // CHECK: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<16xi1>
+  // CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<16xi1>
+  // CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<16xi1>
+  // CHECK: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<16xi1>, vector<16xi1>
+  // CHECK: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<16xi1> to vector<1x16xi1>
+  // CHECK: return %[[CAST]] : vector<1x16xi1>
   %c0 = arith.constant 0 : index
   %c20 = arith.constant 20 : index
   %0 = vector.create_mask %c0, %c20 : vector<1x16xi1>

>From 2b8a653c279b3be08fd6426316cd57dbf3fd54eb Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 1 May 2025 23:00:13 +0000
Subject: [PATCH 3/4] Add test case for scalable vector

---
 mlir/test/Dialect/Vector/linearize.mlir | 19 +++++++++++++++++++
 1 file changed, 19 insertions(+)

diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index cc5ec1a5c036c..01872426c77bb 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -364,3 +364,22 @@ func.func @linearize_create_mask() -> vector<1x16xi1> {
   %0 = vector.create_mask %c0, %c20 : vector<1x16xi1>
   return %0 : vector<1x16xi1>
 }
+
+// -----
+// ALL-LABEL: linearize_scalable_create_mask
+func.func @linearize_scalable_create_mask() -> vector<1x[16]xi1> {
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[C20:.*]] = arith.constant 20 : index
+  // CHECK: %[[C0_0:.*]] = arith.constant 0 : index
+  // CHECK: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index
+  // CHECK: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<[16]xi1>
+  // CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<[16]xi1>
+  // CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<[16]xi1>
+  // CHECK: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<[16]xi1>, vector<[16]xi1>
+  // CHECK: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<[16]xi1> to vector<1x[16]xi1>
+  // CHECK: return %[[CAST]] : vector<1x[16]xi1>
+  %c0 = arith.constant 0 : index
+  %c20 = arith.constant 20 : index
+  %0 = vector.create_mask %c0, %c20 : vector<1x[16]xi1>
+  return %0 : vector<1x[16]xi1>
+}

>From 8e8de7af27934145383a7ee504e35f18a5787abd Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 1 May 2025 23:07:38 +0000
Subject: [PATCH 4/4] Clean up

---
 mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 2d5e90908d4d0..54defd949c264 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -973,8 +973,7 @@ struct TestVectorLinearize final
     return "Linearizes ND vectors for N >= 2 into 1D vectors";
   }
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<vector::VectorDialect, memref::MemRefDialect,
-                    arith::ArithDialect>();
+    registry.insert<vector::VectorDialect, arith::ArithDialect>();
   }
 
   void runOnOperation() override {



More information about the Mlir-commits mailing list