[llvm] [NVPTX] Support !"cluster_dim_{x,y,z}" metadata (PR #109548)

Alex MacLean via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 23 08:41:21 PDT 2024


https://github.com/AlexMaclean updated https://github.com/llvm/llvm-project/pull/109548

>From d1e90398d2abf51c843780b6ee8a08b1d7256682 Mon Sep 17 00:00:00 2001
From: Alex MacLean <amaclean at nvidia.com>
Date: Wed, 18 Sep 2024 01:50:00 +0000
Subject: [PATCH 1/3] [NVPTX] Support !"cluster_dim_{x,y,z}" metadata

---
 llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp | 27 ++++++++++++++++++++---
 llvm/lib/Target/NVPTX/NVPTXUtilities.cpp  | 16 ++++++++++++--
 llvm/lib/Target/NVPTX/NVPTXUtilities.h    |  6 ++++-
 llvm/test/CodeGen/NVPTX/cluster-dim.ll    | 17 ++++++++++++++
 4 files changed, 60 insertions(+), 6 deletions(-)
 create mode 100644 llvm/test/CodeGen/NVPTX/cluster-dim.ll

diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index d7197a7923eaf0..386829f61bcb04 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -575,9 +575,30 @@ void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
   // filter it out for lower SM versions, as it causes a hard ptxas crash.
   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
   const auto *STI = static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
-  unsigned Maxclusterrank = 0;
-  if (getMaxClusterRank(F, Maxclusterrank) && STI->getSmVersion() >= 90)
-    O << ".maxclusterrank " << Maxclusterrank << "\n";
+
+  if (STI->getSmVersion() >= 90) {
+    std::optional<unsigned> ClusterX = getClusterDimx(F);
+    std::optional<unsigned> ClusterY = getClusterDimy(F);
+    std::optional<unsigned> ClusterZ = getClusterDimz(F);
+
+    if (ClusterX || ClusterY || ClusterZ) {
+      O << ".explicitcluster\n";
+      if (ClusterX.value_or(1) != 0) {
+        assert(ClusterY.value_or(1) && ClusterZ.value_or(1) &&
+               "clusterx != 0 implies clustery and clusterz should be non-zero "
+               "as well");
+
+        O << ".reqnctapercluster " << ClusterX.value_or(1) << ", "
+          << ClusterY.value_or(1) << ", " << ClusterZ.value_or(1) << "\n";
+      } else {
+        assert(
+            !ClusterY.value_or(1) && !ClusterZ.value_or(1) &&
+            "clusterx == 0 implies clustery and clusterz should be 0 as well");
+      }
+    }
+    if (auto Maxclusterrank = getMaxClusterRank(F))
+      O << ".maxclusterrank " << *Maxclusterrank << "\n";
+  }
 }
 
 std::string NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const {
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
index 80361744fd5b6f..5961598606dc15 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
@@ -302,8 +302,20 @@ std::optional<unsigned> getMaxNTID(const Function &F) {
   return std::nullopt;
 }
 
-bool getMaxClusterRank(const Function &F, unsigned &x) {
-  return findOneNVVMAnnotation(&F, "maxclusterrank", x);
+std::optional<unsigned> getClusterDimx(const Function &F) {
+  return findOneNVVMAnnotation(F, "cluster_dim_x");
+}
+
+std::optional<unsigned> getClusterDimy(const Function &F) {
+  return findOneNVVMAnnotation(F, "cluster_dim_y");
+}
+
+std::optional<unsigned> getClusterDimz(const Function &F) {
+  return findOneNVVMAnnotation(F, "cluster_dim_z");
+}
+
+std::optional<unsigned> getMaxClusterRank(const Function &F) {
+  return findOneNVVMAnnotation(F, "maxclusterrank");
 }
 
 std::optional<unsigned> getReqNTIDx(const Function &F) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.h b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
index eebd91fefe4f03..508447a66b5b7b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.h
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
@@ -59,7 +59,11 @@ std::optional<unsigned> getReqNTIDy(const Function &);
 std::optional<unsigned> getReqNTIDz(const Function &);
 std::optional<unsigned> getReqNTID(const Function &);
 
-bool getMaxClusterRank(const Function &, unsigned &);
+std::optional<unsigned> getClusterDimx(const Function &F);
+std::optional<unsigned> getClusterDimy(const Function &F);
+std::optional<unsigned> getClusterDimz(const Function &F);
+
+std::optional<unsigned> getMaxClusterRank(const Function &);
 bool getMinCTASm(const Function &, unsigned &);
 bool getMaxNReg(const Function &, unsigned &);
 bool isKernelFunction(const Function &);
diff --git a/llvm/test/CodeGen/NVPTX/cluster-dim.ll b/llvm/test/CodeGen/NVPTX/cluster-dim.ll
new file mode 100644
index 00000000000000..109c9891417c57
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/cluster-dim.ll
@@ -0,0 +1,17 @@
+; RUN: llc < %s -march=nvptx -mcpu=sm_90 | FileCheck %s
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_90 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx -mcpu=sm_90 | %ptxas-verify %}
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_90 | %ptxas-verify %}
+
+; CHECK-LABEL: .entry kernel_func_clusterxyz
+define void @kernel_func_clusterxyz() {
+; CHECK: .explicitcluster
+; CHECK: .reqnctapercluster 3, 5, 7
+  ret void
+}
+
+
+!nvvm.annotations = !{!1, !2}
+
+!1 = !{ptr @kernel_func_clusterxyz, !"kernel", i32 1}
+!2 = !{ptr @kernel_func_clusterxyz, !"cluster_dim_x", i32 3, !"cluster_dim_y", i32 5, !"cluster_dim_z", i32 7}

>From 6d0d341da518ae09240c88be02b412a5e4ee32e8 Mon Sep 17 00:00:00 2001
From: Alex MacLean <amaclean at nvidia.com>
Date: Sat, 21 Sep 2024 20:03:06 +0000
Subject: [PATCH 2/3] [NVPTX] cleanup annotation API to use optional

---
 llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp | 10 ++-
 llvm/lib/Target/NVPTX/NVPTXUtilities.cpp  | 76 +++++++++--------------
 llvm/lib/Target/NVPTX/NVPTXUtilities.h    |  8 +--
 3 files changed, 39 insertions(+), 55 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 386829f61bcb04..a5cb8d2b4fd63d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -563,13 +563,11 @@ void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
     O << ".maxntid " << Maxntidx.value_or(1) << ", " << Maxntidy.value_or(1)
       << ", " << Maxntidz.value_or(1) << "\n";
 
-  unsigned Mincta = 0;
-  if (getMinCTASm(F, Mincta))
-    O << ".minnctapersm " << Mincta << "\n";
+  if (const auto Mincta = getMinCTASm(F))
+    O << ".minnctapersm " << *Mincta << "\n";
 
-  unsigned Maxnreg = 0;
-  if (getMaxNReg(F, Maxnreg))
-    O << ".maxnreg " << Maxnreg << "\n";
+  if (const auto Maxnreg = getMaxNReg(F))
+    O << ".maxnreg " << *Maxnreg << "\n";
 
   // .maxclusterrank directive requires SM_90 or higher, make sure that we
   // filter it out for lower SM versions, as it causes a hard ptxas crash.
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
index 5961598606dc15..5543bcf105bf9c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
@@ -130,8 +130,8 @@ static void cacheAnnotationFromMD(const Module *m, const GlobalValue *gv) {
   }
 }
 
-bool findOneNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
-                           unsigned &retval) {
+std::optional<unsigned> findOneNVVMAnnotation(const GlobalValue *gv,
+                                              const std::string &prop) {
   auto &AC = getAnnotationCache();
   std::lock_guard<sys::Mutex> Guard(AC.Lock);
   const Module *m = gv->getParent();
@@ -140,17 +140,8 @@ bool findOneNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
   else if (AC.Cache[m].find(gv) == AC.Cache[m].end())
     cacheAnnotationFromMD(m, gv);
   if (AC.Cache[m][gv].find(prop) == AC.Cache[m][gv].end())
-    return false;
-  retval = AC.Cache[m][gv][prop][0];
-  return true;
-}
-
-static std::optional<unsigned>
-findOneNVVMAnnotation(const GlobalValue &GV, const std::string &PropName) {
-  unsigned RetVal;
-  if (findOneNVVMAnnotation(&GV, PropName, RetVal))
-    return RetVal;
-  return std::nullopt;
+    return std::nullopt;
+  return AC.Cache[m][gv][prop][0];
 }
 
 bool findAllNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
@@ -170,9 +161,8 @@ bool findAllNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
 
 bool isTexture(const Value &val) {
   if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
-    unsigned Annot;
-    if (findOneNVVMAnnotation(gv, "texture", Annot)) {
-      assert((Annot == 1) && "Unexpected annotation on a texture symbol");
+    if (const auto Annot = findOneNVVMAnnotation(gv, "texture")) {
+      assert((*Annot == 1) && "Unexpected annotation on a texture symbol");
       return true;
     }
   }
@@ -181,9 +171,8 @@ bool isTexture(const Value &val) {
 
 bool isSurface(const Value &val) {
   if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
-    unsigned Annot;
-    if (findOneNVVMAnnotation(gv, "surface", Annot)) {
-      assert((Annot == 1) && "Unexpected annotation on a surface symbol");
+    if (const auto Annot = findOneNVVMAnnotation(gv, "surface")) {
+      assert((*Annot == 1) && "Unexpected annotation on a surface symbol");
       return true;
     }
   }
@@ -224,9 +213,8 @@ bool isSampler(const Value &val) {
   const char *AnnotationName = "sampler";
 
   if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
-    unsigned Annot;
-    if (findOneNVVMAnnotation(gv, AnnotationName, Annot)) {
-      assert((Annot == 1) && "Unexpected annotation on a sampler symbol");
+    if (const auto Annot = findOneNVVMAnnotation(gv, AnnotationName)) {
+      assert((*Annot == 1) && "Unexpected annotation on a sampler symbol");
       return true;
     }
   }
@@ -251,9 +239,8 @@ bool isImage(const Value &val) {
 
 bool isManaged(const Value &val) {
   if(const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
-    unsigned Annot;
-    if (findOneNVVMAnnotation(gv, "managed", Annot)) {
-      assert((Annot == 1) && "Unexpected annotation on a managed symbol");
+    if (const auto Annot = findOneNVVMAnnotation(gv, "managed")) {
+      assert((*Annot == 1) && "Unexpected annotation on a managed symbol");
       return true;
     }
   }
@@ -276,15 +263,15 @@ std::string getSamplerName(const Value &val) {
 }
 
 std::optional<unsigned> getMaxNTIDx(const Function &F) {
-  return findOneNVVMAnnotation(F, "maxntidx");
+  return findOneNVVMAnnotation(&F, "maxntidx");
 }
 
 std::optional<unsigned> getMaxNTIDy(const Function &F) {
-  return findOneNVVMAnnotation(F, "maxntidy");
+  return findOneNVVMAnnotation(&F, "maxntidy");
 }
 
 std::optional<unsigned> getMaxNTIDz(const Function &F) {
-  return findOneNVVMAnnotation(F, "maxntidz");
+  return findOneNVVMAnnotation(&F, "maxntidz");
 }
 
 std::optional<unsigned> getMaxNTID(const Function &F) {
@@ -303,31 +290,31 @@ std::optional<unsigned> getMaxNTID(const Function &F) {
 }
 
 std::optional<unsigned> getClusterDimx(const Function &F) {
-  return findOneNVVMAnnotation(F, "cluster_dim_x");
+  return findOneNVVMAnnotation(&F, "cluster_dim_x");
 }
 
 std::optional<unsigned> getClusterDimy(const Function &F) {
-  return findOneNVVMAnnotation(F, "cluster_dim_y");
+  return findOneNVVMAnnotation(&F, "cluster_dim_y");
 }
 
 std::optional<unsigned> getClusterDimz(const Function &F) {
-  return findOneNVVMAnnotation(F, "cluster_dim_z");
+  return findOneNVVMAnnotation(&F, "cluster_dim_z");
 }
 
 std::optional<unsigned> getMaxClusterRank(const Function &F) {
-  return findOneNVVMAnnotation(F, "maxclusterrank");
+  return findOneNVVMAnnotation(&F, "maxclusterrank");
 }
 
 std::optional<unsigned> getReqNTIDx(const Function &F) {
-  return findOneNVVMAnnotation(F, "reqntidx");
+  return findOneNVVMAnnotation(&F, "reqntidx");
 }
 
 std::optional<unsigned> getReqNTIDy(const Function &F) {
-  return findOneNVVMAnnotation(F, "reqntidy");
+  return findOneNVVMAnnotation(&F, "reqntidy");
 }
 
 std::optional<unsigned> getReqNTIDz(const Function &F) {
-  return findOneNVVMAnnotation(F, "reqntidz");
+  return findOneNVVMAnnotation(&F, "reqntidz");
 }
 
 std::optional<unsigned> getReqNTID(const Function &F) {
@@ -340,21 +327,20 @@ std::optional<unsigned> getReqNTID(const Function &F) {
   return std::nullopt;
 }
 
-bool getMinCTASm(const Function &F, unsigned &x) {
-  return findOneNVVMAnnotation(&F, "minctasm", x);
+std::optional<unsigned> getMinCTASm(const Function &F) {
+  return findOneNVVMAnnotation(&F, "minctasm");
 }
 
-bool getMaxNReg(const Function &F, unsigned &x) {
-  return findOneNVVMAnnotation(&F, "maxnreg", x);
+std::optional<unsigned> getMaxNReg(const Function &F) {
+  return findOneNVVMAnnotation(&F, "maxnreg");
 }
 
 bool isKernelFunction(const Function &F) {
-  unsigned x = 0;
-  if (!findOneNVVMAnnotation(&F, "kernel", x)) {
-    // There is no NVVM metadata, check the calling convention
-    return F.getCallingConv() == CallingConv::PTX_Kernel;
-  }
-  return (x == 1);
+  if (const auto x = findOneNVVMAnnotation(&F, "kernel"))
+    return (*x == 1);
+
+  // There is no NVVM metadata, check the calling convention
+  return F.getCallingConv() == CallingConv::PTX_Kernel;
 }
 
 MaybeAlign getAlign(const Function &F, unsigned Index) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.h b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
index 508447a66b5b7b..3755814f3ea23c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.h
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
@@ -31,8 +31,8 @@ class TargetMachine;
 
 void clearAnnotationCache(const Module *);
 
-bool findOneNVVMAnnotation(const GlobalValue *, const std::string &,
-                           unsigned &);
+std::optional<unsigned> findOneNVVMAnnotation(const GlobalValue *,
+                                              const std::string &);
 bool findAllNVVMAnnotation(const GlobalValue *, const std::string &,
                            std::vector<unsigned> &);
 
@@ -64,8 +64,8 @@ std::optional<unsigned> getClusterDimy(const Function &F);
 std::optional<unsigned> getClusterDimz(const Function &F);
 
 std::optional<unsigned> getMaxClusterRank(const Function &);
-bool getMinCTASm(const Function &, unsigned &);
-bool getMaxNReg(const Function &, unsigned &);
+std::optional<unsigned> getMinCTASm(const Function &);
+std::optional<unsigned> getMaxNReg(const Function &);
 bool isKernelFunction(const Function &);
 bool isParamGridConstant(const Value &);
 

>From ae3753bf64208df5e2ce48bd1be09c9427a365b7 Mon Sep 17 00:00:00 2001
From: Alex MacLean <amaclean at nvidia.com>
Date: Mon, 23 Sep 2024 15:41:01 +0000
Subject: [PATCH 3/3] address comments

---
 llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp | 10 +++++-----
 llvm/test/CodeGen/NVPTX/cluster-dim.ll    |  4 +---
 2 files changed, 6 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index a5cb8d2b4fd63d..bd5f99bf632bc7 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -583,15 +583,15 @@ void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
       O << ".explicitcluster\n";
       if (ClusterX.value_or(1) != 0) {
         assert(ClusterY.value_or(1) && ClusterZ.value_or(1) &&
-               "clusterx != 0 implies clustery and clusterz should be non-zero "
-               "as well");
+               "cluster_dim_x != 0 implies cluster_dim_y and cluster_dim_z "
+               "should be non-zero as well");
 
         O << ".reqnctapercluster " << ClusterX.value_or(1) << ", "
           << ClusterY.value_or(1) << ", " << ClusterZ.value_or(1) << "\n";
       } else {
-        assert(
-            !ClusterY.value_or(1) && !ClusterZ.value_or(1) &&
-            "clusterx == 0 implies clustery and clusterz should be 0 as well");
+        assert(!ClusterY.value_or(1) && !ClusterZ.value_or(1) &&
+               "cluster_dim_x == 0 implies cluster_dim_y and cluster_dim_z "
+               "should be 0 as well");
       }
     }
     if (auto Maxclusterrank = getMaxClusterRank(F))
diff --git a/llvm/test/CodeGen/NVPTX/cluster-dim.ll b/llvm/test/CodeGen/NVPTX/cluster-dim.ll
index 109c9891417c57..9c11a10acccb14 100644
--- a/llvm/test/CodeGen/NVPTX/cluster-dim.ll
+++ b/llvm/test/CodeGen/NVPTX/cluster-dim.ll
@@ -1,7 +1,5 @@
-; RUN: llc < %s -march=nvptx -mcpu=sm_90 | FileCheck %s
 ; RUN: llc < %s -march=nvptx64 -mcpu=sm_90 | FileCheck %s
-; RUN: %if ptxas %{ llc < %s -march=nvptx -mcpu=sm_90 | %ptxas-verify %}
-; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_90 | %ptxas-verify %}
+; RUN: %if ptxas-12.0 %{ llc < %s -march=nvptx64 -mcpu=sm_90 | %ptxas-verify -arch=sm_90 %}
 
 ; CHECK-LABEL: .entry kernel_func_clusterxyz
 define void @kernel_func_clusterxyz() {



More information about the llvm-commits mailing list