[llvm] [TailRecElim] Introduce support for shift accumulator optimization (PR #181331)
Federico Bruzzone via llvm-commits
llvm-commits at lists.llvm.org
Fri Feb 13 00:28:22 PST 2026
https://github.com/FedericoBruzzone created https://github.com/llvm/llvm-project/pull/181331
This PR enables Tail Recursion Elimination (TRE) for functions where the accumulator operation is a shift (`shl`, `lshr`, `ashr`) by a constant amount -- i.e., pseudo-associative relation.
As pointed out in #178805, `InstCombine` often strength-reduces multiplications (or `f(x-1) + f(x-1)`) into `shl`.
Currently, TRE strictly requires operations to be associative and commutative:
https://github.com/llvm/llvm-project/blob/05e908609227e1e8d993659e604a63668dfd2825/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp#L377-L379
This prevents TRE from transforming recursive shifts into loops, creating a phase-ordering problem where canonicalization blocks a structural optimization.
This PR does **not** perform shift accumulator optimization when there are multiple base cases: it is reserved for future work.
Fixes #178805.
>From dc82a90c7365ff4dd8009309c07a24efd880536d Mon Sep 17 00:00:00 2001
From: FedericoBruzzone <federico.bruzzone.i at gmail.com>
Date: Fri, 13 Feb 2026 08:50:16 +0100
Subject: [PATCH] [TailRecElim] Add support for shift accumulator optimization
Signed-off-by: FedericoBruzzone <federico.bruzzone.i at gmail.com>
---
.../Scalar/TailRecursionElimination.cpp | 221 ++++++++++++++----
.../TailCallElim/shl-accumulator-opt.ll | 118 ++++++++++
2 files changed, 299 insertions(+), 40 deletions(-)
create mode 100644 llvm/test/Transforms/TailCallElim/shl-accumulator-opt.ll
diff --git a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
index 89d41f3e40de7..8af925a105aad 100644
--- a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
@@ -82,12 +82,13 @@
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include <cmath>
+#include <variant>
using namespace llvm;
#define DEBUG_TYPE "tailcallelim"
STATISTIC(NumEliminated, "Number of tail calls removed");
-STATISTIC(NumRetDuped, "Number of return duplicated");
+STATISTIC(NumRetDuped, "Number of return duplicated");
STATISTIC(NumAccumAdded, "Number of accumulators introduced");
static cl::opt<bool> ForceDisableBFI(
@@ -159,7 +160,7 @@ struct AllocaDerivedValueTracker {
case Instruction::Store: {
if (U->getOperandNo() == 0)
EscapePoints.insert(I);
- continue; // Stores have no users to analyze.
+ continue; // Stores have no users to analyze.
}
case Instruction::BitCast:
case Instruction::GetElementPtr:
@@ -215,11 +216,7 @@ static bool markTails(Function &F, OptimizationRemarkEmitter *ORE) {
// Track whether a block is reachable after an alloca has escaped. Blocks that
// contain the escaping instruction will be marked as being visited without an
// escaped alloca, since that is how the block began.
- enum VisitType {
- UNVISITED,
- UNESCAPED,
- ESCAPED
- };
+ enum VisitType { UNVISITED, UNESCAPED, ESCAPED };
DenseMap<BasicBlock *, VisitType> Visited;
// We propagate the fact that an alloca has escaped from block to successor.
@@ -340,7 +337,6 @@ static bool markTails(Function &F, OptimizationRemarkEmitter *ORE) {
/// Return true if it is safe to move the specified
/// instruction from after the call to before the call, assuming that all
/// instructions between the call and this instruction are movable.
-///
static bool canMoveAboveCall(Instruction *I, CallInst *CI, AliasAnalysis *AA) {
if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I))
if (II->getIntrinsicID() == Intrinsic::lifetime_end)
@@ -348,7 +344,7 @@ static bool canMoveAboveCall(Instruction *I, CallInst *CI, AliasAnalysis *AA) {
// FIXME: We can move load/store/call/free instructions above the call if the
// call does not mod/ref the memory location being processed.
- if (I->mayHaveSideEffects()) // This also handles volatile loads.
+ if (I->mayHaveSideEffects()) // This also handles volatile loads.
return false;
if (LoadInst *L = dyn_cast<LoadInst>(I)) {
@@ -374,28 +370,136 @@ static bool canMoveAboveCall(Instruction *I, CallInst *CI, AliasAnalysis *AA) {
return !is_contained(I->operands(), CI);
}
-static bool canTransformAccumulatorRecursion(Instruction *I, CallInst *CI) {
- if (!I->isAssociative() || !I->isCommutative())
+// While shifts are neither associative nor commutative, a chain of shifts by a
+// constant amount C is equivalent to a single shift by the sum of the amounts:
+// ... (Base << C) << C) ... << C == Base << (C * Iterations)
+// This relation applies to left shifts as well as arithmetic/logical right
+// shifts when the shift amount is a constant.
+static bool isPseudoAssociative(Instruction *I) {
+ switch (I->getOpcode()) {
+ case Instruction::Shl:
+ case Instruction::AShr:
+ case Instruction::LShr:
+ break;
+ default:
+ return false;
+ }
+
+ ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1));
+ if (!CI)
+ return false;
+
+ return true;
+}
+
+// Find the base-case return value for function F: examine all
+// return instructions and pick return values that do not depend on a
+// recursive call to F. If there is exactly one distinct such value,
+// return it. If there are none or more than one distinct value, return
+// nullptr to indicate failure.
+//
+// FIXME: There is a room for improvement here in the future, e.g., consider
+// non-constant values and multiple base cases -- e.g., we want to be able to
+// handle code like:
+// ```
+// int f(int x) {
+// if (x == 1) return 1;
+// if (x == 10) return 10;
+// return f(x-1) << 1;
+// }
+// ```
+static Constant *getReturnValue(Function &F) {
+ Constant *BaseCaseVal = nullptr;
+
+ auto ValueUsesRecursiveCall = [&](Value *V) {
+ SmallVector<Value *, 8> Worklist;
+ SmallPtrSet<Value *, 8> Visited;
+ Worklist.push_back(V);
+ while (!Worklist.empty()) {
+ Value *Cur = Worklist.pop_back_val();
+ if (!Visited.insert(Cur).second)
+ continue;
+ if (Instruction *I = dyn_cast<Instruction>(Cur)) {
+ if (CallInst *CI = dyn_cast<CallInst>(I))
+ if (CI->getCalledFunction() == &F)
+ return true;
+ for (Use &U : I->operands())
+ Worklist.push_back(U.get());
+ } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Cur)) {
+ for (unsigned I = 0, E = CE->getNumOperands(); I != E; ++I)
+ Worklist.push_back(CE->getOperand(I));
+ }
+ }
+ return false;
+ };
+
+ for (BasicBlock &BB : F) {
+ if (ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator())) {
+ Value *RV = RI->getReturnValue();
+ if (!RV)
+ continue;
+ // It isn't so trivial to perform TRE if the return value depends on the
+ // result of a recursive call, because the return value will be different
+ // for different iterations of the recursion. So we ignore return values
+ // that depend on recursive calls.
+ if (ValueUsesRecursiveCall(RV))
+ continue;
+ if (!BaseCaseVal && isa<Constant>(RV))
+ BaseCaseVal = cast<Constant>(RV);
+ else if (BaseCaseVal != RV)
+ return nullptr;
+ }
+ }
+
+ return BaseCaseVal;
+}
+
+// This function checks whether the instruction I can be used
+// to perform accumulator recursion elimination for the
+// call instruction CI.
+// In the presence of pseudo-associative operations, it returns
+// the base case constant value to both indicate success and
+// provide the value needed to initialize the accumulator.
+static std::variant<bool, Constant *>
+canTransformAccumulatorRecursion(Instruction *I, CallInst *CI) {
+ auto IsPseudoAssociative = isPseudoAssociative(I);
+ if ((!I->isAssociative() || !I->isCommutative()) && !IsPseudoAssociative)
return false;
assert(I->getNumOperands() >= 2 &&
"Associative/commutative operations should have at least 2 args!");
- if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
- // Accumulators must have an identity.
- if (!ConstantExpr::getIntrinsicIdentity(II->getIntrinsicID(), I->getType()))
+ Constant *BaseCaseConst = nullptr;
+ if (IsPseudoAssociative) {
+ // For pseudo-associative operations, we require that the recursive call
+ // is always on the first operand.
+ if (I->getOperand(0) != CI)
return false;
- }
- // Exactly one operand should be the result of the call instruction.
- if ((I->getOperand(0) == CI && I->getOperand(1) == CI) ||
- (I->getOperand(0) != CI && I->getOperand(1) != CI))
- return false;
+ BaseCaseConst = getReturnValue(*CI->getCalledFunction());
+ if (!BaseCaseConst)
+ return false;
+ } else {
+ if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
+ // Accumulators must have an identity.
+ if (!ConstantExpr::getIntrinsicIdentity(II->getIntrinsicID(),
+ I->getType()))
+ return false;
+ }
+
+ // Exactly one operand should be the result of the call instruction.
+ if ((I->getOperand(0) == CI && I->getOperand(1) == CI) ||
+ (I->getOperand(0) != CI && I->getOperand(1) != CI))
+ return false;
+ }
// The only user of this instruction we allow is a single return instruction.
if (!I->hasOneUse() || !isa<ReturnInst>(I->user_back()))
return false;
+ if (BaseCaseConst)
+ return BaseCaseConst;
+
return true;
}
@@ -436,6 +540,10 @@ class TailRecursionEliminator {
// The instruction doing the accumulating.
Instruction *AccumulatorRecursionInstr = nullptr;
+ // The base case return value. It exists only if the accumulator
+ // operation is pseudo-associative.
+ Constant *BaseCaseValue = nullptr;
+
TailRecursionEliminator(Function &F, const TargetTransformInfo *TTI,
AliasAnalysis *AA, OptimizationRemarkEmitter *ORE,
DomTreeUpdater &DTU, BlockFrequencyInfo *BFI)
@@ -491,7 +599,7 @@ CallInst *TailRecursionEliminator::findTRECandidate(BasicBlock *BB) {
break;
if (BBI == BB->begin())
- return nullptr; // Didn't find a potential tail call.
+ return nullptr; // Didn't find a potential tail call.
--BBI;
}
@@ -512,7 +620,8 @@ CallInst *TailRecursionEliminator::findTRECandidate(BasicBlock *BB) {
auto I = CI->arg_begin(), E = CI->arg_end();
Function::arg_iterator FI = F.arg_begin(), FE = F.arg_end();
for (; I != E && FI != FE; ++I, ++FI)
- if (*I != &*FI) break;
+ if (*I != &*FI)
+ break;
if (I == E && FI == FE)
return nullptr;
}
@@ -594,9 +703,17 @@ void TailRecursionEliminator::insertAccumulator(Instruction *AccRecInstr) {
for (pred_iterator PI = PB; PI != PE; ++PI) {
BasicBlock *P = *PI;
if (P == &F.getEntryBlock()) {
- Constant *Identity =
+ Constant *InitialValue =
ConstantExpr::getIdentity(AccRecInstr, AccRecInstr->getType());
- AccPN->addIncoming(Identity, P);
+ if (!InitialValue) {
+ // We are in the presence of a pseudo-associative operation like shift.
+ // We didn't pass `AllowRHSConstant = true` in getIdentity above because
+ // the identity for shifts is zero, which is not valid for the LHS.
+ // We need the value of the base case(s) of the function to
+ // initialize the accumulator.
+ InitialValue = BaseCaseValue;
+ }
+ AccPN->addIncoming(InitialValue, P);
} else {
AccPN->addIncoming(AccPN, P);
}
@@ -667,15 +784,24 @@ bool TailRecursionEliminator::eliminateCall(CallInst *CI) {
continue;
// If we can't move the instruction above the call, it might be because it
- // is an associative and commutative operation that could be transformed
- // using accumulator recursion elimination. Check to see if this is the
- // case, and if so, remember which instruction accumulates for later.
- if (AccPN || !canTransformAccumulatorRecursion(&*BBI, CI))
+ // is an (associative and commutative) or pseudo-associative arithmetic
+ // operation that could be transformed using accumulator recursion
+ // elimination. Check to see if this is the case, and if so, remember which
+ // instruction accumulates for later.
+ std::variant<bool, Constant *> AccRecResult =
+ canTransformAccumulatorRecursion(&*BBI, CI);
+
+ if (AccPN || (std::holds_alternative<bool>(AccRecResult) &&
+ !std::get<bool>(AccRecResult)))
return false; // We cannot eliminate the tail recursion!
// Yes, this is accumulator recursion. Remember which instruction
// accumulates.
AccRecInstr = &*BBI;
+
+ // Keep track of the base case return value if any.
+ if (std::holds_alternative<Constant *>(AccRecResult))
+ BaseCaseValue = std::get<Constant *>(AccRecResult);
}
BasicBlock *BB = Ret->getParent();
@@ -752,8 +878,8 @@ bool TailRecursionEliminator::eliminateCall(CallInst *CI) {
BranchInst *NewBI = BranchInst::Create(HeaderBB, Ret->getIterator());
NewBI->setDebugLoc(CI->getDebugLoc());
- Ret->eraseFromParent(); // Remove return.
- CI->eraseFromParent(); // Remove call.
+ Ret->eraseFromParent(); // Remove return.
+ CI->eraseFromParent(); // Remove call.
DTU.applyUpdates({{DominatorTree::Insert, BB, HeaderBB}});
++NumEliminated;
if (OrigEntryBBFreq) {
@@ -814,13 +940,26 @@ void TailRecursionEliminator::cleanupAndFinalize() {
if (!RI)
continue;
- Instruction *AccRecInstrNew = AccRecInstr->clone();
- AccRecInstrNew->setName("accumulator.ret.tr");
- AccRecInstrNew->setOperand(AccRecInstr->getOperand(0) == AccPN,
- RI->getOperand(0));
- AccRecInstrNew->insertBefore(RI->getIterator());
- AccRecInstrNew->dropLocation();
- RI->setOperand(0, AccRecInstrNew);
+ // Again, the BaseCaseValue exists only if the accumulator
+ // operation is pseudo-associative. In that case the accumulator PHI
+ // already holds the final result (it is initialized from the base
+ // case), so we should return it directly rather than applying the
+ // accumulation instruction one more time.
+ if (BaseCaseValue) {
+ RI->setOperand(0, AccPN);
+ } else {
+ // Otherwise, since the accumulator starts with the identity value,
+ // before the return we need to apply the accumulation instruction
+ // one more time to combine the last value with the result of the
+ // recursive call.
+ Instruction *AccRecInstrNew = AccRecInstr->clone();
+ AccRecInstrNew->setName("accumulator.ret.tr");
+ AccRecInstrNew->setOperand(AccRecInstr->getOperand(0) == AccPN,
+ RI->getOperand(0));
+ AccRecInstrNew->insertBefore(RI->getIterator());
+ AccRecInstrNew->dropLocation();
+ RI->setOperand(0, AccRecInstrNew);
+ }
}
}
} else {
@@ -890,7 +1029,9 @@ bool TailRecursionEliminator::processBlock(BasicBlock &BB) {
eliminateCall(CI);
return true;
- } else if (isa<ReturnInst>(TI)) {
+ }
+
+ if (isa<ReturnInst>(TI)) {
CallInst *CI = findTRECandidate(&BB);
if (CI)
@@ -955,9 +1096,9 @@ struct TailCallElim : public FunctionPass {
auto *DT = DTWP ? &DTWP->getDomTree() : nullptr;
auto *PDTWP = getAnalysisIfAvailable<PostDominatorTreeWrapperPass>();
auto *PDT = PDTWP ? &PDTWP->getPostDomTree() : nullptr;
- // There is no noticable performance difference here between Lazy and Eager
- // UpdateStrategy based on some test results. It is feasible to switch the
- // UpdateStrategy to Lazy if we find it profitable later.
+ // There is no noticable performance difference here between Lazy and
+ // Eager UpdateStrategy based on some test results. It is feasible to
+ // switch the UpdateStrategy to Lazy if we find it profitable later.
DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Eager);
return TailRecursionEliminator::eliminate(
diff --git a/llvm/test/Transforms/TailCallElim/shl-accumulator-opt.ll b/llvm/test/Transforms/TailCallElim/shl-accumulator-opt.ll
new file mode 100644
index 0000000000000..257a885672cf6
--- /dev/null
+++ b/llvm/test/Transforms/TailCallElim/shl-accumulator-opt.ll
@@ -0,0 +1,118 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes="tailcallelim" -verify-dom-info -S | FileCheck %s
+
+; NOTE: All the following test cases are generate from the underlying C code (-O1)
+; before that the shift accumulator optimization was implemented
+
+
+
+; InstCombine strength-reduce `f(x-1) + f(x-1)` to shl:
+; int f(int x) {
+; if (x == 1) return 7;
+; return f(x-1) + f(x-1); // f(x-1) * 2
+; }
+define dso_local i32 @f(i32 noundef %x) local_unnamed_addr {
+; CHECK-LABEL: define dso_local i32 @f(
+; CHECK-SAME: i32 noundef [[X:%.*]]) local_unnamed_addr {
+; CHECK-NEXT: [[ENTRY:.*]]:
+; CHECK-NEXT: br label %[[TAILRECURSE:.*]]
+; CHECK: [[TAILRECURSE]]:
+; CHECK-NEXT: [[ACCUMULATOR_TR:%.*]] = phi i32 [ 7, %[[ENTRY]] ], [ [[ADD:%.*]], %[[IF_END:.*]] ]
+; CHECK-NEXT: [[X_TR:%.*]] = phi i32 [ [[X]], %[[ENTRY]] ], [ [[SUB:%.*]], %[[IF_END]] ]
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[X_TR]], 1
+; CHECK-NEXT: br i1 [[CMP]], label %[[COMMON_RET:.*]], label %[[IF_END]]
+; CHECK: [[COMMON_RET]]:
+; CHECK-NEXT: ret i32 [[ACCUMULATOR_TR]]
+; CHECK: [[IF_END]]:
+; CHECK-NEXT: [[SUB]] = add nsw i32 [[X_TR]], -1
+; CHECK-NEXT: [[ADD]] = shl nsw i32 [[ACCUMULATOR_TR]], 1
+; CHECK-NEXT: br label %[[TAILRECURSE]]
+;
+entry:
+ %cmp = icmp eq i32 %x, 1
+ br i1 %cmp, label %common.ret, label %if.end
+
+common.ret:
+ %common.ret.op = phi i32 [ %add, %if.end ], [ 7, %entry ]
+ ret i32 %common.ret.op
+
+if.end:
+ %sub = add nsw i32 %x, -1
+ %call = tail call i32 @f(i32 noundef %sub)
+ %add = shl nsw i32 %call, 1
+ br label %common.ret
+}
+
+
+; int f2(int x) {
+; if (x == 1) return 14;
+; return f2(x-1) >> 1;
+; }
+define dso_local i32 @f2(i32 noundef %x) local_unnamed_addr {
+; CHECK-LABEL: define dso_local i32 @f2(
+; CHECK-SAME: i32 noundef [[X:%.*]]) local_unnamed_addr {
+; CHECK-NEXT: [[ENTRY:.*]]:
+; CHECK-NEXT: br label %[[TAILRECURSE:.*]]
+; CHECK: [[TAILRECURSE]]:
+; CHECK-NEXT: [[ACCUMULATOR_TR:%.*]] = phi i32 [ 14, %[[ENTRY]] ], [ [[SHR:%.*]], %[[IF_END:.*]] ]
+; CHECK-NEXT: [[X_TR:%.*]] = phi i32 [ [[X]], %[[ENTRY]] ], [ [[SUB:%.*]], %[[IF_END]] ]
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[X_TR]], 1
+; CHECK-NEXT: br i1 [[CMP]], label %[[COMMON_RET:.*]], label %[[IF_END]]
+; CHECK: [[COMMON_RET]]:
+; CHECK-NEXT: ret i32 [[ACCUMULATOR_TR]]
+; CHECK: [[IF_END]]:
+; CHECK-NEXT: [[SUB]] = add nsw i32 [[X_TR]], -1
+; CHECK-NEXT: [[SHR]] = ashr i32 [[ACCUMULATOR_TR]], 1
+; CHECK-NEXT: br label %[[TAILRECURSE]]
+;
+entry:
+ %cmp = icmp eq i32 %x, 1
+ br i1 %cmp, label %common.ret, label %if.end
+
+common.ret:
+ %common.ret.op = phi i32 [ %shr, %if.end ], [ 14, %entry ]
+ ret i32 %common.ret.op
+
+if.end:
+ %sub = add nsw i32 %x, -1
+ %call = tail call i32 @f2(i32 noundef %sub)
+ %shr = ashr i32 %call, 1
+ br label %common.ret
+}
+
+
+; unsigned int f3(unsigned int x) {
+; if (x <= 1) return 14;
+; return f3(x - 1) >> 1;
+; }
+define dso_local i32 @f3(i32 noundef %x) local_unnamed_addr {
+; CHECK-LABEL: define dso_local i32 @f3(
+; CHECK-SAME: i32 noundef [[X:%.*]]) local_unnamed_addr {
+; CHECK-NEXT: [[ENTRY:.*]]:
+; CHECK-NEXT: br label %[[TAILRECURSE:.*]]
+; CHECK: [[TAILRECURSE]]:
+; CHECK-NEXT: [[ACCUMULATOR_TR:%.*]] = phi i32 [ 21, %[[ENTRY]] ], [ [[SHR:%.*]], %[[IF_END:.*]] ]
+; CHECK-NEXT: [[X_TR:%.*]] = phi i32 [ [[X]], %[[ENTRY]] ], [ [[SUB:%.*]], %[[IF_END]] ]
+; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[X_TR]], 2
+; CHECK-NEXT: br i1 [[CMP]], label %[[COMMON_RET:.*]], label %[[IF_END]]
+; CHECK: [[COMMON_RET]]:
+; CHECK-NEXT: ret i32 [[ACCUMULATOR_TR]]
+; CHECK: [[IF_END]]:
+; CHECK-NEXT: [[SUB]] = add i32 [[X_TR]], -1
+; CHECK-NEXT: [[SHR]] = lshr i32 [[ACCUMULATOR_TR]], 1
+; CHECK-NEXT: br label %[[TAILRECURSE]]
+;
+entry:
+ %cmp = icmp ult i32 %x, 2
+ br i1 %cmp, label %common.ret, label %if.end
+
+common.ret:
+ %common.ret.op = phi i32 [ %shr, %if.end ], [ 21, %entry ]
+ ret i32 %common.ret.op
+
+if.end:
+ %sub = add i32 %x, -1
+ %call = tail call i32 @f3(i32 noundef %sub)
+ %shr = lshr i32 %call, 1
+ br label %common.ret
+}
More information about the llvm-commits
mailing list