[Mlir-commits] [mlir] 036088f - [MLIR][Python] Add SCFIfOp Python binding
Mehdi Amini
llvmlistbot at llvm.org
Sat Mar 12 21:24:33 PST 2022
Author: chhzh123
Date: 2022-03-13T05:24:10Z
New Revision: 036088fd6ea271c2b8fe8deffcb98509e9fd166d
URL: https://github.com/llvm/llvm-project/commit/036088fd6ea271c2b8fe8deffcb98509e9fd166d
DIFF: https://github.com/llvm/llvm-project/commit/036088fd6ea271c2b8fe8deffcb98509e9fd166d.diff
LOG: [MLIR][Python] Add SCFIfOp Python binding
Current generated Python binding for the SCF dialect does not allow
users to call IfOp to create if-else branches on their own.
This PR sets up the default binding generation for scf.if operation
to address this problem.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D121076
Added:
Modified:
mlir/python/mlir/dialects/_scf_ops_ext.py
mlir/test/python/dialects/scf.py
Removed:
################################################################################
diff --git a/mlir/python/mlir/dialects/_scf_ops_ext.py b/mlir/python/mlir/dialects/_scf_ops_ext.py
index a8924a7507a42..3c3e673021585 100644
--- a/mlir/python/mlir/dialects/_scf_ops_ext.py
+++ b/mlir/python/mlir/dialects/_scf_ops_ext.py
@@ -64,3 +64,44 @@ def inner_iter_args(self):
To obtain the loop-carried operands, use `iter_args`.
"""
return self.body.arguments[1:]
+
+
+class IfOp:
+ """Specialization for the SCF if op class."""
+
+ def __init__(self,
+ cond,
+ results_=[],
+ *,
+ hasElse=False,
+ loc=None,
+ ip=None):
+ """Creates an SCF `if` operation.
+
+ - `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed.
+ - `hasElse` determines whether the if operation has the else branch.
+ """
+ operands = []
+ operands.append(cond)
+ results = []
+ results.extend(results_)
+ super().__init__(
+ self.build_generic(
+ regions=2,
+ results=results,
+ operands=operands,
+ loc=loc,
+ ip=ip))
+ self.regions[0].blocks.append(*[])
+ if hasElse:
+ self.regions[1].blocks.append(*[])
+
+ @property
+ def then_block(self):
+ """Returns the then block of the if operation."""
+ return self.regions[0].blocks[0]
+
+ @property
+ def else_block(self):
+ """Returns the else block of the if operation."""
+ return self.regions[1].blocks[0]
diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py
index f434e806ed47f..c45931c7e76cd 100644
--- a/mlir/test/python/dialects/scf.py
+++ b/mlir/test/python/dialects/scf.py
@@ -82,3 +82,58 @@ def testOpsAsArguments():
# CHECK: iter_args(%{{.*}} = %[[ARGS]]#0, %{{.*}} = %[[ARGS]]#1)
# CHECK: scf.yield %{{.*}}, %{{.*}}
# CHECK: return
+
+
+ at constructAndPrintInModule
+def testIfWithoutElse():
+ bool = IntegerType.get_signless(1)
+ i32 = IntegerType.get_signless(32)
+
+ @builtin.FuncOp.from_py_func(bool)
+ def simple_if(cond):
+ if_op = scf.IfOp(cond)
+ with InsertionPoint(if_op.then_block):
+ one = arith.ConstantOp(i32, 1)
+ add = arith.AddIOp(one, one)
+ scf.YieldOp([])
+ return
+
+
+# CHECK: func @simple_if(%[[ARG0:.*]]: i1)
+# CHECK: scf.if %[[ARG0:.*]]
+# CHECK: %[[ONE:.*]] = arith.constant 1
+# CHECK: %[[ADD:.*]] = arith.addi %[[ONE]], %[[ONE]]
+# CHECK: return
+
+
+ at constructAndPrintInModule
+def testIfWithElse():
+ bool = IntegerType.get_signless(1)
+ i32 = IntegerType.get_signless(32)
+
+ @builtin.FuncOp.from_py_func(bool)
+ def simple_if_else(cond):
+ if_op = scf.IfOp(cond, [i32, i32], hasElse=True)
+ with InsertionPoint(if_op.then_block):
+ x_true = arith.ConstantOp(i32, 0)
+ y_true = arith.ConstantOp(i32, 1)
+ scf.YieldOp([x_true, y_true])
+ with InsertionPoint(if_op.else_block):
+ x_false = arith.ConstantOp(i32, 2)
+ y_false = arith.ConstantOp(i32, 3)
+ scf.YieldOp([x_false, y_false])
+ add = arith.AddIOp(if_op.results[0], if_op.results[1])
+ return
+
+
+# CHECK: func @simple_if_else(%[[ARG0:.*]]: i1)
+# CHECK: %[[RET:.*]]:2 = scf.if %[[ARG0:.*]]
+# CHECK: %[[ZERO:.*]] = arith.constant 0
+# CHECK: %[[ONE:.*]] = arith.constant 1
+# CHECK: scf.yield %[[ZERO]], %[[ONE]]
+# CHECK: } else {
+# CHECK: %[[TWO:.*]] = arith.constant 2
+# CHECK: %[[THREE:.*]] = arith.constant 3
+# CHECK: scf.yield %[[TWO]], %[[THREE]]
+# CHECK: arith.addi %[[RET]]#0, %[[RET]]#1
+# CHECK: return
More information about the Mlir-commits
mailing list