[Mlir-commits] [mlir] 9102a16 - [mlir] Support drawing control-flow graphs in ViewOpGraph.cpp

Matthias Springer llvmlistbot at llvm.org
Wed Aug 4 04:45:43 PDT 2021


Author: Matthias Springer
Date: 2021-08-04T20:45:15+09:00
New Revision: 9102a16bef1aa8c780f440f8ac7d71090d1a96c1

URL: https://github.com/llvm/llvm-project/commit/9102a16bef1aa8c780f440f8ac7d71090d1a96c1
DIFF: https://github.com/llvm/llvm-project/commit/9102a16bef1aa8c780f440f8ac7d71090d1a96c1.diff

LOG: [mlir] Support drawing control-flow graphs in ViewOpGraph.cpp

* Add new pass option `print-data-flow-edges`, default value `true`.
* Add new pass option `print-control-flow-edges`, default value `false`.
* Remove `PrintCFGPass`. Same functionality now provided by
  `PrintOpPass`.

Differential Revision: https://reviews.llvm.org/D106342

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/Passes.h
    mlir/include/mlir/Transforms/Passes.td
    mlir/lib/Transforms/CMakeLists.txt
    mlir/lib/Transforms/NormalizeMemRefs.cpp
    mlir/lib/Transforms/ViewOpGraph.cpp
    mlir/test/Transforms/print-op-graph.mlir

Removed: 
    mlir/include/mlir/Transforms/ViewRegionGraph.h
    mlir/lib/Transforms/ViewRegionGraph.cpp


################################################################################
diff  --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index bda2c410223c..04283353a208 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -17,7 +17,6 @@
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/LocationSnapshot.h"
 #include "mlir/Transforms/ViewOpGraph.h"
-#include "mlir/Transforms/ViewRegionGraph.h"
 #include <limits>
 
 namespace mlir {

diff  --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 857b21ccddc4..45d72c061d6e 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -612,11 +612,6 @@ def ParallelLoopCollapsing : Pass<"parallel-loop-collapsing"> {
   ];
 }
 
-def PrintCFG : FunctionPass<"print-cfg-graph"> {
-  let summary = "Print CFG graph per-Region";
-  let constructor = "mlir::createPrintCFGGraphPass()";
-}
-
 def PrintOpStats : Pass<"print-op-stats"> {
   let summary = "Print statistics of operations";
   let constructor = "mlir::createPrintOpStatsPass()";
@@ -689,14 +684,17 @@ def SymbolDCE : Pass<"symbol-dce"> {
 }
 
 def ViewOpGraphPass : Pass<"view-op-graph"> {
-  let summary = "Print Graphviz dataflow visualization of an operation";
+  let summary = "Print Graphviz visualization of an operation";
   let description = [{
-    This pass prints a Graphviz dataflow graph of a module.
+    This pass prints a Graphviz graph of a module.
 
     - Operations are represented as nodes;
-    - Uses as edges;
+    - Uses (data flow) as edges;
+    - Control flow as dashed edges;
     - Regions/blocks as subgraphs.
 
+    By default, only data flow edges are printed.
+
     Note: See https://www.graphviz.org/doc/info/lang.html for more information
     about the Graphviz DOT language.
   }];
@@ -705,6 +703,10 @@ def ViewOpGraphPass : Pass<"view-op-graph"> {
             /*default=*/"20", "Limit attribute/type length to number of chars">,
     Option<"printAttrs", "print-attrs", "bool",
            /*default=*/"true", "Print attributes of operations">,
+    Option<"printControlFlowEdges", "print-control-flow-edges", "bool",
+           /*default=*/"false", "Print control flow edges">,
+    Option<"printDataFlowEdges", "print-data-flow-edges", "bool",
+           /*default=*/"true", "Print data flow edges">,
     Option<"printResultTypes", "print-result-types", "bool",
             /*default=*/"true", "Print result types of operations">
   ];

diff  --git a/mlir/include/mlir/Transforms/ViewRegionGraph.h b/mlir/include/mlir/Transforms/ViewRegionGraph.h
deleted file mode 100644
index 950f4c349bbf..000000000000
--- a/mlir/include/mlir/Transforms/ViewRegionGraph.h
+++ /dev/null
@@ -1,41 +0,0 @@
-//===- ViewRegionGraph.h - View/write graphviz graphs -----------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Defines interface to produce Graphviz outputs of MLIR Regions.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_TRANSFORMS_VIEWFUNCTIONGRAPH_H_
-#define MLIR_TRANSFORMS_VIEWFUNCTIONGRAPH_H_
-
-#include "mlir/Support/LLVM.h"
-#include "llvm/Support/GraphWriter.h"
-#include "llvm/Support/raw_ostream.h"
-
-namespace mlir {
-class FuncOp;
-template <typename T> class OperationPass;
-class Region;
-
-/// Displays the CFG in a window. This is for use from the debugger and
-/// depends on Graphviz to generate the graph.
-void viewGraph(Region &region, const Twine &name, bool shortNames = false,
-               const Twine &title = "",
-               llvm::GraphProgram::Name program = llvm::GraphProgram::DOT);
-
-raw_ostream &writeGraph(raw_ostream &os, Region &region,
-                        bool shortNames = false, const Twine &title = "");
-
-/// Creates a pass to print CFG graphs.
-std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
-createPrintCFGGraphPass(raw_ostream &os = llvm::errs(), bool shortNames = false,
-                        const Twine &title = "");
-
-} // end namespace mlir
-
-#endif // MLIR_TRANSFORMS_VIEWFUNCTIONGRAPH_H_

diff  --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 3bf4221629de..99133af8b981 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -21,7 +21,6 @@ add_mlir_library(MLIRTransforms
   StripDebugInfo.cpp
   SymbolDCE.cpp
   ViewOpGraph.cpp
-  ViewRegionGraph.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms

diff  --git a/mlir/lib/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Transforms/NormalizeMemRefs.cpp
index 831c613a4118..ff0fdc95f45e 100644
--- a/mlir/lib/Transforms/NormalizeMemRefs.cpp
+++ b/mlir/lib/Transforms/NormalizeMemRefs.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Transforms/Passes.h"
 #include "mlir/Transforms/Utils.h"
 #include "llvm/ADT/SmallSet.h"
+#include "llvm/Support/Debug.h"
 
 #define DEBUG_TYPE "normalize-memrefs"
 

diff  --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
index d18de7c53c8e..4545725d8e2e 100644
--- a/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -12,9 +12,11 @@
 #include "mlir/IR/Operation.h"
 #include "mlir/Support/IndentedOstream.h"
 #include "llvm/Support/Format.h"
+#include "llvm/Support/GraphWriter.h"
 
 using namespace mlir;
 
+static const StringRef kLineStyleControlFlow = "dashed";
 static const StringRef kLineStyleDataFlow = "solid";
 static const StringRef kShapeNode = "ellipse";
 static const StringRef kShapeNone = "plain";
@@ -78,6 +80,13 @@ class PrintOpPass : public ViewOpGraphPassBase<PrintOpPass> {
     });
   }
 
+  /// Create a CFG graph for a region. Used in `Region::viewGraph`.
+  void emitRegionCFG(Region &region) {
+    printControlFlowEdges = true;
+    printDataFlowEdges = false;
+    emitGraph([&]() { processRegion(region); });
+  }
+
 private:
   /// Emit all edges. This function should be called after all nodes have been
   /// emitted.
@@ -151,8 +160,7 @@ class PrintOpPass : public ViewOpGraphPassBase<PrintOpPass> {
 
   /// Append an edge to the list of edges.
   /// Note: Edges are written to the output stream via `emitAllEdgeStmts`.
-  void emitEdgeStmt(Node n1, Node n2, std::string label,
-                    StringRef style = kLineStyleDataFlow) {
+  void emitEdgeStmt(Node n1, Node n2, std::string label, StringRef style) {
     AttributeMap attrs;
     attrs["style"] = style.str();
     // Do not label edges that start/end at a cluster boundary. Such edges are
@@ -233,14 +241,20 @@ class PrintOpPass : public ViewOpGraphPassBase<PrintOpPass> {
         valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg));
 
       // Emit a node for each operation.
-      for (Operation &op : block)
-        processOperation(&op);
+      Optional<Node> prevNode;
+      for (Operation &op : block) {
+        Node nextNode = processOperation(&op);
+        if (printControlFlowEdges && prevNode)
+          emitEdgeStmt(*prevNode, nextNode, /*label=*/"",
+                       kLineStyleControlFlow);
+        prevNode = nextNode;
+      }
     });
   }
 
   /// Process an operation. If the operation has regions, emit a cluster.
   /// Otherwise, emit a node.
-  void processOperation(Operation *op) {
+  Node processOperation(Operation *op) {
     Node node;
     if (op->getNumRegions() > 0) {
       // Emit cluster for op with regions.
@@ -254,14 +268,19 @@ class PrintOpPass : public ViewOpGraphPassBase<PrintOpPass> {
       node = emitNodeStmt(getLabel(op));
     }
 
-    // Insert edges originating from each operand.
-    unsigned numOperands = op->getNumOperands();
-    for (unsigned i = 0; i < numOperands; i++)
-      emitEdgeStmt(valueToNode[op->getOperand(i)], node,
-                   /*label=*/numOperands == 1 ? "" : std::to_string(i));
+    // Insert data flow edges originating from each operand.
+    if (printDataFlowEdges) {
+      unsigned numOperands = op->getNumOperands();
+      for (unsigned i = 0; i < numOperands; i++)
+        emitEdgeStmt(valueToNode[op->getOperand(i)], node,
+                     /*label=*/numOperands == 1 ? "" : std::to_string(i),
+                     kLineStyleDataFlow);
+    }
 
     for (Value result : op->getResults())
       valueToNode[result] = node;
+
+    return node;
   }
 
   /// Process a region.
@@ -294,3 +313,25 @@ std::unique_ptr<Pass>
 mlir::createPrintOpGraphPass(raw_ostream &os) {
   return std::make_unique<PrintOpPass>(os);
 }
+
+/// Generate a CFG for a region and show it in a window.
+static void llvmViewGraph(Region &region, const Twine &name) {
+  int fd;
+  std::string filename = llvm::createGraphFilename(name.str(), fd);
+  {
+    llvm::raw_fd_ostream os(fd, /*shouldClose=*/true);
+    if (fd == -1) {
+      llvm::errs() << "error opening file '" << filename << "' for writing\n";
+      return;
+    }
+    PrintOpPass pass(os);
+    pass.emitRegionCFG(region);
+  }
+  llvm::DisplayGraph(filename, /*wait=*/false, llvm::GraphProgram::DOT);
+}
+
+void mlir::Region::viewGraph(const Twine &regionName) {
+  llvmViewGraph(*this, regionName);
+}
+
+void mlir::Region::viewGraph() { viewGraph("region"); }

diff  --git a/mlir/lib/Transforms/ViewRegionGraph.cpp b/mlir/lib/Transforms/ViewRegionGraph.cpp
deleted file mode 100644
index 0c67f30c19cb..000000000000
--- a/mlir/lib/Transforms/ViewRegionGraph.cpp
+++ /dev/null
@@ -1,82 +0,0 @@
-//===- ViewRegionGraph.cpp - View/write graphviz graphs -------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Transforms/ViewRegionGraph.h"
-#include "PassDetail.h"
-#include "mlir/IR/RegionGraphTraits.h"
-
-using namespace mlir;
-
-namespace llvm {
-
-// Specialize DOTGraphTraits to produce more readable output.
-template <> struct DOTGraphTraits<Region *> : public DefaultDOTGraphTraits {
-  using DefaultDOTGraphTraits::DefaultDOTGraphTraits;
-
-  static std::string getNodeLabel(Block *Block, Region *);
-};
-
-std::string DOTGraphTraits<Region *>::getNodeLabel(Block *Block, Region *) {
-  // Reuse the print output for the node labels.
-  std::string outStreamStr;
-  raw_string_ostream os(outStreamStr);
-  Block->print(os);
-  std::string &outStr = os.str();
-
-  if (outStr[0] == '\n')
-    outStr.erase(outStr.begin());
-
-  // Process string output to left justify the block.
-  for (unsigned i = 0; i != outStr.length(); ++i) {
-    if (outStr[i] == '\n') {
-      outStr[i] = '\\';
-      outStr.insert(outStr.begin() + i + 1, 'l');
-    }
-  }
-
-  return outStr;
-}
-
-} // end namespace llvm
-
-void mlir::viewGraph(Region &region, const Twine &name, bool shortNames,
-                     const Twine &title, llvm::GraphProgram::Name program) {
-  llvm::ViewGraph(&region, name, shortNames, title, program);
-}
-
-raw_ostream &mlir::writeGraph(raw_ostream &os, Region &region, bool shortNames,
-                              const Twine &title) {
-  return llvm::WriteGraph(os, &region, shortNames, title);
-}
-
-void mlir::Region::viewGraph(const Twine &regionName) {
-  ::mlir::viewGraph(*this, regionName);
-}
-void mlir::Region::viewGraph() { viewGraph("region"); }
-
-namespace {
-struct PrintCFGPass : public PrintCFGBase<PrintCFGPass> {
-  PrintCFGPass(raw_ostream &os = llvm::errs(), bool shortNames = false,
-               const Twine &title = "")
-      : os(os), shortNames(shortNames), title(title.str()) {}
-  void runOnFunction() override {
-    mlir::writeGraph(os, getFunction().getBody(), shortNames, title);
-  }
-
-private:
-  raw_ostream &os;
-  bool shortNames;
-  std::string title;
-};
-} // namespace
-
-std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
-mlir::createPrintCFGGraphPass(raw_ostream &os, bool shortNames,
-                              const Twine &title) {
-  return std::make_unique<PrintCFGPass>(os, shortNames, title);
-}

diff  --git a/mlir/test/Transforms/print-op-graph.mlir b/mlir/test/Transforms/print-op-graph.mlir
index fb14cf333f28..a056d9445edc 100644
--- a/mlir/test/Transforms/print-op-graph.mlir
+++ b/mlir/test/Transforms/print-op-graph.mlir
@@ -1,26 +1,55 @@
-// RUN: mlir-opt -allow-unregistered-dialect -mlir-elide-elementsattrs-if-larger=2 -view-op-graph %s -o %t 2>&1 | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect -mlir-elide-elementsattrs-if-larger=2 -view-op-graph %s -o %t 2>&1 | FileCheck -check-prefix=DFG %s
+// RUN: mlir-opt -allow-unregistered-dialect -mlir-elide-elementsattrs-if-larger=2 -view-op-graph='print-data-flow-edges=false print-control-flow-edges=true' %s -o %t 2>&1 | FileCheck -check-prefix=CFG %s
+
+// DFG-LABEL: digraph G {
+//       DFG:   subgraph {{.*}} {
+//       DFG:     subgraph {{.*}}
+//       DFG:       label = "builtin.func{{.*}}merge_blocks
+//       DFG:       subgraph {{.*}} {
+//       DFG:         v[[ARG0:.*]] [label = "arg0"
+//       DFG:         v[[CONST10:.*]] [label ={{.*}}10 : i32
+//       DFG:         subgraph [[CLUSTER_MERGE_BLOCKS:.*]] {
+//       DFG:           v[[ANCHOR:.*]] [label = " ", shape = plain]
+//       DFG:           label = "test.merge_blocks
+//       DFG:           subgraph {{.*}} {
+//       DFG:             v[[TEST_BR:.*]] [label = "test.br
+//       DFG:           }
+//       DFG:           subgraph {{.*}} {
+//       DFG:           }
+//       DFG:         }
+//       DFG:         v[[TEST_RET:.*]] [label = "test.return
+//       DFG:   v[[ARG0]] -> v[[TEST_BR]]
+//       DFG:   v[[CONST10]] -> v[[TEST_BR]]
+//       DFG:   v[[ANCHOR]] -> v[[TEST_RET]] [{{.*}}, ltail = [[CLUSTER_MERGE_BLOCKS]]]
+//       DFG:   v[[ANCHOR]] -> v[[TEST_RET]] [{{.*}}, ltail = [[CLUSTER_MERGE_BLOCKS]]]
+
+// CFG-LABEL: digraph G {
+//       CFG:   subgraph {{.*}} {
+//       CFG:     subgraph {{.*}}
+//       CFG:       label = "builtin.func{{.*}}merge_blocks
+//       CFG:       subgraph {{.*}} {
+//       CFG:         v[[C1:.*]] [label = "std.constant
+//       CFG:         v[[C2:.*]] [label = "std.constant
+//       CFG:         v[[C3:.*]] [label = "std.constant
+//       CFG:         v[[C4:.*]] [label = "std.constant
+//       CFG:         v[[TEST_FUNC:.*]] [label = "test.func
+//       CFG:         subgraph [[CLUSTER_MERGE_BLOCKS:.*]] {
+//       CFG:           v[[ANCHOR:.*]] [label = " ", shape = plain]
+//       CFG:           label = "test.merge_blocks
+//       CFG:           subgraph {{.*}} {
+//       CFG:             v[[TEST_BR:.*]] [label = "test.br
+//       CFG:           }
+//       CFG:           subgraph {{.*}} {
+//       CFG:           }
+//       CFG:         }
+//       CFG:         v[[TEST_RET:.*]] [label = "test.return
+//       CFG:   v[[C1]] -> v[[C2]]
+//       CFG:   v[[C2]] -> v[[C3]]
+//       CFG:   v[[C3]] -> v[[C4]]
+//       CFG:   v[[C4]] -> v[[TEST_FUNC]]
+//       CFG:   v[[TEST_FUNC]] -> v[[ANCHOR]] [{{.*}}, lhead = [[CLUSTER_MERGE_BLOCKS]]]
+//       CFG:   v[[ANCHOR]] -> v[[TEST_RET]] [{{.*}}, ltail = [[CLUSTER_MERGE_BLOCKS]]]
 
-// CHECK-LABEL: digraph G {
-//       CHECK:   subgraph {{.*}} {
-//       CHECK:     subgraph {{.*}}
-//       CHECK:       label = "builtin.func{{.*}}merge_blocks
-//       CHECK:       subgraph {{.*}} {
-//       CHECK:         v[[ARG0:.*]] [label = "arg0"
-//       CHECK:         v[[CONST10:.*]] [label ={{.*}}10 : i32
-//       CHECK:         subgraph [[CLUSTER_MERGE_BLOCKS:.*]] {
-//       CHECK:           v[[ANCHOR:.*]] [label = " ", shape = plain]
-//       CHECK:           label = "test.merge_blocks
-//       CHECK:           subgraph {{.*}} {
-//       CHECK:             v[[TEST_BR:.*]] [label = "test.br
-//       CHECK:           }
-//       CHECK:           subgraph {{.*}} {
-//       CHECK:           }
-//       CHECK:         }
-//       CHECK:         v[[TEST_RET:.*]] [label = "test.return
-//       CHECK:   v[[ARG0]] -> v[[TEST_BR]]
-//       CHECK:   v[[CONST10]] -> v[[TEST_BR]]
-//       CHECK:   v[[ANCHOR]] -> v[[TEST_RET]] [{{.*}}, ltail = [[CLUSTER_MERGE_BLOCKS]]]
-//       CHECK:   v[[ANCHOR]] -> v[[TEST_RET]] [{{.*}}, ltail = [[CLUSTER_MERGE_BLOCKS]]]
 func @merge_blocks(%arg0: i32, %arg1 : i32) -> () {
   %0 = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
   %1 = constant dense<1> : tensor<5xi32>


        


More information about the Mlir-commits mailing list