[Mlir-commits] [mlir] [mlir][vector] Fold vector extract from insert when trailing unit dims (PR #192109)
Erick Ochoa Lopez
llvmlistbot at llvm.org
Wed Apr 15 11:01:01 PDT 2026
https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/192109
>From b610e7fc68e0352c799f5ed71dafaad4f33b2c16 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 14 Apr 2026 14:33:33 -0400
Subject: [PATCH 1/4] [mlir][vector] Fold vector extract from insert when
trailing unit dims
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 57 ++++++-
.../Vector/canonicalize/vector-extract.mlir | 142 ++++++++++++++++++
2 files changed, 198 insertions(+), 1 deletion(-)
create mode 100644 mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 238f37ae57ac6..8b068150390ee 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2453,12 +2453,67 @@ struct ExtractToShapeCast final : OpRewritePattern<vector::ExtractOp> {
}
};
+/// Folds vector.extract from vector.insert when the extract position is a
+/// prefix of the insert position and the remaining (un-indexed) dimensions
+/// of the extracted sub-vector are all size 1. In that case the extracted
+/// value is fully determined by the inserted value.
+///
+/// Example:
+/// %ins = vector.insert %s, %v [3, 0] : f32 into vector<16x1xf32>
+/// %ext = vector.extract %ins [3] : vector<1xf32> from vector<16x1xf32>
+/// folds to:
+/// %ext = vector.broadcast %s : f32 to vector<1xf32>
+struct FoldExtractFromInsertUnitDim final
+ : OpRewritePattern<vector::ExtractOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ if (extractOp.hasDynamicPosition()) {
+ return failure();
+ }
+
+ auto insertOp = extractOp.getSource().getDefiningOp<vector::InsertOp>();
+ if (!insertOp || insertOp.hasDynamicPosition()) {
+ return failure();
+ }
+
+ ArrayRef<int64_t> extractPos = extractOp.getStaticPosition();
+ ArrayRef<int64_t> insertPos = insertOp.getStaticPosition();
+
+ // The extract position must be a strict prefix of the insert position.
+ if (extractPos.size() >= insertPos.size()) {
+ return failure();
+ }
+ if (extractPos != insertPos.take_front(extractPos.size())) {
+ return failure();
+ }
+
+ // The remaining dimensions (those not indexed by the extract) must all
+ // be size 1 in the source vector type. This guarantees that the inserted
+ // value fully determines the extracted sub-vector.
+ auto srcVecType = extractOp.getSourceVectorType();
+ for (int64_t i = extractPos.size(), e = srcVecType.getRank(); i < e; ++i) {
+ if (srcVecType.getDimSize(i) != 1) {
+ return failure();
+ }
+ }
+
+ // The inserted value fully determines the extracted sub-vector; broadcast
+ // it to the extracted type.
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+ extractOp, extractOp.getResult().getType(), insertOp.getValueToStore());
+ return success();
+ }
+};
+
} // namespace
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask,
- ExtractOpFromConstantMask, ExtractToShapeCast>(context);
+ ExtractOpFromConstantMask, ExtractToShapeCast,
+ FoldExtractFromInsertUnitDim>(context);
results.add(foldExtractFromShapeCastToShapeCast);
results.add(foldExtractFromFromElements);
}
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir
new file mode 100644
index 0000000000000..86509610cd464
--- /dev/null
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir
@@ -0,0 +1,142 @@
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+// Tests for FoldExtractFromInsertUnitDim: fold vector.extract from
+// vector.insert when the extract position is a strict prefix of the insert
+// position and all remaining dimensions are size 1, so the extracted
+// sub-vector is fully determined by the inserted value.
+
+// Basic case: extract row from a vector<4x1xf32> insert chain.
+// The extract at [i] from insert at [i, 0] should fold to a broadcast.
+
+// CHECK-LABEL: func.func @extract_from_insert_trailing_unit_dim
+// CHECK-SAME: %[[S0:.*]]: f32, %[[S1:.*]]: f32
+// CHECK-NOT: ub.poison
+// CHECK-NOT: vector.insert {{.*}} vector<4x1xf32>
+// CHECK-NOT: vector.extract {{.*}} vector<4x1xf32>
+// CHECK-DAG: vector.broadcast %[[S0]] : f32 to vector<1xf32>
+// CHECK-DAG: vector.broadcast %[[S1]] : f32 to vector<1xf32>
+func.func @extract_from_insert_trailing_unit_dim(%s0: f32, %s1: f32) -> (vector<1xf32>, vector<1xf32>) {
+ %poison = ub.poison : vector<4x1xf32>
+ %ins0 = vector.insert %s0, %poison [0, 0] : f32 into vector<4x1xf32>
+ %ins1 = vector.insert %s1, %ins0 [1, 0] : f32 into vector<4x1xf32>
+ %ext0 = vector.extract %ins0 [0] : vector<1xf32> from vector<4x1xf32>
+ %ext1 = vector.extract %ins1 [1] : vector<1xf32> from vector<4x1xf32>
+ return %ext0, %ext1 : vector<1xf32>, vector<1xf32>
+}
+
+// -----
+
+// Multiple trailing unit dims: vector<4x1x1xf32>.
+// Extract at [i] gives vector<1x1xf32>; the inserted value fully
+// determines the result.
+
+// CHECK-LABEL: func.func @extract_from_insert_multiple_trailing_unit_dims
+// CHECK-SAME: %[[S:.*]]: f32
+// CHECK-NOT: ub.poison
+// CHECK: vector.broadcast %[[S]] : f32 to vector<1x1xf32>
+func.func @extract_from_insert_multiple_trailing_unit_dims(%s: f32) -> vector<1x1xf32> {
+ %poison = ub.poison : vector<4x1x1xf32>
+ %ins = vector.insert %s, %poison [2, 0, 0] : f32 into vector<4x1x1xf32>
+ %ext = vector.extract %ins [2] : vector<1x1xf32> from vector<4x1x1xf32>
+ return %ext : vector<1x1xf32>
+}
+
+// -----
+
+// Negative: extract position does not match insert position.
+// The upstream extract fold forwards through to the poison dest,
+// but our pattern should not fire.
+
+// CHECK-LABEL: func.func @extract_from_insert_position_mismatch
+// CHECK-NOT: vector.broadcast
+func.func @extract_from_insert_position_mismatch(%s: f32) -> vector<1xf32> {
+ %poison = ub.poison : vector<4x1xf32>
+ %ins = vector.insert %s, %poison [1, 0] : f32 into vector<4x1xf32>
+ %ext = vector.extract %ins [0] : vector<1xf32> from vector<4x1xf32>
+ return %ext : vector<1xf32>
+}
+
+// -----
+
+// Negative: trailing dim is not unit size -- should NOT fold.
+
+// CHECK-LABEL: func.func @extract_from_insert_non_unit_trailing_dim
+// CHECK: vector.insert
+// CHECK: vector.extract
+func.func @extract_from_insert_non_unit_trailing_dim(%s: f32) -> vector<4xf32> {
+ %poison = ub.poison : vector<3x4xf32>
+ %ins = vector.insert %s, %poison [1, 2] : f32 into vector<3x4xf32>
+ %ext = vector.extract %ins [1] : vector<4xf32> from vector<3x4xf32>
+ return %ext : vector<4xf32>
+}
+
+// -----
+
+// Negative: dynamic extract position -- should NOT fold.
+
+// CHECK-LABEL: func.func @extract_from_insert_dynamic_position
+// CHECK: vector.insert
+// CHECK: vector.extract
+func.func @extract_from_insert_dynamic_position(%s: f32, %idx: index) -> vector<1xf32> {
+ %poison = ub.poison : vector<4x1xf32>
+ %ins = vector.insert %s, %poison [2, 0] : f32 into vector<4x1xf32>
+ %ext = vector.extract %ins [%idx] : vector<1xf32> from vector<4x1xf32>
+ return %ext : vector<1xf32>
+}
+
+// -----
+
+// Negative: dynamic insert position -- should NOT fold.
+
+// CHECK-LABEL: func.func @insert_dynamic_position_not_folded
+// CHECK: vector.insert
+// CHECK: vector.extract
+func.func @insert_dynamic_position_not_folded(%s: f32, %idx: index) -> vector<1xf32> {
+ %poison = ub.poison : vector<4x1xf32>
+ %ins = vector.insert %s, %poison [%idx, 0] : f32 into vector<4x1xf32>
+ %ext = vector.extract %ins [2] : vector<1xf32> from vector<4x1xf32>
+ return %ext : vector<1xf32>
+}
+
+// -----
+
+// Exact position match (same number of indices) is handled by the
+// existing extract fold, not our pattern. Verify it still works.
+
+// CHECK-LABEL: func.func @extract_from_insert_exact_match
+// CHECK-SAME: %[[S:.*]]: f32
+// CHECK: return %[[S]]
+func.func @extract_from_insert_exact_match(%s: f32) -> f32 {
+ %poison = ub.poison : vector<4x1xf32>
+ %ins = vector.insert %s, %poison [2, 0] : f32 into vector<4x1xf32>
+ %ext = vector.extract %ins [2, 0] : f32 from vector<4x1xf32>
+ return %ext : f32
+}
+
+// -----
+
+// End-to-end: full insert chain with all extracts reading from intermediates.
+// This mirrors the pattern from conv_1d lowering that triggered the original
+// bug. After folding, the insert chain becomes dead and is eliminated.
+
+// CHECK-LABEL: func.func @full_insert_chain_with_extracts
+// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32, %[[C:.*]]: f32, %[[D:.*]]: f32
+// CHECK-NOT: ub.poison : vector<4x1xf32>
+// CHECK-NOT: vector<4x1xf32>
+// CHECK-DAG: %[[BA:.*]] = vector.broadcast %[[A]] : f32 to vector<1xf32>
+// CHECK-DAG: %[[BB:.*]] = vector.broadcast %[[B]] : f32 to vector<1xf32>
+// CHECK-DAG: %[[BC:.*]] = vector.broadcast %[[C]] : f32 to vector<1xf32>
+// CHECK-DAG: %[[BD:.*]] = vector.broadcast %[[D]] : f32 to vector<1xf32>
+// CHECK: return %[[BA]], %[[BB]], %[[BC]], %[[BD]]
+func.func @full_insert_chain_with_extracts(%a: f32, %b: f32, %c: f32, %d: f32) -> (vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>) {
+ %poison = ub.poison : vector<4x1xf32>
+ %ins0 = vector.insert %a, %poison [0, 0] : f32 into vector<4x1xf32>
+ %ins1 = vector.insert %b, %ins0 [1, 0] : f32 into vector<4x1xf32>
+ %ins2 = vector.insert %c, %ins1 [2, 0] : f32 into vector<4x1xf32>
+ %ins3 = vector.insert %d, %ins2 [3, 0] : f32 into vector<4x1xf32>
+ %ext0 = vector.extract %ins0 [0] : vector<1xf32> from vector<4x1xf32>
+ %ext1 = vector.extract %ins1 [1] : vector<1xf32> from vector<4x1xf32>
+ %ext2 = vector.extract %ins2 [2] : vector<1xf32> from vector<4x1xf32>
+ %ext3 = vector.extract %ins3 [3] : vector<1xf32> from vector<4x1xf32>
+ return %ext0, %ext1, %ext2, %ext3 : vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>
+}
>From d283c38e8ddd87be96a3b07ebbc2591181c63204 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 15 Apr 2026 10:18:00 -0400
Subject: [PATCH 2/4] Address review comments
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 19 +++-----
.../Vector/canonicalize/vector-extract.mlir | 46 ++++---------------
2 files changed, 15 insertions(+), 50 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8b068150390ee..b316ec1bfa041 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2469,35 +2469,28 @@ struct FoldExtractFromInsertUnitDim final
LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
PatternRewriter &rewriter) const override {
- if (extractOp.hasDynamicPosition()) {
+ if (extractOp.hasDynamicPosition())
return failure();
- }
auto insertOp = extractOp.getSource().getDefiningOp<vector::InsertOp>();
- if (!insertOp || insertOp.hasDynamicPosition()) {
+ if (!insertOp || insertOp.hasDynamicPosition())
return failure();
- }
ArrayRef<int64_t> extractPos = extractOp.getStaticPosition();
ArrayRef<int64_t> insertPos = insertOp.getStaticPosition();
// The extract position must be a strict prefix of the insert position.
- if (extractPos.size() >= insertPos.size()) {
- return failure();
- }
- if (extractPos != insertPos.take_front(extractPos.size())) {
+ if (extractPos.size() >= insertPos.size() ||
+ extractPos != insertPos.take_front(extractPos.size()))
return failure();
- }
// The remaining dimensions (those not indexed by the extract) must all
// be size 1 in the source vector type. This guarantees that the inserted
// value fully determines the extracted sub-vector.
auto srcVecType = extractOp.getSourceVectorType();
- for (int64_t i = extractPos.size(), e = srcVecType.getRank(); i < e; ++i) {
- if (srcVecType.getDimSize(i) != 1) {
+ for (int64_t i = extractPos.size(), e = srcVecType.getRank(); i < e; ++i)
+ if (srcVecType.getDimSize(i) != 1)
return failure();
- }
- }
// The inserted value fully determines the extracted sub-vector; broadcast
// it to the extracted type.
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir
index 86509610cd464..43deb60b27060 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file | FileCheck %s
// Tests for FoldExtractFromInsertUnitDim: fold vector.extract from
// vector.insert when the extract position is a strict prefix of the insert
@@ -47,9 +47,9 @@ func.func @extract_from_insert_multiple_trailing_unit_dims(%s: f32) -> vector<1x
// The upstream extract fold forwards through to the poison dest,
// but our pattern should not fire.
-// CHECK-LABEL: func.func @extract_from_insert_position_mismatch
+// CHECK-LABEL: func.func @negative_extract_from_insert_position_mismatch
// CHECK-NOT: vector.broadcast
-func.func @extract_from_insert_position_mismatch(%s: f32) -> vector<1xf32> {
+func.func @negative_extract_from_insert_position_mismatch(%s: f32) -> vector<1xf32> {
%poison = ub.poison : vector<4x1xf32>
%ins = vector.insert %s, %poison [1, 0] : f32 into vector<4x1xf32>
%ext = vector.extract %ins [0] : vector<1xf32> from vector<4x1xf32>
@@ -60,10 +60,10 @@ func.func @extract_from_insert_position_mismatch(%s: f32) -> vector<1xf32> {
// Negative: trailing dim is not unit size -- should NOT fold.
-// CHECK-LABEL: func.func @extract_from_insert_non_unit_trailing_dim
+// CHECK-LABEL: func.func @negative_extract_from_insert_non_unit_trailing_dim
// CHECK: vector.insert
// CHECK: vector.extract
-func.func @extract_from_insert_non_unit_trailing_dim(%s: f32) -> vector<4xf32> {
+func.func @negative_extract_from_insert_non_unit_trailing_dim(%s: f32) -> vector<4xf32> {
%poison = ub.poison : vector<3x4xf32>
%ins = vector.insert %s, %poison [1, 2] : f32 into vector<3x4xf32>
%ext = vector.extract %ins [1] : vector<4xf32> from vector<3x4xf32>
@@ -74,10 +74,10 @@ func.func @extract_from_insert_non_unit_trailing_dim(%s: f32) -> vector<4xf32> {
// Negative: dynamic extract position -- should NOT fold.
-// CHECK-LABEL: func.func @extract_from_insert_dynamic_position
+// CHECK-LABEL: func.func @negative_extract_from_insert_dynamic_position
// CHECK: vector.insert
// CHECK: vector.extract
-func.func @extract_from_insert_dynamic_position(%s: f32, %idx: index) -> vector<1xf32> {
+func.func @negative_extract_from_insert_dynamic_position(%s: f32, %idx: index) -> vector<1xf32> {
%poison = ub.poison : vector<4x1xf32>
%ins = vector.insert %s, %poison [2, 0] : f32 into vector<4x1xf32>
%ext = vector.extract %ins [%idx] : vector<1xf32> from vector<4x1xf32>
@@ -88,10 +88,10 @@ func.func @extract_from_insert_dynamic_position(%s: f32, %idx: index) -> vector<
// Negative: dynamic insert position -- should NOT fold.
-// CHECK-LABEL: func.func @insert_dynamic_position_not_folded
+// CHECK-LABEL: func.func @negative_insert_dynamic_position_not_folded
// CHECK: vector.insert
// CHECK: vector.extract
-func.func @insert_dynamic_position_not_folded(%s: f32, %idx: index) -> vector<1xf32> {
+func.func @negative_insert_dynamic_position_not_folded(%s: f32, %idx: index) -> vector<1xf32> {
%poison = ub.poison : vector<4x1xf32>
%ins = vector.insert %s, %poison [%idx, 0] : f32 into vector<4x1xf32>
%ext = vector.extract %ins [2] : vector<1xf32> from vector<4x1xf32>
@@ -112,31 +112,3 @@ func.func @extract_from_insert_exact_match(%s: f32) -> f32 {
%ext = vector.extract %ins [2, 0] : f32 from vector<4x1xf32>
return %ext : f32
}
-
-// -----
-
-// End-to-end: full insert chain with all extracts reading from intermediates.
-// This mirrors the pattern from conv_1d lowering that triggered the original
-// bug. After folding, the insert chain becomes dead and is eliminated.
-
-// CHECK-LABEL: func.func @full_insert_chain_with_extracts
-// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32, %[[C:.*]]: f32, %[[D:.*]]: f32
-// CHECK-NOT: ub.poison : vector<4x1xf32>
-// CHECK-NOT: vector<4x1xf32>
-// CHECK-DAG: %[[BA:.*]] = vector.broadcast %[[A]] : f32 to vector<1xf32>
-// CHECK-DAG: %[[BB:.*]] = vector.broadcast %[[B]] : f32 to vector<1xf32>
-// CHECK-DAG: %[[BC:.*]] = vector.broadcast %[[C]] : f32 to vector<1xf32>
-// CHECK-DAG: %[[BD:.*]] = vector.broadcast %[[D]] : f32 to vector<1xf32>
-// CHECK: return %[[BA]], %[[BB]], %[[BC]], %[[BD]]
-func.func @full_insert_chain_with_extracts(%a: f32, %b: f32, %c: f32, %d: f32) -> (vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>) {
- %poison = ub.poison : vector<4x1xf32>
- %ins0 = vector.insert %a, %poison [0, 0] : f32 into vector<4x1xf32>
- %ins1 = vector.insert %b, %ins0 [1, 0] : f32 into vector<4x1xf32>
- %ins2 = vector.insert %c, %ins1 [2, 0] : f32 into vector<4x1xf32>
- %ins3 = vector.insert %d, %ins2 [3, 0] : f32 into vector<4x1xf32>
- %ext0 = vector.extract %ins0 [0] : vector<1xf32> from vector<4x1xf32>
- %ext1 = vector.extract %ins1 [1] : vector<1xf32> from vector<4x1xf32>
- %ext2 = vector.extract %ins2 [2] : vector<1xf32> from vector<4x1xf32>
- %ext3 = vector.extract %ins3 [3] : vector<1xf32> from vector<4x1xf32>
- return %ext0, %ext1, %ext2, %ext3 : vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>
-}
>From 0f70df8387c79ba1023c9a60120d109ed6b86464 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 15 Apr 2026 11:04:20 -0400
Subject: [PATCH 3/4] Add vector to vector broadcast test.
---
.../Vector/canonicalize/vector-extract.mlir | 14 ++++++++++++++
1 file changed, 14 insertions(+)
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir
index 43deb60b27060..615d79dab77ab 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir
@@ -112,3 +112,17 @@ func.func @extract_from_insert_exact_match(%s: f32) -> f32 {
%ext = vector.extract %ins [2, 0] : f32 from vector<4x1xf32>
return %ext : f32
}
+
+// -----
+
+// First matches FoldExtractFromInsertUnitDim, then BroadcastToShapeCast.
+
+// CHECK-LABEL: func.func @extract_from_insert_vector_to_vector_broadcast
+// CHECK-SAME: %[[SRC:.*]]: vector<1xf32>
+// CHECK: vector.shape_cast %[[SRC]] : vector<1xf32> to vector<1x1xf32>
+func.func @extract_from_insert_vector_to_vector_broadcast(%src: vector<1xf32>) -> vector<1x1xf32> {
+ %poison = ub.poison : vector<16x1x1xf32>
+ %vec1 = vector.insert %src, %poison [0, 0] : vector<1xf32> into vector<16x1x1xf32>
+ %vec2 = vector.extract %vec1[0] : vector<1x1xf32> from vector<16x1x1xf32>
+ return %vec2 : vector<1x1xf32>
+}
>From 9f6f333e51e9efbf484dad75eaeea13e4df0dc5f Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 15 Apr 2026 14:00:38 -0400
Subject: [PATCH 4/4] Use shape_cast if the inserted value is a vector
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 23 ++++++++++++++++++-----
1 file changed, 18 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b316ec1bfa041..0c5bbc04394ac 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2458,11 +2458,16 @@ struct ExtractToShapeCast final : OpRewritePattern<vector::ExtractOp> {
/// of the extracted sub-vector are all size 1. In that case the extracted
/// value is fully determined by the inserted value.
///
-/// Example:
+/// Examples:
/// %ins = vector.insert %s, %v [3, 0] : f32 into vector<16x1xf32>
/// %ext = vector.extract %ins [3] : vector<1xf32> from vector<16x1xf32>
/// folds to:
/// %ext = vector.broadcast %s : f32 to vector<1xf32>
+///
+/// %ins = vector.insert %s, %v [0, 0] : vector<1xf32> into vector<16x1x1xf32>
+// %ext = vector.extract %ins [0] : vector<1x1xf32> from vector<16x1x1xf32>
+/// folds to:
+/// %ext = vector.shape_cast %arg0 : vector<1xf32> to vector<1x1xf32>
struct FoldExtractFromInsertUnitDim final
: OpRewritePattern<vector::ExtractOp> {
using Base::Base;
@@ -2492,10 +2497,18 @@ struct FoldExtractFromInsertUnitDim final
if (srcVecType.getDimSize(i) != 1)
return failure();
- // The inserted value fully determines the extracted sub-vector; broadcast
- // it to the extracted type.
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
- extractOp, extractOp.getResult().getType(), insertOp.getValueToStore());
+ Value inserted = insertOp.getValueToStore();
+ Type extractedType = extractOp.getResult().getType();
+ if (isa<VectorType>(inserted.getType())) {
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, extractedType,
+ inserted);
+ } else {
+ // The inserted value fully determines the extracted sub-vector; broadcast
+ // it to the extracted type.
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+ extractOp, extractOp.getResult().getType(),
+ insertOp.getValueToStore());
+ }
return success();
}
};
More information about the Mlir-commits
mailing list