[PATCH] D158059: [AMDGPU/wmma] - Disable 3-address syntax for f16

Jessica Del via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 8 05:51:56 PDT 2023


OutOfCache added a comment.

In D158059#4640554 <https://reviews.llvm.org/D158059#4640554>, @piotr wrote:

> Thanks, that would avoid the regression. However, I still do not fully understand the failing mode - can you show the test case + extra code that triggers the issue?

I admit, it is probably hard to follow without the full context. I'll try again.

During cooperative matrix calculation, we typically have a loop. This loop calculates one or more accumulator matrices.
Every iteration calculates the partial result of the accumulator, while iterating over different factor matrices.

However, some shaders do not use a loop.

Suppose we have multiple `wmma` instructions. They all use the same C matrix. In our case, they all use a zero matrix as input at first, because there is no previous result.

  v_mov v24, 0
  ; ...
  ; v[24:31] has all zeros
  v_wmma_f16_16x16x16_f16 v[0:7],   ..., v[24:31]
  v_wmma_f16_16x16x16_f16 v[32:39], ..., v[24:31]

Since these are the `wmma_f16` instructions, they write the result into the lower 16 bit of the result registers.
Therefore, after the instructions, the content of `v0` is `0xIJKL????`, where `IJKL` is the result of the `wmma`, while the `????` is the previous content of `v0`, from before the `wmma`. 
In other words, even though the input accumulator has all zeros in its registers, these zeros are **not** copied into the result registers.

Then, we have other `wmma`s, which update the result matrices. We take the previous result as input accumulator, and swap out the factor matrices for different ones.

  v_wmma_f16_16x16x16_f16 v[0:7],   ..., v[0:7]
  v_wmma_f16_16x16x16_f16 v[32:39], ..., v[32:39]

After that, the content of `v0` is `0xPQRS????`.

Usually, this is not an issue. But the future patch tries to fill the `????` with the values of another matrix, so we can save VGPRs.
We can do that thanks to the `op_sel` argument. We want the matrix values inside `v[32:39]` to be inside `v[0:7]`.
Essentally, we want this:

  v_wmma_f16_16x16x16_f16 v[0:7],   ..., v[0:7]
  v_wmma_f16_16x16x16_f16 v[0:7],   ..., v[0:7] op_sel [0, 0, 1] ; instead of v[32:39]

The second instructions reads from the upper 16 bit of `v0`, and also writes the result into the upper 16 bit of `v0`.
That normally does not cause an issue. However, when we intially calculated `v0` with the first `wmma`, we had the content `0xIJKL????`.
Now, any `wmma` with the `op_sel` tries to read the upper sixteen bit, the `????` as input accumulator. These upper bits were not initialized to zeros.

That is the root of the issue, since now the previous content of the registers is added to the `wmma` result, and not 0 as expected.

Tying the input accumulator to the result accumulator solves this issue. We no longer have a separate zero matrix, which we reuse over multiple `wmma`s.
Instead, we correctly initialize the output matrix to all zeros and then calculate the results.

  v_mov v0, 0
  ; ...
  v_wmma_f16_16x16x16_f16 v[0:7],   ..., v[0:7]

After that, the content of `v0` is now `0xIJKL0000`, as expected.

Does that clear things up? Let me know if I skipped some detail.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D158059/new/

https://reviews.llvm.org/D158059



More information about the llvm-commits mailing list