[Mlir-commits] [mlir] 6b2e29c - NFC. Refactor affine fusion code for readability

Uday Bondhugula llvmlistbot at llvm.org
Thu Jan 19 23:50:37 PST 2023


Author: Uday Bondhugula
Date: 2023-01-20T13:17:10+05:30
New Revision: 6b2e29c5080e3e9870d8a2151a0feaf64eb480d0

URL: https://github.com/llvm/llvm-project/commit/6b2e29c5080e3e9870d8a2151a0feaf64eb480d0
DIFF: https://github.com/llvm/llvm-project/commit/6b2e29c5080e3e9870d8a2151a0feaf64eb480d0.diff

LOG: NFC. Refactor affine fusion code for readability

Replace a couple of check instances with llvm::any_of (clang-tidy
warnings).  Factor out "canCreatePrivateMemRef" and
"performFusionsIntoDest" into separate methods to reduce the
length/indent of the containing methods. Add doc comments and debug messages.

Mark some of the methods that should have been const const.

NFC.

Reviewed By: vinayaka-polymage

Differential Revision: https://reviews.llvm.org/D142076

Added: 
    

Modified: 
    mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index 656158798cbc7..db39a835144ac 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -125,20 +125,20 @@ struct MemRefDependenceGraph {
     Node(unsigned id, Operation *op) : id(id), op(op) {}
 
     // Returns the load op count for 'memref'.
-    unsigned getLoadOpCount(Value memref) {
+    unsigned getLoadOpCount(Value memref) const {
       unsigned loadOpCount = 0;
-      for (auto *loadOpInst : loads) {
-        if (memref == cast<AffineReadOpInterface>(loadOpInst).getMemRef())
+      for (Operation *loadOp : loads) {
+        if (memref == cast<AffineReadOpInterface>(loadOp).getMemRef())
           ++loadOpCount;
       }
       return loadOpCount;
     }
 
     // Returns the store op count for 'memref'.
-    unsigned getStoreOpCount(Value memref) {
+    unsigned getStoreOpCount(Value memref) const {
       unsigned storeOpCount = 0;
-      for (auto *storeOpInst : stores) {
-        if (memref == cast<AffineWriteOpInterface>(storeOpInst).getMemRef())
+      for (Operation *storeOp : stores) {
+        if (memref == cast<AffineWriteOpInterface>(storeOp).getMemRef())
           ++storeOpCount;
       }
       return storeOpCount;
@@ -146,31 +146,32 @@ struct MemRefDependenceGraph {
 
     // Returns all store ops in 'storeOps' which access 'memref'.
     void getStoreOpsForMemref(Value memref,
-                              SmallVectorImpl<Operation *> *storeOps) {
-      for (auto *storeOpInst : stores) {
-        if (memref == cast<AffineWriteOpInterface>(storeOpInst).getMemRef())
-          storeOps->push_back(storeOpInst);
+                              SmallVectorImpl<Operation *> *storeOps) const {
+      for (Operation *storeOp : stores) {
+        if (memref == cast<AffineWriteOpInterface>(storeOp).getMemRef())
+          storeOps->push_back(storeOp);
       }
     }
 
     // Returns all load ops in 'loadOps' which access 'memref'.
     void getLoadOpsForMemref(Value memref,
-                             SmallVectorImpl<Operation *> *loadOps) {
-      for (auto *loadOpInst : loads) {
-        if (memref == cast<AffineReadOpInterface>(loadOpInst).getMemRef())
-          loadOps->push_back(loadOpInst);
+                             SmallVectorImpl<Operation *> *loadOps) const {
+      for (Operation *loadOp : loads) {
+        if (memref == cast<AffineReadOpInterface>(loadOp).getMemRef())
+          loadOps->push_back(loadOp);
       }
     }
 
     // Returns all memrefs in 'loadAndStoreMemrefSet' for which this node
     // has at least one load and store operation.
-    void getLoadAndStoreMemrefSet(DenseSet<Value> *loadAndStoreMemrefSet) {
+    void
+    getLoadAndStoreMemrefSet(DenseSet<Value> *loadAndStoreMemrefSet) const {
       llvm::SmallDenseSet<Value, 2> loadMemrefs;
-      for (auto *loadOpInst : loads) {
-        loadMemrefs.insert(cast<AffineReadOpInterface>(loadOpInst).getMemRef());
+      for (Operation *loadOp : loads) {
+        loadMemrefs.insert(cast<AffineReadOpInterface>(loadOp).getMemRef());
       }
-      for (auto *storeOpInst : stores) {
-        auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
+      for (Operation *storeOp : stores) {
+        auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
         if (loadMemrefs.count(memref) > 0)
           loadAndStoreMemrefSet->insert(memref);
       }
@@ -744,14 +745,12 @@ static bool isEscapingMemref(Value memref, Block *block) {
 
   // Check if 'memref' is used by a non-deferencing op (including unknown ones)
   // (e.g., call ops, alias creating ops, etc.).
-  for (Operation *user : memref.getUsers()) {
+  return llvm::any_of(memref.getUsers(), [&](Operation *user) {
     // Ignore users outside of `block`.
     if (block->getParent()->findAncestorOpInRegion(*user)->getBlock() != block)
-      continue;
-    if (!isa<AffineMapAccessInterface>(*user))
-      return true;
-  }
-  return false;
+      return false;
+    return !isa<AffineMapAccessInterface>(*user);
+  });
 }
 
 /// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id'
@@ -1076,10 +1075,9 @@ static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
     return WalkResult::advance();
   });
   // Looking for users between node 'srcId' and node 'dstId'.
-  for (Value memref : memRefValues)
-    if (hasNonAffineUsersOnThePath(srcId, dstId, memref, mdg))
-      return true;
-  return false;
+  return llvm::any_of(memRefValues, [&](Value memref) {
+    return hasNonAffineUsersOnThePath(srcId, dstId, memref, mdg);
+  });
 }
 
 // Checks the profitability of fusing a backwards slice of the loop nest
@@ -1452,280 +1450,301 @@ struct GreedyFusion {
     eraseUnusedMemRefAllocations();
   }
 
-  /// Visit each node in the graph, and for each node, attempt to fuse it with
-  /// producer-consumer candidates. No fusion is performed when producers with a
-  /// user count greater than `maxSrcUserCount` for any of the memrefs involved
-  /// are encountered.
-  void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {
-    LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n");
-    init();
-    while (!worklist.empty()) {
-      unsigned dstId = worklist.back();
-      worklist.pop_back();
+  /// Returns true if a private memref can be created for `memref` given
+  /// the fusion scenario reflected by the other arguments.
+  bool canCreatePrivateMemRef(Value memref,
+                              const DenseSet<Value> &srcEscapingMemRefs,
+                              unsigned producerId, unsigned consumerId,
+                              bool removeSrcNode) {
+    const Node *consumerNode = mdg->getNode(consumerId);
+    // If `memref` is an escaping one, do not create a private memref
+    // for the below scenarios, since doing so will leave the escaping
+    // memref unmodified as all the writes originally meant for the
+    // escaping memref would be performed on the private memref:
+    // 1. The source is to be removed after fusion,
+    // OR
+    // 2. The destination writes to `memref`.
+    if (srcEscapingMemRefs.count(memref) > 0 &&
+        (removeSrcNode || consumerNode->getStoreOpCount(memref) > 0))
+      return false;
 
-      // Skip if this node was removed (fused into another node).
-      if (mdg->nodes.count(dstId) == 0)
-        continue;
-      // Get 'dstNode' into which to attempt fusion.
-      auto *dstNode = mdg->getNode(dstId);
-      // Skip if 'dstNode' is not a loop nest.
-      if (!isa<AffineForOp>(dstNode->op))
-        continue;
-      // Skip if 'dstNode' is a loop nest returning values.
-      // TODO: support loop nests that return values.
-      if (dstNode->op->getNumResults() > 0)
-        continue;
+    // Don't create a private memref if 'srcNode' has in edges on
+    // 'memref' or 'dstNode' has out edges on 'memref'.
+    if (mdg->getIncomingMemRefAccesses(producerId, memref) > 0 ||
+        mdg->getOutEdgeCount(consumerId, memref) > 0)
+      return false;
 
-      LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
-
-      // Sink sequential loops in 'dstNode' (and thus raise parallel loops)
-      // while preserving relative order. This can increase the maximum loop
-      // depth at which we can fuse a slice of a producer loop nest into a
-      // consumer loop nest.
-      sinkSequentialLoops(dstNode);
-      auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
-
-      // Try to fuse 'dstNode' with candidate producer loops until a fixed point
-      // is reached. Fusing two loops may expose new fusion opportunities.
-      bool dstNodeChanged;
-      do {
-        // Gather src loop candidates for 'dstNode' and visit them in "quasi"
-        // reverse program order to minimize the number of iterations needed to
-        // reach the fixed point. Note that this is a best effort approach since
-        // 'getProducerCandidates' does not always guarantee that program order
-        // in 'srcIdCandidates'.
-        dstNodeChanged = false;
-        SmallVector<unsigned, 16> srcIdCandidates;
-        getProducerCandidates(dstId, mdg, srcIdCandidates);
-
-        for (unsigned srcId : llvm::reverse(srcIdCandidates)) {
-          // Get 'srcNode' from which to attempt fusion into 'dstNode'.
-          auto *srcNode = mdg->getNode(srcId);
-          auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
-          LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId
-                                  << " for dst loop " << dstId << "\n");
-
-          // Skip if 'srcNode' is a loop nest returning values.
-          // TODO: support loop nests that return values.
-          if (isa<AffineForOp>(srcNode->op) && srcNode->op->getNumResults() > 0)
-            continue;
+    // If 'srcNode' will be removed but it has out edges on 'memref' to
+    // nodes other than 'dstNode', we have to preserve dependences and
+    // cannot create a private memref.
+    if (removeSrcNode &&
+        any_of(mdg->outEdges[producerId], [&](const auto &edge) {
+          return edge.value == memref && edge.id != consumerId;
+        }))
+      return false;
 
-          DenseSet<Value> producerConsumerMemrefs;
-          gatherProducerConsumerMemrefs(srcId, dstId, mdg,
-                                        producerConsumerMemrefs);
+    return true;
+  }
 
-          // Skip if 'srcNode' out edge count on any memref is greater than
-          // 'maxSrcUserCount'.
-          if (any_of(producerConsumerMemrefs, [&](Value memref) {
-                return mdg->getOutEdgeCount(srcNode->id, memref) >
-                       maxSrcUserCount;
-              }))
-            continue;
+  /// Perform fusions with node `dstId` as the destination of fusion, with
+  /// No fusion is performed when producers with a user count greater than
+  /// `maxSrcUserCount` for any of the memrefs involved.
+  void performFusionsIntoDest(unsigned dstId, unsigned maxSrcUserCount) {
+    LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
+    // Skip if this node was removed (fused into another node).
+    if (mdg->nodes.count(dstId) == 0)
+      return;
+    // Get 'dstNode' into which to attempt fusion.
+    auto *dstNode = mdg->getNode(dstId);
+    // Skip if 'dstNode' is not a loop nest.
+    if (!isa<AffineForOp>(dstNode->op))
+      return;
+    // Skip if 'dstNode' is a loop nest returning values.
+    // TODO: support loop nests that return values.
+    if (dstNode->op->getNumResults() > 0)
+      return;
+
+    LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
+
+    // Sink sequential loops in 'dstNode' (and thus raise parallel loops)
+    // while preserving relative order. This can increase the maximum loop
+    // depth at which we can fuse a slice of a producer loop nest into a
+    // consumer loop nest.
+    sinkSequentialLoops(dstNode);
+    auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
 
-          // Gather memrefs in 'srcNode' that are written and escape out of the
-          // block (e.g., memref block arguments, returned memrefs,
-          // memrefs passed to function calls, etc.).
-          DenseSet<Value> srcEscapingMemRefs;
-          gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs);
-
-          // Skip if there are non-affine operations in between the 'srcNode'
-          // and 'dstNode' using their memrefs. If so, we wouldn't be able to
-          // compute a legal insertion point for now. 'srcNode' and 'dstNode'
-          // memrefs with non-affine operation users would be considered
-          // escaping memrefs so we can limit this check to only scenarios with
-          // escaping memrefs.
-          if (!srcEscapingMemRefs.empty() &&
-              hasNonAffineUsersOnThePath(srcId, dstId, mdg)) {
-            LLVM_DEBUG(
-                llvm::dbgs()
-                << "Can't fuse: non-affine users in between the loops\n.");
-            continue;
-          }
+    // Try to fuse 'dstNode' with candidate producer loops until a fixed point
+    // is reached. Fusing two loops may expose new fusion opportunities.
+    bool dstNodeChanged;
+    do {
+      // Gather src loop candidates for 'dstNode' and visit them in "quasi"
+      // reverse program order to minimize the number of iterations needed to
+      // reach the fixed point. Note that this is a best effort approach since
+      // 'getProducerCandidates' does not always guarantee that program order
+      // in 'srcIdCandidates'.
+      dstNodeChanged = false;
+      SmallVector<unsigned, 16> srcIdCandidates;
+      getProducerCandidates(dstId, mdg, srcIdCandidates);
+
+      for (unsigned srcId : llvm::reverse(srcIdCandidates)) {
+        // Get 'srcNode' from which to attempt fusion into 'dstNode'.
+        auto *srcNode = mdg->getNode(srcId);
+        auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
+        LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId
+                                << " for dst loop " << dstId << "\n");
+
+        // Skip if 'srcNode' is a loop nest returning values.
+        // TODO: support loop nests that return values.
+        if (isa<AffineForOp>(srcNode->op) && srcNode->op->getNumResults() > 0)
+          continue;
 
-          // Compute an operation list insertion point for the fused loop
-          // nest which preserves dependences.
-          Operation *fusedLoopInsPoint =
-              mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id);
-          if (fusedLoopInsPoint == nullptr)
-            continue;
+        DenseSet<Value> producerConsumerMemrefs;
+        gatherProducerConsumerMemrefs(srcId, dstId, mdg,
+                                      producerConsumerMemrefs);
 
-          // Compute the innermost common loop depth for dstNode
-          // producer-consumer loads/stores.
-          SmallVector<Operation *, 2> dstMemrefOps;
-          for (Operation *op : dstNode->loads)
-            if (producerConsumerMemrefs.count(
-                    cast<AffineReadOpInterface>(op).getMemRef()) > 0)
-              dstMemrefOps.push_back(op);
-          for (Operation *op : dstNode->stores)
-            if (producerConsumerMemrefs.count(
-                    cast<AffineWriteOpInterface>(op).getMemRef()))
-              dstMemrefOps.push_back(op);
-          unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstMemrefOps);
-
-          // Check the feasibility of fusing src loop nest into dst loop nest
-          // at loop depths in range [1, dstLoopDepthTest].
-          unsigned maxLegalFusionDepth = 0;
-          SmallVector<ComputationSliceState, 8> depthSliceUnions;
-          depthSliceUnions.resize(dstLoopDepthTest);
-          FusionStrategy strategy(FusionStrategy::ProducerConsumer);
-          for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
-            FusionResult result = mlir::canFuseLoops(
-                srcAffineForOp, dstAffineForOp,
-                /*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy);
-
-            if (result.value == FusionResult::Success)
-              maxLegalFusionDepth = i;
-          }
+        // Skip if 'srcNode' out edge count on any memref is greater than
+        // 'maxSrcUserCount'.
+        if (any_of(producerConsumerMemrefs, [&](Value memref) {
+              return mdg->getOutEdgeCount(srcNode->id, memref) >
+                     maxSrcUserCount;
+            }))
+          continue;
 
-          if (maxLegalFusionDepth == 0) {
-            LLVM_DEBUG(llvm::dbgs()
-                       << "Can't fuse: fusion is not legal at any depth\n");
-            continue;
-          }
+        // Gather memrefs in 'srcNode' that are written and escape out of the
+        // block (e.g., memref block arguments, returned memrefs,
+        // memrefs passed to function calls, etc.).
+        DenseSet<Value> srcEscapingMemRefs;
+        gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs);
+
+        // Skip if there are non-affine operations in between the 'srcNode'
+        // and 'dstNode' using their memrefs. If so, we wouldn't be able to
+        // compute a legal insertion point for now. 'srcNode' and 'dstNode'
+        // memrefs with non-affine operation users would be considered
+        // escaping memrefs so we can limit this check to only scenarios with
+        // escaping memrefs.
+        if (!srcEscapingMemRefs.empty() &&
+            hasNonAffineUsersOnThePath(srcId, dstId, mdg)) {
+          LLVM_DEBUG(llvm::dbgs()
+                     << "Can't fuse: non-affine users in between the loops\n.");
+          continue;
+        }
 
-          // Check if fusion would be profitable. We skip profitability analysis
-          // for maximal fusion since we already know the maximal legal depth to
-          // fuse.
-          unsigned bestDstLoopDepth = maxLegalFusionDepth;
-          if (!maximalFusion) {
-            // Retrieve producer stores from the src loop.
-            SmallVector<Operation *, 2> producerStores;
-            for (Operation *op : srcNode->stores)
-              if (producerConsumerMemrefs.count(
-                      cast<AffineWriteOpInterface>(op).getMemRef()))
-                producerStores.push_back(op);
-
-            // TODO: Suppport multiple producer stores in profitability
-            // analysis. We limit profitability analysis to only scenarios with
-            // a single producer store for now. Note that some multi-store
-            // producer scenarios will still go through profitability analysis
-            // if only one of the stores is involved the producer-consumer
-            // relationship of the candidate loops.
-            assert(!producerStores.empty() && "Expected producer store");
-            if (producerStores.size() > 1)
-              LLVM_DEBUG(llvm::dbgs() << "Skipping profitability analysis. Not "
-                                         "supported for this case\n");
-            else if (!isFusionProfitable(producerStores[0], producerStores[0],
-                                         dstAffineForOp, depthSliceUnions,
-                                         maxLegalFusionDepth, &bestDstLoopDepth,
-                                         computeToleranceThreshold))
-              continue;
-          }
+        // Compute an operation list insertion point for the fused loop
+        // nest which preserves dependences.
+        Operation *fusedLoopInsPoint =
+            mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id);
+        if (fusedLoopInsPoint == nullptr)
+          continue;
+
+        // Compute the innermost common loop depth for dstNode
+        // producer-consumer loads/stores.
+        SmallVector<Operation *, 2> dstMemrefOps;
+        for (Operation *op : dstNode->loads)
+          if (producerConsumerMemrefs.count(
+                  cast<AffineReadOpInterface>(op).getMemRef()) > 0)
+            dstMemrefOps.push_back(op);
+        for (Operation *op : dstNode->stores)
+          if (producerConsumerMemrefs.count(
+                  cast<AffineWriteOpInterface>(op).getMemRef()))
+            dstMemrefOps.push_back(op);
+        unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstMemrefOps);
+
+        // Check the feasibility of fusing src loop nest into dst loop nest
+        // at loop depths in range [1, dstLoopDepthTest].
+        unsigned maxLegalFusionDepth = 0;
+        SmallVector<ComputationSliceState, 8> depthSliceUnions;
+        depthSliceUnions.resize(dstLoopDepthTest);
+        FusionStrategy strategy(FusionStrategy::ProducerConsumer);
+        for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
+          FusionResult result = mlir::canFuseLoops(
+              srcAffineForOp, dstAffineForOp,
+              /*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy);
+
+          if (result.value == FusionResult::Success)
+            maxLegalFusionDepth = i;
+        }
 
-          assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
-          ComputationSliceState &bestSlice =
-              depthSliceUnions[bestDstLoopDepth - 1];
-          assert(!bestSlice.isEmpty() && "Missing slice union for depth");
-
-          // Determine if 'srcId' can be removed after fusion, taking into
-          // account remaining dependences, escaping memrefs and the fusion
-          // insertion point.
-          bool removeSrcNode = canRemoveSrcNodeAfterFusion(
-              srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs,
-              mdg);
-
-          DenseSet<Value> privateMemrefs;
-          for (Value memref : producerConsumerMemrefs) {
-            // If `memref` is an escaping one, do not create a private memref
-            // for the below scenarios, since doing so will leave the escaping
-            // memref unmodified as all the writes originally meant for the
-            // escaping memref would be performed on the private memref:
-            // 1. The source is to be removed after fusion,
-            // OR
-            // 2. The destination writes to `memref`.
-            if (srcEscapingMemRefs.count(memref) > 0 &&
-                (removeSrcNode || dstNode->getStoreOpCount(memref) > 0))
-              continue;
-
-            // Don't create a private memref if 'srcNode' has in edges on
-            // 'memref' or 'dstNode' has out edges on 'memref'.
-            if (mdg->getIncomingMemRefAccesses(srcId, memref) > 0 ||
-                mdg->getOutEdgeCount(dstId, memref) > 0)
-              continue;
-
-            // If 'srcNode' will be removed but it has out edges on 'memref' to
-            // nodes other than 'dstNode', we have to preserve dependences and
-            // cannot create a private memref.
-            if (removeSrcNode &&
-                any_of(mdg->outEdges[srcId], [&](const auto &edge) {
-                  return edge.value == memref && edge.id != dstId;
-                }))
-              continue;
+        if (maxLegalFusionDepth == 0) {
+          LLVM_DEBUG(llvm::dbgs()
+                     << "Can't fuse: fusion is not legal at any depth\n");
+          continue;
+        }
 
+        // Check if fusion would be profitable. We skip profitability analysis
+        // for maximal fusion since we already know the maximal legal depth to
+        // fuse.
+        unsigned bestDstLoopDepth = maxLegalFusionDepth;
+        if (!maximalFusion) {
+          // Retrieve producer stores from the src loop.
+          SmallVector<Operation *, 2> producerStores;
+          for (Operation *op : srcNode->stores)
+            if (producerConsumerMemrefs.count(
+                    cast<AffineWriteOpInterface>(op).getMemRef()))
+              producerStores.push_back(op);
+
+          // TODO: Suppport multiple producer stores in profitability
+          // analysis. We limit profitability analysis to only scenarios with
+          // a single producer store for now. Note that some multi-store
+          // producer scenarios will still go through profitability analysis
+          // if only one of the stores is involved the producer-consumer
+          // relationship of the candidate loops.
+          assert(!producerStores.empty() && "Expected producer store");
+          if (producerStores.size() > 1)
+            LLVM_DEBUG(llvm::dbgs() << "Skipping profitability analysis. Not "
+                                       "supported for this case\n");
+          else if (!isFusionProfitable(producerStores[0], producerStores[0],
+                                       dstAffineForOp, depthSliceUnions,
+                                       maxLegalFusionDepth, &bestDstLoopDepth,
+                                       computeToleranceThreshold))
+            continue;
+        }
+
+        assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
+        ComputationSliceState &bestSlice =
+            depthSliceUnions[bestDstLoopDepth - 1];
+        assert(!bestSlice.isEmpty() && "Missing slice union for depth");
+
+        // Determine if 'srcId' can be removed after fusion, taking into
+        // account remaining dependences, escaping memrefs and the fusion
+        // insertion point.
+        bool removeSrcNode = canRemoveSrcNodeAfterFusion(
+            srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs,
+            mdg);
+
+        DenseSet<Value> privateMemrefs;
+        for (Value memref : producerConsumerMemrefs) {
+          if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId,
+                                     removeSrcNode)) {
+            // Create a private version of this memref.
+            LLVM_DEBUG(llvm::dbgs()
+                       << "Creating private memref for " << memref << '\n');
             // Create a private version of this memref.
             privateMemrefs.insert(memref);
           }
+        }
 
-          // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
-          fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice);
-          dstNodeChanged = true;
+        // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
+        fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice);
+        dstNodeChanged = true;
+
+        LLVM_DEBUG(llvm::dbgs()
+                   << "Fused src loop " << srcId << " into dst loop " << dstId
+                   << " at depth " << bestDstLoopDepth << ":\n"
+                   << dstAffineForOp << "\n");
+
+        // Move 'dstAffineForOp' before 'insertPointInst' if needed.
+        if (fusedLoopInsPoint != dstAffineForOp)
+          dstAffineForOp->moveBefore(fusedLoopInsPoint);
+
+        // Update edges between 'srcNode' and 'dstNode'.
+        mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs,
+                         removeSrcNode);
+
+        // Create private memrefs.
+        if (!privateMemrefs.empty()) {
+          // Gather stores for all the private-to-be memrefs.
+          DenseMap<Value, SmallVector<Operation *, 4>> privateMemRefToStores;
+          dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {
+            Value storeMemRef = storeOp.getMemRef();
+            if (privateMemrefs.count(storeMemRef) > 0)
+              privateMemRefToStores[storeMemRef].push_back(storeOp);
+          });
 
-          LLVM_DEBUG(llvm::dbgs()
-                     << "Fused src loop " << srcId << " into dst loop " << dstId
-                     << " at depth " << bestDstLoopDepth << ":\n"
-                     << dstAffineForOp << "\n");
-
-          // Move 'dstAffineForOp' before 'insertPointInst' if needed.
-          if (fusedLoopInsPoint != dstAffineForOp)
-            dstAffineForOp->moveBefore(fusedLoopInsPoint);
-
-          // Update edges between 'srcNode' and 'dstNode'.
-          mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs,
-                           removeSrcNode);
-
-          // Create private memrefs.
-          if (!privateMemrefs.empty()) {
-            // Gather stores for all the private-to-be memrefs.
-            DenseMap<Value, SmallVector<Operation *, 4>> privateMemRefToStores;
-            dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {
-              Value storeMemRef = storeOp.getMemRef();
-              if (privateMemrefs.count(storeMemRef) > 0)
-                privateMemRefToStores[storeMemRef].push_back(storeOp);
-            });
-
-            // Replace original memrefs with private memrefs. Note that all the
-            // loads and stores on these memrefs will be replaced with a new
-            // loads and stores. Any reference to the original ones becomes
-            // invalid after this point.
-            for (auto &memrefToStoresPair : privateMemRefToStores) {
-              // TODO: Use union of memref write regions to compute
-              // private memref footprint.
-              SmallVector<Operation *, 4> &storesForMemref =
-                  memrefToStoresPair.second;
-              Value newMemRef = createPrivateMemRef(
-                  dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
-                  fastMemorySpace, localBufSizeThreshold);
-              // Create new node in dependence graph for 'newMemRef' alloc op.
-              unsigned newMemRefNodeId =
-                  mdg->addNode(newMemRef.getDefiningOp());
-              // Add edge from 'newMemRef' node to dstNode.
-              mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
-            }
-            // One or more entries for 'newMemRef' alloc op are inserted into
-            // the DenseMap mdg->nodes. Since an insertion may cause DenseMap to
-            // reallocate, update dstNode.
-            dstNode = mdg->getNode(dstId);
+          // Replace original memrefs with private memrefs. Note that all the
+          // loads and stores on these memrefs will be replaced with a new
+          // loads and stores. Any reference to the original ones becomes
+          // invalid after this point.
+          for (auto &memrefToStoresPair : privateMemRefToStores) {
+            // TODO: Use union of memref write regions to compute
+            // private memref footprint.
+            SmallVector<Operation *, 4> &storesForMemref =
+                memrefToStoresPair.second;
+            Value newMemRef = createPrivateMemRef(
+                dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
+                fastMemorySpace, localBufSizeThreshold);
+            // Create new node in dependence graph for 'newMemRef' alloc op.
+            unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp());
+            // Add edge from 'newMemRef' node to dstNode.
+            mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
           }
+          // One or more entries for 'newMemRef' alloc op are inserted into
+          // the DenseMap mdg->nodes. Since an insertion may cause DenseMap to
+          // reallocate, update dstNode.
+          dstNode = mdg->getNode(dstId);
+        }
 
-          // Collect dst loop stats after memref privatization transformation.
-          LoopNestStateCollector dstLoopCollector;
-          dstLoopCollector.collect(dstAffineForOp);
+        // Collect dst loop stats after memref privatization transformation.
+        LoopNestStateCollector dstLoopCollector;
+        dstLoopCollector.collect(dstAffineForOp);
 
-          // Clear and add back loads and stores.
-          mdg->clearNodeLoadAndStores(dstNode->id);
-          mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
-                         dstLoopCollector.storeOpInsts);
+        // Clear and add back loads and stores.
+        mdg->clearNodeLoadAndStores(dstNode->id);
+        mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
+                       dstLoopCollector.storeOpInsts);
 
-          if (removeSrcNode) {
-            LLVM_DEBUG(llvm::dbgs()
-                       << "Removing src loop " << srcId << " after fusion\n");
-            // srcNode is no longer valid after it is removed from mdg.
-            srcAffineForOp.erase();
-            mdg->removeNode(srcId);
-            srcNode = nullptr;
-          }
+        if (removeSrcNode) {
+          LLVM_DEBUG(llvm::dbgs()
+                     << "Removing src loop " << srcId << " after fusion\n");
+          // srcNode is no longer valid after it is removed from mdg.
+          srcAffineForOp.erase();
+          mdg->removeNode(srcId);
+          srcNode = nullptr;
         }
-      } while (dstNodeChanged);
+      }
+    } while (dstNodeChanged);
+  }
+
+  /// Visit each node in the graph, and for each node, attempt to fuse it with
+  /// producer-consumer candidates. No fusion is performed when producers with a
+  /// user count greater than `maxSrcUserCount` for any of the memrefs involved
+  /// are encountered.
+  void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {
+    LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n");
+    init();
+    while (!worklist.empty()) {
+      unsigned dstId = worklist.back();
+      worklist.pop_back();
+      performFusionsIntoDest(dstId, maxSrcUserCount);
     }
   }
 


        


More information about the Mlir-commits mailing list