[Mlir-commits] [mlir] 3e8560f - [MLIR] Add option to print users of an operation as comment in the printer

Mehdi Amini llvmlistbot at llvm.org
Fri Apr 22 11:58:18 PDT 2022


Author: cpillmayer
Date: 2022-04-22T18:58:10Z
New Revision: 3e8560f890bb795fec6b86186239633f22a153d5

URL: https://github.com/llvm/llvm-project/commit/3e8560f890bb795fec6b86186239633f22a153d5
DIFF: https://github.com/llvm/llvm-project/commit/3e8560f890bb795fec6b86186239633f22a153d5.diff

LOG: [MLIR] Add option to print users of an operation as comment in the printer

This allows printing the users of an operation as proposed in the git issue #53286.
To be able to refer to operations with no result, these operations are assigned an
ID in SSANameState.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D124048

Added: 
    mlir/test/IR/print-value-users.mlir

Modified: 
    mlir/include/mlir/IR/OperationSupport.h
    mlir/lib/IR/AsmPrinter.cpp
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 52138b79554eb..d995518282eee 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -752,6 +752,9 @@ class OpPrintingFlags {
   /// the full module.
   OpPrintingFlags &useLocalScope();
 
+  /// Print users of values as comments.
+  OpPrintingFlags &printValueUsers();
+
   /// Return if the given ElementsAttr should be elided.
   bool shouldElideElementsAttr(ElementsAttr attr) const;
 
@@ -773,6 +776,9 @@ class OpPrintingFlags {
   /// Return if the printer should use local scope when dumping the IR.
   bool shouldUseLocalScope() const;
 
+  /// Return if the printer should print users of values.
+  bool shouldPrintValueUsers() const;
+
 private:
   /// Elide large elements attributes if the number of elements is larger than
   /// the upper limit.
@@ -790,6 +796,9 @@ class OpPrintingFlags {
 
   /// Print operations with numberings local to the current operation.
   bool printLocalScope : 1;
+
+  /// Print users of values.
+  bool printValueUsersFlag : 1;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 980c886744ad5..013bf162211b8 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -152,6 +152,11 @@ struct AsmPrinterOptions {
       "mlir-print-local-scope", llvm::cl::init(false),
       llvm::cl::desc("Print with local scope and inline information (eliding "
                      "aliases for attributes, types, and locations")};
+
+  llvm::cl::opt<bool> printValueUsers{
+      "mlir-print-value-users", llvm::cl::init(false),
+      llvm::cl::desc(
+          "Print users of operation results and block arguments as a comment")};
 };
 } // namespace
 
@@ -168,7 +173,7 @@ void mlir::registerAsmPrinterCLOptions() {
 OpPrintingFlags::OpPrintingFlags()
     : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
       printGenericOpFormFlag(false), assumeVerifiedFlag(false),
-      printLocalScope(false) {
+      printLocalScope(false), printValueUsersFlag(false) {
   // Initialize based upon command line options, if they are available.
   if (!clOptions.isConstructed())
     return;
@@ -179,6 +184,7 @@ OpPrintingFlags::OpPrintingFlags()
   printGenericOpFormFlag = clOptions->printGenericOpFormOpt;
   assumeVerifiedFlag = clOptions->assumeVerifiedOpt;
   printLocalScope = clOptions->printLocalScopeOpt;
+  printValueUsersFlag = clOptions->printValueUsers;
 }
 
 /// Enable the elision of large elements attributes, by printing a '...'
@@ -219,6 +225,12 @@ OpPrintingFlags &OpPrintingFlags::useLocalScope() {
   return *this;
 }
 
+/// Print users of values as comments.
+OpPrintingFlags &OpPrintingFlags::printValueUsers() {
+  printValueUsersFlag = true;
+  return *this;
+}
+
 /// Return if the given ElementsAttr should be elided.
 bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const {
   return elementsAttrElementLimit.hasValue() &&
@@ -254,6 +266,11 @@ bool OpPrintingFlags::shouldAssumeVerified() const {
 /// Return if the printer should use local scope when dumping the IR.
 bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; }
 
+/// Return if the printer should print users of values.
+bool OpPrintingFlags::shouldPrintValueUsers() const {
+  return printValueUsersFlag;
+}
+
 /// Returns true if an ElementsAttr with the given number of elements should be
 /// printed with hex.
 static bool shouldPrintElementsAttrWithHex(int64_t numElements) {
@@ -831,6 +848,9 @@ class SSANameState {
   /// of this value.
   void printValueID(Value value, bool printResultNo, raw_ostream &stream) const;
 
+  /// Print the operation identifier.
+  void printOperationID(Operation *op, raw_ostream &stream) const;
+
   /// Return the result indices for each of the result groups registered by this
   /// operation, or empty if none exist.
   ArrayRef<int> getOpResultGroups(Operation *op);
@@ -868,6 +888,10 @@ class SSANameState {
   DenseMap<Value, unsigned> valueIDs;
   DenseMap<Value, StringRef> valueNames;
 
+  /// When printing users of values, an operation without a result might
+  /// be the user. This map holds ids for such operations.
+  DenseMap<Operation *, unsigned> operationIDs;
+
   /// This is a map of operations that contain multiple named result groups,
   /// i.e. there may be multiple names for the results of the operation. The
   /// value of this map are the result numbers that start a result group.
@@ -990,6 +1014,15 @@ void SSANameState::printValueID(Value value, bool printResultNo,
     stream << '#' << resultNo;
 }
 
+void SSANameState::printOperationID(Operation *op, raw_ostream &stream) const {
+  auto it = operationIDs.find(op);
+  if (it == operationIDs.end()) {
+    stream << "<<UNKOWN OPERATION>>";
+  } else {
+    stream << '%' << it->second;
+  }
+}
+
 ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) {
   auto it = opResultGroups.find(op);
   return it == opResultGroups.end() ? ArrayRef<int>() : it->second;
@@ -1121,8 +1154,14 @@ void SSANameState::numberValuesInOp(Operation &op) {
   }
 
   unsigned numResults = op.getNumResults();
-  if (numResults == 0)
+  if (numResults == 0) {
+    // If value users should be printed, operations with no result need an id.
+    if (printerFlags.shouldPrintValueUsers()) {
+      if (operationIDs.try_emplace(&op, nextValueID).second)
+        ++nextValueID;
+    }
     return;
+  }
   Value resultBegin = op.getResult(0);
 
   // If the first result wasn't numbered, give it a default number.
@@ -2481,6 +2520,10 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
   void printValueID(Value value, bool printResultNo = true,
                     raw_ostream *streamOverride = nullptr) const;
 
+  /// Print the ID of the given operation.
+  void printOperationID(Operation *op,
+                        raw_ostream *streamOverride = nullptr) const;
+
   //===--------------------------------------------------------------------===//
   // OpAsmPrinter methods
   //===--------------------------------------------------------------------===//
@@ -2549,6 +2592,19 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
   void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
                                ValueRange symOperands) override;
 
+  /// Print users of this operation or id of this operation if it has no result.
+  void printUsersComment(Operation *op);
+
+  /// Print users of this block arg.
+  void printUsersComment(BlockArgument arg);
+
+  /// Print the users of a value.
+  void printValueUsers(Value value);
+
+  /// Print either the ids of the result values or the id of the operation if
+  /// the operation has no results.
+  void printUserIDs(Operation *user, bool prefixComma = false);
+
 private:
   // Contains the stack of default dialects to use when printing regions.
   // A new dialect is pushed to the stack before parsing regions nested under an
@@ -2602,6 +2658,8 @@ void OperationPrinter::print(Operation *op) {
   os.indent(currentIndent);
   printOperation(op);
   printTrailingLocation(op->getLoc());
+  if (printerFlags.shouldPrintValueUsers())
+    printUsersComment(op);
 }
 
 void OperationPrinter::printOperation(Operation *op) {
@@ -2657,6 +2715,80 @@ void OperationPrinter::printOperation(Operation *op) {
   printGenericOp(op, /*printOpName=*/true);
 }
 
+void OperationPrinter::printUsersComment(Operation *op) {
+  unsigned numResults = op->getNumResults();
+  if (!numResults && op->getNumOperands()) {
+    os << " // id: ";
+    printOperationID(op);
+  } else if (numResults && op->use_empty()) {
+    os << " // unused";
+  } else if (numResults && !op->use_empty()) {
+    // Print "user" if the operation has one result used to compute one other
+    // result, or is used in one operation with no result.
+    unsigned usedInNResults = 0;
+    unsigned usedInNOperations = 0;
+    SmallPtrSet<Operation *, 1> userSet;
+    for (Operation *user : op->getUsers()) {
+      if (userSet.insert(user).second) {
+        ++usedInNOperations;
+        usedInNResults += user->getNumResults();
+      }
+    }
+
+    // We already know that users is not empty.
+    bool exactlyOneUniqueUse =
+        usedInNResults <= 1 && usedInNOperations <= 1 && numResults == 1;
+    os << " // " << (exactlyOneUniqueUse ? "user" : "users") << ": ";
+    bool shouldPrintBrackets = numResults > 1;
+    auto printOpResult = [&](OpResult opResult) {
+      if (shouldPrintBrackets)
+        os << "(";
+      printValueUsers(opResult);
+      if (shouldPrintBrackets)
+        os << ")";
+    };
+
+    interleaveComma(op->getResults(), printOpResult);
+  }
+}
+
+void OperationPrinter::printUsersComment(BlockArgument arg) {
+  os << "// ";
+  printValueID(arg);
+  if (arg.use_empty()) {
+    os << " is unused";
+  } else {
+    os << " is used by ";
+    printValueUsers(arg);
+  }
+  os << newLine;
+}
+
+void OperationPrinter::printValueUsers(Value value) {
+  if (value.use_empty())
+    os << "unused";
+
+  // One value might be used as the operand of an operation more than once.
+  // Only print the operations results once in that case.
+  SmallPtrSet<Operation *, 1> userSet;
+  for (auto &indexedUser : enumerate(value.getUsers())) {
+    if (userSet.insert(indexedUser.value()).second)
+      printUserIDs(indexedUser.value(), indexedUser.index());
+  }
+}
+
+void OperationPrinter::printUserIDs(Operation *user, bool prefixComma) {
+  if (prefixComma)
+    os << ", ";
+
+  if (!user->getNumResults()) {
+    printOperationID(user);
+  } else {
+    interleaveComma(user->getResults(),
+                    [this](Value result) { printValueID(result); });
+  }
+}
+
 void OperationPrinter::printGenericOp(Operation *op, bool printOpName) {
   if (printOpName) {
     os << '"';
@@ -2745,6 +2877,14 @@ void OperationPrinter::print(Block *block, bool printBlockArgs,
   }
 
   currentIndent += indentWidth;
+
+  if (printerFlags.shouldPrintValueUsers()) {
+    for (BlockArgument arg : block->getArguments()) {
+      os.indent(currentIndent);
+      printUsersComment(arg);
+    }
+  }
+
   bool hasTerminator =
       !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>();
   auto range = llvm::make_range(
@@ -2764,6 +2904,12 @@ void OperationPrinter::printValueID(Value value, bool printResultNo,
                                         streamOverride ? *streamOverride : os);
 }
 
+void OperationPrinter::printOperationID(Operation *op,
+                                        raw_ostream *streamOverride) const {
+  state->getSSANameState().printOperationID(op, streamOverride ? *streamOverride
+                                                               : os);
+}
+
 void OperationPrinter::printSuccessor(Block *successor) {
   printBlockName(successor);
 }

diff  --git a/mlir/test/IR/print-value-users.mlir b/mlir/test/IR/print-value-users.mlir
new file mode 100644
index 0000000000000..042206a00757b
--- /dev/null
+++ b/mlir/test/IR/print-value-users.mlir
@@ -0,0 +1,65 @@
+// RUN: mlir-opt -allow-unregistered-dialect -mlir-print-value-users -split-input-file %s | FileCheck %s
+
+module {
+    // CHECK: %[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32
+    func @foo(%arg0: i32, %arg1: i32, %arg3: i32) -> i32 {
+        // CHECK-NEXT: // %[[ARG0]] is used by %[[ARG0U1:.+]], %[[ARG0U2:.+]], %[[ARG0U3:.+]]
+        // CHECK-NEXT: // %[[ARG1]] is used by %[[ARG1U1:.+]], %[[ARG1U2:.+]]
+        // CHECK-NEXT: // %[[ARG2]] is unused
+        // CHECK-NEXT: test.noop
+        // CHECK-NOT: // unused
+        "test.noop"() : () -> ()
+        // When no result is produced, an id should be printed.
+        // CHECK-NEXT: // id: %[[ARG0U3]]
+        "test.no_result"(%arg0) {} : (i32) -> ()
+        // Check for unused result.
+        // CHECK-NEXT: %[[ARG0U2]] = 
+        // CHECK-SAME: // unused
+        %1 = "test.unused_result"(%arg0, %arg1) {} : (i32, i32) -> i32
+        // Check that both users are printed.
+        // CHECK-NEXT: %[[ARG0U1]] = 
+        // CHECK-SAME: // users: %[[A:.+]]#0, %[[A]]#1
+        %2 = "test.one_result"(%arg0, %arg1) {} : (i32, i32) -> i32
+        // For multiple results, users should be grouped per result.
+        // CHECK-NEXT: %[[A]]:2 = 
+        // CHECK-SAME: // users: (%[[B:.+]], %[[C:.+]]), (%[[B]], %[[D:.+]])
+        %3:2 = "test.many_results"(%2) {} : (i32) -> (i32, i32)
+        // Two results are produced, but there is only one user.
+        // CHECK-NEXT: // users:
+        %7:2 = "test.many_results"() : () -> (i32, i32)
+        // CHECK-NEXT: %[[C]] =
+        // Result is used twice in next operation but it produces only one result.
+        // CHECK-SAME: // user:
+        %4 = "test.foo"(%3#0) {} : (i32) -> i32
+        // CHECK-NEXT: %[[D]] =
+        %5 = "test.foo"(%3#1, %4, %4) {} : (i32, i32, i32) -> i32
+        // CHECK-NEXT: %[[B]] =
+        // Result is not used in any other result but in two operations.
+        // CHECK-SAME: // users:
+        %6 = "test.foo"(%3#0, %3#1) {} : (i32, i32) -> i32
+        "test.no_result"(%6) {} : (i32) -> ()
+        "test.no_result"(%7#0) : (i32) -> ()
+        return %6: i32
+    }
+}
+
+// -----
+
+module {
+    // Check with nested operation.
+    // CHECK: %[[CONSTNAME:.+]] = arith.constant
+    %0 = arith.constant 42 : i32
+    %test = "test.outerop"(%0) ({
+        // CHECK: "test.innerop"(%[[CONSTNAME]]) : (i32) -> () // id: %
+        "test.innerop"(%0) : (i32) -> ()
+    // CHECK: (i32) -> i32 // users: %r, %s, %p, %p_0, %q
+    }): (i32) -> i32
+
+    // Check named results.
+    // CHECK-NEXT: // users: (%u, %v), (unused), (%u, %v, %r, %s)
+    %p:2, %q = "test.custom_result_name"(%test) {names = ["p", "p", "q"]} : (i32) -> (i32, i32, i32)
+    // CHECK-NEXT: // users: (unused), (%u, %v)
+    %r, %s = "test.custom_result_name"(%q#0, %q#0, %test) {names = ["r", "s"]} : (i32, i32, i32) -> (i32, i32)
+    // CHECK-NEXT: // unused
+    %u, %v = "test.custom_result_name"(%s, %q#0, %p) {names = ["u", "v"]} : (i32, i32, i32) -> (i32, i32)
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index c7aa3f1b9c9aa..68668596d7cba 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -1168,6 +1168,15 @@ void StringAttrPrettyNameOp::getAsmResultNames(
         setNameFn(getResult(i), str.getValue());
 }
 
+void CustomResultsNameOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  ArrayAttr value = getNames();
+  for (size_t i = 0, e = value.size(); i != e; ++i)
+    if (auto str = value[i].dyn_cast<StringAttr>())
+      if (!str.getValue().empty())
+        setNameFn(getResult(i), str.getValue());
+}
+
 //===----------------------------------------------------------------------===//
 // ResultTypeWithTraitOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index f42fb8edcc188..bf41f5f34f68d 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -732,6 +732,19 @@ def StringAttrPrettyNameOp
   let hasCustomAssemblyFormat = 1;
 }
 
+
+// This is used to test encoding of a string attribute into an SSA name of a
+// pretty printed value name.
+def CustomResultsNameOp
+ : TEST_Op<"custom_result_name",
+           [DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
+  let arguments = (ins 
+    Variadic<AnyInteger>:$optional,
+    StrArrayAttr:$names
+  );
+  let results = (outs Variadic<AnyInteger>:$r);
+}
+
 // This is used to test the OpAsmOpInterface::getDefaultDialect() feature:
 // operations nested in a region under this op will drop the "test." dialect
 // prefix.


        


More information about the Mlir-commits mailing list