[Mlir-commits] [mlir] [mlir][vector] Fold vector extract from insert when trailing unit dims (PR #192109)

Erick Ochoa Lopez llvmlistbot at llvm.org
Wed Apr 15 11:01:01 PDT 2026


https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/192109

>From b610e7fc68e0352c799f5ed71dafaad4f33b2c16 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 14 Apr 2026 14:33:33 -0400
Subject: [PATCH 1/4] [mlir][vector] Fold vector extract from insert when
 trailing unit dims

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      |  57 ++++++-
 .../Vector/canonicalize/vector-extract.mlir   | 142 ++++++++++++++++++
 2 files changed, 198 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 238f37ae57ac6..8b068150390ee 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2453,12 +2453,67 @@ struct ExtractToShapeCast final : OpRewritePattern<vector::ExtractOp> {
   }
 };
 
+/// Folds vector.extract from vector.insert when the extract position is a
+/// prefix of the insert position and the remaining (un-indexed) dimensions
+/// of the extracted sub-vector are all size 1. In that case the extracted
+/// value is fully determined by the inserted value.
+///
+/// Example:
+///   %ins = vector.insert %s, %v [3, 0] : f32 into vector<16x1xf32>
+///   %ext = vector.extract %ins [3] : vector<1xf32> from vector<16x1xf32>
+/// folds to:
+///   %ext = vector.broadcast %s : f32 to vector<1xf32>
+struct FoldExtractFromInsertUnitDim final
+    : OpRewritePattern<vector::ExtractOp> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    if (extractOp.hasDynamicPosition()) {
+      return failure();
+    }
+
+    auto insertOp = extractOp.getSource().getDefiningOp<vector::InsertOp>();
+    if (!insertOp || insertOp.hasDynamicPosition()) {
+      return failure();
+    }
+
+    ArrayRef<int64_t> extractPos = extractOp.getStaticPosition();
+    ArrayRef<int64_t> insertPos = insertOp.getStaticPosition();
+
+    // The extract position must be a strict prefix of the insert position.
+    if (extractPos.size() >= insertPos.size()) {
+      return failure();
+    }
+    if (extractPos != insertPos.take_front(extractPos.size())) {
+      return failure();
+    }
+
+    // The remaining dimensions (those not indexed by the extract) must all
+    // be size 1 in the source vector type. This guarantees that the inserted
+    // value fully determines the extracted sub-vector.
+    auto srcVecType = extractOp.getSourceVectorType();
+    for (int64_t i = extractPos.size(), e = srcVecType.getRank(); i < e; ++i) {
+      if (srcVecType.getDimSize(i) != 1) {
+        return failure();
+      }
+    }
+
+    // The inserted value fully determines the extracted sub-vector; broadcast
+    // it to the extracted type.
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+        extractOp, extractOp.getResult().getType(), insertOp.getValueToStore());
+    return success();
+  }
+};
+
 } // namespace
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
   results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask,
-              ExtractOpFromConstantMask, ExtractToShapeCast>(context);
+              ExtractOpFromConstantMask, ExtractToShapeCast,
+              FoldExtractFromInsertUnitDim>(context);
   results.add(foldExtractFromShapeCastToShapeCast);
   results.add(foldExtractFromFromElements);
 }
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir
new file mode 100644
index 0000000000000..86509610cd464
--- /dev/null
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir
@@ -0,0 +1,142 @@
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+// Tests for FoldExtractFromInsertUnitDim: fold vector.extract from
+// vector.insert when the extract position is a strict prefix of the insert
+// position and all remaining dimensions are size 1, so the extracted
+// sub-vector is fully determined by the inserted value.
+
+// Basic case: extract row from a vector<4x1xf32> insert chain.
+// The extract at [i] from insert at [i, 0] should fold to a broadcast.
+
+// CHECK-LABEL: func.func @extract_from_insert_trailing_unit_dim
+// CHECK-SAME:    %[[S0:.*]]: f32, %[[S1:.*]]: f32
+// CHECK-NOT:     ub.poison
+// CHECK-NOT:     vector.insert {{.*}} vector<4x1xf32>
+// CHECK-NOT:     vector.extract {{.*}} vector<4x1xf32>
+// CHECK-DAG:     vector.broadcast %[[S0]] : f32 to vector<1xf32>
+// CHECK-DAG:     vector.broadcast %[[S1]] : f32 to vector<1xf32>
+func.func @extract_from_insert_trailing_unit_dim(%s0: f32, %s1: f32) -> (vector<1xf32>, vector<1xf32>) {
+  %poison = ub.poison : vector<4x1xf32>
+  %ins0 = vector.insert %s0, %poison [0, 0] : f32 into vector<4x1xf32>
+  %ins1 = vector.insert %s1, %ins0 [1, 0] : f32 into vector<4x1xf32>
+  %ext0 = vector.extract %ins0 [0] : vector<1xf32> from vector<4x1xf32>
+  %ext1 = vector.extract %ins1 [1] : vector<1xf32> from vector<4x1xf32>
+  return %ext0, %ext1 : vector<1xf32>, vector<1xf32>
+}
+
+// -----
+
+// Multiple trailing unit dims: vector<4x1x1xf32>.
+// Extract at [i] gives vector<1x1xf32>; the inserted value fully
+// determines the result.
+
+// CHECK-LABEL: func.func @extract_from_insert_multiple_trailing_unit_dims
+// CHECK-SAME:    %[[S:.*]]: f32
+// CHECK-NOT:     ub.poison
+// CHECK:         vector.broadcast %[[S]] : f32 to vector<1x1xf32>
+func.func @extract_from_insert_multiple_trailing_unit_dims(%s: f32) -> vector<1x1xf32> {
+  %poison = ub.poison : vector<4x1x1xf32>
+  %ins = vector.insert %s, %poison [2, 0, 0] : f32 into vector<4x1x1xf32>
+  %ext = vector.extract %ins [2] : vector<1x1xf32> from vector<4x1x1xf32>
+  return %ext : vector<1x1xf32>
+}
+
+// -----
+
+// Negative: extract position does not match insert position.
+// The upstream extract fold forwards through to the poison dest,
+// but our pattern should not fire.
+
+// CHECK-LABEL: func.func @extract_from_insert_position_mismatch
+// CHECK-NOT:     vector.broadcast
+func.func @extract_from_insert_position_mismatch(%s: f32) -> vector<1xf32> {
+  %poison = ub.poison : vector<4x1xf32>
+  %ins = vector.insert %s, %poison [1, 0] : f32 into vector<4x1xf32>
+  %ext = vector.extract %ins [0] : vector<1xf32> from vector<4x1xf32>
+  return %ext : vector<1xf32>
+}
+
+// -----
+
+// Negative: trailing dim is not unit size -- should NOT fold.
+
+// CHECK-LABEL: func.func @extract_from_insert_non_unit_trailing_dim
+// CHECK:         vector.insert
+// CHECK:         vector.extract
+func.func @extract_from_insert_non_unit_trailing_dim(%s: f32) -> vector<4xf32> {
+  %poison = ub.poison : vector<3x4xf32>
+  %ins = vector.insert %s, %poison [1, 2] : f32 into vector<3x4xf32>
+  %ext = vector.extract %ins [1] : vector<4xf32> from vector<3x4xf32>
+  return %ext : vector<4xf32>
+}
+
+// -----
+
+// Negative: dynamic extract position -- should NOT fold.
+
+// CHECK-LABEL: func.func @extract_from_insert_dynamic_position
+// CHECK:         vector.insert
+// CHECK:         vector.extract
+func.func @extract_from_insert_dynamic_position(%s: f32, %idx: index) -> vector<1xf32> {
+  %poison = ub.poison : vector<4x1xf32>
+  %ins = vector.insert %s, %poison [2, 0] : f32 into vector<4x1xf32>
+  %ext = vector.extract %ins [%idx] : vector<1xf32> from vector<4x1xf32>
+  return %ext : vector<1xf32>
+}
+
+// -----
+
+// Negative: dynamic insert position -- should NOT fold.
+
+// CHECK-LABEL: func.func @insert_dynamic_position_not_folded
+// CHECK:         vector.insert
+// CHECK:         vector.extract
+func.func @insert_dynamic_position_not_folded(%s: f32, %idx: index) -> vector<1xf32> {
+  %poison = ub.poison : vector<4x1xf32>
+  %ins = vector.insert %s, %poison [%idx, 0] : f32 into vector<4x1xf32>
+  %ext = vector.extract %ins [2] : vector<1xf32> from vector<4x1xf32>
+  return %ext : vector<1xf32>
+}
+
+// -----
+
+// Exact position match (same number of indices) is handled by the
+// existing extract fold, not our pattern. Verify it still works.
+
+// CHECK-LABEL: func.func @extract_from_insert_exact_match
+// CHECK-SAME:    %[[S:.*]]: f32
+// CHECK:         return %[[S]]
+func.func @extract_from_insert_exact_match(%s: f32) -> f32 {
+  %poison = ub.poison : vector<4x1xf32>
+  %ins = vector.insert %s, %poison [2, 0] : f32 into vector<4x1xf32>
+  %ext = vector.extract %ins [2, 0] : f32 from vector<4x1xf32>
+  return %ext : f32
+}
+
+// -----
+
+// End-to-end: full insert chain with all extracts reading from intermediates.
+// This mirrors the pattern from conv_1d lowering that triggered the original
+// bug. After folding, the insert chain becomes dead and is eliminated.
+
+// CHECK-LABEL: func.func @full_insert_chain_with_extracts
+// CHECK-SAME:    %[[A:.*]]: f32, %[[B:.*]]: f32, %[[C:.*]]: f32, %[[D:.*]]: f32
+// CHECK-NOT:     ub.poison : vector<4x1xf32>
+// CHECK-NOT:     vector<4x1xf32>
+// CHECK-DAG:     %[[BA:.*]] = vector.broadcast %[[A]] : f32 to vector<1xf32>
+// CHECK-DAG:     %[[BB:.*]] = vector.broadcast %[[B]] : f32 to vector<1xf32>
+// CHECK-DAG:     %[[BC:.*]] = vector.broadcast %[[C]] : f32 to vector<1xf32>
+// CHECK-DAG:     %[[BD:.*]] = vector.broadcast %[[D]] : f32 to vector<1xf32>
+// CHECK:         return %[[BA]], %[[BB]], %[[BC]], %[[BD]]
+func.func @full_insert_chain_with_extracts(%a: f32, %b: f32, %c: f32, %d: f32) -> (vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>) {
+  %poison = ub.poison : vector<4x1xf32>
+  %ins0 = vector.insert %a, %poison [0, 0] : f32 into vector<4x1xf32>
+  %ins1 = vector.insert %b, %ins0 [1, 0] : f32 into vector<4x1xf32>
+  %ins2 = vector.insert %c, %ins1 [2, 0] : f32 into vector<4x1xf32>
+  %ins3 = vector.insert %d, %ins2 [3, 0] : f32 into vector<4x1xf32>
+  %ext0 = vector.extract %ins0 [0] : vector<1xf32> from vector<4x1xf32>
+  %ext1 = vector.extract %ins1 [1] : vector<1xf32> from vector<4x1xf32>
+  %ext2 = vector.extract %ins2 [2] : vector<1xf32> from vector<4x1xf32>
+  %ext3 = vector.extract %ins3 [3] : vector<1xf32> from vector<4x1xf32>
+  return %ext0, %ext1, %ext2, %ext3 : vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>
+}

>From d283c38e8ddd87be96a3b07ebbc2591181c63204 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 15 Apr 2026 10:18:00 -0400
Subject: [PATCH 2/4] Address review comments

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 19 +++-----
 .../Vector/canonicalize/vector-extract.mlir   | 46 ++++---------------
 2 files changed, 15 insertions(+), 50 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8b068150390ee..b316ec1bfa041 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2469,35 +2469,28 @@ struct FoldExtractFromInsertUnitDim final
 
   LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
                                 PatternRewriter &rewriter) const override {
-    if (extractOp.hasDynamicPosition()) {
+    if (extractOp.hasDynamicPosition())
       return failure();
-    }
 
     auto insertOp = extractOp.getSource().getDefiningOp<vector::InsertOp>();
-    if (!insertOp || insertOp.hasDynamicPosition()) {
+    if (!insertOp || insertOp.hasDynamicPosition())
       return failure();
-    }
 
     ArrayRef<int64_t> extractPos = extractOp.getStaticPosition();
     ArrayRef<int64_t> insertPos = insertOp.getStaticPosition();
 
     // The extract position must be a strict prefix of the insert position.
-    if (extractPos.size() >= insertPos.size()) {
-      return failure();
-    }
-    if (extractPos != insertPos.take_front(extractPos.size())) {
+    if (extractPos.size() >= insertPos.size() ||
+        extractPos != insertPos.take_front(extractPos.size()))
       return failure();
-    }
 
     // The remaining dimensions (those not indexed by the extract) must all
     // be size 1 in the source vector type. This guarantees that the inserted
     // value fully determines the extracted sub-vector.
     auto srcVecType = extractOp.getSourceVectorType();
-    for (int64_t i = extractPos.size(), e = srcVecType.getRank(); i < e; ++i) {
-      if (srcVecType.getDimSize(i) != 1) {
+    for (int64_t i = extractPos.size(), e = srcVecType.getRank(); i < e; ++i)
+      if (srcVecType.getDimSize(i) != 1)
         return failure();
-      }
-    }
 
     // The inserted value fully determines the extracted sub-vector; broadcast
     // it to the extracted type.
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir
index 86509610cd464..43deb60b27060 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file | FileCheck %s
 
 // Tests for FoldExtractFromInsertUnitDim: fold vector.extract from
 // vector.insert when the extract position is a strict prefix of the insert
@@ -47,9 +47,9 @@ func.func @extract_from_insert_multiple_trailing_unit_dims(%s: f32) -> vector<1x
 // The upstream extract fold forwards through to the poison dest,
 // but our pattern should not fire.
 
-// CHECK-LABEL: func.func @extract_from_insert_position_mismatch
+// CHECK-LABEL: func.func @negative_extract_from_insert_position_mismatch
 // CHECK-NOT:     vector.broadcast
-func.func @extract_from_insert_position_mismatch(%s: f32) -> vector<1xf32> {
+func.func @negative_extract_from_insert_position_mismatch(%s: f32) -> vector<1xf32> {
   %poison = ub.poison : vector<4x1xf32>
   %ins = vector.insert %s, %poison [1, 0] : f32 into vector<4x1xf32>
   %ext = vector.extract %ins [0] : vector<1xf32> from vector<4x1xf32>
@@ -60,10 +60,10 @@ func.func @extract_from_insert_position_mismatch(%s: f32) -> vector<1xf32> {
 
 // Negative: trailing dim is not unit size -- should NOT fold.
 
-// CHECK-LABEL: func.func @extract_from_insert_non_unit_trailing_dim
+// CHECK-LABEL: func.func @negative_extract_from_insert_non_unit_trailing_dim
 // CHECK:         vector.insert
 // CHECK:         vector.extract
-func.func @extract_from_insert_non_unit_trailing_dim(%s: f32) -> vector<4xf32> {
+func.func @negative_extract_from_insert_non_unit_trailing_dim(%s: f32) -> vector<4xf32> {
   %poison = ub.poison : vector<3x4xf32>
   %ins = vector.insert %s, %poison [1, 2] : f32 into vector<3x4xf32>
   %ext = vector.extract %ins [1] : vector<4xf32> from vector<3x4xf32>
@@ -74,10 +74,10 @@ func.func @extract_from_insert_non_unit_trailing_dim(%s: f32) -> vector<4xf32> {
 
 // Negative: dynamic extract position -- should NOT fold.
 
-// CHECK-LABEL: func.func @extract_from_insert_dynamic_position
+// CHECK-LABEL: func.func @negative_extract_from_insert_dynamic_position
 // CHECK:         vector.insert
 // CHECK:         vector.extract
-func.func @extract_from_insert_dynamic_position(%s: f32, %idx: index) -> vector<1xf32> {
+func.func @negative_extract_from_insert_dynamic_position(%s: f32, %idx: index) -> vector<1xf32> {
   %poison = ub.poison : vector<4x1xf32>
   %ins = vector.insert %s, %poison [2, 0] : f32 into vector<4x1xf32>
   %ext = vector.extract %ins [%idx] : vector<1xf32> from vector<4x1xf32>
@@ -88,10 +88,10 @@ func.func @extract_from_insert_dynamic_position(%s: f32, %idx: index) -> vector<
 
 // Negative: dynamic insert position -- should NOT fold.
 
-// CHECK-LABEL: func.func @insert_dynamic_position_not_folded
+// CHECK-LABEL: func.func @negative_insert_dynamic_position_not_folded
 // CHECK:         vector.insert
 // CHECK:         vector.extract
-func.func @insert_dynamic_position_not_folded(%s: f32, %idx: index) -> vector<1xf32> {
+func.func @negative_insert_dynamic_position_not_folded(%s: f32, %idx: index) -> vector<1xf32> {
   %poison = ub.poison : vector<4x1xf32>
   %ins = vector.insert %s, %poison [%idx, 0] : f32 into vector<4x1xf32>
   %ext = vector.extract %ins [2] : vector<1xf32> from vector<4x1xf32>
@@ -112,31 +112,3 @@ func.func @extract_from_insert_exact_match(%s: f32) -> f32 {
   %ext = vector.extract %ins [2, 0] : f32 from vector<4x1xf32>
   return %ext : f32
 }
-
-// -----
-
-// End-to-end: full insert chain with all extracts reading from intermediates.
-// This mirrors the pattern from conv_1d lowering that triggered the original
-// bug. After folding, the insert chain becomes dead and is eliminated.
-
-// CHECK-LABEL: func.func @full_insert_chain_with_extracts
-// CHECK-SAME:    %[[A:.*]]: f32, %[[B:.*]]: f32, %[[C:.*]]: f32, %[[D:.*]]: f32
-// CHECK-NOT:     ub.poison : vector<4x1xf32>
-// CHECK-NOT:     vector<4x1xf32>
-// CHECK-DAG:     %[[BA:.*]] = vector.broadcast %[[A]] : f32 to vector<1xf32>
-// CHECK-DAG:     %[[BB:.*]] = vector.broadcast %[[B]] : f32 to vector<1xf32>
-// CHECK-DAG:     %[[BC:.*]] = vector.broadcast %[[C]] : f32 to vector<1xf32>
-// CHECK-DAG:     %[[BD:.*]] = vector.broadcast %[[D]] : f32 to vector<1xf32>
-// CHECK:         return %[[BA]], %[[BB]], %[[BC]], %[[BD]]
-func.func @full_insert_chain_with_extracts(%a: f32, %b: f32, %c: f32, %d: f32) -> (vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>) {
-  %poison = ub.poison : vector<4x1xf32>
-  %ins0 = vector.insert %a, %poison [0, 0] : f32 into vector<4x1xf32>
-  %ins1 = vector.insert %b, %ins0 [1, 0] : f32 into vector<4x1xf32>
-  %ins2 = vector.insert %c, %ins1 [2, 0] : f32 into vector<4x1xf32>
-  %ins3 = vector.insert %d, %ins2 [3, 0] : f32 into vector<4x1xf32>
-  %ext0 = vector.extract %ins0 [0] : vector<1xf32> from vector<4x1xf32>
-  %ext1 = vector.extract %ins1 [1] : vector<1xf32> from vector<4x1xf32>
-  %ext2 = vector.extract %ins2 [2] : vector<1xf32> from vector<4x1xf32>
-  %ext3 = vector.extract %ins3 [3] : vector<1xf32> from vector<4x1xf32>
-  return %ext0, %ext1, %ext2, %ext3 : vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>
-}

>From 0f70df8387c79ba1023c9a60120d109ed6b86464 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 15 Apr 2026 11:04:20 -0400
Subject: [PATCH 3/4] Add vector to vector broadcast test.

---
 .../Vector/canonicalize/vector-extract.mlir        | 14 ++++++++++++++
 1 file changed, 14 insertions(+)

diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir
index 43deb60b27060..615d79dab77ab 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir
@@ -112,3 +112,17 @@ func.func @extract_from_insert_exact_match(%s: f32) -> f32 {
   %ext = vector.extract %ins [2, 0] : f32 from vector<4x1xf32>
   return %ext : f32
 }
+
+// -----
+
+// First matches FoldExtractFromInsertUnitDim, then BroadcastToShapeCast.
+
+// CHECK-LABEL: func.func @extract_from_insert_vector_to_vector_broadcast
+// CHECK-SAME:    %[[SRC:.*]]: vector<1xf32>
+// CHECK:         vector.shape_cast %[[SRC]] : vector<1xf32> to vector<1x1xf32>
+func.func @extract_from_insert_vector_to_vector_broadcast(%src: vector<1xf32>) -> vector<1x1xf32> {
+  %poison = ub.poison : vector<16x1x1xf32>
+  %vec1 = vector.insert %src, %poison [0, 0] : vector<1xf32> into vector<16x1x1xf32>
+  %vec2 = vector.extract %vec1[0] : vector<1x1xf32> from vector<16x1x1xf32>
+  return %vec2 : vector<1x1xf32>
+}

>From 9f6f333e51e9efbf484dad75eaeea13e4df0dc5f Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 15 Apr 2026 14:00:38 -0400
Subject: [PATCH 4/4] Use shape_cast if the inserted value is a vector

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 23 ++++++++++++++++++-----
 1 file changed, 18 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b316ec1bfa041..0c5bbc04394ac 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2458,11 +2458,16 @@ struct ExtractToShapeCast final : OpRewritePattern<vector::ExtractOp> {
 /// of the extracted sub-vector are all size 1. In that case the extracted
 /// value is fully determined by the inserted value.
 ///
-/// Example:
+/// Examples:
 ///   %ins = vector.insert %s, %v [3, 0] : f32 into vector<16x1xf32>
 ///   %ext = vector.extract %ins [3] : vector<1xf32> from vector<16x1xf32>
 /// folds to:
 ///   %ext = vector.broadcast %s : f32 to vector<1xf32>
+///
+///   %ins = vector.insert %s, %v [0, 0] : vector<1xf32> into vector<16x1x1xf32>
+//    %ext = vector.extract %ins [0] : vector<1x1xf32> from vector<16x1x1xf32>
+/// folds to:
+///   %ext = vector.shape_cast %arg0 : vector<1xf32> to vector<1x1xf32>
 struct FoldExtractFromInsertUnitDim final
     : OpRewritePattern<vector::ExtractOp> {
   using Base::Base;
@@ -2492,10 +2497,18 @@ struct FoldExtractFromInsertUnitDim final
       if (srcVecType.getDimSize(i) != 1)
         return failure();
 
-    // The inserted value fully determines the extracted sub-vector; broadcast
-    // it to the extracted type.
-    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
-        extractOp, extractOp.getResult().getType(), insertOp.getValueToStore());
+    Value inserted = insertOp.getValueToStore();
+    Type extractedType = extractOp.getResult().getType();
+    if (isa<VectorType>(inserted.getType())) {
+      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, extractedType,
+                                                       inserted);
+    } else {
+      // The inserted value fully determines the extracted sub-vector; broadcast
+      // it to the extracted type.
+      rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+          extractOp, extractOp.getResult().getType(),
+          insertOp.getValueToStore());
+    }
     return success();
   }
 };



More information about the Mlir-commits mailing list