[clang] [clang][bytecode] Check types when loading values (PR #165385)

via cfe-commits cfe-commits at lists.llvm.org
Tue Oct 28 06:17:32 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clang

Author: Timm Baeder (tbaederr)

<details>
<summary>Changes</summary>

We need to allow BitCasts between pointer types to different prim types, but that means we need to catch the problem at a later stage, i.e. when loading the values.

Fixes https://github.com/llvm/llvm-project/issues/158527 Fixex https://github.com/llvm/llvm-project/issues/163778

---
Full diff: https://github.com/llvm/llvm-project/pull/165385.diff


6 Files Affected:

- (modified) clang/lib/AST/ByteCode/Compiler.cpp (+21-1) 
- (modified) clang/lib/AST/ByteCode/Compiler.h (+2) 
- (modified) clang/lib/AST/ByteCode/Interp.h (+19-8) 
- (modified) clang/lib/AST/ByteCode/PrimType.h (+4) 
- (modified) clang/lib/AST/ByteCode/Program.cpp (+11-39) 
- (modified) clang/test/AST/ByteCode/invalid.cpp (+23) 


``````````diff
diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp
index 6c088469a3ca2..18e6bbcc82bd3 100644
--- a/clang/lib/AST/ByteCode/Compiler.cpp
+++ b/clang/lib/AST/ByteCode/Compiler.cpp
@@ -208,6 +208,19 @@ template <class Emitter> class LocOverrideScope final {
 } // namespace interp
 } // namespace clang
 
+template <class Emitter>
+bool Compiler<Emitter>::isValidBitCast(const CastExpr *E) {
+  QualType FromTy = E->getSubExpr()->getType()->getPointeeType();
+  QualType ToTy = E->getType()->getPointeeType();
+
+  if (classify(FromTy) == classify(ToTy))
+    return true;
+
+  if (FromTy->isVoidType() || ToTy->isVoidType())
+    return true;
+  return false;
+}
+
 template <class Emitter>
 bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
   const Expr *SubExpr = CE->getSubExpr();
@@ -476,8 +489,9 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
     return this->delegate(SubExpr);
 
   case CK_BitCast: {
+    QualType CETy = CE->getType();
     // Reject bitcasts to atomic types.
-    if (CE->getType()->isAtomicType()) {
+    if (CETy->isAtomicType()) {
       if (!this->discard(SubExpr))
         return false;
       return this->emitInvalidCast(CastKind::Reinterpret, /*Fatal=*/true, CE);
@@ -492,6 +506,12 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
     if (!FromT || !ToT)
       return false;
 
+    if (!this->isValidBitCast(CE)) {
+      if (!this->emitInvalidCast(CastKind::ReinterpretLike, /*Fatal=*/false,
+                                 CE))
+        return false;
+    }
+
     assert(isPtrType(*FromT));
     assert(isPtrType(*ToT));
     if (FromT == ToT) {
diff --git a/clang/lib/AST/ByteCode/Compiler.h b/clang/lib/AST/ByteCode/Compiler.h
index 5c46f75af4da3..fac0a7f4e1886 100644
--- a/clang/lib/AST/ByteCode/Compiler.h
+++ b/clang/lib/AST/ByteCode/Compiler.h
@@ -425,6 +425,8 @@ class Compiler : public ConstStmtVisitor<Compiler<Emitter>, bool>,
 
   bool refersToUnion(const Expr *E);
 
+  bool isValidBitCast(const CastExpr *E);
+
 protected:
   /// Variable to storage mapping.
   llvm::DenseMap<const ValueDecl *, Scope::Local> Locals;
diff --git a/clang/lib/AST/ByteCode/Interp.h b/clang/lib/AST/ByteCode/Interp.h
index 89f6fbefb1907..36f01e4e4496b 100644
--- a/clang/lib/AST/ByteCode/Interp.h
+++ b/clang/lib/AST/ByteCode/Interp.h
@@ -1914,6 +1914,9 @@ bool Load(InterpState &S, CodePtr OpPC) {
     return false;
   if (!Ptr.isBlockPointer())
     return false;
+  if (!Ptr.getFieldDesc()->isPrimitive() ||
+      Ptr.getFieldDesc()->getPrimType() != Name)
+    return false;
   S.Stk.push<T>(Ptr.deref<T>());
   return true;
 }
@@ -1925,6 +1928,9 @@ bool LoadPop(InterpState &S, CodePtr OpPC) {
     return false;
   if (!Ptr.isBlockPointer())
     return false;
+  if (!Ptr.getFieldDesc()->isPrimitive() ||
+      Ptr.getFieldDesc()->getPrimType() != Name)
+    return false;
   S.Stk.push<T>(Ptr.deref<T>());
   return true;
 }
@@ -3286,12 +3292,18 @@ inline bool InvalidCast(InterpState &S, CodePtr OpPC, CastKind Kind,
                         bool Fatal) {
   const SourceLocation &Loc = S.Current->getLocation(OpPC);
 
-  if (Kind == CastKind::Reinterpret) {
+  switch (Kind) {
+  case CastKind::Reinterpret:
     S.CCEDiag(Loc, diag::note_constexpr_invalid_cast)
-        << static_cast<unsigned>(Kind) << S.Current->getRange(OpPC);
+        << diag::ConstexprInvalidCastKind::Reinterpret
+        << S.Current->getRange(OpPC);
     return !Fatal;
-  }
-  if (Kind == CastKind::Volatile) {
+  case CastKind::ReinterpretLike:
+    S.CCEDiag(Loc, diag::note_constexpr_invalid_cast)
+        << diag::ConstexprInvalidCastKind::ThisConversionOrReinterpret
+        << S.getLangOpts().CPlusPlus << S.Current->getRange(OpPC);
+    return !Fatal;
+  case CastKind::Volatile:
     if (!S.checkingPotentialConstantExpression()) {
       const auto *E = cast<CastExpr>(S.Current->getExpr(OpPC));
       if (S.getLangOpts().CPlusPlus)
@@ -3302,14 +3314,13 @@ inline bool InvalidCast(InterpState &S, CodePtr OpPC, CastKind Kind,
     }
 
     return false;
-  }
-  if (Kind == CastKind::Dynamic) {
+  case CastKind::Dynamic:
     assert(!S.getLangOpts().CPlusPlus20);
-    S.CCEDiag(S.Current->getSource(OpPC), diag::note_constexpr_invalid_cast)
+    S.CCEDiag(Loc, diag::note_constexpr_invalid_cast)
         << diag::ConstexprInvalidCastKind::Dynamic;
     return true;
   }
-
+  llvm_unreachable("Unhandled CastKind");
   return false;
 }
 
diff --git a/clang/lib/AST/ByteCode/PrimType.h b/clang/lib/AST/ByteCode/PrimType.h
index 54fd39ac6fcc8..f0454b484ff98 100644
--- a/clang/lib/AST/ByteCode/PrimType.h
+++ b/clang/lib/AST/ByteCode/PrimType.h
@@ -101,6 +101,7 @@ inline constexpr bool isSignedType(PrimType T) {
 
 enum class CastKind : uint8_t {
   Reinterpret,
+  ReinterpretLike,
   Volatile,
   Dynamic,
 };
@@ -111,6 +112,9 @@ inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
   case interp::CastKind::Reinterpret:
     OS << "reinterpret_cast";
     break;
+  case interp::CastKind::ReinterpretLike:
+    OS << "reinterpret_like";
+    break;
   case interp::CastKind::Volatile:
     OS << "volatile";
     break;
diff --git a/clang/lib/AST/ByteCode/Program.cpp b/clang/lib/AST/ByteCode/Program.cpp
index e0b2852f0e906..465f4d15def73 100644
--- a/clang/lib/AST/ByteCode/Program.cpp
+++ b/clang/lib/AST/ByteCode/Program.cpp
@@ -36,30 +36,19 @@ unsigned Program::createGlobalString(const StringLiteral *S, const Expr *Base) {
   const size_t BitWidth = CharWidth * Ctx.getCharBit();
   unsigned StringLength = S->getLength();
 
-  PrimType CharType;
-  switch (CharWidth) {
-  case 1:
-    CharType = PT_Sint8;
-    break;
-  case 2:
-    CharType = PT_Uint16;
-    break;
-  case 4:
-    CharType = PT_Uint32;
-    break;
-  default:
-    llvm_unreachable("unsupported character width");
-  }
+  OptPrimType CharType =
+      Ctx.classify(S->getType()->castAsArrayTypeUnsafe()->getElementType());
+  assert(CharType);
 
   if (!Base)
     Base = S;
 
   // Create a descriptor for the string.
-  Descriptor *Desc =
-      allocateDescriptor(Base, CharType, Descriptor::GlobalMD, StringLength + 1,
-                         /*isConst=*/true,
-                         /*isTemporary=*/false,
-                         /*isMutable=*/false);
+  Descriptor *Desc = allocateDescriptor(Base, *CharType, Descriptor::GlobalMD,
+                                        StringLength + 1,
+                                        /*isConst=*/true,
+                                        /*isTemporary=*/false,
+                                        /*isMutable=*/false);
 
   // Allocate storage for the string.
   // The byte length does not include the null terminator.
@@ -79,26 +68,9 @@ unsigned Program::createGlobalString(const StringLiteral *S, const Expr *Base) {
   } else {
     // Construct the string in storage.
     for (unsigned I = 0; I <= StringLength; ++I) {
-      const uint32_t CodePoint = I == StringLength ? 0 : S->getCodeUnit(I);
-      switch (CharType) {
-      case PT_Sint8: {
-        using T = PrimConv<PT_Sint8>::T;
-        Ptr.elem<T>(I) = T::from(CodePoint, BitWidth);
-        break;
-      }
-      case PT_Uint16: {
-        using T = PrimConv<PT_Uint16>::T;
-        Ptr.elem<T>(I) = T::from(CodePoint, BitWidth);
-        break;
-      }
-      case PT_Uint32: {
-        using T = PrimConv<PT_Uint32>::T;
-        Ptr.elem<T>(I) = T::from(CodePoint, BitWidth);
-        break;
-      }
-      default:
-        llvm_unreachable("unsupported character type");
-      }
+      uint32_t CodePoint = I == StringLength ? 0 : S->getCodeUnit(I);
+      INT_TYPE_SWITCH_NO_BOOL(*CharType,
+                              Ptr.elem<T>(I) = T::from(CodePoint, BitWidth););
     }
   }
   Ptr.initializeAllElements();
diff --git a/clang/test/AST/ByteCode/invalid.cpp b/clang/test/AST/ByteCode/invalid.cpp
index 00db27419e36b..1f2d6bc1d48eb 100644
--- a/clang/test/AST/ByteCode/invalid.cpp
+++ b/clang/test/AST/ByteCode/invalid.cpp
@@ -66,3 +66,26 @@ struct S {
 S s;
 S *sp[2] = {&s, &s};
 S *&spp = sp[1];
+
+namespace InvalidBitCast {
+  void foo() {
+    const long long int i = 1; // both-note {{declared const here}}
+    if (*(double *)&i == 2) {
+      i = 0; // both-error {{cannot assign to variable}}
+    }
+  }
+
+  struct S2 {
+    void *p;
+  };
+  struct T {
+    S2 s;
+  };
+  constexpr T t = {{nullptr}};
+  constexpr void *foo2() { return ((void **)&t)[0]; } // both-error {{never produces a constant expression}} \
+                                                      // both-note 2{{cast that performs the conversions of a reinterpret_cast}}
+  constexpr auto x = foo2(); // both-error {{must be initialized by a constant expression}} \
+                             // both-note {{in call to}}
+
+
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/165385


More information about the cfe-commits mailing list