[Mlir-commits] [mlir] 4d33082 - [mlir][nvgpu] NVGPU Tutorials (#87065)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 24 03:00:16 PDT 2024


Author: Guray Ozen
Date: 2024-04-24T12:00:12+02:00
New Revision: 4d3308202e52b213a05023c8b8b470b346151de6

URL: https://github.com/llvm/llvm-project/commit/4d3308202e52b213a05023c8b8b470b346151de6
DIFF: https://github.com/llvm/llvm-project/commit/4d3308202e52b213a05023c8b8b470b346151de6.diff

LOG: [mlir][nvgpu] NVGPU Tutorials (#87065)

I have a tutorial at EuroLLVM 2024 ([Zero to Hero: Programming Nvidia
Hopper Tensor Core with MLIR's NVGPU
Dialect](https://llvm.swoogo.com/2024eurollvm/session/2086997/zero-to-hero-programming-nvidia-hopper-tensor-core-with-mlir's-nvgpu-dialect)).
For that, I implemented tutorial codes in Python. The focus is the nvgpu
dialect and how to use its advanced features. I thought it might be
useful to upstream this.

The tutorial codes are as follows:
- **Ch0.py:** Hello World
- **Ch1.py:** 2D Saxpy
- **Ch2.py:** 2D Saxpy using TMA
- **Ch3.py:** GEMM 128x128x64 using Tensor Core and TMA 
- **Ch4.py:** Multistage performant GEMM using Tensor Core and TMA
- **Ch5.py:** Warp Specialized GEMM using Tensor Core and TMA

I might implement one more chapter:

- **Ch6.py:** Warp Specialized Persistent ping-pong GEMM

This PR also introduces the nvdsl class, making IR building in the
tutorial easier.

Added: 
    mlir/test/Examples/NVGPU/Ch0.py
    mlir/test/Examples/NVGPU/Ch1.py
    mlir/test/Examples/NVGPU/Ch2.py
    mlir/test/Examples/NVGPU/Ch3.py
    mlir/test/Examples/NVGPU/Ch4.py
    mlir/test/Examples/NVGPU/Ch5.py
    mlir/test/Examples/NVGPU/lit.local.cfg
    mlir/test/Examples/NVGPU/tools/lit.local.cfg
    mlir/test/Examples/NVGPU/tools/nvdsl.py
    mlir/test/Examples/NVGPU/tools/nvgpucompiler.py

Modified: 
    

Removed: 
    


################################################################################
diff  --git a/mlir/test/Examples/NVGPU/Ch0.py b/mlir/test/Examples/NVGPU/Ch0.py
new file mode 100644
index 00000000000000..8f60088178d119
--- /dev/null
+++ b/mlir/test/Examples/NVGPU/Ch0.py
@@ -0,0 +1,50 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN:   %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+#  Chapter 0 : Hello World
+# ===----------------------------------------------------------------------===//
+#
+# This program demonstrates Hello World:
+#   1. Build MLIR function with arguments
+#   2. Build MLIR GPU kernel
+#   3. Print from a GPU thread
+#   4. Pass arguments, JIT compile and run the MLIR function
+#
+# ===----------------------------------------------------------------------===//
+
+
+from mlir.dialects import gpu
+from tools.nvdsl import *
+
+
+# 1. The decorator generates a MLIR func.func.
+# Everything inside the Python function becomes the body of the func.
+# The decorator also translates `alpha` to an `index` type.
+ at NVDSL.mlir_func
+def main(alpha):
+    # 2. The decorator generates a MLIR gpu.launch.
+    # Everything inside the Python function becomes the body of the gpu.launch.
+    # This allows for late outlining of the GPU kernel, enabling optimizations
+    # like constant folding from host to device.
+    @NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(4, 1, 1))
+    def kernel():
+        tidx = gpu.thread_id(gpu.Dimension.x)
+        # + operator generates arith.addi
+        myValue = alpha + tidx
+        # Print from a GPU thread
+        gpu.printf("GPU thread %llu has %llu\n", [tidx, myValue])
+
+    # 3. Call the GPU kernel
+    kernel()
+
+
+alpha = 100
+# 4. The `mlir_func` decorator JIT compiles the IR and executes the MLIR function.
+main(alpha)
+
+
+# CHECK: GPU thread 0 has 100
+# CHECK: GPU thread 1 has 101
+# CHECK: GPU thread 2 has 102
+# CHECK: GPU thread 3 has 103

diff  --git a/mlir/test/Examples/NVGPU/Ch1.py b/mlir/test/Examples/NVGPU/Ch1.py
new file mode 100644
index 00000000000000..da65aa2ef6a172
--- /dev/null
+++ b/mlir/test/Examples/NVGPU/Ch1.py
@@ -0,0 +1,66 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN:   %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+#  Chapter 1 : 2D Saxpy
+# ===----------------------------------------------------------------------===//
+#
+# This program demonstrates 2D Saxpy:
+#  1. Use GPU dialect to allocate and copy memory host to gpu and vice versa
+#  2. Computes 2D SAXPY kernel using operator overloading
+#  3. Pass numpy arrays to MLIR as memref arguments
+#  4. Verify MLIR program with reference computation in python
+#
+# ===----------------------------------------------------------------------===//
+
+
+from mlir import ir
+from mlir.dialects import gpu, memref
+from tools.nvdsl import *
+import numpy as np
+
+
+ at NVDSL.mlir_func
+def saxpy(x, y, alpha):
+    # 1. Use MLIR GPU dialect to allocate and copy memory
+    token_ty = ir.Type.parse("!gpu.async.token")
+    t1 = gpu.wait(token_ty, [])
+    x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
+    y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
+    t4 = gpu.memcpy(token_ty, [t3], x_dev, x)
+    t5 = gpu.memcpy(token_ty, [t4], y_dev, y)
+    t6 = gpu.wait(token_ty, [t5])
+
+    # 2. Compute 2D SAXPY kernel
+    @NVDSL.mlir_gpu_launch(grid=(M, 1, 1), block=(N, 1, 1))
+    def saxpy_kernel():
+        bidx = gpu.block_id(gpu.Dimension.x)
+        tidx = gpu.thread_id(gpu.Dimension.x)
+        x_val = memref.load(x_dev, [bidx, tidx])
+        y_val = memref.load(y_dev, [bidx, tidx])
+
+        # SAXPY: y[i] += a * x[i];
+        y_val += x_val * alpha
+
+        memref.store(y_val, y_dev, [bidx, tidx])
+
+    saxpy_kernel()
+
+    t7 = gpu.memcpy(token_ty, [t6], y, y_dev)
+    gpu.wait(token_ty, [t7])
+
+
+# 3. Pass numpy arrays to MLIR
+M = 256
+N = 32
+alpha = 2.0
+x = np.random.randn(M, N).astype(np.float32)
+y = np.ones((M, N), np.float32)
+saxpy(x, y, alpha)
+
+#  4. Verify MLIR with reference computation
+ref = np.ones((M, N), np.float32)
+ref += x * alpha
+np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
+print("PASS")
+# CHECK-NOT: Mismatched elements

diff  --git a/mlir/test/Examples/NVGPU/Ch2.py b/mlir/test/Examples/NVGPU/Ch2.py
new file mode 100644
index 00000000000000..78c14cb2c7ad8c
--- /dev/null
+++ b/mlir/test/Examples/NVGPU/Ch2.py
@@ -0,0 +1,93 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN:   %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+#  Chapter 2 : 2D Saxpy with TMA
+# ===----------------------------------------------------------------------===//
+#
+# This program demonstrates 2D Saxpy. It is same as Chapter 1,
+# but it loads data using TMA (Tensor Memory Accelerator)
+#
+# This chapter introduces demonstrates:
+#  1. Computes 2D SAXPY in the same way as Ch1.py but loads data using TMA
+#  2. Create and initialize 1 asynchronous transactional barrier (mbarrier)
+#  3. Thread-0 Load request data load from TMA for each thread block
+#  4. Each thread block loads <1x32xf32> for x and y.
+#  5. Wait for completion of TMA load with mbarrier
+#
+# ===----------------------------------------------------------------------===//
+
+from mlir import ir
+from mlir.dialects import nvgpu, scf, arith, memref, vector, gpu
+from tools.nvdsl import *
+from mlir import runtime as rt
+from mlir.extras import types as T
+import numpy as np
+
+
+ at NVDSL.mlir_func
+def saxpy(x, y, alpha):
+    token_ty = ir.Type.parse("!gpu.async.token")
+    t1 = gpu.wait(token_ty, [])
+    x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
+    y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
+    t4 = gpu.memcpy(token_ty, [t3], x_dev, x)
+    t5 = gpu.memcpy(token_ty, [t4], y_dev, y)
+    t6 = gpu.wait(token_ty, [t5])
+
+    x_tma = TMA([1, N], x.type)
+    y_tma = TMA([1, N], y.type)
+    x_tma.create_descriptor(x_dev)
+    y_tma.create_descriptor(y_dev)
+    sz_x = get_type_size(x_tma.tma_memref)
+    sz_y = get_type_size(x_tma.tma_memref)
+    sz = sz_x + sz_y
+
+    @NVDSL.mlir_gpu_launch(grid=(M, 1, 1), block=(N, 1, 1), smem=sz)
+    def saxpy_tma_kernel():
+        bidx = gpu.block_id(gpu.Dimension.x)
+        tidx = gpu.thread_id(gpu.Dimension.x)
+        isThread0 = tidx == 0
+
+        # 1. Create and initialize asynchronous transactional barrier (mbarrier)
+        mbar_group = Mbarriers(number_of_barriers=1)
+        mbar_group[0].init(1, predicate=isThread0)
+
+        # 2. Execute Tensor Memory Accelerator (TMA) Load
+        x_smem = get_dynamic_shared_memory([1, N], T.f32())
+        y_smem = get_dynamic_shared_memory([1, N], T.f32(), offset=sz_x)
+        x_tma.load(x_smem, mbar_group[0], coords=[0, bidx], predicate=isThread0)
+        y_tma.load(y_smem, mbar_group[0], coords=[0, bidx], predicate=isThread0)
+        mbar_group[0].arrive(txcount=sz, predicate=isThread0)
+
+        # 3. Wait for completion of TMA load with mbarrier
+        mbar_group[0].try_wait()
+
+        x_val = memref.load(x_smem, [const(0), tidx])
+        y_val = memref.load(y_smem, [const(0), tidx])
+
+        # SAXPY: y[i] += a * x[i];
+        y_val += x_val * alpha
+
+        memref.store(y_val, y_dev, [bidx, tidx])
+
+    saxpy_tma_kernel()
+
+    t7 = gpu.memcpy(token_ty, [t6], y, y_dev)
+    gpu.wait(token_ty, [t7])
+
+
+# 3. Pass numpy arrays to MLIR
+M = 256
+N = 32
+alpha = 2.0
+x = np.random.randn(M, N).astype(np.float32)
+y = np.ones((M, N), np.float32)
+saxpy(x, y, alpha)
+
+#  4. Verify MLIR with reference computation
+ref = np.ones((M, N), np.float32)
+ref += x * alpha
+np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
+print("PASS")
+# CHECK-NOT: Mismatched elements

diff  --git a/mlir/test/Examples/NVGPU/Ch3.py b/mlir/test/Examples/NVGPU/Ch3.py
new file mode 100644
index 00000000000000..a417014de8b49a
--- /dev/null
+++ b/mlir/test/Examples/NVGPU/Ch3.py
@@ -0,0 +1,129 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN:   %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+#  Chapter 3 : GEMM 128x128x64 with Tensor Core
+# ===----------------------------------------------------------------------===//
+#
+# This program demonstrates a GEMM operation with 128x128x64 matrix multiplication
+#
+# This chapter introduces demonstrates:
+# 1. Execute TMA Load for two input matrices
+# 2. Performs Tensor Core GEMM 128x128x64 by warpgroup
+# 3. Stores fragmented registers to global memory by warpgroup
+#
+# ===----------------------------------------------------------------------===//
+
+
+from mlir import ir
+from mlir.dialects import nvgpu, scf, arith, memref, vector, gpu
+from tools.nvdsl import *
+from mlir.extras import types as T
+import numpy as np
+
+
+def tma_load(
+    mbar_group: Mbarriers,
+    a_tma: TMA,
+    b_tma: TMA,
+    p,
+):
+    """
+    TMA loads two input matrices from global memory to shared memory. It performs the following operations:
+
+       - tma.load a_shared_memory[0] at coordinate [0, 0]  (Loads 128x64)
+       - tma.load b_shared_memory[0] at coordinate [0, 0]  (Loads 64x64)
+       - tma.load b_shared_memory[0] at coordinate [64, 0] (Loads 64x64)
+
+       mbarrier.arrive ta_count = 128x64xf16 + 64x128xf16
+    """
+
+    size_tma_a = get_type_size(a_tma.tma_memref)
+    size_tma_b = get_type_size(b_tma.tma_memref)
+    ta_count = size_tma_a + (size_tma_b * 2)
+
+    off_b = size_tma_a
+    off_b2 = off_b + size_tma_b
+    a_elem_ty = a_tma.tma_memref.element_type
+    b_elem_ty = b_tma.tma_memref.element_type
+    a = get_dynamic_shared_memory(a_tma.tma_memref.shape, a_elem_ty)
+    b1 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b)
+    b2 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b2)
+
+    mbar_group[0].arrive(ta_count, predicate=p)
+
+    a_tma.load(a, mbar_group[0], coords=[0, 0], predicate=p)
+    b_tma.load(b1, mbar_group[0], coords=[0, 0], predicate=p)
+    b_tma.load(b2, mbar_group[0], coords=[64, 0], predicate=p)
+
+
+ at NVDSL.mlir_func
+def gemm_128_128_64(a, b, d):
+    token_ty = ir.Type.parse("!gpu.async.token")
+    t1 = gpu.wait(token_ty, [])
+    a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
+    b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])
+    d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], [])
+    t5 = gpu.memcpy(token_ty, [t4], a_dev, a)
+    t6 = gpu.memcpy(token_ty, [t5], b_dev, b)
+    t7 = gpu.wait(token_ty, [t6])
+
+    sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
+    a_tma = TMA([128, 64], a.type, swizzle=sw)
+    b_tma = TMA([64, 64], b.type, swizzle=sw)
+    a_tma.create_descriptor(a_dev)
+    b_tma.create_descriptor(b_dev)
+    a_size = get_type_size(a.type)
+    b_size = get_type_size(b.type)
+    smem_size_in_bytes = a_size + b_size
+
+    @NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(128, 1, 1), smem=smem_size_in_bytes)
+    def gemm_tma_kernel():
+        tidx = gpu.thread_id(gpu.Dimension.x)
+
+        mbar_group = Mbarriers(number_of_barriers=1)
+        isThread0 = tidx == 0
+
+        mbar_group[0].init(1, predicate=isThread0)
+        a_tma.prefetch(predicate=isThread0)
+        b_tma.prefetch(predicate=isThread0)
+
+        a_smem = get_dynamic_shared_memory((M, K), T.f16())
+        b_smem = get_dynamic_shared_memory((K, N), T.f16(), offset=a_size)
+
+        # 1. TMA Load for two input matrices
+        tma_load(mbar_group, a_tma, b_tma, isThread0)
+
+        # 2. All threads wait TMA load completion
+        mbar_group[0].try_wait()
+
+        # 3. Performs Tensor Core GEMM 128x128x64 by warpgroup
+        A = WGMMAMatrix(WGMMAType.Descriptor, [M, K], desc=a_tma, smem=a_smem)
+        B = WGMMAMatrix(WGMMAType.Descriptor, [K, N], desc=b_tma, smem=b_smem)
+        D = WGMMAMatrix(WGMMAType.Accumulator, shape=[M, N], ty=T.f32())
+
+        # Matrix Multiply
+        D += A @ B
+
+        # 4. Stores fragmented registers to global memory by warpgroup
+        D.store_accumulator(d_dev)
+
+    gemm_tma_kernel()
+
+    t8 = gpu.memcpy(token_ty, [t7], d, d_dev)
+    gpu.wait(None, [t8])
+
+
+# Python pass arguments to MLIR
+M = 128
+N = 128
+K = 64
+a = np.random.randn(M, K).astype(np.float16)
+b = np.random.randn(K, N).astype(np.float16)
+d = np.zeros((M, N), np.float32)
+gemm_128_128_64(a, b, d)
+
+ref_d = a.astype(np.float16) @ b.astype(np.float16)
+np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
+print("PASS")
+# CHECK-NOT: Mismatched elements

diff  --git a/mlir/test/Examples/NVGPU/Ch4.py b/mlir/test/Examples/NVGPU/Ch4.py
new file mode 100644
index 00000000000000..8f38d8a90add31
--- /dev/null
+++ b/mlir/test/Examples/NVGPU/Ch4.py
@@ -0,0 +1,323 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN:   %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+#  Chapter 4 : Multistage GEMM with Tensor Core
+# ===----------------------------------------------------------------------===//
+#
+# This program exemplifies a GEMM operation for `f32+=f16*f16`, utilizing the
+# Multistage method with a tile size of 128x128x64. The code completely
+# parallelizes the two outermost loops into thread blocks. It launches one Warp
+# Groups (128 threads in total) and allocates multiple slots/stage in the
+# shared memory. The program consists of three main parts: prologue, mainloop,
+# and epilogue. In the prologue, thread0 requests for TMA to load data into
+# shared memory slots. The mainloop executes MMA while simultaneously loading
+# TMA for the utilized slots. This overlap of TMA and MMA operations enhances
+# performance by maximizing computational throughput.
+#
+# Loops illustration:
+#
+#  for s in range(num_stages):
+#    TMA_128x64_64x128...
+#  for ti in range(M//128):  # -> blockIdx.x
+#   for tj in range(N//128): # -> blockIdx.y
+#    for tk in range(K//64):
+#      MMA_128x128x64...
+#      TMA_128x64_64x128...
+#  Epilogue...
+#
+# This chapter introduces demonstrates:
+#  1. Partition shape based on block IDs
+#  2. Prologue
+#    2.1 Execute TMA Load for two input matrices for each stage
+#  3. Main loop
+#    3.1 Wait for completion of TMA load with mbarrier
+#    3.2 Performs Tensor Core GEMM 64x128x64 by warpgroup
+#    3.3 Load next stage if needed
+#  4. Epilogue
+#    4.1 Store fragmented registers to shared memory
+#    4.2 Store shared memory to global
+#
+# ===----------------------------------------------------------------------===//
+
+
+from mlir import ir
+from mlir.dialects import gpu, scf, nvgpu, nvvm
+from mlir.extras import types as T
+from tools.nvdsl import *
+import numpy as np
+
+
+def partition_shape():
+    """
+    Calculate the partition shape based on the block IDs.
+
+    It partitions the shape like below:
+    for(.. i < M ...)   --> blockIdx.x
+     for(.. j < N ...)  --> blockIdx.y
+      for(.. k < K ...)
+
+    Returns:
+        dimX (int): Dimension along the x-axis.
+        dimY (int): Dimension along the y-axis.
+    """
+    bidx = gpu.block_id(gpu.Dimension.x)
+    bidy = gpu.block_id(gpu.Dimension.y)
+    dimX = bidx * TILE_M
+    dimY = bidy * TILE_N
+    return dimX, dimY
+
+
+def tma_load(
+    mbar_group: Mbarriers,
+    a_tma: TMA,
+    b_tma: TMA,
+    slot,
+    stage,
+    num_stages,
+    p=None,
+):
+    """
+    TMA loads two input matrices from global memory to shared memory. It performs the following operations:
+
+       - tma.load a_shared_memory[off_x]  at coordinate [x, z]      (Loads 128x64)
+       - tma.load b_shared_memory[off_y1] at coordinate [y, x]      (Loads 64x64)
+       - tma.load b_shared_memory[off_y2] at coordinate [y + 64, x] (Loads 64x64)
+
+       mbarrier.arrive ta_count = 128x64x2x4
+    """
+    dimX, dimY = partition_shape()
+
+    tidx = gpu.thread_id(gpu.Dimension.x)
+    begin_b = num_stages * get_type_size(a_tma.tma_memref)
+    size_tma_a = get_type_size(a_tma.tma_memref)
+    size_tma_b = get_type_size(b_tma.tma_memref)
+    ta_count = size_tma_a + (size_tma_b * 2)
+    tidx = gpu.thread_id(gpu.Dimension.x)
+
+    p = tidx == 0 if p is None else p
+
+    off_a = slot * size_tma_a
+    off_b = (slot * size_tma_a) + begin_b
+    off_b2 = off_b + size_tma_b
+    a_elem_ty = a_tma.tma_memref.element_type
+    b_elem_ty = b_tma.tma_memref.element_type
+    a = get_dynamic_shared_memory(a_tma.tma_memref.shape, a_elem_ty, off_a)
+    b1 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b)
+    b2 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b2)
+
+    mbar_group[slot].arrive(ta_count, predicate=p)
+
+    c1 = stage * 64
+    a_tma.load(a, mbar_group[slot], coords=[c1, dimX], predicate=p)
+    b_tma.load(b1, mbar_group[slot], coords=[dimY, c1], predicate=p)
+    b_tma.load(b2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p)
+
+
+def initialize(a_tma: TMA, b_tma: TMA, num_stages):
+    """
+    Initialize mbarriers and prefetch TMA descriptors.
+    """
+    tidx = gpu.thread_id(gpu.Dimension.x)
+    mbar_group = Mbarriers(number_of_barriers=num_stages)
+    isThread0 = tidx == const(0)
+    with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
+        for i in scf.for_(0, num_stages, 1):
+            mbar_group[i].init(1)
+            scf.yield_([])
+        a_tma.prefetch()
+        b_tma.prefetch()
+        scf.yield_([])
+
+    return mbar_group
+
+
+def prologue(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA, num_stages):
+    """
+    Prologue of the GEMM kernel. It loads 2 input matrices for each stage in loop like below:
+
+    for stage in range(NUM_STAGES):
+        tma_load x, y, stage
+
+    """
+    ns = num_stages if num_stages == 1 else num_stages - 1
+    for iv in scf.for_(0, ns, 1):
+        tma_load(mbar_group, a_tma, b_tma, iv, iv, num_stages)
+        scf.yield_([])
+
+
+def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA, num_stages):
+    """
+    Main loop of the Multistage GEMM kernel. It iterates through
+    stages and performs matrix multiplication, loading data by TMA to shared memory. It like following
+
+    MatrixAccumulator D
+    for k in range(K // TILE_K):
+
+        try_wait(stage, ...)    # Wait TMA load
+
+        Matrix A(stage, ...)    # Find shared memory slot
+        Matrix B(stage, ...)    # Find shared memory slot
+        D += A @ B              # Multiply and accumulate
+
+        if(needLoad)            # Load next stage if needed
+            tma_load(x, y, nextSlot, nextStage)
+
+    """
+    ns = num_stages if num_stages == 1 else num_stages - 1
+
+    tidx = gpu.thread_id(gpu.Dimension.x)
+    begin_b = num_stages * get_type_size(a_tma.tma_memref)
+
+    size_a = TILE_M * TILE_K * get_type_size(T.f16())
+
+    # Initialize A and B (input matrices) and C (accumulator)
+    A = WGMMAMatrix(WGMMAType.Descriptor, [TILE_M, TILE_K], desc=a_tma)
+    B = WGMMAMatrix(WGMMAType.Descriptor, [TILE_K, TILE_N], desc=b_tma)
+    D = WGMMAMatrix(WGMMAType.Accumulator, shape=[TILE_M, TILE_N], ty=T.f32())
+
+    phase = const(False, ty=T.bool())
+
+    # Main Loop
+    for_op = scf.ForOp(const(0), const(K // TILE_K), const(1), [D.acc_op, phase])
+    with ir.InsertionPoint(for_op.body):
+        phase = for_op.inner_iter_args[1]
+        iv = for_op.induction_variable
+        stage = iv % num_stages
+
+        # Wait for current stage
+        mbar_group[stage].try_wait(phase=phase)
+
+        # Find shared memory slot
+        offset_a = stage * size_a
+        offset_b = offset_a + begin_b
+        a_smem = get_dynamic_shared_memory([TILE_M, TILE_K], T.f16(), offset_a)
+        b_smem = get_dynamic_shared_memory([TILE_K, TILE_N], T.f16(), offset_b)
+
+        # Iterate input matrices, update accumulator
+        A.update_smem(a_smem)
+        B.update_smem(b_smem)
+        D.update_accumulator(for_op.inner_iter_args[0])
+
+        # Matrix Multiply
+        D += A @ B
+
+        # Wait Tensor Core for single stage
+        if num_stages == 1:
+            nvvm.WgmmaWaitGroupSyncOp(0)
+
+        # Load next stage
+        pred = ((iv + ns) < const(K // TILE_K)) & (tidx == 0)
+        nextStage = iv + ns
+        nextSlot = nextStage % num_stages
+        tma_load(mbar_group, a_tma, b_tma, nextSlot, nextStage, num_stages, pred)
+
+        # Switch phase parity for the mbarrier
+        newPhase = arith.select(
+            stage == (num_stages - 1),
+            (phase ^ const(True, ty=T.bool())),
+            phase,
+        )
+        scf.yield_([D.acc_op, newPhase])
+
+    nvvm.WgmmaWaitGroupSyncOp(0)
+
+    D.update_accumulator(for_op.results[0])
+    return D
+
+
+def epilogue(D: WGMMAMatrix, d_dev):
+    """
+    Epilogue of the GEMM kernel. It stores the fragmented registers to global memory.
+
+    MatrixAccumulator D               # Fragmented results
+    store D -> Shared Memory          # Store Shared Memory
+    Shared Memory -> Z[dimX][dimY]    # Store Shared Memory to Global Memory
+
+    """
+    tidx = gpu.thread_id(gpu.Dimension.x)
+    dimX, dimY = partition_shape()
+
+    d_smem = get_dynamic_shared_memory([TILE_M, TILE_N], T.f32())
+    d_gmem = memref.subview(d_dev, [dimX, dimY], [TILE_M, TILE_N], [1, 1])
+
+    # Store (registers -> shared memory)
+    D.store_accumulator(d_smem)
+    gpu.barrier()
+
+    # Store (shared memory --> global memory)
+    for i in scf.for_(0, TILE_M, 1):
+        val = memref.load(d_smem, [i, tidx])
+        memref.store(val, d_gmem, [i, tidx])
+        scf.yield_([])
+
+
+# The decorator generates
+#   a -> memref<MxKxf16>
+#   b -> memref<NxKf16>
+#   d -> memref<MxNxf32>
+ at NVDSL.mlir_func
+def gemm_multistage(a, b, d, num_stages):
+    token_ty = ir.Type.parse("!gpu.async.token")
+    t1 = gpu.wait(token_ty, [])
+    a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
+    b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])
+    d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], [])
+    t5 = gpu.memcpy(token_ty, [t4], a_dev, a)
+    t6 = gpu.memcpy(token_ty, [t5], b_dev, b)
+    t7 = gpu.wait(token_ty, [t6])
+
+    sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
+    a_tma = TMA([128, 64], a.type, swizzle=sw)
+    b_tma = TMA([64, 64], b.type, swizzle=sw)
+    a_tma.create_descriptor(a_dev)
+    b_tma.create_descriptor(b_dev)
+
+    grid = [(M // TILE_M), (N // TILE_N), 1]
+    block = [128, 1, 1]
+
+    size_a = get_type_size(a.type.element_type) * TILE_M * TILE_K
+    size_b = get_type_size(b.type.element_type) * TILE_N * TILE_K
+    smem_size_in_bytes = (size_a + size_b) * num_stages
+
+    @NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=smem_size_in_bytes)
+    def gemm_multistage_kernel():
+        # Initialize mbarriers and prefetch TMA descriptors
+        mbar_group = initialize(a_tma, b_tma, num_stages)
+
+        # Fill the pipeline stages
+        prologue(mbar_group, a_tma, b_tma, num_stages)
+
+        # Main loop
+        D = mainloop(mbar_group, a_tma, b_tma, num_stages)
+
+        # Store registers to global memory
+        epilogue(D, d_dev)
+
+    gemm_multistage_kernel()
+
+    t8 = gpu.memcpy(token_ty, [t7], d, d_dev)
+    gpu.wait(None, [t8])
+
+
+# Python pass arguments to MLIR
+N = 256
+M = 512
+K = 1024
+TILE_M = 128
+TILE_N = 128
+TILE_K = 64
+a = np.random.randn(M, K).astype(np.float16)
+b = np.random.randn(K, N).astype(np.float16)
+d = np.zeros((M, N), np.float32)
+
+gemm_multistage(a, b, d, num_stages=7)
+
+
+# Verify MLIR with reference computation
+ref_d = a.astype(np.float16) @ b.astype(np.float16)
+np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
+
+
+print("PASS")
+# CHECK-NOT: Mismatched elements

diff  --git a/mlir/test/Examples/NVGPU/Ch5.py b/mlir/test/Examples/NVGPU/Ch5.py
new file mode 100644
index 00000000000000..92e9314e1b812d
--- /dev/null
+++ b/mlir/test/Examples/NVGPU/Ch5.py
@@ -0,0 +1,321 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN:   %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+#  Chapter 5 : Warp Specialized GEMM with Tensor Core
+# ===----------------------------------------------------------------------===//
+#
+# This program demonstrates a GEMM operation for `f32+=f16*f16`, utilizing the
+# Warp Specialized method with a tile size of 128x128x64. The code completely
+# parallelizes the two outermost loops into thread blocks. It launches two Warp
+# Groups (256 threads in total): one for the producer and the other for the consumer.
+# Each group takes a 
diff erent control-flow. The producer thread group is responsible
+# for loading data into shared memory, while the consumer group executes the Tensor
+# Core GEMM operation and epilogue.
+#
+#  for ti in range(M//128):  # -> blockIdx.x
+#   for tj in range(N//128): # -> blockIdx.y
+#    with wg_producer:
+#     for tk in range(K//64):
+#        TMA_128x64_64x128...
+#    with wg_consumer:
+#     for tk in range(K//64):
+#        MMA_128x128x64...
+#     Epilogue..
+#
+# This chapter demonstrates:
+#  2 WG (warpgroups)
+#    Producer:
+#       2.1.1 Wait MMA Barrier
+#       2.1.1 Load TMA with TMA barrier
+#       2.1.1 Arrive TMA barrier with txcount
+#    Consumer:
+#       Loop
+#           Wait TMA barrier
+#           Performs Tensor Core GEMM 64x128x64 by warpgroup
+#           Arrive MMA Barrier
+#       Epilogue
+#           Store fragmented registers to shared memory
+#           Store shared memory to global
+#
+# ===----------------------------------------------------------------------===//
+
+
+from mlir import ir
+from mlir.dialects import gpu, scf, nvgpu, nvvm
+from mlir.extras import types as T
+from tools.nvdsl import *
+import numpy as np
+
+
+def partition_shape():
+    """
+    Calculate the partition shape based on the block IDs.
+
+    It parallelizes the two outermost loops into thread blocks.
+    for ti in range(M//128):    # -> blockIdx.x
+     for tj in range(N//128):   # -> blockIdx.y
+      D = 0
+      for tk in range(K//64):
+       for i in range(128):
+        for j in range(128):
+         for k in range(64):
+           FMA
+
+    Returns:
+        dimX (int): Dimension along the x-axis.
+        dimY (int): Dimension along the y-axis.
+    """
+    bidx = gpu.block_id(gpu.Dimension.x)
+    bidy = gpu.block_id(gpu.Dimension.y)
+    dimX = bidx * TILE_M
+    dimY = bidy * TILE_N
+    return dimX, dimY
+
+
+def tma_load(
+    mbar_group: Mbarriers,
+    a_tma: TMA,
+    b_tma: TMA,
+    slot,
+    stage,
+    num_stages,
+    p=None,
+):
+    """
+    TMA loads two input matrices from global memory to shared memory. It performs the following operations:
+
+       - tma.load a_shared_memory[off_x]  at coordinate [x, z]      (Loads 128x64)
+       - tma.load b_shared_memory[off_y1] at coordinate [y, x]      (Loads 64x64)
+       - tma.load b_shared_memory[off_y2] at coordinate [y + 64, x] (Loads 64x64)
+
+       mbarrier.arrive ta_count = 128x64x2x4
+    """
+    dimX, dimY = partition_shape()
+
+    tidx = gpu.thread_id(gpu.Dimension.x)
+    begin_b = num_stages * get_type_size(a_tma.tma_memref)
+    size_tma_a = get_type_size(a_tma.tma_memref)
+    size_tma_b = get_type_size(b_tma.tma_memref)
+    ta_count = size_tma_a + (size_tma_b * 2)
+
+    off_a = slot * size_tma_a
+    off_b = (slot * size_tma_a) + begin_b
+    off_b2 = off_b + size_tma_b
+    a_elem_ty = a_tma.tma_memref.element_type
+    b_elem_ty = b_tma.tma_memref.element_type
+    a = get_dynamic_shared_memory(a_tma.tma_memref.shape, a_elem_ty, off_a)
+    b1 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b)
+    b2 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b2)
+
+    mbar_group[slot].arrive(ta_count, predicate=p)
+    p = (tidx % WARP_GROUP_SIZE) == 0
+    c1 = stage * 64
+    a_tma.load(a, mbar_group[slot], coords=[c1, dimX], predicate=p)
+    b_tma.load(b1, mbar_group[slot], coords=[dimY, c1], predicate=p)
+    b_tma.load(b2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p)
+
+
+def initialize(a_tma: TMA, b_tma: TMA, num_stages):
+    """
+    Initialize mbarriers and prefetch TMA descriptors.
+    """
+    tidx = gpu.thread_id(gpu.Dimension.x)
+    mbar_group_tma = Mbarriers(number_of_barriers=num_stages)
+    mbar_group_mma = Mbarriers(number_of_barriers=num_stages)
+    isThread0 = tidx == const(0)
+    with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
+        for i in scf.for_(0, num_stages, 1):
+            mbar_group_tma[i].init(1)
+            mbar_group_mma[i].init(1)
+            scf.yield_([])
+        a_tma.prefetch()
+        b_tma.prefetch()
+        scf.yield_([])
+
+    return mbar_group_tma, mbar_group_mma
+
+
+def switch_phase(stage, phase, num_stages):
+    p = stage == (num_stages - 1)
+    phase = arith.select(
+        p,
+        (phase ^ const(True, ty=T.bool())),
+        phase,
+    )
+    return phase
+
+
+def producer_loop(
+    mbar_tma: Mbarriers,
+    mbar_mma: Mbarriers,
+    a_tma: TMA,
+    b_tma: TMA,
+    wg_me: Warpgroup,
+    num_stages,
+):
+    phase = const(True, ty=T.bool())
+
+    for iv, phase in scf.for_(0, (K // TILE_K), 1, [phase]):
+        stage = iv % num_stages
+        # Wait MMA to be done
+        mbar_mma[stage].try_wait(phase)
+        # New phase for mbarrier
+        phase = switch_phase(stage, phase, num_stages)
+        # TMA Load
+        tma_load(mbar_tma, a_tma, b_tma, stage, iv, num_stages, wg_me.is_wg_primary)
+        scf.yield_([phase])
+
+
+def consumer_loop(
+    mbar_tma: Mbarriers,
+    mbar_mma: Mbarriers,
+    a_tma: TMA,
+    b_tma: TMA,
+    wg_me: Warpgroup,
+    num_stages,
+):
+    begin_b = num_stages * get_type_size(a_tma.tma_memref)
+
+    size_a = TILE_M * TILE_K * get_type_size(T.f16())
+
+    phase = const(False, ty=T.bool())
+    A = WGMMAMatrix(WGMMAType.Descriptor, [TILE_M, TILE_K], desc=a_tma)
+    B = WGMMAMatrix(WGMMAType.Descriptor, [TILE_K, TILE_N], desc=b_tma)
+    D = WGMMAMatrix(WGMMAType.Accumulator, shape=[TILE_M, TILE_N], ty=T.f32())
+
+    for_op = scf.ForOp(const(0), const(K // TILE_K), const(1), [D.acc_op, phase])
+    with ir.InsertionPoint(for_op.body):
+        phase = for_op.inner_iter_args[1]
+        iv = for_op.induction_variable
+        stage = iv % num_stages
+
+        # Wait TMA for current stage
+        mbar_tma[stage].try_wait(phase)
+
+        # Find shared memory slot
+        offset_a = stage * size_a
+        offset_b = offset_a + begin_b
+        a_smem = get_dynamic_shared_memory([TILE_M, TILE_K], T.f16(), offset_a)
+        b_smem = get_dynamic_shared_memory([TILE_K, TILE_N], T.f16(), offset_b)
+
+        # Iterate input matrices, update accumulator
+        A.update_smem(a_smem)
+        B.update_smem(b_smem)
+        D.update_accumulator(for_op.inner_iter_args[0])
+
+        # Matrix Multiply
+        D += A @ B
+
+        # MMA Barrier Arrive
+        p_arrive = (iv > 0) & wg_me.is_wg_primary
+        with ir.InsertionPoint(scf.IfOp(p_arrive).then_block):
+            barId = arith.select((stage == 0), const(num_stages - 1), (stage - 1))
+            mbar_mma[barId].arrive()
+            scf.yield_([])
+
+        phase = switch_phase(stage, phase, num_stages)
+        scf.yield_([D.acc_op, phase])
+
+    nvvm.WgmmaWaitGroupSyncOp(0)
+    D.update_accumulator(for_op.results[0])
+    return D
+
+
+def epilogue(D: WGMMAMatrix, d_dev):
+    """
+    Epilogue of the GEMM kernel. It stores the fragmented registers to global memory.
+
+    MatrixAccumulator D               # Fragmented results
+    store D -> Shared Memory          # Store Shared Memory
+    Shared Memory -> Z[dimX][dimY]    # Store Shared Memory to Global Memory
+
+    """
+    tidx = gpu.thread_id(gpu.Dimension.x)
+    dimX, dimY = partition_shape()
+    # s = tidx - WARP_GROUP_SIZE
+    # debug_print("[Epilogue] store to global memory @ s={}", s)
+
+    d_smem = get_dynamic_shared_memory([TILE_M, TILE_N], T.f32())
+    d_gmem = memref.subview(d_dev, [dimX, dimY], [TILE_M, TILE_N], [1, 1])
+
+    # Store (registers -> shared memory)
+    D.store_accumulator(d_smem)
+    gpu.barrier()
+
+    # Store (shared memory --> global memory)
+    for i in scf.for_(0, TILE_M, 1):
+        val = memref.load(d_smem, [i, tidx])
+        memref.store(val, d_gmem, [i, tidx])
+        scf.yield_([])
+
+
+ at NVDSL.mlir_func
+def gemm_warp_specialized(a, b, d, num_stages):
+    token_ty = ir.Type.parse("!gpu.async.token")
+    t1 = gpu.wait(token_ty, [])
+    a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
+    b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])
+    d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], [])
+    t5 = gpu.memcpy(token_ty, [t4], a_dev, a)
+    t6 = gpu.memcpy(token_ty, [t5], b_dev, b)
+    t7 = gpu.wait(token_ty, [t6])
+
+    sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
+    a_tma = TMA([128, 64], a.type, swizzle=sw)
+    b_tma = TMA([64, 64], b.type, swizzle=sw)
+    a_tma.create_descriptor(a_dev)
+    b_tma.create_descriptor(b_dev)
+
+    grid = [(M // TILE_M), (N // TILE_N), 1]
+    block = [256, 1, 1]
+
+    size_a = get_type_size(a.type.element_type) * TILE_M * TILE_K
+    size_b = get_type_size(b.type.element_type) * TILE_N * TILE_K
+    smem_size_in_bytes = (size_a + size_b) * num_stages
+
+    @NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=smem_size_in_bytes)
+    def gemm_warp_specialized_kernel():
+        # Init Warpgroups
+        wg_producer = Warpgroup(primary_thread=128, register_size=40)
+        wg_consumer = Warpgroup(primary_thread=0, register_size=232)
+
+        # Initialize mbarriers and prefetch TMA descriptors
+        mbar_mma, mbar_tma = initialize(a_tma, b_tma, num_stages)
+
+        # Producer performs TMA
+        with wg_producer:
+            producer_loop(mbar_tma, mbar_mma, a_tma, b_tma, wg_producer, num_stages)
+
+        # Consumer performs MMA/Tensor Core
+        with wg_consumer:
+            D = consumer_loop(mbar_tma, mbar_mma, a_tma, b_tma, wg_consumer, num_stages)
+            epilogue(D, d_dev)
+
+    gemm_warp_specialized_kernel()
+
+    t8 = gpu.memcpy(token_ty, [t7], d, d_dev)
+    gpu.wait(None, [t8])
+
+
+# Python pass arguments to MLIR
+N = 256
+M = 512
+K = 1024
+TILE_M = 128
+TILE_N = 128
+TILE_K = 64
+a = np.random.randn(M, K).astype(np.float16)
+b = np.random.randn(K, N).astype(np.float16)
+d = np.zeros((M, N), np.float32)
+
+gemm_warp_specialized(a, b, d, num_stages=7)
+
+
+# Verify MLIR with reference computation
+ref_d = a.astype(np.float16) @ b.astype(np.float16)
+np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
+
+
+print("PASS")
+# CHECK-NOT: Mismatched elements

diff  --git a/mlir/test/Examples/NVGPU/lit.local.cfg b/mlir/test/Examples/NVGPU/lit.local.cfg
new file mode 100644
index 00000000000000..689cd252e7a254
--- /dev/null
+++ b/mlir/test/Examples/NVGPU/lit.local.cfg
@@ -0,0 +1,4 @@
+config.unsupported = False
+if not config.enable_cuda_runner or not config.mlir_run_cuda_sm90_tests:
+  config.unsupported = True
+  
\ No newline at end of file

diff  --git a/mlir/test/Examples/NVGPU/tools/lit.local.cfg b/mlir/test/Examples/NVGPU/tools/lit.local.cfg
new file mode 100644
index 00000000000000..d9f34f219c4d95
--- /dev/null
+++ b/mlir/test/Examples/NVGPU/tools/lit.local.cfg
@@ -0,0 +1,3 @@
+# Files in this directory are tools, not tests.
+config.unsupported = True
+

diff  --git a/mlir/test/Examples/NVGPU/tools/nvdsl.py b/mlir/test/Examples/NVGPU/tools/nvdsl.py
new file mode 100644
index 00000000000000..600cae5b47eeec
--- /dev/null
+++ b/mlir/test/Examples/NVGPU/tools/nvdsl.py
@@ -0,0 +1,456 @@
+from enum import Enum
+import functools, sys, ctypes, os, errno
+import numpy as np
+from functools import partialmethod
+from mlir import ir
+from mlir.dialects import arith, func, gpu, memref, nvgpu, scf, nvvm
+from mlir.extras import types as T
+from mlir import runtime as rt
+from tools import nvgpucompiler
+
+MLIR_DYNAMIC = -9223372036854775808
+
+
+def const(value: int, ty=None):
+    ty = T.index() if ty is None else ty
+    if isinstance(value, ir.Value) and (
+        value.type.isinstance(value.type) or T.bool().isinstance(value.type)
+    ):
+        return value
+    return arith.constant(ty, value)
+
+
+def get_type_size(ty):
+    if ir.MemRefType.isinstance(ty):
+        size = get_type_size(ty.element_type)
+        for sz in ty.shape:
+            size *= sz
+        return size
+    if ir.FloatType.isinstance(ty):
+        return ir.FloatType(ty).width // 8
+    if ir.IntegerType.isinstance(ty):
+        return ir.IntegerType(ty).width // 8
+    raise NotImplementedError(ty)
+
+
+def get_mlir_func_obj_ty(inputArgs):
+    args = []
+    c_int_p = ctypes.c_int * 1
+    c_float_p = ctypes.c_float * 1
+    c_bool_p = ctypes.c_bool * 1
+    for arg in inputArgs:
+        if isinstance(arg, bool):
+            args.append(c_bool_p(arg))
+        elif isinstance(arg, int):
+            args.append(c_int_p(arg))
+        elif isinstance(arg, float):
+            args.append(c_float_p(arg))
+        elif isinstance(arg, np.ndarray):
+            args.append(
+                ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(arg)))
+            )
+        else:
+            raise NotImplementedError(arg)
+    return args
+
+
+class Mbarriers:
+    def __init__(self, number_of_barriers=1):
+        self.mbar_ty = ir.Type.parse(
+            "!nvgpu.mbarrier.group<memorySpace=#gpu.address_space<workgroup>, num_barriers = "
+            + str(number_of_barriers)
+            + ">"
+        )
+        self.mbar_group_op = nvgpu.mbarrier_create(self.mbar_ty)
+        self.number_of_barriers = number_of_barriers
+
+    def __getitem__(self, key):
+        self.id_op = const(key)
+        return self
+
+    def init(self, count: int, predicate=None):
+        count_op = const(count)
+        if predicate is None:
+            nvgpu.mbarrier_init(self.mbar_group_op, count_op, self.id_op)
+        else:
+            nvgpu.mbarrier_init(
+                self.mbar_group_op, count_op, self.id_op, predicate=predicate
+            )
+
+    def arrive(self, txcount: int = 0, predicate=None):
+        if txcount != 0:
+            txcount_op = const(txcount)
+            nvgpu.mbarrier_arrive_expect_tx(
+                self.mbar_group_op, txcount_op, self.id_op, predicate=predicate
+            )
+        else:
+            nvgpu.mbarrier_arrive(
+                ir.Type.parse("!nvgpu.mbarrier.token"), self.mbar_group_op, self.id_op
+            )
+
+    def try_wait(self, phase: bool = False, ticks: int = 10000000):
+        ticks_op = const(ticks)
+        phase_op = const(phase, T.bool())
+        nvgpu.MBarrierTryWaitParityOp(
+            self.mbar_group_op,
+            phase_op,
+            ticks_op,
+            mbarId=self.id_op,
+        )
+
+
+class TMA:
+    """A class that builds a TMA descriptor."""
+
+    def __init__(
+        self,
+        tma_box_shape,
+        memref_ty,
+        swizzle=nvgpu.TensorMapSwizzleKind.SWIZZLE_NONE,
+        l2promo=nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
+        oob=nvgpu.TensorMapOOBKind.OOB_ZERO,
+        interleave=nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
+    ):
+        self.swizzle = swizzle  # mlir.nvgpu.TensorMapSwizzleKind
+        self.l2promo = l2promo  # mlir.nvgpu.TensorMapL2PromoKind
+        self.oob = oob  # mlir.nvgpu.TensorMapOOBKind
+        self.interleave = interleave  # mlir.nvgpu.TensorMapInterleaveKind
+        self.tma_box_shape = tma_box_shape
+        self.memref_ty = memref_ty  # MemRefType
+        self.tma_memref = ir.MemRefType.get(tma_box_shape, memref_ty.element_type)
+
+    @property
+    def tensormap_descriptor_ty(self):
+        """Returns a tensormap descriptor type."""
+        tensorMemrefType = ir.MemRefType.get(
+            self.tma_box_shape,
+            self.memref_ty.element_type,
+            memory_space=ir.Attribute.parse("3"),
+        )
+        return nvgpu.TensorMapDescriptorType.get(
+            tensorMemrefType,
+            self.swizzle,
+            self.l2promo,
+            self.oob,
+            self.interleave,
+        )
+
+    def create_descriptor(self, device_ptr):
+        tma_descriptor_ty = self.tensormap_descriptor_ty
+        device_unranked_memref = memref.CastOp(
+            ir.UnrankedMemRefType.get(
+                self.memref_ty.element_type, self.memref_ty.memory_space
+            ),
+            device_ptr,
+        )
+        self.tma_descriptor = nvgpu.TmaCreateDescriptorOp(
+            tma_descriptor_ty, device_unranked_memref, map(const, self.tma_box_shape)
+        )
+        return self.tma_descriptor.result
+
+    def prefetch(self, predicate=None):
+        nvgpu.tma_prefetch_descriptor(self.tma_descriptor, predicate=predicate)
+
+    def load(self, dest, mbarrier: Mbarriers, coords=[0], predicate=None):
+        nvgpu.TmaAsyncLoadOp(
+            dest,
+            mbarrier.mbar_group_op,
+            self.tma_descriptor,
+            coordinates=map(const, coords),
+            mbarId=mbarrier.id_op,
+            predicate=predicate,
+        )
+
+
+WARP_GROUP_SIZE = 128  # Number of threads in a warpgroup
+
+
+class Warpgroup:
+    def __init__(self, primary_thread, register_size):
+        assert (primary_thread % WARP_GROUP_SIZE) == 0
+        tidx = gpu.thread_id(gpu.Dimension.x)
+        self.primary_thread = primary_thread
+        self.register_size = register_size
+        self.is_wg_primary = (tidx % WARP_GROUP_SIZE) == 0
+        self.wg_id = tidx / WARP_GROUP_SIZE
+        self.is_me = self.wg_id == (primary_thread // WARP_GROUP_SIZE)
+
+    def __enter__(self):
+        if_op = scf.IfOp(self.is_me)
+        self.ipoint_op = ir.InsertionPoint(if_op.then_block)
+        self.ipoint_op.__enter__()
+        if self.register_size < 64:
+            nvvm.setmaxregister(self.register_size, nvvm.SetMaxRegisterAction.decrease)
+        else:
+            nvvm.setmaxregister(self.register_size, nvvm.SetMaxRegisterAction.increase)
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        scf.yield_([])
+        self.ipoint_op.__exit__(exc_type, exc_value, traceback)
+        return True
+
+
+class WGMMAType(Enum):
+    Accumulator = 1
+    Descriptor = 2
+
+
+class WGMMAMatrix:
+    def __init__(
+        self,
+        matrix_type: WGMMAType,
+        shape: list = None,
+        desc: TMA = None,
+        smem=None,
+        ty=None,
+        acc_op=None,
+    ):
+        if acc_op is None:
+            self.M = shape[0]
+            self.N = shape[1]
+            self.ty = ty
+            self.matrix_type = matrix_type
+            self.desc = desc
+            self.smem = smem
+            if matrix_type is WGMMAType.Accumulator:
+                self.acc_op = nvgpu.warpgroup_mma_init_accumulator(self.acc_ty)
+        elif acc_op:
+            self.acc_op = acc_op
+            self.matrix_type = WGMMAType.Accumulator
+
+    @property
+    def acc_ty(self):
+        parse_str = f"!nvgpu.warpgroup.accumulator<fragmented=vector<{self.M}x{self.N}x{self.ty}>>"
+        return ir.Type.parse(parse_str)
+
+    @property
+    def wgmma_ty(self):
+        parse_str = f"!nvgpu.warpgroup.descriptor<tensor=memref<{self.M}x{self.N}x{self.desc.memref_ty.element_type}, #gpu.address_space<workgroup>>>"
+        return ir.Type.parse(parse_str)
+
+    def store_accumulator(self, dest):
+        assert self.matrix_type == WGMMAType.Accumulator
+        nvgpu.warpgroup_mma_store(self.acc_op, dest)
+
+    def update_smem(self, smem):
+        self.smem = smem
+
+    def update_accumulator(self, acc_op):
+        self.acc_op = acc_op
+
+    def __matmul__(self, rhs):
+        lhs = nvgpu.warpgroup_generate_descriptor(
+            self.wgmma_ty, self.smem, self.desc.tma_descriptor
+        )
+        rhs = nvgpu.warpgroup_generate_descriptor(
+            rhs.wgmma_ty, rhs.smem, rhs.desc.tma_descriptor
+        )
+        return [lhs, rhs]
+
+    def __iadd__(self, matmulResult):
+        lhs = matmulResult[0]
+        rhs = matmulResult[1]
+        acc_op = nvgpu.WarpgroupMmaOp(
+            self.acc_op.type, lhs, rhs, self.acc_op, transposeB=True
+        )
+        return WGMMAMatrix(WGMMAType.Accumulator, acc_op=acc_op)
+
+
+def get_dynamic_shared_memory(shape=None, ty=None, offset: int = 0):
+    smem_space_str = "#gpu.address_space<workgroup>"
+    smem_space = ir.Attribute.parse(smem_space_str)
+    dynamic_smem = gpu.dynamic_shared_memory(
+        ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space)
+    )
+    if shape is None:
+        return dynamic_smem
+    memref_ty = ir.MemRefType.get(shape, ty, memory_space=smem_space)
+    return memref.view(
+        ir.MemRefType.get(
+            memref_ty.shape, memref_ty.element_type, memory_space=smem_space
+        ),
+        dynamic_smem,
+        const(offset),
+        [],
+    )
+
+
+def get_mlir_ty(arg):
+    def get_mlir_ty_from_np(dtype):
+        if dtype == np.float16:
+            return T.f16()
+        if dtype == np.float32:
+            return T.f32()
+        if dtype == np.float64:
+            return T.f64()
+        if dtype == np.int32:
+            return T.i32()
+        if dtype == np.int64:
+            return T.i64()
+        raise NotImplementedError(dtype)
+
+    if isinstance(arg, bool):
+        return T.bool()
+    elif isinstance(arg, int):
+        return T.index()
+    elif isinstance(arg, float):
+        return T.f32()
+    elif isinstance(arg, np.ndarray):
+        descriptor = rt.get_ranked_memref_descriptor(arg)
+        dtype = get_mlir_ty_from_np(arg.dtype)
+        shape = descriptor.shape
+        return memref.MemRefType.get(shape, dtype)
+    raise NotImplementedError(arg)
+
+
+class NVDSL:
+    @staticmethod
+    def mlir_gpu_launch(grid=(1, 1, 1), block=(1, 1, 1), smem=0):
+        def decorator(func):
+            @functools.wraps(func)
+            def wrapper(*args, **kwargs):
+                launch_op = gpu.LaunchOp(
+                    None,
+                    [],
+                    *map(const, grid),
+                    *map(const, block),
+                    dynamicSharedMemorySize=arith.constant(T.i32(), smem),
+                )
+                launch_op.body.blocks.append(*([T.index()] * 12))
+                with ir.InsertionPoint(launch_op.body.blocks[0]):
+                    result = func(*args, **kwargs)
+                    gpu.terminator()
+                    return result
+
+            return wrapper
+
+        return decorator
+
+    @staticmethod
+    def mlir_func(funcBody):
+        @functools.wraps(funcBody)
+        def wrapper(*args, **kwargs):
+            function_name = funcBody.__name__
+
+            def saveIR(module):
+                """Save generated IR"""
+                if True:  # self.saveIR:
+                    # print(mlir_nvgpu_module)
+                    original_stdout = sys.stdout
+                    with open("nvdsl.mlir", "w") as f:
+                        sys.stdout = f
+                        print(module)
+                        sys.stdout = original_stdout
+
+            def _binary_op(lhs, rhs, op: str, predAtt="") -> "ArithValue":
+                """Generate MLIR's Arith dialects binary operations."""
+                rhs = const(rhs)
+                if arith._is_float_type(lhs.type) and arith._is_float_type(rhs.type):
+                    op += "F"
+                    if op.startswith("Cmp"):
+                        predicateAttr = getattr(arith, f"CmpFPredicate").__dict__[
+                            predAtt
+                        ]
+                elif arith._is_integer_like_type(
+                    lhs.type
+                ) and arith._is_integer_like_type(lhs.type):
+                    if op == "Div" or op == "Rem":
+                        op += "U"
+                    op += "I"
+                    if op.startswith("Cmp"):
+                        predicateAttr = getattr(arith, f"CmpIPredicate").__dict__[
+                            predAtt
+                        ]
+                else:
+                    raise NotImplementedError(
+                        f"Unsupported '{op}' operands: {lhs}, {rhs}"
+                    )
+
+                if op.startswith("Cmp"):
+                    op = getattr(arith, f"{op}Op")
+
+                    return op(predicateAttr, lhs, rhs).result
+                else:
+                    op = getattr(arith, f"{op}Op")
+                    return op(lhs, rhs).result
+
+            @ir.register_value_caster(ir.IndexType.static_typeid)
+            @ir.register_value_caster(ir.F32Type.static_typeid)
+            @ir.register_value_caster(ir.F16Type.static_typeid)
+            @ir.register_value_caster(ir.F64Type.static_typeid)
+            @ir.register_value_caster(ir.IntegerType.static_typeid)
+            class ArithValue(ir.Value):
+                """Overloads operators for MLIR's Arith dialects binary operations."""
+
+                def __init__(self, v):
+                    super().__init__(v)
+
+                __add__ = partialmethod(_binary_op, op="Add")
+                __sub__ = partialmethod(_binary_op, op="Sub")
+                __mul__ = partialmethod(_binary_op, op="Mul")
+                __truediv__ = partialmethod(_binary_op, op="Div")
+                __mod__ = partialmethod(_binary_op, op="Rem")
+                __xor__ = partialmethod(_binary_op, op="XOr")
+                __lt__ = partialmethod(_binary_op, op="Cmp", predAtt="ult")
+                __le__ = partialmethod(_binary_op, op="Cmp", predAtt="ule")
+                __eq__ = partialmethod(_binary_op, op="Cmp", predAtt="eq")
+                __ne__ = partialmethod(_binary_op, op="Cmp", predAtt="ne")
+                __gt__ = partialmethod(_binary_op, op="Cmp", predAtt="ugt")
+                __ge__ = partialmethod(_binary_op, op="Cmp", predAtt="uge")
+                __and__ = partialmethod(_binary_op, op="And")
+                __or__ = partialmethod(_binary_op, op="Or")
+
+                def __str__(self):
+                    return (
+                        super()
+                        .__str__()
+                        .replace(ir.Value.__name__, ArithValue.__name__)
+                    )
+
+            # Generate MLIR Context and start generating IR
+            with ir.Context(), ir.Location.unknown():
+                types = []
+                for arg in args:
+                    types.append(get_mlir_ty(arg))
+
+                # Build IR
+                module = ir.Module.create()
+                with ir.InsertionPoint(module.body):
+                    fop = func.FuncOp(function_name, (types, []))
+                    fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+                    with ir.InsertionPoint(fop.add_entry_block()):
+                        fargs = []
+                        for i, a in enumerate(types):
+                            fargs.append(fop.arguments[i])
+
+                        # Call user function body
+                        result = funcBody(*fargs, **kwargs)
+                        func.ReturnOp([])
+
+                # Save IR in a file
+                # saveIR(module)
+
+                # Verify the module
+                # module.operation.verify()
+
+                # Compile and JIT MLIR module
+                options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
+                support_lib = os.getenv("SUPPORT_LIB")
+                if not os.path.exists(support_lib):
+                    raise FileNotFoundError(
+                        errno.ENOENT, os.strerror(errno.ENOENT), support_lib
+                    )
+                compiler = nvgpucompiler.NvgpuCompiler(
+                    options, opt_level=3, shared_libs=[support_lib]
+                )
+                engine = compiler.compile_and_jit(module)
+
+            # Convert input arguments to MLIR arguments
+            newArgs = get_mlir_func_obj_ty(args)
+
+            # Run the compiled program
+            engine.invoke(function_name, *newArgs)
+
+            return result
+
+        return wrapper

diff  --git a/mlir/test/Examples/NVGPU/tools/nvgpucompiler.py b/mlir/test/Examples/NVGPU/tools/nvgpucompiler.py
new file mode 100644
index 00000000000000..1c9cc74fcd169c
--- /dev/null
+++ b/mlir/test/Examples/NVGPU/tools/nvgpucompiler.py
@@ -0,0 +1,45 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#  This file contains the Nvgpu class.
+
+from mlir import execution_engine
+from mlir import ir
+from mlir import passmanager
+from typing import Sequence
+import errno
+import os
+import sys
+
+_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(_SCRIPT_PATH)
+
+
+class NvgpuCompiler:
+    """Nvgpu class for compiling and building MLIR modules."""
+
+    def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
+        pipeline = f"builtin.module(gpu-lower-to-nvvm-pipeline{{{options}}})"
+        self.pipeline = pipeline
+        self.shared_libs = shared_libs
+        self.opt_level = opt_level
+
+    def __call__(self, module: ir.Module):
+        """Convenience application method."""
+        self.compile(module)
+
+    def compile(self, module: ir.Module):
+        """Compiles the module by invoking the nvgpu pipeline."""
+        passmanager.PassManager.parse(self.pipeline).run(module.operation)
+
+    def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+        """Wraps the module in a JIT execution engine."""
+        return execution_engine.ExecutionEngine(
+            module, opt_level=self.opt_level, shared_libs=self.shared_libs
+        )
+
+    def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+        """Compiles and jits the module."""
+        self.compile(module)
+        return self.jit(module)


        


More information about the Mlir-commits mailing list