[llvm] b0e9b00 - [NVPTX] Make nvptx mma instructions convergent. (#96521)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 24 19:16:01 PDT 2024
Author: weiwei chen
Date: 2024-06-24T22:15:58-04:00
New Revision: b0e9b00ce7d623175c5e60e82afe24e7f8a200be
URL: https://github.com/llvm/llvm-project/commit/b0e9b00ce7d623175c5e60e82afe24e7f8a200be
DIFF: https://github.com/llvm/llvm-project/commit/b0e9b00ce7d623175c5e60e82afe24e7f8a200be.diff
LOG: [NVPTX] Make nvptx mma instructions convergent. (#96521)
We are running into NVPTX backend generating wrong code for an input:
```
%0 = llvm.nvvm.mma.m?n?k?.row.col.??? (...)
if laneid == 0:
ret
else:
store %0
```
The backend reorder the instruction (as an effect of `MachineSink` pass)
to
```
if laneid == 0:
ret
else:
%0 = llvm.nvvm.mma.m?n?k?.row.col.??? (...)
store %0
```
This is incorrect because `mma` is a warp instruction which needs all
threads to sync before performing the operation instead of being guarded
by a specific thread id. It should be similar as the shuffle instruction
`shfl` in terms of warp level sync, and `shfl` is marked as
`isConvergent = true`.
Apply `isConvergent = true` to `mma` instructions.
Added:
llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll
Modified:
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
Removed:
################################################################################
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index c0509054af1f4..c81dfa68e4bd4 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -6725,6 +6725,7 @@ class WMMA_MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
# FragC.regstring # ";";
}
+let isConvergent = true in {
defset list<WMMA_INSTR> WMMAs = {
foreach layout_a = ["row", "col"] in {
foreach layout_b = ["row", "col"] in {
@@ -6746,6 +6747,7 @@ defset list<WMMA_INSTR> WMMAs = {
} // layout_b
} // layout_a
} // defset
+}
// MMA
class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
@@ -6775,6 +6777,7 @@ class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
# FragC.regstring # ";";
}
+let isConvergent = true in {
defset list<WMMA_INSTR> MMAs = {
foreach layout_a = ["row", "col"] in {
foreach layout_b = ["row", "col"] in {
@@ -6794,6 +6797,7 @@ defset list<WMMA_INSTR> MMAs = {
} // layout_b
} // layout_a
} // defset
+}
//
// ldmatrix.sync.aligned.m8n8[|.trans][|.shared].b16
diff --git a/llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll b/llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll
new file mode 100644
index 0000000000000..a88bc4d43dbef
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll
@@ -0,0 +1,26 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx81 | FileCheck %s
+
+declare { float, float, float, float } @llvm.nvvm.mma.m16n8k4.row.col.tf32(i32, i32, i32, float, float, float, float) #1
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.laneid() #0
+
+; COM: llvm.nvvm.mma should not sink to the next block and gets reordered to be after laneid check.
+; CHECK-LABEL: no_reorder_mma_and_laneid_check
+define dso_local void @no_reorder_mma_and_laneid_check(ptr %arg, ptr %arg1) {
+bb:
+ ; CHECK: mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32
+ ; CHECK: laneid
+ %i = tail call { float, float, float, float } @llvm.nvvm.mma.m16n8k4.row.col.tf32(i32 10, i32 10, i32 8, float 0.0, float 0.0, float 0.0, float 0.0)
+ %i3 = tail call i32 @llvm.nvvm.read.ptx.sreg.laneid()
+ %i4 = icmp eq i32 %i3, 0
+ br i1 %i4, label %bb5, label %bb8
+
+bb5: ; preds = %bb
+ %i6 = extractvalue { float, float, float, float } %i, 0
+ %i7 = getelementptr float, ptr %arg, i64 0
+ store float %i6, ptr %i7, align 4
+ br label %bb8
+
+bb8: ; preds = %bb5, %bb
+ ret void
+}
More information about the llvm-commits
mailing list