[llvm] [NVPTX] Support !"cluster_dim_{x,y,z}" metadata (PR #109548)

Alex MacLean via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 23 08:42:13 PDT 2024


================
@@ -563,21 +563,40 @@ void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
     O << ".maxntid " << Maxntidx.value_or(1) << ", " << Maxntidy.value_or(1)
       << ", " << Maxntidz.value_or(1) << "\n";
 
-  unsigned Mincta = 0;
-  if (getMinCTASm(F, Mincta))
-    O << ".minnctapersm " << Mincta << "\n";
+  if (const auto Mincta = getMinCTASm(F))
+    O << ".minnctapersm " << *Mincta << "\n";
 
-  unsigned Maxnreg = 0;
-  if (getMaxNReg(F, Maxnreg))
-    O << ".maxnreg " << Maxnreg << "\n";
+  if (const auto Maxnreg = getMaxNReg(F))
+    O << ".maxnreg " << *Maxnreg << "\n";
 
   // .maxclusterrank directive requires SM_90 or higher, make sure that we
   // filter it out for lower SM versions, as it causes a hard ptxas crash.
   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
   const auto *STI = static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
-  unsigned Maxclusterrank = 0;
-  if (getMaxClusterRank(F, Maxclusterrank) && STI->getSmVersion() >= 90)
-    O << ".maxclusterrank " << Maxclusterrank << "\n";
+
+  if (STI->getSmVersion() >= 90) {
+    std::optional<unsigned> ClusterX = getClusterDimx(F);
+    std::optional<unsigned> ClusterY = getClusterDimy(F);
+    std::optional<unsigned> ClusterZ = getClusterDimz(F);
+
+    if (ClusterX || ClusterY || ClusterZ) {
+      O << ".explicitcluster\n";
+      if (ClusterX.value_or(1) != 0) {
+        assert(ClusterY.value_or(1) && ClusterZ.value_or(1) &&
+               "clusterx != 0 implies clustery and clusterz should be non-zero "
----------------
AlexMaclean wrote:

I've renamed to "cluster_dim_{x,y,z}"

https://github.com/llvm/llvm-project/pull/109548


More information about the llvm-commits mailing list