[llvm] [NVPTX] add an optional early copy of byval arguments (PR #113384)

Artem Belevich via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 22 14:19:03 PDT 2024


https://github.com/Artem-B created https://github.com/llvm/llvm-project/pull/113384

byval arguments in NVPTX are special. We're only allowed to read from them using a special instruction, and if we ever need to write to them or take an address, we must make a local copy and use it, instead.

The problem is that local copies are very expensive, and we create them very late in the compilation pipeline, so LLVM does not have much of a chance to eliminate them, if they turn out to be unnecessary.

One way around that is to create such copies early on, and let them percolate through the optimizations. The copying itself will never trigger creation of another copy later on, as the reads are allowed. If LLVM can eliminate it, it's a win. It the full optimization pipeline can't remove the copy, that's as good as it gets in terms of the effort we could've done, and it's certainly a much better effort than what we do now.

This early injection of the copies has potential to create undesireable side-effects, so it's disabled by default, for now, until it sees more testing.

>From ad1b391604ec9a5f1d179c2f2cbdc1d960a7ed17 Mon Sep 17 00:00:00 2001
From: Artem Belevich <tra at google.com>
Date: Tue, 22 Oct 2024 11:28:54 -0700
Subject: [PATCH] [NVPTX] add an optional early copy of byval arguments

byval arguments in NVPTX are special. We're only allowed to read from them using
a special instruction, and if we ever need to write to them or take an address,
we must make a local copy and use it, instead.

The problem is that local copies are very expensive, and we create them very late
in the compilation pipeline, so LLVM does not have much of a chance to eliminate
them, if they turn out to be unnecessary.

One way around that is to create such copies early on, and let them percolate
through the optimizations. The copying itself will never trigger creation of
another copy later on, as the reads are allowed. If LLVM can eliminate it,
it's a win. It the full optimization pipeline can't remove the copy, that's
as good as it gets in terms of the effort we could've done, and it's certainly
a much better effort than what we do now.

This early injection of the copies has potential to create undesireable
side-effects, so it's disabled by default, for now, until it sees more testing.
---
 llvm/lib/Target/NVPTX/NVPTX.h                |  4 ++
 llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp     | 73 +++++++++++++-------
 llvm/lib/Target/NVPTX/NVPTXPassRegistry.def  |  1 +
 llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp |  7 ++
 4 files changed, 60 insertions(+), 25 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h
index f6ab81d3ca0bb2..ca915cd3f3732f 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.h
+++ b/llvm/lib/Target/NVPTX/NVPTX.h
@@ -70,6 +70,10 @@ struct GenericToNVVMPass : PassInfoMixin<GenericToNVVMPass> {
   PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
 };
 
+struct NVPTXCopyByValArgsPass : PassInfoMixin<NVPTXCopyByValArgsPass> {
+  PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
+};
+
 namespace NVPTX {
 enum DrvInterface {
   NVCL,
diff --git a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
index bb76cfd6fdb7bd..5d58c460f57a50 100644
--- a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
@@ -543,6 +543,33 @@ struct ArgUseChecker : PtrUseVisitor<ArgUseChecker> {
       PI.setAborted(&II);
   }
 }; // struct ArgUseChecker
+
+void copyByValParam(Function &F, Argument &Arg) {
+  LLVM_DEBUG(dbgs() << "Creating a local copy of " << Arg << "\n");
+  // Otherwise we have to create a temporary copy.
+  BasicBlock::iterator FirstInst = F.getEntryBlock().begin();
+  Type *StructType = Arg.getParamByValType();
+  const DataLayout &DL = F.getDataLayout();
+  AllocaInst *AllocA = new AllocaInst(StructType, DL.getAllocaAddrSpace(),
+                                      Arg.getName(), FirstInst);
+  // Set the alignment to alignment of the byval parameter. This is because,
+  // later load/stores assume that alignment, and we are going to replace
+  // the use of the byval parameter with this alloca instruction.
+  AllocA->setAlignment(F.getParamAlign(Arg.getArgNo())
+                           .value_or(DL.getPrefTypeAlign(StructType)));
+  Arg.replaceAllUsesWith(AllocA);
+
+  Value *ArgInParam = new AddrSpaceCastInst(
+      &Arg, PointerType::get(Arg.getContext(), ADDRESS_SPACE_PARAM),
+      Arg.getName(), FirstInst);
+  // Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
+  // addrspacecast preserves alignment.  Since params are constant, this load
+  // is definitely not volatile.
+  const auto ArgSize = *AllocA->getAllocationSize(DL);
+  IRBuilder<> IRB(&*FirstInst);
+  IRB.CreateMemCpy(AllocA, AllocA->getAlign(), ArgInParam, AllocA->getAlign(),
+                   ArgSize);
+}
 } // namespace
 
 void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
@@ -558,7 +585,7 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
 
   ArgUseChecker AUC(DL, IsGridConstant);
   ArgUseChecker::PtrInfo PI = AUC.visitArgPtr(*Arg);
-  bool ArgUseIsReadOnly  = !(PI.isEscaped() || PI.isAborted());
+  bool ArgUseIsReadOnly = !(PI.isEscaped() || PI.isAborted());
   // Easy case, accessing parameter directly is fine.
   if (ArgUseIsReadOnly && AUC.Conditionals.empty()) {
     // Convert all loads and intermediate operations to use parameter AS and
@@ -587,7 +614,6 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
   // However, we're still not allowed to write to it. If the user specified
   // `__grid_constant__` for the argument, we'll consider escaped pointer as
   // read-only.
-  unsigned AS = DL.getAllocaAddrSpace();
   if (HasCvtaParam && (ArgUseIsReadOnly || IsGridConstant)) {
     LLVM_DEBUG(dbgs() << "Using non-copy pointer to " << *Arg << "\n");
     // Replace all argument pointer uses (which might include a device function
@@ -612,29 +638,8 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
 
     // Do not replace Arg in the cast to param space
     CastToParam->setOperand(0, Arg);
-  } else {
-    LLVM_DEBUG(dbgs() << "Creating a local copy of " << *Arg << "\n");
-    // Otherwise we have to create a temporary copy.
-    AllocaInst *AllocA =
-        new AllocaInst(StructType, AS, Arg->getName(), FirstInst);
-    // Set the alignment to alignment of the byval parameter. This is because,
-    // later load/stores assume that alignment, and we are going to replace
-    // the use of the byval parameter with this alloca instruction.
-    AllocA->setAlignment(Func->getParamAlign(Arg->getArgNo())
-                             .value_or(DL.getPrefTypeAlign(StructType)));
-    Arg->replaceAllUsesWith(AllocA);
-
-    Value *ArgInParam = new AddrSpaceCastInst(
-        Arg, PointerType::get(Arg->getContext(), ADDRESS_SPACE_PARAM),
-        Arg->getName(), FirstInst);
-    // Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
-    // addrspacecast preserves alignment.  Since params are constant, this load
-    // is definitely not volatile.
-    const auto ArgSize = *AllocA->getAllocationSize(DL);
-    IRBuilder<> IRB(&*FirstInst);
-    IRB.CreateMemCpy(AllocA, AllocA->getAlign(), ArgInParam, AllocA->getAlign(),
-                     ArgSize);
-  }
+  } else
+    copyByValParam(*Func, *Arg);
 }
 
 void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) {
@@ -734,3 +739,21 @@ bool NVPTXLowerArgs::runOnFunction(Function &F) {
 }
 
 FunctionPass *llvm::createNVPTXLowerArgsPass() { return new NVPTXLowerArgs(); }
+
+static bool copyFunctionByValArgs(Function &F) {
+  LLVM_DEBUG(dbgs() << "Creating a copy of byval args of " << F.getName()
+                    << "\n");
+  bool Changed = false;
+  for (Argument &Arg : F.args())
+    if (Arg.getType()->isPointerTy() && Arg.hasByValAttr()) {
+      copyByValParam(F, Arg);
+      Changed = true;
+    }
+  return Changed;
+}
+
+PreservedAnalyses NVPTXCopyByValArgsPass::run(Function &F,
+                                              FunctionAnalysisManager &AM) {
+  return copyFunctionByValArgs(F) ? PreservedAnalyses::none()
+                                  : PreservedAnalyses::all();
+}
diff --git a/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def b/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def
index 6ff15ab6f13c44..28ea9dd9c02270 100644
--- a/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def
+++ b/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def
@@ -37,4 +37,5 @@ FUNCTION_ALIAS_ANALYSIS("nvptx-aa", NVPTXAA())
 #endif
 FUNCTION_PASS("nvvm-intr-range", NVVMIntrRangePass())
 FUNCTION_PASS("nvvm-reflect", NVVMReflectPass())
+FUNCTION_PASS("nvptx-copy-byval-args", NVPTXCopyByValArgsPass())
 #undef FUNCTION_PASS
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
index 7d04cf3dc51e67..38c90ab953ada7 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
@@ -64,6 +64,11 @@ static cl::opt<bool> UseShortPointersOpt(
         "Use 32-bit pointers for accessing const/local/shared address spaces."),
     cl::init(false), cl::Hidden);
 
+static cl::opt<bool> EarlyByValArgsCopy(
+    "nvptx-early-byval-copy",
+    cl::desc("Create a copy of byval function arguments early."),
+    cl::init(false), cl::Hidden);
+
 namespace llvm {
 
 void initializeGenericToNVVMLegacyPassPass(PassRegistry &);
@@ -236,6 +241,8 @@ void NVPTXTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
         // Note: NVVMIntrRangePass was causing numerical discrepancies at one
         // point, if issues crop up, consider disabling.
         FPM.addPass(NVVMIntrRangePass());
+        if (EarlyByValArgsCopy)
+          FPM.addPass(NVPTXCopyByValArgsPass());
         PM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
       });
 }



More information about the llvm-commits mailing list