[Mlir-commits] [mlir] [mlir][nvgpu] NVGPU Tutorials (PR #87065)

Guray Ozen llvmlistbot at llvm.org
Mon Apr 1 07:06:36 PDT 2024


https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/87065

>From 5a2e21c9c071a8ba2b7a4cbbf68dc65889256efd Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Fri, 29 Mar 2024 13:09:45 +0000
Subject: [PATCH 1/3] [mlir][nvgpu] Zero to Hero: Programming Nvidia Hopper
 Tensor Core with MLIR's NVGPU Dialect

I have a tutorial on nvgpu dialect using python bindings at EuroLLVM 2024. For that, I implemented tutorial codes in Python. The focus is the nvgpu dialect and how to use its advanced features. I thought it might be useful to upstream this.

The tutorial codes are as follows:
- **Ch0.py:** Hello World
- **Ch1.py:** 2D Saxpy
- **Ch2.py:** 2D Saxpy using TMA engine
- **Ch3.py:** GEMM 64x64x64 using Tensor Core and TMA
- **Ch4.py:** Multistage performant GEMM using Tensor Core and TMA

I might implement two more chapters, but they are more like GPU programming than compiler related.
- **Ch5.py:** Warp Specilized GEMM
- **Ch6.py:** Warp Specilized Persistent GEMM

This PR also introduces the nvdsl class, making IR building in the tutorial easier.
---
 mlir/test/Examples/nvgpu/Ch0.py               |  46 ++
 mlir/test/Examples/nvgpu/Ch1.py               |  68 +++
 mlir/test/Examples/nvgpu/Ch2.py               |  94 ++++
 mlir/test/Examples/nvgpu/Ch3.py               |  92 ++++
 mlir/test/Examples/nvgpu/Ch4.py               | 291 +++++++++++++
 mlir/test/Examples/nvgpu/lit.local.cfg        |   3 +
 mlir/test/Examples/nvgpu/tools/lit.local.cfg  |   3 +
 mlir/test/Examples/nvgpu/tools/nvdsl.py       | 404 ++++++++++++++++++
 .../Examples/nvgpu/tools/nvgpucompiler.py     |  45 ++
 9 files changed, 1046 insertions(+)
 create mode 100644 mlir/test/Examples/nvgpu/Ch0.py
 create mode 100644 mlir/test/Examples/nvgpu/Ch1.py
 create mode 100644 mlir/test/Examples/nvgpu/Ch2.py
 create mode 100644 mlir/test/Examples/nvgpu/Ch3.py
 create mode 100644 mlir/test/Examples/nvgpu/Ch4.py
 create mode 100644 mlir/test/Examples/nvgpu/lit.local.cfg
 create mode 100644 mlir/test/Examples/nvgpu/tools/lit.local.cfg
 create mode 100644 mlir/test/Examples/nvgpu/tools/nvdsl.py
 create mode 100644 mlir/test/Examples/nvgpu/tools/nvgpucompiler.py

diff --git a/mlir/test/Examples/nvgpu/Ch0.py b/mlir/test/Examples/nvgpu/Ch0.py
new file mode 100644
index 00000000000000..221ca43d37307a
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/Ch0.py
@@ -0,0 +1,46 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN:   %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+#  Chapter 0 : Hello World
+# ===----------------------------------------------------------------------===//
+#
+# This program demonstrates Hello World
+#
+# This chapter introduces demonstrates:
+#   1. Build MLIR function with arguments
+#   2. Build MLIR GPU kernel
+#   3. Print from a GPU thread
+#   4. Pass arguments, JIT compile and run the MLIR function
+#
+# ===----------------------------------------------------------------------===//
+
+
+from mlir.dialects import gpu
+from tools.nvdsl import *
+
+
+# 1. Build function with arguments
+ at NVDSL.mlir_func
+def main(alpha):
+    # 2. Build GPU kernel
+    @NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(4, 1, 1))
+    def kernel():
+        tidx = gpu.thread_id(gpu.Dimension.x)
+        myValue = alpha + tidx
+        # Print from a GPU thread
+        gpu.printf("GPU thread %llu has %llu\n", [tidx, myValue])
+
+    # 3. Call the GPU kernel
+    kernel()
+
+
+# 4. Pass arguments, JIT compile and run the MLIR function
+alpha = 100
+main(alpha)
+
+
+# CHECK: GPU thread 0 has 100
+# CHECK: GPU thread 1 has 101
+# CHECK: GPU thread 2 has 102
+# CHECK: GPU thread 3 has 103
diff --git a/mlir/test/Examples/nvgpu/Ch1.py b/mlir/test/Examples/nvgpu/Ch1.py
new file mode 100644
index 00000000000000..a888c61358a024
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/Ch1.py
@@ -0,0 +1,68 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN:   %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+#  Chapter 1 : 2D Saxpy
+# ===----------------------------------------------------------------------===//
+#
+# This program demonstrates 2D Saxpy
+#
+# This chapter introduces demonstrates:
+#  1. Use MLIR GPU dialect to allocate and copy memory
+#  2. Compute 2D SAXPY kernel
+#  3. Pass numpy arrays to MLIR
+#  4. Verify MLIR with reference computation
+#
+# ===----------------------------------------------------------------------===//
+
+
+from mlir import ir
+from mlir.dialects import gpu, memref
+from tools.nvdsl import *
+import numpy as np
+
+
+ at NVDSL.mlir_func
+def saxpy(x, y, alpha):
+    # 1. Use MLIR GPU dialect to allocate and copy memory
+    token_ty = ir.Type.parse("!gpu.async.token")
+    t1 = gpu.wait(token_ty, [])
+    x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
+    y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
+    t4 = gpu.memcpy(token_ty, [t3], x_dev, x)
+    t5 = gpu.memcpy(token_ty, [t4], y_dev, y)
+    t6 = gpu.wait(token_ty, [t5])
+
+    # 2. Compute 2D SAXPY kernel
+    @NVDSL.mlir_gpu_launch(grid=(M, 1, 1), block=(N, 1, 1))
+    def saxpy_kernel():
+        bidx = gpu.block_id(gpu.Dimension.x)
+        tidx = gpu.thread_id(gpu.Dimension.x)
+        x_val = memref.load(x_dev, [bidx, tidx])
+        y_val = memref.load(y_dev, [bidx, tidx])
+
+        # SAXPY: y[i] += a * x[i];
+        y_val += x_val * alpha
+
+        memref.store(y_val, y_dev, [bidx, tidx])
+
+    saxpy_kernel()
+
+    t7 = gpu.memcpy(token_ty, [t6], y, y_dev)
+    gpu.wait(token_ty, [t7])
+
+
+# 3. Pass numpy arrays to MLIR
+M = 256
+N = 32
+alpha = 2.0
+x = np.ones((M, N), np.float32)
+y = np.ones((M, N), np.float32)
+ref = np.ones((M, N), np.float32)
+saxpy(x, y, alpha)
+
+#  4. Verify MLIR with reference computation
+ref += x * alpha
+np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
+print("PASS")
+# CHECK-NOT: Mismatched elements
diff --git a/mlir/test/Examples/nvgpu/Ch2.py b/mlir/test/Examples/nvgpu/Ch2.py
new file mode 100644
index 00000000000000..caa2691a06d69e
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/Ch2.py
@@ -0,0 +1,94 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN:   %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+#  Chapter 2 : 2D Saxpy with TMA
+# ===----------------------------------------------------------------------===//
+#
+# This program demonstrates 2D Saxpy. It is same as Chapter 1, 
+# but it loads data using TMA (Tensor Memory Accelerator)
+#
+# This chapter introduces demonstrates:
+#  1. Create and initialize asynchronous transactional barrier (mbarrier)
+#  2. Execute Tensor Memory Accelerator (TMA) Load
+#  3. Wait for completion of TMA load with mbarrier
+#
+# ===----------------------------------------------------------------------===//
+
+from mlir import ir
+from mlir.dialects import nvgpu, scf, arith, memref, vector, gpu
+from tools.nvdsl import *
+from mlir import runtime as rt
+from mlir.extras import types as T
+import numpy as np
+
+
+ at NVDSL.mlir_func
+def saxpy_tma(x, y, alpha):
+    token_ty = ir.Type.parse("!gpu.async.token")
+    t1 = gpu.wait(token_ty, [])
+    x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
+    y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
+    t4 = gpu.memcpy(token_ty, [t3], x_dev, x)
+    t5 = gpu.memcpy(token_ty, [t4], y_dev, y)
+    t6 = gpu.wait(token_ty, [t5])
+
+    x_tma = TMA((M, N), x.type)
+    y_tma = TMA((M, N), y.type)
+    x_tma.create_descriptor(x_dev)
+    y_tma.create_descriptor(y_dev)
+
+    @NVDSL.mlir_gpu_launch(grid=(M, 1, 1), block=(N, 1, 1), smem=65536)
+    def saxpy_tma_kernel():
+        bidx = gpu.block_id(gpu.Dimension.x)
+        tidx = gpu.thread_id(gpu.Dimension.x)
+        isThread0 = tidx == 0
+
+        # 1. Create and initialize asynchronous transactional barrier (mbarrier)
+        mbar_group = Mbarriers(number_of_barriers=1)
+        with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
+            mbar_group[0].init(1)
+            x_tma.prefetch()
+            y_tma.prefetch()
+            scf.yield_([])
+
+        x_smem = get_dynamic_shared_memory((M, N), T.f32())
+        y_smem = get_dynamic_shared_memory((M, N), T.f32(), offset=M * N * 2)
+
+        # 2. Execute Tensor Memory Accelerator (TMA) Load
+        with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
+            x_tma.load(x_smem, mbar_group[0])
+            y_tma.load(y_smem, mbar_group[0])
+            mbar_group[0].arrive(txcount=M * N * 2 * 4)
+            scf.yield_([])
+
+        # 3. Wait for completion of TMA load with mbarrier
+        mbar_group[0].try_wait()
+
+        x_val = memref.load(x_smem, [bidx, tidx])
+        y_val = memref.load(y_smem, [bidx, tidx])
+
+        # SAXPY: y[i] += a * x[i];
+        y_val += x_val * alpha
+
+        memref.store(y_val, y_dev, [bidx, tidx])
+
+    saxpy_tma_kernel()
+
+    t7 = gpu.memcpy(token_ty, [t6], y, y_dev)
+    gpu.wait(token_ty, [t7])
+
+
+M = 256
+N = 32
+alpha = 2.0
+x = np.ones((M, N), np.float32)
+y = np.ones((M, N), np.float32)
+ref = np.ones((M, N), np.float32)
+saxpy_tma(x, y, alpha)
+
+#  4. Verify MLIR with reference computation
+ref += x * alpha
+np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
+print("PASS")
+# CHECK-NOT: Mismatched elements
diff --git a/mlir/test/Examples/nvgpu/Ch3.py b/mlir/test/Examples/nvgpu/Ch3.py
new file mode 100644
index 00000000000000..802cb59ead1555
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/Ch3.py
@@ -0,0 +1,92 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN:   %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+#  Chapter 3 : GEMM 64x64x64 with Tensor Core
+# ===----------------------------------------------------------------------===//
+#
+# This program demonstrates a GEMM operation with 64x64x64 matrix multiplication
+#
+# This chapter introduces demonstrates:
+# 1. Execute TMA Load for two input matrices
+# 2. Performs Tensor Core GEMM 64x64x64 by warpgroup
+# 3. Stores fragmented registers to global memory by warpgroup
+#
+# ===----------------------------------------------------------------------===//
+
+
+from mlir import ir
+from mlir.dialects import nvgpu, scf, arith, memref, vector, gpu
+from tools.nvdsl import *
+from mlir.extras import types as T
+import numpy as np
+
+
+ at NVDSL.mlir_func
+def gemm_64_64_64(x, y, z):
+    token_ty = ir.Type.parse("!gpu.async.token")
+    t1 = gpu.wait(token_ty, [])
+    x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
+    y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
+    z_dev, t4 = gpu.alloc(z.type, token_ty, [t3], [], [])
+    t5 = gpu.memcpy(token_ty, [t4], x_dev, x)
+    t6 = gpu.memcpy(token_ty, [t5], y_dev, y)
+    t7 = gpu.wait(token_ty, [t6])
+
+    sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
+    x_tma = TMA([N, N], x.type, swizzle=sw)
+    y_tma = TMA([N, N], y.type, swizzle=sw)
+    x_tma.create_descriptor(x_dev)
+    y_tma.create_descriptor(y_dev)
+
+    @NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(128, 1, 1), smem=16384)
+    def gemm_tma_kernel():
+        tidx = gpu.thread_id(gpu.Dimension.x)
+
+        mbar_group = Mbarriers(number_of_barriers=1)
+        isThread0 = tidx == 0
+        with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
+            mbar_group[0].init(1)
+            x_tma.prefetch()
+            y_tma.prefetch()
+            scf.yield_([])
+
+        x_smem = get_dynamic_shared_memory((N, N), T.f16())
+        y_smem = get_dynamic_shared_memory((N, N), T.f16(), offset=N * N * 2)
+
+        # 1. Execute TMA Load for two input matrices
+        with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
+            x_tma.load(x_smem, mbar_group[0])
+            y_tma.load(y_smem, mbar_group[0])
+            tx_count = get_type_size(x_tma.tma_memref) + get_type_size(y_tma.tma_memref)
+            mbar_group[0].arrive(tx_count)
+            scf.yield_([])
+
+        mbar_group[0].try_wait()
+
+        # 2. Performs Tensor Core GEMM 64x64x64 by warpgroup
+        A = Matrix(x_smem, x_tma, N, N)
+        B = Matrix(y_smem, y_tma, N, N)
+        C = MatrixAccumulator(N, N, T.f32()).op()
+        D = Matrix.matmul(A, B, C)
+
+        # 3. Stores fragmented registers to global memory by warpgroup
+        nvgpu.warpgroup_mma_store(D, z_dev)
+
+    gemm_tma_kernel()
+
+    t8 = gpu.memcpy(token_ty, [t7], z, z_dev)
+    gpu.wait(None, [t8])
+
+
+# Python pass arguments to MLIR
+N = 64
+x = np.random.randn(N, N).astype(np.float16)
+y = np.random.randn(N, N).astype(np.float16)
+z = np.zeros((N, N), np.float32)
+gemm_64_64_64(x, y, z)
+
+ref = x.astype(np.float16) @ y.astype(np.float16)
+np.testing.assert_allclose(z, ref, rtol=5e-03, atol=1e-01)
+print("PASS")
+# CHECK-NOT: Mismatched elements
diff --git a/mlir/test/Examples/nvgpu/Ch4.py b/mlir/test/Examples/nvgpu/Ch4.py
new file mode 100644
index 00000000000000..d222d6e60aafad
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/Ch4.py
@@ -0,0 +1,291 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN:   %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+#  Chapter 4 : Multistage GEMM with Tensor Core
+# ===----------------------------------------------------------------------===//
+#
+# This program demonstrates a GEMM operation with 64x64x64 matrix multiplication
+#
+# This chapter introduces demonstrates:
+#  1. Partition shape based on block IDs
+#  2. Prologue
+#    2.1 Execute TMA Load for two input matrices for each stage
+#  3. Main loop
+#    3.1 Wait for completion of TMA load with mbarrier
+#    3.2 Performs Tensor Core GEMM 64x128x64 by warpgroup
+#    3.3 Load next stage if needed
+#  4. Epilogue
+#    4.1 Store fragmented registers to shared memory
+#    4.2 Store shared memory to global
+#
+# ===----------------------------------------------------------------------===//
+
+
+from mlir import ir
+from mlir.dialects import gpu, scf, nvgpu, nvvm
+from mlir.extras import types as T
+from tools.nvdsl import *
+import numpy as np
+
+
+def partition_shape():
+    """
+    Calculate the partition shape based on the block IDs.
+
+    It partitions the shape like below:
+    for(.. i < M ...)   --> blockIdx.x
+     for(.. j < N ...)  --> blockIdx.y
+      for(.. k < K ...)
+
+    Returns:
+        dimX (int): Dimension along the x-axis.
+        dimY (int): Dimension along the y-axis.
+    """
+    bidx = gpu.block_id(gpu.Dimension.x)
+    bidy = gpu.block_id(gpu.Dimension.y)
+    dimX = bidx * TILE_M
+    dimY = bidy * TILE_N
+    return dimX, dimY
+
+
+def tma_load(
+    mbar_group: Mbarriers,
+    x_tma: TMA,
+    y_tma: TMA,
+    slot,
+    stage,
+    p=None,
+):
+    """
+    TMA loads two input matrices from global memory to shared memory. It performs the following operations:
+
+       - tma.load x_shared_memory[offset] at coordinate [x, y] (Loads 128x64)
+       - tma.load y_shared_memory[offset] at coordinate [x, y] (Loads 64x64)
+       - tma.load y_shared_memory[offset] at coordinate [x, y] (Loads 64x64)
+
+       mbarrier.arrive tx_count = 128x64x2x4
+    """
+    dimX, dimY = partition_shape()
+
+    tidx = gpu.thread_id(gpu.Dimension.x)
+    begin_y = NUM_STAGES * get_type_size(x_tma.tma_memref)
+    size_tma_x = get_type_size(x_tma.tma_memref)
+    size_tma_y = get_type_size(y_tma.tma_memref)
+    tx_count = size_tma_x + (size_tma_y * 2)
+    tidx = gpu.thread_id(gpu.Dimension.x)
+
+    p = tidx == 0 if p is None else p
+
+    off_x = slot * size_tma_x
+    off_y = (slot * size_tma_x) + begin_y
+    off_y2 = off_y + size_tma_y
+    x = get_dynamic_shared_memory(
+        x_tma.tma_memref.shape, x_tma.tma_memref.element_type, off_x
+    )
+    y1 = get_dynamic_shared_memory(
+        y_tma.tma_memref.shape, y_tma.tma_memref.element_type, off_y
+    )
+    y2 = get_dynamic_shared_memory(
+        y_tma.tma_memref.shape, y_tma.tma_memref.element_type, off_y2
+    )
+
+    mbar_group[slot].arrive(tx_count, predicate=p)
+
+    c1 = stage * 64
+    x_tma.load(x, mbar_group[slot], coords=[c1, dimX], predicate=p)
+    y_tma.load(y1, mbar_group[slot], coords=[dimY, c1], predicate=p)
+    y_tma.load(y2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p)
+
+
+def bootstrap(x_tma: TMA, y_tma: TMA):
+    """
+    Initialize mbarriers and prefetch TMA descriptors.
+    """
+    tidx = gpu.thread_id(gpu.Dimension.x)
+    mbar_group = Mbarriers(number_of_barriers=NUM_STAGES)
+    isThread0 = tidx == const(0)
+    with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
+        for i in scf.for_(0, NUM_STAGES, 1):
+            mbar_group[i].init(1)
+            scf.yield_([])
+        x_tma.prefetch()
+        y_tma.prefetch()
+        scf.yield_([])
+
+    return mbar_group
+
+
+def prologue(mbar_group: Mbarriers, x_tma: TMA, y_tma: TMA):
+    """
+    Prologue of the GEMM kernel. It loads 2 input matrices for each stage in loop like below:
+
+    for stage in range(NUM_STAGES):
+        tma_load x, y, stage
+
+    """
+    ns = NUM_STAGES if NUM_STAGES == 1 else NUM_STAGES - 1
+    for iv in scf.for_(0, ns, 1):
+        tma_load(mbar_group, x_tma, y_tma, iv, iv)
+        scf.yield_([])
+
+
+def mainloop(mbar_group: Mbarriers, x_tma: TMA, y_tma: TMA):
+    """
+    Main loop of the Multistage GEMM kernel. It iterates through
+    stages and performs matrix multiplication, loading data by TMA to shared memory. It like following
+
+    MatrixAccumulator D
+    for k in range(K // TILE_K):
+
+        try_wait(stage, ...)    # Wait TMA load
+
+        Matrix A(stage, ...)    # Find shared memory slot
+        Matrix B(stage, ...)    # Find shared memory slot
+        D += A @ B              # Multiply and accumulate
+
+        if(needLoad)            # Load next stage if needed
+            tma_load(x, y, nextSlot, nextStage)
+
+    """
+    ns = NUM_STAGES if NUM_STAGES == 1 else NUM_STAGES - 1
+
+    tidx = gpu.thread_id(gpu.Dimension.x)
+    begin_y = NUM_STAGES * get_type_size(x_tma.tma_memref)
+
+    size_x = TILE_M * TILE_K * get_type_size(T.f16())
+
+    C = MatrixAccumulator(TILE_M, TILE_N, T.f32()).op()
+    pp = const(False, ty=T.bool())
+
+    # Main Loop
+    for_op = scf.ForOp(const(0), const(K // TILE_K), const(1), [C, pp])
+    with ir.InsertionPoint(for_op.body):
+        pp = for_op.inner_iter_args[1]
+        iv = for_op.induction_variable
+        stage = iv % NUM_STAGES
+
+        # Wait for current stage
+        mbar_group[stage].try_wait(phase=pp)
+
+        # Find shared memory slot
+        offset_x = stage * size_x
+        offset_y = offset_x + begin_y
+        x_smem = get_dynamic_shared_memory([TILE_M, TILE_K], T.f16(), offset_x)
+        y_smem = get_dynamic_shared_memory([TILE_K, TILE_N], T.f16(), offset_y)
+
+        # Matrix Multiply
+        A = Matrix(x_smem, x_tma, TILE_M, TILE_K)
+        B = Matrix(y_smem, y_tma, TILE_K, TILE_N)
+        C = for_op.inner_iter_args[0]
+        D = Matrix.matmul(A, B, C)
+        if NUM_STAGES == 1:
+            nvvm.WgmmaWaitGroupSyncOp(0)
+
+        # Load next stage
+        pred = ((iv + ns) < const(K // TILE_K)) & (tidx == 0)
+        nextStage = iv + ns
+        nextSlot = nextStage % NUM_STAGES
+        tma_load(mbar_group, x_tma, y_tma, nextSlot, nextStage, pred)
+
+        # Switch phase parity for the mbarrier
+        switched = pp ^ const(True, ty=T.bool())
+        newPP = arith.select(
+            stage == (NUM_STAGES - 1),
+            switched,
+            pp,
+        )
+        scf.yield_([D, newPP])
+
+    nvvm.WgmmaWaitGroupSyncOp(0)
+
+    return for_op.results[0]
+
+
+def epilogue(D, z_dev):
+    """
+    Epilogue of the GEMM kernel. It stores the fragmented registers to global memory.
+
+    MatrixAccumulator D               # Fragmented results
+    store D -> Shared Memory          # Store Shared Memory
+    Shared Memory -> Z[dimX][dimY]    # Store Shared Memory to Global Memory
+
+    """
+    tidx = gpu.thread_id(gpu.Dimension.x)
+    dimX, dimY = partition_shape()
+
+    z_smem = get_dynamic_shared_memory([TILE_M, TILE_N], T.f32())
+    z_gmem = memref.subview(z_dev, [dimX, dimY], [TILE_M, TILE_N], [1, 1])
+
+    # Store (registers -> shared memory)
+    nvgpu.WarpgroupMmaStoreOp(D, z_smem)
+    gpu.barrier()
+
+    # Store (shared memory --> global memory)
+    for i in scf.for_(0, TILE_M, 1):
+        val = memref.load(z_smem, [i, tidx])
+        memref.store(val, z_gmem, [i, tidx])
+        scf.yield_([])
+
+
+ at NVDSL.mlir_func
+def gemm_multistage(x, y, z):
+    token_ty = ir.Type.parse("!gpu.async.token")
+    t1 = gpu.wait(token_ty, [])
+    x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
+    y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
+    z_dev, t4 = gpu.alloc(z.type, token_ty, [t3], [], [])
+    t5 = gpu.memcpy(token_ty, [t4], x_dev, x)
+    t6 = gpu.memcpy(token_ty, [t5], y_dev, y)
+    t7 = gpu.wait(token_ty, [t6])
+
+    sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
+    x_tma = TMA([128, 64], x.type, swizzle=sw)
+    y_tma = TMA([64, 64], y.type, swizzle=sw)
+    x_tma.create_descriptor(x_dev)
+    y_tma.create_descriptor(y_dev)
+
+    grid = [(M // TILE_M), (N // TILE_N), 1]
+    block = [128, 1, 1]
+    @NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=229440)
+    def gemm_multistage_kernel():
+        # Initialize mbarriers and prefetch TMA descriptors
+        mbar_group = bootstrap(x_tma, y_tma)
+
+        # Fill the pipeline stages
+        prologue(mbar_group, x_tma, y_tma)
+
+        # Main loop
+        D = mainloop(mbar_group, x_tma, y_tma)
+
+        # Store registers to global memory
+        epilogue(D, z_dev)
+
+    gemm_multistage_kernel()
+
+    t8 = gpu.memcpy(token_ty, [t7], z, z_dev)
+    gpu.wait(None, [t8])
+
+
+# Python pass arguments to MLIR
+NUM_STAGES = 7
+N = 256
+M = 512
+K = 1024
+TILE_M = 128
+TILE_N = 128
+TILE_K = 64
+x = np.random.randn(M, K).astype(np.float16)
+y = np.random.randn(K, N).astype(np.float16)
+z = np.zeros((M, N), np.float32)
+
+gemm_multistage(x, y, z)
+
+
+# Verify MLIR with reference computation
+ref = x.astype(np.float16) @ y.astype(np.float16)
+np.testing.assert_allclose(z, ref, rtol=5e-03, atol=1e-01)
+
+
+print("PASS")
+# CHECK-NOT: Mismatched elements
diff --git a/mlir/test/Examples/nvgpu/lit.local.cfg b/mlir/test/Examples/nvgpu/lit.local.cfg
new file mode 100644
index 00000000000000..e586b55573898a
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/lit.local.cfg
@@ -0,0 +1,3 @@
+config.unsupported = False
+if not config.enable_cuda_runner or not config.mlir_run_cuda_sm90_tests:
+  config.unsupported = True
\ No newline at end of file
diff --git a/mlir/test/Examples/nvgpu/tools/lit.local.cfg b/mlir/test/Examples/nvgpu/tools/lit.local.cfg
new file mode 100644
index 00000000000000..d9f34f219c4d95
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/tools/lit.local.cfg
@@ -0,0 +1,3 @@
+# Files in this directory are tools, not tests.
+config.unsupported = True
+
diff --git a/mlir/test/Examples/nvgpu/tools/nvdsl.py b/mlir/test/Examples/nvgpu/tools/nvdsl.py
new file mode 100644
index 00000000000000..4c774f4d4deac9
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/tools/nvdsl.py
@@ -0,0 +1,404 @@
+from enum import Enum
+import functools, sys, ctypes, os, errno
+import numpy as np
+from functools import partialmethod
+from mlir import ir
+from mlir.dialects import arith, func, gpu, memref, nvgpu
+from mlir.extras import types as T
+from mlir import runtime as rt
+from tools import nvgpucompiler
+
+DEBUG = True
+MLIR_DYNAMIC = -9223372036854775808
+
+
+def const(value: int, ty=None):
+    ty = T.index() if ty is None else ty
+    if isinstance(value, ir.Value) and (value.type.isinstance(value.type) or T.bool().isinstance(value.type)):
+        return value
+    return arith.constant(ty, value)
+
+
+def get_type_size(ty):
+    if ir.MemRefType.isinstance(ty):
+        size = get_type_size(ty.element_type)
+        for sz in ty.shape:
+            size *= sz
+        return size
+    if ir.FloatType.isinstance(ty):
+        return ir.FloatType(ty).width // 8
+    if ir.IntegerType.isinstance(ty):
+        return ir.IntegerType(ty).width // 8
+    raise NotImplementedError(ty)
+
+def get_mlir_func_obj_ty(inputArgs):
+    args = []
+    c_int_p = ctypes.c_int * 1
+    c_float_p = ctypes.c_float * 1
+    c_bool_p = ctypes.c_bool * 1
+    for arg in inputArgs:
+        if isinstance(arg, bool):
+            args.append(c_bool_p(arg))
+        elif isinstance(arg, int):
+            args.append(c_int_p(arg))
+        elif isinstance(arg, float):
+            args.append(c_float_p(arg))
+        elif isinstance(arg, np.ndarray):
+            args.append(
+                ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(arg)))
+            )
+        else:
+            raise NotImplementedError(arg)
+    return args
+
+
+class Mbarriers:
+    def __init__(self, number_of_barriers=1):
+        self.mbar_ty = ir.Type.parse(
+            "!nvgpu.mbarrier.group<memorySpace=#gpu.address_space<workgroup>, num_barriers = "
+            + str(number_of_barriers)
+            + ">"
+        )
+        self.mbar_group_op = nvgpu.mbarrier_create(self.mbar_ty)
+        self.number_of_barriers = number_of_barriers
+
+    def __getitem__(self, key):
+        self.id_op = const(key)
+        return self
+
+    def init(self, count: int, predicate=None):
+        count_op = const(count)
+        if predicate is None:
+            nvgpu.mbarrier_init(self.mbar_group_op, count_op, self.id_op)
+        else:
+            nvgpu.mbarrier_init(
+                self.mbar_group_op, count_op, self.id_op, predicate=predicate
+            )
+
+    def arrive(self, txcount: int = 0, predicate=None):
+        if txcount != 0:
+            txcount_op = const(txcount)
+            nvgpu.mbarrier_arrive_expect_tx(
+                self.mbar_group_op, txcount_op, self.id_op, predicate=predicate
+            )
+        else:
+            nvgpu.mbarrier_arrive(self.mbar_group_op, self.id_op, predicate=predicate)
+
+    def try_wait(self, phase: bool = False, ticks: int = 10000000):
+        ticks_op = const(ticks)
+        phase_op = const(phase, T.bool())
+        nvgpu.MBarrierTryWaitParityOp(
+            self.mbar_group_op,
+            phase_op,
+            ticks_op,
+            mbarId=self.id_op,
+        )
+
+
+class TMA:
+    """A class that builds a TMA descriptor."""
+
+    def __init__(
+        self,
+        shape,
+        memref_ty,
+        swizzle=nvgpu.TensorMapSwizzleKind.SWIZZLE_NONE,
+        l2promo=nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
+        oob=nvgpu.TensorMapOOBKind.OOB_ZERO,
+        interleave=nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
+    ):
+        self.swizzle = swizzle  # mlir.nvgpu.TensorMapSwizzleKind
+        self.l2promo = l2promo  # mlir.nvgpu.TensorMapL2PromoKind
+        self.oob = oob  # mlir.nvgpu.TensorMapOOBKind
+        self.interleave = interleave  # mlir.nvgpu.TensorMapInterleaveKind
+        self.shape = shape
+        self.memref_ty = memref_ty  # MemRefType
+        self.lastDim = 64
+        self.requiredLoad = 1
+        self.tma_shape = shape
+        self.tma_memref = ir.MemRefType.get(shape, memref_ty.element_type)
+
+    @property
+    def tensormap_descriptor_ty(self):
+        """Returns a tensormap descriptor type."""
+        memref_str = f"memref<{self.tma_shape[0]}x{self.tma_shape[1]}x{self.memref_ty.element_type}, 3>"
+        parse_str = f"!nvgpu.tensormap.descriptor<tensor = {memref_str},\
+                                              swizzle = {self.swizzle},\
+                                              l2promo = {self.l2promo},\
+                                              oob = {self.oob},\
+                                              interleave = {self.interleave}>"
+
+        return ir.Type.parse(parse_str)
+
+    def create_descriptor(self, device_ptr):
+        tma_descriptor_ty = self.tensormap_descriptor_ty
+        device_unranked_memref = memref.CastOp(
+            ir.UnrankedMemRefType.get(self.memref_ty.element_type,
+                                      self.memref_ty.memory_space),
+            device_ptr,
+        )
+        self.tma_descriptor = nvgpu.TmaCreateDescriptorOp(
+            tma_descriptor_ty, device_unranked_memref,
+            map(const, self.tma_shape))
+        return self.tma_descriptor.result
+
+    def prefetch(self, predicate=None):
+        nvgpu.tma_prefetch_descriptor(self.tma_descriptor, predicate=predicate)
+
+    def load(self, dest, mbarrier: Mbarriers, coords=[0,0], predicate=None):
+        coord_ops = [const(c) for c in coords]
+        nvgpu.TmaAsyncLoadOp(
+            dest,
+            mbarrier.mbar_group_op,
+            self.tma_descriptor,
+            coordinates=coord_ops,
+            mbarId=mbarrier.id_op,
+            predicate=predicate,
+        )
+
+
+class MatrixAccumulator:
+    def __init__(self, M, N, ty):
+        self.M = M
+        self.N = N
+        self.ty = ty
+
+    @property
+    def acc_ty(self):
+        return ir.Type.parse(
+            "!nvgpu.warpgroup.accumulator<fragmented=vector<"
+            + str(self.M)
+            + "x"
+            + str(self.N)
+            + "x"
+            + str(self.ty)
+            + ">>"
+        )
+
+    def op(self):
+        return nvgpu.warpgroup_mma_init_accumulator(self.acc_ty)
+
+
+class Matrix:
+
+    def __init__(self, smem, tma_descriptor: TMA, M, N):
+        self.tma_descriptor = tma_descriptor
+        self.smem = smem
+        self.M = M
+        self.N = N
+
+    @property
+    def wgmma_ty(self):
+        return ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" +
+                             str(self.M) + "x" +
+                             str(self.N) + "x" +
+                             str(self.tma_descriptor.memref_ty.element_type) +
+                             ", #gpu.address_space<workgroup>>>")
+
+    def matmul(lhs, rhs, acc):
+        wgmma_desc_lhs = nvgpu.warpgroup_generate_descriptor(
+            lhs.wgmma_ty, lhs.smem, lhs.tma_descriptor.tma_descriptor)
+        wgmma_desc_rhs = nvgpu.warpgroup_generate_descriptor(
+            rhs.wgmma_ty, rhs.smem, rhs.tma_descriptor.tma_descriptor)
+        return nvgpu.WarpgroupMmaOp(acc.type,
+                                    wgmma_desc_lhs,
+                                    wgmma_desc_rhs,
+                                    acc,
+                                    transposeB=True)
+
+def get_dynamic_shared_memory(shape=None, ty=None, offset: int = 0):
+    smem_space_str = "#gpu.address_space<workgroup>"
+    smem_space = ir.Attribute.parse(smem_space_str)
+    dynamic_smem = gpu.dynamic_shared_memory(
+        ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space)
+    )
+    if shape is None:
+        return dynamic_smem
+    memref_ty = ir.MemRefType.get(shape, ty, memory_space=smem_space)
+    return memref.view(
+        ir.MemRefType.get(
+            memref_ty.shape, memref_ty.element_type, memory_space=smem_space
+        ),
+        dynamic_smem,
+        const(offset),
+        [],
+    )
+
+ at staticmethod
+def get_mlir_ty(arg):
+    def get_mlir_ty_from_np(dtype):
+        if dtype == np.float16:
+            return T.f16()
+        if dtype == np.float32:
+            return T.f32()
+        if dtype == np.float64:
+            return T.f64()
+        if dtype == np.int32:
+            return T.i32()
+        if dtype == np.int64:
+            return T.i64()
+        raise NotImplementedError(dtype)
+    if isinstance(arg, bool):
+        return T.bool()
+    elif isinstance(arg, int):
+        return T.index()
+    elif isinstance(arg, float):
+        return T.f32()
+    elif isinstance(arg, np.ndarray):
+        descriptor = rt.get_ranked_memref_descriptor(arg)
+        dtype = get_mlir_ty_from_np(arg.dtype)
+        shape = descriptor.shape
+        return memref.MemRefType.get(shape, dtype)
+    raise NotImplementedError(arg)
+
+class NVDSL:
+    @staticmethod
+    def mlir_gpu_launch(grid=(1, 1, 1), block=(1, 1, 1), smem=0):
+        def decorator(func):
+            @functools.wraps(func)
+            def wrapper(*args, **kwargs):
+                launch_op = gpu.LaunchOp(
+                    None,
+                    [],
+                    *map(const, grid),
+                    *map(const, block),
+                    dynamicSharedMemorySize=arith.constant(T.i32(), smem),
+                )
+                launch_op.body.blocks.append(*([T.index()] * 12))
+                with ir.InsertionPoint(launch_op.body.blocks[0]):
+                    result = func(*args, **kwargs)
+                    gpu.terminator()
+                    return result
+
+            return wrapper
+
+        return decorator
+
+    @staticmethod
+    def mlir_func(funcBody):
+        @functools.wraps(funcBody)
+        def wrapper(*args, **kwargs):
+            function_name = funcBody.__name__
+
+            def saveIR(module):
+                """Save generated IR"""
+                if True:  # self.saveIR:
+                    # print(mlir_nvgpu_module)
+                    original_stdout = sys.stdout
+                    with open("nvdsl.mlir", "w") as f:
+                        sys.stdout = f
+                        print(module)
+                        sys.stdout = original_stdout
+
+            def _binary_op(lhs, rhs, op: str, predAtt="") -> "ArithValue":
+                """Generate MLIR's Arith dialects binary operations."""
+                rhs = const(rhs)
+                if arith._is_float_type(lhs.type) and arith._is_float_type(rhs.type):
+                    op += "F"
+                    if op.startswith("Cmp"):
+                        predicateAttr = getattr(arith, f"CmpFPredicate").__dict__[
+                            predAtt
+                        ]
+                elif arith._is_integer_like_type(
+                    lhs.type
+                ) and arith._is_integer_like_type(lhs.type):
+                    if op == "Div" or op == "Rem":
+                        op += "U"
+                    op += "I"
+                    if op.startswith("Cmp"):
+                        predicateAttr = getattr(arith, f"CmpIPredicate").__dict__[
+                            predAtt
+                        ]
+                else:
+                    raise NotImplementedError(
+                        f"Unsupported '{op}' operands: {lhs}, {rhs}"
+                    )
+
+                if op.startswith("Cmp"):
+                    op = getattr(arith, f"{op}Op")
+
+                    return op(predicateAttr, lhs, rhs).result
+                else:
+                    op = getattr(arith, f"{op}Op")
+                    return op(lhs, rhs).result
+
+            @ir.register_value_caster(ir.IndexType.static_typeid)
+            @ir.register_value_caster(ir.F32Type.static_typeid)
+            @ir.register_value_caster(ir.F16Type.static_typeid)
+            @ir.register_value_caster(ir.F64Type.static_typeid)
+            @ir.register_value_caster(ir.IntegerType.static_typeid)
+            class ArithValue(ir.Value):
+                """Overloads operators for MLIR's Arith dialects binary operations."""
+
+                def __init__(self, v):
+                    super().__init__(v)
+
+                __add__ = partialmethod(_binary_op, op="Add")
+                __sub__ = partialmethod(_binary_op, op="Sub")
+                __mul__ = partialmethod(_binary_op, op="Mul")
+                __truediv__ = partialmethod(_binary_op, op="Div")
+                __mod__ = partialmethod(_binary_op, op="Rem")
+                __xor__ = partialmethod(_binary_op, op="XOr")
+                __lt__ = partialmethod(_binary_op, op="Cmp", predAtt="ult")
+                __le__ = partialmethod(_binary_op, op="Cmp", predAtt="ule")
+                __eq__ = partialmethod(_binary_op, op="Cmp", predAtt="eq")
+                __ne__ = partialmethod(_binary_op, op="Cmp", predAtt="ne")
+                __gt__ = partialmethod(_binary_op, op="Cmp", predAtt="ugt")
+                __ge__ = partialmethod(_binary_op, op="Cmp", predAtt="uge")
+                __and__ = partialmethod(_binary_op, op="And")
+                __or__ = partialmethod(_binary_op, op="Or")
+
+                def __str__(self):
+                    return (
+                        super()
+                        .__str__()
+                        .replace(ir.Value.__name__, ArithValue.__name__)
+                    )
+
+            # Generate MLIR Context and start generating IR
+            with ir.Context(), ir.Location.unknown():
+                types = []
+                for arg in args:
+                    types.append(get_mlir_ty(arg))
+
+                # Build IR
+                module = ir.Module.create()
+                with ir.InsertionPoint(module.body):
+                    fop = func.FuncOp(function_name, (types, []))
+                    fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+                    with ir.InsertionPoint(fop.add_entry_block()):
+                        fargs = []
+                        for i, a in enumerate(types):
+                            fargs.append(fop.arguments[i])
+
+                        # Call user function body
+                        result = funcBody(*fargs, **kwargs)
+                        func.ReturnOp([])
+
+                # Verify the module
+                module.operation.verify()
+
+                # Save IR in a file
+                # saveIR(module)
+
+                # Compile and JIT MLIR module
+                options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
+                support_lib = os.getenv("SUPPORT_LIB")
+                if not os.path.exists(support_lib):
+                    raise FileNotFoundError(
+                        errno.ENOENT, os.strerror(errno.ENOENT), support_lib
+                    )
+                compiler = nvgpucompiler.NvgpuCompiler(
+                    options, opt_level=3, shared_libs=[support_lib]
+                )
+                engine = compiler.compile_and_jit(module)
+
+            # Convert input arguments to MLIR arguments
+            newArgs = get_mlir_func_obj_ty(args)
+
+            # Run the compiled program
+            engine.invoke(function_name, *newArgs)
+
+            return result
+
+        return wrapper
diff --git a/mlir/test/Examples/nvgpu/tools/nvgpucompiler.py b/mlir/test/Examples/nvgpu/tools/nvgpucompiler.py
new file mode 100644
index 00000000000000..1c9cc74fcd169c
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/tools/nvgpucompiler.py
@@ -0,0 +1,45 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#  This file contains the Nvgpu class.
+
+from mlir import execution_engine
+from mlir import ir
+from mlir import passmanager
+from typing import Sequence
+import errno
+import os
+import sys
+
+_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(_SCRIPT_PATH)
+
+
+class NvgpuCompiler:
+    """Nvgpu class for compiling and building MLIR modules."""
+
+    def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
+        pipeline = f"builtin.module(gpu-lower-to-nvvm-pipeline{{{options}}})"
+        self.pipeline = pipeline
+        self.shared_libs = shared_libs
+        self.opt_level = opt_level
+
+    def __call__(self, module: ir.Module):
+        """Convenience application method."""
+        self.compile(module)
+
+    def compile(self, module: ir.Module):
+        """Compiles the module by invoking the nvgpu pipeline."""
+        passmanager.PassManager.parse(self.pipeline).run(module.operation)
+
+    def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+        """Wraps the module in a JIT execution engine."""
+        return execution_engine.ExecutionEngine(
+            module, opt_level=self.opt_level, shared_libs=self.shared_libs
+        )
+
+    def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+        """Compiles and jits the module."""
+        self.compile(module)
+        return self.jit(module)

>From c07ade60c2dd343165827587bebe48ae03f6785b Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Fri, 29 Mar 2024 13:51:21 +0000
Subject: [PATCH 2/3] format

---
 mlir/test/Examples/nvgpu/Ch2.py         |  2 +-
 mlir/test/Examples/nvgpu/Ch4.py         |  1 +
 mlir/test/Examples/nvgpu/tools/nvdsl.py | 49 +++++++++++++++----------
 3 files changed, 32 insertions(+), 20 deletions(-)

diff --git a/mlir/test/Examples/nvgpu/Ch2.py b/mlir/test/Examples/nvgpu/Ch2.py
index caa2691a06d69e..e9369227af0e85 100644
--- a/mlir/test/Examples/nvgpu/Ch2.py
+++ b/mlir/test/Examples/nvgpu/Ch2.py
@@ -5,7 +5,7 @@
 #  Chapter 2 : 2D Saxpy with TMA
 # ===----------------------------------------------------------------------===//
 #
-# This program demonstrates 2D Saxpy. It is same as Chapter 1, 
+# This program demonstrates 2D Saxpy. It is same as Chapter 1,
 # but it loads data using TMA (Tensor Memory Accelerator)
 #
 # This chapter introduces demonstrates:
diff --git a/mlir/test/Examples/nvgpu/Ch4.py b/mlir/test/Examples/nvgpu/Ch4.py
index d222d6e60aafad..5f790c4c190ca9 100644
--- a/mlir/test/Examples/nvgpu/Ch4.py
+++ b/mlir/test/Examples/nvgpu/Ch4.py
@@ -247,6 +247,7 @@ def gemm_multistage(x, y, z):
 
     grid = [(M // TILE_M), (N // TILE_N), 1]
     block = [128, 1, 1]
+
     @NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=229440)
     def gemm_multistage_kernel():
         # Initialize mbarriers and prefetch TMA descriptors
diff --git a/mlir/test/Examples/nvgpu/tools/nvdsl.py b/mlir/test/Examples/nvgpu/tools/nvdsl.py
index 4c774f4d4deac9..e27857d8ed1d79 100644
--- a/mlir/test/Examples/nvgpu/tools/nvdsl.py
+++ b/mlir/test/Examples/nvgpu/tools/nvdsl.py
@@ -14,7 +14,9 @@
 
 def const(value: int, ty=None):
     ty = T.index() if ty is None else ty
-    if isinstance(value, ir.Value) and (value.type.isinstance(value.type) or T.bool().isinstance(value.type)):
+    if isinstance(value, ir.Value) and (
+        value.type.isinstance(value.type) or T.bool().isinstance(value.type)
+    ):
         return value
     return arith.constant(ty, value)
 
@@ -31,6 +33,7 @@ def get_type_size(ty):
         return ir.IntegerType(ty).width // 8
     raise NotImplementedError(ty)
 
+
 def get_mlir_func_obj_ty(inputArgs):
     args = []
     c_int_p = ctypes.c_int * 1
@@ -133,19 +136,20 @@ def tensormap_descriptor_ty(self):
     def create_descriptor(self, device_ptr):
         tma_descriptor_ty = self.tensormap_descriptor_ty
         device_unranked_memref = memref.CastOp(
-            ir.UnrankedMemRefType.get(self.memref_ty.element_type,
-                                      self.memref_ty.memory_space),
+            ir.UnrankedMemRefType.get(
+                self.memref_ty.element_type, self.memref_ty.memory_space
+            ),
             device_ptr,
         )
         self.tma_descriptor = nvgpu.TmaCreateDescriptorOp(
-            tma_descriptor_ty, device_unranked_memref,
-            map(const, self.tma_shape))
+            tma_descriptor_ty, device_unranked_memref, map(const, self.tma_shape)
+        )
         return self.tma_descriptor.result
 
     def prefetch(self, predicate=None):
         nvgpu.tma_prefetch_descriptor(self.tma_descriptor, predicate=predicate)
 
-    def load(self, dest, mbarrier: Mbarriers, coords=[0,0], predicate=None):
+    def load(self, dest, mbarrier: Mbarriers, coords=[0, 0], predicate=None):
         coord_ops = [const(c) for c in coords]
         nvgpu.TmaAsyncLoadOp(
             dest,
@@ -180,7 +184,6 @@ def op(self):
 
 
 class Matrix:
-
     def __init__(self, smem, tma_descriptor: TMA, M, N):
         self.tma_descriptor = tma_descriptor
         self.smem = smem
@@ -189,22 +192,27 @@ def __init__(self, smem, tma_descriptor: TMA, M, N):
 
     @property
     def wgmma_ty(self):
-        return ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" +
-                             str(self.M) + "x" +
-                             str(self.N) + "x" +
-                             str(self.tma_descriptor.memref_ty.element_type) +
-                             ", #gpu.address_space<workgroup>>>")
+        return ir.Type.parse(
+            "!nvgpu.warpgroup.descriptor<tensor=memref<"
+            + str(self.M)
+            + "x"
+            + str(self.N)
+            + "x"
+            + str(self.tma_descriptor.memref_ty.element_type)
+            + ", #gpu.address_space<workgroup>>>"
+        )
 
     def matmul(lhs, rhs, acc):
         wgmma_desc_lhs = nvgpu.warpgroup_generate_descriptor(
-            lhs.wgmma_ty, lhs.smem, lhs.tma_descriptor.tma_descriptor)
+            lhs.wgmma_ty, lhs.smem, lhs.tma_descriptor.tma_descriptor
+        )
         wgmma_desc_rhs = nvgpu.warpgroup_generate_descriptor(
-            rhs.wgmma_ty, rhs.smem, rhs.tma_descriptor.tma_descriptor)
-        return nvgpu.WarpgroupMmaOp(acc.type,
-                                    wgmma_desc_lhs,
-                                    wgmma_desc_rhs,
-                                    acc,
-                                    transposeB=True)
+            rhs.wgmma_ty, rhs.smem, rhs.tma_descriptor.tma_descriptor
+        )
+        return nvgpu.WarpgroupMmaOp(
+            acc.type, wgmma_desc_lhs, wgmma_desc_rhs, acc, transposeB=True
+        )
+
 
 def get_dynamic_shared_memory(shape=None, ty=None, offset: int = 0):
     smem_space_str = "#gpu.address_space<workgroup>"
@@ -224,6 +232,7 @@ def get_dynamic_shared_memory(shape=None, ty=None, offset: int = 0):
         [],
     )
 
+
 @staticmethod
 def get_mlir_ty(arg):
     def get_mlir_ty_from_np(dtype):
@@ -238,6 +247,7 @@ def get_mlir_ty_from_np(dtype):
         if dtype == np.int64:
             return T.i64()
         raise NotImplementedError(dtype)
+
     if isinstance(arg, bool):
         return T.bool()
     elif isinstance(arg, int):
@@ -251,6 +261,7 @@ def get_mlir_ty_from_np(dtype):
         return memref.MemRefType.get(shape, dtype)
     raise NotImplementedError(arg)
 
+
 class NVDSL:
     @staticmethod
     def mlir_gpu_launch(grid=(1, 1, 1), block=(1, 1, 1), smem=0):

>From dd0b01e185fea20207455892cd4326fa0a5956e3 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Mon, 1 Apr 2024 14:05:12 +0000
Subject: [PATCH 3/3] address comments

---
 mlir/test/Examples/nvgpu/Ch3.py         |  63 +++++------
 mlir/test/Examples/nvgpu/Ch4.py         | 134 ++++++++++++------------
 mlir/test/Examples/nvgpu/tools/nvdsl.py |  26 ++---
 3 files changed, 107 insertions(+), 116 deletions(-)

diff --git a/mlir/test/Examples/nvgpu/Ch3.py b/mlir/test/Examples/nvgpu/Ch3.py
index 802cb59ead1555..b1623f9d645c2c 100644
--- a/mlir/test/Examples/nvgpu/Ch3.py
+++ b/mlir/test/Examples/nvgpu/Ch3.py
@@ -23,23 +23,24 @@
 
 
 @NVDSL.mlir_func
-def gemm_64_64_64(x, y, z):
+def gemm_64_64_64(x, y, d):
     token_ty = ir.Type.parse("!gpu.async.token")
     t1 = gpu.wait(token_ty, [])
-    x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
-    y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
-    z_dev, t4 = gpu.alloc(z.type, token_ty, [t3], [], [])
-    t5 = gpu.memcpy(token_ty, [t4], x_dev, x)
-    t6 = gpu.memcpy(token_ty, [t5], y_dev, y)
+    a_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
+    b_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
+    d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], [])
+    t5 = gpu.memcpy(token_ty, [t4], a_dev, x)
+    t6 = gpu.memcpy(token_ty, [t5], b_dev, y)
     t7 = gpu.wait(token_ty, [t6])
 
     sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
-    x_tma = TMA([N, N], x.type, swizzle=sw)
-    y_tma = TMA([N, N], y.type, swizzle=sw)
-    x_tma.create_descriptor(x_dev)
-    y_tma.create_descriptor(y_dev)
+    a_tma = TMA([N, N], x.type, swizzle=sw)
+    b_tma = TMA([N, N], y.type, swizzle=sw)
+    a_tma.create_descriptor(a_dev)
+    b_tma.create_descriptor(b_dev)
+    smem_size_in_bytes = get_type_size(x.type) + get_type_size(y.type)
 
-    @NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(128, 1, 1), smem=16384)
+    @NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(128, 1, 1), smem=smem_size_in_bytes)
     def gemm_tma_kernel():
         tidx = gpu.thread_id(gpu.Dimension.x)
 
@@ -47,46 +48,46 @@ def gemm_tma_kernel():
         isThread0 = tidx == 0
         with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
             mbar_group[0].init(1)
-            x_tma.prefetch()
-            y_tma.prefetch()
+            a_tma.prefetch()
+            b_tma.prefetch()
             scf.yield_([])
 
-        x_smem = get_dynamic_shared_memory((N, N), T.f16())
-        y_smem = get_dynamic_shared_memory((N, N), T.f16(), offset=N * N * 2)
+        a_smem = get_dynamic_shared_memory((N, N), T.f16())
+        b_smem = get_dynamic_shared_memory((N, N), T.f16(), offset=N * N * 2)
 
         # 1. Execute TMA Load for two input matrices
         with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
-            x_tma.load(x_smem, mbar_group[0])
-            y_tma.load(y_smem, mbar_group[0])
-            tx_count = get_type_size(x_tma.tma_memref) + get_type_size(y_tma.tma_memref)
-            mbar_group[0].arrive(tx_count)
+            a_tma.load(a_smem, mbar_group[0])
+            b_tma.load(b_smem, mbar_group[0])
+            ta_count = get_type_size(a_tma.tma_memref) + get_type_size(b_tma.tma_memref)
+            mbar_group[0].arrive(ta_count)
             scf.yield_([])
 
         mbar_group[0].try_wait()
 
         # 2. Performs Tensor Core GEMM 64x64x64 by warpgroup
-        A = Matrix(x_smem, x_tma, N, N)
-        B = Matrix(y_smem, y_tma, N, N)
-        C = MatrixAccumulator(N, N, T.f32()).op()
-        D = Matrix.matmul(A, B, C)
+        A = WarpgroupMatrix(a_smem, a_tma, N, N)
+        B = WarpgroupMatrix(b_smem, b_tma, N, N)
+        C = WarpgroupAccumulatorMatrix(N, N, T.f32()).op()
+        D = WarpgroupMatrix.matmul(A, B, C)
 
         # 3. Stores fragmented registers to global memory by warpgroup
-        nvgpu.warpgroup_mma_store(D, z_dev)
+        nvgpu.warpgroup_mma_store(D, d_dev)
 
     gemm_tma_kernel()
 
-    t8 = gpu.memcpy(token_ty, [t7], z, z_dev)
+    t8 = gpu.memcpy(token_ty, [t7], d, d_dev)
     gpu.wait(None, [t8])
 
 
 # Python pass arguments to MLIR
 N = 64
-x = np.random.randn(N, N).astype(np.float16)
-y = np.random.randn(N, N).astype(np.float16)
-z = np.zeros((N, N), np.float32)
-gemm_64_64_64(x, y, z)
+a = np.random.randn(N, N).astype(np.float16)
+b = np.random.randn(N, N).astype(np.float16)
+d = np.zeros((N, N), np.float32)
+gemm_64_64_64(a, b, d)
 
-ref = x.astype(np.float16) @ y.astype(np.float16)
-np.testing.assert_allclose(z, ref, rtol=5e-03, atol=1e-01)
+ref_d = a.astype(np.float16) @ b.astype(np.float16)
+np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
 print("PASS")
 # CHECK-NOT: Mismatched elements
diff --git a/mlir/test/Examples/nvgpu/Ch4.py b/mlir/test/Examples/nvgpu/Ch4.py
index 5f790c4c190ca9..2f38a49501e792 100644
--- a/mlir/test/Examples/nvgpu/Ch4.py
+++ b/mlir/test/Examples/nvgpu/Ch4.py
@@ -51,8 +51,8 @@ def partition_shape():
 
 def tma_load(
     mbar_group: Mbarriers,
-    x_tma: TMA,
-    y_tma: TMA,
+    a_tma: TMA,
+    b_tma: TMA,
     slot,
     stage,
     p=None,
@@ -60,45 +60,45 @@ def tma_load(
     """
     TMA loads two input matrices from global memory to shared memory. It performs the following operations:
 
-       - tma.load x_shared_memory[offset] at coordinate [x, y] (Loads 128x64)
-       - tma.load y_shared_memory[offset] at coordinate [x, y] (Loads 64x64)
-       - tma.load y_shared_memory[offset] at coordinate [x, y] (Loads 64x64)
+       - tma.load a_shared_memory[offset] at coordinate [x, y] (Loads 128x64)
+       - tma.load b_shared_memory[offset] at coordinate [x, y] (Loads 64x64)
+       - tma.load b_shared_memory[offset] at coordinate [x, y] (Loads 64x64)
 
-       mbarrier.arrive tx_count = 128x64x2x4
+       mbarrier.arrive ta_count = 128x64x2x4
     """
     dimX, dimY = partition_shape()
 
     tidx = gpu.thread_id(gpu.Dimension.x)
-    begin_y = NUM_STAGES * get_type_size(x_tma.tma_memref)
-    size_tma_x = get_type_size(x_tma.tma_memref)
-    size_tma_y = get_type_size(y_tma.tma_memref)
-    tx_count = size_tma_x + (size_tma_y * 2)
+    begin_b = NUM_STAGES * get_type_size(a_tma.tma_memref)
+    size_tma_a = get_type_size(a_tma.tma_memref)
+    size_tma_b = get_type_size(b_tma.tma_memref)
+    ta_count = size_tma_a + (size_tma_b * 2)
     tidx = gpu.thread_id(gpu.Dimension.x)
 
     p = tidx == 0 if p is None else p
 
-    off_x = slot * size_tma_x
-    off_y = (slot * size_tma_x) + begin_y
-    off_y2 = off_y + size_tma_y
+    off_a = slot * size_tma_a
+    off_b = (slot * size_tma_a) + begin_b
+    off_b2 = off_b + size_tma_b
     x = get_dynamic_shared_memory(
-        x_tma.tma_memref.shape, x_tma.tma_memref.element_type, off_x
+        a_tma.tma_memref.shape, a_tma.tma_memref.element_type, off_a
     )
     y1 = get_dynamic_shared_memory(
-        y_tma.tma_memref.shape, y_tma.tma_memref.element_type, off_y
+        b_tma.tma_memref.shape, b_tma.tma_memref.element_type, off_b
     )
     y2 = get_dynamic_shared_memory(
-        y_tma.tma_memref.shape, y_tma.tma_memref.element_type, off_y2
+        b_tma.tma_memref.shape, b_tma.tma_memref.element_type, off_b2
     )
 
-    mbar_group[slot].arrive(tx_count, predicate=p)
+    mbar_group[slot].arrive(ta_count, predicate=p)
 
     c1 = stage * 64
-    x_tma.load(x, mbar_group[slot], coords=[c1, dimX], predicate=p)
-    y_tma.load(y1, mbar_group[slot], coords=[dimY, c1], predicate=p)
-    y_tma.load(y2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p)
+    a_tma.load(x, mbar_group[slot], coords=[c1, dimX], predicate=p)
+    b_tma.load(y1, mbar_group[slot], coords=[dimY, c1], predicate=p)
+    b_tma.load(y2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p)
 
 
-def bootstrap(x_tma: TMA, y_tma: TMA):
+def bootstrap(a_tma: TMA, b_tma: TMA):
     """
     Initialize mbarriers and prefetch TMA descriptors.
     """
@@ -109,14 +109,14 @@ def bootstrap(x_tma: TMA, y_tma: TMA):
         for i in scf.for_(0, NUM_STAGES, 1):
             mbar_group[i].init(1)
             scf.yield_([])
-        x_tma.prefetch()
-        y_tma.prefetch()
+        a_tma.prefetch()
+        b_tma.prefetch()
         scf.yield_([])
 
     return mbar_group
 
 
-def prologue(mbar_group: Mbarriers, x_tma: TMA, y_tma: TMA):
+def prologue(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
     """
     Prologue of the GEMM kernel. It loads 2 input matrices for each stage in loop like below:
 
@@ -126,11 +126,11 @@ def prologue(mbar_group: Mbarriers, x_tma: TMA, y_tma: TMA):
     """
     ns = NUM_STAGES if NUM_STAGES == 1 else NUM_STAGES - 1
     for iv in scf.for_(0, ns, 1):
-        tma_load(mbar_group, x_tma, y_tma, iv, iv)
+        tma_load(mbar_group, a_tma, b_tma, iv, iv)
         scf.yield_([])
 
 
-def mainloop(mbar_group: Mbarriers, x_tma: TMA, y_tma: TMA):
+def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
     """
     Main loop of the Multistage GEMM kernel. It iterates through
     stages and performs matrix multiplication, loading data by TMA to shared memory. It like following
@@ -151,11 +151,11 @@ def mainloop(mbar_group: Mbarriers, x_tma: TMA, y_tma: TMA):
     ns = NUM_STAGES if NUM_STAGES == 1 else NUM_STAGES - 1
 
     tidx = gpu.thread_id(gpu.Dimension.x)
-    begin_y = NUM_STAGES * get_type_size(x_tma.tma_memref)
+    begin_b = NUM_STAGES * get_type_size(a_tma.tma_memref)
 
-    size_x = TILE_M * TILE_K * get_type_size(T.f16())
+    size_a = TILE_M * TILE_K * get_type_size(T.f16())
 
-    C = MatrixAccumulator(TILE_M, TILE_N, T.f32()).op()
+    C = WarpgroupAccumulatorMatrix(TILE_M, TILE_N, T.f32()).op()
     pp = const(False, ty=T.bool())
 
     # Main Loop
@@ -169,16 +169,16 @@ def mainloop(mbar_group: Mbarriers, x_tma: TMA, y_tma: TMA):
         mbar_group[stage].try_wait(phase=pp)
 
         # Find shared memory slot
-        offset_x = stage * size_x
-        offset_y = offset_x + begin_y
-        x_smem = get_dynamic_shared_memory([TILE_M, TILE_K], T.f16(), offset_x)
-        y_smem = get_dynamic_shared_memory([TILE_K, TILE_N], T.f16(), offset_y)
+        offset_a = stage * size_a
+        offset_b = offset_a + begin_b
+        a_smem = get_dynamic_shared_memory([TILE_M, TILE_K], T.f16(), offset_a)
+        b_smem = get_dynamic_shared_memory([TILE_K, TILE_N], T.f16(), offset_b)
 
         # Matrix Multiply
-        A = Matrix(x_smem, x_tma, TILE_M, TILE_K)
-        B = Matrix(y_smem, y_tma, TILE_K, TILE_N)
+        A = WarpgroupMatrix(a_smem, a_tma, TILE_M, TILE_K)
+        B = WarpgroupMatrix(b_smem, b_tma, TILE_K, TILE_N)
         C = for_op.inner_iter_args[0]
-        D = Matrix.matmul(A, B, C)
+        D = WarpgroupMatrix.matmul(A, B, C)
         if NUM_STAGES == 1:
             nvvm.WgmmaWaitGroupSyncOp(0)
 
@@ -186,7 +186,7 @@ def mainloop(mbar_group: Mbarriers, x_tma: TMA, y_tma: TMA):
         pred = ((iv + ns) < const(K // TILE_K)) & (tidx == 0)
         nextStage = iv + ns
         nextSlot = nextStage % NUM_STAGES
-        tma_load(mbar_group, x_tma, y_tma, nextSlot, nextStage, pred)
+        tma_load(mbar_group, a_tma, b_tma, nextSlot, nextStage, pred)
 
         # Switch phase parity for the mbarrier
         switched = pp ^ const(True, ty=T.bool())
@@ -202,7 +202,7 @@ def mainloop(mbar_group: Mbarriers, x_tma: TMA, y_tma: TMA):
     return for_op.results[0]
 
 
-def epilogue(D, z_dev):
+def epilogue(D, d_dev):
     """
     Epilogue of the GEMM kernel. It stores the fragmented registers to global memory.
 
@@ -214,57 +214,61 @@ def epilogue(D, z_dev):
     tidx = gpu.thread_id(gpu.Dimension.x)
     dimX, dimY = partition_shape()
 
-    z_smem = get_dynamic_shared_memory([TILE_M, TILE_N], T.f32())
-    z_gmem = memref.subview(z_dev, [dimX, dimY], [TILE_M, TILE_N], [1, 1])
+    d_smem = get_dynamic_shared_memory([TILE_M, TILE_N], T.f32())
+    d_gmem = memref.subview(d_dev, [dimX, dimY], [TILE_M, TILE_N], [1, 1])
 
     # Store (registers -> shared memory)
-    nvgpu.WarpgroupMmaStoreOp(D, z_smem)
+    nvgpu.WarpgroupMmaStoreOp(D, d_smem)
     gpu.barrier()
 
     # Store (shared memory --> global memory)
     for i in scf.for_(0, TILE_M, 1):
-        val = memref.load(z_smem, [i, tidx])
-        memref.store(val, z_gmem, [i, tidx])
+        val = memref.load(d_smem, [i, tidx])
+        memref.store(val, d_gmem, [i, tidx])
         scf.yield_([])
 
 
 @NVDSL.mlir_func
-def gemm_multistage(x, y, z):
+def gemm_multistage(x, y, d):
     token_ty = ir.Type.parse("!gpu.async.token")
     t1 = gpu.wait(token_ty, [])
-    x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
-    y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
-    z_dev, t4 = gpu.alloc(z.type, token_ty, [t3], [], [])
-    t5 = gpu.memcpy(token_ty, [t4], x_dev, x)
-    t6 = gpu.memcpy(token_ty, [t5], y_dev, y)
+    a_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
+    b_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
+    d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], [])
+    t5 = gpu.memcpy(token_ty, [t4], a_dev, x)
+    t6 = gpu.memcpy(token_ty, [t5], b_dev, y)
     t7 = gpu.wait(token_ty, [t6])
 
     sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
-    x_tma = TMA([128, 64], x.type, swizzle=sw)
-    y_tma = TMA([64, 64], y.type, swizzle=sw)
-    x_tma.create_descriptor(x_dev)
-    y_tma.create_descriptor(y_dev)
+    a_tma = TMA([128, 64], x.type, swizzle=sw)
+    b_tma = TMA([64, 64], y.type, swizzle=sw)
+    a_tma.create_descriptor(a_dev)
+    b_tma.create_descriptor(b_dev)
 
     grid = [(M // TILE_M), (N // TILE_N), 1]
     block = [128, 1, 1]
 
-    @NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=229440)
+    size_a = get_type_size(x.type.element_type) * TILE_M * TILE_K
+    size_b = get_type_size(x.type.element_type) * TILE_N * TILE_K
+    smem_size_in_bytes = (size_a + size_b) * NUM_STAGES
+
+    @NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=smem_size_in_bytes)
     def gemm_multistage_kernel():
         # Initialize mbarriers and prefetch TMA descriptors
-        mbar_group = bootstrap(x_tma, y_tma)
+        mbar_group = bootstrap(a_tma, b_tma)
 
         # Fill the pipeline stages
-        prologue(mbar_group, x_tma, y_tma)
+        prologue(mbar_group, a_tma, b_tma)
 
         # Main loop
-        D = mainloop(mbar_group, x_tma, y_tma)
+        D = mainloop(mbar_group, a_tma, b_tma)
 
         # Store registers to global memory
-        epilogue(D, z_dev)
+        epilogue(D, d_dev)
 
     gemm_multistage_kernel()
 
-    t8 = gpu.memcpy(token_ty, [t7], z, z_dev)
+    t8 = gpu.memcpy(token_ty, [t7], d, d_dev)
     gpu.wait(None, [t8])
 
 
@@ -276,16 +280,16 @@ def gemm_multistage_kernel():
 TILE_M = 128
 TILE_N = 128
 TILE_K = 64
-x = np.random.randn(M, K).astype(np.float16)
-y = np.random.randn(K, N).astype(np.float16)
-z = np.zeros((M, N), np.float32)
+a = np.random.randn(M, K).astype(np.float16)
+b = np.random.randn(K, N).astype(np.float16)
+d = np.zeros((M, N), np.float32)
 
-gemm_multistage(x, y, z)
+gemm_multistage(a, b, d)
 
 
 # Verify MLIR with reference computation
-ref = x.astype(np.float16) @ y.astype(np.float16)
-np.testing.assert_allclose(z, ref, rtol=5e-03, atol=1e-01)
+ref_d = a.astype(np.float16) @ b.astype(np.float16)
+np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
 
 
 print("PASS")
diff --git a/mlir/test/Examples/nvgpu/tools/nvdsl.py b/mlir/test/Examples/nvgpu/tools/nvdsl.py
index e27857d8ed1d79..bb312a9e259ef3 100644
--- a/mlir/test/Examples/nvgpu/tools/nvdsl.py
+++ b/mlir/test/Examples/nvgpu/tools/nvdsl.py
@@ -161,7 +161,7 @@ def load(self, dest, mbarrier: Mbarriers, coords=[0, 0], predicate=None):
         )
 
 
-class MatrixAccumulator:
+class WarpgroupAccumulatorMatrix:
     def __init__(self, M, N, ty):
         self.M = M
         self.N = N
@@ -169,21 +169,14 @@ def __init__(self, M, N, ty):
 
     @property
     def acc_ty(self):
-        return ir.Type.parse(
-            "!nvgpu.warpgroup.accumulator<fragmented=vector<"
-            + str(self.M)
-            + "x"
-            + str(self.N)
-            + "x"
-            + str(self.ty)
-            + ">>"
-        )
+        parse_str = f"!nvgpu.warpgroup.accumulator<fragmented=vector<{self.M}x{self.N}x{self.ty}>>"
+        return ir.Type.parse(parse_str)
 
     def op(self):
         return nvgpu.warpgroup_mma_init_accumulator(self.acc_ty)
 
 
-class Matrix:
+class WarpgroupMatrix:
     def __init__(self, smem, tma_descriptor: TMA, M, N):
         self.tma_descriptor = tma_descriptor
         self.smem = smem
@@ -192,15 +185,8 @@ def __init__(self, smem, tma_descriptor: TMA, M, N):
 
     @property
     def wgmma_ty(self):
-        return ir.Type.parse(
-            "!nvgpu.warpgroup.descriptor<tensor=memref<"
-            + str(self.M)
-            + "x"
-            + str(self.N)
-            + "x"
-            + str(self.tma_descriptor.memref_ty.element_type)
-            + ", #gpu.address_space<workgroup>>>"
-        )
+        parse_str = f"!nvgpu.warpgroup.descriptor<tensor=memref<{self.M}x{self.N}x{self.tma_descriptor.memref_ty.element_type}, #gpu.address_space<workgroup>>>"
+        return ir.Type.parse(parse_str)
 
     def matmul(lhs, rhs, acc):
         wgmma_desc_lhs = nvgpu.warpgroup_generate_descriptor(



More information about the Mlir-commits mailing list