[PATCH] D158059: [AMDGPU/wmma] - Disable 3-address syntax for f16
Jessica Del via Phabricator via llvm-commits
llvm-commits at lists.llvm.org
Fri Sep 8 05:51:56 PDT 2023
OutOfCache added a comment.
In D158059#4640554 <https://reviews.llvm.org/D158059#4640554>, @piotr wrote:
> Thanks, that would avoid the regression. However, I still do not fully understand the failing mode - can you show the test case + extra code that triggers the issue?
I admit, it is probably hard to follow without the full context. I'll try again.
During cooperative matrix calculation, we typically have a loop. This loop calculates one or more accumulator matrices.
Every iteration calculates the partial result of the accumulator, while iterating over different factor matrices.
However, some shaders do not use a loop.
Suppose we have multiple `wmma` instructions. They all use the same C matrix. In our case, they all use a zero matrix as input at first, because there is no previous result.
v_mov v24, 0
; ...
; v[24:31] has all zeros
v_wmma_f16_16x16x16_f16 v[0:7], ..., v[24:31]
v_wmma_f16_16x16x16_f16 v[32:39], ..., v[24:31]
Since these are the `wmma_f16` instructions, they write the result into the lower 16 bit of the result registers.
Therefore, after the instructions, the content of `v0` is `0xIJKL????`, where `IJKL` is the result of the `wmma`, while the `????` is the previous content of `v0`, from before the `wmma`.
In other words, even though the input accumulator has all zeros in its registers, these zeros are **not** copied into the result registers.
Then, we have other `wmma`s, which update the result matrices. We take the previous result as input accumulator, and swap out the factor matrices for different ones.
v_wmma_f16_16x16x16_f16 v[0:7], ..., v[0:7]
v_wmma_f16_16x16x16_f16 v[32:39], ..., v[32:39]
After that, the content of `v0` is `0xPQRS????`.
Usually, this is not an issue. But the future patch tries to fill the `????` with the values of another matrix, so we can save VGPRs.
We can do that thanks to the `op_sel` argument. We want the matrix values inside `v[32:39]` to be inside `v[0:7]`.
Essentally, we want this:
v_wmma_f16_16x16x16_f16 v[0:7], ..., v[0:7]
v_wmma_f16_16x16x16_f16 v[0:7], ..., v[0:7] op_sel [0, 0, 1] ; instead of v[32:39]
The second instructions reads from the upper 16 bit of `v0`, and also writes the result into the upper 16 bit of `v0`.
That normally does not cause an issue. However, when we intially calculated `v0` with the first `wmma`, we had the content `0xIJKL????`.
Now, any `wmma` with the `op_sel` tries to read the upper sixteen bit, the `????` as input accumulator. These upper bits were not initialized to zeros.
That is the root of the issue, since now the previous content of the registers is added to the `wmma` result, and not 0 as expected.
Tying the input accumulator to the result accumulator solves this issue. We no longer have a separate zero matrix, which we reuse over multiple `wmma`s.
Instead, we correctly initialize the output matrix to all zeros and then calculate the results.
v_mov v0, 0
; ...
v_wmma_f16_16x16x16_f16 v[0:7], ..., v[0:7]
After that, the content of `v0` is now `0xIJKL0000`, as expected.
Does that clear things up? Let me know if I skipped some detail.
Repository:
rG LLVM Github Monorepo
CHANGES SINCE LAST ACTION
https://reviews.llvm.org/D158059/new/
https://reviews.llvm.org/D158059
More information about the llvm-commits
mailing list