[Mlir-commits] [mlir] [mlir] GEMM Hopper Tensor Core Integration Test (PR #81478)
Vinod Grover
llvmlistbot at llvm.org
Mon Feb 12 14:51:32 PST 2024
================
@@ -0,0 +1,266 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN: %PYTHON %s | FileCheck %s
+# CHECK: PASS
+
+# ===--- GEMM Hopper Tensor Core Integration Test ---===
+#
+# This test aims to validate the correctness of the supported GEMM kernels in
+# NVGPU dialects, with current support for Multistage and Warp Specialization
+# kernels.
+# The test constructs and metaprograms IR using Python bindings, allowing
+# generic IR building. This flexibility enables changes to the shape,
+# tile size, or data type of the GEMM for testing purposes.
+# The entry function is `matmul`, where one can specify GEMM shape, tile size,
+# data type, GEMM algorithm (Multistage or Warp Specialization), and the maximum
+# number of stages.
+# Verification is done via numpy's matmul operation.
+#
+# Example:
+# matmul(input_type=np.float16, # input types
+# output_type=np.float32, # output type
+# M=4096, N=4096, K=4096, # Shape
+# BLOCK_M=128, BLOCK_N=128, BLOCK_K=64, # Tile Size
+# use_warp_specialization=True, # Enable Warp Specialization
+# max_num_stages=3) # Number of stages in shared memory
+#
+# ===--- Parallelism Across CTAs ---===
+#
+# GEMM includes three loops defining the shape of the GEMM, specified in the
+# `matmul` function.
+# The program builds IR using the following loop structure, tiling the loops
+# with the given tile size and parallelizing the two outermost loops into the
+# first and second dimensions of CTAs.
+#
+# for(bi = 0; i < M; i += BLOCK_M) # parallelize across blockIdx.x
+# for(bj = 0; j < N; j += BLOCK_N) # parallelize across blockIdx.y
+# for(bk = 0; k < K; K += BLOCK_K)
+# for(i = bi; i < (bi + BLOCK_M); ++i)
+# for(j = bj; j < (bj + BLOCK_N); ++j)
+# for(k = bk; k < (bk + BLOCK_K); ++k)
+#
+# ===--- Multistage Kernel ---===
+#
+# This kernel launches a single warp group (128 threads). The primary thread
+# (pthread) requests load from TMA. Threads collectively wait for the data and
+# perform mma operations. After completing the shape, threads together store
+# first fragmented registers to shared memory, then from shared memory to global
+# memory; this part is called the epilogue.
+#
+# Execution Timeline of Multistage Kernel with 3 stages:
+# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
+# | |Prologue ----> |MainLoop ----> |Epilogue |
+# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
+# |pthread|[tma-0,1,2] |[wait-0][mma][tma-2]|[wait-1][mma][tma-0]|[wait-2][mma][tma-1]| ... | [mma-wait] |[epilogue]|
+# |wgroup | ........ |[wait-0][mma] |[wait-1][mma] |[wait-2][mma] | ... | [mma-wait] |[epilogue]|
+# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
+#
+# ===--- Warp Specialization Kernel ---===
+#
+# This kernel launches 2 warp groups (2x128 threads) per CTA, specializing one
+# as `producer warp group` and another as `consumer warp group`. The
+# `producer warp group` is responsible for requesting TMA load, while the
+# `consumer warp group` performs the mma operation. The epilogue section is
+# handled by the `consumer warp group` as its threads own the fragmented registers.
+#
+# Execution Timeline of Warp Specialization Kernel with 2 stages:
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+# | |MainLoop ----> | 1st Epilogue | 2nd Epilogue |
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+# |pthread1|[tma-0] | [tma-1] | [tma-0] | [tma-1] | ..........................| ........... | [shmem->global] |
+# |wgroup1 | .......| | | | | | [shmem->global] |
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+# |wgroup2 |[wait-0][mma], [wait-1][mma], [wait-0][mma], [wait-1][mma], ......| [reg->shmem] | [shmem->global]|
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+
+import errno
+import numpy as np
+import subprocess
+import ctypes
+from tools import nvgpucompiler
+from tools import matmulBuilder
+import contextlib
+import os
+import sys
+import pathlib
+import ctypes
+from mlir import runtime as rt
+
+
+def generate_matmul(
+ input_type=np.float16,
+ output_type=np.float32,
+ M=4096,
+ N=4096,
+ K=4096,
+ BLOCK_M=128,
+ BLOCK_N=128,
+ BLOCK_K=64,
+ use_warp_specilization=True,
+ saveIR=False,
+ max_num_stages=3,
+):
+ with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown():
+ if use_warp_specilization:
+ mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(
+ input_type,
+ output_type,
+ M,
+ N,
+ K,
+ BLOCK_M,
+ BLOCK_N,
+ BLOCK_K,
+ max_num_stages,
+ )
+ else:
+ mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage(
+ input_type,
+ output_type,
+ M,
+ N,
+ K,
+ BLOCK_M,
+ BLOCK_N,
+ BLOCK_K,
+ max_num_stages,
+ )
+
+ mlir_nvgpu_module.operation.verify()
+
+ # Save generated IR
+ if saveIR:
+ # print(mlir_nvgpu_module)
+ original_stdout = sys.stdout
+ with open("gemm.mlir", "w") as f:
+ sys.stdout = f
+ print(mlir_nvgpu_module)
+ sys.stdout = original_stdout
+
+ # Get compiler
+ options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
+ support_lib = os.getenv("SUPPORT_LIB")
+ if not os.path.exists(support_lib):
+ raise FileNotFoundError(
+ errno.ENOENT, os.strerror(errno.ENOENT), support_lib
+ )
+ compiler = nvgpucompiler.NvgpuCompiler(
+ options, opt_level=3, shared_libs=[support_lib]
+ )
+
+ # Compile
+ engine = compiler.compile_and_jit(mlir_nvgpu_module)
+ return engine
+
+
+def matmul(
+ input_type=np.float16,
+ output_type=np.float32,
+ M=128,
+ N=128,
+ K=128,
+ BLOCK_M=128,
+ BLOCK_N=128,
+ BLOCK_K=64,
+ use_warp_specilization=True,
+ saveIR=False,
+ max_num_stages=3,
+ print_results=False,
+ no_verify=False,
+):
+ # Print the configuration
+ ity = "f16" if input_type == np.float16 else "f32"
+ oty = "f16" if output_type == np.float16 else "f32"
+ gemmty = "Warp Specilization" if use_warp_specilization else "Multistage"
+ print(
+ "===-- Running GEMM "
+ + gemmty
+ + " "
+ + oty
+ + " += "
+ + ity
+ + " * "
+ + ity
+ + ", Size "
+ + str(M)
+ + "x"
+ + str(N)
+ + "x"
+ + str(K)
+ + ", Tile "
+ + str(BLOCK_M)
+ + "x"
+ + str(BLOCK_N)
+ + "x"
+ + str(BLOCK_K)
+ + ", stages "
+ + str(max_num_stages)
+ + " --==="
+ )
+
+ # Build IR and compile
+ engine = generate_matmul(
+ input_type,
+ output_type,
+ M,
+ N,
+ K,
+ BLOCK_M,
+ BLOCK_N,
+ BLOCK_K,
+ use_warp_specilization,
----------------
vinodgro wrote:
spelling? "specilization" ="specialization" ?
https://github.com/llvm/llvm-project/pull/81478
More information about the Mlir-commits
mailing list