[Mlir-commits] [mlir] 91dae57 - [mlir][DeclareOpInterfaceMethods] Allow specifying a set of methods to force declaration generation for.

River Riddle llvmlistbot at llvm.org
Wed Apr 29 16:49:46 PDT 2020


Author: River Riddle
Date: 2020-04-29T16:48:15-07:00
New Revision: 91dae5708708c0c0b3e2383b419005bfe0402ae0

URL: https://github.com/llvm/llvm-project/commit/91dae5708708c0c0b3e2383b419005bfe0402ae0
DIFF: https://github.com/llvm/llvm-project/commit/91dae5708708c0c0b3e2383b419005bfe0402ae0.diff

LOG: [mlir][DeclareOpInterfaceMethods] Allow specifying a set of methods to force declaration generation for.

Currently a declaration won't be generated if the method has a default implementation. Meaning that operations that wan't to override the default have to explicitly declare the method in the extraClassDeclarations. This revision adds an optional list parameter to DeclareOpInterfaceMethods to allow for specifying a set of methods that should always have the declarations generated, even if there is a default.

Differential Revision: https://reviews.llvm.org/D79030

Added: 
    

Modified: 
    mlir/docs/OpDefinitions.md
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/TableGen/OpTrait.h
    mlir/lib/TableGen/OpTrait.cpp
    mlir/test/mlir-tblgen/op-interface.td
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index d323121ac758..8c785f65ed97 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -442,7 +442,7 @@ def MyInterface : OpInterface<"MyInterface"> {
     // Provide only a default definition of the method.
     // Note: `ConcreteOp` corresponds to the derived operation typename.
     InterfaceMethod<"/*insert doc here*/",
-      "unsigned", "getNumInputsAndOutputs", (ins), /*methodBody=*/[{}], [{
+      "unsigned", "getNumWithDefault", (ins), /*methodBody=*/[{}], [{
         ConcreteOp op = cast<ConcreteOp>(getOperation());
         return op.getNumInputs() + op.getNumOutputs();
     }]>,
@@ -455,6 +455,13 @@ def MyInterface : OpInterface<"MyInterface"> {
 // declaration but instead handled by the op interface trait directly.
 def OpWithInferTypeInterfaceOp : Op<...
     [DeclareOpInterfaceMethods<MyInterface>]> { ... }
+
+// Methods that have a default implementation do not have declarations
+// generated. If an operation wishes to override the default behavior, it can
+// explicitly specify the method that it wishes to override. This will force
+// the generation of a declaration for those methods.
+def OpWithOverrideInferTypeInterfaceOp : Op<...
+    [DeclareOpInterfaceMethods<MyInterface, ["getNumWithDefault"]>]> { ... }
 ```
 
 A verification method can also be specified on the `OpInterface` by setting

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 7db3dd849c81..48ed99051642 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -562,7 +562,8 @@ def AtomicYieldOp : Std_Op<"atomic_yield", [
 //===----------------------------------------------------------------------===//
 
 def BranchOp : Std_Op<"br",
-    [DeclareOpInterfaceMethods<BranchOpInterface>, NoSideEffect, Terminator]> {
+    [DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
+     NoSideEffect, Terminator]> {
   let summary = "branch operation";
   let description = [{
     The `br` operation represents a branch operation in a function.
@@ -598,10 +599,6 @@ def BranchOp : Std_Op<"br",
 
     /// Erase the operand at 'index' from the operand list.
     void eraseOperand(unsigned index);
-
-    /// Returns the successor that would be chosen with the given constant
-    /// operands. Returns nullptr if a single successor could not be chosen.
-    Block *getSuccessorForOperands(ArrayRef<Attribute>);
   }];
 
   let hasCanonicalizer = 1;
@@ -991,7 +988,8 @@ def CmpIOp : Std_Op<"cmpi",
 //===----------------------------------------------------------------------===//
 
 def CondBranchOp : Std_Op<"cond_br",
-    [AttrSizedOperandSegments, DeclareOpInterfaceMethods<BranchOpInterface>,
+    [AttrSizedOperandSegments,
+     DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
      NoSideEffect, Terminator]> {
   let summary = "conditional branch operation";
   let description = [{
@@ -1098,10 +1096,6 @@ def CondBranchOp : Std_Op<"cond_br",
       eraseSuccessorOperand(falseIndex, index);
     }
 
-    /// Returns the successor that would be chosen with the given constant
-    /// operands. Returns nullptr if a single successor could not be chosen.
-    Block *getSuccessorForOperands(ArrayRef<Attribute> operands);
-
   private:
     /// Get the index of the first true destination operand.
     unsigned getTrueDestOperandIndex() { return 1; }

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 087aecb6b35d..c6f144edb94d 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1780,12 +1780,20 @@ class OpInterface<string name> : OpInterfaceTrait<name> {
 
 // Whether to declare the op interface methods in the op's header. This class
 // simply wraps an OpInterface but is used to indicate that the method
-// declarations should be generated.
-class DeclareOpInterfaceMethods<OpInterface interface> :
-  OpInterface<interface.cppClassName> {
+// declarations should be generated. This class takes an optional set of methods
+// that should have declarations generated even if the method has a default
+// implementation.
+class DeclareOpInterfaceMethods<OpInterface interface,
+                                list<string> overridenMethods = []>
+      : OpInterface<interface.cppClassName> {
     let description = interface.description;
     let cppClassName = interface.cppClassName;
     let methods = interface.methods;
+
+    // This field contains a set of method names that should always have their
+    // declarations generated. This allows for generating declarations for
+    // methods with default implementations that need to be overridden.
+    list<string> alwaysOverriddenMethods = overridenMethods;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/TableGen/OpTrait.h b/mlir/include/mlir/TableGen/OpTrait.h
index 2d212f7a9d7c..269c6393e434 100644
--- a/mlir/include/mlir/TableGen/OpTrait.h
+++ b/mlir/include/mlir/TableGen/OpTrait.h
@@ -15,6 +15,7 @@
 
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/StringRef.h"
+#include <vector>
 
 namespace llvm {
 class Init;
@@ -105,6 +106,10 @@ class InterfaceOpTrait : public OpTrait {
 
   // Whether the declaration of methods for this trait should be emitted.
   bool shouldDeclareMethods() const;
+
+  // Returns the methods that should always be declared if this interface is
+  // emitting declarations.
+  std::vector<StringRef> getAlwaysDeclaredMethods() const;
 };
 
 } // end namespace tblgen

diff  --git a/mlir/lib/TableGen/OpTrait.cpp b/mlir/lib/TableGen/OpTrait.cpp
index f8257bb3315d..7a1e9cef0559 100644
--- a/mlir/lib/TableGen/OpTrait.cpp
+++ b/mlir/lib/TableGen/OpTrait.cpp
@@ -63,3 +63,7 @@ llvm::StringRef InterfaceOpTrait::getTrait() const {
 bool InterfaceOpTrait::shouldDeclareMethods() const {
   return def->isSubClassOf("DeclareOpInterfaceMethods");
 }
+
+std::vector<StringRef> InterfaceOpTrait::getAlwaysDeclaredMethods() const {
+  return def->getValueAsListOfStrings("alwaysOverriddenMethods");
+}

diff  --git a/mlir/test/mlir-tblgen/op-interface.td b/mlir/test/mlir-tblgen/op-interface.td
index 7cda61da08e0..cb53a77ac0cb 100644
--- a/mlir/test/mlir-tblgen/op-interface.td
+++ b/mlir/test/mlir-tblgen/op-interface.td
@@ -1,4 +1,5 @@
 // RUN: mlir-tblgen -gen-op-interface-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL --dump-input-on-failure
+// RUN: mlir-tblgen -gen-op-decls -I %S/../../include %s | FileCheck %s --check-prefix=OP_DECL --dump-input-on-failure
 
 include "mlir/IR/OpBase.td"
 
@@ -12,6 +13,14 @@ def TestOpInterface : OpInterface<"TestOpInterface"> {
       /*methodName=*/"foo",
       /*args=*/(ins "int":$input)
     >,
+    InterfaceMethod<
+      /*desc=*/[{some function comment}],
+      /*retTy=*/"int",
+      /*methodName=*/"default_foo",
+      /*args=*/(ins "int":$input),
+      /*body=*/[{}],
+      /*defaultBody=*/[{ return 0; }]
+    >,
   ];
 }
 
@@ -27,8 +36,19 @@ def OpInterfaceOp : Op<TestDialect, "op_interface_op", [TestOpInterface]>;
 def DeclareMethodsOp : Op<TestDialect, "declare_methods_op",
                           [DeclareOpInterfaceMethods<TestOpInterface>]>;
 
+def DeclareMethodsWithDefaultOp : Op<TestDialect, "declare_methods_op",
+      [DeclareOpInterfaceMethods<TestOpInterface, ["default_foo"]>]>;
+
 // DECL-LABEL: TestOpInterfaceInterfaceTraits
 // DECL: class TestOpInterface : public OpInterface<TestOpInterface, detail::TestOpInterfaceInterfaceTraits>
 // DECL: int foo(int input);
 
 // DECL-NOT: TestOpInterface
+
+// OP_DECL-LABEL: class DeclareMethodsOp : public
+// OP_DECL: int foo(int input);
+// OP_DECL-NOT: int default_foo(int input);
+
+// OP_DECL-LABEL: class DeclareMethodsWithDefaultOp : public
+// OP_DECL: int foo(int input);
+// OP_DECL: int default_foo(int input);

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 35fb291498bb..29df09b551b8 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1282,10 +1282,23 @@ void OpEmitter::genOpInterfaceMethods() {
     if (!opTrait || !opTrait->shouldDeclareMethods())
       continue;
     auto interface = opTrait->getOpInterface();
-    for (auto method : interface.getMethods()) {
-      // Don't declare if the method has a body or a default implementation.
-      if (method.getBody() || method.getDefaultImplementation())
+
+    // Get the set of methods that should always be declared.
+    auto alwaysDeclaredMethodsVec = opTrait->getAlwaysDeclaredMethods();
+    llvm::StringSet<> alwaysDeclaredMethods;
+    alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(),
+                                 alwaysDeclaredMethodsVec.end());
+
+    for (const OpInterfaceMethod &method : interface.getMethods()) {
+      // Don't declare if the method has a body.
+      if (method.getBody())
         continue;
+      // Don't declare if the method has a default implementation and the op
+      // didn't request that it always be declared.
+      if (method.getDefaultImplementation() &&
+          !alwaysDeclaredMethods.count(method.getName()))
+        continue;
+
       std::string args;
       llvm::raw_string_ostream os(args);
       interleaveComma(method.getArguments(), os,


        


More information about the Mlir-commits mailing list