[Mlir-commits] [mlir] 001ba42 - [MLIR][Affine] Make affine fusion MDG API const correct (#125994)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 10 15:58:18 PST 2025


Author: Uday Bondhugula
Date: 2025-02-11T05:28:15+05:30
New Revision: 001ba42fe057de10942ac886c3bd82ee54373ddf

URL: https://github.com/llvm/llvm-project/commit/001ba42fe057de10942ac886c3bd82ee54373ddf
DIFF: https://github.com/llvm/llvm-project/commit/001ba42fe057de10942ac886c3bd82ee54373ddf.diff

LOG: [MLIR][Affine] Make affine fusion MDG API const correct (#125994)

Make affine fusion MDG API const correct. NFC changes otherwise.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
    mlir/lib/Dialect/Affine/Analysis/Utils.cpp
    mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
index 7164ade6ea53a60..5b386868cb0042e 100644
--- a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
@@ -139,9 +139,11 @@ struct MemRefDependenceGraph {
 
   // Map from node id to Node.
   DenseMap<unsigned, Node> nodes;
-  // Map from node id to list of input edges.
+  // Map from node id to list of input edges. The absence of an entry for a key
+  // is also equivalent to the absence of any edges.
   DenseMap<unsigned, SmallVector<Edge, 2>> inEdges;
-  // Map from node id to list of output edges.
+  // Map from node id to list of output edges. The absence of an entry for a
+  // node is also equivalent to the absence of any edges.
   DenseMap<unsigned, SmallVector<Edge, 2>> outEdges;
   // Map from memref to a count on the dependence edges associated with that
   // memref.
@@ -156,10 +158,21 @@ struct MemRefDependenceGraph {
   bool init();
 
   // Returns the graph node for 'id'.
-  Node *getNode(unsigned id);
+  const Node *getNode(unsigned id) const;
+  Node *getNode(unsigned id) {
+    return const_cast<Node *>(
+        static_cast<const MemRefDependenceGraph *>(this)->getNode(id));
+  }
+
+  // Returns true if the graph has node with ID `id`.
+  bool hasNode(unsigned id) const { return nodes.contains(id); }
 
   // Returns the graph node for 'forOp'.
-  Node *getForOpNode(AffineForOp forOp);
+  const Node *getForOpNode(AffineForOp forOp) const;
+  Node *getForOpNode(AffineForOp forOp) {
+    return const_cast<Node *>(
+        static_cast<const MemRefDependenceGraph *>(this)->getForOpNode(forOp));
+  }
 
   // Adds a node with 'op' to the graph and returns its unique identifier.
   unsigned addNode(Operation *op);
@@ -169,12 +182,12 @@ struct MemRefDependenceGraph {
 
   // Returns true if node 'id' writes to any memref which escapes (or is an
   // argument to) the block. Returns false otherwise.
-  bool writesToLiveInOrEscapingMemrefs(unsigned id);
+  bool writesToLiveInOrEscapingMemrefs(unsigned id) const;
 
   // Returns true iff there is an edge from node 'srcId' to node 'dstId' which
   // is for 'value' if non-null, or for any value otherwise. Returns false
   // otherwise.
-  bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr);
+  bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr) const;
 
   // Adds an edge from node 'srcId' to node 'dstId' for 'value'.
   void addEdge(unsigned srcId, unsigned dstId, Value value);
@@ -185,23 +198,25 @@ struct MemRefDependenceGraph {
   // Returns true if there is a path in the dependence graph from node 'srcId'
   // to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
   // operations that the edges connected are expected to be from the same block.
-  bool hasDependencePath(unsigned srcId, unsigned dstId);
+  bool hasDependencePath(unsigned srcId, unsigned dstId) const;
 
   // Returns the input edge count for node 'id' and 'memref' from src nodes
   // which access 'memref' with a store operation.
-  unsigned getIncomingMemRefAccesses(unsigned id, Value memref);
+  unsigned getIncomingMemRefAccesses(unsigned id, Value memref) const;
 
   // Returns the output edge count for node 'id' and 'memref' (if non-null),
   // otherwise returns the total output edge count from node 'id'.
-  unsigned getOutEdgeCount(unsigned id, Value memref = nullptr);
+  unsigned getOutEdgeCount(unsigned id, Value memref = nullptr) const;
 
   /// Return all nodes which define SSA values used in node 'id'.
-  void gatherDefiningNodes(unsigned id, DenseSet<unsigned> &definingNodes);
+  void gatherDefiningNodes(unsigned id,
+                           DenseSet<unsigned> &definingNodes) const;
 
   // Computes and returns an insertion point operation, before which the
   // the fused <srcId, dstId> loop nest can be inserted while preserving
   // dependences. Returns nullptr if no such insertion point is found.
-  Operation *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId);
+  Operation *getFusedLoopNestInsertionPoint(unsigned srcId,
+                                            unsigned dstId) const;
 
   // Updates edge mappings from node 'srcId' to node 'dstId' after fusing them,
   // taking into account that:

diff  --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index 10de0d04cbea640..b829633252fdd72 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -188,8 +188,9 @@ static void getEffectedValues(Operation *op, SmallVectorImpl<Value> &values) {
 
 /// Add `op` to MDG creating a new node and adding its memory accesses (affine
 /// or non-affine to memrefAccesses (memref -> list of nodes with accesses) map.
-Node *addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
-                   DenseMap<Value, SetVector<unsigned>> &memrefAccesses) {
+static Node *
+addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
+             DenseMap<Value, SetVector<unsigned>> &memrefAccesses) {
   auto &nodes = mdg.nodes;
   // Create graph node 'id' to represent top-level 'forOp' and record
   // all loads and store accesses it contains.
@@ -359,14 +360,14 @@ bool MemRefDependenceGraph::init() {
 }
 
 // Returns the graph node for 'id'.
-Node *MemRefDependenceGraph::getNode(unsigned id) {
+const Node *MemRefDependenceGraph::getNode(unsigned id) const {
   auto it = nodes.find(id);
   assert(it != nodes.end());
   return &it->second;
 }
 
 // Returns the graph node for 'forOp'.
-Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) {
+const Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) const {
   for (auto &idAndNode : nodes)
     if (idAndNode.second.op == forOp)
       return &idAndNode.second;
@@ -390,7 +391,7 @@ void MemRefDependenceGraph::removeNode(unsigned id) {
     }
   }
   // Remove each edge in 'outEdges[id]'.
-  if (outEdges.count(id) > 0) {
+  if (outEdges.contains(id)) {
     SmallVector<Edge, 2> oldOutEdges = outEdges[id];
     for (auto &outEdge : oldOutEdges) {
       removeEdge(id, outEdge.id, outEdge.value);
@@ -404,8 +405,8 @@ void MemRefDependenceGraph::removeNode(unsigned id) {
 
 // Returns true if node 'id' writes to any memref which escapes (or is an
 // argument to) the block. Returns false otherwise.
-bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) {
-  Node *node = getNode(id);
+bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) const {
+  const Node *node = getNode(id);
   for (auto *storeOpInst : node->stores) {
     auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
     auto *op = memref.getDefiningOp();
@@ -425,14 +426,14 @@ bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) {
 // is for 'value' if non-null, or for any value otherwise. Returns false
 // otherwise.
 bool MemRefDependenceGraph::hasEdge(unsigned srcId, unsigned dstId,
-                                    Value value) {
-  if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
+                                    Value value) const {
+  if (!outEdges.contains(srcId) || !inEdges.contains(dstId)) {
     return false;
   }
-  bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
+  bool hasOutEdge = llvm::any_of(outEdges.lookup(srcId), [=](const Edge &edge) {
     return edge.id == dstId && (!value || edge.value == value);
   });
-  bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
+  bool hasInEdge = llvm::any_of(inEdges.lookup(dstId), [=](const Edge &edge) {
     return edge.id == srcId && (!value || edge.value == value);
   });
   return hasOutEdge && hasInEdge;
@@ -477,7 +478,8 @@ void MemRefDependenceGraph::removeEdge(unsigned srcId, unsigned dstId,
 // Returns true if there is a path in the dependence graph from node 'srcId'
 // to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
 // operations that the edges connected are expected to be from the same block.
-bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
+bool MemRefDependenceGraph::hasDependencePath(unsigned srcId,
+                                              unsigned dstId) const {
   // Worklist state is: <node-id, next-output-edge-index-to-visit>
   SmallVector<std::pair<unsigned, unsigned>, 4> worklist;
   worklist.push_back({srcId, 0});
@@ -490,13 +492,13 @@ bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
       return true;
     // Pop and continue if node has no out edges, or if all out edges have
     // already been visited.
-    if (outEdges.count(idAndIndex.first) == 0 ||
-        idAndIndex.second == outEdges[idAndIndex.first].size()) {
+    if (!outEdges.contains(idAndIndex.first) ||
+        idAndIndex.second == outEdges.lookup(idAndIndex.first).size()) {
       worklist.pop_back();
       continue;
     }
     // Get graph edge to traverse.
-    Edge edge = outEdges[idAndIndex.first][idAndIndex.second];
+    const Edge edge = outEdges.lookup(idAndIndex.first)[idAndIndex.second];
     // Increment next output edge index for 'idAndIndex'.
     ++idAndIndex.second;
     // Add node at 'edge.id' to the worklist. We don't need to consider
@@ -512,34 +514,34 @@ bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
 // Returns the input edge count for node 'id' and 'memref' from src nodes
 // which access 'memref' with a store operation.
 unsigned MemRefDependenceGraph::getIncomingMemRefAccesses(unsigned id,
-                                                          Value memref) {
+                                                          Value memref) const {
   unsigned inEdgeCount = 0;
-  if (inEdges.count(id) > 0)
-    for (auto &inEdge : inEdges[id])
-      if (inEdge.value == memref) {
-        Node *srcNode = getNode(inEdge.id);
-        // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
-        if (srcNode->getStoreOpCount(memref) > 0)
-          ++inEdgeCount;
-      }
+  for (const Edge &inEdge : inEdges.lookup(id)) {
+    if (inEdge.value == memref) {
+      const Node *srcNode = getNode(inEdge.id);
+      // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
+      if (srcNode->getStoreOpCount(memref) > 0)
+        ++inEdgeCount;
+    }
+  }
   return inEdgeCount;
 }
 
 // Returns the output edge count for node 'id' and 'memref' (if non-null),
 // otherwise returns the total output edge count from node 'id'.
-unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id, Value memref) {
+unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id,
+                                                Value memref) const {
   unsigned outEdgeCount = 0;
-  if (outEdges.count(id) > 0)
-    for (auto &outEdge : outEdges[id])
-      if (!memref || outEdge.value == memref)
-        ++outEdgeCount;
+  for (const auto &outEdge : outEdges.lookup(id))
+    if (!memref || outEdge.value == memref)
+      ++outEdgeCount;
   return outEdgeCount;
 }
 
 /// Return all nodes which define SSA values used in node 'id'.
 void MemRefDependenceGraph::gatherDefiningNodes(
-    unsigned id, DenseSet<unsigned> &definingNodes) {
-  for (MemRefDependenceGraph::Edge edge : inEdges[id])
+    unsigned id, DenseSet<unsigned> &definingNodes) const {
+  for (const Edge &edge : inEdges.lookup(id))
     // By definition of edge, if the edge value is a non-memref value,
     // then the dependence is between a graph node which defines an SSA value
     // and another graph node which uses the SSA value.
@@ -552,8 +554,8 @@ void MemRefDependenceGraph::gatherDefiningNodes(
 // dependences. Returns nullptr if no such insertion point is found.
 Operation *
 MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
-                                                      unsigned dstId) {
-  if (outEdges.count(srcId) == 0)
+                                                      unsigned dstId) const {
+  if (!outEdges.contains(srcId))
     return getNode(dstId)->op;
 
   // Skip if there is any defining node of 'dstId' that depends on 'srcId'.
@@ -569,13 +571,13 @@ MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
 
   // Build set of insts in range (srcId, dstId) which depend on 'srcId'.
   SmallPtrSet<Operation *, 2> srcDepInsts;
-  for (auto &outEdge : outEdges[srcId])
+  for (auto &outEdge : outEdges.lookup(srcId))
     if (outEdge.id != dstId)
       srcDepInsts.insert(getNode(outEdge.id)->op);
 
   // Build set of insts in range (srcId, dstId) on which 'dstId' depends.
   SmallPtrSet<Operation *, 2> dstDepInsts;
-  for (auto &inEdge : inEdges[dstId])
+  for (auto &inEdge : inEdges.lookup(dstId))
     if (inEdge.id != srcId)
       dstDepInsts.insert(getNode(inEdge.id)->op);
 
@@ -635,7 +637,7 @@ void MemRefDependenceGraph::updateEdges(unsigned srcId, unsigned dstId,
     SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
     for (auto &inEdge : oldInEdges) {
       // Add edge from 'inEdge.id' to 'dstId' if it's not a private memref.
-      if (privateMemRefs.count(inEdge.value) == 0)
+      if (!privateMemRefs.contains(inEdge.value))
         addEdge(inEdge.id, dstId, inEdge.value);
     }
   }

diff  --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index fe6cf0f434cb7eb..b38dd8effe669df 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -78,13 +78,13 @@ struct LoopFusion : public affine::impl::AffineLoopFusionBase<LoopFusion> {
 static bool canRemoveSrcNodeAfterFusion(
     unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice,
     Operation *fusedLoopInsPoint, const DenseSet<Value> &escapingMemRefs,
-    MemRefDependenceGraph *mdg) {
+    const MemRefDependenceGraph &mdg) {
 
-  Operation *dstNodeOp = mdg->getNode(dstId)->op;
+  Operation *dstNodeOp = mdg.getNode(dstId)->op;
   bool hasOutDepsAfterFusion = false;
 
-  for (auto &outEdge : mdg->outEdges[srcId]) {
-    Operation *depNodeOp = mdg->getNode(outEdge.id)->op;
+  for (auto &outEdge : mdg.outEdges.lookup(srcId)) {
+    Operation *depNodeOp = mdg.getNode(outEdge.id)->op;
     // Skip dependence with dstOp since it will be removed after fusion.
     if (depNodeOp == dstNodeOp)
       continue;
@@ -134,22 +134,23 @@ static bool canRemoveSrcNodeAfterFusion(
 /// held if the 'mdg' is reused from a previous fusion step or if the node
 /// creation order changes in the future to support more advance cases.
 // TODO: Move this to a loop fusion utility once 'mdg' is also moved.
-static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg,
+static void getProducerCandidates(unsigned dstId,
+                                  const MemRefDependenceGraph &mdg,
                                   SmallVectorImpl<unsigned> &srcIdCandidates) {
   // Skip if no input edges along which to fuse.
-  if (mdg->inEdges.count(dstId) == 0)
+  if (mdg.inEdges.count(dstId) == 0)
     return;
 
   // Gather memrefs from loads in 'dstId'.
-  auto *dstNode = mdg->getNode(dstId);
+  auto *dstNode = mdg.getNode(dstId);
   DenseSet<Value> consumedMemrefs;
   for (Operation *load : dstNode->loads)
     consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef());
 
   // Traverse 'dstId' incoming edges and gather the nodes that contain a store
   // to one of the consumed memrefs.
-  for (auto &srcEdge : mdg->inEdges[dstId]) {
-    auto *srcNode = mdg->getNode(srcEdge.id);
+  for (const auto &srcEdge : mdg.inEdges.lookup(dstId)) {
+    const auto *srcNode = mdg.getNode(srcEdge.id);
     // Skip if 'srcNode' is not a loop nest.
     if (!isa<AffineForOp>(srcNode->op))
       continue;
@@ -169,10 +170,10 @@ static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg,
 /// producer-consumer dependence between 'srcId' and 'dstId'.
 static void
 gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId,
-                              MemRefDependenceGraph *mdg,
+                              const MemRefDependenceGraph &mdg,
                               DenseSet<Value> &producerConsumerMemrefs) {
-  auto *dstNode = mdg->getNode(dstId);
-  auto *srcNode = mdg->getNode(srcId);
+  auto *dstNode = mdg.getNode(dstId);
+  auto *srcNode = mdg.getNode(srcId);
   gatherProducerConsumerMemrefs(srcNode->stores, dstNode->loads,
                                 producerConsumerMemrefs);
 }
@@ -214,14 +215,14 @@ static bool isEscapingMemref(Value memref, Block *block) {
 
 /// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id'
 /// that escape the block or are accessed in a non-affine way.
-static void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
+static void gatherEscapingMemrefs(unsigned id, const MemRefDependenceGraph &mdg,
                                   DenseSet<Value> &escapingMemRefs) {
-  auto *node = mdg->getNode(id);
+  auto *node = mdg.getNode(id);
   for (Operation *storeOp : node->stores) {
     auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
     if (escapingMemRefs.count(memref))
       continue;
-    if (isEscapingMemref(memref, &mdg->block))
+    if (isEscapingMemref(memref, &mdg.block))
       escapingMemRefs.insert(memref);
   }
 }
@@ -826,7 +827,7 @@ struct GreedyFusion {
       // in 'srcIdCandidates'.
       dstNodeChanged = false;
       SmallVector<unsigned, 16> srcIdCandidates;
-      getProducerCandidates(dstId, mdg, srcIdCandidates);
+      getProducerCandidates(dstId, *mdg, srcIdCandidates);
 
       for (unsigned srcId : llvm::reverse(srcIdCandidates)) {
         // Get 'srcNode' from which to attempt fusion into 'dstNode'.
@@ -841,7 +842,7 @@ struct GreedyFusion {
           continue;
 
         DenseSet<Value> producerConsumerMemrefs;
-        gatherProducerConsumerMemrefs(srcId, dstId, mdg,
+        gatherProducerConsumerMemrefs(srcId, dstId, *mdg,
                                       producerConsumerMemrefs);
 
         // Skip if 'srcNode' out edge count on any memref is greater than
@@ -856,7 +857,7 @@ struct GreedyFusion {
         // block (e.g., memref block arguments, returned memrefs,
         // memrefs passed to function calls, etc.).
         DenseSet<Value> srcEscapingMemRefs;
-        gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs);
+        gatherEscapingMemrefs(srcNode->id, *mdg, srcEscapingMemRefs);
 
         // Compute an operation list insertion point for the fused loop
         // nest which preserves dependences.
@@ -950,7 +951,7 @@ struct GreedyFusion {
         // insertion point.
         bool removeSrcNode = canRemoveSrcNodeAfterFusion(
             srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs,
-            mdg);
+            *mdg);
 
         DenseSet<Value> privateMemrefs;
         for (Value memref : producerConsumerMemrefs) {


        


More information about the Mlir-commits mailing list