[Mlir-commits] [mlir] [mlir] GEMM Hopper Tensor Core Integration Test (PR #81478)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 12 05:36:55 PST 2024


github-actions[bot] wrote:

<!--LLVM CODE FORMAT COMMENT: {darker}-->


:warning: Python code formatter, darker found issues in your code. :warning:

<details>
<summary>
You can test this locally with the following command:
</summary>

``````````bash
darker --check --diff -r 7bc079c85219ad6e954fb6071cd108151203c85e...a9b4035aa421506302e74c6f6ff19a0aec62b735 mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
``````````

</details>

<details>
<summary>
View the diff from darker here.
</summary>

``````````diff
--- matmul.py	2024-02-12 13:33:26.000000 +0000
+++ matmul.py	2024-02-12 13:36:46.432817 +0000
@@ -83,91 +83,155 @@
 import sys
 import pathlib
 import ctypes
 from mlir import runtime as rt
 
-def generate_matmul(input_type=np.float16,
-                    output_type=np.float32,
-                    M=4096,
-                    N=4096,
-                    K=4096,
-                    BLOCK_M=128,
-                    BLOCK_N=128,
-                    BLOCK_K=64,
-                    use_warp_specilization=True,
-                    saveIR=False,
-                    max_num_stages=3):
+
+def generate_matmul(
+    input_type=np.float16,
+    output_type=np.float32,
+    M=4096,
+    N=4096,
+    K=4096,
+    BLOCK_M=128,
+    BLOCK_N=128,
+    BLOCK_K=64,
+    use_warp_specilization=True,
+    saveIR=False,
+    max_num_stages=3,
+):
     with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown():
         if use_warp_specilization:
-            mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(input_type, output_type, M, N, K, BLOCK_M, BLOCK_N,
-                                                                 BLOCK_K, max_num_stages)
+            mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(
+                input_type,
+                output_type,
+                M,
+                N,
+                K,
+                BLOCK_M,
+                BLOCK_N,
+                BLOCK_K,
+                max_num_stages,
+            )
         else:
-            mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage(input_type, output_type, M, N, K, BLOCK_M,
-                                                                         BLOCK_N, BLOCK_K, max_num_stages)
+            mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage(
+                input_type,
+                output_type,
+                M,
+                N,
+                K,
+                BLOCK_M,
+                BLOCK_N,
+                BLOCK_K,
+                max_num_stages,
+            )
 
         mlir_nvgpu_module.operation.verify()
-        
+
         # Save generated IR
         if saveIR:
             # print(mlir_nvgpu_module)
             original_stdout = sys.stdout
-            with open('gemm.mlir', 'w') as f:
+            with open("gemm.mlir", "w") as f:
                 sys.stdout = f
                 print(mlir_nvgpu_module)
                 sys.stdout = original_stdout
 
         # Get compiler
         options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
         support_lib = os.getenv("SUPPORT_LIB")
         if not os.path.exists(support_lib):
-            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
-        compiler = nvgpucompiler.NvgpuCompiler(options, opt_level=3, shared_libs=[support_lib])
+            raise FileNotFoundError(
+                errno.ENOENT, os.strerror(errno.ENOENT), support_lib
+            )
+        compiler = nvgpucompiler.NvgpuCompiler(
+            options, opt_level=3, shared_libs=[support_lib]
+        )
 
         # Compile
         engine = compiler.compile_and_jit(mlir_nvgpu_module)
         return engine
 
 
-def matmul(input_type=np.float16,
-           output_type=np.float32,
-           M=128,
-           N=128,
-           K=128,
-           BLOCK_M=128,
-           BLOCK_N=128,
-           BLOCK_K=64,
-           use_warp_specilization=True,
-           saveIR=False,
-           max_num_stages=3,
-           print_results=False,
-           no_verify=False):
+def matmul(
+    input_type=np.float16,
+    output_type=np.float32,
+    M=128,
+    N=128,
+    K=128,
+    BLOCK_M=128,
+    BLOCK_N=128,
+    BLOCK_K=64,
+    use_warp_specilization=True,
+    saveIR=False,
+    max_num_stages=3,
+    print_results=False,
+    no_verify=False,
+):
     # Print the configuration
     ity = "f16" if input_type == np.float16 else "f32"
     oty = "f16" if output_type == np.float16 else "f32"
     gemmty = "Warp Specilization" if use_warp_specilization else "Multistage"
-    print("===-- Running GEMM " + gemmty + " " + oty + " += " + ity + " * " + ity + ", Size " + str(M) + "x" + str(N) +
-          "x" + str(K) + ", Tile " + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" + str(BLOCK_K) + ", stages " +
-          str(max_num_stages) + " --===")
+    print(
+        "===-- Running GEMM "
+        + gemmty
+        + " "
+        + oty
+        + " += "
+        + ity
+        + " * "
+        + ity
+        + ", Size "
+        + str(M)
+        + "x"
+        + str(N)
+        + "x"
+        + str(K)
+        + ", Tile "
+        + str(BLOCK_M)
+        + "x"
+        + str(BLOCK_N)
+        + "x"
+        + str(BLOCK_K)
+        + ", stages "
+        + str(max_num_stages)
+        + " --==="
+    )
 
     # Build IR and compile
-    engine = generate_matmul(input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, use_warp_specilization,
-                             saveIR, max_num_stages)
+    engine = generate_matmul(
+        input_type,
+        output_type,
+        M,
+        N,
+        K,
+        BLOCK_M,
+        BLOCK_N,
+        BLOCK_K,
+        use_warp_specilization,
+        saveIR,
+        max_num_stages,
+    )
 
     # Allocate matrices and invoke the matmul
     c = np.zeros((M, N), output_type)
     a = np.random.randn(M, K).astype(input_type)
     b = np.random.randn(K, N).astype(input_type)
     mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
     mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
     mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
-    kernelName = "mlir_matmul_warpspecialized" if use_warp_specilization else "mlir_matmul_multistage"
+    kernelName = (
+        "mlir_matmul_warpspecialized"
+        if use_warp_specilization
+        else "mlir_matmul_multistage"
+    )
 
     # Launch the MLIR generated kernel
     engine.invoke(kernelName, mem_a, mem_b, mem_c)
 
     float_formatter = "{:.2f}".format
-    np.set_printoptions(formatter={'float_kind': float_formatter})
+    np.set_printoptions(formatter={"float_kind": float_formatter})
 
     if print_results:
         print(c)
 
     # Verify the results
@@ -179,8 +243,24 @@
 
     print("PASS ")
 
 
 # GEMM Multistage       f32 += f16 * f16
-matmul(np.float16, np.float32, 128, 128, 4096, max_num_stages=3, use_warp_specilization=False)
+matmul(
+    np.float16,
+    np.float32,
+    128,
+    128,
+    4096,
+    max_num_stages=3,
+    use_warp_specilization=False,
+)
 # GEMM Warp Specilized  f32 += f16 * f16
-matmul(np.float16, np.float32, 256, 1024, 512, max_num_stages=3, use_warp_specilization=True)
+matmul(
+    np.float16,
+    np.float32,
+    256,
+    1024,
+    512,
+    max_num_stages=3,
+    use_warp_specilization=True,
+)
--- tools/matmulBuilder.py	2024-02-12 13:33:26.000000 +0000
+++ tools/matmulBuilder.py	2024-02-12 13:36:46.754115 +0000
@@ -63,19 +63,21 @@
 def c(value, ty=None):
     ty = ir.IndexType.get() if ty is None else ty
     return arith.constant(ty, value)
 
 
-def generate_matmul_ws(input_type=np.float16,
-                       output_type=np.float32,
-                       M=4096,
-                       N=4096,
-                       K=4096,
-                       BLOCK_M=128,
-                       BLOCK_N=128,
-                       BLOCK_K=128,
-                       max_num_stages=3):
+def generate_matmul_ws(
+    input_type=np.float16,
+    output_type=np.float32,
+    M=4096,
+    N=4096,
+    K=4096,
+    BLOCK_M=128,
+    BLOCK_N=128,
+    BLOCK_K=128,
+    max_num_stages=3,
+):
     # Limitaitons for now
     assert input_type == np.float16
     assert output_type == np.float32
     assert M % BLOCK_M == 0
     assert N % BLOCK_N == 0
@@ -97,29 +99,77 @@
     c_elem_ty = f16 if output_type == np.float16 else f32
     c_ty = ir.MemRefType.get((M, N), c_elem_ty)
     a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
     b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
     b_tile_shape = (BLOCK_K, BLOCK_N)
-    txcount = ((b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
+    txcount = (
+        (b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])
+    ) * f16_byte
     smem_space_str = "#gpu.address_space<workgroup>"
     smem_space = ir.Attribute.parse(smem_space_str)
     input_type_str = "f16" if input_type == np.float16 else "f32"
     output_type_str = "f16" if output_type == np.float16 else "f32"
-    mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " + str(smem_space) + ", num_barriers = " +
-                            str(num_stages) + ">")
-    a_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
-                                  str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
-                                  ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
-    b_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
-                                  str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
-                                  ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
-    acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
-                           str(output_type_str) + ">>")
-    a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
-                               str(input_type_str) + ", " + smem_space_str + ">>")
-    b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
-                               str(input_type_str) + ", " + smem_space_str + ">>")
+    mbar_ty = ir.Type.parse(
+        "!nvgpu.mbarrier.group<memorySpace = "
+        + str(smem_space)
+        + ", num_barriers = "
+        + str(num_stages)
+        + ">"
+    )
+    a_tma_desc_ty = ir.Type.parse(
+        "!nvgpu.tensormap.descriptor<tensor = memref<"
+        + str(BLOCK_M)
+        + "x"
+        + str(TMA_LAST_DIM_F16)
+        + "x"
+        + str(input_type_str)
+        + ", "
+        + str(smem_space)
+        + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
+    )
+    b_tma_desc_ty = ir.Type.parse(
+        "!nvgpu.tensormap.descriptor<tensor = memref<"
+        + str(BLOCK_K)
+        + "x"
+        + str(TMA_LAST_DIM_F16)
+        + "x"
+        + str(input_type_str)
+        + ", "
+        + str(smem_space)
+        + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
+    )
+    acc_ty = ir.Type.parse(
+        "!nvgpu.warpgroup.accumulator<fragmented=vector<"
+        + str(BLOCK_M)
+        + "x"
+        + str(BLOCK_N)
+        + "x"
+        + str(output_type_str)
+        + ">>"
+    )
+    a_wgmma_ty = ir.Type.parse(
+        "!nvgpu.warpgroup.descriptor<tensor=memref<"
+        + str(BLOCK_M)
+        + "x"
+        + str(BLOCK_K)
+        + "x"
+        + str(input_type_str)
+        + ", "
+        + smem_space_str
+        + ">>"
+    )
+    b_wgmma_ty = ir.Type.parse(
+        "!nvgpu.warpgroup.descriptor<tensor=memref<"
+        + str(BLOCK_K)
+        + "x"
+        + str(BLOCK_N)
+        + "x"
+        + str(input_type_str)
+        + ", "
+        + smem_space_str
+        + ">>"
+    )
 
     with ir.InsertionPoint(module.body):
 
         @func.FuncOp.from_py_func(a_ty, b_ty, c_ty)
         def mlir_matmul_warpspecialized(a_host, b_host, c_host):
@@ -137,46 +187,71 @@
             t5 = gpu.memcpy(token_ty, [t4], a_device, a_host)
             t6 = gpu.memcpy(token_ty, [t5], b_device, b_host)
             t7 = gpu.wait(token_ty, [t6])
 
             # Step 2. Create TMA Descriptors
-            tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape), (b_device, b_tma_desc_ty, b_tma_shape)]
+            tma_specs = [
+                (a_device, a_tma_desc_ty, a_tma_shape),
+                (b_device, b_tma_desc_ty, b_tma_shape),
+            ]
             tma_descs = []
             for x_device, tensor_map_ty, tile_shape in tma_specs:
-                x_unranked = memref.cast(ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device)
-                tma_descs.append(nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked, map(c, tile_shape)).result)
+                x_unranked = memref.cast(
+                    ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device
+                )
+                tma_descs.append(
+                    nvgpu.TmaCreateDescriptorOp(
+                        tensor_map_ty, x_unranked, map(c, tile_shape)
+                    ).result
+                )
             a_tma_desc, b_tma_desc = tma_descs
-            
+
             # Step 3. Launch Kernel with 2 Warpgroups : 1 Producer, 1 Consumer
             cta_m = M // BLOCK_M
             cta_n = N // BLOCK_N
             assert M % BLOCK_M == 0 and N % BLOCK_N == 0
             grid = (cta_m, cta_n, 1)
             block = (WARP_GROUP_SIZE * 2, 1, 1)
-            launch_op = gpu.LaunchOp(token_ty, [t7],
-                                     *map(c, grid),
-                                     *map(c, block),
-                                     dynamicSharedMemorySize=c(smem_size, ty=i32))
+            launch_op = gpu.LaunchOp(
+                token_ty,
+                [t7],
+                *map(c, grid),
+                *map(c, block),
+                dynamicSharedMemorySize=c(smem_size, ty=i32)
+            )
             launch_op.body.blocks.append(*([index] * 12))
             with ir.InsertionPoint(launch_op.body.blocks[0]):
                 # GPU Step 0. This is need for vectorized ld/st
                 memref.assume_alignment(c_device, 16)
                 dynamic_smem = gpu.dynamic_shared_memory(
-                    ir.MemRefType.get((MLIR_DYNAMIC, ), i8, memory_space=smem_space))
+                    ir.MemRefType.get((MLIR_DYNAMIC,), i8, memory_space=smem_space)
+                )
                 ticks = c(10000000)
 
                 # GPU Step 1. Bootstrapping: find the primary thread, warps, warp groups and etc.
                 tidx = gpu.thread_id(gpu.Dimension.x)
-                wgPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0))
+                wgPrimaryThread = arith.cmpi(
+                    arith.CmpIPredicate.eq, arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0)
+                )
                 warp_id = arith.divui(tidx, c(32))
                 warpgroup_id = arith.divui(warp_id, c(4))
-                is_producer = arith.cmpi(arith.CmpIPredicate.eq, warpgroup_id,
-                                         c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0))
-                is_consumer = arith.cmpi(arith.CmpIPredicate.eq, warpgroup_id,
-                                         c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1))
-                producerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(PRODUCER_PRIMARY_THREAD))
-                consumerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(CONSUMER_PRIMARY_THREAD))
+                is_producer = arith.cmpi(
+                    arith.CmpIPredicate.eq,
+                    warpgroup_id,
+                    c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0),
+                )
+                is_consumer = arith.cmpi(
+                    arith.CmpIPredicate.eq,
+                    warpgroup_id,
+                    c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1),
+                )
+                producerPrimaryThread = arith.cmpi(
+                    arith.CmpIPredicate.eq, tidx, c(PRODUCER_PRIMARY_THREAD)
+                )
+                consumerPrimaryThread = arith.cmpi(
+                    arith.CmpIPredicate.eq, tidx, c(CONSUMER_PRIMARY_THREAD)
+                )
                 bidx = gpu.block_id(gpu.Dimension.x)
                 bidy = gpu.block_id(gpu.Dimension.y)
                 dimX = arith.muli(bidx, c(BLOCK_M))
                 dimY = arith.muli(bidy, c(BLOCK_N))
 
@@ -189,215 +264,338 @@
                 gpu.barrier()
 
                 # GPU Step 3. Prefetch TMA descriptors
                 nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=wgPrimaryThread)
                 nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=wgPrimaryThread)
-                
+
                 # GPU Step 5. Producer Warpgroup (TMA Warpgroup)
                 with ir.InsertionPoint(scf.IfOp(is_producer).then_block):
-                    
                     # Step 5.1. Reduce register size
-                    nvvm.setmaxregister(PRODUCER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.decrease)
+                    nvvm.setmaxregister(
+                        PRODUCER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.decrease
+                    )
 
                     # Step 5.2. TMA Main Loop
-                    for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [arith.constant(i1, 1)])
+                    for_op = scf.ForOp(
+                        c(0), c(K // BLOCK_K), c(1), [arith.constant(i1, 1)]
+                    )
                     with ir.InsertionPoint(for_op.body):
                         phaseParity = for_op.inner_iter_args[0]
                         iv = for_op.induction_variable
                         stage = arith.remui(iv, c(num_stages))
 
                         # Step 5.2.1. Wait mbarDONE
-                        debug_print("[prod] {}  | mbarDONE[{}] try_wait  phase={}",
-                                    iv,
-                                    stage,
-                                    phaseParity,
-                                    predicate=producerPrimaryThread)
-                        nvgpu.MBarrierTryWaitParityOp(mbarDONE, phaseParity, ticks, mbarId=stage)
-                        debug_print("[prod] {}  | mbarDONE[{}] try_wait  phase={} [done]",
-                                    iv,
-                                    stage,
-                                    phaseParity,
-                                    predicate=producerPrimaryThread)
+                        debug_print(
+                            "[prod] {}  | mbarDONE[{}] try_wait  phase={}",
+                            iv,
+                            stage,
+                            phaseParity,
+                            predicate=producerPrimaryThread,
+                        )
+                        nvgpu.MBarrierTryWaitParityOp(
+                            mbarDONE, phaseParity, ticks, mbarId=stage
+                        )
+                        debug_print(
+                            "[prod] {}  | mbarDONE[{}] try_wait  phase={} [done]",
+                            iv,
+                            stage,
+                            phaseParity,
+                            predicate=producerPrimaryThread,
+                        )
                         p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
-                        phaseParity = arith.select(p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity)
+                        phaseParity = arith.select(
+                            p,
+                            arith.xori(phaseParity, arith.constant(i1, 1)),
+                            phaseParity,
+                        )
 
                         # Step 5.2.2. Load TMA
                         a_offset = arith.muli(stage, c(lhs_tile_bytes))
-                        a_tma_slice = memref.view(ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
-                                                  dynamic_smem, a_offset, [])
-                        b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
-                        b_tma_slice_1 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
-                                                    dynamic_smem, b_offset, [])
-                        b_offset2 = arith.addi(b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
-                        b_tma_slice_2 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
-                                                    dynamic_smem, b_offset2, [])
-                        debug_print("[prod] a_offset={} b_offset={} b_offset2={}",
-                                    a_offset,
-                                    b_offset,
-                                    b_offset2,
-                                    predicate=producerPrimaryThread)
+                        a_tma_slice = memref.view(
+                            ir.MemRefType.get(
+                                a_tma_shape, f16, memory_space=smem_space
+                            ),
+                            dynamic_smem,
+                            a_offset,
+                            [],
+                        )
+                        b_offset = arith.addi(
+                            arith.muli(stage, c(rhs_tile_bytes)),
+                            c(lhs_tile_bytes * num_stages),
+                        )
+                        b_tma_slice_1 = memref.view(
+                            ir.MemRefType.get(
+                                b_tma_shape, f16, memory_space=smem_space
+                            ),
+                            dynamic_smem,
+                            b_offset,
+                            [],
+                        )
+                        b_offset2 = arith.addi(
+                            b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte)
+                        )
+                        b_tma_slice_2 = memref.view(
+                            ir.MemRefType.get(
+                                b_tma_shape, f16, memory_space=smem_space
+                            ),
+                            dynamic_smem,
+                            b_offset2,
+                            [],
+                        )
+                        debug_print(
+                            "[prod] a_offset={} b_offset={} b_offset2={}",
+                            a_offset,
+                            b_offset,
+                            b_offset2,
+                            predicate=producerPrimaryThread,
+                        )
                         coord = arith.muli(c(64), iv)
-                        nvgpu.TmaAsyncLoadOp(a_tma_slice,
-                                             mbarTMA,
-                                             a_tma_desc,
-                                             coordinates=[coord, dimX],
-                                             mbarId=stage,
-                                             predicate=producerPrimaryThread)
-                        nvgpu.TmaAsyncLoadOp(b_tma_slice_1,
-                                             mbarTMA,
-                                             b_tma_desc,
-                                             coordinates=[dimY, coord],
-                                             mbarId=stage,
-                                             predicate=producerPrimaryThread)
+                        nvgpu.TmaAsyncLoadOp(
+                            a_tma_slice,
+                            mbarTMA,
+                            a_tma_desc,
+                            coordinates=[coord, dimX],
+                            mbarId=stage,
+                            predicate=producerPrimaryThread,
+                        )
+                        nvgpu.TmaAsyncLoadOp(
+                            b_tma_slice_1,
+                            mbarTMA,
+                            b_tma_desc,
+                            coordinates=[dimY, coord],
+                            mbarId=stage,
+                            predicate=producerPrimaryThread,
+                        )
                         dimY2 = arith.addi(dimY, c(64))
-                        nvgpu.TmaAsyncLoadOp(b_tma_slice_2,
-                                             mbarTMA,
-                                             b_tma_desc,
-                                             coordinates=[dimY2, coord],
-                                             mbarId=stage,
-                                             predicate=producerPrimaryThread)
+                        nvgpu.TmaAsyncLoadOp(
+                            b_tma_slice_2,
+                            mbarTMA,
+                            b_tma_desc,
+                            coordinates=[dimY2, coord],
+                            mbarId=stage,
+                            predicate=producerPrimaryThread,
+                        )
 
                         # Step 5.2.3. Arrive mbarTMA
-                        debug_print("[prod] {}  | mbarTMA[{}] arrive", iv, stage, predicate=producerPrimaryThread)
-                        nvgpu.mbarrier_arrive_expect_tx(mbarTMA, c(txcount), stage, predicate=producerPrimaryThread)
-                        debug_print("[prod] {}  | mbarTMA[{}] arrive [done]",
-                                    iv,
-                                    stage,
-                                    predicate=producerPrimaryThread)
+                        debug_print(
+                            "[prod] {}  | mbarTMA[{}] arrive",
+                            iv,
+                            stage,
+                            predicate=producerPrimaryThread,
+                        )
+                        nvgpu.mbarrier_arrive_expect_tx(
+                            mbarTMA, c(txcount), stage, predicate=producerPrimaryThread
+                        )
+                        debug_print(
+                            "[prod] {}  | mbarTMA[{}] arrive [done]",
+                            iv,
+                            stage,
+                            predicate=producerPrimaryThread,
+                        )
                         scf.yield_([phaseParity])
                     scf.yield_([])
 
                 # GPU Step 6. Consumer Warpgroup (MMA Warpgroup)
                 if_op = scf.IfOp(is_consumer)
                 with ir.InsertionPoint(if_op.then_block):
-                    
                     # Step 6.1. Increase register size
-                    nvvm.setmaxregister(CONSUMER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.increase)
-                    
+                    nvvm.setmaxregister(
+                        CONSUMER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.increase
+                    )
+
                     # GPU Step 6.2. Initialize MMA registers
                     acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
 
                     # Step 6.3. MMA Main Loop
-                    for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)])
+                    for_op = scf.ForOp(
+                        c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)]
+                    )
                     with ir.InsertionPoint(for_op.body):
                         # Step 6.3.1. Wait mbar1
                         phaseParity = for_op.inner_iter_args[1]
                         iv = for_op.induction_variable
                         stage = arith.remui(iv, c(num_stages))
-                        debug_print("[cons] {}  | mbarTMA[{}] try_wait   phase={}",
-                                    iv,
-                                    stage,
-                                    phaseParity,
-                                    predicate=consumerPrimaryThread)
-                        nvgpu.MBarrierTryWaitParityOp(mbarTMA, phaseParity, ticks, mbarId=stage)
-                        debug_print("[cons] {}  | mbarTMA[{}] try_wait   phase={} [done]",
-                                    iv,
-                                    stage,
-                                    phaseParity,
-                                    predicate=consumerPrimaryThread)
-                        
+                        debug_print(
+                            "[cons] {}  | mbarTMA[{}] try_wait   phase={}",
+                            iv,
+                            stage,
+                            phaseParity,
+                            predicate=consumerPrimaryThread,
+                        )
+                        nvgpu.MBarrierTryWaitParityOp(
+                            mbarTMA, phaseParity, ticks, mbarId=stage
+                        )
+                        debug_print(
+                            "[cons] {}  | mbarTMA[{}] try_wait   phase={} [done]",
+                            iv,
+                            stage,
+                            phaseParity,
+                            predicate=consumerPrimaryThread,
+                        )
+
                         # Step 6.3.2. Create WGMMA Descriptors
                         a_offset = arith.muli(stage, c(lhs_tile_bytes))
-                        a_tile_slice = memref.view(ir.MemRefType.get(a_tile_shape, f16, memory_space=smem_space),
-                                                   dynamic_smem, a_offset, [])
-                        b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
-                        b_tile_slice = memref.view(ir.MemRefType.get(b_tile_shape, f16, memory_space=smem_space),
-                                                   dynamic_smem, b_offset, [])
-                        debug_print("[cons] a_offset={} b_offset={}",
-                                    a_offset,
-                                    b_offset,
-                                    predicate=consumerPrimaryThread)
-                        da = nvgpu.WarpgroupGenerateDescriptorOp(a_wgmma_ty, a_tile_slice, a_tma_desc)
-                        db = nvgpu.WarpgroupGenerateDescriptorOp(b_wgmma_ty, b_tile_slice, b_tma_desc)
+                        a_tile_slice = memref.view(
+                            ir.MemRefType.get(
+                                a_tile_shape, f16, memory_space=smem_space
+                            ),
+                            dynamic_smem,
+                            a_offset,
+                            [],
+                        )
+                        b_offset = arith.addi(
+                            arith.muli(stage, c(rhs_tile_bytes)),
+                            c(lhs_tile_bytes * num_stages),
+                        )
+                        b_tile_slice = memref.view(
+                            ir.MemRefType.get(
+                                b_tile_shape, f16, memory_space=smem_space
+                            ),
+                            dynamic_smem,
+                            b_offset,
+                            [],
+                        )
+                        debug_print(
+                            "[cons] a_offset={} b_offset={}",
+                            a_offset,
+                            b_offset,
+                            predicate=consumerPrimaryThread,
+                        )
+                        da = nvgpu.WarpgroupGenerateDescriptorOp(
+                            a_wgmma_ty, a_tile_slice, a_tma_desc
+                        )
+                        db = nvgpu.WarpgroupGenerateDescriptorOp(
+                            b_wgmma_ty, b_tile_slice, b_tma_desc
+                        )
 
                         # Step 6.3.3. MMA
                         carry_acc = for_op.inner_iter_args[0]
-                        new_acc = nvgpu.WarpgroupMmaOp(acc.type, da, db, carry_acc, transposeB=True)
+                        new_acc = nvgpu.WarpgroupMmaOp(
+                            acc.type, da, db, carry_acc, transposeB=True
+                        )
 
                         # Step 6.3.4. Arrive mbarDONE
                         p1 = arith.cmpi(arith.CmpIPredicate.sgt, iv, c(0))
                         p_arrive = arith.andi(consumerPrimaryThread, p1)
                         with ir.InsertionPoint(scf.IfOp(p_arrive).then_block):
                             p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(0))
-                            barId = arith.select(p, c(num_stages - 1), arith.subi(stage, c(1)))
+                            barId = arith.select(
+                                p, c(num_stages - 1), arith.subi(stage, c(1))
+                            )
                             remoteCtaId = c(0)
                             pred = consumerPrimaryThread
-                            debug_print("[cons] {}  | mbarDONE[{}] arrive.mapa  pred={} ",
-                                        iv,
-                                        barId,
-                                        remoteCtaId,
-                                        pred,
-                                        predicate=consumerPrimaryThread)
-                            nvgpu.mbarrier_arrive(ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId)
-                            debug_print("[cons] {}  | mbarDONE[{}] arrive  pred={} [done]",
-                                        iv,
-                                        barId,
-                                        remoteCtaId,
-                                        pred,
-                                        predicate=consumerPrimaryThread)
+                            debug_print(
+                                "[cons] {}  | mbarDONE[{}] arrive.mapa  pred={} ",
+                                iv,
+                                barId,
+                                remoteCtaId,
+                                pred,
+                                predicate=consumerPrimaryThread,
+                            )
+                            nvgpu.mbarrier_arrive(
+                                ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId
+                            )
+                            debug_print(
+                                "[cons] {}  | mbarDONE[{}] arrive  pred={} [done]",
+                                iv,
+                                barId,
+                                remoteCtaId,
+                                pred,
+                                predicate=consumerPrimaryThread,
+                            )
                             scf.yield_([])
 
                         p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
-                        phaseParity = arith.select(p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity)
+                        phaseParity = arith.select(
+                            p,
+                            arith.xori(phaseParity, arith.constant(i1, 1)),
+                            phaseParity,
+                        )
 
                         # Step 6.3.5. Yield
                         scf.yield_([new_acc, phaseParity])
 
                     # Step 6.3. Wait All WGMMA
                     nvvm.WgmmaWaitGroupSyncOp(0)
 
                     with ir.InsertionPoint(scf.IfOp(consumerPrimaryThread).then_block):
                         barId = c((K // BLOCK_K) % num_stages)
-                        nvgpu.mbarrier_arrive(ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId)
+                        nvgpu.mbarrier_arrive(
+                            ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId
+                        )
                         scf.yield_([])
 
                     # Step 6.4. Epilogue (registers --> shared memory)
-                    acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space)
+                    acc_smem_ty = ir.MemRefType.get(
+                        (BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space
+                    )
                     acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
                     debug_print("[cons]  | Storing", predicate=consumerPrimaryThread)
                     nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
                     scf.yield_([])
                 gpu.barrier()
 
                 # GPU Step 9. Epilogue (shared memory --> global memory)
-                fd = ir.MemRefType.get([BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space)
+                fd = ir.MemRefType.get(
+                    [BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space
+                )
                 collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
-                rty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty,
-                                        ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"))
-                c_device_per_block = memref.SubViewOp(rty, c_device, [dimX, dimY], [], [], [MLIR_DYNAMIC, MLIR_DYNAMIC],
-                                                      [BLOCK_M, BLOCK_N], [1, 1])
+                rty = ir.MemRefType.get(
+                    (BLOCK_M, BLOCK_N),
+                    c_elem_ty,
+                    ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"),
+                )
+                c_device_per_block = memref.SubViewOp(
+                    rty,
+                    c_device,
+                    [dimX, dimY],
+                    [],
+                    [],
+                    [MLIR_DYNAMIC, MLIR_DYNAMIC],
+                    [BLOCK_M, BLOCK_N],
+                    [1, 1],
+                )
                 vlen = 1
-                for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE * 2))
+                for_op = scf.ForOp(
+                    tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE * 2)
+                )
                 with ir.InsertionPoint(for_op.body):
                     x = arith.divui(for_op.induction_variable, c(BLOCK_M))
                     y = arith.remui(for_op.induction_variable, c(BLOCK_N))
-                    vdata = vector.load(ir.VectorType.get((vlen, ), c_elem_ty), collapsed_smem,
-                                        [for_op.induction_variable])
+                    vdata = vector.load(
+                        ir.VectorType.get((vlen,), c_elem_ty),
+                        collapsed_smem,
+                        [for_op.induction_variable],
+                    )
                     vector.store(vdata, c_device_per_block, [x, y])
                     scf.yield_([])
 
                 gpu.terminator()
 
             # Step 4. Copy back to host
             t8 = gpu.wait(token_ty, [launch_op])
             t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
             gpu.wait(token_ty, [t9])
 
-    mlir_matmul_warpspecialized.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+    mlir_matmul_warpspecialized.func_op.attributes[
+        "llvm.emit_c_interface"
+    ] = ir.UnitAttr.get()
     module.operation.verify()
     return module
 
 
-def generate_matmul_multistage(input_type=np.float16,
-                               output_type=np.float32,
-                               M=4096,
-                               N=4096,
-                               K=4096,
-                               BLOCK_M=128,
-                               BLOCK_N=128,
-                               BLOCK_K=64,
-                               max_num_stages=3):
+def generate_matmul_multistage(
+    input_type=np.float16,
+    output_type=np.float32,
+    M=4096,
+    N=4096,
+    K=4096,
+    BLOCK_M=128,
+    BLOCK_N=128,
+    BLOCK_K=64,
+    max_num_stages=3,
+):
     # Limitaitons for now
     assert input_type == np.float16
     assert output_type == np.float32
     assert M % BLOCK_M == 0
     assert N % BLOCK_N == 0
@@ -419,29 +617,77 @@
     c_elem_ty = f16 if output_type == np.float16 else f32
     c_ty = ir.MemRefType.get((M, N), c_elem_ty)
     a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
     b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
     b_tile_shape = (BLOCK_K, BLOCK_N)
-    txcount = ((b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
+    txcount = (
+        (b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])
+    ) * f16_byte
     smem_space_str = "#gpu.address_space<workgroup>"
     smem_space = ir.Attribute.parse(smem_space_str)
     input_type_str = "f16" if input_type == np.float16 else "f32"
     output_type_str = "f16" if output_type == np.float16 else "f32"
-    mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " + str(smem_space) + ", num_barriers = " +
-                            str(num_stages) + ">")
-    a_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
-                                  str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
-                                  ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
-    b_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
-                                  str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
-                                  ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
-    acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
-                           str(output_type_str) + ">>")
-    a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
-                               str(input_type_str) + ", " + smem_space_str + ">>")
-    b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
-                               str(input_type_str) + ", " + smem_space_str + ">>")
+    mbar_ty = ir.Type.parse(
+        "!nvgpu.mbarrier.group<memorySpace = "
+        + str(smem_space)
+        + ", num_barriers = "
+        + str(num_stages)
+        + ">"
+    )
+    a_tma_desc_ty = ir.Type.parse(
+        "!nvgpu.tensormap.descriptor<tensor = memref<"
+        + str(BLOCK_M)
+        + "x"
+        + str(TMA_LAST_DIM_F16)
+        + "x"
+        + str(input_type_str)
+        + ", "
+        + str(smem_space)
+        + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
+    )
+    b_tma_desc_ty = ir.Type.parse(
+        "!nvgpu.tensormap.descriptor<tensor = memref<"
+        + str(BLOCK_K)
+        + "x"
+        + str(TMA_LAST_DIM_F16)
+        + "x"
+        + str(input_type_str)
+        + ", "
+        + str(smem_space)
+        + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
+    )
+    acc_ty = ir.Type.parse(
+        "!nvgpu.warpgroup.accumulator<fragmented=vector<"
+        + str(BLOCK_M)
+        + "x"
+        + str(BLOCK_N)
+        + "x"
+        + str(output_type_str)
+        + ">>"
+    )
+    a_wgmma_ty = ir.Type.parse(
+        "!nvgpu.warpgroup.descriptor<tensor=memref<"
+        + str(BLOCK_M)
+        + "x"
+        + str(BLOCK_K)
+        + "x"
+        + str(input_type_str)
+        + ", "
+        + smem_space_str
+        + ">>"
+    )
+    b_wgmma_ty = ir.Type.parse(
+        "!nvgpu.warpgroup.descriptor<tensor=memref<"
+        + str(BLOCK_K)
+        + "x"
+        + str(BLOCK_N)
+        + "x"
+        + str(input_type_str)
+        + ", "
+        + smem_space_str
+        + ">>"
+    )
 
     with ir.InsertionPoint(module.body):
 
         @func.FuncOp.from_py_func(a_ty, b_ty, c_ty)
         def mlir_matmul_multistage(a_host, b_host, c_host):
@@ -459,33 +705,46 @@
             t5 = gpu.memcpy(token_ty, [t4], a_device, a_host)
             t6 = gpu.memcpy(token_ty, [t5], b_device, b_host)
             t7 = gpu.wait(token_ty, [t6])
 
             # Step 2. Create TMA Descriptors
-            tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape), (b_device, b_tma_desc_ty, b_tma_shape)]
+            tma_specs = [
+                (a_device, a_tma_desc_ty, a_tma_shape),
+                (b_device, b_tma_desc_ty, b_tma_shape),
+            ]
             tma_descs = []
             for x_device, tensor_map_ty, tile_shape in tma_specs:
-                x_unranked = memref.cast(ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device)
-                tma_descs.append(nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked, map(c, tile_shape)).result)
+                x_unranked = memref.cast(
+                    ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device
+                )
+                tma_descs.append(
+                    nvgpu.TmaCreateDescriptorOp(
+                        tensor_map_ty, x_unranked, map(c, tile_shape)
+                    ).result
+                )
             a_tma_desc, b_tma_desc = tma_descs
 
             # Step 3. Launch Kernel with 1 Warpgroup
             cta_m = M // BLOCK_M
             cta_n = N // BLOCK_N
             assert M % BLOCK_M == 0 and N % BLOCK_N == 0
             grid = (cta_m, cta_n, 1)
             block = (WARP_GROUP_SIZE, 1, 1)
-            launch_op = gpu.LaunchOp(token_ty, [t7],
-                                     *map(c, grid),
-                                     *map(c, block),
-                                     dynamicSharedMemorySize=c(smem_size, ty=i32))
+            launch_op = gpu.LaunchOp(
+                token_ty,
+                [t7],
+                *map(c, grid),
+                *map(c, block),
+                dynamicSharedMemorySize=c(smem_size, ty=i32)
+            )
             launch_op.body.blocks.append(*([index] * 12))
             with ir.InsertionPoint(launch_op.body.blocks[0]):
                 # GPU Step 0. Bootstrapping
                 memref.assume_alignment(c_device, 16)
                 dynamic_smem = gpu.dynamic_shared_memory(
-                    ir.MemRefType.get((MLIR_DYNAMIC, ), i8, memory_space=smem_space))
+                    ir.MemRefType.get((MLIR_DYNAMIC,), i8, memory_space=smem_space)
+                )
                 ticks = c(10000000)
                 tidx = gpu.thread_id(gpu.Dimension.x)
                 primaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(0))
                 warpId = arith.divui(tidx, c(32))
                 bidx = gpu.block_id(gpu.Dimension.x)
@@ -502,175 +761,319 @@
                 # GPU Step 2. Prefetch TMA descriptors
                 nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=primaryThread)
                 nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=primaryThread)
 
                 # GPU Step 3. Prologue (global memory --> shared memory)
-                for_op = scf.ForOp(c(0), c(num_stages-1), c(1))
+                for_op = scf.ForOp(c(0), c(num_stages - 1), c(1))
                 with ir.InsertionPoint(for_op.body):
                     iv = for_op.induction_variable
 
                     # Step 3.1. Calculate offsets
                     a_offset = arith.muli(iv, c(lhs_tile_bytes))
-                    a_tma_slice = memref.view(ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
-                                              dynamic_smem, a_offset, [])
-                    b_offset = arith.addi(arith.muli(iv, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
-                    b_tma_slice_1 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
-                                                dynamic_smem, b_offset, [])
-                    b_offset2 = arith.addi(b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
-                    b_tma_slice_2 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
-                                                dynamic_smem, b_offset2, [])
+                    a_tma_slice = memref.view(
+                        ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
+                        dynamic_smem,
+                        a_offset,
+                        [],
+                    )
+                    b_offset = arith.addi(
+                        arith.muli(iv, c(rhs_tile_bytes)),
+                        c(lhs_tile_bytes * num_stages),
+                    )
+                    b_tma_slice_1 = memref.view(
+                        ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+                        dynamic_smem,
+                        b_offset,
+                        [],
+                    )
+                    b_offset2 = arith.addi(
+                        b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte)
+                    )
+                    b_tma_slice_2 = memref.view(
+                        ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+                        dynamic_smem,
+                        b_offset2,
+                        [],
+                    )
 
                     # Step 3.2. TMA Load
                     coord = arith.muli(c(64), iv)
                     dimY2 = arith.addi(dimY, c(64))
-                    debug_print("[Prologue] TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
-                                a_offset,
-                                b_offset,
-                                b_offset2,
-                                coord, dimX,
-                                dimY, coord,
-                                predicate=primaryThread)
-                    nvgpu.TmaAsyncLoadOp(a_tma_slice,
-                                         mbarTMA,
-                                         a_tma_desc,
-                                         coordinates=[coord, dimX],
-                                         mbarId=iv,
-                                         predicate=primaryThread)
-                    nvgpu.TmaAsyncLoadOp(b_tma_slice_1,
-                                         mbarTMA,
-                                         b_tma_desc,
-                                         coordinates=[dimY, coord],
-                                         mbarId=iv,
-                                         predicate=primaryThread)
-                    nvgpu.TmaAsyncLoadOp(b_tma_slice_2,
-                                         mbarTMA,
-                                         b_tma_desc,
-                                         coordinates=[dimY2, coord],
-                                         mbarId=iv,
-                                         predicate=primaryThread)
+                    debug_print(
+                        "[Prologue] TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
+                        a_offset,
+                        b_offset,
+                        b_offset2,
+                        coord,
+                        dimX,
+                        dimY,
+                        coord,
+                        predicate=primaryThread,
+                    )
+                    nvgpu.TmaAsyncLoadOp(
+                        a_tma_slice,
+                        mbarTMA,
+                        a_tma_desc,
+                        coordinates=[coord, dimX],
+                        mbarId=iv,
+                        predicate=primaryThread,
+                    )
+                    nvgpu.TmaAsyncLoadOp(
+                        b_tma_slice_1,
+                        mbarTMA,
+                        b_tma_desc,
+                        coordinates=[dimY, coord],
+                        mbarId=iv,
+                        predicate=primaryThread,
+                    )
+                    nvgpu.TmaAsyncLoadOp(
+                        b_tma_slice_2,
+                        mbarTMA,
+                        b_tma_desc,
+                        coordinates=[dimY2, coord],
+                        mbarId=iv,
+                        predicate=primaryThread,
+                    )
 
                     # Step 3.2. mbarTMA arrive
-                    debug_print("[Prologue] mbarTMA[{}] arrive", iv, predicate=primaryThread)
-                    nvgpu.mbarrier_arrive_expect_tx(mbarTMA, c(txcount), iv, predicate=primaryThread)
-                    debug_print("[Prologue] mbarTMA[{}] arrive [done]", iv, predicate=primaryThread)
+                    debug_print(
+                        "[Prologue] mbarTMA[{}] arrive", iv, predicate=primaryThread
+                    )
+                    nvgpu.mbarrier_arrive_expect_tx(
+                        mbarTMA, c(txcount), iv, predicate=primaryThread
+                    )
+                    debug_print(
+                        "[Prologue] mbarTMA[{}] arrive [done]",
+                        iv,
+                        predicate=primaryThread,
+                    )
                     scf.yield_([])
 
                 # GPU Step 4. Main Loop
                 acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
-                for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)])
+                for_op = scf.ForOp(
+                    c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)]
+                )
                 with ir.InsertionPoint(for_op.body):
                     # Step 4.1. Wait mbarTMA
                     phaseParity = for_op.inner_iter_args[1]
                     iv = for_op.induction_variable
                     stage = arith.remui(iv, c(num_stages))
-                    debug_print("[MainLoop] mbarTMA[{}] try_wait   phase={}", stage, phaseParity, predicate=primaryThread)
-                    nvgpu.MBarrierTryWaitParityOp(mbarTMA, phaseParity, ticks, mbarId=stage)
-                    debug_print("[MainLoop] mbarTMA[{}] try_wait   phase={} [done]", stage, phaseParity, predicate=primaryThread)
+                    debug_print(
+                        "[MainLoop] mbarTMA[{}] try_wait   phase={}",
+                        stage,
+                        phaseParity,
+                        predicate=primaryThread,
+                    )
+                    nvgpu.MBarrierTryWaitParityOp(
+                        mbarTMA, phaseParity, ticks, mbarId=stage
+                    )
+                    debug_print(
+                        "[MainLoop] mbarTMA[{}] try_wait   phase={} [done]",
+                        stage,
+                        phaseParity,
+                        predicate=primaryThread,
+                    )
 
                     # Step 4.2. Create WGMMA Descriptors
                     a_offset = arith.muli(stage, c(lhs_tile_bytes))
-                    a_tile_slice = memref.view(ir.MemRefType.get(a_tile_shape, f16, memory_space=smem_space),
-                                               dynamic_smem, a_offset, [])
-                    b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
-                    b_tile_slice = memref.view(ir.MemRefType.get(b_tile_shape, f16, memory_space=smem_space),
-                                               dynamic_smem, b_offset, [])
-                    debug_print("[MainLoop] iv={} MMA a_offset={} b_offset={}", iv, a_offset, b_offset, predicate=primaryThread)
-                    da = nvgpu.WarpgroupGenerateDescriptorOp(a_wgmma_ty, a_tile_slice, a_tma_desc)
-                    db = nvgpu.WarpgroupGenerateDescriptorOp(b_wgmma_ty, b_tile_slice, b_tma_desc)
+                    a_tile_slice = memref.view(
+                        ir.MemRefType.get(a_tile_shape, f16, memory_space=smem_space),
+                        dynamic_smem,
+                        a_offset,
+                        [],
+                    )
+                    b_offset = arith.addi(
+                        arith.muli(stage, c(rhs_tile_bytes)),
+                        c(lhs_tile_bytes * num_stages),
+                    )
+                    b_tile_slice = memref.view(
+                        ir.MemRefType.get(b_tile_shape, f16, memory_space=smem_space),
+                        dynamic_smem,
+                        b_offset,
+                        [],
+                    )
+                    debug_print(
+                        "[MainLoop] iv={} MMA a_offset={} b_offset={}",
+                        iv,
+                        a_offset,
+                        b_offset,
+                        predicate=primaryThread,
+                    )
+                    da = nvgpu.WarpgroupGenerateDescriptorOp(
+                        a_wgmma_ty, a_tile_slice, a_tma_desc
+                    )
+                    db = nvgpu.WarpgroupGenerateDescriptorOp(
+                        b_wgmma_ty, b_tile_slice, b_tma_desc
+                    )
 
                     # Step 4.3. MMA
                     carry_acc = for_op.inner_iter_args[0]
-                    new_acc = nvgpu.WarpgroupMmaOp(acc.type, da, db, carry_acc, transposeB=True)
+                    new_acc = nvgpu.WarpgroupMmaOp(
+                        acc.type, da, db, carry_acc, transposeB=True
+                    )
 
                     # Step 4.4. Load TMA for next stage
-                    p1 = arith.cmpi(arith.CmpIPredicate.ult, arith.addi(iv, c(num_stages-1)), c(K // BLOCK_K))
+                    p1 = arith.cmpi(
+                        arith.CmpIPredicate.ult,
+                        arith.addi(iv, c(num_stages - 1)),
+                        c(K // BLOCK_K),
+                    )
                     p = arith.andi(primaryThread, p1)
                     with ir.InsertionPoint(scf.IfOp(p).then_block):
-                        nextStage = arith.addi(iv, c(num_stages-1))
+                        nextStage = arith.addi(iv, c(num_stages - 1))
                         nextSlot = arith.remui(nextStage, c(num_stages))
                         a_offset = arith.muli(nextSlot, c(lhs_tile_bytes))
-                        a_tma_slice = memref.view(ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
-                                                dynamic_smem, a_offset, [])
-                        b_offset = arith.addi(arith.muli(nextSlot, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
-                        b_tma_slice_1 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
-                                                    dynamic_smem, b_offset, [])
-                        b_offset2 = arith.addi(b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
-                        b_tma_slice_2 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
-                                                    dynamic_smem, b_offset2, [])
+                        a_tma_slice = memref.view(
+                            ir.MemRefType.get(
+                                a_tma_shape, f16, memory_space=smem_space
+                            ),
+                            dynamic_smem,
+                            a_offset,
+                            [],
+                        )
+                        b_offset = arith.addi(
+                            arith.muli(nextSlot, c(rhs_tile_bytes)),
+                            c(lhs_tile_bytes * num_stages),
+                        )
+                        b_tma_slice_1 = memref.view(
+                            ir.MemRefType.get(
+                                b_tma_shape, f16, memory_space=smem_space
+                            ),
+                            dynamic_smem,
+                            b_offset,
+                            [],
+                        )
+                        b_offset2 = arith.addi(
+                            b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte)
+                        )
+                        b_tma_slice_2 = memref.view(
+                            ir.MemRefType.get(
+                                b_tma_shape, f16, memory_space=smem_space
+                            ),
+                            dynamic_smem,
+                            b_offset2,
+                            [],
+                        )
 
                         coord = arith.muli(c(64), nextStage)
-                        debug_print("[MainLoop] iv={} TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
-                                iv, a_offset,
-                                b_offset,
-                                b_offset2,
-                                coord, dimX,
-                                dimY, coord,
-                                predicate=primaryThread)
-                        nvgpu.TmaAsyncLoadOp(a_tma_slice,
-                                            mbarTMA,
-                                            a_tma_desc,
-                                            coordinates=[coord, dimX],
-                                            mbarId=nextSlot,
-                                            predicate=primaryThread)
-                        nvgpu.TmaAsyncLoadOp(b_tma_slice_1,
-                                            mbarTMA,
-                                            b_tma_desc,
-                                            coordinates=[dimY, coord],
-                                            mbarId=nextSlot,
-                                            predicate=primaryThread)
+                        debug_print(
+                            "[MainLoop] iv={} TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
+                            iv,
+                            a_offset,
+                            b_offset,
+                            b_offset2,
+                            coord,
+                            dimX,
+                            dimY,
+                            coord,
+                            predicate=primaryThread,
+                        )
+                        nvgpu.TmaAsyncLoadOp(
+                            a_tma_slice,
+                            mbarTMA,
+                            a_tma_desc,
+                            coordinates=[coord, dimX],
+                            mbarId=nextSlot,
+                            predicate=primaryThread,
+                        )
+                        nvgpu.TmaAsyncLoadOp(
+                            b_tma_slice_1,
+                            mbarTMA,
+                            b_tma_desc,
+                            coordinates=[dimY, coord],
+                            mbarId=nextSlot,
+                            predicate=primaryThread,
+                        )
                         dimY2 = arith.addi(dimY, c(64))
-                        nvgpu.TmaAsyncLoadOp(b_tma_slice_2,
-                                            mbarTMA,
-                                            b_tma_desc,
-                                            coordinates=[dimY2, coord],
-                                            mbarId=nextSlot,
-                                            predicate=primaryThread)
-
-                        debug_print("[MainLoop] mbarTMA[{}] arrive", nextSlot, predicate=primaryThread)
-                        nvgpu.mbarrier_arrive_expect_tx(mbarTMA, c(txcount), nextSlot, predicate=primaryThread)
-                        debug_print("[MainLoop] mbarTMA[{}] arrive [done]", nextSlot, predicate=primaryThread)
+                        nvgpu.TmaAsyncLoadOp(
+                            b_tma_slice_2,
+                            mbarTMA,
+                            b_tma_desc,
+                            coordinates=[dimY2, coord],
+                            mbarId=nextSlot,
+                            predicate=primaryThread,
+                        )
+
+                        debug_print(
+                            "[MainLoop] mbarTMA[{}] arrive",
+                            nextSlot,
+                            predicate=primaryThread,
+                        )
+                        nvgpu.mbarrier_arrive_expect_tx(
+                            mbarTMA, c(txcount), nextSlot, predicate=primaryThread
+                        )
+                        debug_print(
+                            "[MainLoop] mbarTMA[{}] arrive [done]",
+                            nextSlot,
+                            predicate=primaryThread,
+                        )
                         scf.yield_([])
                     # Step 4.5. Change the phaseParity
                     p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
-                    phaseParity = arith.select(p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity)
+                    phaseParity = arith.select(
+                        p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity
+                    )
 
                     # Step 4.5. Yield
                     scf.yield_([new_acc, phaseParity])
 
                 # Step 5. Wait All WGMMA groups
                 nvvm.WgmmaWaitGroupSyncOp(0)
                 gpu.barrier()
 
                 # Step 6. Epilogue (registers --> shared memory)
-                acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space)
+                acc_smem_ty = ir.MemRefType.get(
+                    (BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space
+                )
                 acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
                 debug_print("Storing", predicate=primaryThread)
                 nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
                 gpu.barrier()
 
                 # GPU Step 7. Epilogue (shared memory --> global memory)
-                fd = ir.MemRefType.get([BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space)
+                fd = ir.MemRefType.get(
+                    [BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space
+                )
                 collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
-                rty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty,
-                                        ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"))
-                c_device_per_block = memref.SubViewOp(rty, c_device, [dimX, dimY], [], [], [MLIR_DYNAMIC, MLIR_DYNAMIC],
-                                                      [BLOCK_M, BLOCK_N], [1, 1])
+                rty = ir.MemRefType.get(
+                    (BLOCK_M, BLOCK_N),
+                    c_elem_ty,
+                    ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"),
+                )
+                c_device_per_block = memref.SubViewOp(
+                    rty,
+                    c_device,
+                    [dimX, dimY],
+                    [],
+                    [],
+                    [MLIR_DYNAMIC, MLIR_DYNAMIC],
+                    [BLOCK_M, BLOCK_N],
+                    [1, 1],
+                )
                 vlen = 1
-                for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE))
+                for_op = scf.ForOp(
+                    tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE)
+                )
                 with ir.InsertionPoint(for_op.body):
                     x = arith.divui(for_op.induction_variable, c(BLOCK_M))
                     y = arith.remui(for_op.induction_variable, c(BLOCK_N))
-                    vdata = vector.load(ir.VectorType.get((vlen, ), c_elem_ty), collapsed_smem,
-                                        [for_op.induction_variable])
+                    vdata = vector.load(
+                        ir.VectorType.get((vlen,), c_elem_ty),
+                        collapsed_smem,
+                        [for_op.induction_variable],
+                    )
                     vector.store(vdata, c_device_per_block, [x, y])
                     scf.yield_([])
 
                 gpu.terminator()
 
             # Step 4. Copy back to host
             t8 = gpu.wait(token_ty, [launch_op])
             t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
             gpu.wait(token_ty, [t9])
 
-    mlir_matmul_multistage.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+    mlir_matmul_multistage.func_op.attributes[
+        "llvm.emit_c_interface"
+    ] = ir.UnitAttr.get()
     module.operation.verify()
     return module
--- tools/nvgpucompiler.py	2024-02-12 13:33:26.000000 +0000
+++ tools/nvgpucompiler.py	2024-02-12 13:36:46.799197 +0000
@@ -13,10 +13,11 @@
 import sys
 
 _SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
 sys.path.append(_SCRIPT_PATH)
 
+
 class NvgpuCompiler:
     """Nvgpu class for compiling and building MLIR modules."""
 
     def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
         pipeline = f"builtin.module(gpu-lower-to-nvvm-pipeline{{{options}}})"
@@ -29,11 +30,11 @@
         self.compile(module)
 
     def compile(self, module: ir.Module):
         """Compiles the module by invoking the nvgpu pipeline."""
         passmanager.PassManager.parse(self.pipeline).run(module.operation)
-    
+
     def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
         """Wraps the module in a JIT execution engine."""
         return execution_engine.ExecutionEngine(
             module, opt_level=self.opt_level, shared_libs=self.shared_libs
         )

``````````

</details>


https://github.com/llvm/llvm-project/pull/81478


More information about the Mlir-commits mailing list