[Mlir-commits] [mlir] [MLIR][Linalg] Scalable Vectorization of Reduction (PR #97788)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Jul 11 07:24:24 PDT 2024


================
@@ -298,6 +298,30 @@ func.func @scalable_dim_1d(%A: vector<[4]xf32>, %B: f32, %C: vector<[4]xi1>) ->
 // CHECK:          %[[VAL_4:.*]] = vector.extract %[[VAL_3]][0] : f32 from vector<1xf32>
 // CHECK:          return %[[VAL_4]] : f32
 
+func.func @scalable_dim_2d(%A: vector<2x[4]xf32>, %B: vector<2xf32>, %C: vector<2x[4]xi1>) -> vector<2xf32> {
+    %0 = vector.mask %C { vector.multi_reduction <add>, %A, %B [1] : vector<2x[4]xf32> to vector<2xf32> } : vector<2x[4]xi1> -> vector<2xf32>
+    return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL:  func.func @scalable_dim_2d(
+// CHECK-SAME:                                      %[[ARG_0:.*]]: vector<2x[4]xf32>,
+// CHECK-SAME:                                      %[[ARG_1:.*]]: vector<2xf32>,
+// CHECK-SAME:                                      %[[ARG_2:.*]]: vector<2x[4]xi1>) -> vector<2xf32> {
+// CHECK-DAG:      %[[CON_0:.*]] = arith.constant 1 : index
+// CHECK-DAG:      %[[CON_1:.*]] = arith.constant 0 : index
+// CHECK-DAG:      %[[CON_2:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
+// CHECK:          %[[VAL_0:.*]] = vector.extract %[[ARG_0]][0] : vector<[4]xf32> from vector<2x[4]xf32>
+// CHECK:          %[[VAL_1:.*]] = vector.extract %[[ARG_1]][0] : f32 from vector<2xf32>
+// CHECK:          %[[VAL_2:.*]] = vector.extract %[[ARG_2]][0] : vector<[4]xi1> from vector<2x[4]xi1>
+// CHECK:          %[[VAL_3:.*]] = vector.mask %[[VAL_2]] { vector.reduction <add>, %[[VAL_0]], %[[VAL_1]] : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32
+// CHECK:          %[[VAL_4:.*]] = vector.insertelement %[[VAL_3]], %[[CON_2]][%[[CON_1]] : index] : vector<2xf32>
+// CHECK:          %[[VAL_5:.*]] = vector.extract %[[ARG_0]][1] : vector<[4]xf32> from vector<2x[4]xf32>
+// CHECK:          %[[VAL_6:.*]] = vector.extract %[[ARG_1]][1] : f32 from vector<2xf32>
+// CHECK:          %[[VAL_7:.*]] = vector.extract %[[ARG_2]][1] : vector<[4]xi1> from vector<2x[4]xi1>
+// CHECK:          %[[VAL_8:.*]] = vector.mask %[[VAL_7]] { vector.reduction <add>, %[[VAL_5]], %[[VAL_6]] : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32
+// CHECK:          %[[VAL_9:.*]] = vector.insertelement %[[VAL_8]], %[[VAL_4]][%[[CON_0]] : index] : vector<2xf32>
+// CHECK:          return %[[VAL_9]] : vector<2xf32>
+
----------------
banach-space wrote:

This test is unrelated to this change - please send as a separate PR.

https://github.com/llvm/llvm-project/pull/97788


More information about the Mlir-commits mailing list