[Mlir-commits] [mlir] d50fbe4 - [MLIR][Python] Python binding support for AffineIfOp (#108323)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 13 13:27:50 PST 2024


Author: Amy Wang
Date: 2024-11-13T16:27:46-05:00
New Revision: d50fbe43c9887e776cdfe95deaf312fb9cecfeaf

URL: https://github.com/llvm/llvm-project/commit/d50fbe43c9887e776cdfe95deaf312fb9cecfeaf
DIFF: https://github.com/llvm/llvm-project/commit/d50fbe43c9887e776cdfe95deaf312fb9cecfeaf.diff

LOG: [MLIR][Python] Python binding support for AffineIfOp (#108323)

Fix the AffineIfOp's default builder such that it takes in an
IntegerSetAttr. AffineIfOp has skipDefaultBuilders=1 which effectively
skips the creation of the default AffineIfOp::builder on the C++ side.
(AffineIfOp has two custom OpBuilder defined in the
extraClassDeclaration.) However, on the python side, _affine_ops_gen.py
shows that the default builder is being created, but it does not accept
IntegerSet and thus is useless. This fix at line 411 makes the default
python AffineIfOp builder take in an IntegerSet input and does not
impact the C++ side of things.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
    mlir/include/mlir/IR/CommonAttrConstraints.td
    mlir/python/mlir/dialects/affine.py
    mlir/test/python/dialects/affine.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 6a495e11ae1ad5..ea65911af43a1e 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -407,7 +407,8 @@ def AffineIfOp : Affine_Op<"if",
     }
     ```
   }];
-  let arguments = (ins Variadic<AnyType>);
+  let arguments = (ins Variadic<AnyType>,
+                       IntegerSetAttr:$condition);
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$thenRegion, AnyRegion:$elseRegion);
 

diff  --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index de5f6797235e3c..17ca82c510f8a2 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -558,6 +558,15 @@ CPred<"::llvm::isa<::mlir::AffineMapAttr>($_self)">, "AffineMap attribute"> {
   let constBuilderCall = "::mlir::AffineMapAttr::get($0)";
 }
 
+// Attributes containing integer sets.
+def IntegerSetAttr : Attr<
+CPred<"::llvm::isa<::mlir::IntegerSetAttr>($_self)">, "IntegerSet attribute"> {
+  let storageType = [{::mlir::IntegerSetAttr }];
+  let returnType = [{ ::mlir::IntegerSet }];
+  let valueType = NoneType;
+  let constBuilderCall = "::mlir::IntegerSetAttr::get($0)";
+}
+
 // Base class for array attributes.
 class ArrayAttrBase<Pred condition, string summary> : Attr<condition, summary> {
   let storageType = [{ ::mlir::ArrayAttr }];

diff  --git a/mlir/python/mlir/dialects/affine.py b/mlir/python/mlir/dialects/affine.py
index 913cea61105cee..7641d36e39799f 100644
--- a/mlir/python/mlir/dialects/affine.py
+++ b/mlir/python/mlir/dialects/affine.py
@@ -156,3 +156,61 @@ def for_(
             yield iv, iter_args[0]
         else:
             yield iv
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class AffineIfOp(AffineIfOp):
+    """Specialization for the Affine if op class."""
+
+    def __init__(
+        self,
+        cond: IntegerSet,
+        results_: Optional[Type] = None,
+        *,
+        cond_operands: Optional[_VariadicResultValueT] = None,
+        has_else: bool = False,
+        loc=None,
+        ip=None,
+    ):
+        """Creates an Affine `if` operation.
+
+        - `cond` is the integer set used to determine which regions of code
+          will be executed.
+        - `results` are the list of types to be yielded by the operand.
+        - `cond_operands` is the list of arguments to substitute the
+          dimensions, then symbols in the `cond` integer set expression to
+          determine whether they are in the set.
+        - `has_else` determines whether the affine if operation has the else
+          branch.
+        """
+        if results_ is None:
+            results_ = []
+        if cond_operands is None:
+            cond_operands = []
+
+        if cond.n_inputs != len(cond_operands):
+            raise ValueError(
+                f"expected {cond.n_inputs} condition operands, got {len(cond_operands)}"
+            )
+
+        operands = []
+        operands.extend(cond_operands)
+        results = []
+        results.extend(results_)
+
+        super().__init__(results, cond_operands, cond)
+        self.regions[0].blocks.append(*[])
+        if has_else:
+            self.regions[1].blocks.append(*[])
+
+    @property
+    def then_block(self) -> Block:
+        """Returns the then block of the if operation."""
+        return self.regions[0].blocks[0]
+
+    @property
+    def else_block(self) -> Optional[Block]:
+        """Returns the else block of the if operation."""
+        if len(self.regions[1].blocks) == 0:
+            return None
+        return self.regions[1].blocks[0]

diff  --git a/mlir/test/python/dialects/affine.py b/mlir/test/python/dialects/affine.py
index 0dc69d7ba522de..7faae6ccedc972 100644
--- a/mlir/test/python/dialects/affine.py
+++ b/mlir/test/python/dialects/affine.py
@@ -263,3 +263,73 @@ def range_loop_8(lb, ub, memref_v):
             add = arith.addi(i, i)
             memref.store(add, it, [i])
             affine.yield_([it])
+
+
+# CHECK-LABEL: TEST: testAffineIfWithoutElse
+ at constructAndPrintInModule
+def testAffineIfWithoutElse():
+    index = IndexType.get()
+    i32 = IntegerType.get_signless(32)
+    d0 = AffineDimExpr.get(0)
+
+    # CHECK: #[[$SET0:.*]] = affine_set<(d0) : (d0 - 5 >= 0)>
+    cond = IntegerSet.get(1, 0, [d0 - 5], [False])
+
+    # CHECK-LABEL:  func.func @simple_affine_if(
+    # CHECK-SAME:                        %[[VAL_0:.*]]: index) {
+    # CHECK:            affine.if #[[$SET0]](%[[VAL_0]]) {
+    # CHECK:                %[[VAL_1:.*]] = arith.constant 1 : i32
+    # CHECK:                %[[VAL_2:.*]] = arith.addi %[[VAL_1]], %[[VAL_1]] : i32
+    # CHECK:            }
+    # CHECK:            return
+    # CHECK:        }
+    @func.FuncOp.from_py_func(index)
+    def simple_affine_if(cond_operands):
+        if_op = affine.AffineIfOp(cond, cond_operands=[cond_operands])
+        with InsertionPoint(if_op.then_block):
+            one = arith.ConstantOp(i32, 1)
+            add = arith.AddIOp(one, one)
+            affine.AffineYieldOp([])
+        return
+
+
+# CHECK-LABEL: TEST: testAffineIfWithElse
+ at constructAndPrintInModule
+def testAffineIfWithElse():
+    index = IndexType.get()
+    i32 = IntegerType.get_signless(32)
+    d0 = AffineDimExpr.get(0)
+
+    # CHECK: #[[$SET0:.*]] = affine_set<(d0) : (d0 - 5 >= 0)>
+    cond = IntegerSet.get(1, 0, [d0 - 5], [False])
+
+    # CHECK-LABEL:  func.func @simple_affine_if_else(
+    # CHECK-SAME:                                    %[[VAL_0:.*]]: index) {
+    # CHECK:            %[[VAL_IF:.*]]:2 = affine.if #[[$SET0]](%[[VAL_0]]) -> (i32, i32) {
+    # CHECK:                %[[VAL_XT:.*]] = arith.constant 0 : i32
+    # CHECK:                %[[VAL_YT:.*]] = arith.constant 1 : i32
+    # CHECK:                affine.yield %[[VAL_XT]], %[[VAL_YT]] : i32, i32
+    # CHECK:            } else {
+    # CHECK:                %[[VAL_XF:.*]] = arith.constant 2 : i32
+    # CHECK:                %[[VAL_YF:.*]] = arith.constant 3 : i32
+    # CHECK:                affine.yield %[[VAL_XF]], %[[VAL_YF]] : i32, i32
+    # CHECK:            }
+    # CHECK:            %[[VAL_ADD:.*]] = arith.addi %[[VAL_IF]]#0, %[[VAL_IF]]#1 : i32
+    # CHECK:            return
+    # CHECK:        }
+
+    @func.FuncOp.from_py_func(index)
+    def simple_affine_if_else(cond_operands):
+        if_op = affine.AffineIfOp(
+            cond, [i32, i32], cond_operands=[cond_operands], has_else=True
+        )
+        with InsertionPoint(if_op.then_block):
+            x_true = arith.ConstantOp(i32, 0)
+            y_true = arith.ConstantOp(i32, 1)
+            affine.AffineYieldOp([x_true, y_true])
+        with InsertionPoint(if_op.else_block):
+            x_false = arith.ConstantOp(i32, 2)
+            y_false = arith.ConstantOp(i32, 3)
+            affine.AffineYieldOp([x_false, y_false])
+        add = arith.AddIOp(if_op.results[0], if_op.results[1])
+        return


        


More information about the Mlir-commits mailing list