[Mlir-commits] [mlir] [mlir][vector] Add pattern to reorder shape_cast(arithmetic(a, b)) (PR #74817)
Cullen Rhodes
llvmlistbot at llvm.org
Fri Dec 8 09:35:17 PST 2023
================
@@ -254,3 +254,62 @@ func.func @transfer_read_flattenable_negative2(
// CHECK-LABEL: func @transfer_read_flattenable_negative2
// CHECK: vector.transfer_read {{.*}} vector<5x4x3x2xi8>
+
+// -----
+
+func.func @fold_unit_dim_add(%arg0 : vector<8x1xi32>,
+ %arg1 : vector<1x8xi32>) -> vector<8xi32> {
+ %sc_arg0 = vector.shape_cast %arg0 : vector<8x1xi32> to vector<1x8xi32>
+ %add = arith.addi %sc_arg0, %arg1 : vector<1x8xi32>
+ %res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32>
+ return %res : vector<8xi32>
+}
+
+// CHECK-LABEL: func.func @fold_unit_dim_add(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1xi32>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8xi32>) -> vector<8xi32> {
+// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1xi32> to vector<8xi32>
+// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8xi32> to vector<8xi32>
+// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_2]], %[[VAL_3]] : vector<8xi32>
+// CHECK: return %[[VAL_4]] : vector<8xi32>
+
+// -----
+
+func.func @fold_unit_dim_mulf(%arg0 : vector<8x[2]x1xf32>,
+ %arg1 : vector<1x8x[2]xf32>) -> vector<8x[2]xf32> {
+ %sc_arg0 = vector.shape_cast %arg0 : vector<8x[2]x1xf32> to vector<1x8x[2]xf32>
+ %add = arith.mulf %sc_arg0, %arg1 : vector<1x8x[2]xf32>
+ %res = vector.shape_cast %add : vector<1x8x[2]xf32> to vector<8x[2]xf32>
+ return %res : vector<8x[2]xf32>
+}
+
+// CHECK-LABEL: func.func @fold_unit_dim_mulf(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<8x[2]x1xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x[2]xf32>) -> vector<8x[2]xf32> {
+// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x[2]x1xf32> to vector<8x[2]xf32>
+// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x[2]xf32> to vector<8x[2]xf32>
+// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[2]xf32>
+// CHECK: return %[[VAL_4]] : vector<8x[2]xf32>
+
+// -----
+
+// All shape casts are folded away
+
+func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
----------------
c-rhodes wrote:
I think this can be removed, what it's testing is already covered
https://github.com/llvm/llvm-project/pull/74817
More information about the Mlir-commits
mailing list