[llvm] [NVPTX] Make nvptx mma instructions convergent. (PR #96521)

weiwei chen via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 24 11:20:57 PDT 2024


https://github.com/weiweichen updated https://github.com/llvm/llvm-project/pull/96521

>From 2f636761daba3376170a0a197e686cad9ecc410c Mon Sep 17 00:00:00 2001
From: Weiwei Chen <weiwei.chen at modular.com>
Date: Mon, 24 Jun 2024 13:27:33 -0400
Subject: [PATCH 1/2] Make nvptx mma instructions convergent.

---
 llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index a65170e56aa24..a19ec21826b82 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -6724,6 +6724,7 @@ class WMMA_MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
                   # FragC.regstring # ";";
 }
 
+let isConvergent = true in {
 defset list<WMMA_INSTR> WMMAs  = {
   foreach layout_a = ["row", "col"] in {
     foreach layout_b = ["row", "col"] in {
@@ -6745,6 +6746,7 @@ defset list<WMMA_INSTR> WMMAs  = {
     } // layout_b
   } // layout_a
 } // defset
+}
 
 // MMA
 class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
@@ -6774,6 +6776,7 @@ class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
                   # FragC.regstring # ";";
 }
 
+let isConvergent = true in {
 defset list<WMMA_INSTR> MMAs  = {
   foreach layout_a = ["row", "col"] in {
     foreach layout_b = ["row", "col"] in {
@@ -6793,6 +6796,7 @@ defset list<WMMA_INSTR> MMAs  = {
     } // layout_b
   } // layout_a
 } // defset
+}
 
 //
 // ldmatrix.sync.aligned.m8n8[|.trans][|.shared].b16

>From 924a2fb25a8e9899a5e167859eb32ebe2ae69247 Mon Sep 17 00:00:00 2001
From: Weiwei Chen <weiwei.chen at modular.com>
Date: Mon, 24 Jun 2024 14:20:24 -0400
Subject: [PATCH 2/2] Add a test.

---
 .../NVPTX/mma-no-sink-after-laneid-check.ll   | 32 +++++++++++++++++++
 1 file changed, 32 insertions(+)
 create mode 100644 llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll

diff --git a/llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll b/llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll
new file mode 100644
index 0000000000000..44bc9f3140fb8
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll
@@ -0,0 +1,32 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx81 | FileCheck %s
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #1
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.ntid.x() #1
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.tid.x() #1
+
+declare { float, float, float, float } @llvm.nvvm.mma.m16n8k4.row.col.tf32(i32, i32, i32, float, float, float, float) #2
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.laneid() #1
+
+; COM: llvm.nvvm.mma should not sink to the next block and gets reordered to be after laneid check.
+; CHECK-LABEL: no_reorder_mma_and_laneid_check
+define dso_local void @no_reorder_mma_and_laneid_check(ptr %0, ptr %1, i64 %2) #0 {
+3:
+  ; CHECK: mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32
+  ; CHECK: laneid
+  %4 = tail call { float, float, float, float } @llvm.nvvm.mma.m16n8k4.row.col.tf32(i32 1065353216, i32 1065353216, i32 9, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00)
+  %5 = tail call i32 @llvm.nvvm.read.ptx.sreg.laneid()
+  %6 = icmp eq i32 %5, 0
+  br i1 %6, label %7, label %10
+
+7:                                               ; preds = %3
+  %8 = extractvalue { float, float, float, float } %4, 0
+  %9 = getelementptr float, ptr %0, i64 0
+  store float %8, ptr %9, align 4
+  br label %10
+
+10:                                               ; preds = %3, %7
+  ret void
+}



More information about the llvm-commits mailing list