[Mlir-commits] [mlir] d95e6d0 - [mlir] GEMM Hopper Tensor Core Integration Test (#81478)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 4 13:04:03 PST 2024


Author: Guray Ozen
Date: 2024-03-04T21:03:59Z
New Revision: d95e6d027486876559f1a2a96c33b8ad93cc0ae4

URL: https://github.com/llvm/llvm-project/commit/d95e6d027486876559f1a2a96c33b8ad93cc0ae4
DIFF: https://github.com/llvm/llvm-project/commit/d95e6d027486876559f1a2a96c33b8ad93cc0ae4.diff

LOG: [mlir] GEMM Hopper Tensor Core Integration Test (#81478)

Added: 
    mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg
    mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
    mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg
    mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
    mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py

Modified: 
    

Removed: 
    


################################################################################
diff  --git a/mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg b/mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg
new file mode 100644
index 00000000000000..2d5a9d00e73226
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg
@@ -0,0 +1,2 @@
+if not config.enable_cuda_runner or not config.mlir_run_cuda_sm90_tests:
+    config.unsupported = True

diff  --git a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
new file mode 100644
index 00000000000000..cb7248ef23cd9e
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
@@ -0,0 +1,341 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN:   %PYTHON %s | FileCheck %s
+
+
+# ===--- GEMM Hopper Tensor Core Integration Test ---===
+#
+# This test aims to validate the correctness of the supported GEMM kernels in
+# NVGPU dialects, with current support for Multistage and Warp Specialization
+# kernels.
+# The test constructs and metaprograms IR using Python bindings, allowing
+# generic IR building. This flexibility enables changes to the shape,
+# tile size, or data type of the GEMM for testing purposes.
+# The entry function is `matmul`, where one can specify GEMM shape, tile size,
+# data type, GEMM algorithm (Multistage or Warp Specialization), and the maximum
+# number of stages.
+# Verification is done via numpy's matmul operation.
+#
+# Example:
+# matmul(input_type=np.float16,                # input types
+#        output_type=np.float32,               # output type
+#        M=4096, N=4096, K=4096,               # Shape
+#        BLOCK_M=128, BLOCK_N=128, BLOCK_K=64, # Tile Size
+#        use_warp_specialization=True,         # Enable Warp Specialization
+#        max_num_stages=3)                     # Number of stages in shared memory
+#
+# ===--- Parallelism Across CTAs  ---===
+#
+# GEMM includes three loops defining the shape of the GEMM, specified in the
+# `matmul` function.
+# The program builds IR using the following loop structure, tiling the loops
+# with the given tile size and parallelizing the two outermost loops into the
+# first and second dimensions of CTAs.
+#
+# for(bi = 0; i < M; i += BLOCK_M)          # parallelize across blockIdx.x
+#     for(bj = 0; j < N; j += BLOCK_N)      # parallelize across blockIdx.y
+#         for(bk = 0; k < K; K += BLOCK_K)
+#             for(i = bi; i < (bi + BLOCK_M); ++i)
+#                 for(j = bj; j < (bj + BLOCK_N); ++j)
+#                     for(k = bk; k < (bk + BLOCK_K); ++k)
+#
+# ===--- Multistage Kernel ---===
+#
+# This kernel launches a single warp group (128 threads). The primary thread
+# (pthread) requests load from TMA. Threads collectively wait for the data and
+# perform mma operations. After completing the shape, threads together store
+# first fragmented registers to shared memory, then from shared memory to global
+# memory; this part is called the epilogue.
+#
+# Execution Timeline of Multistage Kernel with 3 stages:
+# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
+# |       |Prologue ---->   |MainLoop ---->                                                                  |Epilogue  |
+# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
+# |pthread|[tma-0,1,2]     |[wait-0][mma][tma-2]|[wait-1][mma][tma-0]|[wait-2][mma][tma-1]| ... | [mma-wait] |[epilogue]|
+# |wgroup | ........       |[wait-0][mma]       |[wait-1][mma]       |[wait-2][mma]       | ... | [mma-wait] |[epilogue]|
+# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
+#
+# ===--- Warp Specialization Kernel  ---===
+#
+# This kernel launches 2 warp groups (2x128 threads) per CTA, specializing one
+# as `producer warp group` and another as `consumer warp group`. The
+# `producer warp group` is responsible for requesting TMA load, while the
+# `consumer warp group` performs the mma operation. The epilogue section is
+# handled by the `consumer warp group` as its threads own the fragmented registers.
+#
+# Execution Timeline of Warp Specialization Kernel with 2 stages:
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+# |        |MainLoop ---->                                                    | 1st Epilogue | 2nd Epilogue    |
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+# |pthread1|[tma-0] | [tma-1] | [tma-0] | [tma-1] | ..........................| ...........  | [shmem->global] |
+# |wgroup1 | .......|         |         |         |                           |              | [shmem->global] |
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+# |wgroup2 |[wait-0][mma], [wait-1][mma], [wait-0][mma], [wait-1][mma], ......| [reg->shmem] | [shmem->global]|
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+
+import errno
+import numpy as np
+import subprocess
+import ctypes
+from tools import nvgpucompiler
+from tools import matmulBuilder
+import contextlib
+import os
+import sys
+import pathlib
+import ctypes
+from mlir import runtime as rt
+
+
+def generate_matmul(
+    input_type=np.float16,
+    output_type=np.float32,
+    M=4096,
+    N=4096,
+    K=4096,
+    BLOCK_M=128,
+    BLOCK_N=128,
+    BLOCK_K=64,
+    use_warp_specialization=True,
+    saveIR=False,
+    max_num_stages=3,
+    options=f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3",
+):
+    with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown():
+        if use_warp_specialization:
+            mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(
+                input_type,
+                output_type,
+                M,
+                N,
+                K,
+                BLOCK_M,
+                BLOCK_N,
+                BLOCK_K,
+                max_num_stages,
+            )
+        else:
+            mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage(
+                input_type,
+                output_type,
+                M,
+                N,
+                K,
+                BLOCK_M,
+                BLOCK_N,
+                BLOCK_K,
+                max_num_stages,
+            )
+
+        mlir_nvgpu_module.operation.verify()
+
+        # Save generated IR
+        if saveIR:
+            # print(mlir_nvgpu_module)
+            original_stdout = sys.stdout
+            with open("gemm.mlir", "w") as f:
+                sys.stdout = f
+                print(mlir_nvgpu_module)
+                sys.stdout = original_stdout
+
+        # Get compiler
+        support_lib = os.getenv("SUPPORT_LIB")
+        if not os.path.exists(support_lib):
+            raise FileNotFoundError(
+                errno.ENOENT, os.strerror(errno.ENOENT), support_lib
+            )
+        compiler = nvgpucompiler.NvgpuCompiler(
+            options, opt_level=3, shared_libs=[support_lib]
+        )
+
+        # Compile
+        engine = compiler.compile_and_jit(mlir_nvgpu_module)
+        return engine
+
+
+def matmul(
+    input_type=np.float16,
+    output_type=np.float32,
+    M=128,
+    N=128,
+    K=128,
+    BLOCK_M=128,
+    BLOCK_N=128,
+    BLOCK_K=64,
+    use_warp_specialization=True,
+    saveIR=False,
+    max_num_stages=3,
+    print_results=False,
+    no_verify=False,
+):
+    # Print the configuration
+    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
+    num_stages = min(required_stages, max_num_stages)
+    ity = "f16" if input_type == np.float16 else "f32"
+    oty = "f16" if output_type == np.float16 else "f32"
+    gemmty = "Warp specialization" if use_warp_specialization else "Multistage"
+    print(
+        "===-- Running GEMM "
+        + gemmty
+        + " "
+        + oty
+        + " += "
+        + ity
+        + " * "
+        + ity
+        + ", Size "
+        + str(M)
+        + "x"
+        + str(N)
+        + "x"
+        + str(K)
+        + ", Tile "
+        + str(BLOCK_M)
+        + "x"
+        + str(BLOCK_N)
+        + "x"
+        + str(BLOCK_K)
+        + ", stages "
+        + str(num_stages)
+        + " --==="
+    )
+
+    # Build IR and compile
+    engine = generate_matmul(
+        input_type,
+        output_type,
+        M,
+        N,
+        K,
+        BLOCK_M,
+        BLOCK_N,
+        BLOCK_K,
+        use_warp_specialization,
+        saveIR,
+        num_stages,
+    )
+
+    # Allocate matrices and invoke the matmul
+    c = np.zeros((M, N), output_type)
+    a = np.random.randn(M, K).astype(input_type)
+    b = np.random.randn(K, N).astype(input_type)
+    mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
+    mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
+    mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
+    kernelName = matmulBuilder.make_kernel_name(
+        input_type,
+        output_type,
+        M,
+        N,
+        K,
+        BLOCK_M,
+        BLOCK_N,
+        BLOCK_K,
+        num_stages,
+        use_warp_specialization,
+    )
+
+    # Launch the MLIR generated kernel
+    engine.invoke(kernelName, mem_a, mem_b, mem_c)
+
+    float_formatter = "{:.2f}".format
+    np.set_printoptions(formatter={"float_kind": float_formatter})
+
+    if print_results:
+        print(c)
+
+    # Verify the results
+    if not no_verify:
+        ref = a.astype(input_type) @ b.astype(input_type)
+        if print_results:
+            print(ref)
+        np.testing.assert_allclose(c, ref, rtol=5e-03, atol=1e-01)
+
+    print("PASS ")
+
+
+# Takes longer time to run
+def test_long():
+    for stages in range(1, 7):
+        for M in [128, 512, 1024, 4096, 8192]:
+            for N in [128, 512, 1024, 4096, 8192]:
+                for K in [64, 128, 512, 1024, 4096, 8192]:
+                    matmul(
+                        np.float16,
+                        np.float32,
+                        M,
+                        N,
+                        K,
+                        max_num_stages=stages,
+                        use_warp_specialization=False,
+                        no_verify=True,
+                    )
+                    matmul(
+                        np.float16,
+                        np.float32,
+                        M,
+                        N,
+                        K,
+                        max_num_stages=stages,
+                        use_warp_specialization=True,
+                    )
+
+
+def test_short():
+    for stages in [1, 3]:
+        for M in [128, 512]:
+            for N in [128]:
+                for K in [64, 256]:
+                    matmul(
+                        np.float16,
+                        np.float32,
+                        M,
+                        N,
+                        K,
+                        max_num_stages=stages,
+                        use_warp_specialization=False,
+                    )
+                    matmul(
+                        np.float16,
+                        np.float32,
+                        M,
+                        N,
+                        K,
+                        max_num_stages=stages,
+                        use_warp_specialization=True,
+                    )
+
+
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 3 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 3 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 3 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 3 --===
+# CHECK: PASS
+
+test_short()

diff  --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg
new file mode 100644
index 00000000000000..d9f34f219c4d95
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg
@@ -0,0 +1,3 @@
+# Files in this directory are tools, not tests.
+config.unsupported = True
+

diff  --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
new file mode 100644
index 00000000000000..fac138dce605a7
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
@@ -0,0 +1,1156 @@
+import numpy as np
+from mlir import ir
+from mlir.dialects import arith
+from mlir.dialects import func
+from mlir.dialects import gpu
+from mlir.dialects import memref
+from mlir.dialects import nvgpu
+from mlir.dialects import nvvm
+from mlir.dialects import llvm
+from mlir.dialects import builtin
+from mlir.dialects import scf
+from mlir.dialects import vector
+from mlir.extras import types as T
+
+TMA_LAST_DIM_F16 = 64  # 128B flaot16
+WARP_SIZE = 32
+WARP_GROUP_SIZE = WARP_SIZE * 4
+
+PRODUCER_REGISTER_SIZE = 40
+CONSUMER_REGISTER_SIZE = 232
+
+PRODUCER_PRIMARY_THREAD = 128
+CONSUMER_PRIMARY_THREAD = 0
+
+# C++ uses this value to understand whether it's dynamic or not.
+MLIR_DYNAMIC = -9223372036854775808
+
+DEBUG = False
+
+
+def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False):
+    if not DEBUG and not forcePrint:
+        return
+    type_formats = []
+    for arg in args:
+        ty_format = None
+        if ir.IndexType.isinstance(arg.type):
+            ty_format = "%llu"
+        if ir.IntegerType.isinstance(arg.type):
+            width = ir.IntegerType(arg.type).width
+            if width == 64:
+                ty_format = "%llu"
+            elif width == 32:
+                ty_format = "%d"
+            elif width == 1:
+                ty_format = "%i"
+        if ir.F32Type.isinstance(arg.type):
+            ty_format = "%f"
+        if ty_format is None:
+            raise NotImplementedError(arg.type)
+        type_formats.append(ty_format)
+    if threadNumber != -1:
+        tidx = gpu.thread_id(gpu.Dimension.x)
+        predicate = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(threadNumber))
+        scf.yield_([])
+    if_op = scf.IfOp(predicate)
+    with ir.InsertionPoint(if_op.then_block):
+        gpu.printf(fmt.format(*type_formats) + "\n", args)
+        scf.yield_([])
+
+
+def get_type_size(ty):
+    if ir.FloatType.isinstance(ty):
+        return ir.FloatType(ty).width // 8
+    if ir.IntegerType.isinstance(ty):
+        return ir.IntegerType(ty).width // 8
+    raise NotImplementedError(ty)
+
+
+def get_mlir_ty(dtype):
+    if dtype == np.float16:
+        return T.f16()
+    if dtype == np.float32:
+        return T.f32()
+    if dtype == np.float64:
+        return T.f64()
+    if dtype == np.int32:
+        return T.i32()
+    if dtype == np.int64:
+        return T.i64()
+    raise NotImplementedError(dtype)
+
+
+def c(value, ty=None):
+    ty = T.index() if ty is None else ty
+    return arith.constant(ty, value)
+
+
+def make_kernel_name(
+    input_type=np.float16,
+    output_type=np.float32,
+    M=4096,
+    N=4096,
+    K=4096,
+    BLOCK_M=128,
+    BLOCK_N=128,
+    BLOCK_K=128,
+    num_stages=3,
+    use_warp_specialization=False,
+):
+    kernelName = "warpspecialized" if use_warp_specialization else "multistage"
+    return (
+        kernelName
+        + "_"
+        + str(M)
+        + "x"
+        + str(N)
+        + "x"
+        + str(K)
+        + "_"
+        + str(BLOCK_M)
+        + "x"
+        + str(BLOCK_N)
+        + "x"
+        + str(BLOCK_K)
+        + "_"
+        + str(num_stages)
+    )
+
+
+def generate_matmul_ws(
+    input_type=np.float16,
+    output_type=np.float32,
+    M=4096,
+    N=4096,
+    K=4096,
+    BLOCK_M=128,
+    BLOCK_N=128,
+    BLOCK_K=128,
+    num_stages=3,
+):
+    # Limitaitons for now
+    assert input_type == np.float16
+    assert output_type == np.float32
+    assert BLOCK_M == 128
+    assert BLOCK_N == 128
+    assert BLOCK_K == 64
+    assert M % BLOCK_M == 0
+    assert N % BLOCK_N == 0
+    assert K % BLOCK_K == 0
+
+    module = ir.Module.create()
+    token_ty = ir.Type.parse("!gpu.async.token")
+    a_elem_ty = get_mlir_ty(input_type)
+    b_elem_ty = get_mlir_ty(input_type)
+    c_elem_ty = get_mlir_ty(output_type)
+    a_ty = ir.MemRefType.get([M, K], a_elem_ty)
+    b_ty = ir.MemRefType.get((K, N), b_elem_ty)
+    c_ty = ir.MemRefType.get((M, N), c_elem_ty)
+    a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
+    b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
+    b_tile_shape = (BLOCK_K, BLOCK_N)
+    txcount = (b_tile_shape[0] * b_tile_shape[1] * get_type_size(a_elem_ty)) + (
+        a_tile_shape[0] * a_tile_shape[1] * get_type_size(b_elem_ty)
+    )
+    smem_space_str = "#gpu.address_space<workgroup>"
+    smem_space = ir.Attribute.parse(smem_space_str)
+    mbar_ty = ir.Type.parse(
+        "!nvgpu.mbarrier.group<memorySpace = "
+        + str(smem_space)
+        + ", num_barriers = "
+        + str(num_stages)
+        + ">"
+    )
+    a_tma_desc_ty = ir.Type.parse(
+        "!nvgpu.tensormap.descriptor<tensor = memref<"
+        + str(BLOCK_M)
+        + "x"
+        + str(TMA_LAST_DIM_F16)
+        + "x"
+        + str(a_elem_ty)
+        + ", "
+        + str(smem_space)
+        + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
+    )
+    b_tma_desc_ty = ir.Type.parse(
+        "!nvgpu.tensormap.descriptor<tensor = memref<"
+        + str(BLOCK_K)
+        + "x"
+        + str(TMA_LAST_DIM_F16)
+        + "x"
+        + str(b_elem_ty)
+        + ", "
+        + str(smem_space)
+        + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
+    )
+    acc_ty = ir.Type.parse(
+        "!nvgpu.warpgroup.accumulator<fragmented=vector<"
+        + str(BLOCK_M)
+        + "x"
+        + str(BLOCK_N)
+        + "x"
+        + str(c_elem_ty)
+        + ">>"
+    )
+    a_wgmma_ty = ir.Type.parse(
+        "!nvgpu.warpgroup.descriptor<tensor=memref<"
+        + str(BLOCK_M)
+        + "x"
+        + str(BLOCK_K)
+        + "x"
+        + str(a_elem_ty)
+        + ", "
+        + smem_space_str
+        + ">>"
+    )
+    b_wgmma_ty = ir.Type.parse(
+        "!nvgpu.warpgroup.descriptor<tensor=memref<"
+        + str(BLOCK_K)
+        + "x"
+        + str(BLOCK_N)
+        + "x"
+        + str(a_elem_ty)
+        + ", "
+        + smem_space_str
+        + ">>"
+    )
+    kernelName = make_kernel_name(
+        input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_stages, True
+    )
+    with ir.InsertionPoint(module.body):
+        fop = func.FuncOp(kernelName, ([a_ty, b_ty, c_ty], []))
+        with ir.InsertionPoint(fop.add_entry_block()):
+            a_host = fop.arguments[0]
+            b_host = fop.arguments[1]
+            c_host = fop.arguments[2]
+            lhs_tile_bytes = BLOCK_M * BLOCK_K * get_type_size(a_elem_ty)
+            rhs_tile_bytes = BLOCK_N * BLOCK_K * get_type_size(b_elem_ty)
+            smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
+            smem_size_output = BLOCK_M * BLOCK_N * get_type_size(c_elem_ty)
+            smem_size = max(smem_size_input, smem_size_output)
+
+            # Step 1. Allocate device memory and memcpy
+            t1 = gpu.wait(token_ty, [])
+            a_device, t2 = gpu.alloc(a_ty, token_ty, [t1], [], [])
+            b_device, t3 = gpu.alloc(b_ty, token_ty, [t2], [], [])
+            c_device, t4 = gpu.alloc(c_ty, token_ty, [t3], [], [])
+            t5 = gpu.memcpy(token_ty, [t4], a_device, a_host)
+            t6 = gpu.memcpy(token_ty, [t5], b_device, b_host)
+            t7 = gpu.wait(token_ty, [t6])
+
+            # Step 2. Create TMA Descriptors
+            tma_specs = [
+                (a_device, a_tma_desc_ty, a_tma_shape),
+                (b_device, b_tma_desc_ty, b_tma_shape),
+            ]
+            tma_descs = []
+            for x_device, tensor_map_ty, tile_shape in tma_specs:
+                x_unranked = memref.cast(
+                    ir.UnrankedMemRefType.get(a_elem_ty, a_ty.memory_space), x_device
+                )
+                tma_descs.append(
+                    nvgpu.TmaCreateDescriptorOp(
+                        tensor_map_ty, x_unranked, map(c, tile_shape)
+                    ).result
+                )
+            a_tma_desc, b_tma_desc = tma_descs
+
+            # Step 3. Launch Kernel with 2 Warpgroups : 1 Producer, 1 Consumer
+            cta_m = M // BLOCK_M
+            cta_n = N // BLOCK_N
+            assert M % BLOCK_M == 0 and N % BLOCK_N == 0
+            grid = (cta_m, cta_n, 1)
+            block = (WARP_GROUP_SIZE * 2, 1, 1)
+            launch_op = gpu.LaunchOp(
+                token_ty,
+                [t7],
+                *map(c, grid),
+                *map(c, block),
+                dynamicSharedMemorySize=c(smem_size, ty=T.i32())
+            )
+            launch_op.body.blocks.append(*([T.index()] * 12))
+            with ir.InsertionPoint(launch_op.body.blocks[0]):
+                # GPU Step 0. This is need for vectorized ld/st
+                memref.assume_alignment(c_device, 16)
+                dynamic_smem = gpu.dynamic_shared_memory(
+                    ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space)
+                )
+                ticks = c(10000000)
+
+                # GPU Step 1. Bootstrapping: find the primary thread, warps, warp groups and etc.
+                tidx = gpu.thread_id(gpu.Dimension.x)
+                wgPrimaryThread = arith.cmpi(
+                    arith.CmpIPredicate.eq, arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0)
+                )
+                warp_id = arith.divui(tidx, c(32))
+                warpgroup_id = arith.divui(warp_id, c(4))
+                is_producer = arith.cmpi(
+                    arith.CmpIPredicate.eq,
+                    warpgroup_id,
+                    c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0),
+                )
+                is_consumer = arith.cmpi(
+                    arith.CmpIPredicate.eq,
+                    warpgroup_id,
+                    c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1),
+                )
+                producerPrimaryThread = arith.cmpi(
+                    arith.CmpIPredicate.eq, tidx, c(PRODUCER_PRIMARY_THREAD)
+                )
+                consumerPrimaryThread = arith.cmpi(
+                    arith.CmpIPredicate.eq, tidx, c(CONSUMER_PRIMARY_THREAD)
+                )
+                bidx = gpu.block_id(gpu.Dimension.x)
+                bidy = gpu.block_id(gpu.Dimension.y)
+                dimX = arith.muli(bidx, c(BLOCK_M))
+                dimY = arith.muli(bidy, c(BLOCK_N))
+
+                # GPU Step 2. Initialize mbarrier groups
+                mbarTMA = nvgpu.mbarrier_create(mbar_ty)
+                mbarDONE = nvgpu.mbarrier_create(mbar_ty)
+                for i in range(num_stages):
+                    nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=wgPrimaryThread)
+                    nvgpu.mbarrier_init(mbarDONE, c(1), c(i), predicate=wgPrimaryThread)
+                gpu.barrier()
+
+                # GPU Step 3. Prefetch TMA descriptors
+                nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=wgPrimaryThread)
+                nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=wgPrimaryThread)
+
+                ns = num_stages if num_stages == 1 else num_stages - 1
+                # GPU Step 5. Producer Warpgroup (TMA Warpgroup)
+                with ir.InsertionPoint(scf.IfOp(is_producer).then_block):
+                    # Step 5.1. Reduce register size
+                    nvvm.setmaxregister(
+                        PRODUCER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.decrease
+                    )
+
+                    # Step 5.2. TMA Main Loop
+                    for_op = scf.ForOp(
+                        c(0), c(K // BLOCK_K), c(1), [arith.constant(T.bool(), 1)]
+                    )
+                    with ir.InsertionPoint(for_op.body):
+                        phaseParity = for_op.inner_iter_args[0]
+                        iv = for_op.induction_variable
+                        stage = arith.remui(iv, c(num_stages))
+
+                        # Step 5.2.1. Wait mbarDONE
+                        debug_print(
+                            "[prod] iv={}  | mbarDONE[{}] try_wait  phase={}",
+                            iv,
+                            stage,
+                            phaseParity,
+                            predicate=producerPrimaryThread,
+                        )
+                        nvgpu.MBarrierTryWaitParityOp(
+                            mbarDONE, phaseParity, ticks, mbarId=stage
+                        )
+                        debug_print(
+                            "[prod] iv={}  | mbarDONE[{}] try_wait  phase={} [done]",
+                            iv,
+                            stage,
+                            phaseParity,
+                            predicate=producerPrimaryThread,
+                        )
+                        p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
+                        phaseParity = arith.select(
+                            p,
+                            arith.xori(phaseParity, arith.constant(T.bool(), 1)),
+                            phaseParity,
+                        )
+
+                        # Step 5.2.2. Load TMA
+                        a_offset = arith.muli(stage, c(lhs_tile_bytes))
+                        a_tma_slice = memref.view(
+                            ir.MemRefType.get(
+                                a_tma_shape, a_elem_ty, memory_space=smem_space
+                            ),
+                            dynamic_smem,
+                            a_offset,
+                            [],
+                        )
+                        b_offset = arith.addi(
+                            arith.muli(stage, c(rhs_tile_bytes)),
+                            c(lhs_tile_bytes * num_stages),
+                        )
+                        b_tma_slice_1 = memref.view(
+                            ir.MemRefType.get(
+                                b_tma_shape, b_elem_ty, memory_space=smem_space
+                            ),
+                            dynamic_smem,
+                            b_offset,
+                            [],
+                        )
+                        b_offset2 = arith.addi(
+                            b_offset,
+                            c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)),
+                        )
+                        b_tma_slice_2 = memref.view(
+                            ir.MemRefType.get(
+                                b_tma_shape, b_elem_ty, memory_space=smem_space
+                            ),
+                            dynamic_smem,
+                            b_offset2,
+                            [],
+                        )
+                        debug_print(
+                            "[prod] a_offset={} b_offset={} b_offset2={}",
+                            a_offset,
+                            b_offset,
+                            b_offset2,
+                            predicate=producerPrimaryThread,
+                        )
+                        coord = arith.muli(c(64), iv)
+                        nvgpu.TmaAsyncLoadOp(
+                            a_tma_slice,
+                            mbarTMA,
+                            a_tma_desc,
+                            coordinates=[coord, dimX],
+                            mbarId=stage,
+                            predicate=producerPrimaryThread,
+                        )
+                        nvgpu.TmaAsyncLoadOp(
+                            b_tma_slice_1,
+                            mbarTMA,
+                            b_tma_desc,
+                            coordinates=[dimY, coord],
+                            mbarId=stage,
+                            predicate=producerPrimaryThread,
+                        )
+                        dimY2 = arith.addi(dimY, c(64))
+                        nvgpu.TmaAsyncLoadOp(
+                            b_tma_slice_2,
+                            mbarTMA,
+                            b_tma_desc,
+                            coordinates=[dimY2, coord],
+                            mbarId=stage,
+                            predicate=producerPrimaryThread,
+                        )
+
+                        # Step 5.2.3. Arrive mbarTMA
+                        debug_print(
+                            "[prod] iv={}  | mbarTMA[{}] arrive",
+                            iv,
+                            stage,
+                            predicate=producerPrimaryThread,
+                        )
+                        nvgpu.mbarrier_arrive_expect_tx(
+                            mbarTMA, c(txcount), stage, predicate=producerPrimaryThread
+                        )
+                        debug_print(
+                            "[prod] iv={}  | mbarTMA[{}] arrive [done]",
+                            iv,
+                            stage,
+                            predicate=producerPrimaryThread,
+                        )
+                        scf.yield_([phaseParity])
+                    scf.yield_([])
+
+                # GPU Step 6. Consumer Warpgroup (MMA Warpgroup)
+                if_op = scf.IfOp(is_consumer)
+                with ir.InsertionPoint(if_op.then_block):
+                    # Step 6.1. Increase register size
+                    nvvm.setmaxregister(
+                        CONSUMER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.increase
+                    )
+
+                    # GPU Step 6.2. Initialize MMA registers
+                    acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
+
+                    # Step 6.3. MMA Main Loop
+                    for_op = scf.ForOp(
+                        c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(T.bool(), 0)]
+                    )
+                    with ir.InsertionPoint(for_op.body):
+                        # Step 6.3.1. Wait mbar1
+                        phaseParity = for_op.inner_iter_args[1]
+                        iv = for_op.induction_variable
+                        stage = arith.remui(iv, c(num_stages))
+                        debug_print(
+                            "[cons] iv={}  | mbarTMA[{}] try_wait   phase={}",
+                            iv,
+                            stage,
+                            phaseParity,
+                            predicate=consumerPrimaryThread,
+                        )
+                        nvgpu.MBarrierTryWaitParityOp(
+                            mbarTMA, phaseParity, ticks, mbarId=stage
+                        )
+                        debug_print(
+                            "[cons] iv={}  | mbarTMA[{}] try_wait   phase={} [done]",
+                            iv,
+                            stage,
+                            phaseParity,
+                            predicate=consumerPrimaryThread,
+                        )
+
+                        # Step 6.3.2. Create WGMMA Descriptors
+                        a_offset = arith.muli(stage, c(lhs_tile_bytes))
+                        a_tile_slice = memref.view(
+                            ir.MemRefType.get(
+                                a_tile_shape, a_elem_ty, memory_space=smem_space
+                            ),
+                            dynamic_smem,
+                            a_offset,
+                            [],
+                        )
+                        b_offset = arith.addi(
+                            arith.muli(stage, c(rhs_tile_bytes)),
+                            c(lhs_tile_bytes * num_stages),
+                        )
+                        b_tile_slice = memref.view(
+                            ir.MemRefType.get(
+                                b_tile_shape, b_elem_ty, memory_space=smem_space
+                            ),
+                            dynamic_smem,
+                            b_offset,
+                            [],
+                        )
+                        debug_print(
+                            "[cons] a_offset={} b_offset={}",
+                            a_offset,
+                            b_offset,
+                            predicate=consumerPrimaryThread,
+                        )
+                        da = nvgpu.WarpgroupGenerateDescriptorOp(
+                            a_wgmma_ty, a_tile_slice, a_tma_desc
+                        )
+                        db = nvgpu.WarpgroupGenerateDescriptorOp(
+                            b_wgmma_ty, b_tile_slice, b_tma_desc
+                        )
+
+                        # Step 6.3.3. MMA
+                        carry_acc = for_op.inner_iter_args[0]
+                        new_acc = nvgpu.WarpgroupMmaOp(
+                            acc.type, da, db, carry_acc, transposeB=True
+                        )
+
+                        # Step 6.3.4. Arrive mbarDONE
+                        if num_stages == 1:
+                            p_arrive = consumerPrimaryThread
+                        else:
+                            p1 = arith.cmpi(arith.CmpIPredicate.sgt, iv, c(0))
+                            p_arrive = arith.andi(consumerPrimaryThread, p1)
+                        with ir.InsertionPoint(scf.IfOp(p_arrive).then_block):
+                            p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(0))
+                            barId = arith.select(
+                                p, c(num_stages - 1), arith.subi(stage, c(1))
+                            )
+                            debug_print(
+                                "[cons] iv={}  | mbarDONE[{}] arrive ",
+                                iv,
+                                barId,
+                                predicate=consumerPrimaryThread,
+                            )
+                            nvgpu.mbarrier_arrive(
+                                ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId
+                            )
+                            debug_print(
+                                "[cons] iv={}  | mbarDONE[{}] arrive [done]",
+                                iv,
+                                barId,
+                                predicate=consumerPrimaryThread,
+                            )
+                            scf.yield_([])
+
+                        p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
+                        phaseParity = arith.select(
+                            p,
+                            arith.xori(phaseParity, arith.constant(T.bool(), 1)),
+                            phaseParity,
+                        )
+
+                        # Step 6.3.5. Yield
+                        scf.yield_([new_acc, phaseParity])
+
+                    # Step 6.3. Wait All WGMMA
+                    nvvm.WgmmaWaitGroupSyncOp(0)
+
+                    with ir.InsertionPoint(scf.IfOp(consumerPrimaryThread).then_block):
+                        barId = c((K // BLOCK_K) % num_stages)
+                        nvgpu.mbarrier_arrive(
+                            ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId
+                        )
+                        scf.yield_([])
+
+                    # Step 6.4. Epilogue (registers --> shared memory)
+                    acc_smem_ty = ir.MemRefType.get(
+                        (BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space
+                    )
+                    acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
+                    debug_print("[cons]  | Storing", predicate=consumerPrimaryThread)
+                    nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
+                    scf.yield_([])
+                gpu.barrier()
+
+                # GPU Step 9. Epilogue (shared memory --> global memory)
+                fd = ir.MemRefType.get(
+                    [BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space
+                )
+                collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
+                rty = ir.MemRefType.get(
+                    (BLOCK_M, BLOCK_N),
+                    c_elem_ty,
+                    ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"),
+                )
+                c_device_per_block = memref.SubViewOp(
+                    rty,
+                    c_device,
+                    [dimX, dimY],
+                    [],
+                    [],
+                    [MLIR_DYNAMIC, MLIR_DYNAMIC],
+                    [BLOCK_M, BLOCK_N],
+                    [1, 1],
+                )
+                vlen = 1
+                for_op = scf.ForOp(
+                    tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE * 2)
+                )
+                with ir.InsertionPoint(for_op.body):
+                    x = arith.divui(for_op.induction_variable, c(BLOCK_M))
+                    y = arith.remui(for_op.induction_variable, c(BLOCK_N))
+                    vdata = vector.load(
+                        ir.VectorType.get((vlen,), c_elem_ty),
+                        collapsed_smem,
+                        [for_op.induction_variable],
+                    )
+                    vector.store(vdata, c_device_per_block, [x, y])
+                    scf.yield_([])
+
+                gpu.terminator()
+
+            # Step 4. Copy back to host
+            t8 = gpu.wait(token_ty, [launch_op])
+            t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
+            gpu.dealloc(token_ty, [t8], a_device)
+            gpu.dealloc(token_ty, [t8], b_device)
+            gpu.wait(token_ty, [t9])
+            gpu.dealloc(token_ty, [t8], c_device)
+            func.ReturnOp([])
+
+    fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+    module.operation.verify()
+    return module
+
+
+def generate_matmul_multistage(
+    input_type=np.float16,
+    output_type=np.float32,
+    M=4096,
+    N=4096,
+    K=4096,
+    BLOCK_M=128,
+    BLOCK_N=128,
+    BLOCK_K=64,
+    num_stages=3,
+):
+    # Limitaitons for now
+    assert input_type == np.float16
+    assert output_type == np.float32
+    assert BLOCK_M == 128
+    assert BLOCK_N == 128
+    assert BLOCK_K == 64
+    assert M % BLOCK_M == 0
+    assert N % BLOCK_N == 0
+    assert K % BLOCK_K == 0
+
+    module = ir.Module.create()
+    token_ty = ir.Type.parse("!gpu.async.token")
+    a_elem_ty = get_mlir_ty(input_type)
+    b_elem_ty = get_mlir_ty(input_type)
+    c_elem_ty = get_mlir_ty(output_type)
+    a_ty = ir.MemRefType.get([M, K], a_elem_ty)
+    b_ty = ir.MemRefType.get((K, N), b_elem_ty)
+    c_ty = ir.MemRefType.get((M, N), c_elem_ty)
+    a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
+    b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
+    b_tile_shape = (BLOCK_K, BLOCK_N)
+    txcount = (b_tile_shape[0] * b_tile_shape[1] * get_type_size(a_elem_ty)) + (
+        a_tile_shape[0] * a_tile_shape[1] * get_type_size(b_elem_ty)
+    )
+    smem_space_str = "#gpu.address_space<workgroup>"
+    smem_space = ir.Attribute.parse(smem_space_str)
+    mbar_ty = ir.Type.parse(
+        "!nvgpu.mbarrier.group<memorySpace = "
+        + str(smem_space)
+        + ", num_barriers = "
+        + str(num_stages)
+        + ">"
+    )
+    a_tma_desc_ty = ir.Type.parse(
+        "!nvgpu.tensormap.descriptor<tensor = memref<"
+        + str(BLOCK_M)
+        + "x"
+        + str(TMA_LAST_DIM_F16)
+        + "x"
+        + str(a_elem_ty)
+        + ", "
+        + str(smem_space)
+        + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
+    )
+    b_tma_desc_ty = ir.Type.parse(
+        "!nvgpu.tensormap.descriptor<tensor = memref<"
+        + str(BLOCK_K)
+        + "x"
+        + str(TMA_LAST_DIM_F16)
+        + "x"
+        + str(b_elem_ty)
+        + ", "
+        + str(smem_space)
+        + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
+    )
+    acc_ty = ir.Type.parse(
+        "!nvgpu.warpgroup.accumulator<fragmented=vector<"
+        + str(BLOCK_M)
+        + "x"
+        + str(BLOCK_N)
+        + "x"
+        + str(c_elem_ty)
+        + ">>"
+    )
+    a_wgmma_ty = ir.Type.parse(
+        "!nvgpu.warpgroup.descriptor<tensor=memref<"
+        + str(BLOCK_M)
+        + "x"
+        + str(BLOCK_K)
+        + "x"
+        + str(a_elem_ty)
+        + ", "
+        + smem_space_str
+        + ">>"
+    )
+    b_wgmma_ty = ir.Type.parse(
+        "!nvgpu.warpgroup.descriptor<tensor=memref<"
+        + str(BLOCK_K)
+        + "x"
+        + str(BLOCK_N)
+        + "x"
+        + str(a_elem_ty)
+        + ", "
+        + smem_space_str
+        + ">>"
+    )
+
+    with ir.InsertionPoint(module.body):
+        kernelName = make_kernel_name(
+            input_type,
+            output_type,
+            M,
+            N,
+            K,
+            BLOCK_M,
+            BLOCK_N,
+            BLOCK_K,
+            num_stages,
+            False,
+        )
+        fop = func.FuncOp(kernelName, ([a_ty, b_ty, c_ty], []))
+        with ir.InsertionPoint(fop.add_entry_block()):
+            a_host = fop.arguments[0]
+            b_host = fop.arguments[1]
+            c_host = fop.arguments[2]
+            lhs_tile_bytes = BLOCK_M * BLOCK_K * get_type_size(a_elem_ty)
+            rhs_tile_bytes = BLOCK_N * BLOCK_K * get_type_size(b_elem_ty)
+            smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
+            smem_size_output = BLOCK_M * BLOCK_N * get_type_size(c_elem_ty)
+            smem_size = max(smem_size_input, smem_size_output)
+
+            # Step 1. Allocate device memory and memcpy
+            t1 = gpu.wait(token_ty, [])
+            a_device, t2 = gpu.alloc(a_ty, token_ty, [t1], [], [])
+            b_device, t3 = gpu.alloc(b_ty, token_ty, [t2], [], [])
+            c_device, t4 = gpu.alloc(c_ty, token_ty, [t3], [], [])
+            t5 = gpu.memcpy(token_ty, [t4], a_device, a_host)
+            t6 = gpu.memcpy(token_ty, [t5], b_device, b_host)
+            t7 = gpu.wait(token_ty, [t6])
+
+            # Step 2. Create TMA Descriptors
+            tma_specs = [
+                (a_device, a_tma_desc_ty, a_tma_shape),
+                (b_device, b_tma_desc_ty, b_tma_shape),
+            ]
+            tma_descs = []
+            for x_device, tensor_map_ty, tile_shape in tma_specs:
+                x_unranked = memref.cast(
+                    ir.UnrankedMemRefType.get(a_elem_ty, a_ty.memory_space), x_device
+                )
+                tma_descs.append(
+                    nvgpu.TmaCreateDescriptorOp(
+                        tensor_map_ty, x_unranked, map(c, tile_shape)
+                    ).result
+                )
+            a_tma_desc, b_tma_desc = tma_descs
+
+            # Step 3. Launch Kernel with 1 Warpgroup
+            cta_m = M // BLOCK_M
+            cta_n = N // BLOCK_N
+            assert M % BLOCK_M == 0 and N % BLOCK_N == 0
+            grid = (cta_m, cta_n, 1)
+            block = (WARP_GROUP_SIZE, 1, 1)
+            launch_op = gpu.LaunchOp(
+                token_ty,
+                [t7],
+                *map(c, grid),
+                *map(c, block),
+                dynamicSharedMemorySize=c(smem_size, ty=T.i32())
+            )
+            launch_op.body.blocks.append(*([T.index()] * 12))
+            with ir.InsertionPoint(launch_op.body.blocks[0]):
+                # GPU Step 0. Bootstrapping
+                memref.assume_alignment(c_device, 16)
+                dynamic_smem = gpu.dynamic_shared_memory(
+                    ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space)
+                )
+                ticks = c(10000000)
+                tidx = gpu.thread_id(gpu.Dimension.x)
+                primaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(0))
+                warpId = arith.divui(tidx, c(32))
+                bidx = gpu.block_id(gpu.Dimension.x)
+                bidy = gpu.block_id(gpu.Dimension.y)
+                dimX = arith.muli(bidx, c(BLOCK_M))
+                dimY = arith.muli(bidy, c(BLOCK_N))
+
+                # GPU Step 1. Initialize mbarrier groups
+                mbarTMA = nvgpu.mbarrier_create(mbar_ty)
+                for i in range(num_stages):
+                    nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=primaryThread)
+                gpu.barrier()
+
+                # GPU Step 2. Prefetch TMA descriptors
+                nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=primaryThread)
+                nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=primaryThread)
+
+                # GPU Step 3. Prologue (global memory --> shared memory)
+                ns = num_stages if num_stages == 1 else num_stages - 1
+                for_op = scf.ForOp(c(0), c(ns), c(1))
+                with ir.InsertionPoint(for_op.body):
+                    iv = for_op.induction_variable
+
+                    # Step 3.1. Calculate offsets
+                    a_offset = arith.muli(iv, c(lhs_tile_bytes))
+                    a_tma_slice = memref.view(
+                        ir.MemRefType.get(
+                            a_tma_shape, a_elem_ty, memory_space=smem_space
+                        ),
+                        dynamic_smem,
+                        a_offset,
+                        [],
+                    )
+                    b_offset = arith.addi(
+                        arith.muli(iv, c(rhs_tile_bytes)),
+                        c(lhs_tile_bytes * num_stages),
+                    )
+                    b_tma_slice_1 = memref.view(
+                        ir.MemRefType.get(
+                            b_tma_shape, b_elem_ty, memory_space=smem_space
+                        ),
+                        dynamic_smem,
+                        b_offset,
+                        [],
+                    )
+                    b_offset2 = arith.addi(
+                        b_offset,
+                        c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)),
+                    )
+                    b_tma_slice_2 = memref.view(
+                        ir.MemRefType.get(
+                            b_tma_shape, b_elem_ty, memory_space=smem_space
+                        ),
+                        dynamic_smem,
+                        b_offset2,
+                        [],
+                    )
+
+                    # Step 3.2. TMA Load
+                    coord = arith.muli(c(64), iv)
+                    dimY2 = arith.addi(dimY, c(64))
+                    debug_print(
+                        "[Prologue] TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
+                        a_offset,
+                        b_offset,
+                        b_offset2,
+                        coord,
+                        dimX,
+                        dimY,
+                        coord,
+                        predicate=primaryThread,
+                    )
+                    nvgpu.TmaAsyncLoadOp(
+                        a_tma_slice,
+                        mbarTMA,
+                        a_tma_desc,
+                        coordinates=[coord, dimX],
+                        mbarId=iv,
+                        predicate=primaryThread,
+                    )
+                    nvgpu.TmaAsyncLoadOp(
+                        b_tma_slice_1,
+                        mbarTMA,
+                        b_tma_desc,
+                        coordinates=[dimY, coord],
+                        mbarId=iv,
+                        predicate=primaryThread,
+                    )
+                    nvgpu.TmaAsyncLoadOp(
+                        b_tma_slice_2,
+                        mbarTMA,
+                        b_tma_desc,
+                        coordinates=[dimY2, coord],
+                        mbarId=iv,
+                        predicate=primaryThread,
+                    )
+
+                    # Step 3.2. mbarTMA arrive
+                    debug_print(
+                        "[Prologue] mbarTMA[{}] arrive", iv, predicate=primaryThread
+                    )
+                    nvgpu.mbarrier_arrive_expect_tx(
+                        mbarTMA, c(txcount), iv, predicate=primaryThread
+                    )
+                    debug_print(
+                        "[Prologue] mbarTMA[{}] arrive [done]",
+                        iv,
+                        predicate=primaryThread,
+                    )
+                    scf.yield_([])
+
+                # GPU Step 4. Main Loop
+                acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
+                for_op = scf.ForOp(
+                    c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(T.bool(), 0)]
+                )
+                with ir.InsertionPoint(for_op.body):
+                    # Step 4.1. Wait mbarTMA
+                    phaseParity = for_op.inner_iter_args[1]
+                    iv = for_op.induction_variable
+                    stage = arith.remui(iv, c(num_stages))
+                    debug_print(
+                        "[MainLoop] mbarTMA[{}] try_wait   phase={}",
+                        stage,
+                        phaseParity,
+                        predicate=primaryThread,
+                    )
+                    nvgpu.MBarrierTryWaitParityOp(
+                        mbarTMA, phaseParity, ticks, mbarId=stage
+                    )
+                    debug_print(
+                        "[MainLoop] mbarTMA[{}] try_wait   phase={} [done]",
+                        stage,
+                        phaseParity,
+                        predicate=primaryThread,
+                    )
+
+                    # Step 4.2. Create WGMMA Descriptors
+                    a_offset = arith.muli(stage, c(lhs_tile_bytes))
+                    a_tile_slice = memref.view(
+                        ir.MemRefType.get(
+                            a_tile_shape, a_elem_ty, memory_space=smem_space
+                        ),
+                        dynamic_smem,
+                        a_offset,
+                        [],
+                    )
+                    b_offset = arith.addi(
+                        arith.muli(stage, c(rhs_tile_bytes)),
+                        c(lhs_tile_bytes * num_stages),
+                    )
+                    b_tile_slice = memref.view(
+                        ir.MemRefType.get(
+                            b_tile_shape, b_elem_ty, memory_space=smem_space
+                        ),
+                        dynamic_smem,
+                        b_offset,
+                        [],
+                    )
+                    debug_print(
+                        "[MainLoop] iv={} MMA a_offset={} b_offset={}",
+                        iv,
+                        a_offset,
+                        b_offset,
+                        predicate=primaryThread,
+                    )
+                    da = nvgpu.WarpgroupGenerateDescriptorOp(
+                        a_wgmma_ty, a_tile_slice, a_tma_desc
+                    )
+                    db = nvgpu.WarpgroupGenerateDescriptorOp(
+                        b_wgmma_ty, b_tile_slice, b_tma_desc
+                    )
+
+                    # Step 4.3. MMA
+                    carry_acc = for_op.inner_iter_args[0]
+                    new_acc = nvgpu.WarpgroupMmaOp(
+                        acc.type, da, db, carry_acc, transposeB=True
+                    )
+                    if num_stages == 1:
+                        nvvm.WgmmaWaitGroupSyncOp(0)
+
+                    # Step 4.4. Load TMA for next stage
+                    p1 = arith.cmpi(
+                        arith.CmpIPredicate.ult,
+                        arith.addi(iv, c(ns)),
+                        c(K // BLOCK_K),
+                    )
+                    p = arith.andi(primaryThread, p1)
+                    nextStage = arith.addi(iv, c(ns))
+                    nextSlot = arith.remui(nextStage, c(num_stages))
+                    a_offset = arith.muli(nextSlot, c(lhs_tile_bytes))
+
+                    debug_print(
+                        "[MainLoop] mbarTMA[{}] arrive",
+                        nextSlot,
+                        predicate=p,
+                    )
+                    nvgpu.mbarrier_arrive_expect_tx(
+                        mbarTMA, c(txcount), nextSlot, predicate=p
+                    )
+                    debug_print(
+                        "[MainLoop] mbarTMA[{}] arrive [done]",
+                        nextSlot,
+                        predicate=p,
+                    )
+
+                    a_tma_slice = memref.view(
+                        ir.MemRefType.get(
+                            a_tma_shape, a_elem_ty, memory_space=smem_space
+                        ),
+                        dynamic_smem,
+                        a_offset,
+                        [],
+                    )
+                    b_offset = arith.addi(
+                        arith.muli(nextSlot, c(rhs_tile_bytes)),
+                        c(lhs_tile_bytes * num_stages),
+                    )
+                    b_tma_slice_1 = memref.view(
+                        ir.MemRefType.get(
+                            b_tma_shape, b_elem_ty, memory_space=smem_space
+                        ),
+                        dynamic_smem,
+                        b_offset,
+                        [],
+                    )
+                    b_offset2 = arith.addi(
+                        b_offset,
+                        c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)),
+                    )
+                    b_tma_slice_2 = memref.view(
+                        ir.MemRefType.get(
+                            b_tma_shape, b_elem_ty, memory_space=smem_space
+                        ),
+                        dynamic_smem,
+                        b_offset2,
+                        [],
+                    )
+
+                    coord = arith.muli(c(64), nextStage)
+                    debug_print(
+                        "[MainLoop] iv={} TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
+                        iv,
+                        a_offset,
+                        b_offset,
+                        b_offset2,
+                        coord,
+                        dimX,
+                        dimY,
+                        coord,
+                        predicate=p,
+                    )
+                    nvgpu.TmaAsyncLoadOp(
+                        a_tma_slice,
+                        mbarTMA,
+                        a_tma_desc,
+                        coordinates=[coord, dimX],
+                        mbarId=nextSlot,
+                        predicate=p,
+                    )
+                    nvgpu.TmaAsyncLoadOp(
+                        b_tma_slice_1,
+                        mbarTMA,
+                        b_tma_desc,
+                        coordinates=[dimY, coord],
+                        mbarId=nextSlot,
+                        predicate=p,
+                    )
+                    dimY2 = arith.addi(dimY, c(64))
+                    nvgpu.TmaAsyncLoadOp(
+                        b_tma_slice_2,
+                        mbarTMA,
+                        b_tma_desc,
+                        coordinates=[dimY2, coord],
+                        mbarId=nextSlot,
+                        predicate=p,
+                    )
+                    # Step 4.5. Change the phaseParity
+                    p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
+                    phaseParity = arith.select(
+                        p,
+                        arith.xori(phaseParity, arith.constant(T.bool(), 1)),
+                        phaseParity,
+                    )
+
+                    # Step 4.5. Yield
+                    scf.yield_([new_acc, phaseParity])
+
+                # Step 5. Wait All WGMMA groups
+                nvvm.WgmmaWaitGroupSyncOp(0)
+
+                # Step 6. Epilogue (registers --> shared memory)
+                acc_smem_ty = ir.MemRefType.get(
+                    (BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space
+                )
+                acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
+                debug_print("Storing", predicate=primaryThread)
+                nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
+                gpu.barrier()
+
+                # GPU Step 7. Epilogue (shared memory --> global memory)
+                fd = ir.MemRefType.get(
+                    [BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space
+                )
+                collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
+                rty = ir.MemRefType.get(
+                    (BLOCK_M, BLOCK_N),
+                    c_elem_ty,
+                    ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"),
+                )
+                c_device_per_block = memref.SubViewOp(
+                    rty,
+                    c_device,
+                    [dimX, dimY],
+                    [],
+                    [],
+                    [MLIR_DYNAMIC, MLIR_DYNAMIC],
+                    [BLOCK_M, BLOCK_N],
+                    [1, 1],
+                )
+                vlen = 1
+                for_op = scf.ForOp(
+                    tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE)
+                )
+                with ir.InsertionPoint(for_op.body):
+                    x = arith.divui(for_op.induction_variable, c(BLOCK_M))
+                    y = arith.remui(for_op.induction_variable, c(BLOCK_N))
+                    vdata = vector.load(
+                        ir.VectorType.get((vlen,), c_elem_ty),
+                        collapsed_smem,
+                        [for_op.induction_variable],
+                    )
+                    vector.store(vdata, c_device_per_block, [x, y])
+                    scf.yield_([])
+
+                gpu.terminator()
+
+            # Step 4. Copy back to host
+            t8 = gpu.wait(token_ty, [launch_op])
+            t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
+            gpu.dealloc(token_ty, [t8], a_device)
+            gpu.dealloc(token_ty, [t8], b_device)
+            gpu.wait(token_ty, [t9])
+            gpu.dealloc(token_ty, [t8], c_device)
+            func.ReturnOp([])
+
+    fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+    module.operation.verify()
+    return module

diff  --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
new file mode 100644
index 00000000000000..1c9cc74fcd169c
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
@@ -0,0 +1,45 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#  This file contains the Nvgpu class.
+
+from mlir import execution_engine
+from mlir import ir
+from mlir import passmanager
+from typing import Sequence
+import errno
+import os
+import sys
+
+_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(_SCRIPT_PATH)
+
+
+class NvgpuCompiler:
+    """Nvgpu class for compiling and building MLIR modules."""
+
+    def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
+        pipeline = f"builtin.module(gpu-lower-to-nvvm-pipeline{{{options}}})"
+        self.pipeline = pipeline
+        self.shared_libs = shared_libs
+        self.opt_level = opt_level
+
+    def __call__(self, module: ir.Module):
+        """Convenience application method."""
+        self.compile(module)
+
+    def compile(self, module: ir.Module):
+        """Compiles the module by invoking the nvgpu pipeline."""
+        passmanager.PassManager.parse(self.pipeline).run(module.operation)
+
+    def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+        """Wraps the module in a JIT execution engine."""
+        return execution_engine.ExecutionEngine(
+            module, opt_level=self.opt_level, shared_libs=self.shared_libs
+        )
+
+    def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+        """Compiles and jits the module."""
+        self.compile(module)
+        return self.jit(module)


        


More information about the Mlir-commits mailing list