[Mlir-commits] [mlir] a0c776f - Add a mechanism for Dialects to customize printing/parsing operations when they are unregistered
Mehdi Amini
llvmlistbot at llvm.org
Mon Mar 22 17:40:17 PDT 2021
Author: Mehdi Amini
Date: 2021-03-23T00:40:03Z
New Revision: a0c776fc94d3179822c95dcb9f79b344e13f069b
URL: https://github.com/llvm/llvm-project/commit/a0c776fc94d3179822c95dcb9f79b344e13f069b
DIFF: https://github.com/llvm/llvm-project/commit/a0c776fc94d3179822c95dcb9f79b344e13f069b.diff
LOG: Add a mechanism for Dialects to customize printing/parsing operations when they are unregistered
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D99007
Added:
Modified:
mlir/include/mlir/IR/Dialect.h
mlir/include/mlir/IR/OperationSupport.h
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/Dialect.cpp
mlir/lib/Parser/Parser.cpp
mlir/test/IR/parser.mlir
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h
index 4a816ccb79c95..2fdbdc4829833 100644
--- a/mlir/include/mlir/IR/Dialect.h
+++ b/mlir/include/mlir/IR/Dialect.h
@@ -39,6 +39,12 @@ using InterfaceAllocatorFunction =
///
class Dialect {
public:
+ /// Type for a callback provided by the dialect to parse a custom operation.
+ /// This is used for the dialect to provide an alternative way to parse custom
+ /// operations, including unregistered ones.
+ using ParseOpHook =
+ function_ref<ParseResult(OpAsmParser &parser, OperationState &result)>;
+
virtual ~Dialect();
/// Utility function that returns if the given string is a valid dialect
@@ -97,6 +103,18 @@ class Dialect {
llvm_unreachable("dialect has no registered type printing hook");
}
+ /// Return the hook to parse an operation registered to this dialect, if any.
+ /// By default this will lookup for registered operations and return the
+ /// `parse()` method registered on the AbstractOperation. Dialects can
+ /// override this behavior and handle unregistered operations as well.
+ virtual Optional<ParseOpHook> getParseOperationHook(StringRef opName) const;
+
+ /// Print an operation registered to this dialect.
+ /// This hook is invoked for registered operation which don't override the
+ /// `print()` method to define their own custom assembly.
+ virtual LogicalResult printOperation(Operation *op,
+ OpAsmPrinter &printer) const;
+
//===--------------------------------------------------------------------===//
// Verification Hooks
//===--------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index c2241df46191d..8cc97d9c02ee8 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -88,6 +88,9 @@ class AbstractOperation {
/// Use the specified object to parse this ops custom assembly format.
ParseResult parseAssembly(OpAsmParser &parser, OperationState &result) const;
+ /// Return the static hook for parsing this operation assembly.
+ ParseAssemblyFn getParseAssemblyFn() const { return parseAssemblyFn; }
+
/// This hook implements the AsmPrinter for this operation.
void printAssembly(Operation *op, OpAsmPrinter &p) const {
return printAssemblyFn(op, p);
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index ad2c3c6a9075c..8cd20c7777adf 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2409,6 +2409,11 @@ void OperationPrinter::printOperation(Operation *op) {
opInfo->printAssembly(op, *this);
return;
}
+ // Otherwise try to dispatch to the dialect, if available.
+ if (Dialect *dialect = op->getDialect()) {
+ if (succeeded(dialect->printOperation(op, *this)))
+ return;
+ }
}
// Otherwise print with the generic assembly form.
diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index fc21152878e36..612c902d47079 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -136,6 +136,18 @@ Type Dialect::parseType(DialectAsmParser &parser) const {
return Type();
}
+Optional<Dialect::ParseOpHook>
+Dialect::getParseOperationHook(StringRef opName) const {
+ return None;
+}
+
+LogicalResult Dialect::printOperation(Operation *op,
+ OpAsmPrinter &printer) const {
+ assert(op->getDialect() == this &&
+ "Dialect hook invoked on non-dialect owned operation");
+ return failure();
+}
+
/// Utility function that returns if the given string is a valid dialect
/// namespace.
bool Dialect::isValidNamespace(StringRef str) {
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 7b3cd158b2ada..ad80204ac496d 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -925,17 +925,18 @@ Operation *OperationParser::parseGenericOperation(Block *insertBlock,
namespace {
class CustomOpAsmParser : public OpAsmParser {
public:
- CustomOpAsmParser(SMLoc nameLoc,
- ArrayRef<OperationParser::ResultRecord> resultIDs,
- const AbstractOperation *opDefinition,
- OperationParser &parser)
- : nameLoc(nameLoc), resultIDs(resultIDs), opDefinition(opDefinition),
+ CustomOpAsmParser(
+ SMLoc nameLoc, ArrayRef<OperationParser::ResultRecord> resultIDs,
+ function_ref<ParseResult(OpAsmParser &, OperationState &)> parseAssembly,
+ bool isIsolatedFromAbove, StringRef opName, OperationParser &parser)
+ : nameLoc(nameLoc), resultIDs(resultIDs), parseAssembly(parseAssembly),
+ isIsolatedFromAbove(isIsolatedFromAbove), opName(opName),
parser(parser) {}
/// Parse an instance of the operation described by 'opDefinition' into the
/// provided operation state.
ParseResult parseOperation(OperationState &opState) {
- if (opDefinition->parseAssembly(*this, opState))
+ if (parseAssembly(*this, opState))
return failure();
// Verify that the parsed attributes does not have duplicate attributes.
// This can happen if an attribute set during parsing is also specified in
@@ -964,8 +965,7 @@ class CustomOpAsmParser : public OpAsmParser {
/// Emit a diagnostic at the specified location and return failure.
InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override {
emittedError = true;
- return parser.emitError(loc, "custom op '" + opDefinition->name.strref() +
- "' " + message);
+ return parser.emitError(loc, "custom op '" + opName + "' " + message);
}
llvm::SMLoc getCurrentLocation() override {
@@ -1490,8 +1490,7 @@ class CustomOpAsmParser : public OpAsmParser {
}
// Try to parse the region.
- assert((!enableNameShadowing ||
- opDefinition->hasTrait<OpTrait::IsIsolatedFromAbove>()) &&
+ assert((!enableNameShadowing || isIsolatedFromAbove) &&
"name shadowing is only allowed on isolated regions");
if (parser.parseRegion(region, regionArguments, enableNameShadowing))
return failure();
@@ -1656,7 +1655,9 @@ class CustomOpAsmParser : public OpAsmParser {
ArrayRef<OperationParser::ResultRecord> resultIDs;
/// The abstract information of the operation.
- const AbstractOperation *opDefinition;
+ function_ref<ParseResult(OpAsmParser &, OperationState &)> parseAssembly;
+ bool isIsolatedFromAbove;
+ StringRef opName;
/// The main operation parser.
OperationParser &parser;
@@ -1670,31 +1671,51 @@ Operation *
OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
llvm::SMLoc opLoc = getToken().getLoc();
StringRef opName = getTokenSpelling();
-
auto *opDefinition = AbstractOperation::lookup(opName, getContext());
- if (!opDefinition) {
+ Dialect *dialect = nullptr;
+ if (opDefinition) {
+ dialect = &opDefinition->dialect;
+ } else {
if (opName.contains('.')) {
// This op has a dialect, we try to check if we can register it in the
// context on the fly.
StringRef dialectName = opName.split('.').first;
- if (!getContext()->getLoadedDialect(dialectName) &&
- getContext()->getOrLoadDialect(dialectName)) {
+ dialect = getContext()->getLoadedDialect(dialectName);
+ if (!dialect && (dialect = getContext()->getOrLoadDialect(dialectName)))
opDefinition = AbstractOperation::lookup(opName, getContext());
- }
} else {
// If the operation name has no namespace prefix we treat it as a standard
// operation and prefix it with "std".
// TODO: Would it be better to just build a mapping of the registered
// operations in the standard dialect?
- if (getContext()->getOrLoadDialect("std"))
+ if (getContext()->getOrLoadDialect("std")) {
opDefinition = AbstractOperation::lookup(Twine("std." + opName).str(),
getContext());
+ if (opDefinition)
+ opName = opDefinition->name.strref();
+ }
}
}
- if (!opDefinition) {
- emitError(opLoc) << "custom op '" << opName << "' is unknown";
- return nullptr;
+ // This is the actual hook for the custom op parsing, usually implemented by
+ // the op itself (`Op::parse()`). We retrieve it either from the
+ // AbstractOperation or from the Dialect.
+ std::function<ParseResult(OpAsmParser &, OperationState &)> parseAssemblyFn;
+ bool isIsolatedFromAbove = false;
+
+ if (opDefinition) {
+ parseAssemblyFn = opDefinition->getParseAssemblyFn();
+ isIsolatedFromAbove =
+ opDefinition->hasTrait<OpTrait::IsIsolatedFromAbove>();
+ } else {
+ Optional<Dialect::ParseOpHook> dialectHook;
+ if (dialect)
+ dialectHook = dialect->getParseOperationHook(opName);
+ if (!dialectHook.hasValue()) {
+ emitError(opLoc) << "custom op '" << opName << "' is unknown";
+ return nullptr;
+ }
+ parseAssemblyFn = *dialectHook;
}
consumeToken();
@@ -1709,9 +1730,10 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
auto srcLocation = getEncodedSourceLocation(opLoc);
// Have the op implementation take a crack and parsing this.
- OperationState opState(srcLocation, opDefinition->name);
+ OperationState opState(srcLocation, opName);
CleanupOpStateRegions guard{opState};
- CustomOpAsmParser opAsmParser(opLoc, resultIDs, opDefinition, *this);
+ CustomOpAsmParser opAsmParser(opLoc, resultIDs, parseAssemblyFn,
+ isIsolatedFromAbove, opName, *this);
if (opAsmParser.parseOperation(opState))
return nullptr;
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index d24f5f07637ec..df6f216108237 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -1411,3 +1411,7 @@ test.graph_region {
%2 = "bar"(%1) : (i64) -> i64
"unregistered_terminator"() : () -> ()
}) {sym_name = "unregistered_op_dominance_violation_ok", type = () -> i1} : () -> ()
+
+// This is an unregister operation, the printing/parsing is handled by the dialect.
+// CHECK: test.dialect_custom_printer custom_format
+test.dialect_custom_printer custom_format
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index eee0a9be75dbe..3bb4e8f4a6236 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -208,6 +208,26 @@ TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
return success();
}
+Optional<Dialect::ParseOpHook>
+TestDialect::getParseOperationHook(StringRef opName) const {
+ if (opName == "test.dialect_custom_printer") {
+ return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
+ return parser.parseKeyword("custom_format");
+ }};
+ }
+ return None;
+}
+
+LogicalResult TestDialect::printOperation(Operation *op,
+ OpAsmPrinter &printer) const {
+ StringRef opName = op->getName().getStringRef();
+ if (opName == "test.dialect_custom_printer") {
+ printer.getStream() << opName << " custom_format";
+ return success();
+ }
+ return failure();
+}
+
//===----------------------------------------------------------------------===//
// TestBranchOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index b8956e48b33df..7d48f8d4547a9 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -39,6 +39,12 @@ def Test_Dialect : Dialect {
Type type) const override;
void printAttribute(Attribute attr,
DialectAsmPrinter &printer) const override;
+
+ // Provides a custom printing/parsing for some operations.
+ Optional<ParseOpHook>
+ getParseOperationHook(StringRef opName) const override;
+ LogicalResult printOperation(Operation *op,
+ OpAsmPrinter &printer) const override;
}];
}
More information about the Mlir-commits
mailing list