[clang] 12a7897 - [clang][Interp] BaseToDerived casts

Timm Bäder via cfe-commits cfe-commits at lists.llvm.org
Tue Sep 5 01:54:17 PDT 2023


Author: Timm Bäder
Date: 2023-09-05T10:53:54+02:00
New Revision: 12a789710e2bafe9bb49adea5e634784b086eb6a

URL: https://github.com/llvm/llvm-project/commit/12a789710e2bafe9bb49adea5e634784b086eb6a
DIFF: https://github.com/llvm/llvm-project/commit/12a789710e2bafe9bb49adea5e634784b086eb6a.diff

LOG: [clang][Interp] BaseToDerived casts

We can implement these similarly to DerivedToBase casts. We just have to
walk the class hierarchy, sum the base offsets and subtract it from the
current base offset of the pointer.

Differential Revision: https://reviews.llvm.org/D149133

Added: 
    

Modified: 
    clang/lib/AST/Interp/ByteCodeExprGen.cpp
    clang/lib/AST/Interp/ByteCodeExprGen.h
    clang/lib/AST/Interp/Interp.cpp
    clang/lib/AST/Interp/Interp.h
    clang/lib/AST/Interp/Opcodes.td
    clang/lib/AST/Interp/Pointer.h
    clang/test/AST/Interp/records.cpp

Removed: 
    


################################################################################
diff  --git a/clang/lib/AST/Interp/ByteCodeExprGen.cpp b/clang/lib/AST/Interp/ByteCodeExprGen.cpp
index b2177c29d30d5af..4708799773603de 100644
--- a/clang/lib/AST/Interp/ByteCodeExprGen.cpp
+++ b/clang/lib/AST/Interp/ByteCodeExprGen.cpp
@@ -92,8 +92,20 @@ bool ByteCodeExprGen<Emitter>::VisitCastExpr(const CastExpr *CE) {
     if (!this->visit(SubExpr))
       return false;
 
-    return this->emitDerivedToBaseCasts(getRecordTy(SubExpr->getType()),
-                                        getRecordTy(CE->getType()), CE);
+    unsigned DerivedOffset = collectBaseOffset(getRecordTy(CE->getType()),
+                                               getRecordTy(SubExpr->getType()));
+
+    return this->emitGetPtrBasePop(DerivedOffset, CE);
+  }
+
+  case CK_BaseToDerived: {
+    if (!this->visit(SubExpr))
+      return false;
+
+    unsigned DerivedOffset = collectBaseOffset(getRecordTy(SubExpr->getType()),
+                                               getRecordTy(CE->getType()));
+
+    return this->emitGetPtrDerivedPop(DerivedOffset, CE);
   }
 
   case CK_FloatingCast: {
@@ -2262,13 +2274,15 @@ void ByteCodeExprGen<Emitter>::emitCleanup() {
 }
 
 template <class Emitter>
-bool ByteCodeExprGen<Emitter>::emitDerivedToBaseCasts(
-    const RecordType *DerivedType, const RecordType *BaseType, const Expr *E) {
-  // Pointer of derived type is already on the stack.
+unsigned
+ByteCodeExprGen<Emitter>::collectBaseOffset(const RecordType *BaseType,
+                                            const RecordType *DerivedType) {
   const auto *FinalDecl = cast<CXXRecordDecl>(BaseType->getDecl());
   const RecordDecl *CurDecl = DerivedType->getDecl();
   const Record *CurRecord = getRecord(CurDecl);
   assert(CurDecl && FinalDecl);
+
+  unsigned OffsetSum = 0;
   for (;;) {
     assert(CurRecord->getNumBases() > 0);
     // One level up
@@ -2276,21 +2290,18 @@ bool ByteCodeExprGen<Emitter>::emitDerivedToBaseCasts(
       const auto *BaseDecl = cast<CXXRecordDecl>(B.Decl);
 
       if (BaseDecl == FinalDecl || BaseDecl->isDerivedFrom(FinalDecl)) {
-        // This decl will lead us to the final decl, so emit a base cast.
-        if (!this->emitGetPtrBasePop(B.Offset, E))
-          return false;
-
+        OffsetSum += B.Offset;
         CurRecord = B.R;
         CurDecl = BaseDecl;
         break;
       }
     }
     if (CurDecl == FinalDecl)
-      return true;
+      break;
   }
 
-  llvm_unreachable("Couldn't find the base class?");
-  return false;
+  assert(OffsetSum > 0);
+  return OffsetSum;
 }
 
 /// When calling this, we have a pointer of the local-to-destroy

diff  --git a/clang/lib/AST/Interp/ByteCodeExprGen.h b/clang/lib/AST/Interp/ByteCodeExprGen.h
index 3a64a4f6fec0726..cd924e911759e01 100644
--- a/clang/lib/AST/Interp/ByteCodeExprGen.h
+++ b/clang/lib/AST/Interp/ByteCodeExprGen.h
@@ -267,6 +267,8 @@ class ByteCodeExprGen : public ConstStmtVisitor<ByteCodeExprGen<Emitter>, bool>,
   bool emitRecordDestruction(const Descriptor *Desc);
   bool emitDerivedToBaseCasts(const RecordType *DerivedType,
                               const RecordType *BaseType, const Expr *E);
+  unsigned collectBaseOffset(const RecordType *BaseType,
+                             const RecordType *DerivedType);
 
 protected:
   /// Variable to storage mapping.

diff  --git a/clang/lib/AST/Interp/Interp.cpp b/clang/lib/AST/Interp/Interp.cpp
index 2b7f3bf35aa5982..23c54b3898e102b 100644
--- a/clang/lib/AST/Interp/Interp.cpp
+++ b/clang/lib/AST/Interp/Interp.cpp
@@ -213,6 +213,16 @@ bool CheckRange(InterpState &S, CodePtr OpPC, const Pointer &Ptr,
   return false;
 }
 
+bool CheckBaseDerived(InterpState &S, CodePtr OpPC, const Pointer &Ptr,
+                      CheckSubobjectKind CSK) {
+  if (!Ptr.isOnePastEnd())
+    return true;
+
+  const SourceInfo &Loc = S.Current->getSource(OpPC);
+  S.FFDiag(Loc, diag::note_constexpr_past_end_subobject) << CSK;
+  return false;
+}
+
 bool CheckConst(InterpState &S, CodePtr OpPC, const Pointer &Ptr) {
   assert(Ptr.isLive() && "Pointer is not live");
   if (!Ptr.isConst())

diff  --git a/clang/lib/AST/Interp/Interp.h b/clang/lib/AST/Interp/Interp.h
index a78b3921b0c7dd3..006abc0f0e94ace 100644
--- a/clang/lib/AST/Interp/Interp.h
+++ b/clang/lib/AST/Interp/Interp.h
@@ -67,6 +67,10 @@ bool CheckRange(InterpState &S, CodePtr OpPC, const Pointer &Ptr,
 bool CheckRange(InterpState &S, CodePtr OpPC, const Pointer &Ptr,
                 CheckSubobjectKind CSK);
 
+/// Checks if accessing a base or derived record of the given pointer is valid.
+bool CheckBaseDerived(InterpState &S, CodePtr OpPC, const Pointer &Ptr,
+                      CheckSubobjectKind CSK);
+
 /// Checks if a pointer points to const storage.
 bool CheckConst(InterpState &S, CodePtr OpPC, const Pointer &Ptr);
 
@@ -1157,10 +1161,22 @@ inline bool GetPtrActiveThisField(InterpState &S, CodePtr OpPC, uint32_t Off) {
   return true;
 }
 
+inline bool GetPtrDerivedPop(InterpState &S, CodePtr OpPC, uint32_t Off) {
+  const Pointer &Ptr = S.Stk.pop<Pointer>();
+  if (!CheckNull(S, OpPC, Ptr, CSK_Derived))
+    return false;
+  if (!CheckBaseDerived(S, OpPC, Ptr, CSK_Derived))
+    return false;
+  S.Stk.push<Pointer>(Ptr.atFieldSub(Off));
+  return true;
+}
+
 inline bool GetPtrBase(InterpState &S, CodePtr OpPC, uint32_t Off) {
   const Pointer &Ptr = S.Stk.peek<Pointer>();
   if (!CheckNull(S, OpPC, Ptr, CSK_Base))
     return false;
+  if (!CheckBaseDerived(S, OpPC, Ptr, CSK_Base))
+    return false;
   S.Stk.push<Pointer>(Ptr.atField(Off));
   return true;
 }
@@ -1169,6 +1185,8 @@ inline bool GetPtrBasePop(InterpState &S, CodePtr OpPC, uint32_t Off) {
   const Pointer &Ptr = S.Stk.pop<Pointer>();
   if (!CheckNull(S, OpPC, Ptr, CSK_Base))
     return false;
+  if (!CheckBaseDerived(S, OpPC, Ptr, CSK_Base))
+    return false;
   S.Stk.push<Pointer>(Ptr.atField(Off));
   return true;
 }

diff  --git a/clang/lib/AST/Interp/Opcodes.td b/clang/lib/AST/Interp/Opcodes.td
index f66c9492ecd1d43..8bdc4432e89b410 100644
--- a/clang/lib/AST/Interp/Opcodes.td
+++ b/clang/lib/AST/Interp/Opcodes.td
@@ -293,6 +293,10 @@ def GetPtrBasePop : Opcode {
   let Args = [ArgUint32];
 }
 
+def GetPtrDerivedPop : Opcode {
+  let Args = [ArgUint32];
+}
+
 // [Pointer] -> [Pointer]
 def GetPtrVirtBase : Opcode {
   // RecordDecl of base class.

diff  --git a/clang/lib/AST/Interp/Pointer.h b/clang/lib/AST/Interp/Pointer.h
index 48aa4df1430fc1a..f2d8922f896552f 100644
--- a/clang/lib/AST/Interp/Pointer.h
+++ b/clang/lib/AST/Interp/Pointer.h
@@ -109,6 +109,14 @@ class Pointer {
     return Pointer(Pointee, Field, Field);
   }
 
+  /// Subtract the given offset from the current Base and Offset
+  /// of the pointer.
+  Pointer atFieldSub(unsigned Off) const {
+    assert(Offset >= Off);
+    unsigned O = Offset - Off;
+    return Pointer(Pointee, O, O);
+  }
+
   /// Restricts the scope of an array element pointer.
   Pointer narrow() const {
     // Null pointers cannot be narrowed.

diff  --git a/clang/test/AST/Interp/records.cpp b/clang/test/AST/Interp/records.cpp
index 76882d29e201e9e..b559f1f8b95b0a2 100644
--- a/clang/test/AST/Interp/records.cpp
+++ b/clang/test/AST/Interp/records.cpp
@@ -625,6 +625,58 @@ namespace Destructors {
                                // ref-note {{in call to 'testS()'}}
 }
 
+namespace BaseToDerived {
+namespace A {
+  struct A {};
+  struct B : A { int n; };
+  struct C : B {};
+  C c = {};
+  constexpr C *pb = (C*)((A*)&c + 1); // expected-error {{must be initialized by a constant expression}} \
+                                      // expected-note {{cannot access derived class of pointer past the end of object}} \
+                                      // ref-error {{must be initialized by a constant expression}} \
+                                      // ref-note {{cannot access derived class of pointer past the end of object}}
+}
+namespace B {
+  struct A {};
+  struct Z {};
+  struct B : Z, A {
+    int n;
+   constexpr B() : n(10) {}
+  };
+  struct C : B {
+   constexpr C() : B() {}
+  };
+
+  constexpr C c = {};
+  constexpr const A *pa = &c;
+  constexpr const C *cp = (C*)pa;
+  constexpr const B *cb = (B*)cp;
+
+  static_assert(cb->n == 10);
+  static_assert(cp->n == 10);
+}
+
+namespace C {
+  struct Base { int *a; };
+  struct Base2 : Base { int f[12]; };
+
+  struct Middle1 { int b[3]; };
+  struct Middle2 : Base2 { char c; };
+  struct Middle3 : Middle2 { char g[3]; };
+  struct Middle4 { int f[3]; };
+  struct Middle5 : Middle4, Middle3 { char g2[3]; };
+
+  struct NotQuiteDerived : Middle1, Middle5 { bool d; };
+  struct Derived : NotQuiteDerived { int e; };
+
+  constexpr NotQuiteDerived NQD1 = {};
+
+  constexpr Middle5 *M4 = (Middle5*)((Base2*)&NQD1);
+  static_assert(M4->a == nullptr);
+  static_assert(M4->g2[0] == 0);
+}
+}
+
 
 namespace VirtualDtors {
   class A {


        


More information about the cfe-commits mailing list