[clang] d53515a - [clang][Interp] Fix variadic member functions

Timm Bäder via cfe-commits cfe-commits at lists.llvm.org
Wed Feb 14 21:32:13 PST 2024


Author: Timm Bäder
Date: 2024-02-15T05:59:53+01:00
New Revision: d53515afef57a3abf84daff169fbc7626a306817

URL: https://github.com/llvm/llvm-project/commit/d53515afef57a3abf84daff169fbc7626a306817
DIFF: https://github.com/llvm/llvm-project/commit/d53515afef57a3abf84daff169fbc7626a306817.diff

LOG: [clang][Interp] Fix variadic member functions

For variadic member functions, the way we calculated the instance
pointer and RVO pointer offsts on the stack was incorrect, due
to Func->getArgSize() not returning the full size of all the
passed arguments. When calling variadic functions, we need
to pass the size of the passed (variadic) arguments to the Call*
ops, so they can use that information to properly check the
instance pointer, etc.

This patch adds a bit of code duplication in Interp.h, which I
will get rid of in later cleanup NFC patches.

Added: 
    

Modified: 
    clang/lib/AST/Interp/ByteCodeExprGen.cpp
    clang/lib/AST/Interp/ByteCodeStmtGen.cpp
    clang/lib/AST/Interp/Context.cpp
    clang/lib/AST/Interp/EvalEmitter.cpp
    clang/lib/AST/Interp/Function.h
    clang/lib/AST/Interp/Interp.cpp
    clang/lib/AST/Interp/Interp.h
    clang/lib/AST/Interp/InterpFrame.cpp
    clang/lib/AST/Interp/InterpFrame.h
    clang/lib/AST/Interp/Opcodes.td
    clang/test/AST/Interp/functions.cpp

Removed: 
    


################################################################################
diff  --git a/clang/lib/AST/Interp/ByteCodeExprGen.cpp b/clang/lib/AST/Interp/ByteCodeExprGen.cpp
index 91b9985eefbd30..988765972a36e6 100644
--- a/clang/lib/AST/Interp/ByteCodeExprGen.cpp
+++ b/clang/lib/AST/Interp/ByteCodeExprGen.cpp
@@ -1829,8 +1829,19 @@ bool ByteCodeExprGen<Emitter>::VisitCXXConstructExpr(
         return false;
     }
 
-    if (!this->emitCall(Func, E))
-      return false;
+    if (Func->isVariadic()) {
+      uint32_t VarArgSize = 0;
+      unsigned NumParams = Func->getNumWrittenParams();
+      for (unsigned I = NumParams, N = E->getNumArgs(); I != N; ++I) {
+        VarArgSize +=
+            align(primSize(classify(E->getArg(I)->getType()).value_or(PT_Ptr)));
+      }
+      if (!this->emitCallVar(Func, VarArgSize, E))
+        return false;
+    } else {
+      if (!this->emitCall(Func, 0, E))
+        return false;
+    }
 
     // Immediately call the destructor if we have to.
     if (DiscardResult) {
@@ -1863,7 +1874,7 @@ bool ByteCodeExprGen<Emitter>::VisitCXXConstructExpr(
           return false;
       }
 
-      if (!this->emitCall(Func, E))
+      if (!this->emitCall(Func, 0, E))
         return false;
     }
     return true;
@@ -2049,7 +2060,7 @@ bool ByteCodeExprGen<Emitter>::VisitCXXInheritedCtorInitExpr(
     Offset += align(primSize(PT));
   }
 
-  return this->emitCall(F, E);
+  return this->emitCall(F, 0, E);
 }
 
 template <class Emitter>
@@ -2846,20 +2857,38 @@ bool ByteCodeExprGen<Emitter>::VisitCallExpr(const CallExpr *E) {
     // and if the function has RVO, we already have the pointer on the stack to
     // write the result into.
     if (IsVirtual && !HasQualifier) {
-      if (!this->emitCallVirt(Func, E))
+      uint32_t VarArgSize = 0;
+      unsigned NumParams = Func->getNumWrittenParams();
+      for (unsigned I = NumParams, N = E->getNumArgs(); I != N; ++I)
+        VarArgSize += align(primSize(classify(E->getArg(I)).value_or(PT_Ptr)));
+
+      if (!this->emitCallVirt(Func, VarArgSize, E))
+        return false;
+    } else if (Func->isVariadic()) {
+      uint32_t VarArgSize = 0;
+      unsigned NumParams = Func->getNumWrittenParams();
+      for (unsigned I = NumParams, N = E->getNumArgs(); I != N; ++I)
+        VarArgSize += align(primSize(classify(E->getArg(I)).value_or(PT_Ptr)));
+      if (!this->emitCallVar(Func, VarArgSize, E))
         return false;
     } else {
-      if (!this->emitCall(Func, E))
+      if (!this->emitCall(Func, 0, E))
         return false;
     }
   } else {
     // Indirect call. Visit the callee, which will leave a FunctionPointer on
     // the stack. Cleanup of the returned value if necessary will be done after
     // the function call completed.
+
+    // Sum the size of all args from the call expr.
+    uint32_t ArgSize = 0;
+    for (unsigned I = 0, N = E->getNumArgs(); I != N; ++I)
+      ArgSize += align(primSize(classify(E->getArg(I)).value_or(PT_Ptr)));
+
     if (!this->visit(E->getCallee()))
       return false;
 
-    if (!this->emitCallPtr(E))
+    if (!this->emitCallPtr(ArgSize, E))
       return false;
   }
 
@@ -3386,7 +3415,7 @@ bool ByteCodeExprGen<Emitter>::emitRecordDestruction(const Descriptor *Desc) {
       assert(DtorFunc->getNumParams() == 1);
       if (!this->emitDupPtr(SourceInfo{}))
         return false;
-      if (!this->emitCall(DtorFunc, SourceInfo{}))
+      if (!this->emitCall(DtorFunc, 0, SourceInfo{}))
         return false;
     }
   }

diff  --git a/clang/lib/AST/Interp/ByteCodeStmtGen.cpp b/clang/lib/AST/Interp/ByteCodeStmtGen.cpp
index bedcc78dc23555..7e2043f8de90b0 100644
--- a/clang/lib/AST/Interp/ByteCodeStmtGen.cpp
+++ b/clang/lib/AST/Interp/ByteCodeStmtGen.cpp
@@ -126,7 +126,7 @@ bool ByteCodeStmtGen<Emitter>::emitLambdaStaticInvokerBody(
       return false;
   }
 
-  if (!this->emitCall(Func, LambdaCallOp))
+  if (!this->emitCall(Func, 0, LambdaCallOp))
     return false;
 
   this->emitCleanup();

diff  --git a/clang/lib/AST/Interp/Context.cpp b/clang/lib/AST/Interp/Context.cpp
index 5f5a6622f10f3d..7396db22943663 100644
--- a/clang/lib/AST/Interp/Context.cpp
+++ b/clang/lib/AST/Interp/Context.cpp
@@ -209,7 +209,8 @@ bool Context::Run(State &Parent, const Function *Func, APValue &Result) {
 
   {
     InterpState State(Parent, *P, Stk, *this);
-    State.Current = new InterpFrame(State, Func, /*Caller=*/nullptr, {});
+    State.Current = new InterpFrame(State, Func, /*Caller=*/nullptr, CodePtr(),
+                                    Func->getArgSize());
     if (Interpret(State, Result)) {
       assert(Stk.empty());
       return true;

diff  --git a/clang/lib/AST/Interp/EvalEmitter.cpp b/clang/lib/AST/Interp/EvalEmitter.cpp
index 945b78d7a609d7..c1e4ce3ebb0729 100644
--- a/clang/lib/AST/Interp/EvalEmitter.cpp
+++ b/clang/lib/AST/Interp/EvalEmitter.cpp
@@ -22,7 +22,7 @@ EvalEmitter::EvalEmitter(Context &Ctx, Program &P, State &Parent,
     : Ctx(Ctx), P(P), S(Parent, P, Stk, Ctx, this), EvalResult(&Ctx) {
   // Create a dummy frame for the interpreter which does not have locals.
   S.Current =
-      new InterpFrame(S, /*Func=*/nullptr, /*Caller=*/nullptr, CodePtr());
+      new InterpFrame(S, /*Func=*/nullptr, /*Caller=*/nullptr, CodePtr(), 0);
 }
 
 EvalEmitter::~EvalEmitter() {

diff  --git a/clang/lib/AST/Interp/Function.h b/clang/lib/AST/Interp/Function.h
index 7c3e0f63024908..6500e0126c226f 100644
--- a/clang/lib/AST/Interp/Function.h
+++ b/clang/lib/AST/Interp/Function.h
@@ -183,6 +183,16 @@ class Function final {
 
   unsigned getNumParams() const { return ParamTypes.size(); }
 
+  /// Returns the number of parameter this function takes when it's called,
+  /// i.e excluding the instance pointer and the RVO pointer.
+  unsigned getNumWrittenParams() const {
+    assert(getNumParams() >= (hasThisPointer() + hasRVO()));
+    return getNumParams() - hasThisPointer() - hasRVO();
+  }
+  unsigned getWrittenArgSize() const {
+    return ArgSize - (align(primSize(PT_Ptr)) * (hasThisPointer() + hasRVO()));
+  }
+
   unsigned getParamOffset(unsigned ParamIndex) const {
     return ParamOffsets[ParamIndex];
   }

diff  --git a/clang/lib/AST/Interp/Interp.cpp b/clang/lib/AST/Interp/Interp.cpp
index 683151f7caf528..2338f88569db8b 100644
--- a/clang/lib/AST/Interp/Interp.cpp
+++ b/clang/lib/AST/Interp/Interp.cpp
@@ -169,16 +169,27 @@ void cleanupAfterFunctionCall(InterpState &S, CodePtr OpPC) {
     // CallExpr we're look for is at the return PC of the current function, i.e.
     // in the caller.
     // This code path should be executed very rarely.
-    const auto *CE =
-        cast<CallExpr>(S.Current->Caller->getExpr(S.Current->getRetPC()));
-    unsigned FixedParams = CurFunc->getNumParams();
-    int32_t ArgsToPop = CE->getNumArgs() - FixedParams;
-    assert(ArgsToPop >= 0);
-    for (int32_t I = ArgsToPop - 1; I >= 0; --I) {
-      const Expr *A = CE->getArg(FixedParams + I);
+    unsigned NumVarArgs;
+    const Expr *const *Args = nullptr;
+    unsigned NumArgs = 0;
+    const Expr *CallSite = S.Current->Caller->getExpr(S.Current->getRetPC());
+    if (const auto *CE = dyn_cast<CallExpr>(CallSite)) {
+      Args = CE->getArgs();
+      NumArgs = CE->getNumArgs();
+    } else if (const auto *CE = dyn_cast<CXXConstructExpr>(CallSite)) {
+      Args = CE->getArgs();
+      NumArgs = CE->getNumArgs();
+    } else
+      assert(false && "Can't get arguments from that expression type");
+
+    assert(NumArgs >= CurFunc->getNumWrittenParams());
+    NumVarArgs = NumArgs - CurFunc->getNumWrittenParams();
+    for (unsigned I = 0; I != NumVarArgs; ++I) {
+      const Expr *A = Args[NumArgs - 1 - I];
       popArg(S, A);
     }
   }
+
   // And in any case, remove the fixed parameters (the non-variadic ones)
   // at the end.
   S.Current->popArgs();

diff  --git a/clang/lib/AST/Interp/Interp.h b/clang/lib/AST/Interp/Interp.h
index e2fda18e3f44d4..77c724f08e8eef 100644
--- a/clang/lib/AST/Interp/Interp.h
+++ b/clang/lib/AST/Interp/Interp.h
@@ -1915,10 +1915,60 @@ inline bool ArrayDecay(InterpState &S, CodePtr OpPC) {
   return false;
 }
 
-inline bool Call(InterpState &S, CodePtr OpPC, const Function *Func) {
+inline bool CallVar(InterpState &S, CodePtr OpPC, const Function *Func,
+                    uint32_t VarArgSize) {
   if (Func->hasThisPointer()) {
-    size_t ThisOffset =
-        Func->getArgSize() - (Func->hasRVO() ? primSize(PT_Ptr) : 0);
+    size_t ArgSize = Func->getArgSize() + VarArgSize;
+    size_t ThisOffset = ArgSize - (Func->hasRVO() ? primSize(PT_Ptr) : 0);
+    const Pointer &ThisPtr = S.Stk.peek<Pointer>(ThisOffset);
+
+    // If the current function is a lambda static invoker and
+    // the function we're about to call is a lambda call operator,
+    // skip the CheckInvoke, since the ThisPtr is a null pointer
+    // anyway.
+    if (!(S.Current->getFunction() &&
+          S.Current->getFunction()->isLambdaStaticInvoker() &&
+          Func->isLambdaCallOperator())) {
+      if (!CheckInvoke(S, OpPC, ThisPtr))
+        return false;
+    }
+
+    if (S.checkingPotentialConstantExpression())
+      return false;
+  }
+
+  if (!CheckCallable(S, OpPC, Func))
+    return false;
+
+  if (!CheckCallDepth(S, OpPC))
+    return false;
+
+  auto NewFrame = std::make_unique<InterpFrame>(S, Func, OpPC, VarArgSize);
+  InterpFrame *FrameBefore = S.Current;
+  S.Current = NewFrame.get();
+
+  APValue CallResult;
+  // Note that we cannot assert(CallResult.hasValue()) here since
+  // Ret() above only sets the APValue if the curent frame doesn't
+  // have a caller set.
+  if (Interpret(S, CallResult)) {
+    NewFrame.release(); // Frame was delete'd already.
+    assert(S.Current == FrameBefore);
+    return true;
+  }
+
+  // Interpreting the function failed somehow. Reset to
+  // previous state.
+  S.Current = FrameBefore;
+  return false;
+
+  return false;
+}
+inline bool Call(InterpState &S, CodePtr OpPC, const Function *Func,
+                 uint32_t VarArgSize) {
+  if (Func->hasThisPointer()) {
+    size_t ArgSize = Func->getArgSize() + VarArgSize;
+    size_t ThisOffset = ArgSize - (Func->hasRVO() ? primSize(PT_Ptr) : 0);
 
     const Pointer &ThisPtr = S.Stk.peek<Pointer>(ThisOffset);
 
@@ -1943,7 +1993,7 @@ inline bool Call(InterpState &S, CodePtr OpPC, const Function *Func) {
   if (!CheckCallDepth(S, OpPC))
     return false;
 
-  auto NewFrame = std::make_unique<InterpFrame>(S, Func, OpPC);
+  auto NewFrame = std::make_unique<InterpFrame>(S, Func, OpPC, VarArgSize);
   InterpFrame *FrameBefore = S.Current;
   S.Current = NewFrame.get();
 
@@ -1963,11 +2013,12 @@ inline bool Call(InterpState &S, CodePtr OpPC, const Function *Func) {
   return false;
 }
 
-inline bool CallVirt(InterpState &S, CodePtr OpPC, const Function *Func) {
+inline bool CallVirt(InterpState &S, CodePtr OpPC, const Function *Func,
+                     uint32_t VarArgSize) {
   assert(Func->hasThisPointer());
   assert(Func->isVirtual());
-  size_t ThisOffset =
-      Func->getArgSize() - (Func->hasRVO() ? primSize(PT_Ptr) : 0);
+  size_t ArgSize = Func->getArgSize() + VarArgSize;
+  size_t ThisOffset = ArgSize - (Func->hasRVO() ? primSize(PT_Ptr) : 0);
   Pointer &ThisPtr = S.Stk.peek<Pointer>(ThisOffset);
 
   const CXXRecordDecl *DynamicDecl =
@@ -1998,7 +2049,7 @@ inline bool CallVirt(InterpState &S, CodePtr OpPC, const Function *Func) {
     }
   }
 
-  return Call(S, OpPC, Func);
+  return Call(S, OpPC, Func, VarArgSize);
 }
 
 inline bool CallBI(InterpState &S, CodePtr &PC, const Function *Func,
@@ -2016,17 +2067,20 @@ inline bool CallBI(InterpState &S, CodePtr &PC, const Function *Func,
   return false;
 }
 
-inline bool CallPtr(InterpState &S, CodePtr OpPC) {
+inline bool CallPtr(InterpState &S, CodePtr OpPC, uint32_t ArgSize) {
   const FunctionPointer &FuncPtr = S.Stk.pop<FunctionPointer>();
 
   const Function *F = FuncPtr.getFunction();
   if (!F || !F->isConstexpr())
     return false;
 
+  assert(ArgSize >= F->getWrittenArgSize());
+  uint32_t VarArgSize = ArgSize - F->getWrittenArgSize();
+
   if (F->isVirtual())
-    return CallVirt(S, OpPC, F);
+    return CallVirt(S, OpPC, F, VarArgSize);
 
-  return Call(S, OpPC, F);
+  return Call(S, OpPC, F, VarArgSize);
 }
 
 inline bool GetFnPtr(InterpState &S, CodePtr OpPC, const Function *Func) {

diff  --git a/clang/lib/AST/Interp/InterpFrame.cpp b/clang/lib/AST/Interp/InterpFrame.cpp
index bf2cca733b66bb..f69ff06b5e81b5 100644
--- a/clang/lib/AST/Interp/InterpFrame.cpp
+++ b/clang/lib/AST/Interp/InterpFrame.cpp
@@ -22,10 +22,10 @@ using namespace clang;
 using namespace clang::interp;
 
 InterpFrame::InterpFrame(InterpState &S, const Function *Func,
-                         InterpFrame *Caller, CodePtr RetPC)
+                         InterpFrame *Caller, CodePtr RetPC, unsigned ArgSize)
     : Caller(Caller), S(S), Depth(Caller ? Caller->Depth + 1 : 0), Func(Func),
-      RetPC(RetPC), ArgSize(Func ? Func->getArgSize() : 0),
-      Args(static_cast<char *>(S.Stk.top())), FrameOffset(S.Stk.size()) {
+      RetPC(RetPC), ArgSize(ArgSize), Args(static_cast<char *>(S.Stk.top())),
+      FrameOffset(S.Stk.size()) {
   if (!Func)
     return;
 
@@ -43,8 +43,9 @@ InterpFrame::InterpFrame(InterpState &S, const Function *Func,
   }
 }
 
-InterpFrame::InterpFrame(InterpState &S, const Function *Func, CodePtr RetPC)
-    : InterpFrame(S, Func, S.Current, RetPC) {
+InterpFrame::InterpFrame(InterpState &S, const Function *Func, CodePtr RetPC,
+                         unsigned VarArgSize)
+    : InterpFrame(S, Func, S.Current, RetPC, Func->getArgSize() + VarArgSize) {
   // As per our calling convention, the this pointer is
   // part of the ArgSize.
   // If the function has RVO, the RVO pointer is first.

diff  --git a/clang/lib/AST/Interp/InterpFrame.h b/clang/lib/AST/Interp/InterpFrame.h
index cba4f9560bf56a..322d5dcfa698ae 100644
--- a/clang/lib/AST/Interp/InterpFrame.h
+++ b/clang/lib/AST/Interp/InterpFrame.h
@@ -32,13 +32,14 @@ class InterpFrame final : public Frame {
 
   /// Creates a new frame for a method call.
   InterpFrame(InterpState &S, const Function *Func, InterpFrame *Caller,
-              CodePtr RetPC);
+              CodePtr RetPC, unsigned ArgSize);
 
   /// Creates a new frame with the values that make sense.
   /// I.e., the caller is the current frame of S,
   /// the This() pointer is the current Pointer on the top of S's stack,
   /// and the RVO pointer is before that.
-  InterpFrame(InterpState &S, const Function *Func, CodePtr RetPC);
+  InterpFrame(InterpState &S, const Function *Func, CodePtr RetPC,
+              unsigned VarArgSize = 0);
 
   /// Destroys the frame, killing all live pointers to stack slots.
   ~InterpFrame();

diff  --git a/clang/lib/AST/Interp/Opcodes.td b/clang/lib/AST/Interp/Opcodes.td
index 7f5bd7e5b44bca..f1b08944a8812e 100644
--- a/clang/lib/AST/Interp/Opcodes.td
+++ b/clang/lib/AST/Interp/Opcodes.td
@@ -191,12 +191,12 @@ def NoRet : Opcode {}
 
 
 def Call : Opcode {
-  let Args = [ArgFunction];
+  let Args = [ArgFunction, ArgUint32];
   let Types = [];
 }
 
 def CallVirt : Opcode {
-  let Args = [ArgFunction];
+  let Args = [ArgFunction, ArgUint32];
   let Types = [];
 }
 
@@ -206,7 +206,12 @@ def CallBI : Opcode {
 }
 
 def CallPtr : Opcode {
-  let Args = [];
+  let Args = [ArgUint32];
+  let Types = [];
+}
+
+def CallVar : Opcode {
+  let Args = [ArgFunction, ArgUint32];
   let Types = [];
 }
 

diff  --git a/clang/test/AST/Interp/functions.cpp b/clang/test/AST/Interp/functions.cpp
index 6e995ce704e394..7b8278cf13aa88 100644
--- a/clang/test/AST/Interp/functions.cpp
+++ b/clang/test/AST/Interp/functions.cpp
@@ -381,6 +381,81 @@ namespace Variadic {
 
   constexpr int (*VFP)(...) = variadic_function2;
   static_assert(VFP() == 12, "");
+
+  /// Member functions
+  struct Foo {
+    int a = 0;
+    constexpr void bla(...) {}
+    constexpr S bla2(...) {
+      return S{12, true};
+    }
+    constexpr Foo(...) : a(1337) {}
+    constexpr Foo(void *c, bool b, void*p, ...) : a('a' + b) {}
+    constexpr Foo(int a, const S* s, ...) : a(a) {}
+  };
+
+  constexpr int foo2() {
+    Foo f(1, nullptr);
+    auto s = f.bla2(1, 2, S{1, false});
+    return s.a + s.b;
+  }
+  static_assert(foo2() == 13, "");
+
+  constexpr Foo _f = 123;
+  static_assert(_f.a == 1337, "");
+
+  constexpr Foo __f(nullptr, false, nullptr, nullptr, 'a', Foo());
+  static_assert(__f.a ==  'a', "");
+
+
+#if __cplusplus >= 202002L
+namespace VariadicVirtual {
+  class A {
+  public:
+    constexpr virtual void foo(int &a, ...) {
+      a = 1;
+    }
+  };
+
+  class B : public A {
+  public:
+    constexpr void foo(int &a, ...) override {
+      a = 2;
+    }
+  };
+
+  constexpr int foo() {
+    B b;
+    int a;
+    b.foo(a, 1,2,nullptr);
+    return a;
+  }
+  static_assert(foo() == 2, "");
+} // VariadicVirtual
+
+namespace VariadicQualified {
+  class A {
+      public:
+      constexpr virtual int foo(...) const {
+          return 5;
+      }
+  };
+  class B : public A {};
+  class C : public B {
+      public:
+      constexpr int foo(...) const override {
+          return B::foo(1,2,3); // B doesn't have a foo(), so this should call A::foo().
+      }
+      constexpr int foo2() const {
+        return this->A::foo(1,2,3,this);
+      }
+  };
+  constexpr C c;
+  static_assert(c.foo() == 5);
+  static_assert(c.foo2() == 5);
+} // VariadicQualified
+#endif
+
 }
 
 namespace Packs {


        


More information about the cfe-commits mailing list