[Mlir-commits] [mlir] ac2e2d6 - [mlir] Add Python bindings for StridedLayoutAttr
Denys Shabalin
llvmlistbot at llvm.org
Thu Sep 29 04:03:40 PDT 2022
Author: Denys Shabalin
Date: 2022-09-29T11:03:30Z
New Revision: ac2e2d6598191d6ffc31127b80d8cba10d00b765
URL: https://github.com/llvm/llvm-project/commit/ac2e2d6598191d6ffc31127b80d8cba10d00b765
DIFF: https://github.com/llvm/llvm-project/commit/ac2e2d6598191d6ffc31127b80d8cba10d00b765.diff
LOG: [mlir] Add Python bindings for StridedLayoutAttr
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D134869
Added:
Modified:
mlir/include/mlir-c/BuiltinAttributes.h
mlir/lib/Bindings/Python/IRAttributes.cpp
mlir/lib/Bindings/Python/IRTypes.cpp
mlir/lib/CAPI/IR/BuiltinAttributes.cpp
mlir/python/mlir/dialects/_structured_transform_ops_ext.py
mlir/test/python/ir/attributes.py
mlir/test/python/ir/builtin_types.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h
index b2e32f6d58cb2..79f22376e003a 100644
--- a/mlir/include/mlir-c/BuiltinAttributes.h
+++ b/mlir/include/mlir-c/BuiltinAttributes.h
@@ -543,10 +543,9 @@ mlirSparseElementsAttrGetValues(MlirAttribute attr);
MLIR_CAPI_EXPORTED bool mlirAttributeIsAStridedLayout(MlirAttribute attr);
// Creates a strided layout attribute from given strides and offset.
-MLIR_CAPI_EXPORTED MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx,
- int64_t offset,
- intptr_t numStrides,
- int64_t *strides);
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset, intptr_t numStrides,
+ const int64_t *strides);
// Returns the offset in the given strided layout layout attribute.
MLIR_CAPI_EXPORTED int64_t mlirStridedLayoutAttrGetOffset(MlirAttribute attr);
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 8d8cea395e174..e62f1550c6dc6 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -1031,6 +1031,45 @@ class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
}
};
+/// Strided layout attribute subclass.
+class PyStridedLayoutAttribute
+ : public PyConcreteAttribute<PyStridedLayoutAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
+ static constexpr const char *pyClassName = "StridedLayoutAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](int64_t offset, const std::vector<int64_t> strides,
+ DefaultingPyMlirContext ctx) {
+ MlirAttribute attr = mlirStridedLayoutAttrGet(
+ ctx->get(), offset, strides.size(), strides.data());
+ return PyStridedLayoutAttribute(ctx->getRef(), attr);
+ },
+ py::arg("offset"), py::arg("strides"), py::arg("context") = py::none(),
+ "Gets a strided layout attribute.");
+ c.def_property_readonly(
+ "offset",
+ [](PyStridedLayoutAttribute &self) {
+ return mlirStridedLayoutAttrGetOffset(self);
+ },
+ "Returns the value of the float point attribute");
+ c.def_property_readonly(
+ "strides",
+ [](PyStridedLayoutAttribute &self) {
+ intptr_t size = mlirStridedLayoutAttrGetNumStrides(self);
+ std::vector<int64_t> strides(size);
+ for (intptr_t i = 0; i < size; i++) {
+ strides[i] = mlirStridedLayoutAttrGetStride(self, i);
+ }
+ return strides;
+ },
+ "Returns the value of the float point attribute");
+ }
+};
+
} // namespace
void mlir::python::populateIRAttributes(py::module &m) {
@@ -1065,4 +1104,6 @@ void mlir::python::populateIRAttributes(py::module &m) {
PyStringAttribute::bind(m);
PyTypeAttribute::bind(m);
PyUnitAttribute::bind(m);
+
+ PyStridedLayoutAttribute::bind(m);
}
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 153664d0771dd..379510ce9654e 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -302,11 +302,11 @@ class PyShapedType : public PyConcreteType<PyShapedType> {
},
"Returns the shape of the ranked shaped type as a list of integers.");
c.def_static(
- "_get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); },
+ "get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); },
"Returns the value used to indicate dynamic dimensions in shaped "
"types.");
c.def_static(
- "_get_dynamic_stride_or_offset",
+ "get_dynamic_stride_or_offset",
[]() { return mlirShapedTypeGetDynamicStrideOrOffset(); },
"Returns the value used to indicate dynamic strides or offsets in "
"shaped types.");
diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 1ae2a2b540d75..05ecb0fe80792 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -732,7 +732,8 @@ bool mlirAttributeIsAStridedLayout(MlirAttribute attr) {
}
MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset,
- intptr_t numStrides, int64_t *strides) {
+ intptr_t numStrides,
+ const int64_t *strides) {
return wrap(StridedLayoutAttr::get(unwrap(ctx), offset,
ArrayRef<int64_t>(strides, numStrides)));
}
diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
index eddc384f64aa2..527a8656f7e33 100644
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
@@ -211,7 +211,7 @@ def __init__(self,
static_split_point = split_point
dynamic_split_point = None
else:
- static_split_point = _get_int64_attr(ShapedType._get_dynamic_size())
+ static_split_point = _get_int64_attr(ShapedType.get_dynamic_size())
dynamic_split_point = _get_op_result_or_value(split_point)
pdl_operation_type = pdl.OperationType.get()
@@ -255,7 +255,7 @@ def __init__(self,
static_sizes.append(size)
else:
static_sizes.append(
- IntegerAttr.get(i64_type, ShapedType._get_dynamic_size()))
+ IntegerAttr.get(i64_type, ShapedType.get_dynamic_size()))
dynamic_sizes.append(_get_op_result_or_value(size))
sizes_attr = ArrayAttr.get(static_sizes)
diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index a958abfc9e75e..e0960e3f1c456 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -523,3 +523,22 @@ def testArrayAttr():
array = array + [StringAttr.get("c")]
# CHECK: concat: ["a", "b", "c"]
print("concat: ", array)
+
+
+# CHECK-LABEL: TEST: testStridedLayoutAttr
+ at run
+def testStridedLayoutAttr():
+ with Context():
+ attr = StridedLayoutAttr.get(42, [5, 7, 13])
+ # CHECK: strided<[5, 7, 13], offset: 42>
+ print(attr)
+ # CHECK: 42
+ print(attr.offset)
+ # CHECK: 3
+ print(len(attr.strides))
+ # CHECK: 5
+ print(attr.strides[0])
+ # CHECK: 7
+ print(attr.strides[1])
+ # CHECK: 13
+ print(attr.strides[2])
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 945ed7e141f02..91c820f121b9f 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -487,3 +487,13 @@ def testOpaqueType():
print("dialect namespace:", opaque.dialect_namespace)
# CHECK: data: type
print("data:", opaque.data)
+
+
+# CHECK-LABEL: TEST: testShapedTypeConstants
+# Tests that ShapedType exposes magic value constants.
+ at run
+def testShapedTypeConstants():
+ # CHECK: <class 'int'>
+ print(type(ShapedType.get_dynamic_size()))
+ # CHECK: <class 'int'>
+ print(type(ShapedType.get_dynamic_stride_or_offset()))
More information about the Mlir-commits
mailing list