[Mlir-commits] [mlir] [mlir] GEMM Hopper Tensor Core Integration Test (PR #81478)

Vinod Grover llvmlistbot at llvm.org
Mon Feb 12 14:51:32 PST 2024


================
@@ -0,0 +1,266 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN:   %PYTHON %s | FileCheck %s
+# CHECK: PASS
+
+# ===--- GEMM Hopper Tensor Core Integration Test ---===
+#
+# This test aims to validate the correctness of the supported GEMM kernels in
+# NVGPU dialects, with current support for Multistage and Warp Specialization
+# kernels.
+# The test constructs and metaprograms IR using Python bindings, allowing
+# generic IR building. This flexibility enables changes to the shape,
+# tile size, or data type of the GEMM for testing purposes.
+# The entry function is `matmul`, where one can specify GEMM shape, tile size,
+# data type, GEMM algorithm (Multistage or Warp Specialization), and the maximum
+# number of stages.
+# Verification is done via numpy's matmul operation.
+#
+# Example:
+# matmul(input_type=np.float16,                # input types
+#        output_type=np.float32,               # output type
+#        M=4096, N=4096, K=4096,               # Shape
+#        BLOCK_M=128, BLOCK_N=128, BLOCK_K=64, # Tile Size
+#        use_warp_specialization=True,         # Enable Warp Specialization
+#        max_num_stages=3)                     # Number of stages in shared memory
+#
+# ===--- Parallelism Across CTAs  ---===
+#
+# GEMM includes three loops defining the shape of the GEMM, specified in the
+# `matmul` function.
+# The program builds IR using the following loop structure, tiling the loops
+# with the given tile size and parallelizing the two outermost loops into the
+# first and second dimensions of CTAs.
+#
+# for(bi = 0; i < M; i += BLOCK_M)          # parallelize across blockIdx.x
+#     for(bj = 0; j < N; j += BLOCK_N)      # parallelize across blockIdx.y
+#         for(bk = 0; k < K; K += BLOCK_K)
+#             for(i = bi; i < (bi + BLOCK_M); ++i)
+#                 for(j = bj; j < (bj + BLOCK_N); ++j)
+#                     for(k = bk; k < (bk + BLOCK_K); ++k)
+#
+# ===--- Multistage Kernel ---===
+#
+# This kernel launches a single warp group (128 threads). The primary thread
+# (pthread) requests load from TMA. Threads collectively wait for the data and
+# perform mma operations. After completing the shape, threads together store
+# first fragmented registers to shared memory, then from shared memory to global
+# memory; this part is called the epilogue.
+#
+# Execution Timeline of Multistage Kernel with 3 stages:
+# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
+# |       |Prologue ---->   |MainLoop ---->                                                                  |Epilogue  |
+# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
+# |pthread|[tma-0,1,2]     |[wait-0][mma][tma-2]|[wait-1][mma][tma-0]|[wait-2][mma][tma-1]| ... | [mma-wait] |[epilogue]|
+# |wgroup | ........       |[wait-0][mma]       |[wait-1][mma]       |[wait-2][mma]       | ... | [mma-wait] |[epilogue]|
+# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
+#
+# ===--- Warp Specialization Kernel  ---===
+#
+# This kernel launches 2 warp groups (2x128 threads) per CTA, specializing one
+# as `producer warp group` and another as `consumer warp group`. The
+# `producer warp group` is responsible for requesting TMA load, while the
+# `consumer warp group` performs the mma operation. The epilogue section is
+# handled by the `consumer warp group` as its threads own the fragmented registers.
+#
+# Execution Timeline of Warp Specialization Kernel with 2 stages:
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+# |        |MainLoop ---->                                                    | 1st Epilogue | 2nd Epilogue    |
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+# |pthread1|[tma-0] | [tma-1] | [tma-0] | [tma-1] | ..........................| ...........  | [shmem->global] |
+# |wgroup1 | .......|         |         |         |                           |              | [shmem->global] |
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+# |wgroup2 |[wait-0][mma], [wait-1][mma], [wait-0][mma], [wait-1][mma], ......| [reg->shmem] | [shmem->global]|
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+
+import errno
+import numpy as np
+import subprocess
+import ctypes
+from tools import nvgpucompiler
+from tools import matmulBuilder
+import contextlib
+import os
+import sys
+import pathlib
+import ctypes
+from mlir import runtime as rt
+
+
+def generate_matmul(
+    input_type=np.float16,
+    output_type=np.float32,
+    M=4096,
+    N=4096,
+    K=4096,
+    BLOCK_M=128,
+    BLOCK_N=128,
+    BLOCK_K=64,
+    use_warp_specilization=True,
+    saveIR=False,
+    max_num_stages=3,
+):
+    with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown():
+        if use_warp_specilization:
+            mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(
+                input_type,
+                output_type,
+                M,
+                N,
+                K,
+                BLOCK_M,
+                BLOCK_N,
+                BLOCK_K,
+                max_num_stages,
+            )
+        else:
+            mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage(
+                input_type,
+                output_type,
+                M,
+                N,
+                K,
+                BLOCK_M,
+                BLOCK_N,
+                BLOCK_K,
+                max_num_stages,
+            )
+
+        mlir_nvgpu_module.operation.verify()
+
+        # Save generated IR
+        if saveIR:
+            # print(mlir_nvgpu_module)
+            original_stdout = sys.stdout
+            with open("gemm.mlir", "w") as f:
+                sys.stdout = f
+                print(mlir_nvgpu_module)
+                sys.stdout = original_stdout
+
+        # Get compiler
+        options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
+        support_lib = os.getenv("SUPPORT_LIB")
+        if not os.path.exists(support_lib):
+            raise FileNotFoundError(
+                errno.ENOENT, os.strerror(errno.ENOENT), support_lib
+            )
+        compiler = nvgpucompiler.NvgpuCompiler(
+            options, opt_level=3, shared_libs=[support_lib]
+        )
+
+        # Compile
+        engine = compiler.compile_and_jit(mlir_nvgpu_module)
+        return engine
+
+
+def matmul(
+    input_type=np.float16,
+    output_type=np.float32,
+    M=128,
+    N=128,
+    K=128,
+    BLOCK_M=128,
+    BLOCK_N=128,
+    BLOCK_K=64,
+    use_warp_specilization=True,
+    saveIR=False,
+    max_num_stages=3,
+    print_results=False,
+    no_verify=False,
+):
+    # Print the configuration
+    ity = "f16" if input_type == np.float16 else "f32"
+    oty = "f16" if output_type == np.float16 else "f32"
+    gemmty = "Warp Specilization" if use_warp_specilization else "Multistage"
+    print(
+        "===-- Running GEMM "
+        + gemmty
+        + " "
+        + oty
+        + " += "
+        + ity
+        + " * "
+        + ity
+        + ", Size "
+        + str(M)
+        + "x"
+        + str(N)
+        + "x"
+        + str(K)
+        + ", Tile "
+        + str(BLOCK_M)
+        + "x"
+        + str(BLOCK_N)
+        + "x"
+        + str(BLOCK_K)
+        + ", stages "
+        + str(max_num_stages)
+        + " --==="
+    )
+
+    # Build IR and compile
+    engine = generate_matmul(
+        input_type,
+        output_type,
+        M,
+        N,
+        K,
+        BLOCK_M,
+        BLOCK_N,
+        BLOCK_K,
+        use_warp_specilization,
----------------
vinodgro wrote:

spelling? "specilization" ="specialization" ?

https://github.com/llvm/llvm-project/pull/81478


More information about the Mlir-commits mailing list