[Mlir-commits] [mlir] [mlir] GEMM Hopper Tensor Core Integration Test (PR #81478)

Guray Ozen llvmlistbot at llvm.org
Mon Feb 12 23:42:19 PST 2024


================
@@ -0,0 +1,1078 @@
+import numpy as np
+from mlir import ir
+from mlir.dialects import arith
+from mlir.dialects import func
+from mlir.dialects import gpu
+from mlir.dialects import memref
+from mlir.dialects import nvgpu
+from mlir.dialects import nvvm
+from mlir.dialects import llvm
+from mlir.dialects import builtin
+from mlir.dialects import scf
+from mlir.dialects import vector
+
+TMA_LAST_DIM_F16 = 64  # 128B flaot16
+WARP_SIZE = 32
+WARP_GROUP_SIZE = WARP_SIZE * 4
+
+PRODUCER_REGISTER_SIZE = 40
+CONSUMER_REGISTER_SIZE = 232
+
+PRODUCER_PRIMARY_THREAD = 128
+CONSUMER_PRIMARY_THREAD = 0
+
+MLIR_DYNAMIC = -9223372036854775808
+f16_byte = 2
+f32_byte = 4
+
+DEBUG = False
+
+
+def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False):
+    if not DEBUG and not forcePrint:
+        return
+    type_formats = []
+    for arg in args:
+        ty_format = None
+        if ir.IndexType.isinstance(arg.type):
+            ty_format = "%llu"
+        if ir.IntegerType.isinstance(arg.type):
+            width = ir.IntegerType(arg.type).width
+            if width == 64:
+                ty_format = "%llu"
+            elif width == 32:
+                ty_format = "%d"
+            elif width == 1:
+                ty_format = "%i"
+        if ir.F32Type.isinstance(arg.type):
+            ty_format = "%f"
+        if ty_format is None:
+            raise NotImplementedError(arg.type)
+        type_formats.append(ty_format)
+    if threadNumber != -1:
+        tidx = gpu.thread_id(gpu.Dimension.x)
+        predicate = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(threadNumber))
+        scf.yield_([])
+    if_op = scf.IfOp(predicate)
+    with ir.InsertionPoint(if_op.then_block):
+        gpu.printf(fmt.format(*type_formats) + "\n", args)
+        scf.yield_([])
+
+
+def c(value, ty=None):
+    ty = ir.IndexType.get() if ty is None else ty
+    return arith.constant(ty, value)
+
+
+def generate_matmul_ws(
+    input_type=np.float16,
+    output_type=np.float32,
+    M=4096,
+    N=4096,
+    K=4096,
+    BLOCK_M=128,
+    BLOCK_N=128,
+    BLOCK_K=128,
+    max_num_stages=3,
+):
+    # Limitaitons for now
----------------
grypp wrote:

Thanks for your review and catch! 

I need to call driver for that (each gpu has different shmem size in these days), but I'm figuring out how to do it without adding extra dependencies in this code.

We check the maximum shared memory at kernel launch in MLIR's runtime, one will receive error message at latest there 
https://github.com/llvm/llvm-project/blob/main/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp#L176-L183

https://github.com/llvm/llvm-project/pull/81478


More information about the Mlir-commits mailing list