[llvm] [NVPTX] Support !"cluster_dim_{x,y,z}" metadata (PR #109548)
via llvm-commits
llvm-commits at lists.llvm.org
Sat Sep 21 13:06:01 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-nvptx
Author: Alex MacLean (AlexMaclean)
<details>
<summary>Changes</summary>
Add support for !"cluster_dim_{x,y,z}" metadata to allow specifying cluster dimensions on a kernel function in llvm.
If any of these metadata entries are present, the `.explicitcluster` PTX directive is used and the specified dimensions are lowered with the `.reqnctapercluster` directive. For more details see: [PTX ISA: 11.7. Cluster Dimension Directives](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cluster-dimension-directives)
---
Full diff: https://github.com/llvm/llvm-project/pull/109548.diff
4 Files Affected:
- (modified) llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp (+28-9)
- (modified) llvm/lib/Target/NVPTX/NVPTXUtilities.cpp (+41-43)
- (modified) llvm/lib/Target/NVPTX/NVPTXUtilities.h (+9-5)
- (added) llvm/test/CodeGen/NVPTX/cluster-dim.ll (+17)
``````````diff
diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index d7197a7923eaf0..a5cb8d2b4fd63d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -563,21 +563,40 @@ void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
O << ".maxntid " << Maxntidx.value_or(1) << ", " << Maxntidy.value_or(1)
<< ", " << Maxntidz.value_or(1) << "\n";
- unsigned Mincta = 0;
- if (getMinCTASm(F, Mincta))
- O << ".minnctapersm " << Mincta << "\n";
+ if (const auto Mincta = getMinCTASm(F))
+ O << ".minnctapersm " << *Mincta << "\n";
- unsigned Maxnreg = 0;
- if (getMaxNReg(F, Maxnreg))
- O << ".maxnreg " << Maxnreg << "\n";
+ if (const auto Maxnreg = getMaxNReg(F))
+ O << ".maxnreg " << *Maxnreg << "\n";
// .maxclusterrank directive requires SM_90 or higher, make sure that we
// filter it out for lower SM versions, as it causes a hard ptxas crash.
const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
const auto *STI = static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
- unsigned Maxclusterrank = 0;
- if (getMaxClusterRank(F, Maxclusterrank) && STI->getSmVersion() >= 90)
- O << ".maxclusterrank " << Maxclusterrank << "\n";
+
+ if (STI->getSmVersion() >= 90) {
+ std::optional<unsigned> ClusterX = getClusterDimx(F);
+ std::optional<unsigned> ClusterY = getClusterDimy(F);
+ std::optional<unsigned> ClusterZ = getClusterDimz(F);
+
+ if (ClusterX || ClusterY || ClusterZ) {
+ O << ".explicitcluster\n";
+ if (ClusterX.value_or(1) != 0) {
+ assert(ClusterY.value_or(1) && ClusterZ.value_or(1) &&
+ "clusterx != 0 implies clustery and clusterz should be non-zero "
+ "as well");
+
+ O << ".reqnctapercluster " << ClusterX.value_or(1) << ", "
+ << ClusterY.value_or(1) << ", " << ClusterZ.value_or(1) << "\n";
+ } else {
+ assert(
+ !ClusterY.value_or(1) && !ClusterZ.value_or(1) &&
+ "clusterx == 0 implies clustery and clusterz should be 0 as well");
+ }
+ }
+ if (auto Maxclusterrank = getMaxClusterRank(F))
+ O << ".maxclusterrank " << *Maxclusterrank << "\n";
+ }
}
std::string NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const {
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
index 80361744fd5b6f..5543bcf105bf9c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
@@ -130,8 +130,8 @@ static void cacheAnnotationFromMD(const Module *m, const GlobalValue *gv) {
}
}
-bool findOneNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
- unsigned &retval) {
+std::optional<unsigned> findOneNVVMAnnotation(const GlobalValue *gv,
+ const std::string &prop) {
auto &AC = getAnnotationCache();
std::lock_guard<sys::Mutex> Guard(AC.Lock);
const Module *m = gv->getParent();
@@ -140,17 +140,8 @@ bool findOneNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
else if (AC.Cache[m].find(gv) == AC.Cache[m].end())
cacheAnnotationFromMD(m, gv);
if (AC.Cache[m][gv].find(prop) == AC.Cache[m][gv].end())
- return false;
- retval = AC.Cache[m][gv][prop][0];
- return true;
-}
-
-static std::optional<unsigned>
-findOneNVVMAnnotation(const GlobalValue &GV, const std::string &PropName) {
- unsigned RetVal;
- if (findOneNVVMAnnotation(&GV, PropName, RetVal))
- return RetVal;
- return std::nullopt;
+ return std::nullopt;
+ return AC.Cache[m][gv][prop][0];
}
bool findAllNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
@@ -170,9 +161,8 @@ bool findAllNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
bool isTexture(const Value &val) {
if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
- unsigned Annot;
- if (findOneNVVMAnnotation(gv, "texture", Annot)) {
- assert((Annot == 1) && "Unexpected annotation on a texture symbol");
+ if (const auto Annot = findOneNVVMAnnotation(gv, "texture")) {
+ assert((*Annot == 1) && "Unexpected annotation on a texture symbol");
return true;
}
}
@@ -181,9 +171,8 @@ bool isTexture(const Value &val) {
bool isSurface(const Value &val) {
if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
- unsigned Annot;
- if (findOneNVVMAnnotation(gv, "surface", Annot)) {
- assert((Annot == 1) && "Unexpected annotation on a surface symbol");
+ if (const auto Annot = findOneNVVMAnnotation(gv, "surface")) {
+ assert((*Annot == 1) && "Unexpected annotation on a surface symbol");
return true;
}
}
@@ -224,9 +213,8 @@ bool isSampler(const Value &val) {
const char *AnnotationName = "sampler";
if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
- unsigned Annot;
- if (findOneNVVMAnnotation(gv, AnnotationName, Annot)) {
- assert((Annot == 1) && "Unexpected annotation on a sampler symbol");
+ if (const auto Annot = findOneNVVMAnnotation(gv, AnnotationName)) {
+ assert((*Annot == 1) && "Unexpected annotation on a sampler symbol");
return true;
}
}
@@ -251,9 +239,8 @@ bool isImage(const Value &val) {
bool isManaged(const Value &val) {
if(const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
- unsigned Annot;
- if (findOneNVVMAnnotation(gv, "managed", Annot)) {
- assert((Annot == 1) && "Unexpected annotation on a managed symbol");
+ if (const auto Annot = findOneNVVMAnnotation(gv, "managed")) {
+ assert((*Annot == 1) && "Unexpected annotation on a managed symbol");
return true;
}
}
@@ -276,15 +263,15 @@ std::string getSamplerName(const Value &val) {
}
std::optional<unsigned> getMaxNTIDx(const Function &F) {
- return findOneNVVMAnnotation(F, "maxntidx");
+ return findOneNVVMAnnotation(&F, "maxntidx");
}
std::optional<unsigned> getMaxNTIDy(const Function &F) {
- return findOneNVVMAnnotation(F, "maxntidy");
+ return findOneNVVMAnnotation(&F, "maxntidy");
}
std::optional<unsigned> getMaxNTIDz(const Function &F) {
- return findOneNVVMAnnotation(F, "maxntidz");
+ return findOneNVVMAnnotation(&F, "maxntidz");
}
std::optional<unsigned> getMaxNTID(const Function &F) {
@@ -302,20 +289,32 @@ std::optional<unsigned> getMaxNTID(const Function &F) {
return std::nullopt;
}
-bool getMaxClusterRank(const Function &F, unsigned &x) {
- return findOneNVVMAnnotation(&F, "maxclusterrank", x);
+std::optional<unsigned> getClusterDimx(const Function &F) {
+ return findOneNVVMAnnotation(&F, "cluster_dim_x");
+}
+
+std::optional<unsigned> getClusterDimy(const Function &F) {
+ return findOneNVVMAnnotation(&F, "cluster_dim_y");
+}
+
+std::optional<unsigned> getClusterDimz(const Function &F) {
+ return findOneNVVMAnnotation(&F, "cluster_dim_z");
+}
+
+std::optional<unsigned> getMaxClusterRank(const Function &F) {
+ return findOneNVVMAnnotation(&F, "maxclusterrank");
}
std::optional<unsigned> getReqNTIDx(const Function &F) {
- return findOneNVVMAnnotation(F, "reqntidx");
+ return findOneNVVMAnnotation(&F, "reqntidx");
}
std::optional<unsigned> getReqNTIDy(const Function &F) {
- return findOneNVVMAnnotation(F, "reqntidy");
+ return findOneNVVMAnnotation(&F, "reqntidy");
}
std::optional<unsigned> getReqNTIDz(const Function &F) {
- return findOneNVVMAnnotation(F, "reqntidz");
+ return findOneNVVMAnnotation(&F, "reqntidz");
}
std::optional<unsigned> getReqNTID(const Function &F) {
@@ -328,21 +327,20 @@ std::optional<unsigned> getReqNTID(const Function &F) {
return std::nullopt;
}
-bool getMinCTASm(const Function &F, unsigned &x) {
- return findOneNVVMAnnotation(&F, "minctasm", x);
+std::optional<unsigned> getMinCTASm(const Function &F) {
+ return findOneNVVMAnnotation(&F, "minctasm");
}
-bool getMaxNReg(const Function &F, unsigned &x) {
- return findOneNVVMAnnotation(&F, "maxnreg", x);
+std::optional<unsigned> getMaxNReg(const Function &F) {
+ return findOneNVVMAnnotation(&F, "maxnreg");
}
bool isKernelFunction(const Function &F) {
- unsigned x = 0;
- if (!findOneNVVMAnnotation(&F, "kernel", x)) {
- // There is no NVVM metadata, check the calling convention
- return F.getCallingConv() == CallingConv::PTX_Kernel;
- }
- return (x == 1);
+ if (const auto x = findOneNVVMAnnotation(&F, "kernel"))
+ return (*x == 1);
+
+ // There is no NVVM metadata, check the calling convention
+ return F.getCallingConv() == CallingConv::PTX_Kernel;
}
MaybeAlign getAlign(const Function &F, unsigned Index) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.h b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
index eebd91fefe4f03..3755814f3ea23c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.h
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
@@ -31,8 +31,8 @@ class TargetMachine;
void clearAnnotationCache(const Module *);
-bool findOneNVVMAnnotation(const GlobalValue *, const std::string &,
- unsigned &);
+std::optional<unsigned> findOneNVVMAnnotation(const GlobalValue *,
+ const std::string &);
bool findAllNVVMAnnotation(const GlobalValue *, const std::string &,
std::vector<unsigned> &);
@@ -59,9 +59,13 @@ std::optional<unsigned> getReqNTIDy(const Function &);
std::optional<unsigned> getReqNTIDz(const Function &);
std::optional<unsigned> getReqNTID(const Function &);
-bool getMaxClusterRank(const Function &, unsigned &);
-bool getMinCTASm(const Function &, unsigned &);
-bool getMaxNReg(const Function &, unsigned &);
+std::optional<unsigned> getClusterDimx(const Function &F);
+std::optional<unsigned> getClusterDimy(const Function &F);
+std::optional<unsigned> getClusterDimz(const Function &F);
+
+std::optional<unsigned> getMaxClusterRank(const Function &);
+std::optional<unsigned> getMinCTASm(const Function &);
+std::optional<unsigned> getMaxNReg(const Function &);
bool isKernelFunction(const Function &);
bool isParamGridConstant(const Value &);
diff --git a/llvm/test/CodeGen/NVPTX/cluster-dim.ll b/llvm/test/CodeGen/NVPTX/cluster-dim.ll
new file mode 100644
index 00000000000000..109c9891417c57
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/cluster-dim.ll
@@ -0,0 +1,17 @@
+; RUN: llc < %s -march=nvptx -mcpu=sm_90 | FileCheck %s
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_90 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx -mcpu=sm_90 | %ptxas-verify %}
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_90 | %ptxas-verify %}
+
+; CHECK-LABEL: .entry kernel_func_clusterxyz
+define void @kernel_func_clusterxyz() {
+; CHECK: .explicitcluster
+; CHECK: .reqnctapercluster 3, 5, 7
+ ret void
+}
+
+
+!nvvm.annotations = !{!1, !2}
+
+!1 = !{ptr @kernel_func_clusterxyz, !"kernel", i32 1}
+!2 = !{ptr @kernel_func_clusterxyz, !"cluster_dim_x", i32 3, !"cluster_dim_y", i32 5, !"cluster_dim_z", i32 7}
``````````
</details>
https://github.com/llvm/llvm-project/pull/109548
More information about the llvm-commits
mailing list