[Mlir-commits] [mlir] [mlir][nvgpu] NVGPU Tutorials (PR #87065)

Guray Ozen llvmlistbot at llvm.org
Fri Apr 12 11:20:02 PDT 2024


https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/87065

>From 5a2e21c9c071a8ba2b7a4cbbf68dc65889256efd Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Fri, 29 Mar 2024 13:09:45 +0000
Subject: [PATCH 01/14] [mlir][nvgpu] Zero to Hero: Programming Nvidia Hopper
 Tensor Core with MLIR's NVGPU Dialect

I have a tutorial on nvgpu dialect using python bindings at EuroLLVM 2024. For that, I implemented tutorial codes in Python. The focus is the nvgpu dialect and how to use its advanced features. I thought it might be useful to upstream this.

The tutorial codes are as follows:
- **Ch0.py:** Hello World
- **Ch1.py:** 2D Saxpy
- **Ch2.py:** 2D Saxpy using TMA engine
- **Ch3.py:** GEMM 64x64x64 using Tensor Core and TMA
- **Ch4.py:** Multistage performant GEMM using Tensor Core and TMA

I might implement two more chapters, but they are more like GPU programming than compiler related.
- **Ch5.py:** Warp Specilized GEMM
- **Ch6.py:** Warp Specilized Persistent GEMM

This PR also introduces the nvdsl class, making IR building in the tutorial easier.
---
 mlir/test/Examples/nvgpu/Ch0.py               |  46 ++
 mlir/test/Examples/nvgpu/Ch1.py               |  68 +++
 mlir/test/Examples/nvgpu/Ch2.py               |  94 ++++
 mlir/test/Examples/nvgpu/Ch3.py               |  92 ++++
 mlir/test/Examples/nvgpu/Ch4.py               | 291 +++++++++++++
 mlir/test/Examples/nvgpu/lit.local.cfg        |   3 +
 mlir/test/Examples/nvgpu/tools/lit.local.cfg  |   3 +
 mlir/test/Examples/nvgpu/tools/nvdsl.py       | 404 ++++++++++++++++++
 .../Examples/nvgpu/tools/nvgpucompiler.py     |  45 ++
 9 files changed, 1046 insertions(+)
 create mode 100644 mlir/test/Examples/nvgpu/Ch0.py
 create mode 100644 mlir/test/Examples/nvgpu/Ch1.py
 create mode 100644 mlir/test/Examples/nvgpu/Ch2.py
 create mode 100644 mlir/test/Examples/nvgpu/Ch3.py
 create mode 100644 mlir/test/Examples/nvgpu/Ch4.py
 create mode 100644 mlir/test/Examples/nvgpu/lit.local.cfg
 create mode 100644 mlir/test/Examples/nvgpu/tools/lit.local.cfg
 create mode 100644 mlir/test/Examples/nvgpu/tools/nvdsl.py
 create mode 100644 mlir/test/Examples/nvgpu/tools/nvgpucompiler.py

diff --git a/mlir/test/Examples/nvgpu/Ch0.py b/mlir/test/Examples/nvgpu/Ch0.py
new file mode 100644
index 00000000000000..221ca43d37307a
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/Ch0.py
@@ -0,0 +1,46 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN:   %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+#  Chapter 0 : Hello World
+# ===----------------------------------------------------------------------===//
+#
+# This program demonstrates Hello World
+#
+# This chapter introduces demonstrates:
+#   1. Build MLIR function with arguments
+#   2. Build MLIR GPU kernel
+#   3. Print from a GPU thread
+#   4. Pass arguments, JIT compile and run the MLIR function
+#
+# ===----------------------------------------------------------------------===//
+
+
+from mlir.dialects import gpu
+from tools.nvdsl import *
+
+
+# 1. Build function with arguments
+ at NVDSL.mlir_func
+def main(alpha):
+    # 2. Build GPU kernel
+    @NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(4, 1, 1))
+    def kernel():
+        tidx = gpu.thread_id(gpu.Dimension.x)
+        myValue = alpha + tidx
+        # Print from a GPU thread
+        gpu.printf("GPU thread %llu has %llu\n", [tidx, myValue])
+
+    # 3. Call the GPU kernel
+    kernel()
+
+
+# 4. Pass arguments, JIT compile and run the MLIR function
+alpha = 100
+main(alpha)
+
+
+# CHECK: GPU thread 0 has 100
+# CHECK: GPU thread 1 has 101
+# CHECK: GPU thread 2 has 102
+# CHECK: GPU thread 3 has 103
diff --git a/mlir/test/Examples/nvgpu/Ch1.py b/mlir/test/Examples/nvgpu/Ch1.py
new file mode 100644
index 00000000000000..a888c61358a024
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/Ch1.py
@@ -0,0 +1,68 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN:   %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+#  Chapter 1 : 2D Saxpy
+# ===----------------------------------------------------------------------===//
+#
+# This program demonstrates 2D Saxpy
+#
+# This chapter introduces demonstrates:
+#  1. Use MLIR GPU dialect to allocate and copy memory
+#  2. Compute 2D SAXPY kernel
+#  3. Pass numpy arrays to MLIR
+#  4. Verify MLIR with reference computation
+#
+# ===----------------------------------------------------------------------===//
+
+
+from mlir import ir
+from mlir.dialects import gpu, memref
+from tools.nvdsl import *
+import numpy as np
+
+
+ at NVDSL.mlir_func
+def saxpy(x, y, alpha):
+    # 1. Use MLIR GPU dialect to allocate and copy memory
+    token_ty = ir.Type.parse("!gpu.async.token")
+    t1 = gpu.wait(token_ty, [])
+    x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
+    y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
+    t4 = gpu.memcpy(token_ty, [t3], x_dev, x)
+    t5 = gpu.memcpy(token_ty, [t4], y_dev, y)
+    t6 = gpu.wait(token_ty, [t5])
+
+    # 2. Compute 2D SAXPY kernel
+    @NVDSL.mlir_gpu_launch(grid=(M, 1, 1), block=(N, 1, 1))
+    def saxpy_kernel():
+        bidx = gpu.block_id(gpu.Dimension.x)
+        tidx = gpu.thread_id(gpu.Dimension.x)
+        x_val = memref.load(x_dev, [bidx, tidx])
+        y_val = memref.load(y_dev, [bidx, tidx])
+
+        # SAXPY: y[i] += a * x[i];
+        y_val += x_val * alpha
+
+        memref.store(y_val, y_dev, [bidx, tidx])
+
+    saxpy_kernel()
+
+    t7 = gpu.memcpy(token_ty, [t6], y, y_dev)
+    gpu.wait(token_ty, [t7])
+
+
+# 3. Pass numpy arrays to MLIR
+M = 256
+N = 32
+alpha = 2.0
+x = np.ones((M, N), np.float32)
+y = np.ones((M, N), np.float32)
+ref = np.ones((M, N), np.float32)
+saxpy(x, y, alpha)
+
+#  4. Verify MLIR with reference computation
+ref += x * alpha
+np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
+print("PASS")
+# CHECK-NOT: Mismatched elements
diff --git a/mlir/test/Examples/nvgpu/Ch2.py b/mlir/test/Examples/nvgpu/Ch2.py
new file mode 100644
index 00000000000000..caa2691a06d69e
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/Ch2.py
@@ -0,0 +1,94 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN:   %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+#  Chapter 2 : 2D Saxpy with TMA
+# ===----------------------------------------------------------------------===//
+#
+# This program demonstrates 2D Saxpy. It is same as Chapter 1, 
+# but it loads data using TMA (Tensor Memory Accelerator)
+#
+# This chapter introduces demonstrates:
+#  1. Create and initialize asynchronous transactional barrier (mbarrier)
+#  2. Execute Tensor Memory Accelerator (TMA) Load
+#  3. Wait for completion of TMA load with mbarrier
+#
+# ===----------------------------------------------------------------------===//
+
+from mlir import ir
+from mlir.dialects import nvgpu, scf, arith, memref, vector, gpu
+from tools.nvdsl import *
+from mlir import runtime as rt
+from mlir.extras import types as T
+import numpy as np
+
+
+ at NVDSL.mlir_func
+def saxpy_tma(x, y, alpha):
+    token_ty = ir.Type.parse("!gpu.async.token")
+    t1 = gpu.wait(token_ty, [])
+    x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
+    y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
+    t4 = gpu.memcpy(token_ty, [t3], x_dev, x)
+    t5 = gpu.memcpy(token_ty, [t4], y_dev, y)
+    t6 = gpu.wait(token_ty, [t5])
+
+    x_tma = TMA((M, N), x.type)
+    y_tma = TMA((M, N), y.type)
+    x_tma.create_descriptor(x_dev)
+    y_tma.create_descriptor(y_dev)
+
+    @NVDSL.mlir_gpu_launch(grid=(M, 1, 1), block=(N, 1, 1), smem=65536)
+    def saxpy_tma_kernel():
+        bidx = gpu.block_id(gpu.Dimension.x)
+        tidx = gpu.thread_id(gpu.Dimension.x)
+        isThread0 = tidx == 0
+
+        # 1. Create and initialize asynchronous transactional barrier (mbarrier)
+        mbar_group = Mbarriers(number_of_barriers=1)
+        with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
+            mbar_group[0].init(1)
+            x_tma.prefetch()
+            y_tma.prefetch()
+            scf.yield_([])
+
+        x_smem = get_dynamic_shared_memory((M, N), T.f32())
+        y_smem = get_dynamic_shared_memory((M, N), T.f32(), offset=M * N * 2)
+
+        # 2. Execute Tensor Memory Accelerator (TMA) Load
+        with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
+            x_tma.load(x_smem, mbar_group[0])
+            y_tma.load(y_smem, mbar_group[0])
+            mbar_group[0].arrive(txcount=M * N * 2 * 4)
+            scf.yield_([])
+
+        # 3. Wait for completion of TMA load with mbarrier
+        mbar_group[0].try_wait()
+
+        x_val = memref.load(x_smem, [bidx, tidx])
+        y_val = memref.load(y_smem, [bidx, tidx])
+
+        # SAXPY: y[i] += a * x[i];
+        y_val += x_val * alpha
+
+        memref.store(y_val, y_dev, [bidx, tidx])
+
+    saxpy_tma_kernel()
+
+    t7 = gpu.memcpy(token_ty, [t6], y, y_dev)
+    gpu.wait(token_ty, [t7])
+
+
+M = 256
+N = 32
+alpha = 2.0
+x = np.ones((M, N), np.float32)
+y = np.ones((M, N), np.float32)
+ref = np.ones((M, N), np.float32)
+saxpy_tma(x, y, alpha)
+
+#  4. Verify MLIR with reference computation
+ref += x * alpha
+np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
+print("PASS")
+# CHECK-NOT: Mismatched elements
diff --git a/mlir/test/Examples/nvgpu/Ch3.py b/mlir/test/Examples/nvgpu/Ch3.py
new file mode 100644
index 00000000000000..802cb59ead1555
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/Ch3.py
@@ -0,0 +1,92 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN:   %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+#  Chapter 3 : GEMM 64x64x64 with Tensor Core
+# ===----------------------------------------------------------------------===//
+#
+# This program demonstrates a GEMM operation with 64x64x64 matrix multiplication
+#
+# This chapter introduces demonstrates:
+# 1. Execute TMA Load for two input matrices
+# 2. Performs Tensor Core GEMM 64x64x64 by warpgroup
+# 3. Stores fragmented registers to global memory by warpgroup
+#
+# ===----------------------------------------------------------------------===//
+
+
+from mlir import ir
+from mlir.dialects import nvgpu, scf, arith, memref, vector, gpu
+from tools.nvdsl import *
+from mlir.extras import types as T
+import numpy as np
+
+
+ at NVDSL.mlir_func
+def gemm_64_64_64(x, y, z):
+    token_ty = ir.Type.parse("!gpu.async.token")
+    t1 = gpu.wait(token_ty, [])
+    x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
+    y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
+    z_dev, t4 = gpu.alloc(z.type, token_ty, [t3], [], [])
+    t5 = gpu.memcpy(token_ty, [t4], x_dev, x)
+    t6 = gpu.memcpy(token_ty, [t5], y_dev, y)
+    t7 = gpu.wait(token_ty, [t6])
+
+    sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
+    x_tma = TMA([N, N], x.type, swizzle=sw)
+    y_tma = TMA([N, N], y.type, swizzle=sw)
+    x_tma.create_descriptor(x_dev)
+    y_tma.create_descriptor(y_dev)
+
+    @NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(128, 1, 1), smem=16384)
+    def gemm_tma_kernel():
+        tidx = gpu.thread_id(gpu.Dimension.x)
+
+        mbar_group = Mbarriers(number_of_barriers=1)
+        isThread0 = tidx == 0
+        with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
+            mbar_group[0].init(1)
+            x_tma.prefetch()
+            y_tma.prefetch()
+            scf.yield_([])
+
+        x_smem = get_dynamic_shared_memory((N, N), T.f16())
+        y_smem = get_dynamic_shared_memory((N, N), T.f16(), offset=N * N * 2)
+
+        # 1. Execute TMA Load for two input matrices
+        with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
+            x_tma.load(x_smem, mbar_group[0])
+            y_tma.load(y_smem, mbar_group[0])
+            tx_count = get_type_size(x_tma.tma_memref) + get_type_size(y_tma.tma_memref)
+            mbar_group[0].arrive(tx_count)
+            scf.yield_([])
+
+        mbar_group[0].try_wait()
+
+        # 2. Performs Tensor Core GEMM 64x64x64 by warpgroup
+        A = Matrix(x_smem, x_tma, N, N)
+        B = Matrix(y_smem, y_tma, N, N)
+        C = MatrixAccumulator(N, N, T.f32()).op()
+        D = Matrix.matmul(A, B, C)
+
+        # 3. Stores fragmented registers to global memory by warpgroup
+        nvgpu.warpgroup_mma_store(D, z_dev)
+
+    gemm_tma_kernel()
+
+    t8 = gpu.memcpy(token_ty, [t7], z, z_dev)
+    gpu.wait(None, [t8])
+
+
+# Python pass arguments to MLIR
+N = 64
+x = np.random.randn(N, N).astype(np.float16)
+y = np.random.randn(N, N).astype(np.float16)
+z = np.zeros((N, N), np.float32)
+gemm_64_64_64(x, y, z)
+
+ref = x.astype(np.float16) @ y.astype(np.float16)
+np.testing.assert_allclose(z, ref, rtol=5e-03, atol=1e-01)
+print("PASS")
+# CHECK-NOT: Mismatched elements
diff --git a/mlir/test/Examples/nvgpu/Ch4.py b/mlir/test/Examples/nvgpu/Ch4.py
new file mode 100644
index 00000000000000..d222d6e60aafad
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/Ch4.py
@@ -0,0 +1,291 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN:   %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+#  Chapter 4 : Multistage GEMM with Tensor Core
+# ===----------------------------------------------------------------------===//
+#
+# This program demonstrates a GEMM operation with 64x64x64 matrix multiplication
+#
+# This chapter introduces demonstrates:
+#  1. Partition shape based on block IDs
+#  2. Prologue
+#    2.1 Execute TMA Load for two input matrices for each stage
+#  3. Main loop
+#    3.1 Wait for completion of TMA load with mbarrier
+#    3.2 Performs Tensor Core GEMM 64x128x64 by warpgroup
+#    3.3 Load next stage if needed
+#  4. Epilogue
+#    4.1 Store fragmented registers to shared memory
+#    4.2 Store shared memory to global
+#
+# ===----------------------------------------------------------------------===//
+
+
+from mlir import ir
+from mlir.dialects import gpu, scf, nvgpu, nvvm
+from mlir.extras import types as T
+from tools.nvdsl import *
+import numpy as np
+
+
+def partition_shape():
+    """
+    Calculate the partition shape based on the block IDs.
+
+    It partitions the shape like below:
+    for(.. i < M ...)   --> blockIdx.x
+     for(.. j < N ...)  --> blockIdx.y
+      for(.. k < K ...)
+
+    Returns:
+        dimX (int): Dimension along the x-axis.
+        dimY (int): Dimension along the y-axis.
+    """
+    bidx = gpu.block_id(gpu.Dimension.x)
+    bidy = gpu.block_id(gpu.Dimension.y)
+    dimX = bidx * TILE_M
+    dimY = bidy * TILE_N
+    return dimX, dimY
+
+
+def tma_load(
+    mbar_group: Mbarriers,
+    x_tma: TMA,
+    y_tma: TMA,
+    slot,
+    stage,
+    p=None,
+):
+    """
+    TMA loads two input matrices from global memory to shared memory. It performs the following operations:
+
+       - tma.load x_shared_memory[offset] at coordinate [x, y] (Loads 128x64)
+       - tma.load y_shared_memory[offset] at coordinate [x, y] (Loads 64x64)
+       - tma.load y_shared_memory[offset] at coordinate [x, y] (Loads 64x64)
+
+       mbarrier.arrive tx_count = 128x64x2x4
+    """
+    dimX, dimY = partition_shape()
+
+    tidx = gpu.thread_id(gpu.Dimension.x)
+    begin_y = NUM_STAGES * get_type_size(x_tma.tma_memref)
+    size_tma_x = get_type_size(x_tma.tma_memref)
+    size_tma_y = get_type_size(y_tma.tma_memref)
+    tx_count = size_tma_x + (size_tma_y * 2)
+    tidx = gpu.thread_id(gpu.Dimension.x)
+
+    p = tidx == 0 if p is None else p
+
+    off_x = slot * size_tma_x
+    off_y = (slot * size_tma_x) + begin_y
+    off_y2 = off_y + size_tma_y
+    x = get_dynamic_shared_memory(
+        x_tma.tma_memref.shape, x_tma.tma_memref.element_type, off_x
+    )
+    y1 = get_dynamic_shared_memory(
+        y_tma.tma_memref.shape, y_tma.tma_memref.element_type, off_y
+    )
+    y2 = get_dynamic_shared_memory(
+        y_tma.tma_memref.shape, y_tma.tma_memref.element_type, off_y2
+    )
+
+    mbar_group[slot].arrive(tx_count, predicate=p)
+
+    c1 = stage * 64
+    x_tma.load(x, mbar_group[slot], coords=[c1, dimX], predicate=p)
+    y_tma.load(y1, mbar_group[slot], coords=[dimY, c1], predicate=p)
+    y_tma.load(y2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p)
+
+
+def bootstrap(x_tma: TMA, y_tma: TMA):
+    """
+    Initialize mbarriers and prefetch TMA descriptors.
+    """
+    tidx = gpu.thread_id(gpu.Dimension.x)
+    mbar_group = Mbarriers(number_of_barriers=NUM_STAGES)
+    isThread0 = tidx == const(0)
+    with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
+        for i in scf.for_(0, NUM_STAGES, 1):
+            mbar_group[i].init(1)
+            scf.yield_([])
+        x_tma.prefetch()
+        y_tma.prefetch()
+        scf.yield_([])
+
+    return mbar_group
+
+
+def prologue(mbar_group: Mbarriers, x_tma: TMA, y_tma: TMA):
+    """
+    Prologue of the GEMM kernel. It loads 2 input matrices for each stage in loop like below:
+
+    for stage in range(NUM_STAGES):
+        tma_load x, y, stage
+
+    """
+    ns = NUM_STAGES if NUM_STAGES == 1 else NUM_STAGES - 1
+    for iv in scf.for_(0, ns, 1):
+        tma_load(mbar_group, x_tma, y_tma, iv, iv)
+        scf.yield_([])
+
+
+def mainloop(mbar_group: Mbarriers, x_tma: TMA, y_tma: TMA):
+    """
+    Main loop of the Multistage GEMM kernel. It iterates through
+    stages and performs matrix multiplication, loading data by TMA to shared memory. It like following
+
+    MatrixAccumulator D
+    for k in range(K // TILE_K):
+
+        try_wait(stage, ...)    # Wait TMA load
+
+        Matrix A(stage, ...)    # Find shared memory slot
+        Matrix B(stage, ...)    # Find shared memory slot
+        D += A @ B              # Multiply and accumulate
+
+        if(needLoad)            # Load next stage if needed
+            tma_load(x, y, nextSlot, nextStage)
+
+    """
+    ns = NUM_STAGES if NUM_STAGES == 1 else NUM_STAGES - 1
+
+    tidx = gpu.thread_id(gpu.Dimension.x)
+    begin_y = NUM_STAGES * get_type_size(x_tma.tma_memref)
+
+    size_x = TILE_M * TILE_K * get_type_size(T.f16())
+
+    C = MatrixAccumulator(TILE_M, TILE_N, T.f32()).op()
+    pp = const(False, ty=T.bool())
+
+    # Main Loop
+    for_op = scf.ForOp(const(0), const(K // TILE_K), const(1), [C, pp])
+    with ir.InsertionPoint(for_op.body):
+        pp = for_op.inner_iter_args[1]
+        iv = for_op.induction_variable
+        stage = iv % NUM_STAGES
+
+        # Wait for current stage
+        mbar_group[stage].try_wait(phase=pp)
+
+        # Find shared memory slot
+        offset_x = stage * size_x
+        offset_y = offset_x + begin_y
+        x_smem = get_dynamic_shared_memory([TILE_M, TILE_K], T.f16(), offset_x)
+        y_smem = get_dynamic_shared_memory([TILE_K, TILE_N], T.f16(), offset_y)
+
+        # Matrix Multiply
+        A = Matrix(x_smem, x_tma, TILE_M, TILE_K)
+        B = Matrix(y_smem, y_tma, TILE_K, TILE_N)
+        C = for_op.inner_iter_args[0]
+        D = Matrix.matmul(A, B, C)
+        if NUM_STAGES == 1:
+            nvvm.WgmmaWaitGroupSyncOp(0)
+
+        # Load next stage
+        pred = ((iv + ns) < const(K // TILE_K)) & (tidx == 0)
+        nextStage = iv + ns
+        nextSlot = nextStage % NUM_STAGES
+        tma_load(mbar_group, x_tma, y_tma, nextSlot, nextStage, pred)
+
+        # Switch phase parity for the mbarrier
+        switched = pp ^ const(True, ty=T.bool())
+        newPP = arith.select(
+            stage == (NUM_STAGES - 1),
+            switched,
+            pp,
+        )
+        scf.yield_([D, newPP])
+
+    nvvm.WgmmaWaitGroupSyncOp(0)
+
+    return for_op.results[0]
+
+
+def epilogue(D, z_dev):
+    """
+    Epilogue of the GEMM kernel. It stores the fragmented registers to global memory.
+
+    MatrixAccumulator D               # Fragmented results
+    store D -> Shared Memory          # Store Shared Memory
+    Shared Memory -> Z[dimX][dimY]    # Store Shared Memory to Global Memory
+
+    """
+    tidx = gpu.thread_id(gpu.Dimension.x)
+    dimX, dimY = partition_shape()
+
+    z_smem = get_dynamic_shared_memory([TILE_M, TILE_N], T.f32())
+    z_gmem = memref.subview(z_dev, [dimX, dimY], [TILE_M, TILE_N], [1, 1])
+
+    # Store (registers -> shared memory)
+    nvgpu.WarpgroupMmaStoreOp(D, z_smem)
+    gpu.barrier()
+
+    # Store (shared memory --> global memory)
+    for i in scf.for_(0, TILE_M, 1):
+        val = memref.load(z_smem, [i, tidx])
+        memref.store(val, z_gmem, [i, tidx])
+        scf.yield_([])
+
+
+ at NVDSL.mlir_func
+def gemm_multistage(x, y, z):
+    token_ty = ir.Type.parse("!gpu.async.token")
+    t1 = gpu.wait(token_ty, [])
+    x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
+    y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
+    z_dev, t4 = gpu.alloc(z.type, token_ty, [t3], [], [])
+    t5 = gpu.memcpy(token_ty, [t4], x_dev, x)
+    t6 = gpu.memcpy(token_ty, [t5], y_dev, y)
+    t7 = gpu.wait(token_ty, [t6])
+
+    sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
+    x_tma = TMA([128, 64], x.type, swizzle=sw)
+    y_tma = TMA([64, 64], y.type, swizzle=sw)
+    x_tma.create_descriptor(x_dev)
+    y_tma.create_descriptor(y_dev)
+
+    grid = [(M // TILE_M), (N // TILE_N), 1]
+    block = [128, 1, 1]
+    @NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=229440)
+    def gemm_multistage_kernel():
+        # Initialize mbarriers and prefetch TMA descriptors
+        mbar_group = bootstrap(x_tma, y_tma)
+
+        # Fill the pipeline stages
+        prologue(mbar_group, x_tma, y_tma)
+
+        # Main loop
+        D = mainloop(mbar_group, x_tma, y_tma)
+
+        # Store registers to global memory
+        epilogue(D, z_dev)
+
+    gemm_multistage_kernel()
+
+    t8 = gpu.memcpy(token_ty, [t7], z, z_dev)
+    gpu.wait(None, [t8])
+
+
+# Python pass arguments to MLIR
+NUM_STAGES = 7
+N = 256
+M = 512
+K = 1024
+TILE_M = 128
+TILE_N = 128
+TILE_K = 64
+x = np.random.randn(M, K).astype(np.float16)
+y = np.random.randn(K, N).astype(np.float16)
+z = np.zeros((M, N), np.float32)
+
+gemm_multistage(x, y, z)
+
+
+# Verify MLIR with reference computation
+ref = x.astype(np.float16) @ y.astype(np.float16)
+np.testing.assert_allclose(z, ref, rtol=5e-03, atol=1e-01)
+
+
+print("PASS")
+# CHECK-NOT: Mismatched elements
diff --git a/mlir/test/Examples/nvgpu/lit.local.cfg b/mlir/test/Examples/nvgpu/lit.local.cfg
new file mode 100644
index 00000000000000..e586b55573898a
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/lit.local.cfg
@@ -0,0 +1,3 @@
+config.unsupported = False
+if not config.enable_cuda_runner or not config.mlir_run_cuda_sm90_tests:
+  config.unsupported = True
\ No newline at end of file
diff --git a/mlir/test/Examples/nvgpu/tools/lit.local.cfg b/mlir/test/Examples/nvgpu/tools/lit.local.cfg
new file mode 100644
index 00000000000000..d9f34f219c4d95
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/tools/lit.local.cfg
@@ -0,0 +1,3 @@
+# Files in this directory are tools, not tests.
+config.unsupported = True
+
diff --git a/mlir/test/Examples/nvgpu/tools/nvdsl.py b/mlir/test/Examples/nvgpu/tools/nvdsl.py
new file mode 100644
index 00000000000000..4c774f4d4deac9
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/tools/nvdsl.py
@@ -0,0 +1,404 @@
+from enum import Enum
+import functools, sys, ctypes, os, errno
+import numpy as np
+from functools import partialmethod
+from mlir import ir
+from mlir.dialects import arith, func, gpu, memref, nvgpu
+from mlir.extras import types as T
+from mlir import runtime as rt
+from tools import nvgpucompiler
+
+DEBUG = True
+MLIR_DYNAMIC = -9223372036854775808
+
+
+def const(value: int, ty=None):
+    ty = T.index() if ty is None else ty
+    if isinstance(value, ir.Value) and (value.type.isinstance(value.type) or T.bool().isinstance(value.type)):
+        return value
+    return arith.constant(ty, value)
+
+
+def get_type_size(ty):
+    if ir.MemRefType.isinstance(ty):
+        size = get_type_size(ty.element_type)
+        for sz in ty.shape:
+            size *= sz
+        return size
+    if ir.FloatType.isinstance(ty):
+        return ir.FloatType(ty).width // 8
+    if ir.IntegerType.isinstance(ty):
+        return ir.IntegerType(ty).width // 8
+    raise NotImplementedError(ty)
+
+def get_mlir_func_obj_ty(inputArgs):
+    args = []
+    c_int_p = ctypes.c_int * 1
+    c_float_p = ctypes.c_float * 1
+    c_bool_p = ctypes.c_bool * 1
+    for arg in inputArgs:
+        if isinstance(arg, bool):
+            args.append(c_bool_p(arg))
+        elif isinstance(arg, int):
+            args.append(c_int_p(arg))
+        elif isinstance(arg, float):
+            args.append(c_float_p(arg))
+        elif isinstance(arg, np.ndarray):
+            args.append(
+                ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(arg)))
+            )
+        else:
+            raise NotImplementedError(arg)
+    return args
+
+
+class Mbarriers:
+    def __init__(self, number_of_barriers=1):
+        self.mbar_ty = ir.Type.parse(
+            "!nvgpu.mbarrier.group<memorySpace=#gpu.address_space<workgroup>, num_barriers = "
+            + str(number_of_barriers)
+            + ">"
+        )
+        self.mbar_group_op = nvgpu.mbarrier_create(self.mbar_ty)
+        self.number_of_barriers = number_of_barriers
+
+    def __getitem__(self, key):
+        self.id_op = const(key)
+        return self
+
+    def init(self, count: int, predicate=None):
+        count_op = const(count)
+        if predicate is None:
+            nvgpu.mbarrier_init(self.mbar_group_op, count_op, self.id_op)
+        else:
+            nvgpu.mbarrier_init(
+                self.mbar_group_op, count_op, self.id_op, predicate=predicate
+            )
+
+    def arrive(self, txcount: int = 0, predicate=None):
+        if txcount != 0:
+            txcount_op = const(txcount)
+            nvgpu.mbarrier_arrive_expect_tx(
+                self.mbar_group_op, txcount_op, self.id_op, predicate=predicate
+            )
+        else:
+            nvgpu.mbarrier_arrive(self.mbar_group_op, self.id_op, predicate=predicate)
+
+    def try_wait(self, phase: bool = False, ticks: int = 10000000):
+        ticks_op = const(ticks)
+        phase_op = const(phase, T.bool())
+        nvgpu.MBarrierTryWaitParityOp(
+            self.mbar_group_op,
+            phase_op,
+            ticks_op,
+            mbarId=self.id_op,
+        )
+
+
+class TMA:
+    """A class that builds a TMA descriptor."""
+
+    def __init__(
+        self,
+        shape,
+        memref_ty,
+        swizzle=nvgpu.TensorMapSwizzleKind.SWIZZLE_NONE,
+        l2promo=nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
+        oob=nvgpu.TensorMapOOBKind.OOB_ZERO,
+        interleave=nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
+    ):
+        self.swizzle = swizzle  # mlir.nvgpu.TensorMapSwizzleKind
+        self.l2promo = l2promo  # mlir.nvgpu.TensorMapL2PromoKind
+        self.oob = oob  # mlir.nvgpu.TensorMapOOBKind
+        self.interleave = interleave  # mlir.nvgpu.TensorMapInterleaveKind
+        self.shape = shape
+        self.memref_ty = memref_ty  # MemRefType
+        self.lastDim = 64
+        self.requiredLoad = 1
+        self.tma_shape = shape
+        self.tma_memref = ir.MemRefType.get(shape, memref_ty.element_type)
+
+    @property
+    def tensormap_descriptor_ty(self):
+        """Returns a tensormap descriptor type."""
+        memref_str = f"memref<{self.tma_shape[0]}x{self.tma_shape[1]}x{self.memref_ty.element_type}, 3>"
+        parse_str = f"!nvgpu.tensormap.descriptor<tensor = {memref_str},\
+                                              swizzle = {self.swizzle},\
+                                              l2promo = {self.l2promo},\
+                                              oob = {self.oob},\
+                                              interleave = {self.interleave}>"
+
+        return ir.Type.parse(parse_str)
+
+    def create_descriptor(self, device_ptr):
+        tma_descriptor_ty = self.tensormap_descriptor_ty
+        device_unranked_memref = memref.CastOp(
+            ir.UnrankedMemRefType.get(self.memref_ty.element_type,
+                                      self.memref_ty.memory_space),
+            device_ptr,
+        )
+        self.tma_descriptor = nvgpu.TmaCreateDescriptorOp(
+            tma_descriptor_ty, device_unranked_memref,
+            map(const, self.tma_shape))
+        return self.tma_descriptor.result
+
+    def prefetch(self, predicate=None):
+        nvgpu.tma_prefetch_descriptor(self.tma_descriptor, predicate=predicate)
+
+    def load(self, dest, mbarrier: Mbarriers, coords=[0,0], predicate=None):
+        coord_ops = [const(c) for c in coords]
+        nvgpu.TmaAsyncLoadOp(
+            dest,
+            mbarrier.mbar_group_op,
+            self.tma_descriptor,
+            coordinates=coord_ops,
+            mbarId=mbarrier.id_op,
+            predicate=predicate,
+        )
+
+
+class MatrixAccumulator:
+    def __init__(self, M, N, ty):
+        self.M = M
+        self.N = N
+        self.ty = ty
+
+    @property
+    def acc_ty(self):
+        return ir.Type.parse(
+            "!nvgpu.warpgroup.accumulator<fragmented=vector<"
+            + str(self.M)
+            + "x"
+            + str(self.N)
+            + "x"
+            + str(self.ty)
+            + ">>"
+        )
+
+    def op(self):
+        return nvgpu.warpgroup_mma_init_accumulator(self.acc_ty)
+
+
+class Matrix:
+
+    def __init__(self, smem, tma_descriptor: TMA, M, N):
+        self.tma_descriptor = tma_descriptor
+        self.smem = smem
+        self.M = M
+        self.N = N
+
+    @property
+    def wgmma_ty(self):
+        return ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" +
+                             str(self.M) + "x" +
+                             str(self.N) + "x" +
+                             str(self.tma_descriptor.memref_ty.element_type) +
+                             ", #gpu.address_space<workgroup>>>")
+
+    def matmul(lhs, rhs, acc):
+        wgmma_desc_lhs = nvgpu.warpgroup_generate_descriptor(
+            lhs.wgmma_ty, lhs.smem, lhs.tma_descriptor.tma_descriptor)
+        wgmma_desc_rhs = nvgpu.warpgroup_generate_descriptor(
+            rhs.wgmma_ty, rhs.smem, rhs.tma_descriptor.tma_descriptor)
+        return nvgpu.WarpgroupMmaOp(acc.type,
+                                    wgmma_desc_lhs,
+                                    wgmma_desc_rhs,
+                                    acc,
+                                    transposeB=True)
+
+def get_dynamic_shared_memory(shape=None, ty=None, offset: int = 0):
+    smem_space_str = "#gpu.address_space<workgroup>"
+    smem_space = ir.Attribute.parse(smem_space_str)
+    dynamic_smem = gpu.dynamic_shared_memory(
+        ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space)
+    )
+    if shape is None:
+        return dynamic_smem
+    memref_ty = ir.MemRefType.get(shape, ty, memory_space=smem_space)
+    return memref.view(
+        ir.MemRefType.get(
+            memref_ty.shape, memref_ty.element_type, memory_space=smem_space
+        ),
+        dynamic_smem,
+        const(offset),
+        [],
+    )
+
+ at staticmethod
+def get_mlir_ty(arg):
+    def get_mlir_ty_from_np(dtype):
+        if dtype == np.float16:
+            return T.f16()
+        if dtype == np.float32:
+            return T.f32()
+        if dtype == np.float64:
+            return T.f64()
+        if dtype == np.int32:
+            return T.i32()
+        if dtype == np.int64:
+            return T.i64()
+        raise NotImplementedError(dtype)
+    if isinstance(arg, bool):
+        return T.bool()
+    elif isinstance(arg, int):
+        return T.index()
+    elif isinstance(arg, float):
+        return T.f32()
+    elif isinstance(arg, np.ndarray):
+        descriptor = rt.get_ranked_memref_descriptor(arg)
+        dtype = get_mlir_ty_from_np(arg.dtype)
+        shape = descriptor.shape
+        return memref.MemRefType.get(shape, dtype)
+    raise NotImplementedError(arg)
+
+class NVDSL:
+    @staticmethod
+    def mlir_gpu_launch(grid=(1, 1, 1), block=(1, 1, 1), smem=0):
+        def decorator(func):
+            @functools.wraps(func)
+            def wrapper(*args, **kwargs):
+                launch_op = gpu.LaunchOp(
+                    None,
+                    [],
+                    *map(const, grid),
+                    *map(const, block),
+                    dynamicSharedMemorySize=arith.constant(T.i32(), smem),
+                )
+                launch_op.body.blocks.append(*([T.index()] * 12))
+                with ir.InsertionPoint(launch_op.body.blocks[0]):
+                    result = func(*args, **kwargs)
+                    gpu.terminator()
+                    return result
+
+            return wrapper
+
+        return decorator
+
+    @staticmethod
+    def mlir_func(funcBody):
+        @functools.wraps(funcBody)
+        def wrapper(*args, **kwargs):
+            function_name = funcBody.__name__
+
+            def saveIR(module):
+                """Save generated IR"""
+                if True:  # self.saveIR:
+                    # print(mlir_nvgpu_module)
+                    original_stdout = sys.stdout
+                    with open("nvdsl.mlir", "w") as f:
+                        sys.stdout = f
+                        print(module)
+                        sys.stdout = original_stdout
+
+            def _binary_op(lhs, rhs, op: str, predAtt="") -> "ArithValue":
+                """Generate MLIR's Arith dialects binary operations."""
+                rhs = const(rhs)
+                if arith._is_float_type(lhs.type) and arith._is_float_type(rhs.type):
+                    op += "F"
+                    if op.startswith("Cmp"):
+                        predicateAttr = getattr(arith, f"CmpFPredicate").__dict__[
+                            predAtt
+                        ]
+                elif arith._is_integer_like_type(
+                    lhs.type
+                ) and arith._is_integer_like_type(lhs.type):
+                    if op == "Div" or op == "Rem":
+                        op += "U"
+                    op += "I"
+                    if op.startswith("Cmp"):
+                        predicateAttr = getattr(arith, f"CmpIPredicate").__dict__[
+                            predAtt
+                        ]
+                else:
+                    raise NotImplementedError(
+                        f"Unsupported '{op}' operands: {lhs}, {rhs}"
+                    )
+
+                if op.startswith("Cmp"):
+                    op = getattr(arith, f"{op}Op")
+
+                    return op(predicateAttr, lhs, rhs).result
+                else:
+                    op = getattr(arith, f"{op}Op")
+                    return op(lhs, rhs).result
+
+            @ir.register_value_caster(ir.IndexType.static_typeid)
+            @ir.register_value_caster(ir.F32Type.static_typeid)
+            @ir.register_value_caster(ir.F16Type.static_typeid)
+            @ir.register_value_caster(ir.F64Type.static_typeid)
+            @ir.register_value_caster(ir.IntegerType.static_typeid)
+            class ArithValue(ir.Value):
+                """Overloads operators for MLIR's Arith dialects binary operations."""
+
+                def __init__(self, v):
+                    super().__init__(v)
+
+                __add__ = partialmethod(_binary_op, op="Add")
+                __sub__ = partialmethod(_binary_op, op="Sub")
+                __mul__ = partialmethod(_binary_op, op="Mul")
+                __truediv__ = partialmethod(_binary_op, op="Div")
+                __mod__ = partialmethod(_binary_op, op="Rem")
+                __xor__ = partialmethod(_binary_op, op="XOr")
+                __lt__ = partialmethod(_binary_op, op="Cmp", predAtt="ult")
+                __le__ = partialmethod(_binary_op, op="Cmp", predAtt="ule")
+                __eq__ = partialmethod(_binary_op, op="Cmp", predAtt="eq")
+                __ne__ = partialmethod(_binary_op, op="Cmp", predAtt="ne")
+                __gt__ = partialmethod(_binary_op, op="Cmp", predAtt="ugt")
+                __ge__ = partialmethod(_binary_op, op="Cmp", predAtt="uge")
+                __and__ = partialmethod(_binary_op, op="And")
+                __or__ = partialmethod(_binary_op, op="Or")
+
+                def __str__(self):
+                    return (
+                        super()
+                        .__str__()
+                        .replace(ir.Value.__name__, ArithValue.__name__)
+                    )
+
+            # Generate MLIR Context and start generating IR
+            with ir.Context(), ir.Location.unknown():
+                types = []
+                for arg in args:
+                    types.append(get_mlir_ty(arg))
+
+                # Build IR
+                module = ir.Module.create()
+                with ir.InsertionPoint(module.body):
+                    fop = func.FuncOp(function_name, (types, []))
+                    fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+                    with ir.InsertionPoint(fop.add_entry_block()):
+                        fargs = []
+                        for i, a in enumerate(types):
+                            fargs.append(fop.arguments[i])
+
+                        # Call user function body
+                        result = funcBody(*fargs, **kwargs)
+                        func.ReturnOp([])
+
+                # Verify the module
+                module.operation.verify()
+
+                # Save IR in a file
+                # saveIR(module)
+
+                # Compile and JIT MLIR module
+                options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
+                support_lib = os.getenv("SUPPORT_LIB")
+                if not os.path.exists(support_lib):
+                    raise FileNotFoundError(
+                        errno.ENOENT, os.strerror(errno.ENOENT), support_lib
+                    )
+                compiler = nvgpucompiler.NvgpuCompiler(
+                    options, opt_level=3, shared_libs=[support_lib]
+                )
+                engine = compiler.compile_and_jit(module)
+
+            # Convert input arguments to MLIR arguments
+            newArgs = get_mlir_func_obj_ty(args)
+
+            # Run the compiled program
+            engine.invoke(function_name, *newArgs)
+
+            return result
+
+        return wrapper
diff --git a/mlir/test/Examples/nvgpu/tools/nvgpucompiler.py b/mlir/test/Examples/nvgpu/tools/nvgpucompiler.py
new file mode 100644
index 00000000000000..1c9cc74fcd169c
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/tools/nvgpucompiler.py
@@ -0,0 +1,45 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#  This file contains the Nvgpu class.
+
+from mlir import execution_engine
+from mlir import ir
+from mlir import passmanager
+from typing import Sequence
+import errno
+import os
+import sys
+
+_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(_SCRIPT_PATH)
+
+
+class NvgpuCompiler:
+    """Nvgpu class for compiling and building MLIR modules."""
+
+    def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
+        pipeline = f"builtin.module(gpu-lower-to-nvvm-pipeline{{{options}}})"
+        self.pipeline = pipeline
+        self.shared_libs = shared_libs
+        self.opt_level = opt_level
+
+    def __call__(self, module: ir.Module):
+        """Convenience application method."""
+        self.compile(module)
+
+    def compile(self, module: ir.Module):
+        """Compiles the module by invoking the nvgpu pipeline."""
+        passmanager.PassManager.parse(self.pipeline).run(module.operation)
+
+    def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+        """Wraps the module in a JIT execution engine."""
+        return execution_engine.ExecutionEngine(
+            module, opt_level=self.opt_level, shared_libs=self.shared_libs
+        )
+
+    def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+        """Compiles and jits the module."""
+        self.compile(module)
+        return self.jit(module)

>From c07ade60c2dd343165827587bebe48ae03f6785b Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Fri, 29 Mar 2024 13:51:21 +0000
Subject: [PATCH 02/14] format

---
 mlir/test/Examples/nvgpu/Ch2.py         |  2 +-
 mlir/test/Examples/nvgpu/Ch4.py         |  1 +
 mlir/test/Examples/nvgpu/tools/nvdsl.py | 49 +++++++++++++++----------
 3 files changed, 32 insertions(+), 20 deletions(-)

diff --git a/mlir/test/Examples/nvgpu/Ch2.py b/mlir/test/Examples/nvgpu/Ch2.py
index caa2691a06d69e..e9369227af0e85 100644
--- a/mlir/test/Examples/nvgpu/Ch2.py
+++ b/mlir/test/Examples/nvgpu/Ch2.py
@@ -5,7 +5,7 @@
 #  Chapter 2 : 2D Saxpy with TMA
 # ===----------------------------------------------------------------------===//
 #
-# This program demonstrates 2D Saxpy. It is same as Chapter 1, 
+# This program demonstrates 2D Saxpy. It is same as Chapter 1,
 # but it loads data using TMA (Tensor Memory Accelerator)
 #
 # This chapter introduces demonstrates:
diff --git a/mlir/test/Examples/nvgpu/Ch4.py b/mlir/test/Examples/nvgpu/Ch4.py
index d222d6e60aafad..5f790c4c190ca9 100644
--- a/mlir/test/Examples/nvgpu/Ch4.py
+++ b/mlir/test/Examples/nvgpu/Ch4.py
@@ -247,6 +247,7 @@ def gemm_multistage(x, y, z):
 
     grid = [(M // TILE_M), (N // TILE_N), 1]
     block = [128, 1, 1]
+
     @NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=229440)
     def gemm_multistage_kernel():
         # Initialize mbarriers and prefetch TMA descriptors
diff --git a/mlir/test/Examples/nvgpu/tools/nvdsl.py b/mlir/test/Examples/nvgpu/tools/nvdsl.py
index 4c774f4d4deac9..e27857d8ed1d79 100644
--- a/mlir/test/Examples/nvgpu/tools/nvdsl.py
+++ b/mlir/test/Examples/nvgpu/tools/nvdsl.py
@@ -14,7 +14,9 @@
 
 def const(value: int, ty=None):
     ty = T.index() if ty is None else ty
-    if isinstance(value, ir.Value) and (value.type.isinstance(value.type) or T.bool().isinstance(value.type)):
+    if isinstance(value, ir.Value) and (
+        value.type.isinstance(value.type) or T.bool().isinstance(value.type)
+    ):
         return value
     return arith.constant(ty, value)
 
@@ -31,6 +33,7 @@ def get_type_size(ty):
         return ir.IntegerType(ty).width // 8
     raise NotImplementedError(ty)
 
+
 def get_mlir_func_obj_ty(inputArgs):
     args = []
     c_int_p = ctypes.c_int * 1
@@ -133,19 +136,20 @@ def tensormap_descriptor_ty(self):
     def create_descriptor(self, device_ptr):
         tma_descriptor_ty = self.tensormap_descriptor_ty
         device_unranked_memref = memref.CastOp(
-            ir.UnrankedMemRefType.get(self.memref_ty.element_type,
-                                      self.memref_ty.memory_space),
+            ir.UnrankedMemRefType.get(
+                self.memref_ty.element_type, self.memref_ty.memory_space
+            ),
             device_ptr,
         )
         self.tma_descriptor = nvgpu.TmaCreateDescriptorOp(
-            tma_descriptor_ty, device_unranked_memref,
-            map(const, self.tma_shape))
+            tma_descriptor_ty, device_unranked_memref, map(const, self.tma_shape)
+        )
         return self.tma_descriptor.result
 
     def prefetch(self, predicate=None):
         nvgpu.tma_prefetch_descriptor(self.tma_descriptor, predicate=predicate)
 
-    def load(self, dest, mbarrier: Mbarriers, coords=[0,0], predicate=None):
+    def load(self, dest, mbarrier: Mbarriers, coords=[0, 0], predicate=None):
         coord_ops = [const(c) for c in coords]
         nvgpu.TmaAsyncLoadOp(
             dest,
@@ -180,7 +184,6 @@ def op(self):
 
 
 class Matrix:
-
     def __init__(self, smem, tma_descriptor: TMA, M, N):
         self.tma_descriptor = tma_descriptor
         self.smem = smem
@@ -189,22 +192,27 @@ def __init__(self, smem, tma_descriptor: TMA, M, N):
 
     @property
     def wgmma_ty(self):
-        return ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" +
-                             str(self.M) + "x" +
-                             str(self.N) + "x" +
-                             str(self.tma_descriptor.memref_ty.element_type) +
-                             ", #gpu.address_space<workgroup>>>")
+        return ir.Type.parse(
+            "!nvgpu.warpgroup.descriptor<tensor=memref<"
+            + str(self.M)
+            + "x"
+            + str(self.N)
+            + "x"
+            + str(self.tma_descriptor.memref_ty.element_type)
+            + ", #gpu.address_space<workgroup>>>"
+        )
 
     def matmul(lhs, rhs, acc):
         wgmma_desc_lhs = nvgpu.warpgroup_generate_descriptor(
-            lhs.wgmma_ty, lhs.smem, lhs.tma_descriptor.tma_descriptor)
+            lhs.wgmma_ty, lhs.smem, lhs.tma_descriptor.tma_descriptor
+        )
         wgmma_desc_rhs = nvgpu.warpgroup_generate_descriptor(
-            rhs.wgmma_ty, rhs.smem, rhs.tma_descriptor.tma_descriptor)
-        return nvgpu.WarpgroupMmaOp(acc.type,
-                                    wgmma_desc_lhs,
-                                    wgmma_desc_rhs,
-                                    acc,
-                                    transposeB=True)
+            rhs.wgmma_ty, rhs.smem, rhs.tma_descriptor.tma_descriptor
+        )
+        return nvgpu.WarpgroupMmaOp(
+            acc.type, wgmma_desc_lhs, wgmma_desc_rhs, acc, transposeB=True
+        )
+
 
 def get_dynamic_shared_memory(shape=None, ty=None, offset: int = 0):
     smem_space_str = "#gpu.address_space<workgroup>"
@@ -224,6 +232,7 @@ def get_dynamic_shared_memory(shape=None, ty=None, offset: int = 0):
         [],
     )
 
+
 @staticmethod
 def get_mlir_ty(arg):
     def get_mlir_ty_from_np(dtype):
@@ -238,6 +247,7 @@ def get_mlir_ty_from_np(dtype):
         if dtype == np.int64:
             return T.i64()
         raise NotImplementedError(dtype)
+
     if isinstance(arg, bool):
         return T.bool()
     elif isinstance(arg, int):
@@ -251,6 +261,7 @@ def get_mlir_ty_from_np(dtype):
         return memref.MemRefType.get(shape, dtype)
     raise NotImplementedError(arg)
 
+
 class NVDSL:
     @staticmethod
     def mlir_gpu_launch(grid=(1, 1, 1), block=(1, 1, 1), smem=0):

>From dd0b01e185fea20207455892cd4326fa0a5956e3 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Mon, 1 Apr 2024 14:05:12 +0000
Subject: [PATCH 03/14] address comments

---
 mlir/test/Examples/nvgpu/Ch3.py         |  63 +++++------
 mlir/test/Examples/nvgpu/Ch4.py         | 134 ++++++++++++------------
 mlir/test/Examples/nvgpu/tools/nvdsl.py |  26 ++---
 3 files changed, 107 insertions(+), 116 deletions(-)

diff --git a/mlir/test/Examples/nvgpu/Ch3.py b/mlir/test/Examples/nvgpu/Ch3.py
index 802cb59ead1555..b1623f9d645c2c 100644
--- a/mlir/test/Examples/nvgpu/Ch3.py
+++ b/mlir/test/Examples/nvgpu/Ch3.py
@@ -23,23 +23,24 @@
 
 
 @NVDSL.mlir_func
-def gemm_64_64_64(x, y, z):
+def gemm_64_64_64(x, y, d):
     token_ty = ir.Type.parse("!gpu.async.token")
     t1 = gpu.wait(token_ty, [])
-    x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
-    y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
-    z_dev, t4 = gpu.alloc(z.type, token_ty, [t3], [], [])
-    t5 = gpu.memcpy(token_ty, [t4], x_dev, x)
-    t6 = gpu.memcpy(token_ty, [t5], y_dev, y)
+    a_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
+    b_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
+    d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], [])
+    t5 = gpu.memcpy(token_ty, [t4], a_dev, x)
+    t6 = gpu.memcpy(token_ty, [t5], b_dev, y)
     t7 = gpu.wait(token_ty, [t6])
 
     sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
-    x_tma = TMA([N, N], x.type, swizzle=sw)
-    y_tma = TMA([N, N], y.type, swizzle=sw)
-    x_tma.create_descriptor(x_dev)
-    y_tma.create_descriptor(y_dev)
+    a_tma = TMA([N, N], x.type, swizzle=sw)
+    b_tma = TMA([N, N], y.type, swizzle=sw)
+    a_tma.create_descriptor(a_dev)
+    b_tma.create_descriptor(b_dev)
+    smem_size_in_bytes = get_type_size(x.type) + get_type_size(y.type)
 
-    @NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(128, 1, 1), smem=16384)
+    @NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(128, 1, 1), smem=smem_size_in_bytes)
     def gemm_tma_kernel():
         tidx = gpu.thread_id(gpu.Dimension.x)
 
@@ -47,46 +48,46 @@ def gemm_tma_kernel():
         isThread0 = tidx == 0
         with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
             mbar_group[0].init(1)
-            x_tma.prefetch()
-            y_tma.prefetch()
+            a_tma.prefetch()
+            b_tma.prefetch()
             scf.yield_([])
 
-        x_smem = get_dynamic_shared_memory((N, N), T.f16())
-        y_smem = get_dynamic_shared_memory((N, N), T.f16(), offset=N * N * 2)
+        a_smem = get_dynamic_shared_memory((N, N), T.f16())
+        b_smem = get_dynamic_shared_memory((N, N), T.f16(), offset=N * N * 2)
 
         # 1. Execute TMA Load for two input matrices
         with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
-            x_tma.load(x_smem, mbar_group[0])
-            y_tma.load(y_smem, mbar_group[0])
-            tx_count = get_type_size(x_tma.tma_memref) + get_type_size(y_tma.tma_memref)
-            mbar_group[0].arrive(tx_count)
+            a_tma.load(a_smem, mbar_group[0])
+            b_tma.load(b_smem, mbar_group[0])
+            ta_count = get_type_size(a_tma.tma_memref) + get_type_size(b_tma.tma_memref)
+            mbar_group[0].arrive(ta_count)
             scf.yield_([])
 
         mbar_group[0].try_wait()
 
         # 2. Performs Tensor Core GEMM 64x64x64 by warpgroup
-        A = Matrix(x_smem, x_tma, N, N)
-        B = Matrix(y_smem, y_tma, N, N)
-        C = MatrixAccumulator(N, N, T.f32()).op()
-        D = Matrix.matmul(A, B, C)
+        A = WarpgroupMatrix(a_smem, a_tma, N, N)
+        B = WarpgroupMatrix(b_smem, b_tma, N, N)
+        C = WarpgroupAccumulatorMatrix(N, N, T.f32()).op()
+        D = WarpgroupMatrix.matmul(A, B, C)
 
         # 3. Stores fragmented registers to global memory by warpgroup
-        nvgpu.warpgroup_mma_store(D, z_dev)
+        nvgpu.warpgroup_mma_store(D, d_dev)
 
     gemm_tma_kernel()
 
-    t8 = gpu.memcpy(token_ty, [t7], z, z_dev)
+    t8 = gpu.memcpy(token_ty, [t7], d, d_dev)
     gpu.wait(None, [t8])
 
 
 # Python pass arguments to MLIR
 N = 64
-x = np.random.randn(N, N).astype(np.float16)
-y = np.random.randn(N, N).astype(np.float16)
-z = np.zeros((N, N), np.float32)
-gemm_64_64_64(x, y, z)
+a = np.random.randn(N, N).astype(np.float16)
+b = np.random.randn(N, N).astype(np.float16)
+d = np.zeros((N, N), np.float32)
+gemm_64_64_64(a, b, d)
 
-ref = x.astype(np.float16) @ y.astype(np.float16)
-np.testing.assert_allclose(z, ref, rtol=5e-03, atol=1e-01)
+ref_d = a.astype(np.float16) @ b.astype(np.float16)
+np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
 print("PASS")
 # CHECK-NOT: Mismatched elements
diff --git a/mlir/test/Examples/nvgpu/Ch4.py b/mlir/test/Examples/nvgpu/Ch4.py
index 5f790c4c190ca9..2f38a49501e792 100644
--- a/mlir/test/Examples/nvgpu/Ch4.py
+++ b/mlir/test/Examples/nvgpu/Ch4.py
@@ -51,8 +51,8 @@ def partition_shape():
 
 def tma_load(
     mbar_group: Mbarriers,
-    x_tma: TMA,
-    y_tma: TMA,
+    a_tma: TMA,
+    b_tma: TMA,
     slot,
     stage,
     p=None,
@@ -60,45 +60,45 @@ def tma_load(
     """
     TMA loads two input matrices from global memory to shared memory. It performs the following operations:
 
-       - tma.load x_shared_memory[offset] at coordinate [x, y] (Loads 128x64)
-       - tma.load y_shared_memory[offset] at coordinate [x, y] (Loads 64x64)
-       - tma.load y_shared_memory[offset] at coordinate [x, y] (Loads 64x64)
+       - tma.load a_shared_memory[offset] at coordinate [x, y] (Loads 128x64)
+       - tma.load b_shared_memory[offset] at coordinate [x, y] (Loads 64x64)
+       - tma.load b_shared_memory[offset] at coordinate [x, y] (Loads 64x64)
 
-       mbarrier.arrive tx_count = 128x64x2x4
+       mbarrier.arrive ta_count = 128x64x2x4
     """
     dimX, dimY = partition_shape()
 
     tidx = gpu.thread_id(gpu.Dimension.x)
-    begin_y = NUM_STAGES * get_type_size(x_tma.tma_memref)
-    size_tma_x = get_type_size(x_tma.tma_memref)
-    size_tma_y = get_type_size(y_tma.tma_memref)
-    tx_count = size_tma_x + (size_tma_y * 2)
+    begin_b = NUM_STAGES * get_type_size(a_tma.tma_memref)
+    size_tma_a = get_type_size(a_tma.tma_memref)
+    size_tma_b = get_type_size(b_tma.tma_memref)
+    ta_count = size_tma_a + (size_tma_b * 2)
     tidx = gpu.thread_id(gpu.Dimension.x)
 
     p = tidx == 0 if p is None else p
 
-    off_x = slot * size_tma_x
-    off_y = (slot * size_tma_x) + begin_y
-    off_y2 = off_y + size_tma_y
+    off_a = slot * size_tma_a
+    off_b = (slot * size_tma_a) + begin_b
+    off_b2 = off_b + size_tma_b
     x = get_dynamic_shared_memory(
-        x_tma.tma_memref.shape, x_tma.tma_memref.element_type, off_x
+        a_tma.tma_memref.shape, a_tma.tma_memref.element_type, off_a
     )
     y1 = get_dynamic_shared_memory(
-        y_tma.tma_memref.shape, y_tma.tma_memref.element_type, off_y
+        b_tma.tma_memref.shape, b_tma.tma_memref.element_type, off_b
     )
     y2 = get_dynamic_shared_memory(
-        y_tma.tma_memref.shape, y_tma.tma_memref.element_type, off_y2
+        b_tma.tma_memref.shape, b_tma.tma_memref.element_type, off_b2
     )
 
-    mbar_group[slot].arrive(tx_count, predicate=p)
+    mbar_group[slot].arrive(ta_count, predicate=p)
 
     c1 = stage * 64
-    x_tma.load(x, mbar_group[slot], coords=[c1, dimX], predicate=p)
-    y_tma.load(y1, mbar_group[slot], coords=[dimY, c1], predicate=p)
-    y_tma.load(y2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p)
+    a_tma.load(x, mbar_group[slot], coords=[c1, dimX], predicate=p)
+    b_tma.load(y1, mbar_group[slot], coords=[dimY, c1], predicate=p)
+    b_tma.load(y2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p)
 
 
-def bootstrap(x_tma: TMA, y_tma: TMA):
+def bootstrap(a_tma: TMA, b_tma: TMA):
     """
     Initialize mbarriers and prefetch TMA descriptors.
     """
@@ -109,14 +109,14 @@ def bootstrap(x_tma: TMA, y_tma: TMA):
         for i in scf.for_(0, NUM_STAGES, 1):
             mbar_group[i].init(1)
             scf.yield_([])
-        x_tma.prefetch()
-        y_tma.prefetch()
+        a_tma.prefetch()
+        b_tma.prefetch()
         scf.yield_([])
 
     return mbar_group
 
 
-def prologue(mbar_group: Mbarriers, x_tma: TMA, y_tma: TMA):
+def prologue(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
     """
     Prologue of the GEMM kernel. It loads 2 input matrices for each stage in loop like below:
 
@@ -126,11 +126,11 @@ def prologue(mbar_group: Mbarriers, x_tma: TMA, y_tma: TMA):
     """
     ns = NUM_STAGES if NUM_STAGES == 1 else NUM_STAGES - 1
     for iv in scf.for_(0, ns, 1):
-        tma_load(mbar_group, x_tma, y_tma, iv, iv)
+        tma_load(mbar_group, a_tma, b_tma, iv, iv)
         scf.yield_([])
 
 
-def mainloop(mbar_group: Mbarriers, x_tma: TMA, y_tma: TMA):
+def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
     """
     Main loop of the Multistage GEMM kernel. It iterates through
     stages and performs matrix multiplication, loading data by TMA to shared memory. It like following
@@ -151,11 +151,11 @@ def mainloop(mbar_group: Mbarriers, x_tma: TMA, y_tma: TMA):
     ns = NUM_STAGES if NUM_STAGES == 1 else NUM_STAGES - 1
 
     tidx = gpu.thread_id(gpu.Dimension.x)
-    begin_y = NUM_STAGES * get_type_size(x_tma.tma_memref)
+    begin_b = NUM_STAGES * get_type_size(a_tma.tma_memref)
 
-    size_x = TILE_M * TILE_K * get_type_size(T.f16())
+    size_a = TILE_M * TILE_K * get_type_size(T.f16())
 
-    C = MatrixAccumulator(TILE_M, TILE_N, T.f32()).op()
+    C = WarpgroupAccumulatorMatrix(TILE_M, TILE_N, T.f32()).op()
     pp = const(False, ty=T.bool())
 
     # Main Loop
@@ -169,16 +169,16 @@ def mainloop(mbar_group: Mbarriers, x_tma: TMA, y_tma: TMA):
         mbar_group[stage].try_wait(phase=pp)
 
         # Find shared memory slot
-        offset_x = stage * size_x
-        offset_y = offset_x + begin_y
-        x_smem = get_dynamic_shared_memory([TILE_M, TILE_K], T.f16(), offset_x)
-        y_smem = get_dynamic_shared_memory([TILE_K, TILE_N], T.f16(), offset_y)
+        offset_a = stage * size_a
+        offset_b = offset_a + begin_b
+        a_smem = get_dynamic_shared_memory([TILE_M, TILE_K], T.f16(), offset_a)
+        b_smem = get_dynamic_shared_memory([TILE_K, TILE_N], T.f16(), offset_b)
 
         # Matrix Multiply
-        A = Matrix(x_smem, x_tma, TILE_M, TILE_K)
-        B = Matrix(y_smem, y_tma, TILE_K, TILE_N)
+        A = WarpgroupMatrix(a_smem, a_tma, TILE_M, TILE_K)
+        B = WarpgroupMatrix(b_smem, b_tma, TILE_K, TILE_N)
         C = for_op.inner_iter_args[0]
-        D = Matrix.matmul(A, B, C)
+        D = WarpgroupMatrix.matmul(A, B, C)
         if NUM_STAGES == 1:
             nvvm.WgmmaWaitGroupSyncOp(0)
 
@@ -186,7 +186,7 @@ def mainloop(mbar_group: Mbarriers, x_tma: TMA, y_tma: TMA):
         pred = ((iv + ns) < const(K // TILE_K)) & (tidx == 0)
         nextStage = iv + ns
         nextSlot = nextStage % NUM_STAGES
-        tma_load(mbar_group, x_tma, y_tma, nextSlot, nextStage, pred)
+        tma_load(mbar_group, a_tma, b_tma, nextSlot, nextStage, pred)
 
         # Switch phase parity for the mbarrier
         switched = pp ^ const(True, ty=T.bool())
@@ -202,7 +202,7 @@ def mainloop(mbar_group: Mbarriers, x_tma: TMA, y_tma: TMA):
     return for_op.results[0]
 
 
-def epilogue(D, z_dev):
+def epilogue(D, d_dev):
     """
     Epilogue of the GEMM kernel. It stores the fragmented registers to global memory.
 
@@ -214,57 +214,61 @@ def epilogue(D, z_dev):
     tidx = gpu.thread_id(gpu.Dimension.x)
     dimX, dimY = partition_shape()
 
-    z_smem = get_dynamic_shared_memory([TILE_M, TILE_N], T.f32())
-    z_gmem = memref.subview(z_dev, [dimX, dimY], [TILE_M, TILE_N], [1, 1])
+    d_smem = get_dynamic_shared_memory([TILE_M, TILE_N], T.f32())
+    d_gmem = memref.subview(d_dev, [dimX, dimY], [TILE_M, TILE_N], [1, 1])
 
     # Store (registers -> shared memory)
-    nvgpu.WarpgroupMmaStoreOp(D, z_smem)
+    nvgpu.WarpgroupMmaStoreOp(D, d_smem)
     gpu.barrier()
 
     # Store (shared memory --> global memory)
     for i in scf.for_(0, TILE_M, 1):
-        val = memref.load(z_smem, [i, tidx])
-        memref.store(val, z_gmem, [i, tidx])
+        val = memref.load(d_smem, [i, tidx])
+        memref.store(val, d_gmem, [i, tidx])
         scf.yield_([])
 
 
 @NVDSL.mlir_func
-def gemm_multistage(x, y, z):
+def gemm_multistage(x, y, d):
     token_ty = ir.Type.parse("!gpu.async.token")
     t1 = gpu.wait(token_ty, [])
-    x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
-    y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
-    z_dev, t4 = gpu.alloc(z.type, token_ty, [t3], [], [])
-    t5 = gpu.memcpy(token_ty, [t4], x_dev, x)
-    t6 = gpu.memcpy(token_ty, [t5], y_dev, y)
+    a_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
+    b_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
+    d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], [])
+    t5 = gpu.memcpy(token_ty, [t4], a_dev, x)
+    t6 = gpu.memcpy(token_ty, [t5], b_dev, y)
     t7 = gpu.wait(token_ty, [t6])
 
     sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
-    x_tma = TMA([128, 64], x.type, swizzle=sw)
-    y_tma = TMA([64, 64], y.type, swizzle=sw)
-    x_tma.create_descriptor(x_dev)
-    y_tma.create_descriptor(y_dev)
+    a_tma = TMA([128, 64], x.type, swizzle=sw)
+    b_tma = TMA([64, 64], y.type, swizzle=sw)
+    a_tma.create_descriptor(a_dev)
+    b_tma.create_descriptor(b_dev)
 
     grid = [(M // TILE_M), (N // TILE_N), 1]
     block = [128, 1, 1]
 
-    @NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=229440)
+    size_a = get_type_size(x.type.element_type) * TILE_M * TILE_K
+    size_b = get_type_size(x.type.element_type) * TILE_N * TILE_K
+    smem_size_in_bytes = (size_a + size_b) * NUM_STAGES
+
+    @NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=smem_size_in_bytes)
     def gemm_multistage_kernel():
         # Initialize mbarriers and prefetch TMA descriptors
-        mbar_group = bootstrap(x_tma, y_tma)
+        mbar_group = bootstrap(a_tma, b_tma)
 
         # Fill the pipeline stages
-        prologue(mbar_group, x_tma, y_tma)
+        prologue(mbar_group, a_tma, b_tma)
 
         # Main loop
-        D = mainloop(mbar_group, x_tma, y_tma)
+        D = mainloop(mbar_group, a_tma, b_tma)
 
         # Store registers to global memory
-        epilogue(D, z_dev)
+        epilogue(D, d_dev)
 
     gemm_multistage_kernel()
 
-    t8 = gpu.memcpy(token_ty, [t7], z, z_dev)
+    t8 = gpu.memcpy(token_ty, [t7], d, d_dev)
     gpu.wait(None, [t8])
 
 
@@ -276,16 +280,16 @@ def gemm_multistage_kernel():
 TILE_M = 128
 TILE_N = 128
 TILE_K = 64
-x = np.random.randn(M, K).astype(np.float16)
-y = np.random.randn(K, N).astype(np.float16)
-z = np.zeros((M, N), np.float32)
+a = np.random.randn(M, K).astype(np.float16)
+b = np.random.randn(K, N).astype(np.float16)
+d = np.zeros((M, N), np.float32)
 
-gemm_multistage(x, y, z)
+gemm_multistage(a, b, d)
 
 
 # Verify MLIR with reference computation
-ref = x.astype(np.float16) @ y.astype(np.float16)
-np.testing.assert_allclose(z, ref, rtol=5e-03, atol=1e-01)
+ref_d = a.astype(np.float16) @ b.astype(np.float16)
+np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
 
 
 print("PASS")
diff --git a/mlir/test/Examples/nvgpu/tools/nvdsl.py b/mlir/test/Examples/nvgpu/tools/nvdsl.py
index e27857d8ed1d79..bb312a9e259ef3 100644
--- a/mlir/test/Examples/nvgpu/tools/nvdsl.py
+++ b/mlir/test/Examples/nvgpu/tools/nvdsl.py
@@ -161,7 +161,7 @@ def load(self, dest, mbarrier: Mbarriers, coords=[0, 0], predicate=None):
         )
 
 
-class MatrixAccumulator:
+class WarpgroupAccumulatorMatrix:
     def __init__(self, M, N, ty):
         self.M = M
         self.N = N
@@ -169,21 +169,14 @@ def __init__(self, M, N, ty):
 
     @property
     def acc_ty(self):
-        return ir.Type.parse(
-            "!nvgpu.warpgroup.accumulator<fragmented=vector<"
-            + str(self.M)
-            + "x"
-            + str(self.N)
-            + "x"
-            + str(self.ty)
-            + ">>"
-        )
+        parse_str = f"!nvgpu.warpgroup.accumulator<fragmented=vector<{self.M}x{self.N}x{self.ty}>>"
+        return ir.Type.parse(parse_str)
 
     def op(self):
         return nvgpu.warpgroup_mma_init_accumulator(self.acc_ty)
 
 
-class Matrix:
+class WarpgroupMatrix:
     def __init__(self, smem, tma_descriptor: TMA, M, N):
         self.tma_descriptor = tma_descriptor
         self.smem = smem
@@ -192,15 +185,8 @@ def __init__(self, smem, tma_descriptor: TMA, M, N):
 
     @property
     def wgmma_ty(self):
-        return ir.Type.parse(
-            "!nvgpu.warpgroup.descriptor<tensor=memref<"
-            + str(self.M)
-            + "x"
-            + str(self.N)
-            + "x"
-            + str(self.tma_descriptor.memref_ty.element_type)
-            + ", #gpu.address_space<workgroup>>>"
-        )
+        parse_str = f"!nvgpu.warpgroup.descriptor<tensor=memref<{self.M}x{self.N}x{self.tma_descriptor.memref_ty.element_type}, #gpu.address_space<workgroup>>>"
+        return ir.Type.parse(parse_str)
 
     def matmul(lhs, rhs, acc):
         wgmma_desc_lhs = nvgpu.warpgroup_generate_descriptor(

>From 709e2a6d520fe47b9bbc853311f44b6785e383d8 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Mon, 1 Apr 2024 14:42:23 +0000
Subject: [PATCH 04/14] Perform 128x128x64 GEMM instead of 64x64x64

---
 mlir/test/Examples/nvgpu/Ch3.py | 100 +++++++++++++++++++++++---------
 1 file changed, 74 insertions(+), 26 deletions(-)

diff --git a/mlir/test/Examples/nvgpu/Ch3.py b/mlir/test/Examples/nvgpu/Ch3.py
index b1623f9d645c2c..12e6def2395555 100644
--- a/mlir/test/Examples/nvgpu/Ch3.py
+++ b/mlir/test/Examples/nvgpu/Ch3.py
@@ -22,23 +22,71 @@
 import numpy as np
 
 
+def tma_load(
+    mbar_group: Mbarriers,
+    a_tma: TMA,
+    b_tma: TMA,
+    slot,
+    stage,
+    p=None,
+):
+    """
+    TMA loads two input matrices from global memory to shared memory. It performs the following operations:
+
+       - tma.load a_shared_memory[offset] at coordinate [x, y] (Loads 128x64)
+       - tma.load b_shared_memory[offset] at coordinate [x, y] (Loads 64x64)
+       - tma.load b_shared_memory[offset] at coordinate [x, y] (Loads 64x64)
+
+       mbarrier.arrive ta_count = 128x64x2x4
+    """
+
+    tidx = gpu.thread_id(gpu.Dimension.x)
+    begin_b = get_type_size(a_tma.tma_memref)
+    size_tma_a = get_type_size(a_tma.tma_memref)
+    size_tma_b = get_type_size(b_tma.tma_memref)
+    ta_count = size_tma_a + (size_tma_b * 2)
+    tidx = gpu.thread_id(gpu.Dimension.x)
+
+    p = tidx == 0 if p is None else p
+
+    off_a = slot * size_tma_a
+    off_b = (slot * size_tma_a) + begin_b
+    off_b2 = off_b + size_tma_b
+    x = get_dynamic_shared_memory(
+        a_tma.tma_memref.shape, a_tma.tma_memref.element_type, off_a
+    )
+    y1 = get_dynamic_shared_memory(
+        b_tma.tma_memref.shape, b_tma.tma_memref.element_type, off_b
+    )
+    y2 = get_dynamic_shared_memory(
+        b_tma.tma_memref.shape, b_tma.tma_memref.element_type, off_b2
+    )
+
+    mbar_group[slot].arrive(ta_count, predicate=p)
+
+    c1 = stage * 64
+    a_tma.load(x, mbar_group[slot], coords=[c1, 0], predicate=p)
+    b_tma.load(y1, mbar_group[slot], coords=[0, c1], predicate=p)
+    b_tma.load(y2, mbar_group[slot], coords=[64, c1], predicate=p)
+
+
 @NVDSL.mlir_func
-def gemm_64_64_64(x, y, d):
+def gemm_128_128_64(a, b, d):
     token_ty = ir.Type.parse("!gpu.async.token")
     t1 = gpu.wait(token_ty, [])
-    a_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
-    b_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
+    a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
+    b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])
     d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], [])
-    t5 = gpu.memcpy(token_ty, [t4], a_dev, x)
-    t6 = gpu.memcpy(token_ty, [t5], b_dev, y)
+    t5 = gpu.memcpy(token_ty, [t4], a_dev, a)
+    t6 = gpu.memcpy(token_ty, [t5], b_dev, b)
     t7 = gpu.wait(token_ty, [t6])
 
     sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
-    a_tma = TMA([N, N], x.type, swizzle=sw)
-    b_tma = TMA([N, N], y.type, swizzle=sw)
+    a_tma = TMA([128, 64], a.type, swizzle=sw)
+    b_tma = TMA([64, 64], b.type, swizzle=sw)
     a_tma.create_descriptor(a_dev)
     b_tma.create_descriptor(b_dev)
-    smem_size_in_bytes = get_type_size(x.type) + get_type_size(y.type)
+    smem_size_in_bytes = get_type_size(a.type) + get_type_size(b.type)
 
     @NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(128, 1, 1), smem=smem_size_in_bytes)
     def gemm_tma_kernel():
@@ -52,26 +100,24 @@ def gemm_tma_kernel():
             b_tma.prefetch()
             scf.yield_([])
 
-        a_smem = get_dynamic_shared_memory((N, N), T.f16())
-        b_smem = get_dynamic_shared_memory((N, N), T.f16(), offset=N * N * 2)
+        a_smem = get_dynamic_shared_memory((M, K), T.f16())
+        b_smem = get_dynamic_shared_memory(
+            (K, N), T.f16(), offset=get_type_size(a.type)
+        )
 
         # 1. Execute TMA Load for two input matrices
-        with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
-            a_tma.load(a_smem, mbar_group[0])
-            b_tma.load(b_smem, mbar_group[0])
-            ta_count = get_type_size(a_tma.tma_memref) + get_type_size(b_tma.tma_memref)
-            mbar_group[0].arrive(ta_count)
-            scf.yield_([])
+        tma_load(mbar_group, a_tma, b_tma, 0, 0, p=isThread0)
 
+        # 2. All threads wait TMA load completion
         mbar_group[0].try_wait()
 
-        # 2. Performs Tensor Core GEMM 64x64x64 by warpgroup
-        A = WarpgroupMatrix(a_smem, a_tma, N, N)
-        B = WarpgroupMatrix(b_smem, b_tma, N, N)
-        C = WarpgroupAccumulatorMatrix(N, N, T.f32()).op()
+        # 3. Performs Tensor Core GEMM 128x128x64 by warpgroup
+        A = WarpgroupMatrix(a_smem, a_tma, M, K)
+        B = WarpgroupMatrix(b_smem, b_tma, K, N)
+        C = WarpgroupAccumulatorMatrix(M, N, T.f32()).op()
         D = WarpgroupMatrix.matmul(A, B, C)
 
-        # 3. Stores fragmented registers to global memory by warpgroup
+        # 4. Stores fragmented registers to global memory by warpgroup
         nvgpu.warpgroup_mma_store(D, d_dev)
 
     gemm_tma_kernel()
@@ -81,11 +127,13 @@ def gemm_tma_kernel():
 
 
 # Python pass arguments to MLIR
-N = 64
-a = np.random.randn(N, N).astype(np.float16)
-b = np.random.randn(N, N).astype(np.float16)
-d = np.zeros((N, N), np.float32)
-gemm_64_64_64(a, b, d)
+M = 128
+N = 128
+K = 64
+a = np.random.randn(M, K).astype(np.float16)
+b = np.random.randn(K, N).astype(np.float16)
+d = np.zeros((M, N), np.float32)
+gemm_128_128_64(a, b, d)
 
 ref_d = a.astype(np.float16) @ b.astype(np.float16)
 np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)

>From f6160e6a37bc2a48940a070c42ba9cdb38e94977 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Mon, 1 Apr 2024 14:52:36 +0000
Subject: [PATCH 05/14] fix names and simplify tma_load for ch3

---
 mlir/test/Examples/nvgpu/Ch3.py | 45 ++++++++++++---------------------
 mlir/test/Examples/nvgpu/Ch4.py | 44 +++++++++++++++-----------------
 2 files changed, 36 insertions(+), 53 deletions(-)

diff --git a/mlir/test/Examples/nvgpu/Ch3.py b/mlir/test/Examples/nvgpu/Ch3.py
index 12e6def2395555..ac0ee9b3f62c89 100644
--- a/mlir/test/Examples/nvgpu/Ch3.py
+++ b/mlir/test/Examples/nvgpu/Ch3.py
@@ -26,48 +26,35 @@ def tma_load(
     mbar_group: Mbarriers,
     a_tma: TMA,
     b_tma: TMA,
-    slot,
-    stage,
-    p=None,
+    p,
 ):
     """
     TMA loads two input matrices from global memory to shared memory. It performs the following operations:
 
-       - tma.load a_shared_memory[offset] at coordinate [x, y] (Loads 128x64)
-       - tma.load b_shared_memory[offset] at coordinate [x, y] (Loads 64x64)
-       - tma.load b_shared_memory[offset] at coordinate [x, y] (Loads 64x64)
+       - tma.load a_shared_memory[0] at coordinate [0, 0]  (Loads 128x64)
+       - tma.load b_shared_memory[0] at coordinate [0, 0]  (Loads 64x64)
+       - tma.load b_shared_memory[0] at coordinate [64, 0] (Loads 64x64)
 
-       mbarrier.arrive ta_count = 128x64x2x4
+       mbarrier.arrive ta_count = 128x64xf16 + 64x128xf16
     """
 
-    tidx = gpu.thread_id(gpu.Dimension.x)
-    begin_b = get_type_size(a_tma.tma_memref)
     size_tma_a = get_type_size(a_tma.tma_memref)
     size_tma_b = get_type_size(b_tma.tma_memref)
     ta_count = size_tma_a + (size_tma_b * 2)
-    tidx = gpu.thread_id(gpu.Dimension.x)
 
-    p = tidx == 0 if p is None else p
-
-    off_a = slot * size_tma_a
-    off_b = (slot * size_tma_a) + begin_b
+    off_b = size_tma_a
     off_b2 = off_b + size_tma_b
-    x = get_dynamic_shared_memory(
-        a_tma.tma_memref.shape, a_tma.tma_memref.element_type, off_a
-    )
-    y1 = get_dynamic_shared_memory(
-        b_tma.tma_memref.shape, b_tma.tma_memref.element_type, off_b
-    )
-    y2 = get_dynamic_shared_memory(
-        b_tma.tma_memref.shape, b_tma.tma_memref.element_type, off_b2
-    )
+    a_elem_ty = a_tma.tma_memref.element_type
+    b_elem_ty = b_tma.tma_memref.element_type
+    a = get_dynamic_shared_memory(a_tma.tma_memref.shape, a_elem_ty)
+    b1 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b)
+    b2 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b2)
 
-    mbar_group[slot].arrive(ta_count, predicate=p)
+    mbar_group[0].arrive(ta_count, predicate=p)
 
-    c1 = stage * 64
-    a_tma.load(x, mbar_group[slot], coords=[c1, 0], predicate=p)
-    b_tma.load(y1, mbar_group[slot], coords=[0, c1], predicate=p)
-    b_tma.load(y2, mbar_group[slot], coords=[64, c1], predicate=p)
+    a_tma.load(a, mbar_group[0], coords=[0, 0], predicate=p)
+    b_tma.load(b1, mbar_group[0], coords=[0, 0], predicate=p)
+    b_tma.load(b2, mbar_group[0], coords=[64, 0], predicate=p)
 
 
 @NVDSL.mlir_func
@@ -106,7 +93,7 @@ def gemm_tma_kernel():
         )
 
         # 1. Execute TMA Load for two input matrices
-        tma_load(mbar_group, a_tma, b_tma, 0, 0, p=isThread0)
+        tma_load(mbar_group, a_tma, b_tma, isThread0)
 
         # 2. All threads wait TMA load completion
         mbar_group[0].try_wait()
diff --git a/mlir/test/Examples/nvgpu/Ch4.py b/mlir/test/Examples/nvgpu/Ch4.py
index 2f38a49501e792..4b426522a89c52 100644
--- a/mlir/test/Examples/nvgpu/Ch4.py
+++ b/mlir/test/Examples/nvgpu/Ch4.py
@@ -60,9 +60,9 @@ def tma_load(
     """
     TMA loads two input matrices from global memory to shared memory. It performs the following operations:
 
-       - tma.load a_shared_memory[offset] at coordinate [x, y] (Loads 128x64)
-       - tma.load b_shared_memory[offset] at coordinate [x, y] (Loads 64x64)
-       - tma.load b_shared_memory[offset] at coordinate [x, y] (Loads 64x64)
+       - tma.load a_shared_memory[off_x]  at coordinate [x, z]      (Loads 128x64)
+       - tma.load b_shared_memory[off_y1] at coordinate [y, x]      (Loads 64x64)
+       - tma.load b_shared_memory[off_y2] at coordinate [y + 64, x] (Loads 64x64)
 
        mbarrier.arrive ta_count = 128x64x2x4
     """
@@ -80,22 +80,18 @@ def tma_load(
     off_a = slot * size_tma_a
     off_b = (slot * size_tma_a) + begin_b
     off_b2 = off_b + size_tma_b
-    x = get_dynamic_shared_memory(
-        a_tma.tma_memref.shape, a_tma.tma_memref.element_type, off_a
-    )
-    y1 = get_dynamic_shared_memory(
-        b_tma.tma_memref.shape, b_tma.tma_memref.element_type, off_b
-    )
-    y2 = get_dynamic_shared_memory(
-        b_tma.tma_memref.shape, b_tma.tma_memref.element_type, off_b2
-    )
+    a_elem_ty = a_tma.tma_memref.element_type
+    b_elem_ty = b_tma.tma_memref.element_type
+    a = get_dynamic_shared_memory(a_tma.tma_memref.shape, a_elem_ty, off_a)
+    b1 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b)
+    b2 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b2)
 
     mbar_group[slot].arrive(ta_count, predicate=p)
 
     c1 = stage * 64
-    a_tma.load(x, mbar_group[slot], coords=[c1, dimX], predicate=p)
-    b_tma.load(y1, mbar_group[slot], coords=[dimY, c1], predicate=p)
-    b_tma.load(y2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p)
+    a_tma.load(a, mbar_group[slot], coords=[c1, dimX], predicate=p)
+    b_tma.load(b1, mbar_group[slot], coords=[dimY, c1], predicate=p)
+    b_tma.load(b2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p)
 
 
 def bootstrap(a_tma: TMA, b_tma: TMA):
@@ -229,27 +225,27 @@ def epilogue(D, d_dev):
 
 
 @NVDSL.mlir_func
-def gemm_multistage(x, y, d):
+def gemm_multistage(a, b, d):
     token_ty = ir.Type.parse("!gpu.async.token")
     t1 = gpu.wait(token_ty, [])
-    a_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
-    b_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
+    a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
+    b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])
     d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], [])
-    t5 = gpu.memcpy(token_ty, [t4], a_dev, x)
-    t6 = gpu.memcpy(token_ty, [t5], b_dev, y)
+    t5 = gpu.memcpy(token_ty, [t4], a_dev, a)
+    t6 = gpu.memcpy(token_ty, [t5], b_dev, b)
     t7 = gpu.wait(token_ty, [t6])
 
     sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
-    a_tma = TMA([128, 64], x.type, swizzle=sw)
-    b_tma = TMA([64, 64], y.type, swizzle=sw)
+    a_tma = TMA([128, 64], a.type, swizzle=sw)
+    b_tma = TMA([64, 64], b.type, swizzle=sw)
     a_tma.create_descriptor(a_dev)
     b_tma.create_descriptor(b_dev)
 
     grid = [(M // TILE_M), (N // TILE_N), 1]
     block = [128, 1, 1]
 
-    size_a = get_type_size(x.type.element_type) * TILE_M * TILE_K
-    size_b = get_type_size(x.type.element_type) * TILE_N * TILE_K
+    size_a = get_type_size(a.type.element_type) * TILE_M * TILE_K
+    size_b = get_type_size(b.type.element_type) * TILE_N * TILE_K
     smem_size_in_bytes = (size_a + size_b) * NUM_STAGES
 
     @NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=smem_size_in_bytes)

>From 9c7934dc82e20389faf01bdb90f5e9dd648db3ea Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Tue, 2 Apr 2024 14:35:22 +0000
Subject: [PATCH 06/14] Operator overload for `C += A @ B`

---
 mlir/test/Examples/nvgpu/Ch3.py         | 12 +++--
 mlir/test/Examples/nvgpu/Ch4.py         | 24 ++++++---
 mlir/test/Examples/nvgpu/tools/nvdsl.py | 70 ++++++++++++++++---------
 3 files changed, 69 insertions(+), 37 deletions(-)

diff --git a/mlir/test/Examples/nvgpu/Ch3.py b/mlir/test/Examples/nvgpu/Ch3.py
index ac0ee9b3f62c89..8c2864dd0d465d 100644
--- a/mlir/test/Examples/nvgpu/Ch3.py
+++ b/mlir/test/Examples/nvgpu/Ch3.py
@@ -99,13 +99,15 @@ def gemm_tma_kernel():
         mbar_group[0].try_wait()
 
         # 3. Performs Tensor Core GEMM 128x128x64 by warpgroup
-        A = WarpgroupMatrix(a_smem, a_tma, M, K)
-        B = WarpgroupMatrix(b_smem, b_tma, K, N)
-        C = WarpgroupAccumulatorMatrix(M, N, T.f32()).op()
-        D = WarpgroupMatrix.matmul(A, B, C)
+        A = WGMMAMatrix(WGMMAType.Descriptor, [M,K], desc=a_tma, smem=a_smem)
+        B = WGMMAMatrix(WGMMAType.Descriptor, [K,N], desc=b_tma, smem=b_smem)
+        C = WGMMAMatrix(WGMMAType.Accumulator, shape=[M,N], ty=T.f32())
+
+        # Matrix Multiply
+        C += A @ B
 
         # 4. Stores fragmented registers to global memory by warpgroup
-        nvgpu.warpgroup_mma_store(D, d_dev)
+        nvgpu.warpgroup_mma_store(C, d_dev)
 
     gemm_tma_kernel()
 
diff --git a/mlir/test/Examples/nvgpu/Ch4.py b/mlir/test/Examples/nvgpu/Ch4.py
index 4b426522a89c52..9555b367a0078b 100644
--- a/mlir/test/Examples/nvgpu/Ch4.py
+++ b/mlir/test/Examples/nvgpu/Ch4.py
@@ -151,11 +151,15 @@ def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
 
     size_a = TILE_M * TILE_K * get_type_size(T.f16())
 
-    C = WarpgroupAccumulatorMatrix(TILE_M, TILE_N, T.f32()).op()
+    # Initialize A and B (input matrices) and C (accumulator)
+    A = WGMMAMatrix(WGMMAType.Descriptor, [TILE_M, TILE_K], desc=a_tma)
+    B = WGMMAMatrix(WGMMAType.Descriptor, [TILE_K, TILE_N], desc=b_tma)
+    C = WGMMAMatrix(WGMMAType.Accumulator, shape=[TILE_M, TILE_N], ty=T.f32())
+    
     pp = const(False, ty=T.bool())
 
     # Main Loop
-    for_op = scf.ForOp(const(0), const(K // TILE_K), const(1), [C, pp])
+    for_op = scf.ForOp(const(0), const(K // TILE_K), const(1), [C.acc_op, pp])
     with ir.InsertionPoint(for_op.body):
         pp = for_op.inner_iter_args[1]
         iv = for_op.induction_variable
@@ -170,14 +174,18 @@ def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
         a_smem = get_dynamic_shared_memory([TILE_M, TILE_K], T.f16(), offset_a)
         b_smem = get_dynamic_shared_memory([TILE_K, TILE_N], T.f16(), offset_b)
 
+        # Initialize matrices
+        A.update_smem(a_smem)
+        B.update_smem(b_smem)
+        C.update_accumulator(for_op.inner_iter_args[0])
+
         # Matrix Multiply
-        A = WarpgroupMatrix(a_smem, a_tma, TILE_M, TILE_K)
-        B = WarpgroupMatrix(b_smem, b_tma, TILE_K, TILE_N)
-        C = for_op.inner_iter_args[0]
-        D = WarpgroupMatrix.matmul(A, B, C)
+        C += A @ B
+
+        # Wait Tensor Core for single stage
         if NUM_STAGES == 1:
             nvvm.WgmmaWaitGroupSyncOp(0)
-
+            
         # Load next stage
         pred = ((iv + ns) < const(K // TILE_K)) & (tidx == 0)
         nextStage = iv + ns
@@ -191,7 +199,7 @@ def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
             switched,
             pp,
         )
-        scf.yield_([D, newPP])
+        scf.yield_([C, newPP])
 
     nvvm.WgmmaWaitGroupSyncOp(0)
 
diff --git a/mlir/test/Examples/nvgpu/tools/nvdsl.py b/mlir/test/Examples/nvgpu/tools/nvdsl.py
index bb312a9e259ef3..cfa90cdf3c6539 100644
--- a/mlir/test/Examples/nvgpu/tools/nvdsl.py
+++ b/mlir/test/Examples/nvgpu/tools/nvdsl.py
@@ -161,42 +161,64 @@ def load(self, dest, mbarrier: Mbarriers, coords=[0, 0], predicate=None):
         )
 
 
-class WarpgroupAccumulatorMatrix:
-    def __init__(self, M, N, ty):
-        self.M = M
-        self.N = N
-        self.ty = ty
+class WGMMAType(Enum):
+    Accumulator = 1
+    Descriptor = 2
+
+
+class WGMMAMatrix:
+    def __init__(
+        self,
+        matrix_type: WGMMAType,
+        shape: list = None,
+        desc: TMA = None,
+        smem=None,
+        ty=None,
+        acc_op=None,
+    ):
+        if acc_op is None:
+            self.M = shape[0]
+            self.N = shape[1]
+            self.ty = ty
+            self.matrix_type = matrix_type
+            self.desc = desc
+            self.smem = smem
+            if matrix_type is WGMMAType.Accumulator:
+               self.acc_op = nvgpu.warpgroup_mma_init_accumulator(self.acc_ty) 
+        elif acc_op:
+            self.acc_op = acc_op
+            self.matrix_type = WGMMAType.Accumulator
 
     @property
     def acc_ty(self):
         parse_str = f"!nvgpu.warpgroup.accumulator<fragmented=vector<{self.M}x{self.N}x{self.ty}>>"
         return ir.Type.parse(parse_str)
 
-    def op(self):
-        return nvgpu.warpgroup_mma_init_accumulator(self.acc_ty)
-
-
-class WarpgroupMatrix:
-    def __init__(self, smem, tma_descriptor: TMA, M, N):
-        self.tma_descriptor = tma_descriptor
-        self.smem = smem
-        self.M = M
-        self.N = N
-
     @property
     def wgmma_ty(self):
-        parse_str = f"!nvgpu.warpgroup.descriptor<tensor=memref<{self.M}x{self.N}x{self.tma_descriptor.memref_ty.element_type}, #gpu.address_space<workgroup>>>"
+        parse_str = f"!nvgpu.warpgroup.descriptor<tensor=memref<{self.M}x{self.N}x{self.desc.memref_ty.element_type}, #gpu.address_space<workgroup>>>"
         return ir.Type.parse(parse_str)
 
-    def matmul(lhs, rhs, acc):
-        wgmma_desc_lhs = nvgpu.warpgroup_generate_descriptor(
-            lhs.wgmma_ty, lhs.smem, lhs.tma_descriptor.tma_descriptor
+    def update_smem(self, smem):
+        self.smem = smem 
+    
+    def update_accumulator(self, acc_op):
+        self.acc_op = acc_op 
+
+    def __matmul__(self, rhs):
+        lhs = nvgpu.warpgroup_generate_descriptor(
+            self.wgmma_ty, self.smem, self.desc.tma_descriptor
         )
-        wgmma_desc_rhs = nvgpu.warpgroup_generate_descriptor(
-            rhs.wgmma_ty, rhs.smem, rhs.tma_descriptor.tma_descriptor
+        rhs = nvgpu.warpgroup_generate_descriptor(
+            rhs.wgmma_ty, rhs.smem, rhs.desc.tma_descriptor
         )
+        return [lhs, rhs]
+
+    def __iadd__(self, matmulResult):
+        lhs = matmulResult[0]
+        rhs = matmulResult[1]
         return nvgpu.WarpgroupMmaOp(
-            acc.type, wgmma_desc_lhs, wgmma_desc_rhs, acc, transposeB=True
+            self.acc_op.type, lhs, rhs, self.acc_op, transposeB=True
         )
 
 
@@ -376,7 +398,7 @@ def __str__(self):
                 module.operation.verify()
 
                 # Save IR in a file
-                # saveIR(module)
+                saveIR(module)
 
                 # Compile and JIT MLIR module
                 options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"

>From b5b786cf4c5bf7b480a4980b85b0de75915bea45 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Wed, 3 Apr 2024 09:48:35 +0000
Subject: [PATCH 07/14] format, more simplification

---
 mlir/test/Examples/nvgpu/Ch2.py         | 14 ++++----------
 mlir/test/Examples/nvgpu/Ch3.py         | 23 +++++++++++------------
 mlir/test/Examples/nvgpu/Ch4.py         |  4 ++--
 mlir/test/Examples/nvgpu/tools/nvdsl.py |  8 ++++----
 4 files changed, 21 insertions(+), 28 deletions(-)

diff --git a/mlir/test/Examples/nvgpu/Ch2.py b/mlir/test/Examples/nvgpu/Ch2.py
index e9369227af0e85..749e1a00585fd6 100644
--- a/mlir/test/Examples/nvgpu/Ch2.py
+++ b/mlir/test/Examples/nvgpu/Ch2.py
@@ -46,21 +46,15 @@ def saxpy_tma_kernel():
 
         # 1. Create and initialize asynchronous transactional barrier (mbarrier)
         mbar_group = Mbarriers(number_of_barriers=1)
-        with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
-            mbar_group[0].init(1)
-            x_tma.prefetch()
-            y_tma.prefetch()
-            scf.yield_([])
+        mbar_group[0].init(1, predicate=isThread0)
 
         x_smem = get_dynamic_shared_memory((M, N), T.f32())
         y_smem = get_dynamic_shared_memory((M, N), T.f32(), offset=M * N * 2)
 
         # 2. Execute Tensor Memory Accelerator (TMA) Load
-        with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
-            x_tma.load(x_smem, mbar_group[0])
-            y_tma.load(y_smem, mbar_group[0])
-            mbar_group[0].arrive(txcount=M * N * 2 * 4)
-            scf.yield_([])
+        x_tma.load(x_smem, mbar_group[0], predicate=isThread0)
+        y_tma.load(y_smem, mbar_group[0], predicate=isThread0)
+        mbar_group[0].arrive(txcount=M * N * 2 * 4, predicate=isThread0)
 
         # 3. Wait for completion of TMA load with mbarrier
         mbar_group[0].try_wait()
diff --git a/mlir/test/Examples/nvgpu/Ch3.py b/mlir/test/Examples/nvgpu/Ch3.py
index 8c2864dd0d465d..6b1d36908f4365 100644
--- a/mlir/test/Examples/nvgpu/Ch3.py
+++ b/mlir/test/Examples/nvgpu/Ch3.py
@@ -73,7 +73,9 @@ def gemm_128_128_64(a, b, d):
     b_tma = TMA([64, 64], b.type, swizzle=sw)
     a_tma.create_descriptor(a_dev)
     b_tma.create_descriptor(b_dev)
-    smem_size_in_bytes = get_type_size(a.type) + get_type_size(b.type)
+    a_size = get_type_size(a.type)
+    b_size = get_type_size(b.type)
+    smem_size_in_bytes = a_size + b_size
 
     @NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(128, 1, 1), smem=smem_size_in_bytes)
     def gemm_tma_kernel():
@@ -81,16 +83,13 @@ def gemm_tma_kernel():
 
         mbar_group = Mbarriers(number_of_barriers=1)
         isThread0 = tidx == 0
-        with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
-            mbar_group[0].init(1)
-            a_tma.prefetch()
-            b_tma.prefetch()
-            scf.yield_([])
+
+        mbar_group[0].init(1, predicate=isThread0)
+        a_tma.prefetch(predicate=isThread0)
+        b_tma.prefetch(predicate=isThread0)
 
         a_smem = get_dynamic_shared_memory((M, K), T.f16())
-        b_smem = get_dynamic_shared_memory(
-            (K, N), T.f16(), offset=get_type_size(a.type)
-        )
+        b_smem = get_dynamic_shared_memory((K, N), T.f16(), offset=a_size)
 
         # 1. Execute TMA Load for two input matrices
         tma_load(mbar_group, a_tma, b_tma, isThread0)
@@ -99,9 +98,9 @@ def gemm_tma_kernel():
         mbar_group[0].try_wait()
 
         # 3. Performs Tensor Core GEMM 128x128x64 by warpgroup
-        A = WGMMAMatrix(WGMMAType.Descriptor, [M,K], desc=a_tma, smem=a_smem)
-        B = WGMMAMatrix(WGMMAType.Descriptor, [K,N], desc=b_tma, smem=b_smem)
-        C = WGMMAMatrix(WGMMAType.Accumulator, shape=[M,N], ty=T.f32())
+        A = WGMMAMatrix(WGMMAType.Descriptor, [M, K], desc=a_tma, smem=a_smem)
+        B = WGMMAMatrix(WGMMAType.Descriptor, [K, N], desc=b_tma, smem=b_smem)
+        C = WGMMAMatrix(WGMMAType.Accumulator, shape=[M, N], ty=T.f32())
 
         # Matrix Multiply
         C += A @ B
diff --git a/mlir/test/Examples/nvgpu/Ch4.py b/mlir/test/Examples/nvgpu/Ch4.py
index 9555b367a0078b..1f23f423ce4092 100644
--- a/mlir/test/Examples/nvgpu/Ch4.py
+++ b/mlir/test/Examples/nvgpu/Ch4.py
@@ -155,7 +155,7 @@ def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
     A = WGMMAMatrix(WGMMAType.Descriptor, [TILE_M, TILE_K], desc=a_tma)
     B = WGMMAMatrix(WGMMAType.Descriptor, [TILE_K, TILE_N], desc=b_tma)
     C = WGMMAMatrix(WGMMAType.Accumulator, shape=[TILE_M, TILE_N], ty=T.f32())
-    
+
     pp = const(False, ty=T.bool())
 
     # Main Loop
@@ -185,7 +185,7 @@ def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
         # Wait Tensor Core for single stage
         if NUM_STAGES == 1:
             nvvm.WgmmaWaitGroupSyncOp(0)
-            
+
         # Load next stage
         pred = ((iv + ns) < const(K // TILE_K)) & (tidx == 0)
         nextStage = iv + ns
diff --git a/mlir/test/Examples/nvgpu/tools/nvdsl.py b/mlir/test/Examples/nvgpu/tools/nvdsl.py
index cfa90cdf3c6539..b9754491d50e8b 100644
--- a/mlir/test/Examples/nvgpu/tools/nvdsl.py
+++ b/mlir/test/Examples/nvgpu/tools/nvdsl.py
@@ -184,7 +184,7 @@ def __init__(
             self.desc = desc
             self.smem = smem
             if matrix_type is WGMMAType.Accumulator:
-               self.acc_op = nvgpu.warpgroup_mma_init_accumulator(self.acc_ty) 
+                self.acc_op = nvgpu.warpgroup_mma_init_accumulator(self.acc_ty)
         elif acc_op:
             self.acc_op = acc_op
             self.matrix_type = WGMMAType.Accumulator
@@ -200,10 +200,10 @@ def wgmma_ty(self):
         return ir.Type.parse(parse_str)
 
     def update_smem(self, smem):
-        self.smem = smem 
-    
+        self.smem = smem
+
     def update_accumulator(self, acc_op):
-        self.acc_op = acc_op 
+        self.acc_op = acc_op
 
     def __matmul__(self, rhs):
         lhs = nvgpu.warpgroup_generate_descriptor(

>From cc384c47867c9a5168c95a36dbbcff75f68c73a6 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Wed, 3 Apr 2024 09:53:26 +0000
Subject: [PATCH 08/14] calculate shmem dynamically

---
 mlir/test/Examples/nvgpu/Ch2.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/test/Examples/nvgpu/Ch2.py b/mlir/test/Examples/nvgpu/Ch2.py
index 749e1a00585fd6..347da9ab900184 100644
--- a/mlir/test/Examples/nvgpu/Ch2.py
+++ b/mlir/test/Examples/nvgpu/Ch2.py
@@ -37,8 +37,9 @@ def saxpy_tma(x, y, alpha):
     y_tma = TMA((M, N), y.type)
     x_tma.create_descriptor(x_dev)
     y_tma.create_descriptor(y_dev)
+    smem_size_in_bytes = get_type_size(x.type) + get_type_size(y.type)
 
-    @NVDSL.mlir_gpu_launch(grid=(M, 1, 1), block=(N, 1, 1), smem=65536)
+    @NVDSL.mlir_gpu_launch(grid=(M, 1, 1), block=(N, 1, 1), smem=smem_size_in_bytes)
     def saxpy_tma_kernel():
         bidx = gpu.block_id(gpu.Dimension.x)
         tidx = gpu.thread_id(gpu.Dimension.x)

>From 2e659a351229bc278354d115c307b326f02e72ee Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Thu, 4 Apr 2024 08:02:13 +0000
Subject: [PATCH 09/14] add `update_accumulator`

---
 mlir/test/Examples/nvgpu/Ch2.py         |  3 +--
 mlir/test/Examples/nvgpu/Ch3.py         |  8 +++----
 mlir/test/Examples/nvgpu/Ch4.py         | 30 ++++++++++++-------------
 mlir/test/Examples/nvgpu/tools/nvdsl.py | 10 ++++++---
 4 files changed, 27 insertions(+), 24 deletions(-)

diff --git a/mlir/test/Examples/nvgpu/Ch2.py b/mlir/test/Examples/nvgpu/Ch2.py
index 347da9ab900184..f7817747fde70c 100644
--- a/mlir/test/Examples/nvgpu/Ch2.py
+++ b/mlir/test/Examples/nvgpu/Ch2.py
@@ -49,10 +49,9 @@ def saxpy_tma_kernel():
         mbar_group = Mbarriers(number_of_barriers=1)
         mbar_group[0].init(1, predicate=isThread0)
 
+        # 2. Execute Tensor Memory Accelerator (TMA) Load
         x_smem = get_dynamic_shared_memory((M, N), T.f32())
         y_smem = get_dynamic_shared_memory((M, N), T.f32(), offset=M * N * 2)
-
-        # 2. Execute Tensor Memory Accelerator (TMA) Load
         x_tma.load(x_smem, mbar_group[0], predicate=isThread0)
         y_tma.load(y_smem, mbar_group[0], predicate=isThread0)
         mbar_group[0].arrive(txcount=M * N * 2 * 4, predicate=isThread0)
diff --git a/mlir/test/Examples/nvgpu/Ch3.py b/mlir/test/Examples/nvgpu/Ch3.py
index 6b1d36908f4365..9e91a6bf931ee1 100644
--- a/mlir/test/Examples/nvgpu/Ch3.py
+++ b/mlir/test/Examples/nvgpu/Ch3.py
@@ -91,7 +91,7 @@ def gemm_tma_kernel():
         a_smem = get_dynamic_shared_memory((M, K), T.f16())
         b_smem = get_dynamic_shared_memory((K, N), T.f16(), offset=a_size)
 
-        # 1. Execute TMA Load for two input matrices
+        # 1. TMA Load for two input matrices
         tma_load(mbar_group, a_tma, b_tma, isThread0)
 
         # 2. All threads wait TMA load completion
@@ -100,13 +100,13 @@ def gemm_tma_kernel():
         # 3. Performs Tensor Core GEMM 128x128x64 by warpgroup
         A = WGMMAMatrix(WGMMAType.Descriptor, [M, K], desc=a_tma, smem=a_smem)
         B = WGMMAMatrix(WGMMAType.Descriptor, [K, N], desc=b_tma, smem=b_smem)
-        C = WGMMAMatrix(WGMMAType.Accumulator, shape=[M, N], ty=T.f32())
+        D = WGMMAMatrix(WGMMAType.Accumulator, shape=[M, N], ty=T.f32())
 
         # Matrix Multiply
-        C += A @ B
+        D += A @ B
 
         # 4. Stores fragmented registers to global memory by warpgroup
-        nvgpu.warpgroup_mma_store(C, d_dev)
+        D.store_accumulator(d_dev)
 
     gemm_tma_kernel()
 
diff --git a/mlir/test/Examples/nvgpu/Ch4.py b/mlir/test/Examples/nvgpu/Ch4.py
index 1f23f423ce4092..ef0edadefa0094 100644
--- a/mlir/test/Examples/nvgpu/Ch4.py
+++ b/mlir/test/Examples/nvgpu/Ch4.py
@@ -154,19 +154,19 @@ def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
     # Initialize A and B (input matrices) and C (accumulator)
     A = WGMMAMatrix(WGMMAType.Descriptor, [TILE_M, TILE_K], desc=a_tma)
     B = WGMMAMatrix(WGMMAType.Descriptor, [TILE_K, TILE_N], desc=b_tma)
-    C = WGMMAMatrix(WGMMAType.Accumulator, shape=[TILE_M, TILE_N], ty=T.f32())
+    D = WGMMAMatrix(WGMMAType.Accumulator, shape=[TILE_M, TILE_N], ty=T.f32())
 
-    pp = const(False, ty=T.bool())
+    phase = const(False, ty=T.bool())
 
     # Main Loop
-    for_op = scf.ForOp(const(0), const(K // TILE_K), const(1), [C.acc_op, pp])
+    for_op = scf.ForOp(const(0), const(K // TILE_K), const(1), [D.acc_op, phase])
     with ir.InsertionPoint(for_op.body):
-        pp = for_op.inner_iter_args[1]
+        phase = for_op.inner_iter_args[1]
         iv = for_op.induction_variable
         stage = iv % NUM_STAGES
 
         # Wait for current stage
-        mbar_group[stage].try_wait(phase=pp)
+        mbar_group[stage].try_wait(phase=phase)
 
         # Find shared memory slot
         offset_a = stage * size_a
@@ -177,10 +177,10 @@ def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
         # Initialize matrices
         A.update_smem(a_smem)
         B.update_smem(b_smem)
-        C.update_accumulator(for_op.inner_iter_args[0])
+        D.update_accumulator(for_op.inner_iter_args[0])
 
         # Matrix Multiply
-        C += A @ B
+        D += A @ B
 
         # Wait Tensor Core for single stage
         if NUM_STAGES == 1:
@@ -193,20 +193,20 @@ def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
         tma_load(mbar_group, a_tma, b_tma, nextSlot, nextStage, pred)
 
         # Switch phase parity for the mbarrier
-        switched = pp ^ const(True, ty=T.bool())
-        newPP = arith.select(
+        newPhase = arith.select(
             stage == (NUM_STAGES - 1),
-            switched,
-            pp,
+            (phase ^ const(True, ty=T.bool())),
+            phase,
         )
-        scf.yield_([C, newPP])
+        scf.yield_([D.acc_op, newPhase])
 
     nvvm.WgmmaWaitGroupSyncOp(0)
 
-    return for_op.results[0]
+    D.update_accumulator(for_op.results[0])
+    return D
 
 
-def epilogue(D, d_dev):
+def epilogue(D: WGMMAMatrix, d_dev):
     """
     Epilogue of the GEMM kernel. It stores the fragmented registers to global memory.
 
@@ -222,7 +222,7 @@ def epilogue(D, d_dev):
     d_gmem = memref.subview(d_dev, [dimX, dimY], [TILE_M, TILE_N], [1, 1])
 
     # Store (registers -> shared memory)
-    nvgpu.WarpgroupMmaStoreOp(D, d_smem)
+    D.store_accumulator(d_smem)
     gpu.barrier()
 
     # Store (shared memory --> global memory)
diff --git a/mlir/test/Examples/nvgpu/tools/nvdsl.py b/mlir/test/Examples/nvgpu/tools/nvdsl.py
index b9754491d50e8b..6904fe68b5a01d 100644
--- a/mlir/test/Examples/nvgpu/tools/nvdsl.py
+++ b/mlir/test/Examples/nvgpu/tools/nvdsl.py
@@ -150,12 +150,11 @@ def prefetch(self, predicate=None):
         nvgpu.tma_prefetch_descriptor(self.tma_descriptor, predicate=predicate)
 
     def load(self, dest, mbarrier: Mbarriers, coords=[0, 0], predicate=None):
-        coord_ops = [const(c) for c in coords]
         nvgpu.TmaAsyncLoadOp(
             dest,
             mbarrier.mbar_group_op,
             self.tma_descriptor,
-            coordinates=coord_ops,
+            coordinates=map(const, coords),
             mbarId=mbarrier.id_op,
             predicate=predicate,
         )
@@ -199,6 +198,10 @@ def wgmma_ty(self):
         parse_str = f"!nvgpu.warpgroup.descriptor<tensor=memref<{self.M}x{self.N}x{self.desc.memref_ty.element_type}, #gpu.address_space<workgroup>>>"
         return ir.Type.parse(parse_str)
 
+    def store_accumulator(self, dest):
+        assert self.matrix_type == WGMMAType.Accumulator
+        nvgpu.warpgroup_mma_store(self.acc_op, dest)
+
     def update_smem(self, smem):
         self.smem = smem
 
@@ -217,9 +220,10 @@ def __matmul__(self, rhs):
     def __iadd__(self, matmulResult):
         lhs = matmulResult[0]
         rhs = matmulResult[1]
-        return nvgpu.WarpgroupMmaOp(
+        acc_op = nvgpu.WarpgroupMmaOp(
             self.acc_op.type, lhs, rhs, self.acc_op, transposeB=True
         )
+        return WGMMAMatrix(WGMMAType.Accumulator, acc_op=acc_op)
 
 
 def get_dynamic_shared_memory(shape=None, ty=None, offset: int = 0):

>From 778069ad72f28498455b0c7e8869416d174c0245 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Thu, 4 Apr 2024 13:28:23 +0000
Subject: [PATCH 10/14] Fix saxpy tma, now it loads partial data to smem

---
 mlir/test/Examples/nvgpu/Ch1.py         |  4 +--
 mlir/test/Examples/nvgpu/Ch2.py         | 33 ++++++++++++++-----------
 mlir/test/Examples/nvgpu/tools/nvdsl.py | 24 +++++++++++-------
 3 files changed, 35 insertions(+), 26 deletions(-)

diff --git a/mlir/test/Examples/nvgpu/Ch1.py b/mlir/test/Examples/nvgpu/Ch1.py
index a888c61358a024..22e1eb7c92e005 100644
--- a/mlir/test/Examples/nvgpu/Ch1.py
+++ b/mlir/test/Examples/nvgpu/Ch1.py
@@ -56,12 +56,12 @@ def saxpy_kernel():
 M = 256
 N = 32
 alpha = 2.0
-x = np.ones((M, N), np.float32)
+x = np.random.randn(M, N).astype(np.float32)
 y = np.ones((M, N), np.float32)
-ref = np.ones((M, N), np.float32)
 saxpy(x, y, alpha)
 
 #  4. Verify MLIR with reference computation
+ref = np.ones((M, N), np.float32)
 ref += x * alpha
 np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
 print("PASS")
diff --git a/mlir/test/Examples/nvgpu/Ch2.py b/mlir/test/Examples/nvgpu/Ch2.py
index f7817747fde70c..2e316cf406e99f 100644
--- a/mlir/test/Examples/nvgpu/Ch2.py
+++ b/mlir/test/Examples/nvgpu/Ch2.py
@@ -24,7 +24,7 @@
 
 
 @NVDSL.mlir_func
-def saxpy_tma(x, y, alpha):
+def saxpy(x, y, alpha):
     token_ty = ir.Type.parse("!gpu.async.token")
     t1 = gpu.wait(token_ty, [])
     x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
@@ -33,13 +33,15 @@ def saxpy_tma(x, y, alpha):
     t5 = gpu.memcpy(token_ty, [t4], y_dev, y)
     t6 = gpu.wait(token_ty, [t5])
 
-    x_tma = TMA((M, N), x.type)
-    y_tma = TMA((M, N), y.type)
+    x_tma = TMA([1, N], x.type)
+    y_tma = TMA([1, N], y.type)
     x_tma.create_descriptor(x_dev)
     y_tma.create_descriptor(y_dev)
-    smem_size_in_bytes = get_type_size(x.type) + get_type_size(y.type)
+    sz_x = get_type_size(x_tma.tma_memref)
+    sz_y = get_type_size(x_tma.tma_memref)
+    sz = sz_x + sz_y
 
-    @NVDSL.mlir_gpu_launch(grid=(M, 1, 1), block=(N, 1, 1), smem=smem_size_in_bytes)
+    @NVDSL.mlir_gpu_launch(grid=(M, 1, 1), block=(N, 1, 1), smem=sz)
     def saxpy_tma_kernel():
         bidx = gpu.block_id(gpu.Dimension.x)
         tidx = gpu.thread_id(gpu.Dimension.x)
@@ -50,17 +52,17 @@ def saxpy_tma_kernel():
         mbar_group[0].init(1, predicate=isThread0)
 
         # 2. Execute Tensor Memory Accelerator (TMA) Load
-        x_smem = get_dynamic_shared_memory((M, N), T.f32())
-        y_smem = get_dynamic_shared_memory((M, N), T.f32(), offset=M * N * 2)
-        x_tma.load(x_smem, mbar_group[0], predicate=isThread0)
-        y_tma.load(y_smem, mbar_group[0], predicate=isThread0)
-        mbar_group[0].arrive(txcount=M * N * 2 * 4, predicate=isThread0)
+        x_smem = get_dynamic_shared_memory([1, N], T.f32())
+        y_smem = get_dynamic_shared_memory([1, N], T.f32(), offset=sz_x)
+        x_tma.load(x_smem, mbar_group[0], coords=[0, bidx], predicate=isThread0)
+        y_tma.load(y_smem, mbar_group[0], coords=[0, bidx], predicate=isThread0)
+        mbar_group[0].arrive(txcount=sz, predicate=isThread0)
 
         # 3. Wait for completion of TMA load with mbarrier
         mbar_group[0].try_wait()
 
-        x_val = memref.load(x_smem, [bidx, tidx])
-        y_val = memref.load(y_smem, [bidx, tidx])
+        x_val = memref.load(x_smem, [const(0), tidx])
+        y_val = memref.load(y_smem, [const(0), tidx])
 
         # SAXPY: y[i] += a * x[i];
         y_val += x_val * alpha
@@ -73,15 +75,16 @@ def saxpy_tma_kernel():
     gpu.wait(token_ty, [t7])
 
 
+# 3. Pass numpy arrays to MLIR
 M = 256
 N = 32
 alpha = 2.0
-x = np.ones((M, N), np.float32)
+x = np.random.randn(M, N).astype(np.float32)
 y = np.ones((M, N), np.float32)
-ref = np.ones((M, N), np.float32)
-saxpy_tma(x, y, alpha)
+saxpy(x, y, alpha)
 
 #  4. Verify MLIR with reference computation
+ref = np.ones((M, N), np.float32)
 ref += x * alpha
 np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
 print("PASS")
diff --git a/mlir/test/Examples/nvgpu/tools/nvdsl.py b/mlir/test/Examples/nvgpu/tools/nvdsl.py
index 6904fe68b5a01d..28105911db8134 100644
--- a/mlir/test/Examples/nvgpu/tools/nvdsl.py
+++ b/mlir/test/Examples/nvgpu/tools/nvdsl.py
@@ -8,7 +8,6 @@
 from mlir import runtime as rt
 from tools import nvgpucompiler
 
-DEBUG = True
 MLIR_DYNAMIC = -9223372036854775808
 
 
@@ -124,13 +123,20 @@ def __init__(
     @property
     def tensormap_descriptor_ty(self):
         """Returns a tensormap descriptor type."""
-        memref_str = f"memref<{self.tma_shape[0]}x{self.tma_shape[1]}x{self.memref_ty.element_type}, 3>"
-        parse_str = f"!nvgpu.tensormap.descriptor<tensor = {memref_str},\
-                                              swizzle = {self.swizzle},\
-                                              l2promo = {self.l2promo},\
-                                              oob = {self.oob},\
-                                              interleave = {self.interleave}>"
-
+        memref_str = (
+            "memref<"
+            + "x".join(map(str, self.tma_shape))
+            + "x"
+            + str(self.memref_ty.element_type)
+            + ", 3>"
+        )
+        parse_str = (
+            f"!nvgpu.tensormap.descriptor<tensor = {memref_str}, "
+            f"swizzle = {self.swizzle}, "
+            f"l2promo = {self.l2promo}, "
+            f"oob = {self.oob}, "
+            f"interleave = {self.interleave}>"
+        )
         return ir.Type.parse(parse_str)
 
     def create_descriptor(self, device_ptr):
@@ -149,7 +155,7 @@ def create_descriptor(self, device_ptr):
     def prefetch(self, predicate=None):
         nvgpu.tma_prefetch_descriptor(self.tma_descriptor, predicate=predicate)
 
-    def load(self, dest, mbarrier: Mbarriers, coords=[0, 0], predicate=None):
+    def load(self, dest, mbarrier: Mbarriers, coords=[0], predicate=None):
         nvgpu.TmaAsyncLoadOp(
             dest,
             mbarrier.mbar_group_op,

>From 291467bde17fee58991964e0410c6be85d36b1bc Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Sat, 6 Apr 2024 09:43:09 +0000
Subject: [PATCH 11/14] better comment

---
 mlir/test/Examples/nvgpu/Ch4.py | 23 +++++++++++++++++++++--
 1 file changed, 21 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Examples/nvgpu/Ch4.py b/mlir/test/Examples/nvgpu/Ch4.py
index ef0edadefa0094..47a81cd9a65688 100644
--- a/mlir/test/Examples/nvgpu/Ch4.py
+++ b/mlir/test/Examples/nvgpu/Ch4.py
@@ -5,7 +5,26 @@
 #  Chapter 4 : Multistage GEMM with Tensor Core
 # ===----------------------------------------------------------------------===//
 #
-# This program demonstrates a GEMM operation with 64x64x64 matrix multiplication
+# This program exemplifies a GEMM operation for `f32+=f16*f16`, utilizing the
+# Multistage method with a tile size of 128x128x64. The code completely
+# parallelizes the two outermost loops into thread blocks. It launches one Warp
+# Groups (128 threads in total) and allocates multiple slots/stage in the
+# shared memory. The program consists of three main parts: prologue, mainloop,
+# and epilogue. In the prologue, thread0 requests for TMA to load data into
+# shared memory slots. The mainloop executes MMA while simultaneously loading
+# TMA for the utilized slots. This overlap of TMA and MMA operations enhances
+# performance by maximizing computational throughput.
+#
+# Loops illustration:
+#
+#  for s in range(NUM_STAGES):
+#    TMA_128x64_64x128...
+#  for ti in range(M//128):  # -> blockIdx.x
+#   for tj in range(N//128): # -> blockIdx.y
+#    for tk in range(K//64):
+#      MMA_128x128x64...
+#      TMA_128x64_64x128...
+#  Epilogue...
 #
 # This chapter introduces demonstrates:
 #  1. Partition shape based on block IDs
@@ -174,7 +193,7 @@ def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
         a_smem = get_dynamic_shared_memory([TILE_M, TILE_K], T.f16(), offset_a)
         b_smem = get_dynamic_shared_memory([TILE_K, TILE_N], T.f16(), offset_b)
 
-        # Initialize matrices
+        # Iterate input matrices, update accumulator
         A.update_smem(a_smem)
         B.update_smem(b_smem)
         D.update_accumulator(for_op.inner_iter_args[0])

>From db23db73a9b30d2d26fef324ab4602d7c237ea8a Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Sat, 6 Apr 2024 09:43:33 +0000
Subject: [PATCH 12/14] Add Ch5.py Warp Specialized Kernel

---
 mlir/test/Examples/nvgpu/Ch5.py         | 320 ++++++++++++++++++++++++
 mlir/test/Examples/nvgpu/tools/nvdsl.py |  35 ++-
 2 files changed, 352 insertions(+), 3 deletions(-)
 create mode 100644 mlir/test/Examples/nvgpu/Ch5.py

diff --git a/mlir/test/Examples/nvgpu/Ch5.py b/mlir/test/Examples/nvgpu/Ch5.py
new file mode 100644
index 00000000000000..7881af2ab54439
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/Ch5.py
@@ -0,0 +1,320 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN:   %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+#  Chapter 5 : Warp Specialized GEMM with Tensor Core
+# ===----------------------------------------------------------------------===//
+#
+# This program exemplifies a GEMM operation for `f32+=f16*f16`, utilizing the
+# Warp Specialized method with a tile size of 128x128x64. The code completely
+# parallelizes the two outermost loops into thread blocks. It launches two Warp
+# Groups (256 threads in total): one for the producer and the other for the consumer.
+# Each group takes a different control-flow. The producer thread group is responsible
+# for loading data into shared memory, while the consumer group executes the Tensor
+# Core GEMM operation and epilogue.
+#
+#  for ti in range(M//128):  # -> blockIdx.x
+#   for tj in range(N//128): # -> blockIdx.y
+#    with wg_producer:
+#     for tk in range(K//64):
+#        TMA_128x64_64x128...
+#    with wg_consumer:
+#     for tk in range(K//64):
+#        MMA_128x128x64...
+#     Epilogue..
+#
+# This chapter demonstrates:
+#  2 WG (warpgroups)
+#    Producer:
+#       2.1.1 Wait MMA Barrier
+#       2.1.1 Load TMA with TMA barrier
+#       2.1.1 Arrive TMA barrier with txcount
+#    Consumer:
+#       Loop
+#           Wait TMA barrier
+#           Performs Tensor Core GEMM 64x128x64 by warpgroup
+#           Arrive MMA Barrier
+#       Epilogue
+#           Store fragmented registers to shared memory
+#           Store shared memory to global
+#
+# ===----------------------------------------------------------------------===//
+
+
+from mlir import ir
+from mlir.dialects import gpu, scf, nvgpu, nvvm
+from mlir.extras import types as T
+from tools.nvdsl import *
+import numpy as np
+
+
+PRODUCER_PRIMARY_THREAD = 128  # Producer primary thread
+CONSUMER_PRIMARY_THREAD = 0  # Consumer primary thread
+PRODUCER_REGISTER_SIZE = 40  # Producer primary thread
+CONSUMER_REGISTER_SIZE = 232  # Consumer primary thread
+
+
+def partition_shape():
+    """
+    Calculate the partition shape based on the block IDs.
+
+    It partitions the shape like below:
+    for(.. i < M ...)   --> blockIdx.x
+     for(.. j < N ...)  --> blockIdx.y
+      for(.. k < K ...)
+
+    Returns:
+        dimX (int): Dimension along the x-axis.
+        dimY (int): Dimension along the y-axis.
+    """
+    bidx = gpu.block_id(gpu.Dimension.x)
+    bidy = gpu.block_id(gpu.Dimension.y)
+    dimX = bidx * TILE_M
+    dimY = bidy * TILE_N
+    return dimX, dimY
+
+
+def tma_load(
+    mbar_group: Mbarriers,
+    a_tma: TMA,
+    b_tma: TMA,
+    slot,
+    stage,
+    p=None,
+):
+    """
+    TMA loads two input matrices from global memory to shared memory. It performs the following operations:
+
+       - tma.load a_shared_memory[off_x]  at coordinate [x, z]      (Loads 128x64)
+       - tma.load b_shared_memory[off_y1] at coordinate [y, x]      (Loads 64x64)
+       - tma.load b_shared_memory[off_y2] at coordinate [y + 64, x] (Loads 64x64)
+
+       mbarrier.arrive ta_count = 128x64x2x4
+    """
+    dimX, dimY = partition_shape()
+
+    tidx = gpu.thread_id(gpu.Dimension.x)
+    begin_b = NUM_STAGES * get_type_size(a_tma.tma_memref)
+    size_tma_a = get_type_size(a_tma.tma_memref)
+    size_tma_b = get_type_size(b_tma.tma_memref)
+    ta_count = size_tma_a + (size_tma_b * 2)
+
+    off_a = slot * size_tma_a
+    off_b = (slot * size_tma_a) + begin_b
+    off_b2 = off_b + size_tma_b
+    a_elem_ty = a_tma.tma_memref.element_type
+    b_elem_ty = b_tma.tma_memref.element_type
+    a = get_dynamic_shared_memory(a_tma.tma_memref.shape, a_elem_ty, off_a)
+    b1 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b)
+    b2 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b2)
+
+    mbar_group[slot].arrive(ta_count, predicate=p)
+    p = (tidx % WARP_GROUP_SIZE) == 0
+    c1 = stage * 64
+    a_tma.load(a, mbar_group[slot], coords=[c1, dimX], predicate=p)
+    b_tma.load(b1, mbar_group[slot], coords=[dimY, c1], predicate=p)
+    b_tma.load(b2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p)
+
+
+def bootstrap(a_tma: TMA, b_tma: TMA):
+    """
+    Initialize mbarriers and prefetch TMA descriptors.
+    """
+    tidx = gpu.thread_id(gpu.Dimension.x)
+    mbar_group_tma = Mbarriers(number_of_barriers=NUM_STAGES)
+    mbar_group_mma = Mbarriers(number_of_barriers=NUM_STAGES)
+    isThread0 = tidx == const(0)
+    with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
+        for i in scf.for_(0, NUM_STAGES, 1):
+            mbar_group_tma[i].init(1)
+            mbar_group_mma[i].init(1)
+            scf.yield_([])
+        a_tma.prefetch()
+        b_tma.prefetch()
+        scf.yield_([])
+
+    return mbar_group_tma, mbar_group_mma
+
+
+def switch_phase(stage, phase):
+    p = stage == (NUM_STAGES - 1)
+    phase = arith.select(
+        p,
+        (phase ^ const(True, ty=T.bool())),
+        phase,
+    )
+    return phase
+
+
+def producer_loop(
+    mbar_group_tma: Mbarriers,
+    mbar_group_mma: Mbarriers,
+    a_tma: TMA,
+    b_tma: TMA,
+    wg_me: Warpgroup,
+):
+    phase = const(True, ty=T.bool())
+
+    for iv, phase in scf.for_(0, (K // TILE_K), 1, [phase]):
+        stage = iv % NUM_STAGES
+        # Wait MMA to be done
+        mbar_group_mma[stage].try_wait(phase)
+        # New phase for mbarrier
+        phase = switch_phase(stage, phase)
+        # TMA Load
+        tma_load(mbar_group_tma, a_tma, b_tma, stage, iv, wg_me.is_wg_primary)
+        scf.yield_([phase])
+
+
+def consumer_loop(
+    mbar_group_tma: Mbarriers,
+    mbar_group_mma: Mbarriers,
+    a_tma: TMA,
+    b_tma: TMA,
+    wg_me: Warpgroup,
+):
+    begin_b = NUM_STAGES * get_type_size(a_tma.tma_memref)
+
+    size_a = TILE_M * TILE_K * get_type_size(T.f16())
+
+    phase = const(False, ty=T.bool())
+    A = WGMMAMatrix(WGMMAType.Descriptor, [TILE_M, TILE_K], desc=a_tma)
+    B = WGMMAMatrix(WGMMAType.Descriptor, [TILE_K, TILE_N], desc=b_tma)
+    D = WGMMAMatrix(WGMMAType.Accumulator, shape=[TILE_M, TILE_N], ty=T.f32())
+
+    for_op = scf.ForOp(const(0), const(K // TILE_K), const(1), [D.acc_op, phase])
+    with ir.InsertionPoint(for_op.body):
+        phase = for_op.inner_iter_args[1]
+        iv = for_op.induction_variable
+        stage = iv % NUM_STAGES
+
+        # Wait TMA for current stage
+        mbar_group_tma[stage].try_wait(phase)
+
+        # Find shared memory slot
+        offset_a = stage * size_a
+        offset_b = offset_a + begin_b
+        a_smem = get_dynamic_shared_memory([TILE_M, TILE_K], T.f16(), offset_a)
+        b_smem = get_dynamic_shared_memory([TILE_K, TILE_N], T.f16(), offset_b)
+
+        # Iterate input matrices, update accumulator
+        A.update_smem(a_smem)
+        B.update_smem(b_smem)
+        D.update_accumulator(for_op.inner_iter_args[0])
+
+        # Matrix Multiply
+        D += A @ B
+
+        # MMA Barrier Arrive
+        p_arrive = (iv > 0) & wg_me.is_wg_primary
+        with ir.InsertionPoint(scf.IfOp(p_arrive).then_block):
+            barId = arith.select((stage == 0), const(NUM_STAGES - 1), (stage - 1))
+            mbar_group_mma[barId].arrive()
+            scf.yield_([])
+
+        phase = switch_phase(stage, phase)
+        scf.yield_([D.acc_op, phase])
+
+    nvvm.WgmmaWaitGroupSyncOp(0)
+    D.update_accumulator(for_op.results[0])
+    return D
+
+
+def epilogue(D: WGMMAMatrix, d_dev):
+    """
+    Epilogue of the GEMM kernel. It stores the fragmented registers to global memory.
+
+    MatrixAccumulator D               # Fragmented results
+    store D -> Shared Memory          # Store Shared Memory
+    Shared Memory -> Z[dimX][dimY]    # Store Shared Memory to Global Memory
+
+    """
+    tidx = gpu.thread_id(gpu.Dimension.x)
+    dimX, dimY = partition_shape()
+    # s = tidx - WARP_GROUP_SIZE
+    # debug_print("[Epilogue] store to global memory @ s={}", s)
+
+    d_smem = get_dynamic_shared_memory([TILE_M, TILE_N], T.f32())
+    d_gmem = memref.subview(d_dev, [dimX, dimY], [TILE_M, TILE_N], [1, 1])
+
+    # Store (registers -> shared memory)
+    D.store_accumulator(d_smem)
+    gpu.barrier()
+
+    # Store (shared memory --> global memory)
+    for i in scf.for_(0, TILE_M, 1):
+        val = memref.load(d_smem, [i, tidx])
+        memref.store(val, d_gmem, [i, tidx])
+        scf.yield_([])
+
+
+ at NVDSL.mlir_func
+def gemm_warp_specialized(a, b, d):
+    token_ty = ir.Type.parse("!gpu.async.token")
+    t1 = gpu.wait(token_ty, [])
+    a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
+    b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])
+    d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], [])
+    t5 = gpu.memcpy(token_ty, [t4], a_dev, a)
+    t6 = gpu.memcpy(token_ty, [t5], b_dev, b)
+    t7 = gpu.wait(token_ty, [t6])
+
+    sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
+    a_tma = TMA([128, 64], a.type, swizzle=sw)
+    b_tma = TMA([64, 64], b.type, swizzle=sw)
+    a_tma.create_descriptor(a_dev)
+    b_tma.create_descriptor(b_dev)
+
+    grid = [(M // TILE_M), (N // TILE_N), 1]
+    block = [256, 1, 1]
+
+    size_a = get_type_size(a.type.element_type) * TILE_M * TILE_K
+    size_b = get_type_size(b.type.element_type) * TILE_N * TILE_K
+    smem_size_in_bytes = (size_a + size_b) * NUM_STAGES
+
+    @NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=smem_size_in_bytes)
+    def gemm_warp_specialized_kernel():
+        # Init Warpgroups
+        wg_producer = Warpgroup(PRODUCER_PRIMARY_THREAD, PRODUCER_REGISTER_SIZE)
+        wg_consumer = Warpgroup(CONSUMER_PRIMARY_THREAD, CONSUMER_REGISTER_SIZE)
+
+        # Initialize mbarriers and prefetch TMA descriptors
+        mbar_group_mma, mbar_group_tma = bootstrap(a_tma, b_tma)
+
+        # Producer performs TMA
+        with wg_producer:
+            producer_loop(mbar_group_tma, mbar_group_mma, a_tma, b_tma, wg_producer)
+
+        # Producer performs MMA/Tensor Core
+        with wg_consumer:
+            D = consumer_loop(mbar_group_tma, mbar_group_mma, a_tma, b_tma, wg_consumer)
+            epilogue(D, d_dev)
+
+    gemm_warp_specialized_kernel()
+
+    t8 = gpu.memcpy(token_ty, [t7], d, d_dev)
+    gpu.wait(None, [t8])
+
+
+# Python pass arguments to MLIR
+NUM_STAGES = 7
+N = 256
+M = 512
+K = 1024
+TILE_M = 128
+TILE_N = 128
+TILE_K = 64
+a = np.random.randn(M, K).astype(np.float16)
+b = np.random.randn(K, N).astype(np.float16)
+d = np.zeros((M, N), np.float32)
+
+gemm_warp_specialized(a, b, d)
+
+
+# Verify MLIR with reference computation
+ref_d = a.astype(np.float16) @ b.astype(np.float16)
+np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
+
+
+print("PASS")
+# CHECK-NOT: Mismatched elements
diff --git a/mlir/test/Examples/nvgpu/tools/nvdsl.py b/mlir/test/Examples/nvgpu/tools/nvdsl.py
index 28105911db8134..7d9e13c7f8627c 100644
--- a/mlir/test/Examples/nvgpu/tools/nvdsl.py
+++ b/mlir/test/Examples/nvgpu/tools/nvdsl.py
@@ -3,7 +3,7 @@
 import numpy as np
 from functools import partialmethod
 from mlir import ir
-from mlir.dialects import arith, func, gpu, memref, nvgpu
+from mlir.dialects import arith, func, gpu, memref, nvgpu, scf, nvvm
 from mlir.extras import types as T
 from mlir import runtime as rt
 from tools import nvgpucompiler
@@ -84,7 +84,9 @@ def arrive(self, txcount: int = 0, predicate=None):
                 self.mbar_group_op, txcount_op, self.id_op, predicate=predicate
             )
         else:
-            nvgpu.mbarrier_arrive(self.mbar_group_op, self.id_op, predicate=predicate)
+            nvgpu.mbarrier_arrive(
+                ir.Type.parse("!nvgpu.mbarrier.token"), self.mbar_group_op, self.id_op
+            )
 
     def try_wait(self, phase: bool = False, ticks: int = 10000000):
         ticks_op = const(ticks)
@@ -166,6 +168,33 @@ def load(self, dest, mbarrier: Mbarriers, coords=[0], predicate=None):
         )
 
 
+WARP_GROUP_SIZE = 128  # Number of threads in a warpgroup
+
+
+class Warpgroup:
+    def __init__(self, primaryThread, registerSize):
+        tidx = gpu.thread_id(gpu.Dimension.x)
+        self.primary_thread = primaryThread
+        self.register_size = registerSize
+        self.is_wg_primary = (tidx % WARP_GROUP_SIZE) == 0
+        self.wg_id = tidx / WARP_GROUP_SIZE
+        self.is_me = self.wg_id == (primaryThread // WARP_GROUP_SIZE)
+
+    def __enter__(self):
+        if_op = scf.IfOp(self.is_me)
+        self.ipoint_op = ir.InsertionPoint(if_op.then_block)
+        self.ipoint_op.__enter__()
+        if self.register_size < 64:
+            nvvm.setmaxregister(self.register_size, nvvm.SetMaxRegisterAction.decrease)
+        else:
+            nvvm.setmaxregister(self.register_size, nvvm.SetMaxRegisterAction.increase)
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        scf.yield_([])
+        self.ipoint_op.__exit__(exc_type, exc_value, traceback)
+        return True
+
+
 class WGMMAType(Enum):
     Accumulator = 1
     Descriptor = 2
@@ -408,7 +437,7 @@ def __str__(self):
                 module.operation.verify()
 
                 # Save IR in a file
-                saveIR(module)
+                # saveIR(module)
 
                 # Compile and JIT MLIR module
                 options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"

>From 77910b74b3d6e7fd77972b0a9449a8f22ac48a2f Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Fri, 12 Apr 2024 17:44:13 +0000
Subject: [PATCH 13/14] fix omments

---
 mlir/test/Examples/nvgpu/Ch0.py         | 16 +++++++++++-----
 mlir/test/Examples/nvgpu/Ch1.py         |  2 +-
 mlir/test/Examples/nvgpu/Ch5.py         |  6 +++---
 mlir/test/Examples/nvgpu/tools/nvdsl.py |  8 ++++----
 4 files changed, 19 insertions(+), 13 deletions(-)

diff --git a/mlir/test/Examples/nvgpu/Ch0.py b/mlir/test/Examples/nvgpu/Ch0.py
index 221ca43d37307a..0c3cc30fde85e4 100644
--- a/mlir/test/Examples/nvgpu/Ch0.py
+++ b/mlir/test/Examples/nvgpu/Ch0.py
@@ -7,7 +7,7 @@
 #
 # This program demonstrates Hello World
 #
-# This chapter introduces demonstrates:
+# This chapter demonstrates:
 #   1. Build MLIR function with arguments
 #   2. Build MLIR GPU kernel
 #   3. Print from a GPU thread
@@ -20,13 +20,19 @@
 from tools.nvdsl import *
 
 
-# 1. Build function with arguments
+# 1. The `mlir_func` decorator generates a MLIR `func.func`. 
+#    Everything inside the Python function becomes the body of the `func`. 
+#    The decorator also translates `alpha` to an `index` type.
 @NVDSL.mlir_func
 def main(alpha):
-    # 2. Build GPU kernel
+# 2. The `mlir_gpu_launch` decorator generates a MLIR `gpu.launch`. 
+#    Everything inside the Python function becomes the body of the `gpu.launch`.
+#    This allows for late outlining of the GPU kernel, enabling optimizations 
+#    like constant folding from host to device.
     @NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(4, 1, 1))
     def kernel():
         tidx = gpu.thread_id(gpu.Dimension.x)
+        # `+`` operator generates a `arith.addi`
         myValue = alpha + tidx
         # Print from a GPU thread
         gpu.printf("GPU thread %llu has %llu\n", [tidx, myValue])
@@ -34,12 +40,12 @@ def kernel():
     # 3. Call the GPU kernel
     kernel()
 
-
-# 4. Pass arguments, JIT compile and run the MLIR function
 alpha = 100
+# 4. The `mlir_func` decorator JIT compiles the IR and executes the MLIR function.
 main(alpha)
 
 
+
 # CHECK: GPU thread 0 has 100
 # CHECK: GPU thread 1 has 101
 # CHECK: GPU thread 2 has 102
diff --git a/mlir/test/Examples/nvgpu/Ch1.py b/mlir/test/Examples/nvgpu/Ch1.py
index 22e1eb7c92e005..315480138f66c8 100644
--- a/mlir/test/Examples/nvgpu/Ch1.py
+++ b/mlir/test/Examples/nvgpu/Ch1.py
@@ -7,7 +7,7 @@
 #
 # This program demonstrates 2D Saxpy
 #
-# This chapter introduces demonstrates:
+# This chapter demonstrates:
 #  1. Use MLIR GPU dialect to allocate and copy memory
 #  2. Compute 2D SAXPY kernel
 #  3. Pass numpy arrays to MLIR
diff --git a/mlir/test/Examples/nvgpu/Ch5.py b/mlir/test/Examples/nvgpu/Ch5.py
index 7881af2ab54439..a22de802452792 100644
--- a/mlir/test/Examples/nvgpu/Ch5.py
+++ b/mlir/test/Examples/nvgpu/Ch5.py
@@ -275,8 +275,8 @@ def gemm_warp_specialized(a, b, d):
     @NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=smem_size_in_bytes)
     def gemm_warp_specialized_kernel():
         # Init Warpgroups
-        wg_producer = Warpgroup(PRODUCER_PRIMARY_THREAD, PRODUCER_REGISTER_SIZE)
-        wg_consumer = Warpgroup(CONSUMER_PRIMARY_THREAD, CONSUMER_REGISTER_SIZE)
+        wg_producer = Warpgroup(primary_thread=128, register_size=40)
+        wg_consumer = Warpgroup(primary_thread=0, register_size=232)
 
         # Initialize mbarriers and prefetch TMA descriptors
         mbar_group_mma, mbar_group_tma = bootstrap(a_tma, b_tma)
@@ -285,7 +285,7 @@ def gemm_warp_specialized_kernel():
         with wg_producer:
             producer_loop(mbar_group_tma, mbar_group_mma, a_tma, b_tma, wg_producer)
 
-        # Producer performs MMA/Tensor Core
+        # Consumer performs MMA/Tensor Core
         with wg_consumer:
             D = consumer_loop(mbar_group_tma, mbar_group_mma, a_tma, b_tma, wg_consumer)
             epilogue(D, d_dev)
diff --git a/mlir/test/Examples/nvgpu/tools/nvdsl.py b/mlir/test/Examples/nvgpu/tools/nvdsl.py
index 7d9e13c7f8627c..2c1615d6f9d68b 100644
--- a/mlir/test/Examples/nvgpu/tools/nvdsl.py
+++ b/mlir/test/Examples/nvgpu/tools/nvdsl.py
@@ -172,13 +172,13 @@ def load(self, dest, mbarrier: Mbarriers, coords=[0], predicate=None):
 
 
 class Warpgroup:
-    def __init__(self, primaryThread, registerSize):
+    def __init__(self, primary_thread, register_size):
         tidx = gpu.thread_id(gpu.Dimension.x)
-        self.primary_thread = primaryThread
-        self.register_size = registerSize
+        self.primary_thread = primary_thread
+        self.register_size = register_size
         self.is_wg_primary = (tidx % WARP_GROUP_SIZE) == 0
         self.wg_id = tidx / WARP_GROUP_SIZE
-        self.is_me = self.wg_id == (primaryThread // WARP_GROUP_SIZE)
+        self.is_me = self.wg_id == (primary_thread // WARP_GROUP_SIZE)
 
     def __enter__(self):
         if_op = scf.IfOp(self.is_me)

>From 6b0a3d13328ba98a5c8b96dcea241dc06b8864ca Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Fri, 12 Apr 2024 18:19:44 +0000
Subject: [PATCH 14/14] address comments

---
 mlir/test/Examples/nvgpu/Ch0.py         |   4 +-
 mlir/test/Examples/nvgpu/Ch1.py         |  12 +-
 mlir/test/Examples/nvgpu/Ch2.py         |   8 +-
 mlir/test/Examples/nvgpu/Ch4.py         |  51 ++++---
 mlir/test/Examples/nvgpu/Ch5.py         |  65 +++++----
 mlir/test/Examples/nvgpu/lit.local.cfg  |   3 +-
 mlir/test/Examples/nvgpu/nvdsl.mlir     | 186 ++++++++++++++++++++++++
 mlir/test/Examples/nvgpu/tools/nvdsl.py |  10 +-
 8 files changed, 264 insertions(+), 75 deletions(-)
 create mode 100644 mlir/test/Examples/nvgpu/nvdsl.mlir

diff --git a/mlir/test/Examples/nvgpu/Ch0.py b/mlir/test/Examples/nvgpu/Ch0.py
index 0c3cc30fde85e4..ff294c7ec416ef 100644
--- a/mlir/test/Examples/nvgpu/Ch0.py
+++ b/mlir/test/Examples/nvgpu/Ch0.py
@@ -5,9 +5,7 @@
 #  Chapter 0 : Hello World
 # ===----------------------------------------------------------------------===//
 #
-# This program demonstrates Hello World
-#
-# This chapter demonstrates:
+# This program demonstrates Hello World:
 #   1. Build MLIR function with arguments
 #   2. Build MLIR GPU kernel
 #   3. Print from a GPU thread
diff --git a/mlir/test/Examples/nvgpu/Ch1.py b/mlir/test/Examples/nvgpu/Ch1.py
index 315480138f66c8..da65aa2ef6a172 100644
--- a/mlir/test/Examples/nvgpu/Ch1.py
+++ b/mlir/test/Examples/nvgpu/Ch1.py
@@ -5,13 +5,11 @@
 #  Chapter 1 : 2D Saxpy
 # ===----------------------------------------------------------------------===//
 #
-# This program demonstrates 2D Saxpy
-#
-# This chapter demonstrates:
-#  1. Use MLIR GPU dialect to allocate and copy memory
-#  2. Compute 2D SAXPY kernel
-#  3. Pass numpy arrays to MLIR
-#  4. Verify MLIR with reference computation
+# This program demonstrates 2D Saxpy:
+#  1. Use GPU dialect to allocate and copy memory host to gpu and vice versa
+#  2. Computes 2D SAXPY kernel using operator overloading
+#  3. Pass numpy arrays to MLIR as memref arguments
+#  4. Verify MLIR program with reference computation in python
 #
 # ===----------------------------------------------------------------------===//
 
diff --git a/mlir/test/Examples/nvgpu/Ch2.py b/mlir/test/Examples/nvgpu/Ch2.py
index 2e316cf406e99f..78c14cb2c7ad8c 100644
--- a/mlir/test/Examples/nvgpu/Ch2.py
+++ b/mlir/test/Examples/nvgpu/Ch2.py
@@ -9,9 +9,11 @@
 # but it loads data using TMA (Tensor Memory Accelerator)
 #
 # This chapter introduces demonstrates:
-#  1. Create and initialize asynchronous transactional barrier (mbarrier)
-#  2. Execute Tensor Memory Accelerator (TMA) Load
-#  3. Wait for completion of TMA load with mbarrier
+#  1. Computes 2D SAXPY in the same way as Ch1.py but loads data using TMA
+#  2. Create and initialize 1 asynchronous transactional barrier (mbarrier)
+#  3. Thread-0 Load request data load from TMA for each thread block
+#  4. Each thread block loads <1x32xf32> for x and y.
+#  5. Wait for completion of TMA load with mbarrier
 #
 # ===----------------------------------------------------------------------===//
 
diff --git a/mlir/test/Examples/nvgpu/Ch4.py b/mlir/test/Examples/nvgpu/Ch4.py
index 47a81cd9a65688..175d0aad5f77d7 100644
--- a/mlir/test/Examples/nvgpu/Ch4.py
+++ b/mlir/test/Examples/nvgpu/Ch4.py
@@ -17,7 +17,7 @@
 #
 # Loops illustration:
 #
-#  for s in range(NUM_STAGES):
+#  for s in range(num_stages):
 #    TMA_128x64_64x128...
 #  for ti in range(M//128):  # -> blockIdx.x
 #   for tj in range(N//128): # -> blockIdx.y
@@ -74,6 +74,7 @@ def tma_load(
     b_tma: TMA,
     slot,
     stage,
+    num_stages,
     p=None,
 ):
     """
@@ -88,7 +89,7 @@ def tma_load(
     dimX, dimY = partition_shape()
 
     tidx = gpu.thread_id(gpu.Dimension.x)
-    begin_b = NUM_STAGES * get_type_size(a_tma.tma_memref)
+    begin_b = num_stages * get_type_size(a_tma.tma_memref)
     size_tma_a = get_type_size(a_tma.tma_memref)
     size_tma_b = get_type_size(b_tma.tma_memref)
     ta_count = size_tma_a + (size_tma_b * 2)
@@ -113,15 +114,15 @@ def tma_load(
     b_tma.load(b2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p)
 
 
-def bootstrap(a_tma: TMA, b_tma: TMA):
+def initialize(a_tma: TMA, b_tma: TMA, num_stages):
     """
     Initialize mbarriers and prefetch TMA descriptors.
     """
     tidx = gpu.thread_id(gpu.Dimension.x)
-    mbar_group = Mbarriers(number_of_barriers=NUM_STAGES)
+    mbar_group = Mbarriers(number_of_barriers=num_stages)
     isThread0 = tidx == const(0)
     with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
-        for i in scf.for_(0, NUM_STAGES, 1):
+        for i in scf.for_(0, num_stages, 1):
             mbar_group[i].init(1)
             scf.yield_([])
         a_tma.prefetch()
@@ -131,7 +132,7 @@ def bootstrap(a_tma: TMA, b_tma: TMA):
     return mbar_group
 
 
-def prologue(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
+def prologue(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA, num_stages):
     """
     Prologue of the GEMM kernel. It loads 2 input matrices for each stage in loop like below:
 
@@ -139,13 +140,13 @@ def prologue(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
         tma_load x, y, stage
 
     """
-    ns = NUM_STAGES if NUM_STAGES == 1 else NUM_STAGES - 1
+    ns = num_stages if num_stages == 1 else num_stages - 1
     for iv in scf.for_(0, ns, 1):
-        tma_load(mbar_group, a_tma, b_tma, iv, iv)
+        tma_load(mbar_group, a_tma, b_tma, iv, iv, num_stages)
         scf.yield_([])
 
 
-def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
+def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA, num_stages):
     """
     Main loop of the Multistage GEMM kernel. It iterates through
     stages and performs matrix multiplication, loading data by TMA to shared memory. It like following
@@ -163,10 +164,10 @@ def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
             tma_load(x, y, nextSlot, nextStage)
 
     """
-    ns = NUM_STAGES if NUM_STAGES == 1 else NUM_STAGES - 1
+    ns = num_stages if num_stages == 1 else num_stages - 1
 
     tidx = gpu.thread_id(gpu.Dimension.x)
-    begin_b = NUM_STAGES * get_type_size(a_tma.tma_memref)
+    begin_b = num_stages * get_type_size(a_tma.tma_memref)
 
     size_a = TILE_M * TILE_K * get_type_size(T.f16())
 
@@ -182,7 +183,7 @@ def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
     with ir.InsertionPoint(for_op.body):
         phase = for_op.inner_iter_args[1]
         iv = for_op.induction_variable
-        stage = iv % NUM_STAGES
+        stage = iv % num_stages
 
         # Wait for current stage
         mbar_group[stage].try_wait(phase=phase)
@@ -202,18 +203,18 @@ def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
         D += A @ B
 
         # Wait Tensor Core for single stage
-        if NUM_STAGES == 1:
+        if num_stages == 1:
             nvvm.WgmmaWaitGroupSyncOp(0)
 
         # Load next stage
         pred = ((iv + ns) < const(K // TILE_K)) & (tidx == 0)
         nextStage = iv + ns
-        nextSlot = nextStage % NUM_STAGES
-        tma_load(mbar_group, a_tma, b_tma, nextSlot, nextStage, pred)
+        nextSlot = nextStage % num_stages
+        tma_load(mbar_group, a_tma, b_tma, nextSlot, nextStage, num_stages, pred)
 
         # Switch phase parity for the mbarrier
         newPhase = arith.select(
-            stage == (NUM_STAGES - 1),
+            stage == (num_stages - 1),
             (phase ^ const(True, ty=T.bool())),
             phase,
         )
@@ -250,9 +251,12 @@ def epilogue(D: WGMMAMatrix, d_dev):
         memref.store(val, d_gmem, [i, tidx])
         scf.yield_([])
 
-
+# The decorator generates 
+#   a -> memref<MxKxf16> 
+#   b -> memref<NxKf16> 
+#   d -> memref<MxNxf32> 
 @NVDSL.mlir_func
-def gemm_multistage(a, b, d):
+def gemm_multistage(a, b, d, num_stages):
     token_ty = ir.Type.parse("!gpu.async.token")
     t1 = gpu.wait(token_ty, [])
     a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
@@ -273,18 +277,18 @@ def gemm_multistage(a, b, d):
 
     size_a = get_type_size(a.type.element_type) * TILE_M * TILE_K
     size_b = get_type_size(b.type.element_type) * TILE_N * TILE_K
-    smem_size_in_bytes = (size_a + size_b) * NUM_STAGES
+    smem_size_in_bytes = (size_a + size_b) * num_stages
 
     @NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=smem_size_in_bytes)
     def gemm_multistage_kernel():
         # Initialize mbarriers and prefetch TMA descriptors
-        mbar_group = bootstrap(a_tma, b_tma)
+        mbar_group = initialize(a_tma, b_tma, num_stages)
 
         # Fill the pipeline stages
-        prologue(mbar_group, a_tma, b_tma)
+        prologue(mbar_group, a_tma, b_tma, num_stages)
 
         # Main loop
-        D = mainloop(mbar_group, a_tma, b_tma)
+        D = mainloop(mbar_group, a_tma, b_tma, num_stages)
 
         # Store registers to global memory
         epilogue(D, d_dev)
@@ -296,7 +300,6 @@ def gemm_multistage_kernel():
 
 
 # Python pass arguments to MLIR
-NUM_STAGES = 7
 N = 256
 M = 512
 K = 1024
@@ -307,7 +310,7 @@ def gemm_multistage_kernel():
 b = np.random.randn(K, N).astype(np.float16)
 d = np.zeros((M, N), np.float32)
 
-gemm_multistage(a, b, d)
+gemm_multistage(a, b, d, num_stages=7)
 
 
 # Verify MLIR with reference computation
diff --git a/mlir/test/Examples/nvgpu/Ch5.py b/mlir/test/Examples/nvgpu/Ch5.py
index a22de802452792..daffb38ba0db7e 100644
--- a/mlir/test/Examples/nvgpu/Ch5.py
+++ b/mlir/test/Examples/nvgpu/Ch5.py
@@ -5,7 +5,7 @@
 #  Chapter 5 : Warp Specialized GEMM with Tensor Core
 # ===----------------------------------------------------------------------===//
 #
-# This program exemplifies a GEMM operation for `f32+=f16*f16`, utilizing the
+# This program demonstrates a GEMM operation for `f32+=f16*f16`, utilizing the
 # Warp Specialized method with a tile size of 128x128x64. The code completely
 # parallelizes the two outermost loops into thread blocks. It launches two Warp
 # Groups (256 threads in total): one for the producer and the other for the consumer.
@@ -48,20 +48,19 @@
 import numpy as np
 
 
-PRODUCER_PRIMARY_THREAD = 128  # Producer primary thread
-CONSUMER_PRIMARY_THREAD = 0  # Consumer primary thread
-PRODUCER_REGISTER_SIZE = 40  # Producer primary thread
-CONSUMER_REGISTER_SIZE = 232  # Consumer primary thread
-
-
 def partition_shape():
     """
     Calculate the partition shape based on the block IDs.
 
-    It partitions the shape like below:
-    for(.. i < M ...)   --> blockIdx.x
-     for(.. j < N ...)  --> blockIdx.y
-      for(.. k < K ...)
+    It parallelizes the two outermost loops into thread blocks.
+    for ti in range(M//128):    # -> blockIdx.x
+     for tj in range(N//128):   # -> blockIdx.y
+      D = 0
+      for tk in range(K//64):
+       for i in range(128):
+        for j in range(128):
+         for k in range(64):
+           FMA
 
     Returns:
         dimX (int): Dimension along the x-axis.
@@ -80,6 +79,7 @@ def tma_load(
     b_tma: TMA,
     slot,
     stage,
+    num_stages,
     p=None,
 ):
     """
@@ -94,7 +94,7 @@ def tma_load(
     dimX, dimY = partition_shape()
 
     tidx = gpu.thread_id(gpu.Dimension.x)
-    begin_b = NUM_STAGES * get_type_size(a_tma.tma_memref)
+    begin_b = num_stages * get_type_size(a_tma.tma_memref)
     size_tma_a = get_type_size(a_tma.tma_memref)
     size_tma_b = get_type_size(b_tma.tma_memref)
     ta_count = size_tma_a + (size_tma_b * 2)
@@ -116,16 +116,16 @@ def tma_load(
     b_tma.load(b2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p)
 
 
-def bootstrap(a_tma: TMA, b_tma: TMA):
+def initialize(a_tma: TMA, b_tma: TMA, num_stages):
     """
     Initialize mbarriers and prefetch TMA descriptors.
     """
     tidx = gpu.thread_id(gpu.Dimension.x)
-    mbar_group_tma = Mbarriers(number_of_barriers=NUM_STAGES)
-    mbar_group_mma = Mbarriers(number_of_barriers=NUM_STAGES)
+    mbar_group_tma = Mbarriers(number_of_barriers=num_stages)
+    mbar_group_mma = Mbarriers(number_of_barriers=num_stages)
     isThread0 = tidx == const(0)
     with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
-        for i in scf.for_(0, NUM_STAGES, 1):
+        for i in scf.for_(0, num_stages, 1):
             mbar_group_tma[i].init(1)
             mbar_group_mma[i].init(1)
             scf.yield_([])
@@ -136,8 +136,8 @@ def bootstrap(a_tma: TMA, b_tma: TMA):
     return mbar_group_tma, mbar_group_mma
 
 
-def switch_phase(stage, phase):
-    p = stage == (NUM_STAGES - 1)
+def switch_phase(stage, phase, num_stages):
+    p = stage == (num_stages - 1)
     phase = arith.select(
         p,
         (phase ^ const(True, ty=T.bool())),
@@ -152,17 +152,18 @@ def producer_loop(
     a_tma: TMA,
     b_tma: TMA,
     wg_me: Warpgroup,
+    num_stages
 ):
     phase = const(True, ty=T.bool())
 
     for iv, phase in scf.for_(0, (K // TILE_K), 1, [phase]):
-        stage = iv % NUM_STAGES
+        stage = iv % num_stages
         # Wait MMA to be done
         mbar_group_mma[stage].try_wait(phase)
         # New phase for mbarrier
-        phase = switch_phase(stage, phase)
+        phase = switch_phase(stage, phase, num_stages)
         # TMA Load
-        tma_load(mbar_group_tma, a_tma, b_tma, stage, iv, wg_me.is_wg_primary)
+        tma_load(mbar_group_tma, a_tma, b_tma, stage, iv, num_stages, wg_me.is_wg_primary)
         scf.yield_([phase])
 
 
@@ -172,8 +173,9 @@ def consumer_loop(
     a_tma: TMA,
     b_tma: TMA,
     wg_me: Warpgroup,
+    num_stages
 ):
-    begin_b = NUM_STAGES * get_type_size(a_tma.tma_memref)
+    begin_b = num_stages * get_type_size(a_tma.tma_memref)
 
     size_a = TILE_M * TILE_K * get_type_size(T.f16())
 
@@ -186,7 +188,7 @@ def consumer_loop(
     with ir.InsertionPoint(for_op.body):
         phase = for_op.inner_iter_args[1]
         iv = for_op.induction_variable
-        stage = iv % NUM_STAGES
+        stage = iv % num_stages
 
         # Wait TMA for current stage
         mbar_group_tma[stage].try_wait(phase)
@@ -208,11 +210,11 @@ def consumer_loop(
         # MMA Barrier Arrive
         p_arrive = (iv > 0) & wg_me.is_wg_primary
         with ir.InsertionPoint(scf.IfOp(p_arrive).then_block):
-            barId = arith.select((stage == 0), const(NUM_STAGES - 1), (stage - 1))
+            barId = arith.select((stage == 0), const(num_stages - 1), (stage - 1))
             mbar_group_mma[barId].arrive()
             scf.yield_([])
 
-        phase = switch_phase(stage, phase)
+        phase = switch_phase(stage, phase, num_stages)
         scf.yield_([D.acc_op, phase])
 
     nvvm.WgmmaWaitGroupSyncOp(0)
@@ -249,7 +251,7 @@ def epilogue(D: WGMMAMatrix, d_dev):
 
 
 @NVDSL.mlir_func
-def gemm_warp_specialized(a, b, d):
+def gemm_warp_specialized(a, b, d, num_stages):
     token_ty = ir.Type.parse("!gpu.async.token")
     t1 = gpu.wait(token_ty, [])
     a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
@@ -270,7 +272,7 @@ def gemm_warp_specialized(a, b, d):
 
     size_a = get_type_size(a.type.element_type) * TILE_M * TILE_K
     size_b = get_type_size(b.type.element_type) * TILE_N * TILE_K
-    smem_size_in_bytes = (size_a + size_b) * NUM_STAGES
+    smem_size_in_bytes = (size_a + size_b) * num_stages
 
     @NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=smem_size_in_bytes)
     def gemm_warp_specialized_kernel():
@@ -279,15 +281,15 @@ def gemm_warp_specialized_kernel():
         wg_consumer = Warpgroup(primary_thread=0, register_size=232)
 
         # Initialize mbarriers and prefetch TMA descriptors
-        mbar_group_mma, mbar_group_tma = bootstrap(a_tma, b_tma)
+        mbar_group_mma, mbar_group_tma = initialize(a_tma, b_tma, num_stages)
 
         # Producer performs TMA
         with wg_producer:
-            producer_loop(mbar_group_tma, mbar_group_mma, a_tma, b_tma, wg_producer)
+            producer_loop(mbar_group_tma, mbar_group_mma, a_tma, b_tma, wg_producer, num_stages)
 
         # Consumer performs MMA/Tensor Core
         with wg_consumer:
-            D = consumer_loop(mbar_group_tma, mbar_group_mma, a_tma, b_tma, wg_consumer)
+            D = consumer_loop(mbar_group_tma, mbar_group_mma, a_tma, b_tma, wg_consumer, num_stages)
             epilogue(D, d_dev)
 
     gemm_warp_specialized_kernel()
@@ -297,7 +299,6 @@ def gemm_warp_specialized_kernel():
 
 
 # Python pass arguments to MLIR
-NUM_STAGES = 7
 N = 256
 M = 512
 K = 1024
@@ -308,7 +309,7 @@ def gemm_warp_specialized_kernel():
 b = np.random.randn(K, N).astype(np.float16)
 d = np.zeros((M, N), np.float32)
 
-gemm_warp_specialized(a, b, d)
+gemm_warp_specialized(a, b, d, num_stages = 7)
 
 
 # Verify MLIR with reference computation
diff --git a/mlir/test/Examples/nvgpu/lit.local.cfg b/mlir/test/Examples/nvgpu/lit.local.cfg
index e586b55573898a..689cd252e7a254 100644
--- a/mlir/test/Examples/nvgpu/lit.local.cfg
+++ b/mlir/test/Examples/nvgpu/lit.local.cfg
@@ -1,3 +1,4 @@
 config.unsupported = False
 if not config.enable_cuda_runner or not config.mlir_run_cuda_sm90_tests:
-  config.unsupported = True
\ No newline at end of file
+  config.unsupported = True
+  
\ No newline at end of file
diff --git a/mlir/test/Examples/nvgpu/nvdsl.mlir b/mlir/test/Examples/nvgpu/nvdsl.mlir
new file mode 100644
index 00000000000000..e61f7ba0e9a67d
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/nvdsl.mlir
@@ -0,0 +1,186 @@
+module {
+  func.func @gemm_warp_specialized(%arg0: memref<512x1024xf16>, %arg1: memref<1024x256xf16>, %arg2: memref<512x256xf32>) attributes {llvm.emit_c_interface} {
+    %0 = gpu.wait async
+    %memref, %asyncToken = gpu.alloc async [%0] () : memref<512x1024xf16>
+    %memref_0, %asyncToken_1 = gpu.alloc async [%asyncToken] () : memref<1024x256xf16>
+    %memref_2, %asyncToken_3 = gpu.alloc async [%asyncToken_1] () : memref<512x256xf32>
+    %1 = gpu.memcpy async [%asyncToken_3] %memref, %arg0 : memref<512x1024xf16>, memref<512x1024xf16>
+    %2 = gpu.memcpy async [%1] %memref_0, %arg1 : memref<1024x256xf16>, memref<1024x256xf16>
+    %3 = gpu.wait async [%2]
+    %cast = memref.cast %memref : memref<512x1024xf16> to memref<*xf16>
+    %c128 = arith.constant 128 : index
+    %c64 = arith.constant 64 : index
+    %4 = nvgpu.tma.create.descriptor %cast box[%c128, %c64] : memref<*xf16> -> <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
+    %cast_4 = memref.cast %memref_0 : memref<1024x256xf16> to memref<*xf16>
+    %c64_5 = arith.constant 64 : index
+    %c64_6 = arith.constant 64 : index
+    %5 = nvgpu.tma.create.descriptor %cast_4 box[%c64_5, %c64_6] : memref<*xf16> -> <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
+    %c4 = arith.constant 4 : index
+    %c2 = arith.constant 2 : index
+    %c1 = arith.constant 1 : index
+    %c256 = arith.constant 256 : index
+    %c1_7 = arith.constant 1 : index
+    %c1_8 = arith.constant 1 : index
+    %c229376_i32 = arith.constant 229376 : i32
+    gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c4, %arg10 = %c2, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c256, %arg13 = %c1_7, %arg14 = %c1_8) dynamic_shared_memory_size %c229376_i32 {
+      %thread_id_x = gpu.thread_id  x
+      %c128_9 = arith.constant 128 : index
+      %7 = arith.remui %thread_id_x, %c128_9 : index
+      %c0 = arith.constant 0 : index
+      %8 = arith.cmpi eq, %7, %c0 : index
+      %c128_10 = arith.constant 128 : index
+      %9 = arith.divui %thread_id_x, %c128_10 : index
+      %c1_11 = arith.constant 1 : index
+      %10 = arith.cmpi eq, %9, %c1_11 : index
+      %thread_id_x_12 = gpu.thread_id  x
+      %c128_13 = arith.constant 128 : index
+      %11 = arith.remui %thread_id_x_12, %c128_13 : index
+      %c0_14 = arith.constant 0 : index
+      %12 = arith.cmpi eq, %11, %c0_14 : index
+      %c128_15 = arith.constant 128 : index
+      %13 = arith.divui %thread_id_x_12, %c128_15 : index
+      %c0_16 = arith.constant 0 : index
+      %14 = arith.cmpi eq, %13, %c0_16 : index
+      %thread_id_x_17 = gpu.thread_id  x
+      %15 = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+      %16 = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+      %c0_18 = arith.constant 0 : index
+      %17 = arith.cmpi eq, %thread_id_x_17, %c0_18 : index
+      scf.if %17 {
+        %c0_19 = arith.constant 0 : index
+        %c7 = arith.constant 7 : index
+        %c1_20 = arith.constant 1 : index
+        scf.for %arg15 = %c0_19 to %c7 step %c1_20 {
+          %c1_21 = arith.constant 1 : index
+          nvgpu.mbarrier.init %15[%arg15], %c1_21 : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+          %c1_22 = arith.constant 1 : index
+          nvgpu.mbarrier.init %16[%arg15], %c1_22 : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+        }
+        nvgpu.tma.prefetch.descriptor %4 : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
+        nvgpu.tma.prefetch.descriptor %5 : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
+      }
+      scf.if %10 {
+        nvvm.setmaxregister  decrease 40
+        %true = arith.constant true
+        %c0_19 = arith.constant 0 : index
+        %c16 = arith.constant 16 : index
+        %c1_20 = arith.constant 1 : index
+        %18 = scf.for %arg15 = %c0_19 to %c16 step %c1_20 iter_args(%arg16 = %true) -> (i1) {
+          %c7 = arith.constant 7 : index
+          %19 = arith.remui %arg15, %c7 : index
+          %c10000000 = arith.constant 10000000 : index
+          nvgpu.mbarrier.try_wait.parity %15[%19], %arg16, %c10000000 : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+          %c6 = arith.constant 6 : index
+          %20 = arith.cmpi eq, %19, %c6 : index
+          %true_21 = arith.constant true
+          %21 = arith.xori %arg16, %true_21 : i1
+          %22 = arith.select %20, %21, %arg16 : i1
+          %block_id_x = gpu.block_id  x
+          %block_id_y = gpu.block_id  y
+          %c128_22 = arith.constant 128 : index
+          %23 = arith.muli %block_id_x, %c128_22 : index
+          %c128_23 = arith.constant 128 : index
+          %24 = arith.muli %block_id_y, %c128_23 : index
+          %thread_id_x_24 = gpu.thread_id  x
+          %c16384 = arith.constant 16384 : index
+          %25 = arith.muli %19, %c16384 : index
+          %c16384_25 = arith.constant 16384 : index
+          %26 = arith.muli %19, %c16384_25 : index
+          %c114688 = arith.constant 114688 : index
+          %27 = arith.addi %26, %c114688 : index
+          %c8192 = arith.constant 8192 : index
+          %28 = arith.addi %27, %c8192 : index
+          %29 = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+          %view = memref.view %29[%25][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>>
+          %30 = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+          %view_26 = memref.view %30[%27][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
+          %31 = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+          %view_27 = memref.view %31[%28][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
+          %c32768 = arith.constant 32768 : index
+          nvgpu.mbarrier.arrive.expect_tx %16[%19], %c32768, predicate = %8 : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+          %c128_28 = arith.constant 128 : index
+          %32 = arith.remui %thread_id_x_24, %c128_28 : index
+          %c0_29 = arith.constant 0 : index
+          %33 = arith.cmpi eq, %32, %c0_29 : index
+          %c64_30 = arith.constant 64 : index
+          %34 = arith.muli %arg15, %c64_30 : index
+          nvgpu.tma.async.load %4[%34, %23], %16[%19] to %view, predicate = %33 : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<128x64xf16, #gpu.address_space<workgroup>>
+          nvgpu.tma.async.load %5[%24, %34], %16[%19] to %view_26, predicate = %33 : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>>
+          %c64_31 = arith.constant 64 : index
+          %35 = arith.addi %24, %c64_31 : index
+          nvgpu.tma.async.load %5[%35, %34], %16[%19] to %view_27, predicate = %33 : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>>
+          scf.yield %22 : i1
+        }
+      }
+      scf.if %14 {
+        nvvm.setmaxregister  increase 232
+        %false = arith.constant false
+        %18 = nvgpu.warpgroup.mma.init.accumulator -> <fragmented = vector<128x128xf32>>
+        %c0_19 = arith.constant 0 : index
+        %c16 = arith.constant 16 : index
+        %c1_20 = arith.constant 1 : index
+        %19:2 = scf.for %arg15 = %c0_19 to %c16 step %c1_20 iter_args(%arg16 = %18, %arg17 = %false) -> (!nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>, i1) {
+          %c7 = arith.constant 7 : index
+          %23 = arith.remui %arg15, %c7 : index
+          %c10000000 = arith.constant 10000000 : index
+          nvgpu.mbarrier.try_wait.parity %16[%23], %arg17, %c10000000 : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+          %c16384 = arith.constant 16384 : index
+          %24 = arith.muli %23, %c16384 : index
+          %c114688 = arith.constant 114688 : index
+          %25 = arith.addi %24, %c114688 : index
+          %26 = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+          %view_28 = memref.view %26[%24][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>>
+          %27 = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+          %view_29 = memref.view %27[%25][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x128xf16, #gpu.address_space<workgroup>>
+          %28 = nvgpu.warpgroup.generate.descriptor %view_28, %4 : memref<128x64xf16, #gpu.address_space<workgroup>>, <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>>
+          %29 = nvgpu.warpgroup.generate.descriptor %view_29, %5 : memref<64x128xf16, #gpu.address_space<workgroup>>, <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>>
+          %30 = nvgpu.warpgroup.mma %28, %29, %arg16 {transposeB} : <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>>, <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>>, <fragmented = vector<128x128xf32>> -> <fragmented = vector<128x128xf32>>
+          %c0_30 = arith.constant 0 : index
+          %31 = arith.cmpi ugt, %arg15, %c0_30 : index
+          %32 = arith.andi %31, %12 : i1
+          scf.if %32 {
+            %c0_31 = arith.constant 0 : index
+            %36 = arith.cmpi eq, %23, %c0_31 : index
+            %c6_32 = arith.constant 6 : index
+            %c1_33 = arith.constant 1 : index
+            %37 = arith.subi %23, %c1_33 : index
+            %38 = arith.select %36, %c6_32, %37 : index
+            %39 = nvgpu.mbarrier.arrive %15[%38] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> !nvgpu.mbarrier.token
+          }
+          %c6 = arith.constant 6 : index
+          %33 = arith.cmpi eq, %23, %c6 : index
+          %true = arith.constant true
+          %34 = arith.xori %arg17, %true : i1
+          %35 = arith.select %33, %34, %arg17 : i1
+          scf.yield %30, %35 : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>, i1
+        }
+        nvvm.wgmma.wait.group.sync.aligned 0
+        %thread_id_x_21 = gpu.thread_id  x
+        %block_id_x = gpu.block_id  x
+        %block_id_y = gpu.block_id  y
+        %c128_22 = arith.constant 128 : index
+        %20 = arith.muli %block_id_x, %c128_22 : index
+        %c128_23 = arith.constant 128 : index
+        %21 = arith.muli %block_id_y, %c128_23 : index
+        %22 = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+        %c0_24 = arith.constant 0 : index
+        %view = memref.view %22[%c0_24][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x128xf32, #gpu.address_space<workgroup>>
+        %subview = memref.subview %memref_2[%20, %21] [128, 128] [1, 1] : memref<512x256xf32> to memref<128x128xf32, strided<[256, 1], offset: ?>>
+        nvgpu.warpgroup.mma.store %19#0, %view : <fragmented = vector<128x128xf32>> to memref<128x128xf32, #gpu.address_space<workgroup>>
+        gpu.barrier
+        %c0_25 = arith.constant 0 : index
+        %c128_26 = arith.constant 128 : index
+        %c1_27 = arith.constant 1 : index
+        scf.for %arg15 = %c0_25 to %c128_26 step %c1_27 {
+          %23 = memref.load %view[%arg15, %thread_id_x_21] : memref<128x128xf32, #gpu.address_space<workgroup>>
+          memref.store %23, %subview[%arg15, %thread_id_x_21] : memref<128x128xf32, strided<[256, 1], offset: ?>>
+        }
+      }
+      gpu.terminator
+    }
+    %6 = gpu.memcpy async [%3] %arg2, %memref_2 : memref<512x256xf32>, memref<512x256xf32>
+    gpu.wait [%6]
+    return
+  }
+}
+
diff --git a/mlir/test/Examples/nvgpu/tools/nvdsl.py b/mlir/test/Examples/nvgpu/tools/nvdsl.py
index 2c1615d6f9d68b..06faf692346d03 100644
--- a/mlir/test/Examples/nvgpu/tools/nvdsl.py
+++ b/mlir/test/Examples/nvgpu/tools/nvdsl.py
@@ -173,6 +173,7 @@ def load(self, dest, mbarrier: Mbarriers, coords=[0], predicate=None):
 
 class Warpgroup:
     def __init__(self, primary_thread, register_size):
+        assert (primary_thread % WARP_GROUP_SIZE) == 0
         tidx = gpu.thread_id(gpu.Dimension.x)
         self.primary_thread = primary_thread
         self.register_size = register_size
@@ -280,7 +281,6 @@ def get_dynamic_shared_memory(shape=None, ty=None, offset: int = 0):
     )
 
 
- at staticmethod
 def get_mlir_ty(arg):
     def get_mlir_ty_from_np(dtype):
         if dtype == np.float16:
@@ -433,11 +433,11 @@ def __str__(self):
                         result = funcBody(*fargs, **kwargs)
                         func.ReturnOp([])
 
-                # Verify the module
-                module.operation.verify()
-
                 # Save IR in a file
-                # saveIR(module)
+                saveIR(module)
+
+                # Verify the module
+                # module.operation.verify()
 
                 # Compile and JIT MLIR module
                 options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"



More information about the Mlir-commits mailing list