[llvm] [mlir][python] generate value builders (PR #68308)
Maksim Levental via llvm-commits
llvm-commits at lists.llvm.org
Thu Oct 5 06:28:23 PDT 2023
https://github.com/makslevental created https://github.com/llvm/llvm-project/pull/68308
None
>From f5f02d3564802c49093fa39d221cf1929b6a3893 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 5 Oct 2023 01:35:05 -0500
Subject: [PATCH] [mlir][python] generate value builders
---
llvm/lib/Support/StringExtras.cpp | 17 ++---
llvm/unittests/ADT/StringExtrasTest.cpp | 5 ++
mlir/python/mlir/dialects/_ods_common.py | 17 +++++
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 68 +++++++++++++++----
4 files changed, 84 insertions(+), 23 deletions(-)
diff --git a/llvm/lib/Support/StringExtras.cpp b/llvm/lib/Support/StringExtras.cpp
index 5683d7005584eb2..a418360fae6f6bb 100644
--- a/llvm/lib/Support/StringExtras.cpp
+++ b/llvm/lib/Support/StringExtras.cpp
@@ -12,6 +12,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Regex.h"
#include "llvm/Support/raw_ostream.h"
#include <cctype>
@@ -96,18 +97,12 @@ std::string llvm::convertToSnakeFromCamelCase(StringRef input) {
if (input.empty())
return "";
- std::string snakeCase;
- snakeCase.reserve(input.size());
- for (char c : input) {
- if (!std::isupper(c)) {
- snakeCase.push_back(c);
- continue;
- }
-
- if (!snakeCase.empty() && snakeCase.back() != '_')
- snakeCase.push_back('_');
- snakeCase.push_back(llvm::toLower(c));
+ std::string snakeCase = input.str();
+ for (int i = 0; i < 10; ++i) {
+ snakeCase = llvm::Regex("([A-Z]+)([A-Z][a-z])").sub("\\1_\\2", snakeCase);
+ snakeCase = llvm::Regex("([a-z])([A-Z])").sub("\\1_\\2", snakeCase);
}
+ llvm::transform(snakeCase, snakeCase.begin(), std::tolower);
return snakeCase;
}
diff --git a/llvm/unittests/ADT/StringExtrasTest.cpp b/llvm/unittests/ADT/StringExtrasTest.cpp
index 3f69c91b270a355..fab562f1ed0d594 100644
--- a/llvm/unittests/ADT/StringExtrasTest.cpp
+++ b/llvm/unittests/ADT/StringExtrasTest.cpp
@@ -184,6 +184,11 @@ TEST(StringExtrasTest, ConvertToSnakeFromCamelCase) {
testConvertToSnakeCase("OpName", "op_name");
testConvertToSnakeCase("opName", "op_name");
+ testConvertToSnakeCase("OPName", "op_name");
+ testConvertToSnakeCase("opNAME", "op_name");
+ testConvertToSnakeCase("opNAMe", "op_na_me");
+ testConvertToSnakeCase("opnameE", "opname_e");
+ testConvertToSnakeCase("OPNameOPName", "op_name_op_name");
testConvertToSnakeCase("_OpName", "_op_name");
testConvertToSnakeCase("Op_Name", "op_name");
testConvertToSnakeCase("", "");
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 7655629a5542520..b4535fbb8437c19 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -13,6 +13,7 @@
"get_default_loc_context",
"get_op_result_or_value",
"get_op_results_or_values",
+ "get_op_result_or_op_results",
"segmented_accessor",
]
@@ -167,3 +168,19 @@ def get_op_results_or_values(
return arg.results
else:
return [get_op_result_or_value(element) for element in arg]
+
+
+def get_op_result_or_op_results(
+ op: _Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value],
+) -> _Union[_cext.ir.OpView, _cext.ir.Value, _Sequence[_cext.ir.Value]]:
+ if isinstance(op, _cext.ir.Value):
+ return op
+ elif isinstance(op, _cext.ir.OpView):
+ op = op.operation
+ return (
+ list(get_op_results_or_values(op))
+ if len(op.results) > 1
+ else get_op_result_or_value(op)
+ if len(op.results) > 0
+ else op
+ )
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 0b5df7ab70dddb2..e4fd4d3b6221f74 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -30,7 +30,7 @@ constexpr const char *fileHeader = R"Py(
# Autogenerated by mlir-tblgen; don't manually edit.
from ._ods_common import _cext as _ods_cext
-from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
+from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, get_op_result_or_op_results as _get_op_result_or_op_results
_ods_ir = _ods_cext.ir
try:
@@ -260,11 +260,16 @@ constexpr const char *attributeDeleterTemplate = R"Py(
del self.operation.attributes["{1}"]
)Py";
-constexpr const char *regionAccessorTemplate = R"PY(
+constexpr const char *regionAccessorTemplate = R"Py(
@builtins.property
def {0}(self):
return self.regions[{1}]
-)PY";
+)Py";
+
+constexpr const char *valueBuilderTemplate = R"Py(
+def {0}({2}):
+ return _get_op_result_or_op_results({1}({3}))
+)Py";
static llvm::cl::OptionCategory
clOpPythonBindingCat("Options for -gen-python-op-bindings");
@@ -609,9 +614,7 @@ populateBuilderArgsResults(const Operator &op,
static void
populateBuilderArgs(const Operator &op,
llvm::SmallVectorImpl<std::string> &builderArgs,
- llvm::SmallVectorImpl<std::string> &operandNames,
- llvm::SmallVectorImpl<std::string> &successorArgNames) {
-
+ llvm::SmallVectorImpl<std::string> &operandNames) {
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
std::string name = op.getArgName(i).str();
if (name.empty())
@@ -734,11 +737,11 @@ populateBuilderLinesOperand(const Operator &op,
/// attribute:
/// - {0} is the name of the attribute from which to derive the types.
constexpr const char *deriveTypeFromAttrTemplate =
- R"PY(_ods_result_type_source_attr = attributes["{0}"]
+ R"Py(_ods_result_type_source_attr = attributes["{0}"]
_ods_derived_result_type = (
_ods_ir.TypeAttr(_ods_result_type_source_attr).value
if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
- _ods_result_type_source_attr.type))PY";
+ _ods_result_type_source_attr.type))Py";
/// Python code template appending {0} type {1} times to the results list.
constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
@@ -837,10 +840,11 @@ populateBuilderRegions(const Operator &op,
/// Emits a default builder constructing an operation from the list of its
/// result types, followed by a list of its operands.
-static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
+static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
+ raw_ostream &os) {
// If we are asked to skip default builders, comply.
if (op.skipDefaultBuilders())
- return;
+ return {};
llvm::SmallVector<std::string> builderArgs;
llvm::SmallVector<std::string> builderLines;
@@ -850,7 +854,7 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
op.getNumNativeAttributes() + op.getNumSuccessors());
populateBuilderArgsResults(op, builderArgs);
size_t numResultArgs = builderArgs.size();
- populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames);
+ populateBuilderArgs(op, builderArgs, operandArgNames);
size_t numOperandAttrArgs = builderArgs.size() - numResultArgs;
populateBuilderArgsSuccessors(op, builderArgs, successorArgNames);
@@ -921,6 +925,8 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
os << llvm::formatv(initTemplate, llvm::join(functionArgs, ", "),
llvm::join(builderLines, "\n "),
llvm::join(initArgs, ", "));
+ return llvm::to_vector<8>(
+ llvm::map_range(functionArgs, [](llvm::StringRef s) { return s.str(); }));
}
static void emitSegmentSpec(
@@ -968,6 +974,43 @@ static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
}
}
+/// Emits builder that extracts results from op
+static void emitValueBuilder(const Operator &op,
+ llvm::SmallVector<std::string> functionArgs,
+ raw_ostream &os) {
+ // If we are asked to skip default builders, comply.
+ if (op.skipDefaultBuilders())
+ return;
+
+ auto name =
+ llvm::join(llvm::split(sanitizeName(op.getOperationName()), "."), "_");
+ os << llvm::formatv(
+ valueBuilderTemplate, name, op.getCppClassName(),
+ llvm::join(
+ llvm::map_range(functionArgs,
+ [](const std::string &argAndMaybeDefault) {
+ llvm::SmallVector<llvm::StringRef> argMaybeDefault =
+ llvm::to_vector<2>(
+ llvm::split(argAndMaybeDefault, "="));
+ auto arg = llvm::convertToSnakeFromCamelCase(
+ argMaybeDefault[0]);
+ if (argMaybeDefault.size() == 2)
+ return arg + "=" + argMaybeDefault[1].str();
+ return arg;
+ }),
+ ", "),
+ llvm::join(
+ llvm::map_range(
+ llvm::make_filter_range(
+ functionArgs, [](const std::string &s) { return s != "*"; }),
+ [](const std::string &arg) {
+ auto lhs = *llvm::split(arg, "=").begin();
+ return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs))
+ .str();
+ }),
+ ", "));
+}
+
/// Emits bindings for a specific Op to the given output stream.
static void emitOpBindings(const Operator &op, raw_ostream &os) {
os << llvm::formatv(opClassTemplate, op.getCppClassName(),
@@ -982,11 +1025,12 @@ static void emitOpBindings(const Operator &op, raw_ostream &os) {
}
emitRegionAttributes(op, os);
- emitDefaultOpBuilder(op, os);
+ llvm::SmallVector<std::string> functionArgs = emitDefaultOpBuilder(op, os);
emitOperandAccessors(op, os);
emitAttributeAccessors(op, os);
emitResultAccessors(op, os);
emitRegionAccessors(op, os);
+ emitValueBuilder(op, functionArgs, os);
}
/// Emits bindings for the dialect specified in the command line, including file
More information about the llvm-commits
mailing list