[Mlir-commits] [mlir] 47b0a9b - [ODS] Extra Concrete Declarations and Definitions under Traits

Jacques Pienaar llvmlistbot at llvm.org
Wed Jul 12 08:46:27 PDT 2023


Author: Amanda Tang
Date: 2023-07-12T08:46:19-07:00
New Revision: 47b0a9b9311ffdaab511c240de89b5da75b1252b

URL: https://github.com/llvm/llvm-project/commit/47b0a9b9311ffdaab511c240de89b5da75b1252b
DIFF: https://github.com/llvm/llvm-project/commit/47b0a9b9311ffdaab511c240de89b5da75b1252b.diff

LOG: [ODS] Extra Concrete Declarations and Definitions under Traits

Support extra concrete class declarations and definitions under NativeTrait that get injected into the class that specifies the trait. Extra declarations and definitions can be passed in as template arguments for NativeOpTraitNativeAttrTrait and NativeTypeTrait.

Usage examples of this feature include:

- Creating a wrapper Trait for authoring inferReturnTypes with the OpAdaptor by specifying necessary Op specific declarations and definitions directly in the trait
- Refactoring the InferTensorType trait

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D154731

Added: 
    

Modified: 
    mlir/docs/Traits.md
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/include/mlir/IR/AttrTypeBase.td
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/Interfaces/InferTypeOpInterface.h
    mlir/include/mlir/Interfaces/InferTypeOpInterface.td
    mlir/include/mlir/TableGen/Trait.h
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Interfaces/InferTypeOpInterface.cpp
    mlir/lib/TableGen/Trait.cpp
    mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
    mlir/tools/mlir-tblgen/OpClass.cpp
    mlir/tools/mlir-tblgen/OpClass.h
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Traits.md b/mlir/docs/Traits.md
index 27804194ffe8f9..74ab7784c9ab6d 100644
--- a/mlir/docs/Traits.md
+++ b/mlir/docs/Traits.md
@@ -100,6 +100,22 @@ Note: It is generally good practice to define the implementation of the
 `foldTrait` hook out-of-line as a free function when possible to avoid
 instantiating the implementation for every concrete operation type.
 
+### Extra Declarations and Definitions
+A trait may require additional declarations and definitions directly on
+the Operation, Attribute or Type instances which specify that trait.
+The `extraConcreteClassDeclaration` and `extraConcreteClassDefinition`
+fields under the `NativeTrait` class are mechanisms designed for injecting
+code directly into generated C++ Operation, Attribute or Type classes.
+
+Code within the `extraConcreteClassDeclaration` field will be formatted and copied
+into the generated C++ Operation, Attribute or Type class. Code within
+`extraConcreteClassDefinition` will be added to the generated source file inside
+the class’s C++ namespace. The substitution `$cppClass` is replaced by the C++ class
+name.
+
+The intention is to group trait specific logic together and reduce
+redundant extra declarations and definitions on the instances themselves.
+
 ### Parametric Traits
 
 The above demonstrates the definition of a simple self-contained trait. It is

diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 809f951be6025d..0c3f96ff70b9c4 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -903,7 +903,7 @@ def MemRef_ExtractStridedMetadataOp : MemRef_Op<"extract_strided_metadata", [
     Pure,
     SameVariadicResultSize,
     ViewLikeOpInterface,
-    DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+    InferTypeOpInterfaceAdaptor]> {
   let summary = "Extracts a buffer base with offset and strides";
   let description = [{
     Extracts a base buffer, offset and strides. This op allows additional layers

diff  --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td
index abefd4047b2ddd..d24a6b4df8c7bc 100644
--- a/mlir/include/mlir/IR/AttrTypeBase.td
+++ b/mlir/include/mlir/IR/AttrTypeBase.td
@@ -21,7 +21,14 @@ include "mlir/IR/OpBase.td"
 //===----------------------------------------------------------------------===//
 
 // These classes are used to define attribute specific traits.
-class NativeAttrTrait<string name> : NativeTrait<name, "Attribute">;
+
+// Specify attribute specific declarations and definitions in `extraAttrDeclaration`
+// and `extraAttrDefinition` template arguments.
+class NativeAttrTrait<string name,
+                      code extraAttrDeclaration = [{}],
+                      code extraAttrDefinition = [{}]>
+    : NativeTrait<name, "Attribute", extraAttrDeclaration, extraAttrDefinition>;
+
 class ParamNativeAttrTrait<string prop, string params>
     : ParamNativeTrait<prop, params, "Attribute">;
 class GenInternalAttrTrait<string prop> : GenInternalTrait<prop, "Attribute">;
@@ -32,7 +39,14 @@ class PredAttrTrait<string descr, Pred pred> : PredTrait<descr, pred>;
 //===----------------------------------------------------------------------===//
 
 // These classes are used to define type specific traits.
-class NativeTypeTrait<string name> : NativeTrait<name, "Type">;
+
+// Specify type specific declarations and definitions in `extraTypeDeclaration`
+// and `extraTypeDefinition` template arguments.
+class NativeTypeTrait<string name,
+                      code extraTypeDeclaration = [{}],
+                      code extraTypeDefinition = [{}]>
+    : NativeTrait<name, "Type", extraTypeDeclaration, extraTypeDefinition>;
+
 class ParamNativeTypeTrait<string prop, string params>
     : ParamNativeTrait<prop, params, "Type">;
 class GenInternalTypeTrait<string prop> : GenInternalTrait<prop, "Type">;

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 40674988114f2e..ddcf60c9071b8a 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1958,9 +1958,16 @@ class TraitList<list<Trait> props> : Trait {
 // NativeTrait corresponds to the MLIR C++ trait mechanism. The purpose to wrap
 // around C++ symbol string with this class is to make traits specified for
 // entities in TableGen less alien and more integrated.
-class NativeTrait<string name, string entityType> : Trait {
+// `extraConcreteClassDeclaration` and `extraConcreteClassDefinition` code
+// get injected into the entities in which the NativeTrait is specified for.
+class NativeTrait<string name, string entityType,
+                    code extraClassDeclaration = [{}],
+                    code extraClassDefinition = [{}]> : Trait {
   string trait = name;
   string cppNamespace = "::mlir::" # entityType # "Trait";
+
+  code extraConcreteClassDeclaration = extraClassDeclaration;
+  code extraConcreteClassDefinition = extraClassDefinition;
 }
 
 // ParamNativeTrait corresponds to the template-parameterized traits in the C++
@@ -1993,8 +2000,13 @@ class PredTrait<string descr, Pred pred> : Trait {
 class StructuralOpTrait;
 
 // These classes are used to define operation specific traits.
-class NativeOpTrait<string name, list<Trait> traits = []>
-    : NativeTrait<name, "Op"> {
+
+// Specify op specific declarations and definitions in `extraOpDeclaration`
+// and `extraOpDefinition` template arguments.
+class NativeOpTrait<string name, list<Trait> traits = [],
+                    code extraOpDeclaration = [{}],
+                    code extraOpDefinition = [{}]>
+    : NativeTrait<name, "Op", extraOpDeclaration, extraOpDefinition> {
   // Specify the list of traits that need to be verified before the verification
   // of this NativeOpTrait.
   list<Trait> dependentTraits = traits;

diff  --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
index 4ead5ec40ed936..747ec0a76f3f1d 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
@@ -237,19 +237,9 @@ class ValueShapeRange : public ValueRange::RangeBaseT {
 namespace detail {
 // Helper function to infer return tensor returns types given element and
 // shape inference function.
-//
-// TODO: Consider generating typedefs for trait member functions if this usage
-// becomes more common.
-LogicalResult inferReturnTensorTypes(
-    function_ref<
-        LogicalResult(MLIRContext *, std::optional<Location> location,
-                      ValueShapeRange operands, DictionaryAttr attributes,
-                      OpaqueProperties properties, RegionRange regions,
-                      SmallVectorImpl<ShapedTypeComponents> &retComponents)>
-        componentTypeFn,
-    MLIRContext *context, std::optional<Location> location, ValueRange operands,
-    DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
-    SmallVectorImpl<Type> &inferredReturnTypes);
+LogicalResult
+inferReturnTensorTypes(ArrayRef<ShapedTypeComponents> retComponents,
+                       SmallVectorImpl<Type> &inferredReturnTypes);
 
 /// Verifies that the inferred result types match the actual result types for
 /// the op. Precondition: op implements InferTypeOpInterface.
@@ -268,6 +258,10 @@ class InferTensorType;
 namespace mlir {
 namespace OpTrait {
 
+template <typename ConcreteType>
+class InferTypeOpInterfaceAdaptor
+    : public TraitBase<ConcreteType, InferTypeOpInterfaceAdaptor> {};
+
 /// Tensor type inference trait that constructs a tensor from the inferred
 /// shape and elemental types.
 /// Requires: Op implements InferShapedTypeOpInterface and InferTypeOpInterface.
@@ -276,24 +270,7 @@ namespace OpTrait {
 ///   trait is currently only used where the interfaces are, so keep it
 ///   restricted for now).
 template <typename ConcreteType>
-class InferTensorType : public TraitBase<ConcreteType, InferTensorType> {
-public:
-  static LogicalResult
-  inferReturnTypes(MLIRContext *context, std::optional<Location> location,
-                   ValueRange operands, DictionaryAttr attributes,
-                   OpaqueProperties properties, RegionRange regions,
-                   SmallVectorImpl<Type> &inferredReturnTypes) {
-    static_assert(
-        ConcreteType::template hasTrait<InferShapedTypeOpInterface::Trait>(),
-        "requires InferShapedTypeOpInterface to ensure succesful invocation");
-    static_assert(
-        ConcreteType::template hasTrait<InferTypeOpInterface::Trait>(),
-        "requires InferTypeOpInterface to ensure succesful invocation");
-    return ::mlir::detail::inferReturnTensorTypes(
-        ConcreteType::inferReturnTypeComponents, context, location, operands,
-        attributes, properties, regions, inferredReturnTypes);
-  }
-};
+class InferTensorType : public TraitBase<ConcreteType, InferTensorType> {};
 
 } // namespace OpTrait
 } // namespace mlir

diff  --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
index 6925e39b0a9c6e..c9c1c6cc9ab01c 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
@@ -184,18 +184,69 @@ def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
   ];
 }
 
+// Convenient trait to define a wrapper to inferReturnTypes that passes in the
+// Op Adaptor directly
+def InferTypeOpInterfaceAdaptor : TraitList<
+  [
+    // Op implements infer type op interface.
+    DeclareOpInterfaceMethods<InferTypeOpInterface>,
+    NativeOpTrait<
+      /*name=*/"InferTypeOpInterfaceAdaptor",
+      /*traits=*/[],
+      /*extraOpDeclaration=*/[{
+        static LogicalResult
+        inferReturnTypesAdaptor(MLIRContext *context,
+                                std::optional<Location> location,
+                                Adaptor adaptor,
+                                SmallVectorImpl<Type> &inferredReturnTypes);
+      }],
+      /*extraOpDefinition=*/[{
+        LogicalResult
+        $cppClass::inferReturnTypes(MLIRContext *context,
+                          std::optional<Location> location,
+                          ValueRange operands, DictionaryAttr attributes,
+                          OpaqueProperties properties, RegionRange regions,
+                          SmallVectorImpl<Type> &inferredReturnTypes) {
+          $cppClass::Adaptor adaptor(operands, attributes, properties, regions);
+          return $cppClass::inferReturnTypesAdaptor(context,
+            location, adaptor, inferredReturnTypes);
+        }
+      }]
+    >
+  ]>;
+
 // Convenience class grouping together type and shaped type op interfaces for
 // ops that have tensor return types.
 class InferTensorTypeBase<list<string> overridenMethods = []> : TraitList<
   [
     // Op implements infer type op interface.
-    InferTypeOpInterface,
+    DeclareOpInterfaceMethods<InferTypeOpInterface>,
     // The op will have methods implementing the ShapedType type inference
     // interface.
     DeclareOpInterfaceMethods<InferShapedTypeOpInterface, overridenMethods>,
     // The op produces tensors and will use the ShapedType type infer interface
     // along with knowledge that it is producing Tensors to infer the type.
-    NativeOpTrait<"InferTensorType">
+    NativeOpTrait<
+      /*name=*/"InferTensorType",
+      /*traits=*/[],
+      /*extraOpDeclaration=*/[{}],
+      /*extraOpDefinition=*/[{
+        LogicalResult
+        $cppClass::inferReturnTypes(MLIRContext *context,
+                          std::optional<Location> location,
+                          ValueRange operands, DictionaryAttr attributes,
+                          OpaqueProperties properties, RegionRange regions,
+                          SmallVectorImpl<Type> &inferredReturnTypes) {
+          SmallVector<ShapedTypeComponents, 2> retComponents;
+          if (failed($cppClass::inferReturnTypeComponents(context, location,
+                                    operands, attributes, properties, regions,
+                                    retComponents)))
+            return failure();
+          return ::mlir::detail::inferReturnTensorTypes(retComponents,
+                                    inferredReturnTypes);
+        }
+      }]
+    >
   ]>;
 
 def InferTensorType : InferTensorTypeBase<["inferReturnTypeComponents"]>;

diff  --git a/mlir/include/mlir/TableGen/Trait.h b/mlir/include/mlir/TableGen/Trait.h
index 8da5303855feee..bebb5c8528a5ae 100644
--- a/mlir/include/mlir/TableGen/Trait.h
+++ b/mlir/include/mlir/TableGen/Trait.h
@@ -68,6 +68,14 @@ class NativeTrait : public Trait {
   // Returns if this is a structural op trait.
   bool isStructuralOpTrait() const;
 
+  // Returns extra class declaration code to be added to the concrete instance
+  // when the trait is specified
+  StringRef getExtraConcreteClassDeclaration() const;
+
+  // Returns extra class definition code to be added to the concrete instance
+  // when the trait is specified
+  StringRef getExtraConcreteClassDefinition() const;
+
   static bool classof(const Trait *t) { return t->getKind() == Kind::Native; }
 };
 

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 6f5b8693be9585..3a42c2fd280826 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1355,14 +1355,11 @@ void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
 
 /// The number and type of the results are inferred from the
 /// shape of the source.
-LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
-    MLIRContext *context, std::optional<Location> location, ValueRange operands,
-    DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
+LogicalResult ExtractStridedMetadataOp::inferReturnTypesAdaptor(
+    MLIRContext *context, std::optional<Location> location,
+    ExtractStridedMetadataOp::Adaptor adaptor,
     SmallVectorImpl<Type> &inferredReturnTypes) {
-  ExtractStridedMetadataOpAdaptor extractAdaptor(operands, attributes,
-                                                 properties);
-  auto sourceType =
-      llvm::dyn_cast<MemRefType>(extractAdaptor.getSource().getType());
+  auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
   if (!sourceType)
     return failure();
 

diff  --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
index 00d1c51b0348f9..3c50c4c37c6f59 100644
--- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp
+++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
@@ -217,19 +217,8 @@ ShapeAdaptor ValueShapeRange::getShape(int index) const {
 }
 
 LogicalResult mlir::detail::inferReturnTensorTypes(
-    function_ref<
-        LogicalResult(MLIRContext *, std::optional<Location> location,
-                      ValueShapeRange operands, DictionaryAttr attributes,
-                      OpaqueProperties properties, RegionRange regions,
-                      SmallVectorImpl<ShapedTypeComponents> &retComponents)>
-        componentTypeFn,
-    MLIRContext *context, std::optional<Location> location, ValueRange operands,
-    DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
+    ArrayRef<ShapedTypeComponents> retComponents,
     SmallVectorImpl<Type> &inferredReturnTypes) {
-  SmallVector<ShapedTypeComponents, 2> retComponents;
-  if (failed(componentTypeFn(context, location, operands, attributes,
-                             properties, regions, retComponents)))
-    return failure();
   for (const auto &shapeAndType : retComponents) {
     Type elementTy = shapeAndType.getElementType();
     assert(elementTy && "element type required to construct tensor");

diff  --git a/mlir/lib/TableGen/Trait.cpp b/mlir/lib/TableGen/Trait.cpp
index ee4b999c4bd470..6246ba959b00ab 100644
--- a/mlir/lib/TableGen/Trait.cpp
+++ b/mlir/lib/TableGen/Trait.cpp
@@ -54,6 +54,14 @@ bool NativeTrait::isStructuralOpTrait() const {
   return def->isSubClassOf("StructuralOpTrait");
 }
 
+StringRef NativeTrait::getExtraConcreteClassDeclaration() const {
+  return def->getValueAsString("extraConcreteClassDeclaration");
+}
+
+StringRef NativeTrait::getExtraConcreteClassDefinition() const {
+  return def->getValueAsString("extraConcreteClassDefinition");
+}
+
 //===----------------------------------------------------------------------===//
 // InternalTrait
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index c1b4f2a2f10b3b..4dbeb30bf77241 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -214,14 +214,42 @@ void DefGen::createParentWithTraits() {
   defCls.addParent(std::move(defParent));
 }
 
+/// Include declarations specified on NativeTrait
+static std::string formatExtraDeclarations(const AttrOrTypeDef &def) {
+  SmallVector<StringRef> extraDeclarations;
+  // Include extra class declarations from NativeTrait
+  for (const auto &trait : def.getTraits()) {
+    if (auto *attrOrTypeTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
+      StringRef value = attrOrTypeTrait->getExtraConcreteClassDeclaration();
+      if (value.empty())
+        continue;
+      extraDeclarations.push_back(value);
+    }
+  }
+  if (std::optional<StringRef> extraDecl = def.getExtraDecls()) {
+    extraDeclarations.push_back(*extraDecl);
+  }
+  return llvm::join(extraDeclarations, "\n");
+}
+
 /// Extra class definitions have a `$cppClass` substitution that is to be
 /// replaced by the C++ class name.
 static std::string formatExtraDefinitions(const AttrOrTypeDef &def) {
+  SmallVector<StringRef> extraDefinitions;
+  // Include extra class definitions from NativeTrait
+  for (const auto &trait : def.getTraits()) {
+    if (auto *attrOrTypeTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
+      StringRef value = attrOrTypeTrait->getExtraConcreteClassDefinition();
+      if (value.empty())
+        continue;
+      extraDefinitions.push_back(value);
+    }
+  }
   if (std::optional<StringRef> extraDef = def.getExtraDefs()) {
-    FmtContext ctx = FmtContext().addSubst("cppClass", def.getCppClassName());
-    return tgfmt(*extraDef, &ctx).str();
+    extraDefinitions.push_back(*extraDef);
   }
-  return "";
+  FmtContext ctx = FmtContext().addSubst("cppClass", def.getCppClassName());
+  return tgfmt(llvm::join(extraDefinitions, "\n"), &ctx).str();
 }
 
 void DefGen::emitTopLevelDeclarations() {
@@ -230,9 +258,9 @@ void DefGen::emitTopLevelDeclarations() {
   defCls.declare<UsingDeclaration>("Base::Base");
 
   // Emit the extra declarations first in case there's a definition in there.
-  std::optional<StringRef> extraDecl = def.getExtraDecls();
+  std::string extraDecl = formatExtraDeclarations(def);
   std::string extraDef = formatExtraDefinitions(def);
-  defCls.declare<ExtraClassDeclaration>(extraDecl ? *extraDecl : "",
+  defCls.declare<ExtraClassDeclaration>(std::move(extraDecl),
                                         std::move(extraDef));
 }
 

diff  --git a/mlir/tools/mlir-tblgen/OpClass.cpp b/mlir/tools/mlir-tblgen/OpClass.cpp
index 698569c790e934..60fa1833ce625e 100644
--- a/mlir/tools/mlir-tblgen/OpClass.cpp
+++ b/mlir/tools/mlir-tblgen/OpClass.cpp
@@ -15,9 +15,10 @@ using namespace mlir::tblgen;
 // OpClass definitions
 //===----------------------------------------------------------------------===//
 
-OpClass::OpClass(StringRef name, StringRef extraClassDeclaration,
+OpClass::OpClass(StringRef name, std::string extraClassDeclaration,
                  std::string extraClassDefinition)
-    : Class(name.str()), extraClassDeclaration(extraClassDeclaration),
+    : Class(name.str()),
+      extraClassDeclaration(std::move(extraClassDeclaration)),
       extraClassDefinition(std::move(extraClassDefinition)),
       parent(addParent("::mlir::Op")) {
   parent.addTemplateParam(getClassName().str());
@@ -37,6 +38,5 @@ OpClass::OpClass(StringRef name, StringRef extraClassDeclaration,
 void OpClass::finalize() {
   Class::finalize();
   declare<VisibilityDeclaration>(Visibility::Public);
-  declare<ExtraClassDeclaration>(extraClassDeclaration.str(),
-                                 extraClassDefinition);
+  declare<ExtraClassDeclaration>(extraClassDeclaration, extraClassDefinition);
 }

diff  --git a/mlir/tools/mlir-tblgen/OpClass.h b/mlir/tools/mlir-tblgen/OpClass.h
index 6b90dd2c3a3a3d..20b96baf868c77 100644
--- a/mlir/tools/mlir-tblgen/OpClass.h
+++ b/mlir/tools/mlir-tblgen/OpClass.h
@@ -25,7 +25,7 @@ class OpClass : public Class {
   /// - inheritance of `print`
   /// - a type alias for the associated adaptor class
   ///
-  OpClass(StringRef name, StringRef extraClassDeclaration,
+  OpClass(StringRef name, std::string extraClassDeclaration,
           std::string extraClassDefinition);
 
   /// Add an op trait.
@@ -39,7 +39,7 @@ class OpClass : public Class {
 
 private:
   /// Hand-written extra class declarations.
-  StringRef extraClassDeclaration;
+  std::string extraClassDeclaration;
   /// Hand-written extra class definitions.
   std::string extraClassDefinition;
   /// The parent class, which also contains the traits to be inherited.

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index f26114cdf98f55..a935e21152b910 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -853,17 +853,45 @@ while (true) {{
       emitVerifier(namedAttr.attr, namedAttr.name, getVarName(namedAttr.name));
 }
 
+/// Include declarations specified on NativeTrait
+static std::string formatExtraDeclarations(const Operator &op) {
+  SmallVector<StringRef> extraDeclarations;
+  // Include extra class declarations from NativeTrait
+  for (const auto &trait : op.getTraits()) {
+    if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
+      StringRef value = opTrait->getExtraConcreteClassDeclaration();
+      if (value.empty())
+        continue;
+      extraDeclarations.push_back(value);
+    }
+  }
+  extraDeclarations.push_back(op.getExtraClassDeclaration());
+  return llvm::join(extraDeclarations, "\n");
+}
+
 /// Op extra class definitions have a `$cppClass` substitution that is to be
 /// replaced by the C++ class name.
+/// Include declarations specified on NativeTrait
 static std::string formatExtraDefinitions(const Operator &op) {
+  SmallVector<StringRef> extraDefinitions;
+  // Include extra class definitions from NativeTrait
+  for (const auto &trait : op.getTraits()) {
+    if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
+      StringRef value = opTrait->getExtraConcreteClassDefinition();
+      if (value.empty())
+        continue;
+      extraDefinitions.push_back(value);
+    }
+  }
+  extraDefinitions.push_back(op.getExtraClassDefinition());
   FmtContext ctx = FmtContext().addSubst("cppClass", op.getCppClassName());
-  return tgfmt(op.getExtraClassDefinition(), &ctx).str();
+  return tgfmt(llvm::join(extraDefinitions, "\n"), &ctx).str();
 }
 
 OpEmitter::OpEmitter(const Operator &op,
                      const StaticVerifierFunctionEmitter &staticVerifierEmitter)
     : def(op.getDef()), op(op),
-      opClass(op.getCppClassName(), op.getExtraClassDeclaration(),
+      opClass(op.getCppClassName(), formatExtraDeclarations(op),
               formatExtraDefinitions(op)),
       staticVerifierEmitter(staticVerifierEmitter),
       emitHelper(op, /*emitForOp=*/true) {


        


More information about the Mlir-commits mailing list