[Mlir-commits] [mlir] [mlir][vector] VectorLinearize: `ub.poison` support (PR #128612)
Ivan Butygin
llvmlistbot at llvm.org
Mon Feb 24 17:11:52 PST 2025
https://github.com/Hardcode84 created https://github.com/llvm/llvm-project/pull/128612
None
>From 489bcfadfa410f3d1eb5e9189e655d5db5686b7f Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 25 Feb 2025 02:07:30 +0100
Subject: [PATCH 1/2] [mlir][vector] VectorLinearize: ub.poison support
---
.../Vector/Transforms/VectorLinearize.cpp | 38 +++++++++++++++++--
mlir/test/Dialect/Vector/linearize.mlir | 16 ++++++++
2 files changed, 50 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 3ecd585c5a26d..65bd982319e45 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/Attributes.h"
@@ -97,6 +98,35 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
unsigned targetVectorBitWidth;
};
+struct LinearizePoison final : OpConversionPattern<ub::PoisonOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizePoison(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+ LogicalResult
+ matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto resType = getTypeConverter()->convertType<VectorType>(op.getType());
+
+ if (!resType)
+ return rewriter.notifyMatchFailure(loc, "can't convert return type");
+
+ if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
+ return rewriter.notifyMatchFailure(
+ loc, "Can't flatten since targetBitWidth <= OpSize");
+
+ rewriter.replaceOpWithNewOp<ub::PoisonOp>(op, resType);
+ return success();
+ }
+
+private:
+ unsigned targetVectorBitWidth;
+};
+
struct LinearizeVectorizable final
: OpTraitConversionPattern<OpTrait::Vectorizable> {
using OpTraitConversionPattern::OpTraitConversionPattern;
@@ -525,7 +555,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
typeConverter.addTargetMaterialization(materializeCast);
target.markUnknownOpDynamicallyLegal(
[=](Operation *op) -> std::optional<bool> {
- if ((isa<arith::ConstantOp>(op) || isa<vector::BitCastOp>(op) ||
+ if ((isa<arith::ConstantOp, ub::PoisonOp, vector::BitCastOp>(op) ||
op->hasTrait<OpTrait::Vectorizable>())) {
return (isLessThanTargetBitWidth(op, targetBitWidth)
? typeConverter.isLegal(op)
@@ -534,9 +564,9 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
return std::nullopt;
});
- patterns
- .add<LinearizeConstant, LinearizeVectorizable, LinearizeVectorBitCast>(
- typeConverter, patterns.getContext(), targetBitWidth);
+ patterns.add<LinearizeConstant, LinearizePoison, LinearizeVectorizable,
+ LinearizeVectorBitCast>(typeConverter, patterns.getContext(),
+ targetBitWidth);
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 99b1bbab1eede..22d2cd452166b 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -32,6 +32,22 @@ func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
// -----
+// ALL-LABEL: test_linearize_poison
+func.func @test_linearize_poison(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
+ // DEFAULT: %[[P:.*]] = ub.poison : vector<4xf32>
+ // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[P]] : vector<4xf32> to vector<2x2xf32>
+
+ // BW-128: %[[P:.*]] = ub.poison : vector<4xf32>
+ // BW-128: %[[RES:.*]] = vector.shape_cast %[[P]] : vector<4xf32> to vector<2x2xf32>
+
+ // BW-0: %[[RES:.*]] = ub.poison : vector<2x2xf32>
+ %0 = ub.poison : vector<2x2xf32>
+ // ALL: return %[[RES]] : vector<2x2xf32>
+ return %0 : vector<2x2xf32>
+}
+
+// -----
+
// ALL-LABEL: test_partial_linearize
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>, %[[ORIG_ARG2:.*]]: vector<4x4xf32>)
func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>) -> vector<2x2xf32> {
>From df42371ba5d10921e017a490b05fbd8f4465a701 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 25 Feb 2025 02:10:57 +0100
Subject: [PATCH 2/2] fix test
---
mlir/test/Dialect/Vector/linearize.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 22d2cd452166b..f859ffd0e19d7 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -33,7 +33,7 @@ func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
// -----
// ALL-LABEL: test_linearize_poison
-func.func @test_linearize_poison(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
+func.func @test_linearize_poison() -> vector<2x2xf32> {
// DEFAULT: %[[P:.*]] = ub.poison : vector<4xf32>
// DEFAULT: %[[RES:.*]] = vector.shape_cast %[[P]] : vector<4xf32> to vector<2x2xf32>
More information about the Mlir-commits
mailing list