[Mlir-commits] [mlir] 6a8ba31 - [mlir] Split std.splat into tensor.splat and vector.splat

River Riddle llvmlistbot at llvm.org
Wed Feb 2 14:46:00 PST 2022


Author: River Riddle
Date: 2022-02-02T14:45:12-08:00
New Revision: 6a8ba3186ed561bd5ac6aa31ba483d81223bf198

URL: https://github.com/llvm/llvm-project/commit/6a8ba3186ed561bd5ac6aa31ba483d81223bf198
DIFF: https://github.com/llvm/llvm-project/commit/6a8ba3186ed561bd5ac6aa31ba483d81223bf198.diff

LOG: [mlir] Split std.splat into tensor.splat and vector.splat

This is part of the larger effort to split the standard dialect. This will also allow for pruning some
additional dependencies on Standard (done in a followup).

Differential Revision: https://reviews.llvm.org/D118202

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/include/mlir/IR/Attributes.h
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
    mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
    mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
    mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Conversion/VectorToSPIRV/simple.mlir
    mlir/test/Dialect/Standard/ops.mlir
    mlir/test/Dialect/Tensor/canonicalize.mlir
    mlir/test/Dialect/Tensor/invalid.mlir
    mlir/test/Dialect/Tensor/ops.mlir
    mlir/test/Dialect/Vector/canonicalize.mlir
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir
    mlir/test/Dialect/Vector/vector-contract-transforms.mlir
    mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
    mlir/test/IR/core-ops.mlir
    mlir/test/IR/invalid-ops.mlir
    mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
    mlir/test/Integration/Dialect/Vector/CPU/test-outerproduct-f32.mlir
    mlir/test/Integration/Dialect/Vector/CPU/test-outerproduct-i64.mlir
    mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir
    mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
    mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir
    mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read.mlir
    mlir/test/Integration/Dialect/Vector/CPU/test-transfer-write.mlir
    mlir/test/Transforms/constant-fold.mlir
    mlir/test/mlir-cpu-runner/utils.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 322bba1cba2ac..ab0d93783b7bf 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -507,55 +507,6 @@ def SelectOp : Std_Op<"select", [NoSideEffect,
   let hasVerifier = 1;
 }
 
-//===----------------------------------------------------------------------===//
-// SplatOp
-//===----------------------------------------------------------------------===//
-
-def SplatOp : Std_Op<"splat", [NoSideEffect,
-     TypesMatchWith<"operand type matches element type of result",
-                    "aggregate", "input",
-                    "$_self.cast<ShapedType>().getElementType()">]> {
-  let summary = "splat or broadcast operation";
-  let description = [{
-    Broadcast the operand to all elements of the result vector or tensor. The
-    operand has to be of integer/index/float type. When the result is a tensor,
-    it has to be statically shaped.
-
-    Example:
-
-    ```mlir
-    %s = load %A[%i] : memref<128xf32>
-    %v = splat %s : vector<4xf32>
-    %t = splat %s : tensor<8x16xi32>
-    ```
-
-    TODO: This operation is easy to extend to broadcast to dynamically shaped
-    tensors in the same way dynamically shaped memrefs are handled.
-
-    ```mlir
-    // Broadcasts %s to a 2-d dynamically shaped tensor, with %m, %n binding
-    // to the sizes of the two dynamic dimensions.
-    %m = "foo"() : () -> (index)
-    %n = "bar"() : () -> (index)
-    %t = splat %s [%m, %n] : tensor<?x?xi32>
-    ```
-  }];
-
-  let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat],
-                                 "integer/index/float type">:$input);
-  let results = (outs AnyTypeOf<[AnyVectorOfAnyRank,
-                                 AnyStaticShapeTensor]>:$aggregate);
-
-  let builders = [
-    OpBuilder<(ins "Value":$element, "Type":$aggregateType),
-    [{ build($_builder, $_state, aggregateType, element); }]>];
-
-  let hasFolder = 1;
-  let hasVerifier = 1;
-
-  let assemblyFormat = "$input attr-dict `:` type($aggregate)";
-}
-
 //===----------------------------------------------------------------------===//
 // SwitchOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index a2f15b380b2e8..d05049c2b766a 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -968,6 +968,52 @@ def Tensor_PadOp : Tensor_Op<"pad", [AttrSizedOperandSegments, NoSideEffect,
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// SplatOp
+//===----------------------------------------------------------------------===//
+
+def Tensor_SplatOp : Tensor_Op<"splat", [
+    NoSideEffect,
+    TypesMatchWith<"operand type matches element type of result",
+                   "aggregate", "input",
+                   "$_self.cast<TensorType>().getElementType()">
+  ]> {
+  let summary = "tensor splat or broadcast operation";
+  let description = [{
+    Broadcast the operand to all elements of the result tensor. The operand is
+    required to be of integer/index/float type, and the result tensor must be
+    statically shaped.
+
+    Example:
+
+    ```mlir
+    %s = arith.constant 10.1 : f32
+    %t = tensor.splat %s : tensor<8x16xi32>
+    ```
+
+    TODO: This operation is easy to extend to broadcast to dynamically shaped
+          tensors:
+
+    ```mlir
+    // Broadcasts %s to a 2-d dynamically shaped tensor, with %m, %n binding
+    // to the sizes of the two dynamic dimensions.
+    %m = "foo"() : () -> (index)
+    %n = "bar"() : () -> (index)
+    %t = tensor.splat %s [%m, %n] : tensor<?x?xi32>
+    ```
+  }];
+
+  let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat],
+                                 "integer/index/float type">:$input);
+  let results = (outs AnyStaticShapeTensor:$aggregate);
+
+  let builders = [
+    OpBuilder<(ins "Value":$element, "Type":$aggregateType),
+    [{ build($_builder, $_state, aggregateType, element); }]>];
+  let assemblyFormat = "$input attr-dict `:` type($aggregate)";
+
+  let hasFolder = 1;
+}
 
 //===----------------------------------------------------------------------===//
 // YieldOp

diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 1a48f7d7d9bab..d2efe2c30962a 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2420,6 +2420,41 @@ def Vector_FlatTransposeOp : Vector_Op<"flat_transpose", [NoSideEffect,
   let assemblyFormat = "$matrix attr-dict `:` type($matrix) `->` type($res)";
 }
 
+//===----------------------------------------------------------------------===//
+// SplatOp
+//===----------------------------------------------------------------------===//
+
+def Vector_SplatOp : Vector_Op<"splat", [
+    NoSideEffect,
+    TypesMatchWith<"operand type matches element type of result",
+                   "aggregate", "input",
+                   "$_self.cast<VectorType>().getElementType()">
+  ]> {
+  let summary = "vector splat or broadcast operation";
+  let description = [{
+    Broadcast the operand to all elements of the result vector. The operand is
+    required to be of integer/index/float type.
+
+    Example:
+
+    ```mlir
+    %s = arith.constant 10.1 : f32
+    %t = vector.splat %s : vector<8x16xi32>
+    ```
+  }];
+
+  let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat],
+                                 "integer/index/float type">:$input);
+  let results = (outs AnyVectorOfAnyRank:$aggregate);
+
+  let builders = [
+    OpBuilder<(ins "Value":$element, "Type":$aggregateType),
+    [{ build($_builder, $_state, aggregateType, element); }]>];
+  let assemblyFormat = "$input attr-dict `:` type($aggregate)";
+
+  let hasFolder = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // VectorScaleOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index 9eeb5b6d08e61..da5b78edf92be 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -49,6 +49,8 @@ class Attribute {
   template <typename U> bool isa() const;
   template <typename First, typename Second, typename... Rest>
   bool isa() const;
+  template <typename First, typename... Rest>
+  bool isa_and_nonnull() const;
   template <typename U> U dyn_cast() const;
   template <typename U> U dyn_cast_or_null() const;
   template <typename U> U cast() const;
@@ -114,6 +116,11 @@ bool Attribute::isa() const {
   return isa<First>() || isa<Second, Rest...>();
 }
 
+template <typename First, typename... Rest>
+bool Attribute::isa_and_nonnull() const {
+  return impl && isa<First, Rest...>();
+}
+
 template <typename U> U Attribute::dyn_cast() const {
   return isa<U>() ? U(impl) : U(nullptr);
 }

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index b68587d43c777..467e51c5d5fad 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -663,99 +663,6 @@ struct SwitchOpLowering
   using Super::Super;
 };
 
-// The Splat operation is lowered to an insertelement + a shufflevector
-// operation. Splat to only 0-d and 1-d vector result types are lowered.
-struct SplatOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
-  using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
-    if (!resultType || resultType.getRank() > 1)
-      return failure();
-
-    // First insert it into an undef vector so we can shuffle it.
-    auto vectorType = typeConverter->convertType(splatOp.getType());
-    Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
-    auto zero = rewriter.create<LLVM::ConstantOp>(
-        splatOp.getLoc(),
-        typeConverter->convertType(rewriter.getIntegerType(32)),
-        rewriter.getZeroAttr(rewriter.getIntegerType(32)));
-
-    // For 0-d vector, we simply do `insertelement`.
-    if (resultType.getRank() == 0) {
-      rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
-          splatOp, vectorType, undef, adaptor.getInput(), zero);
-      return success();
-    }
-
-    // For 1-d vector, we additionally do a `vectorshuffle`.
-    auto v = rewriter.create<LLVM::InsertElementOp>(
-        splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero);
-
-    int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0);
-    SmallVector<int32_t, 4> zeroValues(width, 0);
-
-    // Shuffle the value across the desired number of elements.
-    ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
-    rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef,
-                                                       zeroAttrs);
-    return success();
-  }
-};
-
-// The Splat operation is lowered to an insertelement + a shufflevector
-// operation. Splat to only 2+-d vector result types are lowered by the
-// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
-struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
-  using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
-    if (!resultType || resultType.getRank() <= 1)
-      return failure();
-
-    // First insert it into an undef vector so we can shuffle it.
-    auto loc = splatOp.getLoc();
-    auto vectorTypeInfo =
-        LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
-    auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
-    auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
-    if (!llvmNDVectorTy || !llvm1DVectorTy)
-      return failure();
-
-    // Construct returned value.
-    Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy);
-
-    // Construct a 1-D vector with the splatted value that we insert in all the
-    // places within the returned descriptor.
-    Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy);
-    auto zero = rewriter.create<LLVM::ConstantOp>(
-        loc, typeConverter->convertType(rewriter.getIntegerType(32)),
-        rewriter.getZeroAttr(rewriter.getIntegerType(32)));
-    Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
-                                                     adaptor.getInput(), zero);
-
-    // Shuffle the value across the desired number of elements.
-    int64_t width = resultType.getDimSize(resultType.getRank() - 1);
-    SmallVector<int32_t, 4> zeroValues(width, 0);
-    ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
-    v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroAttrs);
-
-    // Iterate of linear index, convert to coords space and insert splatted 1-D
-    // vector in each position.
-    nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
-      desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmNDVectorTy, desc, v,
-                                                  position);
-    });
-    rewriter.replaceOp(splatOp, desc);
-    return success();
-  }
-};
-
 } // namespace
 
 void mlir::populateStdToLLVMFuncOpConversionPattern(
@@ -779,8 +686,6 @@ void mlir::populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
       ConstantOpLowering,
       ReturnOpLowering,
       SelectOpLowering,
-      SplatOpLowering,
-      SplatNdOpLowering,
       SwitchOpLowering>(converter);
   // clang-format on
 }

diff  --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
index 8d39aa9e598d7..003b5cc75f97a 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
@@ -55,16 +55,6 @@ class SelectOpPattern final : public OpConversionPattern<SelectOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
-/// Converts std.splat to spv.CompositeConstruct.
-class SplatPattern final : public OpConversionPattern<SplatOp> {
-public:
-  using OpConversionPattern<SplatOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(SplatOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-
 /// Converts std.br to spv.Branch.
 struct BranchOpPattern final : public OpConversionPattern<BranchOp> {
   using OpConversionPattern<BranchOp>::OpConversionPattern;
@@ -178,22 +168,6 @@ SelectOpPattern::matchAndRewrite(SelectOp op, OpAdaptor adaptor,
   return success();
 }
 
-//===----------------------------------------------------------------------===//
-// SplatOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult
-SplatPattern::matchAndRewrite(SplatOp op, OpAdaptor adaptor,
-                              ConversionPatternRewriter &rewriter) const {
-  auto dstVecType = op.getType().dyn_cast<VectorType>();
-  if (!dstVecType || !spirv::CompositeType::isValid(dstVecType))
-    return failure();
-  SmallVector<Value, 4> source(dstVecType.getNumElements(), adaptor.getInput());
-  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstVecType,
-                                                           source);
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // BranchOpPattern
 //===----------------------------------------------------------------------===//
@@ -237,8 +211,8 @@ void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
       spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSLSMinOp>,
       spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLSLUMinOp>,
 
-      ReturnOpPattern, SelectOpPattern, SplatPattern, BranchOpPattern,
-      CondBranchOpPattern>(typeConverter, context);
+      ReturnOpPattern, SelectOpPattern, BranchOpPattern, CondBranchOpPattern>(
+      typeConverter, context);
 }
 
 void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 6649472119959..80f50a3996e36 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -778,7 +778,7 @@ class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
     auto elemType = vType.getElementType();
     Value zero = rewriter.create<arith::ConstantOp>(
         loc, elemType, rewriter.getZeroAttr(elemType));
-    Value desc = rewriter.create<SplatOp>(loc, vType, zero);
+    Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
@@ -1062,6 +1062,99 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
   }
 };
 
+/// The Splat operation is lowered to an insertelement + a shufflevector
+/// operation. Splat to only 0-d and 1-d vector result types are lowered.
+struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
+  using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType resultType = splatOp.getType().cast<VectorType>();
+    if (resultType.getRank() > 1)
+      return failure();
+
+    // First insert it into an undef vector so we can shuffle it.
+    auto vectorType = typeConverter->convertType(splatOp.getType());
+    Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
+    auto zero = rewriter.create<LLVM::ConstantOp>(
+        splatOp.getLoc(),
+        typeConverter->convertType(rewriter.getIntegerType(32)),
+        rewriter.getZeroAttr(rewriter.getIntegerType(32)));
+
+    // For 0-d vector, we simply do `insertelement`.
+    if (resultType.getRank() == 0) {
+      rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
+          splatOp, vectorType, undef, adaptor.input(), zero);
+      return success();
+    }
+
+    // For 1-d vector, we additionally do a `vectorshuffle`.
+    auto v = rewriter.create<LLVM::InsertElementOp>(
+        splatOp.getLoc(), vectorType, undef, adaptor.input(), zero);
+
+    int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0);
+    SmallVector<int32_t, 4> zeroValues(width, 0);
+
+    // Shuffle the value across the desired number of elements.
+    ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
+    rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef,
+                                                       zeroAttrs);
+    return success();
+  }
+};
+
+/// The Splat operation is lowered to an insertelement + a shufflevector
+/// operation. Splat to only 2+-d vector result types are lowered by the
+/// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
+struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
+  using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType resultType = splatOp.getType();
+    if (resultType.getRank() <= 1)
+      return failure();
+
+    // First insert it into an undef vector so we can shuffle it.
+    auto loc = splatOp.getLoc();
+    auto vectorTypeInfo =
+        LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
+    auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
+    auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
+    if (!llvmNDVectorTy || !llvm1DVectorTy)
+      return failure();
+
+    // Construct returned value.
+    Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy);
+
+    // Construct a 1-D vector with the splatted value that we insert in all the
+    // places within the returned descriptor.
+    Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy);
+    auto zero = rewriter.create<LLVM::ConstantOp>(
+        loc, typeConverter->convertType(rewriter.getIntegerType(32)),
+        rewriter.getZeroAttr(rewriter.getIntegerType(32)));
+    Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
+                                                     adaptor.input(), zero);
+
+    // Shuffle the value across the desired number of elements.
+    int64_t width = resultType.getDimSize(resultType.getRank() - 1);
+    SmallVector<int32_t, 4> zeroValues(width, 0);
+    ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
+    v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroAttrs);
+
+    // Iterate of linear index, convert to coords space and insert splatted 1-D
+    // vector in each position.
+    nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
+      desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmNDVectorTy, desc, v,
+                                                  position);
+    });
+    rewriter.replaceOp(splatOp, desc);
+    return success();
+  }
+};
+
 } // namespace
 
 /// Populate the given list with patterns that convert from Vector to LLVM.
@@ -1085,8 +1178,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
            VectorLoadStoreConversion<vector::MaskedStoreOp,
                                      vector::MaskedStoreOpAdaptor>,
            VectorGatherOpConversion, VectorScatterOpConversion,
-           VectorExpandLoadOpConversion, VectorCompressStoreOpConversion>(
-          converter);
+           VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
+           VectorSplatOpLowering, VectorSplatNdOpLowering>(converter);
   // Transfer ops with rank > 1 are handled by VectorToSCF.
   populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
 }

diff  --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 499f4403d317f..5af5693f03ab7 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -425,7 +425,7 @@ struct Strategy<TransferReadOp> {
     Location loc = xferOp.getLoc();
     auto bufferType = buffer.getType().dyn_cast<ShapedType>();
     auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
-    auto vec = b.create<SplatOp>(loc, vecType, xferOp.padding());
+    auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.padding());
     b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
 
     return Value();
@@ -855,8 +855,8 @@ struct UnrollTransferReadConversion
     if (auto insertOp = getInsertOp(xferOp))
       return insertOp.dest();
     Location loc = xferOp.getLoc();
-    return rewriter.create<SplatOp>(loc, xferOp.getVectorType(),
-                                    xferOp.padding());
+    return rewriter.create<vector::SplatOp>(loc, xferOp.getVectorType(),
+                                            xferOp.padding());
   }
 
   /// If the result of the TransferReadOp has exactly one user, which is a
@@ -1143,7 +1143,8 @@ struct Strategy1d<TransferReadOp> {
   static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
     // Inititalize vector with padding value.
     Location loc = xferOp.getLoc();
-    return b.create<SplatOp>(loc, xferOp.getVectorType(), xferOp.padding());
+    return b.create<vector::SplatOp>(loc, xferOp.getVectorType(),
+                                     xferOp.padding());
   }
 };
 

diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 55a5d1c53770b..b40052df603b9 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -247,6 +247,23 @@ struct VectorInsertStridedSliceOpConvert final
   }
 };
 
+class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
+public:
+  using OpConversionPattern<vector::SplatOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType dstVecType = op.getType();
+    if (!spirv::CompositeType::isValid(dstVecType))
+      return failure();
+    SmallVector<Value, 4> source(dstVecType.getNumElements(), adaptor.input());
+    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstVecType,
+                                                             source);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
@@ -255,6 +272,6 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                VectorExtractElementOpConvert, VectorExtractOpConvert,
                VectorExtractStridedSliceOpConvert, VectorFmaOpConvert,
                VectorInsertElementOpConvert, VectorInsertOpConvert,
-               VectorInsertStridedSliceOpConvert>(typeConverter,
-                                                  patterns.getContext());
+               VectorInsertStridedSliceOpConvert, VectorSplatPattern>(
+      typeConverter, patterns.getContext());
 }

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index b16d59dc9dc0e..932959517b655 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -863,34 +863,6 @@ LogicalResult SelectOp::verify() {
   return success();
 }
 
-//===----------------------------------------------------------------------===//
-// SplatOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult SplatOp::verify() {
-  // TODO: we could replace this by a trait.
-  if (getOperand().getType() != getType().cast<ShapedType>().getElementType())
-    return emitError("operand should be of elemental type of result type");
-
-  return success();
-}
-
-// Constant folding hook for SplatOp.
-OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.size() == 1 && "splat takes one operand");
-
-  auto constOperand = operands.front();
-  if (!constOperand || !constOperand.isa<IntegerAttr, FloatAttr>())
-    return {};
-
-  auto shapedType = getType().cast<ShapedType>();
-  assert(shapedType.getElementType() == constOperand.getType() &&
-         "incorrect input attribute type for folding");
-
-  // SplatElementsAttr::get treats single value for second arg as being a splat.
-  return SplatElementsAttr::get(shapedType, {constOperand});
-}
-
 //===----------------------------------------------------------------------===//
 // SwitchOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 91dfddb2dfe99..09586adc4df72 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1800,6 +1800,19 @@ OpFoldResult PadOp::fold(ArrayRef<Attribute>) {
   return {};
 }
 
+//===----------------------------------------------------------------------===//
+// SplatOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
+  auto constOperand = operands.front();
+  if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
+    return {};
+
+  // SplatElementsAttr::get treats single value for second arg as being a splat.
+  return SplatElementsAttr::get(getType(), {constOperand});
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index e034b98e5d0cc..207817b6922cf 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -15,7 +15,6 @@
 
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -2540,7 +2539,7 @@ class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
     auto splat = op.vector().getDefiningOp<SplatOp>();
     if (!splat)
       return failure();
-    rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput());
+    rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.input());
     return success();
   }
 };
@@ -4369,5 +4368,22 @@ void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
           patterns.getContext());
 }
 
+//===----------------------------------------------------------------------===//
+// SplatOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
+  auto constOperand = operands.front();
+  if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
+    return {};
+
+  // SplatElementsAttr::get treats single value for second arg as being a splat.
+  return SplatElementsAttr::get(getType(), {constOperand});
+}
+
+//===----------------------------------------------------------------------===//
+// TableGen'd op method definitions
+//===----------------------------------------------------------------------===//
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index 469ee756cc5ec..4308fa6a43be9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -8,7 +8,6 @@
 
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 1e7fc9a88b192..40897ef53b053 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -10,8 +10,8 @@
 // transfer_write ops.
 //
 //===----------------------------------------------------------------------===//
+
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 69b76dc856ef7..4ff9d65ad6696 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -17,7 +17,6 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/SCF.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
@@ -205,7 +204,7 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
 
     // Scalar to any vector can use splat.
     if (!srcType) {
-      rewriter.replaceOpWithNewOp<SplatOp>(op, dstType, op.source());
+      rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.source());
       return success();
     }
 
@@ -220,7 +219,7 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
         ext = rewriter.create<vector::ExtractElementOp>(loc, op.source());
       else
         ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
-      rewriter.replaceOpWithNewOp<SplatOp>(op, dstType, ext);
+      rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
       return success();
     }
 
@@ -1735,7 +1734,7 @@ struct TransferReadToVectorLoadLowering
     // Create vector load op.
     Operation *loadOp;
     if (read.mask()) {
-      Value fill = rewriter.create<SplatOp>(
+      Value fill = rewriter.create<vector::SplatOp>(
           read.getLoc(), unbroadcastedVectorType, read.padding());
       loadOp = rewriter.create<vector::MaskedLoadOp>(
           read.getLoc(), unbroadcastedVectorType, read.source(), read.indices(),
@@ -2168,12 +2167,13 @@ static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
   // Add in an offset if requested.
   if (off) {
     Value o = createCastToIndexLike(rewriter, loc, idxType, *off);
-    Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
+    Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o);
     indices = rewriter.create<arith::AddIOp>(loc, ov, indices);
   }
   // Construct the vector comparison.
   Value bound = createCastToIndexLike(rewriter, loc, idxType, b);
-  Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
+  Value bounds =
+      rewriter.create<vector::SplatOp>(loc, indices.getType(), bound);
   return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
                                         bounds);
 }

diff  --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 800e19ceda71d..faa4beae8d41f 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -16,7 +16,6 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/Builders.h"

diff  --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
index 26f064fb03122..71bd62a009e23 100644
--- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
@@ -456,36 +456,6 @@ func @dfs_block_order(%arg0: i32) -> (i32) {
 
 // -----
 
-// CHECK-LABEL: @splat_0d
-// CHECK-SAME: %[[ARG:.*]]: f32
-func @splat_0d(%a: f32) -> vector<f32> {
-  %v = splat %a : vector<f32>
-  return %v : vector<f32>
-}
-// CHECK-NEXT: %[[UNDEF:[0-9]+]] = llvm.mlir.undef : vector<1xf32>
-// CHECK-NEXT: %[[ZERO:[0-9]+]] = llvm.mlir.constant(0 : i32) : i32
-// CHECK-NEXT: %[[V:[0-9]+]] = llvm.insertelement %[[ARG]], %[[UNDEF]][%[[ZERO]] : i32] : vector<1xf32>
-// CHECK-NEXT: llvm.return %[[V]] : vector<1xf32>
-
-// -----
-
-// CHECK-LABEL: @splat
-// CHECK-SAME: %[[A:arg[0-9]+]]: vector<4xf32>
-// CHECK-SAME: %[[ELT:arg[0-9]+]]: f32
-func @splat(%a: vector<4xf32>, %b: f32) -> vector<4xf32> {
-  %vb = splat %b : vector<4xf32>
-  %r = arith.mulf %a, %vb : vector<4xf32>
-  return %r : vector<4xf32>
-}
-// CHECK-NEXT: %[[UNDEF:[0-9]+]] = llvm.mlir.undef : vector<4xf32>
-// CHECK-NEXT: %[[ZERO:[0-9]+]] = llvm.mlir.constant(0 : i32) : i32
-// CHECK-NEXT: %[[V:[0-9]+]] = llvm.insertelement %[[ELT]], %[[UNDEF]][%[[ZERO]] : i32] : vector<4xf32>
-// CHECK-NEXT: %[[SPLAT:[0-9]+]] = llvm.shufflevector %[[V]], %[[UNDEF]] [0 : i32, 0 : i32, 0 : i32, 0 : i32]
-// CHECK-NEXT: %[[SCALE:[0-9]+]] = llvm.fmul %[[A]], %[[SPLAT]] : vector<4xf32>
-// CHECK-NEXT: llvm.return %[[SCALE]] : vector<4xf32>
-
-// -----
-
 // CHECK-LABEL: func @ceilf(
 // CHECK-SAME: f32
 func @ceilf(%arg0 : f32) {

diff  --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index 029c161b3df45..5790c3dcbf03c 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -921,21 +921,6 @@ func @tensor_extract_constant(%a : index, %b: index, %c: index) -> i32 {
 
 // -----
 
-//===----------------------------------------------------------------------===//
-// splat
-//===----------------------------------------------------------------------===//
-
-// CHECK-LABEL: func @splat
-//  CHECK-SAME: (%[[A:.+]]: f32)
-//       CHECK:   %[[VAL:.+]] = spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] : vector<4xf32>
-//       CHECK:   spv.ReturnValue %[[VAL]]
-func @splat(%f : f32) -> vector<4xf32> {
-  %splat = splat %f : vector<4xf32>
-  return %splat : vector<4xf32>
-}
-
-// -----
-
 //===----------------------------------------------------------------------===//
 // std.br, std.cond_br
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
index 9ad8a4162f4fc..6599323b573d5 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
@@ -5,17 +5,19 @@
 // CMP32-SAME: %[[ARG:.*]]: index)
 // CMP32: %[[T0:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]> : vector<11xi32>
 // CMP32: %[[T1:.*]] = arith.index_cast %[[ARG]] : index to i32
-// CMP32: %[[T2:.*]] = splat %[[T1]] : vector<11xi32>
-// CMP32: %[[T3:.*]] = arith.cmpi slt, %[[T0]], %[[T2]] : vector<11xi32>
-// CMP32: return %[[T3]] : vector<11xi1>
+// CMP32: %[[T2:.*]] = llvm.insertelement %[[T1]], %{{.*}}[%{{.*}} : i32] : vector<11xi32>
+// CMP32: %[[T3:.*]] = llvm.shufflevector %[[T2]], %{{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<11xi32>, vector<11xi32>
+// CMP32: %[[T4:.*]] = arith.cmpi slt, %[[T0]], %[[T3]] : vector<11xi32>
+// CMP32: return %[[T4]] : vector<11xi1>
 
 // CMP64-LABEL: @genbool_var_1d(
 // CMP64-SAME: %[[ARG:.*]]: index)
 // CMP64: %[[T0:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]> : vector<11xi64>
 // CMP64: %[[T1:.*]] = arith.index_cast %[[ARG]] : index to i64
-// CMP64: %[[T2:.*]] = splat %[[T1]] : vector<11xi64>
-// CMP64: %[[T3:.*]] = arith.cmpi slt, %[[T0]], %[[T2]] : vector<11xi64>
-// CMP64: return %[[T3]] : vector<11xi1>
+// CMP64: %[[T2:.*]] = llvm.insertelement %[[T1]], %{{.*}}[%{{.*}} : i32] : vector<11xi64>
+// CMP64: %[[T3:.*]] = llvm.shufflevector %[[T2]], %{{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<11xi64>, vector<11xi64>
+// CMP64: %[[T4:.*]] = arith.cmpi slt, %[[T0]], %[[T3]] : vector<11xi64>
+// CMP64: return %[[T4]] : vector<11xi1>
 
 func @genbool_var_1d(%arg0: index) -> vector<11xi1> {
   %0 = vector.create_mask %arg0 : vector<11xi1>

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index ccb75b8a606a2..2f9bfdcde5b7d 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -55,8 +55,9 @@ func @broadcast_vec0d_from_f32(%arg0: f32) -> vector<f32> {
 }
 // CHECK-LABEL: @broadcast_vec0d_from_f32
 // CHECK-SAME:  %[[A:.*]]: f32)
-// CHECK:       %[[T0:.*]] = splat %[[A]] : vector<f32>
-// CHECK:       return %[[T0]] : vector<f32>
+// CHECK:       %[[T0:.*]] = llvm.insertelement %[[A]]
+// CHECK:       %[[T1:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<1xf32> to vector<f32>
+// CHECK:       return %[[T1]] : vector<f32>
 
 // -----
 
@@ -76,8 +77,9 @@ func @broadcast_vec1d_from_f32(%arg0: f32) -> vector<2xf32> {
 }
 // CHECK-LABEL: @broadcast_vec1d_from_f32
 // CHECK-SAME:  %[[A:.*]]: f32)
-// CHECK:       %[[T0:.*]] = splat %[[A]] : vector<2xf32>
-// CHECK:       return %[[T0]] : vector<2xf32>
+// CHECK:       %[[T0:.*]] = llvm.insertelement %[[A]]
+// CHECK:       %[[T1:.*]] = llvm.shufflevector %[[T0]]
+// CHECK:       return %[[T1]] : vector<2xf32>
 
 // -----
 
@@ -87,8 +89,11 @@ func @broadcast_vec1d_from_index(%arg0: index) -> vector<2xindex> {
 }
 // CHECK-LABEL: @broadcast_vec1d_from_index
 // CHECK-SAME:  %[[A:.*]]: index)
-// CHECK:       %[[T0:.*]] = splat %[[A]] : vector<2xindex>
-// CHECK:       return %[[T0]] : vector<2xindex>
+// CHECK:       %[[A1:.*]] = builtin.unrealized_conversion_cast %[[A]] : index to i64
+// CHECK:       %[[T0:.*]] = llvm.insertelement %[[A1]]
+// CHECK:       %[[T1:.*]] = llvm.shufflevector %[[T0]]
+// CHECK:       %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<2xi64> to vector<2xindex>
+// CHECK:       return %[[T2]] : vector<2xindex>
 
 // -----
 
@@ -98,8 +103,12 @@ func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
 }
 // CHECK-LABEL: @broadcast_vec2d_from_scalar(
 // CHECK-SAME:  %[[A:.*]]: f32)
-// CHECK:       %[[T0:.*]] = splat %[[A]] : vector<2x3xf32>
-// CHECK:       return %[[T0]] : vector<2x3xf32>
+// CHECK:       %[[T0:.*]] = llvm.insertelement %[[A]]
+// CHECK:       %[[T1:.*]] = llvm.shufflevector %[[T0]]
+// CHECK:       %[[T2:.*]] = llvm.insertvalue %[[T1]], %{{.*}}[0] : !llvm.array<2 x vector<3xf32>> 
+// CHECK:       %[[T3:.*]] = llvm.insertvalue %[[T1]], %{{.*}}[1] : !llvm.array<2 x vector<3xf32>> 
+// CHECK:       %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T3]] : !llvm.array<2 x vector<3xf32>> to vector<2x3xf32>
+// CHECK:       return %[[T4]] : vector<2x3xf32>
 
 // -----
 
@@ -109,8 +118,13 @@ func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> {
 }
 // CHECK-LABEL: @broadcast_vec3d_from_scalar(
 // CHECK-SAME:  %[[A:.*]]: f32)
-// CHECK:       %[[T0:.*]] = splat %[[A]] : vector<2x3x4xf32>
-// CHECK:       return %[[T0]] : vector<2x3x4xf32>
+// CHECK:       %[[T0:.*]] = llvm.insertelement %[[A]]
+// CHECK:       %[[T1:.*]] = llvm.shufflevector %[[T0]]
+// CHECK:       %[[T2:.*]] = llvm.insertvalue %[[T1]], %{{.*}}[0, 0] : !llvm.array<2 x array<3 x vector<4xf32>>>
+// ...
+// CHECK:       %[[T3:.*]] = llvm.insertvalue %[[T1]], %{{.*}}[1, 2] : !llvm.array<2 x array<3 x vector<4xf32>>>
+// CHECK:       %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T3]] : !llvm.array<2 x array<3 x vector<4xf32>>> to vector<2x3x4xf32>
+// CHECK:       return %[[T4]] : vector<2x3x4xf32>
 
 // -----
 
@@ -135,7 +149,8 @@ func @broadcast_vec2d_from_vec0d(%arg0: vector<f32>) -> vector<3x2xf32> {
 //       CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
 //       CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : index) : i64
 //       CHECK: %[[T5:.*]] = llvm.extractelement %[[T0]][%[[T4]] : i64] : vector<1xf32>
-//       CHECK: %[[T6:.*]] = splat %[[T5]] : vector<2xf32>
+//       CHECK: %[[T6Insert:.*]] = llvm.insertelement %[[T5]]
+//       CHECK: %[[T6:.*]] = llvm.shufflevector %[[T6Insert]]
 //       CHECK: %[[T7:.*]] = llvm.insertvalue %[[T6]], %[[T2]][0] : !llvm.array<3 x vector<2xf32>>
 //       CHECK: %[[T8:.*]] = llvm.insertvalue %[[T6]], %[[T7]][1] : !llvm.array<3 x vector<2xf32>>
 //       CHECK: %[[T9:.*]] = llvm.insertvalue %[[T6]], %[[T8]][2] : !llvm.array<3 x vector<2xf32>>
@@ -228,8 +243,9 @@ func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> {
 // CHECK-SAME:  %[[A:.*]]: vector<1xf32>)
 // CHECK:       %[[T1:.*]] = llvm.mlir.constant(0 : i64) : i64
 // CHECK:       %[[T2:.*]] = llvm.extractelement %[[A]]{{\[}}%[[T1]] : i64] : vector<1xf32>
-// CHECK:       %[[T3:.*]] = splat %[[T2]] : vector<4xf32>
-// CHECK:       return %[[T3]] : vector<4xf32>
+// CHECK:       %[[T3:.*]] = llvm.insertelement %[[T2]]
+// CHECK:       %[[T4:.*]] = llvm.shufflevector %[[T3]]
+// CHECK:       return %[[T4]] : vector<4xf32>
 
 // -----
 
@@ -263,22 +279,26 @@ func @broadcast_stretch_at_end(%arg0: vector<4x1xf32>) -> vector<4x3xf32> {
 // CHECK:       %[[T3:.*]] = llvm.extractvalue %[[T2]][0] : !llvm.array<4 x vector<1xf32>>
 // CHECK:       %[[T4:.*]] = llvm.mlir.constant(0 : i64) : i64
 // CHECK:       %[[T5:.*]] = llvm.extractelement %[[T3]]{{\[}}%[[T4]] : i64] : vector<1xf32>
-// CHECK:       %[[T6:.*]] = splat %[[T5]] : vector<3xf32>
+// CHECK:       %[[T6Insert:.*]] = llvm.insertelement %[[T5]]
+// CHECK:       %[[T6:.*]] = llvm.shufflevector %[[T6Insert]]
 // CHECK:       %[[T8:.*]] = llvm.insertvalue %[[T6]], %[[T7]][0] : !llvm.array<4 x vector<3xf32>>
 // CHECK:       %[[T10:.*]] = llvm.extractvalue %[[T2]][1] : !llvm.array<4 x vector<1xf32>>
 // CHECK:       %[[T11:.*]] = llvm.mlir.constant(0 : i64) : i64
 // CHECK:       %[[T12:.*]] = llvm.extractelement %[[T10]]{{\[}}%[[T11]] : i64] : vector<1xf32>
-// CHECK:       %[[T13:.*]] = splat %[[T12]] : vector<3xf32>
+// CHECK:       %[[T13Insert:.*]] = llvm.insertelement %[[T12]]
+// CHECK:       %[[T13:.*]] = llvm.shufflevector %[[T13Insert]]
 // CHECK:       %[[T14:.*]] = llvm.insertvalue %[[T13]], %[[T8]][1] : !llvm.array<4 x vector<3xf32>>
 // CHECK:       %[[T16:.*]] = llvm.extractvalue %[[T2]][2] : !llvm.array<4 x vector<1xf32>>
 // CHECK:       %[[T17:.*]] = llvm.mlir.constant(0 : i64) : i64
 // CHECK:       %[[T18:.*]] = llvm.extractelement %[[T16]]{{\[}}%[[T17]] : i64] : vector<1xf32>
-// CHECK:       %[[T19:.*]] = splat %[[T18]] : vector<3xf32>
+// CHECK:       %[[T19Insert:.*]] = llvm.insertelement %[[T18]]
+// CHECK:       %[[T19:.*]] = llvm.shufflevector %[[T19Insert]]
 // CHECK:       %[[T20:.*]] = llvm.insertvalue %[[T19]], %[[T14]][2] : !llvm.array<4 x vector<3xf32>>
 // CHECK:       %[[T22:.*]] = llvm.extractvalue %[[T2]][3] : !llvm.array<4 x vector<1xf32>>
 // CHECK:       %[[T23:.*]] = llvm.mlir.constant(0 : i64) : i64
 // CHECK:       %[[T24:.*]] = llvm.extractelement %[[T22]]{{\[}}%[[T23]] : i64] : vector<1xf32>
-// CHECK:       %[[T25:.*]] = splat %[[T24]] : vector<3xf32>
+// CHECK:       %[[T25Insert:.*]] = llvm.insertelement %[[T24]]
+// CHECK:       %[[T25:.*]] = llvm.shufflevector %[[T25Insert]]
 // CHECK:       %[[T26:.*]] = llvm.insertvalue %[[T25]], %[[T20]][3] : !llvm.array<4 x vector<3xf32>>
 // CHECK:       %[[T27:.*]] = builtin.unrealized_conversion_cast %[[T26]] : !llvm.array<4 x vector<3xf32>> to vector<4x3xf32>
 // CHECK:       return %[[T27]] : vector<4x3xf32>
@@ -332,12 +352,14 @@ func @outerproduct(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<2x3xf32
 // CHECK:       %[[T7:.*]] = builtin.unrealized_conversion_cast %[[T2]] : vector<2x3xf32> to !llvm.array<2 x vector<3xf32>>
 // CHECK:       %[[T3:.*]] = llvm.mlir.constant(0 : i64) : i64
 // CHECK:       %[[T4:.*]] = llvm.extractelement %[[A]]{{\[}}%[[T3]] : i64] : vector<2xf32>
-// CHECK:       %[[T5:.*]] = splat %[[T4]] : vector<3xf32>
+// CHECK:       %[[T5Insert:.*]] = llvm.insertelement %[[T4]]
+// CHECK:       %[[T5:.*]] = llvm.shufflevector %[[T5Insert]]
 // CHECK:       %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32>
 // CHECK:       %[[T8:.*]] = llvm.insertvalue %[[T6]], %[[T7]][0] : !llvm.array<2 x vector<3xf32>>
 // CHECK:       %[[T9:.*]] = llvm.mlir.constant(1 : i64) : i64
 // CHECK:       %[[T10:.*]] = llvm.extractelement %[[A]]{{\[}}%[[T9]] : i64] : vector<2xf32>
-// CHECK:       %[[T11:.*]] = splat %[[T10]] : vector<3xf32>
+// CHECK:       %[[T11Insert:.*]] = llvm.insertelement %[[T10]]
+// CHECK:       %[[T11:.*]] = llvm.shufflevector %[[T11Insert]]
 // CHECK:       %[[T12:.*]] = arith.mulf %[[T11]], %[[B]] : vector<3xf32>
 // CHECK:       %[[T13:.*]] = llvm.insertvalue %[[T12]], %[[T8]][1] : !llvm.array<2 x vector<3xf32>>
 // CHECK:       %[[T14:.*]] = builtin.unrealized_conversion_cast %[[T13]] : !llvm.array<2 x vector<3xf32>> to vector<2x3xf32>
@@ -357,9 +379,10 @@ func @outerproduct_index(%arg0: vector<2xindex>, %arg1: vector<3xindex>) -> vect
 // CHECK:       %[[T8:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<2x3xindex> to !llvm.array<2 x vector<3xi64>>
 // CHECK:       %[[T2:.*]] = llvm.mlir.constant(0 : i64) : i64
 // CHECK:       %[[T3:.*]] = llvm.extractelement %[[T1]]{{\[}}%[[T2]] : i64] : vector<2xi64>
-// CHECK:       %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T3]] : i64 to index
-// CHECK:       %[[T5:.*]] = splat %[[T4]] : vector<3xindex>
-// CHECK:       %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xindex>
+// CHECK:       %[[T4:.*]] = llvm.insertelement %[[T3]]
+// CHECK:       %[[T5:.*]] = llvm.shufflevector %[[T4]]
+// CHECK:       %[[T5Cast:.*]] = builtin.unrealized_conversion_cast %[[T5]] : vector<3xi64> to vector<3xindex>
+// CHECK:       %[[T6:.*]] = arith.muli %[[T5Cast]], %[[B]] : vector<3xindex>
 // CHECK:       %[[T7:.*]] = builtin.unrealized_conversion_cast %[[T6]] : vector<3xindex> to vector<3xi64>
 // CHECK:       %{{.*}} = llvm.insertvalue %[[T7]], %[[T8]][0] : !llvm.array<2 x vector<3xi64>>
 
@@ -378,13 +401,15 @@ func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: vector
 // CHECK:       %[[T10:.*]] = builtin.unrealized_conversion_cast %[[T3]] : vector<2x3xf32> to !llvm.array<2 x vector<3xf32>>
 // CHECK:       %[[T4:.*]] = llvm.mlir.constant(0 : i64) : i64
 // CHECK:       %[[T5:.*]] = llvm.extractelement %[[A]]{{\[}}%[[T4]] : i64] : vector<2xf32>
-// CHECK:       %[[T6:.*]] = splat %[[T5]] : vector<3xf32>
+// CHECK:       %[[T6Insert:.*]] = llvm.insertelement %[[T5]]
+// CHECK:       %[[T6:.*]] = llvm.shufflevector %[[T6Insert]]
 // CHECK:       %[[T8:.*]] = llvm.extractvalue %[[T7]][0] : !llvm.array<2 x vector<3xf32>>
 // CHECK:       %[[T9:.*]] = "llvm.intr.fmuladd"(%[[T6]], %[[B]], %[[T8]]) : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> vector<3xf32>
 // CHECK:       %[[T11:.*]] = llvm.insertvalue %[[T9]], %[[T10]][0] : !llvm.array<2 x vector<3xf32>>
 // CHECK:       %[[T12:.*]] = llvm.mlir.constant(1 : i64) : i64
 // CHECK:       %[[T13:.*]] = llvm.extractelement %[[A]]{{\[}}%[[T12]] : i64] : vector<2xf32>
-// CHECK:       %[[T14:.*]] = splat %[[T13]] : vector<3xf32>
+// CHECK:       %[[T14Insert:.*]] = llvm.insertelement %[[T13]]
+// CHECK:       %[[T14:.*]] = llvm.shufflevector %[[T14Insert]]
 // CHECK:       %[[T16:.*]] = llvm.extractvalue %[[T7]][1] : !llvm.array<2 x vector<3xf32>>
 // CHECK:       %[[T17:.*]] = "llvm.intr.fmuladd"(%[[T14]], %[[B]], %[[T16]]) : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> vector<3xf32>
 // CHECK:       %[[T18:.*]] = llvm.insertvalue %[[T17]], %[[T11]][1] : !llvm.array<2 x vector<3xf32>>
@@ -986,8 +1011,7 @@ func @extract_strided_slice3(%arg0: vector<4x8xf32>) -> vector<2x2xf32> {
 // CHECK-LABEL: @extract_strided_slice3(
 //  CHECK-SAME:    %[[ARG:.*]]: vector<4x8xf32>)
 //       CHECK:    %[[A:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x8xf32> to !llvm.array<4 x vector<8xf32>>
-//       CHECK:    %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32
-//       CHECK:    %[[VAL_2:.*]] = splat %[[VAL_1]] : vector<2x2xf32>
+//       CHECK:    %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
 //       CHECK:    %[[VAL_6:.*]] = builtin.unrealized_conversion_cast %[[VAL_2]] : vector<2x2xf32> to !llvm.array<2 x vector<2xf32>>
 //       CHECK:    %[[T2:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vector<8xf32>>
 //       CHECK:    %[[T3:.*]] = llvm.shufflevector %[[T2]], %[[T2]] [2, 3] : vector<8xf32>, vector<8xf32>
@@ -1233,17 +1257,19 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
 //
 // 2. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
 //       CHECK: %[[otrunc:.*]] = arith.index_cast %[[BASE]] : index to i32
-//       CHECK: %[[offsetVec:.*]] = splat %[[otrunc]] : vector<17xi32>
+//       CHECK: %[[offsetVecInsert:.*]] = llvm.insertelement %[[otrunc]]
+//       CHECK: %[[offsetVec:.*]] = llvm.shufflevector %[[offsetVecInsert]]
 //       CHECK: %[[offsetVec2:.*]] = arith.addi %[[offsetVec]], %[[linearIndex]] : vector<17xi32>
 //
 // 3. Let dim the memref dimension, compute the vector comparison mask:
 //    [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
 //       CHECK: %[[dtrunc:.*]] = arith.index_cast %[[DIM]] : index to i32
-//       CHECK: %[[dimVec:.*]] = splat %[[dtrunc]] : vector<17xi32>
+//       CHECK: %[[dimVecInsert:.*]] = llvm.insertelement %[[dtrunc]]
+//       CHECK: %[[dimVec:.*]] = llvm.shufflevector %[[dimVecInsert]]
 //       CHECK: %[[mask:.*]] = arith.cmpi slt, %[[offsetVec2]], %[[dimVec]] : vector<17xi32>
 //
 // 4. Create pass-through vector.
-//       CHECK: %[[PASS_THROUGH:.*]] = splat %[[c7]] : vector<17xf32>
+//       CHECK: %[[PASS_THROUGH:.*]] = arith.constant dense<7.{{.*}}> : vector<17xf32>
 //
 // 5. Bitcast to vector form.
 //       CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
@@ -1262,12 +1288,12 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
 //  CHECK-SAME: vector<17xi32>
 //
 // 2. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
-//       CHECK: splat %{{.*}} : vector<17xi32>
+//       CHECK: llvm.shufflevector %{{.*}} : vector<17xi32>
 //       CHECK: arith.addi
 //
 // 3. Let dim the memref dimension, compute the vector comparison mask:
 //    [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
-//       CHECK: splat %{{.*}} : vector<17xi32>
+//       CHECK: llvm.shufflevector %{{.*}} : vector<17xi32>
 //       CHECK: %[[mask_b:.*]] = arith.cmpi slt, {{.*}} : vector<17xi32>
 //
 // 4. Bitcast to vector form.
@@ -1295,8 +1321,7 @@ func @transfer_read_index_1d(%A : memref<?xindex>, %base: index) -> vector<17xin
 }
 // CHECK-LABEL: func @transfer_read_index_1d
 //  CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<17xindex>
-//       CHECK: %[[C7:.*]] = arith.constant 7 : index
-//       CHECK: %[[SPLAT:.*]] = splat %[[C7]] : vector<17xindex>
+//       CHECK: %[[SPLAT:.*]] = arith.constant dense<7> : vector<17xindex>
 //       CHECK: %{{.*}} = builtin.unrealized_conversion_cast %[[SPLAT]] : vector<17xindex> to vector<17xi64>
 
 //       CHECK: %[[loaded:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} :
@@ -1321,12 +1346,14 @@ func @transfer_read_2d_to_1d(%A : memref<?x?xf32>, %base0: index, %base1: index)
 //
 // Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
 //       CHECK: %[[trunc:.*]] = arith.index_cast %[[BASE_1]] : index to i32
-//       CHECK: %[[offsetVec:.*]] = splat %[[trunc]] : vector<17xi32>
+//       CHECK: %[[offsetVecInsert:.*]] = llvm.insertelement %[[trunc]]
+//       CHECK: %[[offsetVec:.*]] = llvm.shufflevector %[[offsetVecInsert]]
 //
 // Let dim the memref dimension, compute the vector comparison mask:
 //    [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
 //       CHECK: %[[dimtrunc:.*]] = arith.index_cast %[[DIM]] : index to i32
-//       CHECK: splat %[[dimtrunc]] : vector<17xi32>
+//       CHECK: %[[dimtruncInsert:.*]] = llvm.insertelement %[[dimtrunc]]
+//       CHECK: llvm.shufflevector %[[dimtruncInsert]]
 
 // -----
 
@@ -1451,9 +1478,11 @@ func @create_mask_0d(%a : index) -> vector<i1> {
 // CHECK-SAME: %[[arg:.*]]: index
 // CHECK:  %[[indices:.*]] = arith.constant dense<0> : vector<i32>
 // CHECK:  %[[arg_i32:.*]] = arith.index_cast %[[arg]] : index to i32
-// CHECK:  %[[bounds:.*]] = splat %[[arg_i32]] : vector<i32>
-// CHECK:  %[[result:.*]] = arith.cmpi slt, %[[indices]], %[[bounds]] : vector<i32>
+// CHECK:  %[[bounds:.*]] = llvm.insertelement %[[arg_i32]]
+// CHECK:  %[[boundsCast:.*]] = builtin.unrealized_conversion_cast %[[bounds]] : vector<1xi32> to vector<i32>
+// CHECK:  %[[result:.*]] = arith.cmpi slt, %[[indices]], %[[boundsCast]] : vector<i32>
 // CHECK:  return %[[result]] : vector<i1>
+
 // -----
 
 func @create_mask_1d(%a : index) -> vector<4xi1> {
@@ -1465,7 +1494,8 @@ func @create_mask_1d(%a : index) -> vector<4xi1> {
 // CHECK-SAME: %[[arg:.*]]: index
 // CHECK:  %[[indices:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32>
 // CHECK:  %[[arg_i32:.*]] = arith.index_cast %[[arg]] : index to i32
-// CHECK:  %[[bounds:.*]] = splat %[[arg_i32]] : vector<4xi32>
+// CHECK:  %[[boundsInsert:.*]] = llvm.insertelement %[[arg_i32]]
+// CHECK:  %[[bounds:.*]] = llvm.shufflevector %[[boundsInsert]]
 // CHECK:  %[[result:.*]] = arith.cmpi slt, %[[indices]], %[[bounds]] : vector<4xi32>
 // CHECK:  return %[[result]] : vector<4xi1>
 
@@ -1728,3 +1758,34 @@ func @compress_store_op_index(%arg0: memref<?xindex>, %arg1: vector<11xi1>, %arg
 }
 // CHECK-LABEL: func @compress_store_op_index
 // CHECK: "llvm.intr.masked.compressstore"(%{{.*}}, %{{.*}}, %{{.*}}) : (vector<11xi64>, !llvm.ptr<i64>, vector<11xi1>) -> ()
+
+// -----
+
+// CHECK-LABEL: @splat_0d
+// CHECK-SAME: %[[ARG:.*]]: f32
+func @splat_0d(%a: f32) -> vector<f32> {
+  %v = vector.splat %a : vector<f32>
+  return %v : vector<f32>
+}
+// CHECK-NEXT: %[[UNDEF:[0-9]+]] = llvm.mlir.undef : vector<1xf32>
+// CHECK-NEXT: %[[ZERO:[0-9]+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK-NEXT: %[[V:[0-9]+]] = llvm.insertelement %[[ARG]], %[[UNDEF]][%[[ZERO]] : i32] : vector<1xf32>
+// CHECK-NEXT: %[[VCAST:[0-9]+]] = builtin.unrealized_conversion_cast %[[V]] : vector<1xf32> to vector<f32>
+// CHECK-NEXT: return %[[VCAST]] : vector<f32>
+
+// -----
+
+// CHECK-LABEL: @splat
+// CHECK-SAME: %[[A:arg[0-9]+]]: vector<4xf32>
+// CHECK-SAME: %[[ELT:arg[0-9]+]]: f32
+func @splat(%a: vector<4xf32>, %b: f32) -> vector<4xf32> {
+  %vb = vector.splat %b : vector<4xf32>
+  %r = arith.mulf %a, %vb : vector<4xf32>
+  return %r : vector<4xf32>
+}
+// CHECK-NEXT: %[[UNDEF:[0-9]+]] = llvm.mlir.undef : vector<4xf32>
+// CHECK-NEXT: %[[ZERO:[0-9]+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK-NEXT: %[[V:[0-9]+]] = llvm.insertelement %[[ELT]], %[[UNDEF]][%[[ZERO]] : i32] : vector<4xf32>
+// CHECK-NEXT: %[[SPLAT:[0-9]+]] = llvm.shufflevector %[[V]], %[[UNDEF]] [0 : i32, 0 : i32, 0 : i32, 0 : i32]
+// CHECK-NEXT: %[[SCALE:[0-9]+]] = arith.mulf %[[A]], %[[SPLAT]] : vector<4xf32>
+// CHECK-NEXT: return %[[SCALE]] : vector<4xf32>

diff  --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
index 7a3e4b3289729..d93aa21f7dc36 100644
--- a/mlir/test/Conversion/VectorToSPIRV/simple.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
@@ -168,3 +168,14 @@ func @fma(%a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>) -> vector<4xf
   %0 = vector.fma %a, %b, %c: vector<4xf32>
 	return %0 : vector<4xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @splat
+//  CHECK-SAME: (%[[A:.+]]: f32)
+//       CHECK:   %[[VAL:.+]] = spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] : vector<4xf32>
+//       CHECK:   return %[[VAL]]
+func @splat(%f : f32) -> vector<4xf32> {
+  %splat = vector.splat %f : vector<4xf32>
+  return %splat : vector<4xf32>
+}

diff  --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir
index 1c3570eedb7ea..cb71b495ad674 100644
--- a/mlir/test/Dialect/Standard/ops.mlir
+++ b/mlir/test/Dialect/Standard/ops.mlir
@@ -50,10 +50,3 @@ func @switch_i64(%flag : i64, %caseOperand : i32) {
   ^bb3(%bb3arg : i32):
     return
 }
-
-// CHECK-LABEL: func @vector_splat_0d(
-func @vector_splat_0d(%a: f32) -> vector<f32> {
-  // CHECK: splat %{{.*}} : vector<f32>
-  %0 = splat %a : vector<f32>
-  return %0 : vector<f32>
-}

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 81be542915d69..f489c4188523f 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1219,3 +1219,15 @@ func @propogate_index_cast(%arg0: tensor<1xi32>) -> index {
   %1 = tensor.extract %0[%c0] : tensor<1xindex>
   return %1 : index
 }
+
+// -----
+
+// CHECK-LABEL: func @splat_fold
+func @splat_fold() -> tensor<4xf32> {
+  %c = arith.constant 1.0 : f32
+  %t = tensor.splat %c : tensor<4xf32>
+  return %t : tensor<4xf32>
+
+  // CHECK-NEXT: [[T:%.*]] = arith.constant dense<1.000000e+00> : tensor<4xf32>
+  // CHECK-NEXT: return [[T]] : tensor<4xf32>
+}

diff  --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 7dbf662fd6fe3..5c48a085d35a2 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -363,3 +363,18 @@ func @pad_yield_type(%arg0: tensor<?x4xi32>, %arg1: i8) -> tensor<?x9xi32> {
   return %0 : tensor<?x9xi32>
 }
 
+// -----
+
+func @invalid_splat(%v : f32) {
+  // expected-error at +1 {{invalid kind of type specified}}
+  tensor.splat %v : memref<8xf32>
+  return
+}
+
+// -----
+
+func @invalid_splat(%v : vector<8xf32>) {
+  // expected-error at +1 {{must be integer/index/float type}}
+  %w = tensor.splat %v : tensor<8xvector<8xf32>>
+  return
+}
\ No newline at end of file

diff  --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index a76a18d190b57..4d8e169722227 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -250,3 +250,13 @@ func @pad_to_static_size(%arg0: tensor<?x?xf32>, %ub0: index, %ub1: index,
 
 // -----
 
+// CHECK-LABEL: func @test_splat_op
+// CHECK-SAME: [[S:%arg[0-9]+]]: f32
+func @test_splat_op(%s : f32) {
+  // CHECK: tensor.splat [[S]] : tensor<8xf32>
+  %v = tensor.splat %s : tensor<8xf32>
+  
+  // CHECK: tensor.splat [[S]] : tensor<4xf32>
+  %u = "tensor.splat"(%s) : (f32) -> tensor<4xf32>
+  return
+}

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 3d1923ac09ace..a25b0687ca756 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -515,7 +515,7 @@ func @fold_extract_broadcast(%a : f32) -> f32 {
 //  CHECK-SAME:   %[[A:.*]]: f32
 //       CHECK:   return %[[A]] : f32
 func @fold_extract_splat(%a : f32) -> f32 {
-  %b = splat %a : vector<1x2x4xf32>
+  %b = vector.splat %a : vector<1x2x4xf32>
   %r = vector.extract %b[0, 1, 2] : vector<1x2x4xf32>
   return %r : f32
 }
@@ -1121,10 +1121,10 @@ func @insert_strided_slice_full_range(%source: vector<16x16xf16>, %dest: vector<
 // -----
 
 // CHECK-LABEL: extract_strided_splat
-//       CHECK:   %[[B:.*]] = splat %{{.*}} : vector<2x4xf16>
+//       CHECK:   %[[B:.*]] = vector.splat %{{.*}} : vector<2x4xf16>
 //  CHECK-NEXT:   return %[[B]] : vector<2x4xf16>
 func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> {
- %0 = splat %arg0 : vector<16x4xf16>
+ %0 = vector.splat %arg0 : vector<16x4xf16>
  %1 = vector.extract_strided_slice %0
   {offsets = [1, 0], sizes = [2, 4], strides = [1, 1]} :
   vector<16x4xf16> to vector<2x4xf16>
@@ -1242,3 +1242,15 @@ func @extract_extract_strided2(%A: vector<2x4xf32>)
  %1 = vector.extract %0[0] : vector<1x4xf32>
  return %1 : vector<4xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @splat_fold
+func @splat_fold() -> vector<4xf32> {
+  %c = arith.constant 1.0 : f32
+  %v = vector.splat %c : vector<4xf32>
+  return %v : vector<4xf32>
+
+  // CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
+  // CHECK-NEXT: return [[V]] : vector<4xf32>
+}

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index cd795b7bab083..acc810aa60451 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -300,7 +300,7 @@ func @test_vector.transfer_read(%arg0: memref<?x?xf32>) {
 func @test_vector.transfer_read(%arg0: vector<4x3xf32>) {
   %c3 = arith.constant 3 : index
   %f0 = arith.constant 0.0 : f32
-  %vf0 = splat %f0 : vector<4x3xf32>
+  %vf0 = vector.splat %f0 : vector<4x3xf32>
   // expected-error at +1 {{ requires memref or ranked tensor type}}
   %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : vector<4x3xf32>, vector<1x1x2x3xf32>
 }
@@ -310,7 +310,7 @@ func @test_vector.transfer_read(%arg0: vector<4x3xf32>) {
 func @test_vector.transfer_read(%arg0: memref<4x3xf32>) {
   %c3 = arith.constant 3 : index
   %f0 = arith.constant 0.0 : f32
-  %vf0 = splat %f0 : vector<4x3xf32>
+  %vf0 = vector.splat %f0 : vector<4x3xf32>
   // expected-error at +1 {{ requires vector type}}
   %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : memref<4x3xf32>, f32
 }
@@ -376,7 +376,7 @@ func @test_vector.transfer_read(%arg0: memref<?x?x?xf32>) {
   %c3 = arith.constant 3 : index
   %cst = arith.constant 3.0 : f32
   // expected-note at +1 {{prior use here}}
-  %mask = splat %c1 : vector<3x8x7xi1>
+  %mask = vector.splat %c1 : vector<3x8x7xi1>
   // expected-error at +1 {{expects 
diff erent type than prior uses: 'vector<3x7xi1>' vs 'vector<3x8x7xi1>'}}
   %0 = vector.transfer_read %arg0[%c3, %c3, %c3], %cst, %mask {permutation_map = affine_map<(d0, d1, d2)->(d0, 0, d2)>} : memref<?x?x?xf32>, vector<3x8x7xf32>
 }
@@ -386,7 +386,7 @@ func @test_vector.transfer_read(%arg0: memref<?x?x?xf32>) {
 func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
   %c3 = arith.constant 3 : index
   %f0 = arith.constant 0.0 : f32
-  %vf0 = splat %f0 : vector<4x3xf32>
+  %vf0 = vector.splat %f0 : vector<4x3xf32>
   // expected-error at +1 {{requires source vector element and vector result ranks to match}}
   %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<4x3xf32>>, vector<3xf32>
 }
@@ -396,7 +396,7 @@ func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
 func @test_vector.transfer_read(%arg0: memref<?x?xvector<6xf32>>) {
   %c3 = arith.constant 3 : index
   %f0 = arith.constant 0.0 : f32
-  %vf0 = splat %f0 : vector<6xf32>
+  %vf0 = vector.splat %f0 : vector<6xf32>
   // expected-error at +1 {{requires the bitwidth of the minor 1-D vector to be an integral multiple of the bitwidth of the minor 1-D vector of the source}}
   %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : memref<?x?xvector<6xf32>>, vector<3xf32>
 }
@@ -406,7 +406,7 @@ func @test_vector.transfer_read(%arg0: memref<?x?xvector<6xf32>>) {
 func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
   %c3 = arith.constant 3 : index
   %f0 = arith.constant 0.0 : f32
-  %vf0 = splat %f0 : vector<2x3xf32>
+  %vf0 = vector.splat %f0 : vector<2x3xf32>
   // expected-error at +1 {{ expects the optional in_bounds attr of same rank as permutation_map results: affine_map<(d0, d1) -> (d0, d1)>}}
   %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {in_bounds = [true], permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<2x3xf32>>, vector<1x1x2x3xf32>
 }
@@ -416,7 +416,7 @@ func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
 func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
   %c3 = arith.constant 3 : index
   %f0 = arith.constant 0.0 : f32
-  %vf0 = splat %f0 : vector<2x3xf32>
+  %vf0 = vector.splat %f0 : vector<2x3xf32>
   // expected-error at +1 {{requires broadcast dimensions to be in-bounds}}
   %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {in_bounds = [false, true], permutation_map = affine_map<(d0, d1)->(0, d1)>} : memref<?x?xvector<2x3xf32>>, vector<1x1x2x3xf32>
 }
@@ -426,8 +426,8 @@ func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
 func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
   %c3 = arith.constant 3 : index
   %f0 = arith.constant 0.0 : f32
-  %vf0 = splat %f0 : vector<2x3xf32>
-  %mask = splat %c1 : vector<2x3xi1>
+  %vf0 = vector.splat %f0 : vector<2x3xf32>
+  %mask = vector.splat %c1 : vector<2x3xi1>
   // expected-error at +1 {{does not support masks with vector element type}}
   %0 = vector.transfer_read %arg0[%c3, %c3], %vf0, %mask {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<2x3xf32>>, vector<1x1x2x3xf32>
 }
@@ -446,7 +446,7 @@ func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
 func @test_vector.transfer_write(%arg0: memref<vector<4x3xf32>>) {
   %c3 = arith.constant 3 : index
   %f0 = arith.constant 0.0 : f32
-  %vf0 = splat %f0 : vector<4x3xf32>
+  %vf0 = vector.splat %f0 : vector<4x3xf32>
   // expected-error at +1 {{ requires vector type}}
   vector.transfer_write %arg0, %arg0[%c3, %c3] : memref<vector<4x3xf32>>, vector<4x3xf32>
 }
@@ -456,7 +456,7 @@ func @test_vector.transfer_write(%arg0: memref<vector<4x3xf32>>) {
 func @test_vector.transfer_write(%arg0: vector<4x3xf32>) {
   %c3 = arith.constant 3 : index
   %f0 = arith.constant 0.0 : f32
-  %vf0 = splat %f0 : vector<4x3xf32>
+  %vf0 = vector.splat %f0 : vector<4x3xf32>
   // expected-error at +1 {{ requires memref or ranked tensor type}}
   vector.transfer_write %arg0, %arg0[%c3, %c3] : vector<4x3xf32>, f32
 }
@@ -1506,3 +1506,11 @@ func @scan_incompatible_shapes(%arg0: vector<2x3xi32>, %arg1: vector<5xi32>) ->
     vector<2x3xi32>, vector<5xi32>
   return %0#0 : vector<2x3xi32>
 }
+
+// -----
+
+func @invalid_splat(%v : f32) {
+  // expected-error at +1 {{invalid kind of type specified}}
+  vector.splat %v : memref<8xf32>
+  return
+}

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index a839476602141..47bbed2df4e86 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -45,11 +45,11 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>,
   %i0 = arith.constant 0 : index
   %i1 = arith.constant 1 : i1
 
-  %vf0 = splat %f0 : vector<4x3xf32>
-  %v0 = splat %c0 : vector<4x3xi32>
-  %vi0 = splat %i0 : vector<4x3xindex>
+  %vf0 = vector.splat %f0 : vector<4x3xf32>
+  %v0 = vector.splat %c0 : vector<4x3xi32>
+  %vi0 = vector.splat %i0 : vector<4x3xindex>
   %m = arith.constant dense<[0, 0, 1, 0, 1]> : vector<5xi1>
-  %m2 = splat %i1 : vector<5x4xi1>
+  %m2 = vector.splat %i1 : vector<5x4xi1>
   //
   // CHECK: vector.transfer_read
   %0 = vector.transfer_read %arg0[%c3, %c3], %f0 {permutation_map = affine_map<(d0, d1)->(d0)>} : memref<?x?xf32>, vector<128xf32>
@@ -106,9 +106,9 @@ func @vector_transfer_ops_tensor(%arg0: tensor<?x?xf32>,
   %c0 = arith.constant 0 : i32
   %i0 = arith.constant 0 : index
 
-  %vf0 = splat %f0 : vector<4x3xf32>
-  %v0 = splat %c0 : vector<4x3xi32>
-  %vi0 = splat %i0 : vector<4x3xindex>
+  %vf0 = vector.splat %f0 : vector<4x3xf32>
+  %v0 = vector.splat %c0 : vector<4x3xi32>
+  %vi0 = vector.splat %i0 : vector<4x3xindex>
 
   //
   // CHECK: vector.transfer_read
@@ -725,3 +725,21 @@ func @vector_scan(%0: vector<4x8x16x32xf32>) -> vector<4x8x16x32xf32> {
     vector<4x8x16x32xf32>, vector<4x16x32xf32>
   return %2#0 : vector<4x8x16x32xf32>
 }
+
+// CHECK-LABEL: func @test_splat_op
+// CHECK-SAME: [[S:%arg[0-9]+]]: f32
+func @test_splat_op(%s : f32) {
+  // CHECK: vector.splat [[S]] : vector<8xf32>
+  %v = vector.splat %s : vector<8xf32>
+  
+  // CHECK: vector.splat [[S]] : vector<4xf32>
+  %u = "vector.splat"(%s) : (f32) -> vector<4xf32>
+  return
+}
+
+// CHECK-LABEL: func @vector_splat_0d(
+func @vector_splat_0d(%a: f32) -> vector<f32> {
+  // CHECK: vector.splat %{{.*}} : vector<f32>
+  %0 = vector.splat %a : vector<f32>
+  return %0 : vector<f32>
+}

diff  --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index 01e6a23698cc1..1ed83adac71d0 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -268,11 +268,11 @@ func @full_contract2(%arg0: vector<2x3xf32>,
 // CHECK-SAME: %[[B:.*1]]: vector<3xf32>
 // CHECK:      %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
 // CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32>
-// CHECK:      %[[T1:.*]] = splat %[[T0]] : vector<3xf32>
+// CHECK:      %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32>
 // CHECK:      %[[T2:.*]] = arith.mulf %[[T1]], %[[B]] : vector<3xf32>
 // CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
 // CHECK:      %[[T4:.*]] = vector.extract %[[A]][1] : vector<2xf32>
-// CHECK:      %[[T5:.*]] = splat %[[T4]] : vector<3xf32>
+// CHECK:      %[[T5:.*]] = vector.splat %[[T4]] : vector<3xf32>
 // CHECK:      %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32>
 // CHECK:      %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<2x3xf32>
 // CHECK:      return %[[T7]] : vector<2x3xf32>
@@ -289,12 +289,12 @@ func @outerproduct_noacc(%arg0: vector<2xf32>,
 // CHECK-SAME: %[[C:.*2]]: vector<2x3xf32>
 // CHECK:      %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
 // CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32>
-// CHECK:      %[[T1:.*]] = splat %[[T0]] : vector<3xf32>
+// CHECK:      %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32>
 // CHECK:      %[[T2:.*]] = vector.extract %[[C]][0] : vector<2x3xf32>
 // CHECK:      %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32>
 // CHECK:      %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
 // CHECK:      %[[T5:.*]] = vector.extract %[[A]][1] : vector<2xf32>
-// CHECK:      %[[T6:.*]] = splat %[[T5]] : vector<3xf32>
+// CHECK:      %[[T6:.*]] = vector.splat %[[T5]] : vector<3xf32>
 // CHECK:      %[[T7:.*]] = vector.extract %[[C]][1] : vector<2x3xf32>
 // CHECK:      %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32>
 // CHECK:      %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32>
@@ -312,11 +312,11 @@ func @outerproduct_acc(%arg0: vector<2xf32>,
 // CHECK-SAME: %[[B:.*1]]: vector<3xi32>
 // CHECK:      %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32>
 // CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xi32>
-// CHECK:      %[[T1:.*]] = splat %[[T0]] : vector<3xi32>
+// CHECK:      %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32>
 // CHECK:      %[[T2:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32>
 // CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
 // CHECK:      %[[T4:.*]] = vector.extract %[[A]][1] : vector<2xi32>
-// CHECK:      %[[T5:.*]] = splat %[[T4]] : vector<3xi32>
+// CHECK:      %[[T5:.*]] = vector.splat %[[T4]] : vector<3xi32>
 // CHECK:      %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32>
 // CHECK:      %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xi32> into vector<2x3xi32>
 // CHECK:      return %[[T7]] : vector<2x3xi32>
@@ -332,13 +332,13 @@ func @outerproduct_noacc_int(%arg0: vector<2xi32>,
 // CHECK-SAME: %[[C:.*2]]: vector<2x3xi32>
 // CHECK:      %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32>
 // CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xi32>
-// CHECK:      %[[T1:.*]] = splat %[[T0]] : vector<3xi32>
+// CHECK:      %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32>
 // CHECK:      %[[T2:.*]] = vector.extract %[[C]][0] : vector<2x3xi32>
 // CHECK:      %[[T3:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32>
 // CHECK:      %[[T4:.*]] = arith.addi %[[T3]], %[[T2]] : vector<3xi32>
 // CHECK:      %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
 // CHECK:      %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xi32>
-// CHECK:      %[[T7:.*]] = splat %[[T6]] : vector<3xi32>
+// CHECK:      %[[T7:.*]] = vector.splat %[[T6]] : vector<3xi32>
 // CHECK:      %[[T8:.*]] = vector.extract %[[C]][1] : vector<2x3xi32>
 // CHECK:      %[[T9:.*]] = arith.muli %[[T7]], %[[B]] : vector<3xi32>
 // CHECK:      %[[T10:.*]] = arith.addi %[[T9]], %[[T8]] : vector<3xi32>
@@ -354,7 +354,7 @@ func @outerproduct_acc_int(%arg0: vector<2xi32>,
 // CHECK-LABEL: func @axpy_fp(
 // CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
 // CHECK-SAME: %[[B:.*1]]: f32)
-// CHECK: %[[T0:.*]] = splat %[[B]] : vector<16xf32>
+// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32>
 // CHECK: %[[T1:.*]] = arith.mulf %[[A]], %[[T0]] : vector<16xf32>
 // CHECK: return %[[T1]] : vector<16xf32>
 func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> {
@@ -366,7 +366,7 @@ func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> {
 // CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
 // CHECK-SAME: %[[B:.*1]]: f32,
 // CHECK-SAME: %[[C:.*2]]: vector<16xf32>)
-// CHECK: %[[T0:.*]] = splat %[[B]] : vector<16xf32>
+// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32>
 // CHECK: %[[T1:.*]] = vector.fma %[[A]], %[[T0]], %[[C]] : vector<16xf32>
 // CHECK: return %[[T1]] : vector<16xf32>
 func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) -> vector<16xf32> {
@@ -377,7 +377,7 @@ func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) ->
 // CHECK-LABEL: func @axpy_int(
 // CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
 // CHECK-SAME: %[[B:.*1]]: i32)
-// CHECK: %[[T0:.*]] = splat %[[B]] : vector<16xi32>
+// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32>
 // CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32>
 // CHECK: return %[[T1]] : vector<16xi32>
 func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> {
@@ -389,7 +389,7 @@ func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> {
 // CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
 // CHECK-SAME: %[[B:.*1]]: i32,
 // CHECK-SAME: %[[C:.*2]]: vector<16xi32>)
-// CHECK: %[[T0:.*]] = splat %[[B]] : vector<16xi32>
+// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32>
 // CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32>
 // CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[C]] : vector<16xi32>
 // CHECK: return %[[T2]] : vector<16xi32>
@@ -612,7 +612,7 @@ func @matmul(%arg0: vector<2x4xf32>,
 
 // CHECK-LABEL: func @broadcast_vec1d_from_scalar
 // CHECK-SAME: %[[A:.*0]]: f32
-// CHECK:      %[[T0:.*]] = splat %[[A]] : vector<2xf32>
+// CHECK:      %[[T0:.*]] = vector.splat %[[A]] : vector<2xf32>
 // CHECK:      return %[[T0]] : vector<2xf32>
 
 func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
@@ -622,7 +622,7 @@ func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
 
 // CHECK-LABEL: func @broadcast_vec2d_from_scalar
 // CHECK-SAME: %[[A:.*0]]: f32
-// CHECK:      %[[T0:.*]] = splat %[[A]] : vector<2x3xf32>
+// CHECK:      %[[T0:.*]] = vector.splat %[[A]] : vector<2x3xf32>
 // CHECK:      return %[[T0]] : vector<2x3xf32>
 
 func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
@@ -632,7 +632,7 @@ func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
 
 // CHECK-LABEL: func @broadcast_vec3d_from_scalar
 // CHECK-SAME: %[[A:.*0]]: f32
-// CHECK:      %[[T0:.*]] = splat %[[A]] : vector<2x3x4xf32>
+// CHECK:      %[[T0:.*]] = vector.splat %[[A]] : vector<2x3x4xf32>
 // CHECK:      return %[[T0]] : vector<2x3x4xf32>
 
 func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> {
@@ -697,7 +697,7 @@ func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf32> {
 // CHECK-LABEL: func @broadcast_stretch
 // CHECK-SAME: %[[A:.*0]]: vector<1xf32>
 // CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<1xf32>
-// CHECK:      %[[T1:.*]] = splat %[[T0]] : vector<4xf32>
+// CHECK:      %[[T1:.*]] = vector.splat %[[T0]] : vector<4xf32>
 // CHECK:      return %[[T1]] : vector<4xf32>
 
 func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> {
@@ -723,16 +723,16 @@ func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32> {
 // CHECK-SAME: %[[A:.*0]]: vector<4x1xf32>
 // CHECK:      %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3xf32>
 // CHECK:      %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<4x1xf32>
-// CHECK:      %[[T2:.*]] = splat %[[T0]] : vector<3xf32>
+// CHECK:      %[[T2:.*]] = vector.splat %[[T0]] : vector<3xf32>
 // CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<4x3xf32>
 // CHECK:      %[[T4:.*]] = vector.extract %[[A]][1, 0] : vector<4x1xf32>
-// CHECK:      %[[T6:.*]] = splat %[[T4]] : vector<3xf32>
+// CHECK:      %[[T6:.*]] = vector.splat %[[T4]] : vector<3xf32>
 // CHECK:      %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<4x3xf32>
 // CHECK:      %[[T8:.*]] = vector.extract %[[A]][2, 0] : vector<4x1xf32>
-// CHECK:      %[[T10:.*]] = splat %[[T8]] : vector<3xf32>
+// CHECK:      %[[T10:.*]] = vector.splat %[[T8]] : vector<3xf32>
 // CHECK:      %[[T11:.*]] = vector.insert %[[T10]], %[[T7]] [2] : vector<3xf32> into vector<4x3xf32>
 // CHECK:      %[[T12:.*]] = vector.extract %[[A]][3, 0] : vector<4x1xf32>
-// CHECK:      %[[T14:.*]] = splat %[[T12]] : vector<3xf32>
+// CHECK:      %[[T14:.*]] = vector.splat %[[T12]] : vector<3xf32>
 // CHECK:      %[[T15:.*]] = vector.insert %[[T14]], %[[T11]] [3] : vector<3xf32> into vector<4x3xf32>
 // CHECK:      return %[[T15]] : vector<4x3xf32>
 

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index 7983a81f8d8c6..155d23b529ddf 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -282,19 +282,19 @@ func @transfer_read_permutations(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?x?x?
   %c0 = arith.constant 0 : index
   %m = arith.constant 1 : i1
 
-  %mask0 = splat %m : vector<7x14xi1>
+  %mask0 = vector.splat %m : vector<7x14xi1>
   %0 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask0 {in_bounds = [true, false, true, true], permutation_map = #map0} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
 // CHECK: %[[MASK0:.*]] = vector.transpose {{.*}} : vector<7x14xi1> to vector<14x7xi1>
 // CHECK: vector.transfer_read {{.*}} %[[MASK0]] {in_bounds = [false, true, true, true], permutation_map = #[[$MAP0]]} : memref<?x?x?x?xf32>, vector<14x7x8x16xf32>
 // CHECK: vector.transpose %{{.*}}, [1, 0, 2, 3] : vector<14x7x8x16xf32> to vector<7x14x8x16xf32>
 
-  %mask1 = splat %m : vector<14x16xi1>
+  %mask1 = vector.splat %m : vector<14x16xi1>
   %1 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask1 {permutation_map = #map1} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
 // CHECK: %[[MASK1:.*]] = vector.transpose {{.*}} : vector<14x16xi1> to vector<16x14xi1>
 // CHECK: vector.transfer_read {{.*}} %[[MASK1]] {permutation_map = #[[$MAP0]]} : memref<?x?x?x?xf32>, vector<16x14x7x8xf32>
 // CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32>
 
-  %mask2 = splat %m : vector<7x14xi1>
+  %mask2 = vector.splat %m : vector<7x14xi1>
   %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask2 {in_bounds = [true, false, true, true], permutation_map = #map2} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
 // CHECK: %[[MASK2:.*]] = vector.transpose {{.*}} : vector<7x14xi1> to vector<14x7xi1>
 // CHECK: vector.transfer_read {{.*}} %[[MASK2]] {in_bounds = [false, true, true], permutation_map = #[[$MAP1]]} : memref<?x?x?x?xf32>, vector<14x16x7xf32>
@@ -336,7 +336,7 @@ func @transfer_write_permutations(
   %c0 = arith.constant 0 : index
   %m = arith.constant 1 : i1
 
-  %mask0 = splat %m : vector<7x14x8x16xi1>
+  %mask0 = vector.splat %m : vector<7x14x8x16xi1>
   %0 = vector.transfer_write %v1, %arg1[%c0, %c0, %c0, %c0], %mask0 {in_bounds = [true, false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>} : vector<7x14x8x16xf32>, tensor<?x?x?x?xf32>
   // CHECK: %[[NEW_MASK0:.*]] = vector.transpose %{{.*}} [2, 1, 3, 0] : vector<7x14x8x16xi1> to vector<8x14x16x7xi1>
   // CHECK: %[[NEW_VEC0:.*]] = vector.transpose %{{.*}} [2, 1, 3, 0] : vector<7x14x8x16xf32> to vector<8x14x16x7xf32>

diff  --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index 81fdac964f76a..5f00a71c26ff7 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -295,18 +295,6 @@ func @test_dimop(%arg0: tensor<4x4x?xf32>) {
   return
 }
 
-// CHECK-LABEL: func @test_splat_op
-// CHECK-SAME: [[S:%arg[0-9]+]]: f32
-func @test_splat_op(%s : f32) {
-  %v = splat %s : vector<8xf32>
-  // CHECK: splat [[S]] : vector<8xf32>
-  %t = splat %s : tensor<8xf32>
-  // CHECK: splat [[S]] : tensor<8xf32>
-  %u = "std.splat"(%s) : (f32) -> vector<4xf32>
-  // CHECK: splat [[S]] : vector<4xf32>
-  return
-}
-
 // CHECK-LABEL: func @tensor_load_store
 func @tensor_load_store(%0 : memref<4x4xi32>, %1 : tensor<4x4xi32>) {
   // CHECK-SAME: (%[[MEMREF:.*]]: memref<4x4xi32>,

diff  --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 6650330c7eb7c..037373d1e3346 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -106,24 +106,8 @@ func @return_not_in_function() {
 
 // -----
 
-func @invalid_splat(%v : f32) {
-  splat %v : memref<8xf32>
-  // expected-error at -1 {{must be vector of any type values or statically shaped tensor of any type values}}
-  return
-}
-
-// -----
-
-func @invalid_splat(%v : vector<8xf32>) {
-  %w = splat %v : tensor<8xvector<8xf32>>
-  // expected-error at -1 {{must be integer/index/float type}}
-  return
-}
-
-// -----
-
 func @invalid_splat(%v : f32) { // expected-note {{prior use here}}
-  splat %v : vector<8xf64>
+  vector.splat %v : vector<8xf64>
   // expected-error at -1 {{expects 
diff erent type than prior uses}}
   return
 }

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
index 572280f570737..c7a81018c3a3a 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
@@ -22,7 +22,7 @@ func @print_vector_0d(%a: vector<f32>) {
 }
 
 func @splat_0d(%a: f32) {
-  %1 = splat %a : vector<f32>
+  %1 = vector.splat %a : vector<f32>
   // CHECK: ( 42 )
   vector.print %1: vector<f32>
   return

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-outerproduct-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-outerproduct-f32.mlir
index 28ad9a9521b44..901f030692104 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-outerproduct-f32.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-outerproduct-f32.mlir
@@ -14,9 +14,9 @@
 !vector_type_R = type vector<7xf32>
 
 func @vector_outerproduct_splat_8x8(%fa: f32, %fb: f32, %fc: f32) -> !vector_type_C {
-  %a = splat %fa: !vector_type_A
-  %b = splat %fb: !vector_type_B
-  %c = splat %fc: !vector_type_C
+  %a = vector.splat %fa: !vector_type_A
+  %b = vector.splat %fb: !vector_type_B
+  %c = vector.splat %fc: !vector_type_C
   %d = vector.outerproduct %a, %b, %c : !vector_type_A, !vector_type_B
   return %d: !vector_type_C
 }

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-outerproduct-i64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-outerproduct-i64.mlir
index cf07b113911e0..b6414b92595ac 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-outerproduct-i64.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-outerproduct-i64.mlir
@@ -14,9 +14,9 @@
 !vector_type_R = type vector<7xi64>
 
 func @vector_outerproduct_splat_8x8(%ia: i64, %ib: i64, %ic: i64) -> !vector_type_C {
-  %a = splat %ia: !vector_type_A
-  %b = splat %ib: !vector_type_B
-  %c = splat %ic: !vector_type_C
+  %a = vector.splat %ia: !vector_type_A
+  %b = vector.splat %ib: !vector_type_B
+  %c = vector.splat %ic: !vector_type_C
   %d = vector.outerproduct %a, %b, %c : !vector_type_A, !vector_type_B
   return %d: !vector_type_C
 }

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir
index 5ca8ec18650c8..f1a77923c9eee 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir
@@ -138,7 +138,7 @@ func @transfer_read_1d_mask_in_bounds(
 // Non-contiguous, strided store.
 func @transfer_write_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
   %fn1 = arith.constant -1.0 : f32
-  %vf0 = splat %fn1 : vector<7xf32>
+  %vf0 = vector.splat %fn1 : vector<7xf32>
   vector.transfer_write %vf0, %A[%base1, %base2]
     {permutation_map = affine_map<(d0, d1) -> (d0)>}
     : vector<7xf32>, memref<?x?xf32>
@@ -148,7 +148,7 @@ func @transfer_write_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
 // Non-contiguous, strided store.
 func @transfer_write_1d_mask(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
   %fn1 = arith.constant -2.0 : f32
-  %vf0 = splat %fn1 : vector<7xf32>
+  %vf0 = vector.splat %fn1 : vector<7xf32>
   %mask = arith.constant dense<[1, 0, 1, 0, 1, 1, 1]> : vector<7xi1>
   vector.transfer_write %vf0, %A[%base1, %base2], %mask
     {permutation_map = affine_map<(d0, d1) -> (d0)>}

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
index 0c5444922654c..b2f56525738d6 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
@@ -111,7 +111,7 @@ func @transfer_read_2d_broadcast(
 // Vector store.
 func @transfer_write_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
   %fn1 = arith.constant -1.0 : f32
-  %vf0 = splat %fn1 : vector<1x4xf32>
+  %vf0 = vector.splat %fn1 : vector<1x4xf32>
   vector.transfer_write %vf0, %A[%base1, %base2]
     {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} :
     vector<1x4xf32>, memref<?x?xf32>
@@ -122,7 +122,7 @@ func @transfer_write_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
 func @transfer_write_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: index) {
   %fn1 = arith.constant -2.0 : f32
   %mask = arith.constant dense<[[1, 0, 1, 0]]> : vector<1x4xi1>
-  %vf0 = splat %fn1 : vector<1x4xf32>
+  %vf0 = vector.splat %fn1 : vector<1x4xf32>
   vector.transfer_write %vf0, %A[%base1, %base2], %mask
     {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} :
     vector<1x4xf32>, memref<?x?xf32>

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir
index 62a7ac5d3a8c3..3553d094324fd 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir
@@ -72,7 +72,7 @@ func @transfer_read_3d_transposed(%A : memref<?x?x?x?xf32>,
 func @transfer_write_3d(%A : memref<?x?x?x?xf32>,
                         %o: index, %a: index, %b: index, %c: index) {
   %fn1 = arith.constant -1.0 : f32
-  %vf0 = splat %fn1 : vector<2x9x3xf32>
+  %vf0 = vector.splat %fn1 : vector<2x9x3xf32>
   vector.transfer_write %vf0, %A[%o, %a, %b, %c]
       : vector<2x9x3xf32>, memref<?x?x?x?xf32>
   return

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read.mlir
index 0da322aa67398..427e5101000b1 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read.mlir
@@ -45,7 +45,7 @@ func @transfer_read_mask_inbounds_4(%A : memref<?xf32>, %base: index) {
 
 func @transfer_write_1d(%A : memref<?xf32>, %base: index) {
   %f0 = arith.constant 0.0 : f32
-  %vf0 = splat %f0 : vector<4xf32>
+  %vf0 = vector.splat %f0 : vector<4xf32>
   vector.transfer_write %vf0, %A[%base]
       {permutation_map = affine_map<(d0) -> (d0)>} :
     vector<4xf32>, memref<?xf32>

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-write.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-write.mlir
index 1fc0e88ed7a6c..b773f1f439127 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-write.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-write.mlir
@@ -5,7 +5,7 @@
 
 func @transfer_write16_inbounds_1d(%A : memref<?xf32>, %base: index) {
   %f = arith.constant 16.0 : f32
-  %v = splat %f : vector<16xf32>
+  %v = vector.splat %f : vector<16xf32>
   vector.transfer_write %v, %A[%base]
     {permutation_map = affine_map<(d0) -> (d0)>, in_bounds = [true]}
     : vector<16xf32>, memref<?xf32>
@@ -14,7 +14,7 @@ func @transfer_write16_inbounds_1d(%A : memref<?xf32>, %base: index) {
 
 func @transfer_write13_1d(%A : memref<?xf32>, %base: index) {
   %f = arith.constant 13.0 : f32
-  %v = splat %f : vector<13xf32>
+  %v = vector.splat %f : vector<13xf32>
   vector.transfer_write %v, %A[%base]
     {permutation_map = affine_map<(d0) -> (d0)>}
     : vector<13xf32>, memref<?xf32>
@@ -23,7 +23,7 @@ func @transfer_write13_1d(%A : memref<?xf32>, %base: index) {
 
 func @transfer_write17_1d(%A : memref<?xf32>, %base: index) {
   %f = arith.constant 17.0 : f32
-  %v = splat %f : vector<17xf32>
+  %v = vector.splat %f : vector<17xf32>
   vector.transfer_write %v, %A[%base]
     {permutation_map = affine_map<(d0) -> (d0)>}
     : vector<17xf32>, memref<?xf32>

diff  --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir
index 2e720eae3439c..3542f37f61b76 100644
--- a/mlir/test/Transforms/constant-fold.mlir
+++ b/mlir/test/Transforms/constant-fold.mlir
@@ -789,18 +789,6 @@ func @custom_insertion_position() {
   return
 }
 
-// CHECK-LABEL: func @splat_fold
-func @splat_fold() -> (vector<4xf32>, tensor<4xf32>) {
-  %c = arith.constant 1.0 : f32
-  %v = splat %c : vector<4xf32>
-  %t = splat %c : tensor<4xf32>
-  return %v, %t : vector<4xf32>, tensor<4xf32>
-
-  // CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
-  // CHECK-NEXT: [[T:%.*]] = arith.constant dense<1.000000e+00> : tensor<4xf32>
-  // CHECK-NEXT: return [[V]], [[T]] : vector<4xf32>, tensor<4xf32>
-}
-
 // -----
 
 // CHECK-LABEL: func @subview_scalar_fold

diff  --git a/mlir/test/mlir-cpu-runner/utils.mlir b/mlir/test/mlir-cpu-runner/utils.mlir
index a8e0c49cec796..d6b4f472f1566 100644
--- a/mlir/test/mlir-cpu-runner/utils.mlir
+++ b/mlir/test/mlir-cpu-runner/utils.mlir
@@ -56,7 +56,7 @@ func private @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface
 func @vector_splat_2d() {
   %c0 = arith.constant 0 : index
   %f10 = arith.constant 10.0 : f32
-  %vf10 = splat %f10: !vector_type_C
+  %vf10 = vector.splat %f10: !vector_type_C
   %C = memref.alloc() : !matrix_type_CC
   memref.store %vf10, %C[%c0, %c0]: !matrix_type_CC
 


        


More information about the Mlir-commits mailing list