[llvm] [NVPTX] Make nvptx mma instructions convergent. (PR #96521)
weiwei chen via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 24 11:20:57 PDT 2024
https://github.com/weiweichen updated https://github.com/llvm/llvm-project/pull/96521
>From 2f636761daba3376170a0a197e686cad9ecc410c Mon Sep 17 00:00:00 2001
From: Weiwei Chen <weiwei.chen at modular.com>
Date: Mon, 24 Jun 2024 13:27:33 -0400
Subject: [PATCH 1/2] Make nvptx mma instructions convergent.
---
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index a65170e56aa24..a19ec21826b82 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -6724,6 +6724,7 @@ class WMMA_MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
# FragC.regstring # ";";
}
+let isConvergent = true in {
defset list<WMMA_INSTR> WMMAs = {
foreach layout_a = ["row", "col"] in {
foreach layout_b = ["row", "col"] in {
@@ -6745,6 +6746,7 @@ defset list<WMMA_INSTR> WMMAs = {
} // layout_b
} // layout_a
} // defset
+}
// MMA
class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
@@ -6774,6 +6776,7 @@ class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
# FragC.regstring # ";";
}
+let isConvergent = true in {
defset list<WMMA_INSTR> MMAs = {
foreach layout_a = ["row", "col"] in {
foreach layout_b = ["row", "col"] in {
@@ -6793,6 +6796,7 @@ defset list<WMMA_INSTR> MMAs = {
} // layout_b
} // layout_a
} // defset
+}
//
// ldmatrix.sync.aligned.m8n8[|.trans][|.shared].b16
>From 924a2fb25a8e9899a5e167859eb32ebe2ae69247 Mon Sep 17 00:00:00 2001
From: Weiwei Chen <weiwei.chen at modular.com>
Date: Mon, 24 Jun 2024 14:20:24 -0400
Subject: [PATCH 2/2] Add a test.
---
.../NVPTX/mma-no-sink-after-laneid-check.ll | 32 +++++++++++++++++++
1 file changed, 32 insertions(+)
create mode 100644 llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll
diff --git a/llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll b/llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll
new file mode 100644
index 0000000000000..44bc9f3140fb8
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll
@@ -0,0 +1,32 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx81 | FileCheck %s
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #1
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.ntid.x() #1
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.tid.x() #1
+
+declare { float, float, float, float } @llvm.nvvm.mma.m16n8k4.row.col.tf32(i32, i32, i32, float, float, float, float) #2
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.laneid() #1
+
+; COM: llvm.nvvm.mma should not sink to the next block and gets reordered to be after laneid check.
+; CHECK-LABEL: no_reorder_mma_and_laneid_check
+define dso_local void @no_reorder_mma_and_laneid_check(ptr %0, ptr %1, i64 %2) #0 {
+3:
+ ; CHECK: mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32
+ ; CHECK: laneid
+ %4 = tail call { float, float, float, float } @llvm.nvvm.mma.m16n8k4.row.col.tf32(i32 1065353216, i32 1065353216, i32 9, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00)
+ %5 = tail call i32 @llvm.nvvm.read.ptx.sreg.laneid()
+ %6 = icmp eq i32 %5, 0
+ br i1 %6, label %7, label %10
+
+7: ; preds = %3
+ %8 = extractvalue { float, float, float, float } %4, 0
+ %9 = getelementptr float, ptr %0, i64 0
+ store float %8, ptr %9, align 4
+ br label %10
+
+10: ; preds = %3, %7
+ ret void
+}
More information about the llvm-commits
mailing list