[Mlir-commits] [mlir] Added a flag to enable flattening of Constants and Vectors based (PR #83314)
Balaji V. Iyer.
llvmlistbot at llvm.org
Wed Feb 28 15:05:38 PST 2024
https://github.com/bviyer updated https://github.com/llvm/llvm-project/pull/83314
>From 457ff26ae0c7dbc6902909c209ddb50a0b8e4392 Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Mon, 26 Feb 2024 16:40:06 +0000
Subject: [PATCH 1/2] Added a flag to enable flattening of Constants and
Vectors based on user specified vector-lengths.
---
.../Vector/Transforms/VectorRewritePatterns.h | 2 +-
.../Vector/Transforms/VectorLinearize.cpp | 54 ++++++++++++++++---
.../Dialect/Vector/TestVectorTransforms.cpp | 12 ++++-
3 files changed, 58 insertions(+), 10 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 46bb3ddec0baf6..1b634edbcb7cac 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -387,7 +387,7 @@ void populateVectorTransposeNarrowTypeRewritePatterns(
/// the ops to get converted properly.
void populateVectorLinearizeTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target);
+ ConversionTarget &target, unsigned targBitWidth);
} // namespace vector
} // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index c5352043955579..a460cfec7325d8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -19,10 +19,27 @@
using namespace mlir;
+static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
+ auto resultTypes = op->getResultTypes();
+ for (auto resType : resultTypes) {
+ VectorType vecType = cast<VectorType>(resType);
+ unsigned trailingVecDimBitWidth =
+ vecType.getShape().back() * vecType.getElementTypeBitWidth();
+ if (trailingVecDimBitWidth >= targetBitWidth)
+ return false;
+ }
+ return true;
+}
+
namespace {
struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
using OpConversionPattern::OpConversionPattern;
-
+ LinearizeConstant(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
LogicalResult
matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -31,7 +48,9 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
getTypeConverter()->convertType<VectorType>(constOp.getType());
if (!resType)
return rewriter.notifyMatchFailure(loc, "can't convert return type");
-
+ if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
+ return rewriter.notifyMatchFailure(
+ loc, "Can't flatten since targetBitWidth <= OpSize");
auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
if (!dstElementsAttr)
return rewriter.notifyMatchFailure(loc, "unsupported attr type");
@@ -41,15 +60,28 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
dstElementsAttr);
return success();
}
+
+private:
+ unsigned targetVectorBitWidth;
};
struct LinearizeVectorizable final
: OpTraitConversionPattern<OpTrait::Vectorizable> {
using OpTraitConversionPattern::OpTraitConversionPattern;
+public:
+ LinearizeVectorizable(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpTraitConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
+ if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
+ return rewriter.notifyMatchFailure(
+ op->getLoc(), "Can't flatten since targetBitWidth <= OpSize");
FailureOr<Operation *> newOp =
convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
if (failed(newOp))
@@ -58,12 +90,19 @@ struct LinearizeVectorizable final
rewriter.replaceOp(op, (*newOp)->getResults());
return success();
}
+
+private:
+ unsigned targetVectorBitWidth;
};
} // namespace
void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target) {
+ ConversionTarget &target, unsigned targetBitWidth) {
+
+ // Can't pass a paramter into lambda function directory. So need to store
+ // it in a local variable.
+ unsigned targBitWidth = targetBitWidth;
typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
// Ignore scalable vectors for now.
if (type.getRank() <= 1 || type.isScalable())
@@ -83,15 +122,16 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
typeConverter.addArgumentMaterialization(materializeCast);
typeConverter.addSourceMaterialization(materializeCast);
typeConverter.addTargetMaterialization(materializeCast);
-
target.markUnknownOpDynamicallyLegal(
[&](Operation *op) -> std::optional<bool> {
- if (isa<arith::ConstantOp>(op) || op->hasTrait<OpTrait::Vectorizable>())
+ if ((isa<arith::ConstantOp>(op) ||
+ op->hasTrait<OpTrait::Vectorizable>()) &&
+ isLessThanTargetBitWidth(op, targBitWidth))
return typeConverter.isLegal(op);
return std::nullopt;
});
- patterns.add<LinearizeConstant, LinearizeVectorizable>(typeConverter,
- patterns.getContext());
+ patterns.add<LinearizeConstant, LinearizeVectorizable>(
+ typeConverter, patterns.getContext(), targBitWidth);
}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 178a58e796b246..74d2dfa44f4fe9 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -842,6 +842,9 @@ struct TestVectorLinearize final
: public PassWrapper<TestVectorLinearize, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
+ TestVectorLinearize() = default;
+ TestVectorLinearize(const TestVectorLinearize &pass) : PassWrapper(pass) {}
+
StringRef getArgument() const override { return "test-vector-linearize"; }
StringRef getDescription() const override {
return "Linearizes ND vectors for N >= 2 into 1D vectors";
@@ -850,6 +853,11 @@ struct TestVectorLinearize final
registry.insert<vector::VectorDialect>();
}
+ Option<unsigned> targetVectorBitwidth{
+ *this, "target-vector-bitwidth",
+ llvm::cl::desc(
+ "Minimum vector bitwidth to enable the flattening transformation"),
+ llvm::cl::init(std::numeric_limits<unsigned>::max())};
void runOnOperation() override {
auto *context = &getContext();
@@ -857,8 +865,8 @@ struct TestVectorLinearize final
RewritePatternSet patterns(context);
ConversionTarget target(*context);
- vector::populateVectorLinearizeTypeConversionsAndLegality(typeConverter,
- patterns, target);
+ vector::populateVectorLinearizeTypeConversionsAndLegality(
+ typeConverter, patterns, target, targetVectorBitwidth);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
>From bff51c559c7d40608b6aab3b7880e5fb69a1a2a7 Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Wed, 28 Feb 2024 23:04:58 +0000
Subject: [PATCH 2/2] Fixed a couple issues and added a test.
---
.../Vector/Transforms/VectorLinearize.cpp | 16 +++++++---------
mlir/test/Dialect/Vector/linearize.mlir | 8 ++++++++
2 files changed, 15 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index a460cfec7325d8..28a7de22954f99 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -100,9 +100,6 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, unsigned targetBitWidth) {
- // Can't pass a paramter into lambda function directory. So need to store
- // it in a local variable.
- unsigned targBitWidth = targetBitWidth;
typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
// Ignore scalable vectors for now.
if (type.getRank() <= 1 || type.isScalable())
@@ -123,15 +120,16 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
typeConverter.addSourceMaterialization(materializeCast);
typeConverter.addTargetMaterialization(materializeCast);
target.markUnknownOpDynamicallyLegal(
- [&](Operation *op) -> std::optional<bool> {
+ [=](Operation *op) -> std::optional<bool> {
if ((isa<arith::ConstantOp>(op) ||
- op->hasTrait<OpTrait::Vectorizable>()) &&
- isLessThanTargetBitWidth(op, targBitWidth))
- return typeConverter.isLegal(op);
-
+ op->hasTrait<OpTrait::Vectorizable>())) {
+ return (isLessThanTargetBitWidth(op, targetBitWidth)
+ ? typeConverter.isLegal(op)
+ : true);
+ }
return std::nullopt;
});
patterns.add<LinearizeConstant, LinearizeVectorizable>(
- typeConverter, patterns.getContext(), targBitWidth);
+ typeConverter, patterns.getContext(), targetBitWidth);
}
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 85e23103eaedb7..659bb021846d89 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -1,17 +1,25 @@
// RUN: mlir-opt %s -split-input-file -test-vector-linearize | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=128 | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=12 | FileCheck %s --check-prefix=CHECK12
// CHECK-LABEL: test_linearize
+// CHECK12-LABEL: test_linearize
// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32>
func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
// CHECK: %[[C1:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
+// CHECK12: %[[C1:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32>
+
%0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
// CHECK: %[[RES:.*]] = vector.shape_cast %[[C1]] : vector<4xf32> to vector<2x2xf32>
// Arith and math ops are handled in generic way, check some of them
// CHECK: %{{.*}} = math.sin %[[ARG]] : vector<4xf32>
+// CHECK12: %{{.*}} = math.sin %{{.*}} : vector<2x2xf32>
%1 = math.sin %arg0 : vector<2x2xf32>
// CHECK: %{{.*}} = arith.addf %[[ARG]], %[[C1]] : vector<4xf32>
+// CHECK12: %{{.*}} = arith.addf %{{.*}} : vector<2x2xf32>
+
%2 = arith.addf %arg0, %0 : vector<2x2xf32>
// CHECK: return %[[RES]] : vector<2x2xf32>
More information about the Mlir-commits
mailing list