[clang] 701d726 - [NVPTX] Improve NVVMReflect Efficiency (#134416)

via cfe-commits cfe-commits at lists.llvm.org
Thu Apr 10 18:33:40 PDT 2025


Author: Yonah Goldberg
Date: 2025-04-10T18:33:37-07:00
New Revision: 701d726ef09ea89909df9bd2fdc63c63758fe8d6

URL: https://github.com/llvm/llvm-project/commit/701d726ef09ea89909df9bd2fdc63c63758fe8d6
DIFF: https://github.com/llvm/llvm-project/commit/701d726ef09ea89909df9bd2fdc63c63758fe8d6.diff

LOG: [NVPTX] Improve NVVMReflect Efficiency (#134416)

The NVVMReflect pass simply replaces calls to nvvm-reflect functions
with the appropriate constant, either the architecture number, or
nvvm-reflect-ftz, found in the module's metadata.

The implementation is inefficient and does this by traversing through
all instructions to find calls. The common case is that you never call
nvvm-reflect, so this traversal is costly.

This PR:
- Updates the pass so that it finds the reflect functions by name, and
then traverses through their uses to find the calls directly.
- Adds a line (245) to make sure the dead nvvm-reflect definitions are
erased.
- Adds the ability to set reflect values via command line

Added: 
    llvm/test/CodeGen/NVPTX/nvvm-reflect-options.ll

Modified: 
    clang/test/CodeGen/builtins-nvptx.c
    llvm/lib/Target/NVPTX/NVPTX.h
    llvm/lib/Target/NVPTX/NVPTXPassRegistry.def
    llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
    llvm/lib/Target/NVPTX/NVVMReflect.cpp

Removed: 
    


################################################################################
diff  --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c
index 71b29849618b6..e2c159aac903f 100644
--- a/clang/test/CodeGen/builtins-nvptx.c
+++ b/clang/test/CodeGen/builtins-nvptx.c
@@ -1,37 +1,37 @@
 // REQUIRES: nvptx-registered-target
 // RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_70 -target-feature +ptx63 \
-// RUN:            -fcuda-is-device -emit-llvm -o - -x cuda %s \
+// RUN:            -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
 // RUN:   | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX63_SM70 -check-prefix=LP64 %s
 // RUN: %clang_cc1 -ffp-contract=off -triple nvptx-unknown-unknown -target-cpu sm_80 -target-feature +ptx70 \
-// RUN:            -fcuda-is-device -emit-llvm -o - -x cuda %s \
+// RUN:            -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
 // RUN:   | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX70_SM80 -check-prefix=LP32 %s
 // RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_80 -target-feature +ptx70 \
-// RUN:            -fcuda-is-device -emit-llvm -o - -x cuda %s \
+// RUN:            -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
 // RUN:   | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX70_SM80 -check-prefix=LP64 %s
 // RUN: %clang_cc1 -ffp-contract=off -triple nvptx-unknown-unknown -target-cpu sm_60 -target-feature +ptx62 \
-// RUN:            -fcuda-is-device -emit-llvm -o - -x cuda %s \
+// RUN:            -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
 // RUN:   | FileCheck -check-prefix=CHECK -check-prefix=LP32 %s
 // RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_60 -target-feature +ptx62 \
-// RUN:            -fcuda-is-device -emit-llvm -o - -x cuda %s \
+// RUN:            -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
 // RUN:   | FileCheck -check-prefix=CHECK -check-prefix=LP64 %s
 // RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_61 -target-feature +ptx62 \
-// RUN:            -fcuda-is-device -emit-llvm -o - -x cuda %s \
+// RUN:            -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
 // RUN:   | FileCheck -check-prefix=CHECK -check-prefix=LP64 %s
 // RUN: %clang_cc1 -triple nvptx-unknown-unknown -target-cpu sm_53 -target-feature +ptx62 \
 // RUN:   -DERROR_CHECK -fcuda-is-device -S -o /dev/null -x cuda -verify %s
 // RUN: %clang_cc1 -ffp-contract=off -triple nvptx-unknown-unknown -target-cpu sm_86 -target-feature +ptx72 \
-// RUN:            -fcuda-is-device -emit-llvm -o - -x cuda %s \
+// RUN:            -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
 // RUN:   | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX72_SM86 -check-prefix=LP32 %s
 // RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_86 -target-feature +ptx72 \
-// RUN:            -fcuda-is-device -emit-llvm -o - -x cuda %s \
+// RUN:            -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
 // RUN:   | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX72_SM86 -check-prefix=LP64 %s
 // RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_89 -target-feature +ptx81 \
-// RUN:            -fcuda-is-device -emit-llvm -o - -x cuda %s \
+// RUN:            -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
 // RUN:   | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX81_SM89 %s
 // ###  The last run to check with the highest SM and PTX version available
 // ###  to make sure target builtins are still accepted.
 // RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_100a -target-feature +ptx87 \
-// RUN:            -fcuda-is-device -emit-llvm -o - -x cuda %s \
+// RUN:            -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
 // RUN:   | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX81_SM89 %s
 
 #define __device__ __attribute__((device))
@@ -61,6 +61,7 @@ __device__ bool reflect() {
 
   unsigned x = __nvvm_reflect("__CUDA_ARCH");
   return x >= 700;
+
 }
 
 __device__ int read_ntid() {

diff  --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h
index ff983e52179af..98e77ca80b8d5 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.h
+++ b/llvm/lib/Target/NVPTX/NVPTX.h
@@ -43,7 +43,7 @@ ModulePass *createNVPTXAssignValidGlobalNamesPass();
 ModulePass *createGenericToNVVMLegacyPass();
 ModulePass *createNVPTXCtorDtorLoweringLegacyPass();
 FunctionPass *createNVVMIntrRangePass();
-FunctionPass *createNVVMReflectPass(unsigned int SmVersion);
+ModulePass *createNVVMReflectPass(unsigned int SmVersion);
 MachineFunctionPass *createNVPTXPrologEpilogPass();
 MachineFunctionPass *createNVPTXReplaceImageHandlesPass();
 FunctionPass *createNVPTXImageOptimizerPass();
@@ -55,6 +55,7 @@ MachineFunctionPass *createNVPTXPeephole();
 MachineFunctionPass *createNVPTXProxyRegErasurePass();
 MachineFunctionPass *createNVPTXForwardParamsPass();
 
+void initializeNVVMReflectLegacyPassPass(PassRegistry &);
 void initializeGenericToNVVMLegacyPassPass(PassRegistry &);
 void initializeNVPTXAllocaHoistingPass(PassRegistry &);
 void initializeNVPTXAssignValidGlobalNamesPass(PassRegistry &);
@@ -78,9 +79,9 @@ struct NVVMIntrRangePass : PassInfoMixin<NVVMIntrRangePass> {
 };
 
 struct NVVMReflectPass : PassInfoMixin<NVVMReflectPass> {
-  NVVMReflectPass();
+  NVVMReflectPass() : SmVersion(0) {}
   NVVMReflectPass(unsigned SmVersion) : SmVersion(SmVersion) {}
-  PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
+  PreservedAnalyses run(Module &F, ModuleAnalysisManager &AM);
 
 private:
   unsigned SmVersion;

diff  --git a/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def b/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def
index 34c79b8f77bae..1c813c2c51f70 100644
--- a/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def
+++ b/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def
@@ -18,6 +18,7 @@
 #endif
 MODULE_PASS("generic-to-nvvm", GenericToNVVMPass())
 MODULE_PASS("nvptx-lower-ctor-dtor", NVPTXCtorDtorLoweringPass())
+MODULE_PASS("nvvm-reflect", NVVMReflectPass())
 #undef MODULE_PASS
 
 #ifndef FUNCTION_ANALYSIS
@@ -36,7 +37,6 @@ FUNCTION_ALIAS_ANALYSIS("nvptx-aa", NVPTXAA())
 #define FUNCTION_PASS(NAME, CREATE_PASS)
 #endif
 FUNCTION_PASS("nvvm-intr-range", NVVMIntrRangePass())
-FUNCTION_PASS("nvvm-reflect", NVVMReflectPass())
 FUNCTION_PASS("nvptx-copy-byval-args", NVPTXCopyByValArgsPass())
 FUNCTION_PASS("nvptx-lower-args", NVPTXLowerArgsPass(*this));
 #undef FUNCTION_PASS

diff  --git a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
index 8a25256ea1e4a..a4c3b43aec9f2 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
@@ -95,7 +95,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXTarget() {
   PassRegistry &PR = *PassRegistry::getPassRegistry();
   // FIXME: This pass is really intended to be invoked during IR optimization,
   // but it's very NVPTX-specific.
-  initializeNVVMReflectPass(PR);
+  initializeNVVMReflectLegacyPassPass(PR);
   initializeNVVMIntrRangePass(PR);
   initializeGenericToNVVMLegacyPassPass(PR);
   initializeNVPTXAllocaHoistingPass(PR);
@@ -240,11 +240,12 @@ void NVPTXTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
 
   PB.registerPipelineStartEPCallback(
       [this](ModulePassManager &PM, OptimizationLevel Level) {
-        FunctionPassManager FPM;
         // We do not want to fold out calls to nvvm.reflect early if the user
         // has not provided a target architecture just yet.
         if (Subtarget.hasTargetName())
-          FPM.addPass(NVVMReflectPass(Subtarget.getSmVersion()));
+          PM.addPass(NVVMReflectPass(Subtarget.getSmVersion()));
+
+        FunctionPassManager FPM;
         // Note: NVVMIntrRangePass was causing numerical discrepancies at one
         // point, if issues crop up, consider disabling.
         FPM.addPass(NVVMIntrRangePass());

diff  --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
index 2809ec2303f99..7273b30e4ae2e 100644
--- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
@@ -20,12 +20,12 @@
 
 #include "NVPTX.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringExtras.h"
 #include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/CodeGen/CommandFlags.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/Function.h"
-#include "llvm/IR/InstIterator.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsNVPTX.h"
@@ -39,186 +39,201 @@
 #include "llvm/Transforms/Scalar.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/Local.h"
-#include <algorithm>
 #define NVVM_REFLECT_FUNCTION "__nvvm_reflect"
 #define NVVM_REFLECT_OCL_FUNCTION "__nvvm_reflect_ocl"
+// Argument of reflect call to retrive arch number
+#define CUDA_ARCH_NAME "__CUDA_ARCH"
+// Argument of reflect call to retrive ftz mode
+#define CUDA_FTZ_NAME "__CUDA_FTZ"
+// Name of module metadata where ftz mode is stored
+#define CUDA_FTZ_MODULE_NAME "nvvm-reflect-ftz"
 
 using namespace llvm;
 
-#define DEBUG_TYPE "nvptx-reflect"
+#define DEBUG_TYPE "nvvm-reflect"
+
+namespace llvm {
+void initializeNVVMReflectLegacyPassPass(PassRegistry &);
+} // namespace llvm
 
 namespace {
-class NVVMReflect : public FunctionPass {
+class NVVMReflect {
+  // Map from reflect function call arguments to the value to replace the call
+  // with. Should include __CUDA_FTZ and __CUDA_ARCH values.
+  StringMap<unsigned> ReflectMap;
+  bool handleReflectFunction(Module &M, StringRef ReflectName);
+  void populateReflectMap(Module &M);
+  void foldReflectCall(CallInst *Call, Constant *NewValue);
+
 public:
-  static char ID;
-  unsigned int SmVersion;
-  NVVMReflect() : NVVMReflect(0) {}
-  explicit NVVMReflect(unsigned int Sm) : FunctionPass(ID), SmVersion(Sm) {}
+  // __CUDA_FTZ is assigned in `runOnModule` by checking nvvm-reflect-ftz module
+  // metadata.
+  explicit NVVMReflect(unsigned SmVersion)
+      : ReflectMap({{CUDA_ARCH_NAME, SmVersion * 10}}) {}
+  bool runOnModule(Module &M);
+};
 
-  bool runOnFunction(Function &) override;
+class NVVMReflectLegacyPass : public ModulePass {
+  NVVMReflect Impl;
+
+public:
+  static char ID;
+  NVVMReflectLegacyPass(unsigned SmVersion) : ModulePass(ID), Impl(SmVersion) {}
+  bool runOnModule(Module &M) override;
 };
 } // namespace
 
-FunctionPass *llvm::createNVVMReflectPass(unsigned int SmVersion) {
-  return new NVVMReflect(SmVersion);
+ModulePass *llvm::createNVVMReflectPass(unsigned SmVersion) {
+  return new NVVMReflectLegacyPass(SmVersion);
 }
 
 static cl::opt<bool>
     NVVMReflectEnabled("nvvm-reflect-enable", cl::init(true), cl::Hidden,
                        cl::desc("NVVM reflection, enabled by default"));
 
-char NVVMReflect::ID = 0;
-INITIALIZE_PASS(NVVMReflect, "nvvm-reflect",
+char NVVMReflectLegacyPass::ID = 0;
+INITIALIZE_PASS(NVVMReflectLegacyPass, "nvvm-reflect",
                 "Replace occurrences of __nvvm_reflect() calls with 0/1", false,
                 false)
 
-static bool runNVVMReflect(Function &F, unsigned SmVersion) {
-  if (!NVVMReflectEnabled)
-    return false;
-
-  if (F.getName() == NVVM_REFLECT_FUNCTION ||
-      F.getName() == NVVM_REFLECT_OCL_FUNCTION) {
-    assert(F.isDeclaration() && "_reflect function should not have a body");
-    assert(F.getReturnType()->isIntegerTy() &&
-           "_reflect's return type should be integer");
-    return false;
+// Allow users to specify additional key/value pairs to reflect. These key/value
+// pairs are the last to be added to the ReflectMap, and therefore will take
+// precedence over initial values (i.e. __CUDA_FTZ from module medadata and
+// __CUDA_ARCH from SmVersion).
+static cl::list<std::string> ReflectList(
+    "nvvm-reflect-add", cl::value_desc("name=<int>"), cl::Hidden,
+    cl::desc("A key=value pair. Replace __nvvm_reflect(name) with value."),
+    cl::ValueRequired);
+
+// Set the ReflectMap with, first, the value of __CUDA_FTZ from module metadata,
+// and then the key/value pairs from the command line.
+void NVVMReflect::populateReflectMap(Module &M) {
+  if (auto *Flag = mdconst::extract_or_null<ConstantInt>(
+          M.getModuleFlag(CUDA_FTZ_MODULE_NAME)))
+    ReflectMap[CUDA_FTZ_NAME] = Flag->getSExtValue();
+
+  for (auto &Option : ReflectList) {
+    LLVM_DEBUG(dbgs() << "ReflectOption : " << Option << "\n");
+    StringRef OptionRef(Option);
+    auto [Name, Val] = OptionRef.split('=');
+    if (Name.empty())
+      report_fatal_error(Twine("Empty name in nvvm-reflect-add option '") +
+                         Option + "'");
+    if (Val.empty())
+      report_fatal_error(Twine("Missing value in nvvm-reflect-add option '") +
+                         Option + "'");
+    unsigned ValInt;
+    if (!to_integer(Val.trim(), ValInt, 10))
+      report_fatal_error(
+          Twine("integer value expected in nvvm-reflect-add option '") +
+          Option + "'");
+    ReflectMap[Name] = ValInt;
   }
+}
 
-  SmallVector<Instruction *, 4> ToRemove;
-  SmallVector<Instruction *, 4> ToSimplify;
-
-  // Go through the calls in this function.  Each call to __nvvm_reflect or
-  // llvm.nvvm.reflect should be a CallInst with a ConstantArray argument.
-  // First validate that. If the c-string corresponding to the ConstantArray can
-  // be found successfully, see if it can be found in VarMap. If so, replace the
-  // uses of CallInst with the value found in VarMap. If not, replace the use
-  // with value 0.
-
-  // The IR for __nvvm_reflect calls 
diff ers between CUDA versions.
-  //
-  // CUDA 6.5 and earlier uses this sequence:
-  //    %ptr = tail call i8* @llvm.nvvm.ptr.constant.to.gen.p0i8.p4i8
-  //        (i8 addrspace(4)* getelementptr inbounds
-  //           ([8 x i8], [8 x i8] addrspace(4)* @str, i32 0, i32 0))
-  //    %reflect = tail call i32 @__nvvm_reflect(i8* %ptr)
-  //
-  // The value returned by Sym->getOperand(0) is a Constant with a
-  // ConstantDataSequential operand which can be converted to string and used
-  // for lookup.
-  //
-  // CUDA 7.0 does it slightly 
diff erently:
-  //   %reflect = call i32 @__nvvm_reflect(i8* addrspacecast
-  //        (i8 addrspace(1)* getelementptr inbounds
-  //           ([8 x i8], [8 x i8] addrspace(1)* @str, i32 0, i32 0) to i8*))
-  //
-  // In this case, we get a Constant with a GlobalVariable operand and we need
-  // to dig deeper to find its initializer with the string we'll use for lookup.
-  for (Instruction &I : instructions(F)) {
-    CallInst *Call = dyn_cast<CallInst>(&I);
+/// Process a reflect function by finding all its calls and replacing them with
+/// appropriate constant values. For __CUDA_FTZ, uses the module flag value.
+/// For __CUDA_ARCH, uses SmVersion * 10. For all other strings, uses 0.
+bool NVVMReflect::handleReflectFunction(Module &M, StringRef ReflectName) {
+  Function *F = M.getFunction(ReflectName);
+  if (!F)
+    return false;
+  assert(F->isDeclaration() && "_reflect function should not have a body");
+  assert(F->getReturnType()->isIntegerTy() &&
+         "_reflect's return type should be integer");
+
+  const bool Changed = F->getNumUses() > 0;
+  for (User *U : make_early_inc_range(F->users())) {
+    // Reflect function calls look like:
+    // @arch = private unnamed_addr addrspace(1) constant [12 x i8]
+    // c"__CUDA_ARCH\00" call i32 @__nvvm_reflect(ptr addrspacecast (ptr
+    // addrspace(1) @arch to ptr)) We need to extract the string argument from
+    // the call (i.e. "__CUDA_ARCH")
+    auto *Call = dyn_cast<CallInst>(U);
     if (!Call)
-      continue;
-    Function *Callee = Call->getCalledFunction();
-    if (!Callee || (Callee->getName() != NVVM_REFLECT_FUNCTION &&
-                    Callee->getName() != NVVM_REFLECT_OCL_FUNCTION &&
-                    Callee->getIntrinsicID() != Intrinsic::nvvm_reflect))
-      continue;
-
-    // FIXME: Improve error handling here and elsewhere in this pass.
-    assert(Call->getNumOperands() == 2 &&
-           "Wrong number of operands to __nvvm_reflect function");
-
-    // In cuda 6.5 and earlier, we will have an extra constant-to-generic
-    // conversion of the string.
-    const Value *Str = Call->getArgOperand(0);
-    if (const CallInst *ConvCall = dyn_cast<CallInst>(Str)) {
-      // FIXME: Add assertions about ConvCall.
-      Str = ConvCall->getArgOperand(0);
-    }
-    // Pre opaque pointers we have a constant expression wrapping the constant
-    // string.
-    Str = Str->stripPointerCasts();
-    assert(isa<Constant>(Str) &&
-           "Format of __nvvm_reflect function not recognized");
-
-    const Value *Operand = cast<Constant>(Str)->getOperand(0);
-    if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(Operand)) {
-      // For CUDA-7.0 style __nvvm_reflect calls, we need to find the operand's
-      // initializer.
-      assert(GV->hasInitializer() &&
-             "Format of _reflect function not recognized");
-      const Constant *Initializer = GV->getInitializer();
-      Operand = Initializer;
-    }
-
-    assert(isa<ConstantDataSequential>(Operand) &&
-           "Format of _reflect function not recognized");
-    assert(cast<ConstantDataSequential>(Operand)->isCString() &&
-           "Format of _reflect function not recognized");
-
-    StringRef ReflectArg = cast<ConstantDataSequential>(Operand)->getAsString();
-    ReflectArg = ReflectArg.substr(0, ReflectArg.size() - 1);
-    LLVM_DEBUG(dbgs() << "Arg of _reflect : " << ReflectArg << "\n");
-
-    int ReflectVal = 0; // The default value is 0
-    if (ReflectArg == "__CUDA_FTZ") {
-      // Try to pull __CUDA_FTZ from the nvvm-reflect-ftz module flag.  Our
-      // choice here must be kept in sync with AutoUpgrade, which uses the same
-      // technique to detect whether ftz is enabled.
-      if (auto *Flag = mdconst::extract_or_null<ConstantInt>(
-              F.getParent()->getModuleFlag("nvvm-reflect-ftz")))
-        ReflectVal = Flag->getSExtValue();
-    } else if (ReflectArg == "__CUDA_ARCH") {
-      ReflectVal = SmVersion * 10;
-    }
-
-    // If the immediate user is a simple comparison we want to simplify it.
-    for (User *U : Call->users())
-      if (Instruction *I = dyn_cast<Instruction>(U))
-        ToSimplify.push_back(I);
-
-    Call->replaceAllUsesWith(ConstantInt::get(Call->getType(), ReflectVal));
-    ToRemove.push_back(Call);
+      report_fatal_error(
+          "__nvvm_reflect can only be used in a call instruction");
+    if (Call->getNumOperands() != 2)
+      report_fatal_error("__nvvm_reflect requires exactly one argument");
+
+    auto *GlobalStr =
+        dyn_cast<Constant>(Call->getArgOperand(0)->stripPointerCasts());
+    if (!GlobalStr)
+      report_fatal_error("__nvvm_reflect argument must be a constant string");
+
+    auto *ConstantStr =
+        dyn_cast<ConstantDataSequential>(GlobalStr->getOperand(0));
+    if (!ConstantStr)
+      report_fatal_error("__nvvm_reflect argument must be a string constant");
+    if (!ConstantStr->isCString())
+      report_fatal_error(
+          "__nvvm_reflect argument must be a null-terminated string");
+
+    StringRef ReflectArg = ConstantStr->getAsString().drop_back();
+    if (ReflectArg.empty())
+      report_fatal_error("__nvvm_reflect argument cannot be empty");
+    // Now that we have extracted the string argument, we can look it up in the
+    // ReflectMap
+    unsigned ReflectVal = 0; // The default value is 0
+    if (ReflectMap.contains(ReflectArg))
+      ReflectVal = ReflectMap[ReflectArg];
+
+    LLVM_DEBUG(dbgs() << "Replacing call of reflect function " << F->getName()
+                      << "(" << ReflectArg << ") with value " << ReflectVal
+                      << "\n");
+    auto *NewValue = ConstantInt::get(Call->getType(), ReflectVal);
+    foldReflectCall(Call, NewValue);
+    Call->eraseFromParent();
   }
 
-  // The code guarded by __nvvm_reflect may be invalid for the target machine.
-  // Traverse the use-def chain, continually simplifying constant expressions
-  // until we find a terminator that we can then remove.
-  while (!ToSimplify.empty()) {
-    Instruction *I = ToSimplify.pop_back_val();
-    if (Constant *C = ConstantFoldInstruction(I, F.getDataLayout())) {
-      for (User *U : I->users())
-        if (Instruction *I = dyn_cast<Instruction>(U))
-          ToSimplify.push_back(I);
-
-      I->replaceAllUsesWith(C);
-      if (isInstructionTriviallyDead(I)) {
-        ToRemove.push_back(I);
-      }
+  // Remove the __nvvm_reflect function from the module
+  F->eraseFromParent();
+  return Changed;
+}
+
+void NVVMReflect::foldReflectCall(CallInst *Call, Constant *NewValue) {
+  SmallVector<Instruction *, 8> Worklist;
+  // Replace an instruction with a constant and add all users of the instruction
+  // to the worklist
+  auto ReplaceInstructionWithConst = [&](Instruction *I, Constant *C) {
+    for (auto *U : I->users())
+      if (auto *UI = dyn_cast<Instruction>(U))
+        Worklist.push_back(UI);
+    I->replaceAllUsesWith(C);
+  };
+
+  ReplaceInstructionWithConst(Call, NewValue);
+
+  auto &DL = Call->getModule()->getDataLayout();
+  while (!Worklist.empty()) {
+    auto *I = Worklist.pop_back_val();
+    if (auto *C = ConstantFoldInstruction(I, DL)) {
+      ReplaceInstructionWithConst(I, C);
+      if (isInstructionTriviallyDead(I))
+        I->eraseFromParent();
     } else if (I->isTerminator()) {
       ConstantFoldTerminator(I->getParent());
     }
   }
-
-  // Removing via isInstructionTriviallyDead may add duplicates to the ToRemove
-  // array. Filter out the duplicates before starting to erase from parent.
-  std::sort(ToRemove.begin(), ToRemove.end());
-  auto NewLastIter = llvm::unique(ToRemove);
-  ToRemove.erase(NewLastIter, ToRemove.end());
-
-  for (Instruction *I : ToRemove)
-    I->eraseFromParent();
-
-  return ToRemove.size() > 0;
 }
 
-bool NVVMReflect::runOnFunction(Function &F) {
-  return runNVVMReflect(F, SmVersion);
+bool NVVMReflect::runOnModule(Module &M) {
+  if (!NVVMReflectEnabled)
+    return false;
+  populateReflectMap(M);
+  bool Changed = true;
+  Changed |= handleReflectFunction(M, NVVM_REFLECT_FUNCTION);
+  Changed |= handleReflectFunction(M, NVVM_REFLECT_OCL_FUNCTION);
+  Changed |=
+      handleReflectFunction(M, Intrinsic::getName(Intrinsic::nvvm_reflect));
+  return Changed;
 }
 
-NVVMReflectPass::NVVMReflectPass() : NVVMReflectPass(0) {}
+bool NVVMReflectLegacyPass::runOnModule(Module &M) {
+  return Impl.runOnModule(M);
+}
 
-PreservedAnalyses NVVMReflectPass::run(Function &F,
-                                       FunctionAnalysisManager &AM) {
-  return runNVVMReflect(F, SmVersion) ? PreservedAnalyses::none()
-                                      : PreservedAnalyses::all();
+PreservedAnalyses NVVMReflectPass::run(Module &M, ModuleAnalysisManager &AM) {
+  return NVVMReflect(SmVersion).runOnModule(M) ? PreservedAnalyses::none()
+                                               : PreservedAnalyses::all();
 }

diff  --git a/llvm/test/CodeGen/NVPTX/nvvm-reflect-options.ll b/llvm/test/CodeGen/NVPTX/nvvm-reflect-options.ll
new file mode 100644
index 0000000000000..0706882236d86
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/nvvm-reflect-options.ll
@@ -0,0 +1,54 @@
+; Test the NVVM reflect pass functionality: verifying that reflect calls are replaced with 
+; appropriate values based on command-line options. Verify that we can handle custom reflect arguments
+; that aren't __CUDA_ARCH or __CUDA_FTZ. If that argument is given a value on the command-line,
+; the reflect call should be replaced with that value. Otherwise, the reflect call should be replaced with 0.
+
+; RUN: opt -passes=nvvm-reflect -mtriple=nvptx-nvidia-cuda \
+; RUN:   -nvvm-reflect-add __CUDA_FTZ=1 -nvvm-reflect-add __CUDA_ARCH=350 %s -S \
+; RUN:   | FileCheck %s --check-prefixes=COMMON,FTZ1,ARCH350,CUSTOM-ABSENT
+; RUN: opt -passes=nvvm-reflect -mtriple=nvptx-nvidia-cuda \
+; RUN:   -nvvm-reflect-add __CUDA_FTZ=0 -nvvm-reflect-add __CUDA_ARCH=520 %s -S \
+; RUN:   | FileCheck %s --check-prefixes=COMMON,FTZ0,ARCH520,CUSTOM-ABSENT
+; RUN: opt -passes=nvvm-reflect -mtriple=nvptx-nvidia-cuda \
+; RUN:   -nvvm-reflect-add __CUDA_FTZ=0 -nvvm-reflect-add __CUDA_ARCH=520 \
+; RUN:   -nvvm-reflect-add __CUSTOM_VALUE=42 %s -S \
+; RUN:   | FileCheck %s --check-prefixes=COMMON,CUSTOM-PRESENT
+
+; To ensure that command line options override module options, create a copy of this test file 
+; with module options appended and rerun some tests.
+
+; RUN: cat %s > %t.options
+; RUN: echo '!llvm.module.flags = !{!0}' >> %t.options
+; RUN: echo '!0 = !{i32 4, !"nvvm-reflect-ftz", i32 1}' >> %t.options
+; RUN: opt -passes=nvvm-reflect -mtriple=nvptx-nvidia-cuda \
+; RUN:   -nvvm-reflect-add __CUDA_FTZ=0 -nvvm-reflect-add __CUDA_ARCH=520 %t.options -S \
+; RUN:   | FileCheck %s --check-prefixes=COMMON,FTZ0,ARCH520
+
+declare i32 @__nvvm_reflect(ptr)
+ at ftz = private unnamed_addr addrspace(1) constant [11 x i8] c"__CUDA_FTZ\00"
+ at arch = private unnamed_addr addrspace(1) constant [12 x i8] c"__CUDA_ARCH\00"
+ at custom = private unnamed_addr addrspace(1) constant [15 x i8] c"__CUSTOM_VALUE\00"
+
+; COMMON-LABEL: define i32 @test_ftz()
+; FTZ1: ret i32 1
+; FTZ0: ret i32 0
+define i32 @test_ftz() {
+  %1 = call i32 @__nvvm_reflect(ptr addrspacecast (ptr addrspace(1) @ftz to ptr))
+  ret i32 %1
+}
+
+; COMMON-LABEL: define i32 @test_arch()
+; ARCH350: ret i32 350
+; ARCH520: ret i32 520
+define i32 @test_arch() {
+  %1 = call i32 @__nvvm_reflect(ptr addrspacecast (ptr addrspace(1) @arch to ptr))
+  ret i32 %1
+}
+
+; COMMON-LABEL: define i32 @test_custom()
+; CUSTOM-ABSENT: ret i32 0
+; CUSTOM-PRESENT: ret i32 42
+define i32 @test_custom() {
+  %1 = call i32 @__nvvm_reflect(ptr addrspacecast (ptr addrspace(1) @custom to ptr))
+  ret i32 %1
+}


        


More information about the cfe-commits mailing list