[Mlir-commits] [mlir] [mlir][nvgpu] NVGPU Tutorials (PR #87065)

Jacques Pienaar llvmlistbot at llvm.org
Wed Apr 10 00:59:25 PDT 2024


================
@@ -0,0 +1,319 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN:   %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+#  Chapter 4 : Multistage GEMM with Tensor Core
+# ===----------------------------------------------------------------------===//
+#
+# This program exemplifies a GEMM operation for `f32+=f16*f16`, utilizing the
+# Multistage method with a tile size of 128x128x64. The code completely
+# parallelizes the two outermost loops into thread blocks. It launches one Warp
+# Groups (128 threads in total) and allocates multiple slots/stage in the
+# shared memory. The program consists of three main parts: prologue, mainloop,
+# and epilogue. In the prologue, thread0 requests for TMA to load data into
+# shared memory slots. The mainloop executes MMA while simultaneously loading
+# TMA for the utilized slots. This overlap of TMA and MMA operations enhances
+# performance by maximizing computational throughput.
+#
+# Loops illustration:
+#
+#  for s in range(NUM_STAGES):
+#    TMA_128x64_64x128...
+#  for ti in range(M//128):  # -> blockIdx.x
+#   for tj in range(N//128): # -> blockIdx.y
+#    for tk in range(K//64):
+#      MMA_128x128x64...
+#      TMA_128x64_64x128...
+#  Epilogue...
+#
+# This chapter introduces demonstrates:
+#  1. Partition shape based on block IDs
+#  2. Prologue
+#    2.1 Execute TMA Load for two input matrices for each stage
+#  3. Main loop
+#    3.1 Wait for completion of TMA load with mbarrier
+#    3.2 Performs Tensor Core GEMM 64x128x64 by warpgroup
+#    3.3 Load next stage if needed
+#  4. Epilogue
+#    4.1 Store fragmented registers to shared memory
+#    4.2 Store shared memory to global
+#
+# ===----------------------------------------------------------------------===//
+
+
+from mlir import ir
+from mlir.dialects import gpu, scf, nvgpu, nvvm
+from mlir.extras import types as T
+from tools.nvdsl import *
+import numpy as np
+
+
+def partition_shape():
+    """
+    Calculate the partition shape based on the block IDs.
+
+    It partitions the shape like below:
+    for(.. i < M ...)   --> blockIdx.x
+     for(.. j < N ...)  --> blockIdx.y
+      for(.. k < K ...)
+
+    Returns:
+        dimX (int): Dimension along the x-axis.
+        dimY (int): Dimension along the y-axis.
+    """
+    bidx = gpu.block_id(gpu.Dimension.x)
+    bidy = gpu.block_id(gpu.Dimension.y)
+    dimX = bidx * TILE_M
+    dimY = bidy * TILE_N
+    return dimX, dimY
+
+
+def tma_load(
+    mbar_group: Mbarriers,
+    a_tma: TMA,
+    b_tma: TMA,
+    slot,
+    stage,
+    p=None,
+):
+    """
+    TMA loads two input matrices from global memory to shared memory. It performs the following operations:
+
+       - tma.load a_shared_memory[off_x]  at coordinate [x, z]      (Loads 128x64)
+       - tma.load b_shared_memory[off_y1] at coordinate [y, x]      (Loads 64x64)
+       - tma.load b_shared_memory[off_y2] at coordinate [y + 64, x] (Loads 64x64)
+
+       mbarrier.arrive ta_count = 128x64x2x4
+    """
+    dimX, dimY = partition_shape()
+
+    tidx = gpu.thread_id(gpu.Dimension.x)
+    begin_b = NUM_STAGES * get_type_size(a_tma.tma_memref)
+    size_tma_a = get_type_size(a_tma.tma_memref)
+    size_tma_b = get_type_size(b_tma.tma_memref)
+    ta_count = size_tma_a + (size_tma_b * 2)
+    tidx = gpu.thread_id(gpu.Dimension.x)
+
+    p = tidx == 0 if p is None else p
+
+    off_a = slot * size_tma_a
+    off_b = (slot * size_tma_a) + begin_b
+    off_b2 = off_b + size_tma_b
+    a_elem_ty = a_tma.tma_memref.element_type
+    b_elem_ty = b_tma.tma_memref.element_type
+    a = get_dynamic_shared_memory(a_tma.tma_memref.shape, a_elem_ty, off_a)
+    b1 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b)
+    b2 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b2)
+
+    mbar_group[slot].arrive(ta_count, predicate=p)
+
+    c1 = stage * 64
+    a_tma.load(a, mbar_group[slot], coords=[c1, dimX], predicate=p)
+    b_tma.load(b1, mbar_group[slot], coords=[dimY, c1], predicate=p)
+    b_tma.load(b2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p)
+
+
+def bootstrap(a_tma: TMA, b_tma: TMA):
----------------
jpienaar wrote:

Why not just init?

https://github.com/llvm/llvm-project/pull/87065


More information about the Mlir-commits mailing list