[Mlir-commits] [mlir] [MLIR][Affine] Make affine fusion MDG API const correct (PR #125994)

Uday Bondhugula llvmlistbot at llvm.org
Mon Feb 10 07:45:03 PST 2025


https://github.com/bondhugula updated https://github.com/llvm/llvm-project/pull/125994

>From 5650e0acfea676bd8b4e045fdc61cfabce5e3134 Mon Sep 17 00:00:00 2001
From: Uday Bondhugula <uday at polymagelabs.com>
Date: Wed, 5 Feb 2025 13:19:40 +0530
Subject: [PATCH] [MLIR][Affine] Make affine fusion MDG API const correct

Make affine fusion MDG API const correct. NFC changes otherwise.
---
 .../mlir/Dialect/Affine/Analysis/Utils.h      | 37 +++++++---
 mlir/lib/Dialect/Affine/Analysis/Utils.cpp    | 74 ++++++++++---------
 2 files changed, 64 insertions(+), 47 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
index b1fbf4477428ca2..7355509807b4dbc 100644
--- a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
@@ -139,9 +139,11 @@ struct MemRefDependenceGraph {
 
   // Map from node id to Node.
   DenseMap<unsigned, Node> nodes;
-  // Map from node id to list of input edges.
+  // Map from node id to list of input edges. The absence of an entry for a key
+  // is also equivalent to the absence of any edges.
   DenseMap<unsigned, SmallVector<Edge, 2>> inEdges;
-  // Map from node id to list of output edges.
+  // Map from node id to list of output edges. The absence of an entry for a
+  // node is also equivalent to the absence of any edges.
   DenseMap<unsigned, SmallVector<Edge, 2>> outEdges;
   // Map from memref to a count on the dependence edges associated with that
   // memref.
@@ -156,10 +158,21 @@ struct MemRefDependenceGraph {
   bool init();
 
   // Returns the graph node for 'id'.
-  Node *getNode(unsigned id);
+  const Node *getNode(unsigned id) const;
+  Node *getNode(unsigned id) {
+    return const_cast<Node *>(
+        static_cast<const MemRefDependenceGraph *>(this)->getNode(id));
+  }
+
+  // Returns true if the graph has node with ID `id`.
+  bool hasNode(unsigned id) const { return nodes.contains(id); }
 
   // Returns the graph node for 'forOp'.
-  Node *getForOpNode(AffineForOp forOp);
+  const Node *getForOpNode(AffineForOp forOp) const;
+  Node *getForOpNode(AffineForOp forOp) {
+    return const_cast<Node *>(
+        static_cast<const MemRefDependenceGraph *>(this)->getForOpNode(forOp));
+  }
 
   // Adds a node with 'op' to the graph and returns its unique identifier.
   unsigned addNode(Operation *op);
@@ -169,12 +182,12 @@ struct MemRefDependenceGraph {
 
   // Returns true if node 'id' writes to any memref which escapes (or is an
   // argument to) the block. Returns false otherwise.
-  bool writesToLiveInOrEscapingMemrefs(unsigned id);
+  bool writesToLiveInOrEscapingMemrefs(unsigned id) const;
 
   // Returns true iff there is an edge from node 'srcId' to node 'dstId' which
   // is for 'value' if non-null, or for any value otherwise. Returns false
   // otherwise.
-  bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr);
+  bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr) const;
 
   // Adds an edge from node 'srcId' to node 'dstId' for 'value'.
   void addEdge(unsigned srcId, unsigned dstId, Value value);
@@ -185,23 +198,25 @@ struct MemRefDependenceGraph {
   // Returns true if there is a path in the dependence graph from node 'srcId'
   // to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
   // operations that the edges connected are expected to be from the same block.
-  bool hasDependencePath(unsigned srcId, unsigned dstId);
+  bool hasDependencePath(unsigned srcId, unsigned dstId) const;
 
   // Returns the input edge count for node 'id' and 'memref' from src nodes
   // which access 'memref' with a store operation.
-  unsigned getIncomingMemRefAccesses(unsigned id, Value memref);
+  unsigned getIncomingMemRefAccesses(unsigned id, Value memref) const;
 
   // Returns the output edge count for node 'id' and 'memref' (if non-null),
   // otherwise returns the total output edge count from node 'id'.
-  unsigned getOutEdgeCount(unsigned id, Value memref = nullptr);
+  unsigned getOutEdgeCount(unsigned id, Value memref = nullptr) const;
 
   /// Return all nodes which define SSA values used in node 'id'.
-  void gatherDefiningNodes(unsigned id, DenseSet<unsigned> &definingNodes);
+  void gatherDefiningNodes(unsigned id,
+                           DenseSet<unsigned> &definingNodes) const;
 
   // Computes and returns an insertion point operation, before which the
   // the fused <srcId, dstId> loop nest can be inserted while preserving
   // dependences. Returns nullptr if no such insertion point is found.
-  Operation *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId);
+  Operation *getFusedLoopNestInsertionPoint(unsigned srcId,
+                                            unsigned dstId) const;
 
   // Updates edge mappings from node 'srcId' to node 'dstId' after fusing them,
   // taking into account that:
diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index 9c0b5dbf52d299b..92e7667ff2c72f0 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -187,8 +187,9 @@ static void getEffectedValues(Operation *op, SmallVectorImpl<Value> &values) {
 
 /// Add `op` to MDG creating a new node and adding its memory accesses (affine
 /// or non-affine to memrefAccesses (memref -> list of nodes with accesses) map.
-Node *addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
-                   DenseMap<Value, SetVector<unsigned>> &memrefAccesses) {
+static Node *
+addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
+             DenseMap<Value, SetVector<unsigned>> &memrefAccesses) {
   auto &nodes = mdg.nodes;
   // Create graph node 'id' to represent top-level 'forOp' and record
   // all loads and store accesses it contains.
@@ -358,14 +359,14 @@ bool MemRefDependenceGraph::init() {
 }
 
 // Returns the graph node for 'id'.
-Node *MemRefDependenceGraph::getNode(unsigned id) {
+const Node *MemRefDependenceGraph::getNode(unsigned id) const {
   auto it = nodes.find(id);
   assert(it != nodes.end());
   return &it->second;
 }
 
 // Returns the graph node for 'forOp'.
-Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) {
+const Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) const {
   for (auto &idAndNode : nodes)
     if (idAndNode.second.op == forOp)
       return &idAndNode.second;
@@ -389,7 +390,7 @@ void MemRefDependenceGraph::removeNode(unsigned id) {
     }
   }
   // Remove each edge in 'outEdges[id]'.
-  if (outEdges.count(id) > 0) {
+  if (outEdges.contains(id)) {
     SmallVector<Edge, 2> oldOutEdges = outEdges[id];
     for (auto &outEdge : oldOutEdges) {
       removeEdge(id, outEdge.id, outEdge.value);
@@ -403,8 +404,8 @@ void MemRefDependenceGraph::removeNode(unsigned id) {
 
 // Returns true if node 'id' writes to any memref which escapes (or is an
 // argument to) the block. Returns false otherwise.
-bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) {
-  Node *node = getNode(id);
+bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) const {
+  const Node *node = getNode(id);
   for (auto *storeOpInst : node->stores) {
     auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
     auto *op = memref.getDefiningOp();
@@ -424,14 +425,14 @@ bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) {
 // is for 'value' if non-null, or for any value otherwise. Returns false
 // otherwise.
 bool MemRefDependenceGraph::hasEdge(unsigned srcId, unsigned dstId,
-                                    Value value) {
-  if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
+                                    Value value) const {
+  if (!outEdges.contains(srcId) || !inEdges.contains(dstId)) {
     return false;
   }
-  bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
+  bool hasOutEdge = llvm::any_of(outEdges.lookup(srcId), [=](const Edge &edge) {
     return edge.id == dstId && (!value || edge.value == value);
   });
-  bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
+  bool hasInEdge = llvm::any_of(inEdges.lookup(dstId), [=](const Edge &edge) {
     return edge.id == srcId && (!value || edge.value == value);
   });
   return hasOutEdge && hasInEdge;
@@ -476,7 +477,8 @@ void MemRefDependenceGraph::removeEdge(unsigned srcId, unsigned dstId,
 // Returns true if there is a path in the dependence graph from node 'srcId'
 // to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
 // operations that the edges connected are expected to be from the same block.
-bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
+bool MemRefDependenceGraph::hasDependencePath(unsigned srcId,
+                                              unsigned dstId) const {
   // Worklist state is: <node-id, next-output-edge-index-to-visit>
   SmallVector<std::pair<unsigned, unsigned>, 4> worklist;
   worklist.push_back({srcId, 0});
@@ -489,13 +491,13 @@ bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
       return true;
     // Pop and continue if node has no out edges, or if all out edges have
     // already been visited.
-    if (outEdges.count(idAndIndex.first) == 0 ||
-        idAndIndex.second == outEdges[idAndIndex.first].size()) {
+    if (!outEdges.contains(idAndIndex.first) ||
+        idAndIndex.second == outEdges.lookup(idAndIndex.first).size()) {
       worklist.pop_back();
       continue;
     }
     // Get graph edge to traverse.
-    Edge edge = outEdges[idAndIndex.first][idAndIndex.second];
+    const Edge edge = outEdges.lookup(idAndIndex.first)[idAndIndex.second];
     // Increment next output edge index for 'idAndIndex'.
     ++idAndIndex.second;
     // Add node at 'edge.id' to the worklist. We don't need to consider
@@ -511,34 +513,34 @@ bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
 // Returns the input edge count for node 'id' and 'memref' from src nodes
 // which access 'memref' with a store operation.
 unsigned MemRefDependenceGraph::getIncomingMemRefAccesses(unsigned id,
-                                                          Value memref) {
+                                                          Value memref) const {
   unsigned inEdgeCount = 0;
-  if (inEdges.count(id) > 0)
-    for (auto &inEdge : inEdges[id])
-      if (inEdge.value == memref) {
-        Node *srcNode = getNode(inEdge.id);
-        // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
-        if (srcNode->getStoreOpCount(memref) > 0)
-          ++inEdgeCount;
-      }
+  for (const Edge &inEdge : inEdges.lookup(id)) {
+    if (inEdge.value == memref) {
+      const Node *srcNode = getNode(inEdge.id);
+      // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
+      if (srcNode->getStoreOpCount(memref) > 0)
+        ++inEdgeCount;
+    }
+  }
   return inEdgeCount;
 }
 
 // Returns the output edge count for node 'id' and 'memref' (if non-null),
 // otherwise returns the total output edge count from node 'id'.
-unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id, Value memref) {
+unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id,
+                                                Value memref) const {
   unsigned outEdgeCount = 0;
-  if (outEdges.count(id) > 0)
-    for (auto &outEdge : outEdges[id])
-      if (!memref || outEdge.value == memref)
-        ++outEdgeCount;
+  for (const auto &outEdge : outEdges.lookup(id))
+    if (!memref || outEdge.value == memref)
+      ++outEdgeCount;
   return outEdgeCount;
 }
 
 /// Return all nodes which define SSA values used in node 'id'.
 void MemRefDependenceGraph::gatherDefiningNodes(
-    unsigned id, DenseSet<unsigned> &definingNodes) {
-  for (MemRefDependenceGraph::Edge edge : inEdges[id])
+    unsigned id, DenseSet<unsigned> &definingNodes) const {
+  for (const Edge &edge : inEdges.lookup(id))
     // By definition of edge, if the edge value is a non-memref value,
     // then the dependence is between a graph node which defines an SSA value
     // and another graph node which uses the SSA value.
@@ -551,8 +553,8 @@ void MemRefDependenceGraph::gatherDefiningNodes(
 // dependences. Returns nullptr if no such insertion point is found.
 Operation *
 MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
-                                                      unsigned dstId) {
-  if (outEdges.count(srcId) == 0)
+                                                      unsigned dstId) const {
+  if (!outEdges.contains(srcId))
     return getNode(dstId)->op;
 
   // Skip if there is any defining node of 'dstId' that depends on 'srcId'.
@@ -568,13 +570,13 @@ MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
 
   // Build set of insts in range (srcId, dstId) which depend on 'srcId'.
   SmallPtrSet<Operation *, 2> srcDepInsts;
-  for (auto &outEdge : outEdges[srcId])
+  for (auto &outEdge : outEdges.lookup(srcId))
     if (outEdge.id != dstId)
       srcDepInsts.insert(getNode(outEdge.id)->op);
 
   // Build set of insts in range (srcId, dstId) on which 'dstId' depends.
   SmallPtrSet<Operation *, 2> dstDepInsts;
-  for (auto &inEdge : inEdges[dstId])
+  for (auto &inEdge : inEdges.lookup(dstId))
     if (inEdge.id != srcId)
       dstDepInsts.insert(getNode(inEdge.id)->op);
 
@@ -634,7 +636,7 @@ void MemRefDependenceGraph::updateEdges(unsigned srcId, unsigned dstId,
     SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
     for (auto &inEdge : oldInEdges) {
       // Add edge from 'inEdge.id' to 'dstId' if it's not a private memref.
-      if (privateMemRefs.count(inEdge.value) == 0)
+      if (!privateMemRefs.contains(inEdge.value))
         addEdge(inEdge.id, dstId, inEdge.value);
     }
   }



More information about the Mlir-commits mailing list