[Mlir-commits] [mlir] 3e8560f - [MLIR] Add option to print users of an operation as comment in the printer
Mehdi Amini
llvmlistbot at llvm.org
Fri Apr 22 11:58:18 PDT 2022
Author: cpillmayer
Date: 2022-04-22T18:58:10Z
New Revision: 3e8560f890bb795fec6b86186239633f22a153d5
URL: https://github.com/llvm/llvm-project/commit/3e8560f890bb795fec6b86186239633f22a153d5
DIFF: https://github.com/llvm/llvm-project/commit/3e8560f890bb795fec6b86186239633f22a153d5.diff
LOG: [MLIR] Add option to print users of an operation as comment in the printer
This allows printing the users of an operation as proposed in the git issue #53286.
To be able to refer to operations with no result, these operations are assigned an
ID in SSANameState.
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D124048
Added:
mlir/test/IR/print-value-users.mlir
Modified:
mlir/include/mlir/IR/OperationSupport.h
mlir/lib/IR/AsmPrinter.cpp
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 52138b79554eb..d995518282eee 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -752,6 +752,9 @@ class OpPrintingFlags {
/// the full module.
OpPrintingFlags &useLocalScope();
+ /// Print users of values as comments.
+ OpPrintingFlags &printValueUsers();
+
/// Return if the given ElementsAttr should be elided.
bool shouldElideElementsAttr(ElementsAttr attr) const;
@@ -773,6 +776,9 @@ class OpPrintingFlags {
/// Return if the printer should use local scope when dumping the IR.
bool shouldUseLocalScope() const;
+ /// Return if the printer should print users of values.
+ bool shouldPrintValueUsers() const;
+
private:
/// Elide large elements attributes if the number of elements is larger than
/// the upper limit.
@@ -790,6 +796,9 @@ class OpPrintingFlags {
/// Print operations with numberings local to the current operation.
bool printLocalScope : 1;
+
+ /// Print users of values.
+ bool printValueUsersFlag : 1;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 980c886744ad5..013bf162211b8 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -152,6 +152,11 @@ struct AsmPrinterOptions {
"mlir-print-local-scope", llvm::cl::init(false),
llvm::cl::desc("Print with local scope and inline information (eliding "
"aliases for attributes, types, and locations")};
+
+ llvm::cl::opt<bool> printValueUsers{
+ "mlir-print-value-users", llvm::cl::init(false),
+ llvm::cl::desc(
+ "Print users of operation results and block arguments as a comment")};
};
} // namespace
@@ -168,7 +173,7 @@ void mlir::registerAsmPrinterCLOptions() {
OpPrintingFlags::OpPrintingFlags()
: printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
printGenericOpFormFlag(false), assumeVerifiedFlag(false),
- printLocalScope(false) {
+ printLocalScope(false), printValueUsersFlag(false) {
// Initialize based upon command line options, if they are available.
if (!clOptions.isConstructed())
return;
@@ -179,6 +184,7 @@ OpPrintingFlags::OpPrintingFlags()
printGenericOpFormFlag = clOptions->printGenericOpFormOpt;
assumeVerifiedFlag = clOptions->assumeVerifiedOpt;
printLocalScope = clOptions->printLocalScopeOpt;
+ printValueUsersFlag = clOptions->printValueUsers;
}
/// Enable the elision of large elements attributes, by printing a '...'
@@ -219,6 +225,12 @@ OpPrintingFlags &OpPrintingFlags::useLocalScope() {
return *this;
}
+/// Print users of values as comments.
+OpPrintingFlags &OpPrintingFlags::printValueUsers() {
+ printValueUsersFlag = true;
+ return *this;
+}
+
/// Return if the given ElementsAttr should be elided.
bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const {
return elementsAttrElementLimit.hasValue() &&
@@ -254,6 +266,11 @@ bool OpPrintingFlags::shouldAssumeVerified() const {
/// Return if the printer should use local scope when dumping the IR.
bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; }
+/// Return if the printer should print users of values.
+bool OpPrintingFlags::shouldPrintValueUsers() const {
+ return printValueUsersFlag;
+}
+
/// Returns true if an ElementsAttr with the given number of elements should be
/// printed with hex.
static bool shouldPrintElementsAttrWithHex(int64_t numElements) {
@@ -831,6 +848,9 @@ class SSANameState {
/// of this value.
void printValueID(Value value, bool printResultNo, raw_ostream &stream) const;
+ /// Print the operation identifier.
+ void printOperationID(Operation *op, raw_ostream &stream) const;
+
/// Return the result indices for each of the result groups registered by this
/// operation, or empty if none exist.
ArrayRef<int> getOpResultGroups(Operation *op);
@@ -868,6 +888,10 @@ class SSANameState {
DenseMap<Value, unsigned> valueIDs;
DenseMap<Value, StringRef> valueNames;
+ /// When printing users of values, an operation without a result might
+ /// be the user. This map holds ids for such operations.
+ DenseMap<Operation *, unsigned> operationIDs;
+
/// This is a map of operations that contain multiple named result groups,
/// i.e. there may be multiple names for the results of the operation. The
/// value of this map are the result numbers that start a result group.
@@ -990,6 +1014,15 @@ void SSANameState::printValueID(Value value, bool printResultNo,
stream << '#' << resultNo;
}
+void SSANameState::printOperationID(Operation *op, raw_ostream &stream) const {
+ auto it = operationIDs.find(op);
+ if (it == operationIDs.end()) {
+ stream << "<<UNKOWN OPERATION>>";
+ } else {
+ stream << '%' << it->second;
+ }
+}
+
ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) {
auto it = opResultGroups.find(op);
return it == opResultGroups.end() ? ArrayRef<int>() : it->second;
@@ -1121,8 +1154,14 @@ void SSANameState::numberValuesInOp(Operation &op) {
}
unsigned numResults = op.getNumResults();
- if (numResults == 0)
+ if (numResults == 0) {
+ // If value users should be printed, operations with no result need an id.
+ if (printerFlags.shouldPrintValueUsers()) {
+ if (operationIDs.try_emplace(&op, nextValueID).second)
+ ++nextValueID;
+ }
return;
+ }
Value resultBegin = op.getResult(0);
// If the first result wasn't numbered, give it a default number.
@@ -2481,6 +2520,10 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
void printValueID(Value value, bool printResultNo = true,
raw_ostream *streamOverride = nullptr) const;
+ /// Print the ID of the given operation.
+ void printOperationID(Operation *op,
+ raw_ostream *streamOverride = nullptr) const;
+
//===--------------------------------------------------------------------===//
// OpAsmPrinter methods
//===--------------------------------------------------------------------===//
@@ -2549,6 +2592,19 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
ValueRange symOperands) override;
+ /// Print users of this operation or id of this operation if it has no result.
+ void printUsersComment(Operation *op);
+
+ /// Print users of this block arg.
+ void printUsersComment(BlockArgument arg);
+
+ /// Print the users of a value.
+ void printValueUsers(Value value);
+
+ /// Print either the ids of the result values or the id of the operation if
+ /// the operation has no results.
+ void printUserIDs(Operation *user, bool prefixComma = false);
+
private:
// Contains the stack of default dialects to use when printing regions.
// A new dialect is pushed to the stack before parsing regions nested under an
@@ -2602,6 +2658,8 @@ void OperationPrinter::print(Operation *op) {
os.indent(currentIndent);
printOperation(op);
printTrailingLocation(op->getLoc());
+ if (printerFlags.shouldPrintValueUsers())
+ printUsersComment(op);
}
void OperationPrinter::printOperation(Operation *op) {
@@ -2657,6 +2715,80 @@ void OperationPrinter::printOperation(Operation *op) {
printGenericOp(op, /*printOpName=*/true);
}
+void OperationPrinter::printUsersComment(Operation *op) {
+ unsigned numResults = op->getNumResults();
+ if (!numResults && op->getNumOperands()) {
+ os << " // id: ";
+ printOperationID(op);
+ } else if (numResults && op->use_empty()) {
+ os << " // unused";
+ } else if (numResults && !op->use_empty()) {
+ // Print "user" if the operation has one result used to compute one other
+ // result, or is used in one operation with no result.
+ unsigned usedInNResults = 0;
+ unsigned usedInNOperations = 0;
+ SmallPtrSet<Operation *, 1> userSet;
+ for (Operation *user : op->getUsers()) {
+ if (userSet.insert(user).second) {
+ ++usedInNOperations;
+ usedInNResults += user->getNumResults();
+ }
+ }
+
+ // We already know that users is not empty.
+ bool exactlyOneUniqueUse =
+ usedInNResults <= 1 && usedInNOperations <= 1 && numResults == 1;
+ os << " // " << (exactlyOneUniqueUse ? "user" : "users") << ": ";
+ bool shouldPrintBrackets = numResults > 1;
+ auto printOpResult = [&](OpResult opResult) {
+ if (shouldPrintBrackets)
+ os << "(";
+ printValueUsers(opResult);
+ if (shouldPrintBrackets)
+ os << ")";
+ };
+
+ interleaveComma(op->getResults(), printOpResult);
+ }
+}
+
+void OperationPrinter::printUsersComment(BlockArgument arg) {
+ os << "// ";
+ printValueID(arg);
+ if (arg.use_empty()) {
+ os << " is unused";
+ } else {
+ os << " is used by ";
+ printValueUsers(arg);
+ }
+ os << newLine;
+}
+
+void OperationPrinter::printValueUsers(Value value) {
+ if (value.use_empty())
+ os << "unused";
+
+ // One value might be used as the operand of an operation more than once.
+ // Only print the operations results once in that case.
+ SmallPtrSet<Operation *, 1> userSet;
+ for (auto &indexedUser : enumerate(value.getUsers())) {
+ if (userSet.insert(indexedUser.value()).second)
+ printUserIDs(indexedUser.value(), indexedUser.index());
+ }
+}
+
+void OperationPrinter::printUserIDs(Operation *user, bool prefixComma) {
+ if (prefixComma)
+ os << ", ";
+
+ if (!user->getNumResults()) {
+ printOperationID(user);
+ } else {
+ interleaveComma(user->getResults(),
+ [this](Value result) { printValueID(result); });
+ }
+}
+
void OperationPrinter::printGenericOp(Operation *op, bool printOpName) {
if (printOpName) {
os << '"';
@@ -2745,6 +2877,14 @@ void OperationPrinter::print(Block *block, bool printBlockArgs,
}
currentIndent += indentWidth;
+
+ if (printerFlags.shouldPrintValueUsers()) {
+ for (BlockArgument arg : block->getArguments()) {
+ os.indent(currentIndent);
+ printUsersComment(arg);
+ }
+ }
+
bool hasTerminator =
!block->empty() && block->back().hasTrait<OpTrait::IsTerminator>();
auto range = llvm::make_range(
@@ -2764,6 +2904,12 @@ void OperationPrinter::printValueID(Value value, bool printResultNo,
streamOverride ? *streamOverride : os);
}
+void OperationPrinter::printOperationID(Operation *op,
+ raw_ostream *streamOverride) const {
+ state->getSSANameState().printOperationID(op, streamOverride ? *streamOverride
+ : os);
+}
+
void OperationPrinter::printSuccessor(Block *successor) {
printBlockName(successor);
}
diff --git a/mlir/test/IR/print-value-users.mlir b/mlir/test/IR/print-value-users.mlir
new file mode 100644
index 0000000000000..042206a00757b
--- /dev/null
+++ b/mlir/test/IR/print-value-users.mlir
@@ -0,0 +1,65 @@
+// RUN: mlir-opt -allow-unregistered-dialect -mlir-print-value-users -split-input-file %s | FileCheck %s
+
+module {
+ // CHECK: %[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32
+ func @foo(%arg0: i32, %arg1: i32, %arg3: i32) -> i32 {
+ // CHECK-NEXT: // %[[ARG0]] is used by %[[ARG0U1:.+]], %[[ARG0U2:.+]], %[[ARG0U3:.+]]
+ // CHECK-NEXT: // %[[ARG1]] is used by %[[ARG1U1:.+]], %[[ARG1U2:.+]]
+ // CHECK-NEXT: // %[[ARG2]] is unused
+ // CHECK-NEXT: test.noop
+ // CHECK-NOT: // unused
+ "test.noop"() : () -> ()
+ // When no result is produced, an id should be printed.
+ // CHECK-NEXT: // id: %[[ARG0U3]]
+ "test.no_result"(%arg0) {} : (i32) -> ()
+ // Check for unused result.
+ // CHECK-NEXT: %[[ARG0U2]] =
+ // CHECK-SAME: // unused
+ %1 = "test.unused_result"(%arg0, %arg1) {} : (i32, i32) -> i32
+ // Check that both users are printed.
+ // CHECK-NEXT: %[[ARG0U1]] =
+ // CHECK-SAME: // users: %[[A:.+]]#0, %[[A]]#1
+ %2 = "test.one_result"(%arg0, %arg1) {} : (i32, i32) -> i32
+ // For multiple results, users should be grouped per result.
+ // CHECK-NEXT: %[[A]]:2 =
+ // CHECK-SAME: // users: (%[[B:.+]], %[[C:.+]]), (%[[B]], %[[D:.+]])
+ %3:2 = "test.many_results"(%2) {} : (i32) -> (i32, i32)
+ // Two results are produced, but there is only one user.
+ // CHECK-NEXT: // users:
+ %7:2 = "test.many_results"() : () -> (i32, i32)
+ // CHECK-NEXT: %[[C]] =
+ // Result is used twice in next operation but it produces only one result.
+ // CHECK-SAME: // user:
+ %4 = "test.foo"(%3#0) {} : (i32) -> i32
+ // CHECK-NEXT: %[[D]] =
+ %5 = "test.foo"(%3#1, %4, %4) {} : (i32, i32, i32) -> i32
+ // CHECK-NEXT: %[[B]] =
+ // Result is not used in any other result but in two operations.
+ // CHECK-SAME: // users:
+ %6 = "test.foo"(%3#0, %3#1) {} : (i32, i32) -> i32
+ "test.no_result"(%6) {} : (i32) -> ()
+ "test.no_result"(%7#0) : (i32) -> ()
+ return %6: i32
+ }
+}
+
+// -----
+
+module {
+ // Check with nested operation.
+ // CHECK: %[[CONSTNAME:.+]] = arith.constant
+ %0 = arith.constant 42 : i32
+ %test = "test.outerop"(%0) ({
+ // CHECK: "test.innerop"(%[[CONSTNAME]]) : (i32) -> () // id: %
+ "test.innerop"(%0) : (i32) -> ()
+ // CHECK: (i32) -> i32 // users: %r, %s, %p, %p_0, %q
+ }): (i32) -> i32
+
+ // Check named results.
+ // CHECK-NEXT: // users: (%u, %v), (unused), (%u, %v, %r, %s)
+ %p:2, %q = "test.custom_result_name"(%test) {names = ["p", "p", "q"]} : (i32) -> (i32, i32, i32)
+ // CHECK-NEXT: // users: (unused), (%u, %v)
+ %r, %s = "test.custom_result_name"(%q#0, %q#0, %test) {names = ["r", "s"]} : (i32, i32, i32) -> (i32, i32)
+ // CHECK-NEXT: // unused
+ %u, %v = "test.custom_result_name"(%s, %q#0, %p) {names = ["u", "v"]} : (i32, i32, i32) -> (i32, i32)
+}
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index c7aa3f1b9c9aa..68668596d7cba 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -1168,6 +1168,15 @@ void StringAttrPrettyNameOp::getAsmResultNames(
setNameFn(getResult(i), str.getValue());
}
+void CustomResultsNameOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ ArrayAttr value = getNames();
+ for (size_t i = 0, e = value.size(); i != e; ++i)
+ if (auto str = value[i].dyn_cast<StringAttr>())
+ if (!str.getValue().empty())
+ setNameFn(getResult(i), str.getValue());
+}
+
//===----------------------------------------------------------------------===//
// ResultTypeWithTraitOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index f42fb8edcc188..bf41f5f34f68d 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -732,6 +732,19 @@ def StringAttrPrettyNameOp
let hasCustomAssemblyFormat = 1;
}
+
+// This is used to test encoding of a string attribute into an SSA name of a
+// pretty printed value name.
+def CustomResultsNameOp
+ : TEST_Op<"custom_result_name",
+ [DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
+ let arguments = (ins
+ Variadic<AnyInteger>:$optional,
+ StrArrayAttr:$names
+ );
+ let results = (outs Variadic<AnyInteger>:$r);
+}
+
// This is used to test the OpAsmOpInterface::getDefaultDialect() feature:
// operations nested in a region under this op will drop the "test." dialect
// prefix.
More information about the Mlir-commits
mailing list