[Mlir-commits] [mlir] [MLIR][Vector]Generalize DropUnitDimFromElementwiseOps (PR #92934)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Mon Jun 3 02:11:10 PDT 2024
================
@@ -459,6 +459,41 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
// CHECK-128B-LABEL: func @fold_unit_dims_entirely(
// CHECK-128B-NOT: memref.collapse_shape
+// -----
+
+func.func @fold_unit_inner_dim(%arg0 : vector<8x1x3xf128>,
+ %arg1 : vector<1x8x3xf128>) -> vector<8x3xf128> {
+ %sc_arg1 = vector.shape_cast %arg1 : vector<1x8x3xf128> to vector<8x1x3xf128>
+ %mul = arith.mulf %arg0, %sc_arg1 : vector<8x1x3xf128>
+ %res = vector.shape_cast %mul : vector<8x1x3xf128> to vector<8x3xf128>
+ return %res : vector<8x3xf128>
+}
+
+// CHECK-LABEL: func.func @fold_unit_inner_dim(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x3xf128>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x3xf128>) -> vector<8x3xf128> {
+// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x3xf128> to vector<8x3xf128>
+// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x3xf128> to vector<8x3xf128>
+// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x3xf128>
+// CHECK: return %[[VAL_4]] : vector<8x3xf128>
+
+// -----
+
+func.func @fold_unit_inner_dim_scalable(%arg0 : vector<8x1x[1]x3xf128>,
----------------
banach-space wrote:
[nit]
```suggestion
func.func @fold_inner_unit_dim_scalable(%arg0 : vector<8x1x[1]x3xf128>,
```
https://github.com/llvm/llvm-project/pull/92934
More information about the Mlir-commits
mailing list