[Mlir-commits] [mlir] [mlir] GEMM Hopper Tensor Core Integration Test (PR #81478)

Guray Ozen llvmlistbot at llvm.org
Sun Mar 3 00:42:15 PST 2024


https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/81478

>From 72e6a310edaa2e7d8e7ede385de09e52bd302f9b Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Mon, 12 Feb 2024 13:33:26 +0000
Subject: [PATCH 1/8] [mlir] GEMM Hopper Tensor Core Integration Test

This test aims to validate the correctness of the supported GEMM kernels in
NVGPU dialects, with current support for Multistage and Warp Specialization
kernels.
The test constructs and metaprograms IR using Python bindings, allowing
generic IR building. This flexibility enables changes to the shape,
tile size, or data type of the GEMM for testing purposes.
The entry function is `matmul`, where one can specify GEMM shape, tile size,
data type, GEMM algorithm (Multistage or Warp Specialization), and the maximum
number of stages.
Verification is done via numpy's matmul operation.

Example:
```
matmul(input_type=np.float16,                # input types
       output_type=np.float32,               # output type
       M=4096, N=4096, K=4096,               # Shape
       BLOCK_M=128, BLOCK_N=128, BLOCK_K=64, # Tile Size
       use_warp_specialization=True,         # Enable Warp Specialization
       max_num_stages=3)                     # Number of stages in shared memory
```
### Parallelism Across CTAs

GEMM includes three loops defining the shape of the GEMM, specified in the
`matmul` function.
The program builds IR using the following loop structure, tiling the loops
with the given tile size and parallelizing the two outermost loops into the
first and second dimensions of CTAs.
```
for(bi = 0; i < M; i += BLOCK_M)          # parallelize across blockIdx.x
    for(bj = 0; j < N; j += BLOCK_N)      # parallelize across blockIdx.y
        for(bk = 0; k < K; K += BLOCK_K)
            for(i = bi; i < (bi + BLOCK_M); ++i)
                for(j = bj; j < (bj + BLOCK_N); ++j)
                    for(k = bk; k < (bk + BLOCK_K); ++k)
```

## Multistage Kernel

This kernel launches a single warp group (128 threads). The primary thread
(pthread) requests load from TMA. Threads collectively wait for the data and
perform mma operations. After completing the shape, threads together store
first fragmented registers to shared memory, then from shared memory to global
memory; this part is called the epilogue.

Execution Timeline of Multistage Kernel with 3 stages:
```
+-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
|       |Prologue ---->   |MainLoop ---->                                                                  |Epilogue  |
+-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
|pthread|[tma-0,1,2]     |[wait-0][mma][tma-2]|[wait-1][mma][tma-0]|[wait-2][mma][tma-1]| ... | [mma-wait] |[epilogue]|
|wgroup | ........       |[wait-0][mma]       |[wait-1][mma]       |[wait-2][mma]       | ... | [mma-wait] |[epilogue]|
+-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
```

## Warp Specialization Kernel

This kernel launches 2 warp groups (2x128 threads) per CTA, specializing one
as `producer warp group` and another as `consumer warp group`. The
`producer warp group` is responsible for requesting TMA load, while the
`consumer warp group` performs the mma operation. The epilogue section is
handled by the `consumer warp group` as its threads own the fragmented registers.

Execution Timeline of Warp Specialization Kernel with 2 stages:
```
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
|        |MainLoop ---->                                                    | 1st Epilogue | 2nd Epilogue    |
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
|pthread1|[tma-0] | [tma-1] | [tma-0] | [tma-1] | ..........................| ...........  | [shmem->global] |
|wgroup1 | .......|         |         |         |                           |              | [shmem->global] |
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
|wgroup2 |[wait-0][mma], [wait-1][mma], [wait-0][mma], [wait-1][mma], ......| [reg->shmem] | [shmem->global]|
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
```
---
 .../GPU/CUDA/sm90/python/lit.local.cfg        |   2 +
 .../GPU/CUDA/sm90/python/matmul.py            | 186 +++++
 .../GPU/CUDA/sm90/python/tools/lit.local.cfg  |   3 +
 .../CUDA/sm90/python/tools/matmulBuilder.py   | 676 ++++++++++++++++++
 .../CUDA/sm90/python/tools/nvgpucompiler.py   |  44 ++
 5 files changed, 911 insertions(+)
 create mode 100644 mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg
 create mode 100644 mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
 create mode 100644 mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg
 create mode 100644 mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
 create mode 100644 mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py

diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg b/mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg
new file mode 100644
index 00000000000000..2d5a9d00e73226
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg
@@ -0,0 +1,2 @@
+if not config.enable_cuda_runner or not config.mlir_run_cuda_sm90_tests:
+    config.unsupported = True
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
new file mode 100644
index 00000000000000..e153fcb44b9860
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
@@ -0,0 +1,186 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN:   %PYTHON %s | FileCheck %s
+# CHECK: PASS
+
+# ===--- GEMM Hopper Tensor Core Integration Test ---===
+#
+# This test aims to validate the correctness of the supported GEMM kernels in
+# NVGPU dialects, with current support for Multistage and Warp Specialization
+# kernels.
+# The test constructs and metaprograms IR using Python bindings, allowing
+# generic IR building. This flexibility enables changes to the shape,
+# tile size, or data type of the GEMM for testing purposes.
+# The entry function is `matmul`, where one can specify GEMM shape, tile size,
+# data type, GEMM algorithm (Multistage or Warp Specialization), and the maximum
+# number of stages.
+# Verification is done via numpy's matmul operation.
+#
+# Example:
+# matmul(input_type=np.float16,                # input types
+#        output_type=np.float32,               # output type
+#        M=4096, N=4096, K=4096,               # Shape
+#        BLOCK_M=128, BLOCK_N=128, BLOCK_K=64, # Tile Size
+#        use_warp_specialization=True,         # Enable Warp Specialization
+#        max_num_stages=3)                     # Number of stages in shared memory
+#
+# ===--- Parallelism Across CTAs  ---===
+#
+# GEMM includes three loops defining the shape of the GEMM, specified in the
+# `matmul` function.
+# The program builds IR using the following loop structure, tiling the loops
+# with the given tile size and parallelizing the two outermost loops into the
+# first and second dimensions of CTAs.
+#
+# for(bi = 0; i < M; i += BLOCK_M)          # parallelize across blockIdx.x
+#     for(bj = 0; j < N; j += BLOCK_N)      # parallelize across blockIdx.y
+#         for(bk = 0; k < K; K += BLOCK_K)
+#             for(i = bi; i < (bi + BLOCK_M); ++i)
+#                 for(j = bj; j < (bj + BLOCK_N); ++j)
+#                     for(k = bk; k < (bk + BLOCK_K); ++k)
+#
+# ===--- Multistage Kernel ---===
+#
+# This kernel launches a single warp group (128 threads). The primary thread
+# (pthread) requests load from TMA. Threads collectively wait for the data and
+# perform mma operations. After completing the shape, threads together store
+# first fragmented registers to shared memory, then from shared memory to global
+# memory; this part is called the epilogue.
+#
+# Execution Timeline of Multistage Kernel with 3 stages:
+# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
+# |       |Prologue ---->   |MainLoop ---->                                                                  |Epilogue  |
+# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
+# |pthread|[tma-0,1,2]     |[wait-0][mma][tma-2]|[wait-1][mma][tma-0]|[wait-2][mma][tma-1]| ... | [mma-wait] |[epilogue]|
+# |wgroup | ........       |[wait-0][mma]       |[wait-1][mma]       |[wait-2][mma]       | ... | [mma-wait] |[epilogue]|
+# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
+#
+# ===--- Warp Specialization Kernel  ---===
+#
+# This kernel launches 2 warp groups (2x128 threads) per CTA, specializing one
+# as `producer warp group` and another as `consumer warp group`. The
+# `producer warp group` is responsible for requesting TMA load, while the
+# `consumer warp group` performs the mma operation. The epilogue section is
+# handled by the `consumer warp group` as its threads own the fragmented registers.
+#
+# Execution Timeline of Warp Specialization Kernel with 2 stages:
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+# |        |MainLoop ---->                                                    | 1st Epilogue | 2nd Epilogue    |
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+# |pthread1|[tma-0] | [tma-1] | [tma-0] | [tma-1] | ..........................| ...........  | [shmem->global] |
+# |wgroup1 | .......|         |         |         |                           |              | [shmem->global] |
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+# |wgroup2 |[wait-0][mma], [wait-1][mma], [wait-0][mma], [wait-1][mma], ......| [reg->shmem] | [shmem->global]|
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+
+import errno
+import numpy as np
+import subprocess
+import ctypes
+from tools import nvgpucompiler
+from tools import matmulBuilder
+import contextlib
+import os
+import sys
+import pathlib
+import ctypes
+from mlir import runtime as rt
+
+def generate_matmul(input_type=np.float16,
+                    output_type=np.float32,
+                    M=4096,
+                    N=4096,
+                    K=4096,
+                    BLOCK_M=128,
+                    BLOCK_N=128,
+                    BLOCK_K=64,
+                    use_warp_specilization=True,
+                    saveIR=False,
+                    max_num_stages=3):
+    with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown():
+        if use_warp_specilization:
+            mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(input_type, output_type, M, N, K, BLOCK_M, BLOCK_N,
+                                                                 BLOCK_K, max_num_stages)
+        else:
+            mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage(input_type, output_type, M, N, K, BLOCK_M,
+                                                                         BLOCK_N, BLOCK_K, max_num_stages)
+
+        mlir_nvgpu_module.operation.verify()
+        
+        # Save generated IR
+        if saveIR:
+            # print(mlir_nvgpu_module)
+            original_stdout = sys.stdout
+            with open('gemm.mlir', 'w') as f:
+                sys.stdout = f
+                print(mlir_nvgpu_module)
+                sys.stdout = original_stdout
+
+        # Get compiler
+        options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
+        support_lib = os.getenv("SUPPORT_LIB")
+        if not os.path.exists(support_lib):
+            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
+        compiler = nvgpucompiler.NvgpuCompiler(options, opt_level=3, shared_libs=[support_lib])
+
+        # Compile
+        engine = compiler.compile_and_jit(mlir_nvgpu_module)
+        return engine
+
+
+def matmul(input_type=np.float16,
+           output_type=np.float32,
+           M=128,
+           N=128,
+           K=128,
+           BLOCK_M=128,
+           BLOCK_N=128,
+           BLOCK_K=64,
+           use_warp_specilization=True,
+           saveIR=False,
+           max_num_stages=3,
+           print_results=False,
+           no_verify=False):
+    # Print the configuration
+    ity = "f16" if input_type == np.float16 else "f32"
+    oty = "f16" if output_type == np.float16 else "f32"
+    gemmty = "Warp Specilization" if use_warp_specilization else "Multistage"
+    print("===-- Running GEMM " + gemmty + " " + oty + " += " + ity + " * " + ity + ", Size " + str(M) + "x" + str(N) +
+          "x" + str(K) + ", Tile " + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" + str(BLOCK_K) + ", stages " +
+          str(max_num_stages) + " --===")
+
+    # Build IR and compile
+    engine = generate_matmul(input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, use_warp_specilization,
+                             saveIR, max_num_stages)
+
+    # Allocate matrices and invoke the matmul
+    c = np.zeros((M, N), output_type)
+    a = np.random.randn(M, K).astype(input_type)
+    b = np.random.randn(K, N).astype(input_type)
+    mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
+    mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
+    mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
+    kernelName = "mlir_matmul_warpspecialized" if use_warp_specilization else "mlir_matmul_multistage"
+
+    # Launch the MLIR generated kernel
+    engine.invoke(kernelName, mem_a, mem_b, mem_c)
+
+    float_formatter = "{:.2f}".format
+    np.set_printoptions(formatter={'float_kind': float_formatter})
+
+    if print_results:
+        print(c)
+
+    # Verify the results
+    if not no_verify:
+        ref = a.astype(input_type) @ b.astype(input_type)
+        if print_results:
+            print(ref)
+        np.testing.assert_allclose(c, ref, rtol=5e-03, atol=1e-01)
+
+    print("PASS ")
+
+
+# GEMM Multistage       f32 += f16 * f16
+matmul(np.float16, np.float32, 128, 128, 4096, max_num_stages=3, use_warp_specilization=False)
+# GEMM Warp Specilized  f32 += f16 * f16
+matmul(np.float16, np.float32, 256, 1024, 512, max_num_stages=3, use_warp_specilization=True)
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg
new file mode 100644
index 00000000000000..d9f34f219c4d95
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg
@@ -0,0 +1,3 @@
+# Files in this directory are tools, not tests.
+config.unsupported = True
+
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
new file mode 100644
index 00000000000000..09ab35d4c5f15e
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
@@ -0,0 +1,676 @@
+import numpy as np
+from mlir import ir
+from mlir.dialects import arith
+from mlir.dialects import func
+from mlir.dialects import gpu
+from mlir.dialects import memref
+from mlir.dialects import nvgpu
+from mlir.dialects import nvvm
+from mlir.dialects import llvm
+from mlir.dialects import builtin
+from mlir.dialects import scf
+from mlir.dialects import vector
+
+
+TMA_LAST_DIM_F16 = 64  # 128B flaot16
+WARP_SIZE = 32
+WARP_GROUP_SIZE = WARP_SIZE * 4
+
+PRODUCER_REGISTER_SIZE = 40
+CONSUMER_REGISTER_SIZE = 232
+
+PRODUCER_PRIMARY_THREAD = 128
+CONSUMER_PRIMARY_THREAD = 0
+
+MLIR_DYNAMIC = -9223372036854775808
+f16_byte = 2
+f32_byte = 4
+
+DEBUG = False
+
+
+def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False):
+    if not DEBUG and not forcePrint:
+        return
+    type_formats = []
+    for arg in args:
+        ty_format = None
+        if ir.IndexType.isinstance(arg.type):
+            ty_format = "%llu"
+        if ir.IntegerType.isinstance(arg.type):
+            width = ir.IntegerType(arg.type).width
+            if width == 64:
+                ty_format = "%llu"
+            elif width == 32:
+                ty_format = "%d"
+            elif width == 1:
+                ty_format = "%i"
+        if ir.F32Type.isinstance(arg.type):
+            ty_format = "%f"
+        if ty_format is None:
+            raise NotImplementedError(arg.type)
+        type_formats.append(ty_format)
+    if threadNumber != -1:
+        tidx = gpu.thread_id(gpu.Dimension.x)
+        predicate = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(threadNumber))
+        scf.yield_([])
+    if_op = scf.IfOp(predicate)
+    with ir.InsertionPoint(if_op.then_block):
+        gpu.printf(fmt.format(*type_formats) + "\n", args)
+        scf.yield_([])
+
+
+def c(value, ty=None):
+    ty = ir.IndexType.get() if ty is None else ty
+    return arith.constant(ty, value)
+
+
+def generate_matmul_ws(input_type=np.float16,
+                       output_type=np.float32,
+                       M=4096,
+                       N=4096,
+                       K=4096,
+                       BLOCK_M=128,
+                       BLOCK_N=128,
+                       BLOCK_K=128,
+                       max_num_stages=3):
+    # Limitaitons for now
+    assert input_type == np.float16
+    assert output_type == np.float32
+    assert M % BLOCK_M == 0
+    assert N % BLOCK_N == 0
+    assert K % BLOCK_K == 0
+
+    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
+    num_stages = min(required_stages, max_num_stages)
+
+    module = ir.Module.create()
+    f16 = ir.F16Type.get()
+    f32 = ir.F32Type.get()
+    i1 = ir.IntegerType.get_signless(1)
+    i32 = ir.IntegerType.get_signless(32)
+    index = ir.IndexType.get()
+    i8 = ir.IntegerType.get_signless(8)
+    token_ty = ir.Type.parse("!gpu.async.token")
+    a_ty = ir.MemRefType.get([M, K], f16)
+    b_ty = ir.MemRefType.get((K, N), f16)
+    c_elem_ty = f16 if output_type == np.float16 else f32
+    c_ty = ir.MemRefType.get((M, N), c_elem_ty)
+    a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
+    b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
+    b_tile_shape = (BLOCK_K, BLOCK_N)
+    txcount = ((b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
+    smem_space_str = "#gpu.address_space<workgroup>"
+    smem_space = ir.Attribute.parse(smem_space_str)
+    input_type_str = "f16" if input_type == np.float16 else "f32"
+    output_type_str = "f16" if output_type == np.float16 else "f32"
+    mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " + str(smem_space) + ", num_barriers = " +
+                            str(num_stages) + ">")
+    a_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
+                                  str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
+                                  ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+    b_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
+                                  str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
+                                  ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+    acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
+                           str(output_type_str) + ">>")
+    a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
+                               str(input_type_str) + ", " + smem_space_str + ">>")
+    b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
+                               str(input_type_str) + ", " + smem_space_str + ">>")
+
+    with ir.InsertionPoint(module.body):
+
+        @func.FuncOp.from_py_func(a_ty, b_ty, c_ty)
+        def mlir_matmul_warpspecialized(a_host, b_host, c_host):
+            lhs_tile_bytes = BLOCK_M * BLOCK_K * f16_byte
+            rhs_tile_bytes = BLOCK_N * BLOCK_K * f16_byte
+            smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
+            smem_size_output = BLOCK_M * BLOCK_N * f32_byte
+            smem_size = max(smem_size_input, smem_size_output)
+
+            # Step 1. Allocate device memory and memcpy
+            t1 = gpu.wait(token_ty, [])
+            a_device, t2 = gpu.alloc(a_ty, token_ty, [t1], [], [])
+            b_device, t3 = gpu.alloc(b_ty, token_ty, [t2], [], [])
+            c_device, t4 = gpu.alloc(c_ty, token_ty, [t3], [], [])
+            t5 = gpu.memcpy(token_ty, [t4], a_device, a_host)
+            t6 = gpu.memcpy(token_ty, [t5], b_device, b_host)
+            t7 = gpu.wait(token_ty, [t6])
+
+            # Step 2. Create TMA Descriptors
+            tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape), (b_device, b_tma_desc_ty, b_tma_shape)]
+            tma_descs = []
+            for x_device, tensor_map_ty, tile_shape in tma_specs:
+                x_unranked = memref.cast(ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device)
+                tma_descs.append(nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked, map(c, tile_shape)).result)
+            a_tma_desc, b_tma_desc = tma_descs
+            
+            # Step 3. Launch Kernel with 2 Warpgroups : 1 Producer, 1 Consumer
+            cta_m = M // BLOCK_M
+            cta_n = N // BLOCK_N
+            assert M % BLOCK_M == 0 and N % BLOCK_N == 0
+            grid = (cta_m, cta_n, 1)
+            block = (WARP_GROUP_SIZE * 2, 1, 1)
+            launch_op = gpu.LaunchOp(token_ty, [t7],
+                                     *map(c, grid),
+                                     *map(c, block),
+                                     dynamicSharedMemorySize=c(smem_size, ty=i32))
+            launch_op.body.blocks.append(*([index] * 12))
+            with ir.InsertionPoint(launch_op.body.blocks[0]):
+                # GPU Step 0. This is need for vectorized ld/st
+                memref.assume_alignment(c_device, 16)
+                dynamic_smem = gpu.dynamic_shared_memory(
+                    ir.MemRefType.get((MLIR_DYNAMIC, ), i8, memory_space=smem_space))
+                ticks = c(10000000)
+
+                # GPU Step 1. Bootstrapping: find the primary thread, warps, warp groups and etc.
+                tidx = gpu.thread_id(gpu.Dimension.x)
+                wgPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0))
+                warp_id = arith.divui(tidx, c(32))
+                warpgroup_id = arith.divui(warp_id, c(4))
+                is_producer = arith.cmpi(arith.CmpIPredicate.eq, warpgroup_id,
+                                         c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0))
+                is_consumer = arith.cmpi(arith.CmpIPredicate.eq, warpgroup_id,
+                                         c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1))
+                producerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(PRODUCER_PRIMARY_THREAD))
+                consumerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(CONSUMER_PRIMARY_THREAD))
+                bidx = gpu.block_id(gpu.Dimension.x)
+                bidy = gpu.block_id(gpu.Dimension.y)
+                dimX = arith.muli(bidx, c(BLOCK_M))
+                dimY = arith.muli(bidy, c(BLOCK_N))
+
+                # GPU Step 2. Initialize mbarrier groups
+                mbarTMA = nvgpu.mbarrier_create(mbar_ty)
+                mbarDONE = nvgpu.mbarrier_create(mbar_ty)
+                for i in range(num_stages):
+                    nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=wgPrimaryThread)
+                    nvgpu.mbarrier_init(mbarDONE, c(1), c(i), predicate=wgPrimaryThread)
+                gpu.barrier()
+
+                # GPU Step 3. Prefetch TMA descriptors
+                nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=wgPrimaryThread)
+                nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=wgPrimaryThread)
+                
+                # GPU Step 5. Producer Warpgroup (TMA Warpgroup)
+                with ir.InsertionPoint(scf.IfOp(is_producer).then_block):
+                    
+                    # Step 5.1. Reduce register size
+                    nvvm.setmaxregister(PRODUCER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.decrease)
+
+                    # Step 5.2. TMA Main Loop
+                    for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [arith.constant(i1, 1)])
+                    with ir.InsertionPoint(for_op.body):
+                        phaseParity = for_op.inner_iter_args[0]
+                        iv = for_op.induction_variable
+                        stage = arith.remui(iv, c(num_stages))
+
+                        # Step 5.2.1. Wait mbarDONE
+                        debug_print("[prod] {}  | mbarDONE[{}] try_wait  phase={}",
+                                    iv,
+                                    stage,
+                                    phaseParity,
+                                    predicate=producerPrimaryThread)
+                        nvgpu.MBarrierTryWaitParityOp(mbarDONE, phaseParity, ticks, mbarId=stage)
+                        debug_print("[prod] {}  | mbarDONE[{}] try_wait  phase={} [done]",
+                                    iv,
+                                    stage,
+                                    phaseParity,
+                                    predicate=producerPrimaryThread)
+                        p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
+                        phaseParity = arith.select(p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity)
+
+                        # Step 5.2.2. Load TMA
+                        a_offset = arith.muli(stage, c(lhs_tile_bytes))
+                        a_tma_slice = memref.view(ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
+                                                  dynamic_smem, a_offset, [])
+                        b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
+                        b_tma_slice_1 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+                                                    dynamic_smem, b_offset, [])
+                        b_offset2 = arith.addi(b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+                        b_tma_slice_2 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+                                                    dynamic_smem, b_offset2, [])
+                        debug_print("[prod] a_offset={} b_offset={} b_offset2={}",
+                                    a_offset,
+                                    b_offset,
+                                    b_offset2,
+                                    predicate=producerPrimaryThread)
+                        coord = arith.muli(c(64), iv)
+                        nvgpu.TmaAsyncLoadOp(a_tma_slice,
+                                             mbarTMA,
+                                             a_tma_desc,
+                                             coordinates=[coord, dimX],
+                                             mbarId=stage,
+                                             predicate=producerPrimaryThread)
+                        nvgpu.TmaAsyncLoadOp(b_tma_slice_1,
+                                             mbarTMA,
+                                             b_tma_desc,
+                                             coordinates=[dimY, coord],
+                                             mbarId=stage,
+                                             predicate=producerPrimaryThread)
+                        dimY2 = arith.addi(dimY, c(64))
+                        nvgpu.TmaAsyncLoadOp(b_tma_slice_2,
+                                             mbarTMA,
+                                             b_tma_desc,
+                                             coordinates=[dimY2, coord],
+                                             mbarId=stage,
+                                             predicate=producerPrimaryThread)
+
+                        # Step 5.2.3. Arrive mbarTMA
+                        debug_print("[prod] {}  | mbarTMA[{}] arrive", iv, stage, predicate=producerPrimaryThread)
+                        nvgpu.mbarrier_arrive_expect_tx(mbarTMA, c(txcount), stage, predicate=producerPrimaryThread)
+                        debug_print("[prod] {}  | mbarTMA[{}] arrive [done]",
+                                    iv,
+                                    stage,
+                                    predicate=producerPrimaryThread)
+                        scf.yield_([phaseParity])
+                    scf.yield_([])
+
+                # GPU Step 6. Consumer Warpgroup (MMA Warpgroup)
+                if_op = scf.IfOp(is_consumer)
+                with ir.InsertionPoint(if_op.then_block):
+                    
+                    # Step 6.1. Increase register size
+                    nvvm.setmaxregister(CONSUMER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.increase)
+                    
+                    # GPU Step 6.2. Initialize MMA registers
+                    acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
+
+                    # Step 6.3. MMA Main Loop
+                    for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)])
+                    with ir.InsertionPoint(for_op.body):
+                        # Step 6.3.1. Wait mbar1
+                        phaseParity = for_op.inner_iter_args[1]
+                        iv = for_op.induction_variable
+                        stage = arith.remui(iv, c(num_stages))
+                        debug_print("[cons] {}  | mbarTMA[{}] try_wait   phase={}",
+                                    iv,
+                                    stage,
+                                    phaseParity,
+                                    predicate=consumerPrimaryThread)
+                        nvgpu.MBarrierTryWaitParityOp(mbarTMA, phaseParity, ticks, mbarId=stage)
+                        debug_print("[cons] {}  | mbarTMA[{}] try_wait   phase={} [done]",
+                                    iv,
+                                    stage,
+                                    phaseParity,
+                                    predicate=consumerPrimaryThread)
+                        
+                        # Step 6.3.2. Create WGMMA Descriptors
+                        a_offset = arith.muli(stage, c(lhs_tile_bytes))
+                        a_tile_slice = memref.view(ir.MemRefType.get(a_tile_shape, f16, memory_space=smem_space),
+                                                   dynamic_smem, a_offset, [])
+                        b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
+                        b_tile_slice = memref.view(ir.MemRefType.get(b_tile_shape, f16, memory_space=smem_space),
+                                                   dynamic_smem, b_offset, [])
+                        debug_print("[cons] a_offset={} b_offset={}",
+                                    a_offset,
+                                    b_offset,
+                                    predicate=consumerPrimaryThread)
+                        da = nvgpu.WarpgroupGenerateDescriptorOp(a_wgmma_ty, a_tile_slice, a_tma_desc)
+                        db = nvgpu.WarpgroupGenerateDescriptorOp(b_wgmma_ty, b_tile_slice, b_tma_desc)
+
+                        # Step 6.3.3. MMA
+                        carry_acc = for_op.inner_iter_args[0]
+                        new_acc = nvgpu.WarpgroupMmaOp(acc.type, da, db, carry_acc, transposeB=True)
+
+                        # Step 6.3.4. Arrive mbarDONE
+                        p1 = arith.cmpi(arith.CmpIPredicate.sgt, iv, c(0))
+                        p_arrive = arith.andi(consumerPrimaryThread, p1)
+                        with ir.InsertionPoint(scf.IfOp(p_arrive).then_block):
+                            p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(0))
+                            barId = arith.select(p, c(num_stages - 1), arith.subi(stage, c(1)))
+                            remoteCtaId = c(0)
+                            pred = consumerPrimaryThread
+                            debug_print("[cons] {}  | mbarDONE[{}] arrive.mapa  pred={} ",
+                                        iv,
+                                        barId,
+                                        remoteCtaId,
+                                        pred,
+                                        predicate=consumerPrimaryThread)
+                            nvgpu.mbarrier_arrive(ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId)
+                            debug_print("[cons] {}  | mbarDONE[{}] arrive  pred={} [done]",
+                                        iv,
+                                        barId,
+                                        remoteCtaId,
+                                        pred,
+                                        predicate=consumerPrimaryThread)
+                            scf.yield_([])
+
+                        p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
+                        phaseParity = arith.select(p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity)
+
+                        # Step 6.3.5. Yield
+                        scf.yield_([new_acc, phaseParity])
+
+                    # Step 6.3. Wait All WGMMA
+                    nvvm.WgmmaWaitGroupSyncOp(0)
+
+                    with ir.InsertionPoint(scf.IfOp(consumerPrimaryThread).then_block):
+                        barId = c((K // BLOCK_K) % num_stages)
+                        nvgpu.mbarrier_arrive(ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId)
+                        scf.yield_([])
+
+                    # Step 6.4. Epilogue (registers --> shared memory)
+                    acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space)
+                    acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
+                    debug_print("[cons]  | Storing", predicate=consumerPrimaryThread)
+                    nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
+                    scf.yield_([])
+                gpu.barrier()
+
+                # GPU Step 9. Epilogue (shared memory --> global memory)
+                fd = ir.MemRefType.get([BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space)
+                collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
+                rty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty,
+                                        ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"))
+                c_device_per_block = memref.SubViewOp(rty, c_device, [dimX, dimY], [], [], [MLIR_DYNAMIC, MLIR_DYNAMIC],
+                                                      [BLOCK_M, BLOCK_N], [1, 1])
+                vlen = 1
+                for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE * 2))
+                with ir.InsertionPoint(for_op.body):
+                    x = arith.divui(for_op.induction_variable, c(BLOCK_M))
+                    y = arith.remui(for_op.induction_variable, c(BLOCK_N))
+                    vdata = vector.load(ir.VectorType.get((vlen, ), c_elem_ty), collapsed_smem,
+                                        [for_op.induction_variable])
+                    vector.store(vdata, c_device_per_block, [x, y])
+                    scf.yield_([])
+
+                gpu.terminator()
+
+            # Step 4. Copy back to host
+            t8 = gpu.wait(token_ty, [launch_op])
+            t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
+            gpu.wait(token_ty, [t9])
+
+    mlir_matmul_warpspecialized.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+    module.operation.verify()
+    return module
+
+
+def generate_matmul_multistage(input_type=np.float16,
+                               output_type=np.float32,
+                               M=4096,
+                               N=4096,
+                               K=4096,
+                               BLOCK_M=128,
+                               BLOCK_N=128,
+                               BLOCK_K=64,
+                               max_num_stages=3):
+    # Limitaitons for now
+    assert input_type == np.float16
+    assert output_type == np.float32
+    assert M % BLOCK_M == 0
+    assert N % BLOCK_N == 0
+    assert K % BLOCK_K == 0
+
+    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
+    num_stages = min(required_stages, max_num_stages)
+
+    module = ir.Module.create()
+    f16 = ir.F16Type.get()
+    f32 = ir.F32Type.get()
+    i1 = ir.IntegerType.get_signless(1)
+    i32 = ir.IntegerType.get_signless(32)
+    index = ir.IndexType.get()
+    i8 = ir.IntegerType.get_signless(8)
+    token_ty = ir.Type.parse("!gpu.async.token")
+    a_ty = ir.MemRefType.get([M, K], f16)
+    b_ty = ir.MemRefType.get((K, N), f16)
+    c_elem_ty = f16 if output_type == np.float16 else f32
+    c_ty = ir.MemRefType.get((M, N), c_elem_ty)
+    a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
+    b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
+    b_tile_shape = (BLOCK_K, BLOCK_N)
+    txcount = ((b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
+    smem_space_str = "#gpu.address_space<workgroup>"
+    smem_space = ir.Attribute.parse(smem_space_str)
+    input_type_str = "f16" if input_type == np.float16 else "f32"
+    output_type_str = "f16" if output_type == np.float16 else "f32"
+    mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " + str(smem_space) + ", num_barriers = " +
+                            str(num_stages) + ">")
+    a_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
+                                  str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
+                                  ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+    b_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
+                                  str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
+                                  ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+    acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
+                           str(output_type_str) + ">>")
+    a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
+                               str(input_type_str) + ", " + smem_space_str + ">>")
+    b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
+                               str(input_type_str) + ", " + smem_space_str + ">>")
+
+    with ir.InsertionPoint(module.body):
+
+        @func.FuncOp.from_py_func(a_ty, b_ty, c_ty)
+        def mlir_matmul_multistage(a_host, b_host, c_host):
+            lhs_tile_bytes = BLOCK_M * BLOCK_K * f16_byte
+            rhs_tile_bytes = BLOCK_N * BLOCK_K * f16_byte
+            smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
+            smem_size_output = BLOCK_M * BLOCK_N * f32_byte
+            smem_size = max(smem_size_input, smem_size_output)
+
+            # Step 1. Allocate device memory and memcpy
+            t1 = gpu.wait(token_ty, [])
+            a_device, t2 = gpu.alloc(a_ty, token_ty, [t1], [], [])
+            b_device, t3 = gpu.alloc(b_ty, token_ty, [t2], [], [])
+            c_device, t4 = gpu.alloc(c_ty, token_ty, [t3], [], [])
+            t5 = gpu.memcpy(token_ty, [t4], a_device, a_host)
+            t6 = gpu.memcpy(token_ty, [t5], b_device, b_host)
+            t7 = gpu.wait(token_ty, [t6])
+
+            # Step 2. Create TMA Descriptors
+            tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape), (b_device, b_tma_desc_ty, b_tma_shape)]
+            tma_descs = []
+            for x_device, tensor_map_ty, tile_shape in tma_specs:
+                x_unranked = memref.cast(ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device)
+                tma_descs.append(nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked, map(c, tile_shape)).result)
+            a_tma_desc, b_tma_desc = tma_descs
+
+            # Step 3. Launch Kernel with 1 Warpgroup
+            cta_m = M // BLOCK_M
+            cta_n = N // BLOCK_N
+            assert M % BLOCK_M == 0 and N % BLOCK_N == 0
+            grid = (cta_m, cta_n, 1)
+            block = (WARP_GROUP_SIZE, 1, 1)
+            launch_op = gpu.LaunchOp(token_ty, [t7],
+                                     *map(c, grid),
+                                     *map(c, block),
+                                     dynamicSharedMemorySize=c(smem_size, ty=i32))
+            launch_op.body.blocks.append(*([index] * 12))
+            with ir.InsertionPoint(launch_op.body.blocks[0]):
+                # GPU Step 0. Bootstrapping
+                memref.assume_alignment(c_device, 16)
+                dynamic_smem = gpu.dynamic_shared_memory(
+                    ir.MemRefType.get((MLIR_DYNAMIC, ), i8, memory_space=smem_space))
+                ticks = c(10000000)
+                tidx = gpu.thread_id(gpu.Dimension.x)
+                primaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(0))
+                warpId = arith.divui(tidx, c(32))
+                bidx = gpu.block_id(gpu.Dimension.x)
+                bidy = gpu.block_id(gpu.Dimension.y)
+                dimX = arith.muli(bidx, c(BLOCK_M))
+                dimY = arith.muli(bidy, c(BLOCK_N))
+
+                # GPU Step 1. Initialize mbarrier groups
+                mbarTMA = nvgpu.mbarrier_create(mbar_ty)
+                for i in range(num_stages):
+                    nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=primaryThread)
+                gpu.barrier()
+
+                # GPU Step 2. Prefetch TMA descriptors
+                nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=primaryThread)
+                nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=primaryThread)
+
+                # GPU Step 3. Prologue (global memory --> shared memory)
+                for_op = scf.ForOp(c(0), c(num_stages-1), c(1))
+                with ir.InsertionPoint(for_op.body):
+                    iv = for_op.induction_variable
+
+                    # Step 3.1. Calculate offsets
+                    a_offset = arith.muli(iv, c(lhs_tile_bytes))
+                    a_tma_slice = memref.view(ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
+                                              dynamic_smem, a_offset, [])
+                    b_offset = arith.addi(arith.muli(iv, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
+                    b_tma_slice_1 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+                                                dynamic_smem, b_offset, [])
+                    b_offset2 = arith.addi(b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+                    b_tma_slice_2 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+                                                dynamic_smem, b_offset2, [])
+
+                    # Step 3.2. TMA Load
+                    coord = arith.muli(c(64), iv)
+                    dimY2 = arith.addi(dimY, c(64))
+                    debug_print("[Prologue] TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
+                                a_offset,
+                                b_offset,
+                                b_offset2,
+                                coord, dimX,
+                                dimY, coord,
+                                predicate=primaryThread)
+                    nvgpu.TmaAsyncLoadOp(a_tma_slice,
+                                         mbarTMA,
+                                         a_tma_desc,
+                                         coordinates=[coord, dimX],
+                                         mbarId=iv,
+                                         predicate=primaryThread)
+                    nvgpu.TmaAsyncLoadOp(b_tma_slice_1,
+                                         mbarTMA,
+                                         b_tma_desc,
+                                         coordinates=[dimY, coord],
+                                         mbarId=iv,
+                                         predicate=primaryThread)
+                    nvgpu.TmaAsyncLoadOp(b_tma_slice_2,
+                                         mbarTMA,
+                                         b_tma_desc,
+                                         coordinates=[dimY2, coord],
+                                         mbarId=iv,
+                                         predicate=primaryThread)
+
+                    # Step 3.2. mbarTMA arrive
+                    debug_print("[Prologue] mbarTMA[{}] arrive", iv, predicate=primaryThread)
+                    nvgpu.mbarrier_arrive_expect_tx(mbarTMA, c(txcount), iv, predicate=primaryThread)
+                    debug_print("[Prologue] mbarTMA[{}] arrive [done]", iv, predicate=primaryThread)
+                    scf.yield_([])
+
+                # GPU Step 4. Main Loop
+                acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
+                for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)])
+                with ir.InsertionPoint(for_op.body):
+                    # Step 4.1. Wait mbarTMA
+                    phaseParity = for_op.inner_iter_args[1]
+                    iv = for_op.induction_variable
+                    stage = arith.remui(iv, c(num_stages))
+                    debug_print("[MainLoop] mbarTMA[{}] try_wait   phase={}", stage, phaseParity, predicate=primaryThread)
+                    nvgpu.MBarrierTryWaitParityOp(mbarTMA, phaseParity, ticks, mbarId=stage)
+                    debug_print("[MainLoop] mbarTMA[{}] try_wait   phase={} [done]", stage, phaseParity, predicate=primaryThread)
+
+                    # Step 4.2. Create WGMMA Descriptors
+                    a_offset = arith.muli(stage, c(lhs_tile_bytes))
+                    a_tile_slice = memref.view(ir.MemRefType.get(a_tile_shape, f16, memory_space=smem_space),
+                                               dynamic_smem, a_offset, [])
+                    b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
+                    b_tile_slice = memref.view(ir.MemRefType.get(b_tile_shape, f16, memory_space=smem_space),
+                                               dynamic_smem, b_offset, [])
+                    debug_print("[MainLoop] iv={} MMA a_offset={} b_offset={}", iv, a_offset, b_offset, predicate=primaryThread)
+                    da = nvgpu.WarpgroupGenerateDescriptorOp(a_wgmma_ty, a_tile_slice, a_tma_desc)
+                    db = nvgpu.WarpgroupGenerateDescriptorOp(b_wgmma_ty, b_tile_slice, b_tma_desc)
+
+                    # Step 4.3. MMA
+                    carry_acc = for_op.inner_iter_args[0]
+                    new_acc = nvgpu.WarpgroupMmaOp(acc.type, da, db, carry_acc, transposeB=True)
+
+                    # Step 4.4. Load TMA for next stage
+                    p1 = arith.cmpi(arith.CmpIPredicate.ult, arith.addi(iv, c(num_stages-1)), c(K // BLOCK_K))
+                    p = arith.andi(primaryThread, p1)
+                    with ir.InsertionPoint(scf.IfOp(p).then_block):
+                        nextStage = arith.addi(iv, c(num_stages-1))
+                        nextSlot = arith.remui(nextStage, c(num_stages))
+                        a_offset = arith.muli(nextSlot, c(lhs_tile_bytes))
+                        a_tma_slice = memref.view(ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
+                                                dynamic_smem, a_offset, [])
+                        b_offset = arith.addi(arith.muli(nextSlot, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
+                        b_tma_slice_1 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+                                                    dynamic_smem, b_offset, [])
+                        b_offset2 = arith.addi(b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+                        b_tma_slice_2 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+                                                    dynamic_smem, b_offset2, [])
+
+                        coord = arith.muli(c(64), nextStage)
+                        debug_print("[MainLoop] iv={} TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
+                                iv, a_offset,
+                                b_offset,
+                                b_offset2,
+                                coord, dimX,
+                                dimY, coord,
+                                predicate=primaryThread)
+                        nvgpu.TmaAsyncLoadOp(a_tma_slice,
+                                            mbarTMA,
+                                            a_tma_desc,
+                                            coordinates=[coord, dimX],
+                                            mbarId=nextSlot,
+                                            predicate=primaryThread)
+                        nvgpu.TmaAsyncLoadOp(b_tma_slice_1,
+                                            mbarTMA,
+                                            b_tma_desc,
+                                            coordinates=[dimY, coord],
+                                            mbarId=nextSlot,
+                                            predicate=primaryThread)
+                        dimY2 = arith.addi(dimY, c(64))
+                        nvgpu.TmaAsyncLoadOp(b_tma_slice_2,
+                                            mbarTMA,
+                                            b_tma_desc,
+                                            coordinates=[dimY2, coord],
+                                            mbarId=nextSlot,
+                                            predicate=primaryThread)
+
+                        debug_print("[MainLoop] mbarTMA[{}] arrive", nextSlot, predicate=primaryThread)
+                        nvgpu.mbarrier_arrive_expect_tx(mbarTMA, c(txcount), nextSlot, predicate=primaryThread)
+                        debug_print("[MainLoop] mbarTMA[{}] arrive [done]", nextSlot, predicate=primaryThread)
+                        scf.yield_([])
+                    # Step 4.5. Change the phaseParity
+                    p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
+                    phaseParity = arith.select(p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity)
+
+                    # Step 4.5. Yield
+                    scf.yield_([new_acc, phaseParity])
+
+                # Step 5. Wait All WGMMA groups
+                nvvm.WgmmaWaitGroupSyncOp(0)
+                gpu.barrier()
+
+                # Step 6. Epilogue (registers --> shared memory)
+                acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space)
+                acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
+                debug_print("Storing", predicate=primaryThread)
+                nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
+                gpu.barrier()
+
+                # GPU Step 7. Epilogue (shared memory --> global memory)
+                fd = ir.MemRefType.get([BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space)
+                collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
+                rty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty,
+                                        ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"))
+                c_device_per_block = memref.SubViewOp(rty, c_device, [dimX, dimY], [], [], [MLIR_DYNAMIC, MLIR_DYNAMIC],
+                                                      [BLOCK_M, BLOCK_N], [1, 1])
+                vlen = 1
+                for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE))
+                with ir.InsertionPoint(for_op.body):
+                    x = arith.divui(for_op.induction_variable, c(BLOCK_M))
+                    y = arith.remui(for_op.induction_variable, c(BLOCK_N))
+                    vdata = vector.load(ir.VectorType.get((vlen, ), c_elem_ty), collapsed_smem,
+                                        [for_op.induction_variable])
+                    vector.store(vdata, c_device_per_block, [x, y])
+                    scf.yield_([])
+
+                gpu.terminator()
+
+            # Step 4. Copy back to host
+            t8 = gpu.wait(token_ty, [launch_op])
+            t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
+            gpu.wait(token_ty, [t9])
+
+    mlir_matmul_multistage.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+    module.operation.verify()
+    return module
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
new file mode 100644
index 00000000000000..2a4f67326ee718
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
@@ -0,0 +1,44 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#  This file contains the Nvgpu class.
+
+from mlir import execution_engine
+from mlir import ir
+from mlir import passmanager
+from typing import Sequence
+import errno
+import os
+import sys
+
+_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(_SCRIPT_PATH)
+
+class NvgpuCompiler:
+    """Nvgpu class for compiling and building MLIR modules."""
+
+    def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
+        pipeline = f"builtin.module(gpu-lower-to-nvvm-pipeline{{{options}}})"
+        self.pipeline = pipeline
+        self.shared_libs = shared_libs
+        self.opt_level = opt_level
+
+    def __call__(self, module: ir.Module):
+        """Convenience application method."""
+        self.compile(module)
+
+    def compile(self, module: ir.Module):
+        """Compiles the module by invoking the nvgpu pipeline."""
+        passmanager.PassManager.parse(self.pipeline).run(module.operation)
+    
+    def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+        """Wraps the module in a JIT execution engine."""
+        return execution_engine.ExecutionEngine(
+            module, opt_level=self.opt_level, shared_libs=self.shared_libs
+        )
+
+    def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+        """Compiles and jits the module."""
+        self.compile(module)
+        return self.jit(module)

>From 26da27e3c5b5198750b1f80fbd9d1e5d6dfe229d Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Mon, 12 Feb 2024 13:39:08 +0000
Subject: [PATCH 2/8] format with yapf

---
 .../GPU/CUDA/sm90/python/matmul.py            |  51 +-
 .../CUDA/sm90/python/tools/matmulBuilder.py   | 655 ++++++++++++------
 .../CUDA/sm90/python/tools/nvgpucompiler.py   |  15 +-
 3 files changed, 483 insertions(+), 238 deletions(-)

diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
index e153fcb44b9860..00313270d723d6 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
@@ -85,6 +85,7 @@
 import ctypes
 from mlir import runtime as rt
 
+
 def generate_matmul(input_type=np.float16,
                     output_type=np.float32,
                     M=4096,
@@ -96,16 +97,19 @@ def generate_matmul(input_type=np.float16,
                     use_warp_specilization=True,
                     saveIR=False,
                     max_num_stages=3):
-    with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown():
+    with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown(
+    ):
         if use_warp_specilization:
-            mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(input_type, output_type, M, N, K, BLOCK_M, BLOCK_N,
-                                                                 BLOCK_K, max_num_stages)
+            mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(
+                input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K,
+                max_num_stages)
         else:
-            mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage(input_type, output_type, M, N, K, BLOCK_M,
-                                                                         BLOCK_N, BLOCK_K, max_num_stages)
+            mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage(
+                input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K,
+                max_num_stages)
 
         mlir_nvgpu_module.operation.verify()
-        
+
         # Save generated IR
         if saveIR:
             # print(mlir_nvgpu_module)
@@ -119,8 +123,11 @@ def generate_matmul(input_type=np.float16,
         options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
         support_lib = os.getenv("SUPPORT_LIB")
         if not os.path.exists(support_lib):
-            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
-        compiler = nvgpucompiler.NvgpuCompiler(options, opt_level=3, shared_libs=[support_lib])
+            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
+                                    support_lib)
+        compiler = nvgpucompiler.NvgpuCompiler(options,
+                                               opt_level=3,
+                                               shared_libs=[support_lib])
 
         # Compile
         engine = compiler.compile_and_jit(mlir_nvgpu_module)
@@ -144,13 +151,15 @@ def matmul(input_type=np.float16,
     ity = "f16" if input_type == np.float16 else "f32"
     oty = "f16" if output_type == np.float16 else "f32"
     gemmty = "Warp Specilization" if use_warp_specilization else "Multistage"
-    print("===-- Running GEMM " + gemmty + " " + oty + " += " + ity + " * " + ity + ", Size " + str(M) + "x" + str(N) +
-          "x" + str(K) + ", Tile " + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" + str(BLOCK_K) + ", stages " +
-          str(max_num_stages) + " --===")
+    print("===-- Running GEMM " + gemmty + " " + oty + " += " + ity + " * " +
+          ity + ", Size " + str(M) + "x" + str(N) + "x" + str(K) + ", Tile " +
+          str(BLOCK_M) + "x" + str(BLOCK_N) + "x" + str(BLOCK_K) +
+          ", stages " + str(max_num_stages) + " --===")
 
     # Build IR and compile
-    engine = generate_matmul(input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, use_warp_specilization,
-                             saveIR, max_num_stages)
+    engine = generate_matmul(input_type, output_type, M, N, K, BLOCK_M,
+                             BLOCK_N, BLOCK_K, use_warp_specilization, saveIR,
+                             max_num_stages)
 
     # Allocate matrices and invoke the matmul
     c = np.zeros((M, N), output_type)
@@ -181,6 +190,18 @@ def matmul(input_type=np.float16,
 
 
 # GEMM Multistage       f32 += f16 * f16
-matmul(np.float16, np.float32, 128, 128, 4096, max_num_stages=3, use_warp_specilization=False)
+matmul(np.float16,
+       np.float32,
+       128,
+       128,
+       4096,
+       max_num_stages=3,
+       use_warp_specilization=False)
 # GEMM Warp Specilized  f32 += f16 * f16
-matmul(np.float16, np.float32, 256, 1024, 512, max_num_stages=3, use_warp_specilization=True)
+matmul(np.float16,
+       np.float32,
+       256,
+       1024,
+       512,
+       max_num_stages=3,
+       use_warp_specilization=True)
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
index 09ab35d4c5f15e..0fda8698606cd5 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
@@ -11,7 +11,6 @@
 from mlir.dialects import scf
 from mlir.dialects import vector
 
-
 TMA_LAST_DIM_F16 = 64  # 128B flaot16
 WARP_SIZE = 32
 WARP_GROUP_SIZE = WARP_SIZE * 4
@@ -81,7 +80,8 @@ def generate_matmul_ws(input_type=np.float16,
     assert N % BLOCK_N == 0
     assert K % BLOCK_K == 0
 
-    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
+    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K +
+                                          BLOCK_K * BLOCK_N)
     num_stages = min(required_stages, max_num_stages)
 
     module = ir.Module.create()
@@ -99,25 +99,36 @@ def generate_matmul_ws(input_type=np.float16,
     a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
     b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
     b_tile_shape = (BLOCK_K, BLOCK_N)
-    txcount = ((b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
+    txcount = ((b_tile_shape[0] * b_tile_shape[1]) +
+               (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
     smem_space_str = "#gpu.address_space<workgroup>"
     smem_space = ir.Attribute.parse(smem_space_str)
     input_type_str = "f16" if input_type == np.float16 else "f32"
     output_type_str = "f16" if output_type == np.float16 else "f32"
-    mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " + str(smem_space) + ", num_barriers = " +
+    mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " +
+                            str(smem_space) + ", num_barriers = " +
                             str(num_stages) + ">")
-    a_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
-                                  str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
-                                  ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
-    b_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
-                                  str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
-                                  ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
-    acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
+    a_tma_desc_ty = ir.Type.parse(
+        "!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
+        str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " +
+        str(smem_space) +
+        ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+    b_tma_desc_ty = ir.Type.parse(
+        "!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
+        str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " +
+        str(smem_space) +
+        ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+    acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" +
+                           str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
                            str(output_type_str) + ">>")
-    a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
-                               str(input_type_str) + ", " + smem_space_str + ">>")
-    b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
-                               str(input_type_str) + ", " + smem_space_str + ">>")
+    a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" +
+                               str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
+                               str(input_type_str) + ", " + smem_space_str +
+                               ">>")
+    b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" +
+                               str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
+                               str(input_type_str) + ", " + smem_space_str +
+                               ">>")
 
     with ir.InsertionPoint(module.body):
 
@@ -139,13 +150,18 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
             t7 = gpu.wait(token_ty, [t6])
 
             # Step 2. Create TMA Descriptors
-            tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape), (b_device, b_tma_desc_ty, b_tma_shape)]
+            tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape),
+                         (b_device, b_tma_desc_ty, b_tma_shape)]
             tma_descs = []
             for x_device, tensor_map_ty, tile_shape in tma_specs:
-                x_unranked = memref.cast(ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device)
-                tma_descs.append(nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked, map(c, tile_shape)).result)
+                x_unranked = memref.cast(
+                    ir.UnrankedMemRefType.get(f16, a_ty.memory_space),
+                    x_device)
+                tma_descs.append(
+                    nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked,
+                                                map(c, tile_shape)).result)
             a_tma_desc, b_tma_desc = tma_descs
-            
+
             # Step 3. Launch Kernel with 2 Warpgroups : 1 Producer, 1 Consumer
             cta_m = M // BLOCK_M
             cta_n = N // BLOCK_N
@@ -155,26 +171,37 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
             launch_op = gpu.LaunchOp(token_ty, [t7],
                                      *map(c, grid),
                                      *map(c, block),
-                                     dynamicSharedMemorySize=c(smem_size, ty=i32))
+                                     dynamicSharedMemorySize=c(smem_size,
+                                                               ty=i32))
             launch_op.body.blocks.append(*([index] * 12))
             with ir.InsertionPoint(launch_op.body.blocks[0]):
                 # GPU Step 0. This is need for vectorized ld/st
                 memref.assume_alignment(c_device, 16)
                 dynamic_smem = gpu.dynamic_shared_memory(
-                    ir.MemRefType.get((MLIR_DYNAMIC, ), i8, memory_space=smem_space))
+                    ir.MemRefType.get((MLIR_DYNAMIC, ),
+                                      i8,
+                                      memory_space=smem_space))
                 ticks = c(10000000)
 
                 # GPU Step 1. Bootstrapping: find the primary thread, warps, warp groups and etc.
                 tidx = gpu.thread_id(gpu.Dimension.x)
-                wgPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0))
+                wgPrimaryThread = arith.cmpi(
+                    arith.CmpIPredicate.eq,
+                    arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0))
                 warp_id = arith.divui(tidx, c(32))
                 warpgroup_id = arith.divui(warp_id, c(4))
-                is_producer = arith.cmpi(arith.CmpIPredicate.eq, warpgroup_id,
-                                         c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0))
-                is_consumer = arith.cmpi(arith.CmpIPredicate.eq, warpgroup_id,
-                                         c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1))
-                producerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(PRODUCER_PRIMARY_THREAD))
-                consumerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(CONSUMER_PRIMARY_THREAD))
+                is_producer = arith.cmpi(
+                    arith.CmpIPredicate.eq, warpgroup_id,
+                    c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0))
+                is_consumer = arith.cmpi(
+                    arith.CmpIPredicate.eq, warpgroup_id,
+                    c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1))
+                producerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq,
+                                                   tidx,
+                                                   c(PRODUCER_PRIMARY_THREAD))
+                consumerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq,
+                                                   tidx,
+                                                   c(CONSUMER_PRIMARY_THREAD))
                 bidx = gpu.block_id(gpu.Dimension.x)
                 bidy = gpu.block_id(gpu.Dimension.y)
                 dimX = arith.muli(bidx, c(BLOCK_M))
@@ -184,57 +211,88 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                 mbarTMA = nvgpu.mbarrier_create(mbar_ty)
                 mbarDONE = nvgpu.mbarrier_create(mbar_ty)
                 for i in range(num_stages):
-                    nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=wgPrimaryThread)
-                    nvgpu.mbarrier_init(mbarDONE, c(1), c(i), predicate=wgPrimaryThread)
+                    nvgpu.mbarrier_init(mbarTMA,
+                                        c(1),
+                                        c(i),
+                                        predicate=wgPrimaryThread)
+                    nvgpu.mbarrier_init(mbarDONE,
+                                        c(1),
+                                        c(i),
+                                        predicate=wgPrimaryThread)
                 gpu.barrier()
 
                 # GPU Step 3. Prefetch TMA descriptors
-                nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=wgPrimaryThread)
-                nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=wgPrimaryThread)
-                
+                nvgpu.tma_prefetch_descriptor(a_tma_desc,
+                                              predicate=wgPrimaryThread)
+                nvgpu.tma_prefetch_descriptor(b_tma_desc,
+                                              predicate=wgPrimaryThread)
+
                 # GPU Step 5. Producer Warpgroup (TMA Warpgroup)
                 with ir.InsertionPoint(scf.IfOp(is_producer).then_block):
-                    
+
                     # Step 5.1. Reduce register size
-                    nvvm.setmaxregister(PRODUCER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.decrease)
+                    nvvm.setmaxregister(PRODUCER_REGISTER_SIZE,
+                                        nvvm.SetMaxRegisterAction.decrease)
 
                     # Step 5.2. TMA Main Loop
-                    for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [arith.constant(i1, 1)])
+                    for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1),
+                                       [arith.constant(i1, 1)])
                     with ir.InsertionPoint(for_op.body):
                         phaseParity = for_op.inner_iter_args[0]
                         iv = for_op.induction_variable
                         stage = arith.remui(iv, c(num_stages))
 
                         # Step 5.2.1. Wait mbarDONE
-                        debug_print("[prod] {}  | mbarDONE[{}] try_wait  phase={}",
-                                    iv,
-                                    stage,
-                                    phaseParity,
-                                    predicate=producerPrimaryThread)
-                        nvgpu.MBarrierTryWaitParityOp(mbarDONE, phaseParity, ticks, mbarId=stage)
-                        debug_print("[prod] {}  | mbarDONE[{}] try_wait  phase={} [done]",
-                                    iv,
-                                    stage,
-                                    phaseParity,
-                                    predicate=producerPrimaryThread)
-                        p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
-                        phaseParity = arith.select(p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity)
+                        debug_print(
+                            "[prod] {}  | mbarDONE[{}] try_wait  phase={}",
+                            iv,
+                            stage,
+                            phaseParity,
+                            predicate=producerPrimaryThread)
+                        nvgpu.MBarrierTryWaitParityOp(mbarDONE,
+                                                      phaseParity,
+                                                      ticks,
+                                                      mbarId=stage)
+                        debug_print(
+                            "[prod] {}  | mbarDONE[{}] try_wait  phase={} [done]",
+                            iv,
+                            stage,
+                            phaseParity,
+                            predicate=producerPrimaryThread)
+                        p = arith.cmpi(arith.CmpIPredicate.eq, stage,
+                                       c(num_stages - 1))
+                        phaseParity = arith.select(
+                            p, arith.xori(phaseParity, arith.constant(i1, 1)),
+                            phaseParity)
 
                         # Step 5.2.2. Load TMA
                         a_offset = arith.muli(stage, c(lhs_tile_bytes))
-                        a_tma_slice = memref.view(ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
-                                                  dynamic_smem, a_offset, [])
-                        b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
-                        b_tma_slice_1 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
-                                                    dynamic_smem, b_offset, [])
-                        b_offset2 = arith.addi(b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
-                        b_tma_slice_2 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
-                                                    dynamic_smem, b_offset2, [])
-                        debug_print("[prod] a_offset={} b_offset={} b_offset2={}",
-                                    a_offset,
-                                    b_offset,
-                                    b_offset2,
-                                    predicate=producerPrimaryThread)
+                        a_tma_slice = memref.view(
+                            ir.MemRefType.get(a_tma_shape,
+                                              f16,
+                                              memory_space=smem_space),
+                            dynamic_smem, a_offset, [])
+                        b_offset = arith.addi(
+                            arith.muli(stage, c(rhs_tile_bytes)),
+                            c(lhs_tile_bytes * num_stages))
+                        b_tma_slice_1 = memref.view(
+                            ir.MemRefType.get(b_tma_shape,
+                                              f16,
+                                              memory_space=smem_space),
+                            dynamic_smem, b_offset, [])
+                        b_offset2 = arith.addi(
+                            b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+                        b_tma_slice_2 = memref.view(
+                            ir.MemRefType.get(b_tma_shape,
+                                              f16,
+                                              memory_space=smem_space),
+                            dynamic_smem, b_offset2, [])
+                        debug_print(
+                            "[prod] a_offset={} b_offset={} b_offset2={}",
+                            a_offset,
+                            b_offset,
+                            b_offset2,
+                            predicate=producerPrimaryThread)
                         coord = arith.muli(c(64), iv)
                         nvgpu.TmaAsyncLoadOp(a_tma_slice,
                                              mbarTMA,
@@ -257,8 +315,15 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                                              predicate=producerPrimaryThread)
 
                         # Step 5.2.3. Arrive mbarTMA
-                        debug_print("[prod] {}  | mbarTMA[{}] arrive", iv, stage, predicate=producerPrimaryThread)
-                        nvgpu.mbarrier_arrive_expect_tx(mbarTMA, c(txcount), stage, predicate=producerPrimaryThread)
+                        debug_print("[prod] {}  | mbarTMA[{}] arrive",
+                                    iv,
+                                    stage,
+                                    predicate=producerPrimaryThread)
+                        nvgpu.mbarrier_arrive_expect_tx(
+                            mbarTMA,
+                            c(txcount),
+                            stage,
+                            predicate=producerPrimaryThread)
                         debug_print("[prod] {}  | mbarTMA[{}] arrive [done]",
                                     iv,
                                     stage,
@@ -269,75 +334,104 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                 # GPU Step 6. Consumer Warpgroup (MMA Warpgroup)
                 if_op = scf.IfOp(is_consumer)
                 with ir.InsertionPoint(if_op.then_block):
-                    
+
                     # Step 6.1. Increase register size
-                    nvvm.setmaxregister(CONSUMER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.increase)
-                    
+                    nvvm.setmaxregister(CONSUMER_REGISTER_SIZE,
+                                        nvvm.SetMaxRegisterAction.increase)
+
                     # GPU Step 6.2. Initialize MMA registers
                     acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
 
                     # Step 6.3. MMA Main Loop
-                    for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)])
+                    for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1),
+                                       [acc, arith.constant(i1, 0)])
                     with ir.InsertionPoint(for_op.body):
                         # Step 6.3.1. Wait mbar1
                         phaseParity = for_op.inner_iter_args[1]
                         iv = for_op.induction_variable
                         stage = arith.remui(iv, c(num_stages))
-                        debug_print("[cons] {}  | mbarTMA[{}] try_wait   phase={}",
-                                    iv,
-                                    stage,
-                                    phaseParity,
-                                    predicate=consumerPrimaryThread)
-                        nvgpu.MBarrierTryWaitParityOp(mbarTMA, phaseParity, ticks, mbarId=stage)
-                        debug_print("[cons] {}  | mbarTMA[{}] try_wait   phase={} [done]",
-                                    iv,
-                                    stage,
-                                    phaseParity,
-                                    predicate=consumerPrimaryThread)
-                        
+                        debug_print(
+                            "[cons] {}  | mbarTMA[{}] try_wait   phase={}",
+                            iv,
+                            stage,
+                            phaseParity,
+                            predicate=consumerPrimaryThread)
+                        nvgpu.MBarrierTryWaitParityOp(mbarTMA,
+                                                      phaseParity,
+                                                      ticks,
+                                                      mbarId=stage)
+                        debug_print(
+                            "[cons] {}  | mbarTMA[{}] try_wait   phase={} [done]",
+                            iv,
+                            stage,
+                            phaseParity,
+                            predicate=consumerPrimaryThread)
+
                         # Step 6.3.2. Create WGMMA Descriptors
                         a_offset = arith.muli(stage, c(lhs_tile_bytes))
-                        a_tile_slice = memref.view(ir.MemRefType.get(a_tile_shape, f16, memory_space=smem_space),
-                                                   dynamic_smem, a_offset, [])
-                        b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
-                        b_tile_slice = memref.view(ir.MemRefType.get(b_tile_shape, f16, memory_space=smem_space),
-                                                   dynamic_smem, b_offset, [])
+                        a_tile_slice = memref.view(
+                            ir.MemRefType.get(a_tile_shape,
+                                              f16,
+                                              memory_space=smem_space),
+                            dynamic_smem, a_offset, [])
+                        b_offset = arith.addi(
+                            arith.muli(stage, c(rhs_tile_bytes)),
+                            c(lhs_tile_bytes * num_stages))
+                        b_tile_slice = memref.view(
+                            ir.MemRefType.get(b_tile_shape,
+                                              f16,
+                                              memory_space=smem_space),
+                            dynamic_smem, b_offset, [])
                         debug_print("[cons] a_offset={} b_offset={}",
                                     a_offset,
                                     b_offset,
                                     predicate=consumerPrimaryThread)
-                        da = nvgpu.WarpgroupGenerateDescriptorOp(a_wgmma_ty, a_tile_slice, a_tma_desc)
-                        db = nvgpu.WarpgroupGenerateDescriptorOp(b_wgmma_ty, b_tile_slice, b_tma_desc)
+                        da = nvgpu.WarpgroupGenerateDescriptorOp(
+                            a_wgmma_ty, a_tile_slice, a_tma_desc)
+                        db = nvgpu.WarpgroupGenerateDescriptorOp(
+                            b_wgmma_ty, b_tile_slice, b_tma_desc)
 
                         # Step 6.3.3. MMA
                         carry_acc = for_op.inner_iter_args[0]
-                        new_acc = nvgpu.WarpgroupMmaOp(acc.type, da, db, carry_acc, transposeB=True)
+                        new_acc = nvgpu.WarpgroupMmaOp(acc.type,
+                                                       da,
+                                                       db,
+                                                       carry_acc,
+                                                       transposeB=True)
 
                         # Step 6.3.4. Arrive mbarDONE
                         p1 = arith.cmpi(arith.CmpIPredicate.sgt, iv, c(0))
                         p_arrive = arith.andi(consumerPrimaryThread, p1)
                         with ir.InsertionPoint(scf.IfOp(p_arrive).then_block):
                             p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(0))
-                            barId = arith.select(p, c(num_stages - 1), arith.subi(stage, c(1)))
+                            barId = arith.select(p, c(num_stages - 1),
+                                                 arith.subi(stage, c(1)))
                             remoteCtaId = c(0)
                             pred = consumerPrimaryThread
-                            debug_print("[cons] {}  | mbarDONE[{}] arrive.mapa  pred={} ",
-                                        iv,
-                                        barId,
-                                        remoteCtaId,
-                                        pred,
-                                        predicate=consumerPrimaryThread)
-                            nvgpu.mbarrier_arrive(ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId)
-                            debug_print("[cons] {}  | mbarDONE[{}] arrive  pred={} [done]",
-                                        iv,
-                                        barId,
-                                        remoteCtaId,
-                                        pred,
-                                        predicate=consumerPrimaryThread)
+                            debug_print(
+                                "[cons] {}  | mbarDONE[{}] arrive.mapa  pred={} ",
+                                iv,
+                                barId,
+                                remoteCtaId,
+                                pred,
+                                predicate=consumerPrimaryThread)
+                            nvgpu.mbarrier_arrive(
+                                ir.Type.parse("!nvgpu.mbarrier.token"),
+                                mbarDONE, barId)
+                            debug_print(
+                                "[cons] {}  | mbarDONE[{}] arrive  pred={} [done]",
+                                iv,
+                                barId,
+                                remoteCtaId,
+                                pred,
+                                predicate=consumerPrimaryThread)
                             scf.yield_([])
 
-                        p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
-                        phaseParity = arith.select(p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity)
+                        p = arith.cmpi(arith.CmpIPredicate.eq, stage,
+                                       c(num_stages - 1))
+                        phaseParity = arith.select(
+                            p, arith.xori(phaseParity, arith.constant(i1, 1)),
+                            phaseParity)
 
                         # Step 6.3.5. Yield
                         scf.yield_([new_acc, phaseParity])
@@ -345,32 +439,45 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                     # Step 6.3. Wait All WGMMA
                     nvvm.WgmmaWaitGroupSyncOp(0)
 
-                    with ir.InsertionPoint(scf.IfOp(consumerPrimaryThread).then_block):
+                    with ir.InsertionPoint(
+                            scf.IfOp(consumerPrimaryThread).then_block):
                         barId = c((K // BLOCK_K) % num_stages)
-                        nvgpu.mbarrier_arrive(ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId)
+                        nvgpu.mbarrier_arrive(
+                            ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE,
+                            barId)
                         scf.yield_([])
 
                     # Step 6.4. Epilogue (registers --> shared memory)
-                    acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space)
+                    acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N),
+                                                    c_elem_ty,
+                                                    memory_space=smem_space)
                     acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
-                    debug_print("[cons]  | Storing", predicate=consumerPrimaryThread)
+                    debug_print("[cons]  | Storing",
+                                predicate=consumerPrimaryThread)
                     nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
                     scf.yield_([])
                 gpu.barrier()
 
                 # GPU Step 9. Epilogue (shared memory --> global memory)
-                fd = ir.MemRefType.get([BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space)
+                fd = ir.MemRefType.get([BLOCK_M * BLOCK_N],
+                                       c_elem_ty,
+                                       memory_space=smem_space)
                 collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
-                rty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty,
-                                        ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"))
-                c_device_per_block = memref.SubViewOp(rty, c_device, [dimX, dimY], [], [], [MLIR_DYNAMIC, MLIR_DYNAMIC],
-                                                      [BLOCK_M, BLOCK_N], [1, 1])
+                rty = ir.MemRefType.get(
+                    (BLOCK_M, BLOCK_N), c_elem_ty,
+                    ir.Attribute.parse("strided<[" + str(N) +
+                                       ", 1], offset: ?>"))
+                c_device_per_block = memref.SubViewOp(
+                    rty, c_device, [dimX, dimY], [], [],
+                    [MLIR_DYNAMIC, MLIR_DYNAMIC], [BLOCK_M, BLOCK_N], [1, 1])
                 vlen = 1
-                for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE * 2))
+                for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N),
+                                   c(vlen * WARP_GROUP_SIZE * 2))
                 with ir.InsertionPoint(for_op.body):
                     x = arith.divui(for_op.induction_variable, c(BLOCK_M))
                     y = arith.remui(for_op.induction_variable, c(BLOCK_N))
-                    vdata = vector.load(ir.VectorType.get((vlen, ), c_elem_ty), collapsed_smem,
+                    vdata = vector.load(ir.VectorType.get(
+                        (vlen, ), c_elem_ty), collapsed_smem,
                                         [for_op.induction_variable])
                     vector.store(vdata, c_device_per_block, [x, y])
                     scf.yield_([])
@@ -382,7 +489,8 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
             t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
             gpu.wait(token_ty, [t9])
 
-    mlir_matmul_warpspecialized.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+    mlir_matmul_warpspecialized.func_op.attributes[
+        "llvm.emit_c_interface"] = ir.UnitAttr.get()
     module.operation.verify()
     return module
 
@@ -403,7 +511,8 @@ def generate_matmul_multistage(input_type=np.float16,
     assert N % BLOCK_N == 0
     assert K % BLOCK_K == 0
 
-    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
+    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K +
+                                          BLOCK_K * BLOCK_N)
     num_stages = min(required_stages, max_num_stages)
 
     module = ir.Module.create()
@@ -421,25 +530,36 @@ def generate_matmul_multistage(input_type=np.float16,
     a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
     b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
     b_tile_shape = (BLOCK_K, BLOCK_N)
-    txcount = ((b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
+    txcount = ((b_tile_shape[0] * b_tile_shape[1]) +
+               (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
     smem_space_str = "#gpu.address_space<workgroup>"
     smem_space = ir.Attribute.parse(smem_space_str)
     input_type_str = "f16" if input_type == np.float16 else "f32"
     output_type_str = "f16" if output_type == np.float16 else "f32"
-    mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " + str(smem_space) + ", num_barriers = " +
+    mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " +
+                            str(smem_space) + ", num_barriers = " +
                             str(num_stages) + ">")
-    a_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
-                                  str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
-                                  ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
-    b_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
-                                  str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
-                                  ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
-    acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
+    a_tma_desc_ty = ir.Type.parse(
+        "!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
+        str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " +
+        str(smem_space) +
+        ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+    b_tma_desc_ty = ir.Type.parse(
+        "!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
+        str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " +
+        str(smem_space) +
+        ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+    acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" +
+                           str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
                            str(output_type_str) + ">>")
-    a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
-                               str(input_type_str) + ", " + smem_space_str + ">>")
-    b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
-                               str(input_type_str) + ", " + smem_space_str + ">>")
+    a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" +
+                               str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
+                               str(input_type_str) + ", " + smem_space_str +
+                               ">>")
+    b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" +
+                               str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
+                               str(input_type_str) + ", " + smem_space_str +
+                               ">>")
 
     with ir.InsertionPoint(module.body):
 
@@ -461,11 +581,16 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
             t7 = gpu.wait(token_ty, [t6])
 
             # Step 2. Create TMA Descriptors
-            tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape), (b_device, b_tma_desc_ty, b_tma_shape)]
+            tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape),
+                         (b_device, b_tma_desc_ty, b_tma_shape)]
             tma_descs = []
             for x_device, tensor_map_ty, tile_shape in tma_specs:
-                x_unranked = memref.cast(ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device)
-                tma_descs.append(nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked, map(c, tile_shape)).result)
+                x_unranked = memref.cast(
+                    ir.UnrankedMemRefType.get(f16, a_ty.memory_space),
+                    x_device)
+                tma_descs.append(
+                    nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked,
+                                                map(c, tile_shape)).result)
             a_tma_desc, b_tma_desc = tma_descs
 
             # Step 3. Launch Kernel with 1 Warpgroup
@@ -477,13 +602,16 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
             launch_op = gpu.LaunchOp(token_ty, [t7],
                                      *map(c, grid),
                                      *map(c, block),
-                                     dynamicSharedMemorySize=c(smem_size, ty=i32))
+                                     dynamicSharedMemorySize=c(smem_size,
+                                                               ty=i32))
             launch_op.body.blocks.append(*([index] * 12))
             with ir.InsertionPoint(launch_op.body.blocks[0]):
                 # GPU Step 0. Bootstrapping
                 memref.assume_alignment(c_device, 16)
                 dynamic_smem = gpu.dynamic_shared_memory(
-                    ir.MemRefType.get((MLIR_DYNAMIC, ), i8, memory_space=smem_space))
+                    ir.MemRefType.get((MLIR_DYNAMIC, ),
+                                      i8,
+                                      memory_space=smem_space))
                 ticks = c(10000000)
                 tidx = gpu.thread_id(gpu.Dimension.x)
                 primaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(0))
@@ -496,39 +624,58 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                 # GPU Step 1. Initialize mbarrier groups
                 mbarTMA = nvgpu.mbarrier_create(mbar_ty)
                 for i in range(num_stages):
-                    nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=primaryThread)
+                    nvgpu.mbarrier_init(mbarTMA,
+                                        c(1),
+                                        c(i),
+                                        predicate=primaryThread)
                 gpu.barrier()
 
                 # GPU Step 2. Prefetch TMA descriptors
-                nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=primaryThread)
-                nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=primaryThread)
+                nvgpu.tma_prefetch_descriptor(a_tma_desc,
+                                              predicate=primaryThread)
+                nvgpu.tma_prefetch_descriptor(b_tma_desc,
+                                              predicate=primaryThread)
 
                 # GPU Step 3. Prologue (global memory --> shared memory)
-                for_op = scf.ForOp(c(0), c(num_stages-1), c(1))
+                for_op = scf.ForOp(c(0), c(num_stages - 1), c(1))
                 with ir.InsertionPoint(for_op.body):
                     iv = for_op.induction_variable
 
                     # Step 3.1. Calculate offsets
                     a_offset = arith.muli(iv, c(lhs_tile_bytes))
-                    a_tma_slice = memref.view(ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
-                                              dynamic_smem, a_offset, [])
-                    b_offset = arith.addi(arith.muli(iv, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
-                    b_tma_slice_1 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
-                                                dynamic_smem, b_offset, [])
-                    b_offset2 = arith.addi(b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
-                    b_tma_slice_2 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
-                                                dynamic_smem, b_offset2, [])
+                    a_tma_slice = memref.view(
+                        ir.MemRefType.get(a_tma_shape,
+                                          f16,
+                                          memory_space=smem_space),
+                        dynamic_smem, a_offset, [])
+                    b_offset = arith.addi(arith.muli(iv, c(rhs_tile_bytes)),
+                                          c(lhs_tile_bytes * num_stages))
+                    b_tma_slice_1 = memref.view(
+                        ir.MemRefType.get(b_tma_shape,
+                                          f16,
+                                          memory_space=smem_space),
+                        dynamic_smem, b_offset, [])
+                    b_offset2 = arith.addi(
+                        b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+                    b_tma_slice_2 = memref.view(
+                        ir.MemRefType.get(b_tma_shape,
+                                          f16,
+                                          memory_space=smem_space),
+                        dynamic_smem, b_offset2, [])
 
                     # Step 3.2. TMA Load
                     coord = arith.muli(c(64), iv)
                     dimY2 = arith.addi(dimY, c(64))
-                    debug_print("[Prologue] TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
-                                a_offset,
-                                b_offset,
-                                b_offset2,
-                                coord, dimX,
-                                dimY, coord,
-                                predicate=primaryThread)
+                    debug_print(
+                        "[Prologue] TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
+                        a_offset,
+                        b_offset,
+                        b_offset2,
+                        coord,
+                        dimX,
+                        dimY,
+                        coord,
+                        predicate=primaryThread)
                     nvgpu.TmaAsyncLoadOp(a_tma_slice,
                                          mbarTMA,
                                          a_tma_desc,
@@ -549,89 +696,153 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                                          predicate=primaryThread)
 
                     # Step 3.2. mbarTMA arrive
-                    debug_print("[Prologue] mbarTMA[{}] arrive", iv, predicate=primaryThread)
-                    nvgpu.mbarrier_arrive_expect_tx(mbarTMA, c(txcount), iv, predicate=primaryThread)
-                    debug_print("[Prologue] mbarTMA[{}] arrive [done]", iv, predicate=primaryThread)
+                    debug_print("[Prologue] mbarTMA[{}] arrive",
+                                iv,
+                                predicate=primaryThread)
+                    nvgpu.mbarrier_arrive_expect_tx(mbarTMA,
+                                                    c(txcount),
+                                                    iv,
+                                                    predicate=primaryThread)
+                    debug_print("[Prologue] mbarTMA[{}] arrive [done]",
+                                iv,
+                                predicate=primaryThread)
                     scf.yield_([])
 
                 # GPU Step 4. Main Loop
                 acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
-                for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)])
+                for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1),
+                                   [acc, arith.constant(i1, 0)])
                 with ir.InsertionPoint(for_op.body):
                     # Step 4.1. Wait mbarTMA
                     phaseParity = for_op.inner_iter_args[1]
                     iv = for_op.induction_variable
                     stage = arith.remui(iv, c(num_stages))
-                    debug_print("[MainLoop] mbarTMA[{}] try_wait   phase={}", stage, phaseParity, predicate=primaryThread)
-                    nvgpu.MBarrierTryWaitParityOp(mbarTMA, phaseParity, ticks, mbarId=stage)
-                    debug_print("[MainLoop] mbarTMA[{}] try_wait   phase={} [done]", stage, phaseParity, predicate=primaryThread)
+                    debug_print("[MainLoop] mbarTMA[{}] try_wait   phase={}",
+                                stage,
+                                phaseParity,
+                                predicate=primaryThread)
+                    nvgpu.MBarrierTryWaitParityOp(mbarTMA,
+                                                  phaseParity,
+                                                  ticks,
+                                                  mbarId=stage)
+                    debug_print(
+                        "[MainLoop] mbarTMA[{}] try_wait   phase={} [done]",
+                        stage,
+                        phaseParity,
+                        predicate=primaryThread)
 
                     # Step 4.2. Create WGMMA Descriptors
                     a_offset = arith.muli(stage, c(lhs_tile_bytes))
-                    a_tile_slice = memref.view(ir.MemRefType.get(a_tile_shape, f16, memory_space=smem_space),
-                                               dynamic_smem, a_offset, [])
-                    b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
-                    b_tile_slice = memref.view(ir.MemRefType.get(b_tile_shape, f16, memory_space=smem_space),
-                                               dynamic_smem, b_offset, [])
-                    debug_print("[MainLoop] iv={} MMA a_offset={} b_offset={}", iv, a_offset, b_offset, predicate=primaryThread)
-                    da = nvgpu.WarpgroupGenerateDescriptorOp(a_wgmma_ty, a_tile_slice, a_tma_desc)
-                    db = nvgpu.WarpgroupGenerateDescriptorOp(b_wgmma_ty, b_tile_slice, b_tma_desc)
+                    a_tile_slice = memref.view(
+                        ir.MemRefType.get(a_tile_shape,
+                                          f16,
+                                          memory_space=smem_space),
+                        dynamic_smem, a_offset, [])
+                    b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)),
+                                          c(lhs_tile_bytes * num_stages))
+                    b_tile_slice = memref.view(
+                        ir.MemRefType.get(b_tile_shape,
+                                          f16,
+                                          memory_space=smem_space),
+                        dynamic_smem, b_offset, [])
+                    debug_print("[MainLoop] iv={} MMA a_offset={} b_offset={}",
+                                iv,
+                                a_offset,
+                                b_offset,
+                                predicate=primaryThread)
+                    da = nvgpu.WarpgroupGenerateDescriptorOp(
+                        a_wgmma_ty, a_tile_slice, a_tma_desc)
+                    db = nvgpu.WarpgroupGenerateDescriptorOp(
+                        b_wgmma_ty, b_tile_slice, b_tma_desc)
 
                     # Step 4.3. MMA
                     carry_acc = for_op.inner_iter_args[0]
-                    new_acc = nvgpu.WarpgroupMmaOp(acc.type, da, db, carry_acc, transposeB=True)
+                    new_acc = nvgpu.WarpgroupMmaOp(acc.type,
+                                                   da,
+                                                   db,
+                                                   carry_acc,
+                                                   transposeB=True)
 
                     # Step 4.4. Load TMA for next stage
-                    p1 = arith.cmpi(arith.CmpIPredicate.ult, arith.addi(iv, c(num_stages-1)), c(K // BLOCK_K))
+                    p1 = arith.cmpi(arith.CmpIPredicate.ult,
+                                    arith.addi(iv, c(num_stages - 1)),
+                                    c(K // BLOCK_K))
                     p = arith.andi(primaryThread, p1)
                     with ir.InsertionPoint(scf.IfOp(p).then_block):
-                        nextStage = arith.addi(iv, c(num_stages-1))
+                        nextStage = arith.addi(iv, c(num_stages - 1))
                         nextSlot = arith.remui(nextStage, c(num_stages))
                         a_offset = arith.muli(nextSlot, c(lhs_tile_bytes))
-                        a_tma_slice = memref.view(ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
-                                                dynamic_smem, a_offset, [])
-                        b_offset = arith.addi(arith.muli(nextSlot, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
-                        b_tma_slice_1 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
-                                                    dynamic_smem, b_offset, [])
-                        b_offset2 = arith.addi(b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
-                        b_tma_slice_2 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
-                                                    dynamic_smem, b_offset2, [])
+                        a_tma_slice = memref.view(
+                            ir.MemRefType.get(a_tma_shape,
+                                              f16,
+                                              memory_space=smem_space),
+                            dynamic_smem, a_offset, [])
+                        b_offset = arith.addi(
+                            arith.muli(nextSlot, c(rhs_tile_bytes)),
+                            c(lhs_tile_bytes * num_stages))
+                        b_tma_slice_1 = memref.view(
+                            ir.MemRefType.get(b_tma_shape,
+                                              f16,
+                                              memory_space=smem_space),
+                            dynamic_smem, b_offset, [])
+                        b_offset2 = arith.addi(
+                            b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+                        b_tma_slice_2 = memref.view(
+                            ir.MemRefType.get(b_tma_shape,
+                                              f16,
+                                              memory_space=smem_space),
+                            dynamic_smem, b_offset2, [])
 
                         coord = arith.muli(c(64), nextStage)
-                        debug_print("[MainLoop] iv={} TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
-                                iv, a_offset,
-                                b_offset,
-                                b_offset2,
-                                coord, dimX,
-                                dimY, coord,
-                                predicate=primaryThread)
+                        debug_print(
+                            "[MainLoop] iv={} TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
+                            iv,
+                            a_offset,
+                            b_offset,
+                            b_offset2,
+                            coord,
+                            dimX,
+                            dimY,
+                            coord,
+                            predicate=primaryThread)
                         nvgpu.TmaAsyncLoadOp(a_tma_slice,
-                                            mbarTMA,
-                                            a_tma_desc,
-                                            coordinates=[coord, dimX],
-                                            mbarId=nextSlot,
-                                            predicate=primaryThread)
+                                             mbarTMA,
+                                             a_tma_desc,
+                                             coordinates=[coord, dimX],
+                                             mbarId=nextSlot,
+                                             predicate=primaryThread)
                         nvgpu.TmaAsyncLoadOp(b_tma_slice_1,
-                                            mbarTMA,
-                                            b_tma_desc,
-                                            coordinates=[dimY, coord],
-                                            mbarId=nextSlot,
-                                            predicate=primaryThread)
+                                             mbarTMA,
+                                             b_tma_desc,
+                                             coordinates=[dimY, coord],
+                                             mbarId=nextSlot,
+                                             predicate=primaryThread)
                         dimY2 = arith.addi(dimY, c(64))
                         nvgpu.TmaAsyncLoadOp(b_tma_slice_2,
-                                            mbarTMA,
-                                            b_tma_desc,
-                                            coordinates=[dimY2, coord],
-                                            mbarId=nextSlot,
-                                            predicate=primaryThread)
-
-                        debug_print("[MainLoop] mbarTMA[{}] arrive", nextSlot, predicate=primaryThread)
-                        nvgpu.mbarrier_arrive_expect_tx(mbarTMA, c(txcount), nextSlot, predicate=primaryThread)
-                        debug_print("[MainLoop] mbarTMA[{}] arrive [done]", nextSlot, predicate=primaryThread)
+                                             mbarTMA,
+                                             b_tma_desc,
+                                             coordinates=[dimY2, coord],
+                                             mbarId=nextSlot,
+                                             predicate=primaryThread)
+
+                        debug_print("[MainLoop] mbarTMA[{}] arrive",
+                                    nextSlot,
+                                    predicate=primaryThread)
+                        nvgpu.mbarrier_arrive_expect_tx(
+                            mbarTMA,
+                            c(txcount),
+                            nextSlot,
+                            predicate=primaryThread)
+                        debug_print("[MainLoop] mbarTMA[{}] arrive [done]",
+                                    nextSlot,
+                                    predicate=primaryThread)
                         scf.yield_([])
                     # Step 4.5. Change the phaseParity
-                    p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
-                    phaseParity = arith.select(p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity)
+                    p = arith.cmpi(arith.CmpIPredicate.eq, stage,
+                                   c(num_stages - 1))
+                    phaseParity = arith.select(
+                        p, arith.xori(phaseParity, arith.constant(i1, 1)),
+                        phaseParity)
 
                     # Step 4.5. Yield
                     scf.yield_([new_acc, phaseParity])
@@ -641,25 +852,34 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                 gpu.barrier()
 
                 # Step 6. Epilogue (registers --> shared memory)
-                acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space)
+                acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N),
+                                                c_elem_ty,
+                                                memory_space=smem_space)
                 acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
                 debug_print("Storing", predicate=primaryThread)
                 nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
                 gpu.barrier()
 
                 # GPU Step 7. Epilogue (shared memory --> global memory)
-                fd = ir.MemRefType.get([BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space)
+                fd = ir.MemRefType.get([BLOCK_M * BLOCK_N],
+                                       c_elem_ty,
+                                       memory_space=smem_space)
                 collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
-                rty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty,
-                                        ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"))
-                c_device_per_block = memref.SubViewOp(rty, c_device, [dimX, dimY], [], [], [MLIR_DYNAMIC, MLIR_DYNAMIC],
-                                                      [BLOCK_M, BLOCK_N], [1, 1])
+                rty = ir.MemRefType.get(
+                    (BLOCK_M, BLOCK_N), c_elem_ty,
+                    ir.Attribute.parse("strided<[" + str(N) +
+                                       ", 1], offset: ?>"))
+                c_device_per_block = memref.SubViewOp(
+                    rty, c_device, [dimX, dimY], [], [],
+                    [MLIR_DYNAMIC, MLIR_DYNAMIC], [BLOCK_M, BLOCK_N], [1, 1])
                 vlen = 1
-                for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE))
+                for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N),
+                                   c(vlen * WARP_GROUP_SIZE))
                 with ir.InsertionPoint(for_op.body):
                     x = arith.divui(for_op.induction_variable, c(BLOCK_M))
                     y = arith.remui(for_op.induction_variable, c(BLOCK_N))
-                    vdata = vector.load(ir.VectorType.get((vlen, ), c_elem_ty), collapsed_smem,
+                    vdata = vector.load(ir.VectorType.get(
+                        (vlen, ), c_elem_ty), collapsed_smem,
                                         [for_op.induction_variable])
                     vector.store(vdata, c_device_per_block, [x, y])
                     scf.yield_([])
@@ -671,6 +891,7 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
             t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
             gpu.wait(token_ty, [t9])
 
-    mlir_matmul_multistage.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+    mlir_matmul_multistage.func_op.attributes[
+        "llvm.emit_c_interface"] = ir.UnitAttr.get()
     module.operation.verify()
     return module
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
index 2a4f67326ee718..ab218812bb4d96 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
@@ -15,10 +15,12 @@
 _SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
 sys.path.append(_SCRIPT_PATH)
 
+
 class NvgpuCompiler:
     """Nvgpu class for compiling and building MLIR modules."""
 
-    def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
+    def __init__(self, options: str, opt_level: int,
+                 shared_libs: Sequence[str]):
         pipeline = f"builtin.module(gpu-lower-to-nvvm-pipeline{{{options}}})"
         self.pipeline = pipeline
         self.shared_libs = shared_libs
@@ -31,14 +33,15 @@ def __call__(self, module: ir.Module):
     def compile(self, module: ir.Module):
         """Compiles the module by invoking the nvgpu pipeline."""
         passmanager.PassManager.parse(self.pipeline).run(module.operation)
-    
+
     def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
         """Wraps the module in a JIT execution engine."""
-        return execution_engine.ExecutionEngine(
-            module, opt_level=self.opt_level, shared_libs=self.shared_libs
-        )
+        return execution_engine.ExecutionEngine(module,
+                                                opt_level=self.opt_level,
+                                                shared_libs=self.shared_libs)
 
-    def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+    def compile_and_jit(self,
+                        module: ir.Module) -> execution_engine.ExecutionEngine:
         """Compiles and jits the module."""
         self.compile(module)
         return self.jit(module)

>From 623e90f380ecfd1ef10ae8b7fe539e39dd269e52 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Mon, 12 Feb 2024 13:59:39 +0000
Subject: [PATCH 3/8] format it with black

---
 .../GPU/CUDA/sm90/python/matmul.py            | 177 ++--
 .../CUDA/sm90/python/tools/matmulBuilder.py   | 975 +++++++++++-------
 .../CUDA/sm90/python/tools/nvgpucompiler.py   |  12 +-
 3 files changed, 701 insertions(+), 463 deletions(-)

diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
index 00313270d723d6..ff9c950193e695 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
@@ -86,27 +86,44 @@
 from mlir import runtime as rt
 
 
-def generate_matmul(input_type=np.float16,
-                    output_type=np.float32,
-                    M=4096,
-                    N=4096,
-                    K=4096,
-                    BLOCK_M=128,
-                    BLOCK_N=128,
-                    BLOCK_K=64,
-                    use_warp_specilization=True,
-                    saveIR=False,
-                    max_num_stages=3):
-    with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown(
-    ):
+def generate_matmul(
+    input_type=np.float16,
+    output_type=np.float32,
+    M=4096,
+    N=4096,
+    K=4096,
+    BLOCK_M=128,
+    BLOCK_N=128,
+    BLOCK_K=64,
+    use_warp_specilization=True,
+    saveIR=False,
+    max_num_stages=3,
+):
+    with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown():
         if use_warp_specilization:
             mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(
-                input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K,
-                max_num_stages)
+                input_type,
+                output_type,
+                M,
+                N,
+                K,
+                BLOCK_M,
+                BLOCK_N,
+                BLOCK_K,
+                max_num_stages,
+            )
         else:
             mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage(
-                input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K,
-                max_num_stages)
+                input_type,
+                output_type,
+                M,
+                N,
+                K,
+                BLOCK_M,
+                BLOCK_N,
+                BLOCK_K,
+                max_num_stages,
+            )
 
         mlir_nvgpu_module.operation.verify()
 
@@ -114,7 +131,7 @@ def generate_matmul(input_type=np.float16,
         if saveIR:
             # print(mlir_nvgpu_module)
             original_stdout = sys.stdout
-            with open('gemm.mlir', 'w') as f:
+            with open("gemm.mlir", "w") as f:
                 sys.stdout = f
                 print(mlir_nvgpu_module)
                 sys.stdout = original_stdout
@@ -123,43 +140,77 @@ def generate_matmul(input_type=np.float16,
         options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
         support_lib = os.getenv("SUPPORT_LIB")
         if not os.path.exists(support_lib):
-            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
-                                    support_lib)
-        compiler = nvgpucompiler.NvgpuCompiler(options,
-                                               opt_level=3,
-                                               shared_libs=[support_lib])
+            raise FileNotFoundError(
+                errno.ENOENT, os.strerror(errno.ENOENT), support_lib
+            )
+        compiler = nvgpucompiler.NvgpuCompiler(
+            options, opt_level=3, shared_libs=[support_lib]
+        )
 
         # Compile
         engine = compiler.compile_and_jit(mlir_nvgpu_module)
         return engine
 
 
-def matmul(input_type=np.float16,
-           output_type=np.float32,
-           M=128,
-           N=128,
-           K=128,
-           BLOCK_M=128,
-           BLOCK_N=128,
-           BLOCK_K=64,
-           use_warp_specilization=True,
-           saveIR=False,
-           max_num_stages=3,
-           print_results=False,
-           no_verify=False):
+def matmul(
+    input_type=np.float16,
+    output_type=np.float32,
+    M=128,
+    N=128,
+    K=128,
+    BLOCK_M=128,
+    BLOCK_N=128,
+    BLOCK_K=64,
+    use_warp_specilization=True,
+    saveIR=False,
+    max_num_stages=3,
+    print_results=False,
+    no_verify=False,
+):
     # Print the configuration
     ity = "f16" if input_type == np.float16 else "f32"
     oty = "f16" if output_type == np.float16 else "f32"
     gemmty = "Warp Specilization" if use_warp_specilization else "Multistage"
-    print("===-- Running GEMM " + gemmty + " " + oty + " += " + ity + " * " +
-          ity + ", Size " + str(M) + "x" + str(N) + "x" + str(K) + ", Tile " +
-          str(BLOCK_M) + "x" + str(BLOCK_N) + "x" + str(BLOCK_K) +
-          ", stages " + str(max_num_stages) + " --===")
+    print(
+        "===-- Running GEMM "
+        + gemmty
+        + " "
+        + oty
+        + " += "
+        + ity
+        + " * "
+        + ity
+        + ", Size "
+        + str(M)
+        + "x"
+        + str(N)
+        + "x"
+        + str(K)
+        + ", Tile "
+        + str(BLOCK_M)
+        + "x"
+        + str(BLOCK_N)
+        + "x"
+        + str(BLOCK_K)
+        + ", stages "
+        + str(max_num_stages)
+        + " --==="
+    )
 
     # Build IR and compile
-    engine = generate_matmul(input_type, output_type, M, N, K, BLOCK_M,
-                             BLOCK_N, BLOCK_K, use_warp_specilization, saveIR,
-                             max_num_stages)
+    engine = generate_matmul(
+        input_type,
+        output_type,
+        M,
+        N,
+        K,
+        BLOCK_M,
+        BLOCK_N,
+        BLOCK_K,
+        use_warp_specilization,
+        saveIR,
+        max_num_stages,
+    )
 
     # Allocate matrices and invoke the matmul
     c = np.zeros((M, N), output_type)
@@ -168,13 +219,17 @@ def matmul(input_type=np.float16,
     mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
     mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
     mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
-    kernelName = "mlir_matmul_warpspecialized" if use_warp_specilization else "mlir_matmul_multistage"
+    kernelName = (
+        "mlir_matmul_warpspecialized"
+        if use_warp_specilization
+        else "mlir_matmul_multistage"
+    )
 
     # Launch the MLIR generated kernel
     engine.invoke(kernelName, mem_a, mem_b, mem_c)
 
     float_formatter = "{:.2f}".format
-    np.set_printoptions(formatter={'float_kind': float_formatter})
+    np.set_printoptions(formatter={"float_kind": float_formatter})
 
     if print_results:
         print(c)
@@ -190,18 +245,22 @@ def matmul(input_type=np.float16,
 
 
 # GEMM Multistage       f32 += f16 * f16
-matmul(np.float16,
-       np.float32,
-       128,
-       128,
-       4096,
-       max_num_stages=3,
-       use_warp_specilization=False)
+matmul(
+    np.float16,
+    np.float32,
+    128,
+    128,
+    4096,
+    max_num_stages=3,
+    use_warp_specilization=False,
+)
 # GEMM Warp Specilized  f32 += f16 * f16
-matmul(np.float16,
-       np.float32,
-       256,
-       1024,
-       512,
-       max_num_stages=3,
-       use_warp_specilization=True)
+matmul(
+    np.float16,
+    np.float32,
+    256,
+    1024,
+    512,
+    max_num_stages=3,
+    use_warp_specilization=True,
+)
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
index 0fda8698606cd5..4c132ec992a461 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
@@ -64,15 +64,17 @@ def c(value, ty=None):
     return arith.constant(ty, value)
 
 
-def generate_matmul_ws(input_type=np.float16,
-                       output_type=np.float32,
-                       M=4096,
-                       N=4096,
-                       K=4096,
-                       BLOCK_M=128,
-                       BLOCK_N=128,
-                       BLOCK_K=128,
-                       max_num_stages=3):
+def generate_matmul_ws(
+    input_type=np.float16,
+    output_type=np.float32,
+    M=4096,
+    N=4096,
+    K=4096,
+    BLOCK_M=128,
+    BLOCK_N=128,
+    BLOCK_K=128,
+    max_num_stages=3,
+):
     # Limitaitons for now
     assert input_type == np.float16
     assert output_type == np.float32
@@ -80,8 +82,7 @@ def generate_matmul_ws(input_type=np.float16,
     assert N % BLOCK_N == 0
     assert K % BLOCK_K == 0
 
-    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K +
-                                          BLOCK_K * BLOCK_N)
+    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
     num_stages = min(required_stages, max_num_stages)
 
     module = ir.Module.create()
@@ -99,36 +100,73 @@ def generate_matmul_ws(input_type=np.float16,
     a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
     b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
     b_tile_shape = (BLOCK_K, BLOCK_N)
-    txcount = ((b_tile_shape[0] * b_tile_shape[1]) +
-               (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
+    txcount = (
+        (b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])
+    ) * f16_byte
     smem_space_str = "#gpu.address_space<workgroup>"
     smem_space = ir.Attribute.parse(smem_space_str)
     input_type_str = "f16" if input_type == np.float16 else "f32"
     output_type_str = "f16" if output_type == np.float16 else "f32"
-    mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " +
-                            str(smem_space) + ", num_barriers = " +
-                            str(num_stages) + ">")
+    mbar_ty = ir.Type.parse(
+        "!nvgpu.mbarrier.group<memorySpace = "
+        + str(smem_space)
+        + ", num_barriers = "
+        + str(num_stages)
+        + ">"
+    )
     a_tma_desc_ty = ir.Type.parse(
-        "!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
-        str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " +
-        str(smem_space) +
-        ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+        "!nvgpu.tensormap.descriptor<tensor = memref<"
+        + str(BLOCK_M)
+        + "x"
+        + str(TMA_LAST_DIM_F16)
+        + "x"
+        + str(input_type_str)
+        + ", "
+        + str(smem_space)
+        + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
+    )
     b_tma_desc_ty = ir.Type.parse(
-        "!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
-        str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " +
-        str(smem_space) +
-        ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
-    acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" +
-                           str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
-                           str(output_type_str) + ">>")
-    a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" +
-                               str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
-                               str(input_type_str) + ", " + smem_space_str +
-                               ">>")
-    b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" +
-                               str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
-                               str(input_type_str) + ", " + smem_space_str +
-                               ">>")
+        "!nvgpu.tensormap.descriptor<tensor = memref<"
+        + str(BLOCK_K)
+        + "x"
+        + str(TMA_LAST_DIM_F16)
+        + "x"
+        + str(input_type_str)
+        + ", "
+        + str(smem_space)
+        + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
+    )
+    acc_ty = ir.Type.parse(
+        "!nvgpu.warpgroup.accumulator<fragmented=vector<"
+        + str(BLOCK_M)
+        + "x"
+        + str(BLOCK_N)
+        + "x"
+        + str(output_type_str)
+        + ">>"
+    )
+    a_wgmma_ty = ir.Type.parse(
+        "!nvgpu.warpgroup.descriptor<tensor=memref<"
+        + str(BLOCK_M)
+        + "x"
+        + str(BLOCK_K)
+        + "x"
+        + str(input_type_str)
+        + ", "
+        + smem_space_str
+        + ">>"
+    )
+    b_wgmma_ty = ir.Type.parse(
+        "!nvgpu.warpgroup.descriptor<tensor=memref<"
+        + str(BLOCK_K)
+        + "x"
+        + str(BLOCK_N)
+        + "x"
+        + str(input_type_str)
+        + ", "
+        + smem_space_str
+        + ">>"
+    )
 
     with ir.InsertionPoint(module.body):
 
@@ -150,16 +188,20 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
             t7 = gpu.wait(token_ty, [t6])
 
             # Step 2. Create TMA Descriptors
-            tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape),
-                         (b_device, b_tma_desc_ty, b_tma_shape)]
+            tma_specs = [
+                (a_device, a_tma_desc_ty, a_tma_shape),
+                (b_device, b_tma_desc_ty, b_tma_shape),
+            ]
             tma_descs = []
             for x_device, tensor_map_ty, tile_shape in tma_specs:
                 x_unranked = memref.cast(
-                    ir.UnrankedMemRefType.get(f16, a_ty.memory_space),
-                    x_device)
+                    ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device
+                )
                 tma_descs.append(
-                    nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked,
-                                                map(c, tile_shape)).result)
+                    nvgpu.TmaCreateDescriptorOp(
+                        tensor_map_ty, x_unranked, map(c, tile_shape)
+                    ).result
+                )
             a_tma_desc, b_tma_desc = tma_descs
 
             # Step 3. Launch Kernel with 2 Warpgroups : 1 Producer, 1 Consumer
@@ -168,40 +210,45 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
             assert M % BLOCK_M == 0 and N % BLOCK_N == 0
             grid = (cta_m, cta_n, 1)
             block = (WARP_GROUP_SIZE * 2, 1, 1)
-            launch_op = gpu.LaunchOp(token_ty, [t7],
-                                     *map(c, grid),
-                                     *map(c, block),
-                                     dynamicSharedMemorySize=c(smem_size,
-                                                               ty=i32))
+            launch_op = gpu.LaunchOp(
+                token_ty,
+                [t7],
+                *map(c, grid),
+                *map(c, block),
+                dynamicSharedMemorySize=c(smem_size, ty=i32)
+            )
             launch_op.body.blocks.append(*([index] * 12))
             with ir.InsertionPoint(launch_op.body.blocks[0]):
                 # GPU Step 0. This is need for vectorized ld/st
                 memref.assume_alignment(c_device, 16)
                 dynamic_smem = gpu.dynamic_shared_memory(
-                    ir.MemRefType.get((MLIR_DYNAMIC, ),
-                                      i8,
-                                      memory_space=smem_space))
+                    ir.MemRefType.get((MLIR_DYNAMIC,), i8, memory_space=smem_space)
+                )
                 ticks = c(10000000)
 
                 # GPU Step 1. Bootstrapping: find the primary thread, warps, warp groups and etc.
                 tidx = gpu.thread_id(gpu.Dimension.x)
                 wgPrimaryThread = arith.cmpi(
-                    arith.CmpIPredicate.eq,
-                    arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0))
+                    arith.CmpIPredicate.eq, arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0)
+                )
                 warp_id = arith.divui(tidx, c(32))
                 warpgroup_id = arith.divui(warp_id, c(4))
                 is_producer = arith.cmpi(
-                    arith.CmpIPredicate.eq, warpgroup_id,
-                    c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0))
+                    arith.CmpIPredicate.eq,
+                    warpgroup_id,
+                    c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0),
+                )
                 is_consumer = arith.cmpi(
-                    arith.CmpIPredicate.eq, warpgroup_id,
-                    c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1))
-                producerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq,
-                                                   tidx,
-                                                   c(PRODUCER_PRIMARY_THREAD))
-                consumerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq,
-                                                   tidx,
-                                                   c(CONSUMER_PRIMARY_THREAD))
+                    arith.CmpIPredicate.eq,
+                    warpgroup_id,
+                    c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1),
+                )
+                producerPrimaryThread = arith.cmpi(
+                    arith.CmpIPredicate.eq, tidx, c(PRODUCER_PRIMARY_THREAD)
+                )
+                consumerPrimaryThread = arith.cmpi(
+                    arith.CmpIPredicate.eq, tidx, c(CONSUMER_PRIMARY_THREAD)
+                )
                 bidx = gpu.block_id(gpu.Dimension.x)
                 bidy = gpu.block_id(gpu.Dimension.y)
                 dimX = arith.muli(bidx, c(BLOCK_M))
@@ -211,32 +258,25 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                 mbarTMA = nvgpu.mbarrier_create(mbar_ty)
                 mbarDONE = nvgpu.mbarrier_create(mbar_ty)
                 for i in range(num_stages):
-                    nvgpu.mbarrier_init(mbarTMA,
-                                        c(1),
-                                        c(i),
-                                        predicate=wgPrimaryThread)
-                    nvgpu.mbarrier_init(mbarDONE,
-                                        c(1),
-                                        c(i),
-                                        predicate=wgPrimaryThread)
+                    nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=wgPrimaryThread)
+                    nvgpu.mbarrier_init(mbarDONE, c(1), c(i), predicate=wgPrimaryThread)
                 gpu.barrier()
 
                 # GPU Step 3. Prefetch TMA descriptors
-                nvgpu.tma_prefetch_descriptor(a_tma_desc,
-                                              predicate=wgPrimaryThread)
-                nvgpu.tma_prefetch_descriptor(b_tma_desc,
-                                              predicate=wgPrimaryThread)
+                nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=wgPrimaryThread)
+                nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=wgPrimaryThread)
 
                 # GPU Step 5. Producer Warpgroup (TMA Warpgroup)
                 with ir.InsertionPoint(scf.IfOp(is_producer).then_block):
-
                     # Step 5.1. Reduce register size
-                    nvvm.setmaxregister(PRODUCER_REGISTER_SIZE,
-                                        nvvm.SetMaxRegisterAction.decrease)
+                    nvvm.setmaxregister(
+                        PRODUCER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.decrease
+                    )
 
                     # Step 5.2. TMA Main Loop
-                    for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1),
-                                       [arith.constant(i1, 1)])
+                    for_op = scf.ForOp(
+                        c(0), c(K // BLOCK_K), c(1), [arith.constant(i1, 1)]
+                    )
                     with ir.InsertionPoint(for_op.body):
                         phaseParity = for_op.inner_iter_args[0]
                         iv = for_op.induction_variable
@@ -248,103 +288,126 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                             iv,
                             stage,
                             phaseParity,
-                            predicate=producerPrimaryThread)
-                        nvgpu.MBarrierTryWaitParityOp(mbarDONE,
-                                                      phaseParity,
-                                                      ticks,
-                                                      mbarId=stage)
+                            predicate=producerPrimaryThread,
+                        )
+                        nvgpu.MBarrierTryWaitParityOp(
+                            mbarDONE, phaseParity, ticks, mbarId=stage
+                        )
                         debug_print(
                             "[prod] {}  | mbarDONE[{}] try_wait  phase={} [done]",
                             iv,
                             stage,
                             phaseParity,
-                            predicate=producerPrimaryThread)
-                        p = arith.cmpi(arith.CmpIPredicate.eq, stage,
-                                       c(num_stages - 1))
+                            predicate=producerPrimaryThread,
+                        )
+                        p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
                         phaseParity = arith.select(
-                            p, arith.xori(phaseParity, arith.constant(i1, 1)),
-                            phaseParity)
+                            p,
+                            arith.xori(phaseParity, arith.constant(i1, 1)),
+                            phaseParity,
+                        )
 
                         # Step 5.2.2. Load TMA
                         a_offset = arith.muli(stage, c(lhs_tile_bytes))
                         a_tma_slice = memref.view(
-                            ir.MemRefType.get(a_tma_shape,
-                                              f16,
-                                              memory_space=smem_space),
-                            dynamic_smem, a_offset, [])
+                            ir.MemRefType.get(
+                                a_tma_shape, f16, memory_space=smem_space
+                            ),
+                            dynamic_smem,
+                            a_offset,
+                            [],
+                        )
                         b_offset = arith.addi(
                             arith.muli(stage, c(rhs_tile_bytes)),
-                            c(lhs_tile_bytes * num_stages))
+                            c(lhs_tile_bytes * num_stages),
+                        )
                         b_tma_slice_1 = memref.view(
-                            ir.MemRefType.get(b_tma_shape,
-                                              f16,
-                                              memory_space=smem_space),
-                            dynamic_smem, b_offset, [])
+                            ir.MemRefType.get(
+                                b_tma_shape, f16, memory_space=smem_space
+                            ),
+                            dynamic_smem,
+                            b_offset,
+                            [],
+                        )
                         b_offset2 = arith.addi(
-                            b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+                            b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte)
+                        )
                         b_tma_slice_2 = memref.view(
-                            ir.MemRefType.get(b_tma_shape,
-                                              f16,
-                                              memory_space=smem_space),
-                            dynamic_smem, b_offset2, [])
+                            ir.MemRefType.get(
+                                b_tma_shape, f16, memory_space=smem_space
+                            ),
+                            dynamic_smem,
+                            b_offset2,
+                            [],
+                        )
                         debug_print(
                             "[prod] a_offset={} b_offset={} b_offset2={}",
                             a_offset,
                             b_offset,
                             b_offset2,
-                            predicate=producerPrimaryThread)
+                            predicate=producerPrimaryThread,
+                        )
                         coord = arith.muli(c(64), iv)
-                        nvgpu.TmaAsyncLoadOp(a_tma_slice,
-                                             mbarTMA,
-                                             a_tma_desc,
-                                             coordinates=[coord, dimX],
-                                             mbarId=stage,
-                                             predicate=producerPrimaryThread)
-                        nvgpu.TmaAsyncLoadOp(b_tma_slice_1,
-                                             mbarTMA,
-                                             b_tma_desc,
-                                             coordinates=[dimY, coord],
-                                             mbarId=stage,
-                                             predicate=producerPrimaryThread)
+                        nvgpu.TmaAsyncLoadOp(
+                            a_tma_slice,
+                            mbarTMA,
+                            a_tma_desc,
+                            coordinates=[coord, dimX],
+                            mbarId=stage,
+                            predicate=producerPrimaryThread,
+                        )
+                        nvgpu.TmaAsyncLoadOp(
+                            b_tma_slice_1,
+                            mbarTMA,
+                            b_tma_desc,
+                            coordinates=[dimY, coord],
+                            mbarId=stage,
+                            predicate=producerPrimaryThread,
+                        )
                         dimY2 = arith.addi(dimY, c(64))
-                        nvgpu.TmaAsyncLoadOp(b_tma_slice_2,
-                                             mbarTMA,
-                                             b_tma_desc,
-                                             coordinates=[dimY2, coord],
-                                             mbarId=stage,
-                                             predicate=producerPrimaryThread)
+                        nvgpu.TmaAsyncLoadOp(
+                            b_tma_slice_2,
+                            mbarTMA,
+                            b_tma_desc,
+                            coordinates=[dimY2, coord],
+                            mbarId=stage,
+                            predicate=producerPrimaryThread,
+                        )
 
                         # Step 5.2.3. Arrive mbarTMA
-                        debug_print("[prod] {}  | mbarTMA[{}] arrive",
-                                    iv,
-                                    stage,
-                                    predicate=producerPrimaryThread)
+                        debug_print(
+                            "[prod] {}  | mbarTMA[{}] arrive",
+                            iv,
+                            stage,
+                            predicate=producerPrimaryThread,
+                        )
                         nvgpu.mbarrier_arrive_expect_tx(
-                            mbarTMA,
-                            c(txcount),
+                            mbarTMA, c(txcount), stage, predicate=producerPrimaryThread
+                        )
+                        debug_print(
+                            "[prod] {}  | mbarTMA[{}] arrive [done]",
+                            iv,
                             stage,
-                            predicate=producerPrimaryThread)
-                        debug_print("[prod] {}  | mbarTMA[{}] arrive [done]",
-                                    iv,
-                                    stage,
-                                    predicate=producerPrimaryThread)
+                            predicate=producerPrimaryThread,
+                        )
                         scf.yield_([phaseParity])
                     scf.yield_([])
 
                 # GPU Step 6. Consumer Warpgroup (MMA Warpgroup)
                 if_op = scf.IfOp(is_consumer)
                 with ir.InsertionPoint(if_op.then_block):
-
                     # Step 6.1. Increase register size
-                    nvvm.setmaxregister(CONSUMER_REGISTER_SIZE,
-                                        nvvm.SetMaxRegisterAction.increase)
+                    nvvm.setmaxregister(
+                        CONSUMER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.increase
+                    )
 
                     # GPU Step 6.2. Initialize MMA registers
                     acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
 
                     # Step 6.3. MMA Main Loop
-                    for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1),
-                                       [acc, arith.constant(i1, 0)])
+                    for_op = scf.ForOp(
+                        c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)]
+                    )
                     with ir.InsertionPoint(for_op.body):
                         # Step 6.3.1. Wait mbar1
                         phaseParity = for_op.inner_iter_args[1]
@@ -355,57 +418,68 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                             iv,
                             stage,
                             phaseParity,
-                            predicate=consumerPrimaryThread)
-                        nvgpu.MBarrierTryWaitParityOp(mbarTMA,
-                                                      phaseParity,
-                                                      ticks,
-                                                      mbarId=stage)
+                            predicate=consumerPrimaryThread,
+                        )
+                        nvgpu.MBarrierTryWaitParityOp(
+                            mbarTMA, phaseParity, ticks, mbarId=stage
+                        )
                         debug_print(
                             "[cons] {}  | mbarTMA[{}] try_wait   phase={} [done]",
                             iv,
                             stage,
                             phaseParity,
-                            predicate=consumerPrimaryThread)
+                            predicate=consumerPrimaryThread,
+                        )
 
                         # Step 6.3.2. Create WGMMA Descriptors
                         a_offset = arith.muli(stage, c(lhs_tile_bytes))
                         a_tile_slice = memref.view(
-                            ir.MemRefType.get(a_tile_shape,
-                                              f16,
-                                              memory_space=smem_space),
-                            dynamic_smem, a_offset, [])
+                            ir.MemRefType.get(
+                                a_tile_shape, f16, memory_space=smem_space
+                            ),
+                            dynamic_smem,
+                            a_offset,
+                            [],
+                        )
                         b_offset = arith.addi(
                             arith.muli(stage, c(rhs_tile_bytes)),
-                            c(lhs_tile_bytes * num_stages))
+                            c(lhs_tile_bytes * num_stages),
+                        )
                         b_tile_slice = memref.view(
-                            ir.MemRefType.get(b_tile_shape,
-                                              f16,
-                                              memory_space=smem_space),
-                            dynamic_smem, b_offset, [])
-                        debug_print("[cons] a_offset={} b_offset={}",
-                                    a_offset,
-                                    b_offset,
-                                    predicate=consumerPrimaryThread)
+                            ir.MemRefType.get(
+                                b_tile_shape, f16, memory_space=smem_space
+                            ),
+                            dynamic_smem,
+                            b_offset,
+                            [],
+                        )
+                        debug_print(
+                            "[cons] a_offset={} b_offset={}",
+                            a_offset,
+                            b_offset,
+                            predicate=consumerPrimaryThread,
+                        )
                         da = nvgpu.WarpgroupGenerateDescriptorOp(
-                            a_wgmma_ty, a_tile_slice, a_tma_desc)
+                            a_wgmma_ty, a_tile_slice, a_tma_desc
+                        )
                         db = nvgpu.WarpgroupGenerateDescriptorOp(
-                            b_wgmma_ty, b_tile_slice, b_tma_desc)
+                            b_wgmma_ty, b_tile_slice, b_tma_desc
+                        )
 
                         # Step 6.3.3. MMA
                         carry_acc = for_op.inner_iter_args[0]
-                        new_acc = nvgpu.WarpgroupMmaOp(acc.type,
-                                                       da,
-                                                       db,
-                                                       carry_acc,
-                                                       transposeB=True)
+                        new_acc = nvgpu.WarpgroupMmaOp(
+                            acc.type, da, db, carry_acc, transposeB=True
+                        )
 
                         # Step 6.3.4. Arrive mbarDONE
                         p1 = arith.cmpi(arith.CmpIPredicate.sgt, iv, c(0))
                         p_arrive = arith.andi(consumerPrimaryThread, p1)
                         with ir.InsertionPoint(scf.IfOp(p_arrive).then_block):
                             p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(0))
-                            barId = arith.select(p, c(num_stages - 1),
-                                                 arith.subi(stage, c(1)))
+                            barId = arith.select(
+                                p, c(num_stages - 1), arith.subi(stage, c(1))
+                            )
                             remoteCtaId = c(0)
                             pred = consumerPrimaryThread
                             debug_print(
@@ -414,24 +488,27 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                                 barId,
                                 remoteCtaId,
                                 pred,
-                                predicate=consumerPrimaryThread)
+                                predicate=consumerPrimaryThread,
+                            )
                             nvgpu.mbarrier_arrive(
-                                ir.Type.parse("!nvgpu.mbarrier.token"),
-                                mbarDONE, barId)
+                                ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId
+                            )
                             debug_print(
                                 "[cons] {}  | mbarDONE[{}] arrive  pred={} [done]",
                                 iv,
                                 barId,
                                 remoteCtaId,
                                 pred,
-                                predicate=consumerPrimaryThread)
+                                predicate=consumerPrimaryThread,
+                            )
                             scf.yield_([])
 
-                        p = arith.cmpi(arith.CmpIPredicate.eq, stage,
-                                       c(num_stages - 1))
+                        p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
                         phaseParity = arith.select(
-                            p, arith.xori(phaseParity, arith.constant(i1, 1)),
-                            phaseParity)
+                            p,
+                            arith.xori(phaseParity, arith.constant(i1, 1)),
+                            phaseParity,
+                        )
 
                         # Step 6.3.5. Yield
                         scf.yield_([new_acc, phaseParity])
@@ -439,46 +516,55 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                     # Step 6.3. Wait All WGMMA
                     nvvm.WgmmaWaitGroupSyncOp(0)
 
-                    with ir.InsertionPoint(
-                            scf.IfOp(consumerPrimaryThread).then_block):
+                    with ir.InsertionPoint(scf.IfOp(consumerPrimaryThread).then_block):
                         barId = c((K // BLOCK_K) % num_stages)
                         nvgpu.mbarrier_arrive(
-                            ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE,
-                            barId)
+                            ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId
+                        )
                         scf.yield_([])
 
                     # Step 6.4. Epilogue (registers --> shared memory)
-                    acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N),
-                                                    c_elem_ty,
-                                                    memory_space=smem_space)
+                    acc_smem_ty = ir.MemRefType.get(
+                        (BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space
+                    )
                     acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
-                    debug_print("[cons]  | Storing",
-                                predicate=consumerPrimaryThread)
+                    debug_print("[cons]  | Storing", predicate=consumerPrimaryThread)
                     nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
                     scf.yield_([])
                 gpu.barrier()
 
                 # GPU Step 9. Epilogue (shared memory --> global memory)
-                fd = ir.MemRefType.get([BLOCK_M * BLOCK_N],
-                                       c_elem_ty,
-                                       memory_space=smem_space)
+                fd = ir.MemRefType.get(
+                    [BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space
+                )
                 collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
                 rty = ir.MemRefType.get(
-                    (BLOCK_M, BLOCK_N), c_elem_ty,
-                    ir.Attribute.parse("strided<[" + str(N) +
-                                       ", 1], offset: ?>"))
+                    (BLOCK_M, BLOCK_N),
+                    c_elem_ty,
+                    ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"),
+                )
                 c_device_per_block = memref.SubViewOp(
-                    rty, c_device, [dimX, dimY], [], [],
-                    [MLIR_DYNAMIC, MLIR_DYNAMIC], [BLOCK_M, BLOCK_N], [1, 1])
+                    rty,
+                    c_device,
+                    [dimX, dimY],
+                    [],
+                    [],
+                    [MLIR_DYNAMIC, MLIR_DYNAMIC],
+                    [BLOCK_M, BLOCK_N],
+                    [1, 1],
+                )
                 vlen = 1
-                for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N),
-                                   c(vlen * WARP_GROUP_SIZE * 2))
+                for_op = scf.ForOp(
+                    tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE * 2)
+                )
                 with ir.InsertionPoint(for_op.body):
                     x = arith.divui(for_op.induction_variable, c(BLOCK_M))
                     y = arith.remui(for_op.induction_variable, c(BLOCK_N))
-                    vdata = vector.load(ir.VectorType.get(
-                        (vlen, ), c_elem_ty), collapsed_smem,
-                                        [for_op.induction_variable])
+                    vdata = vector.load(
+                        ir.VectorType.get((vlen,), c_elem_ty),
+                        collapsed_smem,
+                        [for_op.induction_variable],
+                    )
                     vector.store(vdata, c_device_per_block, [x, y])
                     scf.yield_([])
 
@@ -490,20 +576,23 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
             gpu.wait(token_ty, [t9])
 
     mlir_matmul_warpspecialized.func_op.attributes[
-        "llvm.emit_c_interface"] = ir.UnitAttr.get()
+        "llvm.emit_c_interface"
+    ] = ir.UnitAttr.get()
     module.operation.verify()
     return module
 
 
-def generate_matmul_multistage(input_type=np.float16,
-                               output_type=np.float32,
-                               M=4096,
-                               N=4096,
-                               K=4096,
-                               BLOCK_M=128,
-                               BLOCK_N=128,
-                               BLOCK_K=64,
-                               max_num_stages=3):
+def generate_matmul_multistage(
+    input_type=np.float16,
+    output_type=np.float32,
+    M=4096,
+    N=4096,
+    K=4096,
+    BLOCK_M=128,
+    BLOCK_N=128,
+    BLOCK_K=64,
+    max_num_stages=3,
+):
     # Limitaitons for now
     assert input_type == np.float16
     assert output_type == np.float32
@@ -511,8 +600,7 @@ def generate_matmul_multistage(input_type=np.float16,
     assert N % BLOCK_N == 0
     assert K % BLOCK_K == 0
 
-    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K +
-                                          BLOCK_K * BLOCK_N)
+    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
     num_stages = min(required_stages, max_num_stages)
 
     module = ir.Module.create()
@@ -530,36 +618,73 @@ def generate_matmul_multistage(input_type=np.float16,
     a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
     b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
     b_tile_shape = (BLOCK_K, BLOCK_N)
-    txcount = ((b_tile_shape[0] * b_tile_shape[1]) +
-               (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
+    txcount = (
+        (b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])
+    ) * f16_byte
     smem_space_str = "#gpu.address_space<workgroup>"
     smem_space = ir.Attribute.parse(smem_space_str)
     input_type_str = "f16" if input_type == np.float16 else "f32"
     output_type_str = "f16" if output_type == np.float16 else "f32"
-    mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " +
-                            str(smem_space) + ", num_barriers = " +
-                            str(num_stages) + ">")
+    mbar_ty = ir.Type.parse(
+        "!nvgpu.mbarrier.group<memorySpace = "
+        + str(smem_space)
+        + ", num_barriers = "
+        + str(num_stages)
+        + ">"
+    )
     a_tma_desc_ty = ir.Type.parse(
-        "!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
-        str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " +
-        str(smem_space) +
-        ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+        "!nvgpu.tensormap.descriptor<tensor = memref<"
+        + str(BLOCK_M)
+        + "x"
+        + str(TMA_LAST_DIM_F16)
+        + "x"
+        + str(input_type_str)
+        + ", "
+        + str(smem_space)
+        + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
+    )
     b_tma_desc_ty = ir.Type.parse(
-        "!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
-        str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " +
-        str(smem_space) +
-        ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
-    acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" +
-                           str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
-                           str(output_type_str) + ">>")
-    a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" +
-                               str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
-                               str(input_type_str) + ", " + smem_space_str +
-                               ">>")
-    b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" +
-                               str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
-                               str(input_type_str) + ", " + smem_space_str +
-                               ">>")
+        "!nvgpu.tensormap.descriptor<tensor = memref<"
+        + str(BLOCK_K)
+        + "x"
+        + str(TMA_LAST_DIM_F16)
+        + "x"
+        + str(input_type_str)
+        + ", "
+        + str(smem_space)
+        + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
+    )
+    acc_ty = ir.Type.parse(
+        "!nvgpu.warpgroup.accumulator<fragmented=vector<"
+        + str(BLOCK_M)
+        + "x"
+        + str(BLOCK_N)
+        + "x"
+        + str(output_type_str)
+        + ">>"
+    )
+    a_wgmma_ty = ir.Type.parse(
+        "!nvgpu.warpgroup.descriptor<tensor=memref<"
+        + str(BLOCK_M)
+        + "x"
+        + str(BLOCK_K)
+        + "x"
+        + str(input_type_str)
+        + ", "
+        + smem_space_str
+        + ">>"
+    )
+    b_wgmma_ty = ir.Type.parse(
+        "!nvgpu.warpgroup.descriptor<tensor=memref<"
+        + str(BLOCK_K)
+        + "x"
+        + str(BLOCK_N)
+        + "x"
+        + str(input_type_str)
+        + ", "
+        + smem_space_str
+        + ">>"
+    )
 
     with ir.InsertionPoint(module.body):
 
@@ -581,16 +706,20 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
             t7 = gpu.wait(token_ty, [t6])
 
             # Step 2. Create TMA Descriptors
-            tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape),
-                         (b_device, b_tma_desc_ty, b_tma_shape)]
+            tma_specs = [
+                (a_device, a_tma_desc_ty, a_tma_shape),
+                (b_device, b_tma_desc_ty, b_tma_shape),
+            ]
             tma_descs = []
             for x_device, tensor_map_ty, tile_shape in tma_specs:
                 x_unranked = memref.cast(
-                    ir.UnrankedMemRefType.get(f16, a_ty.memory_space),
-                    x_device)
+                    ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device
+                )
                 tma_descs.append(
-                    nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked,
-                                                map(c, tile_shape)).result)
+                    nvgpu.TmaCreateDescriptorOp(
+                        tensor_map_ty, x_unranked, map(c, tile_shape)
+                    ).result
+                )
             a_tma_desc, b_tma_desc = tma_descs
 
             # Step 3. Launch Kernel with 1 Warpgroup
@@ -599,19 +728,20 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
             assert M % BLOCK_M == 0 and N % BLOCK_N == 0
             grid = (cta_m, cta_n, 1)
             block = (WARP_GROUP_SIZE, 1, 1)
-            launch_op = gpu.LaunchOp(token_ty, [t7],
-                                     *map(c, grid),
-                                     *map(c, block),
-                                     dynamicSharedMemorySize=c(smem_size,
-                                                               ty=i32))
+            launch_op = gpu.LaunchOp(
+                token_ty,
+                [t7],
+                *map(c, grid),
+                *map(c, block),
+                dynamicSharedMemorySize=c(smem_size, ty=i32)
+            )
             launch_op.body.blocks.append(*([index] * 12))
             with ir.InsertionPoint(launch_op.body.blocks[0]):
                 # GPU Step 0. Bootstrapping
                 memref.assume_alignment(c_device, 16)
                 dynamic_smem = gpu.dynamic_shared_memory(
-                    ir.MemRefType.get((MLIR_DYNAMIC, ),
-                                      i8,
-                                      memory_space=smem_space))
+                    ir.MemRefType.get((MLIR_DYNAMIC,), i8, memory_space=smem_space)
+                )
                 ticks = c(10000000)
                 tidx = gpu.thread_id(gpu.Dimension.x)
                 primaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(0))
@@ -624,17 +754,12 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                 # GPU Step 1. Initialize mbarrier groups
                 mbarTMA = nvgpu.mbarrier_create(mbar_ty)
                 for i in range(num_stages):
-                    nvgpu.mbarrier_init(mbarTMA,
-                                        c(1),
-                                        c(i),
-                                        predicate=primaryThread)
+                    nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=primaryThread)
                 gpu.barrier()
 
                 # GPU Step 2. Prefetch TMA descriptors
-                nvgpu.tma_prefetch_descriptor(a_tma_desc,
-                                              predicate=primaryThread)
-                nvgpu.tma_prefetch_descriptor(b_tma_desc,
-                                              predicate=primaryThread)
+                nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=primaryThread)
+                nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=primaryThread)
 
                 # GPU Step 3. Prologue (global memory --> shared memory)
                 for_op = scf.ForOp(c(0), c(num_stages - 1), c(1))
@@ -644,24 +769,30 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                     # Step 3.1. Calculate offsets
                     a_offset = arith.muli(iv, c(lhs_tile_bytes))
                     a_tma_slice = memref.view(
-                        ir.MemRefType.get(a_tma_shape,
-                                          f16,
-                                          memory_space=smem_space),
-                        dynamic_smem, a_offset, [])
-                    b_offset = arith.addi(arith.muli(iv, c(rhs_tile_bytes)),
-                                          c(lhs_tile_bytes * num_stages))
+                        ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
+                        dynamic_smem,
+                        a_offset,
+                        [],
+                    )
+                    b_offset = arith.addi(
+                        arith.muli(iv, c(rhs_tile_bytes)),
+                        c(lhs_tile_bytes * num_stages),
+                    )
                     b_tma_slice_1 = memref.view(
-                        ir.MemRefType.get(b_tma_shape,
-                                          f16,
-                                          memory_space=smem_space),
-                        dynamic_smem, b_offset, [])
+                        ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+                        dynamic_smem,
+                        b_offset,
+                        [],
+                    )
                     b_offset2 = arith.addi(
-                        b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+                        b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte)
+                    )
                     b_tma_slice_2 = memref.view(
-                        ir.MemRefType.get(b_tma_shape,
-                                          f16,
-                                          memory_space=smem_space),
-                        dynamic_smem, b_offset2, [])
+                        ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+                        dynamic_smem,
+                        b_offset2,
+                        [],
+                    )
 
                     # Step 3.2. TMA Load
                     coord = arith.muli(c(64), iv)
@@ -675,123 +806,153 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                         dimX,
                         dimY,
                         coord,
-                        predicate=primaryThread)
-                    nvgpu.TmaAsyncLoadOp(a_tma_slice,
-                                         mbarTMA,
-                                         a_tma_desc,
-                                         coordinates=[coord, dimX],
-                                         mbarId=iv,
-                                         predicate=primaryThread)
-                    nvgpu.TmaAsyncLoadOp(b_tma_slice_1,
-                                         mbarTMA,
-                                         b_tma_desc,
-                                         coordinates=[dimY, coord],
-                                         mbarId=iv,
-                                         predicate=primaryThread)
-                    nvgpu.TmaAsyncLoadOp(b_tma_slice_2,
-                                         mbarTMA,
-                                         b_tma_desc,
-                                         coordinates=[dimY2, coord],
-                                         mbarId=iv,
-                                         predicate=primaryThread)
+                        predicate=primaryThread,
+                    )
+                    nvgpu.TmaAsyncLoadOp(
+                        a_tma_slice,
+                        mbarTMA,
+                        a_tma_desc,
+                        coordinates=[coord, dimX],
+                        mbarId=iv,
+                        predicate=primaryThread,
+                    )
+                    nvgpu.TmaAsyncLoadOp(
+                        b_tma_slice_1,
+                        mbarTMA,
+                        b_tma_desc,
+                        coordinates=[dimY, coord],
+                        mbarId=iv,
+                        predicate=primaryThread,
+                    )
+                    nvgpu.TmaAsyncLoadOp(
+                        b_tma_slice_2,
+                        mbarTMA,
+                        b_tma_desc,
+                        coordinates=[dimY2, coord],
+                        mbarId=iv,
+                        predicate=primaryThread,
+                    )
 
                     # Step 3.2. mbarTMA arrive
-                    debug_print("[Prologue] mbarTMA[{}] arrive",
-                                iv,
-                                predicate=primaryThread)
-                    nvgpu.mbarrier_arrive_expect_tx(mbarTMA,
-                                                    c(txcount),
-                                                    iv,
-                                                    predicate=primaryThread)
-                    debug_print("[Prologue] mbarTMA[{}] arrive [done]",
-                                iv,
-                                predicate=primaryThread)
+                    debug_print(
+                        "[Prologue] mbarTMA[{}] arrive", iv, predicate=primaryThread
+                    )
+                    nvgpu.mbarrier_arrive_expect_tx(
+                        mbarTMA, c(txcount), iv, predicate=primaryThread
+                    )
+                    debug_print(
+                        "[Prologue] mbarTMA[{}] arrive [done]",
+                        iv,
+                        predicate=primaryThread,
+                    )
                     scf.yield_([])
 
                 # GPU Step 4. Main Loop
                 acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
-                for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1),
-                                   [acc, arith.constant(i1, 0)])
+                for_op = scf.ForOp(
+                    c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)]
+                )
                 with ir.InsertionPoint(for_op.body):
                     # Step 4.1. Wait mbarTMA
                     phaseParity = for_op.inner_iter_args[1]
                     iv = for_op.induction_variable
                     stage = arith.remui(iv, c(num_stages))
-                    debug_print("[MainLoop] mbarTMA[{}] try_wait   phase={}",
-                                stage,
-                                phaseParity,
-                                predicate=primaryThread)
-                    nvgpu.MBarrierTryWaitParityOp(mbarTMA,
-                                                  phaseParity,
-                                                  ticks,
-                                                  mbarId=stage)
+                    debug_print(
+                        "[MainLoop] mbarTMA[{}] try_wait   phase={}",
+                        stage,
+                        phaseParity,
+                        predicate=primaryThread,
+                    )
+                    nvgpu.MBarrierTryWaitParityOp(
+                        mbarTMA, phaseParity, ticks, mbarId=stage
+                    )
                     debug_print(
                         "[MainLoop] mbarTMA[{}] try_wait   phase={} [done]",
                         stage,
                         phaseParity,
-                        predicate=primaryThread)
+                        predicate=primaryThread,
+                    )
 
                     # Step 4.2. Create WGMMA Descriptors
                     a_offset = arith.muli(stage, c(lhs_tile_bytes))
                     a_tile_slice = memref.view(
-                        ir.MemRefType.get(a_tile_shape,
-                                          f16,
-                                          memory_space=smem_space),
-                        dynamic_smem, a_offset, [])
-                    b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)),
-                                          c(lhs_tile_bytes * num_stages))
+                        ir.MemRefType.get(a_tile_shape, f16, memory_space=smem_space),
+                        dynamic_smem,
+                        a_offset,
+                        [],
+                    )
+                    b_offset = arith.addi(
+                        arith.muli(stage, c(rhs_tile_bytes)),
+                        c(lhs_tile_bytes * num_stages),
+                    )
                     b_tile_slice = memref.view(
-                        ir.MemRefType.get(b_tile_shape,
-                                          f16,
-                                          memory_space=smem_space),
-                        dynamic_smem, b_offset, [])
-                    debug_print("[MainLoop] iv={} MMA a_offset={} b_offset={}",
-                                iv,
-                                a_offset,
-                                b_offset,
-                                predicate=primaryThread)
+                        ir.MemRefType.get(b_tile_shape, f16, memory_space=smem_space),
+                        dynamic_smem,
+                        b_offset,
+                        [],
+                    )
+                    debug_print(
+                        "[MainLoop] iv={} MMA a_offset={} b_offset={}",
+                        iv,
+                        a_offset,
+                        b_offset,
+                        predicate=primaryThread,
+                    )
                     da = nvgpu.WarpgroupGenerateDescriptorOp(
-                        a_wgmma_ty, a_tile_slice, a_tma_desc)
+                        a_wgmma_ty, a_tile_slice, a_tma_desc
+                    )
                     db = nvgpu.WarpgroupGenerateDescriptorOp(
-                        b_wgmma_ty, b_tile_slice, b_tma_desc)
+                        b_wgmma_ty, b_tile_slice, b_tma_desc
+                    )
 
                     # Step 4.3. MMA
                     carry_acc = for_op.inner_iter_args[0]
-                    new_acc = nvgpu.WarpgroupMmaOp(acc.type,
-                                                   da,
-                                                   db,
-                                                   carry_acc,
-                                                   transposeB=True)
+                    new_acc = nvgpu.WarpgroupMmaOp(
+                        acc.type, da, db, carry_acc, transposeB=True
+                    )
 
                     # Step 4.4. Load TMA for next stage
-                    p1 = arith.cmpi(arith.CmpIPredicate.ult,
-                                    arith.addi(iv, c(num_stages - 1)),
-                                    c(K // BLOCK_K))
+                    p1 = arith.cmpi(
+                        arith.CmpIPredicate.ult,
+                        arith.addi(iv, c(num_stages - 1)),
+                        c(K // BLOCK_K),
+                    )
                     p = arith.andi(primaryThread, p1)
                     with ir.InsertionPoint(scf.IfOp(p).then_block):
                         nextStage = arith.addi(iv, c(num_stages - 1))
                         nextSlot = arith.remui(nextStage, c(num_stages))
                         a_offset = arith.muli(nextSlot, c(lhs_tile_bytes))
                         a_tma_slice = memref.view(
-                            ir.MemRefType.get(a_tma_shape,
-                                              f16,
-                                              memory_space=smem_space),
-                            dynamic_smem, a_offset, [])
+                            ir.MemRefType.get(
+                                a_tma_shape, f16, memory_space=smem_space
+                            ),
+                            dynamic_smem,
+                            a_offset,
+                            [],
+                        )
                         b_offset = arith.addi(
                             arith.muli(nextSlot, c(rhs_tile_bytes)),
-                            c(lhs_tile_bytes * num_stages))
+                            c(lhs_tile_bytes * num_stages),
+                        )
                         b_tma_slice_1 = memref.view(
-                            ir.MemRefType.get(b_tma_shape,
-                                              f16,
-                                              memory_space=smem_space),
-                            dynamic_smem, b_offset, [])
+                            ir.MemRefType.get(
+                                b_tma_shape, f16, memory_space=smem_space
+                            ),
+                            dynamic_smem,
+                            b_offset,
+                            [],
+                        )
                         b_offset2 = arith.addi(
-                            b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+                            b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte)
+                        )
                         b_tma_slice_2 = memref.view(
-                            ir.MemRefType.get(b_tma_shape,
-                                              f16,
-                                              memory_space=smem_space),
-                            dynamic_smem, b_offset2, [])
+                            ir.MemRefType.get(
+                                b_tma_shape, f16, memory_space=smem_space
+                            ),
+                            dynamic_smem,
+                            b_offset2,
+                            [],
+                        )
 
                         coord = arith.muli(c(64), nextStage)
                         debug_print(
@@ -804,45 +965,53 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                             dimX,
                             dimY,
                             coord,
-                            predicate=primaryThread)
-                        nvgpu.TmaAsyncLoadOp(a_tma_slice,
-                                             mbarTMA,
-                                             a_tma_desc,
-                                             coordinates=[coord, dimX],
-                                             mbarId=nextSlot,
-                                             predicate=primaryThread)
-                        nvgpu.TmaAsyncLoadOp(b_tma_slice_1,
-                                             mbarTMA,
-                                             b_tma_desc,
-                                             coordinates=[dimY, coord],
-                                             mbarId=nextSlot,
-                                             predicate=primaryThread)
+                            predicate=primaryThread,
+                        )
+                        nvgpu.TmaAsyncLoadOp(
+                            a_tma_slice,
+                            mbarTMA,
+                            a_tma_desc,
+                            coordinates=[coord, dimX],
+                            mbarId=nextSlot,
+                            predicate=primaryThread,
+                        )
+                        nvgpu.TmaAsyncLoadOp(
+                            b_tma_slice_1,
+                            mbarTMA,
+                            b_tma_desc,
+                            coordinates=[dimY, coord],
+                            mbarId=nextSlot,
+                            predicate=primaryThread,
+                        )
                         dimY2 = arith.addi(dimY, c(64))
-                        nvgpu.TmaAsyncLoadOp(b_tma_slice_2,
-                                             mbarTMA,
-                                             b_tma_desc,
-                                             coordinates=[dimY2, coord],
-                                             mbarId=nextSlot,
-                                             predicate=primaryThread)
-
-                        debug_print("[MainLoop] mbarTMA[{}] arrive",
-                                    nextSlot,
-                                    predicate=primaryThread)
-                        nvgpu.mbarrier_arrive_expect_tx(
+                        nvgpu.TmaAsyncLoadOp(
+                            b_tma_slice_2,
                             mbarTMA,
-                            c(txcount),
+                            b_tma_desc,
+                            coordinates=[dimY2, coord],
+                            mbarId=nextSlot,
+                            predicate=primaryThread,
+                        )
+
+                        debug_print(
+                            "[MainLoop] mbarTMA[{}] arrive",
+                            nextSlot,
+                            predicate=primaryThread,
+                        )
+                        nvgpu.mbarrier_arrive_expect_tx(
+                            mbarTMA, c(txcount), nextSlot, predicate=primaryThread
+                        )
+                        debug_print(
+                            "[MainLoop] mbarTMA[{}] arrive [done]",
                             nextSlot,
-                            predicate=primaryThread)
-                        debug_print("[MainLoop] mbarTMA[{}] arrive [done]",
-                                    nextSlot,
-                                    predicate=primaryThread)
+                            predicate=primaryThread,
+                        )
                         scf.yield_([])
                     # Step 4.5. Change the phaseParity
-                    p = arith.cmpi(arith.CmpIPredicate.eq, stage,
-                                   c(num_stages - 1))
+                    p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
                     phaseParity = arith.select(
-                        p, arith.xori(phaseParity, arith.constant(i1, 1)),
-                        phaseParity)
+                        p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity
+                    )
 
                     # Step 4.5. Yield
                     scf.yield_([new_acc, phaseParity])
@@ -852,35 +1021,46 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                 gpu.barrier()
 
                 # Step 6. Epilogue (registers --> shared memory)
-                acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N),
-                                                c_elem_ty,
-                                                memory_space=smem_space)
+                acc_smem_ty = ir.MemRefType.get(
+                    (BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space
+                )
                 acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
                 debug_print("Storing", predicate=primaryThread)
                 nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
                 gpu.barrier()
 
                 # GPU Step 7. Epilogue (shared memory --> global memory)
-                fd = ir.MemRefType.get([BLOCK_M * BLOCK_N],
-                                       c_elem_ty,
-                                       memory_space=smem_space)
+                fd = ir.MemRefType.get(
+                    [BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space
+                )
                 collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
                 rty = ir.MemRefType.get(
-                    (BLOCK_M, BLOCK_N), c_elem_ty,
-                    ir.Attribute.parse("strided<[" + str(N) +
-                                       ", 1], offset: ?>"))
+                    (BLOCK_M, BLOCK_N),
+                    c_elem_ty,
+                    ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"),
+                )
                 c_device_per_block = memref.SubViewOp(
-                    rty, c_device, [dimX, dimY], [], [],
-                    [MLIR_DYNAMIC, MLIR_DYNAMIC], [BLOCK_M, BLOCK_N], [1, 1])
+                    rty,
+                    c_device,
+                    [dimX, dimY],
+                    [],
+                    [],
+                    [MLIR_DYNAMIC, MLIR_DYNAMIC],
+                    [BLOCK_M, BLOCK_N],
+                    [1, 1],
+                )
                 vlen = 1
-                for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N),
-                                   c(vlen * WARP_GROUP_SIZE))
+                for_op = scf.ForOp(
+                    tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE)
+                )
                 with ir.InsertionPoint(for_op.body):
                     x = arith.divui(for_op.induction_variable, c(BLOCK_M))
                     y = arith.remui(for_op.induction_variable, c(BLOCK_N))
-                    vdata = vector.load(ir.VectorType.get(
-                        (vlen, ), c_elem_ty), collapsed_smem,
-                                        [for_op.induction_variable])
+                    vdata = vector.load(
+                        ir.VectorType.get((vlen,), c_elem_ty),
+                        collapsed_smem,
+                        [for_op.induction_variable],
+                    )
                     vector.store(vdata, c_device_per_block, [x, y])
                     scf.yield_([])
 
@@ -892,6 +1072,7 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
             gpu.wait(token_ty, [t9])
 
     mlir_matmul_multistage.func_op.attributes[
-        "llvm.emit_c_interface"] = ir.UnitAttr.get()
+        "llvm.emit_c_interface"
+    ] = ir.UnitAttr.get()
     module.operation.verify()
     return module
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
index ab218812bb4d96..1c9cc74fcd169c 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
@@ -19,8 +19,7 @@
 class NvgpuCompiler:
     """Nvgpu class for compiling and building MLIR modules."""
 
-    def __init__(self, options: str, opt_level: int,
-                 shared_libs: Sequence[str]):
+    def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
         pipeline = f"builtin.module(gpu-lower-to-nvvm-pipeline{{{options}}})"
         self.pipeline = pipeline
         self.shared_libs = shared_libs
@@ -36,12 +35,11 @@ def compile(self, module: ir.Module):
 
     def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
         """Wraps the module in a JIT execution engine."""
-        return execution_engine.ExecutionEngine(module,
-                                                opt_level=self.opt_level,
-                                                shared_libs=self.shared_libs)
+        return execution_engine.ExecutionEngine(
+            module, opt_level=self.opt_level, shared_libs=self.shared_libs
+        )
 
-    def compile_and_jit(self,
-                        module: ir.Module) -> execution_engine.ExecutionEngine:
+    def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
         """Compiles and jits the module."""
         self.compile(module)
         return self.jit(module)

>From aa3e3486ef011cd4bc9bb784f81b73981e913151 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Tue, 13 Feb 2024 09:09:46 +0000
Subject: [PATCH 4/8] fix the spelling mistake

---
 .../Integration/GPU/CUDA/sm90/python/matmul.py   | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
index ff9c950193e695..f5d5b30c70d1e1 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
@@ -95,12 +95,12 @@ def generate_matmul(
     BLOCK_M=128,
     BLOCK_N=128,
     BLOCK_K=64,
-    use_warp_specilization=True,
+    use_warp_specialization=True,
     saveIR=False,
     max_num_stages=3,
 ):
     with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown():
-        if use_warp_specilization:
+        if use_warp_specialization:
             mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(
                 input_type,
                 output_type,
@@ -161,7 +161,7 @@ def matmul(
     BLOCK_M=128,
     BLOCK_N=128,
     BLOCK_K=64,
-    use_warp_specilization=True,
+    use_warp_specialization=True,
     saveIR=False,
     max_num_stages=3,
     print_results=False,
@@ -170,7 +170,7 @@ def matmul(
     # Print the configuration
     ity = "f16" if input_type == np.float16 else "f32"
     oty = "f16" if output_type == np.float16 else "f32"
-    gemmty = "Warp Specilization" if use_warp_specilization else "Multistage"
+    gemmty = "Warp specialization" if use_warp_specialization else "Multistage"
     print(
         "===-- Running GEMM "
         + gemmty
@@ -207,7 +207,7 @@ def matmul(
         BLOCK_M,
         BLOCK_N,
         BLOCK_K,
-        use_warp_specilization,
+        use_warp_specialization,
         saveIR,
         max_num_stages,
     )
@@ -221,7 +221,7 @@ def matmul(
     mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
     kernelName = (
         "mlir_matmul_warpspecialized"
-        if use_warp_specilization
+        if use_warp_specialization
         else "mlir_matmul_multistage"
     )
 
@@ -252,7 +252,7 @@ def matmul(
     128,
     4096,
     max_num_stages=3,
-    use_warp_specilization=False,
+    use_warp_specialization=False,
 )
 # GEMM Warp Specilized  f32 += f16 * f16
 matmul(
@@ -262,5 +262,5 @@ def matmul(
     1024,
     512,
     max_num_stages=3,
-    use_warp_specilization=True,
+    use_warp_specialization=True,
 )

>From dea779aae915d610b409660a2fc6d0b2a3697a43 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Tue, 20 Feb 2024 10:09:11 +0000
Subject: [PATCH 5/8] address comments

---
 .../CUDA/sm90/python/tools/matmulBuilder.py   | 201 +++++++++++-------
 1 file changed, 121 insertions(+), 80 deletions(-)

diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
index 4c132ec992a461..434af5e0920466 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
@@ -10,6 +10,7 @@
 from mlir.dialects import builtin
 from mlir.dialects import scf
 from mlir.dialects import vector
+from mlir.extras import types as T
 
 TMA_LAST_DIM_F16 = 64  # 128B flaot16
 WARP_SIZE = 32
@@ -22,12 +23,8 @@
 CONSUMER_PRIMARY_THREAD = 0
 
 MLIR_DYNAMIC = -9223372036854775808
-f16_byte = 2
-f32_byte = 4
 
 DEBUG = False
-
-
 def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False):
     if not DEBUG and not forcePrint:
         return
@@ -59,11 +56,52 @@ def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False):
         scf.yield_([])
 
 
+def get_type_str(ty):
+    if ir.F16Type.isinstance(ty):
+        return "f16"
+    if ir.F32Type.isinstance(ty):
+        return "f32"
+    if ir.F64Type.isinstance(ty):
+        return "f64"
+    if ir.IntegerType.isinstance(ty):
+        return "i" + str(ir.IntegerType(ty).width)
+    if ir.IndexType.isinstance(ty):
+        return "T.index()"
+    raise NotImplementedError(ty)
+
+
+def get_type_size(ty):
+    if ir.F16Type.isinstance(ty):
+        return 2
+    if ir.F32Type.isinstance(ty):
+        return 4
+    if ir.F64Type.isinstance(ty):
+        return 8
+    if ir.IntegerType.isinstance(ty):
+        return ir.IntegerType(ty).width // 8
+    if ir.IndexType.isinstance(ty):
+        return 8
+    raise NotImplementedError(ty)
+
+
+def get_mlir_ty(dtype):
+    if dtype == np.float16:
+        return T.f16()
+    if dtype == np.float32:
+        return T.f32()
+    if dtype == np.float64:
+        return T.f64()
+    if dtype == np.int32:
+        return T.i32()
+    if dtype == np.int64:
+        return T.i64()
+    raise NotImplementedError(dtype)
+
+
 def c(value, ty=None):
-    ty = ir.IndexType.get() if ty is None else ty
+    ty = T.index() if ty is None else ty
     return arith.constant(ty, value)
 
-
 def generate_matmul_ws(
     input_type=np.float16,
     output_type=np.float32,
@@ -86,27 +124,21 @@ def generate_matmul_ws(
     num_stages = min(required_stages, max_num_stages)
 
     module = ir.Module.create()
-    f16 = ir.F16Type.get()
-    f32 = ir.F32Type.get()
-    i1 = ir.IntegerType.get_signless(1)
-    i32 = ir.IntegerType.get_signless(32)
-    index = ir.IndexType.get()
-    i8 = ir.IntegerType.get_signless(8)
     token_ty = ir.Type.parse("!gpu.async.token")
-    a_ty = ir.MemRefType.get([M, K], f16)
-    b_ty = ir.MemRefType.get((K, N), f16)
-    c_elem_ty = f16 if output_type == np.float16 else f32
+    a_elem_ty = get_mlir_ty(input_type)
+    b_elem_ty = get_mlir_ty(input_type)
+    c_elem_ty = get_mlir_ty(output_type)
+    a_ty = ir.MemRefType.get([M, K], a_elem_ty)
+    b_ty = ir.MemRefType.get((K, N), b_elem_ty)
     c_ty = ir.MemRefType.get((M, N), c_elem_ty)
     a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
     b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
     b_tile_shape = (BLOCK_K, BLOCK_N)
-    txcount = (
-        (b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])
-    ) * f16_byte
+    txcount = (b_tile_shape[0] * b_tile_shape[1] * get_type_size(a_elem_ty)) + (
+        a_tile_shape[0] * a_tile_shape[1] * get_type_size(b_elem_ty)
+    )
     smem_space_str = "#gpu.address_space<workgroup>"
     smem_space = ir.Attribute.parse(smem_space_str)
-    input_type_str = "f16" if input_type == np.float16 else "f32"
-    output_type_str = "f16" if output_type == np.float16 else "f32"
     mbar_ty = ir.Type.parse(
         "!nvgpu.mbarrier.group<memorySpace = "
         + str(smem_space)
@@ -120,7 +152,7 @@ def generate_matmul_ws(
         + "x"
         + str(TMA_LAST_DIM_F16)
         + "x"
-        + str(input_type_str)
+        + get_type_str(a_elem_ty)
         + ", "
         + str(smem_space)
         + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
@@ -131,7 +163,7 @@ def generate_matmul_ws(
         + "x"
         + str(TMA_LAST_DIM_F16)
         + "x"
-        + str(input_type_str)
+        + get_type_str(b_elem_ty)
         + ", "
         + str(smem_space)
         + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
@@ -142,7 +174,7 @@ def generate_matmul_ws(
         + "x"
         + str(BLOCK_N)
         + "x"
-        + str(output_type_str)
+        + get_type_str(c_elem_ty)
         + ">>"
     )
     a_wgmma_ty = ir.Type.parse(
@@ -151,7 +183,7 @@ def generate_matmul_ws(
         + "x"
         + str(BLOCK_K)
         + "x"
-        + str(input_type_str)
+        + get_type_str(a_elem_ty)
         + ", "
         + smem_space_str
         + ">>"
@@ -162,7 +194,7 @@ def generate_matmul_ws(
         + "x"
         + str(BLOCK_N)
         + "x"
-        + str(input_type_str)
+        + get_type_str(a_elem_ty)
         + ", "
         + smem_space_str
         + ">>"
@@ -172,10 +204,10 @@ def generate_matmul_ws(
 
         @func.FuncOp.from_py_func(a_ty, b_ty, c_ty)
         def mlir_matmul_warpspecialized(a_host, b_host, c_host):
-            lhs_tile_bytes = BLOCK_M * BLOCK_K * f16_byte
-            rhs_tile_bytes = BLOCK_N * BLOCK_K * f16_byte
+            lhs_tile_bytes = BLOCK_M * BLOCK_K * get_type_size(a_elem_ty)
+            rhs_tile_bytes = BLOCK_N * BLOCK_K * get_type_size(b_elem_ty)
             smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
-            smem_size_output = BLOCK_M * BLOCK_N * f32_byte
+            smem_size_output = BLOCK_M * BLOCK_N * get_type_size(c_elem_ty)
             smem_size = max(smem_size_input, smem_size_output)
 
             # Step 1. Allocate device memory and memcpy
@@ -195,7 +227,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
             tma_descs = []
             for x_device, tensor_map_ty, tile_shape in tma_specs:
                 x_unranked = memref.cast(
-                    ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device
+                    ir.UnrankedMemRefType.get(a_elem_ty, a_ty.memory_space), x_device
                 )
                 tma_descs.append(
                     nvgpu.TmaCreateDescriptorOp(
@@ -215,14 +247,14 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                 [t7],
                 *map(c, grid),
                 *map(c, block),
-                dynamicSharedMemorySize=c(smem_size, ty=i32)
+                dynamicSharedMemorySize=c(smem_size, ty=T.i32())
             )
-            launch_op.body.blocks.append(*([index] * 12))
+            launch_op.body.blocks.append(*([T.index()] * 12))
             with ir.InsertionPoint(launch_op.body.blocks[0]):
                 # GPU Step 0. This is need for vectorized ld/st
                 memref.assume_alignment(c_device, 16)
                 dynamic_smem = gpu.dynamic_shared_memory(
-                    ir.MemRefType.get((MLIR_DYNAMIC,), i8, memory_space=smem_space)
+                    ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space)
                 )
                 ticks = c(10000000)
 
@@ -275,7 +307,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
 
                     # Step 5.2. TMA Main Loop
                     for_op = scf.ForOp(
-                        c(0), c(K // BLOCK_K), c(1), [arith.constant(i1, 1)]
+                        c(0), c(K // BLOCK_K), c(1), [arith.constant(T.bool(), 1)]
                     )
                     with ir.InsertionPoint(for_op.body):
                         phaseParity = for_op.inner_iter_args[0]
@@ -303,7 +335,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                         p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
                         phaseParity = arith.select(
                             p,
-                            arith.xori(phaseParity, arith.constant(i1, 1)),
+                            arith.xori(phaseParity, arith.constant(T.bool(), 1)),
                             phaseParity,
                         )
 
@@ -311,7 +343,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                         a_offset = arith.muli(stage, c(lhs_tile_bytes))
                         a_tma_slice = memref.view(
                             ir.MemRefType.get(
-                                a_tma_shape, f16, memory_space=smem_space
+                                a_tma_shape, a_elem_ty, memory_space=smem_space
                             ),
                             dynamic_smem,
                             a_offset,
@@ -323,18 +355,19 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                         )
                         b_tma_slice_1 = memref.view(
                             ir.MemRefType.get(
-                                b_tma_shape, f16, memory_space=smem_space
+                                b_tma_shape, b_elem_ty, memory_space=smem_space
                             ),
                             dynamic_smem,
                             b_offset,
                             [],
                         )
                         b_offset2 = arith.addi(
-                            b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte)
+                            b_offset,
+                            c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)),
                         )
                         b_tma_slice_2 = memref.view(
                             ir.MemRefType.get(
-                                b_tma_shape, f16, memory_space=smem_space
+                                b_tma_shape, b_elem_ty, memory_space=smem_space
                             ),
                             dynamic_smem,
                             b_offset2,
@@ -406,7 +439,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
 
                     # Step 6.3. MMA Main Loop
                     for_op = scf.ForOp(
-                        c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)]
+                        c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(T.bool(), 0)]
                     )
                     with ir.InsertionPoint(for_op.body):
                         # Step 6.3.1. Wait mbar1
@@ -435,7 +468,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                         a_offset = arith.muli(stage, c(lhs_tile_bytes))
                         a_tile_slice = memref.view(
                             ir.MemRefType.get(
-                                a_tile_shape, f16, memory_space=smem_space
+                                a_tile_shape, a_elem_ty, memory_space=smem_space
                             ),
                             dynamic_smem,
                             a_offset,
@@ -447,7 +480,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                         )
                         b_tile_slice = memref.view(
                             ir.MemRefType.get(
-                                b_tile_shape, f16, memory_space=smem_space
+                                b_tile_shape, b_elem_ty, memory_space=smem_space
                             ),
                             dynamic_smem,
                             b_offset,
@@ -506,7 +539,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                         p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
                         phaseParity = arith.select(
                             p,
-                            arith.xori(phaseParity, arith.constant(i1, 1)),
+                            arith.xori(phaseParity, arith.constant(T.bool(), 1)),
                             phaseParity,
                         )
 
@@ -604,27 +637,21 @@ def generate_matmul_multistage(
     num_stages = min(required_stages, max_num_stages)
 
     module = ir.Module.create()
-    f16 = ir.F16Type.get()
-    f32 = ir.F32Type.get()
-    i1 = ir.IntegerType.get_signless(1)
-    i32 = ir.IntegerType.get_signless(32)
-    index = ir.IndexType.get()
-    i8 = ir.IntegerType.get_signless(8)
     token_ty = ir.Type.parse("!gpu.async.token")
-    a_ty = ir.MemRefType.get([M, K], f16)
-    b_ty = ir.MemRefType.get((K, N), f16)
-    c_elem_ty = f16 if output_type == np.float16 else f32
+    a_elem_ty = get_mlir_ty(input_type)
+    b_elem_ty = get_mlir_ty(input_type)
+    c_elem_ty = get_mlir_ty(output_type)
+    a_ty = ir.MemRefType.get([M, K], a_elem_ty)
+    b_ty = ir.MemRefType.get((K, N), b_elem_ty)
     c_ty = ir.MemRefType.get((M, N), c_elem_ty)
     a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
     b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
     b_tile_shape = (BLOCK_K, BLOCK_N)
-    txcount = (
-        (b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])
-    ) * f16_byte
+    txcount = (b_tile_shape[0] * b_tile_shape[1] * get_type_size(a_elem_ty)) + (
+        a_tile_shape[0] * a_tile_shape[1] * get_type_size(b_elem_ty)
+    )
     smem_space_str = "#gpu.address_space<workgroup>"
     smem_space = ir.Attribute.parse(smem_space_str)
-    input_type_str = "f16" if input_type == np.float16 else "f32"
-    output_type_str = "f16" if output_type == np.float16 else "f32"
     mbar_ty = ir.Type.parse(
         "!nvgpu.mbarrier.group<memorySpace = "
         + str(smem_space)
@@ -638,7 +665,7 @@ def generate_matmul_multistage(
         + "x"
         + str(TMA_LAST_DIM_F16)
         + "x"
-        + str(input_type_str)
+        + get_type_str(a_elem_ty)
         + ", "
         + str(smem_space)
         + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
@@ -649,7 +676,7 @@ def generate_matmul_multistage(
         + "x"
         + str(TMA_LAST_DIM_F16)
         + "x"
-        + str(input_type_str)
+        + get_type_str(b_elem_ty)
         + ", "
         + str(smem_space)
         + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
@@ -660,7 +687,7 @@ def generate_matmul_multistage(
         + "x"
         + str(BLOCK_N)
         + "x"
-        + str(output_type_str)
+        + get_type_str(c_elem_ty)
         + ">>"
     )
     a_wgmma_ty = ir.Type.parse(
@@ -669,7 +696,7 @@ def generate_matmul_multistage(
         + "x"
         + str(BLOCK_K)
         + "x"
-        + str(input_type_str)
+        + get_type_str(a_elem_ty)
         + ", "
         + smem_space_str
         + ">>"
@@ -680,7 +707,7 @@ def generate_matmul_multistage(
         + "x"
         + str(BLOCK_N)
         + "x"
-        + str(input_type_str)
+        + get_type_str(a_elem_ty)
         + ", "
         + smem_space_str
         + ">>"
@@ -690,10 +717,10 @@ def generate_matmul_multistage(
 
         @func.FuncOp.from_py_func(a_ty, b_ty, c_ty)
         def mlir_matmul_multistage(a_host, b_host, c_host):
-            lhs_tile_bytes = BLOCK_M * BLOCK_K * f16_byte
-            rhs_tile_bytes = BLOCK_N * BLOCK_K * f16_byte
+            lhs_tile_bytes = BLOCK_M * BLOCK_K * get_type_size(a_elem_ty)
+            rhs_tile_bytes = BLOCK_N * BLOCK_K * get_type_size(b_elem_ty)
             smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
-            smem_size_output = BLOCK_M * BLOCK_N * f32_byte
+            smem_size_output = BLOCK_M * BLOCK_N * get_type_size(c_elem_ty)
             smem_size = max(smem_size_input, smem_size_output)
 
             # Step 1. Allocate device memory and memcpy
@@ -713,7 +740,7 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
             tma_descs = []
             for x_device, tensor_map_ty, tile_shape in tma_specs:
                 x_unranked = memref.cast(
-                    ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device
+                    ir.UnrankedMemRefType.get(a_elem_ty, a_ty.memory_space), x_device
                 )
                 tma_descs.append(
                     nvgpu.TmaCreateDescriptorOp(
@@ -733,14 +760,14 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                 [t7],
                 *map(c, grid),
                 *map(c, block),
-                dynamicSharedMemorySize=c(smem_size, ty=i32)
+                dynamicSharedMemorySize=c(smem_size, ty=T.i32())
             )
-            launch_op.body.blocks.append(*([index] * 12))
+            launch_op.body.blocks.append(*([T.index()] * 12))
             with ir.InsertionPoint(launch_op.body.blocks[0]):
                 # GPU Step 0. Bootstrapping
                 memref.assume_alignment(c_device, 16)
                 dynamic_smem = gpu.dynamic_shared_memory(
-                    ir.MemRefType.get((MLIR_DYNAMIC,), i8, memory_space=smem_space)
+                    ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space)
                 )
                 ticks = c(10000000)
                 tidx = gpu.thread_id(gpu.Dimension.x)
@@ -769,7 +796,9 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                     # Step 3.1. Calculate offsets
                     a_offset = arith.muli(iv, c(lhs_tile_bytes))
                     a_tma_slice = memref.view(
-                        ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
+                        ir.MemRefType.get(
+                            a_tma_shape, a_elem_ty, memory_space=smem_space
+                        ),
                         dynamic_smem,
                         a_offset,
                         [],
@@ -779,16 +808,21 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                         c(lhs_tile_bytes * num_stages),
                     )
                     b_tma_slice_1 = memref.view(
-                        ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+                        ir.MemRefType.get(
+                            b_tma_shape, b_elem_ty, memory_space=smem_space
+                        ),
                         dynamic_smem,
                         b_offset,
                         [],
                     )
                     b_offset2 = arith.addi(
-                        b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte)
+                        b_offset,
+                        c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)),
                     )
                     b_tma_slice_2 = memref.view(
-                        ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+                        ir.MemRefType.get(
+                            b_tma_shape, b_elem_ty, memory_space=smem_space
+                        ),
                         dynamic_smem,
                         b_offset2,
                         [],
@@ -850,7 +884,7 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                 # GPU Step 4. Main Loop
                 acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
                 for_op = scf.ForOp(
-                    c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)]
+                    c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(T.bool(), 0)]
                 )
                 with ir.InsertionPoint(for_op.body):
                     # Step 4.1. Wait mbarTMA
@@ -876,7 +910,9 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                     # Step 4.2. Create WGMMA Descriptors
                     a_offset = arith.muli(stage, c(lhs_tile_bytes))
                     a_tile_slice = memref.view(
-                        ir.MemRefType.get(a_tile_shape, f16, memory_space=smem_space),
+                        ir.MemRefType.get(
+                            a_tile_shape, a_elem_ty, memory_space=smem_space
+                        ),
                         dynamic_smem,
                         a_offset,
                         [],
@@ -886,7 +922,9 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                         c(lhs_tile_bytes * num_stages),
                     )
                     b_tile_slice = memref.view(
-                        ir.MemRefType.get(b_tile_shape, f16, memory_space=smem_space),
+                        ir.MemRefType.get(
+                            b_tile_shape, b_elem_ty, memory_space=smem_space
+                        ),
                         dynamic_smem,
                         b_offset,
                         [],
@@ -924,7 +962,7 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                         a_offset = arith.muli(nextSlot, c(lhs_tile_bytes))
                         a_tma_slice = memref.view(
                             ir.MemRefType.get(
-                                a_tma_shape, f16, memory_space=smem_space
+                                a_tma_shape, a_elem_ty, memory_space=smem_space
                             ),
                             dynamic_smem,
                             a_offset,
@@ -936,18 +974,19 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                         )
                         b_tma_slice_1 = memref.view(
                             ir.MemRefType.get(
-                                b_tma_shape, f16, memory_space=smem_space
+                                b_tma_shape, b_elem_ty, memory_space=smem_space
                             ),
                             dynamic_smem,
                             b_offset,
                             [],
                         )
                         b_offset2 = arith.addi(
-                            b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte)
+                            b_offset,
+                            c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)),
                         )
                         b_tma_slice_2 = memref.view(
                             ir.MemRefType.get(
-                                b_tma_shape, f16, memory_space=smem_space
+                                b_tma_shape, b_elem_ty, memory_space=smem_space
                             ),
                             dynamic_smem,
                             b_offset2,
@@ -1010,7 +1049,9 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                     # Step 4.5. Change the phaseParity
                     p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
                     phaseParity = arith.select(
-                        p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity
+                        p,
+                        arith.xori(phaseParity, arith.constant(T.bool(), 1)),
+                        phaseParity,
                     )
 
                     # Step 4.5. Yield

>From d4393a47f1d93253bac35c5ce784bc9b06c463f1 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Tue, 20 Feb 2024 10:23:23 +0000
Subject: [PATCH 6/8] format

---
 .../Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py    | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
index 434af5e0920466..52732b1c01426e 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
@@ -25,6 +25,8 @@
 MLIR_DYNAMIC = -9223372036854775808
 
 DEBUG = False
+
+
 def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False):
     if not DEBUG and not forcePrint:
         return
@@ -102,6 +104,7 @@ def c(value, ty=None):
     ty = T.index() if ty is None else ty
     return arith.constant(ty, value)
 
+
 def generate_matmul_ws(
     input_type=np.float16,
     output_type=np.float32,

>From 09fb9c87480422460f374be1133ebfbf13f8d6b1 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Sun, 3 Mar 2024 08:23:20 +0000
Subject: [PATCH 7/8] Allow multiple stages, and fix the kernels. Add
 test_short that test multiple cases.

---
 .../GPU/CUDA/sm90/python/matmul.py            | 157 +++++++--
 .../CUDA/sm90/python/tools/matmulBuilder.py   | 306 ++++++++++--------
 2 files changed, 308 insertions(+), 155 deletions(-)

diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
index f5d5b30c70d1e1..88609273a8b548 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
@@ -1,6 +1,6 @@
 # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
 # RUN:   %PYTHON %s | FileCheck %s
-# CHECK: PASS
+
 
 # ===--- GEMM Hopper Tensor Core Integration Test ---===
 #
@@ -168,6 +168,8 @@ def matmul(
     no_verify=False,
 ):
     # Print the configuration
+    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
+    num_stages = min(required_stages, max_num_stages)
     ity = "f16" if input_type == np.float16 else "f32"
     oty = "f16" if output_type == np.float16 else "f32"
     gemmty = "Warp specialization" if use_warp_specialization else "Multistage"
@@ -193,7 +195,7 @@ def matmul(
         + "x"
         + str(BLOCK_K)
         + ", stages "
-        + str(max_num_stages)
+        + str(num_stages)
         + " --==="
     )
 
@@ -209,7 +211,7 @@ def matmul(
         BLOCK_K,
         use_warp_specialization,
         saveIR,
-        max_num_stages,
+        num_stages,
     )
 
     # Allocate matrices and invoke the matmul
@@ -219,10 +221,17 @@ def matmul(
     mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
     mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
     mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
-    kernelName = (
-        "mlir_matmul_warpspecialized"
-        if use_warp_specialization
-        else "mlir_matmul_multistage"
+    kernelName = matmulBuilder.make_kernel_name(
+        input_type,
+        output_type,
+        M,
+        N,
+        K,
+        BLOCK_M,
+        BLOCK_N,
+        BLOCK_K,
+        num_stages,
+        use_warp_specialization,
     )
 
     # Launch the MLIR generated kernel
@@ -243,24 +252,118 @@ def matmul(
 
     print("PASS ")
 
+# Takes longer time to run
+def test_long():
+    for stages in range(1,7):
+        for M in [128, 512, 1024, 4096, 8192]:
+            for N in [128, 512, 1024, 4096, 8192]:
+                for K in [64, 128, 512, 1024, 4096, 8192]:
+                    matmul(
+                        np.float16,
+                        np.float32,
+                        M,N,
+                        K,
+                        max_num_stages=stages,
+                        use_warp_specialization=False,
+                        no_verify=True,
+                        saveIR=True
+                    )
+                    matmul(
+                        np.float16,
+                        np.float32,
+                        M,N,
+                        K,
+                        max_num_stages=stages,
+                        use_warp_specialization=True,                    
+                    )
 
-# GEMM Multistage       f32 += f16 * f16
-matmul(
-    np.float16,
-    np.float32,
-    128,
-    128,
-    4096,
-    max_num_stages=3,
-    use_warp_specialization=False,
-)
-# GEMM Warp Specilized  f32 += f16 * f16
-matmul(
-    np.float16,
-    np.float32,
-    256,
-    1024,
-    512,
-    max_num_stages=3,
-    use_warp_specialization=True,
-)
+def test_short():
+    for stages in [1, 3]:
+        for M in [128, 512]:
+            for N in [128, 512]:
+                for K in [64, 512]:
+                    matmul(
+                        np.float16,
+                        np.float32,
+                        M,N,
+                        K,
+                        max_num_stages=stages,
+                        use_warp_specialization=False,
+                        no_verify=True,
+                        saveIR=True
+                    )
+                    matmul(
+                        np.float16,
+                        np.float32,
+                        M,N,
+                        K,
+                        max_num_stages=stages,
+                        use_warp_specialization=True,                    
+                    )
+
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 3 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 3 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 2 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 2 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 3 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 3 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 3 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 3 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 3 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 3 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 3 --===
+# CHECK: PASS 
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 3 --===
+# CHECK: PASS
+
+test_short()
\ No newline at end of file
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
index 52732b1c01426e..a541378359fe97 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
@@ -105,6 +105,41 @@ def c(value, ty=None):
     return arith.constant(ty, value)
 
 
+def make_kernel_name(
+    input_type=np.float16,
+    output_type=np.float32,
+    M=4096,
+    N=4096,
+    K=4096,
+    BLOCK_M=128,
+    BLOCK_N=128,
+    BLOCK_K=128,
+    num_stages=3,
+    use_warp_specialization=False,
+):
+    kernelName = (
+        "warpspecialized"
+        if use_warp_specialization
+        else "multistage"
+    )
+    return (
+        kernelName
+        + "_"
+        + str(M)
+        + "x"
+        + str(N)
+        + "x"
+        + str(K)
+        + "_"
+        + str(BLOCK_M)
+        + "x"
+        + str(BLOCK_N)
+        + "x"
+        + str(BLOCK_K)
+        + "_"
+        + str(num_stages)
+    )
+
 def generate_matmul_ws(
     input_type=np.float16,
     output_type=np.float32,
@@ -114,7 +149,7 @@ def generate_matmul_ws(
     BLOCK_M=128,
     BLOCK_N=128,
     BLOCK_K=128,
-    max_num_stages=3,
+    num_stages=3,
 ):
     # Limitaitons for now
     assert input_type == np.float16
@@ -123,9 +158,6 @@ def generate_matmul_ws(
     assert N % BLOCK_N == 0
     assert K % BLOCK_K == 0
 
-    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
-    num_stages = min(required_stages, max_num_stages)
-
     module = ir.Module.create()
     token_ty = ir.Type.parse("!gpu.async.token")
     a_elem_ty = get_mlir_ty(input_type)
@@ -202,11 +234,15 @@ def generate_matmul_ws(
         + smem_space_str
         + ">>"
     )
-
+    kernelName = make_kernel_name(
+        input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_stages, True
+    )
     with ir.InsertionPoint(module.body):
-
-        @func.FuncOp.from_py_func(a_ty, b_ty, c_ty)
-        def mlir_matmul_warpspecialized(a_host, b_host, c_host):
+        fop = func.FuncOp(kernelName, ([a_ty, b_ty, c_ty], []))
+        with ir.InsertionPoint(fop.add_entry_block()):
+            a_host = fop.arguments[0]
+            b_host = fop.arguments[1]
+            c_host = fop.arguments[2]
             lhs_tile_bytes = BLOCK_M * BLOCK_K * get_type_size(a_elem_ty)
             rhs_tile_bytes = BLOCK_N * BLOCK_K * get_type_size(b_elem_ty)
             smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
@@ -301,6 +337,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                 nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=wgPrimaryThread)
                 nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=wgPrimaryThread)
 
+                ns = num_stages if num_stages == 1 else num_stages - 1
                 # GPU Step 5. Producer Warpgroup (TMA Warpgroup)
                 with ir.InsertionPoint(scf.IfOp(is_producer).then_block):
                     # Step 5.1. Reduce register size
@@ -319,7 +356,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
 
                         # Step 5.2.1. Wait mbarDONE
                         debug_print(
-                            "[prod] {}  | mbarDONE[{}] try_wait  phase={}",
+                            "[prod] iv={}  | mbarDONE[{}] try_wait  phase={}",
                             iv,
                             stage,
                             phaseParity,
@@ -329,7 +366,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                             mbarDONE, phaseParity, ticks, mbarId=stage
                         )
                         debug_print(
-                            "[prod] {}  | mbarDONE[{}] try_wait  phase={} [done]",
+                            "[prod] iv={}  | mbarDONE[{}] try_wait  phase={} [done]",
                             iv,
                             stage,
                             phaseParity,
@@ -412,7 +449,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
 
                         # Step 5.2.3. Arrive mbarTMA
                         debug_print(
-                            "[prod] {}  | mbarTMA[{}] arrive",
+                            "[prod] iv={}  | mbarTMA[{}] arrive",
                             iv,
                             stage,
                             predicate=producerPrimaryThread,
@@ -421,7 +458,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                             mbarTMA, c(txcount), stage, predicate=producerPrimaryThread
                         )
                         debug_print(
-                            "[prod] {}  | mbarTMA[{}] arrive [done]",
+                            "[prod] iv={}  | mbarTMA[{}] arrive [done]",
                             iv,
                             stage,
                             predicate=producerPrimaryThread,
@@ -450,7 +487,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                         iv = for_op.induction_variable
                         stage = arith.remui(iv, c(num_stages))
                         debug_print(
-                            "[cons] {}  | mbarTMA[{}] try_wait   phase={}",
+                            "[cons] iv={}  | mbarTMA[{}] try_wait   phase={}",
                             iv,
                             stage,
                             phaseParity,
@@ -460,7 +497,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                             mbarTMA, phaseParity, ticks, mbarId=stage
                         )
                         debug_print(
-                            "[cons] {}  | mbarTMA[{}] try_wait   phase={} [done]",
+                            "[cons] iv={}  | mbarTMA[{}] try_wait   phase={} [done]",
                             iv,
                             stage,
                             phaseParity,
@@ -509,32 +546,29 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
                         )
 
                         # Step 6.3.4. Arrive mbarDONE
-                        p1 = arith.cmpi(arith.CmpIPredicate.sgt, iv, c(0))
-                        p_arrive = arith.andi(consumerPrimaryThread, p1)
+                        if num_stages == 1:
+                            p_arrive = consumerPrimaryThread
+                        else:
+                            p1 = arith.cmpi(arith.CmpIPredicate.sgt, iv, c(0))
+                            p_arrive = arith.andi(consumerPrimaryThread, p1)
                         with ir.InsertionPoint(scf.IfOp(p_arrive).then_block):
                             p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(0))
                             barId = arith.select(
                                 p, c(num_stages - 1), arith.subi(stage, c(1))
                             )
-                            remoteCtaId = c(0)
-                            pred = consumerPrimaryThread
                             debug_print(
-                                "[cons] {}  | mbarDONE[{}] arrive.mapa  pred={} ",
+                                "[cons] iv={}  | mbarDONE[{}] arrive ",
                                 iv,
                                 barId,
-                                remoteCtaId,
-                                pred,
                                 predicate=consumerPrimaryThread,
                             )
                             nvgpu.mbarrier_arrive(
                                 ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId
                             )
                             debug_print(
-                                "[cons] {}  | mbarDONE[{}] arrive  pred={} [done]",
+                                "[cons] iv={}  | mbarDONE[{}] arrive [done]",
                                 iv,
                                 barId,
-                                remoteCtaId,
-                                pred,
                                 predicate=consumerPrimaryThread,
                             )
                             scf.yield_([])
@@ -609,11 +643,13 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
             # Step 4. Copy back to host
             t8 = gpu.wait(token_ty, [launch_op])
             t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
+            gpu.dealloc(token_ty, [t8], a_device)
+            gpu.dealloc(token_ty, [t8], b_device)            
             gpu.wait(token_ty, [t9])
+            gpu.dealloc(token_ty, [t8], c_device)
+            func.ReturnOp([])
 
-    mlir_matmul_warpspecialized.func_op.attributes[
-        "llvm.emit_c_interface"
-    ] = ir.UnitAttr.get()
+    fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
     module.operation.verify()
     return module
 
@@ -627,7 +663,7 @@ def generate_matmul_multistage(
     BLOCK_M=128,
     BLOCK_N=128,
     BLOCK_K=64,
-    max_num_stages=3,
+    num_stages=3,
 ):
     # Limitaitons for now
     assert input_type == np.float16
@@ -636,9 +672,6 @@ def generate_matmul_multistage(
     assert N % BLOCK_N == 0
     assert K % BLOCK_K == 0
 
-    required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
-    num_stages = min(required_stages, max_num_stages)
-
     module = ir.Module.create()
     token_ty = ir.Type.parse("!gpu.async.token")
     a_elem_ty = get_mlir_ty(input_type)
@@ -717,9 +750,23 @@ def generate_matmul_multistage(
     )
 
     with ir.InsertionPoint(module.body):
-
-        @func.FuncOp.from_py_func(a_ty, b_ty, c_ty)
-        def mlir_matmul_multistage(a_host, b_host, c_host):
+        kernelName = make_kernel_name(
+            input_type,
+            output_type,
+            M,
+            N,
+            K,
+            BLOCK_M,
+            BLOCK_N,
+            BLOCK_K,
+            num_stages,
+            False,
+        )
+        fop = func.FuncOp(kernelName, ([a_ty, b_ty, c_ty], []))
+        with ir.InsertionPoint(fop.add_entry_block()):
+            a_host = fop.arguments[0]
+            b_host = fop.arguments[1]
+            c_host = fop.arguments[2]
             lhs_tile_bytes = BLOCK_M * BLOCK_K * get_type_size(a_elem_ty)
             rhs_tile_bytes = BLOCK_N * BLOCK_K * get_type_size(b_elem_ty)
             smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
@@ -792,7 +839,8 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                 nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=primaryThread)
 
                 # GPU Step 3. Prologue (global memory --> shared memory)
-                for_op = scf.ForOp(c(0), c(num_stages - 1), c(1))
+                ns = num_stages if num_stages == 1 else num_stages - 1
+                for_op = scf.ForOp(c(0), c(ns), c(1))
                 with ir.InsertionPoint(for_op.body):
                     iv = for_op.induction_variable
 
@@ -951,104 +999,105 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
                     new_acc = nvgpu.WarpgroupMmaOp(
                         acc.type, da, db, carry_acc, transposeB=True
                     )
+                    if num_stages == 1:
+                        nvvm.WgmmaWaitGroupSyncOp(0)
 
                     # Step 4.4. Load TMA for next stage
                     p1 = arith.cmpi(
                         arith.CmpIPredicate.ult,
-                        arith.addi(iv, c(num_stages - 1)),
+                        arith.addi(iv, c(ns)),
                         c(K // BLOCK_K),
                     )
                     p = arith.andi(primaryThread, p1)
-                    with ir.InsertionPoint(scf.IfOp(p).then_block):
-                        nextStage = arith.addi(iv, c(num_stages - 1))
-                        nextSlot = arith.remui(nextStage, c(num_stages))
-                        a_offset = arith.muli(nextSlot, c(lhs_tile_bytes))
-                        a_tma_slice = memref.view(
-                            ir.MemRefType.get(
-                                a_tma_shape, a_elem_ty, memory_space=smem_space
-                            ),
-                            dynamic_smem,
-                            a_offset,
-                            [],
-                        )
-                        b_offset = arith.addi(
-                            arith.muli(nextSlot, c(rhs_tile_bytes)),
-                            c(lhs_tile_bytes * num_stages),
-                        )
-                        b_tma_slice_1 = memref.view(
-                            ir.MemRefType.get(
-                                b_tma_shape, b_elem_ty, memory_space=smem_space
-                            ),
-                            dynamic_smem,
-                            b_offset,
-                            [],
-                        )
-                        b_offset2 = arith.addi(
-                            b_offset,
-                            c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)),
-                        )
-                        b_tma_slice_2 = memref.view(
-                            ir.MemRefType.get(
-                                b_tma_shape, b_elem_ty, memory_space=smem_space
-                            ),
-                            dynamic_smem,
-                            b_offset2,
-                            [],
-                        )
+                    nextStage = arith.addi(iv, c(ns))
+                    nextSlot = arith.remui(nextStage, c(num_stages))
+                    a_offset = arith.muli(nextSlot, c(lhs_tile_bytes))
 
-                        coord = arith.muli(c(64), nextStage)
-                        debug_print(
-                            "[MainLoop] iv={} TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
-                            iv,
-                            a_offset,
-                            b_offset,
-                            b_offset2,
-                            coord,
-                            dimX,
-                            dimY,
-                            coord,
-                            predicate=primaryThread,
-                        )
-                        nvgpu.TmaAsyncLoadOp(
-                            a_tma_slice,
-                            mbarTMA,
-                            a_tma_desc,
-                            coordinates=[coord, dimX],
-                            mbarId=nextSlot,
-                            predicate=primaryThread,
-                        )
-                        nvgpu.TmaAsyncLoadOp(
-                            b_tma_slice_1,
-                            mbarTMA,
-                            b_tma_desc,
-                            coordinates=[dimY, coord],
-                            mbarId=nextSlot,
-                            predicate=primaryThread,
-                        )
-                        dimY2 = arith.addi(dimY, c(64))
-                        nvgpu.TmaAsyncLoadOp(
-                            b_tma_slice_2,
-                            mbarTMA,
-                            b_tma_desc,
-                            coordinates=[dimY2, coord],
-                            mbarId=nextSlot,
-                            predicate=primaryThread,
-                        )
+                    debug_print(
+                        "[MainLoop] mbarTMA[{}] arrive",
+                        nextSlot,
+                        predicate=p,
+                    )
+                    nvgpu.mbarrier_arrive_expect_tx(
+                        mbarTMA, c(txcount), nextSlot, predicate=p
+                    )
+                    debug_print(
+                        "[MainLoop] mbarTMA[{}] arrive [done]",
+                        nextSlot,
+                        predicate=p,
+                    )
 
-                        debug_print(
-                            "[MainLoop] mbarTMA[{}] arrive",
-                            nextSlot,
-                            predicate=primaryThread,
-                        )
-                        nvgpu.mbarrier_arrive_expect_tx(
-                            mbarTMA, c(txcount), nextSlot, predicate=primaryThread
-                        )
-                        debug_print(
-                            "[MainLoop] mbarTMA[{}] arrive [done]",
-                            nextSlot,
-                            predicate=primaryThread,
-                        )
-                        scf.yield_([])
+                    a_tma_slice = memref.view(
+                        ir.MemRefType.get(
+                            a_tma_shape, a_elem_ty, memory_space=smem_space
+                        ),
+                        dynamic_smem,
+                        a_offset,
+                        [],
+                    )
+                    b_offset = arith.addi(
+                        arith.muli(nextSlot, c(rhs_tile_bytes)),
+                        c(lhs_tile_bytes * num_stages),
+                    )
+                    b_tma_slice_1 = memref.view(
+                        ir.MemRefType.get(
+                            b_tma_shape, b_elem_ty, memory_space=smem_space
+                        ),
+                        dynamic_smem,
+                        b_offset,
+                        [],
+                    )
+                    b_offset2 = arith.addi(
+                        b_offset,
+                        c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)),
+                    )
+                    b_tma_slice_2 = memref.view(
+                        ir.MemRefType.get(
+                            b_tma_shape, b_elem_ty, memory_space=smem_space
+                        ),
+                        dynamic_smem,
+                        b_offset2,
+                        [],
+                    )
+
+                    coord = arith.muli(c(64), nextStage)
+                    debug_print(
+                        "[MainLoop] iv={} TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
+                        iv,
+                        a_offset,
+                        b_offset,
+                        b_offset2,
+                        coord,
+                        dimX,
+                        dimY,
+                        coord,
+                        predicate=p,
+                    )
+                    nvgpu.TmaAsyncLoadOp(
+                        a_tma_slice,
+                        mbarTMA,
+                        a_tma_desc,
+                        coordinates=[coord, dimX],
+                        mbarId=nextSlot,
+                        predicate=p,
+                    )
+                    nvgpu.TmaAsyncLoadOp(
+                        b_tma_slice_1,
+                        mbarTMA,
+                        b_tma_desc,
+                        coordinates=[dimY, coord],
+                        mbarId=nextSlot,
+                        predicate=p,
+                    )
+                    dimY2 = arith.addi(dimY, c(64))
+                    nvgpu.TmaAsyncLoadOp(
+                        b_tma_slice_2,
+                        mbarTMA,
+                        b_tma_desc,
+                        coordinates=[dimY2, coord],
+                        mbarId=nextSlot,
+                        predicate=p,
+                    )
                     # Step 4.5. Change the phaseParity
                     p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
                     phaseParity = arith.select(
@@ -1062,7 +1111,6 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
 
                 # Step 5. Wait All WGMMA groups
                 nvvm.WgmmaWaitGroupSyncOp(0)
-                gpu.barrier()
 
                 # Step 6. Epilogue (registers --> shared memory)
                 acc_smem_ty = ir.MemRefType.get(
@@ -1113,10 +1161,12 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
             # Step 4. Copy back to host
             t8 = gpu.wait(token_ty, [launch_op])
             t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
+            gpu.dealloc(token_ty, [t8], a_device)
+            gpu.dealloc(token_ty, [t8], b_device)            
             gpu.wait(token_ty, [t9])
+            gpu.dealloc(token_ty, [t8], c_device)
+            func.ReturnOp([])
 
-    mlir_matmul_multistage.func_op.attributes[
-        "llvm.emit_c_interface"
-    ] = ir.UnitAttr.get()
+    fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
     module.operation.verify()
     return module

>From ad767a649d0aab25cf775a2bc2c2fb7ee6ad61fa Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Sun, 3 Mar 2024 08:42:00 +0000
Subject: [PATCH 8/8] format

---
 .../GPU/CUDA/sm90/python/matmul.py            | 108 +++++++-----------
 .../CUDA/sm90/python/tools/matmulBuilder.py   |  11 +-
 2 files changed, 44 insertions(+), 75 deletions(-)

diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
index 88609273a8b548..dd474b01ffabef 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
@@ -252,118 +252,90 @@ def matmul(
 
     print("PASS ")
 
+
 # Takes longer time to run
 def test_long():
-    for stages in range(1,7):
+    for stages in range(1, 7):
         for M in [128, 512, 1024, 4096, 8192]:
             for N in [128, 512, 1024, 4096, 8192]:
                 for K in [64, 128, 512, 1024, 4096, 8192]:
                     matmul(
                         np.float16,
                         np.float32,
-                        M,N,
+                        M,
+                        N,
                         K,
                         max_num_stages=stages,
                         use_warp_specialization=False,
                         no_verify=True,
-                        saveIR=True
                     )
                     matmul(
                         np.float16,
                         np.float32,
-                        M,N,
+                        M,
+                        N,
                         K,
                         max_num_stages=stages,
-                        use_warp_specialization=True,                    
+                        use_warp_specialization=True,
                     )
 
+
 def test_short():
     for stages in [1, 3]:
         for M in [128, 512]:
-            for N in [128, 512]:
-                for K in [64, 512]:
+            for N in [128]:
+                for K in [64, 256]:
                     matmul(
                         np.float16,
                         np.float32,
-                        M,N,
+                        M,
+                        N,
                         K,
                         max_num_stages=stages,
                         use_warp_specialization=False,
-                        no_verify=True,
-                        saveIR=True
                     )
                     matmul(
                         np.float16,
                         np.float32,
-                        M,N,
+                        M,
+                        N,
                         K,
                         max_num_stages=stages,
-                        use_warp_specialization=True,                    
+                        use_warp_specialization=True,
                     )
 
+
 # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
-# CHECK: PASS 
+# CHECK: PASS
 # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
-# CHECK: PASS 
-# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 1 --===
-# CHECK: PASS 
-# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 1 --===
-# CHECK: PASS 
-# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 1 --===
-# CHECK: PASS 
-# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 1 --===
-# CHECK: PASS 
-# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 1 --===
-# CHECK: PASS 
-# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 1 --===
-# CHECK: PASS 
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
 # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --===
-# CHECK: PASS 
+# CHECK: PASS
 # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --===
-# CHECK: PASS 
-# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 1 --===
-# CHECK: PASS 
-# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 1 --===
-# CHECK: PASS 
-# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 1 --===
-# CHECK: PASS 
-# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 1 --===
-# CHECK: PASS 
-# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 1 --===
-# CHECK: PASS 
-# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 1 --===
-# CHECK: PASS 
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
 # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
-# CHECK: PASS 
+# CHECK: PASS
 # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
-# CHECK: PASS 
-# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 3 --===
-# CHECK: PASS 
-# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 3 --===
-# CHECK: PASS 
-# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 2 --===
-# CHECK: PASS 
-# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 2 --===
-# CHECK: PASS 
-# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 3 --===
-# CHECK: PASS 
-# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 3 --===
-# CHECK: PASS 
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 3 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 3 --===
+# CHECK: PASS
 # CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --===
-# CHECK: PASS 
+# CHECK: PASS
 # CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --===
-# CHECK: PASS 
-# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 3 --===
-# CHECK: PASS 
-# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 3 --===
-# CHECK: PASS 
-# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 3 --===
-# CHECK: PASS 
-# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 3 --===
-# CHECK: PASS 
-# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 3 --===
-# CHECK: PASS 
-# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 3 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 3 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 3 --===
 # CHECK: PASS
 
-test_short()
\ No newline at end of file
+test_short()
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
index a541378359fe97..fefc65b854040f 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
@@ -117,11 +117,7 @@ def make_kernel_name(
     num_stages=3,
     use_warp_specialization=False,
 ):
-    kernelName = (
-        "warpspecialized"
-        if use_warp_specialization
-        else "multistage"
-    )
+    kernelName = "warpspecialized" if use_warp_specialization else "multistage"
     return (
         kernelName
         + "_"
@@ -140,6 +136,7 @@ def make_kernel_name(
         + str(num_stages)
     )
 
+
 def generate_matmul_ws(
     input_type=np.float16,
     output_type=np.float32,
@@ -644,7 +641,7 @@ def generate_matmul_ws(
             t8 = gpu.wait(token_ty, [launch_op])
             t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
             gpu.dealloc(token_ty, [t8], a_device)
-            gpu.dealloc(token_ty, [t8], b_device)            
+            gpu.dealloc(token_ty, [t8], b_device)
             gpu.wait(token_ty, [t9])
             gpu.dealloc(token_ty, [t8], c_device)
             func.ReturnOp([])
@@ -1162,7 +1159,7 @@ def generate_matmul_multistage(
             t8 = gpu.wait(token_ty, [launch_op])
             t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
             gpu.dealloc(token_ty, [t8], a_device)
-            gpu.dealloc(token_ty, [t8], b_device)            
+            gpu.dealloc(token_ty, [t8], b_device)
             gpu.wait(token_ty, [t9])
             gpu.dealloc(token_ty, [t8], c_device)
             func.ReturnOp([])



More information about the Mlir-commits mailing list