[Mlir-commits] [mlir] [mlir][python] namespace generated enums in python (PR #77830)
Maksim Levental
llvmlistbot at llvm.org
Fri Jan 12 18:22:14 PST 2024
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/77830
>From 184e9fe2d5fef086dc45483c0f3abc8eb31f98d2 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 11 Jan 2024 14:56:18 -0600
Subject: [PATCH 1/2] namespace generated enums in python
---
.../mlir-tblgen/EnumPythonBindingGen.cpp | 66 +++++++++++--------
mlir/tools/mlir-tblgen/OpGenHelpers.cpp | 12 ++++
mlir/tools/mlir-tblgen/OpGenHelpers.h | 3 +
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 26 ++++++--
4 files changed, 72 insertions(+), 35 deletions(-)
diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
index f4ced0803772ed..4c6e04ad12362f 100644
--- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
@@ -105,7 +105,14 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
return true;
}
- os << llvm::formatv("@register_attribute_builder(\"{0}\")\n",
+ llvm::SmallVector<StringRef> namespaces;
+ enumAttr.getStorageType().ltrim("::").split(namespaces, "::");
+ namespaces = llvm::SmallVector<StringRef>{llvm::drop_end(namespaces)};
+ std::string namespace_ = getAttributeNameSpace(namespaces);
+ if (!namespace_.empty())
+ namespace_ += "_";
+
+ os << llvm::formatv("@register_attribute_builder(\"{0}{1}\")\n", namespace_,
enumAttr.getAttrDefName());
os << llvm::formatv("def _{0}(x, context):\n",
enumAttr.getAttrDefName().lower());
@@ -120,11 +127,33 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
/// Emits an attribute builder for the given dialect enum attribute to support
/// automatic conversion between enum values and attributes in Python. Returns
/// `false` on success, `true` on failure.
-static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
- StringRef formatString,
+static bool emitDialectEnumAttributeBuilder(const AttrOrTypeDef &attr,
raw_ostream &os) {
- os << llvm::formatv("@register_attribute_builder(\"{0}\")\n", attrDefName);
- os << llvm::formatv("def _{0}(x, context):\n", attrDefName.lower());
+ StringRef mnemonic = attr.getMnemonic().value();
+ std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
+ StringRef dialect = attr.getDialect().getName();
+ std::string formatString;
+ if (assemblyFormat == "`<` $value `>`")
+ formatString =
+ llvm::formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str();
+ else if (assemblyFormat == "$value")
+ formatString =
+ llvm::formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str();
+ else {
+ llvm::errs()
+ << "unsupported assembly format for python enum bindings generation";
+ return true;
+ }
+
+ llvm::SmallVector<StringRef> namespaces;
+ attr.getStorageNamespace().ltrim("::").split(namespaces, "::");
+ std::string namespace_ = getAttributeNameSpace(namespaces);
+ if (!namespace_.empty())
+ namespace_ += "_";
+
+ os << llvm::formatv("@register_attribute_builder(\"{0}{1}\")\n", namespace_,
+ attr.getName());
+ os << llvm::formatv("def _{0}(x, context):\n", attr.getName().lower());
os << llvm::formatv(" return "
"_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
formatString);
@@ -142,29 +171,10 @@ static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
emitEnumClass(enumAttr, os);
emitAttributeBuilder(enumAttr, os);
}
- for (auto &it : recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
- AttrOrTypeDef attr(&*it);
- if (!attr.getMnemonic()) {
- llvm::errs() << "enum case " << attr
- << " needs mnemonic for python enum bindings generation";
- return true;
- }
- StringRef mnemonic = attr.getMnemonic().value();
- std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
- StringRef dialect = attr.getDialect().getName();
- if (assemblyFormat == "`<` $value `>`") {
- emitDialectEnumAttributeBuilder(
- attr.getName(),
- llvm::formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
- } else if (assemblyFormat == "$value") {
- emitDialectEnumAttributeBuilder(
- attr.getName(),
- llvm::formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
- } else {
- llvm::errs()
- << "unsupported assembly format for python enum bindings generation";
- return true;
- }
+ for (const auto &it :
+ recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
+ const AttrOrTypeDef attr(&*it);
+ return emitDialectEnumAttributeBuilder(attr, os);
}
return false;
diff --git a/mlir/tools/mlir-tblgen/OpGenHelpers.cpp b/mlir/tools/mlir-tblgen/OpGenHelpers.cpp
index 7fd34df8460d39..fd6250f05a451f 100644
--- a/mlir/tools/mlir-tblgen/OpGenHelpers.cpp
+++ b/mlir/tools/mlir-tblgen/OpGenHelpers.cpp
@@ -79,4 +79,16 @@ bool mlir::tblgen::isPythonReserved(StringRef str) {
reserved.insert("issubclass");
reserved.insert("type");
return reserved.contains(str);
+}
+
+std::string
+mlir::tblgen::getAttributeNameSpace(llvm::SmallVector<StringRef> namespaces) {
+ std::string namespace_;
+ if (namespaces[0] == "mlir")
+ namespace_ = llvm::join(llvm::drop_begin(namespaces), "_");
+ else
+ namespace_ = llvm::join(namespaces, "_");
+ std::transform(namespace_.begin(), namespace_.end(), namespace_.begin(),
+ tolower);
+ return namespace_;
}
\ No newline at end of file
diff --git a/mlir/tools/mlir-tblgen/OpGenHelpers.h b/mlir/tools/mlir-tblgen/OpGenHelpers.h
index 3dcff14d1221ee..447d48aeb7a6bb 100644
--- a/mlir/tools/mlir-tblgen/OpGenHelpers.h
+++ b/mlir/tools/mlir-tblgen/OpGenHelpers.h
@@ -28,6 +28,9 @@ getRequestedOpDefinitions(const llvm::RecordKeeper &recordKeeper);
/// Regenerate using python -c"print(set(sorted(__import__('keyword').kwlist)))"
bool isPythonReserved(llvm::StringRef str);
+std::string
+getAttributeNameSpace(llvm::SmallVector<llvm::StringRef> namespaces);
+
} // namespace tblgen
} // namespace mlir
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 0770ed562309e7..37e40c1c66d952 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -529,27 +529,31 @@ constexpr const char *multiResultAppendTemplate = "results.extend({0})";
/// Template for attribute builder from raw input in the operation builder.
/// {0} is the builder argument name;
/// {1} is the attribute builder from raw;
-/// {2} is the attribute builder from raw.
+/// {2} is the attribute builder from raw;
+/// {3} is the attribute's dialect.
/// Use the value the user passed in if either it is already an Attribute or
/// there is no method registered to make it an Attribute.
constexpr const char *initAttributeWithBuilderTemplate =
R"Py(attributes["{1}"] = ({0} if (
issubclass(type({0}), _ods_ir.Attribute) or
- not _ods_ir.AttrBuilder.contains('{2}')) else
- _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
+ not (_ods_ir.AttrBuilder.contains('{3}') or _ods_ir.AttrBuilder.contains('{2}{3}'))) else
+ (_ods_ir.AttrBuilder.get('{3}')({0}, context=_ods_context) if _ods_ir.AttrBuilder.contains('{3}')
+ else _ods_ir.AttrBuilder.contains('{2}{3}')({0}, context=_ods_context))))Py";
/// Template for attribute builder from raw input for optional attribute in the
/// operation builder.
/// {0} is the builder argument name;
/// {1} is the attribute builder from raw;
-/// {2} is the attribute builder from raw.
+/// {2} is the attribute builder from raw;
+/// {3} is the attribute's dialect.
/// Use the value the user passed in if either it is already an Attribute or
/// there is no method registered to make it an Attribute.
constexpr const char *initOptionalAttributeWithBuilderTemplate =
R"Py(if {0} is not None: attributes["{1}"] = ({0} if (
issubclass(type({0}), _ods_ir.Attribute) or
- not _ods_ir.AttrBuilder.contains('{2}')) else
- _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
+ not (_ods_ir.AttrBuilder.contains('{3}') or _ods_ir.AttrBuilder.contains('{2}{3}'))) else
+ (_ods_ir.AttrBuilder.get('{3}')({0}, context=_ods_context) if _ods_ir.AttrBuilder.contains('{3}')
+ else _ods_ir.AttrBuilder.contains('{2}{3}')({0}, context=_ods_context))))Py";
constexpr const char *initUnitAttributeTemplate =
R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
@@ -677,11 +681,19 @@ populateBuilderLinesAttr(const Operator &op,
continue;
}
+ llvm::SmallVector<StringRef> namespaces;
+ attribute->attr.getStorageType().ltrim("::").split(namespaces, "::");
+ namespaces = llvm::SmallVector<StringRef>{llvm::drop_end(namespaces)};
+ std::string namespace_ = getAttributeNameSpace(namespaces);
+ if (!namespace_.empty())
+ namespace_ += "_";
+
builderLines.push_back(llvm::formatv(
attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
? initOptionalAttributeWithBuilderTemplate
: initAttributeWithBuilderTemplate,
- argNames[i], attribute->name, attribute->attr.getAttrDefName()));
+ argNames[i], attribute->name, namespace_,
+ attribute->attr.getAttrDefName()));
}
}
>From a2288e5f94cc16263f19590e22c18a34dff2c8d6 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Fri, 12 Jan 2024 17:53:59 -0600
Subject: [PATCH 2/2] working
---
mlir/lib/TableGen/Attribute.cpp | 4 ++-
mlir/python/CMakeLists.txt | 9 +++++-
mlir/python/mlir/dialects/python_test.py | 1 +
.../test/mlir-tblgen/enums-python-bindings.td | 5 ++--
mlir/test/mlir-tblgen/op-python-bindings.td | 7 +++--
mlir/test/python/CMakeLists.txt | 2 ++
mlir/test/python/dialects/python_test.py | 17 +++++++++++
mlir/test/python/lib/PythonTestDialect.cpp | 2 ++
mlir/test/python/lib/PythonTestDialect.h | 1 +
mlir/test/python/lit.local.cfg | 1 +
mlir/test/python/python_test_ops.td | 19 ++++++++++++
.../mlir-tblgen/EnumPythonBindingGen.cpp | 30 +++++++++----------
mlir/tools/mlir-tblgen/OpGenHelpers.cpp | 7 +++--
mlir/tools/mlir-tblgen/OpGenHelpers.h | 4 +--
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 20 ++++++++-----
15 files changed, 95 insertions(+), 34 deletions(-)
diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp
index 57c77c74106b96..415d94a919e5ae 100644
--- a/mlir/lib/TableGen/Attribute.cpp
+++ b/mlir/lib/TableGen/Attribute.cpp
@@ -53,7 +53,9 @@ bool Attribute::isSymbolRefAttr() const {
return isSubClassOf("SymbolRefAttr") || isSubClassOf("FlatSymbolRefAttr");
}
-bool Attribute::isEnumAttr() const { return isSubClassOf("EnumAttrInfo"); }
+bool Attribute::isEnumAttr() const {
+ return isSubClassOf("EnumAttrInfo") || isSubClassOf("EnumAttr");
+}
StringRef Attribute::getStorageType() const {
const auto *init = def->getValueInit("storageType");
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 55c5973e40e525..3239f69e39f9e1 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -606,12 +606,19 @@ if(MLIR_INCLUDE_TESTS)
"dialects/_python_test_ops_gen.py"
-gen-python-op-bindings
-bind-dialect=python_test)
+ mlir_tablegen(
+ "dialects/_python_test_enums_gen.py"
+ -gen-python-enum-bindings
+ EXTRA_INCLUDES
+ "${MLIR_MAIN_SRC_DIR}/test/python")
add_public_tablegen_target(PythonTestDialectPyIncGen)
declare_mlir_python_sources(
MLIRPythonTestSources.Dialects.PythonTest.ops_gen
ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}"
ADD_TO_PARENT MLIRPythonTestSources.Dialects.PythonTest
- SOURCES "dialects/_python_test_ops_gen.py")
+ SOURCES
+ "dialects/_python_test_ops_gen.py"
+ "dialects/_python_test_enums_gen.py")
declare_mlir_python_extension(MLIRPythonTestSources.PythonTestExtension
MODULE_NAME _mlirPythonTest
diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py
index b5baa80bc767fb..4c5f54e41b7be6 100644
--- a/mlir/python/mlir/dialects/python_test.py
+++ b/mlir/python/mlir/dialects/python_test.py
@@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._python_test_ops_gen import *
+from ._python_test_enums_gen import *
from .._mlir_libs._mlirPythonTest import (
TestAttr,
TestType,
diff --git a/mlir/test/mlir-tblgen/enums-python-bindings.td b/mlir/test/mlir-tblgen/enums-python-bindings.td
index 1c5567f54a5f4b..51de9b537fdf7b 100644
--- a/mlir/test/mlir-tblgen/enums-python-bindings.td
+++ b/mlir/test/mlir-tblgen/enums-python-bindings.td
@@ -70,6 +70,7 @@ def TestBitEnum
]> {
let genSpecializedAttr = 0;
let separator = " | ";
+ let cppNamespace = "test";
}
def TestBitEnum_Attr : EnumAttr<Test_Dialect, TestBitEnum, "testbitenum">;
@@ -96,11 +97,11 @@ def TestBitEnum_Attr : EnumAttr<Test_Dialect, TestBitEnum, "testbitenum">;
// CHECK: return "other"
// CHECK: raise ValueError("Unknown TestBitEnum enum entry.")
-// CHECK: @register_attribute_builder("TestBitEnum")
+// CHECK: @register_attribute_builder("test_TestBitEnum")
// CHECK: def _testbitenum(x, context):
// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
-// CHECK: @register_attribute_builder("TestBitEnum_Attr")
+// CHECK: @register_attribute_builder("test_TestBitEnum_Attr")
// CHECK: def _testbitenum_attr(x, context):
// CHECK: return _ods_ir.Attribute.parse(f'#TestDialect<testbitenum {str(x)}>', context=context)
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index f7df8ba2df0ae2..a9b652f036cd84 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -123,9 +123,10 @@ def AttributedOp : TestOp<"attributed_op"> {
// CHECK: attributes = {}
// CHECK: regions = None
// CHECK: attributes["i32attr"] = (i32attr if (
- // CHECK-NEXT: issubclass(type(i32attr), _ods_ir.Attribute) or
- // CHECK-NEXT: not _ods_ir.AttrBuilder.contains('I32Attr')
- // CHECK-NEXT: _ods_ir.AttrBuilder.get('I32Attr')(i32attr, context=_ods_context)
+ // CHECK-NEXT: issubclass(type(i32attr), _ods_ir.Attribute) or
+ // CHECK-NEXT: not (_ods_ir.AttrBuilder.contains('I32Attr') or _ods_ir.AttrBuilder.contains('I32Attr'))) else
+ // CHECK-NEXT: (_ods_ir.AttrBuilder.get('I32Attr')(i32attr, context=_ods_context) if _ods_ir.AttrBuilder.contains('I32Attr')
+ // CHECK-NEXT: else _ods_ir.AttrBuilder.get('I32Attr')(i32attr, context=_ods_context)))
// CHECK: if optionalF32Attr is not None: attributes["optionalF32Attr"] = (optionalF32Attr
// CHECK: if bool(unitAttr): attributes["unitAttr"] = _ods_ir.UnitAttr.get(
// CHECK: _ods_get_default_loc_context(loc))
diff --git a/mlir/test/python/CMakeLists.txt b/mlir/test/python/CMakeLists.txt
index 1db957c86819d2..ddc11236efe5b5 100644
--- a/mlir/test/python/CMakeLists.txt
+++ b/mlir/test/python/CMakeLists.txt
@@ -3,6 +3,8 @@ mlir_tablegen(lib/PythonTestDialect.h.inc -gen-dialect-decls)
mlir_tablegen(lib/PythonTestDialect.cpp.inc -gen-dialect-defs)
mlir_tablegen(lib/PythonTestOps.h.inc -gen-op-decls)
mlir_tablegen(lib/PythonTestOps.cpp.inc -gen-op-defs)
+mlir_tablegen(lib/PythonTestEnums.h.inc -gen-enum-decls)
+mlir_tablegen(lib/PythonTestEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(lib/PythonTestAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(lib/PythonTestAttributes.cpp.inc -gen-attrdef-defs)
mlir_tablegen(lib/PythonTestTypes.h.inc -gen-typedef-decls)
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 88761c9d08fe07..4ba4ccb0be6c3a 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -5,6 +5,13 @@
import mlir.dialects.python_test as test
import mlir.dialects.tensor as tensor
import mlir.dialects.arith as arith
+from mlir.dialects import llvm
+from mlir.dialects._llvm_enum_gen import (
+ _llvm_integeroverflowflagsattr as llvm_integeroverflowflagsattr,
+)
+from mlir.dialects._python_test_enums_gen import (
+ _llvm_integeroverflowflagsattr as python_test_integeroverflowflagsattr,
+)
test.register_python_test_dialect(get_dialect_registry())
@@ -543,3 +550,13 @@ def testInferTypeOpInterface():
two_operands = test.InferResultsVariadicInputsOp(single=zero, doubled=zero)
# CHECK: f32
print(two_operands.result.type)
+
+
+# CHECK-LABEL: TEST: testEnumNamespacing
+ at run
+def testEnumNamespacing():
+ with Context() as ctx, Location.unknown(ctx):
+ # CHECK: #llvm.overflow<none>
+ print(llvm_integeroverflowflagsattr(llvm.IntegerOverflowFlags.none, ctx))
+ # CHECK: #python_test.overflow<none>
+ print(python_test_integeroverflowflagsattr(test.IntegerOverflowFlags.none, ctx))
diff --git a/mlir/test/python/lib/PythonTestDialect.cpp b/mlir/test/python/lib/PythonTestDialect.cpp
index a0ff31504c6918..738d2c7de5c98a 100644
--- a/mlir/test/python/lib/PythonTestDialect.cpp
+++ b/mlir/test/python/lib/PythonTestDialect.cpp
@@ -9,10 +9,12 @@
#include "PythonTestDialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
+#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "PythonTestDialect.cpp.inc"
+#include "PythonTestEnums.cpp.inc"
#define GET_ATTRDEF_CLASSES
#include "PythonTestAttributes.cpp.inc"
diff --git a/mlir/test/python/lib/PythonTestDialect.h b/mlir/test/python/lib/PythonTestDialect.h
index 044381fcd4728d..26204ca43c3fa2 100644
--- a/mlir/test/python/lib/PythonTestDialect.h
+++ b/mlir/test/python/lib/PythonTestDialect.h
@@ -19,6 +19,7 @@
#define GET_OP_CLASSES
#include "PythonTestOps.h.inc"
+#include "PythonTestEnums.h.inc"
#define GET_ATTRDEF_CLASSES
#include "PythonTestAttributes.h.inc"
diff --git a/mlir/test/python/lit.local.cfg b/mlir/test/python/lit.local.cfg
index 12d6e1f22744a2..dccca70dda76d3 100644
--- a/mlir/test/python/lit.local.cfg
+++ b/mlir/test/python/lit.local.cfg
@@ -2,3 +2,4 @@ config.environment["ASAN_OPTIONS"] = "detect_leaks=0"
if not config.enable_bindings_python:
config.unsupported = True
config.excludes.add("python_test_ops.td")
+config.excludes.add("python_test_enums.td")
diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td
index 95301985e3fde0..189082f1734423 100644
--- a/mlir/test/python/python_test_ops.td
+++ b/mlir/test/python/python_test_ops.td
@@ -10,6 +10,7 @@
#define PYTHON_TEST_OPS
include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -48,6 +49,24 @@ def TestType : TestType<"TestType", "test_type">;
def TestAttr : TestAttr<"TestAttr", "test_attr">;
+def IOFnone : I32BitEnumAttrCaseNone<"none">;
+def IOFnsw : I32BitEnumAttrCaseBit<"nsw", 0>;
+def IOFnuw : I32BitEnumAttrCaseBit<"nuw", 1>;
+
+def IntegerOverflowFlags : I32BitEnumAttr<
+ "IntegerOverflowFlags", "",
+ [IOFnone, IOFnsw, IOFnuw]> {
+ let separator = ", ";
+ let cppNamespace = "python_test";
+ let genSpecializedAttr = 0;
+}
+
+// This is intentionally prefixed with LLVM to test for collision in AttrBuilder.
+def LLVM_IntegerOverflowFlagsAttr :
+ EnumAttr<Python_Test_Dialect, IntegerOverflowFlags, "overflow"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
//===----------------------------------------------------------------------===//
// Operation definitions.
//===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
index 4c6e04ad12362f..9d17faca067691 100644
--- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
@@ -97,7 +97,8 @@ static bool extractUIntBitwidth(StringRef uintType, int64_t &bitwidth) {
/// Emits an attribute builder for the given enum attribute to support automatic
/// conversion between enum values and attributes in Python. Returns
/// `false` on success, `true` on failure.
-static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
+static bool emitAttributeBuilderRegistration(const EnumAttr &enumAttr,
+ raw_ostream &os) {
int64_t bitwidth;
if (extractUIntBitwidth(enumAttr.getUnderlyingType(), bitwidth)) {
llvm::errs() << "failed to identify bitwidth of "
@@ -105,10 +106,7 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
return true;
}
- llvm::SmallVector<StringRef> namespaces;
- enumAttr.getStorageType().ltrim("::").split(namespaces, "::");
- namespaces = llvm::SmallVector<StringRef>{llvm::drop_end(namespaces)};
- std::string namespace_ = getAttributeNameSpace(namespaces);
+ std::string namespace_ = getEnumAttributeNameSpace(enumAttr);
if (!namespace_.empty())
namespace_ += "_";
@@ -127,8 +125,9 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
/// Emits an attribute builder for the given dialect enum attribute to support
/// automatic conversion between enum values and attributes in Python. Returns
/// `false` on success, `true` on failure.
-static bool emitDialectEnumAttributeBuilder(const AttrOrTypeDef &attr,
- raw_ostream &os) {
+static bool emitDialectEnumAttributeBuilderRegistration(const llvm::Record &def,
+ raw_ostream &os) {
+ const AttrOrTypeDef attr(&def);
StringRef mnemonic = attr.getMnemonic().value();
std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
StringRef dialect = attr.getDialect().getName();
@@ -145,15 +144,15 @@ static bool emitDialectEnumAttributeBuilder(const AttrOrTypeDef &attr,
return true;
}
- llvm::SmallVector<StringRef> namespaces;
- attr.getStorageNamespace().ltrim("::").split(namespaces, "::");
- std::string namespace_ = getAttributeNameSpace(namespaces);
+ EnumAttr enumAttr(def);
+ std::string namespace_ = getEnumAttributeNameSpace(enumAttr);
if (!namespace_.empty())
namespace_ += "_";
os << llvm::formatv("@register_attribute_builder(\"{0}{1}\")\n", namespace_,
- attr.getName());
- os << llvm::formatv("def _{0}(x, context):\n", attr.getName().lower());
+ enumAttr.getAttrDefName());
+ os << llvm::formatv("def _{0}(x, context):\n",
+ enumAttr.getAttrDefName().lower());
os << llvm::formatv(" return "
"_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
formatString);
@@ -169,12 +168,13 @@ static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) {
EnumAttr enumAttr(*it);
emitEnumClass(enumAttr, os);
- emitAttributeBuilder(enumAttr, os);
+ if (emitAttributeBuilderRegistration(enumAttr, os))
+ return true;
}
for (const auto &it :
recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
- const AttrOrTypeDef attr(&*it);
- return emitDialectEnumAttributeBuilder(attr, os);
+ if (emitDialectEnumAttributeBuilderRegistration(*it, os))
+ return true;
}
return false;
diff --git a/mlir/tools/mlir-tblgen/OpGenHelpers.cpp b/mlir/tools/mlir-tblgen/OpGenHelpers.cpp
index fd6250f05a451f..374af0fa26666f 100644
--- a/mlir/tools/mlir-tblgen/OpGenHelpers.cpp
+++ b/mlir/tools/mlir-tblgen/OpGenHelpers.cpp
@@ -17,6 +17,8 @@
#include "llvm/Support/Regex.h"
#include "llvm/TableGen/Error.h"
+#include <iostream>
+
using namespace llvm;
using namespace mlir;
using namespace mlir::tblgen;
@@ -81,9 +83,10 @@ bool mlir::tblgen::isPythonReserved(StringRef str) {
return reserved.contains(str);
}
-std::string
-mlir::tblgen::getAttributeNameSpace(llvm::SmallVector<StringRef> namespaces) {
+std::string mlir::tblgen::getEnumAttributeNameSpace(const EnumAttr &enumAttr) {
std::string namespace_;
+ llvm::SmallVector<StringRef> namespaces;
+ enumAttr.getCppNamespace().ltrim("::").split(namespaces, "::");
if (namespaces[0] == "mlir")
namespace_ = llvm::join(llvm::drop_begin(namespaces), "_");
else
diff --git a/mlir/tools/mlir-tblgen/OpGenHelpers.h b/mlir/tools/mlir-tblgen/OpGenHelpers.h
index 447d48aeb7a6bb..23542bf3ce87a9 100644
--- a/mlir/tools/mlir-tblgen/OpGenHelpers.h
+++ b/mlir/tools/mlir-tblgen/OpGenHelpers.h
@@ -13,6 +13,7 @@
#ifndef MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_
#define MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_
+#include "mlir/TableGen/Attribute.h"
#include "llvm/TableGen/Record.h"
#include <vector>
@@ -28,8 +29,7 @@ getRequestedOpDefinitions(const llvm::RecordKeeper &recordKeeper);
/// Regenerate using python -c"print(set(sorted(__import__('keyword').kwlist)))"
bool isPythonReserved(llvm::StringRef str);
-std::string
-getAttributeNameSpace(llvm::SmallVector<llvm::StringRef> namespaces);
+std::string getEnumAttributeNameSpace(const EnumAttr &enumAttr);
} // namespace tblgen
} // namespace mlir
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 37e40c1c66d952..71ed85905ee294 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -530,7 +530,7 @@ constexpr const char *multiResultAppendTemplate = "results.extend({0})";
/// {0} is the builder argument name;
/// {1} is the attribute builder from raw;
/// {2} is the attribute builder from raw;
-/// {3} is the attribute's dialect.
+/// {3} is the attribute's fully qualified namespace.
/// Use the value the user passed in if either it is already an Attribute or
/// there is no method registered to make it an Attribute.
constexpr const char *initAttributeWithBuilderTemplate =
@@ -538,14 +538,14 @@ constexpr const char *initAttributeWithBuilderTemplate =
issubclass(type({0}), _ods_ir.Attribute) or
not (_ods_ir.AttrBuilder.contains('{3}') or _ods_ir.AttrBuilder.contains('{2}{3}'))) else
(_ods_ir.AttrBuilder.get('{3}')({0}, context=_ods_context) if _ods_ir.AttrBuilder.contains('{3}')
- else _ods_ir.AttrBuilder.contains('{2}{3}')({0}, context=_ods_context))))Py";
+ else _ods_ir.AttrBuilder.get('{2}{3}')({0}, context=_ods_context))))Py";
/// Template for attribute builder from raw input for optional attribute in the
/// operation builder.
/// {0} is the builder argument name;
/// {1} is the attribute builder from raw;
/// {2} is the attribute builder from raw;
-/// {3} is the attribute's dialect.
+/// {3} is the attribute's fully qualified namespace.
/// Use the value the user passed in if either it is already an Attribute or
/// there is no method registered to make it an Attribute.
constexpr const char *initOptionalAttributeWithBuilderTemplate =
@@ -553,7 +553,7 @@ constexpr const char *initOptionalAttributeWithBuilderTemplate =
issubclass(type({0}), _ods_ir.Attribute) or
not (_ods_ir.AttrBuilder.contains('{3}') or _ods_ir.AttrBuilder.contains('{2}{3}'))) else
(_ods_ir.AttrBuilder.get('{3}')({0}, context=_ods_context) if _ods_ir.AttrBuilder.contains('{3}')
- else _ods_ir.AttrBuilder.contains('{2}{3}')({0}, context=_ods_context))))Py";
+ else _ods_ir.AttrBuilder.get('{2}{3}')({0}, context=_ods_context))))Py";
constexpr const char *initUnitAttributeTemplate =
R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
@@ -681,10 +681,14 @@ populateBuilderLinesAttr(const Operator &op,
continue;
}
- llvm::SmallVector<StringRef> namespaces;
- attribute->attr.getStorageType().ltrim("::").split(namespaces, "::");
- namespaces = llvm::SmallVector<StringRef>{llvm::drop_end(namespaces)};
- std::string namespace_ = getAttributeNameSpace(namespaces);
+ std::string namespace_;
+ if (attribute->attr.isEnumAttr()) {
+ EnumAttr enumAttr(attribute->attr.getDef());
+ namespace_ = getEnumAttributeNameSpace(enumAttr);
+ } else if (attribute->attr.getBaseAttr().isEnumAttr()) {
+ EnumAttr enumAttr(attribute->attr.getBaseAttr().getDef());
+ namespace_ = getEnumAttributeNameSpace(enumAttr);
+ }
if (!namespace_.empty())
namespace_ += "_";
More information about the Mlir-commits
mailing list