[Mlir-commits] [mlir] [WIP][ViewOpGraph] Improve GraphViz output (PR #125509)

Eric Hein llvmlistbot at llvm.org
Mon Feb 3 06:44:54 PST 2025


https://github.com/ehein6 created https://github.com/llvm/llvm-project/pull/125509

This patch improves the GraphViz output of ViewOpGraph (--view-op-graph).

See before/after comparison below:

- Switch to rectangular record-based nodes.
- Add input and output ports for each operand and result.
- Remove edge labels.

The graphviz documentation notes some limitations with record-based nodes, and
recommends switching to the HTML-like syntax. I'm sticking with record-based
nodes for now since it's less verbose and I don't need any of the newer features,
but it might be worth switching to the HTML-like syntax in the future.

This visualization displays the short name of each operand along the top edge of each node.
Result types are displayed along the bottom edge of each node. This differs somewhat from
reading the mlir source code, where the name of the result comes before the operation name.

But this style works better in the visualization for the following reasons:

1. Most ops have many operands but few results. So we need a short
representation for the operands, but we can use a longer representation for the results.
2. Since a value has only one producer, we only need to print the full type in one place for each value.
3. The dataflow dependencies flow from top to bottom, so it makes sense to put the
operands on top and the results on the bottom.
4. The viewer's eye can naturally follow the edge upwards to match the value number to its type.

>From 1f327be80e5102f01b995910d1dfba1014bd0c28 Mon Sep 17 00:00:00 2001
From: Eric Hein <ehein at modular.com>
Date: Fri, 31 Jan 2025 14:47:36 +0000
Subject: [PATCH 01/16] [WIP] Improve graphviz output

---
 mlir/lib/Transforms/ViewOpGraph.cpp | 20 ++++++++++++++++++--
 1 file changed, 18 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
index fa0af7665ba4c4c..920de8b22f1ac97 100644
--- a/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -29,7 +29,7 @@ using namespace mlir;
 
 static const StringRef kLineStyleControlFlow = "dashed";
 static const StringRef kLineStyleDataFlow = "solid";
-static const StringRef kShapeNode = "ellipse";
+static const StringRef kShapeNode = "record";
 static const StringRef kShapeNone = "plain";
 
 /// Return the size limits for eliding large attributes.
@@ -59,6 +59,21 @@ static std::string quoteString(const std::string &str) {
   return "\"" + str + "\"";
 }
 
+/// For Graphviz record nodes:
+/// " Braces, vertical bars and angle brackets must be escaped with a backslash
+/// character if you wish them to appear as a literal character "
+static std::string escapeLabelString(const std::string &str) {
+  std::string buf;
+  llvm::raw_string_ostream os(buf);
+  for (char c : str) {
+    if (c == '{' || c == '|' || c == '<' || c == '}' || c == '>') {
+      os << "\\\\";
+    }
+    os << c;
+  }
+  return buf;
+}
+
 using AttributeMap = std::map<std::string, std::string>;
 
 namespace {
@@ -240,7 +255,8 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
                     StringRef background = "") {
     int nodeId = ++counter;
     AttributeMap attrs;
-    attrs["label"] = quoteString(escapeString(std::move(label)));
+    attrs["label"] =
+        quoteString(escapeString(escapeLabelString(std::move(label))));
     attrs["shape"] = shape.str();
     if (!background.empty()) {
       attrs["style"] = "filled";

>From b08a117695679d418cc56a305960b8c003567ec6 Mon Sep 17 00:00:00 2001
From: Eric Hein <ehein at modular.com>
Date: Fri, 31 Jan 2025 13:26:14 -0500
Subject: [PATCH 02/16] Moving to record nodes and port-based edges.

---
 mlir/lib/Transforms/ViewOpGraph.cpp | 43 ++++++++++++++++++-----------
 1 file changed, 27 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
index 920de8b22f1ac97..53e4bcfaa57f6e9 100644
--- a/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -62,12 +62,12 @@ static std::string quoteString(const std::string &str) {
 /// For Graphviz record nodes:
 /// " Braces, vertical bars and angle brackets must be escaped with a backslash
 /// character if you wish them to appear as a literal character "
-static std::string escapeLabelString(const std::string &str) {
+std::string escapeLabelString(const std::string &str) {
   std::string buf;
   llvm::raw_string_ostream os(buf);
   for (char c : str) {
     if (c == '{' || c == '|' || c == '<' || c == '}' || c == '>') {
-      os << "\\\\";
+      os << '\\';
     }
     os << c;
   }
@@ -145,7 +145,7 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
   void emitAllEdgeStmts() {
     if (printDataFlowEdges) {
       for (const auto &[value, node, label] : dataFlowEdges) {
-        emitEdgeStmt(valueToNode[value], node, label, kLineStyleDataFlow);
+        emitEdgeStmt(valueToNode[value], "", node, "", kLineStyleDataFlow);
       }
     }
 
@@ -219,14 +219,10 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
 
   /// Append an edge to the list of edges.
   /// Note: Edges are written to the output stream via `emitAllEdgeStmts`.
-  void emitEdgeStmt(Node n1, Node n2, std::string label, StringRef style) {
+  void emitEdgeStmt(Node n1, std::string outPort, Node n2, std::string inPort,
+                    StringRef style) {
     AttributeMap attrs;
     attrs["style"] = style.str();
-    // Do not label edges that start/end at a cluster boundary. Such edges are
-    // clipped at the boundary, but labels are not. This can lead to labels
-    // floating around without any edge next to them.
-    if (!n1.clusterId && !n2.clusterId)
-      attrs["label"] = quoteString(escapeString(std::move(label)));
     // Use `ltail` and `lhead` to draw edges between clusters.
     if (n1.clusterId)
       attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId);
@@ -234,7 +230,13 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
       attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId);
 
     edges.push_back(strFromOs([&](raw_ostream &os) {
-      os << llvm::format("v%i -> v%i ", n1.id, n2.id);
+      os << "v" << n1.id;
+      if (!outPort.empty())
+        os << ":" << outPort;
+      os << " -> ";
+      os << "v" << n2.id;
+      if (!inPort.empty())
+        os << ":" << inPort;
       emitAttrList(os, attrs);
     }));
   }
@@ -255,8 +257,7 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
                     StringRef background = "") {
     int nodeId = ++counter;
     AttributeMap attrs;
-    attrs["label"] =
-        quoteString(escapeString(escapeLabelString(std::move(label))));
+    attrs["label"] = quoteString(escapeString(std::move(label)));
     attrs["shape"] = shape.str();
     if (!background.empty()) {
       attrs["style"] = "filled";
@@ -271,6 +272,16 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
   /// Generate a label for an operation.
   std::string getLabel(Operation *op) {
     return strFromOs([&](raw_ostream &os) {
+      os << "{{";
+      // Print operation inputs.
+      interleave(
+          op->getOperands(), os,
+          [&](Value operand) {
+            OpPrintingFlags flags;
+            operand.printAsOperand(os, flags);
+          },
+          "|");
+      os << "}|";
       // Print operation name and type.
       os << op->getName();
       if (printResultTypes) {
@@ -289,6 +300,7 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
           emitMlirAttr(os, attr.getValue());
         }
       }
+      os << "}";
     });
   }
 
@@ -301,16 +313,15 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
   /// operation inside the cluster.
   void processBlock(Block &block) {
     emitClusterStmt([&]() {
-      for (BlockArgument &blockArg : block.getArguments())
+      for (BlockArgument &blockArg : block.getArguments()) {
         valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg));
-
+      }
       // Emit a node for each operation.
       std::optional<Node> prevNode;
       for (Operation &op : block) {
         Node nextNode = processOperation(&op);
         if (printControlFlowEdges && prevNode)
-          emitEdgeStmt(*prevNode, nextNode, /*label=*/"",
-                       kLineStyleControlFlow);
+          emitEdgeStmt(*prevNode, "", nextNode, "", kLineStyleControlFlow);
         prevNode = nextNode;
       }
     });

>From 0803fa6cd0aa63ef01c3f4c5ae2f71966b1070de Mon Sep 17 00:00:00 2001
From: Eric Hein <ehein at modular.com>
Date: Fri, 31 Jan 2025 14:44:13 -0500
Subject: [PATCH 03/16] Fix input port linking

---
 mlir/lib/Transforms/ViewOpGraph.cpp | 54 +++++++++++++++++++++--------
 1 file changed, 39 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
index 53e4bcfaa57f6e9..88ec22d88109525 100644
--- a/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -94,6 +94,13 @@ struct Node {
   std::optional<int> clusterId;
 };
 
+struct DataFlowEdge {
+  Value value;
+  std::string outPort;
+  Node node;
+  std::string inPort;
+};
+
 /// This pass generates a Graphviz dataflow visualization of an MLIR operation.
 /// Note: See https://www.graphviz.org/doc/info/lang.html for more information
 /// about the Graphviz DOT language.
@@ -144,8 +151,9 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
   /// emitted.
   void emitAllEdgeStmts() {
     if (printDataFlowEdges) {
-      for (const auto &[value, node, label] : dataFlowEdges) {
-        emitEdgeStmt(valueToNode[value], "", node, "", kLineStyleDataFlow);
+      for (const auto &e : dataFlowEdges) {
+        emitEdgeStmt(valueToNode[e.value], e.outPort, e.node, e.inPort,
+                     kLineStyleDataFlow);
       }
     }
 
@@ -269,19 +277,33 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
     return Node(nodeId);
   }
 
+  std::string getOperandPortName(Value operand) {
+    // Print value as an operand and omit the leading '%' character.
+    return strFromOs([&](raw_ostream &os) {
+             operand.printAsOperand(os, OpPrintingFlags());
+           })
+        .substr(1, std::string::npos);
+  }
+
   /// Generate a label for an operation.
   std::string getLabel(Operation *op) {
     return strFromOs([&](raw_ostream &os) {
-      os << "{{";
+      os << "{";
+
       // Print operation inputs.
-      interleave(
-          op->getOperands(), os,
-          [&](Value operand) {
-            OpPrintingFlags flags;
-            operand.printAsOperand(os, flags);
-          },
-          "|");
-      os << "}|";
+      if (op->getNumOperands() > 0) {
+        os << "{";
+        interleave(
+            op->getOperands(), os,
+            [&](Value operand) {
+              os << "<";
+              os << getOperandPortName(operand);
+              os << "> ";
+              operand.printAsOperand(os, OpPrintingFlags());
+            },
+            "|");
+        os << "}|";
+      }
       // Print operation name and type.
       os << op->getName();
       if (printResultTypes) {
@@ -347,9 +369,11 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
     // Insert data flow edges originating from each operand.
     if (printDataFlowEdges) {
       unsigned numOperands = op->getNumOperands();
-      for (unsigned i = 0; i < numOperands; i++)
-        dataFlowEdges.push_back({op->getOperand(i), node,
-                                 numOperands == 1 ? "" : std::to_string(i)});
+      for (unsigned i = 0; i < numOperands; i++) {
+        auto operand = op->getOperand(i);
+        auto inPort = getOperandPortName(operand);
+        dataFlowEdges.push_back({operand, "", node, inPort});
+      }
     }
 
     for (Value result : op->getResults())
@@ -379,7 +403,7 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
   /// Mapping of SSA values to Graphviz nodes/clusters.
   DenseMap<Value, Node> valueToNode;
   /// Output for data flow edges is delayed until the end to handle cycles
-  std::vector<std::tuple<Value, Node, std::string>> dataFlowEdges;
+  std::vector<DataFlowEdge> dataFlowEdges;
   /// Counter for generating unique node/subgraph identifiers.
   int counter = 0;
 

>From 04586df2a729f6a35b1a4ecaab298ee0278e9a18 Mon Sep 17 00:00:00 2001
From: Eric Hein <ehein at modular.com>
Date: Fri, 31 Jan 2025 15:04:46 -0500
Subject: [PATCH 04/16] Fix result type printing

---
 mlir/lib/Transforms/ViewOpGraph.cpp | 17 ++++++++++++-----
 1 file changed, 12 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
index 88ec22d88109525..bf8e2a845751b89 100644
--- a/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -265,7 +265,7 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
                     StringRef background = "") {
     int nodeId = ++counter;
     AttributeMap attrs;
-    attrs["label"] = quoteString(escapeString(std::move(label)));
+    attrs["label"] = quoteString(label);
     attrs["shape"] = shape.str();
     if (!background.empty()) {
       attrs["style"] = "filled";
@@ -306,12 +306,19 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
       }
       // Print operation name and type.
       os << op->getName();
-      if (printResultTypes) {
-        os << " : (";
+      if (printResultTypes && op->getNumResults() > 0) {
+        os << "|{";
         std::string buf;
         llvm::raw_string_ostream ss(buf);
-        interleaveComma(op->getResultTypes(), ss);
-        os << truncateString(buf) << ")";
+        interleave(
+            op->getResultTypes(), ss,
+            [&](Type type) {
+              ss << escapeLabelString(
+                  strFromOs([&](raw_ostream &os) { os << type; }));
+            },
+            "|");
+        // TODO: how to truncate string without breaking the layout?
+        os << buf << "}";
       }
 
       // Print attributes.

>From 0fe91630485696735a52564fef945f85afcbbb00 Mon Sep 17 00:00:00 2001
From: Eric Hein <ehein at modular.com>
Date: Fri, 31 Jan 2025 15:22:33 -0500
Subject: [PATCH 05/16] Fix attribute printing

---
 mlir/lib/Transforms/ViewOpGraph.cpp | 22 ++++++++++++----------
 1 file changed, 12 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
index bf8e2a845751b89..2ddefe971f38a19 100644
--- a/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -66,7 +66,7 @@ std::string escapeLabelString(const std::string &str) {
   std::string buf;
   llvm::raw_string_ostream os(buf);
   for (char c : str) {
-    if (c == '{' || c == '|' || c == '<' || c == '}' || c == '>') {
+    if (c == '{' || c == '|' || c == '<' || c == '}' || c == '>' || c == '\n') {
       os << '\\';
     }
     os << c;
@@ -222,7 +222,7 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
     std::string buf;
     llvm::raw_string_ostream ss(buf);
     attr.print(ss);
-    os << truncateString(buf);
+    os << escapeLabelString(truncateString(buf));
   }
 
   /// Append an edge to the list of edges.
@@ -306,6 +306,16 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
       }
       // Print operation name and type.
       os << op->getName();
+
+      // Print attributes.
+      if (printAttrs && !op->getAttrs().empty()) {
+        os << "\\n";
+        for (const NamedAttribute &attr : op->getAttrs()) {
+          os << "\\n" << attr.getName().getValue() << ": ";
+          emitMlirAttr(os, attr.getValue());
+        }
+      }
+
       if (printResultTypes && op->getNumResults() > 0) {
         os << "|{";
         std::string buf;
@@ -321,14 +331,6 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
         os << buf << "}";
       }
 
-      // Print attributes.
-      if (printAttrs) {
-        os << "\n";
-        for (const NamedAttribute &attr : op->getAttrs()) {
-          os << '\n' << attr.getName().getValue() << ": ";
-          emitMlirAttr(os, attr.getValue());
-        }
-      }
       os << "}";
     });
   }

>From 96ec4fe799b18479b309e9476d73c129dee12e82 Mon Sep 17 00:00:00 2001
From: Eric Hein <ehein at modular.com>
Date: Fri, 31 Jan 2025 15:41:02 -0500
Subject: [PATCH 06/16] Switch to Mrecord shape (rounded rectangle)

---
 mlir/lib/Transforms/ViewOpGraph.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
index 2ddefe971f38a19..7493eab592f8ec9 100644
--- a/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -29,7 +29,7 @@ using namespace mlir;
 
 static const StringRef kLineStyleControlFlow = "dashed";
 static const StringRef kLineStyleDataFlow = "solid";
-static const StringRef kShapeNode = "record";
+static const StringRef kShapeNode = "Mrecord";
 static const StringRef kShapeNone = "plain";
 
 /// Return the size limits for eliding large attributes.

>From 5094ab1700002a9b6fca63d6ac19305c61d2f5ec Mon Sep 17 00:00:00 2001
From: Eric Hein <ehein at modular.com>
Date: Fri, 31 Jan 2025 15:41:27 -0500
Subject: [PATCH 07/16] Handle more special characters in node labels

---
 mlir/lib/Transforms/ViewOpGraph.cpp | 14 +++++++++-----
 1 file changed, 9 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
index 7493eab592f8ec9..246aa88a045b581 100644
--- a/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -65,8 +65,9 @@ static std::string quoteString(const std::string &str) {
 std::string escapeLabelString(const std::string &str) {
   std::string buf;
   llvm::raw_string_ostream os(buf);
+  llvm::DenseSet<char> shouldEscape = {'{', '|', '<', '}', '>', '\n', '"'};
   for (char c : str) {
-    if (c == '{' || c == '|' || c == '<' || c == '}' || c == '>' || c == '\n') {
+    if (shouldEscape.contains(c)) {
       os << '\\';
     }
     os << c;
@@ -279,10 +280,13 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
 
   std::string getOperandPortName(Value operand) {
     // Print value as an operand and omit the leading '%' character.
-    return strFromOs([&](raw_ostream &os) {
-             operand.printAsOperand(os, OpPrintingFlags());
-           })
-        .substr(1, std::string::npos);
+    auto str = strFromOs([&](raw_ostream &os) {
+      operand.printAsOperand(os, OpPrintingFlags());
+    });
+    // Replace % and # with _
+    std::replace(str.begin(), str.end(), '%', '_');
+    std::replace(str.begin(), str.end(), '#', '_');
+    return str;
   }
 
   /// Generate a label for an operation.

>From bbd7b257137c8d53d54de8ed8bf52a015b49ae0d Mon Sep 17 00:00:00 2001
From: Eric Hein <ehein at modular.com>
Date: Sat, 1 Feb 2025 14:30:55 -0500
Subject: [PATCH 08/16] Fix output port names

---
 mlir/lib/Transforms/ViewOpGraph.cpp | 18 +++++++++---------
 1 file changed, 9 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
index 246aa88a045b581..27d7131543c454f 100644
--- a/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -320,19 +320,19 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
         }
       }
 
-      if (printResultTypes && op->getNumResults() > 0) {
+      if (op->getNumResults() > 0) {
         os << "|{";
-        std::string buf;
-        llvm::raw_string_ostream ss(buf);
         interleave(
-            op->getResultTypes(), ss,
-            [&](Type type) {
-              ss << escapeLabelString(
-                  strFromOs([&](raw_ostream &os) { os << type; }));
+            op->getResults(), os,
+            [&](Value result) {
+              os << "<" << getOperandPortName(result) << "> ";
+              if (printResultTypes)
+                os << truncateString(escapeLabelString(strFromOs(
+                    [&](raw_ostream &os) { os << result.getType(); })));
             },
             "|");
         // TODO: how to truncate string without breaking the layout?
-        os << buf << "}";
+        os << "}";
       }
 
       os << "}";
@@ -385,7 +385,7 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
       for (unsigned i = 0; i < numOperands; i++) {
         auto operand = op->getOperand(i);
         auto inPort = getOperandPortName(operand);
-        dataFlowEdges.push_back({operand, "", node, inPort});
+        dataFlowEdges.push_back({operand, inPort, node, inPort});
       }
     }
 

>From 89f780c56b0b23768cf37833179369912fac1302 Mon Sep 17 00:00:00 2001
From: Eric Hein <ehein at modular.com>
Date: Sat, 1 Feb 2025 15:16:34 -0500
Subject: [PATCH 09/16] Clean up port generation code

---
 mlir/lib/Transforms/ViewOpGraph.cpp | 37 +++++++++++++----------------
 1 file changed, 16 insertions(+), 21 deletions(-)

diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
index 27d7131543c454f..af8095583de5a88 100644
--- a/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -278,7 +278,7 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
     return Node(nodeId);
   }
 
-  std::string getOperandPortName(Value operand) {
+  std::string getValuePortName(Value operand) {
     // Print value as an operand and omit the leading '%' character.
     auto str = strFromOs([&](raw_ostream &os) {
       operand.printAsOperand(os, OpPrintingFlags());
@@ -297,15 +297,11 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
       // Print operation inputs.
       if (op->getNumOperands() > 0) {
         os << "{";
-        interleave(
-            op->getOperands(), os,
-            [&](Value operand) {
-              os << "<";
-              os << getOperandPortName(operand);
-              os << "> ";
-              operand.printAsOperand(os, OpPrintingFlags());
-            },
-            "|");
+        auto operandToPort = [&](Value operand) {
+          os << "<" << getValuePortName(operand) << "> ";
+          operand.printAsOperand(os, OpPrintingFlags());
+        };
+        interleave(op->getOperands(), os, operandToPort, "|");
         os << "}|";
       }
       // Print operation name and type.
@@ -322,16 +318,15 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
 
       if (op->getNumResults() > 0) {
         os << "|{";
-        interleave(
-            op->getResults(), os,
-            [&](Value result) {
-              os << "<" << getOperandPortName(result) << "> ";
-              if (printResultTypes)
-                os << truncateString(escapeLabelString(strFromOs(
-                    [&](raw_ostream &os) { os << result.getType(); })));
-            },
-            "|");
-        // TODO: how to truncate string without breaking the layout?
+        auto resultToPort = [&](Value result) {
+          os << "<" << getValuePortName(result) << "> ";
+          if (printResultTypes)
+            os << truncateString(escapeLabelString(
+                strFromOs([&](raw_ostream &os) { os << result.getType(); })));
+          else
+            result.printAsOperand(os, OpPrintingFlags());
+        };
+        interleave(op->getResults(), os, resultToPort, "|");
         os << "}";
       }
 
@@ -384,7 +379,7 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
       unsigned numOperands = op->getNumOperands();
       for (unsigned i = 0; i < numOperands; i++) {
         auto operand = op->getOperand(i);
-        auto inPort = getOperandPortName(operand);
+        auto inPort = getValuePortName(operand);
         dataFlowEdges.push_back({operand, inPort, node, inPort});
       }
     }

>From 5bfa8a273137c4a45966db782a06d55ff11aad94 Mon Sep 17 00:00:00 2001
From: Eric Hein <ehein at modular.com>
Date: Sat, 1 Feb 2025 15:40:45 -0500
Subject: [PATCH 10/16] Left-justify the label text.

---
 mlir/lib/Transforms/ViewOpGraph.cpp | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
index af8095583de5a88..3bc6a460d01ed7a 100644
--- a/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -305,14 +305,16 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
         os << "}|";
       }
       // Print operation name and type.
-      os << op->getName();
+      os << op->getName() << "\\l";
 
       // Print attributes.
       if (printAttrs && !op->getAttrs().empty()) {
-        os << "\\n";
+        // Extra line break to separate attributes from the operation name.
+        os << "\\l";
         for (const NamedAttribute &attr : op->getAttrs()) {
-          os << "\\n" << attr.getName().getValue() << ": ";
+          os << attr.getName().getValue() << ": ";
           emitMlirAttr(os, attr.getValue());
+          os << "\\l";
         }
       }
 

>From dadfa4389bfbdfcfbf3d9a0009a0adba1a8c9b07 Mon Sep 17 00:00:00 2001
From: Eric Hein <ehein at modular.com>
Date: Sat, 1 Feb 2025 15:41:09 -0500
Subject: [PATCH 11/16] Always print result short name.

---
 mlir/lib/Transforms/ViewOpGraph.cpp | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
index 3bc6a460d01ed7a..3576ed601843707 100644
--- a/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -322,11 +322,11 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
         os << "|{";
         auto resultToPort = [&](Value result) {
           os << "<" << getValuePortName(result) << "> ";
+          result.printAsOperand(os, OpPrintingFlags());
           if (printResultTypes)
-            os << truncateString(escapeLabelString(
-                strFromOs([&](raw_ostream &os) { os << result.getType(); })));
-          else
-            result.printAsOperand(os, OpPrintingFlags());
+            os << " "
+               << truncateString(escapeLabelString(strFromOs(
+                      [&](raw_ostream &os) { os << result.getType(); })));
         };
         interleave(op->getResults(), os, resultToPort, "|");
         os << "}";

>From 9afd8cf3b34bf85215e4108f6ec2db41f34d8120 Mon Sep 17 00:00:00 2001
From: Eric Hein <ehein at modular.com>
Date: Sat, 1 Feb 2025 15:58:31 -0500
Subject: [PATCH 12/16] Fix edge attachment points.

---
 mlir/lib/Transforms/ViewOpGraph.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
index 3576ed601843707..ac9ce6b0c641cfe 100644
--- a/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -241,11 +241,11 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
     edges.push_back(strFromOs([&](raw_ostream &os) {
       os << "v" << n1.id;
       if (!outPort.empty())
-        os << ":" << outPort;
+        os << ":" << outPort << ":s";
       os << " -> ";
       os << "v" << n2.id;
       if (!inPort.empty())
-        os << ":" << inPort;
+        os << ":" << inPort << ":n";
       emitAttrList(os, attrs);
     }));
   }

>From 744eb5f6c93a7d1c4ca7ba48cb45db378b1ebd55 Mon Sep 17 00:00:00 2001
From: Eric Hein <ehein at modular.com>
Date: Sun, 2 Feb 2025 13:30:47 -0500
Subject: [PATCH 13/16] Merge in and out ports into a single port attribute.

---
 mlir/lib/Transforms/ViewOpGraph.cpp | 24 +++++++++++-------------
 1 file changed, 11 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
index ac9ce6b0c641cfe..43949ca289590bd 100644
--- a/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -97,9 +97,8 @@ struct Node {
 
 struct DataFlowEdge {
   Value value;
-  std::string outPort;
   Node node;
-  std::string inPort;
+  std::string port;
 };
 
 /// This pass generates a Graphviz dataflow visualization of an MLIR operation.
@@ -153,8 +152,7 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
   void emitAllEdgeStmts() {
     if (printDataFlowEdges) {
       for (const auto &e : dataFlowEdges) {
-        emitEdgeStmt(valueToNode[e.value], e.outPort, e.node, e.inPort,
-                     kLineStyleDataFlow);
+        emitEdgeStmt(valueToNode[e.value], e.node, e.port, kLineStyleDataFlow);
       }
     }
 
@@ -228,8 +226,7 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
 
   /// Append an edge to the list of edges.
   /// Note: Edges are written to the output stream via `emitAllEdgeStmts`.
-  void emitEdgeStmt(Node n1, std::string outPort, Node n2, std::string inPort,
-                    StringRef style) {
+  void emitEdgeStmt(Node n1, Node n2, std::string port, StringRef style) {
     AttributeMap attrs;
     attrs["style"] = style.str();
     // Use `ltail` and `lhead` to draw edges between clusters.
@@ -240,12 +237,14 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
 
     edges.push_back(strFromOs([&](raw_ostream &os) {
       os << "v" << n1.id;
-      if (!outPort.empty())
-        os << ":" << outPort << ":s";
+      if (!port.empty())
+        // Attach edge to south compass point of the result
+        os << ":" << port << ":s";
       os << " -> ";
       os << "v" << n2.id;
-      if (!inPort.empty())
-        os << ":" << inPort << ":n";
+      if (!port.empty())
+        // Attach edge to north compass point of the operand
+        os << ":" << port << ":n";
       emitAttrList(os, attrs);
     }));
   }
@@ -353,7 +352,7 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
       for (Operation &op : block) {
         Node nextNode = processOperation(&op);
         if (printControlFlowEdges && prevNode)
-          emitEdgeStmt(*prevNode, "", nextNode, "", kLineStyleControlFlow);
+          emitEdgeStmt(*prevNode, nextNode, /*port=*/"", kLineStyleControlFlow);
         prevNode = nextNode;
       }
     });
@@ -381,8 +380,7 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
       unsigned numOperands = op->getNumOperands();
       for (unsigned i = 0; i < numOperands; i++) {
         auto operand = op->getOperand(i);
-        auto inPort = getValuePortName(operand);
-        dataFlowEdges.push_back({operand, inPort, node, inPort});
+        dataFlowEdges.push_back({operand, node, getValuePortName(operand)});
       }
     }
 

>From 904193fd50e80877c0b9a03ccb6bd4eb6e599682 Mon Sep 17 00:00:00 2001
From: Eric Hein <ehein at modular.com>
Date: Sun, 2 Feb 2025 21:16:49 -0500
Subject: [PATCH 14/16] Fix formatting of cluster labels

---
 mlir/lib/Transforms/ViewOpGraph.cpp | 38 +++++++++++++++++++++--------
 1 file changed, 28 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
index 43949ca289590bd..2fef6aa4d9a082d 100644
--- a/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -49,11 +49,6 @@ static std::string strFromOs(function_ref<void(raw_ostream &)> func) {
   return buf;
 }
 
-/// Escape special characters such as '\n' and quotation marks.
-static std::string escapeString(std::string str) {
-  return strFromOs([&](raw_ostream &os) { os.write_escaped(str); });
-}
-
 /// Put quotation marks around a given string.
 static std::string quoteString(const std::string &str) {
   return "\"" + str + "\"";
@@ -169,8 +164,7 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
     os.indent();
     // Emit invisible anchor node from/to which arrows can be drawn.
     Node anchorNode = emitNodeStmt(" ", kShapeNone);
-    os << attrStmt("label", quoteString(escapeString(std::move(label))))
-       << ";\n";
+    os << attrStmt("label", quoteString(label)) << ";\n";
     builder();
     os.unindent();
     os << "}\n";
@@ -288,8 +282,32 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
     return str;
   }
 
+  std::string getClusterLabel(Operation *op) {
+    return strFromOs([&](raw_ostream &os) {
+      // Print operation name and type.
+      os << op->getName();
+      if (printResultTypes) {
+        os << " : (";
+        std::string buf;
+        llvm::raw_string_ostream ss(buf);
+        interleaveComma(op->getResultTypes(), ss);
+        os << truncateString(buf) << ")";
+      }
+
+      // Print attributes.
+      if (printAttrs) {
+        os << "\\l";
+        for (const NamedAttribute &attr : op->getAttrs()) {
+          os << attr.getName().getValue() << ": ";
+          emitMlirAttr(os, attr.getValue());
+          os << "\\l";
+        }
+      }
+    });
+  }
+
   /// Generate a label for an operation.
-  std::string getLabel(Operation *op) {
+  std::string getRecordLabel(Operation *op) {
     return strFromOs([&](raw_ostream &os) {
       os << "{";
 
@@ -369,9 +387,9 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
             for (Region &region : op->getRegions())
               processRegion(region);
           },
-          getLabel(op));
+          getClusterLabel(op));
     } else {
-      node = emitNodeStmt(getLabel(op), kShapeNode,
+      node = emitNodeStmt(getRecordLabel(op), kShapeNode,
                           backgroundColors[op->getName()].second);
     }
 

>From 291350c4a7da2930194a0d5aa3072e2333b1eabf Mon Sep 17 00:00:00 2001
From: Eric Hein <ehein at modular.com>
Date: Sun, 2 Feb 2025 21:32:26 -0500
Subject: [PATCH 15/16] Print result types in block arguments

---
 mlir/lib/Transforms/ViewOpGraph.cpp | 9 ++++++++-
 1 file changed, 8 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
index 2fef6aa4d9a082d..abf0d8e36bc5ca7 100644
--- a/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -355,7 +355,14 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
 
   /// Generate a label for a block argument.
   std::string getLabel(BlockArgument arg) {
-    return "arg" + std::to_string(arg.getArgNumber());
+    return strFromOs([&](raw_ostream &os) {
+      os << "<" << getValuePortName(arg) << "> ";
+      arg.printAsOperand(os, OpPrintingFlags());
+      if (printResultTypes)
+        os << " "
+           << truncateString(escapeLabelString(
+                  strFromOs([&](raw_ostream &os) { os << arg.getType(); })));
+    });
   }
 
   /// Process a block. Emit a cluster and one node per block argument and

>From b709ff7e6fb1b0ad1dabeb3a29e9526d6d46a316 Mon Sep 17 00:00:00 2001
From: Eric Hein <ehein at modular.com>
Date: Mon, 3 Feb 2025 09:13:16 -0500
Subject: [PATCH 16/16] Improved node fill color choices

---
 mlir/lib/Transforms/ViewOpGraph.cpp | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
index abf0d8e36bc5ca7..065b4da568b588d 100644
--- a/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -124,7 +124,7 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
 private:
   /// Generate a color mapping that will color every operation with the same
   /// name the same way. It'll interpolate the hue in the HSV color-space,
-  /// attempting to keep the contrast suitable for black text.
+  /// using muted colors that provide good contrast for black text.
   template <typename T>
   void initColorMapping(T &irEntity) {
     backgroundColors.clear();
@@ -137,8 +137,10 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
     });
     for (auto indexedOps : llvm::enumerate(ops)) {
       double hue = ((double)indexedOps.index()) / ops.size();
+      // Use lower saturation (0.3) and higher value (0.95) for better
+      // readability
       backgroundColors[indexedOps.value()->getName()].second =
-          std::to_string(hue) + " 1.0 1.0";
+          std::to_string(hue) + " 0.3 0.95";
     }
   }
 
@@ -263,7 +265,7 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
     attrs["shape"] = shape.str();
     if (!background.empty()) {
       attrs["style"] = "filled";
-      attrs["fillcolor"] = ("\"" + background + "\"").str();
+      attrs["fillcolor"] = quoteString(background.str());
     }
     os << llvm::format("v%i ", nodeId);
     emitAttrList(os, attrs);



More information about the Mlir-commits mailing list