[Mlir-commits] [mlir] [mlir][python] add type wrappers (PR #71218)

Maksim Levental llvmlistbot at llvm.org
Tue Nov 14 11:35:18 PST 2023


================
@@ -0,0 +1,207 @@
+from functools import partial
+from typing import Optional
+
+from .ir import (
+    Attribute,
+    BF16Type,
+    ComplexType,
+    F16Type,
+    F32Type,
+    F64Type,
+    Float8E4M3B11FNUZType,
+    Float8E4M3FNType,
+    Float8E5M2Type,
+    FunctionType,
+    IndexType,
+    IntegerType,
+    MemRefType,
+    NoneType,
+    OpaqueType,
+    RankedTensorType,
+    StridedLayoutAttr,
+    StringAttr,
+    TupleType,
+    Type,
+    UnrankedMemRefType,
+    UnrankedTensorType,
+    VectorType,
+)
+
+from .dialects import transform
+from .dialects import pdl
+
+
+_index = lambda: IndexType.get()
+_bool = lambda: IntegerType.get_signless(1)
+
+_i8 = lambda: IntegerType.get_signless(8)
+_i16 = lambda: IntegerType.get_signless(16)
+_i32 = lambda: IntegerType.get_signless(32)
+_i64 = lambda: IntegerType.get_signless(64)
+
+_si8 = lambda: IntegerType.get_signed(8)
+_si16 = lambda: IntegerType.get_signed(16)
+_si32 = lambda: IntegerType.get_signed(32)
+_si64 = lambda: IntegerType.get_signed(64)
+
+_ui8 = lambda: IntegerType.get_unsigned(8)
+_ui16 = lambda: IntegerType.get_unsigned(16)
+_ui32 = lambda: IntegerType.get_unsigned(32)
+_ui64 = lambda: IntegerType.get_unsigned(64)
+
+_f16 = lambda: F16Type.get()
+_f32 = lambda: F32Type.get()
+_f64 = lambda: F64Type.get()
+_bf16 = lambda: BF16Type.get()
+
+_f8e5m2 = lambda: Float8E5M2Type.get()
+_f8e4m3 = lambda: Float8E4M3FNType.get()
+_f8e4m3b11fnuz = lambda: Float8E4M3B11FNUZType.get()
+
+_cmp16 = lambda: ComplexType.get(_f16())
+_cmp32 = lambda: ComplexType.get(_f32())
+_cmp64 = lambda: ComplexType.get(_f64())
+
+_none = lambda: NoneType.get()
+
+_pdl_operation = lambda: pdl.OperationType.get()
+
+
+def _transform_any_op():
+    return transform.AnyOpType.get()
+
+
+_name_to_type = {
+    "index": _index,
+    "bool": _bool,
+    "i8": _i8,
+    "i16": _i16,
+    "i32": _i32,
+    "i64": _i64,
+    "si8": _si8,
+    "si16": _si16,
+    "si32": _si32,
+    "si64": _si64,
+    "ui8": _ui8,
+    "ui16": _ui16,
+    "ui32": _ui32,
+    "ui64": _ui64,
+    "f16": _f16,
+    "f32": _f32,
+    "f64": _f64,
+    "bf16": _bf16,
+    "f8e5m2": _f8e5m2,
+    "f8e4m3": _f8e4m3,
+    "f8e4m3b11fnuz": _f8e4m3b11fnuz,
+    "cmp16": _cmp16,
+    "cmp32": _cmp32,
+    "cmp64": _cmp64,
+    "none": _none,
+    "pdl_operation": _pdl_operation,
+    "transform_any_op": _transform_any_op,
+}
+
+
+def __getattr__(name):
+    if name in _name_to_type:
+        return _name_to_type[name]()
+    # This delegates the lookup to default module attribute lookup
+    # (i.e., functions defined below and such).
+    return None
+
+
+def transform_op(name):
+    return transform.OperationType.get(name)
+
+
+def opaque(dialect_namespace, type_data):
+    return OpaqueType.get(dialect_namespace, type_data)
+
+
+def _shaped(*shape, element_type: Type = None, type_constructor=None):
+    if type_constructor is None:
+        raise ValueError("shaped is an abstract base class - cannot be constructed.")
+    if (element_type is None and shape and not isinstance(shape[-1], Type)) or (
+        shape and isinstance(shape[-1], Type) and element_type is not None
+    ):
+        raise ValueError(
+            f"Either element_type must be provided explicitly XOR last arg to tensor type constructor must be the element type."
+        )
+    if element_type is not None:
+        type = element_type
+        sizes = shape
+    else:
+        type = shape[-1]
+        sizes = shape[:-1]
+    if sizes:
+        return type_constructor(sizes, type)
+    else:
+        return type_constructor(type)
+
+
+def vector(
+    *shape,
+    element_type: Type = None,
+    scalable: Optional[list[bool]] = None,
+    scalable_dims: Optional[list[int]] = None,
+):
+    return _shaped(
+        *shape,
+        element_type=element_type,
+        type_constructor=partial(
+            VectorType.get, scalable=scalable, scalable_dims=scalable_dims
+        ),
+    )
+
+
+def tensor(*shape, element_type: Type = None, encoding: Optional[str] = None):
+    if encoding is not None:
+        encoding = StringAttr.get(encoding)
+    if not len(shape) or len(shape) == 1 and isinstance(shape[-1], Type):
+        if encoding is not None:
+            raise ValueError("UnrankedTensorType does not support encoding.")
+        return _shaped(
+            *shape, element_type=element_type, type_constructor=UnrankedTensorType.get
+        )
+    else:
----------------
makslevental wrote:

err whoops - got it on `memref` but not `tensor`. thanks!

https://github.com/llvm/llvm-project/pull/71218


More information about the Mlir-commits mailing list