[Mlir-commits] [mlir] [mlir][vector] Improve shape_cast lowering (PR #140800)

Andrzej WarzyƄski llvmlistbot at llvm.org
Tue May 27 09:50:51 PDT 2025


================
@@ -140,6 +140,59 @@ func.func @shape_cast_1d0d(%arg0 : vector<1xf32>) -> vector<f32> {
   return %s : vector<f32>
 }
 
+
+// The shapes have 2 inner dimension sizes in common, so the extract result is rank-2.
+// CHECK-LABEL:  func.func @squeeze_out_prefix_unit_dim(
+// CHECK-SAME:               %[[ARG0:.*]]: vector<1x2x3xf32>) -> vector<2x3xf32> {
+// CHECK:          %[[EXTRACTED:.*]] = vector.extract %[[ARG0]][0] :
+// CHECK-SAME:               vector<2x3xf32> from vector<1x2x3xf32>
+// CHECK:          return %[[EXTRACTED]] : vector<2x3xf32>
+func.func @squeeze_out_prefix_unit_dim(%arg0 : vector<1x2x3xf32>) -> vector<2x3xf32> {
+  %s = vector.shape_cast %arg0 : vector<1x2x3xf32> to vector<2x3xf32>
+  return %s : vector<2x3xf32>
+}
+
+// The shapes have 1 inner dimension size in common, so the extract results are rank-1.
+// CHECK-LABEL:  func.func @squeeze_out_middle_unit_dim(
+// CHECK-SAME:               %[[ARG0:.*]]: vector<2x1x3xf32>) -> vector<2x3xf32> {
+// CHECK:         %[[UB:.*]] = ub.poison : vector<2x3xf32>
+// CHECK:         %[[E0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<3xf32>
+// CHECK:         %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0] : vector<3xf32>
+// CHECK-SAME:                 into vector<2x3xf32>
+// CHECK:         %[[E1:.*]] = vector.extract %[[ARG0]][1, 0] : vector<3xf32>
+// CHECK:         %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1] : vector<3xf32>
+// CHECK-SAME:                 into vector<2x3xf32>
+// CHECK:         return %[[I1]] : vector<2x3xf32>
+func.func @squeeze_out_middle_unit_dim(%arg0 : vector<2x1x3xf32>) -> vector<2x3xf32> {
+  %s = vector.shape_cast %arg0 : vector<2x1x3xf32> to vector<2x3xf32>
+  return %s : vector<2x3xf32>
+}
+
+// CHECK-LABEL:  func.func @prepend_unit_dim(
+// CHECK-SAME:               %[[ARG0:.*]]: vector<2x3xf32>) -> vector<1x2x3xf32> {
+// CHECK:          %[[UB:.*]] = ub.poison : vector<1x2x3xf32>
+// CHECK:          %[[INSERTED:.*]] = vector.insert %[[ARG0]], %[[UB]] [0]
+// CHECK-SAME:               : vector<2x3xf32> into vector<1x2x3xf32>
+// CHECK:        return %[[INSERTED]] : vector<1x2x3xf32>
----------------
banach-space wrote:

This indention is off.

https://github.com/llvm/llvm-project/pull/140800


More information about the Mlir-commits mailing list