[Mlir-commits] [mlir] [mlir][python] add type wrappers (PR #71218)
Maksim Levental
llvmlistbot at llvm.org
Mon Nov 27 10:35:03 PST 2023
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/71218
>From 8282a78de440cd98148b499e7a2270904e0f6e09 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 2 Nov 2023 11:15:45 -0500
Subject: [PATCH 1/2] [mlir][python] add type wrappers
---
mlir/lib/Bindings/Python/IRTypes.cpp | 24 ++--
mlir/python/CMakeLists.txt | 1 +
mlir/python/mlir/types.py | 189 +++++++++++++++++++++++++++
mlir/test/python/ir/builtin_types.py | 112 ++++++++++++++++
4 files changed, 310 insertions(+), 16 deletions(-)
create mode 100644 mlir/python/mlir/types.py
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 483db673f989e6b..56e895d3053796e 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -463,7 +463,7 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
static void bindDerived(ClassTy &c) {
c.def_static("get", &PyVectorType::get, py::arg("shape"),
- py::arg("elementType"), py::kw_only(),
+ py::arg("element_type"), py::kw_only(),
py::arg("scalable") = py::none(),
py::arg("scalable_dims") = py::none(),
py::arg("loc") = py::none(), "Create a vector type")
@@ -689,13 +689,9 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
static void bindDerived(ClassTy &c) {
c.def_static(
"get_tuple",
- [](py::list elementList, DefaultingPyMlirContext context) {
- intptr_t num = py::len(elementList);
- // Mapping py::list to SmallVector.
- SmallVector<MlirType, 4> elements;
- for (auto element : elementList)
- elements.push_back(element.cast<PyType>());
- MlirType t = mlirTupleTypeGet(context->get(), num, elements.data());
+ [](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
+ MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
+ elements.data());
return PyTupleType(context->getRef(), t);
},
py::arg("elements"), py::arg("context") = py::none(),
@@ -727,13 +723,11 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](std::vector<PyType> inputs, std::vector<PyType> results,
+ [](std::vector<MlirType> inputs, std::vector<MlirType> results,
DefaultingPyMlirContext context) {
- SmallVector<MlirType, 4> inputsRaw(inputs.begin(), inputs.end());
- SmallVector<MlirType, 4> resultsRaw(results.begin(), results.end());
- MlirType t = mlirFunctionTypeGet(context->get(), inputsRaw.size(),
- inputsRaw.data(), resultsRaw.size(),
- resultsRaw.data());
+ MlirType t =
+ mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
+ results.size(), results.data());
return PyFunctionType(context->getRef(), t);
},
py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(),
@@ -742,7 +736,6 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
"inputs",
[](PyFunctionType &self) {
MlirType t = self;
- auto contextRef = self.getContext();
py::list types;
for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
++i) {
@@ -754,7 +747,6 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
c.def_property_readonly(
"results",
[](PyFunctionType &self) {
- auto contextRef = self.getContext();
py::list types;
for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
++i) {
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 971ad2dd214a15f..12e2dab60f3011b 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -21,6 +21,7 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
_mlir_libs/__init__.py
ir.py
passmanager.py
+ types.py
dialects/_ods_common.py
# The main _mlir module has submodules: include stubs from each.
diff --git a/mlir/python/mlir/types.py b/mlir/python/mlir/types.py
new file mode 100644
index 000000000000000..aa8a2639ac980b2
--- /dev/null
+++ b/mlir/python/mlir/types.py
@@ -0,0 +1,189 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from functools import partial
+from typing import Optional, List
+
+from .ir import (
+ Attribute,
+ BF16Type,
+ ComplexType,
+ Context,
+ F16Type,
+ F32Type,
+ F64Type,
+ Float8E4M3B11FNUZType,
+ Float8E4M3FNType,
+ Float8E5M2Type,
+ FunctionType,
+ IndexType,
+ IntegerType,
+ MemRefType,
+ NoneType,
+ OpaqueType,
+ RankedTensorType,
+ StridedLayoutAttr,
+ StringAttr,
+ TupleType,
+ Type,
+ UnrankedMemRefType,
+ UnrankedTensorType,
+ VectorType,
+)
+
+__all__ = []
+
+_index = lambda: IndexType.get()
+_bool = lambda: IntegerType.get_signless(1)
+
+_i8 = lambda: IntegerType.get_signless(8)
+_i16 = lambda: IntegerType.get_signless(16)
+_i32 = lambda: IntegerType.get_signless(32)
+_i64 = lambda: IntegerType.get_signless(64)
+
+_si8 = lambda: IntegerType.get_signed(8)
+_si16 = lambda: IntegerType.get_signed(16)
+_si32 = lambda: IntegerType.get_signed(32)
+_si64 = lambda: IntegerType.get_signed(64)
+
+_ui8 = lambda: IntegerType.get_unsigned(8)
+_ui16 = lambda: IntegerType.get_unsigned(16)
+_ui32 = lambda: IntegerType.get_unsigned(32)
+_ui64 = lambda: IntegerType.get_unsigned(64)
+
+_f16 = lambda: F16Type.get()
+_f32 = lambda: F32Type.get()
+_f64 = lambda: F64Type.get()
+_bf16 = lambda: BF16Type.get()
+
+_f8e5m2 = lambda: Float8E5M2Type.get()
+_f8e4m3 = lambda: Float8E4M3FNType.get()
+_f8e4m3b11fnuz = lambda: Float8E4M3B11FNUZType.get()
+
+_none = lambda: NoneType.get()
+
+
+def _i(width):
+ return IntegerType.get_signless(width)
+
+
+def _si(width):
+ return IntegerType.get_signed(width)
+
+
+def _ui(width):
+ return IntegerType.get_unsigned(width)
+
+
+def _complex(type):
+ return ComplexType.get(type)
+
+
+def _opaque(dialect_namespace, type_data):
+ return OpaqueType.get(dialect_namespace, type_data)
+
+
+def _shaped(*shape, element_type: Type = None, type_constructor=None):
+ if type_constructor is None:
+ raise ValueError("shaped is an abstract base class - cannot be constructed.")
+ if (element_type is None and shape and not isinstance(shape[-1], Type)) or (
+ shape and isinstance(shape[-1], Type) and element_type is not None
+ ):
+ raise ValueError(
+ f"Either element_type must be provided explicitly XOR last arg to tensor type constructor must be the element type."
+ )
+ if element_type is not None:
+ type = element_type
+ sizes = shape
+ else:
+ type = shape[-1]
+ sizes = shape[:-1]
+ if sizes:
+ return type_constructor(sizes, type)
+ else:
+ return type_constructor(type)
+
+
+def _vector(
+ *shape,
+ element_type: Type = None,
+ scalable: Optional[List[bool]] = None,
+ scalable_dims: Optional[List[int]] = None,
+):
+ return _shaped(
+ *shape,
+ element_type=element_type,
+ type_constructor=partial(
+ VectorType.get, scalable=scalable, scalable_dims=scalable_dims
+ ),
+ )
+
+
+def _tensor(*shape, element_type: Type = None, encoding: Optional[str] = None):
+ if encoding is not None:
+ encoding = StringAttr.get(encoding)
+ if not shape or (len(shape) == 1 and isinstance(shape[-1], Type)):
+ if encoding is not None:
+ raise ValueError("UnrankedTensorType does not support encoding.")
+ return _shaped(
+ *shape, element_type=element_type, type_constructor=UnrankedTensorType.get
+ )
+ return _shaped(
+ *shape,
+ element_type=element_type,
+ type_constructor=partial(RankedTensorType.get, encoding=encoding),
+ )
+
+
+def _memref(
+ *shape,
+ element_type: Type = None,
+ memory_space: Optional[int] = None,
+ layout: Optional[StridedLayoutAttr] = None,
+):
+ if memory_space is not None:
+ memory_space = Attribute.parse(str(memory_space))
+ if not shape or (len(shape) == 1 and isinstance(shape[-1], Type)):
+ return _shaped(
+ *shape,
+ element_type=element_type,
+ type_constructor=partial(UnrankedMemRefType.get, memory_space=memory_space),
+ )
+ return _shaped(
+ *shape,
+ element_type=element_type,
+ type_constructor=partial(
+ MemRefType.get, memory_space=memory_space, layout=layout
+ ),
+ )
+
+
+def _tuple(*elements):
+ return TupleType.get_tuple(elements)
+
+
+def _function(*, inputs, results):
+ return FunctionType.get(inputs, results)
+
+
+def __getattr__(name):
+ if name == "__path__":
+ # https://docs.python.org/3/reference/import.html#path__
+ # If a module is a package (either regular or namespace), the module object’s __path__ attribute must be set.
+ # This module is NOT a package and so this must be None (rather than throw the RuntimeError below).
+ return None
+ try:
+ Context.current
+ except ValueError:
+ raise RuntimeError("Types can only be instantiated under an active context.")
+
+ if f"_{name}" in globals():
+ builder = globals()[f"_{name}"]
+ if (
+ isinstance(builder, type(lambda: None))
+ and builder.__name__ == (lambda: None).__name__
+ ):
+ return builder()
+ return builder
+ raise RuntimeError(f"{name} is not a legal type.")
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index d4fed86b4f135ee..6dbb35d10aad48a 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -3,6 +3,7 @@
import gc
from mlir.ir import *
from mlir.dialects import arith, tensor, func, memref
+import mlir.types as T
def run(f):
@@ -772,3 +773,114 @@ def testCustomTypeTypeCaster():
print(t)
# CHECK: OperationType(!transform.op<"foo.bar">)
print(repr(t))
+
+
+# CHECK-LABEL: TEST: testTypeWrappers
+ at run
+def testTypeWrappers():
+ try:
+ from mlir.types import i32
+ except RuntimeError as e:
+ assert e.args[0] == "Types can only be instantiated under an active context."
+
+ try:
+ from mlir.types import tensor
+ except RuntimeError as e:
+ assert e.args[0] == "Types can only be instantiated under an active context."
+
+ def stride(strides, offset=0):
+ return StridedLayoutAttr.get(offset, strides)
+
+ with Context(), Location.unknown():
+ try:
+ from mlir.types import non_existent_type
+ except RuntimeError as e:
+ assert e.args[0] == "non_existent_type is not a legal type."
+
+ ia = T.i(5)
+ sia = T.si(6)
+ uia = T.ui(7)
+ assert repr(ia) == "IntegerType(i5)"
+ assert repr(sia) == "IntegerType(si6)"
+ assert repr(uia) == "IntegerType(ui7)"
+
+ assert T.i(16) == T.i16
+ assert T.si(16) == T.si16
+ assert T.ui(16) == T.ui16
+
+ c1 = T.complex(T.f16)
+ c2 = T.complex(T.i32)
+ assert repr(c1) == "ComplexType(complex<f16>)"
+ assert repr(c2) == "ComplexType(complex<i32>)"
+
+ vec_1 = T.vector(2, 3, T.f32)
+ vec_2 = T.vector(2, 3, 4, T.f32)
+ assert repr(vec_1) == "VectorType(vector<2x3xf32>)"
+ assert repr(vec_2) == "VectorType(vector<2x3x4xf32>)"
+
+ m1 = T.memref(2, 3, 4, T.f64)
+ assert repr(m1) == "MemRefType(memref<2x3x4xf64>)"
+
+ m2 = T.memref(2, 3, 4, T.f64, memory_space=1)
+ assert repr(m2) == "MemRefType(memref<2x3x4xf64, 1>)"
+
+ m3 = T.memref(2, 3, 4, T.f64, memory_space=1, layout=stride([5, 7, 13]))
+ assert repr(m3) == "MemRefType(memref<2x3x4xf64, strided<[5, 7, 13]>, 1>)"
+
+ m4 = T.memref(2, 3, 4, T.f64, memory_space=1, layout=stride([5, 7, 13], 42))
+ assert (
+ repr(m4)
+ == "MemRefType(memref<2x3x4xf64, strided<[5, 7, 13], offset: 42>, 1>)"
+ )
+
+ S = ShapedType.get_dynamic_size()
+
+ t1 = T.tensor(S, 3, S, T.f64)
+ assert repr(t1) == "RankedTensorType(tensor<?x3x?xf64>)"
+ ut1 = T.tensor(T.f64)
+ assert repr(ut1) == "UnrankedTensorType(tensor<*xf64>)"
+ t2 = T.tensor(S, 3, S, element_type=T.f64)
+ assert repr(t2) == "RankedTensorType(tensor<?x3x?xf64>)"
+ ut2 = T.tensor(element_type=T.f64)
+ assert repr(ut2) == "UnrankedTensorType(tensor<*xf64>)"
+
+ t3 = T.tensor(S, 3, S, T.f64, encoding="encoding")
+ assert repr(t3) == 'RankedTensorType(tensor<?x3x?xf64, "encoding">)'
+
+ v = T.vector(3, 3, 3, T.f64)
+ assert repr(v) == "VectorType(vector<3x3x3xf64>)"
+
+ m5 = T.memref(S, 3, S, T.f64)
+ assert repr(m5) == "MemRefType(memref<?x3x?xf64>)"
+ um1 = T.memref(T.f64)
+ assert repr(um1) == "UnrankedMemRefType(memref<*xf64>)"
+ m6 = T.memref(S, 3, S, element_type=T.f64)
+ assert repr(m6) == "MemRefType(memref<?x3x?xf64>)"
+ um2 = T.memref(element_type=T.f64)
+ assert repr(um2) == "UnrankedMemRefType(memref<*xf64>)"
+
+ m7 = T.memref(S, 3, S, T.f64)
+ assert repr(m7) == "MemRefType(memref<?x3x?xf64>)"
+ um3 = T.memref(T.f64)
+ assert repr(um3) == "UnrankedMemRefType(memref<*xf64>)"
+
+ scalable_1 = T.vector(2, 3, T.f32, scalable=[False, True])
+ scalable_2 = T.vector(2, 3, 4, T.f32, scalable=[True, False, True])
+ assert repr(scalable_1) == "VectorType(vector<2x[3]xf32>)"
+ assert repr(scalable_2) == "VectorType(vector<[2]x3x[4]xf32>)"
+
+ scalable_3 = T.vector(2, 3, T.f32, scalable_dims=[1])
+ scalable_4 = T.vector(2, 3, 4, T.f32, scalable_dims=[0, 2])
+ assert scalable_3 == scalable_1
+ assert scalable_4 == scalable_2
+
+ opaq = T.opaque("scf", "placeholder")
+ assert repr(opaq) == "OpaqueType(!scf.placeholder)"
+
+ tup1 = T.tuple(T.i16, T.i32, T.i64)
+ tup2 = T.tuple(T.f16, T.f32, T.f64)
+ assert repr(tup1) == "TupleType(tuple<i16, i32, i64>)"
+ assert repr(tup2) == "TupleType(tuple<f16, f32, f64>)"
+
+ func = T.function(inputs=(T.i16, T.i32, T.i64), results=(T.f16, T.f32, T.f64))
+ assert repr(func) == "FunctionType((i16, i32, i64) -> (f16, f32, f64))"
>From d1a9c4ca181bacd6862ca741f2b553ed2bee5b40 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Sat, 25 Nov 2023 20:30:37 -0600
Subject: [PATCH 2/2] incorporate comments
---
mlir/lib/Bindings/Python/IRCore.cpp | 4 +-
mlir/python/CMakeLists.txt | 2 +-
mlir/python/mlir/extras/__init__.py | 0
mlir/python/mlir/{ => extras}/types.py | 96 ++++++++++---------------
mlir/test/python/ir/builtin_types.py | 79 +++++++++-----------
mlir/test/python/ir/context_managers.py | 8 +--
6 files changed, 73 insertions(+), 116 deletions(-)
create mode 100644 mlir/python/mlir/extras/__init__.py
rename mlir/python/mlir/{ => extras}/types.py (62%)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 745aa64e63b67d4..bda41a572c4d66a 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2538,8 +2538,8 @@ void mlir::python::populateIRCore(py::module &m) {
[](py::object & /*class*/) {
auto *context = PyThreadContextEntry::getDefaultContext();
if (!context)
- throw py::value_error("No current Context");
- return context;
+ return py::none().cast<py::object>();
+ return py::cast(context);
},
"Gets the Context bound to the current thread or raises ValueError")
.def_property_readonly(
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 12e2dab60f3011b..55731943fb78de4 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -21,7 +21,7 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
_mlir_libs/__init__.py
ir.py
passmanager.py
- types.py
+ extras/types.py
dialects/_ods_common.py
# The main _mlir module has submodules: include stubs from each.
diff --git a/mlir/python/mlir/extras/__init__.py b/mlir/python/mlir/extras/__init__.py
new file mode 100644
index 000000000000000..e69de29bb2d1d64
diff --git a/mlir/python/mlir/types.py b/mlir/python/mlir/extras/types.py
similarity index 62%
rename from mlir/python/mlir/types.py
rename to mlir/python/mlir/extras/types.py
index aa8a2639ac980b2..db9e8229fb2884e 100644
--- a/mlir/python/mlir/types.py
+++ b/mlir/python/mlir/extras/types.py
@@ -5,11 +5,10 @@
from functools import partial
from typing import Optional, List
-from .ir import (
+from ..ir import (
Attribute,
BF16Type,
ComplexType,
- Context,
F16Type,
F32Type,
F64Type,
@@ -32,55 +31,54 @@
VectorType,
)
-__all__ = []
+index = lambda: IndexType.get()
-_index = lambda: IndexType.get()
-_bool = lambda: IntegerType.get_signless(1)
-_i8 = lambda: IntegerType.get_signless(8)
-_i16 = lambda: IntegerType.get_signless(16)
-_i32 = lambda: IntegerType.get_signless(32)
-_i64 = lambda: IntegerType.get_signless(64)
+def i(width):
+ return IntegerType.get_signless(width)
-_si8 = lambda: IntegerType.get_signed(8)
-_si16 = lambda: IntegerType.get_signed(16)
-_si32 = lambda: IntegerType.get_signed(32)
-_si64 = lambda: IntegerType.get_signed(64)
-_ui8 = lambda: IntegerType.get_unsigned(8)
-_ui16 = lambda: IntegerType.get_unsigned(16)
-_ui32 = lambda: IntegerType.get_unsigned(32)
-_ui64 = lambda: IntegerType.get_unsigned(64)
+def si(width):
+ return IntegerType.get_signed(width)
-_f16 = lambda: F16Type.get()
-_f32 = lambda: F32Type.get()
-_f64 = lambda: F64Type.get()
-_bf16 = lambda: BF16Type.get()
-_f8e5m2 = lambda: Float8E5M2Type.get()
-_f8e4m3 = lambda: Float8E4M3FNType.get()
-_f8e4m3b11fnuz = lambda: Float8E4M3B11FNUZType.get()
+def ui(width):
+ return IntegerType.get_unsigned(width)
-_none = lambda: NoneType.get()
+bool = lambda: i(1)
+i8 = lambda: i(8)
+i16 = lambda: i(16)
+i32 = lambda: i(32)
+i64 = lambda: i(64)
-def _i(width):
- return IntegerType.get_signless(width)
+si8 = lambda: si(8)
+si16 = lambda: si(16)
+si32 = lambda: si(32)
+si64 = lambda: si(64)
+ui8 = lambda: ui(8)
+ui16 = lambda: ui(16)
+ui32 = lambda: ui(32)
+ui64 = lambda: ui(64)
-def _si(width):
- return IntegerType.get_signed(width)
+f16 = lambda: F16Type.get()
+f32 = lambda: F32Type.get()
+f64 = lambda: F64Type.get()
+bf16 = lambda: BF16Type.get()
+f8E5M2 = lambda: Float8E5M2Type.get()
+f8E4M3 = lambda: Float8E4M3FNType.get()
+f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get()
-def _ui(width):
- return IntegerType.get_unsigned(width)
+none = lambda: NoneType.get()
-def _complex(type):
+def complex(type):
return ComplexType.get(type)
-def _opaque(dialect_namespace, type_data):
+def opaque(dialect_namespace, type_data):
return OpaqueType.get(dialect_namespace, type_data)
@@ -105,7 +103,7 @@ def _shaped(*shape, element_type: Type = None, type_constructor=None):
return type_constructor(type)
-def _vector(
+def vector(
*shape,
element_type: Type = None,
scalable: Optional[List[bool]] = None,
@@ -120,7 +118,7 @@ def _vector(
)
-def _tensor(*shape, element_type: Type = None, encoding: Optional[str] = None):
+def tensor(*shape, element_type: Type = None, encoding: Optional[str] = None):
if encoding is not None:
encoding = StringAttr.get(encoding)
if not shape or (len(shape) == 1 and isinstance(shape[-1], Type)):
@@ -136,7 +134,7 @@ def _tensor(*shape, element_type: Type = None, encoding: Optional[str] = None):
)
-def _memref(
+def memref(
*shape,
element_type: Type = None,
memory_space: Optional[int] = None,
@@ -159,31 +157,9 @@ def _memref(
)
-def _tuple(*elements):
+def tuple(*elements):
return TupleType.get_tuple(elements)
-def _function(*, inputs, results):
+def function(*, inputs, results):
return FunctionType.get(inputs, results)
-
-
-def __getattr__(name):
- if name == "__path__":
- # https://docs.python.org/3/reference/import.html#path__
- # If a module is a package (either regular or namespace), the module object’s __path__ attribute must be set.
- # This module is NOT a package and so this must be None (rather than throw the RuntimeError below).
- return None
- try:
- Context.current
- except ValueError:
- raise RuntimeError("Types can only be instantiated under an active context.")
-
- if f"_{name}" in globals():
- builder = globals()[f"_{name}"]
- if (
- isinstance(builder, type(lambda: None))
- and builder.__name__ == (lambda: None).__name__
- ):
- return builder()
- return builder
- raise RuntimeError(f"{name} is not a legal type.")
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 6dbb35d10aad48a..30a5054ada91ac7 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -3,7 +3,7 @@
import gc
from mlir.ir import *
from mlir.dialects import arith, tensor, func, memref
-import mlir.types as T
+import mlir.extras.types as T
def run(f):
@@ -778,25 +778,10 @@ def testCustomTypeTypeCaster():
# CHECK-LABEL: TEST: testTypeWrappers
@run
def testTypeWrappers():
- try:
- from mlir.types import i32
- except RuntimeError as e:
- assert e.args[0] == "Types can only be instantiated under an active context."
-
- try:
- from mlir.types import tensor
- except RuntimeError as e:
- assert e.args[0] == "Types can only be instantiated under an active context."
-
def stride(strides, offset=0):
return StridedLayoutAttr.get(offset, strides)
with Context(), Location.unknown():
- try:
- from mlir.types import non_existent_type
- except RuntimeError as e:
- assert e.args[0] == "non_existent_type is not a legal type."
-
ia = T.i(5)
sia = T.si(6)
uia = T.ui(7)
@@ -804,30 +789,30 @@ def stride(strides, offset=0):
assert repr(sia) == "IntegerType(si6)"
assert repr(uia) == "IntegerType(ui7)"
- assert T.i(16) == T.i16
- assert T.si(16) == T.si16
- assert T.ui(16) == T.ui16
+ assert T.i(16) == T.i16()
+ assert T.si(16) == T.si16()
+ assert T.ui(16) == T.ui16()
- c1 = T.complex(T.f16)
- c2 = T.complex(T.i32)
+ c1 = T.complex(T.f16())
+ c2 = T.complex(T.i32())
assert repr(c1) == "ComplexType(complex<f16>)"
assert repr(c2) == "ComplexType(complex<i32>)"
- vec_1 = T.vector(2, 3, T.f32)
- vec_2 = T.vector(2, 3, 4, T.f32)
+ vec_1 = T.vector(2, 3, T.f32())
+ vec_2 = T.vector(2, 3, 4, T.f32())
assert repr(vec_1) == "VectorType(vector<2x3xf32>)"
assert repr(vec_2) == "VectorType(vector<2x3x4xf32>)"
- m1 = T.memref(2, 3, 4, T.f64)
+ m1 = T.memref(2, 3, 4, T.f64())
assert repr(m1) == "MemRefType(memref<2x3x4xf64>)"
- m2 = T.memref(2, 3, 4, T.f64, memory_space=1)
+ m2 = T.memref(2, 3, 4, T.f64(), memory_space=1)
assert repr(m2) == "MemRefType(memref<2x3x4xf64, 1>)"
- m3 = T.memref(2, 3, 4, T.f64, memory_space=1, layout=stride([5, 7, 13]))
+ m3 = T.memref(2, 3, 4, T.f64(), memory_space=1, layout=stride([5, 7, 13]))
assert repr(m3) == "MemRefType(memref<2x3x4xf64, strided<[5, 7, 13]>, 1>)"
- m4 = T.memref(2, 3, 4, T.f64, memory_space=1, layout=stride([5, 7, 13], 42))
+ m4 = T.memref(2, 3, 4, T.f64(), memory_space=1, layout=stride([5, 7, 13], 42))
assert (
repr(m4)
== "MemRefType(memref<2x3x4xf64, strided<[5, 7, 13], offset: 42>, 1>)"
@@ -835,52 +820,54 @@ def stride(strides, offset=0):
S = ShapedType.get_dynamic_size()
- t1 = T.tensor(S, 3, S, T.f64)
+ t1 = T.tensor(S, 3, S, T.f64())
assert repr(t1) == "RankedTensorType(tensor<?x3x?xf64>)"
- ut1 = T.tensor(T.f64)
+ ut1 = T.tensor(T.f64())
assert repr(ut1) == "UnrankedTensorType(tensor<*xf64>)"
- t2 = T.tensor(S, 3, S, element_type=T.f64)
+ t2 = T.tensor(S, 3, S, element_type=T.f64())
assert repr(t2) == "RankedTensorType(tensor<?x3x?xf64>)"
- ut2 = T.tensor(element_type=T.f64)
+ ut2 = T.tensor(element_type=T.f64())
assert repr(ut2) == "UnrankedTensorType(tensor<*xf64>)"
- t3 = T.tensor(S, 3, S, T.f64, encoding="encoding")
+ t3 = T.tensor(S, 3, S, T.f64(), encoding="encoding")
assert repr(t3) == 'RankedTensorType(tensor<?x3x?xf64, "encoding">)'
- v = T.vector(3, 3, 3, T.f64)
+ v = T.vector(3, 3, 3, T.f64())
assert repr(v) == "VectorType(vector<3x3x3xf64>)"
- m5 = T.memref(S, 3, S, T.f64)
+ m5 = T.memref(S, 3, S, T.f64())
assert repr(m5) == "MemRefType(memref<?x3x?xf64>)"
- um1 = T.memref(T.f64)
+ um1 = T.memref(T.f64())
assert repr(um1) == "UnrankedMemRefType(memref<*xf64>)"
- m6 = T.memref(S, 3, S, element_type=T.f64)
+ m6 = T.memref(S, 3, S, element_type=T.f64())
assert repr(m6) == "MemRefType(memref<?x3x?xf64>)"
- um2 = T.memref(element_type=T.f64)
+ um2 = T.memref(element_type=T.f64())
assert repr(um2) == "UnrankedMemRefType(memref<*xf64>)"
- m7 = T.memref(S, 3, S, T.f64)
+ m7 = T.memref(S, 3, S, T.f64())
assert repr(m7) == "MemRefType(memref<?x3x?xf64>)"
- um3 = T.memref(T.f64)
+ um3 = T.memref(T.f64())
assert repr(um3) == "UnrankedMemRefType(memref<*xf64>)"
- scalable_1 = T.vector(2, 3, T.f32, scalable=[False, True])
- scalable_2 = T.vector(2, 3, 4, T.f32, scalable=[True, False, True])
+ scalable_1 = T.vector(2, 3, T.f32(), scalable=[False, True])
+ scalable_2 = T.vector(2, 3, 4, T.f32(), scalable=[True, False, True])
assert repr(scalable_1) == "VectorType(vector<2x[3]xf32>)"
assert repr(scalable_2) == "VectorType(vector<[2]x3x[4]xf32>)"
- scalable_3 = T.vector(2, 3, T.f32, scalable_dims=[1])
- scalable_4 = T.vector(2, 3, 4, T.f32, scalable_dims=[0, 2])
+ scalable_3 = T.vector(2, 3, T.f32(), scalable_dims=[1])
+ scalable_4 = T.vector(2, 3, 4, T.f32(), scalable_dims=[0, 2])
assert scalable_3 == scalable_1
assert scalable_4 == scalable_2
opaq = T.opaque("scf", "placeholder")
assert repr(opaq) == "OpaqueType(!scf.placeholder)"
- tup1 = T.tuple(T.i16, T.i32, T.i64)
- tup2 = T.tuple(T.f16, T.f32, T.f64)
+ tup1 = T.tuple(T.i16(), T.i32(), T.i64())
+ tup2 = T.tuple(T.f16(), T.f32(), T.f64())
assert repr(tup1) == "TupleType(tuple<i16, i32, i64>)"
assert repr(tup2) == "TupleType(tuple<f16, f32, f64>)"
- func = T.function(inputs=(T.i16, T.i32, T.i64), results=(T.f16, T.f32, T.f64))
+ func = T.function(
+ inputs=(T.i16(), T.i32(), T.i64()), results=(T.f16(), T.f32(), T.f64())
+ )
assert repr(func) == "FunctionType((i16, i32, i64) -> (f16, f32, f64))"
diff --git a/mlir/test/python/ir/context_managers.py b/mlir/test/python/ir/context_managers.py
index 48d9e357324c9e5..8091687f7f082d6 100644
--- a/mlir/test/python/ir/context_managers.py
+++ b/mlir/test/python/ir/context_managers.py
@@ -15,13 +15,7 @@ def run(f):
def testContextEnterExit():
with Context() as ctx:
assert Context.current is ctx
- try:
- _ = Context.current
- except ValueError as e:
- # CHECK: No current Context
- print(e)
- else:
- assert False, "Expected exception"
+ assert Context.current is None
run(testContextEnterExit)
More information about the Mlir-commits
mailing list