[Mlir-commits] [mlir] [MLIR][TableGen] Add inheritableExtraClassDeclaration/Definition for Op and AttrOrTypeDef (PR #182265)

Henrich Lauko llvmlistbot at llvm.org
Thu Feb 19 06:17:46 PST 2026


https://github.com/xlauko updated https://github.com/llvm/llvm-project/pull/182265

>From 2a75e2da83bf8a43c5b5472664488e81cffaf870 Mon Sep 17 00:00:00 2001
From: xlauko <xlauko at mail.muni.cz>
Date: Thu, 19 Feb 2026 14:22:18 +0100
Subject: [PATCH] [MLIR][TableGen] Add
 inheritableExtraClassDeclaration/Definition for Op and AttrOrTypeDef

---
 mlir/include/mlir/IR/AttrTypeBase.td        |   9 ++
 mlir/include/mlir/IR/OpBase.td              |   9 ++
 mlir/include/mlir/TableGen/AttrOrTypeDef.h  |   8 ++
 mlir/include/mlir/TableGen/Operator.h       |  10 ++
 mlir/lib/TableGen/AttrOrTypeDef.cpp         |  41 ++++++++
 mlir/lib/TableGen/Operator.cpp              |  44 +++++++++
 mlir/test/mlir-tblgen/attrdefs.td           |  82 ++++++++++++++++
 mlir/test/mlir-tblgen/op-decl-and-defs.td   | 101 ++++++++++++++++++++
 mlir/test/mlir-tblgen/typedefs.td           |  82 ++++++++++++++++
 mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp |   8 +-
 mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp |   8 +-
 11 files changed, 398 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td
index 16f7f8b532521..eee78eb8c2a94 100644
--- a/mlir/include/mlir/IR/AttrTypeBase.td
+++ b/mlir/include/mlir/IR/AttrTypeBase.td
@@ -250,6 +250,15 @@ class AttrOrTypeDef<string valueType, string name, list<Trait> defTraits,
   // replaced by the class name.
   code extraClassDefinition = [{}];
 
+  // Extra class declarations/definitions that are inherited by all derived
+  // classes. These can be set at any level in the class hierarchy. Unlike
+  // extraClassDeclaration/extraClassDefinition, the inheritable values carry
+  // over to all derived classes — both the inheritable and the regular extra
+  // declarations are concatenated in the generated code. A derived class can
+  // discard inherited declarations by setting these to empty [{}].
+  code inheritableExtraClassDeclaration = [{}];
+  code inheritableExtraClassDefinition = [{}];
+
   // Generate a default 'getAlias' method for OpAsm{Type,Attr}Interface.
   bit genMnemonicAlias = 0;
 }
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 7a667d701ab71..4d288c6fc09f4 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -440,6 +440,15 @@ class Op<Dialect dialect, string mnemonic, list<Trait> props = []> {
   // generated code is placed inside the op's C++ namespace. `$cppClass` is
   // replaced by the op's C++ class name.
   code extraClassDefinition = ?;
+
+  // Extra class declarations/definitions that are inherited by all derived op
+  // classes. These can be set at any level in the class hierarchy. Unlike
+  // extraClassDeclaration/extraClassDefinition, the inheritable values carry
+  // over to all derived ops — both the inheritable and the regular extra
+  // declarations are concatenated in the generated code. A derived class can
+  // discard inherited declarations by setting these to empty [{}].
+  code inheritableExtraClassDeclaration = [{}];
+  code inheritableExtraClassDefinition = [{}];
 }
 
 // The arguments of an op.
diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
index 65992f9fef5e9..23e3bceb757ee 100644
--- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h
+++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
@@ -216,6 +216,14 @@ class AttrOrTypeDef {
   /// Returns the def's extra class definition code.
   std::optional<StringRef> getExtraDefs() const;
 
+  /// Collects inheritable extra class declarations accumulated across the
+  /// class hierarchy into `result`.
+  void getInheritableExtraDecls(SmallVectorImpl<StringRef> &result) const;
+
+  /// Collects inheritable extra class definitions accumulated across the
+  /// class hierarchy into `result`.
+  void getInheritableExtraDefs(SmallVectorImpl<StringRef> &result) const;
+
   /// Returns true if we need to generate a default 'getAlias' implementation
   /// using the mnemonic.
   bool genMnemonicAlias() const;
diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h
index f0514d8e61748..7821eccc927b2 100644
--- a/mlir/include/mlir/TableGen/Operator.h
+++ b/mlir/include/mlir/TableGen/Operator.h
@@ -303,6 +303,16 @@ class Operator {
   /// Returns this op's extra class definition code.
   StringRef getExtraClassDefinition() const;
 
+  /// Collects inheritable extra class declarations accumulated across the
+  /// class hierarchy into `result`.
+  void getInheritableExtraClassDeclarations(
+      SmallVectorImpl<StringRef> &result) const;
+
+  /// Collects inheritable extra class definitions accumulated across the
+  /// class hierarchy into `result`.
+  void getInheritableExtraClassDefinitions(
+      SmallVectorImpl<StringRef> &result) const;
+
   /// Returns the Tablegen definition this operator was constructed from.
   /// TODO: do not expose the TableGen record, this is a temporary solution to
   /// OpEmitter requiring a Record because Operator does not provide enough
diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp
index bf835a860cd5b..9ca619c6aec1e 100644
--- a/mlir/lib/TableGen/AttrOrTypeDef.cpp
+++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp
@@ -207,6 +207,47 @@ std::optional<StringRef> AttrOrTypeDef::getExtraDefs() const {
   return value.empty() ? std::optional<StringRef>() : value;
 }
 
+/// Walk the superclass chain and accumulate values for `fieldName`. Each class
+/// that explicitly sets a new (non-empty) value contributes to the result.
+/// Setting an empty value discards all previously accumulated values.
+static void accumulateInheritableField(const Record &def, StringRef fieldName,
+                                       SmallVectorImpl<StringRef> &result) {
+  StringRef prev;
+  for (const Record *superClass : def.getSuperClasses()) {
+    auto *val = superClass->getValue(fieldName);
+    if (!val)
+      continue;
+    auto *si = dyn_cast<StringInit>(val->getValue());
+    if (!si)
+      continue;
+    StringRef value = si->getValue();
+    if (value == prev)
+      continue;
+    if (value.empty())
+      result.clear();
+    else
+      result.push_back(value);
+    prev = value;
+  }
+  StringRef defValue = def.getValueAsString(fieldName);
+  if (defValue != prev) {
+    if (defValue.empty())
+      result.clear();
+    else
+      result.push_back(defValue);
+  }
+}
+
+void AttrOrTypeDef::getInheritableExtraDecls(
+    SmallVectorImpl<StringRef> &result) const {
+  accumulateInheritableField(*def, "inheritableExtraClassDeclaration", result);
+}
+
+void AttrOrTypeDef::getInheritableExtraDefs(
+    SmallVectorImpl<StringRef> &result) const {
+  accumulateInheritableField(*def, "inheritableExtraClassDefinition", result);
+}
+
 bool AttrOrTypeDef::genMnemonicAlias() const {
   return def->getValueAsBit("genMnemonicAlias");
 }
diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index 82dfbcbfa4d4f..2b53554d40a53 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -180,6 +180,50 @@ StringRef Operator::getExtraClassDefinition() const {
   return def.getValueAsString(attr);
 }
 
+/// Walk the superclass chain and accumulate values for `fieldName`. Each class
+/// that explicitly sets a new (non-empty) value contributes to the result.
+/// Setting an empty value discards all previously accumulated values.
+static void accumulateInheritableField(const Record &def, StringRef fieldName,
+                                       SmallVectorImpl<StringRef> &result) {
+  StringRef prev;
+  for (const Record *superClass : def.getSuperClasses()) {
+    auto *val = superClass->getValue(fieldName);
+    if (!val)
+      continue;
+    auto *si = dyn_cast<StringInit>(val->getValue());
+    if (!si)
+      continue;
+    StringRef value = si->getValue();
+    // Skip inherited (unchanged) values.
+    if (value == prev)
+      continue;
+    // An empty value means discard all accumulated declarations.
+    if (value.empty())
+      result.clear();
+    else
+      result.push_back(value);
+    prev = value;
+  }
+  // Check if the def itself overrides the field.
+  StringRef defValue = def.getValueAsString(fieldName);
+  if (defValue != prev) {
+    if (defValue.empty())
+      result.clear();
+    else
+      result.push_back(defValue);
+  }
+}
+
+void Operator::getInheritableExtraClassDeclarations(
+    SmallVectorImpl<StringRef> &result) const {
+  accumulateInheritableField(def, "inheritableExtraClassDeclaration", result);
+}
+
+void Operator::getInheritableExtraClassDefinitions(
+    SmallVectorImpl<StringRef> &result) const {
+  accumulateInheritableField(def, "inheritableExtraClassDefinition", result);
+}
+
 const Record &Operator::getDef() const { return def; }
 
 bool Operator::skipDefaultBuilders() const {
diff --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td
index a809611fd0aec..25b009f5ec14b 100644
--- a/mlir/test/mlir-tblgen/attrdefs.td
+++ b/mlir/test/mlir-tblgen/attrdefs.td
@@ -204,3 +204,85 @@ def J_CustomStorageCtorAttr : AttrDef<Test_Dialect, "CustomStorageCtorAttr"> {
 // DEF-LABEL: struct CustomStorageCtorAttrAttrStorage : public ::mlir::AttributeStorage
 // DEF: static CustomStorageCtorAttrAttrStorage *construct
 // DEF-SAME: (::mlir::AttributeStorageAllocator &allocator, KeyTy &&tblgenKey);
+
+// Test inheritable extra class declarations/definitions for attributes.
+
+class InheritableTestAttr<string name> : AttrDef<Test_Dialect, name> {
+  let inheritableExtraClassDeclaration = [{
+    int getInheritedHelper();
+  }];
+  let inheritableExtraClassDefinition = [{
+    int $cppClass::getInheritedHelper() { return 42; }
+  }];
+}
+
+// Both inheritable and regular extra declarations should appear.
+def K_InheritableAttrA : InheritableTestAttr<"InheritableA"> {
+  let attrName = "test.inheritable_a";
+  let extraClassDeclaration = [{
+    void doA();
+  }];
+  let extraClassDefinition = [{
+    void $cppClass::doA() {}
+  }];
+}
+
+// DECL-LABEL: class InheritableAAttr
+// DECL: int getInheritedHelper();
+// DECL: void doA();
+
+// DEF-LABEL: int InheritableAAttr::getInheritedHelper()
+// DEF: return 42;
+// DEF-LABEL: void InheritableAAttr::doA()
+
+// Only inheritable declarations (no extraClassDeclaration).
+def L_InheritableAttrB : InheritableTestAttr<"InheritableB"> {
+  let attrName = "test.inheritable_b";
+}
+
+// DECL-LABEL: class InheritableBAttr
+// DECL: int getInheritedHelper();
+
+// Discard inheritable declarations by setting to empty.
+def M_InheritableAttrC : InheritableTestAttr<"InheritableC"> {
+  let attrName = "test.inheritable_c";
+  let inheritableExtraClassDeclaration = [{}];
+  let inheritableExtraClassDefinition = [{}];
+}
+
+// DECL-LABEL: class InheritableCAttr
+// DECL-NOT: int getInheritedHelper();
+
+// Middle-of-stack: accumulates with base inheritable declarations.
+class InheritableMiddleAttr<string name> : InheritableTestAttr<name> {
+  let inheritableExtraClassDeclaration = [{
+    int getMiddleHelper();
+  }];
+  let inheritableExtraClassDefinition = [{
+    int $cppClass::getMiddleHelper() { return 1; }
+  }];
+}
+
+// Concrete attr inheriting from middle gets both base and middle.
+def N_InheritableAttrD : InheritableMiddleAttr<"InheritableD"> {
+  let attrName = "test.inheritable_d";
+}
+
+// DECL-LABEL: class InheritableDAttr
+// DECL: int getInheritedHelper();
+// DECL: int getMiddleHelper();
+
+// DEF-LABEL: int InheritableDAttr::getInheritedHelper()
+// DEF: return 42;
+// DEF-LABEL: int InheritableDAttr::getMiddleHelper()
+// DEF: return 1;
+
+// Passthrough: middle class doesn't set inheritable, base value passes through.
+class InheritablePassthroughAttr<string name> : InheritableTestAttr<name> {}
+
+def O_InheritableAttrE : InheritablePassthroughAttr<"InheritableE"> {
+  let attrName = "test.inheritable_e";
+}
+
+// DECL-LABEL: class InheritableEAttr
+// DECL: int getInheritedHelper();
diff --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td
index 80dedb8475b9e..a960d82dd9090 100644
--- a/mlir/test/mlir-tblgen/op-decl-and-defs.td
+++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td
@@ -360,6 +360,43 @@ def NS_IOp : NS_Op<"op_with_same_operands_and_result_types_trait", [SameOperands
 // CHECK: static IOp create(::mlir::OpBuilder &builder, ::mlir::Location location, ::mlir::ValueRange operands, const Properties &properties, ::llvm::ArrayRef<::mlir::NamedAttribute> discardableAttributes = {});
 // CHECK: static IOp create(::mlir::ImplicitLocOpBuilder &builder, ::mlir::ValueRange operands, const Properties &properties, ::llvm::ArrayRef<::mlir::NamedAttribute> discardableAttributes = {});
 
+// CHECK-LABEL: NS::InheritableOpA declarations
+// CHECK: int getInheritedHelper();
+// CHECK: void doA();
+
+// CHECK-LABEL: NS::InheritableOpB declarations
+// CHECK: int getInheritedHelper();
+
+// Discard: no inheritable declarations.
+// CHECK-LABEL: NS::InheritableOpC declarations
+// CHECK-NOT: int getInheritedHelper();
+
+// Middle-of-stack: gets both base and middle inheritable declarations.
+// CHECK-LABEL: NS::InheritableOpD declarations
+// CHECK: int getInheritedHelper();
+// CHECK: int getMiddleHelper();
+
+// Middle-of-stack with own extraClassDeclaration.
+// CHECK-LABEL: NS::InheritableOpE declarations
+// CHECK: int getInheritedHelper();
+// CHECK: int getMiddleHelper();
+// CHECK: void doE();
+
+// Passthrough: middle class doesn't set inheritable, base value passes through.
+// CHECK-LABEL: NS::InheritableOpF declarations
+// CHECK: int getInheritedHelper();
+
+// DEFS-LABEL: NS::InheritableOpA definitions
+// DEFS: int InheritableOpA::getInheritedHelper() { return 42; }
+// DEFS: void InheritableOpA::doA() {}
+
+// DEFS-LABEL: NS::InheritableOpD definitions
+// DEFS: int InheritableOpD::getInheritedHelper() { return 42; }
+// DEFS: int InheritableOpD::getMiddleHelper() { return 1; }
+
+// DEFS-LABEL: NS::InheritableOpF definitions
+// DEFS: int InheritableOpF::getInheritedHelper() { return 42; }
+
 // Check default value of `attributes` for the `genInferredTypeCollectiveParamBuilder` builder
 def NS_JOp : NS_Op<"op_with_InferTypeOpInterface_interface", [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
   let arguments = (ins AnyType:$a, AnyType:$b);
@@ -551,3 +588,67 @@ def _TypeInferredPropOp : NS_Op<"type_inferred_prop_op_with_properties", [
   let results = (outs AnyType:$result);
   let hasCustomAssemblyFormat = 1;
 }
+
+// Test inheritable extra class declarations/definitions.
+class NS_InheritableOp<string mnemonic, list<Trait> traits = []>
+    : NS_Op<mnemonic, traits> {
+  let inheritableExtraClassDeclaration = [{
+    int getInheritedHelper();
+  }];
+  let inheritableExtraClassDefinition = [{
+    int $cppClass::getInheritedHelper() { return 42; }
+  }];
+}
+
+// Both inheritable and regular extra declarations should appear.
+def NS_InheritableOpA : NS_InheritableOp<"inheritable_op_a"> {
+  let extraClassDeclaration = [{
+    void doA();
+  }];
+  let extraClassDefinition = [{
+    void $cppClass::doA() {}
+  }];
+}
+
+// Only inheritable declarations (no extraClassDeclaration).
+def NS_InheritableOpB : NS_InheritableOp<"inheritable_op_b"> {}
+
+// Discard inheritable declarations by setting to empty.
+def NS_InheritableOpC : NS_InheritableOp<"inheritable_op_c"> {
+  let inheritableExtraClassDeclaration = [{}];
+  let inheritableExtraClassDefinition = [{}];
+}
+
+// Middle-of-stack: NS_Op -> NS_InheritableOp -> NS_InheritableMiddleOp
+// The middle class adds its own inheritableExtraClassDeclaration. Concrete ops
+// get both the base and middle inheritable declarations (accumulated).
+class NS_InheritableMiddleOp<string mnemonic, list<Trait> traits = []>
+    : NS_InheritableOp<mnemonic, traits> {
+  let inheritableExtraClassDeclaration = [{
+    int getMiddleHelper();
+  }];
+  let inheritableExtraClassDefinition = [{
+    int $cppClass::getMiddleHelper() { return 1; }
+  }];
+}
+
+// Concrete op inheriting from middle class gets the middle class's value.
+def NS_InheritableOpD : NS_InheritableMiddleOp<"inheritable_op_d"> {
+}
+
+// Concrete op inheriting from middle class with its own extraClassDeclaration.
+def NS_InheritableOpE : NS_InheritableMiddleOp<"inheritable_op_e"> {
+  let extraClassDeclaration = [{
+    void doE();
+  }];
+}
+
+// Middle class that does NOT set inheritableExtraClassDeclaration — inherits
+// the base class value and passes it through to concrete ops.
+class NS_InheritablePassthroughOp<string mnemonic, list<Trait> traits = []>
+    : NS_InheritableOp<mnemonic, traits> {
+}
+
+// Concrete op gets the original base class inheritable declarations.
+def NS_InheritableOpF : NS_InheritablePassthroughOp<"inheritable_op_f"> {
+}
diff --git a/mlir/test/mlir-tblgen/typedefs.td b/mlir/test/mlir-tblgen/typedefs.td
index b9e3a7954e361..74524e192493f 100644
--- a/mlir/test/mlir-tblgen/typedefs.td
+++ b/mlir/test/mlir-tblgen/typedefs.td
@@ -153,3 +153,85 @@ def E_IntegerType : TestType<"Integer"> {
 // DECL-NEXT: /// Return true if this is an unsigned integer type.
 // DECL-NEXT: bool isUnsigned() const { return getSignedness() == Unsigned; }
 }
+
+// Test inheritable extra class declarations/definitions for types.
+
+class InheritableTestType<string name> : TypeDef<Test_Dialect, name> {
+  let inheritableExtraClassDeclaration = [{
+    int getInheritedHelper();
+  }];
+  let inheritableExtraClassDefinition = [{
+    int $cppClass::getInheritedHelper() { return 42; }
+  }];
+}
+
+// Both inheritable and regular extra declarations should appear.
+def F_InheritableTypeA : InheritableTestType<"InheritableA"> {
+  let typeName = "test.inheritable_a";
+  let extraClassDeclaration = [{
+    void doA();
+  }];
+  let extraClassDefinition = [{
+    void $cppClass::doA() {}
+  }];
+}
+
+// DECL-LABEL: class InheritableAType
+// DECL: int getInheritedHelper();
+// DECL: void doA();
+
+// DEF-LABEL: int InheritableAType::getInheritedHelper()
+// DEF: return 42;
+// DEF-LABEL: void InheritableAType::doA()
+
+// Only inheritable declarations (no extraClassDeclaration).
+def G_InheritableTypeB : InheritableTestType<"InheritableB"> {
+  let typeName = "test.inheritable_b";
+}
+
+// DECL-LABEL: class InheritableBType
+// DECL: int getInheritedHelper();
+
+// Discard inheritable declarations by setting to empty.
+def H_InheritableTypeC : InheritableTestType<"InheritableC"> {
+  let typeName = "test.inheritable_c";
+  let inheritableExtraClassDeclaration = [{}];
+  let inheritableExtraClassDefinition = [{}];
+}
+
+// DECL-LABEL: class InheritableCType
+// DECL-NOT: int getInheritedHelper();
+
+// Middle-of-stack: accumulates with base inheritable declarations.
+class InheritableMiddleType<string name> : InheritableTestType<name> {
+  let inheritableExtraClassDeclaration = [{
+    int getMiddleHelper();
+  }];
+  let inheritableExtraClassDefinition = [{
+    int $cppClass::getMiddleHelper() { return 1; }
+  }];
+}
+
+// Concrete type inheriting from middle gets both base and middle.
+def I_InheritableTypeD : InheritableMiddleType<"InheritableD"> {
+  let typeName = "test.inheritable_d";
+}
+
+// DECL-LABEL: class InheritableDType
+// DECL: int getInheritedHelper();
+// DECL: int getMiddleHelper();
+
+// DEF-LABEL: int InheritableDType::getInheritedHelper()
+// DEF: return 42;
+// DEF-LABEL: int InheritableDType::getMiddleHelper()
+// DEF: return 1;
+
+// Passthrough: middle class doesn't set inheritable, base value passes through.
+class InheritablePassthroughType<string name> : InheritableTestType<name> {}
+
+def J_InheritableTypeE : InheritablePassthroughType<"InheritableE"> {
+  let typeName = "test.inheritable_e";
+}
+
+// DECL-LABEL: class InheritableEType
+// DECL: int getInheritedHelper();
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index 031e03071842f..de4c3be8f2377 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -273,7 +273,9 @@ void DefGen::createParentWithTraits() {
 /// Include declarations specified on NativeTrait
 static std::string formatExtraDeclarations(const AttrOrTypeDef &def) {
   SmallVector<StringRef> extraDeclarations;
-  // Include extra class declarations from NativeTrait
+  // Include inheritable extra class declarations.
+  def.getInheritableExtraDecls(extraDeclarations);
+  // Include extra class declarations from NativeTrait.
   for (const auto &trait : def.getTraits()) {
     if (auto *attrOrTypeTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
       StringRef value = attrOrTypeTrait->getExtraConcreteClassDeclaration();
@@ -292,7 +294,9 @@ static std::string formatExtraDeclarations(const AttrOrTypeDef &def) {
 /// replaced by the C++ class name.
 static std::string formatExtraDefinitions(const AttrOrTypeDef &def) {
   SmallVector<StringRef> extraDefinitions;
-  // Include extra class definitions from NativeTrait
+  // Include inheritable extra class definitions.
+  def.getInheritableExtraDefs(extraDefinitions);
+  // Include extra class definitions from NativeTrait.
   for (const auto &trait : def.getTraits()) {
     if (auto *attrOrTypeTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
       StringRef value = attrOrTypeTrait->getExtraConcreteClassDefinition();
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index d1f1e85371133..6c1b4d6cf3137 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1164,7 +1164,9 @@ static void genPropertyVerifier(
 /// Include declarations specified on NativeTrait
 static std::string formatExtraDeclarations(const Operator &op) {
   SmallVector<StringRef> extraDeclarations;
-  // Include extra class declarations from NativeTrait
+  // Include inheritable extra class declarations.
+  op.getInheritableExtraClassDeclarations(extraDeclarations);
+  // Include extra class declarations from NativeTrait.
   for (const auto &trait : op.getTraits()) {
     if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
       StringRef value = opTrait->getExtraConcreteClassDeclaration();
@@ -1182,7 +1184,9 @@ static std::string formatExtraDeclarations(const Operator &op) {
 /// Include declarations specified on NativeTrait
 static std::string formatExtraDefinitions(const Operator &op) {
   SmallVector<StringRef> extraDefinitions;
-  // Include extra class definitions from NativeTrait
+  // Include inheritable extra class definitions.
+  op.getInheritableExtraClassDefinitions(extraDefinitions);
+  // Include extra class definitions from NativeTrait.
   for (const auto &trait : op.getTraits()) {
     if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
       StringRef value = opTrait->getExtraConcreteClassDefinition();



More information about the Mlir-commits mailing list