[Mlir-commits] [mlir] Refactor LoopFuseSiblingOp and support parallel fusion (PR #94391)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jun 4 12:45:47 PDT 2024
https://github.com/srcarroll created https://github.com/llvm/llvm-project/pull/94391
None
>From 76cc04002d2ca1bfd44d2da72bf441a1fbe717a4 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Tue, 4 Jun 2024 14:44:53 -0500
Subject: [PATCH] Refactor LoopFuseSiblingOp and support parallel fusion
---
mlir/include/mlir/Dialect/SCF/Utils/Utils.h | 16 ++
.../SCF/TransformOps/SCFTransformOps.cpp | 53 +++--
.../SCF/Transforms/ParallelLoopFusion.cpp | 204 +----------------
mlir/lib/Dialect/SCF/Utils/Utils.cpp | 208 ++++++++++++++++++
.../SCF/transform-loop-fuse-sibling.mlir | 53 +++++
5 files changed, 304 insertions(+), 230 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index bc09cc7f7fa5e..2944d8ffac022 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -156,6 +156,12 @@ Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes);
void getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
scf::ForOp root);
+/// Prepends operations of firstPloop's body into secondPloop's body.
+/// Updates secondPloop with new loop.
+void fuseIfLegal(scf::ParallelOp firstPloop, scf::ParallelOp &secondPloop,
+ OpBuilder builder,
+ llvm::function_ref<bool(Value, Value)> mayAlias);
+
/// Given two scf.forall loops, `target` and `source`, fuses `target` into
/// `source`. Assumes that the given loops are siblings and are independent of
/// each other.
@@ -177,6 +183,16 @@ scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target,
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source,
RewriterBase &rewriter);
+/// Given two scf.parallel loops, `target` and `source`, fuses `target` into
+/// `source`. Assumes that the given loops are siblings and are independent of
+/// each other.
+///
+/// This function does not perform any legality checks and simply fuses the
+/// loops. The caller is responsible for ensuring that the loops are legal to
+/// fuse.
+scf::ParallelOp fuseIndependentSiblingParallelLoops(scf::ParallelOp target,
+ scf::ParallelOp source,
+ RewriterBase &rewriter);
} // namespace mlir
#endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 69f83d8bd70da..1c53e89d69040 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -442,39 +442,32 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
return DiagnosedSilenceableFailure::success();
}
-/// Check if `target` scf.forall can be fused into `source` scf.forall.
+/// Check if `target` scf loop can be fused into `source` scf loop.
+/// Applies for scf.for, scf.forall, and scf.parallel.
///
/// This simply checks if both loops have the same bounds, steps and mapping.
/// No attempt is made at checking that the side effects of `target` and
/// `source` are independent of each other.
-static bool isForallWithIdenticalConfiguration(Operation *target,
- Operation *source) {
- auto targetOp = dyn_cast<scf::ForallOp>(target);
- auto sourceOp = dyn_cast<scf::ForallOp>(source);
- if (!targetOp || !sourceOp)
- return false;
-
- return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
- targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
- targetOp.getMixedStep() == sourceOp.getMixedStep() &&
- targetOp.getMapping() == sourceOp.getMapping();
-}
-
-/// Check if `target` scf.for can be fused into `source` scf.for.
-///
-/// This simply checks if both loops have the same bounds and steps. No attempt
-/// is made at checking that the side effects of `target` and `source` are
-/// independent of each other.
-static bool isForWithIdenticalConfiguration(Operation *target,
- Operation *source) {
- auto targetOp = dyn_cast<scf::ForOp>(target);
- auto sourceOp = dyn_cast<scf::ForOp>(source);
+template <typename LoopTy>
+static bool isLoopWithIdenticalConfiguration(Operation *target,
+ Operation *source) {
+ static_assert(llvm::is_one_of<LoopTy, scf::ForallOp, scf::ForOp,
+ scf::ParallelOp>::value,
+ "applies to only `forall`, `for` and `parallel`");
+ auto targetOp = dyn_cast<LoopTy>(target);
+ auto sourceOp = dyn_cast<LoopTy>(source);
if (!targetOp || !sourceOp)
return false;
- return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
- targetOp.getUpperBound() == sourceOp.getUpperBound() &&
- targetOp.getStep() == sourceOp.getStep();
+ if constexpr (std::is_same_v<LoopTy, scf::ForallOp>)
+ return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
+ targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
+ targetOp.getMixedStep() == sourceOp.getMixedStep() &&
+ targetOp.getMapping() == sourceOp.getMapping();
+ else
+ return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
+ targetOp.getUpperBound() == sourceOp.getUpperBound() &&
+ targetOp.getStep() == sourceOp.getStep();
}
DiagnosedSilenceableFailure
@@ -502,12 +495,16 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
Operation *fusedLoop;
/// TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
- if (isForWithIdenticalConfiguration(target, source)) {
+ if (isLoopWithIdenticalConfiguration<scf::ForOp>(target, source)) {
fusedLoop = fuseIndependentSiblingForLoops(
cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
- } else if (isForallWithIdenticalConfiguration(target, source)) {
+ } else if (isLoopWithIdenticalConfiguration<scf::ForallOp>(target, source)) {
fusedLoop = fuseIndependentSiblingForallLoops(
cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
+ } else if (isLoopWithIdenticalConfiguration<scf::ParallelOp>(target,
+ source)) {
+ fusedLoop = fuseIndependentSiblingParallelLoops(
+ cast<scf::ParallelOp>(target), cast<scf::ParallelOp>(source), rewriter);
} else
return emitSilenceableFailure(target->getLoc())
<< "operations cannot be fused";
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index 5934d85373b03..abac91cfaf7d9 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/OpDefinition.h"
@@ -30,207 +31,6 @@ namespace mlir {
using namespace mlir;
using namespace mlir::scf;
-/// Verify there are no nested ParallelOps.
-static bool hasNestedParallelOp(ParallelOp ploop) {
- auto walkResult =
- ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
- return walkResult.wasInterrupted();
-}
-
-/// Verify equal iteration spaces.
-static bool equalIterationSpaces(ParallelOp firstPloop,
- ParallelOp secondPloop) {
- if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
- return false;
-
- auto matchOperands = [&](const OperandRange &lhs,
- const OperandRange &rhs) -> bool {
- // TODO: Extend this to support aliases and equal constants.
- return std::equal(lhs.begin(), lhs.end(), rhs.begin());
- };
- return matchOperands(firstPloop.getLowerBound(),
- secondPloop.getLowerBound()) &&
- matchOperands(firstPloop.getUpperBound(),
- secondPloop.getUpperBound()) &&
- matchOperands(firstPloop.getStep(), secondPloop.getStep());
-}
-
-/// Checks if the parallel loops have mixed access to the same buffers. Returns
-/// `true` if the first parallel loop writes to the same indices that the second
-/// loop reads.
-static bool haveNoReadsAfterWriteExceptSameIndex(
- ParallelOp firstPloop, ParallelOp secondPloop,
- const IRMapping &firstToSecondPloopIndices,
- llvm::function_ref<bool(Value, Value)> mayAlias) {
- DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
- SmallVector<Value> bufferStoresVec;
- firstPloop.getBody()->walk([&](memref::StoreOp store) {
- bufferStores[store.getMemRef()].push_back(store.getIndices());
- bufferStoresVec.emplace_back(store.getMemRef());
- });
- auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
- Value loadMem = load.getMemRef();
- // Stop if the memref is defined in secondPloop body. Careful alias analysis
- // is needed.
- auto *memrefDef = loadMem.getDefiningOp();
- if (memrefDef && memrefDef->getBlock() == load->getBlock())
- return WalkResult::interrupt();
-
- for (Value store : bufferStoresVec)
- if (store != loadMem && mayAlias(store, loadMem))
- return WalkResult::interrupt();
-
- auto write = bufferStores.find(loadMem);
- if (write == bufferStores.end())
- return WalkResult::advance();
-
- // Check that at last one store was retrieved
- if (!write->second.size())
- return WalkResult::interrupt();
-
- auto storeIndices = write->second.front();
-
- // Multiple writes to the same memref are allowed only on the same indices
- for (const auto &othStoreIndices : write->second) {
- if (othStoreIndices != storeIndices)
- return WalkResult::interrupt();
- }
-
- // Check that the load indices of secondPloop coincide with store indices of
- // firstPloop for the same memrefs.
- auto loadIndices = load.getIndices();
- if (storeIndices.size() != loadIndices.size())
- return WalkResult::interrupt();
- for (int i = 0, e = storeIndices.size(); i < e; ++i) {
- if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
- loadIndices[i]) {
- auto *storeIndexDefOp = storeIndices[i].getDefiningOp();
- auto *loadIndexDefOp = loadIndices[i].getDefiningOp();
- if (storeIndexDefOp && loadIndexDefOp) {
- if (!isMemoryEffectFree(storeIndexDefOp))
- return WalkResult::interrupt();
- if (!isMemoryEffectFree(loadIndexDefOp))
- return WalkResult::interrupt();
- if (!OperationEquivalence::isEquivalentTo(
- storeIndexDefOp, loadIndexDefOp,
- [&](Value storeIndex, Value loadIndex) {
- if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) !=
- firstToSecondPloopIndices.lookupOrDefault(loadIndex))
- return failure();
- else
- return success();
- },
- /*markEquivalent=*/nullptr,
- OperationEquivalence::Flags::IgnoreLocations)) {
- return WalkResult::interrupt();
- }
- } else
- return WalkResult::interrupt();
- }
- }
- return WalkResult::advance();
- });
- return !walkResult.wasInterrupted();
-}
-
-/// Analyzes dependencies in the most primitive way by checking simple read and
-/// write patterns.
-static LogicalResult
-verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
- const IRMapping &firstToSecondPloopIndices,
- llvm::function_ref<bool(Value, Value)> mayAlias) {
- if (!haveNoReadsAfterWriteExceptSameIndex(
- firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
- return failure();
-
- IRMapping secondToFirstPloopIndices;
- secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
- firstPloop.getBody()->getArguments());
- return success(haveNoReadsAfterWriteExceptSameIndex(
- secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
-}
-
-static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
- const IRMapping &firstToSecondPloopIndices,
- llvm::function_ref<bool(Value, Value)> mayAlias) {
- return !hasNestedParallelOp(firstPloop) &&
- !hasNestedParallelOp(secondPloop) &&
- equalIterationSpaces(firstPloop, secondPloop) &&
- succeeded(verifyDependencies(firstPloop, secondPloop,
- firstToSecondPloopIndices, mayAlias));
-}
-
-/// Prepends operations of firstPloop's body into secondPloop's body.
-/// Updates secondPloop with new loop.
-static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
- OpBuilder builder,
- llvm::function_ref<bool(Value, Value)> mayAlias) {
- Block *block1 = firstPloop.getBody();
- Block *block2 = secondPloop.getBody();
- IRMapping firstToSecondPloopIndices;
- firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
-
- if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
- mayAlias))
- return;
-
- DominanceInfo dom;
- // We are fusing first loop into second, make sure there are no users of the
- // first loop results between loops.
- for (Operation *user : firstPloop->getUsers())
- if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
- return;
-
- ValueRange inits1 = firstPloop.getInitVals();
- ValueRange inits2 = secondPloop.getInitVals();
-
- SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
- newInitVars.append(inits2.begin(), inits2.end());
-
- IRRewriter b(builder);
- b.setInsertionPoint(secondPloop);
- auto newSecondPloop = b.create<ParallelOp>(
- secondPloop.getLoc(), secondPloop.getLowerBound(),
- secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
-
- Block *newBlock = newSecondPloop.getBody();
- auto term1 = cast<ReduceOp>(block1->getTerminator());
- auto term2 = cast<ReduceOp>(block2->getTerminator());
-
- b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
- newBlock->getArguments());
- b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
- newBlock->getArguments());
-
- ValueRange results = newSecondPloop.getResults();
- if (!results.empty()) {
- b.setInsertionPointToEnd(newBlock);
-
- ValueRange reduceArgs1 = term1.getOperands();
- ValueRange reduceArgs2 = term2.getOperands();
- SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
- newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
-
- auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
-
- for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
- term1.getReductions(), term2.getReductions()))) {
- Block &oldRedBlock = reg.front();
- Block &newRedBlock = newReduceOp.getReductions()[i].front();
- b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
- newRedBlock.getArguments());
- }
-
- firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
- secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
- }
- term1->erase();
- term2->erase();
- firstPloop.erase();
- secondPloop.erase();
- secondPloop = newSecondPloop;
-}
-
void mlir::scf::naivelyFuseParallelOps(
Region ®ion, llvm::function_ref<bool(Value, Value)> mayAlias) {
OpBuilder b(region);
@@ -259,7 +59,7 @@ void mlir::scf::naivelyFuseParallelOps(
}
for (MutableArrayRef<ParallelOp> ploops : ploopChains) {
for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
- fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
+ mlir::fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
}
}
}
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 6658cca03eba7..d85339f32dbe3 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
@@ -1070,6 +1071,206 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
return tileLoops;
}
+/// Checks if the parallel loops have mixed access to the same buffers. Returns
+/// `true` if the first parallel loop writes to the same indices that the second
+/// loop reads.
+static bool haveNoReadsAfterWriteExceptSameIndex(
+ scf::ParallelOp firstPloop, scf::ParallelOp secondPloop,
+ const IRMapping &firstToSecondPloopIndices,
+ llvm::function_ref<bool(Value, Value)> mayAlias) {
+ DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
+ SmallVector<Value> bufferStoresVec;
+ firstPloop.getBody()->walk([&](memref::StoreOp store) {
+ bufferStores[store.getMemRef()].push_back(store.getIndices());
+ bufferStoresVec.emplace_back(store.getMemRef());
+ });
+ auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
+ Value loadMem = load.getMemRef();
+ // Stop if the memref is defined in secondPloop body. Careful alias analysis
+ // is needed.
+ auto *memrefDef = loadMem.getDefiningOp();
+ if (memrefDef && memrefDef->getBlock() == load->getBlock())
+ return WalkResult::interrupt();
+
+ for (Value store : bufferStoresVec)
+ if (store != loadMem && mayAlias(store, loadMem))
+ return WalkResult::interrupt();
+
+ auto write = bufferStores.find(loadMem);
+ if (write == bufferStores.end())
+ return WalkResult::advance();
+
+ // Check that at last one store was retrieved
+ if (!write->second.size())
+ return WalkResult::interrupt();
+
+ auto storeIndices = write->second.front();
+
+ // Multiple writes to the same memref are allowed only on the same indices
+ for (const auto &othStoreIndices : write->second) {
+ if (othStoreIndices != storeIndices)
+ return WalkResult::interrupt();
+ }
+
+ // Check that the load indices of secondPloop coincide with store indices of
+ // firstPloop for the same memrefs.
+ auto loadIndices = load.getIndices();
+ if (storeIndices.size() != loadIndices.size())
+ return WalkResult::interrupt();
+ for (int i = 0, e = storeIndices.size(); i < e; ++i) {
+ if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
+ loadIndices[i]) {
+ auto *storeIndexDefOp = storeIndices[i].getDefiningOp();
+ auto *loadIndexDefOp = loadIndices[i].getDefiningOp();
+ if (storeIndexDefOp && loadIndexDefOp) {
+ if (!isMemoryEffectFree(storeIndexDefOp))
+ return WalkResult::interrupt();
+ if (!isMemoryEffectFree(loadIndexDefOp))
+ return WalkResult::interrupt();
+ if (!OperationEquivalence::isEquivalentTo(
+ storeIndexDefOp, loadIndexDefOp,
+ [&](Value storeIndex, Value loadIndex) {
+ if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) !=
+ firstToSecondPloopIndices.lookupOrDefault(loadIndex))
+ return failure();
+ else
+ return success();
+ },
+ /*markEquivalent=*/nullptr,
+ OperationEquivalence::Flags::IgnoreLocations)) {
+ return WalkResult::interrupt();
+ }
+ } else
+ return WalkResult::interrupt();
+ }
+ }
+ return WalkResult::advance();
+ });
+ return !walkResult.wasInterrupted();
+}
+
+/// Analyzes dependencies in the most primitive way by checking simple read and
+/// write patterns.
+static LogicalResult
+verifyDependencies(scf::ParallelOp firstPloop, scf::ParallelOp secondPloop,
+ const IRMapping &firstToSecondPloopIndices,
+ llvm::function_ref<bool(Value, Value)> mayAlias) {
+ if (!haveNoReadsAfterWriteExceptSameIndex(
+ firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
+ return failure();
+
+ IRMapping secondToFirstPloopIndices;
+ secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
+ firstPloop.getBody()->getArguments());
+ return success(haveNoReadsAfterWriteExceptSameIndex(
+ secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
+}
+
+/// Verify equal iteration spaces.
+static bool equalIterationSpaces(scf::ParallelOp firstPloop,
+ scf::ParallelOp secondPloop) {
+ if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
+ return false;
+
+ auto matchOperands = [&](const OperandRange &lhs,
+ const OperandRange &rhs) -> bool {
+ // TODO: Extend this to support aliases and equal constants.
+ return std::equal(lhs.begin(), lhs.end(), rhs.begin());
+ };
+ return matchOperands(firstPloop.getLowerBound(),
+ secondPloop.getLowerBound()) &&
+ matchOperands(firstPloop.getUpperBound(),
+ secondPloop.getUpperBound()) &&
+ matchOperands(firstPloop.getStep(), secondPloop.getStep());
+}
+
+/// Verify there are no nested ParallelOps.
+static bool hasNestedParallelOp(scf::ParallelOp ploop) {
+ auto walkResult = ploop.getBody()->walk(
+ [](scf::ParallelOp) { return WalkResult::interrupt(); });
+ return walkResult.wasInterrupted();
+}
+
+static bool isFusionLegal(scf::ParallelOp firstPloop,
+ scf::ParallelOp secondPloop,
+ const IRMapping &firstToSecondPloopIndices,
+ llvm::function_ref<bool(Value, Value)> mayAlias) {
+ return !hasNestedParallelOp(firstPloop) &&
+ !hasNestedParallelOp(secondPloop) &&
+ equalIterationSpaces(firstPloop, secondPloop) &&
+ succeeded(verifyDependencies(firstPloop, secondPloop,
+ firstToSecondPloopIndices, mayAlias));
+}
+
+void mlir::fuseIfLegal(scf::ParallelOp firstPloop, scf::ParallelOp &secondPloop,
+ OpBuilder builder,
+ llvm::function_ref<bool(Value, Value)> mayAlias) {
+ Block *block1 = firstPloop.getBody();
+ Block *block2 = secondPloop.getBody();
+ IRMapping firstToSecondPloopIndices;
+ firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
+
+ if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
+ mayAlias))
+ return;
+
+ DominanceInfo dom;
+ // We are fusing first loop into second, make sure there are no users of the
+ // first loop results between loops.
+ for (Operation *user : firstPloop->getUsers())
+ if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
+ return;
+
+ ValueRange inits1 = firstPloop.getInitVals();
+ ValueRange inits2 = secondPloop.getInitVals();
+
+ SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
+ newInitVars.append(inits2.begin(), inits2.end());
+
+ IRRewriter b(builder);
+ b.setInsertionPoint(secondPloop);
+ auto newSecondPloop = b.create<scf::ParallelOp>(
+ secondPloop.getLoc(), secondPloop.getLowerBound(),
+ secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
+
+ Block *newBlock = newSecondPloop.getBody();
+ auto term1 = cast<scf::ReduceOp>(block1->getTerminator());
+ auto term2 = cast<scf::ReduceOp>(block2->getTerminator());
+
+ b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
+ newBlock->getArguments());
+ b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
+ newBlock->getArguments());
+
+ ValueRange results = newSecondPloop.getResults();
+ if (!results.empty()) {
+ b.setInsertionPointToEnd(newBlock);
+
+ ValueRange reduceArgs1 = term1.getOperands();
+ ValueRange reduceArgs2 = term2.getOperands();
+ SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
+ newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
+
+ auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
+
+ for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
+ term1.getReductions(), term2.getReductions()))) {
+ Block &oldRedBlock = reg.front();
+ Block &newRedBlock = newReduceOp.getReductions()[i].front();
+ b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
+ newRedBlock.getArguments());
+ }
+
+ firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
+ secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
+ }
+ term1->erase();
+ term2->erase();
+ firstPloop.erase();
+ secondPloop.erase();
+ secondPloop = newSecondPloop;
+}
+
scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
scf::ForallOp source,
RewriterBase &rewriter) {
@@ -1171,3 +1372,10 @@ scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
return fusedLoop;
}
+
+scf::ParallelOp mlir::fuseIndependentSiblingParallelLoops(
+ scf::ParallelOp target, scf::ParallelOp source, RewriterBase &rewriter) {
+ auto mayAlias = [&](Value val1, Value val2) -> bool { return false; };
+ mlir::fuseIfLegal(target, source, rewriter, mayAlias);
+ return source;
+}
diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
index 0f51b1cdbe0cf..46c6be36c3271 100644
--- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
+++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
@@ -47,6 +47,59 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: func @fuse_two_parallel
+// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
+func.func @fuse_two_parallel(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
+// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
+// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
+// CHECK-DAG: [[C1FP:%.*]] = arith.constant 1.
+ %c2 = arith.constant 2 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c1fp = arith.constant 1.0 : f32
+// CHECK: [[SUM:%.*]] = memref.alloc()
+ %sum = memref.alloc() : memref<2x2xf32>
+// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
+// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
+// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
+// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
+// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT: scf.parallel
+// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
+// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
+// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
+// CHECK: scf.reduce
+// CHECK: }
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+ %sum_elem = arith.addf %B_elem, %c1fp : f32
+ memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
+ scf.reduce
+ }
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
+ %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+ %product_elem = arith.mulf %sum_elem, %A_elem : f32
+ memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
+ scf.reduce
+ }
+// CHECK: memref.dealloc [[SUM]]
+ memref.dealloc %sum : memref<2x2xf32>
+ return
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %parallel:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %fused = transform.loop.fuse_sibling %parallel#0 into %parallel#1 : (!transform.any_op,!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK: func.func @fuse_2nd_for_into_1st([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
func.func @fuse_2nd_for_into_1st(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {
// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
More information about the Mlir-commits
mailing list