[Mlir-commits] [mlir] 225648e - [mlir][python] add type wrappers (#71218)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 27 13:58:05 PST 2023


Author: Maksim Levental
Date: 2023-11-27T15:58:00-06:00
New Revision: 225648e91ccd951eab9a4ab3200248d5617df1cc

URL: https://github.com/llvm/llvm-project/commit/225648e91ccd951eab9a4ab3200248d5617df1cc
DIFF: https://github.com/llvm/llvm-project/commit/225648e91ccd951eab9a4ab3200248d5617df1cc.diff

LOG: [mlir][python] add type wrappers (#71218)

Added: 
    mlir/python/mlir/extras/__init__.py
    mlir/python/mlir/extras/types.py

Modified: 
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/IRTypes.cpp
    mlir/python/CMakeLists.txt
    mlir/test/python/ir/builtin_types.py
    mlir/test/python/ir/context_managers.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 5612f3b960ac7f8..5412c3dec4b1b68 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2558,8 +2558,8 @@ void mlir::python::populateIRCore(py::module &m) {
           [](py::object & /*class*/) {
             auto *context = PyThreadContextEntry::getDefaultContext();
             if (!context)
-              throw py::value_error("No current Context");
-            return context;
+              return py::none().cast<py::object>();
+            return py::cast(context);
           },
           "Gets the Context bound to the current thread or raises ValueError")
       .def_property_readonly(

diff  --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 483db673f989e6b..56e895d3053796e 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -463,7 +463,7 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
 
   static void bindDerived(ClassTy &c) {
     c.def_static("get", &PyVectorType::get, py::arg("shape"),
-                 py::arg("elementType"), py::kw_only(),
+                 py::arg("element_type"), py::kw_only(),
                  py::arg("scalable") = py::none(),
                  py::arg("scalable_dims") = py::none(),
                  py::arg("loc") = py::none(), "Create a vector type")
@@ -689,13 +689,9 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
   static void bindDerived(ClassTy &c) {
     c.def_static(
         "get_tuple",
-        [](py::list elementList, DefaultingPyMlirContext context) {
-          intptr_t num = py::len(elementList);
-          // Mapping py::list to SmallVector.
-          SmallVector<MlirType, 4> elements;
-          for (auto element : elementList)
-            elements.push_back(element.cast<PyType>());
-          MlirType t = mlirTupleTypeGet(context->get(), num, elements.data());
+        [](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
+          MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
+                                        elements.data());
           return PyTupleType(context->getRef(), t);
         },
         py::arg("elements"), py::arg("context") = py::none(),
@@ -727,13 +723,11 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
   static void bindDerived(ClassTy &c) {
     c.def_static(
         "get",
-        [](std::vector<PyType> inputs, std::vector<PyType> results,
+        [](std::vector<MlirType> inputs, std::vector<MlirType> results,
            DefaultingPyMlirContext context) {
-          SmallVector<MlirType, 4> inputsRaw(inputs.begin(), inputs.end());
-          SmallVector<MlirType, 4> resultsRaw(results.begin(), results.end());
-          MlirType t = mlirFunctionTypeGet(context->get(), inputsRaw.size(),
-                                           inputsRaw.data(), resultsRaw.size(),
-                                           resultsRaw.data());
+          MlirType t =
+              mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
+                                  results.size(), results.data());
           return PyFunctionType(context->getRef(), t);
         },
         py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(),
@@ -742,7 +736,6 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
         "inputs",
         [](PyFunctionType &self) {
           MlirType t = self;
-          auto contextRef = self.getContext();
           py::list types;
           for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
                ++i) {
@@ -754,7 +747,6 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
     c.def_property_readonly(
         "results",
         [](PyFunctionType &self) {
-          auto contextRef = self.getContext();
           py::list types;
           for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
                ++i) {

diff  --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 971ad2dd214a15f..55731943fb78de4 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -21,6 +21,7 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
     _mlir_libs/__init__.py
     ir.py
     passmanager.py
+    extras/types.py
     dialects/_ods_common.py
 
     # The main _mlir module has submodules: include stubs from each.

diff  --git a/mlir/python/mlir/extras/__init__.py b/mlir/python/mlir/extras/__init__.py
new file mode 100644
index 000000000000000..e69de29bb2d1d64

diff  --git a/mlir/python/mlir/extras/types.py b/mlir/python/mlir/extras/types.py
new file mode 100644
index 000000000000000..db9e8229fb2884e
--- /dev/null
+++ b/mlir/python/mlir/extras/types.py
@@ -0,0 +1,165 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from functools import partial
+from typing import Optional, List
+
+from ..ir import (
+    Attribute,
+    BF16Type,
+    ComplexType,
+    F16Type,
+    F32Type,
+    F64Type,
+    Float8E4M3B11FNUZType,
+    Float8E4M3FNType,
+    Float8E5M2Type,
+    FunctionType,
+    IndexType,
+    IntegerType,
+    MemRefType,
+    NoneType,
+    OpaqueType,
+    RankedTensorType,
+    StridedLayoutAttr,
+    StringAttr,
+    TupleType,
+    Type,
+    UnrankedMemRefType,
+    UnrankedTensorType,
+    VectorType,
+)
+
+index = lambda: IndexType.get()
+
+
+def i(width):
+    return IntegerType.get_signless(width)
+
+
+def si(width):
+    return IntegerType.get_signed(width)
+
+
+def ui(width):
+    return IntegerType.get_unsigned(width)
+
+
+bool = lambda: i(1)
+i8 = lambda: i(8)
+i16 = lambda: i(16)
+i32 = lambda: i(32)
+i64 = lambda: i(64)
+
+si8 = lambda: si(8)
+si16 = lambda: si(16)
+si32 = lambda: si(32)
+si64 = lambda: si(64)
+
+ui8 = lambda: ui(8)
+ui16 = lambda: ui(16)
+ui32 = lambda: ui(32)
+ui64 = lambda: ui(64)
+
+f16 = lambda: F16Type.get()
+f32 = lambda: F32Type.get()
+f64 = lambda: F64Type.get()
+bf16 = lambda: BF16Type.get()
+
+f8E5M2 = lambda: Float8E5M2Type.get()
+f8E4M3 = lambda: Float8E4M3FNType.get()
+f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get()
+
+none = lambda: NoneType.get()
+
+
+def complex(type):
+    return ComplexType.get(type)
+
+
+def opaque(dialect_namespace, type_data):
+    return OpaqueType.get(dialect_namespace, type_data)
+
+
+def _shaped(*shape, element_type: Type = None, type_constructor=None):
+    if type_constructor is None:
+        raise ValueError("shaped is an abstract base class - cannot be constructed.")
+    if (element_type is None and shape and not isinstance(shape[-1], Type)) or (
+        shape and isinstance(shape[-1], Type) and element_type is not None
+    ):
+        raise ValueError(
+            f"Either element_type must be provided explicitly XOR last arg to tensor type constructor must be the element type."
+        )
+    if element_type is not None:
+        type = element_type
+        sizes = shape
+    else:
+        type = shape[-1]
+        sizes = shape[:-1]
+    if sizes:
+        return type_constructor(sizes, type)
+    else:
+        return type_constructor(type)
+
+
+def vector(
+    *shape,
+    element_type: Type = None,
+    scalable: Optional[List[bool]] = None,
+    scalable_dims: Optional[List[int]] = None,
+):
+    return _shaped(
+        *shape,
+        element_type=element_type,
+        type_constructor=partial(
+            VectorType.get, scalable=scalable, scalable_dims=scalable_dims
+        ),
+    )
+
+
+def tensor(*shape, element_type: Type = None, encoding: Optional[str] = None):
+    if encoding is not None:
+        encoding = StringAttr.get(encoding)
+    if not shape or (len(shape) == 1 and isinstance(shape[-1], Type)):
+        if encoding is not None:
+            raise ValueError("UnrankedTensorType does not support encoding.")
+        return _shaped(
+            *shape, element_type=element_type, type_constructor=UnrankedTensorType.get
+        )
+    return _shaped(
+        *shape,
+        element_type=element_type,
+        type_constructor=partial(RankedTensorType.get, encoding=encoding),
+    )
+
+
+def memref(
+    *shape,
+    element_type: Type = None,
+    memory_space: Optional[int] = None,
+    layout: Optional[StridedLayoutAttr] = None,
+):
+    if memory_space is not None:
+        memory_space = Attribute.parse(str(memory_space))
+    if not shape or (len(shape) == 1 and isinstance(shape[-1], Type)):
+        return _shaped(
+            *shape,
+            element_type=element_type,
+            type_constructor=partial(UnrankedMemRefType.get, memory_space=memory_space),
+        )
+    return _shaped(
+        *shape,
+        element_type=element_type,
+        type_constructor=partial(
+            MemRefType.get, memory_space=memory_space, layout=layout
+        ),
+    )
+
+
+def tuple(*elements):
+    return TupleType.get_tuple(elements)
+
+
+def function(*, inputs, results):
+    return FunctionType.get(inputs, results)

diff  --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index d4fed86b4f135ee..30a5054ada91ac7 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -3,6 +3,7 @@
 import gc
 from mlir.ir import *
 from mlir.dialects import arith, tensor, func, memref
+import mlir.extras.types as T
 
 
 def run(f):
@@ -772,3 +773,101 @@ def testCustomTypeTypeCaster():
         print(t)
         # CHECK: OperationType(!transform.op<"foo.bar">)
         print(repr(t))
+
+
+# CHECK-LABEL: TEST: testTypeWrappers
+ at run
+def testTypeWrappers():
+    def stride(strides, offset=0):
+        return StridedLayoutAttr.get(offset, strides)
+
+    with Context(), Location.unknown():
+        ia = T.i(5)
+        sia = T.si(6)
+        uia = T.ui(7)
+        assert repr(ia) == "IntegerType(i5)"
+        assert repr(sia) == "IntegerType(si6)"
+        assert repr(uia) == "IntegerType(ui7)"
+
+        assert T.i(16) == T.i16()
+        assert T.si(16) == T.si16()
+        assert T.ui(16) == T.ui16()
+
+        c1 = T.complex(T.f16())
+        c2 = T.complex(T.i32())
+        assert repr(c1) == "ComplexType(complex<f16>)"
+        assert repr(c2) == "ComplexType(complex<i32>)"
+
+        vec_1 = T.vector(2, 3, T.f32())
+        vec_2 = T.vector(2, 3, 4, T.f32())
+        assert repr(vec_1) == "VectorType(vector<2x3xf32>)"
+        assert repr(vec_2) == "VectorType(vector<2x3x4xf32>)"
+
+        m1 = T.memref(2, 3, 4, T.f64())
+        assert repr(m1) == "MemRefType(memref<2x3x4xf64>)"
+
+        m2 = T.memref(2, 3, 4, T.f64(), memory_space=1)
+        assert repr(m2) == "MemRefType(memref<2x3x4xf64, 1>)"
+
+        m3 = T.memref(2, 3, 4, T.f64(), memory_space=1, layout=stride([5, 7, 13]))
+        assert repr(m3) == "MemRefType(memref<2x3x4xf64, strided<[5, 7, 13]>, 1>)"
+
+        m4 = T.memref(2, 3, 4, T.f64(), memory_space=1, layout=stride([5, 7, 13], 42))
+        assert (
+            repr(m4)
+            == "MemRefType(memref<2x3x4xf64, strided<[5, 7, 13], offset: 42>, 1>)"
+        )
+
+        S = ShapedType.get_dynamic_size()
+
+        t1 = T.tensor(S, 3, S, T.f64())
+        assert repr(t1) == "RankedTensorType(tensor<?x3x?xf64>)"
+        ut1 = T.tensor(T.f64())
+        assert repr(ut1) == "UnrankedTensorType(tensor<*xf64>)"
+        t2 = T.tensor(S, 3, S, element_type=T.f64())
+        assert repr(t2) == "RankedTensorType(tensor<?x3x?xf64>)"
+        ut2 = T.tensor(element_type=T.f64())
+        assert repr(ut2) == "UnrankedTensorType(tensor<*xf64>)"
+
+        t3 = T.tensor(S, 3, S, T.f64(), encoding="encoding")
+        assert repr(t3) == 'RankedTensorType(tensor<?x3x?xf64, "encoding">)'
+
+        v = T.vector(3, 3, 3, T.f64())
+        assert repr(v) == "VectorType(vector<3x3x3xf64>)"
+
+        m5 = T.memref(S, 3, S, T.f64())
+        assert repr(m5) == "MemRefType(memref<?x3x?xf64>)"
+        um1 = T.memref(T.f64())
+        assert repr(um1) == "UnrankedMemRefType(memref<*xf64>)"
+        m6 = T.memref(S, 3, S, element_type=T.f64())
+        assert repr(m6) == "MemRefType(memref<?x3x?xf64>)"
+        um2 = T.memref(element_type=T.f64())
+        assert repr(um2) == "UnrankedMemRefType(memref<*xf64>)"
+
+        m7 = T.memref(S, 3, S, T.f64())
+        assert repr(m7) == "MemRefType(memref<?x3x?xf64>)"
+        um3 = T.memref(T.f64())
+        assert repr(um3) == "UnrankedMemRefType(memref<*xf64>)"
+
+        scalable_1 = T.vector(2, 3, T.f32(), scalable=[False, True])
+        scalable_2 = T.vector(2, 3, 4, T.f32(), scalable=[True, False, True])
+        assert repr(scalable_1) == "VectorType(vector<2x[3]xf32>)"
+        assert repr(scalable_2) == "VectorType(vector<[2]x3x[4]xf32>)"
+
+        scalable_3 = T.vector(2, 3, T.f32(), scalable_dims=[1])
+        scalable_4 = T.vector(2, 3, 4, T.f32(), scalable_dims=[0, 2])
+        assert scalable_3 == scalable_1
+        assert scalable_4 == scalable_2
+
+        opaq = T.opaque("scf", "placeholder")
+        assert repr(opaq) == "OpaqueType(!scf.placeholder)"
+
+        tup1 = T.tuple(T.i16(), T.i32(), T.i64())
+        tup2 = T.tuple(T.f16(), T.f32(), T.f64())
+        assert repr(tup1) == "TupleType(tuple<i16, i32, i64>)"
+        assert repr(tup2) == "TupleType(tuple<f16, f32, f64>)"
+
+        func = T.function(
+            inputs=(T.i16(), T.i32(), T.i64()), results=(T.f16(), T.f32(), T.f64())
+        )
+        assert repr(func) == "FunctionType((i16, i32, i64) -> (f16, f32, f64))"

diff  --git a/mlir/test/python/ir/context_managers.py b/mlir/test/python/ir/context_managers.py
index 48d9e357324c9e5..8091687f7f082d6 100644
--- a/mlir/test/python/ir/context_managers.py
+++ b/mlir/test/python/ir/context_managers.py
@@ -15,13 +15,7 @@ def run(f):
 def testContextEnterExit():
     with Context() as ctx:
         assert Context.current is ctx
-    try:
-        _ = Context.current
-    except ValueError as e:
-        # CHECK: No current Context
-        print(e)
-    else:
-        assert False, "Expected exception"
+    assert Context.current is None
 
 
 run(testContextEnterExit)


        


More information about the Mlir-commits mailing list