[Mlir-commits] [mlir] 4d33082 - [mlir][nvgpu] NVGPU Tutorials (#87065)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 24 03:00:16 PDT 2024
Author: Guray Ozen
Date: 2024-04-24T12:00:12+02:00
New Revision: 4d3308202e52b213a05023c8b8b470b346151de6
URL: https://github.com/llvm/llvm-project/commit/4d3308202e52b213a05023c8b8b470b346151de6
DIFF: https://github.com/llvm/llvm-project/commit/4d3308202e52b213a05023c8b8b470b346151de6.diff
LOG: [mlir][nvgpu] NVGPU Tutorials (#87065)
I have a tutorial at EuroLLVM 2024 ([Zero to Hero: Programming Nvidia
Hopper Tensor Core with MLIR's NVGPU
Dialect](https://llvm.swoogo.com/2024eurollvm/session/2086997/zero-to-hero-programming-nvidia-hopper-tensor-core-with-mlir's-nvgpu-dialect)).
For that, I implemented tutorial codes in Python. The focus is the nvgpu
dialect and how to use its advanced features. I thought it might be
useful to upstream this.
The tutorial codes are as follows:
- **Ch0.py:** Hello World
- **Ch1.py:** 2D Saxpy
- **Ch2.py:** 2D Saxpy using TMA
- **Ch3.py:** GEMM 128x128x64 using Tensor Core and TMA
- **Ch4.py:** Multistage performant GEMM using Tensor Core and TMA
- **Ch5.py:** Warp Specialized GEMM using Tensor Core and TMA
I might implement one more chapter:
- **Ch6.py:** Warp Specialized Persistent ping-pong GEMM
This PR also introduces the nvdsl class, making IR building in the
tutorial easier.
Added:
mlir/test/Examples/NVGPU/Ch0.py
mlir/test/Examples/NVGPU/Ch1.py
mlir/test/Examples/NVGPU/Ch2.py
mlir/test/Examples/NVGPU/Ch3.py
mlir/test/Examples/NVGPU/Ch4.py
mlir/test/Examples/NVGPU/Ch5.py
mlir/test/Examples/NVGPU/lit.local.cfg
mlir/test/Examples/NVGPU/tools/lit.local.cfg
mlir/test/Examples/NVGPU/tools/nvdsl.py
mlir/test/Examples/NVGPU/tools/nvgpucompiler.py
Modified:
Removed:
################################################################################
diff --git a/mlir/test/Examples/NVGPU/Ch0.py b/mlir/test/Examples/NVGPU/Ch0.py
new file mode 100644
index 00000000000000..8f60088178d119
--- /dev/null
+++ b/mlir/test/Examples/NVGPU/Ch0.py
@@ -0,0 +1,50 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN: %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+# Chapter 0 : Hello World
+# ===----------------------------------------------------------------------===//
+#
+# This program demonstrates Hello World:
+# 1. Build MLIR function with arguments
+# 2. Build MLIR GPU kernel
+# 3. Print from a GPU thread
+# 4. Pass arguments, JIT compile and run the MLIR function
+#
+# ===----------------------------------------------------------------------===//
+
+
+from mlir.dialects import gpu
+from tools.nvdsl import *
+
+
+# 1. The decorator generates a MLIR func.func.
+# Everything inside the Python function becomes the body of the func.
+# The decorator also translates `alpha` to an `index` type.
+ at NVDSL.mlir_func
+def main(alpha):
+ # 2. The decorator generates a MLIR gpu.launch.
+ # Everything inside the Python function becomes the body of the gpu.launch.
+ # This allows for late outlining of the GPU kernel, enabling optimizations
+ # like constant folding from host to device.
+ @NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(4, 1, 1))
+ def kernel():
+ tidx = gpu.thread_id(gpu.Dimension.x)
+ # + operator generates arith.addi
+ myValue = alpha + tidx
+ # Print from a GPU thread
+ gpu.printf("GPU thread %llu has %llu\n", [tidx, myValue])
+
+ # 3. Call the GPU kernel
+ kernel()
+
+
+alpha = 100
+# 4. The `mlir_func` decorator JIT compiles the IR and executes the MLIR function.
+main(alpha)
+
+
+# CHECK: GPU thread 0 has 100
+# CHECK: GPU thread 1 has 101
+# CHECK: GPU thread 2 has 102
+# CHECK: GPU thread 3 has 103
diff --git a/mlir/test/Examples/NVGPU/Ch1.py b/mlir/test/Examples/NVGPU/Ch1.py
new file mode 100644
index 00000000000000..da65aa2ef6a172
--- /dev/null
+++ b/mlir/test/Examples/NVGPU/Ch1.py
@@ -0,0 +1,66 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN: %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+# Chapter 1 : 2D Saxpy
+# ===----------------------------------------------------------------------===//
+#
+# This program demonstrates 2D Saxpy:
+# 1. Use GPU dialect to allocate and copy memory host to gpu and vice versa
+# 2. Computes 2D SAXPY kernel using operator overloading
+# 3. Pass numpy arrays to MLIR as memref arguments
+# 4. Verify MLIR program with reference computation in python
+#
+# ===----------------------------------------------------------------------===//
+
+
+from mlir import ir
+from mlir.dialects import gpu, memref
+from tools.nvdsl import *
+import numpy as np
+
+
+ at NVDSL.mlir_func
+def saxpy(x, y, alpha):
+ # 1. Use MLIR GPU dialect to allocate and copy memory
+ token_ty = ir.Type.parse("!gpu.async.token")
+ t1 = gpu.wait(token_ty, [])
+ x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
+ y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
+ t4 = gpu.memcpy(token_ty, [t3], x_dev, x)
+ t5 = gpu.memcpy(token_ty, [t4], y_dev, y)
+ t6 = gpu.wait(token_ty, [t5])
+
+ # 2. Compute 2D SAXPY kernel
+ @NVDSL.mlir_gpu_launch(grid=(M, 1, 1), block=(N, 1, 1))
+ def saxpy_kernel():
+ bidx = gpu.block_id(gpu.Dimension.x)
+ tidx = gpu.thread_id(gpu.Dimension.x)
+ x_val = memref.load(x_dev, [bidx, tidx])
+ y_val = memref.load(y_dev, [bidx, tidx])
+
+ # SAXPY: y[i] += a * x[i];
+ y_val += x_val * alpha
+
+ memref.store(y_val, y_dev, [bidx, tidx])
+
+ saxpy_kernel()
+
+ t7 = gpu.memcpy(token_ty, [t6], y, y_dev)
+ gpu.wait(token_ty, [t7])
+
+
+# 3. Pass numpy arrays to MLIR
+M = 256
+N = 32
+alpha = 2.0
+x = np.random.randn(M, N).astype(np.float32)
+y = np.ones((M, N), np.float32)
+saxpy(x, y, alpha)
+
+# 4. Verify MLIR with reference computation
+ref = np.ones((M, N), np.float32)
+ref += x * alpha
+np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
+print("PASS")
+# CHECK-NOT: Mismatched elements
diff --git a/mlir/test/Examples/NVGPU/Ch2.py b/mlir/test/Examples/NVGPU/Ch2.py
new file mode 100644
index 00000000000000..78c14cb2c7ad8c
--- /dev/null
+++ b/mlir/test/Examples/NVGPU/Ch2.py
@@ -0,0 +1,93 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN: %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+# Chapter 2 : 2D Saxpy with TMA
+# ===----------------------------------------------------------------------===//
+#
+# This program demonstrates 2D Saxpy. It is same as Chapter 1,
+# but it loads data using TMA (Tensor Memory Accelerator)
+#
+# This chapter introduces demonstrates:
+# 1. Computes 2D SAXPY in the same way as Ch1.py but loads data using TMA
+# 2. Create and initialize 1 asynchronous transactional barrier (mbarrier)
+# 3. Thread-0 Load request data load from TMA for each thread block
+# 4. Each thread block loads <1x32xf32> for x and y.
+# 5. Wait for completion of TMA load with mbarrier
+#
+# ===----------------------------------------------------------------------===//
+
+from mlir import ir
+from mlir.dialects import nvgpu, scf, arith, memref, vector, gpu
+from tools.nvdsl import *
+from mlir import runtime as rt
+from mlir.extras import types as T
+import numpy as np
+
+
+ at NVDSL.mlir_func
+def saxpy(x, y, alpha):
+ token_ty = ir.Type.parse("!gpu.async.token")
+ t1 = gpu.wait(token_ty, [])
+ x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
+ y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
+ t4 = gpu.memcpy(token_ty, [t3], x_dev, x)
+ t5 = gpu.memcpy(token_ty, [t4], y_dev, y)
+ t6 = gpu.wait(token_ty, [t5])
+
+ x_tma = TMA([1, N], x.type)
+ y_tma = TMA([1, N], y.type)
+ x_tma.create_descriptor(x_dev)
+ y_tma.create_descriptor(y_dev)
+ sz_x = get_type_size(x_tma.tma_memref)
+ sz_y = get_type_size(x_tma.tma_memref)
+ sz = sz_x + sz_y
+
+ @NVDSL.mlir_gpu_launch(grid=(M, 1, 1), block=(N, 1, 1), smem=sz)
+ def saxpy_tma_kernel():
+ bidx = gpu.block_id(gpu.Dimension.x)
+ tidx = gpu.thread_id(gpu.Dimension.x)
+ isThread0 = tidx == 0
+
+ # 1. Create and initialize asynchronous transactional barrier (mbarrier)
+ mbar_group = Mbarriers(number_of_barriers=1)
+ mbar_group[0].init(1, predicate=isThread0)
+
+ # 2. Execute Tensor Memory Accelerator (TMA) Load
+ x_smem = get_dynamic_shared_memory([1, N], T.f32())
+ y_smem = get_dynamic_shared_memory([1, N], T.f32(), offset=sz_x)
+ x_tma.load(x_smem, mbar_group[0], coords=[0, bidx], predicate=isThread0)
+ y_tma.load(y_smem, mbar_group[0], coords=[0, bidx], predicate=isThread0)
+ mbar_group[0].arrive(txcount=sz, predicate=isThread0)
+
+ # 3. Wait for completion of TMA load with mbarrier
+ mbar_group[0].try_wait()
+
+ x_val = memref.load(x_smem, [const(0), tidx])
+ y_val = memref.load(y_smem, [const(0), tidx])
+
+ # SAXPY: y[i] += a * x[i];
+ y_val += x_val * alpha
+
+ memref.store(y_val, y_dev, [bidx, tidx])
+
+ saxpy_tma_kernel()
+
+ t7 = gpu.memcpy(token_ty, [t6], y, y_dev)
+ gpu.wait(token_ty, [t7])
+
+
+# 3. Pass numpy arrays to MLIR
+M = 256
+N = 32
+alpha = 2.0
+x = np.random.randn(M, N).astype(np.float32)
+y = np.ones((M, N), np.float32)
+saxpy(x, y, alpha)
+
+# 4. Verify MLIR with reference computation
+ref = np.ones((M, N), np.float32)
+ref += x * alpha
+np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
+print("PASS")
+# CHECK-NOT: Mismatched elements
diff --git a/mlir/test/Examples/NVGPU/Ch3.py b/mlir/test/Examples/NVGPU/Ch3.py
new file mode 100644
index 00000000000000..a417014de8b49a
--- /dev/null
+++ b/mlir/test/Examples/NVGPU/Ch3.py
@@ -0,0 +1,129 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN: %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+# Chapter 3 : GEMM 128x128x64 with Tensor Core
+# ===----------------------------------------------------------------------===//
+#
+# This program demonstrates a GEMM operation with 128x128x64 matrix multiplication
+#
+# This chapter introduces demonstrates:
+# 1. Execute TMA Load for two input matrices
+# 2. Performs Tensor Core GEMM 128x128x64 by warpgroup
+# 3. Stores fragmented registers to global memory by warpgroup
+#
+# ===----------------------------------------------------------------------===//
+
+
+from mlir import ir
+from mlir.dialects import nvgpu, scf, arith, memref, vector, gpu
+from tools.nvdsl import *
+from mlir.extras import types as T
+import numpy as np
+
+
+def tma_load(
+ mbar_group: Mbarriers,
+ a_tma: TMA,
+ b_tma: TMA,
+ p,
+):
+ """
+ TMA loads two input matrices from global memory to shared memory. It performs the following operations:
+
+ - tma.load a_shared_memory[0] at coordinate [0, 0] (Loads 128x64)
+ - tma.load b_shared_memory[0] at coordinate [0, 0] (Loads 64x64)
+ - tma.load b_shared_memory[0] at coordinate [64, 0] (Loads 64x64)
+
+ mbarrier.arrive ta_count = 128x64xf16 + 64x128xf16
+ """
+
+ size_tma_a = get_type_size(a_tma.tma_memref)
+ size_tma_b = get_type_size(b_tma.tma_memref)
+ ta_count = size_tma_a + (size_tma_b * 2)
+
+ off_b = size_tma_a
+ off_b2 = off_b + size_tma_b
+ a_elem_ty = a_tma.tma_memref.element_type
+ b_elem_ty = b_tma.tma_memref.element_type
+ a = get_dynamic_shared_memory(a_tma.tma_memref.shape, a_elem_ty)
+ b1 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b)
+ b2 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b2)
+
+ mbar_group[0].arrive(ta_count, predicate=p)
+
+ a_tma.load(a, mbar_group[0], coords=[0, 0], predicate=p)
+ b_tma.load(b1, mbar_group[0], coords=[0, 0], predicate=p)
+ b_tma.load(b2, mbar_group[0], coords=[64, 0], predicate=p)
+
+
+ at NVDSL.mlir_func
+def gemm_128_128_64(a, b, d):
+ token_ty = ir.Type.parse("!gpu.async.token")
+ t1 = gpu.wait(token_ty, [])
+ a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
+ b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])
+ d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], [])
+ t5 = gpu.memcpy(token_ty, [t4], a_dev, a)
+ t6 = gpu.memcpy(token_ty, [t5], b_dev, b)
+ t7 = gpu.wait(token_ty, [t6])
+
+ sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
+ a_tma = TMA([128, 64], a.type, swizzle=sw)
+ b_tma = TMA([64, 64], b.type, swizzle=sw)
+ a_tma.create_descriptor(a_dev)
+ b_tma.create_descriptor(b_dev)
+ a_size = get_type_size(a.type)
+ b_size = get_type_size(b.type)
+ smem_size_in_bytes = a_size + b_size
+
+ @NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(128, 1, 1), smem=smem_size_in_bytes)
+ def gemm_tma_kernel():
+ tidx = gpu.thread_id(gpu.Dimension.x)
+
+ mbar_group = Mbarriers(number_of_barriers=1)
+ isThread0 = tidx == 0
+
+ mbar_group[0].init(1, predicate=isThread0)
+ a_tma.prefetch(predicate=isThread0)
+ b_tma.prefetch(predicate=isThread0)
+
+ a_smem = get_dynamic_shared_memory((M, K), T.f16())
+ b_smem = get_dynamic_shared_memory((K, N), T.f16(), offset=a_size)
+
+ # 1. TMA Load for two input matrices
+ tma_load(mbar_group, a_tma, b_tma, isThread0)
+
+ # 2. All threads wait TMA load completion
+ mbar_group[0].try_wait()
+
+ # 3. Performs Tensor Core GEMM 128x128x64 by warpgroup
+ A = WGMMAMatrix(WGMMAType.Descriptor, [M, K], desc=a_tma, smem=a_smem)
+ B = WGMMAMatrix(WGMMAType.Descriptor, [K, N], desc=b_tma, smem=b_smem)
+ D = WGMMAMatrix(WGMMAType.Accumulator, shape=[M, N], ty=T.f32())
+
+ # Matrix Multiply
+ D += A @ B
+
+ # 4. Stores fragmented registers to global memory by warpgroup
+ D.store_accumulator(d_dev)
+
+ gemm_tma_kernel()
+
+ t8 = gpu.memcpy(token_ty, [t7], d, d_dev)
+ gpu.wait(None, [t8])
+
+
+# Python pass arguments to MLIR
+M = 128
+N = 128
+K = 64
+a = np.random.randn(M, K).astype(np.float16)
+b = np.random.randn(K, N).astype(np.float16)
+d = np.zeros((M, N), np.float32)
+gemm_128_128_64(a, b, d)
+
+ref_d = a.astype(np.float16) @ b.astype(np.float16)
+np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
+print("PASS")
+# CHECK-NOT: Mismatched elements
diff --git a/mlir/test/Examples/NVGPU/Ch4.py b/mlir/test/Examples/NVGPU/Ch4.py
new file mode 100644
index 00000000000000..8f38d8a90add31
--- /dev/null
+++ b/mlir/test/Examples/NVGPU/Ch4.py
@@ -0,0 +1,323 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN: %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+# Chapter 4 : Multistage GEMM with Tensor Core
+# ===----------------------------------------------------------------------===//
+#
+# This program exemplifies a GEMM operation for `f32+=f16*f16`, utilizing the
+# Multistage method with a tile size of 128x128x64. The code completely
+# parallelizes the two outermost loops into thread blocks. It launches one Warp
+# Groups (128 threads in total) and allocates multiple slots/stage in the
+# shared memory. The program consists of three main parts: prologue, mainloop,
+# and epilogue. In the prologue, thread0 requests for TMA to load data into
+# shared memory slots. The mainloop executes MMA while simultaneously loading
+# TMA for the utilized slots. This overlap of TMA and MMA operations enhances
+# performance by maximizing computational throughput.
+#
+# Loops illustration:
+#
+# for s in range(num_stages):
+# TMA_128x64_64x128...
+# for ti in range(M//128): # -> blockIdx.x
+# for tj in range(N//128): # -> blockIdx.y
+# for tk in range(K//64):
+# MMA_128x128x64...
+# TMA_128x64_64x128...
+# Epilogue...
+#
+# This chapter introduces demonstrates:
+# 1. Partition shape based on block IDs
+# 2. Prologue
+# 2.1 Execute TMA Load for two input matrices for each stage
+# 3. Main loop
+# 3.1 Wait for completion of TMA load with mbarrier
+# 3.2 Performs Tensor Core GEMM 64x128x64 by warpgroup
+# 3.3 Load next stage if needed
+# 4. Epilogue
+# 4.1 Store fragmented registers to shared memory
+# 4.2 Store shared memory to global
+#
+# ===----------------------------------------------------------------------===//
+
+
+from mlir import ir
+from mlir.dialects import gpu, scf, nvgpu, nvvm
+from mlir.extras import types as T
+from tools.nvdsl import *
+import numpy as np
+
+
+def partition_shape():
+ """
+ Calculate the partition shape based on the block IDs.
+
+ It partitions the shape like below:
+ for(.. i < M ...) --> blockIdx.x
+ for(.. j < N ...) --> blockIdx.y
+ for(.. k < K ...)
+
+ Returns:
+ dimX (int): Dimension along the x-axis.
+ dimY (int): Dimension along the y-axis.
+ """
+ bidx = gpu.block_id(gpu.Dimension.x)
+ bidy = gpu.block_id(gpu.Dimension.y)
+ dimX = bidx * TILE_M
+ dimY = bidy * TILE_N
+ return dimX, dimY
+
+
+def tma_load(
+ mbar_group: Mbarriers,
+ a_tma: TMA,
+ b_tma: TMA,
+ slot,
+ stage,
+ num_stages,
+ p=None,
+):
+ """
+ TMA loads two input matrices from global memory to shared memory. It performs the following operations:
+
+ - tma.load a_shared_memory[off_x] at coordinate [x, z] (Loads 128x64)
+ - tma.load b_shared_memory[off_y1] at coordinate [y, x] (Loads 64x64)
+ - tma.load b_shared_memory[off_y2] at coordinate [y + 64, x] (Loads 64x64)
+
+ mbarrier.arrive ta_count = 128x64x2x4
+ """
+ dimX, dimY = partition_shape()
+
+ tidx = gpu.thread_id(gpu.Dimension.x)
+ begin_b = num_stages * get_type_size(a_tma.tma_memref)
+ size_tma_a = get_type_size(a_tma.tma_memref)
+ size_tma_b = get_type_size(b_tma.tma_memref)
+ ta_count = size_tma_a + (size_tma_b * 2)
+ tidx = gpu.thread_id(gpu.Dimension.x)
+
+ p = tidx == 0 if p is None else p
+
+ off_a = slot * size_tma_a
+ off_b = (slot * size_tma_a) + begin_b
+ off_b2 = off_b + size_tma_b
+ a_elem_ty = a_tma.tma_memref.element_type
+ b_elem_ty = b_tma.tma_memref.element_type
+ a = get_dynamic_shared_memory(a_tma.tma_memref.shape, a_elem_ty, off_a)
+ b1 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b)
+ b2 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b2)
+
+ mbar_group[slot].arrive(ta_count, predicate=p)
+
+ c1 = stage * 64
+ a_tma.load(a, mbar_group[slot], coords=[c1, dimX], predicate=p)
+ b_tma.load(b1, mbar_group[slot], coords=[dimY, c1], predicate=p)
+ b_tma.load(b2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p)
+
+
+def initialize(a_tma: TMA, b_tma: TMA, num_stages):
+ """
+ Initialize mbarriers and prefetch TMA descriptors.
+ """
+ tidx = gpu.thread_id(gpu.Dimension.x)
+ mbar_group = Mbarriers(number_of_barriers=num_stages)
+ isThread0 = tidx == const(0)
+ with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
+ for i in scf.for_(0, num_stages, 1):
+ mbar_group[i].init(1)
+ scf.yield_([])
+ a_tma.prefetch()
+ b_tma.prefetch()
+ scf.yield_([])
+
+ return mbar_group
+
+
+def prologue(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA, num_stages):
+ """
+ Prologue of the GEMM kernel. It loads 2 input matrices for each stage in loop like below:
+
+ for stage in range(NUM_STAGES):
+ tma_load x, y, stage
+
+ """
+ ns = num_stages if num_stages == 1 else num_stages - 1
+ for iv in scf.for_(0, ns, 1):
+ tma_load(mbar_group, a_tma, b_tma, iv, iv, num_stages)
+ scf.yield_([])
+
+
+def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA, num_stages):
+ """
+ Main loop of the Multistage GEMM kernel. It iterates through
+ stages and performs matrix multiplication, loading data by TMA to shared memory. It like following
+
+ MatrixAccumulator D
+ for k in range(K // TILE_K):
+
+ try_wait(stage, ...) # Wait TMA load
+
+ Matrix A(stage, ...) # Find shared memory slot
+ Matrix B(stage, ...) # Find shared memory slot
+ D += A @ B # Multiply and accumulate
+
+ if(needLoad) # Load next stage if needed
+ tma_load(x, y, nextSlot, nextStage)
+
+ """
+ ns = num_stages if num_stages == 1 else num_stages - 1
+
+ tidx = gpu.thread_id(gpu.Dimension.x)
+ begin_b = num_stages * get_type_size(a_tma.tma_memref)
+
+ size_a = TILE_M * TILE_K * get_type_size(T.f16())
+
+ # Initialize A and B (input matrices) and C (accumulator)
+ A = WGMMAMatrix(WGMMAType.Descriptor, [TILE_M, TILE_K], desc=a_tma)
+ B = WGMMAMatrix(WGMMAType.Descriptor, [TILE_K, TILE_N], desc=b_tma)
+ D = WGMMAMatrix(WGMMAType.Accumulator, shape=[TILE_M, TILE_N], ty=T.f32())
+
+ phase = const(False, ty=T.bool())
+
+ # Main Loop
+ for_op = scf.ForOp(const(0), const(K // TILE_K), const(1), [D.acc_op, phase])
+ with ir.InsertionPoint(for_op.body):
+ phase = for_op.inner_iter_args[1]
+ iv = for_op.induction_variable
+ stage = iv % num_stages
+
+ # Wait for current stage
+ mbar_group[stage].try_wait(phase=phase)
+
+ # Find shared memory slot
+ offset_a = stage * size_a
+ offset_b = offset_a + begin_b
+ a_smem = get_dynamic_shared_memory([TILE_M, TILE_K], T.f16(), offset_a)
+ b_smem = get_dynamic_shared_memory([TILE_K, TILE_N], T.f16(), offset_b)
+
+ # Iterate input matrices, update accumulator
+ A.update_smem(a_smem)
+ B.update_smem(b_smem)
+ D.update_accumulator(for_op.inner_iter_args[0])
+
+ # Matrix Multiply
+ D += A @ B
+
+ # Wait Tensor Core for single stage
+ if num_stages == 1:
+ nvvm.WgmmaWaitGroupSyncOp(0)
+
+ # Load next stage
+ pred = ((iv + ns) < const(K // TILE_K)) & (tidx == 0)
+ nextStage = iv + ns
+ nextSlot = nextStage % num_stages
+ tma_load(mbar_group, a_tma, b_tma, nextSlot, nextStage, num_stages, pred)
+
+ # Switch phase parity for the mbarrier
+ newPhase = arith.select(
+ stage == (num_stages - 1),
+ (phase ^ const(True, ty=T.bool())),
+ phase,
+ )
+ scf.yield_([D.acc_op, newPhase])
+
+ nvvm.WgmmaWaitGroupSyncOp(0)
+
+ D.update_accumulator(for_op.results[0])
+ return D
+
+
+def epilogue(D: WGMMAMatrix, d_dev):
+ """
+ Epilogue of the GEMM kernel. It stores the fragmented registers to global memory.
+
+ MatrixAccumulator D # Fragmented results
+ store D -> Shared Memory # Store Shared Memory
+ Shared Memory -> Z[dimX][dimY] # Store Shared Memory to Global Memory
+
+ """
+ tidx = gpu.thread_id(gpu.Dimension.x)
+ dimX, dimY = partition_shape()
+
+ d_smem = get_dynamic_shared_memory([TILE_M, TILE_N], T.f32())
+ d_gmem = memref.subview(d_dev, [dimX, dimY], [TILE_M, TILE_N], [1, 1])
+
+ # Store (registers -> shared memory)
+ D.store_accumulator(d_smem)
+ gpu.barrier()
+
+ # Store (shared memory --> global memory)
+ for i in scf.for_(0, TILE_M, 1):
+ val = memref.load(d_smem, [i, tidx])
+ memref.store(val, d_gmem, [i, tidx])
+ scf.yield_([])
+
+
+# The decorator generates
+# a -> memref<MxKxf16>
+# b -> memref<NxKf16>
+# d -> memref<MxNxf32>
+ at NVDSL.mlir_func
+def gemm_multistage(a, b, d, num_stages):
+ token_ty = ir.Type.parse("!gpu.async.token")
+ t1 = gpu.wait(token_ty, [])
+ a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
+ b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])
+ d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], [])
+ t5 = gpu.memcpy(token_ty, [t4], a_dev, a)
+ t6 = gpu.memcpy(token_ty, [t5], b_dev, b)
+ t7 = gpu.wait(token_ty, [t6])
+
+ sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
+ a_tma = TMA([128, 64], a.type, swizzle=sw)
+ b_tma = TMA([64, 64], b.type, swizzle=sw)
+ a_tma.create_descriptor(a_dev)
+ b_tma.create_descriptor(b_dev)
+
+ grid = [(M // TILE_M), (N // TILE_N), 1]
+ block = [128, 1, 1]
+
+ size_a = get_type_size(a.type.element_type) * TILE_M * TILE_K
+ size_b = get_type_size(b.type.element_type) * TILE_N * TILE_K
+ smem_size_in_bytes = (size_a + size_b) * num_stages
+
+ @NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=smem_size_in_bytes)
+ def gemm_multistage_kernel():
+ # Initialize mbarriers and prefetch TMA descriptors
+ mbar_group = initialize(a_tma, b_tma, num_stages)
+
+ # Fill the pipeline stages
+ prologue(mbar_group, a_tma, b_tma, num_stages)
+
+ # Main loop
+ D = mainloop(mbar_group, a_tma, b_tma, num_stages)
+
+ # Store registers to global memory
+ epilogue(D, d_dev)
+
+ gemm_multistage_kernel()
+
+ t8 = gpu.memcpy(token_ty, [t7], d, d_dev)
+ gpu.wait(None, [t8])
+
+
+# Python pass arguments to MLIR
+N = 256
+M = 512
+K = 1024
+TILE_M = 128
+TILE_N = 128
+TILE_K = 64
+a = np.random.randn(M, K).astype(np.float16)
+b = np.random.randn(K, N).astype(np.float16)
+d = np.zeros((M, N), np.float32)
+
+gemm_multistage(a, b, d, num_stages=7)
+
+
+# Verify MLIR with reference computation
+ref_d = a.astype(np.float16) @ b.astype(np.float16)
+np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
+
+
+print("PASS")
+# CHECK-NOT: Mismatched elements
diff --git a/mlir/test/Examples/NVGPU/Ch5.py b/mlir/test/Examples/NVGPU/Ch5.py
new file mode 100644
index 00000000000000..92e9314e1b812d
--- /dev/null
+++ b/mlir/test/Examples/NVGPU/Ch5.py
@@ -0,0 +1,321 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN: %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+# Chapter 5 : Warp Specialized GEMM with Tensor Core
+# ===----------------------------------------------------------------------===//
+#
+# This program demonstrates a GEMM operation for `f32+=f16*f16`, utilizing the
+# Warp Specialized method with a tile size of 128x128x64. The code completely
+# parallelizes the two outermost loops into thread blocks. It launches two Warp
+# Groups (256 threads in total): one for the producer and the other for the consumer.
+# Each group takes a
diff erent control-flow. The producer thread group is responsible
+# for loading data into shared memory, while the consumer group executes the Tensor
+# Core GEMM operation and epilogue.
+#
+# for ti in range(M//128): # -> blockIdx.x
+# for tj in range(N//128): # -> blockIdx.y
+# with wg_producer:
+# for tk in range(K//64):
+# TMA_128x64_64x128...
+# with wg_consumer:
+# for tk in range(K//64):
+# MMA_128x128x64...
+# Epilogue..
+#
+# This chapter demonstrates:
+# 2 WG (warpgroups)
+# Producer:
+# 2.1.1 Wait MMA Barrier
+# 2.1.1 Load TMA with TMA barrier
+# 2.1.1 Arrive TMA barrier with txcount
+# Consumer:
+# Loop
+# Wait TMA barrier
+# Performs Tensor Core GEMM 64x128x64 by warpgroup
+# Arrive MMA Barrier
+# Epilogue
+# Store fragmented registers to shared memory
+# Store shared memory to global
+#
+# ===----------------------------------------------------------------------===//
+
+
+from mlir import ir
+from mlir.dialects import gpu, scf, nvgpu, nvvm
+from mlir.extras import types as T
+from tools.nvdsl import *
+import numpy as np
+
+
+def partition_shape():
+ """
+ Calculate the partition shape based on the block IDs.
+
+ It parallelizes the two outermost loops into thread blocks.
+ for ti in range(M//128): # -> blockIdx.x
+ for tj in range(N//128): # -> blockIdx.y
+ D = 0
+ for tk in range(K//64):
+ for i in range(128):
+ for j in range(128):
+ for k in range(64):
+ FMA
+
+ Returns:
+ dimX (int): Dimension along the x-axis.
+ dimY (int): Dimension along the y-axis.
+ """
+ bidx = gpu.block_id(gpu.Dimension.x)
+ bidy = gpu.block_id(gpu.Dimension.y)
+ dimX = bidx * TILE_M
+ dimY = bidy * TILE_N
+ return dimX, dimY
+
+
+def tma_load(
+ mbar_group: Mbarriers,
+ a_tma: TMA,
+ b_tma: TMA,
+ slot,
+ stage,
+ num_stages,
+ p=None,
+):
+ """
+ TMA loads two input matrices from global memory to shared memory. It performs the following operations:
+
+ - tma.load a_shared_memory[off_x] at coordinate [x, z] (Loads 128x64)
+ - tma.load b_shared_memory[off_y1] at coordinate [y, x] (Loads 64x64)
+ - tma.load b_shared_memory[off_y2] at coordinate [y + 64, x] (Loads 64x64)
+
+ mbarrier.arrive ta_count = 128x64x2x4
+ """
+ dimX, dimY = partition_shape()
+
+ tidx = gpu.thread_id(gpu.Dimension.x)
+ begin_b = num_stages * get_type_size(a_tma.tma_memref)
+ size_tma_a = get_type_size(a_tma.tma_memref)
+ size_tma_b = get_type_size(b_tma.tma_memref)
+ ta_count = size_tma_a + (size_tma_b * 2)
+
+ off_a = slot * size_tma_a
+ off_b = (slot * size_tma_a) + begin_b
+ off_b2 = off_b + size_tma_b
+ a_elem_ty = a_tma.tma_memref.element_type
+ b_elem_ty = b_tma.tma_memref.element_type
+ a = get_dynamic_shared_memory(a_tma.tma_memref.shape, a_elem_ty, off_a)
+ b1 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b)
+ b2 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b2)
+
+ mbar_group[slot].arrive(ta_count, predicate=p)
+ p = (tidx % WARP_GROUP_SIZE) == 0
+ c1 = stage * 64
+ a_tma.load(a, mbar_group[slot], coords=[c1, dimX], predicate=p)
+ b_tma.load(b1, mbar_group[slot], coords=[dimY, c1], predicate=p)
+ b_tma.load(b2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p)
+
+
+def initialize(a_tma: TMA, b_tma: TMA, num_stages):
+ """
+ Initialize mbarriers and prefetch TMA descriptors.
+ """
+ tidx = gpu.thread_id(gpu.Dimension.x)
+ mbar_group_tma = Mbarriers(number_of_barriers=num_stages)
+ mbar_group_mma = Mbarriers(number_of_barriers=num_stages)
+ isThread0 = tidx == const(0)
+ with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
+ for i in scf.for_(0, num_stages, 1):
+ mbar_group_tma[i].init(1)
+ mbar_group_mma[i].init(1)
+ scf.yield_([])
+ a_tma.prefetch()
+ b_tma.prefetch()
+ scf.yield_([])
+
+ return mbar_group_tma, mbar_group_mma
+
+
+def switch_phase(stage, phase, num_stages):
+ p = stage == (num_stages - 1)
+ phase = arith.select(
+ p,
+ (phase ^ const(True, ty=T.bool())),
+ phase,
+ )
+ return phase
+
+
+def producer_loop(
+ mbar_tma: Mbarriers,
+ mbar_mma: Mbarriers,
+ a_tma: TMA,
+ b_tma: TMA,
+ wg_me: Warpgroup,
+ num_stages,
+):
+ phase = const(True, ty=T.bool())
+
+ for iv, phase in scf.for_(0, (K // TILE_K), 1, [phase]):
+ stage = iv % num_stages
+ # Wait MMA to be done
+ mbar_mma[stage].try_wait(phase)
+ # New phase for mbarrier
+ phase = switch_phase(stage, phase, num_stages)
+ # TMA Load
+ tma_load(mbar_tma, a_tma, b_tma, stage, iv, num_stages, wg_me.is_wg_primary)
+ scf.yield_([phase])
+
+
+def consumer_loop(
+ mbar_tma: Mbarriers,
+ mbar_mma: Mbarriers,
+ a_tma: TMA,
+ b_tma: TMA,
+ wg_me: Warpgroup,
+ num_stages,
+):
+ begin_b = num_stages * get_type_size(a_tma.tma_memref)
+
+ size_a = TILE_M * TILE_K * get_type_size(T.f16())
+
+ phase = const(False, ty=T.bool())
+ A = WGMMAMatrix(WGMMAType.Descriptor, [TILE_M, TILE_K], desc=a_tma)
+ B = WGMMAMatrix(WGMMAType.Descriptor, [TILE_K, TILE_N], desc=b_tma)
+ D = WGMMAMatrix(WGMMAType.Accumulator, shape=[TILE_M, TILE_N], ty=T.f32())
+
+ for_op = scf.ForOp(const(0), const(K // TILE_K), const(1), [D.acc_op, phase])
+ with ir.InsertionPoint(for_op.body):
+ phase = for_op.inner_iter_args[1]
+ iv = for_op.induction_variable
+ stage = iv % num_stages
+
+ # Wait TMA for current stage
+ mbar_tma[stage].try_wait(phase)
+
+ # Find shared memory slot
+ offset_a = stage * size_a
+ offset_b = offset_a + begin_b
+ a_smem = get_dynamic_shared_memory([TILE_M, TILE_K], T.f16(), offset_a)
+ b_smem = get_dynamic_shared_memory([TILE_K, TILE_N], T.f16(), offset_b)
+
+ # Iterate input matrices, update accumulator
+ A.update_smem(a_smem)
+ B.update_smem(b_smem)
+ D.update_accumulator(for_op.inner_iter_args[0])
+
+ # Matrix Multiply
+ D += A @ B
+
+ # MMA Barrier Arrive
+ p_arrive = (iv > 0) & wg_me.is_wg_primary
+ with ir.InsertionPoint(scf.IfOp(p_arrive).then_block):
+ barId = arith.select((stage == 0), const(num_stages - 1), (stage - 1))
+ mbar_mma[barId].arrive()
+ scf.yield_([])
+
+ phase = switch_phase(stage, phase, num_stages)
+ scf.yield_([D.acc_op, phase])
+
+ nvvm.WgmmaWaitGroupSyncOp(0)
+ D.update_accumulator(for_op.results[0])
+ return D
+
+
+def epilogue(D: WGMMAMatrix, d_dev):
+ """
+ Epilogue of the GEMM kernel. It stores the fragmented registers to global memory.
+
+ MatrixAccumulator D # Fragmented results
+ store D -> Shared Memory # Store Shared Memory
+ Shared Memory -> Z[dimX][dimY] # Store Shared Memory to Global Memory
+
+ """
+ tidx = gpu.thread_id(gpu.Dimension.x)
+ dimX, dimY = partition_shape()
+ # s = tidx - WARP_GROUP_SIZE
+ # debug_print("[Epilogue] store to global memory @ s={}", s)
+
+ d_smem = get_dynamic_shared_memory([TILE_M, TILE_N], T.f32())
+ d_gmem = memref.subview(d_dev, [dimX, dimY], [TILE_M, TILE_N], [1, 1])
+
+ # Store (registers -> shared memory)
+ D.store_accumulator(d_smem)
+ gpu.barrier()
+
+ # Store (shared memory --> global memory)
+ for i in scf.for_(0, TILE_M, 1):
+ val = memref.load(d_smem, [i, tidx])
+ memref.store(val, d_gmem, [i, tidx])
+ scf.yield_([])
+
+
+ at NVDSL.mlir_func
+def gemm_warp_specialized(a, b, d, num_stages):
+ token_ty = ir.Type.parse("!gpu.async.token")
+ t1 = gpu.wait(token_ty, [])
+ a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
+ b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])
+ d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], [])
+ t5 = gpu.memcpy(token_ty, [t4], a_dev, a)
+ t6 = gpu.memcpy(token_ty, [t5], b_dev, b)
+ t7 = gpu.wait(token_ty, [t6])
+
+ sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
+ a_tma = TMA([128, 64], a.type, swizzle=sw)
+ b_tma = TMA([64, 64], b.type, swizzle=sw)
+ a_tma.create_descriptor(a_dev)
+ b_tma.create_descriptor(b_dev)
+
+ grid = [(M // TILE_M), (N // TILE_N), 1]
+ block = [256, 1, 1]
+
+ size_a = get_type_size(a.type.element_type) * TILE_M * TILE_K
+ size_b = get_type_size(b.type.element_type) * TILE_N * TILE_K
+ smem_size_in_bytes = (size_a + size_b) * num_stages
+
+ @NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=smem_size_in_bytes)
+ def gemm_warp_specialized_kernel():
+ # Init Warpgroups
+ wg_producer = Warpgroup(primary_thread=128, register_size=40)
+ wg_consumer = Warpgroup(primary_thread=0, register_size=232)
+
+ # Initialize mbarriers and prefetch TMA descriptors
+ mbar_mma, mbar_tma = initialize(a_tma, b_tma, num_stages)
+
+ # Producer performs TMA
+ with wg_producer:
+ producer_loop(mbar_tma, mbar_mma, a_tma, b_tma, wg_producer, num_stages)
+
+ # Consumer performs MMA/Tensor Core
+ with wg_consumer:
+ D = consumer_loop(mbar_tma, mbar_mma, a_tma, b_tma, wg_consumer, num_stages)
+ epilogue(D, d_dev)
+
+ gemm_warp_specialized_kernel()
+
+ t8 = gpu.memcpy(token_ty, [t7], d, d_dev)
+ gpu.wait(None, [t8])
+
+
+# Python pass arguments to MLIR
+N = 256
+M = 512
+K = 1024
+TILE_M = 128
+TILE_N = 128
+TILE_K = 64
+a = np.random.randn(M, K).astype(np.float16)
+b = np.random.randn(K, N).astype(np.float16)
+d = np.zeros((M, N), np.float32)
+
+gemm_warp_specialized(a, b, d, num_stages=7)
+
+
+# Verify MLIR with reference computation
+ref_d = a.astype(np.float16) @ b.astype(np.float16)
+np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
+
+
+print("PASS")
+# CHECK-NOT: Mismatched elements
diff --git a/mlir/test/Examples/NVGPU/lit.local.cfg b/mlir/test/Examples/NVGPU/lit.local.cfg
new file mode 100644
index 00000000000000..689cd252e7a254
--- /dev/null
+++ b/mlir/test/Examples/NVGPU/lit.local.cfg
@@ -0,0 +1,4 @@
+config.unsupported = False
+if not config.enable_cuda_runner or not config.mlir_run_cuda_sm90_tests:
+ config.unsupported = True
+
\ No newline at end of file
diff --git a/mlir/test/Examples/NVGPU/tools/lit.local.cfg b/mlir/test/Examples/NVGPU/tools/lit.local.cfg
new file mode 100644
index 00000000000000..d9f34f219c4d95
--- /dev/null
+++ b/mlir/test/Examples/NVGPU/tools/lit.local.cfg
@@ -0,0 +1,3 @@
+# Files in this directory are tools, not tests.
+config.unsupported = True
+
diff --git a/mlir/test/Examples/NVGPU/tools/nvdsl.py b/mlir/test/Examples/NVGPU/tools/nvdsl.py
new file mode 100644
index 00000000000000..600cae5b47eeec
--- /dev/null
+++ b/mlir/test/Examples/NVGPU/tools/nvdsl.py
@@ -0,0 +1,456 @@
+from enum import Enum
+import functools, sys, ctypes, os, errno
+import numpy as np
+from functools import partialmethod
+from mlir import ir
+from mlir.dialects import arith, func, gpu, memref, nvgpu, scf, nvvm
+from mlir.extras import types as T
+from mlir import runtime as rt
+from tools import nvgpucompiler
+
+MLIR_DYNAMIC = -9223372036854775808
+
+
+def const(value: int, ty=None):
+ ty = T.index() if ty is None else ty
+ if isinstance(value, ir.Value) and (
+ value.type.isinstance(value.type) or T.bool().isinstance(value.type)
+ ):
+ return value
+ return arith.constant(ty, value)
+
+
+def get_type_size(ty):
+ if ir.MemRefType.isinstance(ty):
+ size = get_type_size(ty.element_type)
+ for sz in ty.shape:
+ size *= sz
+ return size
+ if ir.FloatType.isinstance(ty):
+ return ir.FloatType(ty).width // 8
+ if ir.IntegerType.isinstance(ty):
+ return ir.IntegerType(ty).width // 8
+ raise NotImplementedError(ty)
+
+
+def get_mlir_func_obj_ty(inputArgs):
+ args = []
+ c_int_p = ctypes.c_int * 1
+ c_float_p = ctypes.c_float * 1
+ c_bool_p = ctypes.c_bool * 1
+ for arg in inputArgs:
+ if isinstance(arg, bool):
+ args.append(c_bool_p(arg))
+ elif isinstance(arg, int):
+ args.append(c_int_p(arg))
+ elif isinstance(arg, float):
+ args.append(c_float_p(arg))
+ elif isinstance(arg, np.ndarray):
+ args.append(
+ ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(arg)))
+ )
+ else:
+ raise NotImplementedError(arg)
+ return args
+
+
+class Mbarriers:
+ def __init__(self, number_of_barriers=1):
+ self.mbar_ty = ir.Type.parse(
+ "!nvgpu.mbarrier.group<memorySpace=#gpu.address_space<workgroup>, num_barriers = "
+ + str(number_of_barriers)
+ + ">"
+ )
+ self.mbar_group_op = nvgpu.mbarrier_create(self.mbar_ty)
+ self.number_of_barriers = number_of_barriers
+
+ def __getitem__(self, key):
+ self.id_op = const(key)
+ return self
+
+ def init(self, count: int, predicate=None):
+ count_op = const(count)
+ if predicate is None:
+ nvgpu.mbarrier_init(self.mbar_group_op, count_op, self.id_op)
+ else:
+ nvgpu.mbarrier_init(
+ self.mbar_group_op, count_op, self.id_op, predicate=predicate
+ )
+
+ def arrive(self, txcount: int = 0, predicate=None):
+ if txcount != 0:
+ txcount_op = const(txcount)
+ nvgpu.mbarrier_arrive_expect_tx(
+ self.mbar_group_op, txcount_op, self.id_op, predicate=predicate
+ )
+ else:
+ nvgpu.mbarrier_arrive(
+ ir.Type.parse("!nvgpu.mbarrier.token"), self.mbar_group_op, self.id_op
+ )
+
+ def try_wait(self, phase: bool = False, ticks: int = 10000000):
+ ticks_op = const(ticks)
+ phase_op = const(phase, T.bool())
+ nvgpu.MBarrierTryWaitParityOp(
+ self.mbar_group_op,
+ phase_op,
+ ticks_op,
+ mbarId=self.id_op,
+ )
+
+
+class TMA:
+ """A class that builds a TMA descriptor."""
+
+ def __init__(
+ self,
+ tma_box_shape,
+ memref_ty,
+ swizzle=nvgpu.TensorMapSwizzleKind.SWIZZLE_NONE,
+ l2promo=nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
+ oob=nvgpu.TensorMapOOBKind.OOB_ZERO,
+ interleave=nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
+ ):
+ self.swizzle = swizzle # mlir.nvgpu.TensorMapSwizzleKind
+ self.l2promo = l2promo # mlir.nvgpu.TensorMapL2PromoKind
+ self.oob = oob # mlir.nvgpu.TensorMapOOBKind
+ self.interleave = interleave # mlir.nvgpu.TensorMapInterleaveKind
+ self.tma_box_shape = tma_box_shape
+ self.memref_ty = memref_ty # MemRefType
+ self.tma_memref = ir.MemRefType.get(tma_box_shape, memref_ty.element_type)
+
+ @property
+ def tensormap_descriptor_ty(self):
+ """Returns a tensormap descriptor type."""
+ tensorMemrefType = ir.MemRefType.get(
+ self.tma_box_shape,
+ self.memref_ty.element_type,
+ memory_space=ir.Attribute.parse("3"),
+ )
+ return nvgpu.TensorMapDescriptorType.get(
+ tensorMemrefType,
+ self.swizzle,
+ self.l2promo,
+ self.oob,
+ self.interleave,
+ )
+
+ def create_descriptor(self, device_ptr):
+ tma_descriptor_ty = self.tensormap_descriptor_ty
+ device_unranked_memref = memref.CastOp(
+ ir.UnrankedMemRefType.get(
+ self.memref_ty.element_type, self.memref_ty.memory_space
+ ),
+ device_ptr,
+ )
+ self.tma_descriptor = nvgpu.TmaCreateDescriptorOp(
+ tma_descriptor_ty, device_unranked_memref, map(const, self.tma_box_shape)
+ )
+ return self.tma_descriptor.result
+
+ def prefetch(self, predicate=None):
+ nvgpu.tma_prefetch_descriptor(self.tma_descriptor, predicate=predicate)
+
+ def load(self, dest, mbarrier: Mbarriers, coords=[0], predicate=None):
+ nvgpu.TmaAsyncLoadOp(
+ dest,
+ mbarrier.mbar_group_op,
+ self.tma_descriptor,
+ coordinates=map(const, coords),
+ mbarId=mbarrier.id_op,
+ predicate=predicate,
+ )
+
+
+WARP_GROUP_SIZE = 128 # Number of threads in a warpgroup
+
+
+class Warpgroup:
+ def __init__(self, primary_thread, register_size):
+ assert (primary_thread % WARP_GROUP_SIZE) == 0
+ tidx = gpu.thread_id(gpu.Dimension.x)
+ self.primary_thread = primary_thread
+ self.register_size = register_size
+ self.is_wg_primary = (tidx % WARP_GROUP_SIZE) == 0
+ self.wg_id = tidx / WARP_GROUP_SIZE
+ self.is_me = self.wg_id == (primary_thread // WARP_GROUP_SIZE)
+
+ def __enter__(self):
+ if_op = scf.IfOp(self.is_me)
+ self.ipoint_op = ir.InsertionPoint(if_op.then_block)
+ self.ipoint_op.__enter__()
+ if self.register_size < 64:
+ nvvm.setmaxregister(self.register_size, nvvm.SetMaxRegisterAction.decrease)
+ else:
+ nvvm.setmaxregister(self.register_size, nvvm.SetMaxRegisterAction.increase)
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ scf.yield_([])
+ self.ipoint_op.__exit__(exc_type, exc_value, traceback)
+ return True
+
+
+class WGMMAType(Enum):
+ Accumulator = 1
+ Descriptor = 2
+
+
+class WGMMAMatrix:
+ def __init__(
+ self,
+ matrix_type: WGMMAType,
+ shape: list = None,
+ desc: TMA = None,
+ smem=None,
+ ty=None,
+ acc_op=None,
+ ):
+ if acc_op is None:
+ self.M = shape[0]
+ self.N = shape[1]
+ self.ty = ty
+ self.matrix_type = matrix_type
+ self.desc = desc
+ self.smem = smem
+ if matrix_type is WGMMAType.Accumulator:
+ self.acc_op = nvgpu.warpgroup_mma_init_accumulator(self.acc_ty)
+ elif acc_op:
+ self.acc_op = acc_op
+ self.matrix_type = WGMMAType.Accumulator
+
+ @property
+ def acc_ty(self):
+ parse_str = f"!nvgpu.warpgroup.accumulator<fragmented=vector<{self.M}x{self.N}x{self.ty}>>"
+ return ir.Type.parse(parse_str)
+
+ @property
+ def wgmma_ty(self):
+ parse_str = f"!nvgpu.warpgroup.descriptor<tensor=memref<{self.M}x{self.N}x{self.desc.memref_ty.element_type}, #gpu.address_space<workgroup>>>"
+ return ir.Type.parse(parse_str)
+
+ def store_accumulator(self, dest):
+ assert self.matrix_type == WGMMAType.Accumulator
+ nvgpu.warpgroup_mma_store(self.acc_op, dest)
+
+ def update_smem(self, smem):
+ self.smem = smem
+
+ def update_accumulator(self, acc_op):
+ self.acc_op = acc_op
+
+ def __matmul__(self, rhs):
+ lhs = nvgpu.warpgroup_generate_descriptor(
+ self.wgmma_ty, self.smem, self.desc.tma_descriptor
+ )
+ rhs = nvgpu.warpgroup_generate_descriptor(
+ rhs.wgmma_ty, rhs.smem, rhs.desc.tma_descriptor
+ )
+ return [lhs, rhs]
+
+ def __iadd__(self, matmulResult):
+ lhs = matmulResult[0]
+ rhs = matmulResult[1]
+ acc_op = nvgpu.WarpgroupMmaOp(
+ self.acc_op.type, lhs, rhs, self.acc_op, transposeB=True
+ )
+ return WGMMAMatrix(WGMMAType.Accumulator, acc_op=acc_op)
+
+
+def get_dynamic_shared_memory(shape=None, ty=None, offset: int = 0):
+ smem_space_str = "#gpu.address_space<workgroup>"
+ smem_space = ir.Attribute.parse(smem_space_str)
+ dynamic_smem = gpu.dynamic_shared_memory(
+ ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space)
+ )
+ if shape is None:
+ return dynamic_smem
+ memref_ty = ir.MemRefType.get(shape, ty, memory_space=smem_space)
+ return memref.view(
+ ir.MemRefType.get(
+ memref_ty.shape, memref_ty.element_type, memory_space=smem_space
+ ),
+ dynamic_smem,
+ const(offset),
+ [],
+ )
+
+
+def get_mlir_ty(arg):
+ def get_mlir_ty_from_np(dtype):
+ if dtype == np.float16:
+ return T.f16()
+ if dtype == np.float32:
+ return T.f32()
+ if dtype == np.float64:
+ return T.f64()
+ if dtype == np.int32:
+ return T.i32()
+ if dtype == np.int64:
+ return T.i64()
+ raise NotImplementedError(dtype)
+
+ if isinstance(arg, bool):
+ return T.bool()
+ elif isinstance(arg, int):
+ return T.index()
+ elif isinstance(arg, float):
+ return T.f32()
+ elif isinstance(arg, np.ndarray):
+ descriptor = rt.get_ranked_memref_descriptor(arg)
+ dtype = get_mlir_ty_from_np(arg.dtype)
+ shape = descriptor.shape
+ return memref.MemRefType.get(shape, dtype)
+ raise NotImplementedError(arg)
+
+
+class NVDSL:
+ @staticmethod
+ def mlir_gpu_launch(grid=(1, 1, 1), block=(1, 1, 1), smem=0):
+ def decorator(func):
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ launch_op = gpu.LaunchOp(
+ None,
+ [],
+ *map(const, grid),
+ *map(const, block),
+ dynamicSharedMemorySize=arith.constant(T.i32(), smem),
+ )
+ launch_op.body.blocks.append(*([T.index()] * 12))
+ with ir.InsertionPoint(launch_op.body.blocks[0]):
+ result = func(*args, **kwargs)
+ gpu.terminator()
+ return result
+
+ return wrapper
+
+ return decorator
+
+ @staticmethod
+ def mlir_func(funcBody):
+ @functools.wraps(funcBody)
+ def wrapper(*args, **kwargs):
+ function_name = funcBody.__name__
+
+ def saveIR(module):
+ """Save generated IR"""
+ if True: # self.saveIR:
+ # print(mlir_nvgpu_module)
+ original_stdout = sys.stdout
+ with open("nvdsl.mlir", "w") as f:
+ sys.stdout = f
+ print(module)
+ sys.stdout = original_stdout
+
+ def _binary_op(lhs, rhs, op: str, predAtt="") -> "ArithValue":
+ """Generate MLIR's Arith dialects binary operations."""
+ rhs = const(rhs)
+ if arith._is_float_type(lhs.type) and arith._is_float_type(rhs.type):
+ op += "F"
+ if op.startswith("Cmp"):
+ predicateAttr = getattr(arith, f"CmpFPredicate").__dict__[
+ predAtt
+ ]
+ elif arith._is_integer_like_type(
+ lhs.type
+ ) and arith._is_integer_like_type(lhs.type):
+ if op == "Div" or op == "Rem":
+ op += "U"
+ op += "I"
+ if op.startswith("Cmp"):
+ predicateAttr = getattr(arith, f"CmpIPredicate").__dict__[
+ predAtt
+ ]
+ else:
+ raise NotImplementedError(
+ f"Unsupported '{op}' operands: {lhs}, {rhs}"
+ )
+
+ if op.startswith("Cmp"):
+ op = getattr(arith, f"{op}Op")
+
+ return op(predicateAttr, lhs, rhs).result
+ else:
+ op = getattr(arith, f"{op}Op")
+ return op(lhs, rhs).result
+
+ @ir.register_value_caster(ir.IndexType.static_typeid)
+ @ir.register_value_caster(ir.F32Type.static_typeid)
+ @ir.register_value_caster(ir.F16Type.static_typeid)
+ @ir.register_value_caster(ir.F64Type.static_typeid)
+ @ir.register_value_caster(ir.IntegerType.static_typeid)
+ class ArithValue(ir.Value):
+ """Overloads operators for MLIR's Arith dialects binary operations."""
+
+ def __init__(self, v):
+ super().__init__(v)
+
+ __add__ = partialmethod(_binary_op, op="Add")
+ __sub__ = partialmethod(_binary_op, op="Sub")
+ __mul__ = partialmethod(_binary_op, op="Mul")
+ __truediv__ = partialmethod(_binary_op, op="Div")
+ __mod__ = partialmethod(_binary_op, op="Rem")
+ __xor__ = partialmethod(_binary_op, op="XOr")
+ __lt__ = partialmethod(_binary_op, op="Cmp", predAtt="ult")
+ __le__ = partialmethod(_binary_op, op="Cmp", predAtt="ule")
+ __eq__ = partialmethod(_binary_op, op="Cmp", predAtt="eq")
+ __ne__ = partialmethod(_binary_op, op="Cmp", predAtt="ne")
+ __gt__ = partialmethod(_binary_op, op="Cmp", predAtt="ugt")
+ __ge__ = partialmethod(_binary_op, op="Cmp", predAtt="uge")
+ __and__ = partialmethod(_binary_op, op="And")
+ __or__ = partialmethod(_binary_op, op="Or")
+
+ def __str__(self):
+ return (
+ super()
+ .__str__()
+ .replace(ir.Value.__name__, ArithValue.__name__)
+ )
+
+ # Generate MLIR Context and start generating IR
+ with ir.Context(), ir.Location.unknown():
+ types = []
+ for arg in args:
+ types.append(get_mlir_ty(arg))
+
+ # Build IR
+ module = ir.Module.create()
+ with ir.InsertionPoint(module.body):
+ fop = func.FuncOp(function_name, (types, []))
+ fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+ with ir.InsertionPoint(fop.add_entry_block()):
+ fargs = []
+ for i, a in enumerate(types):
+ fargs.append(fop.arguments[i])
+
+ # Call user function body
+ result = funcBody(*fargs, **kwargs)
+ func.ReturnOp([])
+
+ # Save IR in a file
+ # saveIR(module)
+
+ # Verify the module
+ # module.operation.verify()
+
+ # Compile and JIT MLIR module
+ options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
+ support_lib = os.getenv("SUPPORT_LIB")
+ if not os.path.exists(support_lib):
+ raise FileNotFoundError(
+ errno.ENOENT, os.strerror(errno.ENOENT), support_lib
+ )
+ compiler = nvgpucompiler.NvgpuCompiler(
+ options, opt_level=3, shared_libs=[support_lib]
+ )
+ engine = compiler.compile_and_jit(module)
+
+ # Convert input arguments to MLIR arguments
+ newArgs = get_mlir_func_obj_ty(args)
+
+ # Run the compiled program
+ engine.invoke(function_name, *newArgs)
+
+ return result
+
+ return wrapper
diff --git a/mlir/test/Examples/NVGPU/tools/nvgpucompiler.py b/mlir/test/Examples/NVGPU/tools/nvgpucompiler.py
new file mode 100644
index 00000000000000..1c9cc74fcd169c
--- /dev/null
+++ b/mlir/test/Examples/NVGPU/tools/nvgpucompiler.py
@@ -0,0 +1,45 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# This file contains the Nvgpu class.
+
+from mlir import execution_engine
+from mlir import ir
+from mlir import passmanager
+from typing import Sequence
+import errno
+import os
+import sys
+
+_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(_SCRIPT_PATH)
+
+
+class NvgpuCompiler:
+ """Nvgpu class for compiling and building MLIR modules."""
+
+ def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
+ pipeline = f"builtin.module(gpu-lower-to-nvvm-pipeline{{{options}}})"
+ self.pipeline = pipeline
+ self.shared_libs = shared_libs
+ self.opt_level = opt_level
+
+ def __call__(self, module: ir.Module):
+ """Convenience application method."""
+ self.compile(module)
+
+ def compile(self, module: ir.Module):
+ """Compiles the module by invoking the nvgpu pipeline."""
+ passmanager.PassManager.parse(self.pipeline).run(module.operation)
+
+ def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+ """Wraps the module in a JIT execution engine."""
+ return execution_engine.ExecutionEngine(
+ module, opt_level=self.opt_level, shared_libs=self.shared_libs
+ )
+
+ def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+ """Compiles and jits the module."""
+ self.compile(module)
+ return self.jit(module)
More information about the Mlir-commits
mailing list