[llvm] [ArgPromotion] Handle pointer arguments of recursive calls (PR #78735)

Vedant Paranjape via llvm-commits llvm-commits at lists.llvm.org
Fri May 31 05:58:01 PDT 2024


https://github.com/vedantparanjape-amd updated https://github.com/llvm/llvm-project/pull/78735

>From b8b89bb93c8e253a9577e2beb204f293479ea1b6 Mon Sep 17 00:00:00 2001
From: Vedant Paranjape <vedant.paranjape at amd.com>
Date: Thu, 18 Jan 2024 19:52:06 +0000
Subject: [PATCH] [ArgPromotion] Handle pointer arguments of recursive calls

---
 llvm/lib/Transforms/IPO/ArgumentPromotion.cpp | 60 +++++++++++++++++--
 1 file changed, 56 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
index 3aa8ea3f51471..05211adf9f1a1 100644
--- a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
+++ b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
@@ -445,7 +445,7 @@ static bool allCallersPassValidPointerForArgument(Argument *Arg,
 /// Determine that this argument is safe to promote, and find the argument
 /// parts it can be promoted into.
 static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
-                         unsigned MaxElements, bool IsRecursive,
+                         unsigned MaxElements, bool IsRecursive, bool IsSelfRecursive,
                          SmallVectorImpl<OffsetAndArgPart> &ArgPartsVec) {
   // Quick exit for unused arguments
   if (Arg->use_empty())
@@ -610,13 +610,59 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
       // unknown users
     }
 
+    if (IsSelfRecursive && isa<CallBase>(V)) {
+      Value *Ptr = dyn_cast<Value>(U);
+      Type *PtrTy = Ptr->getType();
+      Align PtrAlign = Ptr->getPointerAlignment(DL);
+      APInt Offset(DL.getIndexTypeSizeInBits(Ptr->getType()), 0);
+      Ptr = Ptr->stripAndAccumulateConstantOffsets(DL, Offset,
+                                                    /* AllowNonInbounds */ true);
+      if (Ptr != Arg)
+        return false;
+
+      if (Offset.getSignificantBits() >= 64)
+        return false;
+
+      int64_t Off = Offset.getSExtValue();
+      auto Pair = ArgParts.try_emplace(Off, ArgPart{PtrTy, PtrAlign, nullptr});
+      ArgPart &Part = Pair.first->second;
+
+      // We limit promotion to only promoting up to a fixed number of elements
+      // of the aggregate.
+      if (MaxElements > 0 && ArgParts.size() > MaxElements) {
+        LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
+                          << "more than " << MaxElements << " parts\n");
+        return false;
+      }
+
+      Part.Alignment = std::max(Part.Alignment, PtrAlign);
+      continue;
+    }
     // Unknown user.
     LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
                       << "unknown user " << *V << "\n");
     return false;
   }
 
-  if (NeededDerefBytes || NeededAlign > 1) {
+  // Incase of functions with recursive calls, this check will fail when it
+  // tries to look at the first caller of this function. The caller may or may
+  // not have a load, incase it doesn't load the pointer being passed, this
+  // check will fail. So, it's safe to skip the check incase we know that we
+  // are dealing with a recursive call.
+  //
+  // def fun(ptr %a) {
+  //   ...
+  //   %loadres = load i32, ptr %a, align 4
+  //   %res = call i32 @fun(ptr %a)
+  //   ...
+  // }
+  //
+  // def bar(ptr %x) {
+  //   ...
+  //   %resbar = call i32 @fun(ptr %x)
+  //   ...
+  // }
+  if (!IsRecursive && (NeededDerefBytes || NeededAlign > 1)) {
     // Try to prove a required deref / aligned requirement.
     if (!allCallersPassValidPointerForArgument(Arg, NeededAlign,
                                                NeededDerefBytes)) {
@@ -699,6 +745,10 @@ static bool areTypesABICompatible(ArrayRef<Type *> Types, const Function &F,
 /// calls the DoPromotion method.
 static Function *promoteArguments(Function *F, FunctionAnalysisManager &FAM,
                                   unsigned MaxElements, bool IsRecursive) {
+  // Due to complexity of handling cases where the SCC has more than one
+  // component. We want to limit argument promotion of recursive calls to
+  // just functions that directly call themselves.
+  bool IsSelfRecursive = false;
   // Don't perform argument promotion for naked functions; otherwise we can end
   // up removing parameters that are seemingly 'not used' as they are referred
   // to in the assembly.
@@ -744,8 +794,10 @@ static Function *promoteArguments(Function *F, FunctionAnalysisManager &FAM,
     if (CB->isMustTailCall())
       return nullptr;
 
-    if (CB->getFunction() == F)
+    if (CB->getFunction() == F) {
       IsRecursive = true;
+      IsSelfRecursive = true;
+    }
   }
 
   // Can't change signature of musttail caller
@@ -779,7 +831,7 @@ static Function *promoteArguments(Function *F, FunctionAnalysisManager &FAM,
     // If we can promote the pointer to its value.
     SmallVector<OffsetAndArgPart, 4> ArgParts;
 
-    if (findArgParts(PtrArg, DL, AAR, MaxElements, IsRecursive, ArgParts)) {
+    if (findArgParts(PtrArg, DL, AAR, MaxElements, IsRecursive, IsSelfRecursive, ArgParts)) {
       SmallVector<Type *, 4> Types;
       for (const auto &Pair : ArgParts)
         Types.push_back(Pair.second.Ty);



More information about the llvm-commits mailing list