[Mlir-commits] [mlir] [mlir] GEMM Hopper Tensor Core Integration Test (PR #81478)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 12 05:34:46 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Guray Ozen (grypp)

<details>
<summary>Changes</summary>

This test aims to validate the correctness of the supported GEMM kernels in NVGPU dialects, with current support for Multistage and Warp Specialization kernels.
The test constructs and metaprograms IR using Python bindings, allowing generic IR building. This flexibility enables changes to the shape, tile size, or data type of the GEMM for testing purposes. The entry function is `matmul`, where one can specify GEMM shape, tile size, data type, GEMM algorithm (Multistage or Warp Specialization), and the maximum number of stages.
Verification is done via numpy's matmul operation.

Example:
```
matmul(input_type=np.float16,                # input types
       output_type=np.float32,               # output type
       M=4096, N=4096, K=4096,               # Shape
       BLOCK_M=128, BLOCK_N=128, BLOCK_K=64, # Tile Size
       use_warp_specialization=True,         # Enable Warp Specialization
       max_num_stages=3)                     # Number of stages in shared memory
```
### Parallelism Across CTAs

GEMM includes three loops defining the shape of the GEMM, specified in the `matmul` function.
The program builds IR using the following loop structure, tiling the loops with the given tile size and parallelizing the two outermost loops into the first and second dimensions of CTAs.
```
for(bi = 0; i < M; i += BLOCK_M)          # parallelize across blockIdx.x
    for(bj = 0; j < N; j += BLOCK_N)      # parallelize across blockIdx.y
        for(bk = 0; k < K; K += BLOCK_K)
            for(i = bi; i < (bi + BLOCK_M); ++i)
                for(j = bj; j < (bj + BLOCK_N); ++j)
                    for(k = bk; k < (bk + BLOCK_K); ++k)
```

## Multistage Kernel

This kernel launches a single warp group (128 threads). The primary thread (pthread) requests load from TMA. Threads collectively wait for the data and perform mma operations. After completing the shape, threads together store first fragmented registers to shared memory, then from shared memory to global memory; this part is called the epilogue.

Execution Timeline of Multistage Kernel with 3 stages:
```
+-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
|       |Prologue ---->   |MainLoop ---->                                                                  |Epilogue  |
+-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
|pthread|[tma-0,1,2]     |[wait-0][mma][tma-2]|[wait-1][mma][tma-0]|[wait-2][mma][tma-1]| ... | [mma-wait] |[epilogue]|
|wgroup | ........       |[wait-0][mma]       |[wait-1][mma]       |[wait-2][mma]       | ... | [mma-wait] |[epilogue]|
+-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
```

## Warp Specialization Kernel

This kernel launches 2 warp groups (2x128 threads) per CTA, specializing one as `producer warp group` and another as `consumer warp group`. The `producer warp group` is responsible for requesting TMA load, while the `consumer warp group` performs the mma operation. The epilogue section is handled by the `consumer warp group` as its threads own the fragmented registers.

Execution Timeline of Warp Specialization Kernel with 2 stages:
```
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
|        |MainLoop ---->                                                    | 1st Epilogue | 2nd Epilogue    |
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
|pthread1|[tma-0] | [tma-1] | [tma-0] | [tma-1] | ..........................| ...........  | [shmem->global] |
|wgroup1 | .......|         |         |         |                           |              | [shmem->global] |
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
|wgroup2 |[wait-0][mma], [wait-1][mma], [wait-0][mma], [wait-1][mma], ......| [reg->shmem] | [shmem->global]|
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
```

---

Patch is 49.61 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81478.diff


5 Files Affected:

- (added) mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg (+2) 
- (added) mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py (+186) 
- (added) mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg (+3) 
- (added) mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py (+676) 
- (added) mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py (+44) 


``````````diff
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg b/mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg
new file mode 100644
index 00000000000000..2d5a9d00e73226
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg
@@ -0,0 +1,2 @@
+if not config.enable_cuda_runner or not config.mlir_run_cuda_sm90_tests:
+    config.unsupported = True
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
new file mode 100644
index 00000000000000..e153fcb44b9860
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
@@ -0,0 +1,186 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN:   %PYTHON %s | FileCheck %s
+# CHECK: PASS
+
+# ===--- GEMM Hopper Tensor Core Integration Test ---===
+#
+# This test aims to validate the correctness of the supported GEMM kernels in
+# NVGPU dialects, with current support for Multistage and Warp Specialization
+# kernels.
+# The test constructs and metaprograms IR using Python bindings, allowing
+# generic IR building. This flexibility enables changes to the shape,
+# tile size, or data type of the GEMM for testing purposes.
+# The entry function is `matmul`, where one can specify GEMM shape, tile size,
+# data type, GEMM algorithm (Multistage or Warp Specialization), and the maximum
+# number of stages.
+# Verification is done via numpy's matmul operation.
+#
+# Example:
+# matmul(input_type=np.float16,                # input types
+#        output_type=np.float32,               # output type
+#        M=4096, N=4096, K=4096,               # Shape
+#        BLOCK_M=128, BLOCK_N=128, BLOCK_K=64, # Tile Size
+#        use_warp_specialization=True,         # Enable Warp Specialization
+#        max_num_stages=3)                     # Number of stages in shared memory
+#
+# ===--- Parallelism Across CTAs  ---===
+#
+# GEMM includes three loops defining the shape of the GEMM, specified in the
+# `matmul` function.
+# The program builds IR using the following loop structure, tiling the loops
+# with the given tile size and parallelizing the two outermost loops into the
+# first and second dimensions of CTAs.
+#
+# for(bi = 0; i < M; i += BLOCK_M)          # parallelize across blockIdx.x
+#     for(bj = 0; j < N; j += BLOCK_N)      # parallelize across blockIdx.y
+#         for(bk = 0; k < K; K += BLOCK_K)
+#             for(i = bi; i < (bi + BLOCK_M); ++i)
+#                 for(j = bj; j < (bj + BLOCK_N); ++j)
+#                     for(k = bk; k < (bk + BLOCK_K); ++k)
+#
+# ===--- Multistage Kernel ---===
+#
+# This kernel launches a single warp group (128 threads). The primary thread
+# (pthread) requests load from TMA. Threads collectively wait for the data and
+# perform mma operations. After completing the shape, threads together store
+# first fragmented registers to shared memory, then from shared memory to global
+# memory; this part is called the epilogue.
+#
+# Execution Timeline of Multistage Kernel with 3 stages:
+# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
+# |       |Prologue ---->   |MainLoop ---->                                                                  |Epilogue  |
+# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
+# |pthread|[tma-0,1,2]     |[wait-0][mma][tma-2]|[wait-1][mma][tma-0]|[wait-2][mma][tma-1]| ... | [mma-wait] |[epilogue]|
+# |wgroup | ........       |[wait-0][mma]       |[wait-1][mma]       |[wait-2][mma]       | ... | [mma-wait] |[epilogue]|
+# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
+#
+# ===--- Warp Specialization Kernel  ---===
+#
+# This kernel launches 2 warp groups (2x128 threads) per CTA, specializing one
+# as `producer warp group` and another as `consumer warp group`. The
+# `producer warp group` is responsible for requesting TMA load, while the
+# `consumer warp group` performs the mma operation. The epilogue section is
+# handled by the `consumer warp group` as its threads own the fragmented registers.
+#
+# Execution Timeline of Warp Specialization Kernel with 2 stages:
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+# |        |MainLoop ---->                                                    | 1st Epilogue | 2nd Epilogue    |
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+# |pthread1|[tma-0] | [tma-1] | [tma-0] | [tma-1] | ..........................| ...........  | [shmem->global] |
+# |wgroup1 | .......|         |         |         |                           |              | [shmem->global] |
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+# |wgroup2 |[wait-0][mma], [wait-1][mma], [wait-0][mma], [wait-1][mma], ......| [reg->shmem] | [shmem->global]|
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+
+import errno
+import numpy as np
+import subprocess
+import ctypes
+from tools import nvgpucompiler
+from tools import matmulBuilder
+import contextlib
+import os
+import sys
+import pathlib
+import ctypes
+from mlir import runtime as rt
+
+def generate_matmul(input_type=np.float16,
+                    output_type=np.float32,
+                    M=4096,
+                    N=4096,
+                    K=4096,
+                    BLOCK_M=128,
+                    BLOCK_N=128,
+                    BLOCK_K=64,
+                    use_warp_specilization=True,
+                    saveIR=False,
+                    max_num_stages=3):
+    with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown():
+        if use_warp_specilization:
+            mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(input_type, output_type, M, N, K, BLOCK_M, BLOCK_N,
+                                                                 BLOCK_K, max_num_stages)
+        else:
+            mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage(input_type, output_type, M, N, K, BLOCK_M,
+                                                                         BLOCK_N, BLOCK_K, max_num_stages)
+
+        mlir_nvgpu_module.operation.verify()
+        
+        # Save generated IR
+        if saveIR:
+            # print(mlir_nvgpu_module)
+            original_stdout = sys.stdout
+            with open('gemm.mlir', 'w') as f:
+                sys.stdout = f
+                print(mlir_nvgpu_module)
+                sys.stdout = original_stdout
+
+        # Get compiler
+        options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
+        support_lib = os.getenv("SUPPORT_LIB")
+        if not os.path.exists(support_lib):
+            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
+        compiler = nvgpucompiler.NvgpuCompiler(options, opt_level=3, shared_libs=[support_lib])
+
+        # Compile
+        engine = compiler.compile_and_jit(mlir_nvgpu_module)
+        return engine
+
+
+def matmul(input_type=np.float16,
+           output_type=np.float32,
+           M=128,
+           N=128,
+           K=128,
+           BLOCK_M=128,
+           BLOCK_N=128,
+           BLOCK_K=64,
+           use_warp_specilization=True,
+           saveIR=False,
+           max_num_stages=3,
+           print_results=False,
+           no_verify=False):
+    # Print the configuration
+    ity = "f16" if input_type == np.float16 else "f32"
+    oty = "f16" if output_type == np.float16 else "f32"
+    gemmty = "Warp Specilization" if use_warp_specilization else "Multistage"
+    print("===-- Running GEMM " + gemmty + " " + oty + " += " + ity + " * " + ity + ", Size " + str(M) + "x" + str(N) +
+          "x" + str(K) + ", Tile " + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" + str(BLOCK_K) + ", stages " +
+          str(max_num_stages) + " --===")
+
+    # Build IR and compile
+    engine = generate_matmul(input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, use_warp_specilization,
+                             saveIR, max_num_stages)
+
+    # Allocate matrices and invoke the matmul
+    c = np.zeros((M, N), output_type)
+    a = np.random.randn(M, K).astype(input_type)
+    b = np.random.randn(K, N).astype(input_type)
+    mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
+    mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
+    mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
+    kernelName = "mlir_matmul_warpspecialized" if use_warp_specilization else "mlir_matmul_multistage"
+
+    # Launch the MLIR generated kernel
+    engine.invoke(kernelName, mem_a, mem_b, mem_c)
+
+    float_formatter = "{:.2f}".format
+    np.set_printoptions(formatter={'float_kind': float_formatter})
+
+    if print_results:
+        print(c)
+
+    # Verify the results
+    if not no_verify:
+        ref = a.astype(input_type) @ b.astype(input_type)
+        if print_results:
+            print(ref)
+        np.testing.assert_allclose(c, ref, rtol=5e-03, atol=1e-01)
+
+    print("PASS ")
+
+
+# GEMM Multistage       f32 += f16 * f16
+matmul(np.float16, np.float32, 128, 128, 4096, max_num_stages=3, use_warp_specilization=False)
+# GEMM Warp Specilized  f32 += f16 * f16
+matmul(np.float16, np.float32, 256, 1024, 512, max_num_stages=3, use_warp_specilization=True)
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg
new file mode 100644
index 00000000000000..d9f34f219c4d95
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg
@@ -0,0 +1,3 @@
+# Files in this directory are tools, not tests.
+config.unsupported = True
+
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
new file mode 100644
index 00000000000000..09ab35d4c5f15e
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
@@ -0,0 +1,676 @@
+import numpy as np
+from mlir import ir
+from mlir.dialects import arith
+from mlir.dialects import func
+from mlir.dialects import gpu
+from mlir.dialects import memref
+from mlir.dialects import nvgpu
+from mlir.dialects import nvvm
+from mlir.dialects import llvm
+from mlir.dialects import builtin
+from mlir.dialects import scf
+from mlir.dialects import vector
+
+
+TMA_LAST_DIM_F16 = 64  # 128B flaot16
+WARP_SIZE = 32
+WARP_GROUP_SIZE = WARP_SIZE * 4
+
+PRODUCER_REGISTER_SIZE = 40
+CONSUMER_REGISTER_SIZE = 232
+
+PRODUCER_PRIMARY_THREAD = 128
+CONSUMER_PRIMARY_THREAD = 0
+
+MLIR_DYNAMIC = -9223372036854775808
+f16_byte = 2
+f32_byte = 4
+
+DEBUG = False
+
+
+def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False):
+    if not DEBUG and not forcePrint:
+        return
+    type_formats = []
+    for arg in args:
+        ty_format = None
+        if ir.IndexType.isinstance(arg.type):
+            ty_format = "%llu"
+        if ir.IntegerType.isinstance(arg.type):
+            width = ir.IntegerType(arg.type).width
+            if width == 64:
+                ty_format = "%llu"
+            elif width == 32:
+                ty_format = "%d"
+            elif width == 1:
+                ty_format = "%i"
+        if ir.F32Type.isinstance(arg.type):
+            ty_format = "%f"
+        if ty_format is None:
+            raise NotImplementedError(arg.type)
+        type_formats.append(ty_format)
+    if threadNumber != -1:
+        tidx = gpu.thread_id(gpu.Dimension.x)
+        predicate = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(threadNumber))
+        scf.yield_([])
+    if_op = scf.IfOp(predicate)
+    with ir.InsertionPoint(if_op.then_block):
+        gpu.printf(fmt.format(*type_formats) + "\n", args)
+        scf.yield_([])
+
+
+def c(value, ty=None):
+    ty = ir.IndexType.get() if ty is None else ty
+    return arith.constant(ty, value)
+
+
+def generate_matmul_ws(input_type=np.float16,
+                       output_type=np.float32,
+                       M=4096,
+                       N=4096,
+                       K=4096,
+                       BLOCK_M=128,
+                       BLOCK_N=128,
+                       BLOCK_K=128,
+                       max_num_stages=3):
+    # Limitaitons for now
+    assert input_type == np.float16
+    assert output_type == np.float32
+    assert M % BLOCK_M == 0
+    assert N % BLOCK_N == 0
+    assert K % BLOCK_K == 0
+
+    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
+    num_stages = min(required_stages, max_num_stages)
+
+    module = ir.Module.create()
+    f16 = ir.F16Type.get()
+    f32 = ir.F32Type.get()
+    i1 = ir.IntegerType.get_signless(1)
+    i32 = ir.IntegerType.get_signless(32)
+    index = ir.IndexType.get()
+    i8 = ir.IntegerType.get_signless(8)
+    token_ty = ir.Type.parse("!gpu.async.token")
+    a_ty = ir.MemRefType.get([M, K], f16)
+    b_ty = ir.MemRefType.get((K, N), f16)
+    c_elem_ty = f16 if output_type == np.float16 else f32
+    c_ty = ir.MemRefType.get((M, N), c_elem_ty)
+    a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
+    b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
+    b_tile_shape = (BLOCK_K, BLOCK_N)
+    txcount = ((b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
+    smem_space_str = "#gpu.address_space<workgroup>"
+    smem_space = ir.Attribute.parse(smem_space_str)
+    input_type_str = "f16" if input_type == np.float16 else "f32"
+    output_type_str = "f16" if output_type == np.float16 else "f32"
+    mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " + str(smem_space) + ", num_barriers = " +
+                            str(num_stages) + ">")
+    a_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
+                                  str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
+                                  ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+    b_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
+                                  str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
+                                  ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+    acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
+                           str(output_type_str) + ">>")
+    a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
+                               str(input_type_str) + ", " + smem_space_str + ">>")
+    b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
+                               str(input_type_str) + ", " + smem_space_str + ">>")
+
+    with ir.InsertionPoint(module.body):
+
+        @func.FuncOp.from_py_func(a_ty, b_ty, c_ty)
+        def mlir_matmul_warpspecialized(a_host, b_host, c_host):
+            lhs_tile_bytes = BLOCK_M * BLOCK_K * f16_byte
+            rhs_tile_bytes = BLOCK_N * BLOCK_K * f16_byte
+            smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
+            smem_size_output = BLOCK_M * BLOCK_N * f32_byte
+            smem_size = max(smem_size_input, smem_size_output)
+
+            # Step 1. Allocate device memory and memcpy
+            t1 = gpu.wait(token_ty, [])
+            a_device, t2 = gpu.alloc(a_ty, token_ty, [t1], [], [])
+            b_device, t3 = gpu.alloc(b_ty, token_ty, [t2], [], [])
+            c_device, t4 = gpu.alloc(c_ty, token_ty, [t3], [], [])
+            t5 = gpu.memcpy(token_ty, [t4], a_device, a_host)
+            t6 = gpu.memcpy(token_ty, [t5], b_device, b_host)
+            t7 = gpu.wait(token_ty, [t6])
+
+            # Step 2. Create TMA Descriptors
+            tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape), (b_device, b_tma_desc_ty, b_tma_shape)]
+            tma_descs = []
+            for x_device, tensor_map_ty, tile_shape in tma_specs:
+                x_unranked = memref.cast(ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device)
+                tma_descs.append(nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked, map(c, tile_shape)).result)
+            a_tma_desc, b_tma_desc = tma_descs
+            
+            # Step 3. Launch Kernel with 2 Warpgroups : 1 Producer, 1 Consumer
+            cta_m = M // BLOCK_M
+            cta_n = N // BLOCK_N
+            assert M % BLOCK_M == 0 and N % BLOCK_N == 0
+            grid = (cta_m, cta_n, 1)
+            block = (WARP_GROUP_SIZE * 2, 1, 1)
+            launch_op = gpu.LaunchOp(token_ty, [t7],
+                                     *map(c, grid),
+                                     *map(c, block),
+                                     dynamicSharedMemorySize=c(smem_size, ty=i32))
+            launch_op.body.blocks.append(*([index] * 12))
+            with ir.InsertionPoint(launch_op.body.blocks[0]):
+                # GPU Step 0. This is need for vectorized ld/st
+                memref.assume_alignment(c_device, 16)
+                dynamic_smem = gpu.dynamic_shared_memory(
+                    ir.MemRefType.get((MLIR_DYNAMIC, ), i8, memory_space=smem_space))
+                ticks = c(10000000)
+
+                # GPU Step 1. Bootstrapping: find the primary thread, warps, warp groups and etc.
+                tidx = gpu.thread_id(gpu.Dimension.x)
+                wgPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0))
+                warp_id = arith.divui(tidx, c(32))
+                warpgroup_id = arith.divui(warp_id, c(4))
+                is_producer = arith.cmpi(arith.CmpIPredicate.eq, warpgroup_id,
+                                         c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0))
+                is_consumer = arith.cmpi(arith.CmpIPredicate.eq, warpgroup_id,
+                                         c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1))
+                producerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(PRODUCER_PRIMARY_THREAD))
+                consumerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(CONSUMER_PRIMARY_THREAD))
+                bidx = gpu.block_id(gpu.Dimension.x)
+                bidy = gpu.block_id(gpu.Dimension.y)
+                dimX = arith.muli(bidx, c(BLOCK_M))
+                dimY = arith.muli(bidy, c(BLOCK_N))
+
+                # GPU Step 2. Initialize mbarrier groups
+                mbarTMA = nvgpu.mbarrier_create(mbar_ty)
+                mbarDONE = nvgpu.mbarrier_create(mbar_ty)
+                for i in range(num_stages):
+                    nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=wgPrimaryThread)
+                    nvgpu.mbarrier_init(mbarDONE, c(1), c(i), predicate=wgPrimaryThread)
+                gpu.barrier()
+
+                # GPU Step 3. Prefetch TMA descriptors
+                nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=wgPrimaryThread)
+                nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=wgPrimaryThread)
+                
+                # GPU Step 5. Producer Warpgroup (TMA Warpgroup)
+                with ir.InsertionPoint(scf.IfOp(is_producer).then_block):
+                    
+                    # Step 5.1. Reduce register size
+                    nvvm.setmaxregister(PRODUCER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.decrease)
+
+                    # Step 5.2. TMA Main Loop
+                    for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [arith.constant(i1, 1)])
+                    with ir.InsertionPoint(for_op.body):
+                        phaseParity = for_op.inner_iter_args[0]
+                        iv = for_op.induction_variable
+                        stage = arith.remui(iv, c(num_stages))
+
+                        # Step 5.2.1. Wait mbarDONE
+                        debug_print("[prod] {}  | mbarDONE[{}] try_wait  phase={}",
+                                    iv,
+                                    stage,
+            ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/81478


More information about the Mlir-commits mailing list