[Mlir-commits] [mlir] ac2e2d6 - [mlir] Add Python bindings for StridedLayoutAttr

Denys Shabalin llvmlistbot at llvm.org
Thu Sep 29 04:03:40 PDT 2022


Author: Denys Shabalin
Date: 2022-09-29T11:03:30Z
New Revision: ac2e2d6598191d6ffc31127b80d8cba10d00b765

URL: https://github.com/llvm/llvm-project/commit/ac2e2d6598191d6ffc31127b80d8cba10d00b765
DIFF: https://github.com/llvm/llvm-project/commit/ac2e2d6598191d6ffc31127b80d8cba10d00b765.diff

LOG: [mlir] Add Python bindings for StridedLayoutAttr

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D134869

Added: 
    

Modified: 
    mlir/include/mlir-c/BuiltinAttributes.h
    mlir/lib/Bindings/Python/IRAttributes.cpp
    mlir/lib/Bindings/Python/IRTypes.cpp
    mlir/lib/CAPI/IR/BuiltinAttributes.cpp
    mlir/python/mlir/dialects/_structured_transform_ops_ext.py
    mlir/test/python/ir/attributes.py
    mlir/test/python/ir/builtin_types.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h
index b2e32f6d58cb2..79f22376e003a 100644
--- a/mlir/include/mlir-c/BuiltinAttributes.h
+++ b/mlir/include/mlir-c/BuiltinAttributes.h
@@ -543,10 +543,9 @@ mlirSparseElementsAttrGetValues(MlirAttribute attr);
 MLIR_CAPI_EXPORTED bool mlirAttributeIsAStridedLayout(MlirAttribute attr);
 
 // Creates a strided layout attribute from given strides and offset.
-MLIR_CAPI_EXPORTED MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx,
-                                                          int64_t offset,
-                                                          intptr_t numStrides,
-                                                          int64_t *strides);
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset, intptr_t numStrides,
+                         const int64_t *strides);
 
 // Returns the offset in the given strided layout layout attribute.
 MLIR_CAPI_EXPORTED int64_t mlirStridedLayoutAttrGetOffset(MlirAttribute attr);

diff  --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 8d8cea395e174..e62f1550c6dc6 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -1031,6 +1031,45 @@ class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
   }
 };
 
+/// Strided layout attribute subclass.
+class PyStridedLayoutAttribute
+    : public PyConcreteAttribute<PyStridedLayoutAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
+  static constexpr const char *pyClassName = "StridedLayoutAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](int64_t offset, const std::vector<int64_t> strides,
+           DefaultingPyMlirContext ctx) {
+          MlirAttribute attr = mlirStridedLayoutAttrGet(
+              ctx->get(), offset, strides.size(), strides.data());
+          return PyStridedLayoutAttribute(ctx->getRef(), attr);
+        },
+        py::arg("offset"), py::arg("strides"), py::arg("context") = py::none(),
+        "Gets a strided layout attribute.");
+    c.def_property_readonly(
+        "offset",
+        [](PyStridedLayoutAttribute &self) {
+          return mlirStridedLayoutAttrGetOffset(self);
+        },
+        "Returns the value of the float point attribute");
+    c.def_property_readonly(
+        "strides",
+        [](PyStridedLayoutAttribute &self) {
+          intptr_t size = mlirStridedLayoutAttrGetNumStrides(self);
+          std::vector<int64_t> strides(size);
+          for (intptr_t i = 0; i < size; i++) {
+            strides[i] = mlirStridedLayoutAttrGetStride(self, i);
+          }
+          return strides;
+        },
+        "Returns the value of the float point attribute");
+  }
+};
+
 } // namespace
 
 void mlir::python::populateIRAttributes(py::module &m) {
@@ -1065,4 +1104,6 @@ void mlir::python::populateIRAttributes(py::module &m) {
   PyStringAttribute::bind(m);
   PyTypeAttribute::bind(m);
   PyUnitAttribute::bind(m);
+
+  PyStridedLayoutAttribute::bind(m);
 }

diff  --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 153664d0771dd..379510ce9654e 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -302,11 +302,11 @@ class PyShapedType : public PyConcreteType<PyShapedType> {
         },
         "Returns the shape of the ranked shaped type as a list of integers.");
     c.def_static(
-        "_get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); },
+        "get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); },
         "Returns the value used to indicate dynamic dimensions in shaped "
         "types.");
     c.def_static(
-        "_get_dynamic_stride_or_offset",
+        "get_dynamic_stride_or_offset",
         []() { return mlirShapedTypeGetDynamicStrideOrOffset(); },
         "Returns the value used to indicate dynamic strides or offsets in "
         "shaped types.");

diff  --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 1ae2a2b540d75..05ecb0fe80792 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -732,7 +732,8 @@ bool mlirAttributeIsAStridedLayout(MlirAttribute attr) {
 }
 
 MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset,
-                                       intptr_t numStrides, int64_t *strides) {
+                                       intptr_t numStrides,
+                                       const int64_t *strides) {
   return wrap(StridedLayoutAttr::get(unwrap(ctx), offset,
                                      ArrayRef<int64_t>(strides, numStrides)));
 }

diff  --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
index eddc384f64aa2..527a8656f7e33 100644
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
@@ -211,7 +211,7 @@ def __init__(self,
       static_split_point = split_point
       dynamic_split_point = None
     else:
-      static_split_point = _get_int64_attr(ShapedType._get_dynamic_size())
+      static_split_point = _get_int64_attr(ShapedType.get_dynamic_size())
       dynamic_split_point = _get_op_result_or_value(split_point)
 
     pdl_operation_type = pdl.OperationType.get()
@@ -255,7 +255,7 @@ def __init__(self,
           static_sizes.append(size)
         else:
           static_sizes.append(
-              IntegerAttr.get(i64_type, ShapedType._get_dynamic_size()))
+              IntegerAttr.get(i64_type, ShapedType.get_dynamic_size()))
           dynamic_sizes.append(_get_op_result_or_value(size))
       sizes_attr = ArrayAttr.get(static_sizes)
 

diff  --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index a958abfc9e75e..e0960e3f1c456 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -523,3 +523,22 @@ def testArrayAttr():
     array = array + [StringAttr.get("c")]
     # CHECK: concat: ["a", "b", "c"]
     print("concat: ", array)
+
+
+# CHECK-LABEL: TEST: testStridedLayoutAttr
+ at run
+def testStridedLayoutAttr():
+  with Context():
+    attr = StridedLayoutAttr.get(42, [5, 7, 13])
+    # CHECK: strided<[5, 7, 13], offset: 42>
+    print(attr)
+    # CHECK: 42
+    print(attr.offset)
+    # CHECK: 3
+    print(len(attr.strides))
+    # CHECK: 5
+    print(attr.strides[0])
+    # CHECK: 7
+    print(attr.strides[1])
+    # CHECK: 13
+    print(attr.strides[2])

diff  --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 945ed7e141f02..91c820f121b9f 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -487,3 +487,13 @@ def testOpaqueType():
     print("dialect namespace:", opaque.dialect_namespace)
     # CHECK: data: type
     print("data:", opaque.data)
+
+
+# CHECK-LABEL: TEST: testShapedTypeConstants
+# Tests that ShapedType exposes magic value constants.
+ at run
+def testShapedTypeConstants():
+  # CHECK: <class 'int'>
+  print(type(ShapedType.get_dynamic_size()))
+  # CHECK: <class 'int'>
+  print(type(ShapedType.get_dynamic_stride_or_offset()))


        


More information about the Mlir-commits mailing list