[clang] 6dfe555 - [clang][Interp] Rework initializers

Timm Bäder via cfe-commits cfe-commits at lists.llvm.org
Sun Aug 20 04:33:27 PDT 2023


Author: Timm Bäder
Date: 2023-08-20T13:33:08+02:00
New Revision: 6dfe55569d88ff654d13e6c09267eff0cd9c9f0d

URL: https://github.com/llvm/llvm-project/commit/6dfe55569d88ff654d13e6c09267eff0cd9c9f0d
DIFF: https://github.com/llvm/llvm-project/commit/6dfe55569d88ff654d13e6c09267eff0cd9c9f0d.diff

LOG: [clang][Interp] Rework initializers

Before this patch, we had visitRecordInitializer() and
visitArrayInitializer(), which were different from the regular visit()
in that they expected a pointer on the top of the stack, which they
initialized. For example, visitArrayInitializer handled InitListExprs by
looping over the members and initializing the elements of that pointer.

However, this had a few corner cases and problems. For example, in
visitLambdaExpr() (a lambda is always of record type), it was not clear
whether we should always create a new local variable to save the lambda
to, or not. This is why https://reviews.llvm.org/D153616 changed
things around.

This patch changes the visiting functions to:

 - visit(): Always leaves a new value on the stack. If the expression
   can be mapped to a primitive type, it's just visited and the value is
   put on the stack. If it's of composite type, this function will
   create a local variable for the expression value and call
   visitInitializer(). The pointer to the local variable will stay on
   the stack.

 - visitInitializer(): Visits the given expression, assuming there is a
   pointer on top of the stack that will be initialized by it.

 - discard(): Visit the expression for side-effects, but don't leave a
   value on the stack.

It also adds an additional Initializing flag to differentiate between the initializing and non-initializing case.

Differential Revision: https://reviews.llvm.org/D156027

Added: 
    

Modified: 
    clang/lib/AST/Interp/ByteCodeExprGen.cpp
    clang/lib/AST/Interp/ByteCodeExprGen.h
    clang/lib/AST/Interp/Context.cpp
    clang/test/AST/Interp/lambda.cpp
    clang/test/AST/Interp/records.cpp

Removed: 
    


################################################################################
diff  --git a/clang/lib/AST/Interp/ByteCodeExprGen.cpp b/clang/lib/AST/Interp/ByteCodeExprGen.cpp
index 94eb1998839f4a..d8a4ca0db12fc8 100644
--- a/clang/lib/AST/Interp/ByteCodeExprGen.cpp
+++ b/clang/lib/AST/Interp/ByteCodeExprGen.cpp
@@ -43,18 +43,25 @@ template <class Emitter> class DeclScope final : public VariableScope<Emitter> {
 template <class Emitter> class OptionScope final {
 public:
   /// Root constructor, compiling or discarding primitives.
-  OptionScope(ByteCodeExprGen<Emitter> *Ctx, bool NewDiscardResult)
-      : Ctx(Ctx), OldDiscardResult(Ctx->DiscardResult) {
+  OptionScope(ByteCodeExprGen<Emitter> *Ctx, bool NewDiscardResult,
+              bool NewInitializing)
+      : Ctx(Ctx), OldDiscardResult(Ctx->DiscardResult),
+        OldInitializing(Ctx->Initializing) {
     Ctx->DiscardResult = NewDiscardResult;
+    Ctx->Initializing = NewInitializing;
   }
 
-  ~OptionScope() { Ctx->DiscardResult = OldDiscardResult; }
+  ~OptionScope() {
+    Ctx->DiscardResult = OldDiscardResult;
+    Ctx->Initializing = OldInitializing;
+  }
 
 private:
   /// Parent context.
   ByteCodeExprGen<Emitter> *Ctx;
   /// Old discard flag to restore.
   bool OldDiscardResult;
+  bool OldInitializing;
 };
 
 } // namespace interp
@@ -144,9 +151,7 @@ bool ByteCodeExprGen<Emitter>::VisitCastExpr(const CastExpr *CE) {
   case CK_NoOp:
   case CK_UserDefinedConversion:
   case CK_BitCast:
-    if (DiscardResult)
-      return this->discard(SubExpr);
-    return this->visit(SubExpr);
+    return this->delegate(SubExpr);
 
   case CK_IntegralToBoolean:
   case CK_IntegralCast: {
@@ -245,7 +250,8 @@ bool ByteCodeExprGen<Emitter>::VisitBinaryOperator(const BinaryOperator *BO) {
       return this->discard(RHS);
 
     // Otherwise, visit RHS and optionally discard its value.
-    return Discard(this->visit(RHS));
+    return Discard(Initializing ? this->visitInitializer(RHS)
+                                : this->visit(RHS));
   }
 
   if (!LT || !RT || !T)
@@ -438,12 +444,38 @@ bool ByteCodeExprGen<Emitter>::VisitLogicalBinOp(const BinaryOperator *E) {
 
 template <class Emitter>
 bool ByteCodeExprGen<Emitter>::VisitImplicitValueInitExpr(const ImplicitValueInitExpr *E) {
-  std::optional<PrimType> T = classify(E);
+  QualType QT = E->getType();
 
-  if (!T)
+  if (classify(QT))
+    return this->visitZeroInitializer(QT, E);
+
+  if (QT->isRecordType())
     return false;
 
-  return this->visitZeroInitializer(E->getType(), E);
+  if (QT->isArrayType()) {
+    const ArrayType *AT = QT->getAsArrayTypeUnsafe();
+    assert(AT);
+    const auto *CAT = cast<ConstantArrayType>(AT);
+    size_t NumElems = CAT->getSize().getZExtValue();
+
+    if (std::optional<PrimType> ElemT = classify(CAT->getElementType())) {
+      // TODO(perf): For int and bool types, we can probably just skip this
+      //   since we memset our Block*s to 0 and so we have the desired value
+      //   without this.
+      for (size_t I = 0; I != NumElems; ++I) {
+        if (!this->visitZeroInitializer(CAT->getElementType(), E))
+          return false;
+        if (!this->emitInitElem(*ElemT, I, E))
+          return false;
+      }
+    } else {
+      assert(false && "default initializer for non-primitive type");
+    }
+
+    return true;
+  }
+
+  return false;
 }
 
 template <class Emitter>
@@ -469,31 +501,116 @@ bool ByteCodeExprGen<Emitter>::VisitArraySubscriptExpr(
 
 template <class Emitter>
 bool ByteCodeExprGen<Emitter>::VisitInitListExpr(const InitListExpr *E) {
-  for (const Expr *Init : E->inits()) {
-    if (DiscardResult) {
+  // Handle discarding first.
+  if (DiscardResult) {
+    for (const Expr *Init : E->inits()) {
       if (!this->discard(Init))
         return false;
-    } else {
-      if (!this->visit(Init))
+    }
+    return true;
+  }
+
+  // Primitive values.
+  if (std::optional<PrimType> T = classify(E->getType())) {
+    assert(E->getNumInits() == 1);
+    assert(!DiscardResult);
+    return this->delegate(E->inits()[0]);
+  }
+
+  QualType T = E->getType();
+  if (T->isRecordType()) {
+    const Record *R = getRecord(T);
+
+    unsigned InitIndex = 0;
+    for (const Expr *Init : E->inits()) {
+      if (!this->emitDupPtr(E))
         return false;
+
+      if (std::optional<PrimType> T = classify(Init)) {
+        const Record::Field *FieldToInit = R->getField(InitIndex);
+        if (!this->visit(Init))
+          return false;
+
+        if (!this->emitInitField(*T, FieldToInit->Offset, E))
+          return false;
+
+        if (!this->emitPopPtr(E))
+          return false;
+        ++InitIndex;
+      } else {
+        // Initializer for a direct base class.
+        if (const Record::Base *B = R->getBase(Init->getType())) {
+          if (!this->emitGetPtrBasePop(B->Offset, Init))
+            return false;
+
+          if (!this->visitInitializer(Init))
+            return false;
+
+          if (!this->emitPopPtr(E))
+            return false;
+          // Base initializers don't increase InitIndex, since they don't count
+          // into the Record's fields.
+        } else {
+          const Record::Field *FieldToInit = R->getField(InitIndex);
+          // Non-primitive case. Get a pointer to the field-to-initialize
+          // on the stack and recurse into visitInitializer().
+          if (!this->emitGetPtrField(FieldToInit->Offset, Init))
+            return false;
+
+          if (!this->visitInitializer(Init))
+            return false;
+
+          if (!this->emitPopPtr(E))
+            return false;
+          ++InitIndex;
+        }
+      }
     }
+    return true;
   }
-  return true;
+
+  if (T->isArrayType()) {
+    // FIXME: Array fillers.
+    unsigned ElementIndex = 0;
+    for (const Expr *Init : E->inits()) {
+      if (std::optional<PrimType> T = classify(Init->getType())) {
+        // Visit the primitive element like normal.
+        if (!this->visit(Init))
+          return false;
+        if (!this->emitInitElem(*T, ElementIndex, Init))
+          return false;
+      } else {
+        // Advance the pointer currently on the stack to the given
+        // dimension.
+        if (!this->emitConstUint32(ElementIndex, Init))
+          return false;
+        if (!this->emitArrayElemPtrUint32(Init))
+          return false;
+        if (!this->visitInitializer(Init))
+          return false;
+        if (!this->emitPopPtr(Init))
+          return false;
+      }
+
+      ++ElementIndex;
+    }
+    return true;
+  }
+
+  return false;
 }
 
 template <class Emitter>
 bool ByteCodeExprGen<Emitter>::VisitSubstNonTypeTemplateParmExpr(
     const SubstNonTypeTemplateParmExpr *E) {
-  if (DiscardResult)
-    return this->discard(E->getReplacement());
-  return this->visit(E->getReplacement());
+  return this->delegate(E->getReplacement());
 }
 
 template <class Emitter>
 bool ByteCodeExprGen<Emitter>::VisitConstantExpr(const ConstantExpr *E) {
   // TODO: Check if the ConstantExpr already has a value set and if so,
   //   use that instead of evaluating it again.
-  return this->visit(E->getSubExpr());
+  return this->delegate(E->getSubExpr());
 }
 
 static CharUnits AlignOfType(QualType T, const ASTContext &ASTCtx,
@@ -613,25 +730,129 @@ bool ByteCodeExprGen<Emitter>::VisitArrayInitIndexExpr(
   return this->emitConst(*ArrayIndex, E);
 }
 
+template <class Emitter>
+bool ByteCodeExprGen<Emitter>::VisitArrayInitLoopExpr(
+    const ArrayInitLoopExpr *E) {
+  assert(Initializing);
+  assert(!DiscardResult);
+  // TODO: This compiles to quite a lot of bytecode if the array is larger.
+  //   Investigate compiling this to a loop, or at least try to use
+  //   the AILE's Common expr.
+  const Expr *SubExpr = E->getSubExpr();
+  size_t Size = E->getArraySize().getZExtValue();
+  std::optional<PrimType> ElemT = classify(SubExpr->getType());
+
+  // So, every iteration, we execute an assignment here
+  // where the LHS is on the stack (the target array)
+  // and the RHS is our SubExpr.
+  for (size_t I = 0; I != Size; ++I) {
+    ArrayIndexScope<Emitter> IndexScope(this, I);
+
+    if (ElemT) {
+      if (!this->visit(SubExpr))
+        return false;
+      if (!this->emitInitElem(*ElemT, I, E))
+        return false;
+    } else {
+      // Get to our array element and recurse into visitInitializer()
+      if (!this->emitConstUint64(I, SubExpr))
+        return false;
+      if (!this->emitArrayElemPtrUint64(SubExpr))
+        return false;
+      if (!visitInitializer(SubExpr))
+        return false;
+      if (!this->emitPopPtr(E))
+        return false;
+    }
+  }
+  return true;
+}
+
 template <class Emitter>
 bool ByteCodeExprGen<Emitter>::VisitOpaqueValueExpr(const OpaqueValueExpr *E) {
+  if (Initializing)
+    return this->visitInitializer(E->getSourceExpr());
   return this->visit(E->getSourceExpr());
 }
 
 template <class Emitter>
 bool ByteCodeExprGen<Emitter>::VisitAbstractConditionalOperator(
     const AbstractConditionalOperator *E) {
-  return this->visitConditional(E, [this](const Expr *E) {
-    return DiscardResult ? this->discard(E) : this->visit(E);
-  });
+  const Expr *Condition = E->getCond();
+  const Expr *TrueExpr = E->getTrueExpr();
+  const Expr *FalseExpr = E->getFalseExpr();
+
+  LabelTy LabelEnd = this->getLabel();   // Label after the operator.
+  LabelTy LabelFalse = this->getLabel(); // Label for the false expr.
+
+  if (!this->visit(Condition))
+    return false;
+
+  // C special case: Convert to bool because our jump ops need that.
+  // TODO: We probably want this to be done in visitBool().
+  if (std::optional<PrimType> CondT = classify(Condition->getType());
+      CondT && CondT != PT_Bool) {
+    if (!this->emitCast(*CondT, PT_Bool, E))
+      return false;
+  }
+
+  if (!this->jumpFalse(LabelFalse))
+    return false;
+
+  if (!this->delegate(TrueExpr))
+    return false;
+  if (!this->jump(LabelEnd))
+    return false;
+
+  this->emitLabel(LabelFalse);
+
+  if (!this->delegate(FalseExpr))
+    return false;
+
+  this->fallthrough(LabelEnd);
+  this->emitLabel(LabelEnd);
+
+  return true;
 }
 
 template <class Emitter>
 bool ByteCodeExprGen<Emitter>::VisitStringLiteral(const StringLiteral *E) {
   if (DiscardResult)
     return true;
-  unsigned StringIndex = P.createGlobalString(E);
-  return this->emitGetPtrGlobal(StringIndex, E);
+
+  if (!Initializing) {
+    unsigned StringIndex = P.createGlobalString(E);
+    return this->emitGetPtrGlobal(StringIndex, E);
+  }
+
+  // We are initializing an array on the stack.
+  const ConstantArrayType *CAT =
+      Ctx.getASTContext().getAsConstantArrayType(E->getType());
+  assert(CAT && "a string literal that's not a constant array?");
+
+  // If the initializer string is too long, a diagnostic has already been
+  // emitted. Read only the array length from the string literal.
+  unsigned N =
+      std::min(unsigned(CAT->getSize().getZExtValue()), E->getLength());
+  size_t CharWidth = E->getCharByteWidth();
+
+  for (unsigned I = 0; I != N; ++I) {
+    uint32_t CodeUnit = E->getCodeUnit(I);
+
+    if (CharWidth == 1) {
+      this->emitConstSint8(CodeUnit, E);
+      this->emitInitElemSint8(I, E);
+    } else if (CharWidth == 2) {
+      this->emitConstUint16(CodeUnit, E);
+      this->emitInitElemUint16(I, E);
+    } else if (CharWidth == 4) {
+      this->emitConstUint32(CodeUnit, E);
+      this->emitInitElemUint32(I, E);
+    } else {
+      llvm_unreachable("unsupported character width");
+    }
+  }
+  return true;
 }
 
 template <class Emitter>
@@ -873,23 +1094,25 @@ bool ByteCodeExprGen<Emitter>::VisitExprWithCleanups(
   const Expr *SubExpr = E->getSubExpr();
 
   assert(E->getNumObjects() == 0 && "TODO: Implement cleanups");
-  if (DiscardResult)
-    return this->discard(SubExpr);
 
-  return this->visit(SubExpr);
+  return this->delegate(SubExpr);
 }
 
 template <class Emitter>
 bool ByteCodeExprGen<Emitter>::VisitMaterializeTemporaryExpr(
     const MaterializeTemporaryExpr *E) {
   const Expr *SubExpr = E->getSubExpr();
-  std::optional<PrimType> SubExprT = classify(SubExpr);
 
+  if (Initializing) {
+    // We already have a value, just initialize that.
+    return this->visitInitializer(SubExpr);
+  }
   // If we don't end up using the materialized temporary anyway, don't
   // bother creating it.
   if (DiscardResult)
     return this->discard(SubExpr);
 
+  std::optional<PrimType> SubExprT = classify(SubExpr);
   if (E->getStorageDuration() == SD_Static) {
     std::optional<unsigned> GlobalIndex = P.createGlobal(E);
     if (!GlobalIndex)
@@ -900,7 +1123,7 @@ bool ByteCodeExprGen<Emitter>::VisitMaterializeTemporaryExpr(
     assert(TempDecl);
 
     if (SubExprT) {
-      if (!this->visitInitializer(SubExpr))
+      if (!this->visit(SubExpr))
         return false;
       if (!this->emitInitGlobalTemp(*SubExprT, *GlobalIndex, TempDecl, E))
         return false;
@@ -919,7 +1142,7 @@ bool ByteCodeExprGen<Emitter>::VisitMaterializeTemporaryExpr(
   if (SubExprT) {
     if (std::optional<unsigned> LocalIndex = allocateLocalPrimitive(
             SubExpr, *SubExprT, /*IsConst=*/true, /*IsExtended=*/true)) {
-      if (!this->visitInitializer(SubExpr))
+      if (!this->visit(SubExpr))
         return false;
       this->emitSetLocal(*SubExprT, *LocalIndex, E);
       return this->emitGetPtrLocal(*LocalIndex, E);
@@ -938,26 +1161,21 @@ bool ByteCodeExprGen<Emitter>::VisitMaterializeTemporaryExpr(
 template <class Emitter>
 bool ByteCodeExprGen<Emitter>::VisitCXXBindTemporaryExpr(
     const CXXBindTemporaryExpr *E) {
-
+  if (Initializing)
+    return this->visitInitializer(E->getSubExpr());
   return this->visit(E->getSubExpr());
 }
 
-template <class Emitter>
-bool ByteCodeExprGen<Emitter>::VisitCXXTemporaryObjectExpr(
-    const CXXTemporaryObjectExpr *E) {
-
-  if (std::optional<unsigned> LocalIndex =
-          allocateLocal(E, /*IsExtended=*/false)) {
-    return this->visitLocalInitializer(E, *LocalIndex);
-  }
-  return false;
-}
-
 template <class Emitter>
 bool ByteCodeExprGen<Emitter>::VisitCompoundLiteralExpr(
     const CompoundLiteralExpr *E) {
-  std::optional<PrimType> T = classify(E->getType());
   const Expr *Init = E->getInitializer();
+  if (Initializing) {
+    // We already have a value, just initialize that.
+    return this->visitInitializer(Init);
+  }
+
+  std::optional<PrimType> T = classify(E->getType());
   if (E->isFileScope()) {
     if (std::optional<unsigned> GlobalIndex = P.createGlobal(E)) {
       if (classify(E->getType()))
@@ -971,7 +1189,7 @@ bool ByteCodeExprGen<Emitter>::VisitCompoundLiteralExpr(
   // Otherwise, use a local variable.
   if (T) {
     // For primitive types, we just visit the initializer.
-    return DiscardResult ? this->discard(Init) : this->visit(Init);
+    return this->delegate(Init);
   } else {
     if (std::optional<unsigned> LocalIndex = allocateLocal(Init)) {
       if (!this->emitGetPtrLocal(*LocalIndex, E))
@@ -996,8 +1214,7 @@ bool ByteCodeExprGen<Emitter>::VisitTypeTraitExpr(const TypeTraitExpr *E) {
 
 template <class Emitter>
 bool ByteCodeExprGen<Emitter>::VisitLambdaExpr(const LambdaExpr *E) {
-  // XXX We assume here that a pointer-to-initialize is on the stack.
-
+  assert(Initializing);
   const Record *R = P.getOrCreateRecord(E->getLambdaClass());
 
   auto *CaptureInitIt = E->capture_init_begin();
@@ -1036,6 +1253,7 @@ bool ByteCodeExprGen<Emitter>::VisitPredefinedExpr(const PredefinedExpr *E) {
   if (DiscardResult)
     return true;
 
+  assert(!Initializing);
   return this->visit(E->getFunctionName());
 }
 
@@ -1065,74 +1283,151 @@ bool ByteCodeExprGen<Emitter>::VisitCXXNoexceptExpr(const CXXNoexceptExpr *E) {
   return this->emitConstBool(E->getValue(), E);
 }
 
+template <class Emitter>
+bool ByteCodeExprGen<Emitter>::VisitCXXConstructExpr(
+    const CXXConstructExpr *E) {
+  QualType T = E->getType();
+  assert(!classify(T));
+
+  if (T->isRecordType()) {
+    const Function *Func = getFunction(E->getConstructor());
+
+    if (!Func)
+      return false;
+
+    assert(Func->hasThisPointer());
+    assert(!Func->hasRVO());
+
+    // If we're discarding a construct expression, we still need
+    // to allocate a variable and call the constructor and destructor.
+    if (DiscardResult) {
+      assert(!Initializing);
+      std::optional<unsigned> LocalIndex =
+          allocateLocal(E, /*IsExtended=*/true);
+
+      if (!LocalIndex)
+        return false;
+
+      if (!this->emitGetPtrLocal(*LocalIndex, E))
+        return false;
+    }
+
+    //  The This pointer is already on the stack because this is an initializer,
+    //  but we need to dup() so the call() below has its own copy.
+    if (!this->emitDupPtr(E))
+      return false;
+
+    // Constructor arguments.
+    for (const auto *Arg : E->arguments()) {
+      if (!this->visit(Arg))
+        return false;
+    }
+
+    if (!this->emitCall(Func, E))
+      return false;
+
+    // Immediately call the destructor if we have to.
+    if (DiscardResult) {
+      if (!this->emitPopPtr(E))
+        return false;
+    }
+    return true;
+  }
+
+  if (T->isArrayType()) {
+    const ConstantArrayType *CAT =
+        Ctx.getASTContext().getAsConstantArrayType(E->getType());
+    assert(CAT);
+    size_t NumElems = CAT->getSize().getZExtValue();
+    const Function *Func = getFunction(E->getConstructor());
+    if (!Func || !Func->isConstexpr())
+      return false;
+
+    // FIXME(perf): We're calling the constructor once per array element here,
+    //   in the old intepreter we had a special-case for trivial constructors.
+    for (size_t I = 0; I != NumElems; ++I) {
+      if (!this->emitConstUint64(I, E))
+        return false;
+      if (!this->emitArrayElemPtrUint64(E))
+        return false;
+
+      // Constructor arguments.
+      for (const auto *Arg : E->arguments()) {
+        if (!this->visit(Arg))
+          return false;
+      }
+
+      if (!this->emitCall(Func, E))
+        return false;
+    }
+    return true;
+  }
+
+  return false;
+}
+
 template <class Emitter> bool ByteCodeExprGen<Emitter>::discard(const Expr *E) {
   if (E->containsErrors())
     return false;
 
-  OptionScope<Emitter> Scope(this, /*NewDiscardResult=*/true);
+  OptionScope<Emitter> Scope(this, /*NewDiscardResult=*/true,
+                             /*NewInitializing=*/false);
   return this->Visit(E);
 }
 
 template <class Emitter>
-bool ByteCodeExprGen<Emitter>::visit(const Expr *E) {
+bool ByteCodeExprGen<Emitter>::delegate(const Expr *E) {
   if (E->containsErrors())
     return false;
 
-  OptionScope<Emitter> Scope(this, /*NewDiscardResult=*/false);
+  // We're basically doing:
+  // OptionScope<Emitter> Scope(this, DicardResult, Initializing);
+  // but that's unnecessary of course.
   return this->Visit(E);
 }
 
-template <class Emitter>
-bool ByteCodeExprGen<Emitter>::visitBool(const Expr *E) {
-  if (std::optional<PrimType> T = classify(E->getType())) {
-    return visit(E);
-  } else {
-    return this->bail(E);
-  }
-}
-
-/// Visit a conditional operator, i.e. `A ? B : C`.
-/// \V determines what function to call for the B and C expressions.
-template <class Emitter>
-bool ByteCodeExprGen<Emitter>::visitConditional(
-    const AbstractConditionalOperator *E,
-    llvm::function_ref<bool(const Expr *)> V) {
-
-  const Expr *Condition = E->getCond();
-  const Expr *TrueExpr = E->getTrueExpr();
-  const Expr *FalseExpr = E->getFalseExpr();
+template <class Emitter> bool ByteCodeExprGen<Emitter>::visit(const Expr *E) {
+  if (E->containsErrors())
+    return false;
 
-  LabelTy LabelEnd = this->getLabel();   // Label after the operator.
-  LabelTy LabelFalse = this->getLabel(); // Label for the false expr.
+  if (E->getType()->isVoidType())
+    return this->discard(E);
 
-  if (!this->visit(Condition))
-    return false;
+  // Create local variable to hold the return value.
+  if (!E->isGLValue() && !classify(E->getType())) {
+    std::optional<unsigned> LocalIndex = allocateLocal(E, /*IsExtended=*/true);
+    if (!LocalIndex)
+      return false;
 
-  // C special case: Convert to bool because our jump ops need that.
-  // TODO: We probably want this to be done in visitBool().
-  if (std::optional<PrimType> CondT = classify(Condition->getType());
-      CondT && CondT != PT_Bool) {
-    if (!this->emitCast(*CondT, PT_Bool, E))
+    if (!this->emitGetPtrLocal(*LocalIndex, E))
       return false;
+    return this->visitInitializer(E);
   }
 
-  if (!this->jumpFalse(LabelFalse))
-    return false;
-
-  if (!V(TrueExpr))
-    return false;
-  if (!this->jump(LabelEnd))
-    return false;
+  //  Otherwise,we have a primitive return value, produce the value directly
+  //  and puish it on the stack.
+  OptionScope<Emitter> Scope(this, /*NewDiscardResult=*/false,
+                             /*NewInitializing=*/false);
+  return this->Visit(E);
+}
 
-  this->emitLabel(LabelFalse);
+template <class Emitter>
+bool ByteCodeExprGen<Emitter>::visitInitializer(const Expr *E) {
+  assert(!classify(E->getType()));
 
-  if (!V(FalseExpr))
+  if (E->containsErrors())
     return false;
 
-  this->fallthrough(LabelEnd);
-  this->emitLabel(LabelEnd);
+  OptionScope<Emitter> Scope(this, /*NewDiscardResult=*/false,
+                             /*NewInitializing=*/true);
+  return this->Visit(E);
+}
 
-  return true;
+template <class Emitter>
+bool ByteCodeExprGen<Emitter>::visitBool(const Expr *E) {
+  if (std::optional<PrimType> T = classify(E->getType()))
+    return visit(E);
+  return this->bail(E);
 }
 
 template <class Emitter>
@@ -1410,270 +1705,6 @@ ByteCodeExprGen<Emitter>::allocateLocal(DeclTy &&Src, bool IsExtended) {
   return Local.Offset;
 }
 
-// NB: When calling this function, we have a pointer to the
-//   array-to-initialize on the stack.
-template <class Emitter>
-bool ByteCodeExprGen<Emitter>::visitArrayInitializer(const Expr *Initializer) {
-  assert(Initializer->getType()->isArrayType());
-
-  // TODO: Fillers?
-  if (const auto *InitList = dyn_cast<InitListExpr>(Initializer)) {
-    unsigned ElementIndex = 0;
-    for (const Expr *Init : InitList->inits()) {
-      if (std::optional<PrimType> T = classify(Init->getType())) {
-        // Visit the primitive element like normal.
-        if (!this->visit(Init))
-          return false;
-        if (!this->emitInitElem(*T, ElementIndex, Init))
-          return false;
-      } else {
-        // Advance the pointer currently on the stack to the given
-        // dimension.
-        if (!this->emitConstUint32(ElementIndex, Init))
-          return false;
-        if (!this->emitArrayElemPtrUint32(Init))
-          return false;
-        if (!visitInitializer(Init))
-          return false;
-        if (!this->emitPopPtr(Init))
-          return false;
-      }
-
-      ++ElementIndex;
-    }
-    return true;
-  } else if (const auto *DIE = dyn_cast<CXXDefaultInitExpr>(Initializer)) {
-    return this->visitInitializer(DIE->getExpr());
-  } else if (const auto *AILE = dyn_cast<ArrayInitLoopExpr>(Initializer)) {
-    // TODO: This compiles to quite a lot of bytecode if the array is larger.
-    //   Investigate compiling this to a loop, or at least try to use
-    //   the AILE's Common expr.
-    const Expr *SubExpr = AILE->getSubExpr();
-    size_t Size = AILE->getArraySize().getZExtValue();
-    std::optional<PrimType> ElemT = classify(SubExpr->getType());
-
-    // So, every iteration, we execute an assignment here
-    // where the LHS is on the stack (the target array)
-    // and the RHS is our SubExpr.
-    for (size_t I = 0; I != Size; ++I) {
-      ArrayIndexScope<Emitter> IndexScope(this, I);
-
-      if (ElemT) {
-        if (!this->visit(SubExpr))
-          return false;
-        if (!this->emitInitElem(*ElemT, I, Initializer))
-          return false;
-      } else {
-        // Get to our array element and recurse into visitInitializer()
-        if (!this->emitConstUint64(I, SubExpr))
-          return false;
-        if (!this->emitArrayElemPtrUint64(SubExpr))
-          return false;
-        if (!visitInitializer(SubExpr))
-          return false;
-        if (!this->emitPopPtr(Initializer))
-          return false;
-      }
-    }
-    return true;
-  } else if (const auto *IVIE = dyn_cast<ImplicitValueInitExpr>(Initializer)) {
-    const ArrayType *AT = IVIE->getType()->getAsArrayTypeUnsafe();
-    assert(AT);
-    const auto *CAT = cast<ConstantArrayType>(AT);
-    size_t NumElems = CAT->getSize().getZExtValue();
-
-    if (std::optional<PrimType> ElemT = classify(CAT->getElementType())) {
-      // TODO(perf): For int and bool types, we can probably just skip this
-      //   since we memset our Block*s to 0 and so we have the desired value
-      //   without this.
-      for (size_t I = 0; I != NumElems; ++I) {
-        if (!this->visitZeroInitializer(CAT->getElementType(), Initializer))
-          return false;
-        if (!this->emitInitElem(*ElemT, I, Initializer))
-          return false;
-      }
-    } else {
-      assert(false && "default initializer for non-primitive type");
-    }
-
-    return true;
-  } else if (const auto *Ctor = dyn_cast<CXXConstructExpr>(Initializer)) {
-    const ConstantArrayType *CAT =
-        Ctx.getASTContext().getAsConstantArrayType(Ctor->getType());
-    assert(CAT);
-    size_t NumElems = CAT->getSize().getZExtValue();
-    const Function *Func = getFunction(Ctor->getConstructor());
-    if (!Func || !Func->isConstexpr())
-      return false;
-
-    // FIXME(perf): We're calling the constructor once per array element here,
-    //   in the old intepreter we had a special-case for trivial constructors.
-    for (size_t I = 0; I != NumElems; ++I) {
-      if (!this->emitConstUint64(I, Initializer))
-        return false;
-      if (!this->emitArrayElemPtrUint64(Initializer))
-        return false;
-
-      // Constructor arguments.
-      for (const auto *Arg : Ctor->arguments()) {
-        if (!this->visit(Arg))
-          return false;
-      }
-
-      if (!this->emitCall(Func, Initializer))
-        return false;
-    }
-    return true;
-  } else if (const auto *SL = dyn_cast<StringLiteral>(Initializer)) {
-    const ConstantArrayType *CAT =
-        Ctx.getASTContext().getAsConstantArrayType(SL->getType());
-    assert(CAT && "a string literal that's not a constant array?");
-
-    // If the initializer string is too long, a diagnostic has already been
-    // emitted. Read only the array length from the string literal.
-    unsigned N =
-        std::min(unsigned(CAT->getSize().getZExtValue()), SL->getLength());
-    size_t CharWidth = SL->getCharByteWidth();
-
-    for (unsigned I = 0; I != N; ++I) {
-      uint32_t CodeUnit = SL->getCodeUnit(I);
-
-      if (CharWidth == 1) {
-        this->emitConstSint8(CodeUnit, SL);
-        this->emitInitElemSint8(I, SL);
-      } else if (CharWidth == 2) {
-        this->emitConstUint16(CodeUnit, SL);
-        this->emitInitElemUint16(I, SL);
-      } else if (CharWidth == 4) {
-        this->emitConstUint32(CodeUnit, SL);
-        this->emitInitElemUint32(I, SL);
-      } else {
-        llvm_unreachable("unsupported character width");
-      }
-    }
-    return true;
-  } else if (const auto *CLE = dyn_cast<CompoundLiteralExpr>(Initializer)) {
-    return visitInitializer(CLE->getInitializer());
-  } else if (const auto *EWC = dyn_cast<ExprWithCleanups>(Initializer)) {
-    return visitInitializer(EWC->getSubExpr());
-  }
-
-  assert(false && "Unknown expression for array initialization");
-  return false;
-}
-
-template <class Emitter>
-bool ByteCodeExprGen<Emitter>::visitRecordInitializer(const Expr *Initializer) {
-  Initializer = Initializer->IgnoreParenImpCasts();
-  assert(Initializer->getType()->isRecordType());
-
-  if (const auto CtorExpr = dyn_cast<CXXConstructExpr>(Initializer)) {
-    const Function *Func = getFunction(CtorExpr->getConstructor());
-
-    if (!Func)
-      return false;
-
-    // The This pointer is already on the stack because this is an initializer,
-    // but we need to dup() so the call() below has its own copy.
-    if (!this->emitDupPtr(Initializer))
-      return false;
-
-    // Constructor arguments.
-    for (const auto *Arg : CtorExpr->arguments()) {
-      if (!this->visit(Arg))
-        return false;
-    }
-
-    return this->emitCall(Func, Initializer);
-  } else if (const auto *InitList = dyn_cast<InitListExpr>(Initializer)) {
-    const Record *R = getRecord(InitList->getType());
-
-    unsigned InitIndex = 0;
-    for (const Expr *Init : InitList->inits()) {
-
-      if (!this->emitDupPtr(Initializer))
-        return false;
-
-      if (std::optional<PrimType> T = classify(Init)) {
-        const Record::Field *FieldToInit = R->getField(InitIndex);
-        if (!this->visit(Init))
-          return false;
-
-        if (!this->emitInitField(*T, FieldToInit->Offset, Initializer))
-          return false;
-
-        if (!this->emitPopPtr(Initializer))
-          return false;
-        ++InitIndex;
-      } else {
-        // Initializer for a direct base class.
-        if (const Record::Base *B = R->getBase(Init->getType())) {
-          if (!this->emitGetPtrBasePop(B->Offset, Init))
-            return false;
-
-          if (!this->visitInitializer(Init))
-            return false;
-
-          if (!this->emitPopPtr(Initializer))
-            return false;
-          // Base initializers don't increase InitIndex, since they don't count
-          // into the Record's fields.
-        } else {
-          const Record::Field *FieldToInit = R->getField(InitIndex);
-          // Non-primitive case. Get a pointer to the field-to-initialize
-          // on the stack and recurse into visitInitializer().
-          if (!this->emitGetPtrField(FieldToInit->Offset, Init))
-            return false;
-
-          if (!this->visitInitializer(Init))
-            return false;
-
-          if (!this->emitPopPtr(Initializer))
-            return false;
-          ++InitIndex;
-        }
-      }
-    }
-
-    return true;
-  } else if (const CallExpr *CE = dyn_cast<CallExpr>(Initializer)) {
-    // RVO functions expect a pointer to initialize on the stack.
-    // Dup our existing pointer so it has its own copy to use.
-    if (!this->emitDupPtr(Initializer))
-      return false;
-
-    return this->visit(CE);
-  } else if (const auto *DIE = dyn_cast<CXXDefaultInitExpr>(Initializer)) {
-    return this->visitInitializer(DIE->getExpr());
-  } else if (const auto *CE = dyn_cast<CastExpr>(Initializer)) {
-    return this->visitInitializer(CE->getSubExpr());
-  } else if (const auto *CE = dyn_cast<CXXBindTemporaryExpr>(Initializer)) {
-    return this->visitInitializer(CE->getSubExpr());
-  } else if (const auto *ACO =
-                 dyn_cast<AbstractConditionalOperator>(Initializer)) {
-    return this->visitConditional(
-        ACO, [this](const Expr *E) { return this->visitRecordInitializer(E); });
-  } else if (const auto *LE = dyn_cast<LambdaExpr>(Initializer)) {
-    return this->VisitLambdaExpr(LE);
-  }
-
-  return false;
-}
-
-template <class Emitter>
-bool ByteCodeExprGen<Emitter>::visitInitializer(const Expr *Initializer) {
-  QualType InitializerType = Initializer->getType();
-
-  if (InitializerType->isArrayType())
-    return visitArrayInitializer(Initializer);
-
-  if (InitializerType->isRecordType())
-    return visitRecordInitializer(Initializer);
-
-  // Otherwise, visit the expression like normal.
-  return this->visit(Initializer);
-}
-
 template <class Emitter>
 const RecordType *ByteCodeExprGen<Emitter>::getRecordTy(QualType Ty) {
   if (const PointerType *PT = dyn_cast<PointerType>(Ty))
@@ -1854,13 +1885,21 @@ bool ByteCodeExprGen<Emitter>::VisitCallExpr(const CallExpr *E) {
   std::optional<PrimType> T = classify(ReturnType);
   bool HasRVO = !ReturnType->isVoidType() && !T;
 
-  if (HasRVO && DiscardResult) {
-    // If we need to discard the return value but the function returns its
-    // value via an RVO pointer, we need to create one such pointer just
-    // for this call.
-    if (std::optional<unsigned> LocalIndex = allocateLocal(E)) {
-      if (!this->emitGetPtrLocal(*LocalIndex, E))
-        return false;
+  if (HasRVO) {
+    if (DiscardResult) {
+      // If we need to discard the return value but the function returns its
+      // value via an RVO pointer, we need to create one such pointer just
+      // for this call.
+      if (std::optional<unsigned> LocalIndex = allocateLocal(E)) {
+        if (!this->emitGetPtrLocal(*LocalIndex, E))
+          return false;
+      }
+    } else {
+      assert(Initializing);
+      if (!isa<CXXMemberCallExpr>(E)) {
+        if (!this->emitDupPtr(E))
+          return false;
+      }
     }
   }
 
@@ -1923,7 +1962,13 @@ bool ByteCodeExprGen<Emitter>::VisitCallExpr(const CallExpr *E) {
 template <class Emitter>
 bool ByteCodeExprGen<Emitter>::VisitCXXMemberCallExpr(
     const CXXMemberCallExpr *E) {
-  // Get a This pointer on the stack.
+  if (Initializing) {
+    // If we're initializing, the current stack top is the pointer to
+    // initialize, so dup that so this call has its own version.
+    if (!this->emitDupPtr(E))
+      return false;
+  }
+
   if (!this->visit(E->getImplicitObjectArgument()))
     return false;
 
@@ -1933,6 +1978,10 @@ bool ByteCodeExprGen<Emitter>::VisitCXXMemberCallExpr(
 template <class Emitter>
 bool ByteCodeExprGen<Emitter>::VisitCXXDefaultInitExpr(
     const CXXDefaultInitExpr *E) {
+
+  if (Initializing)
+    return this->visitInitializer(E->getExpr());
+
   assert(classify(E->getType()));
   return this->visit(E->getExpr());
 }
@@ -1940,7 +1989,13 @@ bool ByteCodeExprGen<Emitter>::VisitCXXDefaultInitExpr(
 template <class Emitter>
 bool ByteCodeExprGen<Emitter>::VisitCXXDefaultArgExpr(
     const CXXDefaultArgExpr *E) {
-  return this->visit(E->getExpr());
+  const Expr *SubExpr = E->getExpr();
+
+  if (std::optional<PrimType> T = classify(E->getExpr()))
+    return this->visit(SubExpr);
+
+  assert(Initializing);
+  return this->visitInitializer(SubExpr);
 }
 
 template <class Emitter>

diff  --git a/clang/lib/AST/Interp/ByteCodeExprGen.h b/clang/lib/AST/Interp/ByteCodeExprGen.h
index d28e03d571b1c6..3a64a4f6fec072 100644
--- a/clang/lib/AST/Interp/ByteCodeExprGen.h
+++ b/clang/lib/AST/Interp/ByteCodeExprGen.h
@@ -83,6 +83,7 @@ class ByteCodeExprGen : public ConstStmtVisitor<ByteCodeExprGen<Emitter>, bool>,
   bool VisitUnaryExprOrTypeTraitExpr(const UnaryExprOrTypeTraitExpr *E);
   bool VisitMemberExpr(const MemberExpr *E);
   bool VisitArrayInitIndexExpr(const ArrayInitIndexExpr *E);
+  bool VisitArrayInitLoopExpr(const ArrayInitLoopExpr *E);
   bool VisitOpaqueValueExpr(const OpaqueValueExpr *E);
   bool VisitAbstractConditionalOperator(const AbstractConditionalOperator *E);
   bool VisitStringLiteral(const StringLiteral *E);
@@ -93,7 +94,6 @@ class ByteCodeExprGen : public ConstStmtVisitor<ByteCodeExprGen<Emitter>, bool>,
   bool VisitExprWithCleanups(const ExprWithCleanups *E);
   bool VisitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *E);
   bool VisitCXXBindTemporaryExpr(const CXXBindTemporaryExpr *E);
-  bool VisitCXXTemporaryObjectExpr(const CXXTemporaryObjectExpr *E);
   bool VisitCompoundLiteralExpr(const CompoundLiteralExpr *E);
   bool VisitTypeTraitExpr(const TypeTraitExpr *E);
   bool VisitLambdaExpr(const LambdaExpr *E);
@@ -101,6 +101,7 @@ class ByteCodeExprGen : public ConstStmtVisitor<ByteCodeExprGen<Emitter>, bool>,
   bool VisitCXXThrowExpr(const CXXThrowExpr *E);
   bool VisitCXXReinterpretCastExpr(const CXXReinterpretCastExpr *E);
   bool VisitCXXNoexceptExpr(const CXXNoexceptExpr *E);
+  bool VisitCXXConstructExpr(const CXXConstructExpr *E);
 
 protected:
   bool visitExpr(const Expr *E) override;
@@ -136,17 +137,21 @@ class ByteCodeExprGen : public ConstStmtVisitor<ByteCodeExprGen<Emitter>, bool>,
     }
     llvm_unreachable("not a primitive type");
   }
-
-  /// Evaluates an expression for side effects and discards the result.
-  bool discard(const Expr *E);
-  /// Evaluates an expression and places result on stack.
+  /// Evaluates an expression and places the result on the stack. If the
+  /// expression is of composite type, a local variable will be created
+  /// and a pointer to said variable will be placed on the stack.
   bool visit(const Expr *E);
-  /// Compiles an initializer.
+  /// Compiles an initializer. This is like visit() but it will never
+  /// create a variable and instead rely on a variable already having
+  /// been created. visitInitializer() then relies on a pointer to this
+  /// variable being on top of the stack.
   bool visitInitializer(const Expr *E);
-  /// Compiles an array initializer.
-  bool visitArrayInitializer(const Expr *Initializer);
-  /// Compiles a record initializer.
-  bool visitRecordInitializer(const Expr *Initializer);
+  /// Evaluates an expression for side effects and discards the result.
+  bool discard(const Expr *E);
+  /// Just pass evaluation on to \p E. This leaves all the parsing flags
+  /// intact.
+  bool delegate(const Expr *E);
+
   /// Creates and initializes a variable from the given decl.
   bool visitVarDecl(const VarDecl *VD);
 
@@ -190,9 +195,6 @@ class ByteCodeExprGen : public ConstStmtVisitor<ByteCodeExprGen<Emitter>, bool>,
     return this->emitPopPtr(I);
   }
 
-  bool visitConditional(const AbstractConditionalOperator *E,
-                        llvm::function_ref<bool(const Expr *)> V);
-
   /// Creates a local primitive value.
   unsigned allocateLocalPrimitive(DeclTy &&Decl, PrimType Ty, bool IsConst,
                                   bool IsExtended = false);
@@ -281,6 +283,10 @@ class ByteCodeExprGen : public ConstStmtVisitor<ByteCodeExprGen<Emitter>, bool>,
 
   /// Flag indicating if return value is to be discarded.
   bool DiscardResult = false;
+
+  /// Flag inidicating if we're initializing an already created
+  /// variable. This is set in visitInitializer().
+  bool Initializing = false;
 };
 
 extern template class ByteCodeExprGen<ByteCodeEmitter>;

diff  --git a/clang/lib/AST/Interp/Context.cpp b/clang/lib/AST/Interp/Context.cpp
index 4c4808324c3a14..6e0d949457d673 100644
--- a/clang/lib/AST/Interp/Context.cpp
+++ b/clang/lib/AST/Interp/Context.cpp
@@ -128,7 +128,7 @@ std::optional<PrimType> Context::classify(QualType T) const {
     return PT_Float;
 
   if (T->isFunctionPointerType() || T->isFunctionReferenceType() ||
-      T->isFunctionType())
+      T->isFunctionType() || T->isSpecificBuiltinType(BuiltinType::BoundMember))
     return PT_FnPtr;
 
   if (T->isReferenceType() || T->isPointerType())

diff  --git a/clang/test/AST/Interp/lambda.cpp b/clang/test/AST/Interp/lambda.cpp
index b913ad13500bc0..da1d706af1d050 100644
--- a/clang/test/AST/Interp/lambda.cpp
+++ b/clang/test/AST/Interp/lambda.cpp
@@ -103,8 +103,7 @@ namespace LambdaParams {
 
     return a;
   }
-  /// FIXME: This should work in the new interpreter.
-  static_assert(foo() == 1); // expected-error {{not an integral constant expression}}
+  static_assert(foo() == 1);
 }
 
 namespace StaticInvoker {
@@ -136,10 +135,6 @@ namespace StaticInvoker {
   }
   static_assert(sv4(12) == 12);
 
-
-
-  /// FIXME: This is broken for lambda-unrelated reasons.
-#if 0
   constexpr int sv5(int i) {
     struct F { int a; float f; };
     auto l = [](int m, F f) { return m; };
@@ -147,7 +142,6 @@ namespace StaticInvoker {
     return fp(i, F{12, 14.0});
   }
   static_assert(sv5(12) == 12);
-#endif
 
   constexpr int sv6(int i) {
     struct F { int a;
@@ -162,3 +156,26 @@ namespace StaticInvoker {
   }
   static_assert(sv6(12) == 12);
 }
+
+namespace LambdasAsParams {
+  template<typename F>
+  constexpr auto call(F f) {
+    return f();
+  }
+  static_assert(call([](){ return 1;}) == 1);
+  static_assert(call([](){ return 2;}) == 2);
+
+
+  constexpr unsigned L = call([](){ return 12;});
+  static_assert(L == 12);
+
+
+  constexpr float heh() {
+    auto a = []() {
+      return 1.0;
+    };
+
+    return static_cast<float>(a());
+  }
+  static_assert(heh() == 1.0);
+}

diff  --git a/clang/test/AST/Interp/records.cpp b/clang/test/AST/Interp/records.cpp
index c4191e3417bf17..76882d29e201e9 100644
--- a/clang/test/AST/Interp/records.cpp
+++ b/clang/test/AST/Interp/records.cpp
@@ -795,3 +795,44 @@ namespace VirtualFunctionPointers {
 
 };
 #endif
+
+namespace CompositeDefaultArgs {
+  struct Foo {
+    int a;
+    int b;
+    constexpr Foo() : a(12), b(13) {}
+  };
+
+  class Bar {
+  public:
+    bool B = false;
+
+    constexpr int someFunc(Foo F = Foo()) {
+      this->B = true;
+      return 5;
+    }
+  };
+
+  constexpr bool testMe() {
+    Bar B;
+    B.someFunc();
+    return B.B;
+  }
+  static_assert(testMe(), "");
+}
+
+constexpr bool BPand(BoolPair bp) {
+  return bp.first && bp.second;
+}
+static_assert(BPand(BoolPair{true, false}) == false, "");
+
+namespace TemporaryObjectExpr {
+  struct F {
+    int a;
+    constexpr F() : a(12) {}
+  };
+  constexpr int foo(F f) {
+    return 0;
+  }
+  static_assert(foo(F()) == 0, "");
+}


        


More information about the cfe-commits mailing list