[Mlir-commits] [mlir] a0c776f - Add a mechanism for Dialects to customize printing/parsing operations when they are unregistered

Mehdi Amini llvmlistbot at llvm.org
Mon Mar 22 17:40:17 PDT 2021


Author: Mehdi Amini
Date: 2021-03-23T00:40:03Z
New Revision: a0c776fc94d3179822c95dcb9f79b344e13f069b

URL: https://github.com/llvm/llvm-project/commit/a0c776fc94d3179822c95dcb9f79b344e13f069b
DIFF: https://github.com/llvm/llvm-project/commit/a0c776fc94d3179822c95dcb9f79b344e13f069b.diff

LOG: Add a mechanism for Dialects to customize printing/parsing operations when they are unregistered

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D99007

Added: 
    

Modified: 
    mlir/include/mlir/IR/Dialect.h
    mlir/include/mlir/IR/OperationSupport.h
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/Dialect.cpp
    mlir/lib/Parser/Parser.cpp
    mlir/test/IR/parser.mlir
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h
index 4a816ccb79c95..2fdbdc4829833 100644
--- a/mlir/include/mlir/IR/Dialect.h
+++ b/mlir/include/mlir/IR/Dialect.h
@@ -39,6 +39,12 @@ using InterfaceAllocatorFunction =
 ///
 class Dialect {
 public:
+  /// Type for a callback provided by the dialect to parse a custom operation.
+  /// This is used for the dialect to provide an alternative way to parse custom
+  /// operations, including unregistered ones.
+  using ParseOpHook =
+      function_ref<ParseResult(OpAsmParser &parser, OperationState &result)>;
+
   virtual ~Dialect();
 
   /// Utility function that returns if the given string is a valid dialect
@@ -97,6 +103,18 @@ class Dialect {
     llvm_unreachable("dialect has no registered type printing hook");
   }
 
+  /// Return the hook to parse an operation registered to this dialect, if any.
+  /// By default this will lookup for registered operations and return the
+  /// `parse()` method registered on the AbstractOperation. Dialects can
+  /// override this behavior and handle unregistered operations as well.
+  virtual Optional<ParseOpHook> getParseOperationHook(StringRef opName) const;
+
+  /// Print an operation registered to this dialect.
+  /// This hook is invoked for registered operation which don't override the
+  /// `print()` method to define their own custom assembly.
+  virtual LogicalResult printOperation(Operation *op,
+                                       OpAsmPrinter &printer) const;
+
   //===--------------------------------------------------------------------===//
   // Verification Hooks
   //===--------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index c2241df46191d..8cc97d9c02ee8 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -88,6 +88,9 @@ class AbstractOperation {
   /// Use the specified object to parse this ops custom assembly format.
   ParseResult parseAssembly(OpAsmParser &parser, OperationState &result) const;
 
+  /// Return the static hook for parsing this operation assembly.
+  ParseAssemblyFn getParseAssemblyFn() const { return parseAssemblyFn; }
+
   /// This hook implements the AsmPrinter for this operation.
   void printAssembly(Operation *op, OpAsmPrinter &p) const {
     return printAssemblyFn(op, p);

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index ad2c3c6a9075c..8cd20c7777adf 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2409,6 +2409,11 @@ void OperationPrinter::printOperation(Operation *op) {
       opInfo->printAssembly(op, *this);
       return;
     }
+    // Otherwise try to dispatch to the dialect, if available.
+    if (Dialect *dialect = op->getDialect()) {
+      if (succeeded(dialect->printOperation(op, *this)))
+        return;
+    }
   }
 
   // Otherwise print with the generic assembly form.

diff  --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index fc21152878e36..612c902d47079 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -136,6 +136,18 @@ Type Dialect::parseType(DialectAsmParser &parser) const {
   return Type();
 }
 
+Optional<Dialect::ParseOpHook>
+Dialect::getParseOperationHook(StringRef opName) const {
+  return None;
+}
+
+LogicalResult Dialect::printOperation(Operation *op,
+                                      OpAsmPrinter &printer) const {
+  assert(op->getDialect() == this &&
+         "Dialect hook invoked on non-dialect owned operation");
+  return failure();
+}
+
 /// Utility function that returns if the given string is a valid dialect
 /// namespace.
 bool Dialect::isValidNamespace(StringRef str) {

diff  --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 7b3cd158b2ada..ad80204ac496d 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -925,17 +925,18 @@ Operation *OperationParser::parseGenericOperation(Block *insertBlock,
 namespace {
 class CustomOpAsmParser : public OpAsmParser {
 public:
-  CustomOpAsmParser(SMLoc nameLoc,
-                    ArrayRef<OperationParser::ResultRecord> resultIDs,
-                    const AbstractOperation *opDefinition,
-                    OperationParser &parser)
-      : nameLoc(nameLoc), resultIDs(resultIDs), opDefinition(opDefinition),
+  CustomOpAsmParser(
+      SMLoc nameLoc, ArrayRef<OperationParser::ResultRecord> resultIDs,
+      function_ref<ParseResult(OpAsmParser &, OperationState &)> parseAssembly,
+      bool isIsolatedFromAbove, StringRef opName, OperationParser &parser)
+      : nameLoc(nameLoc), resultIDs(resultIDs), parseAssembly(parseAssembly),
+        isIsolatedFromAbove(isIsolatedFromAbove), opName(opName),
         parser(parser) {}
 
   /// Parse an instance of the operation described by 'opDefinition' into the
   /// provided operation state.
   ParseResult parseOperation(OperationState &opState) {
-    if (opDefinition->parseAssembly(*this, opState))
+    if (parseAssembly(*this, opState))
       return failure();
     // Verify that the parsed attributes does not have duplicate attributes.
     // This can happen if an attribute set during parsing is also specified in
@@ -964,8 +965,7 @@ class CustomOpAsmParser : public OpAsmParser {
   /// Emit a diagnostic at the specified location and return failure.
   InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override {
     emittedError = true;
-    return parser.emitError(loc, "custom op '" + opDefinition->name.strref() +
-                                     "' " + message);
+    return parser.emitError(loc, "custom op '" + opName + "' " + message);
   }
 
   llvm::SMLoc getCurrentLocation() override {
@@ -1490,8 +1490,7 @@ class CustomOpAsmParser : public OpAsmParser {
     }
 
     // Try to parse the region.
-    assert((!enableNameShadowing ||
-            opDefinition->hasTrait<OpTrait::IsIsolatedFromAbove>()) &&
+    assert((!enableNameShadowing || isIsolatedFromAbove) &&
            "name shadowing is only allowed on isolated regions");
     if (parser.parseRegion(region, regionArguments, enableNameShadowing))
       return failure();
@@ -1656,7 +1655,9 @@ class CustomOpAsmParser : public OpAsmParser {
   ArrayRef<OperationParser::ResultRecord> resultIDs;
 
   /// The abstract information of the operation.
-  const AbstractOperation *opDefinition;
+  function_ref<ParseResult(OpAsmParser &, OperationState &)> parseAssembly;
+  bool isIsolatedFromAbove;
+  StringRef opName;
 
   /// The main operation parser.
   OperationParser &parser;
@@ -1670,31 +1671,51 @@ Operation *
 OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
   llvm::SMLoc opLoc = getToken().getLoc();
   StringRef opName = getTokenSpelling();
-
   auto *opDefinition = AbstractOperation::lookup(opName, getContext());
-  if (!opDefinition) {
+  Dialect *dialect = nullptr;
+  if (opDefinition) {
+    dialect = &opDefinition->dialect;
+  } else {
     if (opName.contains('.')) {
       // This op has a dialect, we try to check if we can register it in the
       // context on the fly.
       StringRef dialectName = opName.split('.').first;
-      if (!getContext()->getLoadedDialect(dialectName) &&
-          getContext()->getOrLoadDialect(dialectName)) {
+      dialect = getContext()->getLoadedDialect(dialectName);
+      if (!dialect && (dialect = getContext()->getOrLoadDialect(dialectName)))
         opDefinition = AbstractOperation::lookup(opName, getContext());
-      }
     } else {
       // If the operation name has no namespace prefix we treat it as a standard
       // operation and prefix it with "std".
       // TODO: Would it be better to just build a mapping of the registered
       // operations in the standard dialect?
-      if (getContext()->getOrLoadDialect("std"))
+      if (getContext()->getOrLoadDialect("std")) {
         opDefinition = AbstractOperation::lookup(Twine("std." + opName).str(),
                                                  getContext());
+        if (opDefinition)
+          opName = opDefinition->name.strref();
+      }
     }
   }
 
-  if (!opDefinition) {
-    emitError(opLoc) << "custom op '" << opName << "' is unknown";
-    return nullptr;
+  // This is the actual hook for the custom op parsing, usually implemented by
+  // the op itself (`Op::parse()`). We retrieve it either from the
+  // AbstractOperation or from the Dialect.
+  std::function<ParseResult(OpAsmParser &, OperationState &)> parseAssemblyFn;
+  bool isIsolatedFromAbove = false;
+
+  if (opDefinition) {
+    parseAssemblyFn = opDefinition->getParseAssemblyFn();
+    isIsolatedFromAbove =
+        opDefinition->hasTrait<OpTrait::IsIsolatedFromAbove>();
+  } else {
+    Optional<Dialect::ParseOpHook> dialectHook;
+    if (dialect)
+      dialectHook = dialect->getParseOperationHook(opName);
+    if (!dialectHook.hasValue()) {
+      emitError(opLoc) << "custom op '" << opName << "' is unknown";
+      return nullptr;
+    }
+    parseAssemblyFn = *dialectHook;
   }
 
   consumeToken();
@@ -1709,9 +1730,10 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
   auto srcLocation = getEncodedSourceLocation(opLoc);
 
   // Have the op implementation take a crack and parsing this.
-  OperationState opState(srcLocation, opDefinition->name);
+  OperationState opState(srcLocation, opName);
   CleanupOpStateRegions guard{opState};
-  CustomOpAsmParser opAsmParser(opLoc, resultIDs, opDefinition, *this);
+  CustomOpAsmParser opAsmParser(opLoc, resultIDs, parseAssemblyFn,
+                                isIsolatedFromAbove, opName, *this);
   if (opAsmParser.parseOperation(opState))
     return nullptr;
 

diff  --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index d24f5f07637ec..df6f216108237 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -1411,3 +1411,7 @@ test.graph_region {
   %2 = "bar"(%1) : (i64) -> i64
   "unregistered_terminator"() : () -> ()
 }) {sym_name = "unregistered_op_dominance_violation_ok", type = () -> i1} : () -> ()
+
+// This is an unregister operation, the printing/parsing is handled by the dialect.
+// CHECK: test.dialect_custom_printer custom_format
+test.dialect_custom_printer custom_format

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index eee0a9be75dbe..3bb4e8f4a6236 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -208,6 +208,26 @@ TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
   return success();
 }
 
+Optional<Dialect::ParseOpHook>
+TestDialect::getParseOperationHook(StringRef opName) const {
+  if (opName == "test.dialect_custom_printer") {
+    return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
+      return parser.parseKeyword("custom_format");
+    }};
+  }
+  return None;
+}
+
+LogicalResult TestDialect::printOperation(Operation *op,
+                                          OpAsmPrinter &printer) const {
+  StringRef opName = op->getName().getStringRef();
+  if (opName == "test.dialect_custom_printer") {
+    printer.getStream() << opName << " custom_format";
+    return success();
+  }
+  return failure();
+}
+
 //===----------------------------------------------------------------------===//
 // TestBranchOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index b8956e48b33df..7d48f8d4547a9 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -39,6 +39,12 @@ def Test_Dialect : Dialect {
                              Type type) const override;
     void printAttribute(Attribute attr,
                         DialectAsmPrinter &printer) const override;
+
+    // Provides a custom printing/parsing for some operations.
+    Optional<ParseOpHook>
+      getParseOperationHook(StringRef opName) const override;
+    LogicalResult printOperation(Operation *op,
+                                 OpAsmPrinter &printer) const override;
   }];
 }
 


        


More information about the Mlir-commits mailing list