[llvm] [NVPTX] Make nvptx mma instructions convergent. (PR #96521)

weiwei chen via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 24 10:38:50 PDT 2024


https://github.com/weiweichen created https://github.com/llvm/llvm-project/pull/96521

We are running into NVPTX backend generating wrong code for an input:
```
%0 = llvm.nvvm.mma.m?n?k?.row.col.??? (...)
if laneid == 0:
  ret
else:
  store %0
```

The backend reorder the instruction (as an effect of `MachineSink` pass) to
```
if laneid == 0:
  ret
else:
  %0 = llvm.nvvm.mma.m?n?k?.row.col.??? (...)
  store %0
```

This is incorrect because `mma` is a warp instruction which needs all threads to sync before performing the operation instead of being guarded by a specific thread id. It should be similar as the shuffle instruction `shfl` in terms of warp level sync, and `shfl` is marked as `isConvergent = true`.

Apply `isConvergent = true` to `mma` instructions. 


>From 2f636761daba3376170a0a197e686cad9ecc410c Mon Sep 17 00:00:00 2001
From: Weiwei Chen <weiwei.chen at modular.com>
Date: Mon, 24 Jun 2024 13:27:33 -0400
Subject: [PATCH] Make nvptx mma instructions convergent.

---
 llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index a65170e56aa24..a19ec21826b82 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -6724,6 +6724,7 @@ class WMMA_MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
                   # FragC.regstring # ";";
 }
 
+let isConvergent = true in {
 defset list<WMMA_INSTR> WMMAs  = {
   foreach layout_a = ["row", "col"] in {
     foreach layout_b = ["row", "col"] in {
@@ -6745,6 +6746,7 @@ defset list<WMMA_INSTR> WMMAs  = {
     } // layout_b
   } // layout_a
 } // defset
+}
 
 // MMA
 class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
@@ -6774,6 +6776,7 @@ class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
                   # FragC.regstring # ";";
 }
 
+let isConvergent = true in {
 defset list<WMMA_INSTR> MMAs  = {
   foreach layout_a = ["row", "col"] in {
     foreach layout_b = ["row", "col"] in {
@@ -6793,6 +6796,7 @@ defset list<WMMA_INSTR> MMAs  = {
     } // layout_b
   } // layout_a
 } // defset
+}
 
 //
 // ldmatrix.sync.aligned.m8n8[|.trans][|.shared].b16



More information about the llvm-commits mailing list