[Mlir-commits] [mlir] 334873f - [MLIR][Python] Python binding support for IntegerSet attribute (#107640)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 11 04:37:38 PDT 2024


Author: Amy Wang
Date: 2024-09-11T07:37:35-04:00
New Revision: 334873fe2df27a4fa613e8744f29e502d3358397

URL: https://github.com/llvm/llvm-project/commit/334873fe2df27a4fa613e8744f29e502d3358397
DIFF: https://github.com/llvm/llvm-project/commit/334873fe2df27a4fa613e8744f29e502d3358397.diff

LOG: [MLIR][Python] Python binding support for IntegerSet attribute (#107640)

Support IntegerSet attribute python binding.

Added: 
    

Modified: 
    mlir/include/mlir-c/BuiltinAttributes.h
    mlir/lib/Bindings/Python/IRAttributes.cpp
    mlir/lib/CAPI/IR/BuiltinAttributes.cpp
    mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
    mlir/python/mlir/ir.py
    mlir/test/python/ir/attributes.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h
index 231eb83b5e2694..7c8c84e55b962f 100644
--- a/mlir/include/mlir-c/BuiltinAttributes.h
+++ b/mlir/include/mlir-c/BuiltinAttributes.h
@@ -16,6 +16,7 @@
 
 #include "mlir-c/AffineMap.h"
 #include "mlir-c/IR.h"
+#include "mlir-c/IntegerSet.h"
 #include "mlir-c/Support.h"
 
 #ifdef __cplusplus
@@ -177,6 +178,14 @@ MLIR_CAPI_EXPORTED bool mlirBoolAttrGetValue(MlirAttribute attr);
 /// Checks whether the given attribute is an integer set attribute.
 MLIR_CAPI_EXPORTED bool mlirAttributeIsAIntegerSet(MlirAttribute attr);
 
+/// Creates an integer set attribute wrapping the given set. The attribute
+/// belongs to the same context as the integer set.
+MLIR_CAPI_EXPORTED MlirAttribute mlirIntegerSetAttrGet(MlirIntegerSet set);
+
+/// Returns the integer set wrapped in the given integer set attribute.
+MLIR_CAPI_EXPORTED MlirIntegerSet
+mlirIntegerSetAttrGetValue(MlirAttribute attr);
+
 /// Returns the typeID of an IntegerSet attribute.
 MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerSetAttrGetTypeID(void);
 

diff  --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index b4049bd7972d44..bfdd4a520af275 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -147,6 +147,26 @@ class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
   }
 };
 
+class PyIntegerSetAttribute
+    : public PyConcreteAttribute<PyIntegerSetAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAIntegerSet;
+  static constexpr const char *pyClassName = "IntegerSetAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirIntegerSetAttrGetTypeID;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](PyIntegerSet &integerSet) {
+          MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get());
+          return PyIntegerSetAttribute(integerSet.getContext(), attr);
+        },
+        py::arg("integer_set"), "Gets an attribute wrapping an IntegerSet.");
+  }
+};
+
 template <typename T>
 static T pyTryCast(py::handle object) {
   try {
@@ -1426,7 +1446,6 @@ py::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
 
 void mlir::python::populateIRAttributes(py::module &m) {
   PyAffineMapAttribute::bind(m);
-
   PyDenseBoolArrayAttribute::bind(m);
   PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
   PyDenseI8ArrayAttribute::bind(m);
@@ -1466,6 +1485,7 @@ void mlir::python::populateIRAttributes(py::module &m) {
   PyOpaqueAttribute::bind(m);
   PyFloatAttribute::bind(m);
   PyIntegerAttribute::bind(m);
+  PyIntegerSetAttribute::bind(m);
   PyStringAttribute::bind(m);
   PyTypeAttribute::bind(m);
   PyGlobals::get().registerTypeCaster(

diff  --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 726af884668b2d..11d1ade552f5a2 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -10,6 +10,7 @@
 #include "mlir-c/Support.h"
 #include "mlir/CAPI/AffineMap.h"
 #include "mlir/CAPI/IR.h"
+#include "mlir/CAPI/IntegerSet.h"
 #include "mlir/CAPI/Support.h"
 #include "mlir/IR/AsmState.h"
 #include "mlir/IR/Attributes.h"
@@ -192,6 +193,14 @@ MlirTypeID mlirIntegerSetAttrGetTypeID(void) {
   return wrap(IntegerSetAttr::getTypeID());
 }
 
+MlirAttribute mlirIntegerSetAttrGet(MlirIntegerSet set) {
+  return wrap(IntegerSetAttr::get(unwrap(set)));
+}
+
+MlirIntegerSet mlirIntegerSetAttrGetValue(MlirAttribute attr) {
+  return wrap(llvm::cast<IntegerSetAttr>(unwrap(attr)).getValue());
+}
+
 //===----------------------------------------------------------------------===//
 // Opaque attribute.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 7b4fac7275bfc6..a3d3a926186966 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -138,6 +138,7 @@ __all__ = [
     "InsertionPoint",
     "IntegerAttr",
     "IntegerSet",
+    "IntegerSetAttr",
     "IntegerSetConstraint",
     "IntegerSetConstraintList",
     "IntegerType",
@@ -1905,6 +1906,21 @@ class IntegerSet:
     @property
     def n_symbols(self) -> int: ...
 
+class IntegerSetAttr(Attribute):
+    static_typeid: ClassVar[TypeID]
+    @staticmethod
+    def get(integer_set) -> IntegerSetAttr:
+        """
+        Gets an attribute wrapping an IntegerSet.
+        """
+    @staticmethod
+    def isinstance(other: Attribute) -> bool: ...
+    def __init__(self, cast_from_attr: Attribute) -> None: ...
+    @property
+    def type(self) -> Type: ...
+    @property
+    def typeid(self) -> TypeID: ...
+
 class IntegerSetConstraint:
     def __init__(self, *args, **kwargs) -> None: ...
     @property

diff  --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index a9ac765fe1c178..9a6ce462047ad2 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -22,6 +22,11 @@ def _affineMapAttr(x, context):
     return AffineMapAttr.get(x)
 
 
+ at register_attribute_builder("IntegerSetAttr")
+def _integerSetAttr(x, context):
+    return IntegerSetAttr.get(x)
+
+
 @register_attribute_builder("BoolAttr")
 def _boolAttr(x, context):
     return BoolAttr.get(x, context=context)

diff  --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index 4b475db6346453..00c3e1b4decdb7 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -162,6 +162,24 @@ def testAffineMapAttr():
         assert attr_built == attr_parsed
 
 
+# CHECK-LABEL: TEST: testIntegerSetAttr
+ at run
+def testIntegerSetAttr():
+    with Context() as ctx:
+        d0 = AffineDimExpr.get(0)
+        d1 = AffineDimExpr.get(1)
+        s0 = AffineSymbolExpr.get(0)
+        c42 = AffineConstantExpr.get(42)
+        set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42], [True, False])
+
+        # CHECK: affine_set<(d0, d1)[s0] : (d0 - d1 == 0, s0 - 42 >= 0)>
+        attr_built = IntegerSetAttr.get(set0)
+        print(str(attr_built))
+
+        attr_parsed = Attribute.parse(str(attr_built))
+        assert attr_built == attr_parsed
+
+
 # CHECK-LABEL: TEST: testFloatAttr
 @run
 def testFloatAttr():


        


More information about the Mlir-commits mailing list