[llvm] [NVPTX] Improve copy avoidance during lowering. (PR #106423)

Akshay Deodhar via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 28 12:23:10 PDT 2024


================
@@ -409,49 +453,121 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
   }
 }
 
+namespace {
+struct ArgUseChecker : PtrUseVisitor<ArgUseChecker> {
+  using Base = PtrUseVisitor<ArgUseChecker>;
+
+  bool IsGridConstant;
+  SmallPtrSet<Value *, 16> AllArgUsers;
+  // Set of phi/select instructions using the Arg
+  SmallPtrSet<Instruction *, 4> Conditionals;
+
+  ArgUseChecker(const DataLayout &DL, bool IsGridConstant)
+      : PtrUseVisitor(DL), IsGridConstant(IsGridConstant) {}
+
+  PtrInfo visitArgPtr(Argument &A) {
+    assert(A.getType()->isPointerTy());
+    IntegerType *IntIdxTy = cast<IntegerType>(DL.getIndexType(A.getType()));
+    IsOffsetKnown = false;
+    Offset = APInt(IntIdxTy->getBitWidth(), 0);
+    PI.reset();
+    AllArgUsers.clear();
+    Conditionals.clear();
+
+    LLVM_DEBUG(dbgs() << "Checking Argument " << A << "\n");
+    // Enqueue the uses of this pointer.
+    enqueueUsers(A);
+    AllArgUsers.insert(&A);
+
+    // Visit all the uses off the worklist until it is empty.
+    // Note that unlike PtrUseVisitor we're intentionally do not track offset.
+    // We're only interested in how we use the pointer.
+    while (!(Worklist.empty() || PI.isAborted())) {
+      UseToVisit ToVisit = Worklist.pop_back_val();
+      U = ToVisit.UseAndIsOffsetKnown.getPointer();
+      Instruction *I = cast<Instruction>(U->getUser());
+      AllArgUsers.insert(I);
+      if (isa<PHINode>(I) || isa<SelectInst>(I))
+        Conditionals.insert(I);
+      LLVM_DEBUG(dbgs() << "Processing " << *I << "\n");
+      Base::visit(I);
+    }
+    if (PI.isEscaped())
+      LLVM_DEBUG(dbgs() << "Argument pointer escaped: " << *PI.getEscapingInst()
+                        << "\n");
+    else if (PI.isAborted())
+      LLVM_DEBUG(dbgs() << "Pointer use needs a copy: " << *PI.getAbortingInst()
+                        << "\n");
+    LLVM_DEBUG(dbgs() << "Traversed " << AllArgUsers.size() << " with "
+                      << Conditionals.size() << " conditionals\n");
+    return PI;
+  }
+
+  void visitStoreInst(StoreInst &SI) {
+    // Storing the pointer escapes it.
+    if (U->get() == SI.getValueOperand())
+      return PI.setEscapedAndAborted(&SI);
+    // Writes to the pointer are UB w/ __grid_constant__, but do not force a
+    // copy.
+    if (!IsGridConstant)
+      return PI.setAborted(&SI);
+  }
+
+  void visitAddrSpaceCastInst(AddrSpaceCastInst &ASC) {
+    // ASC to param space are no-ops and do not need a copy
+    if (ASC.getDestAddressSpace() != ADDRESS_SPACE_PARAM)
+      return PI.setEscapedAndAborted(&ASC);
+    Base::visitAddrSpaceCastInst(ASC);
+  }
+
+  void visitPtrToIntInst(PtrToIntInst &I) {
+    if (IsGridConstant)
+      return;
+    Base::visitPtrToIntInst(I);
+  }
+  void visitPHINodeOrSelectInst(Instruction &I) {
+    assert(isa<PHINode>(I) || isa<SelectInst>(I));
+  }
+  // PHI and select just pass through the pointers.
+  void visitPHINode(PHINode &PN) { enqueueUsers(PN); }
+  void visitSelectInst(SelectInst &SI) { enqueueUsers(SI); }
+
+  void visitMemTransferInst(MemTransferInst &II) {
+    if (*U == II.getRawDest() && !IsGridConstant)
+      PI.setAborted(&II);
+
+    // TODO: memcpy from arg is OK as it can get unrolled into ld.param.
+    // However, memcpys are currently expected to be unrolled before we
+    // get here, so we never see them in practice, and we do not currently
+    // handle them when we convert IR to access param space directly. So,
+    // we'll mark it as an escape for now. It would still force a copy on
+    // pre-sm_70 GPUs where we can't take address of a parameter w/o a copy.
+    //
+    // PI.setEscaped(&II);
+  }
+
+  void visitMemSetInst(MemSetInst &II) {
+    if (*U == II.getRawDest() && !IsGridConstant)
+      PI.setAborted(&II);
+  }
+  // debug only helper.
+  auto &getVisitedUses() { return VisitedUses; }
+};
+} // namespace
 void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
                                       Argument *Arg) {
-  bool IsGridConstant = isParamGridConstant(*Arg);
   Function *Func = Arg->getParent();
+  bool HasCvtaParam = TM.getSubtargetImpl(*Func)->hasCvtaParam();
+  bool IsGridConstant = HasCvtaParam && isParamGridConstant(*Arg);
+  const DataLayout &DL = Func->getDataLayout();
   BasicBlock::iterator FirstInst = Func->getEntryBlock().begin();
   Type *StructType = Arg->getParamByValType();
   assert(StructType && "Missing byval type");
 
-  auto AreSupportedUsers = [&](Value *Start) {
-    SmallVector<Value *, 16> ValuesToCheck = {Start};
-    auto IsSupportedUse = [IsGridConstant](Value *V) -> bool {
-      if (isa<GetElementPtrInst>(V) || isa<BitCastInst>(V) || isa<LoadInst>(V))
-        return true;
-      // ASC to param space are OK, too -- we'll just strip them.
-      if (auto *ASC = dyn_cast<AddrSpaceCastInst>(V)) {
-        if (ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM)
-          return true;
-      }
-      // Simple calls and stores are supported for grid_constants
-      // writes to these pointers are undefined behaviour
-      if (IsGridConstant &&
-          (isa<CallInst>(V) || isa<StoreInst>(V) || isa<PtrToIntInst>(V)))
-        return true;
-      return false;
-    };
-
-    while (!ValuesToCheck.empty()) {
-      Value *V = ValuesToCheck.pop_back_val();
-      if (!IsSupportedUse(V)) {
-        LLVM_DEBUG(dbgs() << "Need a "
-                          << (isParamGridConstant(*Arg) ? "cast " : "copy ")
-                          << "of " << *Arg << " because of " << *V << "\n");
-        (void)Arg;
-        return false;
-      }
-      if (!isa<LoadInst>(V) && !isa<CallInst>(V) && !isa<StoreInst>(V) &&
-          !isa<PtrToIntInst>(V))
-        llvm::append_range(ValuesToCheck, V->users());
-    }
-    return true;
-  };
-
-  if (llvm::all_of(Arg->users(), AreSupportedUsers)) {
+  ArgUseChecker AUC(DL, IsGridConstant);
+  ArgUseChecker::PtrInfo PI = AUC.visitArgPtr(*Arg);
+  // Easy case, accessing parameter directly is fine.
+  if (!(PI.isEscaped() || PI.isAborted()) && AUC.Conditionals.empty()) {
----------------
akshayrdeodhar wrote:

I see that the Conditionals set is used for debug logs, but for pure functionality, having a local variable for the conditionals, and a simple `NoConditionalUses` boolean might make more sense? nbdl, and I'm not sure about this.

https://github.com/llvm/llvm-project/pull/106423


More information about the llvm-commits mailing list