[clang] [clang][Interp] Simplify and fix variable scope handling (PR #101788)

Timm Baeder via cfe-commits cfe-commits at lists.llvm.org
Fri Aug 2 21:18:08 PDT 2024


Timm =?utf-8?q?Bäder?= <tbaeder at redhat.com>
Message-ID: <llvm.org/llvm/llvm-project/pull/101788 at github.com>
In-Reply-To:


https://github.com/tbaederr created https://github.com/llvm/llvm-project/pull/101788

None

>From 02ba0cd327c91d65ac6d4ee7699a590cea6a92ec Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Timm=20B=C3=A4der?= <tbaeder at redhat.com>
Date: Fri, 2 Aug 2024 17:45:19 +0200
Subject: [PATCH 1/2] [clang][Interp] Enhance CodePtr

Add more relational operators.
---
 clang/lib/AST/Interp/Source.h | 13 +++++++++----
 1 file changed, 9 insertions(+), 4 deletions(-)

diff --git a/clang/lib/AST/Interp/Source.h b/clang/lib/AST/Interp/Source.h
index c28b488ff554d..88b5ec7740df5 100644
--- a/clang/lib/AST/Interp/Source.h
+++ b/clang/lib/AST/Interp/Source.h
@@ -29,7 +29,7 @@ class Function;
 /// Pointer into the code segment.
 class CodePtr final {
 public:
-  CodePtr() : Ptr(nullptr) {}
+  CodePtr() = default;
 
   CodePtr &operator+=(int32_t Offset) {
     Ptr += Offset;
@@ -45,11 +45,16 @@ class CodePtr final {
     assert(Ptr != nullptr && "Invalid code pointer");
     return CodePtr(Ptr - RHS);
   }
+  CodePtr operator+(ssize_t RHS) const {
+    assert(Ptr != nullptr && "Invalid code pointer");
+    return CodePtr(Ptr + RHS);
+  }
 
   bool operator!=(const CodePtr &RHS) const { return Ptr != RHS.Ptr; }
   const std::byte *operator*() const { return Ptr; }
-
-  operator bool() const { return Ptr; }
+  explicit operator bool() const { return Ptr; }
+  bool operator<=(const CodePtr &RHS) const { return Ptr <= RHS.Ptr; }
+  bool operator>=(const CodePtr &RHS) const { return Ptr >= RHS.Ptr; }
 
   /// Reads data and advances the pointer.
   template <typename T> std::enable_if_t<!std::is_pointer<T>::value, T> read() {
@@ -65,7 +70,7 @@ class CodePtr final {
   /// Constructor used by Function to generate pointers.
   CodePtr(const std::byte *Ptr) : Ptr(Ptr) {}
   /// Pointer into the code owned by a function.
-  const std::byte *Ptr;
+  const std::byte *Ptr = nullptr;
 };
 
 /// Describes the statement/declaration an opcode was generated from.

>From b7a48211f3d7a0661a029382486e47f52cbd8a27 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Timm=20B=C3=A4der?= <tbaeder at redhat.com>
Date: Fri, 2 Aug 2024 19:05:04 +0200
Subject: [PATCH 2/2] [clang][Interp] Simplify and fix variable scope handling

Change scope handling to allow multiple Destroy calls for a given scope,
provided it is preceeded by a InitScope call. This is necessary to
properly allow nested scopes in loops.
---
 clang/lib/AST/Interp/Compiler.cpp    | 63 +++++++---------------------
 clang/lib/AST/Interp/Compiler.h      | 44 +++++++++----------
 clang/lib/AST/Interp/Interp.h        |  5 +++
 clang/lib/AST/Interp/InterpFrame.cpp | 14 +++++--
 clang/lib/AST/Interp/InterpFrame.h   |  1 +
 clang/lib/AST/Interp/Opcodes.td      |  3 ++
 clang/test/AST/Interp/if.cpp         | 18 ++++++++
 clang/test/AST/Interp/loops.cpp      | 28 +++++++++++++
 8 files changed, 101 insertions(+), 75 deletions(-)

diff --git a/clang/lib/AST/Interp/Compiler.cpp b/clang/lib/AST/Interp/Compiler.cpp
index 258e4ed645254..f600d9b5b80f8 100644
--- a/clang/lib/AST/Interp/Compiler.cpp
+++ b/clang/lib/AST/Interp/Compiler.cpp
@@ -2226,7 +2226,7 @@ bool Compiler<Emitter>::VisitExprWithCleanups(const ExprWithCleanups *E) {
 
   assert(E->getNumObjects() == 0 && "TODO: Implement cleanups");
 
-  return this->delegate(SubExpr) && ES.destroyLocals();
+  return this->delegate(SubExpr) && ES.destroyLocals(E);
 }
 
 template <class Emitter>
@@ -2537,13 +2537,8 @@ bool Compiler<Emitter>::VisitCXXConstructExpr(const CXXConstructExpr *E) {
         return false;
     }
 
-    // Immediately call the destructor if we have to.
-    if (DiscardResult) {
-      if (!this->emitRecordDestruction(getRecord(E->getType())))
-        return false;
-      if (!this->emitPopPtr(E))
-        return false;
-    }
+    if (DiscardResult)
+      return this->emitPopPtr(E);
     return true;
   }
 
@@ -4222,22 +4217,6 @@ template <class Emitter> bool Compiler<Emitter>::visitStmt(const Stmt *S) {
   }
 }
 
-/// Visits the given statment without creating a variable
-/// scope for it in case it is a compound statement.
-template <class Emitter> bool Compiler<Emitter>::visitLoopBody(const Stmt *S) {
-  if (isa<NullStmt>(S))
-    return true;
-
-  if (const auto *CS = dyn_cast<CompoundStmt>(S)) {
-    for (const auto *InnerStmt : CS->body())
-      if (!visitStmt(InnerStmt))
-        return false;
-    return true;
-  }
-
-  return this->visitStmt(S);
-}
-
 template <class Emitter>
 bool Compiler<Emitter>::visitCompoundStmt(const CompoundStmt *S) {
   BlockScope<Emitter> Scope(this);
@@ -4300,8 +4279,6 @@ bool Compiler<Emitter>::visitReturnStmt(const ReturnStmt *RS) {
 }
 
 template <class Emitter> bool Compiler<Emitter>::visitIfStmt(const IfStmt *IS) {
-  BlockScope<Emitter> IfScope(this);
-
   if (IS->isNonNegatedConsteval())
     return visitStmt(IS->getThen());
   if (IS->isNegatedConsteval())
@@ -4340,7 +4317,7 @@ template <class Emitter> bool Compiler<Emitter>::visitIfStmt(const IfStmt *IS) {
     this->emitLabel(LabelEnd);
   }
 
-  return IfScope.destroyLocals();
+  return true;
 }
 
 template <class Emitter>
@@ -4364,12 +4341,8 @@ bool Compiler<Emitter>::visitWhileStmt(const WhileStmt *S) {
   if (!this->jumpFalse(EndLabel))
     return false;
 
-  LocalScope<Emitter> Scope(this);
-  {
-    DestructorScope<Emitter> DS(Scope);
-    if (!this->visitLoopBody(Body))
-      return false;
-  }
+  if (!this->visitStmt(Body))
+    return false;
 
   if (!this->jump(CondLabel))
     return false;
@@ -4387,14 +4360,11 @@ template <class Emitter> bool Compiler<Emitter>::visitDoStmt(const DoStmt *S) {
   LabelTy EndLabel = this->getLabel();
   LabelTy CondLabel = this->getLabel();
   LoopScope<Emitter> LS(this, EndLabel, CondLabel);
-  LocalScope<Emitter> Scope(this);
 
   this->fallthrough(StartLabel);
   this->emitLabel(StartLabel);
   {
-    DestructorScope<Emitter> DS(Scope);
-
-    if (!this->visitLoopBody(Body))
+    if (!this->visitStmt(Body))
       return false;
     this->fallthrough(CondLabel);
     this->emitLabel(CondLabel);
@@ -4421,10 +4391,10 @@ bool Compiler<Emitter>::visitForStmt(const ForStmt *S) {
   LabelTy CondLabel = this->getLabel();
   LabelTy IncLabel = this->getLabel();
   LoopScope<Emitter> LS(this, EndLabel, IncLabel);
-  LocalScope<Emitter> Scope(this);
 
   if (Init && !this->visitStmt(Init))
     return false;
+
   this->fallthrough(CondLabel);
   this->emitLabel(CondLabel);
 
@@ -4440,10 +4410,9 @@ bool Compiler<Emitter>::visitForStmt(const ForStmt *S) {
   }
 
   {
-    DestructorScope<Emitter> DS(Scope);
-
-    if (Body && !this->visitLoopBody(Body))
+    if (Body && !this->visitStmt(Body))
       return false;
+
     this->fallthrough(IncLabel);
     this->emitLabel(IncLabel);
     if (Inc && !this->discard(Inc))
@@ -4495,13 +4464,11 @@ bool Compiler<Emitter>::visitCXXForRangeStmt(const CXXForRangeStmt *S) {
     return false;
 
   // Body.
-  LocalScope<Emitter> Scope(this);
   {
-    DestructorScope<Emitter> DS(Scope);
-
-    if (!this->visitLoopBody(Body))
+    if (!this->visitStmt(Body))
       return false;
-  this->fallthrough(IncLabel);
+
+    this->fallthrough(IncLabel);
     this->emitLabel(IncLabel);
     if (!this->discard(Inc))
       return false;
@@ -4520,7 +4487,7 @@ bool Compiler<Emitter>::visitBreakStmt(const BreakStmt *S) {
   if (!BreakLabel)
     return false;
 
-  this->VarScope->emitDestructors();
+  this->emitCleanup();
   return this->jump(*BreakLabel);
 }
 
@@ -4529,7 +4496,7 @@ bool Compiler<Emitter>::visitContinueStmt(const ContinueStmt *S) {
   if (!ContinueLabel)
     return false;
 
-  this->VarScope->emitDestructors();
+  this->emitCleanup();
   return this->jump(*ContinueLabel);
 }
 
diff --git a/clang/lib/AST/Interp/Compiler.h b/clang/lib/AST/Interp/Compiler.h
index 6bc9985fe7232..eabde051b4278 100644
--- a/clang/lib/AST/Interp/Compiler.h
+++ b/clang/lib/AST/Interp/Compiler.h
@@ -202,7 +202,6 @@ class Compiler : public ConstStmtVisitor<Compiler<Emitter>, bool>,
 
   // Statements.
   bool visitCompoundStmt(const CompoundStmt *S);
-  bool visitLoopBody(const Stmt *S);
   bool visitDeclStmt(const DeclStmt *DS);
   bool visitReturnStmt(const ReturnStmt *RS);
   bool visitIfStmt(const IfStmt *IS);
@@ -452,11 +451,15 @@ template <class Emitter> class VariableScope {
     }
 
     // Use the parent scope.
-    addExtended(Local);
+    if (this->Parent)
+      this->Parent->addLocal(Local);
+    else
+      this->addLocal(Local);
   }
 
   virtual void emitDestruction() {}
-  virtual bool emitDestructors() { return true; }
+  virtual bool emitDestructors(const Expr *E = nullptr) { return true; }
+  virtual bool destroyLocals(const Expr *E = nullptr) { return true; }
   VariableScope *getParent() const { return Parent; }
 
 protected:
@@ -483,16 +486,21 @@ template <class Emitter> class LocalScope : public VariableScope<Emitter> {
   }
 
   /// Overriden to support explicit destruction.
-  void emitDestruction() override { destroyLocals(); }
+  void emitDestruction() override {
+    if (!Idx)
+      return;
+
+    this->emitDestructors();
+    this->Ctx->emitDestroy(*Idx, SourceInfo{});
+  }
 
   /// Explicit destruction of local variables.
-  bool destroyLocals() {
+  bool destroyLocals(const Expr *E = nullptr) override {
     if (!Idx)
       return true;
 
-    bool Success = this->emitDestructors();
-    this->Ctx->emitDestroy(*Idx, SourceInfo{});
-    removeStoredOpaqueValues();
+    bool Success = this->emitDestructors(E);
+    this->Ctx->emitDestroy(*Idx, E);
     this->Idx = std::nullopt;
     return Success;
   }
@@ -501,25 +509,26 @@ template <class Emitter> class LocalScope : public VariableScope<Emitter> {
     if (!Idx) {
       Idx = this->Ctx->Descriptors.size();
       this->Ctx->Descriptors.emplace_back();
+      this->Ctx->emitInitScope(*Idx, {});
     }
 
     this->Ctx->Descriptors[*Idx].emplace_back(Local);
   }
 
-  bool emitDestructors() override {
+  bool emitDestructors(const Expr *E = nullptr) override {
     if (!Idx)
       return true;
     // Emit destructor calls for local variables of record
     // type with a destructor.
     for (Scope::Local &Local : this->Ctx->Descriptors[*Idx]) {
       if (!Local.Desc->isPrimitive() && !Local.Desc->isPrimitiveArray()) {
-        if (!this->Ctx->emitGetPtrLocal(Local.Offset, SourceInfo{}))
+        if (!this->Ctx->emitGetPtrLocal(Local.Offset, E))
           return false;
 
         if (!this->Ctx->emitDestruction(Local.Desc))
           return false;
 
-        if (!this->Ctx->emitPopPtr(SourceInfo{}))
+        if (!this->Ctx->emitPopPtr(E))
           return false;
         removeIfStoredOpaqueValue(Local);
       }
@@ -549,19 +558,6 @@ template <class Emitter> class LocalScope : public VariableScope<Emitter> {
   std::optional<unsigned> Idx;
 };
 
-/// Emits the destructors of the variables of \param OtherScope
-/// when this scope is destroyed. Does not create a Scope in the bytecode at
-/// all, this is just a RAII object to emit destructors.
-template <class Emitter> class DestructorScope final {
-public:
-  DestructorScope(LocalScope<Emitter> &OtherScope) : OtherScope(OtherScope) {}
-
-  ~DestructorScope() { OtherScope.emitDestructors(); }
-
-private:
-  LocalScope<Emitter> &OtherScope;
-};
-
 /// Scope for storage declared in a compound statement.
 template <class Emitter> class BlockScope final : public LocalScope<Emitter> {
 public:
diff --git a/clang/lib/AST/Interp/Interp.h b/clang/lib/AST/Interp/Interp.h
index 63e9966b831db..b54635b9988e2 100644
--- a/clang/lib/AST/Interp/Interp.h
+++ b/clang/lib/AST/Interp/Interp.h
@@ -2012,6 +2012,11 @@ inline bool Destroy(InterpState &S, CodePtr OpPC, uint32_t I) {
   return true;
 }
 
+inline bool InitScope(InterpState &S, CodePtr OpPC, uint32_t I) {
+  S.Current->initScope(I);
+  return true;
+}
+
 //===----------------------------------------------------------------------===//
 // Cast, CastFP
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/AST/Interp/InterpFrame.cpp b/clang/lib/AST/Interp/InterpFrame.cpp
index 1c37450ae1c6e..1925ffeabb1d7 100644
--- a/clang/lib/AST/Interp/InterpFrame.cpp
+++ b/clang/lib/AST/Interp/InterpFrame.cpp
@@ -37,9 +37,9 @@ InterpFrame::InterpFrame(InterpState &S, const Function *Func,
   Locals = std::make_unique<char[]>(FrameSize);
   for (auto &Scope : Func->scopes()) {
     for (auto &Local : Scope.locals()) {
-      Block *B =
-          new (localBlock(Local.Offset)) Block(S.Ctx.getEvalID(), Local.Desc);
-      B->invokeCtor();
+      new (localBlock(Local.Offset)) Block(S.Ctx.getEvalID(), Local.Desc);
+      // Note that we are NOT calling invokeCtor() here, since that is done
+      // via the InitScope op.
       new (localInlineDesc(Local.Offset)) InlineDescriptor(Local.Desc);
     }
   }
@@ -83,6 +83,14 @@ InterpFrame::~InterpFrame() {
   }
 }
 
+void InterpFrame::initScope(unsigned Idx) {
+  if (!Func)
+    return;
+  for (auto &Local : Func->getScope(Idx).locals()) {
+    localBlock(Local.Offset)->invokeCtor();
+  }
+}
+
 void InterpFrame::destroy(unsigned Idx) {
   for (auto &Local : Func->getScope(Idx).locals()) {
     S.deallocate(localBlock(Local.Offset));
diff --git a/clang/lib/AST/Interp/InterpFrame.h b/clang/lib/AST/Interp/InterpFrame.h
index 4a312a71bcf1c..91b9b41b5d334 100644
--- a/clang/lib/AST/Interp/InterpFrame.h
+++ b/clang/lib/AST/Interp/InterpFrame.h
@@ -44,6 +44,7 @@ class InterpFrame final : public Frame {
 
   /// Invokes the destructors for a scope.
   void destroy(unsigned Idx);
+  void initScope(unsigned Idx);
 
   /// Pops the arguments off the stack.
   void popArgs();
diff --git a/clang/lib/AST/Interp/Opcodes.td b/clang/lib/AST/Interp/Opcodes.td
index eeb9cb2e933a6..3e830f89754dc 100644
--- a/clang/lib/AST/Interp/Opcodes.td
+++ b/clang/lib/AST/Interp/Opcodes.td
@@ -232,6 +232,9 @@ def Destroy : Opcode {
   let Args = [ArgUint32];
   let HasCustomEval = 1;
 }
+def InitScope : Opcode {
+  let Args = [ArgUint32];
+}
 
 //===----------------------------------------------------------------------===//
 // Constants
diff --git a/clang/test/AST/Interp/if.cpp b/clang/test/AST/Interp/if.cpp
index 37289d69d3255..540cb76fbaac3 100644
--- a/clang/test/AST/Interp/if.cpp
+++ b/clang/test/AST/Interp/if.cpp
@@ -58,3 +58,21 @@ constexpr char g(char const (&x)[2]) {
     ;
 }
 static_assert(g("x") == 'x');
+
+namespace IfScope {
+  struct Inc {
+    int &a;
+    constexpr Inc(int &a) : a(a) {}
+    constexpr ~Inc() { ++a; }
+  };
+
+  constexpr int foo() {
+    int a= 0;
+    int b = 12;
+    if (Inc{a}; true) {
+      b += a;
+    }
+    return b;
+  }
+  static_assert(foo() == 13, "");
+}
diff --git a/clang/test/AST/Interp/loops.cpp b/clang/test/AST/Interp/loops.cpp
index 2e235123af76e..38ab5613e1cbd 100644
--- a/clang/test/AST/Interp/loops.cpp
+++ b/clang/test/AST/Interp/loops.cpp
@@ -324,3 +324,31 @@ namespace RangeForLoop {
                                                // ref-note {{semicolon on a separate line}}
   }
 }
+
+namespace Scopes {
+  constexpr int foo() {
+    int n = 0;
+    {
+      int m = 12;
+    for (int i = 0;i < 10;++i) {
+
+      {
+        int a  = 10;
+        {
+          int b = 20;
+          {
+            int c = 30;
+            continue;
+          }
+        }
+      }
+    }
+    ++m;
+    n = m;
+    }
+
+    ++n;
+    return n;
+  }
+  static_assert(foo() == 14, "");
+}



More information about the cfe-commits mailing list