[Mlir-commits] [mlir] [mlir][vector] Add pattern to reorder shape_cast(arithmetic(a, b)) (PR #74817)

Andrzej WarzyƄski llvmlistbot at llvm.org
Mon Dec 11 00:12:51 PST 2023


================
@@ -254,3 +254,62 @@ func.func @transfer_read_flattenable_negative2(
 
 // CHECK-LABEL: func @transfer_read_flattenable_negative2
 //       CHECK:   vector.transfer_read {{.*}} vector<5x4x3x2xi8>
+
+// -----
+
+func.func @fold_unit_dim_add(%arg0 : vector<8x1xi32>,
+                             %arg1 : vector<1x8xi32>) -> vector<8xi32> {
+   %sc_arg0 = vector.shape_cast %arg0 : vector<8x1xi32> to vector<1x8xi32>
+   %add = arith.addi %sc_arg0, %arg1 : vector<1x8xi32>
+   %res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32>
+   return %res : vector<8xi32>
+}
+
+// CHECK-LABEL:   func.func @fold_unit_dim_add(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<8x1xi32>,
+// CHECK-SAME:      %[[VAL_1:.*]]: vector<1x8xi32>) -> vector<8xi32> {
+// CHECK:           %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1xi32> to vector<8xi32>
+// CHECK:           %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8xi32> to vector<8xi32>
+// CHECK:           %[[VAL_4:.*]] = arith.addi %[[VAL_2]], %[[VAL_3]] : vector<8xi32>
+// CHECK:           return %[[VAL_4]] : vector<8xi32>
+
+// -----
+
+func.func @fold_unit_dim_mulf(%arg0 : vector<8x[2]x1xf32>,
+                              %arg1 : vector<1x8x[2]xf32>) -> vector<8x[2]xf32> {
+   %sc_arg0 = vector.shape_cast %arg0 : vector<8x[2]x1xf32> to vector<1x8x[2]xf32>
+   %add = arith.mulf %sc_arg0, %arg1 : vector<1x8x[2]xf32>
+   %res = vector.shape_cast %add : vector<1x8x[2]xf32> to vector<8x[2]xf32>
+   return %res : vector<8x[2]xf32>
+}
+
+// CHECK-LABEL:   func.func @fold_unit_dim_mulf(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<8x[2]x1xf32>,
+// CHECK-SAME:      %[[VAL_1:.*]]: vector<1x8x[2]xf32>) -> vector<8x[2]xf32> {
+// CHECK:           %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x[2]x1xf32> to vector<8x[2]xf32>
+// CHECK:           %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x[2]xf32> to vector<8x[2]xf32>
+// CHECK:           %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[2]xf32>
+// CHECK:           return %[[VAL_4]] : vector<8x[2]xf32>
+
+// -----
+
+// All shape casts are folded away
+
+func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
----------------
banach-space wrote:

The other tests don't really show how `shape_cast` gets folded away entirely. And this has actually come up in a discussion recently:
* https://github.com/openxla/iree/pull/15839#issuecomment-1849075025

https://github.com/llvm/llvm-project/pull/74817


More information about the Mlir-commits mailing list