[Mlir-commits] [mlir] [MLIR][Affine] Make affine fusion MDG API const correct (PR #125994)
Uday Bondhugula
llvmlistbot at llvm.org
Wed Feb 5 19:47:34 PST 2025
https://github.com/bondhugula created https://github.com/llvm/llvm-project/pull/125994
Make affine fusion MDG API const correct. NFC changes otherwise.
>From aa07c72c900a109c4a2ad468f370977c03519772 Mon Sep 17 00:00:00 2001
From: Uday Bondhugula <uday at polymagelabs.com>
Date: Wed, 5 Feb 2025 13:19:40 +0530
Subject: [PATCH] [MLIR][Affine] Make affine fusion MDG API const correct
Make affine fusion MDG API const correct. NFC changes otherwise.
---
.../mlir/Dialect/Affine/Analysis/Utils.h | 29 ++++++---
mlir/lib/Dialect/Affine/Analysis/Utils.cpp | 59 ++++++++++---------
2 files changed, 51 insertions(+), 37 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
index b1fbf4477428ca2..97cf29ce045ced6 100644
--- a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
@@ -156,10 +156,19 @@ struct MemRefDependenceGraph {
bool init();
// Returns the graph node for 'id'.
- Node *getNode(unsigned id);
+ const Node *getNode(unsigned id) const;
+ Node *getNode(unsigned id) {
+ return const_cast<Node *>(
+ static_cast<const MemRefDependenceGraph *>(this)->getNode(id));
+ }
+ bool hasNode(unsigned id) const { return nodes.contains(id); }
// Returns the graph node for 'forOp'.
- Node *getForOpNode(AffineForOp forOp);
+ const Node *getForOpNode(AffineForOp forOp) const;
+ Node *getForOpNode(AffineForOp forOp) {
+ return const_cast<Node *>(
+ static_cast<const MemRefDependenceGraph *>(this)->getForOpNode(forOp));
+ }
// Adds a node with 'op' to the graph and returns its unique identifier.
unsigned addNode(Operation *op);
@@ -169,12 +178,12 @@ struct MemRefDependenceGraph {
// Returns true if node 'id' writes to any memref which escapes (or is an
// argument to) the block. Returns false otherwise.
- bool writesToLiveInOrEscapingMemrefs(unsigned id);
+ bool writesToLiveInOrEscapingMemrefs(unsigned id) const;
// Returns true iff there is an edge from node 'srcId' to node 'dstId' which
// is for 'value' if non-null, or for any value otherwise. Returns false
// otherwise.
- bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr);
+ bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr) const;
// Adds an edge from node 'srcId' to node 'dstId' for 'value'.
void addEdge(unsigned srcId, unsigned dstId, Value value);
@@ -185,23 +194,25 @@ struct MemRefDependenceGraph {
// Returns true if there is a path in the dependence graph from node 'srcId'
// to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
// operations that the edges connected are expected to be from the same block.
- bool hasDependencePath(unsigned srcId, unsigned dstId);
+ bool hasDependencePath(unsigned srcId, unsigned dstId) const;
// Returns the input edge count for node 'id' and 'memref' from src nodes
// which access 'memref' with a store operation.
- unsigned getIncomingMemRefAccesses(unsigned id, Value memref);
+ unsigned getIncomingMemRefAccesses(unsigned id, Value memref) const;
// Returns the output edge count for node 'id' and 'memref' (if non-null),
// otherwise returns the total output edge count from node 'id'.
- unsigned getOutEdgeCount(unsigned id, Value memref = nullptr);
+ unsigned getOutEdgeCount(unsigned id, Value memref = nullptr) const;
/// Return all nodes which define SSA values used in node 'id'.
- void gatherDefiningNodes(unsigned id, DenseSet<unsigned> &definingNodes);
+ void gatherDefiningNodes(unsigned id,
+ DenseSet<unsigned> &definingNodes) const;
// Computes and returns an insertion point operation, before which the
// the fused <srcId, dstId> loop nest can be inserted while preserving
// dependences. Returns nullptr if no such insertion point is found.
- Operation *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId);
+ Operation *getFusedLoopNestInsertionPoint(unsigned srcId,
+ unsigned dstId) const;
// Updates edge mappings from node 'srcId' to node 'dstId' after fusing them,
// taking into account that:
diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index 9c0b5dbf52d299b..54bb529041e0914 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -187,8 +187,9 @@ static void getEffectedValues(Operation *op, SmallVectorImpl<Value> &values) {
/// Add `op` to MDG creating a new node and adding its memory accesses (affine
/// or non-affine to memrefAccesses (memref -> list of nodes with accesses) map.
-Node *addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
- DenseMap<Value, SetVector<unsigned>> &memrefAccesses) {
+static Node *
+addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
+ DenseMap<Value, SetVector<unsigned>> &memrefAccesses) {
auto &nodes = mdg.nodes;
// Create graph node 'id' to represent top-level 'forOp' and record
// all loads and store accesses it contains.
@@ -358,14 +359,14 @@ bool MemRefDependenceGraph::init() {
}
// Returns the graph node for 'id'.
-Node *MemRefDependenceGraph::getNode(unsigned id) {
+const Node *MemRefDependenceGraph::getNode(unsigned id) const {
auto it = nodes.find(id);
assert(it != nodes.end());
return &it->second;
}
// Returns the graph node for 'forOp'.
-Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) {
+const Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) const {
for (auto &idAndNode : nodes)
if (idAndNode.second.op == forOp)
return &idAndNode.second;
@@ -403,8 +404,8 @@ void MemRefDependenceGraph::removeNode(unsigned id) {
// Returns true if node 'id' writes to any memref which escapes (or is an
// argument to) the block. Returns false otherwise.
-bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) {
- Node *node = getNode(id);
+bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) const {
+ const Node *node = getNode(id);
for (auto *storeOpInst : node->stores) {
auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
auto *op = memref.getDefiningOp();
@@ -424,14 +425,14 @@ bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) {
// is for 'value' if non-null, or for any value otherwise. Returns false
// otherwise.
bool MemRefDependenceGraph::hasEdge(unsigned srcId, unsigned dstId,
- Value value) {
+ Value value) const {
if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
return false;
}
- bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
+ bool hasOutEdge = llvm::any_of(outEdges.lookup(srcId), [=](Edge &edge) {
return edge.id == dstId && (!value || edge.value == value);
});
- bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
+ bool hasInEdge = llvm::any_of(inEdges.lookup(dstId), [=](Edge &edge) {
return edge.id == srcId && (!value || edge.value == value);
});
return hasOutEdge && hasInEdge;
@@ -476,7 +477,8 @@ void MemRefDependenceGraph::removeEdge(unsigned srcId, unsigned dstId,
// Returns true if there is a path in the dependence graph from node 'srcId'
// to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
// operations that the edges connected are expected to be from the same block.
-bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
+bool MemRefDependenceGraph::hasDependencePath(unsigned srcId,
+ unsigned dstId) const {
// Worklist state is: <node-id, next-output-edge-index-to-visit>
SmallVector<std::pair<unsigned, unsigned>, 4> worklist;
worklist.push_back({srcId, 0});
@@ -490,12 +492,12 @@ bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
// Pop and continue if node has no out edges, or if all out edges have
// already been visited.
if (outEdges.count(idAndIndex.first) == 0 ||
- idAndIndex.second == outEdges[idAndIndex.first].size()) {
+ idAndIndex.second == outEdges.lookup(idAndIndex.first).size()) {
worklist.pop_back();
continue;
}
// Get graph edge to traverse.
- Edge edge = outEdges[idAndIndex.first][idAndIndex.second];
+ const Edge edge = outEdges.lookup(idAndIndex.first)[idAndIndex.second];
// Increment next output edge index for 'idAndIndex'.
++idAndIndex.second;
// Add node at 'edge.id' to the worklist. We don't need to consider
@@ -511,25 +513,26 @@ bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
// Returns the input edge count for node 'id' and 'memref' from src nodes
// which access 'memref' with a store operation.
unsigned MemRefDependenceGraph::getIncomingMemRefAccesses(unsigned id,
- Value memref) {
+ Value memref) const {
unsigned inEdgeCount = 0;
- if (inEdges.count(id) > 0)
- for (auto &inEdge : inEdges[id])
- if (inEdge.value == memref) {
- Node *srcNode = getNode(inEdge.id);
- // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
- if (srcNode->getStoreOpCount(memref) > 0)
- ++inEdgeCount;
- }
+ for (const MemRefDependenceGraph::Edge &inEdge : inEdges.lookup(id)) {
+ if (inEdge.value == memref) {
+ const Node *srcNode = getNode(inEdge.id);
+ // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
+ if (srcNode->getStoreOpCount(memref) > 0)
+ ++inEdgeCount;
+ }
+ }
return inEdgeCount;
}
// Returns the output edge count for node 'id' and 'memref' (if non-null),
// otherwise returns the total output edge count from node 'id'.
-unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id, Value memref) {
+unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id,
+ Value memref) const {
unsigned outEdgeCount = 0;
if (outEdges.count(id) > 0)
- for (auto &outEdge : outEdges[id])
+ for (auto &outEdge : outEdges.lookup(id))
if (!memref || outEdge.value == memref)
++outEdgeCount;
return outEdgeCount;
@@ -537,8 +540,8 @@ unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id, Value memref) {
/// Return all nodes which define SSA values used in node 'id'.
void MemRefDependenceGraph::gatherDefiningNodes(
- unsigned id, DenseSet<unsigned> &definingNodes) {
- for (MemRefDependenceGraph::Edge edge : inEdges[id])
+ unsigned id, DenseSet<unsigned> &definingNodes) const {
+ for (MemRefDependenceGraph::Edge edge : inEdges.lookup(id))
// By definition of edge, if the edge value is a non-memref value,
// then the dependence is between a graph node which defines an SSA value
// and another graph node which uses the SSA value.
@@ -551,7 +554,7 @@ void MemRefDependenceGraph::gatherDefiningNodes(
// dependences. Returns nullptr if no such insertion point is found.
Operation *
MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
- unsigned dstId) {
+ unsigned dstId) const {
if (outEdges.count(srcId) == 0)
return getNode(dstId)->op;
@@ -568,13 +571,13 @@ MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
// Build set of insts in range (srcId, dstId) which depend on 'srcId'.
SmallPtrSet<Operation *, 2> srcDepInsts;
- for (auto &outEdge : outEdges[srcId])
+ for (auto &outEdge : outEdges.lookup(srcId))
if (outEdge.id != dstId)
srcDepInsts.insert(getNode(outEdge.id)->op);
// Build set of insts in range (srcId, dstId) on which 'dstId' depends.
SmallPtrSet<Operation *, 2> dstDepInsts;
- for (auto &inEdge : inEdges[dstId])
+ for (auto &inEdge : inEdges.lookup(dstId))
if (inEdge.id != srcId)
dstDepInsts.insert(getNode(inEdge.id)->op);
More information about the Mlir-commits
mailing list