[Mlir-commits] [mlir] [mlir][vector] vector.splat deprecation: folding/canonicalizing parity with broadcast (PR #150284)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jul 23 11:57:44 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: James Newling (newling)
<details>
<summary>Changes</summary>
This PR ensures parity in folding/canonicalizing of vector.broadcast (from a scalar) and vector.splat. This means that by using vector.broadcast instead of vector.splat (which is currently deprecated), there is no loss in optimizations performed. All tests which were previously checking folding/canonicalizing of vector.splat are now done for vector.broadcast. The vector.splat canonicalization tests are now in a separate file, ready for removal when, in the future, we remove vector.splat completely.
This PR also adds a canonicalizer to vector.splat to always convert it to vector.broadcast. This is to reduce the 'traffic' through vector.splat.
---
Patch is 31.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/150284.diff
6 Files Affected:
- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+4)
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+83-50)
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+44-42)
- (modified) mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir (+5-5)
- (added) mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir (+155)
- (modified) mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir (+4-4)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 0a5c1e5d9ab97..c3afa64fa08c3 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2879,6 +2879,10 @@ def Vector_SplatOp : Vector_Op<"splat", [
let assemblyFormat = "$input attr-dict `:` type($aggregate)";
let hasFolder = 1;
+
+ // vector.splat is deprecated, and vector.broadcast should be used instead.
+ // Canonicalize vector.splat to vector.broadcast.
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8c97aed6e7742..28a573353ecf4 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1288,19 +1288,47 @@ LogicalResult vector::ExtractElementOp::verify() {
return success();
}
+/// Consider the defining operation `defOp` of `value`. If `defOp` is a
+/// vector.splat or a vector.broadcast with a scalar operand, return the scalar
+/// value that is splatted. Otherwise return null.
+///
+/// Cases where null is not returned:
+///
+/// scalar_source --> vector.splat --> value - return scalar_source
+/// scalar_source --> vector.broadcast --> value - return scalar_source
+static Value getSplatSource(Value value) {
+
+ // Block argument:
+ Operation *defOp = value.getDefiningOp();
+ if (!defOp)
+ return {};
+
+ // Splat:
+ auto splat = dyn_cast<vector::SplatOp>(defOp);
+ if (splat)
+ return splat.getInput();
+
+ auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
+
+ // Not broadcast (and not splat):
+ if (!broadcast)
+ return {};
+
+ // Broadcast of a vector:
+ if (isa<VectorType>(broadcast.getSourceType()))
+ return {};
+
+ // Broadcast of a scalar:
+ return broadcast.getSource();
+}
+
OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
// Skip the 0-D vector here now.
if (!adaptor.getPosition())
return {};
- // Fold extractelement (splat X) -> X.
- if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
- return splat.getInput();
-
- // Fold extractelement(broadcast(X)) -> X.
- if (auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
- if (!llvm::isa<VectorType>(broadcast.getSource().getType()))
- return broadcast.getSource();
+ if (auto splatValue = getSplatSource(getVector()))
+ return splatValue;
auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
@@ -2539,12 +2567,14 @@ OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
///
/// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
/// ==> rewrite to vector.splat %a : vector<3xf32>
-static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
- PatternRewriter &rewriter) {
+static LogicalResult
+rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp,
+ PatternRewriter &rewriter) {
if (!llvm::all_equal(fromElementsOp.getElements()))
return failure();
- rewriter.replaceOpWithNewOp<SplatOp>(fromElementsOp, fromElementsOp.getType(),
- fromElementsOp.getElements().front());
+ rewriter.replaceOpWithNewOp<BroadcastOp>(
+ fromElementsOp, fromElementsOp.getType(),
+ fromElementsOp.getElements().front());
return success();
}
@@ -2575,7 +2605,7 @@ class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
LogicalResult matchAndRewrite(FromElementsOp fromElements,
PatternRewriter &rewriter) const override {
- // Handled by `rewriteFromElementsAsSplat`
+ // Handled by `rewriteFromElementsAsBroadcast`
if (fromElements.getType().getNumElements() == 1)
return failure();
@@ -2669,7 +2699,7 @@ class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add(rewriteFromElementsAsSplat);
+ results.add(rewriteFromElementsAsBroadcast);
results.add<FromElementsToShapeCast>(context);
}
@@ -3117,23 +3147,18 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
}
};
-/// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
+/// Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v)
class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ShuffleOp op,
PatternRewriter &rewriter) const override {
- auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
- auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
-
- if (!v1Splat || !v2Splat)
- return failure();
-
- if (v1Splat.getInput() != v2Splat.getInput())
+ Value splat = getSplatSource(op.getV1());
+ if (!splat || getSplatSource(op.getV2()) != splat)
return failure();
- rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), v1Splat.getInput());
+ rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
return success();
}
};
@@ -3343,23 +3368,19 @@ class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
}
};
-/// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
+/// Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v)
class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(InsertOp op,
PatternRewriter &rewriter) const override {
- auto srcSplat = op.getValueToStore().getDefiningOp<SplatOp>();
- auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
-
- if (!srcSplat || !dstSplat)
- return failure();
- if (srcSplat.getInput() != dstSplat.getInput())
+ Value splat = getSplatSource(op.getValueToStore());
+ if (!splat || getSplatSource(op.getDest()) != splat)
return failure();
- rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput());
+ rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
return success();
}
};
@@ -3627,8 +3648,7 @@ LogicalResult InsertStridedSliceOp::verify() {
}
namespace {
-/// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
-/// SplatOp(X):dst_type) to SplatOp(X):dst_type.
+/// Rewrite insert_strided_slice(splat-like(v), splat-like(v)) as v
class FoldInsertStridedSliceSplat final
: public OpRewritePattern<InsertStridedSliceOp> {
public:
@@ -3636,18 +3656,13 @@ class FoldInsertStridedSliceSplat final
LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
PatternRewriter &rewriter) const override {
- auto srcSplatOp =
- insertStridedSliceOp.getValueToStore().getDefiningOp<vector::SplatOp>();
- auto destSplatOp =
- insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
- if (!srcSplatOp || !destSplatOp)
+ auto dst = insertStridedSliceOp.getDest();
+ auto splat = getSplatSource(insertStridedSliceOp.getValueToStore());
+ if (!splat || getSplatSource(dst) != splat)
return failure();
- if (srcSplatOp.getInput() != destSplatOp.getInput())
- return failure();
-
- rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
+ rewriter.replaceOp(insertStridedSliceOp, dst);
return success();
}
};
@@ -4302,17 +4317,18 @@ class StridedSliceBroadcast final
}
};
-/// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
+/// Rewrite extract_strided_slice(splat-like(v)) with broadcast(v)
class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
- auto splat = op.getVector().getDefiningOp<SplatOp>();
+
+ Value splat = getSplatSource(op.getVector());
if (!splat)
return failure();
- rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput());
+ rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
return success();
}
};
@@ -6463,19 +6479,19 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
}
};
-// Folds transpose(splat x : src_type) : res_type into splat x : res_type.
+/// Replace transpose(splat-like(v)) with broadcast(v)
class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
- auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
- if (!splatOp)
+ Value splat = getSplatSource(transposeOp.getVector());
+ if (!splat)
return failure();
- rewriter.replaceOpWithNewOp<vector::SplatOp>(
- transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+ transposeOp, transposeOp.getResultVectorType(), splat);
return success();
}
};
@@ -7226,6 +7242,23 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
return SplatElementsAttr::get(getType(), {constOperand});
}
+// Canonicalizer for vector.splat. It always gets canonicalized to a
+// vector.broadcast.
+class SplatToBroadcastPattern : public OpRewritePattern<SplatOp> {
+public:
+ using OpRewritePattern<SplatOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(SplatOp splatOp,
+ PatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
+ splatOp.getOperand());
+ return success();
+ }
+};
+void SplatOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<SplatToBroadcastPattern>(context);
+}
+
void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRanges) {
setResultRanges(getResult(), argRanges.front());
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1461c30162c5f..166df205358c7 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -823,11 +823,11 @@ func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32
// -----
-// CHECK-LABEL: fold_extract_scalar_from_splat
+// CHECK-LABEL: fold_extract_splatlike
// CHECK-SAME: %[[A:.*]]: f32
// CHECK: return %[[A]] : f32
-func.func @fold_extract_scalar_from_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
- %b = vector.splat %a : vector<1x2x4xf32>
+func.func @fold_extract_splatlike(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
+ %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
return %r : f32
}
@@ -2033,11 +2033,11 @@ func.func @insert_strided_slice_full_range(%source: vector<16x16xf16>, %dest: ve
// -----
-// CHECK-LABEL: extract_strided_splat
-// CHECK: %[[B:.*]] = vector.splat %{{.*}} : vector<2x4xf16>
+// CHECK-LABEL: extract_strided_splatlike
+// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} f16 to vector<2x4xf16>
// CHECK-NEXT: return %[[B]] : vector<2x4xf16>
-func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> {
- %0 = vector.splat %arg0 : vector<16x4xf16>
+func.func @extract_strided_splatlike(%arg0: f16) -> vector<2x4xf16> {
+ %0 = vector.broadcast %arg0 : f16 to vector<16x4xf16>
%1 = vector.extract_strided_slice %0
{offsets = [1, 0], sizes = [2, 4], strides = [1, 1]} :
vector<16x4xf16> to vector<2x4xf16>
@@ -2323,10 +2323,10 @@ func.func @extract_extract_strided2(%A: vector<2x4xf32>)
// -----
-// CHECK-LABEL: func @splat_fold
-func.func @splat_fold() -> vector<4xf32> {
+// CHECK-LABEL: func @splatlike_fold
+func.func @splatlike_fold() -> vector<4xf32> {
%c = arith.constant 1.0 : f32
- %v = vector.splat %c : vector<4xf32>
+ %v = vector.broadcast %c : f32 to vector<4xf32>
return %v : vector<4xf32>
// CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
@@ -2469,10 +2469,10 @@ func.func @shuffle_nofold1(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<5
// -----
-// CHECK-LABEL: func @transpose_splat_constant
+// CHECK-LABEL: func @transpose_splatlike_constant
// CHECK: %[[CST:.+]] = arith.constant dense<5.000000e+00> : vector<8x4xf32>
// CHECK: return %[[CST]]
-func.func @transpose_splat_constant() -> vector<8x4xf32> {
+func.func @transpose_splatlike_constant() -> vector<8x4xf32> {
%cst = arith.constant dense<5.0> : vector<4x8xf32>
%0 = vector.transpose %cst, [1, 0] : vector<4x8xf32> to vector<8x4xf32>
return %0 : vector<8x4xf32>
@@ -2480,13 +2480,13 @@ func.func @transpose_splat_constant() -> vector<8x4xf32> {
// -----
-// CHECK-LABEL: func @transpose_splat2(
+// CHECK-LABEL: func @transpose_splatlike2(
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> {
-// CHECK: %[[VAL_1:.*]] = vector.splat %[[VAL_0]] : vector<3x4xf32>
+ // CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : f32 to vector<3x4xf32>
// CHECK: return %[[VAL_1]] : vector<3x4xf32>
// CHECK: }
-func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> {
- %splat = vector.splat %arg : vector<4x3xf32>
+func.func @transpose_splatlike2(%arg : f32) -> vector<3x4xf32> {
+ %splat = vector.broadcast %arg : f32 to vector<4x3xf32>
%0 = vector.transpose %splat, [1, 0] : vector<4x3xf32> to vector<3x4xf32>
return %0 : vector<3x4xf32>
}
@@ -2638,11 +2638,13 @@ func.func @extract_element_fold() -> i32 {
return %1 : i32
}
-// CHECK-LABEL: func @extract_element_splat_fold
+// -----
+
+// CHECK-LABEL: func @extract_element_splatlike_fold
// CHECK-SAME: (%[[ARG:.+]]: i32)
// CHECK: return %[[ARG]]
-func.func @extract_element_splat_fold(%a : i32) -> i32 {
- %v = vector.splat %a : vector<4xi32>
+func.func @extract_element_splatlike_fold(%a : i32) -> i32 {
+ %v = vector.broadcast %a : i32 to vector<4xi32>
%i = arith.constant 2 : i32
%1 = vector.extractelement %v[%i : i32] : vector<4xi32>
return %1 : i32
@@ -2781,13 +2783,13 @@ func.func @bitcast(%a: vector<4x8xf32>) -> vector<4x16xi16> {
// -----
-// CHECK-LABEL: @insert_strided_slice_splat
+// CHECK-LABEL: @insert_strided_slice_splatlike
// CHECK-SAME: (%[[ARG:.*]]: f32)
-// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8x16xf32>
+// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : f32 to vector<8x16xf32>
// CHECK-NEXT: return %[[SPLAT]] : vector<8x16xf32>
-func.func @insert_strided_slice_splat(%x: f32) -> (vector<8x16xf32>) {
- %splat0 = vector.splat %x : vector<4x4xf32>
- %splat1 = vector.splat %x : vector<8x16xf32>
+func.func @insert_strided_slice_splatlike(%x: f32) -> (vector<8x16xf32>) {
+ %splat0 = vector.broadcast %x : f32 to vector<4x4xf32>
+ %splat1 = vector.broadcast %x : f32 to vector<8x16xf32>
%0 = vector.insert_strided_slice %splat0, %splat1 {offsets = [2, 2], strides = [1, 1]}
: vector<4x4xf32> into vector<8x16xf32>
return %0 : vector<8x16xf32>
@@ -2860,13 +2862,13 @@ func.func @insert_strided_2d_constant() ->
// -----
-// CHECK-LABEL: func @shuffle_splat
+// CHECK-LABEL: func @shuffle_splatlike
// CHECK-SAME: (%[[ARG:.*]]: i32)
-// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<4xi32>
+// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<4xi32>
// CHECK-NEXT: return %[[SPLAT]] : vector<4xi32>
-func.func @shuffle_splat(%x : i32) -> vector<4xi32> {
- %v0 = vector.splat %x : vector<4xi32>
- %v1 = vector.splat %x : vector<2xi32>
+func.func @shuffle_splatlike(%x : i32) -> vector<4xi32> {
+ %v0 = vector.broadcast %x : i32 to vector<4xi32>
+ %v1 = vector.broadcast %x : i32 to vector<2xi32>
%shuffle = vector.shuffle %v0, %v1 [2, 3, 4, 5] : vector<4xi32>, vector<2xi32>
return %shuffle : vector<4xi32>
}
@@ -2874,13 +2876,13 @@ func.func @shuffle_splat(%x : i32) -> vector<4xi32> {
// -----
-// CHECK-LABEL: func @insert_splat
+// CHECK-LABEL: func @insert_splatlike
// CHECK-SAME: (%[[ARG:.*]]: i32)
-// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<2x4x3xi32>
+// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<2x4x3xi32>
// CHECK-NEXT: return %[[SPLAT]] : vector<2x4x3xi32>
-func.func @insert_splat(%x : i32) -> vector<2x4x3xi32> {
- %v0 = vector.splat %x : vector<4x3xi32>
- %v1 = vector.splat %x : vector<2x4x3xi32>
+func.func @insert_splatlike(%x : i32) -> vector<2x4x3xi32> {
+ %v0 = vector.broadcast %x : i32 to vector<4x3xi32>
+ %v1 = vector.broadcast %x : i32 to vector<2x4x3xi32>
%insert = vector.insert %v0, %v1[0] : vector<4x3xi32> into vector<2x4x3xi32>
return %insert : vector<2x4x3xi32>
}
@@ -3124,11 +3126,11 @@ func.func @rank_1_shuffle_to_interleave(%arg0: vector<6xi32>, %arg1: vector<6xi3
// -----
-// CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression(
+// CHECK-LABEL: func @extract_from_0d_splatlike_broadcast_regression(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: vector<f32>, %[[c:.*]]: vector<2xf32>)
-func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>, %c: vector<2xf32>) -> (f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) {
- // Splat scalar to 0D and extract scalar.
- %0 = vector.splat %a : vector<f32>
+func.func @extract_from_0d_splatlike_broadcast_regression(%a: f32, %b: vector<f32>, %c: vector<2xf32>) -> (f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) {
+ // Splat/broadcast scalar to 0D and extract scalar.
+ %0 = vector.broadcast %a : f32 to vector<f32>
%1 = vector.extract %0[] : f32 from vector<f32>
// Broadcast scalar to 0D and extract scalar.
@@ -3140,8 +3142,8 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,
%4 = vector.broadcast %b : vector<f32> to vector<1x2x4xf32>
%5 = vector.extract %4[0, 0, 1] : f32 from vector<1x2x4xf32>
- // Splat scalar to 2D and extract scalar.
- %6 = vector.splat %a : vector<2x3xf32>
+ // Splat/broadcast scalar to 2D and extract scalar.
+ %6 = vector.broadcast %a : f32 to vector<2x3xf32>
%7 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
// Broadcast scalar to 3D and extract scalar.
@@ -3598,7 +3600,7 @@ func.func @fold_insert_use_chain(%arg : vector<4x4xf32>, %val : f32, %pos: index
%v_0 = vector.insert %val, %arg[%pos, 0] : f32 into vector<4x4xf32>
%v_1 = vector.insert %val, %v_0[%pos, 0] : f32 into vector<4x4xf32>
%v_2 = vector.insert %val, %v_1[%pos, 0] : f32 into vector<4x4xf32>
- return %v_2 : vector<4x4xf32>
+ return %v_2 : vector<4x4xf32>
}
// -----
@@ -3612,5 +3614,5 @@ func.func @fold_insert_use_chain(%arg : vector<4x4xf32>, %val : f32, %pos: index
func.func @no_fold_insert_use_chain_mismatch_static_position(%arg : vector<4xf32>, %val : f32) -> vector<4xf32> {
%v_0 = vector.insert %val, %arg[0] : f32 into vector<4xf32>
%v_1 = vector.insert %val, %v_0[1] : f32 into vector<4xf32>
- return %v_1 : vector<4xf32>
+ return %v_1 : vector<4xf32>
}
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
index fdab2a8918a2e..f43328f621787 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
@@ -36,9 +36,9 @@ func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32
// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32)
func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vect...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/150284
More information about the Mlir-commits
mailing list