[llvm] [ArgPromotion] Handle pointer arguments of recursive calls (PR #78735)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 19 07:53:57 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Vedant Paranjape (vedantparanjape-amd)

<details>
<summary>Changes</summary>

Tries to fix llvm#<!-- -->1259. This implementation fails pr42028-recursion.ll, the compiler crashes as it is not able to handle call graph as follows:

```
A -> B -> C
         ↑____|
```

---
Full diff: https://github.com/llvm/llvm-project/pull/78735.diff


1 Files Affected:

- (modified) llvm/lib/Transforms/IPO/ArgumentPromotion.cpp (+54-1) 


``````````diff
diff --git a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
index 8058282c422503..072a8c46f1e54c 100644
--- a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
+++ b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
@@ -609,13 +609,66 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
       // unknown users
     }
 
+    auto *CI = dyn_cast<CallInst>(V);
+    if (IsRecursive && CI && CI->getFunction() == Arg->getParent()) {
+      dbgs() << "Found recursive call\n";
+      Type *Ty = CI->getType();
+      Value *Ptr = CI->getArgOperand(U->getOperandNo());
+      Align PtrAlign = Ptr->getPointerAlignment(DL);
+      APInt Offset(DL.getIndexTypeSizeInBits(Ptr->getType()), 0);
+      Ptr = Ptr->stripAndAccumulateConstantOffsets(DL, Offset, /* AllowNonInbounds */ true);
+      if (Ptr != Arg)
+        return false;
+
+      if (Offset.getSignificantBits() >= 64)
+        return false;
+
+      TypeSize Size = DL.getTypeStoreSize(Ty);
+      // Don't try to promote scalable types.
+      if (Size.isScalable())
+        return false;
+
+      int64_t Off = Offset.getSExtValue();
+      auto Pair = ArgParts.try_emplace(
+          Off, ArgPart{Ty, PtrAlign, nullptr});
+      ArgPart &Part = Pair.first->second;
+
+      // We limit promotion to only promoting up to a fixed number of elements of
+      // the aggregate.
+      if (MaxElements > 0 && ArgParts.size() > MaxElements) {
+        LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
+                          << "more than " << MaxElements << " parts\n");
+        return false;
+      }
+
+      Part.Alignment = std::max(Part.Alignment, PtrAlign);
+      continue;
+    }
     // Unknown user.
     LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
                       << "unknown user " << *V << "\n");
     return false;
   }
 
-  if (NeededDerefBytes || NeededAlign > 1) {
+  // Incase of functions with recursive calls, this check will fail when it
+  // tries to look at the first caller of this function. The caller may or may
+  // not have a load, incase it doesn't load the pointer being passed, this
+  // check will fail. So, it's safe to skip the check incase we know that we
+  // are dealing with a recursive call.
+  //
+  // def fun(ptr %a) {
+  //   ...
+  //   %loadres = load i32, ptr %a, align 4
+  //   %res = call i32 @fun(ptr %a)
+  //   ...
+  // }
+  //
+  // def bar(ptr %x) {
+  //   ...
+  //   %resbar = call i32 @fun(ptr %x)
+  //   ...
+  // }
+  if (!IsRecursive && (NeededDerefBytes || NeededAlign > 1)) {
     // Try to prove a required deref / aligned requirement.
     if (!allCallersPassValidPointerForArgument(Arg, NeededAlign,
                                                NeededDerefBytes)) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/78735


More information about the llvm-commits mailing list