[llvm] [NVPTX] add an optional early copy of byval arguments (PR #113384)

via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 22 14:19:37 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-nvptx

Author: Artem Belevich (Artem-B)

<details>
<summary>Changes</summary>

byval arguments in NVPTX are special. We're only allowed to read from them using a special instruction, and if we ever need to write to them or take an address, we must make a local copy and use it, instead.

The problem is that local copies are very expensive, and we create them very late in the compilation pipeline, so LLVM does not have much of a chance to eliminate them, if they turn out to be unnecessary.

One way around that is to create such copies early on, and let them percolate through the optimizations. The copying itself will never trigger creation of another copy later on, as the reads are allowed. If LLVM can eliminate it, it's a win. It the full optimization pipeline can't remove the copy, that's as good as it gets in terms of the effort we could've done, and it's certainly a much better effort than what we do now.

This early injection of the copies has potential to create undesireable side-effects, so it's disabled by default, for now, until it sees more testing.

---
Full diff: https://github.com/llvm/llvm-project/pull/113384.diff


4 Files Affected:

- (modified) llvm/lib/Target/NVPTX/NVPTX.h (+4) 
- (modified) llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp (+48-25) 
- (modified) llvm/lib/Target/NVPTX/NVPTXPassRegistry.def (+1) 
- (modified) llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp (+7) 


``````````diff
diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h
index f6ab81d3ca0bb2..ca915cd3f3732f 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.h
+++ b/llvm/lib/Target/NVPTX/NVPTX.h
@@ -70,6 +70,10 @@ struct GenericToNVVMPass : PassInfoMixin<GenericToNVVMPass> {
   PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
 };
 
+struct NVPTXCopyByValArgsPass : PassInfoMixin<NVPTXCopyByValArgsPass> {
+  PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
+};
+
 namespace NVPTX {
 enum DrvInterface {
   NVCL,
diff --git a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
index bb76cfd6fdb7bd..5d58c460f57a50 100644
--- a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
@@ -543,6 +543,33 @@ struct ArgUseChecker : PtrUseVisitor<ArgUseChecker> {
       PI.setAborted(&II);
   }
 }; // struct ArgUseChecker
+
+void copyByValParam(Function &F, Argument &Arg) {
+  LLVM_DEBUG(dbgs() << "Creating a local copy of " << Arg << "\n");
+  // Otherwise we have to create a temporary copy.
+  BasicBlock::iterator FirstInst = F.getEntryBlock().begin();
+  Type *StructType = Arg.getParamByValType();
+  const DataLayout &DL = F.getDataLayout();
+  AllocaInst *AllocA = new AllocaInst(StructType, DL.getAllocaAddrSpace(),
+                                      Arg.getName(), FirstInst);
+  // Set the alignment to alignment of the byval parameter. This is because,
+  // later load/stores assume that alignment, and we are going to replace
+  // the use of the byval parameter with this alloca instruction.
+  AllocA->setAlignment(F.getParamAlign(Arg.getArgNo())
+                           .value_or(DL.getPrefTypeAlign(StructType)));
+  Arg.replaceAllUsesWith(AllocA);
+
+  Value *ArgInParam = new AddrSpaceCastInst(
+      &Arg, PointerType::get(Arg.getContext(), ADDRESS_SPACE_PARAM),
+      Arg.getName(), FirstInst);
+  // Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
+  // addrspacecast preserves alignment.  Since params are constant, this load
+  // is definitely not volatile.
+  const auto ArgSize = *AllocA->getAllocationSize(DL);
+  IRBuilder<> IRB(&*FirstInst);
+  IRB.CreateMemCpy(AllocA, AllocA->getAlign(), ArgInParam, AllocA->getAlign(),
+                   ArgSize);
+}
 } // namespace
 
 void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
@@ -558,7 +585,7 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
 
   ArgUseChecker AUC(DL, IsGridConstant);
   ArgUseChecker::PtrInfo PI = AUC.visitArgPtr(*Arg);
-  bool ArgUseIsReadOnly  = !(PI.isEscaped() || PI.isAborted());
+  bool ArgUseIsReadOnly = !(PI.isEscaped() || PI.isAborted());
   // Easy case, accessing parameter directly is fine.
   if (ArgUseIsReadOnly && AUC.Conditionals.empty()) {
     // Convert all loads and intermediate operations to use parameter AS and
@@ -587,7 +614,6 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
   // However, we're still not allowed to write to it. If the user specified
   // `__grid_constant__` for the argument, we'll consider escaped pointer as
   // read-only.
-  unsigned AS = DL.getAllocaAddrSpace();
   if (HasCvtaParam && (ArgUseIsReadOnly || IsGridConstant)) {
     LLVM_DEBUG(dbgs() << "Using non-copy pointer to " << *Arg << "\n");
     // Replace all argument pointer uses (which might include a device function
@@ -612,29 +638,8 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
 
     // Do not replace Arg in the cast to param space
     CastToParam->setOperand(0, Arg);
-  } else {
-    LLVM_DEBUG(dbgs() << "Creating a local copy of " << *Arg << "\n");
-    // Otherwise we have to create a temporary copy.
-    AllocaInst *AllocA =
-        new AllocaInst(StructType, AS, Arg->getName(), FirstInst);
-    // Set the alignment to alignment of the byval parameter. This is because,
-    // later load/stores assume that alignment, and we are going to replace
-    // the use of the byval parameter with this alloca instruction.
-    AllocA->setAlignment(Func->getParamAlign(Arg->getArgNo())
-                             .value_or(DL.getPrefTypeAlign(StructType)));
-    Arg->replaceAllUsesWith(AllocA);
-
-    Value *ArgInParam = new AddrSpaceCastInst(
-        Arg, PointerType::get(Arg->getContext(), ADDRESS_SPACE_PARAM),
-        Arg->getName(), FirstInst);
-    // Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
-    // addrspacecast preserves alignment.  Since params are constant, this load
-    // is definitely not volatile.
-    const auto ArgSize = *AllocA->getAllocationSize(DL);
-    IRBuilder<> IRB(&*FirstInst);
-    IRB.CreateMemCpy(AllocA, AllocA->getAlign(), ArgInParam, AllocA->getAlign(),
-                     ArgSize);
-  }
+  } else
+    copyByValParam(*Func, *Arg);
 }
 
 void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) {
@@ -734,3 +739,21 @@ bool NVPTXLowerArgs::runOnFunction(Function &F) {
 }
 
 FunctionPass *llvm::createNVPTXLowerArgsPass() { return new NVPTXLowerArgs(); }
+
+static bool copyFunctionByValArgs(Function &F) {
+  LLVM_DEBUG(dbgs() << "Creating a copy of byval args of " << F.getName()
+                    << "\n");
+  bool Changed = false;
+  for (Argument &Arg : F.args())
+    if (Arg.getType()->isPointerTy() && Arg.hasByValAttr()) {
+      copyByValParam(F, Arg);
+      Changed = true;
+    }
+  return Changed;
+}
+
+PreservedAnalyses NVPTXCopyByValArgsPass::run(Function &F,
+                                              FunctionAnalysisManager &AM) {
+  return copyFunctionByValArgs(F) ? PreservedAnalyses::none()
+                                  : PreservedAnalyses::all();
+}
diff --git a/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def b/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def
index 6ff15ab6f13c44..28ea9dd9c02270 100644
--- a/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def
+++ b/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def
@@ -37,4 +37,5 @@ FUNCTION_ALIAS_ANALYSIS("nvptx-aa", NVPTXAA())
 #endif
 FUNCTION_PASS("nvvm-intr-range", NVVMIntrRangePass())
 FUNCTION_PASS("nvvm-reflect", NVVMReflectPass())
+FUNCTION_PASS("nvptx-copy-byval-args", NVPTXCopyByValArgsPass())
 #undef FUNCTION_PASS
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
index 7d04cf3dc51e67..38c90ab953ada7 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
@@ -64,6 +64,11 @@ static cl::opt<bool> UseShortPointersOpt(
         "Use 32-bit pointers for accessing const/local/shared address spaces."),
     cl::init(false), cl::Hidden);
 
+static cl::opt<bool> EarlyByValArgsCopy(
+    "nvptx-early-byval-copy",
+    cl::desc("Create a copy of byval function arguments early."),
+    cl::init(false), cl::Hidden);
+
 namespace llvm {
 
 void initializeGenericToNVVMLegacyPassPass(PassRegistry &);
@@ -236,6 +241,8 @@ void NVPTXTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
         // Note: NVVMIntrRangePass was causing numerical discrepancies at one
         // point, if issues crop up, consider disabling.
         FPM.addPass(NVVMIntrRangePass());
+        if (EarlyByValArgsCopy)
+          FPM.addPass(NVPTXCopyByValArgsPass());
         PM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
       });
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/113384


More information about the llvm-commits mailing list