[clang] [Webkit Checkers] Introduce a Webkit checker for memory unsafe casts (PR #114606)

Rashmi Mudduluru via cfe-commits cfe-commits at lists.llvm.org
Wed Nov 20 15:53:57 PST 2024


https://github.com/t-rasmud updated https://github.com/llvm/llvm-project/pull/114606

>From cc19550fdbaca4b77e90de57c472a31a8c3f8293 Mon Sep 17 00:00:00 2001
From: Rashmi Mudduluru <r_mudduluru at apple.com>
Date: Fri, 1 Nov 2024 14:10:50 -0700
Subject: [PATCH 1/7] [Webkit Checkers] Introduce a Webkit checker for memory
 unsafe casts

The checker warns all downcasts from a base type to a derived type.

rdar://137766829
---
 clang/docs/analyzer/checkers.rst              |  25 +++
 clang/docs/tools/clang-formatted-files.txt    |   1 +
 .../clang/StaticAnalyzer/Checkers/Checkers.td |   4 +
 .../StaticAnalyzer/Checkers/CMakeLists.txt    |   1 +
 .../WebKit/MemoryUnsafeCastChecker.cpp        |  86 ++++++++++
 .../Checkers/WebKit/memory-unsafe-cast.cpp    | 151 ++++++++++++++++++
 .../Checkers/WebKit/memory-unsafe-cast.mm     |  29 ++++
 7 files changed, 297 insertions(+)
 create mode 100644 clang/lib/StaticAnalyzer/Checkers/WebKit/MemoryUnsafeCastChecker.cpp
 create mode 100644 clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.cpp
 create mode 100644 clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.mm

diff --git a/clang/docs/analyzer/checkers.rst b/clang/docs/analyzer/checkers.rst
index 87b03438e6e0b9..f01755ce7a236a 100644
--- a/clang/docs/analyzer/checkers.rst
+++ b/clang/docs/analyzer/checkers.rst
@@ -3452,6 +3452,31 @@ alpha.WebKit
 
 .. _alpha-webkit-NoUncheckedPtrMemberChecker:
 
+alpha.webkit.MemoryUnsafeCastChecker
+""""""""""""""""""""""""""""""""""""""
+Check for all casts from a base type to its derived type as these might be memory-unsafe.
+
+Example:
+
+.. code-block:: cpp
+
+class Base { };
+class Derived : public Base { };
+
+void f(Base* base) {
+    Derived* derived = static_cast<Derived*>(base); // ERROR
+}
+
+For all cast operations (C-style casts, static_cast, reinterpret_cast, dynamic_cast), if the source type a `Base*` and the destination type is `Derived*`, where `Derived` inherits from `Base`, the static analyzer should signal an error.
+
+This applies to:
+
+- C structs, C++ structs and classes, and Objective-C classes and protocols.
+- Pointers and references.
+- Inside template instantiations and macro expansions that are visible to the compiler.
+
+For types like this, instead of using built in casts, the programmer will use helper functions that internally perform the appropriate type check and disable static analysis.
+
 alpha.webkit.NoUncheckedPtrMemberChecker
 """"""""""""""""""""""""""""""""""""""""
 Raw pointers and references to an object which supports CheckedPtr or CheckedRef can't be used as class members. Only CheckedPtr, CheckedRef, RefPtr, or Ref are allowed.
diff --git a/clang/docs/tools/clang-formatted-files.txt b/clang/docs/tools/clang-formatted-files.txt
index 67ff085144f4de..74ab155d6174fd 100644
--- a/clang/docs/tools/clang-formatted-files.txt
+++ b/clang/docs/tools/clang-formatted-files.txt
@@ -537,6 +537,7 @@ clang/lib/StaticAnalyzer/Checkers/UninitializedObject/UninitializedPointee.cpp
 clang/lib/StaticAnalyzer/Checkers/WebKit/ASTUtils.cpp
 clang/lib/StaticAnalyzer/Checkers/WebKit/ASTUtils.h
 clang/lib/StaticAnalyzer/Checkers/WebKit/DiagOutputUtils.h
+clang/lib/StaticAnalyzer/Checkers/WebKit/MemoryUnsafeCastChecker.cpp
 clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
 clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.h
 clang/lib/StaticAnalyzer/Checkers/WebKit/RefCntblBaseVirtualDtorChecker.cpp
diff --git a/clang/include/clang/StaticAnalyzer/Checkers/Checkers.td b/clang/include/clang/StaticAnalyzer/Checkers/Checkers.td
index 9a6b35c1b9f774..445379e88ab9e3 100644
--- a/clang/include/clang/StaticAnalyzer/Checkers/Checkers.td
+++ b/clang/include/clang/StaticAnalyzer/Checkers/Checkers.td
@@ -1752,6 +1752,10 @@ def UncountedLambdaCapturesChecker : Checker<"UncountedLambdaCapturesChecker">,
 
 let ParentPackage = WebKitAlpha in {
 
+def MemoryUnsafeCastChecker : Checker<"MemoryUnsafeCastChecker">,
+  HelpText<"Check for memory unsafe casts from base type to derived type.">,
+  Documentation<HasDocumentation>;
+
 def NoUncheckedPtrMemberChecker : Checker<"NoUncheckedPtrMemberChecker">,
   HelpText<"Check for no unchecked member variables.">,
   Documentation<HasDocumentation>;
diff --git a/clang/lib/StaticAnalyzer/Checkers/CMakeLists.txt b/clang/lib/StaticAnalyzer/Checkers/CMakeLists.txt
index 62aa5ff7f002a9..7e987740f9ee2d 100644
--- a/clang/lib/StaticAnalyzer/Checkers/CMakeLists.txt
+++ b/clang/lib/StaticAnalyzer/Checkers/CMakeLists.txt
@@ -132,6 +132,7 @@ add_clang_library(clangStaticAnalyzerCheckers
   VirtualCallChecker.cpp
   WebKit/RawPtrRefMemberChecker.cpp
   WebKit/ASTUtils.cpp
+  WebKit/MemoryUnsafeCastChecker.cpp
   WebKit/PtrTypesSemantics.cpp
   WebKit/RefCntblBaseVirtualDtorChecker.cpp
   WebKit/UncountedCallArgsChecker.cpp
diff --git a/clang/lib/StaticAnalyzer/Checkers/WebKit/MemoryUnsafeCastChecker.cpp b/clang/lib/StaticAnalyzer/Checkers/WebKit/MemoryUnsafeCastChecker.cpp
new file mode 100644
index 00000000000000..05a5f89d28c8fe
--- /dev/null
+++ b/clang/lib/StaticAnalyzer/Checkers/WebKit/MemoryUnsafeCastChecker.cpp
@@ -0,0 +1,86 @@
+//=======- MemoryUnsafeCastChecker.cpp -------------------------*- C++ -*-==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines MemoryUnsafeCast checker, which checks for casts from a
+// base type to a derived type.
+//===----------------------------------------------------------------------===//
+
+#include "clang/AST/ASTContext.h"
+#include "clang/StaticAnalyzer/Checkers/BuiltinCheckerRegistration.h"
+#include "clang/StaticAnalyzer/Core/BugReporter/BugType.h"
+#include "clang/StaticAnalyzer/Core/Checker.h"
+#include "clang/StaticAnalyzer/Core/CheckerManager.h"
+#include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h"
+
+using namespace clang;
+using namespace ento;
+
+namespace {
+class MemoryUnsafeCastChecker : public Checker<check::PreStmt<CastExpr>> {
+  BugType BT{this, ""};
+
+public:
+  void checkPreStmt(const CastExpr *CE, CheckerContext &C) const;
+};
+} // end namespace
+
+void emitWarning(CheckerContext &C, const CastExpr &CE, const BugType &BT,
+                 QualType FromType, QualType ToType) {
+  ExplodedNode *errorNode = C.generateNonFatalErrorNode();
+  if (!errorNode)
+    return;
+  SmallString<192> Buf;
+  llvm::raw_svector_ostream OS(Buf);
+  OS << "Memory unsafe cast from base type '";
+  QualType::print(FromType.getTypePtr(), Qualifiers(), OS, C.getLangOpts(),
+                  llvm::Twine());
+  OS << "' to derived type '";
+  QualType::print(ToType.getTypePtr(), Qualifiers(), OS, C.getLangOpts(),
+                  llvm::Twine());
+  OS << "'";
+  auto R = std::make_unique<PathSensitiveBugReport>(BT, OS.str(), errorNode);
+  R->addRange(CE.getSourceRange());
+  C.emitReport(std::move(R));
+}
+
+void MemoryUnsafeCastChecker::checkPreStmt(const CastExpr *CE,
+                                           CheckerContext &C) const {
+  auto ExpCast = dyn_cast_or_null<ExplicitCastExpr>(CE);
+  if (!ExpCast)
+    return;
+
+  auto ToDerivedQualType = ExpCast->getTypeAsWritten();
+  auto *SE = CE->getSubExprAsWritten();
+  if (ToDerivedQualType->isObjCObjectPointerType()) {
+    auto FromBaseQualType = SE->getType();
+    bool IsObjCSubType =
+        !C.getASTContext().hasSameType(ToDerivedQualType, FromBaseQualType) &&
+        C.getASTContext().canAssignObjCInterfaces(
+            FromBaseQualType->getAsObjCInterfacePointerType(),
+            ToDerivedQualType->getAsObjCInterfacePointerType());
+    if (IsObjCSubType)
+      emitWarning(C, *CE, BT, FromBaseQualType, ToDerivedQualType);
+    return;
+  }
+  auto ToDerivedType = ToDerivedQualType->getPointeeCXXRecordDecl();
+  auto FromBaseType = SE->getType()->getPointeeCXXRecordDecl();
+  if (!FromBaseType)
+    FromBaseType = SE->getType()->getAsCXXRecordDecl();
+  if (!FromBaseType)
+    return;
+  if (ToDerivedType->isDerivedFrom(FromBaseType))
+    emitWarning(C, *CE, BT, SE->getType(), ToDerivedQualType);
+}
+
+void ento::registerMemoryUnsafeCastChecker(CheckerManager &Mgr) {
+  Mgr.registerChecker<MemoryUnsafeCastChecker>();
+}
+
+bool ento::shouldRegisterMemoryUnsafeCastChecker(const CheckerManager &) {
+  return true;
+}
diff --git a/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.cpp b/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.cpp
new file mode 100644
index 00000000000000..1a4ef6858d2180
--- /dev/null
+++ b/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.cpp
@@ -0,0 +1,151 @@
+// RUN: %clang_analyze_cc1 -analyzer-checker=alpha.webkit.MemoryUnsafeCastChecker -verify %s
+
+class Base { };
+class Derived : public Base { };
+
+void test_pointers(Base *base) {
+  Derived *derived_static = static_cast<Derived*>(base);
+  // expected-warning at -1{{Memory unsafe cast from base type 'Base *' to derived type 'Derived *'}}
+  Derived *derived_reinterpret = reinterpret_cast<Derived*>(base);
+  // expected-warning at -1{{Memory unsafe cast from base type 'Base *' to derived type 'Derived *'}}
+  Derived *derived_c = (Derived*)base;
+  // expected-warning at -1{{Memory unsafe cast from base type 'Base *' to derived type 'Derived *'}}
+}
+
+void test_refs(Base &base) {
+  Derived &derived_static = static_cast<Derived&>(base);
+  // expected-warning at -1{{Memory unsafe cast from base type 'Base' to derived type 'Derived &'}}
+  Derived &derived_reinterpret = reinterpret_cast<Derived&>(base);
+  // expected-warning at -1{{Memory unsafe cast from base type 'Base' to derived type 'Derived &'}}
+  Derived &derived_c = (Derived&)base;
+  // expected-warning at -1{{Memory unsafe cast from base type 'Base' to derived type 'Derived &'}}
+}
+
+class BaseVirtual {
+  virtual void virtual_base_function();
+};
+
+class DerivedVirtual : public BaseVirtual {
+  void virtual_base_function() override { }
+};
+
+void test_dynamic_casts(BaseVirtual *base_ptr, BaseVirtual &base_ref) {
+  DerivedVirtual *derived_dynamic_ptr = dynamic_cast<DerivedVirtual*>(base_ptr);
+  // expected-warning at -1{{Memory unsafe cast from base type 'BaseVirtual *' to derived type 'DerivedVirtual *'}}
+  DerivedVirtual &derived_dynamic_ref = dynamic_cast<DerivedVirtual&>(base_ref);
+  // expected-warning at -1{{Memory unsafe cast from base type 'BaseVirtual' to derived type 'DerivedVirtual &'}}
+}
+
+struct BaseStruct { };
+struct DerivedStruct : BaseStruct { };
+
+void test_struct_pointers(struct BaseStruct *base_struct) {
+  struct DerivedStruct *derived_static = static_cast<struct DerivedStruct*>(base_struct);
+  // expected-warning at -1{{Memory unsafe cast from base type 'struct BaseStruct *' to derived type 'struct DerivedStruct *'}}
+  struct DerivedStruct *derived_reinterpret = reinterpret_cast<struct DerivedStruct*>(base_struct);
+  // expected-warning at -1{{Memory unsafe cast from base type 'struct BaseStruct *' to derived type 'struct DerivedStruct *'}}
+  struct DerivedStruct *derived_c = (struct DerivedStruct*)base_struct;
+  // expected-warning at -1{{Memory unsafe cast from base type 'struct BaseStruct *' to derived type 'struct DerivedStruct *'}}
+}
+
+typedef struct BaseStruct BStruct;
+typedef struct DerivedStruct DStruct;
+
+void test_struct_refs(BStruct &base_struct) {
+  DStruct &derived_static = static_cast<DStruct&>(base_struct);
+  // expected-warning at -1{{Memory unsafe cast from base type 'BStruct' to derived type 'DStruct &'}}
+  DStruct &derived_reinterpret = reinterpret_cast<DStruct&>(base_struct);
+  // expected-warning at -1{{Memory unsafe cast from base type 'BStruct' to derived type 'DStruct &'}}
+  DStruct &derived_c = (DStruct&)base_struct;
+  // expected-warning at -1{{Memory unsafe cast from base type 'BStruct' to derived type 'DStruct &'}}
+}
+
+int counter = 0;
+void test_recursive(BStruct &base_struct) {
+  if (counter == 5)
+    return;
+  counter++;
+  DStruct &derived_static = static_cast<DStruct&>(base_struct);
+  // expected-warning at -1{{Memory unsafe cast from base type 'BStruct' to derived type 'DStruct &'}}
+}
+
+template<typename T>
+class BaseTemplate { };
+
+template<typename T>
+class DerivedTemplate : public BaseTemplate<T> { };
+
+void test_templates(BaseTemplate<int> *base, BaseTemplate<int> &base_ref) {
+  DerivedTemplate<int> *derived_static = static_cast<DerivedTemplate<int>*>(base);
+  // expected-warning at -1{{Memory unsafe cast from base type 'BaseTemplate<int> *' to derived type 'DerivedTemplate<int> *'}}
+  DerivedTemplate<int> *derived_reinterpret = reinterpret_cast<DerivedTemplate<int>*>(base);
+  // expected-warning at -1{{Memory unsafe cast from base type 'BaseTemplate<int> *' to derived type 'DerivedTemplate<int> *'}}
+  DerivedTemplate<int> *derived_c = (DerivedTemplate<int>*)base;
+  // expected-warning at -1{{Memory unsafe cast from base type 'BaseTemplate<int> *' to derived type 'DerivedTemplate<int> *'}}
+  DerivedTemplate<int> &derived_static_ref = static_cast<DerivedTemplate<int>&>(base_ref);
+  // expected-warning at -1{{Memory unsafe cast from base type 'BaseTemplate<int>' to derived type 'DerivedTemplate<int> &'}}
+  DerivedTemplate<int> &derived_reinterpret_ref = reinterpret_cast<DerivedTemplate<int>&>(base_ref);
+  // expected-warning at -1{{Memory unsafe cast from base type 'BaseTemplate<int>' to derived type 'DerivedTemplate<int> &'}}
+  DerivedTemplate<int> &derived_c_ref = (DerivedTemplate<int>&)base_ref;
+  // expected-warning at -1{{Memory unsafe cast from base type 'BaseTemplate<int>' to derived type 'DerivedTemplate<int> &'}}
+}
+
+#define CAST_MACRO_STATIC(X,Y) (static_cast<Y>(X))
+#define CAST_MACRO_REINTERPRET(X,Y) (reinterpret_cast<Y>(X))
+#define CAST_MACRO_C(X,Y) ((Y)X)
+
+void test_macro_static(Base *base, Derived *derived, Base &base_ref) {
+  Derived *derived_static = CAST_MACRO_STATIC(base, Derived*);
+  // expected-warning at -1{{Memory unsafe cast from base type 'Base *' to derived type 'Derived *'}}
+  Derived &derived_static_ref = CAST_MACRO_STATIC(base_ref, Derived&);
+  // expected-warning at -1{{Memory unsafe cast from base type 'Base' to derived type 'Derived &'}}
+  Base *base_static_same = CAST_MACRO_STATIC(base, Base*);  // no warning
+  Base *base_static_upcast = CAST_MACRO_STATIC(derived, Base*);  // no warning
+}
+
+void test_macro_reinterpret(Base *base, Derived *derived, Base &base_ref) {
+  Derived *derived_reinterpret = CAST_MACRO_REINTERPRET(base, Derived*);
+  // expected-warning at -1{{Memory unsafe cast from base type 'Base *' to derived type 'Derived *'}}
+  Derived &derived_reinterpret_ref = CAST_MACRO_REINTERPRET(base_ref, Derived&);
+  // expected-warning at -1{{Memory unsafe cast from base type 'Base' to derived type 'Derived &' [alpha.webkit.MemoryUnsafeCastChecker]}}
+  Base *base_reinterpret_same = CAST_MACRO_REINTERPRET(base, Base*);  // no warning
+  Base *base_reinterpret_upcast = CAST_MACRO_REINTERPRET(derived, Base*);  // no warning
+}
+
+void test_macro_c(Base *base, Derived *derived, Base &base_ref) {
+  Derived *derived_c = CAST_MACRO_C(base, Derived*);
+  // expected-warning at -1{{Memory unsafe cast from base type 'Base *' to derived type 'Derived *' [alpha.webkit.MemoryUnsafeCastChecker]}}
+  Derived &derived_c_ref = CAST_MACRO_C(base_ref, Derived&);
+  // expected-warning at -1{{Memory unsafe cast from base type 'Base' to derived type 'Derived &' [alpha.webkit.MemoryUnsafeCastChecker]}}
+  Base *base_c_same = CAST_MACRO_C(base, Base*);  // no warning
+  Base *base_c_upcast = CAST_MACRO_C(derived, Base*);  // no warning
+}
+
+struct BaseStructCpp {
+  int t;
+  void increment() { t++; }
+};
+struct DerivedStructCpp : BaseStructCpp {
+  void increment_t() {increment();}
+};
+
+void test_struct_cpp_pointers(struct BaseStructCpp *base_struct) {
+  struct DerivedStructCpp *derived_static = static_cast<struct DerivedStructCpp*>(base_struct);
+  // expected-warning at -1{{Memory unsafe cast from base type 'struct BaseStructCpp *' to derived type 'struct DerivedStructCpp *'}}
+  struct DerivedStructCpp *derived_reinterpret = reinterpret_cast<struct DerivedStructCpp*>(base_struct);
+  // expected-warning at -1{{Memory unsafe cast from base type 'struct BaseStructCpp *' to derived type 'struct DerivedStructCpp *'}}
+  struct DerivedStructCpp *derived_c = (struct DerivedStructCpp*)base_struct;
+  // expected-warning at -1{{Memory unsafe cast from base type 'struct BaseStructCpp *' to derived type 'struct DerivedStructCpp *'}}
+}
+
+typedef struct BaseStructCpp BStructCpp;
+typedef struct DerivedStructCpp DStructCpp;
+
+void test_struct_cpp_refs(BStructCpp &base_struct) {
+  DStructCpp &derived_static = static_cast<DStructCpp&>(base_struct);
+  // expected-warning at -1{{Memory unsafe cast from base type 'BStructCpp' to derived type 'DStructCpp &'}}
+  DStructCpp &derived_reinterpret = reinterpret_cast<DStructCpp&>(base_struct);
+  // expected-warning at -1{{Memory unsafe cast from base type 'BStructCpp' to derived type 'DStructCpp &'}}
+  DStructCpp &derived_c = (DStructCpp&)base_struct;
+  // expected-warning at -1{{Memory unsafe cast from base type 'BStructCpp' to derived type 'DStructCpp &'}}
+}
diff --git a/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.mm b/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.mm
new file mode 100644
index 00000000000000..1ea2f4aa472715
--- /dev/null
+++ b/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.mm
@@ -0,0 +1,29 @@
+// RUN: %clang_analyze_cc1 -analyzer-checker=alpha.webkit.MemoryUnsafeCastChecker -verify %s
+
+ at protocol NSObject
++alloc;
+-init;
+ at end
+
+ at interface NSObject <NSObject> {}
+ at end
+
+ at interface BaseClass : NSObject
+ at end
+
+ at interface DerivedClass : BaseClass
+-(void)testCasts:(BaseClass*)base;
+ at end
+
+ at implementation DerivedClass
+-(void)testCasts:(BaseClass*)base {
+  DerivedClass *derived = (DerivedClass*)base;
+  // expected-warning at -1{{Memory unsafe cast from base type 'BaseClass *' to derived type 'DerivedClass *'}}
+  DerivedClass *derived_static = static_cast<DerivedClass*>(base);
+  // expected-warning at -1{{Memory unsafe cast from base type 'BaseClass *' to derived type 'DerivedClass *'}}
+  DerivedClass *derived_reinterpret = reinterpret_cast<DerivedClass*>(base);
+  // expected-warning at -1{{Memory unsafe cast from base type 'BaseClass *' to derived type 'DerivedClass *'}}
+  base = (BaseClass*)derived;  // no warning
+  base = (BaseClass*)base;  // no warning
+}
+ at end

>From eece6157a062e76bec376beb2aada1edbb3199e7 Mon Sep 17 00:00:00 2001
From: Rashmi Mudduluru <r_mudduluru at apple.com>
Date: Tue, 5 Nov 2024 10:59:22 -0800
Subject: [PATCH 2/7] Address comments: fix tests, make checker
 path-insensitive

---
 .../WebKit/MemoryUnsafeCastChecker.cpp        | 95 ++++++++++++-------
 .../Checkers/WebKit/memory-unsafe-cast.cpp    | 71 +++++++-------
 .../Checkers/WebKit/memory-unsafe-cast.mm     |  6 +-
 3 files changed, 103 insertions(+), 69 deletions(-)

diff --git a/clang/lib/StaticAnalyzer/Checkers/WebKit/MemoryUnsafeCastChecker.cpp b/clang/lib/StaticAnalyzer/Checkers/WebKit/MemoryUnsafeCastChecker.cpp
index 05a5f89d28c8fe..b5d1e0d810c610 100644
--- a/clang/lib/StaticAnalyzer/Checkers/WebKit/MemoryUnsafeCastChecker.cpp
+++ b/clang/lib/StaticAnalyzer/Checkers/WebKit/MemoryUnsafeCastChecker.cpp
@@ -10,46 +10,71 @@
 // base type to a derived type.
 //===----------------------------------------------------------------------===//
 
-#include "clang/AST/ASTContext.h"
 #include "clang/StaticAnalyzer/Checkers/BuiltinCheckerRegistration.h"
-#include "clang/StaticAnalyzer/Core/BugReporter/BugType.h"
+#include "clang/AST/StmtVisitor.h"
+#include "clang/Analysis/AnalysisDeclContext.h"
+#include "clang/StaticAnalyzer/Core/BugReporter/BugReporter.h"
 #include "clang/StaticAnalyzer/Core/Checker.h"
-#include "clang/StaticAnalyzer/Core/CheckerManager.h"
-#include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h"
+#include "clang/StaticAnalyzer/Core/PathSensitive/AnalysisManager.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/Support/raw_ostream.h"
 
 using namespace clang;
 using namespace ento;
 
 namespace {
-class MemoryUnsafeCastChecker : public Checker<check::PreStmt<CastExpr>> {
-  BugType BT{this, ""};
+class WalkAST : public StmtVisitor<WalkAST> {
+  BugReporter &BR;
+  const CheckerBase *Checker;
+  AnalysisDeclContext* AC;
+  ASTContext &ASTC;
 
 public:
-  void checkPreStmt(const CastExpr *CE, CheckerContext &C) const;
+  WalkAST(BugReporter &br, const CheckerBase *checker, AnalysisDeclContext *ac)
+      : BR(br), Checker(checker), AC(ac), ASTC(AC->getASTContext()) {}
+
+  // Statement visitor methods.
+  void VisitChildren(Stmt *S);
+  void VisitStmt(Stmt *S) { VisitChildren(S); }
+  void VisitCastExpr(CastExpr *CE);
 };
-} // end namespace
+} // end anonymous namespace
 
-void emitWarning(CheckerContext &C, const CastExpr &CE, const BugType &BT,
-                 QualType FromType, QualType ToType) {
-  ExplodedNode *errorNode = C.generateNonFatalErrorNode();
-  if (!errorNode)
-    return;
-  SmallString<192> Buf;
-  llvm::raw_svector_ostream OS(Buf);
-  OS << "Memory unsafe cast from base type '";
-  QualType::print(FromType.getTypePtr(), Qualifiers(), OS, C.getLangOpts(),
-                  llvm::Twine());
-  OS << "' to derived type '";
-  QualType::print(ToType.getTypePtr(), Qualifiers(), OS, C.getLangOpts(),
-                  llvm::Twine());
-  OS << "'";
-  auto R = std::make_unique<PathSensitiveBugReport>(BT, OS.str(), errorNode);
-  R->addRange(CE.getSourceRange());
-  C.emitReport(std::move(R));
+void emitWarning(QualType FromType, QualType ToType,
+                 AnalysisDeclContext *AC, BugReporter &BR,
+                 const CheckerBase *Checker,
+                 CastExpr *CE) {
+  std::string Diagnostics;
+  llvm::raw_string_ostream OS(Diagnostics);
+  OS << "Unsafe cast from base type '"
+     << FromType
+     << "' to derived type '"
+     << ToType
+     << "'",
+
+  BR.EmitBasicReport(
+    AC->getDecl(),
+    Checker,
+    /*Name=*/"Memory unsafe cast",
+    categories::SecurityError,
+    Diagnostics,
+    PathDiagnosticLocation::createBegin(CE, BR.getSourceManager(), AC),
+    CE->getSourceRange());
 }
 
-void MemoryUnsafeCastChecker::checkPreStmt(const CastExpr *CE,
-                                           CheckerContext &C) const {
+namespace {
+class MemoryUnsafeCastChecker : public Checker<check::ASTCodeBody> {
+  BugType BT{this, ""};
+public:
+  void checkASTCodeBody(const Decl *D, AnalysisManager& Mgr,
+                        BugReporter &BR) const {
+    WalkAST walker(BR, this, Mgr.getAnalysisDeclContext(D));
+    walker.Visit(D->getBody());
+  }
+};
+}
+
+void WalkAST::VisitCastExpr(CastExpr *CE) {
   auto ExpCast = dyn_cast_or_null<ExplicitCastExpr>(CE);
   if (!ExpCast)
     return;
@@ -59,12 +84,12 @@ void MemoryUnsafeCastChecker::checkPreStmt(const CastExpr *CE,
   if (ToDerivedQualType->isObjCObjectPointerType()) {
     auto FromBaseQualType = SE->getType();
     bool IsObjCSubType =
-        !C.getASTContext().hasSameType(ToDerivedQualType, FromBaseQualType) &&
-        C.getASTContext().canAssignObjCInterfaces(
+        !ASTC.hasSameType(ToDerivedQualType, FromBaseQualType) &&
+        ASTC.canAssignObjCInterfaces(
             FromBaseQualType->getAsObjCInterfacePointerType(),
             ToDerivedQualType->getAsObjCInterfacePointerType());
     if (IsObjCSubType)
-      emitWarning(C, *CE, BT, FromBaseQualType, ToDerivedQualType);
+      emitWarning(SE->getType(), ToDerivedQualType,AC, BR, Checker, CE);
     return;
   }
   auto ToDerivedType = ToDerivedQualType->getPointeeCXXRecordDecl();
@@ -74,13 +99,19 @@ void MemoryUnsafeCastChecker::checkPreStmt(const CastExpr *CE,
   if (!FromBaseType)
     return;
   if (ToDerivedType->isDerivedFrom(FromBaseType))
-    emitWarning(C, *CE, BT, SE->getType(), ToDerivedQualType);
+    emitWarning(SE->getType(), ToDerivedQualType, AC, BR, Checker, CE);
+}
+
+void WalkAST::VisitChildren(Stmt *S) {
+  for (Stmt *Child : S->children())
+    if (Child)
+      Visit(Child);
 }
 
 void ento::registerMemoryUnsafeCastChecker(CheckerManager &Mgr) {
   Mgr.registerChecker<MemoryUnsafeCastChecker>();
 }
 
-bool ento::shouldRegisterMemoryUnsafeCastChecker(const CheckerManager &) {
+bool ento::shouldRegisterMemoryUnsafeCastChecker(const CheckerManager &mgr) {
   return true;
 }
diff --git a/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.cpp b/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.cpp
index 1a4ef6858d2180..ed58ac4ae0d5b7 100644
--- a/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.cpp
+++ b/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.cpp
@@ -5,20 +5,20 @@ class Derived : public Base { };
 
 void test_pointers(Base *base) {
   Derived *derived_static = static_cast<Derived*>(base);
-  // expected-warning at -1{{Memory unsafe cast from base type 'Base *' to derived type 'Derived *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'Base *' to derived type 'Derived *'}}
   Derived *derived_reinterpret = reinterpret_cast<Derived*>(base);
-  // expected-warning at -1{{Memory unsafe cast from base type 'Base *' to derived type 'Derived *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'Base *' to derived type 'Derived *'}}
   Derived *derived_c = (Derived*)base;
-  // expected-warning at -1{{Memory unsafe cast from base type 'Base *' to derived type 'Derived *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'Base *' to derived type 'Derived *'}}
 }
 
 void test_refs(Base &base) {
   Derived &derived_static = static_cast<Derived&>(base);
-  // expected-warning at -1{{Memory unsafe cast from base type 'Base' to derived type 'Derived &'}}
+  // expected-warning at -1{{Unsafe cast from base type 'Base' to derived type 'Derived &'}}
   Derived &derived_reinterpret = reinterpret_cast<Derived&>(base);
-  // expected-warning at -1{{Memory unsafe cast from base type 'Base' to derived type 'Derived &'}}
+  // expected-warning at -1{{Unsafe cast from base type 'Base' to derived type 'Derived &'}}
   Derived &derived_c = (Derived&)base;
-  // expected-warning at -1{{Memory unsafe cast from base type 'Base' to derived type 'Derived &'}}
+  // expected-warning at -1{{Unsafe cast from base type 'Base' to derived type 'Derived &'}}
 }
 
 class BaseVirtual {
@@ -31,9 +31,9 @@ class DerivedVirtual : public BaseVirtual {
 
 void test_dynamic_casts(BaseVirtual *base_ptr, BaseVirtual &base_ref) {
   DerivedVirtual *derived_dynamic_ptr = dynamic_cast<DerivedVirtual*>(base_ptr);
-  // expected-warning at -1{{Memory unsafe cast from base type 'BaseVirtual *' to derived type 'DerivedVirtual *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseVirtual *' to derived type 'DerivedVirtual *'}}
   DerivedVirtual &derived_dynamic_ref = dynamic_cast<DerivedVirtual&>(base_ref);
-  // expected-warning at -1{{Memory unsafe cast from base type 'BaseVirtual' to derived type 'DerivedVirtual &'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseVirtual' to derived type 'DerivedVirtual &'}}
 }
 
 struct BaseStruct { };
@@ -41,11 +41,11 @@ struct DerivedStruct : BaseStruct { };
 
 void test_struct_pointers(struct BaseStruct *base_struct) {
   struct DerivedStruct *derived_static = static_cast<struct DerivedStruct*>(base_struct);
-  // expected-warning at -1{{Memory unsafe cast from base type 'struct BaseStruct *' to derived type 'struct DerivedStruct *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'struct BaseStruct *' to derived type 'struct DerivedStruct *'}}
   struct DerivedStruct *derived_reinterpret = reinterpret_cast<struct DerivedStruct*>(base_struct);
-  // expected-warning at -1{{Memory unsafe cast from base type 'struct BaseStruct *' to derived type 'struct DerivedStruct *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'struct BaseStruct *' to derived type 'struct DerivedStruct *'}}
   struct DerivedStruct *derived_c = (struct DerivedStruct*)base_struct;
-  // expected-warning at -1{{Memory unsafe cast from base type 'struct BaseStruct *' to derived type 'struct DerivedStruct *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'struct BaseStruct *' to derived type 'struct DerivedStruct *'}}
 }
 
 typedef struct BaseStruct BStruct;
@@ -53,11 +53,11 @@ typedef struct DerivedStruct DStruct;
 
 void test_struct_refs(BStruct &base_struct) {
   DStruct &derived_static = static_cast<DStruct&>(base_struct);
-  // expected-warning at -1{{Memory unsafe cast from base type 'BStruct' to derived type 'DStruct &'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BStruct' to derived type 'DStruct &'}}
   DStruct &derived_reinterpret = reinterpret_cast<DStruct&>(base_struct);
-  // expected-warning at -1{{Memory unsafe cast from base type 'BStruct' to derived type 'DStruct &'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BStruct' to derived type 'DStruct &'}}
   DStruct &derived_c = (DStruct&)base_struct;
-  // expected-warning at -1{{Memory unsafe cast from base type 'BStruct' to derived type 'DStruct &'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BStruct' to derived type 'DStruct &'}}
 }
 
 int counter = 0;
@@ -66,7 +66,7 @@ void test_recursive(BStruct &base_struct) {
     return;
   counter++;
   DStruct &derived_static = static_cast<DStruct&>(base_struct);
-  // expected-warning at -1{{Memory unsafe cast from base type 'BStruct' to derived type 'DStruct &'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BStruct' to derived type 'DStruct &'}}
 }
 
 template<typename T>
@@ -77,17 +77,17 @@ class DerivedTemplate : public BaseTemplate<T> { };
 
 void test_templates(BaseTemplate<int> *base, BaseTemplate<int> &base_ref) {
   DerivedTemplate<int> *derived_static = static_cast<DerivedTemplate<int>*>(base);
-  // expected-warning at -1{{Memory unsafe cast from base type 'BaseTemplate<int> *' to derived type 'DerivedTemplate<int> *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseTemplate<int> *' to derived type 'DerivedTemplate<int> *'}}
   DerivedTemplate<int> *derived_reinterpret = reinterpret_cast<DerivedTemplate<int>*>(base);
-  // expected-warning at -1{{Memory unsafe cast from base type 'BaseTemplate<int> *' to derived type 'DerivedTemplate<int> *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseTemplate<int> *' to derived type 'DerivedTemplate<int> *'}}
   DerivedTemplate<int> *derived_c = (DerivedTemplate<int>*)base;
-  // expected-warning at -1{{Memory unsafe cast from base type 'BaseTemplate<int> *' to derived type 'DerivedTemplate<int> *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseTemplate<int> *' to derived type 'DerivedTemplate<int> *'}}
   DerivedTemplate<int> &derived_static_ref = static_cast<DerivedTemplate<int>&>(base_ref);
-  // expected-warning at -1{{Memory unsafe cast from base type 'BaseTemplate<int>' to derived type 'DerivedTemplate<int> &'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseTemplate<int>' to derived type 'DerivedTemplate<int> &'}}
   DerivedTemplate<int> &derived_reinterpret_ref = reinterpret_cast<DerivedTemplate<int>&>(base_ref);
-  // expected-warning at -1{{Memory unsafe cast from base type 'BaseTemplate<int>' to derived type 'DerivedTemplate<int> &'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseTemplate<int>' to derived type 'DerivedTemplate<int> &'}}
   DerivedTemplate<int> &derived_c_ref = (DerivedTemplate<int>&)base_ref;
-  // expected-warning at -1{{Memory unsafe cast from base type 'BaseTemplate<int>' to derived type 'DerivedTemplate<int> &'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseTemplate<int>' to derived type 'DerivedTemplate<int> &'}}
 }
 
 #define CAST_MACRO_STATIC(X,Y) (static_cast<Y>(X))
@@ -96,27 +96,27 @@ void test_templates(BaseTemplate<int> *base, BaseTemplate<int> &base_ref) {
 
 void test_macro_static(Base *base, Derived *derived, Base &base_ref) {
   Derived *derived_static = CAST_MACRO_STATIC(base, Derived*);
-  // expected-warning at -1{{Memory unsafe cast from base type 'Base *' to derived type 'Derived *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'Base *' to derived type 'Derived *'}}
   Derived &derived_static_ref = CAST_MACRO_STATIC(base_ref, Derived&);
-  // expected-warning at -1{{Memory unsafe cast from base type 'Base' to derived type 'Derived &'}}
+  // expected-warning at -1{{Unsafe cast from base type 'Base' to derived type 'Derived &'}}
   Base *base_static_same = CAST_MACRO_STATIC(base, Base*);  // no warning
   Base *base_static_upcast = CAST_MACRO_STATIC(derived, Base*);  // no warning
 }
 
 void test_macro_reinterpret(Base *base, Derived *derived, Base &base_ref) {
   Derived *derived_reinterpret = CAST_MACRO_REINTERPRET(base, Derived*);
-  // expected-warning at -1{{Memory unsafe cast from base type 'Base *' to derived type 'Derived *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'Base *' to derived type 'Derived *'}}
   Derived &derived_reinterpret_ref = CAST_MACRO_REINTERPRET(base_ref, Derived&);
-  // expected-warning at -1{{Memory unsafe cast from base type 'Base' to derived type 'Derived &' [alpha.webkit.MemoryUnsafeCastChecker]}}
+  // expected-warning at -1{{Unsafe cast from base type 'Base' to derived type 'Derived &' [alpha.webkit.MemoryUnsafeCastChecker]}}
   Base *base_reinterpret_same = CAST_MACRO_REINTERPRET(base, Base*);  // no warning
   Base *base_reinterpret_upcast = CAST_MACRO_REINTERPRET(derived, Base*);  // no warning
 }
 
 void test_macro_c(Base *base, Derived *derived, Base &base_ref) {
   Derived *derived_c = CAST_MACRO_C(base, Derived*);
-  // expected-warning at -1{{Memory unsafe cast from base type 'Base *' to derived type 'Derived *' [alpha.webkit.MemoryUnsafeCastChecker]}}
+  // expected-warning at -1{{Unsafe cast from base type 'Base *' to derived type 'Derived *' [alpha.webkit.MemoryUnsafeCastChecker]}}
   Derived &derived_c_ref = CAST_MACRO_C(base_ref, Derived&);
-  // expected-warning at -1{{Memory unsafe cast from base type 'Base' to derived type 'Derived &' [alpha.webkit.MemoryUnsafeCastChecker]}}
+  // expected-warning at -1{{Unsafe cast from base type 'Base' to derived type 'Derived &' [alpha.webkit.MemoryUnsafeCastChecker]}}
   Base *base_c_same = CAST_MACRO_C(base, Base*);  // no warning
   Base *base_c_upcast = CAST_MACRO_C(derived, Base*);  // no warning
 }
@@ -131,21 +131,24 @@ struct DerivedStructCpp : BaseStructCpp {
 
 void test_struct_cpp_pointers(struct BaseStructCpp *base_struct) {
   struct DerivedStructCpp *derived_static = static_cast<struct DerivedStructCpp*>(base_struct);
-  // expected-warning at -1{{Memory unsafe cast from base type 'struct BaseStructCpp *' to derived type 'struct DerivedStructCpp *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'struct BaseStructCpp *' to derived type 'struct DerivedStructCpp *'}}
   struct DerivedStructCpp *derived_reinterpret = reinterpret_cast<struct DerivedStructCpp*>(base_struct);
-  // expected-warning at -1{{Memory unsafe cast from base type 'struct BaseStructCpp *' to derived type 'struct DerivedStructCpp *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'struct BaseStructCpp *' to derived type 'struct DerivedStructCpp *'}}
   struct DerivedStructCpp *derived_c = (struct DerivedStructCpp*)base_struct;
-  // expected-warning at -1{{Memory unsafe cast from base type 'struct BaseStructCpp *' to derived type 'struct DerivedStructCpp *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'struct BaseStructCpp *' to derived type 'struct DerivedStructCpp *'}}
 }
 
 typedef struct BaseStructCpp BStructCpp;
 typedef struct DerivedStructCpp DStructCpp;
 
-void test_struct_cpp_refs(BStructCpp &base_struct) {
+void test_struct_cpp_refs(BStructCpp &base_struct, DStructCpp &derived_struct) {
   DStructCpp &derived_static = static_cast<DStructCpp&>(base_struct);
-  // expected-warning at -1{{Memory unsafe cast from base type 'BStructCpp' to derived type 'DStructCpp &'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BStructCpp' to derived type 'DStructCpp &'}}
   DStructCpp &derived_reinterpret = reinterpret_cast<DStructCpp&>(base_struct);
-  // expected-warning at -1{{Memory unsafe cast from base type 'BStructCpp' to derived type 'DStructCpp &'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BStructCpp' to derived type 'DStructCpp &'}}
   DStructCpp &derived_c = (DStructCpp&)base_struct;
-  // expected-warning at -1{{Memory unsafe cast from base type 'BStructCpp' to derived type 'DStructCpp &'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BStructCpp' to derived type 'DStructCpp &'}}
+  BStructCpp &base = (BStructCpp&)derived_struct; // no warning
+  BStructCpp &base_static = static_cast<BStructCpp&>(derived_struct); // no warning
+  BStructCpp &base_reinterpret = reinterpret_cast<BStructCpp&>(derived_struct); // no warning
 }
diff --git a/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.mm b/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.mm
index 1ea2f4aa472715..61f2fffd43fbda 100644
--- a/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.mm
+++ b/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.mm
@@ -18,11 +18,11 @@ -(void)testCasts:(BaseClass*)base;
 @implementation DerivedClass
 -(void)testCasts:(BaseClass*)base {
   DerivedClass *derived = (DerivedClass*)base;
-  // expected-warning at -1{{Memory unsafe cast from base type 'BaseClass *' to derived type 'DerivedClass *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseClass *' to derived type 'DerivedClass *'}}
   DerivedClass *derived_static = static_cast<DerivedClass*>(base);
-  // expected-warning at -1{{Memory unsafe cast from base type 'BaseClass *' to derived type 'DerivedClass *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseClass *' to derived type 'DerivedClass *'}}
   DerivedClass *derived_reinterpret = reinterpret_cast<DerivedClass*>(base);
-  // expected-warning at -1{{Memory unsafe cast from base type 'BaseClass *' to derived type 'DerivedClass *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseClass *' to derived type 'DerivedClass *'}}
   base = (BaseClass*)derived;  // no warning
   base = (BaseClass*)base;  // no warning
 }

>From 8e30c93bd7a9175f04cae658bb89ffa149b8b256 Mon Sep 17 00:00:00 2001
From: Rashmi Mudduluru <r_mudduluru at apple.com>
Date: Wed, 6 Nov 2024 10:56:48 -0800
Subject: [PATCH 3/7] Add tests: define downcast and ensure no warning

---
 .../Checkers/WebKit/memory-unsafe-cast.cpp         | 14 ++++++++++++++
 1 file changed, 14 insertions(+)

diff --git a/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.cpp b/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.cpp
index ed58ac4ae0d5b7..72a054d0652d7d 100644
--- a/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.cpp
+++ b/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.cpp
@@ -3,6 +3,18 @@
 class Base { };
 class Derived : public Base { };
 
+template<typename Target, typename Source>
+Target& downcast_ref(Source& source){
+  [[clang::suppress]]
+  return static_cast<Target&>(source);
+}
+
+template<typename Target, typename Source>
+Target* downcast_ptr(Source* source){
+  [[clang::suppress]]
+  return static_cast<Target*>(source);
+}
+
 void test_pointers(Base *base) {
   Derived *derived_static = static_cast<Derived*>(base);
   // expected-warning at -1{{Unsafe cast from base type 'Base *' to derived type 'Derived *'}}
@@ -10,6 +22,7 @@ void test_pointers(Base *base) {
   // expected-warning at -1{{Unsafe cast from base type 'Base *' to derived type 'Derived *'}}
   Derived *derived_c = (Derived*)base;
   // expected-warning at -1{{Unsafe cast from base type 'Base *' to derived type 'Derived *'}}
+  Derived *derived_d = downcast_ptr<Derived, Base>(base);  // no warning
 }
 
 void test_refs(Base &base) {
@@ -19,6 +32,7 @@ void test_refs(Base &base) {
   // expected-warning at -1{{Unsafe cast from base type 'Base' to derived type 'Derived &'}}
   Derived &derived_c = (Derived&)base;
   // expected-warning at -1{{Unsafe cast from base type 'Base' to derived type 'Derived &'}}
+  Derived &derived_d = downcast_ref<Derived, Base>(base);  // no warning
 }
 
 class BaseVirtual {

>From 073abfca2c856b4d19fc65da1dcc51088d4b3a31 Mon Sep 17 00:00:00 2001
From: Rashmi Mudduluru <r_mudduluru at apple.com>
Date: Thu, 7 Nov 2024 15:41:50 -0800
Subject: [PATCH 4/7] Fix crashes and add corresponding reproducer test cases

---
 .../WebKit/MemoryUnsafeCastChecker.cpp        | 21 +++++++++++-----
 .../Checkers/WebKit/memory-unsafe-cast.cpp    |  8 +++++++
 .../Checkers/WebKit/memory-unsafe-cast.mm     | 24 +++++++++++++++++++
 3 files changed, 47 insertions(+), 6 deletions(-)

diff --git a/clang/lib/StaticAnalyzer/Checkers/WebKit/MemoryUnsafeCastChecker.cpp b/clang/lib/StaticAnalyzer/Checkers/WebKit/MemoryUnsafeCastChecker.cpp
index b5d1e0d810c610..ff659c0e4af18f 100644
--- a/clang/lib/StaticAnalyzer/Checkers/WebKit/MemoryUnsafeCastChecker.cpp
+++ b/clang/lib/StaticAnalyzer/Checkers/WebKit/MemoryUnsafeCastChecker.cpp
@@ -64,7 +64,7 @@ void emitWarning(QualType FromType, QualType ToType,
 
 namespace {
 class MemoryUnsafeCastChecker : public Checker<check::ASTCodeBody> {
-  BugType BT{this, ""};
+  BugType BT{this, "Unsafe cast", "WebKit coding guidelines"};
 public:
   void checkASTCodeBody(const Decl *D, AnalysisManager& Mgr,
                         BugReporter &BR) const {
@@ -83,21 +83,30 @@ void WalkAST::VisitCastExpr(CastExpr *CE) {
   auto *SE = CE->getSubExprAsWritten();
   if (ToDerivedQualType->isObjCObjectPointerType()) {
     auto FromBaseQualType = SE->getType();
+    auto BaseObjCPtrType = FromBaseQualType->getAsObjCInterfacePointerType();
+    if (!BaseObjCPtrType)
+      return;
+    auto DerivedObjCPtrType = ToDerivedQualType->getAsObjCInterfacePointerType();
+    if (!DerivedObjCPtrType)
+      return;
     bool IsObjCSubType =
         !ASTC.hasSameType(ToDerivedQualType, FromBaseQualType) &&
-        ASTC.canAssignObjCInterfaces(
-            FromBaseQualType->getAsObjCInterfacePointerType(),
-            ToDerivedQualType->getAsObjCInterfacePointerType());
+        ASTC.canAssignObjCInterfaces(FromBaseQualType
+                                     ->getAsObjCInterfacePointerType(),
+                                     ToDerivedQualType
+                                     ->getAsObjCInterfacePointerType());
     if (IsObjCSubType)
       emitWarning(SE->getType(), ToDerivedQualType,AC, BR, Checker, CE);
     return;
   }
   auto ToDerivedType = ToDerivedQualType->getPointeeCXXRecordDecl();
+  if (!ToDerivedType || !ToDerivedType->hasDefinition())
+      return;
   auto FromBaseType = SE->getType()->getPointeeCXXRecordDecl();
   if (!FromBaseType)
     FromBaseType = SE->getType()->getAsCXXRecordDecl();
-  if (!FromBaseType)
-    return;
+  if (!FromBaseType || !FromBaseType->hasDefinition())
+      return;
   if (ToDerivedType->isDerivedFrom(FromBaseType))
     emitWarning(SE->getType(), ToDerivedQualType, AC, BR, Checker, CE);
 }
diff --git a/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.cpp b/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.cpp
index 72a054d0652d7d..e59c8fa71ab288 100644
--- a/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.cpp
+++ b/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.cpp
@@ -166,3 +166,11 @@ void test_struct_cpp_refs(BStructCpp &base_struct, DStructCpp &derived_struct) {
   BStructCpp &base_static = static_cast<BStructCpp&>(derived_struct); // no warning
   BStructCpp &base_reinterpret = reinterpret_cast<BStructCpp&>(derived_struct); // no warning
 }
+
+struct stack_st { };
+
+#define STACK_OF(type) struct stack_st_##type
+
+void test_stack(stack_st *base) {
+  STACK_OF(void) *derived = (STACK_OF(void)*)base;
+}
diff --git a/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.mm b/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.mm
index 61f2fffd43fbda..6c49304e1c2c60 100644
--- a/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.mm
+++ b/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.mm
@@ -27,3 +27,27 @@ -(void)testCasts:(BaseClass*)base {
   base = (BaseClass*)base;  // no warning
 }
 @end
+
+template <typename T>
+class WrappedObject
+{
+public:
+  T get() const { return mMetalObject; }
+  T mMetalObject = nullptr;
+};
+
+ at protocol MTLCommandEncoder
+ at end
+ at protocol MTLRenderCommandEncoder
+ at end
+class CommandEncoder : public WrappedObject<id<MTLCommandEncoder>> { };
+
+class RenderCommandEncoder final : public CommandEncoder
+{
+private:
+    // Override CommandEncoder
+    id<MTLRenderCommandEncoder> get()
+    {
+        return static_cast<id<MTLRenderCommandEncoder>>(CommandEncoder::get());
+    }
+};

>From e0455cf2ba9fae9cf2ddfb362c5c15fa2351a7af Mon Sep 17 00:00:00 2001
From: Rashmi Mudduluru <r_mudduluru at apple.com>
Date: Thu, 7 Nov 2024 16:20:23 -0800
Subject: [PATCH 5/7] Minor fixes

---
 .../Checkers/WebKit/MemoryUnsafeCastChecker.cpp            | 7 ++-----
 clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.mm  | 1 -
 2 files changed, 2 insertions(+), 6 deletions(-)

diff --git a/clang/lib/StaticAnalyzer/Checkers/WebKit/MemoryUnsafeCastChecker.cpp b/clang/lib/StaticAnalyzer/Checkers/WebKit/MemoryUnsafeCastChecker.cpp
index ff659c0e4af18f..289cd4f2736a44 100644
--- a/clang/lib/StaticAnalyzer/Checkers/WebKit/MemoryUnsafeCastChecker.cpp
+++ b/clang/lib/StaticAnalyzer/Checkers/WebKit/MemoryUnsafeCastChecker.cpp
@@ -91,12 +91,9 @@ void WalkAST::VisitCastExpr(CastExpr *CE) {
       return;
     bool IsObjCSubType =
         !ASTC.hasSameType(ToDerivedQualType, FromBaseQualType) &&
-        ASTC.canAssignObjCInterfaces(FromBaseQualType
-                                     ->getAsObjCInterfacePointerType(),
-                                     ToDerivedQualType
-                                     ->getAsObjCInterfacePointerType());
+        ASTC.canAssignObjCInterfaces(BaseObjCPtrType, DerivedObjCPtrType);
     if (IsObjCSubType)
-      emitWarning(SE->getType(), ToDerivedQualType,AC, BR, Checker, CE);
+      emitWarning(SE->getType(), ToDerivedQualType, AC, BR, Checker, CE);
     return;
   }
   auto ToDerivedType = ToDerivedQualType->getPointeeCXXRecordDecl();
diff --git a/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.mm b/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.mm
index 6c49304e1c2c60..72dcef7d91fe79 100644
--- a/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.mm
+++ b/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.mm
@@ -45,7 +45,6 @@ @protocol MTLRenderCommandEncoder
 class RenderCommandEncoder final : public CommandEncoder
 {
 private:
-    // Override CommandEncoder
     id<MTLRenderCommandEncoder> get()
     {
         return static_cast<id<MTLRenderCommandEncoder>>(CommandEncoder::get());

>From 2202f599c13aa64803012d1cf1a06290e9d49db1 Mon Sep 17 00:00:00 2001
From: Rashmi Mudduluru <r_mudduluru at apple.com>
Date: Wed, 13 Nov 2024 13:03:40 -0800
Subject: [PATCH 6/7] Use AST Matchers; Introduce AST matcher support for
 Objective C pointers.

---
 clang/include/clang/ASTMatchers/ASTMatchers.h |   4 +-
 .../clang/ASTMatchers/ASTMatchersInternal.h   |   2 +-
 clang/lib/ASTMatchers/ASTMatchersInternal.cpp |   3 +-
 .../WebKit/MemoryUnsafeCastChecker.cpp        | 167 +++++++++---------
 .../Checkers/WebKit/memory-unsafe-cast.cpp    |  66 +++----
 .../Checkers/WebKit/memory-unsafe-cast.mm     |   6 +-
 6 files changed, 129 insertions(+), 119 deletions(-)

diff --git a/clang/include/clang/ASTMatchers/ASTMatchers.h b/clang/include/clang/ASTMatchers/ASTMatchers.h
index 54e484d41fb1c3..c45b56755a66d1 100644
--- a/clang/include/clang/ASTMatchers/ASTMatchers.h
+++ b/clang/include/clang/ASTMatchers/ASTMatchers.h
@@ -4033,7 +4033,7 @@ AST_POLYMORPHIC_MATCHER_P_OVERLOAD(
 AST_POLYMORPHIC_MATCHER_P_OVERLOAD(
     hasType,
     AST_POLYMORPHIC_SUPPORTED_TYPES(Expr, FriendDecl, ValueDecl,
-                                    CXXBaseSpecifier),
+                                    CXXBaseSpecifier, ObjCInterfaceDecl),
     internal::Matcher<Decl>, InnerMatcher, 1) {
   QualType QT = internal::getUnderlyingType(Node);
   if (!QT.isNull())
@@ -7433,7 +7433,7 @@ extern const AstTypeMatcher<RValueReferenceType> rValueReferenceType;
 AST_TYPELOC_TRAVERSE_MATCHER_DECL(
     pointee, getPointee,
     AST_POLYMORPHIC_SUPPORTED_TYPES(BlockPointerType, MemberPointerType,
-                                    PointerType, ReferenceType));
+                                    PointerType, ReferenceType, ObjCObjectPointerType));
 
 /// Matches typedef types.
 ///
diff --git a/clang/include/clang/ASTMatchers/ASTMatchersInternal.h b/clang/include/clang/ASTMatchers/ASTMatchersInternal.h
index ab8b146453e761..e980aa93ba8512 100644
--- a/clang/include/clang/ASTMatchers/ASTMatchersInternal.h
+++ b/clang/include/clang/ASTMatchers/ASTMatchersInternal.h
@@ -1213,7 +1213,7 @@ using HasDeclarationSupportedTypes =
              ElaboratedType, InjectedClassNameType, LabelStmt, AddrLabelExpr,
              MemberExpr, QualType, RecordType, TagType,
              TemplateSpecializationType, TemplateTypeParmType, TypedefType,
-             UnresolvedUsingType, ObjCIvarRefExpr>;
+             UnresolvedUsingType, ObjCIvarRefExpr, ObjCInterfaceDecl>;
 
 /// A Matcher that allows binding the node it matches to an id.
 ///
diff --git a/clang/lib/ASTMatchers/ASTMatchersInternal.cpp b/clang/lib/ASTMatchers/ASTMatchersInternal.cpp
index 46dd44e6f2b24f..b51e4483f8f2b3 100644
--- a/clang/lib/ASTMatchers/ASTMatchersInternal.cpp
+++ b/clang/lib/ASTMatchers/ASTMatchersInternal.cpp
@@ -1097,7 +1097,8 @@ AST_TYPELOC_TRAVERSE_MATCHER_DEF(hasValueType,
 AST_TYPELOC_TRAVERSE_MATCHER_DEF(
     pointee,
     AST_POLYMORPHIC_SUPPORTED_TYPES(BlockPointerType, MemberPointerType,
-                                    PointerType, ReferenceType));
+                                    PointerType, ReferenceType, ObjCObjectPointerType));
+
 
 const internal::VariadicDynCastAllOfMatcher<Stmt, OMPExecutableDirective>
     ompExecutableDirective;
diff --git a/clang/lib/StaticAnalyzer/Checkers/WebKit/MemoryUnsafeCastChecker.cpp b/clang/lib/StaticAnalyzer/Checkers/WebKit/MemoryUnsafeCastChecker.cpp
index 289cd4f2736a44..cc5c6554692a6b 100644
--- a/clang/lib/StaticAnalyzer/Checkers/WebKit/MemoryUnsafeCastChecker.cpp
+++ b/clang/lib/StaticAnalyzer/Checkers/WebKit/MemoryUnsafeCastChecker.cpp
@@ -11,107 +11,116 @@
 //===----------------------------------------------------------------------===//
 
 #include "clang/StaticAnalyzer/Checkers/BuiltinCheckerRegistration.h"
-#include "clang/AST/StmtVisitor.h"
-#include "clang/Analysis/AnalysisDeclContext.h"
+#include "clang/ASTMatchers/ASTMatchFinder.h"
 #include "clang/StaticAnalyzer/Core/BugReporter/BugReporter.h"
 #include "clang/StaticAnalyzer/Core/Checker.h"
 #include "clang/StaticAnalyzer/Core/PathSensitive/AnalysisManager.h"
-#include "llvm/ADT/SmallString.h"
-#include "llvm/Support/raw_ostream.h"
 
 using namespace clang;
 using namespace ento;
+using namespace ast_matchers;
 
 namespace {
-class WalkAST : public StmtVisitor<WalkAST> {
-  BugReporter &BR;
-  const CheckerBase *Checker;
-  AnalysisDeclContext* AC;
-  ASTContext &ASTC;
+static constexpr const char *const BaseNode = "BaseNode";
+static constexpr const char *const DerivedNode = "DerivedNode";
+static constexpr const char *const WarnRecordDecl = "WarnRecordDecl";
 
+class MemoryUnsafeCastChecker : public Checker<check::ASTCodeBody> {
+  BugType BT{this, "Unsafe cast", "WebKit coding guidelines"};
 public:
-  WalkAST(BugReporter &br, const CheckerBase *checker, AnalysisDeclContext *ac)
-      : BR(br), Checker(checker), AC(ac), ASTC(AC->getASTContext()) {}
-
-  // Statement visitor methods.
-  void VisitChildren(Stmt *S);
-  void VisitStmt(Stmt *S) { VisitChildren(S); }
-  void VisitCastExpr(CastExpr *CE);
+  void checkASTCodeBody(const Decl *D, AnalysisManager& Mgr,
+                        BugReporter &BR) const;
 };
-} // end anonymous namespace
+}  // end namespace
+
+static void emitDiagnostics(const BoundNodes &Nodes, BugReporter &BR,
+                            AnalysisDeclContext *ADC,
+                            const MemoryUnsafeCastChecker *Checker) {
+  const auto *CE = Nodes.getNodeAs<CastExpr>(WarnRecordDecl);
+  const NamedDecl *Base = Nodes.getNodeAs<NamedDecl>(BaseNode);
+  const NamedDecl *Derived = Nodes.getNodeAs<NamedDecl>(DerivedNode);
+  assert(CE && Base && Derived);
 
-void emitWarning(QualType FromType, QualType ToType,
-                 AnalysisDeclContext *AC, BugReporter &BR,
-                 const CheckerBase *Checker,
-                 CastExpr *CE) {
   std::string Diagnostics;
   llvm::raw_string_ostream OS(Diagnostics);
-  OS << "Unsafe cast from base type '"
-     << FromType
-     << "' to derived type '"
-     << ToType
-     << "'",
+  OS << "Unsafe cast from base type '" << Base->getNameAsString()
+     << "' to derived type '" << Derived->getNameAsString() << "'";
 
   BR.EmitBasicReport(
-    AC->getDecl(),
-    Checker,
-    /*Name=*/"Memory unsafe cast",
-    categories::SecurityError,
-    Diagnostics,
-    PathDiagnosticLocation::createBegin(CE, BR.getSourceManager(), AC),
-    CE->getSourceRange());
+      ADC->getDecl(), Checker,
+      /*Name=*/"OSObject C-Style Cast", categories::SecurityError,
+      Diagnostics,
+      PathDiagnosticLocation::createBegin(CE, BR.getSourceManager(), ADC),
+      CE->getSourceRange());
 }
 
-namespace {
-class MemoryUnsafeCastChecker : public Checker<check::ASTCodeBody> {
-  BugType BT{this, "Unsafe cast", "WebKit coding guidelines"};
-public:
-  void checkASTCodeBody(const Decl *D, AnalysisManager& Mgr,
-                        BugReporter &BR) const {
-    WalkAST walker(BR, this, Mgr.getAnalysisDeclContext(D));
-    walker.Visit(D->getBody());
-  }
-};
+namespace clang {
+namespace ast_matchers {
+AST_MATCHER_P(StringLiteral, mentionsBoundType, std::string, BindingID) {
+  return Builder->removeBindings([this, &Node](const BoundNodesMap &Nodes) {
+    const auto &BN = Nodes.getNode(this->BindingID);
+    if (const auto *ND = BN.get<NamedDecl>()) {
+      return ND->getName() != Node.getString();
+    }
+    return true;
+  });
 }
+} // end namespace ast_matchers
+} // end namespace clang
 
-void WalkAST::VisitCastExpr(CastExpr *CE) {
-  auto ExpCast = dyn_cast_or_null<ExplicitCastExpr>(CE);
-  if (!ExpCast)
-    return;
-
-  auto ToDerivedQualType = ExpCast->getTypeAsWritten();
-  auto *SE = CE->getSubExprAsWritten();
-  if (ToDerivedQualType->isObjCObjectPointerType()) {
-    auto FromBaseQualType = SE->getType();
-    auto BaseObjCPtrType = FromBaseQualType->getAsObjCInterfacePointerType();
-    if (!BaseObjCPtrType)
-      return;
-    auto DerivedObjCPtrType = ToDerivedQualType->getAsObjCInterfacePointerType();
-    if (!DerivedObjCPtrType)
-      return;
-    bool IsObjCSubType =
-        !ASTC.hasSameType(ToDerivedQualType, FromBaseQualType) &&
-        ASTC.canAssignObjCInterfaces(BaseObjCPtrType, DerivedObjCPtrType);
-    if (IsObjCSubType)
-      emitWarning(SE->getType(), ToDerivedQualType, AC, BR, Checker, CE);
-    return;
-  }
-  auto ToDerivedType = ToDerivedQualType->getPointeeCXXRecordDecl();
-  if (!ToDerivedType || !ToDerivedType->hasDefinition())
-      return;
-  auto FromBaseType = SE->getType()->getPointeeCXXRecordDecl();
-  if (!FromBaseType)
-    FromBaseType = SE->getType()->getAsCXXRecordDecl();
-  if (!FromBaseType || !FromBaseType->hasDefinition())
-      return;
-  if (ToDerivedType->isDerivedFrom(FromBaseType))
-    emitWarning(SE->getType(), ToDerivedQualType, AC, BR, Checker, CE);
+static decltype(auto) hasTypePointingTo(DeclarationMatcher DeclM) {
+  return hasType(pointerType(pointee(hasDeclaration(DeclM))));
 }
 
-void WalkAST::VisitChildren(Stmt *S) {
-  for (Stmt *Child : S->children())
-    if (Child)
-      Visit(Child);
+void MemoryUnsafeCastChecker::checkASTCodeBody(const Decl *D,
+                                               AnalysisManager &AM,
+                                               BugReporter &BR) const {
+
+  AnalysisDeclContext *ADC = AM.getAnalysisDeclContext(D);
+
+  auto MatchExprPtr = allOf(
+      hasSourceExpression(hasTypePointingTo(cxxRecordDecl().bind(BaseNode))),
+      hasTypePointingTo(cxxRecordDecl(isDerivedFrom(equalsBoundNode(BaseNode)))
+                            .bind(DerivedNode)));
+  auto MatchExprPtrObjC = allOf(
+      hasSourceExpression(ignoringImpCasts(hasType(objcObjectPointerType(
+          pointee(hasDeclaration(objcInterfaceDecl().bind(BaseNode))))))),
+      ignoringImpCasts(hasType(objcObjectPointerType(pointee(hasDeclaration(
+          objcInterfaceDecl(isDerivedFrom(equalsBoundNode(BaseNode)))
+              .bind(DerivedNode)))))));
+  auto MatchExprRef =
+      allOf(hasSourceExpression(hasType(cxxRecordDecl().bind(BaseNode))),
+            hasType(cxxRecordDecl(isDerivedFrom(equalsBoundNode(BaseNode)))
+                        .bind(DerivedNode)));
+  auto MatchExprRefTypeDef =
+      allOf(hasSourceExpression(hasType(hasUnqualifiedDesugaredType(recordType(
+                 hasDeclaration(decl(cxxRecordDecl().bind(BaseNode))))))),
+            hasType(hasUnqualifiedDesugaredType(recordType(hasDeclaration(
+                decl(cxxRecordDecl(isDerivedFrom(equalsBoundNode(BaseNode)))
+                         .bind(DerivedNode)))))));
+
+  auto CastC = cStyleCastExpr(anyOf(MatchExprPtr, MatchExprRef,
+                                    MatchExprRefTypeDef, MatchExprPtrObjC))
+                   .bind(WarnRecordDecl);
+  auto CastStatic =
+      cxxStaticCastExpr(anyOf(MatchExprPtr, MatchExprRef, MatchExprRefTypeDef,
+                              MatchExprPtrObjC))
+          .bind(WarnRecordDecl);
+  auto CastReinterpret =
+      cxxReinterpretCastExpr(anyOf(MatchExprPtr, MatchExprRef,
+                                   MatchExprRefTypeDef, MatchExprPtrObjC))
+          .bind(WarnRecordDecl);
+  auto CastDynamic =
+      cxxDynamicCastExpr(anyOf(MatchExprPtr, MatchExprRef, MatchExprRefTypeDef,
+                               MatchExprPtrObjC))
+          .bind(WarnRecordDecl);
+
+  auto Cast = stmt(anyOf(CastC, CastStatic, CastReinterpret, CastDynamic));
+
+  auto Matches =
+      match(stmt(forEachDescendant(Cast)), *D->getBody(), AM.getASTContext());
+  for (BoundNodes Match : Matches)
+    emitDiagnostics(Match, BR, ADC, this);
 }
 
 void ento::registerMemoryUnsafeCastChecker(CheckerManager &Mgr) {
diff --git a/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.cpp b/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.cpp
index e59c8fa71ab288..c2da099851ced7 100644
--- a/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.cpp
+++ b/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.cpp
@@ -17,21 +17,21 @@ Target* downcast_ptr(Source* source){
 
 void test_pointers(Base *base) {
   Derived *derived_static = static_cast<Derived*>(base);
-  // expected-warning at -1{{Unsafe cast from base type 'Base *' to derived type 'Derived *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'Base' to derived type 'Derived'}}
   Derived *derived_reinterpret = reinterpret_cast<Derived*>(base);
-  // expected-warning at -1{{Unsafe cast from base type 'Base *' to derived type 'Derived *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'Base' to derived type 'Derived'}}
   Derived *derived_c = (Derived*)base;
-  // expected-warning at -1{{Unsafe cast from base type 'Base *' to derived type 'Derived *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'Base' to derived type 'Derived'}}
   Derived *derived_d = downcast_ptr<Derived, Base>(base);  // no warning
 }
 
 void test_refs(Base &base) {
   Derived &derived_static = static_cast<Derived&>(base);
-  // expected-warning at -1{{Unsafe cast from base type 'Base' to derived type 'Derived &'}}
+  // expected-warning at -1{{Unsafe cast from base type 'Base' to derived type 'Derived'}}
   Derived &derived_reinterpret = reinterpret_cast<Derived&>(base);
-  // expected-warning at -1{{Unsafe cast from base type 'Base' to derived type 'Derived &'}}
+  // expected-warning at -1{{Unsafe cast from base type 'Base' to derived type 'Derived'}}
   Derived &derived_c = (Derived&)base;
-  // expected-warning at -1{{Unsafe cast from base type 'Base' to derived type 'Derived &'}}
+  // expected-warning at -1{{Unsafe cast from base type 'Base' to derived type 'Derived'}}
   Derived &derived_d = downcast_ref<Derived, Base>(base);  // no warning
 }
 
@@ -45,9 +45,9 @@ class DerivedVirtual : public BaseVirtual {
 
 void test_dynamic_casts(BaseVirtual *base_ptr, BaseVirtual &base_ref) {
   DerivedVirtual *derived_dynamic_ptr = dynamic_cast<DerivedVirtual*>(base_ptr);
-  // expected-warning at -1{{Unsafe cast from base type 'BaseVirtual *' to derived type 'DerivedVirtual *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseVirtual' to derived type 'DerivedVirtual'}}
   DerivedVirtual &derived_dynamic_ref = dynamic_cast<DerivedVirtual&>(base_ref);
-  // expected-warning at -1{{Unsafe cast from base type 'BaseVirtual' to derived type 'DerivedVirtual &'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseVirtual' to derived type 'DerivedVirtual'}}
 }
 
 struct BaseStruct { };
@@ -55,11 +55,11 @@ struct DerivedStruct : BaseStruct { };
 
 void test_struct_pointers(struct BaseStruct *base_struct) {
   struct DerivedStruct *derived_static = static_cast<struct DerivedStruct*>(base_struct);
-  // expected-warning at -1{{Unsafe cast from base type 'struct BaseStruct *' to derived type 'struct DerivedStruct *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseStruct' to derived type 'DerivedStruct'}}
   struct DerivedStruct *derived_reinterpret = reinterpret_cast<struct DerivedStruct*>(base_struct);
-  // expected-warning at -1{{Unsafe cast from base type 'struct BaseStruct *' to derived type 'struct DerivedStruct *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseStruct' to derived type 'DerivedStruct'}}
   struct DerivedStruct *derived_c = (struct DerivedStruct*)base_struct;
-  // expected-warning at -1{{Unsafe cast from base type 'struct BaseStruct *' to derived type 'struct DerivedStruct *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseStruct' to derived type 'DerivedStruct'}}
 }
 
 typedef struct BaseStruct BStruct;
@@ -67,11 +67,11 @@ typedef struct DerivedStruct DStruct;
 
 void test_struct_refs(BStruct &base_struct) {
   DStruct &derived_static = static_cast<DStruct&>(base_struct);
-  // expected-warning at -1{{Unsafe cast from base type 'BStruct' to derived type 'DStruct &'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseStruct' to derived type 'DerivedStruct'}}
   DStruct &derived_reinterpret = reinterpret_cast<DStruct&>(base_struct);
-  // expected-warning at -1{{Unsafe cast from base type 'BStruct' to derived type 'DStruct &'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseStruct' to derived type 'DerivedStruct'}}
   DStruct &derived_c = (DStruct&)base_struct;
-  // expected-warning at -1{{Unsafe cast from base type 'BStruct' to derived type 'DStruct &'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseStruct' to derived type 'DerivedStruct'}}
 }
 
 int counter = 0;
@@ -80,7 +80,7 @@ void test_recursive(BStruct &base_struct) {
     return;
   counter++;
   DStruct &derived_static = static_cast<DStruct&>(base_struct);
-  // expected-warning at -1{{Unsafe cast from base type 'BStruct' to derived type 'DStruct &'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseStruct' to derived type 'DerivedStruct'}}
 }
 
 template<typename T>
@@ -91,17 +91,17 @@ class DerivedTemplate : public BaseTemplate<T> { };
 
 void test_templates(BaseTemplate<int> *base, BaseTemplate<int> &base_ref) {
   DerivedTemplate<int> *derived_static = static_cast<DerivedTemplate<int>*>(base);
-  // expected-warning at -1{{Unsafe cast from base type 'BaseTemplate<int> *' to derived type 'DerivedTemplate<int> *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseTemplate' to derived type 'DerivedTemplate'}}
   DerivedTemplate<int> *derived_reinterpret = reinterpret_cast<DerivedTemplate<int>*>(base);
-  // expected-warning at -1{{Unsafe cast from base type 'BaseTemplate<int> *' to derived type 'DerivedTemplate<int> *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseTemplate' to derived type 'DerivedTemplate'}}
   DerivedTemplate<int> *derived_c = (DerivedTemplate<int>*)base;
-  // expected-warning at -1{{Unsafe cast from base type 'BaseTemplate<int> *' to derived type 'DerivedTemplate<int> *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseTemplate' to derived type 'DerivedTemplate'}}
   DerivedTemplate<int> &derived_static_ref = static_cast<DerivedTemplate<int>&>(base_ref);
-  // expected-warning at -1{{Unsafe cast from base type 'BaseTemplate<int>' to derived type 'DerivedTemplate<int> &'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseTemplate' to derived type 'DerivedTemplate'}}
   DerivedTemplate<int> &derived_reinterpret_ref = reinterpret_cast<DerivedTemplate<int>&>(base_ref);
-  // expected-warning at -1{{Unsafe cast from base type 'BaseTemplate<int>' to derived type 'DerivedTemplate<int> &'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseTemplate' to derived type 'DerivedTemplate'}}
   DerivedTemplate<int> &derived_c_ref = (DerivedTemplate<int>&)base_ref;
-  // expected-warning at -1{{Unsafe cast from base type 'BaseTemplate<int>' to derived type 'DerivedTemplate<int> &'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseTemplate' to derived type 'DerivedTemplate'}}
 }
 
 #define CAST_MACRO_STATIC(X,Y) (static_cast<Y>(X))
@@ -110,27 +110,27 @@ void test_templates(BaseTemplate<int> *base, BaseTemplate<int> &base_ref) {
 
 void test_macro_static(Base *base, Derived *derived, Base &base_ref) {
   Derived *derived_static = CAST_MACRO_STATIC(base, Derived*);
-  // expected-warning at -1{{Unsafe cast from base type 'Base *' to derived type 'Derived *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'Base' to derived type 'Derived'}}
   Derived &derived_static_ref = CAST_MACRO_STATIC(base_ref, Derived&);
-  // expected-warning at -1{{Unsafe cast from base type 'Base' to derived type 'Derived &'}}
+  // expected-warning at -1{{Unsafe cast from base type 'Base' to derived type 'Derived'}}
   Base *base_static_same = CAST_MACRO_STATIC(base, Base*);  // no warning
   Base *base_static_upcast = CAST_MACRO_STATIC(derived, Base*);  // no warning
 }
 
 void test_macro_reinterpret(Base *base, Derived *derived, Base &base_ref) {
   Derived *derived_reinterpret = CAST_MACRO_REINTERPRET(base, Derived*);
-  // expected-warning at -1{{Unsafe cast from base type 'Base *' to derived type 'Derived *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'Base' to derived type 'Derived'}}
   Derived &derived_reinterpret_ref = CAST_MACRO_REINTERPRET(base_ref, Derived&);
-  // expected-warning at -1{{Unsafe cast from base type 'Base' to derived type 'Derived &' [alpha.webkit.MemoryUnsafeCastChecker]}}
+  // expected-warning at -1{{Unsafe cast from base type 'Base' to derived type 'Derived'}}
   Base *base_reinterpret_same = CAST_MACRO_REINTERPRET(base, Base*);  // no warning
   Base *base_reinterpret_upcast = CAST_MACRO_REINTERPRET(derived, Base*);  // no warning
 }
 
 void test_macro_c(Base *base, Derived *derived, Base &base_ref) {
   Derived *derived_c = CAST_MACRO_C(base, Derived*);
-  // expected-warning at -1{{Unsafe cast from base type 'Base *' to derived type 'Derived *' [alpha.webkit.MemoryUnsafeCastChecker]}}
+  // expected-warning at -1{{Unsafe cast from base type 'Base' to derived type 'Derived'}}
   Derived &derived_c_ref = CAST_MACRO_C(base_ref, Derived&);
-  // expected-warning at -1{{Unsafe cast from base type 'Base' to derived type 'Derived &' [alpha.webkit.MemoryUnsafeCastChecker]}}
+  // expected-warning at -1{{Unsafe cast from base type 'Base' to derived type 'Derived'}}
   Base *base_c_same = CAST_MACRO_C(base, Base*);  // no warning
   Base *base_c_upcast = CAST_MACRO_C(derived, Base*);  // no warning
 }
@@ -145,11 +145,11 @@ struct DerivedStructCpp : BaseStructCpp {
 
 void test_struct_cpp_pointers(struct BaseStructCpp *base_struct) {
   struct DerivedStructCpp *derived_static = static_cast<struct DerivedStructCpp*>(base_struct);
-  // expected-warning at -1{{Unsafe cast from base type 'struct BaseStructCpp *' to derived type 'struct DerivedStructCpp *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseStructCpp' to derived type 'DerivedStructCpp'}}
   struct DerivedStructCpp *derived_reinterpret = reinterpret_cast<struct DerivedStructCpp*>(base_struct);
-  // expected-warning at -1{{Unsafe cast from base type 'struct BaseStructCpp *' to derived type 'struct DerivedStructCpp *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseStructCpp' to derived type 'DerivedStructCpp'}}
   struct DerivedStructCpp *derived_c = (struct DerivedStructCpp*)base_struct;
-  // expected-warning at -1{{Unsafe cast from base type 'struct BaseStructCpp *' to derived type 'struct DerivedStructCpp *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseStructCpp' to derived type 'DerivedStructCpp'}}
 }
 
 typedef struct BaseStructCpp BStructCpp;
@@ -157,11 +157,11 @@ typedef struct DerivedStructCpp DStructCpp;
 
 void test_struct_cpp_refs(BStructCpp &base_struct, DStructCpp &derived_struct) {
   DStructCpp &derived_static = static_cast<DStructCpp&>(base_struct);
-  // expected-warning at -1{{Unsafe cast from base type 'BStructCpp' to derived type 'DStructCpp &'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseStructCpp' to derived type 'DerivedStructCpp'}}
   DStructCpp &derived_reinterpret = reinterpret_cast<DStructCpp&>(base_struct);
-  // expected-warning at -1{{Unsafe cast from base type 'BStructCpp' to derived type 'DStructCpp &'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseStructCpp' to derived type 'DerivedStructCpp'}}
   DStructCpp &derived_c = (DStructCpp&)base_struct;
-  // expected-warning at -1{{Unsafe cast from base type 'BStructCpp' to derived type 'DStructCpp &'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseStructCpp' to derived type 'DerivedStructCpp'}}
   BStructCpp &base = (BStructCpp&)derived_struct; // no warning
   BStructCpp &base_static = static_cast<BStructCpp&>(derived_struct); // no warning
   BStructCpp &base_reinterpret = reinterpret_cast<BStructCpp&>(derived_struct); // no warning
diff --git a/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.mm b/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.mm
index 72dcef7d91fe79..db8ca7da0b52c1 100644
--- a/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.mm
+++ b/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.mm
@@ -18,11 +18,11 @@ -(void)testCasts:(BaseClass*)base;
 @implementation DerivedClass
 -(void)testCasts:(BaseClass*)base {
   DerivedClass *derived = (DerivedClass*)base;
-  // expected-warning at -1{{Unsafe cast from base type 'BaseClass *' to derived type 'DerivedClass *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseClass' to derived type 'DerivedClass'}}
   DerivedClass *derived_static = static_cast<DerivedClass*>(base);
-  // expected-warning at -1{{Unsafe cast from base type 'BaseClass *' to derived type 'DerivedClass *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseClass' to derived type 'DerivedClass'}}
   DerivedClass *derived_reinterpret = reinterpret_cast<DerivedClass*>(base);
-  // expected-warning at -1{{Unsafe cast from base type 'BaseClass *' to derived type 'DerivedClass *'}}
+  // expected-warning at -1{{Unsafe cast from base type 'BaseClass' to derived type 'DerivedClass'}}
   base = (BaseClass*)derived;  // no warning
   base = (BaseClass*)base;  // no warning
 }

>From c229740c8a727585456d805a85885e7eada4a0ea Mon Sep 17 00:00:00 2001
From: Rashmi Mudduluru <r_mudduluru at apple.com>
Date: Wed, 20 Nov 2024 15:53:38 -0800
Subject: [PATCH 7/7] Add warnings for unrelated pointer casts

---
 .../WebKit/MemoryUnsafeCastChecker.cpp        | 120 ++++++++++++++----
 .../Checkers/WebKit/memory-unsafe-cast.cpp    |  59 +++++++++
 .../Checkers/WebKit/memory-unsafe-cast.mm     |  12 ++
 3 files changed, 167 insertions(+), 24 deletions(-)

diff --git a/clang/lib/StaticAnalyzer/Checkers/WebKit/MemoryUnsafeCastChecker.cpp b/clang/lib/StaticAnalyzer/Checkers/WebKit/MemoryUnsafeCastChecker.cpp
index cc5c6554692a6b..d699da29615768 100644
--- a/clang/lib/StaticAnalyzer/Checkers/WebKit/MemoryUnsafeCastChecker.cpp
+++ b/clang/lib/StaticAnalyzer/Checkers/WebKit/MemoryUnsafeCastChecker.cpp
@@ -23,6 +23,8 @@ using namespace ast_matchers;
 namespace {
 static constexpr const char *const BaseNode = "BaseNode";
 static constexpr const char *const DerivedNode = "DerivedNode";
+static constexpr const char *const FromCastNode = "FromCast";
+static constexpr const char *const ToCastNode = "ToCast";
 static constexpr const char *const WarnRecordDecl = "WarnRecordDecl";
 
 class MemoryUnsafeCastChecker : public Checker<check::ASTCodeBody> {
@@ -31,7 +33,17 @@ class MemoryUnsafeCastChecker : public Checker<check::ASTCodeBody> {
   void checkASTCodeBody(const Decl *D, AnalysisManager& Mgr,
                         BugReporter &BR) const;
 };
-}  // end namespace
+} // end namespace
+
+static void emitReport(AnalysisDeclContext *ADC, BugReporter &BR,
+                       const MemoryUnsafeCastChecker *Checker,
+                       std::string &Diagnostics, const CastExpr *CE) {
+  BR.EmitBasicReport(
+      ADC->getDecl(), Checker,
+      /*Name=*/"Unsafe type cast", categories::SecurityError, Diagnostics,
+      PathDiagnosticLocation::createBegin(CE, BR.getSourceManager(), ADC),
+      CE->getSourceRange());
+}
 
 static void emitDiagnostics(const BoundNodes &Nodes, BugReporter &BR,
                             AnalysisDeclContext *ADC,
@@ -45,15 +57,25 @@ static void emitDiagnostics(const BoundNodes &Nodes, BugReporter &BR,
   llvm::raw_string_ostream OS(Diagnostics);
   OS << "Unsafe cast from base type '" << Base->getNameAsString()
      << "' to derived type '" << Derived->getNameAsString() << "'";
+  emitReport(ADC, BR, Checker, Diagnostics, CE);
+}
 
-  BR.EmitBasicReport(
-      ADC->getDecl(), Checker,
-      /*Name=*/"OSObject C-Style Cast", categories::SecurityError,
-      Diagnostics,
-      PathDiagnosticLocation::createBegin(CE, BR.getSourceManager(), ADC),
-      CE->getSourceRange());
+static void emitDiagnosticsUnrelated(const BoundNodes &Nodes, BugReporter &BR,
+                                     AnalysisDeclContext *ADC,
+                                     const MemoryUnsafeCastChecker *Checker) {
+  const auto *CE = Nodes.getNodeAs<CastExpr>(WarnRecordDecl);
+  const NamedDecl *FromCast = Nodes.getNodeAs<NamedDecl>(FromCastNode);
+  const NamedDecl *ToCast = Nodes.getNodeAs<NamedDecl>(ToCastNode);
+  assert(CE && FromCast && ToCast);
+
+  std::string Diagnostics;
+  llvm::raw_string_ostream OS(Diagnostics);
+  OS << "Unsafe cast from type '" << FromCast->getNameAsString()
+     << "' to an unrelated type '" << ToCast->getNameAsString() << "'";
+  emitReport(ADC, BR, Checker, Diagnostics, CE);
 }
 
+
 namespace clang {
 namespace ast_matchers {
 AST_MATCHER_P(StringLiteral, mentionsBoundType, std::string, BindingID) {
@@ -78,6 +100,7 @@ void MemoryUnsafeCastChecker::checkASTCodeBody(const Decl *D,
 
   AnalysisDeclContext *ADC = AM.getAnalysisDeclContext(D);
 
+  // Match downcasts from base type to derived type and warn
   auto MatchExprPtr = allOf(
       hasSourceExpression(hasTypePointingTo(cxxRecordDecl().bind(BaseNode))),
       hasTypePointingTo(cxxRecordDecl(isDerivedFrom(equalsBoundNode(BaseNode)))
@@ -88,32 +111,26 @@ void MemoryUnsafeCastChecker::checkASTCodeBody(const Decl *D,
       ignoringImpCasts(hasType(objcObjectPointerType(pointee(hasDeclaration(
           objcInterfaceDecl(isDerivedFrom(equalsBoundNode(BaseNode)))
               .bind(DerivedNode)))))));
-  auto MatchExprRef =
-      allOf(hasSourceExpression(hasType(cxxRecordDecl().bind(BaseNode))),
-            hasType(cxxRecordDecl(isDerivedFrom(equalsBoundNode(BaseNode)))
-                        .bind(DerivedNode)));
   auto MatchExprRefTypeDef =
       allOf(hasSourceExpression(hasType(hasUnqualifiedDesugaredType(recordType(
-                 hasDeclaration(decl(cxxRecordDecl().bind(BaseNode))))))),
+                hasDeclaration(decl(cxxRecordDecl().bind(BaseNode))))))),
             hasType(hasUnqualifiedDesugaredType(recordType(hasDeclaration(
                 decl(cxxRecordDecl(isDerivedFrom(equalsBoundNode(BaseNode)))
                          .bind(DerivedNode)))))));
 
-  auto CastC = cStyleCastExpr(anyOf(MatchExprPtr, MatchExprRef,
-                                    MatchExprRefTypeDef, MatchExprPtrObjC))
-                   .bind(WarnRecordDecl);
-  auto CastStatic =
-      cxxStaticCastExpr(anyOf(MatchExprPtr, MatchExprRef, MatchExprRefTypeDef,
-                              MatchExprPtrObjC))
+  auto CastC =
+      cStyleCastExpr(anyOf(MatchExprPtr, MatchExprRefTypeDef, MatchExprPtrObjC))
           .bind(WarnRecordDecl);
+  auto CastStatic = cxxStaticCastExpr(anyOf(MatchExprPtr, MatchExprRefTypeDef,
+                                            MatchExprPtrObjC))
+                        .bind(WarnRecordDecl);
   auto CastReinterpret =
-      cxxReinterpretCastExpr(anyOf(MatchExprPtr, MatchExprRef,
-                                   MatchExprRefTypeDef, MatchExprPtrObjC))
-          .bind(WarnRecordDecl);
-  auto CastDynamic =
-      cxxDynamicCastExpr(anyOf(MatchExprPtr, MatchExprRef, MatchExprRefTypeDef,
-                               MatchExprPtrObjC))
+      cxxReinterpretCastExpr(
+          anyOf(MatchExprPtr, MatchExprRefTypeDef, MatchExprPtrObjC))
           .bind(WarnRecordDecl);
+  auto CastDynamic = cxxDynamicCastExpr(anyOf(MatchExprPtr, MatchExprRefTypeDef,
+                                              MatchExprPtrObjC))
+                         .bind(WarnRecordDecl);
 
   auto Cast = stmt(anyOf(CastC, CastStatic, CastReinterpret, CastDynamic));
 
@@ -121,6 +138,61 @@ void MemoryUnsafeCastChecker::checkASTCodeBody(const Decl *D,
       match(stmt(forEachDescendant(Cast)), *D->getBody(), AM.getASTContext());
   for (BoundNodes Match : Matches)
     emitDiagnostics(Match, BR, ADC, this);
+
+  // Match casts between unrelated types and warn
+  auto MatchExprPtrUnrelatedTypes = allOf(
+      hasSourceExpression(
+          hasTypePointingTo(cxxRecordDecl().bind(FromCastNode))),
+      hasTypePointingTo(cxxRecordDecl().bind(ToCastNode)),
+      unless(anyOf(hasTypePointingTo(cxxRecordDecl(
+                       isSameOrDerivedFrom(equalsBoundNode(FromCastNode)))),
+                   hasSourceExpression(hasTypePointingTo(cxxRecordDecl(
+                       isSameOrDerivedFrom(equalsBoundNode(ToCastNode))))))));
+  auto MatchExprPtrObjCUnrelatedTypes = allOf(
+      hasSourceExpression(ignoringImpCasts(hasType(objcObjectPointerType(
+          pointee(hasDeclaration(objcInterfaceDecl().bind(FromCastNode))))))),
+      ignoringImpCasts(hasType(objcObjectPointerType(
+          pointee(hasDeclaration(objcInterfaceDecl().bind(ToCastNode)))))),
+      unless(anyOf(
+          ignoringImpCasts(hasType(
+              objcObjectPointerType(pointee(hasDeclaration(objcInterfaceDecl(
+                  isSameOrDerivedFrom(equalsBoundNode(FromCastNode)))))))),
+          hasSourceExpression(ignoringImpCasts(hasType(
+              objcObjectPointerType(pointee(hasDeclaration(objcInterfaceDecl(
+                  isSameOrDerivedFrom(equalsBoundNode(ToCastNode))))))))))));
+  auto MatchExprRefTypeDefUnrelated = allOf(
+      hasSourceExpression(hasType(hasUnqualifiedDesugaredType(recordType(
+          hasDeclaration(decl(cxxRecordDecl().bind(FromCastNode))))))),
+      hasType(hasUnqualifiedDesugaredType(
+          recordType(hasDeclaration(decl(cxxRecordDecl().bind(ToCastNode)))))),
+      unless(anyOf(
+          hasType(hasUnqualifiedDesugaredType(
+              recordType(hasDeclaration(decl(cxxRecordDecl(
+                  isSameOrDerivedFrom(equalsBoundNode(FromCastNode)))))))),
+          hasSourceExpression(hasType(hasUnqualifiedDesugaredType(
+              recordType(hasDeclaration(decl(cxxRecordDecl(
+                  isSameOrDerivedFrom(equalsBoundNode(ToCastNode))))))))))));
+
+  auto CastCUnrelated = cStyleCastExpr(anyOf(MatchExprPtrUnrelatedTypes,
+                                             MatchExprPtrObjCUnrelatedTypes,
+                                             MatchExprRefTypeDefUnrelated))
+                            .bind(WarnRecordDecl);
+  auto CastReinterpretUnrelated =
+      cxxReinterpretCastExpr(anyOf(MatchExprPtrUnrelatedTypes,
+                                   MatchExprPtrObjCUnrelatedTypes,
+                                   MatchExprRefTypeDefUnrelated))
+          .bind(WarnRecordDecl);
+  auto CastDynamicUnrelated =
+      cxxDynamicCastExpr(anyOf(MatchExprPtrUnrelatedTypes,
+                               MatchExprPtrObjCUnrelatedTypes,
+                               MatchExprRefTypeDefUnrelated))
+          .bind(WarnRecordDecl);
+  auto CastUnrelated = stmt(
+      anyOf(CastCUnrelated, CastReinterpretUnrelated, CastDynamicUnrelated));
+  auto MatchesUnrelatedTypes = match(stmt(forEachDescendant(CastUnrelated)),
+                                     *D->getBody(), AM.getASTContext());
+  for (BoundNodes Match : MatchesUnrelatedTypes)
+    emitDiagnosticsUnrelated(Match, BR, ADC, this);
 }
 
 void ento::registerMemoryUnsafeCastChecker(CheckerManager &Mgr) {
diff --git a/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.cpp b/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.cpp
index c2da099851ced7..fc3ef7e69921ea 100644
--- a/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.cpp
+++ b/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.cpp
@@ -25,6 +25,10 @@ void test_pointers(Base *base) {
   Derived *derived_d = downcast_ptr<Derived, Base>(base);  // no warning
 }
 
+void test_non_pointers(Derived derived) {
+  Base base_static = static_cast<Base>(derived);  // no warning
+}
+
 void test_refs(Base &base) {
   Derived &derived_static = static_cast<Derived&>(base);
   // expected-warning at -1{{Unsafe cast from base type 'Base' to derived type 'Derived'}}
@@ -173,4 +177,59 @@ struct stack_st { };
 
 void test_stack(stack_st *base) {
   STACK_OF(void) *derived = (STACK_OF(void)*)base;
+  // expected-warning at -1{{Unsafe cast from type 'stack_st' to an unrelated type 'stack_st_void'}}
+}
+
+class Parent { };
+class Child1 : public Parent { };
+class Child2 : public Parent { };
+
+void test_common_parent(Child1 *c1, Child2 *c2) {
+  Child2 *c2_cstyle = (Child2 *)c1;
+  // expected-warning at -1{{Unsafe cast from type 'Child1' to an unrelated type 'Child2'}}
+  Child2 *c2_reinterpret = reinterpret_cast<Child2 *>(c1);
+  // expected-warning at -1{{Unsafe cast from type 'Child1' to an unrelated type 'Child2'}}
+}
+
+class Type1 { };
+class Type2 { };
+
+void test_unrelated_ref(Type1 &t1, Type2 &t2) {
+  Type2 &t2_cstyle = (Type2 &)t1;
+  // expected-warning at -1{{Unsafe cast from type 'Type1' to an unrelated type 'Type2'}}
+  Type2 &t2_reinterpret = reinterpret_cast<Type2 &>(t1);
+  // expected-warning at -1{{Unsafe cast from type 'Type1' to an unrelated type 'Type2'}}
+  Type2 &t2_same = reinterpret_cast<Type2 &>(t2); // no warning
+}
+
+
+class VirtualClass1 {
+  virtual void virtual_base_function();
+};
+
+class VirtualClass2 {
+  void virtual_base_function();
+};
+
+void test_unrelated_virtual(VirtualClass1 &v1) {
+  VirtualClass2 &v2 = dynamic_cast<VirtualClass2 &>(v1);
+  // expected-warning at -1{{Unsafe cast from type 'VirtualClass1' to an unrelated type 'VirtualClass2'}}
+}
+
+struct StructA { };
+struct StructB { };
+
+typedef struct StructA StA;
+typedef struct StructB StB;
+
+void test_struct_unrelated_refs(StA &a, StB &b) {
+  StB &b_reinterpret = reinterpret_cast<StB&>(a);
+  // expected-warning at -1{{Unsafe cast from type 'StructA' to an unrelated type 'StructB'}}
+  StB &b_c = (StB&)a;
+  // expected-warning at -1{{Unsafe cast from type 'StructA' to an unrelated type 'StructB'}}
+  StA &a_local = (StA&)b;
+  // expected-warning at -1{{Unsafe cast from type 'StructB' to an unrelated type 'StructA'}}
+  StA &a_reinterpret = reinterpret_cast<StA&>(b);
+  // expected-warning at -1{{Unsafe cast from type 'StructB' to an unrelated type 'StructA'}}
+  StA &a_same = (StA&)a; // no warning
 }
diff --git a/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.mm b/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.mm
index db8ca7da0b52c1..f9046d79817849 100644
--- a/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.mm
+++ b/clang/test/Analysis/Checkers/WebKit/memory-unsafe-cast.mm
@@ -50,3 +50,15 @@ @protocol MTLRenderCommandEncoder
         return static_cast<id<MTLRenderCommandEncoder>>(CommandEncoder::get());
     }
 };
+
+ at interface Class1
+ at end
+
+ at interface Class2
+ at end
+
+void testUnrelated(Class1 *c1) {
+  Class2 *c2 = (Class2*)c1;
+  // expected-warning at -1{{Unsafe cast from type 'Class1' to an unrelated type 'Class2'}}
+  Class1 *c1_same = reinterpret_cast<Class1*>(c1); // no warning
+}



More information about the cfe-commits mailing list