[Mlir-commits] [mlir] Refactor LoopFuseSiblingOp and support parallel fusion (PR #94391)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 27 10:22:59 PDT 2024


================
@@ -47,6 +47,59 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: func @fuse_two_parallel
+// CHECK-SAME:   ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
+func.func @fuse_two_parallel(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
+// CHECK-DAG:  [[C2:%.*]] = arith.constant 2 : index
+// CHECK-DAG:  [[C0:%.*]] = arith.constant 0 : index
+// CHECK-DAG:  [[C1:%.*]] = arith.constant 1 : index
+// CHECK-DAG:  [[C1FP:%.*]] = arith.constant 1.
+  %c2 = arith.constant 2 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c1fp = arith.constant 1.0 : f32
+// CHECK:      [[SUM:%.*]] = memref.alloc()
+  %sum = memref.alloc()  : memref<2x2xf32>
+// CHECK:      scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
+// CHECK-SAME:     to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
+// CHECK:        [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
+// CHECK:        memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT:  scf.parallel
+// CHECK:        [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
+// CHECK:        memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
+// CHECK:        scf.reduce
+// CHECK:      }
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+    %sum_elem = arith.addf %B_elem, %c1fp : f32
+    memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
+    scf.reduce
+  }
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
+    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+    %product_elem = arith.mulf %sum_elem, %A_elem : f32
+    memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
+    scf.reduce
+  }
+// CHECK:      memref.dealloc [[SUM]]
+  memref.dealloc %sum : memref<2x2xf32>
+  return
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %parallel:2 = transform.split_handle %0 :  (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused = transform.loop.fuse_sibling %parallel#0 into %parallel#1 : (!transform.any_op,!transform.any_op) ->  !transform.any_op
----------------
srcarroll wrote:

it's not symmetrical, the order of fusing will affect the order that the ops show up in fused body. there are checks to ensure that the fused op doesn't violate dominance. so not all loops can be fused in arbitrary order. in this test there would be no dominance violation if the order was switched.

https://github.com/llvm/llvm-project/pull/94391


More information about the Mlir-commits mailing list