[Mlir-commits] [mlir] [mlir] GEMM Hopper Tensor Core Integration Test (PR #81478)

Guray Ozen llvmlistbot at llvm.org
Sun Mar 3 00:24:19 PST 2024


https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/81478

>From 72e6a310edaa2e7d8e7ede385de09e52bd302f9b Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Mon, 12 Feb 2024 13:33:26 +0000
Subject: [PATCH 1/7] [mlir] GEMM Hopper Tensor Core Integration Test

This test aims to validate the correctness of the supported GEMM kernels in
NVGPU dialects, with current support for Multistage and Warp Specialization
kernels.
The test constructs and metaprograms IR using Python bindings, allowing
generic IR building. This flexibility enables changes to the shape,
tile size, or data type of the GEMM for testing purposes.
The entry function is `matmul`, where one can specify GEMM shape, tile size,
data type, GEMM algorithm (Multistage or Warp Specialization), and the maximum
number of stages.
Verification is done via numpy's matmul operation.

Example:
```
matmul(input_type=np.float16,                # input types
       output_type=np.float32,               # output type
       M=4096, N=4096, K=4096,               # Shape
       BLOCK_M=128, BLOCK_N=128, BLOCK_K=64, # Tile Size
       use_warp_specialization=True,         # Enable Warp Specialization
       max_num_stages=3)                     # Number of stages in shared memory
```
### Parallelism Across CTAs

GEMM includes three loops defining the shape of the GEMM, specified in the
`matmul` function.
The program builds IR using the following loop structure, tiling the loops
with the given tile size and parallelizing the two outermost loops into the
first and second dimensions of CTAs.
```
for(bi = 0; i < M; i += BLOCK_M)          # parallelize across blockIdx.x
    for(bj = 0; j < N; j += BLOCK_N)      # parallelize across blockIdx.y
        for(bk = 0; k < K; K += BLOCK_K)
            for(i = bi; i < (bi + BLOCK_M); ++i)
                for(j = bj; j < (bj + BLOCK_N); ++j)
                    for(k = bk; k < (bk + BLOCK_K); ++k)
```

## Multistage Kernel

This kernel launches a single warp group (128 threads). The primary thread
(pthread) requests load from TMA. Threads collectively wait for the data and
perform mma operations. After completing the shape, threads together store
first fragmented registers to shared memory, then from shared memory to global
memory; this part is called the epilogue.

Execution Timeline of Multistage Kernel with 3 stages:
```
+-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
|       |Prologue ---->   |MainLoop ---->                                                                  |Epilogue  |
+-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
|pthread|[tma-0,1,2]     |[wait-0][mma][tma-2]|[wait-1][mma][tma-0]|[wait-2][mma][tma-1]| ... | [mma-wait] |[epilogue]|
|wgroup | ........       |[wait-0][mma]       |[wait-1][mma]       |[wait-2][mma]       | ... | [mma-wait] |[epilogue]|
+-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
```

## Warp Specialization Kernel

This kernel launches 2 warp groups (2x128 threads) per CTA, specializing one
as `producer warp group` and another as `consumer warp group`. The
`producer warp group` is responsible for requesting TMA load, while the
`consumer warp group` performs the mma operation. The epilogue section is
handled by the `consumer warp group` as its threads own the fragmented registers.

Execution Timeline of Warp Specialization Kernel with 2 stages:
```
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
|        |MainLoop ---->                                                    | 1st Epilogue | 2nd Epilogue    |
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
|pthread1|[tma-0] | [tma-1] | [tma-0] | [tma-1] | ..........................| ...........  | [shmem->global] |
|wgroup1 | .......|         |         |         |                           |              | [shmem->global] |
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
|wgroup2 |[wait-0][mma], [wait-1][mma], [wait-0][mma], [wait-1][mma], ......| [reg->shmem] | [shmem->global]|
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
```
---
 .../GPU/CUDA/sm90/python/lit.local.cfg        |   2 +
 .../GPU/CUDA/sm90/python/matmul.py            | 186 +++++
 .../GPU/CUDA/sm90/python/tools/lit.local.cfg  |   3 +
 .../CUDA/sm90/python/tools/matmulBuilder.py   | 676 ++++++++++++++++++
 .../CUDA/sm90/python/tools/nvgpucompiler.py   |  44 ++
 5 files changed, 911 insertions(+)
 create mode 100644 mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg
 create mode 100644 mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
 create mode 100644 mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg
 create mode 100644 mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
 create mode 100644 mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py

diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg b/mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg
new file mode 100644
index 00000000000000..2d5a9d00e73226
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg
@@ -0,0 +1,2 @@
+if not config.enable_cuda_runner or not config.mlir_run_cuda_sm90_tests:
+    config.unsupported = True
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
new file mode 100644
index 00000000000000..e153fcb44b9860
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
@@ -0,0 +1,186 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN:   %PYTHON %s | FileCheck %s
+# CHECK: PASS
+
+# ===--- GEMM Hopper Tensor Core Integration Test ---===
+#
+# This test aims to validate the correctness of the supported GEMM kernels in
+# NVGPU dialects, with current support for Multistage and Warp Specialization
+# kernels.
+# The test constructs and metaprograms IR using Python bindings, allowing
+# generic IR building. This flexibility enables changes to the shape,
+# tile size, or data type of the GEMM for testing purposes.
+# The entry function is `matmul`, where one can specify GEMM shape, tile size,
+# data type, GEMM algorithm (Multistage or Warp Specialization), and the maximum
+# number of stages.
+# Verification is done via numpy's matmul operation.
+#
+# Example:
+# matmul(input_type=np.float16,                # input types
+#        output_type=np.float32,               # output type
+#        M=4096, N=4096, K=4096,               # Shape
+#        BLOCK_M=128, BLOCK_N=128, BLOCK_K=64, # Tile Size
+#        use_warp_specialization=True,         # Enable Warp Specialization
+#        max_num_stages=3)                     # Number of stages in shared memory
+#
+# ===--- Parallelism Across CTAs  ---===
+#
+# GEMM includes three loops defining the shape of the GEMM, specified in the
+# `matmul` function.
+# The program builds IR using the following loop structure, tiling the loops
+# with the given tile size and parallelizing the two outermost loops into the
+# first and second dimensions of CTAs.
+#
+# for(bi = 0; i < M; i += BLOCK_M)          # parallelize across blockIdx.x
+#     for(bj = 0; j < N; j += BLOCK_N)      # parallelize across blockIdx.y
+#         for(bk = 0; k < K; K += BLOCK_K)
+#             for(i = bi; i < (bi + BLOCK_M); ++i)
+#                 for(j = bj; j < (bj + BLOCK_N); ++j)
+#                     for(k = bk; k < (bk + BLOCK_K); ++k)
+#
+# ===--- Multistage Kernel ---===
+#
+# This kernel launches a single warp group (128 threads). The primary thread
+# (pthread) requests load from TMA. Threads collectively wait for the data and
+# perform mma operations. After completing the shape, threads together store
+# first fragmented registers to shared memory, then from shared memory to global
+# memory; this part is called the epilogue.
+#
+# Execution Timeline of Multistage Kernel with 3 stages:
+# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
+# |       |Prologue ---->   |MainLoop ---->                                                                  |Epilogue  |
+# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
+# |pthread|[tma-0,1,2]     |[wait-0][mma][tma-2]|[wait-1][mma][tma-0]|[wait-2][mma][tma-1]| ... | [mma-wait] |[epilogue]|
+# |wgroup | ........       |[wait-0][mma]       |[wait-1][mma]       |[wait-2][mma]       | ... | [mma-wait] |[epilogue]|
+# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
+#
+# ===--- Warp Specialization Kernel  ---===
+#
+# This kernel launches 2 warp groups (2x128 threads) per CTA, specializing one
+# as `producer warp group` and another as `consumer warp group`. The
+# `producer warp group` is responsible for requesting TMA load, while the
+# `consumer warp group` performs the mma operation. The epilogue section is
+# handled by the `consumer warp group` as its threads own the fragmented registers.
+#
+# Execution Timeline of Warp Specialization Kernel with 2 stages:
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+# |        |MainLoop ---->                                                    | 1st Epilogue | 2nd Epilogue    |
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+# |pthread1|[tma-0] | [tma-1] | [tma-0] | [tma-1] | ..........................| ...........  | [shmem->global] |
+# |wgroup1 | .......|         |         |         |                           |              | [shmem->global] |
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+# |wgroup2 |[wait-0][mma], [wait-1][mma], [wait-0][mma], [wait-1][mma], ......| [reg->shmem] | [shmem->global]|
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+
+import errno
+import numpy as np
+import subprocess
+import ctypes
+from tools import nvgpucompiler
+from tools import matmulBuilder
+import contextlib
+import os
+import sys
+import pathlib
+import ctypes
+from mlir import runtime as rt
+
+def generate_matmul(input_type=np.float16,
+                    output_type=np.float32,
+                    M=4096,
+                    N=4096,
+                    K=4096,
+                    BLOCK_M=128,
+                    BLOCK_N=128,
+                    BLOCK_K=64,
+                    use_warp_specilization=True,
+                    saveIR=False,
+                    max_num_stages=3):
+    with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown():
+        if use_warp_specilization:
+            mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(input_type, output_type, M, N, K, BLOCK_M, BLOCK_N,
+                                                                 BLOCK_K, max_num_stages)
+        else:
+            mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage(input_type, output_type, M, N, K, BLOCK_M,
+                                                                         BLOCK_N, BLOCK_K, max_num_stages)
+
+        mlir_nvgpu_module.operation.verify()
+        
+        # Save generated IR
+        if saveIR:
+            # print(mlir_nvgpu_module)
+            original_stdout = sys.stdout
+            with open('gemm.mlir', 'w') as f:
+                sys.stdout = f
+                print(mlir_nvgpu_module)
+                sys.stdout = original_stdout
+
+        # Get compiler
+        options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
+        support_lib = os.getenv("SUPPORT_LIB")
+        if not os.path.exists(support_lib):
+            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
+        compiler = nvgpucompiler.NvgpuCompiler(options, opt_level=3, shared_libs=[support_lib])
+
+        # Compile
+        engine = compiler.compile_and_jit(mlir_nvgpu_module)
+        return engine
+
+
+def matmul(input_type=np.float16,
+           output_type=np.float32,
+           M=128,
+           N=128,
+           K=128,
+           BLOCK_M=128,
+           BLOCK_N=128,
+           BLOCK_K=64,
+           use_warp_specilization=True,
+           saveIR=False,
+           max_num_stages=3,
+           print_results=False,
+           no_verify=False):
+    # Print the configuration
+    ity = "f16" if input_type == np.float16 else "f32"
+    oty = "f16" if output_type == np.float16 else "f32"
+    gemmty = "Warp Specilization" if use_warp_specilization else "Multistage"
+    print("===-- Running GEMM " + gemmty + " " + oty + " += " + ity + " * " + ity + ", Size " + str(M) + "x" + str(N) +
+          "x" + str(K) + ", Tile " + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" + str(BLOCK_K) + ", stages " +
+          str(max_num_stages) + " --===")
+
+    # Build IR and compile
+    engine = generate_matmul(input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, use_warp_specilization,
+                             saveIR, max_num_stages)
+
+    # Allocate matrices and invoke the matmul
+    c = np.zeros((M, N), output_type)
+    a = np.random.randn(M, K).astype(input_type)
+    b = np.random.randn(K, N).astype(input_type)
+    mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
+    mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
+    mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
+    kernelName = "mlir_matmul_warpspecialized" if use_warp_specilization else "mlir_matmul_multistage"
+
+    # Launch the MLIR generated kernel
+    engine.invoke(kernelName, mem_a, mem_b, mem_c)
+
+    float_formatter = "{:.2f}".format
+    np.set_printoptions(formatter={'float_kind': float_formatter})
+
+    if print_results:
+        print(c)
+
+    # Verify the results
+    if not no_verify:
+        ref = a.astype(input_type) @ b.astype(input_type)
+        if print_results:
+            print(ref)
+        np.testing.assert_allclose(c, ref, rtol=5e-03, atol=1e-01)
+
+    print("PASS ")
+
+
+# GEMM Multistage       f32 += f16 * f16
+matmul(np.float16, np.float32, 128, 128, 4096, max_num_stages=3, use_warp_specilization=False)
+# GEMM Warp Specilized  f32 += f16 * f16
+matmul(np.float16, np.float32, 256, 1024, 512, max_num_stages=3, use_warp_specilization=True)
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg
new file mode 100644
index 00000000000000..d9f34f219c4d95
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg
@@ -0,0 +1,3 @@
+# Files in this directory are tools, not tests.
+config.unsupported = True
+
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
new file mode 100644
index 00000000000000..09ab35d4c5f15e
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
@@ -0,0 +1,676 @@
+import numpy as np
+from mlir import ir
+from mlir.dialects import arith
+from mlir.dialects import func
+from mlir.dialects import gpu
+from mlir.dialects import memref
+from mlir.dialects import nvgpu
+from mlir.dialects import nvvm
+from mlir.dialects import llvm
+from mlir.dialects import builtin
+from mlir.dialects import scf
+from mlir.dialects import vector
+
+
+TMA_LAST_DIM_F16 = 64  # 128B flaot16
+WARP_SIZE = 32
+WARP_GROUP_SIZE = WARP_SIZE * 4
+
+PRODUCER_REGISTER_SIZE = 40
+CONSUMER_REGISTER_SIZE = 232
+
+PRODUCER_PRIMARY_THREAD = 128
+CONSUMER_PRIMARY_THREAD = 0
+
+MLIR_DYNAMIC = -9223372036854775808
+f16_byte = 2
+f32_byte = 4
+
+DEBUG = False
+
+
+def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False):
+    if not DEBUG and not forcePrint:
+        return
+    type_formats = []
+    for arg in args:
+        ty_format = None
+        if ir.IndexType.isinstance(arg.type):
+            ty_format = "%llu"
+        if ir.IntegerType.isinstance(arg.type):
+            width = ir.IntegerType(arg.type).width
+            if width == 64:
+                ty_format = "%llu"
+            elif width == 32:
+                ty_format = "%d"
+            elif width == 1:
+                ty_format = "%i"
+        if ir.F32Type.isinstance(arg.type):
+            ty_format = "%f"
+        if ty_format is None:
+            raise NotImplementedError(arg.type)
+        type_formats.append(ty_format)
+    if threadNumber != -1:
+        tidx = gpu.thread_id(gpu.Dimension.x)
+        predicate = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(threadNumber))
+        scf.yield_([])
+    if_op = scf.IfOp(predicate)
+    with ir.InsertionPoint(if_op.then_block):
+        gpu.printf(fmt.format(*type_formats) + "\n", args)
+        scf.yield_([])
+
+
+def c(value, ty=None):
+    ty = ir.IndexType.get() if ty is None else ty
+    return arith.constant(ty, value)
+
+
+def generate_matmul_ws(input_type=np.float16,
+                       output_type=np.float32,
+                       M=4096,
+                       N=4096,
+                       K=4096,
+                       BLOCK_M=128,
+                       BLOCK_N=128,
+                       BLOCK_K=128,
+                       max_num_stages=3):
+    # Limitaitons for now
+    assert input_type == np.float16
+    assert output_type == np.float32
+    assert M % BLOCK_M == 0
+    assert N % BLOCK_N == 0
+    assert K % BLOCK_K == 0
+
+    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
+    num_stages = min(required_stages, max_num_stages)
+
+    module = ir.Module.create()
+    f16 = ir.F16Type.get()
+    f32 = ir.F32Type.get()
+    i1 = ir.IntegerType.get_signless(1)
+    i32 = ir.IntegerType.get_signless(32)
+    index = ir.IndexType.get()
+    i8 = ir.IntegerType.get_signless(8)
+    token_ty = ir.Type.parse("!gpu.async.token")
+    a_ty = ir.MemRefType.get([M, K], f16)
+    b_ty = ir.MemRefType.get((K, N), f16)
+    c_elem_ty = f16 if output_type == np.float16 else f32
+    c_ty = ir.MemRefType.get((M, N), c_elem_ty)
+    a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
+    b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
+    b_tile_shape = (BLOCK_K, BLOCK_N)
+    txcount = ((b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
+    smem_space_str = "#gpu.address_space<workgroup>"
+    smem_space = ir.Attribute.parse(smem_space_str)
+    input_type_str = "f16" if input_type == np.float16 else "f32"
+    output_type_str = "f16" if output_type == np.float16 else "f32"
+    mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " + str(smem_space) + ", num_barriers = " +
+                            str(num_stages) + ">")
+    a_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
+                                  str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
+                                  ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+    b_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
+                                  str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
+                                  ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+    acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
+                           str(output_type_str) + ">>")
+    a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
+                               str(input_type_str) + ", " + smem_space_str + ">>")
+    b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
+                               str(input_type_str) + ", " + smem_space_str + ">>")
+
+    with ir.InsertionPoint(module.body):
+
+        @func.FuncOp.from_py_func(a_ty, b_ty, c_ty)
+        def mlir_matmul_warpspecialized(a_host, b_host, c_host):
+            lhs_tile_bytes = BLOCK_M * BLOCK_K * f16_byte
+            rhs_tile_bytes = BLOCK_N * BLOCK_K * f16_byte
+            smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
+            smem_size_output = BLOCK_M * BLOCK_N * f32_byte
+            smem_size = max(smem_size_input, smem_size_output)
+
+            # Step 1. Allocate device memory and memcpy
+            t1 = gpu.wait(token_ty, [])
+            a_device, t2 = gpu.alloc(a_ty, token_ty, [t1], [], [])
+            b_device, t3 = gpu.alloc(b_ty, token_ty, [t2], [], [])
+            c_device, t4 = gpu.alloc(c_ty, token_ty, [t3], [], [])
+            t5 = gpu.memcpy(token_ty, [t4], a_device, a_host)
+            t6 = gpu.memcpy(token_ty, [t5], b_device, b_host)
+            t7 = gpu.wait(token_ty, [t6])
+
+            # Step 2. Create TMA Descriptors
+            tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape), (b_device, b_tma_desc_ty, b_tma_shape)]
+            tma_descs = []
+            for x_device, tensor_map_ty, tile_shape in tma_specs:
+                x_unranked = memref.cast(ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device)
+                tma_descs.append(nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked, map(c, tile_shape)).result)
+            a_tma_desc, b_tma_desc = tma_descs
+            
+            # Step 3. Launch Kernel with 2 Warpgroups : 1 Producer, 1 Consumer
+            cta_m = M // BLOCK_M
+            cta_n = N // BLOCK_N
+            assert M % BLOCK_M == 0 and N % BLOCK_N == 0
+            grid = (cta_m, cta_n, 1)
+            block = (WARP_GROUP_SIZE * 2, 1, 1)
+            launch_op = gpu.LaunchOp(token_ty, [t7],
+                                     *map(c, grid),
+                                     *map(c, block),
+                                     dynamicSharedMemorySize=c(smem_size, ty=i32))
+            launch_op.body.blocks.append(*([index] * 12))
+            with ir.InsertionPoint(launch_op.body.blocks[0]):
+                # GPU Step 0. This is need for vectorized ld/st
+                memref.assume_alignment(c_device, 16)
+                dynamic_smem = gpu.dynamic_shared_memory(
+                    ir.MemRefType.get((MLIR_DYNAMIC, ), i8, memory_space=smem_space))
+                ticks = c(10000000)
+
+                # GPU Step 1. Bootstrapping: find the primary thread, warps, warp groups and etc.
+                tidx = gpu.thread_id(gpu.Dimension.x)
+                wgPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0))
+                warp_id = arith.divui(tidx, c(32))
+                warpgroup_id = arith.divui(warp_id, c(4))
+                is_producer = arith.cmpi(arith.CmpIPredicate.eq, warpgroup_id,
+                                         c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0))
+                is_consumer = arith.cmpi(arith.CmpIPredicate.eq, warpgroup_id,
+                                         c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1))
+                producerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(PRODUCER_PRIMARY_THREAD))
+                consumerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(CONSUMER_PRIMARY_THREAD))
+                bidx = gpu.block_id(gpu.Dimension.x)
+                bidy = gpu.block_id(gpu.Dimension.y)
+                dimX = arith.muli(bidx, c(BLOCK_M))
+                dimY = arith.muli(bidy, c(BLOCK_N))
+
+                # GPU Step 2. Initialize mbarrier groups
+                mbarTMA = nvgpu.mbarrier_create(mbar_ty)
+                mbarDONE = nvgpu.mbarrier_create(mbar_ty)
+                for i in range(num_stages):
+                    nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=wgPrimaryThread)
+                    nvgpu.mbarrier_init(mbarDONE, c(1), c(i), predicate=wgPrimaryThread)
+                gpu.barrier()
+
+                # GPU Step 3. Prefetch TMA descriptors
+                nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=wgPrimaryThread)
+                nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=wgPrimaryThread)
+                
+                # GPU Step 5. Producer Warpgroup (TMA Warpgroup)
+                with ir.InsertionPoint(scf.IfOp(is_producer).then_block):
+                    
+                    # Step 5.1. Reduce register size
+                    nvvm.setmaxregister(PRODUCER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.decrease)
+
+                    # Step 5.2. TMA Main Loop
+                    for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [arith.constant(i1, 1)])
+                    with ir.InsertionPoint(for_op.body):
+                        phaseParity = for_op.inner_iter_args[0]
+                        iv = for_op.induction_variable
+                        stage = arith.remui(iv, c(num_stages))
+
+                        # Step 5.2.1. Wait mbarDONE
+                        debug_print("[prod] {}  | mbarDONE[{}] try_wait  phase={}",
+                                    iv,
+                                    stage,
+                                    phaseParity,
+                                    predicate=producerPrimaryThread)
+                        nvgpu.MBarrierTryWaitParityOp(mbarDONE, phaseParity, ticks, mbarId=stage)
+                        debug_print("[prod] {}  | mbarDONE[{}] try_wait  phase={} [done]",
+                                    iv,
+                                    stage,
+                                    phaseParity,
+                                    predicate=producerPrimaryThread)
+                        p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
+                        phaseParity = arith.select(p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity)
+
+                        # Step 5.2.2. Load TMA
+                        a_offset = arith.muli(stage, c(lhs_tile_bytes))
+                        a_tma_slice = memref.view(ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
+                                                  dynamic_smem, a_offset, [])
+                        b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
+                        b_tma_slice_1 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+                                                    dynamic_smem, b_offset, [])
+                        b_offset2 = arith.addi(b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+                        b_tma_slice_2 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+                                                    dynamic_smem, b_offset2, [])
+                        debug_print("[prod] a_offset={} b_offset={} b_offset2={}",
+                                    a_offset,
+                                    b_offset,
+                                    b_offset2,
+                                    predicate=producerPrimaryThread)
+                        coord = arith.muli(c(64), iv)
+                        nvgpu.TmaAsyncLoadOp(a_tma_slice,
+                                             mbarTMA,
+                                             a_tma_desc,
+                                             coordinates=[coord, dimX],
+                                             mbarId=stage,
+                                             predicate=producerPrimaryThread)
+                        nvgpu.TmaAsyncLoadOp(b_tma_slice_1,
+                                             mbarTMA,
+                                             b_tma_desc,
+                                             coordinates=[dimY, coord],
+                                             mbarId=stage,
+                                             predicate=producerPrimaryThread)
+                        dimY2 = arith.addi(dimY, c(64))
+                        nvgpu.TmaAsyncLoadOp(b_tma_slice_2,
+                                             mbarTMA,
+                                             b_tma_desc,
+                                             coordinates=[dimY2, coord],
+                                             mbarId=stage,
+                                             predicate=producerPrimaryThread)
+
+                        # Step 5.2.3. Arrive mbarTMA
+                        debug_print("[prod] {}  | mbarTMA[{}] arrive", iv, stage, predicate=producerPrimaryThread)
+                        nvgpu.mbarrier_arrive_expect_tx(mbarTMA, c(txcount), stage, predicate=producerPrimaryThread)
+                        debug_print("[prod] {}  | mbarTMA[{}] arrive [done]",
+                                    iv,
+                                    stage,
+                                    predicate=producerPrimaryThread)
+                        scf.yield_([phaseParity])
+                    scf.yield_([])
+
+                # GPU Step 6. Consumer Warpgroup (MMA Warpgroup)
+                if_op = scf.IfOp(is_consumer)
+                with ir.InsertionPoint(if_op.then_block):
+                    
+                    # Step 6.1. Increase register size
+                    nvvm.setmaxregister(CONSUMER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.increase)
+                    
+                    # GPU Step 6.2. Initialize MMA registers
+                    acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
+
+                    # Step 6.3. MMA Main Loop
+                    for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)])
+                    with ir.InsertionPoint(for_op.body):
+                        # Step 6.3.1. Wait mbar1
+                        phaseParity = for_op.inner_iter_args[1]
+                        iv = for_op.induction_variable
+                        stage = arith.remui(iv, c(num_stages))
+                        debug_print("[cons] {}  | mbarTMA[{}] try_wait   phase={}",
+                                    iv,
+                                    stage,
+                                    phaseParity,
+                                    predicate=consumerPrimaryThread)
+                        nvgpu.MBarrierTryWaitParityOp(mbarTMA, phaseParity, ticks, mbarId=stage)
+                        debug_print("[cons] {}  | mbarTMA[{}] try_wait   phase={} [done]",
+                                    iv,
+                                    stage,
+                                    phaseParity,
+                                    predicate=consumerPrimaryThread)
+                        
+                        # Step 6.3.2. Create WGMMA Descriptors
+                        a_offset = arith.muli(stage, c(lhs_tile_bytes))
+                        a_tile_slice = memref.view(ir.MemRefType.get(a_tile_shape, f16, memory_space=smem_space),
+                                                   dynamic_smem, a_offset, [])
+                        b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
+                        b_tile_slice = memref.view(ir.MemRefType.get(b_tile_shape, f16, memory_space=smem_space),
+                                                   dynamic_smem, b_offset, [])
+                        debug_print("[cons] a_offset={} b_offset={}",
+                                    a_offset,
+                                    b_offset,
+                                    predicate=consumerPrimaryThread)
+                        da = nvgpu.WarpgroupGenerateDescriptorOp(a_wgmma_ty, a_tile_slice, a_tma_desc)
+                        db = nvgpu.WarpgroupGenerateDescriptorOp(b_wgmma_ty, b_tile_slice, b_tma_desc)
+
+                        # Step 6.3.3. MMA
+                        carry_acc = for_op.inner_iter_args[0]
+                        new_acc = nvgpu.WarpgroupMmaOp(acc.type, da, db, carry_acc, transposeB=True)
+
+                        # Step 6.3.4. Arrive mbarDONE
+                        p1 = arith.cmpi(arith.CmpIPredicate.sgt, iv, c(0))
+                        p_arrive = arith.andi(consumerPrimaryThread, p1)
+                        with ir.InsertionPoint(scf.IfOp(p_arrive).then_block):
+                            p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(0))
+                            barId = arith.select(p, c(num_stages - 1), arith.subi(stage, c(1)))
+                            remoteCtaId = c(0)
+                            pred = consumerPrimaryThread
+                            debug_print("[cons] {}  | mbarDONE[{}] arrive.mapa  pred={} ",
+                                        iv,
+                                        barId,
+                                        remoteCtaId,
+                                        pred,
+                                        predicate=consumerPrimaryThread)
+                            nvgpu.mbarrier_arrive(ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId)
+                            debug_print("[cons] {}  | mbarDONE[{}] arrive  pred={} [done]",
+                                        iv,
+                                        barId,
+                                        remoteCtaId,
+                                        pred,
+                                        predicate=consumerPrimaryThread)
+                            scf.yield_([])
+
+                        p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
+                        phaseParity = arith.select(p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity)
+
+                        # Step 6.3.5. Yield
+                        scf.yield_([new_acc, phaseParity])
+
+                    # Step 6.3. Wait All WGMMA
+                    nvvm.WgmmaWaitGroupSyncOp(0)
+
+                    with ir.InsertionPoint(scf.IfOp(consumerPrimaryThread).then_block):
+                        barId = c((K // BLOCK_K) % num_stages)
+                        nvgpu.mbarrier_arrive(ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId)
+                        scf.yield_([])
+
+                    # Step 6.4. Epilogue (registers --> shared memory)
+                    acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space)
+                    acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
+                    debug_print("[cons]  | Storing", predicate=consumerPrimaryThread)
+                    nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
+                    scf.yield_([])
+                gpu.barrier()
+
+                # GPU Step 9. Epilogue (shared memory --> global memory)
+                fd = ir.MemRefType.get([BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space)
+                collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
+                rty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty,
+                                        ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"))
+                c_device_per_block = memref.SubViewOp(rty, c_device, [dimX, dimY], [], [], [MLIR_DYNAMIC, MLIR_DYNAMIC],
+                                                      [BLOCK_M, BLOCK_N], [1, 1])
+                vlen = 1
+                for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE * 2))
+                with ir.InsertionPoint(for_op.body):
+                    x = arith.divui(for_op.induction_variable, c(BLOCK_M))
+                    y = arith.remui(for_op.induction_variable, c(BLOCK_N))
+                    vdata = vector.load(ir.VectorType.get((vlen, ), c_elem_ty), collapsed_smem,
+                                        [for_op.induction_variable])
+                    vector.store(vdata, c_device_per_block, [x, y])
+                    scf.yield_([])
+
+                gpu.terminator()
+
+            # Step 4. Copy back to host
+            t8 = gpu.wait(token_ty, [launch_op])
+            t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
+            gpu.wait(token_ty, [t9])
+
+    mlir_matmul_warpspecialized.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+    module.operation.verify()
+    return module
+
+
+def generate_matmul_multistage(input_type=np.float16,
+                               output_type=np.float32,
+                               M=4096,
+                               N=4096,
+                               K=4096,
+                               BLOCK_M=128,
+                               BLOCK_N=128,
+                               BLOCK_K=64,
+                               max_num_stages=3):
+    # Limitaitons for now
+    assert input_type == np.float16
+    assert output_type == np.float32
+    assert M % BLOCK_M == 0
+    assert N % BLOCK_N == 0
+    assert K % BLOCK_K == 0
+
+    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
+    num_stages = min(required_stages, max_num_stages)
+
+    module = ir.Module.create()
+    f16 = ir.F16Type.get()
+    f32 = ir.F32Type.get()
+    i1 = ir.IntegerType.get_signless(1)
+    i32 = ir.IntegerType.get_signless(32)
+    index = ir.IndexType.get()
+    i8 = ir.IntegerType.get_signless(8)
+    token_ty = ir.Type.parse("!gpu.async.token")
+    a_ty = ir.MemRefType.get([M, K], f16)
+    b_ty = ir.MemRefType.get((K, N), f16)
+    c_elem_ty = f16 if output_type == np.float16 else f32
+    c_ty = ir.MemRefType.get((M, N), c_elem_ty)
+    a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
+    b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
+    b_tile_shape = (BLOCK_K, BLOCK_N)
+    txcount = ((b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
+    smem_space_str = "#gpu.address_space<workgroup>"
+    smem_space = ir.Attribute.parse(smem_space_str)
+    input_type_str = "f16" if input_type == np.float16 else "f32"
+    output_type_str = "f16" if output_type == np.float16 else "f32"
+    mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " + str(smem_space) + ", num_barriers = " +
+                            str(num_stages) + ">")
+    a_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
+                                  str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
+                                  ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+    b_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
+                                  str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
+                                  ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+    acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
+                           str(output_type_str) + ">>")
+    a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
+                               str(input_type_str) + ", " + smem_space_str + ">>")
+    b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
+                               str(input_type_str) + ", " + smem_space_str + ">>")
+
+    with ir.InsertionPoint(module.body):
+
+        @func.FuncOp.from_py_func(a_ty, b_ty, c_ty)
+        def mlir_matmul_multistage(a_host, b_host, c_host):
+            lhs_tile_bytes = BLOCK_M * BLOCK_K * f16_byte
+            rhs_tile_bytes = BLOCK_N * BLOCK_K * f16_byte
+            smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
+            smem_size_output = BLOCK_M * BLOCK_N * f32_byte
+            smem_size = max(smem_size_input, smem_size_output)
+
+            # Step 1. Allocate device memory and memcpy
+            t1 = gpu.wait(token_ty, [])
+            a_device, t2 = gpu.alloc(a_ty, token_ty, [t1], [], [])
+            b_device, t3 = gpu.alloc(b_ty, token_ty, [t2], [], [])
+            c_device, t4 = gpu.alloc(c_ty, token_ty, [t3], [], [])
+            t5 = gpu.memcpy(token_ty, [t4], a_device, a_host)
+            t6 = gpu.memcpy(token_ty, [t5], b_device, b_host)
+            t7 = gpu.wait(token_ty, [t6])
+
+            # Step 2. Create TMA Descriptors
+            tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape), (b_device, b_tma_desc_ty, b_tma_shape)]
+            tma_descs = []
+            for x_device, tensor_map_ty, tile_shape in tma_specs:
+                x_unranked = memref.cast(ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device)
+                tma_descs.append(nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked, map(c, tile_shape)).result)
+            a_tma_desc, b_tma_desc = tma_descs
+
+            # Step 3. Launch Kernel with 1 Warpgroup
+            cta_m = M // BLOCK_M
+            cta_n = N // BLOCK_N
+            assert M % BLOCK_M == 0 and N % BLOCK_N == 0
+            grid = (cta_m, cta_n, 1)
+            block = (WARP_GROUP_SIZE, 1, 1)
+            launch_op = gpu.LaunchOp(token_ty, [t7],
+                                     *map(c, grid),
+                                     *map(c, block),
+                                     dynamicSharedMemorySize=c(smem_size, ty=i32))
+            launch_op.body.blocks.append(*([index] * 12))
+            with ir.InsertionPoint(launch_op.body.blocks[0]):
+                # GPU Step 0. Bootstrapping
+                memref.assume_alignment(c_device, 16)
+                dynamic_smem = gpu.dynamic_shared_memory(
+                    ir.MemRefType.get((MLIR_DYNAMIC, ), i8, memory_space=smem_space))
+                ticks = c(10000000)
+                tidx = gpu.thread_id(gpu.Dimension.x)
+                primaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(0))
+                warpId = arith.divui(tidx, c(32))
+                bidx = gpu.block_id(gpu.Dimension.x)
+                bidy = gpu.block_id(gpu.Dimension.y)
+                dimX = arith.muli(bidx, c(BLOCK_M))
+                dimY = arith.muli(bidy, c(BLOCK_N))
+
+                # GPU Step 1. Initialize mbarrier groups
+                mbarTMA = nvgpu.mbarrier_create(mbar_ty)
+                for i in range(num_stages):
+                    nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=primaryThread)
+                gpu.barrier()
+
+                # GPU Step 2. Prefetch TMA descriptors
+                nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=primaryThread)
+                nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=primaryThread)
+
+                # GPU Step 3. Prologue (global memory --> shared memory)
+                for_op = scf.ForOp(c(0), c(num_stages-1), c(1))
+                with ir.InsertionPoint(for_op.body):
+                    iv = for_op.induction_variable
+
+                    # Step 3.1. Calculate offsets
+                    a_offset = arith.muli(iv, c(lhs_tile_bytes))
+                    a_tma_slice = memref.view(ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
+                                              dynamic_smem, a_offset, [])
+                    b_offset = arith.addi(arith.muli(iv, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
+                    b_tma_slice_1 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+                                                dynamic_smem, b_offset, [])
+                    b_offset2 = arith.addi(b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+                    b_tma_slice_2 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+                                                dynamic_smem, b_offset2, [])
+
+                    # Step 3.2. TMA Load
+                    coord = arith.muli(c(64), iv)
+                    dimY2 = arith.addi(dimY, c(64))
+                    debug_print("[Prologue] TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
+                                a_offset,
+                                b_offset,
+                                b_offset2,
+                                coord, dimX,
+                                dimY, coord,
+                                predicate=primaryThread)
+                    nvgpu.TmaAsyncLoadOp(a_tma_slice,
+                                         mbarTMA,
+                                         a_tma_desc,
+                                         coordinates=[coord, dimX],
+                                         mbarId=iv,
+                                         predicate=primaryThread)
+                    nvgpu.TmaAsyncLoadOp(b_tma_slice_1,
+                                         mbarTMA,
+                                         b_tma_desc,
+                                         coordinates=[dimY, coord],
+                                         mbarId=iv,
+                                         predicate=primaryThread)
+                    nvgpu.TmaAsyncLoadOp(b_tma_slice_2,
+                                         mbarTMA,
+                                         b_tma_desc,
+                                         coordinates=[dimY2, coord],
+                                         mbarId=iv,
+                                         predicate=primaryThread)
+
+                    # Step 3.2. mbarTMA arrive
+                    debug_print("[Prologue] mbarTMA[{}] arrive", iv, predicate=primaryThread)
+                    nvgpu.mbarrier_arrive_expect_tx(mbarTMA, c(txcount), iv, predicate=primaryThread)
+                    debug_print("[Prologue] mbarTMA[{}] arrive [done]", iv, predicate=primaryThread)
+                    scf.yield_([])
+
+                # GPU Step 4. Main Loop
+                acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
+                for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)])
+                with ir.InsertionPoint(for_op.body):
+                    # Step 4.1. Wait mbarTMA
+                    phaseParity = for_op.inner_iter_args[1]
+                    iv = for_op.induction_variable
+                    stage = arith.remui(iv, c(num_stages))
+                    debug_print("[MainLoop] mbarTMA[{}] try_wait   phase={}", stage, phaseParity, predicate=primaryThread)
+                    nvgpu.MBarrierTryWaitParityOp(mbarTMA, phaseParity, ticks, mbarId=stage)
+                    debug_print("[MainLoop] mbarTMA[{}] try_wait   phase={} [done]", stage, phaseParity, predicate=primaryThread)
+
+                    # Step 4.2. Create WGMMA Descriptors
+                    a_offset = arith.muli(stage, c(lhs_tile_bytes))
+                    a_tile_slice = memref.view(ir.MemRefType.get(a_tile_shape, f16, memory_space=smem_space),
+                                               dynamic_smem, a_offset, [])
+                    b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
+                    b_tile_slice = memref.view(ir.MemRefType.get(b_tile_shape, f16, memory_space=smem_space),
+                                               dynamic_smem, b_offset, [])
+                    debug_print("[MainLoop] iv={} MMA a_offset={} b_offset={}", iv, a_offset, b_offset, predicate=primaryThread)
+                    da = nvgpu.WarpgroupGenerateDescriptorOp(a_wgmma_ty, a_tile_slice, a_tma_desc)
+                    db = nvgpu.WarpgroupGenerateDescriptorOp(b_wgmma_ty, b_tile_slice, b_tma_desc)
+
+                    # Step 4.3. MMA
+                    carry_acc = for_op.inner_iter_args[0]
+                    new_acc = nvgpu.WarpgroupMmaOp(acc.type, da, db, carry_acc, transposeB=True)
+
+                    # Step 4.4. Load TMA for next stage
+                    p1 = arith.cmpi(arith.CmpIPredicate.ult, arith.addi(iv, c(num_stages-1)), c(K // BLOCK_K))
+                    p = arith.andi(primaryThread, p1)
+                    with ir.InsertionPoint(scf.IfOp(p).then_block):
+                        nextStage = arith.addi(iv, c(num_stages-1))
+                        nextSlot = arith.remui(nextStage, c(num_stages))
+                        a_offset = arith.muli(nextSlot, c(lhs_tile_bytes))
+                        a_tma_slice = memref.view(ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
+                                                dynamic_smem, a_offset, [])
+                        b_offset = arith.addi(arith.muli(nextSlot, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
+                        b_tma_slice_1 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+                                                    dynamic_smem, b_offset, [])
+                        b_offset2 = arith.addi(b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+                        b_tma_slice_2 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+                                                    dynamic_smem, b_offset2, [])
+
+                        coord = arith.muli(c(64), nextStage)
+                        debug_print("[MainLoop] iv={} TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
+                                iv, a_offset,
+                                b_offset,
+                                b_offset2,
+                                coord, dimX,
+                                dimY, coord,
+                                predicate=primaryThread)
+                        nvgpu.TmaAsyncLoadOp(a_tma_slice,
+                                            mbarTMA,
+                                            a_tma_desc,
+                                            coordinates=[coord, dimX],
+                                            mbarId=nextSlot,
+                                            predicate=primaryThread)
+                        nvgpu.TmaAsyncLoadOp(b_tma_slice_1,
+                                            mbarTMA,
+                                            b_tma_desc,
+                                            coordinates=[dimY, coord],
+                                            mbarId=nextSlot,
+                                            predicate=primaryThread)
+                        dimY2 = arith.addi(dimY, c(64))
+                        nvgpu.TmaAsyncLoadOp(b_tma_slice_2,
+                                            mbarTMA,
+                                            b_tma_desc,
+                                            coordinates=[dimY2, coord],
+                                            mbarId=nextSlot,
+                                            predicate=primaryThread)
+
+                        debug_print("[MainLoop] mbarTMA[{}] arrive", nextSlot, predicate=primaryThread)
+                        nvgpu.mbarrier_arrive_expect_tx(mbarTMA, c(txcount), nextSlot, predicate=primaryThread)
+                        debug_print("[MainLoop] mbarTMA[{}] arrive [done]", nextSlot, predicate=primaryThread)
+                        scf.yield_([])
+                    # Step 4.5. Change the phaseParity
+                    p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
+                    phaseParity = arith.select(p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity)
+
+                    # Step 4.5. Yield
+                    scf.yield_([new_acc, phaseParity])
+
+                # Step 5. Wait All WGMMA groups
+                nvvm.WgmmaWaitGroupSyncOp(0)
+                gpu.barrier()
+
+                # Step 6. Epilogue (registers --> shared memory)
+                acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space)
+                acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
+                debug_print("Storing", predicate=primaryThread)
+                nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
+                gpu.barrier()
+
+                # GPU Step 7. Epilogue (shared memory --> global memory)
+                fd = ir.MemRefType.get([BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space)
+                collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
+                rty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty,
+                                        ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"))
+                c_device_per_block = memref.SubViewOp(rty, c_device, [dimX, dimY], [], [], [MLIR_DYNAMIC, MLIR_DYNAMIC],
+                                                      [BLOCK_M, BLOCK_N], [1, 1])
+                vlen = 1
+                for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE))
+                with ir.InsertionPoint(for_op.body):
+                    x = arith.divui(for_op.induction_variable, c(BLOCK_M))
+                    y = arith.remui(for_op.induction_variable, c(BLOCK_N))
+                    vdata = vector.load(ir.VectorType.get((vlen, ), c_elem_ty), collapsed_smem,
+                                        [for_op.induction_variable])
+                    vector.store(vdata, c_device_per_block, [x, y])
+                    scf.yield_([])
+
+                gpu.terminator()
+
+            # Step 4. Copy back to host
+            t8 = gpu.wait(token_ty, [launch_op])
+            t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
+            gpu.wait(token_ty, [t9])
+
+    mlir_matmul_multistage.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+    module.operation.verify()
+    return module
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
new file mode 100644
index 00000000000000..2a4f67326ee718
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
@@ -0,0 +1,44 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#  This file contains the Nvgpu class.
+
+from mlir import execution_engine
+from mlir import ir
+from mlir import passmanager
+from typing import Sequence
+import errno
+import os
+import sys
+
+_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(_SCRIPT_PATH)
+
+class NvgpuCompiler:
+    """Nvgpu class for compiling and building MLIR modules."""
+
+    def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
+        pipeline = f"builtin.module(gpu-lower-to-nvvm-pipeline{{{options}}})"
+        self.pipeline = pipeline
+        self.shared_libs = shared_libs
+        self.opt_level = opt_level
+
+    def __call__(self, module: ir.Module):
+        """Convenience application method."""
+        self.compile(module)
+
+    def compile(self, module: ir.Module):
+        """Compiles the module by invoking the nvgpu pipeline."""
+        passmanager.PassManager.parse(self.pipeline).run(module.operation)
+    
+    def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+        """Wraps the module in a JIT execution engine."""
+        return execution_engine.ExecutionEngine(
+            module, opt_level=self.opt_level, shared_libs=self.shared_libs
+        )
+
+    def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+        """Compiles and jits the module."""
+        self.compile(module)
+        return self.jit(module)

>From 26da27e3c5b5198750b1f80fbd9d1e5d6dfe229d Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Mon, 12 Feb 2024 13:39:08 +0000
Subject: [PATCH 2/7] format with yapf

---
 .../GPU/CUDA/sm90/python/matmul.py            |  51 +-
 .../CUDA/sm90/python/tools/matmulBuilder.py   | 655 ++++++++++++------
 .../CUDA/sm90/python/tools/nvgpucompiler.py   |  15 +-
 3 files changed, 483 insertions(+), 238 deletions(-)

diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
index e153fcb44b9860..00313270d723d6 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
@@ -85,6 +85,7 @@
 import ctypes
 from mlir import runtime as rt
 
+
 def generate_matmul(input_type=np.float16,
                     output_type=np.float32,
                     M=4096,
@@ -96,16 +97,19 @@ def generate_matmul(input_type=np.float16,
                     use_warp_specilization=True,
                     saveIR=False,
                     max_num_stages=3):
-    with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown():
+    with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown(
+    ):
         if use_warp_specilization:
-            mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(input_type, output_type, M, N, K, BLOCK_M, BLOCK_N,
-                                                                 BLOCK_K, max_num_stages)
+            mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(
+                input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K,
+                max_num_stages)
         else:
-            mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage(input_type, output_type, M, N, K, BLOCK_M,
-                                                                         BLOCK_N, BLOCK_K, max_num_stages)
+            mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage(
+                input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K,
+                max_num_stages)
 
         mlir_nvgpu_module.operation.verify()
-        
+
         # Save generated IR
         if saveIR:
             # print(mlir_nvgpu_module)
@@ -119,8 +123,11 @@ def generate_matmul(input_type=np.float16,
         options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
         support_lib = os.getenv("SUPPORT_LIB")
         if not os.path.exists(support_lib):
-            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
-        compiler = nvgpucompiler.NvgpuCompiler(options, opt_level=3, shared_libs=[support_lib])
+            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
+                                    support_lib)
+        compiler = nvgpucompiler.NvgpuCompiler(options,
+                                               opt_level=3,
+                                               shared_libs=[support_lib])
 
         # Compile
         engine = compiler.compile_and_jit(mlir_nvgpu_module)
@@ -144,13 +151,15 @@ def matmul(input_type=np.float16,
     ity = "f16" if input_type == np.float16 else "f32"
     oty = "f16" if output_type == np.float16 else "f32"
     gemmty = "Warp Specilization" if use_warp_specilization else "Multistage"
-    print("===-- Running GEMM " + gemmty + " " + oty + " += " + ity + " * " + ity + ", Size " + str(M) + "x" + str(N) +
-          "x" + str(K) + ", Tile " + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" + str(BLOCK_K) + ", stages " +
-          str(max_num_stages) + " --===")
+    print("===-- Running GEMM " + gemmty + " " + oty + " += " + ity + " * " +
+          ity + ", Size " + str(M) + "x" + str(N) + "x" + str(K) + ", Tile " +
+          str(BLOCK_M) + "x" + str(BLOCK_N) + "x" + str(BLOCK_K) +
+          ", stages " + str(max_num_stages) + " --===")
 
     # Build IR and compile
-    engine = generate_matmul(input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, use_warp_specilization,
-                             saveIR, max_num_stages)
+    engine = generate_matmul(input_type, output_type, M, N, K, BLOCK_M,
+                             BLOCK_N, BLOCK_K, use_warp_specilization, saveIR,
+                             max_num_stages)
 
     # Allocate matrices and invoke the matmul
     c = np.zeros((M, N), output_type)
@@ -181,6 +190,18 @@ def matmul(input_type=np.float16,
 
 
 # GEMM Multistage       f32 += f16 * f16
-matmul(np.float16, np.float32, 128, 128, 4096, max_num_stages=3, use_warp_specilization=False)
+matmul(np.float16,
+       np.float32,
+       128,
+       128,
+       4096,
+       max_num_stages=3,
+       use_warp_specilization=False)
 # GEMM Warp Specilized  f32 += f16 * f16
-matmul(np.float16, np.float32, 256, 1024, 512, max_num_stages=3, use_warp_specilization=True)
+matmul(np.float16,
+       np.float32,
+       256,
+       1024,
+       512,
+       max_num_stages=3,
+       use_warp_specilization=True)
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
index 09ab35d4c5f15e..0fda8698606cd5 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
@@ -11,7 +11,6 @@
 from mlir.dialects import scf
 from mlir.dialects import vector
 
-
 TMA_LAST_DIM_F16 = 64  # 128B flaot16
 WARP_SIZE = 32
 WARP_GROUP_SIZE = WARP_SIZE * 4
@@ -81,7 +80,8 @@ def generate_matmul_ws(input_type=np.float16,
     assert N % BLOCK_N == 0
     assert K % BLOCK_K == 0
 
-    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
+    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K +
+                                          BLOCK_K * BLOCK_N)
     num_stages = min(required_stages, max_num_stages)
 
     module = ir.Module.create()
@@ -99,25 +99,36 @@ def generate_matmul_ws(input_type=np.float16,
     a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
     b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
     b_tile_shape = (BLOCK_K, BLOCK_N)
-    txcount = ((b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
+    txcount = ((b_tile_shape[0] * b_tile_shape[1]) +
+               (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
     smem_space_str = "#gpu.address_space<workgroup>"
     smem_space = ir.Attribute.parse(smem_space_str)
     input_type_str = "f16" if input_type == np.float16 else "f32"
     output_type_str = "f16" if output_type == np.float16 else "f32"
-    mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " + str(smem_space) + ", num_barriers = " +
+    mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " +
+                            str(smem_space) + ", num_barriers = " +
                             str(num_stages) + ">")
-    a_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
-                                  str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
-                                  ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
-    b_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
-                                  str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
-                                  ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
-    acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
+    a_tma_desc_ty = ir.Type.parse(
+        "!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
+        str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " +
+        str(smem_space) +
+        ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+    b_tma_desc_ty = ir.Type.parse(
+        "!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
+        str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " +
+        str(smem_space) +
+        ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+    acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" +
+                           str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
                            str(output_type_str) + ">>")
-    a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
-                               str(input_type_str) + ", " + smem_space_str + ">>")
-    b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
-                               str(input_type_str) + ", " + smem_space_str + ">>")
+    a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" +
+                               str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
+                               str(input_type_str) + ", " + smem_space_str +
+                               ">>")
+    b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" +
+                               str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
+                               str(input_type_str) + ", " + smem_space_str +
+                               ">>")
 
     with ir.InsertionPoint(module.body):
 
@@ -139,13 +150,18 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
             t7 = gpu.wait(token_ty, [t6])
 
             # Step 2. Create TMA Descriptors
-            tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape), (b_device, b_tma_desc_ty, b_tma_shape)]
+            tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape),
+                         (b_device, b_tma_desc_ty, b_tma_shape)]
             tma_descs = []
             for x_device, tensor_map_ty, tile_shape in tma_specs:
-                x_unranked = memref.cast(ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device)
-                tma_descs.append(nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked, map(c, tile_shape)).result)
+                x_unranked = memref.cast(
+                    ir.UnrankedMemRefType.get(f16, a_ty.memory_space),
+                    x_device)
+                tma_descs.append(
+                    nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked,
+                                                map(c, tile_shape)).result)
             a_tma_desc, b_tma_desc = tma_descs
-            
+
             # Step 3. Launch Kernel with 2 Warpgroups : 1 Producer, 1 Consumer
             cta_m = M // BLOCK_M
             cta_n = N // BLOCK_N
@@ -155,26 +171,37 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
             launch_op = gpu.LaunchOp(token_ty, [t7],
                                      *map(c, grid),
                                      *map(c, block),
-                                     dynamicSharedMemorySize=c(smem_size, ty=i32))
+                                     dynamicSharedMemorySize=c(smem_size,
+                                                               ty=i32))
             launch_op.body.blocks.append(*([index] * 12))
             with ir.InsertionPoint(launch_op.body.blocks[0]):
                 # GPU Step 0. This is need for vectorized ld/st
                 memref.assume_alignment(c_device, 16)
                 dynamic_smem = gpu.dynamic_shared_memory(
-                    ir.MemRefType.get((MLIR_DYNAMIC, ), i8, memory_space=smem_space))
+                    ir.MemRefType.get((MLIR_DYNAMIC, ),
+                                      i8,
+                                      memory_space=smem_space))
                 ticks = c(10000000)
 
                 # GPU Step 1. Bootstrapping: find the primary thread, warps, warp groups and etc.
                 tidx = gpu.thread_id(gpu.Dimension.x)
-                wgPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0))
+                wgPrimaryThread = arith.cmpi(
+                    arith.CmpIPredicate.eq,
+                    arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0))
                 warp_id = arith.divui(tidx, c(32))
                 warpgroup_id = arith.divui(warp_id, c(4))
-                is_producer = arith.cmpi(arith.CmpIPredicate.eq, warpgroup_id,
-                                         c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0))
-                is_consumer = arith.cmpi(arith.CmpIPredicate.eq, warpgroup_id,
-                                         c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1))
-                producerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(PRODUCER_PRIMARY_THREAD))
-                consumerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(CONSUMER_PRIMARY_THREAD))
+                is_producer = arith.cmpi(
+                    arith.CmpIPredicate.eq, warpgroup_id,
+                    c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0))
+                is_consumer = arith.cmpi(
+                    arith.CmpIPredicate.eq, warpgroup_id,
+                    c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1))
+                producerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq,
+                                                   tidx,
+                                                   c(PRODUCER_PRIMARY_THREAD))
+                consumerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq,
+                                                   tidx,
+                                                   c(CONSUMER_PRIMARY_THREAD))
                 bidx = gpu.block_id(gpu.Dimension.x)
                 bidy = gpu.block_id(gpu.Dimension.y)
                 dimX = arith.muli(bidx, c(BLOCK_M))
@@ -184,57 +211,88 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                 mbarTMA = nvgpu.mbarrier_create(mbar_ty)
                 mbarDONE = nvgpu.mbarrier_create(mbar_ty)
                 for i in range(num_stages):
-                    nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=wgPrimaryThread)
-                    nvgpu.mbarrier_init(mbarDONE, c(1), c(i), predicate=wgPrimaryThread)
+                    nvgpu.mbarrier_init(mbarTMA,
+                                        c(1),
+                                        c(i),
+                                        predicate=wgPrimaryThread)
+                    nvgpu.mbarrier_init(mbarDONE,
+                                        c(1),
+                                        c(i),
+                                        predicate=wgPrimaryThread)
                 gpu.barrier()
 
                 # GPU Step 3. Prefetch TMA descriptors
-                nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=wgPrimaryThread)
-                nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=wgPrimaryThread)
-                
+                nvgpu.tma_prefetch_descriptor(a_tma_desc,
+                                              predicate=wgPrimaryThread)
+                nvgpu.tma_prefetch_descriptor(b_tma_desc,
+                                              predicate=wgPrimaryThread)
+
                 # GPU Step 5. Producer Warpgroup (TMA Warpgroup)
                 with ir.InsertionPoint(scf.IfOp(is_producer).then_block):
-                    
+
                     # Step 5.1. Reduce register size
-                    nvvm.setmaxregister(PRODUCER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.decrease)
+                    nvvm.setmaxregister(PRODUCER_REGISTER_SIZE,
+                                        nvvm.SetMaxRegisterAction.decrease)
 
                     # Step 5.2. TMA Main Loop
-                    for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [arith.constant(i1, 1)])
+                    for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1),
+                                       [arith.constant(i1, 1)])
                     with ir.InsertionPoint(for_op.body):
                         phaseParity = for_op.inner_iter_args[0]
                         iv = for_op.induction_variable
                         stage = arith.remui(iv, c(num_stages))
 
                         # Step 5.2.1. Wait mbarDONE
-                        debug_print("[prod] {}  | mbarDONE[{}] try_wait  phase={}",
-                                    iv,
-                                    stage,
-                                    phaseParity,
-                                    predicate=producerPrimaryThread)
-                        nvgpu.MBarrierTryWaitParityOp(mbarDONE, phaseParity, ticks, mbarId=stage)
-                        debug_print("[prod] {}  | mbarDONE[{}] try_wait  phase={} [done]",
-                                    iv,
-                                    stage,
-                                    phaseParity,
-                                    predicate=producerPrimaryThread)
-                        p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
-                        phaseParity = arith.select(p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity)
+                        debug_print(
+                            "[prod] {}  | mbarDONE[{}] try_wait  phase={}",
+                            iv,
+                            stage,
+                            phaseParity,
+                            predicate=producerPrimaryThread)
+                        nvgpu.MBarrierTryWaitParityOp(mbarDONE,
+                                                      phaseParity,
+                                                      ticks,
+                                                      mbarId=stage)
+                        debug_print(
+                            "[prod] {}  | mbarDONE[{}] try_wait  phase={} [done]",
+                            iv,
+                            stage,
+                            phaseParity,
+                            predicate=producerPrimaryThread)
+                        p = arith.cmpi(arith.CmpIPredicate.eq, stage,
+                                       c(num_stages - 1))
+                        phaseParity = arith.select(
+                            p, arith.xori(phaseParity, arith.constant(i1, 1)),
+                            phaseParity)
 
                         # Step 5.2.2. Load TMA
                         a_offset = arith.muli(stage, c(lhs_tile_bytes))
-                        a_tma_slice = memref.view(ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
-                                                  dynamic_smem, a_offset, [])
-                        b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
-                        b_tma_slice_1 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
-                                                    dynamic_smem, b_offset, [])
-                        b_offset2 = arith.addi(b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
-                        b_tma_slice_2 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
-                                                    dynamic_smem, b_offset2, [])
-                        debug_print("[prod] a_offset={} b_offset={} b_offset2={}",
-                                    a_offset,
-                                    b_offset,
-                                    b_offset2,
-                                    predicate=producerPrimaryThread)
+                        a_tma_slice = memref.view(
+                            ir.MemRefType.get(a_tma_shape,
+                                              f16,
+                                              memory_space=smem_space),
+                            dynamic_smem, a_offset, [])
+                        b_offset = arith.addi(
+                            arith.muli(stage, c(rhs_tile_bytes)),
+                            c(lhs_tile_bytes * num_stages))
+                        b_tma_slice_1 = memref.view(
+                            ir.MemRefType.get(b_tma_shape,
+                                              f16,
+                                              memory_space=smem_space),
+                            dynamic_smem, b_offset, [])
+                        b_offset2 = arith.addi(
+                            b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+                        b_tma_slice_2 = memref.view(
+                            ir.MemRefType.get(b_tma_shape,
+                                              f16,
+                                              memory_space=smem_space),
+                            dynamic_smem, b_offset2, [])
+                        debug_print(
+                            "[prod] a_offset={} b_offset={} b_offset2={}",
+                            a_offset,
+                            b_offset,
+                            b_offset2,
+                            predicate=producerPrimaryThread)
                         coord = arith.muli(c(64), iv)
                         nvgpu.TmaAsyncLoadOp(a_tma_slice,
                                              mbarTMA,
@@ -257,8 +315,15 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                                              predicate=producerPrimaryThread)
 
                         # Step 5.2.3. Arrive mbarTMA
-                        debug_print("[prod] {}  | mbarTMA[{}] arrive", iv, stage, predicate=producerPrimaryThread)
-                        nvgpu.mbarrier_arrive_expect_tx(mbarTMA, c(txcount), stage, predicate=producerPrimaryThread)
+                        debug_print("[prod] {}  | mbarTMA[{}] arrive",
+                                    iv,
+                                    stage,
+                                    predicate=producerPrimaryThread)
+                        nvgpu.mbarrier_arrive_expect_tx(
+                            mbarTMA,
+                            c(txcount),
+                            stage,
+                            predicate=producerPrimaryThread)
                         debug_print("[prod] {}  | mbarTMA[{}] arrive [done]",
                                     iv,
                                     stage,
@@ -269,75 +334,104 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                 # GPU Step 6. Consumer Warpgroup (MMA Warpgroup)
                 if_op = scf.IfOp(is_consumer)
                 with ir.InsertionPoint(if_op.then_block):
-                    
+
                     # Step 6.1. Increase register size
-                    nvvm.setmaxregister(CONSUMER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.increase)
-                    
+                    nvvm.setmaxregister(CONSUMER_REGISTER_SIZE,
+                                        nvvm.SetMaxRegisterAction.increase)
+
                     # GPU Step 6.2. Initialize MMA registers
                     acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
 
                     # Step 6.3. MMA Main Loop
-                    for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)])
+                    for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1),
+                                       [acc, arith.constant(i1, 0)])
                     with ir.InsertionPoint(for_op.body):
                         # Step 6.3.1. Wait mbar1
                         phaseParity = for_op.inner_iter_args[1]
                         iv = for_op.induction_variable
                         stage = arith.remui(iv, c(num_stages))
-                        debug_print("[cons] {}  | mbarTMA[{}] try_wait   phase={}",
-                                    iv,
-                                    stage,
-                                    phaseParity,
-                                    predicate=consumerPrimaryThread)
-                        nvgpu.MBarrierTryWaitParityOp(mbarTMA, phaseParity, ticks, mbarId=stage)
-                        debug_print("[cons] {}  | mbarTMA[{}] try_wait   phase={} [done]",
-                                    iv,
-                                    stage,
-                                    phaseParity,
-                                    predicate=consumerPrimaryThread)
-                        
+                        debug_print(
+                            "[cons] {}  | mbarTMA[{}] try_wait   phase={}",
+                            iv,
+                            stage,
+                            phaseParity,
+                            predicate=consumerPrimaryThread)
+                        nvgpu.MBarrierTryWaitParityOp(mbarTMA,
+                                                      phaseParity,
+                                                      ticks,
+                                                      mbarId=stage)
+                        debug_print(
+                            "[cons] {}  | mbarTMA[{}] try_wait   phase={} [done]",
+                            iv,
+                            stage,
+                            phaseParity,
+                            predicate=consumerPrimaryThread)
+
                         # Step 6.3.2. Create WGMMA Descriptors
                         a_offset = arith.muli(stage, c(lhs_tile_bytes))
-                        a_tile_slice = memref.view(ir.MemRefType.get(a_tile_shape, f16, memory_space=smem_space),
-                                                   dynamic_smem, a_offset, [])
-                        b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
-                        b_tile_slice = memref.view(ir.MemRefType.get(b_tile_shape, f16, memory_space=smem_space),
-                                                   dynamic_smem, b_offset, [])
+                        a_tile_slice = memref.view(
+                            ir.MemRefType.get(a_tile_shape,
+                                              f16,
+                                              memory_space=smem_space),
+                            dynamic_smem, a_offset, [])
+                        b_offset = arith.addi(
+                            arith.muli(stage, c(rhs_tile_bytes)),
+                            c(lhs_tile_bytes * num_stages))
+                        b_tile_slice = memref.view(
+                            ir.MemRefType.get(b_tile_shape,
+                                              f16,
+                                              memory_space=smem_space),
+                            dynamic_smem, b_offset, [])
                         debug_print("[cons] a_offset={} b_offset={}",
                                     a_offset,
                                     b_offset,
                                     predicate=consumerPrimaryThread)
-                        da = nvgpu.WarpgroupGenerateDescriptorOp(a_wgmma_ty, a_tile_slice, a_tma_desc)
-                        db = nvgpu.WarpgroupGenerateDescriptorOp(b_wgmma_ty, b_tile_slice, b_tma_desc)
+                        da = nvgpu.WarpgroupGenerateDescriptorOp(
+                            a_wgmma_ty, a_tile_slice, a_tma_desc)
+                        db = nvgpu.WarpgroupGenerateDescriptorOp(
+                            b_wgmma_ty, b_tile_slice, b_tma_desc)
 
                         # Step 6.3.3. MMA
                         carry_acc = for_op.inner_iter_args[0]
-                        new_acc = nvgpu.WarpgroupMmaOp(acc.type, da, db, carry_acc, transposeB=True)
+                        new_acc = nvgpu.WarpgroupMmaOp(acc.type,
+                                                       da,
+                                                       db,
+                                                       carry_acc,
+                                                       transposeB=True)
 
                         # Step 6.3.4. Arrive mbarDONE
                         p1 = arith.cmpi(arith.CmpIPredicate.sgt, iv, c(0))
                         p_arrive = arith.andi(consumerPrimaryThread, p1)
                         with ir.InsertionPoint(scf.IfOp(p_arrive).then_block):
                             p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(0))
-                            barId = arith.select(p, c(num_stages - 1), arith.subi(stage, c(1)))
+                            barId = arith.select(p, c(num_stages - 1),
+                                                 arith.subi(stage, c(1)))
                             remoteCtaId = c(0)
                             pred = consumerPrimaryThread
-                            debug_print("[cons] {}  | mbarDONE[{}] arrive.mapa  pred={} ",
-                                        iv,
-                                        barId,
-                                        remoteCtaId,
-                                        pred,
-                                        predicate=consumerPrimaryThread)
-                            nvgpu.mbarrier_arrive(ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId)
-                            debug_print("[cons] {}  | mbarDONE[{}] arrive  pred={} [done]",
-                                        iv,
-                                        barId,
-                                        remoteCtaId,
-                                        pred,
-                                        predicate=consumerPrimaryThread)
+                            debug_print(
+                                "[cons] {}  | mbarDONE[{}] arrive.mapa  pred={} ",
+                                iv,
+                                barId,
+                                remoteCtaId,
+                                pred,
+                                predicate=consumerPrimaryThread)
+                            nvgpu.mbarrier_arrive(
+                                ir.Type.parse("!nvgpu.mbarrier.token"),
+                                mbarDONE, barId)
+                            debug_print(
+                                "[cons] {}  | mbarDONE[{}] arrive  pred={} [done]",
+                                iv,
+                                barId,
+                                remoteCtaId,
+                                pred,
+                                predicate=consumerPrimaryThread)
                             scf.yield_([])
 
-                        p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
-                        phaseParity = arith.select(p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity)
+                        p = arith.cmpi(arith.CmpIPredicate.eq, stage,
+                                       c(num_stages - 1))
+                        phaseParity = arith.select(
+                            p, arith.xori(phaseParity, arith.constant(i1, 1)),
+                            phaseParity)
 
                         # Step 6.3.5. Yield
                         scf.yield_([new_acc, phaseParity])
@@ -345,32 +439,45 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                     # Step 6.3. Wait All WGMMA
                     nvvm.WgmmaWaitGroupSyncOp(0)
 
-                    with ir.InsertionPoint(scf.IfOp(consumerPrimaryThread).then_block):
+                    with ir.InsertionPoint(
+                            scf.IfOp(consumerPrimaryThread).then_block):
                         barId = c((K // BLOCK_K) % num_stages)
-                        nvgpu.mbarrier_arrive(ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId)
+                        nvgpu.mbarrier_arrive(
+                            ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE,
+                            barId)
                         scf.yield_([])
 
                     # Step 6.4. Epilogue (registers --> shared memory)
-                    acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space)
+                    acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N),
+                                                    c_elem_ty,
+                                                    memory_space=smem_space)
                     acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
-                    debug_print("[cons]  | Storing", predicate=consumerPrimaryThread)
+                    debug_print("[cons]  | Storing",
+                                predicate=consumerPrimaryThread)
                     nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
                     scf.yield_([])
                 gpu.barrier()
 
                 # GPU Step 9. Epilogue (shared memory --> global memory)
-                fd = ir.MemRefType.get([BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space)
+                fd = ir.MemRefType.get([BLOCK_M * BLOCK_N],
+                                       c_elem_ty,
+                                       memory_space=smem_space)
                 collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
-                rty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty,
-                                        ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"))
-                c_device_per_block = memref.SubViewOp(rty, c_device, [dimX, dimY], [], [], [MLIR_DYNAMIC, MLIR_DYNAMIC],
-                                                      [BLOCK_M, BLOCK_N], [1, 1])
+                rty = ir.MemRefType.get(
+                    (BLOCK_M, BLOCK_N), c_elem_ty,
+                    ir.Attribute.parse("strided<[" + str(N) +
+                                       ", 1], offset: ?>"))
+                c_device_per_block = memref.SubViewOp(
+                    rty, c_device, [dimX, dimY], [], [],
+                    [MLIR_DYNAMIC, MLIR_DYNAMIC], [BLOCK_M, BLOCK_N], [1, 1])
                 vlen = 1
-                for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE * 2))
+                for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N),
+                                   c(vlen * WARP_GROUP_SIZE * 2))
                 with ir.InsertionPoint(for_op.body):
                     x = arith.divui(for_op.induction_variable, c(BLOCK_M))
                     y = arith.remui(for_op.induction_variable, c(BLOCK_N))
-                    vdata = vector.load(ir.VectorType.get((vlen, ), c_elem_ty), collapsed_smem,
+                    vdata = vector.load(ir.VectorType.get(
+                        (vlen, ), c_elem_ty), collapsed_smem,
                                         [for_op.induction_variable])
                     vector.store(vdata, c_device_per_block, [x, y])
                     scf.yield_([])
@@ -382,7 +489,8 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
             t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
             gpu.wait(token_ty, [t9])
 
-    mlir_matmul_warpspecialized.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+    mlir_matmul_warpspecialized.func_op.attributes[
+        "llvm.emit_c_interface"] = ir.UnitAttr.get()
     module.operation.verify()
     return module
 
@@ -403,7 +511,8 @@ def generate_matmul_multistage(input_type=np.float16,
     assert N % BLOCK_N == 0
     assert K % BLOCK_K == 0
 
-    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
+    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K +
+                                          BLOCK_K * BLOCK_N)
     num_stages = min(required_stages, max_num_stages)
 
     module = ir.Module.create()
@@ -421,25 +530,36 @@ def generate_matmul_multistage(input_type=np.float16,
     a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
     b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
     b_tile_shape = (BLOCK_K, BLOCK_N)
-    txcount = ((b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
+    txcount = ((b_tile_shape[0] * b_tile_shape[1]) +
+               (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
     smem_space_str = "#gpu.address_space<workgroup>"
     smem_space = ir.Attribute.parse(smem_space_str)
     input_type_str = "f16" if input_type == np.float16 else "f32"
     output_type_str = "f16" if output_type == np.float16 else "f32"
-    mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " + str(smem_space) + ", num_barriers = " +
+    mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " +
+                            str(smem_space) + ", num_barriers = " +
                             str(num_stages) + ">")
-    a_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
-                                  str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
-                                  ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
-    b_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
-                                  str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
-                                  ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
-    acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
+    a_tma_desc_ty = ir.Type.parse(
+        "!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
+        str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " +
+        str(smem_space) +
+        ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+    b_tma_desc_ty = ir.Type.parse(
+        "!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
+        str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " +
+        str(smem_space) +
+        ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+    acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" +
+                           str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
                            str(output_type_str) + ">>")
-    a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
-                               str(input_type_str) + ", " + smem_space_str + ">>")
-    b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
-                               str(input_type_str) + ", " + smem_space_str + ">>")
+    a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" +
+                               str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
+                               str(input_type_str) + ", " + smem_space_str +
+                               ">>")
+    b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" +
+                               str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
+                               str(input_type_str) + ", " + smem_space_str +
+                               ">>")
 
     with ir.InsertionPoint(module.body):
 
@@ -461,11 +581,16 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
             t7 = gpu.wait(token_ty, [t6])
 
             # Step 2. Create TMA Descriptors
-            tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape), (b_device, b_tma_desc_ty, b_tma_shape)]
+            tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape),
+                         (b_device, b_tma_desc_ty, b_tma_shape)]
             tma_descs = []
             for x_device, tensor_map_ty, tile_shape in tma_specs:
-                x_unranked = memref.cast(ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device)
-                tma_descs.append(nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked, map(c, tile_shape)).result)
+                x_unranked = memref.cast(
+                    ir.UnrankedMemRefType.get(f16, a_ty.memory_space),
+                    x_device)
+                tma_descs.append(
+                    nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked,
+                                                map(c, tile_shape)).result)
             a_tma_desc, b_tma_desc = tma_descs
 
             # Step 3. Launch Kernel with 1 Warpgroup
@@ -477,13 +602,16 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
             launch_op = gpu.LaunchOp(token_ty, [t7],
                                      *map(c, grid),
                                      *map(c, block),
-                                     dynamicSharedMemorySize=c(smem_size, ty=i32))
+                                     dynamicSharedMemorySize=c(smem_size,
+                                                               ty=i32))
             launch_op.body.blocks.append(*([index] * 12))
             with ir.InsertionPoint(launch_op.body.blocks[0]):
                 # GPU Step 0. Bootstrapping
                 memref.assume_alignment(c_device, 16)
                 dynamic_smem = gpu.dynamic_shared_memory(
-                    ir.MemRefType.get((MLIR_DYNAMIC, ), i8, memory_space=smem_space))
+                    ir.MemRefType.get((MLIR_DYNAMIC, ),
+                                      i8,
+                                      memory_space=smem_space))
                 ticks = c(10000000)
                 tidx = gpu.thread_id(gpu.Dimension.x)
                 primaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(0))
@@ -496,39 +624,58 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                 # GPU Step 1. Initialize mbarrier groups
                 mbarTMA = nvgpu.mbarrier_create(mbar_ty)
                 for i in range(num_stages):
-                    nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=primaryThread)
+                    nvgpu.mbarrier_init(mbarTMA,
+                                        c(1),
+                                        c(i),
+                                        predicate=primaryThread)
                 gpu.barrier()
 
                 # GPU Step 2. Prefetch TMA descriptors
-                nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=primaryThread)
-                nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=primaryThread)
+                nvgpu.tma_prefetch_descriptor(a_tma_desc,
+                                              predicate=primaryThread)
+                nvgpu.tma_prefetch_descriptor(b_tma_desc,
+                                              predicate=primaryThread)
 
                 # GPU Step 3. Prologue (global memory --> shared memory)
-                for_op = scf.ForOp(c(0), c(num_stages-1), c(1))
+                for_op = scf.ForOp(c(0), c(num_stages - 1), c(1))
                 with ir.InsertionPoint(for_op.body):
                     iv = for_op.induction_variable
 
                     # Step 3.1. Calculate offsets
                     a_offset = arith.muli(iv, c(lhs_tile_bytes))
-                    a_tma_slice = memref.view(ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
-                                              dynamic_smem, a_offset, [])
-                    b_offset = arith.addi(arith.muli(iv, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
-                    b_tma_slice_1 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
-                                                dynamic_smem, b_offset, [])
-                    b_offset2 = arith.addi(b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
-                    b_tma_slice_2 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
-                                                dynamic_smem, b_offset2, [])
+                    a_tma_slice = memref.view(
+                        ir.MemRefType.get(a_tma_shape,
+                                          f16,
+                                          memory_space=smem_space),
+                        dynamic_smem, a_offset, [])
+                    b_offset = arith.addi(arith.muli(iv, c(rhs_tile_bytes)),
+                                          c(lhs_tile_bytes * num_stages))
+                    b_tma_slice_1 = memref.view(
+                        ir.MemRefType.get(b_tma_shape,
+                                          f16,
+                                          memory_space=smem_space),
+                        dynamic_smem, b_offset, [])
+                    b_offset2 = arith.addi(
+                        b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+                    b_tma_slice_2 = memref.view(
+                        ir.MemRefType.get(b_tma_shape,
+                                          f16,
+                                          memory_space=smem_space),
+                        dynamic_smem, b_offset2, [])
 
                     # Step 3.2. TMA Load
                     coord = arith.muli(c(64), iv)
                     dimY2 = arith.addi(dimY, c(64))
-                    debug_print("[Prologue] TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
-                                a_offset,
-                                b_offset,
-                                b_offset2,
-                                coord, dimX,
-                                dimY, coord,
-                                predicate=primaryThread)
+                    debug_print(
+                        "[Prologue] TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
+                        a_offset,
+                        b_offset,
+                        b_offset2,
+                        coord,
+                        dimX,
+                        dimY,
+                        coord,
+                        predicate=primaryThread)
                     nvgpu.TmaAsyncLoadOp(a_tma_slice,
                                          mbarTMA,
                                          a_tma_desc,
@@ -549,89 +696,153 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                                          predicate=primaryThread)
 
                     # Step 3.2. mbarTMA arrive
-                    debug_print("[Prologue] mbarTMA[{}] arrive", iv, predicate=primaryThread)
-                    nvgpu.mbarrier_arrive_expect_tx(mbarTMA, c(txcount), iv, predicate=primaryThread)
-                    debug_print("[Prologue] mbarTMA[{}] arrive [done]", iv, predicate=primaryThread)
+                    debug_print("[Prologue] mbarTMA[{}] arrive",
+                                iv,
+                                predicate=primaryThread)
+                    nvgpu.mbarrier_arrive_expect_tx(mbarTMA,
+                                                    c(txcount),
+                                                    iv,
+                                                    predicate=primaryThread)
+                    debug_print("[Prologue] mbarTMA[{}] arrive [done]",
+                                iv,
+                                predicate=primaryThread)
                     scf.yield_([])
 
                 # GPU Step 4. Main Loop
                 acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
-                for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)])
+                for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1),
+                                   [acc, arith.constant(i1, 0)])
                 with ir.InsertionPoint(for_op.body):
                     # Step 4.1. Wait mbarTMA
                     phaseParity = for_op.inner_iter_args[1]
                     iv = for_op.induction_variable
                     stage = arith.remui(iv, c(num_stages))
-                    debug_print("[MainLoop] mbarTMA[{}] try_wait   phase={}", stage, phaseParity, predicate=primaryThread)
-                    nvgpu.MBarrierTryWaitParityOp(mbarTMA, phaseParity, ticks, mbarId=stage)
-                    debug_print("[MainLoop] mbarTMA[{}] try_wait   phase={} [done]", stage, phaseParity, predicate=primaryThread)
+                    debug_print("[MainLoop] mbarTMA[{}] try_wait   phase={}",
+                                stage,
+                                phaseParity,
+                                predicate=primaryThread)
+                    nvgpu.MBarrierTryWaitParityOp(mbarTMA,
+                                                  phaseParity,
+                                                  ticks,
+                                                  mbarId=stage)
+                    debug_print(
+                        "[MainLoop] mbarTMA[{}] try_wait   phase={} [done]",
+                        stage,
+                        phaseParity,
+                        predicate=primaryThread)
 
                     # Step 4.2. Create WGMMA Descriptors
                     a_offset = arith.muli(stage, c(lhs_tile_bytes))
-                    a_tile_slice = memref.view(ir.MemRefType.get(a_tile_shape, f16, memory_space=smem_space),
-                                               dynamic_smem, a_offset, [])
-                    b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
-                    b_tile_slice = memref.view(ir.MemRefType.get(b_tile_shape, f16, memory_space=smem_space),
-                                               dynamic_smem, b_offset, [])
-                    debug_print("[MainLoop] iv={} MMA a_offset={} b_offset={}", iv, a_offset, b_offset, predicate=primaryThread)
-                    da = nvgpu.WarpgroupGenerateDescriptorOp(a_wgmma_ty, a_tile_slice, a_tma_desc)
-                    db = nvgpu.WarpgroupGenerateDescriptorOp(b_wgmma_ty, b_tile_slice, b_tma_desc)
+                    a_tile_slice = memref.view(
+                        ir.MemRefType.get(a_tile_shape,
+                                          f16,
+                                          memory_space=smem_space),
+                        dynamic_smem, a_offset, [])
+                    b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)),
+                                          c(lhs_tile_bytes * num_stages))
+                    b_tile_slice = memref.view(
+                        ir.MemRefType.get(b_tile_shape,
+                                          f16,
+                                          memory_space=smem_space),
+                        dynamic_smem, b_offset, [])
+                    debug_print("[MainLoop] iv={} MMA a_offset={} b_offset={}",
+                                iv,
+                                a_offset,
+                                b_offset,
+                                predicate=primaryThread)
+                    da = nvgpu.WarpgroupGenerateDescriptorOp(
+                        a_wgmma_ty, a_tile_slice, a_tma_desc)
+                    db = nvgpu.WarpgroupGenerateDescriptorOp(
+                        b_wgmma_ty, b_tile_slice, b_tma_desc)
 
                     # Step 4.3. MMA
                     carry_acc = for_op.inner_iter_args[0]
-                    new_acc = nvgpu.WarpgroupMmaOp(acc.type, da, db, carry_acc, transposeB=True)
+                    new_acc = nvgpu.WarpgroupMmaOp(acc.type,
+                                                   da,
+                                                   db,
+                                                   carry_acc,
+                                                   transposeB=True)
 
                     # Step 4.4. Load TMA for next stage
-                    p1 = arith.cmpi(arith.CmpIPredicate.ult, arith.addi(iv, c(num_stages-1)), c(K // BLOCK_K))
+                    p1 = arith.cmpi(arith.CmpIPredicate.ult,
+                                    arith.addi(iv, c(num_stages - 1)),
+                                    c(K // BLOCK_K))
                     p = arith.andi(primaryThread, p1)
                     with ir.InsertionPoint(scf.IfOp(p).then_block):
-                        nextStage = arith.addi(iv, c(num_stages-1))
+                        nextStage = arith.addi(iv, c(num_stages - 1))
                         nextSlot = arith.remui(nextStage, c(num_stages))
                         a_offset = arith.muli(nextSlot, c(lhs_tile_bytes))
-                        a_tma_slice = memref.view(ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
-                                                dynamic_smem, a_offset, [])
-                        b_offset = arith.addi(arith.muli(nextSlot, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
-                        b_tma_slice_1 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
-                                                    dynamic_smem, b_offset, [])
-                        b_offset2 = arith.addi(b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
-                        b_tma_slice_2 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
-                                                    dynamic_smem, b_offset2, [])
+                        a_tma_slice = memref.view(
+                            ir.MemRefType.get(a_tma_shape,
+                                              f16,
+                                              memory_space=smem_space),
+                            dynamic_smem, a_offset, [])
+                        b_offset = arith.addi(
+                            arith.muli(nextSlot, c(rhs_tile_bytes)),
+                            c(lhs_tile_bytes * num_stages))
+                        b_tma_slice_1 = memref.view(
+                            ir.MemRefType.get(b_tma_shape,
+                                              f16,
+                                              memory_space=smem_space),
+                            dynamic_smem, b_offset, [])
+                        b_offset2 = arith.addi(
+                            b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+                        b_tma_slice_2 = memref.view(
+                            ir.MemRefType.get(b_tma_shape,
+                                              f16,
+                                              memory_space=smem_space),
+                            dynamic_smem, b_offset2, [])
 
                         coord = arith.muli(c(64), nextStage)
-                        debug_print("[MainLoop] iv={} TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
-                                iv, a_offset,
-                                b_offset,
-                                b_offset2,
-                                coord, dimX,
-                                dimY, coord,
-                                predicate=primaryThread)
+                        debug_print(
+                            "[MainLoop] iv={} TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
+                            iv,
+                            a_offset,
+                            b_offset,
+                            b_offset2,
+                            coord,
+                            dimX,
+                            dimY,
+                            coord,
+                            predicate=primaryThread)
                         nvgpu.TmaAsyncLoadOp(a_tma_slice,
-                                            mbarTMA,
-                                            a_tma_desc,
-                                            coordinates=[coord, dimX],
-                                            mbarId=nextSlot,
-                                            predicate=primaryThread)
+                                             mbarTMA,
+                                             a_tma_desc,
+                                             coordinates=[coord, dimX],
+                                             mbarId=nextSlot,
+                                             predicate=primaryThread)
                         nvgpu.TmaAsyncLoadOp(b_tma_slice_1,
-                                            mbarTMA,
-                                            b_tma_desc,
-                                            coordinates=[dimY, coord],
-                                            mbarId=nextSlot,
-                                            predicate=primaryThread)
+                                             mbarTMA,
+                                             b_tma_desc,
+                                             coordinates=[dimY, coord],
+                                             mbarId=nextSlot,
+                                             predicate=primaryThread)
                         dimY2 = arith.addi(dimY, c(64))
                         nvgpu.TmaAsyncLoadOp(b_tma_slice_2,
-                                            mbarTMA,
-                                            b_tma_desc,
-                                            coordinates=[dimY2, coord],
-                                            mbarId=nextSlot,
-                                            predicate=primaryThread)
-
-                        debug_print("[MainLoop] mbarTMA[{}] arrive", nextSlot, predicate=primaryThread)
-                        nvgpu.mbarrier_arrive_expect_tx(mbarTMA, c(txcount), nextSlot, predicate=primaryThread)
-                        debug_print("[MainLoop] mbarTMA[{}] arrive [done]", nextSlot, predicate=primaryThread)
+                                             mbarTMA,
+                                             b_tma_desc,
+                                             coordinates=[dimY2, coord],
+                                             mbarId=nextSlot,
+                                             predicate=primaryThread)
+
+                        debug_print("[MainLoop] mbarTMA[{}] arrive",
+                                    nextSlot,
+                                    predicate=primaryThread)
+                        nvgpu.mbarrier_arrive_expect_tx(
+                            mbarTMA,
+                            c(txcount),
+                            nextSlot,
+                            predicate=primaryThread)
+                        debug_print("[MainLoop] mbarTMA[{}] arrive [done]",
+                                    nextSlot,
+                                    predicate=primaryThread)
                         scf.yield_([])
                     # Step 4.5. Change the phaseParity
-                    p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
-                    phaseParity = arith.select(p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity)
+                    p = arith.cmpi(arith.CmpIPredicate.eq, stage,
+                                   c(num_stages - 1))
+                    phaseParity = arith.select(
+                        p, arith.xori(phaseParity, arith.constant(i1, 1)),
+                        phaseParity)
 
                     # Step 4.5. Yield
                     scf.yield_([new_acc, phaseParity])
@@ -641,25 +852,34 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                 gpu.barrier()
 
                 # Step 6. Epilogue (registers --> shared memory)
-                acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space)
+                acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N),
+                                                c_elem_ty,
+                                                memory_space=smem_space)
                 acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
                 debug_print("Storing", predicate=primaryThread)
                 nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
                 gpu.barrier()
 
                 # GPU Step 7. Epilogue (shared memory --> global memory)
-                fd = ir.MemRefType.get([BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space)
+                fd = ir.MemRefType.get([BLOCK_M * BLOCK_N],
+                                       c_elem_ty,
+                                       memory_space=smem_space)
                 collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
-                rty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty,
-                                        ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"))
-                c_device_per_block = memref.SubViewOp(rty, c_device, [dimX, dimY], [], [], [MLIR_DYNAMIC, MLIR_DYNAMIC],
-                                                      [BLOCK_M, BLOCK_N], [1, 1])
+                rty = ir.MemRefType.get(
+                    (BLOCK_M, BLOCK_N), c_elem_ty,
+                    ir.Attribute.parse("strided<[" + str(N) +
+                                       ", 1], offset: ?>"))
+                c_device_per_block = memref.SubViewOp(
+                    rty, c_device, [dimX, dimY], [], [],
+                    [MLIR_DYNAMIC, MLIR_DYNAMIC], [BLOCK_M, BLOCK_N], [1, 1])
                 vlen = 1
-                for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE))
+                for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N),
+                                   c(vlen * WARP_GROUP_SIZE))
                 with ir.InsertionPoint(for_op.body):
                     x = arith.divui(for_op.induction_variable, c(BLOCK_M))
                     y = arith.remui(for_op.induction_variable, c(BLOCK_N))
-                    vdata = vector.load(ir.VectorType.get((vlen, ), c_elem_ty), collapsed_smem,
+                    vdata = vector.load(ir.VectorType.get(
+                        (vlen, ), c_elem_ty), collapsed_smem,
                                         [for_op.induction_variable])
                     vector.store(vdata, c_device_per_block, [x, y])
                     scf.yield_([])
@@ -671,6 +891,7 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
             t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
             gpu.wait(token_ty, [t9])
 
-    mlir_matmul_multistage.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+    mlir_matmul_multistage.func_op.attributes[
+        "llvm.emit_c_interface"] = ir.UnitAttr.get()
     module.operation.verify()
     return module
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
index 2a4f67326ee718..ab218812bb4d96 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
@@ -15,10 +15,12 @@
 _SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
 sys.path.append(_SCRIPT_PATH)
 
+
 class NvgpuCompiler:
     """Nvgpu class for compiling and building MLIR modules."""
 
-    def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
+    def __init__(self, options: str, opt_level: int,
+                 shared_libs: Sequence[str]):
         pipeline = f"builtin.module(gpu-lower-to-nvvm-pipeline{{{options}}})"
         self.pipeline = pipeline
         self.shared_libs = shared_libs
@@ -31,14 +33,15 @@ def __call__(self, module: ir.Module):
     def compile(self, module: ir.Module):
         """Compiles the module by invoking the nvgpu pipeline."""
         passmanager.PassManager.parse(self.pipeline).run(module.operation)
-    
+
     def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
         """Wraps the module in a JIT execution engine."""
-        return execution_engine.ExecutionEngine(
-            module, opt_level=self.opt_level, shared_libs=self.shared_libs
-        )
+        return execution_engine.ExecutionEngine(module,
+                                                opt_level=self.opt_level,
+                                                shared_libs=self.shared_libs)
 
-    def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+    def compile_and_jit(self,
+                        module: ir.Module) -> execution_engine.ExecutionEngine:
         """Compiles and jits the module."""
         self.compile(module)
         return self.jit(module)

>From 623e90f380ecfd1ef10ae8b7fe539e39dd269e52 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Mon, 12 Feb 2024 13:59:39 +0000
Subject: [PATCH 3/7] format it with black

---
 .../GPU/CUDA/sm90/python/matmul.py            | 177 ++--
 .../CUDA/sm90/python/tools/matmulBuilder.py   | 975 +++++++++++-------
 .../CUDA/sm90/python/tools/nvgpucompiler.py   |  12 +-
 3 files changed, 701 insertions(+), 463 deletions(-)

diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
index 00313270d723d6..ff9c950193e695 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
@@ -86,27 +86,44 @@
 from mlir import runtime as rt
 
 
-def generate_matmul(input_type=np.float16,
-                    output_type=np.float32,
-                    M=4096,
-                    N=4096,
-                    K=4096,
-                    BLOCK_M=128,
-                    BLOCK_N=128,
-                    BLOCK_K=64,
-                    use_warp_specilization=True,
-                    saveIR=False,
-                    max_num_stages=3):
-    with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown(
-    ):
+def generate_matmul(
+    input_type=np.float16,
+    output_type=np.float32,
+    M=4096,
+    N=4096,
+    K=4096,
+    BLOCK_M=128,
+    BLOCK_N=128,
+    BLOCK_K=64,
+    use_warp_specilization=True,
+    saveIR=False,
+    max_num_stages=3,
+):
+    with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown():
         if use_warp_specilization:
             mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(
-                input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K,
-                max_num_stages)
+                input_type,
+                output_type,
+                M,
+                N,
+                K,
+                BLOCK_M,
+                BLOCK_N,
+                BLOCK_K,
+                max_num_stages,
+            )
         else:
             mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage(
-                input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K,
-                max_num_stages)
+                input_type,
+                output_type,
+                M,
+                N,
+                K,
+                BLOCK_M,
+                BLOCK_N,
+                BLOCK_K,
+                max_num_stages,
+            )
 
         mlir_nvgpu_module.operation.verify()
 
@@ -114,7 +131,7 @@ def generate_matmul(input_type=np.float16,
         if saveIR:
             # print(mlir_nvgpu_module)
             original_stdout = sys.stdout
-            with open('gemm.mlir', 'w') as f:
+            with open("gemm.mlir", "w") as f:
                 sys.stdout = f
                 print(mlir_nvgpu_module)
                 sys.stdout = original_stdout
@@ -123,43 +140,77 @@ def generate_matmul(input_type=np.float16,
         options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
         support_lib = os.getenv("SUPPORT_LIB")
         if not os.path.exists(support_lib):
-            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
-                                    support_lib)
-        compiler = nvgpucompiler.NvgpuCompiler(options,
-                                               opt_level=3,
-                                               shared_libs=[support_lib])
+            raise FileNotFoundError(
+                errno.ENOENT, os.strerror(errno.ENOENT), support_lib
+            )
+        compiler = nvgpucompiler.NvgpuCompiler(
+            options, opt_level=3, shared_libs=[support_lib]
+        )
 
         # Compile
         engine = compiler.compile_and_jit(mlir_nvgpu_module)
         return engine
 
 
-def matmul(input_type=np.float16,
-           output_type=np.float32,
-           M=128,
-           N=128,
-           K=128,
-           BLOCK_M=128,
-           BLOCK_N=128,
-           BLOCK_K=64,
-           use_warp_specilization=True,
-           saveIR=False,
-           max_num_stages=3,
-           print_results=False,
-           no_verify=False):
+def matmul(
+    input_type=np.float16,
+    output_type=np.float32,
+    M=128,
+    N=128,
+    K=128,
+    BLOCK_M=128,
+    BLOCK_N=128,
+    BLOCK_K=64,
+    use_warp_specilization=True,
+    saveIR=False,
+    max_num_stages=3,
+    print_results=False,
+    no_verify=False,
+):
     # Print the configuration
     ity = "f16" if input_type == np.float16 else "f32"
     oty = "f16" if output_type == np.float16 else "f32"
     gemmty = "Warp Specilization" if use_warp_specilization else "Multistage"
-    print("===-- Running GEMM " + gemmty + " " + oty + " += " + ity + " * " +
-          ity + ", Size " + str(M) + "x" + str(N) + "x" + str(K) + ", Tile " +
-          str(BLOCK_M) + "x" + str(BLOCK_N) + "x" + str(BLOCK_K) +
-          ", stages " + str(max_num_stages) + " --===")
+    print(
+        "===-- Running GEMM "
+        + gemmty
+        + " "
+        + oty
+        + " += "
+        + ity
+        + " * "
+        + ity
+        + ", Size "
+        + str(M)
+        + "x"
+        + str(N)
+        + "x"
+        + str(K)
+        + ", Tile "
+        + str(BLOCK_M)
+        + "x"
+        + str(BLOCK_N)
+        + "x"
+        + str(BLOCK_K)
+        + ", stages "
+        + str(max_num_stages)
+        + " --==="
+    )
 
     # Build IR and compile
-    engine = generate_matmul(input_type, output_type, M, N, K, BLOCK_M,
-                             BLOCK_N, BLOCK_K, use_warp_specilization, saveIR,
-                             max_num_stages)
+    engine = generate_matmul(
+        input_type,
+        output_type,
+        M,
+        N,
+        K,
+        BLOCK_M,
+        BLOCK_N,
+        BLOCK_K,
+        use_warp_specilization,
+        saveIR,
+        max_num_stages,
+    )
 
     # Allocate matrices and invoke the matmul
     c = np.zeros((M, N), output_type)
@@ -168,13 +219,17 @@ def matmul(input_type=np.float16,
     mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
     mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
     mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
-    kernelName = "mlir_matmul_warpspecialized" if use_warp_specilization else "mlir_matmul_multistage"
+    kernelName = (
+        "mlir_matmul_warpspecialized"
+        if use_warp_specilization
+        else "mlir_matmul_multistage"
+    )
 
     # Launch the MLIR generated kernel
     engine.invoke(kernelName, mem_a, mem_b, mem_c)
 
     float_formatter = "{:.2f}".format
-    np.set_printoptions(formatter={'float_kind': float_formatter})
+    np.set_printoptions(formatter={"float_kind": float_formatter})
 
     if print_results:
         print(c)
@@ -190,18 +245,22 @@ def matmul(input_type=np.float16,
 
 
 # GEMM Multistage       f32 += f16 * f16
-matmul(np.float16,
-       np.float32,
-       128,
-       128,
-       4096,
-       max_num_stages=3,
-       use_warp_specilization=False)
+matmul(
+    np.float16,
+    np.float32,
+    128,
+    128,
+    4096,
+    max_num_stages=3,
+    use_warp_specilization=False,
+)
 # GEMM Warp Specilized  f32 += f16 * f16
-matmul(np.float16,
-       np.float32,
-       256,
-       1024,
-       512,
-       max_num_stages=3,
-       use_warp_specilization=True)
+matmul(
+    np.float16,
+    np.float32,
+    256,
+    1024,
+    512,
+    max_num_stages=3,
+    use_warp_specilization=True,
+)
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
index 0fda8698606cd5..4c132ec992a461 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
@@ -64,15 +64,17 @@ def c(value, ty=None):
     return arith.constant(ty, value)
 
 
-def generate_matmul_ws(input_type=np.float16,
-                       output_type=np.float32,
-                       M=4096,
-                       N=4096,
-                       K=4096,
-                       BLOCK_M=128,
-                       BLOCK_N=128,
-                       BLOCK_K=128,
-                       max_num_stages=3):
+def generate_matmul_ws(
+    input_type=np.float16,
+    output_type=np.float32,
+    M=4096,
+    N=4096,
+    K=4096,
+    BLOCK_M=128,
+    BLOCK_N=128,
+    BLOCK_K=128,
+    max_num_stages=3,
+):
     # Limitaitons for now
     assert input_type == np.float16
     assert output_type == np.float32
@@ -80,8 +82,7 @@ def generate_matmul_ws(input_type=np.float16,
     assert N % BLOCK_N == 0
     assert K % BLOCK_K == 0
 
-    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K +
-                                          BLOCK_K * BLOCK_N)
+    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
     num_stages = min(required_stages, max_num_stages)
 
     module = ir.Module.create()
@@ -99,36 +100,73 @@ def generate_matmul_ws(input_type=np.float16,
     a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
     b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
     b_tile_shape = (BLOCK_K, BLOCK_N)
-    txcount = ((b_tile_shape[0] * b_tile_shape[1]) +
-               (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
+    txcount = (
+        (b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])
+    ) * f16_byte
     smem_space_str = "#gpu.address_space<workgroup>"
     smem_space = ir.Attribute.parse(smem_space_str)
     input_type_str = "f16" if input_type == np.float16 else "f32"
     output_type_str = "f16" if output_type == np.float16 else "f32"
-    mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " +
-                            str(smem_space) + ", num_barriers = " +
-                            str(num_stages) + ">")
+    mbar_ty = ir.Type.parse(
+        "!nvgpu.mbarrier.group<memorySpace = "
+        + str(smem_space)
+        + ", num_barriers = "
+        + str(num_stages)
+        + ">"
+    )
     a_tma_desc_ty = ir.Type.parse(
-        "!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
-        str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " +
-        str(smem_space) +
-        ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+        "!nvgpu.tensormap.descriptor<tensor = memref<"
+        + str(BLOCK_M)
+        + "x"
+        + str(TMA_LAST_DIM_F16)
+        + "x"
+        + str(input_type_str)
+        + ", "
+        + str(smem_space)
+        + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
+    )
     b_tma_desc_ty = ir.Type.parse(
-        "!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
-        str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " +
-        str(smem_space) +
-        ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
-    acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" +
-                           str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
-                           str(output_type_str) + ">>")
-    a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" +
-                               str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
-                               str(input_type_str) + ", " + smem_space_str +
-                               ">>")
-    b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" +
-                               str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
-                               str(input_type_str) + ", " + smem_space_str +
-                               ">>")
+        "!nvgpu.tensormap.descriptor<tensor = memref<"
+        + str(BLOCK_K)
+        + "x"
+        + str(TMA_LAST_DIM_F16)
+        + "x"
+        + str(input_type_str)
+        + ", "
+        + str(smem_space)
+        + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
+    )
+    acc_ty = ir.Type.parse(
+        "!nvgpu.warpgroup.accumulator<fragmented=vector<"
+        + str(BLOCK_M)
+        + "x"
+        + str(BLOCK_N)
+        + "x"
+        + str(output_type_str)
+        + ">>"
+    )
+    a_wgmma_ty = ir.Type.parse(
+        "!nvgpu.warpgroup.descriptor<tensor=memref<"
+        + str(BLOCK_M)
+        + "x"
+        + str(BLOCK_K)
+        + "x"
+        + str(input_type_str)
+        + ", "
+        + smem_space_str
+        + ">>"
+    )
+    b_wgmma_ty = ir.Type.parse(
+        "!nvgpu.warpgroup.descriptor<tensor=memref<"
+        + str(BLOCK_K)
+        + "x"
+        + str(BLOCK_N)
+        + "x"
+        + str(input_type_str)
+        + ", "
+        + smem_space_str
+        + ">>"
+    )
 
     with ir.InsertionPoint(module.body):
 
@@ -150,16 +188,20 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
             t7 = gpu.wait(token_ty, [t6])
 
             # Step 2. Create TMA Descriptors
-            tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape),
-                         (b_device, b_tma_desc_ty, b_tma_shape)]
+            tma_specs = [
+                (a_device, a_tma_desc_ty, a_tma_shape),
+                (b_device, b_tma_desc_ty, b_tma_shape),
+            ]
             tma_descs = []
             for x_device, tensor_map_ty, tile_shape in tma_specs:
                 x_unranked = memref.cast(
-                    ir.UnrankedMemRefType.get(f16, a_ty.memory_space),
-                    x_device)
+                    ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device
+                )
                 tma_descs.append(
-                    nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked,
-                                                map(c, tile_shape)).result)
+                    nvgpu.TmaCreateDescriptorOp(
+                        tensor_map_ty, x_unranked, map(c, tile_shape)
+                    ).result
+                )
             a_tma_desc, b_tma_desc = tma_descs
 
             # Step 3. Launch Kernel with 2 Warpgroups : 1 Producer, 1 Consumer
@@ -168,40 +210,45 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
             assert M % BLOCK_M == 0 and N % BLOCK_N == 0
             grid = (cta_m, cta_n, 1)
             block = (WARP_GROUP_SIZE * 2, 1, 1)
-            launch_op = gpu.LaunchOp(token_ty, [t7],
-                                     *map(c, grid),
-                                     *map(c, block),
-                                     dynamicSharedMemorySize=c(smem_size,
-                                                               ty=i32))
+            launch_op = gpu.LaunchOp(
+                token_ty,
+                [t7],
+                *map(c, grid),
+                *map(c, block),
+                dynamicSharedMemorySize=c(smem_size, ty=i32)
+            )
             launch_op.body.blocks.append(*([index] * 12))
             with ir.InsertionPoint(launch_op.body.blocks[0]):
                 # GPU Step 0. This is need for vectorized ld/st
                 memref.assume_alignment(c_device, 16)
                 dynamic_smem = gpu.dynamic_shared_memory(
-                    ir.MemRefType.get((MLIR_DYNAMIC, ),
-                                      i8,
-                                      memory_space=smem_space))
+                    ir.MemRefType.get((MLIR_DYNAMIC,), i8, memory_space=smem_space)
+                )
                 ticks = c(10000000)
 
                 # GPU Step 1. Bootstrapping: find the primary thread, warps, warp groups and etc.
                 tidx = gpu.thread_id(gpu.Dimension.x)
                 wgPrimaryThread = arith.cmpi(
-                    arith.CmpIPredicate.eq,
-                    arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0))
+                    arith.CmpIPredicate.eq, arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0)
+                )
                 warp_id = arith.divui(tidx, c(32))
                 warpgroup_id = arith.divui(warp_id, c(4))
                 is_producer = arith.cmpi(
-                    arith.CmpIPredicate.eq, warpgroup_id,
-                    c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0))
+                    arith.CmpIPredicate.eq,
+                    warpgroup_id,
+                    c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0),
+                )
                 is_consumer = arith.cmpi(
-                    arith.CmpIPredicate.eq, warpgroup_id,
-                    c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1))
-                producerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq,
-                                                   tidx,
-                                                   c(PRODUCER_PRIMARY_THREAD))
-                consumerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq,
-                                                   tidx,
-                                                   c(CONSUMER_PRIMARY_THREAD))
+                    arith.CmpIPredicate.eq,
+                    warpgroup_id,
+                    c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1),
+                )
+                producerPrimaryThread = arith.cmpi(
+                    arith.CmpIPredicate.eq, tidx, c(PRODUCER_PRIMARY_THREAD)
+                )
+                consumerPrimaryThread = arith.cmpi(
+                    arith.CmpIPredicate.eq, tidx, c(CONSUMER_PRIMARY_THREAD)
+                )
                 bidx = gpu.block_id(gpu.Dimension.x)
                 bidy = gpu.block_id(gpu.Dimension.y)
                 dimX = arith.muli(bidx, c(BLOCK_M))
@@ -211,32 +258,25 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                 mbarTMA = nvgpu.mbarrier_create(mbar_ty)
                 mbarDONE = nvgpu.mbarrier_create(mbar_ty)
                 for i in range(num_stages):
-                    nvgpu.mbarrier_init(mbarTMA,
-                                        c(1),
-                                        c(i),
-                                        predicate=wgPrimaryThread)
-                    nvgpu.mbarrier_init(mbarDONE,
-                                        c(1),
-                                        c(i),
-                                        predicate=wgPrimaryThread)
+                    nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=wgPrimaryThread)
+                    nvgpu.mbarrier_init(mbarDONE, c(1), c(i), predicate=wgPrimaryThread)
                 gpu.barrier()
 
                 # GPU Step 3. Prefetch TMA descriptors
-                nvgpu.tma_prefetch_descriptor(a_tma_desc,
-                                              predicate=wgPrimaryThread)
-                nvgpu.tma_prefetch_descriptor(b_tma_desc,
-                                              predicate=wgPrimaryThread)
+                nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=wgPrimaryThread)
+                nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=wgPrimaryThread)
 
                 # GPU Step 5. Producer Warpgroup (TMA Warpgroup)
                 with ir.InsertionPoint(scf.IfOp(is_producer).then_block):
-
                     # Step 5.1. Reduce register size
-                    nvvm.setmaxregister(PRODUCER_REGISTER_SIZE,
-                                        nvvm.SetMaxRegisterAction.decrease)
+                    nvvm.setmaxregister(
+                        PRODUCER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.decrease
+                    )
 
                     # Step 5.2. TMA Main Loop
-                    for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1),
-                                       [arith.constant(i1, 1)])
+                    for_op = scf.ForOp(
+                        c(0), c(K // BLOCK_K), c(1), [arith.constant(i1, 1)]
+                    )
                     with ir.InsertionPoint(for_op.body):
                         phaseParity = for_op.inner_iter_args[0]
                         iv = for_op.induction_variable
@@ -248,103 +288,126 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                             iv,
                             stage,
                             phaseParity,
-                            predicate=producerPrimaryThread)
-                        nvgpu.MBarrierTryWaitParityOp(mbarDONE,
-                                                      phaseParity,
-                                                      ticks,
-                                                      mbarId=stage)
+                            predicate=producerPrimaryThread,
+                        )
+                        nvgpu.MBarrierTryWaitParityOp(
+                            mbarDONE, phaseParity, ticks, mbarId=stage
+                        )
                         debug_print(
                             "[prod] {}  | mbarDONE[{}] try_wait  phase={} [done]",
                             iv,
                             stage,
                             phaseParity,
-                            predicate=producerPrimaryThread)
-                        p = arith.cmpi(arith.CmpIPredicate.eq, stage,
-                                       c(num_stages - 1))
+                            predicate=producerPrimaryThread,
+                        )
+                        p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
                         phaseParity = arith.select(
-                            p, arith.xori(phaseParity, arith.constant(i1, 1)),
-                            phaseParity)
+                            p,
+                            arith.xori(phaseParity, arith.constant(i1, 1)),
+                            phaseParity,
+                        )
 
                         # Step 5.2.2. Load TMA
                         a_offset = arith.muli(stage, c(lhs_tile_bytes))
                         a_tma_slice = memref.view(
-                            ir.MemRefType.get(a_tma_shape,
-                                              f16,
-                                              memory_space=smem_space),
-                            dynamic_smem, a_offset, [])
+                            ir.MemRefType.get(
+                                a_tma_shape, f16, memory_space=smem_space
+                            ),
+                            dynamic_smem,
+                            a_offset,
+                            [],
+                        )
                         b_offset = arith.addi(
                             arith.muli(stage, c(rhs_tile_bytes)),
-                            c(lhs_tile_bytes * num_stages))
+                            c(lhs_tile_bytes * num_stages),
+                        )
                         b_tma_slice_1 = memref.view(
-                            ir.MemRefType.get(b_tma_shape,
-                                              f16,
-                                              memory_space=smem_space),
-                            dynamic_smem, b_offset, [])
+                            ir.MemRefType.get(
+                                b_tma_shape, f16, memory_space=smem_space
+                            ),
+                            dynamic_smem,
+                            b_offset,
+                            [],
+                        )
                         b_offset2 = arith.addi(
-                            b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+                            b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte)
+                        )
                         b_tma_slice_2 = memref.view(
-                            ir.MemRefType.get(b_tma_shape,
-                                              f16,
-                                              memory_space=smem_space),
-                            dynamic_smem, b_offset2, [])
+                            ir.MemRefType.get(
+                                b_tma_shape, f16, memory_space=smem_space
+                            ),
+                            dynamic_smem,
+                            b_offset2,
+                            [],
+                        )
                         debug_print(
                             "[prod] a_offset={} b_offset={} b_offset2={}",
                             a_offset,
                             b_offset,
                             b_offset2,
-                            predicate=producerPrimaryThread)
+                            predicate=producerPrimaryThread,
+                        )
                         coord = arith.muli(c(64), iv)
-                        nvgpu.TmaAsyncLoadOp(a_tma_slice,
-                                             mbarTMA,
-                                             a_tma_desc,
-                                             coordinates=[coord, dimX],
-                                             mbarId=stage,
-                                             predicate=producerPrimaryThread)
-                        nvgpu.TmaAsyncLoadOp(b_tma_slice_1,
-                                             mbarTMA,
-                                             b_tma_desc,
-                                             coordinates=[dimY, coord],
-                                             mbarId=stage,
-                                             predicate=producerPrimaryThread)
+                        nvgpu.TmaAsyncLoadOp(
+                            a_tma_slice,
+                            mbarTMA,
+                            a_tma_desc,
+                            coordinates=[coord, dimX],
+                            mbarId=stage,
+                            predicate=producerPrimaryThread,
+                        )
+                        nvgpu.TmaAsyncLoadOp(
+                            b_tma_slice_1,
+                            mbarTMA,
+                            b_tma_desc,
+                            coordinates=[dimY, coord],
+                            mbarId=stage,
+                            predicate=producerPrimaryThread,
+                        )
                         dimY2 = arith.addi(dimY, c(64))
-                        nvgpu.TmaAsyncLoadOp(b_tma_slice_2,
-                                             mbarTMA,
-                                             b_tma_desc,
-                                             coordinates=[dimY2, coord],
-                                             mbarId=stage,
-                                             predicate=producerPrimaryThread)
+                        nvgpu.TmaAsyncLoadOp(
+                            b_tma_slice_2,
+                            mbarTMA,
+                            b_tma_desc,
+                            coordinates=[dimY2, coord],
+                            mbarId=stage,
+                            predicate=producerPrimaryThread,
+                        )
 
                         # Step 5.2.3. Arrive mbarTMA
-                        debug_print("[prod] {}  | mbarTMA[{}] arrive",
-                                    iv,
-                                    stage,
-                                    predicate=producerPrimaryThread)
+                        debug_print(
+                            "[prod] {}  | mbarTMA[{}] arrive",
+                            iv,
+                            stage,
+                            predicate=producerPrimaryThread,
+                        )
                         nvgpu.mbarrier_arrive_expect_tx(
-                            mbarTMA,
-                            c(txcount),
+                            mbarTMA, c(txcount), stage, predicate=producerPrimaryThread
+                        )
+                        debug_print(
+                            "[prod] {}  | mbarTMA[{}] arrive [done]",
+                            iv,
                             stage,
-                            predicate=producerPrimaryThread)
-                        debug_print("[prod] {}  | mbarTMA[{}] arrive [done]",
-                                    iv,
-                                    stage,
-                                    predicate=producerPrimaryThread)
+                            predicate=producerPrimaryThread,
+                        )
                         scf.yield_([phaseParity])
                     scf.yield_([])
 
                 # GPU Step 6. Consumer Warpgroup (MMA Warpgroup)
                 if_op = scf.IfOp(is_consumer)
                 with ir.InsertionPoint(if_op.then_block):
-
                     # Step 6.1. Increase register size
-                    nvvm.setmaxregister(CONSUMER_REGISTER_SIZE,
-                                        nvvm.SetMaxRegisterAction.increase)
+                    nvvm.setmaxregister(
+                        CONSUMER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.increase
+                    )
 
                     # GPU Step 6.2. Initialize MMA registers
                     acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
 
                     # Step 6.3. MMA Main Loop
-                    for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1),
-                                       [acc, arith.constant(i1, 0)])
+                    for_op = scf.ForOp(
+                        c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)]
+                    )
                     with ir.InsertionPoint(for_op.body):
                         # Step 6.3.1. Wait mbar1
                         phaseParity = for_op.inner_iter_args[1]
@@ -355,57 +418,68 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                             iv,
                             stage,
                             phaseParity,
-                            predicate=consumerPrimaryThread)
-                        nvgpu.MBarrierTryWaitParityOp(mbarTMA,
-                                                      phaseParity,
-                                                      ticks,
-                                                      mbarId=stage)
+                            predicate=consumerPrimaryThread,
+                        )
+                        nvgpu.MBarrierTryWaitParityOp(
+                            mbarTMA, phaseParity, ticks, mbarId=stage
+                        )
                         debug_print(
                             "[cons] {}  | mbarTMA[{}] try_wait   phase={} [done]",
                             iv,
                             stage,
                             phaseParity,
-                            predicate=consumerPrimaryThread)
+                            predicate=consumerPrimaryThread,
+                        )
 
                         # Step 6.3.2. Create WGMMA Descriptors
                         a_offset = arith.muli(stage, c(lhs_tile_bytes))
                         a_tile_slice = memref.view(
-                            ir.MemRefType.get(a_tile_shape,
-                                              f16,
-                                              memory_space=smem_space),
-                            dynamic_smem, a_offset, [])
+                            ir.MemRefType.get(
+                                a_tile_shape, f16, memory_space=smem_space
+                            ),
+                            dynamic_smem,
+                            a_offset,
+                            [],
+                        )
                         b_offset = arith.addi(
                             arith.muli(stage, c(rhs_tile_bytes)),
-                            c(lhs_tile_bytes * num_stages))
+                            c(lhs_tile_bytes * num_stages),
+                        )
                         b_tile_slice = memref.view(
-                            ir.MemRefType.get(b_tile_shape,
-                                              f16,
-                                              memory_space=smem_space),
-                            dynamic_smem, b_offset, [])
-                        debug_print("[cons] a_offset={} b_offset={}",
-                                    a_offset,
-                                    b_offset,
-                                    predicate=consumerPrimaryThread)
+                            ir.MemRefType.get(
+                                b_tile_shape, f16, memory_space=smem_space
+                            ),
+                            dynamic_smem,
+                            b_offset,
+                            [],
+                        )
+                        debug_print(
+                            "[cons] a_offset={} b_offset={}",
+                            a_offset,
+                            b_offset,
+                            predicate=consumerPrimaryThread,
+                        )
                         da = nvgpu.WarpgroupGenerateDescriptorOp(
-                            a_wgmma_ty, a_tile_slice, a_tma_desc)
+                            a_wgmma_ty, a_tile_slice, a_tma_desc
+                        )
                         db = nvgpu.WarpgroupGenerateDescriptorOp(
-                            b_wgmma_ty, b_tile_slice, b_tma_desc)
+                            b_wgmma_ty, b_tile_slice, b_tma_desc
+                        )
 
                         # Step 6.3.3. MMA
                         carry_acc = for_op.inner_iter_args[0]
-                        new_acc = nvgpu.WarpgroupMmaOp(acc.type,
-                                                       da,
-                                                       db,
-                                                       carry_acc,
-                                                       transposeB=True)
+                        new_acc = nvgpu.WarpgroupMmaOp(
+                            acc.type, da, db, carry_acc, transposeB=True
+                        )
 
                         # Step 6.3.4. Arrive mbarDONE
                         p1 = arith.cmpi(arith.CmpIPredicate.sgt, iv, c(0))
                         p_arrive = arith.andi(consumerPrimaryThread, p1)
                         with ir.InsertionPoint(scf.IfOp(p_arrive).then_block):
                             p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(0))
-                            barId = arith.select(p, c(num_stages - 1),
-                                                 arith.subi(stage, c(1)))
+                            barId = arith.select(
+                                p, c(num_stages - 1), arith.subi(stage, c(1))
+                            )
                             remoteCtaId = c(0)
                             pred = consumerPrimaryThread
                             debug_print(
@@ -414,24 +488,27 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                                 barId,
                                 remoteCtaId,
                                 pred,
-                                predicate=consumerPrimaryThread)
+                                predicate=consumerPrimaryThread,
+                            )
                             nvgpu.mbarrier_arrive(
-                                ir.Type.parse("!nvgpu.mbarrier.token"),
-                                mbarDONE, barId)
+                                ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId
+                            )
                             debug_print(
                                 "[cons] {}  | mbarDONE[{}] arrive  pred={} [done]",
                                 iv,
                                 barId,
                                 remoteCtaId,
                                 pred,
-                                predicate=consumerPrimaryThread)
+                                predicate=consumerPrimaryThread,
+                            )
                             scf.yield_([])
 
-                        p = arith.cmpi(arith.CmpIPredicate.eq, stage,
-                                       c(num_stages - 1))
+                        p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
                         phaseParity = arith.select(
-                            p, arith.xori(phaseParity, arith.constant(i1, 1)),
-                            phaseParity)
+                            p,
+                            arith.xori(phaseParity, arith.constant(i1, 1)),
+                            phaseParity,
+                        )
 
                         # Step 6.3.5. Yield
                         scf.yield_([new_acc, phaseParity])
@@ -439,46 +516,55 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                     # Step 6.3. Wait All WGMMA
                     nvvm.WgmmaWaitGroupSyncOp(0)
 
-                    with ir.InsertionPoint(
-                            scf.IfOp(consumerPrimaryThread).then_block):
+                    with ir.InsertionPoint(scf.IfOp(consumerPrimaryThread).then_block):
                         barId = c((K // BLOCK_K) % num_stages)
                         nvgpu.mbarrier_arrive(
-                            ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE,
-                            barId)
+                            ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId
+                        )
                         scf.yield_([])
 
                     # Step 6.4. Epilogue (registers --> shared memory)
-                    acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N),
-                                                    c_elem_ty,
-                                                    memory_space=smem_space)
+                    acc_smem_ty = ir.MemRefType.get(
+                        (BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space
+                    )
                     acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
-                    debug_print("[cons]  | Storing",
-                                predicate=consumerPrimaryThread)
+                    debug_print("[cons]  | Storing", predicate=consumerPrimaryThread)
                     nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
                     scf.yield_([])
                 gpu.barrier()
 
                 # GPU Step 9. Epilogue (shared memory --> global memory)
-                fd = ir.MemRefType.get([BLOCK_M * BLOCK_N],
-                                       c_elem_ty,
-                                       memory_space=smem_space)
+                fd = ir.MemRefType.get(
+                    [BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space
+                )
                 collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
                 rty = ir.MemRefType.get(
-                    (BLOCK_M, BLOCK_N), c_elem_ty,
-                    ir.Attribute.parse("strided<[" + str(N) +
-                                       ", 1], offset: ?>"))
+                    (BLOCK_M, BLOCK_N),
+                    c_elem_ty,
+                    ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"),
+                )
                 c_device_per_block = memref.SubViewOp(
-                    rty, c_device, [dimX, dimY], [], [],
-                    [MLIR_DYNAMIC, MLIR_DYNAMIC], [BLOCK_M, BLOCK_N], [1, 1])
+                    rty,
+                    c_device,
+                    [dimX, dimY],
+                    [],
+                    [],
+                    [MLIR_DYNAMIC, MLIR_DYNAMIC],
+                    [BLOCK_M, BLOCK_N],
+                    [1, 1],
+                )
                 vlen = 1
-                for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N),
-                                   c(vlen * WARP_GROUP_SIZE * 2))
+                for_op = scf.ForOp(
+                    tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE * 2)
+                )
                 with ir.InsertionPoint(for_op.body):
                     x = arith.divui(for_op.induction_variable, c(BLOCK_M))
                     y = arith.remui(for_op.induction_variable, c(BLOCK_N))
-                    vdata = vector.load(ir.VectorType.get(
-                        (vlen, ), c_elem_ty), collapsed_smem,
-                                        [for_op.induction_variable])
+                    vdata = vector.load(
+                        ir.VectorType.get((vlen,), c_elem_ty),
+                        collapsed_smem,
+                        [for_op.induction_variable],
+                    )
                     vector.store(vdata, c_device_per_block, [x, y])
                     scf.yield_([])
 
@@ -490,20 +576,23 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
             gpu.wait(token_ty, [t9])
 
     mlir_matmul_warpspecialized.func_op.attributes[
-        "llvm.emit_c_interface"] = ir.UnitAttr.get()
+        "llvm.emit_c_interface"
+    ] = ir.UnitAttr.get()
     module.operation.verify()
     return module
 
 
-def generate_matmul_multistage(input_type=np.float16,
-                               output_type=np.float32,
-                               M=4096,
-                               N=4096,
-                               K=4096,
-                               BLOCK_M=128,
-                               BLOCK_N=128,
-                               BLOCK_K=64,
-                               max_num_stages=3):
+def generate_matmul_multistage(
+    input_type=np.float16,
+    output_type=np.float32,
+    M=4096,
+    N=4096,
+    K=4096,
+    BLOCK_M=128,
+    BLOCK_N=128,
+    BLOCK_K=64,
+    max_num_stages=3,
+):
     # Limitaitons for now
     assert input_type == np.float16
     assert output_type == np.float32
@@ -511,8 +600,7 @@ def generate_matmul_multistage(input_type=np.float16,
     assert N % BLOCK_N == 0
     assert K % BLOCK_K == 0
 
-    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K +
-                                          BLOCK_K * BLOCK_N)
+    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
     num_stages = min(required_stages, max_num_stages)
 
     module = ir.Module.create()
@@ -530,36 +618,73 @@ def generate_matmul_multistage(input_type=np.float16,
     a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
     b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
     b_tile_shape = (BLOCK_K, BLOCK_N)
-    txcount = ((b_tile_shape[0] * b_tile_shape[1]) +
-               (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
+    txcount = (
+        (b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])
+    ) * f16_byte
     smem_space_str = "#gpu.address_space<workgroup>"
     smem_space = ir.Attribute.parse(smem_space_str)
     input_type_str = "f16" if input_type == np.float16 else "f32"
     output_type_str = "f16" if output_type == np.float16 else "f32"
-    mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " +
-                            str(smem_space) + ", num_barriers = " +
-                            str(num_stages) + ">")
+    mbar_ty = ir.Type.parse(
+        "!nvgpu.mbarrier.group<memorySpace = "
+        + str(smem_space)
+        + ", num_barriers = "
+        + str(num_stages)
+        + ">"
+    )
     a_tma_desc_ty = ir.Type.parse(
-        "!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
-        str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " +
-        str(smem_space) +
-        ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+        "!nvgpu.tensormap.descriptor<tensor = memref<"
+        + str(BLOCK_M)
+        + "x"
+        + str(TMA_LAST_DIM_F16)
+        + "x"
+        + str(input_type_str)
+        + ", "
+        + str(smem_space)
+        + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
+    )
     b_tma_desc_ty = ir.Type.parse(
-        "!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
-        str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " +
-        str(smem_space) +
-        ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
-    acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" +
-                           str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
-                           str(output_type_str) + ">>")
-    a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" +
-                               str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
-                               str(input_type_str) + ", " + smem_space_str +
-                               ">>")
-    b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" +
-                               str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
-                               str(input_type_str) + ", " + smem_space_str +
-                               ">>")
+        "!nvgpu.tensormap.descriptor<tensor = memref<"
+        + str(BLOCK_K)
+        + "x"
+        + str(TMA_LAST_DIM_F16)
+        + "x"
+        + str(input_type_str)
+        + ", "
+        + str(smem_space)
+        + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
+    )
+    acc_ty = ir.Type.parse(
+        "!nvgpu.warpgroup.accumulator<fragmented=vector<"
+        + str(BLOCK_M)
+        + "x"
+        + str(BLOCK_N)
+        + "x"
+        + str(output_type_str)
+        + ">>"
+    )
+    a_wgmma_ty = ir.Type.parse(
+        "!nvgpu.warpgroup.descriptor<tensor=memref<"
+        + str(BLOCK_M)
+        + "x"
+        + str(BLOCK_K)
+        + "x"
+        + str(input_type_str)
+        + ", "
+        + smem_space_str
+        + ">>"
+    )
+    b_wgmma_ty = ir.Type.parse(
+        "!nvgpu.warpgroup.descriptor<tensor=memref<"
+        + str(BLOCK_K)
+        + "x"
+        + str(BLOCK_N)
+        + "x"
+        + str(input_type_str)
+        + ", "
+        + smem_space_str
+        + ">>"
+    )
 
     with ir.InsertionPoint(module.body):
 
@@ -581,16 +706,20 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
             t7 = gpu.wait(token_ty, [t6])
 
             # Step 2. Create TMA Descriptors
-            tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape),
-                         (b_device, b_tma_desc_ty, b_tma_shape)]
+            tma_specs = [
+                (a_device, a_tma_desc_ty, a_tma_shape),
+                (b_device, b_tma_desc_ty, b_tma_shape),
+            ]
             tma_descs = []
             for x_device, tensor_map_ty, tile_shape in tma_specs:
                 x_unranked = memref.cast(
-                    ir.UnrankedMemRefType.get(f16, a_ty.memory_space),
-                    x_device)
+                    ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device
+                )
                 tma_descs.append(
-                    nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked,
-                                                map(c, tile_shape)).result)
+                    nvgpu.TmaCreateDescriptorOp(
+                        tensor_map_ty, x_unranked, map(c, tile_shape)
+                    ).result
+                )
             a_tma_desc, b_tma_desc = tma_descs
 
             # Step 3. Launch Kernel with 1 Warpgroup
@@ -599,19 +728,20 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
             assert M % BLOCK_M == 0 and N % BLOCK_N == 0
             grid = (cta_m, cta_n, 1)
             block = (WARP_GROUP_SIZE, 1, 1)
-            launch_op = gpu.LaunchOp(token_ty, [t7],
-                                     *map(c, grid),
-                                     *map(c, block),
-                                     dynamicSharedMemorySize=c(smem_size,
-                                                               ty=i32))
+            launch_op = gpu.LaunchOp(
+                token_ty,
+                [t7],
+                *map(c, grid),
+                *map(c, block),
+                dynamicSharedMemorySize=c(smem_size, ty=i32)
+            )
             launch_op.body.blocks.append(*([index] * 12))
             with ir.InsertionPoint(launch_op.body.blocks[0]):
                 # GPU Step 0. Bootstrapping
                 memref.assume_alignment(c_device, 16)
                 dynamic_smem = gpu.dynamic_shared_memory(
-                    ir.MemRefType.get((MLIR_DYNAMIC, ),
-                                      i8,
-                                      memory_space=smem_space))
+                    ir.MemRefType.get((MLIR_DYNAMIC,), i8, memory_space=smem_space)
+                )
                 ticks = c(10000000)
                 tidx = gpu.thread_id(gpu.Dimension.x)
                 primaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(0))
@@ -624,17 +754,12 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                 # GPU Step 1. Initialize mbarrier groups
                 mbarTMA = nvgpu.mbarrier_create(mbar_ty)
                 for i in range(num_stages):
-                    nvgpu.mbarrier_init(mbarTMA,
-                                        c(1),
-                                        c(i),
-                                        predicate=primaryThread)
+                    nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=primaryThread)
                 gpu.barrier()
 
                 # GPU Step 2. Prefetch TMA descriptors
-                nvgpu.tma_prefetch_descriptor(a_tma_desc,
-                                              predicate=primaryThread)
-                nvgpu.tma_prefetch_descriptor(b_tma_desc,
-                                              predicate=primaryThread)
+                nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=primaryThread)
+                nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=primaryThread)
 
                 # GPU Step 3. Prologue (global memory --> shared memory)
                 for_op = scf.ForOp(c(0), c(num_stages - 1), c(1))
@@ -644,24 +769,30 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                     # Step 3.1. Calculate offsets
                     a_offset = arith.muli(iv, c(lhs_tile_bytes))
                     a_tma_slice = memref.view(
-                        ir.MemRefType.get(a_tma_shape,
-                                          f16,
-                                          memory_space=smem_space),
-                        dynamic_smem, a_offset, [])
-                    b_offset = arith.addi(arith.muli(iv, c(rhs_tile_bytes)),
-                                          c(lhs_tile_bytes * num_stages))
+                        ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
+                        dynamic_smem,
+                        a_offset,
+                        [],
+                    )
+                    b_offset = arith.addi(
+                        arith.muli(iv, c(rhs_tile_bytes)),
+                        c(lhs_tile_bytes * num_stages),
+                    )
                     b_tma_slice_1 = memref.view(
-                        ir.MemRefType.get(b_tma_shape,
-                                          f16,
-                                          memory_space=smem_space),
-                        dynamic_smem, b_offset, [])
+                        ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+                        dynamic_smem,
+                        b_offset,
+                        [],
+                    )
                     b_offset2 = arith.addi(
-                        b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+                        b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte)
+                    )
                     b_tma_slice_2 = memref.view(
-                        ir.MemRefType.get(b_tma_shape,
-                                          f16,
-                                          memory_space=smem_space),
-                        dynamic_smem, b_offset2, [])
+                        ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+                        dynamic_smem,
+                        b_offset2,
+                        [],
+                    )
 
                     # Step 3.2. TMA Load
                     coord = arith.muli(c(64), iv)
@@ -675,123 +806,153 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                         dimX,
                         dimY,
                         coord,
-                        predicate=primaryThread)
-                    nvgpu.TmaAsyncLoadOp(a_tma_slice,
-                                         mbarTMA,
-                                         a_tma_desc,
-                                         coordinates=[coord, dimX],
-                                         mbarId=iv,
-                                         predicate=primaryThread)
-                    nvgpu.TmaAsyncLoadOp(b_tma_slice_1,
-                                         mbarTMA,
-                                         b_tma_desc,
-                                         coordinates=[dimY, coord],
-                                         mbarId=iv,
-                                         predicate=primaryThread)
-                    nvgpu.TmaAsyncLoadOp(b_tma_slice_2,
-                                         mbarTMA,
-                                         b_tma_desc,
-                                         coordinates=[dimY2, coord],
-                                         mbarId=iv,
-                                         predicate=primaryThread)
+                        predicate=primaryThread,
+                    )
+                    nvgpu.TmaAsyncLoadOp(
+                        a_tma_slice,
+                        mbarTMA,
+                        a_tma_desc,
+                        coordinates=[coord, dimX],
+                        mbarId=iv,
+                        predicate=primaryThread,
+                    )
+                    nvgpu.TmaAsyncLoadOp(
+                        b_tma_slice_1,
+                        mbarTMA,
+                        b_tma_desc,
+                        coordinates=[dimY, coord],
+                        mbarId=iv,
+                        predicate=primaryThread,
+                    )
+                    nvgpu.TmaAsyncLoadOp(
+                        b_tma_slice_2,
+                        mbarTMA,
+                        b_tma_desc,
+                        coordinates=[dimY2, coord],
+                        mbarId=iv,
+                        predicate=primaryThread,
+                    )
 
                     # Step 3.2. mbarTMA arrive
-                    debug_print("[Prologue] mbarTMA[{}] arrive",
-                                iv,
-                                predicate=primaryThread)
-                    nvgpu.mbarrier_arrive_expect_tx(mbarTMA,
-                                                    c(txcount),
-                                                    iv,
-                                                    predicate=primaryThread)
-                    debug_print("[Prologue] mbarTMA[{}] arrive [done]",
-                                iv,
-                                predicate=primaryThread)
+                    debug_print(
+                        "[Prologue] mbarTMA[{}] arrive", iv, predicate=primaryThread
+                    )
+                    nvgpu.mbarrier_arrive_expect_tx(
+                        mbarTMA, c(txcount), iv, predicate=primaryThread
+                    )
+                    debug_print(
+                        "[Prologue] mbarTMA[{}] arrive [done]",
+                        iv,
+                        predicate=primaryThread,
+                    )
                     scf.yield_([])
 
                 # GPU Step 4. Main Loop
                 acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
-                for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1),
-                                   [acc, arith.constant(i1, 0)])
+                for_op = scf.ForOp(
+                    c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)]
+                )
                 with ir.InsertionPoint(for_op.body):
                     # Step 4.1. Wait mbarTMA
                     phaseParity = for_op.inner_iter_args[1]
                     iv = for_op.induction_variable
                     stage = arith.remui(iv, c(num_stages))
-                    debug_print("[MainLoop] mbarTMA[{}] try_wait   phase={}",
-                                stage,
-                                phaseParity,
-                                predicate=primaryThread)
-                    nvgpu.MBarrierTryWaitParityOp(mbarTMA,
-                                                  phaseParity,
-                                                  ticks,
-                                                  mbarId=stage)
+                    debug_print(
+                        "[MainLoop] mbarTMA[{}] try_wait   phase={}",
+                        stage,
+                        phaseParity,
+                        predicate=primaryThread,
+                    )
+                    nvgpu.MBarrierTryWaitParityOp(
+                        mbarTMA, phaseParity, ticks, mbarId=stage
+                    )
                     debug_print(
                         "[MainLoop] mbarTMA[{}] try_wait   phase={} [done]",
                         stage,
                         phaseParity,
-                        predicate=primaryThread)
+                        predicate=primaryThread,
+                    )
 
                     # Step 4.2. Create WGMMA Descriptors
                     a_offset = arith.muli(stage, c(lhs_tile_bytes))
                     a_tile_slice = memref.view(
-                        ir.MemRefType.get(a_tile_shape,
-                                          f16,
-                                          memory_space=smem_space),
-                        dynamic_smem, a_offset, [])
-                    b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)),
-                                          c(lhs_tile_bytes * num_stages))
+                        ir.MemRefType.get(a_tile_shape, f16, memory_space=smem_space),
+                        dynamic_smem,
+                        a_offset,
+                        [],
+                    )
+                    b_offset = arith.addi(
+                        arith.muli(stage, c(rhs_tile_bytes)),
+                        c(lhs_tile_bytes * num_stages),
+                    )
                     b_tile_slice = memref.view(
-                        ir.MemRefType.get(b_tile_shape,
-                                          f16,
-                                          memory_space=smem_space),
-                        dynamic_smem, b_offset, [])
-                    debug_print("[MainLoop] iv={} MMA a_offset={} b_offset={}",
-                                iv,
-                                a_offset,
-                                b_offset,
-                                predicate=primaryThread)
+                        ir.MemRefType.get(b_tile_shape, f16, memory_space=smem_space),
+                        dynamic_smem,
+                        b_offset,
+                        [],
+                    )
+                    debug_print(
+                        "[MainLoop] iv={} MMA a_offset={} b_offset={}",
+                        iv,
+                        a_offset,
+                        b_offset,
+                        predicate=primaryThread,
+                    )
                     da = nvgpu.WarpgroupGenerateDescriptorOp(
-                        a_wgmma_ty, a_tile_slice, a_tma_desc)
+                        a_wgmma_ty, a_tile_slice, a_tma_desc
+                    )
                     db = nvgpu.WarpgroupGenerateDescriptorOp(
-                        b_wgmma_ty, b_tile_slice, b_tma_desc)
+                        b_wgmma_ty, b_tile_slice, b_tma_desc
+                    )
 
                     # Step 4.3. MMA
                     carry_acc = for_op.inner_iter_args[0]
-                    new_acc = nvgpu.WarpgroupMmaOp(acc.type,
-                                                   da,
-                                                   db,
-                                                   carry_acc,
-                                                   transposeB=True)
+                    new_acc = nvgpu.WarpgroupMmaOp(
+                        acc.type, da, db, carry_acc, transposeB=True
+                    )
 
                     # Step 4.4. Load TMA for next stage
-                    p1 = arith.cmpi(arith.CmpIPredicate.ult,
-                                    arith.addi(iv, c(num_stages - 1)),
-                                    c(K // BLOCK_K))
+                    p1 = arith.cmpi(
+                        arith.CmpIPredicate.ult,
+                        arith.addi(iv, c(num_stages - 1)),
+                        c(K // BLOCK_K),
+                    )
                     p = arith.andi(primaryThread, p1)
                     with ir.InsertionPoint(scf.IfOp(p).then_block):
                         nextStage = arith.addi(iv, c(num_stages - 1))
                         nextSlot = arith.remui(nextStage, c(num_stages))
                         a_offset = arith.muli(nextSlot, c(lhs_tile_bytes))
                         a_tma_slice = memref.view(
-                            ir.MemRefType.get(a_tma_shape,
-                                              f16,
-                                              memory_space=smem_space),
-                            dynamic_smem, a_offset, [])
+                            ir.MemRefType.get(
+                                a_tma_shape, f16, memory_space=smem_space
+                            ),
+                            dynamic_smem,
+                            a_offset,
+                            [],
+                        )
                         b_offset = arith.addi(
                             arith.muli(nextSlot, c(rhs_tile_bytes)),
-                            c(lhs_tile_bytes * num_stages))
+                            c(lhs_tile_bytes * num_stages),
+                        )
                         b_tma_slice_1 = memref.view(
-                            ir.MemRefType.get(b_tma_shape,
-                                              f16,
-                                              memory_space=smem_space),
-                            dynamic_smem, b_offset, [])
+                            ir.MemRefType.get(
+                                b_tma_shape, f16, memory_space=smem_space
+                            ),
+                            dynamic_smem,
+                            b_offset,
+                            [],
+                        )
                         b_offset2 = arith.addi(
-                            b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+                            b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte)
+                        )
                         b_tma_slice_2 = memref.view(
-                            ir.MemRefType.get(b_tma_shape,
-                                              f16,
-                                              memory_space=smem_space),
-                            dynamic_smem, b_offset2, [])
+                            ir.MemRefType.get(
+                                b_tma_shape, f16, memory_space=smem_space
+                            ),
+                            dynamic_smem,
+                            b_offset2,
+                            [],
+                        )
 
                         coord = arith.muli(c(64), nextStage)
                         debug_print(
@@ -804,45 +965,53 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                             dimX,
                             dimY,
                             coord,
-                            predicate=primaryThread)
-                        nvgpu.TmaAsyncLoadOp(a_tma_slice,
-                                             mbarTMA,
-                                             a_tma_desc,
-                                             coordinates=[coord, dimX],
-                                             mbarId=nextSlot,
-                                             predicate=primaryThread)
-                        nvgpu.TmaAsyncLoadOp(b_tma_slice_1,
-                                             mbarTMA,
-                                             b_tma_desc,
-                                             coordinates=[dimY, coord],
-                                             mbarId=nextSlot,
-                                             predicate=primaryThread)
+                            predicate=primaryThread,
+                        )
+                        nvgpu.TmaAsyncLoadOp(
+                            a_tma_slice,
+                            mbarTMA,
+                            a_tma_desc,
+                            coordinates=[coord, dimX],
+                            mbarId=nextSlot,
+                            predicate=primaryThread,
+                        )
+                        nvgpu.TmaAsyncLoadOp(
+                            b_tma_slice_1,
+                            mbarTMA,
+                            b_tma_desc,
+                            coordinates=[dimY, coord],
+                            mbarId=nextSlot,
+                            predicate=primaryThread,
+                        )
                         dimY2 = arith.addi(dimY, c(64))
-                        nvgpu.TmaAsyncLoadOp(b_tma_slice_2,
-                                             mbarTMA,
-                                             b_tma_desc,
-                                             coordinates=[dimY2, coord],
-                                             mbarId=nextSlot,
-                                             predicate=primaryThread)
-
-                        debug_print("[MainLoop] mbarTMA[{}] arrive",
-                                    nextSlot,
-                                    predicate=primaryThread)
-                        nvgpu.mbarrier_arrive_expect_tx(
+                        nvgpu.TmaAsyncLoadOp(
+                            b_tma_slice_2,
                             mbarTMA,
-                            c(txcount),
+                            b_tma_desc,
+                            coordinates=[dimY2, coord],
+                            mbarId=nextSlot,
+                            predicate=primaryThread,
+                        )
+
+                        debug_print(
+                            "[MainLoop] mbarTMA[{}] arrive",
+                            nextSlot,
+                            predicate=primaryThread,
+                        )
+                        nvgpu.mbarrier_arrive_expect_tx(
+                            mbarTMA, c(txcount), nextSlot, predicate=primaryThread
+                        )
+                        debug_print(
+                            "[MainLoop] mbarTMA[{}] arrive [done]",
                             nextSlot,
-                            predicate=primaryThread)
-                        debug_print("[MainLoop] mbarTMA[{}] arrive [done]",
-                                    nextSlot,
-                                    predicate=primaryThread)
+                            predicate=primaryThread,
+                        )
                         scf.yield_([])
                     # Step 4.5. Change the phaseParity
-                    p = arith.cmpi(arith.CmpIPredicate.eq, stage,
-                                   c(num_stages - 1))
+                    p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
                     phaseParity = arith.select(
-                        p, arith.xori(phaseParity, arith.constant(i1, 1)),
-                        phaseParity)
+                        p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity
+                    )
 
                     # Step 4.5. Yield
                     scf.yield_([new_acc, phaseParity])
@@ -852,35 +1021,46 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                 gpu.barrier()
 
                 # Step 6. Epilogue (registers --> shared memory)
-                acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N),
-                                                c_elem_ty,
-                                                memory_space=smem_space)
+                acc_smem_ty = ir.MemRefType.get(
+                    (BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space
+                )
                 acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
                 debug_print("Storing", predicate=primaryThread)
                 nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
                 gpu.barrier()
 
                 # GPU Step 7. Epilogue (shared memory --> global memory)
-                fd = ir.MemRefType.get([BLOCK_M * BLOCK_N],
-                                       c_elem_ty,
-                                       memory_space=smem_space)
+                fd = ir.MemRefType.get(
+                    [BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space
+                )
                 collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
                 rty = ir.MemRefType.get(
-                    (BLOCK_M, BLOCK_N), c_elem_ty,
-                    ir.Attribute.parse("strided<[" + str(N) +
-                                       ", 1], offset: ?>"))
+                    (BLOCK_M, BLOCK_N),
+                    c_elem_ty,
+                    ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"),
+                )
                 c_device_per_block = memref.SubViewOp(
-                    rty, c_device, [dimX, dimY], [], [],
-                    [MLIR_DYNAMIC, MLIR_DYNAMIC], [BLOCK_M, BLOCK_N], [1, 1])
+                    rty,
+                    c_device,
+                    [dimX, dimY],
+                    [],
+                    [],
+                    [MLIR_DYNAMIC, MLIR_DYNAMIC],
+                    [BLOCK_M, BLOCK_N],
+                    [1, 1],
+                )
                 vlen = 1
-                for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N),
-                                   c(vlen * WARP_GROUP_SIZE))
+                for_op = scf.ForOp(
+                    tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE)
+                )
                 with ir.InsertionPoint(for_op.body):
                     x = arith.divui(for_op.induction_variable, c(BLOCK_M))
                     y = arith.remui(for_op.induction_variable, c(BLOCK_N))
-                    vdata = vector.load(ir.VectorType.get(
-                        (vlen, ), c_elem_ty), collapsed_smem,
-                                        [for_op.induction_variable])
+                    vdata = vector.load(
+                        ir.VectorType.get((vlen,), c_elem_ty),
+                        collapsed_smem,
+                        [for_op.induction_variable],
+                    )
                     vector.store(vdata, c_device_per_block, [x, y])
                     scf.yield_([])
 
@@ -892,6 +1072,7 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
             gpu.wait(token_ty, [t9])
 
     mlir_matmul_multistage.func_op.attributes[
-        "llvm.emit_c_interface"] = ir.UnitAttr.get()
+        "llvm.emit_c_interface"
+    ] = ir.UnitAttr.get()
     module.operation.verify()
     return module
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
index ab218812bb4d96..1c9cc74fcd169c 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
@@ -19,8 +19,7 @@
 class NvgpuCompiler:
     """Nvgpu class for compiling and building MLIR modules."""
 
-    def __init__(self, options: str, opt_level: int,
-                 shared_libs: Sequence[str]):
+    def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
         pipeline = f"builtin.module(gpu-lower-to-nvvm-pipeline{{{options}}})"
         self.pipeline = pipeline
         self.shared_libs = shared_libs
@@ -36,12 +35,11 @@ def compile(self, module: ir.Module):
 
     def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
         """Wraps the module in a JIT execution engine."""
-        return execution_engine.ExecutionEngine(module,
-                                                opt_level=self.opt_level,
-                                                shared_libs=self.shared_libs)
+        return execution_engine.ExecutionEngine(
+            module, opt_level=self.opt_level, shared_libs=self.shared_libs
+        )
 
-    def compile_and_jit(self,
-                        module: ir.Module) -> execution_engine.ExecutionEngine:
+    def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
         """Compiles and jits the module."""
         self.compile(module)
         return self.jit(module)

>From aa3e3486ef011cd4bc9bb784f81b73981e913151 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Tue, 13 Feb 2024 09:09:46 +0000
Subject: [PATCH 4/7] fix the spelling mistake

---
 .../Integration/GPU/CUDA/sm90/python/matmul.py   | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
index ff9c950193e695..f5d5b30c70d1e1 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
@@ -95,12 +95,12 @@ def generate_matmul(
     BLOCK_M=128,
     BLOCK_N=128,
     BLOCK_K=64,
-    use_warp_specilization=True,
+    use_warp_specialization=True,
     saveIR=False,
     max_num_stages=3,
 ):
     with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown():
-        if use_warp_specilization:
+        if use_warp_specialization:
             mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(
                 input_type,
                 output_type,
@@ -161,7 +161,7 @@ def matmul(
     BLOCK_M=128,
     BLOCK_N=128,
     BLOCK_K=64,
-    use_warp_specilization=True,
+    use_warp_specialization=True,
     saveIR=False,
     max_num_stages=3,
     print_results=False,
@@ -170,7 +170,7 @@ def matmul(
     # Print the configuration
     ity = "f16" if input_type == np.float16 else "f32"
     oty = "f16" if output_type == np.float16 else "f32"
-    gemmty = "Warp Specilization" if use_warp_specilization else "Multistage"
+    gemmty = "Warp specialization" if use_warp_specialization else "Multistage"
     print(
         "===-- Running GEMM "
         + gemmty
@@ -207,7 +207,7 @@ def matmul(
         BLOCK_M,
         BLOCK_N,
         BLOCK_K,
-        use_warp_specilization,
+        use_warp_specialization,
         saveIR,
         max_num_stages,
     )
@@ -221,7 +221,7 @@ def matmul(
     mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
     kernelName = (
         "mlir_matmul_warpspecialized"
-        if use_warp_specilization
+        if use_warp_specialization
         else "mlir_matmul_multistage"
     )
 
@@ -252,7 +252,7 @@ def matmul(
     128,
     4096,
     max_num_stages=3,
-    use_warp_specilization=False,
+    use_warp_specialization=False,
 )
 # GEMM Warp Specilized  f32 += f16 * f16
 matmul(
@@ -262,5 +262,5 @@ def matmul(
     1024,
     512,
     max_num_stages=3,
-    use_warp_specilization=True,
+    use_warp_specialization=True,
 )

>From dea779aae915d610b409660a2fc6d0b2a3697a43 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Tue, 20 Feb 2024 10:09:11 +0000
Subject: [PATCH 5/7] address comments

---
 .../CUDA/sm90/python/tools/matmulBuilder.py   | 201 +++++++++++-------
 1 file changed, 121 insertions(+), 80 deletions(-)

diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
index 4c132ec992a461..434af5e0920466 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
@@ -10,6 +10,7 @@
 from mlir.dialects import builtin
 from mlir.dialects import scf
 from mlir.dialects import vector
+from mlir.extras import types as T
 
 TMA_LAST_DIM_F16 = 64  # 128B flaot16
 WARP_SIZE = 32
@@ -22,12 +23,8 @@
 CONSUMER_PRIMARY_THREAD = 0
 
 MLIR_DYNAMIC = -9223372036854775808
-f16_byte = 2
-f32_byte = 4
 
 DEBUG = False
-
-
 def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False):
     if not DEBUG and not forcePrint:
         return
@@ -59,11 +56,52 @@ def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False):
         scf.yield_([])
 
 
+def get_type_str(ty):
+    if ir.F16Type.isinstance(ty):
+        return "f16"
+    if ir.F32Type.isinstance(ty):
+        return "f32"
+    if ir.F64Type.isinstance(ty):
+        return "f64"
+    if ir.IntegerType.isinstance(ty):
+        return "i" + str(ir.IntegerType(ty).width)
+    if ir.IndexType.isinstance(ty):
+        return "T.index()"
+    raise NotImplementedError(ty)
+
+
+def get_type_size(ty):
+    if ir.F16Type.isinstance(ty):
+        return 2
+    if ir.F32Type.isinstance(ty):
+        return 4
+    if ir.F64Type.isinstance(ty):
+        return 8
+    if ir.IntegerType.isinstance(ty):
+        return ir.IntegerType(ty).width // 8
+    if ir.IndexType.isinstance(ty):
+        return 8
+    raise NotImplementedError(ty)
+
+
+def get_mlir_ty(dtype):
+    if dtype == np.float16:
+        return T.f16()
+    if dtype == np.float32:
+        return T.f32()
+    if dtype == np.float64:
+        return T.f64()
+    if dtype == np.int32:
+        return T.i32()
+    if dtype == np.int64:
+        return T.i64()
+    raise NotImplementedError(dtype)
+
+
 def c(value, ty=None):
-    ty = ir.IndexType.get() if ty is None else ty
+    ty = T.index() if ty is None else ty
     return arith.constant(ty, value)
 
-
 def generate_matmul_ws(
     input_type=np.float16,
     output_type=np.float32,
@@ -86,27 +124,21 @@ def generate_matmul_ws(
     num_stages = min(required_stages, max_num_stages)
 
     module = ir.Module.create()
-    f16 = ir.F16Type.get()
-    f32 = ir.F32Type.get()
-    i1 = ir.IntegerType.get_signless(1)
-    i32 = ir.IntegerType.get_signless(32)
-    index = ir.IndexType.get()
-    i8 = ir.IntegerType.get_signless(8)
     token_ty = ir.Type.parse("!gpu.async.token")
-    a_ty = ir.MemRefType.get([M, K], f16)
-    b_ty = ir.MemRefType.get((K, N), f16)
-    c_elem_ty = f16 if output_type == np.float16 else f32
+    a_elem_ty = get_mlir_ty(input_type)
+    b_elem_ty = get_mlir_ty(input_type)
+    c_elem_ty = get_mlir_ty(output_type)
+    a_ty = ir.MemRefType.get([M, K], a_elem_ty)
+    b_ty = ir.MemRefType.get((K, N), b_elem_ty)
     c_ty = ir.MemRefType.get((M, N), c_elem_ty)
     a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
     b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
     b_tile_shape = (BLOCK_K, BLOCK_N)
-    txcount = (
-        (b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])
-    ) * f16_byte
+    txcount = (b_tile_shape[0] * b_tile_shape[1] * get_type_size(a_elem_ty)) + (
+        a_tile_shape[0] * a_tile_shape[1] * get_type_size(b_elem_ty)
+    )
     smem_space_str = "#gpu.address_space<workgroup>"
     smem_space = ir.Attribute.parse(smem_space_str)
-    input_type_str = "f16" if input_type == np.float16 else "f32"
-    output_type_str = "f16" if output_type == np.float16 else "f32"
     mbar_ty = ir.Type.parse(
         "!nvgpu.mbarrier.group<memorySpace = "
         + str(smem_space)
@@ -120,7 +152,7 @@ def generate_matmul_ws(
         + "x"
         + str(TMA_LAST_DIM_F16)
         + "x"
-        + str(input_type_str)
+        + get_type_str(a_elem_ty)
         + ", "
         + str(smem_space)
         + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
@@ -131,7 +163,7 @@ def generate_matmul_ws(
         + "x"
         + str(TMA_LAST_DIM_F16)
         + "x"
-        + str(input_type_str)
+        + get_type_str(b_elem_ty)
         + ", "
         + str(smem_space)
         + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
@@ -142,7 +174,7 @@ def generate_matmul_ws(
         + "x"
         + str(BLOCK_N)
         + "x"
-        + str(output_type_str)
+        + get_type_str(c_elem_ty)
         + ">>"
     )
     a_wgmma_ty = ir.Type.parse(
@@ -151,7 +183,7 @@ def generate_matmul_ws(
         + "x"
         + str(BLOCK_K)
         + "x"
-        + str(input_type_str)
+        + get_type_str(a_elem_ty)
         + ", "
         + smem_space_str
         + ">>"
@@ -162,7 +194,7 @@ def generate_matmul_ws(
         + "x"
         + str(BLOCK_N)
         + "x"
-        + str(input_type_str)
+        + get_type_str(a_elem_ty)
         + ", "
         + smem_space_str
         + ">>"
@@ -172,10 +204,10 @@ def generate_matmul_ws(
 
         @func.FuncOp.from_py_func(a_ty, b_ty, c_ty)
         def mlir_matmul_warpspecialized(a_host, b_host, c_host):
-            lhs_tile_bytes = BLOCK_M * BLOCK_K * f16_byte
-            rhs_tile_bytes = BLOCK_N * BLOCK_K * f16_byte
+            lhs_tile_bytes = BLOCK_M * BLOCK_K * get_type_size(a_elem_ty)
+            rhs_tile_bytes = BLOCK_N * BLOCK_K * get_type_size(b_elem_ty)
             smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
-            smem_size_output = BLOCK_M * BLOCK_N * f32_byte
+            smem_size_output = BLOCK_M * BLOCK_N * get_type_size(c_elem_ty)
             smem_size = max(smem_size_input, smem_size_output)
 
             # Step 1. Allocate device memory and memcpy
@@ -195,7 +227,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
             tma_descs = []
             for x_device, tensor_map_ty, tile_shape in tma_specs:
                 x_unranked = memref.cast(
-                    ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device
+                    ir.UnrankedMemRefType.get(a_elem_ty, a_ty.memory_space), x_device
                 )
                 tma_descs.append(
                     nvgpu.TmaCreateDescriptorOp(
@@ -215,14 +247,14 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                 [t7],
                 *map(c, grid),
                 *map(c, block),
-                dynamicSharedMemorySize=c(smem_size, ty=i32)
+                dynamicSharedMemorySize=c(smem_size, ty=T.i32())
             )
-            launch_op.body.blocks.append(*([index] * 12))
+            launch_op.body.blocks.append(*([T.index()] * 12))
             with ir.InsertionPoint(launch_op.body.blocks[0]):
                 # GPU Step 0. This is need for vectorized ld/st
                 memref.assume_alignment(c_device, 16)
                 dynamic_smem = gpu.dynamic_shared_memory(
-                    ir.MemRefType.get((MLIR_DYNAMIC,), i8, memory_space=smem_space)
+                    ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space)
                 )
                 ticks = c(10000000)
 
@@ -275,7 +307,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
 
                     # Step 5.2. TMA Main Loop
                     for_op = scf.ForOp(
-                        c(0), c(K // BLOCK_K), c(1), [arith.constant(i1, 1)]
+                        c(0), c(K // BLOCK_K), c(1), [arith.constant(T.bool(), 1)]
                     )
                     with ir.InsertionPoint(for_op.body):
                         phaseParity = for_op.inner_iter_args[0]
@@ -303,7 +335,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                         p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
                         phaseParity = arith.select(
                             p,
-                            arith.xori(phaseParity, arith.constant(i1, 1)),
+                            arith.xori(phaseParity, arith.constant(T.bool(), 1)),
                             phaseParity,
                         )
 
@@ -311,7 +343,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                         a_offset = arith.muli(stage, c(lhs_tile_bytes))
                         a_tma_slice = memref.view(
                             ir.MemRefType.get(
-                                a_tma_shape, f16, memory_space=smem_space
+                                a_tma_shape, a_elem_ty, memory_space=smem_space
                             ),
                             dynamic_smem,
                             a_offset,
@@ -323,18 +355,19 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                         )
                         b_tma_slice_1 = memref.view(
                             ir.MemRefType.get(
-                                b_tma_shape, f16, memory_space=smem_space
+                                b_tma_shape, b_elem_ty, memory_space=smem_space
                             ),
                             dynamic_smem,
                             b_offset,
                             [],
                         )
                         b_offset2 = arith.addi(
-                            b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte)
+                            b_offset,
+                            c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)),
                         )
                         b_tma_slice_2 = memref.view(
                             ir.MemRefType.get(
-                                b_tma_shape, f16, memory_space=smem_space
+                                b_tma_shape, b_elem_ty, memory_space=smem_space
                             ),
                             dynamic_smem,
                             b_offset2,
@@ -406,7 +439,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
 
                     # Step 6.3. MMA Main Loop
                     for_op = scf.ForOp(
-                        c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)]
+                        c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(T.bool(), 0)]
                     )
                     with ir.InsertionPoint(for_op.body):
                         # Step 6.3.1. Wait mbar1
@@ -435,7 +468,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                         a_offset = arith.muli(stage, c(lhs_tile_bytes))
                         a_tile_slice = memref.view(
                             ir.MemRefType.get(
-                                a_tile_shape, f16, memory_space=smem_space
+                                a_tile_shape, a_elem_ty, memory_space=smem_space
                             ),
                             dynamic_smem,
                             a_offset,
@@ -447,7 +480,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                         )
                         b_tile_slice = memref.view(
                             ir.MemRefType.get(
-                                b_tile_shape, f16, memory_space=smem_space
+                                b_tile_shape, b_elem_ty, memory_space=smem_space
                             ),
                             dynamic_smem,
                             b_offset,
@@ -506,7 +539,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                         p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
                         phaseParity = arith.select(
                             p,
-                            arith.xori(phaseParity, arith.constant(i1, 1)),
+                            arith.xori(phaseParity, arith.constant(T.bool(), 1)),
                             phaseParity,
                         )
 
@@ -604,27 +637,21 @@ def generate_matmul_multistage(
     num_stages = min(required_stages, max_num_stages)
 
     module = ir.Module.create()
-    f16 = ir.F16Type.get()
-    f32 = ir.F32Type.get()
-    i1 = ir.IntegerType.get_signless(1)
-    i32 = ir.IntegerType.get_signless(32)
-    index = ir.IndexType.get()
-    i8 = ir.IntegerType.get_signless(8)
     token_ty = ir.Type.parse("!gpu.async.token")
-    a_ty = ir.MemRefType.get([M, K], f16)
-    b_ty = ir.MemRefType.get((K, N), f16)
-    c_elem_ty = f16 if output_type == np.float16 else f32
+    a_elem_ty = get_mlir_ty(input_type)
+    b_elem_ty = get_mlir_ty(input_type)
+    c_elem_ty = get_mlir_ty(output_type)
+    a_ty = ir.MemRefType.get([M, K], a_elem_ty)
+    b_ty = ir.MemRefType.get((K, N), b_elem_ty)
     c_ty = ir.MemRefType.get((M, N), c_elem_ty)
     a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
     b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
     b_tile_shape = (BLOCK_K, BLOCK_N)
-    txcount = (
-        (b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])
-    ) * f16_byte
+    txcount = (b_tile_shape[0] * b_tile_shape[1] * get_type_size(a_elem_ty)) + (
+        a_tile_shape[0] * a_tile_shape[1] * get_type_size(b_elem_ty)
+    )
     smem_space_str = "#gpu.address_space<workgroup>"
     smem_space = ir.Attribute.parse(smem_space_str)
-    input_type_str = "f16" if input_type == np.float16 else "f32"
-    output_type_str = "f16" if output_type == np.float16 else "f32"
     mbar_ty = ir.Type.parse(
         "!nvgpu.mbarrier.group<memorySpace = "
         + str(smem_space)
@@ -638,7 +665,7 @@ def generate_matmul_multistage(
         + "x"
         + str(TMA_LAST_DIM_F16)
         + "x"
-        + str(input_type_str)
+        + get_type_str(a_elem_ty)
         + ", "
         + str(smem_space)
         + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
@@ -649,7 +676,7 @@ def generate_matmul_multistage(
         + "x"
         + str(TMA_LAST_DIM_F16)
         + "x"
-        + str(input_type_str)
+        + get_type_str(b_elem_ty)
         + ", "
         + str(smem_space)
         + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
@@ -660,7 +687,7 @@ def generate_matmul_multistage(
         + "x"
         + str(BLOCK_N)
         + "x"
-        + str(output_type_str)
+        + get_type_str(c_elem_ty)
         + ">>"
     )
     a_wgmma_ty = ir.Type.parse(
@@ -669,7 +696,7 @@ def generate_matmul_multistage(
         + "x"
         + str(BLOCK_K)
         + "x"
-        + str(input_type_str)
+        + get_type_str(a_elem_ty)
         + ", "
         + smem_space_str
         + ">>"
@@ -680,7 +707,7 @@ def generate_matmul_multistage(
         + "x"
         + str(BLOCK_N)
         + "x"
-        + str(input_type_str)
+        + get_type_str(a_elem_ty)
         + ", "
         + smem_space_str
         + ">>"
@@ -690,10 +717,10 @@ def generate_matmul_multistage(
 
         @func.FuncOp.from_py_func(a_ty, b_ty, c_ty)
         def mlir_matmul_multistage(a_host, b_host, c_host):
-            lhs_tile_bytes = BLOCK_M * BLOCK_K * f16_byte
-            rhs_tile_bytes = BLOCK_N * BLOCK_K * f16_byte
+            lhs_tile_bytes = BLOCK_M * BLOCK_K * get_type_size(a_elem_ty)
+            rhs_tile_bytes = BLOCK_N * BLOCK_K * get_type_size(b_elem_ty)
             smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
-            smem_size_output = BLOCK_M * BLOCK_N * f32_byte
+            smem_size_output = BLOCK_M * BLOCK_N * get_type_size(c_elem_ty)
             smem_size = max(smem_size_input, smem_size_output)
 
             # Step 1. Allocate device memory and memcpy
@@ -713,7 +740,7 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
             tma_descs = []
             for x_device, tensor_map_ty, tile_shape in tma_specs:
                 x_unranked = memref.cast(
-                    ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device
+                    ir.UnrankedMemRefType.get(a_elem_ty, a_ty.memory_space), x_device
                 )
                 tma_descs.append(
                     nvgpu.TmaCreateDescriptorOp(
@@ -733,14 +760,14 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                 [t7],
                 *map(c, grid),
                 *map(c, block),
-                dynamicSharedMemorySize=c(smem_size, ty=i32)
+                dynamicSharedMemorySize=c(smem_size, ty=T.i32())
             )
-            launch_op.body.blocks.append(*([index] * 12))
+            launch_op.body.blocks.append(*([T.index()] * 12))
             with ir.InsertionPoint(launch_op.body.blocks[0]):
                 # GPU Step 0. Bootstrapping
                 memref.assume_alignment(c_device, 16)
                 dynamic_smem = gpu.dynamic_shared_memory(
-                    ir.MemRefType.get((MLIR_DYNAMIC,), i8, memory_space=smem_space)
+                    ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space)
                 )
                 ticks = c(10000000)
                 tidx = gpu.thread_id(gpu.Dimension.x)
@@ -769,7 +796,9 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                     # Step 3.1. Calculate offsets
                     a_offset = arith.muli(iv, c(lhs_tile_bytes))
                     a_tma_slice = memref.view(
-                        ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
+                        ir.MemRefType.get(
+                            a_tma_shape, a_elem_ty, memory_space=smem_space
+                        ),
                         dynamic_smem,
                         a_offset,
                         [],
@@ -779,16 +808,21 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                         c(lhs_tile_bytes * num_stages),
                     )
                     b_tma_slice_1 = memref.view(
-                        ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+                        ir.MemRefType.get(
+                            b_tma_shape, b_elem_ty, memory_space=smem_space
+                        ),
                         dynamic_smem,
                         b_offset,
                         [],
                     )
                     b_offset2 = arith.addi(
-                        b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte)
+                        b_offset,
+                        c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)),
                     )
                     b_tma_slice_2 = memref.view(
-                        ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+                        ir.MemRefType.get(
+                            b_tma_shape, b_elem_ty, memory_space=smem_space
+                        ),
                         dynamic_smem,
                         b_offset2,
                         [],
@@ -850,7 +884,7 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                 # GPU Step 4. Main Loop
                 acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
                 for_op = scf.ForOp(
-                    c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)]
+                    c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(T.bool(), 0)]
                 )
                 with ir.InsertionPoint(for_op.body):
                     # Step 4.1. Wait mbarTMA
@@ -876,7 +910,9 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                     # Step 4.2. Create WGMMA Descriptors
                     a_offset = arith.muli(stage, c(lhs_tile_bytes))
                     a_tile_slice = memref.view(
-                        ir.MemRefType.get(a_tile_shape, f16, memory_space=smem_space),
+                        ir.MemRefType.get(
+                            a_tile_shape, a_elem_ty, memory_space=smem_space
+                        ),
                         dynamic_smem,
                         a_offset,
                         [],
@@ -886,7 +922,9 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                         c(lhs_tile_bytes * num_stages),
                     )
                     b_tile_slice = memref.view(
-                        ir.MemRefType.get(b_tile_shape, f16, memory_space=smem_space),
+                        ir.MemRefType.get(
+                            b_tile_shape, b_elem_ty, memory_space=smem_space
+                        ),
                         dynamic_smem,
                         b_offset,
                         [],
@@ -924,7 +962,7 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                         a_offset = arith.muli(nextSlot, c(lhs_tile_bytes))
                         a_tma_slice = memref.view(
                             ir.MemRefType.get(
-                                a_tma_shape, f16, memory_space=smem_space
+                                a_tma_shape, a_elem_ty, memory_space=smem_space
                             ),
                             dynamic_smem,
                             a_offset,
@@ -936,18 +974,19 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                         )
                         b_tma_slice_1 = memref.view(
                             ir.MemRefType.get(
-                                b_tma_shape, f16, memory_space=smem_space
+                                b_tma_shape, b_elem_ty, memory_space=smem_space
                             ),
                             dynamic_smem,
                             b_offset,
                             [],
                         )
                         b_offset2 = arith.addi(
-                            b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte)
+                            b_offset,
+                            c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)),
                         )
                         b_tma_slice_2 = memref.view(
                             ir.MemRefType.get(
-                                b_tma_shape, f16, memory_space=smem_space
+                                b_tma_shape, b_elem_ty, memory_space=smem_space
                             ),
                             dynamic_smem,
                             b_offset2,
@@ -1010,7 +1049,9 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                     # Step 4.5. Change the phaseParity
                     p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
                     phaseParity = arith.select(
-                        p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity
+                        p,
+                        arith.xori(phaseParity, arith.constant(T.bool(), 1)),
+                        phaseParity,
                     )
 
                     # Step 4.5. Yield

>From d4393a47f1d93253bac35c5ce784bc9b06c463f1 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Tue, 20 Feb 2024 10:23:23 +0000
Subject: [PATCH 6/7] format

---
 .../Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py    | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
index 434af5e0920466..52732b1c01426e 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
@@ -25,6 +25,8 @@
 MLIR_DYNAMIC = -9223372036854775808
 
 DEBUG = False
+
+
 def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False):
     if not DEBUG and not forcePrint:
         return
@@ -102,6 +104,7 @@ def c(value, ty=None):
     ty = T.index() if ty is None else ty
     return arith.constant(ty, value)
 
+
 def generate_matmul_ws(
     input_type=np.float16,
     output_type=np.float32,

>From 09fb9c87480422460f374be1133ebfbf13f8d6b1 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Sun, 3 Mar 2024 08:23:20 +0000
Subject: [PATCH 7/7] Allow multiple stages, and fix the kernels. Add
 test_short that test multiple cases.

---
 .../GPU/CUDA/sm90/python/matmul.py            | 157 +++++++--
 .../CUDA/sm90/python/tools/matmulBuilder.py   | 306 ++++++++++--------
 2 files changed, 308 insertions(+), 155 deletions(-)

diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
index f5d5b30c70d1e1..88609273a8b548 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
@@ -1,6 +1,6 @@
 # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
 # RUN:   %PYTHON %s | FileCheck %s
-# CHECK: PASS
+
 
 # ===--- GEMM Hopper Tensor Core Integration Test ---===
 #
@@ -168,6 +168,8 @@ def matmul(
     no_verify=False,
 ):
     # Print the configuration
+    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
+    num_stages = min(required_stages, max_num_stages)
     ity = "f16" if input_type == np.float16 else "f32"
     oty = "f16" if output_type == np.float16 else "f32"
     gemmty = "Warp specialization" if use_warp_specialization else "Multistage"
@@ -193,7 +195,7 @@ def matmul(
         + "x"
         + str(BLOCK_K)
         + ", stages "
-        + str(max_num_stages)
+        + str(num_stages)
         + " --==="
     )
 
@@ -209,7 +211,7 @@ def matmul(
         BLOCK_K,
         use_warp_specialization,
         saveIR,
-        max_num_stages,
+        num_stages,
     )
 
     # Allocate matrices and invoke the matmul
@@ -219,10 +221,17 @@ def matmul(
     mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
     mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
     mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
-    kernelName = (
-        "mlir_matmul_warpspecialized"
-        if use_warp_specialization
-        else "mlir_matmul_multistage"
+    kernelName = matmulBuilder.make_kernel_name(
+        input_type,
+        output_type,
+        M,
+        N,
+        K,
+        BLOCK_M,
+        BLOCK_N,
+        BLOCK_K,
+        num_stages,
+        use_warp_specialization,
     )
 
     # Launch the MLIR generated kernel
@@ -243,24 +252,118 @@ def matmul(
 
     print("PASS ")
 
+# Takes longer time to run
+def test_long():
+    for stages in range(1,7):
+        for M in [128, 512, 1024, 4096, 8192]:
+            for N in [128, 512, 1024, 4096, 8192]:
+                for K in [64, 128, 512, 1024, 4096, 8192]:
+                    matmul(
+                        np.float16,
+                        np.float32,
+                        M,N,
+                        K,
+                        max_num_stages=stages,
+                        use_warp_specialization=False,
+                        no_verify=True,
+                        saveIR=True
+                    )
+                    matmul(
+                        np.float16,
+                        np.float32,
+                        M,N,
+                        K,
+                        max_num_stages=stages,
+                        use_warp_specialization=True,                    
+                    )
 
-# GEMM Multistage       f32 += f16 * f16
-matmul(
-    np.float16,
-    np.float32,
-    128,
-    128,
-    4096,
-    max_num_stages=3,
-    use_warp_specialization=False,
-)
-# GEMM Warp Specilized  f32 += f16 * f16
-matmul(
-    np.float16,
-    np.float32,
-    256,
-    1024,
-    512,
-    max_num_stages=3,
-    use_warp_specialization=True,
-)
+def test_short():
+    for stages in [1, 3]:
+        for M in [128, 512]:
+            for N in [128, 512]:
+                for K in [64, 512]:
+                    matmul(
+                        np.float16,
+                        np.float32,
+                        M,N,
+                        K,
+                        max_num_stages=stages,
+                        use_warp_specialization=False,
+                        no_verify=True,
+                        saveIR=True
+                    )
+                    matmul(
+                        np.float16,
+                        np.float32,
+                        M,N,
+                        K,
+                        max_num_stages=stages,
+                        use_warp_specialization=True,                    
+                    )
+
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 3 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 3 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 2 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 2 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 3 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 3 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 3 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 3 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 3 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 3 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 3 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 3 --===
+# CHECK: PASS
+
+test_short()
\ No newline at end of file
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
index 52732b1c01426e..a541378359fe97 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
@@ -105,6 +105,41 @@ def c(value, ty=None):
     return arith.constant(ty, value)
 
 
+def make_kernel_name(
+    input_type=np.float16,
+    output_type=np.float32,
+    M=4096,
+    N=4096,
+    K=4096,
+    BLOCK_M=128,
+    BLOCK_N=128,
+    BLOCK_K=128,
+    num_stages=3,
+    use_warp_specialization=False,
+):
+    kernelName = (
+        "warpspecialized"
+        if use_warp_specialization
+        else "multistage"
+    )
+    return (
+        kernelName
+        + "_"
+        + str(M)
+        + "x"
+        + str(N)
+        + "x"
+        + str(K)
+        + "_"
+        + str(BLOCK_M)
+        + "x"
+        + str(BLOCK_N)
+        + "x"
+        + str(BLOCK_K)
+        + "_"
+        + str(num_stages)
+    )
+
 def generate_matmul_ws(
     input_type=np.float16,
     output_type=np.float32,
@@ -114,7 +149,7 @@ def generate_matmul_ws(
     BLOCK_M=128,
     BLOCK_N=128,
     BLOCK_K=128,
-    max_num_stages=3,
+    num_stages=3,
 ):
     # Limitaitons for now
     assert input_type == np.float16
@@ -123,9 +158,6 @@ def generate_matmul_ws(
     assert N % BLOCK_N == 0
     assert K % BLOCK_K == 0
 
-    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
-    num_stages = min(required_stages, max_num_stages)
-
     module = ir.Module.create()
     token_ty = ir.Type.parse("!gpu.async.token")
     a_elem_ty = get_mlir_ty(input_type)
@@ -202,11 +234,15 @@ def generate_matmul_ws(
         + smem_space_str
         + ">>"
     )
-
+    kernelName = make_kernel_name(
+        input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_stages, True
+    )
     with ir.InsertionPoint(module.body):
-
-        @func.FuncOp.from_py_func(a_ty, b_ty, c_ty)
-        def mlir_matmul_warpspecialized(a_host, b_host, c_host):
+        fop = func.FuncOp(kernelName, ([a_ty, b_ty, c_ty], []))
+        with ir.InsertionPoint(fop.add_entry_block()):
+            a_host = fop.arguments[0]
+            b_host = fop.arguments[1]
+            c_host = fop.arguments[2]
             lhs_tile_bytes = BLOCK_M * BLOCK_K * get_type_size(a_elem_ty)
             rhs_tile_bytes = BLOCK_N * BLOCK_K * get_type_size(b_elem_ty)
             smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
@@ -301,6 +337,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                 nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=wgPrimaryThread)
                 nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=wgPrimaryThread)
 
+                ns = num_stages if num_stages == 1 else num_stages - 1
                 # GPU Step 5. Producer Warpgroup (TMA Warpgroup)
                 with ir.InsertionPoint(scf.IfOp(is_producer).then_block):
                     # Step 5.1. Reduce register size
@@ -319,7 +356,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
 
                         # Step 5.2.1. Wait mbarDONE
                         debug_print(
-                            "[prod] {}  | mbarDONE[{}] try_wait  phase={}",
+                            "[prod] iv={}  | mbarDONE[{}] try_wait  phase={}",
                             iv,
                             stage,
                             phaseParity,
@@ -329,7 +366,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                             mbarDONE, phaseParity, ticks, mbarId=stage
                         )
                         debug_print(
-                            "[prod] {}  | mbarDONE[{}] try_wait  phase={} [done]",
+                            "[prod] iv={}  | mbarDONE[{}] try_wait  phase={} [done]",
                             iv,
                             stage,
                             phaseParity,
@@ -412,7 +449,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
 
                         # Step 5.2.3. Arrive mbarTMA
                         debug_print(
-                            "[prod] {}  | mbarTMA[{}] arrive",
+                            "[prod] iv={}  | mbarTMA[{}] arrive",
                             iv,
                             stage,
                             predicate=producerPrimaryThread,
@@ -421,7 +458,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                             mbarTMA, c(txcount), stage, predicate=producerPrimaryThread
                         )
                         debug_print(
-                            "[prod] {}  | mbarTMA[{}] arrive [done]",
+                            "[prod] iv={}  | mbarTMA[{}] arrive [done]",
                             iv,
                             stage,
                             predicate=producerPrimaryThread,
@@ -450,7 +487,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                         iv = for_op.induction_variable
                         stage = arith.remui(iv, c(num_stages))
                         debug_print(
-                            "[cons] {}  | mbarTMA[{}] try_wait   phase={}",
+                            "[cons] iv={}  | mbarTMA[{}] try_wait   phase={}",
                             iv,
                             stage,
                             phaseParity,
@@ -460,7 +497,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                             mbarTMA, phaseParity, ticks, mbarId=stage
                         )
                         debug_print(
-                            "[cons] {}  | mbarTMA[{}] try_wait   phase={} [done]",
+                            "[cons] iv={}  | mbarTMA[{}] try_wait   phase={} [done]",
                             iv,
                             stage,
                             phaseParity,
@@ -509,32 +546,29 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                         )
 
                         # Step 6.3.4. Arrive mbarDONE
-                        p1 = arith.cmpi(arith.CmpIPredicate.sgt, iv, c(0))
-                        p_arrive = arith.andi(consumerPrimaryThread, p1)
+                        if num_stages == 1:
+                            p_arrive = consumerPrimaryThread
+                        else:
+                            p1 = arith.cmpi(arith.CmpIPredicate.sgt, iv, c(0))
+                            p_arrive = arith.andi(consumerPrimaryThread, p1)
                         with ir.InsertionPoint(scf.IfOp(p_arrive).then_block):
                             p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(0))
                             barId = arith.select(
                                 p, c(num_stages - 1), arith.subi(stage, c(1))
                             )
-                            remoteCtaId = c(0)
-                            pred = consumerPrimaryThread
                             debug_print(
-                                "[cons] {}  | mbarDONE[{}] arrive.mapa  pred={} ",
+                                "[cons] iv={}  | mbarDONE[{}] arrive ",
                                 iv,
                                 barId,
-                                remoteCtaId,
-                                pred,
                                 predicate=consumerPrimaryThread,
                             )
                             nvgpu.mbarrier_arrive(
                                 ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId
                             )
                             debug_print(
-                                "[cons] {}  | mbarDONE[{}] arrive  pred={} [done]",
+                                "[cons] iv={}  | mbarDONE[{}] arrive [done]",
                                 iv,
                                 barId,
-                                remoteCtaId,
-                                pred,
                                 predicate=consumerPrimaryThread,
                             )
                             scf.yield_([])
@@ -609,11 +643,13 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
             # Step 4. Copy back to host
             t8 = gpu.wait(token_ty, [launch_op])
             t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
+            gpu.dealloc(token_ty, [t8], a_device)
+            gpu.dealloc(token_ty, [t8], b_device)            
             gpu.wait(token_ty, [t9])
+            gpu.dealloc(token_ty, [t8], c_device)
+            func.ReturnOp([])
 
-    mlir_matmul_warpspecialized.func_op.attributes[
-        "llvm.emit_c_interface"
-    ] = ir.UnitAttr.get()
+    fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
     module.operation.verify()
     return module
 
@@ -627,7 +663,7 @@ def generate_matmul_multistage(
     BLOCK_M=128,
     BLOCK_N=128,
     BLOCK_K=64,
-    max_num_stages=3,
+    num_stages=3,
 ):
     # Limitaitons for now
     assert input_type == np.float16
@@ -636,9 +672,6 @@ def generate_matmul_multistage(
     assert N % BLOCK_N == 0
     assert K % BLOCK_K == 0
 
-    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
-    num_stages = min(required_stages, max_num_stages)
-
     module = ir.Module.create()
     token_ty = ir.Type.parse("!gpu.async.token")
     a_elem_ty = get_mlir_ty(input_type)
@@ -717,9 +750,23 @@ def generate_matmul_multistage(
     )
 
     with ir.InsertionPoint(module.body):
-
-        @func.FuncOp.from_py_func(a_ty, b_ty, c_ty)
-        def mlir_matmul_multistage(a_host, b_host, c_host):
+        kernelName = make_kernel_name(
+            input_type,
+            output_type,
+            M,
+            N,
+            K,
+            BLOCK_M,
+            BLOCK_N,
+            BLOCK_K,
+            num_stages,
+            False,
+        )
+        fop = func.FuncOp(kernelName, ([a_ty, b_ty, c_ty], []))
+        with ir.InsertionPoint(fop.add_entry_block()):
+            a_host = fop.arguments[0]
+            b_host = fop.arguments[1]
+            c_host = fop.arguments[2]
             lhs_tile_bytes = BLOCK_M * BLOCK_K * get_type_size(a_elem_ty)
             rhs_tile_bytes = BLOCK_N * BLOCK_K * get_type_size(b_elem_ty)
             smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
@@ -792,7 +839,8 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                 nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=primaryThread)
 
                 # GPU Step 3. Prologue (global memory --> shared memory)
-                for_op = scf.ForOp(c(0), c(num_stages - 1), c(1))
+                ns = num_stages if num_stages == 1 else num_stages - 1
+                for_op = scf.ForOp(c(0), c(ns), c(1))
                 with ir.InsertionPoint(for_op.body):
                     iv = for_op.induction_variable
 
@@ -951,104 +999,105 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                     new_acc = nvgpu.WarpgroupMmaOp(
                         acc.type, da, db, carry_acc, transposeB=True
                     )
+                    if num_stages == 1:
+                        nvvm.WgmmaWaitGroupSyncOp(0)
 
                     # Step 4.4. Load TMA for next stage
                     p1 = arith.cmpi(
                         arith.CmpIPredicate.ult,
-                        arith.addi(iv, c(num_stages - 1)),
+                        arith.addi(iv, c(ns)),
                         c(K // BLOCK_K),
                     )
                     p = arith.andi(primaryThread, p1)
-                    with ir.InsertionPoint(scf.IfOp(p).then_block):
-                        nextStage = arith.addi(iv, c(num_stages - 1))
-                        nextSlot = arith.remui(nextStage, c(num_stages))
-                        a_offset = arith.muli(nextSlot, c(lhs_tile_bytes))
-                        a_tma_slice = memref.view(
-                            ir.MemRefType.get(
-                                a_tma_shape, a_elem_ty, memory_space=smem_space
-                            ),
-                            dynamic_smem,
-                            a_offset,
-                            [],
-                        )
-                        b_offset = arith.addi(
-                            arith.muli(nextSlot, c(rhs_tile_bytes)),
-                            c(lhs_tile_bytes * num_stages),
-                        )
-                        b_tma_slice_1 = memref.view(
-                            ir.MemRefType.get(
-                                b_tma_shape, b_elem_ty, memory_space=smem_space
-                            ),
-                            dynamic_smem,
-                            b_offset,
-                            [],
-                        )
-                        b_offset2 = arith.addi(
-                            b_offset,
-                            c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)),
-                        )
-                        b_tma_slice_2 = memref.view(
-                            ir.MemRefType.get(
-                                b_tma_shape, b_elem_ty, memory_space=smem_space
-                            ),
-                            dynamic_smem,
-                            b_offset2,
-                            [],
-                        )
+                    nextStage = arith.addi(iv, c(ns))
+                    nextSlot = arith.remui(nextStage, c(num_stages))
+                    a_offset = arith.muli(nextSlot, c(lhs_tile_bytes))
 
-                        coord = arith.muli(c(64), nextStage)
-                        debug_print(
-                            "[MainLoop] iv={} TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
-                            iv,
-                            a_offset,
-                            b_offset,
-                            b_offset2,
-                            coord,
-                            dimX,
-                            dimY,
-                            coord,
-                            predicate=primaryThread,
-                        )
-                        nvgpu.TmaAsyncLoadOp(
-                            a_tma_slice,
-                            mbarTMA,
-                            a_tma_desc,
-                            coordinates=[coord, dimX],
-                            mbarId=nextSlot,
-                            predicate=primaryThread,
-                        )
-                        nvgpu.TmaAsyncLoadOp(
-                            b_tma_slice_1,
-                            mbarTMA,
-                            b_tma_desc,
-                            coordinates=[dimY, coord],
-                            mbarId=nextSlot,
-                            predicate=primaryThread,
-                        )
-                        dimY2 = arith.addi(dimY, c(64))
-                        nvgpu.TmaAsyncLoadOp(
-                            b_tma_slice_2,
-                            mbarTMA,
-                            b_tma_desc,
-                            coordinates=[dimY2, coord],
-                            mbarId=nextSlot,
-                            predicate=primaryThread,
-                        )
+                    debug_print(
+                        "[MainLoop] mbarTMA[{}] arrive",
+                        nextSlot,
+                        predicate=p,
+                    )
+                    nvgpu.mbarrier_arrive_expect_tx(
+                        mbarTMA, c(txcount), nextSlot, predicate=p
+                    )
+                    debug_print(
+                        "[MainLoop] mbarTMA[{}] arrive [done]",
+                        nextSlot,
+                        predicate=p,
+                    )
 
-                        debug_print(
-                            "[MainLoop] mbarTMA[{}] arrive",
-                            nextSlot,
-                            predicate=primaryThread,
-                        )
-                        nvgpu.mbarrier_arrive_expect_tx(
-                            mbarTMA, c(txcount), nextSlot, predicate=primaryThread
-                        )
-                        debug_print(
-                            "[MainLoop] mbarTMA[{}] arrive [done]",
-                            nextSlot,
-                            predicate=primaryThread,
-                        )
-                        scf.yield_([])
+                    a_tma_slice = memref.view(
+                        ir.MemRefType.get(
+                            a_tma_shape, a_elem_ty, memory_space=smem_space
+                        ),
+                        dynamic_smem,
+                        a_offset,
+                        [],
+                    )
+                    b_offset = arith.addi(
+                        arith.muli(nextSlot, c(rhs_tile_bytes)),
+                        c(lhs_tile_bytes * num_stages),
+                    )
+                    b_tma_slice_1 = memref.view(
+                        ir.MemRefType.get(
+                            b_tma_shape, b_elem_ty, memory_space=smem_space
+                        ),
+                        dynamic_smem,
+                        b_offset,
+                        [],
+                    )
+                    b_offset2 = arith.addi(
+                        b_offset,
+                        c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)),
+                    )
+                    b_tma_slice_2 = memref.view(
+                        ir.MemRefType.get(
+                            b_tma_shape, b_elem_ty, memory_space=smem_space
+                        ),
+                        dynamic_smem,
+                        b_offset2,
+                        [],
+                    )
+
+                    coord = arith.muli(c(64), nextStage)
+                    debug_print(
+                        "[MainLoop] iv={} TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
+                        iv,
+                        a_offset,
+                        b_offset,
+                        b_offset2,
+                        coord,
+                        dimX,
+                        dimY,
+                        coord,
+                        predicate=p,
+                    )
+                    nvgpu.TmaAsyncLoadOp(
+                        a_tma_slice,
+                        mbarTMA,
+                        a_tma_desc,
+                        coordinates=[coord, dimX],
+                        mbarId=nextSlot,
+                        predicate=p,
+                    )
+                    nvgpu.TmaAsyncLoadOp(
+                        b_tma_slice_1,
+                        mbarTMA,
+                        b_tma_desc,
+                        coordinates=[dimY, coord],
+                        mbarId=nextSlot,
+                        predicate=p,
+                    )
+                    dimY2 = arith.addi(dimY, c(64))
+                    nvgpu.TmaAsyncLoadOp(
+                        b_tma_slice_2,
+                        mbarTMA,
+                        b_tma_desc,
+                        coordinates=[dimY2, coord],
+                        mbarId=nextSlot,
+                        predicate=p,
+                    )
                     # Step 4.5. Change the phaseParity
                     p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
                     phaseParity = arith.select(
@@ -1062,7 +1111,6 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
 
                 # Step 5. Wait All WGMMA groups
                 nvvm.WgmmaWaitGroupSyncOp(0)
-                gpu.barrier()
 
                 # Step 6. Epilogue (registers --> shared memory)
                 acc_smem_ty = ir.MemRefType.get(
@@ -1113,10 +1161,12 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
             # Step 4. Copy back to host
             t8 = gpu.wait(token_ty, [launch_op])
             t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
+            gpu.dealloc(token_ty, [t8], a_device)
+            gpu.dealloc(token_ty, [t8], b_device)            
             gpu.wait(token_ty, [t9])
+            gpu.dealloc(token_ty, [t8], c_device)
+            func.ReturnOp([])
 
-    mlir_matmul_multistage.func_op.attributes[
-        "llvm.emit_c_interface"
-    ] = ir.UnitAttr.get()
+    fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
     module.operation.verify()
     return module



More information about the Mlir-commits mailing list