[llvm] 489acb2 - [NVPTX][NFC] Refactor utilities to use std::optional (#109883)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Sep 24 19:31:43 PDT 2024
Author: Alex MacLean
Date: 2024-09-24T19:31:40-07:00
New Revision: 489acb2401b51d940fcdbe965d4a5b2d39168b96
URL: https://github.com/llvm/llvm-project/commit/489acb2401b51d940fcdbe965d4a5b2d39168b96
DIFF: https://github.com/llvm/llvm-project/commit/489acb2401b51d940fcdbe965d4a5b2d39168b96.diff
LOG: [NVPTX][NFC] Refactor utilities to use std::optional (#109883)
Added:
Modified:
llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
llvm/lib/Target/NVPTX/NVPTXUtilities.h
Removed:
################################################################################
diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 38c51666139a89..9bcc911b6c3451 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -563,21 +563,19 @@ void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
O << ".maxntid " << Maxntidx.value_or(1) << ", " << Maxntidy.value_or(1)
<< ", " << Maxntidz.value_or(1) << "\n";
- unsigned Mincta = 0;
- if (getMinCTASm(F, Mincta))
- O << ".minnctapersm " << Mincta << "\n";
+ if (const auto Mincta = getMinCTASm(F))
+ O << ".minnctapersm " << *Mincta << "\n";
- unsigned Maxnreg = 0;
- if (getMaxNReg(F, Maxnreg))
- O << ".maxnreg " << Maxnreg << "\n";
+ if (const auto Maxnreg = getMaxNReg(F))
+ O << ".maxnreg " << *Maxnreg << "\n";
// .maxclusterrank directive requires SM_90 or higher, make sure that we
// filter it out for lower SM versions, as it causes a hard ptxas crash.
const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
const auto *STI = static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
- unsigned Maxclusterrank = 0;
- if (getMaxClusterRank(F, Maxclusterrank) && STI->getSmVersion() >= 90)
- O << ".maxclusterrank " << Maxclusterrank << "\n";
+ if (STI->getSmVersion() >= 90)
+ if (const auto Maxclusterrank = getMaxClusterRank(F))
+ O << ".maxclusterrank " << *Maxclusterrank << "\n";
}
std::string NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const {
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
index 80361744fd5b6f..be1c87d07f4ded 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
@@ -13,6 +13,7 @@
#include "NVPTXUtilities.h"
#include "NVPTX.h"
#include "NVPTXTargetMachine.h"
+#include "llvm/ADT/StringRef.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalVariable.h"
@@ -130,8 +131,8 @@ static void cacheAnnotationFromMD(const Module *m, const GlobalValue *gv) {
}
}
-bool findOneNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
- unsigned &retval) {
+static std::optional<unsigned> findOneNVVMAnnotation(const GlobalValue *gv,
+ const std::string &prop) {
auto &AC = getAnnotationCache();
std::lock_guard<sys::Mutex> Guard(AC.Lock);
const Module *m = gv->getParent();
@@ -140,21 +141,13 @@ bool findOneNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
else if (AC.Cache[m].find(gv) == AC.Cache[m].end())
cacheAnnotationFromMD(m, gv);
if (AC.Cache[m][gv].find(prop) == AC.Cache[m][gv].end())
- return false;
- retval = AC.Cache[m][gv][prop][0];
- return true;
-}
-
-static std::optional<unsigned>
-findOneNVVMAnnotation(const GlobalValue &GV, const std::string &PropName) {
- unsigned RetVal;
- if (findOneNVVMAnnotation(&GV, PropName, RetVal))
- return RetVal;
- return std::nullopt;
+ return std::nullopt;
+ return AC.Cache[m][gv][prop][0];
}
-bool findAllNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
- std::vector<unsigned> &retval) {
+static bool findAllNVVMAnnotation(const GlobalValue *gv,
+ const std::string &prop,
+ std::vector<unsigned> &retval) {
auto &AC = getAnnotationCache();
std::lock_guard<sys::Mutex> Guard(AC.Lock);
const Module *m = gv->getParent();
@@ -168,25 +161,13 @@ bool findAllNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
return true;
}
-bool isTexture(const Value &val) {
- if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
- unsigned Annot;
- if (findOneNVVMAnnotation(gv, "texture", Annot)) {
- assert((Annot == 1) && "Unexpected annotation on a texture symbol");
+static bool globalHasNVVMAnnotation(const Value &V, const std::string &Prop) {
+ if (const auto *GV = dyn_cast<GlobalValue>(&V))
+ if (const auto Annot = findOneNVVMAnnotation(GV, Prop)) {
+ assert((*Annot == 1) && "Unexpected annotation on a symbol");
return true;
}
- }
- return false;
-}
-bool isSurface(const Value &val) {
- if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
- unsigned Annot;
- if (findOneNVVMAnnotation(gv, "surface", Annot)) {
- assert((Annot == 1) && "Unexpected annotation on a surface symbol");
- return true;
- }
- }
return false;
}
@@ -220,71 +201,60 @@ bool isParamGridConstant(const Value &V) {
return false;
}
-bool isSampler(const Value &val) {
+bool isTexture(const Value &V) { return globalHasNVVMAnnotation(V, "texture"); }
+
+bool isSurface(const Value &V) { return globalHasNVVMAnnotation(V, "surface"); }
+
+bool isSampler(const Value &V) {
const char *AnnotationName = "sampler";
- if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
- unsigned Annot;
- if (findOneNVVMAnnotation(gv, AnnotationName, Annot)) {
- assert((Annot == 1) && "Unexpected annotation on a sampler symbol");
- return true;
- }
- }
- return argHasNVVMAnnotation(val, AnnotationName);
+ return globalHasNVVMAnnotation(V, AnnotationName) ||
+ argHasNVVMAnnotation(V, AnnotationName);
}
-bool isImageReadOnly(const Value &val) {
- return argHasNVVMAnnotation(val, "rdoimage");
+bool isImageReadOnly(const Value &V) {
+ return argHasNVVMAnnotation(V, "rdoimage");
}
-bool isImageWriteOnly(const Value &val) {
- return argHasNVVMAnnotation(val, "wroimage");
+bool isImageWriteOnly(const Value &V) {
+ return argHasNVVMAnnotation(V, "wroimage");
}
-bool isImageReadWrite(const Value &val) {
- return argHasNVVMAnnotation(val, "rdwrimage");
+bool isImageReadWrite(const Value &V) {
+ return argHasNVVMAnnotation(V, "rdwrimage");
}
-bool isImage(const Value &val) {
- return isImageReadOnly(val) || isImageWriteOnly(val) || isImageReadWrite(val);
+bool isImage(const Value &V) {
+ return isImageReadOnly(V) || isImageWriteOnly(V) || isImageReadWrite(V);
}
-bool isManaged(const Value &val) {
- if(const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
- unsigned Annot;
- if (findOneNVVMAnnotation(gv, "managed", Annot)) {
- assert((Annot == 1) && "Unexpected annotation on a managed symbol");
- return true;
- }
- }
- return false;
-}
+bool isManaged(const Value &V) { return globalHasNVVMAnnotation(V, "managed"); }
-std::string getTextureName(const Value &val) {
- assert(val.hasName() && "Found texture variable with no name");
- return std::string(val.getName());
+StringRef getTextureName(const Value &V) {
+ assert(V.hasName() && "Found texture variable with no name");
+ return V.getName();
}
-std::string getSurfaceName(const Value &val) {
- assert(val.hasName() && "Found surface variable with no name");
- return std::string(val.getName());
+StringRef getSurfaceName(const Value &V) {
+ assert(V.hasName() && "Found surface variable with no name");
+ return V.getName();
}
-std::string getSamplerName(const Value &val) {
- assert(val.hasName() && "Found sampler variable with no name");
- return std::string(val.getName());
+StringRef getSamplerName(const Value &V) {
+ assert(V.hasName() && "Found sampler variable with no name");
+ return V.getName();
}
std::optional<unsigned> getMaxNTIDx(const Function &F) {
- return findOneNVVMAnnotation(F, "maxntidx");
+ return findOneNVVMAnnotation(&F, "maxntidx");
}
std::optional<unsigned> getMaxNTIDy(const Function &F) {
- return findOneNVVMAnnotation(F, "maxntidy");
+ return findOneNVVMAnnotation(&F, "maxntidy");
}
std::optional<unsigned> getMaxNTIDz(const Function &F) {
- return findOneNVVMAnnotation(F, "maxntidz");
+ return findOneNVVMAnnotation(&F, "maxntidz");
}
std::optional<unsigned> getMaxNTID(const Function &F) {
@@ -302,20 +272,20 @@ std::optional<unsigned> getMaxNTID(const Function &F) {
return std::nullopt;
}
-bool getMaxClusterRank(const Function &F, unsigned &x) {
- return findOneNVVMAnnotation(&F, "maxclusterrank", x);
+std::optional<unsigned> getMaxClusterRank(const Function &F) {
+ return findOneNVVMAnnotation(&F, "maxclusterrank");
}
std::optional<unsigned> getReqNTIDx(const Function &F) {
- return findOneNVVMAnnotation(F, "reqntidx");
+ return findOneNVVMAnnotation(&F, "reqntidx");
}
std::optional<unsigned> getReqNTIDy(const Function &F) {
- return findOneNVVMAnnotation(F, "reqntidy");
+ return findOneNVVMAnnotation(&F, "reqntidy");
}
std::optional<unsigned> getReqNTIDz(const Function &F) {
- return findOneNVVMAnnotation(F, "reqntidz");
+ return findOneNVVMAnnotation(&F, "reqntidz");
}
std::optional<unsigned> getReqNTID(const Function &F) {
@@ -328,21 +298,20 @@ std::optional<unsigned> getReqNTID(const Function &F) {
return std::nullopt;
}
-bool getMinCTASm(const Function &F, unsigned &x) {
- return findOneNVVMAnnotation(&F, "minctasm", x);
+std::optional<unsigned> getMinCTASm(const Function &F) {
+ return findOneNVVMAnnotation(&F, "minctasm");
}
-bool getMaxNReg(const Function &F, unsigned &x) {
- return findOneNVVMAnnotation(&F, "maxnreg", x);
+std::optional<unsigned> getMaxNReg(const Function &F) {
+ return findOneNVVMAnnotation(&F, "maxnreg");
}
bool isKernelFunction(const Function &F) {
- unsigned x = 0;
- if (!findOneNVVMAnnotation(&F, "kernel", x)) {
- // There is no NVVM metadata, check the calling convention
- return F.getCallingConv() == CallingConv::PTX_Kernel;
- }
- return (x == 1);
+ if (const auto X = findOneNVVMAnnotation(&F, "kernel"))
+ return (*X == 1);
+
+ // There is no NVVM metadata, check the calling convention
+ return F.getCallingConv() == CallingConv::PTX_Kernel;
}
MaybeAlign getAlign(const Function &F, unsigned Index) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.h b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
index 938b9b04b7a449..cf15dff85cbde0 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.h
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
@@ -32,11 +32,6 @@ class TargetMachine;
void clearAnnotationCache(const Module *);
-bool findOneNVVMAnnotation(const GlobalValue *, const std::string &,
- unsigned &);
-bool findAllNVVMAnnotation(const GlobalValue *, const std::string &,
- std::vector<unsigned> &);
-
bool isTexture(const Value &);
bool isSurface(const Value &);
bool isSampler(const Value &);
@@ -46,23 +41,23 @@ bool isImageWriteOnly(const Value &);
bool isImageReadWrite(const Value &);
bool isManaged(const Value &);
-std::string getTextureName(const Value &);
-std::string getSurfaceName(const Value &);
-std::string getSamplerName(const Value &);
+StringRef getTextureName(const Value &);
+StringRef getSurfaceName(const Value &);
+StringRef getSamplerName(const Value &);
std::optional<unsigned> getMaxNTIDx(const Function &);
std::optional<unsigned> getMaxNTIDy(const Function &);
std::optional<unsigned> getMaxNTIDz(const Function &);
-std::optional<unsigned> getMaxNTID(const Function &F);
+std::optional<unsigned> getMaxNTID(const Function &);
std::optional<unsigned> getReqNTIDx(const Function &);
std::optional<unsigned> getReqNTIDy(const Function &);
std::optional<unsigned> getReqNTIDz(const Function &);
std::optional<unsigned> getReqNTID(const Function &);
-bool getMaxClusterRank(const Function &, unsigned &);
-bool getMinCTASm(const Function &, unsigned &);
-bool getMaxNReg(const Function &, unsigned &);
+std::optional<unsigned> getMaxClusterRank(const Function &);
+std::optional<unsigned> getMinCTASm(const Function &);
+std::optional<unsigned> getMaxNReg(const Function &);
bool isKernelFunction(const Function &);
bool isParamGridConstant(const Value &);
@@ -75,10 +70,9 @@ Function *getMaybeBitcastedCallee(const CallBase *CB);
inline unsigned promoteScalarArgumentSize(unsigned size) {
if (size <= 32)
return 32;
- else if (size <= 64)
+ if (size <= 64)
return 64;
- else
- return size;
+ return size;
}
bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM);
More information about the llvm-commits
mailing list