[Mlir-commits] [mlir] 0e6beb2 - [mlir][Python] Add python binding to create DenseElementsAttribute.

Stella Laurenzo llvmlistbot at llvm.org
Mon Oct 19 22:37:25 PDT 2020


Author: Stella Laurenzo
Date: 2020-10-19T22:29:35-07:00
New Revision: 0e6beb29966abc6666e73ab5f151cb9754f04901

URL: https://github.com/llvm/llvm-project/commit/0e6beb29966abc6666e73ab5f151cb9754f04901
DIFF: https://github.com/llvm/llvm-project/commit/0e6beb29966abc6666e73ab5f151cb9754f04901.diff

LOG: [mlir][Python] Add python binding to create DenseElementsAttribute.

* Interops with Python buffers/numpy arrays to create.
* Also cleans up 'get' factory methods on some types to be consistent.
* Adds mlirAttributeGetType() to C-API to facilitate error handling and other uses.
* Punts on a lot of features of the ElementsAttribute hierarchy for now.
* Does not yet support bool or string attributes.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D89363

Added: 
    mlir/test/Bindings/Python/ir_array_attributes.py

Modified: 
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/test/Bindings/Python/ir_attributes.py
    mlir/test/Bindings/Python/ir_types.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 8f525e8b6239..2a768df0ffd9 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -912,6 +912,150 @@ class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
   }
 };
 
+// TODO: Support construction of bool elements.
+// TODO: Support construction of string elements.
+class PyDenseElementsAttribute
+    : public PyConcreteAttribute<PyDenseElementsAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
+  static constexpr const char *pyClassName = "DenseElementsAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  static PyDenseElementsAttribute getFromBuffer(PyMlirContext &contextWrapper,
+                                                py::buffer array,
+                                                bool signless) {
+    // Request a contiguous view. In exotic cases, this will cause a copy.
+    int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
+    Py_buffer *view = new Py_buffer();
+    if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
+      delete view;
+      throw py::error_already_set();
+    }
+    py::buffer_info arrayInfo(view);
+
+    MlirContext context = contextWrapper.get();
+    // Switch on the types that can be bulk loaded between the Python and
+    // MLIR-C APIs.
+    if (arrayInfo.format == "f") {
+      // f32
+      assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
+      return PyDenseElementsAttribute(
+          contextWrapper.getRef(),
+          bulkLoad(context, mlirDenseElementsAttrFloatGet,
+                   mlirF32TypeGet(context), arrayInfo));
+    } else if (arrayInfo.format == "d") {
+      // f64
+      assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
+      return PyDenseElementsAttribute(
+          contextWrapper.getRef(),
+          bulkLoad(context, mlirDenseElementsAttrDoubleGet,
+                   mlirF64TypeGet(context), arrayInfo));
+    } else if (arrayInfo.format == "i") {
+      // i32
+      assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
+      MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
+                                      : mlirIntegerTypeSignedGet(context, 32);
+      return PyDenseElementsAttribute(contextWrapper.getRef(),
+                                      bulkLoad(context,
+                                               mlirDenseElementsAttrInt32Get,
+                                               elementType, arrayInfo));
+    } else if (arrayInfo.format == "I") {
+      // unsigned i32
+      assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
+      MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
+                                      : mlirIntegerTypeUnsignedGet(context, 32);
+      return PyDenseElementsAttribute(contextWrapper.getRef(),
+                                      bulkLoad(context,
+                                               mlirDenseElementsAttrUInt32Get,
+                                               elementType, arrayInfo));
+    } else if (arrayInfo.format == "l") {
+      // i64
+      assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
+      MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
+                                      : mlirIntegerTypeSignedGet(context, 64);
+      return PyDenseElementsAttribute(contextWrapper.getRef(),
+                                      bulkLoad(context,
+                                               mlirDenseElementsAttrInt64Get,
+                                               elementType, arrayInfo));
+    } else if (arrayInfo.format == "L") {
+      // unsigned i64
+      assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
+      MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
+                                      : mlirIntegerTypeUnsignedGet(context, 64);
+      return PyDenseElementsAttribute(contextWrapper.getRef(),
+                                      bulkLoad(context,
+                                               mlirDenseElementsAttrUInt64Get,
+                                               elementType, arrayInfo));
+    }
+
+    // TODO: Fall back to string-based get.
+    std::string message = "unimplemented array format conversion from format: ";
+    message.append(arrayInfo.format);
+    throw SetPyError(PyExc_ValueError, message);
+  }
+
+  static PyDenseElementsAttribute getSplat(PyType shapedType,
+                                           PyAttribute &elementAttr) {
+    auto contextWrapper =
+        PyMlirContext::forContext(mlirTypeGetContext(shapedType));
+    if (!mlirAttributeIsAInteger(elementAttr.attr) &&
+        !mlirAttributeIsAFloat(elementAttr.attr)) {
+      std::string message = "Illegal element type for DenseElementsAttr: ";
+      message.append(py::repr(py::cast(elementAttr)));
+      throw SetPyError(PyExc_ValueError, message);
+    }
+    if (!mlirTypeIsAShaped(shapedType) ||
+        !mlirShapedTypeHasStaticShape(shapedType)) {
+      std::string message =
+          "Expected a static ShapedType for the shaped_type parameter: ";
+      message.append(py::repr(py::cast(shapedType)));
+      throw SetPyError(PyExc_ValueError, message);
+    }
+    MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType.type);
+    MlirType attrType = mlirAttributeGetType(elementAttr.attr);
+    if (!mlirTypeEqual(shapedElementType, attrType)) {
+      std::string message =
+          "Shaped element type and attribute type must be equal: shaped=";
+      message.append(py::repr(py::cast(shapedType)));
+      message.append(", element=");
+      message.append(py::repr(py::cast(elementAttr)));
+      throw SetPyError(PyExc_ValueError, message);
+    }
+
+    MlirAttribute elements =
+        mlirDenseElementsAttrSplatGet(shapedType.type, elementAttr.attr);
+    return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
+  }
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static("get", PyDenseElementsAttribute::getFromBuffer,
+                 py::arg("context"), py::arg("array"),
+                 py::arg("signless") = true, "Gets from a buffer or ndarray")
+        .def_static("get_splat", PyDenseElementsAttribute::getSplat,
+                    py::arg("shaped_type"), py::arg("element_attr"),
+                    "Gets a DenseElementsAttr where all values are the same")
+        .def_property_readonly("is_splat",
+                               [](PyDenseElementsAttribute &self) -> bool {
+                                 return mlirDenseElementsAttrIsSplat(self.attr);
+                               });
+  }
+
+private:
+  template <typename ElementTy>
+  static MlirAttribute
+  bulkLoad(MlirContext context,
+           MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *),
+           MlirType mlirElementType, py::buffer_info &arrayInfo) {
+    SmallVector<int64_t, 4> shape(arrayInfo.shape.begin(),
+                                  arrayInfo.shape.begin() + arrayInfo.ndim);
+    auto shapedType =
+        mlirRankedTensorTypeGet(shape.size(), shape.data(), mlirElementType);
+    intptr_t numElements = arrayInfo.size;
+    const ElementTy *contents = static_cast<const ElementTy *>(arrayInfo.ptr);
+    return ctor(shapedType, numElements, contents);
+  }
+};
+
 } // namespace
 
 //------------------------------------------------------------------------------
@@ -1021,11 +1165,13 @@ class PyIndexType : public PyConcreteType<PyIndexType> {
   using PyConcreteType::PyConcreteType;
 
   static void bindDerived(ClassTy &c) {
-    c.def(py::init([](PyMlirContext &context) {
-            MlirType t = mlirIndexTypeGet(context.get());
-            return PyIndexType(context.getRef(), t);
-          }),
-          "Create a index type.");
+    c.def_static(
+        "get",
+        [](PyMlirContext &context) {
+          MlirType t = mlirIndexTypeGet(context.get());
+          return PyIndexType(context.getRef(), t);
+        },
+        "Create a index type.");
   }
 };
 
@@ -1037,11 +1183,13 @@ class PyBF16Type : public PyConcreteType<PyBF16Type> {
   using PyConcreteType::PyConcreteType;
 
   static void bindDerived(ClassTy &c) {
-    c.def(py::init([](PyMlirContext &context) {
-            MlirType t = mlirBF16TypeGet(context.get());
-            return PyBF16Type(context.getRef(), t);
-          }),
-          "Create a bf16 type.");
+    c.def_static(
+        "get",
+        [](PyMlirContext &context) {
+          MlirType t = mlirBF16TypeGet(context.get());
+          return PyBF16Type(context.getRef(), t);
+        },
+        "Create a bf16 type.");
   }
 };
 
@@ -1053,11 +1201,13 @@ class PyF16Type : public PyConcreteType<PyF16Type> {
   using PyConcreteType::PyConcreteType;
 
   static void bindDerived(ClassTy &c) {
-    c.def(py::init([](PyMlirContext &context) {
-            MlirType t = mlirF16TypeGet(context.get());
-            return PyF16Type(context.getRef(), t);
-          }),
-          "Create a f16 type.");
+    c.def_static(
+        "get",
+        [](PyMlirContext &context) {
+          MlirType t = mlirF16TypeGet(context.get());
+          return PyF16Type(context.getRef(), t);
+        },
+        "Create a f16 type.");
   }
 };
 
@@ -1069,11 +1219,13 @@ class PyF32Type : public PyConcreteType<PyF32Type> {
   using PyConcreteType::PyConcreteType;
 
   static void bindDerived(ClassTy &c) {
-    c.def(py::init([](PyMlirContext &context) {
-            MlirType t = mlirF32TypeGet(context.get());
-            return PyF32Type(context.getRef(), t);
-          }),
-          "Create a f32 type.");
+    c.def_static(
+        "get",
+        [](PyMlirContext &context) {
+          MlirType t = mlirF32TypeGet(context.get());
+          return PyF32Type(context.getRef(), t);
+        },
+        "Create a f32 type.");
   }
 };
 
@@ -1085,11 +1237,13 @@ class PyF64Type : public PyConcreteType<PyF64Type> {
   using PyConcreteType::PyConcreteType;
 
   static void bindDerived(ClassTy &c) {
-    c.def(py::init([](PyMlirContext &context) {
-            MlirType t = mlirF64TypeGet(context.get());
-            return PyF64Type(context.getRef(), t);
-          }),
-          "Create a f64 type.");
+    c.def_static(
+        "get",
+        [](PyMlirContext &context) {
+          MlirType t = mlirF64TypeGet(context.get());
+          return PyF64Type(context.getRef(), t);
+        },
+        "Create a f64 type.");
   }
 };
 
@@ -1101,11 +1255,13 @@ class PyNoneType : public PyConcreteType<PyNoneType> {
   using PyConcreteType::PyConcreteType;
 
   static void bindDerived(ClassTy &c) {
-    c.def(py::init([](PyMlirContext &context) {
-            MlirType t = mlirNoneTypeGet(context.get());
-            return PyNoneType(context.getRef(), t);
-          }),
-          "Create a none type.");
+    c.def_static(
+        "get",
+        [](PyMlirContext &context) {
+          MlirType t = mlirNoneTypeGet(context.get());
+          return PyNoneType(context.getRef(), t);
+        },
+        "Create a none type.");
   }
 };
 
@@ -1118,7 +1274,7 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
 
   static void bindDerived(ClassTy &c) {
     c.def_static(
-        "get_complex",
+        "get",
         [](PyType &elementType) {
           // The element must be a floating point or integer scalar type.
           if (mlirTypeIsAIntegerOrFloat(elementType.type)) {
@@ -1224,7 +1380,7 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
 
   static void bindDerived(ClassTy &c) {
     c.def_static(
-        "get_vector",
+        "get",
         // TODO: Make the location optional and create a default location.
         [](std::vector<int64_t> shape, PyType &elementType, PyLocation &loc) {
           MlirType t = mlirVectorTypeGetChecked(shape.size(), shape.data(),
@@ -1254,7 +1410,7 @@ class PyRankedTensorType
 
   static void bindDerived(ClassTy &c) {
     c.def_static(
-        "get_ranked_tensor",
+        "get",
         // TODO: Make the location optional and create a default location.
         [](std::vector<int64_t> shape, PyType &elementType, PyLocation &loc) {
           MlirType t = mlirRankedTensorTypeGetChecked(
@@ -1286,7 +1442,7 @@ class PyUnrankedTensorType
 
   static void bindDerived(ClassTy &c) {
     c.def_static(
-        "get_unranked_tensor",
+        "get",
         // TODO: Make the location optional and create a default location.
         [](PyType &elementType, PyLocation &loc) {
           MlirType t =
@@ -1366,7 +1522,7 @@ class PyUnrankedMemRefType
 
   static void bindDerived(ClassTy &c) {
     c.def_static(
-         "get_unranked_memref",
+         "get",
          // TODO: Make the location optional and create a default location.
          [](PyType &elementType, unsigned memorySpace, PyLocation &loc) {
            MlirType t = mlirUnrankedMemRefTypeGetChecked(elementType.type,
@@ -1719,6 +1875,11 @@ void mlir::python::populateIRSubmodule(py::module &m) {
           "context",
           [](PyAttribute &self) { return self.getContext().getObject(); },
           "Context that owns the Attribute")
+      .def_property_readonly("type",
+                             [](PyAttribute &self) {
+                               return PyType(self.getContext()->getRef(),
+                                             mlirAttributeGetType(self.attr));
+                             })
       .def(
           "get_named",
           [](PyAttribute &self, std::string name) {
@@ -1796,6 +1957,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
   PyIntegerAttribute::bind(m);
   PyBoolAttribute::bind(m);
   PyStringAttribute::bind(m);
+  PyDenseElementsAttribute::bind(m);
 
   // Mapping of Type.
   py::class_<PyType>(m, "Type")

diff  --git a/mlir/test/Bindings/Python/ir_array_attributes.py b/mlir/test/Bindings/Python/ir_array_attributes.py
new file mode 100644
index 000000000000..97a9802ae148
--- /dev/null
+++ b/mlir/test/Bindings/Python/ir_array_attributes.py
@@ -0,0 +1,213 @@
+# RUN: %PYTHON %s | FileCheck %s
+# Note that this is separate from ir_attributes.py since it depends on numpy,
+# and we may want to disable if not available.
+
+import gc
+import mlir
+import numpy as np
+
+def run(f):
+  print("\nTEST:", f.__name__)
+  f()
+  gc.collect()
+  assert mlir.ir.Context._get_live_count() == 0
+
+################################################################################
+# Tests of the array/buffer .get() factory method on unsupported dtype.
+################################################################################
+
+def testGetDenseElementsUnsupported():
+  ctx = mlir.ir.Context()
+  array = np.array([["hello", "goodbye"]])
+  try:
+    attr = mlir.ir.DenseElementsAttr.get(ctx, array)
+  except ValueError as e:
+    # CHECK: unimplemented array format conversion from format:
+    print(e)
+
+run(testGetDenseElementsUnsupported)
+
+################################################################################
+# Splats.
+################################################################################
+
+# CHECK-LABEL: TEST: testGetDenseElementsSplatInt
+def testGetDenseElementsSplatInt():
+  ctx = mlir.ir.Context()
+  loc = ctx.get_unknown_location()
+  t = mlir.ir.IntegerType.get_signless(ctx, 32)
+  element = mlir.ir.IntegerAttr.get(t, 555)
+  shaped_type = mlir.ir.RankedTensorType.get((2, 3, 4), t, loc)
+  attr = mlir.ir.DenseElementsAttr.get_splat(shaped_type, element)
+  # CHECK: dense<555> : tensor<2x3x4xi32>
+  print(attr)
+  # CHECK: is_splat: True
+  print("is_splat:", attr.is_splat)
+
+run(testGetDenseElementsSplatInt)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsSplatFloat
+def testGetDenseElementsSplatFloat():
+  ctx = mlir.ir.Context()
+  loc = ctx.get_unknown_location()
+  t = mlir.ir.F32Type.get(ctx)
+  element = mlir.ir.FloatAttr.get(t, 1.2, loc)
+  shaped_type = mlir.ir.RankedTensorType.get((2, 3, 4), t, loc)
+  attr = mlir.ir.DenseElementsAttr.get_splat(shaped_type, element)
+  # CHECK: dense<1.200000e+00> : tensor<2x3x4xf32>
+  print(attr)
+
+run(testGetDenseElementsSplatFloat)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsSplatErrors
+def testGetDenseElementsSplatErrors():
+  ctx = mlir.ir.Context()
+  loc = ctx.get_unknown_location()
+  t = mlir.ir.F32Type.get(ctx)
+  other_t = mlir.ir.F64Type.get(ctx)
+  element = mlir.ir.FloatAttr.get(t, 1.2, loc)
+  other_element = mlir.ir.FloatAttr.get(other_t, 1.2, loc)
+  shaped_type = mlir.ir.RankedTensorType.get((2, 3, 4), t, loc)
+  dynamic_shaped_type = mlir.ir.UnrankedTensorType.get(t, loc)
+  non_shaped_type = t
+
+  try:
+    attr = mlir.ir.DenseElementsAttr.get_splat(non_shaped_type, element)
+  except ValueError as e:
+    # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(f32)
+    print(e)
+
+  try:
+    attr = mlir.ir.DenseElementsAttr.get_splat(dynamic_shaped_type, element)
+  except ValueError as e:
+    # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(tensor<*xf32>)
+    print(e)
+
+  try:
+    attr = mlir.ir.DenseElementsAttr.get_splat(shaped_type, other_element)
+  except ValueError as e:
+    # CHECK: Shaped element type and attribute type must be equal: shaped=Type(tensor<2x3x4xf32>), element=Attribute(1.200000e+00 : f64)
+    print(e)
+
+run(testGetDenseElementsSplatErrors)
+
+
+################################################################################
+# Tests of the array/buffer .get() factory method, in all of its permutations.
+################################################################################
+
+### float and double arrays.
+
+# CHECK-LABEL: TEST: testGetDenseElementsF32
+def testGetDenseElementsF32():
+  ctx = mlir.ir.Context()
+  array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32)
+  attr = mlir.ir.DenseElementsAttr.get(ctx, array)
+  # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf32>
+  print(attr)
+  # CHECK: is_splat: False
+  print("is_splat:", attr.is_splat)
+
+run(testGetDenseElementsF32)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsF64
+def testGetDenseElementsF64():
+  ctx = mlir.ir.Context()
+  array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float64)
+  attr = mlir.ir.DenseElementsAttr.get(ctx, array)
+  # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf64>
+  print(attr)
+
+run(testGetDenseElementsF64)
+
+
+### 32 bit integer arrays
+# CHECK-LABEL: TEST: testGetDenseElementsI32Signless
+def testGetDenseElementsI32Signless():
+  ctx = mlir.ir.Context()
+  array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
+  attr = mlir.ir.DenseElementsAttr.get(ctx, array)
+  # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
+  print(attr)
+
+run(testGetDenseElementsI32Signless)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsUI32Signless
+def testGetDenseElementsUI32Signless():
+  ctx = mlir.ir.Context()
+  array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)
+  attr = mlir.ir.DenseElementsAttr.get(ctx, array)
+  # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
+  print(attr)
+
+run(testGetDenseElementsUI32Signless)
+
+# CHECK-LABEL: TEST: testGetDenseElementsI32
+def testGetDenseElementsI32():
+  ctx = mlir.ir.Context()
+  array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
+  attr = mlir.ir.DenseElementsAttr.get(ctx, array, signless=False)
+  # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi32>
+  print(attr)
+
+run(testGetDenseElementsI32)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsUI32
+def testGetDenseElementsUI32():
+  ctx = mlir.ir.Context()
+  array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)
+  attr = mlir.ir.DenseElementsAttr.get(ctx, array, signless=False)
+  # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui32>
+  print(attr)
+
+run(testGetDenseElementsUI32)
+
+
+## 64bit integer arrays
+# CHECK-LABEL: TEST: testGetDenseElementsI64Signless
+def testGetDenseElementsI64Signless():
+  ctx = mlir.ir.Context()
+  array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
+  attr = mlir.ir.DenseElementsAttr.get(ctx, array)
+  # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
+  print(attr)
+
+run(testGetDenseElementsI64Signless)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsUI64Signless
+def testGetDenseElementsUI64Signless():
+  ctx = mlir.ir.Context()
+  array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)
+  attr = mlir.ir.DenseElementsAttr.get(ctx, array)
+  # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
+  print(attr)
+
+run(testGetDenseElementsUI64Signless)
+
+# CHECK-LABEL: TEST: testGetDenseElementsI64
+def testGetDenseElementsI64():
+  ctx = mlir.ir.Context()
+  array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
+  attr = mlir.ir.DenseElementsAttr.get(ctx, array, signless=False)
+  # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi64>
+  print(attr)
+
+run(testGetDenseElementsI64)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsUI64
+def testGetDenseElementsUI64():
+  ctx = mlir.ir.Context()
+  array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)
+  attr = mlir.ir.DenseElementsAttr.get(ctx, array, signless=False)
+  # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui64>
+  print(attr)
+
+run(testGetDenseElementsUI64)
+

diff  --git a/mlir/test/Bindings/Python/ir_attributes.py b/mlir/test/Bindings/Python/ir_attributes.py
index bf99a7686b17..d1f3e6b4a61a 100644
--- a/mlir/test/Bindings/Python/ir_attributes.py
+++ b/mlir/test/Bindings/Python/ir_attributes.py
@@ -104,7 +104,7 @@ def testFloatAttr():
   loc = ctx.get_unknown_location()
   # CHECK: default_get: 4.200000e+01 : f32
   print("default_get:", mlir.ir.FloatAttr.get(
-      mlir.ir.F32Type(ctx), 42.0, loc))
+      mlir.ir.F32Type.get(ctx), 42.0, loc))
   # CHECK: f32_get: 4.200000e+01 : f32
   print("f32_get:", mlir.ir.FloatAttr.get_f32(ctx, 42.0))
   # CHECK: f64_get: 4.200000e+01 : f64
@@ -127,6 +127,8 @@ def testIntegerAttr():
   iattr = mlir.ir.IntegerAttr(ctx.parse_attr("42"))
   # CHECK: iattr value: 42
   print("iattr value:", iattr.value)
+  # CHECK: iattr type: i64
+  print("iattr type:", iattr.type)
 
   # Test factory methods.
   # CHECK: default_get: 42 : i32

diff  --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py
index 5a9c5a16bc92..151a4679bd8c 100644
--- a/mlir/test/Bindings/Python/ir_types.py
+++ b/mlir/test/Bindings/Python/ir_types.py
@@ -135,7 +135,7 @@ def testIntegerType():
 def testIndexType():
   ctx = mlir.ir.Context()
   # CHECK: index type: index
-  print("index type:", mlir.ir.IndexType(ctx))
+  print("index type:", mlir.ir.IndexType.get(ctx))
 
 run(testIndexType)
 
@@ -143,13 +143,13 @@ def testIndexType():
 def testFloatType():
   ctx = mlir.ir.Context()
   # CHECK: float: bf16
-  print("float:", mlir.ir.BF16Type(ctx))
+  print("float:", mlir.ir.BF16Type.get(ctx))
   # CHECK: float: f16
-  print("float:", mlir.ir.F16Type(ctx))
+  print("float:", mlir.ir.F16Type.get(ctx))
   # CHECK: float: f32
-  print("float:", mlir.ir.F32Type(ctx))
+  print("float:", mlir.ir.F32Type.get(ctx))
   # CHECK: float: f64
-  print("float:", mlir.ir.F64Type(ctx))
+  print("float:", mlir.ir.F64Type.get(ctx))
 
 run(testFloatType)
 
@@ -157,7 +157,7 @@ def testFloatType():
 def testNoneType():
   ctx = mlir.ir.Context()
   # CHECK: none type: none
-  print("none type:", mlir.ir.NoneType(ctx))
+  print("none type:", mlir.ir.NoneType.get(ctx))
 
 run(testNoneType)
 
@@ -168,13 +168,13 @@ def testComplexType():
   # CHECK: complex type element: i32
   print("complex type element:", complex_i32.element_type)
 
-  f32 = mlir.ir.F32Type(ctx)
+  f32 = mlir.ir.F32Type.get(ctx)
   # CHECK: complex type: complex<f32>
-  print("complex type:", mlir.ir.ComplexType.get_complex(f32))
+  print("complex type:", mlir.ir.ComplexType.get(f32))
 
-  index = mlir.ir.IndexType(ctx)
+  index = mlir.ir.IndexType.get(ctx)
   try:
-    complex_invalid = mlir.ir.ComplexType.get_complex(index)
+    complex_invalid = mlir.ir.ComplexType.get(index)
   except ValueError as e:
     # CHECK: invalid 'Type(index)' and expected floating point or integer type.
     print(e)
@@ -225,15 +225,15 @@ def testAbstractShapedType():
 # CHECK-LABEL: TEST: testVectorType
 def testVectorType():
   ctx = mlir.ir.Context()
-  f32 = mlir.ir.F32Type(ctx)
+  f32 = mlir.ir.F32Type.get(ctx)
   shape = [2, 3]
   loc = ctx.get_unknown_location()
   # CHECK: vector type: vector<2x3xf32>
-  print("vector type:", mlir.ir.VectorType.get_vector(shape, f32, loc))
+  print("vector type:", mlir.ir.VectorType.get(shape, f32, loc))
 
-  none = mlir.ir.NoneType(ctx)
+  none = mlir.ir.NoneType.get(ctx)
   try:
-    vector_invalid = mlir.ir.VectorType.get_vector(shape, none, loc)
+    vector_invalid = mlir.ir.VectorType.get(shape, none, loc)
   except ValueError as e:
     # CHECK: invalid 'Type(none)' and expected floating point or integer type.
     print(e)
@@ -245,17 +245,16 @@ def testVectorType():
 # CHECK-LABEL: TEST: testRankedTensorType
 def testRankedTensorType():
   ctx = mlir.ir.Context()
-  f32 = mlir.ir.F32Type(ctx)
+  f32 = mlir.ir.F32Type.get(ctx)
   shape = [2, 3]
   loc = ctx.get_unknown_location()
   # CHECK: ranked tensor type: tensor<2x3xf32>
   print("ranked tensor type:",
-        mlir.ir.RankedTensorType.get_ranked_tensor(shape, f32, loc))
+        mlir.ir.RankedTensorType.get(shape, f32, loc))
 
-  none = mlir.ir.NoneType(ctx)
+  none = mlir.ir.NoneType.get(ctx)
   try:
-    tensor_invalid = mlir.ir.RankedTensorType.get_ranked_tensor(shape, none,
-                                                                loc)
+    tensor_invalid = mlir.ir.RankedTensorType.get(shape, none, loc)
   except ValueError as e:
     # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
     # CHECK: or complex type.
@@ -268,9 +267,9 @@ def testRankedTensorType():
 # CHECK-LABEL: TEST: testUnrankedTensorType
 def testUnrankedTensorType():
   ctx = mlir.ir.Context()
-  f32 = mlir.ir.F32Type(ctx)
+  f32 = mlir.ir.F32Type.get(ctx)
   loc = ctx.get_unknown_location()
-  unranked_tensor = mlir.ir.UnrankedTensorType.get_unranked_tensor(f32, loc)
+  unranked_tensor = mlir.ir.UnrankedTensorType.get(f32, loc)
   # CHECK: unranked tensor type: tensor<*xf32>
   print("unranked tensor type:", unranked_tensor)
   try:
@@ -295,9 +294,9 @@ def testUnrankedTensorType():
   else:
     print("Exception not produced")
 
-  none = mlir.ir.NoneType(ctx)
+  none = mlir.ir.NoneType.get(ctx)
   try:
-    tensor_invalid = mlir.ir.UnrankedTensorType.get_unranked_tensor(none, loc)
+    tensor_invalid = mlir.ir.UnrankedTensorType.get(none, loc)
   except ValueError as e:
     # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
     # CHECK: or complex type.
@@ -310,7 +309,7 @@ def testUnrankedTensorType():
 # CHECK-LABEL: TEST: testMemRefType
 def testMemRefType():
   ctx = mlir.ir.Context()
-  f32 = mlir.ir.F32Type(ctx)
+  f32 = mlir.ir.F32Type.get(ctx)
   shape = [2, 3]
   loc = ctx.get_unknown_location()
   memref = mlir.ir.MemRefType.get_contiguous_memref(f32, shape, 2, loc)
@@ -321,7 +320,7 @@ def testMemRefType():
   # CHECK: memory space: 2
   print("memory space:", memref.memory_space)
 
-  none = mlir.ir.NoneType(ctx)
+  none = mlir.ir.NoneType.get(ctx)
   try:
     memref_invalid = mlir.ir.MemRefType.get_contiguous_memref(none, shape, 2,
                                                               loc)
@@ -337,9 +336,9 @@ def testMemRefType():
 # CHECK-LABEL: TEST: testUnrankedMemRefType
 def testUnrankedMemRefType():
   ctx = mlir.ir.Context()
-  f32 = mlir.ir.F32Type(ctx)
+  f32 = mlir.ir.F32Type.get(ctx)
   loc = ctx.get_unknown_location()
-  unranked_memref = mlir.ir.UnrankedMemRefType.get_unranked_memref(f32, 2, loc)
+  unranked_memref = mlir.ir.UnrankedMemRefType.get(f32, 2, loc)
   # CHECK: unranked memref type: memref<*xf32, 2>
   print("unranked memref type:", unranked_memref)
   try:
@@ -364,10 +363,9 @@ def testUnrankedMemRefType():
   else:
     print("Exception not produced")
 
-  none = mlir.ir.NoneType(ctx)
+  none = mlir.ir.NoneType.get(ctx)
   try:
-    memref_invalid = mlir.ir.UnrankedMemRefType.get_unranked_memref(none, 2,
-                                                                    loc)
+    memref_invalid = mlir.ir.UnrankedMemRefType.get(none, 2, loc)
   except ValueError as e:
     # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
     # CHECK: or complex type.
@@ -381,7 +379,7 @@ def testUnrankedMemRefType():
 def testTupleType():
   ctx = mlir.ir.Context()
   i32 = mlir.ir.IntegerType(ctx.parse_type("i32"))
-  f32 = mlir.ir.F32Type(ctx)
+  f32 = mlir.ir.F32Type.get(ctx)
   vector = mlir.ir.VectorType(ctx.parse_type("vector<2x3xf32>"))
   l = [i32, f32, vector]
   tuple_type = mlir.ir.TupleType.get_tuple(ctx, l)
@@ -400,7 +398,7 @@ def testFunctionType():
   ctx = mlir.ir.Context()
   input_types = [mlir.ir.IntegerType.get_signless(ctx, 32),
                  mlir.ir.IntegerType.get_signless(ctx, 16)]
-  result_types = [mlir.ir.IndexType(ctx)]
+  result_types = [mlir.ir.IndexType.get(ctx)]
   func = mlir.ir.FunctionType.get(ctx, input_types, result_types)
   # CHECK: INPUTS: [Type(i32), Type(i16)]
   print("INPUTS:", func.inputs)


        


More information about the Mlir-commits mailing list