[llvm] [ArgPromotion] Handle pointer arguments of recursive calls (PR #78735)

Vedant Paranjape via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 12 12:52:43 PDT 2024


https://github.com/vedantparanjape-amd updated https://github.com/llvm/llvm-project/pull/78735

>From e9a1cd4b41edff16ca2f2e3d1a89743839980350 Mon Sep 17 00:00:00 2001
From: Vedant Paranjape <vedant.paranjape at amd.com>
Date: Thu, 18 Jan 2024 19:52:06 +0000
Subject: [PATCH 1/8] [ArgPromotion] Handle pointer arguments of recursive
 calls

Argument promotion doesn't handle recursive function calls to promote
arguments. This patch adds functionality to handle self recursive
function calls, i.e. whose SCC size is 1. Due to complexity of Value
Tracking in recursive calls with SCC size greater than 1, we bail out
in such cases.
---
 llvm/lib/Transforms/IPO/ArgumentPromotion.cpp | 62 +++++++++++++++++-
 .../argpromotion-recursion-pr1259.ll          | 65 +++++++++++++++++++
 2 files changed, 124 insertions(+), 3 deletions(-)
 create mode 100644 llvm/test/Transforms/ArgumentPromotion/argpromotion-recursion-pr1259.ll

diff --git a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
index 3aa8ea3f51471..8f5d0ea4af397 100644
--- a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
+++ b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
@@ -446,6 +446,7 @@ static bool allCallersPassValidPointerForArgument(Argument *Arg,
 /// parts it can be promoted into.
 static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
                          unsigned MaxElements, bool IsRecursive,
+                         bool IsSelfRecursive,
                          SmallVectorImpl<OffsetAndArgPart> &ArgPartsVec) {
   // Quick exit for unused arguments
   if (Arg->use_empty())
@@ -610,13 +611,61 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
       // unknown users
     }
 
+    auto *CB = dyn_cast<CallBase>(V);
+    Value *PtrArg = dyn_cast<Value>(U);
+    if (IsSelfRecursive && CB && PtrArg) {
+      Type *PtrTy = PtrArg->getType();
+      Align PtrAlign = PtrArg->getPointerAlignment(DL);
+      APInt Offset(DL.getIndexTypeSizeInBits(PtrArg->getType()), 0);
+      PtrArg = PtrArg->stripAndAccumulateConstantOffsets(
+          DL, Offset,
+          /* AllowNonInbounds= */ true);
+      if (PtrArg != Arg)
+        return false;
+
+      if (Offset.getSignificantBits() >= 64)
+        return false;
+
+      int64_t Off = Offset.getSExtValue();
+      auto Pair = ArgParts.try_emplace(Off, ArgPart{PtrTy, PtrAlign, nullptr});
+      ArgPart &Part = Pair.first->second;
+
+      // We limit promotion to only promoting up to a fixed number of elements
+      // of the aggregate.
+      if (MaxElements > 0 && ArgParts.size() > MaxElements) {
+        LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
+                          << "more than " << MaxElements << " parts\n");
+        return false;
+      }
+
+      Part.Alignment = std::max(Part.Alignment, PtrAlign);
+      continue;
+    }
     // Unknown user.
     LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
                       << "unknown user " << *V << "\n");
     return false;
   }
 
-  if (NeededDerefBytes || NeededAlign > 1) {
+  // Incase of functions with recursive calls, this check will fail when it
+  // tries to look at the first caller of this function. The caller may or may
+  // not have a load, incase it doesn't load the pointer being passed, this
+  // check will fail. So, it's safe to skip the check incase we know that we
+  // are dealing with a recursive call.
+  //
+  // def fun(ptr %a) {
+  //   ...
+  //   %loadres = load i32, ptr %a, align 4
+  //   %res = call i32 @fun(ptr %a)
+  //   ...
+  // }
+  //
+  // def bar(ptr %x) {
+  //   ...
+  //   %resbar = call i32 @fun(ptr %x)
+  //   ...
+  // }
+  if (!IsRecursive && (NeededDerefBytes || NeededAlign > 1)) {
     // Try to prove a required deref / aligned requirement.
     if (!allCallersPassValidPointerForArgument(Arg, NeededAlign,
                                                NeededDerefBytes)) {
@@ -699,6 +748,10 @@ static bool areTypesABICompatible(ArrayRef<Type *> Types, const Function &F,
 /// calls the DoPromotion method.
 static Function *promoteArguments(Function *F, FunctionAnalysisManager &FAM,
                                   unsigned MaxElements, bool IsRecursive) {
+  // Due to complexity of handling cases where the SCC has more than one
+  // component. We want to limit argument promotion of recursive calls to
+  // just functions that directly call themselves.
+  bool IsSelfRecursive = false;
   // Don't perform argument promotion for naked functions; otherwise we can end
   // up removing parameters that are seemingly 'not used' as they are referred
   // to in the assembly.
@@ -744,8 +797,10 @@ static Function *promoteArguments(Function *F, FunctionAnalysisManager &FAM,
     if (CB->isMustTailCall())
       return nullptr;
 
-    if (CB->getFunction() == F)
+    if (CB->getFunction() == F) {
       IsRecursive = true;
+      IsSelfRecursive = true;
+    }
   }
 
   // Can't change signature of musttail caller
@@ -779,7 +834,8 @@ static Function *promoteArguments(Function *F, FunctionAnalysisManager &FAM,
     // If we can promote the pointer to its value.
     SmallVector<OffsetAndArgPart, 4> ArgParts;
 
-    if (findArgParts(PtrArg, DL, AAR, MaxElements, IsRecursive, ArgParts)) {
+    if (findArgParts(PtrArg, DL, AAR, MaxElements, IsRecursive, IsSelfRecursive,
+                     ArgParts)) {
       SmallVector<Type *, 4> Types;
       for (const auto &Pair : ArgParts)
         Types.push_back(Pair.second.Ty);
diff --git a/llvm/test/Transforms/ArgumentPromotion/argpromotion-recursion-pr1259.ll b/llvm/test/Transforms/ArgumentPromotion/argpromotion-recursion-pr1259.ll
new file mode 100644
index 0000000000000..19bb4492171fc
--- /dev/null
+++ b/llvm/test/Transforms/ArgumentPromotion/argpromotion-recursion-pr1259.ll
@@ -0,0 +1,65 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes=argpromotion < %s | FileCheck %s
+define internal i32 @foo(ptr %x, i32 %n, i32 %m) {
+; CHECK-LABEL: define internal i32 @foo(
+; CHECK-SAME: i32 [[X_0_VAL:%.*]], i32 [[N:%.*]], i32 [[M:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ne i32 [[N]], 0
+; CHECK-NEXT:    br i1 [[CMP]], label %[[COND_TRUE:.*]], label %[[COND_FALSE:.*]]
+; CHECK:       [[COND_TRUE]]:
+; CHECK-NEXT:    br label %[[RETURN:.*]]
+; CHECK:       [[COND_FALSE]]:
+; CHECK-NEXT:    [[SUBVAL:%.*]] = sub i32 [[N]], 1
+; CHECK-NEXT:    [[CALLRET:%.*]] = call i32 @foo(i32 [[X_0_VAL]], i32 [[SUBVAL]], i32 [[X_0_VAL]])
+; CHECK-NEXT:    [[SUBVAL2:%.*]] = sub i32 [[N]], 2
+; CHECK-NEXT:    [[CALLRET2:%.*]] = call i32 @foo(i32 [[X_0_VAL]], i32 [[SUBVAL2]], i32 [[M]])
+; CHECK-NEXT:    [[CMP2:%.*]] = add i32 [[CALLRET]], [[CALLRET2]]
+; CHECK-NEXT:    br label %[[RETURN]]
+; CHECK:       [[COND_NEXT:.*]]:
+; CHECK-NEXT:    br label %[[RETURN]]
+; CHECK:       [[RETURN]]:
+; CHECK-NEXT:    [[RETVAL_0:%.*]] = phi i32 [ [[X_0_VAL]], %[[COND_TRUE]] ], [ [[CMP2]], %[[COND_FALSE]] ], [ undef, %[[COND_NEXT]] ]
+; CHECK-NEXT:    ret i32 [[RETVAL_0]]
+;
+entry:
+  %cmp = icmp ne i32 %n, 0
+  br i1 %cmp, label %cond_true, label %cond_false
+
+cond_true:                                        ; preds = %entry
+  %val = load i32, ptr %x, align 4
+  br label %return
+
+cond_false:                                       ; preds = %entry
+  %val2 = load i32, ptr %x, align 4
+  %subval = sub i32 %n, 1
+  %callret = call i32 @foo(ptr %x, i32 %subval, i32 %val2)
+  %subval2 = sub i32 %n, 2
+  %callret2 = call i32 @foo(ptr %x, i32 %subval2, i32 %m)
+  %cmp2 = add i32 %callret, %callret2
+  br label %return
+
+cond_next:                                        ; No predecessors!
+  br label %return
+
+return:                                           ; preds = %cond_next, %cond_false, %cond_true
+  %retval.0 = phi i32 [ %val, %cond_true ], [ %cmp2, %cond_false ], [ undef, %cond_next ]
+  ret i32 %retval.0
+}
+
+define i32 @bar(ptr %x, i32 %n, i32 %m) {
+; CHECK-LABEL: define i32 @bar(
+; CHECK-SAME: ptr [[X:%.*]], i32 [[N:%.*]], i32 [[M:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[X_VAL:%.*]] = load i32, ptr [[X]], align 4
+; CHECK-NEXT:    [[CALLRET3:%.*]] = call i32 @foo(i32 [[X_VAL]], i32 [[N]], i32 [[M]])
+; CHECK-NEXT:    br label %[[RETURN:.*]]
+; CHECK:       [[RETURN]]:
+; CHECK-NEXT:    ret i32 [[CALLRET3]]
+;
+entry:
+  %callret3 = call i32 @foo(ptr %x, i32 %n, i32 %m)
+  br label %return
+
+return:                                           ; preds = %entry
+  ret i32 %callret3
+}

>From dd18202a560aa25cd6c6bcbfb01b5d6315a55f6f Mon Sep 17 00:00:00 2001
From: Vedant Paranjape <vedant.paranjape at amd.com>
Date: Mon, 3 Jun 2024 08:23:17 +0000
Subject: [PATCH 2/8] Add a check to stop recursive promotion of ptr args

---
 llvm/lib/Transforms/IPO/ArgumentPromotion.cpp | 31 ++++++++++++-------
 1 file changed, 20 insertions(+), 11 deletions(-)

diff --git a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
index 8f5d0ea4af397..47da9b7f5a882 100644
--- a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
+++ b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
@@ -446,7 +446,6 @@ static bool allCallersPassValidPointerForArgument(Argument *Arg,
 /// parts it can be promoted into.
 static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
                          unsigned MaxElements, bool IsRecursive,
-                         bool IsSelfRecursive,
                          SmallVectorImpl<OffsetAndArgPart> &ArgPartsVec) {
   // Quick exit for unused arguments
   if (Arg->use_empty())
@@ -613,7 +612,7 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
 
     auto *CB = dyn_cast<CallBase>(V);
     Value *PtrArg = dyn_cast<Value>(U);
-    if (IsSelfRecursive && CB && PtrArg) {
+    if (IsRecursive && CB && PtrArg) {
       Type *PtrTy = PtrArg->getType();
       Align PtrAlign = PtrArg->getPointerAlignment(DL);
       APInt Offset(DL.getIndexTypeSizeInBits(PtrArg->getType()), 0);
@@ -626,6 +625,23 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
       if (Offset.getSignificantBits() >= 64)
         return false;
 
+      // If this is a recursive function and one of the argument types is a
+      // pointer that isn't loaded to a non pointer type, it can lead to
+      // recursive promotion. Look for any Load candidates above the function
+      // call that load a non pointer type from this argument pointer. If we
+      // don't find even one such use, return false. For reference, you can
+      // refer to Transforms/ArgumentPromotion/pr42028-recursion.ll and
+      // Transforms/ArgumentPromotion/2008-09-08-CGUpdateSelfEdge.ll
+      // testcases.
+      bool doesPointerResolve = false;
+      for (auto Load : Loads)
+        if (Load->getPointerOperand() == PtrArg &&
+            !Load->getType()->isPointerTy())
+          doesPointerResolve = true;
+
+      if (!doesPointerResolve)
+        return false;
+
       int64_t Off = Offset.getSExtValue();
       auto Pair = ArgParts.try_emplace(Off, ArgPart{PtrTy, PtrAlign, nullptr});
       ArgPart &Part = Pair.first->second;
@@ -748,10 +764,6 @@ static bool areTypesABICompatible(ArrayRef<Type *> Types, const Function &F,
 /// calls the DoPromotion method.
 static Function *promoteArguments(Function *F, FunctionAnalysisManager &FAM,
                                   unsigned MaxElements, bool IsRecursive) {
-  // Due to complexity of handling cases where the SCC has more than one
-  // component. We want to limit argument promotion of recursive calls to
-  // just functions that directly call themselves.
-  bool IsSelfRecursive = false;
   // Don't perform argument promotion for naked functions; otherwise we can end
   // up removing parameters that are seemingly 'not used' as they are referred
   // to in the assembly.
@@ -797,10 +809,8 @@ static Function *promoteArguments(Function *F, FunctionAnalysisManager &FAM,
     if (CB->isMustTailCall())
       return nullptr;
 
-    if (CB->getFunction() == F) {
+    if (CB->getFunction() == F)
       IsRecursive = true;
-      IsSelfRecursive = true;
-    }
   }
 
   // Can't change signature of musttail caller
@@ -834,8 +844,7 @@ static Function *promoteArguments(Function *F, FunctionAnalysisManager &FAM,
     // If we can promote the pointer to its value.
     SmallVector<OffsetAndArgPart, 4> ArgParts;
 
-    if (findArgParts(PtrArg, DL, AAR, MaxElements, IsRecursive, IsSelfRecursive,
-                     ArgParts)) {
+    if (findArgParts(PtrArg, DL, AAR, MaxElements, IsRecursive, ArgParts)) {
       SmallVector<Type *, 4> Types;
       for (const auto &Pair : ArgParts)
         Types.push_back(Pair.second.Ty);

>From 5bc21021f4993040940b5e743b19ce5d163ad74d Mon Sep 17 00:00:00 2001
From: Vedant Paranjape <vedant.paranjape at amd.com>
Date: Fri, 7 Jun 2024 02:55:01 +0000
Subject: [PATCH 3/8] Address review comments

---
 llvm/lib/Transforms/IPO/ArgumentPromotion.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
index 47da9b7f5a882..9291f845fe9f9 100644
--- a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
+++ b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
@@ -611,7 +611,7 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
     }
 
     auto *CB = dyn_cast<CallBase>(V);
-    Value *PtrArg = dyn_cast<Value>(U);
+    Value *PtrArg = cast<Value>(U);
     if (IsRecursive && CB && PtrArg) {
       Type *PtrTy = PtrArg->getType();
       Align PtrAlign = PtrArg->getPointerAlignment(DL);
@@ -663,7 +663,7 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
     return false;
   }
 
-  // Incase of functions with recursive calls, this check will fail when it
+  // In case of functions with recursive calls, this check will fail when it
   // tries to look at the first caller of this function. The caller may or may
   // not have a load, incase it doesn't load the pointer being passed, this
   // check will fail. So, it's safe to skip the check incase we know that we

>From bc1c6c6bbb5d09d7b5b3f0ad6a5ae8667c8170f1 Mon Sep 17 00:00:00 2001
From: Vedant Paranjape <vedant.paranjape at amd.com>
Date: Mon, 10 Jun 2024 05:37:06 +0000
Subject: [PATCH 4/8] Address review 2 comments

---
 llvm/lib/Transforms/IPO/ArgumentPromotion.cpp | 21 ++++++++++++-------
 1 file changed, 13 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
index 9291f845fe9f9..7c557e2657b28 100644
--- a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
+++ b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
@@ -422,13 +422,16 @@ doPromotion(Function *F, FunctionAnalysisManager &FAM,
 
 /// Return true if we can prove that all callees pass in a valid pointer for the
 /// specified function argument.
-static bool allCallersPassValidPointerForArgument(Argument *Arg,
-                                                  Align NeededAlign,
-                                                  uint64_t NeededDerefBytes) {
+static bool allCallersPassValidPointerForArgument(
+    Argument *Arg, SmallPtrSet<CallBase *, 4> &RecursiveCalls,
+    Align NeededAlign, uint64_t NeededDerefBytes) {
   Function *Callee = Arg->getParent();
   const DataLayout &DL = Callee->getParent()->getDataLayout();
   APInt Bytes(64, NeededDerefBytes);
 
+  if (RecursiveCalls.size())
+    return true;
+
   // Check if the argument itself is marked dereferenceable and aligned.
   if (isDereferenceableAndAlignedPointer(Arg, NeededAlign, Bytes, DL))
     return true;
@@ -570,6 +573,7 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
   SmallVector<const Use *, 16> Worklist;
   SmallPtrSet<const Use *, 16> Visited;
   SmallVector<LoadInst *, 16> Loads;
+  SmallPtrSet<CallBase *, 4> RecursiveCalls;
   auto AppendUses = [&](const Value *V) {
     for (const Use &U : V->uses())
       if (Visited.insert(&U).second)
@@ -643,8 +647,9 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
         return false;
 
       int64_t Off = Offset.getSExtValue();
-      auto Pair = ArgParts.try_emplace(Off, ArgPart{PtrTy, PtrAlign, nullptr});
-      ArgPart &Part = Pair.first->second;
+      if (Off)
+        LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
+                          << "pointer offset is not equal to zero\n");
 
       // We limit promotion to only promoting up to a fixed number of elements
       // of the aggregate.
@@ -654,7 +659,7 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
         return false;
       }
 
-      Part.Alignment = std::max(Part.Alignment, PtrAlign);
+      RecursiveCalls.insert(CB);
       continue;
     }
     // Unknown user.
@@ -681,9 +686,9 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
   //   %resbar = call i32 @fun(ptr %x)
   //   ...
   // }
-  if (!IsRecursive && (NeededDerefBytes || NeededAlign > 1)) {
+  if (NeededDerefBytes || NeededAlign > 1) {
     // Try to prove a required deref / aligned requirement.
-    if (!allCallersPassValidPointerForArgument(Arg, NeededAlign,
+    if (!allCallersPassValidPointerForArgument(Arg, RecursiveCalls, NeededAlign,
                                                NeededDerefBytes)) {
       LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
                         << "not dereferenceable or aligned\n");

>From 3d3aaeae4ee4e7ff520a13a0ed25de08186ab1ef Mon Sep 17 00:00:00 2001
From: Vedant Paranjape <vedant.paranjape at amd.com>
Date: Mon, 10 Jun 2024 05:43:01 +0000
Subject: [PATCH 5/8] minor nitpick

---
 llvm/lib/Transforms/IPO/ArgumentPromotion.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
index 7c557e2657b28..ff6dc17ab721a 100644
--- a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
+++ b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
@@ -618,8 +618,7 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
     Value *PtrArg = cast<Value>(U);
     if (IsRecursive && CB && PtrArg) {
       Type *PtrTy = PtrArg->getType();
-      Align PtrAlign = PtrArg->getPointerAlignment(DL);
-      APInt Offset(DL.getIndexTypeSizeInBits(PtrArg->getType()), 0);
+      APInt Offset(DL.getIndexTypeSizeInBits(PtrTy), 0);
       PtrArg = PtrArg->stripAndAccumulateConstantOffsets(
           DL, Offset,
           /* AllowNonInbounds= */ true);

>From c3a1a27fe7fa8cf66f199d031a971157bdd568fe Mon Sep 17 00:00:00 2001
From: Vedant Paranjape <vedant.paranjape at amd.com>
Date: Wed, 12 Jun 2024 10:28:03 +0000
Subject: [PATCH 6/8] Address review 3 comments

---
 llvm/lib/Transforms/IPO/ArgumentPromotion.cpp | 15 ++++++++++++---
 1 file changed, 12 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
index ff6dc17ab721a..0e9e0b5eaf2a1 100644
--- a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
+++ b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
@@ -429,8 +429,8 @@ static bool allCallersPassValidPointerForArgument(
   const DataLayout &DL = Callee->getParent()->getDataLayout();
   APInt Bytes(64, NeededDerefBytes);
 
-  if (RecursiveCalls.size())
-    return true;
+  // if (RecursiveCalls.size())
+  //   return true;
 
   // Check if the argument itself is marked dereferenceable and aligned.
   if (isDereferenceableAndAlignedPointer(Arg, NeededAlign, Bytes, DL))
@@ -440,6 +440,13 @@ static bool allCallersPassValidPointerForArgument(
   // direct callees.
   return all_of(Callee->users(), [&](User *U) {
     CallBase &CB = cast<CallBase>(*U);
+    if (RecursiveCalls.contains(&CB))
+      return true;
+
+    // if (RecursiveCalls.size() &&
+    //     CB.getCalledFunction()->getName() == Callee->getName())
+    //   return true;
+
     return isDereferenceableAndAlignedPointer(CB.getArgOperand(Arg->getArgNo()),
                                               NeededAlign, Bytes, DL);
   });
@@ -646,9 +653,11 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
         return false;
 
       int64_t Off = Offset.getSExtValue();
-      if (Off)
+      if (Off) {
         LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
                           << "pointer offset is not equal to zero\n");
+        return false;
+      }
 
       // We limit promotion to only promoting up to a fixed number of elements
       // of the aggregate.

>From 661c6d9bcbdbe2930303ca4996ca69061311f645 Mon Sep 17 00:00:00 2001
From: Vedant Paranjape <vedant.paranjape at amd.com>
Date: Wed, 12 Jun 2024 19:05:55 +0000
Subject: [PATCH 7/8] Address review 4 comments

---
 llvm/lib/Transforms/IPO/ArgumentPromotion.cpp | 7 +++++++
 1 file changed, 7 insertions(+)

diff --git a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
index 0e9e0b5eaf2a1..1aef4d463ea17 100644
--- a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
+++ b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
@@ -659,6 +659,13 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
         return false;
       }
 
+      unsigned int ArgNo = Arg->getArgNo();
+      if (CB->getArgOperand(ArgNo) != Arg) {
+        LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
+                          << "arg position is different in callee\n");
+        return false;
+      }
+
       // We limit promotion to only promoting up to a fixed number of elements
       // of the aggregate.
       if (MaxElements > 0 && ArgParts.size() > MaxElements) {

>From 5cc9ff541ddcd0a76c3cb6a33899c43c66faffa8 Mon Sep 17 00:00:00 2001
From: Vedant Paranjape <vedant.paranjape at amd.com>
Date: Wed, 12 Jun 2024 19:52:08 +0000
Subject: [PATCH 8/8] Refactor and add comments, also address review comments

---
 llvm/lib/Transforms/IPO/ArgumentPromotion.cpp | 85 ++++++++++---------
 .../argpromotion-recursion-pr1259.ll          |  4 +-
 2 files changed, 47 insertions(+), 42 deletions(-)

diff --git a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
index 1aef4d463ea17..cb248ec7c59f5 100644
--- a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
+++ b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
@@ -420,6 +420,26 @@ doPromotion(Function *F, FunctionAnalysisManager &FAM,
   return NF;
 }
 
+/// Returns true if the Ptr is loaded by any Load in the vector of
+/// Loads, and if the Loaded value is not a pointer.
+static bool checkIfPointerIsDereferenced(SmallVector<LoadInst *, 16> &Loads,
+                                         const Value *Ptr) {
+  // If this is a recursive function and one of the argument types is a
+  // pointer that isn't loaded to a non pointer type, it can lead to
+  // recursive promotion. Look for any Load candidates above the function
+  // call that load a non pointer type from this argument pointer. If we
+  // don't find even one such use, return false. For reference, you can
+  // refer to Transforms/ArgumentPromotion/pr42028-recursion.ll and
+  // Transforms/ArgumentPromotion/2008-09-08-CGUpdateSelfEdge.ll
+  // testcases.
+  bool doesPointerResolve = false;
+  for (auto Load : Loads)
+    if (Load->getPointerOperand() == Ptr && !Load->getType()->isPointerTy())
+      doesPointerResolve = true;
+
+  return doesPointerResolve;
+}
+
 /// Return true if we can prove that all callees pass in a valid pointer for the
 /// specified function argument.
 static bool allCallersPassValidPointerForArgument(
@@ -429,9 +449,6 @@ static bool allCallersPassValidPointerForArgument(
   const DataLayout &DL = Callee->getParent()->getDataLayout();
   APInt Bytes(64, NeededDerefBytes);
 
-  // if (RecursiveCalls.size())
-  //   return true;
-
   // Check if the argument itself is marked dereferenceable and aligned.
   if (isDereferenceableAndAlignedPointer(Arg, NeededAlign, Bytes, DL))
     return true;
@@ -440,13 +457,33 @@ static bool allCallersPassValidPointerForArgument(
   // direct callees.
   return all_of(Callee->users(), [&](User *U) {
     CallBase &CB = cast<CallBase>(*U);
+    // In case of functions with recursive calls, this check
+    // (isDereferenceableAndAlignedPointer) will fail when it tries to look at
+    // the first caller of this function. The caller may or may not have a load,
+    // incase it doesn't load the pointer being passed, this check will fail.
+    // So, it's safe to skip the check incase we know that we are dealing with a
+    // recursive call. For example we have a IR given below.
+    //
+    // def fun(ptr %a) {
+    //   ...
+    //   %loadres = load i32, ptr %a, align 4
+    //   %res = call i32 @fun(ptr %a)
+    //   ...
+    // }
+    //
+    // def bar(ptr %x) {
+    //   ...
+    //   %resbar = call i32 @fun(ptr %x)
+    //   ...
+    // }
+    //
+    // Since we record processed recursive calls, we check if the current
+    // CallBase has been processed before. If yes it means that it is a
+    // recursive call and we can skip the check just for this call. So, just
+    // return true.
     if (RecursiveCalls.contains(&CB))
       return true;
 
-    // if (RecursiveCalls.size() &&
-    //     CB.getCalledFunction()->getName() == Callee->getName())
-    //   return true;
-
     return isDereferenceableAndAlignedPointer(CB.getArgOperand(Arg->getArgNo()),
                                               NeededAlign, Bytes, DL);
   });
@@ -635,21 +672,7 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
       if (Offset.getSignificantBits() >= 64)
         return false;
 
-      // If this is a recursive function and one of the argument types is a
-      // pointer that isn't loaded to a non pointer type, it can lead to
-      // recursive promotion. Look for any Load candidates above the function
-      // call that load a non pointer type from this argument pointer. If we
-      // don't find even one such use, return false. For reference, you can
-      // refer to Transforms/ArgumentPromotion/pr42028-recursion.ll and
-      // Transforms/ArgumentPromotion/2008-09-08-CGUpdateSelfEdge.ll
-      // testcases.
-      bool doesPointerResolve = false;
-      for (auto Load : Loads)
-        if (Load->getPointerOperand() == PtrArg &&
-            !Load->getType()->isPointerTy())
-          doesPointerResolve = true;
-
-      if (!doesPointerResolve)
+      if (!checkIfPointerIsDereferenced(Loads, PtrArg))
         return false;
 
       int64_t Off = Offset.getSExtValue();
@@ -683,24 +706,6 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
     return false;
   }
 
-  // In case of functions with recursive calls, this check will fail when it
-  // tries to look at the first caller of this function. The caller may or may
-  // not have a load, incase it doesn't load the pointer being passed, this
-  // check will fail. So, it's safe to skip the check incase we know that we
-  // are dealing with a recursive call.
-  //
-  // def fun(ptr %a) {
-  //   ...
-  //   %loadres = load i32, ptr %a, align 4
-  //   %res = call i32 @fun(ptr %a)
-  //   ...
-  // }
-  //
-  // def bar(ptr %x) {
-  //   ...
-  //   %resbar = call i32 @fun(ptr %x)
-  //   ...
-  // }
   if (NeededDerefBytes || NeededAlign > 1) {
     // Try to prove a required deref / aligned requirement.
     if (!allCallersPassValidPointerForArgument(Arg, RecursiveCalls, NeededAlign,
diff --git a/llvm/test/Transforms/ArgumentPromotion/argpromotion-recursion-pr1259.ll b/llvm/test/Transforms/ArgumentPromotion/argpromotion-recursion-pr1259.ll
index 19bb4492171fc..401cd9f6cbf33 100644
--- a/llvm/test/Transforms/ArgumentPromotion/argpromotion-recursion-pr1259.ll
+++ b/llvm/test/Transforms/ArgumentPromotion/argpromotion-recursion-pr1259.ll
@@ -18,7 +18,7 @@ define internal i32 @foo(ptr %x, i32 %n, i32 %m) {
 ; CHECK:       [[COND_NEXT:.*]]:
 ; CHECK-NEXT:    br label %[[RETURN]]
 ; CHECK:       [[RETURN]]:
-; CHECK-NEXT:    [[RETVAL_0:%.*]] = phi i32 [ [[X_0_VAL]], %[[COND_TRUE]] ], [ [[CMP2]], %[[COND_FALSE]] ], [ undef, %[[COND_NEXT]] ]
+; CHECK-NEXT:    [[RETVAL_0:%.*]] = phi i32 [ [[X_0_VAL]], %[[COND_TRUE]] ], [ [[CMP2]], %[[COND_FALSE]] ], [ poison, %[[COND_NEXT]] ]
 ; CHECK-NEXT:    ret i32 [[RETVAL_0]]
 ;
 entry:
@@ -42,7 +42,7 @@ cond_next:                                        ; No predecessors!
   br label %return
 
 return:                                           ; preds = %cond_next, %cond_false, %cond_true
-  %retval.0 = phi i32 [ %val, %cond_true ], [ %cmp2, %cond_false ], [ undef, %cond_next ]
+  %retval.0 = phi i32 [ %val, %cond_true ], [ %cmp2, %cond_false ], [ poison, %cond_next ]
   ret i32 %retval.0
 }
 



More information about the llvm-commits mailing list