[Mlir-commits] [mlir] [mlir][python] generate value builders (PR #68308)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 5 11:22:10 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-adt

<details>
<summary>Changes</summary>

This PR adds the additional generation of what I'm calling "value builders" (a term I'm not married to) that look like this:

```python
def empty(sizes, element_type, *, loc=None, ip=None):
    return get_result_or_results(tensor.EmptyOp(sizes=sizes, element_type=element_type, loc=loc, ip=ip))
```

which instantiates a `tensor.EmptyOp` and then immediately grabs the result (`OpResult`) and then returns that *instead of a handle to the op*.

What's the point of adding these when `EmptyOp.result` already exists? My claim/feeling/intuition is that eDSL users are more comfortable with a value centric programming model (i.e., passing values as operands) as opposed to an operator instantiation programming model. Thus this change enables (or at least goes towards) the bindings supporting such a user and use case. For example,

```python
i32 = IntegerType.get_signless(32)
...
ten1 = tensor.empty((10, 10), i32)
ten2 = tensor.empty((10, 10), i32)
ten3 = arith.addi(ten1, ten2)
```

Note, in order to present a "pythonic" API and enable "pythonic" eDSLs, the generated identifiers (op names and operand names) are snake case instead of camel case and thus `llvm::convertToSnakeFromCamelCase` needed a small fix; currently runs of caps aren't handled correctly so e.g. something like `Intel_OCL_BI` is snake cased to `intel_o_c_l_b_i` (previously discussed on this [phabricator patch](https://reviews.llvm.org/rG92233062c17590d3157bdc6db430fcdfc54312fe)). The easiest way to fix was to use regexes instead of the existing loop. Full disclosure the regexes were pulled from [inflection](https://github.com/jpvanhal/inflection/blob/88eefaacf7d0caaa701af7c8ab2d0ab3f17086f1/inflection/__init__.py#L416-L418) but it should be pretty clear they're correct.

---

Patch is 26.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/68308.diff


7 Files Affected:

- (modified) llvm/lib/Support/StringExtras.cpp (+7-11) 
- (modified) llvm/unittests/ADT/StringExtrasTest.cpp (+5) 
- (modified) mlir/python/mlir/dialects/_ods_common.py (+15) 
- (modified) mlir/python/mlir/dialects/_scf_ops_ext.py (+38-2) 
- (modified) mlir/test/mlir-tblgen/op-python-bindings.td (+69-1) 
- (modified) mlir/test/python/dialects/scf.py (+22-1) 
- (modified) mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp (+63-12) 


``````````diff
diff --git a/llvm/lib/Support/StringExtras.cpp b/llvm/lib/Support/StringExtras.cpp
index 5683d7005584eb2..fd5a34fb3d6e82c 100644
--- a/llvm/lib/Support/StringExtras.cpp
+++ b/llvm/lib/Support/StringExtras.cpp
@@ -12,6 +12,7 @@
 
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Regex.h"
 #include "llvm/Support/raw_ostream.h"
 #include <cctype>
 
@@ -96,18 +97,13 @@ std::string llvm::convertToSnakeFromCamelCase(StringRef input) {
   if (input.empty())
     return "";
 
-  std::string snakeCase;
-  snakeCase.reserve(input.size());
-  for (char c : input) {
-    if (!std::isupper(c)) {
-      snakeCase.push_back(c);
-      continue;
-    }
-
-    if (!snakeCase.empty() && snakeCase.back() != '_')
-      snakeCase.push_back('_');
-    snakeCase.push_back(llvm::toLower(c));
+  std::string snakeCase = input.str();
+  for (int i = 0; i < 10; ++i) {
+    snakeCase = llvm::Regex("([A-Z]+)([A-Z][a-z])").sub("\\1_\\2", snakeCase);
+    snakeCase = llvm::Regex("([a-z0-9])([A-Z])").sub("\\1_\\2", snakeCase);
   }
+  std::transform(snakeCase.begin(), snakeCase.end(), snakeCase.begin(),
+                 [](unsigned char c) { return std::tolower(c); });
   return snakeCase;
 }
 
diff --git a/llvm/unittests/ADT/StringExtrasTest.cpp b/llvm/unittests/ADT/StringExtrasTest.cpp
index 3f69c91b270a355..fab562f1ed0d594 100644
--- a/llvm/unittests/ADT/StringExtrasTest.cpp
+++ b/llvm/unittests/ADT/StringExtrasTest.cpp
@@ -184,6 +184,11 @@ TEST(StringExtrasTest, ConvertToSnakeFromCamelCase) {
 
   testConvertToSnakeCase("OpName", "op_name");
   testConvertToSnakeCase("opName", "op_name");
+  testConvertToSnakeCase("OPName", "op_name");
+  testConvertToSnakeCase("opNAME", "op_name");
+  testConvertToSnakeCase("opNAMe", "op_na_me");
+  testConvertToSnakeCase("opnameE", "opname_e");
+  testConvertToSnakeCase("OPNameOPName", "op_name_op_name");
   testConvertToSnakeCase("_OpName", "_op_name");
   testConvertToSnakeCase("Op_Name", "op_name");
   testConvertToSnakeCase("", "");
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 7655629a5542520..895c3228139b392 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -13,6 +13,7 @@
     "get_default_loc_context",
     "get_op_result_or_value",
     "get_op_results_or_values",
+    "get_op_result_or_op_results",
     "segmented_accessor",
 ]
 
@@ -167,3 +168,17 @@ def get_op_results_or_values(
         return arg.results
     else:
         return [get_op_result_or_value(element) for element in arg]
+
+
+def get_op_result_or_op_results(
+    op: _Union[_cext.ir.OpView, _cext.ir.Operation],
+) -> _Union[_cext.ir.Operation, _cext.ir.OpResult, _Sequence[_cext.ir.OpResult]]:
+    if isinstance(op, _cext.ir.OpView):
+        op = op.operation
+    return (
+        list(get_op_results_or_values(op))
+        if len(op.results) > 1
+        else get_op_result_or_value(op)
+        if len(op.results) > 0
+        else op
+    )
diff --git a/mlir/python/mlir/dialects/_scf_ops_ext.py b/mlir/python/mlir/dialects/_scf_ops_ext.py
index 4b0a31327abb0ee..35bd247a0a1e7f7 100644
--- a/mlir/python/mlir/dialects/_scf_ops_ext.py
+++ b/mlir/python/mlir/dialects/_scf_ops_ext.py
@@ -7,11 +7,13 @@
 except ImportError as e:
     raise RuntimeError("Error loading imports from extension module") from e
 
-from typing import Any, Optional, Sequence, Union
+from typing import Optional, Sequence, Union
 from ._ods_common import (
     get_op_result_or_value as _get_op_result_or_value,
     get_op_results_or_values as _get_op_results_or_values,
 )
+from .arith import constant
+from . import scf
 
 
 class ForOp:
@@ -25,7 +27,7 @@ def __init__(
         iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
         *,
         loc=None,
-        ip=None
+        ip=None,
     ):
         """Creates an SCF `for` operation.
 
@@ -104,3 +106,37 @@ def then_block(self):
     def else_block(self):
         """Returns the else block of the if operation."""
         return self.regions[1].blocks[0]
+
+
+def range_(
+    start,
+    stop=None,
+    step=None,
+    iter_args: Optional[Sequence[Value]] = None,
+    *,
+    loc=None,
+    ip=None,
+):
+    if step is None:
+        step = 1
+    if stop is None:
+        stop = start
+        start = 0
+    params = [start, stop, step]
+    for i, p in enumerate(params):
+        if isinstance(p, int):
+            p = constant(p)
+        elif isinstance(p, float):
+            raise ValueError(f"{p=} must be int.")
+        params[i] = p
+
+    for_op = scf.ForOp(start, stop, step, iter_args, loc=loc, ip=ip)
+    iv = for_op.induction_variable
+    iter_args = tuple(for_op.inner_iter_args)
+    with InsertionPoint(for_op.body):
+        if len(iter_args) > 1:
+            yield iv, iter_args
+        elif len(iter_args) == 1:
+            yield iv, iter_args[0]
+        else:
+            yield iv
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index a131209fa45cb6c..be2cfc8adc78179 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -60,6 +60,9 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
                    Optional<AnyType>:$variadic2);
 }
 
+// CHECK: def attr_sized_operands(variadic1, non_variadic, *, variadic2=None, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:   return _get_op_result_or_op_results(AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class AttrSizedResultsOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.attr_sized_results"
@@ -104,6 +107,9 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
                  Variadic<AnyType>:$variadic2);
 }
 
+// CHECK: def attr_sized_results(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK:   return _get_op_result_or_op_results(AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class AttributedOp(_ods_ir.OpView):
@@ -151,6 +157,9 @@ def AttributedOp : TestOp<"attributed_op"> {
                    UnitAttr:$unitAttr, I32Attr:$in);
 }
 
+// CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:     return _get_op_result_or_op_results(AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip))
+
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class AttributedOpWithOperands(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.attributed_op_with_operands"
@@ -184,6 +193,9 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
   let arguments = (ins I32, UnitAttr:$in, F32, OptionalAttr<F32Attr>:$is);
 }
 
+// CHECK: def attributed_op_with_operands(_gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:   return _get_op_result_or_op_results(AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip))
+
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class DefaultValuedAttrsOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.default_valued_attrs"
@@ -205,6 +217,9 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
   let results = (outs);
 }
 
+// CHECK: def default_valued_attrs(*, arr=None, unsupported=None, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:   return _get_op_result_or_op_results(DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip))
+
 // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
 def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
   // CHECK: def __init__(self, type_, *, loc=None, ip=None):
@@ -220,6 +235,9 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu
   let results = (outs AnyType:$res, AnyType);
 }
 
+// CHECK: def derive_result_types_op(type_, *, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK:   return _get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip))
+
 // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
 def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
   // CHECK: def __init__(self, res, _gen_res_1, type_, *, loc=None, ip=None):
@@ -227,6 +245,9 @@ def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [Fir
   let results = (outs AnyType:$res, Variadic<AnyType>);
 }
 
+// CHECK: def derive_result_types_variadic_op(res, _gen_res_1, type_, *, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK:   return _get_op_result_or_op_results(DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip))
+
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class EmptyOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.empty"
@@ -241,6 +262,8 @@ def EmptyOp : TestOp<"empty">;
   // CHECK:     attributes=attributes, results=results, operands=operands,
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip))
 
+// CHECK: def empty(*, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:   return _get_op_result_or_op_results(EmptyOp(loc=loc, ip=ip))
 
 // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op"
 def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
@@ -252,6 +275,9 @@ def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
   let results = (outs I32:$i32, F32:$f32);
 }
 
+// CHECK: def infer_result_types_implied_op(*, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK:   return _get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip))
+
 // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op"
 def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> {
   // CHECK:  def __init__(self, *, loc=None, ip=None):
@@ -262,6 +288,9 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]>
   let results = (outs AnyType, AnyType, AnyType);
 }
 
+// CHECK: def infer_result_types_op(*, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK:   return _get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip))
+
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class MissingNamesOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.missing_names"
@@ -297,6 +326,9 @@ def MissingNamesOp : TestOp<"missing_names"> {
   let results = (outs I32:$i32, AnyFloat, I64:$i64);
 }
 
+// CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK:   return _get_op_result_or_op_results(MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip))
+
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class OneOptionalOperandOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.one_optional_operand"
@@ -323,9 +355,11 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
   // CHECK: @builtins.property
   // CHECK: def optional(self):
   // CHECK:   return None if len(self.operation.operands) < 2 else self.operation.operands[1]
-
 }
 
+// CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:   return _get_op_result_or_op_results(OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip))
+
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class OneVariadicOperandOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.one_variadic_operand"
@@ -355,6 +389,9 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
   let arguments = (ins AnyType:$non_variadic, Variadic<AnyType>:$variadic);
 }
 
+// CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:   return _get_op_result_or_op_results(OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip))
+
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class OneVariadicResultOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.one_variadic_result"
@@ -385,6 +422,9 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
   let results = (outs Variadic<AnyType>:$variadic, AnyType:$non_variadic);
 }
 
+// CHECK: def one_variadic_result(variadic, non_variadic, *, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK:   return _get_op_result_or_op_results(OneVariadicResultOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip))
+
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class PythonKeywordOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.python_keyword"
@@ -405,6 +445,10 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
   // CHECK:   return self.operation.operands[0]
   let arguments = (ins AnyType:$in);
 }
+
+// CHECK: def python_keyword(in_, *, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:   return _get_op_result_or_op_results(PythonKeywordOp(in_=in_, loc=loc, ip=ip))
+
 // CHECK-LABEL: OPERATION_NAME = "test.same_results"
 def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
   // CHECK: def __init__(self, in1, in2, *, loc=None, ip=None):
@@ -416,6 +460,9 @@ def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
   let results = (outs AnyType:$res);
 }
 
+// CHECK: def same_results(in1, in2, *, loc=None, ip=None) -> _ods_ir.OpResult:
+// CHECK:   return _get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip))
+
 // CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic"
 def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> {
   // CHECK: def __init__(self, res, in1, in2, *, loc=None, ip=None):
@@ -423,6 +470,9 @@ def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResu
   let results = (outs Variadic<AnyType>:$res);
 }
 
+// CHECK: def same_results_variadic(res, in1, in2, *, loc=None, ip=None) -> _ods_ir.OpResult:
+// CHECK:   return _get_op_result_or_op_results(SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip))
+
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class SameVariadicOperandSizeOp(_ods_ir.OpView):
@@ -447,6 +497,9 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
                    Variadic<AnyType>:$variadic2);
 }
 
+// CHECK: def same_variadic_operand(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:   return _get_op_result_or_op_results(SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class SameVariadicResultSizeOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.same_variadic_result"
@@ -470,6 +523,9 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
                  Variadic<AnyType>:$variadic2);
 }
 
+// CHECK: def same_variadic_result(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK:   return _get_op_result_or_op_results(SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class SimpleOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.simple"
@@ -507,6 +563,9 @@ def SimpleOp : TestOp<"simple"> {
   let results = (outs I64:$i64, AnyFloat:$f64);
 }
 
+// CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK:   return _get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip))
+
 // CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.variadic_and_normal_region"
 def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
@@ -531,6 +590,9 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
   // CHECK:    return self.regions[2:]
 }
 
+// CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:   return _get_op_result_or_op_results(VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
+
 // CHECK: class VariadicRegionOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.variadic_region"
 def VariadicRegionOp : TestOp<"variadic_region"> {
@@ -551,6 +613,9 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
   // CHECK:    return self.regions[0:]
 }
 
+// CHECK: def variadic_region(num_variadic, *, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:   return _get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
+
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class WithSuccessorsOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.with_successors"
@@ -562,3 +627,6 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
   let successors = (successor AnySuccessor:$successor,
                               VariadicSuccessor<AnySuccessor>:$successors);
 }
+
+// CHECK: def with_successors(successor, successors, *, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:   return _get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip))
\ No newline at end of file
diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py
index 8cb55fdf6a1eb3b..843b87b66871761 100644
--- a/mlir/test/python/dialects/scf.py
+++ b/mlir/test/python/dialects/scf.py
@@ -4,7 +4,7 @@
 from mlir.dialects import arith
 from mlir.dialects import func
 from mlir.dialects import scf
-from mlir.dialects import builtin
+from mlir.dialects._scf_ops_ext import range_
 
 
 def constructAndPrintInModule(f):
@@ -54,6 +54,27 @@ def induction_var(lb, ub, step):
 # CHECK: scf.yield %[[IV]]
 
 
+# CHECK-LABEL: TEST: testForSugar
+ at constructAndPrintInModule
+def testForSugar():
+    index_type = IndexType.get()
+
+    @func.FuncOp.from_py_func(index_type, index_type, index_type)
+    def range_loop(lb, ub, step):
+        for i in range_(lb, ub, step):
+            add = arith.addi(i, i)
+            scf.yield_([])
+        return
+
+
+# CHECK: func.func @range_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) {
+# CHECK:   scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
+# CHECK:     %0 = arith.addi %[[IV]], %[[IV]] : index
+# CHECK:   }
+# CHECK:   return
+# CHECK: }
+
+
 @constructAndPrintInModule
 def testOpsAsArguments():
     index_type = IndexType.get()
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 0b5df7ab70dddb2..1357c84099ccd5a 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -30,7 +30,7 @@ constexpr const char *fileHeader = R"Py(
 # Autogenerated by mlir-tblgen; don't manually edit.
 
 from ._ods_common import _cext as _ods_cext
-from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
+from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, get_op_result_or_op_results as _get_op_result_or_op_results
 _ods_ir = _ods_cext.ir
 
 try:
@@ -39,6 +39,7 @@ except ImportError:
   _ods_ext_module = None
 
 import builtins
+from typing import Sequence as _Sequence, Union as _Union
 
 )Py";
 
@@ -260,11 +261,16 @@ constexpr const char...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/68308


More information about the Mlir-commits mailing list