[Mlir-commits] [mlir] [mlir][python] add type wrappers (PR #71218)
Maksim Levental
llvmlistbot at llvm.org
Mon Nov 13 09:03:26 PST 2023
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/71218
>From 34c7cf40b42264f1610ec64ff0f96018d9e4cf9f Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 2 Nov 2023 11:15:45 -0500
Subject: [PATCH] [mlir][python] add type wrappers
---
mlir/python/CMakeLists.txt | 1 +
mlir/python/mlir/types.py | 184 +++++++++++++++++++++++++++
mlir/test/python/ir/builtin_types.py | 39 +++++-
3 files changed, 223 insertions(+), 1 deletion(-)
create mode 100644 mlir/python/mlir/types.py
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 88e6e13602d291a..1de5d039030c606 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -19,6 +19,7 @@ declare_mlir_python_sources(MLIRPythonSources.Core
_mlir_libs/__init__.py
ir.py
passmanager.py
+ types.py
dialects/_ods_common.py
# The main _mlir module has submodules: include stubs from each.
diff --git a/mlir/python/mlir/types.py b/mlir/python/mlir/types.py
new file mode 100644
index 000000000000000..dd249d16bf62046
--- /dev/null
+++ b/mlir/python/mlir/types.py
@@ -0,0 +1,184 @@
+from functools import partial
+
+from .ir import (
+ Attribute,
+ BF16Type,
+ ComplexType,
+ F16Type,
+ F32Type,
+ F64Type,
+ Float8E4M3B11FNUZType,
+ Float8E4M3FNType,
+ Float8E5M2Type,
+ IndexType,
+ IntegerType,
+ MemRefType,
+ NoneType,
+ OpaqueType,
+ RankedTensorType,
+ StridedLayoutAttr,
+ Type,
+ UnrankedMemRefType,
+ UnrankedTensorType,
+ VectorType,
+)
+
+from .dialects import transform
+from .dialects import pdl
+
+
+_index = lambda: IndexType.get()
+_bool = lambda: IntegerType.get_signless(1)
+
+_i8 = lambda: IntegerType.get_signless(8)
+_i16 = lambda: IntegerType.get_signless(16)
+_i32 = lambda: IntegerType.get_signless(32)
+_i64 = lambda: IntegerType.get_signless(64)
+
+_si8 = lambda: IntegerType.get_signed(8)
+_si16 = lambda: IntegerType.get_signed(16)
+_si32 = lambda: IntegerType.get_signed(32)
+_si64 = lambda: IntegerType.get_signed(64)
+
+_ui8 = lambda: IntegerType.get_unsigned(8)
+_ui16 = lambda: IntegerType.get_unsigned(16)
+_ui32 = lambda: IntegerType.get_unsigned(32)
+_ui64 = lambda: IntegerType.get_unsigned(64)
+
+_f16 = lambda: F16Type.get()
+_f32 = lambda: F32Type.get()
+_f64 = lambda: F64Type.get()
+_bf16 = lambda: BF16Type.get()
+
+_f8e5m2 = lambda: Float8E5M2Type.get()
+_f8e4m3 = lambda: Float8E4M3FNType.get()
+_f8e4m3b11fnuz = lambda: Float8E4M3B11FNUZType.get()
+
+_cmp16 = lambda: ComplexType.get(_f16())
+_cmp32 = lambda: ComplexType.get(_f32())
+_cmp64 = lambda: ComplexType.get(_f64())
+
+_none = lambda: NoneType.get()
+
+_pdl_operation = lambda: pdl.OperationType.get()
+
+opaque = lambda dialect_namespace, buffer: OpaqueType.get(dialect_namespace, buffer)
+
+
+def _transform_any_op():
+ return transform.AnyOpType.get()
+
+
+def _llvm_ptr():
+ return Type.parse("!llvm.ptr")
+
+
+def placeholder_opaque():
+ return opaque("scf", "placeholder")
+
+
+_name_to_type = {
+ "index": _index,
+ "bool": _bool,
+ "i8": _i8,
+ "i16": _i16,
+ "i32": _i32,
+ "i64": _i64,
+ "si8": _si8,
+ "si16": _si16,
+ "si32": _si32,
+ "si64": _si64,
+ "ui8": _ui8,
+ "ui16": _ui16,
+ "ui32": _ui32,
+ "ui64": _ui64,
+ "f16": _f16,
+ "f32": _f32,
+ "f64": _f64,
+ "bf16": _bf16,
+ "f8e5m2": _f8e5m2,
+ "f8e4m3": _f8e4m3,
+ "f8e4m3b11fnuz": _f8e4m3b11fnuz,
+ "cmp16": _cmp16,
+ "cmp32": _cmp32,
+ "cmp64": _cmp64,
+ "none": _none,
+ "pdl_operation": _pdl_operation,
+ "transform_any_op": _transform_any_op,
+ "llvm_ptr": _llvm_ptr,
+}
+
+
+def __getattr__(name):
+ if name in _name_to_type:
+ return _name_to_type[name]()
+ # this kicks it to the default module attribute lookup (i.e., functions defined below and such)
+ return None
+
+
+def shaped(*args, element_type: Type = None, type_constructor=None):
+ if type_constructor is None:
+ raise ValueError("shaped is an abstract base class - cannot be constructed")
+ if (element_type is None and args and not isinstance(args[-1], Type)) or (
+ args and isinstance(args[-1], Type) and element_type is not None
+ ):
+ raise ValueError(
+ f"Either element_type must be provided explicitly XOR last arg to tensor type constructor must be the element type."
+ )
+ if element_type is not None:
+ type = element_type
+ sizes = args
+ else:
+ type = args[-1]
+ sizes = args[:-1]
+ if sizes:
+ return type_constructor(sizes, type)
+ else:
+ return type_constructor(type)
+
+
+def vector(*args, element_type: Type = None):
+ return shaped(*args, element_type=element_type, type_constructor=VectorType.get)
+
+
+def tensor(*args, element_type: Type = None):
+ if not len(args) or len(args) == 1 and isinstance(args[-1], Type):
+ return shaped(
+ *args, element_type=element_type, type_constructor=UnrankedTensorType.get
+ )
+ else:
+ return shaped(
+ *args, element_type=element_type, type_constructor=RankedTensorType.get
+ )
+
+
+def memref(
+ *args,
+ element_type: Type = None,
+ memory_space: int = None,
+ layout: tuple[tuple[int, ...], int] = None,
+):
+ if memory_space is None:
+ memory_space = 0
+ if layout is not None:
+ strides, offset = layout
+ layout = StridedLayoutAttr.get(offset, strides)
+ memory_space = Attribute.parse(str(memory_space))
+ if not len(args) or len(args) == 1 and isinstance(args[-1], Type):
+ return shaped(
+ *args,
+ element_type=element_type,
+ type_constructor=partial(UnrankedMemRefType.get, memory_space=memory_space),
+ )
+ else:
+ return shaped(
+ *args,
+ element_type=element_type,
+ type_constructor=partial(
+ MemRefType.get, memory_space=memory_space, layout=layout
+ ),
+ )
+
+
+def transform_op(name):
+ return transform.OperationType.get(name)
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 672418b5383ae45..197432b0010b575 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -337,7 +337,6 @@ def testRankedTensorType():
assert RankedTensorType.get(shape, f32).encoding is None
-
# CHECK-LABEL: TEST: testUnrankedTensorType
@run
def testUnrankedTensorType():
@@ -733,3 +732,41 @@ def testCustomTypeTypeCaster():
print(t)
# CHECK: OperationType(!transform.op<"foo.bar">)
print(repr(t))
+
+
+# CHECK-LABEL: TEST: testTypeWrappers
+ at run
+def testTypeWrappers():
+
+ try:
+ from mlir.types import i32
+ except RuntimeError as e:
+ # CHECK: RuntimeError: An MLIR function requires a Context
+ print(e)
+
+ import mlir.types as T
+ from mlir.types import vector, tensor
+
+ with Context(), Location.unknown():
+ S = ShapedType.get_dynamic_size()
+ t = T.tensor(S, 3, S, T.f64)
+
+ assert repr(t) == "RankedTensorType(tensor<?x3x?xf64>)"
+ ut = tensor(T.f64)
+ assert repr(ut) == "UnrankedTensorType(tensor<*xf64>)"
+ t = tensor(S, 3, S, element_type=T.f64)
+ assert repr(t) == "RankedTensorType(tensor<?x3x?xf64>)"
+ ut = tensor(element_type=T.f64)
+ assert repr(ut) == "UnrankedTensorType(tensor<*xf64>)"
+
+ v = vector(3, 3, 3, T.f64)
+ assert repr(v) == "VectorType(vector<3x3x3xf64>)"
+
+ m = T.memref(S, 3, S, T.f64)
+ assert repr(m) == "MemRefType(memref<?x3x?xf64>)"
+ um = T.memref(T.f64)
+ assert repr(um) == "UnrankedMemRefType(memref<*xf64>)"
+ m = T.memref(S, 3, S, element_type=T.f64)
+ assert repr(m) == "MemRefType(memref<?x3x?xf64>)"
+ um = T.memref(element_type=T.f64)
+ assert repr(um) == "UnrankedMemRefType(memref<*xf64>)"
More information about the Mlir-commits
mailing list