[clang] Attribute support `[[clang::musttail]]` in `ExprConstant.cpp` (work in progress) (PR #138477)
Hana Dusíková via cfe-commits
cfe-commits at lists.llvm.org
Sun May 4 14:32:01 PDT 2025
https://github.com/hanickadot created https://github.com/llvm/llvm-project/pull/138477
This change makes `[[clang::musttail]]` work. Function calls marked with this attribute won't use system stack, but will loop after nearest function call. The attribute is already very strick, and checks all problematic cases (non-trivial destructors, referencing local variables).
This PR is work in progress.
>From f084366a545f5e2c0ec54fa7cc4dd688950c13af Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Hana=20Dusi=CC=81kova=CC=81?= <hanicka at hanicka.net>
Date: Sun, 4 May 2025 23:27:09 +0200
Subject: [PATCH] [clang] Attribute support [[clang::musttail]] in
ExprConstant.cpp allows guaranteed tail recursion.
---
clang/lib/AST/ExprConstant.cpp | 260 +++++++++++++++++++++++++++------
1 file changed, 219 insertions(+), 41 deletions(-)
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index b79d8c197fe7d..9ef6b983d196a 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -735,6 +735,13 @@ namespace {
ScopeKind Scope)
: Value(Val, Scope), Base(Base), T(T) {}
+ Cleanup(Cleanup &&Other) noexcept
+ : Value{Other.Value}, Base{Other.Base}, T{Other.T} {
+ Other.Value = {};
+ }
+
+ Cleanup &operator=(Cleanup &&) = default;
+
/// Determine whether this cleanup should be performed at the end of the
/// given kind of scope.
bool isDestroyedAtEndOf(ScopeKind K) const {
@@ -1006,6 +1013,24 @@ namespace {
EM_IgnoreSideEffects,
} EvalMode;
+ /// Pointer to last tail recursion enabled return. Enforced with
+ /// [[clang::musttail]]
+ const ReturnStmt *TailRecursionReturnStmt = nullptr;
+
+ struct DeferRecursionFunctionCall {
+ const CallExpr *E{nullptr};
+ const FunctionDecl *Definition{nullptr};
+ bool HasThis{false};
+ APValue ThisVal{}; // can't use LValue here :(
+ llvm::ArrayRef<const clang::Expr *> Args{};
+ CallRef Call{};
+ Stmt *Body{nullptr};
+ SmallVector<QualType, 4> CovariantAdjustmentPath{};
+ SmallVector<Cleanup, 16> ArgumentsStored{};
+ };
+
+ DeferRecursionFunctionCall DeferFunctionCall{};
+
/// Are we checking whether the expression is a potential constant
/// expression?
bool checkingPotentialConstantExpression() const override {
@@ -1124,6 +1149,21 @@ namespace {
return Result;
}
+ void EnableTailRecursion(const ReturnStmt *ret) {
+ TailRecursionReturnStmt = ret;
+ }
+
+ void DisableTailRecursion() { TailRecursionReturnStmt = nullptr; }
+
+ bool TailRecursionReady() const { return DeferFunctionCall.E != nullptr; }
+
+ bool IsTailRecursion(const ReturnStmt *ret) {
+ if (TailRecursionReturnStmt != ret)
+ return false;
+ TailRecursionReturnStmt = nullptr;
+ return true;
+ }
+
/// Get the allocated storage for the given parameter of the given call.
APValue *getParamSlot(CallRef Call, const ParmVarDecl *PVD) {
CallStackFrame *Frame = getCallFrameAndDepth(Call.CallIndex).first;
@@ -1439,6 +1479,12 @@ namespace {
// instances of this class.
Info.CurrentCall->popTempVersion();
}
+
+ friend void transferFromCallScope(ScopeRAII &,
+ llvm::SmallVectorImpl<Cleanup> &);
+ friend bool transferIntoCallScope(ScopeRAII &,
+ llvm::SmallVectorImpl<Cleanup> &);
+
private:
static bool cleanup(EvalInfo &Info, bool RunDestructors,
unsigned OldStackSize) {
@@ -1457,6 +1503,10 @@ namespace {
}
}
+ compact(Info, OldStackSize);
+ return Success;
+ }
+ static void compact(EvalInfo &Info, unsigned OldStackSize) {
// Compact any retained cleanups.
auto NewEnd = Info.CleanupStack.begin() + OldStackSize;
if (Kind != ScopeKind::Block)
@@ -1465,12 +1515,47 @@ namespace {
return C.isDestroyedAtEndOf(Kind);
});
Info.CleanupStack.erase(NewEnd, Info.CleanupStack.end());
- return Success;
}
};
typedef ScopeRAII<ScopeKind::Block> BlockScopeRAII;
typedef ScopeRAII<ScopeKind::FullExpression> FullExpressionRAII;
typedef ScopeRAII<ScopeKind::Call> CallScopeRAII;
+
+ static void transferFromCallScope(CallScopeRAII &Scope,
+ llvm::SmallVectorImpl<Cleanup> &Backup) {
+ Backup.clear();
+
+ auto CurrentVariables = MutableArrayRef<Cleanup>(Scope.Info.CleanupStack)
+ .slice(Scope.OldStackSize);
+
+ // Transfer of cleanup informations of tail call outside of current scope.
+ // These variables are going to be destroyed in current scope, which only
+ // prepares the tail call, but is not doing it.
+ Backup.clear();
+
+ for (Cleanup &Lifetime : CurrentVariables) {
+ Backup.push_back(std::move(Lifetime));
+ }
+
+ // Remove lifetime management from this scope.
+ Scope.compact(Scope.Info, Scope.OldStackSize);
+ Scope.Info.CleanupStack.truncate(
+ Scope.OldStackSize); // make sure this is ok
+ assert(Scope.Info.CleanupStack.size() == Scope.OldStackSize);
+ }
+
+ static bool transferIntoCallScope(CallScopeRAII &Scope,
+ llvm::SmallVectorImpl<Cleanup> &Backup) {
+ if (!Scope.cleanup(Scope.Info, true, Scope.OldStackSize))
+ return false;
+
+ for (auto &Lifetime : Backup) {
+ Scope.Info.CleanupStack.push_back(std::move(Lifetime));
+ }
+
+ Backup.clear();
+ return true;
+ }
}
bool SubobjectDesignator::checkSubobject(EvalInfo &Info, const Expr *E,
@@ -5614,10 +5699,14 @@ static EvalStmtResult EvaluateStmt(StmtResult &Result, EvalInfo &Info,
// We know we returned, but we don't know what the value is.
return ESR_Failed;
}
- if (RetExpr &&
- !(Result.Slot
- ? EvaluateInPlace(Result.Value, Info, *Result.Slot, RetExpr)
- : Evaluate(Result.Value, Info, RetExpr)))
+
+ if (!RetExpr || !isa<CallExpr>(RetExpr)) {
+ Info.DisableTailRecursion();
+ }
+
+ if (RetExpr && !(Result.Slot ? EvaluateInPlace(Result.Value, Info,
+ *Result.Slot, RetExpr)
+ : Evaluate(Result.Value, Info, RetExpr)))
return ESR_Failed;
return Scope.destroy() ? ESR_Returned : ESR_Failed;
}
@@ -5869,32 +5958,37 @@ static EvalStmtResult EvaluateStmt(StmtResult &Result, EvalInfo &Info,
case Stmt::AttributedStmtClass: {
const auto *AS = cast<AttributedStmt>(S);
const auto *SS = AS->getSubStmt();
+ const auto *RS = dyn_cast<ReturnStmt>(SS);
MSConstexprContextRAII ConstexprContext(
- *Info.CurrentCall, hasSpecificAttr<MSConstexprAttr>(AS->getAttrs()) &&
- isa<ReturnStmt>(SS));
+ *Info.CurrentCall,
+ hasSpecificAttr<MSConstexprAttr>(AS->getAttrs()) && RS != nullptr);
auto LO = Info.getASTContext().getLangOpts();
- if (LO.CXXAssumptions && !LO.MSVCCompat) {
- for (auto *Attr : AS->getAttrs()) {
- auto *AA = dyn_cast<CXXAssumeAttr>(Attr);
- if (!AA)
- continue;
-
- auto *Assumption = AA->getAssumption();
- if (Assumption->isValueDependent())
- return ESR_Failed;
+ for (auto *Attr : AS->getAttrs()) {
+ if (auto *AA = dyn_cast<CXXAssumeAttr>(Attr)) {
+ // This branch handles C++'s [[assume(<EXPR>)]]
+ if (LO.CXXAssumptions && !LO.MSVCCompat) {
+ auto *Assumption = AA->getAssumption();
+ if (Assumption->isValueDependent())
+ return ESR_Failed;
- if (Assumption->HasSideEffects(Info.getASTContext()))
- continue;
+ if (Assumption->HasSideEffects(Info.getASTContext()))
+ continue;
- bool Value;
- if (!EvaluateAsBooleanCondition(Assumption, Value, Info))
- return ESR_Failed;
- if (!Value) {
- Info.CCEDiag(Assumption->getExprLoc(),
- diag::note_constexpr_assumption_failed);
- return ESR_Failed;
+ bool Value;
+ if (!EvaluateAsBooleanCondition(Assumption, Value, Info))
+ return ESR_Failed;
+ if (!Value) {
+ Info.CCEDiag(Assumption->getExprLoc(),
+ diag::note_constexpr_assumption_failed);
+ return ESR_Failed;
+ }
}
+ } else if (isa<MustTailAttr>(Attr) && RS != nullptr) {
+ // This branch handles [[clang::mustttail]] enforcement on
+ // tail-recursion which is strict and already checked, otherwise it will
+ // fail to compile.
+ Info.EnableTailRecursion(RS);
}
}
@@ -6514,16 +6608,16 @@ static bool MaybeHandleUnionActiveMemberChange(EvalInfo &Info,
static bool EvaluateCallArg(const ParmVarDecl *PVD, const Expr *Arg,
CallRef Call, EvalInfo &Info,
- bool NonNull = false) {
+ CallStackFrame &CallerFrame, bool NonNull = false) {
LValue LV;
// Create the parameter slot and register its destruction. For a vararg
// argument, create a temporary.
// FIXME: For calling conventions that destroy parameters in the callee,
// should we consider performing destruction when the function returns
// instead?
- APValue &V = PVD ? Info.CurrentCall->createParam(Call, PVD, LV)
- : Info.CurrentCall->createTemporary(Arg, Arg->getType(),
- ScopeKind::Call, LV);
+ APValue &V = PVD ? CallerFrame.createParam(Call, PVD, LV)
+ : CallerFrame.createTemporary(Arg, Arg->getType(),
+ ScopeKind::Call, LV);
if (!EvaluateInPlace(V, Info, LV, Arg))
return false;
@@ -6539,8 +6633,8 @@ static bool EvaluateCallArg(const ParmVarDecl *PVD, const Expr *Arg,
/// Evaluate the arguments to a function call.
static bool EvaluateArgs(ArrayRef<const Expr *> Args, CallRef Call,
- EvalInfo &Info, const FunctionDecl *Callee,
- bool RightToLeft = false) {
+ EvalInfo &Info, CallStackFrame &CallerFrame,
+ const FunctionDecl *Callee, bool RightToLeft = false) {
bool Success = true;
llvm::SmallBitVector ForbiddenNullArgs;
if (Callee->hasAttr<NonNullAttr>()) {
@@ -6563,7 +6657,7 @@ static bool EvaluateArgs(ArrayRef<const Expr *> Args, CallRef Call,
const ParmVarDecl *PVD =
Idx < Callee->getNumParams() ? Callee->getParamDecl(Idx) : nullptr;
bool NonNull = !ForbiddenNullArgs.empty() && ForbiddenNullArgs[Idx];
- if (!EvaluateCallArg(PVD, Args[Idx], Call, Info, NonNull)) {
+ if (!EvaluateCallArg(PVD, Args[Idx], Call, Info, CallerFrame, NonNull)) {
// If we're checking for a potential constant expression, evaluate all
// initializers even if some of them fail.
if (!Info.noteFailure())
@@ -6650,6 +6744,44 @@ static bool HandleFunctionCall(SourceLocation CallLoc,
return ESR == ESR_Returned;
}
+static void HandleTailCallTransfer(
+ EvalInfo &Info, const CallExpr *E, const FunctionDecl *Definition,
+ const LValue *This, LValue &ThisVal,
+ llvm::ArrayRef<const clang::Expr *> Args, CallRef Call, Stmt *Body,
+ SmallVector<QualType, 4> &CovariantAdjustmentPath, CallScopeRAII &Scope) {
+ auto &defer = Info.DeferFunctionCall;
+
+ defer.E = E;
+ defer.Definition = Definition;
+ defer.HasThis = This != nullptr;
+ ThisVal.moveInto(defer.ThisVal);
+ defer.Args = Args;
+ defer.Call = Call;
+ defer.Body = Body;
+ defer.CovariantAdjustmentPath = std::move(CovariantAdjustmentPath);
+
+ transferFromCallScope(Scope, defer.ArgumentsStored);
+}
+
+static bool HandleTailCallSetup(
+ EvalInfo &Info, const CallExpr *&E, const FunctionDecl *&Definition,
+ LValue *&This, LValue &ThisVal, llvm::ArrayRef<const clang::Expr *> &Args,
+ CallRef &Call, Stmt *&Body,
+ SmallVector<QualType, 4> &CovariantAdjustmentPath, CallScopeRAII &Scope) {
+ auto &defer = Info.DeferFunctionCall;
+ assert(defer.E != nullptr);
+
+ E = std::exchange(defer.E, nullptr);
+ Definition = defer.Definition;
+ ThisVal.setFrom(Info.Ctx, defer.ThisVal);
+ This = defer.HasThis ? &ThisVal : nullptr;
+ Args = defer.Args;
+ Call = defer.Call;
+ Body = defer.Body;
+ CovariantAdjustmentPath = std::move(defer.CovariantAdjustmentPath);
+ return transferIntoCallScope(Scope, defer.ArgumentsStored);
+}
+
/// Evaluate a constructor call.
static bool HandleConstructorCall(const Expr *E, const LValue &This,
CallRef Call,
@@ -6871,7 +7003,7 @@ static bool HandleConstructorCall(const Expr *E, const LValue &This,
EvalInfo &Info, APValue &Result) {
CallScopeRAII CallScope(Info);
CallRef Call = Info.CurrentCall->createCall(Definition);
- if (!EvaluateArgs(Args, Call, Info, Definition))
+ if (!EvaluateArgs(Args, Call, Info, *Info.CurrentCall, Definition))
return false;
return HandleConstructorCall(E, This, Call, Definition, Info, Result) &&
@@ -8242,6 +8374,13 @@ class ExprEvaluatorBase
APValue Result;
if (!handleCallExpr(E, Result, nullptr))
return false;
+
+ // When our current call is defered as a tail recursion
+ // we can't change result (yet).
+ if (Info.DeferFunctionCall.E != nullptr) {
+ return true;
+ }
+
return DerivedSuccess(Result, E);
}
@@ -8257,6 +8396,11 @@ class ExprEvaluatorBase
auto Args = llvm::ArrayRef(E->getArgs(), E->getNumArgs());
bool HasQualifier = false;
+ // Check for tail recursion, before we start evaluating any internal
+ // expression which can steal tail on their own.
+ const bool TailRecursion =
+ std::exchange(Info.TailRecursionReturnStmt, nullptr) != nullptr;
+
CallRef Call;
// Extract function decl and 'this' pointer from the callee.
@@ -8317,12 +8461,15 @@ class ExprEvaluatorBase
auto *OCE = dyn_cast<CXXOperatorCallExpr>(E);
if (OCE && OCE->isAssignmentOp()) {
assert(Args.size() == 2 && "wrong number of arguments in assignment");
- Call = Info.CurrentCall->createCall(FD);
bool HasThis = false;
if (const auto *MD = dyn_cast<CXXMethodDecl>(FD))
HasThis = MD->isImplicitObjectMemberFunction();
- if (!EvaluateArgs(HasThis ? Args.slice(1) : Args, Call, Info, FD,
- /*RightToLeft=*/true))
+
+ CallStackFrame &CallOriginFrame =
+ *(TailRecursion ? Info.CurrentCall->Caller : Info.CurrentCall);
+ Call = CallOriginFrame.createCall(FD);
+ if (!EvaluateArgs(HasThis ? Args.slice(1) : Args, Call, Info,
+ CallOriginFrame, FD, /*RightToLeft = */ true))
return false;
}
@@ -8404,8 +8551,10 @@ class ExprEvaluatorBase
// Evaluate the arguments now if we've not already done so.
if (!Call) {
- Call = Info.CurrentCall->createCall(FD);
- if (!EvaluateArgs(Args, Call, Info, FD))
+ CallStackFrame &CallOriginFrame =
+ *(TailRecursion ? Info.CurrentCall->Caller : Info.CurrentCall);
+ Call = CallOriginFrame.createCall(FD);
+ if (!EvaluateArgs(Args, Call, Info, CallOriginFrame, FD))
return false;
}
@@ -8438,11 +8587,40 @@ class ExprEvaluatorBase
const FunctionDecl *Definition = nullptr;
Stmt *Body = FD->getBody(Definition);
- if (!CheckConstexprFunction(Info, E->getExprLoc(), FD, Definition, Body) ||
- !HandleFunctionCall(E->getExprLoc(), Definition, This, E, Args, Call,
+ if (!CheckConstexprFunction(Info, E->getExprLoc(), FD, Definition, Body)) {
+ return false;
+ }
+
+ // If we are doing tail recursion, we need to store everything needed for
+ // the function call. There is always max one tail recursion prepared during
+ // execution of a program.
+ if (TailRecursion) {
+ HandleTailCallTransfer(Info, E, Definition, This, ThisVal, Args, Call,
+ Body, CovariantAdjustmentPath, CallScope);
+ return true;
+ }
+
+ if (!HandleFunctionCall(E->getExprLoc(), Definition, This, E, Args, Call,
Body, Info, Result, ResultSlot))
return false;
+ // If we do tail recursion, we don't have result yet.
+ assert(!Info.TailRecursionReady() || Result.isAbsent());
+
+ // A tail recursion can result in another tail recursion, so we need to loop
+ // here.
+ while (Info.TailRecursionReady()) {
+ if (!HandleTailCallSetup(Info, E, Definition, This, ThisVal, Args, Call,
+ Body, CovariantAdjustmentPath, CallScope))
+ return false;
+
+ if (!HandleFunctionCall(E->getExprLoc(), Definition, This, E, Args, Call,
+ Body, Info, Result, ResultSlot))
+ return false;
+ }
+
+ // TODO checkme this is correct
+ // We got out of tail recursion, it was just a normal function.
if (!CovariantAdjustmentPath.empty() &&
!HandleCovariantReturnAdjustment(Info, E, Result,
CovariantAdjustmentPath))
@@ -17832,7 +18010,7 @@ bool Expr::EvaluateWithSubstitution(APValue &Value, ASTContext &Ctx,
break;
const ParmVarDecl *PVD = Callee->getParamDecl(Idx);
if ((*I)->isValueDependent() ||
- !EvaluateCallArg(PVD, *I, Call, Info) ||
+ !EvaluateCallArg(PVD, *I, Call, Info, *Info.CurrentCall) ||
Info.EvalStatus.HasSideEffects) {
// If evaluation fails, throw away the argument entirely.
if (APValue *Slot = Info.getParamSlot(Call, PVD))
More information about the cfe-commits
mailing list