[Mlir-commits] [mlir] [mlir][nvgpu] NVGPU Tutorials (PR #87065)

Jacques Pienaar llvmlistbot at llvm.org
Wed Apr 10 00:59:27 PDT 2024


================
@@ -0,0 +1,462 @@
+from enum import Enum
+import functools, sys, ctypes, os, errno
+import numpy as np
+from functools import partialmethod
+from mlir import ir
+from mlir.dialects import arith, func, gpu, memref, nvgpu, scf, nvvm
+from mlir.extras import types as T
+from mlir import runtime as rt
+from tools import nvgpucompiler
+
+MLIR_DYNAMIC = -9223372036854775808
+
+
+def const(value: int, ty=None):
+    ty = T.index() if ty is None else ty
+    if isinstance(value, ir.Value) and (
+        value.type.isinstance(value.type) or T.bool().isinstance(value.type)
+    ):
+        return value
+    return arith.constant(ty, value)
+
+
+def get_type_size(ty):
+    if ir.MemRefType.isinstance(ty):
+        size = get_type_size(ty.element_type)
+        for sz in ty.shape:
+            size *= sz
+        return size
+    if ir.FloatType.isinstance(ty):
+        return ir.FloatType(ty).width // 8
+    if ir.IntegerType.isinstance(ty):
+        return ir.IntegerType(ty).width // 8
+    raise NotImplementedError(ty)
+
+
+def get_mlir_func_obj_ty(inputArgs):
+    args = []
+    c_int_p = ctypes.c_int * 1
+    c_float_p = ctypes.c_float * 1
+    c_bool_p = ctypes.c_bool * 1
+    for arg in inputArgs:
+        if isinstance(arg, bool):
+            args.append(c_bool_p(arg))
+        elif isinstance(arg, int):
+            args.append(c_int_p(arg))
+        elif isinstance(arg, float):
+            args.append(c_float_p(arg))
+        elif isinstance(arg, np.ndarray):
+            args.append(
+                ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(arg)))
+            )
+        else:
+            raise NotImplementedError(arg)
+    return args
+
+
+class Mbarriers:
+    def __init__(self, number_of_barriers=1):
+        self.mbar_ty = ir.Type.parse(
+            "!nvgpu.mbarrier.group<memorySpace=#gpu.address_space<workgroup>, num_barriers = "
+            + str(number_of_barriers)
+            + ">"
+        )
+        self.mbar_group_op = nvgpu.mbarrier_create(self.mbar_ty)
+        self.number_of_barriers = number_of_barriers
+
+    def __getitem__(self, key):
+        self.id_op = const(key)
+        return self
+
+    def init(self, count: int, predicate=None):
+        count_op = const(count)
+        if predicate is None:
+            nvgpu.mbarrier_init(self.mbar_group_op, count_op, self.id_op)
+        else:
+            nvgpu.mbarrier_init(
+                self.mbar_group_op, count_op, self.id_op, predicate=predicate
+            )
+
+    def arrive(self, txcount: int = 0, predicate=None):
+        if txcount != 0:
+            txcount_op = const(txcount)
+            nvgpu.mbarrier_arrive_expect_tx(
+                self.mbar_group_op, txcount_op, self.id_op, predicate=predicate
+            )
+        else:
+            nvgpu.mbarrier_arrive(
+                ir.Type.parse("!nvgpu.mbarrier.token"), self.mbar_group_op, self.id_op
+            )
+
+    def try_wait(self, phase: bool = False, ticks: int = 10000000):
+        ticks_op = const(ticks)
+        phase_op = const(phase, T.bool())
+        nvgpu.MBarrierTryWaitParityOp(
+            self.mbar_group_op,
+            phase_op,
+            ticks_op,
+            mbarId=self.id_op,
+        )
+
+
+class TMA:
+    """A class that builds a TMA descriptor."""
+
+    def __init__(
+        self,
+        shape,
+        memref_ty,
+        swizzle=nvgpu.TensorMapSwizzleKind.SWIZZLE_NONE,
+        l2promo=nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
+        oob=nvgpu.TensorMapOOBKind.OOB_ZERO,
+        interleave=nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
+    ):
+        self.swizzle = swizzle  # mlir.nvgpu.TensorMapSwizzleKind
+        self.l2promo = l2promo  # mlir.nvgpu.TensorMapL2PromoKind
+        self.oob = oob  # mlir.nvgpu.TensorMapOOBKind
+        self.interleave = interleave  # mlir.nvgpu.TensorMapInterleaveKind
+        self.shape = shape
+        self.memref_ty = memref_ty  # MemRefType
+        self.lastDim = 64
+        self.requiredLoad = 1
+        self.tma_shape = shape
+        self.tma_memref = ir.MemRefType.get(shape, memref_ty.element_type)
+
+    @property
+    def tensormap_descriptor_ty(self):
+        """Returns a tensormap descriptor type."""
+        memref_str = (
+            "memref<"
+            + "x".join(map(str, self.tma_shape))
+            + "x"
+            + str(self.memref_ty.element_type)
+            + ", 3>"
+        )
+        parse_str = (
+            f"!nvgpu.tensormap.descriptor<tensor = {memref_str}, "
+            f"swizzle = {self.swizzle}, "
+            f"l2promo = {self.l2promo}, "
+            f"oob = {self.oob}, "
+            f"interleave = {self.interleave}>"
+        )
+        return ir.Type.parse(parse_str)
+
+    def create_descriptor(self, device_ptr):
+        tma_descriptor_ty = self.tensormap_descriptor_ty
+        device_unranked_memref = memref.CastOp(
+            ir.UnrankedMemRefType.get(
+                self.memref_ty.element_type, self.memref_ty.memory_space
+            ),
+            device_ptr,
+        )
+        self.tma_descriptor = nvgpu.TmaCreateDescriptorOp(
+            tma_descriptor_ty, device_unranked_memref, map(const, self.tma_shape)
+        )
+        return self.tma_descriptor.result
+
+    def prefetch(self, predicate=None):
+        nvgpu.tma_prefetch_descriptor(self.tma_descriptor, predicate=predicate)
+
+    def load(self, dest, mbarrier: Mbarriers, coords=[0], predicate=None):
+        nvgpu.TmaAsyncLoadOp(
+            dest,
+            mbarrier.mbar_group_op,
+            self.tma_descriptor,
+            coordinates=map(const, coords),
+            mbarId=mbarrier.id_op,
+            predicate=predicate,
+        )
+
+
+WARP_GROUP_SIZE = 128  # Number of threads in a warpgroup
+
+
+class Warpgroup:
+    def __init__(self, primaryThread, registerSize):
+        tidx = gpu.thread_id(gpu.Dimension.x)
+        self.primary_thread = primaryThread
+        self.register_size = registerSize
+        self.is_wg_primary = (tidx % WARP_GROUP_SIZE) == 0
+        self.wg_id = tidx / WARP_GROUP_SIZE
+        self.is_me = self.wg_id == (primaryThread // WARP_GROUP_SIZE)
+
+    def __enter__(self):
+        if_op = scf.IfOp(self.is_me)
+        self.ipoint_op = ir.InsertionPoint(if_op.then_block)
+        self.ipoint_op.__enter__()
+        if self.register_size < 64:
+            nvvm.setmaxregister(self.register_size, nvvm.SetMaxRegisterAction.decrease)
+        else:
+            nvvm.setmaxregister(self.register_size, nvvm.SetMaxRegisterAction.increase)
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        scf.yield_([])
+        self.ipoint_op.__exit__(exc_type, exc_value, traceback)
+        return True
+
+
+class WGMMAType(Enum):
+    Accumulator = 1
+    Descriptor = 2
+
+
+class WGMMAMatrix:
+    def __init__(
+        self,
+        matrix_type: WGMMAType,
+        shape: list = None,
+        desc: TMA = None,
+        smem=None,
+        ty=None,
+        acc_op=None,
+    ):
+        if acc_op is None:
+            self.M = shape[0]
+            self.N = shape[1]
+            self.ty = ty
+            self.matrix_type = matrix_type
+            self.desc = desc
+            self.smem = smem
+            if matrix_type is WGMMAType.Accumulator:
+                self.acc_op = nvgpu.warpgroup_mma_init_accumulator(self.acc_ty)
+        elif acc_op:
+            self.acc_op = acc_op
+            self.matrix_type = WGMMAType.Accumulator
+
+    @property
+    def acc_ty(self):
+        parse_str = f"!nvgpu.warpgroup.accumulator<fragmented=vector<{self.M}x{self.N}x{self.ty}>>"
+        return ir.Type.parse(parse_str)
+
+    @property
+    def wgmma_ty(self):
+        parse_str = f"!nvgpu.warpgroup.descriptor<tensor=memref<{self.M}x{self.N}x{self.desc.memref_ty.element_type}, #gpu.address_space<workgroup>>>"
+        return ir.Type.parse(parse_str)
+
+    def store_accumulator(self, dest):
+        assert self.matrix_type == WGMMAType.Accumulator
+        nvgpu.warpgroup_mma_store(self.acc_op, dest)
+
+    def update_smem(self, smem):
+        self.smem = smem
+
+    def update_accumulator(self, acc_op):
+        self.acc_op = acc_op
+
+    def __matmul__(self, rhs):
+        lhs = nvgpu.warpgroup_generate_descriptor(
+            self.wgmma_ty, self.smem, self.desc.tma_descriptor
+        )
+        rhs = nvgpu.warpgroup_generate_descriptor(
+            rhs.wgmma_ty, rhs.smem, rhs.desc.tma_descriptor
+        )
+        return [lhs, rhs]
+
+    def __iadd__(self, matmulResult):
+        lhs = matmulResult[0]
+        rhs = matmulResult[1]
+        acc_op = nvgpu.WarpgroupMmaOp(
+            self.acc_op.type, lhs, rhs, self.acc_op, transposeB=True
+        )
+        return WGMMAMatrix(WGMMAType.Accumulator, acc_op=acc_op)
+
+
+def get_dynamic_shared_memory(shape=None, ty=None, offset: int = 0):
+    smem_space_str = "#gpu.address_space<workgroup>"
+    smem_space = ir.Attribute.parse(smem_space_str)
+    dynamic_smem = gpu.dynamic_shared_memory(
+        ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space)
+    )
+    if shape is None:
+        return dynamic_smem
+    memref_ty = ir.MemRefType.get(shape, ty, memory_space=smem_space)
+    return memref.view(
+        ir.MemRefType.get(
+            memref_ty.shape, memref_ty.element_type, memory_space=smem_space
+        ),
+        dynamic_smem,
+        const(offset),
+        [],
+    )
+
+
+ at staticmethod
+def get_mlir_ty(arg):
+    def get_mlir_ty_from_np(dtype):
+        if dtype == np.float16:
+            return T.f16()
+        if dtype == np.float32:
+            return T.f32()
+        if dtype == np.float64:
+            return T.f64()
+        if dtype == np.int32:
+            return T.i32()
+        if dtype == np.int64:
+            return T.i64()
+        raise NotImplementedError(dtype)
+
+    if isinstance(arg, bool):
+        return T.bool()
+    elif isinstance(arg, int):
+        return T.index()
+    elif isinstance(arg, float):
+        return T.f32()
+    elif isinstance(arg, np.ndarray):
+        descriptor = rt.get_ranked_memref_descriptor(arg)
+        dtype = get_mlir_ty_from_np(arg.dtype)
+        shape = descriptor.shape
+        return memref.MemRefType.get(shape, dtype)
+    raise NotImplementedError(arg)
+
+
+class NVDSL:
+    @staticmethod
+    def mlir_gpu_launch(grid=(1, 1, 1), block=(1, 1, 1), smem=0):
+        def decorator(func):
+            @functools.wraps(func)
+            def wrapper(*args, **kwargs):
+                launch_op = gpu.LaunchOp(
+                    None,
+                    [],
+                    *map(const, grid),
+                    *map(const, block),
+                    dynamicSharedMemorySize=arith.constant(T.i32(), smem),
+                )
+                launch_op.body.blocks.append(*([T.index()] * 12))
+                with ir.InsertionPoint(launch_op.body.blocks[0]):
+                    result = func(*args, **kwargs)
+                    gpu.terminator()
+                    return result
+
+            return wrapper
+
+        return decorator
+
+    @staticmethod
+    def mlir_func(funcBody):
+        @functools.wraps(funcBody)
+        def wrapper(*args, **kwargs):
+            function_name = funcBody.__name__
+
+            def saveIR(module):
+                """Save generated IR"""
+                if True:  # self.saveIR:
+                    # print(mlir_nvgpu_module)
+                    original_stdout = sys.stdout
+                    with open("nvdsl.mlir", "w") as f:
+                        sys.stdout = f
+                        print(module)
+                        sys.stdout = original_stdout
+
+            def _binary_op(lhs, rhs, op: str, predAtt="") -> "ArithValue":
+                """Generate MLIR's Arith dialects binary operations."""
+                rhs = const(rhs)
+                if arith._is_float_type(lhs.type) and arith._is_float_type(rhs.type):
+                    op += "F"
+                    if op.startswith("Cmp"):
+                        predicateAttr = getattr(arith, f"CmpFPredicate").__dict__[
+                            predAtt
+                        ]
+                elif arith._is_integer_like_type(
+                    lhs.type
+                ) and arith._is_integer_like_type(lhs.type):
+                    if op == "Div" or op == "Rem":
+                        op += "U"
+                    op += "I"
+                    if op.startswith("Cmp"):
+                        predicateAttr = getattr(arith, f"CmpIPredicate").__dict__[
+                            predAtt
+                        ]
+                else:
+                    raise NotImplementedError(
+                        f"Unsupported '{op}' operands: {lhs}, {rhs}"
+                    )
+
+                if op.startswith("Cmp"):
+                    op = getattr(arith, f"{op}Op")
+
+                    return op(predicateAttr, lhs, rhs).result
+                else:
+                    op = getattr(arith, f"{op}Op")
+                    return op(lhs, rhs).result
+
+            @ir.register_value_caster(ir.IndexType.static_typeid)
+            @ir.register_value_caster(ir.F32Type.static_typeid)
+            @ir.register_value_caster(ir.F16Type.static_typeid)
+            @ir.register_value_caster(ir.F64Type.static_typeid)
+            @ir.register_value_caster(ir.IntegerType.static_typeid)
+            class ArithValue(ir.Value):
+                """Overloads operators for MLIR's Arith dialects binary operations."""
+
+                def __init__(self, v):
+                    super().__init__(v)
+
+                __add__ = partialmethod(_binary_op, op="Add")
+                __sub__ = partialmethod(_binary_op, op="Sub")
+                __mul__ = partialmethod(_binary_op, op="Mul")
+                __truediv__ = partialmethod(_binary_op, op="Div")
+                __mod__ = partialmethod(_binary_op, op="Rem")
+                __xor__ = partialmethod(_binary_op, op="XOr")
+                __lt__ = partialmethod(_binary_op, op="Cmp", predAtt="ult")
+                __le__ = partialmethod(_binary_op, op="Cmp", predAtt="ule")
+                __eq__ = partialmethod(_binary_op, op="Cmp", predAtt="eq")
+                __ne__ = partialmethod(_binary_op, op="Cmp", predAtt="ne")
+                __gt__ = partialmethod(_binary_op, op="Cmp", predAtt="ugt")
+                __ge__ = partialmethod(_binary_op, op="Cmp", predAtt="uge")
+                __and__ = partialmethod(_binary_op, op="And")
+                __or__ = partialmethod(_binary_op, op="Or")
+
+                def __str__(self):
+                    return (
+                        super()
+                        .__str__()
+                        .replace(ir.Value.__name__, ArithValue.__name__)
+                    )
+
+            # Generate MLIR Context and start generating IR
+            with ir.Context(), ir.Location.unknown():
+                types = []
+                for arg in args:
+                    types.append(get_mlir_ty(arg))
+
+                # Build IR
+                module = ir.Module.create()
+                with ir.InsertionPoint(module.body):
+                    fop = func.FuncOp(function_name, (types, []))
+                    fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+                    with ir.InsertionPoint(fop.add_entry_block()):
+                        fargs = []
+                        for i, a in enumerate(types):
+                            fargs.append(fop.arguments[i])
+
+                        # Call user function body
+                        result = funcBody(*fargs, **kwargs)
+                        func.ReturnOp([])
+
+                # Verify the module
+                module.operation.verify()
+
+                # Save IR in a file
+                # saveIR(module)
+
+                # Compile and JIT MLIR module
+                options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
----------------
jpienaar wrote:

This would be a surprise: this is coupling what seems like a general nvgpu tool to a specific set of versions here.

https://github.com/llvm/llvm-project/pull/87065


More information about the Mlir-commits mailing list