[Mlir-commits] [mlir] 6b2e29c - NFC. Refactor affine fusion code for readability
Uday Bondhugula
llvmlistbot at llvm.org
Thu Jan 19 23:50:37 PST 2023
Author: Uday Bondhugula
Date: 2023-01-20T13:17:10+05:30
New Revision: 6b2e29c5080e3e9870d8a2151a0feaf64eb480d0
URL: https://github.com/llvm/llvm-project/commit/6b2e29c5080e3e9870d8a2151a0feaf64eb480d0
DIFF: https://github.com/llvm/llvm-project/commit/6b2e29c5080e3e9870d8a2151a0feaf64eb480d0.diff
LOG: NFC. Refactor affine fusion code for readability
Replace a couple of check instances with llvm::any_of (clang-tidy
warnings). Factor out "canCreatePrivateMemRef" and
"performFusionsIntoDest" into separate methods to reduce the
length/indent of the containing methods. Add doc comments and debug messages.
Mark some of the methods that should have been const const.
NFC.
Reviewed By: vinayaka-polymage
Differential Revision: https://reviews.llvm.org/D142076
Added:
Modified:
mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index 656158798cbc7..db39a835144ac 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -125,20 +125,20 @@ struct MemRefDependenceGraph {
Node(unsigned id, Operation *op) : id(id), op(op) {}
// Returns the load op count for 'memref'.
- unsigned getLoadOpCount(Value memref) {
+ unsigned getLoadOpCount(Value memref) const {
unsigned loadOpCount = 0;
- for (auto *loadOpInst : loads) {
- if (memref == cast<AffineReadOpInterface>(loadOpInst).getMemRef())
+ for (Operation *loadOp : loads) {
+ if (memref == cast<AffineReadOpInterface>(loadOp).getMemRef())
++loadOpCount;
}
return loadOpCount;
}
// Returns the store op count for 'memref'.
- unsigned getStoreOpCount(Value memref) {
+ unsigned getStoreOpCount(Value memref) const {
unsigned storeOpCount = 0;
- for (auto *storeOpInst : stores) {
- if (memref == cast<AffineWriteOpInterface>(storeOpInst).getMemRef())
+ for (Operation *storeOp : stores) {
+ if (memref == cast<AffineWriteOpInterface>(storeOp).getMemRef())
++storeOpCount;
}
return storeOpCount;
@@ -146,31 +146,32 @@ struct MemRefDependenceGraph {
// Returns all store ops in 'storeOps' which access 'memref'.
void getStoreOpsForMemref(Value memref,
- SmallVectorImpl<Operation *> *storeOps) {
- for (auto *storeOpInst : stores) {
- if (memref == cast<AffineWriteOpInterface>(storeOpInst).getMemRef())
- storeOps->push_back(storeOpInst);
+ SmallVectorImpl<Operation *> *storeOps) const {
+ for (Operation *storeOp : stores) {
+ if (memref == cast<AffineWriteOpInterface>(storeOp).getMemRef())
+ storeOps->push_back(storeOp);
}
}
// Returns all load ops in 'loadOps' which access 'memref'.
void getLoadOpsForMemref(Value memref,
- SmallVectorImpl<Operation *> *loadOps) {
- for (auto *loadOpInst : loads) {
- if (memref == cast<AffineReadOpInterface>(loadOpInst).getMemRef())
- loadOps->push_back(loadOpInst);
+ SmallVectorImpl<Operation *> *loadOps) const {
+ for (Operation *loadOp : loads) {
+ if (memref == cast<AffineReadOpInterface>(loadOp).getMemRef())
+ loadOps->push_back(loadOp);
}
}
// Returns all memrefs in 'loadAndStoreMemrefSet' for which this node
// has at least one load and store operation.
- void getLoadAndStoreMemrefSet(DenseSet<Value> *loadAndStoreMemrefSet) {
+ void
+ getLoadAndStoreMemrefSet(DenseSet<Value> *loadAndStoreMemrefSet) const {
llvm::SmallDenseSet<Value, 2> loadMemrefs;
- for (auto *loadOpInst : loads) {
- loadMemrefs.insert(cast<AffineReadOpInterface>(loadOpInst).getMemRef());
+ for (Operation *loadOp : loads) {
+ loadMemrefs.insert(cast<AffineReadOpInterface>(loadOp).getMemRef());
}
- for (auto *storeOpInst : stores) {
- auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
+ for (Operation *storeOp : stores) {
+ auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
if (loadMemrefs.count(memref) > 0)
loadAndStoreMemrefSet->insert(memref);
}
@@ -744,14 +745,12 @@ static bool isEscapingMemref(Value memref, Block *block) {
// Check if 'memref' is used by a non-deferencing op (including unknown ones)
// (e.g., call ops, alias creating ops, etc.).
- for (Operation *user : memref.getUsers()) {
+ return llvm::any_of(memref.getUsers(), [&](Operation *user) {
// Ignore users outside of `block`.
if (block->getParent()->findAncestorOpInRegion(*user)->getBlock() != block)
- continue;
- if (!isa<AffineMapAccessInterface>(*user))
- return true;
- }
- return false;
+ return false;
+ return !isa<AffineMapAccessInterface>(*user);
+ });
}
/// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id'
@@ -1076,10 +1075,9 @@ static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
return WalkResult::advance();
});
// Looking for users between node 'srcId' and node 'dstId'.
- for (Value memref : memRefValues)
- if (hasNonAffineUsersOnThePath(srcId, dstId, memref, mdg))
- return true;
- return false;
+ return llvm::any_of(memRefValues, [&](Value memref) {
+ return hasNonAffineUsersOnThePath(srcId, dstId, memref, mdg);
+ });
}
// Checks the profitability of fusing a backwards slice of the loop nest
@@ -1452,280 +1450,301 @@ struct GreedyFusion {
eraseUnusedMemRefAllocations();
}
- /// Visit each node in the graph, and for each node, attempt to fuse it with
- /// producer-consumer candidates. No fusion is performed when producers with a
- /// user count greater than `maxSrcUserCount` for any of the memrefs involved
- /// are encountered.
- void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {
- LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n");
- init();
- while (!worklist.empty()) {
- unsigned dstId = worklist.back();
- worklist.pop_back();
+ /// Returns true if a private memref can be created for `memref` given
+ /// the fusion scenario reflected by the other arguments.
+ bool canCreatePrivateMemRef(Value memref,
+ const DenseSet<Value> &srcEscapingMemRefs,
+ unsigned producerId, unsigned consumerId,
+ bool removeSrcNode) {
+ const Node *consumerNode = mdg->getNode(consumerId);
+ // If `memref` is an escaping one, do not create a private memref
+ // for the below scenarios, since doing so will leave the escaping
+ // memref unmodified as all the writes originally meant for the
+ // escaping memref would be performed on the private memref:
+ // 1. The source is to be removed after fusion,
+ // OR
+ // 2. The destination writes to `memref`.
+ if (srcEscapingMemRefs.count(memref) > 0 &&
+ (removeSrcNode || consumerNode->getStoreOpCount(memref) > 0))
+ return false;
- // Skip if this node was removed (fused into another node).
- if (mdg->nodes.count(dstId) == 0)
- continue;
- // Get 'dstNode' into which to attempt fusion.
- auto *dstNode = mdg->getNode(dstId);
- // Skip if 'dstNode' is not a loop nest.
- if (!isa<AffineForOp>(dstNode->op))
- continue;
- // Skip if 'dstNode' is a loop nest returning values.
- // TODO: support loop nests that return values.
- if (dstNode->op->getNumResults() > 0)
- continue;
+ // Don't create a private memref if 'srcNode' has in edges on
+ // 'memref' or 'dstNode' has out edges on 'memref'.
+ if (mdg->getIncomingMemRefAccesses(producerId, memref) > 0 ||
+ mdg->getOutEdgeCount(consumerId, memref) > 0)
+ return false;
- LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
-
- // Sink sequential loops in 'dstNode' (and thus raise parallel loops)
- // while preserving relative order. This can increase the maximum loop
- // depth at which we can fuse a slice of a producer loop nest into a
- // consumer loop nest.
- sinkSequentialLoops(dstNode);
- auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
-
- // Try to fuse 'dstNode' with candidate producer loops until a fixed point
- // is reached. Fusing two loops may expose new fusion opportunities.
- bool dstNodeChanged;
- do {
- // Gather src loop candidates for 'dstNode' and visit them in "quasi"
- // reverse program order to minimize the number of iterations needed to
- // reach the fixed point. Note that this is a best effort approach since
- // 'getProducerCandidates' does not always guarantee that program order
- // in 'srcIdCandidates'.
- dstNodeChanged = false;
- SmallVector<unsigned, 16> srcIdCandidates;
- getProducerCandidates(dstId, mdg, srcIdCandidates);
-
- for (unsigned srcId : llvm::reverse(srcIdCandidates)) {
- // Get 'srcNode' from which to attempt fusion into 'dstNode'.
- auto *srcNode = mdg->getNode(srcId);
- auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
- LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId
- << " for dst loop " << dstId << "\n");
-
- // Skip if 'srcNode' is a loop nest returning values.
- // TODO: support loop nests that return values.
- if (isa<AffineForOp>(srcNode->op) && srcNode->op->getNumResults() > 0)
- continue;
+ // If 'srcNode' will be removed but it has out edges on 'memref' to
+ // nodes other than 'dstNode', we have to preserve dependences and
+ // cannot create a private memref.
+ if (removeSrcNode &&
+ any_of(mdg->outEdges[producerId], [&](const auto &edge) {
+ return edge.value == memref && edge.id != consumerId;
+ }))
+ return false;
- DenseSet<Value> producerConsumerMemrefs;
- gatherProducerConsumerMemrefs(srcId, dstId, mdg,
- producerConsumerMemrefs);
+ return true;
+ }
- // Skip if 'srcNode' out edge count on any memref is greater than
- // 'maxSrcUserCount'.
- if (any_of(producerConsumerMemrefs, [&](Value memref) {
- return mdg->getOutEdgeCount(srcNode->id, memref) >
- maxSrcUserCount;
- }))
- continue;
+ /// Perform fusions with node `dstId` as the destination of fusion, with
+ /// No fusion is performed when producers with a user count greater than
+ /// `maxSrcUserCount` for any of the memrefs involved.
+ void performFusionsIntoDest(unsigned dstId, unsigned maxSrcUserCount) {
+ LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
+ // Skip if this node was removed (fused into another node).
+ if (mdg->nodes.count(dstId) == 0)
+ return;
+ // Get 'dstNode' into which to attempt fusion.
+ auto *dstNode = mdg->getNode(dstId);
+ // Skip if 'dstNode' is not a loop nest.
+ if (!isa<AffineForOp>(dstNode->op))
+ return;
+ // Skip if 'dstNode' is a loop nest returning values.
+ // TODO: support loop nests that return values.
+ if (dstNode->op->getNumResults() > 0)
+ return;
+
+ LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
+
+ // Sink sequential loops in 'dstNode' (and thus raise parallel loops)
+ // while preserving relative order. This can increase the maximum loop
+ // depth at which we can fuse a slice of a producer loop nest into a
+ // consumer loop nest.
+ sinkSequentialLoops(dstNode);
+ auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
- // Gather memrefs in 'srcNode' that are written and escape out of the
- // block (e.g., memref block arguments, returned memrefs,
- // memrefs passed to function calls, etc.).
- DenseSet<Value> srcEscapingMemRefs;
- gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs);
-
- // Skip if there are non-affine operations in between the 'srcNode'
- // and 'dstNode' using their memrefs. If so, we wouldn't be able to
- // compute a legal insertion point for now. 'srcNode' and 'dstNode'
- // memrefs with non-affine operation users would be considered
- // escaping memrefs so we can limit this check to only scenarios with
- // escaping memrefs.
- if (!srcEscapingMemRefs.empty() &&
- hasNonAffineUsersOnThePath(srcId, dstId, mdg)) {
- LLVM_DEBUG(
- llvm::dbgs()
- << "Can't fuse: non-affine users in between the loops\n.");
- continue;
- }
+ // Try to fuse 'dstNode' with candidate producer loops until a fixed point
+ // is reached. Fusing two loops may expose new fusion opportunities.
+ bool dstNodeChanged;
+ do {
+ // Gather src loop candidates for 'dstNode' and visit them in "quasi"
+ // reverse program order to minimize the number of iterations needed to
+ // reach the fixed point. Note that this is a best effort approach since
+ // 'getProducerCandidates' does not always guarantee that program order
+ // in 'srcIdCandidates'.
+ dstNodeChanged = false;
+ SmallVector<unsigned, 16> srcIdCandidates;
+ getProducerCandidates(dstId, mdg, srcIdCandidates);
+
+ for (unsigned srcId : llvm::reverse(srcIdCandidates)) {
+ // Get 'srcNode' from which to attempt fusion into 'dstNode'.
+ auto *srcNode = mdg->getNode(srcId);
+ auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
+ LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId
+ << " for dst loop " << dstId << "\n");
+
+ // Skip if 'srcNode' is a loop nest returning values.
+ // TODO: support loop nests that return values.
+ if (isa<AffineForOp>(srcNode->op) && srcNode->op->getNumResults() > 0)
+ continue;
- // Compute an operation list insertion point for the fused loop
- // nest which preserves dependences.
- Operation *fusedLoopInsPoint =
- mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id);
- if (fusedLoopInsPoint == nullptr)
- continue;
+ DenseSet<Value> producerConsumerMemrefs;
+ gatherProducerConsumerMemrefs(srcId, dstId, mdg,
+ producerConsumerMemrefs);
- // Compute the innermost common loop depth for dstNode
- // producer-consumer loads/stores.
- SmallVector<Operation *, 2> dstMemrefOps;
- for (Operation *op : dstNode->loads)
- if (producerConsumerMemrefs.count(
- cast<AffineReadOpInterface>(op).getMemRef()) > 0)
- dstMemrefOps.push_back(op);
- for (Operation *op : dstNode->stores)
- if (producerConsumerMemrefs.count(
- cast<AffineWriteOpInterface>(op).getMemRef()))
- dstMemrefOps.push_back(op);
- unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstMemrefOps);
-
- // Check the feasibility of fusing src loop nest into dst loop nest
- // at loop depths in range [1, dstLoopDepthTest].
- unsigned maxLegalFusionDepth = 0;
- SmallVector<ComputationSliceState, 8> depthSliceUnions;
- depthSliceUnions.resize(dstLoopDepthTest);
- FusionStrategy strategy(FusionStrategy::ProducerConsumer);
- for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
- FusionResult result = mlir::canFuseLoops(
- srcAffineForOp, dstAffineForOp,
- /*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy);
-
- if (result.value == FusionResult::Success)
- maxLegalFusionDepth = i;
- }
+ // Skip if 'srcNode' out edge count on any memref is greater than
+ // 'maxSrcUserCount'.
+ if (any_of(producerConsumerMemrefs, [&](Value memref) {
+ return mdg->getOutEdgeCount(srcNode->id, memref) >
+ maxSrcUserCount;
+ }))
+ continue;
- if (maxLegalFusionDepth == 0) {
- LLVM_DEBUG(llvm::dbgs()
- << "Can't fuse: fusion is not legal at any depth\n");
- continue;
- }
+ // Gather memrefs in 'srcNode' that are written and escape out of the
+ // block (e.g., memref block arguments, returned memrefs,
+ // memrefs passed to function calls, etc.).
+ DenseSet<Value> srcEscapingMemRefs;
+ gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs);
+
+ // Skip if there are non-affine operations in between the 'srcNode'
+ // and 'dstNode' using their memrefs. If so, we wouldn't be able to
+ // compute a legal insertion point for now. 'srcNode' and 'dstNode'
+ // memrefs with non-affine operation users would be considered
+ // escaping memrefs so we can limit this check to only scenarios with
+ // escaping memrefs.
+ if (!srcEscapingMemRefs.empty() &&
+ hasNonAffineUsersOnThePath(srcId, dstId, mdg)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Can't fuse: non-affine users in between the loops\n.");
+ continue;
+ }
- // Check if fusion would be profitable. We skip profitability analysis
- // for maximal fusion since we already know the maximal legal depth to
- // fuse.
- unsigned bestDstLoopDepth = maxLegalFusionDepth;
- if (!maximalFusion) {
- // Retrieve producer stores from the src loop.
- SmallVector<Operation *, 2> producerStores;
- for (Operation *op : srcNode->stores)
- if (producerConsumerMemrefs.count(
- cast<AffineWriteOpInterface>(op).getMemRef()))
- producerStores.push_back(op);
-
- // TODO: Suppport multiple producer stores in profitability
- // analysis. We limit profitability analysis to only scenarios with
- // a single producer store for now. Note that some multi-store
- // producer scenarios will still go through profitability analysis
- // if only one of the stores is involved the producer-consumer
- // relationship of the candidate loops.
- assert(!producerStores.empty() && "Expected producer store");
- if (producerStores.size() > 1)
- LLVM_DEBUG(llvm::dbgs() << "Skipping profitability analysis. Not "
- "supported for this case\n");
- else if (!isFusionProfitable(producerStores[0], producerStores[0],
- dstAffineForOp, depthSliceUnions,
- maxLegalFusionDepth, &bestDstLoopDepth,
- computeToleranceThreshold))
- continue;
- }
+ // Compute an operation list insertion point for the fused loop
+ // nest which preserves dependences.
+ Operation *fusedLoopInsPoint =
+ mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id);
+ if (fusedLoopInsPoint == nullptr)
+ continue;
+
+ // Compute the innermost common loop depth for dstNode
+ // producer-consumer loads/stores.
+ SmallVector<Operation *, 2> dstMemrefOps;
+ for (Operation *op : dstNode->loads)
+ if (producerConsumerMemrefs.count(
+ cast<AffineReadOpInterface>(op).getMemRef()) > 0)
+ dstMemrefOps.push_back(op);
+ for (Operation *op : dstNode->stores)
+ if (producerConsumerMemrefs.count(
+ cast<AffineWriteOpInterface>(op).getMemRef()))
+ dstMemrefOps.push_back(op);
+ unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstMemrefOps);
+
+ // Check the feasibility of fusing src loop nest into dst loop nest
+ // at loop depths in range [1, dstLoopDepthTest].
+ unsigned maxLegalFusionDepth = 0;
+ SmallVector<ComputationSliceState, 8> depthSliceUnions;
+ depthSliceUnions.resize(dstLoopDepthTest);
+ FusionStrategy strategy(FusionStrategy::ProducerConsumer);
+ for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
+ FusionResult result = mlir::canFuseLoops(
+ srcAffineForOp, dstAffineForOp,
+ /*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy);
+
+ if (result.value == FusionResult::Success)
+ maxLegalFusionDepth = i;
+ }
- assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
- ComputationSliceState &bestSlice =
- depthSliceUnions[bestDstLoopDepth - 1];
- assert(!bestSlice.isEmpty() && "Missing slice union for depth");
-
- // Determine if 'srcId' can be removed after fusion, taking into
- // account remaining dependences, escaping memrefs and the fusion
- // insertion point.
- bool removeSrcNode = canRemoveSrcNodeAfterFusion(
- srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs,
- mdg);
-
- DenseSet<Value> privateMemrefs;
- for (Value memref : producerConsumerMemrefs) {
- // If `memref` is an escaping one, do not create a private memref
- // for the below scenarios, since doing so will leave the escaping
- // memref unmodified as all the writes originally meant for the
- // escaping memref would be performed on the private memref:
- // 1. The source is to be removed after fusion,
- // OR
- // 2. The destination writes to `memref`.
- if (srcEscapingMemRefs.count(memref) > 0 &&
- (removeSrcNode || dstNode->getStoreOpCount(memref) > 0))
- continue;
-
- // Don't create a private memref if 'srcNode' has in edges on
- // 'memref' or 'dstNode' has out edges on 'memref'.
- if (mdg->getIncomingMemRefAccesses(srcId, memref) > 0 ||
- mdg->getOutEdgeCount(dstId, memref) > 0)
- continue;
-
- // If 'srcNode' will be removed but it has out edges on 'memref' to
- // nodes other than 'dstNode', we have to preserve dependences and
- // cannot create a private memref.
- if (removeSrcNode &&
- any_of(mdg->outEdges[srcId], [&](const auto &edge) {
- return edge.value == memref && edge.id != dstId;
- }))
- continue;
+ if (maxLegalFusionDepth == 0) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Can't fuse: fusion is not legal at any depth\n");
+ continue;
+ }
+ // Check if fusion would be profitable. We skip profitability analysis
+ // for maximal fusion since we already know the maximal legal depth to
+ // fuse.
+ unsigned bestDstLoopDepth = maxLegalFusionDepth;
+ if (!maximalFusion) {
+ // Retrieve producer stores from the src loop.
+ SmallVector<Operation *, 2> producerStores;
+ for (Operation *op : srcNode->stores)
+ if (producerConsumerMemrefs.count(
+ cast<AffineWriteOpInterface>(op).getMemRef()))
+ producerStores.push_back(op);
+
+ // TODO: Suppport multiple producer stores in profitability
+ // analysis. We limit profitability analysis to only scenarios with
+ // a single producer store for now. Note that some multi-store
+ // producer scenarios will still go through profitability analysis
+ // if only one of the stores is involved the producer-consumer
+ // relationship of the candidate loops.
+ assert(!producerStores.empty() && "Expected producer store");
+ if (producerStores.size() > 1)
+ LLVM_DEBUG(llvm::dbgs() << "Skipping profitability analysis. Not "
+ "supported for this case\n");
+ else if (!isFusionProfitable(producerStores[0], producerStores[0],
+ dstAffineForOp, depthSliceUnions,
+ maxLegalFusionDepth, &bestDstLoopDepth,
+ computeToleranceThreshold))
+ continue;
+ }
+
+ assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
+ ComputationSliceState &bestSlice =
+ depthSliceUnions[bestDstLoopDepth - 1];
+ assert(!bestSlice.isEmpty() && "Missing slice union for depth");
+
+ // Determine if 'srcId' can be removed after fusion, taking into
+ // account remaining dependences, escaping memrefs and the fusion
+ // insertion point.
+ bool removeSrcNode = canRemoveSrcNodeAfterFusion(
+ srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs,
+ mdg);
+
+ DenseSet<Value> privateMemrefs;
+ for (Value memref : producerConsumerMemrefs) {
+ if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId,
+ removeSrcNode)) {
+ // Create a private version of this memref.
+ LLVM_DEBUG(llvm::dbgs()
+ << "Creating private memref for " << memref << '\n');
// Create a private version of this memref.
privateMemrefs.insert(memref);
}
+ }
- // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
- fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice);
- dstNodeChanged = true;
+ // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
+ fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice);
+ dstNodeChanged = true;
+
+ LLVM_DEBUG(llvm::dbgs()
+ << "Fused src loop " << srcId << " into dst loop " << dstId
+ << " at depth " << bestDstLoopDepth << ":\n"
+ << dstAffineForOp << "\n");
+
+ // Move 'dstAffineForOp' before 'insertPointInst' if needed.
+ if (fusedLoopInsPoint != dstAffineForOp)
+ dstAffineForOp->moveBefore(fusedLoopInsPoint);
+
+ // Update edges between 'srcNode' and 'dstNode'.
+ mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs,
+ removeSrcNode);
+
+ // Create private memrefs.
+ if (!privateMemrefs.empty()) {
+ // Gather stores for all the private-to-be memrefs.
+ DenseMap<Value, SmallVector<Operation *, 4>> privateMemRefToStores;
+ dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {
+ Value storeMemRef = storeOp.getMemRef();
+ if (privateMemrefs.count(storeMemRef) > 0)
+ privateMemRefToStores[storeMemRef].push_back(storeOp);
+ });
- LLVM_DEBUG(llvm::dbgs()
- << "Fused src loop " << srcId << " into dst loop " << dstId
- << " at depth " << bestDstLoopDepth << ":\n"
- << dstAffineForOp << "\n");
-
- // Move 'dstAffineForOp' before 'insertPointInst' if needed.
- if (fusedLoopInsPoint != dstAffineForOp)
- dstAffineForOp->moveBefore(fusedLoopInsPoint);
-
- // Update edges between 'srcNode' and 'dstNode'.
- mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs,
- removeSrcNode);
-
- // Create private memrefs.
- if (!privateMemrefs.empty()) {
- // Gather stores for all the private-to-be memrefs.
- DenseMap<Value, SmallVector<Operation *, 4>> privateMemRefToStores;
- dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {
- Value storeMemRef = storeOp.getMemRef();
- if (privateMemrefs.count(storeMemRef) > 0)
- privateMemRefToStores[storeMemRef].push_back(storeOp);
- });
-
- // Replace original memrefs with private memrefs. Note that all the
- // loads and stores on these memrefs will be replaced with a new
- // loads and stores. Any reference to the original ones becomes
- // invalid after this point.
- for (auto &memrefToStoresPair : privateMemRefToStores) {
- // TODO: Use union of memref write regions to compute
- // private memref footprint.
- SmallVector<Operation *, 4> &storesForMemref =
- memrefToStoresPair.second;
- Value newMemRef = createPrivateMemRef(
- dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
- fastMemorySpace, localBufSizeThreshold);
- // Create new node in dependence graph for 'newMemRef' alloc op.
- unsigned newMemRefNodeId =
- mdg->addNode(newMemRef.getDefiningOp());
- // Add edge from 'newMemRef' node to dstNode.
- mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
- }
- // One or more entries for 'newMemRef' alloc op are inserted into
- // the DenseMap mdg->nodes. Since an insertion may cause DenseMap to
- // reallocate, update dstNode.
- dstNode = mdg->getNode(dstId);
+ // Replace original memrefs with private memrefs. Note that all the
+ // loads and stores on these memrefs will be replaced with a new
+ // loads and stores. Any reference to the original ones becomes
+ // invalid after this point.
+ for (auto &memrefToStoresPair : privateMemRefToStores) {
+ // TODO: Use union of memref write regions to compute
+ // private memref footprint.
+ SmallVector<Operation *, 4> &storesForMemref =
+ memrefToStoresPair.second;
+ Value newMemRef = createPrivateMemRef(
+ dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
+ fastMemorySpace, localBufSizeThreshold);
+ // Create new node in dependence graph for 'newMemRef' alloc op.
+ unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp());
+ // Add edge from 'newMemRef' node to dstNode.
+ mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
}
+ // One or more entries for 'newMemRef' alloc op are inserted into
+ // the DenseMap mdg->nodes. Since an insertion may cause DenseMap to
+ // reallocate, update dstNode.
+ dstNode = mdg->getNode(dstId);
+ }
- // Collect dst loop stats after memref privatization transformation.
- LoopNestStateCollector dstLoopCollector;
- dstLoopCollector.collect(dstAffineForOp);
+ // Collect dst loop stats after memref privatization transformation.
+ LoopNestStateCollector dstLoopCollector;
+ dstLoopCollector.collect(dstAffineForOp);
- // Clear and add back loads and stores.
- mdg->clearNodeLoadAndStores(dstNode->id);
- mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
- dstLoopCollector.storeOpInsts);
+ // Clear and add back loads and stores.
+ mdg->clearNodeLoadAndStores(dstNode->id);
+ mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
+ dstLoopCollector.storeOpInsts);
- if (removeSrcNode) {
- LLVM_DEBUG(llvm::dbgs()
- << "Removing src loop " << srcId << " after fusion\n");
- // srcNode is no longer valid after it is removed from mdg.
- srcAffineForOp.erase();
- mdg->removeNode(srcId);
- srcNode = nullptr;
- }
+ if (removeSrcNode) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Removing src loop " << srcId << " after fusion\n");
+ // srcNode is no longer valid after it is removed from mdg.
+ srcAffineForOp.erase();
+ mdg->removeNode(srcId);
+ srcNode = nullptr;
}
- } while (dstNodeChanged);
+ }
+ } while (dstNodeChanged);
+ }
+
+ /// Visit each node in the graph, and for each node, attempt to fuse it with
+ /// producer-consumer candidates. No fusion is performed when producers with a
+ /// user count greater than `maxSrcUserCount` for any of the memrefs involved
+ /// are encountered.
+ void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {
+ LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n");
+ init();
+ while (!worklist.empty()) {
+ unsigned dstId = worklist.back();
+ worklist.pop_back();
+ performFusionsIntoDest(dstId, maxSrcUserCount);
}
}
More information about the Mlir-commits
mailing list