[Mlir-commits] [mlir] 0e5cad0 - [mlir][vector] Fold vector extract from insert when trailing unit dims (#192109)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 16 06:34:42 PDT 2026
Author: Erick Ochoa Lopez
Date: 2026-04-16T09:34:37-04:00
New Revision: 0e5cad062cd55499ae12fd3ef4bdb555dc2169f4
URL: https://github.com/llvm/llvm-project/commit/0e5cad062cd55499ae12fd3ef4bdb555dc2169f4
DIFF: https://github.com/llvm/llvm-project/commit/0e5cad062cd55499ae12fd3ef4bdb555dc2169f4.diff
LOG: [mlir][vector] Fold vector extract from insert when trailing unit dims (#192109)
Upstreamed from https://github.com/iree-org/iree/pull/23789
Folds vector.extract from vector.insert when the extract position is a
prefix of the insert position and the remaining (un-indexed) dimensions
of the extracted sub-vector are all size 1. In that case the extracted
value is fully determined by the inserted value.
Example:
%ins = vector.insert %s, %v [3, 0] : f32 into vector<16x1xf32>
%ext = vector.extract %ins [3] : vector<1xf32> from vector<16x1xf32>
folds to:
%ext = vector.broadcast %s : f32 to vector<1xf32>
Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply at anthropic.com>
Added:
mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 238f37ae57ac6..0c5bbc04394ac 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2453,12 +2453,73 @@ struct ExtractToShapeCast final : OpRewritePattern<vector::ExtractOp> {
}
};
+/// Folds vector.extract from vector.insert when the extract position is a
+/// prefix of the insert position and the remaining (un-indexed) dimensions
+/// of the extracted sub-vector are all size 1. In that case the extracted
+/// value is fully determined by the inserted value.
+///
+/// Examples:
+/// %ins = vector.insert %s, %v [3, 0] : f32 into vector<16x1xf32>
+/// %ext = vector.extract %ins [3] : vector<1xf32> from vector<16x1xf32>
+/// folds to:
+/// %ext = vector.broadcast %s : f32 to vector<1xf32>
+///
+/// %ins = vector.insert %s, %v [0, 0] : vector<1xf32> into vector<16x1x1xf32>
+// %ext = vector.extract %ins [0] : vector<1x1xf32> from vector<16x1x1xf32>
+/// folds to:
+/// %ext = vector.shape_cast %arg0 : vector<1xf32> to vector<1x1xf32>
+struct FoldExtractFromInsertUnitDim final
+ : OpRewritePattern<vector::ExtractOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ if (extractOp.hasDynamicPosition())
+ return failure();
+
+ auto insertOp = extractOp.getSource().getDefiningOp<vector::InsertOp>();
+ if (!insertOp || insertOp.hasDynamicPosition())
+ return failure();
+
+ ArrayRef<int64_t> extractPos = extractOp.getStaticPosition();
+ ArrayRef<int64_t> insertPos = insertOp.getStaticPosition();
+
+ // The extract position must be a strict prefix of the insert position.
+ if (extractPos.size() >= insertPos.size() ||
+ extractPos != insertPos.take_front(extractPos.size()))
+ return failure();
+
+ // The remaining dimensions (those not indexed by the extract) must all
+ // be size 1 in the source vector type. This guarantees that the inserted
+ // value fully determines the extracted sub-vector.
+ auto srcVecType = extractOp.getSourceVectorType();
+ for (int64_t i = extractPos.size(), e = srcVecType.getRank(); i < e; ++i)
+ if (srcVecType.getDimSize(i) != 1)
+ return failure();
+
+ Value inserted = insertOp.getValueToStore();
+ Type extractedType = extractOp.getResult().getType();
+ if (isa<VectorType>(inserted.getType())) {
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, extractedType,
+ inserted);
+ } else {
+ // The inserted value fully determines the extracted sub-vector; broadcast
+ // it to the extracted type.
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+ extractOp, extractOp.getResult().getType(),
+ insertOp.getValueToStore());
+ }
+ return success();
+ }
+};
+
} // namespace
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask,
- ExtractOpFromConstantMask, ExtractToShapeCast>(context);
+ ExtractOpFromConstantMask, ExtractToShapeCast,
+ FoldExtractFromInsertUnitDim>(context);
results.add(foldExtractFromShapeCastToShapeCast);
results.add(foldExtractFromFromElements);
}
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir
new file mode 100644
index 0000000000000..615d79dab77ab
--- /dev/null
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir
@@ -0,0 +1,128 @@
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file | FileCheck %s
+
+// Tests for FoldExtractFromInsertUnitDim: fold vector.extract from
+// vector.insert when the extract position is a strict prefix of the insert
+// position and all remaining dimensions are size 1, so the extracted
+// sub-vector is fully determined by the inserted value.
+
+// Basic case: extract row from a vector<4x1xf32> insert chain.
+// The extract at [i] from insert at [i, 0] should fold to a broadcast.
+
+// CHECK-LABEL: func.func @extract_from_insert_trailing_unit_dim
+// CHECK-SAME: %[[S0:.*]]: f32, %[[S1:.*]]: f32
+// CHECK-NOT: ub.poison
+// CHECK-NOT: vector.insert {{.*}} vector<4x1xf32>
+// CHECK-NOT: vector.extract {{.*}} vector<4x1xf32>
+// CHECK-DAG: vector.broadcast %[[S0]] : f32 to vector<1xf32>
+// CHECK-DAG: vector.broadcast %[[S1]] : f32 to vector<1xf32>
+func.func @extract_from_insert_trailing_unit_dim(%s0: f32, %s1: f32) -> (vector<1xf32>, vector<1xf32>) {
+ %poison = ub.poison : vector<4x1xf32>
+ %ins0 = vector.insert %s0, %poison [0, 0] : f32 into vector<4x1xf32>
+ %ins1 = vector.insert %s1, %ins0 [1, 0] : f32 into vector<4x1xf32>
+ %ext0 = vector.extract %ins0 [0] : vector<1xf32> from vector<4x1xf32>
+ %ext1 = vector.extract %ins1 [1] : vector<1xf32> from vector<4x1xf32>
+ return %ext0, %ext1 : vector<1xf32>, vector<1xf32>
+}
+
+// -----
+
+// Multiple trailing unit dims: vector<4x1x1xf32>.
+// Extract at [i] gives vector<1x1xf32>; the inserted value fully
+// determines the result.
+
+// CHECK-LABEL: func.func @extract_from_insert_multiple_trailing_unit_dims
+// CHECK-SAME: %[[S:.*]]: f32
+// CHECK-NOT: ub.poison
+// CHECK: vector.broadcast %[[S]] : f32 to vector<1x1xf32>
+func.func @extract_from_insert_multiple_trailing_unit_dims(%s: f32) -> vector<1x1xf32> {
+ %poison = ub.poison : vector<4x1x1xf32>
+ %ins = vector.insert %s, %poison [2, 0, 0] : f32 into vector<4x1x1xf32>
+ %ext = vector.extract %ins [2] : vector<1x1xf32> from vector<4x1x1xf32>
+ return %ext : vector<1x1xf32>
+}
+
+// -----
+
+// Negative: extract position does not match insert position.
+// The upstream extract fold forwards through to the poison dest,
+// but our pattern should not fire.
+
+// CHECK-LABEL: func.func @negative_extract_from_insert_position_mismatch
+// CHECK-NOT: vector.broadcast
+func.func @negative_extract_from_insert_position_mismatch(%s: f32) -> vector<1xf32> {
+ %poison = ub.poison : vector<4x1xf32>
+ %ins = vector.insert %s, %poison [1, 0] : f32 into vector<4x1xf32>
+ %ext = vector.extract %ins [0] : vector<1xf32> from vector<4x1xf32>
+ return %ext : vector<1xf32>
+}
+
+// -----
+
+// Negative: trailing dim is not unit size -- should NOT fold.
+
+// CHECK-LABEL: func.func @negative_extract_from_insert_non_unit_trailing_dim
+// CHECK: vector.insert
+// CHECK: vector.extract
+func.func @negative_extract_from_insert_non_unit_trailing_dim(%s: f32) -> vector<4xf32> {
+ %poison = ub.poison : vector<3x4xf32>
+ %ins = vector.insert %s, %poison [1, 2] : f32 into vector<3x4xf32>
+ %ext = vector.extract %ins [1] : vector<4xf32> from vector<3x4xf32>
+ return %ext : vector<4xf32>
+}
+
+// -----
+
+// Negative: dynamic extract position -- should NOT fold.
+
+// CHECK-LABEL: func.func @negative_extract_from_insert_dynamic_position
+// CHECK: vector.insert
+// CHECK: vector.extract
+func.func @negative_extract_from_insert_dynamic_position(%s: f32, %idx: index) -> vector<1xf32> {
+ %poison = ub.poison : vector<4x1xf32>
+ %ins = vector.insert %s, %poison [2, 0] : f32 into vector<4x1xf32>
+ %ext = vector.extract %ins [%idx] : vector<1xf32> from vector<4x1xf32>
+ return %ext : vector<1xf32>
+}
+
+// -----
+
+// Negative: dynamic insert position -- should NOT fold.
+
+// CHECK-LABEL: func.func @negative_insert_dynamic_position_not_folded
+// CHECK: vector.insert
+// CHECK: vector.extract
+func.func @negative_insert_dynamic_position_not_folded(%s: f32, %idx: index) -> vector<1xf32> {
+ %poison = ub.poison : vector<4x1xf32>
+ %ins = vector.insert %s, %poison [%idx, 0] : f32 into vector<4x1xf32>
+ %ext = vector.extract %ins [2] : vector<1xf32> from vector<4x1xf32>
+ return %ext : vector<1xf32>
+}
+
+// -----
+
+// Exact position match (same number of indices) is handled by the
+// existing extract fold, not our pattern. Verify it still works.
+
+// CHECK-LABEL: func.func @extract_from_insert_exact_match
+// CHECK-SAME: %[[S:.*]]: f32
+// CHECK: return %[[S]]
+func.func @extract_from_insert_exact_match(%s: f32) -> f32 {
+ %poison = ub.poison : vector<4x1xf32>
+ %ins = vector.insert %s, %poison [2, 0] : f32 into vector<4x1xf32>
+ %ext = vector.extract %ins [2, 0] : f32 from vector<4x1xf32>
+ return %ext : f32
+}
+
+// -----
+
+// First matches FoldExtractFromInsertUnitDim, then BroadcastToShapeCast.
+
+// CHECK-LABEL: func.func @extract_from_insert_vector_to_vector_broadcast
+// CHECK-SAME: %[[SRC:.*]]: vector<1xf32>
+// CHECK: vector.shape_cast %[[SRC]] : vector<1xf32> to vector<1x1xf32>
+func.func @extract_from_insert_vector_to_vector_broadcast(%src: vector<1xf32>) -> vector<1x1xf32> {
+ %poison = ub.poison : vector<16x1x1xf32>
+ %vec1 = vector.insert %src, %poison [0, 0] : vector<1xf32> into vector<16x1x1xf32>
+ %vec2 = vector.extract %vec1[0] : vector<1x1xf32> from vector<16x1x1xf32>
+ return %vec2 : vector<1x1xf32>
+}
More information about the Mlir-commits
mailing list