[Mlir-commits] [mlir] [mlir][scf] Add reductions support to `scf.parallel` fusion (PR #75955)

Ettore Tiotto llvmlistbot at llvm.org
Tue Dec 19 09:43:22 PST 2023


================
@@ -387,3 +387,125 @@ func.func @do_not_fuse_alias(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
 // CHECK-LABEL: func @do_not_fuse_alias
 // CHECK:      scf.parallel
 // CHECK:      scf.parallel
+
+// -----
+
+func.func @fuse_reductions(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
+  %c2 = arith.constant 2 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %init1 = arith.constant 1.0 : f32
+  %init2 = arith.constant 2.0 : f32
+  %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
+    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+    scf.reduce(%A_elem) : f32 {
+    ^bb0(%lhs: f32, %rhs: f32):
+      %1 = arith.addf %lhs, %rhs : f32
+      scf.reduce.return %1 : f32
+    }
+    scf.yield
+  }
+  %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
+    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+    scf.reduce(%B_elem) : f32 {
+    ^bb0(%lhs: f32, %rhs: f32):
+      %1 = arith.mulf %lhs, %rhs : f32
+      scf.reduce.return %1 : f32
+    }
+    scf.yield
+  }
+  return %res1, %res2 : f32, f32
+}
+
+// CHECK-LABEL: func @fuse_reductions
+//  CHECK-SAME:  (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>)
----------------
etiotto wrote:

Add also the return type here?

https://github.com/llvm/llvm-project/pull/75955


More information about the Mlir-commits mailing list