[llvm] [ArgPromotion] Handle pointer arguments of recursive calls (PR #78735)

Vedant Paranjape via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 19 07:57:18 PST 2024


https://github.com/vedantparanjape-amd updated https://github.com/llvm/llvm-project/pull/78735

>From 9a6065d730833677a4ac574416b2b08e076a790e Mon Sep 17 00:00:00 2001
From: Vedant Paranjape <vedant.paranjape at amd.com>
Date: Thu, 18 Jan 2024 19:52:06 +0000
Subject: [PATCH] [ArgPromotion] Handle pointer arguments of recursive calls

---
 llvm/lib/Transforms/IPO/ArgumentPromotion.cpp | 55 ++++++++++++++++++-
 1 file changed, 54 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
index 8058282c422503..3891afbe019de9 100644
--- a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
+++ b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
@@ -609,13 +609,66 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
       // unknown users
     }
 
+    auto *CI = dyn_cast<CallInst>(V);
+    if (IsRecursive && CI && CI->getFunction() == Arg->getParent()) {
+      dbgs() << "Found recursive call\n";
+      Type *Ty = CI->getType();
+      Value *Ptr = CI->getArgOperand(U->getOperandNo());
+      Align PtrAlign = Ptr->getPointerAlignment(DL);
+      APInt Offset(DL.getIndexTypeSizeInBits(Ptr->getType()), 0);
+      Ptr = Ptr->stripAndAccumulateConstantOffsets(DL, Offset,
+                                                   /* AllowNonInbounds */ true);
+      if (Ptr != Arg)
+        return false;
+
+      if (Offset.getSignificantBits() >= 64)
+        return false;
+
+      TypeSize Size = DL.getTypeStoreSize(Ty);
+      // Don't try to promote scalable types.
+      if (Size.isScalable())
+        return false;
+
+      int64_t Off = Offset.getSExtValue();
+      auto Pair = ArgParts.try_emplace(Off, ArgPart{Ty, PtrAlign, nullptr});
+      ArgPart &Part = Pair.first->second;
+
+      // We limit promotion to only promoting up to a fixed number of elements
+      // of the aggregate.
+      if (MaxElements > 0 && ArgParts.size() > MaxElements) {
+        LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
+                          << "more than " << MaxElements << " parts\n");
+        return false;
+      }
+
+      Part.Alignment = std::max(Part.Alignment, PtrAlign);
+      continue;
+    }
     // Unknown user.
     LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
                       << "unknown user " << *V << "\n");
     return false;
   }
 
-  if (NeededDerefBytes || NeededAlign > 1) {
+  // Incase of functions with recursive calls, this check will fail when it
+  // tries to look at the first caller of this function. The caller may or may
+  // not have a load, incase it doesn't load the pointer being passed, this
+  // check will fail. So, it's safe to skip the check incase we know that we
+  // are dealing with a recursive call.
+  //
+  // def fun(ptr %a) {
+  //   ...
+  //   %loadres = load i32, ptr %a, align 4
+  //   %res = call i32 @fun(ptr %a)
+  //   ...
+  // }
+  //
+  // def bar(ptr %x) {
+  //   ...
+  //   %resbar = call i32 @fun(ptr %x)
+  //   ...
+  // }
+  if (!IsRecursive && (NeededDerefBytes || NeededAlign > 1)) {
     // Try to prove a required deref / aligned requirement.
     if (!allCallersPassValidPointerForArgument(Arg, NeededAlign,
                                                NeededDerefBytes)) {



More information about the llvm-commits mailing list