[clang] 976d8b4 - [clang][Interp] Virtual function calls
Timm Bäder via cfe-commits
cfe-commits at lists.llvm.org
Thu Jun 15 04:34:07 PDT 2023
Author: Timm Bäder
Date: 2023-06-15T13:33:43+02:00
New Revision: 976d8b40cccf4678fe8c414210ce82170049b715
URL: https://github.com/llvm/llvm-project/commit/976d8b40cccf4678fe8c414210ce82170049b715
DIFF: https://github.com/llvm/llvm-project/commit/976d8b40cccf4678fe8c414210ce82170049b715.diff
LOG: [clang][Interp] Virtual function calls
Add a CallVirt opcode and implement virtual function calls this way.
Differential Revision: https://reviews.llvm.org/D142630
Added:
Modified:
clang/lib/AST/Interp/ByteCodeExprGen.cpp
clang/lib/AST/Interp/Context.cpp
clang/lib/AST/Interp/Context.h
clang/lib/AST/Interp/Descriptor.cpp
clang/lib/AST/Interp/Function.h
clang/lib/AST/Interp/Interp.h
clang/lib/AST/Interp/InterpState.h
clang/lib/AST/Interp/Opcodes.td
clang/lib/AST/Interp/Pointer.h
clang/test/AST/Interp/records.cpp
Removed:
################################################################################
diff --git a/clang/lib/AST/Interp/ByteCodeExprGen.cpp b/clang/lib/AST/Interp/ByteCodeExprGen.cpp
index 1be131be66e3b..94db8b868758e 100644
--- a/clang/lib/AST/Interp/ByteCodeExprGen.cpp
+++ b/clang/lib/AST/Interp/ByteCodeExprGen.cpp
@@ -1711,12 +1711,24 @@ bool ByteCodeExprGen<Emitter>::VisitCallExpr(const CallExpr *E) {
assert(HasRVO == Func->hasRVO());
+ bool HasQualifier = false;
+ if (const auto *ME = dyn_cast<MemberExpr>(E->getCallee()))
+ HasQualifier = ME->hasQualifier();
+
+ bool IsVirtual = false;
+ if (const auto *MD = dyn_cast<CXXMethodDecl>(FuncDecl))
+ IsVirtual = MD->isVirtual();
+
// In any case call the function. The return value will end up on the stack
// and if the function has RVO, we already have the pointer on the stack to
// write the result into.
- if (!this->emitCall(Func, E))
- return false;
-
+ if (IsVirtual && !HasQualifier) {
+ if (!this->emitCallVirt(Func, E))
+ return false;
+ } else {
+ if (!this->emitCall(Func, E))
+ return false;
+ }
} else {
// Indirect call. Visit the callee, which will leave a FunctionPointer on
// the stack. Cleanup of the returned value if necessary will be done after
diff --git a/clang/lib/AST/Interp/Context.cpp b/clang/lib/AST/Interp/Context.cpp
index ed7ed41b1b24d..67fb69e663382 100644
--- a/clang/lib/AST/Interp/Context.cpp
+++ b/clang/lib/AST/Interp/Context.cpp
@@ -158,3 +158,38 @@ bool Context::Check(State &Parent, llvm::Expected<bool> &&Flag) {
});
return false;
}
+
+// TODO: Virtual bases?
+const CXXMethodDecl *
+Context::getOverridingFunction(const CXXRecordDecl *DynamicDecl,
+ const CXXRecordDecl *StaticDecl,
+ const CXXMethodDecl *InitialFunction) const {
+
+ const CXXRecordDecl *CurRecord = DynamicDecl;
+ const CXXMethodDecl *FoundFunction = InitialFunction;
+ for (;;) {
+ const CXXMethodDecl *Overrider =
+ FoundFunction->getCorrespondingMethodDeclaredInClass(CurRecord, false);
+ if (Overrider)
+ return Overrider;
+
+ // Common case of only one base class.
+ if (CurRecord->getNumBases() == 1) {
+ CurRecord = CurRecord->bases_begin()->getType()->getAsCXXRecordDecl();
+ continue;
+ }
+
+ // Otherwise, go to the base class that will lead to the StaticDecl.
+ for (const CXXBaseSpecifier &Spec : CurRecord->bases()) {
+ const CXXRecordDecl *Base = Spec.getType()->getAsCXXRecordDecl();
+ if (Base == StaticDecl || Base->isDerivedFrom(StaticDecl)) {
+ CurRecord = Base;
+ break;
+ }
+ }
+ }
+
+ llvm_unreachable(
+ "Couldn't find an overriding function in the class hierarchy?");
+ return nullptr;
+}
diff --git a/clang/lib/AST/Interp/Context.h b/clang/lib/AST/Interp/Context.h
index cbae7fcf2860a..107bb75a46247 100644
--- a/clang/lib/AST/Interp/Context.h
+++ b/clang/lib/AST/Interp/Context.h
@@ -63,6 +63,11 @@ class Context final {
/// Classifies an expression.
std::optional<PrimType> classify(QualType T) const;
+ const CXXMethodDecl *
+ getOverridingFunction(const CXXRecordDecl *DynamicDecl,
+ const CXXRecordDecl *StaticDecl,
+ const CXXMethodDecl *InitialFunction) const;
+
private:
/// Runs a function.
bool Run(State &Parent, Function *Func, APValue &Result);
diff --git a/clang/lib/AST/Interp/Descriptor.cpp b/clang/lib/AST/Interp/Descriptor.cpp
index a6bef77bf8c16..565c4b2003847 100644
--- a/clang/lib/AST/Interp/Descriptor.cpp
+++ b/clang/lib/AST/Interp/Descriptor.cpp
@@ -274,6 +274,8 @@ QualType Descriptor::getType() const {
return E->getType();
if (auto *D = asValueDecl())
return D->getType();
+ if (auto *T = dyn_cast<TypeDecl>(asDecl()))
+ return QualType(T->getTypeForDecl(), 0);
llvm_unreachable("Invalid descriptor type");
}
diff --git a/clang/lib/AST/Interp/Function.h b/clang/lib/AST/Interp/Function.h
index 005cda7379c2d..6fde5a616dec0 100644
--- a/clang/lib/AST/Interp/Function.h
+++ b/clang/lib/AST/Interp/Function.h
@@ -137,6 +137,13 @@ class Function final {
/// Checks if the function is a destructor.
bool isDestructor() const { return isa<CXXDestructorDecl>(F); }
+ /// Returns the parent record decl, if any.
+ const CXXRecordDecl *getParentDecl() const {
+ if (const auto *MD = dyn_cast<CXXMethodDecl>(F))
+ return MD->getParent();
+ return nullptr;
+ }
+
/// Checks if the function is fully done compiling.
bool isFullyCompiled() const { return IsFullyCompiled; }
diff --git a/clang/lib/AST/Interp/Interp.h b/clang/lib/AST/Interp/Interp.h
index fd5ce3c325961..b54ef9542010b 100644
--- a/clang/lib/AST/Interp/Interp.h
+++ b/clang/lib/AST/Interp/Interp.h
@@ -1622,6 +1622,36 @@ inline bool Call(InterpState &S, CodePtr OpPC, const Function *Func) {
return false;
}
+inline bool CallVirt(InterpState &S, CodePtr OpPC, const Function *Func) {
+ assert(Func->hasThisPointer());
+ assert(Func->isVirtual());
+ size_t ThisOffset =
+ Func->getArgSize() + (Func->hasRVO() ? primSize(PT_Ptr) : 0);
+ Pointer &ThisPtr = S.Stk.peek<Pointer>(ThisOffset);
+
+ const CXXRecordDecl *DynamicDecl =
+ ThisPtr.getDeclDesc()->getType()->getAsCXXRecordDecl();
+ const auto *StaticDecl = cast<CXXRecordDecl>(Func->getParentDecl());
+ const auto *InitialFunction = cast<CXXMethodDecl>(Func->getDecl());
+ const CXXMethodDecl *Overrider = S.getContext().getOverridingFunction(
+ DynamicDecl, StaticDecl, InitialFunction);
+
+ if (Overrider != InitialFunction) {
+ Func = S.P.getFunction(Overrider);
+
+ const CXXRecordDecl *ThisFieldDecl =
+ ThisPtr.getFieldDesc()->getType()->getAsCXXRecordDecl();
+ if (Func->getParentDecl()->isDerivedFrom(ThisFieldDecl)) {
+ // If the function we call is further DOWN the hierarchy than the
+ // FieldDesc of our pointer, just get the DeclDesc instead, which
+ // is the furthest we might go up in the hierarchy.
+ ThisPtr = ThisPtr.getDeclPtr();
+ }
+ }
+
+ return Call(S, OpPC, Func);
+}
+
inline bool CallBI(InterpState &S, CodePtr &PC, const Function *Func) {
auto NewFrame = std::make_unique<InterpFrame>(S, Func, PC);
diff --git a/clang/lib/AST/Interp/InterpState.h b/clang/lib/AST/Interp/InterpState.h
index 74c4667bb0196..fc28c74002d9d 100644
--- a/clang/lib/AST/Interp/InterpState.h
+++ b/clang/lib/AST/Interp/InterpState.h
@@ -89,6 +89,8 @@ class InterpState final : public State, public SourceMapper {
return M ? M->getSource(F, PC) : F->getSource(PC);
}
+ Context &getContext() const { return Ctx; }
+
private:
/// AST Walker state.
State &Parent;
diff --git a/clang/lib/AST/Interp/Opcodes.td b/clang/lib/AST/Interp/Opcodes.td
index 15f7312ad00e6..28074a350d05f 100644
--- a/clang/lib/AST/Interp/Opcodes.td
+++ b/clang/lib/AST/Interp/Opcodes.td
@@ -181,6 +181,11 @@ def Call : Opcode {
let Types = [];
}
+def CallVirt : Opcode {
+ let Args = [ArgFunction];
+ let Types = [];
+}
+
def CallBI : Opcode {
let Args = [ArgFunction];
let Types = [];
diff --git a/clang/lib/AST/Interp/Pointer.h b/clang/lib/AST/Interp/Pointer.h
index 863d8b3bae0d6..ab196beb93aa3 100644
--- a/clang/lib/AST/Interp/Pointer.h
+++ b/clang/lib/AST/Interp/Pointer.h
@@ -200,6 +200,8 @@ class Pointer {
/// Returns the type of the innermost field.
QualType getType() const { return getFieldDesc()->getType(); }
+ Pointer getDeclPtr() const { return Pointer(Pointee); }
+
/// Returns the element size of the innermost field.
size_t elemSize() const {
if (Base == RootPtrMark)
diff --git a/clang/test/AST/Interp/records.cpp b/clang/test/AST/Interp/records.cpp
index a874bf3781504..68d902a66e3ca 100644
--- a/clang/test/AST/Interp/records.cpp
+++ b/clang/test/AST/Interp/records.cpp
@@ -1,8 +1,10 @@
// RUN: %clang_cc1 -fexperimental-new-constant-interpreter -verify %s
// RUN: %clang_cc1 -fexperimental-new-constant-interpreter -std=c++14 -verify %s
+// RUN: %clang_cc1 -fexperimental-new-constant-interpreter -std=c++20 -verify %s
// RUN: %clang_cc1 -fexperimental-new-constant-interpreter -triple i686 -verify %s
// RUN: %clang_cc1 -verify=ref %s
// RUN: %clang_cc1 -verify=ref -std=c++14 %s
+// RUN: %clang_cc1 -verify=ref -std=c++20 %s
// RUN: %clang_cc1 -verify=ref -triple i686 %s
struct BoolPair {
@@ -380,6 +382,7 @@ namespace MI {
};
namespace DeriveFailures {
+#if __cplusplus < 202002L
struct Base { // ref-note 2{{declared here}} expected-note {{declared here}}
int Val;
};
@@ -397,10 +400,12 @@ namespace DeriveFailures {
// ref-note {{declared here}} \
// expected-error {{must be initialized by a constant expression}} \
// expected-note {{in call to 'Derived(12)'}}
+
static_assert(D.Val == 0, ""); // ref-error {{not an integral constant expression}} \
// ref-note {{initializer of 'D' is not a constant expression}} \
// expected-error {{not an integral constant expression}} \
// expected-note {{read of object outside its lifetime}}
+#endif
struct AnotherBase {
int Val;
@@ -488,3 +493,201 @@ namespace DeclRefs {
//static_assert(b.a.m == 100, "");
//static_assert(b.a.f == 100, "");
}
+
+#if __cplusplus >= 202002L
+namespace VirtualCalls {
+namespace Obvious {
+
+ class A {
+ public:
+ constexpr A(){}
+ constexpr virtual int foo() {
+ return 3;
+ }
+ };
+ class B : public A {
+ public:
+ constexpr int foo() override {
+ return 6;
+ }
+ };
+
+ constexpr int getFooB(bool b) {
+ A *a;
+ A myA;
+ B myB;
+
+ if (b)
+ a = &myA;
+ else
+ a = &myB;
+
+ return a->foo();
+ }
+ static_assert(getFooB(true) == 3, "");
+ static_assert(getFooB(false) == 6, "");
+}
+
+namespace MultipleBases {
+ class A {
+ public:
+ constexpr virtual int getInt() const { return 10; }
+ };
+ class B {
+ public:
+ };
+ class C : public A, public B {
+ public:
+ constexpr int getInt() const override { return 20; }
+ };
+
+ constexpr int callGetInt(const A& a) { return a.getInt(); }
+ static_assert(callGetInt(C()) == 20, "");
+ static_assert(callGetInt(A()) == 10, "");
+}
+
+namespace Destructors {
+ class Base {
+ public:
+ int i;
+ constexpr Base(int &i) : i(i) {i++;}
+ constexpr virtual ~Base() {i--;}
+ };
+
+ class Derived : public Base {
+ public:
+ constexpr Derived(int &i) : Base(i) {}
+ constexpr virtual ~Derived() {i--;}
+ };
+
+ constexpr int test() {
+ int i = 0;
+ Derived d(i);
+ return i;
+ }
+ static_assert(test() == 1);
+}
+
+
+namespace VirtualDtors {
+ class A {
+ public:
+ unsigned &v;
+ constexpr A(unsigned &v) : v(v) {}
+ constexpr virtual ~A() {
+ v |= (1 << 0);
+ }
+ };
+ class B : public A {
+ public:
+ constexpr B(unsigned &v) : A(v) {}
+ constexpr virtual ~B() {
+ v |= (1 << 1);
+ }
+ };
+ class C : public B {
+ public:
+ constexpr C(unsigned &v) : B(v) {}
+ constexpr virtual ~C() {
+ v |= (1 << 2);
+ }
+ };
+
+ constexpr bool foo() {
+ unsigned a = 0;
+ {
+ C c(a);
+ }
+ return ((a & (1 << 0)) && (a & (1 << 1)) && (a & (1 << 2)));
+ }
+
+ static_assert(foo());
+
+
+};
+
+namespace QualifiedCalls {
+ class A {
+ public:
+ constexpr virtual int foo() const {
+ return 5;
+ }
+ };
+ class B : public A {};
+ class C : public B {
+ public:
+ constexpr int foo() const override {
+ return B::foo(); // B doesn't have a foo(), so this should call A::foo().
+ }
+ constexpr int foo2() const {
+ return this->A::foo();
+ }
+ };
+ constexpr C c;
+ static_assert(c.foo() == 5);
+ static_assert(c.foo2() == 5);
+
+
+ struct S {
+ int _c = 0;
+ virtual constexpr int foo() const { return 1; }
+ };
+
+ struct SS : S {
+ int a;
+ constexpr SS() {
+ a = S::foo();
+ }
+ constexpr int foo() const override {
+ return S::foo();
+ }
+ };
+
+ constexpr SS ss;
+ static_assert(ss.a == 1);
+}
+
+namespace CtorDtor {
+ struct Base {
+ int i = 0;
+ int j = 0;
+
+ constexpr Base() : i(func()) {
+ j = func();
+ }
+ constexpr Base(int i) : i(i), j(i) {}
+
+ constexpr virtual int func() const { return 1; }
+ };
+
+ struct Derived : Base {
+ constexpr Derived() {}
+ constexpr Derived(int i) : Base(i) {}
+ constexpr int func() const override { return 2; }
+ };
+
+ struct Derived2 : Derived {
+ constexpr Derived2() : Derived(func()) {} // ref-note {{subexpression not valid in a constant expression}}
+ constexpr int func() const override { return 3; }
+ };
+
+ constexpr Base B;
+ static_assert(B.i == 1 && B.j == 1, "");
+
+ constexpr Derived D;
+ static_assert(D.i == 1, ""); // expected-error {{static assertion failed}} \
+ // expected-note {{2 == 1}}
+ static_assert(D.j == 1, ""); // expected-error {{static assertion failed}} \
+ // expected-note {{2 == 1}}
+
+ constexpr Derived2 D2; // ref-error {{must be initialized by a constant expression}} \
+ // ref-note {{in call to 'Derived2()'}} \
+ // ref-note 2{{declared here}}
+ static_assert(D2.i == 3, ""); // ref-error {{not an integral constant expression}} \
+ // ref-note {{initializer of 'D2' is not a constant expression}}
+ static_assert(D2.j == 3, ""); // ref-error {{not an integral constant expression}} \
+ // ref-note {{initializer of 'D2' is not a constant expression}}
+
+}
+};
+#endif
More information about the cfe-commits
mailing list