[all-commits] [llvm/llvm-project] 51916f: [mlir] Add sm_90a GEMM test 128x128x128 (F32 += F1...
Guray Ozen via All-commits
all-commits at lists.llvm.org
Fri Nov 10 07:53:56 PST 2023
Branch: refs/heads/main
Home: https://github.com/llvm/llvm-project
Commit: 51916f0c924f2ed4e970dd043a14d70b6b1d3f71
https://github.com/llvm/llvm-project/commit/51916f0c924f2ed4e970dd043a14d70b6b1d3f71
Author: Guray Ozen <guray.ozen at gmail.com>
Date: 2023-11-10 (Fri, 10 Nov 2023)
Changed paths:
A mlir/test/Integration/GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir
Log Message:
-----------
[mlir] Add sm_90a GEMM test 128x128x128 (F32 += F16 * F16) (#69913)
This PR adds a test that performs GEMM 128x128x128 (F32 += F16 * F16).
It uses `sm_90a` features in NVGPU dialect.
Simplified algorithm is as follows:
**Prologue**
```
mgroup = mbarriers.init x 2
tma.load ... shmem_buffer_lhs<0 x 128 x 64>
tma.load ... shmem_buffer_rhs<0 x 64 x 64>
tma.load ... shmem_buffer_rhs<0 x 64 x 64>
mbarrier.expect_tx 32768
tma.load ... shmem_buffer_lhs<1 x 128 x 64>
tma.load ... shmem_buffer_rhs<1 x 64 x 64>
tma.load ... shmem_buffer_rhs<1 x 64 x 64>
mbarrier.expect_tx 32768
```
**Mainloop**
```
matrixD =
for(i = 0;...2) {
mbarrier.try_wait [i]
lhs = shmem_buffer_lhs<pipe x 128 x 64>
rhs = shmem_buffer_rhs<pipe x 64 x 128>
yield nvgpu.warpgroup.mma (lhs, rhs)
// Expanded : nvgpu.warpgroup.mma [128][128]+=[128][64]*[64][128]
// wgmma.m64n128k16(A[0:64][0:16] * B[0:16][0:128])
// wgmma.m64n128k16(A[0:64][16:32] * B[16:32][0:128])
// wgmma.m64n128k16(A[0:64][32:48] * B[32:48][0:128])
// wgmma.m64n128k16(A[0:64][48:64] * B[48:64][0:128])
// wgmma.m64n128k16(A[64:128][0:16] * B[0:16][0:128])
// wgmma.m64n128k16(A[64:128][16:32] * B[16:32][0:128])
// wgmma.m64n128k16(A[64:128][32:48] * B[32:48][0:128])
// wgmma.m64n128k16(A[64:128][48:64] * B[48:64][0:128])
```
**Epilogue**
```
//reg->shmem
warpgroup.mma.store matrixD, shmem
//shmem->glbmem
parallel-for(i=0;...128)
parallel-for(j=0;...128)
store shmem, globalmem
```
More information about the All-commits
mailing list