[Mlir-commits] [mlir] 0e6beb2 - [mlir][Python] Add python binding to create DenseElementsAttribute.
Stella Laurenzo
llvmlistbot at llvm.org
Mon Oct 19 22:37:25 PDT 2020
Author: Stella Laurenzo
Date: 2020-10-19T22:29:35-07:00
New Revision: 0e6beb29966abc6666e73ab5f151cb9754f04901
URL: https://github.com/llvm/llvm-project/commit/0e6beb29966abc6666e73ab5f151cb9754f04901
DIFF: https://github.com/llvm/llvm-project/commit/0e6beb29966abc6666e73ab5f151cb9754f04901.diff
LOG: [mlir][Python] Add python binding to create DenseElementsAttribute.
* Interops with Python buffers/numpy arrays to create.
* Also cleans up 'get' factory methods on some types to be consistent.
* Adds mlirAttributeGetType() to C-API to facilitate error handling and other uses.
* Punts on a lot of features of the ElementsAttribute hierarchy for now.
* Does not yet support bool or string attributes.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D89363
Added:
mlir/test/Bindings/Python/ir_array_attributes.py
Modified:
mlir/lib/Bindings/Python/IRModules.cpp
mlir/test/Bindings/Python/ir_attributes.py
mlir/test/Bindings/Python/ir_types.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 8f525e8b6239..2a768df0ffd9 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -912,6 +912,150 @@ class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
}
};
+// TODO: Support construction of bool elements.
+// TODO: Support construction of string elements.
+class PyDenseElementsAttribute
+ : public PyConcreteAttribute<PyDenseElementsAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
+ static constexpr const char *pyClassName = "DenseElementsAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static PyDenseElementsAttribute getFromBuffer(PyMlirContext &contextWrapper,
+ py::buffer array,
+ bool signless) {
+ // Request a contiguous view. In exotic cases, this will cause a copy.
+ int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
+ Py_buffer *view = new Py_buffer();
+ if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
+ delete view;
+ throw py::error_already_set();
+ }
+ py::buffer_info arrayInfo(view);
+
+ MlirContext context = contextWrapper.get();
+ // Switch on the types that can be bulk loaded between the Python and
+ // MLIR-C APIs.
+ if (arrayInfo.format == "f") {
+ // f32
+ assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
+ return PyDenseElementsAttribute(
+ contextWrapper.getRef(),
+ bulkLoad(context, mlirDenseElementsAttrFloatGet,
+ mlirF32TypeGet(context), arrayInfo));
+ } else if (arrayInfo.format == "d") {
+ // f64
+ assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
+ return PyDenseElementsAttribute(
+ contextWrapper.getRef(),
+ bulkLoad(context, mlirDenseElementsAttrDoubleGet,
+ mlirF64TypeGet(context), arrayInfo));
+ } else if (arrayInfo.format == "i") {
+ // i32
+ assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
+ MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
+ : mlirIntegerTypeSignedGet(context, 32);
+ return PyDenseElementsAttribute(contextWrapper.getRef(),
+ bulkLoad(context,
+ mlirDenseElementsAttrInt32Get,
+ elementType, arrayInfo));
+ } else if (arrayInfo.format == "I") {
+ // unsigned i32
+ assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
+ MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
+ : mlirIntegerTypeUnsignedGet(context, 32);
+ return PyDenseElementsAttribute(contextWrapper.getRef(),
+ bulkLoad(context,
+ mlirDenseElementsAttrUInt32Get,
+ elementType, arrayInfo));
+ } else if (arrayInfo.format == "l") {
+ // i64
+ assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
+ MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
+ : mlirIntegerTypeSignedGet(context, 64);
+ return PyDenseElementsAttribute(contextWrapper.getRef(),
+ bulkLoad(context,
+ mlirDenseElementsAttrInt64Get,
+ elementType, arrayInfo));
+ } else if (arrayInfo.format == "L") {
+ // unsigned i64
+ assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
+ MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
+ : mlirIntegerTypeUnsignedGet(context, 64);
+ return PyDenseElementsAttribute(contextWrapper.getRef(),
+ bulkLoad(context,
+ mlirDenseElementsAttrUInt64Get,
+ elementType, arrayInfo));
+ }
+
+ // TODO: Fall back to string-based get.
+ std::string message = "unimplemented array format conversion from format: ";
+ message.append(arrayInfo.format);
+ throw SetPyError(PyExc_ValueError, message);
+ }
+
+ static PyDenseElementsAttribute getSplat(PyType shapedType,
+ PyAttribute &elementAttr) {
+ auto contextWrapper =
+ PyMlirContext::forContext(mlirTypeGetContext(shapedType));
+ if (!mlirAttributeIsAInteger(elementAttr.attr) &&
+ !mlirAttributeIsAFloat(elementAttr.attr)) {
+ std::string message = "Illegal element type for DenseElementsAttr: ";
+ message.append(py::repr(py::cast(elementAttr)));
+ throw SetPyError(PyExc_ValueError, message);
+ }
+ if (!mlirTypeIsAShaped(shapedType) ||
+ !mlirShapedTypeHasStaticShape(shapedType)) {
+ std::string message =
+ "Expected a static ShapedType for the shaped_type parameter: ";
+ message.append(py::repr(py::cast(shapedType)));
+ throw SetPyError(PyExc_ValueError, message);
+ }
+ MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType.type);
+ MlirType attrType = mlirAttributeGetType(elementAttr.attr);
+ if (!mlirTypeEqual(shapedElementType, attrType)) {
+ std::string message =
+ "Shaped element type and attribute type must be equal: shaped=";
+ message.append(py::repr(py::cast(shapedType)));
+ message.append(", element=");
+ message.append(py::repr(py::cast(elementAttr)));
+ throw SetPyError(PyExc_ValueError, message);
+ }
+
+ MlirAttribute elements =
+ mlirDenseElementsAttrSplatGet(shapedType.type, elementAttr.attr);
+ return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static("get", PyDenseElementsAttribute::getFromBuffer,
+ py::arg("context"), py::arg("array"),
+ py::arg("signless") = true, "Gets from a buffer or ndarray")
+ .def_static("get_splat", PyDenseElementsAttribute::getSplat,
+ py::arg("shaped_type"), py::arg("element_attr"),
+ "Gets a DenseElementsAttr where all values are the same")
+ .def_property_readonly("is_splat",
+ [](PyDenseElementsAttribute &self) -> bool {
+ return mlirDenseElementsAttrIsSplat(self.attr);
+ });
+ }
+
+private:
+ template <typename ElementTy>
+ static MlirAttribute
+ bulkLoad(MlirContext context,
+ MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *),
+ MlirType mlirElementType, py::buffer_info &arrayInfo) {
+ SmallVector<int64_t, 4> shape(arrayInfo.shape.begin(),
+ arrayInfo.shape.begin() + arrayInfo.ndim);
+ auto shapedType =
+ mlirRankedTensorTypeGet(shape.size(), shape.data(), mlirElementType);
+ intptr_t numElements = arrayInfo.size;
+ const ElementTy *contents = static_cast<const ElementTy *>(arrayInfo.ptr);
+ return ctor(shapedType, numElements, contents);
+ }
+};
+
} // namespace
//------------------------------------------------------------------------------
@@ -1021,11 +1165,13 @@ class PyIndexType : public PyConcreteType<PyIndexType> {
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
- c.def(py::init([](PyMlirContext &context) {
- MlirType t = mlirIndexTypeGet(context.get());
- return PyIndexType(context.getRef(), t);
- }),
- "Create a index type.");
+ c.def_static(
+ "get",
+ [](PyMlirContext &context) {
+ MlirType t = mlirIndexTypeGet(context.get());
+ return PyIndexType(context.getRef(), t);
+ },
+ "Create a index type.");
}
};
@@ -1037,11 +1183,13 @@ class PyBF16Type : public PyConcreteType<PyBF16Type> {
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
- c.def(py::init([](PyMlirContext &context) {
- MlirType t = mlirBF16TypeGet(context.get());
- return PyBF16Type(context.getRef(), t);
- }),
- "Create a bf16 type.");
+ c.def_static(
+ "get",
+ [](PyMlirContext &context) {
+ MlirType t = mlirBF16TypeGet(context.get());
+ return PyBF16Type(context.getRef(), t);
+ },
+ "Create a bf16 type.");
}
};
@@ -1053,11 +1201,13 @@ class PyF16Type : public PyConcreteType<PyF16Type> {
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
- c.def(py::init([](PyMlirContext &context) {
- MlirType t = mlirF16TypeGet(context.get());
- return PyF16Type(context.getRef(), t);
- }),
- "Create a f16 type.");
+ c.def_static(
+ "get",
+ [](PyMlirContext &context) {
+ MlirType t = mlirF16TypeGet(context.get());
+ return PyF16Type(context.getRef(), t);
+ },
+ "Create a f16 type.");
}
};
@@ -1069,11 +1219,13 @@ class PyF32Type : public PyConcreteType<PyF32Type> {
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
- c.def(py::init([](PyMlirContext &context) {
- MlirType t = mlirF32TypeGet(context.get());
- return PyF32Type(context.getRef(), t);
- }),
- "Create a f32 type.");
+ c.def_static(
+ "get",
+ [](PyMlirContext &context) {
+ MlirType t = mlirF32TypeGet(context.get());
+ return PyF32Type(context.getRef(), t);
+ },
+ "Create a f32 type.");
}
};
@@ -1085,11 +1237,13 @@ class PyF64Type : public PyConcreteType<PyF64Type> {
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
- c.def(py::init([](PyMlirContext &context) {
- MlirType t = mlirF64TypeGet(context.get());
- return PyF64Type(context.getRef(), t);
- }),
- "Create a f64 type.");
+ c.def_static(
+ "get",
+ [](PyMlirContext &context) {
+ MlirType t = mlirF64TypeGet(context.get());
+ return PyF64Type(context.getRef(), t);
+ },
+ "Create a f64 type.");
}
};
@@ -1101,11 +1255,13 @@ class PyNoneType : public PyConcreteType<PyNoneType> {
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
- c.def(py::init([](PyMlirContext &context) {
- MlirType t = mlirNoneTypeGet(context.get());
- return PyNoneType(context.getRef(), t);
- }),
- "Create a none type.");
+ c.def_static(
+ "get",
+ [](PyMlirContext &context) {
+ MlirType t = mlirNoneTypeGet(context.get());
+ return PyNoneType(context.getRef(), t);
+ },
+ "Create a none type.");
}
};
@@ -1118,7 +1274,7 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
static void bindDerived(ClassTy &c) {
c.def_static(
- "get_complex",
+ "get",
[](PyType &elementType) {
// The element must be a floating point or integer scalar type.
if (mlirTypeIsAIntegerOrFloat(elementType.type)) {
@@ -1224,7 +1380,7 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
static void bindDerived(ClassTy &c) {
c.def_static(
- "get_vector",
+ "get",
// TODO: Make the location optional and create a default location.
[](std::vector<int64_t> shape, PyType &elementType, PyLocation &loc) {
MlirType t = mlirVectorTypeGetChecked(shape.size(), shape.data(),
@@ -1254,7 +1410,7 @@ class PyRankedTensorType
static void bindDerived(ClassTy &c) {
c.def_static(
- "get_ranked_tensor",
+ "get",
// TODO: Make the location optional and create a default location.
[](std::vector<int64_t> shape, PyType &elementType, PyLocation &loc) {
MlirType t = mlirRankedTensorTypeGetChecked(
@@ -1286,7 +1442,7 @@ class PyUnrankedTensorType
static void bindDerived(ClassTy &c) {
c.def_static(
- "get_unranked_tensor",
+ "get",
// TODO: Make the location optional and create a default location.
[](PyType &elementType, PyLocation &loc) {
MlirType t =
@@ -1366,7 +1522,7 @@ class PyUnrankedMemRefType
static void bindDerived(ClassTy &c) {
c.def_static(
- "get_unranked_memref",
+ "get",
// TODO: Make the location optional and create a default location.
[](PyType &elementType, unsigned memorySpace, PyLocation &loc) {
MlirType t = mlirUnrankedMemRefTypeGetChecked(elementType.type,
@@ -1719,6 +1875,11 @@ void mlir::python::populateIRSubmodule(py::module &m) {
"context",
[](PyAttribute &self) { return self.getContext().getObject(); },
"Context that owns the Attribute")
+ .def_property_readonly("type",
+ [](PyAttribute &self) {
+ return PyType(self.getContext()->getRef(),
+ mlirAttributeGetType(self.attr));
+ })
.def(
"get_named",
[](PyAttribute &self, std::string name) {
@@ -1796,6 +1957,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
PyIntegerAttribute::bind(m);
PyBoolAttribute::bind(m);
PyStringAttribute::bind(m);
+ PyDenseElementsAttribute::bind(m);
// Mapping of Type.
py::class_<PyType>(m, "Type")
diff --git a/mlir/test/Bindings/Python/ir_array_attributes.py b/mlir/test/Bindings/Python/ir_array_attributes.py
new file mode 100644
index 000000000000..97a9802ae148
--- /dev/null
+++ b/mlir/test/Bindings/Python/ir_array_attributes.py
@@ -0,0 +1,213 @@
+# RUN: %PYTHON %s | FileCheck %s
+# Note that this is separate from ir_attributes.py since it depends on numpy,
+# and we may want to disable if not available.
+
+import gc
+import mlir
+import numpy as np
+
+def run(f):
+ print("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert mlir.ir.Context._get_live_count() == 0
+
+################################################################################
+# Tests of the array/buffer .get() factory method on unsupported dtype.
+################################################################################
+
+def testGetDenseElementsUnsupported():
+ ctx = mlir.ir.Context()
+ array = np.array([["hello", "goodbye"]])
+ try:
+ attr = mlir.ir.DenseElementsAttr.get(ctx, array)
+ except ValueError as e:
+ # CHECK: unimplemented array format conversion from format:
+ print(e)
+
+run(testGetDenseElementsUnsupported)
+
+################################################################################
+# Splats.
+################################################################################
+
+# CHECK-LABEL: TEST: testGetDenseElementsSplatInt
+def testGetDenseElementsSplatInt():
+ ctx = mlir.ir.Context()
+ loc = ctx.get_unknown_location()
+ t = mlir.ir.IntegerType.get_signless(ctx, 32)
+ element = mlir.ir.IntegerAttr.get(t, 555)
+ shaped_type = mlir.ir.RankedTensorType.get((2, 3, 4), t, loc)
+ attr = mlir.ir.DenseElementsAttr.get_splat(shaped_type, element)
+ # CHECK: dense<555> : tensor<2x3x4xi32>
+ print(attr)
+ # CHECK: is_splat: True
+ print("is_splat:", attr.is_splat)
+
+run(testGetDenseElementsSplatInt)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsSplatFloat
+def testGetDenseElementsSplatFloat():
+ ctx = mlir.ir.Context()
+ loc = ctx.get_unknown_location()
+ t = mlir.ir.F32Type.get(ctx)
+ element = mlir.ir.FloatAttr.get(t, 1.2, loc)
+ shaped_type = mlir.ir.RankedTensorType.get((2, 3, 4), t, loc)
+ attr = mlir.ir.DenseElementsAttr.get_splat(shaped_type, element)
+ # CHECK: dense<1.200000e+00> : tensor<2x3x4xf32>
+ print(attr)
+
+run(testGetDenseElementsSplatFloat)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsSplatErrors
+def testGetDenseElementsSplatErrors():
+ ctx = mlir.ir.Context()
+ loc = ctx.get_unknown_location()
+ t = mlir.ir.F32Type.get(ctx)
+ other_t = mlir.ir.F64Type.get(ctx)
+ element = mlir.ir.FloatAttr.get(t, 1.2, loc)
+ other_element = mlir.ir.FloatAttr.get(other_t, 1.2, loc)
+ shaped_type = mlir.ir.RankedTensorType.get((2, 3, 4), t, loc)
+ dynamic_shaped_type = mlir.ir.UnrankedTensorType.get(t, loc)
+ non_shaped_type = t
+
+ try:
+ attr = mlir.ir.DenseElementsAttr.get_splat(non_shaped_type, element)
+ except ValueError as e:
+ # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(f32)
+ print(e)
+
+ try:
+ attr = mlir.ir.DenseElementsAttr.get_splat(dynamic_shaped_type, element)
+ except ValueError as e:
+ # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(tensor<*xf32>)
+ print(e)
+
+ try:
+ attr = mlir.ir.DenseElementsAttr.get_splat(shaped_type, other_element)
+ except ValueError as e:
+ # CHECK: Shaped element type and attribute type must be equal: shaped=Type(tensor<2x3x4xf32>), element=Attribute(1.200000e+00 : f64)
+ print(e)
+
+run(testGetDenseElementsSplatErrors)
+
+
+################################################################################
+# Tests of the array/buffer .get() factory method, in all of its permutations.
+################################################################################
+
+### float and double arrays.
+
+# CHECK-LABEL: TEST: testGetDenseElementsF32
+def testGetDenseElementsF32():
+ ctx = mlir.ir.Context()
+ array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32)
+ attr = mlir.ir.DenseElementsAttr.get(ctx, array)
+ # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf32>
+ print(attr)
+ # CHECK: is_splat: False
+ print("is_splat:", attr.is_splat)
+
+run(testGetDenseElementsF32)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsF64
+def testGetDenseElementsF64():
+ ctx = mlir.ir.Context()
+ array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float64)
+ attr = mlir.ir.DenseElementsAttr.get(ctx, array)
+ # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf64>
+ print(attr)
+
+run(testGetDenseElementsF64)
+
+
+### 32 bit integer arrays
+# CHECK-LABEL: TEST: testGetDenseElementsI32Signless
+def testGetDenseElementsI32Signless():
+ ctx = mlir.ir.Context()
+ array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
+ attr = mlir.ir.DenseElementsAttr.get(ctx, array)
+ # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
+ print(attr)
+
+run(testGetDenseElementsI32Signless)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsUI32Signless
+def testGetDenseElementsUI32Signless():
+ ctx = mlir.ir.Context()
+ array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)
+ attr = mlir.ir.DenseElementsAttr.get(ctx, array)
+ # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
+ print(attr)
+
+run(testGetDenseElementsUI32Signless)
+
+# CHECK-LABEL: TEST: testGetDenseElementsI32
+def testGetDenseElementsI32():
+ ctx = mlir.ir.Context()
+ array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
+ attr = mlir.ir.DenseElementsAttr.get(ctx, array, signless=False)
+ # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi32>
+ print(attr)
+
+run(testGetDenseElementsI32)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsUI32
+def testGetDenseElementsUI32():
+ ctx = mlir.ir.Context()
+ array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)
+ attr = mlir.ir.DenseElementsAttr.get(ctx, array, signless=False)
+ # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui32>
+ print(attr)
+
+run(testGetDenseElementsUI32)
+
+
+## 64bit integer arrays
+# CHECK-LABEL: TEST: testGetDenseElementsI64Signless
+def testGetDenseElementsI64Signless():
+ ctx = mlir.ir.Context()
+ array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
+ attr = mlir.ir.DenseElementsAttr.get(ctx, array)
+ # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
+ print(attr)
+
+run(testGetDenseElementsI64Signless)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsUI64Signless
+def testGetDenseElementsUI64Signless():
+ ctx = mlir.ir.Context()
+ array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)
+ attr = mlir.ir.DenseElementsAttr.get(ctx, array)
+ # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
+ print(attr)
+
+run(testGetDenseElementsUI64Signless)
+
+# CHECK-LABEL: TEST: testGetDenseElementsI64
+def testGetDenseElementsI64():
+ ctx = mlir.ir.Context()
+ array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
+ attr = mlir.ir.DenseElementsAttr.get(ctx, array, signless=False)
+ # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi64>
+ print(attr)
+
+run(testGetDenseElementsI64)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsUI64
+def testGetDenseElementsUI64():
+ ctx = mlir.ir.Context()
+ array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)
+ attr = mlir.ir.DenseElementsAttr.get(ctx, array, signless=False)
+ # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui64>
+ print(attr)
+
+run(testGetDenseElementsUI64)
+
diff --git a/mlir/test/Bindings/Python/ir_attributes.py b/mlir/test/Bindings/Python/ir_attributes.py
index bf99a7686b17..d1f3e6b4a61a 100644
--- a/mlir/test/Bindings/Python/ir_attributes.py
+++ b/mlir/test/Bindings/Python/ir_attributes.py
@@ -104,7 +104,7 @@ def testFloatAttr():
loc = ctx.get_unknown_location()
# CHECK: default_get: 4.200000e+01 : f32
print("default_get:", mlir.ir.FloatAttr.get(
- mlir.ir.F32Type(ctx), 42.0, loc))
+ mlir.ir.F32Type.get(ctx), 42.0, loc))
# CHECK: f32_get: 4.200000e+01 : f32
print("f32_get:", mlir.ir.FloatAttr.get_f32(ctx, 42.0))
# CHECK: f64_get: 4.200000e+01 : f64
@@ -127,6 +127,8 @@ def testIntegerAttr():
iattr = mlir.ir.IntegerAttr(ctx.parse_attr("42"))
# CHECK: iattr value: 42
print("iattr value:", iattr.value)
+ # CHECK: iattr type: i64
+ print("iattr type:", iattr.type)
# Test factory methods.
# CHECK: default_get: 42 : i32
diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py
index 5a9c5a16bc92..151a4679bd8c 100644
--- a/mlir/test/Bindings/Python/ir_types.py
+++ b/mlir/test/Bindings/Python/ir_types.py
@@ -135,7 +135,7 @@ def testIntegerType():
def testIndexType():
ctx = mlir.ir.Context()
# CHECK: index type: index
- print("index type:", mlir.ir.IndexType(ctx))
+ print("index type:", mlir.ir.IndexType.get(ctx))
run(testIndexType)
@@ -143,13 +143,13 @@ def testIndexType():
def testFloatType():
ctx = mlir.ir.Context()
# CHECK: float: bf16
- print("float:", mlir.ir.BF16Type(ctx))
+ print("float:", mlir.ir.BF16Type.get(ctx))
# CHECK: float: f16
- print("float:", mlir.ir.F16Type(ctx))
+ print("float:", mlir.ir.F16Type.get(ctx))
# CHECK: float: f32
- print("float:", mlir.ir.F32Type(ctx))
+ print("float:", mlir.ir.F32Type.get(ctx))
# CHECK: float: f64
- print("float:", mlir.ir.F64Type(ctx))
+ print("float:", mlir.ir.F64Type.get(ctx))
run(testFloatType)
@@ -157,7 +157,7 @@ def testFloatType():
def testNoneType():
ctx = mlir.ir.Context()
# CHECK: none type: none
- print("none type:", mlir.ir.NoneType(ctx))
+ print("none type:", mlir.ir.NoneType.get(ctx))
run(testNoneType)
@@ -168,13 +168,13 @@ def testComplexType():
# CHECK: complex type element: i32
print("complex type element:", complex_i32.element_type)
- f32 = mlir.ir.F32Type(ctx)
+ f32 = mlir.ir.F32Type.get(ctx)
# CHECK: complex type: complex<f32>
- print("complex type:", mlir.ir.ComplexType.get_complex(f32))
+ print("complex type:", mlir.ir.ComplexType.get(f32))
- index = mlir.ir.IndexType(ctx)
+ index = mlir.ir.IndexType.get(ctx)
try:
- complex_invalid = mlir.ir.ComplexType.get_complex(index)
+ complex_invalid = mlir.ir.ComplexType.get(index)
except ValueError as e:
# CHECK: invalid 'Type(index)' and expected floating point or integer type.
print(e)
@@ -225,15 +225,15 @@ def testAbstractShapedType():
# CHECK-LABEL: TEST: testVectorType
def testVectorType():
ctx = mlir.ir.Context()
- f32 = mlir.ir.F32Type(ctx)
+ f32 = mlir.ir.F32Type.get(ctx)
shape = [2, 3]
loc = ctx.get_unknown_location()
# CHECK: vector type: vector<2x3xf32>
- print("vector type:", mlir.ir.VectorType.get_vector(shape, f32, loc))
+ print("vector type:", mlir.ir.VectorType.get(shape, f32, loc))
- none = mlir.ir.NoneType(ctx)
+ none = mlir.ir.NoneType.get(ctx)
try:
- vector_invalid = mlir.ir.VectorType.get_vector(shape, none, loc)
+ vector_invalid = mlir.ir.VectorType.get(shape, none, loc)
except ValueError as e:
# CHECK: invalid 'Type(none)' and expected floating point or integer type.
print(e)
@@ -245,17 +245,16 @@ def testVectorType():
# CHECK-LABEL: TEST: testRankedTensorType
def testRankedTensorType():
ctx = mlir.ir.Context()
- f32 = mlir.ir.F32Type(ctx)
+ f32 = mlir.ir.F32Type.get(ctx)
shape = [2, 3]
loc = ctx.get_unknown_location()
# CHECK: ranked tensor type: tensor<2x3xf32>
print("ranked tensor type:",
- mlir.ir.RankedTensorType.get_ranked_tensor(shape, f32, loc))
+ mlir.ir.RankedTensorType.get(shape, f32, loc))
- none = mlir.ir.NoneType(ctx)
+ none = mlir.ir.NoneType.get(ctx)
try:
- tensor_invalid = mlir.ir.RankedTensorType.get_ranked_tensor(shape, none,
- loc)
+ tensor_invalid = mlir.ir.RankedTensorType.get(shape, none, loc)
except ValueError as e:
# CHECK: invalid 'Type(none)' and expected floating point, integer, vector
# CHECK: or complex type.
@@ -268,9 +267,9 @@ def testRankedTensorType():
# CHECK-LABEL: TEST: testUnrankedTensorType
def testUnrankedTensorType():
ctx = mlir.ir.Context()
- f32 = mlir.ir.F32Type(ctx)
+ f32 = mlir.ir.F32Type.get(ctx)
loc = ctx.get_unknown_location()
- unranked_tensor = mlir.ir.UnrankedTensorType.get_unranked_tensor(f32, loc)
+ unranked_tensor = mlir.ir.UnrankedTensorType.get(f32, loc)
# CHECK: unranked tensor type: tensor<*xf32>
print("unranked tensor type:", unranked_tensor)
try:
@@ -295,9 +294,9 @@ def testUnrankedTensorType():
else:
print("Exception not produced")
- none = mlir.ir.NoneType(ctx)
+ none = mlir.ir.NoneType.get(ctx)
try:
- tensor_invalid = mlir.ir.UnrankedTensorType.get_unranked_tensor(none, loc)
+ tensor_invalid = mlir.ir.UnrankedTensorType.get(none, loc)
except ValueError as e:
# CHECK: invalid 'Type(none)' and expected floating point, integer, vector
# CHECK: or complex type.
@@ -310,7 +309,7 @@ def testUnrankedTensorType():
# CHECK-LABEL: TEST: testMemRefType
def testMemRefType():
ctx = mlir.ir.Context()
- f32 = mlir.ir.F32Type(ctx)
+ f32 = mlir.ir.F32Type.get(ctx)
shape = [2, 3]
loc = ctx.get_unknown_location()
memref = mlir.ir.MemRefType.get_contiguous_memref(f32, shape, 2, loc)
@@ -321,7 +320,7 @@ def testMemRefType():
# CHECK: memory space: 2
print("memory space:", memref.memory_space)
- none = mlir.ir.NoneType(ctx)
+ none = mlir.ir.NoneType.get(ctx)
try:
memref_invalid = mlir.ir.MemRefType.get_contiguous_memref(none, shape, 2,
loc)
@@ -337,9 +336,9 @@ def testMemRefType():
# CHECK-LABEL: TEST: testUnrankedMemRefType
def testUnrankedMemRefType():
ctx = mlir.ir.Context()
- f32 = mlir.ir.F32Type(ctx)
+ f32 = mlir.ir.F32Type.get(ctx)
loc = ctx.get_unknown_location()
- unranked_memref = mlir.ir.UnrankedMemRefType.get_unranked_memref(f32, 2, loc)
+ unranked_memref = mlir.ir.UnrankedMemRefType.get(f32, 2, loc)
# CHECK: unranked memref type: memref<*xf32, 2>
print("unranked memref type:", unranked_memref)
try:
@@ -364,10 +363,9 @@ def testUnrankedMemRefType():
else:
print("Exception not produced")
- none = mlir.ir.NoneType(ctx)
+ none = mlir.ir.NoneType.get(ctx)
try:
- memref_invalid = mlir.ir.UnrankedMemRefType.get_unranked_memref(none, 2,
- loc)
+ memref_invalid = mlir.ir.UnrankedMemRefType.get(none, 2, loc)
except ValueError as e:
# CHECK: invalid 'Type(none)' and expected floating point, integer, vector
# CHECK: or complex type.
@@ -381,7 +379,7 @@ def testUnrankedMemRefType():
def testTupleType():
ctx = mlir.ir.Context()
i32 = mlir.ir.IntegerType(ctx.parse_type("i32"))
- f32 = mlir.ir.F32Type(ctx)
+ f32 = mlir.ir.F32Type.get(ctx)
vector = mlir.ir.VectorType(ctx.parse_type("vector<2x3xf32>"))
l = [i32, f32, vector]
tuple_type = mlir.ir.TupleType.get_tuple(ctx, l)
@@ -400,7 +398,7 @@ def testFunctionType():
ctx = mlir.ir.Context()
input_types = [mlir.ir.IntegerType.get_signless(ctx, 32),
mlir.ir.IntegerType.get_signless(ctx, 16)]
- result_types = [mlir.ir.IndexType(ctx)]
+ result_types = [mlir.ir.IndexType.get(ctx)]
func = mlir.ir.FunctionType.get(ctx, input_types, result_types)
# CHECK: INPUTS: [Type(i32), Type(i16)]
print("INPUTS:", func.inputs)
More information about the Mlir-commits
mailing list