[Mlir-commits] [mlir] 1bf3fd9 - [mlir] Use unique_function in AbstractOperation fields
River Riddle
llvmlistbot at llvm.org
Tue May 25 11:40:58 PDT 2021
Author: Mathieu Fehr
Date: 2021-05-25T11:36:12-07:00
New Revision: 1bf3fd9bb55a8e9c8a2f6e446e956951d1715cf7
URL: https://github.com/llvm/llvm-project/commit/1bf3fd9bb55a8e9c8a2f6e446e956951d1715cf7
DIFF: https://github.com/llvm/llvm-project/commit/1bf3fd9bb55a8e9c8a2f6e446e956951d1715cf7.diff
LOG: [mlir] Use unique_function in AbstractOperation fields
Currently, AbstractOperation fields are function pointers.
Modifying them to unique_function allow them to contain
runtime information.
For instance, this allows operations to be defined at runtime.
Differential Revision: https://reviews.llvm.org/D103031
Added:
Modified:
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/OperationSupport.h
mlir/lib/IR/MLIRContext.cpp
mlir/lib/Parser/Parser.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 735d59703029d..b2e50324f75df 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1671,7 +1671,10 @@ class Op : public OpState, public Traits<ConcreteType>... {
detect_has_single_result_fold<ConcreteOpT>::value,
AbstractOperation::FoldHookFn>
getFoldHookFnImpl() {
- return &foldSingleResultHook<ConcreteOpT>;
+ return [](Operation *op, ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ return foldSingleResultHook<ConcreteOpT>(op, operands, results);
+ };
}
/// The internal implementation of `getFoldHookFn` above that is invoked if
/// the operation is not single result and defines a `fold` method.
@@ -1681,7 +1684,10 @@ class Op : public OpState, public Traits<ConcreteType>... {
detect_has_fold<ConcreteOpT>::value,
AbstractOperation::FoldHookFn>
getFoldHookFnImpl() {
- return &foldHook<ConcreteOpT>;
+ return [](Operation *op, ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ return foldHook<ConcreteOpT>(op, operands, results);
+ };
}
/// The internal implementation of `getFoldHookFn` above that is invoked if
/// the operation does not define a `fold` method.
@@ -1690,8 +1696,12 @@ class Op : public OpState, public Traits<ConcreteType>... {
!detect_has_fold<ConcreteOpT>::value,
AbstractOperation::FoldHookFn>
getFoldHookFnImpl() {
- // In this case, we only need to fold the traits of the operation.
- return &op_definition_impl::foldTraits<FoldableTraitsTupleT>;
+ return [](Operation *op, ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ // In this case, we only need to fold the traits of the operation.
+ return op_definition_impl::foldTraits<FoldableTraitsTupleT>(op, operands,
+ results);
+ };
}
/// Return the result of folding a single result operation that defines a
/// `fold` method.
@@ -1735,7 +1745,8 @@ class Op : public OpState, public Traits<ConcreteType>... {
}
/// Implementation of `GetHasTraitFn`
static AbstractOperation::HasTraitFn getHasTraitFn() {
- return &op_definition_impl::hasTrait<Traits...>;
+ return
+ [](TypeID id) { return op_definition_impl::hasTrait<Traits...>(id); };
}
/// Implementation of `ParseAssemblyFn` AbstractOperation hook.
static AbstractOperation::ParseAssemblyFn getParseAssemblyFn() {
@@ -1751,7 +1762,9 @@ class Op : public OpState, public Traits<ConcreteType>... {
static std::enable_if_t<!detect_has_print<ConcreteOpT>::value,
AbstractOperation::PrintAssemblyFn>
getPrintAssemblyFnImpl() {
- return &OpState::print;
+ return [](Operation *op, OpAsmPrinter &parser) {
+ return OpState::print(op, parser);
+ };
}
/// The internal implementation of `getPrintAssemblyFn` that is invoked when
/// the concrete operation defines a `print` method.
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index e7bcacf79a0c6..20d73cc20d127 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -67,14 +67,17 @@ using OwningRewritePatternList = RewritePatternSet;
/// the concrete operation types.
class AbstractOperation {
public:
- using GetCanonicalizationPatternsFn = void (*)(RewritePatternSet &,
- MLIRContext *);
- using FoldHookFn = LogicalResult (*)(Operation *, ArrayRef<Attribute>,
- SmallVectorImpl<OpFoldResult> &);
- using HasTraitFn = bool (*)(TypeID);
- using ParseAssemblyFn = ParseResult (*)(OpAsmParser &, OperationState &);
- using PrintAssemblyFn = void (*)(Operation *, OpAsmPrinter &);
- using VerifyInvariantsFn = LogicalResult (*)(Operation *);
+ using GetCanonicalizationPatternsFn =
+ llvm::unique_function<void(RewritePatternSet &, MLIRContext *) const>;
+ using FoldHookFn = llvm::unique_function<LogicalResult(
+ Operation *, ArrayRef<Attribute>, SmallVectorImpl<OpFoldResult> &) const>;
+ using HasTraitFn = llvm::unique_function<bool(TypeID) const>;
+ using ParseAssemblyFn =
+ llvm::unique_function<ParseResult(OpAsmParser &, OperationState &) const>;
+ using PrintAssemblyFn =
+ llvm::unique_function<void(Operation *, OpAsmPrinter &) const>;
+ using VerifyInvariantsFn =
+ llvm::unique_function<LogicalResult(Operation *) const>;
/// This is the name of the operation.
const Identifier name;
@@ -89,7 +92,7 @@ class AbstractOperation {
ParseResult parseAssembly(OpAsmParser &parser, OperationState &result) const;
/// Return the static hook for parsing this operation assembly.
- ParseAssemblyFn getParseAssemblyFn() const { return parseAssemblyFn; }
+ const ParseAssemblyFn &getParseAssemblyFn() const { return parseAssemblyFn; }
/// This hook implements the AsmPrinter for this operation.
void printAssembly(Operation *op, OpAsmPrinter &p) const {
@@ -175,20 +178,21 @@ class AbstractOperation {
/// Register a new operation in a Dialect object.
/// The use of this method is in general discouraged in favor of
/// 'insert<CustomOp>(dialect)'.
- static void insert(StringRef name, Dialect &dialect, TypeID typeID,
- ParseAssemblyFn parseAssembly,
- PrintAssemblyFn printAssembly,
- VerifyInvariantsFn verifyInvariants, FoldHookFn foldHook,
- GetCanonicalizationPatternsFn getCanonicalizationPatterns,
- detail::InterfaceMap &&interfaceMap, HasTraitFn hasTrait);
+ static void
+ insert(StringRef name, Dialect &dialect, TypeID typeID,
+ ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
+ VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook,
+ GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
+ detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait);
private:
AbstractOperation(StringRef name, Dialect &dialect, TypeID typeID,
- ParseAssemblyFn parseAssembly,
- PrintAssemblyFn printAssembly,
- VerifyInvariantsFn verifyInvariants, FoldHookFn foldHook,
- GetCanonicalizationPatternsFn getCanonicalizationPatterns,
- detail::InterfaceMap &&interfaceMap, HasTraitFn hasTrait);
+ ParseAssemblyFn &&parseAssembly,
+ PrintAssemblyFn &&printAssembly,
+ VerifyInvariantsFn &&verifyInvariants,
+ FoldHookFn &&foldHook,
+ GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
+ detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait);
/// A map of interfaces that were registered to this operation.
detail::InterfaceMap interfaceMap;
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index f1825a481fb70..f438c00085516 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -696,13 +696,15 @@ const AbstractOperation *AbstractOperation::lookup(StringRef opName,
void AbstractOperation::insert(
StringRef name, Dialect &dialect, TypeID typeID,
- ParseAssemblyFn parseAssembly, PrintAssemblyFn printAssembly,
- VerifyInvariantsFn verifyInvariants, FoldHookFn foldHook,
- GetCanonicalizationPatternsFn getCanonicalizationPatterns,
- detail::InterfaceMap &&interfaceMap, HasTraitFn hasTrait) {
- AbstractOperation opInfo(
- name, dialect, typeID, parseAssembly, printAssembly, verifyInvariants,
- foldHook, getCanonicalizationPatterns, std::move(interfaceMap), hasTrait);
+ ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
+ VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook,
+ GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
+ detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait) {
+ AbstractOperation opInfo(name, dialect, typeID, std::move(parseAssembly),
+ std::move(printAssembly),
+ std::move(verifyInvariants), std::move(foldHook),
+ std::move(getCanonicalizationPatterns),
+ std::move(interfaceMap), std::move(hasTrait));
auto &impl = dialect.getContext()->getImpl();
assert(impl.multiThreadedExecutionContext == 0 &&
@@ -717,16 +719,18 @@ void AbstractOperation::insert(
AbstractOperation::AbstractOperation(
StringRef name, Dialect &dialect, TypeID typeID,
- ParseAssemblyFn parseAssembly, PrintAssemblyFn printAssembly,
- VerifyInvariantsFn verifyInvariants, FoldHookFn foldHook,
- GetCanonicalizationPatternsFn getCanonicalizationPatterns,
- detail::InterfaceMap &&interfaceMap, HasTraitFn hasTrait)
+ ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
+ VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook,
+ GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
+ detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait)
: name(Identifier::get(name, dialect.getContext())), dialect(dialect),
typeID(typeID), interfaceMap(std::move(interfaceMap)),
- foldHookFn(foldHook),
- getCanonicalizationPatternsFn(getCanonicalizationPatterns),
- hasTraitFn(hasTrait), parseAssemblyFn(parseAssembly),
- printAssemblyFn(printAssembly), verifyInvariantsFn(verifyInvariants) {}
+ foldHookFn(std::move(foldHook)),
+ getCanonicalizationPatternsFn(std::move(getCanonicalizationPatterns)),
+ hasTraitFn(std::move(hasTrait)),
+ parseAssemblyFn(std::move(parseAssembly)),
+ printAssemblyFn(std::move(printAssembly)),
+ verifyInvariantsFn(std::move(verifyInvariants)) {}
//===----------------------------------------------------------------------===//
// AbstractType
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index f88f94cc57d91..aa0d12f7568fe 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -1830,7 +1830,7 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
// This is the actual hook for the custom op parsing, usually implemented by
// the op itself (`Op::parse()`). We retrieve it either from the
// AbstractOperation or from the Dialect.
- std::function<ParseResult(OpAsmParser &, OperationState &)> parseAssemblyFn;
+ function_ref<ParseResult(OpAsmParser &, OperationState &)> parseAssemblyFn;
bool isIsolatedFromAbove = false;
if (opDefinition) {
More information about the Mlir-commits
mailing list