[Mlir-commits] [mlir] fd87963 - Change dialect `printOperation()` hook to `getOperationPrinter()`

Mehdi Amini llvmlistbot at llvm.org
Tue Aug 31 10:52:52 PDT 2021


Author: Mehdi Amini
Date: 2021-08-31T17:52:39Z
New Revision: fd87963eee23f6cf2aed97bf182a6b3f5e9450ed

URL: https://github.com/llvm/llvm-project/commit/fd87963eee23f6cf2aed97bf182a6b3f5e9450ed
DIFF: https://github.com/llvm/llvm-project/commit/fd87963eee23f6cf2aed97bf182a6b3f5e9450ed.diff

LOG: Change dialect `printOperation()` hook to `getOperationPrinter()`

This makes the hook return a printer if available, instead of using LogicalResult  to
indicate if a printer was available (and invoked). This allows the caller to detect that
the dialect has a printer for a given operation without actually invoking the printer.
It'll be leveraged in a future revision to move printing the op name itself under control
of the ASMPrinter.

Differential Revision: https://reviews.llvm.org/D108803

Added: 
    

Modified: 
    mlir/include/mlir/IR/Dialect.h
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/Dialect.cpp
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h
index f615819fd16bb..14114e379f78b 100644
--- a/mlir/include/mlir/IR/Dialect.h
+++ b/mlir/include/mlir/IR/Dialect.h
@@ -121,8 +121,8 @@ class Dialect {
   /// Print an operation registered to this dialect.
   /// This hook is invoked for registered operation which don't override the
   /// `print()` method to define their own custom assembly.
-  virtual LogicalResult printOperation(Operation *op,
-                                       OpAsmPrinter &printer) const;
+  virtual llvm::unique_function<void(Operation *, OpAsmPrinter &printer)>
+  getOperationPrinter(Operation *op) const;
 
   //===--------------------------------------------------------------------===//
   // Verification Hooks
@@ -297,8 +297,7 @@ class DialectRegistry {
 public:
   explicit DialectRegistry();
 
-  template <typename ConcreteDialect>
-  void insert() {
+  template <typename ConcreteDialect> void insert() {
     insert(TypeID::get<ConcreteDialect>(),
            ConcreteDialect::getDialectNamespace(),
            static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) {
@@ -364,8 +363,7 @@ class DialectRegistry {
   /// Add an external op interface model for an op that belongs to a dialect,
   /// both provided as template parameters. The dialect must be present in the
   /// registry.
-  template <typename OpTy, typename ModelTy>
-  void addOpInterface() {
+  template <typename OpTy, typename ModelTy> void addOpInterface() {
     StringRef opName = OpTy::getOperationName();
     StringRef dialectName = opName.split('.').first;
     addObjectInterface(dialectName, TypeID::get<OpTy>(),
@@ -426,8 +424,7 @@ class DialectRegistry {
 
 namespace llvm {
 /// Provide isa functionality for Dialects.
-template <typename T>
-struct isa_impl<T, ::mlir::Dialect> {
+template <typename T> struct isa_impl<T, ::mlir::Dialect> {
   static inline bool doit(const ::mlir::Dialect &dialect) {
     return mlir::TypeID::get<T>() == dialect.getTypeID();
   }

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 65cbc8a8c33eb..b0fafc3297635 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2508,8 +2508,10 @@ void OperationPrinter::printOperation(Operation *op) {
     }
     // Otherwise try to dispatch to the dialect, if available.
     if (Dialect *dialect = op->getDialect()) {
-      if (succeeded(dialect->printOperation(op, *this)))
+      if (auto opPrinter = dialect->getOperationPrinter(op)) {
+        opPrinter(op, *this);
         return;
+      }
     }
   }
 

diff  --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index 80c8dabe1f3b9..2f2997e11fb83 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -172,11 +172,11 @@ Dialect::getParseOperationHook(StringRef opName) const {
   return None;
 }
 
-LogicalResult Dialect::printOperation(Operation *op,
-                                      OpAsmPrinter &printer) const {
+llvm::unique_function<void(Operation *, OpAsmPrinter &printer)>
+Dialect::getOperationPrinter(Operation *op) const {
   assert(op->getDialect() == this &&
          "Dialect hook invoked on non-dialect owned operation");
-  return failure();
+  return nullptr;
 }
 
 /// Utility function that returns if the given string is a valid dialect

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 2a1f37119e2d5..d12c61a124d14 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -313,14 +313,15 @@ TestDialect::getParseOperationHook(StringRef opName) const {
   return None;
 }
 
-LogicalResult TestDialect::printOperation(Operation *op,
-                                          OpAsmPrinter &printer) const {
+llvm::unique_function<void(Operation *, OpAsmPrinter &)>
+TestDialect::getOperationPrinter(Operation *op) const {
   StringRef opName = op->getName().getStringRef();
   if (opName == "test.dialect_custom_printer") {
-    printer.getStream() << opName << " custom_format";
-    return success();
+    return [](Operation *op, OpAsmPrinter &printer) {
+      printer.getStream() << op->getName().getStringRef() << " custom_format";
+    };
   }
-  return failure();
+  return {};
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index fbbc766839c40..4dee0a1b366ca 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -39,15 +39,17 @@ def Test_Dialect : Dialect {
     void registerTypes();
 
     ::mlir::Attribute parseAttribute(::mlir::DialectAsmParser &parser,
-                             ::mlir::Type type) const override;
+                                     ::mlir::Type type) const override;
     void printAttribute(::mlir::Attribute attr,
                         ::mlir::DialectAsmPrinter &printer) const override;
 
     // Provides a custom printing/parsing for some operations.
     ::llvm::Optional<ParseOpHook>
       getParseOperationHook(::llvm::StringRef opName) const override;
-    ::mlir::LogicalResult printOperation(::mlir::Operation *op,
-                                 ::mlir::OpAsmPrinter &printer) const override;
+    ::llvm::unique_function<void(::mlir::Operation *,
+                                 ::mlir::OpAsmPrinter &printer)>
+     getOperationPrinter(::mlir::Operation *op) const override;
+
   private:
     // Storage for a custom fallback interface.
     void *fallbackEffectOpInterfaces;


        


More information about the Mlir-commits mailing list