[Mlir-commits] [mlir] f431d38 - Make Python MLIR Operation not iterable

Mehdi Amini llvmlistbot at llvm.org
Tue Oct 26 00:21:23 PDT 2021


Author: Mehdi Amini
Date: 2021-10-26T07:21:09Z
New Revision: f431d3878a07a67e544dacb98ad553b6a0b6b25a

URL: https://github.com/llvm/llvm-project/commit/f431d3878a07a67e544dacb98ad553b6a0b6b25a
DIFF: https://github.com/llvm/llvm-project/commit/f431d3878a07a67e544dacb98ad553b6a0b6b25a.diff

LOG: Make Python MLIR Operation not iterable

The current behavior is conveniently allowing to iterate on the regions of an operation
implicitly by exposing an operation as Iterable. However this is also error prone and
code that may intend to iterate on the results or the operands could end up "working"
apparently instead of throwing a runtime error.
The lack of static type checking in Python contributes to the ambiguity here, it seems
safer to not do this and require and explicit qualification to iterate (`op.results`, `op.regions`, ...).

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D111697

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/python/mlir/dialects/_builtin_ops_ext.py
    mlir/python/mlir/dialects/_ods_common.py
    mlir/test/python/dialects/builtin.py
    mlir/test/python/dialects/math.py
    mlir/test/python/ir/operation.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 4fc581b5dee77..7abd2a1f6b796 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2152,10 +2152,6 @@ void mlir::python::populateIRCore(py::module &m) {
           },
           "Returns the source location the operation was defined or derived "
           "from.")
-      .def("__iter__",
-           [](PyOperationBase &self) {
-             return PyRegionIterator(self.getOperation().getRef());
-           })
       .def(
           "__str__",
           [](PyOperationBase &self) {

diff  --git a/mlir/python/mlir/dialects/_builtin_ops_ext.py b/mlir/python/mlir/dialects/_builtin_ops_ext.py
index 462850d63156e..78f8c95c42da5 100644
--- a/mlir/python/mlir/dialects/_builtin_ops_ext.py
+++ b/mlir/python/mlir/dialects/_builtin_ops_ext.py
@@ -195,8 +195,17 @@ def decorator(f):
           # Coerce return values, add ReturnOp and rewrite func type.
           if return_values is None:
             return_values = []
+          elif isinstance(return_values, tuple):
+            return_values = list(return_values)
           elif isinstance(return_values, Value):
+            # Returning a single value is fine, coerce it into a list.
             return_values = [return_values]
+          elif isinstance(return_values, OpView):
+            # Returning a single operation is fine, coerce its results a list.
+            return_values = return_values.operation.results
+          elif isinstance(return_values, Operation):
+            # Returning a single operation is fine, coerce its results a list.
+            return_values = return_values.results
           else:
             return_values = list(return_values)
           std.ReturnOp(return_values)

diff  --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 95c44186533f1..6bb84e97800dd 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -124,7 +124,7 @@ def get_default_loc_context(location=None):
 
 
 def get_op_result_or_value(
-    arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]
+    arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value, _cext.ir.OpResultList]
 ) -> _cext.ir.Value:
   """Returns the given value or the single result of the given op.
 
@@ -136,6 +136,8 @@ def get_op_result_or_value(
     return arg.operation.result
   elif isinstance(arg, _cext.ir.Operation):
     return arg.result
+  elif isinstance(arg, _cext.ir.OpResultList):
+    return arg[0]
   else:
     assert isinstance(arg, _cext.ir.Value)
     return arg

diff  --git a/mlir/test/python/dialects/builtin.py b/mlir/test/python/dialects/builtin.py
index 73f2b5bc9cf95..8f3a041937b9d 100644
--- a/mlir/test/python/dialects/builtin.py
+++ b/mlir/test/python/dialects/builtin.py
@@ -15,6 +15,7 @@ def run(f):
 @run
 def testFromPyFunc():
   with Context() as ctx, Location.unknown() as loc:
+    ctx.allow_unregistered_dialects = True
     m = builtin.ModuleOp()
     f32 = F32Type.get()
     f64 = F64Type.get()
@@ -51,6 +52,14 @@ def call_unary(a):
       def call_binary(a, b):
         return binary_return(a, b)
 
+      # We expect coercion of a single result operation to a returned value.
+      # CHECK-LABEL: func @single_result_op
+      # CHECK: %0 = "custom.op1"() : () -> f32
+      # CHECK: return %0 : f32
+      @builtin.FuncOp.from_py_func()
+      def single_result_op():
+        return Operation.create("custom.op1", results=[f32])
+
       # CHECK-LABEL: func @call_none
       # CHECK: call @none_return(%arg0, %arg1) : (f32, f64) -> ()
       # CHECK: return

diff  --git a/mlir/test/python/dialects/math.py b/mlir/test/python/dialects/math.py
index e3f8829b27783..c5985dbf36b32 100644
--- a/mlir/test/python/dialects/math.py
+++ b/mlir/test/python/dialects/math.py
@@ -19,7 +19,7 @@ def emit_sqrt(arg):
         return mlir_math.SqrtOp(arg)
 
     # CHECK-LABEL: func @emit_sqrt(
-    # CHECK-SAME:                  %[[ARG:.*]]: f32) {
+    # CHECK-SAME:                  %[[ARG:.*]]: f32) -> f32 {
     # CHECK:         math.sqrt %[[ARG]] : f32
     # CHECK:         return
     # CHECK:       }

diff  --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index e53be97cf25d9..9cd4824d68997 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -40,7 +40,7 @@ def testTraverseOpRegionBlockIterators():
   print(f".verify = {module.operation.verify()}")
 
   # Get the regions and blocks from the default collections.
-  default_regions = list(op)
+  default_regions = list(op.regions)
   default_blocks = list(default_regions[0])
   # They should compare equal regardless of how obtained.
   assert default_regions == regions
@@ -53,7 +53,7 @@ def testTraverseOpRegionBlockIterators():
   assert default_operations == operations
 
   def walk_operations(indent, op):
-    for i, region in enumerate(op):
+    for i, region in enumerate(op.regions):
       print(f"{indent}REGION {i}:")
       for j, block in enumerate(region):
         print(f"{indent}  BLOCK {j}:")


        


More information about the Mlir-commits mailing list