[clang] 6d2e141 - [clang][Interp] Handle lambda static invokers
Timm Bäder via cfe-commits
cfe-commits at lists.llvm.org
Tue Jul 25 23:42:34 PDT 2023
Author: Timm Bäder
Date: 2023-07-26T08:42:16+02:00
New Revision: 6d2e141e5c0da9dfd2a2ea9b63aa8017ba8277e3
URL: https://github.com/llvm/llvm-project/commit/6d2e141e5c0da9dfd2a2ea9b63aa8017ba8277e3
DIFF: https://github.com/llvm/llvm-project/commit/6d2e141e5c0da9dfd2a2ea9b63aa8017ba8277e3.diff
LOG: [clang][Interp] Handle lambda static invokers
Differential Revision: https://reviews.llvm.org/D150111
Added:
Modified:
clang/lib/AST/Interp/ByteCodeEmitter.cpp
clang/lib/AST/Interp/ByteCodeStmtGen.cpp
clang/lib/AST/Interp/ByteCodeStmtGen.h
clang/lib/AST/Interp/Function.h
clang/lib/AST/Interp/Interp.h
clang/test/AST/Interp/lambda.cpp
Removed:
################################################################################
diff --git a/clang/lib/AST/Interp/ByteCodeEmitter.cpp b/clang/lib/AST/Interp/ByteCodeEmitter.cpp
index f2072f974c4084..c33c9bd37e031c 100644
--- a/clang/lib/AST/Interp/ByteCodeEmitter.cpp
+++ b/clang/lib/AST/Interp/ByteCodeEmitter.cpp
@@ -96,8 +96,15 @@ ByteCodeEmitter::compileFunc(const FunctionDecl *FuncDecl) {
if (!FuncDecl->isDefined())
return Func;
+ // Lambda static invokers are a special case that we emit custom code for.
+ bool IsEligibleForCompilation = false;
+ if (const auto *MD = dyn_cast<CXXMethodDecl>(FuncDecl))
+ IsEligibleForCompilation = MD->isLambdaStaticInvoker();
+ if (!IsEligibleForCompilation)
+ IsEligibleForCompilation = FuncDecl->isConstexpr();
+
// Compile the function body.
- if (!FuncDecl->isConstexpr() || !visitFunc(FuncDecl)) {
+ if (!IsEligibleForCompilation || !visitFunc(FuncDecl)) {
// Return a dummy function if compilation failed.
if (BailLocation)
return llvm::make_error<ByteCodeGenError>(*BailLocation);
diff --git a/clang/lib/AST/Interp/ByteCodeStmtGen.cpp b/clang/lib/AST/Interp/ByteCodeStmtGen.cpp
index 0c512950c292a9..59ef19be356a67 100644
--- a/clang/lib/AST/Interp/ByteCodeStmtGen.cpp
+++ b/clang/lib/AST/Interp/ByteCodeStmtGen.cpp
@@ -89,11 +89,67 @@ template <class Emitter> class SwitchScope final : public LabelScope<Emitter> {
} // namespace interp
} // namespace clang
+template <class Emitter>
+bool ByteCodeStmtGen<Emitter>::emitLambdaStaticInvokerBody(
+ const CXXMethodDecl *MD) {
+ assert(MD->isLambdaStaticInvoker());
+ assert(MD->hasBody());
+ assert(cast<CompoundStmt>(MD->getBody())->body_empty());
+
+ const CXXRecordDecl *ClosureClass = MD->getParent();
+ const CXXMethodDecl *LambdaCallOp = ClosureClass->getLambdaCallOperator();
+ assert(ClosureClass->captures_begin() == ClosureClass->captures_end());
+ const Function *Func = this->getFunction(LambdaCallOp);
+ if (!Func)
+ return false;
+ assert(Func->hasThisPointer());
+ assert(Func->getNumParams() == (MD->getNumParams() + 1 + Func->hasRVO()));
+
+ if (Func->hasRVO()) {
+ if (!this->emitRVOPtr(MD))
+ return false;
+ }
+
+ // The lambda call operator needs an instance pointer, but we don't have
+ // one here, and we don't need one either because the lambda cannot have
+ // any captures, as verified above. Emit a null pointer. This is then
+ // special-cased when interpreting to not emit any misleading diagnostics.
+ if (!this->emitNullPtr(MD))
+ return false;
+
+ // Forward all arguments from the static invoker to the lambda call operator.
+ for (const ParmVarDecl *PVD : MD->parameters()) {
+ auto It = this->Params.find(PVD);
+ assert(It != this->Params.end());
+
+ // We do the lvalue-to-rvalue conversion manually here, so no need
+ // to care about references.
+ PrimType ParamType = this->classify(PVD->getType()).value_or(PT_Ptr);
+ if (!this->emitGetParam(ParamType, It->second, MD))
+ return false;
+ }
+
+ if (!this->emitCall(Func, LambdaCallOp))
+ return false;
+
+ this->emitCleanup();
+ if (ReturnType)
+ return this->emitRet(*ReturnType, MD);
+
+ // Nothing to do, since we emitted the RVO pointer above.
+ return this->emitRetVoid(MD);
+}
+
template <class Emitter>
bool ByteCodeStmtGen<Emitter>::visitFunc(const FunctionDecl *F) {
// Classify the return type.
ReturnType = this->classify(F->getReturnType());
+ // Emit custom code if this is a lambda static invoker.
+ if (const auto *MD = dyn_cast<CXXMethodDecl>(F);
+ MD && MD->isLambdaStaticInvoker())
+ return this->emitLambdaStaticInvokerBody(MD);
+
// Constructor. Set up field initializers.
if (const auto *Ctor = dyn_cast<CXXConstructorDecl>(F)) {
const RecordDecl *RD = Ctor->getParent();
diff --git a/clang/lib/AST/Interp/ByteCodeStmtGen.h b/clang/lib/AST/Interp/ByteCodeStmtGen.h
index 8d9277a11dd7d7..bc50b977a6d04f 100644
--- a/clang/lib/AST/Interp/ByteCodeStmtGen.h
+++ b/clang/lib/AST/Interp/ByteCodeStmtGen.h
@@ -68,6 +68,8 @@ class ByteCodeStmtGen final : public ByteCodeExprGen<Emitter> {
bool visitCaseStmt(const CaseStmt *S);
bool visitDefaultStmt(const DefaultStmt *S);
+ bool emitLambdaStaticInvokerBody(const CXXMethodDecl *MD);
+
/// Type of the expression returned by the function.
std::optional<PrimType> ReturnType;
diff --git a/clang/lib/AST/Interp/Function.h b/clang/lib/AST/Interp/Function.h
index 55a23ff288e846..644d4cd53b1e19 100644
--- a/clang/lib/AST/Interp/Function.h
+++ b/clang/lib/AST/Interp/Function.h
@@ -17,6 +17,7 @@
#include "Pointer.h"
#include "Source.h"
+#include "clang/AST/ASTLambda.h"
#include "clang/AST/Decl.h"
#include "llvm/Support/raw_ostream.h"
@@ -65,7 +66,7 @@ class Scope final {
/// the argument values need to be preceeded by a Pointer for the This object.
///
/// If the function uses Return Value Optimization, the arguments (and
-/// potentially the This pointer) need to be proceeded by a Pointer pointing
+/// potentially the This pointer) need to be preceeded by a Pointer pointing
/// to the location to construct the returned value.
///
/// After the function has been called, it will remove all arguments,
@@ -127,7 +128,7 @@ class Function final {
SourceInfo getSource(CodePtr PC) const;
/// Checks if the function is valid to call in constexpr.
- bool isConstexpr() const { return IsValid; }
+ bool isConstexpr() const { return IsValid || isLambdaStaticInvoker(); }
/// Checks if the function is virtual.
bool isVirtual() const;
@@ -144,6 +145,22 @@ class Function final {
return nullptr;
}
+ /// Returns whether this function is a lambda static invoker,
+ /// which we generate custom byte code for.
+ bool isLambdaStaticInvoker() const {
+ if (const auto *MD = dyn_cast<CXXMethodDecl>(F))
+ return MD->isLambdaStaticInvoker();
+ return false;
+ }
+
+ /// Returns whether this function is the call operator
+ /// of a lambda record decl.
+ bool isLambdaCallOperator() const {
+ if (const auto *MD = dyn_cast<CXXMethodDecl>(F))
+ return clang::isLambdaCallOperator(MD);
+ return false;
+ }
+
/// Checks if the function is fully done compiling.
bool isFullyCompiled() const { return IsFullyCompiled; }
diff --git a/clang/lib/AST/Interp/Interp.h b/clang/lib/AST/Interp/Interp.h
index 0acaf764353215..38caee0e7a6094 100644
--- a/clang/lib/AST/Interp/Interp.h
+++ b/clang/lib/AST/Interp/Interp.h
@@ -1632,8 +1632,16 @@ inline bool Call(InterpState &S, CodePtr OpPC, const Function *Func) {
const Pointer &ThisPtr = S.Stk.peek<Pointer>(ThisOffset);
- if (!CheckInvoke(S, OpPC, ThisPtr))
- return false;
+ // If the current function is a lambda static invoker and
+ // the function we're about to call is a lambda call operator,
+ // skip the CheckInvoke, since the ThisPtr is a null pointer
+ // anyway.
+ if (!(S.Current->getFunction() &&
+ S.Current->getFunction()->isLambdaStaticInvoker() &&
+ Func->isLambdaCallOperator())) {
+ if (!CheckInvoke(S, OpPC, ThisPtr))
+ return false;
+ }
if (S.checkingPotentialConstantExpression())
return false;
diff --git a/clang/test/AST/Interp/lambda.cpp b/clang/test/AST/Interp/lambda.cpp
index c588da530e4d01..b913ad13500bc0 100644
--- a/clang/test/AST/Interp/lambda.cpp
+++ b/clang/test/AST/Interp/lambda.cpp
@@ -107,3 +107,58 @@ namespace LambdaParams {
static_assert(foo() == 1); // expected-error {{not an integral constant expression}}
}
+namespace StaticInvoker {
+ constexpr int sv1(int i) {
+ auto l = []() { return 12; };
+ int (*fp)() = l;
+ return fp();
+ }
+ static_assert(sv1(12) == 12);
+
+ constexpr int sv2(int i) {
+ auto l = [](int m, float f, void *A) { return m; };
+ int (*fp)(int, float, void*) = l;
+ return fp(i, 4.0f, nullptr);
+ }
+ static_assert(sv2(12) == 12);
+
+ constexpr int sv3(int i) {
+ auto l = [](int m, const int &n) { return m; };
+ int (*fp)(int, const int &) = l;
+ return fp(i, 3);
+ }
+ static_assert(sv3(12) == 12);
+
+ constexpr int sv4(int i) {
+ auto l = [](int &m) { return m; };
+ int (*fp)(int&) = l;
+ return fp(i);
+ }
+ static_assert(sv4(12) == 12);
+
+
+
+ /// FIXME: This is broken for lambda-unrelated reasons.
+#if 0
+ constexpr int sv5(int i) {
+ struct F { int a; float f; };
+ auto l = [](int m, F f) { return m; };
+ int (*fp)(int, F) = l;
+ return fp(i, F{12, 14.0});
+ }
+ static_assert(sv5(12) == 12);
+#endif
+
+ constexpr int sv6(int i) {
+ struct F { int a;
+ constexpr F(int a) : a(a) {}
+ };
+
+ auto l = [](int m) { return F(12); };
+ F (*fp)(int) = l;
+ F f = fp(i);
+
+ return fp(i).a;
+ }
+ static_assert(sv6(12) == 12);
+}
More information about the cfe-commits
mailing list