[all-commits] [llvm/llvm-project] 51916f: [mlir] Add sm_90a GEMM test 128x128x128 (F32 += F1...

Guray Ozen via All-commits all-commits at lists.llvm.org
Fri Nov 10 07:53:56 PST 2023


  Branch: refs/heads/main
  Home:   https://github.com/llvm/llvm-project
  Commit: 51916f0c924f2ed4e970dd043a14d70b6b1d3f71
      https://github.com/llvm/llvm-project/commit/51916f0c924f2ed4e970dd043a14d70b6b1d3f71
  Author: Guray Ozen <guray.ozen at gmail.com>
  Date:   2023-11-10 (Fri, 10 Nov 2023)

  Changed paths:
    A mlir/test/Integration/GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir

  Log Message:
  -----------
  [mlir] Add sm_90a GEMM test 128x128x128 (F32 += F16 * F16) (#69913)

This PR adds a test that performs GEMM 128x128x128 (F32 += F16 * F16).
It uses `sm_90a` features in NVGPU dialect.

Simplified algorithm is as follows:

**Prologue** 
```
mgroup = mbarriers.init x 2
tma.load ... shmem_buffer_lhs<0 x 128 x 64>
tma.load ... shmem_buffer_rhs<0 x 64 x 64>
tma.load ... shmem_buffer_rhs<0 x 64 x 64>
mbarrier.expect_tx 32768
tma.load ... shmem_buffer_lhs<1 x 128 x 64>
tma.load ... shmem_buffer_rhs<1 x 64 x 64>
tma.load ... shmem_buffer_rhs<1 x 64 x 64>
mbarrier.expect_tx 32768
```
**Mainloop**
```
matrixD = 
 for(i = 0;...2) {   
   mbarrier.try_wait [i]
   lhs = shmem_buffer_lhs<pipe x 128 x 64>
   rhs = shmem_buffer_rhs<pipe x 64 x 128>
   yield nvgpu.warpgroup.mma (lhs, rhs)

//   Expanded : nvgpu.warpgroup.mma [128][128]+=[128][64]*[64][128]
//                  wgmma.m64n128k16(A[0:64][0:16]  *  B[0:16][0:128])
//                  wgmma.m64n128k16(A[0:64][16:32] *  B[16:32][0:128])
//                  wgmma.m64n128k16(A[0:64][32:48] *  B[32:48][0:128])
//                  wgmma.m64n128k16(A[0:64][48:64] *  B[48:64][0:128])
//                  wgmma.m64n128k16(A[64:128][0:16]  *  B[0:16][0:128])
//                  wgmma.m64n128k16(A[64:128][16:32] *  B[16:32][0:128])
//                  wgmma.m64n128k16(A[64:128][32:48] *  B[32:48][0:128])
//                  wgmma.m64n128k16(A[64:128][48:64] *  B[48:64][0:128])
```

**Epilogue** 
```
//reg->shmem
warpgroup.mma.store matrixD, shmem
//shmem->glbmem
parallel-for(i=0;...128)
 parallel-for(j=0;...128)
   store shmem, globalmem
```




More information about the All-commits mailing list