[Mlir-commits] [mlir] [mlir][python] Improve sanitization of python names (PR #68801)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 11 09:31:09 PDT 2023


https://github.com/JoelWee updated https://github.com/llvm/llvm-project/pull/68801

>From 1b2df7c110ca1e695bbc24ce17aeceefa9899ec9 Mon Sep 17 00:00:00 2001
From: Joel Wee <joelwee at google.com>
Date: Wed, 11 Oct 2023 12:42:24 +0000
Subject: [PATCH] [mlir][python] Improve sanitization of python names

---
 mlir/test/mlir-tblgen/op-python-bindings.td   |  9 ++++++
 mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 28 ++++++++++---------
 2 files changed, 24 insertions(+), 13 deletions(-)

diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 8ca23fa9f45c4ab..586efd2d5f1829c 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -616,6 +616,15 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
 // CHECK: def variadic_region(num_variadic, *, loc=None, ip=None)
 // CHECK:   return _get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
 
+// CHECK: @_ods_cext.register_operation(_Dialect)
+// CHECK: class WithSpecialCharactersOp(_ods_ir.OpView):
+// CHECK-LABEL: OPERATION_NAME = "test.123with- special.characters"
+def WithSpecialCharactersOp : TestOp<"123with- special.characters"> {
+}
+
+// CHECK: def _123with__special_characters(*, loc=None, ip=None)
+// CHECK:   return _get_op_result_or_op_results(WithSpecialCharactersOp(loc=loc, ip=ip))
+
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class WithSuccessorsOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.with_successors"
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 2c81538b7b40433..996cdf53ac8cf53 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -302,8 +302,12 @@ static bool isODSReserved(StringRef str) {
 /// modified version.
 static std::string sanitizeName(StringRef name) {
   std::string processed_str = name.str();
+  std::replace_if(
+      processed_str.begin(), processed_str.end(),
+      [](char c) { return !llvm::isAlnum(c); }, '_');
 
-  std::replace(processed_str.begin(), processed_str.end(), '-', '_');
+  if (llvm::isDigit(*processed_str.begin()))
+    return "_" + processed_str;
 
   if (isPythonReserved(processed_str) || isODSReserved(processed_str))
     return processed_str + "_";
@@ -988,8 +992,6 @@ static void emitValueBuilder(const Operator &op,
   // If we are asked to skip default builders, comply.
   if (op.skipDefaultBuilders())
     return;
-  auto name = sanitizeName(op.getOperationName());
-  iterator_range<llvm::SplittingIterator> splitName = llvm::split(name, ".");
   // Params with (possibly) default args.
   auto valueBuilderParams =
       llvm::map_range(functionArgs, [](const std::string &argAndMaybeDefault) {
@@ -1008,16 +1010,16 @@ static void emitValueBuilder(const Operator &op,
         auto lhs = *llvm::split(arg, "=").begin();
         return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str();
       });
-  os << llvm::formatv(
-      valueBuilderTemplate,
-      // Drop dialect name and then sanitize again (to catch e.g. func.return).
-      sanitizeName(llvm::join(++splitName.begin(), splitName.end(), "_")),
-      op.getCppClassName(), llvm::join(valueBuilderParams, ", "),
-      llvm::join(opBuilderArgs, ", "),
-      (op.getNumResults() > 1
-           ? "_Sequence[_ods_ir.OpResult]"
-           : (op.getNumResults() > 0 ? "_ods_ir.OpResult"
-                                     : "_ods_ir.Operation")));
+  auto name_without_dialect =
+      op.getOperationName().substr(op.getOperationName().find('.') + 1);
+  os << llvm::formatv(valueBuilderTemplate, sanitizeName(name_without_dialect),
+                      op.getCppClassName(),
+                      llvm::join(valueBuilderParams, ", "),
+                      llvm::join(opBuilderArgs, ", "),
+                      (op.getNumResults() > 1
+                           ? "_Sequence[_ods_ir.OpResult]"
+                           : (op.getNumResults() > 0 ? "_ods_ir.OpResult"
+                                                     : "_ods_ir.Operation")));
 }
 
 /// Emits bindings for a specific Op to the given output stream.



More information about the Mlir-commits mailing list