[Mlir-commits] [mlir] [MLIR][Vector]Generalize DropUnitDimFromElementwiseOps (PR #92934)

Andrzej WarzyƄski llvmlistbot at llvm.org
Mon Jun 3 02:11:10 PDT 2024


================
@@ -459,6 +459,41 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
 // CHECK-128B-LABEL: func @fold_unit_dims_entirely(
 //   CHECK-128B-NOT:   memref.collapse_shape
 
+// -----
+
+func.func @fold_unit_inner_dim(%arg0 : vector<8x1x3xf128>,
+                              %arg1 : vector<1x8x3xf128>) -> vector<8x3xf128> {
+   %sc_arg1 = vector.shape_cast %arg1 : vector<1x8x3xf128> to vector<8x1x3xf128>
+   %mul = arith.mulf %arg0, %sc_arg1 : vector<8x1x3xf128>
+   %res = vector.shape_cast %mul : vector<8x1x3xf128> to vector<8x3xf128>
+   return %res : vector<8x3xf128>
+}
+
+// CHECK-LABEL: func.func @fold_unit_inner_dim(
+// CHECK-SAME:    %[[VAL_0:.*]]: vector<8x1x3xf128>,
+// CHECK-SAME:    %[[VAL_1:.*]]: vector<1x8x3xf128>) -> vector<8x3xf128> {
+// CHECK:         %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x3xf128> to vector<8x3xf128>
+// CHECK:         %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x3xf128> to vector<8x3xf128>
+// CHECK:         %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x3xf128>
+// CHECK:         return %[[VAL_4]] : vector<8x3xf128>
+
+// -----
+
+func.func @fold_unit_inner_dim_scalable(%arg0 : vector<8x1x[1]x3xf128>,
----------------
banach-space wrote:

[nit] 
```suggestion
func.func @fold_inner_unit_dim_scalable(%arg0 : vector<8x1x[1]x3xf128>,
```

https://github.com/llvm/llvm-project/pull/92934


More information about the Mlir-commits mailing list