[all-commits] [llvm/llvm-project] a00caa: [mlir] Add sm_90a GEMM test 128x128x128 (F32 =F16*...

Guray Ozen via All-commits all-commits at lists.llvm.org
Fri Nov 10 07:52:14 PST 2023


  Branch: refs/heads/main
  Home:   https://github.com/llvm/llvm-project
  Commit: a00caad6bf318a7497d477b434464ca75ecb41fc
      https://github.com/llvm/llvm-project/commit/a00caad6bf318a7497d477b434464ca75ecb41fc
  Author: Guray Ozen <guray.ozen at gmail.com>
  Date:   2023-11-10 (Fri, 10 Nov 2023)

  Changed paths:
    A mlir/test/Integration/GPU/CUDA/sm90/gemm_pred_f32_f16_f16_128x128x128.mlir

  Log Message:
  -----------
  [mlir] Add sm_90a GEMM test 128x128x128 (F32 =F16*F16) with predicate (#70028)

PR #69913 added a GEMM test (128x128x128 F32 += F16 * F16) with
if-statement. This PR adds the same test using predicates in PTX.
Predicate support is enabled using _BasicPtxBuilderInterface_
`(nvgpu.opcode ..., predicate = %pred)`.

The predicate condition is computed in `Step 2. [GPU] Elect fastest
thread in CTA` inspired by cutlass. It is as follows:
```
lane_predicate = nvvm.elect.sync
warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0)
warp_idx_in_warp_group = warp_idx % 4
predicate = (lane_predicate & warp_idx_in_warp_group)
```

Depends on #70027 #69934 #69935 #69584




More information about the All-commits mailing list