[Mlir-commits] [mlir] [mlir][vector] Add pattern to reorder shape_cast(arithmetic(a, b)) (PR #74817)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Mon Dec 11 00:12:51 PST 2023
================
@@ -254,3 +254,62 @@ func.func @transfer_read_flattenable_negative2(
// CHECK-LABEL: func @transfer_read_flattenable_negative2
// CHECK: vector.transfer_read {{.*}} vector<5x4x3x2xi8>
+
+// -----
+
+func.func @fold_unit_dim_add(%arg0 : vector<8x1xi32>,
+ %arg1 : vector<1x8xi32>) -> vector<8xi32> {
+ %sc_arg0 = vector.shape_cast %arg0 : vector<8x1xi32> to vector<1x8xi32>
+ %add = arith.addi %sc_arg0, %arg1 : vector<1x8xi32>
+ %res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32>
+ return %res : vector<8xi32>
+}
+
+// CHECK-LABEL: func.func @fold_unit_dim_add(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1xi32>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8xi32>) -> vector<8xi32> {
+// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1xi32> to vector<8xi32>
+// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8xi32> to vector<8xi32>
+// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_2]], %[[VAL_3]] : vector<8xi32>
+// CHECK: return %[[VAL_4]] : vector<8xi32>
+
+// -----
+
+func.func @fold_unit_dim_mulf(%arg0 : vector<8x[2]x1xf32>,
+ %arg1 : vector<1x8x[2]xf32>) -> vector<8x[2]xf32> {
+ %sc_arg0 = vector.shape_cast %arg0 : vector<8x[2]x1xf32> to vector<1x8x[2]xf32>
+ %add = arith.mulf %sc_arg0, %arg1 : vector<1x8x[2]xf32>
+ %res = vector.shape_cast %add : vector<1x8x[2]xf32> to vector<8x[2]xf32>
+ return %res : vector<8x[2]xf32>
+}
+
+// CHECK-LABEL: func.func @fold_unit_dim_mulf(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<8x[2]x1xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x[2]xf32>) -> vector<8x[2]xf32> {
+// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x[2]x1xf32> to vector<8x[2]xf32>
+// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x[2]xf32> to vector<8x[2]xf32>
+// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[2]xf32>
+// CHECK: return %[[VAL_4]] : vector<8x[2]xf32>
+
+// -----
+
+// All shape casts are folded away
+
+func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
----------------
banach-space wrote:
The other tests don't really show how `shape_cast` gets folded away entirely. And this has actually come up in a discussion recently:
* https://github.com/openxla/iree/pull/15839#issuecomment-1849075025
https://github.com/llvm/llvm-project/pull/74817
More information about the Mlir-commits
mailing list