[Mlir-commits] [mlir] fd87963 - Change dialect `printOperation()` hook to `getOperationPrinter()`
Mehdi Amini
llvmlistbot at llvm.org
Tue Aug 31 10:52:52 PDT 2021
Author: Mehdi Amini
Date: 2021-08-31T17:52:39Z
New Revision: fd87963eee23f6cf2aed97bf182a6b3f5e9450ed
URL: https://github.com/llvm/llvm-project/commit/fd87963eee23f6cf2aed97bf182a6b3f5e9450ed
DIFF: https://github.com/llvm/llvm-project/commit/fd87963eee23f6cf2aed97bf182a6b3f5e9450ed.diff
LOG: Change dialect `printOperation()` hook to `getOperationPrinter()`
This makes the hook return a printer if available, instead of using LogicalResult to
indicate if a printer was available (and invoked). This allows the caller to detect that
the dialect has a printer for a given operation without actually invoking the printer.
It'll be leveraged in a future revision to move printing the op name itself under control
of the ASMPrinter.
Differential Revision: https://reviews.llvm.org/D108803
Added:
Modified:
mlir/include/mlir/IR/Dialect.h
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/Dialect.cpp
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h
index f615819fd16bb..14114e379f78b 100644
--- a/mlir/include/mlir/IR/Dialect.h
+++ b/mlir/include/mlir/IR/Dialect.h
@@ -121,8 +121,8 @@ class Dialect {
/// Print an operation registered to this dialect.
/// This hook is invoked for registered operation which don't override the
/// `print()` method to define their own custom assembly.
- virtual LogicalResult printOperation(Operation *op,
- OpAsmPrinter &printer) const;
+ virtual llvm::unique_function<void(Operation *, OpAsmPrinter &printer)>
+ getOperationPrinter(Operation *op) const;
//===--------------------------------------------------------------------===//
// Verification Hooks
@@ -297,8 +297,7 @@ class DialectRegistry {
public:
explicit DialectRegistry();
- template <typename ConcreteDialect>
- void insert() {
+ template <typename ConcreteDialect> void insert() {
insert(TypeID::get<ConcreteDialect>(),
ConcreteDialect::getDialectNamespace(),
static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) {
@@ -364,8 +363,7 @@ class DialectRegistry {
/// Add an external op interface model for an op that belongs to a dialect,
/// both provided as template parameters. The dialect must be present in the
/// registry.
- template <typename OpTy, typename ModelTy>
- void addOpInterface() {
+ template <typename OpTy, typename ModelTy> void addOpInterface() {
StringRef opName = OpTy::getOperationName();
StringRef dialectName = opName.split('.').first;
addObjectInterface(dialectName, TypeID::get<OpTy>(),
@@ -426,8 +424,7 @@ class DialectRegistry {
namespace llvm {
/// Provide isa functionality for Dialects.
-template <typename T>
-struct isa_impl<T, ::mlir::Dialect> {
+template <typename T> struct isa_impl<T, ::mlir::Dialect> {
static inline bool doit(const ::mlir::Dialect &dialect) {
return mlir::TypeID::get<T>() == dialect.getTypeID();
}
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 65cbc8a8c33eb..b0fafc3297635 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2508,8 +2508,10 @@ void OperationPrinter::printOperation(Operation *op) {
}
// Otherwise try to dispatch to the dialect, if available.
if (Dialect *dialect = op->getDialect()) {
- if (succeeded(dialect->printOperation(op, *this)))
+ if (auto opPrinter = dialect->getOperationPrinter(op)) {
+ opPrinter(op, *this);
return;
+ }
}
}
diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index 80c8dabe1f3b9..2f2997e11fb83 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -172,11 +172,11 @@ Dialect::getParseOperationHook(StringRef opName) const {
return None;
}
-LogicalResult Dialect::printOperation(Operation *op,
- OpAsmPrinter &printer) const {
+llvm::unique_function<void(Operation *, OpAsmPrinter &printer)>
+Dialect::getOperationPrinter(Operation *op) const {
assert(op->getDialect() == this &&
"Dialect hook invoked on non-dialect owned operation");
- return failure();
+ return nullptr;
}
/// Utility function that returns if the given string is a valid dialect
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 2a1f37119e2d5..d12c61a124d14 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -313,14 +313,15 @@ TestDialect::getParseOperationHook(StringRef opName) const {
return None;
}
-LogicalResult TestDialect::printOperation(Operation *op,
- OpAsmPrinter &printer) const {
+llvm::unique_function<void(Operation *, OpAsmPrinter &)>
+TestDialect::getOperationPrinter(Operation *op) const {
StringRef opName = op->getName().getStringRef();
if (opName == "test.dialect_custom_printer") {
- printer.getStream() << opName << " custom_format";
- return success();
+ return [](Operation *op, OpAsmPrinter &printer) {
+ printer.getStream() << op->getName().getStringRef() << " custom_format";
+ };
}
- return failure();
+ return {};
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index fbbc766839c40..4dee0a1b366ca 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -39,15 +39,17 @@ def Test_Dialect : Dialect {
void registerTypes();
::mlir::Attribute parseAttribute(::mlir::DialectAsmParser &parser,
- ::mlir::Type type) const override;
+ ::mlir::Type type) const override;
void printAttribute(::mlir::Attribute attr,
::mlir::DialectAsmPrinter &printer) const override;
// Provides a custom printing/parsing for some operations.
::llvm::Optional<ParseOpHook>
getParseOperationHook(::llvm::StringRef opName) const override;
- ::mlir::LogicalResult printOperation(::mlir::Operation *op,
- ::mlir::OpAsmPrinter &printer) const override;
+ ::llvm::unique_function<void(::mlir::Operation *,
+ ::mlir::OpAsmPrinter &printer)>
+ getOperationPrinter(::mlir::Operation *op) const override;
+
private:
// Storage for a custom fallback interface.
void *fallbackEffectOpInterfaces;
More information about the Mlir-commits
mailing list