[Mlir-commits] [mlir] [MLIR][SCF] Add dedicated Python bindings for ForallOp (PR #149416)

Colin De Vlieghere llvmlistbot at llvm.org
Fri Jul 18 13:32:52 PDT 2025


https://github.com/Cubevoid updated https://github.com/llvm/llvm-project/pull/149416

>From cfae6fb9a7c5d24e2a28404bda5daa8a60baab21 Mon Sep 17 00:00:00 2001
From: Colin De Vlieghere <cdevlieghere at tesla.com>
Date: Thu, 17 Jul 2025 15:07:44 -0700
Subject: [PATCH 1/2] [MLIR][SCF] Add dedicated Python bindings for ForallOp

This patch specializes the Python bindings for ForallOp and InParallelOp,
similar to the existing one for ForOp.  These bindings create the regions and
blocks properly and expose some additional helpers.
---
 mlir/python/mlir/dialects/scf.py | 116 ++++++++++++++++++++++++++++++-
 mlir/test/python/dialects/scf.py |  20 ++++++
 2 files changed, 135 insertions(+), 1 deletion(-)

diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index 2d0047b76c702..63d48f4eab260 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -17,7 +17,7 @@
 except ImportError as e:
     raise RuntimeError("Error loading imports from extension module") from e
 
-from typing import Optional, Sequence, Union
+from typing import List, Optional, Sequence, Tuple, Union
 
 
 @_ods_cext.register_operation(_Dialect, replace=True)
@@ -71,6 +71,120 @@ def inner_iter_args(self):
         return self.body.arguments[1:]
 
 
+def dispatch_index_op_fold_results(
+    ofrs: Sequence[Union[int, Value]],
+) -> Tuple[List[Value], List[int]]:
+    """`mlir::dispatchIndexOpFoldResults`"""
+    dynamic_vals = []
+    static_vals = []
+    for ofr in ofrs:
+        if isinstance(ofr, Value):
+            dynamic_vals.append(ofr)
+            static_vals.append(ShapedType.get_dynamic_size())
+        else:
+            static_vals.append(ofr)
+    return dynamic_vals, static_vals
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ForallOp(ForallOp):
+    """Specialization for the SCF forall op class."""
+
+    def __init__(
+        self,
+        lower_bounds: Sequence[Union[Value, int]],
+        upper_bounds: Sequence[Union[Value, int]],
+        steps: Sequence[Union[Value, int]],
+        iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
+        *,
+        mapping=None,
+        loc=None,
+        ip=None,
+    ):
+        """Creates an SCF `forall` operation.
+
+        - `lower_bounds` are the values to use as lower bounds of the loop.
+        - `upper_bounds` are the values to use as upper bounds of the loop.
+        - `steps` are the values to use as loop steps.
+        - `iter_args` is a list of additional loop-carried arguments or an operation
+          producing them as results.
+        """
+        if iter_args is None:
+            iter_args = []
+        iter_args = _get_op_results_or_values(iter_args)
+
+        dynamic_lbs, static_lbs = dispatch_index_op_fold_results(lower_bounds)
+        dynamic_ubs, static_ubs = dispatch_index_op_fold_results(upper_bounds)
+        dynamic_steps, static_steps = dispatch_index_op_fold_results(steps)
+
+        results = [arg.type for arg in iter_args]
+        super().__init__(
+            results,
+            dynamic_lbs,
+            dynamic_ubs,
+            dynamic_steps,
+            static_lbs,
+            static_ubs,
+            static_steps,
+            iter_args,
+            mapping=mapping,
+            loc=loc,
+            ip=ip,
+        )
+        rank = len(static_lbs)
+        iv_types = [IndexType.get()] * rank
+        self.regions[0].blocks.append(*iv_types, *results)
+
+    @property
+    def body(self) -> Block:
+        """Returns the body (block) of the loop."""
+        return self.regions[0].blocks[0]
+
+    @property
+    def rank(self) -> int:
+        """Returns the number of induction variables the loop has."""
+        return len(self.staticLowerBound)
+
+    @property
+    def induction_variables(self) -> BlockArgumentList:
+        """Returns the induction variables usable within the loop."""
+        return self.body.arguments[: self.rank]
+
+    @property
+    def inner_iter_args(self):
+        """Returns the loop-carried arguments usable within the loop.
+
+        To obtain the loop-carried operands, use `iter_args`.
+        """
+        return self.body.arguments[self.rank :]
+
+    @property
+    def terminator(self) -> InParallelOp:
+        """
+        Returns the loop terminator if it exists.
+        Otherwise, create a new one.
+        """
+        ops = self.body.operations
+        with InsertionPoint(self.body):
+            if not ops:
+                return InParallelOp()
+            last = ops[len(ops) - 1]
+            return last if isinstance(last, InParallelOp) else InParallelOp()
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class InParallelOp(InParallelOp):
+    """Specialization of the SCF forall.in_parallel op class."""
+
+    def __init__(self, loc=None, ip=None):
+        super().__init__(loc=loc, ip=ip)
+        self.region.blocks.append()
+
+    @property
+    def block(self) -> Block:
+        return self.region.blocks[0]
+
+
 @_ods_cext.register_operation(_Dialect, replace=True)
 class IfOp(IfOp):
     """Specialization for the SCF if op class."""
diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py
index de61f4613868f..2beff4f7703a3 100644
--- a/mlir/test/python/dialects/scf.py
+++ b/mlir/test/python/dialects/scf.py
@@ -18,6 +18,26 @@ def constructAndPrintInModule(f):
     return f
 
 
+# CHECK-LABEL: TEST: testSimpleForall
+# CHECK: scf.forall (%[[IV0:.*]], %[[IV1:.*]]) in (4, 8) shared_outs(%[[BOUND_ARG:.*]] = %{{.*}}) -> (tensor<4x8xf32>)
+# CHECK:   arith.addi %[[IV0]], %[[IV1]]
+# CHECK:   scf.forall.in_parallel
+ at constructAndPrintInModule
+def testSimpleForall():
+    f32 = F32Type.get()
+    tensor_type = RankedTensorType.get([4, 8], f32)
+
+    @func.FuncOp.from_py_func(tensor_type)
+    def forall_loop(tensor):
+        loop = scf.ForallOp([0, 0], [4, 8], [1, 1], [tensor])
+        with InsertionPoint(loop.body):
+            i, j = loop.induction_variables
+            arith.addi(i, j)
+            loop.terminator
+        # The verifier will check that the regions have been created properly.
+        assert loop.verify()
+
+
 # CHECK-LABEL: TEST: testSimpleLoop
 @constructAndPrintInModule
 def testSimpleLoop():

>From 6da3c2895c4a709bf567e3191b91f8a1e004cdf3 Mon Sep 17 00:00:00 2001
From: Colin De Vlieghere <cdevlieghere at tesla.com>
Date: Fri, 18 Jul 2025 11:13:33 -0700
Subject: [PATCH 2/2] Address comments

---
 mlir/python/mlir/dialects/scf.py | 27 +++++++++++++++------------
 mlir/test/python/dialects/scf.py |  2 +-
 2 files changed, 16 insertions(+), 13 deletions(-)

diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index 63d48f4eab260..1250bcbb18bd1 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -71,15 +71,16 @@ def inner_iter_args(self):
         return self.body.arguments[1:]
 
 
-def dispatch_index_op_fold_results(
-    ofrs: Sequence[Union[int, Value]],
+def _dispatch_index_op_fold_results(
+    ofrs: Sequence[Union[Operation, OpView, Value, int]],
 ) -> Tuple[List[Value], List[int]]:
     """`mlir::dispatchIndexOpFoldResults`"""
     dynamic_vals = []
     static_vals = []
     for ofr in ofrs:
-        if isinstance(ofr, Value):
-            dynamic_vals.append(ofr)
+        if isinstance(ofr, (Operation, OpView, Value)):
+            val = _get_op_result_or_value(ofr)
+            dynamic_vals.append(val)
             static_vals.append(ShapedType.get_dynamic_size())
         else:
             static_vals.append(ofr)
@@ -92,8 +93,8 @@ class ForallOp(ForallOp):
 
     def __init__(
         self,
-        lower_bounds: Sequence[Union[Value, int]],
-        upper_bounds: Sequence[Union[Value, int]],
+        lower_bounds: Sequence[Union[Operation, OpView, Value, int]],
+        upper_bounds: Sequence[Union[Operation, OpView, Value, int]],
         steps: Sequence[Union[Value, int]],
         iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
         *,
@@ -109,13 +110,16 @@ def __init__(
         - `iter_args` is a list of additional loop-carried arguments or an operation
           producing them as results.
         """
+        assert (
+            len(lower_bounds) == len(upper_bounds) == len(steps)
+        ), "Mismatch in length of lower bounds, upper bounds, and steps"
         if iter_args is None:
             iter_args = []
         iter_args = _get_op_results_or_values(iter_args)
 
-        dynamic_lbs, static_lbs = dispatch_index_op_fold_results(lower_bounds)
-        dynamic_ubs, static_ubs = dispatch_index_op_fold_results(upper_bounds)
-        dynamic_steps, static_steps = dispatch_index_op_fold_results(steps)
+        dynamic_lbs, static_lbs = _dispatch_index_op_fold_results(lower_bounds)
+        dynamic_ubs, static_ubs = _dispatch_index_op_fold_results(upper_bounds)
+        dynamic_steps, static_steps = _dispatch_index_op_fold_results(steps)
 
         results = [arg.type for arg in iter_args]
         super().__init__(
@@ -151,18 +155,17 @@ def induction_variables(self) -> BlockArgumentList:
         return self.body.arguments[: self.rank]
 
     @property
-    def inner_iter_args(self):
+    def inner_iter_args(self) -> BlockArgumentList:
         """Returns the loop-carried arguments usable within the loop.
 
         To obtain the loop-carried operands, use `iter_args`.
         """
         return self.body.arguments[self.rank :]
 
-    @property
     def terminator(self) -> InParallelOp:
         """
         Returns the loop terminator if it exists.
-        Otherwise, create a new one.
+        Otherwise, creates a new one.
         """
         ops = self.body.operations
         with InsertionPoint(self.body):
diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py
index 2beff4f7703a3..62d11d5e189c8 100644
--- a/mlir/test/python/dialects/scf.py
+++ b/mlir/test/python/dialects/scf.py
@@ -33,7 +33,7 @@ def forall_loop(tensor):
         with InsertionPoint(loop.body):
             i, j = loop.induction_variables
             arith.addi(i, j)
-            loop.terminator
+            loop.terminator()
         # The verifier will check that the regions have been created properly.
         assert loop.verify()
 



More information about the Mlir-commits mailing list