[Mlir-commits] [mlir] [mlir][python] generate value builders (PR #68308)

Maksim Levental llvmlistbot at llvm.org
Thu Oct 5 09:17:51 PDT 2023


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/68308

>From 032ec201605f4fefdaa8bfaab44e98097024d590 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 5 Oct 2023 01:35:05 -0500
Subject: [PATCH] [mlir][python] generate value builders

---
 llvm/lib/Support/StringExtras.cpp             | 18 ++---
 llvm/unittests/ADT/StringExtrasTest.cpp       |  5 ++
 mlir/python/mlir/dialects/_ods_common.py      | 15 ++++
 mlir/python/mlir/dialects/_scf_ops_ext.py     | 40 ++++++++++-
 mlir/test/python/dialects/scf.py              | 21 +++++-
 mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 70 +++++++++++++++----
 6 files changed, 143 insertions(+), 26 deletions(-)

diff --git a/llvm/lib/Support/StringExtras.cpp b/llvm/lib/Support/StringExtras.cpp
index 5683d7005584eb2..b164e65dbdb3467 100644
--- a/llvm/lib/Support/StringExtras.cpp
+++ b/llvm/lib/Support/StringExtras.cpp
@@ -12,6 +12,7 @@
 
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Regex.h"
 #include "llvm/Support/raw_ostream.h"
 #include <cctype>
 
@@ -96,18 +97,13 @@ std::string llvm::convertToSnakeFromCamelCase(StringRef input) {
   if (input.empty())
     return "";
 
-  std::string snakeCase;
-  snakeCase.reserve(input.size());
-  for (char c : input) {
-    if (!std::isupper(c)) {
-      snakeCase.push_back(c);
-      continue;
-    }
-
-    if (!snakeCase.empty() && snakeCase.back() != '_')
-      snakeCase.push_back('_');
-    snakeCase.push_back(llvm::toLower(c));
+  std::string snakeCase = input.str();
+  for (int i = 0; i < 10; ++i) {
+    snakeCase = llvm::Regex("([A-Z]+)([A-Z][a-z])").sub("\\1_\\2", snakeCase);
+    snakeCase = llvm::Regex("([a-z])([A-Z])").sub("\\1_\\2", snakeCase);
   }
+  std::transform(snakeCase.begin(), snakeCase.end(), snakeCase.begin(),
+                 [](unsigned char c) { return std::tolower(c); });
   return snakeCase;
 }
 
diff --git a/llvm/unittests/ADT/StringExtrasTest.cpp b/llvm/unittests/ADT/StringExtrasTest.cpp
index 3f69c91b270a355..fab562f1ed0d594 100644
--- a/llvm/unittests/ADT/StringExtrasTest.cpp
+++ b/llvm/unittests/ADT/StringExtrasTest.cpp
@@ -184,6 +184,11 @@ TEST(StringExtrasTest, ConvertToSnakeFromCamelCase) {
 
   testConvertToSnakeCase("OpName", "op_name");
   testConvertToSnakeCase("opName", "op_name");
+  testConvertToSnakeCase("OPName", "op_name");
+  testConvertToSnakeCase("opNAME", "op_name");
+  testConvertToSnakeCase("opNAMe", "op_na_me");
+  testConvertToSnakeCase("opnameE", "opname_e");
+  testConvertToSnakeCase("OPNameOPName", "op_name_op_name");
   testConvertToSnakeCase("_OpName", "_op_name");
   testConvertToSnakeCase("Op_Name", "op_name");
   testConvertToSnakeCase("", "");
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 7655629a5542520..9c12b282517472c 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -13,6 +13,7 @@
     "get_default_loc_context",
     "get_op_result_or_value",
     "get_op_results_or_values",
+    "get_op_result_or_op_results",
     "segmented_accessor",
 ]
 
@@ -167,3 +168,17 @@ def get_op_results_or_values(
         return arg.results
     else:
         return [get_op_result_or_value(element) for element in arg]
+
+
+def get_op_result_or_op_results(
+    op: _Union[_cext.ir.OpView, _cext.ir.Operation],
+) -> _Union[_cext.ir.OpView, _cext.ir.Value, _Sequence[_cext.ir.Value]]:
+    if isinstance(op, _cext.ir.OpView):
+        op = op.operation
+    return (
+        list(get_op_results_or_values(op))
+        if len(op.results) > 1
+        else get_op_result_or_value(op)
+        if len(op.results) > 0
+        else op
+    )
diff --git a/mlir/python/mlir/dialects/_scf_ops_ext.py b/mlir/python/mlir/dialects/_scf_ops_ext.py
index 4b0a31327abb0ee..35bd247a0a1e7f7 100644
--- a/mlir/python/mlir/dialects/_scf_ops_ext.py
+++ b/mlir/python/mlir/dialects/_scf_ops_ext.py
@@ -7,11 +7,13 @@
 except ImportError as e:
     raise RuntimeError("Error loading imports from extension module") from e
 
-from typing import Any, Optional, Sequence, Union
+from typing import Optional, Sequence, Union
 from ._ods_common import (
     get_op_result_or_value as _get_op_result_or_value,
     get_op_results_or_values as _get_op_results_or_values,
 )
+from .arith import constant
+from . import scf
 
 
 class ForOp:
@@ -25,7 +27,7 @@ def __init__(
         iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
         *,
         loc=None,
-        ip=None
+        ip=None,
     ):
         """Creates an SCF `for` operation.
 
@@ -104,3 +106,37 @@ def then_block(self):
     def else_block(self):
         """Returns the else block of the if operation."""
         return self.regions[1].blocks[0]
+
+
+def range_(
+    start,
+    stop=None,
+    step=None,
+    iter_args: Optional[Sequence[Value]] = None,
+    *,
+    loc=None,
+    ip=None,
+):
+    if step is None:
+        step = 1
+    if stop is None:
+        stop = start
+        start = 0
+    params = [start, stop, step]
+    for i, p in enumerate(params):
+        if isinstance(p, int):
+            p = constant(p)
+        elif isinstance(p, float):
+            raise ValueError(f"{p=} must be int.")
+        params[i] = p
+
+    for_op = scf.ForOp(start, stop, step, iter_args, loc=loc, ip=ip)
+    iv = for_op.induction_variable
+    iter_args = tuple(for_op.inner_iter_args)
+    with InsertionPoint(for_op.body):
+        if len(iter_args) > 1:
+            yield iv, iter_args
+        elif len(iter_args) == 1:
+            yield iv, iter_args[0]
+        else:
+            yield iv
diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py
index 8cb55fdf6a1eb3b..94d980d3a499789 100644
--- a/mlir/test/python/dialects/scf.py
+++ b/mlir/test/python/dialects/scf.py
@@ -4,7 +4,7 @@
 from mlir.dialects import arith
 from mlir.dialects import func
 from mlir.dialects import scf
-from mlir.dialects import builtin
+from mlir.dialects._scf_ops_ext import range_
 
 
 def constructAndPrintInModule(f):
@@ -54,6 +54,25 @@ def induction_var(lb, ub, step):
 # CHECK: scf.yield %[[IV]]
 
 
+# CHECK-LABEL: TEST: testForSugar
+ at constructAndPrintInModule
+def testForSugar():
+    index_type = IndexType.get()
+
+    @func.FuncOp.from_py_func(index_type, index_type, index_type)
+    def range_loop(lb, ub, step):
+        for i in range_(lb, ub, step):
+            add = arith.addi(i, i)
+            scf.yield_([])
+        return
+
+# CHECK: func.func @range_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) {
+# CHECK:   scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
+# CHECK:     %0 = arith.addi %[[IV]], %[[IV]] : index
+# CHECK:   }
+# CHECK:   return
+# CHECK: }
+
 @constructAndPrintInModule
 def testOpsAsArguments():
     index_type = IndexType.get()
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 0b5df7ab70dddb2..5b10b3fc0e72e87 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -30,7 +30,7 @@ constexpr const char *fileHeader = R"Py(
 # Autogenerated by mlir-tblgen; don't manually edit.
 
 from ._ods_common import _cext as _ods_cext
-from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
+from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, get_op_result_or_op_results as _get_op_result_or_op_results
 _ods_ir = _ods_cext.ir
 
 try:
@@ -260,11 +260,16 @@ constexpr const char *attributeDeleterTemplate = R"Py(
     del self.operation.attributes["{1}"]
 )Py";
 
-constexpr const char *regionAccessorTemplate = R"PY(
+constexpr const char *regionAccessorTemplate = R"Py(
   @builtins.property
   def {0}(self):
     return self.regions[{1}]
-)PY";
+)Py";
+
+constexpr const char *valueBuilderTemplate = R"Py(
+def {0}({2}):
+    return _get_op_result_or_op_results({1}({3}))
+)Py";
 
 static llvm::cl::OptionCategory
     clOpPythonBindingCat("Options for -gen-python-op-bindings");
@@ -609,9 +614,7 @@ populateBuilderArgsResults(const Operator &op,
 static void
 populateBuilderArgs(const Operator &op,
                     llvm::SmallVectorImpl<std::string> &builderArgs,
-                    llvm::SmallVectorImpl<std::string> &operandNames,
-                    llvm::SmallVectorImpl<std::string> &successorArgNames) {
-
+                    llvm::SmallVectorImpl<std::string> &operandNames) {
   for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
     std::string name = op.getArgName(i).str();
     if (name.empty())
@@ -734,11 +737,11 @@ populateBuilderLinesOperand(const Operator &op,
 /// attribute:
 ///   - {0} is the name of the attribute from which to derive the types.
 constexpr const char *deriveTypeFromAttrTemplate =
-    R"PY(_ods_result_type_source_attr = attributes["{0}"]
+    R"Py(_ods_result_type_source_attr = attributes["{0}"]
 _ods_derived_result_type = (
     _ods_ir.TypeAttr(_ods_result_type_source_attr).value
     if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
-    _ods_result_type_source_attr.type))PY";
+    _ods_result_type_source_attr.type))Py";
 
 /// Python code template appending {0} type {1} times to the results list.
 constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
@@ -837,10 +840,11 @@ populateBuilderRegions(const Operator &op,
 
 /// Emits a default builder constructing an operation from the list of its
 /// result types, followed by a list of its operands.
-static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
+static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
+                                                           raw_ostream &os) {
   // If we are asked to skip default builders, comply.
   if (op.skipDefaultBuilders())
-    return;
+    return {};
 
   llvm::SmallVector<std::string> builderArgs;
   llvm::SmallVector<std::string> builderLines;
@@ -850,7 +854,7 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
                       op.getNumNativeAttributes() + op.getNumSuccessors());
   populateBuilderArgsResults(op, builderArgs);
   size_t numResultArgs = builderArgs.size();
-  populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames);
+  populateBuilderArgs(op, builderArgs, operandArgNames);
   size_t numOperandAttrArgs = builderArgs.size() - numResultArgs;
   populateBuilderArgsSuccessors(op, builderArgs, successorArgNames);
 
@@ -921,6 +925,8 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
   os << llvm::formatv(initTemplate, llvm::join(functionArgs, ", "),
                       llvm::join(builderLines, "\n    "),
                       llvm::join(initArgs, ", "));
+  return llvm::to_vector<8>(
+      llvm::map_range(functionArgs, [](llvm::StringRef s) { return s.str(); }));
 }
 
 static void emitSegmentSpec(
@@ -968,6 +974,45 @@ static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
   }
 }
 
+/// Emits builder that extracts results from op
+static void emitValueBuilder(const Operator &op,
+                             llvm::SmallVector<std::string> functionArgs,
+                             raw_ostream &os) {
+  // If we are asked to skip default builders, comply.
+  if (op.skipDefaultBuilders())
+    return;
+  auto name = sanitizeName(op.getOperationName());
+  iterator_range<llvm::SplittingIterator> splitName = llvm::split(name, ".");
+  os << llvm::formatv(
+      valueBuilderTemplate,
+      // Drop dialect name and then sanitize again (to catch e.g. func.return).
+      sanitizeName(llvm::join(++splitName.begin(), splitName.end(), "_")),
+      op.getCppClassName(),
+      llvm::join(
+          llvm::map_range(functionArgs,
+                          [](const std::string &argAndMaybeDefault) {
+                            llvm::SmallVector<llvm::StringRef> argMaybeDefault =
+                                llvm::to_vector<2>(
+                                    llvm::split(argAndMaybeDefault, "="));
+                            auto arg = llvm::convertToSnakeFromCamelCase(
+                                argMaybeDefault[0]);
+                            if (argMaybeDefault.size() == 2)
+                              return arg + "=" + argMaybeDefault[1].str();
+                            return arg;
+                          }),
+          ", "),
+      llvm::join(
+          llvm::map_range(
+              llvm::make_filter_range(
+                  functionArgs, [](const std::string &s) { return s != "*"; }),
+              [](const std::string &arg) {
+                auto lhs = *llvm::split(arg, "=").begin();
+                return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs))
+                    .str();
+              }),
+          ", "));
+}
+
 /// Emits bindings for a specific Op to the given output stream.
 static void emitOpBindings(const Operator &op, raw_ostream &os) {
   os << llvm::formatv(opClassTemplate, op.getCppClassName(),
@@ -982,11 +1027,12 @@ static void emitOpBindings(const Operator &op, raw_ostream &os) {
   }
 
   emitRegionAttributes(op, os);
-  emitDefaultOpBuilder(op, os);
+  llvm::SmallVector<std::string> functionArgs = emitDefaultOpBuilder(op, os);
   emitOperandAccessors(op, os);
   emitAttributeAccessors(op, os);
   emitResultAccessors(op, os);
   emitRegionAccessors(op, os);
+  emitValueBuilder(op, functionArgs, os);
 }
 
 /// Emits bindings for the dialect specified in the command line, including file



More information about the Mlir-commits mailing list