[Mlir-commits] [mlir] [mlir] [vector] Add linearization pattern for vector.create_mask (PR #138214)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 1 16:06:57 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Nishant Patel (nbpatel)
<details>
<summary>Changes</summary>
This PR is a breakdown [3 / 4] of the PR #<!-- -->136193
The PR adds linearization patterns for vector.create_mask
---
Full diff: https://github.com/llvm/llvm-project/pull/138214.diff
3 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+62-3)
- (modified) mlir/test/Dialect/Vector/linearize.mlir (+38)
- (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+2-1)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index b9cef003fa365..cdd937eed6569 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -445,6 +445,64 @@ struct LinearizeVectorSplat final
}
};
+/// This pattern converts the CreateMaskOp to work on a
+/// linearized vector. The pattern currently
+/// supports only 2D masks with a unit outer dimension.
+/// Following,
+/// vector.create_mask %dims : vector<1x4xi1>
+/// is converted to:
+/// %out_1d = vector.create_mask %dims : vector<4xi1>
+/// %out_nd = vector.shape_cast %out_1d : vector<4xi1> to vector<1x4xi1>
+struct LinearizeVectorCreateMask final
+ : OpConversionPattern<vector::CreateMaskOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LinearizeVectorCreateMask(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto srcTy = createMaskOp.getType();
+ auto srcShape = srcTy.getShape();
+ if (srcShape.size() != 2)
+ return rewriter.notifyMatchFailure(createMaskOp,
+ "only 2D mask is supported.");
+
+ if (srcShape[0] != 1)
+ return rewriter.notifyMatchFailure(
+ createMaskOp, "only unit outer dimension is supported.");
+
+ auto dstTy = getTypeConverter()->convertType(srcTy);
+ if (!dstTy)
+ return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
+
+ // Compare the first operand with 0. If it's less than or equal to 0,
+ // create a zero mask, else strip the first operand and create a mask
+ // using the second operand.
+ auto firstOperand = adaptor.getOperands().front();
+ auto zero =
+ rewriter.create<mlir::arith::ConstantIndexOp>(createMaskOp.getLoc(), 0);
+ auto isZeroOrNegative = rewriter.create<mlir::arith::CmpIOp>(
+ createMaskOp.getLoc(), mlir::arith::CmpIPredicate::sle, firstOperand,
+ zero);
+ auto isZeroOrNegativeSplat = rewriter.create<mlir::vector::SplatOp>(
+ createMaskOp.getLoc(), dstTy, isZeroOrNegative);
+
+ // Use a select operation to choose between the masks.
+ auto zeroMask = rewriter.create<mlir::arith::ConstantOp>(
+ createMaskOp.getLoc(), dstTy, rewriter.getZeroAttr(dstTy));
+ auto newMask = rewriter.create<mlir::vector::CreateMaskOp>(
+ createMaskOp.getLoc(), dstTy, adaptor.getOperands().back());
+ auto result = rewriter.create<mlir::arith::SelectOp>(
+ createMaskOp.getLoc(), isZeroOrNegativeSplat, zeroMask, newMask);
+
+ rewriter.replaceOp(createMaskOp, result.getResult());
+ return success();
+ }
+};
+
} // namespace
/// Return true if the operation `op` does not support scalable vectors and
@@ -530,9 +588,10 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
void mlir::vector::populateVectorLinearizeBasePatterns(
const TypeConverter &typeConverter, const ConversionTarget &target,
RewritePatternSet &patterns) {
- patterns.add<LinearizeConstantLike, LinearizeVectorizable,
- LinearizeVectorBitCast, LinearizeVectorSplat>(
- typeConverter, patterns.getContext());
+ patterns
+ .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
+ LinearizeVectorSplat, LinearizeVectorCreateMask>(
+ typeConverter, patterns.getContext());
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 01ad1ac48b012..01872426c77bb 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -345,3 +345,41 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
%0 = vector.splat %arg0 : vector<4x[2]xi32>
return %0 : vector<4x[2]xi32>
}
+
+// -----
+// ALL-LABEL: linearize_create_mask
+func.func @linearize_create_mask() -> vector<1x16xi1> {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C20:.*]] = arith.constant 20 : index
+ // CHECK: %[[C0_0:.*]] = arith.constant 0 : index
+ // CHECK: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index
+ // CHECK: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<16xi1>
+ // CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<16xi1>
+ // CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<16xi1>
+ // CHECK: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<16xi1>, vector<16xi1>
+ // CHECK: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<16xi1> to vector<1x16xi1>
+ // CHECK: return %[[CAST]] : vector<1x16xi1>
+ %c0 = arith.constant 0 : index
+ %c20 = arith.constant 20 : index
+ %0 = vector.create_mask %c0, %c20 : vector<1x16xi1>
+ return %0 : vector<1x16xi1>
+}
+
+// -----
+// ALL-LABEL: linearize_scalable_create_mask
+func.func @linearize_scalable_create_mask() -> vector<1x[16]xi1> {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C20:.*]] = arith.constant 20 : index
+ // CHECK: %[[C0_0:.*]] = arith.constant 0 : index
+ // CHECK: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index
+ // CHECK: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<[16]xi1>
+ // CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<[16]xi1>
+ // CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<[16]xi1>
+ // CHECK: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<[16]xi1>, vector<[16]xi1>
+ // CHECK: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<[16]xi1> to vector<1x[16]xi1>
+ // CHECK: return %[[CAST]] : vector<1x[16]xi1>
+ %c0 = arith.constant 0 : index
+ %c20 = arith.constant 20 : index
+ %0 = vector.create_mask %c0, %c20 : vector<1x[16]xi1>
+ return %0 : vector<1x[16]xi1>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index eda2594fbc7c7..2d5e90908d4d0 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -973,7 +973,8 @@ struct TestVectorLinearize final
return "Linearizes ND vectors for N >= 2 into 1D vectors";
}
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<vector::VectorDialect>();
+ registry.insert<vector::VectorDialect, memref::MemRefDialect,
+ arith::ArithDialect>();
}
void runOnOperation() override {
``````````
</details>
https://github.com/llvm/llvm-project/pull/138214
More information about the Mlir-commits
mailing list