[Mlir-commits] [mlir] [mlir][vector] Fold vector extract from insert when trailing unit dims (PR #192109)

Erick Ochoa Lopez llvmlistbot at llvm.org
Wed Apr 15 08:11:59 PDT 2026


================
@@ -0,0 +1,114 @@
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file | FileCheck %s
+
+// Tests for FoldExtractFromInsertUnitDim: fold vector.extract from
+// vector.insert when the extract position is a strict prefix of the insert
+// position and all remaining dimensions are size 1, so the extracted
+// sub-vector is fully determined by the inserted value.
+
+// Basic case: extract row from a vector<4x1xf32> insert chain.
+// The extract at [i] from insert at [i, 0] should fold to a broadcast.
+
+// CHECK-LABEL: func.func @extract_from_insert_trailing_unit_dim
+// CHECK-SAME:    %[[S0:.*]]: f32, %[[S1:.*]]: f32
+// CHECK-NOT:     ub.poison
+// CHECK-NOT:     vector.insert {{.*}} vector<4x1xf32>
+// CHECK-NOT:     vector.extract {{.*}} vector<4x1xf32>
+// CHECK-DAG:     vector.broadcast %[[S0]] : f32 to vector<1xf32>
+// CHECK-DAG:     vector.broadcast %[[S1]] : f32 to vector<1xf32>
+func.func @extract_from_insert_trailing_unit_dim(%s0: f32, %s1: f32) -> (vector<1xf32>, vector<1xf32>) {
+  %poison = ub.poison : vector<4x1xf32>
+  %ins0 = vector.insert %s0, %poison [0, 0] : f32 into vector<4x1xf32>
+  %ins1 = vector.insert %s1, %ins0 [1, 0] : f32 into vector<4x1xf32>
+  %ext0 = vector.extract %ins0 [0] : vector<1xf32> from vector<4x1xf32>
+  %ext1 = vector.extract %ins1 [1] : vector<1xf32> from vector<4x1xf32>
+  return %ext0, %ext1 : vector<1xf32>, vector<1xf32>
+}
+
+// -----
+
+// Multiple trailing unit dims: vector<4x1x1xf32>.
+// Extract at [i] gives vector<1x1xf32>; the inserted value fully
+// determines the result.
+
+// CHECK-LABEL: func.func @extract_from_insert_multiple_trailing_unit_dims
+// CHECK-SAME:    %[[S:.*]]: f32
+// CHECK-NOT:     ub.poison
+// CHECK:         vector.broadcast %[[S]] : f32 to vector<1x1xf32>
+func.func @extract_from_insert_multiple_trailing_unit_dims(%s: f32) -> vector<1x1xf32> {
+  %poison = ub.poison : vector<4x1x1xf32>
+  %ins = vector.insert %s, %poison [2, 0, 0] : f32 into vector<4x1x1xf32>
+  %ext = vector.extract %ins [2] : vector<1x1xf32> from vector<4x1x1xf32>
----------------
amd-eochoalo wrote:

https://github.com/llvm/llvm-project/pull/192109/commits/0f70df8387c79ba1023c9a60120d109ed6b86464 thanks!

https://github.com/llvm/llvm-project/pull/192109


More information about the Mlir-commits mailing list