[Mlir-commits] [mlir] 6eca120 - Improve MLIR "view-op-graph" to color operations according to their name

Mehdi Amini llvmlistbot at llvm.org
Tue Jun 20 13:01:57 PDT 2023


Author: Mehdi Amini
Date: 2023-06-20T22:00:48+02:00
New Revision: 6eca120dd8d3ec55bf8dc00b45366e8596359be3

URL: https://github.com/llvm/llvm-project/commit/6eca120dd8d3ec55bf8dc00b45366e8596359be3
DIFF: https://github.com/llvm/llvm-project/commit/6eca120dd8d3ec55bf8dc00b45366e8596359be3.diff

LOG: Improve MLIR "view-op-graph" to color operations according to their name

Differential Revision: https://reviews.llvm.org/D153290

Added: 
    

Modified: 
    mlir/lib/Transforms/ViewOpGraph.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
index 7689aa061a09d..3d2723839957c 100644
--- a/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -87,6 +87,7 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
   PrintOpPass(const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {}
 
   void runOnOperation() override {
+    initColorMapping(*getOperation());
     emitGraph([&]() {
       processOperation(getOperation());
       emitAllEdgeStmts();
@@ -97,10 +98,31 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
   void emitRegionCFG(Region &region) {
     printControlFlowEdges = true;
     printDataFlowEdges = false;
+    initColorMapping(region);
     emitGraph([&]() { processRegion(region); });
   }
 
 private:
+  /// Generate a color mapping that will color every operation with the same
+  /// name the same way. It'll interpolate the hue in the HSV color-space,
+  /// attempting to keep the contrast suitable for black text.
+  template <typename T>
+  void initColorMapping(T &irEntity) {
+    backgroundColors.clear();
+    SmallVector<Operation *> ops;
+    irEntity.walk([&](Operation *op) {
+      auto &entry = backgroundColors[op->getName()];
+      if (entry.first == 0)
+        ops.push_back(op);
+      ++entry.first;
+    });
+    for (auto indexedOps : llvm::enumerate(ops)) {
+      double hue = ((double)indexedOps.index()) / ops.size();
+      backgroundColors[indexedOps.value()->getName()].second =
+          std::to_string(hue) + " 1.0 1.0";
+    }
+  }
+
   /// Emit all edges. This function should be called after all nodes have been
   /// emitted.
   void emitAllEdgeStmts() {
@@ -206,11 +228,16 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
   }
 
   /// Emit a node statement.
-  Node emitNodeStmt(std::string label, StringRef shape = kShapeNode) {
+  Node emitNodeStmt(std::string label, StringRef shape = kShapeNode,
+                    StringRef background = "") {
     int nodeId = ++counter;
     AttributeMap attrs;
     attrs["label"] = quoteString(escapeString(std::move(label)));
     attrs["shape"] = shape.str();
+    if (!background.empty()) {
+      attrs["style"] = "filled";
+      attrs["fillcolor"] = ("\"" + background + "\"").str();
+    }
     os << llvm::format("v%i ", nodeId);
     emitAttrList(os, attrs);
     os << ";\n";
@@ -278,7 +305,8 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
           },
           getLabel(op));
     } else {
-      node = emitNodeStmt(getLabel(op));
+      node = emitNodeStmt(getLabel(op), kShapeNode,
+                          backgroundColors[op->getName()].second);
     }
 
     // Insert data flow edges originating from each operand.
@@ -318,6 +346,8 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
   DenseMap<Value, Node> valueToNode;
   /// Counter for generating unique node/subgraph identifiers.
   int counter = 0;
+
+  DenseMap<OperationName, std::pair<int, std::string>> backgroundColors;
 };
 
 } // namespace


        


More information about the Mlir-commits mailing list