[llvm] [ArgPromotion] Handle pointer arguments of recursive calls (PR #78735)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 6 09:13:53 PDT 2024


================
@@ -610,13 +610,78 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
       // unknown users
     }
 
+    auto *CB = dyn_cast<CallBase>(V);
+    Value *PtrArg = dyn_cast<Value>(U);
+    if (IsRecursive && CB && PtrArg) {
+      Type *PtrTy = PtrArg->getType();
+      Align PtrAlign = PtrArg->getPointerAlignment(DL);
+      APInt Offset(DL.getIndexTypeSizeInBits(PtrArg->getType()), 0);
+      PtrArg = PtrArg->stripAndAccumulateConstantOffsets(
+          DL, Offset,
+          /* AllowNonInbounds= */ true);
+      if (PtrArg != Arg)
+        return false;
+
+      if (Offset.getSignificantBits() >= 64)
+        return false;
+
+      // If this is a recursive function and one of the argument types is a
+      // pointer that isn't loaded to a non pointer type, it can lead to
+      // recursive promotion. Look for any Load candidates above the function
+      // call that load a non pointer type from this argument pointer. If we
+      // don't find even one such use, return false. For reference, you can
+      // refer to Transforms/ArgumentPromotion/pr42028-recursion.ll and
+      // Transforms/ArgumentPromotion/2008-09-08-CGUpdateSelfEdge.ll
+      // testcases.
+      bool doesPointerResolve = false;
+      for (auto Load : Loads)
+        if (Load->getPointerOperand() == PtrArg &&
+            !Load->getType()->isPointerTy())
+          doesPointerResolve = true;
+
+      if (!doesPointerResolve)
+        return false;
+
+      int64_t Off = Offset.getSExtValue();
+      auto Pair = ArgParts.try_emplace(Off, ArgPart{PtrTy, PtrAlign, nullptr});
+      ArgPart &Part = Pair.first->second;
+
+      // We limit promotion to only promoting up to a fixed number of elements
+      // of the aggregate.
+      if (MaxElements > 0 && ArgParts.size() > MaxElements) {
+        LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
+                          << "more than " << MaxElements << " parts\n");
+        return false;
+      }
+
+      Part.Alignment = std::max(Part.Alignment, PtrAlign);
+      continue;
+    }
     // Unknown user.
     LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
                       << "unknown user " << *V << "\n");
     return false;
   }
 
-  if (NeededDerefBytes || NeededAlign > 1) {
+  // Incase of functions with recursive calls, this check will fail when it
----------------
arsenm wrote:

"In case" 

https://github.com/llvm/llvm-project/pull/78735


More information about the llvm-commits mailing list