[llvm-branch-commits] [mlir] 2452334 - [MLIR] Generate inferReturnTypes declaration using InferTypeOpInterface trait.

Rahul Joshi via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Dec 4 09:28:11 PST 2020


Author: Rahul Joshi
Date: 2020-12-04T09:05:53-08:00
New Revision: 245233423e466979e11b39cbed676903892d07f8

URL: https://github.com/llvm/llvm-project/commit/245233423e466979e11b39cbed676903892d07f8
DIFF: https://github.com/llvm/llvm-project/commit/245233423e466979e11b39cbed676903892d07f8.diff

LOG: [MLIR] Generate inferReturnTypes declaration using InferTypeOpInterface trait.

- Instead of hardcoding the parameters and return types of 'inferReturnTypes', use the
  InferTypeOpInterface trait to generate the method declaration.
- Fix InferTypeOfInterface to use fully qualified type for inferReturnTypes results.

Differential Revision: https://reviews.llvm.org/D92585

Added: 
    

Modified: 
    mlir/include/mlir/Interfaces/InferTypeOpInterface.td
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
index ed62e9015bde..9de087b1b4ca 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
@@ -36,7 +36,7 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
       which an Operation would be created (e.g., as used in Operation::create)
       and the regions of the op.
       }],
-      /*retTy=*/"LogicalResult",
+      /*retTy=*/"::mlir::LogicalResult",
       /*methodName=*/"inferReturnTypes",
       /*args=*/(ins "::mlir::MLIRContext *":$context,
                     "::llvm::Optional<::mlir::Location>":$location,

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index c96fde648eb2..ccfb13fa3436 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -290,11 +290,16 @@ class OpEmitter {
   // Generates the traits used by the object.
   void genTraits();
 
-  // Generate the OpInterface methods.
+  // Generate the OpInterface methods for all interfaces.
   void genOpInterfaceMethods();
 
-  // Generate op interface method.
-  void genOpInterfaceMethod(const tblgen::InterfaceOpTrait *trait);
+  // Generate op interface methods for the given interface.
+  void genOpInterfaceMethods(const tblgen::InterfaceOpTrait *trait);
+
+  // Generate op interface method for the given interface method. If
+  // 'declaration' is true, generates a declaration, else a definition.
+  OpMethod *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
+                                 bool declaration = true);
 
   // Generate the side effect interface methods.
   void genSideEffectInterfaceMethods();
@@ -1588,7 +1593,7 @@ void OpEmitter::genFolderDecls() {
   }
 }
 
-void OpEmitter::genOpInterfaceMethod(const tblgen::InterfaceOpTrait *opTrait) {
+void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceOpTrait *opTrait) {
   auto interface = opTrait->getOpInterface();
 
   // Get the set of methods that should always be declared.
@@ -1606,23 +1611,29 @@ void OpEmitter::genOpInterfaceMethod(const tblgen::InterfaceOpTrait *opTrait) {
     if (method.getDefaultImplementation() &&
         !alwaysDeclaredMethods.count(method.getName()))
       continue;
-
-    SmallVector<OpMethodParameter, 4> paramList;
-    for (const InterfaceMethod::Argument &arg : method.getArguments())
-      paramList.emplace_back(arg.type, arg.name);
-
-    auto properties = method.isStatic() ? OpMethod::MP_StaticDeclaration
-                                        : OpMethod::MP_Declaration;
-    opClass.addMethodAndPrune(method.getReturnType(), method.getName(),
-                              properties, std::move(paramList));
+    genOpInterfaceMethod(method);
   }
 }
 
+OpMethod *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
+                                          bool declaration) {
+  SmallVector<OpMethodParameter, 4> paramList;
+  for (const InterfaceMethod::Argument &arg : method.getArguments())
+    paramList.emplace_back(arg.type, arg.name);
+
+  auto properties = method.isStatic() ? OpMethod::MP_Static : OpMethod::MP_None;
+  if (declaration)
+    properties =
+        static_cast<OpMethod::Property>(properties | OpMethod::MP_Declaration);
+  return opClass.addMethodAndPrune(method.getReturnType(), method.getName(),
+                                   properties, std::move(paramList));
+}
+
 void OpEmitter::genOpInterfaceMethods() {
   for (const auto &trait : op.getTraits()) {
     if (const auto *opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
       if (opTrait->shouldDeclareMethods())
-        genOpInterfaceMethod(opTrait);
+        genOpInterfaceMethods(opTrait);
   }
 }
 
@@ -1727,18 +1738,20 @@ void OpEmitter::genSideEffectInterfaceMethods() {
 void OpEmitter::genTypeInterfaceMethods() {
   if (!op.allResultTypesKnown())
     return;
-
-  SmallVector<OpMethodParameter, 4> paramList;
-  paramList.emplace_back("::mlir::MLIRContext *", "context");
-  paramList.emplace_back("::llvm::Optional<::mlir::Location>", "location");
-  paramList.emplace_back("::mlir::ValueRange", "operands");
-  paramList.emplace_back("::mlir::DictionaryAttr", "attributes");
-  paramList.emplace_back("::mlir::RegionRange", "regions");
-  paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::Type>&",
-                         "inferredReturnTypes");
-  auto *method =
-      opClass.addMethodAndPrune("::mlir::LogicalResult", "inferReturnTypes",
-                                OpMethod::MP_Static, std::move(paramList));
+  // Generate 'inferReturnTypes' method declaration using the interface method
+  // declared in 'InferTypeOpInterface' op interface.
+  const auto *trait = dyn_cast<InterfaceOpTrait>(
+      op.getTrait("::mlir::InferTypeOpInterface::Trait"));
+  auto interface = trait->getOpInterface();
+  OpMethod *method = [&]() -> OpMethod * {
+    for (const InterfaceMethod &interfaceMethod : interface.getMethods()) {
+      if (interfaceMethod.getName() == "inferReturnTypes") {
+        return genOpInterfaceMethod(interfaceMethod, /*declaration=*/false);
+      }
+    }
+    assert(0 && "unable to find inferReturnTypes interface method");
+    return nullptr;
+  }();
   auto &body = method->body();
   body << "  inferredReturnTypes.resize(" << op.getNumResults() << ");\n";
 


        


More information about the llvm-branch-commits mailing list