[Mlir-commits] [mlir] 8d15b7d - [mlir] Improve Graphviz visualization in PrintOpPass

Matthias Springer llvmlistbot at llvm.org
Tue Aug 3 19:56:53 PDT 2021


Author: Matthias Springer
Date: 2021-08-04T11:56:26+09:00
New Revision: 8d15b7dcbaa1469d7e147ebdce988cca861ace6d

URL: https://github.com/llvm/llvm-project/commit/8d15b7dcbaa1469d7e147ebdce988cca861ace6d
DIFF: https://github.com/llvm/llvm-project/commit/8d15b7dcbaa1469d7e147ebdce988cca861ace6d.diff

LOG: [mlir] Improve Graphviz visualization in PrintOpPass

* Visualize blocks and regions as subgraphs.
* Generate DOT file directly instead of using `GraphTraits`. `GraphTraits` does not support subgraphs.

Differential Revision: https://reviews.llvm.org/D106253

Added: 
    

Modified: 
    mlir/include/mlir/Support/IndentedOstream.h
    mlir/include/mlir/Transforms/Passes.td
    mlir/include/mlir/Transforms/ViewOpGraph.h
    mlir/lib/Transforms/CMakeLists.txt
    mlir/lib/Transforms/ViewOpGraph.cpp
    mlir/test/Transforms/print-op-graph.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Support/IndentedOstream.h b/mlir/include/mlir/Support/IndentedOstream.h
index 6d0701607f469..9a755bc7ebb08 100644
--- a/mlir/include/mlir/Support/IndentedOstream.h
+++ b/mlir/include/mlir/Support/IndentedOstream.h
@@ -45,6 +45,9 @@ class raw_indented_ostream : public raw_ostream {
     llvm::StringRef open, close;
   };
 
+  /// Returns the underlying (unindented) raw_ostream.
+  raw_ostream &getOStream() const { return os; }
+
   /// Returns DelimitedScope.
   DelimitedScope scope(StringRef open = "", StringRef close = "") {
     return DelimitedScope(*this, open, close);

diff  --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index ecd60de1104ed..324f06454a49d 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -688,21 +688,19 @@ def SymbolDCE : Pass<"symbol-dce"> {
   let constructor = "mlir::createSymbolDCEPass()";
 }
 
-def ViewOpGraphPass : Pass<"view-op-graph", "ModuleOp"> {
-  let summary = "Print graphviz view of module";
+def ViewOpGraphPass : Pass<"view-op-graph"> {
+  let summary = "Print Graphviz dataflow visualization of an operation";
   let description = [{
-    This pass prints a graphviz per block of a module.
+    This pass prints a Graphviz dataflow graph of a module.
 
-    - Op are represented as nodes;
+    - Operations are represented as nodes;
     - Uses as edges;
+    - Regions/blocks as subgraphs.
+
+    Note: See https://www.graphviz.org/doc/info/lang.html for more information
+    about the Graphviz DOT language.
   }];
   let constructor = "mlir::createPrintOpGraphPass()";
-  let options = [
-    Option<"title", "title", "std::string",
-           /*default=*/"", "The prefix of the title of the graph">,
-    Option<"shortNames", "short-names", "bool", /*default=*/"false",
-           "Use short names">
-  ];
 }
 
 #endif // MLIR_TRANSFORMS_PASSES

diff  --git a/mlir/include/mlir/Transforms/ViewOpGraph.h b/mlir/include/mlir/Transforms/ViewOpGraph.h
index 61f40358fec21..ec1b6c281e877 100644
--- a/mlir/include/mlir/Transforms/ViewOpGraph.h
+++ b/mlir/include/mlir/Transforms/ViewOpGraph.h
@@ -14,27 +14,14 @@
 #define MLIR_TRANSFORMS_VIEWOPGRAPH_H_
 
 #include "mlir/Support/LLVM.h"
-#include "llvm/Support/GraphWriter.h"
 #include "llvm/Support/raw_ostream.h"
 
 namespace mlir {
-class Block;
-class ModuleOp;
-template <typename T> class OperationPass;
-
-/// Displays the graph in a window. This is for use from the debugger and
-/// depends on Graphviz to generate the graph.
-void viewGraph(Block &block, const Twine &name, bool shortNames = false,
-               const Twine &title = "",
-               llvm::GraphProgram::Name program = llvm::GraphProgram::DOT);
-
-raw_ostream &writeGraph(raw_ostream &os, Block &block, bool shortNames = false,
-                        const Twine &title = "");
+class Pass;
 
 /// Creates a pass to print op graphs.
-std::unique_ptr<OperationPass<ModuleOp>>
-createPrintOpGraphPass(raw_ostream &os = llvm::errs(), bool shortNames = false,
-                       const Twine &title = "");
+std::unique_ptr<Pass>
+createPrintOpGraphPass(raw_ostream &os = llvm::errs());
 
 } // end namespace mlir
 

diff  --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index ee67953430d5e..5563b70f379b9 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -39,6 +39,7 @@ add_mlir_library(MLIRTransforms
   MLIRMemRef
   MLIRSCF
   MLIRPass
+  MLIRSupportIndentedOstream
   MLIRTransformUtils
   MLIRVector
   )

diff  --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
index 3d52d79b7ef7b..86c5725b8910d 100644
--- a/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -9,12 +9,16 @@
 #include "mlir/Transforms/ViewOpGraph.h"
 #include "PassDetail.h"
 #include "mlir/IR/Block.h"
-#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Operation.h"
-#include "llvm/Support/CommandLine.h"
+#include "mlir/Support/IndentedOstream.h"
+#include "llvm/Support/Format.h"
 
 using namespace mlir;
 
+static const StringRef kLineStyleDataFlow = "solid";
+static const StringRef kShapeNode = "ellipse";
+static const StringRef kShapeNone = "plain";
+
 /// Return the size limits for eliding large attributes.
 static int64_t getLargeAttributeSizeLimit() {
   // Use the default from the printer flags if possible.
@@ -23,145 +27,251 @@ static int64_t getLargeAttributeSizeLimit() {
   return 16;
 }
 
-namespace llvm {
+/// Return all values printed onto a stream as a string.
+static std::string strFromOs(function_ref<void(raw_ostream &)> func) {
+  std::string buf;
+  llvm::raw_string_ostream os(buf);
+  func(os);
+  return os.str();
+}
+
+/// Escape special characters such as '\n' and quotation marks.
+static std::string escapeString(std::string str) {
+  return strFromOs([&](raw_ostream &os) { os.write_escaped(str); });
+}
+
+/// Put quotation marks around a given string.
+static std::string quoteString(std::string str) { return "\"" + str + "\""; }
 
-// Specialize GraphTraits to treat Block as a graph of Operations as nodes and
-// uses as edges.
-template <> struct GraphTraits<Block *> {
-  using GraphType = Block *;
-  using NodeRef = Operation *;
+using AttributeMap = llvm::StringMap<std::string>;
 
-  using ChildIteratorType = Operation::user_iterator;
-  static ChildIteratorType child_begin(NodeRef n) { return n->user_begin(); }
-  static ChildIteratorType child_end(NodeRef n) { return n->user_end(); }
+namespace {
+
+/// This struct represents a node in the DOT language. Each node has an
+/// identifier and an optional identifier for the cluster (subgraph) that
+/// contains the node.
+/// Note: In the DOT language, edges can be drawn only from nodes to nodes, but
+/// not between clusters. However, edges can be clipped to the boundary of a
+/// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new
+/// cluster, an invisible "anchor" node is created.
+struct Node {
+public:
+  Node(int id = 0, Optional<int> clusterId = llvm::None)
+      : id(id), clusterId(clusterId) {}
 
-  // Operation's destructor is private so use Operation* instead and use
-  // mapped iterator.
-  static Operation *AddressOf(Operation &op) { return &op; }
-  using nodes_iterator = mapped_iterator<Block::iterator, decltype(&AddressOf)>;
-  static nodes_iterator nodes_begin(Block *b) {
-    return nodes_iterator(b->begin(), &AddressOf);
+  int id;
+  Optional<int> clusterId;
+};
+
+/// This pass generates a Graphviz dataflow visualization of an MLIR operation.
+/// Note: See https://www.graphviz.org/doc/info/lang.html for more information
+/// about the Graphviz DOT language.
+class PrintOpPass : public ViewOpGraphPassBase<PrintOpPass> {
+public:
+  PrintOpPass(raw_ostream &os) : os(os) {}
+  PrintOpPass(const PrintOpPass &o) : os(o.os.getOStream()) {}
+
+  void runOnOperation() override {
+    emitGraph([&]() {
+      processOperation(getOperation());
+      emitAllEdgeStmts();
+    });
   }
-  static nodes_iterator nodes_end(Block *b) {
-    return nodes_iterator(b->end(), &AddressOf);
+
+private:
+  /// Emit all edges. This function should be called after all nodes have been
+  /// emitted.
+  void emitAllEdgeStmts() {
+    for (const std::string &edge : edges)
+      os << edge << ";\n";
+    edges.clear();
   }
-};
 
-// Specialize DOTGraphTraits to produce more readable output.
-template <> struct DOTGraphTraits<Block *> : public DefaultDOTGraphTraits {
-  using DefaultDOTGraphTraits::DefaultDOTGraphTraits;
-  static std::string getNodeLabel(Operation *op, Block *);
-};
+  /// Emit a cluster (subgraph). The specified builder generates the body of the
+  /// cluster. Return the anchor node of the cluster.
+  Node emitClusterStmt(function_ref<void()> builder, std::string label = "") {
+    int clusterId = ++counter;
+    os << "subgraph cluster_" << clusterId << " {\n";
+    os.indent();
+    // Emit invisible anchor node from/to which arrows can be drawn.
+    Node anchorNode = emitNodeStmt(" ", kShapeNone);
+    os << attrStmt("label", quoteString(escapeString(label))) << ";\n";
+    builder();
+    os.unindent();
+    os << "}\n";
+    return Node(anchorNode.id, clusterId);
+  }
 
-std::string DOTGraphTraits<Block *>::getNodeLabel(Operation *op, Block *b) {
-  // Reuse the print output for the node labels.
-  std::string ostr;
-  raw_string_ostream os(ostr);
-  os << op->getName() << "\n";
+  /// Generate an attribute statement.
+  std::string attrStmt(const Twine &key, const Twine &value) {
+    return (key + " = " + value).str();
+  }
 
-  if (!op->getLoc().isa<UnknownLoc>()) {
-    os << op->getLoc() << "\n";
+  /// Emit an attribute list.
+  void emitAttrList(raw_ostream &os, const AttributeMap &map) {
+    os << "[";
+    interleaveComma(map, os, [&](const auto &it) {
+      os << attrStmt(it.getKey(), it.getValue());
+    });
+    os << "]";
   }
 
-  // Print resultant types
-  llvm::interleaveComma(op->getResultTypes(), os);
-  os << "\n";
+  // Print an MLIR attribute to `os`. Large attributes are truncated.
+  void emitMlirAttr(raw_ostream &os, Attribute attr) {
+    // A value used to elide large container attribute.
+    int64_t largeAttrLimit = getLargeAttributeSizeLimit();
 
-  // A value used to elide large container attribute.
-  int64_t largeAttrLimit = getLargeAttributeSizeLimit();
-  for (auto attr : op->getAttrs()) {
-    os << '\n' << attr.first << ": ";
     // Always emit splat attributes.
-    if (attr.second.isa<SplatElementsAttr>()) {
-      attr.second.print(os);
-      continue;
+    if (attr.isa<SplatElementsAttr>()) {
+      attr.print(os);
+      return;
     }
 
     // Elide "big" elements attributes.
-    auto elements = attr.second.dyn_cast<ElementsAttr>();
+    auto elements = attr.dyn_cast<ElementsAttr>();
     if (elements && elements.getNumElements() > largeAttrLimit) {
       os << std::string(elements.getType().getRank(), '[') << "..."
          << std::string(elements.getType().getRank(), ']') << " : "
          << elements.getType();
-      continue;
+      return;
     }
 
-    auto array = attr.second.dyn_cast<ArrayAttr>();
+    auto array = attr.dyn_cast<ArrayAttr>();
     if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) {
       os << "[...]";
-      continue;
+      return;
     }
 
     // Print all other attributes.
-    attr.second.print(os);
+    attr.print(os);
   }
-  return os.str();
-}
 
-} // end namespace llvm
+  /// Append an edge to the list of edges.
+  /// Note: Edges are written to the output stream via `emitAllEdgeStmts`.
+  void emitEdgeStmt(Node n1, Node n2, std::string label,
+                    StringRef style = kLineStyleDataFlow) {
+    AttributeMap attrs;
+    attrs["style"] = style.str();
+    // Do not label edges that start/end at a cluster boundary. Such edges are
+    // clipped at the boundary, but labels are not. This can lead to labels
+    // floating around without any edge next to them.
+    if (!n1.clusterId && !n2.clusterId)
+      attrs["label"] = quoteString(escapeString(label));
+    // Use `ltail` and `lhead` to draw edges between clusters.
+    if (n1.clusterId)
+      attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId);
+    if (n2.clusterId)
+      attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId);
 
-namespace {
-// PrintOpPass is simple pass to write graph per function.
-// Note: this is a module pass only to avoid interleaving on the same ostream
-// due to multi-threading over functions.
-class PrintOpPass : public ViewOpGraphPassBase<PrintOpPass> {
-public:
-  PrintOpPass(raw_ostream &os, bool shortNames, const Twine &title) : os(os) {
-    this->shortNames = shortNames;
-    this->title = title.str();
+    edges.push_back(strFromOs([&](raw_ostream &os) {
+      os << llvm::format("v%i -> v%i ", n1.id, n2.id);
+      emitAttrList(os, attrs);
+    }));
   }
 
-  std::string getOpName(Operation &op) {
-    auto symbolAttr =
-        op.getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
-    if (symbolAttr)
-      return std::string(symbolAttr.getValue());
-    ++unnamedOpCtr;
-    return (op.getName().getStringRef() + llvm::utostr(unnamedOpCtr)).str();
+  /// Emit a graph. The specified builder generates the body of the graph.
+  void emitGraph(function_ref<void()> builder) {
+    os << "digraph G {\n";
+    os.indent();
+    // Edges between clusters are allowed only in compound mode.
+    os << attrStmt("compound", "true") << ";\n";
+    builder();
+    os.unindent();
+    os << "}\n";
   }
 
-  // Print all the ops in a module.
-  void processModule(ModuleOp module) {
-    for (Operation &op : module) {
-      // Modules may actually be nested, recurse on nesting.
-      if (auto nestedModule = dyn_cast<ModuleOp>(op)) {
-        processModule(nestedModule);
-        continue;
-      }
-      auto opName = getOpName(op);
-      for (Region &region : op.getRegions()) {
-        for (auto indexed_block : llvm::enumerate(region)) {
-          // Suffix block number if there are more than 1 block.
-          auto blockName = llvm::hasSingleElement(region)
-                               ? ""
-                               : ("__" + llvm::utostr(indexed_block.index()));
-          llvm::WriteGraph(os, &indexed_block.value(), shortNames,
-                           Twine(title) + opName + blockName);
-        }
+  /// Emit a node statement.
+  Node emitNodeStmt(std::string label, StringRef shape = kShapeNode) {
+    int nodeId = ++counter;
+    AttributeMap attrs;
+    attrs["label"] = quoteString(escapeString(label));
+    attrs["shape"] = shape.str();
+    os << llvm::format("v%i ", nodeId);
+    emitAttrList(os, attrs);
+    os << ";\n";
+    return Node(nodeId);
+  }
+
+  /// Generate a label for an operation.
+  std::string getLabel(Operation *op) {
+    return strFromOs([&](raw_ostream &os) {
+      // Print operation name and type.
+      os << op->getName() << " : (";
+      interleaveComma(op->getResultTypes(), os);
+      os << ")\n";
+
+      // Print attributes.
+      for (const NamedAttribute &attr : op->getAttrs()) {
+        os << '\n' << attr.first << ": ";
+        emitMlirAttr(os, attr.second);
       }
+    });
+  }
+
+  /// Generate a label for a block argument.
+  std::string getLabel(BlockArgument arg) {
+    return "arg" + std::to_string(arg.getArgNumber());
+  }
+
+  /// Process a block. Emit a cluster and one node per block argument and
+  /// operation inside the cluster.
+  void processBlock(Block &block) {
+    emitClusterStmt([&]() {
+      for (BlockArgument &blockArg : block.getArguments())
+        valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg));
+
+      // Emit a node for each operation.
+      for (Operation &op : block)
+        processOperation(&op);
+    });
+  }
+
+  /// Process an operation. If the operation has regions, emit a cluster.
+  /// Otherwise, emit a node.
+  void processOperation(Operation *op) {
+    Node node;
+    if (op->getNumRegions() > 0) {
+      // Emit cluster for op with regions.
+      node = emitClusterStmt(
+          [&]() {
+            for (Region &region : op->getRegions())
+              processRegion(region);
+          },
+          getLabel(op));
+    } else {
+      node = emitNodeStmt(getLabel(op));
     }
+
+    // Insert edges originating from each operand.
+    unsigned numOperands = op->getNumOperands();
+    for (unsigned i = 0; i < numOperands; i++)
+      emitEdgeStmt(valueToNode[op->getOperand(i)], node,
+                   /*label=*/numOperands == 1 ? "" : std::to_string(i));
+
+    for (Value result : op->getResults())
+      valueToNode[result] = node;
   }
 
-  void runOnOperation() override { processModule(getOperation()); }
+  /// Process a region.
+  void processRegion(Region &region) {
+    for (Block &block : region.getBlocks())
+      processBlock(block);
+  }
 
-private:
-  raw_ostream &os;
-  int unnamedOpCtr = 0;
+  /// Output stream to write DOT file to.
+  raw_indented_ostream os;
+  /// A list of edges. For simplicity, should be emitted after all nodes were
+  /// emitted.
+  std::vector<std::string> edges;
+  /// Mapping of SSA values to Graphviz nodes/clusters.
+  DenseMap<Value, Node> valueToNode;
+  /// Counter for generating unique node/subgraph identifiers.
+  int counter = 0;
 };
-} // namespace
-
-void mlir::viewGraph(Block &block, const Twine &name, bool shortNames,
-                     const Twine &title, llvm::GraphProgram::Name program) {
-  llvm::ViewGraph(&block, name, shortNames, title, program);
-}
 
-raw_ostream &mlir::writeGraph(raw_ostream &os, Block &block, bool shortNames,
-                              const Twine &title) {
-  return llvm::WriteGraph(os, &block, shortNames, title);
-}
+} // namespace
 
-std::unique_ptr<OperationPass<ModuleOp>>
-mlir::createPrintOpGraphPass(raw_ostream &os, bool shortNames,
-                             const Twine &title) {
-  return std::make_unique<PrintOpPass>(os, shortNames, title);
+std::unique_ptr<Pass>
+mlir::createPrintOpGraphPass(raw_ostream &os) {
+  return std::make_unique<PrintOpPass>(os);
 }

diff  --git a/mlir/test/Transforms/print-op-graph.mlir b/mlir/test/Transforms/print-op-graph.mlir
index 4a5ac380632e1..fb14cf333f287 100644
--- a/mlir/test/Transforms/print-op-graph.mlir
+++ b/mlir/test/Transforms/print-op-graph.mlir
@@ -1,18 +1,36 @@
 // RUN: mlir-opt -allow-unregistered-dialect -mlir-elide-elementsattrs-if-larger=2 -view-op-graph %s -o %t 2>&1 | FileCheck %s
 
-// CHECK-LABEL: digraph "merge_blocks"
-// CHECK{LITERAL}: value: [[...]] : tensor\<2x2xi32\>}
-// CHECK{LITERAL}: value: dense\<1\> : tensor\<5xi32\>}
-// CHECK{LITERAL}: value: dense\<[[0, 1]]\> : tensor\<1x2xi32\>}
+// CHECK-LABEL: digraph G {
+//       CHECK:   subgraph {{.*}} {
+//       CHECK:     subgraph {{.*}}
+//       CHECK:       label = "builtin.func{{.*}}merge_blocks
+//       CHECK:       subgraph {{.*}} {
+//       CHECK:         v[[ARG0:.*]] [label = "arg0"
+//       CHECK:         v[[CONST10:.*]] [label ={{.*}}10 : i32
+//       CHECK:         subgraph [[CLUSTER_MERGE_BLOCKS:.*]] {
+//       CHECK:           v[[ANCHOR:.*]] [label = " ", shape = plain]
+//       CHECK:           label = "test.merge_blocks
+//       CHECK:           subgraph {{.*}} {
+//       CHECK:             v[[TEST_BR:.*]] [label = "test.br
+//       CHECK:           }
+//       CHECK:           subgraph {{.*}} {
+//       CHECK:           }
+//       CHECK:         }
+//       CHECK:         v[[TEST_RET:.*]] [label = "test.return
+//       CHECK:   v[[ARG0]] -> v[[TEST_BR]]
+//       CHECK:   v[[CONST10]] -> v[[TEST_BR]]
+//       CHECK:   v[[ANCHOR]] -> v[[TEST_RET]] [{{.*}}, ltail = [[CLUSTER_MERGE_BLOCKS]]]
+//       CHECK:   v[[ANCHOR]] -> v[[TEST_RET]] [{{.*}}, ltail = [[CLUSTER_MERGE_BLOCKS]]]
 func @merge_blocks(%arg0: i32, %arg1 : i32) -> () {
   %0 = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
   %1 = constant dense<1> : tensor<5xi32>
   %2 = constant dense<[[0, 1]]> : tensor<1x2xi32>
-
+  %a = constant 10 : i32
+  %b = "test.func"() : () -> i32
   %3:2 = "test.merge_blocks"() ({
   ^bb0:
-     "test.br"(%arg0, %arg1)[^bb1] : (i32, i32) -> ()
-  ^bb1(%arg3 : i32, %arg4 : i32):
+     "test.br"(%arg0, %b, %a)[^bb1] : (i32, i32, i32) -> ()
+  ^bb1(%arg3 : i32, %arg4 : i32, %arg5: i32):
      "test.return"(%arg3, %arg4) : (i32, i32) -> ()
   }) : () -> (i32, i32)
   "test.return"(%3#0, %3#1) : (i32, i32) -> ()


        


More information about the Mlir-commits mailing list