[llvm] d0615a9 - [NVPTX] Handle bitcast and ASC(101) when trying to avoid argument copy.

Artem Belevich via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 6 13:13:34 PDT 2021


Author: Artem Belevich
Date: 2021-04-06T13:06:00-07:00
New Revision: d0615a93bb6d7aedc43323dc8957fe57e86ed8ae

URL: https://github.com/llvm/llvm-project/commit/d0615a93bb6d7aedc43323dc8957fe57e86ed8ae
DIFF: https://github.com/llvm/llvm-project/commit/d0615a93bb6d7aedc43323dc8957fe57e86ed8ae.diff

LOG: [NVPTX] Handle bitcast and ASC(101) when trying to avoid argument copy.

This allows us to skip the copy in few more cases.

Differential Revision: https://reviews.llvm.org/D99979

Added: 
    

Modified: 
    llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
    llvm/test/CodeGen/NVPTX/lower-byval-args.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
index 0143f4f4b62a6..b27e8a9cd7844 100644
--- a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
@@ -99,6 +99,8 @@
 #include "llvm/IR/Type.h"
 #include "llvm/Pass.h"
 
+#define DEBUG_TYPE "nvptx-lower-args"
+
 using namespace llvm;
 
 namespace llvm {
@@ -166,40 +168,60 @@ static void convertToParamAS(Value *OldUser, Value *Param) {
     Value *NewParam;
   };
   SmallVector<IP> ItemsToConvert = {{I, Param}};
-  SmallVector<GetElementPtrInst *> GEPsToDelete;
-  while (!ItemsToConvert.empty()) {
-    IP I = ItemsToConvert.pop_back_val();
-    if (auto *LI = dyn_cast<LoadInst>(I.OldInstruction))
+  SmallVector<Instruction *> InstructionsToDelete;
+
+  auto CloneInstInParamAS = [](const IP &I) -> Value * {
+    if (auto *LI = dyn_cast<LoadInst>(I.OldInstruction)) {
       LI->setOperand(0, I.NewParam);
-    else if (auto *GEP = dyn_cast<GetElementPtrInst>(I.OldInstruction)) {
+      return LI;
+    }
+    if (auto *GEP = dyn_cast<GetElementPtrInst>(I.OldInstruction)) {
       SmallVector<Value *, 4> Indices(GEP->indices());
       auto *NewGEP = GetElementPtrInst::Create(nullptr, I.NewParam, Indices,
                                                GEP->getName(), GEP);
       NewGEP->setIsInBounds(GEP->isInBounds());
-      llvm::for_each(GEP->users(), [NewGEP, &ItemsToConvert](Value *V) {
-        ItemsToConvert.push_back({cast<Instruction>(V), NewGEP});
-      });
-      GEPsToDelete.push_back(GEP);
-    } else
-      llvm_unreachable("Only Load and GEP can be converted to param AS.");
-  }
-  llvm::for_each(GEPsToDelete,
-                 [](GetElementPtrInst *GEP) { GEP->eraseFromParent(); });
-}
+      return NewGEP;
+    }
+    if (auto *BC = dyn_cast<BitCastInst>(I.OldInstruction)) {
+      auto *NewBCType = BC->getType()->getPointerElementType()->getPointerTo(
+          ADDRESS_SPACE_PARAM);
+      return BitCastInst::Create(BC->getOpcode(), I.NewParam, NewBCType,
+                                 BC->getName(), BC);
+    }
+    if (auto *ASC = dyn_cast<AddrSpaceCastInst>(I.OldInstruction)) {
+      assert(ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM);
+      // Just pass through the argument, the old ASC is no longer needed.
+      return I.NewParam;
+    }
+    llvm_unreachable("Unsupported instruction");
+  };
 
-static bool isALoadChain(Value *Start) {
-  SmallVector<Value *, 16> ValuesToCheck = {Start};
-  while (!ValuesToCheck.empty()) {
-    Value *V = ValuesToCheck.pop_back_val();
-    Instruction *I = dyn_cast<Instruction>(V);
-    if (!I)
-      return false;
-    if (isa<GetElementPtrInst>(I))
-      ValuesToCheck.append(I->user_begin(), I->user_end());
-    else if (!isa<LoadInst>(I))
-      return false;
+  while (!ItemsToConvert.empty()) {
+    IP I = ItemsToConvert.pop_back_val();
+    Value *NewInst = CloneInstInParamAS(I);
+
+    if (NewInst && NewInst != I.OldInstruction) {
+      // We've created a new instruction. Queue users of the old instruction to
+      // be converted and the instruction itself to be deleted. We can't delete
+      // the old instruction yet, because it's still in use by a load somewhere.
+      llvm::for_each(
+          I.OldInstruction->users(), [NewInst, &ItemsToConvert](Value *V) {
+            ItemsToConvert.push_back({cast<Instruction>(V), NewInst});
+          });
+
+      InstructionsToDelete.push_back(I.OldInstruction);
+    }
   }
-  return true;
+
+  // Now we know that all argument loads are using addresses in parameter space
+  // and we can finally remove the old instructions in generic AS.  Instructions
+  // scheduled for removal should be processed in reverse order so the ones
+  // closest to the load are deleted first. Otherwise they may still be in use.
+  // E.g if we have Value = Load(BitCast(GEP(arg))), InstructionsToDelete will
+  // have {GEP,BitCast}. GEP can't be deleted first, because it's still used by
+  // the BitCast.
+  llvm::for_each(reverse(InstructionsToDelete),
+                 [](Instruction *I) { I->eraseFromParent(); });
 }
 
 void NVPTXLowerArgs::handleByValParam(Argument *Arg) {
@@ -211,9 +233,35 @@ void NVPTXLowerArgs::handleByValParam(Argument *Arg) {
 
   Type *StructType = PType->getElementType();
 
-  if (llvm::all_of(Arg->users(), isALoadChain)) {
-    // Replace all loads with the loads in param AS. This allows loading the Arg
-    // directly from parameter AS, without making a temporary copy.
+  auto IsALoadChain = [Arg](Value *Start) {
+    SmallVector<Value *, 16> ValuesToCheck = {Start};
+    auto IsALoadChainInstr = [](Value *V) -> bool {
+      if (isa<GetElementPtrInst>(V) || isa<BitCastInst>(V) || isa<LoadInst>(V))
+        return true;
+      // ASC to param space are OK, too -- we'll just strip them.
+      if (auto *ASC = dyn_cast<AddrSpaceCastInst>(V)) {
+        if (ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM)
+          return true;
+      }
+      return false;
+    };
+
+    while (!ValuesToCheck.empty()) {
+      Value *V = ValuesToCheck.pop_back_val();
+      if (!IsALoadChainInstr(V)) {
+        LLVM_DEBUG(dbgs() << "Need a copy of " << *Arg << " because of " << *V
+                          << "\n");
+        return false;
+      }
+      if (!isa<LoadInst>(V))
+        llvm::append_range(ValuesToCheck, V->users());
+    }
+    return true;
+  };
+
+  if (llvm::all_of(Arg->users(), IsALoadChain)) {
+    // Convert all loads and intermediate operations to use parameter AS and
+    // skip creation of a local copy of the argument.
     SmallVector<User *, 16> UsersToUpdate(Arg->users());
     Value *ArgInParamAS = new AddrSpaceCastInst(
         Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(),
@@ -221,6 +269,7 @@ void NVPTXLowerArgs::handleByValParam(Argument *Arg) {
     llvm::for_each(UsersToUpdate, [ArgInParamAS](Value *V) {
       convertToParamAS(V, ArgInParamAS);
     });
+    LLVM_DEBUG(dbgs() << "No need to copy " << *Arg << "\n");
     return;
   }
 
@@ -297,6 +346,7 @@ bool NVPTXLowerArgs::runOnKernelFunction(Function &F) {
     }
   }
 
+  LLVM_DEBUG(dbgs() << "Lowering kernel args of " << F.getName() << "\n");
   for (Argument &Arg : F.args()) {
     if (Arg.getType()->isPointerTy()) {
       if (Arg.hasByValAttr())
@@ -310,6 +360,7 @@ bool NVPTXLowerArgs::runOnKernelFunction(Function &F) {
 
 // Device functions only need to copy byval args into local memory.
 bool NVPTXLowerArgs::runOnDeviceFunction(Function &F) {
+  LLVM_DEBUG(dbgs() << "Lowering function args of " << F.getName() << "\n");
   for (Argument &Arg : F.args())
     if (Arg.getType()->isPointerTy() && Arg.hasByValAttr())
       handleByValParam(&Arg);

diff  --git a/llvm/test/CodeGen/NVPTX/lower-byval-args.ll b/llvm/test/CodeGen/NVPTX/lower-byval-args.ll
index 455eb37e5a175..4d74b44bc100e 100644
--- a/llvm/test/CodeGen/NVPTX/lower-byval-args.ll
+++ b/llvm/test/CodeGen/NVPTX/lower-byval-args.ll
@@ -6,6 +6,7 @@ target triple = "nvptx64-nvidia-cuda"
 
 ; // Verify that load with static offset into parameter is done directly.
 ; CHECK-LABEL: .visible .entry static_offset
+; CHECK-NOT:   .local
 ; CHECK: ld.param.u64    [[result_addr:%rd[0-9]+]], [{{.*}}_param_0]
 ; CHECK: mov.b64         %[[param_addr:rd[0-9]+]], {{.*}}_param_1
 ; CHECK: mov.u64         %[[param_addr1:rd[0-9]+]], %[[param_addr]]
@@ -30,6 +31,7 @@ bb6:                                              ; preds = %bb3, %bb
 
 ; // Verify that load with dynamic offset into parameter is also done directly.
 ; CHECK-LABEL: .visible .entry dynamic_offset
+; CHECK-NOT:   .local
 ; CHECK: ld.param.u64    [[result_addr:%rd[0-9]+]], [{{.*}}_param_0]
 ; CHECK: mov.b64         %[[param_addr:rd[0-9]+]], {{.*}}_param_1
 ; CHECK: mov.u64         %[[param_addr1:rd[0-9]+]], %[[param_addr]]
@@ -48,6 +50,48 @@ bb:
   ret void
 }
 
+; Same as above, but with a bitcast present in the chain
+; CHECK-LABEL:.visible .entry gep_bitcast
+; CHECK-NOT: .local
+; CHECK-DAG: ld.param.u64    [[out:%rd[0-9]+]], [gep_bitcast_param_0]
+; CHECK-DAG: mov.b64         {{%rd[0-9]+}}, gep_bitcast_param_1
+; CHECK-DAG: ld.param.u32    {{%r[0-9]+}}, [gep_bitcast_param_2]
+; CHECK:     ld.param.u8     [[value:%rs[0-9]+]], [{{%rd[0-9]+}}]
+; CHECK:     st.global.u8    [{{%rd[0-9]+}}], [[value]];
+;
+; Function Attrs: nofree norecurse nounwind willreturn mustprogress
+define dso_local void @gep_bitcast(i8* nocapture %out,  %struct.ham* nocapture readonly byval(%struct.ham) align 4 %in, i32 %n) local_unnamed_addr #0 {
+bb:
+  %n64 = sext i32 %n to i64
+  %gep = getelementptr inbounds %struct.ham, %struct.ham* %in, i64 0, i32 0, i64 %n64
+  %bc = bitcast i32* %gep to i8*
+  %load = load i8, i8* %bc, align 4
+  store i8 %load, i8* %out, align 4
+  ret void
+}
+
+; Same as above, but with an ASC(101) present in the chain
+; CHECK-LABEL:.visible .entry gep_bitcast_asc
+; CHECK-NOT: .local
+; CHECK-DAG: ld.param.u64    [[out:%rd[0-9]+]], [gep_bitcast_asc_param_0]
+; CHECK-DAG: mov.b64         {{%rd[0-9]+}}, gep_bitcast_asc_param_1
+; CHECK-DAG: ld.param.u32    {{%r[0-9]+}}, [gep_bitcast_asc_param_2]
+; CHECK:     ld.param.u8     [[value:%rs[0-9]+]], [{{%rd[0-9]+}}]
+; CHECK:     st.global.u8    [{{%rd[0-9]+}}], [[value]];
+;
+; Function Attrs: nofree norecurse nounwind willreturn mustprogress
+define dso_local void @gep_bitcast_asc(i8* nocapture %out,  %struct.ham* nocapture readonly byval(%struct.ham) align 4 %in, i32 %n) local_unnamed_addr #0 {
+bb:
+  %n64 = sext i32 %n to i64
+  %gep = getelementptr inbounds %struct.ham, %struct.ham* %in, i64 0, i32 0, i64 %n64
+  %bc = bitcast i32* %gep to i8*
+  %asc = addrspacecast i8* %bc to i8 addrspace(101)*
+  %load = load i8, i8 addrspace(101)* %asc, align 4
+  store i8 %load, i8* %out, align 4
+  ret void
+}
+
+
 ; Verify that if the pointer escapes, then we do fall back onto using a temp copy.
 ; CHECK-LABEL: .visible .entry pointer_escapes
 ; CHECK: .local .align 8 .b8     __local_depot{{.*}}
@@ -82,7 +126,7 @@ declare dso_local i32* @escape(i32*) local_unnamed_addr
 
 
 !llvm.module.flags = !{!0, !1, !2}
-!nvvm.annotations = !{!3, !4, !5}
+!nvvm.annotations = !{!3, !4, !5, !6, !7}
 
 !0 = !{i32 2, !"SDK Version", [2 x i32] [i32 9, i32 1]}
 !1 = !{i32 1, !"wchar_size", i32 4}
@@ -90,3 +134,5 @@ declare dso_local i32* @escape(i32*) local_unnamed_addr
 !3 = !{void (i32*, %struct.ham*, i32)* @static_offset, !"kernel", i32 1}
 !4 = !{void (i32*, %struct.ham*, i32)* @dynamic_offset, !"kernel", i32 1}
 !5 = !{void (i32*, %struct.ham*, i32)* @pointer_escapes, !"kernel", i32 1}
+!6 = !{void (i8*, %struct.ham*, i32)* @gep_bitcast, !"kernel", i32 1}
+!7 = !{void (i8*, %struct.ham*, i32)* @gep_bitcast_asc, !"kernel", i32 1}


        


More information about the llvm-commits mailing list