[Mlir-commits] [mlir] [MLIR][Python] Python binding support for AffineIfOp (PR #107336)

Amy Wang llvmlistbot at llvm.org
Wed Sep 4 18:00:15 PDT 2024


https://github.com/kaitingwang updated https://github.com/llvm/llvm-project/pull/107336

>From 728b82777dfcf360a25a3b36ee780093519eebaf Mon Sep 17 00:00:00 2001
From: Amy Wang <kai.ting.wang at huawei.com>
Date: Wed, 4 Sep 2024 20:10:40 -0400
Subject: [PATCH] [MLIR][Python] Python binding support for AffineIfOp

The MR includes python support for IntegerSet needed for
the AffineIfOp.
---
 mlir/include/mlir-c/BuiltinAttributes.h       |  9 +++
 .../mlir/Dialect/Affine/IR/AffineOps.td       |  3 +-
 mlir/include/mlir/IR/CommonAttrConstraints.td |  9 +++
 mlir/lib/Bindings/Python/IRAttributes.cpp     | 22 +++++-
 mlir/lib/CAPI/IR/BuiltinAttributes.cpp        |  9 +++
 mlir/python/mlir/_mlir_libs/_mlir/ir.pyi      | 16 +++++
 mlir/python/mlir/dialects/affine.py           | 58 +++++++++++++++
 mlir/python/mlir/ir.py                        |  5 ++
 mlir/test/python/dialects/affine.py           | 70 +++++++++++++++++++
 mlir/test/python/ir/attributes.py             | 18 +++++
 10 files changed, 217 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h
index 231eb83b5e2694..7c8c84e55b962f 100644
--- a/mlir/include/mlir-c/BuiltinAttributes.h
+++ b/mlir/include/mlir-c/BuiltinAttributes.h
@@ -16,6 +16,7 @@
 
 #include "mlir-c/AffineMap.h"
 #include "mlir-c/IR.h"
+#include "mlir-c/IntegerSet.h"
 #include "mlir-c/Support.h"
 
 #ifdef __cplusplus
@@ -177,6 +178,14 @@ MLIR_CAPI_EXPORTED bool mlirBoolAttrGetValue(MlirAttribute attr);
 /// Checks whether the given attribute is an integer set attribute.
 MLIR_CAPI_EXPORTED bool mlirAttributeIsAIntegerSet(MlirAttribute attr);
 
+/// Creates an integer set attribute wrapping the given set. The attribute
+/// belongs to the same context as the integer set.
+MLIR_CAPI_EXPORTED MlirAttribute mlirIntegerSetAttrGet(MlirIntegerSet set);
+
+/// Returns the integer set wrapped in the given integer set attribute.
+MLIR_CAPI_EXPORTED MlirIntegerSet
+mlirIntegerSetAttrGetValue(MlirAttribute attr);
+
 /// Returns the typeID of an IntegerSet attribute.
 MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerSetAttrGetTypeID(void);
 
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index dbec741cf1b1f3..a7fbebed9aa586 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -407,7 +407,8 @@ def AffineIfOp : Affine_Op<"if",
     }
     ```
   }];
-  let arguments = (ins Variadic<AnyType>);
+  let arguments = (ins Variadic<AnyType>,
+                       IntegerSetAttr:$condition);
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$thenRegion, AnyRegion:$elseRegion);
 
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index 6774a7c568315d..785b73322397eb 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -557,6 +557,15 @@ CPred<"::llvm::isa<::mlir::AffineMapAttr>($_self)">, "AffineMap attribute"> {
   let constBuilderCall = "::mlir::AffineMapAttr::get($0)";
 }
 
+// Attributes containing integer sets.
+def IntegerSetAttr : Attr<
+CPred<"::llvm::isa<::mlir::IntegerSetAttr>($_self)">, "IntegerSet attribute"> {
+  let storageType = [{::mlir::IntegerSetAttr }];
+  let returnType = [{ ::mlir::IntegerSet }];
+  let valueType = NoneType;
+  let constBuilderCall = "::mlir::IntegerSetAttr::get($0)";
+}
+
 // Base class for array attributes.
 class ArrayAttrBase<Pred condition, string summary> : Attr<condition, summary> {
   let storageType = [{ ::mlir::ArrayAttr }];
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index b4049bd7972d44..c557b9d4c5071c 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -147,6 +147,26 @@ class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
   }
 };
 
+class PyIntegerSetAttribute
+    : public PyConcreteAttribute<PyIntegerSetAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAIntegerSet;
+  static constexpr const char *pyClassName = "IntegerSetAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirIntegerSetAttrGetTypeID;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](PyIntegerSet &integerSet) {
+          MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get());
+          return PyIntegerSetAttribute(integerSet.getContext(), attr);
+        },
+        py::arg("integer_set"), "Gets an attribute wrapping an IntegerSet.");
+  }
+};
+
 template <typename T>
 static T pyTryCast(py::handle object) {
   try {
@@ -1426,7 +1446,7 @@ py::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
 
 void mlir::python::populateIRAttributes(py::module &m) {
   PyAffineMapAttribute::bind(m);
-
+  PyIntegerSetAttribute::bind(m);
   PyDenseBoolArrayAttribute::bind(m);
   PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
   PyDenseI8ArrayAttribute::bind(m);
diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 726af884668b2d..11d1ade552f5a2 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -10,6 +10,7 @@
 #include "mlir-c/Support.h"
 #include "mlir/CAPI/AffineMap.h"
 #include "mlir/CAPI/IR.h"
+#include "mlir/CAPI/IntegerSet.h"
 #include "mlir/CAPI/Support.h"
 #include "mlir/IR/AsmState.h"
 #include "mlir/IR/Attributes.h"
@@ -192,6 +193,14 @@ MlirTypeID mlirIntegerSetAttrGetTypeID(void) {
   return wrap(IntegerSetAttr::getTypeID());
 }
 
+MlirAttribute mlirIntegerSetAttrGet(MlirIntegerSet set) {
+  return wrap(IntegerSetAttr::get(unwrap(set)));
+}
+
+MlirIntegerSet mlirIntegerSetAttrGetValue(MlirAttribute attr) {
+  return wrap(llvm::cast<IntegerSetAttr>(unwrap(attr)).getValue());
+}
+
 //===----------------------------------------------------------------------===//
 // Opaque attribute.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 4a2d0e977ccf26..80ce0849f35140 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -137,6 +137,7 @@ __all__ = [
     "InsertionPoint",
     "IntegerAttr",
     "IntegerSet",
+    "IntegerSetAttr",
     "IntegerSetConstraint",
     "IntegerSetConstraintList",
     "IntegerType",
@@ -1891,6 +1892,21 @@ class IntegerSet:
     @property
     def n_symbols(self) -> int: ...
 
+class IntegerSetAttr(Attribute):
+    static_typeid: ClassVar[TypeID]
+    @staticmethod
+    def get(integer_set) -> IntegerSetAttr:
+        """
+        Gets an attribute wrapping an IntegerSet.
+        """
+    @staticmethod
+    def isinstance(other: Attribute) -> bool: ...
+    def __init__(self, cast_from_attr: Attribute) -> None: ...
+    @property
+    def type(self) -> Type: ...
+    @property
+    def typeid(self) -> TypeID: ...
+
 class IntegerSetConstraint:
     def __init__(self, *args, **kwargs) -> None: ...
     @property
diff --git a/mlir/python/mlir/dialects/affine.py b/mlir/python/mlir/dialects/affine.py
index 913cea61105cee..dcd4e7d2c232c2 100644
--- a/mlir/python/mlir/dialects/affine.py
+++ b/mlir/python/mlir/dialects/affine.py
@@ -156,3 +156,61 @@ def for_(
             yield iv, iter_args[0]
         else:
             yield iv
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class AffineIfOp(AffineIfOp):
+    """Specialization for the Affine if op class."""
+
+    def __init__(
+        self,
+        cond: IntegerSet,
+        results_: Optional[Type] = None,
+        *,
+        cond_operands: Optional[_VariadicResultValueT] = None,
+        hasElse: bool = False,
+        loc=None,
+        ip=None,
+    ):
+        """Creates an Affine `if` operation.
+
+        - `cond` is the integer set used to determine which regions of code
+          will be executed.
+        - `results` are the list of types to be yielded by the operand.
+        - `cond_operands` is the list of arguments to substitute the
+          dimensions, then symbols in the `cond` integer set expression to
+          determine whether they are in the set.
+        - `hasElse` determines whether the affine if operation has the else
+          branch.
+        """
+        if results_ is None:
+            results_ = []
+        if cond_operands is None:
+            cond_operands = []
+
+        if not (actual_n_inputs := len(cond_operands)) == (
+            exp_n_inputs := cond.n_inputs
+        ):
+            raise ValueError(
+                f"expected {exp_n_inputs} condition operands, got {actual_n_inputs}"
+            )
+
+        operands = []
+        operands.extend(cond_operands)
+        results = []
+        results.extend(results_)
+
+        super().__init__(results, cond_operands, cond)
+        self.regions[0].blocks.append(*[])
+        if hasElse:
+            self.regions[1].blocks.append(*[])
+
+    @property
+    def then_block(self):
+        """Returns the then block of the if operation."""
+        return self.regions[0].blocks[0]
+
+    @property
+    def else_block(self):
+        """Returns the else block of the if operation."""
+        return self.regions[1].blocks[0]
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index a9ac765fe1c178..9a6ce462047ad2 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -22,6 +22,11 @@ def _affineMapAttr(x, context):
     return AffineMapAttr.get(x)
 
 
+ at register_attribute_builder("IntegerSetAttr")
+def _integerSetAttr(x, context):
+    return IntegerSetAttr.get(x)
+
+
 @register_attribute_builder("BoolAttr")
 def _boolAttr(x, context):
     return BoolAttr.get(x, context=context)
diff --git a/mlir/test/python/dialects/affine.py b/mlir/test/python/dialects/affine.py
index 6f39e1348fcd57..24fe4f398f4e87 100644
--- a/mlir/test/python/dialects/affine.py
+++ b/mlir/test/python/dialects/affine.py
@@ -265,3 +265,73 @@ def range_loop_8(lb, ub, memref_v):
             add = arith.addi(i, i)
             memref.store(add, it, [i])
             affine.yield_([it])
+
+
+# CHECK-LABEL: TEST: testAffineIfWithoutElse
+ at constructAndPrintInModule
+def testAffineIfWithoutElse():
+    index = IndexType.get()
+    i32 = IntegerType.get_signless(32)
+    d0 = AffineDimExpr.get(0)
+
+    # CHECK: #[[$SET0:.*]] = affine_set<(d0) : (d0 - 5 >= 0)>
+    cond = IntegerSet.get(1, 0, [d0 - 5], [False])
+
+    # CHECK-LABEL:  func.func @simple_affine_if(
+    # CHECK-SAME:                        %[[VAL_0:.*]]: index) {
+    # CHECK:            affine.if #[[$SET0]](%[[VAL_0]]) {
+    # CHECK:                %[[VAL_1:.*]] = arith.constant 1 : i32
+    # CHECK:                %[[VAL_2:.*]] = arith.addi %[[VAL_1]], %[[VAL_1]] : i32
+    # CHECK:            }
+    # CHECK:            return
+    # CHECK:        }
+    @func.FuncOp.from_py_func(index)
+    def simple_affine_if(cond_operands):
+        if_op = affine.AffineIfOp(cond, cond_operands=[cond_operands])
+        with InsertionPoint(if_op.then_block):
+            one = arith.ConstantOp(i32, 1)
+            add = arith.AddIOp(one, one)
+            affine.AffineYieldOp([])
+        return
+
+
+# CHECK-LABEL: TEST: testAffineIfWithElse
+ at constructAndPrintInModule
+def testAffineIfWithElse():
+    index = IndexType.get()
+    i32 = IntegerType.get_signless(32)
+    d0 = AffineDimExpr.get(0)
+
+    # CHECK: #[[$SET0:.*]] = affine_set<(d0) : (d0 - 5 >= 0)>
+    cond = IntegerSet.get(1, 0, [d0 - 5], [False])
+
+    # CHECK-LABEL:  func.func @simple_affine_if_else(
+    # CHECK-SAME:                                    %[[VAL_0:.*]]: index) {
+    # CHECK:            %[[VAL_IF:.*]]:2 = affine.if #[[$SET0]](%[[VAL_0]]) -> (i32, i32) {
+    # CHECK:                %[[VAL_XT:.*]] = arith.constant 0 : i32
+    # CHECK:                %[[VAL_YT:.*]] = arith.constant 1 : i32
+    # CHECK:                affine.yield %[[VAL_XT]], %[[VAL_YT]] : i32, i32
+    # CHECK:            } else {
+    # CHECK:                %[[VAL_XF:.*]] = arith.constant 2 : i32
+    # CHECK:                %[[VAL_YF:.*]] = arith.constant 3 : i32
+    # CHECK:                affine.yield %[[VAL_XF]], %[[VAL_YF]] : i32, i32
+    # CHECK:            }
+    # CHECK:            %[[VAL_ADD:.*]] = arith.addi %[[VAL_IF]]#0, %[[VAL_IF]]#1 : i32
+    # CHECK:            return
+    # CHECK:        }
+
+    @func.FuncOp.from_py_func(index)
+    def simple_affine_if_else(cond_operands):
+        if_op = affine.AffineIfOp(
+            cond, [i32, i32], cond_operands=[cond_operands], hasElse=True
+        )
+        with InsertionPoint(if_op.then_block):
+            x_true = arith.ConstantOp(i32, 0)
+            y_true = arith.ConstantOp(i32, 1)
+            affine.AffineYieldOp([x_true, y_true])
+        with InsertionPoint(if_op.else_block):
+            x_false = arith.ConstantOp(i32, 2)
+            y_false = arith.ConstantOp(i32, 3)
+            affine.AffineYieldOp([x_false, y_false])
+        add = arith.AddIOp(if_op.results[0], if_op.results[1])
+        return
diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index 4b475db6346453..00c3e1b4decdb7 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -162,6 +162,24 @@ def testAffineMapAttr():
         assert attr_built == attr_parsed
 
 
+# CHECK-LABEL: TEST: testIntegerSetAttr
+ at run
+def testIntegerSetAttr():
+    with Context() as ctx:
+        d0 = AffineDimExpr.get(0)
+        d1 = AffineDimExpr.get(1)
+        s0 = AffineSymbolExpr.get(0)
+        c42 = AffineConstantExpr.get(42)
+        set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42], [True, False])
+
+        # CHECK: affine_set<(d0, d1)[s0] : (d0 - d1 == 0, s0 - 42 >= 0)>
+        attr_built = IntegerSetAttr.get(set0)
+        print(str(attr_built))
+
+        attr_parsed = Attribute.parse(str(attr_built))
+        assert attr_built == attr_parsed
+
+
 # CHECK-LABEL: TEST: testFloatAttr
 @run
 def testFloatAttr():



More information about the Mlir-commits mailing list