[llvm] [NVPTX] Improve NVVMReflect Efficiency (PR #134416)

Yonah Goldberg via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 8 15:29:58 PDT 2025


https://github.com/YonahGoldberg updated https://github.com/llvm/llvm-project/pull/134416

>From 21ec83872e1846e1f2453d1dafd8651a76d20494 Mon Sep 17 00:00:00 2001
From: Yonah Goldberg <ygoldberg at nvidia.com>
Date: Fri, 4 Apr 2025 16:51:40 +0000
Subject: [PATCH 01/13] making nvvm reflect more efficient

---
 llvm/lib/Target/NVPTX/NVPTX.h                |  10 +-
 llvm/lib/Target/NVPTX/NVPTXPassRegistry.def  |   2 +-
 llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp |   5 +-
 llvm/lib/Target/NVPTX/NVVMReflect.cpp        | 155 ++++++++++++-------
 4 files changed, 112 insertions(+), 60 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h
index ff983e52179af..f98ace3a0d189 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.h
+++ b/llvm/lib/Target/NVPTX/NVPTX.h
@@ -43,7 +43,7 @@ ModulePass *createNVPTXAssignValidGlobalNamesPass();
 ModulePass *createGenericToNVVMLegacyPass();
 ModulePass *createNVPTXCtorDtorLoweringLegacyPass();
 FunctionPass *createNVVMIntrRangePass();
-FunctionPass *createNVVMReflectPass(unsigned int SmVersion);
+ModulePass *createNVVMReflectPass(unsigned int SmVersion);
 MachineFunctionPass *createNVPTXPrologEpilogPass();
 MachineFunctionPass *createNVPTXReplaceImageHandlesPass();
 FunctionPass *createNVPTXImageOptimizerPass();
@@ -78,12 +78,12 @@ struct NVVMIntrRangePass : PassInfoMixin<NVVMIntrRangePass> {
 };
 
 struct NVVMReflectPass : PassInfoMixin<NVVMReflectPass> {
-  NVVMReflectPass();
-  NVVMReflectPass(unsigned SmVersion) : SmVersion(SmVersion) {}
-  PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
+  NVVMReflectPass() : NVVMReflectPass(0) {}
+  NVVMReflectPass(unsigned SmVersion);
+  PreservedAnalyses run(Module &F, ModuleAnalysisManager &AM);
 
 private:
-  unsigned SmVersion;
+  StringMap<int> VarMap;
 };
 
 struct GenericToNVVMPass : PassInfoMixin<GenericToNVVMPass> {
diff --git a/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def b/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def
index 34c79b8f77bae..1c813c2c51f70 100644
--- a/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def
+++ b/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def
@@ -18,6 +18,7 @@
 #endif
 MODULE_PASS("generic-to-nvvm", GenericToNVVMPass())
 MODULE_PASS("nvptx-lower-ctor-dtor", NVPTXCtorDtorLoweringPass())
+MODULE_PASS("nvvm-reflect", NVVMReflectPass())
 #undef MODULE_PASS
 
 #ifndef FUNCTION_ANALYSIS
@@ -36,7 +37,6 @@ FUNCTION_ALIAS_ANALYSIS("nvptx-aa", NVPTXAA())
 #define FUNCTION_PASS(NAME, CREATE_PASS)
 #endif
 FUNCTION_PASS("nvvm-intr-range", NVVMIntrRangePass())
-FUNCTION_PASS("nvvm-reflect", NVVMReflectPass())
 FUNCTION_PASS("nvptx-copy-byval-args", NVPTXCopyByValArgsPass())
 FUNCTION_PASS("nvptx-lower-args", NVPTXLowerArgsPass(*this));
 #undef FUNCTION_PASS
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
index 8a25256ea1e4a..d4376ef87b1f1 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
@@ -240,11 +240,12 @@ void NVPTXTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
 
   PB.registerPipelineStartEPCallback(
       [this](ModulePassManager &PM, OptimizationLevel Level) {
-        FunctionPassManager FPM;
         // We do not want to fold out calls to nvvm.reflect early if the user
         // has not provided a target architecture just yet.
         if (Subtarget.hasTargetName())
-          FPM.addPass(NVVMReflectPass(Subtarget.getSmVersion()));
+          PM.addPass(NVVMReflectPass(Subtarget.getSmVersion()));
+        
+        FunctionPassManager FPM;
         // Note: NVVMIntrRangePass was causing numerical discrepancies at one
         // point, if issues crop up, consider disabling.
         FPM.addPass(NVVMIntrRangePass());
diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
index 2809ec2303f99..96e23a6699666 100644
--- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
@@ -4,7 +4,7 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
-//===----------------------------------------------------------------------===//
+
 //
 // This pass replaces occurrences of __nvvm_reflect("foo") and llvm.nvvm.reflect
 // with an integer.
@@ -25,7 +25,6 @@
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/Function.h"
-#include "llvm/IR/InstIterator.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsNVPTX.h"
@@ -39,27 +38,43 @@
 #include "llvm/Transforms/Scalar.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/Local.h"
+#include "llvm/Transforms/Utils/StripGCRelocates.h"
 #include <algorithm>
 #define NVVM_REFLECT_FUNCTION "__nvvm_reflect"
 #define NVVM_REFLECT_OCL_FUNCTION "__nvvm_reflect_ocl"
 
 using namespace llvm;
 
-#define DEBUG_TYPE "nvptx-reflect"
+#define DEBUG_TYPE "nvvm-reflect"
 
 namespace {
-class NVVMReflect : public FunctionPass {
+class NVVMReflect : public ModulePass {
+private:
+  StringMap<int> VarMap;
+  /// Process a reflect function by finding all its uses and replacing them with
+  /// appropriate constant values. For __CUDA_FTZ, uses the module flag value.
+  /// For __CUDA_ARCH, uses SmVersion * 10. For all other strings, uses 0.
+  bool handleReflectFunction(Function *F);
+  void setVarMap(Module &M);
+
 public:
   static char ID;
-  unsigned int SmVersion;
   NVVMReflect() : NVVMReflect(0) {}
-  explicit NVVMReflect(unsigned int Sm) : FunctionPass(ID), SmVersion(Sm) {}
-
-  bool runOnFunction(Function &) override;
+  // __CUDA_FTZ is assigned in `runOnModule` by checking nvvm-reflect-ftz module
+  // metadata.
+  explicit NVVMReflect(unsigned int Sm) : ModulePass(ID) {
+    VarMap["__CUDA_ARCH"] = Sm * 10;
+    initializeNVVMReflectPass(*PassRegistry::getPassRegistry());
+  }
+  // This mapping will contain should include __CUDA_FTZ and __CUDA_ARCH values.
+  explicit NVVMReflect(const StringMap<int> &Mapping) : ModulePass(ID), VarMap(Mapping) {
+    initializeNVVMReflectPass(*PassRegistry::getPassRegistry());
+  }
+  bool runOnModule(Module &M) override;
 };
 } // namespace
 
-FunctionPass *llvm::createNVVMReflectPass(unsigned int SmVersion) {
+ModulePass *llvm::createNVVMReflectPass(unsigned int SmVersion) {
   return new NVVMReflect(SmVersion);
 }
 
@@ -72,27 +87,51 @@ INITIALIZE_PASS(NVVMReflect, "nvvm-reflect",
                 "Replace occurrences of __nvvm_reflect() calls with 0/1", false,
                 false)
 
-static bool runNVVMReflect(Function &F, unsigned SmVersion) {
-  if (!NVVMReflectEnabled)
-    return false;
+static cl::list<std::string>
+    ReflectList("nvvm-reflect-list", cl::value_desc("name=<int>"), cl::Hidden,
+                cl::desc("A list of string=num assignments"),
+                cl::ValueRequired);
 
-  if (F.getName() == NVVM_REFLECT_FUNCTION ||
-      F.getName() == NVVM_REFLECT_OCL_FUNCTION) {
-    assert(F.isDeclaration() && "_reflect function should not have a body");
-    assert(F.getReturnType()->isIntegerTy() &&
-           "_reflect's return type should be integer");
-    return false;
+/// The command line can look as follows :
+/// -nvvm-reflect-list a=1,b=2 -nvvm-reflect-list c=3,d=0 -R e=2
+/// The strings "a=1,b=2", "c=3,d=0", "e=2" are available in the
+/// ReflectList vector. First, each of ReflectList[i] is 'split'
+/// using "," as the delimiter. Then each of this part is split
+/// using "=" as the delimiter.
+void NVVMReflect::setVarMap(Module &M) {
+  if (auto *Flag = mdconst::extract_or_null<ConstantInt>(
+          M.getModuleFlag("nvvm-reflect-ftz")))
+    VarMap["__CUDA_FTZ"] = Flag->getSExtValue();
+
+  for (unsigned I = 0, E = ReflectList.size(); I != E; ++I) {
+    LLVM_DEBUG(dbgs() << "Option : " << ReflectList[I] << "\n");
+    SmallVector<StringRef, 4> NameValList;
+    StringRef(ReflectList[I]).split(NameValList, ",");
+    for (unsigned J = 0, EJ = NameValList.size(); J != EJ; ++J) {
+      SmallVector<StringRef, 2> NameValPair;
+      NameValList[J].split(NameValPair, "=");
+      assert(NameValPair.size() == 2 && "name=val expected");
+      StringRef ValStr = NameValPair[1].trim();
+      int Val;
+      if (ValStr.getAsInteger(10, Val))
+        report_fatal_error("integer value expected");
+      VarMap[NameValPair[0]] = Val;
+    }
   }
+}
+
+bool NVVMReflect::handleReflectFunction(Function *F) {
+  // Validate _reflect function
+  assert(F->isDeclaration() && "_reflect function should not have a body");
+  assert(F->getReturnType()->isIntegerTy() &&
+         "_reflect's return type should be integer");
 
   SmallVector<Instruction *, 4> ToRemove;
   SmallVector<Instruction *, 4> ToSimplify;
 
-  // Go through the calls in this function.  Each call to __nvvm_reflect or
-  // llvm.nvvm.reflect should be a CallInst with a ConstantArray argument.
-  // First validate that. If the c-string corresponding to the ConstantArray can
-  // be found successfully, see if it can be found in VarMap. If so, replace the
-  // uses of CallInst with the value found in VarMap. If not, replace the use
-  // with value 0.
+  // Go through the uses of the reflect function. Each use should be a CallInst
+  // with a ConstantArray argument. Replace the uses with the appropriate
+  // constant values.
 
   // The IR for __nvvm_reflect calls differs between CUDA versions.
   //
@@ -113,15 +152,10 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
   //
   // In this case, we get a Constant with a GlobalVariable operand and we need
   // to dig deeper to find its initializer with the string we'll use for lookup.
-  for (Instruction &I : instructions(F)) {
-    CallInst *Call = dyn_cast<CallInst>(&I);
-    if (!Call)
-      continue;
-    Function *Callee = Call->getCalledFunction();
-    if (!Callee || (Callee->getName() != NVVM_REFLECT_FUNCTION &&
-                    Callee->getName() != NVVM_REFLECT_OCL_FUNCTION &&
-                    Callee->getIntrinsicID() != Intrinsic::nvvm_reflect))
-      continue;
+
+  for (User *U : F->users()) {
+    assert(isa<CallInst>(U) && "Only a call instruction can use _reflect");
+    CallInst *Call = cast<CallInst>(U);
 
     // FIXME: Improve error handling here and elsewhere in this pass.
     assert(Call->getNumOperands() == 2 &&
@@ -156,20 +190,15 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
            "Format of _reflect function not recognized");
 
     StringRef ReflectArg = cast<ConstantDataSequential>(Operand)->getAsString();
+    // Remove the null terminator from the string
     ReflectArg = ReflectArg.substr(0, ReflectArg.size() - 1);
     LLVM_DEBUG(dbgs() << "Arg of _reflect : " << ReflectArg << "\n");
 
     int ReflectVal = 0; // The default value is 0
-    if (ReflectArg == "__CUDA_FTZ") {
-      // Try to pull __CUDA_FTZ from the nvvm-reflect-ftz module flag.  Our
-      // choice here must be kept in sync with AutoUpgrade, which uses the same
-      // technique to detect whether ftz is enabled.
-      if (auto *Flag = mdconst::extract_or_null<ConstantInt>(
-              F.getParent()->getModuleFlag("nvvm-reflect-ftz")))
-        ReflectVal = Flag->getSExtValue();
-    } else if (ReflectArg == "__CUDA_ARCH") {
-      ReflectVal = SmVersion * 10;
+    if (VarMap.contains(ReflectArg)) {
+      ReflectVal = VarMap[ReflectArg];
     }
+    LLVM_DEBUG(dbgs() << "ReflectVal: " << ReflectVal << "\n");
 
     // If the immediate user is a simple comparison we want to simplify it.
     for (User *U : Call->users())
@@ -185,7 +214,7 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
   // until we find a terminator that we can then remove.
   while (!ToSimplify.empty()) {
     Instruction *I = ToSimplify.pop_back_val();
-    if (Constant *C = ConstantFoldInstruction(I, F.getDataLayout())) {
+    if (Constant *C = ConstantFoldInstruction(I, F->getDataLayout())) {
       for (User *U : I->users())
         if (Instruction *I = dyn_cast<Instruction>(U))
           ToSimplify.push_back(I);
@@ -202,23 +231,45 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
   // Removing via isInstructionTriviallyDead may add duplicates to the ToRemove
   // array. Filter out the duplicates before starting to erase from parent.
   std::sort(ToRemove.begin(), ToRemove.end());
-  auto NewLastIter = llvm::unique(ToRemove);
+  auto *NewLastIter = llvm::unique(ToRemove);
   ToRemove.erase(NewLastIter, ToRemove.end());
 
   for (Instruction *I : ToRemove)
     I->eraseFromParent();
 
+  // Remove the __nvvm_reflect function from the module
+  F->eraseFromParent();
   return ToRemove.size() > 0;
 }
 
-bool NVVMReflect::runOnFunction(Function &F) {
-  return runNVVMReflect(F, SmVersion);
-}
+bool NVVMReflect::runOnModule(Module &M) {
+  if (!NVVMReflectEnabled)
+    return false;
+
+  setVarMap(M);
 
-NVVMReflectPass::NVVMReflectPass() : NVVMReflectPass(0) {}
+  bool Changed = false;
+  // Names of reflect function to find and replace
+  SmallVector<std::string, 3> ReflectNames = {
+      NVVM_REFLECT_FUNCTION, NVVM_REFLECT_OCL_FUNCTION,
+      Intrinsic::getName(Intrinsic::nvvm_reflect).str()};
+
+  // Process all reflect functions
+  for (const std::string &Name : ReflectNames) {
+    Function *ReflectFunction = M.getFunction(Name);
+    if (ReflectFunction) {
+      Changed |= handleReflectFunction(ReflectFunction);
+    }
+  }
+
+  return Changed;
+}
 
-PreservedAnalyses NVVMReflectPass::run(Function &F,
-                                       FunctionAnalysisManager &AM) {
-  return runNVVMReflect(F, SmVersion) ? PreservedAnalyses::none()
-                                      : PreservedAnalyses::all();
+// Implementations for the pass that works with the new pass manager.
+NVVMReflectPass::NVVMReflectPass(unsigned SmVersion) {
+  VarMap["__CUDA_ARCH"] = SmVersion * 10;
 }
+PreservedAnalyses NVVMReflectPass::run(Module &M, ModuleAnalysisManager &AM) {
+  return NVVMReflect(VarMap).runOnModule(M) ? PreservedAnalyses::none()
+                                            : PreservedAnalyses::all();
+}
\ No newline at end of file

>From cc497638e698ad324ffbe484cd6171f3de54561d Mon Sep 17 00:00:00 2001
From: Yonah Goldberg <ygoldberg at nvidia.com>
Date: Fri, 4 Apr 2025 16:59:38 +0000
Subject: [PATCH 02/13] cleanup

---
 llvm/lib/Target/NVPTX/NVVMReflect.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
index 96e23a6699666..85a714b1d5aec 100644
--- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
@@ -4,7 +4,7 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
-
+//===----------------------------------------------------------------------===//
 //
 // This pass replaces occurrences of __nvvm_reflect("foo") and llvm.nvvm.reflect
 // with an integer.

>From db16866c2e8afc4fa30715321b052c66dd895f10 Mon Sep 17 00:00:00 2001
From: Yonah Goldberg <ygoldberg at nvidia.com>
Date: Fri, 4 Apr 2025 17:17:14 +0000
Subject: [PATCH 03/13] newline

---
 llvm/lib/Target/NVPTX/NVVMReflect.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
index 85a714b1d5aec..f0f7f2baf0206 100644
--- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
@@ -272,4 +272,4 @@ NVVMReflectPass::NVVMReflectPass(unsigned SmVersion) {
 PreservedAnalyses NVVMReflectPass::run(Module &M, ModuleAnalysisManager &AM) {
   return NVVMReflect(VarMap).runOnModule(M) ? PreservedAnalyses::none()
                                             : PreservedAnalyses::all();
-}
\ No newline at end of file
+}

>From f167f2c6def6858cf2904520c0f3ba4536b56bc7 Mon Sep 17 00:00:00 2001
From: Yonah Goldberg <ygoldberg at nvidia.com>
Date: Mon, 7 Apr 2025 20:05:06 +0000
Subject: [PATCH 04/13] reflect improvements

---
 llvm/lib/Target/NVPTX/NVVMReflect.cpp         | 183 ++++++++++--------
 .../CodeGen/NVPTX/nvvm-reflect-options.ll     |  26 +++
 2 files changed, 124 insertions(+), 85 deletions(-)
 create mode 100644 llvm/test/CodeGen/NVPTX/nvvm-reflect-options.ll

diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
index f0f7f2baf0206..35546d05ab5be 100644
--- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
@@ -4,6 +4,12 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
+// NVIDIA_COPYRIGHT_BEGIN
+//
+// Copyright (c) 2023-2025, NVIDIA CORPORATION.  All rights reserved.
+//
+// NVIDIA_COPYRIGHT_END
+//
 //===----------------------------------------------------------------------===//
 //
 // This pass replaces occurrences of __nvvm_reflect("foo") and llvm.nvvm.reflect
@@ -38,8 +44,7 @@
 #include "llvm/Transforms/Scalar.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/Local.h"
-#include "llvm/Transforms/Utils/StripGCRelocates.h"
-#include <algorithm>
+#include "llvm/ADT/StringExtras.h"
 #define NVVM_REFLECT_FUNCTION "__nvvm_reflect"
 #define NVVM_REFLECT_OCL_FUNCTION "__nvvm_reflect_ocl"
 
@@ -54,16 +59,15 @@ class NVVMReflect : public ModulePass {
   /// Process a reflect function by finding all its uses and replacing them with
   /// appropriate constant values. For __CUDA_FTZ, uses the module flag value.
   /// For __CUDA_ARCH, uses SmVersion * 10. For all other strings, uses 0.
-  bool handleReflectFunction(Function *F);
+  void handleReflectFunction(Function *F);
   void setVarMap(Module &M);
-
+  void foldReflectCall(CallInst *Call, Constant *NewValue);
 public:
   static char ID;
   NVVMReflect() : NVVMReflect(0) {}
   // __CUDA_FTZ is assigned in `runOnModule` by checking nvvm-reflect-ftz module
   // metadata.
-  explicit NVVMReflect(unsigned int Sm) : ModulePass(ID) {
-    VarMap["__CUDA_ARCH"] = Sm * 10;
+  explicit NVVMReflect(unsigned SmVersion) : ModulePass(ID), VarMap({{"__CUDA_ARCH", SmVersion * 10}}) {
     initializeNVVMReflectPass(*PassRegistry::getPassRegistry());
   }
   // This mapping will contain should include __CUDA_FTZ and __CUDA_ARCH values.
@@ -87,51 +91,58 @@ INITIALIZE_PASS(NVVMReflect, "nvvm-reflect",
                 "Replace occurrences of __nvvm_reflect() calls with 0/1", false,
                 false)
 
+// Allow users to specify additional key/value pairs to reflect. These key/value pairs
+// are the last to be added to the VarMap, and therefore will take precedence over initial
+// values (i.e. __CUDA_FTZ from module medadata and __CUDA_ARCH from SmVersion).
 static cl::list<std::string>
-    ReflectList("nvvm-reflect-list", cl::value_desc("name=<int>"), cl::Hidden,
-                cl::desc("A list of string=num assignments"),
-                cl::ValueRequired);
-
-/// The command line can look as follows :
-/// -nvvm-reflect-list a=1,b=2 -nvvm-reflect-list c=3,d=0 -R e=2
-/// The strings "a=1,b=2", "c=3,d=0", "e=2" are available in the
-/// ReflectList vector. First, each of ReflectList[i] is 'split'
-/// using "," as the delimiter. Then each of this part is split
-/// using "=" as the delimiter.
+ReflectList("nvvm-reflect-add", cl::value_desc("name=<int>"), cl::Hidden,
+            cl::desc("list of comma-separated key=value pairs"),
+            cl::ValueRequired);
+
+// Set the VarMap with, first, the value of __CUDA_FTZ from module metadata, and then
+// the key/value pairs from the command line.
 void NVVMReflect::setVarMap(Module &M) {
+  LLVM_DEBUG(dbgs() << "Reflect list values:\n");
+  for (StringRef Option : ReflectList) {
+    LLVM_DEBUG(dbgs() << "  " << Option << "\n");
+  }
   if (auto *Flag = mdconst::extract_or_null<ConstantInt>(
-          M.getModuleFlag("nvvm-reflect-ftz")))
+      M.getModuleFlag("nvvm-reflect-ftz")))
     VarMap["__CUDA_FTZ"] = Flag->getSExtValue();
 
-  for (unsigned I = 0, E = ReflectList.size(); I != E; ++I) {
-    LLVM_DEBUG(dbgs() << "Option : " << ReflectList[I] << "\n");
-    SmallVector<StringRef, 4> NameValList;
-    StringRef(ReflectList[I]).split(NameValList, ",");
-    for (unsigned J = 0, EJ = NameValList.size(); J != EJ; ++J) {
-      SmallVector<StringRef, 2> NameValPair;
-      NameValList[J].split(NameValPair, "=");
-      assert(NameValPair.size() == 2 && "name=val expected");
-      StringRef ValStr = NameValPair[1].trim();
+  /// The command line can look as follows :
+  /// -nvvm-reflect-add a=1,b=2 -nvvm-reflect-add c=3,d=0 -nvvm-reflect-add e=2
+  /// The strings "a=1,b=2", "c=3,d=0", "e=2" are available in the
+  /// ReflectList vector. First, each of ReflectList[i] is 'split'
+  /// using "," as the delimiter. Then each of this part is split
+  /// using "=" as the delimiter.
+  for (StringRef Option : ReflectList) {
+    LLVM_DEBUG(dbgs() << "ReflectOption : " << Option << "\n");
+    while (!Option.empty()) {
+      std::pair<StringRef, StringRef> Split = Option.split(',');
+      StringRef NameVal = Split.first;
+      Option = Split.second;
+
+      auto NameValPair = NameVal.split('=');
+      assert(!NameValPair.first.empty() && !NameValPair.second.empty() && 
+             "name=val expected");
+      
       int Val;
-      if (ValStr.getAsInteger(10, Val))
+      if (!to_integer(NameValPair.second.trim(), Val, 10))
         report_fatal_error("integer value expected");
-      VarMap[NameValPair[0]] = Val;
+      VarMap[NameValPair.first] = Val;
     }
   }
 }
 
-bool NVVMReflect::handleReflectFunction(Function *F) {
+void NVVMReflect::handleReflectFunction(Function *F) {
   // Validate _reflect function
   assert(F->isDeclaration() && "_reflect function should not have a body");
-  assert(F->getReturnType()->isIntegerTy() &&
-         "_reflect's return type should be integer");
+  assert(F->getReturnType()->isIntegerTy() && "_reflect's return type should be integer");
 
-  SmallVector<Instruction *, 4> ToRemove;
-  SmallVector<Instruction *, 4> ToSimplify;
 
   // Go through the uses of the reflect function. Each use should be a CallInst
-  // with a ConstantArray argument. Replace the uses with the appropriate
-  // constant values.
+  // with a ConstantArray argument. Replace the uses with the appropriate constant values.
 
   // The IR for __nvvm_reflect calls differs between CUDA versions.
   //
@@ -153,7 +164,7 @@ bool NVVMReflect::handleReflectFunction(Function *F) {
   // In this case, we get a Constant with a GlobalVariable operand and we need
   // to dig deeper to find its initializer with the string we'll use for lookup.
 
-  for (User *U : F->users()) {
+  for (User *U : make_early_inc_range(F->users())) {
     assert(isa<CallInst>(U) && "Only a call instruction can use _reflect");
     CallInst *Call = cast<CallInst>(U);
 
@@ -165,21 +176,23 @@ bool NVVMReflect::handleReflectFunction(Function *F) {
     // conversion of the string.
     const Value *Str = Call->getArgOperand(0);
     if (const CallInst *ConvCall = dyn_cast<CallInst>(Str)) {
-      // FIXME: Add assertions about ConvCall.
+      // Verify this is the constant-to-generic intrinsic
+      Function *Callee = ConvCall->getCalledFunction();
+      assert(Callee && Callee->isIntrinsic() && 
+             Callee->getName().starts_with("llvm.nvvm.ptr.constant.to.gen") &&
+             "Expected llvm.nvvm.ptr.constant.to.gen intrinsic");
+      assert(ConvCall->getNumOperands() == 2 && "Expected one argument for ptr conversion");
       Str = ConvCall->getArgOperand(0);
     }
     // Pre opaque pointers we have a constant expression wrapping the constant
-    // string.
     Str = Str->stripPointerCasts();
-    assert(isa<Constant>(Str) &&
-           "Format of __nvvm_reflect function not recognized");
+    assert(isa<Constant>(Str) && "Format of __nvvm_reflect function not recognized");
 
     const Value *Operand = cast<Constant>(Str)->getOperand(0);
     if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(Operand)) {
       // For CUDA-7.0 style __nvvm_reflect calls, we need to find the operand's
       // initializer.
-      assert(GV->hasInitializer() &&
-             "Format of _reflect function not recognized");
+      assert(GV->hasInitializer() && "Format of _reflect function not recognized");
       const Constant *Initializer = GV->getInitializer();
       Operand = Initializer;
     }
@@ -192,54 +205,48 @@ bool NVVMReflect::handleReflectFunction(Function *F) {
     StringRef ReflectArg = cast<ConstantDataSequential>(Operand)->getAsString();
     // Remove the null terminator from the string
     ReflectArg = ReflectArg.substr(0, ReflectArg.size() - 1);
-    LLVM_DEBUG(dbgs() << "Arg of _reflect : " << ReflectArg << "\n");
 
     int ReflectVal = 0; // The default value is 0
     if (VarMap.contains(ReflectArg)) {
       ReflectVal = VarMap[ReflectArg];
     }
-    LLVM_DEBUG(dbgs() << "ReflectVal: " << ReflectVal << "\n");
+    LLVM_DEBUG(dbgs() << "Replacing call of reflect function " << F->getName() << "(" << ReflectArg << ") with value " << ReflectVal << "\n");
+    Constant *NewValue = ConstantInt::get(Call->getType(), ReflectVal);
+    foldReflectCall(Call, NewValue);
+    Call->eraseFromParent();
+  }
 
-    // If the immediate user is a simple comparison we want to simplify it.
-    for (User *U : Call->users())
-      if (Instruction *I = dyn_cast<Instruction>(U))
-        ToSimplify.push_back(I);
+  // Remove the __nvvm_reflect function from the module
+  F->eraseFromParent();
+}
 
-    Call->replaceAllUsesWith(ConstantInt::get(Call->getType(), ReflectVal));
-    ToRemove.push_back(Call);
+void NVVMReflect::foldReflectCall(CallInst *Call, Constant *NewValue) {
+  // Initialize worklist with all users of the call
+  SmallVector<Instruction*, 8> Worklist;
+  for (User *U : Call->users()) {
+    if (Instruction *I = dyn_cast<Instruction>(U)) {
+      Worklist.push_back(I);
+    }
   }
 
-  // The code guarded by __nvvm_reflect may be invalid for the target machine.
-  // Traverse the use-def chain, continually simplifying constant expressions
-  // until we find a terminator that we can then remove.
-  while (!ToSimplify.empty()) {
-    Instruction *I = ToSimplify.pop_back_val();
-    if (Constant *C = ConstantFoldInstruction(I, F->getDataLayout())) {
-      for (User *U : I->users())
-        if (Instruction *I = dyn_cast<Instruction>(U))
-          ToSimplify.push_back(I);
+  Call->replaceAllUsesWith(NewValue);
 
-      I->replaceAllUsesWith(C);
-      if (isInstructionTriviallyDead(I)) {
-        ToRemove.push_back(I);
+  while (!Worklist.empty()) {
+    Instruction *I = Worklist.pop_back_val();
+    if (Constant *C = ConstantFoldInstruction(I, Call->getModule()->getDataLayout())) {
+      // Add all users of this instruction to the worklist, replace it with the constant
+      // then delete it if it's dead
+      for (User *U : I->users()) {
+        if (Instruction *UI = dyn_cast<Instruction>(U))
+          Worklist.push_back(UI);
       }
+      I->replaceAllUsesWith(C);
+      if (isInstructionTriviallyDead(I))
+        I->eraseFromParent();
     } else if (I->isTerminator()) {
       ConstantFoldTerminator(I->getParent());
     }
   }
-
-  // Removing via isInstructionTriviallyDead may add duplicates to the ToRemove
-  // array. Filter out the duplicates before starting to erase from parent.
-  std::sort(ToRemove.begin(), ToRemove.end());
-  auto *NewLastIter = llvm::unique(ToRemove);
-  ToRemove.erase(NewLastIter, ToRemove.end());
-
-  for (Instruction *I : ToRemove)
-    I->eraseFromParent();
-
-  // Remove the __nvvm_reflect function from the module
-  F->eraseFromParent();
-  return ToRemove.size() > 0;
 }
 
 bool NVVMReflect::runOnModule(Module &M) {
@@ -250,15 +257,19 @@ bool NVVMReflect::runOnModule(Module &M) {
 
   bool Changed = false;
   // Names of reflect function to find and replace
-  SmallVector<std::string, 3> ReflectNames = {
-      NVVM_REFLECT_FUNCTION, NVVM_REFLECT_OCL_FUNCTION,
-      Intrinsic::getName(Intrinsic::nvvm_reflect).str()};
+  SmallVector<StringRef, 5> ReflectNames = {
+      NVVM_REFLECT_FUNCTION,
+      NVVM_REFLECT_OCL_FUNCTION,
+      Intrinsic::getName(Intrinsic::nvvm_reflect),
+  };
 
   // Process all reflect functions
-  for (const std::string &Name : ReflectNames) {
-    Function *ReflectFunction = M.getFunction(Name);
-    if (ReflectFunction) {
-      Changed |= handleReflectFunction(ReflectFunction);
+  for (StringRef Name : ReflectNames) {
+    if (Function *ReflectFunction = M.getFunction(Name)) {
+      // If the reflect functition is called, we need to replace the call
+      // with the appropriate constant, modifying the IR.
+      Changed |= ReflectFunction->getNumUses() > 0;
+      handleReflectFunction(ReflectFunction);
     }
   }
 
@@ -269,7 +280,9 @@ bool NVVMReflect::runOnModule(Module &M) {
 NVVMReflectPass::NVVMReflectPass(unsigned SmVersion) {
   VarMap["__CUDA_ARCH"] = SmVersion * 10;
 }
-PreservedAnalyses NVVMReflectPass::run(Module &M, ModuleAnalysisManager &AM) {
+
+PreservedAnalyses NVVMReflectPass::run(Module &M,
+                                    ModuleAnalysisManager &AM) {
   return NVVMReflect(VarMap).runOnModule(M) ? PreservedAnalyses::none()
-                                            : PreservedAnalyses::all();
-}
+                                   : PreservedAnalyses::all();
+}
\ No newline at end of file
diff --git a/llvm/test/CodeGen/NVPTX/nvvm-reflect-options.ll b/llvm/test/CodeGen/NVPTX/nvvm-reflect-options.ll
new file mode 100644
index 0000000000000..54087897d65b5
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/nvvm-reflect-options.ll
@@ -0,0 +1,26 @@
+; Verify that when passing in command-line options to NVVMReflect, that reflect calls are replaced with
+; the appropriate command line values.
+
+declare i32 @__nvvm_reflect(ptr)
+ at ftz = private unnamed_addr addrspace(1) constant [11 x i8] c"__CUDA_FTZ\00"
+ at arch = private unnamed_addr addrspace(1) constant [12 x i8] c"__CUDA_ARCH\00"
+
+; RUN: opt -passes=nvvm-reflect -mtriple=nvptx-nvidia-cuda -nvvm-reflect-add=__CUDA_FTZ=1,__CUDA_ARCH=350 %s -S | FileCheck %s --check-prefix=CHECK-FTZ1-ARCH350
+; RUN: opt -passes=nvvm-reflect -mtriple=nvptx-nvidia-cuda -nvvm-reflect-add=__CUDA_FTZ=0 -nvvm-reflect-add=__CUDA_ARCH=520 %s -S | FileCheck %s --check-prefix=CHECK-FTZ0-ARCH520
+
+; Verify that if we have module metadata that sets __CUDA_FTZ=1, that gets overridden by the command line arguments
+
+; RUN: cat %s > %t.options
+; RUN: echo '!llvm.module.flags = !{!0}' >> %t.options
+; RUN: echo '!0 = !{i32 4, !"nvvm-reflect-ftz", i32 1}' >> %t.options
+; RUN: opt -passes=nvvm-reflect -mtriple=nvptx-nvidia-cuda -nvvm-reflect-add=__CUDA_FTZ=0,__CUDA_ARCH=520 %t.options -S | FileCheck %s --check-prefix=CHECK-FTZ0-ARCH520
+
+define i32 @options() {
+  %1 = call i32 @__nvvm_reflect(ptr addrspacecast (ptr addrspace(1) @ftz to ptr))
+  %2 = call i32 @__nvvm_reflect(ptr addrspacecast (ptr addrspace(1) @arch to ptr))
+  %3 = add i32 %1, %2
+  ret i32 %3
+}
+
+; CHECK-FTZ1-ARCH350: ret i32 351
+; CHECK-FTZ0-ARCH520: ret i32 520
\ No newline at end of file

>From 78232a8ffc4fe28f921170ddcb040601f2cc858b Mon Sep 17 00:00:00 2001
From: Yonah Goldberg <ygoldberg at nvidia.com>
Date: Mon, 7 Apr 2025 20:07:23 +0000
Subject: [PATCH 05/13] comment move

---
 llvm/lib/Target/NVPTX/NVVMReflect.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
index 35546d05ab5be..859b63c54bebd 100644
--- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
@@ -56,9 +56,6 @@ namespace {
 class NVVMReflect : public ModulePass {
 private:
   StringMap<int> VarMap;
-  /// Process a reflect function by finding all its uses and replacing them with
-  /// appropriate constant values. For __CUDA_FTZ, uses the module flag value.
-  /// For __CUDA_ARCH, uses SmVersion * 10. For all other strings, uses 0.
   void handleReflectFunction(Function *F);
   void setVarMap(Module &M);
   void foldReflectCall(CallInst *Call, Constant *NewValue);
@@ -135,6 +132,9 @@ void NVVMReflect::setVarMap(Module &M) {
   }
 }
 
+/// Process a reflect function by finding all its uses and replacing them with
+/// appropriate constant values. For __CUDA_FTZ, uses the module flag value.
+/// For __CUDA_ARCH, uses SmVersion * 10. For all other strings, uses 0.
 void NVVMReflect::handleReflectFunction(Function *F) {
   // Validate _reflect function
   assert(F->isDeclaration() && "_reflect function should not have a body");

>From 7385210ce4dd110f293d5fd74b3980c694e123eb Mon Sep 17 00:00:00 2001
From: Yonah Goldberg <ygoldberg at nvidia.com>
Date: Mon, 7 Apr 2025 20:11:13 +0000
Subject: [PATCH 06/13] auto

---
 llvm/lib/Target/NVPTX/NVVMReflect.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
index 859b63c54bebd..b39131de4b7bf 100644
--- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
@@ -116,7 +116,7 @@ void NVVMReflect::setVarMap(Module &M) {
   for (StringRef Option : ReflectList) {
     LLVM_DEBUG(dbgs() << "ReflectOption : " << Option << "\n");
     while (!Option.empty()) {
-      std::pair<StringRef, StringRef> Split = Option.split(',');
+      auto Split = Option.split(',');
       StringRef NameVal = Split.first;
       Option = Split.second;
 

>From 453c62981c54f9a5fbd46324cebb17c90963ab07 Mon Sep 17 00:00:00 2001
From: Yonah Goldberg <ygoldberg at nvidia.com>
Date: Mon, 7 Apr 2025 20:17:49 +0000
Subject: [PATCH 07/13] remove nvidia copyright

---
 llvm/lib/Target/NVPTX/NVVMReflect.cpp | 6 ------
 1 file changed, 6 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
index b39131de4b7bf..0915456c3915b 100644
--- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
@@ -4,12 +4,6 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
-// NVIDIA_COPYRIGHT_BEGIN
-//
-// Copyright (c) 2023-2025, NVIDIA CORPORATION.  All rights reserved.
-//
-// NVIDIA_COPYRIGHT_END
-//
 //===----------------------------------------------------------------------===//
 //
 // This pass replaces occurrences of __nvvm_reflect("foo") and llvm.nvvm.reflect

>From 8e95772bac5cfcbd8da612d09c2f71b17722fa58 Mon Sep 17 00:00:00 2001
From: Yonah Goldberg <ygoldberg at nvidia.com>
Date: Mon, 7 Apr 2025 20:22:45 +0000
Subject: [PATCH 08/13] improve error messages

---
 llvm/lib/Target/NVPTX/NVVMReflect.cpp | 34 ++++++++++++++++-----------
 1 file changed, 20 insertions(+), 14 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
index 0915456c3915b..2533c1cbee25d 100644
--- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
@@ -159,12 +159,12 @@ void NVVMReflect::handleReflectFunction(Function *F) {
   // to dig deeper to find its initializer with the string we'll use for lookup.
 
   for (User *U : make_early_inc_range(F->users())) {
-    assert(isa<CallInst>(U) && "Only a call instruction can use _reflect");
+    if (!isa<CallInst>(U))
+      report_fatal_error("__nvvm_reflect can only be used in a call instruction");
     CallInst *Call = cast<CallInst>(U);
 
-    // FIXME: Improve error handling here and elsewhere in this pass.
-    assert(Call->getNumOperands() == 2 &&
-           "Wrong number of operands to __nvvm_reflect function");
+    if (Call->getNumOperands() != 2)
+      report_fatal_error("__nvvm_reflect requires exactly one argument");
 
     // In cuda 6.5 and earlier, we will have an extra constant-to-generic
     // conversion of the string.
@@ -172,34 +172,40 @@ void NVVMReflect::handleReflectFunction(Function *F) {
     if (const CallInst *ConvCall = dyn_cast<CallInst>(Str)) {
       // Verify this is the constant-to-generic intrinsic
       Function *Callee = ConvCall->getCalledFunction();
-      assert(Callee && Callee->isIntrinsic() && 
-             Callee->getName().starts_with("llvm.nvvm.ptr.constant.to.gen") &&
-             "Expected llvm.nvvm.ptr.constant.to.gen intrinsic");
-      assert(ConvCall->getNumOperands() == 2 && "Expected one argument for ptr conversion");
+      if (!Callee || !Callee->isIntrinsic() || 
+          !Callee->getName().starts_with("llvm.nvvm.ptr.constant.to.gen"))
+        report_fatal_error("Expected llvm.nvvm.ptr.constant.to.gen intrinsic");
+      if (ConvCall->getNumOperands() != 2)
+        report_fatal_error("Expected one argument for ptr conversion");
       Str = ConvCall->getArgOperand(0);
     }
     // Pre opaque pointers we have a constant expression wrapping the constant
     Str = Str->stripPointerCasts();
-    assert(isa<Constant>(Str) && "Format of __nvvm_reflect function not recognized");
+    if (!isa<Constant>(Str))
+      report_fatal_error("__nvvm_reflect argument must be a constant string");
 
     const Value *Operand = cast<Constant>(Str)->getOperand(0);
     if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(Operand)) {
       // For CUDA-7.0 style __nvvm_reflect calls, we need to find the operand's
       // initializer.
-      assert(GV->hasInitializer() && "Format of _reflect function not recognized");
+      if (!GV->hasInitializer())
+        report_fatal_error("__nvvm_reflect string must have an initializer");
       const Constant *Initializer = GV->getInitializer();
       Operand = Initializer;
     }
 
-    assert(isa<ConstantDataSequential>(Operand) &&
-           "Format of _reflect function not recognized");
-    assert(cast<ConstantDataSequential>(Operand)->isCString() &&
-           "Format of _reflect function not recognized");
+    if (!isa<ConstantDataSequential>(Operand))
+      report_fatal_error("__nvvm_reflect argument must be a string constant");
+    if (!cast<ConstantDataSequential>(Operand)->isCString())
+      report_fatal_error("__nvvm_reflect argument must be a null-terminated string");
 
     StringRef ReflectArg = cast<ConstantDataSequential>(Operand)->getAsString();
     // Remove the null terminator from the string
     ReflectArg = ReflectArg.substr(0, ReflectArg.size() - 1);
 
+    if (ReflectArg.empty())
+      report_fatal_error("__nvvm_reflect argument cannot be empty");
+
     int ReflectVal = 0; // The default value is 0
     if (VarMap.contains(ReflectArg)) {
       ReflectVal = VarMap[ReflectArg];

>From d9a0ec6733c8f442c638087b4b1099ef4317433a Mon Sep 17 00:00:00 2001
From: Yonah Goldberg <ygoldberg at nvidia.com>
Date: Mon, 7 Apr 2025 21:10:36 +0000
Subject: [PATCH 09/13] fix command line options

---
 llvm/lib/Target/NVPTX/NVVMReflect.cpp         | 32 +++++++------------
 .../CodeGen/NVPTX/nvvm-reflect-options.ll     |  6 ++--
 2 files changed, 14 insertions(+), 24 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
index 2533c1cbee25d..2cfd74303a821 100644
--- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
@@ -86,7 +86,8 @@ INITIALIZE_PASS(NVVMReflect, "nvvm-reflect",
 // are the last to be added to the VarMap, and therefore will take precedence over initial
 // values (i.e. __CUDA_FTZ from module medadata and __CUDA_ARCH from SmVersion).
 static cl::list<std::string>
-ReflectList("nvvm-reflect-add", cl::value_desc("name=<int>"), cl::Hidden,
+ReflectList("nvvm-reflect-list", cl::value_desc("name=<int>"), cl::Hidden,
+            cl::CommaSeparated,
             cl::desc("list of comma-separated key=value pairs"),
             cl::ValueRequired);
 
@@ -101,28 +102,17 @@ void NVVMReflect::setVarMap(Module &M) {
       M.getModuleFlag("nvvm-reflect-ftz")))
     VarMap["__CUDA_FTZ"] = Flag->getSExtValue();
 
-  /// The command line can look as follows :
-  /// -nvvm-reflect-add a=1,b=2 -nvvm-reflect-add c=3,d=0 -nvvm-reflect-add e=2
-  /// The strings "a=1,b=2", "c=3,d=0", "e=2" are available in the
-  /// ReflectList vector. First, each of ReflectList[i] is 'split'
-  /// using "," as the delimiter. Then each of this part is split
-  /// using "=" as the delimiter.
   for (StringRef Option : ReflectList) {
     LLVM_DEBUG(dbgs() << "ReflectOption : " << Option << "\n");
-    while (!Option.empty()) {
-      auto Split = Option.split(',');
-      StringRef NameVal = Split.first;
-      Option = Split.second;
-
-      auto NameValPair = NameVal.split('=');
-      assert(!NameValPair.first.empty() && !NameValPair.second.empty() && 
-             "name=val expected");
-      
-      int Val;
-      if (!to_integer(NameValPair.second.trim(), Val, 10))
-        report_fatal_error("integer value expected");
-      VarMap[NameValPair.first] = Val;
-    }
+    auto [Name, Val] = Option.split('=');
+    if (Name.empty())
+      report_fatal_error("Empty name in nvvm-reflect-list option '" + Option + "'");
+    if (Val.empty()) 
+      report_fatal_error("Missing value in nvvm-reflect- option '" + Option + "'");
+    int ValInt;
+    if (!to_integer(Val.trim(), ValInt, 10))
+      report_fatal_error("integer value expected in nvvm-reflect-list option '" + Option + "'");
+    VarMap[Name] = ValInt;
   }
 }
 
diff --git a/llvm/test/CodeGen/NVPTX/nvvm-reflect-options.ll b/llvm/test/CodeGen/NVPTX/nvvm-reflect-options.ll
index 54087897d65b5..bd4fb3eb537d0 100644
--- a/llvm/test/CodeGen/NVPTX/nvvm-reflect-options.ll
+++ b/llvm/test/CodeGen/NVPTX/nvvm-reflect-options.ll
@@ -5,15 +5,15 @@ declare i32 @__nvvm_reflect(ptr)
 @ftz = private unnamed_addr addrspace(1) constant [11 x i8] c"__CUDA_FTZ\00"
 @arch = private unnamed_addr addrspace(1) constant [12 x i8] c"__CUDA_ARCH\00"
 
-; RUN: opt -passes=nvvm-reflect -mtriple=nvptx-nvidia-cuda -nvvm-reflect-add=__CUDA_FTZ=1,__CUDA_ARCH=350 %s -S | FileCheck %s --check-prefix=CHECK-FTZ1-ARCH350
-; RUN: opt -passes=nvvm-reflect -mtriple=nvptx-nvidia-cuda -nvvm-reflect-add=__CUDA_FTZ=0 -nvvm-reflect-add=__CUDA_ARCH=520 %s -S | FileCheck %s --check-prefix=CHECK-FTZ0-ARCH520
+; RUN: opt -passes=nvvm-reflect -mtriple=nvptx-nvidia-cuda -nvvm-reflect-list=__CUDA_FTZ=1,__CUDA_ARCH=350 %s -S | FileCheck %s --check-prefix=CHECK-FTZ1-ARCH350
+; RUN: opt -passes=nvvm-reflect -mtriple=nvptx-nvidia-cuda -nvvm-reflect-list=__CUDA_FTZ=0 -nvvm-reflect-list=__CUDA_ARCH=520 %s -S | FileCheck %s --check-prefix=CHECK-FTZ0-ARCH520
 
 ; Verify that if we have module metadata that sets __CUDA_FTZ=1, that gets overridden by the command line arguments
 
 ; RUN: cat %s > %t.options
 ; RUN: echo '!llvm.module.flags = !{!0}' >> %t.options
 ; RUN: echo '!0 = !{i32 4, !"nvvm-reflect-ftz", i32 1}' >> %t.options
-; RUN: opt -passes=nvvm-reflect -mtriple=nvptx-nvidia-cuda -nvvm-reflect-add=__CUDA_FTZ=0,__CUDA_ARCH=520 %t.options -S | FileCheck %s --check-prefix=CHECK-FTZ0-ARCH520
+; RUN: opt -passes=nvvm-reflect -mtriple=nvptx-nvidia-cuda -nvvm-reflect-list=__CUDA_FTZ=0,__CUDA_ARCH=520 %t.options -S | FileCheck %s --check-prefix=CHECK-FTZ0-ARCH520
 
 define i32 @options() {
   %1 = call i32 @__nvvm_reflect(ptr addrspacecast (ptr addrspace(1) @ftz to ptr))

>From f5444f388f3de1869ed4681260265829485c366c Mon Sep 17 00:00:00 2001
From: Yonah Goldberg <ygoldberg at nvidia.com>
Date: Mon, 7 Apr 2025 21:57:26 +0000
Subject: [PATCH 10/13] fix command line options

---
 llvm/lib/Target/NVPTX/NVVMReflect.cpp           | 9 ++-------
 llvm/test/CodeGen/NVPTX/nvvm-reflect-options.ll | 6 +++---
 2 files changed, 5 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
index 2cfd74303a821..b1e01a1e3fe71 100644
--- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
@@ -86,18 +86,13 @@ INITIALIZE_PASS(NVVMReflect, "nvvm-reflect",
 // are the last to be added to the VarMap, and therefore will take precedence over initial
 // values (i.e. __CUDA_FTZ from module medadata and __CUDA_ARCH from SmVersion).
 static cl::list<std::string>
-ReflectList("nvvm-reflect-list", cl::value_desc("name=<int>"), cl::Hidden,
-            cl::CommaSeparated,
-            cl::desc("list of comma-separated key=value pairs"),
+ReflectList("nvvm-reflect-add", cl::value_desc("name=<int>"), cl::Hidden,
+            cl::desc("A key=value pair. Replace __nvvm_reflect(name) with value."),
             cl::ValueRequired);
 
 // Set the VarMap with, first, the value of __CUDA_FTZ from module metadata, and then
 // the key/value pairs from the command line.
 void NVVMReflect::setVarMap(Module &M) {
-  LLVM_DEBUG(dbgs() << "Reflect list values:\n");
-  for (StringRef Option : ReflectList) {
-    LLVM_DEBUG(dbgs() << "  " << Option << "\n");
-  }
   if (auto *Flag = mdconst::extract_or_null<ConstantInt>(
       M.getModuleFlag("nvvm-reflect-ftz")))
     VarMap["__CUDA_FTZ"] = Flag->getSExtValue();
diff --git a/llvm/test/CodeGen/NVPTX/nvvm-reflect-options.ll b/llvm/test/CodeGen/NVPTX/nvvm-reflect-options.ll
index bd4fb3eb537d0..fae48e554383b 100644
--- a/llvm/test/CodeGen/NVPTX/nvvm-reflect-options.ll
+++ b/llvm/test/CodeGen/NVPTX/nvvm-reflect-options.ll
@@ -5,15 +5,15 @@ declare i32 @__nvvm_reflect(ptr)
 @ftz = private unnamed_addr addrspace(1) constant [11 x i8] c"__CUDA_FTZ\00"
 @arch = private unnamed_addr addrspace(1) constant [12 x i8] c"__CUDA_ARCH\00"
 
-; RUN: opt -passes=nvvm-reflect -mtriple=nvptx-nvidia-cuda -nvvm-reflect-list=__CUDA_FTZ=1,__CUDA_ARCH=350 %s -S | FileCheck %s --check-prefix=CHECK-FTZ1-ARCH350
-; RUN: opt -passes=nvvm-reflect -mtriple=nvptx-nvidia-cuda -nvvm-reflect-list=__CUDA_FTZ=0 -nvvm-reflect-list=__CUDA_ARCH=520 %s -S | FileCheck %s --check-prefix=CHECK-FTZ0-ARCH520
+; RUN: opt -passes=nvvm-reflect -mtriple=nvptx-nvidia-cuda -nvvm-reflect-add __CUDA_FTZ=1 -nvvm-reflect-add __CUDA_ARCH=350 %s -S | FileCheck %s --check-prefix=CHECK-FTZ1-ARCH350
+; RUN: opt -passes=nvvm-reflect -mtriple=nvptx-nvidia-cuda -nvvm-reflect-add __CUDA_FTZ=0 -nvvm-reflect-add __CUDA_ARCH=520 %s -S | FileCheck %s --check-prefix=CHECK-FTZ0-ARCH520
 
 ; Verify that if we have module metadata that sets __CUDA_FTZ=1, that gets overridden by the command line arguments
 
 ; RUN: cat %s > %t.options
 ; RUN: echo '!llvm.module.flags = !{!0}' >> %t.options
 ; RUN: echo '!0 = !{i32 4, !"nvvm-reflect-ftz", i32 1}' >> %t.options
-; RUN: opt -passes=nvvm-reflect -mtriple=nvptx-nvidia-cuda -nvvm-reflect-list=__CUDA_FTZ=0,__CUDA_ARCH=520 %t.options -S | FileCheck %s --check-prefix=CHECK-FTZ0-ARCH520
+; RUN: opt -passes=nvvm-reflect -mtriple=nvptx-nvidia-cuda -nvvm-reflect-add __CUDA_FTZ=0 -nvvm-reflect-add __CUDA_ARCH=520 %t.options -S | FileCheck %s --check-prefix=CHECK-FTZ0-ARCH520
 
 define i32 @options() {
   %1 = call i32 @__nvvm_reflect(ptr addrspacecast (ptr addrspace(1) @ftz to ptr))

>From 6fd4fa930ff2e29e79a333d9d5106a1bcf2d394e Mon Sep 17 00:00:00 2001
From: Yonah Goldberg <ygoldberg at nvidia.com>
Date: Tue, 8 Apr 2025 22:05:03 +0000
Subject: [PATCH 11/13] final reflect cleanup

---
 llvm/lib/Target/NVPTX/NVPTX.h                |   6 +-
 llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp |  23 ++-
 llvm/lib/Target/NVPTX/NVVMReflect.cpp        | 194 +++++++------------
 3 files changed, 99 insertions(+), 124 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h
index f98ace3a0d189..93b31f2a4c01d 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.h
+++ b/llvm/lib/Target/NVPTX/NVPTX.h
@@ -78,12 +78,12 @@ struct NVVMIntrRangePass : PassInfoMixin<NVVMIntrRangePass> {
 };
 
 struct NVVMReflectPass : PassInfoMixin<NVVMReflectPass> {
-  NVVMReflectPass() : NVVMReflectPass(0) {}
-  NVVMReflectPass(unsigned SmVersion);
+  NVVMReflectPass() : SmVersion(0) {}
+  NVVMReflectPass(unsigned SmVersion) : SmVersion(SmVersion) {}
   PreservedAnalyses run(Module &F, ModuleAnalysisManager &AM);
 
 private:
-  StringMap<int> VarMap;
+  unsigned SmVersion;
 };
 
 struct GenericToNVVMPass : PassInfoMixin<GenericToNVVMPass> {
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
index d4376ef87b1f1..8bc94846337cf 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
@@ -87,6 +87,27 @@ static cl::opt<bool> EarlyByValArgsCopy(
     cl::desc("Create a copy of byval function arguments early."),
     cl::init(false), cl::Hidden);
 
+namespace llvm {
+
+void initializeGenericToNVVMLegacyPassPass(PassRegistry &);
+void initializeNVPTXAllocaHoistingPass(PassRegistry &);
+void initializeNVPTXAssignValidGlobalNamesPass(PassRegistry &);
+void initializeNVPTXAtomicLowerPass(PassRegistry &);
+void initializeNVPTXCtorDtorLoweringLegacyPass(PassRegistry &);
+void initializeNVPTXLowerAggrCopiesPass(PassRegistry &);
+void initializeNVPTXLowerAllocaPass(PassRegistry &);
+void initializeNVPTXLowerUnreachablePass(PassRegistry &);
+void initializeNVPTXCtorDtorLoweringLegacyPass(PassRegistry &);
+void initializeNVPTXLowerArgsLegacyPassPass(PassRegistry &);
+void initializeNVPTXProxyRegErasurePass(PassRegistry &);
+void initializeNVPTXForwardParamsPassPass(PassRegistry &);
+void initializeNVVMIntrRangePass(PassRegistry &);
+void initializeNVVMReflectLegacyPassPass(PassRegistry &);
+void initializeNVPTXAAWrapperPassPass(PassRegistry &);
+void initializeNVPTXExternalAAWrapperPass(PassRegistry &);
+
+} // end namespace llvm
+
 extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXTarget() {
   // Register the target.
   RegisterTargetMachine<NVPTXTargetMachine32> X(getTheNVPTXTarget32());
@@ -95,7 +116,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXTarget() {
   PassRegistry &PR = *PassRegistry::getPassRegistry();
   // FIXME: This pass is really intended to be invoked during IR optimization,
   // but it's very NVPTX-specific.
-  initializeNVVMReflectPass(PR);
+  initializeNVVMReflectLegacyPassPass(PR);
   initializeNVVMIntrRangePass(PR);
   initializeGenericToNVVMLegacyPassPass(PR);
   initializeNVPTXAllocaHoistingPass(PR);
diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
index b1e01a1e3fe71..f0c80a9f8bd3d 100644
--- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
@@ -41,44 +41,60 @@
 #include "llvm/ADT/StringExtras.h"
 #define NVVM_REFLECT_FUNCTION "__nvvm_reflect"
 #define NVVM_REFLECT_OCL_FUNCTION "__nvvm_reflect_ocl"
+// Argument of reflect call to retrive arch number
+#define CUDA_ARCH_NAME "__CUDA_ARCH"
+// Argument of reflect call to retrive ftz mode
+#define CUDA_FTZ_NAME "__CUDA_FTZ"
+// Name of module metadata where ftz mode is stored
+#define CUDA_FTZ_MODULE_NAME "nvvm-reflect-ftz"
 
 using namespace llvm;
 
 #define DEBUG_TYPE "nvvm-reflect"
 
+namespace llvm {
+void initializeNVVMReflectLegacyPassPass(PassRegistry &);
+}
+
 namespace {
-class NVVMReflect : public ModulePass {
+class NVVMReflect {
 private:
-  StringMap<int> VarMap;
-  void handleReflectFunction(Function *F);
-  void setVarMap(Module &M);
+  // Map from reflect function call arguments to the value to replace the call with.
+  // Should include __CUDA_FTZ and __CUDA_ARCH values.
+  StringMap<int> ReflectMap;
+  bool handleReflectFunction(Module &M, StringRef ReflectName);
+  void populateReflectMap(Module &M);
   void foldReflectCall(CallInst *Call, Constant *NewValue);
 public:
-  static char ID;
-  NVVMReflect() : NVVMReflect(0) {}
   // __CUDA_FTZ is assigned in `runOnModule` by checking nvvm-reflect-ftz module
   // metadata.
-  explicit NVVMReflect(unsigned SmVersion) : ModulePass(ID), VarMap({{"__CUDA_ARCH", SmVersion * 10}}) {
-    initializeNVVMReflectPass(*PassRegistry::getPassRegistry());
-  }
-  // This mapping will contain should include __CUDA_FTZ and __CUDA_ARCH values.
-  explicit NVVMReflect(const StringMap<int> &Mapping) : ModulePass(ID), VarMap(Mapping) {
-    initializeNVVMReflectPass(*PassRegistry::getPassRegistry());
+  explicit NVVMReflect(unsigned SmVersion) : ReflectMap({{CUDA_ARCH_NAME, SmVersion * 10}}) {}
+  bool runOnModule(Module &M);
+};
+} // namespace
+
+class NVVMReflectLegacyPass : public ModulePass {
+private:
+  NVVMReflect Impl;
+public:
+  static char ID;
+    NVVMReflectLegacyPass(unsigned SmVersion) : ModulePass(ID), Impl(SmVersion) {
+    initializeNVVMReflectLegacyPassPass(*PassRegistry::getPassRegistry());
   }
   bool runOnModule(Module &M) override;
 };
-} // namespace
 
 ModulePass *llvm::createNVVMReflectPass(unsigned int SmVersion) {
-  return new NVVMReflect(SmVersion);
+  LLVM_DEBUG(dbgs() << "Creating NVVMReflectPass with SM version " << SmVersion << "\n");
+  return new NVVMReflectLegacyPass(SmVersion);
 }
 
 static cl::opt<bool>
     NVVMReflectEnabled("nvvm-reflect-enable", cl::init(true), cl::Hidden,
                        cl::desc("NVVM reflection, enabled by default"));
 
-char NVVMReflect::ID = 0;
-INITIALIZE_PASS(NVVMReflect, "nvvm-reflect",
+char NVVMReflectLegacyPass::ID = 0;
+INITIALIZE_PASS(NVVMReflectLegacyPass, "nvvm-reflect",
                 "Replace occurrences of __nvvm_reflect() calls with 0/1", false,
                 false)
 
@@ -92,10 +108,10 @@ ReflectList("nvvm-reflect-add", cl::value_desc("name=<int>"), cl::Hidden,
 
 // Set the VarMap with, first, the value of __CUDA_FTZ from module metadata, and then
 // the key/value pairs from the command line.
-void NVVMReflect::setVarMap(Module &M) {
+void NVVMReflect::populateReflectMap(Module &M) {
   if (auto *Flag = mdconst::extract_or_null<ConstantInt>(
-      M.getModuleFlag("nvvm-reflect-ftz")))
-    VarMap["__CUDA_FTZ"] = Flag->getSExtValue();
+      M.getModuleFlag(CUDA_FTZ_MODULE_NAME)))
+    ReflectMap[CUDA_FTZ_NAME] = Flag->getSExtValue();
 
   for (StringRef Option : ReflectList) {
     LLVM_DEBUG(dbgs() << "ReflectOption : " << Option << "\n");
@@ -107,94 +123,52 @@ void NVVMReflect::setVarMap(Module &M) {
     int ValInt;
     if (!to_integer(Val.trim(), ValInt, 10))
       report_fatal_error("integer value expected in nvvm-reflect-list option '" + Option + "'");
-    VarMap[Name] = ValInt;
+    ReflectMap[Name] = ValInt;
   }
 }
 
-/// Process a reflect function by finding all its uses and replacing them with
+/// Process a reflect function by finding all its calls and replacing them with
 /// appropriate constant values. For __CUDA_FTZ, uses the module flag value.
 /// For __CUDA_ARCH, uses SmVersion * 10. For all other strings, uses 0.
-void NVVMReflect::handleReflectFunction(Function *F) {
-  // Validate _reflect function
+bool NVVMReflect::handleReflectFunction(Module &M, StringRef ReflectName) {
+  Function *F = M.getFunction(ReflectName);
+  if (!F)
+    return false;
   assert(F->isDeclaration() && "_reflect function should not have a body");
   assert(F->getReturnType()->isIntegerTy() && "_reflect's return type should be integer");
 
-
-  // Go through the uses of the reflect function. Each use should be a CallInst
-  // with a ConstantArray argument. Replace the uses with the appropriate constant values.
-
-  // The IR for __nvvm_reflect calls differs between CUDA versions.
-  //
-  // CUDA 6.5 and earlier uses this sequence:
-  //    %ptr = tail call i8* @llvm.nvvm.ptr.constant.to.gen.p0i8.p4i8
-  //        (i8 addrspace(4)* getelementptr inbounds
-  //           ([8 x i8], [8 x i8] addrspace(4)* @str, i32 0, i32 0))
-  //    %reflect = tail call i32 @__nvvm_reflect(i8* %ptr)
-  //
-  // The value returned by Sym->getOperand(0) is a Constant with a
-  // ConstantDataSequential operand which can be converted to string and used
-  // for lookup.
-  //
-  // CUDA 7.0 does it slightly differently:
-  //   %reflect = call i32 @__nvvm_reflect(i8* addrspacecast
-  //        (i8 addrspace(1)* getelementptr inbounds
-  //           ([8 x i8], [8 x i8] addrspace(1)* @str, i32 0, i32 0) to i8*))
-  //
-  // In this case, we get a Constant with a GlobalVariable operand and we need
-  // to dig deeper to find its initializer with the string we'll use for lookup.
-
+  bool Changed = F->getNumUses() > 0;
   for (User *U : make_early_inc_range(F->users())) {
+    // Reflect function calls look like:
+    // @arch = private unnamed_addr addrspace(1) constant [12 x i8] c"__CUDA_ARCH\00"
+    // call i32 @__nvvm_reflect(ptr addrspacecast (ptr addrspace(1) @arch to ptr))
+    // We need to extract the string argument from the call (i.e. "__CUDA_ARCH")
     if (!isa<CallInst>(U))
       report_fatal_error("__nvvm_reflect can only be used in a call instruction");
     CallInst *Call = cast<CallInst>(U);
-
     if (Call->getNumOperands() != 2)
       report_fatal_error("__nvvm_reflect requires exactly one argument");
 
-    // In cuda 6.5 and earlier, we will have an extra constant-to-generic
-    // conversion of the string.
-    const Value *Str = Call->getArgOperand(0);
-    if (const CallInst *ConvCall = dyn_cast<CallInst>(Str)) {
-      // Verify this is the constant-to-generic intrinsic
-      Function *Callee = ConvCall->getCalledFunction();
-      if (!Callee || !Callee->isIntrinsic() || 
-          !Callee->getName().starts_with("llvm.nvvm.ptr.constant.to.gen"))
-        report_fatal_error("Expected llvm.nvvm.ptr.constant.to.gen intrinsic");
-      if (ConvCall->getNumOperands() != 2)
-        report_fatal_error("Expected one argument for ptr conversion");
-      Str = ConvCall->getArgOperand(0);
-    }
-    // Pre opaque pointers we have a constant expression wrapping the constant
-    Str = Str->stripPointerCasts();
-    if (!isa<Constant>(Str))
+    const Value *GlobalStr = Call->getArgOperand(0)->stripPointerCasts();
+    if (!isa<Constant>(GlobalStr))
       report_fatal_error("__nvvm_reflect argument must be a constant string");
 
-    const Value *Operand = cast<Constant>(Str)->getOperand(0);
-    if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(Operand)) {
-      // For CUDA-7.0 style __nvvm_reflect calls, we need to find the operand's
-      // initializer.
-      if (!GV->hasInitializer())
-        report_fatal_error("__nvvm_reflect string must have an initializer");
-      const Constant *Initializer = GV->getInitializer();
-      Operand = Initializer;
-    }
-
-    if (!isa<ConstantDataSequential>(Operand))
+    const Value *ConstantStr = cast<Constant>(GlobalStr)->getOperand(0);
+    if (!isa<ConstantDataSequential>(ConstantStr))
       report_fatal_error("__nvvm_reflect argument must be a string constant");
-    if (!cast<ConstantDataSequential>(Operand)->isCString())
+    if (!cast<ConstantDataSequential>(ConstantStr)->isCString())
       report_fatal_error("__nvvm_reflect argument must be a null-terminated string");
 
-    StringRef ReflectArg = cast<ConstantDataSequential>(Operand)->getAsString();
+    StringRef ReflectArg = cast<ConstantDataSequential>(ConstantStr)->getAsString();
     // Remove the null terminator from the string
     ReflectArg = ReflectArg.substr(0, ReflectArg.size() - 1);
-
     if (ReflectArg.empty())
       report_fatal_error("__nvvm_reflect argument cannot be empty");
-
+    // Now that we have extracted the string argument, we can look it up in the VarMap
     int ReflectVal = 0; // The default value is 0
-    if (VarMap.contains(ReflectArg)) {
-      ReflectVal = VarMap[ReflectArg];
-    }
+    if (ReflectMap.contains(ReflectArg))
+      ReflectVal = ReflectMap[ReflectArg];
+
     LLVM_DEBUG(dbgs() << "Replacing call of reflect function " << F->getName() << "(" << ReflectArg << ") with value " << ReflectVal << "\n");
     Constant *NewValue = ConstantInt::get(Call->getType(), ReflectVal);
     foldReflectCall(Call, NewValue);
@@ -203,29 +177,26 @@ void NVVMReflect::handleReflectFunction(Function *F) {
 
   // Remove the __nvvm_reflect function from the module
   F->eraseFromParent();
+  return Changed;
 }
 
 void NVVMReflect::foldReflectCall(CallInst *Call, Constant *NewValue) {
-  // Initialize worklist with all users of the call
   SmallVector<Instruction*, 8> Worklist;
-  for (User *U : Call->users()) {
-    if (Instruction *I = dyn_cast<Instruction>(U)) {
-      Worklist.push_back(I);
+  // Replace an instruction with a constant and add all users of the instruction to the worklist
+  auto ReplaceInstructionWithConst = [&](Instruction *I, Constant *C) {
+    for (User *U : I->users()) {
+      if (Instruction *UI = dyn_cast<Instruction>(U))
+        Worklist.push_back(UI);
     }
-  }
+    I->replaceAllUsesWith(C);
+  };
 
-  Call->replaceAllUsesWith(NewValue);
+  ReplaceInstructionWithConst(Call, NewValue);
 
   while (!Worklist.empty()) {
     Instruction *I = Worklist.pop_back_val();
     if (Constant *C = ConstantFoldInstruction(I, Call->getModule()->getDataLayout())) {
-      // Add all users of this instruction to the worklist, replace it with the constant
-      // then delete it if it's dead
-      for (User *U : I->users()) {
-        if (Instruction *UI = dyn_cast<Instruction>(U))
-          Worklist.push_back(UI);
-      }
-      I->replaceAllUsesWith(C);
+      ReplaceInstructionWithConst(I, C);
       if (isInstructionTriviallyDead(I))
         I->eraseFromParent();
     } else if (I->isTerminator()) {
@@ -237,37 +208,20 @@ void NVVMReflect::foldReflectCall(CallInst *Call, Constant *NewValue) {
 bool NVVMReflect::runOnModule(Module &M) {
   if (!NVVMReflectEnabled)
     return false;
-
-  setVarMap(M);
-
-  bool Changed = false;
-  // Names of reflect function to find and replace
-  SmallVector<StringRef, 5> ReflectNames = {
-      NVVM_REFLECT_FUNCTION,
-      NVVM_REFLECT_OCL_FUNCTION,
-      Intrinsic::getName(Intrinsic::nvvm_reflect),
-  };
-
-  // Process all reflect functions
-  for (StringRef Name : ReflectNames) {
-    if (Function *ReflectFunction = M.getFunction(Name)) {
-      // If the reflect functition is called, we need to replace the call
-      // with the appropriate constant, modifying the IR.
-      Changed |= ReflectFunction->getNumUses() > 0;
-      handleReflectFunction(ReflectFunction);
-    }
-  }
-
+  populateReflectMap(M);
+  bool Changed = true;
+  handleReflectFunction(M, NVVM_REFLECT_FUNCTION);
+  handleReflectFunction(M, NVVM_REFLECT_OCL_FUNCTION);
+  handleReflectFunction(M, Intrinsic::getName(Intrinsic::nvvm_reflect));
   return Changed;
 }
 
-// Implementations for the pass that works with the new pass manager.
-NVVMReflectPass::NVVMReflectPass(unsigned SmVersion) {
-  VarMap["__CUDA_ARCH"] = SmVersion * 10;
+bool NVVMReflectLegacyPass::runOnModule(Module &M) {
+  return Impl.runOnModule(M);
 }
 
 PreservedAnalyses NVVMReflectPass::run(Module &M,
                                     ModuleAnalysisManager &AM) {
-  return NVVMReflect(VarMap).runOnModule(M) ? PreservedAnalyses::none()
+  return NVVMReflect(SmVersion).runOnModule(M) ? PreservedAnalyses::none()
                                    : PreservedAnalyses::all();
 }
\ No newline at end of file

>From cbf8664c870159f55ce372ccf64417c0a411850d Mon Sep 17 00:00:00 2001
From: Yonah Goldberg <ygoldberg at nvidia.com>
Date: Tue, 8 Apr 2025 22:22:43 +0000
Subject: [PATCH 12/13] final reflect cleanup

---
 llvm/lib/Target/NVPTX/NVVMReflect.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
index f0c80a9f8bd3d..3c22798e6504e 100644
--- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
@@ -73,16 +73,16 @@ class NVVMReflect {
 };
 } // namespace
 
+namespace {
 class NVVMReflectLegacyPass : public ModulePass {
 private:
   NVVMReflect Impl;
 public:
   static char ID;
-    NVVMReflectLegacyPass(unsigned SmVersion) : ModulePass(ID), Impl(SmVersion) {
-    initializeNVVMReflectLegacyPassPass(*PassRegistry::getPassRegistry());
-  }
+    NVVMReflectLegacyPass(unsigned SmVersion) : ModulePass(ID), Impl(SmVersion) {}
   bool runOnModule(Module &M) override;
 };
+} // namespace
 
 ModulePass *llvm::createNVVMReflectPass(unsigned int SmVersion) {
   LLVM_DEBUG(dbgs() << "Creating NVVMReflectPass with SM version " << SmVersion << "\n");

>From 790b4bbd79f0ee5a1af5cb44b84315f2ef7792de Mon Sep 17 00:00:00 2001
From: Yonah Goldberg <ygoldberg at nvidia.com>
Date: Tue, 8 Apr 2025 22:29:40 +0000
Subject: [PATCH 13/13] clang format

---
 llvm/lib/Target/NVPTX/NVPTX.h                | 35 ++------
 llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp | 23 +----
 llvm/lib/Target/NVPTX/NVVMReflect.cpp        | 92 ++++++++++++--------
 3 files changed, 63 insertions(+), 87 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h
index 93b31f2a4c01d..be1f9c380b680 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.h
+++ b/llvm/lib/Target/NVPTX/NVPTX.h
@@ -27,14 +27,7 @@ class NVPTXTargetMachine;
 class PassRegistry;
 
 namespace NVPTXCC {
-enum CondCodes {
-  EQ,
-  NE,
-  LT,
-  LE,
-  GT,
-  GE
-};
+enum CondCodes { EQ, NE, LT, LE, GT, GE };
 }
 
 FunctionPass *createNVPTXISelDag(NVPTXTargetMachine &TM,
@@ -55,6 +48,7 @@ MachineFunctionPass *createNVPTXPeephole();
 MachineFunctionPass *createNVPTXProxyRegErasurePass();
 MachineFunctionPass *createNVPTXForwardParamsPass();
 
+void initializeNVVMReflectLegacyPassPass(PassRegistry &);
 void initializeGenericToNVVMLegacyPassPass(PassRegistry &);
 void initializeNVPTXAllocaHoistingPass(PassRegistry &);
 void initializeNVPTXAssignValidGlobalNamesPass(PassRegistry &);
@@ -104,10 +98,7 @@ struct NVPTXLowerArgsPass : PassInfoMixin<NVPTXLowerArgsPass> {
 };
 
 namespace NVPTX {
-enum DrvInterface {
-  NVCL,
-  CUDA
-};
+enum DrvInterface { NVCL, CUDA };
 
 // A field inside TSFlags needs a shift and a mask. The usage is
 // always as follows :
@@ -129,10 +120,7 @@ enum VecInstType {
   VecOther = 15
 };
 
-enum SimpleMove {
-  SimpleMoveMask = 0x10,
-  SimpleMoveShift = 4
-};
+enum SimpleMove { SimpleMoveMask = 0x10, SimpleMoveShift = 4 };
 enum LoadStore {
   isLoadMask = 0x20,
   isLoadShift = 5,
@@ -181,17 +169,8 @@ enum AddressSpace : AddressSpaceUnderlyingType {
 };
 
 namespace PTXLdStInstCode {
-enum FromType {
-  Unsigned = 0,
-  Signed,
-  Float,
-  Untyped
-};
-enum VecType {
-  Scalar = 1,
-  V2 = 2,
-  V4 = 4
-};
+enum FromType { Unsigned = 0, Signed, Float, Untyped };
+enum VecType { Scalar = 1, V2 = 2, V4 = 4 };
 } // namespace PTXLdStInstCode
 
 /// PTXCvtMode - Conversion code enumeration
@@ -254,7 +233,7 @@ enum PrmtMode {
   RC16,
 };
 }
-}
+} // namespace NVPTX
 void initializeNVPTXDAGToDAGISelLegacyPass(PassRegistry &);
 } // namespace llvm
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
index 8bc94846337cf..a4c3b43aec9f2 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
@@ -87,27 +87,6 @@ static cl::opt<bool> EarlyByValArgsCopy(
     cl::desc("Create a copy of byval function arguments early."),
     cl::init(false), cl::Hidden);
 
-namespace llvm {
-
-void initializeGenericToNVVMLegacyPassPass(PassRegistry &);
-void initializeNVPTXAllocaHoistingPass(PassRegistry &);
-void initializeNVPTXAssignValidGlobalNamesPass(PassRegistry &);
-void initializeNVPTXAtomicLowerPass(PassRegistry &);
-void initializeNVPTXCtorDtorLoweringLegacyPass(PassRegistry &);
-void initializeNVPTXLowerAggrCopiesPass(PassRegistry &);
-void initializeNVPTXLowerAllocaPass(PassRegistry &);
-void initializeNVPTXLowerUnreachablePass(PassRegistry &);
-void initializeNVPTXCtorDtorLoweringLegacyPass(PassRegistry &);
-void initializeNVPTXLowerArgsLegacyPassPass(PassRegistry &);
-void initializeNVPTXProxyRegErasurePass(PassRegistry &);
-void initializeNVPTXForwardParamsPassPass(PassRegistry &);
-void initializeNVVMIntrRangePass(PassRegistry &);
-void initializeNVVMReflectLegacyPassPass(PassRegistry &);
-void initializeNVPTXAAWrapperPassPass(PassRegistry &);
-void initializeNVPTXExternalAAWrapperPass(PassRegistry &);
-
-} // end namespace llvm
-
 extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXTarget() {
   // Register the target.
   RegisterTargetMachine<NVPTXTargetMachine32> X(getTheNVPTXTarget32());
@@ -265,7 +244,7 @@ void NVPTXTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
         // has not provided a target architecture just yet.
         if (Subtarget.hasTargetName())
           PM.addPass(NVVMReflectPass(Subtarget.getSmVersion()));
-        
+
         FunctionPassManager FPM;
         // Note: NVVMIntrRangePass was causing numerical discrepancies at one
         // point, if issues crop up, consider disabling.
diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
index 3c22798e6504e..3ea6695c3d5f1 100644
--- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
@@ -20,6 +20,7 @@
 
 #include "NVPTX.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringExtras.h"
 #include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/CodeGen/CommandFlags.h"
 #include "llvm/IR/Constants.h"
@@ -38,7 +39,6 @@
 #include "llvm/Transforms/Scalar.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/Local.h"
-#include "llvm/ADT/StringExtras.h"
 #define NVVM_REFLECT_FUNCTION "__nvvm_reflect"
 #define NVVM_REFLECT_OCL_FUNCTION "__nvvm_reflect_ocl"
 // Argument of reflect call to retrive arch number
@@ -54,21 +54,23 @@ using namespace llvm;
 
 namespace llvm {
 void initializeNVVMReflectLegacyPassPass(PassRegistry &);
-}
+} // namespace llvm
 
 namespace {
 class NVVMReflect {
 private:
-  // Map from reflect function call arguments to the value to replace the call with.
-  // Should include __CUDA_FTZ and __CUDA_ARCH values.
+  // Map from reflect function call arguments to the value to replace the call
+  // with. Should include __CUDA_FTZ and __CUDA_ARCH values.
   StringMap<int> ReflectMap;
   bool handleReflectFunction(Module &M, StringRef ReflectName);
   void populateReflectMap(Module &M);
   void foldReflectCall(CallInst *Call, Constant *NewValue);
+
 public:
   // __CUDA_FTZ is assigned in `runOnModule` by checking nvvm-reflect-ftz module
   // metadata.
-  explicit NVVMReflect(unsigned SmVersion) : ReflectMap({{CUDA_ARCH_NAME, SmVersion * 10}}) {}
+  explicit NVVMReflect(unsigned SmVersion)
+      : ReflectMap({{CUDA_ARCH_NAME, SmVersion * 10}}) {}
   bool runOnModule(Module &M);
 };
 } // namespace
@@ -77,15 +79,17 @@ namespace {
 class NVVMReflectLegacyPass : public ModulePass {
 private:
   NVVMReflect Impl;
+
 public:
   static char ID;
-    NVVMReflectLegacyPass(unsigned SmVersion) : ModulePass(ID), Impl(SmVersion) {}
+  NVVMReflectLegacyPass(unsigned SmVersion) : ModulePass(ID), Impl(SmVersion) {}
   bool runOnModule(Module &M) override;
 };
 } // namespace
 
 ModulePass *llvm::createNVVMReflectPass(unsigned int SmVersion) {
-  LLVM_DEBUG(dbgs() << "Creating NVVMReflectPass with SM version " << SmVersion << "\n");
+  LLVM_DEBUG(dbgs() << "Creating NVVMReflectPass with SM version " << SmVersion
+                    << "\n");
   return new NVVMReflectLegacyPass(SmVersion);
 }
 
@@ -98,31 +102,36 @@ INITIALIZE_PASS(NVVMReflectLegacyPass, "nvvm-reflect",
                 "Replace occurrences of __nvvm_reflect() calls with 0/1", false,
                 false)
 
-// Allow users to specify additional key/value pairs to reflect. These key/value pairs
-// are the last to be added to the VarMap, and therefore will take precedence over initial
-// values (i.e. __CUDA_FTZ from module medadata and __CUDA_ARCH from SmVersion).
-static cl::list<std::string>
-ReflectList("nvvm-reflect-add", cl::value_desc("name=<int>"), cl::Hidden,
-            cl::desc("A key=value pair. Replace __nvvm_reflect(name) with value."),
-            cl::ValueRequired);
-
-// Set the VarMap with, first, the value of __CUDA_FTZ from module metadata, and then
-// the key/value pairs from the command line.
+// Allow users to specify additional key/value pairs to reflect. These key/value
+// pairs are the last to be added to the VarMap, and therefore will take
+// precedence over initial values (i.e. __CUDA_FTZ from module medadata and
+// __CUDA_ARCH from SmVersion).
+static cl::list<std::string> ReflectList(
+    "nvvm-reflect-add", cl::value_desc("name=<int>"), cl::Hidden,
+    cl::desc("A key=value pair. Replace __nvvm_reflect(name) with value."),
+    cl::ValueRequired);
+
+// Set the VarMap with, first, the value of __CUDA_FTZ from module metadata, and
+// then the key/value pairs from the command line.
 void NVVMReflect::populateReflectMap(Module &M) {
   if (auto *Flag = mdconst::extract_or_null<ConstantInt>(
-      M.getModuleFlag(CUDA_FTZ_MODULE_NAME)))
+          M.getModuleFlag(CUDA_FTZ_MODULE_NAME)))
     ReflectMap[CUDA_FTZ_NAME] = Flag->getSExtValue();
 
   for (StringRef Option : ReflectList) {
     LLVM_DEBUG(dbgs() << "ReflectOption : " << Option << "\n");
     auto [Name, Val] = Option.split('=');
     if (Name.empty())
-      report_fatal_error("Empty name in nvvm-reflect-list option '" + Option + "'");
-    if (Val.empty()) 
-      report_fatal_error("Missing value in nvvm-reflect- option '" + Option + "'");
+      report_fatal_error("Empty name in nvvm-reflect-list option '" + Option +
+                         "'");
+    if (Val.empty())
+      report_fatal_error("Missing value in nvvm-reflect- option '" + Option +
+                         "'");
     int ValInt;
     if (!to_integer(Val.trim(), ValInt, 10))
-      report_fatal_error("integer value expected in nvvm-reflect-list option '" + Option + "'");
+      report_fatal_error(
+          "integer value expected in nvvm-reflect-list option '" + Option +
+          "'");
     ReflectMap[Name] = ValInt;
   }
 }
@@ -135,16 +144,19 @@ bool NVVMReflect::handleReflectFunction(Module &M, StringRef ReflectName) {
   if (!F)
     return false;
   assert(F->isDeclaration() && "_reflect function should not have a body");
-  assert(F->getReturnType()->isIntegerTy() && "_reflect's return type should be integer");
+  assert(F->getReturnType()->isIntegerTy() &&
+         "_reflect's return type should be integer");
 
   bool Changed = F->getNumUses() > 0;
   for (User *U : make_early_inc_range(F->users())) {
     // Reflect function calls look like:
-    // @arch = private unnamed_addr addrspace(1) constant [12 x i8] c"__CUDA_ARCH\00"
-    // call i32 @__nvvm_reflect(ptr addrspacecast (ptr addrspace(1) @arch to ptr))
-    // We need to extract the string argument from the call (i.e. "__CUDA_ARCH")
+    // @arch = private unnamed_addr addrspace(1) constant [12 x i8]
+    // c"__CUDA_ARCH\00" call i32 @__nvvm_reflect(ptr addrspacecast (ptr
+    // addrspace(1) @arch to ptr)) We need to extract the string argument from
+    // the call (i.e. "__CUDA_ARCH")
     if (!isa<CallInst>(U))
-      report_fatal_error("__nvvm_reflect can only be used in a call instruction");
+      report_fatal_error(
+          "__nvvm_reflect can only be used in a call instruction");
     CallInst *Call = cast<CallInst>(U);
     if (Call->getNumOperands() != 2)
       report_fatal_error("__nvvm_reflect requires exactly one argument");
@@ -157,19 +169,24 @@ bool NVVMReflect::handleReflectFunction(Module &M, StringRef ReflectName) {
     if (!isa<ConstantDataSequential>(ConstantStr))
       report_fatal_error("__nvvm_reflect argument must be a string constant");
     if (!cast<ConstantDataSequential>(ConstantStr)->isCString())
-      report_fatal_error("__nvvm_reflect argument must be a null-terminated string");
+      report_fatal_error(
+          "__nvvm_reflect argument must be a null-terminated string");
 
-    StringRef ReflectArg = cast<ConstantDataSequential>(ConstantStr)->getAsString();
+    StringRef ReflectArg =
+        cast<ConstantDataSequential>(ConstantStr)->getAsString();
     // Remove the null terminator from the string
     ReflectArg = ReflectArg.substr(0, ReflectArg.size() - 1);
     if (ReflectArg.empty())
       report_fatal_error("__nvvm_reflect argument cannot be empty");
-    // Now that we have extracted the string argument, we can look it up in the VarMap
+    // Now that we have extracted the string argument, we can look it up in the
+    // VarMap
     int ReflectVal = 0; // The default value is 0
     if (ReflectMap.contains(ReflectArg))
       ReflectVal = ReflectMap[ReflectArg];
 
-    LLVM_DEBUG(dbgs() << "Replacing call of reflect function " << F->getName() << "(" << ReflectArg << ") with value " << ReflectVal << "\n");
+    LLVM_DEBUG(dbgs() << "Replacing call of reflect function " << F->getName()
+                      << "(" << ReflectArg << ") with value " << ReflectVal
+                      << "\n");
     Constant *NewValue = ConstantInt::get(Call->getType(), ReflectVal);
     foldReflectCall(Call, NewValue);
     Call->eraseFromParent();
@@ -181,8 +198,9 @@ bool NVVMReflect::handleReflectFunction(Module &M, StringRef ReflectName) {
 }
 
 void NVVMReflect::foldReflectCall(CallInst *Call, Constant *NewValue) {
-  SmallVector<Instruction*, 8> Worklist;
-  // Replace an instruction with a constant and add all users of the instruction to the worklist
+  SmallVector<Instruction *, 8> Worklist;
+  // Replace an instruction with a constant and add all users of the instruction
+  // to the worklist
   auto ReplaceInstructionWithConst = [&](Instruction *I, Constant *C) {
     for (User *U : I->users()) {
       if (Instruction *UI = dyn_cast<Instruction>(U))
@@ -195,7 +213,8 @@ void NVVMReflect::foldReflectCall(CallInst *Call, Constant *NewValue) {
 
   while (!Worklist.empty()) {
     Instruction *I = Worklist.pop_back_val();
-    if (Constant *C = ConstantFoldInstruction(I, Call->getModule()->getDataLayout())) {
+    if (Constant *C =
+            ConstantFoldInstruction(I, Call->getModule()->getDataLayout())) {
       ReplaceInstructionWithConst(I, C);
       if (isInstructionTriviallyDead(I))
         I->eraseFromParent();
@@ -220,8 +239,7 @@ bool NVVMReflectLegacyPass::runOnModule(Module &M) {
   return Impl.runOnModule(M);
 }
 
-PreservedAnalyses NVVMReflectPass::run(Module &M,
-                                    ModuleAnalysisManager &AM) {
+PreservedAnalyses NVVMReflectPass::run(Module &M, ModuleAnalysisManager &AM) {
   return NVVMReflect(SmVersion).runOnModule(M) ? PreservedAnalyses::none()
-                                   : PreservedAnalyses::all();
+                                               : PreservedAnalyses::all();
 }
\ No newline at end of file



More information about the llvm-commits mailing list