[Mlir-commits] [mlir] 9b2a1bc - [mlir] separable registration of attribute and type interfaces

Alex Zinenko llvmlistbot at llvm.org
Tue Jun 15 06:20:35 PDT 2021


Author: Alex Zinenko
Date: 2021-06-15T15:20:27+02:00
New Revision: 9b2a1bcf6fbe79cfa3725cfd654f45cce0375d3b

URL: https://github.com/llvm/llvm-project/commit/9b2a1bcf6fbe79cfa3725cfd654f45cce0375d3b
DIFF: https://github.com/llvm/llvm-project/commit/9b2a1bcf6fbe79cfa3725cfd654f45cce0375d3b.diff

LOG: [mlir] separable registration of attribute and type interfaces

It may be desirable to provide an interface implementation for an attribute or
a type without modifying the definition of said attribute or type. Notably,
this allows to implement interfaces for attributes and types outside of the
dialect that defines them and, in particular, provide interfaces for built-in
types. Provide the mechanism to do so.

Currently, separable registration requires the attribute or type to have been
registered with the context, i.e. for the dialect containing the attribute or
type to be loaded. This can be relaxed in the future using a mechanism similar
to delayed dialect interface registration.

See https://llvm.discourse.group/t/rfc-separable-attribute-type-interfaces/3637

Depends On D104233

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D104234

Added: 
    mlir/unittests/IR/InterfaceAttachmentTest.cpp

Modified: 
    mlir/docs/Interfaces.md
    mlir/include/mlir/IR/AttributeSupport.h
    mlir/include/mlir/IR/Attributes.h
    mlir/include/mlir/IR/StorageUniquerSupport.h
    mlir/include/mlir/IR/TypeSupport.h
    mlir/include/mlir/IR/Types.h
    mlir/include/mlir/Support/InterfaceSupport.h
    mlir/lib/IR/MLIRContext.cpp
    mlir/test/lib/Dialect/Test/CMakeLists.txt
    mlir/test/lib/Dialect/Test/TestAttributes.cpp
    mlir/test/lib/Dialect/Test/TestAttributes.h
    mlir/test/lib/Dialect/Test/TestInterfaces.td
    mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
    mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
    mlir/unittests/IR/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md
index b444f65f81482..633e4924ebfd6 100644
--- a/mlir/docs/Interfaces.md
+++ b/mlir/docs/Interfaces.md
@@ -207,6 +207,91 @@ if (ExampleOpInterface example = dyn_cast<ExampleOpInterface>(op))
   llvm::errs() << "hook returned = " << example.exampleInterfaceHook() << "\n";
 ```
 
+#### External Models for Attribute/Type Interfaces
+
+It may be desirable to provide an interface implementation for an attribute or a
+type without modifying the definition of said attribute or type. Notably, this
+allows to implement interfaces for attributes and types outside of the dialect
+that defines them and, in particular, provide interfaces for built-in types.
+
+This is achieved by extending the concept-based polymorphism model with two more
+classes derived from `Concept` as follows.
+
+```c++
+struct ExampleTypeInterfaceTraits {
+  struct Concept {
+    virtual unsigned exampleInterfaceHook(Type type) const = 0;
+    virtual unsigned exampleStaticInterfaceHook() const = 0;
+  };
+
+  template <typename ConcreteType>
+  struct Model : public Concept { /*...*/ };
+
+  /// Unlike `Model`, `FallbackModel` passes the type object through to the
+  /// hook, making it accessible in the method body even if the method is not
+  /// defined in the class itself and thus has no `this` access. ODS
+  /// automatically generates this class for all interfaces.
+  template <typename ConcreteType>
+  struct FallbackModel : public Concept {
+    unsigned exampleInterfaceHook(Type type) const override {
+      getImpl()->exampleInterfaceHook(type);
+    }
+    unsigned exampleStaticInterfaceHook() const override {
+      ConcreteType::exampleStaticInterfaceHook();
+    }
+  };
+
+  /// `ExternalModel` provides a place for default implementations of interface
+  /// methods by explicitly separating the model class, which implements the
+  /// interface, from the type class, for which the interface is being
+  /// implemented. Default implementations can be then defined generically
+  /// making use of `cast<ConcreteType>`. If `ConcreteType` does not provide
+  /// the APIs required by the default implementation, custom implementations
+  /// may use `FallbackModel` directly to override the default implementation.
+  /// Being located in a class template, it never gets instantiated and does not
+  /// lead to compilation errors. ODS automatically generates this class and
+  /// places default method implementations in it.
+  template <typename ConcreteModel, typename ConcreteType>
+  struct ExternalModel : public FallbackModel<ConcreteModel> {
+    unsigned exampleInterfaceHook(Type type) const override {
+      // Default implementation can be provided here.
+      return type.cast<ConcreteType>().callSomeTypeSpecificMethod();
+    }
+  };
+};
+```
+
+External models can be provided for attribute and type interfaces by deriving
+either `FallbackModel` or `ExternalModel` and by registering the model class
+with the attribute or type class in a given context. Other contexts will not see
+the interface unless registered.
+
+```c++
+/// External interface implementation for a concrete class. This does not
+/// require modifying the definition of the type class itself.
+struct ExternalModelExample
+    : public ExampleTypeInterface::ExternalModel<ExternalModelExample,
+                                                 IntegerType> {
+  static unsigned exampleStaticInterfaceHook() {
+    // Implementation is provided here.
+    return IntegerType::someStaticMethod();
+  }
+
+  // No need to define `exampleInterfaceHook` that has a default implementation
+  // in `ExternalModel`. But it can be overridden if desired.
+}
+
+int main() {
+  MLIRContext context;
+  /* ... */;
+
+  // Register the interface model with the type in the given context before
+  // using it. The dialect contaiing the type is expected to have been loaded
+  // at this point.
+  IntegerType::registerInterface<ExternalModelExample>(context);
+}
+```
+
 #### Dialect Fallback for OpInterface
 
 Some dialects have an open ecosystem and don't register all of the possible
@@ -215,9 +300,9 @@ implementing an `OpInterface` for these operation. When an operation isn't
 registered or does not provide an implementation for an interface, the query
 will fallback to the dialect itself.
 
-A second model is used for such cases and automatically generated when
-using ODS (see below) with the name `FallbackModel`. This model can be implemented
-for a particular dialect:
+A second model is used for such cases and automatically generated when using ODS
+(see below) with the name `FallbackModel`. This model can be implemented for a
+particular dialect:
 
 ```c++
 // This is the implementation of a dialect fallback for `ExampleOpInterface`.

diff  --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h
index c228d1e867d2d..b10b54c6d3ca0 100644
--- a/mlir/include/mlir/IR/AttributeSupport.h
+++ b/mlir/include/mlir/IR/AttributeSupport.h
@@ -59,14 +59,24 @@ class AbstractAttribute {
       : dialect(dialect), interfaceMap(std::move(interfaceMap)),
         typeID(typeID) {}
 
+  /// Give StorageUserBase access to the mutable lookup.
+  template <typename ConcreteT, typename BaseT, typename StorageT,
+            typename UniquerT, template <typename T> class... Traits>
+  friend class detail::StorageUserBase;
+
+  /// Look up the specified abstract attribute in the MLIRContext and return a
+  /// (mutable) pointer to it. Return a null pointer if the attribute could not
+  /// be found in the context.
+  static AbstractAttribute *lookupMutable(TypeID typeID, MLIRContext *context);
+
   /// This is the dialect that this attribute was registered to.
-  Dialect &dialect;
+  const Dialect &dialect;
 
   /// This is a collection of the interfaces registered to this attribute.
   detail::InterfaceMap interfaceMap;
 
   /// The unique identifier of the derived Attribute class.
-  TypeID typeID;
+  const TypeID typeID;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index 04e2dccc66411..7a3adf21063f1 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -31,6 +31,7 @@ class Attribute {
 
   using ImplType = AttributeStorage;
   using ValueType = void;
+  using AbstractType = AbstractAttribute;
 
   constexpr Attribute() : impl(nullptr) {}
   /* implicit */ Attribute(const ImplType *impl)

diff  --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h
index 71c6703a40551..02498b75f2d54 100644
--- a/mlir/include/mlir/IR/StorageUniquerSupport.h
+++ b/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -87,6 +87,21 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
     return detail::InterfaceMap::template get<Traits<ConcreteT>...>();
   }
 
+  /// Attach the given models as implementations of the corresponding interfaces
+  /// for the concrete storage user class. The type must be registered with the
+  /// context, i.e. the dialect to which the type belongs must be loaded. The
+  /// call will abort otherwise.
+  template <typename... IfaceModels>
+  static void attachInterface(MLIRContext &context) {
+    typename ConcreteT::AbstractType *abstract =
+        ConcreteT::AbstractType::lookupMutable(TypeID::get<ConcreteT>(),
+                                               &context);
+    if (!abstract)
+      llvm::report_fatal_error("Registering an interface for an attribute/type "
+                               "that is not itself registered.");
+    abstract->interfaceMap.template insert<IfaceModels...>();
+  }
+
   /// Get or create a new ConcreteT instance within the ctx. This
   /// function is guaranteed to return a non null object and will assert if
   /// the arguments provided are invalid.

diff  --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h
index 898c26dc42c27..df05e9ab198b7 100644
--- a/mlir/include/mlir/IR/TypeSupport.h
+++ b/mlir/include/mlir/IR/TypeSupport.h
@@ -67,14 +67,24 @@ class AbstractType {
       : dialect(dialect), interfaceMap(std::move(interfaceMap)),
         typeID(typeID) {}
 
+  /// Give StorageUserBase access to the mutable lookup.
+  template <typename ConcreteT, typename BaseT, typename StorageT,
+            typename UniquerT, template <typename T> class... Traits>
+  friend class detail::StorageUserBase;
+
+  /// Look up the specified abstract type in the MLIRContext and return a
+  /// (mutable) pointer to it. Return a null pointer if the type could not
+  /// be found in the context.
+  static AbstractType *lookupMutable(TypeID typeID, MLIRContext *context);
+
   /// This is the dialect that this type was registered to.
-  Dialect &dialect;
+  const Dialect &dialect;
 
   /// This is a collection of the interfaces registered to this type.
   detail::InterfaceMap interfaceMap;
 
   /// The unique identifier of the derived Type class.
-  TypeID typeID;
+  const TypeID typeID;
 };
 
 //===----------------------------------------------------------------------===//
@@ -105,11 +115,11 @@ class TypeStorage : public StorageUniquer::BaseStorage {
   /// Set the abstract type for this storage instance. This is used by the
   /// TypeUniquer when initializing a newly constructed type storage object.
   void initialize(const AbstractType &abstractTy) {
-    abstractType = &abstractTy;
+    abstractType = const_cast<AbstractType *>(&abstractTy);
   }
 
   /// The abstract description for this type.
-  const AbstractType *abstractType;
+  AbstractType *abstractType;
 };
 
 /// Default storage type for types that require no additional initialization or

diff  --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index a95447d530972..ccebce20c2c8e 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -79,6 +79,8 @@ class Type {
 
   using ImplType = TypeStorage;
 
+  using AbstractType = AbstractType;
+
   constexpr Type() : impl(nullptr) {}
   /* implicit */ Type(const ImplType *impl)
       : impl(const_cast<ImplType *>(impl)) {}

diff  --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h
index 1eed2bbf32477..a49d96d576278 100644
--- a/mlir/include/mlir/Support/InterfaceSupport.h
+++ b/mlir/include/mlir/Support/InterfaceSupport.h
@@ -76,6 +76,8 @@ class Interface : public BaseType {
   using FallbackModel = typename Traits::template FallbackModel<T>;
   using InterfaceBase =
       Interface<ConcreteType, ValueT, Traits, BaseType, BaseTrait>;
+  template <typename T, typename U>
+  using ExternalModel = typename Traits::template ExternalModel<T, U>;
 
   /// This is a special trait that registers a given interface with an object.
   template <typename ConcreteT>
@@ -140,6 +142,25 @@ struct FilterTypes {
           typename FilterTypeT<Pred<Es>::value>::template type<Es>>()...));
 };
 
+namespace {
+/// Type trait indicating whether all template arguments are
+/// trivially-destructible.
+template <typename... Args>
+struct all_trivially_destructible;
+
+template <typename Arg, typename... Args>
+struct all_trivially_destructible<Arg, Args...> {
+  static constexpr const bool value =
+      std::is_trivially_destructible<Arg>::value &&
+      all_trivially_destructible<Args...>::value;
+};
+
+template <>
+struct all_trivially_destructible<> {
+  static constexpr const bool value = true;
+};
+} // namespace
+
 /// This class provides an efficient mapping between a given `Interface` type,
 /// and a particular implementation of its concept.
 class InterfaceMap {
@@ -199,6 +220,28 @@ class InterfaceMap {
     });
   }
 
+  /// Insert the given models as implementations of the corresponding interfaces
+  /// for the concrete attribute class.
+  template <typename... IfaceModels>
+  void insert() {
+    static_assert(all_trivially_destructible<IfaceModels...>::value,
+                  "interface models must be trivially destructible");
+    std::pair<TypeID, void *> elements[] = {
+        std::make_pair(IfaceModels::Interface::getInterfaceID(),
+                       new (malloc(sizeof(IfaceModels))) IfaceModels())...};
+    // Insert directly into the right position to keep the interfaces sorted.
+    for (auto &element : elements) {
+      TypeID id = element.first;
+      auto it =
+          llvm::lower_bound(interfaces, id, [](const auto &it, TypeID id) {
+            return compare(it.first, id);
+          });
+      if (it != interfaces.end() && it->first == id)
+        llvm::report_fatal_error("Interface already registered");
+      interfaces.insert(it, element);
+    }
+  }
+
 private:
   /// Compare two TypeID instances by comparing the underlying pointer.
   static bool compare(TypeID lhs, TypeID rhs) {

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 8b189f8a57474..34551bbef9ee5 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -306,7 +306,7 @@ class MLIRContextImpl {
   // Type uniquing
   //===--------------------------------------------------------------------===//
 
-  DenseMap<TypeID, const AbstractType *> registeredTypes;
+  DenseMap<TypeID, AbstractType *> registeredTypes;
   StorageUniquer typeUniquer;
 
   /// Cached Type Instances.
@@ -324,7 +324,7 @@ class MLIRContextImpl {
   // Attribute uniquing
   //===--------------------------------------------------------------------===//
 
-  DenseMap<TypeID, const AbstractAttribute *> registeredAttributes;
+  DenseMap<TypeID, AbstractAttribute *> registeredAttributes;
   StorageUniquer attributeUniquer;
 
   /// Cached Attribute Instances.
@@ -669,12 +669,20 @@ void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
 /// Get the dialect that registered the attribute with the provided typeid.
 const AbstractAttribute &AbstractAttribute::lookup(TypeID typeID,
                                                    MLIRContext *context) {
+  const AbstractAttribute *abstract = lookupMutable(typeID, context);
+  if (!abstract)
+    llvm::report_fatal_error("Trying to create an Attribute that was not "
+                             "registered in this MLIRContext.");
+  return *abstract;
+}
+
+AbstractAttribute *AbstractAttribute::lookupMutable(TypeID typeID,
+                                                    MLIRContext *context) {
   auto &impl = context->getImpl();
   auto it = impl.registeredAttributes.find(typeID);
   if (it == impl.registeredAttributes.end())
-    llvm::report_fatal_error("Trying to create an Attribute that was not "
-                             "registered in this MLIRContext.");
-  return *it->second;
+    return nullptr;
+  return it->second;
 }
 
 //===----------------------------------------------------------------------===//
@@ -740,12 +748,19 @@ AbstractOperation::AbstractOperation(
 //===----------------------------------------------------------------------===//
 
 const AbstractType &AbstractType::lookup(TypeID typeID, MLIRContext *context) {
+  const AbstractType *type = lookupMutable(typeID, context);
+  if (!type)
+    llvm::report_fatal_error(
+        "Trying to create a Type that was not registered in this MLIRContext.");
+  return *type;
+}
+
+AbstractType *AbstractType::lookupMutable(TypeID typeID, MLIRContext *context) {
   auto &impl = context->getImpl();
   auto it = impl.registeredTypes.find(typeID);
   if (it == impl.registeredTypes.end())
-    llvm::report_fatal_error(
-        "Trying to create a Type that was not registered in this MLIRContext.");
-  return *it->second;
+    return nullptr;
+  return it->second;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index 30fe52e150790..c5bda7ba72c8c 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -5,6 +5,8 @@ set(LLVM_OPTIONAL_SOURCES
 )
 
 set(LLVM_TARGET_DEFINITIONS TestInterfaces.td)
+mlir_tablegen(TestAttrInterfaces.h.inc -gen-attr-interface-decls)
+mlir_tablegen(TestAttrInterfaces.cpp.inc -gen-attr-interface-defs)
 mlir_tablegen(TestTypeInterfaces.h.inc -gen-type-interface-decls)
 mlir_tablegen(TestTypeInterfaces.cpp.inc -gen-type-interface-defs)
 mlir_tablegen(TestOpInterfaces.h.inc -gen-op-interface-decls)

diff  --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 8dcc3498c9647..b6a2ca6145e27 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -93,6 +93,8 @@ void CompoundAAttr::print(DialectAsmPrinter &printer) const {
 // Tablegen Generated Definitions
 //===----------------------------------------------------------------------===//
 
+#include "TestAttrInterfaces.cpp.inc"
+
 #define GET_ATTRDEF_CLASSES
 #include "TestAttrDefs.cpp.inc"
 

diff  --git a/mlir/test/lib/Dialect/Test/TestAttributes.h b/mlir/test/lib/Dialect/Test/TestAttributes.h
index 0eaa78eae5902..f1021141729d7 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.h
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.h
@@ -21,6 +21,8 @@
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/DialectImplementation.h"
 
+#include "TestAttrInterfaces.h.inc"
+
 #define GET_ATTRDEF_CLASSES
 #include "TestAttrDefs.h.inc"
 

diff  --git a/mlir/test/lib/Dialect/Test/TestInterfaces.td b/mlir/test/lib/Dialect/Test/TestInterfaces.td
index ef21d9f9413ec..3458abd7fe2d9 100644
--- a/mlir/test/lib/Dialect/Test/TestInterfaces.td
+++ b/mlir/test/lib/Dialect/Test/TestInterfaces.td
@@ -54,6 +54,40 @@ def TestTypeInterface : TypeInterface<"TestTypeInterface"> {
   }];
 }
 
+def TestExternalTypeInterface : TypeInterface<"TestExternalTypeInterface"> {
+  let cppNamespace = "::mlir";
+  let methods = [
+    InterfaceMethod<"Returns the bitwidth of the type plus 'arg'.",
+      "unsigned", "getBitwidthPlusArg", (ins "unsigned":$arg)>,
+    StaticInterfaceMethod<"Returns some value plus 'arg'.",
+      "unsigned", "staticGetSomeValuePlusArg", (ins "unsigned":$arg)>,
+    InterfaceMethod<"Returns the argument doubled.",
+      "unsigned", "getBitwidthPlusDoubleArgument", (ins "unsigned":$arg), "",
+      "return $_type.getIntOrFloatBitWidth() + 2 * arg;">,
+    StaticInterfaceMethod<"Returns the argument.",
+      "unsigned", "staticGetArgument", (ins "unsigned":$arg), "",
+      "return arg;">,
+  ];
+}
+
+def TestExternalFallbackTypeInterface
+    : TypeInterface<"TestExternalFallbackTypeInterface"> {
+  let cppNamespace = "::mlir";
+  let methods = [
+    InterfaceMethod<"Returns the bitwidth of the given integer type.",
+      "unsigned", "getBitwidth", (ins), "", "return $_type.getWidth();">,
+  ];
+}
+
+def TestExternalAttrInterface : AttrInterface<"TestExternalAttrInterface"> {
+  let cppNamespace = "::mlir";
+  let methods = [
+    InterfaceMethod<"Gets the dialect pointer.", "const ::mlir::Dialect *",
+      "getDialectPtr">,
+    StaticInterfaceMethod<"Returns some number.", "int", "getSomeNumber">,
+  ];
+}
+
 def TestEffectOpInterface
     : EffectOpInterfaceBase<"TestEffectOpInterface",
                             "::mlir::TestEffects::Effect"> {

diff  --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index 11531957b955f..3fdaf6d3ebf9c 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -217,9 +217,13 @@ void InterfaceGenerator::emitConceptDecl(Interface &interface) {
 }
 
 void InterfaceGenerator::emitModelDecl(Interface &interface) {
+  // Emit the basic model and the fallback model.
   for (const char *modelClass : {"Model", "FallbackModel"}) {
     os << "  template<typename " << valueTemplate << ">\n";
     os << "  class " << modelClass << " : public Concept {\n  public:\n";
+    os << "    using Interface = " << interface.getCppNamespace()
+       << (interface.getCppNamespace().empty() ? "" : "::")
+       << interface.getName() << ";\n";
     os << "    " << modelClass << "() : Concept{";
     llvm::interleaveComma(
         interface.getMethods(), os,
@@ -236,6 +240,40 @@ void InterfaceGenerator::emitModelDecl(Interface &interface) {
     }
     os << "  };\n";
   }
+
+  // Emit the template for the external model.
+  os << "  template<typename ConcreteModel, typename " << valueTemplate
+     << ">\n";
+  os << "  class ExternalModel : public FallbackModel<ConcreteModel> {\n";
+  os << "  public:\n";
+
+  // Emit declarations for methods that have default implementations. Other
+  // methods are expected to be implemented by the concrete derived model.
+  for (auto &method : interface.getMethods()) {
+    if (!method.getDefaultImplementation())
+      continue;
+    os << "    ";
+    if (method.isStatic())
+      os << "static ";
+    emitCPPType(method.getReturnType(), os);
+    os << method.getName() << "(";
+    if (!method.isStatic()) {
+      emitCPPType(valueType, os);
+      os << "tablegen_opaque_val";
+      if (!method.arg_empty())
+        os << ", ";
+    }
+    llvm::interleaveComma(method.getArguments(), os,
+                          [&](const InterfaceMethod::Argument &arg) {
+                            emitCPPType(arg.type, os);
+                            os << arg.name;
+                          });
+    os << ")";
+    if (!method.isStatic())
+      os << " const";
+    os << ";\n";
+  }
+  os << "  };\n";
 }
 
 void InterfaceGenerator::emitModelMethodsDef(Interface &interface) {
@@ -298,6 +336,42 @@ void InterfaceGenerator::emitModelMethodsDef(Interface &interface) {
         [&](const InterfaceMethod::Argument &arg) { os << arg.name; });
     os << ");\n}\n";
   }
+
+  // Emit default implementations for the external model.
+  for (auto &method : interface.getMethods()) {
+    if (!method.getDefaultImplementation())
+      continue;
+    os << "template<typename ConcreteModel, typename " << valueTemplate
+       << ">\n";
+    emitCPPType(method.getReturnType(), os);
+    os << "detail::" << interface.getName()
+       << "InterfaceTraits::ExternalModel<ConcreteModel, " << valueTemplate
+       << ">::";
+
+    os << method.getName() << "(";
+    if (!method.isStatic()) {
+      emitCPPType(valueType, os);
+      os << "tablegen_opaque_val";
+      if (!method.arg_empty())
+        os << ", ";
+    }
+    llvm::interleaveComma(method.getArguments(), os,
+                          [&](const InterfaceMethod::Argument &arg) {
+                            emitCPPType(arg.type, os);
+                            os << arg.name;
+                          });
+    os << ")";
+    if (!method.isStatic())
+      os << " const";
+
+    os << " {\n";
+
+    // Use the empty context for static methods.
+    tblgen::FmtContext ctx;
+    os << tblgen::tgfmt(method.getDefaultImplementation()->trim(),
+                        method.isStatic() ? &ctx : &nonStaticMethodFmt);
+    os << "\n}\n";
+  }
 }
 
 void InterfaceGenerator::emitTraitDecl(Interface &interface,

diff  --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 294432a0600a9..10c00fea9b8be 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -227,6 +227,8 @@ static void emitModelDecl(const Availability &availability, raw_ostream &os) {
        << "    }\n"
        << "  };\n";
   }
+  os << "  template<typename ConcreteModel, typename ConcreteOp>\n";
+  os << "  class ExternalModel : public FallbackModel<ConcreteOp> {};\n";
 }
 
 static void emitInterfaceDecl(const Availability &availability,

diff  --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt
index f9cfed62714e1..44da5109a71af 100644
--- a/mlir/unittests/IR/CMakeLists.txt
+++ b/mlir/unittests/IR/CMakeLists.txt
@@ -1,11 +1,17 @@
 add_mlir_unittest(MLIRIRTests
   AttributeTest.cpp
   DialectTest.cpp
+  InterfaceAttachmentTest.cpp
   MemRefTypeTest.cpp
   OperationSupportTest.cpp
   ShapedTypeTest.cpp
   SubElementInterfaceTest.cpp
+
+  DEPENDS
+  MLIRTestInterfaceIncGen
 )
+target_include_directories(MLIRIRTests PRIVATE "${MLIR_BINARY_DIR}/test/lib/Dialect/Test")
 target_link_libraries(MLIRIRTests
   PRIVATE
-  MLIRIR)
+  MLIRIR
+  MLIRTestDialect)

diff  --git a/mlir/unittests/IR/InterfaceAttachmentTest.cpp b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
new file mode 100644
index 0000000000000..e8b48aee98c6e
--- /dev/null
+++ b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
@@ -0,0 +1,153 @@
+//===- InterfaceAttachmentTest.cpp - Test attaching interfaces ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This implements the tests for attaching interfaces to attributes and types
+// without having to specify them on the attribute or type class directly.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "gtest/gtest.h"
+
+#include "../../test/lib/Dialect/Test/TestAttributes.h"
+#include "../../test/lib/Dialect/Test/TestTypes.h"
+
+using namespace mlir;
+using namespace mlir::test;
+
+namespace {
+
+/// External interface model for the integer type. Only provides non-default
+/// methods.
+struct Model
+    : public TestExternalTypeInterface::ExternalModel<Model, IntegerType> {
+  unsigned getBitwidthPlusArg(Type type, unsigned arg) const {
+    return type.getIntOrFloatBitWidth() + arg;
+  }
+
+  static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; }
+};
+
+/// External interface model for the float type. Provides non-deafult and
+/// overrides default methods.
+struct OverridingModel
+    : public TestExternalTypeInterface::ExternalModel<OverridingModel,
+                                                      FloatType> {
+  unsigned getBitwidthPlusArg(Type type, unsigned arg) const {
+    return type.getIntOrFloatBitWidth() + arg;
+  }
+
+  static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; }
+
+  unsigned getBitwidthPlusDoubleArgument(Type type, unsigned arg) const {
+    return 128;
+  }
+
+  static unsigned staticGetArgument(unsigned arg) { return 420; }
+};
+
+TEST(InterfaceAttachment, Type) {
+  MLIRContext context;
+
+  // Check that the type has no interface.
+  IntegerType i8 = IntegerType::get(&context, 8);
+  ASSERT_FALSE(i8.isa<TestExternalTypeInterface>());
+
+  // Attach an interface and check that the type now has the interface.
+  IntegerType::attachInterface<Model>(context);
+  TestExternalTypeInterface iface = i8.dyn_cast<TestExternalTypeInterface>();
+  ASSERT_TRUE(iface != nullptr);
+  EXPECT_EQ(iface.getBitwidthPlusArg(10), 18u);
+  EXPECT_EQ(iface.staticGetSomeValuePlusArg(0), 42u);
+  EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(2), 12u);
+  EXPECT_EQ(iface.staticGetArgument(17), 17u);
+
+  // Same, but with the default implementation overridden.
+  FloatType flt = Float32Type::get(&context);
+  ASSERT_FALSE(flt.isa<TestExternalTypeInterface>());
+  Float32Type::attachInterface<OverridingModel>(context);
+  iface = flt.dyn_cast<TestExternalTypeInterface>();
+  ASSERT_TRUE(iface != nullptr);
+  EXPECT_EQ(iface.getBitwidthPlusArg(10), 42u);
+  EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 52u);
+  EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(3), 128u);
+  EXPECT_EQ(iface.staticGetArgument(17), 420u);
+
+  // Other contexts shouldn't have the attribute attached.
+  MLIRContext other;
+  IntegerType i8other = IntegerType::get(&other, 8);
+  EXPECT_FALSE(i8other.isa<TestExternalTypeInterface>());
+}
+
+/// The interface provides a default implementation that expects
+/// ConcreteType::getWidth to exist, which is the case for IntegerType. So this
+/// just derives from the ExternalModel.
+struct TestExternalFallbackTypeIntegerModel
+    : public TestExternalFallbackTypeInterface::ExternalModel<
+          TestExternalFallbackTypeIntegerModel, IntegerType> {};
+
+/// The interface provides a default implementation that expects
+/// ConcreteType::getWidth to exist, which is *not* the case for VectorType. Use
+/// FallbackModel instead to override this and make sure the code still compiles
+/// because we never instantiate the ExternalModel class template with a
+/// template argument that would have led to compilation failures.
+struct TestExternalFallbackTypeVectorModel
+    : public TestExternalFallbackTypeInterface::FallbackModel<
+          TestExternalFallbackTypeVectorModel> {
+  unsigned getBitwidth(Type type) const {
+    IntegerType elementType = type.cast<VectorType>()
+                                  .getElementType()
+                                  .dyn_cast_or_null<IntegerType>();
+    return elementType ? elementType.getWidth() : 0;
+  }
+};
+
+TEST(InterfaceAttachment, Fallback) {
+  MLIRContext context;
+
+  // Just check that we can attach the interface.
+  IntegerType i8 = IntegerType::get(&context, 8);
+  ASSERT_FALSE(i8.isa<TestExternalFallbackTypeInterface>());
+  IntegerType::attachInterface<TestExternalFallbackTypeIntegerModel>(context);
+  ASSERT_TRUE(i8.isa<TestExternalFallbackTypeInterface>());
+
+  // Call the method so it is guaranteed not to be instantiated.
+  VectorType vec = VectorType::get({42}, i8);
+  ASSERT_FALSE(vec.isa<TestExternalFallbackTypeInterface>());
+  VectorType::attachInterface<TestExternalFallbackTypeVectorModel>(context);
+  ASSERT_TRUE(vec.isa<TestExternalFallbackTypeInterface>());
+  EXPECT_EQ(vec.cast<TestExternalFallbackTypeInterface>().getBitwidth(), 8u);
+}
+
+/// External model for attribute interfaces.
+struct TextExternalIntegerAttrModel
+    : public TestExternalAttrInterface::ExternalModel<
+          TextExternalIntegerAttrModel, IntegerAttr> {
+  const Dialect *getDialectPtr(Attribute attr) const {
+    return &attr.cast<IntegerAttr>().getDialect();
+  }
+
+  static int getSomeNumber() { return 42; }
+};
+
+TEST(InterfaceAttachment, Attribute) {
+  MLIRContext context;
+
+  // Attribute interfaces use the exact same mechanism as types, so just check
+  // that the basics work for attributes.
+  IntegerAttr attr = IntegerAttr::get(IntegerType::get(&context, 32), 42);
+  ASSERT_FALSE(attr.isa<TestExternalAttrInterface>());
+  IntegerAttr::attachInterface<TextExternalIntegerAttrModel>(context);
+  auto iface = attr.dyn_cast<TestExternalAttrInterface>();
+  ASSERT_TRUE(iface != nullptr);
+  EXPECT_EQ(iface.getDialectPtr(), &attr.getDialect());
+  EXPECT_EQ(iface.getSomeNumber(), 42);
+}
+
+} // end namespace


        


More information about the Mlir-commits mailing list