[Mlir-commits] [mlir] [MLIR][Affine] Make affine fusion MDG API const correct (PR #125994)
Uday Bondhugula
llvmlistbot at llvm.org
Mon Feb 10 07:45:03 PST 2025
https://github.com/bondhugula updated https://github.com/llvm/llvm-project/pull/125994
>From 5650e0acfea676bd8b4e045fdc61cfabce5e3134 Mon Sep 17 00:00:00 2001
From: Uday Bondhugula <uday at polymagelabs.com>
Date: Wed, 5 Feb 2025 13:19:40 +0530
Subject: [PATCH] [MLIR][Affine] Make affine fusion MDG API const correct
Make affine fusion MDG API const correct. NFC changes otherwise.
---
.../mlir/Dialect/Affine/Analysis/Utils.h | 37 +++++++---
mlir/lib/Dialect/Affine/Analysis/Utils.cpp | 74 ++++++++++---------
2 files changed, 64 insertions(+), 47 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
index b1fbf4477428ca2..7355509807b4dbc 100644
--- a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
@@ -139,9 +139,11 @@ struct MemRefDependenceGraph {
// Map from node id to Node.
DenseMap<unsigned, Node> nodes;
- // Map from node id to list of input edges.
+ // Map from node id to list of input edges. The absence of an entry for a key
+ // is also equivalent to the absence of any edges.
DenseMap<unsigned, SmallVector<Edge, 2>> inEdges;
- // Map from node id to list of output edges.
+ // Map from node id to list of output edges. The absence of an entry for a
+ // node is also equivalent to the absence of any edges.
DenseMap<unsigned, SmallVector<Edge, 2>> outEdges;
// Map from memref to a count on the dependence edges associated with that
// memref.
@@ -156,10 +158,21 @@ struct MemRefDependenceGraph {
bool init();
// Returns the graph node for 'id'.
- Node *getNode(unsigned id);
+ const Node *getNode(unsigned id) const;
+ Node *getNode(unsigned id) {
+ return const_cast<Node *>(
+ static_cast<const MemRefDependenceGraph *>(this)->getNode(id));
+ }
+
+ // Returns true if the graph has node with ID `id`.
+ bool hasNode(unsigned id) const { return nodes.contains(id); }
// Returns the graph node for 'forOp'.
- Node *getForOpNode(AffineForOp forOp);
+ const Node *getForOpNode(AffineForOp forOp) const;
+ Node *getForOpNode(AffineForOp forOp) {
+ return const_cast<Node *>(
+ static_cast<const MemRefDependenceGraph *>(this)->getForOpNode(forOp));
+ }
// Adds a node with 'op' to the graph and returns its unique identifier.
unsigned addNode(Operation *op);
@@ -169,12 +182,12 @@ struct MemRefDependenceGraph {
// Returns true if node 'id' writes to any memref which escapes (or is an
// argument to) the block. Returns false otherwise.
- bool writesToLiveInOrEscapingMemrefs(unsigned id);
+ bool writesToLiveInOrEscapingMemrefs(unsigned id) const;
// Returns true iff there is an edge from node 'srcId' to node 'dstId' which
// is for 'value' if non-null, or for any value otherwise. Returns false
// otherwise.
- bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr);
+ bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr) const;
// Adds an edge from node 'srcId' to node 'dstId' for 'value'.
void addEdge(unsigned srcId, unsigned dstId, Value value);
@@ -185,23 +198,25 @@ struct MemRefDependenceGraph {
// Returns true if there is a path in the dependence graph from node 'srcId'
// to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
// operations that the edges connected are expected to be from the same block.
- bool hasDependencePath(unsigned srcId, unsigned dstId);
+ bool hasDependencePath(unsigned srcId, unsigned dstId) const;
// Returns the input edge count for node 'id' and 'memref' from src nodes
// which access 'memref' with a store operation.
- unsigned getIncomingMemRefAccesses(unsigned id, Value memref);
+ unsigned getIncomingMemRefAccesses(unsigned id, Value memref) const;
// Returns the output edge count for node 'id' and 'memref' (if non-null),
// otherwise returns the total output edge count from node 'id'.
- unsigned getOutEdgeCount(unsigned id, Value memref = nullptr);
+ unsigned getOutEdgeCount(unsigned id, Value memref = nullptr) const;
/// Return all nodes which define SSA values used in node 'id'.
- void gatherDefiningNodes(unsigned id, DenseSet<unsigned> &definingNodes);
+ void gatherDefiningNodes(unsigned id,
+ DenseSet<unsigned> &definingNodes) const;
// Computes and returns an insertion point operation, before which the
// the fused <srcId, dstId> loop nest can be inserted while preserving
// dependences. Returns nullptr if no such insertion point is found.
- Operation *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId);
+ Operation *getFusedLoopNestInsertionPoint(unsigned srcId,
+ unsigned dstId) const;
// Updates edge mappings from node 'srcId' to node 'dstId' after fusing them,
// taking into account that:
diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index 9c0b5dbf52d299b..92e7667ff2c72f0 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -187,8 +187,9 @@ static void getEffectedValues(Operation *op, SmallVectorImpl<Value> &values) {
/// Add `op` to MDG creating a new node and adding its memory accesses (affine
/// or non-affine to memrefAccesses (memref -> list of nodes with accesses) map.
-Node *addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
- DenseMap<Value, SetVector<unsigned>> &memrefAccesses) {
+static Node *
+addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
+ DenseMap<Value, SetVector<unsigned>> &memrefAccesses) {
auto &nodes = mdg.nodes;
// Create graph node 'id' to represent top-level 'forOp' and record
// all loads and store accesses it contains.
@@ -358,14 +359,14 @@ bool MemRefDependenceGraph::init() {
}
// Returns the graph node for 'id'.
-Node *MemRefDependenceGraph::getNode(unsigned id) {
+const Node *MemRefDependenceGraph::getNode(unsigned id) const {
auto it = nodes.find(id);
assert(it != nodes.end());
return &it->second;
}
// Returns the graph node for 'forOp'.
-Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) {
+const Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) const {
for (auto &idAndNode : nodes)
if (idAndNode.second.op == forOp)
return &idAndNode.second;
@@ -389,7 +390,7 @@ void MemRefDependenceGraph::removeNode(unsigned id) {
}
}
// Remove each edge in 'outEdges[id]'.
- if (outEdges.count(id) > 0) {
+ if (outEdges.contains(id)) {
SmallVector<Edge, 2> oldOutEdges = outEdges[id];
for (auto &outEdge : oldOutEdges) {
removeEdge(id, outEdge.id, outEdge.value);
@@ -403,8 +404,8 @@ void MemRefDependenceGraph::removeNode(unsigned id) {
// Returns true if node 'id' writes to any memref which escapes (or is an
// argument to) the block. Returns false otherwise.
-bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) {
- Node *node = getNode(id);
+bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) const {
+ const Node *node = getNode(id);
for (auto *storeOpInst : node->stores) {
auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
auto *op = memref.getDefiningOp();
@@ -424,14 +425,14 @@ bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) {
// is for 'value' if non-null, or for any value otherwise. Returns false
// otherwise.
bool MemRefDependenceGraph::hasEdge(unsigned srcId, unsigned dstId,
- Value value) {
- if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
+ Value value) const {
+ if (!outEdges.contains(srcId) || !inEdges.contains(dstId)) {
return false;
}
- bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
+ bool hasOutEdge = llvm::any_of(outEdges.lookup(srcId), [=](const Edge &edge) {
return edge.id == dstId && (!value || edge.value == value);
});
- bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
+ bool hasInEdge = llvm::any_of(inEdges.lookup(dstId), [=](const Edge &edge) {
return edge.id == srcId && (!value || edge.value == value);
});
return hasOutEdge && hasInEdge;
@@ -476,7 +477,8 @@ void MemRefDependenceGraph::removeEdge(unsigned srcId, unsigned dstId,
// Returns true if there is a path in the dependence graph from node 'srcId'
// to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
// operations that the edges connected are expected to be from the same block.
-bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
+bool MemRefDependenceGraph::hasDependencePath(unsigned srcId,
+ unsigned dstId) const {
// Worklist state is: <node-id, next-output-edge-index-to-visit>
SmallVector<std::pair<unsigned, unsigned>, 4> worklist;
worklist.push_back({srcId, 0});
@@ -489,13 +491,13 @@ bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
return true;
// Pop and continue if node has no out edges, or if all out edges have
// already been visited.
- if (outEdges.count(idAndIndex.first) == 0 ||
- idAndIndex.second == outEdges[idAndIndex.first].size()) {
+ if (!outEdges.contains(idAndIndex.first) ||
+ idAndIndex.second == outEdges.lookup(idAndIndex.first).size()) {
worklist.pop_back();
continue;
}
// Get graph edge to traverse.
- Edge edge = outEdges[idAndIndex.first][idAndIndex.second];
+ const Edge edge = outEdges.lookup(idAndIndex.first)[idAndIndex.second];
// Increment next output edge index for 'idAndIndex'.
++idAndIndex.second;
// Add node at 'edge.id' to the worklist. We don't need to consider
@@ -511,34 +513,34 @@ bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
// Returns the input edge count for node 'id' and 'memref' from src nodes
// which access 'memref' with a store operation.
unsigned MemRefDependenceGraph::getIncomingMemRefAccesses(unsigned id,
- Value memref) {
+ Value memref) const {
unsigned inEdgeCount = 0;
- if (inEdges.count(id) > 0)
- for (auto &inEdge : inEdges[id])
- if (inEdge.value == memref) {
- Node *srcNode = getNode(inEdge.id);
- // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
- if (srcNode->getStoreOpCount(memref) > 0)
- ++inEdgeCount;
- }
+ for (const Edge &inEdge : inEdges.lookup(id)) {
+ if (inEdge.value == memref) {
+ const Node *srcNode = getNode(inEdge.id);
+ // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
+ if (srcNode->getStoreOpCount(memref) > 0)
+ ++inEdgeCount;
+ }
+ }
return inEdgeCount;
}
// Returns the output edge count for node 'id' and 'memref' (if non-null),
// otherwise returns the total output edge count from node 'id'.
-unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id, Value memref) {
+unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id,
+ Value memref) const {
unsigned outEdgeCount = 0;
- if (outEdges.count(id) > 0)
- for (auto &outEdge : outEdges[id])
- if (!memref || outEdge.value == memref)
- ++outEdgeCount;
+ for (const auto &outEdge : outEdges.lookup(id))
+ if (!memref || outEdge.value == memref)
+ ++outEdgeCount;
return outEdgeCount;
}
/// Return all nodes which define SSA values used in node 'id'.
void MemRefDependenceGraph::gatherDefiningNodes(
- unsigned id, DenseSet<unsigned> &definingNodes) {
- for (MemRefDependenceGraph::Edge edge : inEdges[id])
+ unsigned id, DenseSet<unsigned> &definingNodes) const {
+ for (const Edge &edge : inEdges.lookup(id))
// By definition of edge, if the edge value is a non-memref value,
// then the dependence is between a graph node which defines an SSA value
// and another graph node which uses the SSA value.
@@ -551,8 +553,8 @@ void MemRefDependenceGraph::gatherDefiningNodes(
// dependences. Returns nullptr if no such insertion point is found.
Operation *
MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
- unsigned dstId) {
- if (outEdges.count(srcId) == 0)
+ unsigned dstId) const {
+ if (!outEdges.contains(srcId))
return getNode(dstId)->op;
// Skip if there is any defining node of 'dstId' that depends on 'srcId'.
@@ -568,13 +570,13 @@ MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
// Build set of insts in range (srcId, dstId) which depend on 'srcId'.
SmallPtrSet<Operation *, 2> srcDepInsts;
- for (auto &outEdge : outEdges[srcId])
+ for (auto &outEdge : outEdges.lookup(srcId))
if (outEdge.id != dstId)
srcDepInsts.insert(getNode(outEdge.id)->op);
// Build set of insts in range (srcId, dstId) on which 'dstId' depends.
SmallPtrSet<Operation *, 2> dstDepInsts;
- for (auto &inEdge : inEdges[dstId])
+ for (auto &inEdge : inEdges.lookup(dstId))
if (inEdge.id != srcId)
dstDepInsts.insert(getNode(inEdge.id)->op);
@@ -634,7 +636,7 @@ void MemRefDependenceGraph::updateEdges(unsigned srcId, unsigned dstId,
SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
for (auto &inEdge : oldInEdges) {
// Add edge from 'inEdge.id' to 'dstId' if it's not a private memref.
- if (privateMemRefs.count(inEdge.value) == 0)
+ if (!privateMemRefs.contains(inEdge.value))
addEdge(inEdge.id, dstId, inEdge.value);
}
}
More information about the Mlir-commits
mailing list