[Mlir-commits] [mlir] [mlir][scf] Add reductions support to `scf.parallel` fusion (PR #75955)

Ivan Butygin llvmlistbot at llvm.org
Thu Feb 1 06:38:00 PST 2024


https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/75955

>From 2fd5a4de51ff690cb144f9902bf20c7c16c2f036 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 19 Dec 2023 18:10:56 +0100
Subject: [PATCH 1/6] [mlir][scf] Add reductions support to `scf.parallel`
 fusion

---
 .../SCF/Transforms/ParallelLoopFusion.cpp     |  52 ++++++--
 .../Dialect/SCF/parallel-loop-fusion.mlir     | 124 +++++++++++++++++-
 2 files changed, 166 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index d3dca1427e517..7d9e220518441 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -161,29 +161,63 @@ static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
 }
 
 /// Prepends operations of firstPloop's body into secondPloop's body.
-static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
-                        OpBuilder b,
+/// Updates secondPloop with new loop.
+static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
+                        OpBuilder builder,
                         llvm::function_ref<bool(Value, Value)> mayAlias) {
+  Block *block1 = firstPloop.getBody();
+  Block *block2 = secondPloop.getBody();
   IRMapping firstToSecondPloopIndices;
-  firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(),
-                                secondPloop.getBody()->getArguments());
+  firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
 
   if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
                      mayAlias))
     return;
 
-  b.setInsertionPointToStart(secondPloop.getBody());
-  for (auto &op : firstPloop.getBody()->without_terminator())
-    b.clone(op, firstToSecondPloopIndices);
+  DominanceInfo dom;
+  for (Operation *user : firstPloop->getUsers())
+    if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
+      return;
+
+  ValueRange inits1 = firstPloop.getInitVals();
+  ValueRange inits2 = secondPloop.getInitVals();
+
+  SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
+  newInitVars.append(inits2.begin(), inits2.end());
+
+  IRRewriter b(builder);
+  b.setInsertionPoint(secondPloop);
+  auto newSecondPloop = b.create<ParallelOp>(
+      secondPloop.getLoc(), secondPloop.getLowerBound(),
+      secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
+
+  Block *newBlock = newSecondPloop.getBody();
+  newBlock->getTerminator()->erase();
+
+  block1->getTerminator()->erase();
+
+  b.inlineBlockBefore(block1, newBlock, newBlock->end(),
+                      newBlock->getArguments());
+  b.inlineBlockBefore(block2, newBlock, newBlock->end(),
+                      newBlock->getArguments());
+
+  ValueRange results = newSecondPloop.getResults();
+  firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
+  secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
   firstPloop.erase();
+  secondPloop.erase();
+  secondPloop = newSecondPloop;
 }
 
 void mlir::scf::naivelyFuseParallelOps(
     Region &region, llvm::function_ref<bool(Value, Value)> mayAlias) {
   OpBuilder b(region);
   // Consider every single block and attempt to fuse adjacent loops.
+  SmallVector<SmallVector<ParallelOp>, 1> ploopChains;
   for (auto &block : region) {
-    SmallVector<SmallVector<ParallelOp, 8>, 1> ploopChains{{}};
+    ploopChains.clear();
+    ploopChains.push_back({});
+
     // Not using `walk()` to traverse only top-level parallel loops and also
     // make sure that there are no side-effecting ops between the parallel
     // loops.
@@ -201,7 +235,7 @@ void mlir::scf::naivelyFuseParallelOps(
       // TODO: Handle region side effects properly.
       noSideEffects &= isMemoryEffectFree(&op) && op.getNumRegions() == 0;
     }
-    for (ArrayRef<ParallelOp> ploops : ploopChains) {
+    for (MutableArrayRef<ParallelOp> ploops : ploopChains) {
       for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
         fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
     }
diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
index 9c136bb635658..94ccbff4d8560 100644
--- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
+++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
@@ -89,7 +89,7 @@ func.func @fuse_three(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
     memref.store %product_elem, %prod[%i, %j] : memref<2x2xf32>
     scf.reduce
   }
-  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
     %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
     %res_elem = arith.addf %A_elem, %c2fp : f32
     memref.store %res_elem, %B[%i, %j] : memref<2x2xf32>
@@ -575,3 +575,125 @@ func.func @do_not_fuse_affine_apply_to_non_ind_var(
 // CHECK-NEXT:    }
 // CHECK-NEXT:    memref.dealloc %[[ALLOC]] : memref<2x3xf32>
 // CHECK-NEXT:    return
+
+// -----
+
+func.func @fuse_reductions(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
+  %c2 = arith.constant 2 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %init1 = arith.constant 1.0 : f32
+  %init2 = arith.constant 2.0 : f32
+  %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
+    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+    scf.reduce(%A_elem) : f32 {
+    ^bb0(%lhs: f32, %rhs: f32):
+      %1 = arith.addf %lhs, %rhs : f32
+      scf.reduce.return %1 : f32
+    }
+    scf.yield
+  }
+  %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
+    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+    scf.reduce(%B_elem) : f32 {
+    ^bb0(%lhs: f32, %rhs: f32):
+      %1 = arith.mulf %lhs, %rhs : f32
+      scf.reduce.return %1 : f32
+    }
+    scf.yield
+  }
+  return %res1, %res2 : f32, f32
+}
+
+// CHECK-LABEL: func @fuse_reductions
+//  CHECK-SAME:  (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>)
+//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+//   CHECK-DAG:   %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32
+//   CHECK-DAG:   %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32
+//       CHECK:   %[[RES:.*]]:2 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
+//  CHECK-SAME:   to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]])
+//  CHECK-SAME:   init (%[[INIT1]], %[[INIT2]]) -> (f32, f32)
+//       CHECK:   %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]]
+//       CHECK:   scf.reduce(%[[VAL_A]]) : f32 {
+//       CHECK:   ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
+//       CHECK:     %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
+//       CHECK:     scf.reduce.return %[[R]] : f32
+//       CHECK:   }
+//       CHECK:   %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
+//       CHECK:   scf.reduce(%[[VAL_B]]) : f32 {
+//       CHECK:   ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
+//       CHECK:     %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32
+//       CHECK:     scf.reduce.return %[[R]] : f32
+//       CHECK:   }
+//       CHECK:   scf.yield
+//       CHECK:   return %[[RES]]#0, %[[RES]]#1
+
+// -----
+
+func.func @reductions_use_res(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
+  %c2 = arith.constant 2 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %init1 = arith.constant 1.0 : f32
+  %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
+    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+    scf.reduce(%A_elem) : f32 {
+    ^bb0(%lhs: f32, %rhs: f32):
+      %1 = arith.addf %lhs, %rhs : f32
+      scf.reduce.return %1 : f32
+    }
+    scf.yield
+  }
+  %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%res1) -> f32 {
+    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+    scf.reduce(%B_elem) : f32 {
+    ^bb0(%lhs: f32, %rhs: f32):
+      %1 = arith.mulf %lhs, %rhs : f32
+      scf.reduce.return %1 : f32
+    }
+    scf.yield
+  }
+  return %res1, %res2 : f32, f32
+}
+
+// %res1 is used as second scf.parallel arg, cannot fuse
+// CHECK-LABEL: func @reductions_use_res
+// CHECK:      scf.parallel
+// CHECK:      scf.parallel
+
+// -----
+
+func.func @reductions_use_res_inside(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
+  %c2 = arith.constant 2 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %init1 = arith.constant 1.0 : f32
+  %init2 = arith.constant 2.0 : f32
+  %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
+    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+    scf.reduce(%A_elem) : f32 {
+    ^bb0(%lhs: f32, %rhs: f32):
+      %1 = arith.addf %lhs, %rhs : f32
+      scf.reduce.return %1 : f32
+    }
+    scf.yield
+  }
+  %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
+    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+    %sum = arith.addf %B_elem, %res1 : f32
+    scf.reduce(%sum) : f32 {
+    ^bb0(%lhs: f32, %rhs: f32):
+      %1 = arith.mulf %lhs, %rhs : f32
+      scf.reduce.return %1 : f32
+    }
+    scf.yield
+  }
+  return %res1, %res2 : f32, f32
+}
+
+// %res1 is used inside second scf.parallel arg, cannot fuse
+// CHECK-LABEL: func @reductions_use_res_inside
+// CHECK:      scf.parallel
+// CHECK:      scf.parallel

>From d858a4af347587be86ace6a619009fdc58b2d87c Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 19 Dec 2023 18:17:04 +0100
Subject: [PATCH 2/6] typo

---
 mlir/test/Dialect/SCF/parallel-loop-fusion.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
index 94ccbff4d8560..7644d1bafb183 100644
--- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
+++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
@@ -693,7 +693,7 @@ func.func @reductions_use_res_inside(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -
   return %res1, %res2 : f32, f32
 }
 
-// %res1 is used inside second scf.parallel arg, cannot fuse
+// %res1 is used inside second scf.parallel, cannot fuse
 // CHECK-LABEL: func @reductions_use_res_inside
 // CHECK:      scf.parallel
 // CHECK:      scf.parallel

>From 47ec48ee15d319e6973f3422d1bc0a15d1901d84 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 19 Dec 2023 23:05:27 +0100
Subject: [PATCH 3/6] update test

---
 mlir/test/Dialect/SCF/parallel-loop-fusion.mlir | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
index 7644d1bafb183..9ced6d932274e 100644
--- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
+++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
@@ -606,7 +606,7 @@ func.func @fuse_reductions(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f3
 }
 
 // CHECK-LABEL: func @fuse_reductions
-//  CHECK-SAME:  (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>)
+//  CHECK-SAME:  (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>) -> (f32, f32)
 //   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
 //   CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
@@ -628,7 +628,7 @@ func.func @fuse_reductions(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f3
 //       CHECK:     scf.reduce.return %[[R]] : f32
 //       CHECK:   }
 //       CHECK:   scf.yield
-//       CHECK:   return %[[RES]]#0, %[[RES]]#1
+//       CHECK:   return %[[RES]]#0, %[[RES]]#1 : f32, f32
 
 // -----
 

>From 65a3b05d56c3674c136ba81bcb552c8d0d2cfb6e Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Wed, 20 Dec 2023 18:29:17 +0100
Subject: [PATCH 4/6] Update to new reductions format

---
 .../SCF/Transforms/ParallelLoopFusion.cpp     | 34 +++++--
 .../Dialect/SCF/parallel-loop-fusion.mlir     | 93 +++++++++++++++----
 2 files changed, 102 insertions(+), 25 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index 7d9e220518441..853b63f5adaf5 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -192,18 +192,38 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
       secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
 
   Block *newBlock = newSecondPloop.getBody();
-  newBlock->getTerminator()->erase();
+  auto term1 = cast<ReduceOp>(block1->getTerminator());
+  auto term2 = cast<ReduceOp>(block2->getTerminator());
 
-  block1->getTerminator()->erase();
-
-  b.inlineBlockBefore(block1, newBlock, newBlock->end(),
+  b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
                       newBlock->getArguments());
-  b.inlineBlockBefore(block2, newBlock, newBlock->end(),
+  b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
                       newBlock->getArguments());
 
   ValueRange results = newSecondPloop.getResults();
-  firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
-  secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
+  if (!results.empty()) {
+    b.setInsertionPointToEnd(newBlock);
+
+    ValueRange reduceArgs1 = term1.getOperands();
+    ValueRange reduceArgs2 = term2.getOperands();
+    SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
+    newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
+
+    auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
+
+    for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
+             term1.getReductions(), term2.getReductions()))) {
+      Block &oldRedBlock = reg.front();
+      Block &newRedBlock = newReduceOp.getReductions()[i].front();
+      b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
+                          newRedBlock.getArguments());
+    }
+
+    firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
+    secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
+  }
+  term1->erase();
+  term2->erase();
   firstPloop.erase();
   secondPloop.erase();
   secondPloop = newSecondPloop;
diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
index 9ced6d932274e..d171f96811b10 100644
--- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
+++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
@@ -578,7 +578,7 @@ func.func @do_not_fuse_affine_apply_to_non_ind_var(
 
 // -----
 
-func.func @fuse_reductions(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
+func.func @fuse_reductions_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
   %c2 = arith.constant 2 : index
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
@@ -586,26 +586,24 @@ func.func @fuse_reductions(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f3
   %init2 = arith.constant 2.0 : f32
   %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
     %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
-    scf.reduce(%A_elem) : f32 {
+    scf.reduce(%A_elem : f32) {
     ^bb0(%lhs: f32, %rhs: f32):
       %1 = arith.addf %lhs, %rhs : f32
       scf.reduce.return %1 : f32
     }
-    scf.yield
   }
   %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
     %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
-    scf.reduce(%B_elem) : f32 {
+    scf.reduce(%B_elem : f32) {
     ^bb0(%lhs: f32, %rhs: f32):
       %1 = arith.mulf %lhs, %rhs : f32
       scf.reduce.return %1 : f32
     }
-    scf.yield
   }
   return %res1, %res2 : f32, f32
 }
 
-// CHECK-LABEL: func @fuse_reductions
+// CHECK-LABEL: func @fuse_reductions_two
 //  CHECK-SAME:  (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>) -> (f32, f32)
 //   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
@@ -616,22 +614,85 @@ func.func @fuse_reductions(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f3
 //  CHECK-SAME:   to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]])
 //  CHECK-SAME:   init (%[[INIT1]], %[[INIT2]]) -> (f32, f32)
 //       CHECK:   %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]]
-//       CHECK:   scf.reduce(%[[VAL_A]]) : f32 {
+//       CHECK:   %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
+//       CHECK:   scf.reduce(%[[VAL_A]], %[[VAL_B]] : f32, f32) {
 //       CHECK:   ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
 //       CHECK:     %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
 //       CHECK:     scf.reduce.return %[[R]] : f32
 //       CHECK:   }
-//       CHECK:   %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
-//       CHECK:   scf.reduce(%[[VAL_B]]) : f32 {
 //       CHECK:   ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
 //       CHECK:     %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32
 //       CHECK:     scf.reduce.return %[[R]] : f32
 //       CHECK:   }
-//       CHECK:   scf.yield
 //       CHECK:   return %[[RES]]#0, %[[RES]]#1 : f32, f32
 
 // -----
 
+func.func @fuse_reductions_three(%A: memref<2x2xf32>, %B: memref<2x2xf32>, %C: memref<2x2xf32>) -> (f32, f32, f32) {
+  %c2 = arith.constant 2 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %init1 = arith.constant 1.0 : f32
+  %init2 = arith.constant 2.0 : f32
+  %init3 = arith.constant 3.0 : f32
+  %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
+    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+    scf.reduce(%A_elem : f32) {
+    ^bb0(%lhs: f32, %rhs: f32):
+      %1 = arith.addf %lhs, %rhs : f32
+      scf.reduce.return %1 : f32
+    }
+  }
+  %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
+    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+    scf.reduce(%B_elem : f32) {
+    ^bb0(%lhs: f32, %rhs: f32):
+      %1 = arith.mulf %lhs, %rhs : f32
+      scf.reduce.return %1 : f32
+    }
+  }
+  %res3 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init3) -> f32 {
+    %A_elem = memref.load %C[%i, %j] : memref<2x2xf32>
+    scf.reduce(%A_elem : f32) {
+    ^bb0(%lhs: f32, %rhs: f32):
+      %1 = arith.addf %lhs, %rhs : f32
+      scf.reduce.return %1 : f32
+    }
+  }
+  return %res1, %res2, %res3 : f32, f32, f32
+}
+
+// CHECK-LABEL: func @fuse_reductions_three
+//  CHECK-SAME:  (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>, %[[C:.*]]: memref<2x2xf32>) -> (f32, f32, f32)
+//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+//   CHECK-DAG:   %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32
+//   CHECK-DAG:   %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32
+//   CHECK-DAG:   %[[INIT3:.*]] = arith.constant 3.000000e+00 : f32
+//       CHECK:   %[[RES:.*]]:3 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
+//  CHECK-SAME:   to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]])
+//  CHECK-SAME:   init (%[[INIT1]], %[[INIT2]], %[[INIT3]]) -> (f32, f32, f32)
+//       CHECK:   %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]]
+//       CHECK:   %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
+//       CHECK:   %[[VAL_C:.*]] = memref.load %[[C]][%[[I]], %[[J]]]
+//       CHECK:   scf.reduce(%[[VAL_A]], %[[VAL_B]], %[[VAL_C]] : f32, f32, f32) {
+//       CHECK:   ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
+//       CHECK:     %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
+//       CHECK:     scf.reduce.return %[[R]] : f32
+//       CHECK:   }
+//       CHECK:   ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
+//       CHECK:     %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32
+//       CHECK:     scf.reduce.return %[[R]] : f32
+//       CHECK:   }
+//       CHECK:   ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
+//       CHECK:     %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
+//       CHECK:     scf.reduce.return %[[R]] : f32
+//       CHECK:   }
+//       CHECK:   return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : f32, f32, f32
+
+// -----
+
 func.func @reductions_use_res(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
   %c2 = arith.constant 2 : index
   %c0 = arith.constant 0 : index
@@ -639,21 +700,19 @@ func.func @reductions_use_res(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32,
   %init1 = arith.constant 1.0 : f32
   %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
     %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
-    scf.reduce(%A_elem) : f32 {
+    scf.reduce(%A_elem : f32) {
     ^bb0(%lhs: f32, %rhs: f32):
       %1 = arith.addf %lhs, %rhs : f32
       scf.reduce.return %1 : f32
     }
-    scf.yield
   }
   %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%res1) -> f32 {
     %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
-    scf.reduce(%B_elem) : f32 {
+    scf.reduce(%B_elem : f32) {
     ^bb0(%lhs: f32, %rhs: f32):
       %1 = arith.mulf %lhs, %rhs : f32
       scf.reduce.return %1 : f32
     }
-    scf.yield
   }
   return %res1, %res2 : f32, f32
 }
@@ -673,22 +732,20 @@ func.func @reductions_use_res_inside(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -
   %init2 = arith.constant 2.0 : f32
   %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
     %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
-    scf.reduce(%A_elem) : f32 {
+    scf.reduce(%A_elem : f32) {
     ^bb0(%lhs: f32, %rhs: f32):
       %1 = arith.addf %lhs, %rhs : f32
       scf.reduce.return %1 : f32
     }
-    scf.yield
   }
   %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
     %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
     %sum = arith.addf %B_elem, %res1 : f32
-    scf.reduce(%sum) : f32 {
+    scf.reduce(%sum : f32) {
     ^bb0(%lhs: f32, %rhs: f32):
       %1 = arith.mulf %lhs, %rhs : f32
       scf.reduce.return %1 : f32
     }
-    scf.yield
   }
   return %res1, %res2 : f32, f32
 }

>From 083707eae339c3f916619626f4ca8aca10022195 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Thu, 4 Jan 2024 19:46:13 +0100
Subject: [PATCH 5/6] add comments

---
 mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index 853b63f5adaf5..5934d85373b03 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -175,6 +175,8 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
     return;
 
   DominanceInfo dom;
+  // We are fusing first loop into second, make sure there are no users of the
+  // first loop results between loops.
   for (Operation *user : firstPloop->getUsers())
     if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
       return;

>From 8f7b4a40ff443b25f8aac9736f581e90f612c0f6 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Wed, 31 Jan 2024 23:49:43 +0100
Subject: [PATCH 6/6] more tests

---
 .../Dialect/SCF/parallel-loop-fusion.mlir     | 59 +++++++++++++++++++
 1 file changed, 59 insertions(+)

diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
index d171f96811b10..0d4ea6f20e8d9 100644
--- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
+++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
@@ -24,6 +24,32 @@ func.func @fuse_empty_loops() {
 
 // -----
 
+func.func @fuse_ops_between(%A: f32, %B: f32) -> f32 {
+  %c2 = arith.constant 2 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    scf.reduce
+  }
+  %res = arith.addf %A, %B : f32
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    scf.reduce
+  }
+  return %res : f32
+}
+// CHECK-LABEL: func @fuse_ops_between
+// CHECK-DAG:    [[C0:%.*]] = arith.constant 0 : index
+// CHECK-DAG:    [[C1:%.*]] = arith.constant 1 : index
+// CHECK-DAG:    [[C2:%.*]] = arith.constant 2 : index
+// CHECK:        %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f32
+// CHECK:        scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
+// CHECK-SAME:       to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
+// CHECK:          scf.reduce
+// CHECK:        }
+// CHECK-NOT:    scf.parallel
+
+// -----
+
 func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
   %c2 = arith.constant 2 : index
   %c0 = arith.constant 0 : index
@@ -754,3 +780,36 @@ func.func @reductions_use_res_inside(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -
 // CHECK-LABEL: func @reductions_use_res_inside
 // CHECK:      scf.parallel
 // CHECK:      scf.parallel
+
+// -----
+
+func.func @reductions_use_res_between(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32, f32) {
+  %c2 = arith.constant 2 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %init1 = arith.constant 1.0 : f32
+  %init2 = arith.constant 2.0 : f32
+  %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
+    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+    scf.reduce(%A_elem : f32) {
+    ^bb0(%lhs: f32, %rhs: f32):
+      %1 = arith.addf %lhs, %rhs : f32
+      scf.reduce.return %1 : f32
+    }
+  }
+  %res3 = arith.addf %res1, %init2 : f32
+  %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
+    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+    scf.reduce(%B_elem : f32) {
+    ^bb0(%lhs: f32, %rhs: f32):
+      %1 = arith.mulf %lhs, %rhs : f32
+      scf.reduce.return %1 : f32
+    }
+  }
+  return %res1, %res2, %res3 : f32, f32, f32
+}
+
+// instruction in between the loops uses the first loop result
+// CHECK-LABEL: func @reductions_use_res_between
+// CHECK:      scf.parallel
+// CHECK:      scf.parallel



More information about the Mlir-commits mailing list