[Mlir-commits] [mlir] [MLIR][Python] Python binding support for AffineIfOp (PR #107336)
Amy Wang
llvmlistbot at llvm.org
Wed Sep 4 18:00:15 PDT 2024
https://github.com/kaitingwang updated https://github.com/llvm/llvm-project/pull/107336
>From 728b82777dfcf360a25a3b36ee780093519eebaf Mon Sep 17 00:00:00 2001
From: Amy Wang <kai.ting.wang at huawei.com>
Date: Wed, 4 Sep 2024 20:10:40 -0400
Subject: [PATCH] [MLIR][Python] Python binding support for AffineIfOp
The MR includes python support for IntegerSet needed for
the AffineIfOp.
---
mlir/include/mlir-c/BuiltinAttributes.h | 9 +++
.../mlir/Dialect/Affine/IR/AffineOps.td | 3 +-
mlir/include/mlir/IR/CommonAttrConstraints.td | 9 +++
mlir/lib/Bindings/Python/IRAttributes.cpp | 22 +++++-
mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 9 +++
mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 16 +++++
mlir/python/mlir/dialects/affine.py | 58 +++++++++++++++
mlir/python/mlir/ir.py | 5 ++
mlir/test/python/dialects/affine.py | 70 +++++++++++++++++++
mlir/test/python/ir/attributes.py | 18 +++++
10 files changed, 217 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h
index 231eb83b5e2694..7c8c84e55b962f 100644
--- a/mlir/include/mlir-c/BuiltinAttributes.h
+++ b/mlir/include/mlir-c/BuiltinAttributes.h
@@ -16,6 +16,7 @@
#include "mlir-c/AffineMap.h"
#include "mlir-c/IR.h"
+#include "mlir-c/IntegerSet.h"
#include "mlir-c/Support.h"
#ifdef __cplusplus
@@ -177,6 +178,14 @@ MLIR_CAPI_EXPORTED bool mlirBoolAttrGetValue(MlirAttribute attr);
/// Checks whether the given attribute is an integer set attribute.
MLIR_CAPI_EXPORTED bool mlirAttributeIsAIntegerSet(MlirAttribute attr);
+/// Creates an integer set attribute wrapping the given set. The attribute
+/// belongs to the same context as the integer set.
+MLIR_CAPI_EXPORTED MlirAttribute mlirIntegerSetAttrGet(MlirIntegerSet set);
+
+/// Returns the integer set wrapped in the given integer set attribute.
+MLIR_CAPI_EXPORTED MlirIntegerSet
+mlirIntegerSetAttrGetValue(MlirAttribute attr);
+
/// Returns the typeID of an IntegerSet attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerSetAttrGetTypeID(void);
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index dbec741cf1b1f3..a7fbebed9aa586 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -407,7 +407,8 @@ def AffineIfOp : Affine_Op<"if",
}
```
}];
- let arguments = (ins Variadic<AnyType>);
+ let arguments = (ins Variadic<AnyType>,
+ IntegerSetAttr:$condition);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$thenRegion, AnyRegion:$elseRegion);
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index 6774a7c568315d..785b73322397eb 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -557,6 +557,15 @@ CPred<"::llvm::isa<::mlir::AffineMapAttr>($_self)">, "AffineMap attribute"> {
let constBuilderCall = "::mlir::AffineMapAttr::get($0)";
}
+// Attributes containing integer sets.
+def IntegerSetAttr : Attr<
+CPred<"::llvm::isa<::mlir::IntegerSetAttr>($_self)">, "IntegerSet attribute"> {
+ let storageType = [{::mlir::IntegerSetAttr }];
+ let returnType = [{ ::mlir::IntegerSet }];
+ let valueType = NoneType;
+ let constBuilderCall = "::mlir::IntegerSetAttr::get($0)";
+}
+
// Base class for array attributes.
class ArrayAttrBase<Pred condition, string summary> : Attr<condition, summary> {
let storageType = [{ ::mlir::ArrayAttr }];
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index b4049bd7972d44..c557b9d4c5071c 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -147,6 +147,26 @@ class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
}
};
+class PyIntegerSetAttribute
+ : public PyConcreteAttribute<PyIntegerSetAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAIntegerSet;
+ static constexpr const char *pyClassName = "IntegerSetAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirIntegerSetAttrGetTypeID;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyIntegerSet &integerSet) {
+ MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get());
+ return PyIntegerSetAttribute(integerSet.getContext(), attr);
+ },
+ py::arg("integer_set"), "Gets an attribute wrapping an IntegerSet.");
+ }
+};
+
template <typename T>
static T pyTryCast(py::handle object) {
try {
@@ -1426,7 +1446,7 @@ py::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
void mlir::python::populateIRAttributes(py::module &m) {
PyAffineMapAttribute::bind(m);
-
+ PyIntegerSetAttribute::bind(m);
PyDenseBoolArrayAttribute::bind(m);
PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
PyDenseI8ArrayAttribute::bind(m);
diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 726af884668b2d..11d1ade552f5a2 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -10,6 +10,7 @@
#include "mlir-c/Support.h"
#include "mlir/CAPI/AffineMap.h"
#include "mlir/CAPI/IR.h"
+#include "mlir/CAPI/IntegerSet.h"
#include "mlir/CAPI/Support.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Attributes.h"
@@ -192,6 +193,14 @@ MlirTypeID mlirIntegerSetAttrGetTypeID(void) {
return wrap(IntegerSetAttr::getTypeID());
}
+MlirAttribute mlirIntegerSetAttrGet(MlirIntegerSet set) {
+ return wrap(IntegerSetAttr::get(unwrap(set)));
+}
+
+MlirIntegerSet mlirIntegerSetAttrGetValue(MlirAttribute attr) {
+ return wrap(llvm::cast<IntegerSetAttr>(unwrap(attr)).getValue());
+}
+
//===----------------------------------------------------------------------===//
// Opaque attribute.
//===----------------------------------------------------------------------===//
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 4a2d0e977ccf26..80ce0849f35140 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -137,6 +137,7 @@ __all__ = [
"InsertionPoint",
"IntegerAttr",
"IntegerSet",
+ "IntegerSetAttr",
"IntegerSetConstraint",
"IntegerSetConstraintList",
"IntegerType",
@@ -1891,6 +1892,21 @@ class IntegerSet:
@property
def n_symbols(self) -> int: ...
+class IntegerSetAttr(Attribute):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(integer_set) -> IntegerSetAttr:
+ """
+ Gets an attribute wrapping an IntegerSet.
+ """
+ @staticmethod
+ def isinstance(other: Attribute) -> bool: ...
+ def __init__(self, cast_from_attr: Attribute) -> None: ...
+ @property
+ def type(self) -> Type: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
class IntegerSetConstraint:
def __init__(self, *args, **kwargs) -> None: ...
@property
diff --git a/mlir/python/mlir/dialects/affine.py b/mlir/python/mlir/dialects/affine.py
index 913cea61105cee..dcd4e7d2c232c2 100644
--- a/mlir/python/mlir/dialects/affine.py
+++ b/mlir/python/mlir/dialects/affine.py
@@ -156,3 +156,61 @@ def for_(
yield iv, iter_args[0]
else:
yield iv
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class AffineIfOp(AffineIfOp):
+ """Specialization for the Affine if op class."""
+
+ def __init__(
+ self,
+ cond: IntegerSet,
+ results_: Optional[Type] = None,
+ *,
+ cond_operands: Optional[_VariadicResultValueT] = None,
+ hasElse: bool = False,
+ loc=None,
+ ip=None,
+ ):
+ """Creates an Affine `if` operation.
+
+ - `cond` is the integer set used to determine which regions of code
+ will be executed.
+ - `results` are the list of types to be yielded by the operand.
+ - `cond_operands` is the list of arguments to substitute the
+ dimensions, then symbols in the `cond` integer set expression to
+ determine whether they are in the set.
+ - `hasElse` determines whether the affine if operation has the else
+ branch.
+ """
+ if results_ is None:
+ results_ = []
+ if cond_operands is None:
+ cond_operands = []
+
+ if not (actual_n_inputs := len(cond_operands)) == (
+ exp_n_inputs := cond.n_inputs
+ ):
+ raise ValueError(
+ f"expected {exp_n_inputs} condition operands, got {actual_n_inputs}"
+ )
+
+ operands = []
+ operands.extend(cond_operands)
+ results = []
+ results.extend(results_)
+
+ super().__init__(results, cond_operands, cond)
+ self.regions[0].blocks.append(*[])
+ if hasElse:
+ self.regions[1].blocks.append(*[])
+
+ @property
+ def then_block(self):
+ """Returns the then block of the if operation."""
+ return self.regions[0].blocks[0]
+
+ @property
+ def else_block(self):
+ """Returns the else block of the if operation."""
+ return self.regions[1].blocks[0]
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index a9ac765fe1c178..9a6ce462047ad2 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -22,6 +22,11 @@ def _affineMapAttr(x, context):
return AffineMapAttr.get(x)
+ at register_attribute_builder("IntegerSetAttr")
+def _integerSetAttr(x, context):
+ return IntegerSetAttr.get(x)
+
+
@register_attribute_builder("BoolAttr")
def _boolAttr(x, context):
return BoolAttr.get(x, context=context)
diff --git a/mlir/test/python/dialects/affine.py b/mlir/test/python/dialects/affine.py
index 6f39e1348fcd57..24fe4f398f4e87 100644
--- a/mlir/test/python/dialects/affine.py
+++ b/mlir/test/python/dialects/affine.py
@@ -265,3 +265,73 @@ def range_loop_8(lb, ub, memref_v):
add = arith.addi(i, i)
memref.store(add, it, [i])
affine.yield_([it])
+
+
+# CHECK-LABEL: TEST: testAffineIfWithoutElse
+ at constructAndPrintInModule
+def testAffineIfWithoutElse():
+ index = IndexType.get()
+ i32 = IntegerType.get_signless(32)
+ d0 = AffineDimExpr.get(0)
+
+ # CHECK: #[[$SET0:.*]] = affine_set<(d0) : (d0 - 5 >= 0)>
+ cond = IntegerSet.get(1, 0, [d0 - 5], [False])
+
+ # CHECK-LABEL: func.func @simple_affine_if(
+ # CHECK-SAME: %[[VAL_0:.*]]: index) {
+ # CHECK: affine.if #[[$SET0]](%[[VAL_0]]) {
+ # CHECK: %[[VAL_1:.*]] = arith.constant 1 : i32
+ # CHECK: %[[VAL_2:.*]] = arith.addi %[[VAL_1]], %[[VAL_1]] : i32
+ # CHECK: }
+ # CHECK: return
+ # CHECK: }
+ @func.FuncOp.from_py_func(index)
+ def simple_affine_if(cond_operands):
+ if_op = affine.AffineIfOp(cond, cond_operands=[cond_operands])
+ with InsertionPoint(if_op.then_block):
+ one = arith.ConstantOp(i32, 1)
+ add = arith.AddIOp(one, one)
+ affine.AffineYieldOp([])
+ return
+
+
+# CHECK-LABEL: TEST: testAffineIfWithElse
+ at constructAndPrintInModule
+def testAffineIfWithElse():
+ index = IndexType.get()
+ i32 = IntegerType.get_signless(32)
+ d0 = AffineDimExpr.get(0)
+
+ # CHECK: #[[$SET0:.*]] = affine_set<(d0) : (d0 - 5 >= 0)>
+ cond = IntegerSet.get(1, 0, [d0 - 5], [False])
+
+ # CHECK-LABEL: func.func @simple_affine_if_else(
+ # CHECK-SAME: %[[VAL_0:.*]]: index) {
+ # CHECK: %[[VAL_IF:.*]]:2 = affine.if #[[$SET0]](%[[VAL_0]]) -> (i32, i32) {
+ # CHECK: %[[VAL_XT:.*]] = arith.constant 0 : i32
+ # CHECK: %[[VAL_YT:.*]] = arith.constant 1 : i32
+ # CHECK: affine.yield %[[VAL_XT]], %[[VAL_YT]] : i32, i32
+ # CHECK: } else {
+ # CHECK: %[[VAL_XF:.*]] = arith.constant 2 : i32
+ # CHECK: %[[VAL_YF:.*]] = arith.constant 3 : i32
+ # CHECK: affine.yield %[[VAL_XF]], %[[VAL_YF]] : i32, i32
+ # CHECK: }
+ # CHECK: %[[VAL_ADD:.*]] = arith.addi %[[VAL_IF]]#0, %[[VAL_IF]]#1 : i32
+ # CHECK: return
+ # CHECK: }
+
+ @func.FuncOp.from_py_func(index)
+ def simple_affine_if_else(cond_operands):
+ if_op = affine.AffineIfOp(
+ cond, [i32, i32], cond_operands=[cond_operands], hasElse=True
+ )
+ with InsertionPoint(if_op.then_block):
+ x_true = arith.ConstantOp(i32, 0)
+ y_true = arith.ConstantOp(i32, 1)
+ affine.AffineYieldOp([x_true, y_true])
+ with InsertionPoint(if_op.else_block):
+ x_false = arith.ConstantOp(i32, 2)
+ y_false = arith.ConstantOp(i32, 3)
+ affine.AffineYieldOp([x_false, y_false])
+ add = arith.AddIOp(if_op.results[0], if_op.results[1])
+ return
diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index 4b475db6346453..00c3e1b4decdb7 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -162,6 +162,24 @@ def testAffineMapAttr():
assert attr_built == attr_parsed
+# CHECK-LABEL: TEST: testIntegerSetAttr
+ at run
+def testIntegerSetAttr():
+ with Context() as ctx:
+ d0 = AffineDimExpr.get(0)
+ d1 = AffineDimExpr.get(1)
+ s0 = AffineSymbolExpr.get(0)
+ c42 = AffineConstantExpr.get(42)
+ set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42], [True, False])
+
+ # CHECK: affine_set<(d0, d1)[s0] : (d0 - d1 == 0, s0 - 42 >= 0)>
+ attr_built = IntegerSetAttr.get(set0)
+ print(str(attr_built))
+
+ attr_parsed = Attribute.parse(str(attr_built))
+ assert attr_built == attr_parsed
+
+
# CHECK-LABEL: TEST: testFloatAttr
@run
def testFloatAttr():
More information about the Mlir-commits
mailing list