[llvm] 489acb2 - [NVPTX][NFC] Refactor utilities to use std::optional (#109883)

via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 24 19:31:43 PDT 2024


Author: Alex MacLean
Date: 2024-09-24T19:31:40-07:00
New Revision: 489acb2401b51d940fcdbe965d4a5b2d39168b96

URL: https://github.com/llvm/llvm-project/commit/489acb2401b51d940fcdbe965d4a5b2d39168b96
DIFF: https://github.com/llvm/llvm-project/commit/489acb2401b51d940fcdbe965d4a5b2d39168b96.diff

LOG: [NVPTX][NFC] Refactor utilities to use std::optional (#109883)

Added: 
    

Modified: 
    llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
    llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
    llvm/lib/Target/NVPTX/NVPTXUtilities.h

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 38c51666139a89..9bcc911b6c3451 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -563,21 +563,19 @@ void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
     O << ".maxntid " << Maxntidx.value_or(1) << ", " << Maxntidy.value_or(1)
       << ", " << Maxntidz.value_or(1) << "\n";
 
-  unsigned Mincta = 0;
-  if (getMinCTASm(F, Mincta))
-    O << ".minnctapersm " << Mincta << "\n";
+  if (const auto Mincta = getMinCTASm(F))
+    O << ".minnctapersm " << *Mincta << "\n";
 
-  unsigned Maxnreg = 0;
-  if (getMaxNReg(F, Maxnreg))
-    O << ".maxnreg " << Maxnreg << "\n";
+  if (const auto Maxnreg = getMaxNReg(F))
+    O << ".maxnreg " << *Maxnreg << "\n";
 
   // .maxclusterrank directive requires SM_90 or higher, make sure that we
   // filter it out for lower SM versions, as it causes a hard ptxas crash.
   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
   const auto *STI = static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
-  unsigned Maxclusterrank = 0;
-  if (getMaxClusterRank(F, Maxclusterrank) && STI->getSmVersion() >= 90)
-    O << ".maxclusterrank " << Maxclusterrank << "\n";
+  if (STI->getSmVersion() >= 90)
+    if (const auto Maxclusterrank = getMaxClusterRank(F))
+      O << ".maxclusterrank " << *Maxclusterrank << "\n";
 }
 
 std::string NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const {

diff  --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
index 80361744fd5b6f..be1c87d07f4ded 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
@@ -13,6 +13,7 @@
 #include "NVPTXUtilities.h"
 #include "NVPTX.h"
 #include "NVPTXTargetMachine.h"
+#include "llvm/ADT/StringRef.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/GlobalVariable.h"
@@ -130,8 +131,8 @@ static void cacheAnnotationFromMD(const Module *m, const GlobalValue *gv) {
   }
 }
 
-bool findOneNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
-                           unsigned &retval) {
+static std::optional<unsigned> findOneNVVMAnnotation(const GlobalValue *gv,
+                                                     const std::string &prop) {
   auto &AC = getAnnotationCache();
   std::lock_guard<sys::Mutex> Guard(AC.Lock);
   const Module *m = gv->getParent();
@@ -140,21 +141,13 @@ bool findOneNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
   else if (AC.Cache[m].find(gv) == AC.Cache[m].end())
     cacheAnnotationFromMD(m, gv);
   if (AC.Cache[m][gv].find(prop) == AC.Cache[m][gv].end())
-    return false;
-  retval = AC.Cache[m][gv][prop][0];
-  return true;
-}
-
-static std::optional<unsigned>
-findOneNVVMAnnotation(const GlobalValue &GV, const std::string &PropName) {
-  unsigned RetVal;
-  if (findOneNVVMAnnotation(&GV, PropName, RetVal))
-    return RetVal;
-  return std::nullopt;
+    return std::nullopt;
+  return AC.Cache[m][gv][prop][0];
 }
 
-bool findAllNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
-                           std::vector<unsigned> &retval) {
+static bool findAllNVVMAnnotation(const GlobalValue *gv,
+                                  const std::string &prop,
+                                  std::vector<unsigned> &retval) {
   auto &AC = getAnnotationCache();
   std::lock_guard<sys::Mutex> Guard(AC.Lock);
   const Module *m = gv->getParent();
@@ -168,25 +161,13 @@ bool findAllNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
   return true;
 }
 
-bool isTexture(const Value &val) {
-  if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
-    unsigned Annot;
-    if (findOneNVVMAnnotation(gv, "texture", Annot)) {
-      assert((Annot == 1) && "Unexpected annotation on a texture symbol");
+static bool globalHasNVVMAnnotation(const Value &V, const std::string &Prop) {
+  if (const auto *GV = dyn_cast<GlobalValue>(&V))
+    if (const auto Annot = findOneNVVMAnnotation(GV, Prop)) {
+      assert((*Annot == 1) && "Unexpected annotation on a symbol");
       return true;
     }
-  }
-  return false;
-}
 
-bool isSurface(const Value &val) {
-  if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
-    unsigned Annot;
-    if (findOneNVVMAnnotation(gv, "surface", Annot)) {
-      assert((Annot == 1) && "Unexpected annotation on a surface symbol");
-      return true;
-    }
-  }
   return false;
 }
 
@@ -220,71 +201,60 @@ bool isParamGridConstant(const Value &V) {
   return false;
 }
 
-bool isSampler(const Value &val) {
+bool isTexture(const Value &V) { return globalHasNVVMAnnotation(V, "texture"); }
+
+bool isSurface(const Value &V) { return globalHasNVVMAnnotation(V, "surface"); }
+
+bool isSampler(const Value &V) {
   const char *AnnotationName = "sampler";
 
-  if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
-    unsigned Annot;
-    if (findOneNVVMAnnotation(gv, AnnotationName, Annot)) {
-      assert((Annot == 1) && "Unexpected annotation on a sampler symbol");
-      return true;
-    }
-  }
-  return argHasNVVMAnnotation(val, AnnotationName);
+  return globalHasNVVMAnnotation(V, AnnotationName) ||
+         argHasNVVMAnnotation(V, AnnotationName);
 }
 
-bool isImageReadOnly(const Value &val) {
-  return argHasNVVMAnnotation(val, "rdoimage");
+bool isImageReadOnly(const Value &V) {
+  return argHasNVVMAnnotation(V, "rdoimage");
 }
 
-bool isImageWriteOnly(const Value &val) {
-  return argHasNVVMAnnotation(val, "wroimage");
+bool isImageWriteOnly(const Value &V) {
+  return argHasNVVMAnnotation(V, "wroimage");
 }
 
-bool isImageReadWrite(const Value &val) {
-  return argHasNVVMAnnotation(val, "rdwrimage");
+bool isImageReadWrite(const Value &V) {
+  return argHasNVVMAnnotation(V, "rdwrimage");
 }
 
-bool isImage(const Value &val) {
-  return isImageReadOnly(val) || isImageWriteOnly(val) || isImageReadWrite(val);
+bool isImage(const Value &V) {
+  return isImageReadOnly(V) || isImageWriteOnly(V) || isImageReadWrite(V);
 }
 
-bool isManaged(const Value &val) {
-  if(const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
-    unsigned Annot;
-    if (findOneNVVMAnnotation(gv, "managed", Annot)) {
-      assert((Annot == 1) && "Unexpected annotation on a managed symbol");
-      return true;
-    }
-  }
-  return false;
-}
+bool isManaged(const Value &V) { return globalHasNVVMAnnotation(V, "managed"); }
 
-std::string getTextureName(const Value &val) {
-  assert(val.hasName() && "Found texture variable with no name");
-  return std::string(val.getName());
+StringRef getTextureName(const Value &V) {
+  assert(V.hasName() && "Found texture variable with no name");
+  return V.getName();
 }
 
-std::string getSurfaceName(const Value &val) {
-  assert(val.hasName() && "Found surface variable with no name");
-  return std::string(val.getName());
+StringRef getSurfaceName(const Value &V) {
+  assert(V.hasName() && "Found surface variable with no name");
+  return V.getName();
 }
 
-std::string getSamplerName(const Value &val) {
-  assert(val.hasName() && "Found sampler variable with no name");
-  return std::string(val.getName());
+StringRef getSamplerName(const Value &V) {
+  assert(V.hasName() && "Found sampler variable with no name");
+  return V.getName();
 }
 
 std::optional<unsigned> getMaxNTIDx(const Function &F) {
-  return findOneNVVMAnnotation(F, "maxntidx");
+  return findOneNVVMAnnotation(&F, "maxntidx");
 }
 
 std::optional<unsigned> getMaxNTIDy(const Function &F) {
-  return findOneNVVMAnnotation(F, "maxntidy");
+  return findOneNVVMAnnotation(&F, "maxntidy");
 }
 
 std::optional<unsigned> getMaxNTIDz(const Function &F) {
-  return findOneNVVMAnnotation(F, "maxntidz");
+  return findOneNVVMAnnotation(&F, "maxntidz");
 }
 
 std::optional<unsigned> getMaxNTID(const Function &F) {
@@ -302,20 +272,20 @@ std::optional<unsigned> getMaxNTID(const Function &F) {
   return std::nullopt;
 }
 
-bool getMaxClusterRank(const Function &F, unsigned &x) {
-  return findOneNVVMAnnotation(&F, "maxclusterrank", x);
+std::optional<unsigned> getMaxClusterRank(const Function &F) {
+  return findOneNVVMAnnotation(&F, "maxclusterrank");
 }
 
 std::optional<unsigned> getReqNTIDx(const Function &F) {
-  return findOneNVVMAnnotation(F, "reqntidx");
+  return findOneNVVMAnnotation(&F, "reqntidx");
 }
 
 std::optional<unsigned> getReqNTIDy(const Function &F) {
-  return findOneNVVMAnnotation(F, "reqntidy");
+  return findOneNVVMAnnotation(&F, "reqntidy");
 }
 
 std::optional<unsigned> getReqNTIDz(const Function &F) {
-  return findOneNVVMAnnotation(F, "reqntidz");
+  return findOneNVVMAnnotation(&F, "reqntidz");
 }
 
 std::optional<unsigned> getReqNTID(const Function &F) {
@@ -328,21 +298,20 @@ std::optional<unsigned> getReqNTID(const Function &F) {
   return std::nullopt;
 }
 
-bool getMinCTASm(const Function &F, unsigned &x) {
-  return findOneNVVMAnnotation(&F, "minctasm", x);
+std::optional<unsigned> getMinCTASm(const Function &F) {
+  return findOneNVVMAnnotation(&F, "minctasm");
 }
 
-bool getMaxNReg(const Function &F, unsigned &x) {
-  return findOneNVVMAnnotation(&F, "maxnreg", x);
+std::optional<unsigned> getMaxNReg(const Function &F) {
+  return findOneNVVMAnnotation(&F, "maxnreg");
 }
 
 bool isKernelFunction(const Function &F) {
-  unsigned x = 0;
-  if (!findOneNVVMAnnotation(&F, "kernel", x)) {
-    // There is no NVVM metadata, check the calling convention
-    return F.getCallingConv() == CallingConv::PTX_Kernel;
-  }
-  return (x == 1);
+  if (const auto X = findOneNVVMAnnotation(&F, "kernel"))
+    return (*X == 1);
+
+  // There is no NVVM metadata, check the calling convention
+  return F.getCallingConv() == CallingConv::PTX_Kernel;
 }
 
 MaybeAlign getAlign(const Function &F, unsigned Index) {

diff  --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.h b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
index 938b9b04b7a449..cf15dff85cbde0 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.h
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
@@ -32,11 +32,6 @@ class TargetMachine;
 
 void clearAnnotationCache(const Module *);
 
-bool findOneNVVMAnnotation(const GlobalValue *, const std::string &,
-                           unsigned &);
-bool findAllNVVMAnnotation(const GlobalValue *, const std::string &,
-                           std::vector<unsigned> &);
-
 bool isTexture(const Value &);
 bool isSurface(const Value &);
 bool isSampler(const Value &);
@@ -46,23 +41,23 @@ bool isImageWriteOnly(const Value &);
 bool isImageReadWrite(const Value &);
 bool isManaged(const Value &);
 
-std::string getTextureName(const Value &);
-std::string getSurfaceName(const Value &);
-std::string getSamplerName(const Value &);
+StringRef getTextureName(const Value &);
+StringRef getSurfaceName(const Value &);
+StringRef getSamplerName(const Value &);
 
 std::optional<unsigned> getMaxNTIDx(const Function &);
 std::optional<unsigned> getMaxNTIDy(const Function &);
 std::optional<unsigned> getMaxNTIDz(const Function &);
-std::optional<unsigned> getMaxNTID(const Function &F);
+std::optional<unsigned> getMaxNTID(const Function &);
 
 std::optional<unsigned> getReqNTIDx(const Function &);
 std::optional<unsigned> getReqNTIDy(const Function &);
 std::optional<unsigned> getReqNTIDz(const Function &);
 std::optional<unsigned> getReqNTID(const Function &);
 
-bool getMaxClusterRank(const Function &, unsigned &);
-bool getMinCTASm(const Function &, unsigned &);
-bool getMaxNReg(const Function &, unsigned &);
+std::optional<unsigned> getMaxClusterRank(const Function &);
+std::optional<unsigned> getMinCTASm(const Function &);
+std::optional<unsigned> getMaxNReg(const Function &);
 bool isKernelFunction(const Function &);
 bool isParamGridConstant(const Value &);
 
@@ -75,10 +70,9 @@ Function *getMaybeBitcastedCallee(const CallBase *CB);
 inline unsigned promoteScalarArgumentSize(unsigned size) {
   if (size <= 32)
     return 32;
-  else if (size <= 64)
+  if (size <= 64)
     return 64;
-  else
-    return size;
+  return size;
 }
 
 bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM);


        


More information about the llvm-commits mailing list