[Mlir-commits] [mlir] [mlir][nvgpu] NVGPU Tutorials (PR #87065)
Guray Ozen
llvmlistbot at llvm.org
Fri Apr 19 05:56:24 PDT 2024
https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/87065
>From 25392c335bffb4dad7c8161a296fbb83bea199fd Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Fri, 29 Mar 2024 13:09:45 +0000
Subject: [PATCH 01/18] [mlir][nvgpu] Zero to Hero: Programming Nvidia Hopper
Tensor Core with MLIR's NVGPU Dialect
I have a tutorial on nvgpu dialect using python bindings at EuroLLVM 2024. For that, I implemented tutorial codes in Python. The focus is the nvgpu dialect and how to use its advanced features. I thought it might be useful to upstream this.
The tutorial codes are as follows:
- **Ch0.py:** Hello World
- **Ch1.py:** 2D Saxpy
- **Ch2.py:** 2D Saxpy using TMA engine
- **Ch3.py:** GEMM 64x64x64 using Tensor Core and TMA
- **Ch4.py:** Multistage performant GEMM using Tensor Core and TMA
I might implement two more chapters, but they are more like GPU programming than compiler related.
- **Ch5.py:** Warp Specilized GEMM
- **Ch6.py:** Warp Specilized Persistent GEMM
This PR also introduces the nvdsl class, making IR building in the tutorial easier.
---
mlir/test/Examples/nvgpu/Ch0.py | 46 ++
mlir/test/Examples/nvgpu/Ch1.py | 68 +++
mlir/test/Examples/nvgpu/Ch2.py | 94 ++++
mlir/test/Examples/nvgpu/Ch3.py | 92 ++++
mlir/test/Examples/nvgpu/Ch4.py | 291 +++++++++++++
mlir/test/Examples/nvgpu/lit.local.cfg | 3 +
mlir/test/Examples/nvgpu/tools/lit.local.cfg | 3 +
mlir/test/Examples/nvgpu/tools/nvdsl.py | 404 ++++++++++++++++++
.../Examples/nvgpu/tools/nvgpucompiler.py | 45 ++
9 files changed, 1046 insertions(+)
create mode 100644 mlir/test/Examples/nvgpu/Ch0.py
create mode 100644 mlir/test/Examples/nvgpu/Ch1.py
create mode 100644 mlir/test/Examples/nvgpu/Ch2.py
create mode 100644 mlir/test/Examples/nvgpu/Ch3.py
create mode 100644 mlir/test/Examples/nvgpu/Ch4.py
create mode 100644 mlir/test/Examples/nvgpu/lit.local.cfg
create mode 100644 mlir/test/Examples/nvgpu/tools/lit.local.cfg
create mode 100644 mlir/test/Examples/nvgpu/tools/nvdsl.py
create mode 100644 mlir/test/Examples/nvgpu/tools/nvgpucompiler.py
diff --git a/mlir/test/Examples/nvgpu/Ch0.py b/mlir/test/Examples/nvgpu/Ch0.py
new file mode 100644
index 00000000000000..221ca43d37307a
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/Ch0.py
@@ -0,0 +1,46 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN: %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+# Chapter 0 : Hello World
+# ===----------------------------------------------------------------------===//
+#
+# This program demonstrates Hello World
+#
+# This chapter introduces demonstrates:
+# 1. Build MLIR function with arguments
+# 2. Build MLIR GPU kernel
+# 3. Print from a GPU thread
+# 4. Pass arguments, JIT compile and run the MLIR function
+#
+# ===----------------------------------------------------------------------===//
+
+
+from mlir.dialects import gpu
+from tools.nvdsl import *
+
+
+# 1. Build function with arguments
+ at NVDSL.mlir_func
+def main(alpha):
+ # 2. Build GPU kernel
+ @NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(4, 1, 1))
+ def kernel():
+ tidx = gpu.thread_id(gpu.Dimension.x)
+ myValue = alpha + tidx
+ # Print from a GPU thread
+ gpu.printf("GPU thread %llu has %llu\n", [tidx, myValue])
+
+ # 3. Call the GPU kernel
+ kernel()
+
+
+# 4. Pass arguments, JIT compile and run the MLIR function
+alpha = 100
+main(alpha)
+
+
+# CHECK: GPU thread 0 has 100
+# CHECK: GPU thread 1 has 101
+# CHECK: GPU thread 2 has 102
+# CHECK: GPU thread 3 has 103
diff --git a/mlir/test/Examples/nvgpu/Ch1.py b/mlir/test/Examples/nvgpu/Ch1.py
new file mode 100644
index 00000000000000..a888c61358a024
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/Ch1.py
@@ -0,0 +1,68 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN: %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+# Chapter 1 : 2D Saxpy
+# ===----------------------------------------------------------------------===//
+#
+# This program demonstrates 2D Saxpy
+#
+# This chapter introduces demonstrates:
+# 1. Use MLIR GPU dialect to allocate and copy memory
+# 2. Compute 2D SAXPY kernel
+# 3. Pass numpy arrays to MLIR
+# 4. Verify MLIR with reference computation
+#
+# ===----------------------------------------------------------------------===//
+
+
+from mlir import ir
+from mlir.dialects import gpu, memref
+from tools.nvdsl import *
+import numpy as np
+
+
+ at NVDSL.mlir_func
+def saxpy(x, y, alpha):
+ # 1. Use MLIR GPU dialect to allocate and copy memory
+ token_ty = ir.Type.parse("!gpu.async.token")
+ t1 = gpu.wait(token_ty, [])
+ x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
+ y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
+ t4 = gpu.memcpy(token_ty, [t3], x_dev, x)
+ t5 = gpu.memcpy(token_ty, [t4], y_dev, y)
+ t6 = gpu.wait(token_ty, [t5])
+
+ # 2. Compute 2D SAXPY kernel
+ @NVDSL.mlir_gpu_launch(grid=(M, 1, 1), block=(N, 1, 1))
+ def saxpy_kernel():
+ bidx = gpu.block_id(gpu.Dimension.x)
+ tidx = gpu.thread_id(gpu.Dimension.x)
+ x_val = memref.load(x_dev, [bidx, tidx])
+ y_val = memref.load(y_dev, [bidx, tidx])
+
+ # SAXPY: y[i] += a * x[i];
+ y_val += x_val * alpha
+
+ memref.store(y_val, y_dev, [bidx, tidx])
+
+ saxpy_kernel()
+
+ t7 = gpu.memcpy(token_ty, [t6], y, y_dev)
+ gpu.wait(token_ty, [t7])
+
+
+# 3. Pass numpy arrays to MLIR
+M = 256
+N = 32
+alpha = 2.0
+x = np.ones((M, N), np.float32)
+y = np.ones((M, N), np.float32)
+ref = np.ones((M, N), np.float32)
+saxpy(x, y, alpha)
+
+# 4. Verify MLIR with reference computation
+ref += x * alpha
+np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
+print("PASS")
+# CHECK-NOT: Mismatched elements
diff --git a/mlir/test/Examples/nvgpu/Ch2.py b/mlir/test/Examples/nvgpu/Ch2.py
new file mode 100644
index 00000000000000..caa2691a06d69e
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/Ch2.py
@@ -0,0 +1,94 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN: %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+# Chapter 2 : 2D Saxpy with TMA
+# ===----------------------------------------------------------------------===//
+#
+# This program demonstrates 2D Saxpy. It is same as Chapter 1,
+# but it loads data using TMA (Tensor Memory Accelerator)
+#
+# This chapter introduces demonstrates:
+# 1. Create and initialize asynchronous transactional barrier (mbarrier)
+# 2. Execute Tensor Memory Accelerator (TMA) Load
+# 3. Wait for completion of TMA load with mbarrier
+#
+# ===----------------------------------------------------------------------===//
+
+from mlir import ir
+from mlir.dialects import nvgpu, scf, arith, memref, vector, gpu
+from tools.nvdsl import *
+from mlir import runtime as rt
+from mlir.extras import types as T
+import numpy as np
+
+
+ at NVDSL.mlir_func
+def saxpy_tma(x, y, alpha):
+ token_ty = ir.Type.parse("!gpu.async.token")
+ t1 = gpu.wait(token_ty, [])
+ x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
+ y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
+ t4 = gpu.memcpy(token_ty, [t3], x_dev, x)
+ t5 = gpu.memcpy(token_ty, [t4], y_dev, y)
+ t6 = gpu.wait(token_ty, [t5])
+
+ x_tma = TMA((M, N), x.type)
+ y_tma = TMA((M, N), y.type)
+ x_tma.create_descriptor(x_dev)
+ y_tma.create_descriptor(y_dev)
+
+ @NVDSL.mlir_gpu_launch(grid=(M, 1, 1), block=(N, 1, 1), smem=65536)
+ def saxpy_tma_kernel():
+ bidx = gpu.block_id(gpu.Dimension.x)
+ tidx = gpu.thread_id(gpu.Dimension.x)
+ isThread0 = tidx == 0
+
+ # 1. Create and initialize asynchronous transactional barrier (mbarrier)
+ mbar_group = Mbarriers(number_of_barriers=1)
+ with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
+ mbar_group[0].init(1)
+ x_tma.prefetch()
+ y_tma.prefetch()
+ scf.yield_([])
+
+ x_smem = get_dynamic_shared_memory((M, N), T.f32())
+ y_smem = get_dynamic_shared_memory((M, N), T.f32(), offset=M * N * 2)
+
+ # 2. Execute Tensor Memory Accelerator (TMA) Load
+ with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
+ x_tma.load(x_smem, mbar_group[0])
+ y_tma.load(y_smem, mbar_group[0])
+ mbar_group[0].arrive(txcount=M * N * 2 * 4)
+ scf.yield_([])
+
+ # 3. Wait for completion of TMA load with mbarrier
+ mbar_group[0].try_wait()
+
+ x_val = memref.load(x_smem, [bidx, tidx])
+ y_val = memref.load(y_smem, [bidx, tidx])
+
+ # SAXPY: y[i] += a * x[i];
+ y_val += x_val * alpha
+
+ memref.store(y_val, y_dev, [bidx, tidx])
+
+ saxpy_tma_kernel()
+
+ t7 = gpu.memcpy(token_ty, [t6], y, y_dev)
+ gpu.wait(token_ty, [t7])
+
+
+M = 256
+N = 32
+alpha = 2.0
+x = np.ones((M, N), np.float32)
+y = np.ones((M, N), np.float32)
+ref = np.ones((M, N), np.float32)
+saxpy_tma(x, y, alpha)
+
+# 4. Verify MLIR with reference computation
+ref += x * alpha
+np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
+print("PASS")
+# CHECK-NOT: Mismatched elements
diff --git a/mlir/test/Examples/nvgpu/Ch3.py b/mlir/test/Examples/nvgpu/Ch3.py
new file mode 100644
index 00000000000000..802cb59ead1555
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/Ch3.py
@@ -0,0 +1,92 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN: %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+# Chapter 3 : GEMM 64x64x64 with Tensor Core
+# ===----------------------------------------------------------------------===//
+#
+# This program demonstrates a GEMM operation with 64x64x64 matrix multiplication
+#
+# This chapter introduces demonstrates:
+# 1. Execute TMA Load for two input matrices
+# 2. Performs Tensor Core GEMM 64x64x64 by warpgroup
+# 3. Stores fragmented registers to global memory by warpgroup
+#
+# ===----------------------------------------------------------------------===//
+
+
+from mlir import ir
+from mlir.dialects import nvgpu, scf, arith, memref, vector, gpu
+from tools.nvdsl import *
+from mlir.extras import types as T
+import numpy as np
+
+
+ at NVDSL.mlir_func
+def gemm_64_64_64(x, y, z):
+ token_ty = ir.Type.parse("!gpu.async.token")
+ t1 = gpu.wait(token_ty, [])
+ x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
+ y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
+ z_dev, t4 = gpu.alloc(z.type, token_ty, [t3], [], [])
+ t5 = gpu.memcpy(token_ty, [t4], x_dev, x)
+ t6 = gpu.memcpy(token_ty, [t5], y_dev, y)
+ t7 = gpu.wait(token_ty, [t6])
+
+ sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
+ x_tma = TMA([N, N], x.type, swizzle=sw)
+ y_tma = TMA([N, N], y.type, swizzle=sw)
+ x_tma.create_descriptor(x_dev)
+ y_tma.create_descriptor(y_dev)
+
+ @NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(128, 1, 1), smem=16384)
+ def gemm_tma_kernel():
+ tidx = gpu.thread_id(gpu.Dimension.x)
+
+ mbar_group = Mbarriers(number_of_barriers=1)
+ isThread0 = tidx == 0
+ with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
+ mbar_group[0].init(1)
+ x_tma.prefetch()
+ y_tma.prefetch()
+ scf.yield_([])
+
+ x_smem = get_dynamic_shared_memory((N, N), T.f16())
+ y_smem = get_dynamic_shared_memory((N, N), T.f16(), offset=N * N * 2)
+
+ # 1. Execute TMA Load for two input matrices
+ with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
+ x_tma.load(x_smem, mbar_group[0])
+ y_tma.load(y_smem, mbar_group[0])
+ tx_count = get_type_size(x_tma.tma_memref) + get_type_size(y_tma.tma_memref)
+ mbar_group[0].arrive(tx_count)
+ scf.yield_([])
+
+ mbar_group[0].try_wait()
+
+ # 2. Performs Tensor Core GEMM 64x64x64 by warpgroup
+ A = Matrix(x_smem, x_tma, N, N)
+ B = Matrix(y_smem, y_tma, N, N)
+ C = MatrixAccumulator(N, N, T.f32()).op()
+ D = Matrix.matmul(A, B, C)
+
+ # 3. Stores fragmented registers to global memory by warpgroup
+ nvgpu.warpgroup_mma_store(D, z_dev)
+
+ gemm_tma_kernel()
+
+ t8 = gpu.memcpy(token_ty, [t7], z, z_dev)
+ gpu.wait(None, [t8])
+
+
+# Python pass arguments to MLIR
+N = 64
+x = np.random.randn(N, N).astype(np.float16)
+y = np.random.randn(N, N).astype(np.float16)
+z = np.zeros((N, N), np.float32)
+gemm_64_64_64(x, y, z)
+
+ref = x.astype(np.float16) @ y.astype(np.float16)
+np.testing.assert_allclose(z, ref, rtol=5e-03, atol=1e-01)
+print("PASS")
+# CHECK-NOT: Mismatched elements
diff --git a/mlir/test/Examples/nvgpu/Ch4.py b/mlir/test/Examples/nvgpu/Ch4.py
new file mode 100644
index 00000000000000..d222d6e60aafad
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/Ch4.py
@@ -0,0 +1,291 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN: %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+# Chapter 4 : Multistage GEMM with Tensor Core
+# ===----------------------------------------------------------------------===//
+#
+# This program demonstrates a GEMM operation with 64x64x64 matrix multiplication
+#
+# This chapter introduces demonstrates:
+# 1. Partition shape based on block IDs
+# 2. Prologue
+# 2.1 Execute TMA Load for two input matrices for each stage
+# 3. Main loop
+# 3.1 Wait for completion of TMA load with mbarrier
+# 3.2 Performs Tensor Core GEMM 64x128x64 by warpgroup
+# 3.3 Load next stage if needed
+# 4. Epilogue
+# 4.1 Store fragmented registers to shared memory
+# 4.2 Store shared memory to global
+#
+# ===----------------------------------------------------------------------===//
+
+
+from mlir import ir
+from mlir.dialects import gpu, scf, nvgpu, nvvm
+from mlir.extras import types as T
+from tools.nvdsl import *
+import numpy as np
+
+
+def partition_shape():
+ """
+ Calculate the partition shape based on the block IDs.
+
+ It partitions the shape like below:
+ for(.. i < M ...) --> blockIdx.x
+ for(.. j < N ...) --> blockIdx.y
+ for(.. k < K ...)
+
+ Returns:
+ dimX (int): Dimension along the x-axis.
+ dimY (int): Dimension along the y-axis.
+ """
+ bidx = gpu.block_id(gpu.Dimension.x)
+ bidy = gpu.block_id(gpu.Dimension.y)
+ dimX = bidx * TILE_M
+ dimY = bidy * TILE_N
+ return dimX, dimY
+
+
+def tma_load(
+ mbar_group: Mbarriers,
+ x_tma: TMA,
+ y_tma: TMA,
+ slot,
+ stage,
+ p=None,
+):
+ """
+ TMA loads two input matrices from global memory to shared memory. It performs the following operations:
+
+ - tma.load x_shared_memory[offset] at coordinate [x, y] (Loads 128x64)
+ - tma.load y_shared_memory[offset] at coordinate [x, y] (Loads 64x64)
+ - tma.load y_shared_memory[offset] at coordinate [x, y] (Loads 64x64)
+
+ mbarrier.arrive tx_count = 128x64x2x4
+ """
+ dimX, dimY = partition_shape()
+
+ tidx = gpu.thread_id(gpu.Dimension.x)
+ begin_y = NUM_STAGES * get_type_size(x_tma.tma_memref)
+ size_tma_x = get_type_size(x_tma.tma_memref)
+ size_tma_y = get_type_size(y_tma.tma_memref)
+ tx_count = size_tma_x + (size_tma_y * 2)
+ tidx = gpu.thread_id(gpu.Dimension.x)
+
+ p = tidx == 0 if p is None else p
+
+ off_x = slot * size_tma_x
+ off_y = (slot * size_tma_x) + begin_y
+ off_y2 = off_y + size_tma_y
+ x = get_dynamic_shared_memory(
+ x_tma.tma_memref.shape, x_tma.tma_memref.element_type, off_x
+ )
+ y1 = get_dynamic_shared_memory(
+ y_tma.tma_memref.shape, y_tma.tma_memref.element_type, off_y
+ )
+ y2 = get_dynamic_shared_memory(
+ y_tma.tma_memref.shape, y_tma.tma_memref.element_type, off_y2
+ )
+
+ mbar_group[slot].arrive(tx_count, predicate=p)
+
+ c1 = stage * 64
+ x_tma.load(x, mbar_group[slot], coords=[c1, dimX], predicate=p)
+ y_tma.load(y1, mbar_group[slot], coords=[dimY, c1], predicate=p)
+ y_tma.load(y2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p)
+
+
+def bootstrap(x_tma: TMA, y_tma: TMA):
+ """
+ Initialize mbarriers and prefetch TMA descriptors.
+ """
+ tidx = gpu.thread_id(gpu.Dimension.x)
+ mbar_group = Mbarriers(number_of_barriers=NUM_STAGES)
+ isThread0 = tidx == const(0)
+ with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
+ for i in scf.for_(0, NUM_STAGES, 1):
+ mbar_group[i].init(1)
+ scf.yield_([])
+ x_tma.prefetch()
+ y_tma.prefetch()
+ scf.yield_([])
+
+ return mbar_group
+
+
+def prologue(mbar_group: Mbarriers, x_tma: TMA, y_tma: TMA):
+ """
+ Prologue of the GEMM kernel. It loads 2 input matrices for each stage in loop like below:
+
+ for stage in range(NUM_STAGES):
+ tma_load x, y, stage
+
+ """
+ ns = NUM_STAGES if NUM_STAGES == 1 else NUM_STAGES - 1
+ for iv in scf.for_(0, ns, 1):
+ tma_load(mbar_group, x_tma, y_tma, iv, iv)
+ scf.yield_([])
+
+
+def mainloop(mbar_group: Mbarriers, x_tma: TMA, y_tma: TMA):
+ """
+ Main loop of the Multistage GEMM kernel. It iterates through
+ stages and performs matrix multiplication, loading data by TMA to shared memory. It like following
+
+ MatrixAccumulator D
+ for k in range(K // TILE_K):
+
+ try_wait(stage, ...) # Wait TMA load
+
+ Matrix A(stage, ...) # Find shared memory slot
+ Matrix B(stage, ...) # Find shared memory slot
+ D += A @ B # Multiply and accumulate
+
+ if(needLoad) # Load next stage if needed
+ tma_load(x, y, nextSlot, nextStage)
+
+ """
+ ns = NUM_STAGES if NUM_STAGES == 1 else NUM_STAGES - 1
+
+ tidx = gpu.thread_id(gpu.Dimension.x)
+ begin_y = NUM_STAGES * get_type_size(x_tma.tma_memref)
+
+ size_x = TILE_M * TILE_K * get_type_size(T.f16())
+
+ C = MatrixAccumulator(TILE_M, TILE_N, T.f32()).op()
+ pp = const(False, ty=T.bool())
+
+ # Main Loop
+ for_op = scf.ForOp(const(0), const(K // TILE_K), const(1), [C, pp])
+ with ir.InsertionPoint(for_op.body):
+ pp = for_op.inner_iter_args[1]
+ iv = for_op.induction_variable
+ stage = iv % NUM_STAGES
+
+ # Wait for current stage
+ mbar_group[stage].try_wait(phase=pp)
+
+ # Find shared memory slot
+ offset_x = stage * size_x
+ offset_y = offset_x + begin_y
+ x_smem = get_dynamic_shared_memory([TILE_M, TILE_K], T.f16(), offset_x)
+ y_smem = get_dynamic_shared_memory([TILE_K, TILE_N], T.f16(), offset_y)
+
+ # Matrix Multiply
+ A = Matrix(x_smem, x_tma, TILE_M, TILE_K)
+ B = Matrix(y_smem, y_tma, TILE_K, TILE_N)
+ C = for_op.inner_iter_args[0]
+ D = Matrix.matmul(A, B, C)
+ if NUM_STAGES == 1:
+ nvvm.WgmmaWaitGroupSyncOp(0)
+
+ # Load next stage
+ pred = ((iv + ns) < const(K // TILE_K)) & (tidx == 0)
+ nextStage = iv + ns
+ nextSlot = nextStage % NUM_STAGES
+ tma_load(mbar_group, x_tma, y_tma, nextSlot, nextStage, pred)
+
+ # Switch phase parity for the mbarrier
+ switched = pp ^ const(True, ty=T.bool())
+ newPP = arith.select(
+ stage == (NUM_STAGES - 1),
+ switched,
+ pp,
+ )
+ scf.yield_([D, newPP])
+
+ nvvm.WgmmaWaitGroupSyncOp(0)
+
+ return for_op.results[0]
+
+
+def epilogue(D, z_dev):
+ """
+ Epilogue of the GEMM kernel. It stores the fragmented registers to global memory.
+
+ MatrixAccumulator D # Fragmented results
+ store D -> Shared Memory # Store Shared Memory
+ Shared Memory -> Z[dimX][dimY] # Store Shared Memory to Global Memory
+
+ """
+ tidx = gpu.thread_id(gpu.Dimension.x)
+ dimX, dimY = partition_shape()
+
+ z_smem = get_dynamic_shared_memory([TILE_M, TILE_N], T.f32())
+ z_gmem = memref.subview(z_dev, [dimX, dimY], [TILE_M, TILE_N], [1, 1])
+
+ # Store (registers -> shared memory)
+ nvgpu.WarpgroupMmaStoreOp(D, z_smem)
+ gpu.barrier()
+
+ # Store (shared memory --> global memory)
+ for i in scf.for_(0, TILE_M, 1):
+ val = memref.load(z_smem, [i, tidx])
+ memref.store(val, z_gmem, [i, tidx])
+ scf.yield_([])
+
+
+ at NVDSL.mlir_func
+def gemm_multistage(x, y, z):
+ token_ty = ir.Type.parse("!gpu.async.token")
+ t1 = gpu.wait(token_ty, [])
+ x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
+ y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
+ z_dev, t4 = gpu.alloc(z.type, token_ty, [t3], [], [])
+ t5 = gpu.memcpy(token_ty, [t4], x_dev, x)
+ t6 = gpu.memcpy(token_ty, [t5], y_dev, y)
+ t7 = gpu.wait(token_ty, [t6])
+
+ sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
+ x_tma = TMA([128, 64], x.type, swizzle=sw)
+ y_tma = TMA([64, 64], y.type, swizzle=sw)
+ x_tma.create_descriptor(x_dev)
+ y_tma.create_descriptor(y_dev)
+
+ grid = [(M // TILE_M), (N // TILE_N), 1]
+ block = [128, 1, 1]
+ @NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=229440)
+ def gemm_multistage_kernel():
+ # Initialize mbarriers and prefetch TMA descriptors
+ mbar_group = bootstrap(x_tma, y_tma)
+
+ # Fill the pipeline stages
+ prologue(mbar_group, x_tma, y_tma)
+
+ # Main loop
+ D = mainloop(mbar_group, x_tma, y_tma)
+
+ # Store registers to global memory
+ epilogue(D, z_dev)
+
+ gemm_multistage_kernel()
+
+ t8 = gpu.memcpy(token_ty, [t7], z, z_dev)
+ gpu.wait(None, [t8])
+
+
+# Python pass arguments to MLIR
+NUM_STAGES = 7
+N = 256
+M = 512
+K = 1024
+TILE_M = 128
+TILE_N = 128
+TILE_K = 64
+x = np.random.randn(M, K).astype(np.float16)
+y = np.random.randn(K, N).astype(np.float16)
+z = np.zeros((M, N), np.float32)
+
+gemm_multistage(x, y, z)
+
+
+# Verify MLIR with reference computation
+ref = x.astype(np.float16) @ y.astype(np.float16)
+np.testing.assert_allclose(z, ref, rtol=5e-03, atol=1e-01)
+
+
+print("PASS")
+# CHECK-NOT: Mismatched elements
diff --git a/mlir/test/Examples/nvgpu/lit.local.cfg b/mlir/test/Examples/nvgpu/lit.local.cfg
new file mode 100644
index 00000000000000..e586b55573898a
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/lit.local.cfg
@@ -0,0 +1,3 @@
+config.unsupported = False
+if not config.enable_cuda_runner or not config.mlir_run_cuda_sm90_tests:
+ config.unsupported = True
\ No newline at end of file
diff --git a/mlir/test/Examples/nvgpu/tools/lit.local.cfg b/mlir/test/Examples/nvgpu/tools/lit.local.cfg
new file mode 100644
index 00000000000000..d9f34f219c4d95
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/tools/lit.local.cfg
@@ -0,0 +1,3 @@
+# Files in this directory are tools, not tests.
+config.unsupported = True
+
diff --git a/mlir/test/Examples/nvgpu/tools/nvdsl.py b/mlir/test/Examples/nvgpu/tools/nvdsl.py
new file mode 100644
index 00000000000000..4c774f4d4deac9
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/tools/nvdsl.py
@@ -0,0 +1,404 @@
+from enum import Enum
+import functools, sys, ctypes, os, errno
+import numpy as np
+from functools import partialmethod
+from mlir import ir
+from mlir.dialects import arith, func, gpu, memref, nvgpu
+from mlir.extras import types as T
+from mlir import runtime as rt
+from tools import nvgpucompiler
+
+DEBUG = True
+MLIR_DYNAMIC = -9223372036854775808
+
+
+def const(value: int, ty=None):
+ ty = T.index() if ty is None else ty
+ if isinstance(value, ir.Value) and (value.type.isinstance(value.type) or T.bool().isinstance(value.type)):
+ return value
+ return arith.constant(ty, value)
+
+
+def get_type_size(ty):
+ if ir.MemRefType.isinstance(ty):
+ size = get_type_size(ty.element_type)
+ for sz in ty.shape:
+ size *= sz
+ return size
+ if ir.FloatType.isinstance(ty):
+ return ir.FloatType(ty).width // 8
+ if ir.IntegerType.isinstance(ty):
+ return ir.IntegerType(ty).width // 8
+ raise NotImplementedError(ty)
+
+def get_mlir_func_obj_ty(inputArgs):
+ args = []
+ c_int_p = ctypes.c_int * 1
+ c_float_p = ctypes.c_float * 1
+ c_bool_p = ctypes.c_bool * 1
+ for arg in inputArgs:
+ if isinstance(arg, bool):
+ args.append(c_bool_p(arg))
+ elif isinstance(arg, int):
+ args.append(c_int_p(arg))
+ elif isinstance(arg, float):
+ args.append(c_float_p(arg))
+ elif isinstance(arg, np.ndarray):
+ args.append(
+ ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(arg)))
+ )
+ else:
+ raise NotImplementedError(arg)
+ return args
+
+
+class Mbarriers:
+ def __init__(self, number_of_barriers=1):
+ self.mbar_ty = ir.Type.parse(
+ "!nvgpu.mbarrier.group<memorySpace=#gpu.address_space<workgroup>, num_barriers = "
+ + str(number_of_barriers)
+ + ">"
+ )
+ self.mbar_group_op = nvgpu.mbarrier_create(self.mbar_ty)
+ self.number_of_barriers = number_of_barriers
+
+ def __getitem__(self, key):
+ self.id_op = const(key)
+ return self
+
+ def init(self, count: int, predicate=None):
+ count_op = const(count)
+ if predicate is None:
+ nvgpu.mbarrier_init(self.mbar_group_op, count_op, self.id_op)
+ else:
+ nvgpu.mbarrier_init(
+ self.mbar_group_op, count_op, self.id_op, predicate=predicate
+ )
+
+ def arrive(self, txcount: int = 0, predicate=None):
+ if txcount != 0:
+ txcount_op = const(txcount)
+ nvgpu.mbarrier_arrive_expect_tx(
+ self.mbar_group_op, txcount_op, self.id_op, predicate=predicate
+ )
+ else:
+ nvgpu.mbarrier_arrive(self.mbar_group_op, self.id_op, predicate=predicate)
+
+ def try_wait(self, phase: bool = False, ticks: int = 10000000):
+ ticks_op = const(ticks)
+ phase_op = const(phase, T.bool())
+ nvgpu.MBarrierTryWaitParityOp(
+ self.mbar_group_op,
+ phase_op,
+ ticks_op,
+ mbarId=self.id_op,
+ )
+
+
+class TMA:
+ """A class that builds a TMA descriptor."""
+
+ def __init__(
+ self,
+ shape,
+ memref_ty,
+ swizzle=nvgpu.TensorMapSwizzleKind.SWIZZLE_NONE,
+ l2promo=nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
+ oob=nvgpu.TensorMapOOBKind.OOB_ZERO,
+ interleave=nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
+ ):
+ self.swizzle = swizzle # mlir.nvgpu.TensorMapSwizzleKind
+ self.l2promo = l2promo # mlir.nvgpu.TensorMapL2PromoKind
+ self.oob = oob # mlir.nvgpu.TensorMapOOBKind
+ self.interleave = interleave # mlir.nvgpu.TensorMapInterleaveKind
+ self.shape = shape
+ self.memref_ty = memref_ty # MemRefType
+ self.lastDim = 64
+ self.requiredLoad = 1
+ self.tma_shape = shape
+ self.tma_memref = ir.MemRefType.get(shape, memref_ty.element_type)
+
+ @property
+ def tensormap_descriptor_ty(self):
+ """Returns a tensormap descriptor type."""
+ memref_str = f"memref<{self.tma_shape[0]}x{self.tma_shape[1]}x{self.memref_ty.element_type}, 3>"
+ parse_str = f"!nvgpu.tensormap.descriptor<tensor = {memref_str},\
+ swizzle = {self.swizzle},\
+ l2promo = {self.l2promo},\
+ oob = {self.oob},\
+ interleave = {self.interleave}>"
+
+ return ir.Type.parse(parse_str)
+
+ def create_descriptor(self, device_ptr):
+ tma_descriptor_ty = self.tensormap_descriptor_ty
+ device_unranked_memref = memref.CastOp(
+ ir.UnrankedMemRefType.get(self.memref_ty.element_type,
+ self.memref_ty.memory_space),
+ device_ptr,
+ )
+ self.tma_descriptor = nvgpu.TmaCreateDescriptorOp(
+ tma_descriptor_ty, device_unranked_memref,
+ map(const, self.tma_shape))
+ return self.tma_descriptor.result
+
+ def prefetch(self, predicate=None):
+ nvgpu.tma_prefetch_descriptor(self.tma_descriptor, predicate=predicate)
+
+ def load(self, dest, mbarrier: Mbarriers, coords=[0,0], predicate=None):
+ coord_ops = [const(c) for c in coords]
+ nvgpu.TmaAsyncLoadOp(
+ dest,
+ mbarrier.mbar_group_op,
+ self.tma_descriptor,
+ coordinates=coord_ops,
+ mbarId=mbarrier.id_op,
+ predicate=predicate,
+ )
+
+
+class MatrixAccumulator:
+ def __init__(self, M, N, ty):
+ self.M = M
+ self.N = N
+ self.ty = ty
+
+ @property
+ def acc_ty(self):
+ return ir.Type.parse(
+ "!nvgpu.warpgroup.accumulator<fragmented=vector<"
+ + str(self.M)
+ + "x"
+ + str(self.N)
+ + "x"
+ + str(self.ty)
+ + ">>"
+ )
+
+ def op(self):
+ return nvgpu.warpgroup_mma_init_accumulator(self.acc_ty)
+
+
+class Matrix:
+
+ def __init__(self, smem, tma_descriptor: TMA, M, N):
+ self.tma_descriptor = tma_descriptor
+ self.smem = smem
+ self.M = M
+ self.N = N
+
+ @property
+ def wgmma_ty(self):
+ return ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" +
+ str(self.M) + "x" +
+ str(self.N) + "x" +
+ str(self.tma_descriptor.memref_ty.element_type) +
+ ", #gpu.address_space<workgroup>>>")
+
+ def matmul(lhs, rhs, acc):
+ wgmma_desc_lhs = nvgpu.warpgroup_generate_descriptor(
+ lhs.wgmma_ty, lhs.smem, lhs.tma_descriptor.tma_descriptor)
+ wgmma_desc_rhs = nvgpu.warpgroup_generate_descriptor(
+ rhs.wgmma_ty, rhs.smem, rhs.tma_descriptor.tma_descriptor)
+ return nvgpu.WarpgroupMmaOp(acc.type,
+ wgmma_desc_lhs,
+ wgmma_desc_rhs,
+ acc,
+ transposeB=True)
+
+def get_dynamic_shared_memory(shape=None, ty=None, offset: int = 0):
+ smem_space_str = "#gpu.address_space<workgroup>"
+ smem_space = ir.Attribute.parse(smem_space_str)
+ dynamic_smem = gpu.dynamic_shared_memory(
+ ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space)
+ )
+ if shape is None:
+ return dynamic_smem
+ memref_ty = ir.MemRefType.get(shape, ty, memory_space=smem_space)
+ return memref.view(
+ ir.MemRefType.get(
+ memref_ty.shape, memref_ty.element_type, memory_space=smem_space
+ ),
+ dynamic_smem,
+ const(offset),
+ [],
+ )
+
+ at staticmethod
+def get_mlir_ty(arg):
+ def get_mlir_ty_from_np(dtype):
+ if dtype == np.float16:
+ return T.f16()
+ if dtype == np.float32:
+ return T.f32()
+ if dtype == np.float64:
+ return T.f64()
+ if dtype == np.int32:
+ return T.i32()
+ if dtype == np.int64:
+ return T.i64()
+ raise NotImplementedError(dtype)
+ if isinstance(arg, bool):
+ return T.bool()
+ elif isinstance(arg, int):
+ return T.index()
+ elif isinstance(arg, float):
+ return T.f32()
+ elif isinstance(arg, np.ndarray):
+ descriptor = rt.get_ranked_memref_descriptor(arg)
+ dtype = get_mlir_ty_from_np(arg.dtype)
+ shape = descriptor.shape
+ return memref.MemRefType.get(shape, dtype)
+ raise NotImplementedError(arg)
+
+class NVDSL:
+ @staticmethod
+ def mlir_gpu_launch(grid=(1, 1, 1), block=(1, 1, 1), smem=0):
+ def decorator(func):
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ launch_op = gpu.LaunchOp(
+ None,
+ [],
+ *map(const, grid),
+ *map(const, block),
+ dynamicSharedMemorySize=arith.constant(T.i32(), smem),
+ )
+ launch_op.body.blocks.append(*([T.index()] * 12))
+ with ir.InsertionPoint(launch_op.body.blocks[0]):
+ result = func(*args, **kwargs)
+ gpu.terminator()
+ return result
+
+ return wrapper
+
+ return decorator
+
+ @staticmethod
+ def mlir_func(funcBody):
+ @functools.wraps(funcBody)
+ def wrapper(*args, **kwargs):
+ function_name = funcBody.__name__
+
+ def saveIR(module):
+ """Save generated IR"""
+ if True: # self.saveIR:
+ # print(mlir_nvgpu_module)
+ original_stdout = sys.stdout
+ with open("nvdsl.mlir", "w") as f:
+ sys.stdout = f
+ print(module)
+ sys.stdout = original_stdout
+
+ def _binary_op(lhs, rhs, op: str, predAtt="") -> "ArithValue":
+ """Generate MLIR's Arith dialects binary operations."""
+ rhs = const(rhs)
+ if arith._is_float_type(lhs.type) and arith._is_float_type(rhs.type):
+ op += "F"
+ if op.startswith("Cmp"):
+ predicateAttr = getattr(arith, f"CmpFPredicate").__dict__[
+ predAtt
+ ]
+ elif arith._is_integer_like_type(
+ lhs.type
+ ) and arith._is_integer_like_type(lhs.type):
+ if op == "Div" or op == "Rem":
+ op += "U"
+ op += "I"
+ if op.startswith("Cmp"):
+ predicateAttr = getattr(arith, f"CmpIPredicate").__dict__[
+ predAtt
+ ]
+ else:
+ raise NotImplementedError(
+ f"Unsupported '{op}' operands: {lhs}, {rhs}"
+ )
+
+ if op.startswith("Cmp"):
+ op = getattr(arith, f"{op}Op")
+
+ return op(predicateAttr, lhs, rhs).result
+ else:
+ op = getattr(arith, f"{op}Op")
+ return op(lhs, rhs).result
+
+ @ir.register_value_caster(ir.IndexType.static_typeid)
+ @ir.register_value_caster(ir.F32Type.static_typeid)
+ @ir.register_value_caster(ir.F16Type.static_typeid)
+ @ir.register_value_caster(ir.F64Type.static_typeid)
+ @ir.register_value_caster(ir.IntegerType.static_typeid)
+ class ArithValue(ir.Value):
+ """Overloads operators for MLIR's Arith dialects binary operations."""
+
+ def __init__(self, v):
+ super().__init__(v)
+
+ __add__ = partialmethod(_binary_op, op="Add")
+ __sub__ = partialmethod(_binary_op, op="Sub")
+ __mul__ = partialmethod(_binary_op, op="Mul")
+ __truediv__ = partialmethod(_binary_op, op="Div")
+ __mod__ = partialmethod(_binary_op, op="Rem")
+ __xor__ = partialmethod(_binary_op, op="XOr")
+ __lt__ = partialmethod(_binary_op, op="Cmp", predAtt="ult")
+ __le__ = partialmethod(_binary_op, op="Cmp", predAtt="ule")
+ __eq__ = partialmethod(_binary_op, op="Cmp", predAtt="eq")
+ __ne__ = partialmethod(_binary_op, op="Cmp", predAtt="ne")
+ __gt__ = partialmethod(_binary_op, op="Cmp", predAtt="ugt")
+ __ge__ = partialmethod(_binary_op, op="Cmp", predAtt="uge")
+ __and__ = partialmethod(_binary_op, op="And")
+ __or__ = partialmethod(_binary_op, op="Or")
+
+ def __str__(self):
+ return (
+ super()
+ .__str__()
+ .replace(ir.Value.__name__, ArithValue.__name__)
+ )
+
+ # Generate MLIR Context and start generating IR
+ with ir.Context(), ir.Location.unknown():
+ types = []
+ for arg in args:
+ types.append(get_mlir_ty(arg))
+
+ # Build IR
+ module = ir.Module.create()
+ with ir.InsertionPoint(module.body):
+ fop = func.FuncOp(function_name, (types, []))
+ fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+ with ir.InsertionPoint(fop.add_entry_block()):
+ fargs = []
+ for i, a in enumerate(types):
+ fargs.append(fop.arguments[i])
+
+ # Call user function body
+ result = funcBody(*fargs, **kwargs)
+ func.ReturnOp([])
+
+ # Verify the module
+ module.operation.verify()
+
+ # Save IR in a file
+ # saveIR(module)
+
+ # Compile and JIT MLIR module
+ options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
+ support_lib = os.getenv("SUPPORT_LIB")
+ if not os.path.exists(support_lib):
+ raise FileNotFoundError(
+ errno.ENOENT, os.strerror(errno.ENOENT), support_lib
+ )
+ compiler = nvgpucompiler.NvgpuCompiler(
+ options, opt_level=3, shared_libs=[support_lib]
+ )
+ engine = compiler.compile_and_jit(module)
+
+ # Convert input arguments to MLIR arguments
+ newArgs = get_mlir_func_obj_ty(args)
+
+ # Run the compiled program
+ engine.invoke(function_name, *newArgs)
+
+ return result
+
+ return wrapper
diff --git a/mlir/test/Examples/nvgpu/tools/nvgpucompiler.py b/mlir/test/Examples/nvgpu/tools/nvgpucompiler.py
new file mode 100644
index 00000000000000..1c9cc74fcd169c
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/tools/nvgpucompiler.py
@@ -0,0 +1,45 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# This file contains the Nvgpu class.
+
+from mlir import execution_engine
+from mlir import ir
+from mlir import passmanager
+from typing import Sequence
+import errno
+import os
+import sys
+
+_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(_SCRIPT_PATH)
+
+
+class NvgpuCompiler:
+ """Nvgpu class for compiling and building MLIR modules."""
+
+ def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
+ pipeline = f"builtin.module(gpu-lower-to-nvvm-pipeline{{{options}}})"
+ self.pipeline = pipeline
+ self.shared_libs = shared_libs
+ self.opt_level = opt_level
+
+ def __call__(self, module: ir.Module):
+ """Convenience application method."""
+ self.compile(module)
+
+ def compile(self, module: ir.Module):
+ """Compiles the module by invoking the nvgpu pipeline."""
+ passmanager.PassManager.parse(self.pipeline).run(module.operation)
+
+ def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+ """Wraps the module in a JIT execution engine."""
+ return execution_engine.ExecutionEngine(
+ module, opt_level=self.opt_level, shared_libs=self.shared_libs
+ )
+
+ def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+ """Compiles and jits the module."""
+ self.compile(module)
+ return self.jit(module)
>From 843db9b99cbb58ef1ac0140d8a32684d43a99a3b Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Fri, 29 Mar 2024 13:51:21 +0000
Subject: [PATCH 02/18] format
---
mlir/test/Examples/nvgpu/Ch2.py | 2 +-
mlir/test/Examples/nvgpu/Ch4.py | 1 +
mlir/test/Examples/nvgpu/tools/nvdsl.py | 49 +++++++++++++++----------
3 files changed, 32 insertions(+), 20 deletions(-)
diff --git a/mlir/test/Examples/nvgpu/Ch2.py b/mlir/test/Examples/nvgpu/Ch2.py
index caa2691a06d69e..e9369227af0e85 100644
--- a/mlir/test/Examples/nvgpu/Ch2.py
+++ b/mlir/test/Examples/nvgpu/Ch2.py
@@ -5,7 +5,7 @@
# Chapter 2 : 2D Saxpy with TMA
# ===----------------------------------------------------------------------===//
#
-# This program demonstrates 2D Saxpy. It is same as Chapter 1,
+# This program demonstrates 2D Saxpy. It is same as Chapter 1,
# but it loads data using TMA (Tensor Memory Accelerator)
#
# This chapter introduces demonstrates:
diff --git a/mlir/test/Examples/nvgpu/Ch4.py b/mlir/test/Examples/nvgpu/Ch4.py
index d222d6e60aafad..5f790c4c190ca9 100644
--- a/mlir/test/Examples/nvgpu/Ch4.py
+++ b/mlir/test/Examples/nvgpu/Ch4.py
@@ -247,6 +247,7 @@ def gemm_multistage(x, y, z):
grid = [(M // TILE_M), (N // TILE_N), 1]
block = [128, 1, 1]
+
@NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=229440)
def gemm_multistage_kernel():
# Initialize mbarriers and prefetch TMA descriptors
diff --git a/mlir/test/Examples/nvgpu/tools/nvdsl.py b/mlir/test/Examples/nvgpu/tools/nvdsl.py
index 4c774f4d4deac9..e27857d8ed1d79 100644
--- a/mlir/test/Examples/nvgpu/tools/nvdsl.py
+++ b/mlir/test/Examples/nvgpu/tools/nvdsl.py
@@ -14,7 +14,9 @@
def const(value: int, ty=None):
ty = T.index() if ty is None else ty
- if isinstance(value, ir.Value) and (value.type.isinstance(value.type) or T.bool().isinstance(value.type)):
+ if isinstance(value, ir.Value) and (
+ value.type.isinstance(value.type) or T.bool().isinstance(value.type)
+ ):
return value
return arith.constant(ty, value)
@@ -31,6 +33,7 @@ def get_type_size(ty):
return ir.IntegerType(ty).width // 8
raise NotImplementedError(ty)
+
def get_mlir_func_obj_ty(inputArgs):
args = []
c_int_p = ctypes.c_int * 1
@@ -133,19 +136,20 @@ def tensormap_descriptor_ty(self):
def create_descriptor(self, device_ptr):
tma_descriptor_ty = self.tensormap_descriptor_ty
device_unranked_memref = memref.CastOp(
- ir.UnrankedMemRefType.get(self.memref_ty.element_type,
- self.memref_ty.memory_space),
+ ir.UnrankedMemRefType.get(
+ self.memref_ty.element_type, self.memref_ty.memory_space
+ ),
device_ptr,
)
self.tma_descriptor = nvgpu.TmaCreateDescriptorOp(
- tma_descriptor_ty, device_unranked_memref,
- map(const, self.tma_shape))
+ tma_descriptor_ty, device_unranked_memref, map(const, self.tma_shape)
+ )
return self.tma_descriptor.result
def prefetch(self, predicate=None):
nvgpu.tma_prefetch_descriptor(self.tma_descriptor, predicate=predicate)
- def load(self, dest, mbarrier: Mbarriers, coords=[0,0], predicate=None):
+ def load(self, dest, mbarrier: Mbarriers, coords=[0, 0], predicate=None):
coord_ops = [const(c) for c in coords]
nvgpu.TmaAsyncLoadOp(
dest,
@@ -180,7 +184,6 @@ def op(self):
class Matrix:
-
def __init__(self, smem, tma_descriptor: TMA, M, N):
self.tma_descriptor = tma_descriptor
self.smem = smem
@@ -189,22 +192,27 @@ def __init__(self, smem, tma_descriptor: TMA, M, N):
@property
def wgmma_ty(self):
- return ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" +
- str(self.M) + "x" +
- str(self.N) + "x" +
- str(self.tma_descriptor.memref_ty.element_type) +
- ", #gpu.address_space<workgroup>>>")
+ return ir.Type.parse(
+ "!nvgpu.warpgroup.descriptor<tensor=memref<"
+ + str(self.M)
+ + "x"
+ + str(self.N)
+ + "x"
+ + str(self.tma_descriptor.memref_ty.element_type)
+ + ", #gpu.address_space<workgroup>>>"
+ )
def matmul(lhs, rhs, acc):
wgmma_desc_lhs = nvgpu.warpgroup_generate_descriptor(
- lhs.wgmma_ty, lhs.smem, lhs.tma_descriptor.tma_descriptor)
+ lhs.wgmma_ty, lhs.smem, lhs.tma_descriptor.tma_descriptor
+ )
wgmma_desc_rhs = nvgpu.warpgroup_generate_descriptor(
- rhs.wgmma_ty, rhs.smem, rhs.tma_descriptor.tma_descriptor)
- return nvgpu.WarpgroupMmaOp(acc.type,
- wgmma_desc_lhs,
- wgmma_desc_rhs,
- acc,
- transposeB=True)
+ rhs.wgmma_ty, rhs.smem, rhs.tma_descriptor.tma_descriptor
+ )
+ return nvgpu.WarpgroupMmaOp(
+ acc.type, wgmma_desc_lhs, wgmma_desc_rhs, acc, transposeB=True
+ )
+
def get_dynamic_shared_memory(shape=None, ty=None, offset: int = 0):
smem_space_str = "#gpu.address_space<workgroup>"
@@ -224,6 +232,7 @@ def get_dynamic_shared_memory(shape=None, ty=None, offset: int = 0):
[],
)
+
@staticmethod
def get_mlir_ty(arg):
def get_mlir_ty_from_np(dtype):
@@ -238,6 +247,7 @@ def get_mlir_ty_from_np(dtype):
if dtype == np.int64:
return T.i64()
raise NotImplementedError(dtype)
+
if isinstance(arg, bool):
return T.bool()
elif isinstance(arg, int):
@@ -251,6 +261,7 @@ def get_mlir_ty_from_np(dtype):
return memref.MemRefType.get(shape, dtype)
raise NotImplementedError(arg)
+
class NVDSL:
@staticmethod
def mlir_gpu_launch(grid=(1, 1, 1), block=(1, 1, 1), smem=0):
>From 2961d61cb75bfe53fe861fa5eaa0468192704564 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Mon, 1 Apr 2024 14:05:12 +0000
Subject: [PATCH 03/18] address comments
---
mlir/test/Examples/nvgpu/Ch3.py | 63 +++++------
mlir/test/Examples/nvgpu/Ch4.py | 134 ++++++++++++------------
mlir/test/Examples/nvgpu/tools/nvdsl.py | 26 ++---
3 files changed, 107 insertions(+), 116 deletions(-)
diff --git a/mlir/test/Examples/nvgpu/Ch3.py b/mlir/test/Examples/nvgpu/Ch3.py
index 802cb59ead1555..b1623f9d645c2c 100644
--- a/mlir/test/Examples/nvgpu/Ch3.py
+++ b/mlir/test/Examples/nvgpu/Ch3.py
@@ -23,23 +23,24 @@
@NVDSL.mlir_func
-def gemm_64_64_64(x, y, z):
+def gemm_64_64_64(x, y, d):
token_ty = ir.Type.parse("!gpu.async.token")
t1 = gpu.wait(token_ty, [])
- x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
- y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
- z_dev, t4 = gpu.alloc(z.type, token_ty, [t3], [], [])
- t5 = gpu.memcpy(token_ty, [t4], x_dev, x)
- t6 = gpu.memcpy(token_ty, [t5], y_dev, y)
+ a_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
+ b_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
+ d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], [])
+ t5 = gpu.memcpy(token_ty, [t4], a_dev, x)
+ t6 = gpu.memcpy(token_ty, [t5], b_dev, y)
t7 = gpu.wait(token_ty, [t6])
sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
- x_tma = TMA([N, N], x.type, swizzle=sw)
- y_tma = TMA([N, N], y.type, swizzle=sw)
- x_tma.create_descriptor(x_dev)
- y_tma.create_descriptor(y_dev)
+ a_tma = TMA([N, N], x.type, swizzle=sw)
+ b_tma = TMA([N, N], y.type, swizzle=sw)
+ a_tma.create_descriptor(a_dev)
+ b_tma.create_descriptor(b_dev)
+ smem_size_in_bytes = get_type_size(x.type) + get_type_size(y.type)
- @NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(128, 1, 1), smem=16384)
+ @NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(128, 1, 1), smem=smem_size_in_bytes)
def gemm_tma_kernel():
tidx = gpu.thread_id(gpu.Dimension.x)
@@ -47,46 +48,46 @@ def gemm_tma_kernel():
isThread0 = tidx == 0
with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
mbar_group[0].init(1)
- x_tma.prefetch()
- y_tma.prefetch()
+ a_tma.prefetch()
+ b_tma.prefetch()
scf.yield_([])
- x_smem = get_dynamic_shared_memory((N, N), T.f16())
- y_smem = get_dynamic_shared_memory((N, N), T.f16(), offset=N * N * 2)
+ a_smem = get_dynamic_shared_memory((N, N), T.f16())
+ b_smem = get_dynamic_shared_memory((N, N), T.f16(), offset=N * N * 2)
# 1. Execute TMA Load for two input matrices
with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
- x_tma.load(x_smem, mbar_group[0])
- y_tma.load(y_smem, mbar_group[0])
- tx_count = get_type_size(x_tma.tma_memref) + get_type_size(y_tma.tma_memref)
- mbar_group[0].arrive(tx_count)
+ a_tma.load(a_smem, mbar_group[0])
+ b_tma.load(b_smem, mbar_group[0])
+ ta_count = get_type_size(a_tma.tma_memref) + get_type_size(b_tma.tma_memref)
+ mbar_group[0].arrive(ta_count)
scf.yield_([])
mbar_group[0].try_wait()
# 2. Performs Tensor Core GEMM 64x64x64 by warpgroup
- A = Matrix(x_smem, x_tma, N, N)
- B = Matrix(y_smem, y_tma, N, N)
- C = MatrixAccumulator(N, N, T.f32()).op()
- D = Matrix.matmul(A, B, C)
+ A = WarpgroupMatrix(a_smem, a_tma, N, N)
+ B = WarpgroupMatrix(b_smem, b_tma, N, N)
+ C = WarpgroupAccumulatorMatrix(N, N, T.f32()).op()
+ D = WarpgroupMatrix.matmul(A, B, C)
# 3. Stores fragmented registers to global memory by warpgroup
- nvgpu.warpgroup_mma_store(D, z_dev)
+ nvgpu.warpgroup_mma_store(D, d_dev)
gemm_tma_kernel()
- t8 = gpu.memcpy(token_ty, [t7], z, z_dev)
+ t8 = gpu.memcpy(token_ty, [t7], d, d_dev)
gpu.wait(None, [t8])
# Python pass arguments to MLIR
N = 64
-x = np.random.randn(N, N).astype(np.float16)
-y = np.random.randn(N, N).astype(np.float16)
-z = np.zeros((N, N), np.float32)
-gemm_64_64_64(x, y, z)
+a = np.random.randn(N, N).astype(np.float16)
+b = np.random.randn(N, N).astype(np.float16)
+d = np.zeros((N, N), np.float32)
+gemm_64_64_64(a, b, d)
-ref = x.astype(np.float16) @ y.astype(np.float16)
-np.testing.assert_allclose(z, ref, rtol=5e-03, atol=1e-01)
+ref_d = a.astype(np.float16) @ b.astype(np.float16)
+np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
print("PASS")
# CHECK-NOT: Mismatched elements
diff --git a/mlir/test/Examples/nvgpu/Ch4.py b/mlir/test/Examples/nvgpu/Ch4.py
index 5f790c4c190ca9..2f38a49501e792 100644
--- a/mlir/test/Examples/nvgpu/Ch4.py
+++ b/mlir/test/Examples/nvgpu/Ch4.py
@@ -51,8 +51,8 @@ def partition_shape():
def tma_load(
mbar_group: Mbarriers,
- x_tma: TMA,
- y_tma: TMA,
+ a_tma: TMA,
+ b_tma: TMA,
slot,
stage,
p=None,
@@ -60,45 +60,45 @@ def tma_load(
"""
TMA loads two input matrices from global memory to shared memory. It performs the following operations:
- - tma.load x_shared_memory[offset] at coordinate [x, y] (Loads 128x64)
- - tma.load y_shared_memory[offset] at coordinate [x, y] (Loads 64x64)
- - tma.load y_shared_memory[offset] at coordinate [x, y] (Loads 64x64)
+ - tma.load a_shared_memory[offset] at coordinate [x, y] (Loads 128x64)
+ - tma.load b_shared_memory[offset] at coordinate [x, y] (Loads 64x64)
+ - tma.load b_shared_memory[offset] at coordinate [x, y] (Loads 64x64)
- mbarrier.arrive tx_count = 128x64x2x4
+ mbarrier.arrive ta_count = 128x64x2x4
"""
dimX, dimY = partition_shape()
tidx = gpu.thread_id(gpu.Dimension.x)
- begin_y = NUM_STAGES * get_type_size(x_tma.tma_memref)
- size_tma_x = get_type_size(x_tma.tma_memref)
- size_tma_y = get_type_size(y_tma.tma_memref)
- tx_count = size_tma_x + (size_tma_y * 2)
+ begin_b = NUM_STAGES * get_type_size(a_tma.tma_memref)
+ size_tma_a = get_type_size(a_tma.tma_memref)
+ size_tma_b = get_type_size(b_tma.tma_memref)
+ ta_count = size_tma_a + (size_tma_b * 2)
tidx = gpu.thread_id(gpu.Dimension.x)
p = tidx == 0 if p is None else p
- off_x = slot * size_tma_x
- off_y = (slot * size_tma_x) + begin_y
- off_y2 = off_y + size_tma_y
+ off_a = slot * size_tma_a
+ off_b = (slot * size_tma_a) + begin_b
+ off_b2 = off_b + size_tma_b
x = get_dynamic_shared_memory(
- x_tma.tma_memref.shape, x_tma.tma_memref.element_type, off_x
+ a_tma.tma_memref.shape, a_tma.tma_memref.element_type, off_a
)
y1 = get_dynamic_shared_memory(
- y_tma.tma_memref.shape, y_tma.tma_memref.element_type, off_y
+ b_tma.tma_memref.shape, b_tma.tma_memref.element_type, off_b
)
y2 = get_dynamic_shared_memory(
- y_tma.tma_memref.shape, y_tma.tma_memref.element_type, off_y2
+ b_tma.tma_memref.shape, b_tma.tma_memref.element_type, off_b2
)
- mbar_group[slot].arrive(tx_count, predicate=p)
+ mbar_group[slot].arrive(ta_count, predicate=p)
c1 = stage * 64
- x_tma.load(x, mbar_group[slot], coords=[c1, dimX], predicate=p)
- y_tma.load(y1, mbar_group[slot], coords=[dimY, c1], predicate=p)
- y_tma.load(y2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p)
+ a_tma.load(x, mbar_group[slot], coords=[c1, dimX], predicate=p)
+ b_tma.load(y1, mbar_group[slot], coords=[dimY, c1], predicate=p)
+ b_tma.load(y2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p)
-def bootstrap(x_tma: TMA, y_tma: TMA):
+def bootstrap(a_tma: TMA, b_tma: TMA):
"""
Initialize mbarriers and prefetch TMA descriptors.
"""
@@ -109,14 +109,14 @@ def bootstrap(x_tma: TMA, y_tma: TMA):
for i in scf.for_(0, NUM_STAGES, 1):
mbar_group[i].init(1)
scf.yield_([])
- x_tma.prefetch()
- y_tma.prefetch()
+ a_tma.prefetch()
+ b_tma.prefetch()
scf.yield_([])
return mbar_group
-def prologue(mbar_group: Mbarriers, x_tma: TMA, y_tma: TMA):
+def prologue(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
"""
Prologue of the GEMM kernel. It loads 2 input matrices for each stage in loop like below:
@@ -126,11 +126,11 @@ def prologue(mbar_group: Mbarriers, x_tma: TMA, y_tma: TMA):
"""
ns = NUM_STAGES if NUM_STAGES == 1 else NUM_STAGES - 1
for iv in scf.for_(0, ns, 1):
- tma_load(mbar_group, x_tma, y_tma, iv, iv)
+ tma_load(mbar_group, a_tma, b_tma, iv, iv)
scf.yield_([])
-def mainloop(mbar_group: Mbarriers, x_tma: TMA, y_tma: TMA):
+def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
"""
Main loop of the Multistage GEMM kernel. It iterates through
stages and performs matrix multiplication, loading data by TMA to shared memory. It like following
@@ -151,11 +151,11 @@ def mainloop(mbar_group: Mbarriers, x_tma: TMA, y_tma: TMA):
ns = NUM_STAGES if NUM_STAGES == 1 else NUM_STAGES - 1
tidx = gpu.thread_id(gpu.Dimension.x)
- begin_y = NUM_STAGES * get_type_size(x_tma.tma_memref)
+ begin_b = NUM_STAGES * get_type_size(a_tma.tma_memref)
- size_x = TILE_M * TILE_K * get_type_size(T.f16())
+ size_a = TILE_M * TILE_K * get_type_size(T.f16())
- C = MatrixAccumulator(TILE_M, TILE_N, T.f32()).op()
+ C = WarpgroupAccumulatorMatrix(TILE_M, TILE_N, T.f32()).op()
pp = const(False, ty=T.bool())
# Main Loop
@@ -169,16 +169,16 @@ def mainloop(mbar_group: Mbarriers, x_tma: TMA, y_tma: TMA):
mbar_group[stage].try_wait(phase=pp)
# Find shared memory slot
- offset_x = stage * size_x
- offset_y = offset_x + begin_y
- x_smem = get_dynamic_shared_memory([TILE_M, TILE_K], T.f16(), offset_x)
- y_smem = get_dynamic_shared_memory([TILE_K, TILE_N], T.f16(), offset_y)
+ offset_a = stage * size_a
+ offset_b = offset_a + begin_b
+ a_smem = get_dynamic_shared_memory([TILE_M, TILE_K], T.f16(), offset_a)
+ b_smem = get_dynamic_shared_memory([TILE_K, TILE_N], T.f16(), offset_b)
# Matrix Multiply
- A = Matrix(x_smem, x_tma, TILE_M, TILE_K)
- B = Matrix(y_smem, y_tma, TILE_K, TILE_N)
+ A = WarpgroupMatrix(a_smem, a_tma, TILE_M, TILE_K)
+ B = WarpgroupMatrix(b_smem, b_tma, TILE_K, TILE_N)
C = for_op.inner_iter_args[0]
- D = Matrix.matmul(A, B, C)
+ D = WarpgroupMatrix.matmul(A, B, C)
if NUM_STAGES == 1:
nvvm.WgmmaWaitGroupSyncOp(0)
@@ -186,7 +186,7 @@ def mainloop(mbar_group: Mbarriers, x_tma: TMA, y_tma: TMA):
pred = ((iv + ns) < const(K // TILE_K)) & (tidx == 0)
nextStage = iv + ns
nextSlot = nextStage % NUM_STAGES
- tma_load(mbar_group, x_tma, y_tma, nextSlot, nextStage, pred)
+ tma_load(mbar_group, a_tma, b_tma, nextSlot, nextStage, pred)
# Switch phase parity for the mbarrier
switched = pp ^ const(True, ty=T.bool())
@@ -202,7 +202,7 @@ def mainloop(mbar_group: Mbarriers, x_tma: TMA, y_tma: TMA):
return for_op.results[0]
-def epilogue(D, z_dev):
+def epilogue(D, d_dev):
"""
Epilogue of the GEMM kernel. It stores the fragmented registers to global memory.
@@ -214,57 +214,61 @@ def epilogue(D, z_dev):
tidx = gpu.thread_id(gpu.Dimension.x)
dimX, dimY = partition_shape()
- z_smem = get_dynamic_shared_memory([TILE_M, TILE_N], T.f32())
- z_gmem = memref.subview(z_dev, [dimX, dimY], [TILE_M, TILE_N], [1, 1])
+ d_smem = get_dynamic_shared_memory([TILE_M, TILE_N], T.f32())
+ d_gmem = memref.subview(d_dev, [dimX, dimY], [TILE_M, TILE_N], [1, 1])
# Store (registers -> shared memory)
- nvgpu.WarpgroupMmaStoreOp(D, z_smem)
+ nvgpu.WarpgroupMmaStoreOp(D, d_smem)
gpu.barrier()
# Store (shared memory --> global memory)
for i in scf.for_(0, TILE_M, 1):
- val = memref.load(z_smem, [i, tidx])
- memref.store(val, z_gmem, [i, tidx])
+ val = memref.load(d_smem, [i, tidx])
+ memref.store(val, d_gmem, [i, tidx])
scf.yield_([])
@NVDSL.mlir_func
-def gemm_multistage(x, y, z):
+def gemm_multistage(x, y, d):
token_ty = ir.Type.parse("!gpu.async.token")
t1 = gpu.wait(token_ty, [])
- x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
- y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
- z_dev, t4 = gpu.alloc(z.type, token_ty, [t3], [], [])
- t5 = gpu.memcpy(token_ty, [t4], x_dev, x)
- t6 = gpu.memcpy(token_ty, [t5], y_dev, y)
+ a_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
+ b_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
+ d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], [])
+ t5 = gpu.memcpy(token_ty, [t4], a_dev, x)
+ t6 = gpu.memcpy(token_ty, [t5], b_dev, y)
t7 = gpu.wait(token_ty, [t6])
sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
- x_tma = TMA([128, 64], x.type, swizzle=sw)
- y_tma = TMA([64, 64], y.type, swizzle=sw)
- x_tma.create_descriptor(x_dev)
- y_tma.create_descriptor(y_dev)
+ a_tma = TMA([128, 64], x.type, swizzle=sw)
+ b_tma = TMA([64, 64], y.type, swizzle=sw)
+ a_tma.create_descriptor(a_dev)
+ b_tma.create_descriptor(b_dev)
grid = [(M // TILE_M), (N // TILE_N), 1]
block = [128, 1, 1]
- @NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=229440)
+ size_a = get_type_size(x.type.element_type) * TILE_M * TILE_K
+ size_b = get_type_size(x.type.element_type) * TILE_N * TILE_K
+ smem_size_in_bytes = (size_a + size_b) * NUM_STAGES
+
+ @NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=smem_size_in_bytes)
def gemm_multistage_kernel():
# Initialize mbarriers and prefetch TMA descriptors
- mbar_group = bootstrap(x_tma, y_tma)
+ mbar_group = bootstrap(a_tma, b_tma)
# Fill the pipeline stages
- prologue(mbar_group, x_tma, y_tma)
+ prologue(mbar_group, a_tma, b_tma)
# Main loop
- D = mainloop(mbar_group, x_tma, y_tma)
+ D = mainloop(mbar_group, a_tma, b_tma)
# Store registers to global memory
- epilogue(D, z_dev)
+ epilogue(D, d_dev)
gemm_multistage_kernel()
- t8 = gpu.memcpy(token_ty, [t7], z, z_dev)
+ t8 = gpu.memcpy(token_ty, [t7], d, d_dev)
gpu.wait(None, [t8])
@@ -276,16 +280,16 @@ def gemm_multistage_kernel():
TILE_M = 128
TILE_N = 128
TILE_K = 64
-x = np.random.randn(M, K).astype(np.float16)
-y = np.random.randn(K, N).astype(np.float16)
-z = np.zeros((M, N), np.float32)
+a = np.random.randn(M, K).astype(np.float16)
+b = np.random.randn(K, N).astype(np.float16)
+d = np.zeros((M, N), np.float32)
-gemm_multistage(x, y, z)
+gemm_multistage(a, b, d)
# Verify MLIR with reference computation
-ref = x.astype(np.float16) @ y.astype(np.float16)
-np.testing.assert_allclose(z, ref, rtol=5e-03, atol=1e-01)
+ref_d = a.astype(np.float16) @ b.astype(np.float16)
+np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
print("PASS")
diff --git a/mlir/test/Examples/nvgpu/tools/nvdsl.py b/mlir/test/Examples/nvgpu/tools/nvdsl.py
index e27857d8ed1d79..bb312a9e259ef3 100644
--- a/mlir/test/Examples/nvgpu/tools/nvdsl.py
+++ b/mlir/test/Examples/nvgpu/tools/nvdsl.py
@@ -161,7 +161,7 @@ def load(self, dest, mbarrier: Mbarriers, coords=[0, 0], predicate=None):
)
-class MatrixAccumulator:
+class WarpgroupAccumulatorMatrix:
def __init__(self, M, N, ty):
self.M = M
self.N = N
@@ -169,21 +169,14 @@ def __init__(self, M, N, ty):
@property
def acc_ty(self):
- return ir.Type.parse(
- "!nvgpu.warpgroup.accumulator<fragmented=vector<"
- + str(self.M)
- + "x"
- + str(self.N)
- + "x"
- + str(self.ty)
- + ">>"
- )
+ parse_str = f"!nvgpu.warpgroup.accumulator<fragmented=vector<{self.M}x{self.N}x{self.ty}>>"
+ return ir.Type.parse(parse_str)
def op(self):
return nvgpu.warpgroup_mma_init_accumulator(self.acc_ty)
-class Matrix:
+class WarpgroupMatrix:
def __init__(self, smem, tma_descriptor: TMA, M, N):
self.tma_descriptor = tma_descriptor
self.smem = smem
@@ -192,15 +185,8 @@ def __init__(self, smem, tma_descriptor: TMA, M, N):
@property
def wgmma_ty(self):
- return ir.Type.parse(
- "!nvgpu.warpgroup.descriptor<tensor=memref<"
- + str(self.M)
- + "x"
- + str(self.N)
- + "x"
- + str(self.tma_descriptor.memref_ty.element_type)
- + ", #gpu.address_space<workgroup>>>"
- )
+ parse_str = f"!nvgpu.warpgroup.descriptor<tensor=memref<{self.M}x{self.N}x{self.tma_descriptor.memref_ty.element_type}, #gpu.address_space<workgroup>>>"
+ return ir.Type.parse(parse_str)
def matmul(lhs, rhs, acc):
wgmma_desc_lhs = nvgpu.warpgroup_generate_descriptor(
>From 83f3445b3bebde76d4afa34e5ce4542ce102d44a Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Mon, 1 Apr 2024 14:42:23 +0000
Subject: [PATCH 04/18] Perform 128x128x64 GEMM instead of 64x64x64
---
mlir/test/Examples/nvgpu/Ch3.py | 100 +++++++++++++++++++++++---------
1 file changed, 74 insertions(+), 26 deletions(-)
diff --git a/mlir/test/Examples/nvgpu/Ch3.py b/mlir/test/Examples/nvgpu/Ch3.py
index b1623f9d645c2c..12e6def2395555 100644
--- a/mlir/test/Examples/nvgpu/Ch3.py
+++ b/mlir/test/Examples/nvgpu/Ch3.py
@@ -22,23 +22,71 @@
import numpy as np
+def tma_load(
+ mbar_group: Mbarriers,
+ a_tma: TMA,
+ b_tma: TMA,
+ slot,
+ stage,
+ p=None,
+):
+ """
+ TMA loads two input matrices from global memory to shared memory. It performs the following operations:
+
+ - tma.load a_shared_memory[offset] at coordinate [x, y] (Loads 128x64)
+ - tma.load b_shared_memory[offset] at coordinate [x, y] (Loads 64x64)
+ - tma.load b_shared_memory[offset] at coordinate [x, y] (Loads 64x64)
+
+ mbarrier.arrive ta_count = 128x64x2x4
+ """
+
+ tidx = gpu.thread_id(gpu.Dimension.x)
+ begin_b = get_type_size(a_tma.tma_memref)
+ size_tma_a = get_type_size(a_tma.tma_memref)
+ size_tma_b = get_type_size(b_tma.tma_memref)
+ ta_count = size_tma_a + (size_tma_b * 2)
+ tidx = gpu.thread_id(gpu.Dimension.x)
+
+ p = tidx == 0 if p is None else p
+
+ off_a = slot * size_tma_a
+ off_b = (slot * size_tma_a) + begin_b
+ off_b2 = off_b + size_tma_b
+ x = get_dynamic_shared_memory(
+ a_tma.tma_memref.shape, a_tma.tma_memref.element_type, off_a
+ )
+ y1 = get_dynamic_shared_memory(
+ b_tma.tma_memref.shape, b_tma.tma_memref.element_type, off_b
+ )
+ y2 = get_dynamic_shared_memory(
+ b_tma.tma_memref.shape, b_tma.tma_memref.element_type, off_b2
+ )
+
+ mbar_group[slot].arrive(ta_count, predicate=p)
+
+ c1 = stage * 64
+ a_tma.load(x, mbar_group[slot], coords=[c1, 0], predicate=p)
+ b_tma.load(y1, mbar_group[slot], coords=[0, c1], predicate=p)
+ b_tma.load(y2, mbar_group[slot], coords=[64, c1], predicate=p)
+
+
@NVDSL.mlir_func
-def gemm_64_64_64(x, y, d):
+def gemm_128_128_64(a, b, d):
token_ty = ir.Type.parse("!gpu.async.token")
t1 = gpu.wait(token_ty, [])
- a_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
- b_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
+ a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
+ b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])
d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], [])
- t5 = gpu.memcpy(token_ty, [t4], a_dev, x)
- t6 = gpu.memcpy(token_ty, [t5], b_dev, y)
+ t5 = gpu.memcpy(token_ty, [t4], a_dev, a)
+ t6 = gpu.memcpy(token_ty, [t5], b_dev, b)
t7 = gpu.wait(token_ty, [t6])
sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
- a_tma = TMA([N, N], x.type, swizzle=sw)
- b_tma = TMA([N, N], y.type, swizzle=sw)
+ a_tma = TMA([128, 64], a.type, swizzle=sw)
+ b_tma = TMA([64, 64], b.type, swizzle=sw)
a_tma.create_descriptor(a_dev)
b_tma.create_descriptor(b_dev)
- smem_size_in_bytes = get_type_size(x.type) + get_type_size(y.type)
+ smem_size_in_bytes = get_type_size(a.type) + get_type_size(b.type)
@NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(128, 1, 1), smem=smem_size_in_bytes)
def gemm_tma_kernel():
@@ -52,26 +100,24 @@ def gemm_tma_kernel():
b_tma.prefetch()
scf.yield_([])
- a_smem = get_dynamic_shared_memory((N, N), T.f16())
- b_smem = get_dynamic_shared_memory((N, N), T.f16(), offset=N * N * 2)
+ a_smem = get_dynamic_shared_memory((M, K), T.f16())
+ b_smem = get_dynamic_shared_memory(
+ (K, N), T.f16(), offset=get_type_size(a.type)
+ )
# 1. Execute TMA Load for two input matrices
- with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
- a_tma.load(a_smem, mbar_group[0])
- b_tma.load(b_smem, mbar_group[0])
- ta_count = get_type_size(a_tma.tma_memref) + get_type_size(b_tma.tma_memref)
- mbar_group[0].arrive(ta_count)
- scf.yield_([])
+ tma_load(mbar_group, a_tma, b_tma, 0, 0, p=isThread0)
+ # 2. All threads wait TMA load completion
mbar_group[0].try_wait()
- # 2. Performs Tensor Core GEMM 64x64x64 by warpgroup
- A = WarpgroupMatrix(a_smem, a_tma, N, N)
- B = WarpgroupMatrix(b_smem, b_tma, N, N)
- C = WarpgroupAccumulatorMatrix(N, N, T.f32()).op()
+ # 3. Performs Tensor Core GEMM 128x128x64 by warpgroup
+ A = WarpgroupMatrix(a_smem, a_tma, M, K)
+ B = WarpgroupMatrix(b_smem, b_tma, K, N)
+ C = WarpgroupAccumulatorMatrix(M, N, T.f32()).op()
D = WarpgroupMatrix.matmul(A, B, C)
- # 3. Stores fragmented registers to global memory by warpgroup
+ # 4. Stores fragmented registers to global memory by warpgroup
nvgpu.warpgroup_mma_store(D, d_dev)
gemm_tma_kernel()
@@ -81,11 +127,13 @@ def gemm_tma_kernel():
# Python pass arguments to MLIR
-N = 64
-a = np.random.randn(N, N).astype(np.float16)
-b = np.random.randn(N, N).astype(np.float16)
-d = np.zeros((N, N), np.float32)
-gemm_64_64_64(a, b, d)
+M = 128
+N = 128
+K = 64
+a = np.random.randn(M, K).astype(np.float16)
+b = np.random.randn(K, N).astype(np.float16)
+d = np.zeros((M, N), np.float32)
+gemm_128_128_64(a, b, d)
ref_d = a.astype(np.float16) @ b.astype(np.float16)
np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
>From 6d05893a44768fb911eddeeccf7af78e98d0e2f1 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Mon, 1 Apr 2024 14:52:36 +0000
Subject: [PATCH 05/18] fix names and simplify tma_load for ch3
---
mlir/test/Examples/nvgpu/Ch3.py | 45 ++++++++++++---------------------
mlir/test/Examples/nvgpu/Ch4.py | 44 +++++++++++++++-----------------
2 files changed, 36 insertions(+), 53 deletions(-)
diff --git a/mlir/test/Examples/nvgpu/Ch3.py b/mlir/test/Examples/nvgpu/Ch3.py
index 12e6def2395555..ac0ee9b3f62c89 100644
--- a/mlir/test/Examples/nvgpu/Ch3.py
+++ b/mlir/test/Examples/nvgpu/Ch3.py
@@ -26,48 +26,35 @@ def tma_load(
mbar_group: Mbarriers,
a_tma: TMA,
b_tma: TMA,
- slot,
- stage,
- p=None,
+ p,
):
"""
TMA loads two input matrices from global memory to shared memory. It performs the following operations:
- - tma.load a_shared_memory[offset] at coordinate [x, y] (Loads 128x64)
- - tma.load b_shared_memory[offset] at coordinate [x, y] (Loads 64x64)
- - tma.load b_shared_memory[offset] at coordinate [x, y] (Loads 64x64)
+ - tma.load a_shared_memory[0] at coordinate [0, 0] (Loads 128x64)
+ - tma.load b_shared_memory[0] at coordinate [0, 0] (Loads 64x64)
+ - tma.load b_shared_memory[0] at coordinate [64, 0] (Loads 64x64)
- mbarrier.arrive ta_count = 128x64x2x4
+ mbarrier.arrive ta_count = 128x64xf16 + 64x128xf16
"""
- tidx = gpu.thread_id(gpu.Dimension.x)
- begin_b = get_type_size(a_tma.tma_memref)
size_tma_a = get_type_size(a_tma.tma_memref)
size_tma_b = get_type_size(b_tma.tma_memref)
ta_count = size_tma_a + (size_tma_b * 2)
- tidx = gpu.thread_id(gpu.Dimension.x)
- p = tidx == 0 if p is None else p
-
- off_a = slot * size_tma_a
- off_b = (slot * size_tma_a) + begin_b
+ off_b = size_tma_a
off_b2 = off_b + size_tma_b
- x = get_dynamic_shared_memory(
- a_tma.tma_memref.shape, a_tma.tma_memref.element_type, off_a
- )
- y1 = get_dynamic_shared_memory(
- b_tma.tma_memref.shape, b_tma.tma_memref.element_type, off_b
- )
- y2 = get_dynamic_shared_memory(
- b_tma.tma_memref.shape, b_tma.tma_memref.element_type, off_b2
- )
+ a_elem_ty = a_tma.tma_memref.element_type
+ b_elem_ty = b_tma.tma_memref.element_type
+ a = get_dynamic_shared_memory(a_tma.tma_memref.shape, a_elem_ty)
+ b1 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b)
+ b2 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b2)
- mbar_group[slot].arrive(ta_count, predicate=p)
+ mbar_group[0].arrive(ta_count, predicate=p)
- c1 = stage * 64
- a_tma.load(x, mbar_group[slot], coords=[c1, 0], predicate=p)
- b_tma.load(y1, mbar_group[slot], coords=[0, c1], predicate=p)
- b_tma.load(y2, mbar_group[slot], coords=[64, c1], predicate=p)
+ a_tma.load(a, mbar_group[0], coords=[0, 0], predicate=p)
+ b_tma.load(b1, mbar_group[0], coords=[0, 0], predicate=p)
+ b_tma.load(b2, mbar_group[0], coords=[64, 0], predicate=p)
@NVDSL.mlir_func
@@ -106,7 +93,7 @@ def gemm_tma_kernel():
)
# 1. Execute TMA Load for two input matrices
- tma_load(mbar_group, a_tma, b_tma, 0, 0, p=isThread0)
+ tma_load(mbar_group, a_tma, b_tma, isThread0)
# 2. All threads wait TMA load completion
mbar_group[0].try_wait()
diff --git a/mlir/test/Examples/nvgpu/Ch4.py b/mlir/test/Examples/nvgpu/Ch4.py
index 2f38a49501e792..4b426522a89c52 100644
--- a/mlir/test/Examples/nvgpu/Ch4.py
+++ b/mlir/test/Examples/nvgpu/Ch4.py
@@ -60,9 +60,9 @@ def tma_load(
"""
TMA loads two input matrices from global memory to shared memory. It performs the following operations:
- - tma.load a_shared_memory[offset] at coordinate [x, y] (Loads 128x64)
- - tma.load b_shared_memory[offset] at coordinate [x, y] (Loads 64x64)
- - tma.load b_shared_memory[offset] at coordinate [x, y] (Loads 64x64)
+ - tma.load a_shared_memory[off_x] at coordinate [x, z] (Loads 128x64)
+ - tma.load b_shared_memory[off_y1] at coordinate [y, x] (Loads 64x64)
+ - tma.load b_shared_memory[off_y2] at coordinate [y + 64, x] (Loads 64x64)
mbarrier.arrive ta_count = 128x64x2x4
"""
@@ -80,22 +80,18 @@ def tma_load(
off_a = slot * size_tma_a
off_b = (slot * size_tma_a) + begin_b
off_b2 = off_b + size_tma_b
- x = get_dynamic_shared_memory(
- a_tma.tma_memref.shape, a_tma.tma_memref.element_type, off_a
- )
- y1 = get_dynamic_shared_memory(
- b_tma.tma_memref.shape, b_tma.tma_memref.element_type, off_b
- )
- y2 = get_dynamic_shared_memory(
- b_tma.tma_memref.shape, b_tma.tma_memref.element_type, off_b2
- )
+ a_elem_ty = a_tma.tma_memref.element_type
+ b_elem_ty = b_tma.tma_memref.element_type
+ a = get_dynamic_shared_memory(a_tma.tma_memref.shape, a_elem_ty, off_a)
+ b1 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b)
+ b2 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b2)
mbar_group[slot].arrive(ta_count, predicate=p)
c1 = stage * 64
- a_tma.load(x, mbar_group[slot], coords=[c1, dimX], predicate=p)
- b_tma.load(y1, mbar_group[slot], coords=[dimY, c1], predicate=p)
- b_tma.load(y2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p)
+ a_tma.load(a, mbar_group[slot], coords=[c1, dimX], predicate=p)
+ b_tma.load(b1, mbar_group[slot], coords=[dimY, c1], predicate=p)
+ b_tma.load(b2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p)
def bootstrap(a_tma: TMA, b_tma: TMA):
@@ -229,27 +225,27 @@ def epilogue(D, d_dev):
@NVDSL.mlir_func
-def gemm_multistage(x, y, d):
+def gemm_multistage(a, b, d):
token_ty = ir.Type.parse("!gpu.async.token")
t1 = gpu.wait(token_ty, [])
- a_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
- b_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
+ a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
+ b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])
d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], [])
- t5 = gpu.memcpy(token_ty, [t4], a_dev, x)
- t6 = gpu.memcpy(token_ty, [t5], b_dev, y)
+ t5 = gpu.memcpy(token_ty, [t4], a_dev, a)
+ t6 = gpu.memcpy(token_ty, [t5], b_dev, b)
t7 = gpu.wait(token_ty, [t6])
sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
- a_tma = TMA([128, 64], x.type, swizzle=sw)
- b_tma = TMA([64, 64], y.type, swizzle=sw)
+ a_tma = TMA([128, 64], a.type, swizzle=sw)
+ b_tma = TMA([64, 64], b.type, swizzle=sw)
a_tma.create_descriptor(a_dev)
b_tma.create_descriptor(b_dev)
grid = [(M // TILE_M), (N // TILE_N), 1]
block = [128, 1, 1]
- size_a = get_type_size(x.type.element_type) * TILE_M * TILE_K
- size_b = get_type_size(x.type.element_type) * TILE_N * TILE_K
+ size_a = get_type_size(a.type.element_type) * TILE_M * TILE_K
+ size_b = get_type_size(b.type.element_type) * TILE_N * TILE_K
smem_size_in_bytes = (size_a + size_b) * NUM_STAGES
@NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=smem_size_in_bytes)
>From ca1625bd8ecdfbd679cbc8766f4d94622d7e1897 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Tue, 2 Apr 2024 14:35:22 +0000
Subject: [PATCH 06/18] Operator overload for `C += A @ B`
---
mlir/test/Examples/nvgpu/Ch3.py | 12 +++--
mlir/test/Examples/nvgpu/Ch4.py | 24 ++++++---
mlir/test/Examples/nvgpu/tools/nvdsl.py | 70 ++++++++++++++++---------
3 files changed, 69 insertions(+), 37 deletions(-)
diff --git a/mlir/test/Examples/nvgpu/Ch3.py b/mlir/test/Examples/nvgpu/Ch3.py
index ac0ee9b3f62c89..8c2864dd0d465d 100644
--- a/mlir/test/Examples/nvgpu/Ch3.py
+++ b/mlir/test/Examples/nvgpu/Ch3.py
@@ -99,13 +99,15 @@ def gemm_tma_kernel():
mbar_group[0].try_wait()
# 3. Performs Tensor Core GEMM 128x128x64 by warpgroup
- A = WarpgroupMatrix(a_smem, a_tma, M, K)
- B = WarpgroupMatrix(b_smem, b_tma, K, N)
- C = WarpgroupAccumulatorMatrix(M, N, T.f32()).op()
- D = WarpgroupMatrix.matmul(A, B, C)
+ A = WGMMAMatrix(WGMMAType.Descriptor, [M,K], desc=a_tma, smem=a_smem)
+ B = WGMMAMatrix(WGMMAType.Descriptor, [K,N], desc=b_tma, smem=b_smem)
+ C = WGMMAMatrix(WGMMAType.Accumulator, shape=[M,N], ty=T.f32())
+
+ # Matrix Multiply
+ C += A @ B
# 4. Stores fragmented registers to global memory by warpgroup
- nvgpu.warpgroup_mma_store(D, d_dev)
+ nvgpu.warpgroup_mma_store(C, d_dev)
gemm_tma_kernel()
diff --git a/mlir/test/Examples/nvgpu/Ch4.py b/mlir/test/Examples/nvgpu/Ch4.py
index 4b426522a89c52..9555b367a0078b 100644
--- a/mlir/test/Examples/nvgpu/Ch4.py
+++ b/mlir/test/Examples/nvgpu/Ch4.py
@@ -151,11 +151,15 @@ def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
size_a = TILE_M * TILE_K * get_type_size(T.f16())
- C = WarpgroupAccumulatorMatrix(TILE_M, TILE_N, T.f32()).op()
+ # Initialize A and B (input matrices) and C (accumulator)
+ A = WGMMAMatrix(WGMMAType.Descriptor, [TILE_M, TILE_K], desc=a_tma)
+ B = WGMMAMatrix(WGMMAType.Descriptor, [TILE_K, TILE_N], desc=b_tma)
+ C = WGMMAMatrix(WGMMAType.Accumulator, shape=[TILE_M, TILE_N], ty=T.f32())
+
pp = const(False, ty=T.bool())
# Main Loop
- for_op = scf.ForOp(const(0), const(K // TILE_K), const(1), [C, pp])
+ for_op = scf.ForOp(const(0), const(K // TILE_K), const(1), [C.acc_op, pp])
with ir.InsertionPoint(for_op.body):
pp = for_op.inner_iter_args[1]
iv = for_op.induction_variable
@@ -170,14 +174,18 @@ def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
a_smem = get_dynamic_shared_memory([TILE_M, TILE_K], T.f16(), offset_a)
b_smem = get_dynamic_shared_memory([TILE_K, TILE_N], T.f16(), offset_b)
+ # Initialize matrices
+ A.update_smem(a_smem)
+ B.update_smem(b_smem)
+ C.update_accumulator(for_op.inner_iter_args[0])
+
# Matrix Multiply
- A = WarpgroupMatrix(a_smem, a_tma, TILE_M, TILE_K)
- B = WarpgroupMatrix(b_smem, b_tma, TILE_K, TILE_N)
- C = for_op.inner_iter_args[0]
- D = WarpgroupMatrix.matmul(A, B, C)
+ C += A @ B
+
+ # Wait Tensor Core for single stage
if NUM_STAGES == 1:
nvvm.WgmmaWaitGroupSyncOp(0)
-
+
# Load next stage
pred = ((iv + ns) < const(K // TILE_K)) & (tidx == 0)
nextStage = iv + ns
@@ -191,7 +199,7 @@ def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
switched,
pp,
)
- scf.yield_([D, newPP])
+ scf.yield_([C, newPP])
nvvm.WgmmaWaitGroupSyncOp(0)
diff --git a/mlir/test/Examples/nvgpu/tools/nvdsl.py b/mlir/test/Examples/nvgpu/tools/nvdsl.py
index bb312a9e259ef3..cfa90cdf3c6539 100644
--- a/mlir/test/Examples/nvgpu/tools/nvdsl.py
+++ b/mlir/test/Examples/nvgpu/tools/nvdsl.py
@@ -161,42 +161,64 @@ def load(self, dest, mbarrier: Mbarriers, coords=[0, 0], predicate=None):
)
-class WarpgroupAccumulatorMatrix:
- def __init__(self, M, N, ty):
- self.M = M
- self.N = N
- self.ty = ty
+class WGMMAType(Enum):
+ Accumulator = 1
+ Descriptor = 2
+
+
+class WGMMAMatrix:
+ def __init__(
+ self,
+ matrix_type: WGMMAType,
+ shape: list = None,
+ desc: TMA = None,
+ smem=None,
+ ty=None,
+ acc_op=None,
+ ):
+ if acc_op is None:
+ self.M = shape[0]
+ self.N = shape[1]
+ self.ty = ty
+ self.matrix_type = matrix_type
+ self.desc = desc
+ self.smem = smem
+ if matrix_type is WGMMAType.Accumulator:
+ self.acc_op = nvgpu.warpgroup_mma_init_accumulator(self.acc_ty)
+ elif acc_op:
+ self.acc_op = acc_op
+ self.matrix_type = WGMMAType.Accumulator
@property
def acc_ty(self):
parse_str = f"!nvgpu.warpgroup.accumulator<fragmented=vector<{self.M}x{self.N}x{self.ty}>>"
return ir.Type.parse(parse_str)
- def op(self):
- return nvgpu.warpgroup_mma_init_accumulator(self.acc_ty)
-
-
-class WarpgroupMatrix:
- def __init__(self, smem, tma_descriptor: TMA, M, N):
- self.tma_descriptor = tma_descriptor
- self.smem = smem
- self.M = M
- self.N = N
-
@property
def wgmma_ty(self):
- parse_str = f"!nvgpu.warpgroup.descriptor<tensor=memref<{self.M}x{self.N}x{self.tma_descriptor.memref_ty.element_type}, #gpu.address_space<workgroup>>>"
+ parse_str = f"!nvgpu.warpgroup.descriptor<tensor=memref<{self.M}x{self.N}x{self.desc.memref_ty.element_type}, #gpu.address_space<workgroup>>>"
return ir.Type.parse(parse_str)
- def matmul(lhs, rhs, acc):
- wgmma_desc_lhs = nvgpu.warpgroup_generate_descriptor(
- lhs.wgmma_ty, lhs.smem, lhs.tma_descriptor.tma_descriptor
+ def update_smem(self, smem):
+ self.smem = smem
+
+ def update_accumulator(self, acc_op):
+ self.acc_op = acc_op
+
+ def __matmul__(self, rhs):
+ lhs = nvgpu.warpgroup_generate_descriptor(
+ self.wgmma_ty, self.smem, self.desc.tma_descriptor
)
- wgmma_desc_rhs = nvgpu.warpgroup_generate_descriptor(
- rhs.wgmma_ty, rhs.smem, rhs.tma_descriptor.tma_descriptor
+ rhs = nvgpu.warpgroup_generate_descriptor(
+ rhs.wgmma_ty, rhs.smem, rhs.desc.tma_descriptor
)
+ return [lhs, rhs]
+
+ def __iadd__(self, matmulResult):
+ lhs = matmulResult[0]
+ rhs = matmulResult[1]
return nvgpu.WarpgroupMmaOp(
- acc.type, wgmma_desc_lhs, wgmma_desc_rhs, acc, transposeB=True
+ self.acc_op.type, lhs, rhs, self.acc_op, transposeB=True
)
@@ -376,7 +398,7 @@ def __str__(self):
module.operation.verify()
# Save IR in a file
- # saveIR(module)
+ saveIR(module)
# Compile and JIT MLIR module
options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
>From 5464e4dc228d7670334aa53cb6e3b2851739b4f7 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Wed, 3 Apr 2024 09:48:35 +0000
Subject: [PATCH 07/18] format, more simplification
---
mlir/test/Examples/nvgpu/Ch2.py | 14 ++++----------
mlir/test/Examples/nvgpu/Ch3.py | 23 +++++++++++------------
mlir/test/Examples/nvgpu/Ch4.py | 4 ++--
mlir/test/Examples/nvgpu/tools/nvdsl.py | 8 ++++----
4 files changed, 21 insertions(+), 28 deletions(-)
diff --git a/mlir/test/Examples/nvgpu/Ch2.py b/mlir/test/Examples/nvgpu/Ch2.py
index e9369227af0e85..749e1a00585fd6 100644
--- a/mlir/test/Examples/nvgpu/Ch2.py
+++ b/mlir/test/Examples/nvgpu/Ch2.py
@@ -46,21 +46,15 @@ def saxpy_tma_kernel():
# 1. Create and initialize asynchronous transactional barrier (mbarrier)
mbar_group = Mbarriers(number_of_barriers=1)
- with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
- mbar_group[0].init(1)
- x_tma.prefetch()
- y_tma.prefetch()
- scf.yield_([])
+ mbar_group[0].init(1, predicate=isThread0)
x_smem = get_dynamic_shared_memory((M, N), T.f32())
y_smem = get_dynamic_shared_memory((M, N), T.f32(), offset=M * N * 2)
# 2. Execute Tensor Memory Accelerator (TMA) Load
- with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
- x_tma.load(x_smem, mbar_group[0])
- y_tma.load(y_smem, mbar_group[0])
- mbar_group[0].arrive(txcount=M * N * 2 * 4)
- scf.yield_([])
+ x_tma.load(x_smem, mbar_group[0], predicate=isThread0)
+ y_tma.load(y_smem, mbar_group[0], predicate=isThread0)
+ mbar_group[0].arrive(txcount=M * N * 2 * 4, predicate=isThread0)
# 3. Wait for completion of TMA load with mbarrier
mbar_group[0].try_wait()
diff --git a/mlir/test/Examples/nvgpu/Ch3.py b/mlir/test/Examples/nvgpu/Ch3.py
index 8c2864dd0d465d..6b1d36908f4365 100644
--- a/mlir/test/Examples/nvgpu/Ch3.py
+++ b/mlir/test/Examples/nvgpu/Ch3.py
@@ -73,7 +73,9 @@ def gemm_128_128_64(a, b, d):
b_tma = TMA([64, 64], b.type, swizzle=sw)
a_tma.create_descriptor(a_dev)
b_tma.create_descriptor(b_dev)
- smem_size_in_bytes = get_type_size(a.type) + get_type_size(b.type)
+ a_size = get_type_size(a.type)
+ b_size = get_type_size(b.type)
+ smem_size_in_bytes = a_size + b_size
@NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(128, 1, 1), smem=smem_size_in_bytes)
def gemm_tma_kernel():
@@ -81,16 +83,13 @@ def gemm_tma_kernel():
mbar_group = Mbarriers(number_of_barriers=1)
isThread0 = tidx == 0
- with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
- mbar_group[0].init(1)
- a_tma.prefetch()
- b_tma.prefetch()
- scf.yield_([])
+
+ mbar_group[0].init(1, predicate=isThread0)
+ a_tma.prefetch(predicate=isThread0)
+ b_tma.prefetch(predicate=isThread0)
a_smem = get_dynamic_shared_memory((M, K), T.f16())
- b_smem = get_dynamic_shared_memory(
- (K, N), T.f16(), offset=get_type_size(a.type)
- )
+ b_smem = get_dynamic_shared_memory((K, N), T.f16(), offset=a_size)
# 1. Execute TMA Load for two input matrices
tma_load(mbar_group, a_tma, b_tma, isThread0)
@@ -99,9 +98,9 @@ def gemm_tma_kernel():
mbar_group[0].try_wait()
# 3. Performs Tensor Core GEMM 128x128x64 by warpgroup
- A = WGMMAMatrix(WGMMAType.Descriptor, [M,K], desc=a_tma, smem=a_smem)
- B = WGMMAMatrix(WGMMAType.Descriptor, [K,N], desc=b_tma, smem=b_smem)
- C = WGMMAMatrix(WGMMAType.Accumulator, shape=[M,N], ty=T.f32())
+ A = WGMMAMatrix(WGMMAType.Descriptor, [M, K], desc=a_tma, smem=a_smem)
+ B = WGMMAMatrix(WGMMAType.Descriptor, [K, N], desc=b_tma, smem=b_smem)
+ C = WGMMAMatrix(WGMMAType.Accumulator, shape=[M, N], ty=T.f32())
# Matrix Multiply
C += A @ B
diff --git a/mlir/test/Examples/nvgpu/Ch4.py b/mlir/test/Examples/nvgpu/Ch4.py
index 9555b367a0078b..1f23f423ce4092 100644
--- a/mlir/test/Examples/nvgpu/Ch4.py
+++ b/mlir/test/Examples/nvgpu/Ch4.py
@@ -155,7 +155,7 @@ def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
A = WGMMAMatrix(WGMMAType.Descriptor, [TILE_M, TILE_K], desc=a_tma)
B = WGMMAMatrix(WGMMAType.Descriptor, [TILE_K, TILE_N], desc=b_tma)
C = WGMMAMatrix(WGMMAType.Accumulator, shape=[TILE_M, TILE_N], ty=T.f32())
-
+
pp = const(False, ty=T.bool())
# Main Loop
@@ -185,7 +185,7 @@ def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
# Wait Tensor Core for single stage
if NUM_STAGES == 1:
nvvm.WgmmaWaitGroupSyncOp(0)
-
+
# Load next stage
pred = ((iv + ns) < const(K // TILE_K)) & (tidx == 0)
nextStage = iv + ns
diff --git a/mlir/test/Examples/nvgpu/tools/nvdsl.py b/mlir/test/Examples/nvgpu/tools/nvdsl.py
index cfa90cdf3c6539..b9754491d50e8b 100644
--- a/mlir/test/Examples/nvgpu/tools/nvdsl.py
+++ b/mlir/test/Examples/nvgpu/tools/nvdsl.py
@@ -184,7 +184,7 @@ def __init__(
self.desc = desc
self.smem = smem
if matrix_type is WGMMAType.Accumulator:
- self.acc_op = nvgpu.warpgroup_mma_init_accumulator(self.acc_ty)
+ self.acc_op = nvgpu.warpgroup_mma_init_accumulator(self.acc_ty)
elif acc_op:
self.acc_op = acc_op
self.matrix_type = WGMMAType.Accumulator
@@ -200,10 +200,10 @@ def wgmma_ty(self):
return ir.Type.parse(parse_str)
def update_smem(self, smem):
- self.smem = smem
-
+ self.smem = smem
+
def update_accumulator(self, acc_op):
- self.acc_op = acc_op
+ self.acc_op = acc_op
def __matmul__(self, rhs):
lhs = nvgpu.warpgroup_generate_descriptor(
>From 80a0c14323d38b509df790c11731ca02e0f1ccac Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Wed, 3 Apr 2024 09:53:26 +0000
Subject: [PATCH 08/18] calculate shmem dynamically
---
mlir/test/Examples/nvgpu/Ch2.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/test/Examples/nvgpu/Ch2.py b/mlir/test/Examples/nvgpu/Ch2.py
index 749e1a00585fd6..347da9ab900184 100644
--- a/mlir/test/Examples/nvgpu/Ch2.py
+++ b/mlir/test/Examples/nvgpu/Ch2.py
@@ -37,8 +37,9 @@ def saxpy_tma(x, y, alpha):
y_tma = TMA((M, N), y.type)
x_tma.create_descriptor(x_dev)
y_tma.create_descriptor(y_dev)
+ smem_size_in_bytes = get_type_size(x.type) + get_type_size(y.type)
- @NVDSL.mlir_gpu_launch(grid=(M, 1, 1), block=(N, 1, 1), smem=65536)
+ @NVDSL.mlir_gpu_launch(grid=(M, 1, 1), block=(N, 1, 1), smem=smem_size_in_bytes)
def saxpy_tma_kernel():
bidx = gpu.block_id(gpu.Dimension.x)
tidx = gpu.thread_id(gpu.Dimension.x)
>From 22ddcb250a4d3cdd24f3fd38c71a58723860bcaf Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Thu, 4 Apr 2024 08:02:13 +0000
Subject: [PATCH 09/18] add `update_accumulator`
---
mlir/test/Examples/nvgpu/Ch2.py | 3 +--
mlir/test/Examples/nvgpu/Ch3.py | 8 +++----
mlir/test/Examples/nvgpu/Ch4.py | 30 ++++++++++++-------------
mlir/test/Examples/nvgpu/tools/nvdsl.py | 10 ++++++---
4 files changed, 27 insertions(+), 24 deletions(-)
diff --git a/mlir/test/Examples/nvgpu/Ch2.py b/mlir/test/Examples/nvgpu/Ch2.py
index 347da9ab900184..f7817747fde70c 100644
--- a/mlir/test/Examples/nvgpu/Ch2.py
+++ b/mlir/test/Examples/nvgpu/Ch2.py
@@ -49,10 +49,9 @@ def saxpy_tma_kernel():
mbar_group = Mbarriers(number_of_barriers=1)
mbar_group[0].init(1, predicate=isThread0)
+ # 2. Execute Tensor Memory Accelerator (TMA) Load
x_smem = get_dynamic_shared_memory((M, N), T.f32())
y_smem = get_dynamic_shared_memory((M, N), T.f32(), offset=M * N * 2)
-
- # 2. Execute Tensor Memory Accelerator (TMA) Load
x_tma.load(x_smem, mbar_group[0], predicate=isThread0)
y_tma.load(y_smem, mbar_group[0], predicate=isThread0)
mbar_group[0].arrive(txcount=M * N * 2 * 4, predicate=isThread0)
diff --git a/mlir/test/Examples/nvgpu/Ch3.py b/mlir/test/Examples/nvgpu/Ch3.py
index 6b1d36908f4365..9e91a6bf931ee1 100644
--- a/mlir/test/Examples/nvgpu/Ch3.py
+++ b/mlir/test/Examples/nvgpu/Ch3.py
@@ -91,7 +91,7 @@ def gemm_tma_kernel():
a_smem = get_dynamic_shared_memory((M, K), T.f16())
b_smem = get_dynamic_shared_memory((K, N), T.f16(), offset=a_size)
- # 1. Execute TMA Load for two input matrices
+ # 1. TMA Load for two input matrices
tma_load(mbar_group, a_tma, b_tma, isThread0)
# 2. All threads wait TMA load completion
@@ -100,13 +100,13 @@ def gemm_tma_kernel():
# 3. Performs Tensor Core GEMM 128x128x64 by warpgroup
A = WGMMAMatrix(WGMMAType.Descriptor, [M, K], desc=a_tma, smem=a_smem)
B = WGMMAMatrix(WGMMAType.Descriptor, [K, N], desc=b_tma, smem=b_smem)
- C = WGMMAMatrix(WGMMAType.Accumulator, shape=[M, N], ty=T.f32())
+ D = WGMMAMatrix(WGMMAType.Accumulator, shape=[M, N], ty=T.f32())
# Matrix Multiply
- C += A @ B
+ D += A @ B
# 4. Stores fragmented registers to global memory by warpgroup
- nvgpu.warpgroup_mma_store(C, d_dev)
+ D.store_accumulator(d_dev)
gemm_tma_kernel()
diff --git a/mlir/test/Examples/nvgpu/Ch4.py b/mlir/test/Examples/nvgpu/Ch4.py
index 1f23f423ce4092..ef0edadefa0094 100644
--- a/mlir/test/Examples/nvgpu/Ch4.py
+++ b/mlir/test/Examples/nvgpu/Ch4.py
@@ -154,19 +154,19 @@ def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
# Initialize A and B (input matrices) and C (accumulator)
A = WGMMAMatrix(WGMMAType.Descriptor, [TILE_M, TILE_K], desc=a_tma)
B = WGMMAMatrix(WGMMAType.Descriptor, [TILE_K, TILE_N], desc=b_tma)
- C = WGMMAMatrix(WGMMAType.Accumulator, shape=[TILE_M, TILE_N], ty=T.f32())
+ D = WGMMAMatrix(WGMMAType.Accumulator, shape=[TILE_M, TILE_N], ty=T.f32())
- pp = const(False, ty=T.bool())
+ phase = const(False, ty=T.bool())
# Main Loop
- for_op = scf.ForOp(const(0), const(K // TILE_K), const(1), [C.acc_op, pp])
+ for_op = scf.ForOp(const(0), const(K // TILE_K), const(1), [D.acc_op, phase])
with ir.InsertionPoint(for_op.body):
- pp = for_op.inner_iter_args[1]
+ phase = for_op.inner_iter_args[1]
iv = for_op.induction_variable
stage = iv % NUM_STAGES
# Wait for current stage
- mbar_group[stage].try_wait(phase=pp)
+ mbar_group[stage].try_wait(phase=phase)
# Find shared memory slot
offset_a = stage * size_a
@@ -177,10 +177,10 @@ def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
# Initialize matrices
A.update_smem(a_smem)
B.update_smem(b_smem)
- C.update_accumulator(for_op.inner_iter_args[0])
+ D.update_accumulator(for_op.inner_iter_args[0])
# Matrix Multiply
- C += A @ B
+ D += A @ B
# Wait Tensor Core for single stage
if NUM_STAGES == 1:
@@ -193,20 +193,20 @@ def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
tma_load(mbar_group, a_tma, b_tma, nextSlot, nextStage, pred)
# Switch phase parity for the mbarrier
- switched = pp ^ const(True, ty=T.bool())
- newPP = arith.select(
+ newPhase = arith.select(
stage == (NUM_STAGES - 1),
- switched,
- pp,
+ (phase ^ const(True, ty=T.bool())),
+ phase,
)
- scf.yield_([C, newPP])
+ scf.yield_([D.acc_op, newPhase])
nvvm.WgmmaWaitGroupSyncOp(0)
- return for_op.results[0]
+ D.update_accumulator(for_op.results[0])
+ return D
-def epilogue(D, d_dev):
+def epilogue(D: WGMMAMatrix, d_dev):
"""
Epilogue of the GEMM kernel. It stores the fragmented registers to global memory.
@@ -222,7 +222,7 @@ def epilogue(D, d_dev):
d_gmem = memref.subview(d_dev, [dimX, dimY], [TILE_M, TILE_N], [1, 1])
# Store (registers -> shared memory)
- nvgpu.WarpgroupMmaStoreOp(D, d_smem)
+ D.store_accumulator(d_smem)
gpu.barrier()
# Store (shared memory --> global memory)
diff --git a/mlir/test/Examples/nvgpu/tools/nvdsl.py b/mlir/test/Examples/nvgpu/tools/nvdsl.py
index b9754491d50e8b..6904fe68b5a01d 100644
--- a/mlir/test/Examples/nvgpu/tools/nvdsl.py
+++ b/mlir/test/Examples/nvgpu/tools/nvdsl.py
@@ -150,12 +150,11 @@ def prefetch(self, predicate=None):
nvgpu.tma_prefetch_descriptor(self.tma_descriptor, predicate=predicate)
def load(self, dest, mbarrier: Mbarriers, coords=[0, 0], predicate=None):
- coord_ops = [const(c) for c in coords]
nvgpu.TmaAsyncLoadOp(
dest,
mbarrier.mbar_group_op,
self.tma_descriptor,
- coordinates=coord_ops,
+ coordinates=map(const, coords),
mbarId=mbarrier.id_op,
predicate=predicate,
)
@@ -199,6 +198,10 @@ def wgmma_ty(self):
parse_str = f"!nvgpu.warpgroup.descriptor<tensor=memref<{self.M}x{self.N}x{self.desc.memref_ty.element_type}, #gpu.address_space<workgroup>>>"
return ir.Type.parse(parse_str)
+ def store_accumulator(self, dest):
+ assert self.matrix_type == WGMMAType.Accumulator
+ nvgpu.warpgroup_mma_store(self.acc_op, dest)
+
def update_smem(self, smem):
self.smem = smem
@@ -217,9 +220,10 @@ def __matmul__(self, rhs):
def __iadd__(self, matmulResult):
lhs = matmulResult[0]
rhs = matmulResult[1]
- return nvgpu.WarpgroupMmaOp(
+ acc_op = nvgpu.WarpgroupMmaOp(
self.acc_op.type, lhs, rhs, self.acc_op, transposeB=True
)
+ return WGMMAMatrix(WGMMAType.Accumulator, acc_op=acc_op)
def get_dynamic_shared_memory(shape=None, ty=None, offset: int = 0):
>From 5c7288bf1cdf328cf8f7454e1d0fb1374105197c Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Thu, 4 Apr 2024 13:28:23 +0000
Subject: [PATCH 10/18] Fix saxpy tma, now it loads partial data to smem
---
mlir/test/Examples/nvgpu/Ch1.py | 4 +--
mlir/test/Examples/nvgpu/Ch2.py | 33 ++++++++++++++-----------
mlir/test/Examples/nvgpu/tools/nvdsl.py | 24 +++++++++++-------
3 files changed, 35 insertions(+), 26 deletions(-)
diff --git a/mlir/test/Examples/nvgpu/Ch1.py b/mlir/test/Examples/nvgpu/Ch1.py
index a888c61358a024..22e1eb7c92e005 100644
--- a/mlir/test/Examples/nvgpu/Ch1.py
+++ b/mlir/test/Examples/nvgpu/Ch1.py
@@ -56,12 +56,12 @@ def saxpy_kernel():
M = 256
N = 32
alpha = 2.0
-x = np.ones((M, N), np.float32)
+x = np.random.randn(M, N).astype(np.float32)
y = np.ones((M, N), np.float32)
-ref = np.ones((M, N), np.float32)
saxpy(x, y, alpha)
# 4. Verify MLIR with reference computation
+ref = np.ones((M, N), np.float32)
ref += x * alpha
np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
print("PASS")
diff --git a/mlir/test/Examples/nvgpu/Ch2.py b/mlir/test/Examples/nvgpu/Ch2.py
index f7817747fde70c..2e316cf406e99f 100644
--- a/mlir/test/Examples/nvgpu/Ch2.py
+++ b/mlir/test/Examples/nvgpu/Ch2.py
@@ -24,7 +24,7 @@
@NVDSL.mlir_func
-def saxpy_tma(x, y, alpha):
+def saxpy(x, y, alpha):
token_ty = ir.Type.parse("!gpu.async.token")
t1 = gpu.wait(token_ty, [])
x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
@@ -33,13 +33,15 @@ def saxpy_tma(x, y, alpha):
t5 = gpu.memcpy(token_ty, [t4], y_dev, y)
t6 = gpu.wait(token_ty, [t5])
- x_tma = TMA((M, N), x.type)
- y_tma = TMA((M, N), y.type)
+ x_tma = TMA([1, N], x.type)
+ y_tma = TMA([1, N], y.type)
x_tma.create_descriptor(x_dev)
y_tma.create_descriptor(y_dev)
- smem_size_in_bytes = get_type_size(x.type) + get_type_size(y.type)
+ sz_x = get_type_size(x_tma.tma_memref)
+ sz_y = get_type_size(x_tma.tma_memref)
+ sz = sz_x + sz_y
- @NVDSL.mlir_gpu_launch(grid=(M, 1, 1), block=(N, 1, 1), smem=smem_size_in_bytes)
+ @NVDSL.mlir_gpu_launch(grid=(M, 1, 1), block=(N, 1, 1), smem=sz)
def saxpy_tma_kernel():
bidx = gpu.block_id(gpu.Dimension.x)
tidx = gpu.thread_id(gpu.Dimension.x)
@@ -50,17 +52,17 @@ def saxpy_tma_kernel():
mbar_group[0].init(1, predicate=isThread0)
# 2. Execute Tensor Memory Accelerator (TMA) Load
- x_smem = get_dynamic_shared_memory((M, N), T.f32())
- y_smem = get_dynamic_shared_memory((M, N), T.f32(), offset=M * N * 2)
- x_tma.load(x_smem, mbar_group[0], predicate=isThread0)
- y_tma.load(y_smem, mbar_group[0], predicate=isThread0)
- mbar_group[0].arrive(txcount=M * N * 2 * 4, predicate=isThread0)
+ x_smem = get_dynamic_shared_memory([1, N], T.f32())
+ y_smem = get_dynamic_shared_memory([1, N], T.f32(), offset=sz_x)
+ x_tma.load(x_smem, mbar_group[0], coords=[0, bidx], predicate=isThread0)
+ y_tma.load(y_smem, mbar_group[0], coords=[0, bidx], predicate=isThread0)
+ mbar_group[0].arrive(txcount=sz, predicate=isThread0)
# 3. Wait for completion of TMA load with mbarrier
mbar_group[0].try_wait()
- x_val = memref.load(x_smem, [bidx, tidx])
- y_val = memref.load(y_smem, [bidx, tidx])
+ x_val = memref.load(x_smem, [const(0), tidx])
+ y_val = memref.load(y_smem, [const(0), tidx])
# SAXPY: y[i] += a * x[i];
y_val += x_val * alpha
@@ -73,15 +75,16 @@ def saxpy_tma_kernel():
gpu.wait(token_ty, [t7])
+# 3. Pass numpy arrays to MLIR
M = 256
N = 32
alpha = 2.0
-x = np.ones((M, N), np.float32)
+x = np.random.randn(M, N).astype(np.float32)
y = np.ones((M, N), np.float32)
-ref = np.ones((M, N), np.float32)
-saxpy_tma(x, y, alpha)
+saxpy(x, y, alpha)
# 4. Verify MLIR with reference computation
+ref = np.ones((M, N), np.float32)
ref += x * alpha
np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
print("PASS")
diff --git a/mlir/test/Examples/nvgpu/tools/nvdsl.py b/mlir/test/Examples/nvgpu/tools/nvdsl.py
index 6904fe68b5a01d..28105911db8134 100644
--- a/mlir/test/Examples/nvgpu/tools/nvdsl.py
+++ b/mlir/test/Examples/nvgpu/tools/nvdsl.py
@@ -8,7 +8,6 @@
from mlir import runtime as rt
from tools import nvgpucompiler
-DEBUG = True
MLIR_DYNAMIC = -9223372036854775808
@@ -124,13 +123,20 @@ def __init__(
@property
def tensormap_descriptor_ty(self):
"""Returns a tensormap descriptor type."""
- memref_str = f"memref<{self.tma_shape[0]}x{self.tma_shape[1]}x{self.memref_ty.element_type}, 3>"
- parse_str = f"!nvgpu.tensormap.descriptor<tensor = {memref_str},\
- swizzle = {self.swizzle},\
- l2promo = {self.l2promo},\
- oob = {self.oob},\
- interleave = {self.interleave}>"
-
+ memref_str = (
+ "memref<"
+ + "x".join(map(str, self.tma_shape))
+ + "x"
+ + str(self.memref_ty.element_type)
+ + ", 3>"
+ )
+ parse_str = (
+ f"!nvgpu.tensormap.descriptor<tensor = {memref_str}, "
+ f"swizzle = {self.swizzle}, "
+ f"l2promo = {self.l2promo}, "
+ f"oob = {self.oob}, "
+ f"interleave = {self.interleave}>"
+ )
return ir.Type.parse(parse_str)
def create_descriptor(self, device_ptr):
@@ -149,7 +155,7 @@ def create_descriptor(self, device_ptr):
def prefetch(self, predicate=None):
nvgpu.tma_prefetch_descriptor(self.tma_descriptor, predicate=predicate)
- def load(self, dest, mbarrier: Mbarriers, coords=[0, 0], predicate=None):
+ def load(self, dest, mbarrier: Mbarriers, coords=[0], predicate=None):
nvgpu.TmaAsyncLoadOp(
dest,
mbarrier.mbar_group_op,
>From ef6c05deeabf044f35d68b26f71dc9f6cf54729c Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Sat, 6 Apr 2024 09:43:09 +0000
Subject: [PATCH 11/18] better comment
---
mlir/test/Examples/nvgpu/Ch4.py | 23 +++++++++++++++++++++--
1 file changed, 21 insertions(+), 2 deletions(-)
diff --git a/mlir/test/Examples/nvgpu/Ch4.py b/mlir/test/Examples/nvgpu/Ch4.py
index ef0edadefa0094..47a81cd9a65688 100644
--- a/mlir/test/Examples/nvgpu/Ch4.py
+++ b/mlir/test/Examples/nvgpu/Ch4.py
@@ -5,7 +5,26 @@
# Chapter 4 : Multistage GEMM with Tensor Core
# ===----------------------------------------------------------------------===//
#
-# This program demonstrates a GEMM operation with 64x64x64 matrix multiplication
+# This program exemplifies a GEMM operation for `f32+=f16*f16`, utilizing the
+# Multistage method with a tile size of 128x128x64. The code completely
+# parallelizes the two outermost loops into thread blocks. It launches one Warp
+# Groups (128 threads in total) and allocates multiple slots/stage in the
+# shared memory. The program consists of three main parts: prologue, mainloop,
+# and epilogue. In the prologue, thread0 requests for TMA to load data into
+# shared memory slots. The mainloop executes MMA while simultaneously loading
+# TMA for the utilized slots. This overlap of TMA and MMA operations enhances
+# performance by maximizing computational throughput.
+#
+# Loops illustration:
+#
+# for s in range(NUM_STAGES):
+# TMA_128x64_64x128...
+# for ti in range(M//128): # -> blockIdx.x
+# for tj in range(N//128): # -> blockIdx.y
+# for tk in range(K//64):
+# MMA_128x128x64...
+# TMA_128x64_64x128...
+# Epilogue...
#
# This chapter introduces demonstrates:
# 1. Partition shape based on block IDs
@@ -174,7 +193,7 @@ def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
a_smem = get_dynamic_shared_memory([TILE_M, TILE_K], T.f16(), offset_a)
b_smem = get_dynamic_shared_memory([TILE_K, TILE_N], T.f16(), offset_b)
- # Initialize matrices
+ # Iterate input matrices, update accumulator
A.update_smem(a_smem)
B.update_smem(b_smem)
D.update_accumulator(for_op.inner_iter_args[0])
>From 78b372aa53fba93756046543b2321eefb7453b82 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Sat, 6 Apr 2024 09:43:33 +0000
Subject: [PATCH 12/18] Add Ch5.py Warp Specialized Kernel
---
mlir/test/Examples/nvgpu/Ch5.py | 320 ++++++++++++++++++++++++
mlir/test/Examples/nvgpu/tools/nvdsl.py | 35 ++-
2 files changed, 352 insertions(+), 3 deletions(-)
create mode 100644 mlir/test/Examples/nvgpu/Ch5.py
diff --git a/mlir/test/Examples/nvgpu/Ch5.py b/mlir/test/Examples/nvgpu/Ch5.py
new file mode 100644
index 00000000000000..7881af2ab54439
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/Ch5.py
@@ -0,0 +1,320 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN: %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+# Chapter 5 : Warp Specialized GEMM with Tensor Core
+# ===----------------------------------------------------------------------===//
+#
+# This program exemplifies a GEMM operation for `f32+=f16*f16`, utilizing the
+# Warp Specialized method with a tile size of 128x128x64. The code completely
+# parallelizes the two outermost loops into thread blocks. It launches two Warp
+# Groups (256 threads in total): one for the producer and the other for the consumer.
+# Each group takes a different control-flow. The producer thread group is responsible
+# for loading data into shared memory, while the consumer group executes the Tensor
+# Core GEMM operation and epilogue.
+#
+# for ti in range(M//128): # -> blockIdx.x
+# for tj in range(N//128): # -> blockIdx.y
+# with wg_producer:
+# for tk in range(K//64):
+# TMA_128x64_64x128...
+# with wg_consumer:
+# for tk in range(K//64):
+# MMA_128x128x64...
+# Epilogue..
+#
+# This chapter demonstrates:
+# 2 WG (warpgroups)
+# Producer:
+# 2.1.1 Wait MMA Barrier
+# 2.1.1 Load TMA with TMA barrier
+# 2.1.1 Arrive TMA barrier with txcount
+# Consumer:
+# Loop
+# Wait TMA barrier
+# Performs Tensor Core GEMM 64x128x64 by warpgroup
+# Arrive MMA Barrier
+# Epilogue
+# Store fragmented registers to shared memory
+# Store shared memory to global
+#
+# ===----------------------------------------------------------------------===//
+
+
+from mlir import ir
+from mlir.dialects import gpu, scf, nvgpu, nvvm
+from mlir.extras import types as T
+from tools.nvdsl import *
+import numpy as np
+
+
+PRODUCER_PRIMARY_THREAD = 128 # Producer primary thread
+CONSUMER_PRIMARY_THREAD = 0 # Consumer primary thread
+PRODUCER_REGISTER_SIZE = 40 # Producer primary thread
+CONSUMER_REGISTER_SIZE = 232 # Consumer primary thread
+
+
+def partition_shape():
+ """
+ Calculate the partition shape based on the block IDs.
+
+ It partitions the shape like below:
+ for(.. i < M ...) --> blockIdx.x
+ for(.. j < N ...) --> blockIdx.y
+ for(.. k < K ...)
+
+ Returns:
+ dimX (int): Dimension along the x-axis.
+ dimY (int): Dimension along the y-axis.
+ """
+ bidx = gpu.block_id(gpu.Dimension.x)
+ bidy = gpu.block_id(gpu.Dimension.y)
+ dimX = bidx * TILE_M
+ dimY = bidy * TILE_N
+ return dimX, dimY
+
+
+def tma_load(
+ mbar_group: Mbarriers,
+ a_tma: TMA,
+ b_tma: TMA,
+ slot,
+ stage,
+ p=None,
+):
+ """
+ TMA loads two input matrices from global memory to shared memory. It performs the following operations:
+
+ - tma.load a_shared_memory[off_x] at coordinate [x, z] (Loads 128x64)
+ - tma.load b_shared_memory[off_y1] at coordinate [y, x] (Loads 64x64)
+ - tma.load b_shared_memory[off_y2] at coordinate [y + 64, x] (Loads 64x64)
+
+ mbarrier.arrive ta_count = 128x64x2x4
+ """
+ dimX, dimY = partition_shape()
+
+ tidx = gpu.thread_id(gpu.Dimension.x)
+ begin_b = NUM_STAGES * get_type_size(a_tma.tma_memref)
+ size_tma_a = get_type_size(a_tma.tma_memref)
+ size_tma_b = get_type_size(b_tma.tma_memref)
+ ta_count = size_tma_a + (size_tma_b * 2)
+
+ off_a = slot * size_tma_a
+ off_b = (slot * size_tma_a) + begin_b
+ off_b2 = off_b + size_tma_b
+ a_elem_ty = a_tma.tma_memref.element_type
+ b_elem_ty = b_tma.tma_memref.element_type
+ a = get_dynamic_shared_memory(a_tma.tma_memref.shape, a_elem_ty, off_a)
+ b1 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b)
+ b2 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b2)
+
+ mbar_group[slot].arrive(ta_count, predicate=p)
+ p = (tidx % WARP_GROUP_SIZE) == 0
+ c1 = stage * 64
+ a_tma.load(a, mbar_group[slot], coords=[c1, dimX], predicate=p)
+ b_tma.load(b1, mbar_group[slot], coords=[dimY, c1], predicate=p)
+ b_tma.load(b2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p)
+
+
+def bootstrap(a_tma: TMA, b_tma: TMA):
+ """
+ Initialize mbarriers and prefetch TMA descriptors.
+ """
+ tidx = gpu.thread_id(gpu.Dimension.x)
+ mbar_group_tma = Mbarriers(number_of_barriers=NUM_STAGES)
+ mbar_group_mma = Mbarriers(number_of_barriers=NUM_STAGES)
+ isThread0 = tidx == const(0)
+ with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
+ for i in scf.for_(0, NUM_STAGES, 1):
+ mbar_group_tma[i].init(1)
+ mbar_group_mma[i].init(1)
+ scf.yield_([])
+ a_tma.prefetch()
+ b_tma.prefetch()
+ scf.yield_([])
+
+ return mbar_group_tma, mbar_group_mma
+
+
+def switch_phase(stage, phase):
+ p = stage == (NUM_STAGES - 1)
+ phase = arith.select(
+ p,
+ (phase ^ const(True, ty=T.bool())),
+ phase,
+ )
+ return phase
+
+
+def producer_loop(
+ mbar_group_tma: Mbarriers,
+ mbar_group_mma: Mbarriers,
+ a_tma: TMA,
+ b_tma: TMA,
+ wg_me: Warpgroup,
+):
+ phase = const(True, ty=T.bool())
+
+ for iv, phase in scf.for_(0, (K // TILE_K), 1, [phase]):
+ stage = iv % NUM_STAGES
+ # Wait MMA to be done
+ mbar_group_mma[stage].try_wait(phase)
+ # New phase for mbarrier
+ phase = switch_phase(stage, phase)
+ # TMA Load
+ tma_load(mbar_group_tma, a_tma, b_tma, stage, iv, wg_me.is_wg_primary)
+ scf.yield_([phase])
+
+
+def consumer_loop(
+ mbar_group_tma: Mbarriers,
+ mbar_group_mma: Mbarriers,
+ a_tma: TMA,
+ b_tma: TMA,
+ wg_me: Warpgroup,
+):
+ begin_b = NUM_STAGES * get_type_size(a_tma.tma_memref)
+
+ size_a = TILE_M * TILE_K * get_type_size(T.f16())
+
+ phase = const(False, ty=T.bool())
+ A = WGMMAMatrix(WGMMAType.Descriptor, [TILE_M, TILE_K], desc=a_tma)
+ B = WGMMAMatrix(WGMMAType.Descriptor, [TILE_K, TILE_N], desc=b_tma)
+ D = WGMMAMatrix(WGMMAType.Accumulator, shape=[TILE_M, TILE_N], ty=T.f32())
+
+ for_op = scf.ForOp(const(0), const(K // TILE_K), const(1), [D.acc_op, phase])
+ with ir.InsertionPoint(for_op.body):
+ phase = for_op.inner_iter_args[1]
+ iv = for_op.induction_variable
+ stage = iv % NUM_STAGES
+
+ # Wait TMA for current stage
+ mbar_group_tma[stage].try_wait(phase)
+
+ # Find shared memory slot
+ offset_a = stage * size_a
+ offset_b = offset_a + begin_b
+ a_smem = get_dynamic_shared_memory([TILE_M, TILE_K], T.f16(), offset_a)
+ b_smem = get_dynamic_shared_memory([TILE_K, TILE_N], T.f16(), offset_b)
+
+ # Iterate input matrices, update accumulator
+ A.update_smem(a_smem)
+ B.update_smem(b_smem)
+ D.update_accumulator(for_op.inner_iter_args[0])
+
+ # Matrix Multiply
+ D += A @ B
+
+ # MMA Barrier Arrive
+ p_arrive = (iv > 0) & wg_me.is_wg_primary
+ with ir.InsertionPoint(scf.IfOp(p_arrive).then_block):
+ barId = arith.select((stage == 0), const(NUM_STAGES - 1), (stage - 1))
+ mbar_group_mma[barId].arrive()
+ scf.yield_([])
+
+ phase = switch_phase(stage, phase)
+ scf.yield_([D.acc_op, phase])
+
+ nvvm.WgmmaWaitGroupSyncOp(0)
+ D.update_accumulator(for_op.results[0])
+ return D
+
+
+def epilogue(D: WGMMAMatrix, d_dev):
+ """
+ Epilogue of the GEMM kernel. It stores the fragmented registers to global memory.
+
+ MatrixAccumulator D # Fragmented results
+ store D -> Shared Memory # Store Shared Memory
+ Shared Memory -> Z[dimX][dimY] # Store Shared Memory to Global Memory
+
+ """
+ tidx = gpu.thread_id(gpu.Dimension.x)
+ dimX, dimY = partition_shape()
+ # s = tidx - WARP_GROUP_SIZE
+ # debug_print("[Epilogue] store to global memory @ s={}", s)
+
+ d_smem = get_dynamic_shared_memory([TILE_M, TILE_N], T.f32())
+ d_gmem = memref.subview(d_dev, [dimX, dimY], [TILE_M, TILE_N], [1, 1])
+
+ # Store (registers -> shared memory)
+ D.store_accumulator(d_smem)
+ gpu.barrier()
+
+ # Store (shared memory --> global memory)
+ for i in scf.for_(0, TILE_M, 1):
+ val = memref.load(d_smem, [i, tidx])
+ memref.store(val, d_gmem, [i, tidx])
+ scf.yield_([])
+
+
+ at NVDSL.mlir_func
+def gemm_warp_specialized(a, b, d):
+ token_ty = ir.Type.parse("!gpu.async.token")
+ t1 = gpu.wait(token_ty, [])
+ a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
+ b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])
+ d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], [])
+ t5 = gpu.memcpy(token_ty, [t4], a_dev, a)
+ t6 = gpu.memcpy(token_ty, [t5], b_dev, b)
+ t7 = gpu.wait(token_ty, [t6])
+
+ sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
+ a_tma = TMA([128, 64], a.type, swizzle=sw)
+ b_tma = TMA([64, 64], b.type, swizzle=sw)
+ a_tma.create_descriptor(a_dev)
+ b_tma.create_descriptor(b_dev)
+
+ grid = [(M // TILE_M), (N // TILE_N), 1]
+ block = [256, 1, 1]
+
+ size_a = get_type_size(a.type.element_type) * TILE_M * TILE_K
+ size_b = get_type_size(b.type.element_type) * TILE_N * TILE_K
+ smem_size_in_bytes = (size_a + size_b) * NUM_STAGES
+
+ @NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=smem_size_in_bytes)
+ def gemm_warp_specialized_kernel():
+ # Init Warpgroups
+ wg_producer = Warpgroup(PRODUCER_PRIMARY_THREAD, PRODUCER_REGISTER_SIZE)
+ wg_consumer = Warpgroup(CONSUMER_PRIMARY_THREAD, CONSUMER_REGISTER_SIZE)
+
+ # Initialize mbarriers and prefetch TMA descriptors
+ mbar_group_mma, mbar_group_tma = bootstrap(a_tma, b_tma)
+
+ # Producer performs TMA
+ with wg_producer:
+ producer_loop(mbar_group_tma, mbar_group_mma, a_tma, b_tma, wg_producer)
+
+ # Producer performs MMA/Tensor Core
+ with wg_consumer:
+ D = consumer_loop(mbar_group_tma, mbar_group_mma, a_tma, b_tma, wg_consumer)
+ epilogue(D, d_dev)
+
+ gemm_warp_specialized_kernel()
+
+ t8 = gpu.memcpy(token_ty, [t7], d, d_dev)
+ gpu.wait(None, [t8])
+
+
+# Python pass arguments to MLIR
+NUM_STAGES = 7
+N = 256
+M = 512
+K = 1024
+TILE_M = 128
+TILE_N = 128
+TILE_K = 64
+a = np.random.randn(M, K).astype(np.float16)
+b = np.random.randn(K, N).astype(np.float16)
+d = np.zeros((M, N), np.float32)
+
+gemm_warp_specialized(a, b, d)
+
+
+# Verify MLIR with reference computation
+ref_d = a.astype(np.float16) @ b.astype(np.float16)
+np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
+
+
+print("PASS")
+# CHECK-NOT: Mismatched elements
diff --git a/mlir/test/Examples/nvgpu/tools/nvdsl.py b/mlir/test/Examples/nvgpu/tools/nvdsl.py
index 28105911db8134..7d9e13c7f8627c 100644
--- a/mlir/test/Examples/nvgpu/tools/nvdsl.py
+++ b/mlir/test/Examples/nvgpu/tools/nvdsl.py
@@ -3,7 +3,7 @@
import numpy as np
from functools import partialmethod
from mlir import ir
-from mlir.dialects import arith, func, gpu, memref, nvgpu
+from mlir.dialects import arith, func, gpu, memref, nvgpu, scf, nvvm
from mlir.extras import types as T
from mlir import runtime as rt
from tools import nvgpucompiler
@@ -84,7 +84,9 @@ def arrive(self, txcount: int = 0, predicate=None):
self.mbar_group_op, txcount_op, self.id_op, predicate=predicate
)
else:
- nvgpu.mbarrier_arrive(self.mbar_group_op, self.id_op, predicate=predicate)
+ nvgpu.mbarrier_arrive(
+ ir.Type.parse("!nvgpu.mbarrier.token"), self.mbar_group_op, self.id_op
+ )
def try_wait(self, phase: bool = False, ticks: int = 10000000):
ticks_op = const(ticks)
@@ -166,6 +168,33 @@ def load(self, dest, mbarrier: Mbarriers, coords=[0], predicate=None):
)
+WARP_GROUP_SIZE = 128 # Number of threads in a warpgroup
+
+
+class Warpgroup:
+ def __init__(self, primaryThread, registerSize):
+ tidx = gpu.thread_id(gpu.Dimension.x)
+ self.primary_thread = primaryThread
+ self.register_size = registerSize
+ self.is_wg_primary = (tidx % WARP_GROUP_SIZE) == 0
+ self.wg_id = tidx / WARP_GROUP_SIZE
+ self.is_me = self.wg_id == (primaryThread // WARP_GROUP_SIZE)
+
+ def __enter__(self):
+ if_op = scf.IfOp(self.is_me)
+ self.ipoint_op = ir.InsertionPoint(if_op.then_block)
+ self.ipoint_op.__enter__()
+ if self.register_size < 64:
+ nvvm.setmaxregister(self.register_size, nvvm.SetMaxRegisterAction.decrease)
+ else:
+ nvvm.setmaxregister(self.register_size, nvvm.SetMaxRegisterAction.increase)
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ scf.yield_([])
+ self.ipoint_op.__exit__(exc_type, exc_value, traceback)
+ return True
+
+
class WGMMAType(Enum):
Accumulator = 1
Descriptor = 2
@@ -408,7 +437,7 @@ def __str__(self):
module.operation.verify()
# Save IR in a file
- saveIR(module)
+ # saveIR(module)
# Compile and JIT MLIR module
options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
>From a43378cf888da1787b564b2dcc2c3ec05a4c821e Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Fri, 12 Apr 2024 17:44:13 +0000
Subject: [PATCH 13/18] fix omments
---
mlir/test/Examples/nvgpu/Ch0.py | 16 +++++++++++-----
mlir/test/Examples/nvgpu/Ch1.py | 2 +-
mlir/test/Examples/nvgpu/Ch5.py | 6 +++---
mlir/test/Examples/nvgpu/tools/nvdsl.py | 8 ++++----
4 files changed, 19 insertions(+), 13 deletions(-)
diff --git a/mlir/test/Examples/nvgpu/Ch0.py b/mlir/test/Examples/nvgpu/Ch0.py
index 221ca43d37307a..0c3cc30fde85e4 100644
--- a/mlir/test/Examples/nvgpu/Ch0.py
+++ b/mlir/test/Examples/nvgpu/Ch0.py
@@ -7,7 +7,7 @@
#
# This program demonstrates Hello World
#
-# This chapter introduces demonstrates:
+# This chapter demonstrates:
# 1. Build MLIR function with arguments
# 2. Build MLIR GPU kernel
# 3. Print from a GPU thread
@@ -20,13 +20,19 @@
from tools.nvdsl import *
-# 1. Build function with arguments
+# 1. The `mlir_func` decorator generates a MLIR `func.func`.
+# Everything inside the Python function becomes the body of the `func`.
+# The decorator also translates `alpha` to an `index` type.
@NVDSL.mlir_func
def main(alpha):
- # 2. Build GPU kernel
+# 2. The `mlir_gpu_launch` decorator generates a MLIR `gpu.launch`.
+# Everything inside the Python function becomes the body of the `gpu.launch`.
+# This allows for late outlining of the GPU kernel, enabling optimizations
+# like constant folding from host to device.
@NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(4, 1, 1))
def kernel():
tidx = gpu.thread_id(gpu.Dimension.x)
+ # `+`` operator generates a `arith.addi`
myValue = alpha + tidx
# Print from a GPU thread
gpu.printf("GPU thread %llu has %llu\n", [tidx, myValue])
@@ -34,12 +40,12 @@ def kernel():
# 3. Call the GPU kernel
kernel()
-
-# 4. Pass arguments, JIT compile and run the MLIR function
alpha = 100
+# 4. The `mlir_func` decorator JIT compiles the IR and executes the MLIR function.
main(alpha)
+
# CHECK: GPU thread 0 has 100
# CHECK: GPU thread 1 has 101
# CHECK: GPU thread 2 has 102
diff --git a/mlir/test/Examples/nvgpu/Ch1.py b/mlir/test/Examples/nvgpu/Ch1.py
index 22e1eb7c92e005..315480138f66c8 100644
--- a/mlir/test/Examples/nvgpu/Ch1.py
+++ b/mlir/test/Examples/nvgpu/Ch1.py
@@ -7,7 +7,7 @@
#
# This program demonstrates 2D Saxpy
#
-# This chapter introduces demonstrates:
+# This chapter demonstrates:
# 1. Use MLIR GPU dialect to allocate and copy memory
# 2. Compute 2D SAXPY kernel
# 3. Pass numpy arrays to MLIR
diff --git a/mlir/test/Examples/nvgpu/Ch5.py b/mlir/test/Examples/nvgpu/Ch5.py
index 7881af2ab54439..a22de802452792 100644
--- a/mlir/test/Examples/nvgpu/Ch5.py
+++ b/mlir/test/Examples/nvgpu/Ch5.py
@@ -275,8 +275,8 @@ def gemm_warp_specialized(a, b, d):
@NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=smem_size_in_bytes)
def gemm_warp_specialized_kernel():
# Init Warpgroups
- wg_producer = Warpgroup(PRODUCER_PRIMARY_THREAD, PRODUCER_REGISTER_SIZE)
- wg_consumer = Warpgroup(CONSUMER_PRIMARY_THREAD, CONSUMER_REGISTER_SIZE)
+ wg_producer = Warpgroup(primary_thread=128, register_size=40)
+ wg_consumer = Warpgroup(primary_thread=0, register_size=232)
# Initialize mbarriers and prefetch TMA descriptors
mbar_group_mma, mbar_group_tma = bootstrap(a_tma, b_tma)
@@ -285,7 +285,7 @@ def gemm_warp_specialized_kernel():
with wg_producer:
producer_loop(mbar_group_tma, mbar_group_mma, a_tma, b_tma, wg_producer)
- # Producer performs MMA/Tensor Core
+ # Consumer performs MMA/Tensor Core
with wg_consumer:
D = consumer_loop(mbar_group_tma, mbar_group_mma, a_tma, b_tma, wg_consumer)
epilogue(D, d_dev)
diff --git a/mlir/test/Examples/nvgpu/tools/nvdsl.py b/mlir/test/Examples/nvgpu/tools/nvdsl.py
index 7d9e13c7f8627c..2c1615d6f9d68b 100644
--- a/mlir/test/Examples/nvgpu/tools/nvdsl.py
+++ b/mlir/test/Examples/nvgpu/tools/nvdsl.py
@@ -172,13 +172,13 @@ def load(self, dest, mbarrier: Mbarriers, coords=[0], predicate=None):
class Warpgroup:
- def __init__(self, primaryThread, registerSize):
+ def __init__(self, primary_thread, register_size):
tidx = gpu.thread_id(gpu.Dimension.x)
- self.primary_thread = primaryThread
- self.register_size = registerSize
+ self.primary_thread = primary_thread
+ self.register_size = register_size
self.is_wg_primary = (tidx % WARP_GROUP_SIZE) == 0
self.wg_id = tidx / WARP_GROUP_SIZE
- self.is_me = self.wg_id == (primaryThread // WARP_GROUP_SIZE)
+ self.is_me = self.wg_id == (primary_thread // WARP_GROUP_SIZE)
def __enter__(self):
if_op = scf.IfOp(self.is_me)
>From 4c56628420aa6714ff49f5bf26f547b27a1c5c8d Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Fri, 12 Apr 2024 18:19:44 +0000
Subject: [PATCH 14/18] address comments
---
mlir/test/Examples/nvgpu/Ch0.py | 4 +-
mlir/test/Examples/nvgpu/Ch1.py | 12 +-
mlir/test/Examples/nvgpu/Ch2.py | 8 +-
mlir/test/Examples/nvgpu/Ch4.py | 51 ++++---
mlir/test/Examples/nvgpu/Ch5.py | 65 +++++----
mlir/test/Examples/nvgpu/lit.local.cfg | 3 +-
mlir/test/Examples/nvgpu/nvdsl.mlir | 186 ++++++++++++++++++++++++
mlir/test/Examples/nvgpu/tools/nvdsl.py | 10 +-
8 files changed, 264 insertions(+), 75 deletions(-)
create mode 100644 mlir/test/Examples/nvgpu/nvdsl.mlir
diff --git a/mlir/test/Examples/nvgpu/Ch0.py b/mlir/test/Examples/nvgpu/Ch0.py
index 0c3cc30fde85e4..ff294c7ec416ef 100644
--- a/mlir/test/Examples/nvgpu/Ch0.py
+++ b/mlir/test/Examples/nvgpu/Ch0.py
@@ -5,9 +5,7 @@
# Chapter 0 : Hello World
# ===----------------------------------------------------------------------===//
#
-# This program demonstrates Hello World
-#
-# This chapter demonstrates:
+# This program demonstrates Hello World:
# 1. Build MLIR function with arguments
# 2. Build MLIR GPU kernel
# 3. Print from a GPU thread
diff --git a/mlir/test/Examples/nvgpu/Ch1.py b/mlir/test/Examples/nvgpu/Ch1.py
index 315480138f66c8..da65aa2ef6a172 100644
--- a/mlir/test/Examples/nvgpu/Ch1.py
+++ b/mlir/test/Examples/nvgpu/Ch1.py
@@ -5,13 +5,11 @@
# Chapter 1 : 2D Saxpy
# ===----------------------------------------------------------------------===//
#
-# This program demonstrates 2D Saxpy
-#
-# This chapter demonstrates:
-# 1. Use MLIR GPU dialect to allocate and copy memory
-# 2. Compute 2D SAXPY kernel
-# 3. Pass numpy arrays to MLIR
-# 4. Verify MLIR with reference computation
+# This program demonstrates 2D Saxpy:
+# 1. Use GPU dialect to allocate and copy memory host to gpu and vice versa
+# 2. Computes 2D SAXPY kernel using operator overloading
+# 3. Pass numpy arrays to MLIR as memref arguments
+# 4. Verify MLIR program with reference computation in python
#
# ===----------------------------------------------------------------------===//
diff --git a/mlir/test/Examples/nvgpu/Ch2.py b/mlir/test/Examples/nvgpu/Ch2.py
index 2e316cf406e99f..78c14cb2c7ad8c 100644
--- a/mlir/test/Examples/nvgpu/Ch2.py
+++ b/mlir/test/Examples/nvgpu/Ch2.py
@@ -9,9 +9,11 @@
# but it loads data using TMA (Tensor Memory Accelerator)
#
# This chapter introduces demonstrates:
-# 1. Create and initialize asynchronous transactional barrier (mbarrier)
-# 2. Execute Tensor Memory Accelerator (TMA) Load
-# 3. Wait for completion of TMA load with mbarrier
+# 1. Computes 2D SAXPY in the same way as Ch1.py but loads data using TMA
+# 2. Create and initialize 1 asynchronous transactional barrier (mbarrier)
+# 3. Thread-0 Load request data load from TMA for each thread block
+# 4. Each thread block loads <1x32xf32> for x and y.
+# 5. Wait for completion of TMA load with mbarrier
#
# ===----------------------------------------------------------------------===//
diff --git a/mlir/test/Examples/nvgpu/Ch4.py b/mlir/test/Examples/nvgpu/Ch4.py
index 47a81cd9a65688..175d0aad5f77d7 100644
--- a/mlir/test/Examples/nvgpu/Ch4.py
+++ b/mlir/test/Examples/nvgpu/Ch4.py
@@ -17,7 +17,7 @@
#
# Loops illustration:
#
-# for s in range(NUM_STAGES):
+# for s in range(num_stages):
# TMA_128x64_64x128...
# for ti in range(M//128): # -> blockIdx.x
# for tj in range(N//128): # -> blockIdx.y
@@ -74,6 +74,7 @@ def tma_load(
b_tma: TMA,
slot,
stage,
+ num_stages,
p=None,
):
"""
@@ -88,7 +89,7 @@ def tma_load(
dimX, dimY = partition_shape()
tidx = gpu.thread_id(gpu.Dimension.x)
- begin_b = NUM_STAGES * get_type_size(a_tma.tma_memref)
+ begin_b = num_stages * get_type_size(a_tma.tma_memref)
size_tma_a = get_type_size(a_tma.tma_memref)
size_tma_b = get_type_size(b_tma.tma_memref)
ta_count = size_tma_a + (size_tma_b * 2)
@@ -113,15 +114,15 @@ def tma_load(
b_tma.load(b2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p)
-def bootstrap(a_tma: TMA, b_tma: TMA):
+def initialize(a_tma: TMA, b_tma: TMA, num_stages):
"""
Initialize mbarriers and prefetch TMA descriptors.
"""
tidx = gpu.thread_id(gpu.Dimension.x)
- mbar_group = Mbarriers(number_of_barriers=NUM_STAGES)
+ mbar_group = Mbarriers(number_of_barriers=num_stages)
isThread0 = tidx == const(0)
with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
- for i in scf.for_(0, NUM_STAGES, 1):
+ for i in scf.for_(0, num_stages, 1):
mbar_group[i].init(1)
scf.yield_([])
a_tma.prefetch()
@@ -131,7 +132,7 @@ def bootstrap(a_tma: TMA, b_tma: TMA):
return mbar_group
-def prologue(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
+def prologue(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA, num_stages):
"""
Prologue of the GEMM kernel. It loads 2 input matrices for each stage in loop like below:
@@ -139,13 +140,13 @@ def prologue(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
tma_load x, y, stage
"""
- ns = NUM_STAGES if NUM_STAGES == 1 else NUM_STAGES - 1
+ ns = num_stages if num_stages == 1 else num_stages - 1
for iv in scf.for_(0, ns, 1):
- tma_load(mbar_group, a_tma, b_tma, iv, iv)
+ tma_load(mbar_group, a_tma, b_tma, iv, iv, num_stages)
scf.yield_([])
-def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
+def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA, num_stages):
"""
Main loop of the Multistage GEMM kernel. It iterates through
stages and performs matrix multiplication, loading data by TMA to shared memory. It like following
@@ -163,10 +164,10 @@ def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
tma_load(x, y, nextSlot, nextStage)
"""
- ns = NUM_STAGES if NUM_STAGES == 1 else NUM_STAGES - 1
+ ns = num_stages if num_stages == 1 else num_stages - 1
tidx = gpu.thread_id(gpu.Dimension.x)
- begin_b = NUM_STAGES * get_type_size(a_tma.tma_memref)
+ begin_b = num_stages * get_type_size(a_tma.tma_memref)
size_a = TILE_M * TILE_K * get_type_size(T.f16())
@@ -182,7 +183,7 @@ def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
with ir.InsertionPoint(for_op.body):
phase = for_op.inner_iter_args[1]
iv = for_op.induction_variable
- stage = iv % NUM_STAGES
+ stage = iv % num_stages
# Wait for current stage
mbar_group[stage].try_wait(phase=phase)
@@ -202,18 +203,18 @@ def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
D += A @ B
# Wait Tensor Core for single stage
- if NUM_STAGES == 1:
+ if num_stages == 1:
nvvm.WgmmaWaitGroupSyncOp(0)
# Load next stage
pred = ((iv + ns) < const(K // TILE_K)) & (tidx == 0)
nextStage = iv + ns
- nextSlot = nextStage % NUM_STAGES
- tma_load(mbar_group, a_tma, b_tma, nextSlot, nextStage, pred)
+ nextSlot = nextStage % num_stages
+ tma_load(mbar_group, a_tma, b_tma, nextSlot, nextStage, num_stages, pred)
# Switch phase parity for the mbarrier
newPhase = arith.select(
- stage == (NUM_STAGES - 1),
+ stage == (num_stages - 1),
(phase ^ const(True, ty=T.bool())),
phase,
)
@@ -250,9 +251,12 @@ def epilogue(D: WGMMAMatrix, d_dev):
memref.store(val, d_gmem, [i, tidx])
scf.yield_([])
-
+# The decorator generates
+# a -> memref<MxKxf16>
+# b -> memref<NxKf16>
+# d -> memref<MxNxf32>
@NVDSL.mlir_func
-def gemm_multistage(a, b, d):
+def gemm_multistage(a, b, d, num_stages):
token_ty = ir.Type.parse("!gpu.async.token")
t1 = gpu.wait(token_ty, [])
a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
@@ -273,18 +277,18 @@ def gemm_multistage(a, b, d):
size_a = get_type_size(a.type.element_type) * TILE_M * TILE_K
size_b = get_type_size(b.type.element_type) * TILE_N * TILE_K
- smem_size_in_bytes = (size_a + size_b) * NUM_STAGES
+ smem_size_in_bytes = (size_a + size_b) * num_stages
@NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=smem_size_in_bytes)
def gemm_multistage_kernel():
# Initialize mbarriers and prefetch TMA descriptors
- mbar_group = bootstrap(a_tma, b_tma)
+ mbar_group = initialize(a_tma, b_tma, num_stages)
# Fill the pipeline stages
- prologue(mbar_group, a_tma, b_tma)
+ prologue(mbar_group, a_tma, b_tma, num_stages)
# Main loop
- D = mainloop(mbar_group, a_tma, b_tma)
+ D = mainloop(mbar_group, a_tma, b_tma, num_stages)
# Store registers to global memory
epilogue(D, d_dev)
@@ -296,7 +300,6 @@ def gemm_multistage_kernel():
# Python pass arguments to MLIR
-NUM_STAGES = 7
N = 256
M = 512
K = 1024
@@ -307,7 +310,7 @@ def gemm_multistage_kernel():
b = np.random.randn(K, N).astype(np.float16)
d = np.zeros((M, N), np.float32)
-gemm_multistage(a, b, d)
+gemm_multistage(a, b, d, num_stages=7)
# Verify MLIR with reference computation
diff --git a/mlir/test/Examples/nvgpu/Ch5.py b/mlir/test/Examples/nvgpu/Ch5.py
index a22de802452792..daffb38ba0db7e 100644
--- a/mlir/test/Examples/nvgpu/Ch5.py
+++ b/mlir/test/Examples/nvgpu/Ch5.py
@@ -5,7 +5,7 @@
# Chapter 5 : Warp Specialized GEMM with Tensor Core
# ===----------------------------------------------------------------------===//
#
-# This program exemplifies a GEMM operation for `f32+=f16*f16`, utilizing the
+# This program demonstrates a GEMM operation for `f32+=f16*f16`, utilizing the
# Warp Specialized method with a tile size of 128x128x64. The code completely
# parallelizes the two outermost loops into thread blocks. It launches two Warp
# Groups (256 threads in total): one for the producer and the other for the consumer.
@@ -48,20 +48,19 @@
import numpy as np
-PRODUCER_PRIMARY_THREAD = 128 # Producer primary thread
-CONSUMER_PRIMARY_THREAD = 0 # Consumer primary thread
-PRODUCER_REGISTER_SIZE = 40 # Producer primary thread
-CONSUMER_REGISTER_SIZE = 232 # Consumer primary thread
-
-
def partition_shape():
"""
Calculate the partition shape based on the block IDs.
- It partitions the shape like below:
- for(.. i < M ...) --> blockIdx.x
- for(.. j < N ...) --> blockIdx.y
- for(.. k < K ...)
+ It parallelizes the two outermost loops into thread blocks.
+ for ti in range(M//128): # -> blockIdx.x
+ for tj in range(N//128): # -> blockIdx.y
+ D = 0
+ for tk in range(K//64):
+ for i in range(128):
+ for j in range(128):
+ for k in range(64):
+ FMA
Returns:
dimX (int): Dimension along the x-axis.
@@ -80,6 +79,7 @@ def tma_load(
b_tma: TMA,
slot,
stage,
+ num_stages,
p=None,
):
"""
@@ -94,7 +94,7 @@ def tma_load(
dimX, dimY = partition_shape()
tidx = gpu.thread_id(gpu.Dimension.x)
- begin_b = NUM_STAGES * get_type_size(a_tma.tma_memref)
+ begin_b = num_stages * get_type_size(a_tma.tma_memref)
size_tma_a = get_type_size(a_tma.tma_memref)
size_tma_b = get_type_size(b_tma.tma_memref)
ta_count = size_tma_a + (size_tma_b * 2)
@@ -116,16 +116,16 @@ def tma_load(
b_tma.load(b2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p)
-def bootstrap(a_tma: TMA, b_tma: TMA):
+def initialize(a_tma: TMA, b_tma: TMA, num_stages):
"""
Initialize mbarriers and prefetch TMA descriptors.
"""
tidx = gpu.thread_id(gpu.Dimension.x)
- mbar_group_tma = Mbarriers(number_of_barriers=NUM_STAGES)
- mbar_group_mma = Mbarriers(number_of_barriers=NUM_STAGES)
+ mbar_group_tma = Mbarriers(number_of_barriers=num_stages)
+ mbar_group_mma = Mbarriers(number_of_barriers=num_stages)
isThread0 = tidx == const(0)
with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
- for i in scf.for_(0, NUM_STAGES, 1):
+ for i in scf.for_(0, num_stages, 1):
mbar_group_tma[i].init(1)
mbar_group_mma[i].init(1)
scf.yield_([])
@@ -136,8 +136,8 @@ def bootstrap(a_tma: TMA, b_tma: TMA):
return mbar_group_tma, mbar_group_mma
-def switch_phase(stage, phase):
- p = stage == (NUM_STAGES - 1)
+def switch_phase(stage, phase, num_stages):
+ p = stage == (num_stages - 1)
phase = arith.select(
p,
(phase ^ const(True, ty=T.bool())),
@@ -152,17 +152,18 @@ def producer_loop(
a_tma: TMA,
b_tma: TMA,
wg_me: Warpgroup,
+ num_stages
):
phase = const(True, ty=T.bool())
for iv, phase in scf.for_(0, (K // TILE_K), 1, [phase]):
- stage = iv % NUM_STAGES
+ stage = iv % num_stages
# Wait MMA to be done
mbar_group_mma[stage].try_wait(phase)
# New phase for mbarrier
- phase = switch_phase(stage, phase)
+ phase = switch_phase(stage, phase, num_stages)
# TMA Load
- tma_load(mbar_group_tma, a_tma, b_tma, stage, iv, wg_me.is_wg_primary)
+ tma_load(mbar_group_tma, a_tma, b_tma, stage, iv, num_stages, wg_me.is_wg_primary)
scf.yield_([phase])
@@ -172,8 +173,9 @@ def consumer_loop(
a_tma: TMA,
b_tma: TMA,
wg_me: Warpgroup,
+ num_stages
):
- begin_b = NUM_STAGES * get_type_size(a_tma.tma_memref)
+ begin_b = num_stages * get_type_size(a_tma.tma_memref)
size_a = TILE_M * TILE_K * get_type_size(T.f16())
@@ -186,7 +188,7 @@ def consumer_loop(
with ir.InsertionPoint(for_op.body):
phase = for_op.inner_iter_args[1]
iv = for_op.induction_variable
- stage = iv % NUM_STAGES
+ stage = iv % num_stages
# Wait TMA for current stage
mbar_group_tma[stage].try_wait(phase)
@@ -208,11 +210,11 @@ def consumer_loop(
# MMA Barrier Arrive
p_arrive = (iv > 0) & wg_me.is_wg_primary
with ir.InsertionPoint(scf.IfOp(p_arrive).then_block):
- barId = arith.select((stage == 0), const(NUM_STAGES - 1), (stage - 1))
+ barId = arith.select((stage == 0), const(num_stages - 1), (stage - 1))
mbar_group_mma[barId].arrive()
scf.yield_([])
- phase = switch_phase(stage, phase)
+ phase = switch_phase(stage, phase, num_stages)
scf.yield_([D.acc_op, phase])
nvvm.WgmmaWaitGroupSyncOp(0)
@@ -249,7 +251,7 @@ def epilogue(D: WGMMAMatrix, d_dev):
@NVDSL.mlir_func
-def gemm_warp_specialized(a, b, d):
+def gemm_warp_specialized(a, b, d, num_stages):
token_ty = ir.Type.parse("!gpu.async.token")
t1 = gpu.wait(token_ty, [])
a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
@@ -270,7 +272,7 @@ def gemm_warp_specialized(a, b, d):
size_a = get_type_size(a.type.element_type) * TILE_M * TILE_K
size_b = get_type_size(b.type.element_type) * TILE_N * TILE_K
- smem_size_in_bytes = (size_a + size_b) * NUM_STAGES
+ smem_size_in_bytes = (size_a + size_b) * num_stages
@NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=smem_size_in_bytes)
def gemm_warp_specialized_kernel():
@@ -279,15 +281,15 @@ def gemm_warp_specialized_kernel():
wg_consumer = Warpgroup(primary_thread=0, register_size=232)
# Initialize mbarriers and prefetch TMA descriptors
- mbar_group_mma, mbar_group_tma = bootstrap(a_tma, b_tma)
+ mbar_group_mma, mbar_group_tma = initialize(a_tma, b_tma, num_stages)
# Producer performs TMA
with wg_producer:
- producer_loop(mbar_group_tma, mbar_group_mma, a_tma, b_tma, wg_producer)
+ producer_loop(mbar_group_tma, mbar_group_mma, a_tma, b_tma, wg_producer, num_stages)
# Consumer performs MMA/Tensor Core
with wg_consumer:
- D = consumer_loop(mbar_group_tma, mbar_group_mma, a_tma, b_tma, wg_consumer)
+ D = consumer_loop(mbar_group_tma, mbar_group_mma, a_tma, b_tma, wg_consumer, num_stages)
epilogue(D, d_dev)
gemm_warp_specialized_kernel()
@@ -297,7 +299,6 @@ def gemm_warp_specialized_kernel():
# Python pass arguments to MLIR
-NUM_STAGES = 7
N = 256
M = 512
K = 1024
@@ -308,7 +309,7 @@ def gemm_warp_specialized_kernel():
b = np.random.randn(K, N).astype(np.float16)
d = np.zeros((M, N), np.float32)
-gemm_warp_specialized(a, b, d)
+gemm_warp_specialized(a, b, d, num_stages = 7)
# Verify MLIR with reference computation
diff --git a/mlir/test/Examples/nvgpu/lit.local.cfg b/mlir/test/Examples/nvgpu/lit.local.cfg
index e586b55573898a..689cd252e7a254 100644
--- a/mlir/test/Examples/nvgpu/lit.local.cfg
+++ b/mlir/test/Examples/nvgpu/lit.local.cfg
@@ -1,3 +1,4 @@
config.unsupported = False
if not config.enable_cuda_runner or not config.mlir_run_cuda_sm90_tests:
- config.unsupported = True
\ No newline at end of file
+ config.unsupported = True
+
\ No newline at end of file
diff --git a/mlir/test/Examples/nvgpu/nvdsl.mlir b/mlir/test/Examples/nvgpu/nvdsl.mlir
new file mode 100644
index 00000000000000..e61f7ba0e9a67d
--- /dev/null
+++ b/mlir/test/Examples/nvgpu/nvdsl.mlir
@@ -0,0 +1,186 @@
+module {
+ func.func @gemm_warp_specialized(%arg0: memref<512x1024xf16>, %arg1: memref<1024x256xf16>, %arg2: memref<512x256xf32>) attributes {llvm.emit_c_interface} {
+ %0 = gpu.wait async
+ %memref, %asyncToken = gpu.alloc async [%0] () : memref<512x1024xf16>
+ %memref_0, %asyncToken_1 = gpu.alloc async [%asyncToken] () : memref<1024x256xf16>
+ %memref_2, %asyncToken_3 = gpu.alloc async [%asyncToken_1] () : memref<512x256xf32>
+ %1 = gpu.memcpy async [%asyncToken_3] %memref, %arg0 : memref<512x1024xf16>, memref<512x1024xf16>
+ %2 = gpu.memcpy async [%1] %memref_0, %arg1 : memref<1024x256xf16>, memref<1024x256xf16>
+ %3 = gpu.wait async [%2]
+ %cast = memref.cast %memref : memref<512x1024xf16> to memref<*xf16>
+ %c128 = arith.constant 128 : index
+ %c64 = arith.constant 64 : index
+ %4 = nvgpu.tma.create.descriptor %cast box[%c128, %c64] : memref<*xf16> -> <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
+ %cast_4 = memref.cast %memref_0 : memref<1024x256xf16> to memref<*xf16>
+ %c64_5 = arith.constant 64 : index
+ %c64_6 = arith.constant 64 : index
+ %5 = nvgpu.tma.create.descriptor %cast_4 box[%c64_5, %c64_6] : memref<*xf16> -> <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
+ %c4 = arith.constant 4 : index
+ %c2 = arith.constant 2 : index
+ %c1 = arith.constant 1 : index
+ %c256 = arith.constant 256 : index
+ %c1_7 = arith.constant 1 : index
+ %c1_8 = arith.constant 1 : index
+ %c229376_i32 = arith.constant 229376 : i32
+ gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c4, %arg10 = %c2, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c256, %arg13 = %c1_7, %arg14 = %c1_8) dynamic_shared_memory_size %c229376_i32 {
+ %thread_id_x = gpu.thread_id x
+ %c128_9 = arith.constant 128 : index
+ %7 = arith.remui %thread_id_x, %c128_9 : index
+ %c0 = arith.constant 0 : index
+ %8 = arith.cmpi eq, %7, %c0 : index
+ %c128_10 = arith.constant 128 : index
+ %9 = arith.divui %thread_id_x, %c128_10 : index
+ %c1_11 = arith.constant 1 : index
+ %10 = arith.cmpi eq, %9, %c1_11 : index
+ %thread_id_x_12 = gpu.thread_id x
+ %c128_13 = arith.constant 128 : index
+ %11 = arith.remui %thread_id_x_12, %c128_13 : index
+ %c0_14 = arith.constant 0 : index
+ %12 = arith.cmpi eq, %11, %c0_14 : index
+ %c128_15 = arith.constant 128 : index
+ %13 = arith.divui %thread_id_x_12, %c128_15 : index
+ %c0_16 = arith.constant 0 : index
+ %14 = arith.cmpi eq, %13, %c0_16 : index
+ %thread_id_x_17 = gpu.thread_id x
+ %15 = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+ %16 = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+ %c0_18 = arith.constant 0 : index
+ %17 = arith.cmpi eq, %thread_id_x_17, %c0_18 : index
+ scf.if %17 {
+ %c0_19 = arith.constant 0 : index
+ %c7 = arith.constant 7 : index
+ %c1_20 = arith.constant 1 : index
+ scf.for %arg15 = %c0_19 to %c7 step %c1_20 {
+ %c1_21 = arith.constant 1 : index
+ nvgpu.mbarrier.init %15[%arg15], %c1_21 : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+ %c1_22 = arith.constant 1 : index
+ nvgpu.mbarrier.init %16[%arg15], %c1_22 : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+ }
+ nvgpu.tma.prefetch.descriptor %4 : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
+ nvgpu.tma.prefetch.descriptor %5 : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
+ }
+ scf.if %10 {
+ nvvm.setmaxregister decrease 40
+ %true = arith.constant true
+ %c0_19 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c1_20 = arith.constant 1 : index
+ %18 = scf.for %arg15 = %c0_19 to %c16 step %c1_20 iter_args(%arg16 = %true) -> (i1) {
+ %c7 = arith.constant 7 : index
+ %19 = arith.remui %arg15, %c7 : index
+ %c10000000 = arith.constant 10000000 : index
+ nvgpu.mbarrier.try_wait.parity %15[%19], %arg16, %c10000000 : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+ %c6 = arith.constant 6 : index
+ %20 = arith.cmpi eq, %19, %c6 : index
+ %true_21 = arith.constant true
+ %21 = arith.xori %arg16, %true_21 : i1
+ %22 = arith.select %20, %21, %arg16 : i1
+ %block_id_x = gpu.block_id x
+ %block_id_y = gpu.block_id y
+ %c128_22 = arith.constant 128 : index
+ %23 = arith.muli %block_id_x, %c128_22 : index
+ %c128_23 = arith.constant 128 : index
+ %24 = arith.muli %block_id_y, %c128_23 : index
+ %thread_id_x_24 = gpu.thread_id x
+ %c16384 = arith.constant 16384 : index
+ %25 = arith.muli %19, %c16384 : index
+ %c16384_25 = arith.constant 16384 : index
+ %26 = arith.muli %19, %c16384_25 : index
+ %c114688 = arith.constant 114688 : index
+ %27 = arith.addi %26, %c114688 : index
+ %c8192 = arith.constant 8192 : index
+ %28 = arith.addi %27, %c8192 : index
+ %29 = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+ %view = memref.view %29[%25][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>>
+ %30 = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+ %view_26 = memref.view %30[%27][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
+ %31 = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+ %view_27 = memref.view %31[%28][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
+ %c32768 = arith.constant 32768 : index
+ nvgpu.mbarrier.arrive.expect_tx %16[%19], %c32768, predicate = %8 : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+ %c128_28 = arith.constant 128 : index
+ %32 = arith.remui %thread_id_x_24, %c128_28 : index
+ %c0_29 = arith.constant 0 : index
+ %33 = arith.cmpi eq, %32, %c0_29 : index
+ %c64_30 = arith.constant 64 : index
+ %34 = arith.muli %arg15, %c64_30 : index
+ nvgpu.tma.async.load %4[%34, %23], %16[%19] to %view, predicate = %33 : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<128x64xf16, #gpu.address_space<workgroup>>
+ nvgpu.tma.async.load %5[%24, %34], %16[%19] to %view_26, predicate = %33 : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>>
+ %c64_31 = arith.constant 64 : index
+ %35 = arith.addi %24, %c64_31 : index
+ nvgpu.tma.async.load %5[%35, %34], %16[%19] to %view_27, predicate = %33 : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>>
+ scf.yield %22 : i1
+ }
+ }
+ scf.if %14 {
+ nvvm.setmaxregister increase 232
+ %false = arith.constant false
+ %18 = nvgpu.warpgroup.mma.init.accumulator -> <fragmented = vector<128x128xf32>>
+ %c0_19 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c1_20 = arith.constant 1 : index
+ %19:2 = scf.for %arg15 = %c0_19 to %c16 step %c1_20 iter_args(%arg16 = %18, %arg17 = %false) -> (!nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>, i1) {
+ %c7 = arith.constant 7 : index
+ %23 = arith.remui %arg15, %c7 : index
+ %c10000000 = arith.constant 10000000 : index
+ nvgpu.mbarrier.try_wait.parity %16[%23], %arg17, %c10000000 : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+ %c16384 = arith.constant 16384 : index
+ %24 = arith.muli %23, %c16384 : index
+ %c114688 = arith.constant 114688 : index
+ %25 = arith.addi %24, %c114688 : index
+ %26 = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+ %view_28 = memref.view %26[%24][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>>
+ %27 = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+ %view_29 = memref.view %27[%25][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x128xf16, #gpu.address_space<workgroup>>
+ %28 = nvgpu.warpgroup.generate.descriptor %view_28, %4 : memref<128x64xf16, #gpu.address_space<workgroup>>, <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>>
+ %29 = nvgpu.warpgroup.generate.descriptor %view_29, %5 : memref<64x128xf16, #gpu.address_space<workgroup>>, <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>>
+ %30 = nvgpu.warpgroup.mma %28, %29, %arg16 {transposeB} : <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>>, <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>>, <fragmented = vector<128x128xf32>> -> <fragmented = vector<128x128xf32>>
+ %c0_30 = arith.constant 0 : index
+ %31 = arith.cmpi ugt, %arg15, %c0_30 : index
+ %32 = arith.andi %31, %12 : i1
+ scf.if %32 {
+ %c0_31 = arith.constant 0 : index
+ %36 = arith.cmpi eq, %23, %c0_31 : index
+ %c6_32 = arith.constant 6 : index
+ %c1_33 = arith.constant 1 : index
+ %37 = arith.subi %23, %c1_33 : index
+ %38 = arith.select %36, %c6_32, %37 : index
+ %39 = nvgpu.mbarrier.arrive %15[%38] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> !nvgpu.mbarrier.token
+ }
+ %c6 = arith.constant 6 : index
+ %33 = arith.cmpi eq, %23, %c6 : index
+ %true = arith.constant true
+ %34 = arith.xori %arg17, %true : i1
+ %35 = arith.select %33, %34, %arg17 : i1
+ scf.yield %30, %35 : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>, i1
+ }
+ nvvm.wgmma.wait.group.sync.aligned 0
+ %thread_id_x_21 = gpu.thread_id x
+ %block_id_x = gpu.block_id x
+ %block_id_y = gpu.block_id y
+ %c128_22 = arith.constant 128 : index
+ %20 = arith.muli %block_id_x, %c128_22 : index
+ %c128_23 = arith.constant 128 : index
+ %21 = arith.muli %block_id_y, %c128_23 : index
+ %22 = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+ %c0_24 = arith.constant 0 : index
+ %view = memref.view %22[%c0_24][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x128xf32, #gpu.address_space<workgroup>>
+ %subview = memref.subview %memref_2[%20, %21] [128, 128] [1, 1] : memref<512x256xf32> to memref<128x128xf32, strided<[256, 1], offset: ?>>
+ nvgpu.warpgroup.mma.store %19#0, %view : <fragmented = vector<128x128xf32>> to memref<128x128xf32, #gpu.address_space<workgroup>>
+ gpu.barrier
+ %c0_25 = arith.constant 0 : index
+ %c128_26 = arith.constant 128 : index
+ %c1_27 = arith.constant 1 : index
+ scf.for %arg15 = %c0_25 to %c128_26 step %c1_27 {
+ %23 = memref.load %view[%arg15, %thread_id_x_21] : memref<128x128xf32, #gpu.address_space<workgroup>>
+ memref.store %23, %subview[%arg15, %thread_id_x_21] : memref<128x128xf32, strided<[256, 1], offset: ?>>
+ }
+ }
+ gpu.terminator
+ }
+ %6 = gpu.memcpy async [%3] %arg2, %memref_2 : memref<512x256xf32>, memref<512x256xf32>
+ gpu.wait [%6]
+ return
+ }
+}
+
diff --git a/mlir/test/Examples/nvgpu/tools/nvdsl.py b/mlir/test/Examples/nvgpu/tools/nvdsl.py
index 2c1615d6f9d68b..06faf692346d03 100644
--- a/mlir/test/Examples/nvgpu/tools/nvdsl.py
+++ b/mlir/test/Examples/nvgpu/tools/nvdsl.py
@@ -173,6 +173,7 @@ def load(self, dest, mbarrier: Mbarriers, coords=[0], predicate=None):
class Warpgroup:
def __init__(self, primary_thread, register_size):
+ assert (primary_thread % WARP_GROUP_SIZE) == 0
tidx = gpu.thread_id(gpu.Dimension.x)
self.primary_thread = primary_thread
self.register_size = register_size
@@ -280,7 +281,6 @@ def get_dynamic_shared_memory(shape=None, ty=None, offset: int = 0):
)
- at staticmethod
def get_mlir_ty(arg):
def get_mlir_ty_from_np(dtype):
if dtype == np.float16:
@@ -433,11 +433,11 @@ def __str__(self):
result = funcBody(*fargs, **kwargs)
func.ReturnOp([])
- # Verify the module
- module.operation.verify()
-
# Save IR in a file
- # saveIR(module)
+ saveIR(module)
+
+ # Verify the module
+ # module.operation.verify()
# Compile and JIT MLIR module
options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
>From 749a3cf81a710ffbfacdea9da90cb70e488fd225 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Thu, 18 Apr 2024 07:42:15 +0000
Subject: [PATCH 15/18] Use TensorMapDescriptorType
---
mlir/test/Examples/nvgpu/tools/nvdsl.py | 32 ++++++++++---------------
1 file changed, 13 insertions(+), 19 deletions(-)
diff --git a/mlir/test/Examples/nvgpu/tools/nvdsl.py b/mlir/test/Examples/nvgpu/tools/nvdsl.py
index 06faf692346d03..501cdcd353fa88 100644
--- a/mlir/test/Examples/nvgpu/tools/nvdsl.py
+++ b/mlir/test/Examples/nvgpu/tools/nvdsl.py
@@ -104,7 +104,7 @@ class TMA:
def __init__(
self,
- shape,
+ tma_box_shape,
memref_ty,
swizzle=nvgpu.TensorMapSwizzleKind.SWIZZLE_NONE,
l2promo=nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
@@ -115,31 +115,25 @@ def __init__(
self.l2promo = l2promo # mlir.nvgpu.TensorMapL2PromoKind
self.oob = oob # mlir.nvgpu.TensorMapOOBKind
self.interleave = interleave # mlir.nvgpu.TensorMapInterleaveKind
- self.shape = shape
+ self.tma_box_shape = tma_box_shape
self.memref_ty = memref_ty # MemRefType
- self.lastDim = 64
- self.requiredLoad = 1
- self.tma_shape = shape
- self.tma_memref = ir.MemRefType.get(shape, memref_ty.element_type)
+ self.tma_memref = ir.MemRefType.get(tma_box_shape, memref_ty.element_type)
@property
def tensormap_descriptor_ty(self):
"""Returns a tensormap descriptor type."""
- memref_str = (
- "memref<"
- + "x".join(map(str, self.tma_shape))
- + "x"
- + str(self.memref_ty.element_type)
- + ", 3>"
+ tensorMemrefType = ir.MemRefType.get(
+ self.tma_box_shape,
+ self.memref_ty.element_type,
+ memory_space=ir.Attribute.parse("3"),
)
- parse_str = (
- f"!nvgpu.tensormap.descriptor<tensor = {memref_str}, "
- f"swizzle = {self.swizzle}, "
- f"l2promo = {self.l2promo}, "
- f"oob = {self.oob}, "
- f"interleave = {self.interleave}>"
+ return nvgpu.TensorMapDescriptorType.get(
+ tensorMemrefType,
+ self.swizzle,
+ self.l2promo,
+ self.oob,
+ self.interleave,
)
- return ir.Type.parse(parse_str)
def create_descriptor(self, device_ptr):
tma_descriptor_ty = self.tensormap_descriptor_ty
>From b9dec730258c008730175fd92db329a03f84e305 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Thu, 18 Apr 2024 07:57:50 +0000
Subject: [PATCH 16/18] fix typo
---
mlir/test/Examples/nvgpu/tools/nvdsl.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/test/Examples/nvgpu/tools/nvdsl.py b/mlir/test/Examples/nvgpu/tools/nvdsl.py
index 501cdcd353fa88..600cae5b47eeec 100644
--- a/mlir/test/Examples/nvgpu/tools/nvdsl.py
+++ b/mlir/test/Examples/nvgpu/tools/nvdsl.py
@@ -144,7 +144,7 @@ def create_descriptor(self, device_ptr):
device_ptr,
)
self.tma_descriptor = nvgpu.TmaCreateDescriptorOp(
- tma_descriptor_ty, device_unranked_memref, map(const, self.tma_shape)
+ tma_descriptor_ty, device_unranked_memref, map(const, self.tma_box_shape)
)
return self.tma_descriptor.result
@@ -428,7 +428,7 @@ def __str__(self):
func.ReturnOp([])
# Save IR in a file
- saveIR(module)
+ # saveIR(module)
# Verify the module
# module.operation.verify()
>From 82ca231dc2c76e00f35411b79c3c36773d32ff7b Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Thu, 18 Apr 2024 12:38:44 +0000
Subject: [PATCH 17/18] fix format
---
mlir/test/Examples/nvgpu/Ch0.py | 18 +++++++++---------
mlir/test/Examples/nvgpu/Ch4.py | 9 +++++----
mlir/test/Examples/nvgpu/Ch5.py | 28 ++++++++++++++--------------
3 files changed, 28 insertions(+), 27 deletions(-)
diff --git a/mlir/test/Examples/nvgpu/Ch0.py b/mlir/test/Examples/nvgpu/Ch0.py
index ff294c7ec416ef..8f60088178d119 100644
--- a/mlir/test/Examples/nvgpu/Ch0.py
+++ b/mlir/test/Examples/nvgpu/Ch0.py
@@ -18,19 +18,19 @@
from tools.nvdsl import *
-# 1. The `mlir_func` decorator generates a MLIR `func.func`.
-# Everything inside the Python function becomes the body of the `func`.
-# The decorator also translates `alpha` to an `index` type.
+# 1. The decorator generates a MLIR func.func.
+# Everything inside the Python function becomes the body of the func.
+# The decorator also translates `alpha` to an `index` type.
@NVDSL.mlir_func
def main(alpha):
-# 2. The `mlir_gpu_launch` decorator generates a MLIR `gpu.launch`.
-# Everything inside the Python function becomes the body of the `gpu.launch`.
-# This allows for late outlining of the GPU kernel, enabling optimizations
-# like constant folding from host to device.
+ # 2. The decorator generates a MLIR gpu.launch.
+ # Everything inside the Python function becomes the body of the gpu.launch.
+ # This allows for late outlining of the GPU kernel, enabling optimizations
+ # like constant folding from host to device.
@NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(4, 1, 1))
def kernel():
tidx = gpu.thread_id(gpu.Dimension.x)
- # `+`` operator generates a `arith.addi`
+ # + operator generates arith.addi
myValue = alpha + tidx
# Print from a GPU thread
gpu.printf("GPU thread %llu has %llu\n", [tidx, myValue])
@@ -38,12 +38,12 @@ def kernel():
# 3. Call the GPU kernel
kernel()
+
alpha = 100
# 4. The `mlir_func` decorator JIT compiles the IR and executes the MLIR function.
main(alpha)
-
# CHECK: GPU thread 0 has 100
# CHECK: GPU thread 1 has 101
# CHECK: GPU thread 2 has 102
diff --git a/mlir/test/Examples/nvgpu/Ch4.py b/mlir/test/Examples/nvgpu/Ch4.py
index 175d0aad5f77d7..8f38d8a90add31 100644
--- a/mlir/test/Examples/nvgpu/Ch4.py
+++ b/mlir/test/Examples/nvgpu/Ch4.py
@@ -251,10 +251,11 @@ def epilogue(D: WGMMAMatrix, d_dev):
memref.store(val, d_gmem, [i, tidx])
scf.yield_([])
-# The decorator generates
-# a -> memref<MxKxf16>
-# b -> memref<NxKf16>
-# d -> memref<MxNxf32>
+
+# The decorator generates
+# a -> memref<MxKxf16>
+# b -> memref<NxKf16>
+# d -> memref<MxNxf32>
@NVDSL.mlir_func
def gemm_multistage(a, b, d, num_stages):
token_ty = ir.Type.parse("!gpu.async.token")
diff --git a/mlir/test/Examples/nvgpu/Ch5.py b/mlir/test/Examples/nvgpu/Ch5.py
index daffb38ba0db7e..92e9314e1b812d 100644
--- a/mlir/test/Examples/nvgpu/Ch5.py
+++ b/mlir/test/Examples/nvgpu/Ch5.py
@@ -147,33 +147,33 @@ def switch_phase(stage, phase, num_stages):
def producer_loop(
- mbar_group_tma: Mbarriers,
- mbar_group_mma: Mbarriers,
+ mbar_tma: Mbarriers,
+ mbar_mma: Mbarriers,
a_tma: TMA,
b_tma: TMA,
wg_me: Warpgroup,
- num_stages
+ num_stages,
):
phase = const(True, ty=T.bool())
for iv, phase in scf.for_(0, (K // TILE_K), 1, [phase]):
stage = iv % num_stages
# Wait MMA to be done
- mbar_group_mma[stage].try_wait(phase)
+ mbar_mma[stage].try_wait(phase)
# New phase for mbarrier
phase = switch_phase(stage, phase, num_stages)
# TMA Load
- tma_load(mbar_group_tma, a_tma, b_tma, stage, iv, num_stages, wg_me.is_wg_primary)
+ tma_load(mbar_tma, a_tma, b_tma, stage, iv, num_stages, wg_me.is_wg_primary)
scf.yield_([phase])
def consumer_loop(
- mbar_group_tma: Mbarriers,
- mbar_group_mma: Mbarriers,
+ mbar_tma: Mbarriers,
+ mbar_mma: Mbarriers,
a_tma: TMA,
b_tma: TMA,
wg_me: Warpgroup,
- num_stages
+ num_stages,
):
begin_b = num_stages * get_type_size(a_tma.tma_memref)
@@ -191,7 +191,7 @@ def consumer_loop(
stage = iv % num_stages
# Wait TMA for current stage
- mbar_group_tma[stage].try_wait(phase)
+ mbar_tma[stage].try_wait(phase)
# Find shared memory slot
offset_a = stage * size_a
@@ -211,7 +211,7 @@ def consumer_loop(
p_arrive = (iv > 0) & wg_me.is_wg_primary
with ir.InsertionPoint(scf.IfOp(p_arrive).then_block):
barId = arith.select((stage == 0), const(num_stages - 1), (stage - 1))
- mbar_group_mma[barId].arrive()
+ mbar_mma[barId].arrive()
scf.yield_([])
phase = switch_phase(stage, phase, num_stages)
@@ -281,15 +281,15 @@ def gemm_warp_specialized_kernel():
wg_consumer = Warpgroup(primary_thread=0, register_size=232)
# Initialize mbarriers and prefetch TMA descriptors
- mbar_group_mma, mbar_group_tma = initialize(a_tma, b_tma, num_stages)
+ mbar_mma, mbar_tma = initialize(a_tma, b_tma, num_stages)
# Producer performs TMA
with wg_producer:
- producer_loop(mbar_group_tma, mbar_group_mma, a_tma, b_tma, wg_producer, num_stages)
+ producer_loop(mbar_tma, mbar_mma, a_tma, b_tma, wg_producer, num_stages)
# Consumer performs MMA/Tensor Core
with wg_consumer:
- D = consumer_loop(mbar_group_tma, mbar_group_mma, a_tma, b_tma, wg_consumer, num_stages)
+ D = consumer_loop(mbar_tma, mbar_mma, a_tma, b_tma, wg_consumer, num_stages)
epilogue(D, d_dev)
gemm_warp_specialized_kernel()
@@ -309,7 +309,7 @@ def gemm_warp_specialized_kernel():
b = np.random.randn(K, N).astype(np.float16)
d = np.zeros((M, N), np.float32)
-gemm_warp_specialized(a, b, d, num_stages = 7)
+gemm_warp_specialized(a, b, d, num_stages=7)
# Verify MLIR with reference computation
>From 3826649fb04c0e0107f7cbe64777eedbe58e6a03 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Fri, 19 Apr 2024 12:56:03 +0000
Subject: [PATCH 18/18] delete temp file
---
mlir/test/Examples/nvgpu/nvdsl.mlir | 186 ----------------------------
1 file changed, 186 deletions(-)
delete mode 100644 mlir/test/Examples/nvgpu/nvdsl.mlir
diff --git a/mlir/test/Examples/nvgpu/nvdsl.mlir b/mlir/test/Examples/nvgpu/nvdsl.mlir
deleted file mode 100644
index e61f7ba0e9a67d..00000000000000
--- a/mlir/test/Examples/nvgpu/nvdsl.mlir
+++ /dev/null
@@ -1,186 +0,0 @@
-module {
- func.func @gemm_warp_specialized(%arg0: memref<512x1024xf16>, %arg1: memref<1024x256xf16>, %arg2: memref<512x256xf32>) attributes {llvm.emit_c_interface} {
- %0 = gpu.wait async
- %memref, %asyncToken = gpu.alloc async [%0] () : memref<512x1024xf16>
- %memref_0, %asyncToken_1 = gpu.alloc async [%asyncToken] () : memref<1024x256xf16>
- %memref_2, %asyncToken_3 = gpu.alloc async [%asyncToken_1] () : memref<512x256xf32>
- %1 = gpu.memcpy async [%asyncToken_3] %memref, %arg0 : memref<512x1024xf16>, memref<512x1024xf16>
- %2 = gpu.memcpy async [%1] %memref_0, %arg1 : memref<1024x256xf16>, memref<1024x256xf16>
- %3 = gpu.wait async [%2]
- %cast = memref.cast %memref : memref<512x1024xf16> to memref<*xf16>
- %c128 = arith.constant 128 : index
- %c64 = arith.constant 64 : index
- %4 = nvgpu.tma.create.descriptor %cast box[%c128, %c64] : memref<*xf16> -> <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
- %cast_4 = memref.cast %memref_0 : memref<1024x256xf16> to memref<*xf16>
- %c64_5 = arith.constant 64 : index
- %c64_6 = arith.constant 64 : index
- %5 = nvgpu.tma.create.descriptor %cast_4 box[%c64_5, %c64_6] : memref<*xf16> -> <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
- %c4 = arith.constant 4 : index
- %c2 = arith.constant 2 : index
- %c1 = arith.constant 1 : index
- %c256 = arith.constant 256 : index
- %c1_7 = arith.constant 1 : index
- %c1_8 = arith.constant 1 : index
- %c229376_i32 = arith.constant 229376 : i32
- gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c4, %arg10 = %c2, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c256, %arg13 = %c1_7, %arg14 = %c1_8) dynamic_shared_memory_size %c229376_i32 {
- %thread_id_x = gpu.thread_id x
- %c128_9 = arith.constant 128 : index
- %7 = arith.remui %thread_id_x, %c128_9 : index
- %c0 = arith.constant 0 : index
- %8 = arith.cmpi eq, %7, %c0 : index
- %c128_10 = arith.constant 128 : index
- %9 = arith.divui %thread_id_x, %c128_10 : index
- %c1_11 = arith.constant 1 : index
- %10 = arith.cmpi eq, %9, %c1_11 : index
- %thread_id_x_12 = gpu.thread_id x
- %c128_13 = arith.constant 128 : index
- %11 = arith.remui %thread_id_x_12, %c128_13 : index
- %c0_14 = arith.constant 0 : index
- %12 = arith.cmpi eq, %11, %c0_14 : index
- %c128_15 = arith.constant 128 : index
- %13 = arith.divui %thread_id_x_12, %c128_15 : index
- %c0_16 = arith.constant 0 : index
- %14 = arith.cmpi eq, %13, %c0_16 : index
- %thread_id_x_17 = gpu.thread_id x
- %15 = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
- %16 = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
- %c0_18 = arith.constant 0 : index
- %17 = arith.cmpi eq, %thread_id_x_17, %c0_18 : index
- scf.if %17 {
- %c0_19 = arith.constant 0 : index
- %c7 = arith.constant 7 : index
- %c1_20 = arith.constant 1 : index
- scf.for %arg15 = %c0_19 to %c7 step %c1_20 {
- %c1_21 = arith.constant 1 : index
- nvgpu.mbarrier.init %15[%arg15], %c1_21 : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
- %c1_22 = arith.constant 1 : index
- nvgpu.mbarrier.init %16[%arg15], %c1_22 : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
- }
- nvgpu.tma.prefetch.descriptor %4 : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
- nvgpu.tma.prefetch.descriptor %5 : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
- }
- scf.if %10 {
- nvvm.setmaxregister decrease 40
- %true = arith.constant true
- %c0_19 = arith.constant 0 : index
- %c16 = arith.constant 16 : index
- %c1_20 = arith.constant 1 : index
- %18 = scf.for %arg15 = %c0_19 to %c16 step %c1_20 iter_args(%arg16 = %true) -> (i1) {
- %c7 = arith.constant 7 : index
- %19 = arith.remui %arg15, %c7 : index
- %c10000000 = arith.constant 10000000 : index
- nvgpu.mbarrier.try_wait.parity %15[%19], %arg16, %c10000000 : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
- %c6 = arith.constant 6 : index
- %20 = arith.cmpi eq, %19, %c6 : index
- %true_21 = arith.constant true
- %21 = arith.xori %arg16, %true_21 : i1
- %22 = arith.select %20, %21, %arg16 : i1
- %block_id_x = gpu.block_id x
- %block_id_y = gpu.block_id y
- %c128_22 = arith.constant 128 : index
- %23 = arith.muli %block_id_x, %c128_22 : index
- %c128_23 = arith.constant 128 : index
- %24 = arith.muli %block_id_y, %c128_23 : index
- %thread_id_x_24 = gpu.thread_id x
- %c16384 = arith.constant 16384 : index
- %25 = arith.muli %19, %c16384 : index
- %c16384_25 = arith.constant 16384 : index
- %26 = arith.muli %19, %c16384_25 : index
- %c114688 = arith.constant 114688 : index
- %27 = arith.addi %26, %c114688 : index
- %c8192 = arith.constant 8192 : index
- %28 = arith.addi %27, %c8192 : index
- %29 = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
- %view = memref.view %29[%25][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>>
- %30 = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
- %view_26 = memref.view %30[%27][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
- %31 = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
- %view_27 = memref.view %31[%28][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
- %c32768 = arith.constant 32768 : index
- nvgpu.mbarrier.arrive.expect_tx %16[%19], %c32768, predicate = %8 : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
- %c128_28 = arith.constant 128 : index
- %32 = arith.remui %thread_id_x_24, %c128_28 : index
- %c0_29 = arith.constant 0 : index
- %33 = arith.cmpi eq, %32, %c0_29 : index
- %c64_30 = arith.constant 64 : index
- %34 = arith.muli %arg15, %c64_30 : index
- nvgpu.tma.async.load %4[%34, %23], %16[%19] to %view, predicate = %33 : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<128x64xf16, #gpu.address_space<workgroup>>
- nvgpu.tma.async.load %5[%24, %34], %16[%19] to %view_26, predicate = %33 : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>>
- %c64_31 = arith.constant 64 : index
- %35 = arith.addi %24, %c64_31 : index
- nvgpu.tma.async.load %5[%35, %34], %16[%19] to %view_27, predicate = %33 : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>>
- scf.yield %22 : i1
- }
- }
- scf.if %14 {
- nvvm.setmaxregister increase 232
- %false = arith.constant false
- %18 = nvgpu.warpgroup.mma.init.accumulator -> <fragmented = vector<128x128xf32>>
- %c0_19 = arith.constant 0 : index
- %c16 = arith.constant 16 : index
- %c1_20 = arith.constant 1 : index
- %19:2 = scf.for %arg15 = %c0_19 to %c16 step %c1_20 iter_args(%arg16 = %18, %arg17 = %false) -> (!nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>, i1) {
- %c7 = arith.constant 7 : index
- %23 = arith.remui %arg15, %c7 : index
- %c10000000 = arith.constant 10000000 : index
- nvgpu.mbarrier.try_wait.parity %16[%23], %arg17, %c10000000 : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
- %c16384 = arith.constant 16384 : index
- %24 = arith.muli %23, %c16384 : index
- %c114688 = arith.constant 114688 : index
- %25 = arith.addi %24, %c114688 : index
- %26 = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
- %view_28 = memref.view %26[%24][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>>
- %27 = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
- %view_29 = memref.view %27[%25][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x128xf16, #gpu.address_space<workgroup>>
- %28 = nvgpu.warpgroup.generate.descriptor %view_28, %4 : memref<128x64xf16, #gpu.address_space<workgroup>>, <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>>
- %29 = nvgpu.warpgroup.generate.descriptor %view_29, %5 : memref<64x128xf16, #gpu.address_space<workgroup>>, <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>>
- %30 = nvgpu.warpgroup.mma %28, %29, %arg16 {transposeB} : <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>>, <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>>, <fragmented = vector<128x128xf32>> -> <fragmented = vector<128x128xf32>>
- %c0_30 = arith.constant 0 : index
- %31 = arith.cmpi ugt, %arg15, %c0_30 : index
- %32 = arith.andi %31, %12 : i1
- scf.if %32 {
- %c0_31 = arith.constant 0 : index
- %36 = arith.cmpi eq, %23, %c0_31 : index
- %c6_32 = arith.constant 6 : index
- %c1_33 = arith.constant 1 : index
- %37 = arith.subi %23, %c1_33 : index
- %38 = arith.select %36, %c6_32, %37 : index
- %39 = nvgpu.mbarrier.arrive %15[%38] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> !nvgpu.mbarrier.token
- }
- %c6 = arith.constant 6 : index
- %33 = arith.cmpi eq, %23, %c6 : index
- %true = arith.constant true
- %34 = arith.xori %arg17, %true : i1
- %35 = arith.select %33, %34, %arg17 : i1
- scf.yield %30, %35 : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>, i1
- }
- nvvm.wgmma.wait.group.sync.aligned 0
- %thread_id_x_21 = gpu.thread_id x
- %block_id_x = gpu.block_id x
- %block_id_y = gpu.block_id y
- %c128_22 = arith.constant 128 : index
- %20 = arith.muli %block_id_x, %c128_22 : index
- %c128_23 = arith.constant 128 : index
- %21 = arith.muli %block_id_y, %c128_23 : index
- %22 = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
- %c0_24 = arith.constant 0 : index
- %view = memref.view %22[%c0_24][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x128xf32, #gpu.address_space<workgroup>>
- %subview = memref.subview %memref_2[%20, %21] [128, 128] [1, 1] : memref<512x256xf32> to memref<128x128xf32, strided<[256, 1], offset: ?>>
- nvgpu.warpgroup.mma.store %19#0, %view : <fragmented = vector<128x128xf32>> to memref<128x128xf32, #gpu.address_space<workgroup>>
- gpu.barrier
- %c0_25 = arith.constant 0 : index
- %c128_26 = arith.constant 128 : index
- %c1_27 = arith.constant 1 : index
- scf.for %arg15 = %c0_25 to %c128_26 step %c1_27 {
- %23 = memref.load %view[%arg15, %thread_id_x_21] : memref<128x128xf32, #gpu.address_space<workgroup>>
- memref.store %23, %subview[%arg15, %thread_id_x_21] : memref<128x128xf32, strided<[256, 1], offset: ?>>
- }
- }
- gpu.terminator
- }
- %6 = gpu.memcpy async [%3] %arg2, %memref_2 : memref<512x256xf32>, memref<512x256xf32>
- gpu.wait [%6]
- return
- }
-}
-
More information about the Mlir-commits
mailing list