[clang] [clang][bytecode] Fix local destructor order (PR #107951)

via cfe-commits cfe-commits at lists.llvm.org
Mon Sep 9 19:15:07 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clang

Author: Timm Baeder (tbaederr)

<details>
<summary>Changes</summary>

Add appropriate scopes and use reverse-order iteration in LocalScope::emitDestructors().

---
Full diff: https://github.com/llvm/llvm-project/pull/107951.diff


3 Files Affected:

- (modified) clang/lib/AST/ByteCode/Compiler.cpp (+51-31) 
- (modified) clang/lib/AST/ByteCode/Compiler.h (+5-3) 
- (modified) clang/test/AST/ByteCode/cxx2a.cpp (+76) 


``````````diff
diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp
index fddd6db4f4814c..3ba9732a979244 100644
--- a/clang/lib/AST/ByteCode/Compiler.cpp
+++ b/clang/lib/AST/ByteCode/Compiler.cpp
@@ -115,22 +115,26 @@ template <class Emitter> class LoopScope final : public LabelScope<Emitter> {
   LoopScope(Compiler<Emitter> *Ctx, LabelTy BreakLabel, LabelTy ContinueLabel)
       : LabelScope<Emitter>(Ctx), OldBreakLabel(Ctx->BreakLabel),
         OldContinueLabel(Ctx->ContinueLabel),
-        OldLabelVarScope(Ctx->LabelVarScope) {
+        OldBreakVarScope(Ctx->BreakVarScope),
+        OldContinueVarScope(Ctx->ContinueVarScope) {
     this->Ctx->BreakLabel = BreakLabel;
     this->Ctx->ContinueLabel = ContinueLabel;
-    this->Ctx->LabelVarScope = this->Ctx->VarScope;
+    this->Ctx->BreakVarScope = this->Ctx->VarScope;
+    this->Ctx->ContinueVarScope = this->Ctx->VarScope;
   }
 
   ~LoopScope() {
     this->Ctx->BreakLabel = OldBreakLabel;
     this->Ctx->ContinueLabel = OldContinueLabel;
-    this->Ctx->LabelVarScope = OldLabelVarScope;
+    this->Ctx->ContinueVarScope = OldContinueVarScope;
+    this->Ctx->BreakVarScope = OldBreakVarScope;
   }
 
 private:
   OptLabelTy OldBreakLabel;
   OptLabelTy OldContinueLabel;
-  VariableScope<Emitter> *OldLabelVarScope;
+  VariableScope<Emitter> *OldBreakVarScope;
+  VariableScope<Emitter> *OldContinueVarScope;
 };
 
 // Sets the context for a switch scope, mapping labels.
@@ -145,18 +149,18 @@ template <class Emitter> class SwitchScope final : public LabelScope<Emitter> {
       : LabelScope<Emitter>(Ctx), OldBreakLabel(Ctx->BreakLabel),
         OldDefaultLabel(this->Ctx->DefaultLabel),
         OldCaseLabels(std::move(this->Ctx->CaseLabels)),
-        OldLabelVarScope(Ctx->LabelVarScope) {
+        OldLabelVarScope(Ctx->BreakVarScope) {
     this->Ctx->BreakLabel = BreakLabel;
     this->Ctx->DefaultLabel = DefaultLabel;
     this->Ctx->CaseLabels = std::move(CaseLabels);
-    this->Ctx->LabelVarScope = this->Ctx->VarScope;
+    this->Ctx->BreakVarScope = this->Ctx->VarScope;
   }
 
   ~SwitchScope() {
     this->Ctx->BreakLabel = OldBreakLabel;
     this->Ctx->DefaultLabel = OldDefaultLabel;
     this->Ctx->CaseLabels = std::move(OldCaseLabels);
-    this->Ctx->LabelVarScope = OldLabelVarScope;
+    this->Ctx->BreakVarScope = OldLabelVarScope;
   }
 
 private:
@@ -4605,18 +4609,23 @@ bool Compiler<Emitter>::visitWhileStmt(const WhileStmt *S) {
   this->fallthrough(CondLabel);
   this->emitLabel(CondLabel);
 
-  if (const DeclStmt *CondDecl = S->getConditionVariableDeclStmt())
-    if (!visitDeclStmt(CondDecl))
-      return false;
+  {
+    LocalScope<Emitter> CondScope(this);
+    if (const DeclStmt *CondDecl = S->getConditionVariableDeclStmt())
+      if (!visitDeclStmt(CondDecl))
+        return false;
 
-  if (!this->visitBool(Cond))
-    return false;
-  if (!this->jumpFalse(EndLabel))
-    return false;
+    if (!this->visitBool(Cond))
+      return false;
+    if (!this->jumpFalse(EndLabel))
+      return false;
 
-  if (!this->visitStmt(Body))
-    return false;
+    if (!this->visitStmt(Body))
+      return false;
 
+    if (!CondScope.destroyLocals())
+      return false;
+  }
   if (!this->jump(CondLabel))
     return false;
   this->fallthrough(EndLabel);
@@ -4636,13 +4645,18 @@ template <class Emitter> bool Compiler<Emitter>::visitDoStmt(const DoStmt *S) {
 
   this->fallthrough(StartLabel);
   this->emitLabel(StartLabel);
+
   {
+    LocalScope<Emitter> CondScope(this);
     if (!this->visitStmt(Body))
       return false;
     this->fallthrough(CondLabel);
     this->emitLabel(CondLabel);
     if (!this->visitBool(Cond))
       return false;
+
+    if (!CondScope.destroyLocals())
+      return false;
   }
   if (!this->jumpTrue(StartLabel))
     return false;
@@ -4671,18 +4685,19 @@ bool Compiler<Emitter>::visitForStmt(const ForStmt *S) {
   this->fallthrough(CondLabel);
   this->emitLabel(CondLabel);
 
-  if (const DeclStmt *CondDecl = S->getConditionVariableDeclStmt())
-    if (!visitDeclStmt(CondDecl))
-      return false;
+  {
+    LocalScope<Emitter> CondScope(this);
+    if (const DeclStmt *CondDecl = S->getConditionVariableDeclStmt())
+      if (!visitDeclStmt(CondDecl))
+        return false;
 
-  if (Cond) {
-    if (!this->visitBool(Cond))
-      return false;
-    if (!this->jumpFalse(EndLabel))
-      return false;
-  }
+    if (Cond) {
+      if (!this->visitBool(Cond))
+        return false;
+      if (!this->jumpFalse(EndLabel))
+        return false;
+    }
 
-  {
     if (Body && !this->visitStmt(Body))
       return false;
 
@@ -4690,10 +4705,13 @@ bool Compiler<Emitter>::visitForStmt(const ForStmt *S) {
     this->emitLabel(IncLabel);
     if (Inc && !this->discard(Inc))
       return false;
-  }
 
+    if (!CondScope.destroyLocals())
+      return false;
+  }
   if (!this->jump(CondLabel))
     return false;
+
   this->fallthrough(EndLabel);
   this->emitLabel(EndLabel);
   return true;
@@ -4760,7 +4778,7 @@ bool Compiler<Emitter>::visitBreakStmt(const BreakStmt *S) {
   if (!BreakLabel)
     return false;
 
-  for (VariableScope<Emitter> *C = VarScope; C != LabelVarScope;
+  for (VariableScope<Emitter> *C = VarScope; C != BreakVarScope;
        C = C->getParent())
     C->emitDestruction();
   return this->jump(*BreakLabel);
@@ -4771,8 +4789,8 @@ bool Compiler<Emitter>::visitContinueStmt(const ContinueStmt *S) {
   if (!ContinueLabel)
     return false;
 
-  for (VariableScope<Emitter> *C = VarScope; C != LabelVarScope;
-       C = C->getParent())
+  for (VariableScope<Emitter> *C = VarScope;
+       C && C->getParent() != ContinueVarScope; C = C->getParent())
     C->emitDestruction();
   return this->jump(*ContinueLabel);
 }
@@ -4781,6 +4799,7 @@ template <class Emitter>
 bool Compiler<Emitter>::visitSwitchStmt(const SwitchStmt *S) {
   const Expr *Cond = S->getCond();
   PrimType CondT = this->classifyPrim(Cond->getType());
+  LocalScope<Emitter> LS(this);
 
   LabelTy EndLabel = this->getLabel();
   OptLabelTy DefaultLabel = std::nullopt;
@@ -4844,7 +4863,8 @@ bool Compiler<Emitter>::visitSwitchStmt(const SwitchStmt *S) {
   if (!this->visitStmt(S->getBody()))
     return false;
   this->emitLabel(EndLabel);
-  return true;
+
+  return LS.destroyLocals();
 }
 
 template <class Emitter>
diff --git a/clang/lib/AST/ByteCode/Compiler.h b/clang/lib/AST/ByteCode/Compiler.h
index ac4c08c4d0ffb6..2dfa187713a803 100644
--- a/clang/lib/AST/ByteCode/Compiler.h
+++ b/clang/lib/AST/ByteCode/Compiler.h
@@ -409,10 +409,12 @@ class Compiler : public ConstStmtVisitor<Compiler<Emitter>, bool>,
   /// Switch case mapping.
   CaseMap CaseLabels;
 
-  /// Scope to cleanup until when chumping to one of the labels.
-  VariableScope<Emitter> *LabelVarScope = nullptr;
+  /// Scope to cleanup until when we see a break statement.
+  VariableScope<Emitter> *BreakVarScope = nullptr;
   /// Point to break to.
   OptLabelTy BreakLabel;
+  /// Scope to cleanup until when we see a continue statement.
+  VariableScope<Emitter> *ContinueVarScope = nullptr;
   /// Point to continue to.
   OptLabelTy ContinueLabel;
   /// Default case label.
@@ -533,7 +535,7 @@ template <class Emitter> class LocalScope : public VariableScope<Emitter> {
       return true;
     // Emit destructor calls for local variables of record
     // type with a destructor.
-    for (Scope::Local &Local : this->Ctx->Descriptors[*Idx]) {
+    for (Scope::Local &Local : llvm::reverse(this->Ctx->Descriptors[*Idx])) {
       if (!Local.Desc->isPrimitive() && !Local.Desc->isPrimitiveArray()) {
         if (!this->Ctx->emitGetPtrLocal(Local.Offset, E))
           return false;
diff --git a/clang/test/AST/ByteCode/cxx2a.cpp b/clang/test/AST/ByteCode/cxx2a.cpp
index ad021b30cfd3cf..eaae978e011843 100644
--- a/clang/test/AST/ByteCode/cxx2a.cpp
+++ b/clang/test/AST/ByteCode/cxx2a.cpp
@@ -34,3 +34,79 @@ namespace Covariant {
   constexpr const Covariant1 *cb1 = &cb;
   static_assert(cb1->f()->a == 'Z');
 }
+
+namespace DtorOrder {
+  struct Buf {
+    char buf[64];
+    int n = 0;
+    constexpr void operator+=(char c) { buf[n++] = c; }
+    constexpr bool operator==(const char *str) const {
+      if (str[n] != 0)
+        return false;
+
+      for (int i = 0; i < n; ++i) {
+        if (buf[n] != str[n])
+          return false;
+      }
+      return true;
+
+      return __builtin_memcmp(str, buf, n) == 0;
+    }
+    constexpr bool operator!=(const char *str) const { return !operator==(str); }
+  };
+
+  struct A {
+    constexpr A(Buf &buf, char c) : buf(buf), c(c) { buf += c; }
+    constexpr ~A() { buf += (c - 32);}
+    constexpr operator bool() const { return true; }
+    Buf &buf;
+    char c;
+  };
+
+  constexpr void abnormal_termination(Buf &buf) {
+    struct Indestructible {
+      constexpr ~Indestructible(); // not defined
+    };
+    A a(buf, 'a');
+    A(buf, 'b');
+    int n = 0;
+
+    for (A &&c = A(buf, 'c'); A d = A(buf, 'd'); A(buf, 'e')) {
+      switch (A f(buf, 'f'); A g = A(buf, 'g')) { // both-warning {{boolean}}
+      case false: {
+        A x(buf, 'x');
+      }
+
+      case true: {
+        A h(buf, 'h');
+        switch (n++) {
+        case 0:
+          break;
+        case 1:
+          continue;
+        case 2:
+          return;
+        }
+        break;
+      }
+
+      default:
+        Indestructible indest;
+      }
+
+      A j = (A(buf, 'i'), A(buf, 'j'));
+    }
+  }
+
+  constexpr bool check_abnormal_termination() {
+    Buf buf = {};
+    abnormal_termination(buf);
+    return buf ==
+      "abBc"
+        "dfgh" /*break*/ "HGFijIJeED"
+        "dfgh" /*continue*/ "HGFeED"
+        "dfgh" /*return*/ "HGFD"
+      "CA";
+  }
+  static_assert(check_abnormal_termination());
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/107951


More information about the cfe-commits mailing list