[Mlir-commits] [mlir] 3a492ab - [mlir][vector] Add linearization pattern for vector.splat (#137651)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 1 14:26:29 PDT 2025
Author: Nishant Patel
Date: 2025-05-01T14:26:26-07:00
New Revision: 3a492abf05521467bc882094272d03c3eb6251c4
URL: https://github.com/llvm/llvm-project/commit/3a492abf05521467bc882094272d03c3eb6251c4
DIFF: https://github.com/llvm/llvm-project/commit/3a492abf05521467bc882094272d03c3eb6251c4.diff
LOG: [mlir][vector] Add linearization pattern for vector.splat (#137651)
This PR is a breakdown [2 / 4] of the PR #136193
The PR adds linearization patterns for vector.splat.
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
mlir/test/Dialect/Vector/linearize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index b9fdede535112..b9cef003fa365 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -293,6 +293,10 @@ struct LinearizeVectorExtract final
LogicalResult
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ // Skip if result is not a vector type
+ if (!isa<VectorType>(extractOp.getType()))
+ return rewriter.notifyMatchFailure(extractOp,
+ "scalar extract is not supported.");
Type dstTy = getTypeConverter()->convertType(extractOp.getType());
assert(dstTy && "expected 1-D vector type");
@@ -415,6 +419,32 @@ struct LinearizeVectorBitCast final
}
};
+/// This pattern converts the SplatOp to work on a linearized vector.
+/// Following,
+/// vector.splat %value : vector<4x4xf32>
+/// is converted to:
+/// %out_1d = vector.splat %value : vector<16xf32>
+/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
+struct LinearizeVectorSplat final
+ : public OpConversionPattern<vector::SplatOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto dstTy = getTypeConverter()->convertType(splatOp.getType());
+ if (!dstTy)
+ return rewriter.notifyMatchFailure(splatOp, "cannot convert type.");
+ rewriter.replaceOpWithNewOp<vector::SplatOp>(splatOp, adaptor.getInput(),
+ dstTy);
+ return success();
+ }
+};
+
} // namespace
/// Return true if the operation `op` does not support scalable vectors and
@@ -501,7 +531,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
const TypeConverter &typeConverter, const ConversionTarget &target,
RewritePatternSet &patterns) {
patterns.add<LinearizeConstantLike, LinearizeVectorizable,
- LinearizeVectorBitCast>(typeConverter, patterns.getContext());
+ LinearizeVectorBitCast, LinearizeVectorSplat>(
+ typeConverter, patterns.getContext());
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 06eaf58b225ae..20169c15eb2c1 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -413,3 +413,37 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
%1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16>
return %1 : vector<[4]x4xf16>
}
+
+// -----
+// ALL-LABEL: linearize_vector_splat
+// ALL-SAME: (%[[ARG:.*]]: i32) -> vector<4x2xi32>
+func.func @linearize_vector_splat(%arg0: i32) -> vector<4x2xi32> {
+ // DEFAULT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8xi32>
+ // DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<8xi32> to vector<4x2xi32>
+ // DEFAULT: return %[[CAST]] : vector<4x2xi32>
+ // BW-128: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8xi32>
+ // BW-128: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<8xi32> to vector<4x2xi32>
+ // BW-128: return %[[CAST]] : vector<4x2xi32>
+
+ // BW-0: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<4x2xi32>
+ // BW-0: return %[[SPLAT]] : vector<4x2xi32>
+ %0 = vector.splat %arg0 : vector<4x2xi32>
+ return %0 : vector<4x2xi32>
+}
+
+// -----
+// ALL-LABEL: linearize_scalable_vector_splat
+// ALL-SAME: (%[[ARG:.*]]: i32) -> vector<4x[2]xi32>
+func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
+ // DEFAULT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<[8]xi32>
+ // DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<[8]xi32> to vector<4x[2]xi32>
+ // DEFAULT: return %[[CAST]] : vector<4x[2]xi32>
+ // BW-128: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<[8]xi32>
+ // BW-128: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<[8]xi32> to vector<4x[2]xi32>
+ // BW-128: return %[[CAST]] : vector<4x[2]xi32>
+
+ // BW-0: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<4x[2]xi32>
+ // BW-0: return %[[SPLAT]] : vector<4x[2]xi32>
+ %0 = vector.splat %arg0 : vector<4x[2]xi32>
+ return %0 : vector<4x[2]xi32>
+}
More information about the Mlir-commits
mailing list