[Mlir-commits] [mlir] [MLIR][Affine] Make affine fusion MDG API const correct (PR #125994)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 5 19:48:36 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Uday Bondhugula (bondhugula)

<details>
<summary>Changes</summary>

Make affine fusion MDG API const correct. NFC changes otherwise.


---
Full diff: https://github.com/llvm/llvm-project/pull/125994.diff


2 Files Affected:

- (modified) mlir/include/mlir/Dialect/Affine/Analysis/Utils.h (+20-9) 
- (modified) mlir/lib/Dialect/Affine/Analysis/Utils.cpp (+31-28) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
index b1fbf4477428ca2..97cf29ce045ced6 100644
--- a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
@@ -156,10 +156,19 @@ struct MemRefDependenceGraph {
   bool init();
 
   // Returns the graph node for 'id'.
-  Node *getNode(unsigned id);
+  const Node *getNode(unsigned id) const;
+  Node *getNode(unsigned id) {
+    return const_cast<Node *>(
+        static_cast<const MemRefDependenceGraph *>(this)->getNode(id));
+  }
+  bool hasNode(unsigned id) const { return nodes.contains(id); }
 
   // Returns the graph node for 'forOp'.
-  Node *getForOpNode(AffineForOp forOp);
+  const Node *getForOpNode(AffineForOp forOp) const;
+  Node *getForOpNode(AffineForOp forOp) {
+    return const_cast<Node *>(
+        static_cast<const MemRefDependenceGraph *>(this)->getForOpNode(forOp));
+  }
 
   // Adds a node with 'op' to the graph and returns its unique identifier.
   unsigned addNode(Operation *op);
@@ -169,12 +178,12 @@ struct MemRefDependenceGraph {
 
   // Returns true if node 'id' writes to any memref which escapes (or is an
   // argument to) the block. Returns false otherwise.
-  bool writesToLiveInOrEscapingMemrefs(unsigned id);
+  bool writesToLiveInOrEscapingMemrefs(unsigned id) const;
 
   // Returns true iff there is an edge from node 'srcId' to node 'dstId' which
   // is for 'value' if non-null, or for any value otherwise. Returns false
   // otherwise.
-  bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr);
+  bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr) const;
 
   // Adds an edge from node 'srcId' to node 'dstId' for 'value'.
   void addEdge(unsigned srcId, unsigned dstId, Value value);
@@ -185,23 +194,25 @@ struct MemRefDependenceGraph {
   // Returns true if there is a path in the dependence graph from node 'srcId'
   // to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
   // operations that the edges connected are expected to be from the same block.
-  bool hasDependencePath(unsigned srcId, unsigned dstId);
+  bool hasDependencePath(unsigned srcId, unsigned dstId) const;
 
   // Returns the input edge count for node 'id' and 'memref' from src nodes
   // which access 'memref' with a store operation.
-  unsigned getIncomingMemRefAccesses(unsigned id, Value memref);
+  unsigned getIncomingMemRefAccesses(unsigned id, Value memref) const;
 
   // Returns the output edge count for node 'id' and 'memref' (if non-null),
   // otherwise returns the total output edge count from node 'id'.
-  unsigned getOutEdgeCount(unsigned id, Value memref = nullptr);
+  unsigned getOutEdgeCount(unsigned id, Value memref = nullptr) const;
 
   /// Return all nodes which define SSA values used in node 'id'.
-  void gatherDefiningNodes(unsigned id, DenseSet<unsigned> &definingNodes);
+  void gatherDefiningNodes(unsigned id,
+                           DenseSet<unsigned> &definingNodes) const;
 
   // Computes and returns an insertion point operation, before which the
   // the fused <srcId, dstId> loop nest can be inserted while preserving
   // dependences. Returns nullptr if no such insertion point is found.
-  Operation *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId);
+  Operation *getFusedLoopNestInsertionPoint(unsigned srcId,
+                                            unsigned dstId) const;
 
   // Updates edge mappings from node 'srcId' to node 'dstId' after fusing them,
   // taking into account that:
diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index 9c0b5dbf52d299b..54bb529041e0914 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -187,8 +187,9 @@ static void getEffectedValues(Operation *op, SmallVectorImpl<Value> &values) {
 
 /// Add `op` to MDG creating a new node and adding its memory accesses (affine
 /// or non-affine to memrefAccesses (memref -> list of nodes with accesses) map.
-Node *addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
-                   DenseMap<Value, SetVector<unsigned>> &memrefAccesses) {
+static Node *
+addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
+             DenseMap<Value, SetVector<unsigned>> &memrefAccesses) {
   auto &nodes = mdg.nodes;
   // Create graph node 'id' to represent top-level 'forOp' and record
   // all loads and store accesses it contains.
@@ -358,14 +359,14 @@ bool MemRefDependenceGraph::init() {
 }
 
 // Returns the graph node for 'id'.
-Node *MemRefDependenceGraph::getNode(unsigned id) {
+const Node *MemRefDependenceGraph::getNode(unsigned id) const {
   auto it = nodes.find(id);
   assert(it != nodes.end());
   return &it->second;
 }
 
 // Returns the graph node for 'forOp'.
-Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) {
+const Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) const {
   for (auto &idAndNode : nodes)
     if (idAndNode.second.op == forOp)
       return &idAndNode.second;
@@ -403,8 +404,8 @@ void MemRefDependenceGraph::removeNode(unsigned id) {
 
 // Returns true if node 'id' writes to any memref which escapes (or is an
 // argument to) the block. Returns false otherwise.
-bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) {
-  Node *node = getNode(id);
+bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) const {
+  const Node *node = getNode(id);
   for (auto *storeOpInst : node->stores) {
     auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
     auto *op = memref.getDefiningOp();
@@ -424,14 +425,14 @@ bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) {
 // is for 'value' if non-null, or for any value otherwise. Returns false
 // otherwise.
 bool MemRefDependenceGraph::hasEdge(unsigned srcId, unsigned dstId,
-                                    Value value) {
+                                    Value value) const {
   if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
     return false;
   }
-  bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
+  bool hasOutEdge = llvm::any_of(outEdges.lookup(srcId), [=](Edge &edge) {
     return edge.id == dstId && (!value || edge.value == value);
   });
-  bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
+  bool hasInEdge = llvm::any_of(inEdges.lookup(dstId), [=](Edge &edge) {
     return edge.id == srcId && (!value || edge.value == value);
   });
   return hasOutEdge && hasInEdge;
@@ -476,7 +477,8 @@ void MemRefDependenceGraph::removeEdge(unsigned srcId, unsigned dstId,
 // Returns true if there is a path in the dependence graph from node 'srcId'
 // to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
 // operations that the edges connected are expected to be from the same block.
-bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
+bool MemRefDependenceGraph::hasDependencePath(unsigned srcId,
+                                              unsigned dstId) const {
   // Worklist state is: <node-id, next-output-edge-index-to-visit>
   SmallVector<std::pair<unsigned, unsigned>, 4> worklist;
   worklist.push_back({srcId, 0});
@@ -490,12 +492,12 @@ bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
     // Pop and continue if node has no out edges, or if all out edges have
     // already been visited.
     if (outEdges.count(idAndIndex.first) == 0 ||
-        idAndIndex.second == outEdges[idAndIndex.first].size()) {
+        idAndIndex.second == outEdges.lookup(idAndIndex.first).size()) {
       worklist.pop_back();
       continue;
     }
     // Get graph edge to traverse.
-    Edge edge = outEdges[idAndIndex.first][idAndIndex.second];
+    const Edge edge = outEdges.lookup(idAndIndex.first)[idAndIndex.second];
     // Increment next output edge index for 'idAndIndex'.
     ++idAndIndex.second;
     // Add node at 'edge.id' to the worklist. We don't need to consider
@@ -511,25 +513,26 @@ bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
 // Returns the input edge count for node 'id' and 'memref' from src nodes
 // which access 'memref' with a store operation.
 unsigned MemRefDependenceGraph::getIncomingMemRefAccesses(unsigned id,
-                                                          Value memref) {
+                                                          Value memref) const {
   unsigned inEdgeCount = 0;
-  if (inEdges.count(id) > 0)
-    for (auto &inEdge : inEdges[id])
-      if (inEdge.value == memref) {
-        Node *srcNode = getNode(inEdge.id);
-        // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
-        if (srcNode->getStoreOpCount(memref) > 0)
-          ++inEdgeCount;
-      }
+  for (const MemRefDependenceGraph::Edge &inEdge : inEdges.lookup(id)) {
+    if (inEdge.value == memref) {
+      const Node *srcNode = getNode(inEdge.id);
+      // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
+      if (srcNode->getStoreOpCount(memref) > 0)
+        ++inEdgeCount;
+    }
+  }
   return inEdgeCount;
 }
 
 // Returns the output edge count for node 'id' and 'memref' (if non-null),
 // otherwise returns the total output edge count from node 'id'.
-unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id, Value memref) {
+unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id,
+                                                Value memref) const {
   unsigned outEdgeCount = 0;
   if (outEdges.count(id) > 0)
-    for (auto &outEdge : outEdges[id])
+    for (auto &outEdge : outEdges.lookup(id))
       if (!memref || outEdge.value == memref)
         ++outEdgeCount;
   return outEdgeCount;
@@ -537,8 +540,8 @@ unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id, Value memref) {
 
 /// Return all nodes which define SSA values used in node 'id'.
 void MemRefDependenceGraph::gatherDefiningNodes(
-    unsigned id, DenseSet<unsigned> &definingNodes) {
-  for (MemRefDependenceGraph::Edge edge : inEdges[id])
+    unsigned id, DenseSet<unsigned> &definingNodes) const {
+  for (MemRefDependenceGraph::Edge edge : inEdges.lookup(id))
     // By definition of edge, if the edge value is a non-memref value,
     // then the dependence is between a graph node which defines an SSA value
     // and another graph node which uses the SSA value.
@@ -551,7 +554,7 @@ void MemRefDependenceGraph::gatherDefiningNodes(
 // dependences. Returns nullptr if no such insertion point is found.
 Operation *
 MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
-                                                      unsigned dstId) {
+                                                      unsigned dstId) const {
   if (outEdges.count(srcId) == 0)
     return getNode(dstId)->op;
 
@@ -568,13 +571,13 @@ MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
 
   // Build set of insts in range (srcId, dstId) which depend on 'srcId'.
   SmallPtrSet<Operation *, 2> srcDepInsts;
-  for (auto &outEdge : outEdges[srcId])
+  for (auto &outEdge : outEdges.lookup(srcId))
     if (outEdge.id != dstId)
       srcDepInsts.insert(getNode(outEdge.id)->op);
 
   // Build set of insts in range (srcId, dstId) on which 'dstId' depends.
   SmallPtrSet<Operation *, 2> dstDepInsts;
-  for (auto &inEdge : inEdges[dstId])
+  for (auto &inEdge : inEdges.lookup(dstId))
     if (inEdge.id != srcId)
       dstDepInsts.insert(getNode(inEdge.id)->op);
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/125994


More information about the Mlir-commits mailing list