[Mlir-commits] [mlir] [mlir] GEMM Hopper Tensor Core Integration Test (PR #81478)

Guray Ozen llvmlistbot at llvm.org
Mon Feb 12 05:34:18 PST 2024


https://github.com/grypp created https://github.com/llvm/llvm-project/pull/81478

This test aims to validate the correctness of the supported GEMM kernels in NVGPU dialects, with current support for Multistage and Warp Specialization kernels.
The test constructs and metaprograms IR using Python bindings, allowing generic IR building. This flexibility enables changes to the shape, tile size, or data type of the GEMM for testing purposes. The entry function is `matmul`, where one can specify GEMM shape, tile size, data type, GEMM algorithm (Multistage or Warp Specialization), and the maximum number of stages.
Verification is done via numpy's matmul operation.

Example:
```
matmul(input_type=np.float16,                # input types
       output_type=np.float32,               # output type
       M=4096, N=4096, K=4096,               # Shape
       BLOCK_M=128, BLOCK_N=128, BLOCK_K=64, # Tile Size
       use_warp_specialization=True,         # Enable Warp Specialization
       max_num_stages=3)                     # Number of stages in shared memory
```
### Parallelism Across CTAs

GEMM includes three loops defining the shape of the GEMM, specified in the `matmul` function.
The program builds IR using the following loop structure, tiling the loops with the given tile size and parallelizing the two outermost loops into the first and second dimensions of CTAs.
```
for(bi = 0; i < M; i += BLOCK_M)          # parallelize across blockIdx.x
    for(bj = 0; j < N; j += BLOCK_N)      # parallelize across blockIdx.y
        for(bk = 0; k < K; K += BLOCK_K)
            for(i = bi; i < (bi + BLOCK_M); ++i)
                for(j = bj; j < (bj + BLOCK_N); ++j)
                    for(k = bk; k < (bk + BLOCK_K); ++k)
```

## Multistage Kernel

This kernel launches a single warp group (128 threads). The primary thread (pthread) requests load from TMA. Threads collectively wait for the data and perform mma operations. After completing the shape, threads together store first fragmented registers to shared memory, then from shared memory to global memory; this part is called the epilogue.

Execution Timeline of Multistage Kernel with 3 stages:
```
+-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
|       |Prologue ---->   |MainLoop ---->                                                                  |Epilogue  |
+-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
|pthread|[tma-0,1,2]     |[wait-0][mma][tma-2]|[wait-1][mma][tma-0]|[wait-2][mma][tma-1]| ... | [mma-wait] |[epilogue]|
|wgroup | ........       |[wait-0][mma]       |[wait-1][mma]       |[wait-2][mma]       | ... | [mma-wait] |[epilogue]|
+-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
```

## Warp Specialization Kernel

This kernel launches 2 warp groups (2x128 threads) per CTA, specializing one as `producer warp group` and another as `consumer warp group`. The `producer warp group` is responsible for requesting TMA load, while the `consumer warp group` performs the mma operation. The epilogue section is handled by the `consumer warp group` as its threads own the fragmented registers.

Execution Timeline of Warp Specialization Kernel with 2 stages:
```
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
|        |MainLoop ---->                                                    | 1st Epilogue | 2nd Epilogue    |
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
|pthread1|[tma-0] | [tma-1] | [tma-0] | [tma-1] | ..........................| ...........  | [shmem->global] |
|wgroup1 | .......|         |         |         |                           |              | [shmem->global] |
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
|wgroup2 |[wait-0][mma], [wait-1][mma], [wait-0][mma], [wait-1][mma], ......| [reg->shmem] | [shmem->global]|
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
```

>From a9b4035aa421506302e74c6f6ff19a0aec62b735 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Mon, 12 Feb 2024 13:33:26 +0000
Subject: [PATCH] [mlir] GEMM Hopper Tensor Core Integration Test

This test aims to validate the correctness of the supported GEMM kernels in
NVGPU dialects, with current support for Multistage and Warp Specialization
kernels.
The test constructs and metaprograms IR using Python bindings, allowing
generic IR building. This flexibility enables changes to the shape,
tile size, or data type of the GEMM for testing purposes.
The entry function is `matmul`, where one can specify GEMM shape, tile size,
data type, GEMM algorithm (Multistage or Warp Specialization), and the maximum
number of stages.
Verification is done via numpy's matmul operation.

Example:
```
matmul(input_type=np.float16,                # input types
       output_type=np.float32,               # output type
       M=4096, N=4096, K=4096,               # Shape
       BLOCK_M=128, BLOCK_N=128, BLOCK_K=64, # Tile Size
       use_warp_specialization=True,         # Enable Warp Specialization
       max_num_stages=3)                     # Number of stages in shared memory
```
### Parallelism Across CTAs

GEMM includes three loops defining the shape of the GEMM, specified in the
`matmul` function.
The program builds IR using the following loop structure, tiling the loops
with the given tile size and parallelizing the two outermost loops into the
first and second dimensions of CTAs.
```
for(bi = 0; i < M; i += BLOCK_M)          # parallelize across blockIdx.x
    for(bj = 0; j < N; j += BLOCK_N)      # parallelize across blockIdx.y
        for(bk = 0; k < K; K += BLOCK_K)
            for(i = bi; i < (bi + BLOCK_M); ++i)
                for(j = bj; j < (bj + BLOCK_N); ++j)
                    for(k = bk; k < (bk + BLOCK_K); ++k)
```

## Multistage Kernel

This kernel launches a single warp group (128 threads). The primary thread
(pthread) requests load from TMA. Threads collectively wait for the data and
perform mma operations. After completing the shape, threads together store
first fragmented registers to shared memory, then from shared memory to global
memory; this part is called the epilogue.

Execution Timeline of Multistage Kernel with 3 stages:
```
+-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
|       |Prologue ---->   |MainLoop ---->                                                                  |Epilogue  |
+-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
|pthread|[tma-0,1,2]     |[wait-0][mma][tma-2]|[wait-1][mma][tma-0]|[wait-2][mma][tma-1]| ... | [mma-wait] |[epilogue]|
|wgroup | ........       |[wait-0][mma]       |[wait-1][mma]       |[wait-2][mma]       | ... | [mma-wait] |[epilogue]|
+-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
```

## Warp Specialization Kernel

This kernel launches 2 warp groups (2x128 threads) per CTA, specializing one
as `producer warp group` and another as `consumer warp group`. The
`producer warp group` is responsible for requesting TMA load, while the
`consumer warp group` performs the mma operation. The epilogue section is
handled by the `consumer warp group` as its threads own the fragmented registers.

Execution Timeline of Warp Specialization Kernel with 2 stages:
```
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
|        |MainLoop ---->                                                    | 1st Epilogue | 2nd Epilogue    |
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
|pthread1|[tma-0] | [tma-1] | [tma-0] | [tma-1] | ..........................| ...........  | [shmem->global] |
|wgroup1 | .......|         |         |         |                           |              | [shmem->global] |
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
|wgroup2 |[wait-0][mma], [wait-1][mma], [wait-0][mma], [wait-1][mma], ......| [reg->shmem] | [shmem->global]|
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
```
---
 .../GPU/CUDA/sm90/python/lit.local.cfg        |   2 +
 .../GPU/CUDA/sm90/python/matmul.py            | 186 +++++
 .../GPU/CUDA/sm90/python/tools/lit.local.cfg  |   3 +
 .../CUDA/sm90/python/tools/matmulBuilder.py   | 676 ++++++++++++++++++
 .../CUDA/sm90/python/tools/nvgpucompiler.py   |  44 ++
 5 files changed, 911 insertions(+)
 create mode 100644 mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg
 create mode 100644 mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
 create mode 100644 mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg
 create mode 100644 mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
 create mode 100644 mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py

diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg b/mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg
new file mode 100644
index 00000000000000..2d5a9d00e73226
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg
@@ -0,0 +1,2 @@
+if not config.enable_cuda_runner or not config.mlir_run_cuda_sm90_tests:
+    config.unsupported = True
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
new file mode 100644
index 00000000000000..e153fcb44b9860
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
@@ -0,0 +1,186 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN:   %PYTHON %s | FileCheck %s
+# CHECK: PASS
+
+# ===--- GEMM Hopper Tensor Core Integration Test ---===
+#
+# This test aims to validate the correctness of the supported GEMM kernels in
+# NVGPU dialects, with current support for Multistage and Warp Specialization
+# kernels.
+# The test constructs and metaprograms IR using Python bindings, allowing
+# generic IR building. This flexibility enables changes to the shape,
+# tile size, or data type of the GEMM for testing purposes.
+# The entry function is `matmul`, where one can specify GEMM shape, tile size,
+# data type, GEMM algorithm (Multistage or Warp Specialization), and the maximum
+# number of stages.
+# Verification is done via numpy's matmul operation.
+#
+# Example:
+# matmul(input_type=np.float16,                # input types
+#        output_type=np.float32,               # output type
+#        M=4096, N=4096, K=4096,               # Shape
+#        BLOCK_M=128, BLOCK_N=128, BLOCK_K=64, # Tile Size
+#        use_warp_specialization=True,         # Enable Warp Specialization
+#        max_num_stages=3)                     # Number of stages in shared memory
+#
+# ===--- Parallelism Across CTAs  ---===
+#
+# GEMM includes three loops defining the shape of the GEMM, specified in the
+# `matmul` function.
+# The program builds IR using the following loop structure, tiling the loops
+# with the given tile size and parallelizing the two outermost loops into the
+# first and second dimensions of CTAs.
+#
+# for(bi = 0; i < M; i += BLOCK_M)          # parallelize across blockIdx.x
+#     for(bj = 0; j < N; j += BLOCK_N)      # parallelize across blockIdx.y
+#         for(bk = 0; k < K; K += BLOCK_K)
+#             for(i = bi; i < (bi + BLOCK_M); ++i)
+#                 for(j = bj; j < (bj + BLOCK_N); ++j)
+#                     for(k = bk; k < (bk + BLOCK_K); ++k)
+#
+# ===--- Multistage Kernel ---===
+#
+# This kernel launches a single warp group (128 threads). The primary thread
+# (pthread) requests load from TMA. Threads collectively wait for the data and
+# perform mma operations. After completing the shape, threads together store
+# first fragmented registers to shared memory, then from shared memory to global
+# memory; this part is called the epilogue.
+#
+# Execution Timeline of Multistage Kernel with 3 stages:
+# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
+# |       |Prologue ---->   |MainLoop ---->                                                                  |Epilogue  |
+# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
+# |pthread|[tma-0,1,2]     |[wait-0][mma][tma-2]|[wait-1][mma][tma-0]|[wait-2][mma][tma-1]| ... | [mma-wait] |[epilogue]|
+# |wgroup | ........       |[wait-0][mma]       |[wait-1][mma]       |[wait-2][mma]       | ... | [mma-wait] |[epilogue]|
+# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
+#
+# ===--- Warp Specialization Kernel  ---===
+#
+# This kernel launches 2 warp groups (2x128 threads) per CTA, specializing one
+# as `producer warp group` and another as `consumer warp group`. The
+# `producer warp group` is responsible for requesting TMA load, while the
+# `consumer warp group` performs the mma operation. The epilogue section is
+# handled by the `consumer warp group` as its threads own the fragmented registers.
+#
+# Execution Timeline of Warp Specialization Kernel with 2 stages:
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+# |        |MainLoop ---->                                                    | 1st Epilogue | 2nd Epilogue    |
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+# |pthread1|[tma-0] | [tma-1] | [tma-0] | [tma-1] | ..........................| ...........  | [shmem->global] |
+# |wgroup1 | .......|         |         |         |                           |              | [shmem->global] |
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+# |wgroup2 |[wait-0][mma], [wait-1][mma], [wait-0][mma], [wait-1][mma], ......| [reg->shmem] | [shmem->global]|
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+
+import errno
+import numpy as np
+import subprocess
+import ctypes
+from tools import nvgpucompiler
+from tools import matmulBuilder
+import contextlib
+import os
+import sys
+import pathlib
+import ctypes
+from mlir import runtime as rt
+
+def generate_matmul(input_type=np.float16,
+                    output_type=np.float32,
+                    M=4096,
+                    N=4096,
+                    K=4096,
+                    BLOCK_M=128,
+                    BLOCK_N=128,
+                    BLOCK_K=64,
+                    use_warp_specilization=True,
+                    saveIR=False,
+                    max_num_stages=3):
+    with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown():
+        if use_warp_specilization:
+            mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(input_type, output_type, M, N, K, BLOCK_M, BLOCK_N,
+                                                                 BLOCK_K, max_num_stages)
+        else:
+            mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage(input_type, output_type, M, N, K, BLOCK_M,
+                                                                         BLOCK_N, BLOCK_K, max_num_stages)
+
+        mlir_nvgpu_module.operation.verify()
+        
+        # Save generated IR
+        if saveIR:
+            # print(mlir_nvgpu_module)
+            original_stdout = sys.stdout
+            with open('gemm.mlir', 'w') as f:
+                sys.stdout = f
+                print(mlir_nvgpu_module)
+                sys.stdout = original_stdout
+
+        # Get compiler
+        options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
+        support_lib = os.getenv("SUPPORT_LIB")
+        if not os.path.exists(support_lib):
+            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
+        compiler = nvgpucompiler.NvgpuCompiler(options, opt_level=3, shared_libs=[support_lib])
+
+        # Compile
+        engine = compiler.compile_and_jit(mlir_nvgpu_module)
+        return engine
+
+
+def matmul(input_type=np.float16,
+           output_type=np.float32,
+           M=128,
+           N=128,
+           K=128,
+           BLOCK_M=128,
+           BLOCK_N=128,
+           BLOCK_K=64,
+           use_warp_specilization=True,
+           saveIR=False,
+           max_num_stages=3,
+           print_results=False,
+           no_verify=False):
+    # Print the configuration
+    ity = "f16" if input_type == np.float16 else "f32"
+    oty = "f16" if output_type == np.float16 else "f32"
+    gemmty = "Warp Specilization" if use_warp_specilization else "Multistage"
+    print("===-- Running GEMM " + gemmty + " " + oty + " += " + ity + " * " + ity + ", Size " + str(M) + "x" + str(N) +
+          "x" + str(K) + ", Tile " + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" + str(BLOCK_K) + ", stages " +
+          str(max_num_stages) + " --===")
+
+    # Build IR and compile
+    engine = generate_matmul(input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, use_warp_specilization,
+                             saveIR, max_num_stages)
+
+    # Allocate matrices and invoke the matmul
+    c = np.zeros((M, N), output_type)
+    a = np.random.randn(M, K).astype(input_type)
+    b = np.random.randn(K, N).astype(input_type)
+    mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
+    mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
+    mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
+    kernelName = "mlir_matmul_warpspecialized" if use_warp_specilization else "mlir_matmul_multistage"
+
+    # Launch the MLIR generated kernel
+    engine.invoke(kernelName, mem_a, mem_b, mem_c)
+
+    float_formatter = "{:.2f}".format
+    np.set_printoptions(formatter={'float_kind': float_formatter})
+
+    if print_results:
+        print(c)
+
+    # Verify the results
+    if not no_verify:
+        ref = a.astype(input_type) @ b.astype(input_type)
+        if print_results:
+            print(ref)
+        np.testing.assert_allclose(c, ref, rtol=5e-03, atol=1e-01)
+
+    print("PASS ")
+
+
+# GEMM Multistage       f32 += f16 * f16
+matmul(np.float16, np.float32, 128, 128, 4096, max_num_stages=3, use_warp_specilization=False)
+# GEMM Warp Specilized  f32 += f16 * f16
+matmul(np.float16, np.float32, 256, 1024, 512, max_num_stages=3, use_warp_specilization=True)
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg
new file mode 100644
index 00000000000000..d9f34f219c4d95
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg
@@ -0,0 +1,3 @@
+# Files in this directory are tools, not tests.
+config.unsupported = True
+
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
new file mode 100644
index 00000000000000..09ab35d4c5f15e
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
@@ -0,0 +1,676 @@
+import numpy as np
+from mlir import ir
+from mlir.dialects import arith
+from mlir.dialects import func
+from mlir.dialects import gpu
+from mlir.dialects import memref
+from mlir.dialects import nvgpu
+from mlir.dialects import nvvm
+from mlir.dialects import llvm
+from mlir.dialects import builtin
+from mlir.dialects import scf
+from mlir.dialects import vector
+
+
+TMA_LAST_DIM_F16 = 64  # 128B flaot16
+WARP_SIZE = 32
+WARP_GROUP_SIZE = WARP_SIZE * 4
+
+PRODUCER_REGISTER_SIZE = 40
+CONSUMER_REGISTER_SIZE = 232
+
+PRODUCER_PRIMARY_THREAD = 128
+CONSUMER_PRIMARY_THREAD = 0
+
+MLIR_DYNAMIC = -9223372036854775808
+f16_byte = 2
+f32_byte = 4
+
+DEBUG = False
+
+
+def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False):
+    if not DEBUG and not forcePrint:
+        return
+    type_formats = []
+    for arg in args:
+        ty_format = None
+        if ir.IndexType.isinstance(arg.type):
+            ty_format = "%llu"
+        if ir.IntegerType.isinstance(arg.type):
+            width = ir.IntegerType(arg.type).width
+            if width == 64:
+                ty_format = "%llu"
+            elif width == 32:
+                ty_format = "%d"
+            elif width == 1:
+                ty_format = "%i"
+        if ir.F32Type.isinstance(arg.type):
+            ty_format = "%f"
+        if ty_format is None:
+            raise NotImplementedError(arg.type)
+        type_formats.append(ty_format)
+    if threadNumber != -1:
+        tidx = gpu.thread_id(gpu.Dimension.x)
+        predicate = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(threadNumber))
+        scf.yield_([])
+    if_op = scf.IfOp(predicate)
+    with ir.InsertionPoint(if_op.then_block):
+        gpu.printf(fmt.format(*type_formats) + "\n", args)
+        scf.yield_([])
+
+
+def c(value, ty=None):
+    ty = ir.IndexType.get() if ty is None else ty
+    return arith.constant(ty, value)
+
+
+def generate_matmul_ws(input_type=np.float16,
+                       output_type=np.float32,
+                       M=4096,
+                       N=4096,
+                       K=4096,
+                       BLOCK_M=128,
+                       BLOCK_N=128,
+                       BLOCK_K=128,
+                       max_num_stages=3):
+    # Limitaitons for now
+    assert input_type == np.float16
+    assert output_type == np.float32
+    assert M % BLOCK_M == 0
+    assert N % BLOCK_N == 0
+    assert K % BLOCK_K == 0
+
+    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
+    num_stages = min(required_stages, max_num_stages)
+
+    module = ir.Module.create()
+    f16 = ir.F16Type.get()
+    f32 = ir.F32Type.get()
+    i1 = ir.IntegerType.get_signless(1)
+    i32 = ir.IntegerType.get_signless(32)
+    index = ir.IndexType.get()
+    i8 = ir.IntegerType.get_signless(8)
+    token_ty = ir.Type.parse("!gpu.async.token")
+    a_ty = ir.MemRefType.get([M, K], f16)
+    b_ty = ir.MemRefType.get((K, N), f16)
+    c_elem_ty = f16 if output_type == np.float16 else f32
+    c_ty = ir.MemRefType.get((M, N), c_elem_ty)
+    a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
+    b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
+    b_tile_shape = (BLOCK_K, BLOCK_N)
+    txcount = ((b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
+    smem_space_str = "#gpu.address_space<workgroup>"
+    smem_space = ir.Attribute.parse(smem_space_str)
+    input_type_str = "f16" if input_type == np.float16 else "f32"
+    output_type_str = "f16" if output_type == np.float16 else "f32"
+    mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " + str(smem_space) + ", num_barriers = " +
+                            str(num_stages) + ">")
+    a_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
+                                  str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
+                                  ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+    b_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
+                                  str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
+                                  ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+    acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
+                           str(output_type_str) + ">>")
+    a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
+                               str(input_type_str) + ", " + smem_space_str + ">>")
+    b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
+                               str(input_type_str) + ", " + smem_space_str + ">>")
+
+    with ir.InsertionPoint(module.body):
+
+        @func.FuncOp.from_py_func(a_ty, b_ty, c_ty)
+        def mlir_matmul_warpspecialized(a_host, b_host, c_host):
+            lhs_tile_bytes = BLOCK_M * BLOCK_K * f16_byte
+            rhs_tile_bytes = BLOCK_N * BLOCK_K * f16_byte
+            smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
+            smem_size_output = BLOCK_M * BLOCK_N * f32_byte
+            smem_size = max(smem_size_input, smem_size_output)
+
+            # Step 1. Allocate device memory and memcpy
+            t1 = gpu.wait(token_ty, [])
+            a_device, t2 = gpu.alloc(a_ty, token_ty, [t1], [], [])
+            b_device, t3 = gpu.alloc(b_ty, token_ty, [t2], [], [])
+            c_device, t4 = gpu.alloc(c_ty, token_ty, [t3], [], [])
+            t5 = gpu.memcpy(token_ty, [t4], a_device, a_host)
+            t6 = gpu.memcpy(token_ty, [t5], b_device, b_host)
+            t7 = gpu.wait(token_ty, [t6])
+
+            # Step 2. Create TMA Descriptors
+            tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape), (b_device, b_tma_desc_ty, b_tma_shape)]
+            tma_descs = []
+            for x_device, tensor_map_ty, tile_shape in tma_specs:
+                x_unranked = memref.cast(ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device)
+                tma_descs.append(nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked, map(c, tile_shape)).result)
+            a_tma_desc, b_tma_desc = tma_descs
+            
+            # Step 3. Launch Kernel with 2 Warpgroups : 1 Producer, 1 Consumer
+            cta_m = M // BLOCK_M
+            cta_n = N // BLOCK_N
+            assert M % BLOCK_M == 0 and N % BLOCK_N == 0
+            grid = (cta_m, cta_n, 1)
+            block = (WARP_GROUP_SIZE * 2, 1, 1)
+            launch_op = gpu.LaunchOp(token_ty, [t7],
+                                     *map(c, grid),
+                                     *map(c, block),
+                                     dynamicSharedMemorySize=c(smem_size, ty=i32))
+            launch_op.body.blocks.append(*([index] * 12))
+            with ir.InsertionPoint(launch_op.body.blocks[0]):
+                # GPU Step 0. This is need for vectorized ld/st
+                memref.assume_alignment(c_device, 16)
+                dynamic_smem = gpu.dynamic_shared_memory(
+                    ir.MemRefType.get((MLIR_DYNAMIC, ), i8, memory_space=smem_space))
+                ticks = c(10000000)
+
+                # GPU Step 1. Bootstrapping: find the primary thread, warps, warp groups and etc.
+                tidx = gpu.thread_id(gpu.Dimension.x)
+                wgPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0))
+                warp_id = arith.divui(tidx, c(32))
+                warpgroup_id = arith.divui(warp_id, c(4))
+                is_producer = arith.cmpi(arith.CmpIPredicate.eq, warpgroup_id,
+                                         c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0))
+                is_consumer = arith.cmpi(arith.CmpIPredicate.eq, warpgroup_id,
+                                         c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1))
+                producerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(PRODUCER_PRIMARY_THREAD))
+                consumerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(CONSUMER_PRIMARY_THREAD))
+                bidx = gpu.block_id(gpu.Dimension.x)
+                bidy = gpu.block_id(gpu.Dimension.y)
+                dimX = arith.muli(bidx, c(BLOCK_M))
+                dimY = arith.muli(bidy, c(BLOCK_N))
+
+                # GPU Step 2. Initialize mbarrier groups
+                mbarTMA = nvgpu.mbarrier_create(mbar_ty)
+                mbarDONE = nvgpu.mbarrier_create(mbar_ty)
+                for i in range(num_stages):
+                    nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=wgPrimaryThread)
+                    nvgpu.mbarrier_init(mbarDONE, c(1), c(i), predicate=wgPrimaryThread)
+                gpu.barrier()
+
+                # GPU Step 3. Prefetch TMA descriptors
+                nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=wgPrimaryThread)
+                nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=wgPrimaryThread)
+                
+                # GPU Step 5. Producer Warpgroup (TMA Warpgroup)
+                with ir.InsertionPoint(scf.IfOp(is_producer).then_block):
+                    
+                    # Step 5.1. Reduce register size
+                    nvvm.setmaxregister(PRODUCER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.decrease)
+
+                    # Step 5.2. TMA Main Loop
+                    for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [arith.constant(i1, 1)])
+                    with ir.InsertionPoint(for_op.body):
+                        phaseParity = for_op.inner_iter_args[0]
+                        iv = for_op.induction_variable
+                        stage = arith.remui(iv, c(num_stages))
+
+                        # Step 5.2.1. Wait mbarDONE
+                        debug_print("[prod] {}  | mbarDONE[{}] try_wait  phase={}",
+                                    iv,
+                                    stage,
+                                    phaseParity,
+                                    predicate=producerPrimaryThread)
+                        nvgpu.MBarrierTryWaitParityOp(mbarDONE, phaseParity, ticks, mbarId=stage)
+                        debug_print("[prod] {}  | mbarDONE[{}] try_wait  phase={} [done]",
+                                    iv,
+                                    stage,
+                                    phaseParity,
+                                    predicate=producerPrimaryThread)
+                        p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
+                        phaseParity = arith.select(p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity)
+
+                        # Step 5.2.2. Load TMA
+                        a_offset = arith.muli(stage, c(lhs_tile_bytes))
+                        a_tma_slice = memref.view(ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
+                                                  dynamic_smem, a_offset, [])
+                        b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
+                        b_tma_slice_1 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+                                                    dynamic_smem, b_offset, [])
+                        b_offset2 = arith.addi(b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+                        b_tma_slice_2 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+                                                    dynamic_smem, b_offset2, [])
+                        debug_print("[prod] a_offset={} b_offset={} b_offset2={}",
+                                    a_offset,
+                                    b_offset,
+                                    b_offset2,
+                                    predicate=producerPrimaryThread)
+                        coord = arith.muli(c(64), iv)
+                        nvgpu.TmaAsyncLoadOp(a_tma_slice,
+                                             mbarTMA,
+                                             a_tma_desc,
+                                             coordinates=[coord, dimX],
+                                             mbarId=stage,
+                                             predicate=producerPrimaryThread)
+                        nvgpu.TmaAsyncLoadOp(b_tma_slice_1,
+                                             mbarTMA,
+                                             b_tma_desc,
+                                             coordinates=[dimY, coord],
+                                             mbarId=stage,
+                                             predicate=producerPrimaryThread)
+                        dimY2 = arith.addi(dimY, c(64))
+                        nvgpu.TmaAsyncLoadOp(b_tma_slice_2,
+                                             mbarTMA,
+                                             b_tma_desc,
+                                             coordinates=[dimY2, coord],
+                                             mbarId=stage,
+                                             predicate=producerPrimaryThread)
+
+                        # Step 5.2.3. Arrive mbarTMA
+                        debug_print("[prod] {}  | mbarTMA[{}] arrive", iv, stage, predicate=producerPrimaryThread)
+                        nvgpu.mbarrier_arrive_expect_tx(mbarTMA, c(txcount), stage, predicate=producerPrimaryThread)
+                        debug_print("[prod] {}  | mbarTMA[{}] arrive [done]",
+                                    iv,
+                                    stage,
+                                    predicate=producerPrimaryThread)
+                        scf.yield_([phaseParity])
+                    scf.yield_([])
+
+                # GPU Step 6. Consumer Warpgroup (MMA Warpgroup)
+                if_op = scf.IfOp(is_consumer)
+                with ir.InsertionPoint(if_op.then_block):
+                    
+                    # Step 6.1. Increase register size
+                    nvvm.setmaxregister(CONSUMER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.increase)
+                    
+                    # GPU Step 6.2. Initialize MMA registers
+                    acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
+
+                    # Step 6.3. MMA Main Loop
+                    for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)])
+                    with ir.InsertionPoint(for_op.body):
+                        # Step 6.3.1. Wait mbar1
+                        phaseParity = for_op.inner_iter_args[1]
+                        iv = for_op.induction_variable
+                        stage = arith.remui(iv, c(num_stages))
+                        debug_print("[cons] {}  | mbarTMA[{}] try_wait   phase={}",
+                                    iv,
+                                    stage,
+                                    phaseParity,
+                                    predicate=consumerPrimaryThread)
+                        nvgpu.MBarrierTryWaitParityOp(mbarTMA, phaseParity, ticks, mbarId=stage)
+                        debug_print("[cons] {}  | mbarTMA[{}] try_wait   phase={} [done]",
+                                    iv,
+                                    stage,
+                                    phaseParity,
+                                    predicate=consumerPrimaryThread)
+                        
+                        # Step 6.3.2. Create WGMMA Descriptors
+                        a_offset = arith.muli(stage, c(lhs_tile_bytes))
+                        a_tile_slice = memref.view(ir.MemRefType.get(a_tile_shape, f16, memory_space=smem_space),
+                                                   dynamic_smem, a_offset, [])
+                        b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
+                        b_tile_slice = memref.view(ir.MemRefType.get(b_tile_shape, f16, memory_space=smem_space),
+                                                   dynamic_smem, b_offset, [])
+                        debug_print("[cons] a_offset={} b_offset={}",
+                                    a_offset,
+                                    b_offset,
+                                    predicate=consumerPrimaryThread)
+                        da = nvgpu.WarpgroupGenerateDescriptorOp(a_wgmma_ty, a_tile_slice, a_tma_desc)
+                        db = nvgpu.WarpgroupGenerateDescriptorOp(b_wgmma_ty, b_tile_slice, b_tma_desc)
+
+                        # Step 6.3.3. MMA
+                        carry_acc = for_op.inner_iter_args[0]
+                        new_acc = nvgpu.WarpgroupMmaOp(acc.type, da, db, carry_acc, transposeB=True)
+
+                        # Step 6.3.4. Arrive mbarDONE
+                        p1 = arith.cmpi(arith.CmpIPredicate.sgt, iv, c(0))
+                        p_arrive = arith.andi(consumerPrimaryThread, p1)
+                        with ir.InsertionPoint(scf.IfOp(p_arrive).then_block):
+                            p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(0))
+                            barId = arith.select(p, c(num_stages - 1), arith.subi(stage, c(1)))
+                            remoteCtaId = c(0)
+                            pred = consumerPrimaryThread
+                            debug_print("[cons] {}  | mbarDONE[{}] arrive.mapa  pred={} ",
+                                        iv,
+                                        barId,
+                                        remoteCtaId,
+                                        pred,
+                                        predicate=consumerPrimaryThread)
+                            nvgpu.mbarrier_arrive(ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId)
+                            debug_print("[cons] {}  | mbarDONE[{}] arrive  pred={} [done]",
+                                        iv,
+                                        barId,
+                                        remoteCtaId,
+                                        pred,
+                                        predicate=consumerPrimaryThread)
+                            scf.yield_([])
+
+                        p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
+                        phaseParity = arith.select(p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity)
+
+                        # Step 6.3.5. Yield
+                        scf.yield_([new_acc, phaseParity])
+
+                    # Step 6.3. Wait All WGMMA
+                    nvvm.WgmmaWaitGroupSyncOp(0)
+
+                    with ir.InsertionPoint(scf.IfOp(consumerPrimaryThread).then_block):
+                        barId = c((K // BLOCK_K) % num_stages)
+                        nvgpu.mbarrier_arrive(ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId)
+                        scf.yield_([])
+
+                    # Step 6.4. Epilogue (registers --> shared memory)
+                    acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space)
+                    acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
+                    debug_print("[cons]  | Storing", predicate=consumerPrimaryThread)
+                    nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
+                    scf.yield_([])
+                gpu.barrier()
+
+                # GPU Step 9. Epilogue (shared memory --> global memory)
+                fd = ir.MemRefType.get([BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space)
+                collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
+                rty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty,
+                                        ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"))
+                c_device_per_block = memref.SubViewOp(rty, c_device, [dimX, dimY], [], [], [MLIR_DYNAMIC, MLIR_DYNAMIC],
+                                                      [BLOCK_M, BLOCK_N], [1, 1])
+                vlen = 1
+                for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE * 2))
+                with ir.InsertionPoint(for_op.body):
+                    x = arith.divui(for_op.induction_variable, c(BLOCK_M))
+                    y = arith.remui(for_op.induction_variable, c(BLOCK_N))
+                    vdata = vector.load(ir.VectorType.get((vlen, ), c_elem_ty), collapsed_smem,
+                                        [for_op.induction_variable])
+                    vector.store(vdata, c_device_per_block, [x, y])
+                    scf.yield_([])
+
+                gpu.terminator()
+
+            # Step 4. Copy back to host
+            t8 = gpu.wait(token_ty, [launch_op])
+            t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
+            gpu.wait(token_ty, [t9])
+
+    mlir_matmul_warpspecialized.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+    module.operation.verify()
+    return module
+
+
+def generate_matmul_multistage(input_type=np.float16,
+                               output_type=np.float32,
+                               M=4096,
+                               N=4096,
+                               K=4096,
+                               BLOCK_M=128,
+                               BLOCK_N=128,
+                               BLOCK_K=64,
+                               max_num_stages=3):
+    # Limitaitons for now
+    assert input_type == np.float16
+    assert output_type == np.float32
+    assert M % BLOCK_M == 0
+    assert N % BLOCK_N == 0
+    assert K % BLOCK_K == 0
+
+    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
+    num_stages = min(required_stages, max_num_stages)
+
+    module = ir.Module.create()
+    f16 = ir.F16Type.get()
+    f32 = ir.F32Type.get()
+    i1 = ir.IntegerType.get_signless(1)
+    i32 = ir.IntegerType.get_signless(32)
+    index = ir.IndexType.get()
+    i8 = ir.IntegerType.get_signless(8)
+    token_ty = ir.Type.parse("!gpu.async.token")
+    a_ty = ir.MemRefType.get([M, K], f16)
+    b_ty = ir.MemRefType.get((K, N), f16)
+    c_elem_ty = f16 if output_type == np.float16 else f32
+    c_ty = ir.MemRefType.get((M, N), c_elem_ty)
+    a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
+    b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
+    b_tile_shape = (BLOCK_K, BLOCK_N)
+    txcount = ((b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
+    smem_space_str = "#gpu.address_space<workgroup>"
+    smem_space = ir.Attribute.parse(smem_space_str)
+    input_type_str = "f16" if input_type == np.float16 else "f32"
+    output_type_str = "f16" if output_type == np.float16 else "f32"
+    mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " + str(smem_space) + ", num_barriers = " +
+                            str(num_stages) + ">")
+    a_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
+                                  str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
+                                  ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+    b_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
+                                  str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
+                                  ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+    acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
+                           str(output_type_str) + ">>")
+    a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
+                               str(input_type_str) + ", " + smem_space_str + ">>")
+    b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
+                               str(input_type_str) + ", " + smem_space_str + ">>")
+
+    with ir.InsertionPoint(module.body):
+
+        @func.FuncOp.from_py_func(a_ty, b_ty, c_ty)
+        def mlir_matmul_multistage(a_host, b_host, c_host):
+            lhs_tile_bytes = BLOCK_M * BLOCK_K * f16_byte
+            rhs_tile_bytes = BLOCK_N * BLOCK_K * f16_byte
+            smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
+            smem_size_output = BLOCK_M * BLOCK_N * f32_byte
+            smem_size = max(smem_size_input, smem_size_output)
+
+            # Step 1. Allocate device memory and memcpy
+            t1 = gpu.wait(token_ty, [])
+            a_device, t2 = gpu.alloc(a_ty, token_ty, [t1], [], [])
+            b_device, t3 = gpu.alloc(b_ty, token_ty, [t2], [], [])
+            c_device, t4 = gpu.alloc(c_ty, token_ty, [t3], [], [])
+            t5 = gpu.memcpy(token_ty, [t4], a_device, a_host)
+            t6 = gpu.memcpy(token_ty, [t5], b_device, b_host)
+            t7 = gpu.wait(token_ty, [t6])
+
+            # Step 2. Create TMA Descriptors
+            tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape), (b_device, b_tma_desc_ty, b_tma_shape)]
+            tma_descs = []
+            for x_device, tensor_map_ty, tile_shape in tma_specs:
+                x_unranked = memref.cast(ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device)
+                tma_descs.append(nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked, map(c, tile_shape)).result)
+            a_tma_desc, b_tma_desc = tma_descs
+
+            # Step 3. Launch Kernel with 1 Warpgroup
+            cta_m = M // BLOCK_M
+            cta_n = N // BLOCK_N
+            assert M % BLOCK_M == 0 and N % BLOCK_N == 0
+            grid = (cta_m, cta_n, 1)
+            block = (WARP_GROUP_SIZE, 1, 1)
+            launch_op = gpu.LaunchOp(token_ty, [t7],
+                                     *map(c, grid),
+                                     *map(c, block),
+                                     dynamicSharedMemorySize=c(smem_size, ty=i32))
+            launch_op.body.blocks.append(*([index] * 12))
+            with ir.InsertionPoint(launch_op.body.blocks[0]):
+                # GPU Step 0. Bootstrapping
+                memref.assume_alignment(c_device, 16)
+                dynamic_smem = gpu.dynamic_shared_memory(
+                    ir.MemRefType.get((MLIR_DYNAMIC, ), i8, memory_space=smem_space))
+                ticks = c(10000000)
+                tidx = gpu.thread_id(gpu.Dimension.x)
+                primaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(0))
+                warpId = arith.divui(tidx, c(32))
+                bidx = gpu.block_id(gpu.Dimension.x)
+                bidy = gpu.block_id(gpu.Dimension.y)
+                dimX = arith.muli(bidx, c(BLOCK_M))
+                dimY = arith.muli(bidy, c(BLOCK_N))
+
+                # GPU Step 1. Initialize mbarrier groups
+                mbarTMA = nvgpu.mbarrier_create(mbar_ty)
+                for i in range(num_stages):
+                    nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=primaryThread)
+                gpu.barrier()
+
+                # GPU Step 2. Prefetch TMA descriptors
+                nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=primaryThread)
+                nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=primaryThread)
+
+                # GPU Step 3. Prologue (global memory --> shared memory)
+                for_op = scf.ForOp(c(0), c(num_stages-1), c(1))
+                with ir.InsertionPoint(for_op.body):
+                    iv = for_op.induction_variable
+
+                    # Step 3.1. Calculate offsets
+                    a_offset = arith.muli(iv, c(lhs_tile_bytes))
+                    a_tma_slice = memref.view(ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
+                                              dynamic_smem, a_offset, [])
+                    b_offset = arith.addi(arith.muli(iv, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
+                    b_tma_slice_1 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+                                                dynamic_smem, b_offset, [])
+                    b_offset2 = arith.addi(b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+                    b_tma_slice_2 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+                                                dynamic_smem, b_offset2, [])
+
+                    # Step 3.2. TMA Load
+                    coord = arith.muli(c(64), iv)
+                    dimY2 = arith.addi(dimY, c(64))
+                    debug_print("[Prologue] TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
+                                a_offset,
+                                b_offset,
+                                b_offset2,
+                                coord, dimX,
+                                dimY, coord,
+                                predicate=primaryThread)
+                    nvgpu.TmaAsyncLoadOp(a_tma_slice,
+                                         mbarTMA,
+                                         a_tma_desc,
+                                         coordinates=[coord, dimX],
+                                         mbarId=iv,
+                                         predicate=primaryThread)
+                    nvgpu.TmaAsyncLoadOp(b_tma_slice_1,
+                                         mbarTMA,
+                                         b_tma_desc,
+                                         coordinates=[dimY, coord],
+                                         mbarId=iv,
+                                         predicate=primaryThread)
+                    nvgpu.TmaAsyncLoadOp(b_tma_slice_2,
+                                         mbarTMA,
+                                         b_tma_desc,
+                                         coordinates=[dimY2, coord],
+                                         mbarId=iv,
+                                         predicate=primaryThread)
+
+                    # Step 3.2. mbarTMA arrive
+                    debug_print("[Prologue] mbarTMA[{}] arrive", iv, predicate=primaryThread)
+                    nvgpu.mbarrier_arrive_expect_tx(mbarTMA, c(txcount), iv, predicate=primaryThread)
+                    debug_print("[Prologue] mbarTMA[{}] arrive [done]", iv, predicate=primaryThread)
+                    scf.yield_([])
+
+                # GPU Step 4. Main Loop
+                acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
+                for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)])
+                with ir.InsertionPoint(for_op.body):
+                    # Step 4.1. Wait mbarTMA
+                    phaseParity = for_op.inner_iter_args[1]
+                    iv = for_op.induction_variable
+                    stage = arith.remui(iv, c(num_stages))
+                    debug_print("[MainLoop] mbarTMA[{}] try_wait   phase={}", stage, phaseParity, predicate=primaryThread)
+                    nvgpu.MBarrierTryWaitParityOp(mbarTMA, phaseParity, ticks, mbarId=stage)
+                    debug_print("[MainLoop] mbarTMA[{}] try_wait   phase={} [done]", stage, phaseParity, predicate=primaryThread)
+
+                    # Step 4.2. Create WGMMA Descriptors
+                    a_offset = arith.muli(stage, c(lhs_tile_bytes))
+                    a_tile_slice = memref.view(ir.MemRefType.get(a_tile_shape, f16, memory_space=smem_space),
+                                               dynamic_smem, a_offset, [])
+                    b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
+                    b_tile_slice = memref.view(ir.MemRefType.get(b_tile_shape, f16, memory_space=smem_space),
+                                               dynamic_smem, b_offset, [])
+                    debug_print("[MainLoop] iv={} MMA a_offset={} b_offset={}", iv, a_offset, b_offset, predicate=primaryThread)
+                    da = nvgpu.WarpgroupGenerateDescriptorOp(a_wgmma_ty, a_tile_slice, a_tma_desc)
+                    db = nvgpu.WarpgroupGenerateDescriptorOp(b_wgmma_ty, b_tile_slice, b_tma_desc)
+
+                    # Step 4.3. MMA
+                    carry_acc = for_op.inner_iter_args[0]
+                    new_acc = nvgpu.WarpgroupMmaOp(acc.type, da, db, carry_acc, transposeB=True)
+
+                    # Step 4.4. Load TMA for next stage
+                    p1 = arith.cmpi(arith.CmpIPredicate.ult, arith.addi(iv, c(num_stages-1)), c(K // BLOCK_K))
+                    p = arith.andi(primaryThread, p1)
+                    with ir.InsertionPoint(scf.IfOp(p).then_block):
+                        nextStage = arith.addi(iv, c(num_stages-1))
+                        nextSlot = arith.remui(nextStage, c(num_stages))
+                        a_offset = arith.muli(nextSlot, c(lhs_tile_bytes))
+                        a_tma_slice = memref.view(ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
+                                                dynamic_smem, a_offset, [])
+                        b_offset = arith.addi(arith.muli(nextSlot, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
+                        b_tma_slice_1 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+                                                    dynamic_smem, b_offset, [])
+                        b_offset2 = arith.addi(b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+                        b_tma_slice_2 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+                                                    dynamic_smem, b_offset2, [])
+
+                        coord = arith.muli(c(64), nextStage)
+                        debug_print("[MainLoop] iv={} TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
+                                iv, a_offset,
+                                b_offset,
+                                b_offset2,
+                                coord, dimX,
+                                dimY, coord,
+                                predicate=primaryThread)
+                        nvgpu.TmaAsyncLoadOp(a_tma_slice,
+                                            mbarTMA,
+                                            a_tma_desc,
+                                            coordinates=[coord, dimX],
+                                            mbarId=nextSlot,
+                                            predicate=primaryThread)
+                        nvgpu.TmaAsyncLoadOp(b_tma_slice_1,
+                                            mbarTMA,
+                                            b_tma_desc,
+                                            coordinates=[dimY, coord],
+                                            mbarId=nextSlot,
+                                            predicate=primaryThread)
+                        dimY2 = arith.addi(dimY, c(64))
+                        nvgpu.TmaAsyncLoadOp(b_tma_slice_2,
+                                            mbarTMA,
+                                            b_tma_desc,
+                                            coordinates=[dimY2, coord],
+                                            mbarId=nextSlot,
+                                            predicate=primaryThread)
+
+                        debug_print("[MainLoop] mbarTMA[{}] arrive", nextSlot, predicate=primaryThread)
+                        nvgpu.mbarrier_arrive_expect_tx(mbarTMA, c(txcount), nextSlot, predicate=primaryThread)
+                        debug_print("[MainLoop] mbarTMA[{}] arrive [done]", nextSlot, predicate=primaryThread)
+                        scf.yield_([])
+                    # Step 4.5. Change the phaseParity
+                    p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
+                    phaseParity = arith.select(p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity)
+
+                    # Step 4.5. Yield
+                    scf.yield_([new_acc, phaseParity])
+
+                # Step 5. Wait All WGMMA groups
+                nvvm.WgmmaWaitGroupSyncOp(0)
+                gpu.barrier()
+
+                # Step 6. Epilogue (registers --> shared memory)
+                acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space)
+                acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
+                debug_print("Storing", predicate=primaryThread)
+                nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
+                gpu.barrier()
+
+                # GPU Step 7. Epilogue (shared memory --> global memory)
+                fd = ir.MemRefType.get([BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space)
+                collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
+                rty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty,
+                                        ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"))
+                c_device_per_block = memref.SubViewOp(rty, c_device, [dimX, dimY], [], [], [MLIR_DYNAMIC, MLIR_DYNAMIC],
+                                                      [BLOCK_M, BLOCK_N], [1, 1])
+                vlen = 1
+                for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE))
+                with ir.InsertionPoint(for_op.body):
+                    x = arith.divui(for_op.induction_variable, c(BLOCK_M))
+                    y = arith.remui(for_op.induction_variable, c(BLOCK_N))
+                    vdata = vector.load(ir.VectorType.get((vlen, ), c_elem_ty), collapsed_smem,
+                                        [for_op.induction_variable])
+                    vector.store(vdata, c_device_per_block, [x, y])
+                    scf.yield_([])
+
+                gpu.terminator()
+
+            # Step 4. Copy back to host
+            t8 = gpu.wait(token_ty, [launch_op])
+            t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
+            gpu.wait(token_ty, [t9])
+
+    mlir_matmul_multistage.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+    module.operation.verify()
+    return module
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
new file mode 100644
index 00000000000000..2a4f67326ee718
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
@@ -0,0 +1,44 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#  This file contains the Nvgpu class.
+
+from mlir import execution_engine
+from mlir import ir
+from mlir import passmanager
+from typing import Sequence
+import errno
+import os
+import sys
+
+_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(_SCRIPT_PATH)
+
+class NvgpuCompiler:
+    """Nvgpu class for compiling and building MLIR modules."""
+
+    def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
+        pipeline = f"builtin.module(gpu-lower-to-nvvm-pipeline{{{options}}})"
+        self.pipeline = pipeline
+        self.shared_libs = shared_libs
+        self.opt_level = opt_level
+
+    def __call__(self, module: ir.Module):
+        """Convenience application method."""
+        self.compile(module)
+
+    def compile(self, module: ir.Module):
+        """Compiles the module by invoking the nvgpu pipeline."""
+        passmanager.PassManager.parse(self.pipeline).run(module.operation)
+    
+    def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+        """Wraps the module in a JIT execution engine."""
+        return execution_engine.ExecutionEngine(
+            module, opt_level=self.opt_level, shared_libs=self.shared_libs
+        )
+
+    def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+        """Compiles and jits the module."""
+        self.compile(module)
+        return self.jit(module)



More information about the Mlir-commits mailing list