[Mlir-commits] [mlir] [mlir][vector] Add linearization pattern for vector.splat (PR #137651)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Apr 28 08:35:00 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Nishant Patel (nbpatel)
<details>
<summary>Changes</summary>
This PR is a breakdown [2 / 4] of the PR #<!-- -->136193
The PR adds linearization patterns for vector.splat.
---
Full diff: https://github.com/llvm/llvm-project/pull/137651.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+53-10)
- (modified) mlir/test/Dialect/Vector/linearize.mlir (+17)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index a009aa03aaf64..45c7e37738898 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -26,6 +26,9 @@
using namespace mlir;
+constexpr unsigned defaultTargetVectorBitWidth =
+ std::numeric_limits<unsigned>::max();
+
static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
auto resultTypes = op->getResultTypes();
for (auto resType : resultTypes) {
@@ -82,7 +85,7 @@ struct LinearizeConstantLike final
LinearizeConstantLike(
const TypeConverter &typeConverter, MLIRContext *context,
- unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
PatternBenefit benefit = 1)
: OpTraitConversionPattern(typeConverter, context, benefit),
targetVectorBitWidth(targetVectBitWidth) {}
@@ -136,7 +139,7 @@ struct LinearizeVectorizable final
public:
LinearizeVectorizable(
const TypeConverter &typeConverter, MLIRContext *context,
- unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
PatternBenefit benefit = 1)
: OpTraitConversionPattern(typeConverter, context, benefit),
targetVectorBitWidth(targetVectBitWidth) {}
@@ -175,7 +178,7 @@ struct LinearizeVectorExtractStridedSlice final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorExtractStridedSlice(
const TypeConverter &typeConverter, MLIRContext *context,
- unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
targetVectorBitWidth(targetVectBitWidth) {}
@@ -289,7 +292,7 @@ struct LinearizeVectorShuffle final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorShuffle(
const TypeConverter &typeConverter, MLIRContext *context,
- unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
targetVectorBitWidth(targetVectBitWidth) {}
@@ -362,13 +365,17 @@ struct LinearizeVectorExtract final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorExtract(
const TypeConverter &typeConverter, MLIRContext *context,
- unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
targetVectorBitWidth(targetVectBitWidth) {}
LogicalResult
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ // Skip if result is not a vector type
+ if (!isa<VectorType>(extractOp.getType()))
+ return rewriter.notifyMatchFailure(extractOp,
+ "scalar extract is not supported.");
Type dstTy = getTypeConverter()->convertType(extractOp.getType());
if (!dstTy)
return rewriter.notifyMatchFailure(extractOp,
@@ -425,7 +432,7 @@ struct LinearizeVectorInsert final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorInsert(
const TypeConverter &typeConverter, MLIRContext *context,
- unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
targetVectorBitWidth(targetVectBitWidth) {}
@@ -506,7 +513,7 @@ struct LinearizeVectorBitCast final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorBitCast(
const TypeConverter &typeConverter, MLIRContext *context,
- unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
targetVectorBitWidth(targetVectBitWidth) {}
@@ -531,12 +538,48 @@ struct LinearizeVectorBitCast final
unsigned targetVectorBitWidth;
};
+/// This pattern converts the SplatOp to work on a linearized vector.
+/// Following,
+/// vector.splat %value : vector<4x4xf32>
+/// is converted to:
+/// %out_1d = vector.splat %value : vector<16xf32>
+/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
+/// It ensures that the operation is compatible with the target vector
+/// bit width and replaces the original operation with a new SplatOp
+/// that operates on the converted type.
+struct LinearizeVectorSplat final
+ : public OpConversionPattern<vector::SplatOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LinearizeVectorSplat(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto dstTy = getTypeConverter()->convertType(splatOp.getType());
+ if (!dstTy)
+ return rewriter.notifyMatchFailure(splatOp, "cannot convert type.");
+ rewriter.replaceOpWithNewOp<vector::SplatOp>(splatOp, adaptor.getInput(),
+ dstTy);
+ return success();
+ }
+
+private:
+ unsigned targetVectorBitWidth;
+};
+
} // namespace
void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, unsigned targetBitWidth) {
+ typeConverter.addConversion([](Type type) -> Type { return type; });
typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
if (!isLinearizableVector(type))
return type;
@@ -557,7 +600,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
typeConverter.addTargetMaterialization(materializeCast);
target.markUnknownOpDynamicallyLegal(
[=](Operation *op) -> std::optional<bool> {
- if ((isa<vector::BitCastOp>(op) ||
+ if ((isa<vector::BitCastOp, vector::SplatOp>(op) ||
op->hasTrait<OpTrait::ConstantLike>() ||
op->hasTrait<OpTrait::Vectorizable>())) {
return (isLessThanTargetBitWidth(op, targetBitWidth)
@@ -568,8 +611,8 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
});
patterns.add<LinearizeConstantLike, LinearizeVectorizable,
- LinearizeVectorBitCast>(typeConverter, patterns.getContext(),
- targetBitWidth);
+ LinearizeVectorBitCast, LinearizeVectorSplat>(
+ typeConverter, patterns.getContext(), targetBitWidth);
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 9052c6440e6ac..89f01abb79a74 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -399,3 +399,20 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
%1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16>
return %1 : vector<[4]x4xf16>
}
+
+// -----
+// ALL-LABEL: linearize_vector_splat
+// ALL-SAME: (%[[ARG:.*]]: i32) -> vector<4x2xi32>
+func.func @linearize_vector_splat(%arg0: i32) -> vector<4x2xi32> {
+ // DEFAULT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8xi32>
+ // DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<8xi32> to vector<4x2xi32>
+ // DEFAULT: return %[[CAST]] : vector<4x2xi32>
+ // BW-128: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8xi32>
+ // BW-128: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<8xi32> to vector<4x2xi32>
+ // BW-128: return %[[CAST]] : vector<4x2xi32>
+
+ // BW-0: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<4x2xi32>
+ // BW-0: return %[[SPLAT]] : vector<4x2xi32>
+ %0 = vector.splat %arg0 : vector<4x2xi32>
+ return %0 : vector<4x2xi32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/137651
More information about the Mlir-commits
mailing list