[Mlir-commits] [mlir] 0845635 - [mlir][ir] Custom ops' parse/print fall back to dialect hooks
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Dec 10 11:34:29 PST 2021
Author: Mogball
Date: 2021-12-10T19:34:25Z
New Revision: 0845635eda428173861fa1a4dd3de014b595daf7
URL: https://github.com/llvm/llvm-project/commit/0845635eda428173861fa1a4dd3de014b595daf7
DIFF: https://github.com/llvm/llvm-project/commit/0845635eda428173861fa1a4dd3de014b595daf7.diff
LOG: [mlir][ir] Custom ops' parse/print fall back to dialect hooks
Custom ops that have no parser or printer should fall back to the dialect's parser and/or printer hooks. This avoids the need to define parsers and printers that simply dispatch to the dialect hook.
Reviewed By: mehdi_amini, rriddle
Differential Revision: https://reviews.llvm.org/D115481
Added:
Modified:
mlir/include/mlir/IR/OpDefinition.h
mlir/lib/IR/Operation.cpp
mlir/test/IR/parser.mlir
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index ad461e8aef075..c68c6b12bf0a4 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -173,13 +173,19 @@ class OpState {
/// back to this one which accepts everything.
LogicalResult verify() { return success(); }
- /// Unless overridden, the custom assembly form of an op is always rejected.
- /// Op implementations should implement this to return failure.
- /// On success, they should fill in result with the fields to use.
+ /// Parse the custom form of an operation. Unless overridden, this method will
+ /// first try to get an operation parser from the op's dialect. Otherwise the
+ /// custom assembly form of an op is always rejected. Op implementations
+ /// should implement this to return failure. On success, they should fill in
+ /// result with the fields to use.
static ParseResult parse(OpAsmParser &parser, OperationState &result);
- // The fallback for the printer is to print it the generic assembly form.
- static void print(Operation *op, OpAsmPrinter &p);
+ /// Print the operation. Unless overridden, this method will first try to get
+ /// an operation printer from the dialect. Otherwise, it prints the operation
+ /// in generic form.
+ static void print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect);
+
+ /// Print an operation name, eliding the dialect prefix if necessary.
static void printOpName(Operation *op, OpAsmPrinter &p,
StringRef defaultDialect);
@@ -1781,7 +1787,7 @@ class Op : public OpState, public Traits<ConcreteType>... {
OperationName::PrintAssemblyFn>
getPrintAssemblyFnImpl() {
return [](Operation *op, OpAsmPrinter &printer, StringRef defaultDialect) {
- return OpState::print(op, printer);
+ return OpState::print(op, printer, defaultDialect);
};
}
/// The internal implementation of `getPrintAssemblyFn` that is invoked when
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 164aeff04f677..dc224f21d8348 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -580,14 +580,27 @@ Operation *Operation::clone() {
// OpState trait class.
//===----------------------------------------------------------------------===//
-// The fallback for the parser is to reject the custom assembly form.
+// The fallback for the parser is to try for a dialect operation parser.
+// Otherwise, reject the custom assembly form.
ParseResult OpState::parse(OpAsmParser &parser, OperationState &result) {
+ if (auto parseFn = result.name.getDialect()->getParseOperationHook(
+ result.name.getStringRef()))
+ return (*parseFn)(parser, result);
return parser.emitError(parser.getNameLoc(), "has no custom assembly form");
}
-// The fallback for the printer is to print in the generic assembly form.
-void OpState::print(Operation *op, OpAsmPrinter &p) { p.printGenericOp(op); }
-// The fallback for the printer is to print in the generic assembly form.
+// The fallback for the printer is to try for a dialect operation printer.
+// Otherwise, it prints the generic form.
+void OpState::print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect) {
+ if (auto printFn = op->getDialect()->getOperationPrinter(op)) {
+ printOpName(op, p, defaultDialect);
+ printFn(op, p);
+ } else {
+ p.printGenericOp(op);
+ }
+}
+
+/// Print an operation name, eliding the dialect prefix if necessary.
void OpState::printOpName(Operation *op, OpAsmPrinter &p,
StringRef defaultDialect) {
StringRef name = op->getName().getStringRef();
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index 8f2f8706b1de2..30c273fd45de3 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -1425,3 +1425,8 @@ test.graph_region {
// This is an unregister operation, the printing/parsing is handled by the dialect.
// CHECK: test.dialect_custom_printer custom_format
test.dialect_custom_printer custom_format
+
+// This is a registered operation with no custom parser and printer, and should
+// be handled by the dialect.
+// CHECK: test.dialect_custom_format_fallback custom_format_fallback
+test.dialect_custom_format_fallback custom_format_fallback
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 73d4243ad1088..a6b317db467cf 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -318,6 +318,11 @@ TestDialect::getParseOperationHook(StringRef opName) const {
return parser.parseKeyword("custom_format");
}};
}
+ if (opName == "test.dialect_custom_format_fallback") {
+ return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
+ return parser.parseKeyword("custom_format_fallback");
+ }};
+ }
return None;
}
@@ -329,6 +334,11 @@ TestDialect::getOperationPrinter(Operation *op) const {
printer.getStream() << " custom_format";
};
}
+ if (opName == "test.dialect_custom_format_fallback") {
+ return [](Operation *op, OpAsmPrinter &printer) {
+ printer.getStream() << " custom_format_fallback";
+ };
+ }
return {};
}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 4f6abed2eda26..120749e78c83d 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -597,6 +597,10 @@ def AttrSizedResultOp : TEST_Op<"attr_sized_results",
);
}
+// This is used to test that the fallback for a custom op's parser and printer
+// is the dialect parser and printer hooks.
+def CustomFormatFallbackOp : TEST_Op<"dialect_custom_format_fallback">;
+
// This is used to test encoding of a string attribute into an SSA name of a
// pretty printed value name.
def StringAttrPrettyNameOp
More information about the Mlir-commits
mailing list