[all-commits] [llvm/llvm-project] 72e6a3: [mlir] GEMM Hopper Tensor Core Integration Test

Guray Ozen via All-commits all-commits at lists.llvm.org
Sun Mar 3 03:50:05 PST 2024


  Branch: refs/reviewable/pr81478/r2
  Home:   https://github.com/llvm/llvm-project
  Commit: 72e6a310edaa2e7d8e7ede385de09e52bd302f9b
      https://github.com/llvm/llvm-project/commit/72e6a310edaa2e7d8e7ede385de09e52bd302f9b
  Author: grypp <guray.ozen at gmail.com>
  Date:   2024-02-13 (Tue, 13 Feb 2024)

  Changed paths:
    A mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg
    A mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
    A mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg
    A mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
    A mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py

  Log Message:
  -----------
  [mlir] GEMM Hopper Tensor Core Integration Test

This test aims to validate the correctness of the supported GEMM kernels in
NVGPU dialects, with current support for Multistage and Warp Specialization
kernels.
The test constructs and metaprograms IR using Python bindings, allowing
generic IR building. This flexibility enables changes to the shape,
tile size, or data type of the GEMM for testing purposes.
The entry function is `matmul`, where one can specify GEMM shape, tile size,
data type, GEMM algorithm (Multistage or Warp Specialization), and the maximum
number of stages.
Verification is done via numpy's matmul operation.

Example:
```
matmul(input_type=np.float16,                # input types
       output_type=np.float32,               # output type
       M=4096, N=4096, K=4096,               # Shape
       BLOCK_M=128, BLOCK_N=128, BLOCK_K=64, # Tile Size
       use_warp_specialization=True,         # Enable Warp Specialization
       max_num_stages=3)                     # Number of stages in shared memory
```
### Parallelism Across CTAs

GEMM includes three loops defining the shape of the GEMM, specified in the
`matmul` function.
The program builds IR using the following loop structure, tiling the loops
with the given tile size and parallelizing the two outermost loops into the
first and second dimensions of CTAs.
```
for(bi = 0; i < M; i += BLOCK_M)          # parallelize across blockIdx.x
    for(bj = 0; j < N; j += BLOCK_N)      # parallelize across blockIdx.y
        for(bk = 0; k < K; K += BLOCK_K)
            for(i = bi; i < (bi + BLOCK_M); ++i)
                for(j = bj; j < (bj + BLOCK_N); ++j)
                    for(k = bk; k < (bk + BLOCK_K); ++k)
```

## Multistage Kernel

This kernel launches a single warp group (128 threads). The primary thread
(pthread) requests load from TMA. Threads collectively wait for the data and
perform mma operations. After completing the shape, threads together store
first fragmented registers to shared memory, then from shared memory to global
memory; this part is called the epilogue.

Execution Timeline of Multistage Kernel with 3 stages:
```
+-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
|       |Prologue ---->   |MainLoop ---->                                                                  |Epilogue  |
+-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
|pthread|[tma-0,1,2]     |[wait-0][mma][tma-2]|[wait-1][mma][tma-0]|[wait-2][mma][tma-1]| ... | [mma-wait] |[epilogue]|
|wgroup | ........       |[wait-0][mma]       |[wait-1][mma]       |[wait-2][mma]       | ... | [mma-wait] |[epilogue]|
+-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
```

## Warp Specialization Kernel

This kernel launches 2 warp groups (2x128 threads) per CTA, specializing one
as `producer warp group` and another as `consumer warp group`. The
`producer warp group` is responsible for requesting TMA load, while the
`consumer warp group` performs the mma operation. The epilogue section is
handled by the `consumer warp group` as its threads own the fragmented registers.

Execution Timeline of Warp Specialization Kernel with 2 stages:
```
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
|        |MainLoop ---->                                                    | 1st Epilogue | 2nd Epilogue    |
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
|pthread1|[tma-0] | [tma-1] | [tma-0] | [tma-1] | ..........................| ...........  | [shmem->global] |
|wgroup1 | .......|         |         |         |                           |              | [shmem->global] |
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
|wgroup2 |[wait-0][mma], [wait-1][mma], [wait-0][mma], [wait-1][mma], ......| [reg->shmem] | [shmem->global]|
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
```


  Commit: 26da27e3c5b5198750b1f80fbd9d1e5d6dfe229d
      https://github.com/llvm/llvm-project/commit/26da27e3c5b5198750b1f80fbd9d1e5d6dfe229d
  Author: grypp <guray.ozen at gmail.com>
  Date:   2024-02-13 (Tue, 13 Feb 2024)

  Changed paths:
    M mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
    M mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
    M mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py

  Log Message:
  -----------
  format with yapf


  Commit: 623e90f380ecfd1ef10ae8b7fe539e39dd269e52
      https://github.com/llvm/llvm-project/commit/623e90f380ecfd1ef10ae8b7fe539e39dd269e52
  Author: grypp <guray.ozen at gmail.com>
  Date:   2024-02-13 (Tue, 13 Feb 2024)

  Changed paths:
    M mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
    M mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
    M mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py

  Log Message:
  -----------
  format it with black


  Commit: aa3e3486ef011cd4bc9bb784f81b73981e913151
      https://github.com/llvm/llvm-project/commit/aa3e3486ef011cd4bc9bb784f81b73981e913151
  Author: grypp <guray.ozen at gmail.com>
  Date:   2024-02-13 (Tue, 13 Feb 2024)

  Changed paths:
    M mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py

  Log Message:
  -----------
  fix the spelling mistake


  Commit: dea779aae915d610b409660a2fc6d0b2a3697a43
      https://github.com/llvm/llvm-project/commit/dea779aae915d610b409660a2fc6d0b2a3697a43
  Author: grypp <guray.ozen at gmail.com>
  Date:   2024-02-20 (Tue, 20 Feb 2024)

  Changed paths:
    M mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py

  Log Message:
  -----------
  address comments


  Commit: d4393a47f1d93253bac35c5ce784bc9b06c463f1
      https://github.com/llvm/llvm-project/commit/d4393a47f1d93253bac35c5ce784bc9b06c463f1
  Author: grypp <guray.ozen at gmail.com>
  Date:   2024-02-20 (Tue, 20 Feb 2024)

  Changed paths:
    M mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py

  Log Message:
  -----------
  format


Compare: https://github.com/llvm/llvm-project/compare/72e6a310edaa%5E...d4393a47f1d9

To unsubscribe from these emails, change your notification settings at https://github.com/llvm/llvm-project/settings/notifications


More information about the All-commits mailing list