[all-commits] [llvm/llvm-project] 72e6a3: [mlir] GEMM Hopper Tensor Core Integration Test
Guray Ozen via All-commits
all-commits at lists.llvm.org
Sun Mar 3 03:50:05 PST 2024
Branch: refs/reviewable/pr81478/r2
Home: https://github.com/llvm/llvm-project
Commit: 72e6a310edaa2e7d8e7ede385de09e52bd302f9b
https://github.com/llvm/llvm-project/commit/72e6a310edaa2e7d8e7ede385de09e52bd302f9b
Author: grypp <guray.ozen at gmail.com>
Date: 2024-02-13 (Tue, 13 Feb 2024)
Changed paths:
A mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg
A mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
A mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg
A mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
A mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
Log Message:
-----------
[mlir] GEMM Hopper Tensor Core Integration Test
This test aims to validate the correctness of the supported GEMM kernels in
NVGPU dialects, with current support for Multistage and Warp Specialization
kernels.
The test constructs and metaprograms IR using Python bindings, allowing
generic IR building. This flexibility enables changes to the shape,
tile size, or data type of the GEMM for testing purposes.
The entry function is `matmul`, where one can specify GEMM shape, tile size,
data type, GEMM algorithm (Multistage or Warp Specialization), and the maximum
number of stages.
Verification is done via numpy's matmul operation.
Example:
```
matmul(input_type=np.float16, # input types
output_type=np.float32, # output type
M=4096, N=4096, K=4096, # Shape
BLOCK_M=128, BLOCK_N=128, BLOCK_K=64, # Tile Size
use_warp_specialization=True, # Enable Warp Specialization
max_num_stages=3) # Number of stages in shared memory
```
### Parallelism Across CTAs
GEMM includes three loops defining the shape of the GEMM, specified in the
`matmul` function.
The program builds IR using the following loop structure, tiling the loops
with the given tile size and parallelizing the two outermost loops into the
first and second dimensions of CTAs.
```
for(bi = 0; i < M; i += BLOCK_M) # parallelize across blockIdx.x
for(bj = 0; j < N; j += BLOCK_N) # parallelize across blockIdx.y
for(bk = 0; k < K; K += BLOCK_K)
for(i = bi; i < (bi + BLOCK_M); ++i)
for(j = bj; j < (bj + BLOCK_N); ++j)
for(k = bk; k < (bk + BLOCK_K); ++k)
```
## Multistage Kernel
This kernel launches a single warp group (128 threads). The primary thread
(pthread) requests load from TMA. Threads collectively wait for the data and
perform mma operations. After completing the shape, threads together store
first fragmented registers to shared memory, then from shared memory to global
memory; this part is called the epilogue.
Execution Timeline of Multistage Kernel with 3 stages:
```
+-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
| |Prologue ----> |MainLoop ----> |Epilogue |
+-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
|pthread|[tma-0,1,2] |[wait-0][mma][tma-2]|[wait-1][mma][tma-0]|[wait-2][mma][tma-1]| ... | [mma-wait] |[epilogue]|
|wgroup | ........ |[wait-0][mma] |[wait-1][mma] |[wait-2][mma] | ... | [mma-wait] |[epilogue]|
+-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
```
## Warp Specialization Kernel
This kernel launches 2 warp groups (2x128 threads) per CTA, specializing one
as `producer warp group` and another as `consumer warp group`. The
`producer warp group` is responsible for requesting TMA load, while the
`consumer warp group` performs the mma operation. The epilogue section is
handled by the `consumer warp group` as its threads own the fragmented registers.
Execution Timeline of Warp Specialization Kernel with 2 stages:
```
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
| |MainLoop ----> | 1st Epilogue | 2nd Epilogue |
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
|pthread1|[tma-0] | [tma-1] | [tma-0] | [tma-1] | ..........................| ........... | [shmem->global] |
|wgroup1 | .......| | | | | | [shmem->global] |
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
|wgroup2 |[wait-0][mma], [wait-1][mma], [wait-0][mma], [wait-1][mma], ......| [reg->shmem] | [shmem->global]|
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
```
Commit: 26da27e3c5b5198750b1f80fbd9d1e5d6dfe229d
https://github.com/llvm/llvm-project/commit/26da27e3c5b5198750b1f80fbd9d1e5d6dfe229d
Author: grypp <guray.ozen at gmail.com>
Date: 2024-02-13 (Tue, 13 Feb 2024)
Changed paths:
M mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
M mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
M mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
Log Message:
-----------
format with yapf
Commit: 623e90f380ecfd1ef10ae8b7fe539e39dd269e52
https://github.com/llvm/llvm-project/commit/623e90f380ecfd1ef10ae8b7fe539e39dd269e52
Author: grypp <guray.ozen at gmail.com>
Date: 2024-02-13 (Tue, 13 Feb 2024)
Changed paths:
M mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
M mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
M mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
Log Message:
-----------
format it with black
Commit: aa3e3486ef011cd4bc9bb784f81b73981e913151
https://github.com/llvm/llvm-project/commit/aa3e3486ef011cd4bc9bb784f81b73981e913151
Author: grypp <guray.ozen at gmail.com>
Date: 2024-02-13 (Tue, 13 Feb 2024)
Changed paths:
M mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
Log Message:
-----------
fix the spelling mistake
Commit: dea779aae915d610b409660a2fc6d0b2a3697a43
https://github.com/llvm/llvm-project/commit/dea779aae915d610b409660a2fc6d0b2a3697a43
Author: grypp <guray.ozen at gmail.com>
Date: 2024-02-20 (Tue, 20 Feb 2024)
Changed paths:
M mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
Log Message:
-----------
address comments
Commit: d4393a47f1d93253bac35c5ce784bc9b06c463f1
https://github.com/llvm/llvm-project/commit/d4393a47f1d93253bac35c5ce784bc9b06c463f1
Author: grypp <guray.ozen at gmail.com>
Date: 2024-02-20 (Tue, 20 Feb 2024)
Changed paths:
M mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
Log Message:
-----------
format
Compare: https://github.com/llvm/llvm-project/compare/72e6a310edaa%5E...d4393a47f1d9
To unsubscribe from these emails, change your notification settings at https://github.com/llvm/llvm-project/settings/notifications
More information about the All-commits
mailing list