[llvm] [NVPTX] Make nvptx mma instructions convergent. (PR #96521)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 24 10:39:18 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-nvptx

Author: weiwei chen (weiweichen)

<details>
<summary>Changes</summary>

We are running into NVPTX backend generating wrong code for an input:
```
%0 = llvm.nvvm.mma.m?n?k?.row.col.??? (...)
if laneid == 0:
  ret
else:
  store %0
```

The backend reorder the instruction (as an effect of `MachineSink` pass) to
```
if laneid == 0:
  ret
else:
  %0 = llvm.nvvm.mma.m?n?k?.row.col.??? (...)
  store %0
```

This is incorrect because `mma` is a warp instruction which needs all threads to sync before performing the operation instead of being guarded by a specific thread id. It should be similar as the shuffle instruction `shfl` in terms of warp level sync, and `shfl` is marked as `isConvergent = true`.

Apply `isConvergent = true` to `mma` instructions. 


---
Full diff: https://github.com/llvm/llvm-project/pull/96521.diff


1 Files Affected:

- (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+4) 


``````````diff
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index a65170e56aa24..a19ec21826b82 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -6724,6 +6724,7 @@ class WMMA_MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
                   # FragC.regstring # ";";
 }
 
+let isConvergent = true in {
 defset list<WMMA_INSTR> WMMAs  = {
   foreach layout_a = ["row", "col"] in {
     foreach layout_b = ["row", "col"] in {
@@ -6745,6 +6746,7 @@ defset list<WMMA_INSTR> WMMAs  = {
     } // layout_b
   } // layout_a
 } // defset
+}
 
 // MMA
 class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
@@ -6774,6 +6776,7 @@ class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
                   # FragC.regstring # ";";
 }
 
+let isConvergent = true in {
 defset list<WMMA_INSTR> MMAs  = {
   foreach layout_a = ["row", "col"] in {
     foreach layout_b = ["row", "col"] in {
@@ -6793,6 +6796,7 @@ defset list<WMMA_INSTR> MMAs  = {
     } // layout_b
   } // layout_a
 } // defset
+}
 
 //
 // ldmatrix.sync.aligned.m8n8[|.trans][|.shared].b16

``````````

</details>


https://github.com/llvm/llvm-project/pull/96521


More information about the llvm-commits mailing list