[Mlir-commits] [mlir] [mlir][NFC] Move and rename EnumAttrCase, EnumAttr C++ classes (PR #132650)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Mar 23 17:11:45 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-spirv

Author: Krzysztof Drewniak (krzysz00)

<details>
<summary>Changes</summary>

This moves the EnumAttrCase and EnumAttr classes from Attribute.h/.cpp to a new EnumInfo.h/cpp and renames them to EnumCase and EnumInfo, respectively.

This doesn't change any of the tablegen files or any user-facing aspects of the enum attribute generation system, just reorganizes code in order to make main PR (#<!-- -->132148) shorter.

---

Patch is 74.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/132650.diff


15 Files Affected:

- (modified) mlir/include/mlir/TableGen/Attribute.h (-74) 
- (added) mlir/include/mlir/TableGen/EnumInfo.h (+135) 
- (modified) mlir/include/mlir/TableGen/Pattern.h (+6-5) 
- (modified) mlir/lib/TableGen/Attribute.cpp (-94) 
- (modified) mlir/lib/TableGen/CMakeLists.txt (+1) 
- (added) mlir/lib/TableGen/EnumInfo.cpp (+130) 
- (modified) mlir/lib/TableGen/Pattern.cpp (+5-7) 
- (modified) mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp (+26-21) 
- (modified) mlir/tools/mlir-tblgen/EnumsGen.cpp (+85-81) 
- (modified) mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp (+38-37) 
- (modified) mlir/tools/mlir-tblgen/OpDocGen.cpp (+9-8) 
- (modified) mlir/tools/mlir-tblgen/OpFormatGen.cpp (+18-17) 
- (modified) mlir/tools/mlir-tblgen/RewriterGen.cpp (+4-4) 
- (modified) mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp (+46-45) 
- (modified) mlir/tools/mlir-tblgen/TosaUtilsGen.cpp (+3-2) 


``````````diff
diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h
index 62720e74849fc..dee81880bacab 100644
--- a/mlir/include/mlir/TableGen/Attribute.h
+++ b/mlir/include/mlir/TableGen/Attribute.h
@@ -16,7 +16,6 @@
 
 #include "mlir/Support/LLVM.h"
 #include "mlir/TableGen/Constraint.h"
-#include "llvm/ADT/StringRef.h"
 
 namespace llvm {
 class DefInit;
@@ -136,79 +135,6 @@ class ConstantAttr {
   const llvm::Record *def;
 };
 
-// Wrapper class providing helper methods for accessing enum attribute cases
-// defined in TableGen. This is used for enum attribute case backed by both
-// StringAttr and IntegerAttr.
-class EnumAttrCase : public Attribute {
-public:
-  explicit EnumAttrCase(const llvm::Record *record);
-  explicit EnumAttrCase(const llvm::DefInit *init);
-
-  // Returns the symbol of this enum attribute case.
-  StringRef getSymbol() const;
-
-  // Returns the textual representation of this enum attribute case.
-  StringRef getStr() const;
-
-  // Returns the value of this enum attribute case.
-  int64_t getValue() const;
-
-  // Returns the TableGen definition this EnumAttrCase was constructed from.
-  const llvm::Record &getDef() const;
-};
-
-// Wrapper class providing helper methods for accessing enum attributes defined
-// in TableGen.This is used for enum attribute case backed by both StringAttr
-// and IntegerAttr.
-class EnumAttr : public Attribute {
-public:
-  explicit EnumAttr(const llvm::Record *record);
-  explicit EnumAttr(const llvm::Record &record);
-  explicit EnumAttr(const llvm::DefInit *init);
-
-  static bool classof(const Attribute *attr);
-
-  // Returns true if this is a bit enum attribute.
-  bool isBitEnum() const;
-
-  // Returns the enum class name.
-  StringRef getEnumClassName() const;
-
-  // Returns the C++ namespaces this enum class should be placed in.
-  StringRef getCppNamespace() const;
-
-  // Returns the underlying type.
-  StringRef getUnderlyingType() const;
-
-  // Returns the name of the utility function that converts a value of the
-  // underlying type to the corresponding symbol.
-  StringRef getUnderlyingToSymbolFnName() const;
-
-  // Returns the name of the utility function that converts a string to the
-  // corresponding symbol.
-  StringRef getStringToSymbolFnName() const;
-
-  // Returns the name of the utility function that converts a symbol to the
-  // corresponding string.
-  StringRef getSymbolToStringFnName() const;
-
-  // Returns the return type of the utility function that converts a symbol to
-  // the corresponding string.
-  StringRef getSymbolToStringFnRetType() const;
-
-  // Returns the name of the utilit function that returns the max enum value
-  // used within the enum class.
-  StringRef getMaxEnumValFnName() const;
-
-  // Returns all allowed cases for this enum attribute.
-  std::vector<EnumAttrCase> getAllCases() const;
-
-  bool genSpecializedAttr() const;
-  const llvm::Record *getBaseAttrClass() const;
-  StringRef getSpecializedAttrClassName() const;
-  bool printBitEnumPrimaryGroups() const;
-};
-
 // Name of infer type op interface.
 extern const char *inferTypeOpInterface;
 
diff --git a/mlir/include/mlir/TableGen/EnumInfo.h b/mlir/include/mlir/TableGen/EnumInfo.h
new file mode 100644
index 0000000000000..196267864f325
--- /dev/null
+++ b/mlir/include/mlir/TableGen/EnumInfo.h
@@ -0,0 +1,135 @@
+//===- EnumInfo.h - EnumInfo wrapper class --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// EnumInfo wrapper to simplify using a TableGen Record defining an Enum
+// via EnumInfo and its `EnumCase`s.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_ENUMINFO_H_
+#define MLIR_TABLEGEN_ENUMINFO_H_
+
+#include "mlir/Support/LLVM.h"
+#include "mlir/TableGen/Attribute.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace llvm {
+class DefInit;
+class Record;
+} // namespace llvm
+
+namespace mlir {
+namespace tblgen {
+
+// Wrapper class providing around enum cases defined in TableGen.
+class EnumCase {
+public:
+  explicit EnumCase(const llvm::Record *record);
+  explicit EnumCase(const llvm::DefInit *init);
+
+  // Returns the symbol of this enum attribute case.
+  StringRef getSymbol() const;
+
+  // Returns the textual representation of this enum attribute case.
+  StringRef getStr() const;
+
+  // Returns the value of this enum attribute case.
+  int64_t getValue() const;
+
+  // Returns the TableGen definition this EnumAttrCase was constructed from.
+  const llvm::Record &getDef() const;
+
+protected:
+  // The TableGen definition of this constraint.
+  const llvm::Record *def;
+};
+
+// Wrapper class providing helper methods for accessing enums defined
+// in TableGen using EnumInfo. Some methods are only applicable when
+// the enum is also an attribute, or only when it is a bit enum.
+class EnumInfo {
+public:
+  explicit EnumInfo(const llvm::Record *record);
+  explicit EnumInfo(const llvm::Record &record);
+  explicit EnumInfo(const llvm::DefInit *init);
+
+  // Returns true if the given EnumInfo is a subclass of the named TableGen
+  // class.
+  bool isSubClassOf(StringRef className) const;
+
+  // Returns true if this enum is an EnumAttrInfo, thus making it define an
+  // attribute.
+  bool isEnumAttr() const;
+
+  // Create the `Attribute` wrapper around this EnumInfo if it is defining an
+  // attribute.
+  std::optional<Attribute> asEnumAttr() const;
+
+  // Returns true if this is a bit enum.
+  bool isBitEnum() const;
+
+  // Returns the enum class name.
+  StringRef getEnumClassName() const;
+
+  // Returns the C++ namespaces this enum class should be placed in.
+  StringRef getCppNamespace() const;
+
+  // Returns the summary of the enum.
+  StringRef getSummary() const;
+
+  // Returns the description of the enum.
+  StringRef getDescription() const;
+
+  // Returns the underlying type.
+  StringRef getUnderlyingType() const;
+
+  // Returns the name of the utility function that converts a value of the
+  // underlying type to the corresponding symbol.
+  StringRef getUnderlyingToSymbolFnName() const;
+
+  // Returns the name of the utility function that converts a string to the
+  // corresponding symbol.
+  StringRef getStringToSymbolFnName() const;
+
+  // Returns the name of the utility function that converts a symbol to the
+  // corresponding string.
+  StringRef getSymbolToStringFnName() const;
+
+  // Returns the return type of the utility function that converts a symbol to
+  // the corresponding string.
+  StringRef getSymbolToStringFnRetType() const;
+
+  // Returns the name of the utilit function that returns the max enum value
+  // used within the enum class.
+  StringRef getMaxEnumValFnName() const;
+
+  // Returns all allowed cases for this enum attribute.
+  std::vector<EnumCase> getAllCases() const;
+
+  // Only applicable for enum attributes.
+
+  bool genSpecializedAttr() const;
+  const llvm::Record *getBaseAttrClass() const;
+  StringRef getSpecializedAttrClassName() const;
+
+  // Only applicable for bit enums.
+
+  bool printBitEnumPrimaryGroups() const;
+
+  // Returns the TableGen definition this EnumAttrCase was constructed from.
+  const llvm::Record &getDef() const;
+
+protected:
+  // The TableGen definition of this constraint.
+  const llvm::Record *def;
+};
+
+} // namespace tblgen
+} // namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h
index 80f38fdeffee0..1c9e128f0a0fb 100644
--- a/mlir/include/mlir/TableGen/Pattern.h
+++ b/mlir/include/mlir/TableGen/Pattern.h
@@ -16,6 +16,7 @@
 
 #include "mlir/Support/LLVM.h"
 #include "mlir/TableGen/Argument.h"
+#include "mlir/TableGen/EnumInfo.h"
 #include "mlir/TableGen/Operator.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/Hashing.h"
@@ -78,8 +79,8 @@ class DagLeaf {
   // Returns true if this DAG leaf is specifying a constant attribute.
   bool isConstantAttr() const;
 
-  // Returns true if this DAG leaf is specifying an enum attribute case.
-  bool isEnumAttrCase() const;
+  // Returns true if this DAG leaf is specifying an enum case.
+  bool isEnumCase() const;
 
   // Returns true if this DAG leaf is specifying a string attribute.
   bool isStringAttr() const;
@@ -90,9 +91,9 @@ class DagLeaf {
   // Returns this DAG leaf as an constant attribute. Asserts if fails.
   ConstantAttr getAsConstantAttr() const;
 
-  // Returns this DAG leaf as an enum attribute case.
-  // Precondition: isEnumAttrCase()
-  EnumAttrCase getAsEnumAttrCase() const;
+  // Returns this DAG leaf as an enum case.
+  // Precondition: isEnumCase()
+  EnumCase getAsEnumCase() const;
 
   // Returns the matching condition template inside this DAG leaf. Assumes the
   // leaf is an operand/attribute matcher and asserts otherwise.
diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp
index f9fc58a40f334..142d194260942 100644
--- a/mlir/lib/TableGen/Attribute.cpp
+++ b/mlir/lib/TableGen/Attribute.cpp
@@ -146,98 +146,4 @@ StringRef ConstantAttr::getConstantValue() const {
   return def->getValueAsString("value");
 }
 
-EnumAttrCase::EnumAttrCase(const Record *record) : Attribute(record) {
-  assert(isSubClassOf("EnumAttrCaseInfo") &&
-         "must be subclass of TableGen 'EnumAttrInfo' class");
-}
-
-EnumAttrCase::EnumAttrCase(const DefInit *init)
-    : EnumAttrCase(init->getDef()) {}
-
-StringRef EnumAttrCase::getSymbol() const {
-  return def->getValueAsString("symbol");
-}
-
-StringRef EnumAttrCase::getStr() const { return def->getValueAsString("str"); }
-
-int64_t EnumAttrCase::getValue() const { return def->getValueAsInt("value"); }
-
-const Record &EnumAttrCase::getDef() const { return *def; }
-
-EnumAttr::EnumAttr(const Record *record) : Attribute(record) {
-  assert(isSubClassOf("EnumAttrInfo") &&
-         "must be subclass of TableGen 'EnumAttr' class");
-}
-
-EnumAttr::EnumAttr(const Record &record) : Attribute(&record) {}
-
-EnumAttr::EnumAttr(const DefInit *init) : EnumAttr(init->getDef()) {}
-
-bool EnumAttr::classof(const Attribute *attr) {
-  return attr->isSubClassOf("EnumAttrInfo");
-}
-
-bool EnumAttr::isBitEnum() const { return isSubClassOf("BitEnumAttr"); }
-
-StringRef EnumAttr::getEnumClassName() const {
-  return def->getValueAsString("className");
-}
-
-StringRef EnumAttr::getCppNamespace() const {
-  return def->getValueAsString("cppNamespace");
-}
-
-StringRef EnumAttr::getUnderlyingType() const {
-  return def->getValueAsString("underlyingType");
-}
-
-StringRef EnumAttr::getUnderlyingToSymbolFnName() const {
-  return def->getValueAsString("underlyingToSymbolFnName");
-}
-
-StringRef EnumAttr::getStringToSymbolFnName() const {
-  return def->getValueAsString("stringToSymbolFnName");
-}
-
-StringRef EnumAttr::getSymbolToStringFnName() const {
-  return def->getValueAsString("symbolToStringFnName");
-}
-
-StringRef EnumAttr::getSymbolToStringFnRetType() const {
-  return def->getValueAsString("symbolToStringFnRetType");
-}
-
-StringRef EnumAttr::getMaxEnumValFnName() const {
-  return def->getValueAsString("maxEnumValFnName");
-}
-
-std::vector<EnumAttrCase> EnumAttr::getAllCases() const {
-  const auto *inits = def->getValueAsListInit("enumerants");
-
-  std::vector<EnumAttrCase> cases;
-  cases.reserve(inits->size());
-
-  for (const Init *init : *inits) {
-    cases.emplace_back(cast<DefInit>(init));
-  }
-
-  return cases;
-}
-
-bool EnumAttr::genSpecializedAttr() const {
-  return def->getValueAsBit("genSpecializedAttr");
-}
-
-const Record *EnumAttr::getBaseAttrClass() const {
-  return def->getValueAsDef("baseAttrClass");
-}
-
-StringRef EnumAttr::getSpecializedAttrClassName() const {
-  return def->getValueAsString("specializedAttrClassName");
-}
-
-bool EnumAttr::printBitEnumPrimaryGroups() const {
-  return def->getValueAsBit("printBitEnumPrimaryGroups");
-}
-
 const char * ::mlir::tblgen::inferTypeOpInterface = "InferTypeOpInterface";
diff --git a/mlir/lib/TableGen/CMakeLists.txt b/mlir/lib/TableGen/CMakeLists.txt
index c4104e644147c..a90c55847718e 100644
--- a/mlir/lib/TableGen/CMakeLists.txt
+++ b/mlir/lib/TableGen/CMakeLists.txt
@@ -20,6 +20,7 @@ llvm_add_library(MLIRTableGen STATIC
   CodeGenHelpers.cpp
   Constraint.cpp
   Dialect.cpp
+  EnumInfo.cpp
   Format.cpp
   GenInfo.cpp
   Interfaces.cpp
diff --git a/mlir/lib/TableGen/EnumInfo.cpp b/mlir/lib/TableGen/EnumInfo.cpp
new file mode 100644
index 0000000000000..9f491d30f0e7f
--- /dev/null
+++ b/mlir/lib/TableGen/EnumInfo.cpp
@@ -0,0 +1,130 @@
+//===- EnumInfo.cpp - EnumInfo wrapper class ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/EnumInfo.h"
+#include "mlir/TableGen/Attribute.h"
+#include "llvm/TableGen/Record.h"
+
+using namespace mlir;
+using namespace mlir::tblgen;
+
+using llvm::DefInit;
+using llvm::Init;
+using llvm::Record;
+
+EnumCase::EnumCase(const Record *record) : def(record) {
+  assert(def->isSubClassOf("EnumAttrCaseInfo") &&
+         "must be subclass of TableGen 'EnumAttrCaseInfo' class");
+}
+
+EnumCase::EnumCase(const DefInit *init) : EnumCase(init->getDef()) {}
+
+StringRef EnumCase::getSymbol() const {
+  return def->getValueAsString("symbol");
+}
+
+StringRef EnumCase::getStr() const { return def->getValueAsString("str"); }
+
+int64_t EnumCase::getValue() const { return def->getValueAsInt("value"); }
+
+const Record &EnumCase::getDef() const { return *def; }
+
+EnumInfo::EnumInfo(const Record *record) : def(record) {
+  assert(isSubClassOf("EnumAttrInfo") &&
+         "must be subclass of TableGen 'EnumAttrInfo' class");
+}
+
+EnumInfo::EnumInfo(const Record &record) : EnumInfo(&record) {}
+
+EnumInfo::EnumInfo(const DefInit *init) : EnumInfo(init->getDef()) {}
+
+bool EnumInfo::isSubClassOf(StringRef className) const {
+  return def->isSubClassOf(className);
+}
+
+bool EnumInfo::isEnumAttr() const { return isSubClassOf("EnumAttrInfo"); }
+
+std::optional<Attribute> EnumInfo::asEnumAttr() const {
+  if (isEnumAttr())
+    return Attribute(def);
+  return std::nullopt;
+}
+
+bool EnumInfo::isBitEnum() const { return isSubClassOf("BitEnumAttr"); }
+
+StringRef EnumInfo::getEnumClassName() const {
+  return def->getValueAsString("className");
+}
+
+StringRef EnumInfo::getSummary() const {
+  return def->getValueAsString("summary");
+}
+
+StringRef EnumInfo::getDescription() const {
+  return def->getValueAsString("description");
+}
+
+StringRef EnumInfo::getCppNamespace() const {
+  return def->getValueAsString("cppNamespace");
+}
+
+StringRef EnumInfo::getUnderlyingType() const {
+  return def->getValueAsString("underlyingType");
+}
+
+StringRef EnumInfo::getUnderlyingToSymbolFnName() const {
+  return def->getValueAsString("underlyingToSymbolFnName");
+}
+
+StringRef EnumInfo::getStringToSymbolFnName() const {
+  return def->getValueAsString("stringToSymbolFnName");
+}
+
+StringRef EnumInfo::getSymbolToStringFnName() const {
+  return def->getValueAsString("symbolToStringFnName");
+}
+
+StringRef EnumInfo::getSymbolToStringFnRetType() const {
+  return def->getValueAsString("symbolToStringFnRetType");
+}
+
+StringRef EnumInfo::getMaxEnumValFnName() const {
+  return def->getValueAsString("maxEnumValFnName");
+}
+
+std::vector<EnumCase> EnumInfo::getAllCases() const {
+  const auto *inits = def->getValueAsListInit("enumerants");
+
+  std::vector<EnumCase> cases;
+  cases.reserve(inits->size());
+
+  for (const Init *init : *inits) {
+    cases.emplace_back(cast<DefInit>(init));
+  }
+
+  return cases;
+}
+
+bool EnumInfo::genSpecializedAttr() const {
+  return isSubClassOf("EnumAttrInfo") &&
+         def->getValueAsBit("genSpecializedAttr");
+}
+
+const Record *EnumInfo::getBaseAttrClass() const {
+  return def->getValueAsDef("baseAttrClass");
+}
+
+StringRef EnumInfo::getSpecializedAttrClassName() const {
+  return def->getValueAsString("specializedAttrClassName");
+}
+
+bool EnumInfo::printBitEnumPrimaryGroups() const {
+  return def->getValueAsBit("printBitEnumPrimaryGroups");
+}
+
+const Record &EnumInfo::getDef() const { return *def; }
diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index ac8c49c72d384..73e2803c21dae 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -57,9 +57,7 @@ bool DagLeaf::isNativeCodeCall() const {
 
 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
 
-bool DagLeaf::isEnumAttrCase() const {
-  return isSubClassOf("EnumAttrCaseInfo");
-}
+bool DagLeaf::isEnumCase() const { return isSubClassOf("EnumAttrCaseInfo"); }
 
 bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(def); }
 
@@ -74,9 +72,9 @@ ConstantAttr DagLeaf::getAsConstantAttr() const {
   return ConstantAttr(cast<DefInit>(def));
 }
 
-EnumAttrCase DagLeaf::getAsEnumAttrCase() const {
-  assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
-  return EnumAttrCase(cast<DefInit>(def));
+EnumCase DagLeaf::getAsEnumCase() const {
+  assert(isEnumCase() && "the DAG leaf must be an enum attribute case");
+  return EnumCase(cast<DefInit>(def));
 }
 
 std::string DagLeaf::getConditionTemplate() const {
@@ -776,7 +774,7 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
           verifyBind(infoMap.bindValue(treeArgName), treeArgName);
         } else {
           auto constraint = leaf.getAsConstraint();
-          bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() ||
+          bool isAttr = leaf.isAttrMatcher() || leaf.isEnumCase() ||
                         leaf.isConstantAttr() ||
                         constraint.getKind() == Constraint::Kind::CK_Attr;
 
diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
index 3f660ae151c74..5d4d9e90fff67 100644
--- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
@@ -15,6 +15,7 @@
 #include "mlir/TableGen/AttrOrTypeDef.h"
 #include "mlir/TableGen/Attribute.h"
 #include "mlir/TableGen/Dialect.h"
+#include "mlir/TableGen/EnumInfo.h"
 #include "mlir/TableGen/GenInfo.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/TableGen/Record.h"
@@ -44,14 +45,14 @@ static std::string makePythonEnumCaseName(StringRef name) {
 }
 
 /// Emits the Python class for the given enum.
-static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) {
-  os << formatv("class {0}({1}):\n", enumAttr.getEnumClassName(),
-                enumAttr.isBitEnum() ? "IntFlag" : "IntEnum");
-  if (!enumAttr.getSummary().empty())
-    os << formatv("    \"\"\"{0}\"\"\"\n", enumAttr.getSummary());
+static void emitEnumClass(EnumInfo enumInfo, raw_ostream &os) {
+  os << formatv("class {0}({1}):\n", enumInfo.getEnumClassName(),
+                enumInfo.isBitEnum() ? "IntFlag" : "IntEnum");
+  if (!enumInfo.getSummary().empty())
+    os << formatv("    \"\"\"{0}\"\"\"\n", enumInfo.getSummary());
   os << "\n";
 
-  for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
+  for (const EnumCase &enumCase : enumInfo.getAllCases()) {
     os << formatv("    {0} = {1}\n",
                   makePythonEnumCaseName(enumCase.getSymbol()),
                   enumCase.getValue() >= 0 ? std::to_string(enumCase.getValue())
@@ -60,7 +61,7 @@ static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) {
 
   os << "\n";
 
-  if (enumAttr.isBitEnum()) {
+  if (enumInfo.isBitEnum()) {
     os << formatv("    def __iter__(self):\n"
                   "        return iter([case for case in type(self) if "
                   "(self & case) is case])\n");
@@ -70,17 +71,17 @@ static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) {
   }
 
   os << formatv("    def __str__(self):\n");
-  if (enumAttr.isBitEnum())
+  if (enumInfo.isBitEnum())
     os << formatv("        if len(self) > 1:\n"
                   "            return \"{0}\".join(map(str, self))\n",
-                  enumAttr.getDef().getValueAs...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/132650


More information about the Mlir-commits mailing list