[clang] 8c5e9cf - [clang][Interp] Implement nullability argument checking

Timm Bäder via cfe-commits cfe-commits at lists.llvm.org
Sun Feb 25 21:27:20 PST 2024


Author: Timm Bäder
Date: 2024-02-26T06:19:25+01:00
New Revision: 8c5e9cf737138aba22a4a8f64ef2c5efc80dd7f9

URL: https://github.com/llvm/llvm-project/commit/8c5e9cf737138aba22a4a8f64ef2c5efc80dd7f9
DIFF: https://github.com/llvm/llvm-project/commit/8c5e9cf737138aba22a4a8f64ef2c5efc80dd7f9.diff

LOG: [clang][Interp] Implement nullability argument checking

Implement constexpr checking for null pointers being passed to
arguments annotated as nonnull.

Added: 
    clang/lib/AST/Interp/InterpShared.cpp
    clang/lib/AST/Interp/InterpShared.h
    clang/test/AST/Interp/nullable.cpp

Modified: 
    clang/lib/AST/CMakeLists.txt
    clang/lib/AST/Interp/ByteCodeExprGen.cpp
    clang/lib/AST/Interp/Function.h
    clang/lib/AST/Interp/Interp.cpp
    clang/lib/AST/Interp/Interp.h
    clang/lib/AST/Interp/Opcodes.td
    clang/test/Sema/attr-nonnull.c
    clang/test/SemaCXX/attr-nonnull.cpp

Removed: 
    


################################################################################
diff  --git a/clang/lib/AST/CMakeLists.txt b/clang/lib/AST/CMakeLists.txt
index d793c3ed0410a8..6ea1ca3e76cf33 100644
--- a/clang/lib/AST/CMakeLists.txt
+++ b/clang/lib/AST/CMakeLists.txt
@@ -88,6 +88,7 @@ add_clang_library(clangAST
   Interp/Record.cpp
   Interp/Source.cpp
   Interp/State.cpp
+  Interp/InterpShared.cpp
   ItaniumCXXABI.cpp
   ItaniumMangle.cpp
   JSONNodeDumper.cpp

diff  --git a/clang/lib/AST/Interp/ByteCodeExprGen.cpp b/clang/lib/AST/Interp/ByteCodeExprGen.cpp
index 7f97d8ce9fb804..eb5a1b536b7798 100644
--- a/clang/lib/AST/Interp/ByteCodeExprGen.cpp
+++ b/clang/lib/AST/Interp/ByteCodeExprGen.cpp
@@ -13,8 +13,10 @@
 #include "Context.h"
 #include "Floating.h"
 #include "Function.h"
+#include "InterpShared.h"
 #include "PrimType.h"
 #include "Program.h"
+#include "clang/AST/Attr.h"
 
 using namespace clang;
 using namespace clang::interp;
@@ -2656,6 +2658,7 @@ bool ByteCodeExprGen<Emitter>::VisitCallExpr(const CallExpr *E) {
   QualType ReturnType = E->getCallReturnType(Ctx.getASTContext());
   std::optional<PrimType> T = classify(ReturnType);
   bool HasRVO = !ReturnType->isVoidType() && !T;
+  const FunctionDecl *FuncDecl = E->getDirectCallee();
 
   if (HasRVO) {
     if (DiscardResult) {
@@ -2673,17 +2676,16 @@ bool ByteCodeExprGen<Emitter>::VisitCallExpr(const CallExpr *E) {
     }
   }
 
-  auto Args = E->arguments();
+  auto Args = llvm::ArrayRef(E->getArgs(), E->getNumArgs());
   // Calling a static operator will still
   // pass the instance, but we don't need it.
   // Discard it here.
   if (isa<CXXOperatorCallExpr>(E)) {
-    if (const auto *MD =
-            dyn_cast_if_present<CXXMethodDecl>(E->getDirectCallee());
+    if (const auto *MD = dyn_cast_if_present<CXXMethodDecl>(FuncDecl);
         MD && MD->isStatic()) {
       if (!this->discard(E->getArg(0)))
         return false;
-      Args = drop_begin(Args, 1);
+      Args = Args.drop_front();
     }
   }
 
@@ -2693,13 +2695,25 @@ bool ByteCodeExprGen<Emitter>::VisitCallExpr(const CallExpr *E) {
       return false;
   }
 
+  llvm::BitVector NonNullArgs = collectNonNullArgs(FuncDecl, Args);
   // Put arguments on the stack.
+  unsigned ArgIndex = 0;
   for (const auto *Arg : Args) {
     if (!this->visit(Arg))
       return false;
+
+    // If we know the callee already, check the known parametrs for nullability.
+    if (FuncDecl && NonNullArgs[ArgIndex]) {
+      PrimType ArgT = classify(Arg).value_or(PT_Ptr);
+      if (ArgT == PT_Ptr || ArgT == PT_FnPtr) {
+        if (!this->emitCheckNonNullArg(ArgT, Arg))
+          return false;
+      }
+    }
+    ++ArgIndex;
   }
 
-  if (const FunctionDecl *FuncDecl = E->getDirectCallee()) {
+  if (FuncDecl) {
     const Function *Func = getFunction(FuncDecl);
     if (!Func)
       return false;
@@ -2748,7 +2762,7 @@ bool ByteCodeExprGen<Emitter>::VisitCallExpr(const CallExpr *E) {
     if (!this->visit(E->getCallee()))
       return false;
 
-    if (!this->emitCallPtr(ArgSize, E))
+    if (!this->emitCallPtr(ArgSize, E, E))
       return false;
   }
 

diff  --git a/clang/lib/AST/Interp/Function.h b/clang/lib/AST/Interp/Function.h
index b19d64f9371e3c..0be4564e1e9ec4 100644
--- a/clang/lib/AST/Interp/Function.h
+++ b/clang/lib/AST/Interp/Function.h
@@ -15,9 +15,10 @@
 #ifndef LLVM_CLANG_AST_INTERP_FUNCTION_H
 #define LLVM_CLANG_AST_INTERP_FUNCTION_H
 
-#include "Source.h"
 #include "Descriptor.h"
+#include "Source.h"
 #include "clang/AST/ASTLambda.h"
+#include "clang/AST/Attr.h"
 #include "clang/AST/Decl.h"
 #include "llvm/Support/raw_ostream.h"
 
@@ -108,6 +109,8 @@ class Function final {
   /// Checks if the first argument is a RVO pointer.
   bool hasRVO() const { return HasRVO; }
 
+  bool hasNonNullAttr() const { return getDecl()->hasAttr<NonNullAttr>(); }
+
   /// Range over the scope blocks.
   llvm::iterator_range<llvm::SmallVector<Scope, 2>::const_iterator>
   scopes() const {

diff  --git a/clang/lib/AST/Interp/Interp.cpp b/clang/lib/AST/Interp/Interp.cpp
index b2fe70dc14f9d5..5670888c245eb1 100644
--- a/clang/lib/AST/Interp/Interp.cpp
+++ b/clang/lib/AST/Interp/Interp.cpp
@@ -7,10 +7,9 @@
 //===----------------------------------------------------------------------===//
 
 #include "Interp.h"
-#include <limits>
-#include <vector>
 #include "Function.h"
 #include "InterpFrame.h"
+#include "InterpShared.h"
 #include "InterpStack.h"
 #include "Opcode.h"
 #include "PrimType.h"
@@ -22,6 +21,10 @@
 #include "clang/AST/Expr.h"
 #include "clang/AST/ExprCXX.h"
 #include "llvm/ADT/APSInt.h"
+#include <limits>
+#include <vector>
+
+using namespace clang;
 
 using namespace clang;
 using namespace clang::interp;
@@ -622,6 +625,28 @@ bool CheckDeclRef(InterpState &S, CodePtr OpPC, const DeclRefExpr *DR) {
   return false;
 }
 
+bool CheckNonNullArgs(InterpState &S, CodePtr OpPC, const Function *F,
+                      const CallExpr *CE, unsigned ArgSize) {
+  auto Args = llvm::ArrayRef(CE->getArgs(), CE->getNumArgs());
+  auto NonNullArgs = collectNonNullArgs(F->getDecl(), Args);
+  unsigned Offset = 0;
+  unsigned Index = 0;
+  for (const Expr *Arg : Args) {
+    if (NonNullArgs[Index] && Arg->getType()->isPointerType()) {
+      const Pointer &ArgPtr = S.Stk.peek<Pointer>(ArgSize - Offset);
+      if (ArgPtr.isZero()) {
+        const SourceLocation &Loc = S.Current->getLocation(OpPC);
+        S.CCEDiag(Loc, diag::note_non_null_attribute_failed);
+        return false;
+      }
+    }
+
+    Offset += align(primSize(S.Ctx.classify(Arg).value_or(PT_Ptr)));
+    ++Index;
+  }
+  return true;
+}
+
 bool Interpret(InterpState &S, APValue &Result) {
   // The current stack frame when we started Interpret().
   // This is being used by the ops to determine wheter

diff  --git a/clang/lib/AST/Interp/Interp.h b/clang/lib/AST/Interp/Interp.h
index d885d19ce7064f..7994550cc7b97e 100644
--- a/clang/lib/AST/Interp/Interp.h
+++ b/clang/lib/AST/Interp/Interp.h
@@ -113,6 +113,10 @@ bool CheckThis(InterpState &S, CodePtr OpPC, const Pointer &This);
 /// Checks if a method is pure virtual.
 bool CheckPure(InterpState &S, CodePtr OpPC, const CXXMethodDecl *MD);
 
+/// Checks if all the arguments annotated as 'nonnull' are in fact not null.
+bool CheckNonNullArgs(InterpState &S, CodePtr OpPC, const Function *F,
+                      const CallExpr *CE, unsigned ArgSize);
+
 /// Sets the given integral value to the pointer, which is of
 /// a std::{weak,partial,strong}_ordering type.
 bool SetThreeWayComparisonField(InterpState &S, CodePtr OpPC,
@@ -1980,6 +1984,7 @@ inline bool CallVar(InterpState &S, CodePtr OpPC, const Function *Func,
 
   return false;
 }
+
 inline bool Call(InterpState &S, CodePtr OpPC, const Function *Func,
                  uint32_t VarArgSize) {
   if (Func->hasThisPointer()) {
@@ -2083,7 +2088,8 @@ inline bool CallBI(InterpState &S, CodePtr &PC, const Function *Func,
   return false;
 }
 
-inline bool CallPtr(InterpState &S, CodePtr OpPC, uint32_t ArgSize) {
+inline bool CallPtr(InterpState &S, CodePtr OpPC, uint32_t ArgSize,
+                    const CallExpr *CE) {
   const FunctionPointer &FuncPtr = S.Stk.pop<FunctionPointer>();
 
   const Function *F = FuncPtr.getFunction();
@@ -2095,6 +2101,12 @@ inline bool CallPtr(InterpState &S, CodePtr OpPC, uint32_t ArgSize) {
   }
   assert(F);
 
+  // Check argument nullability state.
+  if (F->hasNonNullAttr()) {
+    if (!CheckNonNullArgs(S, OpPC, F, CE, ArgSize))
+      return false;
+  }
+
   assert(ArgSize >= F->getWrittenArgSize());
   uint32_t VarArgSize = ArgSize - F->getWrittenArgSize();
 
@@ -2151,6 +2163,18 @@ inline bool OffsetOf(InterpState &S, CodePtr OpPC, const OffsetOfExpr *E) {
   return true;
 }
 
+template <PrimType Name, class T = typename PrimConv<Name>::T>
+inline bool CheckNonNullArg(InterpState &S, CodePtr OpPC) {
+  const T &Arg = S.Stk.peek<T>();
+  if (!Arg.isZero())
+    return true;
+
+  const SourceLocation &Loc = S.Current->getLocation(OpPC);
+  S.CCEDiag(Loc, diag::note_non_null_attribute_failed);
+
+  return false;
+}
+
 //===----------------------------------------------------------------------===//
 // Read opcode arguments
 //===----------------------------------------------------------------------===//

diff  --git a/clang/lib/AST/Interp/InterpShared.cpp b/clang/lib/AST/Interp/InterpShared.cpp
new file mode 100644
index 00000000000000..6af03691f1b20b
--- /dev/null
+++ b/clang/lib/AST/Interp/InterpShared.cpp
@@ -0,0 +1,42 @@
+//===--- InterpShared.cpp ---------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "InterpShared.h"
+#include "clang/AST/Attr.h"
+#include "llvm/ADT/BitVector.h"
+
+namespace clang {
+namespace interp {
+
+llvm::BitVector collectNonNullArgs(const FunctionDecl *F,
+                                   const llvm::ArrayRef<const Expr *> &Args) {
+  llvm::BitVector NonNullArgs;
+  if (!F)
+    return NonNullArgs;
+
+  assert(F);
+  NonNullArgs.resize(Args.size());
+
+  for (const auto *Attr : F->specific_attrs<NonNullAttr>()) {
+    if (!Attr->args_size()) {
+      NonNullArgs.set();
+      break;
+    } else
+      for (auto Idx : Attr->args()) {
+        unsigned ASTIdx = Idx.getASTIndex();
+        if (ASTIdx >= Args.size())
+          continue;
+        NonNullArgs[ASTIdx] = true;
+      }
+  }
+
+  return NonNullArgs;
+}
+
+} // namespace interp
+} // namespace clang

diff  --git a/clang/lib/AST/Interp/InterpShared.h b/clang/lib/AST/Interp/InterpShared.h
new file mode 100644
index 00000000000000..8c5e0bee22c92a
--- /dev/null
+++ b/clang/lib/AST/Interp/InterpShared.h
@@ -0,0 +1,26 @@
+//===--- InterpShared.h -----------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_LIB_AST_INTERP_SHARED_H
+#define LLVM_CLANG_LIB_AST_INTERP_SHARED_H
+
+#include "llvm/ADT/BitVector.h"
+
+namespace clang {
+class FunctionDecl;
+class Expr;
+
+namespace interp {
+
+llvm::BitVector collectNonNullArgs(const FunctionDecl *F,
+                                   const llvm::ArrayRef<const Expr *> &Args);
+
+} // namespace interp
+} // namespace clang
+
+#endif

diff  --git a/clang/lib/AST/Interp/Opcodes.td b/clang/lib/AST/Interp/Opcodes.td
index 5add723842d2b2..e36c42d450fc9f 100644
--- a/clang/lib/AST/Interp/Opcodes.td
+++ b/clang/lib/AST/Interp/Opcodes.td
@@ -206,7 +206,7 @@ def CallBI : Opcode {
 }
 
 def CallPtr : Opcode {
-  let Args = [ArgUint32];
+  let Args = [ArgUint32, ArgCallExpr];
   let Types = [];
 }
 
@@ -706,3 +706,8 @@ def InvalidDeclRef : Opcode {
 }
 
 def ArrayDecay : Opcode;
+
+def CheckNonNullArg : Opcode {
+  let Types = [PtrTypeClass];
+  let HasGroup = 1;
+}

diff  --git a/clang/test/AST/Interp/nullable.cpp b/clang/test/AST/Interp/nullable.cpp
new file mode 100644
index 00000000000000..3bc2595fb8f006
--- /dev/null
+++ b/clang/test/AST/Interp/nullable.cpp
@@ -0,0 +1,77 @@
+// RUN: %clang_cc1 -fexperimental-new-constant-interpreter -verify=expected,both %s
+// RUN: %clang_cc1 -verify=ref,both %s
+
+
+constexpr int dummy = 1;
+constexpr const int *null = nullptr;
+
+namespace simple {
+  __attribute__((nonnull))
+  constexpr int simple1(const int*) {
+    return 1;
+  }
+  static_assert(simple1(&dummy) == 1, "");
+  static_assert(simple1(nullptr) == 1, ""); // both-error {{not an integral constant expression}} \
+                                            // both-note {{null passed to a callee}}
+  static_assert(simple1(null) == 1, ""); // both-error {{not an integral constant expression}} \
+                                         // both-note {{null passed to a callee}}
+
+  __attribute__((nonnull)) // both-warning {{applied to function with no pointer arguments}}
+  constexpr int simple2(const int &a) {
+    return 12;
+  }
+  static_assert(simple2(1) == 12, "");
+}
+
+namespace methods {
+  struct S {
+    __attribute__((nonnull(2))) // both-warning {{only applies to pointer arguments}}
+    __attribute__((nonnull(3)))
+    constexpr int foo(int a, const void *p) const {
+      return 12;
+    }
+
+    __attribute__((nonnull(3)))
+    constexpr int foo2(...) const {
+      return 12;
+    }
+
+    __attribute__((nonnull))
+    constexpr int foo3(...) const {
+      return 12;
+    }
+  };
+
+  constexpr S s{};
+  static_assert(s.foo(8, &dummy) == 12, "");
+
+  static_assert(s.foo2(nullptr) == 12, "");
+  static_assert(s.foo2(1, nullptr) == 12, ""); // both-error {{not an integral constant expression}} \
+                                               // both-note {{null passed to a callee}}
+
+  constexpr S *s2 = nullptr;
+  static_assert(s2->foo3() == 12, ""); // both-error {{not an integral constant expression}} \
+                                       // both-note {{member call on dereferenced null pointer}}
+}
+
+namespace fnptrs {
+  __attribute__((nonnull))
+  constexpr int add(int a, const void *p) {
+    return a + 1;
+  }
+  __attribute__((nonnull(3)))
+  constexpr int applyBinOp(int a, int b, int (*op)(int, const void *)) {
+    return op(a, nullptr); // both-note {{null passed to a callee}}
+  }
+  static_assert(applyBinOp(10, 20, add) == 11, ""); // both-error {{not an integral constant expression}} \
+                                                    // both-note {{in call to}}
+
+  static_assert(applyBinOp(10, 20, nullptr) == 11, ""); // both-error {{not an integral constant expression}} \
+                                                        // both-note {{null passed to a callee}}
+}
+
+namespace lambdas {
+  auto lstatic = [](const void *P) __attribute__((nonnull)) { return 3; };
+  static_assert(lstatic(nullptr) == 3, ""); // both-error {{not an integral constant expression}} \
+                                            // both-note {{null passed to a callee}}
+}

diff  --git a/clang/test/Sema/attr-nonnull.c b/clang/test/Sema/attr-nonnull.c
index f8de31716a80c7..865348daef10e0 100644
--- a/clang/test/Sema/attr-nonnull.c
+++ b/clang/test/Sema/attr-nonnull.c
@@ -1,4 +1,5 @@
 // RUN: %clang_cc1 %s -verify -fsyntax-only
+// RUN: %clang_cc1 %s -verify -fsyntax-only -fexperimental-new-constant-interpreter
 
 void f1(int *a1, int *a2, int *a3, int *a4, int *a5, int *a6, int *a7,
         int *a8, int *a9, int *a10, int *a11, int *a12, int *a13, int *a14,

diff  --git a/clang/test/SemaCXX/attr-nonnull.cpp b/clang/test/SemaCXX/attr-nonnull.cpp
index 21eedcf376d5b6..6f9119b519d093 100644
--- a/clang/test/SemaCXX/attr-nonnull.cpp
+++ b/clang/test/SemaCXX/attr-nonnull.cpp
@@ -1,4 +1,5 @@
 // RUN: %clang_cc1 -fsyntax-only -verify %s
+// RUN: %clang_cc1 -fsyntax-only -verify %s -fexperimental-new-constant-interpreter
 struct S {
   S(const char *) __attribute__((nonnull(2)));
 
@@ -84,4 +85,4 @@ constexpr int i4 = f4(&c, 0, 0); //expected-error {{constant expression}} expect
 constexpr int i42 = f4(0, &c, 1); //expected-error {{constant expression}} expected-note {{null passed}}
 constexpr int i43 = f4(&c, &c, 0);
 
-}
\ No newline at end of file
+}


        


More information about the cfe-commits mailing list