[Mlir-commits] [mlir] [MLIR][Python] Python binding support for IntegerSet attribute (PR #107640)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Sep 6 13:54:31 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Amy Wang (kaitingwang)
<details>
<summary>Changes</summary>
Support IntegerSet attribute python binding.
---
Full diff: https://github.com/llvm/llvm-project/pull/107640.diff
6 Files Affected:
- (modified) mlir/include/mlir-c/BuiltinAttributes.h (+9)
- (modified) mlir/lib/Bindings/Python/IRAttributes.cpp (+21-1)
- (modified) mlir/lib/CAPI/IR/BuiltinAttributes.cpp (+9)
- (modified) mlir/python/mlir/_mlir_libs/_mlir/ir.pyi (+16)
- (modified) mlir/python/mlir/ir.py (+5)
- (modified) mlir/test/python/ir/attributes.py (+18)
``````````diff
diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h
index 231eb83b5e2694..7c8c84e55b962f 100644
--- a/mlir/include/mlir-c/BuiltinAttributes.h
+++ b/mlir/include/mlir-c/BuiltinAttributes.h
@@ -16,6 +16,7 @@
#include "mlir-c/AffineMap.h"
#include "mlir-c/IR.h"
+#include "mlir-c/IntegerSet.h"
#include "mlir-c/Support.h"
#ifdef __cplusplus
@@ -177,6 +178,14 @@ MLIR_CAPI_EXPORTED bool mlirBoolAttrGetValue(MlirAttribute attr);
/// Checks whether the given attribute is an integer set attribute.
MLIR_CAPI_EXPORTED bool mlirAttributeIsAIntegerSet(MlirAttribute attr);
+/// Creates an integer set attribute wrapping the given set. The attribute
+/// belongs to the same context as the integer set.
+MLIR_CAPI_EXPORTED MlirAttribute mlirIntegerSetAttrGet(MlirIntegerSet set);
+
+/// Returns the integer set wrapped in the given integer set attribute.
+MLIR_CAPI_EXPORTED MlirIntegerSet
+mlirIntegerSetAttrGetValue(MlirAttribute attr);
+
/// Returns the typeID of an IntegerSet attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerSetAttrGetTypeID(void);
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index b4049bd7972d44..c557b9d4c5071c 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -147,6 +147,26 @@ class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
}
};
+class PyIntegerSetAttribute
+ : public PyConcreteAttribute<PyIntegerSetAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAIntegerSet;
+ static constexpr const char *pyClassName = "IntegerSetAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirIntegerSetAttrGetTypeID;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyIntegerSet &integerSet) {
+ MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get());
+ return PyIntegerSetAttribute(integerSet.getContext(), attr);
+ },
+ py::arg("integer_set"), "Gets an attribute wrapping an IntegerSet.");
+ }
+};
+
template <typename T>
static T pyTryCast(py::handle object) {
try {
@@ -1426,7 +1446,7 @@ py::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
void mlir::python::populateIRAttributes(py::module &m) {
PyAffineMapAttribute::bind(m);
-
+ PyIntegerSetAttribute::bind(m);
PyDenseBoolArrayAttribute::bind(m);
PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
PyDenseI8ArrayAttribute::bind(m);
diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 726af884668b2d..11d1ade552f5a2 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -10,6 +10,7 @@
#include "mlir-c/Support.h"
#include "mlir/CAPI/AffineMap.h"
#include "mlir/CAPI/IR.h"
+#include "mlir/CAPI/IntegerSet.h"
#include "mlir/CAPI/Support.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Attributes.h"
@@ -192,6 +193,14 @@ MlirTypeID mlirIntegerSetAttrGetTypeID(void) {
return wrap(IntegerSetAttr::getTypeID());
}
+MlirAttribute mlirIntegerSetAttrGet(MlirIntegerSet set) {
+ return wrap(IntegerSetAttr::get(unwrap(set)));
+}
+
+MlirIntegerSet mlirIntegerSetAttrGetValue(MlirAttribute attr) {
+ return wrap(llvm::cast<IntegerSetAttr>(unwrap(attr)).getValue());
+}
+
//===----------------------------------------------------------------------===//
// Opaque attribute.
//===----------------------------------------------------------------------===//
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 4a2d0e977ccf26..80ce0849f35140 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -137,6 +137,7 @@ __all__ = [
"InsertionPoint",
"IntegerAttr",
"IntegerSet",
+ "IntegerSetAttr",
"IntegerSetConstraint",
"IntegerSetConstraintList",
"IntegerType",
@@ -1891,6 +1892,21 @@ class IntegerSet:
@property
def n_symbols(self) -> int: ...
+class IntegerSetAttr(Attribute):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(integer_set) -> IntegerSetAttr:
+ """
+ Gets an attribute wrapping an IntegerSet.
+ """
+ @staticmethod
+ def isinstance(other: Attribute) -> bool: ...
+ def __init__(self, cast_from_attr: Attribute) -> None: ...
+ @property
+ def type(self) -> Type: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
class IntegerSetConstraint:
def __init__(self, *args, **kwargs) -> None: ...
@property
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index a9ac765fe1c178..9a6ce462047ad2 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -22,6 +22,11 @@ def _affineMapAttr(x, context):
return AffineMapAttr.get(x)
+ at register_attribute_builder("IntegerSetAttr")
+def _integerSetAttr(x, context):
+ return IntegerSetAttr.get(x)
+
+
@register_attribute_builder("BoolAttr")
def _boolAttr(x, context):
return BoolAttr.get(x, context=context)
diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index 4b475db6346453..00c3e1b4decdb7 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -162,6 +162,24 @@ def testAffineMapAttr():
assert attr_built == attr_parsed
+# CHECK-LABEL: TEST: testIntegerSetAttr
+ at run
+def testIntegerSetAttr():
+ with Context() as ctx:
+ d0 = AffineDimExpr.get(0)
+ d1 = AffineDimExpr.get(1)
+ s0 = AffineSymbolExpr.get(0)
+ c42 = AffineConstantExpr.get(42)
+ set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42], [True, False])
+
+ # CHECK: affine_set<(d0, d1)[s0] : (d0 - d1 == 0, s0 - 42 >= 0)>
+ attr_built = IntegerSetAttr.get(set0)
+ print(str(attr_built))
+
+ attr_parsed = Attribute.parse(str(attr_built))
+ assert attr_built == attr_parsed
+
+
# CHECK-LABEL: TEST: testFloatAttr
@run
def testFloatAttr():
``````````
</details>
https://github.com/llvm/llvm-project/pull/107640
More information about the Mlir-commits
mailing list