[Mlir-commits] [mlir] WIP OpGraph fix (PR #67143)
Vladyslav Moisieienkov
llvmlistbot at llvm.org
Fri Sep 22 07:07:44 PDT 2023
https://github.com/VMois created https://github.com/llvm/llvm-project/pull/67143
None
>From ce090bcc1edd1dc7b942e47c28880e2e989e1ebc Mon Sep 17 00:00:00 2001
From: VMois <vmois at protonmail.com>
Date: Tue, 25 Apr 2023 20:45:15 +0200
Subject: [PATCH] WIP OpGraph fix
---
mlir/lib/Transforms/ViewOpGraph.cpp | 69 +++++++++++++++++++++++++++--
1 file changed, 66 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
index 4598b56a901d537..3e7296460d7dd5d 100644
--- a/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -18,6 +18,7 @@
#include <map>
#include <optional>
#include <utility>
+#include <iostream>
namespace mlir {
#define GEN_PASS_DEF_VIEWOPGRAPH
@@ -250,10 +251,33 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
/// Process a block. Emit a cluster and one node per block argument and
/// operation inside the cluster.
void processBlock(Block &block) {
+ //std::cout << "Emit cluster process block" << std::endl;
emitClusterStmt([&]() {
for (BlockArgument &blockArg : block.getArguments())
valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg));
+ // for (Operation &op : block) {
+ // //if (op.getNumRegions() > 0) continue;
+ // //std::string label = getLabel(&op);
+ // Node node = emitNodeStmt(getLabel(&op));
+ // for (Value result : op.getResults())
+ // valueToNode[result] = node;
+
+ // for (Value result : op.getOperands())
+ // valueToNode[result] = node;
+ // //std::cout << "Creating node " << node.id << std::endl;
+ // opToNode[&op] = node;
+ // }
+
+ // for (Operation &op : block) {
+ // Node node = emitNodeStmt(getLabel(&op));
+ // for (Value result : op.getResults())
+ // valueToNode[result] = node;
+
+ // for (Value result : op.getOperands())
+ // valueToNode[result] = node;
+ // }
+
// Emit a node for each operation.
std::optional<Node> prevNode;
for (Operation &op : block) {
@@ -279,16 +303,31 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
},
getLabel(op));
} else {
- node = emitNodeStmt(getLabel(op));
+ //if (opToNode.count(op) == 1) {
+ //node = opToNode[op];
+ //} else {
+ node = emitNodeStmt(opToName[op]);
+ //std::cout << "Using node " << node.id << std::endl;
}
// Insert data flow edges originating from each operand.
if (printDataFlowEdges) {
unsigned numOperands = op->getNumOperands();
- for (unsigned i = 0; i < numOperands; i++)
- emitEdgeStmt(valueToNode[op->getOperand(i)], node,
+ for (unsigned i = 0; i < numOperands; i++) {
+ Value operand = op->getOperand(i);
+ Node operandNode;
+ if (valueToNode.count(operand) == 0) {
+ // no Node created, created one according to pre-specified name
+ Operation* definingOp = operand.getDefiningOp();
+ operandNode = emitNodeStmt(opToName[definingOp]);
+ valueToNode[operand] = operandNode;
+ } else {
+ operandNode = valueToNode[operand];
+ }
+ emitEdgeStmt(operandNode, node,
/*label=*/numOperands == 1 ? "" : std::to_string(i),
kLineStyleDataFlow);
+ }
}
for (Value result : op->getResults())
@@ -299,6 +338,27 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
/// Process a region.
void processRegion(Region ®ion) {
+ // for (Block &block : region.getBlocks()) {
+ // for (Operation &op : block) {
+ // if (op.getNumRegions() > 0) continue;
+ // //std::string label = getLabel(&op);
+ // Node node = emitNodeStmt(getLabel(&op));
+ // for (Value result : op.getResults())
+ // valueToNode[result] = node;
+
+ // for (Value result : op.getOperands())
+ // valueToNode[result] = node;
+ // std::cout << "Creating node " << node.id << std::endl;
+ // opToNode[&op] = node;
+ // }
+ // }
+
+ for (Block &block : region.getBlocks()) {
+ for (Operation &op : block) {
+ opToName[&op] = getLabel(&op);
+ }
+ }
+
for (Block &block : region.getBlocks())
processBlock(block);
}
@@ -317,6 +377,9 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
std::vector<std::string> edges;
/// Mapping of SSA values to Graphviz nodes/clusters.
DenseMap<Value, Node> valueToNode;
+ DenseMap<Operation *, Node> opToNode;
+ DenseMap<Operation *, std::string> opToName;
+ //DenseMap<std::string, Node> nameToNode;
/// Counter for generating unique node/subgraph identifiers.
int counter = 0;
};
More information about the Mlir-commits
mailing list