[Mlir-commits] [mlir] 88f25bd - [mlir] Allow for using interface class name in ODS interface definitions

Alex Zinenko llvmlistbot at llvm.org
Tue Nov 17 05:29:04 PST 2020


Author: Alex Zinenko
Date: 2020-11-17T14:28:55+01:00
New Revision: 88f25bda1376b68631106c0e1c5cbe3f385204e0

URL: https://github.com/llvm/llvm-project/commit/88f25bda1376b68631106c0e1c5cbe3f385204e0
DIFF: https://github.com/llvm/llvm-project/commit/88f25bda1376b68631106c0e1c5cbe3f385204e0.diff

LOG: [mlir] Allow for using interface class name in ODS interface definitions

It may be necessary for interface methods to process or return variables with
the interface class type, in particular for attribute and type interfaces that
can return modified attributes and types that implement the same interface.
However, the code generated by ODS in this case would not compile because the
signature (and the body if provided) appear in the definition of the Model
class and before the interface class, which derives from the Model. Change the ODS
interface method generator to emit only method declarations in the Model class
itself, and emit method definitions after the interface class. Mark as "inline"
since their definitions are still emitted in the header and are no longer
implicitly inline. Add a forward declaration of the interface class before the
Concept+Model classes to make the class name usable in declarations.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D91499

Added: 
    

Modified: 
    mlir/test/lib/Dialect/Test/TestInterfaces.td
    mlir/test/lib/IR/TestInterfaces.cpp
    mlir/test/mlir-tblgen/interfaces.mlir
    mlir/test/mlir-tblgen/op-interface.td
    mlir/tools/mlir-tblgen/OpInterfacesGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/test/lib/Dialect/Test/TestInterfaces.td b/mlir/test/lib/Dialect/Test/TestInterfaces.td
index 08f07cf3a11d..19a779d0a81c 100644
--- a/mlir/test/lib/Dialect/Test/TestInterfaces.td
+++ b/mlir/test/lib/Dialect/Test/TestInterfaces.td
@@ -28,6 +28,15 @@ def TestTypeInterface : TypeInterface<"TestTypeInterface"> {
     InterfaceMethod<"Prints the type name.",
       "void", "printTypeC", (ins "Location":$loc)
     >,
+    // It should be possible to use the interface type name as result type
+    // as well as in the implementation.
+    InterfaceMethod<"Prints the type name and returns the type as interface.",
+      "TestTypeInterface", "printTypeRet", (ins "Location":$loc),
+      [{}], /*defaultImplementation=*/[{
+        emitRemark(loc) << $_type << " - TestRet";
+        return $_type;
+      }]
+    >,
   ];
   let extraClassDeclaration = [{
     /// Prints the type name.

diff  --git a/mlir/test/lib/IR/TestInterfaces.cpp b/mlir/test/lib/IR/TestInterfaces.cpp
index 369001f3540a..3a6e10d40e93 100644
--- a/mlir/test/lib/IR/TestInterfaces.cpp
+++ b/mlir/test/lib/IR/TestInterfaces.cpp
@@ -25,6 +25,10 @@ struct TestTypeInterfaces
           testInterface.printTypeB(op->getLoc());
           testInterface.printTypeC(op->getLoc());
           testInterface.printTypeD(op->getLoc());
+          // Just check that we can assign the result to a variable of interface
+          // type.
+          TestTypeInterface result = testInterface.printTypeRet(op->getLoc());
+          (void)result;
         }
         if (auto testType = type.dyn_cast<TestType>())
           testType.printTypeE(op->getLoc());

diff  --git a/mlir/test/mlir-tblgen/interfaces.mlir b/mlir/test/mlir-tblgen/interfaces.mlir
index 712d93499152..5c1ec613b387 100644
--- a/mlir/test/mlir-tblgen/interfaces.mlir
+++ b/mlir/test/mlir-tblgen/interfaces.mlir
@@ -4,6 +4,7 @@
 // expected-remark at below {{'!test.test_type' - TestB}}
 // expected-remark at below {{'!test.test_type' - TestC}}
 // expected-remark at below {{'!test.test_type' - TestD}}
+// expected-remark at below {{'!test.test_type' - TestRet}}
 // expected-remark at below {{'!test.test_type' - TestE}}
 %foo0 = "foo.test"() : () -> (!test.test_type)
 

diff  --git a/mlir/test/mlir-tblgen/op-interface.td b/mlir/test/mlir-tblgen/op-interface.td
index 4ca2b798c3c9..7f5ae6c5cb85 100644
--- a/mlir/test/mlir-tblgen/op-interface.td
+++ b/mlir/test/mlir-tblgen/op-interface.td
@@ -41,9 +41,11 @@ def DeclareMethodsWithDefaultOp : Op<TestDialect, "declare_methods_op",
 
 // DECL-LABEL: TestOpInterfaceInterfaceTraits
 // DECL: class TestOpInterface : public ::mlir::OpInterface<TestOpInterface, detail::TestOpInterfaceInterfaceTraits>
+
 // DECL: int foo(int input);
 
-// DECL-NOT: TestOpInterface
+// DECL: template<typename ConcreteOp>
+// DECL: int detail::TestOpInterfaceInterfaceTraits::Model<ConcreteOp>::foo
 
 // OP_DECL-LABEL: class DeclareMethodsOp : public
 // OP_DECL: int foo(int input);

diff  --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index 1b0cb273638f..1a8f6b78575b 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -82,6 +82,7 @@ class InterfaceGenerator {
 
   void emitConceptDecl(Interface &interface);
   void emitModelDecl(Interface &interface);
+  void emitModelMethodsDef(Interface &interface);
   void emitTraitDecl(Interface &interface, StringRef interfaceName,
                      StringRef interfaceTraitsName);
   void emitInterfaceDecl(Interface interface);
@@ -217,11 +218,25 @@ void InterfaceGenerator::emitModelDecl(Interface &interface) {
 
   // Insert each of the virtual method overrides.
   for (auto &method : interface.getMethods()) {
-    emitCPPType(method.getReturnType(), os << "    static ");
+    emitCPPType(method.getReturnType(), os << "    static inline ");
     emitMethodNameAndArgs(method, os, valueType,
                           /*addThisArg=*/!method.isStatic(),
                           /*addConst=*/false);
-    os << " {\n      ";
+    os << ";\n";
+  }
+  os << "  };\n";
+}
+
+void InterfaceGenerator::emitModelMethodsDef(Interface &interface) {
+  for (auto &method : interface.getMethods()) {
+    os << "template<typename " << valueTemplate << ">\n";
+    emitCPPType(method.getReturnType(), os);
+    os << "detail::" << interface.getName() << "InterfaceTraits::Model<"
+       << valueTemplate << ">::";
+    emitMethodNameAndArgs(method, os, valueType,
+                          /*addThisArg=*/!method.isStatic(),
+                          /*addConst=*/false);
+    os << " {\n  ";
 
     // Check for a provided body to the function.
     if (Optional<StringRef> body = method.getBody()) {
@@ -229,7 +244,7 @@ void InterfaceGenerator::emitModelDecl(Interface &interface) {
         os << body->trim();
       else
         os << tblgen::tgfmt(body->trim(), &nonStaticMethodFmt);
-      os << "\n    }\n";
+      os << "\n}\n";
       continue;
     }
 
@@ -244,9 +259,8 @@ void InterfaceGenerator::emitModelDecl(Interface &interface) {
     llvm::interleaveComma(
         method.getArguments(), os,
         [&](const InterfaceMethod::Argument &arg) { os << arg.name; });
-    os << ");\n    }\n";
+    os << ");\n}\n";
   }
-  os << "  };\n";
 }
 
 void InterfaceGenerator::emitTraitDecl(Interface &interface,
@@ -308,6 +322,10 @@ void InterfaceGenerator::emitInterfaceDecl(Interface interface) {
   StringRef interfaceName = interface.getName();
   auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str();
 
+  // Emit a forward declaration of the interface class so that it becomes usable
+  // in the signature of its methods.
+  os << "class " << interfaceName << ";\n";
+
   // Emit the traits struct containing the concept and model declarations.
   os << "namespace detail {\n"
      << "struct " << interfaceTraitsName << " {\n";
@@ -340,6 +358,8 @@ void InterfaceGenerator::emitInterfaceDecl(Interface interface) {
 
   os << "};\n";
 
+  emitModelMethodsDef(interface);
+
   for (StringRef ns : llvm::reverse(namespaces))
     os << "} // namespace " << ns << "\n";
 }


        


More information about the Mlir-commits mailing list