[llvm] b0e9b00 - [NVPTX] Make nvptx mma instructions convergent. (#96521)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 24 19:16:01 PDT 2024


Author: weiwei chen
Date: 2024-06-24T22:15:58-04:00
New Revision: b0e9b00ce7d623175c5e60e82afe24e7f8a200be

URL: https://github.com/llvm/llvm-project/commit/b0e9b00ce7d623175c5e60e82afe24e7f8a200be
DIFF: https://github.com/llvm/llvm-project/commit/b0e9b00ce7d623175c5e60e82afe24e7f8a200be.diff

LOG: [NVPTX] Make nvptx mma instructions convergent. (#96521)

We are running into NVPTX backend generating wrong code for an input:
```
%0 = llvm.nvvm.mma.m?n?k?.row.col.??? (...)
if laneid == 0:
  ret
else:
  store %0
```

The backend reorder the instruction (as an effect of `MachineSink` pass)
to
```
if laneid == 0:
  ret
else:
  %0 = llvm.nvvm.mma.m?n?k?.row.col.??? (...)
  store %0
```

This is incorrect because `mma` is a warp instruction which needs all
threads to sync before performing the operation instead of being guarded
by a specific thread id. It should be similar as the shuffle instruction
`shfl` in terms of warp level sync, and `shfl` is marked as
`isConvergent = true`.

Apply `isConvergent = true` to `mma` instructions.

Added: 
    llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll

Modified: 
    llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index c0509054af1f4..c81dfa68e4bd4 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -6725,6 +6725,7 @@ class WMMA_MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
                   # FragC.regstring # ";";
 }
 
+let isConvergent = true in {
 defset list<WMMA_INSTR> WMMAs  = {
   foreach layout_a = ["row", "col"] in {
     foreach layout_b = ["row", "col"] in {
@@ -6746,6 +6747,7 @@ defset list<WMMA_INSTR> WMMAs  = {
     } // layout_b
   } // layout_a
 } // defset
+}
 
 // MMA
 class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
@@ -6775,6 +6777,7 @@ class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
                   # FragC.regstring # ";";
 }
 
+let isConvergent = true in {
 defset list<WMMA_INSTR> MMAs  = {
   foreach layout_a = ["row", "col"] in {
     foreach layout_b = ["row", "col"] in {
@@ -6794,6 +6797,7 @@ defset list<WMMA_INSTR> MMAs  = {
     } // layout_b
   } // layout_a
 } // defset
+}
 
 //
 // ldmatrix.sync.aligned.m8n8[|.trans][|.shared].b16

diff  --git a/llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll b/llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll
new file mode 100644
index 0000000000000..a88bc4d43dbef
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll
@@ -0,0 +1,26 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx81 | FileCheck %s
+
+declare { float, float, float, float } @llvm.nvvm.mma.m16n8k4.row.col.tf32(i32, i32, i32, float, float, float, float) #1
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.laneid() #0
+
+; COM: llvm.nvvm.mma should not sink to the next block and gets reordered to be after laneid check.
+; CHECK-LABEL: no_reorder_mma_and_laneid_check
+define dso_local void @no_reorder_mma_and_laneid_check(ptr %arg, ptr %arg1) {
+bb:
+  ; CHECK: mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32
+  ; CHECK: laneid
+  %i = tail call { float, float, float, float } @llvm.nvvm.mma.m16n8k4.row.col.tf32(i32 10, i32 10, i32 8, float 0.0, float 0.0, float 0.0, float 0.0)
+  %i3 = tail call i32 @llvm.nvvm.read.ptx.sreg.laneid()
+  %i4 = icmp eq i32 %i3, 0
+  br i1 %i4, label %bb5, label %bb8
+
+bb5:                                              ; preds = %bb
+  %i6 = extractvalue { float, float, float, float } %i, 0
+  %i7 = getelementptr float, ptr %arg, i64 0
+  store float %i6, ptr %i7, align 4
+  br label %bb8
+
+bb8:                                              ; preds = %bb5, %bb
+  ret void
+}


        


More information about the llvm-commits mailing list