[Mlir-commits] [mlir] [mlir][vector] Support index type in ND to 1D vector linearization (PR #118404)
Amy Zhuang
llvmlistbot at llvm.org
Mon Dec 2 14:08:38 PST 2024
https://github.com/ayzhuang created https://github.com/llvm/llvm-project/pull/118404
Currently index type is not supported because getElementTypeBitWidth aborts for index type. This patch adds indexBitWidth input to the vector linearization patterns.
>From d82a741ae91b73dff02dbd19c65e0fd7c6198c9e Mon Sep 17 00:00:00 2001
From: Amy Zhuang <amy.zhuang at intel.com>
Date: Mon, 2 Dec 2024 23:54:21 +0200
Subject: [PATCH] [mlir][vector] Support index type in ND to 1D vector
linearization
Currently index type is not supported because getElementTypeBitWidth
aborts for index type. This patch adds indexBitWidth input to
the vector linearization patterns.
---
.../Vector/Transforms/VectorRewritePatterns.h | 4 +-
.../Vector/Transforms/VectorLinearize.cpp | 85 +++++++++++++------
mlir/test/Dialect/Vector/linearize.mlir | 15 +++-
.../Dialect/Vector/TestVectorTransforms.cpp | 8 +-
4 files changed, 79 insertions(+), 33 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index a59f06f3c1ef1b..e3c19a078c18b0 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -399,13 +399,13 @@ void populateVectorTransposeNarrowTypeRewritePatterns(
/// the ops to get converted properly.
void populateVectorLinearizeTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target, unsigned targetBitWidth);
+ ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth);
/// Populates patterns for linearizing ND (N >= 2) vector operations to 1D
/// vector shuffle operations.
void populateVectorLinearizeShuffleLikeOpsPatterns(
const TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target, unsigned targetBitWidth);
+ ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth);
} // namespace vector
} // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 757631944f224f..f0bf6276f0e659 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -25,34 +25,44 @@
using namespace mlir;
-static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
+static bool isLessThanTargetBitWidth(Operation *op, unsigned indexBitWidth,
+ unsigned targetBitWidth) {
auto resultTypes = op->getResultTypes();
for (auto resType : resultTypes) {
VectorType vecType = dyn_cast<VectorType>(resType);
- // Reject index since getElementTypeBitWidth will abort for Index types.
- if (!vecType || vecType.getElementType().isIndex())
+ if (!vecType)
+ return false;
+ bool isIndexTy = vecType.getElementType().isIndex();
+ // Reject index if `indexBitWidth` is not supplied.
+ if (isIndexTy && indexBitWidth == 0)
return false;
// There are no dimension to fold if it is a 0-D vector.
if (vecType.getRank() == 0)
return false;
unsigned trailingVecDimBitWidth =
- vecType.getShape().back() * vecType.getElementTypeBitWidth();
+ vecType.getShape().back() *
+ (isIndexTy ? indexBitWidth : vecType.getElementTypeBitWidth());
if (trailingVecDimBitWidth >= targetBitWidth)
return false;
}
return true;
}
-static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
+static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned indexBitWidth,
+ unsigned targetBitWidth) {
VectorType vecType = dyn_cast<VectorType>(t);
- // Reject index since getElementTypeBitWidth will abort for Index types.
- if (!vecType || vecType.getElementType().isIndex())
+ if (!vecType)
+ return false;
+ bool isIndexTy = vecType.getElementType().isIndex();
+ // Reject index if `indexBitWidth` is not supplied.
+ if (isIndexTy && indexBitWidth == 0)
return false;
// There are no dimension to fold if it is a 0-D vector.
if (vecType.getRank() == 0)
return false;
unsigned trailingVecDimBitWidth =
- vecType.getShape().back() * vecType.getElementTypeBitWidth();
+ vecType.getShape().back() *
+ (isIndexTy ? indexBitWidth : vecType.getElementTypeBitWidth());
return trailingVecDimBitWidth <= targetBitWidth;
}
@@ -61,10 +71,12 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
using OpConversionPattern::OpConversionPattern;
LinearizeConstant(
const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned indexBitWidth = 0,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+ }
LogicalResult
matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -79,7 +91,7 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
if (!resType)
return rewriter.notifyMatchFailure(loc, "can't convert return type");
- if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
+ if (!isLessThanTargetBitWidth(constOp, indexBitWidth, targetVectorBitWidth))
return rewriter.notifyMatchFailure(
loc, "Can't flatten since targetBitWidth <= OpSize");
auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
@@ -93,6 +105,7 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
}
private:
+ unsigned indexBitWidth;
unsigned targetVectorBitWidth;
};
@@ -103,14 +116,16 @@ struct LinearizeVectorizable final
public:
LinearizeVectorizable(
const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned indexBitWidth = 0,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
: OpTraitConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+ }
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
+ if (!isLessThanTargetBitWidth(op, indexBitWidth, targetVectorBitWidth))
return rewriter.notifyMatchFailure(
op->getLoc(), "Can't flatten since targetBitWidth <= OpSize");
FailureOr<Operation *> newOp =
@@ -123,6 +138,7 @@ struct LinearizeVectorizable final
}
private:
+ unsigned indexBitWidth;
unsigned targetVectorBitWidth;
};
@@ -142,10 +158,12 @@ struct LinearizeVectorExtractStridedSlice final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorExtractStridedSlice(
const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned indexBitWidth = 0,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+ }
LogicalResult
matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
@@ -156,7 +174,8 @@ struct LinearizeVectorExtractStridedSlice final
if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
return rewriter.notifyMatchFailure(extractOp,
"scalable vectors are not supported.");
- if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
+ if (!isLessThanTargetBitWidth(extractOp, indexBitWidth,
+ targetVectorBitWidth))
return rewriter.notifyMatchFailure(
extractOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -237,6 +256,7 @@ struct LinearizeVectorExtractStridedSlice final
}
private:
+ unsigned indexBitWidth;
unsigned targetVectorBitWidth;
};
@@ -256,10 +276,12 @@ struct LinearizeVectorShuffle final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorShuffle(
const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned indexBitWidth = 0,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+ }
LogicalResult
matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
@@ -273,7 +295,8 @@ struct LinearizeVectorShuffle final
shuffleOp.getV2VectorType().isScalable() ||
dstType.isScalable()) &&
"scalable vectors are not supported.");
- if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
+ if (!isLessThanTargetBitWidth(shuffleOp, indexBitWidth,
+ targetVectorBitWidth))
return rewriter.notifyMatchFailure(
shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -312,6 +335,7 @@ struct LinearizeVectorShuffle final
}
private:
+ unsigned indexBitWidth;
unsigned targetVectorBitWidth;
};
@@ -329,10 +353,12 @@ struct LinearizeVectorExtract final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorExtract(
const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned indexBitWidth = 0,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+ }
LogicalResult
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -345,7 +371,8 @@ struct LinearizeVectorExtract final
cast<VectorType>(dstTy).isScalable())
return rewriter.notifyMatchFailure(extractOp,
"scalable vectors are not supported.");
- if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
+ if (!isLessThanTargetBitWidth(extractOp, indexBitWidth,
+ targetVectorBitWidth))
return rewriter.notifyMatchFailure(
extractOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -374,6 +401,7 @@ struct LinearizeVectorExtract final
}
private:
+ unsigned indexBitWidth;
unsigned targetVectorBitWidth;
};
@@ -392,10 +420,12 @@ struct LinearizeVectorInsert final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorInsert(
const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned indexBitWidth = 0,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+ }
LogicalResult
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -407,7 +437,7 @@ struct LinearizeVectorInsert final
"scalable vectors are not supported.");
if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
- targetVectorBitWidth))
+ indexBitWidth, targetVectorBitWidth))
return rewriter.notifyMatchFailure(
insertOp, "Can't flatten since targetBitWidth < OpSize");
@@ -457,13 +487,14 @@ struct LinearizeVectorInsert final
}
private:
+ unsigned indexBitWidth;
unsigned targetVectorBitWidth;
};
} // namespace
void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target, unsigned targetBitWidth) {
+ ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth) {
typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
if (!isLinearizableVector(type))
@@ -488,7 +519,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
[=](Operation *op) -> std::optional<bool> {
if ((isa<arith::ConstantOp>(op) ||
op->hasTrait<OpTrait::Vectorizable>())) {
- return (isLessThanTargetBitWidth(op, targetBitWidth)
+ return (isLessThanTargetBitWidth(op, indexBitWidth, targetBitWidth)
? typeConverter.isLegal(op)
: true);
}
@@ -496,15 +527,17 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
});
patterns.add<LinearizeConstant, LinearizeVectorizable>(
- typeConverter, patterns.getContext(), targetBitWidth);
+ typeConverter, patterns.getContext(), indexBitWidth, targetBitWidth);
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
const TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target, unsigned int targetBitWidth) {
+ ConversionTarget &target, unsigned indexBitWidth,
+ unsigned int targetBitWidth) {
target.addDynamicallyLegalOp<vector::ShuffleOp>(
[=](vector::ShuffleOp shuffleOp) -> bool {
- return isLessThanTargetBitWidth(shuffleOp, targetBitWidth)
+ return isLessThanTargetBitWidth(shuffleOp, indexBitWidth,
+ targetBitWidth)
? (typeConverter.isLegal(shuffleOp) &&
cast<mlir::VectorType>(shuffleOp.getResult().getType())
.getRank() == 1)
@@ -512,5 +545,5 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
});
patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
- typeConverter, patterns.getContext(), targetBitWidth);
+ typeConverter, patterns.getContext(), indexBitWidth, targetBitWidth);
}
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 543e76b5b26e0c..fe169d3e16d683 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s --check-prefixes=ALL,DEFAULT
// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=128 -verify-diagnostics | FileCheck %s --check-prefixes=ALL,BW-128
// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefixes=ALL,BW-0
+// RUN: mlir-opt %s -split-input-file -test-vector-linearize=index-bitwidth=64 | FileCheck %s --check-prefixes=ALL,INDEX-BW-64
// ALL-LABEL: test_linearize
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
@@ -14,6 +15,8 @@ func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
// BW-128: %[[RES:.*]] = vector.shape_cast %[[CST]] : vector<4xf32> to vector<2x2xf32>
// BW-0: %[[RES:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32>
+
+ // INDEX-BW-64: %[[RES:.*]] = vector.shape_cast %{{.*}} : vector<4xf32> to vector<2x2xf32>
%0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
// DEFAULT: %{{.*}} = math.sin %[[ARG]] : vector<4xf32>
@@ -45,6 +48,8 @@ func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>
// BW-128: %[[RES:.*]] = vector.shape_cast %[[CST]] : vector<4xf32> to vector<2x2xf32>
// BW-0: %[[RES:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32>
+
+ // INDEX-BW-64: %[[RES:.*]] = vector.shape_cast %{{.*}} : vector<4xf32> to vector<2x2xf32>
%0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
// DEFAULT: %[[C2:.*]] = arith.constant dense<{{.*}}> : vector<16xf32>
@@ -79,9 +84,12 @@ func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>
// -----
-// ALL-LABEL: test_index_no_linearize
-func.func @test_index_no_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> {
- // ALL: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+// ALL-LABEL: test_index_linearize
+func.func @test_index_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> {
+ // DEFAULT: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+ // BW-128: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+ // BW-0: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+ // INDEX-BW-64: %[[ADD:.*]] = arith.addi {{.*}} : vector<4xindex>
%0 = arith.addi %arg0, %arg1 : vector<2x2xindex>
return %0 : vector<2x2xindex>
}
@@ -122,6 +130,7 @@ func.func @test_scalable_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32
// DEFAULT: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[4]xf32> to vector<2x[2]xf32>
// BW-128: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[4]xf32> to vector<2x[2]xf32>
+ // INDEX-BW-64: %[[RES:.*]] = vector.shape_cast %{{.*}} : vector<[4]xf32> to vector<2x[2]xf32>
// ALL: return %[[RES]] : vector<2x[2]xf32>
return %2 : vector<2x[2]xf32>
}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index f67a24755ac09a..2589782aee1449 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -853,6 +853,10 @@ struct TestVectorLinearize final
registry.insert<vector::VectorDialect>();
}
+ Option<unsigned> indexBitwidth{*this, "index-bitwidth",
+ llvm::cl::desc("Bitwidth of the index type"),
+ llvm::cl::init(0)};
+
Option<unsigned> targetVectorBitwidth{
*this, "target-vector-bitwidth",
llvm::cl::desc(
@@ -866,9 +870,9 @@ struct TestVectorLinearize final
ConversionTarget target(*context);
vector::populateVectorLinearizeTypeConversionsAndLegality(
- typeConverter, patterns, target, targetVectorBitwidth);
+ typeConverter, patterns, target, indexBitwidth, targetVectorBitwidth);
vector::populateVectorLinearizeShuffleLikeOpsPatterns(
- typeConverter, patterns, target, targetVectorBitwidth);
+ typeConverter, patterns, target, indexBitwidth, targetVectorBitwidth);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
More information about the Mlir-commits
mailing list