[Mlir-commits] [mlir] [mlir][scf] Add reductions support to `scf.parallel` fusion (PR #75955)
Ivan Butygin
llvmlistbot at llvm.org
Wed Dec 20 09:30:38 PST 2023
https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/75955
>From ab744fd1085af6a8374943fc189a0520473ee57d Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 19 Dec 2023 18:10:56 +0100
Subject: [PATCH 1/4] [mlir][scf] Add reductions support to `scf.parallel`
fusion
---
.../SCF/Transforms/ParallelLoopFusion.cpp | 52 ++++++--
.../Dialect/SCF/parallel-loop-fusion.mlir | 122 ++++++++++++++++++
2 files changed, 165 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index d7184ad0bad2c7..ea9fbee26fdeb0 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -131,29 +131,63 @@ static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
}
/// Prepends operations of firstPloop's body into secondPloop's body.
-static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
- OpBuilder b,
+/// Updates secondPloop with new loop.
+static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
+ OpBuilder builder,
llvm::function_ref<bool(Value, Value)> mayAlias) {
+ Block *block1 = firstPloop.getBody();
+ Block *block2 = secondPloop.getBody();
IRMapping firstToSecondPloopIndices;
- firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(),
- secondPloop.getBody()->getArguments());
+ firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
mayAlias))
return;
- b.setInsertionPointToStart(secondPloop.getBody());
- for (auto &op : firstPloop.getBody()->without_terminator())
- b.clone(op, firstToSecondPloopIndices);
+ DominanceInfo dom;
+ for (Operation *user : firstPloop->getUsers())
+ if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
+ return;
+
+ ValueRange inits1 = firstPloop.getInitVals();
+ ValueRange inits2 = secondPloop.getInitVals();
+
+ SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
+ newInitVars.append(inits2.begin(), inits2.end());
+
+ IRRewriter b(builder);
+ b.setInsertionPoint(secondPloop);
+ auto newSecondPloop = b.create<ParallelOp>(
+ secondPloop.getLoc(), secondPloop.getLowerBound(),
+ secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
+
+ Block *newBlock = newSecondPloop.getBody();
+ newBlock->getTerminator()->erase();
+
+ block1->getTerminator()->erase();
+
+ b.inlineBlockBefore(block1, newBlock, newBlock->end(),
+ newBlock->getArguments());
+ b.inlineBlockBefore(block2, newBlock, newBlock->end(),
+ newBlock->getArguments());
+
+ ValueRange results = newSecondPloop.getResults();
+ firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
+ secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
firstPloop.erase();
+ secondPloop.erase();
+ secondPloop = newSecondPloop;
}
void mlir::scf::naivelyFuseParallelOps(
Region ®ion, llvm::function_ref<bool(Value, Value)> mayAlias) {
OpBuilder b(region);
// Consider every single block and attempt to fuse adjacent loops.
+ SmallVector<SmallVector<ParallelOp>, 1> ploopChains;
for (auto &block : region) {
- SmallVector<SmallVector<ParallelOp, 8>, 1> ploopChains{{}};
+ ploopChains.clear();
+ ploopChains.push_back({});
+
// Not using `walk()` to traverse only top-level parallel loops and also
// make sure that there are no side-effecting ops between the parallel
// loops.
@@ -171,7 +205,7 @@ void mlir::scf::naivelyFuseParallelOps(
// TODO: Handle region side effects properly.
noSideEffects &= isMemoryEffectFree(&op) && op.getNumRegions() == 0;
}
- for (ArrayRef<ParallelOp> ploops : ploopChains) {
+ for (MutableArrayRef<ParallelOp> ploops : ploopChains) {
for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
}
diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
index 9fd33b4e524717..5ddeb001bd6a63 100644
--- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
+++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
@@ -387,3 +387,125 @@ func.func @do_not_fuse_alias(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
// CHECK-LABEL: func @do_not_fuse_alias
// CHECK: scf.parallel
// CHECK: scf.parallel
+
+// -----
+
+func.func @fuse_reductions(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
+ %c2 = arith.constant 2 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %init1 = arith.constant 1.0 : f32
+ %init2 = arith.constant 2.0 : f32
+ %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
+ %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+ scf.reduce(%A_elem) : f32 {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %1 = arith.addf %lhs, %rhs : f32
+ scf.reduce.return %1 : f32
+ }
+ scf.yield
+ }
+ %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
+ %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+ scf.reduce(%B_elem) : f32 {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %1 = arith.mulf %lhs, %rhs : f32
+ scf.reduce.return %1 : f32
+ }
+ scf.yield
+ }
+ return %res1, %res2 : f32, f32
+}
+
+// CHECK-LABEL: func @fuse_reductions
+// CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>)
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-DAG: %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[RES:.*]]:2 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
+// CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]])
+// CHECK-SAME: init (%[[INIT1]], %[[INIT2]]) -> (f32, f32)
+// CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]]
+// CHECK: scf.reduce(%[[VAL_A]]) : f32 {
+// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
+// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
+// CHECK: scf.reduce.return %[[R]] : f32
+// CHECK: }
+// CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
+// CHECK: scf.reduce(%[[VAL_B]]) : f32 {
+// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
+// CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32
+// CHECK: scf.reduce.return %[[R]] : f32
+// CHECK: }
+// CHECK: scf.yield
+// CHECK: return %[[RES]]#0, %[[RES]]#1
+
+// -----
+
+func.func @reductions_use_res(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
+ %c2 = arith.constant 2 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %init1 = arith.constant 1.0 : f32
+ %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
+ %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+ scf.reduce(%A_elem) : f32 {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %1 = arith.addf %lhs, %rhs : f32
+ scf.reduce.return %1 : f32
+ }
+ scf.yield
+ }
+ %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%res1) -> f32 {
+ %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+ scf.reduce(%B_elem) : f32 {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %1 = arith.mulf %lhs, %rhs : f32
+ scf.reduce.return %1 : f32
+ }
+ scf.yield
+ }
+ return %res1, %res2 : f32, f32
+}
+
+// %res1 is used as second scf.parallel arg, cannot fuse
+// CHECK-LABEL: func @reductions_use_res
+// CHECK: scf.parallel
+// CHECK: scf.parallel
+
+// -----
+
+func.func @reductions_use_res_inside(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
+ %c2 = arith.constant 2 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %init1 = arith.constant 1.0 : f32
+ %init2 = arith.constant 2.0 : f32
+ %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
+ %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+ scf.reduce(%A_elem) : f32 {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %1 = arith.addf %lhs, %rhs : f32
+ scf.reduce.return %1 : f32
+ }
+ scf.yield
+ }
+ %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
+ %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+ %sum = arith.addf %B_elem, %res1 : f32
+ scf.reduce(%sum) : f32 {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %1 = arith.mulf %lhs, %rhs : f32
+ scf.reduce.return %1 : f32
+ }
+ scf.yield
+ }
+ return %res1, %res2 : f32, f32
+}
+
+// %res1 is used inside second scf.parallel arg, cannot fuse
+// CHECK-LABEL: func @reductions_use_res_inside
+// CHECK: scf.parallel
+// CHECK: scf.parallel
>From 831b1532d827312d61a4968293b126053d81eafd Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 19 Dec 2023 18:17:04 +0100
Subject: [PATCH 2/4] typo
---
mlir/test/Dialect/SCF/parallel-loop-fusion.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
index 5ddeb001bd6a63..f5dfc49cf5bbef 100644
--- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
+++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
@@ -505,7 +505,7 @@ func.func @reductions_use_res_inside(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -
return %res1, %res2 : f32, f32
}
-// %res1 is used inside second scf.parallel arg, cannot fuse
+// %res1 is used inside second scf.parallel, cannot fuse
// CHECK-LABEL: func @reductions_use_res_inside
// CHECK: scf.parallel
// CHECK: scf.parallel
>From 8e10a4013178773f8837ca2a546443f983ef9467 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 19 Dec 2023 23:05:27 +0100
Subject: [PATCH 3/4] update test
---
mlir/test/Dialect/SCF/parallel-loop-fusion.mlir | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
index f5dfc49cf5bbef..f3923763da954b 100644
--- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
+++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
@@ -418,7 +418,7 @@ func.func @fuse_reductions(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f3
}
// CHECK-LABEL: func @fuse_reductions
-// CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>)
+// CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>) -> (f32, f32)
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
@@ -440,7 +440,7 @@ func.func @fuse_reductions(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f3
// CHECK: scf.reduce.return %[[R]] : f32
// CHECK: }
// CHECK: scf.yield
-// CHECK: return %[[RES]]#0, %[[RES]]#1
+// CHECK: return %[[RES]]#0, %[[RES]]#1 : f32, f32
// -----
>From d3a0a847553f10cb417591df39aebe5bb193e3ed Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Wed, 20 Dec 2023 18:29:17 +0100
Subject: [PATCH 4/4] Update to new reductions format
---
.../SCF/Transforms/ParallelLoopFusion.cpp | 34 +++++--
.../Dialect/SCF/parallel-loop-fusion.mlir | 93 +++++++++++++++----
2 files changed, 102 insertions(+), 25 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index ea9fbee26fdeb0..9eb275d18f5689 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -162,18 +162,38 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
Block *newBlock = newSecondPloop.getBody();
- newBlock->getTerminator()->erase();
+ auto term1 = cast<ReduceOp>(block1->getTerminator());
+ auto term2 = cast<ReduceOp>(block2->getTerminator());
- block1->getTerminator()->erase();
-
- b.inlineBlockBefore(block1, newBlock, newBlock->end(),
+ b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
newBlock->getArguments());
- b.inlineBlockBefore(block2, newBlock, newBlock->end(),
+ b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
newBlock->getArguments());
ValueRange results = newSecondPloop.getResults();
- firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
- secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
+ if (!results.empty()) {
+ b.setInsertionPointToEnd(newBlock);
+
+ ValueRange reduceArgs1 = term1.getOperands();
+ ValueRange reduceArgs2 = term2.getOperands();
+ SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
+ newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
+
+ auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
+
+ for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
+ term1.getReductions(), term2.getReductions()))) {
+ Block &oldRedBlock = reg.front();
+ Block &newRedBlock = newReduceOp.getReductions()[i].front();
+ b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
+ newRedBlock.getArguments());
+ }
+
+ firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
+ secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
+ }
+ term1->erase();
+ term2->erase();
firstPloop.erase();
secondPloop.erase();
secondPloop = newSecondPloop;
diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
index f3923763da954b..46780f0abd7d05 100644
--- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
+++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
@@ -390,7 +390,7 @@ func.func @do_not_fuse_alias(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
// -----
-func.func @fuse_reductions(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
+func.func @fuse_reductions_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
@@ -398,26 +398,24 @@ func.func @fuse_reductions(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f3
%init2 = arith.constant 2.0 : f32
%res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
- scf.reduce(%A_elem) : f32 {
+ scf.reduce(%A_elem : f32) {
^bb0(%lhs: f32, %rhs: f32):
%1 = arith.addf %lhs, %rhs : f32
scf.reduce.return %1 : f32
}
- scf.yield
}
%res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
- scf.reduce(%B_elem) : f32 {
+ scf.reduce(%B_elem : f32) {
^bb0(%lhs: f32, %rhs: f32):
%1 = arith.mulf %lhs, %rhs : f32
scf.reduce.return %1 : f32
}
- scf.yield
}
return %res1, %res2 : f32, f32
}
-// CHECK-LABEL: func @fuse_reductions
+// CHECK-LABEL: func @fuse_reductions_two
// CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>) -> (f32, f32)
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
@@ -428,22 +426,85 @@ func.func @fuse_reductions(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f3
// CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]])
// CHECK-SAME: init (%[[INIT1]], %[[INIT2]]) -> (f32, f32)
// CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]]
-// CHECK: scf.reduce(%[[VAL_A]]) : f32 {
+// CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
+// CHECK: scf.reduce(%[[VAL_A]], %[[VAL_B]] : f32, f32) {
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
// CHECK: scf.reduce.return %[[R]] : f32
// CHECK: }
-// CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
-// CHECK: scf.reduce(%[[VAL_B]]) : f32 {
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
// CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32
// CHECK: scf.reduce.return %[[R]] : f32
// CHECK: }
-// CHECK: scf.yield
// CHECK: return %[[RES]]#0, %[[RES]]#1 : f32, f32
// -----
+func.func @fuse_reductions_three(%A: memref<2x2xf32>, %B: memref<2x2xf32>, %C: memref<2x2xf32>) -> (f32, f32, f32) {
+ %c2 = arith.constant 2 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %init1 = arith.constant 1.0 : f32
+ %init2 = arith.constant 2.0 : f32
+ %init3 = arith.constant 3.0 : f32
+ %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
+ %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+ scf.reduce(%A_elem : f32) {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %1 = arith.addf %lhs, %rhs : f32
+ scf.reduce.return %1 : f32
+ }
+ }
+ %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
+ %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+ scf.reduce(%B_elem : f32) {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %1 = arith.mulf %lhs, %rhs : f32
+ scf.reduce.return %1 : f32
+ }
+ }
+ %res3 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init3) -> f32 {
+ %A_elem = memref.load %C[%i, %j] : memref<2x2xf32>
+ scf.reduce(%A_elem : f32) {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %1 = arith.addf %lhs, %rhs : f32
+ scf.reduce.return %1 : f32
+ }
+ }
+ return %res1, %res2, %res3 : f32, f32, f32
+}
+
+// CHECK-LABEL: func @fuse_reductions_three
+// CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>, %[[C:.*]]: memref<2x2xf32>) -> (f32, f32, f32)
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-DAG: %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[INIT3:.*]] = arith.constant 3.000000e+00 : f32
+// CHECK: %[[RES:.*]]:3 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
+// CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]])
+// CHECK-SAME: init (%[[INIT1]], %[[INIT2]], %[[INIT3]]) -> (f32, f32, f32)
+// CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]]
+// CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
+// CHECK: %[[VAL_C:.*]] = memref.load %[[C]][%[[I]], %[[J]]]
+// CHECK: scf.reduce(%[[VAL_A]], %[[VAL_B]], %[[VAL_C]] : f32, f32, f32) {
+// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
+// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
+// CHECK: scf.reduce.return %[[R]] : f32
+// CHECK: }
+// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
+// CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32
+// CHECK: scf.reduce.return %[[R]] : f32
+// CHECK: }
+// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
+// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
+// CHECK: scf.reduce.return %[[R]] : f32
+// CHECK: }
+// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : f32, f32, f32
+
+// -----
+
func.func @reductions_use_res(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
@@ -451,21 +512,19 @@ func.func @reductions_use_res(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32,
%init1 = arith.constant 1.0 : f32
%res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
- scf.reduce(%A_elem) : f32 {
+ scf.reduce(%A_elem : f32) {
^bb0(%lhs: f32, %rhs: f32):
%1 = arith.addf %lhs, %rhs : f32
scf.reduce.return %1 : f32
}
- scf.yield
}
%res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%res1) -> f32 {
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
- scf.reduce(%B_elem) : f32 {
+ scf.reduce(%B_elem : f32) {
^bb0(%lhs: f32, %rhs: f32):
%1 = arith.mulf %lhs, %rhs : f32
scf.reduce.return %1 : f32
}
- scf.yield
}
return %res1, %res2 : f32, f32
}
@@ -485,22 +544,20 @@ func.func @reductions_use_res_inside(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -
%init2 = arith.constant 2.0 : f32
%res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
- scf.reduce(%A_elem) : f32 {
+ scf.reduce(%A_elem : f32) {
^bb0(%lhs: f32, %rhs: f32):
%1 = arith.addf %lhs, %rhs : f32
scf.reduce.return %1 : f32
}
- scf.yield
}
%res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
%sum = arith.addf %B_elem, %res1 : f32
- scf.reduce(%sum) : f32 {
+ scf.reduce(%sum : f32) {
^bb0(%lhs: f32, %rhs: f32):
%1 = arith.mulf %lhs, %rhs : f32
scf.reduce.return %1 : f32
}
- scf.yield
}
return %res1, %res2 : f32, f32
}
More information about the Mlir-commits
mailing list