[Mlir-commits] [mlir] 8d15b7d - [mlir] Improve Graphviz visualization in PrintOpPass
Matthias Springer
llvmlistbot at llvm.org
Tue Aug 3 19:56:53 PDT 2021
Author: Matthias Springer
Date: 2021-08-04T11:56:26+09:00
New Revision: 8d15b7dcbaa1469d7e147ebdce988cca861ace6d
URL: https://github.com/llvm/llvm-project/commit/8d15b7dcbaa1469d7e147ebdce988cca861ace6d
DIFF: https://github.com/llvm/llvm-project/commit/8d15b7dcbaa1469d7e147ebdce988cca861ace6d.diff
LOG: [mlir] Improve Graphviz visualization in PrintOpPass
* Visualize blocks and regions as subgraphs.
* Generate DOT file directly instead of using `GraphTraits`. `GraphTraits` does not support subgraphs.
Differential Revision: https://reviews.llvm.org/D106253
Added:
Modified:
mlir/include/mlir/Support/IndentedOstream.h
mlir/include/mlir/Transforms/Passes.td
mlir/include/mlir/Transforms/ViewOpGraph.h
mlir/lib/Transforms/CMakeLists.txt
mlir/lib/Transforms/ViewOpGraph.cpp
mlir/test/Transforms/print-op-graph.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Support/IndentedOstream.h b/mlir/include/mlir/Support/IndentedOstream.h
index 6d0701607f469..9a755bc7ebb08 100644
--- a/mlir/include/mlir/Support/IndentedOstream.h
+++ b/mlir/include/mlir/Support/IndentedOstream.h
@@ -45,6 +45,9 @@ class raw_indented_ostream : public raw_ostream {
llvm::StringRef open, close;
};
+ /// Returns the underlying (unindented) raw_ostream.
+ raw_ostream &getOStream() const { return os; }
+
/// Returns DelimitedScope.
DelimitedScope scope(StringRef open = "", StringRef close = "") {
return DelimitedScope(*this, open, close);
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index ecd60de1104ed..324f06454a49d 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -688,21 +688,19 @@ def SymbolDCE : Pass<"symbol-dce"> {
let constructor = "mlir::createSymbolDCEPass()";
}
-def ViewOpGraphPass : Pass<"view-op-graph", "ModuleOp"> {
- let summary = "Print graphviz view of module";
+def ViewOpGraphPass : Pass<"view-op-graph"> {
+ let summary = "Print Graphviz dataflow visualization of an operation";
let description = [{
- This pass prints a graphviz per block of a module.
+ This pass prints a Graphviz dataflow graph of a module.
- - Op are represented as nodes;
+ - Operations are represented as nodes;
- Uses as edges;
+ - Regions/blocks as subgraphs.
+
+ Note: See https://www.graphviz.org/doc/info/lang.html for more information
+ about the Graphviz DOT language.
}];
let constructor = "mlir::createPrintOpGraphPass()";
- let options = [
- Option<"title", "title", "std::string",
- /*default=*/"", "The prefix of the title of the graph">,
- Option<"shortNames", "short-names", "bool", /*default=*/"false",
- "Use short names">
- ];
}
#endif // MLIR_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Transforms/ViewOpGraph.h b/mlir/include/mlir/Transforms/ViewOpGraph.h
index 61f40358fec21..ec1b6c281e877 100644
--- a/mlir/include/mlir/Transforms/ViewOpGraph.h
+++ b/mlir/include/mlir/Transforms/ViewOpGraph.h
@@ -14,27 +14,14 @@
#define MLIR_TRANSFORMS_VIEWOPGRAPH_H_
#include "mlir/Support/LLVM.h"
-#include "llvm/Support/GraphWriter.h"
#include "llvm/Support/raw_ostream.h"
namespace mlir {
-class Block;
-class ModuleOp;
-template <typename T> class OperationPass;
-
-/// Displays the graph in a window. This is for use from the debugger and
-/// depends on Graphviz to generate the graph.
-void viewGraph(Block &block, const Twine &name, bool shortNames = false,
- const Twine &title = "",
- llvm::GraphProgram::Name program = llvm::GraphProgram::DOT);
-
-raw_ostream &writeGraph(raw_ostream &os, Block &block, bool shortNames = false,
- const Twine &title = "");
+class Pass;
/// Creates a pass to print op graphs.
-std::unique_ptr<OperationPass<ModuleOp>>
-createPrintOpGraphPass(raw_ostream &os = llvm::errs(), bool shortNames = false,
- const Twine &title = "");
+std::unique_ptr<Pass>
+createPrintOpGraphPass(raw_ostream &os = llvm::errs());
} // end namespace mlir
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index ee67953430d5e..5563b70f379b9 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -39,6 +39,7 @@ add_mlir_library(MLIRTransforms
MLIRMemRef
MLIRSCF
MLIRPass
+ MLIRSupportIndentedOstream
MLIRTransformUtils
MLIRVector
)
diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
index 3d52d79b7ef7b..86c5725b8910d 100644
--- a/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -9,12 +9,16 @@
#include "mlir/Transforms/ViewOpGraph.h"
#include "PassDetail.h"
#include "mlir/IR/Block.h"
-#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
-#include "llvm/Support/CommandLine.h"
+#include "mlir/Support/IndentedOstream.h"
+#include "llvm/Support/Format.h"
using namespace mlir;
+static const StringRef kLineStyleDataFlow = "solid";
+static const StringRef kShapeNode = "ellipse";
+static const StringRef kShapeNone = "plain";
+
/// Return the size limits for eliding large attributes.
static int64_t getLargeAttributeSizeLimit() {
// Use the default from the printer flags if possible.
@@ -23,145 +27,251 @@ static int64_t getLargeAttributeSizeLimit() {
return 16;
}
-namespace llvm {
+/// Return all values printed onto a stream as a string.
+static std::string strFromOs(function_ref<void(raw_ostream &)> func) {
+ std::string buf;
+ llvm::raw_string_ostream os(buf);
+ func(os);
+ return os.str();
+}
+
+/// Escape special characters such as '\n' and quotation marks.
+static std::string escapeString(std::string str) {
+ return strFromOs([&](raw_ostream &os) { os.write_escaped(str); });
+}
+
+/// Put quotation marks around a given string.
+static std::string quoteString(std::string str) { return "\"" + str + "\""; }
-// Specialize GraphTraits to treat Block as a graph of Operations as nodes and
-// uses as edges.
-template <> struct GraphTraits<Block *> {
- using GraphType = Block *;
- using NodeRef = Operation *;
+using AttributeMap = llvm::StringMap<std::string>;
- using ChildIteratorType = Operation::user_iterator;
- static ChildIteratorType child_begin(NodeRef n) { return n->user_begin(); }
- static ChildIteratorType child_end(NodeRef n) { return n->user_end(); }
+namespace {
+
+/// This struct represents a node in the DOT language. Each node has an
+/// identifier and an optional identifier for the cluster (subgraph) that
+/// contains the node.
+/// Note: In the DOT language, edges can be drawn only from nodes to nodes, but
+/// not between clusters. However, edges can be clipped to the boundary of a
+/// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new
+/// cluster, an invisible "anchor" node is created.
+struct Node {
+public:
+ Node(int id = 0, Optional<int> clusterId = llvm::None)
+ : id(id), clusterId(clusterId) {}
- // Operation's destructor is private so use Operation* instead and use
- // mapped iterator.
- static Operation *AddressOf(Operation &op) { return &op; }
- using nodes_iterator = mapped_iterator<Block::iterator, decltype(&AddressOf)>;
- static nodes_iterator nodes_begin(Block *b) {
- return nodes_iterator(b->begin(), &AddressOf);
+ int id;
+ Optional<int> clusterId;
+};
+
+/// This pass generates a Graphviz dataflow visualization of an MLIR operation.
+/// Note: See https://www.graphviz.org/doc/info/lang.html for more information
+/// about the Graphviz DOT language.
+class PrintOpPass : public ViewOpGraphPassBase<PrintOpPass> {
+public:
+ PrintOpPass(raw_ostream &os) : os(os) {}
+ PrintOpPass(const PrintOpPass &o) : os(o.os.getOStream()) {}
+
+ void runOnOperation() override {
+ emitGraph([&]() {
+ processOperation(getOperation());
+ emitAllEdgeStmts();
+ });
}
- static nodes_iterator nodes_end(Block *b) {
- return nodes_iterator(b->end(), &AddressOf);
+
+private:
+ /// Emit all edges. This function should be called after all nodes have been
+ /// emitted.
+ void emitAllEdgeStmts() {
+ for (const std::string &edge : edges)
+ os << edge << ";\n";
+ edges.clear();
}
-};
-// Specialize DOTGraphTraits to produce more readable output.
-template <> struct DOTGraphTraits<Block *> : public DefaultDOTGraphTraits {
- using DefaultDOTGraphTraits::DefaultDOTGraphTraits;
- static std::string getNodeLabel(Operation *op, Block *);
-};
+ /// Emit a cluster (subgraph). The specified builder generates the body of the
+ /// cluster. Return the anchor node of the cluster.
+ Node emitClusterStmt(function_ref<void()> builder, std::string label = "") {
+ int clusterId = ++counter;
+ os << "subgraph cluster_" << clusterId << " {\n";
+ os.indent();
+ // Emit invisible anchor node from/to which arrows can be drawn.
+ Node anchorNode = emitNodeStmt(" ", kShapeNone);
+ os << attrStmt("label", quoteString(escapeString(label))) << ";\n";
+ builder();
+ os.unindent();
+ os << "}\n";
+ return Node(anchorNode.id, clusterId);
+ }
-std::string DOTGraphTraits<Block *>::getNodeLabel(Operation *op, Block *b) {
- // Reuse the print output for the node labels.
- std::string ostr;
- raw_string_ostream os(ostr);
- os << op->getName() << "\n";
+ /// Generate an attribute statement.
+ std::string attrStmt(const Twine &key, const Twine &value) {
+ return (key + " = " + value).str();
+ }
- if (!op->getLoc().isa<UnknownLoc>()) {
- os << op->getLoc() << "\n";
+ /// Emit an attribute list.
+ void emitAttrList(raw_ostream &os, const AttributeMap &map) {
+ os << "[";
+ interleaveComma(map, os, [&](const auto &it) {
+ os << attrStmt(it.getKey(), it.getValue());
+ });
+ os << "]";
}
- // Print resultant types
- llvm::interleaveComma(op->getResultTypes(), os);
- os << "\n";
+ // Print an MLIR attribute to `os`. Large attributes are truncated.
+ void emitMlirAttr(raw_ostream &os, Attribute attr) {
+ // A value used to elide large container attribute.
+ int64_t largeAttrLimit = getLargeAttributeSizeLimit();
- // A value used to elide large container attribute.
- int64_t largeAttrLimit = getLargeAttributeSizeLimit();
- for (auto attr : op->getAttrs()) {
- os << '\n' << attr.first << ": ";
// Always emit splat attributes.
- if (attr.second.isa<SplatElementsAttr>()) {
- attr.second.print(os);
- continue;
+ if (attr.isa<SplatElementsAttr>()) {
+ attr.print(os);
+ return;
}
// Elide "big" elements attributes.
- auto elements = attr.second.dyn_cast<ElementsAttr>();
+ auto elements = attr.dyn_cast<ElementsAttr>();
if (elements && elements.getNumElements() > largeAttrLimit) {
os << std::string(elements.getType().getRank(), '[') << "..."
<< std::string(elements.getType().getRank(), ']') << " : "
<< elements.getType();
- continue;
+ return;
}
- auto array = attr.second.dyn_cast<ArrayAttr>();
+ auto array = attr.dyn_cast<ArrayAttr>();
if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) {
os << "[...]";
- continue;
+ return;
}
// Print all other attributes.
- attr.second.print(os);
+ attr.print(os);
}
- return os.str();
-}
-} // end namespace llvm
+ /// Append an edge to the list of edges.
+ /// Note: Edges are written to the output stream via `emitAllEdgeStmts`.
+ void emitEdgeStmt(Node n1, Node n2, std::string label,
+ StringRef style = kLineStyleDataFlow) {
+ AttributeMap attrs;
+ attrs["style"] = style.str();
+ // Do not label edges that start/end at a cluster boundary. Such edges are
+ // clipped at the boundary, but labels are not. This can lead to labels
+ // floating around without any edge next to them.
+ if (!n1.clusterId && !n2.clusterId)
+ attrs["label"] = quoteString(escapeString(label));
+ // Use `ltail` and `lhead` to draw edges between clusters.
+ if (n1.clusterId)
+ attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId);
+ if (n2.clusterId)
+ attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId);
-namespace {
-// PrintOpPass is simple pass to write graph per function.
-// Note: this is a module pass only to avoid interleaving on the same ostream
-// due to multi-threading over functions.
-class PrintOpPass : public ViewOpGraphPassBase<PrintOpPass> {
-public:
- PrintOpPass(raw_ostream &os, bool shortNames, const Twine &title) : os(os) {
- this->shortNames = shortNames;
- this->title = title.str();
+ edges.push_back(strFromOs([&](raw_ostream &os) {
+ os << llvm::format("v%i -> v%i ", n1.id, n2.id);
+ emitAttrList(os, attrs);
+ }));
}
- std::string getOpName(Operation &op) {
- auto symbolAttr =
- op.getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
- if (symbolAttr)
- return std::string(symbolAttr.getValue());
- ++unnamedOpCtr;
- return (op.getName().getStringRef() + llvm::utostr(unnamedOpCtr)).str();
+ /// Emit a graph. The specified builder generates the body of the graph.
+ void emitGraph(function_ref<void()> builder) {
+ os << "digraph G {\n";
+ os.indent();
+ // Edges between clusters are allowed only in compound mode.
+ os << attrStmt("compound", "true") << ";\n";
+ builder();
+ os.unindent();
+ os << "}\n";
}
- // Print all the ops in a module.
- void processModule(ModuleOp module) {
- for (Operation &op : module) {
- // Modules may actually be nested, recurse on nesting.
- if (auto nestedModule = dyn_cast<ModuleOp>(op)) {
- processModule(nestedModule);
- continue;
- }
- auto opName = getOpName(op);
- for (Region ®ion : op.getRegions()) {
- for (auto indexed_block : llvm::enumerate(region)) {
- // Suffix block number if there are more than 1 block.
- auto blockName = llvm::hasSingleElement(region)
- ? ""
- : ("__" + llvm::utostr(indexed_block.index()));
- llvm::WriteGraph(os, &indexed_block.value(), shortNames,
- Twine(title) + opName + blockName);
- }
+ /// Emit a node statement.
+ Node emitNodeStmt(std::string label, StringRef shape = kShapeNode) {
+ int nodeId = ++counter;
+ AttributeMap attrs;
+ attrs["label"] = quoteString(escapeString(label));
+ attrs["shape"] = shape.str();
+ os << llvm::format("v%i ", nodeId);
+ emitAttrList(os, attrs);
+ os << ";\n";
+ return Node(nodeId);
+ }
+
+ /// Generate a label for an operation.
+ std::string getLabel(Operation *op) {
+ return strFromOs([&](raw_ostream &os) {
+ // Print operation name and type.
+ os << op->getName() << " : (";
+ interleaveComma(op->getResultTypes(), os);
+ os << ")\n";
+
+ // Print attributes.
+ for (const NamedAttribute &attr : op->getAttrs()) {
+ os << '\n' << attr.first << ": ";
+ emitMlirAttr(os, attr.second);
}
+ });
+ }
+
+ /// Generate a label for a block argument.
+ std::string getLabel(BlockArgument arg) {
+ return "arg" + std::to_string(arg.getArgNumber());
+ }
+
+ /// Process a block. Emit a cluster and one node per block argument and
+ /// operation inside the cluster.
+ void processBlock(Block &block) {
+ emitClusterStmt([&]() {
+ for (BlockArgument &blockArg : block.getArguments())
+ valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg));
+
+ // Emit a node for each operation.
+ for (Operation &op : block)
+ processOperation(&op);
+ });
+ }
+
+ /// Process an operation. If the operation has regions, emit a cluster.
+ /// Otherwise, emit a node.
+ void processOperation(Operation *op) {
+ Node node;
+ if (op->getNumRegions() > 0) {
+ // Emit cluster for op with regions.
+ node = emitClusterStmt(
+ [&]() {
+ for (Region ®ion : op->getRegions())
+ processRegion(region);
+ },
+ getLabel(op));
+ } else {
+ node = emitNodeStmt(getLabel(op));
}
+
+ // Insert edges originating from each operand.
+ unsigned numOperands = op->getNumOperands();
+ for (unsigned i = 0; i < numOperands; i++)
+ emitEdgeStmt(valueToNode[op->getOperand(i)], node,
+ /*label=*/numOperands == 1 ? "" : std::to_string(i));
+
+ for (Value result : op->getResults())
+ valueToNode[result] = node;
}
- void runOnOperation() override { processModule(getOperation()); }
+ /// Process a region.
+ void processRegion(Region ®ion) {
+ for (Block &block : region.getBlocks())
+ processBlock(block);
+ }
-private:
- raw_ostream &os;
- int unnamedOpCtr = 0;
+ /// Output stream to write DOT file to.
+ raw_indented_ostream os;
+ /// A list of edges. For simplicity, should be emitted after all nodes were
+ /// emitted.
+ std::vector<std::string> edges;
+ /// Mapping of SSA values to Graphviz nodes/clusters.
+ DenseMap<Value, Node> valueToNode;
+ /// Counter for generating unique node/subgraph identifiers.
+ int counter = 0;
};
-} // namespace
-
-void mlir::viewGraph(Block &block, const Twine &name, bool shortNames,
- const Twine &title, llvm::GraphProgram::Name program) {
- llvm::ViewGraph(&block, name, shortNames, title, program);
-}
-raw_ostream &mlir::writeGraph(raw_ostream &os, Block &block, bool shortNames,
- const Twine &title) {
- return llvm::WriteGraph(os, &block, shortNames, title);
-}
+} // namespace
-std::unique_ptr<OperationPass<ModuleOp>>
-mlir::createPrintOpGraphPass(raw_ostream &os, bool shortNames,
- const Twine &title) {
- return std::make_unique<PrintOpPass>(os, shortNames, title);
+std::unique_ptr<Pass>
+mlir::createPrintOpGraphPass(raw_ostream &os) {
+ return std::make_unique<PrintOpPass>(os);
}
diff --git a/mlir/test/Transforms/print-op-graph.mlir b/mlir/test/Transforms/print-op-graph.mlir
index 4a5ac380632e1..fb14cf333f287 100644
--- a/mlir/test/Transforms/print-op-graph.mlir
+++ b/mlir/test/Transforms/print-op-graph.mlir
@@ -1,18 +1,36 @@
// RUN: mlir-opt -allow-unregistered-dialect -mlir-elide-elementsattrs-if-larger=2 -view-op-graph %s -o %t 2>&1 | FileCheck %s
-// CHECK-LABEL: digraph "merge_blocks"
-// CHECK{LITERAL}: value: [[...]] : tensor\<2x2xi32\>}
-// CHECK{LITERAL}: value: dense\<1\> : tensor\<5xi32\>}
-// CHECK{LITERAL}: value: dense\<[[0, 1]]\> : tensor\<1x2xi32\>}
+// CHECK-LABEL: digraph G {
+// CHECK: subgraph {{.*}} {
+// CHECK: subgraph {{.*}}
+// CHECK: label = "builtin.func{{.*}}merge_blocks
+// CHECK: subgraph {{.*}} {
+// CHECK: v[[ARG0:.*]] [label = "arg0"
+// CHECK: v[[CONST10:.*]] [label ={{.*}}10 : i32
+// CHECK: subgraph [[CLUSTER_MERGE_BLOCKS:.*]] {
+// CHECK: v[[ANCHOR:.*]] [label = " ", shape = plain]
+// CHECK: label = "test.merge_blocks
+// CHECK: subgraph {{.*}} {
+// CHECK: v[[TEST_BR:.*]] [label = "test.br
+// CHECK: }
+// CHECK: subgraph {{.*}} {
+// CHECK: }
+// CHECK: }
+// CHECK: v[[TEST_RET:.*]] [label = "test.return
+// CHECK: v[[ARG0]] -> v[[TEST_BR]]
+// CHECK: v[[CONST10]] -> v[[TEST_BR]]
+// CHECK: v[[ANCHOR]] -> v[[TEST_RET]] [{{.*}}, ltail = [[CLUSTER_MERGE_BLOCKS]]]
+// CHECK: v[[ANCHOR]] -> v[[TEST_RET]] [{{.*}}, ltail = [[CLUSTER_MERGE_BLOCKS]]]
func @merge_blocks(%arg0: i32, %arg1 : i32) -> () {
%0 = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
%1 = constant dense<1> : tensor<5xi32>
%2 = constant dense<[[0, 1]]> : tensor<1x2xi32>
-
+ %a = constant 10 : i32
+ %b = "test.func"() : () -> i32
%3:2 = "test.merge_blocks"() ({
^bb0:
- "test.br"(%arg0, %arg1)[^bb1] : (i32, i32) -> ()
- ^bb1(%arg3 : i32, %arg4 : i32):
+ "test.br"(%arg0, %b, %a)[^bb1] : (i32, i32, i32) -> ()
+ ^bb1(%arg3 : i32, %arg4 : i32, %arg5: i32):
"test.return"(%arg3, %arg4) : (i32, i32) -> ()
}) : () -> (i32, i32)
"test.return"(%3#0, %3#1) : (i32, i32) -> ()
More information about the Mlir-commits
mailing list