[Mlir-commits] [mlir] [MLIR][Python] Python binding support for IntegerSet attribute (PR #107640)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 6 13:54:31 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Amy Wang (kaitingwang)

<details>
<summary>Changes</summary>

Support IntegerSet attribute python binding.

---
Full diff: https://github.com/llvm/llvm-project/pull/107640.diff


6 Files Affected:

- (modified) mlir/include/mlir-c/BuiltinAttributes.h (+9) 
- (modified) mlir/lib/Bindings/Python/IRAttributes.cpp (+21-1) 
- (modified) mlir/lib/CAPI/IR/BuiltinAttributes.cpp (+9) 
- (modified) mlir/python/mlir/_mlir_libs/_mlir/ir.pyi (+16) 
- (modified) mlir/python/mlir/ir.py (+5) 
- (modified) mlir/test/python/ir/attributes.py (+18) 


``````````diff
diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h
index 231eb83b5e2694..7c8c84e55b962f 100644
--- a/mlir/include/mlir-c/BuiltinAttributes.h
+++ b/mlir/include/mlir-c/BuiltinAttributes.h
@@ -16,6 +16,7 @@
 
 #include "mlir-c/AffineMap.h"
 #include "mlir-c/IR.h"
+#include "mlir-c/IntegerSet.h"
 #include "mlir-c/Support.h"
 
 #ifdef __cplusplus
@@ -177,6 +178,14 @@ MLIR_CAPI_EXPORTED bool mlirBoolAttrGetValue(MlirAttribute attr);
 /// Checks whether the given attribute is an integer set attribute.
 MLIR_CAPI_EXPORTED bool mlirAttributeIsAIntegerSet(MlirAttribute attr);
 
+/// Creates an integer set attribute wrapping the given set. The attribute
+/// belongs to the same context as the integer set.
+MLIR_CAPI_EXPORTED MlirAttribute mlirIntegerSetAttrGet(MlirIntegerSet set);
+
+/// Returns the integer set wrapped in the given integer set attribute.
+MLIR_CAPI_EXPORTED MlirIntegerSet
+mlirIntegerSetAttrGetValue(MlirAttribute attr);
+
 /// Returns the typeID of an IntegerSet attribute.
 MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerSetAttrGetTypeID(void);
 
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index b4049bd7972d44..c557b9d4c5071c 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -147,6 +147,26 @@ class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
   }
 };
 
+class PyIntegerSetAttribute
+    : public PyConcreteAttribute<PyIntegerSetAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAIntegerSet;
+  static constexpr const char *pyClassName = "IntegerSetAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirIntegerSetAttrGetTypeID;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](PyIntegerSet &integerSet) {
+          MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get());
+          return PyIntegerSetAttribute(integerSet.getContext(), attr);
+        },
+        py::arg("integer_set"), "Gets an attribute wrapping an IntegerSet.");
+  }
+};
+
 template <typename T>
 static T pyTryCast(py::handle object) {
   try {
@@ -1426,7 +1446,7 @@ py::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
 
 void mlir::python::populateIRAttributes(py::module &m) {
   PyAffineMapAttribute::bind(m);
-
+  PyIntegerSetAttribute::bind(m);
   PyDenseBoolArrayAttribute::bind(m);
   PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
   PyDenseI8ArrayAttribute::bind(m);
diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 726af884668b2d..11d1ade552f5a2 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -10,6 +10,7 @@
 #include "mlir-c/Support.h"
 #include "mlir/CAPI/AffineMap.h"
 #include "mlir/CAPI/IR.h"
+#include "mlir/CAPI/IntegerSet.h"
 #include "mlir/CAPI/Support.h"
 #include "mlir/IR/AsmState.h"
 #include "mlir/IR/Attributes.h"
@@ -192,6 +193,14 @@ MlirTypeID mlirIntegerSetAttrGetTypeID(void) {
   return wrap(IntegerSetAttr::getTypeID());
 }
 
+MlirAttribute mlirIntegerSetAttrGet(MlirIntegerSet set) {
+  return wrap(IntegerSetAttr::get(unwrap(set)));
+}
+
+MlirIntegerSet mlirIntegerSetAttrGetValue(MlirAttribute attr) {
+  return wrap(llvm::cast<IntegerSetAttr>(unwrap(attr)).getValue());
+}
+
 //===----------------------------------------------------------------------===//
 // Opaque attribute.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 4a2d0e977ccf26..80ce0849f35140 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -137,6 +137,7 @@ __all__ = [
     "InsertionPoint",
     "IntegerAttr",
     "IntegerSet",
+    "IntegerSetAttr",
     "IntegerSetConstraint",
     "IntegerSetConstraintList",
     "IntegerType",
@@ -1891,6 +1892,21 @@ class IntegerSet:
     @property
     def n_symbols(self) -> int: ...
 
+class IntegerSetAttr(Attribute):
+    static_typeid: ClassVar[TypeID]
+    @staticmethod
+    def get(integer_set) -> IntegerSetAttr:
+        """
+        Gets an attribute wrapping an IntegerSet.
+        """
+    @staticmethod
+    def isinstance(other: Attribute) -> bool: ...
+    def __init__(self, cast_from_attr: Attribute) -> None: ...
+    @property
+    def type(self) -> Type: ...
+    @property
+    def typeid(self) -> TypeID: ...
+
 class IntegerSetConstraint:
     def __init__(self, *args, **kwargs) -> None: ...
     @property
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index a9ac765fe1c178..9a6ce462047ad2 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -22,6 +22,11 @@ def _affineMapAttr(x, context):
     return AffineMapAttr.get(x)
 
 
+ at register_attribute_builder("IntegerSetAttr")
+def _integerSetAttr(x, context):
+    return IntegerSetAttr.get(x)
+
+
 @register_attribute_builder("BoolAttr")
 def _boolAttr(x, context):
     return BoolAttr.get(x, context=context)
diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index 4b475db6346453..00c3e1b4decdb7 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -162,6 +162,24 @@ def testAffineMapAttr():
         assert attr_built == attr_parsed
 
 
+# CHECK-LABEL: TEST: testIntegerSetAttr
+ at run
+def testIntegerSetAttr():
+    with Context() as ctx:
+        d0 = AffineDimExpr.get(0)
+        d1 = AffineDimExpr.get(1)
+        s0 = AffineSymbolExpr.get(0)
+        c42 = AffineConstantExpr.get(42)
+        set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42], [True, False])
+
+        # CHECK: affine_set<(d0, d1)[s0] : (d0 - d1 == 0, s0 - 42 >= 0)>
+        attr_built = IntegerSetAttr.get(set0)
+        print(str(attr_built))
+
+        attr_parsed = Attribute.parse(str(attr_built))
+        assert attr_built == attr_parsed
+
+
 # CHECK-LABEL: TEST: testFloatAttr
 @run
 def testFloatAttr():

``````````

</details>


https://github.com/llvm/llvm-project/pull/107640


More information about the Mlir-commits mailing list