[Mlir-commits] [mlir] 0e5cad0 - [mlir][vector] Fold vector extract from insert when trailing unit dims (#192109)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 16 06:34:42 PDT 2026


Author: Erick Ochoa Lopez
Date: 2026-04-16T09:34:37-04:00
New Revision: 0e5cad062cd55499ae12fd3ef4bdb555dc2169f4

URL: https://github.com/llvm/llvm-project/commit/0e5cad062cd55499ae12fd3ef4bdb555dc2169f4
DIFF: https://github.com/llvm/llvm-project/commit/0e5cad062cd55499ae12fd3ef4bdb555dc2169f4.diff

LOG: [mlir][vector] Fold vector extract from insert when trailing unit dims (#192109)

Upstreamed from https://github.com/iree-org/iree/pull/23789

Folds vector.extract from vector.insert when the extract position is a
prefix of the insert position and the remaining (un-indexed) dimensions
of the extracted sub-vector are all size 1. In that case the extracted
value is fully determined by the inserted value.

Example:
  %ins = vector.insert %s, %v [3, 0] : f32 into vector<16x1xf32>
  %ext = vector.extract %ins [3] : vector<1xf32> from vector<16x1xf32>
folds to:
  %ext = vector.broadcast %s : f32 to vector<1xf32>

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply at anthropic.com>

Added: 
    mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 238f37ae57ac6..0c5bbc04394ac 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2453,12 +2453,73 @@ struct ExtractToShapeCast final : OpRewritePattern<vector::ExtractOp> {
   }
 };
 
+/// Folds vector.extract from vector.insert when the extract position is a
+/// prefix of the insert position and the remaining (un-indexed) dimensions
+/// of the extracted sub-vector are all size 1. In that case the extracted
+/// value is fully determined by the inserted value.
+///
+/// Examples:
+///   %ins = vector.insert %s, %v [3, 0] : f32 into vector<16x1xf32>
+///   %ext = vector.extract %ins [3] : vector<1xf32> from vector<16x1xf32>
+/// folds to:
+///   %ext = vector.broadcast %s : f32 to vector<1xf32>
+///
+///   %ins = vector.insert %s, %v [0, 0] : vector<1xf32> into vector<16x1x1xf32>
+//    %ext = vector.extract %ins [0] : vector<1x1xf32> from vector<16x1x1xf32>
+/// folds to:
+///   %ext = vector.shape_cast %arg0 : vector<1xf32> to vector<1x1xf32>
+struct FoldExtractFromInsertUnitDim final
+    : OpRewritePattern<vector::ExtractOp> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    if (extractOp.hasDynamicPosition())
+      return failure();
+
+    auto insertOp = extractOp.getSource().getDefiningOp<vector::InsertOp>();
+    if (!insertOp || insertOp.hasDynamicPosition())
+      return failure();
+
+    ArrayRef<int64_t> extractPos = extractOp.getStaticPosition();
+    ArrayRef<int64_t> insertPos = insertOp.getStaticPosition();
+
+    // The extract position must be a strict prefix of the insert position.
+    if (extractPos.size() >= insertPos.size() ||
+        extractPos != insertPos.take_front(extractPos.size()))
+      return failure();
+
+    // The remaining dimensions (those not indexed by the extract) must all
+    // be size 1 in the source vector type. This guarantees that the inserted
+    // value fully determines the extracted sub-vector.
+    auto srcVecType = extractOp.getSourceVectorType();
+    for (int64_t i = extractPos.size(), e = srcVecType.getRank(); i < e; ++i)
+      if (srcVecType.getDimSize(i) != 1)
+        return failure();
+
+    Value inserted = insertOp.getValueToStore();
+    Type extractedType = extractOp.getResult().getType();
+    if (isa<VectorType>(inserted.getType())) {
+      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, extractedType,
+                                                       inserted);
+    } else {
+      // The inserted value fully determines the extracted sub-vector; broadcast
+      // it to the extracted type.
+      rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+          extractOp, extractOp.getResult().getType(),
+          insertOp.getValueToStore());
+    }
+    return success();
+  }
+};
+
 } // namespace
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
   results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask,
-              ExtractOpFromConstantMask, ExtractToShapeCast>(context);
+              ExtractOpFromConstantMask, ExtractToShapeCast,
+              FoldExtractFromInsertUnitDim>(context);
   results.add(foldExtractFromShapeCastToShapeCast);
   results.add(foldExtractFromFromElements);
 }

diff  --git a/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir
new file mode 100644
index 0000000000000..615d79dab77ab
--- /dev/null
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir
@@ -0,0 +1,128 @@
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file | FileCheck %s
+
+// Tests for FoldExtractFromInsertUnitDim: fold vector.extract from
+// vector.insert when the extract position is a strict prefix of the insert
+// position and all remaining dimensions are size 1, so the extracted
+// sub-vector is fully determined by the inserted value.
+
+// Basic case: extract row from a vector<4x1xf32> insert chain.
+// The extract at [i] from insert at [i, 0] should fold to a broadcast.
+
+// CHECK-LABEL: func.func @extract_from_insert_trailing_unit_dim
+// CHECK-SAME:    %[[S0:.*]]: f32, %[[S1:.*]]: f32
+// CHECK-NOT:     ub.poison
+// CHECK-NOT:     vector.insert {{.*}} vector<4x1xf32>
+// CHECK-NOT:     vector.extract {{.*}} vector<4x1xf32>
+// CHECK-DAG:     vector.broadcast %[[S0]] : f32 to vector<1xf32>
+// CHECK-DAG:     vector.broadcast %[[S1]] : f32 to vector<1xf32>
+func.func @extract_from_insert_trailing_unit_dim(%s0: f32, %s1: f32) -> (vector<1xf32>, vector<1xf32>) {
+  %poison = ub.poison : vector<4x1xf32>
+  %ins0 = vector.insert %s0, %poison [0, 0] : f32 into vector<4x1xf32>
+  %ins1 = vector.insert %s1, %ins0 [1, 0] : f32 into vector<4x1xf32>
+  %ext0 = vector.extract %ins0 [0] : vector<1xf32> from vector<4x1xf32>
+  %ext1 = vector.extract %ins1 [1] : vector<1xf32> from vector<4x1xf32>
+  return %ext0, %ext1 : vector<1xf32>, vector<1xf32>
+}
+
+// -----
+
+// Multiple trailing unit dims: vector<4x1x1xf32>.
+// Extract at [i] gives vector<1x1xf32>; the inserted value fully
+// determines the result.
+
+// CHECK-LABEL: func.func @extract_from_insert_multiple_trailing_unit_dims
+// CHECK-SAME:    %[[S:.*]]: f32
+// CHECK-NOT:     ub.poison
+// CHECK:         vector.broadcast %[[S]] : f32 to vector<1x1xf32>
+func.func @extract_from_insert_multiple_trailing_unit_dims(%s: f32) -> vector<1x1xf32> {
+  %poison = ub.poison : vector<4x1x1xf32>
+  %ins = vector.insert %s, %poison [2, 0, 0] : f32 into vector<4x1x1xf32>
+  %ext = vector.extract %ins [2] : vector<1x1xf32> from vector<4x1x1xf32>
+  return %ext : vector<1x1xf32>
+}
+
+// -----
+
+// Negative: extract position does not match insert position.
+// The upstream extract fold forwards through to the poison dest,
+// but our pattern should not fire.
+
+// CHECK-LABEL: func.func @negative_extract_from_insert_position_mismatch
+// CHECK-NOT:     vector.broadcast
+func.func @negative_extract_from_insert_position_mismatch(%s: f32) -> vector<1xf32> {
+  %poison = ub.poison : vector<4x1xf32>
+  %ins = vector.insert %s, %poison [1, 0] : f32 into vector<4x1xf32>
+  %ext = vector.extract %ins [0] : vector<1xf32> from vector<4x1xf32>
+  return %ext : vector<1xf32>
+}
+
+// -----
+
+// Negative: trailing dim is not unit size -- should NOT fold.
+
+// CHECK-LABEL: func.func @negative_extract_from_insert_non_unit_trailing_dim
+// CHECK:         vector.insert
+// CHECK:         vector.extract
+func.func @negative_extract_from_insert_non_unit_trailing_dim(%s: f32) -> vector<4xf32> {
+  %poison = ub.poison : vector<3x4xf32>
+  %ins = vector.insert %s, %poison [1, 2] : f32 into vector<3x4xf32>
+  %ext = vector.extract %ins [1] : vector<4xf32> from vector<3x4xf32>
+  return %ext : vector<4xf32>
+}
+
+// -----
+
+// Negative: dynamic extract position -- should NOT fold.
+
+// CHECK-LABEL: func.func @negative_extract_from_insert_dynamic_position
+// CHECK:         vector.insert
+// CHECK:         vector.extract
+func.func @negative_extract_from_insert_dynamic_position(%s: f32, %idx: index) -> vector<1xf32> {
+  %poison = ub.poison : vector<4x1xf32>
+  %ins = vector.insert %s, %poison [2, 0] : f32 into vector<4x1xf32>
+  %ext = vector.extract %ins [%idx] : vector<1xf32> from vector<4x1xf32>
+  return %ext : vector<1xf32>
+}
+
+// -----
+
+// Negative: dynamic insert position -- should NOT fold.
+
+// CHECK-LABEL: func.func @negative_insert_dynamic_position_not_folded
+// CHECK:         vector.insert
+// CHECK:         vector.extract
+func.func @negative_insert_dynamic_position_not_folded(%s: f32, %idx: index) -> vector<1xf32> {
+  %poison = ub.poison : vector<4x1xf32>
+  %ins = vector.insert %s, %poison [%idx, 0] : f32 into vector<4x1xf32>
+  %ext = vector.extract %ins [2] : vector<1xf32> from vector<4x1xf32>
+  return %ext : vector<1xf32>
+}
+
+// -----
+
+// Exact position match (same number of indices) is handled by the
+// existing extract fold, not our pattern. Verify it still works.
+
+// CHECK-LABEL: func.func @extract_from_insert_exact_match
+// CHECK-SAME:    %[[S:.*]]: f32
+// CHECK:         return %[[S]]
+func.func @extract_from_insert_exact_match(%s: f32) -> f32 {
+  %poison = ub.poison : vector<4x1xf32>
+  %ins = vector.insert %s, %poison [2, 0] : f32 into vector<4x1xf32>
+  %ext = vector.extract %ins [2, 0] : f32 from vector<4x1xf32>
+  return %ext : f32
+}
+
+// -----
+
+// First matches FoldExtractFromInsertUnitDim, then BroadcastToShapeCast.
+
+// CHECK-LABEL: func.func @extract_from_insert_vector_to_vector_broadcast
+// CHECK-SAME:    %[[SRC:.*]]: vector<1xf32>
+// CHECK:         vector.shape_cast %[[SRC]] : vector<1xf32> to vector<1x1xf32>
+func.func @extract_from_insert_vector_to_vector_broadcast(%src: vector<1xf32>) -> vector<1x1xf32> {
+  %poison = ub.poison : vector<16x1x1xf32>
+  %vec1 = vector.insert %src, %poison [0, 0] : vector<1xf32> into vector<16x1x1xf32>
+  %vec2 = vector.extract %vec1[0] : vector<1x1xf32> from vector<16x1x1xf32>
+  return %vec2 : vector<1x1xf32>
+}


        


More information about the Mlir-commits mailing list