[Mlir-commits] [flang] [mlir] [MLIR][Vector] Remove vector.splat (PR #162167)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 6 15:53:05 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-amdgpu

Author: James Newling (newling)

<details>
<summary>Changes</summary>

vector.splat has been deprecated (use the very similar vector.broadcast instead) with the last PR landing about 6 weeks ago. See https://discourse.llvm.org/t/rfc-mlir-vector-deprecate-then-remove-vector-splat/87143/1

This PR completely removes vector.splat. In addition to removing vector.splat from VectorOps.td, it

- Updates the few remaining places where vector::SplatOp is created (now vector::BroadcastOp is created)
- Removes temporary patterns where vector.splat is replaced by vector.broadcast

The only place 'vector.splat' appears is now the files

https://discourse.llvm.org/t/rfc-mlir-vector-deprecate-then-remove-vector-splat/87143/1 and
https://github.com/llvm/llvm-project/blob/main/mlir/utils/tree-sitter-mlir/dialect/vector.js

@<!-- -->artagnon maybe you can tell me what these files are, and if they should be updated? 

---

Patch is 50.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/162167.diff


28 Files Affected:

- (modified) flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp (+10-9) 
- (modified) mlir/docs/Dialects/Vector.md (+1-3) 
- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (-47) 
- (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (-4) 
- (modified) mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp (+1-1) 
- (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (+4-18) 
- (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+1-14) 
- (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+4-19) 
- (modified) mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp (+1-2) 
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+6-46) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+1-1) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+3-29) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+20-27) 
- (modified) mlir/test/Conversion/ConvertToSPIRV/vector.mlir (+1-1) 
- (modified) mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir (-32) 
- (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir (-17) 
- (modified) mlir/test/Dialect/Math/canonicalize_ipowi.mlir (+2-2) 
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+1-1) 
- (removed) mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir (-126) 
- (modified) mlir/test/Dialect/Vector/int-range-interface.mlir (+1-1) 
- (modified) mlir/test/Dialect/Vector/invalid.mlir (+10-33) 
- (modified) mlir/test/Dialect/Vector/linearize.mlir (-27) 
- (modified) mlir/test/Dialect/Vector/ops.mlir (+7-29) 
- (modified) mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir (+2-2) 
- (modified) mlir/test/IR/invalid-ops.mlir (+1-1) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir (-7) 
- (modified) mlir/test/mlir-runner/utils.mlir (+1-1) 
- (modified) mlir/utils/tree-sitter-mlir/queries/highlights.scm (-1) 


``````````diff
diff --git a/flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp b/flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp
index 03952da95b11e..265e268ab1b09 100644
--- a/flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp
@@ -2383,7 +2383,7 @@ PPCIntrinsicLibrary::genVecSplat(mlir::Type resultType,
   auto context{builder.getContext()};
   auto argBases{getBasesForArgs(args)};
 
-  mlir::vector::SplatOp splatOp{nullptr};
+  mlir::vector::BroadcastOp splatOp{nullptr};
   mlir::Type retTy{nullptr};
   switch (vop) {
   case VecOp::Splat: {
@@ -2391,9 +2391,9 @@ PPCIntrinsicLibrary::genVecSplat(mlir::Type resultType,
     auto vecTyInfo{getVecTypeFromFir(argBases[0])};
 
     auto extractOp{genVecExtract(resultType, args)};
-    splatOp =
-        mlir::vector::SplatOp::create(builder, loc, *(extractOp.getUnboxed()),
-                                      vecTyInfo.toMlirVectorType(context));
+    splatOp = mlir::vector::BroadcastOp::create(
+        builder, loc, vecTyInfo.toMlirVectorType(context),
+        *(extractOp.getUnboxed()));
     retTy = vecTyInfo.toFirVectorType();
     break;
   }
@@ -2401,8 +2401,8 @@ PPCIntrinsicLibrary::genVecSplat(mlir::Type resultType,
     assert(args.size() == 1);
     auto vecTyInfo{getVecTypeFromEle(argBases[0])};
 
-    splatOp = mlir::vector::SplatOp::create(
-        builder, loc, argBases[0], vecTyInfo.toMlirVectorType(context));
+    splatOp = mlir::vector::BroadcastOp::create(
+        builder, loc, vecTyInfo.toMlirVectorType(context), argBases[0]);
     retTy = vecTyInfo.toFirVectorType();
     break;
   }
@@ -2412,8 +2412,8 @@ PPCIntrinsicLibrary::genVecSplat(mlir::Type resultType,
     auto intOp{builder.createConvert(loc, eleTy, argBases[0])};
 
     // the intrinsic always returns vector(integer(4))
-    splatOp = mlir::vector::SplatOp::create(builder, loc, intOp,
-                                            mlir::VectorType::get(4, eleTy));
+    splatOp = mlir::vector::BroadcastOp::create(
+        builder, loc, mlir::VectorType::get(4, eleTy), intOp);
     retTy = fir::VectorType::get(4, eleTy);
     break;
   }
@@ -2444,7 +2444,8 @@ PPCIntrinsicLibrary::genVecXlds(mlir::Type resultType,
   auto addrConv{fir::ConvertOp::create(builder, loc, i64RefTy, addr)};
 
   auto addrVal{fir::LoadOp::create(builder, loc, addrConv)};
-  auto splatRes{mlir::vector::SplatOp::create(builder, loc, addrVal, i64VecTy)};
+  auto splatRes{
+      mlir::vector::BroadcastOp::create(builder, loc, i64VecTy, addrVal)};
 
   mlir::Value result{nullptr};
   if (mlirTy != splatRes.getType()) {
diff --git a/mlir/docs/Dialects/Vector.md b/mlir/docs/Dialects/Vector.md
index 6c8949d70b4a3..839dc75ff0214 100644
--- a/mlir/docs/Dialects/Vector.md
+++ b/mlir/docs/Dialects/Vector.md
@@ -125,7 +125,7 @@ Some existing Arith and Vector Dialect on `n-D` `vector` types comprise:
 // Produces a vector<3x7x8xf32>
 %b = arith.mulf %0, %1 : vector<3x7x8xf32>
 // Produces a vector<3x7x8xf32>
-%c = vector.splat %1 : vector<3x7x8xf32>
+%c = vector.broadcast %1 : f32 to vector<3x7x8xf32>
 
 %d = vector.extract %0[1]: vector<7x8xf32> from vector<3x7x8xf32>
 %e = vector.extract %0[1, 5]: vector<8xf32> from vector<3x7x8xf32>
@@ -176,8 +176,6 @@ infrastructure can apply iteratively.
 ### Virtual Vector to Hardware Vector Lowering
 
 For now, `VV -> HWV` are specified in C++ (see for instance the
-[SplatOpLowering for n-D vectors](https://github.com/tensorflow/mlir/commit/0a0c4867c6a6fcb0a2f17ef26a791c1d551fe33d)
-or the
 [VectorOuterProductOp lowering](https://github.com/tensorflow/mlir/commit/957b1ca9680b4aacabb3a480fbc4ebd2506334b8)).
 
 Simple
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 252c0b72456df..41e075467910f 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2880,53 +2880,6 @@ def Vector_PrintOp :
     }];
 }
 
-//===----------------------------------------------------------------------===//
-// SplatOp
-//===----------------------------------------------------------------------===//
-
-def Vector_SplatOp : Vector_Op<"splat", [
-    Pure,
-    DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
-    TypesMatchWith<"operand type matches element type of result",
-                   "aggregate", "input",
-                   "::llvm::cast<VectorType>($_self).getElementType()">
-  ]> {
-  let summary = "vector splat or broadcast operation";
-  let description = [{
-    Note: This operation is deprecated. Please use vector.broadcast.
-
-    Broadcast the operand to all elements of the result vector. The type of the
-    operand must match the element type of the vector type.
-
-    Example:
-
-    ```mlir
-    %s = arith.constant 10.1 : f32
-    %t = vector.splat %s : vector<8x16xf32>
-    ```
-
-    This operation is deprecated, the preferred representation of the above is:
-
-    ```mlir
-    %s = arith.constant 10.1 : f32
-    %t = vector.broadcast %s : f32 to vector<8x16xf32>
-    ```
-  }];
-
-  let arguments = (ins AnyType:$input);
-  let results = (outs AnyVectorOfAnyRank:$aggregate);
-
-  let builders = [
-    OpBuilder<(ins "Value":$element, "Type":$aggregateType),
-    [{ build($_builder, $_state, aggregateType, element); }]>];
-  let assemblyFormat = "$input attr-dict `:` type($aggregate)";
-
-  let hasFolder = 1;
-
-  // vector.splat is deprecated, and vector.broadcast should be used instead.
-  // Canonicalize vector.splat to vector.broadcast.
-  let hasCanonicalizer = 1;
-}
 
 //===----------------------------------------------------------------------===//
 // VectorScaleOp
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index dcbaa5698d767..247dba101cfc1 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -432,10 +432,6 @@ static Value getOriginalVectorValue(Value value) {
                         current = op.getSource();
                         return false;
                       })
-                      .Case<vector::SplatOp>([&current](auto op) {
-                        current = op.getInput();
-                        return false;
-                      })
                       .Default([](Operation *) { return false; });
 
     if (!skipOp) {
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index bad53c0a4a97a..1002ebe6875b6 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -236,7 +236,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
 ///  AFTER:
 ///  ```mlir
 ///  ...
-///  %pad_1d = vector.splat %pad : vector<[4]xi32>
+///  %pad_1d = vector.broadcast %pad : i32 to vector<[4]xi32>
 ///  %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
 ///                iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
 ///    ...
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 363685a691180..778c616f1bf44 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -731,28 +731,14 @@ struct ExtractFromCreateMaskToPselLowering
   }
 };
 
-// Convert all `vector.splat` to `vector.broadcast`. There is a path from
-// `vector.broadcast` to ArmSME via another pattern.
-struct ConvertSplatToBroadcast : public OpRewritePattern<vector::SplatOp> {
-  using Base::Base;
-
-  LogicalResult matchAndRewrite(vector::SplatOp splatOp,
-                                PatternRewriter &rewriter) const final {
-
-    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
-                                                     splatOp.getInput());
-    return success();
-  }
-};
-
 } // namespace
 
 void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
                                           MLIRContext &ctx) {
-  patterns.add<BroadcastOpToArmSMELowering, ConvertSplatToBroadcast,
-               TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
-               TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
-               VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
+  patterns.add<BroadcastOpToArmSMELowering, TransferReadToArmSMELowering,
+               TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
+               VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
+               VectorOuterProductToArmSMELowering,
                VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
                VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice,
                ExtractFromCreateMaskToPselLowering>(&ctx);
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 546164628b795..5355909b62a7f 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -2161,19 +2161,6 @@ class TransposeOpToMatrixTransposeOpLowering
   }
 };
 
-/// Convert `vector.splat` to `vector.broadcast`. There is a path to LLVM from
-/// `vector.broadcast` through other patterns.
-struct VectorSplatToBroadcast : public ConvertOpToLLVMPattern<vector::SplatOp> {
-  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
-  LogicalResult
-  matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
-                                                     adaptor.getInput());
-    return success();
-  }
-};
-
 } // namespace
 
 void mlir::vector::populateVectorRankReducingFMAPattern(
@@ -2212,7 +2199,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
                VectorInsertOpConversion, VectorPrintOpConversion,
                VectorTypeCastOpConversion, VectorScaleOpConversion,
                VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
-               VectorSplatToBroadcast, VectorBroadcastScalarToLowRankLowering,
+               VectorBroadcastScalarToLowRankLowering,
                VectorBroadcastScalarToNdLowering,
                VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
                MaskedReductionOpConversion, VectorInterleaveOpLowering,
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 311ff6f5fbeee..56e8fee191432 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -22,7 +22,6 @@
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Location.h"
-#include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -79,20 +78,6 @@ struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {
   }
 };
 
-// Convert `vector.splat` to `vector.broadcast`. There is a path from
-// `vector.broadcast` to SPIRV via other patterns.
-struct VectorSplatToBroadcast final
-    : public OpConversionPattern<vector::SplatOp> {
-  using Base::Base;
-  LogicalResult
-  matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
-                                                     adaptor.getInput());
-    return success();
-  }
-};
-
 struct VectorBitcastConvert final
     : public OpConversionPattern<vector::BitCastOp> {
   using Base::Base;
@@ -1092,10 +1077,10 @@ void mlir::populateVectorToSPIRVPatterns(
       VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
       VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
       VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
-      VectorSplatToBroadcast, VectorInsertStridedSliceOpConvert,
-      VectorShuffleOpConvert, VectorInterleaveOpConvert,
-      VectorDeinterleaveOpConvert, VectorScalarBroadcastPattern,
-      VectorLoadOpConverter, VectorStoreOpConverter, VectorStepOpConvert>(
+      VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
+      VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
+      VectorScalarBroadcastPattern, VectorLoadOpConverter,
+      VectorStoreOpConverter, VectorStepOpConvert>(
       typeConverter, patterns.getContext(), PatternBenefit(1));
 
   // Make sure that the more specialized dot product pattern has higher benefit
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index c64e10f534f8e..d018cddeb8dc1 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -123,8 +123,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality(
                                vector::OuterProductOp, vector::ScanOp>(
       [&](Operation *op) { return converter.isLegal(op); });
   target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
-                    arith::ConstantOp, arith::SelectOp, vector::SplatOp,
-                    vector::BroadcastOp>();
+                    arith::ConstantOp, arith::SelectOp, vector::BroadcastOp>();
 }
 
 void EmulateUnsupportedFloatsPass::runOnOperation() {
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b0132e889302f..dc58ac3cdee6f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1664,10 +1664,10 @@ static bool hasZeroDimVectors(Operation *op) {
          llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
 }
 
-/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend
-/// 1s, are considered to be 'broadcastlike'.
+/// All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are
+/// considered to be 'broadcastlike'.
 static bool isBroadcastLike(Operation *op) {
-  if (isa<BroadcastOp, SplatOp>(op))
+  if (isa<BroadcastOp>(op))
     return true;
 
   auto shapeCast = dyn_cast<ShapeCastOp>(op);
@@ -3131,12 +3131,11 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
 };
 
 /// Consider the defining operation `defOp` of `value`. If `defOp` is a
-/// vector.splat or a vector.broadcast with a scalar operand, return the scalar
-/// value that is splatted. Otherwise return null.
+/// vector.broadcast with a scalar operand, return the scalar value that is
+/// splatted. Otherwise return null.
 ///
-/// Examples:
+/// Example:
 ///
-/// scalar_source --> vector.splat --> value     - return scalar_source
 /// scalar_source --> vector.broadcast --> value - return scalar_source
 static Value getScalarSplatSource(Value value) {
   // Block argument:
@@ -3144,10 +3143,6 @@ static Value getScalarSplatSource(Value value) {
   if (!defOp)
     return {};
 
-  // Splat:
-  if (auto splat = dyn_cast<vector::SplatOp>(defOp))
-    return splat.getInput();
-
   auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
 
   // Not broadcast (and not splat):
@@ -7393,41 +7388,6 @@ void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
           patterns.getContext(), benefit);
 }
 
-//===----------------------------------------------------------------------===//
-// SplatOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
-  auto constOperand = adaptor.getInput();
-  if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
-    return {};
-
-  // SplatElementsAttr::get treats single value for second arg as being a splat.
-  return SplatElementsAttr::get(getType(), {constOperand});
-}
-
-// Canonicalizer for vector.splat. It always gets canonicalized to a
-// vector.broadcast.
-class SplatToBroadcastPattern final : public OpRewritePattern<SplatOp> {
-public:
-  using Base::Base;
-  LogicalResult matchAndRewrite(SplatOp splatOp,
-                                PatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
-                                                     splatOp.getOperand());
-    return success();
-  }
-};
-void SplatOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                          MLIRContext *context) {
-  results.add<SplatToBroadcastPattern>(context);
-}
-
-void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
-                                SetIntRangeFn setResultRanges) {
-  setResultRanges(getResult(), argRanges.front());
-}
-
 Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
                                        CombiningKind kind, Value v1, Value acc,
                                        arith::FastMathFlagsAttr fastmath,
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 255f2bf5a8161..3a3231d513369 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -90,7 +90,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
 
   Operation *maskOp = mask.getDefiningOp();
   SmallVector<vector::ExtractOp, 2> extractOps;
-  // TODO: add support to `vector.splat`.
+  // TODO: add support to `vector.broadcast`.
   // Finding the mask creation operation.
   while (maskOp &&
          !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 71fba71c9f15f..1b656d82f3201 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -590,32 +590,6 @@ struct LinearizeVectorBitCast final
   }
 };
 
-/// This pattern converts the SplatOp to work on a linearized vector.
-/// Following,
-///   vector.splat %value : vector<4x4xf32>
-/// is converted to:
-///   %out_1d = vector.splat %value : vector<16xf32>
-///   %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
-struct LinearizeVectorSplat final
-    : public OpConversionPattern<vector::SplatOp> {
-  using Base::Base;
-
-  LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context,
-                       PatternBenefit benefit = 1)
-      : OpConversionPattern(typeConverter, context, benefit) {}
-
-  LogicalResult
-  matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto dstTy = getTypeConverter()->convertType(splatOp.getType());
-    if (!dstTy)
-      return rewriter.notifyMatchFailure(splatOp, "cannot convert type.");
-    rewriter.replaceOpWithNewOp<vector::SplatOp>(splatOp, adaptor.getInput(),
-                                                 dstTy);
-    return success();
-  }
-};
-
 /// This pattern converts the CreateMaskOp to work on a linearized vector.
 /// It currently supports only 2D masks with a unit outer dimension.
 /// Following,
@@ -934,9 +908,9 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
     RewritePatternSet &patterns) {
   patterns
       .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
-           LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
-           LinearizeVectorStore, LinearizeVectorFromElements,
-           LinearizeVectorToElements>(typeConverter, patterns.getContext());
+           LinearizeVectorCreateMask, LinearizeVectorLoad, LinearizeVectorStore,
+           LinearizeVectorFromElements, LinearizeVectorToElements>(
+          typeConverter, patterns.getContext());
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/162167


More information about the Mlir-commits mailing list