[Mlir-commits] [mlir] 44600ba - [mlir][python] Improve sanitization of python names (#68801)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 12 00:53:57 PDT 2023
Author: JoelWee
Date: 2023-10-12T08:53:53+01:00
New Revision: 44600bae893e126d91aee79e522749549caeec50
URL: https://github.com/llvm/llvm-project/commit/44600bae893e126d91aee79e522749549caeec50
DIFF: https://github.com/llvm/llvm-project/commit/44600bae893e126d91aee79e522749549caeec50.diff
LOG: [mlir][python] Improve sanitization of python names (#68801)
Follow up to 7d4cd47e242c28c450c1e2a1a9f4bd4b7b5a01ab, where I fixed
just the case of a dash. This fixes it for all possible types of
strings, which can include "-,." etc.
This modifies some code written in
27c6d55cae74125b6381a647533090a72930ecda
It also handles the case of a leading number which is not valid for
python names.
Ref:
- https://llvm.org/docs/TableGen/ProgRef.html#literals
Added:
Modified:
mlir/test/mlir-tblgen/op-python-bindings.td
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
Removed:
################################################################################
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 8ca23fa9f45c4ab..63dad1cc901fe2b 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -616,6 +616,15 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
// CHECK: def variadic_region(num_variadic, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
+// CHECK: @_ods_cext.register_operation(_Dialect)
+// CHECK: class WithSpecialCharactersOp(_ods_ir.OpView):
+// CHECK-LABEL: OPERATION_NAME = "test.123with--special.characters"
+def WithSpecialCharactersOp : TestOp<"123with--special.characters"> {
+}
+
+// CHECK: def _123with__special_characters(*, loc=None, ip=None)
+// CHECK: return _get_op_result_or_op_results(WithSpecialCharactersOp(loc=loc, ip=ip))
+
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class WithSuccessorsOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.with_successors"
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 2c81538b7b40433..49f3a951426d0ee 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -302,8 +302,12 @@ static bool isODSReserved(StringRef str) {
/// modified version.
static std::string sanitizeName(StringRef name) {
std::string processed_str = name.str();
+ std::replace_if(
+ processed_str.begin(), processed_str.end(),
+ [](char c) { return !llvm::isAlnum(c); }, '_');
- std::replace(processed_str.begin(), processed_str.end(), '-', '_');
+ if (llvm::isDigit(*processed_str.begin()))
+ return "_" + processed_str;
if (isPythonReserved(processed_str) || isODSReserved(processed_str))
return processed_str + "_";
@@ -988,8 +992,6 @@ static void emitValueBuilder(const Operator &op,
// If we are asked to skip default builders, comply.
if (op.skipDefaultBuilders())
return;
- auto name = sanitizeName(op.getOperationName());
- iterator_range<llvm::SplittingIterator> splitName = llvm::split(name, ".");
// Params with (possibly) default args.
auto valueBuilderParams =
llvm::map_range(functionArgs, [](const std::string &argAndMaybeDefault) {
@@ -1008,16 +1010,16 @@ static void emitValueBuilder(const Operator &op,
auto lhs = *llvm::split(arg, "=").begin();
return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str();
});
- os << llvm::formatv(
- valueBuilderTemplate,
- // Drop dialect name and then sanitize again (to catch e.g. func.return).
- sanitizeName(llvm::join(++splitName.begin(), splitName.end(), "_")),
- op.getCppClassName(), llvm::join(valueBuilderParams, ", "),
- llvm::join(opBuilderArgs, ", "),
- (op.getNumResults() > 1
- ? "_Sequence[_ods_ir.OpResult]"
- : (op.getNumResults() > 0 ? "_ods_ir.OpResult"
- : "_ods_ir.Operation")));
+ std::string name_without_dialect =
+ op.getOperationName().substr(op.getOperationName().find('.') + 1);
+ os << llvm::formatv(valueBuilderTemplate, sanitizeName(name_without_dialect),
+ op.getCppClassName(),
+ llvm::join(valueBuilderParams, ", "),
+ llvm::join(opBuilderArgs, ", "),
+ (op.getNumResults() > 1
+ ? "_Sequence[_ods_ir.OpResult]"
+ : (op.getNumResults() > 0 ? "_ods_ir.OpResult"
+ : "_ods_ir.Operation")));
}
/// Emits bindings for a specific Op to the given output stream.
More information about the Mlir-commits
mailing list