[Mlir-commits] [mlir] a96d8ae - [mlir][vector] vector.splat and vector.broadcast folding/canonicalizing parity (#150284)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 1 08:57:42 PDT 2025


Author: James Newling
Date: 2025-08-01T08:57:38-07:00
New Revision: a96d8aed984210fedcee2ff1c7d7968e72843191

URL: https://github.com/llvm/llvm-project/commit/a96d8aed984210fedcee2ff1c7d7968e72843191
DIFF: https://github.com/llvm/llvm-project/commit/a96d8aed984210fedcee2ff1c7d7968e72843191.diff

LOG: [mlir][vector] vector.splat and vector.broadcast folding/canonicalizing parity  (#150284)

This PR ensures parity in folding/canonicalizing of vector.broadcast
(from a scalar) and vector.splat. This means that by using
vector.broadcast instead of vector.splat (which is currently
deprecated), there is no loss in optimizations performed. All tests
which were previously checking folding/canonicalizing of vector.splat
are now done for vector.broadcast. The vector.splat canonicalization
tests are now in a separate file, ready for removal when, in the future,
we remove vector.splat completely.

This PR also adds a canonicalizer to vector.splat to always convert it
to vector.broadcast. This is to reduce the 'traffic' through
vector.splat.

There is a chance that this PR will break downstream users who create/expect 
for vector.splat. Changing all such logic to work just vector.broadcast instead
should fix.

Added: 
    mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir
    mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
    mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 5d45508af5c06..dc55704c36183 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2780,6 +2780,10 @@ def Vector_SplatOp : Vector_Op<"splat", [
   let assemblyFormat = "$input attr-dict `:` type($aggregate)";
 
   let hasFolder = 1;
+
+  // vector.splat is deprecated, and vector.broadcast should be used instead.
+  // Canonicalize vector.splat to vector.broadcast.
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a21b5ba70c520..a450056a3041a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2476,17 +2476,19 @@ OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
-/// Rewrite a vector.from_elements into a vector.splat if all elements are the
-/// same SSA value. E.g.:
-///
-/// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
-/// ==> rewrite to vector.splat %a : vector<3xf32>
-static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
-                                                PatternRewriter &rewriter) {
+/// Rewrite vector.from_elements as vector.broadcast if the elements are the
+/// same. Example:
+///    %0 = vector.from_elements %a, %a, %a : vector<3xf32>
+/// =>
+///    %0 = vector.broadcast %a : f32 to vector<3xf32>
+static LogicalResult
+rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp,
+                               PatternRewriter &rewriter) {
   if (!llvm::all_equal(fromElementsOp.getElements()))
     return failure();
-  rewriter.replaceOpWithNewOp<SplatOp>(fromElementsOp, fromElementsOp.getType(),
-                                       fromElementsOp.getElements().front());
+  rewriter.replaceOpWithNewOp<BroadcastOp>(
+      fromElementsOp, fromElementsOp.getType(),
+      fromElementsOp.getElements().front());
   return success();
 }
 
@@ -2517,7 +2519,7 @@ class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
   LogicalResult matchAndRewrite(FromElementsOp fromElements,
                                 PatternRewriter &rewriter) const override {
 
-    // Handled by `rewriteFromElementsAsSplat`
+    // Handled by `rewriteFromElementsAsBroadcast`.
     if (fromElements.getType().getNumElements() == 1)
       return failure();
 
@@ -2610,7 +2612,7 @@ class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
 
 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                  MLIRContext *context) {
-  results.add(rewriteFromElementsAsSplat);
+  results.add(rewriteFromElementsAsBroadcast);
   results.add<FromElementsToShapeCast>(context);
 }
 
@@ -3058,23 +3060,50 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
   }
 };
 
-/// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
+/// Consider the defining operation `defOp` of `value`. If `defOp` is a
+/// vector.splat or a vector.broadcast with a scalar operand, return the scalar
+/// value that is splatted. Otherwise return null.
+///
+/// Examples:
+///
+/// scalar_source --> vector.splat --> value     - return scalar_source
+/// scalar_source --> vector.broadcast --> value - return scalar_source
+static Value getScalarSplatSource(Value value) {
+  // Block argument:
+  Operation *defOp = value.getDefiningOp();
+  if (!defOp)
+    return {};
+
+  // Splat:
+  if (auto splat = dyn_cast<vector::SplatOp>(defOp))
+    return splat.getInput();
+
+  auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
+
+  // Not broadcast (and not splat):
+  if (!broadcast)
+    return {};
+
+  // Broadcast of a vector:
+  if (isa<VectorType>(broadcast.getSourceType()))
+    return {};
+
+  // Broadcast of a scalar:
+  return broadcast.getSource();
+}
+
+/// Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v).
 class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(ShuffleOp op,
                                 PatternRewriter &rewriter) const override {
-    auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
-    auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
-
-    if (!v1Splat || !v2Splat)
+    Value splat = getScalarSplatSource(op.getV1());
+    if (!splat || getScalarSplatSource(op.getV2()) != splat)
       return failure();
 
-    if (v1Splat.getInput() != v2Splat.getInput())
-      return failure();
-
-    rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), v1Splat.getInput());
+    rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
     return success();
   }
 };
@@ -3230,23 +3259,19 @@ class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
   }
 };
 
-/// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
+/// Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v).
 class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(InsertOp op,
                                 PatternRewriter &rewriter) const override {
-    auto srcSplat = op.getValueToStore().getDefiningOp<SplatOp>();
-    auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
-
-    if (!srcSplat || !dstSplat)
-      return failure();
 
-    if (srcSplat.getInput() != dstSplat.getInput())
+    Value splat = getScalarSplatSource(op.getValueToStore());
+    if (!splat || getScalarSplatSource(op.getDest()) != splat)
       return failure();
 
-    rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput());
+    rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
     return success();
   }
 };
@@ -3514,8 +3539,7 @@ LogicalResult InsertStridedSliceOp::verify() {
 }
 
 namespace {
-/// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
-/// SplatOp(X):dst_type) to SplatOp(X):dst_type.
+/// Rewrite insert_strided_slice(splat-like(v), splat-like(v)) as v.
 class FoldInsertStridedSliceSplat final
     : public OpRewritePattern<InsertStridedSliceOp> {
 public:
@@ -3523,18 +3547,13 @@ class FoldInsertStridedSliceSplat final
 
   LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
                                 PatternRewriter &rewriter) const override {
-    auto srcSplatOp =
-        insertStridedSliceOp.getValueToStore().getDefiningOp<vector::SplatOp>();
-    auto destSplatOp =
-        insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
 
-    if (!srcSplatOp || !destSplatOp)
+    auto dst = insertStridedSliceOp.getDest();
+    auto splat = getScalarSplatSource(insertStridedSliceOp.getValueToStore());
+    if (!splat || getScalarSplatSource(dst) != splat)
       return failure();
 
-    if (srcSplatOp.getInput() != destSplatOp.getInput())
-      return failure();
-
-    rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
+    rewriter.replaceOp(insertStridedSliceOp, dst);
     return success();
   }
 };
@@ -4189,17 +4208,18 @@ class StridedSliceBroadcast final
   }
 };
 
-/// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
+/// Rewrite extract_strided_slice(splat-like(v)) with broadcast(v).
 class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
                                 PatternRewriter &rewriter) const override {
-    auto splat = op.getVector().getDefiningOp<SplatOp>();
+
+    Value splat = getScalarSplatSource(op.getVector());
     if (!splat)
       return failure();
-    rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput());
+    rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
     return success();
   }
 };
@@ -6354,19 +6374,19 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
   }
 };
 
-// Folds transpose(splat x : src_type) : res_type into splat x : res_type.
+/// Replace transpose(splat-like(v)) with broadcast(v)
 class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(TransposeOp transposeOp,
                                 PatternRewriter &rewriter) const override {
-    auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
-    if (!splatOp)
+    Value splat = getScalarSplatSource(transposeOp.getVector());
+    if (!splat)
       return failure();
 
-    rewriter.replaceOpWithNewOp<vector::SplatOp>(
-        transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+        transposeOp, transposeOp.getResultVectorType(), splat);
     return success();
   }
 };
@@ -7117,6 +7137,23 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
   return SplatElementsAttr::get(getType(), {constOperand});
 }
 
+// Canonicalizer for vector.splat. It always gets canonicalized to a
+// vector.broadcast.
+class SplatToBroadcastPattern final : public OpRewritePattern<SplatOp> {
+public:
+  using OpRewritePattern<SplatOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(SplatOp splatOp,
+                                PatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
+                                                     splatOp.getOperand());
+    return success();
+  }
+};
+void SplatOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                          MLIRContext *context) {
+  results.add<SplatToBroadcastPattern>(context);
+}
+
 void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                 SetIntRangeFn setResultRanges) {
   setResultRanges(getResult(), argRanges.front());

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 56996b5f364a5..f86fb387be5b8 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -823,11 +823,11 @@ func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32
 
 // -----
 
-// CHECK-LABEL: fold_extract_scalar_from_splat
+// CHECK-LABEL: fold_extract_splatlike
 //  CHECK-SAME:   %[[A:.*]]: f32
 //       CHECK:   return %[[A]] : f32
-func.func @fold_extract_scalar_from_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
-  %b = vector.splat %a : vector<1x2x4xf32>
+func.func @fold_extract_splatlike(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
+  %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
   %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
   return %r : f32
 }
@@ -2063,11 +2063,11 @@ func.func @insert_strided_slice_full_range(%source: vector<16x16xf16>, %dest: ve
 
 // -----
 
-// CHECK-LABEL: extract_strided_splat
-//       CHECK:   %[[B:.*]] = vector.splat %{{.*}} : vector<2x4xf16>
+// CHECK-LABEL: extract_strided_splatlike
+//       CHECK:   %[[B:.*]] = vector.broadcast %{{.*}} f16 to vector<2x4xf16>
 //  CHECK-NEXT:   return %[[B]] : vector<2x4xf16>
-func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> {
- %0 = vector.splat %arg0 : vector<16x4xf16>
+func.func @extract_strided_splatlike(%arg0: f16) -> vector<2x4xf16> {
+ %0 = vector.broadcast %arg0 : f16 to vector<16x4xf16>
  %1 = vector.extract_strided_slice %0
   {offsets = [1, 0], sizes = [2, 4], strides = [1, 1]} :
   vector<16x4xf16> to vector<2x4xf16>
@@ -2353,14 +2353,14 @@ func.func @extract_extract_strided2(%A: vector<2x4xf32>)
 
 // -----
 
-// CHECK-LABEL: func @splat_fold
-func.func @splat_fold() -> vector<4xf32> {
+// CHECK-LABEL: func @splatlike_fold
+//  CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
+//  CHECK-NEXT: return [[V]] : vector<4xf32>
+func.func @splatlike_fold() -> vector<4xf32> {
   %c = arith.constant 1.0 : f32
-  %v = vector.splat %c : vector<4xf32>
+  %v = vector.broadcast %c : f32 to vector<4xf32>
   return %v : vector<4xf32>
 
-  // CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
-  // CHECK-NEXT: return [[V]] : vector<4xf32>
 }
 
 // -----
@@ -2499,10 +2499,10 @@ func.func @shuffle_nofold1(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<5
 
 // -----
 
-// CHECK-LABEL: func @transpose_splat_constant
+// CHECK-LABEL: func @transpose_splatlike_constant
 //       CHECK:   %[[CST:.+]] = arith.constant dense<5.000000e+00> : vector<8x4xf32>
 //       CHECK:   return %[[CST]]
-func.func @transpose_splat_constant() -> vector<8x4xf32> {
+func.func @transpose_splatlike_constant() -> vector<8x4xf32> {
   %cst = arith.constant dense<5.0> : vector<4x8xf32>
   %0 = vector.transpose %cst, [1, 0] : vector<4x8xf32> to vector<8x4xf32>
   return %0 : vector<8x4xf32>
@@ -2510,13 +2510,13 @@ func.func @transpose_splat_constant() -> vector<8x4xf32> {
 
 // -----
 
-// CHECK-LABEL:   func @transpose_splat2(
-// CHECK-SAME:                           %[[VAL_0:.*]]: f32) -> vector<3x4xf32> {
-// CHECK:           %[[VAL_1:.*]] = vector.splat %[[VAL_0]] : vector<3x4xf32>
-// CHECK:           return %[[VAL_1]] : vector<3x4xf32>
-// CHECK:         }
-func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> {
-  %splat = vector.splat %arg : vector<4x3xf32>
+// CHECK-LABEL:   func @transpose_splatlike2(
+//  CHECK-SAME:     %[[VAL_0:.*]]: f32) -> vector<3x4xf32> {
+//       CHECK:     %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : f32 to vector<3x4xf32>
+//       CHECK:     return %[[VAL_1]] : vector<3x4xf32>
+//       CHECK:     }
+func.func @transpose_splatlike2(%arg : f32) -> vector<3x4xf32> {
+  %splat = vector.broadcast %arg : f32 to vector<4x3xf32>
   %0 = vector.transpose %splat, [1, 0] : vector<4x3xf32> to vector<3x4xf32>
   return %0 : vector<3x4xf32>
 }
@@ -2699,13 +2699,13 @@ func.func @bitcast(%a: vector<4x8xf32>) -> vector<4x16xi16> {
 
 // -----
 
-// CHECK-LABEL: @insert_strided_slice_splat
+// CHECK-LABEL: @insert_strided_slice_splatlike
 //  CHECK-SAME: (%[[ARG:.*]]: f32)
-//  CHECK-NEXT:   %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8x16xf32>
+//  CHECK-NEXT:   %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : f32 to vector<8x16xf32>
 //  CHECK-NEXT:   return %[[SPLAT]] : vector<8x16xf32>
-func.func @insert_strided_slice_splat(%x: f32) -> (vector<8x16xf32>) {
-  %splat0 = vector.splat %x : vector<4x4xf32>
-  %splat1 = vector.splat %x : vector<8x16xf32>
+func.func @insert_strided_slice_splatlike(%x: f32) -> (vector<8x16xf32>) {
+  %splat0 = vector.broadcast %x : f32 to vector<4x4xf32>
+  %splat1 = vector.broadcast %x : f32 to vector<8x16xf32>
   %0 = vector.insert_strided_slice %splat0, %splat1 {offsets = [2, 2], strides = [1, 1]}
     : vector<4x4xf32> into vector<8x16xf32>
   return %0 : vector<8x16xf32>
@@ -2778,13 +2778,13 @@ func.func @insert_strided_2d_constant() ->
 
 // -----
 
-// CHECK-LABEL: func @shuffle_splat
+// CHECK-LABEL: func @shuffle_splatlike
 //  CHECK-SAME:   (%[[ARG:.*]]: i32)
-//  CHECK-NEXT:   %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<4xi32>
+//  CHECK-NEXT:   %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<4xi32>
 //  CHECK-NEXT:   return %[[SPLAT]] : vector<4xi32>
-func.func @shuffle_splat(%x : i32) -> vector<4xi32> {
-  %v0 = vector.splat %x : vector<4xi32>
-  %v1 = vector.splat %x : vector<2xi32>
+func.func @shuffle_splatlike(%x : i32) -> vector<4xi32> {
+  %v0 = vector.broadcast %x : i32 to vector<4xi32>
+  %v1 = vector.broadcast %x : i32 to vector<2xi32>
   %shuffle = vector.shuffle %v0, %v1 [2, 3, 4, 5] : vector<4xi32>, vector<2xi32>
   return %shuffle : vector<4xi32>
 }
@@ -2792,13 +2792,13 @@ func.func @shuffle_splat(%x : i32) -> vector<4xi32> {
 
 // -----
 
-// CHECK-LABEL: func @insert_splat
+// CHECK-LABEL: func @insert_splatlike
 //  CHECK-SAME:   (%[[ARG:.*]]: i32)
-//  CHECK-NEXT:   %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<2x4x3xi32>
+//  CHECK-NEXT:   %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<2x4x3xi32>
 //  CHECK-NEXT:   return %[[SPLAT]] : vector<2x4x3xi32>
-func.func @insert_splat(%x : i32) -> vector<2x4x3xi32> {
-  %v0 = vector.splat %x : vector<4x3xi32>
-  %v1 = vector.splat %x : vector<2x4x3xi32>
+func.func @insert_splatlike(%x : i32) -> vector<2x4x3xi32> {
+  %v0 = vector.broadcast %x : i32 to vector<4x3xi32>
+  %v1 = vector.broadcast %x : i32 to vector<2x4x3xi32>
   %insert = vector.insert %v0, %v1[0] : vector<4x3xi32> into vector<2x4x3xi32>
   return %insert : vector<2x4x3xi32>
 }
@@ -3030,11 +3030,11 @@ func.func @rank_1_shuffle_to_interleave(%arg0: vector<6xi32>, %arg1: vector<6xi3
 
 // -----
 
-// CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression(
-//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: vector<f32>, %[[c:.*]]: vector<2xf32>)
-func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>, %c: vector<2xf32>) -> (f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) {
-  // Splat scalar to 0D and extract scalar.
-  %0 = vector.splat %a : vector<f32>
+// CHECK-LABEL: func @extract_from_0d_splatlike_broadcast_regression(
+//  CHECK-SAME:     %[[A:.*]]: f32, %[[B:.*]]: vector<f32>, %[[C:.*]]: vector<2xf32>)
+func.func @extract_from_0d_splatlike_broadcast_regression(%a: f32, %b: vector<f32>, %c: vector<2xf32>) -> (f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) {
+  // Splat/broadcast scalar to 0D and extract scalar.
+  %0 = vector.broadcast %a : f32 to vector<f32>
   %1 = vector.extract %0[] : f32 from vector<f32>
 
   // Broadcast scalar to 0D and extract scalar.
@@ -3042,12 +3042,12 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,
   %3 = vector.extract %2[] : f32 from vector<f32>
 
   // Broadcast 0D to 3D and extract scalar.
-  // CHECK: %[[extract1:.*]] = vector.extract %[[b]][] : f32 from vector<f32>
+  // CHECK: %[[EXTRACT1:.*]] = vector.extract %[[B]][] : f32 from vector<f32>
   %4 = vector.broadcast %b : vector<f32> to vector<1x2x4xf32>
   %5 = vector.extract %4[0, 0, 1] : f32 from vector<1x2x4xf32>
 
-  // Splat scalar to 2D and extract scalar.
-  %6 = vector.splat %a : vector<2x3xf32>
+  // Splat/broadcast scalar to 2D and extract scalar.
+  %6 = vector.broadcast %a : f32 to vector<2x3xf32>
   %7 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
 
   // Broadcast scalar to 3D and extract scalar.
@@ -3055,14 +3055,14 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,
   %9 = vector.extract %8[2, 1, 5] : f32 from vector<5x6x7xf32>
 
   // Extract 2D from 3D that was broadcasted from a scalar.
-  // CHECK: %[[extract2:.*]] = vector.broadcast %[[a]] : f32 to vector<6x7xf32>
+  // CHECK: %[[EXTRACT2:.*]] = vector.broadcast %[[A]] : f32 to vector<6x7xf32>
   %10 = vector.extract %8[2] : vector<6x7xf32> from vector<5x6x7xf32>
 
   // Extract 1D from 2D that was splat'ed from a scalar.
-  // CHECK: %[[extract3:.*]] = vector.broadcast %[[a]] : f32 to vector<3xf32>
+  // CHECK: %[[EXTRACT3:.*]] = vector.broadcast %[[A]] : f32 to vector<3xf32>
   %11 = vector.extract %6[1] : vector<3xf32> from vector<2x3xf32>
 
-  // CHECK:   return %[[a]], %[[a]], %[[extract1]], %[[a]], %[[a]], %[[extract2]], %[[extract3]]
+  // CHECK:   return %[[A]], %[[A]], %[[EXTRACT1]], %[[A]], %[[A]], %[[EXTRACT2]], %[[EXTRACT3]]
   return %1, %3, %5, %7, %9, %10, %11 : f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>
 }
 
@@ -3504,7 +3504,7 @@ func.func @fold_insert_use_chain(%arg : vector<4x4xf32>, %val : f32, %pos: index
   %v_0 = vector.insert %val, %arg[%pos, 0] : f32 into vector<4x4xf32>
   %v_1 = vector.insert %val, %v_0[%pos, 0] : f32 into vector<4x4xf32>
   %v_2 = vector.insert %val, %v_1[%pos, 0] : f32 into vector<4x4xf32>
-  return %v_2 : vector<4x4xf32>  
+  return %v_2 : vector<4x4xf32>
 }
 
 // -----
@@ -3518,5 +3518,5 @@ func.func @fold_insert_use_chain(%arg : vector<4x4xf32>, %val : f32, %pos: index
 func.func @no_fold_insert_use_chain_mismatch_static_position(%arg : vector<4xf32>, %val : f32) -> vector<4xf32> {
   %v_0 = vector.insert %val, %arg[0] : f32 into vector<4xf32>
   %v_1 = vector.insert %val, %v_0[1] : f32 into vector<4xf32>
-  return %v_1 : vector<4xf32>  
+  return %v_1 : vector<4xf32>
 }

diff  --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
index fdab2a8918a2e..f43328f621787 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
@@ -36,9 +36,9 @@ func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32
 //  CHECK-SAME:      %[[A:.*]]: f32, %[[B:.*]]: f32)
 func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) {
   %0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
-  // CHECK: %[[SPLAT1:.*]] = vector.splat %[[A]] : vector<3xf32>
+  // CHECK: %[[SPLAT1:.*]] = vector.broadcast %[[A]] : f32 to vector<3xf32>
   %1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32>
-  // CHECK: %[[SPLAT2:.*]] = vector.splat %[[B]] : vector<3xf32>
+  // CHECK: %[[SPLAT2:.*]] = vector.broadcast %[[B]] : f32 to vector<3xf32>
   %2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32>
   // CHECK: return %[[SPLAT1]], %[[SPLAT2]]
   return %1, %2 : vector<3xf32>, vector<3xf32>
@@ -63,11 +63,11 @@ func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>,
 // CHECK-LABEL: func @from_elements_to_splat(
 //  CHECK-SAME:      %[[A:.*]]: f32, %[[B:.*]]: f32)
 func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) {
-  // CHECK: %[[SPLAT:.*]] = vector.splat %[[A]] : vector<2x3xf32>
+  // CHECK: %[[SPLAT:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3xf32>
   %0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32>
   // CHECK: %[[FROM_EL:.*]] = vector.from_elements {{.*}} : vector<2x3xf32>
   %1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32>
-  // CHECK: %[[SPLAT2:.*]] = vector.splat %[[A]] : vector<f32>
+  // CHECK: %[[SPLAT2:.*]] = vector.broadcast %[[A]] : f32 to vector<f32>
   %2 = vector.from_elements %a : vector<f32>
   // CHECK: return %[[SPLAT]], %[[FROM_EL]], %[[SPLAT2]]
   return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
@@ -170,7 +170,7 @@ func.func @large_source_with_shape_cast_required(%arg0: vector<2x2x2x2xi8>) -> v
 // Could match, but handled by `rewriteFromElementsAsSplat`.
 // CHECK-LABEL: func @extract_single_elm(
 //  CHECK-NEXT:      vector.extract
-//  CHECK-NEXT:      vector.splat
+//  CHECK-NEXT:      vector.broadcast
 //  CHECK-NEXT:      return
 func.func @extract_single_elm(%arg0 : vector<2x3xi8>) -> vector<1xi8> {
   %0 = vector.extract %arg0[0, 0] : i8 from vector<2x3xi8>

diff  --git a/mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir
new file mode 100644
index 0000000000000..e4a9391770b6c
--- /dev/null
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir
@@ -0,0 +1,126 @@
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+// This file should be removed when vector.splat is removed.
+// This file tests canonicalization/folding with vector.splat.
+// These tests all have equivalent tests using vector.broadcast in canonicalize.mlir
+
+
+// CHECK-LABEL: fold_extract_splat
+//  CHECK-SAME:   %[[A:.*]]: f32
+//       CHECK:   return %[[A]] : f32
+func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
+  %b = vector.splat %a : vector<1x2x4xf32>
+  %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
+  return %r : f32
+}
+
+// -----
+
+// CHECK-LABEL: extract_strided_splat
+//       CHECK:   %[[B:.*]] = vector.broadcast %{{.*}} f16 to vector<2x4xf16>
+//  CHECK-NEXT:   return %[[B]] : vector<2x4xf16>
+func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> {
+ %0 = vector.splat %arg0 : vector<16x4xf16>
+ %1 = vector.extract_strided_slice %0
+  {offsets = [1, 0], sizes = [2, 4], strides = [1, 1]} :
+  vector<16x4xf16> to vector<2x4xf16>
+  return %1 : vector<2x4xf16>
+}
+
+// -----
+
+// CHECK-LABEL: func @splat_fold
+//  CHECK-NEXT:   [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
+//  CHECK-NEXT:   return [[V]] : vector<4xf32>
+func.func @splat_fold() -> vector<4xf32> {
+  %c = arith.constant 1.0 : f32
+  %v = vector.splat %c : vector<4xf32>
+  return %v : vector<4xf32>
+
+}
+
+// -----
+
+// CHECK-LABEL:   func @transpose_splat2(
+//  CHECK-SAME:      %[[VAL_0:.*]]: f32) -> vector<3x4xf32> {
+//       CHECK:      %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : f32 to vector<3x4xf32>
+//       CHECK:      return %[[VAL_1]] : vector<3x4xf32>
+func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> {
+  %splat = vector.splat %arg : vector<4x3xf32>
+  %0 = vector.transpose %splat, [1, 0] : vector<4x3xf32> to vector<3x4xf32>
+  return %0 : vector<3x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @insert_strided_slice_splat
+//  CHECK-SAME:   (%[[ARG:.*]]: f32)
+//  CHECK-NEXT:   %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : f32 to vector<8x16xf32>
+//  CHECK-NEXT:   return %[[SPLAT]] : vector<8x16xf32>
+func.func @insert_strided_slice_splat(%x: f32) -> (vector<8x16xf32>) {
+  %splat0 = vector.splat %x : vector<4x4xf32>
+  %splat1 = vector.splat %x : vector<8x16xf32>
+  %0 = vector.insert_strided_slice %splat0, %splat1 {offsets = [2, 2], strides = [1, 1]}
+    : vector<4x4xf32> into vector<8x16xf32>
+  return %0 : vector<8x16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @shuffle_splat
+//  CHECK-SAME:   (%[[ARG:.*]]: i32)
+//  CHECK-NEXT:   %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<4xi32>
+//  CHECK-NEXT:   return %[[SPLAT]] : vector<4xi32>
+func.func @shuffle_splat(%x : i32) -> vector<4xi32> {
+  %v0 = vector.splat %x : vector<4xi32>
+  %v1 = vector.splat %x : vector<2xi32>
+  %shuffle = vector.shuffle %v0, %v1 [2, 3, 4, 5] : vector<4xi32>, vector<2xi32>
+  return %shuffle : vector<4xi32>
+}
+
+
+// -----
+
+// CHECK-LABEL: func @insert_splat
+//  CHECK-SAME:   (%[[ARG:.*]]: i32)
+//  CHECK-NEXT:   %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<2x4x3xi32>
+//  CHECK-NEXT:   return %[[SPLAT]] : vector<2x4x3xi32>
+func.func @insert_splat(%x : i32) -> vector<2x4x3xi32> {
+  %v0 = vector.splat %x : vector<4x3xi32>
+  %v1 = vector.splat %x : vector<2x4x3xi32>
+  %insert = vector.insert %v0, %v1[0] : vector<4x3xi32> into vector<2x4x3xi32>
+  return %insert : vector<2x4x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression
+//  CHECK-SAME:     (%[[A:.*]]: f32, %[[C:.*]]: vector<2xf32>)
+func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %c: vector<2xf32>) -> (f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) {
+  // Splat scalar to 0D and extract scalar.
+  %0 = vector.splat %a : vector<f32>
+  %1 = vector.extract %0[] : f32 from vector<f32>
+
+  // Broadcast scalar to 0D and extract scalar.
+  %2 = vector.splat %a : vector<f32>
+  %3 = vector.extract %2[] : f32 from vector<f32>
+
+  // Splat scalar to 2D and extract scalar.
+  %6 = vector.splat %a : vector<2x3xf32>
+  %7 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
+
+  // Broadcast scalar to 3D and extract scalar.
+  %8 = vector.splat %a : vector<5x6x7xf32>
+  %9 = vector.extract %8[2, 1, 5] : f32 from vector<5x6x7xf32>
+
+  // Extract 2D from 3D that was broadcasted from a scalar.
+  // CHECK: %[[EXTRACT2:.*]] = vector.broadcast %[[A]] : f32 to vector<6x7xf32>
+  %10 = vector.extract %8[2] : vector<6x7xf32> from vector<5x6x7xf32>
+
+  // Extract 1D from 2D that was splat'ed from a scalar.
+  // CHECK: %[[EXTRACT3:.*]] = vector.broadcast %[[A]] : f32 to vector<3xf32>
+  %11 = vector.extract %6[1] : vector<3xf32> from vector<2x3xf32>
+
+  // CHECK: return %[[A]], %[[A]], %[[A]], %[[A]], %[[EXTRACT2]], %[[EXTRACT3]]
+  return %1, %3, %7, %9, %10, %11 : f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>
+}

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index 511ab70f35086..1b54d54ffbd9f 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -284,19 +284,19 @@ func.func @transfer_read_permutations(%mem_0 : memref<?x?xf32>, %mem_1 : memref<
   %cst = arith.constant 0.000000e+00 : f32
   %c0 = arith.constant 0 : index
 
-// CHECK: %[[MASK0:.*]] = vector.splat %{{.*}} : vector<14x7xi1>
+// CHECK: %[[MASK0:.*]] = vector.broadcast %{{.*}} : i1 to vector<14x7xi1>
   %mask0 = vector.splat %m : vector<14x7xi1>
   %0 = vector.transfer_read %mem_1[%c0, %c0, %c0, %c0], %cst, %mask0 {in_bounds = [true, false, true, true], permutation_map = #map0} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
 // CHECK: vector.transfer_read {{.*}} %[[MASK0]] {in_bounds = [false, true, true, true], permutation_map = #[[$MAP0]]} : memref<?x?x?x?xf32>, vector<14x7x8x16xf32>
 // CHECK: vector.transpose %{{.*}}, [1, 0, 2, 3] : vector<14x7x8x16xf32> to vector<7x14x8x16xf32>
 
-// CHECK: %[[MASK1:.*]] = vector.splat %{{.*}} : vector<16x14xi1>
+// CHECK: %[[MASK1:.*]] = vector.broadcast %{{.*}} : i1 to vector<16x14xi1>
   %mask1 = vector.splat %m : vector<16x14xi1>
   %1 = vector.transfer_read %mem_1[%c0, %c0, %c0, %c0], %cst, %mask1 {in_bounds = [true, false, true, false], permutation_map = #map1} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
 // CHECK: vector.transfer_read {{.*}} %[[MASK1]] {in_bounds = [false, false, true, true], permutation_map = #[[$MAP0]]} : memref<?x?x?x?xf32>, vector<16x14x7x8xf32>
 // CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32>
 
-// CHECK: %[[MASK3:.*]] = vector.splat %{{.*}} : vector<14x7xi1>
+// CHECK: %[[MASK3:.*]] = vector.broadcast %{{.*}} : i1 to vector<14x7xi1>
   %mask2 = vector.splat %m : vector<14x7xi1>
   %2 = vector.transfer_read %mem_1[%c0, %c0, %c0, %c0], %cst, %mask2 {in_bounds = [true, false, true, true], permutation_map = #map2} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
 // CHECK: vector.transfer_read {{.*}} %[[MASK3]] {in_bounds = [false, true, true], permutation_map = #[[$MAP1]]} : memref<?x?x?x?xf32>, vector<14x16x7xf32>
@@ -336,7 +336,7 @@ func.func @transfer_write_permutations_tensor_masked(
   // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
   %c0 = arith.constant 0 : index
 
-  // CHECK: %[[MASK:.*]] = vector.splat %[[M]] : vector<16x14x7x8xi1>
+  // CHECK: %[[MASK:.*]] = vector.broadcast %[[M]] : i1 to vector<16x14x7x8xi1>
   %mask0 = vector.splat %m : vector<16x14x7x8xi1>
   %res = vector.transfer_write %vec, %dst[%c0, %c0, %c0, %c0], %mask0 {in_bounds = [true, false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>} : vector<7x14x8x16xf32>, tensor<?x?x?x?xf32>
   // CHECK: %[[NEW_VEC0:.*]] = vector.transpose %{{.*}} [3, 1, 0, 2] : vector<7x14x8x16xf32> to vector<16x14x7x8xf32>


        


More information about the Mlir-commits mailing list