[Mlir-commits] [mlir] d905c10 - Add a mechanism for Dialects to provide a fallback for OpInterface

Mehdi Amini llvmlistbot at llvm.org
Wed Mar 24 01:41:51 PDT 2021


Author: Mehdi Amini
Date: 2021-03-24T08:41:40Z
New Revision: d905c1035395fb27a5599f3da20ade6874718549

URL: https://github.com/llvm/llvm-project/commit/d905c1035395fb27a5599f3da20ade6874718549
DIFF: https://github.com/llvm/llvm-project/commit/d905c1035395fb27a5599f3da20ade6874718549.diff

LOG: Add a mechanism for Dialects to provide a fallback for OpInterface

This mechanism makes it possible for a dialect to not register all
operations but still answer interface-based queries.
This can useful for dialects that are "open" or connected to an external
system and still interoperate with the compiler. It can also open up the
possibility to have a more extensible compiler at runtime: the compiler
does not need a pre-registration for each operation and the dialect can
inject behavior dynamically.

Reviewed By: rriddle, jpienaar

Differential Revision: https://reviews.llvm.org/D93085

Added: 
    

Modified: 
    mlir/docs/Interfaces.md
    mlir/include/mlir/IR/Dialect.h
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/IR/OpDefinition.h
    mlir/include/mlir/IR/SymbolInterfaces.td
    mlir/include/mlir/Support/InterfaceSupport.h
    mlir/include/mlir/TableGen/Dialect.h
    mlir/lib/TableGen/Dialect.cpp
    mlir/test/IR/test-side-effects.mlir
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/tools/mlir-tblgen/DialectGen.cpp
    mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
    mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md
index 7f2920bae775b..927e1d89a1f5f 100644
--- a/mlir/docs/Interfaces.md
+++ b/mlir/docs/Interfaces.md
@@ -207,6 +207,45 @@ if (ExampleOpInterface example = dyn_cast<ExampleOpInterface>(op))
   llvm::errs() << "hook returned = " << example.exampleInterfaceHook() << "\n";
 ```
 
+#### Dialect Fallback for OpInterface
+
+Some dialects have an open ecosystem and don't register all of the possible
+operations. In such cases it is still possible to provide support for
+implementing an `OpInterface` for these operation. When an operation isn't
+registered or does not provide an implementation for an interface, the query
+will fallback to the dialect itself.
+
+A second model is used for such cases and automatically generated when
+using ODS (see below) with the name `FallbackModel`. This model can be implemented
+for a particular dialect:
+
+```c++
+// This is the implementation of a dialect fallback for `ExampleOpInterface`.
+struct FallbackExampleOpInterface
+    : public ExampleOpInterface::FallbackModel<
+          FallbackExampleOpInterface> {
+  static bool classof(Operation *op) { return true; }
+
+  unsigned exampleInterfaceHook(Operation *op) const;
+  unsigned exampleStaticInterfaceHook() const;
+};
+```
+
+A dialect can then instantiate this implementation and returns it on specific
+operations by overriding the `getRegisteredInterfaceForOp` method :
+
+```c++
+void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
+                                               Identifier opName) {
+  if (typeID == TypeID::get<ExampleOpInterface>()) {
+    if (isSupported(opName))
+      return fallbackExampleOpInterface;
+    return nullptr;
+  }
+  return nullptr;
+}
+```
+
 #### Utilizing the ODS Framework
 
 Note: Before reading this section, the reader should have some familiarity with

diff  --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h
index 2fdbdc4829833..d3af95521d74b 100644
--- a/mlir/include/mlir/IR/Dialect.h
+++ b/mlir/include/mlir/IR/Dialect.h
@@ -158,6 +158,19 @@ class Dialect {
         getRegisteredInterface(InterfaceT::getInterfaceID()));
   }
 
+  /// Lookup an op interface for the given ID if one is registered, otherwise
+  /// nullptr.
+  virtual void *getRegisteredInterfaceForOp(TypeID interfaceID,
+                                            OperationName opName) {
+    return nullptr;
+  }
+  template <typename InterfaceT>
+  typename InterfaceT::Concept *
+  getRegisteredInterfaceForOp(OperationName opName) {
+    return static_cast<typename InterfaceT::Concept *>(
+        getRegisteredInterfaceForOp(InterfaceT::getInterfaceID(), opName));
+  }
+
 protected:
   /// The constructor takes a unique namespace for this dialect as well as the
   /// context to bind to.

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 88f1427a19226..2785dff7f591f 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -272,6 +272,9 @@ class Dialect {
 
   // If this dialect overrides the hook for verifying region result attributes.
   bit hasRegionResultAttrVerify = 0;
+
+  // If this dialect overrides the hook for op interface fallback.
+  bit hasOperationInterfaceFallback = 0;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index ec3884e58fc37..b101370fca8e2 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -19,6 +19,7 @@
 #ifndef MLIR_IR_OPDEFINITION_H
 #define MLIR_IR_OPDEFINITION_H
 
+#include "mlir/IR/Dialect.h"
 #include "mlir/IR/Operation.h"
 #include "llvm/Support/PointerLikeTypeTraits.h"
 
@@ -1721,7 +1722,20 @@ class OpInterface
   static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) {
     // Access the raw interface from the abstract operation.
     auto *abstractOp = op->getAbstractOperation();
-    return abstractOp ? abstractOp->getInterface<ConcreteType>() : nullptr;
+    if (abstractOp) {
+      if (auto *opIface = abstractOp->getInterface<ConcreteType>())
+        return opIface;
+      // Fallback to the dialect to provide it with a chance to implement this
+      // interface for this operation.
+      return abstractOp->dialect.getRegisteredInterfaceForOp<ConcreteType>(
+          op->getName());
+    }
+    // Fallback to the dialect to provide it with a chance to implement this
+    // interface for this operation.
+    Dialect *dialect = op->getName().getDialect();
+    return dialect ? dialect->getRegisteredInterfaceForOp<ConcreteType>(
+                         op->getName())
+                   : nullptr;
   }
 
   /// Allow access to `getInterfaceFor`.

diff  --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td
index a7b1fd8cfe643..68483c9677db0 100644
--- a/mlir/include/mlir/IR/SymbolInterfaces.td
+++ b/mlir/include/mlir/IR/SymbolInterfaces.td
@@ -181,7 +181,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
       auto *concept = getInterfaceFor(op);
       if (!concept)
         return false;
-      return !concept->isOptionalSymbol(op) ||
+      return !concept->isOptionalSymbol(concept, op) ||
              op->getAttr(::mlir::SymbolTable::getSymbolAttrName());
     }
   }];

diff  --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h
index d7a455722c773..6fc61170a5923 100644
--- a/mlir/include/mlir/Support/InterfaceSupport.h
+++ b/mlir/include/mlir/Support/InterfaceSupport.h
@@ -71,6 +71,8 @@ class Interface : public BaseType {
 public:
   using Concept = typename Traits::Concept;
   template <typename T> using Model = typename Traits::template Model<T>;
+  template <typename T>
+  using FallbackModel = typename Traits::template FallbackModel<T>;
   using InterfaceBase =
       Interface<ConcreteType, ValueT, Traits, BaseType, BaseTrait>;
 

diff  --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h
index ee86a2504b3c9..35a9e7ba4c9b6 100644
--- a/mlir/include/mlir/TableGen/Dialect.h
+++ b/mlir/include/mlir/TableGen/Dialect.h
@@ -63,6 +63,9 @@ class Dialect {
   /// Returns true if this dialect has a region result attribute verifier.
   bool hasRegionResultAttrVerify() const;
 
+  /// Returns true if this dialect has fallback interfaces for its operations.
+  bool hasOperationInterfaceFallback() const;
+
   // Returns whether two dialects are equal by checking the equality of the
   // underlying record.
   bool operator==(const Dialect &other) const;

diff  --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp
index 0c1de78ce60e1..6b88f8a40bc1c 100644
--- a/mlir/lib/TableGen/Dialect.cpp
+++ b/mlir/lib/TableGen/Dialect.cpp
@@ -77,6 +77,10 @@ bool Dialect::hasRegionResultAttrVerify() const {
   return def->getValueAsBit("hasRegionResultAttrVerify");
 }
 
+bool Dialect::hasOperationInterfaceFallback() const {
+  return def->getValueAsBit("hasOperationInterfaceFallback");
+}
+
 bool Dialect::operator==(const Dialect &other) const {
   return def == other.def;
 }

diff  --git a/mlir/test/IR/test-side-effects.mlir b/mlir/test/IR/test-side-effects.mlir
index db55414da03be..947faff99d265 100644
--- a/mlir/test/IR/test-side-effects.mlir
+++ b/mlir/test/IR/test-side-effects.mlir
@@ -30,3 +30,9 @@
 %4 = "test.side_effect_op"() {
   effect_parameter = affine_map<(i, j) -> (j, i)>
 } : () -> i32
+
+// Check with this unregistered operation to test the fallback on the dialect.
+// expected-remark at +1 {{found a parametric effect with affine_map<(d0, d1) -> (d1, d0)>}}
+%5 = "test.unregistered_side_effect_op"() {
+  effect_parameter = affine_map<(i, j) -> (j, i)>
+} : () -> i32

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 3bb4e8f4a6236..c7b4b582d9dd2 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -166,6 +166,29 @@ struct TestInlinerInterface : public DialectInlinerInterface {
 // TestDialect
 //===----------------------------------------------------------------------===//
 
+static void testSideEffectOpGetEffect(
+    Operation *op,
+    SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects);
+
+// This is the implementation of a dialect fallback for `TestEffectOpInterface`.
+struct TestOpEffectInterfaceFallback
+    : public TestEffectOpInterface::FallbackModel<
+          TestOpEffectInterfaceFallback> {
+  static bool classof(Operation *op) {
+    bool isSupportedOp =
+        op->getName().getStringRef() == "test.unregistered_side_effect_op";
+    assert(isSupportedOp && "Unexpected dispatch");
+    return isSupportedOp;
+  }
+
+  void
+  getEffects(Operation *op,
+             SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
+                 &effects) const {
+    testSideEffectOpGetEffect(op, effects);
+  }
+};
+
 void TestDialect::initialize() {
   registerAttributes();
   registerTypes();
@@ -176,6 +199,14 @@ void TestDialect::initialize() {
   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
                 TestInlinerInterface>();
   allowUnknownOperations();
+
+  // Instantiate our fallback op interface that we'll use on specific
+  // unregistered op.
+  fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
+}
+TestDialect::~TestDialect() {
+  delete static_cast<TestOpEffectInterfaceFallback *>(
+      fallbackEffectOpInterfaces);
 }
 
 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
@@ -183,6 +214,14 @@ Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
   return builder.create<TestOpConstant>(loc, type, value);
 }
 
+void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
+                                               OperationName opName) {
+  if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
+      typeID == TypeID::get<TestEffectOpInterface>())
+    return fallbackEffectOpInterfaces;
+  return nullptr;
+}
+
 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
                                                     NamedAttribute namedAttr) {
   if (namedAttr.first == "test.invalid_attr")
@@ -716,6 +755,17 @@ struct TestResource : public SideEffects::Resource::Base<TestResource> {
 };
 } // end anonymous namespace
 
+static void testSideEffectOpGetEffect(
+    Operation *op,
+    SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
+        &effects) {
+  auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
+  if (!effectsAttr)
+    return;
+
+  effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
+}
+
 void SideEffectOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   // Check for an effects attribute on the op instance.
@@ -754,11 +804,7 @@ void SideEffectOp::getEffects(
 
 void SideEffectOp::getEffects(
     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
-  auto effectsAttr = (*this)->getAttrOfType<AffineMapAttr>("effect_parameter");
-  if (!effectsAttr)
-    return;
-
-  effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
+  testSideEffectOpGetEffect(getOperation(), effects);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 8be84f2aacbc9..be53ee230f4aa 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -29,6 +29,7 @@ def Test_Dialect : Dialect {
   let hasOperationAttrVerify = 1;
   let hasRegionArgAttrVerify = 1;
   let hasRegionResultAttrVerify = 1;
+  let hasOperationInterfaceFallback = 1;
   let dependentDialects = ["::mlir::DLTIDialect"];
 
   let extraClassDeclaration = [{
@@ -45,6 +46,12 @@ def Test_Dialect : Dialect {
       getParseOperationHook(StringRef opName) const override;
     LogicalResult printOperation(Operation *op,
                                  OpAsmPrinter &printer) const override;
+
+    ~TestDialect();
+  private:
+    // Storage for a custom fallback interface.
+    void *fallbackEffectOpInterfaces;
+
   }];
 }
 

diff  --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp
index 8eaf17cd7c3fd..151519519afda 100644
--- a/mlir/tools/mlir-tblgen/DialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/DialectGen.cpp
@@ -143,6 +143,13 @@ static const char *const regionResultAttrVerifierDecl = R"(
         ::mlir::NamedAttribute attribute) override;
 )";
 
+/// The code block for the op interface fallback hook.
+static const char *const operationInterfaceFallbackDecl = R"(
+    /// Provides a hook for op interface.
+    void *getRegisteredInterfaceForOp(mlir::TypeID interfaceID,
+                                      mlir::OperationName opName) override;
+)";
+
 /// Generate the declaration for the given dialect class.
 static void emitDialectDecl(Dialect &dialect,
                             iterator_range<DialectFilterIterator> dialectAttrs,
@@ -181,6 +188,8 @@ static void emitDialectDecl(Dialect &dialect,
     os << regionArgAttrVerifierDecl;
   if (dialect.hasRegionResultAttrVerify())
     os << regionResultAttrVerifierDecl;
+  if (dialect.hasOperationInterfaceFallback())
+    os << operationInterfaceFallbackDecl;
   if (llvm::Optional<StringRef> extraDecl = dialect.getExtraClassDeclaration())
     os << *extraDecl;
 

diff  --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index 2f8f22ca3f225..d7692d2b6f5e1 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -43,9 +43,13 @@ static void emitMethodNameAndArgs(const InterfaceMethod &method,
                                   raw_ostream &os, StringRef valueType,
                                   bool addThisArg, bool addConst) {
   os << method.getName() << '(';
-  if (addThisArg)
+  if (addThisArg) {
+    if (addConst)
+      os << "const ";
+    os << "const Concept *impl, ";
     emitCPPType(valueType, os)
         << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", ");
+  }
   llvm::interleaveComma(method.getArguments(), os,
                         [&](const InterfaceMethod::Argument &arg) {
                           os << arg.type << " " << arg.name;
@@ -124,7 +128,9 @@ struct OpInterfaceGenerator : public InterfaceGenerator {
     interfaceBaseType = "OpInterface";
     valueTemplate = "ConcreteOp";
     StringRef castCode = "(llvm::cast<ConcreteOp>(tablegen_opaque_val))";
-    nonStaticMethodFmt.withOp(castCode).withSelf(castCode);
+    nonStaticMethodFmt.addSubst("_this", "impl")
+        .withOp(castCode)
+        .withSelf(castCode);
     traitMethodFmt.withOp("(*static_cast<ConcreteOp *>(this))");
   }
 };
@@ -167,6 +173,7 @@ static void emitInterfaceDef(Interface interface, StringRef valueType,
     // Forward to the method on the concrete operation type.
     os << " {\n      return getImpl()->" << method.getName() << '(';
     if (!method.isStatic()) {
+      os << "getImpl(), ";
       os << (isOpInterface ? "getOperation()" : "*this");
       os << (method.arg_empty() ? "" : ", ");
     }
@@ -197,8 +204,10 @@ void InterfaceGenerator::emitConceptDecl(Interface &interface) {
     os << "    ";
     emitCPPType(method.getReturnType(), os);
     os << "(*" << method.getName() << ")(";
-    if (!method.isStatic())
+    if (!method.isStatic()) {
+      os << "const Concept *impl, ";
       emitCPPType(valueType, os) << (method.arg_empty() ? "" : ", ");
+    }
     llvm::interleaveComma(
         method.getArguments(), os,
         [&](const InterfaceMethod::Argument &arg) { os << arg.type; });
@@ -208,23 +217,25 @@ void InterfaceGenerator::emitConceptDecl(Interface &interface) {
 }
 
 void InterfaceGenerator::emitModelDecl(Interface &interface) {
-  os << "  template<typename " << valueTemplate << ">\n";
-  os << "  class Model : public Concept {\n  public:\n";
-  os << "    Model() : Concept{";
-  llvm::interleaveComma(
-      interface.getMethods(), os,
-      [&](const InterfaceMethod &method) { os << method.getName(); });
-  os << "} {}\n\n";
-
-  // Insert each of the virtual method overrides.
-  for (auto &method : interface.getMethods()) {
-    emitCPPType(method.getReturnType(), os << "    static inline ");
-    emitMethodNameAndArgs(method, os, valueType,
-                          /*addThisArg=*/!method.isStatic(),
-                          /*addConst=*/false);
-    os << ";\n";
+  for (const char *modelClass : {"Model", "FallbackModel"}) {
+    os << "  template<typename " << valueTemplate << ">\n";
+    os << "  class " << modelClass << " : public Concept {\n  public:\n";
+    os << "    " << modelClass << "() : Concept{";
+    llvm::interleaveComma(
+        interface.getMethods(), os,
+        [&](const InterfaceMethod &method) { os << method.getName(); });
+    os << "} {}\n\n";
+
+    // Insert each of the virtual method overrides.
+    for (auto &method : interface.getMethods()) {
+      emitCPPType(method.getReturnType(), os << "    static inline ");
+      emitMethodNameAndArgs(method, os, valueType,
+                            /*addThisArg=*/!method.isStatic(),
+                            /*addConst=*/false);
+      os << ";\n";
+    }
+    os << "  };\n";
   }
-  os << "  };\n";
 }
 
 void InterfaceGenerator::emitModelMethodsDef(Interface &interface) {
@@ -261,6 +272,32 @@ void InterfaceGenerator::emitModelMethodsDef(Interface &interface) {
         [&](const InterfaceMethod::Argument &arg) { os << arg.name; });
     os << ");\n}\n";
   }
+
+  for (auto &method : interface.getMethods()) {
+    os << "template<typename " << valueTemplate << ">\n";
+    emitCPPType(method.getReturnType(), os);
+    os << "detail::" << interface.getName() << "InterfaceTraits::FallbackModel<"
+       << valueTemplate << ">::";
+    emitMethodNameAndArgs(method, os, valueType,
+                          /*addThisArg=*/!method.isStatic(),
+                          /*addConst=*/false);
+    os << " {\n  ";
+
+    // Forward to the method on the concrete Model implementation.
+    if (method.isStatic())
+      os << "return " << valueTemplate << "::";
+    else
+      os << "return static_cast<const " << valueTemplate << " *>(impl)->";
+
+    // Add the arguments to the call.
+    os << method.getName() << '(';
+    if (!method.isStatic())
+      os << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", ");
+    llvm::interleaveComma(
+        method.getArguments(), os,
+        [&](const InterfaceMethod::Argument &arg) { os << arg.name; });
+    os << ");\n}\n";
+  }
 }
 
 void InterfaceGenerator::emitTraitDecl(Interface &interface,

diff  --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index bb02a7993e43b..294432a0600a9 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -171,7 +171,7 @@ static void emitInterfaceDef(const Availability &availability,
   StringRef methodName = availability.getQueryFnName();
   os << availability.getQueryFnRetType() << " "
      << availability.getInterfaceClassName() << "::" << methodName << "() {\n"
-     << "  return getImpl()->" << methodName << "(getOperation());\n"
+     << "  return getImpl()->" << methodName << "(getImpl(), getOperation());\n"
      << "}\n";
 }
 
@@ -208,23 +208,25 @@ static void emitConceptDecl(const Availability &availability, raw_ostream &os) {
      << "    virtual ~Concept() = default;\n"
      << "    virtual " << availability.getQueryFnRetType() << " "
      << availability.getQueryFnName()
-     << "(Operation *tblgen_opaque_op) const = 0;\n"
+     << "(const Concept *impl, Operation *tblgen_opaque_op) const = 0;\n"
      << "  };\n";
 }
 
 static void emitModelDecl(const Availability &availability, raw_ostream &os) {
-  os << "  template<typename ConcreteOp>\n";
-  os << "  class Model : public Concept {\n"
-     << "  public:\n"
-     << "    " << availability.getQueryFnRetType() << " "
-     << availability.getQueryFnName()
-     << "(Operation *tblgen_opaque_op) const final {\n"
-     << "      auto op = llvm::cast<ConcreteOp>(tblgen_opaque_op);\n"
-     << "      (void)op;\n"
-     // Forward to the method on the concrete operation type.
-     << "      return op." << availability.getQueryFnName() << "();\n"
-     << "    }\n"
-     << "  };\n";
+  for (const char *modelClass : {"Model", "FallbackModel"}) {
+    os << "  template<typename ConcreteOp>\n";
+    os << "  class " << modelClass << " : public Concept {\n"
+       << "  public:\n"
+       << "    " << availability.getQueryFnRetType() << " "
+       << availability.getQueryFnName()
+       << "(const Concept *impl, Operation *tblgen_opaque_op) const final {\n"
+       << "      auto op = llvm::cast<ConcreteOp>(tblgen_opaque_op);\n"
+       << "      (void)op;\n"
+       // Forward to the method on the concrete operation type.
+       << "      return op." << availability.getQueryFnName() << "();\n"
+       << "    }\n"
+       << "  };\n";
+  }
 }
 
 static void emitInterfaceDecl(const Availability &availability,


        


More information about the Mlir-commits mailing list