[Mlir-commits] [mlir] [mlir] [vector] Add linearization pattern for vector.create_mask (PR #138214)
Nishant Patel
llvmlistbot at llvm.org
Thu May 8 21:08:46 PDT 2025
https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/138214
>From 3a83e2d5cfd5aae8c35fde4050886a96b61edd3f Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 1 May 2025 22:10:03 +0000
Subject: [PATCH 1/7] Add linearization pattern for vector.create_mask
---
.../Vector/Transforms/VectorLinearize.cpp | 65 ++++++++++++++++++-
mlir/test/Dialect/Vector/linearize.mlir | 33 ++++++++++
.../Dialect/Vector/TestVectorTransforms.cpp | 3 +-
3 files changed, 97 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index b9cef003fa365..cdd937eed6569 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -445,6 +445,64 @@ struct LinearizeVectorSplat final
}
};
+/// This pattern converts the CreateMaskOp to work on a
+/// linearized vector. The pattern currently
+/// supports only 2D masks with a unit outer dimension.
+/// Following,
+/// vector.create_mask %dims : vector<1x4xi1>
+/// is converted to:
+/// %out_1d = vector.create_mask %dims : vector<4xi1>
+/// %out_nd = vector.shape_cast %out_1d : vector<4xi1> to vector<1x4xi1>
+struct LinearizeVectorCreateMask final
+ : OpConversionPattern<vector::CreateMaskOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LinearizeVectorCreateMask(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto srcTy = createMaskOp.getType();
+ auto srcShape = srcTy.getShape();
+ if (srcShape.size() != 2)
+ return rewriter.notifyMatchFailure(createMaskOp,
+ "only 2D mask is supported.");
+
+ if (srcShape[0] != 1)
+ return rewriter.notifyMatchFailure(
+ createMaskOp, "only unit outer dimension is supported.");
+
+ auto dstTy = getTypeConverter()->convertType(srcTy);
+ if (!dstTy)
+ return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
+
+ // Compare the first operand with 0. If it's less than or equal to 0,
+ // create a zero mask, else strip the first operand and create a mask
+ // using the second operand.
+ auto firstOperand = adaptor.getOperands().front();
+ auto zero =
+ rewriter.create<mlir::arith::ConstantIndexOp>(createMaskOp.getLoc(), 0);
+ auto isZeroOrNegative = rewriter.create<mlir::arith::CmpIOp>(
+ createMaskOp.getLoc(), mlir::arith::CmpIPredicate::sle, firstOperand,
+ zero);
+ auto isZeroOrNegativeSplat = rewriter.create<mlir::vector::SplatOp>(
+ createMaskOp.getLoc(), dstTy, isZeroOrNegative);
+
+ // Use a select operation to choose between the masks.
+ auto zeroMask = rewriter.create<mlir::arith::ConstantOp>(
+ createMaskOp.getLoc(), dstTy, rewriter.getZeroAttr(dstTy));
+ auto newMask = rewriter.create<mlir::vector::CreateMaskOp>(
+ createMaskOp.getLoc(), dstTy, adaptor.getOperands().back());
+ auto result = rewriter.create<mlir::arith::SelectOp>(
+ createMaskOp.getLoc(), isZeroOrNegativeSplat, zeroMask, newMask);
+
+ rewriter.replaceOp(createMaskOp, result.getResult());
+ return success();
+ }
+};
+
} // namespace
/// Return true if the operation `op` does not support scalable vectors and
@@ -530,9 +588,10 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
void mlir::vector::populateVectorLinearizeBasePatterns(
const TypeConverter &typeConverter, const ConversionTarget &target,
RewritePatternSet &patterns) {
- patterns.add<LinearizeConstantLike, LinearizeVectorizable,
- LinearizeVectorBitCast, LinearizeVectorSplat>(
- typeConverter, patterns.getContext());
+ patterns
+ .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
+ LinearizeVectorSplat, LinearizeVectorCreateMask>(
+ typeConverter, patterns.getContext());
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 20169c15eb2c1..2b802eed64595 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -447,3 +447,36 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
%0 = vector.splat %arg0 : vector<4x[2]xi32>
return %0 : vector<4x[2]xi32>
}
+
+// -----
+// ALL-LABEL: test_create_mask
+func.func @test_create_mask() -> vector<1x16xi1> {
+ // DEFAULT: %[[C0:.*]] = arith.constant 0 : index
+ // BW-128: %[[C0:.*]] = arith.constant 0 : index
+ // DEFAULT: %[[C20:.*]] = arith.constant 20 : index
+ // BW-128: %[[C20:.*]] = arith.constant 20 : index
+ // DEFAULT: %[[C0_0:.*]] = arith.constant 0 : index
+ // BW-128: %[[C0_0:.*]] = arith.constant 0 : index
+ // DEFAULT: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index
+ // BW-128: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index
+ // DEFAULT: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<16xi1>
+ // BW-128: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<16xi1>
+ // DEFAULT: %[[CST:.*]] = arith.constant dense<false> : vector<16xi1>
+ // BW-128: %[[CST:.*]] = arith.constant dense<false> : vector<16xi1>
+ // DEFAULT: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<16xi1>
+ // BW-128: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<16xi1>
+ // DEFAULT: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<16xi1>, vector<16xi1>
+ // BW-128: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<16xi1>
+ // DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<16xi1> to vector<1x16xi1>
+ // BW-128: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<16xi1> to vector<1x16xi1>
+ // DEFAULT: return %[[CAST]] : vector<1x16xi1>
+ // BW-128: return %[[CAST]] : vector<1x16xi1>
+
+ // BW-0: %[[C0:.*]] = arith.constant 0 : index
+ // BW-0: %[[C20:.*]] = arith.constant 20 : index
+ // BW-0: %[[MASK:.*]] = vector.create_mask %[[C0]], %[[C20]] : vector<1x16xi1>
+ %c0 = arith.constant 0 : index
+ %c20 = arith.constant 20 : index
+ %0 = vector.create_mask %c0, %c20 : vector<1x16xi1>
+ return %0 : vector<1x16xi1>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index eda2594fbc7c7..2d5e90908d4d0 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -973,7 +973,8 @@ struct TestVectorLinearize final
return "Linearizes ND vectors for N >= 2 into 1D vectors";
}
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<vector::VectorDialect>();
+ registry.insert<vector::VectorDialect, memref::MemRefDialect,
+ arith::ArithDialect>();
}
void runOnOperation() override {
>From 24b0739da64b109564abbe85bb68706c0ad6d101 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 1 May 2025 22:56:42 +0000
Subject: [PATCH 2/7] Use CHECKS
---
mlir/test/Dialect/Vector/linearize.mlir | 38 ++++++++-----------------
1 file changed, 12 insertions(+), 26 deletions(-)
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index d5d4cfa4f9aa1..cc5ec1a5c036c 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -347,32 +347,18 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
}
// -----
-// ALL-LABEL: test_create_mask
-func.func @test_create_mask() -> vector<1x16xi1> {
- // DEFAULT: %[[C0:.*]] = arith.constant 0 : index
- // BW-128: %[[C0:.*]] = arith.constant 0 : index
- // DEFAULT: %[[C20:.*]] = arith.constant 20 : index
- // BW-128: %[[C20:.*]] = arith.constant 20 : index
- // DEFAULT: %[[C0_0:.*]] = arith.constant 0 : index
- // BW-128: %[[C0_0:.*]] = arith.constant 0 : index
- // DEFAULT: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index
- // BW-128: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index
- // DEFAULT: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<16xi1>
- // BW-128: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<16xi1>
- // DEFAULT: %[[CST:.*]] = arith.constant dense<false> : vector<16xi1>
- // BW-128: %[[CST:.*]] = arith.constant dense<false> : vector<16xi1>
- // DEFAULT: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<16xi1>
- // BW-128: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<16xi1>
- // DEFAULT: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<16xi1>, vector<16xi1>
- // BW-128: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<16xi1>
- // DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<16xi1> to vector<1x16xi1>
- // BW-128: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<16xi1> to vector<1x16xi1>
- // DEFAULT: return %[[CAST]] : vector<1x16xi1>
- // BW-128: return %[[CAST]] : vector<1x16xi1>
-
- // BW-0: %[[C0:.*]] = arith.constant 0 : index
- // BW-0: %[[C20:.*]] = arith.constant 20 : index
- // BW-0: %[[MASK:.*]] = vector.create_mask %[[C0]], %[[C20]] : vector<1x16xi1>
+// ALL-LABEL: linearize_create_mask
+func.func @linearize_create_mask() -> vector<1x16xi1> {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C20:.*]] = arith.constant 20 : index
+ // CHECK: %[[C0_0:.*]] = arith.constant 0 : index
+ // CHECK: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index
+ // CHECK: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<16xi1>
+ // CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<16xi1>
+ // CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<16xi1>
+ // CHECK: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<16xi1>, vector<16xi1>
+ // CHECK: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<16xi1> to vector<1x16xi1>
+ // CHECK: return %[[CAST]] : vector<1x16xi1>
%c0 = arith.constant 0 : index
%c20 = arith.constant 20 : index
%0 = vector.create_mask %c0, %c20 : vector<1x16xi1>
>From 2b8a653c279b3be08fd6426316cd57dbf3fd54eb Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 1 May 2025 23:00:13 +0000
Subject: [PATCH 3/7] Add test case for scalable vector
---
mlir/test/Dialect/Vector/linearize.mlir | 19 +++++++++++++++++++
1 file changed, 19 insertions(+)
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index cc5ec1a5c036c..01872426c77bb 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -364,3 +364,22 @@ func.func @linearize_create_mask() -> vector<1x16xi1> {
%0 = vector.create_mask %c0, %c20 : vector<1x16xi1>
return %0 : vector<1x16xi1>
}
+
+// -----
+// ALL-LABEL: linearize_scalable_create_mask
+func.func @linearize_scalable_create_mask() -> vector<1x[16]xi1> {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C20:.*]] = arith.constant 20 : index
+ // CHECK: %[[C0_0:.*]] = arith.constant 0 : index
+ // CHECK: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index
+ // CHECK: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<[16]xi1>
+ // CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<[16]xi1>
+ // CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<[16]xi1>
+ // CHECK: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<[16]xi1>, vector<[16]xi1>
+ // CHECK: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<[16]xi1> to vector<1x[16]xi1>
+ // CHECK: return %[[CAST]] : vector<1x[16]xi1>
+ %c0 = arith.constant 0 : index
+ %c20 = arith.constant 20 : index
+ %0 = vector.create_mask %c0, %c20 : vector<1x[16]xi1>
+ return %0 : vector<1x[16]xi1>
+}
>From 8e8de7af27934145383a7ee504e35f18a5787abd Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 1 May 2025 23:07:38 +0000
Subject: [PATCH 4/7] Clean up
---
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 2d5e90908d4d0..54defd949c264 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -973,8 +973,7 @@ struct TestVectorLinearize final
return "Linearizes ND vectors for N >= 2 into 1D vectors";
}
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<vector::VectorDialect, memref::MemRefDialect,
- arith::ArithDialect>();
+ registry.insert<vector::VectorDialect, arith::ArithDialect>();
}
void runOnOperation() override {
>From 528f91324c54730a68fde3bb1f3f94c8a258bfce Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 8 May 2025 23:39:59 +0000
Subject: [PATCH 5/7] Address Feedback
---
.../Vector/Transforms/VectorLinearize.cpp | 35 ++++++++++---------
mlir/test/Dialect/Vector/linearize.mlir | 16 ++++-----
2 files changed, 27 insertions(+), 24 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index cdd937eed6569..7e03b073fb369 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -445,14 +445,18 @@ struct LinearizeVectorSplat final
}
};
-/// This pattern converts the CreateMaskOp to work on a
-/// linearized vector. The pattern currently
-/// supports only 2D masks with a unit outer dimension.
+/// This pattern converts the CreateMaskOp to work on a linearized vector.
+// The pattern currently supports only 2D masks with a unit outer dimension.
/// Following,
-/// vector.create_mask %dims : vector<1x4xi1>
+/// vector.create_mask %arg0, %arg1 : vector<1x4xi1>
/// is converted to:
-/// %out_1d = vector.create_mask %dims : vector<4xi1>
-/// %out_nd = vector.shape_cast %out_1d : vector<4xi1> to vector<1x4xi1>
+/// %zero = arith.constant 0 : index
+/// %cmpi = arith.cmpi sle, %arg0, %zero : index
+/// %splat = vector.splat %cmpi : vector<4xi1>
+/// %cst = arith.constant dense<false> : vector<4xi1>
+/// %mask = vector.create_mask %arg1 : vector<4xi1>
+/// %out = arith.select %splat, %cst, %mask : vector<4xi1>
+/// %out_1d = vector.shape_cast %out : vector<4xi1> to vector<1x4xi1>
struct LinearizeVectorCreateMask final
: OpConversionPattern<vector::CreateMaskOp> {
using OpConversionPattern::OpConversionPattern;
@@ -464,7 +468,8 @@ struct LinearizeVectorCreateMask final
LogicalResult
matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto srcTy = createMaskOp.getType();
+ Location loc = createMaskOp.getLoc();
+ VectorType srcTy = createMaskOp.getType();
auto srcShape = srcTy.getShape();
if (srcShape.size() != 2)
return rewriter.notifyMatchFailure(createMaskOp,
@@ -482,21 +487,19 @@ struct LinearizeVectorCreateMask final
// create a zero mask, else strip the first operand and create a mask
// using the second operand.
auto firstOperand = adaptor.getOperands().front();
- auto zero =
- rewriter.create<mlir::arith::ConstantIndexOp>(createMaskOp.getLoc(), 0);
+ auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
auto isZeroOrNegative = rewriter.create<mlir::arith::CmpIOp>(
- createMaskOp.getLoc(), mlir::arith::CmpIPredicate::sle, firstOperand,
- zero);
- auto isZeroOrNegativeSplat = rewriter.create<mlir::vector::SplatOp>(
- createMaskOp.getLoc(), dstTy, isZeroOrNegative);
+ loc, mlir::arith::CmpIPredicate::sle, firstOperand, zero);
+ auto isZeroOrNegativeSplat =
+ rewriter.create<mlir::vector::SplatOp>(loc, dstTy, isZeroOrNegative);
// Use a select operation to choose between the masks.
auto zeroMask = rewriter.create<mlir::arith::ConstantOp>(
- createMaskOp.getLoc(), dstTy, rewriter.getZeroAttr(dstTy));
+ loc, dstTy, rewriter.getZeroAttr(dstTy));
auto newMask = rewriter.create<mlir::vector::CreateMaskOp>(
- createMaskOp.getLoc(), dstTy, adaptor.getOperands().back());
+ loc, dstTy, adaptor.getOperands().back());
auto result = rewriter.create<mlir::arith::SelectOp>(
- createMaskOp.getLoc(), isZeroOrNegativeSplat, zeroMask, newMask);
+ loc, isZeroOrNegativeSplat, zeroMask, newMask);
rewriter.replaceOp(createMaskOp, result.getResult());
return success();
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 01872426c77bb..55fad7b1704c9 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -350,18 +350,18 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
// ALL-LABEL: linearize_create_mask
func.func @linearize_create_mask() -> vector<1x16xi1> {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: %[[C20:.*]] = arith.constant 20 : index
+ // CHECK: %[[C10:.*]] = arith.constant 10 : index
// CHECK: %[[C0_0:.*]] = arith.constant 0 : index
// CHECK: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<16xi1>
// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<16xi1>
- // CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<16xi1>
+ // CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[C10]] : vector<16xi1>
// CHECK: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<16xi1>, vector<16xi1>
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<16xi1> to vector<1x16xi1>
// CHECK: return %[[CAST]] : vector<1x16xi1>
%c0 = arith.constant 0 : index
- %c20 = arith.constant 20 : index
- %0 = vector.create_mask %c0, %c20 : vector<1x16xi1>
+ %c10 = arith.constant 10 : index
+ %0 = vector.create_mask %c0, %c10 : vector<1x16xi1>
return %0 : vector<1x16xi1>
}
@@ -369,17 +369,17 @@ func.func @linearize_create_mask() -> vector<1x16xi1> {
// ALL-LABEL: linearize_scalable_create_mask
func.func @linearize_scalable_create_mask() -> vector<1x[16]xi1> {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: %[[C20:.*]] = arith.constant 20 : index
+ // CHECK: %[[C10:.*]] = arith.constant 10 : index
// CHECK: %[[C0_0:.*]] = arith.constant 0 : index
// CHECK: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<[16]xi1>
// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<[16]xi1>
- // CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<[16]xi1>
+ // CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[C10]] : vector<[16]xi1>
// CHECK: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<[16]xi1>, vector<[16]xi1>
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<[16]xi1> to vector<1x[16]xi1>
// CHECK: return %[[CAST]] : vector<1x[16]xi1>
%c0 = arith.constant 0 : index
- %c20 = arith.constant 20 : index
- %0 = vector.create_mask %c0, %c20 : vector<1x[16]xi1>
+ %c10 = arith.constant 10 : index
+ %0 = vector.create_mask %c0, %c10 : vector<1x[16]xi1>
return %0 : vector<1x[16]xi1>
}
>From c2c1a22a16b1271307620d743378b57d673a3889 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 9 May 2025 01:15:14 +0000
Subject: [PATCH 6/7] Replace select with mul
---
.../Vector/Transforms/VectorLinearize.cpp | 43 +++++++++----------
mlir/test/Dialect/Vector/linearize.mlir | 43 +++++--------------
2 files changed, 31 insertions(+), 55 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 7e03b073fb369..e10483bd1a862 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -446,17 +446,16 @@ struct LinearizeVectorSplat final
};
/// This pattern converts the CreateMaskOp to work on a linearized vector.
-// The pattern currently supports only 2D masks with a unit outer dimension.
+/// It currently supports only 2D masks with a unit outer dimension.
/// Following,
/// vector.create_mask %arg0, %arg1 : vector<1x4xi1>
/// is converted to:
/// %zero = arith.constant 0 : index
-/// %cmpi = arith.cmpi sle, %arg0, %zero : index
-/// %splat = vector.splat %cmpi : vector<4xi1>
-/// %cst = arith.constant dense<false> : vector<4xi1>
-/// %mask = vector.create_mask %arg1 : vector<4xi1>
-/// %out = arith.select %splat, %cst, %mask : vector<4xi1>
-/// %out_1d = vector.shape_cast %out : vector<4xi1> to vector<1x4xi1>
+/// %cmpi = arith.cmpi sgt, %arg0, %zero : index
+/// %index = arith.index_cast %cmpi : i1 to index
+/// %mul = arith.muli %index, %arg1 : index
+/// %mask = vector.create_mask %mul : vector<4xi1>
+/// %out_1d = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1>
struct LinearizeVectorCreateMask final
: OpConversionPattern<vector::CreateMaskOp> {
using OpConversionPattern::OpConversionPattern;
@@ -483,25 +482,23 @@ struct LinearizeVectorCreateMask final
if (!dstTy)
return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
- // Compare the first operand with 0. If it's less than or equal to 0,
- // create a zero mask, else strip the first operand and create a mask
- // using the second operand.
+ // Compare the first operand with 0. If it is greater than 0, the
+ // corresponding mask element is set to true, otherwise false.
+ // The result of the comparison is then multiplied with
+ // the second operand of create_mask to get the 1D mask.
auto firstOperand = adaptor.getOperands().front();
auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
- auto isZeroOrNegative = rewriter.create<mlir::arith::CmpIOp>(
- loc, mlir::arith::CmpIPredicate::sle, firstOperand, zero);
- auto isZeroOrNegativeSplat =
- rewriter.create<mlir::vector::SplatOp>(loc, dstTy, isZeroOrNegative);
-
- // Use a select operation to choose between the masks.
- auto zeroMask = rewriter.create<mlir::arith::ConstantOp>(
- loc, dstTy, rewriter.getZeroAttr(dstTy));
- auto newMask = rewriter.create<mlir::vector::CreateMaskOp>(
- loc, dstTy, adaptor.getOperands().back());
- auto result = rewriter.create<mlir::arith::SelectOp>(
- loc, isZeroOrNegativeSplat, zeroMask, newMask);
+ auto isNonZero = rewriter.create<mlir::arith::CmpIOp>(
+ loc, mlir::arith::CmpIPredicate::sgt, firstOperand, zero);
+ auto isNonZeroIndex = rewriter.create<mlir::arith::IndexCastOp>(
+ loc, rewriter.getIndexType(), isNonZero);
+ auto secondOperand = adaptor.getOperands().back();
+ auto maskSize = rewriter.create<mlir::arith::MulIOp>(
+ loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand);
- rewriter.replaceOp(createMaskOp, result.getResult());
+ auto newMask = rewriter.create<mlir::vector::CreateMaskOp>(
+ loc, dstTy, maskSize.getResult());
+ rewriter.replaceOp(createMaskOp, newMask);
return success();
}
};
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 55fad7b1704c9..3ca2721dc1201 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -347,39 +347,18 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
}
// -----
-// ALL-LABEL: linearize_create_mask
-func.func @linearize_create_mask() -> vector<1x16xi1> {
+
+// CHECK-LABEL: linearize_create_mask
+// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<1x16xi1>
+func.func @linearize_create_mask(%arg0 : index, %arg1 : index) -> vector<1x16xi1> {
+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: %[[C10:.*]] = arith.constant 10 : index
- // CHECK: %[[C0_0:.*]] = arith.constant 0 : index
- // CHECK: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index
- // CHECK: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<16xi1>
- // CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<16xi1>
- // CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[C10]] : vector<16xi1>
- // CHECK: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<16xi1>, vector<16xi1>
- // CHECK: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<16xi1> to vector<1x16xi1>
+ // CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[ARG0]], %[[C0]] : index
+ // CHECK: %[[INDEXCAST:.*]] = arith.index_cast %[[CMP]] : i1 to index
+ // CHECK: %[[MULI:.*]] = arith.muli %[[INDEXCAST]], %[[ARG1]] : index
+ // CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[MULI]] : vector<16xi1>
+ // CHECK: %[[CAST:.*]] = vector.shape_cast %[[MASK_1D]] : vector<16xi1> to vector<1x16xi1>
// CHECK: return %[[CAST]] : vector<1x16xi1>
- %c0 = arith.constant 0 : index
- %c10 = arith.constant 10 : index
- %0 = vector.create_mask %c0, %c10 : vector<1x16xi1>
+ %0 = vector.create_mask %arg0, %arg1 : vector<1x16xi1>
return %0 : vector<1x16xi1>
}
-
-// -----
-// ALL-LABEL: linearize_scalable_create_mask
-func.func @linearize_scalable_create_mask() -> vector<1x[16]xi1> {
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: %[[C10:.*]] = arith.constant 10 : index
- // CHECK: %[[C0_0:.*]] = arith.constant 0 : index
- // CHECK: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index
- // CHECK: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<[16]xi1>
- // CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<[16]xi1>
- // CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[C10]] : vector<[16]xi1>
- // CHECK: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<[16]xi1>, vector<[16]xi1>
- // CHECK: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<[16]xi1> to vector<1x[16]xi1>
- // CHECK: return %[[CAST]] : vector<1x[16]xi1>
- %c0 = arith.constant 0 : index
- %c10 = arith.constant 10 : index
- %0 = vector.create_mask %c0, %c10 : vector<1x[16]xi1>
- return %0 : vector<1x[16]xi1>
-}
>From c5b2e81af1fed9be5eece1966127ffc4ff87af92 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 9 May 2025 04:08:32 +0000
Subject: [PATCH 7/7] Fix typo
---
mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index e10483bd1a862..eb18891772e80 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -455,7 +455,7 @@ struct LinearizeVectorSplat final
/// %index = arith.index_cast %cmpi : i1 to index
/// %mul = arith.muli %index, %arg1 : index
/// %mask = vector.create_mask %mul : vector<4xi1>
-/// %out_1d = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1>
+/// %shape_cast = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1>
struct LinearizeVectorCreateMask final
: OpConversionPattern<vector::CreateMaskOp> {
using OpConversionPattern::OpConversionPattern;
More information about the Mlir-commits
mailing list