[Mlir-commits] [mlir] 1bf3fd9 - [mlir] Use unique_function in AbstractOperation fields

River Riddle llvmlistbot at llvm.org
Tue May 25 11:40:58 PDT 2021


Author: Mathieu Fehr
Date: 2021-05-25T11:36:12-07:00
New Revision: 1bf3fd9bb55a8e9c8a2f6e446e956951d1715cf7

URL: https://github.com/llvm/llvm-project/commit/1bf3fd9bb55a8e9c8a2f6e446e956951d1715cf7
DIFF: https://github.com/llvm/llvm-project/commit/1bf3fd9bb55a8e9c8a2f6e446e956951d1715cf7.diff

LOG: [mlir] Use unique_function in AbstractOperation fields

Currently, AbstractOperation fields are function pointers.
Modifying them to unique_function allow them to contain
runtime information.

For instance, this allows operations to be defined at runtime.

Differential Revision: https://reviews.llvm.org/D103031

Added: 
    

Modified: 
    mlir/include/mlir/IR/OpDefinition.h
    mlir/include/mlir/IR/OperationSupport.h
    mlir/lib/IR/MLIRContext.cpp
    mlir/lib/Parser/Parser.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 735d59703029d..b2e50324f75df 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1671,7 +1671,10 @@ class Op : public OpState, public Traits<ConcreteType>... {
                               detect_has_single_result_fold<ConcreteOpT>::value,
                           AbstractOperation::FoldHookFn>
   getFoldHookFnImpl() {
-    return &foldSingleResultHook<ConcreteOpT>;
+    return [](Operation *op, ArrayRef<Attribute> operands,
+              SmallVectorImpl<OpFoldResult> &results) {
+      return foldSingleResultHook<ConcreteOpT>(op, operands, results);
+    };
   }
   /// The internal implementation of `getFoldHookFn` above that is invoked if
   /// the operation is not single result and defines a `fold` method.
@@ -1681,7 +1684,10 @@ class Op : public OpState, public Traits<ConcreteType>... {
                               detect_has_fold<ConcreteOpT>::value,
                           AbstractOperation::FoldHookFn>
   getFoldHookFnImpl() {
-    return &foldHook<ConcreteOpT>;
+    return [](Operation *op, ArrayRef<Attribute> operands,
+              SmallVectorImpl<OpFoldResult> &results) {
+      return foldHook<ConcreteOpT>(op, operands, results);
+    };
   }
   /// The internal implementation of `getFoldHookFn` above that is invoked if
   /// the operation does not define a `fold` method.
@@ -1690,8 +1696,12 @@ class Op : public OpState, public Traits<ConcreteType>... {
                               !detect_has_fold<ConcreteOpT>::value,
                           AbstractOperation::FoldHookFn>
   getFoldHookFnImpl() {
-    // In this case, we only need to fold the traits of the operation.
-    return &op_definition_impl::foldTraits<FoldableTraitsTupleT>;
+    return [](Operation *op, ArrayRef<Attribute> operands,
+              SmallVectorImpl<OpFoldResult> &results) {
+      // In this case, we only need to fold the traits of the operation.
+      return op_definition_impl::foldTraits<FoldableTraitsTupleT>(op, operands,
+                                                                  results);
+    };
   }
   /// Return the result of folding a single result operation that defines a
   /// `fold` method.
@@ -1735,7 +1745,8 @@ class Op : public OpState, public Traits<ConcreteType>... {
   }
   /// Implementation of `GetHasTraitFn`
   static AbstractOperation::HasTraitFn getHasTraitFn() {
-    return &op_definition_impl::hasTrait<Traits...>;
+    return
+        [](TypeID id) { return op_definition_impl::hasTrait<Traits...>(id); };
   }
   /// Implementation of `ParseAssemblyFn` AbstractOperation hook.
   static AbstractOperation::ParseAssemblyFn getParseAssemblyFn() {
@@ -1751,7 +1762,9 @@ class Op : public OpState, public Traits<ConcreteType>... {
   static std::enable_if_t<!detect_has_print<ConcreteOpT>::value,
                           AbstractOperation::PrintAssemblyFn>
   getPrintAssemblyFnImpl() {
-    return &OpState::print;
+    return [](Operation *op, OpAsmPrinter &parser) {
+      return OpState::print(op, parser);
+    };
   }
   /// The internal implementation of `getPrintAssemblyFn` that is invoked when
   /// the concrete operation defines a `print` method.

diff  --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index e7bcacf79a0c6..20d73cc20d127 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -67,14 +67,17 @@ using OwningRewritePatternList = RewritePatternSet;
 /// the concrete operation types.
 class AbstractOperation {
 public:
-  using GetCanonicalizationPatternsFn = void (*)(RewritePatternSet &,
-                                                 MLIRContext *);
-  using FoldHookFn = LogicalResult (*)(Operation *, ArrayRef<Attribute>,
-                                       SmallVectorImpl<OpFoldResult> &);
-  using HasTraitFn = bool (*)(TypeID);
-  using ParseAssemblyFn = ParseResult (*)(OpAsmParser &, OperationState &);
-  using PrintAssemblyFn = void (*)(Operation *, OpAsmPrinter &);
-  using VerifyInvariantsFn = LogicalResult (*)(Operation *);
+  using GetCanonicalizationPatternsFn =
+      llvm::unique_function<void(RewritePatternSet &, MLIRContext *) const>;
+  using FoldHookFn = llvm::unique_function<LogicalResult(
+      Operation *, ArrayRef<Attribute>, SmallVectorImpl<OpFoldResult> &) const>;
+  using HasTraitFn = llvm::unique_function<bool(TypeID) const>;
+  using ParseAssemblyFn =
+      llvm::unique_function<ParseResult(OpAsmParser &, OperationState &) const>;
+  using PrintAssemblyFn =
+      llvm::unique_function<void(Operation *, OpAsmPrinter &) const>;
+  using VerifyInvariantsFn =
+      llvm::unique_function<LogicalResult(Operation *) const>;
 
   /// This is the name of the operation.
   const Identifier name;
@@ -89,7 +92,7 @@ class AbstractOperation {
   ParseResult parseAssembly(OpAsmParser &parser, OperationState &result) const;
 
   /// Return the static hook for parsing this operation assembly.
-  ParseAssemblyFn getParseAssemblyFn() const { return parseAssemblyFn; }
+  const ParseAssemblyFn &getParseAssemblyFn() const { return parseAssemblyFn; }
 
   /// This hook implements the AsmPrinter for this operation.
   void printAssembly(Operation *op, OpAsmPrinter &p) const {
@@ -175,20 +178,21 @@ class AbstractOperation {
   /// Register a new operation in a Dialect object.
   /// The use of this method is in general discouraged in favor of
   /// 'insert<CustomOp>(dialect)'.
-  static void insert(StringRef name, Dialect &dialect, TypeID typeID,
-                     ParseAssemblyFn parseAssembly,
-                     PrintAssemblyFn printAssembly,
-                     VerifyInvariantsFn verifyInvariants, FoldHookFn foldHook,
-                     GetCanonicalizationPatternsFn getCanonicalizationPatterns,
-                     detail::InterfaceMap &&interfaceMap, HasTraitFn hasTrait);
+  static void
+  insert(StringRef name, Dialect &dialect, TypeID typeID,
+         ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
+         VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook,
+         GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
+         detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait);
 
 private:
   AbstractOperation(StringRef name, Dialect &dialect, TypeID typeID,
-                    ParseAssemblyFn parseAssembly,
-                    PrintAssemblyFn printAssembly,
-                    VerifyInvariantsFn verifyInvariants, FoldHookFn foldHook,
-                    GetCanonicalizationPatternsFn getCanonicalizationPatterns,
-                    detail::InterfaceMap &&interfaceMap, HasTraitFn hasTrait);
+                    ParseAssemblyFn &&parseAssembly,
+                    PrintAssemblyFn &&printAssembly,
+                    VerifyInvariantsFn &&verifyInvariants,
+                    FoldHookFn &&foldHook,
+                    GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
+                    detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait);
 
   /// A map of interfaces that were registered to this operation.
   detail::InterfaceMap interfaceMap;

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index f1825a481fb70..f438c00085516 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -696,13 +696,15 @@ const AbstractOperation *AbstractOperation::lookup(StringRef opName,
 
 void AbstractOperation::insert(
     StringRef name, Dialect &dialect, TypeID typeID,
-    ParseAssemblyFn parseAssembly, PrintAssemblyFn printAssembly,
-    VerifyInvariantsFn verifyInvariants, FoldHookFn foldHook,
-    GetCanonicalizationPatternsFn getCanonicalizationPatterns,
-    detail::InterfaceMap &&interfaceMap, HasTraitFn hasTrait) {
-  AbstractOperation opInfo(
-      name, dialect, typeID, parseAssembly, printAssembly, verifyInvariants,
-      foldHook, getCanonicalizationPatterns, std::move(interfaceMap), hasTrait);
+    ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
+    VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook,
+    GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
+    detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait) {
+  AbstractOperation opInfo(name, dialect, typeID, std::move(parseAssembly),
+                           std::move(printAssembly),
+                           std::move(verifyInvariants), std::move(foldHook),
+                           std::move(getCanonicalizationPatterns),
+                           std::move(interfaceMap), std::move(hasTrait));
 
   auto &impl = dialect.getContext()->getImpl();
   assert(impl.multiThreadedExecutionContext == 0 &&
@@ -717,16 +719,18 @@ void AbstractOperation::insert(
 
 AbstractOperation::AbstractOperation(
     StringRef name, Dialect &dialect, TypeID typeID,
-    ParseAssemblyFn parseAssembly, PrintAssemblyFn printAssembly,
-    VerifyInvariantsFn verifyInvariants, FoldHookFn foldHook,
-    GetCanonicalizationPatternsFn getCanonicalizationPatterns,
-    detail::InterfaceMap &&interfaceMap, HasTraitFn hasTrait)
+    ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
+    VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook,
+    GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
+    detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait)
     : name(Identifier::get(name, dialect.getContext())), dialect(dialect),
       typeID(typeID), interfaceMap(std::move(interfaceMap)),
-      foldHookFn(foldHook),
-      getCanonicalizationPatternsFn(getCanonicalizationPatterns),
-      hasTraitFn(hasTrait), parseAssemblyFn(parseAssembly),
-      printAssemblyFn(printAssembly), verifyInvariantsFn(verifyInvariants) {}
+      foldHookFn(std::move(foldHook)),
+      getCanonicalizationPatternsFn(std::move(getCanonicalizationPatterns)),
+      hasTraitFn(std::move(hasTrait)),
+      parseAssemblyFn(std::move(parseAssembly)),
+      printAssemblyFn(std::move(printAssembly)),
+      verifyInvariantsFn(std::move(verifyInvariants)) {}
 
 //===----------------------------------------------------------------------===//
 // AbstractType

diff  --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index f88f94cc57d91..aa0d12f7568fe 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -1830,7 +1830,7 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
   // This is the actual hook for the custom op parsing, usually implemented by
   // the op itself (`Op::parse()`). We retrieve it either from the
   // AbstractOperation or from the Dialect.
-  std::function<ParseResult(OpAsmParser &, OperationState &)> parseAssemblyFn;
+  function_ref<ParseResult(OpAsmParser &, OperationState &)> parseAssemblyFn;
   bool isIsolatedFromAbove = false;
 
   if (opDefinition) {


        


More information about the Mlir-commits mailing list