[Mlir-commits] [mlir] [MLIR][Affine] Fix fusion in the presence of cyclic deps in source nests (PR #128397)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 24 01:10:29 PST 2025


================
@@ -28,10 +28,138 @@
 #include <optional>
 #include <type_traits>
 
+#define DEBUG_TYPE "affine-loop-analysis"
+
 using namespace mlir;
 using namespace mlir::affine;
 
-#define DEBUG_TYPE "affine-loop-analysis"
+namespace {
+
+/// A directed graph to model relationships between MLIR Operations.
+class DirectedOpGraph {
+public:
+  /// Add a node to
+  void addNode(Operation *op) {
+    assert(!hasNode(op) && "node already added");
+    nodes.emplace_back(op);
+    edges[op] = {};
+  }
+
+  /// Add an edge between `src` and `dest`.
+  void addEdge(Operation *src, Operation *dest) {
+    // This is a multi-graph.
+    assert(hasNode(src) && "src node does not exist in graph");
+    assert(hasNode(dest) && "dest node does not exist in graph");
+    edges[src].push_back(getNode(dest));
+  }
+
+  /// Returns true if there is a (directed) cycle in the graph.
+  bool hasCycle() { return dfsImpl(/*cycleCheck=*/true); }
+
+  void printEdges() {
+    for (auto &en : edges) {
+      llvm::dbgs() << *en.first << " (" << en.first << ")"
+                   << " has " << en.second.size() << " edges:\n";
+      for (auto *node : en.second) {
+        llvm::dbgs() << '\t' << *node->op << '\n';
+      }
+    }
+  }
+
+private:
+  /// A node of a directed graph between MLIR Operations to model various
+  /// relationships. This is meant to be used internally.
+  struct DGNode {
+    DGNode(Operation *op) : op(op) {};
+    Operation *op;
+
+    // Start and finish visit numbers are standard in DFS to implement things
+    // strongly connected components. These numbers are modified during analyses
+    // on the graph and so seemingly const API methods will be non-const.
+
+    /// Start visit number.
+    int vn = -1;
+
+    /// Finish visit number.
+    int fn = -1;
+  };
+
+  /// Get internal node corresponding to `op`.
+  DGNode *getNode(Operation *op) {
+    auto *value =
+        llvm::find_if(nodes, [&](const DGNode &node) { return node.op == op; });
+    assert(value != nodes.end() && "node doesn't exist in graph");
+    return &*value;
+  }
+
+  /// Returns true if `key` is in the graph.
+  bool hasNode(Operation *key) const {
+    return llvm::find_if(nodes, [&](const DGNode &node) {
+             return node.op == key;
+           }) != nodes.end();
+  }
+
+  /// Perform a depth-first traversal of the graph setting visited and finished
+  /// numbers. If `cycleCheck` is set, detects cycles and returns true as soon
+  /// as the first cycle is detected, and false if there are no cycles. If
+  /// `cycleCheck` is not set, completes the DFS and the `return` value doesn't
+  /// have a meaning.
+  bool dfsImpl(bool cycleCheck = false) {
+    for (DGNode &node : nodes)
+      node.vn = 0;
+
+    unsigned time = 0;
+    for (DGNode &node : nodes) {
+      if (node.vn == 0) {
+        bool ret = dfsNode(node, cycleCheck, time);
+        // Check if a cycle was already found.
+        if (cycleCheck && ret)
+          return true;
+      } else if (cycleCheck && node.fn == -1) {
+        // We have encountered a node whose visit has started but it's not
+        // finished. So we have a cycle.
+        return true;
+      }
+    }
+    return false;
+  }
+
+  /// Perform depth-first traversal starting at `node`. Return true
+  /// as soon as a cycle is found if `cycleCheck` was set. Update `time`.
+  bool dfsNode(DGNode &node, bool cycleCheck, unsigned &time) const {
+    auto nodeEdges = edges.find(node.op);
+    assert(nodeEdges != edges.end() && "missing node in graph");
+    // Depth first search from a given vertex.
+    ++time;
+    node.vn = time;
+
+    for (auto &neighbour : nodeEdges->second) {
+      if (neighbour->vn == 0) {
+        bool ret = dfsNode(*neighbour, cycleCheck, time);
+        if (cycleCheck && ret)
+          return true;
+      } else if (cycleCheck && neighbour->fn == -1) {
+        // We have encountered a node whose visit has started but it's not
+        // finished. So we have a cycle.
+        return true;
+      }
+    }
+
+    ++time;
+    // Update finish time.
+    node.fn = time;
----------------
patel-vimal wrote:

Should this be simplified to `node.fn = ++time;`?

https://github.com/llvm/llvm-project/pull/128397


More information about the Mlir-commits mailing list