[Mlir-commits] [mlir] [mlir][python] namespace generated enums in python (PR #77830)
Maksim Levental
llvmlistbot at llvm.org
Thu Jan 11 19:58:25 PST 2024
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/77830
>From 184e9fe2d5fef086dc45483c0f3abc8eb31f98d2 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 11 Jan 2024 14:56:18 -0600
Subject: [PATCH] namespace generated enums in python
---
.../mlir-tblgen/EnumPythonBindingGen.cpp | 66 +++++++++++--------
mlir/tools/mlir-tblgen/OpGenHelpers.cpp | 12 ++++
mlir/tools/mlir-tblgen/OpGenHelpers.h | 3 +
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 26 ++++++--
4 files changed, 72 insertions(+), 35 deletions(-)
diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
index f4ced0803772ed..4c6e04ad12362f 100644
--- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
@@ -105,7 +105,14 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
return true;
}
- os << llvm::formatv("@register_attribute_builder(\"{0}\")\n",
+ llvm::SmallVector<StringRef> namespaces;
+ enumAttr.getStorageType().ltrim("::").split(namespaces, "::");
+ namespaces = llvm::SmallVector<StringRef>{llvm::drop_end(namespaces)};
+ std::string namespace_ = getAttributeNameSpace(namespaces);
+ if (!namespace_.empty())
+ namespace_ += "_";
+
+ os << llvm::formatv("@register_attribute_builder(\"{0}{1}\")\n", namespace_,
enumAttr.getAttrDefName());
os << llvm::formatv("def _{0}(x, context):\n",
enumAttr.getAttrDefName().lower());
@@ -120,11 +127,33 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
/// Emits an attribute builder for the given dialect enum attribute to support
/// automatic conversion between enum values and attributes in Python. Returns
/// `false` on success, `true` on failure.
-static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
- StringRef formatString,
+static bool emitDialectEnumAttributeBuilder(const AttrOrTypeDef &attr,
raw_ostream &os) {
- os << llvm::formatv("@register_attribute_builder(\"{0}\")\n", attrDefName);
- os << llvm::formatv("def _{0}(x, context):\n", attrDefName.lower());
+ StringRef mnemonic = attr.getMnemonic().value();
+ std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
+ StringRef dialect = attr.getDialect().getName();
+ std::string formatString;
+ if (assemblyFormat == "`<` $value `>`")
+ formatString =
+ llvm::formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str();
+ else if (assemblyFormat == "$value")
+ formatString =
+ llvm::formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str();
+ else {
+ llvm::errs()
+ << "unsupported assembly format for python enum bindings generation";
+ return true;
+ }
+
+ llvm::SmallVector<StringRef> namespaces;
+ attr.getStorageNamespace().ltrim("::").split(namespaces, "::");
+ std::string namespace_ = getAttributeNameSpace(namespaces);
+ if (!namespace_.empty())
+ namespace_ += "_";
+
+ os << llvm::formatv("@register_attribute_builder(\"{0}{1}\")\n", namespace_,
+ attr.getName());
+ os << llvm::formatv("def _{0}(x, context):\n", attr.getName().lower());
os << llvm::formatv(" return "
"_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
formatString);
@@ -142,29 +171,10 @@ static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
emitEnumClass(enumAttr, os);
emitAttributeBuilder(enumAttr, os);
}
- for (auto &it : recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
- AttrOrTypeDef attr(&*it);
- if (!attr.getMnemonic()) {
- llvm::errs() << "enum case " << attr
- << " needs mnemonic for python enum bindings generation";
- return true;
- }
- StringRef mnemonic = attr.getMnemonic().value();
- std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
- StringRef dialect = attr.getDialect().getName();
- if (assemblyFormat == "`<` $value `>`") {
- emitDialectEnumAttributeBuilder(
- attr.getName(),
- llvm::formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
- } else if (assemblyFormat == "$value") {
- emitDialectEnumAttributeBuilder(
- attr.getName(),
- llvm::formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
- } else {
- llvm::errs()
- << "unsupported assembly format for python enum bindings generation";
- return true;
- }
+ for (const auto &it :
+ recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
+ const AttrOrTypeDef attr(&*it);
+ return emitDialectEnumAttributeBuilder(attr, os);
}
return false;
diff --git a/mlir/tools/mlir-tblgen/OpGenHelpers.cpp b/mlir/tools/mlir-tblgen/OpGenHelpers.cpp
index 7fd34df8460d39..fd6250f05a451f 100644
--- a/mlir/tools/mlir-tblgen/OpGenHelpers.cpp
+++ b/mlir/tools/mlir-tblgen/OpGenHelpers.cpp
@@ -79,4 +79,16 @@ bool mlir::tblgen::isPythonReserved(StringRef str) {
reserved.insert("issubclass");
reserved.insert("type");
return reserved.contains(str);
+}
+
+std::string
+mlir::tblgen::getAttributeNameSpace(llvm::SmallVector<StringRef> namespaces) {
+ std::string namespace_;
+ if (namespaces[0] == "mlir")
+ namespace_ = llvm::join(llvm::drop_begin(namespaces), "_");
+ else
+ namespace_ = llvm::join(namespaces, "_");
+ std::transform(namespace_.begin(), namespace_.end(), namespace_.begin(),
+ tolower);
+ return namespace_;
}
\ No newline at end of file
diff --git a/mlir/tools/mlir-tblgen/OpGenHelpers.h b/mlir/tools/mlir-tblgen/OpGenHelpers.h
index 3dcff14d1221ee..447d48aeb7a6bb 100644
--- a/mlir/tools/mlir-tblgen/OpGenHelpers.h
+++ b/mlir/tools/mlir-tblgen/OpGenHelpers.h
@@ -28,6 +28,9 @@ getRequestedOpDefinitions(const llvm::RecordKeeper &recordKeeper);
/// Regenerate using python -c"print(set(sorted(__import__('keyword').kwlist)))"
bool isPythonReserved(llvm::StringRef str);
+std::string
+getAttributeNameSpace(llvm::SmallVector<llvm::StringRef> namespaces);
+
} // namespace tblgen
} // namespace mlir
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 0770ed562309e7..37e40c1c66d952 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -529,27 +529,31 @@ constexpr const char *multiResultAppendTemplate = "results.extend({0})";
/// Template for attribute builder from raw input in the operation builder.
/// {0} is the builder argument name;
/// {1} is the attribute builder from raw;
-/// {2} is the attribute builder from raw.
+/// {2} is the attribute builder from raw;
+/// {3} is the attribute's dialect.
/// Use the value the user passed in if either it is already an Attribute or
/// there is no method registered to make it an Attribute.
constexpr const char *initAttributeWithBuilderTemplate =
R"Py(attributes["{1}"] = ({0} if (
issubclass(type({0}), _ods_ir.Attribute) or
- not _ods_ir.AttrBuilder.contains('{2}')) else
- _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
+ not (_ods_ir.AttrBuilder.contains('{3}') or _ods_ir.AttrBuilder.contains('{2}{3}'))) else
+ (_ods_ir.AttrBuilder.get('{3}')({0}, context=_ods_context) if _ods_ir.AttrBuilder.contains('{3}')
+ else _ods_ir.AttrBuilder.contains('{2}{3}')({0}, context=_ods_context))))Py";
/// Template for attribute builder from raw input for optional attribute in the
/// operation builder.
/// {0} is the builder argument name;
/// {1} is the attribute builder from raw;
-/// {2} is the attribute builder from raw.
+/// {2} is the attribute builder from raw;
+/// {3} is the attribute's dialect.
/// Use the value the user passed in if either it is already an Attribute or
/// there is no method registered to make it an Attribute.
constexpr const char *initOptionalAttributeWithBuilderTemplate =
R"Py(if {0} is not None: attributes["{1}"] = ({0} if (
issubclass(type({0}), _ods_ir.Attribute) or
- not _ods_ir.AttrBuilder.contains('{2}')) else
- _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
+ not (_ods_ir.AttrBuilder.contains('{3}') or _ods_ir.AttrBuilder.contains('{2}{3}'))) else
+ (_ods_ir.AttrBuilder.get('{3}')({0}, context=_ods_context) if _ods_ir.AttrBuilder.contains('{3}')
+ else _ods_ir.AttrBuilder.contains('{2}{3}')({0}, context=_ods_context))))Py";
constexpr const char *initUnitAttributeTemplate =
R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
@@ -677,11 +681,19 @@ populateBuilderLinesAttr(const Operator &op,
continue;
}
+ llvm::SmallVector<StringRef> namespaces;
+ attribute->attr.getStorageType().ltrim("::").split(namespaces, "::");
+ namespaces = llvm::SmallVector<StringRef>{llvm::drop_end(namespaces)};
+ std::string namespace_ = getAttributeNameSpace(namespaces);
+ if (!namespace_.empty())
+ namespace_ += "_";
+
builderLines.push_back(llvm::formatv(
attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
? initOptionalAttributeWithBuilderTemplate
: initAttributeWithBuilderTemplate,
- argNames[i], attribute->name, attribute->attr.getAttrDefName()));
+ argNames[i], attribute->name, namespace_,
+ attribute->attr.getAttrDefName()));
}
}
More information about the Mlir-commits
mailing list