[llvm] [NVPTX] Improve NVVMReflect Efficiency (PR #134416)
Yonah Goldberg via llvm-commits
llvm-commits at lists.llvm.org
Mon Apr 7 13:05:18 PDT 2025
https://github.com/YonahGoldberg updated https://github.com/llvm/llvm-project/pull/134416
>From 2f640ab95abb9bdacbb63ca6f04bab7084cf9bad Mon Sep 17 00:00:00 2001
From: Yonah Goldberg <ygoldberg at nvidia.com>
Date: Fri, 4 Apr 2025 16:51:40 +0000
Subject: [PATCH 1/4] making nvvm reflect more efficient
---
llvm/lib/Target/NVPTX/NVPTX.h | 10 +-
llvm/lib/Target/NVPTX/NVPTXPassRegistry.def | 2 +-
llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp | 5 +-
llvm/lib/Target/NVPTX/NVVMReflect.cpp | 153 ++++++++++++-------
4 files changed, 110 insertions(+), 60 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h
index 20a5bf46dc06b..8efa0bb546546 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.h
+++ b/llvm/lib/Target/NVPTX/NVPTX.h
@@ -43,7 +43,7 @@ ModulePass *createNVPTXAssignValidGlobalNamesPass();
ModulePass *createGenericToNVVMLegacyPass();
ModulePass *createNVPTXCtorDtorLoweringLegacyPass();
FunctionPass *createNVVMIntrRangePass();
-FunctionPass *createNVVMReflectPass(unsigned int SmVersion);
+ModulePass *createNVVMReflectPass(unsigned int SmVersion);
MachineFunctionPass *createNVPTXPrologEpilogPass();
MachineFunctionPass *createNVPTXReplaceImageHandlesPass();
FunctionPass *createNVPTXImageOptimizerPass();
@@ -60,12 +60,12 @@ struct NVVMIntrRangePass : PassInfoMixin<NVVMIntrRangePass> {
};
struct NVVMReflectPass : PassInfoMixin<NVVMReflectPass> {
- NVVMReflectPass();
- NVVMReflectPass(unsigned SmVersion) : SmVersion(SmVersion) {}
- PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
+ NVVMReflectPass() : NVVMReflectPass(0) {}
+ NVVMReflectPass(unsigned SmVersion);
+ PreservedAnalyses run(Module &F, ModuleAnalysisManager &AM);
private:
- unsigned SmVersion;
+ StringMap<int> VarMap;
};
struct GenericToNVVMPass : PassInfoMixin<GenericToNVVMPass> {
diff --git a/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def b/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def
index 34c79b8f77bae..1c813c2c51f70 100644
--- a/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def
+++ b/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def
@@ -18,6 +18,7 @@
#endif
MODULE_PASS("generic-to-nvvm", GenericToNVVMPass())
MODULE_PASS("nvptx-lower-ctor-dtor", NVPTXCtorDtorLoweringPass())
+MODULE_PASS("nvvm-reflect", NVVMReflectPass())
#undef MODULE_PASS
#ifndef FUNCTION_ANALYSIS
@@ -36,7 +37,6 @@ FUNCTION_ALIAS_ANALYSIS("nvptx-aa", NVPTXAA())
#define FUNCTION_PASS(NAME, CREATE_PASS)
#endif
FUNCTION_PASS("nvvm-intr-range", NVVMIntrRangePass())
-FUNCTION_PASS("nvvm-reflect", NVVMReflectPass())
FUNCTION_PASS("nvptx-copy-byval-args", NVPTXCopyByValArgsPass())
FUNCTION_PASS("nvptx-lower-args", NVPTXLowerArgsPass(*this));
#undef FUNCTION_PASS
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
index 5bb168704bad0..e84b707725566 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
@@ -260,11 +260,12 @@ void NVPTXTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
PB.registerPipelineStartEPCallback(
[this](ModulePassManager &PM, OptimizationLevel Level) {
- FunctionPassManager FPM;
// We do not want to fold out calls to nvvm.reflect early if the user
// has not provided a target architecture just yet.
if (Subtarget.hasTargetName())
- FPM.addPass(NVVMReflectPass(Subtarget.getSmVersion()));
+ PM.addPass(NVVMReflectPass(Subtarget.getSmVersion()));
+
+ FunctionPassManager FPM;
// Note: NVVMIntrRangePass was causing numerical discrepancies at one
// point, if issues crop up, consider disabling.
FPM.addPass(NVVMIntrRangePass());
diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
index 20b8bef1899b4..ababb7f7c9d1f 100644
--- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
@@ -4,7 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-//===----------------------------------------------------------------------===//
+
//
// This pass replaces occurrences of __nvvm_reflect("foo") and llvm.nvvm.reflect
// with an integer.
@@ -25,7 +25,6 @@
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
-#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
@@ -39,33 +38,47 @@
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Local.h"
+#include "llvm/Transforms/Utils/StripGCRelocates.h"
#include <algorithm>
#define NVVM_REFLECT_FUNCTION "__nvvm_reflect"
#define NVVM_REFLECT_OCL_FUNCTION "__nvvm_reflect_ocl"
using namespace llvm;
-#define DEBUG_TYPE "nvptx-reflect"
+#define DEBUG_TYPE "nvvm-reflect"
namespace llvm {
void initializeNVVMReflectPass(PassRegistry &);
}
namespace {
-class NVVMReflect : public FunctionPass {
+class NVVMReflect : public ModulePass {
+private:
+ StringMap<int> VarMap;
+ /// Process a reflect function by finding all its uses and replacing them with
+ /// appropriate constant values. For __CUDA_FTZ, uses the module flag value.
+ /// For __CUDA_ARCH, uses SmVersion * 10. For all other strings, uses 0.
+ bool handleReflectFunction(Function *F);
+ void setVarMap(Module &M);
+
public:
static char ID;
- unsigned int SmVersion;
NVVMReflect() : NVVMReflect(0) {}
- explicit NVVMReflect(unsigned int Sm) : FunctionPass(ID), SmVersion(Sm) {
+ // __CUDA_FTZ is assigned in `runOnModule` by checking nvvm-reflect-ftz module
+ // metadata.
+ explicit NVVMReflect(unsigned int Sm) : ModulePass(ID) {
+ VarMap["__CUDA_ARCH"] = Sm * 10;
initializeNVVMReflectPass(*PassRegistry::getPassRegistry());
}
-
- bool runOnFunction(Function &) override;
+ // This mapping will contain should include __CUDA_FTZ and __CUDA_ARCH values.
+ explicit NVVMReflect(const StringMap<int> &Mapping) : ModulePass(ID), VarMap(Mapping) {
+ initializeNVVMReflectPass(*PassRegistry::getPassRegistry());
+ }
+ bool runOnModule(Module &M) override;
};
} // namespace
-FunctionPass *llvm::createNVVMReflectPass(unsigned int SmVersion) {
+ModulePass *llvm::createNVVMReflectPass(unsigned int SmVersion) {
return new NVVMReflect(SmVersion);
}
@@ -78,27 +91,51 @@ INITIALIZE_PASS(NVVMReflect, "nvvm-reflect",
"Replace occurrences of __nvvm_reflect() calls with 0/1", false,
false)
-static bool runNVVMReflect(Function &F, unsigned SmVersion) {
- if (!NVVMReflectEnabled)
- return false;
+static cl::list<std::string>
+ ReflectList("nvvm-reflect-list", cl::value_desc("name=<int>"), cl::Hidden,
+ cl::desc("A list of string=num assignments"),
+ cl::ValueRequired);
- if (F.getName() == NVVM_REFLECT_FUNCTION ||
- F.getName() == NVVM_REFLECT_OCL_FUNCTION) {
- assert(F.isDeclaration() && "_reflect function should not have a body");
- assert(F.getReturnType()->isIntegerTy() &&
- "_reflect's return type should be integer");
- return false;
+/// The command line can look as follows :
+/// -nvvm-reflect-list a=1,b=2 -nvvm-reflect-list c=3,d=0 -R e=2
+/// The strings "a=1,b=2", "c=3,d=0", "e=2" are available in the
+/// ReflectList vector. First, each of ReflectList[i] is 'split'
+/// using "," as the delimiter. Then each of this part is split
+/// using "=" as the delimiter.
+void NVVMReflect::setVarMap(Module &M) {
+ if (auto *Flag = mdconst::extract_or_null<ConstantInt>(
+ M.getModuleFlag("nvvm-reflect-ftz")))
+ VarMap["__CUDA_FTZ"] = Flag->getSExtValue();
+
+ for (unsigned I = 0, E = ReflectList.size(); I != E; ++I) {
+ LLVM_DEBUG(dbgs() << "Option : " << ReflectList[I] << "\n");
+ SmallVector<StringRef, 4> NameValList;
+ StringRef(ReflectList[I]).split(NameValList, ",");
+ for (unsigned J = 0, EJ = NameValList.size(); J != EJ; ++J) {
+ SmallVector<StringRef, 2> NameValPair;
+ NameValList[J].split(NameValPair, "=");
+ assert(NameValPair.size() == 2 && "name=val expected");
+ StringRef ValStr = NameValPair[1].trim();
+ int Val;
+ if (ValStr.getAsInteger(10, Val))
+ report_fatal_error("integer value expected");
+ VarMap[NameValPair[0]] = Val;
+ }
}
+}
+
+bool NVVMReflect::handleReflectFunction(Function *F) {
+ // Validate _reflect function
+ assert(F->isDeclaration() && "_reflect function should not have a body");
+ assert(F->getReturnType()->isIntegerTy() &&
+ "_reflect's return type should be integer");
SmallVector<Instruction *, 4> ToRemove;
SmallVector<Instruction *, 4> ToSimplify;
- // Go through the calls in this function. Each call to __nvvm_reflect or
- // llvm.nvvm.reflect should be a CallInst with a ConstantArray argument.
- // First validate that. If the c-string corresponding to the ConstantArray can
- // be found successfully, see if it can be found in VarMap. If so, replace the
- // uses of CallInst with the value found in VarMap. If not, replace the use
- // with value 0.
+ // Go through the uses of the reflect function. Each use should be a CallInst
+ // with a ConstantArray argument. Replace the uses with the appropriate
+ // constant values.
// The IR for __nvvm_reflect calls differs between CUDA versions.
//
@@ -119,15 +156,10 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
//
// In this case, we get a Constant with a GlobalVariable operand and we need
// to dig deeper to find its initializer with the string we'll use for lookup.
- for (Instruction &I : instructions(F)) {
- CallInst *Call = dyn_cast<CallInst>(&I);
- if (!Call)
- continue;
- Function *Callee = Call->getCalledFunction();
- if (!Callee || (Callee->getName() != NVVM_REFLECT_FUNCTION &&
- Callee->getName() != NVVM_REFLECT_OCL_FUNCTION &&
- Callee->getIntrinsicID() != Intrinsic::nvvm_reflect))
- continue;
+
+ for (User *U : F->users()) {
+ assert(isa<CallInst>(U) && "Only a call instruction can use _reflect");
+ CallInst *Call = cast<CallInst>(U);
// FIXME: Improve error handling here and elsewhere in this pass.
assert(Call->getNumOperands() == 2 &&
@@ -162,20 +194,15 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
"Format of _reflect function not recognized");
StringRef ReflectArg = cast<ConstantDataSequential>(Operand)->getAsString();
+ // Remove the null terminator from the string
ReflectArg = ReflectArg.substr(0, ReflectArg.size() - 1);
LLVM_DEBUG(dbgs() << "Arg of _reflect : " << ReflectArg << "\n");
int ReflectVal = 0; // The default value is 0
- if (ReflectArg == "__CUDA_FTZ") {
- // Try to pull __CUDA_FTZ from the nvvm-reflect-ftz module flag. Our
- // choice here must be kept in sync with AutoUpgrade, which uses the same
- // technique to detect whether ftz is enabled.
- if (auto *Flag = mdconst::extract_or_null<ConstantInt>(
- F.getParent()->getModuleFlag("nvvm-reflect-ftz")))
- ReflectVal = Flag->getSExtValue();
- } else if (ReflectArg == "__CUDA_ARCH") {
- ReflectVal = SmVersion * 10;
+ if (VarMap.contains(ReflectArg)) {
+ ReflectVal = VarMap[ReflectArg];
}
+ LLVM_DEBUG(dbgs() << "ReflectVal: " << ReflectVal << "\n");
// If the immediate user is a simple comparison we want to simplify it.
for (User *U : Call->users())
@@ -191,7 +218,7 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
// until we find a terminator that we can then remove.
while (!ToSimplify.empty()) {
Instruction *I = ToSimplify.pop_back_val();
- if (Constant *C = ConstantFoldInstruction(I, F.getDataLayout())) {
+ if (Constant *C = ConstantFoldInstruction(I, F->getDataLayout())) {
for (User *U : I->users())
if (Instruction *I = dyn_cast<Instruction>(U))
ToSimplify.push_back(I);
@@ -208,23 +235,45 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
// Removing via isInstructionTriviallyDead may add duplicates to the ToRemove
// array. Filter out the duplicates before starting to erase from parent.
std::sort(ToRemove.begin(), ToRemove.end());
- auto NewLastIter = llvm::unique(ToRemove);
+ auto *NewLastIter = llvm::unique(ToRemove);
ToRemove.erase(NewLastIter, ToRemove.end());
for (Instruction *I : ToRemove)
I->eraseFromParent();
+ // Remove the __nvvm_reflect function from the module
+ F->eraseFromParent();
return ToRemove.size() > 0;
}
-bool NVVMReflect::runOnFunction(Function &F) {
- return runNVVMReflect(F, SmVersion);
-}
+bool NVVMReflect::runOnModule(Module &M) {
+ if (!NVVMReflectEnabled)
+ return false;
+
+ setVarMap(M);
-NVVMReflectPass::NVVMReflectPass() : NVVMReflectPass(0) {}
+ bool Changed = false;
+ // Names of reflect function to find and replace
+ SmallVector<std::string, 3> ReflectNames = {
+ NVVM_REFLECT_FUNCTION, NVVM_REFLECT_OCL_FUNCTION,
+ Intrinsic::getName(Intrinsic::nvvm_reflect).str()};
+
+ // Process all reflect functions
+ for (const std::string &Name : ReflectNames) {
+ Function *ReflectFunction = M.getFunction(Name);
+ if (ReflectFunction) {
+ Changed |= handleReflectFunction(ReflectFunction);
+ }
+ }
+
+ return Changed;
+}
-PreservedAnalyses NVVMReflectPass::run(Function &F,
- FunctionAnalysisManager &AM) {
- return runNVVMReflect(F, SmVersion) ? PreservedAnalyses::none()
- : PreservedAnalyses::all();
+// Implementations for the pass that works with the new pass manager.
+NVVMReflectPass::NVVMReflectPass(unsigned SmVersion) {
+ VarMap["__CUDA_ARCH"] = SmVersion * 10;
}
+PreservedAnalyses NVVMReflectPass::run(Module &M, ModuleAnalysisManager &AM) {
+ return NVVMReflect(VarMap).runOnModule(M) ? PreservedAnalyses::none()
+ : PreservedAnalyses::all();
+}
\ No newline at end of file
>From 10c19c83ed7eab6ee80b7721df2ee3d9932f9b10 Mon Sep 17 00:00:00 2001
From: Yonah Goldberg <ygoldberg at nvidia.com>
Date: Fri, 4 Apr 2025 16:59:38 +0000
Subject: [PATCH 2/4] cleanup
---
llvm/lib/Target/NVPTX/NVVMReflect.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
index ababb7f7c9d1f..ee562f7ede92b 100644
--- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
@@ -4,7 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-
+//===----------------------------------------------------------------------===//
//
// This pass replaces occurrences of __nvvm_reflect("foo") and llvm.nvvm.reflect
// with an integer.
>From 6241fd7baeffbbfde60f2937644a3f8d027529d1 Mon Sep 17 00:00:00 2001
From: Yonah Goldberg <ygoldberg at nvidia.com>
Date: Fri, 4 Apr 2025 17:17:14 +0000
Subject: [PATCH 3/4] newline
---
llvm/lib/Target/NVPTX/NVVMReflect.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
index ee562f7ede92b..9a8e22032916d 100644
--- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
@@ -276,4 +276,4 @@ NVVMReflectPass::NVVMReflectPass(unsigned SmVersion) {
PreservedAnalyses NVVMReflectPass::run(Module &M, ModuleAnalysisManager &AM) {
return NVVMReflect(VarMap).runOnModule(M) ? PreservedAnalyses::none()
: PreservedAnalyses::all();
-}
\ No newline at end of file
+}
>From e86e7a5f8ac28ddffb07aaebf252f682fbc6efcc Mon Sep 17 00:00:00 2001
From: Yonah Goldberg <ygoldberg at nvidia.com>
Date: Mon, 7 Apr 2025 20:05:06 +0000
Subject: [PATCH 4/4] reflect improvements
---
llvm/lib/Target/NVPTX/NVVMReflect.cpp | 183 ++++++++++--------
.../CodeGen/NVPTX/nvvm-reflect-options.ll | 26 +++
2 files changed, 124 insertions(+), 85 deletions(-)
create mode 100644 llvm/test/CodeGen/NVPTX/nvvm-reflect-options.ll
diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
index 9a8e22032916d..d96f4606ff969 100644
--- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
@@ -4,6 +4,12 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
+// NVIDIA_COPYRIGHT_BEGIN
+//
+// Copyright (c) 2023-2025, NVIDIA CORPORATION. All rights reserved.
+//
+// NVIDIA_COPYRIGHT_END
+//
//===----------------------------------------------------------------------===//
//
// This pass replaces occurrences of __nvvm_reflect("foo") and llvm.nvvm.reflect
@@ -38,8 +44,7 @@
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Local.h"
-#include "llvm/Transforms/Utils/StripGCRelocates.h"
-#include <algorithm>
+#include "llvm/ADT/StringExtras.h"
#define NVVM_REFLECT_FUNCTION "__nvvm_reflect"
#define NVVM_REFLECT_OCL_FUNCTION "__nvvm_reflect_ocl"
@@ -58,16 +63,15 @@ class NVVMReflect : public ModulePass {
/// Process a reflect function by finding all its uses and replacing them with
/// appropriate constant values. For __CUDA_FTZ, uses the module flag value.
/// For __CUDA_ARCH, uses SmVersion * 10. For all other strings, uses 0.
- bool handleReflectFunction(Function *F);
+ void handleReflectFunction(Function *F);
void setVarMap(Module &M);
-
+ void foldReflectCall(CallInst *Call, Constant *NewValue);
public:
static char ID;
NVVMReflect() : NVVMReflect(0) {}
// __CUDA_FTZ is assigned in `runOnModule` by checking nvvm-reflect-ftz module
// metadata.
- explicit NVVMReflect(unsigned int Sm) : ModulePass(ID) {
- VarMap["__CUDA_ARCH"] = Sm * 10;
+ explicit NVVMReflect(unsigned SmVersion) : ModulePass(ID), VarMap({{"__CUDA_ARCH", SmVersion * 10}}) {
initializeNVVMReflectPass(*PassRegistry::getPassRegistry());
}
// This mapping will contain should include __CUDA_FTZ and __CUDA_ARCH values.
@@ -91,51 +95,58 @@ INITIALIZE_PASS(NVVMReflect, "nvvm-reflect",
"Replace occurrences of __nvvm_reflect() calls with 0/1", false,
false)
+// Allow users to specify additional key/value pairs to reflect. These key/value pairs
+// are the last to be added to the VarMap, and therefore will take precedence over initial
+// values (i.e. __CUDA_FTZ from module medadata and __CUDA_ARCH from SmVersion).
static cl::list<std::string>
- ReflectList("nvvm-reflect-list", cl::value_desc("name=<int>"), cl::Hidden,
- cl::desc("A list of string=num assignments"),
- cl::ValueRequired);
-
-/// The command line can look as follows :
-/// -nvvm-reflect-list a=1,b=2 -nvvm-reflect-list c=3,d=0 -R e=2
-/// The strings "a=1,b=2", "c=3,d=0", "e=2" are available in the
-/// ReflectList vector. First, each of ReflectList[i] is 'split'
-/// using "," as the delimiter. Then each of this part is split
-/// using "=" as the delimiter.
+ReflectList("nvvm-reflect-add", cl::value_desc("name=<int>"), cl::Hidden,
+ cl::desc("list of comma-separated key=value pairs"),
+ cl::ValueRequired);
+
+// Set the VarMap with, first, the value of __CUDA_FTZ from module metadata, and then
+// the key/value pairs from the command line.
void NVVMReflect::setVarMap(Module &M) {
+ LLVM_DEBUG(dbgs() << "Reflect list values:\n");
+ for (StringRef Option : ReflectList) {
+ LLVM_DEBUG(dbgs() << " " << Option << "\n");
+ }
if (auto *Flag = mdconst::extract_or_null<ConstantInt>(
- M.getModuleFlag("nvvm-reflect-ftz")))
+ M.getModuleFlag("nvvm-reflect-ftz")))
VarMap["__CUDA_FTZ"] = Flag->getSExtValue();
- for (unsigned I = 0, E = ReflectList.size(); I != E; ++I) {
- LLVM_DEBUG(dbgs() << "Option : " << ReflectList[I] << "\n");
- SmallVector<StringRef, 4> NameValList;
- StringRef(ReflectList[I]).split(NameValList, ",");
- for (unsigned J = 0, EJ = NameValList.size(); J != EJ; ++J) {
- SmallVector<StringRef, 2> NameValPair;
- NameValList[J].split(NameValPair, "=");
- assert(NameValPair.size() == 2 && "name=val expected");
- StringRef ValStr = NameValPair[1].trim();
+ /// The command line can look as follows :
+ /// -nvvm-reflect-add a=1,b=2 -nvvm-reflect-add c=3,d=0 -nvvm-reflect-add e=2
+ /// The strings "a=1,b=2", "c=3,d=0", "e=2" are available in the
+ /// ReflectList vector. First, each of ReflectList[i] is 'split'
+ /// using "," as the delimiter. Then each of this part is split
+ /// using "=" as the delimiter.
+ for (StringRef Option : ReflectList) {
+ LLVM_DEBUG(dbgs() << "ReflectOption : " << Option << "\n");
+ while (!Option.empty()) {
+ std::pair<StringRef, StringRef> Split = Option.split(',');
+ StringRef NameVal = Split.first;
+ Option = Split.second;
+
+ auto NameValPair = NameVal.split('=');
+ assert(!NameValPair.first.empty() && !NameValPair.second.empty() &&
+ "name=val expected");
+
int Val;
- if (ValStr.getAsInteger(10, Val))
+ if (!to_integer(NameValPair.second.trim(), Val, 10))
report_fatal_error("integer value expected");
- VarMap[NameValPair[0]] = Val;
+ VarMap[NameValPair.first] = Val;
}
}
}
-bool NVVMReflect::handleReflectFunction(Function *F) {
+void NVVMReflect::handleReflectFunction(Function *F) {
// Validate _reflect function
assert(F->isDeclaration() && "_reflect function should not have a body");
- assert(F->getReturnType()->isIntegerTy() &&
- "_reflect's return type should be integer");
+ assert(F->getReturnType()->isIntegerTy() && "_reflect's return type should be integer");
- SmallVector<Instruction *, 4> ToRemove;
- SmallVector<Instruction *, 4> ToSimplify;
// Go through the uses of the reflect function. Each use should be a CallInst
- // with a ConstantArray argument. Replace the uses with the appropriate
- // constant values.
+ // with a ConstantArray argument. Replace the uses with the appropriate constant values.
// The IR for __nvvm_reflect calls differs between CUDA versions.
//
@@ -157,7 +168,7 @@ bool NVVMReflect::handleReflectFunction(Function *F) {
// In this case, we get a Constant with a GlobalVariable operand and we need
// to dig deeper to find its initializer with the string we'll use for lookup.
- for (User *U : F->users()) {
+ for (User *U : make_early_inc_range(F->users())) {
assert(isa<CallInst>(U) && "Only a call instruction can use _reflect");
CallInst *Call = cast<CallInst>(U);
@@ -169,21 +180,23 @@ bool NVVMReflect::handleReflectFunction(Function *F) {
// conversion of the string.
const Value *Str = Call->getArgOperand(0);
if (const CallInst *ConvCall = dyn_cast<CallInst>(Str)) {
- // FIXME: Add assertions about ConvCall.
+ // Verify this is the constant-to-generic intrinsic
+ Function *Callee = ConvCall->getCalledFunction();
+ assert(Callee && Callee->isIntrinsic() &&
+ Callee->getName().starts_with("llvm.nvvm.ptr.constant.to.gen") &&
+ "Expected llvm.nvvm.ptr.constant.to.gen intrinsic");
+ assert(ConvCall->getNumOperands() == 2 && "Expected one argument for ptr conversion");
Str = ConvCall->getArgOperand(0);
}
// Pre opaque pointers we have a constant expression wrapping the constant
- // string.
Str = Str->stripPointerCasts();
- assert(isa<Constant>(Str) &&
- "Format of __nvvm_reflect function not recognized");
+ assert(isa<Constant>(Str) && "Format of __nvvm_reflect function not recognized");
const Value *Operand = cast<Constant>(Str)->getOperand(0);
if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(Operand)) {
// For CUDA-7.0 style __nvvm_reflect calls, we need to find the operand's
// initializer.
- assert(GV->hasInitializer() &&
- "Format of _reflect function not recognized");
+ assert(GV->hasInitializer() && "Format of _reflect function not recognized");
const Constant *Initializer = GV->getInitializer();
Operand = Initializer;
}
@@ -196,54 +209,48 @@ bool NVVMReflect::handleReflectFunction(Function *F) {
StringRef ReflectArg = cast<ConstantDataSequential>(Operand)->getAsString();
// Remove the null terminator from the string
ReflectArg = ReflectArg.substr(0, ReflectArg.size() - 1);
- LLVM_DEBUG(dbgs() << "Arg of _reflect : " << ReflectArg << "\n");
int ReflectVal = 0; // The default value is 0
if (VarMap.contains(ReflectArg)) {
ReflectVal = VarMap[ReflectArg];
}
- LLVM_DEBUG(dbgs() << "ReflectVal: " << ReflectVal << "\n");
+ LLVM_DEBUG(dbgs() << "Replacing call of reflect function " << F->getName() << "(" << ReflectArg << ") with value " << ReflectVal << "\n");
+ Constant *NewValue = ConstantInt::get(Call->getType(), ReflectVal);
+ foldReflectCall(Call, NewValue);
+ Call->eraseFromParent();
+ }
- // If the immediate user is a simple comparison we want to simplify it.
- for (User *U : Call->users())
- if (Instruction *I = dyn_cast<Instruction>(U))
- ToSimplify.push_back(I);
+ // Remove the __nvvm_reflect function from the module
+ F->eraseFromParent();
+}
- Call->replaceAllUsesWith(ConstantInt::get(Call->getType(), ReflectVal));
- ToRemove.push_back(Call);
+void NVVMReflect::foldReflectCall(CallInst *Call, Constant *NewValue) {
+ // Initialize worklist with all users of the call
+ SmallVector<Instruction*, 8> Worklist;
+ for (User *U : Call->users()) {
+ if (Instruction *I = dyn_cast<Instruction>(U)) {
+ Worklist.push_back(I);
+ }
}
- // The code guarded by __nvvm_reflect may be invalid for the target machine.
- // Traverse the use-def chain, continually simplifying constant expressions
- // until we find a terminator that we can then remove.
- while (!ToSimplify.empty()) {
- Instruction *I = ToSimplify.pop_back_val();
- if (Constant *C = ConstantFoldInstruction(I, F->getDataLayout())) {
- for (User *U : I->users())
- if (Instruction *I = dyn_cast<Instruction>(U))
- ToSimplify.push_back(I);
+ Call->replaceAllUsesWith(NewValue);
- I->replaceAllUsesWith(C);
- if (isInstructionTriviallyDead(I)) {
- ToRemove.push_back(I);
+ while (!Worklist.empty()) {
+ Instruction *I = Worklist.pop_back_val();
+ if (Constant *C = ConstantFoldInstruction(I, Call->getModule()->getDataLayout())) {
+ // Add all users of this instruction to the worklist, replace it with the constant
+ // then delete it if it's dead
+ for (User *U : I->users()) {
+ if (Instruction *UI = dyn_cast<Instruction>(U))
+ Worklist.push_back(UI);
}
+ I->replaceAllUsesWith(C);
+ if (isInstructionTriviallyDead(I))
+ I->eraseFromParent();
} else if (I->isTerminator()) {
ConstantFoldTerminator(I->getParent());
}
}
-
- // Removing via isInstructionTriviallyDead may add duplicates to the ToRemove
- // array. Filter out the duplicates before starting to erase from parent.
- std::sort(ToRemove.begin(), ToRemove.end());
- auto *NewLastIter = llvm::unique(ToRemove);
- ToRemove.erase(NewLastIter, ToRemove.end());
-
- for (Instruction *I : ToRemove)
- I->eraseFromParent();
-
- // Remove the __nvvm_reflect function from the module
- F->eraseFromParent();
- return ToRemove.size() > 0;
}
bool NVVMReflect::runOnModule(Module &M) {
@@ -254,15 +261,19 @@ bool NVVMReflect::runOnModule(Module &M) {
bool Changed = false;
// Names of reflect function to find and replace
- SmallVector<std::string, 3> ReflectNames = {
- NVVM_REFLECT_FUNCTION, NVVM_REFLECT_OCL_FUNCTION,
- Intrinsic::getName(Intrinsic::nvvm_reflect).str()};
+ SmallVector<StringRef, 5> ReflectNames = {
+ NVVM_REFLECT_FUNCTION,
+ NVVM_REFLECT_OCL_FUNCTION,
+ Intrinsic::getName(Intrinsic::nvvm_reflect),
+ };
// Process all reflect functions
- for (const std::string &Name : ReflectNames) {
- Function *ReflectFunction = M.getFunction(Name);
- if (ReflectFunction) {
- Changed |= handleReflectFunction(ReflectFunction);
+ for (StringRef Name : ReflectNames) {
+ if (Function *ReflectFunction = M.getFunction(Name)) {
+ // If the reflect functition is called, we need to replace the call
+ // with the appropriate constant, modifying the IR.
+ Changed |= ReflectFunction->getNumUses() > 0;
+ handleReflectFunction(ReflectFunction);
}
}
@@ -273,7 +284,9 @@ bool NVVMReflect::runOnModule(Module &M) {
NVVMReflectPass::NVVMReflectPass(unsigned SmVersion) {
VarMap["__CUDA_ARCH"] = SmVersion * 10;
}
-PreservedAnalyses NVVMReflectPass::run(Module &M, ModuleAnalysisManager &AM) {
+
+PreservedAnalyses NVVMReflectPass::run(Module &M,
+ ModuleAnalysisManager &AM) {
return NVVMReflect(VarMap).runOnModule(M) ? PreservedAnalyses::none()
- : PreservedAnalyses::all();
-}
+ : PreservedAnalyses::all();
+}
\ No newline at end of file
diff --git a/llvm/test/CodeGen/NVPTX/nvvm-reflect-options.ll b/llvm/test/CodeGen/NVPTX/nvvm-reflect-options.ll
new file mode 100644
index 0000000000000..54087897d65b5
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/nvvm-reflect-options.ll
@@ -0,0 +1,26 @@
+; Verify that when passing in command-line options to NVVMReflect, that reflect calls are replaced with
+; the appropriate command line values.
+
+declare i32 @__nvvm_reflect(ptr)
+ at ftz = private unnamed_addr addrspace(1) constant [11 x i8] c"__CUDA_FTZ\00"
+ at arch = private unnamed_addr addrspace(1) constant [12 x i8] c"__CUDA_ARCH\00"
+
+; RUN: opt -passes=nvvm-reflect -mtriple=nvptx-nvidia-cuda -nvvm-reflect-add=__CUDA_FTZ=1,__CUDA_ARCH=350 %s -S | FileCheck %s --check-prefix=CHECK-FTZ1-ARCH350
+; RUN: opt -passes=nvvm-reflect -mtriple=nvptx-nvidia-cuda -nvvm-reflect-add=__CUDA_FTZ=0 -nvvm-reflect-add=__CUDA_ARCH=520 %s -S | FileCheck %s --check-prefix=CHECK-FTZ0-ARCH520
+
+; Verify that if we have module metadata that sets __CUDA_FTZ=1, that gets overridden by the command line arguments
+
+; RUN: cat %s > %t.options
+; RUN: echo '!llvm.module.flags = !{!0}' >> %t.options
+; RUN: echo '!0 = !{i32 4, !"nvvm-reflect-ftz", i32 1}' >> %t.options
+; RUN: opt -passes=nvvm-reflect -mtriple=nvptx-nvidia-cuda -nvvm-reflect-add=__CUDA_FTZ=0,__CUDA_ARCH=520 %t.options -S | FileCheck %s --check-prefix=CHECK-FTZ0-ARCH520
+
+define i32 @options() {
+ %1 = call i32 @__nvvm_reflect(ptr addrspacecast (ptr addrspace(1) @ftz to ptr))
+ %2 = call i32 @__nvvm_reflect(ptr addrspacecast (ptr addrspace(1) @arch to ptr))
+ %3 = add i32 %1, %2
+ ret i32 %3
+}
+
+; CHECK-FTZ1-ARCH350: ret i32 351
+; CHECK-FTZ0-ARCH520: ret i32 520
\ No newline at end of file
More information about the llvm-commits
mailing list