[Mlir-commits] [mlir] 572c290 - [mlir][ODS] Add support for specifying the namespace of an interface.

River Riddle llvmlistbot at llvm.org
Sun Jul 12 14:20:06 PDT 2020


Author: River Riddle
Date: 2020-07-12T14:18:32-07:00
New Revision: 572c2905aeaef00a6fedfc4c54f21856ba4cc34e

URL: https://github.com/llvm/llvm-project/commit/572c2905aeaef00a6fedfc4c54f21856ba4cc34e
DIFF: https://github.com/llvm/llvm-project/commit/572c2905aeaef00a6fedfc4c54f21856ba4cc34e.diff

LOG: [mlir][ODS] Add support for specifying the namespace of an interface.

The namespace can be specified using the `cppNamespace` field. This matches the functionality already present on dialects, enums, etc. This fixes problems with using interfaces on operations in a different namespace than the interface was defined in.

Differential Revision: https://reviews.llvm.org/D83604

Added: 
    

Modified: 
    mlir/include/mlir/IR/OpAsmInterface.td
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/IR/OpImplementation.h
    mlir/include/mlir/IR/SymbolInterfaces.td
    mlir/include/mlir/IR/SymbolTable.h
    mlir/include/mlir/Interfaces/CallInterfaces.h
    mlir/include/mlir/Interfaces/CallInterfaces.td
    mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
    mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
    mlir/include/mlir/Interfaces/CopyOpInterface.h
    mlir/include/mlir/Interfaces/CopyOpInterface.td
    mlir/include/mlir/Interfaces/DerivedAttributeOpInterface.h
    mlir/include/mlir/Interfaces/DerivedAttributeOpInterface.td
    mlir/include/mlir/Interfaces/InferTypeOpInterface.h
    mlir/include/mlir/Interfaces/InferTypeOpInterface.td
    mlir/include/mlir/Interfaces/LoopLikeInterface.h
    mlir/include/mlir/Interfaces/LoopLikeInterface.td
    mlir/include/mlir/Interfaces/SideEffectInterfaces.h
    mlir/include/mlir/Interfaces/SideEffectInterfaces.td
    mlir/include/mlir/Interfaces/VectorUnrollInterface.h
    mlir/include/mlir/Interfaces/VectorUnrollInterface.td
    mlir/include/mlir/Interfaces/ViewLikeInterface.h
    mlir/include/mlir/Interfaces/ViewLikeInterface.td
    mlir/include/mlir/TableGen/Interfaces.h
    mlir/include/mlir/TableGen/OpTrait.h
    mlir/include/mlir/TableGen/SideEffects.h
    mlir/lib/TableGen/Interfaces.cpp
    mlir/lib/TableGen/OpTrait.cpp
    mlir/lib/TableGen/Operator.cpp
    mlir/lib/TableGen/SideEffects.cpp
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
    mlir/tools/mlir-tblgen/OpInterfacesGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OpAsmInterface.td b/mlir/include/mlir/IR/OpAsmInterface.td
index 752536a9e9a1..ec50288348c4 100644
--- a/mlir/include/mlir/IR/OpAsmInterface.td
+++ b/mlir/include/mlir/IR/OpAsmInterface.td
@@ -22,6 +22,7 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
     This interface provides hooks to interact with the AsmPrinter and AsmParser
     classes.
   }];
+  let cppNamespace = "::mlir";
 
   let methods = [
     InterfaceMethod<[{

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 4344d075bc34..9cc57a617289 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1803,6 +1803,12 @@ class Interface<string name> {
   // The name given to the c++ interface class.
   string cppClassName = name;
 
+  // The C++ namespace that this interface should be placed into.
+  //
+  // To specify nested namespaces, use "::" as the delimiter, e.g., given
+  // "A::B", ops will be placed in `namespace A { namespace B { <def> } }`.
+  string cppNamespace = "";
+
   // The list of methods defined by this interface.
   list<InterfaceMethod> methods = [];
 
@@ -1838,6 +1844,7 @@ class DeclareOpInterfaceMethods<OpInterface interface,
       : OpInterface<interface.cppClassName> {
     let description = interface.description;
     let cppClassName = interface.cppClassName;
+    let cppNamespace = interface.cppNamespace;
     let methods = interface.methods;
 
     // This field contains a set of method names that should always have their

diff  --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 126d20eacbe4..20660be4347c 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -764,6 +764,7 @@ class OpAsmDialectInterface
   virtual void getAsmBlockArgumentNames(Block *block,
                                         OpAsmSetValueNameFn setNameFn) const {}
 };
+} // end namespace mlir
 
 //===--------------------------------------------------------------------===//
 // Operation OpAsm interface.
@@ -772,6 +773,4 @@ class OpAsmDialectInterface
 /// The OpAsmOpInterface, see OpAsmInterface.td for more details.
 #include "mlir/IR/OpAsmInterface.h.inc"
 
-} // end namespace mlir
-
 #endif

diff  --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td
index 86b33aa36a60..148551324868 100644
--- a/mlir/include/mlir/IR/SymbolInterfaces.td
+++ b/mlir/include/mlir/IR/SymbolInterfaces.td
@@ -27,6 +27,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
     See [Symbols and SymbolTables](SymbolsAndSymbolTables.md) for more details
     and constraints on `Symbol` operations.
   }];
+  let cppNamespace = "::mlir";
 
   let methods = [
     InterfaceMethod<"Returns the name of this symbol.",

diff  --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h
index 0b035836ec61..7e52011f81ff 100644
--- a/mlir/include/mlir/IR/SymbolTable.h
+++ b/mlir/include/mlir/IR/SymbolTable.h
@@ -252,10 +252,9 @@ class SymbolTable : public TraitBase<ConcreteType, SymbolTable> {
 };
 
 } // end namespace OpTrait
+} // end namespace mlir
 
 /// Include the generated symbol interfaces.
 #include "mlir/IR/SymbolInterfaces.h.inc"
 
-} // end namespace mlir
-
 #endif // MLIR_IR_SYMBOLTABLE_H

diff  --git a/mlir/include/mlir/Interfaces/CallInterfaces.h b/mlir/include/mlir/Interfaces/CallInterfaces.h
index ddfd5a942e49..cc8e26eceba3 100644
--- a/mlir/include/mlir/Interfaces/CallInterfaces.h
+++ b/mlir/include/mlir/Interfaces/CallInterfaces.h
@@ -23,8 +23,9 @@ namespace mlir {
 struct CallInterfaceCallable : public PointerUnion<SymbolRefAttr, Value> {
   using PointerUnion<SymbolRefAttr, Value>::PointerUnion;
 };
+} // end namespace mlir
 
+/// Include the generated interface declarations.
 #include "mlir/Interfaces/CallInterfaces.h.inc"
-} // end namespace mlir
 
 #endif // MLIR_INTERFACES_CALLINTERFACES_H

diff  --git a/mlir/include/mlir/Interfaces/CallInterfaces.td b/mlir/include/mlir/Interfaces/CallInterfaces.td
index 18d927571d41..7db6730c5e99 100644
--- a/mlir/include/mlir/Interfaces/CallInterfaces.td
+++ b/mlir/include/mlir/Interfaces/CallInterfaces.td
@@ -29,6 +29,7 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
     indirect calls to other operations `call_indirect %foo`. An operation that
     uses this interface, must *not* also provide the `CallableOpInterface`.
   }];
+  let cppNamespace = "::mlir";
 
   let methods = [
     InterfaceMethod<[{
@@ -70,6 +71,7 @@ def CallableOpInterface : OpInterface<"CallableOpInterface"> {
     `%foo = dialect.create_function(...)`. These operations may only contain a
     single region, or subroutine.
   }];
+  let cppNamespace = "::mlir";
 
   let methods = [
     InterfaceMethod<[{

diff  --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index e18c46f745a2..7e609ca13a09 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -70,12 +70,6 @@ class RegionSuccessor {
   ValueRange inputs;
 };
 
-//===----------------------------------------------------------------------===//
-// ControlFlow Interfaces
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Interfaces/ControlFlowInterfaces.h.inc"
-
 //===----------------------------------------------------------------------===//
 // ControlFlow Traits
 //===----------------------------------------------------------------------===//
@@ -101,4 +95,11 @@ struct ReturnLike : public TraitBase<ConcreteType, ReturnLike> {
 
 } // end namespace mlir
 
+//===----------------------------------------------------------------------===//
+// ControlFlow Interfaces
+//===----------------------------------------------------------------------===//
+
+/// Include the generated interface declarations.
+#include "mlir/Interfaces/ControlFlowInterfaces.h.inc"
+
 #endif // MLIR_INTERFACES_CONTROLFLOWINTERFACES_H

diff  --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 34c7bade6fe1..8b5a0b769ab1 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -25,6 +25,8 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
     This interface provides information for branching terminator operations,
     i.e. terminator operations with successors.
   }];
+  let cppNamespace = "::mlir";
+
   let methods = [
     InterfaceMethod<[{
         Returns a mutable range of operands that correspond to the arguments of
@@ -96,6 +98,8 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
     branching behavior between held regions, i.e. this interface allows for
     expressing control flow information for region holding operations.
   }];
+  let cppNamespace = "::mlir";
+
   let methods = [
     InterfaceMethod<[{
         Returns the operands of this operation used as the entry arguments when

diff  --git a/mlir/include/mlir/Interfaces/CopyOpInterface.h b/mlir/include/mlir/Interfaces/CopyOpInterface.h
index d6dc409c2471..2f38eb326b53 100644
--- a/mlir/include/mlir/Interfaces/CopyOpInterface.h
+++ b/mlir/include/mlir/Interfaces/CopyOpInterface.h
@@ -15,10 +15,7 @@
 
 #include "mlir/IR/OpDefinition.h"
 
-namespace mlir {
-
+/// Include the generated interface declarations.
 #include "mlir/Interfaces/CopyOpInterface.h.inc"
 
-} // namespace mlir
-
 #endif // MLIR_INTERFACES_COPYOPINTERFACE_H_

diff  --git a/mlir/include/mlir/Interfaces/CopyOpInterface.td b/mlir/include/mlir/Interfaces/CopyOpInterface.td
index 658474d70d86..a503abc185d9 100644
--- a/mlir/include/mlir/Interfaces/CopyOpInterface.td
+++ b/mlir/include/mlir/Interfaces/CopyOpInterface.td
@@ -19,6 +19,7 @@ def CopyOpInterface : OpInterface<"CopyOpInterface"> {
   let description = [{
     A copy-like operation is one that copies from source value to target value.
   }];
+  let cppNamespace = "::mlir";
 
   let methods = [
     InterfaceMethod<

diff  --git a/mlir/include/mlir/Interfaces/DerivedAttributeOpInterface.h b/mlir/include/mlir/Interfaces/DerivedAttributeOpInterface.h
index debafc2438d2..63cd09f5bc42 100644
--- a/mlir/include/mlir/Interfaces/DerivedAttributeOpInterface.h
+++ b/mlir/include/mlir/Interfaces/DerivedAttributeOpInterface.h
@@ -15,8 +15,7 @@
 
 #include "mlir/IR/OpDefinition.h"
 
-namespace mlir {
+/// Include the generated interface declarations.
 #include "mlir/Interfaces/DerivedAttributeOpInterface.h.inc"
-} // namespace mlir
 
 #endif // MLIR_INTERFACES_DERIVEDATTRIBUTEOPINTERFACE_H_

diff  --git a/mlir/include/mlir/Interfaces/DerivedAttributeOpInterface.td b/mlir/include/mlir/Interfaces/DerivedAttributeOpInterface.td
index e6f370752bcf..92c901840790 100644
--- a/mlir/include/mlir/Interfaces/DerivedAttributeOpInterface.td
+++ b/mlir/include/mlir/Interfaces/DerivedAttributeOpInterface.td
@@ -23,6 +23,7 @@ def DerivedAttributeOpInterface : OpInterface<"DerivedAttributeOpInterface"> {
     from information of the operation. ODS generates convenience accessors for
     derived attributes and can be used to simplify translations.
   }];
+  let cppNamespace = "::mlir";
 
   let methods = [
     StaticInterfaceMethod<

diff  --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
index 67faeb56a51c..1ae4aa688c84 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
@@ -95,8 +95,6 @@ LogicalResult inferReturnTensorTypes(
 LogicalResult verifyInferredResultTypes(Operation *op);
 } // namespace detail
 
-#include "mlir/Interfaces/InferTypeOpInterface.h.inc"
-
 namespace OpTrait {
 
 /// Tensor type inference trait that constructs a tensor from the inferred
@@ -119,4 +117,7 @@ class InferTensorType : public TraitBase<ConcreteType, InferTensorType> {
 } // namespace OpTrait
 } // namespace mlir
 
+/// Include the generated interface declarations.
+#include "mlir/Interfaces/InferTypeOpInterface.h.inc"
+
 #endif // MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_

diff  --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
index 723cf99d38b3..c5132986ec97 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
@@ -25,6 +25,7 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
     Interface to infer the return types for an operation that could be used
     during op construction, verification or type inference.
   }];
+  let cppNamespace = "::mlir";
 
   let methods = [
     StaticInterfaceMethod<
@@ -73,6 +74,7 @@ def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
 
     The components consists of element type, shape and raw attribute.
   }];
+  let cppNamespace = "::mlir";
 
   let methods = [
     StaticInterfaceMethod<

diff  --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.h b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
index 5891470c9c6e..48399ad0d53a 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
@@ -15,10 +15,7 @@
 
 #include "mlir/IR/OpDefinition.h"
 
-namespace mlir {
-
+/// Include the generated interface declarations.
 #include "mlir/Interfaces/LoopLikeInterface.h.inc"
 
-} // namespace mlir
-
 #endif // MLIR_INTERFACES_LOOPLIKEINTERFACE_H_

diff  --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index cc05030352e7..0e4191b97f97 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -20,6 +20,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
     Encodes properties of a loop. Operations that implement this interface will
     be considered by loop-invariant code motion.
   }];
+  let cppNamespace = "::mlir";
 
   let methods = [
     InterfaceMethod<[{

diff  --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
index 76932e2ef529..181d218838ff 100644
--- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
+++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
@@ -215,13 +215,6 @@ struct Read : public Effect::Base<Read> {};
 struct Write : public Effect::Base<Write> {};
 } // namespace MemoryEffects
 
-//===----------------------------------------------------------------------===//
-// SideEffect Interfaces
-//===----------------------------------------------------------------------===//
-
-/// Include the definitions of the side effect interfaces.
-#include "mlir/Interfaces/SideEffectInterfaces.h.inc"
-
 //===----------------------------------------------------------------------===//
 // SideEffect Utilities
 //===----------------------------------------------------------------------===//
@@ -237,4 +230,11 @@ bool wouldOpBeTriviallyDead(Operation *op);
 
 } // end namespace mlir
 
+//===----------------------------------------------------------------------===//
+// SideEffect Interfaces
+//===----------------------------------------------------------------------===//
+
+/// Include the definitions of the side effect interfaces.
+#include "mlir/Interfaces/SideEffectInterfaces.h.inc"
+
 #endif // MLIR_INTERFACES_SIDEEFFECTS_H

diff  --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.td b/mlir/include/mlir/Interfaces/SideEffectInterfaces.td
index 26f2a9a7e455..2a4da16deec2 100644
--- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.td
+++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.td
@@ -142,6 +142,9 @@ class SideEffect<EffectOpInterfaceBase interface, string effectName,
   /// The parent interface that the effect belongs to.
   string interfaceTrait = interface.trait;
 
+  /// The cpp namespace of the interface trait.
+  string cppNamespace = interface.cppNamespace;
+
   /// The derived effect that is being applied.
   string effect = effectName;
 
@@ -156,6 +159,9 @@ class SideEffectsTraitBase<EffectOpInterfaceBase parentInterface,
   /// The name of the interface trait to use.
   let trait = parentInterface.trait;
 
+  /// The cpp namespace of the interface trait.
+  string cppNamespace = parentInterface.cppNamespace;
+
   /// The name of the base effects class.
   string baseEffectName = parentInterface.baseEffectName;
 
@@ -177,6 +183,7 @@ def MemoryEffectsOpInterface
     An interface used to query information about the memory effects applied by
     an operation.
   }];
+  let cppNamespace = "::mlir";
 }
 
 // The base class for defining specific memory effects.

diff  --git a/mlir/include/mlir/Interfaces/VectorUnrollInterface.h b/mlir/include/mlir/Interfaces/VectorUnrollInterface.h
index a1cf39c17ebe..a68cc3411533 100644
--- a/mlir/include/mlir/Interfaces/VectorUnrollInterface.h
+++ b/mlir/include/mlir/Interfaces/VectorUnrollInterface.h
@@ -17,10 +17,7 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/StandardTypes.h"
 
-namespace mlir {
-
+/// Include the generated interface declarations.
 #include "mlir/Interfaces/VectorUnrollInterface.h.inc"
 
-} // namespace mlir
-
 #endif // MLIR_INTERFACES_VECTORUNROLLINTERFACE_H

diff  --git a/mlir/include/mlir/Interfaces/VectorUnrollInterface.td b/mlir/include/mlir/Interfaces/VectorUnrollInterface.td
index b9cff8bdab1d..166780b20e77 100644
--- a/mlir/include/mlir/Interfaces/VectorUnrollInterface.td
+++ b/mlir/include/mlir/Interfaces/VectorUnrollInterface.td
@@ -19,6 +19,7 @@ def VectorUnrollOpInterface : OpInterface<"VectorUnrollOpInterface"> {
   let description = [{
     Encodes properties of an operation on vectors that can be unrolled.
   }];
+  let cppNamespace = "::mlir";
 
   let methods = [
     InterfaceMethod<[{

diff  --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index fe7dd803ccfb..8d319bbeee18 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -15,10 +15,7 @@
 
 #include "mlir/IR/OpDefinition.h"
 
-namespace mlir {
-
+/// Include the generated interface declarations.
 #include "mlir/Interfaces/ViewLikeInterface.h.inc"
 
-} // namespace mlir
-
 #endif // MLIR_INTERFACES_VIEWLIKEINTERFACE_H_

diff  --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
index 20b03b2315b1..bb00aff488b2 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
@@ -21,6 +21,7 @@ def ViewLikeOpInterface : OpInterface<"ViewLikeOpInterface"> {
     takes in a (view of) buffer (and potentially some other operands) and returns
     another view of buffer.
   }];
+  let cppNamespace = "::mlir";
 
   let methods = [
     InterfaceMethod<

diff  --git a/mlir/include/mlir/TableGen/Interfaces.h b/mlir/include/mlir/TableGen/Interfaces.h
index 4e12ed81fca1..a3462097e480 100644
--- a/mlir/include/mlir/TableGen/Interfaces.h
+++ b/mlir/include/mlir/TableGen/Interfaces.h
@@ -76,6 +76,9 @@ class Interface {
   // Return the name of this interface.
   StringRef getName() const;
 
+  // Return the C++ namespace of this interface.
+  StringRef getCppNamespace() const;
+
   // Return the methods of this interface.
   ArrayRef<InterfaceMethod> getMethods() const;
 

diff  --git a/mlir/include/mlir/TableGen/OpTrait.h b/mlir/include/mlir/TableGen/OpTrait.h
index 69c09b600d38..cf8c506eb9f7 100644
--- a/mlir/include/mlir/TableGen/OpTrait.h
+++ b/mlir/include/mlir/TableGen/OpTrait.h
@@ -98,7 +98,7 @@ class InterfaceOpTrait : public OpTrait {
   OpInterface getOpInterface() const;
 
   // Returns the trait corresponding to a C++ trait class.
-  StringRef getTrait() const;
+  std::string getTrait() const;
 
   static bool classof(const OpTrait *t) {
     return t->getKind() == Kind::Interface;

diff  --git a/mlir/include/mlir/TableGen/SideEffects.h b/mlir/include/mlir/TableGen/SideEffects.h
index 468010515252..7e464476cea1 100644
--- a/mlir/include/mlir/TableGen/SideEffects.h
+++ b/mlir/include/mlir/TableGen/SideEffects.h
@@ -30,7 +30,7 @@ class SideEffect : public Operator::VariableDecorator {
   StringRef getBaseEffectName() const;
 
   // Return the name of the Interface that the effect belongs to.
-  StringRef getInterfaceTrait() const;
+  std::string getInterfaceTrait() const;
 
   // Return the name of the resource class.
   StringRef getResource() const;

diff  --git a/mlir/lib/TableGen/Interfaces.cpp b/mlir/lib/TableGen/Interfaces.cpp
index 0a6dd5f6a642..1e6101f83cab 100644
--- a/mlir/lib/TableGen/Interfaces.cpp
+++ b/mlir/lib/TableGen/Interfaces.cpp
@@ -84,6 +84,11 @@ StringRef Interface::getName() const {
   return def->getValueAsString("cppClassName");
 }
 
+// Return the C++ namespace of this interface.
+StringRef Interface::getCppNamespace() const {
+  return def->getValueAsString("cppNamespace");
+}
+
 // Return the methods of this interface.
 ArrayRef<InterfaceMethod> Interface::getMethods() const { return methods; }
 

diff  --git a/mlir/lib/TableGen/OpTrait.cpp b/mlir/lib/TableGen/OpTrait.cpp
index b32c647b2c95..dbfd0d374b83 100644
--- a/mlir/lib/TableGen/OpTrait.cpp
+++ b/mlir/lib/TableGen/OpTrait.cpp
@@ -27,7 +27,7 @@ OpTrait OpTrait::create(const llvm::Init *init) {
     return OpTrait(Kind::Pred, def);
   if (def->isSubClassOf("GenInternalOpTrait"))
     return OpTrait(Kind::Internal, def);
-  if (def->isSubClassOf("OpInterface"))
+  if (def->isSubClassOf("OpInterfaceTrait"))
     return OpTrait(Kind::Interface, def);
   assert(def->isSubClassOf("NativeOpTrait"));
   return OpTrait(Kind::Native, def);
@@ -56,8 +56,11 @@ OpInterface InterfaceOpTrait::getOpInterface() const {
   return OpInterface(def);
 }
 
-llvm::StringRef InterfaceOpTrait::getTrait() const {
-  return def->getValueAsString("trait");
+std::string InterfaceOpTrait::getTrait() const {
+  llvm::StringRef trait = def->getValueAsString("trait");
+  llvm::StringRef cppNamespace = def->getValueAsString("cppNamespace");
+  return cppNamespace.empty() ? trait.str()
+                              : (cppNamespace + "::" + trait).str();
 }
 
 bool InterfaceOpTrait::shouldDeclareMethods() const {

diff  --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index 7e8b4d816000..3dd924566a8f 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -336,7 +336,7 @@ void tblgen::Operator::populateTypeInferenceInfo(
             llvm::formatv("{0}::Trait", inferTypeOpInterface).str()))
       return;
     if (const auto *opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
-      if (opTrait->getTrait().startswith(inferTypeOpInterface))
+      if (&opTrait->getDef() == inferTrait)
         return;
 
     if (!def.isSubClassOf("AllTypesMatch"))

diff  --git a/mlir/lib/TableGen/SideEffects.cpp b/mlir/lib/TableGen/SideEffects.cpp
index a2116ba3c37b..286cacfdacf8 100644
--- a/mlir/lib/TableGen/SideEffects.cpp
+++ b/mlir/lib/TableGen/SideEffects.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/TableGen/SideEffects.h"
+#include "llvm/ADT/Twine.h"
 #include "llvm/TableGen/Record.h"
 
 using namespace mlir;
@@ -24,8 +25,11 @@ StringRef SideEffect::getBaseEffectName() const {
   return def->getValueAsString("baseEffectName");
 }
 
-StringRef SideEffect::getInterfaceTrait() const {
-  return def->getValueAsString("interfaceTrait");
+std::string SideEffect::getInterfaceTrait() const {
+  StringRef trait = def->getValueAsString("interfaceTrait");
+  StringRef cppNamespace = def->getValueAsString("cppNamespace");
+  return cppNamespace.empty() ? trait.str()
+                              : (cppNamespace + "::" + trait).str();
 }
 
 StringRef SideEffect::getResource() const {

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index dcf40691e17f..b2b4245989b5 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -887,7 +887,8 @@ static bool canGenerateUnwrappedBuilder(Operator &op) {
 }
 
 static bool canInferType(Operator &op) {
-  return op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0;
+  return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
+         op.getNumRegions() == 0;
 }
 
 void OpEmitter::genSeparateArgParamBuilder() {
@@ -1917,7 +1918,7 @@ void OpEmitter::genOpAsmInterface() {
   // TODO: We could also add a flag to allow operations to opt in to this
   // generation, even if they only have a single operation.
   int numResults = op.getNumResults();
-  if (numResults <= 1 || op.getTrait("OpAsmOpInterface::Trait"))
+  if (numResults <= 1 || op.getTrait("::mlir::OpAsmOpInterface::Trait"))
     return;
 
   SmallVector<StringRef, 4> resultNames(numResults);
@@ -1927,7 +1928,7 @@ void OpEmitter::genOpAsmInterface() {
   // Don't add the trait if none of the results have a valid name.
   if (llvm::all_of(resultNames, [](StringRef name) { return name.empty(); }))
     return;
-  opClass.addTrait("OpAsmOpInterface::Trait");
+  opClass.addTrait("::mlir::OpAsmOpInterface::Trait");
 
   // Generate the right accessor for the number of results.
   auto &method = opClass.newMethod("void", "getAsmResultNames",

diff  --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index 5a5501d42b7e..8b27bc6de7c5 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -150,11 +150,16 @@ struct TypeInterfaceGenerator : public InterfaceGenerator {
 static void emitInterfaceDef(Interface interface, StringRef valueType,
                              raw_ostream &os) {
   StringRef interfaceName = interface.getName();
+  StringRef cppNamespace = interface.getCppNamespace();
+  cppNamespace.consume_front("::");
 
   // Insert the method definitions.
   bool isOpInterface = isa<OpInterface>(interface);
   for (auto &method : interface.getMethods()) {
-    emitCPPType(method.getReturnType(), os) << interfaceName << "::";
+    emitCPPType(method.getReturnType(), os);
+    if (!cppNamespace.empty())
+      os << cppNamespace << "::";
+    os << interfaceName << "::";
     emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
                           /*addConst=*/!isOpInterface);
 
@@ -287,6 +292,11 @@ void InterfaceGenerator::emitTraitDecl(Interface &interface,
 }
 
 void InterfaceGenerator::emitInterfaceDecl(Interface interface) {
+  llvm::SmallVector<StringRef, 2> namespaces;
+  llvm::SplitString(interface.getCppNamespace(), namespaces, "::");
+  for (StringRef ns : namespaces)
+    os << "namespace " << ns << " {\n";
+
   StringRef interfaceName = interface.getName();
   auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str();
 
@@ -321,6 +331,9 @@ void InterfaceGenerator::emitInterfaceDecl(Interface interface) {
     os << *extraDecls << "\n";
 
   os << "};\n";
+
+  for (StringRef ns : llvm::reverse(namespaces))
+    os << "} // namespace " << ns << "\n";
 }
 
 bool InterfaceGenerator::emitInterfaceDecls() {


        


More information about the Mlir-commits mailing list