[Mlir-commits] [mlir] [mlir][vector] Fold vector extract from insert when trailing unit dims (PR #192109)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 14 11:47:23 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Erick Ochoa Lopez (amd-eochoalo)
<details>
<summary>Changes</summary>
Upstreamed from https://github.com/iree-org/iree/pull/23789
Folds vector.extract from vector.insert when the extract position is a
prefix of the insert position and the remaining (un-indexed) dimensions
of the extracted sub-vector are all size 1. In that case the extracted
value is fully determined by the inserted value.
Example:
%ins = vector.insert %s, %v [3, 0] : f32 into vector<16x1xf32>
%ext = vector.extract %ins [3] : vector<1xf32> from vector<16x1xf32>
folds to:
%ext = vector.broadcast %s : f32 to vector<1xf32>
Co-authored-by: Jakub Kuderski <kubakuderski@<!-- -->gmail.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@<!-- -->anthropic.com>
---
Full diff: https://github.com/llvm/llvm-project/pull/192109.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+56-1)
- (added) mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir (+142)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 238f37ae57ac6..8b068150390ee 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2453,12 +2453,67 @@ struct ExtractToShapeCast final : OpRewritePattern<vector::ExtractOp> {
}
};
+/// Folds vector.extract from vector.insert when the extract position is a
+/// prefix of the insert position and the remaining (un-indexed) dimensions
+/// of the extracted sub-vector are all size 1. In that case the extracted
+/// value is fully determined by the inserted value.
+///
+/// Example:
+/// %ins = vector.insert %s, %v [3, 0] : f32 into vector<16x1xf32>
+/// %ext = vector.extract %ins [3] : vector<1xf32> from vector<16x1xf32>
+/// folds to:
+/// %ext = vector.broadcast %s : f32 to vector<1xf32>
+struct FoldExtractFromInsertUnitDim final
+ : OpRewritePattern<vector::ExtractOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ if (extractOp.hasDynamicPosition()) {
+ return failure();
+ }
+
+ auto insertOp = extractOp.getSource().getDefiningOp<vector::InsertOp>();
+ if (!insertOp || insertOp.hasDynamicPosition()) {
+ return failure();
+ }
+
+ ArrayRef<int64_t> extractPos = extractOp.getStaticPosition();
+ ArrayRef<int64_t> insertPos = insertOp.getStaticPosition();
+
+ // The extract position must be a strict prefix of the insert position.
+ if (extractPos.size() >= insertPos.size()) {
+ return failure();
+ }
+ if (extractPos != insertPos.take_front(extractPos.size())) {
+ return failure();
+ }
+
+ // The remaining dimensions (those not indexed by the extract) must all
+ // be size 1 in the source vector type. This guarantees that the inserted
+ // value fully determines the extracted sub-vector.
+ auto srcVecType = extractOp.getSourceVectorType();
+ for (int64_t i = extractPos.size(), e = srcVecType.getRank(); i < e; ++i) {
+ if (srcVecType.getDimSize(i) != 1) {
+ return failure();
+ }
+ }
+
+ // The inserted value fully determines the extracted sub-vector; broadcast
+ // it to the extracted type.
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+ extractOp, extractOp.getResult().getType(), insertOp.getValueToStore());
+ return success();
+ }
+};
+
} // namespace
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask,
- ExtractOpFromConstantMask, ExtractToShapeCast>(context);
+ ExtractOpFromConstantMask, ExtractToShapeCast,
+ FoldExtractFromInsertUnitDim>(context);
results.add(foldExtractFromShapeCastToShapeCast);
results.add(foldExtractFromFromElements);
}
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir
new file mode 100644
index 0000000000000..86509610cd464
--- /dev/null
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir
@@ -0,0 +1,142 @@
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+// Tests for FoldExtractFromInsertUnitDim: fold vector.extract from
+// vector.insert when the extract position is a strict prefix of the insert
+// position and all remaining dimensions are size 1, so the extracted
+// sub-vector is fully determined by the inserted value.
+
+// Basic case: extract row from a vector<4x1xf32> insert chain.
+// The extract at [i] from insert at [i, 0] should fold to a broadcast.
+
+// CHECK-LABEL: func.func @extract_from_insert_trailing_unit_dim
+// CHECK-SAME: %[[S0:.*]]: f32, %[[S1:.*]]: f32
+// CHECK-NOT: ub.poison
+// CHECK-NOT: vector.insert {{.*}} vector<4x1xf32>
+// CHECK-NOT: vector.extract {{.*}} vector<4x1xf32>
+// CHECK-DAG: vector.broadcast %[[S0]] : f32 to vector<1xf32>
+// CHECK-DAG: vector.broadcast %[[S1]] : f32 to vector<1xf32>
+func.func @extract_from_insert_trailing_unit_dim(%s0: f32, %s1: f32) -> (vector<1xf32>, vector<1xf32>) {
+ %poison = ub.poison : vector<4x1xf32>
+ %ins0 = vector.insert %s0, %poison [0, 0] : f32 into vector<4x1xf32>
+ %ins1 = vector.insert %s1, %ins0 [1, 0] : f32 into vector<4x1xf32>
+ %ext0 = vector.extract %ins0 [0] : vector<1xf32> from vector<4x1xf32>
+ %ext1 = vector.extract %ins1 [1] : vector<1xf32> from vector<4x1xf32>
+ return %ext0, %ext1 : vector<1xf32>, vector<1xf32>
+}
+
+// -----
+
+// Multiple trailing unit dims: vector<4x1x1xf32>.
+// Extract at [i] gives vector<1x1xf32>; the inserted value fully
+// determines the result.
+
+// CHECK-LABEL: func.func @extract_from_insert_multiple_trailing_unit_dims
+// CHECK-SAME: %[[S:.*]]: f32
+// CHECK-NOT: ub.poison
+// CHECK: vector.broadcast %[[S]] : f32 to vector<1x1xf32>
+func.func @extract_from_insert_multiple_trailing_unit_dims(%s: f32) -> vector<1x1xf32> {
+ %poison = ub.poison : vector<4x1x1xf32>
+ %ins = vector.insert %s, %poison [2, 0, 0] : f32 into vector<4x1x1xf32>
+ %ext = vector.extract %ins [2] : vector<1x1xf32> from vector<4x1x1xf32>
+ return %ext : vector<1x1xf32>
+}
+
+// -----
+
+// Negative: extract position does not match insert position.
+// The upstream extract fold forwards through to the poison dest,
+// but our pattern should not fire.
+
+// CHECK-LABEL: func.func @extract_from_insert_position_mismatch
+// CHECK-NOT: vector.broadcast
+func.func @extract_from_insert_position_mismatch(%s: f32) -> vector<1xf32> {
+ %poison = ub.poison : vector<4x1xf32>
+ %ins = vector.insert %s, %poison [1, 0] : f32 into vector<4x1xf32>
+ %ext = vector.extract %ins [0] : vector<1xf32> from vector<4x1xf32>
+ return %ext : vector<1xf32>
+}
+
+// -----
+
+// Negative: trailing dim is not unit size -- should NOT fold.
+
+// CHECK-LABEL: func.func @extract_from_insert_non_unit_trailing_dim
+// CHECK: vector.insert
+// CHECK: vector.extract
+func.func @extract_from_insert_non_unit_trailing_dim(%s: f32) -> vector<4xf32> {
+ %poison = ub.poison : vector<3x4xf32>
+ %ins = vector.insert %s, %poison [1, 2] : f32 into vector<3x4xf32>
+ %ext = vector.extract %ins [1] : vector<4xf32> from vector<3x4xf32>
+ return %ext : vector<4xf32>
+}
+
+// -----
+
+// Negative: dynamic extract position -- should NOT fold.
+
+// CHECK-LABEL: func.func @extract_from_insert_dynamic_position
+// CHECK: vector.insert
+// CHECK: vector.extract
+func.func @extract_from_insert_dynamic_position(%s: f32, %idx: index) -> vector<1xf32> {
+ %poison = ub.poison : vector<4x1xf32>
+ %ins = vector.insert %s, %poison [2, 0] : f32 into vector<4x1xf32>
+ %ext = vector.extract %ins [%idx] : vector<1xf32> from vector<4x1xf32>
+ return %ext : vector<1xf32>
+}
+
+// -----
+
+// Negative: dynamic insert position -- should NOT fold.
+
+// CHECK-LABEL: func.func @insert_dynamic_position_not_folded
+// CHECK: vector.insert
+// CHECK: vector.extract
+func.func @insert_dynamic_position_not_folded(%s: f32, %idx: index) -> vector<1xf32> {
+ %poison = ub.poison : vector<4x1xf32>
+ %ins = vector.insert %s, %poison [%idx, 0] : f32 into vector<4x1xf32>
+ %ext = vector.extract %ins [2] : vector<1xf32> from vector<4x1xf32>
+ return %ext : vector<1xf32>
+}
+
+// -----
+
+// Exact position match (same number of indices) is handled by the
+// existing extract fold, not our pattern. Verify it still works.
+
+// CHECK-LABEL: func.func @extract_from_insert_exact_match
+// CHECK-SAME: %[[S:.*]]: f32
+// CHECK: return %[[S]]
+func.func @extract_from_insert_exact_match(%s: f32) -> f32 {
+ %poison = ub.poison : vector<4x1xf32>
+ %ins = vector.insert %s, %poison [2, 0] : f32 into vector<4x1xf32>
+ %ext = vector.extract %ins [2, 0] : f32 from vector<4x1xf32>
+ return %ext : f32
+}
+
+// -----
+
+// End-to-end: full insert chain with all extracts reading from intermediates.
+// This mirrors the pattern from conv_1d lowering that triggered the original
+// bug. After folding, the insert chain becomes dead and is eliminated.
+
+// CHECK-LABEL: func.func @full_insert_chain_with_extracts
+// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32, %[[C:.*]]: f32, %[[D:.*]]: f32
+// CHECK-NOT: ub.poison : vector<4x1xf32>
+// CHECK-NOT: vector<4x1xf32>
+// CHECK-DAG: %[[BA:.*]] = vector.broadcast %[[A]] : f32 to vector<1xf32>
+// CHECK-DAG: %[[BB:.*]] = vector.broadcast %[[B]] : f32 to vector<1xf32>
+// CHECK-DAG: %[[BC:.*]] = vector.broadcast %[[C]] : f32 to vector<1xf32>
+// CHECK-DAG: %[[BD:.*]] = vector.broadcast %[[D]] : f32 to vector<1xf32>
+// CHECK: return %[[BA]], %[[BB]], %[[BC]], %[[BD]]
+func.func @full_insert_chain_with_extracts(%a: f32, %b: f32, %c: f32, %d: f32) -> (vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>) {
+ %poison = ub.poison : vector<4x1xf32>
+ %ins0 = vector.insert %a, %poison [0, 0] : f32 into vector<4x1xf32>
+ %ins1 = vector.insert %b, %ins0 [1, 0] : f32 into vector<4x1xf32>
+ %ins2 = vector.insert %c, %ins1 [2, 0] : f32 into vector<4x1xf32>
+ %ins3 = vector.insert %d, %ins2 [3, 0] : f32 into vector<4x1xf32>
+ %ext0 = vector.extract %ins0 [0] : vector<1xf32> from vector<4x1xf32>
+ %ext1 = vector.extract %ins1 [1] : vector<1xf32> from vector<4x1xf32>
+ %ext2 = vector.extract %ins2 [2] : vector<1xf32> from vector<4x1xf32>
+ %ext3 = vector.extract %ins3 [3] : vector<1xf32> from vector<4x1xf32>
+ return %ext0, %ext1, %ext2, %ext3 : vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/192109
More information about the Mlir-commits
mailing list