[llvm] [TailRecElim] Introduce support for shift accumulator optimization (PR #181331)

Federico Bruzzone via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 13 00:28:22 PST 2026


https://github.com/FedericoBruzzone created https://github.com/llvm/llvm-project/pull/181331

This PR enables Tail Recursion Elimination (TRE) for functions where the accumulator operation is a shift (`shl`, `lshr`, `ashr`) by a constant amount -- i.e., pseudo-associative relation.

As pointed out in #178805, `InstCombine` often strength-reduces multiplications (or `f(x-1) + f(x-1)`) into `shl`.
Currently, TRE strictly requires operations to be associative and commutative:
https://github.com/llvm/llvm-project/blob/05e908609227e1e8d993659e604a63668dfd2825/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp#L377-L379
This prevents TRE from transforming recursive shifts into loops, creating a phase-ordering problem where canonicalization blocks a structural optimization.

This PR does **not** perform shift accumulator optimization when there are multiple base cases: it is reserved for future work.

Fixes #178805.

>From dc82a90c7365ff4dd8009309c07a24efd880536d Mon Sep 17 00:00:00 2001
From: FedericoBruzzone <federico.bruzzone.i at gmail.com>
Date: Fri, 13 Feb 2026 08:50:16 +0100
Subject: [PATCH] [TailRecElim] Add support for shift accumulator optimization

Signed-off-by: FedericoBruzzone <federico.bruzzone.i at gmail.com>
---
 .../Scalar/TailRecursionElimination.cpp       | 221 ++++++++++++++----
 .../TailCallElim/shl-accumulator-opt.ll       | 118 ++++++++++
 2 files changed, 299 insertions(+), 40 deletions(-)
 create mode 100644 llvm/test/Transforms/TailCallElim/shl-accumulator-opt.ll

diff --git a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
index 89d41f3e40de7..8af925a105aad 100644
--- a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
@@ -82,12 +82,13 @@
 #include "llvm/Transforms/Scalar.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include <cmath>
+#include <variant>
 using namespace llvm;
 
 #define DEBUG_TYPE "tailcallelim"
 
 STATISTIC(NumEliminated, "Number of tail calls removed");
-STATISTIC(NumRetDuped,   "Number of return duplicated");
+STATISTIC(NumRetDuped, "Number of return duplicated");
 STATISTIC(NumAccumAdded, "Number of accumulators introduced");
 
 static cl::opt<bool> ForceDisableBFI(
@@ -159,7 +160,7 @@ struct AllocaDerivedValueTracker {
       case Instruction::Store: {
         if (U->getOperandNo() == 0)
           EscapePoints.insert(I);
-        continue;  // Stores have no users to analyze.
+        continue; // Stores have no users to analyze.
       }
       case Instruction::BitCast:
       case Instruction::GetElementPtr:
@@ -215,11 +216,7 @@ static bool markTails(Function &F, OptimizationRemarkEmitter *ORE) {
   // Track whether a block is reachable after an alloca has escaped. Blocks that
   // contain the escaping instruction will be marked as being visited without an
   // escaped alloca, since that is how the block began.
-  enum VisitType {
-    UNVISITED,
-    UNESCAPED,
-    ESCAPED
-  };
+  enum VisitType { UNVISITED, UNESCAPED, ESCAPED };
   DenseMap<BasicBlock *, VisitType> Visited;
 
   // We propagate the fact that an alloca has escaped from block to successor.
@@ -340,7 +337,6 @@ static bool markTails(Function &F, OptimizationRemarkEmitter *ORE) {
 /// Return true if it is safe to move the specified
 /// instruction from after the call to before the call, assuming that all
 /// instructions between the call and this instruction are movable.
-///
 static bool canMoveAboveCall(Instruction *I, CallInst *CI, AliasAnalysis *AA) {
   if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I))
     if (II->getIntrinsicID() == Intrinsic::lifetime_end)
@@ -348,7 +344,7 @@ static bool canMoveAboveCall(Instruction *I, CallInst *CI, AliasAnalysis *AA) {
 
   // FIXME: We can move load/store/call/free instructions above the call if the
   // call does not mod/ref the memory location being processed.
-  if (I->mayHaveSideEffects())  // This also handles volatile loads.
+  if (I->mayHaveSideEffects()) // This also handles volatile loads.
     return false;
 
   if (LoadInst *L = dyn_cast<LoadInst>(I)) {
@@ -374,28 +370,136 @@ static bool canMoveAboveCall(Instruction *I, CallInst *CI, AliasAnalysis *AA) {
   return !is_contained(I->operands(), CI);
 }
 
-static bool canTransformAccumulatorRecursion(Instruction *I, CallInst *CI) {
-  if (!I->isAssociative() || !I->isCommutative())
+// While shifts are neither associative nor commutative, a chain of shifts by a
+// constant amount C is equivalent to a single shift by the sum of the amounts:
+//     ... (Base << C) << C) ... << C == Base << (C * Iterations)
+// This relation applies to left shifts as well as arithmetic/logical right
+// shifts when the shift amount is a constant.
+static bool isPseudoAssociative(Instruction *I) {
+  switch (I->getOpcode()) {
+  case Instruction::Shl:
+  case Instruction::AShr:
+  case Instruction::LShr:
+    break;
+  default:
+    return false;
+  }
+
+  ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1));
+  if (!CI)
+    return false;
+
+  return true;
+}
+
+// Find the base-case return value for function F: examine all
+// return instructions and pick return values that do not depend on a
+// recursive call to F. If there is exactly one distinct such value,
+// return it. If there are none or more than one distinct value, return
+// nullptr to indicate failure.
+//
+// FIXME: There is a room for improvement here in the future, e.g., consider
+// non-constant values and multiple base cases -- e.g., we want to be able to
+// handle code like:
+// ```
+// int f(int x) {
+//  if (x == 1) return 1;
+//  if (x == 10) return 10;
+//  return f(x-1) << 1;
+// }
+// ```
+static Constant *getReturnValue(Function &F) {
+  Constant *BaseCaseVal = nullptr;
+
+  auto ValueUsesRecursiveCall = [&](Value *V) {
+    SmallVector<Value *, 8> Worklist;
+    SmallPtrSet<Value *, 8> Visited;
+    Worklist.push_back(V);
+    while (!Worklist.empty()) {
+      Value *Cur = Worklist.pop_back_val();
+      if (!Visited.insert(Cur).second)
+        continue;
+      if (Instruction *I = dyn_cast<Instruction>(Cur)) {
+        if (CallInst *CI = dyn_cast<CallInst>(I))
+          if (CI->getCalledFunction() == &F)
+            return true;
+        for (Use &U : I->operands())
+          Worklist.push_back(U.get());
+      } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Cur)) {
+        for (unsigned I = 0, E = CE->getNumOperands(); I != E; ++I)
+          Worklist.push_back(CE->getOperand(I));
+      }
+    }
+    return false;
+  };
+
+  for (BasicBlock &BB : F) {
+    if (ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator())) {
+      Value *RV = RI->getReturnValue();
+      if (!RV)
+        continue;
+      // It isn't so trivial to perform TRE if the return value depends on the
+      // result of a recursive call, because the return value will be different
+      // for different iterations of the recursion. So we ignore return values
+      // that depend on recursive calls.
+      if (ValueUsesRecursiveCall(RV))
+        continue;
+      if (!BaseCaseVal && isa<Constant>(RV))
+        BaseCaseVal = cast<Constant>(RV);
+      else if (BaseCaseVal != RV)
+        return nullptr;
+    }
+  }
+
+  return BaseCaseVal;
+}
+
+// This function checks whether the instruction I can be used
+// to perform accumulator recursion elimination for the
+// call instruction CI.
+// In the presence of pseudo-associative operations, it returns
+// the base case constant value to both indicate success and
+// provide the value needed to initialize the accumulator.
+static std::variant<bool, Constant *>
+canTransformAccumulatorRecursion(Instruction *I, CallInst *CI) {
+  auto IsPseudoAssociative = isPseudoAssociative(I);
+  if ((!I->isAssociative() || !I->isCommutative()) && !IsPseudoAssociative)
     return false;
 
   assert(I->getNumOperands() >= 2 &&
          "Associative/commutative operations should have at least 2 args!");
 
-  if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
-    // Accumulators must have an identity.
-    if (!ConstantExpr::getIntrinsicIdentity(II->getIntrinsicID(), I->getType()))
+  Constant *BaseCaseConst = nullptr;
+  if (IsPseudoAssociative) {
+    // For pseudo-associative operations, we require that the recursive call
+    // is always on the first operand.
+    if (I->getOperand(0) != CI)
       return false;
-  }
 
-  // Exactly one operand should be the result of the call instruction.
-  if ((I->getOperand(0) == CI && I->getOperand(1) == CI) ||
-      (I->getOperand(0) != CI && I->getOperand(1) != CI))
-    return false;
+    BaseCaseConst = getReturnValue(*CI->getCalledFunction());
+    if (!BaseCaseConst)
+      return false;
+  } else {
+    if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
+      // Accumulators must have an identity.
+      if (!ConstantExpr::getIntrinsicIdentity(II->getIntrinsicID(),
+                                              I->getType()))
+        return false;
+    }
+
+    // Exactly one operand should be the result of the call instruction.
+    if ((I->getOperand(0) == CI && I->getOperand(1) == CI) ||
+        (I->getOperand(0) != CI && I->getOperand(1) != CI))
+      return false;
+  }
 
   // The only user of this instruction we allow is a single return instruction.
   if (!I->hasOneUse() || !isa<ReturnInst>(I->user_back()))
     return false;
 
+  if (BaseCaseConst)
+    return BaseCaseConst;
+
   return true;
 }
 
@@ -436,6 +540,10 @@ class TailRecursionEliminator {
   // The instruction doing the accumulating.
   Instruction *AccumulatorRecursionInstr = nullptr;
 
+  // The base case return value. It exists only if the accumulator
+  // operation is pseudo-associative.
+  Constant *BaseCaseValue = nullptr;
+
   TailRecursionEliminator(Function &F, const TargetTransformInfo *TTI,
                           AliasAnalysis *AA, OptimizationRemarkEmitter *ORE,
                           DomTreeUpdater &DTU, BlockFrequencyInfo *BFI)
@@ -491,7 +599,7 @@ CallInst *TailRecursionEliminator::findTRECandidate(BasicBlock *BB) {
       break;
 
     if (BBI == BB->begin())
-      return nullptr;          // Didn't find a potential tail call.
+      return nullptr; // Didn't find a potential tail call.
     --BBI;
   }
 
@@ -512,7 +620,8 @@ CallInst *TailRecursionEliminator::findTRECandidate(BasicBlock *BB) {
     auto I = CI->arg_begin(), E = CI->arg_end();
     Function::arg_iterator FI = F.arg_begin(), FE = F.arg_end();
     for (; I != E && FI != FE; ++I, ++FI)
-      if (*I != &*FI) break;
+      if (*I != &*FI)
+        break;
     if (I == E && FI == FE)
       return nullptr;
   }
@@ -594,9 +703,17 @@ void TailRecursionEliminator::insertAccumulator(Instruction *AccRecInstr) {
   for (pred_iterator PI = PB; PI != PE; ++PI) {
     BasicBlock *P = *PI;
     if (P == &F.getEntryBlock()) {
-      Constant *Identity =
+      Constant *InitialValue =
           ConstantExpr::getIdentity(AccRecInstr, AccRecInstr->getType());
-      AccPN->addIncoming(Identity, P);
+      if (!InitialValue) {
+        // We are in the presence of a pseudo-associative operation like shift.
+        // We didn't pass `AllowRHSConstant = true` in getIdentity above because
+        // the identity for shifts is zero, which is not valid for the LHS.
+        // We need the value of the base case(s) of the function to
+        // initialize the accumulator.
+        InitialValue = BaseCaseValue;
+      }
+      AccPN->addIncoming(InitialValue, P);
     } else {
       AccPN->addIncoming(AccPN, P);
     }
@@ -667,15 +784,24 @@ bool TailRecursionEliminator::eliminateCall(CallInst *CI) {
       continue;
 
     // If we can't move the instruction above the call, it might be because it
-    // is an associative and commutative operation that could be transformed
-    // using accumulator recursion elimination.  Check to see if this is the
-    // case, and if so, remember which instruction accumulates for later.
-    if (AccPN || !canTransformAccumulatorRecursion(&*BBI, CI))
+    // is an (associative and commutative) or pseudo-associative arithmetic
+    // operation that could be transformed using accumulator recursion
+    // elimination. Check to see if this is the case, and if so, remember which
+    // instruction accumulates for later.
+    std::variant<bool, Constant *> AccRecResult =
+        canTransformAccumulatorRecursion(&*BBI, CI);
+
+    if (AccPN || (std::holds_alternative<bool>(AccRecResult) &&
+                  !std::get<bool>(AccRecResult)))
       return false; // We cannot eliminate the tail recursion!
 
     // Yes, this is accumulator recursion.  Remember which instruction
     // accumulates.
     AccRecInstr = &*BBI;
+
+    // Keep track of the base case return value if any.
+    if (std::holds_alternative<Constant *>(AccRecResult))
+      BaseCaseValue = std::get<Constant *>(AccRecResult);
   }
 
   BasicBlock *BB = Ret->getParent();
@@ -752,8 +878,8 @@ bool TailRecursionEliminator::eliminateCall(CallInst *CI) {
   BranchInst *NewBI = BranchInst::Create(HeaderBB, Ret->getIterator());
   NewBI->setDebugLoc(CI->getDebugLoc());
 
-  Ret->eraseFromParent();  // Remove return.
-  CI->eraseFromParent();   // Remove call.
+  Ret->eraseFromParent(); // Remove return.
+  CI->eraseFromParent();  // Remove call.
   DTU.applyUpdates({{DominatorTree::Insert, BB, HeaderBB}});
   ++NumEliminated;
   if (OrigEntryBBFreq) {
@@ -814,13 +940,26 @@ void TailRecursionEliminator::cleanupAndFinalize() {
           if (!RI)
             continue;
 
-          Instruction *AccRecInstrNew = AccRecInstr->clone();
-          AccRecInstrNew->setName("accumulator.ret.tr");
-          AccRecInstrNew->setOperand(AccRecInstr->getOperand(0) == AccPN,
-                                     RI->getOperand(0));
-          AccRecInstrNew->insertBefore(RI->getIterator());
-          AccRecInstrNew->dropLocation();
-          RI->setOperand(0, AccRecInstrNew);
+          // Again, the BaseCaseValue exists only if the accumulator
+          // operation is pseudo-associative. In that case the accumulator PHI
+          // already holds the final result (it is initialized from the base
+          // case), so we should return it directly rather than applying the
+          // accumulation instruction one more time.
+          if (BaseCaseValue) {
+            RI->setOperand(0, AccPN);
+          } else {
+            // Otherwise, since the accumulator starts with the identity value,
+            // before the return we need to apply the accumulation instruction
+            // one more time to combine the last value with the result of the
+            // recursive call.
+            Instruction *AccRecInstrNew = AccRecInstr->clone();
+            AccRecInstrNew->setName("accumulator.ret.tr");
+            AccRecInstrNew->setOperand(AccRecInstr->getOperand(0) == AccPN,
+                                       RI->getOperand(0));
+            AccRecInstrNew->insertBefore(RI->getIterator());
+            AccRecInstrNew->dropLocation();
+            RI->setOperand(0, AccRecInstrNew);
+          }
         }
       }
     } else {
@@ -890,7 +1029,9 @@ bool TailRecursionEliminator::processBlock(BasicBlock &BB) {
 
     eliminateCall(CI);
     return true;
-  } else if (isa<ReturnInst>(TI)) {
+  }
+
+  if (isa<ReturnInst>(TI)) {
     CallInst *CI = findTRECandidate(&BB);
 
     if (CI)
@@ -955,9 +1096,9 @@ struct TailCallElim : public FunctionPass {
     auto *DT = DTWP ? &DTWP->getDomTree() : nullptr;
     auto *PDTWP = getAnalysisIfAvailable<PostDominatorTreeWrapperPass>();
     auto *PDT = PDTWP ? &PDTWP->getPostDomTree() : nullptr;
-    // There is no noticable performance difference here between Lazy and Eager
-    // UpdateStrategy based on some test results. It is feasible to switch the
-    // UpdateStrategy to Lazy if we find it profitable later.
+    // There is no noticable performance difference here between Lazy and
+    // Eager UpdateStrategy based on some test results. It is feasible to
+    // switch the UpdateStrategy to Lazy if we find it profitable later.
     DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Eager);
 
     return TailRecursionEliminator::eliminate(
diff --git a/llvm/test/Transforms/TailCallElim/shl-accumulator-opt.ll b/llvm/test/Transforms/TailCallElim/shl-accumulator-opt.ll
new file mode 100644
index 0000000000000..257a885672cf6
--- /dev/null
+++ b/llvm/test/Transforms/TailCallElim/shl-accumulator-opt.ll
@@ -0,0 +1,118 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes="tailcallelim" -verify-dom-info -S | FileCheck %s
+
+; NOTE: All the following test cases are generate from the underlying C code (-O1)
+;   before that the shift accumulator optimization was implemented
+
+
+
+; InstCombine strength-reduce `f(x-1) + f(x-1)` to shl:
+; int f(int x) {
+;     if (x == 1) return 7;
+;     return f(x-1) + f(x-1); // f(x-1) * 2
+; }
+define dso_local i32 @f(i32 noundef %x) local_unnamed_addr {
+; CHECK-LABEL: define dso_local i32 @f(
+; CHECK-SAME: i32 noundef [[X:%.*]]) local_unnamed_addr {
+; CHECK-NEXT:  [[ENTRY:.*]]:
+; CHECK-NEXT:    br label %[[TAILRECURSE:.*]]
+; CHECK:       [[TAILRECURSE]]:
+; CHECK-NEXT:    [[ACCUMULATOR_TR:%.*]] = phi i32 [ 7, %[[ENTRY]] ], [ [[ADD:%.*]], %[[IF_END:.*]] ]
+; CHECK-NEXT:    [[X_TR:%.*]] = phi i32 [ [[X]], %[[ENTRY]] ], [ [[SUB:%.*]], %[[IF_END]] ]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[X_TR]], 1
+; CHECK-NEXT:    br i1 [[CMP]], label %[[COMMON_RET:.*]], label %[[IF_END]]
+; CHECK:       [[COMMON_RET]]:
+; CHECK-NEXT:    ret i32 [[ACCUMULATOR_TR]]
+; CHECK:       [[IF_END]]:
+; CHECK-NEXT:    [[SUB]] = add nsw i32 [[X_TR]], -1
+; CHECK-NEXT:    [[ADD]] = shl nsw i32 [[ACCUMULATOR_TR]], 1
+; CHECK-NEXT:    br label %[[TAILRECURSE]]
+;
+entry:
+  %cmp = icmp eq i32 %x, 1
+  br i1 %cmp, label %common.ret, label %if.end
+
+common.ret:
+  %common.ret.op = phi i32 [ %add, %if.end ], [ 7, %entry ]
+  ret i32 %common.ret.op
+
+if.end:
+  %sub = add nsw i32 %x, -1
+  %call = tail call i32 @f(i32 noundef %sub)
+  %add = shl nsw i32 %call, 1
+  br label %common.ret
+}
+
+
+; int f2(int x) {
+;     if (x == 1) return 14;
+;     return f2(x-1) >> 1;
+; }
+define dso_local i32 @f2(i32 noundef %x) local_unnamed_addr {
+; CHECK-LABEL: define dso_local i32 @f2(
+; CHECK-SAME: i32 noundef [[X:%.*]]) local_unnamed_addr {
+; CHECK-NEXT:  [[ENTRY:.*]]:
+; CHECK-NEXT:    br label %[[TAILRECURSE:.*]]
+; CHECK:       [[TAILRECURSE]]:
+; CHECK-NEXT:    [[ACCUMULATOR_TR:%.*]] = phi i32 [ 14, %[[ENTRY]] ], [ [[SHR:%.*]], %[[IF_END:.*]] ]
+; CHECK-NEXT:    [[X_TR:%.*]] = phi i32 [ [[X]], %[[ENTRY]] ], [ [[SUB:%.*]], %[[IF_END]] ]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[X_TR]], 1
+; CHECK-NEXT:    br i1 [[CMP]], label %[[COMMON_RET:.*]], label %[[IF_END]]
+; CHECK:       [[COMMON_RET]]:
+; CHECK-NEXT:    ret i32 [[ACCUMULATOR_TR]]
+; CHECK:       [[IF_END]]:
+; CHECK-NEXT:    [[SUB]] = add nsw i32 [[X_TR]], -1
+; CHECK-NEXT:    [[SHR]] = ashr i32 [[ACCUMULATOR_TR]], 1
+; CHECK-NEXT:    br label %[[TAILRECURSE]]
+;
+entry:
+  %cmp = icmp eq i32 %x, 1
+  br i1 %cmp, label %common.ret, label %if.end
+
+common.ret:
+  %common.ret.op = phi i32 [ %shr, %if.end ], [ 14, %entry ]
+  ret i32 %common.ret.op
+
+if.end:
+  %sub = add nsw i32 %x, -1
+  %call = tail call i32 @f2(i32 noundef %sub)
+  %shr = ashr i32 %call, 1
+  br label %common.ret
+}
+
+
+; unsigned int f3(unsigned int x) {
+;     if (x <= 1) return 14;
+;     return f3(x - 1) >> 1;
+; }
+define dso_local i32 @f3(i32 noundef %x) local_unnamed_addr {
+; CHECK-LABEL: define dso_local i32 @f3(
+; CHECK-SAME: i32 noundef [[X:%.*]]) local_unnamed_addr {
+; CHECK-NEXT:  [[ENTRY:.*]]:
+; CHECK-NEXT:    br label %[[TAILRECURSE:.*]]
+; CHECK:       [[TAILRECURSE]]:
+; CHECK-NEXT:    [[ACCUMULATOR_TR:%.*]] = phi i32 [ 21, %[[ENTRY]] ], [ [[SHR:%.*]], %[[IF_END:.*]] ]
+; CHECK-NEXT:    [[X_TR:%.*]] = phi i32 [ [[X]], %[[ENTRY]] ], [ [[SUB:%.*]], %[[IF_END]] ]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i32 [[X_TR]], 2
+; CHECK-NEXT:    br i1 [[CMP]], label %[[COMMON_RET:.*]], label %[[IF_END]]
+; CHECK:       [[COMMON_RET]]:
+; CHECK-NEXT:    ret i32 [[ACCUMULATOR_TR]]
+; CHECK:       [[IF_END]]:
+; CHECK-NEXT:    [[SUB]] = add i32 [[X_TR]], -1
+; CHECK-NEXT:    [[SHR]] = lshr i32 [[ACCUMULATOR_TR]], 1
+; CHECK-NEXT:    br label %[[TAILRECURSE]]
+;
+entry:
+  %cmp = icmp ult i32 %x, 2
+  br i1 %cmp, label %common.ret, label %if.end
+
+common.ret:
+  %common.ret.op = phi i32 [ %shr, %if.end ], [ 21, %entry ]
+  ret i32 %common.ret.op
+
+if.end:
+  %sub = add i32 %x, -1
+  %call = tail call i32 @f3(i32 noundef %sub)
+  %shr = lshr i32 %call, 1
+  br label %common.ret
+}



More information about the llvm-commits mailing list