[Mlir-commits] [mlir] 23cdf7b - [mlir] separable registration of operation interfaces

Alex Zinenko llvmlistbot at llvm.org
Thu Jun 17 03:00:40 PDT 2021


Author: Alex Zinenko
Date: 2021-06-17T12:00:31+02:00
New Revision: 23cdf7b6ed9781040ef7923372247ce30b250f29

URL: https://github.com/llvm/llvm-project/commit/23cdf7b6ed9781040ef7923372247ce30b250f29
DIFF: https://github.com/llvm/llvm-project/commit/23cdf7b6ed9781040ef7923372247ce30b250f29.diff

LOG: [mlir] separable registration of operation interfaces

This is similar to attribute and type interfaces and mostly the same mechanism
(FallbackModel / ExternalModel, ODS generation). There are minor differences in
how the concept-based polymorphism is implemented for operations that are
accounted for by ODS backends, and this essentially adds a test and exposes the
API.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D104294

Added: 
    

Modified: 
    mlir/docs/Interfaces.md
    mlir/include/mlir/IR/OpDefinition.h
    mlir/include/mlir/IR/OperationSupport.h
    mlir/lib/IR/MLIRContext.cpp
    mlir/test/lib/Dialect/Test/TestInterfaces.td
    mlir/unittests/IR/InterfaceAttachmentTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md
index 8e75146b2c17..bcc9938df8a1 100644
--- a/mlir/docs/Interfaces.md
+++ b/mlir/docs/Interfaces.md
@@ -207,12 +207,12 @@ if (ExampleOpInterface example = dyn_cast<ExampleOpInterface>(op))
   llvm::errs() << "hook returned = " << example.exampleInterfaceHook() << "\n";
 ```
 
-#### External Models for Attribute/Type Interfaces
+#### External Models for Attribute, Operation and Type Interfaces
 
-It may be desirable to provide an interface implementation for an attribute or a
-type without modifying the definition of said attribute or type. Notably, this
-allows to implement interfaces for attributes and types outside of the dialect
-that defines them and, in particular, provide interfaces for built-in types.
+It may be desirable to provide an interface implementation for an IR object
+without modifying the definition of said object. Notably, this allows to
+implement interfaces for attributes, operations and types outside of the dialect
+that defines them, for example, to provide interfaces for built-in types.
 
 This is achieved by extending the concept-based polymorphism model with two more
 classes derived from `Concept` as follows.
@@ -261,9 +261,9 @@ struct ExampleTypeInterfaceTraits {
 };
 ```
 
-External models can be provided for attribute and type interfaces by deriving
-either `FallbackModel` or `ExternalModel` and by registering the model class
-with the attribute or type class in a given context. Other contexts will not see
+External models can be provided for attribute, operation and type interfaces by
+deriving either `FallbackModel` or `ExternalModel` and by registering the model
+class with the relevant class in a given context. Other contexts will not see
 the interface unless registered.
 
 ```c++

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 8fdf29cbd492..e7a4794bd5f4 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1619,6 +1619,19 @@ class Op : public OpState, public Traits<ConcreteType>... {
         reinterpret_cast<Operation *>(const_cast<void *>(pointer)));
   }
 
+  /// Attach the given models as implementations of the corresponding interfaces
+  /// for the concrete operation.
+  template <typename... Models>
+  static void attachInterface(MLIRContext &context) {
+    AbstractOperation *abstract = AbstractOperation::lookupMutable(
+        ConcreteType::getOperationName(), &context);
+    if (!abstract)
+      llvm::report_fatal_error(
+          "Attempting to attach an interface to an unregistered operation " +
+          ConcreteType::getOperationName() + ".");
+    abstract->interfaceMap.insert<Models...>();
+  }
+
 private:
   /// Trait to check if T provides a 'fold' method for a single result op.
   template <typename T, typename... Args>

diff  --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index c1f7a1ac6882..898454a67de7 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -162,7 +162,9 @@ class AbstractOperation {
   /// Look up the specified operation in the specified MLIRContext and return a
   /// pointer to it if present.  Otherwise, return a null pointer.
   static const AbstractOperation *lookup(StringRef opName,
-                                         MLIRContext *context);
+                                         MLIRContext *context) {
+    return lookupMutable(opName, context);
+  }
 
   /// This constructor is used by Dialect objects when they register the list of
   /// operations they contain.
@@ -194,6 +196,15 @@ class AbstractOperation {
                     GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
                     detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait);
 
+  /// Give Op access to lookupMutable.
+  template <typename ConcreteType, template <typename T> class... Traits>
+  friend class Op;
+
+  /// Look up the specified operation in the specified MLIRContext and return a
+  /// pointer to it if present.  Otherwise, return a null pointer.
+  static AbstractOperation *lookupMutable(StringRef opName,
+                                          MLIRContext *context);
+
   /// A map of interfaces that were registered to this operation.
   detail::InterfaceMap interfaceMap;
 

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 34551bbef9ee..da4b08ccd882 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -696,8 +696,8 @@ ParseResult AbstractOperation::parseAssembly(OpAsmParser &parser,
 
 /// Look up the specified operation in the operation set and return a pointer
 /// to it if present. Otherwise, return a null pointer.
-const AbstractOperation *AbstractOperation::lookup(StringRef opName,
-                                                   MLIRContext *context) {
+AbstractOperation *AbstractOperation::lookupMutable(StringRef opName,
+                                                    MLIRContext *context) {
   auto &impl = context->getImpl();
   auto it = impl.registeredOperations.find(opName);
   if (it != impl.registeredOperations.end())

diff  --git a/mlir/test/lib/Dialect/Test/TestInterfaces.td b/mlir/test/lib/Dialect/Test/TestInterfaces.td
index 3458abd7fe2d..817f2f78bc91 100644
--- a/mlir/test/lib/Dialect/Test/TestInterfaces.td
+++ b/mlir/test/lib/Dialect/Test/TestInterfaces.td
@@ -88,6 +88,24 @@ def TestExternalAttrInterface : AttrInterface<"TestExternalAttrInterface"> {
   ];
 }
 
+def TestExternalOpInterface : OpInterface<"TestExternalOpInterface"> {
+  let cppNamespace = "::mlir";
+  let methods = [
+    InterfaceMethod<"Returns the length of the operation name plus arg.",
+      "unsigned", "getNameLengthPlusArg", (ins "unsigned":$arg)>,
+    StaticInterfaceMethod<
+      "Returns the length of the operation name plus arg twice.", "unsigned",
+      "getNameLengthPlusArgTwice", (ins "unsigned":$arg)>,
+    InterfaceMethod<
+      "Returns the length of the product of the operation name and arg.",
+      "unsigned", "getNameLengthTimesArg", (ins "unsigned":$arg), "",
+      "return arg * $_op->getName().getStringRef().size();">,
+    StaticInterfaceMethod<"Returns the length of the operation name minus arg.",
+      "unsigned", "getNameLengthMinusArg", (ins "unsigned":$arg), "",
+      "return ConcreteOp::getOperationName().size() - arg;">,
+  ];
+}
+
 def TestEffectOpInterface
     : EffectOpInterfaceBase<"TestEffectOpInterface",
                             "::mlir::TestEffects::Effect"> {

diff  --git a/mlir/unittests/IR/InterfaceAttachmentTest.cpp b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
index e8b48aee98c6..6ad65438ae6c 100644
--- a/mlir/unittests/IR/InterfaceAttachmentTest.cpp
+++ b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
@@ -12,10 +12,12 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "gtest/gtest.h"
 
 #include "../../test/lib/Dialect/Test/TestAttributes.h"
+#include "../../test/lib/Dialect/Test/TestDialect.h"
 #include "../../test/lib/Dialect/Test/TestTypes.h"
 
 using namespace mlir;
@@ -150,4 +152,72 @@ TEST(InterfaceAttachment, Attribute) {
   EXPECT_EQ(iface.getSomeNumber(), 42);
 }
 
+/// External interface model for the module operation. Only provides non-default
+/// methods.
+struct TestExternalOpModel
+    : public TestExternalOpInterface::ExternalModel<TestExternalOpModel,
+                                                    ModuleOp> {
+  unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
+    return op->getName().getStringRef().size() + arg;
+  }
+
+  static unsigned getNameLengthPlusArgTwice(unsigned arg) {
+    return ModuleOp::getOperationName().size() + 2 * arg;
+  }
+};
+
+/// External interface model for the func operation. Provides non-deafult and
+/// overrides default methods.
+struct TestExternalOpOverridingModel
+    : public TestExternalOpInterface::FallbackModel<
+          TestExternalOpOverridingModel> {
+  unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
+    return op->getName().getStringRef().size() + arg;
+  }
+
+  static unsigned getNameLengthPlusArgTwice(unsigned arg) {
+    return FuncOp::getOperationName().size() + 2 * arg;
+  }
+
+  unsigned getNameLengthTimesArg(Operation *op, unsigned arg) const {
+    return 42;
+  }
+
+  static unsigned getNameLengthMinusArg(unsigned arg) { return 21; }
+};
+
+TEST(InterfaceAttachment, Operation) {
+  MLIRContext context;
+
+  // Initially, the operation doesn't have the interface.
+  auto moduleOp = ModuleOp::create(UnknownLoc::get(&context));
+  ASSERT_FALSE(isa<TestExternalOpInterface>(moduleOp.getOperation()));
+
+  // We can attach an external interface and now the operaiton has it.
+  ModuleOp::attachInterface<TestExternalOpModel>(context);
+  auto iface = dyn_cast<TestExternalOpInterface>(moduleOp.getOperation());
+  ASSERT_TRUE(iface != nullptr);
+  EXPECT_EQ(iface.getNameLengthPlusArg(10), 16u);
+  EXPECT_EQ(iface.getNameLengthTimesArg(3), 18u);
+  EXPECT_EQ(iface.getNameLengthPlusArgTwice(18), 42u);
+  EXPECT_EQ(iface.getNameLengthMinusArg(5), 1u);
+
+  // Default implementation can be overridden.
+  auto funcOp = FuncOp::create(UnknownLoc::get(&context), "function",
+                               FunctionType::get(&context, {}, {}));
+  ASSERT_FALSE(isa<TestExternalOpInterface>(funcOp.getOperation()));
+  FuncOp::attachInterface<TestExternalOpOverridingModel>(context);
+  iface = dyn_cast<TestExternalOpInterface>(funcOp.getOperation());
+  ASSERT_TRUE(iface != nullptr);
+  EXPECT_EQ(iface.getNameLengthPlusArg(10), 14u);
+  EXPECT_EQ(iface.getNameLengthTimesArg(0), 42u);
+  EXPECT_EQ(iface.getNameLengthPlusArgTwice(8), 20u);
+  EXPECT_EQ(iface.getNameLengthMinusArg(1000), 21u);
+
+  // Another context doesn't have the interfaces registered.
+  MLIRContext other;
+  auto otherModuleOp = ModuleOp::create(UnknownLoc::get(&other));
+  ASSERT_FALSE(isa<TestExternalOpInterface>(otherModuleOp.getOperation()));
+}
+
 } // end namespace


        


More information about the Mlir-commits mailing list