[Mlir-commits] [mlir] [MLIR][Affine] NFC. Move misplaced MDG init method (PR #71665)

Uday Bondhugula llvmlistbot at llvm.org
Wed Nov 8 03:50:56 PST 2023


https://github.com/bondhugula updated https://github.com/llvm/llvm-project/pull/71665

>From b85709cb78dbdcf4a416f1c78448a8ed47f1ac4d Mon Sep 17 00:00:00 2001
From: Uday Bondhugula <uday at polymagelabs.com>
Date: Wed, 8 Nov 2023 17:11:55 +0530
Subject: [PATCH] [MLIR][Affine] NFC. Move misplaced MDG init method

MemRefDependenceGraph::init should have been in affine analysis utils
since MemRefDependenceGraph is part of the affine analysis library; its
move was missed. Move it. NFC.
---
 mlir/lib/Dialect/Affine/Analysis/Utils.cpp    | 123 ++++++++++++++++++
 .../Dialect/Affine/Transforms/LoopFusion.cpp  | 122 -----------------
 2 files changed, 123 insertions(+), 122 deletions(-)

diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index ce3ff0a095770c1..23921b700b6669b 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -20,6 +20,8 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/IntegerSet.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
@@ -104,6 +106,127 @@ void Node::getLoadAndStoreMemrefSet(
   }
 }
 
+// Initializes the data dependence graph by walking operations in `block`.
+// Assigns each node in the graph a node id based on program order in 'f'.
+bool MemRefDependenceGraph::init() {
+  LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n");
+  // Map from a memref to the set of ids of the nodes that have ops accessing
+  // the memref.
+  DenseMap<Value, SetVector<unsigned>> memrefAccesses;
+
+  DenseMap<Operation *, unsigned> forToNodeMap;
+  for (Operation &op : block) {
+    if (dyn_cast<AffineForOp>(op)) {
+      // Create graph node 'id' to represent top-level 'forOp' and record
+      // all loads and store accesses it contains.
+      LoopNestStateCollector collector;
+      collector.collect(&op);
+      // Return false if a region holding op other than 'affine.for' and
+      // 'affine.if' was found (not currently supported).
+      if (collector.hasNonAffineRegionOp)
+        return false;
+      Node node(nextNodeId++, &op);
+      for (auto *opInst : collector.loadOpInsts) {
+        node.loads.push_back(opInst);
+        auto memref = cast<AffineReadOpInterface>(opInst).getMemRef();
+        memrefAccesses[memref].insert(node.id);
+      }
+      for (auto *opInst : collector.storeOpInsts) {
+        node.stores.push_back(opInst);
+        auto memref = cast<AffineWriteOpInterface>(opInst).getMemRef();
+        memrefAccesses[memref].insert(node.id);
+      }
+      forToNodeMap[&op] = node.id;
+      nodes.insert({node.id, node});
+    } else if (dyn_cast<AffineReadOpInterface>(op)) {
+      // Create graph node for top-level load op.
+      Node node(nextNodeId++, &op);
+      node.loads.push_back(&op);
+      auto memref = cast<AffineReadOpInterface>(op).getMemRef();
+      memrefAccesses[memref].insert(node.id);
+      nodes.insert({node.id, node});
+    } else if (dyn_cast<AffineWriteOpInterface>(op)) {
+      // Create graph node for top-level store op.
+      Node node(nextNodeId++, &op);
+      node.stores.push_back(&op);
+      auto memref = cast<AffineWriteOpInterface>(op).getMemRef();
+      memrefAccesses[memref].insert(node.id);
+      nodes.insert({node.id, node});
+    } else if (op.getNumRegions() != 0) {
+      // Return false if another region is found (not currently supported).
+      return false;
+    } else if (op.getNumResults() > 0 && !op.use_empty()) {
+      // Create graph node for top-level producer of SSA values, which
+      // could be used by loop nest nodes.
+      Node node(nextNodeId++, &op);
+      nodes.insert({node.id, node});
+    } else if (isa<CallOpInterface>(op)) {
+      // Create graph node for top-level Call Op that takes any argument of
+      // memref type. Call Op that returns one or more memref type results
+      // is already taken care of, by the previous conditions.
+      if (llvm::any_of(op.getOperandTypes(),
+                       [&](Type t) { return isa<MemRefType>(t); })) {
+        Node node(nextNodeId++, &op);
+        nodes.insert({node.id, node});
+      }
+    } else if (hasEffect<MemoryEffects::Write, MemoryEffects::Free>(&op)) {
+      // Create graph node for top-level op, which could have a memory write
+      // side effect.
+      Node node(nextNodeId++, &op);
+      nodes.insert({node.id, node});
+    }
+  }
+
+  for (auto &idAndNode : nodes) {
+    LLVM_DEBUG(llvm::dbgs() << "Create node " << idAndNode.first << " for:\n"
+                            << *(idAndNode.second.op) << "\n");
+    (void)idAndNode;
+  }
+
+  // Add dependence edges between nodes which produce SSA values and their
+  // users. Load ops can be considered as the ones producing SSA values.
+  for (auto &idAndNode : nodes) {
+    const Node &node = idAndNode.second;
+    // Stores don't define SSA values, skip them.
+    if (!node.stores.empty())
+      continue;
+    Operation *opInst = node.op;
+    for (Value value : opInst->getResults()) {
+      for (Operation *user : value.getUsers()) {
+        // Ignore users outside of the block.
+        if (block.getParent()->findAncestorOpInRegion(*user)->getBlock() !=
+            &block)
+          continue;
+        SmallVector<AffineForOp, 4> loops;
+        getAffineForIVs(*user, &loops);
+        if (loops.empty())
+          continue;
+        assert(forToNodeMap.count(loops[0]) > 0 && "missing mapping");
+        unsigned userLoopNestId = forToNodeMap[loops[0]];
+        addEdge(node.id, userLoopNestId, value);
+      }
+    }
+  }
+
+  // Walk memref access lists and add graph edges between dependent nodes.
+  for (auto &memrefAndList : memrefAccesses) {
+    unsigned n = memrefAndList.second.size();
+    for (unsigned i = 0; i < n; ++i) {
+      unsigned srcId = memrefAndList.second[i];
+      bool srcHasStore =
+          getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0;
+      for (unsigned j = i + 1; j < n; ++j) {
+        unsigned dstId = memrefAndList.second[j];
+        bool dstHasStore =
+            getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0;
+        if (srcHasStore || dstHasStore)
+          addEdge(srcId, dstId, memrefAndList.first);
+      }
+    }
+  }
+  return true;
+}
+
 // Returns the graph node for 'id'.
 Node *MemRefDependenceGraph::getNode(unsigned id) {
   auto it = nodes.find(id);
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index d85dfc3e25c4e39..fda0156437478b1 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -27,7 +27,6 @@
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SetVector.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
@@ -230,127 +229,6 @@ static void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
   }
 }
 
-// Initializes the data dependence graph by walking operations in `block`.
-// Assigns each node in the graph a node id based on program order in 'f'.
-bool MemRefDependenceGraph::init() {
-  LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n");
-  // Map from a memref to the set of ids of the nodes that have ops accessing
-  // the memref.
-  DenseMap<Value, SetVector<unsigned>> memrefAccesses;
-
-  DenseMap<Operation *, unsigned> forToNodeMap;
-  for (Operation &op : block) {
-    if (dyn_cast<AffineForOp>(op)) {
-      // Create graph node 'id' to represent top-level 'forOp' and record
-      // all loads and store accesses it contains.
-      LoopNestStateCollector collector;
-      collector.collect(&op);
-      // Return false if a region holding op other than 'affine.for' and
-      // 'affine.if' was found (not currently supported).
-      if (collector.hasNonAffineRegionOp)
-        return false;
-      Node node(nextNodeId++, &op);
-      for (auto *opInst : collector.loadOpInsts) {
-        node.loads.push_back(opInst);
-        auto memref = cast<AffineReadOpInterface>(opInst).getMemRef();
-        memrefAccesses[memref].insert(node.id);
-      }
-      for (auto *opInst : collector.storeOpInsts) {
-        node.stores.push_back(opInst);
-        auto memref = cast<AffineWriteOpInterface>(opInst).getMemRef();
-        memrefAccesses[memref].insert(node.id);
-      }
-      forToNodeMap[&op] = node.id;
-      nodes.insert({node.id, node});
-    } else if (dyn_cast<AffineReadOpInterface>(op)) {
-      // Create graph node for top-level load op.
-      Node node(nextNodeId++, &op);
-      node.loads.push_back(&op);
-      auto memref = cast<AffineReadOpInterface>(op).getMemRef();
-      memrefAccesses[memref].insert(node.id);
-      nodes.insert({node.id, node});
-    } else if (dyn_cast<AffineWriteOpInterface>(op)) {
-      // Create graph node for top-level store op.
-      Node node(nextNodeId++, &op);
-      node.stores.push_back(&op);
-      auto memref = cast<AffineWriteOpInterface>(op).getMemRef();
-      memrefAccesses[memref].insert(node.id);
-      nodes.insert({node.id, node});
-    } else if (op.getNumRegions() != 0) {
-      // Return false if another region is found (not currently supported).
-      return false;
-    } else if (op.getNumResults() > 0 && !op.use_empty()) {
-      // Create graph node for top-level producer of SSA values, which
-      // could be used by loop nest nodes.
-      Node node(nextNodeId++, &op);
-      nodes.insert({node.id, node});
-    } else if (isa<CallOpInterface>(op)) {
-      // Create graph node for top-level Call Op that takes any argument of
-      // memref type. Call Op that returns one or more memref type results
-      // is already taken care of, by the previous conditions.
-      if (llvm::any_of(op.getOperandTypes(),
-                       [&](Type t) { return isa<MemRefType>(t); })) {
-        Node node(nextNodeId++, &op);
-        nodes.insert({node.id, node});
-      }
-    } else if (hasEffect<MemoryEffects::Write, MemoryEffects::Free>(&op)) {
-      // Create graph node for top-level op, which could have a memory write
-      // side effect.
-      Node node(nextNodeId++, &op);
-      nodes.insert({node.id, node});
-    }
-  }
-
-  for (auto &idAndNode : nodes) {
-    LLVM_DEBUG(llvm::dbgs() << "Create node " << idAndNode.first << " for:\n"
-                            << *(idAndNode.second.op) << "\n");
-    (void)idAndNode;
-  }
-
-  // Add dependence edges between nodes which produce SSA values and their
-  // users. Load ops can be considered as the ones producing SSA values.
-  for (auto &idAndNode : nodes) {
-    const Node &node = idAndNode.second;
-    // Stores don't define SSA values, skip them.
-    if (!node.stores.empty())
-      continue;
-    Operation *opInst = node.op;
-    for (Value value : opInst->getResults()) {
-      for (Operation *user : value.getUsers()) {
-        // Ignore users outside of the block.
-        if (block.getParent()->findAncestorOpInRegion(*user)->getBlock() !=
-            &block)
-          continue;
-        SmallVector<AffineForOp, 4> loops;
-        getAffineForIVs(*user, &loops);
-        if (loops.empty())
-          continue;
-        assert(forToNodeMap.count(loops[0]) > 0 && "missing mapping");
-        unsigned userLoopNestId = forToNodeMap[loops[0]];
-        addEdge(node.id, userLoopNestId, value);
-      }
-    }
-  }
-
-  // Walk memref access lists and add graph edges between dependent nodes.
-  for (auto &memrefAndList : memrefAccesses) {
-    unsigned n = memrefAndList.second.size();
-    for (unsigned i = 0; i < n; ++i) {
-      unsigned srcId = memrefAndList.second[i];
-      bool srcHasStore =
-          getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0;
-      for (unsigned j = i + 1; j < n; ++j) {
-        unsigned dstId = memrefAndList.second[j];
-        bool dstHasStore =
-            getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0;
-        if (srcHasStore || dstHasStore)
-          addEdge(srcId, dstId, memrefAndList.first);
-      }
-    }
-  }
-  return true;
-}
-
 // Sinks all sequential loops to the innermost levels (while preserving
 // relative order among them) and moves all parallel loops to the
 // outermost (while again preserving relative order among them).



More information about the Mlir-commits mailing list