[Mlir-commits] [mlir] [mlir][nvgpu] NVGPU Tutorials (PR #87065)

Jacques Pienaar llvmlistbot at llvm.org
Wed Apr 10 00:59:25 PDT 2024


================
@@ -0,0 +1,319 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN:   %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+#  Chapter 4 : Multistage GEMM with Tensor Core
+# ===----------------------------------------------------------------------===//
+#
+# This program exemplifies a GEMM operation for `f32+=f16*f16`, utilizing the
+# Multistage method with a tile size of 128x128x64. The code completely
+# parallelizes the two outermost loops into thread blocks. It launches one Warp
+# Groups (128 threads in total) and allocates multiple slots/stage in the
+# shared memory. The program consists of three main parts: prologue, mainloop,
+# and epilogue. In the prologue, thread0 requests for TMA to load data into
+# shared memory slots. The mainloop executes MMA while simultaneously loading
+# TMA for the utilized slots. This overlap of TMA and MMA operations enhances
+# performance by maximizing computational throughput.
+#
+# Loops illustration:
+#
+#  for s in range(NUM_STAGES):
+#    TMA_128x64_64x128...
+#  for ti in range(M//128):  # -> blockIdx.x
+#   for tj in range(N//128): # -> blockIdx.y
+#    for tk in range(K//64):
+#      MMA_128x128x64...
+#      TMA_128x64_64x128...
+#  Epilogue...
+#
+# This chapter introduces demonstrates:
+#  1. Partition shape based on block IDs
+#  2. Prologue
+#    2.1 Execute TMA Load for two input matrices for each stage
+#  3. Main loop
+#    3.1 Wait for completion of TMA load with mbarrier
+#    3.2 Performs Tensor Core GEMM 64x128x64 by warpgroup
+#    3.3 Load next stage if needed
+#  4. Epilogue
+#    4.1 Store fragmented registers to shared memory
+#    4.2 Store shared memory to global
+#
+# ===----------------------------------------------------------------------===//
+
+
+from mlir import ir
+from mlir.dialects import gpu, scf, nvgpu, nvvm
+from mlir.extras import types as T
+from tools.nvdsl import *
+import numpy as np
+
+
+def partition_shape():
+    """
+    Calculate the partition shape based on the block IDs.
+
+    It partitions the shape like below:
+    for(.. i < M ...)   --> blockIdx.x
+     for(.. j < N ...)  --> blockIdx.y
+      for(.. k < K ...)
+
+    Returns:
+        dimX (int): Dimension along the x-axis.
+        dimY (int): Dimension along the y-axis.
+    """
+    bidx = gpu.block_id(gpu.Dimension.x)
+    bidy = gpu.block_id(gpu.Dimension.y)
+    dimX = bidx * TILE_M
+    dimY = bidy * TILE_N
+    return dimX, dimY
+
+
+def tma_load(
+    mbar_group: Mbarriers,
+    a_tma: TMA,
+    b_tma: TMA,
+    slot,
+    stage,
+    p=None,
+):
+    """
+    TMA loads two input matrices from global memory to shared memory. It performs the following operations:
+
+       - tma.load a_shared_memory[off_x]  at coordinate [x, z]      (Loads 128x64)
+       - tma.load b_shared_memory[off_y1] at coordinate [y, x]      (Loads 64x64)
+       - tma.load b_shared_memory[off_y2] at coordinate [y + 64, x] (Loads 64x64)
+
+       mbarrier.arrive ta_count = 128x64x2x4
+    """
+    dimX, dimY = partition_shape()
+
+    tidx = gpu.thread_id(gpu.Dimension.x)
+    begin_b = NUM_STAGES * get_type_size(a_tma.tma_memref)
+    size_tma_a = get_type_size(a_tma.tma_memref)
+    size_tma_b = get_type_size(b_tma.tma_memref)
+    ta_count = size_tma_a + (size_tma_b * 2)
+    tidx = gpu.thread_id(gpu.Dimension.x)
+
+    p = tidx == 0 if p is None else p
+
+    off_a = slot * size_tma_a
+    off_b = (slot * size_tma_a) + begin_b
+    off_b2 = off_b + size_tma_b
+    a_elem_ty = a_tma.tma_memref.element_type
+    b_elem_ty = b_tma.tma_memref.element_type
+    a = get_dynamic_shared_memory(a_tma.tma_memref.shape, a_elem_ty, off_a)
+    b1 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b)
+    b2 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b2)
+
+    mbar_group[slot].arrive(ta_count, predicate=p)
+
+    c1 = stage * 64
+    a_tma.load(a, mbar_group[slot], coords=[c1, dimX], predicate=p)
+    b_tma.load(b1, mbar_group[slot], coords=[dimY, c1], predicate=p)
+    b_tma.load(b2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p)
+
+
+def bootstrap(a_tma: TMA, b_tma: TMA):
+    """
+    Initialize mbarriers and prefetch TMA descriptors.
+    """
+    tidx = gpu.thread_id(gpu.Dimension.x)
+    mbar_group = Mbarriers(number_of_barriers=NUM_STAGES)
+    isThread0 = tidx == const(0)
+    with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
+        for i in scf.for_(0, NUM_STAGES, 1):
+            mbar_group[i].init(1)
+            scf.yield_([])
+        a_tma.prefetch()
+        b_tma.prefetch()
+        scf.yield_([])
+
+    return mbar_group
+
+
+def prologue(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
+    """
+    Prologue of the GEMM kernel. It loads 2 input matrices for each stage in loop like below:
+
+    for stage in range(NUM_STAGES):
+        tma_load x, y, stage
+
+    """
+    ns = NUM_STAGES if NUM_STAGES == 1 else NUM_STAGES - 1
+    for iv in scf.for_(0, ns, 1):
+        tma_load(mbar_group, a_tma, b_tma, iv, iv)
+        scf.yield_([])
+
+
+def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA):
+    """
+    Main loop of the Multistage GEMM kernel. It iterates through
+    stages and performs matrix multiplication, loading data by TMA to shared memory. It like following
+
+    MatrixAccumulator D
+    for k in range(K // TILE_K):
+
+        try_wait(stage, ...)    # Wait TMA load
+
+        Matrix A(stage, ...)    # Find shared memory slot
+        Matrix B(stage, ...)    # Find shared memory slot
+        D += A @ B              # Multiply and accumulate
+
+        if(needLoad)            # Load next stage if needed
+            tma_load(x, y, nextSlot, nextStage)
+
+    """
+    ns = NUM_STAGES if NUM_STAGES == 1 else NUM_STAGES - 1
+
+    tidx = gpu.thread_id(gpu.Dimension.x)
+    begin_b = NUM_STAGES * get_type_size(a_tma.tma_memref)
+
+    size_a = TILE_M * TILE_K * get_type_size(T.f16())
+
+    # Initialize A and B (input matrices) and C (accumulator)
+    A = WGMMAMatrix(WGMMAType.Descriptor, [TILE_M, TILE_K], desc=a_tma)
+    B = WGMMAMatrix(WGMMAType.Descriptor, [TILE_K, TILE_N], desc=b_tma)
+    D = WGMMAMatrix(WGMMAType.Accumulator, shape=[TILE_M, TILE_N], ty=T.f32())
+
+    phase = const(False, ty=T.bool())
+
+    # Main Loop
+    for_op = scf.ForOp(const(0), const(K // TILE_K), const(1), [D.acc_op, phase])
+    with ir.InsertionPoint(for_op.body):
+        phase = for_op.inner_iter_args[1]
+        iv = for_op.induction_variable
+        stage = iv % NUM_STAGES
+
+        # Wait for current stage
+        mbar_group[stage].try_wait(phase=phase)
+
+        # Find shared memory slot
+        offset_a = stage * size_a
+        offset_b = offset_a + begin_b
+        a_smem = get_dynamic_shared_memory([TILE_M, TILE_K], T.f16(), offset_a)
+        b_smem = get_dynamic_shared_memory([TILE_K, TILE_N], T.f16(), offset_b)
+
+        # Iterate input matrices, update accumulator
+        A.update_smem(a_smem)
+        B.update_smem(b_smem)
+        D.update_accumulator(for_op.inner_iter_args[0])
+
+        # Matrix Multiply
+        D += A @ B
+
+        # Wait Tensor Core for single stage
+        if NUM_STAGES == 1:
+            nvvm.WgmmaWaitGroupSyncOp(0)
+
+        # Load next stage
+        pred = ((iv + ns) < const(K // TILE_K)) & (tidx == 0)
+        nextStage = iv + ns
+        nextSlot = nextStage % NUM_STAGES
+        tma_load(mbar_group, a_tma, b_tma, nextSlot, nextStage, pred)
+
+        # Switch phase parity for the mbarrier
+        newPhase = arith.select(
+            stage == (NUM_STAGES - 1),
+            (phase ^ const(True, ty=T.bool())),
+            phase,
+        )
+        scf.yield_([D.acc_op, newPhase])
+
+    nvvm.WgmmaWaitGroupSyncOp(0)
+
+    D.update_accumulator(for_op.results[0])
+    return D
+
+
+def epilogue(D: WGMMAMatrix, d_dev):
+    """
+    Epilogue of the GEMM kernel. It stores the fragmented registers to global memory.
+
+    MatrixAccumulator D               # Fragmented results
+    store D -> Shared Memory          # Store Shared Memory
+    Shared Memory -> Z[dimX][dimY]    # Store Shared Memory to Global Memory
+
+    """
+    tidx = gpu.thread_id(gpu.Dimension.x)
+    dimX, dimY = partition_shape()
+
+    d_smem = get_dynamic_shared_memory([TILE_M, TILE_N], T.f32())
+    d_gmem = memref.subview(d_dev, [dimX, dimY], [TILE_M, TILE_N], [1, 1])
+
+    # Store (registers -> shared memory)
+    D.store_accumulator(d_smem)
+    gpu.barrier()
+
+    # Store (shared memory --> global memory)
+    for i in scf.for_(0, TILE_M, 1):
+        val = memref.load(d_smem, [i, tidx])
+        memref.store(val, d_gmem, [i, tidx])
+        scf.yield_([])
+
+
+ at NVDSL.mlir_func
+def gemm_multistage(a, b, d):
----------------
jpienaar wrote:

Pseudo code showing what a, b, d represent?

https://github.com/llvm/llvm-project/pull/87065


More information about the Mlir-commits mailing list