[Mlir-commits] [mlir] 2cdd246 - [mlir][NFC] Make 'printOp' public in AsmPrinter

Diego Caballero llvmlistbot at llvm.org
Wed Oct 5 12:01:47 PDT 2022


Author: Diego Caballero
Date: 2022-10-05T19:00:53Z
New Revision: 2cdd246a39076379b1678782f99872f918fa358c

URL: https://github.com/llvm/llvm-project/commit/2cdd246a39076379b1678782f99872f918fa358c
DIFF: https://github.com/llvm/llvm-project/commit/2cdd246a39076379b1678782f99872f918fa358c.diff

LOG: [mlir][NFC] Make 'printOp' public in AsmPrinter

This patch moves the 'printOp' functionality to the public API of
AsmPrinter and rename it to 'printCustomOrGenericOp'. No 'parseOp'
is needed at this time as existing APIs are able to parse operations
producing results where results are omitted in the textual form
(the LHS of an operation is redundant when it comes to building the
operation itself as it only contains the result names).

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D135006

Added: 
    mlir/test/IR/print-op-custom-or-generic.mlir

Modified: 
    mlir/include/mlir/IR/OpImplementation.h
    mlir/lib/IR/AsmPrinter.cpp

Removed: 
    mlir/test/IR/print-op-generic.mlir


################################################################################
diff  --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 78843accdf45a..0251394acba4b 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -380,6 +380,10 @@ class OpAsmPrinter : public AsmPrinter {
   printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs,
                                    ArrayRef<StringRef> elidedAttrs = {}) = 0;
 
+  /// Prints the entire operation with the custom assembly form, if available,
+  /// or the generic assembly form, otherwise.
+  virtual void printCustomOrGenericOp(Operation *op) = 0;
+
   /// Print the entire operation with the default generic assembly form.
   /// If `printOpName` is true, then the operation name is printed (the default)
   /// otherwise it is omitted and the print will start with the operand list.

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index f51ea60de523f..53da51cf10862 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -421,8 +421,9 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
                                       AliasInitializer &initializer)
       : printerFlags(printerFlags), initializer(initializer) {}
 
-  /// Print the given operation.
-  void print(Operation *op) {
+  /// Prints the entire operation with the custom assembly form, if available,
+  /// or the generic assembly form, otherwise.
+  void printCustomOrGenericOp(Operation *op) override {
     // Visit the operation location.
     if (printerFlags.shouldPrintDebugInfo())
       initializer.visit(op->getLoc(), /*canBeDeferred=*/true);
@@ -489,7 +490,7 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
         std::prev(block->end(),
                   (!hasTerminator || printBlockTerminator) ? 0 : 1));
     for (Operation &op : range)
-      print(&op);
+      printCustomOrGenericOp(&op);
   }
 
   /// Print the given region.
@@ -680,7 +681,7 @@ void AliasInitializer::initialize(
   // attributes/types that will actually be used during printing when
   // considering aliases.
   DummyAliasOperationPrinter aliasPrinter(printerFlags, *this);
-  aliasPrinter.print(op);
+  aliasPrinter.printCustomOrGenericOp(op);
 
   // Initialize the aliases sorted by name.
   initializeAliases(aliasToAttr, attrToAlias, &deferrableAttributes);
@@ -2660,11 +2661,16 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
   /// Print the given top-level operation.
   void printTopLevelOperation(Operation *op);
 
-  /// Print the given operation with its indent and location.
-  void print(Operation *op);
-  /// Print the bare location, not including indentation/location/etc.
-  void printOperation(Operation *op);
-  /// Print the given operation in the generic form.
+  /// Print the given operation, including its left-hand side and its right-hand
+  /// side, with its indent and location.
+  void printFullOpWithIndentAndLoc(Operation *op);
+  /// Print the given operation, including its left-hand side and its right-hand
+  /// side, but not including indentation and location.
+  void printFullOp(Operation *op);
+  /// Print the right-hand size of the given operation in the custom or generic
+  /// form.
+  void printCustomOrGenericOp(Operation *op) override;
+  /// Print the right-hand side of the given operation in the generic form.
   void printGenericOp(Operation *op, bool printOpName) override;
 
   /// Print the name of the given block.
@@ -2838,7 +2844,7 @@ void OperationPrinter::printTopLevelOperation(Operation *op) {
   state.getAliasState().printNonDeferredAliases(os, newLine);
 
   // Print the module.
-  print(op);
+  printFullOpWithIndentAndLoc(op);
   os << newLine;
 
   // Output the aliases at the top level that can be deferred.
@@ -2934,18 +2940,18 @@ void OperationPrinter::printRegionArgument(BlockArgument arg,
   printTrailingLocation(arg.getLoc(), /*allowAlias*/ false);
 }
 
-void OperationPrinter::print(Operation *op) {
+void OperationPrinter::printFullOpWithIndentAndLoc(Operation *op) {
   // Track the location of this operation.
   state.registerOperationLocation(op, newLine.curLine, currentIndent);
 
   os.indent(currentIndent);
-  printOperation(op);
+  printFullOp(op);
   printTrailingLocation(op->getLoc());
   if (printerFlags.shouldPrintValueUsers())
     printUsersComment(op);
 }
 
-void OperationPrinter::printOperation(Operation *op) {
+void OperationPrinter::printFullOp(Operation *op) {
   if (size_t numResults = op->getNumResults()) {
     auto printResultGroup = [&](size_t resultNo, size_t resultCount) {
       printValueID(op->getResult(resultNo), /*printResultNo=*/false);
@@ -2972,34 +2978,7 @@ void OperationPrinter::printOperation(Operation *op) {
     os << " = ";
   }
 
-  // If requested, always print the generic form.
-  if (!printerFlags.shouldPrintGenericOpForm()) {
-    // Check to see if this is a known operation. If so, use the registered
-    // custom printer hook.
-    if (auto opInfo = op->getRegisteredInfo()) {
-      opInfo->printAssembly(op, *this, defaultDialectStack.back());
-      return;
-    }
-    // Otherwise try to dispatch to the dialect, if available.
-    if (Dialect *dialect = op->getDialect()) {
-      if (auto opPrinter = dialect->getOperationPrinter(op)) {
-        // Print the op name first.
-        StringRef name = op->getName().getStringRef();
-        // Only drop the default dialect prefix when it cannot lead to
-        // ambiguities.
-        if (name.count('.') == 1)
-          name.consume_front((defaultDialectStack.back() + ".").str());
-        os << name;
-
-        // Print the rest of the op now.
-        opPrinter(op, *this);
-        return;
-      }
-    }
-  }
-
-  // Otherwise print with the generic assembly form.
-  printGenericOp(op, /*printOpName=*/true);
+  printCustomOrGenericOp(op);
 }
 
 void OperationPrinter::printUsersComment(Operation *op) {
@@ -3076,6 +3055,37 @@ void OperationPrinter::printUserIDs(Operation *user, bool prefixComma) {
   }
 }
 
+void OperationPrinter::printCustomOrGenericOp(Operation *op) {
+  // If requested, always print the generic form.
+  if (!printerFlags.shouldPrintGenericOpForm()) {
+    // Check to see if this is a known operation. If so, use the registered
+    // custom printer hook.
+    if (auto opInfo = op->getRegisteredInfo()) {
+      opInfo->printAssembly(op, *this, defaultDialectStack.back());
+      return;
+    }
+    // Otherwise try to dispatch to the dialect, if available.
+    if (Dialect *dialect = op->getDialect()) {
+      if (auto opPrinter = dialect->getOperationPrinter(op)) {
+        // Print the op name first.
+        StringRef name = op->getName().getStringRef();
+        // Only drop the default dialect prefix when it cannot lead to
+        // ambiguities.
+        if (name.count('.') == 1)
+          name.consume_front((defaultDialectStack.back() + ".").str());
+        os << name;
+
+        // Print the rest of the op now.
+        opPrinter(op, *this);
+        return;
+      }
+    }
+  }
+
+  // Otherwise print with the generic assembly form.
+  printGenericOp(op, /*printOpName=*/true);
+}
+
 void OperationPrinter::printGenericOp(Operation *op, bool printOpName) {
   if (printOpName)
     printEscapedString(op->getName().getStringRef());
@@ -3176,7 +3186,7 @@ void OperationPrinter::print(Block *block, bool printBlockArgs,
       std::prev(block->end(),
                 (!hasTerminator || printBlockTerminator) ? 0 : 1));
   for (auto &op : range) {
-    print(&op);
+    printFullOpWithIndentAndLoc(&op);
     os << newLine;
   }
   currentIndent -= indentWidth;
@@ -3418,7 +3428,7 @@ void Operation::print(raw_ostream &os, AsmState &state) {
     state.getImpl().initializeAliases(this);
     printer.printTopLevelOperation(this);
   } else {
-    printer.print(this);
+    printer.printFullOpWithIndentAndLoc(this);
   }
 }
 

diff  --git a/mlir/test/IR/print-op-custom-or-generic.mlir b/mlir/test/IR/print-op-custom-or-generic.mlir
new file mode 100644
index 0000000000000..a82089b36ae94
--- /dev/null
+++ b/mlir/test/IR/print-op-custom-or-generic.mlir
@@ -0,0 +1,28 @@
+// # RUN: mlir-opt %s -split-input-file | FileCheck %s
+// # RUN: mlir-opt %s -mlir-print-op-generic -split-input-file  | FileCheck %s --check-prefix=GENERIC
+
+// Check that `printCustomOrGenericOp` and `printGenericOp` print the right
+// assembly format. For operations without custom format, both should print the
+// generic format.
+
+// CHECK-LABEL: func @op_with_custom_printer
+// CHECK-GENERIC-LABEL: "func"()
+func.func @op_with_custom_printer() {
+  %x = test.string_attr_pretty_name
+  // CHECK: %x = test.string_attr_pretty_name
+  // GENERIC: %0 = "test.string_attr_pretty_name"()
+  return
+  // CHECK: return
+  // GENERIC: "func.return"()
+}
+
+// -----
+
+// CHECK-LABEL: func @op_without_custom_printer
+// CHECK-GENERIC: "func"()
+func.func @op_without_custom_printer() {
+  // CHECK: "test.result_type_with_trait"() : () -> !test.test_type_with_trait
+  // GENERIC: "test.result_type_with_trait"() : () -> !test.test_type_with_trait
+  "test.result_type_with_trait"() : () -> !test.test_type_with_trait
+  return
+}

diff  --git a/mlir/test/IR/print-op-generic.mlir b/mlir/test/IR/print-op-generic.mlir
deleted file mode 100644
index ed34f8406a781..0000000000000
--- a/mlir/test/IR/print-op-generic.mlir
+++ /dev/null
@@ -1,13 +0,0 @@
-// # RUN: mlir-opt %s | FileCheck %s
-// # RUN: mlir-opt %s --mlir-print-op-generic  | FileCheck %s --check-prefix=GENERIC
-
-// CHECK-LABEL: func @pretty_names
-// CHECK-GENERIC: "func"()
-func.func @pretty_names() {
-  %x = test.string_attr_pretty_name
-  // CHECK: %x = test.string_attr_pretty_name
-  // GENERIC: %0 = "test.string_attr_pretty_name"()
-  return
-  // CHECK: return
-  // GENERIC: "func.return"()
-}


        


More information about the Mlir-commits mailing list