[Mlir-commits] [mlir] [mlir][nvgpu] NVGPU Tutorials (PR #87065)

Guray Ozen llvmlistbot at llvm.org
Fri Apr 12 11:14:46 PDT 2024


================
@@ -0,0 +1,462 @@
+from enum import Enum
+import functools, sys, ctypes, os, errno
+import numpy as np
+from functools import partialmethod
+from mlir import ir
+from mlir.dialects import arith, func, gpu, memref, nvgpu, scf, nvvm
+from mlir.extras import types as T
+from mlir import runtime as rt
+from tools import nvgpucompiler
+
+MLIR_DYNAMIC = -9223372036854775808
+
+
+def const(value: int, ty=None):
+    ty = T.index() if ty is None else ty
+    if isinstance(value, ir.Value) and (
+        value.type.isinstance(value.type) or T.bool().isinstance(value.type)
+    ):
+        return value
+    return arith.constant(ty, value)
+
+
+def get_type_size(ty):
+    if ir.MemRefType.isinstance(ty):
+        size = get_type_size(ty.element_type)
+        for sz in ty.shape:
+            size *= sz
+        return size
+    if ir.FloatType.isinstance(ty):
+        return ir.FloatType(ty).width // 8
+    if ir.IntegerType.isinstance(ty):
+        return ir.IntegerType(ty).width // 8
+    raise NotImplementedError(ty)
+
+
+def get_mlir_func_obj_ty(inputArgs):
+    args = []
+    c_int_p = ctypes.c_int * 1
+    c_float_p = ctypes.c_float * 1
+    c_bool_p = ctypes.c_bool * 1
+    for arg in inputArgs:
+        if isinstance(arg, bool):
+            args.append(c_bool_p(arg))
+        elif isinstance(arg, int):
+            args.append(c_int_p(arg))
+        elif isinstance(arg, float):
+            args.append(c_float_p(arg))
+        elif isinstance(arg, np.ndarray):
+            args.append(
+                ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(arg)))
+            )
+        else:
+            raise NotImplementedError(arg)
+    return args
+
+
+class Mbarriers:
+    def __init__(self, number_of_barriers=1):
+        self.mbar_ty = ir.Type.parse(
+            "!nvgpu.mbarrier.group<memorySpace=#gpu.address_space<workgroup>, num_barriers = "
+            + str(number_of_barriers)
+            + ">"
+        )
+        self.mbar_group_op = nvgpu.mbarrier_create(self.mbar_ty)
+        self.number_of_barriers = number_of_barriers
+
+    def __getitem__(self, key):
+        self.id_op = const(key)
+        return self
+
+    def init(self, count: int, predicate=None):
+        count_op = const(count)
+        if predicate is None:
+            nvgpu.mbarrier_init(self.mbar_group_op, count_op, self.id_op)
+        else:
+            nvgpu.mbarrier_init(
+                self.mbar_group_op, count_op, self.id_op, predicate=predicate
+            )
+
+    def arrive(self, txcount: int = 0, predicate=None):
+        if txcount != 0:
+            txcount_op = const(txcount)
+            nvgpu.mbarrier_arrive_expect_tx(
+                self.mbar_group_op, txcount_op, self.id_op, predicate=predicate
+            )
+        else:
+            nvgpu.mbarrier_arrive(
+                ir.Type.parse("!nvgpu.mbarrier.token"), self.mbar_group_op, self.id_op
+            )
+
+    def try_wait(self, phase: bool = False, ticks: int = 10000000):
+        ticks_op = const(ticks)
+        phase_op = const(phase, T.bool())
+        nvgpu.MBarrierTryWaitParityOp(
+            self.mbar_group_op,
+            phase_op,
+            ticks_op,
+            mbarId=self.id_op,
+        )
+
+
+class TMA:
+    """A class that builds a TMA descriptor."""
+
+    def __init__(
+        self,
+        shape,
+        memref_ty,
+        swizzle=nvgpu.TensorMapSwizzleKind.SWIZZLE_NONE,
+        l2promo=nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
+        oob=nvgpu.TensorMapOOBKind.OOB_ZERO,
+        interleave=nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
+    ):
+        self.swizzle = swizzle  # mlir.nvgpu.TensorMapSwizzleKind
+        self.l2promo = l2promo  # mlir.nvgpu.TensorMapL2PromoKind
+        self.oob = oob  # mlir.nvgpu.TensorMapOOBKind
+        self.interleave = interleave  # mlir.nvgpu.TensorMapInterleaveKind
+        self.shape = shape
+        self.memref_ty = memref_ty  # MemRefType
+        self.lastDim = 64
+        self.requiredLoad = 1
+        self.tma_shape = shape
+        self.tma_memref = ir.MemRefType.get(shape, memref_ty.element_type)
+
+    @property
+    def tensormap_descriptor_ty(self):
+        """Returns a tensormap descriptor type."""
+        memref_str = (
+            "memref<"
+            + "x".join(map(str, self.tma_shape))
+            + "x"
+            + str(self.memref_ty.element_type)
+            + ", 3>"
+        )
+        parse_str = (
+            f"!nvgpu.tensormap.descriptor<tensor = {memref_str}, "
+            f"swizzle = {self.swizzle}, "
+            f"l2promo = {self.l2promo}, "
+            f"oob = {self.oob}, "
+            f"interleave = {self.interleave}>"
+        )
+        return ir.Type.parse(parse_str)
+
+    def create_descriptor(self, device_ptr):
+        tma_descriptor_ty = self.tensormap_descriptor_ty
+        device_unranked_memref = memref.CastOp(
+            ir.UnrankedMemRefType.get(
+                self.memref_ty.element_type, self.memref_ty.memory_space
+            ),
+            device_ptr,
+        )
+        self.tma_descriptor = nvgpu.TmaCreateDescriptorOp(
+            tma_descriptor_ty, device_unranked_memref, map(const, self.tma_shape)
+        )
+        return self.tma_descriptor.result
+
+    def prefetch(self, predicate=None):
+        nvgpu.tma_prefetch_descriptor(self.tma_descriptor, predicate=predicate)
+
+    def load(self, dest, mbarrier: Mbarriers, coords=[0], predicate=None):
+        nvgpu.TmaAsyncLoadOp(
+            dest,
+            mbarrier.mbar_group_op,
+            self.tma_descriptor,
+            coordinates=map(const, coords),
+            mbarId=mbarrier.id_op,
+            predicate=predicate,
+        )
+
+
+WARP_GROUP_SIZE = 128  # Number of threads in a warpgroup
+
+
+class Warpgroup:
+    def __init__(self, primaryThread, registerSize):
+        tidx = gpu.thread_id(gpu.Dimension.x)
+        self.primary_thread = primaryThread
+        self.register_size = registerSize
+        self.is_wg_primary = (tidx % WARP_GROUP_SIZE) == 0
+        self.wg_id = tidx / WARP_GROUP_SIZE
+        self.is_me = self.wg_id == (primaryThread // WARP_GROUP_SIZE)
+
+    def __enter__(self):
+        if_op = scf.IfOp(self.is_me)
+        self.ipoint_op = ir.InsertionPoint(if_op.then_block)
+        self.ipoint_op.__enter__()
+        if self.register_size < 64:
+            nvvm.setmaxregister(self.register_size, nvvm.SetMaxRegisterAction.decrease)
+        else:
+            nvvm.setmaxregister(self.register_size, nvvm.SetMaxRegisterAction.increase)
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        scf.yield_([])
+        self.ipoint_op.__exit__(exc_type, exc_value, traceback)
+        return True
+
+
+class WGMMAType(Enum):
+    Accumulator = 1
+    Descriptor = 2
+
+
+class WGMMAMatrix:
+    def __init__(
+        self,
+        matrix_type: WGMMAType,
+        shape: list = None,
+        desc: TMA = None,
+        smem=None,
+        ty=None,
+        acc_op=None,
+    ):
+        if acc_op is None:
+            self.M = shape[0]
+            self.N = shape[1]
+            self.ty = ty
+            self.matrix_type = matrix_type
+            self.desc = desc
+            self.smem = smem
+            if matrix_type is WGMMAType.Accumulator:
+                self.acc_op = nvgpu.warpgroup_mma_init_accumulator(self.acc_ty)
+        elif acc_op:
+            self.acc_op = acc_op
+            self.matrix_type = WGMMAType.Accumulator
+
+    @property
+    def acc_ty(self):
+        parse_str = f"!nvgpu.warpgroup.accumulator<fragmented=vector<{self.M}x{self.N}x{self.ty}>>"
+        return ir.Type.parse(parse_str)
+
+    @property
+    def wgmma_ty(self):
+        parse_str = f"!nvgpu.warpgroup.descriptor<tensor=memref<{self.M}x{self.N}x{self.desc.memref_ty.element_type}, #gpu.address_space<workgroup>>>"
+        return ir.Type.parse(parse_str)
+
+    def store_accumulator(self, dest):
+        assert self.matrix_type == WGMMAType.Accumulator
+        nvgpu.warpgroup_mma_store(self.acc_op, dest)
+
+    def update_smem(self, smem):
+        self.smem = smem
+
+    def update_accumulator(self, acc_op):
+        self.acc_op = acc_op
+
+    def __matmul__(self, rhs):
+        lhs = nvgpu.warpgroup_generate_descriptor(
+            self.wgmma_ty, self.smem, self.desc.tma_descriptor
+        )
+        rhs = nvgpu.warpgroup_generate_descriptor(
+            rhs.wgmma_ty, rhs.smem, rhs.desc.tma_descriptor
+        )
+        return [lhs, rhs]
+
+    def __iadd__(self, matmulResult):
+        lhs = matmulResult[0]
+        rhs = matmulResult[1]
+        acc_op = nvgpu.WarpgroupMmaOp(
+            self.acc_op.type, lhs, rhs, self.acc_op, transposeB=True
+        )
+        return WGMMAMatrix(WGMMAType.Accumulator, acc_op=acc_op)
+
+
+def get_dynamic_shared_memory(shape=None, ty=None, offset: int = 0):
+    smem_space_str = "#gpu.address_space<workgroup>"
+    smem_space = ir.Attribute.parse(smem_space_str)
+    dynamic_smem = gpu.dynamic_shared_memory(
+        ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space)
+    )
+    if shape is None:
+        return dynamic_smem
+    memref_ty = ir.MemRefType.get(shape, ty, memory_space=smem_space)
+    return memref.view(
+        ir.MemRefType.get(
+            memref_ty.shape, memref_ty.element_type, memory_space=smem_space
+        ),
+        dynamic_smem,
+        const(offset),
+        [],
+    )
+
+
+ at staticmethod
+def get_mlir_ty(arg):
+    def get_mlir_ty_from_np(dtype):
+        if dtype == np.float16:
+            return T.f16()
+        if dtype == np.float32:
+            return T.f32()
+        if dtype == np.float64:
+            return T.f64()
+        if dtype == np.int32:
+            return T.i32()
+        if dtype == np.int64:
+            return T.i64()
+        raise NotImplementedError(dtype)
+
+    if isinstance(arg, bool):
+        return T.bool()
+    elif isinstance(arg, int):
+        return T.index()
+    elif isinstance(arg, float):
+        return T.f32()
+    elif isinstance(arg, np.ndarray):
+        descriptor = rt.get_ranked_memref_descriptor(arg)
+        dtype = get_mlir_ty_from_np(arg.dtype)
+        shape = descriptor.shape
+        return memref.MemRefType.get(shape, dtype)
+    raise NotImplementedError(arg)
+
+
+class NVDSL:
+    @staticmethod
+    def mlir_gpu_launch(grid=(1, 1, 1), block=(1, 1, 1), smem=0):
+        def decorator(func):
+            @functools.wraps(func)
+            def wrapper(*args, **kwargs):
+                launch_op = gpu.LaunchOp(
+                    None,
+                    [],
+                    *map(const, grid),
+                    *map(const, block),
+                    dynamicSharedMemorySize=arith.constant(T.i32(), smem),
+                )
+                launch_op.body.blocks.append(*([T.index()] * 12))
+                with ir.InsertionPoint(launch_op.body.blocks[0]):
+                    result = func(*args, **kwargs)
+                    gpu.terminator()
+                    return result
+
+            return wrapper
+
+        return decorator
+
+    @staticmethod
+    def mlir_func(funcBody):
+        @functools.wraps(funcBody)
+        def wrapper(*args, **kwargs):
+            function_name = funcBody.__name__
+
+            def saveIR(module):
+                """Save generated IR"""
+                if True:  # self.saveIR:
+                    # print(mlir_nvgpu_module)
+                    original_stdout = sys.stdout
+                    with open("nvdsl.mlir", "w") as f:
+                        sys.stdout = f
+                        print(module)
+                        sys.stdout = original_stdout
+
+            def _binary_op(lhs, rhs, op: str, predAtt="") -> "ArithValue":
----------------
grypp wrote:

It doesn't look great to me neither, I'll figure out as a follow-up PR

https://github.com/llvm/llvm-project/pull/87065


More information about the Mlir-commits mailing list