[Mlir-commits] [flang] [mlir] [MLIR][Vector] Remove vector.splat (PR #162167)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Oct 6 15:53:05 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-amdgpu
Author: James Newling (newling)
<details>
<summary>Changes</summary>
vector.splat has been deprecated (use the very similar vector.broadcast instead) with the last PR landing about 6 weeks ago. See https://discourse.llvm.org/t/rfc-mlir-vector-deprecate-then-remove-vector-splat/87143/1
This PR completely removes vector.splat. In addition to removing vector.splat from VectorOps.td, it
- Updates the few remaining places where vector::SplatOp is created (now vector::BroadcastOp is created)
- Removes temporary patterns where vector.splat is replaced by vector.broadcast
The only place 'vector.splat' appears is now the files
https://discourse.llvm.org/t/rfc-mlir-vector-deprecate-then-remove-vector-splat/87143/1 and
https://github.com/llvm/llvm-project/blob/main/mlir/utils/tree-sitter-mlir/dialect/vector.js
@<!-- -->artagnon maybe you can tell me what these files are, and if they should be updated?
---
Patch is 50.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/162167.diff
28 Files Affected:
- (modified) flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp (+10-9)
- (modified) mlir/docs/Dialects/Vector.md (+1-3)
- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (-47)
- (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (-4)
- (modified) mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp (+1-1)
- (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (+4-18)
- (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+1-14)
- (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+4-19)
- (modified) mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp (+1-2)
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+6-46)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+1-1)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+3-29)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+20-27)
- (modified) mlir/test/Conversion/ConvertToSPIRV/vector.mlir (+1-1)
- (modified) mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir (-32)
- (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir (-17)
- (modified) mlir/test/Dialect/Math/canonicalize_ipowi.mlir (+2-2)
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+1-1)
- (removed) mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir (-126)
- (modified) mlir/test/Dialect/Vector/int-range-interface.mlir (+1-1)
- (modified) mlir/test/Dialect/Vector/invalid.mlir (+10-33)
- (modified) mlir/test/Dialect/Vector/linearize.mlir (-27)
- (modified) mlir/test/Dialect/Vector/ops.mlir (+7-29)
- (modified) mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir (+2-2)
- (modified) mlir/test/IR/invalid-ops.mlir (+1-1)
- (modified) mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir (-7)
- (modified) mlir/test/mlir-runner/utils.mlir (+1-1)
- (modified) mlir/utils/tree-sitter-mlir/queries/highlights.scm (-1)
``````````diff
diff --git a/flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp b/flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp
index 03952da95b11e..265e268ab1b09 100644
--- a/flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp
@@ -2383,7 +2383,7 @@ PPCIntrinsicLibrary::genVecSplat(mlir::Type resultType,
auto context{builder.getContext()};
auto argBases{getBasesForArgs(args)};
- mlir::vector::SplatOp splatOp{nullptr};
+ mlir::vector::BroadcastOp splatOp{nullptr};
mlir::Type retTy{nullptr};
switch (vop) {
case VecOp::Splat: {
@@ -2391,9 +2391,9 @@ PPCIntrinsicLibrary::genVecSplat(mlir::Type resultType,
auto vecTyInfo{getVecTypeFromFir(argBases[0])};
auto extractOp{genVecExtract(resultType, args)};
- splatOp =
- mlir::vector::SplatOp::create(builder, loc, *(extractOp.getUnboxed()),
- vecTyInfo.toMlirVectorType(context));
+ splatOp = mlir::vector::BroadcastOp::create(
+ builder, loc, vecTyInfo.toMlirVectorType(context),
+ *(extractOp.getUnboxed()));
retTy = vecTyInfo.toFirVectorType();
break;
}
@@ -2401,8 +2401,8 @@ PPCIntrinsicLibrary::genVecSplat(mlir::Type resultType,
assert(args.size() == 1);
auto vecTyInfo{getVecTypeFromEle(argBases[0])};
- splatOp = mlir::vector::SplatOp::create(
- builder, loc, argBases[0], vecTyInfo.toMlirVectorType(context));
+ splatOp = mlir::vector::BroadcastOp::create(
+ builder, loc, vecTyInfo.toMlirVectorType(context), argBases[0]);
retTy = vecTyInfo.toFirVectorType();
break;
}
@@ -2412,8 +2412,8 @@ PPCIntrinsicLibrary::genVecSplat(mlir::Type resultType,
auto intOp{builder.createConvert(loc, eleTy, argBases[0])};
// the intrinsic always returns vector(integer(4))
- splatOp = mlir::vector::SplatOp::create(builder, loc, intOp,
- mlir::VectorType::get(4, eleTy));
+ splatOp = mlir::vector::BroadcastOp::create(
+ builder, loc, mlir::VectorType::get(4, eleTy), intOp);
retTy = fir::VectorType::get(4, eleTy);
break;
}
@@ -2444,7 +2444,8 @@ PPCIntrinsicLibrary::genVecXlds(mlir::Type resultType,
auto addrConv{fir::ConvertOp::create(builder, loc, i64RefTy, addr)};
auto addrVal{fir::LoadOp::create(builder, loc, addrConv)};
- auto splatRes{mlir::vector::SplatOp::create(builder, loc, addrVal, i64VecTy)};
+ auto splatRes{
+ mlir::vector::BroadcastOp::create(builder, loc, i64VecTy, addrVal)};
mlir::Value result{nullptr};
if (mlirTy != splatRes.getType()) {
diff --git a/mlir/docs/Dialects/Vector.md b/mlir/docs/Dialects/Vector.md
index 6c8949d70b4a3..839dc75ff0214 100644
--- a/mlir/docs/Dialects/Vector.md
+++ b/mlir/docs/Dialects/Vector.md
@@ -125,7 +125,7 @@ Some existing Arith and Vector Dialect on `n-D` `vector` types comprise:
// Produces a vector<3x7x8xf32>
%b = arith.mulf %0, %1 : vector<3x7x8xf32>
// Produces a vector<3x7x8xf32>
-%c = vector.splat %1 : vector<3x7x8xf32>
+%c = vector.broadcast %1 : f32 to vector<3x7x8xf32>
%d = vector.extract %0[1]: vector<7x8xf32> from vector<3x7x8xf32>
%e = vector.extract %0[1, 5]: vector<8xf32> from vector<3x7x8xf32>
@@ -176,8 +176,6 @@ infrastructure can apply iteratively.
### Virtual Vector to Hardware Vector Lowering
For now, `VV -> HWV` are specified in C++ (see for instance the
-[SplatOpLowering for n-D vectors](https://github.com/tensorflow/mlir/commit/0a0c4867c6a6fcb0a2f17ef26a791c1d551fe33d)
-or the
[VectorOuterProductOp lowering](https://github.com/tensorflow/mlir/commit/957b1ca9680b4aacabb3a480fbc4ebd2506334b8)).
Simple
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 252c0b72456df..41e075467910f 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2880,53 +2880,6 @@ def Vector_PrintOp :
}];
}
-//===----------------------------------------------------------------------===//
-// SplatOp
-//===----------------------------------------------------------------------===//
-
-def Vector_SplatOp : Vector_Op<"splat", [
- Pure,
- DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
- TypesMatchWith<"operand type matches element type of result",
- "aggregate", "input",
- "::llvm::cast<VectorType>($_self).getElementType()">
- ]> {
- let summary = "vector splat or broadcast operation";
- let description = [{
- Note: This operation is deprecated. Please use vector.broadcast.
-
- Broadcast the operand to all elements of the result vector. The type of the
- operand must match the element type of the vector type.
-
- Example:
-
- ```mlir
- %s = arith.constant 10.1 : f32
- %t = vector.splat %s : vector<8x16xf32>
- ```
-
- This operation is deprecated, the preferred representation of the above is:
-
- ```mlir
- %s = arith.constant 10.1 : f32
- %t = vector.broadcast %s : f32 to vector<8x16xf32>
- ```
- }];
-
- let arguments = (ins AnyType:$input);
- let results = (outs AnyVectorOfAnyRank:$aggregate);
-
- let builders = [
- OpBuilder<(ins "Value":$element, "Type":$aggregateType),
- [{ build($_builder, $_state, aggregateType, element); }]>];
- let assemblyFormat = "$input attr-dict `:` type($aggregate)";
-
- let hasFolder = 1;
-
- // vector.splat is deprecated, and vector.broadcast should be used instead.
- // Canonicalize vector.splat to vector.broadcast.
- let hasCanonicalizer = 1;
-}
//===----------------------------------------------------------------------===//
// VectorScaleOp
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index dcbaa5698d767..247dba101cfc1 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -432,10 +432,6 @@ static Value getOriginalVectorValue(Value value) {
current = op.getSource();
return false;
})
- .Case<vector::SplatOp>([¤t](auto op) {
- current = op.getInput();
- return false;
- })
.Default([](Operation *) { return false; });
if (!skipOp) {
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index bad53c0a4a97a..1002ebe6875b6 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -236,7 +236,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
/// AFTER:
/// ```mlir
/// ...
-/// %pad_1d = vector.splat %pad : vector<[4]xi32>
+/// %pad_1d = vector.broadcast %pad : i32 to vector<[4]xi32>
/// %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
/// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
/// ...
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 363685a691180..778c616f1bf44 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -731,28 +731,14 @@ struct ExtractFromCreateMaskToPselLowering
}
};
-// Convert all `vector.splat` to `vector.broadcast`. There is a path from
-// `vector.broadcast` to ArmSME via another pattern.
-struct ConvertSplatToBroadcast : public OpRewritePattern<vector::SplatOp> {
- using Base::Base;
-
- LogicalResult matchAndRewrite(vector::SplatOp splatOp,
- PatternRewriter &rewriter) const final {
-
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
- splatOp.getInput());
- return success();
- }
-};
-
} // namespace
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
MLIRContext &ctx) {
- patterns.add<BroadcastOpToArmSMELowering, ConvertSplatToBroadcast,
- TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
- TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
- VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
+ patterns.add<BroadcastOpToArmSMELowering, TransferReadToArmSMELowering,
+ TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
+ VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
+ VectorOuterProductToArmSMELowering,
VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice,
ExtractFromCreateMaskToPselLowering>(&ctx);
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 546164628b795..5355909b62a7f 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -2161,19 +2161,6 @@ class TransposeOpToMatrixTransposeOpLowering
}
};
-/// Convert `vector.splat` to `vector.broadcast`. There is a path to LLVM from
-/// `vector.broadcast` through other patterns.
-struct VectorSplatToBroadcast : public ConvertOpToLLVMPattern<vector::SplatOp> {
- using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
- LogicalResult
- matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
- adaptor.getInput());
- return success();
- }
-};
-
} // namespace
void mlir::vector::populateVectorRankReducingFMAPattern(
@@ -2212,7 +2199,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorInsertOpConversion, VectorPrintOpConversion,
VectorTypeCastOpConversion, VectorScaleOpConversion,
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
- VectorSplatToBroadcast, VectorBroadcastScalarToLowRankLowering,
+ VectorBroadcastScalarToLowRankLowering,
VectorBroadcastScalarToNdLowering,
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
MaskedReductionOpConversion, VectorInterleaveOpLowering,
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 311ff6f5fbeee..56e8fee191432 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -22,7 +22,6 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
-#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -79,20 +78,6 @@ struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {
}
};
-// Convert `vector.splat` to `vector.broadcast`. There is a path from
-// `vector.broadcast` to SPIRV via other patterns.
-struct VectorSplatToBroadcast final
- : public OpConversionPattern<vector::SplatOp> {
- using Base::Base;
- LogicalResult
- matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
- adaptor.getInput());
- return success();
- }
-};
-
struct VectorBitcastConvert final
: public OpConversionPattern<vector::BitCastOp> {
using Base::Base;
@@ -1092,10 +1077,10 @@ void mlir::populateVectorToSPIRVPatterns(
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
- VectorSplatToBroadcast, VectorInsertStridedSliceOpConvert,
- VectorShuffleOpConvert, VectorInterleaveOpConvert,
- VectorDeinterleaveOpConvert, VectorScalarBroadcastPattern,
- VectorLoadOpConverter, VectorStoreOpConverter, VectorStepOpConvert>(
+ VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
+ VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
+ VectorScalarBroadcastPattern, VectorLoadOpConverter,
+ VectorStoreOpConverter, VectorStepOpConvert>(
typeConverter, patterns.getContext(), PatternBenefit(1));
// Make sure that the more specialized dot product pattern has higher benefit
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index c64e10f534f8e..d018cddeb8dc1 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -123,8 +123,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality(
vector::OuterProductOp, vector::ScanOp>(
[&](Operation *op) { return converter.isLegal(op); });
target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
- arith::ConstantOp, arith::SelectOp, vector::SplatOp,
- vector::BroadcastOp>();
+ arith::ConstantOp, arith::SelectOp, vector::BroadcastOp>();
}
void EmulateUnsupportedFloatsPass::runOnOperation() {
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b0132e889302f..dc58ac3cdee6f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1664,10 +1664,10 @@ static bool hasZeroDimVectors(Operation *op) {
llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
}
-/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend
-/// 1s, are considered to be 'broadcastlike'.
+/// All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are
+/// considered to be 'broadcastlike'.
static bool isBroadcastLike(Operation *op) {
- if (isa<BroadcastOp, SplatOp>(op))
+ if (isa<BroadcastOp>(op))
return true;
auto shapeCast = dyn_cast<ShapeCastOp>(op);
@@ -3131,12 +3131,11 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
};
/// Consider the defining operation `defOp` of `value`. If `defOp` is a
-/// vector.splat or a vector.broadcast with a scalar operand, return the scalar
-/// value that is splatted. Otherwise return null.
+/// vector.broadcast with a scalar operand, return the scalar value that is
+/// splatted. Otherwise return null.
///
-/// Examples:
+/// Example:
///
-/// scalar_source --> vector.splat --> value - return scalar_source
/// scalar_source --> vector.broadcast --> value - return scalar_source
static Value getScalarSplatSource(Value value) {
// Block argument:
@@ -3144,10 +3143,6 @@ static Value getScalarSplatSource(Value value) {
if (!defOp)
return {};
- // Splat:
- if (auto splat = dyn_cast<vector::SplatOp>(defOp))
- return splat.getInput();
-
auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
// Not broadcast (and not splat):
@@ -7393,41 +7388,6 @@ void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
patterns.getContext(), benefit);
}
-//===----------------------------------------------------------------------===//
-// SplatOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
- auto constOperand = adaptor.getInput();
- if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
- return {};
-
- // SplatElementsAttr::get treats single value for second arg as being a splat.
- return SplatElementsAttr::get(getType(), {constOperand});
-}
-
-// Canonicalizer for vector.splat. It always gets canonicalized to a
-// vector.broadcast.
-class SplatToBroadcastPattern final : public OpRewritePattern<SplatOp> {
-public:
- using Base::Base;
- LogicalResult matchAndRewrite(SplatOp splatOp,
- PatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
- splatOp.getOperand());
- return success();
- }
-};
-void SplatOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<SplatToBroadcastPattern>(context);
-}
-
-void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
- SetIntRangeFn setResultRanges) {
- setResultRanges(getResult(), argRanges.front());
-}
-
Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
CombiningKind kind, Value v1, Value acc,
arith::FastMathFlagsAttr fastmath,
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 255f2bf5a8161..3a3231d513369 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -90,7 +90,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
Operation *maskOp = mask.getDefiningOp();
SmallVector<vector::ExtractOp, 2> extractOps;
- // TODO: add support to `vector.splat`.
+ // TODO: add support to `vector.broadcast`.
// Finding the mask creation operation.
while (maskOp &&
!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 71fba71c9f15f..1b656d82f3201 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -590,32 +590,6 @@ struct LinearizeVectorBitCast final
}
};
-/// This pattern converts the SplatOp to work on a linearized vector.
-/// Following,
-/// vector.splat %value : vector<4x4xf32>
-/// is converted to:
-/// %out_1d = vector.splat %value : vector<16xf32>
-/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
-struct LinearizeVectorSplat final
- : public OpConversionPattern<vector::SplatOp> {
- using Base::Base;
-
- LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context,
- PatternBenefit benefit = 1)
- : OpConversionPattern(typeConverter, context, benefit) {}
-
- LogicalResult
- matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto dstTy = getTypeConverter()->convertType(splatOp.getType());
- if (!dstTy)
- return rewriter.notifyMatchFailure(splatOp, "cannot convert type.");
- rewriter.replaceOpWithNewOp<vector::SplatOp>(splatOp, adaptor.getInput(),
- dstTy);
- return success();
- }
-};
-
/// This pattern converts the CreateMaskOp to work on a linearized vector.
/// It currently supports only 2D masks with a unit outer dimension.
/// Following,
@@ -934,9 +908,9 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
RewritePatternSet &patterns) {
patterns
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
- LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
- LinearizeVectorStore, LinearizeVectorFromElements,
- LinearizeVectorToElements>(typeConverter, patterns.getContext());
+ LinearizeVectorCreateMask, LinearizeVectorLoad, LinearizeVectorStore,
+ LinearizeVectorFromElements, LinearizeVectorToElements>(
+ typeConverter, patterns.getContext());
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/162167
More information about the Mlir-commits
mailing list