[Mlir-commits] [mlir] [mlir][python] namespace generated enums in python (PR #77830)

Maksim Levental llvmlistbot at llvm.org
Fri Jan 12 15:54:14 PST 2024


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/77830

>From 184e9fe2d5fef086dc45483c0f3abc8eb31f98d2 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 11 Jan 2024 14:56:18 -0600
Subject: [PATCH 1/2] namespace generated enums in python

---
 .../mlir-tblgen/EnumPythonBindingGen.cpp      | 66 +++++++++++--------
 mlir/tools/mlir-tblgen/OpGenHelpers.cpp       | 12 ++++
 mlir/tools/mlir-tblgen/OpGenHelpers.h         |  3 +
 mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 26 ++++++--
 4 files changed, 72 insertions(+), 35 deletions(-)

diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
index f4ced0803772edb..4c6e04ad12362f6 100644
--- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
@@ -105,7 +105,14 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
     return true;
   }
 
-  os << llvm::formatv("@register_attribute_builder(\"{0}\")\n",
+  llvm::SmallVector<StringRef> namespaces;
+  enumAttr.getStorageType().ltrim("::").split(namespaces, "::");
+  namespaces = llvm::SmallVector<StringRef>{llvm::drop_end(namespaces)};
+  std::string namespace_ = getAttributeNameSpace(namespaces);
+  if (!namespace_.empty())
+    namespace_ += "_";
+
+  os << llvm::formatv("@register_attribute_builder(\"{0}{1}\")\n", namespace_,
                       enumAttr.getAttrDefName());
   os << llvm::formatv("def _{0}(x, context):\n",
                       enumAttr.getAttrDefName().lower());
@@ -120,11 +127,33 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
 /// Emits an attribute builder for the given dialect enum attribute to support
 /// automatic conversion between enum values and attributes in Python. Returns
 /// `false` on success, `true` on failure.
-static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
-                                            StringRef formatString,
+static bool emitDialectEnumAttributeBuilder(const AttrOrTypeDef &attr,
                                             raw_ostream &os) {
-  os << llvm::formatv("@register_attribute_builder(\"{0}\")\n", attrDefName);
-  os << llvm::formatv("def _{0}(x, context):\n", attrDefName.lower());
+  StringRef mnemonic = attr.getMnemonic().value();
+  std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
+  StringRef dialect = attr.getDialect().getName();
+  std::string formatString;
+  if (assemblyFormat == "`<` $value `>`")
+    formatString =
+        llvm::formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str();
+  else if (assemblyFormat == "$value")
+    formatString =
+        llvm::formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str();
+  else {
+    llvm::errs()
+        << "unsupported assembly format for python enum bindings generation";
+    return true;
+  }
+
+  llvm::SmallVector<StringRef> namespaces;
+  attr.getStorageNamespace().ltrim("::").split(namespaces, "::");
+  std::string namespace_ = getAttributeNameSpace(namespaces);
+  if (!namespace_.empty())
+    namespace_ += "_";
+
+  os << llvm::formatv("@register_attribute_builder(\"{0}{1}\")\n", namespace_,
+                      attr.getName());
+  os << llvm::formatv("def _{0}(x, context):\n", attr.getName().lower());
   os << llvm::formatv("    return "
                       "_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
                       formatString);
@@ -142,29 +171,10 @@ static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
     emitEnumClass(enumAttr, os);
     emitAttributeBuilder(enumAttr, os);
   }
-  for (auto &it : recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
-    AttrOrTypeDef attr(&*it);
-    if (!attr.getMnemonic()) {
-      llvm::errs() << "enum case " << attr
-                   << " needs mnemonic for python enum bindings generation";
-      return true;
-    }
-    StringRef mnemonic = attr.getMnemonic().value();
-    std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
-    StringRef dialect = attr.getDialect().getName();
-    if (assemblyFormat == "`<` $value `>`") {
-      emitDialectEnumAttributeBuilder(
-          attr.getName(),
-          llvm::formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
-    } else if (assemblyFormat == "$value") {
-      emitDialectEnumAttributeBuilder(
-          attr.getName(),
-          llvm::formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
-    } else {
-      llvm::errs()
-          << "unsupported assembly format for python enum bindings generation";
-      return true;
-    }
+  for (const auto &it :
+       recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
+    const AttrOrTypeDef attr(&*it);
+    return emitDialectEnumAttributeBuilder(attr, os);
   }
 
   return false;
diff --git a/mlir/tools/mlir-tblgen/OpGenHelpers.cpp b/mlir/tools/mlir-tblgen/OpGenHelpers.cpp
index 7fd34df8460d398..fd6250f05a451f5 100644
--- a/mlir/tools/mlir-tblgen/OpGenHelpers.cpp
+++ b/mlir/tools/mlir-tblgen/OpGenHelpers.cpp
@@ -79,4 +79,16 @@ bool mlir::tblgen::isPythonReserved(StringRef str) {
   reserved.insert("issubclass");
   reserved.insert("type");
   return reserved.contains(str);
+}
+
+std::string
+mlir::tblgen::getAttributeNameSpace(llvm::SmallVector<StringRef> namespaces) {
+  std::string namespace_;
+  if (namespaces[0] == "mlir")
+    namespace_ = llvm::join(llvm::drop_begin(namespaces), "_");
+  else
+    namespace_ = llvm::join(namespaces, "_");
+  std::transform(namespace_.begin(), namespace_.end(), namespace_.begin(),
+                 tolower);
+  return namespace_;
 }
\ No newline at end of file
diff --git a/mlir/tools/mlir-tblgen/OpGenHelpers.h b/mlir/tools/mlir-tblgen/OpGenHelpers.h
index 3dcff14d1221ee3..447d48aeb7a6bb4 100644
--- a/mlir/tools/mlir-tblgen/OpGenHelpers.h
+++ b/mlir/tools/mlir-tblgen/OpGenHelpers.h
@@ -28,6 +28,9 @@ getRequestedOpDefinitions(const llvm::RecordKeeper &recordKeeper);
 /// Regenerate using python -c"print(set(sorted(__import__('keyword').kwlist)))"
 bool isPythonReserved(llvm::StringRef str);
 
+std::string
+getAttributeNameSpace(llvm::SmallVector<llvm::StringRef> namespaces);
+
 } // namespace tblgen
 } // namespace mlir
 
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 0770ed562309e73..37e40c1c66d9520 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -529,27 +529,31 @@ constexpr const char *multiResultAppendTemplate = "results.extend({0})";
 /// Template for attribute builder from raw input in the operation builder.
 ///   {0} is the builder argument name;
 ///   {1} is the attribute builder from raw;
-///   {2} is the attribute builder from raw.
+///   {2} is the attribute builder from raw;
+///   {3} is the attribute's dialect.
 /// Use the value the user passed in if either it is already an Attribute or
 /// there is no method registered to make it an Attribute.
 constexpr const char *initAttributeWithBuilderTemplate =
     R"Py(attributes["{1}"] = ({0} if (
     issubclass(type({0}), _ods_ir.Attribute) or
-    not _ods_ir.AttrBuilder.contains('{2}')) else
-      _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
+    not (_ods_ir.AttrBuilder.contains('{3}') or _ods_ir.AttrBuilder.contains('{2}{3}'))) else
+      (_ods_ir.AttrBuilder.get('{3}')({0}, context=_ods_context) if _ods_ir.AttrBuilder.contains('{3}')
+       else _ods_ir.AttrBuilder.contains('{2}{3}')({0}, context=_ods_context))))Py";
 
 /// Template for attribute builder from raw input for optional attribute in the
 /// operation builder.
 ///   {0} is the builder argument name;
 ///   {1} is the attribute builder from raw;
-///   {2} is the attribute builder from raw.
+///   {2} is the attribute builder from raw;
+///   {3} is the attribute's dialect.
 /// Use the value the user passed in if either it is already an Attribute or
 /// there is no method registered to make it an Attribute.
 constexpr const char *initOptionalAttributeWithBuilderTemplate =
     R"Py(if {0} is not None: attributes["{1}"] = ({0} if (
         issubclass(type({0}), _ods_ir.Attribute) or
-        not _ods_ir.AttrBuilder.contains('{2}')) else
-          _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
+        not (_ods_ir.AttrBuilder.contains('{3}') or _ods_ir.AttrBuilder.contains('{2}{3}'))) else
+          (_ods_ir.AttrBuilder.get('{3}')({0}, context=_ods_context) if _ods_ir.AttrBuilder.contains('{3}')
+           else _ods_ir.AttrBuilder.contains('{2}{3}')({0}, context=_ods_context))))Py";
 
 constexpr const char *initUnitAttributeTemplate =
     R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
@@ -677,11 +681,19 @@ populateBuilderLinesAttr(const Operator &op,
       continue;
     }
 
+    llvm::SmallVector<StringRef> namespaces;
+    attribute->attr.getStorageType().ltrim("::").split(namespaces, "::");
+    namespaces = llvm::SmallVector<StringRef>{llvm::drop_end(namespaces)};
+    std::string namespace_ = getAttributeNameSpace(namespaces);
+    if (!namespace_.empty())
+      namespace_ += "_";
+
     builderLines.push_back(llvm::formatv(
         attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
             ? initOptionalAttributeWithBuilderTemplate
             : initAttributeWithBuilderTemplate,
-        argNames[i], attribute->name, attribute->attr.getAttrDefName()));
+        argNames[i], attribute->name, namespace_,
+        attribute->attr.getAttrDefName()));
   }
 }
 

>From 33a26215e5d84aa7eb05dc53845a6425c5991a65 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Fri, 12 Jan 2024 17:53:59 -0600
Subject: [PATCH 2/2] working

---
 mlir/lib/TableGen/Attribute.cpp               |  4 +-
 mlir/python/CMakeLists.txt                    |  9 ++++-
 mlir/python/mlir/dialects/python_test.py      |  1 +
 mlir/test/python/CMakeLists.txt               |  2 +
 mlir/test/python/dialects/python_test.py      | 17 ++++++++
 mlir/test/python/lib/PythonTestDialect.cpp    |  2 +
 mlir/test/python/lib/PythonTestDialect.h      |  1 +
 mlir/test/python/lit.local.cfg                |  1 +
 mlir/test/python/python_test_ops.td           | 19 +++++++++
 .../mlir-tblgen/EnumPythonBindingGen.cpp      | 40 ++++++++++++-------
 mlir/tools/mlir-tblgen/OpGenHelpers.cpp       | 12 +++++-
 mlir/tools/mlir-tblgen/OpGenHelpers.h         |  4 +-
 mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 10 ++---
 13 files changed, 95 insertions(+), 27 deletions(-)

diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp
index 57c77c74106b964..415d94a919e5ae1 100644
--- a/mlir/lib/TableGen/Attribute.cpp
+++ b/mlir/lib/TableGen/Attribute.cpp
@@ -53,7 +53,9 @@ bool Attribute::isSymbolRefAttr() const {
   return isSubClassOf("SymbolRefAttr") || isSubClassOf("FlatSymbolRefAttr");
 }
 
-bool Attribute::isEnumAttr() const { return isSubClassOf("EnumAttrInfo"); }
+bool Attribute::isEnumAttr() const {
+  return isSubClassOf("EnumAttrInfo") || isSubClassOf("EnumAttr");
+}
 
 StringRef Attribute::getStorageType() const {
   const auto *init = def->getValueInit("storageType");
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 55c5973e40e525a..3239f69e39f9e12 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -606,12 +606,19 @@ if(MLIR_INCLUDE_TESTS)
     "dialects/_python_test_ops_gen.py"
     -gen-python-op-bindings
     -bind-dialect=python_test)
+  mlir_tablegen(
+    "dialects/_python_test_enums_gen.py"
+    -gen-python-enum-bindings
+    EXTRA_INCLUDES
+      "${MLIR_MAIN_SRC_DIR}/test/python")
   add_public_tablegen_target(PythonTestDialectPyIncGen)
   declare_mlir_python_sources(
     MLIRPythonTestSources.Dialects.PythonTest.ops_gen
     ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}"
     ADD_TO_PARENT MLIRPythonTestSources.Dialects.PythonTest
-    SOURCES "dialects/_python_test_ops_gen.py")
+    SOURCES
+      "dialects/_python_test_ops_gen.py"
+      "dialects/_python_test_enums_gen.py")
 
   declare_mlir_python_extension(MLIRPythonTestSources.PythonTestExtension
     MODULE_NAME _mlirPythonTest
diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py
index b5baa80bc767fb3..4c5f54e41b7be6c 100644
--- a/mlir/python/mlir/dialects/python_test.py
+++ b/mlir/python/mlir/dialects/python_test.py
@@ -3,6 +3,7 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from ._python_test_ops_gen import *
+from ._python_test_enums_gen import *
 from .._mlir_libs._mlirPythonTest import (
     TestAttr,
     TestType,
diff --git a/mlir/test/python/CMakeLists.txt b/mlir/test/python/CMakeLists.txt
index 1db957c86819d2c..ddc11236efe5b53 100644
--- a/mlir/test/python/CMakeLists.txt
+++ b/mlir/test/python/CMakeLists.txt
@@ -3,6 +3,8 @@ mlir_tablegen(lib/PythonTestDialect.h.inc -gen-dialect-decls)
 mlir_tablegen(lib/PythonTestDialect.cpp.inc -gen-dialect-defs)
 mlir_tablegen(lib/PythonTestOps.h.inc -gen-op-decls)
 mlir_tablegen(lib/PythonTestOps.cpp.inc -gen-op-defs)
+mlir_tablegen(lib/PythonTestEnums.h.inc -gen-enum-decls)
+mlir_tablegen(lib/PythonTestEnums.cpp.inc -gen-enum-defs)
 mlir_tablegen(lib/PythonTestAttributes.h.inc -gen-attrdef-decls)
 mlir_tablegen(lib/PythonTestAttributes.cpp.inc -gen-attrdef-defs)
 mlir_tablegen(lib/PythonTestTypes.h.inc -gen-typedef-decls)
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 88761c9d08fe07c..4ba4ccb0be6c3a3 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -5,6 +5,13 @@
 import mlir.dialects.python_test as test
 import mlir.dialects.tensor as tensor
 import mlir.dialects.arith as arith
+from mlir.dialects import llvm
+from mlir.dialects._llvm_enum_gen import (
+    _llvm_integeroverflowflagsattr as llvm_integeroverflowflagsattr,
+)
+from mlir.dialects._python_test_enums_gen import (
+    _llvm_integeroverflowflagsattr as python_test_integeroverflowflagsattr,
+)
 
 test.register_python_test_dialect(get_dialect_registry())
 
@@ -543,3 +550,13 @@ def testInferTypeOpInterface():
             two_operands = test.InferResultsVariadicInputsOp(single=zero, doubled=zero)
             # CHECK: f32
             print(two_operands.result.type)
+
+
+# CHECK-LABEL: TEST: testEnumNamespacing
+ at run
+def testEnumNamespacing():
+    with Context() as ctx, Location.unknown(ctx):
+        # CHECK: #llvm.overflow<none>
+        print(llvm_integeroverflowflagsattr(llvm.IntegerOverflowFlags.none, ctx))
+        # CHECK: #python_test.overflow<none>
+        print(python_test_integeroverflowflagsattr(test.IntegerOverflowFlags.none, ctx))
diff --git a/mlir/test/python/lib/PythonTestDialect.cpp b/mlir/test/python/lib/PythonTestDialect.cpp
index a0ff31504c6918b..738d2c7de5c98ad 100644
--- a/mlir/test/python/lib/PythonTestDialect.cpp
+++ b/mlir/test/python/lib/PythonTestDialect.cpp
@@ -9,10 +9,12 @@
 #include "PythonTestDialect.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/OpImplementation.h"
+#include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 
 #include "PythonTestDialect.cpp.inc"
 
+#include "PythonTestEnums.cpp.inc"
 #define GET_ATTRDEF_CLASSES
 #include "PythonTestAttributes.cpp.inc"
 
diff --git a/mlir/test/python/lib/PythonTestDialect.h b/mlir/test/python/lib/PythonTestDialect.h
index 044381fcd4728d7..26204ca43c3fa26 100644
--- a/mlir/test/python/lib/PythonTestDialect.h
+++ b/mlir/test/python/lib/PythonTestDialect.h
@@ -19,6 +19,7 @@
 #define GET_OP_CLASSES
 #include "PythonTestOps.h.inc"
 
+#include "PythonTestEnums.h.inc"
 #define GET_ATTRDEF_CLASSES
 #include "PythonTestAttributes.h.inc"
 
diff --git a/mlir/test/python/lit.local.cfg b/mlir/test/python/lit.local.cfg
index 12d6e1f22744a25..dccca70dda76d32 100644
--- a/mlir/test/python/lit.local.cfg
+++ b/mlir/test/python/lit.local.cfg
@@ -2,3 +2,4 @@ config.environment["ASAN_OPTIONS"] = "detect_leaks=0"
 if not config.enable_bindings_python:
     config.unsupported = True
 config.excludes.add("python_test_ops.td")
+config.excludes.add("python_test_enums.td")
diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td
index 95301985e3fde03..189082f17344234 100644
--- a/mlir/test/python/python_test_ops.td
+++ b/mlir/test/python/python_test_ops.td
@@ -10,6 +10,7 @@
 #define PYTHON_TEST_OPS
 
 include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/EnumAttr.td"
 include "mlir/IR/OpBase.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 
@@ -48,6 +49,24 @@ def TestType : TestType<"TestType", "test_type">;
 
 def TestAttr : TestAttr<"TestAttr", "test_attr">;
 
+def IOFnone : I32BitEnumAttrCaseNone<"none">;
+def IOFnsw  : I32BitEnumAttrCaseBit<"nsw", 0>;
+def IOFnuw  : I32BitEnumAttrCaseBit<"nuw", 1>;
+
+def IntegerOverflowFlags : I32BitEnumAttr<
+    "IntegerOverflowFlags", "",
+    [IOFnone, IOFnsw, IOFnuw]> {
+  let separator = ", ";
+  let cppNamespace = "python_test";
+  let genSpecializedAttr = 0;
+}
+
+// This is intentionally prefixed with LLVM to test for collision in AttrBuilder.
+def LLVM_IntegerOverflowFlagsAttr :
+    EnumAttr<Python_Test_Dialect, IntegerOverflowFlags, "overflow"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 //===----------------------------------------------------------------------===//
 // Operation definitions.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
index 4c6e04ad12362f6..39e8dff0f27bafa 100644
--- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
@@ -97,7 +97,8 @@ static bool extractUIntBitwidth(StringRef uintType, int64_t &bitwidth) {
 /// Emits an attribute builder for the given enum attribute to support automatic
 /// conversion between enum values and attributes in Python. Returns
 /// `false` on success, `true` on failure.
-static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
+static bool emitAttributeBuilderRegistration(const EnumAttr &enumAttr,
+                                             raw_ostream &os) {
   int64_t bitwidth;
   if (extractUIntBitwidth(enumAttr.getUnderlyingType(), bitwidth)) {
     llvm::errs() << "failed to identify bitwidth of "
@@ -105,10 +106,7 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
     return true;
   }
 
-  llvm::SmallVector<StringRef> namespaces;
-  enumAttr.getStorageType().ltrim("::").split(namespaces, "::");
-  namespaces = llvm::SmallVector<StringRef>{llvm::drop_end(namespaces)};
-  std::string namespace_ = getAttributeNameSpace(namespaces);
+  std::string namespace_ = getEnumAttributeNameSpace(enumAttr);
   if (!namespace_.empty())
     namespace_ += "_";
 
@@ -127,8 +125,9 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
 /// Emits an attribute builder for the given dialect enum attribute to support
 /// automatic conversion between enum values and attributes in Python. Returns
 /// `false` on success, `true` on failure.
-static bool emitDialectEnumAttributeBuilder(const AttrOrTypeDef &attr,
-                                            raw_ostream &os) {
+static bool emitDialectEnumAttributeBuilderRegistration(const llvm::Record &def,
+                                                        raw_ostream &os) {
+  const AttrOrTypeDef attr(&def);
   StringRef mnemonic = attr.getMnemonic().value();
   std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
   StringRef dialect = attr.getDialect().getName();
@@ -145,15 +144,15 @@ static bool emitDialectEnumAttributeBuilder(const AttrOrTypeDef &attr,
     return true;
   }
 
-  llvm::SmallVector<StringRef> namespaces;
-  attr.getStorageNamespace().ltrim("::").split(namespaces, "::");
-  std::string namespace_ = getAttributeNameSpace(namespaces);
+  EnumAttr enumAttr(def);
+  std::string namespace_ = getEnumAttributeNameSpace(enumAttr);
   if (!namespace_.empty())
     namespace_ += "_";
 
   os << llvm::formatv("@register_attribute_builder(\"{0}{1}\")\n", namespace_,
-                      attr.getName());
-  os << llvm::formatv("def _{0}(x, context):\n", attr.getName().lower());
+                      enumAttr.getAttrDefName());
+  os << llvm::formatv("def _{0}(x, context):\n",
+                      enumAttr.getAttrDefName().lower());
   os << llvm::formatv("    return "
                       "_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
                       formatString);
@@ -165,16 +164,27 @@ static bool emitDialectEnumAttributeBuilder(const AttrOrTypeDef &attr,
 static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
                             raw_ostream &os) {
   os << fileHeader;
+  for (const auto &it :
+       recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
+    EnumAttr *enumAttr;
+    for (const auto &value : it->getValues())
+      if (value.getType()->getAsString() == "EnumAttrInfo")
+        enumAttr = new EnumAttr(value.getValue()->getRecordKeeper().getDef(
+            value.getValue()->getAsString()));
+    if (enumAttr)
+      emitEnumClass(*enumAttr, os);
+  }
   for (auto &it :
        recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) {
     EnumAttr enumAttr(*it);
     emitEnumClass(enumAttr, os);
-    emitAttributeBuilder(enumAttr, os);
+    if (emitAttributeBuilderRegistration(enumAttr, os))
+      return true;
   }
   for (const auto &it :
        recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
-    const AttrOrTypeDef attr(&*it);
-    return emitDialectEnumAttributeBuilder(attr, os);
+    if (emitDialectEnumAttributeBuilderRegistration(*it, os))
+      return true;
   }
 
   return false;
diff --git a/mlir/tools/mlir-tblgen/OpGenHelpers.cpp b/mlir/tools/mlir-tblgen/OpGenHelpers.cpp
index fd6250f05a451f5..84ef9e487f4c692 100644
--- a/mlir/tools/mlir-tblgen/OpGenHelpers.cpp
+++ b/mlir/tools/mlir-tblgen/OpGenHelpers.cpp
@@ -81,8 +81,16 @@ bool mlir::tblgen::isPythonReserved(StringRef str) {
   return reserved.contains(str);
 }
 
-std::string
-mlir::tblgen::getAttributeNameSpace(llvm::SmallVector<StringRef> namespaces) {
+std::string mlir::tblgen::getEnumAttributeNameSpace(const EnumAttr &enumAttr) {
+  llvm::SmallVector<StringRef> namespaces;
+  if (enumAttr.getCppNamespace().empty() &&
+      enumAttr.getBaseAttr().isEnumAttr()) {
+    EnumAttr(enumAttr.getBaseAttr().getDef())
+        .getCppNamespace()
+        .ltrim("::")
+        .split(namespaces, "::");
+  } else
+    enumAttr.getCppNamespace().ltrim("::").split(namespaces, "::");
   std::string namespace_;
   if (namespaces[0] == "mlir")
     namespace_ = llvm::join(llvm::drop_begin(namespaces), "_");
diff --git a/mlir/tools/mlir-tblgen/OpGenHelpers.h b/mlir/tools/mlir-tblgen/OpGenHelpers.h
index 447d48aeb7a6bb4..23542bf3ce87a9b 100644
--- a/mlir/tools/mlir-tblgen/OpGenHelpers.h
+++ b/mlir/tools/mlir-tblgen/OpGenHelpers.h
@@ -13,6 +13,7 @@
 #ifndef MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_
 #define MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_
 
+#include "mlir/TableGen/Attribute.h"
 #include "llvm/TableGen/Record.h"
 #include <vector>
 
@@ -28,8 +29,7 @@ getRequestedOpDefinitions(const llvm::RecordKeeper &recordKeeper);
 /// Regenerate using python -c"print(set(sorted(__import__('keyword').kwlist)))"
 bool isPythonReserved(llvm::StringRef str);
 
-std::string
-getAttributeNameSpace(llvm::SmallVector<llvm::StringRef> namespaces);
+std::string getEnumAttributeNameSpace(const EnumAttr &enumAttr);
 
 } // namespace tblgen
 } // namespace mlir
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 37e40c1c66d9520..2448bbc468474f6 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -538,7 +538,7 @@ constexpr const char *initAttributeWithBuilderTemplate =
     issubclass(type({0}), _ods_ir.Attribute) or
     not (_ods_ir.AttrBuilder.contains('{3}') or _ods_ir.AttrBuilder.contains('{2}{3}'))) else
       (_ods_ir.AttrBuilder.get('{3}')({0}, context=_ods_context) if _ods_ir.AttrBuilder.contains('{3}')
-       else _ods_ir.AttrBuilder.contains('{2}{3}')({0}, context=_ods_context))))Py";
+       else _ods_ir.AttrBuilder.get('{2}{3}')({0}, context=_ods_context))))Py";
 
 /// Template for attribute builder from raw input for optional attribute in the
 /// operation builder.
@@ -553,7 +553,7 @@ constexpr const char *initOptionalAttributeWithBuilderTemplate =
         issubclass(type({0}), _ods_ir.Attribute) or
         not (_ods_ir.AttrBuilder.contains('{3}') or _ods_ir.AttrBuilder.contains('{2}{3}'))) else
           (_ods_ir.AttrBuilder.get('{3}')({0}, context=_ods_context) if _ods_ir.AttrBuilder.contains('{3}')
-           else _ods_ir.AttrBuilder.contains('{2}{3}')({0}, context=_ods_context))))Py";
+           else _ods_ir.AttrBuilder.get('{2}{3}')({0}, context=_ods_context))))Py";
 
 constexpr const char *initUnitAttributeTemplate =
     R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
@@ -681,10 +681,8 @@ populateBuilderLinesAttr(const Operator &op,
       continue;
     }
 
-    llvm::SmallVector<StringRef> namespaces;
-    attribute->attr.getStorageType().ltrim("::").split(namespaces, "::");
-    namespaces = llvm::SmallVector<StringRef>{llvm::drop_end(namespaces)};
-    std::string namespace_ = getAttributeNameSpace(namespaces);
+    EnumAttr enumAttr(attribute->attr.getDef());
+    std::string namespace_ = getEnumAttributeNameSpace(enumAttr);
     if (!namespace_.empty())
       namespace_ += "_";
 



More information about the Mlir-commits mailing list