[Mlir-commits] [mlir] [mlir][python] namespace generated enums in python (PR #77830)
Maksim Levental
llvmlistbot at llvm.org
Thu Jan 11 14:37:30 PST 2024
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/77830
>From c9ffacc2871e5901fe55bab46cda0bd724241eaa Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 11 Jan 2024 14:56:18 -0600
Subject: [PATCH] namespace generated enums in python
---
.../mlir-tblgen/EnumPythonBindingGen.cpp | 53 +++++++++----------
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 17 +++---
2 files changed, 35 insertions(+), 35 deletions(-)
diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
index f4ced0803772ed..d5f36ff3bc0fd2 100644
--- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
@@ -105,7 +105,8 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
return true;
}
- os << llvm::formatv("@register_attribute_builder(\"{0}\")\n",
+ os << llvm::formatv("@register_attribute_builder(\"{0}_{1}\")\n",
+ enumAttr.getDialect().getName(),
enumAttr.getAttrDefName());
os << llvm::formatv("def _{0}(x, context):\n",
enumAttr.getAttrDefName().lower());
@@ -120,11 +121,26 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
/// Emits an attribute builder for the given dialect enum attribute to support
/// automatic conversion between enum values and attributes in Python. Returns
/// `false` on success, `true` on failure.
-static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
- StringRef formatString,
+static bool emitDialectEnumAttributeBuilder(const AttrOrTypeDef &attr,
raw_ostream &os) {
- os << llvm::formatv("@register_attribute_builder(\"{0}\")\n", attrDefName);
- os << llvm::formatv("def _{0}(x, context):\n", attrDefName.lower());
+ StringRef mnemonic = attr.getMnemonic().value();
+ std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
+ StringRef dialect = attr.getDialect().getName();
+ std::string formatString;
+ if (assemblyFormat == "`<` $value `>`") {
+ formatString =
+ llvm::formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str();
+ } else if (assemblyFormat == "$value") {
+ formatString =
+ llvm::formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str();
+ } else {
+ llvm::errs()
+ << "unsupported assembly format for python enum bindings generation";
+ return true;
+ }
+ os << llvm::formatv("@register_attribute_builder(\"{0}_{1}\")\n",
+ attr.getDialect().getName(), attr.getName());
+ os << llvm::formatv("def _{0}(x, context):\n", attr.getName().lower());
os << llvm::formatv(" return "
"_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
formatString);
@@ -142,29 +158,10 @@ static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
emitEnumClass(enumAttr, os);
emitAttributeBuilder(enumAttr, os);
}
- for (auto &it : recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
- AttrOrTypeDef attr(&*it);
- if (!attr.getMnemonic()) {
- llvm::errs() << "enum case " << attr
- << " needs mnemonic for python enum bindings generation";
- return true;
- }
- StringRef mnemonic = attr.getMnemonic().value();
- std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
- StringRef dialect = attr.getDialect().getName();
- if (assemblyFormat == "`<` $value `>`") {
- emitDialectEnumAttributeBuilder(
- attr.getName(),
- llvm::formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
- } else if (assemblyFormat == "$value") {
- emitDialectEnumAttributeBuilder(
- attr.getName(),
- llvm::formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
- } else {
- llvm::errs()
- << "unsupported assembly format for python enum bindings generation";
- return true;
- }
+ for (const auto &it :
+ recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
+ const AttrOrTypeDef attr(&*it);
+ return emitDialectEnumAttributeBuilder(attr, os);
}
return false;
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 0770ed562309e7..de343df1c434fa 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -529,27 +529,29 @@ constexpr const char *multiResultAppendTemplate = "results.extend({0})";
/// Template for attribute builder from raw input in the operation builder.
/// {0} is the builder argument name;
/// {1} is the attribute builder from raw;
-/// {2} is the attribute builder from raw.
+/// {2} is the attribute builder from raw;
+/// {3} is the attribute's dialect.
/// Use the value the user passed in if either it is already an Attribute or
/// there is no method registered to make it an Attribute.
constexpr const char *initAttributeWithBuilderTemplate =
R"Py(attributes["{1}"] = ({0} if (
issubclass(type({0}), _ods_ir.Attribute) or
- not _ods_ir.AttrBuilder.contains('{2}')) else
- _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
+ not _ods_ir.AttrBuilder.contains('{2}_{3}')) else
+ _ods_ir.AttrBuilder.get('{2}_{3}')({0}, context=_ods_context)))Py";
/// Template for attribute builder from raw input for optional attribute in the
/// operation builder.
/// {0} is the builder argument name;
/// {1} is the attribute builder from raw;
-/// {2} is the attribute builder from raw.
+/// {2} is the attribute builder from raw;
+/// {3} is the attribute's dialect.
/// Use the value the user passed in if either it is already an Attribute or
/// there is no method registered to make it an Attribute.
constexpr const char *initOptionalAttributeWithBuilderTemplate =
R"Py(if {0} is not None: attributes["{1}"] = ({0} if (
issubclass(type({0}), _ods_ir.Attribute) or
- not _ods_ir.AttrBuilder.contains('{2}')) else
- _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
+ not _ods_ir.AttrBuilder.contains('{2}_{3}')) else
+ _ods_ir.AttrBuilder.get('{2}_{3}')({0}, context=_ods_context)))Py";
constexpr const char *initUnitAttributeTemplate =
R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
@@ -681,7 +683,8 @@ populateBuilderLinesAttr(const Operator &op,
attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
? initOptionalAttributeWithBuilderTemplate
: initAttributeWithBuilderTemplate,
- argNames[i], attribute->name, attribute->attr.getAttrDefName()));
+ argNames[i], attribute->name, attribute->attr.getAttrDefName(),
+ attribute->attr.getDialect().getName()));
}
}
More information about the Mlir-commits
mailing list