[Mlir-commits] [mlir] 0845635 - [mlir][ir] Custom ops' parse/print fall back to dialect hooks

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 10 11:34:29 PST 2021


Author: Mogball
Date: 2021-12-10T19:34:25Z
New Revision: 0845635eda428173861fa1a4dd3de014b595daf7

URL: https://github.com/llvm/llvm-project/commit/0845635eda428173861fa1a4dd3de014b595daf7
DIFF: https://github.com/llvm/llvm-project/commit/0845635eda428173861fa1a4dd3de014b595daf7.diff

LOG: [mlir][ir] Custom ops' parse/print fall back to dialect hooks

Custom ops that have no parser or printer should fall back to the dialect's parser and/or printer hooks. This avoids the need to define parsers and printers that simply dispatch to the dialect hook.

Reviewed By: mehdi_amini, rriddle

Differential Revision: https://reviews.llvm.org/D115481

Added: 
    

Modified: 
    mlir/include/mlir/IR/OpDefinition.h
    mlir/lib/IR/Operation.cpp
    mlir/test/IR/parser.mlir
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index ad461e8aef075..c68c6b12bf0a4 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -173,13 +173,19 @@ class OpState {
   /// back to this one which accepts everything.
   LogicalResult verify() { return success(); }
 
-  /// Unless overridden, the custom assembly form of an op is always rejected.
-  /// Op implementations should implement this to return failure.
-  /// On success, they should fill in result with the fields to use.
+  /// Parse the custom form of an operation. Unless overridden, this method will
+  /// first try to get an operation parser from the op's dialect. Otherwise the
+  /// custom assembly form of an op is always rejected. Op implementations
+  /// should implement this to return failure. On success, they should fill in
+  /// result with the fields to use.
   static ParseResult parse(OpAsmParser &parser, OperationState &result);
 
-  // The fallback for the printer is to print it the generic assembly form.
-  static void print(Operation *op, OpAsmPrinter &p);
+  /// Print the operation. Unless overridden, this method will first try to get
+  /// an operation printer from the dialect. Otherwise, it prints the operation
+  /// in generic form.
+  static void print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect);
+
+  /// Print an operation name, eliding the dialect prefix if necessary.
   static void printOpName(Operation *op, OpAsmPrinter &p,
                           StringRef defaultDialect);
 
@@ -1781,7 +1787,7 @@ class Op : public OpState, public Traits<ConcreteType>... {
                           OperationName::PrintAssemblyFn>
   getPrintAssemblyFnImpl() {
     return [](Operation *op, OpAsmPrinter &printer, StringRef defaultDialect) {
-      return OpState::print(op, printer);
+      return OpState::print(op, printer, defaultDialect);
     };
   }
   /// The internal implementation of `getPrintAssemblyFn` that is invoked when

diff  --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 164aeff04f677..dc224f21d8348 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -580,14 +580,27 @@ Operation *Operation::clone() {
 // OpState trait class.
 //===----------------------------------------------------------------------===//
 
-// The fallback for the parser is to reject the custom assembly form.
+// The fallback for the parser is to try for a dialect operation parser.
+// Otherwise, reject the custom assembly form.
 ParseResult OpState::parse(OpAsmParser &parser, OperationState &result) {
+  if (auto parseFn = result.name.getDialect()->getParseOperationHook(
+          result.name.getStringRef()))
+    return (*parseFn)(parser, result);
   return parser.emitError(parser.getNameLoc(), "has no custom assembly form");
 }
 
-// The fallback for the printer is to print in the generic assembly form.
-void OpState::print(Operation *op, OpAsmPrinter &p) { p.printGenericOp(op); }
-// The fallback for the printer is to print in the generic assembly form.
+// The fallback for the printer is to try for a dialect operation printer.
+// Otherwise, it prints the generic form.
+void OpState::print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect) {
+  if (auto printFn = op->getDialect()->getOperationPrinter(op)) {
+    printOpName(op, p, defaultDialect);
+    printFn(op, p);
+  } else {
+    p.printGenericOp(op);
+  }
+}
+
+/// Print an operation name, eliding the dialect prefix if necessary.
 void OpState::printOpName(Operation *op, OpAsmPrinter &p,
                           StringRef defaultDialect) {
   StringRef name = op->getName().getStringRef();

diff  --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index 8f2f8706b1de2..30c273fd45de3 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -1425,3 +1425,8 @@ test.graph_region {
 // This is an unregister operation, the printing/parsing is handled by the dialect.
 // CHECK: test.dialect_custom_printer custom_format
 test.dialect_custom_printer custom_format
+
+// This is a registered operation with no custom parser and printer, and should
+// be handled by the dialect.
+// CHECK: test.dialect_custom_format_fallback custom_format_fallback
+test.dialect_custom_format_fallback custom_format_fallback

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 73d4243ad1088..a6b317db467cf 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -318,6 +318,11 @@ TestDialect::getParseOperationHook(StringRef opName) const {
       return parser.parseKeyword("custom_format");
     }};
   }
+  if (opName == "test.dialect_custom_format_fallback") {
+    return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
+      return parser.parseKeyword("custom_format_fallback");
+    }};
+  }
   return None;
 }
 
@@ -329,6 +334,11 @@ TestDialect::getOperationPrinter(Operation *op) const {
       printer.getStream() << " custom_format";
     };
   }
+  if (opName == "test.dialect_custom_format_fallback") {
+    return [](Operation *op, OpAsmPrinter &printer) {
+      printer.getStream() << " custom_format_fallback";
+    };
+  }
   return {};
 }
 

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 4f6abed2eda26..120749e78c83d 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -597,6 +597,10 @@ def AttrSizedResultOp : TEST_Op<"attr_sized_results",
   );
 }
 
+// This is used to test that the fallback for a custom op's parser and printer
+// is the dialect parser and printer hooks.
+def CustomFormatFallbackOp : TEST_Op<"dialect_custom_format_fallback">;
+
 // This is used to test encoding of a string attribute into an SSA name of a
 // pretty printed value name.
 def StringAttrPrettyNameOp


        


More information about the Mlir-commits mailing list