[Mlir-commits] [mlir] [mlir][vector] Improve shape_cast lowering (PR #140800)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Tue May 27 09:50:51 PDT 2025
================
@@ -140,6 +140,59 @@ func.func @shape_cast_1d0d(%arg0 : vector<1xf32>) -> vector<f32> {
return %s : vector<f32>
}
+
+// The shapes have 2 inner dimension sizes in common, so the extract result is rank-2.
+// CHECK-LABEL: func.func @squeeze_out_prefix_unit_dim(
+// CHECK-SAME: %[[ARG0:.*]]: vector<1x2x3xf32>) -> vector<2x3xf32> {
+// CHECK: %[[EXTRACTED:.*]] = vector.extract %[[ARG0]][0] :
+// CHECK-SAME: vector<2x3xf32> from vector<1x2x3xf32>
+// CHECK: return %[[EXTRACTED]] : vector<2x3xf32>
+func.func @squeeze_out_prefix_unit_dim(%arg0 : vector<1x2x3xf32>) -> vector<2x3xf32> {
+ %s = vector.shape_cast %arg0 : vector<1x2x3xf32> to vector<2x3xf32>
+ return %s : vector<2x3xf32>
+}
+
+// The shapes have 1 inner dimension size in common, so the extract results are rank-1.
+// CHECK-LABEL: func.func @squeeze_out_middle_unit_dim(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x3xf32>) -> vector<2x3xf32> {
+// CHECK: %[[UB:.*]] = ub.poison : vector<2x3xf32>
+// CHECK: %[[E0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<3xf32>
+// CHECK: %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0] : vector<3xf32>
+// CHECK-SAME: into vector<2x3xf32>
+// CHECK: %[[E1:.*]] = vector.extract %[[ARG0]][1, 0] : vector<3xf32>
+// CHECK: %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1] : vector<3xf32>
+// CHECK-SAME: into vector<2x3xf32>
+// CHECK: return %[[I1]] : vector<2x3xf32>
+func.func @squeeze_out_middle_unit_dim(%arg0 : vector<2x1x3xf32>) -> vector<2x3xf32> {
+ %s = vector.shape_cast %arg0 : vector<2x1x3xf32> to vector<2x3xf32>
+ return %s : vector<2x3xf32>
+}
+
+// CHECK-LABEL: func.func @prepend_unit_dim(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x3xf32>) -> vector<1x2x3xf32> {
+// CHECK: %[[UB:.*]] = ub.poison : vector<1x2x3xf32>
+// CHECK: %[[INSERTED:.*]] = vector.insert %[[ARG0]], %[[UB]] [0]
+// CHECK-SAME: : vector<2x3xf32> into vector<1x2x3xf32>
+// CHECK: return %[[INSERTED]] : vector<1x2x3xf32>
----------------
banach-space wrote:
This indention is off.
https://github.com/llvm/llvm-project/pull/140800
More information about the Mlir-commits
mailing list