[Mlir-commits] [mlir] c910570 - [MLIR] NFC. Expose/move MDG out of Affine fusion into Affine Analysis

Uday Bondhugula llvmlistbot at llvm.org
Wed Mar 29 11:41:29 PDT 2023


Author: Uday Bondhugula
Date: 2023-03-30T00:11:13+05:30
New Revision: c910570fd22139299f219d3a5087e55968d8840d

URL: https://github.com/llvm/llvm-project/commit/c910570fd22139299f219d3a5087e55968d8840d
DIFF: https://github.com/llvm/llvm-project/commit/c910570fd22139299f219d3a5087e55968d8840d.diff

LOG: [MLIR] NFC. Expose/move MDG out of Affine fusion into Affine Analysis

Move out MemRefDependenceGraph analysis structure out of LoopFusion into
the Affine Analysis library. This had been a long pending TODO. Moving
MDG out allows its use in other affine passes as well as allows building
custom affine fusion passes downstream while reusing upstream fusion
utilties. The file LoopFusion.cpp had also become lengthy and this
change makes things more modular. This change is a pure NFC and is a
code movement.

NFC.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D147105

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
    mlir/lib/Dialect/Affine/Analysis/Utils.cpp
    mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
index 99e511f152618..c16fce39484a6 100644
--- a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
@@ -17,11 +17,7 @@
 #define MLIR_DIALECT_AFFINE_ANALYSIS_UTILS_H
 
 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/Block.h"
-#include "mlir/IR/Location.h"
-#include "mlir/Support/LLVM.h"
-#include "llvm/ADT/SmallVector.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include <memory>
 #include <optional>
 
@@ -35,6 +31,186 @@ struct MemRefAccess;
 class Operation;
 class Value;
 
+// LoopNestStateCollector walks loop nests and collects load and store
+// operations, and whether or not a region holding op other than ForOp and IfOp
+// was encountered in the loop nest.
+struct LoopNestStateCollector {
+  SmallVector<AffineForOp, 4> forOps;
+  SmallVector<Operation *, 4> loadOpInsts;
+  SmallVector<Operation *, 4> storeOpInsts;
+  bool hasNonAffineRegionOp = false;
+
+  // Collects load and store operations, and whether or not a region holding op
+  // other than ForOp and IfOp was encountered in the loop nest.
+  void collect(Operation *opToWalk);
+};
+
+// MemRefDependenceGraph is a graph data structure where graph nodes are
+// top-level operations in a `Block` which contain load/store ops, and edges
+// are memref dependences between the nodes.
+// TODO: Add a more flexible dependence graph representation.
+// TODO: Add a depth parameter to dependence graph construction.
+struct MemRefDependenceGraph {
+public:
+  // Node represents a node in the graph. A Node is either an entire loop nest
+  // rooted at the top level which contains loads/stores, or a top level
+  // load/store.
+  struct Node {
+    // The unique identifier of this node in the graph.
+    unsigned id;
+    // The top-level statement which is (or contains) a load/store.
+    Operation *op;
+    // List of load operations.
+    SmallVector<Operation *, 4> loads;
+    // List of store op insts.
+    SmallVector<Operation *, 4> stores;
+
+    Node(unsigned id, Operation *op) : id(id), op(op) {}
+
+    // Returns the load op count for 'memref'.
+    unsigned getLoadOpCount(Value memref) const;
+
+    // Returns the store op count for 'memref'.
+    unsigned getStoreOpCount(Value memref) const;
+
+    // Returns all store ops in 'storeOps' which access 'memref'.
+    void getStoreOpsForMemref(Value memref,
+                              SmallVectorImpl<Operation *> *storeOps) const;
+
+    // Returns all load ops in 'loadOps' which access 'memref'.
+    void getLoadOpsForMemref(Value memref,
+                             SmallVectorImpl<Operation *> *loadOps) const;
+
+    // Returns all memrefs in 'loadAndStoreMemrefSet' for which this node
+    // has at least one load and store operation.
+    void getLoadAndStoreMemrefSet(DenseSet<Value> *loadAndStoreMemrefSet) const;
+  };
+
+  // Edge represents a data dependence between nodes in the graph.
+  struct Edge {
+    // The id of the node at the other end of the edge.
+    // If this edge is stored in Edge = Node.inEdges[i], then
+    // 'Node.inEdges[i].id' is the identifier of the source node of the edge.
+    // If this edge is stored in Edge = Node.outEdges[i], then
+    // 'Node.outEdges[i].id' is the identifier of the dest node of the edge.
+    unsigned id;
+    // The SSA value on which this edge represents a dependence.
+    // If the value is a memref, then the dependence is between graph nodes
+    // which contain accesses to the same memref 'value'. If the value is a
+    // non-memref value, then the dependence is between a graph node which
+    // defines an SSA value and another graph node which uses the SSA value
+    // (e.g. a constant or load operation defining a value which is used inside
+    // a loop nest).
+    Value value;
+  };
+
+  // Map from node id to Node.
+  DenseMap<unsigned, Node> nodes;
+  // Map from node id to list of input edges.
+  DenseMap<unsigned, SmallVector<Edge, 2>> inEdges;
+  // Map from node id to list of output edges.
+  DenseMap<unsigned, SmallVector<Edge, 2>> outEdges;
+  // Map from memref to a count on the dependence edges associated with that
+  // memref.
+  DenseMap<Value, unsigned> memrefEdgeCount;
+  // The next unique identifier to use for newly created graph nodes.
+  unsigned nextNodeId = 0;
+
+  MemRefDependenceGraph(Block &block) : block(block) {}
+
+  // Initializes the dependence graph based on operations in `block'.
+  // Returns true on success, false otherwise.
+  bool init();
+
+  // Returns the graph node for 'id'.
+  Node *getNode(unsigned id);
+
+  // Returns the graph node for 'forOp'.
+  Node *getForOpNode(AffineForOp forOp);
+
+  // Adds a node with 'op' to the graph and returns its unique identifier.
+  unsigned addNode(Operation *op);
+
+  // Remove node 'id' (and its associated edges) from graph.
+  void removeNode(unsigned id);
+
+  // Returns true if node 'id' writes to any memref which escapes (or is an
+  // argument to) the block. Returns false otherwise.
+  bool writesToLiveInOrEscapingMemrefs(unsigned id);
+
+  // Returns true iff there is an edge from node 'srcId' to node 'dstId' which
+  // is for 'value' if non-null, or for any value otherwise. Returns false
+  // otherwise.
+  bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr);
+
+  // Adds an edge from node 'srcId' to node 'dstId' for 'value'.
+  void addEdge(unsigned srcId, unsigned dstId, Value value);
+
+  // Removes an edge from node 'srcId' to node 'dstId' for 'value'.
+  void removeEdge(unsigned srcId, unsigned dstId, Value value);
+
+  // Returns true if there is a path in the dependence graph from node 'srcId'
+  // to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
+  // operations that the edges connected are expected to be from the same block.
+  bool hasDependencePath(unsigned srcId, unsigned dstId);
+
+  // Returns the input edge count for node 'id' and 'memref' from src nodes
+  // which access 'memref' with a store operation.
+  unsigned getIncomingMemRefAccesses(unsigned id, Value memref);
+
+  // Returns the output edge count for node 'id' and 'memref' (if non-null),
+  // otherwise returns the total output edge count from node 'id'.
+  unsigned getOutEdgeCount(unsigned id, Value memref = nullptr);
+
+  /// Return all nodes which define SSA values used in node 'id'.
+  void gatherDefiningNodes(unsigned id, DenseSet<unsigned> &definingNodes);
+
+  // Computes and returns an insertion point operation, before which the
+  // the fused <srcId, dstId> loop nest can be inserted while preserving
+  // dependences. Returns nullptr if no such insertion point is found.
+  Operation *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId);
+
+  // Updates edge mappings from node 'srcId' to node 'dstId' after fusing them,
+  // taking into account that:
+  //   *) if 'removeSrcId' is true, 'srcId' will be removed after fusion,
+  //   *) memrefs in 'privateMemRefs' has been replaced in node at 'dstId' by a
+  //      private memref.
+  void updateEdges(unsigned srcId, unsigned dstId,
+                   const DenseSet<Value> &privateMemRefs, bool removeSrcId);
+
+  // Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion
+  // of sibling node 'sibId' into node 'dstId'.
+  void updateEdges(unsigned sibId, unsigned dstId);
+
+  // Adds ops in 'loads' and 'stores' to node at 'id'.
+  void addToNode(unsigned id, const SmallVectorImpl<Operation *> &loads,
+                 const SmallVectorImpl<Operation *> &stores);
+
+  void clearNodeLoadAndStores(unsigned id);
+
+  // Calls 'callback' for each input edge incident to node 'id' which carries a
+  // memref dependence.
+  void forEachMemRefInputEdge(unsigned id,
+                              const std::function<void(Edge)> &callback);
+
+  // Calls 'callback' for each output edge from node 'id' which carries a
+  // memref dependence.
+  void forEachMemRefOutputEdge(unsigned id,
+                               const std::function<void(Edge)> &callback);
+
+  // Calls 'callback' for each edge in 'edges' which carries a memref
+  // dependence.
+  void forEachMemRefEdge(ArrayRef<Edge> edges,
+                         const std::function<void(Edge)> &callback);
+
+  void print(raw_ostream &os) const;
+
+  void dump() const { print(llvm::errs()); }
+
+  /// The block for which this graph is created to perform fusion.
+  Block █
+};
+
 /// Populates 'loops' with IVs of the affine.for ops surrounding 'op' ordered
 /// from the outermost 'affine.for' operation to the innermost one.
 void getAffineForIVs(Operation &op, SmallVectorImpl<AffineForOp> *loops);

diff  --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index 41a739d726ed5..247b3786031ce 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -32,6 +32,475 @@ using namespace presburger;
 
 using llvm::SmallDenseMap;
 
+using Node = MemRefDependenceGraph::Node;
+
+// LoopNestStateCollector walks loop nests and collects load and store
+// operations, and whether or not a region holding op other than ForOp and IfOp
+// was encountered in the loop nest.
+void LoopNestStateCollector::collect(Operation *opToWalk) {
+  opToWalk->walk([&](Operation *op) {
+    if (isa<AffineForOp>(op))
+      forOps.push_back(cast<AffineForOp>(op));
+    else if (op->getNumRegions() != 0 && !isa<AffineIfOp>(op))
+      hasNonAffineRegionOp = true;
+    else if (isa<AffineReadOpInterface>(op))
+      loadOpInsts.push_back(op);
+    else if (isa<AffineWriteOpInterface>(op))
+      storeOpInsts.push_back(op);
+  });
+}
+
+// Returns the load op count for 'memref'.
+unsigned Node::getLoadOpCount(Value memref) const {
+  unsigned loadOpCount = 0;
+  for (Operation *loadOp : loads) {
+    if (memref == cast<AffineReadOpInterface>(loadOp).getMemRef())
+      ++loadOpCount;
+  }
+  return loadOpCount;
+}
+
+// Returns the store op count for 'memref'.
+unsigned Node::getStoreOpCount(Value memref) const {
+  unsigned storeOpCount = 0;
+  for (Operation *storeOp : stores) {
+    if (memref == cast<AffineWriteOpInterface>(storeOp).getMemRef())
+      ++storeOpCount;
+  }
+  return storeOpCount;
+}
+
+// Returns all store ops in 'storeOps' which access 'memref'.
+void Node::getStoreOpsForMemref(Value memref,
+                                SmallVectorImpl<Operation *> *storeOps) const {
+  for (Operation *storeOp : stores) {
+    if (memref == cast<AffineWriteOpInterface>(storeOp).getMemRef())
+      storeOps->push_back(storeOp);
+  }
+}
+
+// Returns all load ops in 'loadOps' which access 'memref'.
+void Node::getLoadOpsForMemref(Value memref,
+                               SmallVectorImpl<Operation *> *loadOps) const {
+  for (Operation *loadOp : loads) {
+    if (memref == cast<AffineReadOpInterface>(loadOp).getMemRef())
+      loadOps->push_back(loadOp);
+  }
+}
+
+// Returns all memrefs in 'loadAndStoreMemrefSet' for which this node
+// has at least one load and store operation.
+void Node::getLoadAndStoreMemrefSet(
+    DenseSet<Value> *loadAndStoreMemrefSet) const {
+  llvm::SmallDenseSet<Value, 2> loadMemrefs;
+  for (Operation *loadOp : loads) {
+    loadMemrefs.insert(cast<AffineReadOpInterface>(loadOp).getMemRef());
+  }
+  for (Operation *storeOp : stores) {
+    auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
+    if (loadMemrefs.count(memref) > 0)
+      loadAndStoreMemrefSet->insert(memref);
+  }
+}
+
+// Returns the graph node for 'id'.
+Node *MemRefDependenceGraph::getNode(unsigned id) {
+  auto it = nodes.find(id);
+  assert(it != nodes.end());
+  return &it->second;
+}
+
+// Returns the graph node for 'forOp'.
+Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) {
+  for (auto &idAndNode : nodes)
+    if (idAndNode.second.op == forOp)
+      return &idAndNode.second;
+  return nullptr;
+}
+
+// Adds a node with 'op' to the graph and returns its unique identifier.
+unsigned MemRefDependenceGraph::addNode(Operation *op) {
+  Node node(nextNodeId++, op);
+  nodes.insert({node.id, node});
+  return node.id;
+}
+
+// Remove node 'id' (and its associated edges) from graph.
+void MemRefDependenceGraph::removeNode(unsigned id) {
+  // Remove each edge in 'inEdges[id]'.
+  if (inEdges.count(id) > 0) {
+    SmallVector<Edge, 2> oldInEdges = inEdges[id];
+    for (auto &inEdge : oldInEdges) {
+      removeEdge(inEdge.id, id, inEdge.value);
+    }
+  }
+  // Remove each edge in 'outEdges[id]'.
+  if (outEdges.count(id) > 0) {
+    SmallVector<Edge, 2> oldOutEdges = outEdges[id];
+    for (auto &outEdge : oldOutEdges) {
+      removeEdge(id, outEdge.id, outEdge.value);
+    }
+  }
+  // Erase remaining node state.
+  inEdges.erase(id);
+  outEdges.erase(id);
+  nodes.erase(id);
+}
+
+// Returns true if node 'id' writes to any memref which escapes (or is an
+// argument to) the block. Returns false otherwise.
+bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) {
+  Node *node = getNode(id);
+  for (auto *storeOpInst : node->stores) {
+    auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
+    auto *op = memref.getDefiningOp();
+    // Return true if 'memref' is a block argument.
+    if (!op)
+      return true;
+    // Return true if any use of 'memref' does not deference it in an affine
+    // way.
+    for (auto *user : memref.getUsers())
+      if (!isa<AffineMapAccessInterface>(*user))
+        return true;
+  }
+  return false;
+}
+
+// Returns true iff there is an edge from node 'srcId' to node 'dstId' which
+// is for 'value' if non-null, or for any value otherwise. Returns false
+// otherwise.
+bool MemRefDependenceGraph::hasEdge(unsigned srcId, unsigned dstId,
+                                    Value value) {
+  if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
+    return false;
+  }
+  bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
+    return edge.id == dstId && (!value || edge.value == value);
+  });
+  bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
+    return edge.id == srcId && (!value || edge.value == value);
+  });
+  return hasOutEdge && hasInEdge;
+}
+
+// Adds an edge from node 'srcId' to node 'dstId' for 'value'.
+void MemRefDependenceGraph::addEdge(unsigned srcId, unsigned dstId,
+                                    Value value) {
+  if (!hasEdge(srcId, dstId, value)) {
+    outEdges[srcId].push_back({dstId, value});
+    inEdges[dstId].push_back({srcId, value});
+    if (value.getType().isa<MemRefType>())
+      memrefEdgeCount[value]++;
+  }
+}
+
+// Removes an edge from node 'srcId' to node 'dstId' for 'value'.
+void MemRefDependenceGraph::removeEdge(unsigned srcId, unsigned dstId,
+                                       Value value) {
+  assert(inEdges.count(dstId) > 0);
+  assert(outEdges.count(srcId) > 0);
+  if (value.getType().isa<MemRefType>()) {
+    assert(memrefEdgeCount.count(value) > 0);
+    memrefEdgeCount[value]--;
+  }
+  // Remove 'srcId' from 'inEdges[dstId]'.
+  for (auto *it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) {
+    if ((*it).id == srcId && (*it).value == value) {
+      inEdges[dstId].erase(it);
+      break;
+    }
+  }
+  // Remove 'dstId' from 'outEdges[srcId]'.
+  for (auto *it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) {
+    if ((*it).id == dstId && (*it).value == value) {
+      outEdges[srcId].erase(it);
+      break;
+    }
+  }
+}
+
+// Returns true if there is a path in the dependence graph from node 'srcId'
+// to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
+// operations that the edges connected are expected to be from the same block.
+bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
+  // Worklist state is: <node-id, next-output-edge-index-to-visit>
+  SmallVector<std::pair<unsigned, unsigned>, 4> worklist;
+  worklist.push_back({srcId, 0});
+  Operation *dstOp = getNode(dstId)->op;
+  // Run DFS traversal to see if 'dstId' is reachable from 'srcId'.
+  while (!worklist.empty()) {
+    auto &idAndIndex = worklist.back();
+    // Return true if we have reached 'dstId'.
+    if (idAndIndex.first == dstId)
+      return true;
+    // Pop and continue if node has no out edges, or if all out edges have
+    // already been visited.
+    if (outEdges.count(idAndIndex.first) == 0 ||
+        idAndIndex.second == outEdges[idAndIndex.first].size()) {
+      worklist.pop_back();
+      continue;
+    }
+    // Get graph edge to traverse.
+    Edge edge = outEdges[idAndIndex.first][idAndIndex.second];
+    // Increment next output edge index for 'idAndIndex'.
+    ++idAndIndex.second;
+    // Add node at 'edge.id' to the worklist. We don't need to consider
+    // nodes that are "after" dstId in the containing block; one can't have a
+    // path to `dstId` from any of those nodes.
+    bool afterDst = dstOp->isBeforeInBlock(getNode(edge.id)->op);
+    if (!afterDst && edge.id != idAndIndex.first)
+      worklist.push_back({edge.id, 0});
+  }
+  return false;
+}
+
+// Returns the input edge count for node 'id' and 'memref' from src nodes
+// which access 'memref' with a store operation.
+unsigned MemRefDependenceGraph::getIncomingMemRefAccesses(unsigned id,
+                                                          Value memref) {
+  unsigned inEdgeCount = 0;
+  if (inEdges.count(id) > 0)
+    for (auto &inEdge : inEdges[id])
+      if (inEdge.value == memref) {
+        Node *srcNode = getNode(inEdge.id);
+        // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
+        if (srcNode->getStoreOpCount(memref) > 0)
+          ++inEdgeCount;
+      }
+  return inEdgeCount;
+}
+
+// Returns the output edge count for node 'id' and 'memref' (if non-null),
+// otherwise returns the total output edge count from node 'id'.
+unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id, Value memref) {
+  unsigned outEdgeCount = 0;
+  if (outEdges.count(id) > 0)
+    for (auto &outEdge : outEdges[id])
+      if (!memref || outEdge.value == memref)
+        ++outEdgeCount;
+  return outEdgeCount;
+}
+
+/// Return all nodes which define SSA values used in node 'id'.
+void MemRefDependenceGraph::gatherDefiningNodes(
+    unsigned id, DenseSet<unsigned> &definingNodes) {
+  for (MemRefDependenceGraph::Edge edge : inEdges[id])
+    // By definition of edge, if the edge value is a non-memref value,
+    // then the dependence is between a graph node which defines an SSA value
+    // and another graph node which uses the SSA value.
+    if (!edge.value.getType().isa<MemRefType>())
+      definingNodes.insert(edge.id);
+}
+
+// Computes and returns an insertion point operation, before which the
+// the fused <srcId, dstId> loop nest can be inserted while preserving
+// dependences. Returns nullptr if no such insertion point is found.
+Operation *
+MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
+                                                      unsigned dstId) {
+  if (outEdges.count(srcId) == 0)
+    return getNode(dstId)->op;
+
+  // Skip if there is any defining node of 'dstId' that depends on 'srcId'.
+  DenseSet<unsigned> definingNodes;
+  gatherDefiningNodes(dstId, definingNodes);
+  if (llvm::any_of(definingNodes,
+                   [&](unsigned id) { return hasDependencePath(srcId, id); })) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "Can't fuse: a defining op with a user in the dst "
+                  "loop has dependence from the src loop\n");
+    return nullptr;
+  }
+
+  // Build set of insts in range (srcId, dstId) which depend on 'srcId'.
+  SmallPtrSet<Operation *, 2> srcDepInsts;
+  for (auto &outEdge : outEdges[srcId])
+    if (outEdge.id != dstId)
+      srcDepInsts.insert(getNode(outEdge.id)->op);
+
+  // Build set of insts in range (srcId, dstId) on which 'dstId' depends.
+  SmallPtrSet<Operation *, 2> dstDepInsts;
+  for (auto &inEdge : inEdges[dstId])
+    if (inEdge.id != srcId)
+      dstDepInsts.insert(getNode(inEdge.id)->op);
+
+  Operation *srcNodeInst = getNode(srcId)->op;
+  Operation *dstNodeInst = getNode(dstId)->op;
+
+  // Computing insertion point:
+  // *) Walk all operation positions in Block operation list in the
+  //    range (src, dst). For each operation 'op' visited in this search:
+  //   *) Store in 'firstSrcDepPos' the first position where 'op' has a
+  //      dependence edge from 'srcNode'.
+  //   *) Store in 'lastDstDepPost' the last position where 'op' has a
+  //      dependence edge to 'dstNode'.
+  // *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the
+  //    operation insertion point (or return null pointer if no such
+  //    insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos').
+  SmallVector<Operation *, 2> depInsts;
+  std::optional<unsigned> firstSrcDepPos;
+  std::optional<unsigned> lastDstDepPos;
+  unsigned pos = 0;
+  for (Block::iterator it = std::next(Block::iterator(srcNodeInst));
+       it != Block::iterator(dstNodeInst); ++it) {
+    Operation *op = &(*it);
+    if (srcDepInsts.count(op) > 0 && firstSrcDepPos == std::nullopt)
+      firstSrcDepPos = pos;
+    if (dstDepInsts.count(op) > 0)
+      lastDstDepPos = pos;
+    depInsts.push_back(op);
+    ++pos;
+  }
+
+  if (firstSrcDepPos.has_value()) {
+    if (lastDstDepPos.has_value()) {
+      if (*firstSrcDepPos <= *lastDstDepPos) {
+        // No valid insertion point exists which preserves dependences.
+        return nullptr;
+      }
+    }
+    // Return the insertion point at 'firstSrcDepPos'.
+    return depInsts[*firstSrcDepPos];
+  }
+  // No dependence targets in range (or only dst deps in range), return
+  // 'dstNodInst' insertion point.
+  return dstNodeInst;
+}
+
+// Updates edge mappings from node 'srcId' to node 'dstId' after fusing them,
+// taking into account that:
+//   *) if 'removeSrcId' is true, 'srcId' will be removed after fusion,
+//   *) memrefs in 'privateMemRefs' has been replaced in node at 'dstId' by a
+//      private memref.
+void MemRefDependenceGraph::updateEdges(unsigned srcId, unsigned dstId,
+                                        const DenseSet<Value> &privateMemRefs,
+                                        bool removeSrcId) {
+  // For each edge in 'inEdges[srcId]': add new edge remapping to 'dstId'.
+  if (inEdges.count(srcId) > 0) {
+    SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
+    for (auto &inEdge : oldInEdges) {
+      // Add edge from 'inEdge.id' to 'dstId' if it's not a private memref.
+      if (privateMemRefs.count(inEdge.value) == 0)
+        addEdge(inEdge.id, dstId, inEdge.value);
+    }
+  }
+  // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'.
+  // If 'srcId' is going to be removed, remap all the out edges to 'dstId'.
+  if (outEdges.count(srcId) > 0) {
+    SmallVector<Edge, 2> oldOutEdges = outEdges[srcId];
+    for (auto &outEdge : oldOutEdges) {
+      // Remove any out edges from 'srcId' to 'dstId' across memrefs.
+      if (outEdge.id == dstId)
+        removeEdge(srcId, outEdge.id, outEdge.value);
+      else if (removeSrcId) {
+        addEdge(dstId, outEdge.id, outEdge.value);
+        removeEdge(srcId, outEdge.id, outEdge.value);
+      }
+    }
+  }
+  // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being
+  // replaced by a private memref). These edges could come from nodes
+  // other than 'srcId' which were removed in the previous step.
+  if (inEdges.count(dstId) > 0 && !privateMemRefs.empty()) {
+    SmallVector<Edge, 2> oldInEdges = inEdges[dstId];
+    for (auto &inEdge : oldInEdges)
+      if (privateMemRefs.count(inEdge.value) > 0)
+        removeEdge(inEdge.id, dstId, inEdge.value);
+  }
+}
+
+// Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion
+// of sibling node 'sibId' into node 'dstId'.
+void MemRefDependenceGraph::updateEdges(unsigned sibId, unsigned dstId) {
+  // For each edge in 'inEdges[sibId]':
+  // *) Add new edge from source node 'inEdge.id' to 'dstNode'.
+  // *) Remove edge from source node 'inEdge.id' to 'sibNode'.
+  if (inEdges.count(sibId) > 0) {
+    SmallVector<Edge, 2> oldInEdges = inEdges[sibId];
+    for (auto &inEdge : oldInEdges) {
+      addEdge(inEdge.id, dstId, inEdge.value);
+      removeEdge(inEdge.id, sibId, inEdge.value);
+    }
+  }
+
+  // For each edge in 'outEdges[sibId]' to node 'id'
+  // *) Add new edge from 'dstId' to 'outEdge.id'.
+  // *) Remove edge from 'sibId' to 'outEdge.id'.
+  if (outEdges.count(sibId) > 0) {
+    SmallVector<Edge, 2> oldOutEdges = outEdges[sibId];
+    for (auto &outEdge : oldOutEdges) {
+      addEdge(dstId, outEdge.id, outEdge.value);
+      removeEdge(sibId, outEdge.id, outEdge.value);
+    }
+  }
+}
+
+// Adds ops in 'loads' and 'stores' to node at 'id'.
+void MemRefDependenceGraph::addToNode(
+    unsigned id, const SmallVectorImpl<Operation *> &loads,
+    const SmallVectorImpl<Operation *> &stores) {
+  Node *node = getNode(id);
+  llvm::append_range(node->loads, loads);
+  llvm::append_range(node->stores, stores);
+}
+
+void MemRefDependenceGraph::clearNodeLoadAndStores(unsigned id) {
+  Node *node = getNode(id);
+  node->loads.clear();
+  node->stores.clear();
+}
+
+// Calls 'callback' for each input edge incident to node 'id' which carries a
+// memref dependence.
+void MemRefDependenceGraph::forEachMemRefInputEdge(
+    unsigned id, const std::function<void(Edge)> &callback) {
+  if (inEdges.count(id) > 0)
+    forEachMemRefEdge(inEdges[id], callback);
+}
+
+// Calls 'callback' for each output edge from node 'id' which carries a
+// memref dependence.
+void MemRefDependenceGraph::forEachMemRefOutputEdge(
+    unsigned id, const std::function<void(Edge)> &callback) {
+  if (outEdges.count(id) > 0)
+    forEachMemRefEdge(outEdges[id], callback);
+}
+
+// Calls 'callback' for each edge in 'edges' which carries a memref
+// dependence.
+void MemRefDependenceGraph::forEachMemRefEdge(
+    ArrayRef<Edge> edges, const std::function<void(Edge)> &callback) {
+  for (const auto &edge : edges) {
+    // Skip if 'edge' is not a memref dependence edge.
+    if (!edge.value.getType().isa<MemRefType>())
+      continue;
+    assert(nodes.count(edge.id) > 0);
+    // Skip if 'edge.id' is not a loop nest.
+    if (!isa<AffineForOp>(getNode(edge.id)->op))
+      continue;
+    // Visit current input edge 'edge'.
+    callback(edge);
+  }
+}
+
+void MemRefDependenceGraph::print(raw_ostream &os) const {
+  os << "\nMemRefDependenceGraph\n";
+  os << "\nNodes:\n";
+  for (const auto &idAndNode : nodes) {
+    os << "Node: " << idAndNode.first << "\n";
+    auto it = inEdges.find(idAndNode.first);
+    if (it != inEdges.end()) {
+      for (const auto &e : it->second)
+        os << "  InEdge: " << e.id << " " << e.value << "\n";
+    }
+    it = outEdges.find(idAndNode.first);
+    if (it != outEdges.end()) {
+      for (const auto &e : it->second)
+        os << "  OutEdge: " << e.id << " " << e.value << "\n";
+    }
+  }
+}
+
 void mlir::getAffineForIVs(Operation &op, SmallVectorImpl<AffineForOp> *loops) {
   auto *currOp = op.getParentOp();
   AffineForOp currAffineForOp;

diff  --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index d7a75b265a09b..3017cb0ed0a4a 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -70,546 +70,6 @@ struct LoopFusion : public impl::AffineLoopFusionBase<LoopFusion> {
 
 } // namespace
 
-std::unique_ptr<Pass>
-mlir::createLoopFusionPass(unsigned fastMemorySpace,
-                           uint64_t localBufSizeThreshold, bool maximalFusion,
-                           enum FusionMode affineFusionMode) {
-  return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold,
-                                      maximalFusion, affineFusionMode);
-}
-
-namespace {
-
-// LoopNestStateCollector walks loop nests and collects load and store
-// operations, and whether or not a region holding op other than ForOp and IfOp
-// was encountered in the loop nest.
-struct LoopNestStateCollector {
-  SmallVector<AffineForOp, 4> forOps;
-  SmallVector<Operation *, 4> loadOpInsts;
-  SmallVector<Operation *, 4> storeOpInsts;
-  bool hasNonAffineRegionOp = false;
-
-  void collect(Operation *opToWalk) {
-    opToWalk->walk([&](Operation *op) {
-      if (isa<AffineForOp>(op))
-        forOps.push_back(cast<AffineForOp>(op));
-      else if (op->getNumRegions() != 0 && !isa<AffineIfOp>(op))
-        hasNonAffineRegionOp = true;
-      else if (isa<AffineReadOpInterface>(op))
-        loadOpInsts.push_back(op);
-      else if (isa<AffineWriteOpInterface>(op))
-        storeOpInsts.push_back(op);
-    });
-  }
-};
-
-// MemRefDependenceGraph is a graph data structure where graph nodes are
-// top-level operations in a `Block` which contain load/store ops, and edges
-// are memref dependences between the nodes.
-// TODO: Add a more flexible dependence graph representation.
-// TODO: Add a depth parameter to dependence graph construction.
-struct MemRefDependenceGraph {
-public:
-  // Node represents a node in the graph. A Node is either an entire loop nest
-  // rooted at the top level which contains loads/stores, or a top level
-  // load/store.
-  struct Node {
-    // The unique identifier of this node in the graph.
-    unsigned id;
-    // The top-level statement which is (or contains) a load/store.
-    Operation *op;
-    // List of load operations.
-    SmallVector<Operation *, 4> loads;
-    // List of store op insts.
-    SmallVector<Operation *, 4> stores;
-    Node(unsigned id, Operation *op) : id(id), op(op) {}
-
-    // Returns the load op count for 'memref'.
-    unsigned getLoadOpCount(Value memref) const {
-      unsigned loadOpCount = 0;
-      for (Operation *loadOp : loads) {
-        if (memref == cast<AffineReadOpInterface>(loadOp).getMemRef())
-          ++loadOpCount;
-      }
-      return loadOpCount;
-    }
-
-    // Returns the store op count for 'memref'.
-    unsigned getStoreOpCount(Value memref) const {
-      unsigned storeOpCount = 0;
-      for (Operation *storeOp : stores) {
-        if (memref == cast<AffineWriteOpInterface>(storeOp).getMemRef())
-          ++storeOpCount;
-      }
-      return storeOpCount;
-    }
-
-    // Returns all store ops in 'storeOps' which access 'memref'.
-    void getStoreOpsForMemref(Value memref,
-                              SmallVectorImpl<Operation *> *storeOps) const {
-      for (Operation *storeOp : stores) {
-        if (memref == cast<AffineWriteOpInterface>(storeOp).getMemRef())
-          storeOps->push_back(storeOp);
-      }
-    }
-
-    // Returns all load ops in 'loadOps' which access 'memref'.
-    void getLoadOpsForMemref(Value memref,
-                             SmallVectorImpl<Operation *> *loadOps) const {
-      for (Operation *loadOp : loads) {
-        if (memref == cast<AffineReadOpInterface>(loadOp).getMemRef())
-          loadOps->push_back(loadOp);
-      }
-    }
-
-    // Returns all memrefs in 'loadAndStoreMemrefSet' for which this node
-    // has at least one load and store operation.
-    void
-    getLoadAndStoreMemrefSet(DenseSet<Value> *loadAndStoreMemrefSet) const {
-      llvm::SmallDenseSet<Value, 2> loadMemrefs;
-      for (Operation *loadOp : loads) {
-        loadMemrefs.insert(cast<AffineReadOpInterface>(loadOp).getMemRef());
-      }
-      for (Operation *storeOp : stores) {
-        auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
-        if (loadMemrefs.count(memref) > 0)
-          loadAndStoreMemrefSet->insert(memref);
-      }
-    }
-  };
-
-  // Edge represents a data dependence between nodes in the graph.
-  struct Edge {
-    // The id of the node at the other end of the edge.
-    // If this edge is stored in Edge = Node.inEdges[i], then
-    // 'Node.inEdges[i].id' is the identifier of the source node of the edge.
-    // If this edge is stored in Edge = Node.outEdges[i], then
-    // 'Node.outEdges[i].id' is the identifier of the dest node of the edge.
-    unsigned id;
-    // The SSA value on which this edge represents a dependence.
-    // If the value is a memref, then the dependence is between graph nodes
-    // which contain accesses to the same memref 'value'. If the value is a
-    // non-memref value, then the dependence is between a graph node which
-    // defines an SSA value and another graph node which uses the SSA value
-    // (e.g. a constant or load operation defining a value which is used inside
-    // a loop nest).
-    Value value;
-  };
-
-  // Map from node id to Node.
-  DenseMap<unsigned, Node> nodes;
-  // Map from node id to list of input edges.
-  DenseMap<unsigned, SmallVector<Edge, 2>> inEdges;
-  // Map from node id to list of output edges.
-  DenseMap<unsigned, SmallVector<Edge, 2>> outEdges;
-  // Map from memref to a count on the dependence edges associated with that
-  // memref.
-  DenseMap<Value, unsigned> memrefEdgeCount;
-  // The next unique identifier to use for newly created graph nodes.
-  unsigned nextNodeId = 0;
-
-  MemRefDependenceGraph(Block &block) : block(block) {}
-
-  // Initializes the dependence graph based on operations in `block'.
-  // Returns true on success, false otherwise.
-  bool init();
-
-  // Returns the graph node for 'id'.
-  Node *getNode(unsigned id) {
-    auto it = nodes.find(id);
-    assert(it != nodes.end());
-    return &it->second;
-  }
-
-  // Returns the graph node for 'forOp'.
-  Node *getForOpNode(AffineForOp forOp) {
-    for (auto &idAndNode : nodes)
-      if (idAndNode.second.op == forOp)
-        return &idAndNode.second;
-    return nullptr;
-  }
-
-  // Adds a node with 'op' to the graph and returns its unique identifier.
-  unsigned addNode(Operation *op) {
-    Node node(nextNodeId++, op);
-    nodes.insert({node.id, node});
-    return node.id;
-  }
-
-  // Remove node 'id' (and its associated edges) from graph.
-  void removeNode(unsigned id) {
-    // Remove each edge in 'inEdges[id]'.
-    if (inEdges.count(id) > 0) {
-      SmallVector<Edge, 2> oldInEdges = inEdges[id];
-      for (auto &inEdge : oldInEdges) {
-        removeEdge(inEdge.id, id, inEdge.value);
-      }
-    }
-    // Remove each edge in 'outEdges[id]'.
-    if (outEdges.count(id) > 0) {
-      SmallVector<Edge, 2> oldOutEdges = outEdges[id];
-      for (auto &outEdge : oldOutEdges) {
-        removeEdge(id, outEdge.id, outEdge.value);
-      }
-    }
-    // Erase remaining node state.
-    inEdges.erase(id);
-    outEdges.erase(id);
-    nodes.erase(id);
-  }
-
-  // Returns true if node 'id' writes to any memref which escapes (or is an
-  // argument to) the block. Returns false otherwise.
-  bool writesToLiveInOrEscapingMemrefs(unsigned id) {
-    Node *node = getNode(id);
-    for (auto *storeOpInst : node->stores) {
-      auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
-      auto *op = memref.getDefiningOp();
-      // Return true if 'memref' is a block argument.
-      if (!op)
-        return true;
-      // Return true if any use of 'memref' does not deference it in an affine
-      // way.
-      for (auto *user : memref.getUsers())
-        if (!isa<AffineMapAccessInterface>(*user))
-          return true;
-    }
-    return false;
-  }
-
-  // Returns true iff there is an edge from node 'srcId' to node 'dstId' which
-  // is for 'value' if non-null, or for any value otherwise. Returns false
-  // otherwise.
-  bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr) {
-    if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
-      return false;
-    }
-    bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
-      return edge.id == dstId && (!value || edge.value == value);
-    });
-    bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
-      return edge.id == srcId && (!value || edge.value == value);
-    });
-    return hasOutEdge && hasInEdge;
-  }
-
-  // Adds an edge from node 'srcId' to node 'dstId' for 'value'.
-  void addEdge(unsigned srcId, unsigned dstId, Value value) {
-    if (!hasEdge(srcId, dstId, value)) {
-      outEdges[srcId].push_back({dstId, value});
-      inEdges[dstId].push_back({srcId, value});
-      if (value.getType().isa<MemRefType>())
-        memrefEdgeCount[value]++;
-    }
-  }
-
-  // Removes an edge from node 'srcId' to node 'dstId' for 'value'.
-  void removeEdge(unsigned srcId, unsigned dstId, Value value) {
-    assert(inEdges.count(dstId) > 0);
-    assert(outEdges.count(srcId) > 0);
-    if (value.getType().isa<MemRefType>()) {
-      assert(memrefEdgeCount.count(value) > 0);
-      memrefEdgeCount[value]--;
-    }
-    // Remove 'srcId' from 'inEdges[dstId]'.
-    for (auto *it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) {
-      if ((*it).id == srcId && (*it).value == value) {
-        inEdges[dstId].erase(it);
-        break;
-      }
-    }
-    // Remove 'dstId' from 'outEdges[srcId]'.
-    for (auto *it = outEdges[srcId].begin(); it != outEdges[srcId].end();
-         ++it) {
-      if ((*it).id == dstId && (*it).value == value) {
-        outEdges[srcId].erase(it);
-        break;
-      }
-    }
-  }
-
-  // Returns true if there is a path in the dependence graph from node 'srcId'
-  // to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
-  // operations that the edges connected are expected to be from the same block.
-  bool hasDependencePath(unsigned srcId, unsigned dstId) {
-    // Worklist state is: <node-id, next-output-edge-index-to-visit>
-    SmallVector<std::pair<unsigned, unsigned>, 4> worklist;
-    worklist.push_back({srcId, 0});
-    Operation *dstOp = getNode(dstId)->op;
-    // Run DFS traversal to see if 'dstId' is reachable from 'srcId'.
-    while (!worklist.empty()) {
-      auto &idAndIndex = worklist.back();
-      // Return true if we have reached 'dstId'.
-      if (idAndIndex.first == dstId)
-        return true;
-      // Pop and continue if node has no out edges, or if all out edges have
-      // already been visited.
-      if (outEdges.count(idAndIndex.first) == 0 ||
-          idAndIndex.second == outEdges[idAndIndex.first].size()) {
-        worklist.pop_back();
-        continue;
-      }
-      // Get graph edge to traverse.
-      Edge edge = outEdges[idAndIndex.first][idAndIndex.second];
-      // Increment next output edge index for 'idAndIndex'.
-      ++idAndIndex.second;
-      // Add node at 'edge.id' to the worklist. We don't need to consider
-      // nodes that are "after" dstId in the containing block; one can't have a
-      // path to `dstId` from any of those nodes.
-      bool afterDst = dstOp->isBeforeInBlock(getNode(edge.id)->op);
-      if (!afterDst && edge.id != idAndIndex.first)
-        worklist.push_back({edge.id, 0});
-    }
-    return false;
-  }
-
-  // Returns the input edge count for node 'id' and 'memref' from src nodes
-  // which access 'memref' with a store operation.
-  unsigned getIncomingMemRefAccesses(unsigned id, Value memref) {
-    unsigned inEdgeCount = 0;
-    if (inEdges.count(id) > 0)
-      for (auto &inEdge : inEdges[id])
-        if (inEdge.value == memref) {
-          Node *srcNode = getNode(inEdge.id);
-          // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
-          if (srcNode->getStoreOpCount(memref) > 0)
-            ++inEdgeCount;
-        }
-    return inEdgeCount;
-  }
-
-  // Returns the output edge count for node 'id' and 'memref' (if non-null),
-  // otherwise returns the total output edge count from node 'id'.
-  unsigned getOutEdgeCount(unsigned id, Value memref = nullptr) {
-    unsigned outEdgeCount = 0;
-    if (outEdges.count(id) > 0)
-      for (auto &outEdge : outEdges[id])
-        if (!memref || outEdge.value == memref)
-          ++outEdgeCount;
-    return outEdgeCount;
-  }
-
-  /// Return all nodes which define SSA values used in node 'id'.
-  void gatherDefiningNodes(unsigned id, DenseSet<unsigned> &definingNodes) {
-    for (MemRefDependenceGraph::Edge edge : inEdges[id])
-      // By definition of edge, if the edge value is a non-memref value,
-      // then the dependence is between a graph node which defines an SSA value
-      // and another graph node which uses the SSA value.
-      if (!edge.value.getType().isa<MemRefType>())
-        definingNodes.insert(edge.id);
-  }
-
-  // Computes and returns an insertion point operation, before which the
-  // the fused <srcId, dstId> loop nest can be inserted while preserving
-  // dependences. Returns nullptr if no such insertion point is found.
-  Operation *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId) {
-    if (outEdges.count(srcId) == 0)
-      return getNode(dstId)->op;
-
-    // Skip if there is any defining node of 'dstId' that depends on 'srcId'.
-    DenseSet<unsigned> definingNodes;
-    gatherDefiningNodes(dstId, definingNodes);
-    if (llvm::any_of(definingNodes, [&](unsigned id) {
-          return hasDependencePath(srcId, id);
-        })) {
-      LLVM_DEBUG(llvm::dbgs()
-                 << "Can't fuse: a defining op with a user in the dst "
-                    "loop has dependence from the src loop\n");
-      return nullptr;
-    }
-
-    // Build set of insts in range (srcId, dstId) which depend on 'srcId'.
-    SmallPtrSet<Operation *, 2> srcDepInsts;
-    for (auto &outEdge : outEdges[srcId])
-      if (outEdge.id != dstId)
-        srcDepInsts.insert(getNode(outEdge.id)->op);
-
-    // Build set of insts in range (srcId, dstId) on which 'dstId' depends.
-    SmallPtrSet<Operation *, 2> dstDepInsts;
-    for (auto &inEdge : inEdges[dstId])
-      if (inEdge.id != srcId)
-        dstDepInsts.insert(getNode(inEdge.id)->op);
-
-    Operation *srcNodeInst = getNode(srcId)->op;
-    Operation *dstNodeInst = getNode(dstId)->op;
-
-    // Computing insertion point:
-    // *) Walk all operation positions in Block operation list in the
-    //    range (src, dst). For each operation 'op' visited in this search:
-    //   *) Store in 'firstSrcDepPos' the first position where 'op' has a
-    //      dependence edge from 'srcNode'.
-    //   *) Store in 'lastDstDepPost' the last position where 'op' has a
-    //      dependence edge to 'dstNode'.
-    // *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the
-    //    operation insertion point (or return null pointer if no such
-    //    insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos').
-    SmallVector<Operation *, 2> depInsts;
-    std::optional<unsigned> firstSrcDepPos;
-    std::optional<unsigned> lastDstDepPos;
-    unsigned pos = 0;
-    for (Block::iterator it = std::next(Block::iterator(srcNodeInst));
-         it != Block::iterator(dstNodeInst); ++it) {
-      Operation *op = &(*it);
-      if (srcDepInsts.count(op) > 0 && firstSrcDepPos == std::nullopt)
-        firstSrcDepPos = pos;
-      if (dstDepInsts.count(op) > 0)
-        lastDstDepPos = pos;
-      depInsts.push_back(op);
-      ++pos;
-    }
-
-    if (firstSrcDepPos.has_value()) {
-      if (lastDstDepPos.has_value()) {
-        if (*firstSrcDepPos <= *lastDstDepPos) {
-          // No valid insertion point exists which preserves dependences.
-          return nullptr;
-        }
-      }
-      // Return the insertion point at 'firstSrcDepPos'.
-      return depInsts[*firstSrcDepPos];
-    }
-    // No dependence targets in range (or only dst deps in range), return
-    // 'dstNodInst' insertion point.
-    return dstNodeInst;
-  }
-
-  // Updates edge mappings from node 'srcId' to node 'dstId' after fusing them,
-  // taking into account that:
-  //   *) if 'removeSrcId' is true, 'srcId' will be removed after fusion,
-  //   *) memrefs in 'privateMemRefs' has been replaced in node at 'dstId' by a
-  //      private memref.
-  void updateEdges(unsigned srcId, unsigned dstId,
-                   const DenseSet<Value> &privateMemRefs, bool removeSrcId) {
-    // For each edge in 'inEdges[srcId]': add new edge remapping to 'dstId'.
-    if (inEdges.count(srcId) > 0) {
-      SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
-      for (auto &inEdge : oldInEdges) {
-        // Add edge from 'inEdge.id' to 'dstId' if it's not a private memref.
-        if (privateMemRefs.count(inEdge.value) == 0)
-          addEdge(inEdge.id, dstId, inEdge.value);
-      }
-    }
-    // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'.
-    // If 'srcId' is going to be removed, remap all the out edges to 'dstId'.
-    if (outEdges.count(srcId) > 0) {
-      SmallVector<Edge, 2> oldOutEdges = outEdges[srcId];
-      for (auto &outEdge : oldOutEdges) {
-        // Remove any out edges from 'srcId' to 'dstId' across memrefs.
-        if (outEdge.id == dstId)
-          removeEdge(srcId, outEdge.id, outEdge.value);
-        else if (removeSrcId) {
-          addEdge(dstId, outEdge.id, outEdge.value);
-          removeEdge(srcId, outEdge.id, outEdge.value);
-        }
-      }
-    }
-    // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being
-    // replaced by a private memref). These edges could come from nodes
-    // other than 'srcId' which were removed in the previous step.
-    if (inEdges.count(dstId) > 0 && !privateMemRefs.empty()) {
-      SmallVector<Edge, 2> oldInEdges = inEdges[dstId];
-      for (auto &inEdge : oldInEdges)
-        if (privateMemRefs.count(inEdge.value) > 0)
-          removeEdge(inEdge.id, dstId, inEdge.value);
-    }
-  }
-
-  // Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion
-  // of sibling node 'sibId' into node 'dstId'.
-  void updateEdges(unsigned sibId, unsigned dstId) {
-    // For each edge in 'inEdges[sibId]':
-    // *) Add new edge from source node 'inEdge.id' to 'dstNode'.
-    // *) Remove edge from source node 'inEdge.id' to 'sibNode'.
-    if (inEdges.count(sibId) > 0) {
-      SmallVector<Edge, 2> oldInEdges = inEdges[sibId];
-      for (auto &inEdge : oldInEdges) {
-        addEdge(inEdge.id, dstId, inEdge.value);
-        removeEdge(inEdge.id, sibId, inEdge.value);
-      }
-    }
-
-    // For each edge in 'outEdges[sibId]' to node 'id'
-    // *) Add new edge from 'dstId' to 'outEdge.id'.
-    // *) Remove edge from 'sibId' to 'outEdge.id'.
-    if (outEdges.count(sibId) > 0) {
-      SmallVector<Edge, 2> oldOutEdges = outEdges[sibId];
-      for (auto &outEdge : oldOutEdges) {
-        addEdge(dstId, outEdge.id, outEdge.value);
-        removeEdge(sibId, outEdge.id, outEdge.value);
-      }
-    }
-  }
-
-  // Adds ops in 'loads' and 'stores' to node at 'id'.
-  void addToNode(unsigned id, const SmallVectorImpl<Operation *> &loads,
-                 const SmallVectorImpl<Operation *> &stores) {
-    Node *node = getNode(id);
-    llvm::append_range(node->loads, loads);
-    llvm::append_range(node->stores, stores);
-  }
-
-  void clearNodeLoadAndStores(unsigned id) {
-    Node *node = getNode(id);
-    node->loads.clear();
-    node->stores.clear();
-  }
-
-  // Calls 'callback' for each input edge incident to node 'id' which carries a
-  // memref dependence.
-  void forEachMemRefInputEdge(unsigned id,
-                              const std::function<void(Edge)> &callback) {
-    if (inEdges.count(id) > 0)
-      forEachMemRefEdge(inEdges[id], callback);
-  }
-
-  // Calls 'callback' for each output edge from node 'id' which carries a
-  // memref dependence.
-  void forEachMemRefOutputEdge(unsigned id,
-                               const std::function<void(Edge)> &callback) {
-    if (outEdges.count(id) > 0)
-      forEachMemRefEdge(outEdges[id], callback);
-  }
-
-  // Calls 'callback' for each edge in 'edges' which carries a memref
-  // dependence.
-  void forEachMemRefEdge(ArrayRef<Edge> edges,
-                         const std::function<void(Edge)> &callback) {
-    for (const auto &edge : edges) {
-      // Skip if 'edge' is not a memref dependence edge.
-      if (!edge.value.getType().isa<MemRefType>())
-        continue;
-      assert(nodes.count(edge.id) > 0);
-      // Skip if 'edge.id' is not a loop nest.
-      if (!isa<AffineForOp>(getNode(edge.id)->op))
-        continue;
-      // Visit current input edge 'edge'.
-      callback(edge);
-    }
-  }
-
-  void print(raw_ostream &os) const {
-    os << "\nMemRefDependenceGraph\n";
-    os << "\nNodes:\n";
-    for (const auto &idAndNode : nodes) {
-      os << "Node: " << idAndNode.first << "\n";
-      auto it = inEdges.find(idAndNode.first);
-      if (it != inEdges.end()) {
-        for (const auto &e : it->second)
-          os << "  InEdge: " << e.id << " " << e.value << "\n";
-      }
-      it = outEdges.find(idAndNode.first);
-      if (it != outEdges.end()) {
-        for (const auto &e : it->second)
-          os << "  OutEdge: " << e.id << " " << e.value << "\n";
-      }
-    }
-  }
-  void dump() const { print(llvm::errs()); }
-
-  /// The block for which this graph is created to perform fusion.
-  Block █
-};
-
 /// Returns true if node 'srcId' can be removed after fusing it with node
 /// 'dstId'. The node can be removed if any of the following conditions are met:
 ///   1. 'srcId' has no output dependences after fusion and no escaping memrefs.
@@ -755,8 +215,8 @@ static bool isEscapingMemref(Value memref, Block *block) {
 
 /// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id'
 /// that escape the block or are accessed in a non-affine way.
-void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
-                           DenseSet<Value> &escapingMemRefs) {
+static void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
+                                  DenseSet<Value> &escapingMemRefs) {
   auto *node = mdg->getNode(id);
   for (Operation *storeOp : node->stores) {
     auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
@@ -767,8 +227,6 @@ void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
   }
 }
 
-} // namespace
-
 // Initializes the data dependence graph by walking operations in `block`.
 // Assigns each node in the graph a node id based on program order in 'f'.
 bool MemRefDependenceGraph::init() {
@@ -2042,3 +1500,11 @@ void LoopFusion::runOnOperation() {
     for (Block &block : region.getBlocks())
       runOnBlock(&block);
 }
+
+std::unique_ptr<Pass>
+mlir::createLoopFusionPass(unsigned fastMemorySpace,
+                           uint64_t localBufSizeThreshold, bool maximalFusion,
+                           enum FusionMode affineFusionMode) {
+  return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold,
+                                      maximalFusion, affineFusionMode);
+}


        


More information about the Mlir-commits mailing list