[Mlir-commits] [mlir] WIP OpGraph fix (PR #67143)

Vladyslav Moisieienkov llvmlistbot at llvm.org
Fri Sep 22 07:07:44 PDT 2023


https://github.com/VMois created https://github.com/llvm/llvm-project/pull/67143

None

>From ce090bcc1edd1dc7b942e47c28880e2e989e1ebc Mon Sep 17 00:00:00 2001
From: VMois <vmois at protonmail.com>
Date: Tue, 25 Apr 2023 20:45:15 +0200
Subject: [PATCH] WIP OpGraph fix

---
 mlir/lib/Transforms/ViewOpGraph.cpp | 69 +++++++++++++++++++++++++++--
 1 file changed, 66 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
index 4598b56a901d537..3e7296460d7dd5d 100644
--- a/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -18,6 +18,7 @@
 #include <map>
 #include <optional>
 #include <utility>
+#include <iostream>
 
 namespace mlir {
 #define GEN_PASS_DEF_VIEWOPGRAPH
@@ -250,10 +251,33 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
   /// Process a block. Emit a cluster and one node per block argument and
   /// operation inside the cluster.
   void processBlock(Block &block) {
+    //std::cout << "Emit cluster process block" << std::endl;
     emitClusterStmt([&]() {
       for (BlockArgument &blockArg : block.getArguments())
         valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg));
 
+      // for (Operation &op : block) {
+      //   //if (op.getNumRegions() > 0) continue;
+      //   //std::string label = getLabel(&op);
+      //   Node node = emitNodeStmt(getLabel(&op));
+      //   for (Value result : op.getResults())
+      //     valueToNode[result] = node;
+        
+      //   for (Value result : op.getOperands())
+      //     valueToNode[result] = node;
+      //   //std::cout << "Creating node " << node.id << std::endl;
+      //   opToNode[&op] = node;
+      // }
+      
+      // for (Operation &op : block) {
+      //   Node node = emitNodeStmt(getLabel(&op));
+      //   for (Value result : op.getResults())
+      //     valueToNode[result] = node;
+        
+      //   for (Value result : op.getOperands())
+      //     valueToNode[result] = node;
+      // }
+
       // Emit a node for each operation.
       std::optional<Node> prevNode;
       for (Operation &op : block) {
@@ -279,16 +303,31 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
           },
           getLabel(op));
     } else {
-      node = emitNodeStmt(getLabel(op));
+      //if (opToNode.count(op) == 1) {
+      //node = opToNode[op];
+      //} else {
+      node = emitNodeStmt(opToName[op]);
+      //std::cout << "Using node " << node.id << std::endl;
     }
 
     // Insert data flow edges originating from each operand.
     if (printDataFlowEdges) {
       unsigned numOperands = op->getNumOperands();
-      for (unsigned i = 0; i < numOperands; i++)
-        emitEdgeStmt(valueToNode[op->getOperand(i)], node,
+      for (unsigned i = 0; i < numOperands; i++) {
+        Value operand = op->getOperand(i);
+        Node operandNode;
+        if (valueToNode.count(operand) == 0) {
+          // no Node created, created one according to pre-specified name
+          Operation* definingOp = operand.getDefiningOp();
+          operandNode = emitNodeStmt(opToName[definingOp]);
+          valueToNode[operand] = operandNode;
+        } else {
+          operandNode = valueToNode[operand];
+        }
+        emitEdgeStmt(operandNode, node,
                      /*label=*/numOperands == 1 ? "" : std::to_string(i),
                      kLineStyleDataFlow);
+      }
     }
 
     for (Value result : op->getResults())
@@ -299,6 +338,27 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
 
   /// Process a region.
   void processRegion(Region &region) {
+    // for (Block &block : region.getBlocks()) {
+    //   for (Operation &op : block) {
+    //     if (op.getNumRegions() > 0) continue;
+    //     //std::string label = getLabel(&op);
+    //     Node node = emitNodeStmt(getLabel(&op));
+    //     for (Value result : op.getResults())
+    //       valueToNode[result] = node;
+        
+    //     for (Value result : op.getOperands())
+    //       valueToNode[result] = node;
+    //     std::cout << "Creating node " << node.id << std::endl;
+    //     opToNode[&op] = node;
+    //   }
+    // }
+
+    for (Block &block : region.getBlocks()) {
+      for (Operation &op : block) {
+        opToName[&op] = getLabel(&op);
+      }
+    }
+
     for (Block &block : region.getBlocks())
       processBlock(block);
   }
@@ -317,6 +377,9 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
   std::vector<std::string> edges;
   /// Mapping of SSA values to Graphviz nodes/clusters.
   DenseMap<Value, Node> valueToNode;
+  DenseMap<Operation *, Node> opToNode;
+  DenseMap<Operation *, std::string> opToName;
+  //DenseMap<std::string, Node> nameToNode;
   /// Counter for generating unique node/subgraph identifiers.
   int counter = 0;
 };



More information about the Mlir-commits mailing list