[llvm] [NVPTX] Make nvptx mma instructions convergent. (PR #96521)
weiwei chen via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 24 11:57:57 PDT 2024
https://github.com/weiweichen updated https://github.com/llvm/llvm-project/pull/96521
>From 2f636761daba3376170a0a197e686cad9ecc410c Mon Sep 17 00:00:00 2001
From: Weiwei Chen <weiwei.chen at modular.com>
Date: Mon, 24 Jun 2024 13:27:33 -0400
Subject: [PATCH 1/3] Make nvptx mma instructions convergent.
---
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index a65170e56aa24..a19ec21826b82 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -6724,6 +6724,7 @@ class WMMA_MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
# FragC.regstring # ";";
}
+let isConvergent = true in {
defset list<WMMA_INSTR> WMMAs = {
foreach layout_a = ["row", "col"] in {
foreach layout_b = ["row", "col"] in {
@@ -6745,6 +6746,7 @@ defset list<WMMA_INSTR> WMMAs = {
} // layout_b
} // layout_a
} // defset
+}
// MMA
class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
@@ -6774,6 +6776,7 @@ class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
# FragC.regstring # ";";
}
+let isConvergent = true in {
defset list<WMMA_INSTR> MMAs = {
foreach layout_a = ["row", "col"] in {
foreach layout_b = ["row", "col"] in {
@@ -6793,6 +6796,7 @@ defset list<WMMA_INSTR> MMAs = {
} // layout_b
} // layout_a
} // defset
+}
//
// ldmatrix.sync.aligned.m8n8[|.trans][|.shared].b16
>From 924a2fb25a8e9899a5e167859eb32ebe2ae69247 Mon Sep 17 00:00:00 2001
From: Weiwei Chen <weiwei.chen at modular.com>
Date: Mon, 24 Jun 2024 14:20:24 -0400
Subject: [PATCH 2/3] Add a test.
---
.../NVPTX/mma-no-sink-after-laneid-check.ll | 32 +++++++++++++++++++
1 file changed, 32 insertions(+)
create mode 100644 llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll
diff --git a/llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll b/llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll
new file mode 100644
index 0000000000000..44bc9f3140fb8
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll
@@ -0,0 +1,32 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx81 | FileCheck %s
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #1
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.ntid.x() #1
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.tid.x() #1
+
+declare { float, float, float, float } @llvm.nvvm.mma.m16n8k4.row.col.tf32(i32, i32, i32, float, float, float, float) #2
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.laneid() #1
+
+; COM: llvm.nvvm.mma should not sink to the next block and gets reordered to be after laneid check.
+; CHECK-LABEL: no_reorder_mma_and_laneid_check
+define dso_local void @no_reorder_mma_and_laneid_check(ptr %0, ptr %1, i64 %2) #0 {
+3:
+ ; CHECK: mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32
+ ; CHECK: laneid
+ %4 = tail call { float, float, float, float } @llvm.nvvm.mma.m16n8k4.row.col.tf32(i32 1065353216, i32 1065353216, i32 9, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00)
+ %5 = tail call i32 @llvm.nvvm.read.ptx.sreg.laneid()
+ %6 = icmp eq i32 %5, 0
+ br i1 %6, label %7, label %10
+
+7: ; preds = %3
+ %8 = extractvalue { float, float, float, float } %4, 0
+ %9 = getelementptr float, ptr %0, i64 0
+ store float %8, ptr %9, align 4
+ br label %10
+
+10: ; preds = %3, %7
+ ret void
+}
>From 5d29fb22b44eb51155fb2f393a043c6a47878171 Mon Sep 17 00:00:00 2001
From: Weiwei Chen <weiwei.chen at modular.com>
Date: Mon, 24 Jun 2024 14:57:29 -0400
Subject: [PATCH 3/3] Update test file to simplify it.
---
.../NVPTX/mma-no-sink-after-laneid-check.ll | 38 ++++++++-----------
1 file changed, 16 insertions(+), 22 deletions(-)
diff --git a/llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll b/llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll
index 44bc9f3140fb8..a88bc4d43dbef 100644
--- a/llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll
+++ b/llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll
@@ -1,32 +1,26 @@
; RUN: llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx81 | FileCheck %s
-declare noundef i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #1
+declare { float, float, float, float } @llvm.nvvm.mma.m16n8k4.row.col.tf32(i32, i32, i32, float, float, float, float) #1
-declare noundef i32 @llvm.nvvm.read.ptx.sreg.ntid.x() #1
-
-declare noundef i32 @llvm.nvvm.read.ptx.sreg.tid.x() #1
-
-declare { float, float, float, float } @llvm.nvvm.mma.m16n8k4.row.col.tf32(i32, i32, i32, float, float, float, float) #2
-
-declare noundef i32 @llvm.nvvm.read.ptx.sreg.laneid() #1
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.laneid() #0
; COM: llvm.nvvm.mma should not sink to the next block and gets reordered to be after laneid check.
; CHECK-LABEL: no_reorder_mma_and_laneid_check
-define dso_local void @no_reorder_mma_and_laneid_check(ptr %0, ptr %1, i64 %2) #0 {
-3:
+define dso_local void @no_reorder_mma_and_laneid_check(ptr %arg, ptr %arg1) {
+bb:
; CHECK: mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32
; CHECK: laneid
- %4 = tail call { float, float, float, float } @llvm.nvvm.mma.m16n8k4.row.col.tf32(i32 1065353216, i32 1065353216, i32 9, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00)
- %5 = tail call i32 @llvm.nvvm.read.ptx.sreg.laneid()
- %6 = icmp eq i32 %5, 0
- br i1 %6, label %7, label %10
-
-7: ; preds = %3
- %8 = extractvalue { float, float, float, float } %4, 0
- %9 = getelementptr float, ptr %0, i64 0
- store float %8, ptr %9, align 4
- br label %10
-
-10: ; preds = %3, %7
+ %i = tail call { float, float, float, float } @llvm.nvvm.mma.m16n8k4.row.col.tf32(i32 10, i32 10, i32 8, float 0.0, float 0.0, float 0.0, float 0.0)
+ %i3 = tail call i32 @llvm.nvvm.read.ptx.sreg.laneid()
+ %i4 = icmp eq i32 %i3, 0
+ br i1 %i4, label %bb5, label %bb8
+
+bb5: ; preds = %bb
+ %i6 = extractvalue { float, float, float, float } %i, 0
+ %i7 = getelementptr float, ptr %arg, i64 0
+ store float %i6, ptr %i7, align 4
+ br label %bb8
+
+bb8: ; preds = %bb5, %bb
ret void
}
More information about the llvm-commits
mailing list