[Mlir-commits] [mlir] [MLIR][Python] Python binding support for AffineIfOp (PR #108323)
Amy Wang
llvmlistbot at llvm.org
Wed Nov 13 11:54:19 PST 2024
https://github.com/kaitingwang updated https://github.com/llvm/llvm-project/pull/108323
>From 6c36b3726ac79d8cf6b17b4d64abdcd4a5136cad Mon Sep 17 00:00:00 2001
From: Amy Wang <kai.ting.wang at huawei.com>
Date: Wed, 11 Sep 2024 23:03:55 -0400
Subject: [PATCH] [MLIR][Python] Python binding support for AffineIfOp
---
.../mlir/Dialect/Affine/IR/AffineOps.td | 3 +-
mlir/include/mlir/IR/CommonAttrConstraints.td | 9 +++
mlir/python/mlir/dialects/affine.py | 58 +++++++++++++++
mlir/test/python/dialects/affine.py | 70 +++++++++++++++++++
4 files changed, 139 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index c9d9202ae3cf1a..ac0cf36396fa8a 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -407,7 +407,8 @@ def AffineIfOp : Affine_Op<"if",
}
```
}];
- let arguments = (ins Variadic<AnyType>);
+ let arguments = (ins Variadic<AnyType>,
+ IntegerSetAttr:$condition);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$thenRegion, AnyRegion:$elseRegion);
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index de5f6797235e3c..17ca82c510f8a2 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -558,6 +558,15 @@ CPred<"::llvm::isa<::mlir::AffineMapAttr>($_self)">, "AffineMap attribute"> {
let constBuilderCall = "::mlir::AffineMapAttr::get($0)";
}
+// Attributes containing integer sets.
+def IntegerSetAttr : Attr<
+CPred<"::llvm::isa<::mlir::IntegerSetAttr>($_self)">, "IntegerSet attribute"> {
+ let storageType = [{::mlir::IntegerSetAttr }];
+ let returnType = [{ ::mlir::IntegerSet }];
+ let valueType = NoneType;
+ let constBuilderCall = "::mlir::IntegerSetAttr::get($0)";
+}
+
// Base class for array attributes.
class ArrayAttrBase<Pred condition, string summary> : Attr<condition, summary> {
let storageType = [{ ::mlir::ArrayAttr }];
diff --git a/mlir/python/mlir/dialects/affine.py b/mlir/python/mlir/dialects/affine.py
index 913cea61105cee..7641d36e39799f 100644
--- a/mlir/python/mlir/dialects/affine.py
+++ b/mlir/python/mlir/dialects/affine.py
@@ -156,3 +156,61 @@ def for_(
yield iv, iter_args[0]
else:
yield iv
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class AffineIfOp(AffineIfOp):
+ """Specialization for the Affine if op class."""
+
+ def __init__(
+ self,
+ cond: IntegerSet,
+ results_: Optional[Type] = None,
+ *,
+ cond_operands: Optional[_VariadicResultValueT] = None,
+ has_else: bool = False,
+ loc=None,
+ ip=None,
+ ):
+ """Creates an Affine `if` operation.
+
+ - `cond` is the integer set used to determine which regions of code
+ will be executed.
+ - `results` are the list of types to be yielded by the operand.
+ - `cond_operands` is the list of arguments to substitute the
+ dimensions, then symbols in the `cond` integer set expression to
+ determine whether they are in the set.
+ - `has_else` determines whether the affine if operation has the else
+ branch.
+ """
+ if results_ is None:
+ results_ = []
+ if cond_operands is None:
+ cond_operands = []
+
+ if cond.n_inputs != len(cond_operands):
+ raise ValueError(
+ f"expected {cond.n_inputs} condition operands, got {len(cond_operands)}"
+ )
+
+ operands = []
+ operands.extend(cond_operands)
+ results = []
+ results.extend(results_)
+
+ super().__init__(results, cond_operands, cond)
+ self.regions[0].blocks.append(*[])
+ if has_else:
+ self.regions[1].blocks.append(*[])
+
+ @property
+ def then_block(self) -> Block:
+ """Returns the then block of the if operation."""
+ return self.regions[0].blocks[0]
+
+ @property
+ def else_block(self) -> Optional[Block]:
+ """Returns the else block of the if operation."""
+ if len(self.regions[1].blocks) == 0:
+ return None
+ return self.regions[1].blocks[0]
diff --git a/mlir/test/python/dialects/affine.py b/mlir/test/python/dialects/affine.py
index 0dc69d7ba522de..7faae6ccedc972 100644
--- a/mlir/test/python/dialects/affine.py
+++ b/mlir/test/python/dialects/affine.py
@@ -263,3 +263,73 @@ def range_loop_8(lb, ub, memref_v):
add = arith.addi(i, i)
memref.store(add, it, [i])
affine.yield_([it])
+
+
+# CHECK-LABEL: TEST: testAffineIfWithoutElse
+ at constructAndPrintInModule
+def testAffineIfWithoutElse():
+ index = IndexType.get()
+ i32 = IntegerType.get_signless(32)
+ d0 = AffineDimExpr.get(0)
+
+ # CHECK: #[[$SET0:.*]] = affine_set<(d0) : (d0 - 5 >= 0)>
+ cond = IntegerSet.get(1, 0, [d0 - 5], [False])
+
+ # CHECK-LABEL: func.func @simple_affine_if(
+ # CHECK-SAME: %[[VAL_0:.*]]: index) {
+ # CHECK: affine.if #[[$SET0]](%[[VAL_0]]) {
+ # CHECK: %[[VAL_1:.*]] = arith.constant 1 : i32
+ # CHECK: %[[VAL_2:.*]] = arith.addi %[[VAL_1]], %[[VAL_1]] : i32
+ # CHECK: }
+ # CHECK: return
+ # CHECK: }
+ @func.FuncOp.from_py_func(index)
+ def simple_affine_if(cond_operands):
+ if_op = affine.AffineIfOp(cond, cond_operands=[cond_operands])
+ with InsertionPoint(if_op.then_block):
+ one = arith.ConstantOp(i32, 1)
+ add = arith.AddIOp(one, one)
+ affine.AffineYieldOp([])
+ return
+
+
+# CHECK-LABEL: TEST: testAffineIfWithElse
+ at constructAndPrintInModule
+def testAffineIfWithElse():
+ index = IndexType.get()
+ i32 = IntegerType.get_signless(32)
+ d0 = AffineDimExpr.get(0)
+
+ # CHECK: #[[$SET0:.*]] = affine_set<(d0) : (d0 - 5 >= 0)>
+ cond = IntegerSet.get(1, 0, [d0 - 5], [False])
+
+ # CHECK-LABEL: func.func @simple_affine_if_else(
+ # CHECK-SAME: %[[VAL_0:.*]]: index) {
+ # CHECK: %[[VAL_IF:.*]]:2 = affine.if #[[$SET0]](%[[VAL_0]]) -> (i32, i32) {
+ # CHECK: %[[VAL_XT:.*]] = arith.constant 0 : i32
+ # CHECK: %[[VAL_YT:.*]] = arith.constant 1 : i32
+ # CHECK: affine.yield %[[VAL_XT]], %[[VAL_YT]] : i32, i32
+ # CHECK: } else {
+ # CHECK: %[[VAL_XF:.*]] = arith.constant 2 : i32
+ # CHECK: %[[VAL_YF:.*]] = arith.constant 3 : i32
+ # CHECK: affine.yield %[[VAL_XF]], %[[VAL_YF]] : i32, i32
+ # CHECK: }
+ # CHECK: %[[VAL_ADD:.*]] = arith.addi %[[VAL_IF]]#0, %[[VAL_IF]]#1 : i32
+ # CHECK: return
+ # CHECK: }
+
+ @func.FuncOp.from_py_func(index)
+ def simple_affine_if_else(cond_operands):
+ if_op = affine.AffineIfOp(
+ cond, [i32, i32], cond_operands=[cond_operands], has_else=True
+ )
+ with InsertionPoint(if_op.then_block):
+ x_true = arith.ConstantOp(i32, 0)
+ y_true = arith.ConstantOp(i32, 1)
+ affine.AffineYieldOp([x_true, y_true])
+ with InsertionPoint(if_op.else_block):
+ x_false = arith.ConstantOp(i32, 2)
+ y_false = arith.ConstantOp(i32, 3)
+ affine.AffineYieldOp([x_false, y_false])
+ add = arith.AddIOp(if_op.results[0], if_op.results[1])
+ return
More information about the Mlir-commits
mailing list