[Mlir-commits] [mlir] [mlir][python] add type wrappers (PR #71218)
Maksim Levental
llvmlistbot at llvm.org
Tue Nov 14 11:35:18 PST 2023
================
@@ -0,0 +1,207 @@
+from functools import partial
+from typing import Optional
+
+from .ir import (
+ Attribute,
+ BF16Type,
+ ComplexType,
+ F16Type,
+ F32Type,
+ F64Type,
+ Float8E4M3B11FNUZType,
+ Float8E4M3FNType,
+ Float8E5M2Type,
+ FunctionType,
+ IndexType,
+ IntegerType,
+ MemRefType,
+ NoneType,
+ OpaqueType,
+ RankedTensorType,
+ StridedLayoutAttr,
+ StringAttr,
+ TupleType,
+ Type,
+ UnrankedMemRefType,
+ UnrankedTensorType,
+ VectorType,
+)
+
+from .dialects import transform
+from .dialects import pdl
+
+
+_index = lambda: IndexType.get()
+_bool = lambda: IntegerType.get_signless(1)
+
+_i8 = lambda: IntegerType.get_signless(8)
+_i16 = lambda: IntegerType.get_signless(16)
+_i32 = lambda: IntegerType.get_signless(32)
+_i64 = lambda: IntegerType.get_signless(64)
+
+_si8 = lambda: IntegerType.get_signed(8)
+_si16 = lambda: IntegerType.get_signed(16)
+_si32 = lambda: IntegerType.get_signed(32)
+_si64 = lambda: IntegerType.get_signed(64)
+
+_ui8 = lambda: IntegerType.get_unsigned(8)
+_ui16 = lambda: IntegerType.get_unsigned(16)
+_ui32 = lambda: IntegerType.get_unsigned(32)
+_ui64 = lambda: IntegerType.get_unsigned(64)
+
+_f16 = lambda: F16Type.get()
+_f32 = lambda: F32Type.get()
+_f64 = lambda: F64Type.get()
+_bf16 = lambda: BF16Type.get()
+
+_f8e5m2 = lambda: Float8E5M2Type.get()
+_f8e4m3 = lambda: Float8E4M3FNType.get()
+_f8e4m3b11fnuz = lambda: Float8E4M3B11FNUZType.get()
+
+_cmp16 = lambda: ComplexType.get(_f16())
+_cmp32 = lambda: ComplexType.get(_f32())
+_cmp64 = lambda: ComplexType.get(_f64())
+
+_none = lambda: NoneType.get()
+
+_pdl_operation = lambda: pdl.OperationType.get()
+
+
+def _transform_any_op():
+ return transform.AnyOpType.get()
+
+
+_name_to_type = {
+ "index": _index,
+ "bool": _bool,
+ "i8": _i8,
+ "i16": _i16,
+ "i32": _i32,
+ "i64": _i64,
+ "si8": _si8,
+ "si16": _si16,
+ "si32": _si32,
+ "si64": _si64,
+ "ui8": _ui8,
+ "ui16": _ui16,
+ "ui32": _ui32,
+ "ui64": _ui64,
+ "f16": _f16,
+ "f32": _f32,
+ "f64": _f64,
+ "bf16": _bf16,
+ "f8e5m2": _f8e5m2,
+ "f8e4m3": _f8e4m3,
+ "f8e4m3b11fnuz": _f8e4m3b11fnuz,
+ "cmp16": _cmp16,
+ "cmp32": _cmp32,
+ "cmp64": _cmp64,
+ "none": _none,
+ "pdl_operation": _pdl_operation,
+ "transform_any_op": _transform_any_op,
+}
+
+
+def __getattr__(name):
+ if name in _name_to_type:
+ return _name_to_type[name]()
+ # This delegates the lookup to default module attribute lookup
+ # (i.e., functions defined below and such).
+ return None
+
+
+def transform_op(name):
+ return transform.OperationType.get(name)
+
+
+def opaque(dialect_namespace, type_data):
+ return OpaqueType.get(dialect_namespace, type_data)
+
+
+def _shaped(*shape, element_type: Type = None, type_constructor=None):
+ if type_constructor is None:
+ raise ValueError("shaped is an abstract base class - cannot be constructed.")
+ if (element_type is None and shape and not isinstance(shape[-1], Type)) or (
+ shape and isinstance(shape[-1], Type) and element_type is not None
+ ):
+ raise ValueError(
+ f"Either element_type must be provided explicitly XOR last arg to tensor type constructor must be the element type."
+ )
+ if element_type is not None:
+ type = element_type
+ sizes = shape
+ else:
+ type = shape[-1]
+ sizes = shape[:-1]
+ if sizes:
+ return type_constructor(sizes, type)
+ else:
+ return type_constructor(type)
+
+
+def vector(
+ *shape,
+ element_type: Type = None,
+ scalable: Optional[list[bool]] = None,
+ scalable_dims: Optional[list[int]] = None,
+):
+ return _shaped(
+ *shape,
+ element_type=element_type,
+ type_constructor=partial(
+ VectorType.get, scalable=scalable, scalable_dims=scalable_dims
+ ),
+ )
+
+
+def tensor(*shape, element_type: Type = None, encoding: Optional[str] = None):
+ if encoding is not None:
+ encoding = StringAttr.get(encoding)
+ if not len(shape) or len(shape) == 1 and isinstance(shape[-1], Type):
+ if encoding is not None:
+ raise ValueError("UnrankedTensorType does not support encoding.")
+ return _shaped(
+ *shape, element_type=element_type, type_constructor=UnrankedTensorType.get
+ )
+ else:
----------------
makslevental wrote:
err whoops - got it on `memref` but not `tensor`. thanks!
https://github.com/llvm/llvm-project/pull/71218
More information about the Mlir-commits
mailing list