[Mlir-commits] [mlir] [mlir][nvgpu] NVGPU Tutorials (PR #87065)
Guray Ozen
llvmlistbot at llvm.org
Sat Mar 30 05:17:30 PDT 2024
================
@@ -0,0 +1,415 @@
+from enum import Enum
+import functools, sys, ctypes, os, errno
+import numpy as np
+from functools import partialmethod
+from mlir import ir
+from mlir.dialects import arith, func, gpu, memref, nvgpu
+from mlir.extras import types as T
+from mlir import runtime as rt
+from tools import nvgpucompiler
+
+DEBUG = True
+MLIR_DYNAMIC = -9223372036854775808
+
+
+def const(value: int, ty=None):
+ ty = T.index() if ty is None else ty
+ if isinstance(value, ir.Value) and (
+ value.type.isinstance(value.type) or T.bool().isinstance(value.type)
+ ):
+ return value
+ return arith.constant(ty, value)
+
+
+def get_type_size(ty):
+ if ir.MemRefType.isinstance(ty):
+ size = get_type_size(ty.element_type)
+ for sz in ty.shape:
+ size *= sz
+ return size
+ if ir.FloatType.isinstance(ty):
+ return ir.FloatType(ty).width // 8
+ if ir.IntegerType.isinstance(ty):
+ return ir.IntegerType(ty).width // 8
+ raise NotImplementedError(ty)
+
+
+def get_mlir_func_obj_ty(inputArgs):
+ args = []
+ c_int_p = ctypes.c_int * 1
+ c_float_p = ctypes.c_float * 1
+ c_bool_p = ctypes.c_bool * 1
+ for arg in inputArgs:
+ if isinstance(arg, bool):
+ args.append(c_bool_p(arg))
+ elif isinstance(arg, int):
+ args.append(c_int_p(arg))
+ elif isinstance(arg, float):
+ args.append(c_float_p(arg))
+ elif isinstance(arg, np.ndarray):
+ args.append(
+ ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(arg)))
+ )
+ else:
+ raise NotImplementedError(arg)
+ return args
+
+
+class Mbarriers:
+ def __init__(self, number_of_barriers=1):
+ self.mbar_ty = ir.Type.parse(
+ "!nvgpu.mbarrier.group<memorySpace=#gpu.address_space<workgroup>, num_barriers = "
+ + str(number_of_barriers)
+ + ">"
+ )
+ self.mbar_group_op = nvgpu.mbarrier_create(self.mbar_ty)
+ self.number_of_barriers = number_of_barriers
+
+ def __getitem__(self, key):
+ self.id_op = const(key)
+ return self
+
+ def init(self, count: int, predicate=None):
+ count_op = const(count)
+ if predicate is None:
+ nvgpu.mbarrier_init(self.mbar_group_op, count_op, self.id_op)
+ else:
+ nvgpu.mbarrier_init(
+ self.mbar_group_op, count_op, self.id_op, predicate=predicate
+ )
+
+ def arrive(self, txcount: int = 0, predicate=None):
+ if txcount != 0:
+ txcount_op = const(txcount)
+ nvgpu.mbarrier_arrive_expect_tx(
+ self.mbar_group_op, txcount_op, self.id_op, predicate=predicate
+ )
+ else:
+ nvgpu.mbarrier_arrive(self.mbar_group_op, self.id_op, predicate=predicate)
+
+ def try_wait(self, phase: bool = False, ticks: int = 10000000):
+ ticks_op = const(ticks)
+ phase_op = const(phase, T.bool())
+ nvgpu.MBarrierTryWaitParityOp(
+ self.mbar_group_op,
+ phase_op,
+ ticks_op,
+ mbarId=self.id_op,
+ )
+
+
+class TMA:
+ """A class that builds a TMA descriptor."""
+
+ def __init__(
+ self,
+ shape,
+ memref_ty,
+ swizzle=nvgpu.TensorMapSwizzleKind.SWIZZLE_NONE,
+ l2promo=nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
+ oob=nvgpu.TensorMapOOBKind.OOB_ZERO,
+ interleave=nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
+ ):
+ self.swizzle = swizzle # mlir.nvgpu.TensorMapSwizzleKind
+ self.l2promo = l2promo # mlir.nvgpu.TensorMapL2PromoKind
+ self.oob = oob # mlir.nvgpu.TensorMapOOBKind
+ self.interleave = interleave # mlir.nvgpu.TensorMapInterleaveKind
+ self.shape = shape
+ self.memref_ty = memref_ty # MemRefType
+ self.lastDim = 64
+ self.requiredLoad = 1
+ self.tma_shape = shape
+ self.tma_memref = ir.MemRefType.get(shape, memref_ty.element_type)
+
+ @property
+ def tensormap_descriptor_ty(self):
+ """Returns a tensormap descriptor type."""
+ memref_str = f"memref<{self.tma_shape[0]}x{self.tma_shape[1]}x{self.memref_ty.element_type}, 3>"
+ parse_str = f"!nvgpu.tensormap.descriptor<tensor = {memref_str},\
+ swizzle = {self.swizzle},\
+ l2promo = {self.l2promo},\
+ oob = {self.oob},\
+ interleave = {self.interleave}>"
+
+ return ir.Type.parse(parse_str)
+
+ def create_descriptor(self, device_ptr):
+ tma_descriptor_ty = self.tensormap_descriptor_ty
+ device_unranked_memref = memref.CastOp(
+ ir.UnrankedMemRefType.get(
+ self.memref_ty.element_type, self.memref_ty.memory_space
+ ),
+ device_ptr,
+ )
+ self.tma_descriptor = nvgpu.TmaCreateDescriptorOp(
+ tma_descriptor_ty, device_unranked_memref, map(const, self.tma_shape)
+ )
+ return self.tma_descriptor.result
+
+ def prefetch(self, predicate=None):
+ nvgpu.tma_prefetch_descriptor(self.tma_descriptor, predicate=predicate)
+
+ def load(self, dest, mbarrier: Mbarriers, coords=[0, 0], predicate=None):
+ coord_ops = [const(c) for c in coords]
+ nvgpu.TmaAsyncLoadOp(
+ dest,
+ mbarrier.mbar_group_op,
+ self.tma_descriptor,
+ coordinates=coord_ops,
+ mbarId=mbarrier.id_op,
+ predicate=predicate,
+ )
+
+
+class MatrixAccumulator:
+ def __init__(self, M, N, ty):
+ self.M = M
+ self.N = N
+ self.ty = ty
+
+ @property
+ def acc_ty(self):
+ return ir.Type.parse(
+ "!nvgpu.warpgroup.accumulator<fragmented=vector<"
+ + str(self.M)
+ + "x"
+ + str(self.N)
+ + "x"
+ + str(self.ty)
+ + ">>"
+ )
+
+ def op(self):
+ return nvgpu.warpgroup_mma_init_accumulator(self.acc_ty)
+
+
+class Matrix:
+ def __init__(self, smem, tma_descriptor: TMA, M, N):
+ self.tma_descriptor = tma_descriptor
+ self.smem = smem
+ self.M = M
+ self.N = N
+
+ @property
+ def wgmma_ty(self):
+ return ir.Type.parse(
+ "!nvgpu.warpgroup.descriptor<tensor=memref<"
+ + str(self.M)
+ + "x"
+ + str(self.N)
+ + "x"
+ + str(self.tma_descriptor.memref_ty.element_type)
+ + ", #gpu.address_space<workgroup>>>"
+ )
+
+ def matmul(lhs, rhs, acc):
+ wgmma_desc_lhs = nvgpu.warpgroup_generate_descriptor(
+ lhs.wgmma_ty, lhs.smem, lhs.tma_descriptor.tma_descriptor
+ )
+ wgmma_desc_rhs = nvgpu.warpgroup_generate_descriptor(
+ rhs.wgmma_ty, rhs.smem, rhs.tma_descriptor.tma_descriptor
+ )
+ return nvgpu.WarpgroupMmaOp(
+ acc.type, wgmma_desc_lhs, wgmma_desc_rhs, acc, transposeB=True
+ )
----------------
grypp wrote:
> We could integrate all of this code into the `mlir/python/mlir/dialects/nvgpu.py`, if it aligns with the overall structure and purpose
If it is possible to do this, we can also use these classes in integration tests. (An example PR is here https://github.com/llvm/llvm-project/pull/87153)
https://github.com/llvm/llvm-project/pull/87065
More information about the Mlir-commits
mailing list