[Mlir-commits] [mlir] [MLIR][Vector] Remove vector.splat (PR #162167)
James Newling
llvmlistbot at llvm.org
Mon Oct 6 14:50:35 PDT 2025
https://github.com/newling created https://github.com/llvm/llvm-project/pull/162167
None
>From 847ce9666291aa8c7c957e6131af9119ce54fcbc Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 6 Oct 2025 14:52:34 -0700
Subject: [PATCH] complete removal of vector.splat
---
mlir/docs/Dialects/Vector.md | 2 +-
.../mlir/Dialect/Vector/IR/VectorOps.td | 47 -------
.../ArithToAMDGPU/ArithToAMDGPU.cpp | 4 -
.../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 2 +-
.../VectorToArmSME/VectorToArmSME.cpp | 22 +--
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 15 +--
.../VectorToSPIRV/VectorToSPIRV.cpp | 23 +---
.../Transforms/EmulateUnsupportedFloats.cpp | 3 +-
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 52 +-------
.../Transforms/VectorEmulateNarrowType.cpp | 2 +-
.../Vector/Transforms/VectorLinearize.cpp | 32 +----
.../Vector/Transforms/VectorTransforms.cpp | 47 +++----
.../Conversion/ConvertToSPIRV/vector.mlir | 2 +-
.../VectorToArmSME/vector-to-arm-sme.mlir | 32 -----
.../vector-to-llvm-interface.mlir | 17 ---
.../test/Dialect/Math/canonicalize_ipowi.mlir | 4 +-
mlir/test/Dialect/Vector/canonicalize.mlir | 2 +-
.../Vector/canonicalize/vector-splat.mlir | 126 ------------------
.../Dialect/Vector/int-range-interface.mlir | 2 +-
mlir/test/Dialect/Vector/invalid.mlir | 43 ++----
mlir/test/Dialect/Vector/linearize.mlir | 27 ----
mlir/test/Dialect/Vector/ops.mlir | 36 +----
.../vector-emulate-masked-load-store.mlir | 4 +-
mlir/test/IR/invalid-ops.mlir | 2 +-
.../Dialect/Vector/CPU/0-d-vectors.mlir | 7 -
mlir/test/mlir-runner/utils.mlir | 2 +-
.../tree-sitter-mlir/queries/highlights.scm | 1 -
27 files changed, 68 insertions(+), 490 deletions(-)
delete mode 100644 mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir
diff --git a/mlir/docs/Dialects/Vector.md b/mlir/docs/Dialects/Vector.md
index 6c8949d70b4a3..ccf7130dbfd7d 100644
--- a/mlir/docs/Dialects/Vector.md
+++ b/mlir/docs/Dialects/Vector.md
@@ -125,7 +125,7 @@ Some existing Arith and Vector Dialect on `n-D` `vector` types comprise:
// Produces a vector<3x7x8xf32>
%b = arith.mulf %0, %1 : vector<3x7x8xf32>
// Produces a vector<3x7x8xf32>
-%c = vector.splat %1 : vector<3x7x8xf32>
+%c = vector.broadcast %1 : f32 to vector<3x7x8xf32>
%d = vector.extract %0[1]: vector<7x8xf32> from vector<3x7x8xf32>
%e = vector.extract %0[1, 5]: vector<8xf32> from vector<3x7x8xf32>
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 252c0b72456df..41e075467910f 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2880,53 +2880,6 @@ def Vector_PrintOp :
}];
}
-//===----------------------------------------------------------------------===//
-// SplatOp
-//===----------------------------------------------------------------------===//
-
-def Vector_SplatOp : Vector_Op<"splat", [
- Pure,
- DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
- TypesMatchWith<"operand type matches element type of result",
- "aggregate", "input",
- "::llvm::cast<VectorType>($_self).getElementType()">
- ]> {
- let summary = "vector splat or broadcast operation";
- let description = [{
- Note: This operation is deprecated. Please use vector.broadcast.
-
- Broadcast the operand to all elements of the result vector. The type of the
- operand must match the element type of the vector type.
-
- Example:
-
- ```mlir
- %s = arith.constant 10.1 : f32
- %t = vector.splat %s : vector<8x16xf32>
- ```
-
- This operation is deprecated, the preferred representation of the above is:
-
- ```mlir
- %s = arith.constant 10.1 : f32
- %t = vector.broadcast %s : f32 to vector<8x16xf32>
- ```
- }];
-
- let arguments = (ins AnyType:$input);
- let results = (outs AnyVectorOfAnyRank:$aggregate);
-
- let builders = [
- OpBuilder<(ins "Value":$element, "Type":$aggregateType),
- [{ build($_builder, $_state, aggregateType, element); }]>];
- let assemblyFormat = "$input attr-dict `:` type($aggregate)";
-
- let hasFolder = 1;
-
- // vector.splat is deprecated, and vector.broadcast should be used instead.
- // Canonicalize vector.splat to vector.broadcast.
- let hasCanonicalizer = 1;
-}
//===----------------------------------------------------------------------===//
// VectorScaleOp
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index dcbaa5698d767..247dba101cfc1 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -432,10 +432,6 @@ static Value getOriginalVectorValue(Value value) {
current = op.getSource();
return false;
})
- .Case<vector::SplatOp>([¤t](auto op) {
- current = op.getInput();
- return false;
- })
.Default([](Operation *) { return false; });
if (!skipOp) {
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index bad53c0a4a97a..1002ebe6875b6 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -236,7 +236,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
/// AFTER:
/// ```mlir
/// ...
-/// %pad_1d = vector.splat %pad : vector<[4]xi32>
+/// %pad_1d = vector.broadcast %pad : i32 to vector<[4]xi32>
/// %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
/// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
/// ...
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 363685a691180..778c616f1bf44 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -731,28 +731,14 @@ struct ExtractFromCreateMaskToPselLowering
}
};
-// Convert all `vector.splat` to `vector.broadcast`. There is a path from
-// `vector.broadcast` to ArmSME via another pattern.
-struct ConvertSplatToBroadcast : public OpRewritePattern<vector::SplatOp> {
- using Base::Base;
-
- LogicalResult matchAndRewrite(vector::SplatOp splatOp,
- PatternRewriter &rewriter) const final {
-
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
- splatOp.getInput());
- return success();
- }
-};
-
} // namespace
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
MLIRContext &ctx) {
- patterns.add<BroadcastOpToArmSMELowering, ConvertSplatToBroadcast,
- TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
- TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
- VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
+ patterns.add<BroadcastOpToArmSMELowering, TransferReadToArmSMELowering,
+ TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
+ VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
+ VectorOuterProductToArmSMELowering,
VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice,
ExtractFromCreateMaskToPselLowering>(&ctx);
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 546164628b795..5355909b62a7f 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -2161,19 +2161,6 @@ class TransposeOpToMatrixTransposeOpLowering
}
};
-/// Convert `vector.splat` to `vector.broadcast`. There is a path to LLVM from
-/// `vector.broadcast` through other patterns.
-struct VectorSplatToBroadcast : public ConvertOpToLLVMPattern<vector::SplatOp> {
- using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
- LogicalResult
- matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
- adaptor.getInput());
- return success();
- }
-};
-
} // namespace
void mlir::vector::populateVectorRankReducingFMAPattern(
@@ -2212,7 +2199,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorInsertOpConversion, VectorPrintOpConversion,
VectorTypeCastOpConversion, VectorScaleOpConversion,
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
- VectorSplatToBroadcast, VectorBroadcastScalarToLowRankLowering,
+ VectorBroadcastScalarToLowRankLowering,
VectorBroadcastScalarToNdLowering,
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
MaskedReductionOpConversion, VectorInterleaveOpLowering,
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 311ff6f5fbeee..56e8fee191432 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -22,7 +22,6 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
-#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -79,20 +78,6 @@ struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {
}
};
-// Convert `vector.splat` to `vector.broadcast`. There is a path from
-// `vector.broadcast` to SPIRV via other patterns.
-struct VectorSplatToBroadcast final
- : public OpConversionPattern<vector::SplatOp> {
- using Base::Base;
- LogicalResult
- matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
- adaptor.getInput());
- return success();
- }
-};
-
struct VectorBitcastConvert final
: public OpConversionPattern<vector::BitCastOp> {
using Base::Base;
@@ -1092,10 +1077,10 @@ void mlir::populateVectorToSPIRVPatterns(
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
- VectorSplatToBroadcast, VectorInsertStridedSliceOpConvert,
- VectorShuffleOpConvert, VectorInterleaveOpConvert,
- VectorDeinterleaveOpConvert, VectorScalarBroadcastPattern,
- VectorLoadOpConverter, VectorStoreOpConverter, VectorStepOpConvert>(
+ VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
+ VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
+ VectorScalarBroadcastPattern, VectorLoadOpConverter,
+ VectorStoreOpConverter, VectorStepOpConvert>(
typeConverter, patterns.getContext(), PatternBenefit(1));
// Make sure that the more specialized dot product pattern has higher benefit
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index c64e10f534f8e..d018cddeb8dc1 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -123,8 +123,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality(
vector::OuterProductOp, vector::ScanOp>(
[&](Operation *op) { return converter.isLegal(op); });
target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
- arith::ConstantOp, arith::SelectOp, vector::SplatOp,
- vector::BroadcastOp>();
+ arith::ConstantOp, arith::SelectOp, vector::BroadcastOp>();
}
void EmulateUnsupportedFloatsPass::runOnOperation() {
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b0132e889302f..dc58ac3cdee6f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1664,10 +1664,10 @@ static bool hasZeroDimVectors(Operation *op) {
llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
}
-/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend
-/// 1s, are considered to be 'broadcastlike'.
+/// All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are
+/// considered to be 'broadcastlike'.
static bool isBroadcastLike(Operation *op) {
- if (isa<BroadcastOp, SplatOp>(op))
+ if (isa<BroadcastOp>(op))
return true;
auto shapeCast = dyn_cast<ShapeCastOp>(op);
@@ -3131,12 +3131,11 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
};
/// Consider the defining operation `defOp` of `value`. If `defOp` is a
-/// vector.splat or a vector.broadcast with a scalar operand, return the scalar
-/// value that is splatted. Otherwise return null.
+/// vector.broadcast with a scalar operand, return the scalar value that is
+/// splatted. Otherwise return null.
///
-/// Examples:
+/// Example:
///
-/// scalar_source --> vector.splat --> value - return scalar_source
/// scalar_source --> vector.broadcast --> value - return scalar_source
static Value getScalarSplatSource(Value value) {
// Block argument:
@@ -3144,10 +3143,6 @@ static Value getScalarSplatSource(Value value) {
if (!defOp)
return {};
- // Splat:
- if (auto splat = dyn_cast<vector::SplatOp>(defOp))
- return splat.getInput();
-
auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
// Not broadcast (and not splat):
@@ -7393,41 +7388,6 @@ void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
patterns.getContext(), benefit);
}
-//===----------------------------------------------------------------------===//
-// SplatOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
- auto constOperand = adaptor.getInput();
- if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
- return {};
-
- // SplatElementsAttr::get treats single value for second arg as being a splat.
- return SplatElementsAttr::get(getType(), {constOperand});
-}
-
-// Canonicalizer for vector.splat. It always gets canonicalized to a
-// vector.broadcast.
-class SplatToBroadcastPattern final : public OpRewritePattern<SplatOp> {
-public:
- using Base::Base;
- LogicalResult matchAndRewrite(SplatOp splatOp,
- PatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
- splatOp.getOperand());
- return success();
- }
-};
-void SplatOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<SplatToBroadcastPattern>(context);
-}
-
-void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
- SetIntRangeFn setResultRanges) {
- setResultRanges(getResult(), argRanges.front());
-}
-
Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
CombiningKind kind, Value v1, Value acc,
arith::FastMathFlagsAttr fastmath,
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 255f2bf5a8161..3a3231d513369 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -90,7 +90,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
Operation *maskOp = mask.getDefiningOp();
SmallVector<vector::ExtractOp, 2> extractOps;
- // TODO: add support to `vector.splat`.
+ // TODO: add support to `vector.broadcast`.
// Finding the mask creation operation.
while (maskOp &&
!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 71fba71c9f15f..1b656d82f3201 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -590,32 +590,6 @@ struct LinearizeVectorBitCast final
}
};
-/// This pattern converts the SplatOp to work on a linearized vector.
-/// Following,
-/// vector.splat %value : vector<4x4xf32>
-/// is converted to:
-/// %out_1d = vector.splat %value : vector<16xf32>
-/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
-struct LinearizeVectorSplat final
- : public OpConversionPattern<vector::SplatOp> {
- using Base::Base;
-
- LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context,
- PatternBenefit benefit = 1)
- : OpConversionPattern(typeConverter, context, benefit) {}
-
- LogicalResult
- matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto dstTy = getTypeConverter()->convertType(splatOp.getType());
- if (!dstTy)
- return rewriter.notifyMatchFailure(splatOp, "cannot convert type.");
- rewriter.replaceOpWithNewOp<vector::SplatOp>(splatOp, adaptor.getInput(),
- dstTy);
- return success();
- }
-};
-
/// This pattern converts the CreateMaskOp to work on a linearized vector.
/// It currently supports only 2D masks with a unit outer dimension.
/// Following,
@@ -934,9 +908,9 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
RewritePatternSet &patterns) {
patterns
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
- LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
- LinearizeVectorStore, LinearizeVectorFromElements,
- LinearizeVectorToElements>(typeConverter, patterns.getContext());
+ LinearizeVectorCreateMask, LinearizeVectorLoad, LinearizeVectorStore,
+ LinearizeVectorFromElements, LinearizeVectorToElements>(
+ typeConverter, patterns.getContext());
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index d6a6d7cdba673..726da1e9a3d14 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -878,7 +878,7 @@ struct BubbleUpBitCastForStridedSliceInsert
// This transforms IR like:
// %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
// Into:
-// %cst = vector.splat %c0_f32 : vector<4xf32>
+// %cst = vector.broadcast %c0_f32 : f32 to vector<4xf32>
// %1 = vector.extract_strided_slice %0 {
// offsets = [0], sizes = [4], strides = [1]
// } : vector<8xf16> to vector<4xf16>
@@ -987,8 +987,8 @@ static Type cloneOrReplace(Type type, Type newElementType) {
return newElementType;
}
-/// If `value` is the result of a splat or broadcast operation, return the input
-/// of the splat/broadcast operation.
+/// If `value` is the result of a broadcast operation, return the input
+/// of the broadcast operation.
static Value getBroadcastLikeSource(Value value) {
Operation *op = value.getDefiningOp();
@@ -998,13 +998,10 @@ static Value getBroadcastLikeSource(Value value) {
if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
return broadcast.getSource();
- if (auto splat = dyn_cast<vector::SplatOp>(op))
- return splat.getInput();
-
return {};
}
-/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
+/// Reorders elementwise(broadcast) to broadcast(elementwise). Ex:
///
/// Example:
/// ```
@@ -1017,9 +1014,6 @@ static Value getBroadcastLikeSource(Value value) {
/// %r = arith.addi %arg0, %arg1 : index
/// %b = vector.broadcast %r : index to vector<1x4xindex>
/// ```
-///
-/// Both `vector.broadcast` and `vector.splat` are supported as broadcasting
-/// ops.
struct ReorderElementwiseOpsOnBroadcast final
: public OpTraitRewritePattern<OpTrait::Elementwise> {
using OpTraitRewritePattern::OpTraitRewritePattern;
@@ -1045,29 +1039,29 @@ struct ReorderElementwiseOpsOnBroadcast final
Type resultElemType = resultType.getElementType();
// Get the type of the first non-constant operand
- Value splatSource;
+ Value broadcastSource;
for (Value operand : op->getOperands()) {
Operation *definingOp = operand.getDefiningOp();
if (!definingOp)
return failure();
if (definingOp->hasTrait<OpTrait::ConstantLike>())
continue;
- splatSource = getBroadcastLikeSource(operand);
+ broadcastSource = getBroadcastLikeSource(operand);
break;
}
- if (!splatSource)
+ if (!broadcastSource)
return failure();
Type unbroadcastResultType =
- cloneOrReplace(splatSource.getType(), resultElemType);
+ cloneOrReplace(broadcastSource.getType(), resultElemType);
// Make sure that all operands are broadcast from identically-shaped types:
- // * scalar (`vector.broadcast` + `vector.splat`), or
+ // * scalar (`vector.broadcast`), or
// * vector (`vector.broadcast`).
// Otherwise the re-ordering wouldn't be safe.
- if (!llvm::all_of(op->getOperands(), [splatSource](Value val) {
+ if (!llvm::all_of(op->getOperands(), [broadcastSource](Value val) {
if (auto source = getBroadcastLikeSource(val))
return haveSameShapeAndScaling(source.getType(),
- splatSource.getType());
+ broadcastSource.getType());
SplatElementsAttr splatConst;
return matchPattern(val, m_Constant(&splatConst));
})) {
@@ -1271,19 +1265,18 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
}
};
-/// Pattern to rewrite vector.store(vector.splat) -> vector/memref.store.
+/// Pattern to rewrite vector.store(vector.broadcast) -> vector/memref.store.
///
/// Example:
/// ```
-/// %0 = vector.splat %arg2 : vector<1xf32>
+/// %0 = vector.broadcast %arg2 : f32 to vector<1xf32>
/// vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
/// ```
/// Gets converted to:
/// ```
/// memref.store %arg2, %arg0[%arg1] : memref<?xf32>
/// ```
-class StoreOpFromSplatOrBroadcast final
- : public OpRewritePattern<vector::StoreOp> {
+class StoreOpFromBroadcast final : public OpRewritePattern<vector::StoreOp> {
public:
using Base::Base;
@@ -1308,9 +1301,9 @@ class StoreOpFromSplatOrBroadcast final
return rewriter.notifyMatchFailure(
op, "value to store is not from a broadcast");
- // Checking for single use so we can remove splat.
- Operation *splat = toStore.getDefiningOp();
- if (!splat->hasOneUse())
+ // Checking for single use so we can remove broadcast.
+ Operation *broadcast = toStore.getDefiningOp();
+ if (!broadcast->hasOneUse())
return rewriter.notifyMatchFailure(op, "expected single op use");
Value base = op.getBase();
@@ -1321,7 +1314,7 @@ class StoreOpFromSplatOrBroadcast final
} else {
rewriter.replaceOpWithNewOp<memref::StoreOp>(op, source, base, indices);
}
- rewriter.eraseOp(splat);
+ rewriter.eraseOp(broadcast);
return success();
}
};
@@ -2391,8 +2384,8 @@ void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
PatternBenefit benefit) {
// TODO: Consider converting these patterns to canonicalizations.
- patterns.add<ExtractOpFromLoad, StoreOpFromSplatOrBroadcast>(
- patterns.getContext(), benefit);
+ patterns.add<ExtractOpFromLoad, StoreOpFromBroadcast>(patterns.getContext(),
+ benefit);
}
void mlir::vector::populateChainedVectorReductionFoldingPatterns(
diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
index eb9feaad15c5b..a75f30d57fa74 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
@@ -86,7 +86,7 @@ func.func @fma_size1_vector(%a: vector<1xf32>, %b: vector<1xf32>, %c: vector<1xf
// CHECK: %[[VAL:.+]] = spirv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]]
// CHECK: spirv.ReturnValue %[[VAL]] : vector<4xf32>
func.func @splat(%f : f32) -> vector<4xf32> {
- %splat = vector.splat %f : vector<4xf32>
+ %splat = vector.broadcast %f : f32 to vector<4xf32>
return %splat : vector<4xf32>
}
diff --git a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
index c8a434bb8f5de..1735e08782528 100644
--- a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
@@ -429,38 +429,6 @@ func.func @broadcast_vec2d_from_vec1d(%arg0: vector<[8]xi16>) {
return
}
-//===----------------------------------------------------------------------===//
-// vector.splat
-//===----------------------------------------------------------------------===//
-
-// -----
-
-// CHECK-LABEL: func.func @splat_vec2d_from_i32(
-// CHECK-SAME: %[[SRC:.*]]: i32) {
-// CHECK: %[[BCST:.*]] = vector.broadcast %[[SRC]] : i32 to vector<[4]xi32>
-// CHECK: arm_sme.get_tile : vector<[4]x[4]xi32>
-// CHECK: %[[VSCALE:.*]] = vector.vscale
-// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %{{.*}} : index
-// CHECK: scf.for {{.*}} to %[[NUM_TILE_SLICES]] {{.*}} {
-// CHECK: arm_sme.insert_tile_slice %[[BCST]], {{.*}} : vector<[4]xi32> into vector<[4]x[4]xi32>
-func.func @splat_vec2d_from_i32(%arg0: i32) {
- %0 = vector.splat %arg0 : vector<[4]x[4]xi32>
- "prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> ()
- return
-}
-
-// -----
-
-// CHECK-LABEL: func.func @splat_vec2d_from_f16(
-// CHECK-SAME: %[[SRC:.*]]: f16) {
-// CHECK: %[[BCST:.*]] = vector.broadcast %[[SRC]] : f16 to vector<[8]xf16>
-// CHECK: scf.for
-// CHECK: arm_sme.insert_tile_slice %[[BCST]], {{.*}} : vector<[8]xf16> into vector<[8]x[8]xf16>
-func.func @splat_vec2d_from_f16(%arg0: f16) {
- %0 = vector.splat %arg0 : vector<[8]x[8]xf16>
- "prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
- return
-}
//===----------------------------------------------------------------------===//
// vector.transpose
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index 5973c2ba2cbd0..cb48ca3374e8d 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2216,23 +2216,6 @@ func.func @compress_store_op_with_alignment(%arg0: memref<?xindex>, %arg1: vecto
// -----
-//===----------------------------------------------------------------------===//
-// vector.splat
-//===----------------------------------------------------------------------===//
-
-// vector.splat is converted to vector.broadcast. Then, vector.broadcast is converted to LLVM.
-// CHECK-LABEL: @splat_0d
-// CHECK-NOT: splat
-// CHECK: return
-func.func @splat_0d(%elt: f32) -> (vector<f32>, vector<4xf32>, vector<[4]xf32>) {
- %a = vector.splat %elt : vector<f32>
- %b = vector.splat %elt : vector<4xf32>
- %c = vector.splat %elt : vector<[4]xf32>
- return %a, %b, %c : vector<f32>, vector<4xf32>, vector<[4]xf32>
-}
-
-// -----
-
//===----------------------------------------------------------------------===//
// vector.scalable_insert
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Math/canonicalize_ipowi.mlir b/mlir/test/Dialect/Math/canonicalize_ipowi.mlir
index 9e65a96869460..681209276ad6b 100644
--- a/mlir/test/Dialect/Math/canonicalize_ipowi.mlir
+++ b/mlir/test/Dialect/Math/canonicalize_ipowi.mlir
@@ -105,9 +105,9 @@ func.func @ipowi32_fold(%result : memref<?xi32>) {
// --- Test vector folding ---
%arg11_base = arith.constant 2 : i32
- %arg11_base_vec = vector.splat %arg11_base : vector<2x2xi32>
+ %arg11_base_vec = vector.broadcast %arg11_base : i32 to vector<2x2xi32>
%arg11_power = arith.constant 30 : i32
- %arg11_power_vec = vector.splat %arg11_power : vector<2x2xi32>
+ %arg11_power_vec = vector.broadcast %arg11_power : i32 to vector<2x2xi32>
%res11_vec = math.ipowi %arg11_base_vec, %arg11_power_vec : vector<2x2xi32>
%i11 = arith.constant 11 : index
%res11 = vector.extract %res11_vec[1, 1] : i32 from vector<2x2xi32>
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index bccf5d5b77b0e..d093bc92cd8c4 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -837,7 +837,7 @@ func.func @fold_extract_splatlike(%a : f32, %idx0 : index, %idx1 : index, %idx2
// CHECK-LABEL: fold_extract_vector_from_splat
// CHECK: vector.broadcast {{.*}} f32 to vector<4xf32>
func.func @fold_extract_vector_from_splat(%a : f32, %idx0 : index, %idx1 : index) -> vector<4xf32> {
- %b = vector.splat %a : vector<1x2x4xf32>
+ %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
return %r : vector<4xf32>
}
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir
deleted file mode 100644
index e4a9391770b6c..0000000000000
--- a/mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir
+++ /dev/null
@@ -1,126 +0,0 @@
-// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
-
-// This file should be removed when vector.splat is removed.
-// This file tests canonicalization/folding with vector.splat.
-// These tests all have equivalent tests using vector.broadcast in canonicalize.mlir
-
-
-// CHECK-LABEL: fold_extract_splat
-// CHECK-SAME: %[[A:.*]]: f32
-// CHECK: return %[[A]] : f32
-func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
- %b = vector.splat %a : vector<1x2x4xf32>
- %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
- return %r : f32
-}
-
-// -----
-
-// CHECK-LABEL: extract_strided_splat
-// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} f16 to vector<2x4xf16>
-// CHECK-NEXT: return %[[B]] : vector<2x4xf16>
-func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> {
- %0 = vector.splat %arg0 : vector<16x4xf16>
- %1 = vector.extract_strided_slice %0
- {offsets = [1, 0], sizes = [2, 4], strides = [1, 1]} :
- vector<16x4xf16> to vector<2x4xf16>
- return %1 : vector<2x4xf16>
-}
-
-// -----
-
-// CHECK-LABEL: func @splat_fold
-// CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
-// CHECK-NEXT: return [[V]] : vector<4xf32>
-func.func @splat_fold() -> vector<4xf32> {
- %c = arith.constant 1.0 : f32
- %v = vector.splat %c : vector<4xf32>
- return %v : vector<4xf32>
-
-}
-
-// -----
-
-// CHECK-LABEL: func @transpose_splat2(
-// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> {
-// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : f32 to vector<3x4xf32>
-// CHECK: return %[[VAL_1]] : vector<3x4xf32>
-func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> {
- %splat = vector.splat %arg : vector<4x3xf32>
- %0 = vector.transpose %splat, [1, 0] : vector<4x3xf32> to vector<3x4xf32>
- return %0 : vector<3x4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @insert_strided_slice_splat
-// CHECK-SAME: (%[[ARG:.*]]: f32)
-// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : f32 to vector<8x16xf32>
-// CHECK-NEXT: return %[[SPLAT]] : vector<8x16xf32>
-func.func @insert_strided_slice_splat(%x: f32) -> (vector<8x16xf32>) {
- %splat0 = vector.splat %x : vector<4x4xf32>
- %splat1 = vector.splat %x : vector<8x16xf32>
- %0 = vector.insert_strided_slice %splat0, %splat1 {offsets = [2, 2], strides = [1, 1]}
- : vector<4x4xf32> into vector<8x16xf32>
- return %0 : vector<8x16xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @shuffle_splat
-// CHECK-SAME: (%[[ARG:.*]]: i32)
-// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<4xi32>
-// CHECK-NEXT: return %[[SPLAT]] : vector<4xi32>
-func.func @shuffle_splat(%x : i32) -> vector<4xi32> {
- %v0 = vector.splat %x : vector<4xi32>
- %v1 = vector.splat %x : vector<2xi32>
- %shuffle = vector.shuffle %v0, %v1 [2, 3, 4, 5] : vector<4xi32>, vector<2xi32>
- return %shuffle : vector<4xi32>
-}
-
-
-// -----
-
-// CHECK-LABEL: func @insert_splat
-// CHECK-SAME: (%[[ARG:.*]]: i32)
-// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<2x4x3xi32>
-// CHECK-NEXT: return %[[SPLAT]] : vector<2x4x3xi32>
-func.func @insert_splat(%x : i32) -> vector<2x4x3xi32> {
- %v0 = vector.splat %x : vector<4x3xi32>
- %v1 = vector.splat %x : vector<2x4x3xi32>
- %insert = vector.insert %v0, %v1[0] : vector<4x3xi32> into vector<2x4x3xi32>
- return %insert : vector<2x4x3xi32>
-}
-
-// -----
-
-// CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression
-// CHECK-SAME: (%[[A:.*]]: f32, %[[C:.*]]: vector<2xf32>)
-func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %c: vector<2xf32>) -> (f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) {
- // Splat scalar to 0D and extract scalar.
- %0 = vector.splat %a : vector<f32>
- %1 = vector.extract %0[] : f32 from vector<f32>
-
- // Broadcast scalar to 0D and extract scalar.
- %2 = vector.splat %a : vector<f32>
- %3 = vector.extract %2[] : f32 from vector<f32>
-
- // Splat scalar to 2D and extract scalar.
- %6 = vector.splat %a : vector<2x3xf32>
- %7 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
-
- // Broadcast scalar to 3D and extract scalar.
- %8 = vector.splat %a : vector<5x6x7xf32>
- %9 = vector.extract %8[2, 1, 5] : f32 from vector<5x6x7xf32>
-
- // Extract 2D from 3D that was broadcasted from a scalar.
- // CHECK: %[[EXTRACT2:.*]] = vector.broadcast %[[A]] : f32 to vector<6x7xf32>
- %10 = vector.extract %8[2] : vector<6x7xf32> from vector<5x6x7xf32>
-
- // Extract 1D from 2D that was splat'ed from a scalar.
- // CHECK: %[[EXTRACT3:.*]] = vector.broadcast %[[A]] : f32 to vector<3xf32>
- %11 = vector.extract %6[1] : vector<3xf32> from vector<2x3xf32>
-
- // CHECK: return %[[A]], %[[A]], %[[A]], %[[A]], %[[EXTRACT2]], %[[EXTRACT3]]
- return %1, %3, %7, %9, %10, %11 : f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>
-}
diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir
index b2f16bb3dac9c..4da8d8a967c73 100644
--- a/mlir/test/Dialect/Vector/int-range-interface.mlir
+++ b/mlir/test/Dialect/Vector/int-range-interface.mlir
@@ -28,7 +28,7 @@ func.func @float_constant_splat() -> vector<8xf32> {
// CHECK: test.reflect_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index}
func.func @vector_splat() -> vector<4xindex> {
%0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : index
- %1 = vector.splat %0 : vector<4xindex>
+ %1 = vector.broadcast %0 : index to vector<4xindex>
%2 = test.reflect_bounds %1 : vector<4xindex>
func.return %2 : vector<4xindex>
}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 6ee70fdd89a85..5f035e35a1b86 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -320,7 +320,7 @@ func.func @test_vector.transfer_write(%m: memref<1xi32>, %2: vector<1x32xi32>)
func.func @test_vector.transfer_read(%arg0: vector<4x3xf32>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
- %vf0 = vector.splat %f0 : vector<4x3xf32>
+ %vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32>
// expected-error at +1 {{ requires memref or ranked tensor type}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : vector<4x3xf32>, vector<1x1x2x3xf32>
}
@@ -330,7 +330,7 @@ func.func @test_vector.transfer_read(%arg0: vector<4x3xf32>) {
func.func @test_vector.transfer_read(%arg0: memref<4x3xf32>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
- %vf0 = vector.splat %f0 : vector<4x3xf32>
+ %vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32>
// expected-error at +1 {{ requires vector type}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : memref<4x3xf32>, f32
}
@@ -414,7 +414,7 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?x?xf32>) {
%c3 = arith.constant 3 : index
%cst = arith.constant 3.0 : f32
// expected-note at +1 {{prior use here}}
- %mask = vector.splat %c1 : vector<3x8x7xi1>
+ %mask = vector.broadcast %c1 : i1 to vector<3x8x7xi1>
// expected-error at +1 {{expects different type than prior uses: 'vector<3x7xi1>' vs 'vector<3x8x7xi1>'}}
%0 = vector.transfer_read %arg0[%c3, %c3, %c3], %cst, %mask {permutation_map = affine_map<(d0, d1, d2)->(d0, 0, d2)>} : memref<?x?x?xf32>, vector<3x8x7xf32>
}
@@ -424,7 +424,7 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?x?xf32>) {
func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
- %vf0 = vector.splat %f0 : vector<4x3xf32>
+ %vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32>
// expected-error at +1 {{requires source vector element and vector result ranks to match}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<4x3xf32>>, vector<3xf32>
}
@@ -434,7 +434,7 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<6xf32>>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
- %vf0 = vector.splat %f0 : vector<6xf32>
+ %vf0 = vector.broadcast %f0 : f32 to vector<6xf32>
// expected-error at +1 {{requires the bitwidth of the minor 1-D vector to be an integral multiple of the bitwidth of the minor 1-D vector of the source}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : memref<?x?xvector<6xf32>>, vector<3xf32>
}
@@ -444,7 +444,7 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<6xf32>>) {
func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
- %vf0 = vector.splat %f0 : vector<2x3xf32>
+ %vf0 = vector.broadcast %f0 : f32 to vector<2x3xf32>
// expected-error at +1 {{ expects the in_bounds attr of same rank as permutation_map results: affine_map<(d0, d1) -> (d0, d1)>}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {in_bounds = [true], permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<2x3xf32>>, vector<1x1x2x3xf32>
}
@@ -454,8 +454,8 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
- %vf0 = vector.splat %f0 : vector<2x3xf32>
- %mask = vector.splat %c1 : vector<2x3xi1>
+ %vf0 = vector.broadcast %f0 : f32 to vector<2x3xf32>
+ %mask = vector.broadcast %c1 : f32 to vector<2x3xi1>
// expected-error at +1 {{does not support masks with vector element type}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0, %mask {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<2x3xf32>>, vector<1x1x2x3xf32>
}
@@ -492,7 +492,7 @@ func.func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
func.func @test_vector.transfer_write(%arg0: memref<vector<4x3xf32>>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
- %vf0 = vector.splat %f0 : vector<4x3xf32>
+ %vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32>
// expected-error at +1 {{ requires vector type}}
vector.transfer_write %arg0, %arg0[%c3, %c3] : memref<vector<4x3xf32>>, vector<4x3xf32>
}
@@ -502,7 +502,7 @@ func.func @test_vector.transfer_write(%arg0: memref<vector<4x3xf32>>) {
func.func @test_vector.transfer_write(%arg0: vector<4x3xf32>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
- %vf0 = vector.splat %f0 : vector<4x3xf32>
+ %vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32>
// expected-error at +1 {{ requires memref or ranked tensor type}}
vector.transfer_write %arg0, %arg0[%c3, %c3] : vector<4x3xf32>, f32
}
@@ -1980,29 +1980,6 @@ func.func @invalid_step_2d() {
// -----
-//===----------------------------------------------------------------------===//
-// vector.splat
-//===----------------------------------------------------------------------===//
-
-// -----
-
-func.func @vector_splat_invalid_result(%v : f32) {
- // expected-error at +1 {{invalid kind of type specified: expected builtin.vector, but found 'memref<8xf32>'}}
- vector.splat %v : memref<8xf32>
- return
-}
-
-// -----
-
-// expected-note @+1 {{prior use here}}
-func.func @vector_splat_type_mismatch(%a: f32) {
- // expected-error @+1 {{expects different type than prior uses: 'i32' vs 'f32'}}
- %0 = vector.splat %a : vector<1xi32>
- return
-}
-
-// -----
-
//===----------------------------------------------------------------------===//
// vector.load
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index fe697c8b9c057..ee5cfbcda5c19 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -428,33 +428,6 @@ func.func @test_linearize_across_for(%arg0 : vector<4xi8>) -> vector<4xi8> {
// -----
-// CHECK-LABEL: linearize_vector_splat
-// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x2xi32>
-func.func @linearize_vector_splat(%arg0: i32) -> vector<4x2xi32> {
-
- // CHECK: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8xi32>
- // CHECK: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<8xi32> to vector<4x2xi32>
- // CHECK: return %[[CAST]] : vector<4x2xi32>
- %0 = vector.splat %arg0 : vector<4x2xi32>
- return %0 : vector<4x2xi32>
-}
-
-// -----
-
-// CHECK-LABEL: linearize_scalable_vector_splat
-// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x[2]xi32>
-func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
-
- // CHECK: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<[8]xi32>
- // CHECK: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<[8]xi32> to vector<4x[2]xi32>
- // CHECK: return %[[CAST]] : vector<4x[2]xi32>
- %0 = vector.splat %arg0 : vector<4x[2]xi32>
- return %0 : vector<4x[2]xi32>
-
-}
-
-// -----
-
// CHECK-LABEL: linearize_create_mask
// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<1x16xi1>
func.func @linearize_create_mask(%arg0 : index, %arg1 : index) -> vector<1x16xi1> {
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 550e52af7874b..da9a1a8180a05 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -45,11 +45,11 @@ func.func @vector_transfer_ops(%arg0: memref<?x?xf32>,
%i0 = arith.constant 0 : index
%i1 = arith.constant 1 : i1
- %vf0 = vector.splat %f0 : vector<4x3xf32>
- %v0 = vector.splat %c0 : vector<4x3xi32>
- %vi0 = vector.splat %i0 : vector<4x3xindex>
+ %vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32>
+ %v0 = vector.broadcast %c0 : i32 to vector<4x3xi32>
+ %vi0 = vector.broadcast %i0 : index to vector<4x3xindex>
%m = arith.constant dense<[0, 0, 1, 0, 1]> : vector<5xi1>
- %m2 = vector.splat %i1 : vector<4x5xi1>
+ %m2 = vector.broadcast %i1 : i1 to vector<4x5xi1>
//
// CHECK: vector.transfer_read
%0 = vector.transfer_read %arg0[%c3, %c3], %f0 {permutation_map = affine_map<(d0, d1)->(d0)>} : memref<?x?xf32>, vector<128xf32>
@@ -106,9 +106,9 @@ func.func @vector_transfer_ops_tensor(%arg0: tensor<?x?xf32>,
%c0 = arith.constant 0 : i32
%i0 = arith.constant 0 : index
- %vf0 = vector.splat %f0 : vector<4x3xf32>
- %v0 = vector.splat %c0 : vector<4x3xi32>
- %vi0 = vector.splat %i0 : vector<4x3xindex>
+ %vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32>
+ %v0 = vector.broadcast %c0 : i32 to vector<4x3xi32>
+ %vi0 = vector.broadcast %i0 : index to vector<4x3xindex>
//
// CHECK: vector.transfer_read
@@ -922,28 +922,6 @@ func.func @vector_scan(%0: vector<4x8x16x32xf32>) -> vector<4x8x16x32xf32> {
return %2#0 : vector<4x8x16x32xf32>
}
-// CHECK-LABEL: func @test_splat_op
-// CHECK-SAME: %[[s:.*]]: f32, %[[s2:.*]]: !llvm.ptr<1>
-func.func @test_splat_op(%s : f32, %s2 : !llvm.ptr<1>) {
- // CHECK: vector.splat %[[s]] : vector<8xf32>
- %v = vector.splat %s : vector<8xf32>
-
- // CHECK: vector.splat %[[s]] : vector<4xf32>
- %u = "vector.splat"(%s) : (f32) -> vector<4xf32>
-
- // CHECK: vector.splat %[[s2]] : vector<16x!llvm.ptr<1>>
- %w = vector.splat %s2 : vector<16x!llvm.ptr<1>>
- return
-}
-
-// CHECK-LABEL: func @vector_splat_0d(
-func.func @vector_splat_0d(%a: f32) -> vector<f32> {
- // CHECK: vector.splat %{{.*}} : vector<f32>
- %0 = vector.splat %a : vector<f32>
- return %0 : vector<f32>
-}
-
-
// CHECK-LABEL: func @vector_mask
func.func @vector_mask(%a: vector<8xi32>, %m0: vector<8xi1>) -> i32 {
// CHECK-NEXT: %{{.*}} = vector.mask %{{.*}} { vector.reduction <add>, %{{.*}} : vector<8xi32> into i32 } : vector<8xi1> -> i32
diff --git a/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir b/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
index e74eb08339684..6e5d68c859e2c 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
@@ -49,7 +49,7 @@ func.func @vector_maskedload(%arg0 : memref<4x5xf32>) -> vector<4xf32> {
%idx_4 = arith.constant 4 : index
%mask = vector.create_mask %idx_1 : vector<4xi1>
%s = arith.constant 0.0 : f32
- %pass_thru = vector.splat %s : vector<4xf32>
+ %pass_thru = vector.broadcast %s : f32 to vector<4xf32>
%0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru : memref<4x5xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
return %0: vector<4xf32>
}
@@ -65,7 +65,7 @@ func.func @vector_maskedload_with_alignment(%arg0 : memref<4x5xf32>) -> vector<4
%idx_4 = arith.constant 4 : index
%mask = vector.create_mask %idx_1 : vector<4xi1>
%s = arith.constant 0.0 : f32
- %pass_thru = vector.splat %s : vector<4xf32>
+ %pass_thru = vector.broadcast %s : f32 to vector<4xf32>
%0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru {alignment = 8}: memref<4x5xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
return %0: vector<4xf32>
}
diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 12a911ca8c826..0c5fec8c4055a 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -107,7 +107,7 @@ func.func @return_not_in_function() {
// -----
func.func @invalid_splat(%v : f32) { // expected-note {{prior use here}}
- vector.splat %v : vector<8xf64>
+ vector.broadcast %v : f64 to vector<8xf64>
// expected-error at -1 {{expects different type than prior uses}}
return
}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir
index 6ec103193ac6b..1938a3c8ab484 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir
@@ -21,13 +21,6 @@ func.func @print_vector_0d(%a: vector<f32>) {
return
}
-func.func @splat_0d(%a: f32) {
- %1 = vector.splat %a : vector<f32>
- // CHECK: ( 42 )
- vector.print %1: vector<f32>
- return
-}
-
func.func @broadcast_0d(%a: f32) {
%1 = vector.broadcast %a : f32 to vector<f32>
// CHECK: ( 42 )
diff --git a/mlir/test/mlir-runner/utils.mlir b/mlir/test/mlir-runner/utils.mlir
index 0c25078449987..d3fc23b423a56 100644
--- a/mlir/test/mlir-runner/utils.mlir
+++ b/mlir/test/mlir-runner/utils.mlir
@@ -56,7 +56,7 @@ func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interf
func.func @vector_splat_2d() {
%c0 = arith.constant 0 : index
%f10 = arith.constant 10.0 : f32
- %vf10 = vector.splat %f10: !vector_type_C
+ %vf10 = vector.broadcast %f10: f32 to !vector_type_C
%C = memref.alloc() : !matrix_type_CC
memref.store %vf10, %C[%c0, %c0]: !matrix_type_CC
diff --git a/mlir/utils/tree-sitter-mlir/queries/highlights.scm b/mlir/utils/tree-sitter-mlir/queries/highlights.scm
index 59e280bab414a..ca52bcce042f7 100644
--- a/mlir/utils/tree-sitter-mlir/queries/highlights.scm
+++ b/mlir/utils/tree-sitter-mlir/queries/highlights.scm
@@ -181,7 +181,6 @@
"vector.insert_strided_slice"
"vector.matrix_multiply"
"vector.print"
- "vector.splat"
"vector.transfer_read"
"vector.transfer_write"
"vector.yield"
More information about the Mlir-commits
mailing list