[Mlir-commits] [mlir] Refactor LoopFuseSiblingOp and support parallel fusion (PR #94391)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jun 4 13:29:34 PDT 2024
================
@@ -1070,6 +1071,206 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
return tileLoops;
}
+/// Checks if the parallel loops have mixed access to the same buffers. Returns
+/// `true` if the first parallel loop writes to the same indices that the second
+/// loop reads.
+static bool haveNoReadsAfterWriteExceptSameIndex(
+ scf::ParallelOp firstPloop, scf::ParallelOp secondPloop,
+ const IRMapping &firstToSecondPloopIndices,
+ llvm::function_ref<bool(Value, Value)> mayAlias) {
+ DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
+ SmallVector<Value> bufferStoresVec;
+ firstPloop.getBody()->walk([&](memref::StoreOp store) {
+ bufferStores[store.getMemRef()].push_back(store.getIndices());
+ bufferStoresVec.emplace_back(store.getMemRef());
+ });
+ auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
+ Value loadMem = load.getMemRef();
+ // Stop if the memref is defined in secondPloop body. Careful alias analysis
+ // is needed.
+ auto *memrefDef = loadMem.getDefiningOp();
+ if (memrefDef && memrefDef->getBlock() == load->getBlock())
+ return WalkResult::interrupt();
+
+ for (Value store : bufferStoresVec)
+ if (store != loadMem && mayAlias(store, loadMem))
+ return WalkResult::interrupt();
+
+ auto write = bufferStores.find(loadMem);
+ if (write == bufferStores.end())
+ return WalkResult::advance();
+
+ // Check that at last one store was retrieved
+ if (!write->second.size())
+ return WalkResult::interrupt();
+
+ auto storeIndices = write->second.front();
+
+ // Multiple writes to the same memref are allowed only on the same indices
+ for (const auto &othStoreIndices : write->second) {
+ if (othStoreIndices != storeIndices)
+ return WalkResult::interrupt();
+ }
+
+ // Check that the load indices of secondPloop coincide with store indices of
+ // firstPloop for the same memrefs.
+ auto loadIndices = load.getIndices();
+ if (storeIndices.size() != loadIndices.size())
+ return WalkResult::interrupt();
+ for (int i = 0, e = storeIndices.size(); i < e; ++i) {
+ if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
+ loadIndices[i]) {
+ auto *storeIndexDefOp = storeIndices[i].getDefiningOp();
+ auto *loadIndexDefOp = loadIndices[i].getDefiningOp();
+ if (storeIndexDefOp && loadIndexDefOp) {
+ if (!isMemoryEffectFree(storeIndexDefOp))
+ return WalkResult::interrupt();
+ if (!isMemoryEffectFree(loadIndexDefOp))
+ return WalkResult::interrupt();
+ if (!OperationEquivalence::isEquivalentTo(
+ storeIndexDefOp, loadIndexDefOp,
+ [&](Value storeIndex, Value loadIndex) {
+ if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) !=
+ firstToSecondPloopIndices.lookupOrDefault(loadIndex))
+ return failure();
+ else
+ return success();
+ },
+ /*markEquivalent=*/nullptr,
+ OperationEquivalence::Flags::IgnoreLocations)) {
+ return WalkResult::interrupt();
+ }
+ } else
+ return WalkResult::interrupt();
+ }
+ }
+ return WalkResult::advance();
+ });
+ return !walkResult.wasInterrupted();
+}
+
+/// Analyzes dependencies in the most primitive way by checking simple read and
+/// write patterns.
+static LogicalResult
+verifyDependencies(scf::ParallelOp firstPloop, scf::ParallelOp secondPloop,
+ const IRMapping &firstToSecondPloopIndices,
+ llvm::function_ref<bool(Value, Value)> mayAlias) {
+ if (!haveNoReadsAfterWriteExceptSameIndex(
+ firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
+ return failure();
+
+ IRMapping secondToFirstPloopIndices;
+ secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
+ firstPloop.getBody()->getArguments());
+ return success(haveNoReadsAfterWriteExceptSameIndex(
+ secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
+}
+
+/// Verify equal iteration spaces.
+static bool equalIterationSpaces(scf::ParallelOp firstPloop,
+ scf::ParallelOp secondPloop) {
+ if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
+ return false;
+
+ auto matchOperands = [&](const OperandRange &lhs,
+ const OperandRange &rhs) -> bool {
+ // TODO: Extend this to support aliases and equal constants.
+ return std::equal(lhs.begin(), lhs.end(), rhs.begin());
+ };
+ return matchOperands(firstPloop.getLowerBound(),
+ secondPloop.getLowerBound()) &&
+ matchOperands(firstPloop.getUpperBound(),
+ secondPloop.getUpperBound()) &&
+ matchOperands(firstPloop.getStep(), secondPloop.getStep());
+}
+
+/// Verify there are no nested ParallelOps.
+static bool hasNestedParallelOp(scf::ParallelOp ploop) {
+ auto walkResult = ploop.getBody()->walk(
+ [](scf::ParallelOp) { return WalkResult::interrupt(); });
+ return walkResult.wasInterrupted();
+}
+
+static bool isFusionLegal(scf::ParallelOp firstPloop,
+ scf::ParallelOp secondPloop,
+ const IRMapping &firstToSecondPloopIndices,
+ llvm::function_ref<bool(Value, Value)> mayAlias) {
+ return !hasNestedParallelOp(firstPloop) &&
+ !hasNestedParallelOp(secondPloop) &&
+ equalIterationSpaces(firstPloop, secondPloop) &&
+ succeeded(verifyDependencies(firstPloop, secondPloop,
+ firstToSecondPloopIndices, mayAlias));
+}
+
+void mlir::fuseIfLegal(scf::ParallelOp firstPloop, scf::ParallelOp &secondPloop,
+ OpBuilder builder,
+ llvm::function_ref<bool(Value, Value)> mayAlias) {
+ Block *block1 = firstPloop.getBody();
+ Block *block2 = secondPloop.getBody();
+ IRMapping firstToSecondPloopIndices;
+ firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
+
+ if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
+ mayAlias))
+ return;
+
+ DominanceInfo dom;
+ // We are fusing first loop into second, make sure there are no users of the
+ // first loop results between loops.
+ for (Operation *user : firstPloop->getUsers())
+ if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
+ return;
+
+ ValueRange inits1 = firstPloop.getInitVals();
+ ValueRange inits2 = secondPloop.getInitVals();
+
+ SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
+ newInitVars.append(inits2.begin(), inits2.end());
+
+ IRRewriter b(builder);
+ b.setInsertionPoint(secondPloop);
+ auto newSecondPloop = b.create<scf::ParallelOp>(
+ secondPloop.getLoc(), secondPloop.getLowerBound(),
+ secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
+
+ Block *newBlock = newSecondPloop.getBody();
+ auto term1 = cast<scf::ReduceOp>(block1->getTerminator());
+ auto term2 = cast<scf::ReduceOp>(block2->getTerminator());
+
+ b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
+ newBlock->getArguments());
+ b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
+ newBlock->getArguments());
+
+ ValueRange results = newSecondPloop.getResults();
+ if (!results.empty()) {
+ b.setInsertionPointToEnd(newBlock);
+
+ ValueRange reduceArgs1 = term1.getOperands();
+ ValueRange reduceArgs2 = term2.getOperands();
+ SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
+ newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
+
+ auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
+
+ for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
+ term1.getReductions(), term2.getReductions()))) {
+ Block &oldRedBlock = reg.front();
+ Block &newRedBlock = newReduceOp.getReductions()[i].front();
+ b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
+ newRedBlock.getArguments());
+ }
+
+ firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
+ secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
+ }
+ term1->erase();
+ term2->erase();
+ firstPloop.erase();
+ secondPloop.erase();
+ secondPloop = newSecondPloop;
+}
+
----------------
srcarroll wrote:
this whole block of changes was copied from `ParallelLoopFusion.cpp` without modification.
https://github.com/llvm/llvm-project/pull/94391
More information about the Mlir-commits
mailing list