[Mlir-commits] [mlir] [mlir] GEMM Hopper Tensor Core Integration Test (PR #81478)

Guray Ozen llvmlistbot at llvm.org
Mon Feb 12 05:39:21 PST 2024


https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/81478

>From a9b4035aa421506302e74c6f6ff19a0aec62b735 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Mon, 12 Feb 2024 13:33:26 +0000
Subject: [PATCH 1/2] [mlir] GEMM Hopper Tensor Core Integration Test

This test aims to validate the correctness of the supported GEMM kernels in
NVGPU dialects, with current support for Multistage and Warp Specialization
kernels.
The test constructs and metaprograms IR using Python bindings, allowing
generic IR building. This flexibility enables changes to the shape,
tile size, or data type of the GEMM for testing purposes.
The entry function is `matmul`, where one can specify GEMM shape, tile size,
data type, GEMM algorithm (Multistage or Warp Specialization), and the maximum
number of stages.
Verification is done via numpy's matmul operation.

Example:
```
matmul(input_type=np.float16,                # input types
       output_type=np.float32,               # output type
       M=4096, N=4096, K=4096,               # Shape
       BLOCK_M=128, BLOCK_N=128, BLOCK_K=64, # Tile Size
       use_warp_specialization=True,         # Enable Warp Specialization
       max_num_stages=3)                     # Number of stages in shared memory
```
### Parallelism Across CTAs

GEMM includes three loops defining the shape of the GEMM, specified in the
`matmul` function.
The program builds IR using the following loop structure, tiling the loops
with the given tile size and parallelizing the two outermost loops into the
first and second dimensions of CTAs.
```
for(bi = 0; i < M; i += BLOCK_M)          # parallelize across blockIdx.x
    for(bj = 0; j < N; j += BLOCK_N)      # parallelize across blockIdx.y
        for(bk = 0; k < K; K += BLOCK_K)
            for(i = bi; i < (bi + BLOCK_M); ++i)
                for(j = bj; j < (bj + BLOCK_N); ++j)
                    for(k = bk; k < (bk + BLOCK_K); ++k)
```

## Multistage Kernel

This kernel launches a single warp group (128 threads). The primary thread
(pthread) requests load from TMA. Threads collectively wait for the data and
perform mma operations. After completing the shape, threads together store
first fragmented registers to shared memory, then from shared memory to global
memory; this part is called the epilogue.

Execution Timeline of Multistage Kernel with 3 stages:
```
+-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
|       |Prologue ---->   |MainLoop ---->                                                                  |Epilogue  |
+-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
|pthread|[tma-0,1,2]     |[wait-0][mma][tma-2]|[wait-1][mma][tma-0]|[wait-2][mma][tma-1]| ... | [mma-wait] |[epilogue]|
|wgroup | ........       |[wait-0][mma]       |[wait-1][mma]       |[wait-2][mma]       | ... | [mma-wait] |[epilogue]|
+-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
```

## Warp Specialization Kernel

This kernel launches 2 warp groups (2x128 threads) per CTA, specializing one
as `producer warp group` and another as `consumer warp group`. The
`producer warp group` is responsible for requesting TMA load, while the
`consumer warp group` performs the mma operation. The epilogue section is
handled by the `consumer warp group` as its threads own the fragmented registers.

Execution Timeline of Warp Specialization Kernel with 2 stages:
```
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
|        |MainLoop ---->                                                    | 1st Epilogue | 2nd Epilogue    |
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
|pthread1|[tma-0] | [tma-1] | [tma-0] | [tma-1] | ..........................| ...........  | [shmem->global] |
|wgroup1 | .......|         |         |         |                           |              | [shmem->global] |
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
|wgroup2 |[wait-0][mma], [wait-1][mma], [wait-0][mma], [wait-1][mma], ......| [reg->shmem] | [shmem->global]|
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
```
---
 .../GPU/CUDA/sm90/python/lit.local.cfg        |   2 +
 .../GPU/CUDA/sm90/python/matmul.py            | 186 +++++
 .../GPU/CUDA/sm90/python/tools/lit.local.cfg  |   3 +
 .../CUDA/sm90/python/tools/matmulBuilder.py   | 676 ++++++++++++++++++
 .../CUDA/sm90/python/tools/nvgpucompiler.py   |  44 ++
 5 files changed, 911 insertions(+)
 create mode 100644 mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg
 create mode 100644 mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
 create mode 100644 mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg
 create mode 100644 mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
 create mode 100644 mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py

diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg b/mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg
new file mode 100644
index 00000000000000..2d5a9d00e73226
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg
@@ -0,0 +1,2 @@
+if not config.enable_cuda_runner or not config.mlir_run_cuda_sm90_tests:
+    config.unsupported = True
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
new file mode 100644
index 00000000000000..e153fcb44b9860
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
@@ -0,0 +1,186 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN:   %PYTHON %s | FileCheck %s
+# CHECK: PASS
+
+# ===--- GEMM Hopper Tensor Core Integration Test ---===
+#
+# This test aims to validate the correctness of the supported GEMM kernels in
+# NVGPU dialects, with current support for Multistage and Warp Specialization
+# kernels.
+# The test constructs and metaprograms IR using Python bindings, allowing
+# generic IR building. This flexibility enables changes to the shape,
+# tile size, or data type of the GEMM for testing purposes.
+# The entry function is `matmul`, where one can specify GEMM shape, tile size,
+# data type, GEMM algorithm (Multistage or Warp Specialization), and the maximum
+# number of stages.
+# Verification is done via numpy's matmul operation.
+#
+# Example:
+# matmul(input_type=np.float16,                # input types
+#        output_type=np.float32,               # output type
+#        M=4096, N=4096, K=4096,               # Shape
+#        BLOCK_M=128, BLOCK_N=128, BLOCK_K=64, # Tile Size
+#        use_warp_specialization=True,         # Enable Warp Specialization
+#        max_num_stages=3)                     # Number of stages in shared memory
+#
+# ===--- Parallelism Across CTAs  ---===
+#
+# GEMM includes three loops defining the shape of the GEMM, specified in the
+# `matmul` function.
+# The program builds IR using the following loop structure, tiling the loops
+# with the given tile size and parallelizing the two outermost loops into the
+# first and second dimensions of CTAs.
+#
+# for(bi = 0; i < M; i += BLOCK_M)          # parallelize across blockIdx.x
+#     for(bj = 0; j < N; j += BLOCK_N)      # parallelize across blockIdx.y
+#         for(bk = 0; k < K; K += BLOCK_K)
+#             for(i = bi; i < (bi + BLOCK_M); ++i)
+#                 for(j = bj; j < (bj + BLOCK_N); ++j)
+#                     for(k = bk; k < (bk + BLOCK_K); ++k)
+#
+# ===--- Multistage Kernel ---===
+#
+# This kernel launches a single warp group (128 threads). The primary thread
+# (pthread) requests load from TMA. Threads collectively wait for the data and
+# perform mma operations. After completing the shape, threads together store
+# first fragmented registers to shared memory, then from shared memory to global
+# memory; this part is called the epilogue.
+#
+# Execution Timeline of Multistage Kernel with 3 stages:
+# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
+# |       |Prologue ---->   |MainLoop ---->                                                                  |Epilogue  |
+# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
+# |pthread|[tma-0,1,2]     |[wait-0][mma][tma-2]|[wait-1][mma][tma-0]|[wait-2][mma][tma-1]| ... | [mma-wait] |[epilogue]|
+# |wgroup | ........       |[wait-0][mma]       |[wait-1][mma]       |[wait-2][mma]       | ... | [mma-wait] |[epilogue]|
+# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
+#
+# ===--- Warp Specialization Kernel  ---===
+#
+# This kernel launches 2 warp groups (2x128 threads) per CTA, specializing one
+# as `producer warp group` and another as `consumer warp group`. The
+# `producer warp group` is responsible for requesting TMA load, while the
+# `consumer warp group` performs the mma operation. The epilogue section is
+# handled by the `consumer warp group` as its threads own the fragmented registers.
+#
+# Execution Timeline of Warp Specialization Kernel with 2 stages:
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+# |        |MainLoop ---->                                                    | 1st Epilogue | 2nd Epilogue    |
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+# |pthread1|[tma-0] | [tma-1] | [tma-0] | [tma-1] | ..........................| ...........  | [shmem->global] |
+# |wgroup1 | .......|         |         |         |                           |              | [shmem->global] |
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+# |wgroup2 |[wait-0][mma], [wait-1][mma], [wait-0][mma], [wait-1][mma], ......| [reg->shmem] | [shmem->global]|
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+
+import errno
+import numpy as np
+import subprocess
+import ctypes
+from tools import nvgpucompiler
+from tools import matmulBuilder
+import contextlib
+import os
+import sys
+import pathlib
+import ctypes
+from mlir import runtime as rt
+
+def generate_matmul(input_type=np.float16,
+                    output_type=np.float32,
+                    M=4096,
+                    N=4096,
+                    K=4096,
+                    BLOCK_M=128,
+                    BLOCK_N=128,
+                    BLOCK_K=64,
+                    use_warp_specilization=True,
+                    saveIR=False,
+                    max_num_stages=3):
+    with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown():
+        if use_warp_specilization:
+            mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(input_type, output_type, M, N, K, BLOCK_M, BLOCK_N,
+                                                                 BLOCK_K, max_num_stages)
+        else:
+            mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage(input_type, output_type, M, N, K, BLOCK_M,
+                                                                         BLOCK_N, BLOCK_K, max_num_stages)
+
+        mlir_nvgpu_module.operation.verify()
+        
+        # Save generated IR
+        if saveIR:
+            # print(mlir_nvgpu_module)
+            original_stdout = sys.stdout
+            with open('gemm.mlir', 'w') as f:
+                sys.stdout = f
+                print(mlir_nvgpu_module)
+                sys.stdout = original_stdout
+
+        # Get compiler
+        options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
+        support_lib = os.getenv("SUPPORT_LIB")
+        if not os.path.exists(support_lib):
+            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
+        compiler = nvgpucompiler.NvgpuCompiler(options, opt_level=3, shared_libs=[support_lib])
+
+        # Compile
+        engine = compiler.compile_and_jit(mlir_nvgpu_module)
+        return engine
+
+
+def matmul(input_type=np.float16,
+           output_type=np.float32,
+           M=128,
+           N=128,
+           K=128,
+           BLOCK_M=128,
+           BLOCK_N=128,
+           BLOCK_K=64,
+           use_warp_specilization=True,
+           saveIR=False,
+           max_num_stages=3,
+           print_results=False,
+           no_verify=False):
+    # Print the configuration
+    ity = "f16" if input_type == np.float16 else "f32"
+    oty = "f16" if output_type == np.float16 else "f32"
+    gemmty = "Warp Specilization" if use_warp_specilization else "Multistage"
+    print("===-- Running GEMM " + gemmty + " " + oty + " += " + ity + " * " + ity + ", Size " + str(M) + "x" + str(N) +
+          "x" + str(K) + ", Tile " + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" + str(BLOCK_K) + ", stages " +
+          str(max_num_stages) + " --===")
+
+    # Build IR and compile
+    engine = generate_matmul(input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, use_warp_specilization,
+                             saveIR, max_num_stages)
+
+    # Allocate matrices and invoke the matmul
+    c = np.zeros((M, N), output_type)
+    a = np.random.randn(M, K).astype(input_type)
+    b = np.random.randn(K, N).astype(input_type)
+    mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
+    mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
+    mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
+    kernelName = "mlir_matmul_warpspecialized" if use_warp_specilization else "mlir_matmul_multistage"
+
+    # Launch the MLIR generated kernel
+    engine.invoke(kernelName, mem_a, mem_b, mem_c)
+
+    float_formatter = "{:.2f}".format
+    np.set_printoptions(formatter={'float_kind': float_formatter})
+
+    if print_results:
+        print(c)
+
+    # Verify the results
+    if not no_verify:
+        ref = a.astype(input_type) @ b.astype(input_type)
+        if print_results:
+            print(ref)
+        np.testing.assert_allclose(c, ref, rtol=5e-03, atol=1e-01)
+
+    print("PASS ")
+
+
+# GEMM Multistage       f32 += f16 * f16
+matmul(np.float16, np.float32, 128, 128, 4096, max_num_stages=3, use_warp_specilization=False)
+# GEMM Warp Specilized  f32 += f16 * f16
+matmul(np.float16, np.float32, 256, 1024, 512, max_num_stages=3, use_warp_specilization=True)
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg
new file mode 100644
index 00000000000000..d9f34f219c4d95
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg
@@ -0,0 +1,3 @@
+# Files in this directory are tools, not tests.
+config.unsupported = True
+
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
new file mode 100644
index 00000000000000..09ab35d4c5f15e
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
@@ -0,0 +1,676 @@
+import numpy as np
+from mlir import ir
+from mlir.dialects import arith
+from mlir.dialects import func
+from mlir.dialects import gpu
+from mlir.dialects import memref
+from mlir.dialects import nvgpu
+from mlir.dialects import nvvm
+from mlir.dialects import llvm
+from mlir.dialects import builtin
+from mlir.dialects import scf
+from mlir.dialects import vector
+
+
+TMA_LAST_DIM_F16 = 64  # 128B flaot16
+WARP_SIZE = 32
+WARP_GROUP_SIZE = WARP_SIZE * 4
+
+PRODUCER_REGISTER_SIZE = 40
+CONSUMER_REGISTER_SIZE = 232
+
+PRODUCER_PRIMARY_THREAD = 128
+CONSUMER_PRIMARY_THREAD = 0
+
+MLIR_DYNAMIC = -9223372036854775808
+f16_byte = 2
+f32_byte = 4
+
+DEBUG = False
+
+
+def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False):
+    if not DEBUG and not forcePrint:
+        return
+    type_formats = []
+    for arg in args:
+        ty_format = None
+        if ir.IndexType.isinstance(arg.type):
+            ty_format = "%llu"
+        if ir.IntegerType.isinstance(arg.type):
+            width = ir.IntegerType(arg.type).width
+            if width == 64:
+                ty_format = "%llu"
+            elif width == 32:
+                ty_format = "%d"
+            elif width == 1:
+                ty_format = "%i"
+        if ir.F32Type.isinstance(arg.type):
+            ty_format = "%f"
+        if ty_format is None:
+            raise NotImplementedError(arg.type)
+        type_formats.append(ty_format)
+    if threadNumber != -1:
+        tidx = gpu.thread_id(gpu.Dimension.x)
+        predicate = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(threadNumber))
+        scf.yield_([])
+    if_op = scf.IfOp(predicate)
+    with ir.InsertionPoint(if_op.then_block):
+        gpu.printf(fmt.format(*type_formats) + "\n", args)
+        scf.yield_([])
+
+
+def c(value, ty=None):
+    ty = ir.IndexType.get() if ty is None else ty
+    return arith.constant(ty, value)
+
+
+def generate_matmul_ws(input_type=np.float16,
+                       output_type=np.float32,
+                       M=4096,
+                       N=4096,
+                       K=4096,
+                       BLOCK_M=128,
+                       BLOCK_N=128,
+                       BLOCK_K=128,
+                       max_num_stages=3):
+    # Limitaitons for now
+    assert input_type == np.float16
+    assert output_type == np.float32
+    assert M % BLOCK_M == 0
+    assert N % BLOCK_N == 0
+    assert K % BLOCK_K == 0
+
+    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
+    num_stages = min(required_stages, max_num_stages)
+
+    module = ir.Module.create()
+    f16 = ir.F16Type.get()
+    f32 = ir.F32Type.get()
+    i1 = ir.IntegerType.get_signless(1)
+    i32 = ir.IntegerType.get_signless(32)
+    index = ir.IndexType.get()
+    i8 = ir.IntegerType.get_signless(8)
+    token_ty = ir.Type.parse("!gpu.async.token")
+    a_ty = ir.MemRefType.get([M, K], f16)
+    b_ty = ir.MemRefType.get((K, N), f16)
+    c_elem_ty = f16 if output_type == np.float16 else f32
+    c_ty = ir.MemRefType.get((M, N), c_elem_ty)
+    a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
+    b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
+    b_tile_shape = (BLOCK_K, BLOCK_N)
+    txcount = ((b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
+    smem_space_str = "#gpu.address_space<workgroup>"
+    smem_space = ir.Attribute.parse(smem_space_str)
+    input_type_str = "f16" if input_type == np.float16 else "f32"
+    output_type_str = "f16" if output_type == np.float16 else "f32"
+    mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " + str(smem_space) + ", num_barriers = " +
+                            str(num_stages) + ">")
+    a_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
+                                  str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
+                                  ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+    b_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
+                                  str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
+                                  ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+    acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
+                           str(output_type_str) + ">>")
+    a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
+                               str(input_type_str) + ", " + smem_space_str + ">>")
+    b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
+                               str(input_type_str) + ", " + smem_space_str + ">>")
+
+    with ir.InsertionPoint(module.body):
+
+        @func.FuncOp.from_py_func(a_ty, b_ty, c_ty)
+        def mlir_matmul_warpspecialized(a_host, b_host, c_host):
+            lhs_tile_bytes = BLOCK_M * BLOCK_K * f16_byte
+            rhs_tile_bytes = BLOCK_N * BLOCK_K * f16_byte
+            smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
+            smem_size_output = BLOCK_M * BLOCK_N * f32_byte
+            smem_size = max(smem_size_input, smem_size_output)
+
+            # Step 1. Allocate device memory and memcpy
+            t1 = gpu.wait(token_ty, [])
+            a_device, t2 = gpu.alloc(a_ty, token_ty, [t1], [], [])
+            b_device, t3 = gpu.alloc(b_ty, token_ty, [t2], [], [])
+            c_device, t4 = gpu.alloc(c_ty, token_ty, [t3], [], [])
+            t5 = gpu.memcpy(token_ty, [t4], a_device, a_host)
+            t6 = gpu.memcpy(token_ty, [t5], b_device, b_host)
+            t7 = gpu.wait(token_ty, [t6])
+
+            # Step 2. Create TMA Descriptors
+            tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape), (b_device, b_tma_desc_ty, b_tma_shape)]
+            tma_descs = []
+            for x_device, tensor_map_ty, tile_shape in tma_specs:
+                x_unranked = memref.cast(ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device)
+                tma_descs.append(nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked, map(c, tile_shape)).result)
+            a_tma_desc, b_tma_desc = tma_descs
+            
+            # Step 3. Launch Kernel with 2 Warpgroups : 1 Producer, 1 Consumer
+            cta_m = M // BLOCK_M
+            cta_n = N // BLOCK_N
+            assert M % BLOCK_M == 0 and N % BLOCK_N == 0
+            grid = (cta_m, cta_n, 1)
+            block = (WARP_GROUP_SIZE * 2, 1, 1)
+            launch_op = gpu.LaunchOp(token_ty, [t7],
+                                     *map(c, grid),
+                                     *map(c, block),
+                                     dynamicSharedMemorySize=c(smem_size, ty=i32))
+            launch_op.body.blocks.append(*([index] * 12))
+            with ir.InsertionPoint(launch_op.body.blocks[0]):
+                # GPU Step 0. This is need for vectorized ld/st
+                memref.assume_alignment(c_device, 16)
+                dynamic_smem = gpu.dynamic_shared_memory(
+                    ir.MemRefType.get((MLIR_DYNAMIC, ), i8, memory_space=smem_space))
+                ticks = c(10000000)
+
+                # GPU Step 1. Bootstrapping: find the primary thread, warps, warp groups and etc.
+                tidx = gpu.thread_id(gpu.Dimension.x)
+                wgPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0))
+                warp_id = arith.divui(tidx, c(32))
+                warpgroup_id = arith.divui(warp_id, c(4))
+                is_producer = arith.cmpi(arith.CmpIPredicate.eq, warpgroup_id,
+                                         c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0))
+                is_consumer = arith.cmpi(arith.CmpIPredicate.eq, warpgroup_id,
+                                         c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1))
+                producerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(PRODUCER_PRIMARY_THREAD))
+                consumerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(CONSUMER_PRIMARY_THREAD))
+                bidx = gpu.block_id(gpu.Dimension.x)
+                bidy = gpu.block_id(gpu.Dimension.y)
+                dimX = arith.muli(bidx, c(BLOCK_M))
+                dimY = arith.muli(bidy, c(BLOCK_N))
+
+                # GPU Step 2. Initialize mbarrier groups
+                mbarTMA = nvgpu.mbarrier_create(mbar_ty)
+                mbarDONE = nvgpu.mbarrier_create(mbar_ty)
+                for i in range(num_stages):
+                    nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=wgPrimaryThread)
+                    nvgpu.mbarrier_init(mbarDONE, c(1), c(i), predicate=wgPrimaryThread)
+                gpu.barrier()
+
+                # GPU Step 3. Prefetch TMA descriptors
+                nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=wgPrimaryThread)
+                nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=wgPrimaryThread)
+                
+                # GPU Step 5. Producer Warpgroup (TMA Warpgroup)
+                with ir.InsertionPoint(scf.IfOp(is_producer).then_block):
+                    
+                    # Step 5.1. Reduce register size
+                    nvvm.setmaxregister(PRODUCER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.decrease)
+
+                    # Step 5.2. TMA Main Loop
+                    for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [arith.constant(i1, 1)])
+                    with ir.InsertionPoint(for_op.body):
+                        phaseParity = for_op.inner_iter_args[0]
+                        iv = for_op.induction_variable
+                        stage = arith.remui(iv, c(num_stages))
+
+                        # Step 5.2.1. Wait mbarDONE
+                        debug_print("[prod] {}  | mbarDONE[{}] try_wait  phase={}",
+                                    iv,
+                                    stage,
+                                    phaseParity,
+                                    predicate=producerPrimaryThread)
+                        nvgpu.MBarrierTryWaitParityOp(mbarDONE, phaseParity, ticks, mbarId=stage)
+                        debug_print("[prod] {}  | mbarDONE[{}] try_wait  phase={} [done]",
+                                    iv,
+                                    stage,
+                                    phaseParity,
+                                    predicate=producerPrimaryThread)
+                        p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
+                        phaseParity = arith.select(p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity)
+
+                        # Step 5.2.2. Load TMA
+                        a_offset = arith.muli(stage, c(lhs_tile_bytes))
+                        a_tma_slice = memref.view(ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
+                                                  dynamic_smem, a_offset, [])
+                        b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
+                        b_tma_slice_1 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+                                                    dynamic_smem, b_offset, [])
+                        b_offset2 = arith.addi(b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+                        b_tma_slice_2 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+                                                    dynamic_smem, b_offset2, [])
+                        debug_print("[prod] a_offset={} b_offset={} b_offset2={}",
+                                    a_offset,
+                                    b_offset,
+                                    b_offset2,
+                                    predicate=producerPrimaryThread)
+                        coord = arith.muli(c(64), iv)
+                        nvgpu.TmaAsyncLoadOp(a_tma_slice,
+                                             mbarTMA,
+                                             a_tma_desc,
+                                             coordinates=[coord, dimX],
+                                             mbarId=stage,
+                                             predicate=producerPrimaryThread)
+                        nvgpu.TmaAsyncLoadOp(b_tma_slice_1,
+                                             mbarTMA,
+                                             b_tma_desc,
+                                             coordinates=[dimY, coord],
+                                             mbarId=stage,
+                                             predicate=producerPrimaryThread)
+                        dimY2 = arith.addi(dimY, c(64))
+                        nvgpu.TmaAsyncLoadOp(b_tma_slice_2,
+                                             mbarTMA,
+                                             b_tma_desc,
+                                             coordinates=[dimY2, coord],
+                                             mbarId=stage,
+                                             predicate=producerPrimaryThread)
+
+                        # Step 5.2.3. Arrive mbarTMA
+                        debug_print("[prod] {}  | mbarTMA[{}] arrive", iv, stage, predicate=producerPrimaryThread)
+                        nvgpu.mbarrier_arrive_expect_tx(mbarTMA, c(txcount), stage, predicate=producerPrimaryThread)
+                        debug_print("[prod] {}  | mbarTMA[{}] arrive [done]",
+                                    iv,
+                                    stage,
+                                    predicate=producerPrimaryThread)
+                        scf.yield_([phaseParity])
+                    scf.yield_([])
+
+                # GPU Step 6. Consumer Warpgroup (MMA Warpgroup)
+                if_op = scf.IfOp(is_consumer)
+                with ir.InsertionPoint(if_op.then_block):
+                    
+                    # Step 6.1. Increase register size
+                    nvvm.setmaxregister(CONSUMER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.increase)
+                    
+                    # GPU Step 6.2. Initialize MMA registers
+                    acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
+
+                    # Step 6.3. MMA Main Loop
+                    for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)])
+                    with ir.InsertionPoint(for_op.body):
+                        # Step 6.3.1. Wait mbar1
+                        phaseParity = for_op.inner_iter_args[1]
+                        iv = for_op.induction_variable
+                        stage = arith.remui(iv, c(num_stages))
+                        debug_print("[cons] {}  | mbarTMA[{}] try_wait   phase={}",
+                                    iv,
+                                    stage,
+                                    phaseParity,
+                                    predicate=consumerPrimaryThread)
+                        nvgpu.MBarrierTryWaitParityOp(mbarTMA, phaseParity, ticks, mbarId=stage)
+                        debug_print("[cons] {}  | mbarTMA[{}] try_wait   phase={} [done]",
+                                    iv,
+                                    stage,
+                                    phaseParity,
+                                    predicate=consumerPrimaryThread)
+                        
+                        # Step 6.3.2. Create WGMMA Descriptors
+                        a_offset = arith.muli(stage, c(lhs_tile_bytes))
+                        a_tile_slice = memref.view(ir.MemRefType.get(a_tile_shape, f16, memory_space=smem_space),
+                                                   dynamic_smem, a_offset, [])
+                        b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
+                        b_tile_slice = memref.view(ir.MemRefType.get(b_tile_shape, f16, memory_space=smem_space),
+                                                   dynamic_smem, b_offset, [])
+                        debug_print("[cons] a_offset={} b_offset={}",
+                                    a_offset,
+                                    b_offset,
+                                    predicate=consumerPrimaryThread)
+                        da = nvgpu.WarpgroupGenerateDescriptorOp(a_wgmma_ty, a_tile_slice, a_tma_desc)
+                        db = nvgpu.WarpgroupGenerateDescriptorOp(b_wgmma_ty, b_tile_slice, b_tma_desc)
+
+                        # Step 6.3.3. MMA
+                        carry_acc = for_op.inner_iter_args[0]
+                        new_acc = nvgpu.WarpgroupMmaOp(acc.type, da, db, carry_acc, transposeB=True)
+
+                        # Step 6.3.4. Arrive mbarDONE
+                        p1 = arith.cmpi(arith.CmpIPredicate.sgt, iv, c(0))
+                        p_arrive = arith.andi(consumerPrimaryThread, p1)
+                        with ir.InsertionPoint(scf.IfOp(p_arrive).then_block):
+                            p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(0))
+                            barId = arith.select(p, c(num_stages - 1), arith.subi(stage, c(1)))
+                            remoteCtaId = c(0)
+                            pred = consumerPrimaryThread
+                            debug_print("[cons] {}  | mbarDONE[{}] arrive.mapa  pred={} ",
+                                        iv,
+                                        barId,
+                                        remoteCtaId,
+                                        pred,
+                                        predicate=consumerPrimaryThread)
+                            nvgpu.mbarrier_arrive(ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId)
+                            debug_print("[cons] {}  | mbarDONE[{}] arrive  pred={} [done]",
+                                        iv,
+                                        barId,
+                                        remoteCtaId,
+                                        pred,
+                                        predicate=consumerPrimaryThread)
+                            scf.yield_([])
+
+                        p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
+                        phaseParity = arith.select(p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity)
+
+                        # Step 6.3.5. Yield
+                        scf.yield_([new_acc, phaseParity])
+
+                    # Step 6.3. Wait All WGMMA
+                    nvvm.WgmmaWaitGroupSyncOp(0)
+
+                    with ir.InsertionPoint(scf.IfOp(consumerPrimaryThread).then_block):
+                        barId = c((K // BLOCK_K) % num_stages)
+                        nvgpu.mbarrier_arrive(ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId)
+                        scf.yield_([])
+
+                    # Step 6.4. Epilogue (registers --> shared memory)
+                    acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space)
+                    acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
+                    debug_print("[cons]  | Storing", predicate=consumerPrimaryThread)
+                    nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
+                    scf.yield_([])
+                gpu.barrier()
+
+                # GPU Step 9. Epilogue (shared memory --> global memory)
+                fd = ir.MemRefType.get([BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space)
+                collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
+                rty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty,
+                                        ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"))
+                c_device_per_block = memref.SubViewOp(rty, c_device, [dimX, dimY], [], [], [MLIR_DYNAMIC, MLIR_DYNAMIC],
+                                                      [BLOCK_M, BLOCK_N], [1, 1])
+                vlen = 1
+                for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE * 2))
+                with ir.InsertionPoint(for_op.body):
+                    x = arith.divui(for_op.induction_variable, c(BLOCK_M))
+                    y = arith.remui(for_op.induction_variable, c(BLOCK_N))
+                    vdata = vector.load(ir.VectorType.get((vlen, ), c_elem_ty), collapsed_smem,
+                                        [for_op.induction_variable])
+                    vector.store(vdata, c_device_per_block, [x, y])
+                    scf.yield_([])
+
+                gpu.terminator()
+
+            # Step 4. Copy back to host
+            t8 = gpu.wait(token_ty, [launch_op])
+            t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
+            gpu.wait(token_ty, [t9])
+
+    mlir_matmul_warpspecialized.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+    module.operation.verify()
+    return module
+
+
+def generate_matmul_multistage(input_type=np.float16,
+                               output_type=np.float32,
+                               M=4096,
+                               N=4096,
+                               K=4096,
+                               BLOCK_M=128,
+                               BLOCK_N=128,
+                               BLOCK_K=64,
+                               max_num_stages=3):
+    # Limitaitons for now
+    assert input_type == np.float16
+    assert output_type == np.float32
+    assert M % BLOCK_M == 0
+    assert N % BLOCK_N == 0
+    assert K % BLOCK_K == 0
+
+    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
+    num_stages = min(required_stages, max_num_stages)
+
+    module = ir.Module.create()
+    f16 = ir.F16Type.get()
+    f32 = ir.F32Type.get()
+    i1 = ir.IntegerType.get_signless(1)
+    i32 = ir.IntegerType.get_signless(32)
+    index = ir.IndexType.get()
+    i8 = ir.IntegerType.get_signless(8)
+    token_ty = ir.Type.parse("!gpu.async.token")
+    a_ty = ir.MemRefType.get([M, K], f16)
+    b_ty = ir.MemRefType.get((K, N), f16)
+    c_elem_ty = f16 if output_type == np.float16 else f32
+    c_ty = ir.MemRefType.get((M, N), c_elem_ty)
+    a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
+    b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
+    b_tile_shape = (BLOCK_K, BLOCK_N)
+    txcount = ((b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
+    smem_space_str = "#gpu.address_space<workgroup>"
+    smem_space = ir.Attribute.parse(smem_space_str)
+    input_type_str = "f16" if input_type == np.float16 else "f32"
+    output_type_str = "f16" if output_type == np.float16 else "f32"
+    mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " + str(smem_space) + ", num_barriers = " +
+                            str(num_stages) + ">")
+    a_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
+                                  str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
+                                  ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+    b_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
+                                  str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
+                                  ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+    acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
+                           str(output_type_str) + ">>")
+    a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
+                               str(input_type_str) + ", " + smem_space_str + ">>")
+    b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
+                               str(input_type_str) + ", " + smem_space_str + ">>")
+
+    with ir.InsertionPoint(module.body):
+
+        @func.FuncOp.from_py_func(a_ty, b_ty, c_ty)
+        def mlir_matmul_multistage(a_host, b_host, c_host):
+            lhs_tile_bytes = BLOCK_M * BLOCK_K * f16_byte
+            rhs_tile_bytes = BLOCK_N * BLOCK_K * f16_byte
+            smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
+            smem_size_output = BLOCK_M * BLOCK_N * f32_byte
+            smem_size = max(smem_size_input, smem_size_output)
+
+            # Step 1. Allocate device memory and memcpy
+            t1 = gpu.wait(token_ty, [])
+            a_device, t2 = gpu.alloc(a_ty, token_ty, [t1], [], [])
+            b_device, t3 = gpu.alloc(b_ty, token_ty, [t2], [], [])
+            c_device, t4 = gpu.alloc(c_ty, token_ty, [t3], [], [])
+            t5 = gpu.memcpy(token_ty, [t4], a_device, a_host)
+            t6 = gpu.memcpy(token_ty, [t5], b_device, b_host)
+            t7 = gpu.wait(token_ty, [t6])
+
+            # Step 2. Create TMA Descriptors
+            tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape), (b_device, b_tma_desc_ty, b_tma_shape)]
+            tma_descs = []
+            for x_device, tensor_map_ty, tile_shape in tma_specs:
+                x_unranked = memref.cast(ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device)
+                tma_descs.append(nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked, map(c, tile_shape)).result)
+            a_tma_desc, b_tma_desc = tma_descs
+
+            # Step 3. Launch Kernel with 1 Warpgroup
+            cta_m = M // BLOCK_M
+            cta_n = N // BLOCK_N
+            assert M % BLOCK_M == 0 and N % BLOCK_N == 0
+            grid = (cta_m, cta_n, 1)
+            block = (WARP_GROUP_SIZE, 1, 1)
+            launch_op = gpu.LaunchOp(token_ty, [t7],
+                                     *map(c, grid),
+                                     *map(c, block),
+                                     dynamicSharedMemorySize=c(smem_size, ty=i32))
+            launch_op.body.blocks.append(*([index] * 12))
+            with ir.InsertionPoint(launch_op.body.blocks[0]):
+                # GPU Step 0. Bootstrapping
+                memref.assume_alignment(c_device, 16)
+                dynamic_smem = gpu.dynamic_shared_memory(
+                    ir.MemRefType.get((MLIR_DYNAMIC, ), i8, memory_space=smem_space))
+                ticks = c(10000000)
+                tidx = gpu.thread_id(gpu.Dimension.x)
+                primaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(0))
+                warpId = arith.divui(tidx, c(32))
+                bidx = gpu.block_id(gpu.Dimension.x)
+                bidy = gpu.block_id(gpu.Dimension.y)
+                dimX = arith.muli(bidx, c(BLOCK_M))
+                dimY = arith.muli(bidy, c(BLOCK_N))
+
+                # GPU Step 1. Initialize mbarrier groups
+                mbarTMA = nvgpu.mbarrier_create(mbar_ty)
+                for i in range(num_stages):
+                    nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=primaryThread)
+                gpu.barrier()
+
+                # GPU Step 2. Prefetch TMA descriptors
+                nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=primaryThread)
+                nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=primaryThread)
+
+                # GPU Step 3. Prologue (global memory --> shared memory)
+                for_op = scf.ForOp(c(0), c(num_stages-1), c(1))
+                with ir.InsertionPoint(for_op.body):
+                    iv = for_op.induction_variable
+
+                    # Step 3.1. Calculate offsets
+                    a_offset = arith.muli(iv, c(lhs_tile_bytes))
+                    a_tma_slice = memref.view(ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
+                                              dynamic_smem, a_offset, [])
+                    b_offset = arith.addi(arith.muli(iv, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
+                    b_tma_slice_1 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+                                                dynamic_smem, b_offset, [])
+                    b_offset2 = arith.addi(b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+                    b_tma_slice_2 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+                                                dynamic_smem, b_offset2, [])
+
+                    # Step 3.2. TMA Load
+                    coord = arith.muli(c(64), iv)
+                    dimY2 = arith.addi(dimY, c(64))
+                    debug_print("[Prologue] TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
+                                a_offset,
+                                b_offset,
+                                b_offset2,
+                                coord, dimX,
+                                dimY, coord,
+                                predicate=primaryThread)
+                    nvgpu.TmaAsyncLoadOp(a_tma_slice,
+                                         mbarTMA,
+                                         a_tma_desc,
+                                         coordinates=[coord, dimX],
+                                         mbarId=iv,
+                                         predicate=primaryThread)
+                    nvgpu.TmaAsyncLoadOp(b_tma_slice_1,
+                                         mbarTMA,
+                                         b_tma_desc,
+                                         coordinates=[dimY, coord],
+                                         mbarId=iv,
+                                         predicate=primaryThread)
+                    nvgpu.TmaAsyncLoadOp(b_tma_slice_2,
+                                         mbarTMA,
+                                         b_tma_desc,
+                                         coordinates=[dimY2, coord],
+                                         mbarId=iv,
+                                         predicate=primaryThread)
+
+                    # Step 3.2. mbarTMA arrive
+                    debug_print("[Prologue] mbarTMA[{}] arrive", iv, predicate=primaryThread)
+                    nvgpu.mbarrier_arrive_expect_tx(mbarTMA, c(txcount), iv, predicate=primaryThread)
+                    debug_print("[Prologue] mbarTMA[{}] arrive [done]", iv, predicate=primaryThread)
+                    scf.yield_([])
+
+                # GPU Step 4. Main Loop
+                acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
+                for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)])
+                with ir.InsertionPoint(for_op.body):
+                    # Step 4.1. Wait mbarTMA
+                    phaseParity = for_op.inner_iter_args[1]
+                    iv = for_op.induction_variable
+                    stage = arith.remui(iv, c(num_stages))
+                    debug_print("[MainLoop] mbarTMA[{}] try_wait   phase={}", stage, phaseParity, predicate=primaryThread)
+                    nvgpu.MBarrierTryWaitParityOp(mbarTMA, phaseParity, ticks, mbarId=stage)
+                    debug_print("[MainLoop] mbarTMA[{}] try_wait   phase={} [done]", stage, phaseParity, predicate=primaryThread)
+
+                    # Step 4.2. Create WGMMA Descriptors
+                    a_offset = arith.muli(stage, c(lhs_tile_bytes))
+                    a_tile_slice = memref.view(ir.MemRefType.get(a_tile_shape, f16, memory_space=smem_space),
+                                               dynamic_smem, a_offset, [])
+                    b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
+                    b_tile_slice = memref.view(ir.MemRefType.get(b_tile_shape, f16, memory_space=smem_space),
+                                               dynamic_smem, b_offset, [])
+                    debug_print("[MainLoop] iv={} MMA a_offset={} b_offset={}", iv, a_offset, b_offset, predicate=primaryThread)
+                    da = nvgpu.WarpgroupGenerateDescriptorOp(a_wgmma_ty, a_tile_slice, a_tma_desc)
+                    db = nvgpu.WarpgroupGenerateDescriptorOp(b_wgmma_ty, b_tile_slice, b_tma_desc)
+
+                    # Step 4.3. MMA
+                    carry_acc = for_op.inner_iter_args[0]
+                    new_acc = nvgpu.WarpgroupMmaOp(acc.type, da, db, carry_acc, transposeB=True)
+
+                    # Step 4.4. Load TMA for next stage
+                    p1 = arith.cmpi(arith.CmpIPredicate.ult, arith.addi(iv, c(num_stages-1)), c(K // BLOCK_K))
+                    p = arith.andi(primaryThread, p1)
+                    with ir.InsertionPoint(scf.IfOp(p).then_block):
+                        nextStage = arith.addi(iv, c(num_stages-1))
+                        nextSlot = arith.remui(nextStage, c(num_stages))
+                        a_offset = arith.muli(nextSlot, c(lhs_tile_bytes))
+                        a_tma_slice = memref.view(ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
+                                                dynamic_smem, a_offset, [])
+                        b_offset = arith.addi(arith.muli(nextSlot, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
+                        b_tma_slice_1 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+                                                    dynamic_smem, b_offset, [])
+                        b_offset2 = arith.addi(b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+                        b_tma_slice_2 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+                                                    dynamic_smem, b_offset2, [])
+
+                        coord = arith.muli(c(64), nextStage)
+                        debug_print("[MainLoop] iv={} TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
+                                iv, a_offset,
+                                b_offset,
+                                b_offset2,
+                                coord, dimX,
+                                dimY, coord,
+                                predicate=primaryThread)
+                        nvgpu.TmaAsyncLoadOp(a_tma_slice,
+                                            mbarTMA,
+                                            a_tma_desc,
+                                            coordinates=[coord, dimX],
+                                            mbarId=nextSlot,
+                                            predicate=primaryThread)
+                        nvgpu.TmaAsyncLoadOp(b_tma_slice_1,
+                                            mbarTMA,
+                                            b_tma_desc,
+                                            coordinates=[dimY, coord],
+                                            mbarId=nextSlot,
+                                            predicate=primaryThread)
+                        dimY2 = arith.addi(dimY, c(64))
+                        nvgpu.TmaAsyncLoadOp(b_tma_slice_2,
+                                            mbarTMA,
+                                            b_tma_desc,
+                                            coordinates=[dimY2, coord],
+                                            mbarId=nextSlot,
+                                            predicate=primaryThread)
+
+                        debug_print("[MainLoop] mbarTMA[{}] arrive", nextSlot, predicate=primaryThread)
+                        nvgpu.mbarrier_arrive_expect_tx(mbarTMA, c(txcount), nextSlot, predicate=primaryThread)
+                        debug_print("[MainLoop] mbarTMA[{}] arrive [done]", nextSlot, predicate=primaryThread)
+                        scf.yield_([])
+                    # Step 4.5. Change the phaseParity
+                    p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
+                    phaseParity = arith.select(p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity)
+
+                    # Step 4.5. Yield
+                    scf.yield_([new_acc, phaseParity])
+
+                # Step 5. Wait All WGMMA groups
+                nvvm.WgmmaWaitGroupSyncOp(0)
+                gpu.barrier()
+
+                # Step 6. Epilogue (registers --> shared memory)
+                acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space)
+                acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
+                debug_print("Storing", predicate=primaryThread)
+                nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
+                gpu.barrier()
+
+                # GPU Step 7. Epilogue (shared memory --> global memory)
+                fd = ir.MemRefType.get([BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space)
+                collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
+                rty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty,
+                                        ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"))
+                c_device_per_block = memref.SubViewOp(rty, c_device, [dimX, dimY], [], [], [MLIR_DYNAMIC, MLIR_DYNAMIC],
+                                                      [BLOCK_M, BLOCK_N], [1, 1])
+                vlen = 1
+                for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE))
+                with ir.InsertionPoint(for_op.body):
+                    x = arith.divui(for_op.induction_variable, c(BLOCK_M))
+                    y = arith.remui(for_op.induction_variable, c(BLOCK_N))
+                    vdata = vector.load(ir.VectorType.get((vlen, ), c_elem_ty), collapsed_smem,
+                                        [for_op.induction_variable])
+                    vector.store(vdata, c_device_per_block, [x, y])
+                    scf.yield_([])
+
+                gpu.terminator()
+
+            # Step 4. Copy back to host
+            t8 = gpu.wait(token_ty, [launch_op])
+            t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
+            gpu.wait(token_ty, [t9])
+
+    mlir_matmul_multistage.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+    module.operation.verify()
+    return module
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
new file mode 100644
index 00000000000000..2a4f67326ee718
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
@@ -0,0 +1,44 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#  This file contains the Nvgpu class.
+
+from mlir import execution_engine
+from mlir import ir
+from mlir import passmanager
+from typing import Sequence
+import errno
+import os
+import sys
+
+_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(_SCRIPT_PATH)
+
+class NvgpuCompiler:
+    """Nvgpu class for compiling and building MLIR modules."""
+
+    def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
+        pipeline = f"builtin.module(gpu-lower-to-nvvm-pipeline{{{options}}})"
+        self.pipeline = pipeline
+        self.shared_libs = shared_libs
+        self.opt_level = opt_level
+
+    def __call__(self, module: ir.Module):
+        """Convenience application method."""
+        self.compile(module)
+
+    def compile(self, module: ir.Module):
+        """Compiles the module by invoking the nvgpu pipeline."""
+        passmanager.PassManager.parse(self.pipeline).run(module.operation)
+    
+    def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+        """Wraps the module in a JIT execution engine."""
+        return execution_engine.ExecutionEngine(
+            module, opt_level=self.opt_level, shared_libs=self.shared_libs
+        )
+
+    def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+        """Compiles and jits the module."""
+        self.compile(module)
+        return self.jit(module)

>From eccc8ba445c5beaac8cdc3f108efc24b5692c423 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Mon, 12 Feb 2024 13:39:08 +0000
Subject: [PATCH 2/2] format with yapf

---
 .../GPU/CUDA/sm90/python/matmul.py            |  51 +-
 .../CUDA/sm90/python/tools/matmulBuilder.py   | 655 ++++++++++++------
 .../CUDA/sm90/python/tools/nvgpucompiler.py   |  15 +-
 3 files changed, 483 insertions(+), 238 deletions(-)

diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
index e153fcb44b9860..00313270d723d6 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
@@ -85,6 +85,7 @@
 import ctypes
 from mlir import runtime as rt
 
+
 def generate_matmul(input_type=np.float16,
                     output_type=np.float32,
                     M=4096,
@@ -96,16 +97,19 @@ def generate_matmul(input_type=np.float16,
                     use_warp_specilization=True,
                     saveIR=False,
                     max_num_stages=3):
-    with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown():
+    with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown(
+    ):
         if use_warp_specilization:
-            mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(input_type, output_type, M, N, K, BLOCK_M, BLOCK_N,
-                                                                 BLOCK_K, max_num_stages)
+            mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(
+                input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K,
+                max_num_stages)
         else:
-            mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage(input_type, output_type, M, N, K, BLOCK_M,
-                                                                         BLOCK_N, BLOCK_K, max_num_stages)
+            mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage(
+                input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K,
+                max_num_stages)
 
         mlir_nvgpu_module.operation.verify()
-        
+
         # Save generated IR
         if saveIR:
             # print(mlir_nvgpu_module)
@@ -119,8 +123,11 @@ def generate_matmul(input_type=np.float16,
         options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
         support_lib = os.getenv("SUPPORT_LIB")
         if not os.path.exists(support_lib):
-            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
-        compiler = nvgpucompiler.NvgpuCompiler(options, opt_level=3, shared_libs=[support_lib])
+            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
+                                    support_lib)
+        compiler = nvgpucompiler.NvgpuCompiler(options,
+                                               opt_level=3,
+                                               shared_libs=[support_lib])
 
         # Compile
         engine = compiler.compile_and_jit(mlir_nvgpu_module)
@@ -144,13 +151,15 @@ def matmul(input_type=np.float16,
     ity = "f16" if input_type == np.float16 else "f32"
     oty = "f16" if output_type == np.float16 else "f32"
     gemmty = "Warp Specilization" if use_warp_specilization else "Multistage"
-    print("===-- Running GEMM " + gemmty + " " + oty + " += " + ity + " * " + ity + ", Size " + str(M) + "x" + str(N) +
-          "x" + str(K) + ", Tile " + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" + str(BLOCK_K) + ", stages " +
-          str(max_num_stages) + " --===")
+    print("===-- Running GEMM " + gemmty + " " + oty + " += " + ity + " * " +
+          ity + ", Size " + str(M) + "x" + str(N) + "x" + str(K) + ", Tile " +
+          str(BLOCK_M) + "x" + str(BLOCK_N) + "x" + str(BLOCK_K) +
+          ", stages " + str(max_num_stages) + " --===")
 
     # Build IR and compile
-    engine = generate_matmul(input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, use_warp_specilization,
-                             saveIR, max_num_stages)
+    engine = generate_matmul(input_type, output_type, M, N, K, BLOCK_M,
+                             BLOCK_N, BLOCK_K, use_warp_specilization, saveIR,
+                             max_num_stages)
 
     # Allocate matrices and invoke the matmul
     c = np.zeros((M, N), output_type)
@@ -181,6 +190,18 @@ def matmul(input_type=np.float16,
 
 
 # GEMM Multistage       f32 += f16 * f16
-matmul(np.float16, np.float32, 128, 128, 4096, max_num_stages=3, use_warp_specilization=False)
+matmul(np.float16,
+       np.float32,
+       128,
+       128,
+       4096,
+       max_num_stages=3,
+       use_warp_specilization=False)
 # GEMM Warp Specilized  f32 += f16 * f16
-matmul(np.float16, np.float32, 256, 1024, 512, max_num_stages=3, use_warp_specilization=True)
+matmul(np.float16,
+       np.float32,
+       256,
+       1024,
+       512,
+       max_num_stages=3,
+       use_warp_specilization=True)
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
index 09ab35d4c5f15e..0fda8698606cd5 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
@@ -11,7 +11,6 @@
 from mlir.dialects import scf
 from mlir.dialects import vector
 
-
 TMA_LAST_DIM_F16 = 64  # 128B flaot16
 WARP_SIZE = 32
 WARP_GROUP_SIZE = WARP_SIZE * 4
@@ -81,7 +80,8 @@ def generate_matmul_ws(input_type=np.float16,
     assert N % BLOCK_N == 0
     assert K % BLOCK_K == 0
 
-    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
+    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K +
+                                          BLOCK_K * BLOCK_N)
     num_stages = min(required_stages, max_num_stages)
 
     module = ir.Module.create()
@@ -99,25 +99,36 @@ def generate_matmul_ws(input_type=np.float16,
     a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
     b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
     b_tile_shape = (BLOCK_K, BLOCK_N)
-    txcount = ((b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
+    txcount = ((b_tile_shape[0] * b_tile_shape[1]) +
+               (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
     smem_space_str = "#gpu.address_space<workgroup>"
     smem_space = ir.Attribute.parse(smem_space_str)
     input_type_str = "f16" if input_type == np.float16 else "f32"
     output_type_str = "f16" if output_type == np.float16 else "f32"
-    mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " + str(smem_space) + ", num_barriers = " +
+    mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " +
+                            str(smem_space) + ", num_barriers = " +
                             str(num_stages) + ">")
-    a_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
-                                  str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
-                                  ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
-    b_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
-                                  str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
-                                  ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
-    acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
+    a_tma_desc_ty = ir.Type.parse(
+        "!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
+        str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " +
+        str(smem_space) +
+        ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+    b_tma_desc_ty = ir.Type.parse(
+        "!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
+        str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " +
+        str(smem_space) +
+        ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+    acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" +
+                           str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
                            str(output_type_str) + ">>")
-    a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
-                               str(input_type_str) + ", " + smem_space_str + ">>")
-    b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
-                               str(input_type_str) + ", " + smem_space_str + ">>")
+    a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" +
+                               str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
+                               str(input_type_str) + ", " + smem_space_str +
+                               ">>")
+    b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" +
+                               str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
+                               str(input_type_str) + ", " + smem_space_str +
+                               ">>")
 
     with ir.InsertionPoint(module.body):
 
@@ -139,13 +150,18 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
             t7 = gpu.wait(token_ty, [t6])
 
             # Step 2. Create TMA Descriptors
-            tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape), (b_device, b_tma_desc_ty, b_tma_shape)]
+            tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape),
+                         (b_device, b_tma_desc_ty, b_tma_shape)]
             tma_descs = []
             for x_device, tensor_map_ty, tile_shape in tma_specs:
-                x_unranked = memref.cast(ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device)
-                tma_descs.append(nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked, map(c, tile_shape)).result)
+                x_unranked = memref.cast(
+                    ir.UnrankedMemRefType.get(f16, a_ty.memory_space),
+                    x_device)
+                tma_descs.append(
+                    nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked,
+                                                map(c, tile_shape)).result)
             a_tma_desc, b_tma_desc = tma_descs
-            
+
             # Step 3. Launch Kernel with 2 Warpgroups : 1 Producer, 1 Consumer
             cta_m = M // BLOCK_M
             cta_n = N // BLOCK_N
@@ -155,26 +171,37 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
             launch_op = gpu.LaunchOp(token_ty, [t7],
                                      *map(c, grid),
                                      *map(c, block),
-                                     dynamicSharedMemorySize=c(smem_size, ty=i32))
+                                     dynamicSharedMemorySize=c(smem_size,
+                                                               ty=i32))
             launch_op.body.blocks.append(*([index] * 12))
             with ir.InsertionPoint(launch_op.body.blocks[0]):
                 # GPU Step 0. This is need for vectorized ld/st
                 memref.assume_alignment(c_device, 16)
                 dynamic_smem = gpu.dynamic_shared_memory(
-                    ir.MemRefType.get((MLIR_DYNAMIC, ), i8, memory_space=smem_space))
+                    ir.MemRefType.get((MLIR_DYNAMIC, ),
+                                      i8,
+                                      memory_space=smem_space))
                 ticks = c(10000000)
 
                 # GPU Step 1. Bootstrapping: find the primary thread, warps, warp groups and etc.
                 tidx = gpu.thread_id(gpu.Dimension.x)
-                wgPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0))
+                wgPrimaryThread = arith.cmpi(
+                    arith.CmpIPredicate.eq,
+                    arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0))
                 warp_id = arith.divui(tidx, c(32))
                 warpgroup_id = arith.divui(warp_id, c(4))
-                is_producer = arith.cmpi(arith.CmpIPredicate.eq, warpgroup_id,
-                                         c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0))
-                is_consumer = arith.cmpi(arith.CmpIPredicate.eq, warpgroup_id,
-                                         c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1))
-                producerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(PRODUCER_PRIMARY_THREAD))
-                consumerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(CONSUMER_PRIMARY_THREAD))
+                is_producer = arith.cmpi(
+                    arith.CmpIPredicate.eq, warpgroup_id,
+                    c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0))
+                is_consumer = arith.cmpi(
+                    arith.CmpIPredicate.eq, warpgroup_id,
+                    c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1))
+                producerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq,
+                                                   tidx,
+                                                   c(PRODUCER_PRIMARY_THREAD))
+                consumerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq,
+                                                   tidx,
+                                                   c(CONSUMER_PRIMARY_THREAD))
                 bidx = gpu.block_id(gpu.Dimension.x)
                 bidy = gpu.block_id(gpu.Dimension.y)
                 dimX = arith.muli(bidx, c(BLOCK_M))
@@ -184,57 +211,88 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                 mbarTMA = nvgpu.mbarrier_create(mbar_ty)
                 mbarDONE = nvgpu.mbarrier_create(mbar_ty)
                 for i in range(num_stages):
-                    nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=wgPrimaryThread)
-                    nvgpu.mbarrier_init(mbarDONE, c(1), c(i), predicate=wgPrimaryThread)
+                    nvgpu.mbarrier_init(mbarTMA,
+                                        c(1),
+                                        c(i),
+                                        predicate=wgPrimaryThread)
+                    nvgpu.mbarrier_init(mbarDONE,
+                                        c(1),
+                                        c(i),
+                                        predicate=wgPrimaryThread)
                 gpu.barrier()
 
                 # GPU Step 3. Prefetch TMA descriptors
-                nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=wgPrimaryThread)
-                nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=wgPrimaryThread)
-                
+                nvgpu.tma_prefetch_descriptor(a_tma_desc,
+                                              predicate=wgPrimaryThread)
+                nvgpu.tma_prefetch_descriptor(b_tma_desc,
+                                              predicate=wgPrimaryThread)
+
                 # GPU Step 5. Producer Warpgroup (TMA Warpgroup)
                 with ir.InsertionPoint(scf.IfOp(is_producer).then_block):
-                    
+
                     # Step 5.1. Reduce register size
-                    nvvm.setmaxregister(PRODUCER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.decrease)
+                    nvvm.setmaxregister(PRODUCER_REGISTER_SIZE,
+                                        nvvm.SetMaxRegisterAction.decrease)
 
                     # Step 5.2. TMA Main Loop
-                    for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [arith.constant(i1, 1)])
+                    for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1),
+                                       [arith.constant(i1, 1)])
                     with ir.InsertionPoint(for_op.body):
                         phaseParity = for_op.inner_iter_args[0]
                         iv = for_op.induction_variable
                         stage = arith.remui(iv, c(num_stages))
 
                         # Step 5.2.1. Wait mbarDONE
-                        debug_print("[prod] {}  | mbarDONE[{}] try_wait  phase={}",
-                                    iv,
-                                    stage,
-                                    phaseParity,
-                                    predicate=producerPrimaryThread)
-                        nvgpu.MBarrierTryWaitParityOp(mbarDONE, phaseParity, ticks, mbarId=stage)
-                        debug_print("[prod] {}  | mbarDONE[{}] try_wait  phase={} [done]",
-                                    iv,
-                                    stage,
-                                    phaseParity,
-                                    predicate=producerPrimaryThread)
-                        p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
-                        phaseParity = arith.select(p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity)
+                        debug_print(
+                            "[prod] {}  | mbarDONE[{}] try_wait  phase={}",
+                            iv,
+                            stage,
+                            phaseParity,
+                            predicate=producerPrimaryThread)
+                        nvgpu.MBarrierTryWaitParityOp(mbarDONE,
+                                                      phaseParity,
+                                                      ticks,
+                                                      mbarId=stage)
+                        debug_print(
+                            "[prod] {}  | mbarDONE[{}] try_wait  phase={} [done]",
+                            iv,
+                            stage,
+                            phaseParity,
+                            predicate=producerPrimaryThread)
+                        p = arith.cmpi(arith.CmpIPredicate.eq, stage,
+                                       c(num_stages - 1))
+                        phaseParity = arith.select(
+                            p, arith.xori(phaseParity, arith.constant(i1, 1)),
+                            phaseParity)
 
                         # Step 5.2.2. Load TMA
                         a_offset = arith.muli(stage, c(lhs_tile_bytes))
-                        a_tma_slice = memref.view(ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
-                                                  dynamic_smem, a_offset, [])
-                        b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
-                        b_tma_slice_1 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
-                                                    dynamic_smem, b_offset, [])
-                        b_offset2 = arith.addi(b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
-                        b_tma_slice_2 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
-                                                    dynamic_smem, b_offset2, [])
-                        debug_print("[prod] a_offset={} b_offset={} b_offset2={}",
-                                    a_offset,
-                                    b_offset,
-                                    b_offset2,
-                                    predicate=producerPrimaryThread)
+                        a_tma_slice = memref.view(
+                            ir.MemRefType.get(a_tma_shape,
+                                              f16,
+                                              memory_space=smem_space),
+                            dynamic_smem, a_offset, [])
+                        b_offset = arith.addi(
+                            arith.muli(stage, c(rhs_tile_bytes)),
+                            c(lhs_tile_bytes * num_stages))
+                        b_tma_slice_1 = memref.view(
+                            ir.MemRefType.get(b_tma_shape,
+                                              f16,
+                                              memory_space=smem_space),
+                            dynamic_smem, b_offset, [])
+                        b_offset2 = arith.addi(
+                            b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+                        b_tma_slice_2 = memref.view(
+                            ir.MemRefType.get(b_tma_shape,
+                                              f16,
+                                              memory_space=smem_space),
+                            dynamic_smem, b_offset2, [])
+                        debug_print(
+                            "[prod] a_offset={} b_offset={} b_offset2={}",
+                            a_offset,
+                            b_offset,
+                            b_offset2,
+                            predicate=producerPrimaryThread)
                         coord = arith.muli(c(64), iv)
                         nvgpu.TmaAsyncLoadOp(a_tma_slice,
                                              mbarTMA,
@@ -257,8 +315,15 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                                              predicate=producerPrimaryThread)
 
                         # Step 5.2.3. Arrive mbarTMA
-                        debug_print("[prod] {}  | mbarTMA[{}] arrive", iv, stage, predicate=producerPrimaryThread)
-                        nvgpu.mbarrier_arrive_expect_tx(mbarTMA, c(txcount), stage, predicate=producerPrimaryThread)
+                        debug_print("[prod] {}  | mbarTMA[{}] arrive",
+                                    iv,
+                                    stage,
+                                    predicate=producerPrimaryThread)
+                        nvgpu.mbarrier_arrive_expect_tx(
+                            mbarTMA,
+                            c(txcount),
+                            stage,
+                            predicate=producerPrimaryThread)
                         debug_print("[prod] {}  | mbarTMA[{}] arrive [done]",
                                     iv,
                                     stage,
@@ -269,75 +334,104 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                 # GPU Step 6. Consumer Warpgroup (MMA Warpgroup)
                 if_op = scf.IfOp(is_consumer)
                 with ir.InsertionPoint(if_op.then_block):
-                    
+
                     # Step 6.1. Increase register size
-                    nvvm.setmaxregister(CONSUMER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.increase)
-                    
+                    nvvm.setmaxregister(CONSUMER_REGISTER_SIZE,
+                                        nvvm.SetMaxRegisterAction.increase)
+
                     # GPU Step 6.2. Initialize MMA registers
                     acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
 
                     # Step 6.3. MMA Main Loop
-                    for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)])
+                    for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1),
+                                       [acc, arith.constant(i1, 0)])
                     with ir.InsertionPoint(for_op.body):
                         # Step 6.3.1. Wait mbar1
                         phaseParity = for_op.inner_iter_args[1]
                         iv = for_op.induction_variable
                         stage = arith.remui(iv, c(num_stages))
-                        debug_print("[cons] {}  | mbarTMA[{}] try_wait   phase={}",
-                                    iv,
-                                    stage,
-                                    phaseParity,
-                                    predicate=consumerPrimaryThread)
-                        nvgpu.MBarrierTryWaitParityOp(mbarTMA, phaseParity, ticks, mbarId=stage)
-                        debug_print("[cons] {}  | mbarTMA[{}] try_wait   phase={} [done]",
-                                    iv,
-                                    stage,
-                                    phaseParity,
-                                    predicate=consumerPrimaryThread)
-                        
+                        debug_print(
+                            "[cons] {}  | mbarTMA[{}] try_wait   phase={}",
+                            iv,
+                            stage,
+                            phaseParity,
+                            predicate=consumerPrimaryThread)
+                        nvgpu.MBarrierTryWaitParityOp(mbarTMA,
+                                                      phaseParity,
+                                                      ticks,
+                                                      mbarId=stage)
+                        debug_print(
+                            "[cons] {}  | mbarTMA[{}] try_wait   phase={} [done]",
+                            iv,
+                            stage,
+                            phaseParity,
+                            predicate=consumerPrimaryThread)
+
                         # Step 6.3.2. Create WGMMA Descriptors
                         a_offset = arith.muli(stage, c(lhs_tile_bytes))
-                        a_tile_slice = memref.view(ir.MemRefType.get(a_tile_shape, f16, memory_space=smem_space),
-                                                   dynamic_smem, a_offset, [])
-                        b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
-                        b_tile_slice = memref.view(ir.MemRefType.get(b_tile_shape, f16, memory_space=smem_space),
-                                                   dynamic_smem, b_offset, [])
+                        a_tile_slice = memref.view(
+                            ir.MemRefType.get(a_tile_shape,
+                                              f16,
+                                              memory_space=smem_space),
+                            dynamic_smem, a_offset, [])
+                        b_offset = arith.addi(
+                            arith.muli(stage, c(rhs_tile_bytes)),
+                            c(lhs_tile_bytes * num_stages))
+                        b_tile_slice = memref.view(
+                            ir.MemRefType.get(b_tile_shape,
+                                              f16,
+                                              memory_space=smem_space),
+                            dynamic_smem, b_offset, [])
                         debug_print("[cons] a_offset={} b_offset={}",
                                     a_offset,
                                     b_offset,
                                     predicate=consumerPrimaryThread)
-                        da = nvgpu.WarpgroupGenerateDescriptorOp(a_wgmma_ty, a_tile_slice, a_tma_desc)
-                        db = nvgpu.WarpgroupGenerateDescriptorOp(b_wgmma_ty, b_tile_slice, b_tma_desc)
+                        da = nvgpu.WarpgroupGenerateDescriptorOp(
+                            a_wgmma_ty, a_tile_slice, a_tma_desc)
+                        db = nvgpu.WarpgroupGenerateDescriptorOp(
+                            b_wgmma_ty, b_tile_slice, b_tma_desc)
 
                         # Step 6.3.3. MMA
                         carry_acc = for_op.inner_iter_args[0]
-                        new_acc = nvgpu.WarpgroupMmaOp(acc.type, da, db, carry_acc, transposeB=True)
+                        new_acc = nvgpu.WarpgroupMmaOp(acc.type,
+                                                       da,
+                                                       db,
+                                                       carry_acc,
+                                                       transposeB=True)
 
                         # Step 6.3.4. Arrive mbarDONE
                         p1 = arith.cmpi(arith.CmpIPredicate.sgt, iv, c(0))
                         p_arrive = arith.andi(consumerPrimaryThread, p1)
                         with ir.InsertionPoint(scf.IfOp(p_arrive).then_block):
                             p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(0))
-                            barId = arith.select(p, c(num_stages - 1), arith.subi(stage, c(1)))
+                            barId = arith.select(p, c(num_stages - 1),
+                                                 arith.subi(stage, c(1)))
                             remoteCtaId = c(0)
                             pred = consumerPrimaryThread
-                            debug_print("[cons] {}  | mbarDONE[{}] arrive.mapa  pred={} ",
-                                        iv,
-                                        barId,
-                                        remoteCtaId,
-                                        pred,
-                                        predicate=consumerPrimaryThread)
-                            nvgpu.mbarrier_arrive(ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId)
-                            debug_print("[cons] {}  | mbarDONE[{}] arrive  pred={} [done]",
-                                        iv,
-                                        barId,
-                                        remoteCtaId,
-                                        pred,
-                                        predicate=consumerPrimaryThread)
+                            debug_print(
+                                "[cons] {}  | mbarDONE[{}] arrive.mapa  pred={} ",
+                                iv,
+                                barId,
+                                remoteCtaId,
+                                pred,
+                                predicate=consumerPrimaryThread)
+                            nvgpu.mbarrier_arrive(
+                                ir.Type.parse("!nvgpu.mbarrier.token"),
+                                mbarDONE, barId)
+                            debug_print(
+                                "[cons] {}  | mbarDONE[{}] arrive  pred={} [done]",
+                                iv,
+                                barId,
+                                remoteCtaId,
+                                pred,
+                                predicate=consumerPrimaryThread)
                             scf.yield_([])
 
-                        p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
-                        phaseParity = arith.select(p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity)
+                        p = arith.cmpi(arith.CmpIPredicate.eq, stage,
+                                       c(num_stages - 1))
+                        phaseParity = arith.select(
+                            p, arith.xori(phaseParity, arith.constant(i1, 1)),
+                            phaseParity)
 
                         # Step 6.3.5. Yield
                         scf.yield_([new_acc, phaseParity])
@@ -345,32 +439,45 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                     # Step 6.3. Wait All WGMMA
                     nvvm.WgmmaWaitGroupSyncOp(0)
 
-                    with ir.InsertionPoint(scf.IfOp(consumerPrimaryThread).then_block):
+                    with ir.InsertionPoint(
+                            scf.IfOp(consumerPrimaryThread).then_block):
                         barId = c((K // BLOCK_K) % num_stages)
-                        nvgpu.mbarrier_arrive(ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId)
+                        nvgpu.mbarrier_arrive(
+                            ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE,
+                            barId)
                         scf.yield_([])
 
                     # Step 6.4. Epilogue (registers --> shared memory)
-                    acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space)
+                    acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N),
+                                                    c_elem_ty,
+                                                    memory_space=smem_space)
                     acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
-                    debug_print("[cons]  | Storing", predicate=consumerPrimaryThread)
+                    debug_print("[cons]  | Storing",
+                                predicate=consumerPrimaryThread)
                     nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
                     scf.yield_([])
                 gpu.barrier()
 
                 # GPU Step 9. Epilogue (shared memory --> global memory)
-                fd = ir.MemRefType.get([BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space)
+                fd = ir.MemRefType.get([BLOCK_M * BLOCK_N],
+                                       c_elem_ty,
+                                       memory_space=smem_space)
                 collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
-                rty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty,
-                                        ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"))
-                c_device_per_block = memref.SubViewOp(rty, c_device, [dimX, dimY], [], [], [MLIR_DYNAMIC, MLIR_DYNAMIC],
-                                                      [BLOCK_M, BLOCK_N], [1, 1])
+                rty = ir.MemRefType.get(
+                    (BLOCK_M, BLOCK_N), c_elem_ty,
+                    ir.Attribute.parse("strided<[" + str(N) +
+                                       ", 1], offset: ?>"))
+                c_device_per_block = memref.SubViewOp(
+                    rty, c_device, [dimX, dimY], [], [],
+                    [MLIR_DYNAMIC, MLIR_DYNAMIC], [BLOCK_M, BLOCK_N], [1, 1])
                 vlen = 1
-                for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE * 2))
+                for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N),
+                                   c(vlen * WARP_GROUP_SIZE * 2))
                 with ir.InsertionPoint(for_op.body):
                     x = arith.divui(for_op.induction_variable, c(BLOCK_M))
                     y = arith.remui(for_op.induction_variable, c(BLOCK_N))
-                    vdata = vector.load(ir.VectorType.get((vlen, ), c_elem_ty), collapsed_smem,
+                    vdata = vector.load(ir.VectorType.get(
+                        (vlen, ), c_elem_ty), collapsed_smem,
                                         [for_op.induction_variable])
                     vector.store(vdata, c_device_per_block, [x, y])
                     scf.yield_([])
@@ -382,7 +489,8 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
             t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
             gpu.wait(token_ty, [t9])
 
-    mlir_matmul_warpspecialized.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+    mlir_matmul_warpspecialized.func_op.attributes[
+        "llvm.emit_c_interface"] = ir.UnitAttr.get()
     module.operation.verify()
     return module
 
@@ -403,7 +511,8 @@ def generate_matmul_multistage(input_type=np.float16,
     assert N % BLOCK_N == 0
     assert K % BLOCK_K == 0
 
-    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
+    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K +
+                                          BLOCK_K * BLOCK_N)
     num_stages = min(required_stages, max_num_stages)
 
     module = ir.Module.create()
@@ -421,25 +530,36 @@ def generate_matmul_multistage(input_type=np.float16,
     a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
     b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
     b_tile_shape = (BLOCK_K, BLOCK_N)
-    txcount = ((b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
+    txcount = ((b_tile_shape[0] * b_tile_shape[1]) +
+               (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
     smem_space_str = "#gpu.address_space<workgroup>"
     smem_space = ir.Attribute.parse(smem_space_str)
     input_type_str = "f16" if input_type == np.float16 else "f32"
     output_type_str = "f16" if output_type == np.float16 else "f32"
-    mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " + str(smem_space) + ", num_barriers = " +
+    mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " +
+                            str(smem_space) + ", num_barriers = " +
                             str(num_stages) + ">")
-    a_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
-                                  str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
-                                  ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
-    b_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
-                                  str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
-                                  ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
-    acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
+    a_tma_desc_ty = ir.Type.parse(
+        "!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
+        str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " +
+        str(smem_space) +
+        ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+    b_tma_desc_ty = ir.Type.parse(
+        "!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
+        str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " +
+        str(smem_space) +
+        ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+    acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" +
+                           str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
                            str(output_type_str) + ">>")
-    a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
-                               str(input_type_str) + ", " + smem_space_str + ">>")
-    b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
-                               str(input_type_str) + ", " + smem_space_str + ">>")
+    a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" +
+                               str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
+                               str(input_type_str) + ", " + smem_space_str +
+                               ">>")
+    b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" +
+                               str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
+                               str(input_type_str) + ", " + smem_space_str +
+                               ">>")
 
     with ir.InsertionPoint(module.body):
 
@@ -461,11 +581,16 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
             t7 = gpu.wait(token_ty, [t6])
 
             # Step 2. Create TMA Descriptors
-            tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape), (b_device, b_tma_desc_ty, b_tma_shape)]
+            tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape),
+                         (b_device, b_tma_desc_ty, b_tma_shape)]
             tma_descs = []
             for x_device, tensor_map_ty, tile_shape in tma_specs:
-                x_unranked = memref.cast(ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device)
-                tma_descs.append(nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked, map(c, tile_shape)).result)
+                x_unranked = memref.cast(
+                    ir.UnrankedMemRefType.get(f16, a_ty.memory_space),
+                    x_device)
+                tma_descs.append(
+                    nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked,
+                                                map(c, tile_shape)).result)
             a_tma_desc, b_tma_desc = tma_descs
 
             # Step 3. Launch Kernel with 1 Warpgroup
@@ -477,13 +602,16 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
             launch_op = gpu.LaunchOp(token_ty, [t7],
                                      *map(c, grid),
                                      *map(c, block),
-                                     dynamicSharedMemorySize=c(smem_size, ty=i32))
+                                     dynamicSharedMemorySize=c(smem_size,
+                                                               ty=i32))
             launch_op.body.blocks.append(*([index] * 12))
             with ir.InsertionPoint(launch_op.body.blocks[0]):
                 # GPU Step 0. Bootstrapping
                 memref.assume_alignment(c_device, 16)
                 dynamic_smem = gpu.dynamic_shared_memory(
-                    ir.MemRefType.get((MLIR_DYNAMIC, ), i8, memory_space=smem_space))
+                    ir.MemRefType.get((MLIR_DYNAMIC, ),
+                                      i8,
+                                      memory_space=smem_space))
                 ticks = c(10000000)
                 tidx = gpu.thread_id(gpu.Dimension.x)
                 primaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(0))
@@ -496,39 +624,58 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                 # GPU Step 1. Initialize mbarrier groups
                 mbarTMA = nvgpu.mbarrier_create(mbar_ty)
                 for i in range(num_stages):
-                    nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=primaryThread)
+                    nvgpu.mbarrier_init(mbarTMA,
+                                        c(1),
+                                        c(i),
+                                        predicate=primaryThread)
                 gpu.barrier()
 
                 # GPU Step 2. Prefetch TMA descriptors
-                nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=primaryThread)
-                nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=primaryThread)
+                nvgpu.tma_prefetch_descriptor(a_tma_desc,
+                                              predicate=primaryThread)
+                nvgpu.tma_prefetch_descriptor(b_tma_desc,
+                                              predicate=primaryThread)
 
                 # GPU Step 3. Prologue (global memory --> shared memory)
-                for_op = scf.ForOp(c(0), c(num_stages-1), c(1))
+                for_op = scf.ForOp(c(0), c(num_stages - 1), c(1))
                 with ir.InsertionPoint(for_op.body):
                     iv = for_op.induction_variable
 
                     # Step 3.1. Calculate offsets
                     a_offset = arith.muli(iv, c(lhs_tile_bytes))
-                    a_tma_slice = memref.view(ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
-                                              dynamic_smem, a_offset, [])
-                    b_offset = arith.addi(arith.muli(iv, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
-                    b_tma_slice_1 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
-                                                dynamic_smem, b_offset, [])
-                    b_offset2 = arith.addi(b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
-                    b_tma_slice_2 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
-                                                dynamic_smem, b_offset2, [])
+                    a_tma_slice = memref.view(
+                        ir.MemRefType.get(a_tma_shape,
+                                          f16,
+                                          memory_space=smem_space),
+                        dynamic_smem, a_offset, [])
+                    b_offset = arith.addi(arith.muli(iv, c(rhs_tile_bytes)),
+                                          c(lhs_tile_bytes * num_stages))
+                    b_tma_slice_1 = memref.view(
+                        ir.MemRefType.get(b_tma_shape,
+                                          f16,
+                                          memory_space=smem_space),
+                        dynamic_smem, b_offset, [])
+                    b_offset2 = arith.addi(
+                        b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+                    b_tma_slice_2 = memref.view(
+                        ir.MemRefType.get(b_tma_shape,
+                                          f16,
+                                          memory_space=smem_space),
+                        dynamic_smem, b_offset2, [])
 
                     # Step 3.2. TMA Load
                     coord = arith.muli(c(64), iv)
                     dimY2 = arith.addi(dimY, c(64))
-                    debug_print("[Prologue] TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
-                                a_offset,
-                                b_offset,
-                                b_offset2,
-                                coord, dimX,
-                                dimY, coord,
-                                predicate=primaryThread)
+                    debug_print(
+                        "[Prologue] TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
+                        a_offset,
+                        b_offset,
+                        b_offset2,
+                        coord,
+                        dimX,
+                        dimY,
+                        coord,
+                        predicate=primaryThread)
                     nvgpu.TmaAsyncLoadOp(a_tma_slice,
                                          mbarTMA,
                                          a_tma_desc,
@@ -549,89 +696,153 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                                          predicate=primaryThread)
 
                     # Step 3.2. mbarTMA arrive
-                    debug_print("[Prologue] mbarTMA[{}] arrive", iv, predicate=primaryThread)
-                    nvgpu.mbarrier_arrive_expect_tx(mbarTMA, c(txcount), iv, predicate=primaryThread)
-                    debug_print("[Prologue] mbarTMA[{}] arrive [done]", iv, predicate=primaryThread)
+                    debug_print("[Prologue] mbarTMA[{}] arrive",
+                                iv,
+                                predicate=primaryThread)
+                    nvgpu.mbarrier_arrive_expect_tx(mbarTMA,
+                                                    c(txcount),
+                                                    iv,
+                                                    predicate=primaryThread)
+                    debug_print("[Prologue] mbarTMA[{}] arrive [done]",
+                                iv,
+                                predicate=primaryThread)
                     scf.yield_([])
 
                 # GPU Step 4. Main Loop
                 acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
-                for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)])
+                for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1),
+                                   [acc, arith.constant(i1, 0)])
                 with ir.InsertionPoint(for_op.body):
                     # Step 4.1. Wait mbarTMA
                     phaseParity = for_op.inner_iter_args[1]
                     iv = for_op.induction_variable
                     stage = arith.remui(iv, c(num_stages))
-                    debug_print("[MainLoop] mbarTMA[{}] try_wait   phase={}", stage, phaseParity, predicate=primaryThread)
-                    nvgpu.MBarrierTryWaitParityOp(mbarTMA, phaseParity, ticks, mbarId=stage)
-                    debug_print("[MainLoop] mbarTMA[{}] try_wait   phase={} [done]", stage, phaseParity, predicate=primaryThread)
+                    debug_print("[MainLoop] mbarTMA[{}] try_wait   phase={}",
+                                stage,
+                                phaseParity,
+                                predicate=primaryThread)
+                    nvgpu.MBarrierTryWaitParityOp(mbarTMA,
+                                                  phaseParity,
+                                                  ticks,
+                                                  mbarId=stage)
+                    debug_print(
+                        "[MainLoop] mbarTMA[{}] try_wait   phase={} [done]",
+                        stage,
+                        phaseParity,
+                        predicate=primaryThread)
 
                     # Step 4.2. Create WGMMA Descriptors
                     a_offset = arith.muli(stage, c(lhs_tile_bytes))
-                    a_tile_slice = memref.view(ir.MemRefType.get(a_tile_shape, f16, memory_space=smem_space),
-                                               dynamic_smem, a_offset, [])
-                    b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
-                    b_tile_slice = memref.view(ir.MemRefType.get(b_tile_shape, f16, memory_space=smem_space),
-                                               dynamic_smem, b_offset, [])
-                    debug_print("[MainLoop] iv={} MMA a_offset={} b_offset={}", iv, a_offset, b_offset, predicate=primaryThread)
-                    da = nvgpu.WarpgroupGenerateDescriptorOp(a_wgmma_ty, a_tile_slice, a_tma_desc)
-                    db = nvgpu.WarpgroupGenerateDescriptorOp(b_wgmma_ty, b_tile_slice, b_tma_desc)
+                    a_tile_slice = memref.view(
+                        ir.MemRefType.get(a_tile_shape,
+                                          f16,
+                                          memory_space=smem_space),
+                        dynamic_smem, a_offset, [])
+                    b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)),
+                                          c(lhs_tile_bytes * num_stages))
+                    b_tile_slice = memref.view(
+                        ir.MemRefType.get(b_tile_shape,
+                                          f16,
+                                          memory_space=smem_space),
+                        dynamic_smem, b_offset, [])
+                    debug_print("[MainLoop] iv={} MMA a_offset={} b_offset={}",
+                                iv,
+                                a_offset,
+                                b_offset,
+                                predicate=primaryThread)
+                    da = nvgpu.WarpgroupGenerateDescriptorOp(
+                        a_wgmma_ty, a_tile_slice, a_tma_desc)
+                    db = nvgpu.WarpgroupGenerateDescriptorOp(
+                        b_wgmma_ty, b_tile_slice, b_tma_desc)
 
                     # Step 4.3. MMA
                     carry_acc = for_op.inner_iter_args[0]
-                    new_acc = nvgpu.WarpgroupMmaOp(acc.type, da, db, carry_acc, transposeB=True)
+                    new_acc = nvgpu.WarpgroupMmaOp(acc.type,
+                                                   da,
+                                                   db,
+                                                   carry_acc,
+                                                   transposeB=True)
 
                     # Step 4.4. Load TMA for next stage
-                    p1 = arith.cmpi(arith.CmpIPredicate.ult, arith.addi(iv, c(num_stages-1)), c(K // BLOCK_K))
+                    p1 = arith.cmpi(arith.CmpIPredicate.ult,
+                                    arith.addi(iv, c(num_stages - 1)),
+                                    c(K // BLOCK_K))
                     p = arith.andi(primaryThread, p1)
                     with ir.InsertionPoint(scf.IfOp(p).then_block):
-                        nextStage = arith.addi(iv, c(num_stages-1))
+                        nextStage = arith.addi(iv, c(num_stages - 1))
                         nextSlot = arith.remui(nextStage, c(num_stages))
                         a_offset = arith.muli(nextSlot, c(lhs_tile_bytes))
-                        a_tma_slice = memref.view(ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
-                                                dynamic_smem, a_offset, [])
-                        b_offset = arith.addi(arith.muli(nextSlot, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
-                        b_tma_slice_1 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
-                                                    dynamic_smem, b_offset, [])
-                        b_offset2 = arith.addi(b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
-                        b_tma_slice_2 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
-                                                    dynamic_smem, b_offset2, [])
+                        a_tma_slice = memref.view(
+                            ir.MemRefType.get(a_tma_shape,
+                                              f16,
+                                              memory_space=smem_space),
+                            dynamic_smem, a_offset, [])
+                        b_offset = arith.addi(
+                            arith.muli(nextSlot, c(rhs_tile_bytes)),
+                            c(lhs_tile_bytes * num_stages))
+                        b_tma_slice_1 = memref.view(
+                            ir.MemRefType.get(b_tma_shape,
+                                              f16,
+                                              memory_space=smem_space),
+                            dynamic_smem, b_offset, [])
+                        b_offset2 = arith.addi(
+                            b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+                        b_tma_slice_2 = memref.view(
+                            ir.MemRefType.get(b_tma_shape,
+                                              f16,
+                                              memory_space=smem_space),
+                            dynamic_smem, b_offset2, [])
 
                         coord = arith.muli(c(64), nextStage)
-                        debug_print("[MainLoop] iv={} TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
-                                iv, a_offset,
-                                b_offset,
-                                b_offset2,
-                                coord, dimX,
-                                dimY, coord,
-                                predicate=primaryThread)
+                        debug_print(
+                            "[MainLoop] iv={} TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
+                            iv,
+                            a_offset,
+                            b_offset,
+                            b_offset2,
+                            coord,
+                            dimX,
+                            dimY,
+                            coord,
+                            predicate=primaryThread)
                         nvgpu.TmaAsyncLoadOp(a_tma_slice,
-                                            mbarTMA,
-                                            a_tma_desc,
-                                            coordinates=[coord, dimX],
-                                            mbarId=nextSlot,
-                                            predicate=primaryThread)
+                                             mbarTMA,
+                                             a_tma_desc,
+                                             coordinates=[coord, dimX],
+                                             mbarId=nextSlot,
+                                             predicate=primaryThread)
                         nvgpu.TmaAsyncLoadOp(b_tma_slice_1,
-                                            mbarTMA,
-                                            b_tma_desc,
-                                            coordinates=[dimY, coord],
-                                            mbarId=nextSlot,
-                                            predicate=primaryThread)
+                                             mbarTMA,
+                                             b_tma_desc,
+                                             coordinates=[dimY, coord],
+                                             mbarId=nextSlot,
+                                             predicate=primaryThread)
                         dimY2 = arith.addi(dimY, c(64))
                         nvgpu.TmaAsyncLoadOp(b_tma_slice_2,
-                                            mbarTMA,
-                                            b_tma_desc,
-                                            coordinates=[dimY2, coord],
-                                            mbarId=nextSlot,
-                                            predicate=primaryThread)
-
-                        debug_print("[MainLoop] mbarTMA[{}] arrive", nextSlot, predicate=primaryThread)
-                        nvgpu.mbarrier_arrive_expect_tx(mbarTMA, c(txcount), nextSlot, predicate=primaryThread)
-                        debug_print("[MainLoop] mbarTMA[{}] arrive [done]", nextSlot, predicate=primaryThread)
+                                             mbarTMA,
+                                             b_tma_desc,
+                                             coordinates=[dimY2, coord],
+                                             mbarId=nextSlot,
+                                             predicate=primaryThread)
+
+                        debug_print("[MainLoop] mbarTMA[{}] arrive",
+                                    nextSlot,
+                                    predicate=primaryThread)
+                        nvgpu.mbarrier_arrive_expect_tx(
+                            mbarTMA,
+                            c(txcount),
+                            nextSlot,
+                            predicate=primaryThread)
+                        debug_print("[MainLoop] mbarTMA[{}] arrive [done]",
+                                    nextSlot,
+                                    predicate=primaryThread)
                         scf.yield_([])
                     # Step 4.5. Change the phaseParity
-                    p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
-                    phaseParity = arith.select(p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity)
+                    p = arith.cmpi(arith.CmpIPredicate.eq, stage,
+                                   c(num_stages - 1))
+                    phaseParity = arith.select(
+                        p, arith.xori(phaseParity, arith.constant(i1, 1)),
+                        phaseParity)
 
                     # Step 4.5. Yield
                     scf.yield_([new_acc, phaseParity])
@@ -641,25 +852,34 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                 gpu.barrier()
 
                 # Step 6. Epilogue (registers --> shared memory)
-                acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space)
+                acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N),
+                                                c_elem_ty,
+                                                memory_space=smem_space)
                 acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
                 debug_print("Storing", predicate=primaryThread)
                 nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
                 gpu.barrier()
 
                 # GPU Step 7. Epilogue (shared memory --> global memory)
-                fd = ir.MemRefType.get([BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space)
+                fd = ir.MemRefType.get([BLOCK_M * BLOCK_N],
+                                       c_elem_ty,
+                                       memory_space=smem_space)
                 collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
-                rty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty,
-                                        ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"))
-                c_device_per_block = memref.SubViewOp(rty, c_device, [dimX, dimY], [], [], [MLIR_DYNAMIC, MLIR_DYNAMIC],
-                                                      [BLOCK_M, BLOCK_N], [1, 1])
+                rty = ir.MemRefType.get(
+                    (BLOCK_M, BLOCK_N), c_elem_ty,
+                    ir.Attribute.parse("strided<[" + str(N) +
+                                       ", 1], offset: ?>"))
+                c_device_per_block = memref.SubViewOp(
+                    rty, c_device, [dimX, dimY], [], [],
+                    [MLIR_DYNAMIC, MLIR_DYNAMIC], [BLOCK_M, BLOCK_N], [1, 1])
                 vlen = 1
-                for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE))
+                for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N),
+                                   c(vlen * WARP_GROUP_SIZE))
                 with ir.InsertionPoint(for_op.body):
                     x = arith.divui(for_op.induction_variable, c(BLOCK_M))
                     y = arith.remui(for_op.induction_variable, c(BLOCK_N))
-                    vdata = vector.load(ir.VectorType.get((vlen, ), c_elem_ty), collapsed_smem,
+                    vdata = vector.load(ir.VectorType.get(
+                        (vlen, ), c_elem_ty), collapsed_smem,
                                         [for_op.induction_variable])
                     vector.store(vdata, c_device_per_block, [x, y])
                     scf.yield_([])
@@ -671,6 +891,7 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
             t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
             gpu.wait(token_ty, [t9])
 
-    mlir_matmul_multistage.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+    mlir_matmul_multistage.func_op.attributes[
+        "llvm.emit_c_interface"] = ir.UnitAttr.get()
     module.operation.verify()
     return module
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
index 2a4f67326ee718..ab218812bb4d96 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
@@ -15,10 +15,12 @@
 _SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
 sys.path.append(_SCRIPT_PATH)
 
+
 class NvgpuCompiler:
     """Nvgpu class for compiling and building MLIR modules."""
 
-    def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
+    def __init__(self, options: str, opt_level: int,
+                 shared_libs: Sequence[str]):
         pipeline = f"builtin.module(gpu-lower-to-nvvm-pipeline{{{options}}})"
         self.pipeline = pipeline
         self.shared_libs = shared_libs
@@ -31,14 +33,15 @@ def __call__(self, module: ir.Module):
     def compile(self, module: ir.Module):
         """Compiles the module by invoking the nvgpu pipeline."""
         passmanager.PassManager.parse(self.pipeline).run(module.operation)
-    
+
     def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
         """Wraps the module in a JIT execution engine."""
-        return execution_engine.ExecutionEngine(
-            module, opt_level=self.opt_level, shared_libs=self.shared_libs
-        )
+        return execution_engine.ExecutionEngine(module,
+                                                opt_level=self.opt_level,
+                                                shared_libs=self.shared_libs)
 
-    def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+    def compile_and_jit(self,
+                        module: ir.Module) -> execution_engine.ExecutionEngine:
         """Compiles and jits the module."""
         self.compile(module)
         return self.jit(module)



More information about the Mlir-commits mailing list