[llvm] [NVPTX] Make nvptx mma instructions convergent. (PR #96521)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 24 10:39:18 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-nvptx
Author: weiwei chen (weiweichen)
<details>
<summary>Changes</summary>
We are running into NVPTX backend generating wrong code for an input:
```
%0 = llvm.nvvm.mma.m?n?k?.row.col.??? (...)
if laneid == 0:
ret
else:
store %0
```
The backend reorder the instruction (as an effect of `MachineSink` pass) to
```
if laneid == 0:
ret
else:
%0 = llvm.nvvm.mma.m?n?k?.row.col.??? (...)
store %0
```
This is incorrect because `mma` is a warp instruction which needs all threads to sync before performing the operation instead of being guarded by a specific thread id. It should be similar as the shuffle instruction `shfl` in terms of warp level sync, and `shfl` is marked as `isConvergent = true`.
Apply `isConvergent = true` to `mma` instructions.
---
Full diff: https://github.com/llvm/llvm-project/pull/96521.diff
1 Files Affected:
- (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+4)
``````````diff
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index a65170e56aa24..a19ec21826b82 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -6724,6 +6724,7 @@ class WMMA_MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
# FragC.regstring # ";";
}
+let isConvergent = true in {
defset list<WMMA_INSTR> WMMAs = {
foreach layout_a = ["row", "col"] in {
foreach layout_b = ["row", "col"] in {
@@ -6745,6 +6746,7 @@ defset list<WMMA_INSTR> WMMAs = {
} // layout_b
} // layout_a
} // defset
+}
// MMA
class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
@@ -6774,6 +6776,7 @@ class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
# FragC.regstring # ";";
}
+let isConvergent = true in {
defset list<WMMA_INSTR> MMAs = {
foreach layout_a = ["row", "col"] in {
foreach layout_b = ["row", "col"] in {
@@ -6793,6 +6796,7 @@ defset list<WMMA_INSTR> MMAs = {
} // layout_b
} // layout_a
} // defset
+}
//
// ldmatrix.sync.aligned.m8n8[|.trans][|.shared].b16
``````````
</details>
https://github.com/llvm/llvm-project/pull/96521
More information about the llvm-commits
mailing list