[Mlir-commits] [mlir] [MLIR][Python] Python binding support for AffineIfOp (PR #108323)

Amy Wang llvmlistbot at llvm.org
Wed Nov 13 11:54:19 PST 2024


https://github.com/kaitingwang updated https://github.com/llvm/llvm-project/pull/108323

>From 6c36b3726ac79d8cf6b17b4d64abdcd4a5136cad Mon Sep 17 00:00:00 2001
From: Amy Wang <kai.ting.wang at huawei.com>
Date: Wed, 11 Sep 2024 23:03:55 -0400
Subject: [PATCH] [MLIR][Python] Python binding support for AffineIfOp

---
 .../mlir/Dialect/Affine/IR/AffineOps.td       |  3 +-
 mlir/include/mlir/IR/CommonAttrConstraints.td |  9 +++
 mlir/python/mlir/dialects/affine.py           | 58 +++++++++++++++
 mlir/test/python/dialects/affine.py           | 70 +++++++++++++++++++
 4 files changed, 139 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index c9d9202ae3cf1a..ac0cf36396fa8a 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -407,7 +407,8 @@ def AffineIfOp : Affine_Op<"if",
     }
     ```
   }];
-  let arguments = (ins Variadic<AnyType>);
+  let arguments = (ins Variadic<AnyType>,
+                       IntegerSetAttr:$condition);
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$thenRegion, AnyRegion:$elseRegion);
 
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index de5f6797235e3c..17ca82c510f8a2 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -558,6 +558,15 @@ CPred<"::llvm::isa<::mlir::AffineMapAttr>($_self)">, "AffineMap attribute"> {
   let constBuilderCall = "::mlir::AffineMapAttr::get($0)";
 }
 
+// Attributes containing integer sets.
+def IntegerSetAttr : Attr<
+CPred<"::llvm::isa<::mlir::IntegerSetAttr>($_self)">, "IntegerSet attribute"> {
+  let storageType = [{::mlir::IntegerSetAttr }];
+  let returnType = [{ ::mlir::IntegerSet }];
+  let valueType = NoneType;
+  let constBuilderCall = "::mlir::IntegerSetAttr::get($0)";
+}
+
 // Base class for array attributes.
 class ArrayAttrBase<Pred condition, string summary> : Attr<condition, summary> {
   let storageType = [{ ::mlir::ArrayAttr }];
diff --git a/mlir/python/mlir/dialects/affine.py b/mlir/python/mlir/dialects/affine.py
index 913cea61105cee..7641d36e39799f 100644
--- a/mlir/python/mlir/dialects/affine.py
+++ b/mlir/python/mlir/dialects/affine.py
@@ -156,3 +156,61 @@ def for_(
             yield iv, iter_args[0]
         else:
             yield iv
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class AffineIfOp(AffineIfOp):
+    """Specialization for the Affine if op class."""
+
+    def __init__(
+        self,
+        cond: IntegerSet,
+        results_: Optional[Type] = None,
+        *,
+        cond_operands: Optional[_VariadicResultValueT] = None,
+        has_else: bool = False,
+        loc=None,
+        ip=None,
+    ):
+        """Creates an Affine `if` operation.
+
+        - `cond` is the integer set used to determine which regions of code
+          will be executed.
+        - `results` are the list of types to be yielded by the operand.
+        - `cond_operands` is the list of arguments to substitute the
+          dimensions, then symbols in the `cond` integer set expression to
+          determine whether they are in the set.
+        - `has_else` determines whether the affine if operation has the else
+          branch.
+        """
+        if results_ is None:
+            results_ = []
+        if cond_operands is None:
+            cond_operands = []
+
+        if cond.n_inputs != len(cond_operands):
+            raise ValueError(
+                f"expected {cond.n_inputs} condition operands, got {len(cond_operands)}"
+            )
+
+        operands = []
+        operands.extend(cond_operands)
+        results = []
+        results.extend(results_)
+
+        super().__init__(results, cond_operands, cond)
+        self.regions[0].blocks.append(*[])
+        if has_else:
+            self.regions[1].blocks.append(*[])
+
+    @property
+    def then_block(self) -> Block:
+        """Returns the then block of the if operation."""
+        return self.regions[0].blocks[0]
+
+    @property
+    def else_block(self) -> Optional[Block]:
+        """Returns the else block of the if operation."""
+        if len(self.regions[1].blocks) == 0:
+            return None
+        return self.regions[1].blocks[0]
diff --git a/mlir/test/python/dialects/affine.py b/mlir/test/python/dialects/affine.py
index 0dc69d7ba522de..7faae6ccedc972 100644
--- a/mlir/test/python/dialects/affine.py
+++ b/mlir/test/python/dialects/affine.py
@@ -263,3 +263,73 @@ def range_loop_8(lb, ub, memref_v):
             add = arith.addi(i, i)
             memref.store(add, it, [i])
             affine.yield_([it])
+
+
+# CHECK-LABEL: TEST: testAffineIfWithoutElse
+ at constructAndPrintInModule
+def testAffineIfWithoutElse():
+    index = IndexType.get()
+    i32 = IntegerType.get_signless(32)
+    d0 = AffineDimExpr.get(0)
+
+    # CHECK: #[[$SET0:.*]] = affine_set<(d0) : (d0 - 5 >= 0)>
+    cond = IntegerSet.get(1, 0, [d0 - 5], [False])
+
+    # CHECK-LABEL:  func.func @simple_affine_if(
+    # CHECK-SAME:                        %[[VAL_0:.*]]: index) {
+    # CHECK:            affine.if #[[$SET0]](%[[VAL_0]]) {
+    # CHECK:                %[[VAL_1:.*]] = arith.constant 1 : i32
+    # CHECK:                %[[VAL_2:.*]] = arith.addi %[[VAL_1]], %[[VAL_1]] : i32
+    # CHECK:            }
+    # CHECK:            return
+    # CHECK:        }
+    @func.FuncOp.from_py_func(index)
+    def simple_affine_if(cond_operands):
+        if_op = affine.AffineIfOp(cond, cond_operands=[cond_operands])
+        with InsertionPoint(if_op.then_block):
+            one = arith.ConstantOp(i32, 1)
+            add = arith.AddIOp(one, one)
+            affine.AffineYieldOp([])
+        return
+
+
+# CHECK-LABEL: TEST: testAffineIfWithElse
+ at constructAndPrintInModule
+def testAffineIfWithElse():
+    index = IndexType.get()
+    i32 = IntegerType.get_signless(32)
+    d0 = AffineDimExpr.get(0)
+
+    # CHECK: #[[$SET0:.*]] = affine_set<(d0) : (d0 - 5 >= 0)>
+    cond = IntegerSet.get(1, 0, [d0 - 5], [False])
+
+    # CHECK-LABEL:  func.func @simple_affine_if_else(
+    # CHECK-SAME:                                    %[[VAL_0:.*]]: index) {
+    # CHECK:            %[[VAL_IF:.*]]:2 = affine.if #[[$SET0]](%[[VAL_0]]) -> (i32, i32) {
+    # CHECK:                %[[VAL_XT:.*]] = arith.constant 0 : i32
+    # CHECK:                %[[VAL_YT:.*]] = arith.constant 1 : i32
+    # CHECK:                affine.yield %[[VAL_XT]], %[[VAL_YT]] : i32, i32
+    # CHECK:            } else {
+    # CHECK:                %[[VAL_XF:.*]] = arith.constant 2 : i32
+    # CHECK:                %[[VAL_YF:.*]] = arith.constant 3 : i32
+    # CHECK:                affine.yield %[[VAL_XF]], %[[VAL_YF]] : i32, i32
+    # CHECK:            }
+    # CHECK:            %[[VAL_ADD:.*]] = arith.addi %[[VAL_IF]]#0, %[[VAL_IF]]#1 : i32
+    # CHECK:            return
+    # CHECK:        }
+
+    @func.FuncOp.from_py_func(index)
+    def simple_affine_if_else(cond_operands):
+        if_op = affine.AffineIfOp(
+            cond, [i32, i32], cond_operands=[cond_operands], has_else=True
+        )
+        with InsertionPoint(if_op.then_block):
+            x_true = arith.ConstantOp(i32, 0)
+            y_true = arith.ConstantOp(i32, 1)
+            affine.AffineYieldOp([x_true, y_true])
+        with InsertionPoint(if_op.else_block):
+            x_false = arith.ConstantOp(i32, 2)
+            y_false = arith.ConstantOp(i32, 3)
+            affine.AffineYieldOp([x_false, y_false])
+        add = arith.AddIOp(if_op.results[0], if_op.results[1])
+        return



More information about the Mlir-commits mailing list