[clang] 976d8b4 - [clang][Interp] Virtual function calls

Timm Bäder via cfe-commits cfe-commits at lists.llvm.org
Thu Jun 15 04:34:07 PDT 2023


Author: Timm Bäder
Date: 2023-06-15T13:33:43+02:00
New Revision: 976d8b40cccf4678fe8c414210ce82170049b715

URL: https://github.com/llvm/llvm-project/commit/976d8b40cccf4678fe8c414210ce82170049b715
DIFF: https://github.com/llvm/llvm-project/commit/976d8b40cccf4678fe8c414210ce82170049b715.diff

LOG: [clang][Interp] Virtual function calls

Add a CallVirt opcode and implement virtual function calls this way.

Differential Revision: https://reviews.llvm.org/D142630

Added: 
    

Modified: 
    clang/lib/AST/Interp/ByteCodeExprGen.cpp
    clang/lib/AST/Interp/Context.cpp
    clang/lib/AST/Interp/Context.h
    clang/lib/AST/Interp/Descriptor.cpp
    clang/lib/AST/Interp/Function.h
    clang/lib/AST/Interp/Interp.h
    clang/lib/AST/Interp/InterpState.h
    clang/lib/AST/Interp/Opcodes.td
    clang/lib/AST/Interp/Pointer.h
    clang/test/AST/Interp/records.cpp

Removed: 
    


################################################################################
diff  --git a/clang/lib/AST/Interp/ByteCodeExprGen.cpp b/clang/lib/AST/Interp/ByteCodeExprGen.cpp
index 1be131be66e3b..94db8b868758e 100644
--- a/clang/lib/AST/Interp/ByteCodeExprGen.cpp
+++ b/clang/lib/AST/Interp/ByteCodeExprGen.cpp
@@ -1711,12 +1711,24 @@ bool ByteCodeExprGen<Emitter>::VisitCallExpr(const CallExpr *E) {
 
     assert(HasRVO == Func->hasRVO());
 
+    bool HasQualifier = false;
+    if (const auto *ME = dyn_cast<MemberExpr>(E->getCallee()))
+      HasQualifier = ME->hasQualifier();
+
+    bool IsVirtual = false;
+    if (const auto *MD = dyn_cast<CXXMethodDecl>(FuncDecl))
+      IsVirtual = MD->isVirtual();
+
     // In any case call the function. The return value will end up on the stack
     // and if the function has RVO, we already have the pointer on the stack to
     // write the result into.
-    if (!this->emitCall(Func, E))
-      return false;
-
+    if (IsVirtual && !HasQualifier) {
+      if (!this->emitCallVirt(Func, E))
+        return false;
+    } else {
+      if (!this->emitCall(Func, E))
+        return false;
+    }
   } else {
     // Indirect call. Visit the callee, which will leave a FunctionPointer on
     // the stack. Cleanup of the returned value if necessary will be done after

diff  --git a/clang/lib/AST/Interp/Context.cpp b/clang/lib/AST/Interp/Context.cpp
index ed7ed41b1b24d..67fb69e663382 100644
--- a/clang/lib/AST/Interp/Context.cpp
+++ b/clang/lib/AST/Interp/Context.cpp
@@ -158,3 +158,38 @@ bool Context::Check(State &Parent, llvm::Expected<bool> &&Flag) {
   });
   return false;
 }
+
+// TODO: Virtual bases?
+const CXXMethodDecl *
+Context::getOverridingFunction(const CXXRecordDecl *DynamicDecl,
+                               const CXXRecordDecl *StaticDecl,
+                               const CXXMethodDecl *InitialFunction) const {
+
+  const CXXRecordDecl *CurRecord = DynamicDecl;
+  const CXXMethodDecl *FoundFunction = InitialFunction;
+  for (;;) {
+    const CXXMethodDecl *Overrider =
+        FoundFunction->getCorrespondingMethodDeclaredInClass(CurRecord, false);
+    if (Overrider)
+      return Overrider;
+
+    // Common case of only one base class.
+    if (CurRecord->getNumBases() == 1) {
+      CurRecord = CurRecord->bases_begin()->getType()->getAsCXXRecordDecl();
+      continue;
+    }
+
+    // Otherwise, go to the base class that will lead to the StaticDecl.
+    for (const CXXBaseSpecifier &Spec : CurRecord->bases()) {
+      const CXXRecordDecl *Base = Spec.getType()->getAsCXXRecordDecl();
+      if (Base == StaticDecl || Base->isDerivedFrom(StaticDecl)) {
+        CurRecord = Base;
+        break;
+      }
+    }
+  }
+
+  llvm_unreachable(
+      "Couldn't find an overriding function in the class hierarchy?");
+  return nullptr;
+}

diff  --git a/clang/lib/AST/Interp/Context.h b/clang/lib/AST/Interp/Context.h
index cbae7fcf2860a..107bb75a46247 100644
--- a/clang/lib/AST/Interp/Context.h
+++ b/clang/lib/AST/Interp/Context.h
@@ -63,6 +63,11 @@ class Context final {
   /// Classifies an expression.
   std::optional<PrimType> classify(QualType T) const;
 
+  const CXXMethodDecl *
+  getOverridingFunction(const CXXRecordDecl *DynamicDecl,
+                        const CXXRecordDecl *StaticDecl,
+                        const CXXMethodDecl *InitialFunction) const;
+
 private:
   /// Runs a function.
   bool Run(State &Parent, Function *Func, APValue &Result);

diff  --git a/clang/lib/AST/Interp/Descriptor.cpp b/clang/lib/AST/Interp/Descriptor.cpp
index a6bef77bf8c16..565c4b2003847 100644
--- a/clang/lib/AST/Interp/Descriptor.cpp
+++ b/clang/lib/AST/Interp/Descriptor.cpp
@@ -274,6 +274,8 @@ QualType Descriptor::getType() const {
     return E->getType();
   if (auto *D = asValueDecl())
     return D->getType();
+  if (auto *T = dyn_cast<TypeDecl>(asDecl()))
+    return QualType(T->getTypeForDecl(), 0);
   llvm_unreachable("Invalid descriptor type");
 }
 

diff  --git a/clang/lib/AST/Interp/Function.h b/clang/lib/AST/Interp/Function.h
index 005cda7379c2d..6fde5a616dec0 100644
--- a/clang/lib/AST/Interp/Function.h
+++ b/clang/lib/AST/Interp/Function.h
@@ -137,6 +137,13 @@ class Function final {
   /// Checks if the function is a destructor.
   bool isDestructor() const { return isa<CXXDestructorDecl>(F); }
 
+  /// Returns the parent record decl, if any.
+  const CXXRecordDecl *getParentDecl() const {
+    if (const auto *MD = dyn_cast<CXXMethodDecl>(F))
+      return MD->getParent();
+    return nullptr;
+  }
+
   /// Checks if the function is fully done compiling.
   bool isFullyCompiled() const { return IsFullyCompiled; }
 

diff  --git a/clang/lib/AST/Interp/Interp.h b/clang/lib/AST/Interp/Interp.h
index fd5ce3c325961..b54ef9542010b 100644
--- a/clang/lib/AST/Interp/Interp.h
+++ b/clang/lib/AST/Interp/Interp.h
@@ -1622,6 +1622,36 @@ inline bool Call(InterpState &S, CodePtr OpPC, const Function *Func) {
   return false;
 }
 
+inline bool CallVirt(InterpState &S, CodePtr OpPC, const Function *Func) {
+  assert(Func->hasThisPointer());
+  assert(Func->isVirtual());
+  size_t ThisOffset =
+      Func->getArgSize() + (Func->hasRVO() ? primSize(PT_Ptr) : 0);
+  Pointer &ThisPtr = S.Stk.peek<Pointer>(ThisOffset);
+
+  const CXXRecordDecl *DynamicDecl =
+      ThisPtr.getDeclDesc()->getType()->getAsCXXRecordDecl();
+  const auto *StaticDecl = cast<CXXRecordDecl>(Func->getParentDecl());
+  const auto *InitialFunction = cast<CXXMethodDecl>(Func->getDecl());
+  const CXXMethodDecl *Overrider = S.getContext().getOverridingFunction(
+      DynamicDecl, StaticDecl, InitialFunction);
+
+  if (Overrider != InitialFunction) {
+    Func = S.P.getFunction(Overrider);
+
+    const CXXRecordDecl *ThisFieldDecl =
+        ThisPtr.getFieldDesc()->getType()->getAsCXXRecordDecl();
+    if (Func->getParentDecl()->isDerivedFrom(ThisFieldDecl)) {
+      // If the function we call is further DOWN the hierarchy than the
+      // FieldDesc of our pointer, just get the DeclDesc instead, which
+      // is the furthest we might go up in the hierarchy.
+      ThisPtr = ThisPtr.getDeclPtr();
+    }
+  }
+
+  return Call(S, OpPC, Func);
+}
+
 inline bool CallBI(InterpState &S, CodePtr &PC, const Function *Func) {
   auto NewFrame = std::make_unique<InterpFrame>(S, Func, PC);
 

diff  --git a/clang/lib/AST/Interp/InterpState.h b/clang/lib/AST/Interp/InterpState.h
index 74c4667bb0196..fc28c74002d9d 100644
--- a/clang/lib/AST/Interp/InterpState.h
+++ b/clang/lib/AST/Interp/InterpState.h
@@ -89,6 +89,8 @@ class InterpState final : public State, public SourceMapper {
     return M ? M->getSource(F, PC) : F->getSource(PC);
   }
 
+  Context &getContext() const { return Ctx; }
+
 private:
   /// AST Walker state.
   State &Parent;

diff  --git a/clang/lib/AST/Interp/Opcodes.td b/clang/lib/AST/Interp/Opcodes.td
index 15f7312ad00e6..28074a350d05f 100644
--- a/clang/lib/AST/Interp/Opcodes.td
+++ b/clang/lib/AST/Interp/Opcodes.td
@@ -181,6 +181,11 @@ def Call : Opcode {
   let Types = [];
 }
 
+def CallVirt : Opcode {
+  let Args = [ArgFunction];
+  let Types = [];
+}
+
 def CallBI : Opcode {
   let Args = [ArgFunction];
   let Types = [];

diff  --git a/clang/lib/AST/Interp/Pointer.h b/clang/lib/AST/Interp/Pointer.h
index 863d8b3bae0d6..ab196beb93aa3 100644
--- a/clang/lib/AST/Interp/Pointer.h
+++ b/clang/lib/AST/Interp/Pointer.h
@@ -200,6 +200,8 @@ class Pointer {
   /// Returns the type of the innermost field.
   QualType getType() const { return getFieldDesc()->getType(); }
 
+  Pointer getDeclPtr() const { return Pointer(Pointee); }
+
   /// Returns the element size of the innermost field.
   size_t elemSize() const {
     if (Base == RootPtrMark)

diff  --git a/clang/test/AST/Interp/records.cpp b/clang/test/AST/Interp/records.cpp
index a874bf3781504..68d902a66e3ca 100644
--- a/clang/test/AST/Interp/records.cpp
+++ b/clang/test/AST/Interp/records.cpp
@@ -1,8 +1,10 @@
 // RUN: %clang_cc1 -fexperimental-new-constant-interpreter -verify %s
 // RUN: %clang_cc1 -fexperimental-new-constant-interpreter -std=c++14 -verify %s
+// RUN: %clang_cc1 -fexperimental-new-constant-interpreter -std=c++20 -verify %s
 // RUN: %clang_cc1 -fexperimental-new-constant-interpreter -triple i686 -verify %s
 // RUN: %clang_cc1 -verify=ref %s
 // RUN: %clang_cc1 -verify=ref -std=c++14 %s
+// RUN: %clang_cc1 -verify=ref -std=c++20 %s
 // RUN: %clang_cc1 -verify=ref -triple i686 %s
 
 struct BoolPair {
@@ -380,6 +382,7 @@ namespace MI {
 };
 
 namespace DeriveFailures {
+#if __cplusplus < 202002L
   struct Base { // ref-note 2{{declared here}} expected-note {{declared here}}
     int Val;
   };
@@ -397,10 +400,12 @@ namespace DeriveFailures {
                            // ref-note {{declared here}} \
                            // expected-error {{must be initialized by a constant expression}} \
                            // expected-note {{in call to 'Derived(12)'}}
+
   static_assert(D.Val == 0, ""); // ref-error {{not an integral constant expression}} \
                                  // ref-note {{initializer of 'D' is not a constant expression}} \
                                  // expected-error {{not an integral constant expression}} \
                                  // expected-note {{read of object outside its lifetime}}
+#endif
 
   struct AnotherBase {
     int Val;
@@ -488,3 +493,201 @@ namespace DeclRefs {
   //static_assert(b.a.m == 100, "");
   //static_assert(b.a.f == 100, "");
 }
+
+#if __cplusplus >= 202002L
+namespace VirtualCalls {
+namespace Obvious {
+
+  class A {
+  public:
+    constexpr A(){}
+    constexpr virtual int foo() {
+      return 3;
+    }
+  };
+  class B : public A {
+  public:
+    constexpr int foo() override {
+      return 6;
+    }
+  };
+
+  constexpr int getFooB(bool b) {
+    A *a;
+    A myA;
+    B myB;
+
+    if (b)
+      a = &myA;
+    else
+      a = &myB;
+
+    return a->foo();
+  }
+  static_assert(getFooB(true) == 3, "");
+  static_assert(getFooB(false) == 6, "");
+}
+
+namespace MultipleBases {
+  class A {
+  public:
+    constexpr virtual int getInt() const { return 10; }
+  };
+  class B {
+  public:
+  };
+  class C : public A, public B {
+  public:
+    constexpr int getInt() const override { return 20; }
+  };
+
+  constexpr int callGetInt(const A& a) { return a.getInt(); }
+  static_assert(callGetInt(C()) == 20, "");
+  static_assert(callGetInt(A()) == 10, "");
+}
+
+namespace Destructors {
+  class Base {
+  public:
+    int i;
+    constexpr Base(int &i) : i(i) {i++;}
+    constexpr virtual ~Base() {i--;}
+  };
+
+  class Derived : public Base {
+  public:
+    constexpr Derived(int &i) : Base(i) {}
+    constexpr virtual ~Derived() {i--;}
+  };
+
+  constexpr int test() {
+    int i = 0;
+    Derived d(i);
+    return i;
+  }
+  static_assert(test() == 1);
+}
+
+
+namespace VirtualDtors {
+  class A {
+  public:
+    unsigned &v;
+    constexpr A(unsigned &v) : v(v) {}
+    constexpr virtual ~A() {
+      v |= (1 << 0);
+    }
+  };
+  class B : public A {
+  public:
+    constexpr B(unsigned &v) : A(v) {}
+    constexpr virtual ~B() {
+      v |= (1 << 1);
+    }
+  };
+  class C : public B {
+  public:
+    constexpr C(unsigned &v) : B(v) {}
+    constexpr virtual ~C() {
+      v |= (1 << 2);
+    }
+  };
+
+  constexpr bool foo() {
+    unsigned a = 0;
+    {
+      C c(a);
+    }
+    return ((a & (1 << 0)) && (a & (1 << 1)) && (a & (1 << 2)));
+  }
+
+  static_assert(foo());
+
+
+};
+
+namespace QualifiedCalls {
+  class A {
+      public:
+      constexpr virtual int foo() const {
+          return 5;
+      }
+  };
+  class B : public A {};
+  class C : public B {
+      public:
+      constexpr int foo() const override {
+          return B::foo(); // B doesn't have a foo(), so this should call A::foo().
+      }
+      constexpr int foo2() const {
+        return this->A::foo();
+      }
+  };
+  constexpr C c;
+  static_assert(c.foo() == 5);
+  static_assert(c.foo2() == 5);
+
+
+  struct S {
+    int _c = 0;
+    virtual constexpr int foo() const { return 1; }
+  };
+
+  struct SS : S {
+    int a;
+    constexpr SS() {
+      a = S::foo();
+    }
+    constexpr int foo() const override {
+      return S::foo();
+    }
+  };
+
+  constexpr SS ss;
+  static_assert(ss.a == 1);
+}
+
+namespace CtorDtor {
+  struct Base {
+    int i = 0;
+    int j = 0;
+
+    constexpr Base() : i(func()) {
+      j = func();
+    }
+    constexpr Base(int i) : i(i), j(i) {}
+
+    constexpr virtual int func() const { return 1; }
+  };
+
+  struct Derived : Base {
+    constexpr Derived() {}
+    constexpr Derived(int i) : Base(i) {}
+    constexpr int func() const override { return 2; }
+  };
+
+  struct Derived2 : Derived {
+    constexpr Derived2() : Derived(func()) {} // ref-note {{subexpression not valid in a constant expression}}
+    constexpr int func() const override { return 3; }
+  };
+
+  constexpr Base B;
+  static_assert(B.i == 1 && B.j == 1, "");
+
+  constexpr Derived D;
+  static_assert(D.i == 1, ""); // expected-error {{static assertion failed}} \
+                               // expected-note {{2 == 1}}
+  static_assert(D.j == 1, ""); // expected-error {{static assertion failed}} \
+                               // expected-note {{2 == 1}}
+
+  constexpr Derived2 D2; // ref-error {{must be initialized by a constant expression}} \
+                         // ref-note {{in call to 'Derived2()'}} \
+                         // ref-note 2{{declared here}}
+  static_assert(D2.i == 3, ""); // ref-error {{not an integral constant expression}} \
+                                // ref-note {{initializer of 'D2' is not a constant expression}}
+  static_assert(D2.j == 3, ""); // ref-error {{not an integral constant expression}} \
+                                // ref-note {{initializer of 'D2' is not a constant expression}}
+
+}
+};
+#endif


        


More information about the cfe-commits mailing list