[Mlir-commits] [mlir] [mlir][vector] Support index type in ND to 1D vector linearization (PR #118404)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 2 14:09:14 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Amy Zhuang (ayzhuang)
<details>
<summary>Changes</summary>
Currently index type is not supported because getElementTypeBitWidth aborts for index type. This patch adds indexBitWidth input to the vector linearization patterns.
---
Full diff: https://github.com/llvm/llvm-project/pull/118404.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+2-2)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+59-26)
- (modified) mlir/test/Dialect/Vector/linearize.mlir (+12-3)
- (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+6-2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index a59f06f3c1ef1b..e3c19a078c18b0 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -399,13 +399,13 @@ void populateVectorTransposeNarrowTypeRewritePatterns(
/// the ops to get converted properly.
void populateVectorLinearizeTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target, unsigned targetBitWidth);
+ ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth);
/// Populates patterns for linearizing ND (N >= 2) vector operations to 1D
/// vector shuffle operations.
void populateVectorLinearizeShuffleLikeOpsPatterns(
const TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target, unsigned targetBitWidth);
+ ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth);
} // namespace vector
} // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 757631944f224f..f0bf6276f0e659 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -25,34 +25,44 @@
using namespace mlir;
-static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
+static bool isLessThanTargetBitWidth(Operation *op, unsigned indexBitWidth,
+ unsigned targetBitWidth) {
auto resultTypes = op->getResultTypes();
for (auto resType : resultTypes) {
VectorType vecType = dyn_cast<VectorType>(resType);
- // Reject index since getElementTypeBitWidth will abort for Index types.
- if (!vecType || vecType.getElementType().isIndex())
+ if (!vecType)
+ return false;
+ bool isIndexTy = vecType.getElementType().isIndex();
+ // Reject index if `indexBitWidth` is not supplied.
+ if (isIndexTy && indexBitWidth == 0)
return false;
// There are no dimension to fold if it is a 0-D vector.
if (vecType.getRank() == 0)
return false;
unsigned trailingVecDimBitWidth =
- vecType.getShape().back() * vecType.getElementTypeBitWidth();
+ vecType.getShape().back() *
+ (isIndexTy ? indexBitWidth : vecType.getElementTypeBitWidth());
if (trailingVecDimBitWidth >= targetBitWidth)
return false;
}
return true;
}
-static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
+static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned indexBitWidth,
+ unsigned targetBitWidth) {
VectorType vecType = dyn_cast<VectorType>(t);
- // Reject index since getElementTypeBitWidth will abort for Index types.
- if (!vecType || vecType.getElementType().isIndex())
+ if (!vecType)
+ return false;
+ bool isIndexTy = vecType.getElementType().isIndex();
+ // Reject index if `indexBitWidth` is not supplied.
+ if (isIndexTy && indexBitWidth == 0)
return false;
// There are no dimension to fold if it is a 0-D vector.
if (vecType.getRank() == 0)
return false;
unsigned trailingVecDimBitWidth =
- vecType.getShape().back() * vecType.getElementTypeBitWidth();
+ vecType.getShape().back() *
+ (isIndexTy ? indexBitWidth : vecType.getElementTypeBitWidth());
return trailingVecDimBitWidth <= targetBitWidth;
}
@@ -61,10 +71,12 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
using OpConversionPattern::OpConversionPattern;
LinearizeConstant(
const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned indexBitWidth = 0,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+ }
LogicalResult
matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -79,7 +91,7 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
if (!resType)
return rewriter.notifyMatchFailure(loc, "can't convert return type");
- if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
+ if (!isLessThanTargetBitWidth(constOp, indexBitWidth, targetVectorBitWidth))
return rewriter.notifyMatchFailure(
loc, "Can't flatten since targetBitWidth <= OpSize");
auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
@@ -93,6 +105,7 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
}
private:
+ unsigned indexBitWidth;
unsigned targetVectorBitWidth;
};
@@ -103,14 +116,16 @@ struct LinearizeVectorizable final
public:
LinearizeVectorizable(
const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned indexBitWidth = 0,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
: OpTraitConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+ }
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
+ if (!isLessThanTargetBitWidth(op, indexBitWidth, targetVectorBitWidth))
return rewriter.notifyMatchFailure(
op->getLoc(), "Can't flatten since targetBitWidth <= OpSize");
FailureOr<Operation *> newOp =
@@ -123,6 +138,7 @@ struct LinearizeVectorizable final
}
private:
+ unsigned indexBitWidth;
unsigned targetVectorBitWidth;
};
@@ -142,10 +158,12 @@ struct LinearizeVectorExtractStridedSlice final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorExtractStridedSlice(
const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned indexBitWidth = 0,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+ }
LogicalResult
matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
@@ -156,7 +174,8 @@ struct LinearizeVectorExtractStridedSlice final
if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
return rewriter.notifyMatchFailure(extractOp,
"scalable vectors are not supported.");
- if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
+ if (!isLessThanTargetBitWidth(extractOp, indexBitWidth,
+ targetVectorBitWidth))
return rewriter.notifyMatchFailure(
extractOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -237,6 +256,7 @@ struct LinearizeVectorExtractStridedSlice final
}
private:
+ unsigned indexBitWidth;
unsigned targetVectorBitWidth;
};
@@ -256,10 +276,12 @@ struct LinearizeVectorShuffle final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorShuffle(
const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned indexBitWidth = 0,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+ }
LogicalResult
matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
@@ -273,7 +295,8 @@ struct LinearizeVectorShuffle final
shuffleOp.getV2VectorType().isScalable() ||
dstType.isScalable()) &&
"scalable vectors are not supported.");
- if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
+ if (!isLessThanTargetBitWidth(shuffleOp, indexBitWidth,
+ targetVectorBitWidth))
return rewriter.notifyMatchFailure(
shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -312,6 +335,7 @@ struct LinearizeVectorShuffle final
}
private:
+ unsigned indexBitWidth;
unsigned targetVectorBitWidth;
};
@@ -329,10 +353,12 @@ struct LinearizeVectorExtract final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorExtract(
const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned indexBitWidth = 0,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+ }
LogicalResult
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -345,7 +371,8 @@ struct LinearizeVectorExtract final
cast<VectorType>(dstTy).isScalable())
return rewriter.notifyMatchFailure(extractOp,
"scalable vectors are not supported.");
- if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
+ if (!isLessThanTargetBitWidth(extractOp, indexBitWidth,
+ targetVectorBitWidth))
return rewriter.notifyMatchFailure(
extractOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -374,6 +401,7 @@ struct LinearizeVectorExtract final
}
private:
+ unsigned indexBitWidth;
unsigned targetVectorBitWidth;
};
@@ -392,10 +420,12 @@ struct LinearizeVectorInsert final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorInsert(
const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned indexBitWidth = 0,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+ }
LogicalResult
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -407,7 +437,7 @@ struct LinearizeVectorInsert final
"scalable vectors are not supported.");
if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
- targetVectorBitWidth))
+ indexBitWidth, targetVectorBitWidth))
return rewriter.notifyMatchFailure(
insertOp, "Can't flatten since targetBitWidth < OpSize");
@@ -457,13 +487,14 @@ struct LinearizeVectorInsert final
}
private:
+ unsigned indexBitWidth;
unsigned targetVectorBitWidth;
};
} // namespace
void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target, unsigned targetBitWidth) {
+ ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth) {
typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
if (!isLinearizableVector(type))
@@ -488,7 +519,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
[=](Operation *op) -> std::optional<bool> {
if ((isa<arith::ConstantOp>(op) ||
op->hasTrait<OpTrait::Vectorizable>())) {
- return (isLessThanTargetBitWidth(op, targetBitWidth)
+ return (isLessThanTargetBitWidth(op, indexBitWidth, targetBitWidth)
? typeConverter.isLegal(op)
: true);
}
@@ -496,15 +527,17 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
});
patterns.add<LinearizeConstant, LinearizeVectorizable>(
- typeConverter, patterns.getContext(), targetBitWidth);
+ typeConverter, patterns.getContext(), indexBitWidth, targetBitWidth);
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
const TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target, unsigned int targetBitWidth) {
+ ConversionTarget &target, unsigned indexBitWidth,
+ unsigned int targetBitWidth) {
target.addDynamicallyLegalOp<vector::ShuffleOp>(
[=](vector::ShuffleOp shuffleOp) -> bool {
- return isLessThanTargetBitWidth(shuffleOp, targetBitWidth)
+ return isLessThanTargetBitWidth(shuffleOp, indexBitWidth,
+ targetBitWidth)
? (typeConverter.isLegal(shuffleOp) &&
cast<mlir::VectorType>(shuffleOp.getResult().getType())
.getRank() == 1)
@@ -512,5 +545,5 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
});
patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
- typeConverter, patterns.getContext(), targetBitWidth);
+ typeConverter, patterns.getContext(), indexBitWidth, targetBitWidth);
}
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 543e76b5b26e0c..fe169d3e16d683 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s --check-prefixes=ALL,DEFAULT
// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=128 -verify-diagnostics | FileCheck %s --check-prefixes=ALL,BW-128
// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefixes=ALL,BW-0
+// RUN: mlir-opt %s -split-input-file -test-vector-linearize=index-bitwidth=64 | FileCheck %s --check-prefixes=ALL,INDEX-BW-64
// ALL-LABEL: test_linearize
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
@@ -14,6 +15,8 @@ func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
// BW-128: %[[RES:.*]] = vector.shape_cast %[[CST]] : vector<4xf32> to vector<2x2xf32>
// BW-0: %[[RES:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32>
+
+ // INDEX-BW-64: %[[RES:.*]] = vector.shape_cast %{{.*}} : vector<4xf32> to vector<2x2xf32>
%0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
// DEFAULT: %{{.*}} = math.sin %[[ARG]] : vector<4xf32>
@@ -45,6 +48,8 @@ func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>
// BW-128: %[[RES:.*]] = vector.shape_cast %[[CST]] : vector<4xf32> to vector<2x2xf32>
// BW-0: %[[RES:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32>
+
+ // INDEX-BW-64: %[[RES:.*]] = vector.shape_cast %{{.*}} : vector<4xf32> to vector<2x2xf32>
%0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
// DEFAULT: %[[C2:.*]] = arith.constant dense<{{.*}}> : vector<16xf32>
@@ -79,9 +84,12 @@ func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>
// -----
-// ALL-LABEL: test_index_no_linearize
-func.func @test_index_no_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> {
- // ALL: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+// ALL-LABEL: test_index_linearize
+func.func @test_index_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> {
+ // DEFAULT: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+ // BW-128: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+ // BW-0: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+ // INDEX-BW-64: %[[ADD:.*]] = arith.addi {{.*}} : vector<4xindex>
%0 = arith.addi %arg0, %arg1 : vector<2x2xindex>
return %0 : vector<2x2xindex>
}
@@ -122,6 +130,7 @@ func.func @test_scalable_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32
// DEFAULT: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[4]xf32> to vector<2x[2]xf32>
// BW-128: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[4]xf32> to vector<2x[2]xf32>
+ // INDEX-BW-64: %[[RES:.*]] = vector.shape_cast %{{.*}} : vector<[4]xf32> to vector<2x[2]xf32>
// ALL: return %[[RES]] : vector<2x[2]xf32>
return %2 : vector<2x[2]xf32>
}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index f67a24755ac09a..2589782aee1449 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -853,6 +853,10 @@ struct TestVectorLinearize final
registry.insert<vector::VectorDialect>();
}
+ Option<unsigned> indexBitwidth{*this, "index-bitwidth",
+ llvm::cl::desc("Bitwidth of the index type"),
+ llvm::cl::init(0)};
+
Option<unsigned> targetVectorBitwidth{
*this, "target-vector-bitwidth",
llvm::cl::desc(
@@ -866,9 +870,9 @@ struct TestVectorLinearize final
ConversionTarget target(*context);
vector::populateVectorLinearizeTypeConversionsAndLegality(
- typeConverter, patterns, target, targetVectorBitwidth);
+ typeConverter, patterns, target, indexBitwidth, targetVectorBitwidth);
vector::populateVectorLinearizeShuffleLikeOpsPatterns(
- typeConverter, patterns, target, targetVectorBitwidth);
+ typeConverter, patterns, target, indexBitwidth, targetVectorBitwidth);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
``````````
</details>
https://github.com/llvm/llvm-project/pull/118404
More information about the Mlir-commits
mailing list