[llvm] [NVPTX] Make nvptx mma instructions convergent. (PR #96521)

weiwei chen via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 24 11:57:57 PDT 2024


https://github.com/weiweichen updated https://github.com/llvm/llvm-project/pull/96521

>From 2f636761daba3376170a0a197e686cad9ecc410c Mon Sep 17 00:00:00 2001
From: Weiwei Chen <weiwei.chen at modular.com>
Date: Mon, 24 Jun 2024 13:27:33 -0400
Subject: [PATCH 1/3] Make nvptx mma instructions convergent.

---
 llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index a65170e56aa24..a19ec21826b82 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -6724,6 +6724,7 @@ class WMMA_MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
                   # FragC.regstring # ";";
 }
 
+let isConvergent = true in {
 defset list<WMMA_INSTR> WMMAs  = {
   foreach layout_a = ["row", "col"] in {
     foreach layout_b = ["row", "col"] in {
@@ -6745,6 +6746,7 @@ defset list<WMMA_INSTR> WMMAs  = {
     } // layout_b
   } // layout_a
 } // defset
+}
 
 // MMA
 class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
@@ -6774,6 +6776,7 @@ class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
                   # FragC.regstring # ";";
 }
 
+let isConvergent = true in {
 defset list<WMMA_INSTR> MMAs  = {
   foreach layout_a = ["row", "col"] in {
     foreach layout_b = ["row", "col"] in {
@@ -6793,6 +6796,7 @@ defset list<WMMA_INSTR> MMAs  = {
     } // layout_b
   } // layout_a
 } // defset
+}
 
 //
 // ldmatrix.sync.aligned.m8n8[|.trans][|.shared].b16

>From 924a2fb25a8e9899a5e167859eb32ebe2ae69247 Mon Sep 17 00:00:00 2001
From: Weiwei Chen <weiwei.chen at modular.com>
Date: Mon, 24 Jun 2024 14:20:24 -0400
Subject: [PATCH 2/3] Add a test.

---
 .../NVPTX/mma-no-sink-after-laneid-check.ll   | 32 +++++++++++++++++++
 1 file changed, 32 insertions(+)
 create mode 100644 llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll

diff --git a/llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll b/llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll
new file mode 100644
index 0000000000000..44bc9f3140fb8
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll
@@ -0,0 +1,32 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx81 | FileCheck %s
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #1
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.ntid.x() #1
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.tid.x() #1
+
+declare { float, float, float, float } @llvm.nvvm.mma.m16n8k4.row.col.tf32(i32, i32, i32, float, float, float, float) #2
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.laneid() #1
+
+; COM: llvm.nvvm.mma should not sink to the next block and gets reordered to be after laneid check.
+; CHECK-LABEL: no_reorder_mma_and_laneid_check
+define dso_local void @no_reorder_mma_and_laneid_check(ptr %0, ptr %1, i64 %2) #0 {
+3:
+  ; CHECK: mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32
+  ; CHECK: laneid
+  %4 = tail call { float, float, float, float } @llvm.nvvm.mma.m16n8k4.row.col.tf32(i32 1065353216, i32 1065353216, i32 9, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00)
+  %5 = tail call i32 @llvm.nvvm.read.ptx.sreg.laneid()
+  %6 = icmp eq i32 %5, 0
+  br i1 %6, label %7, label %10
+
+7:                                               ; preds = %3
+  %8 = extractvalue { float, float, float, float } %4, 0
+  %9 = getelementptr float, ptr %0, i64 0
+  store float %8, ptr %9, align 4
+  br label %10
+
+10:                                               ; preds = %3, %7
+  ret void
+}

>From 5d29fb22b44eb51155fb2f393a043c6a47878171 Mon Sep 17 00:00:00 2001
From: Weiwei Chen <weiwei.chen at modular.com>
Date: Mon, 24 Jun 2024 14:57:29 -0400
Subject: [PATCH 3/3] Update test file to simplify it.

---
 .../NVPTX/mma-no-sink-after-laneid-check.ll   | 38 ++++++++-----------
 1 file changed, 16 insertions(+), 22 deletions(-)

diff --git a/llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll b/llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll
index 44bc9f3140fb8..a88bc4d43dbef 100644
--- a/llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll
+++ b/llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll
@@ -1,32 +1,26 @@
 ; RUN: llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx81 | FileCheck %s
 
-declare noundef i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #1
+declare { float, float, float, float } @llvm.nvvm.mma.m16n8k4.row.col.tf32(i32, i32, i32, float, float, float, float) #1
 
-declare noundef i32 @llvm.nvvm.read.ptx.sreg.ntid.x() #1
-
-declare noundef i32 @llvm.nvvm.read.ptx.sreg.tid.x() #1
-
-declare { float, float, float, float } @llvm.nvvm.mma.m16n8k4.row.col.tf32(i32, i32, i32, float, float, float, float) #2
-
-declare noundef i32 @llvm.nvvm.read.ptx.sreg.laneid() #1
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.laneid() #0
 
 ; COM: llvm.nvvm.mma should not sink to the next block and gets reordered to be after laneid check.
 ; CHECK-LABEL: no_reorder_mma_and_laneid_check
-define dso_local void @no_reorder_mma_and_laneid_check(ptr %0, ptr %1, i64 %2) #0 {
-3:
+define dso_local void @no_reorder_mma_and_laneid_check(ptr %arg, ptr %arg1) {
+bb:
   ; CHECK: mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32
   ; CHECK: laneid
-  %4 = tail call { float, float, float, float } @llvm.nvvm.mma.m16n8k4.row.col.tf32(i32 1065353216, i32 1065353216, i32 9, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00)
-  %5 = tail call i32 @llvm.nvvm.read.ptx.sreg.laneid()
-  %6 = icmp eq i32 %5, 0
-  br i1 %6, label %7, label %10
-
-7:                                               ; preds = %3
-  %8 = extractvalue { float, float, float, float } %4, 0
-  %9 = getelementptr float, ptr %0, i64 0
-  store float %8, ptr %9, align 4
-  br label %10
-
-10:                                               ; preds = %3, %7
+  %i = tail call { float, float, float, float } @llvm.nvvm.mma.m16n8k4.row.col.tf32(i32 10, i32 10, i32 8, float 0.0, float 0.0, float 0.0, float 0.0)
+  %i3 = tail call i32 @llvm.nvvm.read.ptx.sreg.laneid()
+  %i4 = icmp eq i32 %i3, 0
+  br i1 %i4, label %bb5, label %bb8
+
+bb5:                                              ; preds = %bb
+  %i6 = extractvalue { float, float, float, float } %i, 0
+  %i7 = getelementptr float, ptr %arg, i64 0
+  store float %i6, ptr %i7, align 4
+  br label %bb8
+
+bb8:                                              ; preds = %bb5, %bb
   ret void
 }



More information about the llvm-commits mailing list