[Mlir-commits] [mlir] [mlir][nvgpu] NVGPU Tutorials (PR #87065)

Manish Gupta llvmlistbot at llvm.org
Sat Mar 30 12:01:12 PDT 2024


================
@@ -0,0 +1,292 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN:   %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+#  Chapter 4 : Multistage GEMM with Tensor Core
+# ===----------------------------------------------------------------------===//
+#
+# This program demonstrates a GEMM operation with 64x64x64 matrix multiplication
+#
+# This chapter introduces demonstrates:
+#  1. Partition shape based on block IDs
+#  2. Prologue
+#    2.1 Execute TMA Load for two input matrices for each stage
+#  3. Main loop
+#    3.1 Wait for completion of TMA load with mbarrier
+#    3.2 Performs Tensor Core GEMM 64x128x64 by warpgroup
+#    3.3 Load next stage if needed
+#  4. Epilogue
+#    4.1 Store fragmented registers to shared memory
+#    4.2 Store shared memory to global
+#
+# ===----------------------------------------------------------------------===//
+
+
+from mlir import ir
+from mlir.dialects import gpu, scf, nvgpu, nvvm
+from mlir.extras import types as T
+from tools.nvdsl import *
+import numpy as np
+
+
+def partition_shape():
+    """
+    Calculate the partition shape based on the block IDs.
+
+    It partitions the shape like below:
+    for(.. i < M ...)   --> blockIdx.x
+     for(.. j < N ...)  --> blockIdx.y
+      for(.. k < K ...)
+
+    Returns:
+        dimX (int): Dimension along the x-axis.
+        dimY (int): Dimension along the y-axis.
+    """
+    bidx = gpu.block_id(gpu.Dimension.x)
+    bidy = gpu.block_id(gpu.Dimension.y)
+    dimX = bidx * TILE_M
+    dimY = bidy * TILE_N
+    return dimX, dimY
+
+
+def tma_load(
+    mbar_group: Mbarriers,
+    x_tma: TMA,
+    y_tma: TMA,
+    slot,
+    stage,
+    p=None,
+):
+    """
+    TMA loads two input matrices from global memory to shared memory. It performs the following operations:
+
+       - tma.load x_shared_memory[offset] at coordinate [x, y] (Loads 128x64)
+       - tma.load y_shared_memory[offset] at coordinate [x, y] (Loads 64x64)
+       - tma.load y_shared_memory[offset] at coordinate [x, y] (Loads 64x64)
+
+       mbarrier.arrive tx_count = 128x64x2x4
+    """
+    dimX, dimY = partition_shape()
+
+    tidx = gpu.thread_id(gpu.Dimension.x)
+    begin_y = NUM_STAGES * get_type_size(x_tma.tma_memref)
+    size_tma_x = get_type_size(x_tma.tma_memref)
+    size_tma_y = get_type_size(y_tma.tma_memref)
+    tx_count = size_tma_x + (size_tma_y * 2)
+    tidx = gpu.thread_id(gpu.Dimension.x)
+
+    p = tidx == 0 if p is None else p
+
+    off_x = slot * size_tma_x
+    off_y = (slot * size_tma_x) + begin_y
+    off_y2 = off_y + size_tma_y
+    x = get_dynamic_shared_memory(
+        x_tma.tma_memref.shape, x_tma.tma_memref.element_type, off_x
+    )
+    y1 = get_dynamic_shared_memory(
+        y_tma.tma_memref.shape, y_tma.tma_memref.element_type, off_y
+    )
+    y2 = get_dynamic_shared_memory(
+        y_tma.tma_memref.shape, y_tma.tma_memref.element_type, off_y2
+    )
+
+    mbar_group[slot].arrive(tx_count, predicate=p)
+
+    c1 = stage * 64
+    x_tma.load(x, mbar_group[slot], coords=[c1, dimX], predicate=p)
+    y_tma.load(y1, mbar_group[slot], coords=[dimY, c1], predicate=p)
+    y_tma.load(y2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p)
+
+
+def bootstrap(x_tma: TMA, y_tma: TMA):
+    """
+    Initialize mbarriers and prefetch TMA descriptors.
+    """
+    tidx = gpu.thread_id(gpu.Dimension.x)
+    mbar_group = Mbarriers(number_of_barriers=NUM_STAGES)
+    isThread0 = tidx == const(0)
+    with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
+        for i in scf.for_(0, NUM_STAGES, 1):
+            mbar_group[i].init(1)
+            scf.yield_([])
+        x_tma.prefetch()
+        y_tma.prefetch()
+        scf.yield_([])
+
+    return mbar_group
+
+
+def prologue(mbar_group: Mbarriers, x_tma: TMA, y_tma: TMA):
+    """
+    Prologue of the GEMM kernel. It loads 2 input matrices for each stage in loop like below:
+
+    for stage in range(NUM_STAGES):
+        tma_load x, y, stage
+
+    """
+    ns = NUM_STAGES if NUM_STAGES == 1 else NUM_STAGES - 1
+    for iv in scf.for_(0, ns, 1):
+        tma_load(mbar_group, x_tma, y_tma, iv, iv)
+        scf.yield_([])
+
+
+def mainloop(mbar_group: Mbarriers, x_tma: TMA, y_tma: TMA):
+    """
+    Main loop of the Multistage GEMM kernel. It iterates through
+    stages and performs matrix multiplication, loading data by TMA to shared memory. It like following
+
+    MatrixAccumulator D
+    for k in range(K // TILE_K):
+
+        try_wait(stage, ...)    # Wait TMA load
+
+        Matrix A(stage, ...)    # Find shared memory slot
+        Matrix B(stage, ...)    # Find shared memory slot
+        D += A @ B              # Multiply and accumulate
+
+        if(needLoad)            # Load next stage if needed
+            tma_load(x, y, nextSlot, nextStage)
+
+    """
+    ns = NUM_STAGES if NUM_STAGES == 1 else NUM_STAGES - 1
----------------
manishucsd wrote:

I won't worry about NUM_STAGES = 1. For Ampere and beyond, you will almost always have 3 stages or more in a GEMM kernel. I will just have an assert that `num_stages > 2` and implement a multi staged mainloop as the comments suggests "Main loop of the Multistage GEMM kernel". This should simplify the code and tutorial. I believe your performance runs would show that single stage is not winning for any GEMM problem shapes?

**Different mainloops:**
1. Single stage (_one_ shared memory stage)
2. Double buffered (_two_ shared memory stage. Used until Turing)
3. Multi-staged (_three_ or more stages. Used from Ampere and beyond)
4. Warp-specialized (Hopper and beyond. We have enough shared memory to have more that _two_ stages)

https://github.com/llvm/llvm-project/pull/87065


More information about the Mlir-commits mailing list