[Mlir-commits] [mlir] [mlir][nvgpu] NVGPU Tutorials (PR #87065)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 29 06:13:26 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Guray Ozen (grypp)

<details>
<summary>Changes</summary>

I have a tutorial at EuroLLVM 2024 ([Zero to Hero: Programming Nvidia Hopper Tensor Core with MLIR's NVGPU Dialect](https://llvm.swoogo.com/2024eurollvm/session/2086997/zero-to-hero-programming-nvidia-hopper-tensor-core-with-mlir's-nvgpu-dialect)). For that, I implemented tutorial codes in Python. The focus is the nvgpu dialect and how to use its advanced features. I thought it might be useful to upstream this.

The tutorial codes are as follows:
- **Ch0.py:** Hello World
- **Ch1.py:** 2D Saxpy
- **Ch2.py:** 2D Saxpy using TMA engine
- **Ch3.py:** GEMM 64x64x64 using Tensor Core and TMA
- **Ch4.py:** Multistage performant GEMM using Tensor Core and TMA

I might implement two more chapters, but they are more like GPU programming than compiler related.
- **Ch5.py:** Warp Specilized GEMM
- **Ch6.py:** Warp Specilized Persistent GEMM

This PR also introduces the nvdsl class, making IR building in the tutorial easier.

---

Patch is 37.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/87065.diff


9 Files Affected:

- (added) mlir/test/Examples/nvgpu/Ch0.py (+46) 
- (added) mlir/test/Examples/nvgpu/Ch1.py (+68) 
- (added) mlir/test/Examples/nvgpu/Ch2.py (+94) 
- (added) mlir/test/Examples/nvgpu/Ch3.py (+92) 
- (added) mlir/test/Examples/nvgpu/Ch4.py (+291) 
- (added) mlir/test/Examples/nvgpu/lit.local.cfg (+3) 
- (added) mlir/test/Examples/nvgpu/tools/lit.local.cfg (+3) 
- (added) mlir/test/Examples/nvgpu/tools/nvdsl.py (+404) 
- (added) mlir/test/Examples/nvgpu/tools/nvgpucompiler.py (+45) 


``````````diff
diff --git a/mlir/test/Examples/nvgpu/Ch0.py b/mlir/test/Examples/nvgpu/Ch0.py
new file mode 100644
index 00000000000000..221ca43d37307a
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/Ch0.py
@@ -0,0 +1,46 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN:   %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+#  Chapter 0 : Hello World
+# ===----------------------------------------------------------------------===//
+#
+# This program demonstrates Hello World
+#
+# This chapter introduces demonstrates:
+#   1. Build MLIR function with arguments
+#   2. Build MLIR GPU kernel
+#   3. Print from a GPU thread
+#   4. Pass arguments, JIT compile and run the MLIR function
+#
+# ===----------------------------------------------------------------------===//
+
+
+from mlir.dialects import gpu
+from tools.nvdsl import *
+
+
+# 1. Build function with arguments
+ at NVDSL.mlir_func
+def main(alpha):
+    # 2. Build GPU kernel
+    @NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(4, 1, 1))
+    def kernel():
+        tidx = gpu.thread_id(gpu.Dimension.x)
+        myValue = alpha + tidx
+        # Print from a GPU thread
+        gpu.printf("GPU thread %llu has %llu\n", [tidx, myValue])
+
+    # 3. Call the GPU kernel
+    kernel()
+
+
+# 4. Pass arguments, JIT compile and run the MLIR function
+alpha = 100
+main(alpha)
+
+
+# CHECK: GPU thread 0 has 100
+# CHECK: GPU thread 1 has 101
+# CHECK: GPU thread 2 has 102
+# CHECK: GPU thread 3 has 103
diff --git a/mlir/test/Examples/nvgpu/Ch1.py b/mlir/test/Examples/nvgpu/Ch1.py
new file mode 100644
index 00000000000000..a888c61358a024
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/Ch1.py
@@ -0,0 +1,68 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN:   %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+#  Chapter 1 : 2D Saxpy
+# ===----------------------------------------------------------------------===//
+#
+# This program demonstrates 2D Saxpy
+#
+# This chapter introduces demonstrates:
+#  1. Use MLIR GPU dialect to allocate and copy memory
+#  2. Compute 2D SAXPY kernel
+#  3. Pass numpy arrays to MLIR
+#  4. Verify MLIR with reference computation
+#
+# ===----------------------------------------------------------------------===//
+
+
+from mlir import ir
+from mlir.dialects import gpu, memref
+from tools.nvdsl import *
+import numpy as np
+
+
+ at NVDSL.mlir_func
+def saxpy(x, y, alpha):
+    # 1. Use MLIR GPU dialect to allocate and copy memory
+    token_ty = ir.Type.parse("!gpu.async.token")
+    t1 = gpu.wait(token_ty, [])
+    x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
+    y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
+    t4 = gpu.memcpy(token_ty, [t3], x_dev, x)
+    t5 = gpu.memcpy(token_ty, [t4], y_dev, y)
+    t6 = gpu.wait(token_ty, [t5])
+
+    # 2. Compute 2D SAXPY kernel
+    @NVDSL.mlir_gpu_launch(grid=(M, 1, 1), block=(N, 1, 1))
+    def saxpy_kernel():
+        bidx = gpu.block_id(gpu.Dimension.x)
+        tidx = gpu.thread_id(gpu.Dimension.x)
+        x_val = memref.load(x_dev, [bidx, tidx])
+        y_val = memref.load(y_dev, [bidx, tidx])
+
+        # SAXPY: y[i] += a * x[i];
+        y_val += x_val * alpha
+
+        memref.store(y_val, y_dev, [bidx, tidx])
+
+    saxpy_kernel()
+
+    t7 = gpu.memcpy(token_ty, [t6], y, y_dev)
+    gpu.wait(token_ty, [t7])
+
+
+# 3. Pass numpy arrays to MLIR
+M = 256
+N = 32
+alpha = 2.0
+x = np.ones((M, N), np.float32)
+y = np.ones((M, N), np.float32)
+ref = np.ones((M, N), np.float32)
+saxpy(x, y, alpha)
+
+#  4. Verify MLIR with reference computation
+ref += x * alpha
+np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
+print("PASS")
+# CHECK-NOT: Mismatched elements
diff --git a/mlir/test/Examples/nvgpu/Ch2.py b/mlir/test/Examples/nvgpu/Ch2.py
new file mode 100644
index 00000000000000..caa2691a06d69e
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/Ch2.py
@@ -0,0 +1,94 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN:   %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+#  Chapter 2 : 2D Saxpy with TMA
+# ===----------------------------------------------------------------------===//
+#
+# This program demonstrates 2D Saxpy. It is same as Chapter 1, 
+# but it loads data using TMA (Tensor Memory Accelerator)
+#
+# This chapter introduces demonstrates:
+#  1. Create and initialize asynchronous transactional barrier (mbarrier)
+#  2. Execute Tensor Memory Accelerator (TMA) Load
+#  3. Wait for completion of TMA load with mbarrier
+#
+# ===----------------------------------------------------------------------===//
+
+from mlir import ir
+from mlir.dialects import nvgpu, scf, arith, memref, vector, gpu
+from tools.nvdsl import *
+from mlir import runtime as rt
+from mlir.extras import types as T
+import numpy as np
+
+
+ at NVDSL.mlir_func
+def saxpy_tma(x, y, alpha):
+    token_ty = ir.Type.parse("!gpu.async.token")
+    t1 = gpu.wait(token_ty, [])
+    x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
+    y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
+    t4 = gpu.memcpy(token_ty, [t3], x_dev, x)
+    t5 = gpu.memcpy(token_ty, [t4], y_dev, y)
+    t6 = gpu.wait(token_ty, [t5])
+
+    x_tma = TMA((M, N), x.type)
+    y_tma = TMA((M, N), y.type)
+    x_tma.create_descriptor(x_dev)
+    y_tma.create_descriptor(y_dev)
+
+    @NVDSL.mlir_gpu_launch(grid=(M, 1, 1), block=(N, 1, 1), smem=65536)
+    def saxpy_tma_kernel():
+        bidx = gpu.block_id(gpu.Dimension.x)
+        tidx = gpu.thread_id(gpu.Dimension.x)
+        isThread0 = tidx == 0
+
+        # 1. Create and initialize asynchronous transactional barrier (mbarrier)
+        mbar_group = Mbarriers(number_of_barriers=1)
+        with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
+            mbar_group[0].init(1)
+            x_tma.prefetch()
+            y_tma.prefetch()
+            scf.yield_([])
+
+        x_smem = get_dynamic_shared_memory((M, N), T.f32())
+        y_smem = get_dynamic_shared_memory((M, N), T.f32(), offset=M * N * 2)
+
+        # 2. Execute Tensor Memory Accelerator (TMA) Load
+        with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
+            x_tma.load(x_smem, mbar_group[0])
+            y_tma.load(y_smem, mbar_group[0])
+            mbar_group[0].arrive(txcount=M * N * 2 * 4)
+            scf.yield_([])
+
+        # 3. Wait for completion of TMA load with mbarrier
+        mbar_group[0].try_wait()
+
+        x_val = memref.load(x_smem, [bidx, tidx])
+        y_val = memref.load(y_smem, [bidx, tidx])
+
+        # SAXPY: y[i] += a * x[i];
+        y_val += x_val * alpha
+
+        memref.store(y_val, y_dev, [bidx, tidx])
+
+    saxpy_tma_kernel()
+
+    t7 = gpu.memcpy(token_ty, [t6], y, y_dev)
+    gpu.wait(token_ty, [t7])
+
+
+M = 256
+N = 32
+alpha = 2.0
+x = np.ones((M, N), np.float32)
+y = np.ones((M, N), np.float32)
+ref = np.ones((M, N), np.float32)
+saxpy_tma(x, y, alpha)
+
+#  4. Verify MLIR with reference computation
+ref += x * alpha
+np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
+print("PASS")
+# CHECK-NOT: Mismatched elements
diff --git a/mlir/test/Examples/nvgpu/Ch3.py b/mlir/test/Examples/nvgpu/Ch3.py
new file mode 100644
index 00000000000000..802cb59ead1555
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/Ch3.py
@@ -0,0 +1,92 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN:   %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+#  Chapter 3 : GEMM 64x64x64 with Tensor Core
+# ===----------------------------------------------------------------------===//
+#
+# This program demonstrates a GEMM operation with 64x64x64 matrix multiplication
+#
+# This chapter introduces demonstrates:
+# 1. Execute TMA Load for two input matrices
+# 2. Performs Tensor Core GEMM 64x64x64 by warpgroup
+# 3. Stores fragmented registers to global memory by warpgroup
+#
+# ===----------------------------------------------------------------------===//
+
+
+from mlir import ir
+from mlir.dialects import nvgpu, scf, arith, memref, vector, gpu
+from tools.nvdsl import *
+from mlir.extras import types as T
+import numpy as np
+
+
+ at NVDSL.mlir_func
+def gemm_64_64_64(x, y, z):
+    token_ty = ir.Type.parse("!gpu.async.token")
+    t1 = gpu.wait(token_ty, [])
+    x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
+    y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
+    z_dev, t4 = gpu.alloc(z.type, token_ty, [t3], [], [])
+    t5 = gpu.memcpy(token_ty, [t4], x_dev, x)
+    t6 = gpu.memcpy(token_ty, [t5], y_dev, y)
+    t7 = gpu.wait(token_ty, [t6])
+
+    sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
+    x_tma = TMA([N, N], x.type, swizzle=sw)
+    y_tma = TMA([N, N], y.type, swizzle=sw)
+    x_tma.create_descriptor(x_dev)
+    y_tma.create_descriptor(y_dev)
+
+    @NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(128, 1, 1), smem=16384)
+    def gemm_tma_kernel():
+        tidx = gpu.thread_id(gpu.Dimension.x)
+
+        mbar_group = Mbarriers(number_of_barriers=1)
+        isThread0 = tidx == 0
+        with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
+            mbar_group[0].init(1)
+            x_tma.prefetch()
+            y_tma.prefetch()
+            scf.yield_([])
+
+        x_smem = get_dynamic_shared_memory((N, N), T.f16())
+        y_smem = get_dynamic_shared_memory((N, N), T.f16(), offset=N * N * 2)
+
+        # 1. Execute TMA Load for two input matrices
+        with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
+            x_tma.load(x_smem, mbar_group[0])
+            y_tma.load(y_smem, mbar_group[0])
+            tx_count = get_type_size(x_tma.tma_memref) + get_type_size(y_tma.tma_memref)
+            mbar_group[0].arrive(tx_count)
+            scf.yield_([])
+
+        mbar_group[0].try_wait()
+
+        # 2. Performs Tensor Core GEMM 64x64x64 by warpgroup
+        A = Matrix(x_smem, x_tma, N, N)
+        B = Matrix(y_smem, y_tma, N, N)
+        C = MatrixAccumulator(N, N, T.f32()).op()
+        D = Matrix.matmul(A, B, C)
+
+        # 3. Stores fragmented registers to global memory by warpgroup
+        nvgpu.warpgroup_mma_store(D, z_dev)
+
+    gemm_tma_kernel()
+
+    t8 = gpu.memcpy(token_ty, [t7], z, z_dev)
+    gpu.wait(None, [t8])
+
+
+# Python pass arguments to MLIR
+N = 64
+x = np.random.randn(N, N).astype(np.float16)
+y = np.random.randn(N, N).astype(np.float16)
+z = np.zeros((N, N), np.float32)
+gemm_64_64_64(x, y, z)
+
+ref = x.astype(np.float16) @ y.astype(np.float16)
+np.testing.assert_allclose(z, ref, rtol=5e-03, atol=1e-01)
+print("PASS")
+# CHECK-NOT: Mismatched elements
diff --git a/mlir/test/Examples/nvgpu/Ch4.py b/mlir/test/Examples/nvgpu/Ch4.py
new file mode 100644
index 00000000000000..d222d6e60aafad
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/Ch4.py
@@ -0,0 +1,291 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN:   %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+#  Chapter 4 : Multistage GEMM with Tensor Core
+# ===----------------------------------------------------------------------===//
+#
+# This program demonstrates a GEMM operation with 64x64x64 matrix multiplication
+#
+# This chapter introduces demonstrates:
+#  1. Partition shape based on block IDs
+#  2. Prologue
+#    2.1 Execute TMA Load for two input matrices for each stage
+#  3. Main loop
+#    3.1 Wait for completion of TMA load with mbarrier
+#    3.2 Performs Tensor Core GEMM 64x128x64 by warpgroup
+#    3.3 Load next stage if needed
+#  4. Epilogue
+#    4.1 Store fragmented registers to shared memory
+#    4.2 Store shared memory to global
+#
+# ===----------------------------------------------------------------------===//
+
+
+from mlir import ir
+from mlir.dialects import gpu, scf, nvgpu, nvvm
+from mlir.extras import types as T
+from tools.nvdsl import *
+import numpy as np
+
+
+def partition_shape():
+    """
+    Calculate the partition shape based on the block IDs.
+
+    It partitions the shape like below:
+    for(.. i < M ...)   --> blockIdx.x
+     for(.. j < N ...)  --> blockIdx.y
+      for(.. k < K ...)
+
+    Returns:
+        dimX (int): Dimension along the x-axis.
+        dimY (int): Dimension along the y-axis.
+    """
+    bidx = gpu.block_id(gpu.Dimension.x)
+    bidy = gpu.block_id(gpu.Dimension.y)
+    dimX = bidx * TILE_M
+    dimY = bidy * TILE_N
+    return dimX, dimY
+
+
+def tma_load(
+    mbar_group: Mbarriers,
+    x_tma: TMA,
+    y_tma: TMA,
+    slot,
+    stage,
+    p=None,
+):
+    """
+    TMA loads two input matrices from global memory to shared memory. It performs the following operations:
+
+       - tma.load x_shared_memory[offset] at coordinate [x, y] (Loads 128x64)
+       - tma.load y_shared_memory[offset] at coordinate [x, y] (Loads 64x64)
+       - tma.load y_shared_memory[offset] at coordinate [x, y] (Loads 64x64)
+
+       mbarrier.arrive tx_count = 128x64x2x4
+    """
+    dimX, dimY = partition_shape()
+
+    tidx = gpu.thread_id(gpu.Dimension.x)
+    begin_y = NUM_STAGES * get_type_size(x_tma.tma_memref)
+    size_tma_x = get_type_size(x_tma.tma_memref)
+    size_tma_y = get_type_size(y_tma.tma_memref)
+    tx_count = size_tma_x + (size_tma_y * 2)
+    tidx = gpu.thread_id(gpu.Dimension.x)
+
+    p = tidx == 0 if p is None else p
+
+    off_x = slot * size_tma_x
+    off_y = (slot * size_tma_x) + begin_y
+    off_y2 = off_y + size_tma_y
+    x = get_dynamic_shared_memory(
+        x_tma.tma_memref.shape, x_tma.tma_memref.element_type, off_x
+    )
+    y1 = get_dynamic_shared_memory(
+        y_tma.tma_memref.shape, y_tma.tma_memref.element_type, off_y
+    )
+    y2 = get_dynamic_shared_memory(
+        y_tma.tma_memref.shape, y_tma.tma_memref.element_type, off_y2
+    )
+
+    mbar_group[slot].arrive(tx_count, predicate=p)
+
+    c1 = stage * 64
+    x_tma.load(x, mbar_group[slot], coords=[c1, dimX], predicate=p)
+    y_tma.load(y1, mbar_group[slot], coords=[dimY, c1], predicate=p)
+    y_tma.load(y2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p)
+
+
+def bootstrap(x_tma: TMA, y_tma: TMA):
+    """
+    Initialize mbarriers and prefetch TMA descriptors.
+    """
+    tidx = gpu.thread_id(gpu.Dimension.x)
+    mbar_group = Mbarriers(number_of_barriers=NUM_STAGES)
+    isThread0 = tidx == const(0)
+    with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
+        for i in scf.for_(0, NUM_STAGES, 1):
+            mbar_group[i].init(1)
+            scf.yield_([])
+        x_tma.prefetch()
+        y_tma.prefetch()
+        scf.yield_([])
+
+    return mbar_group
+
+
+def prologue(mbar_group: Mbarriers, x_tma: TMA, y_tma: TMA):
+    """
+    Prologue of the GEMM kernel. It loads 2 input matrices for each stage in loop like below:
+
+    for stage in range(NUM_STAGES):
+        tma_load x, y, stage
+
+    """
+    ns = NUM_STAGES if NUM_STAGES == 1 else NUM_STAGES - 1
+    for iv in scf.for_(0, ns, 1):
+        tma_load(mbar_group, x_tma, y_tma, iv, iv)
+        scf.yield_([])
+
+
+def mainloop(mbar_group: Mbarriers, x_tma: TMA, y_tma: TMA):
+    """
+    Main loop of the Multistage GEMM kernel. It iterates through
+    stages and performs matrix multiplication, loading data by TMA to shared memory. It like following
+
+    MatrixAccumulator D
+    for k in range(K // TILE_K):
+
+        try_wait(stage, ...)    # Wait TMA load
+
+        Matrix A(stage, ...)    # Find shared memory slot
+        Matrix B(stage, ...)    # Find shared memory slot
+        D += A @ B              # Multiply and accumulate
+
+        if(needLoad)            # Load next stage if needed
+            tma_load(x, y, nextSlot, nextStage)
+
+    """
+    ns = NUM_STAGES if NUM_STAGES == 1 else NUM_STAGES - 1
+
+    tidx = gpu.thread_id(gpu.Dimension.x)
+    begin_y = NUM_STAGES * get_type_size(x_tma.tma_memref)
+
+    size_x = TILE_M * TILE_K * get_type_size(T.f16())
+
+    C = MatrixAccumulator(TILE_M, TILE_N, T.f32()).op()
+    pp = const(False, ty=T.bool())
+
+    # Main Loop
+    for_op = scf.ForOp(const(0), const(K // TILE_K), const(1), [C, pp])
+    with ir.InsertionPoint(for_op.body):
+        pp = for_op.inner_iter_args[1]
+        iv = for_op.induction_variable
+        stage = iv % NUM_STAGES
+
+        # Wait for current stage
+        mbar_group[stage].try_wait(phase=pp)
+
+        # Find shared memory slot
+        offset_x = stage * size_x
+        offset_y = offset_x + begin_y
+        x_smem = get_dynamic_shared_memory([TILE_M, TILE_K], T.f16(), offset_x)
+        y_smem = get_dynamic_shared_memory([TILE_K, TILE_N], T.f16(), offset_y)
+
+        # Matrix Multiply
+        A = Matrix(x_smem, x_tma, TILE_M, TILE_K)
+        B = Matrix(y_smem, y_tma, TILE_K, TILE_N)
+        C = for_op.inner_iter_args[0]
+        D = Matrix.matmul(A, B, C)
+        if NUM_STAGES == 1:
+            nvvm.WgmmaWaitGroupSyncOp(0)
+
+        # Load next stage
+        pred = ((iv + ns) < const(K // TILE_K)) & (tidx == 0)
+        nextStage = iv + ns
+        nextSlot = nextStage % NUM_STAGES
+        tma_load(mbar_group, x_tma, y_tma, nextSlot, nextStage, pred)
+
+        # Switch phase parity for the mbarrier
+        switched = pp ^ const(True, ty=T.bool())
+        newPP = arith.select(
+            stage == (NUM_STAGES - 1),
+            switched,
+            pp,
+        )
+        scf.yield_([D, newPP])
+
+    nvvm.WgmmaWaitGroupSyncOp(0)
+
+    return for_op.results[0]
+
+
+def epilogue(D, z_dev):
+    """
+    Epilogue of the GEMM kernel. It stores the fragmented registers to global memory.
+
+    MatrixAccumulator D               # Fragmented results
+    store D -> Shared Memory          # Store Shared Memory
+    Shared Memory -> Z[dimX][dimY]    # Store Shared Memory to Global Memory
+
+    """
+    tidx = gpu.thread_id(gpu.Dimension.x)
+    dimX, dimY = partition_shape()
+
+    z_smem = get_dynamic_shared_memory([TILE_M, TILE_N], T.f32())
+    z_gmem = memref.subview(z_dev, [dimX, dimY], [TILE_M, TILE_N], [1, 1])
+
+    # Store (registers -> shared memory)
+    nvgpu.WarpgroupMmaStoreOp(D, z_smem)
+    gpu.barrier()
+
+    # Store (shared memory --> global memory)
+    for i in scf.for_(0, TILE_M, 1):
+        val = memref.load(z_smem, [i, tidx])
+        memref.store(val, z_gmem, [i, tidx])
+        scf.yield_([])
+
+
+ at NVDSL.mlir_func
+def gemm_multistage(x, y, z):
+    token_ty = ir.Type.parse("!gpu.async.token")
+    t1 = gpu.wait(token_ty, [])
+    x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
+    y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
+    z_dev, t4 = gpu.alloc(z.type, token_ty, [t3], [], [])
+    t5 = gpu.memcpy(token_ty, [t4], x_dev, x)
+    t6 = gpu.memcpy(token_ty, [t5], y_dev, y)
+    t7 = gpu.wait(token_ty, [t6])
+
+    sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
+    x_tma = TMA([128, 64], x.type, swizzle=sw)
+    y_tma = TMA([64, 64], y.type, swizzle=sw)
+    x_tma.create_descriptor(x_dev)
+    y_tma.create_descriptor(y_dev)
+
+    grid = [(M // TILE_M), (N // TILE_N), 1]
+    block = [128, 1, 1]
+    @NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=229440)
+    def gemm_multistage_kernel():
+        # Initialize mbarriers and prefetch TMA descriptors
+        mbar_group = bootstrap(x_tma, y_tma)
+
+        # Fill the pipeline stages
+        prologue(mbar_group, x_tma, y_tma)
+
+        # Main loop
+        D = mainloop(mbar_group, x_tma, y_tma)
+
+        # Store registers to global memory
+        epilogue(D, z_dev)
+
+    gemm_multistage_kernel()
+
+    t8 = gpu.memcpy(token_ty, [t7], z, z_dev)
+    gpu.wait(None, [t8])
+
+
+# Python pass arguments to MLIR
+NUM_STAGES = 7
+N = 256
+M = 512
+K = 1024
+TILE_M = 128
+TILE_N = 128
+TILE_K = 64
+x = np.random.randn(M, K).astype(np.float16)
+y = np.random.randn(K, N).astype(np.float16)
+z = np.zeros((M, N), np.float32)
+
+gemm_multistage(x, y, z)
+
+
+# Verify MLIR with reference computation
+ref = x.astype(np.float16) @ y.astype(np.float16)
+np.testing.assert_allclose(z, ref, rtol=5e-03, atol=1e-01)
+
+
+print("PASS")
+# CHECK-NOT: Mismatched elements
diff --git a/mlir/test/Examples/nvgpu/lit.local.cfg b/mlir/test/Examples/nvgpu/lit.local.cfg
new file mode 100644
index 00000000000000..e586b55573898a
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/lit.local.cfg
@@ -0,0 +1,3 @@
+config.unsupported = False
+if not config.enable_cuda_runner or not config.mlir_run_cuda_sm90_tests:
+  config.unsupported = True
\ No newl...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/87065


More information about the Mlir-commits mailing list