[Mlir-commits] [mlir] [mlir][python] python binding wrapper for the affine.AffineForOp (PR #74408)

Amy Wang llvmlistbot at llvm.org
Mon Dec 4 22:51:23 PST 2023


https://github.com/kaitingwang updated https://github.com/llvm/llvm-project/pull/74408

>From 21e9a601a48a7d6365892b87f6e7d0be6ab6a8da Mon Sep 17 00:00:00 2001
From: Amy Wang <kai.ting.wang at huawei.com>
Date: Mon, 4 Dec 2023 22:31:23 -0500
Subject: [PATCH] [mlir][python] python binding for the affine.for op

This PR creates the wrapper class AffineForOp and adds a testcase
for it. A testcase for AffineLoadOp is also added as well
as some syntatic suger tests.
---
 mlir/python/mlir/dialects/affine.py | 138 +++++++++++++++++++++
 mlir/test/python/dialects/affine.py | 182 +++++++++++++++++++++++-----
 2 files changed, 293 insertions(+), 27 deletions(-)

diff --git a/mlir/python/mlir/dialects/affine.py b/mlir/python/mlir/dialects/affine.py
index 80d3873e19a05..26e827009bc04 100644
--- a/mlir/python/mlir/dialects/affine.py
+++ b/mlir/python/mlir/dialects/affine.py
@@ -3,3 +3,141 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from ._affine_ops_gen import *
+from ._affine_ops_gen import _Dialect, AffineForOp
+from .arith import constant
+
+try:
+    from ..ir import *
+    from ._ods_common import (
+        get_op_result_or_value as _get_op_result_or_value,
+        get_op_results_or_values as _get_op_results_or_values,
+        _cext as _ods_cext,
+    )
+except ImportError as e:
+    raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Optional, Sequence, Union
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class AffineForOp(AffineForOp):
+    """Specialization for the Affine for op class"""
+
+    def __init__(
+        self,
+        lower_bound,
+        upper_bound,
+        step,
+        iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
+        *,
+        lower_bound_operands=[],
+        upper_bound_operands=[],
+        loc=None,
+        ip=None,
+    ):
+        """Creates an Affine `for` operation.
+
+        - `lower_bound` is the affine map to use as lower bound of the loop.
+        - `upper_bound` is the affine map to use as upper bound of the loop.
+        - `step` is the value to use as loop step.
+        - `iter_args` is a list of additional loop-carried arguments or an operation
+          producing them as results.
+        - `lower_bound_operands` is the list of arguments to substitute the dimensions,
+          then symbols in the `lower_bound` affine map, in an increasing order
+        - `upper_bound_operands` is the list of arguments to substitute the dimensions,
+          then symbols in the `upper_bound` affine map, in an increasing order
+        """
+
+        if iter_args is None:
+            iter_args = []
+        iter_args = _get_op_results_or_values(iter_args)
+        if len(lower_bound_operands) != lower_bound.n_inputs:
+            raise ValueError(
+                f"Wrong number of lower bound operands passed to AffineForOp. "
+                + "Expected {lower_bound.n_symbols}, got {len(lower_bound_operands)}."
+            )
+
+        if len(upper_bound_operands) != upper_bound.n_inputs:
+            raise ValueError(
+                f"Wrong number of upper bound operands passed to AffineForOp. "
+                + "Expected {upper_bound.n_symbols}, got {len(upper_bound_operands)}."
+            )
+
+        results = [arg.type for arg in iter_args]
+        super().__init__(
+            results_=results,
+            lowerBoundOperands=_get_op_results_or_values(lower_bound_operands),
+            upperBoundOperands=_get_op_results_or_values(upper_bound_operands),
+            inits=list(iter_args),
+            lowerBoundMap=AffineMapAttr.get(lower_bound),
+            upperBoundMap=AffineMapAttr.get(upper_bound),
+            step=IntegerAttr.get(IndexType.get(), step),
+            loc=loc,
+            ip=ip,
+        )
+        self.regions[0].blocks.append(IndexType.get(), *results)
+
+    @property
+    def body(self):
+        """Returns the body (block) of the loop."""
+        return self.regions[0].blocks[0]
+
+    @property
+    def induction_variable(self):
+        """Returns the induction variable of the loop."""
+        return self.body.arguments[0]
+
+    @property
+    def inner_iter_args(self):
+        """Returns the loop-carried arguments usable within the loop.
+
+        To obtain the loop-carried operands, use `iter_args`.
+        """
+        return self.body.arguments[1:]
+
+
+def for_(
+    start,
+    stop=None,
+    step=None,
+    iter_args: Optional[Sequence[Value]] = None,
+    *,
+    loc=None,
+    ip=None,
+):
+    if step is None:
+        step = 1
+    if stop is None:
+        stop = start
+        start = 0
+    params = [start, stop]
+    for i, p in enumerate(params):
+        if isinstance(p, int):
+            p = constant(IntegerAttr.get(IndexType.get(), p))
+        elif isinstance(p, float):
+            raise ValueError(f"{p=} must be int.")
+        params[i] = p
+
+    start, stop = params
+    s0 = AffineSymbolExpr.get(0)
+    lbmap = AffineMap.get(0, 1, [s0])
+    ubmap = AffineMap.get(0, 1, [s0])
+    for_op = AffineForOp(
+        lbmap,
+        ubmap,
+        step,
+        iter_args=iter_args,
+        lower_bound_operands=[start],
+        upper_bound_operands=[stop],
+        loc=loc,
+        ip=ip,
+    )
+    iv = for_op.induction_variable
+    iter_args = tuple(for_op.inner_iter_args)
+    with InsertionPoint(for_op.body):
+        if len(iter_args) > 1:
+            yield iv, iter_args
+        elif len(iter_args) == 1:
+            yield iv, iter_args[0]
+        else:
+            yield iv
diff --git a/mlir/test/python/dialects/affine.py b/mlir/test/python/dialects/affine.py
index c5ec85457493b..df42f8fcf1a57 100644
--- a/mlir/test/python/dialects/affine.py
+++ b/mlir/test/python/dialects/affine.py
@@ -1,44 +1,172 @@
 # RUN: %PYTHON %s | FileCheck %s
 
 from mlir.ir import *
-import mlir.dialects.func as func
-import mlir.dialects.arith as arith
-import mlir.dialects.affine as affine
-import mlir.dialects.memref as memref
+from mlir.dialects import func
+from mlir.dialects import arith
+from mlir.dialects import memref
+from mlir.dialects import affine
 
 
-def run(f):
+def constructAndPrintInModule(f):
     print("\nTEST:", f.__name__)
-    f()
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            f()
+        print(module)
     return f
 
 
 # CHECK-LABEL: TEST: testAffineStoreOp
- at run
+ at constructAndPrintInModule
 def testAffineStoreOp():
-    with Context() as ctx, Location.unknown():
-        module = Module.create()
-        with InsertionPoint(module.body):
-            f32 = F32Type.get()
-            index_type = IndexType.get()
-            memref_type_out = MemRefType.get([12, 12], f32)
+    f32 = F32Type.get()
+    index_type = IndexType.get()
+    memref_type_out = MemRefType.get([12, 12], f32)
 
-            # CHECK: func.func @affine_store_test(%[[ARG0:.*]]: index) -> memref<12x12xf32> {
-            @func.FuncOp.from_py_func(index_type)
-            def affine_store_test(arg0):
-                # CHECK: %[[O_VAR:.*]] = memref.alloc() : memref<12x12xf32>
-                mem = memref.AllocOp(memref_type_out, [], []).result
+    # CHECK: func.func @affine_store_test(%[[ARG0:.*]]: index) -> memref<12x12xf32> {
+    @func.FuncOp.from_py_func(index_type)
+    def affine_store_test(arg0):
+        # CHECK: %[[O_VAR:.*]] = memref.alloc() : memref<12x12xf32>
+        mem = memref.AllocOp(memref_type_out, [], []).result
 
-                d0 = AffineDimExpr.get(0)
-                s0 = AffineSymbolExpr.get(0)
-                map = AffineMap.get(1, 1, [s0 * 3, d0 + s0 + 1])
+        d0 = AffineDimExpr.get(0)
+        s0 = AffineSymbolExpr.get(0)
+        map = AffineMap.get(1, 1, [s0 * 3, d0 + s0 + 1])
 
-                # CHECK: %[[A1:.*]] = arith.constant 2.100000e+00 : f32
-                a1 = arith.ConstantOp(f32, 2.1)
+        # CHECK: %[[A1:.*]] = arith.constant 2.100000e+00 : f32
+        a1 = arith.ConstantOp(f32, 2.1)
 
-                # CHECK: affine.store %[[A1]], %alloc[symbol(%[[ARG0]]) * 3, %[[ARG0]] + symbol(%[[ARG0]]) + 1] : memref<12x12xf32>
-                affine.AffineStoreOp(a1, mem, indices=[arg0, arg0], map=map)
+        # CHECK: affine.store %[[A1]], %alloc[symbol(%[[ARG0]]) * 3, %[[ARG0]] + symbol(%[[ARG0]]) + 1] : memref<12x12xf32>
+        affine.AffineStoreOp(a1, mem, indices=[arg0, arg0], map=map)
 
-                return mem
+        return mem
 
-        print(module)
+
+# CHECK-LABEL: TEST: testAffineLoadOp
+ at constructAndPrintInModule
+def testAffineLoadOp():
+    f32 = F32Type.get()
+    index_type = IndexType.get()
+    memref_type_in = MemRefType.get([10, 10], f32)
+
+    # CHECK: func.func @affine_load_test(%[[I_VAR:.*]]: memref<10x10xf32>, %[[ARG0:.*]]: index) -> f32 {
+    @func.FuncOp.from_py_func(memref_type_in, index_type)
+    def affine_load_test(I, arg0):
+        d0 = AffineDimExpr.get(0)
+        s0 = AffineSymbolExpr.get(0)
+        map = AffineMap.get(1, 1, [s0 * 3, d0 + s0 + 1])
+
+        # CHECK: {{.*}} = affine.load %[[I_VAR]][symbol(%[[ARG0]]) * 3, %[[ARG0]] + symbol(%[[ARG0]]) + 1] : memref<10x10xf32>
+        a1 = affine.AffineLoadOp(f32, I, indices=[arg0, arg0], map=map)
+
+        return a1
+
+
+# CHECK-LABEL: TEST: testAffineForOp
+ at constructAndPrintInModule
+def testAffineForOp():
+    f32 = F32Type.get()
+    index_type = IndexType.get()
+    memref_type = MemRefType.get([1024], f32)
+
+    # CHECK: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (0, d0 + s0)>
+    # CHECK: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0 - 2, d1 * 32)>
+    # CHECK: func.func @affine_for_op_test(%[[BUFFER:.*]]: memref<1024xf32>) {
+    @func.FuncOp.from_py_func(memref_type)
+    def affine_for_op_test(buffer):
+        # CHECK: %[[C1:.*]] = arith.constant 1 : index
+        c1 = arith.ConstantOp(index_type, 1)
+        # CHECK: %[[C2:.*]] = arith.constant 2 : index
+        c2 = arith.ConstantOp(index_type, 2)
+        # CHECK: %[[C3:.*]] = arith.constant 3 : index
+        c3 = arith.ConstantOp(index_type, 3)
+        # CHECK: %[[C9:.*]] = arith.constant 9 : index
+        c9 = arith.ConstantOp(index_type, 9)
+        # CHECK: %[[AC0:.*]] = arith.constant 0.000000e+00 : f32
+        ac0 = AffineConstantExpr.get(0)
+
+        d0 = AffineDimExpr.get(0)
+        d1 = AffineDimExpr.get(1)
+        s0 = AffineSymbolExpr.get(0)
+        lb = AffineMap.get(1, 1, [ac0, d0 + s0])
+        ub = AffineMap.get(2, 0, [d0 - 2, 32 * d1])
+        sum_0 = arith.ConstantOp(f32, 0.0)
+
+        # CHECK: %0 = affine.for %[[INDVAR:.*]] = max #[[MAP0]](%[[C2]])[%[[C3]]] to min #[[MAP1]](%[[C9]], %[[C1]]) step 2 iter_args(%[[SUM0:.*]] = %[[AC0]]) -> (f32) {
+        sum = affine.AffineForOp(
+            lb,
+            ub,
+            2,
+            iter_args=[sum_0],
+            lower_bound_operands=[c2, c3],
+            upper_bound_operands=[c9, c1],
+        )
+
+        with InsertionPoint(sum.body):
+            # CHECK: %[[TMP:.*]] = memref.load %[[BUFFER]][%[[INDVAR]]] : memref<1024xf32>
+            tmp = memref.LoadOp(buffer, [sum.induction_variable])
+            sum_next = arith.AddFOp(sum.inner_iter_args[0], tmp)
+
+            affine.AffineYieldOp([sum_next])
+
+        return
+
+
+ at constructAndPrintInModule
+def testForSugar():
+    index_type = IndexType.get()
+    memref_t = MemRefType.get([10], index_type)
+    range = affine.for_
+
+    # CHECK:  func.func @range_loop_1(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
+    # CHECK:    %[[VAL_4:.*]] = arith.constant 10 : index
+    # CHECK:    affine.for %[[VAL_6:.*]] = %[[VAL_0]] to %[[VAL_4]] step 2 {
+    # CHECK:      %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
+    # CHECK:      affine.store %[[VAL_7]], %[[VAL_3]]{{\[symbol\(}}%[[VAL_6]]{{\)\]}} : memref<10xindex>
+    # CHECK:    }
+    # CHECK:    return
+    # CHECK:  }
+    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+    def range_loop_1(lb, ub, step, memref_v):
+        for i in range(lb, 10, 2):
+            add = arith.addi(i, i)
+            s0 = AffineSymbolExpr.get(0)
+            map = AffineMap.get(0, 1, [s0])
+            affine.store(add, memref_v, [i], map=map)
+            affine.AffineYieldOp([])
+
+    # CHECK:  func.func @range_loop_2(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
+    # CHECK:    %[[VAL_4:.*]] = arith.constant 0 : index
+    # CHECK:    %[[VAL_5:.*]] = arith.constant 10 : index
+    # CHECK:    affine.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] {
+    # CHECK:      %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index
+    # CHECK:      affine.store %[[VAL_8]], %[[VAL_3]]{{\[symbol\(}}%[[VAL_7]]{{\)\]}} : memref<10xindex>
+    # CHECK:    }
+    # CHECK:    return
+    # CHECK:  }
+    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+    def range_loop_2(lb, ub, step, memref_v):
+        for i in range(0, 10, 1):
+            add = arith.addi(i, i)
+            s0 = AffineSymbolExpr.get(0)
+            map = AffineMap.get(0, 1, [s0])
+            affine.store(add, memref_v, [i], map=map)
+            affine.AffineYieldOp([])
+
+    # CHECK:  func.func @range_loop_3(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
+    # CHECK:    %[[VAL_4:.*]] = arith.constant 0 : index
+    # CHECK:    affine.for %[[VAL_6:.*]] = %[[VAL_4]] to %[[VAL_1]] {
+    # CHECK:      %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
+    # CHECK:      affine.store %[[VAL_7]], %[[VAL_3]]{{\[symbol\(}}%[[VAL_6]]{{\)\]}} : memref<10xindex>
+    # CHECK:    }
+    # CHECK:    return
+    # CHECK:  }
+    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+    def range_loop_3(lb, ub, step, memref_v):
+        for i in range(0, ub, 1):
+            add = arith.addi(i, i)
+            s0 = AffineSymbolExpr.get(0)
+            map = AffineMap.get(0, 1, [s0])
+            affine.store(add, memref_v, [i], map=map)
+            affine.AffineYieldOp([])



More information about the Mlir-commits mailing list