[clang] [clang][Interp] Allow adding offsets to function pointers (PR #105641)

Timm Baeder via cfe-commits cfe-commits at lists.llvm.org
Thu Aug 22 04:05:18 PDT 2024


https://github.com/tbaederr created https://github.com/llvm/llvm-project/pull/105641

Convert them to Pointers, do the offset calculation and then convert them back to function pointers.

>From af18bbdf3440b3d13077dbcfd22a04fcf0b1564b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Timm=20B=C3=A4der?= <tbaeder at redhat.com>
Date: Wed, 21 Aug 2024 09:49:54 +0200
Subject: [PATCH] [clang][Interp] Allow adding offsets to function pointers

Convert them to Pointers, do the offset calculation and then
convert them back to function pointers.
---
 clang/lib/AST/ByteCode/Compiler.cpp        | 40 ++++++++++++++++----
 clang/lib/AST/ByteCode/FunctionPointer.cpp | 44 ++++++++++++++++++++++
 clang/lib/AST/ByteCode/FunctionPointer.h   | 41 +++++---------------
 clang/lib/AST/ByteCode/Interp.h            | 37 +++++++++++++++---
 clang/lib/AST/ByteCode/Pointer.h           |  6 ++-
 clang/lib/AST/CMakeLists.txt               |  1 +
 clang/test/AST/ByteCode/c.c                | 16 ++++++++
 7 files changed, 140 insertions(+), 45 deletions(-)
 create mode 100644 clang/lib/AST/ByteCode/FunctionPointer.cpp

diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp
index 10f3222726fd43..a91a3447230049 100644
--- a/clang/lib/AST/ByteCode/Compiler.cpp
+++ b/clang/lib/AST/ByteCode/Compiler.cpp
@@ -885,12 +885,21 @@ bool Compiler<Emitter>::VisitPointerArithBinOp(const BinaryOperator *E) {
   if (!LT || !RT)
     return false;
 
+  // Visit the given pointer expression and optionally convert to a PT_Ptr.
+  auto visitAsPointer = [&](const Expr *E, PrimType T) -> bool {
+    if (!this->visit(E))
+      return false;
+    if (T != PT_Ptr)
+      return this->emitDecayPtr(T, PT_Ptr, E);
+    return true;
+  };
+
   if (LHS->getType()->isPointerType() && RHS->getType()->isPointerType()) {
     if (Op != BO_Sub)
       return false;
 
     assert(E->getType()->isIntegerType());
-    if (!visit(RHS) || !visit(LHS))
+    if (!visitAsPointer(RHS, *RT) || !visitAsPointer(LHS, *LT))
       return false;
 
     return this->emitSubPtr(classifyPrim(E->getType()), E);
@@ -898,21 +907,38 @@ bool Compiler<Emitter>::VisitPointerArithBinOp(const BinaryOperator *E) {
 
   PrimType OffsetType;
   if (LHS->getType()->isIntegerType()) {
-    if (!visit(RHS) || !visit(LHS))
+    if (!visitAsPointer(RHS))
+      return false;
+    if (!this->visit(LHS))
       return false;
     OffsetType = *LT;
   } else if (RHS->getType()->isIntegerType()) {
-    if (!visit(LHS) || !visit(RHS))
+    if (!visitAsPointer(LHS))
+      return false;
+    if (!this->visit(RHS))
       return false;
     OffsetType = *RT;
   } else {
     return false;
   }
 
-  if (Op == BO_Add)
-    return this->emitAddOffset(OffsetType, E);
-  else if (Op == BO_Sub)
-    return this->emitSubOffset(OffsetType, E);
+  // Do the operation and optionally transform to
+  // result pointer type.
+  if (Op == BO_Add) {
+    if (!this->emitAddOffset(OffsetType, E))
+      return false;
+
+    if (classifyPrim(E) != PT_Ptr)
+      return this->emitDecayPtr(PT_Ptr, classifyPrim(E), E);
+    return true;
+  } else if (Op == BO_Sub) {
+    if (!this->emitSubOffset(OffsetType, E))
+      return false;
+
+    if (classifyPrim(E) != PT_Ptr)
+      return this->emitDecayPtr(PT_Ptr, classifyPrim(E), E);
+    return true;
+  }
 
   return false;
 }
diff --git a/clang/lib/AST/ByteCode/FunctionPointer.cpp b/clang/lib/AST/ByteCode/FunctionPointer.cpp
new file mode 100644
index 00000000000000..4d6ca2e0f8ae2e
--- /dev/null
+++ b/clang/lib/AST/ByteCode/FunctionPointer.cpp
@@ -0,0 +1,44 @@
+//===----------------------- FunctionPointer.cpp ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "FunctionPointer.h"
+
+namespace clang {
+namespace interp {
+
+APValue FunctionPointer::toAPValue(const ASTContext &) const {
+  llvm::errs() << __PRETTY_FUNCTION__ << ": " << Offset << '\n';
+  if (!Func)
+    return APValue(static_cast<Expr *>(nullptr), CharUnits::Zero(), {},
+                   /*OnePastTheEnd=*/false, /*IsNull=*/true);
+
+  if (!Valid)
+    return APValue(static_cast<Expr *>(nullptr),
+                   CharUnits::fromQuantity(getIntegerRepresentation()), {},
+                   /*OnePastTheEnd=*/false, /*IsNull=*/false);
+
+  if (Func->getDecl())
+    return APValue(Func->getDecl(), CharUnits::fromQuantity(Offset), {},
+                   /*OnePastTheEnd=*/false, /*IsNull=*/false);
+  return APValue(Func->getExpr(), CharUnits::fromQuantity(Offset), {},
+                 /*OnePastTheEnd=*/false, /*IsNull=*/false);
+}
+
+void FunctionPointer::print(llvm::raw_ostream &OS) const {
+  OS << "FnPtr(";
+  if (Func && Valid)
+    OS << Func->getName();
+  else if (Func)
+    OS << reinterpret_cast<uintptr_t>(Func);
+  else
+    OS << "nullptr";
+  OS << ") + " << Offset;
+}
+
+} // namespace interp
+} // namespace clang
diff --git a/clang/lib/AST/ByteCode/FunctionPointer.h b/clang/lib/AST/ByteCode/FunctionPointer.h
index c9bdfbee55441a..e2b45b2344fdce 100644
--- a/clang/lib/AST/ByteCode/FunctionPointer.h
+++ b/clang/lib/AST/ByteCode/FunctionPointer.h
@@ -11,25 +11,29 @@
 
 #include "Function.h"
 #include "Primitives.h"
-#include "clang/AST/APValue.h"
 
 namespace clang {
 class ASTContext;
+class APValue;
 namespace interp {
 
 class FunctionPointer final {
 private:
   const Function *Func;
+  uint64_t Offset;
   bool Valid;
 
 public:
   FunctionPointer() = default;
-  FunctionPointer(const Function *Func) : Func(Func), Valid(true) {}
+  FunctionPointer(const Function *Func, uint64_t Offset = 0)
+      : Func(Func), Offset(Offset), Valid(true) {}
 
   FunctionPointer(uintptr_t IntVal, const Descriptor *Desc = nullptr)
-      : Func(reinterpret_cast<const Function *>(IntVal)), Valid(false) {}
+      : Func(reinterpret_cast<const Function *>(IntVal)), Offset(0),
+        Valid(false) {}
 
   const Function *getFunction() const { return Func; }
+  uint64_t getOffset() const { return Offset; }
   bool isZero() const { return !Func; }
   bool isValid() const { return Valid; }
   bool isWeak() const {
@@ -39,33 +43,8 @@ class FunctionPointer final {
     return Func->getDecl()->isWeak();
   }
 
-  APValue toAPValue(const ASTContext &) const {
-    if (!Func)
-      return APValue(static_cast<Expr *>(nullptr), CharUnits::Zero(), {},
-                     /*OnePastTheEnd=*/false, /*IsNull=*/true);
-
-    if (!Valid)
-      return APValue(static_cast<Expr *>(nullptr),
-                     CharUnits::fromQuantity(getIntegerRepresentation()), {},
-                     /*OnePastTheEnd=*/false, /*IsNull=*/false);
-
-    if (Func->getDecl())
-      return APValue(Func->getDecl(), CharUnits::Zero(), {},
-                     /*OnePastTheEnd=*/false, /*IsNull=*/false);
-    return APValue(Func->getExpr(), CharUnits::Zero(), {},
-                   /*OnePastTheEnd=*/false, /*IsNull=*/false);
-  }
-
-  void print(llvm::raw_ostream &OS) const {
-    OS << "FnPtr(";
-    if (Func && Valid)
-      OS << Func->getName();
-    else if (Func)
-      OS << reinterpret_cast<uintptr_t>(Func);
-    else
-      OS << "nullptr";
-    OS << ")";
-  }
+  APValue toAPValue(const ASTContext &) const;
+  void print(llvm::raw_ostream &OS) const;
 
   std::string toDiagnosticString(const ASTContext &Ctx) const {
     if (!Func)
@@ -79,7 +58,7 @@ class FunctionPointer final {
   }
 
   ComparisonCategoryResult compare(const FunctionPointer &RHS) const {
-    if (Func == RHS.Func)
+    if (Func == RHS.Func && Offset == RHS.Offset)
       return ComparisonCategoryResult::Equal;
     return ComparisonCategoryResult::Unordered;
   }
diff --git a/clang/lib/AST/ByteCode/Interp.h b/clang/lib/AST/ByteCode/Interp.h
index d8629881abc685..fd4406c0db2b88 100644
--- a/clang/lib/AST/ByteCode/Interp.h
+++ b/clang/lib/AST/ByteCode/Interp.h
@@ -1857,8 +1857,23 @@ bool OffsetHelper(InterpState &S, CodePtr OpPC, const T &Offset,
     else
       S.Stk.push<Pointer>(V - O, Ptr.asIntPointer().Desc);
     return true;
+  } else if (Ptr.isFunctionPointer()) {
+    uint64_t O = static_cast<uint64_t>(Offset);
+    uint64_t N;
+    if constexpr (Op == ArithOp::Add)
+      N = Ptr.getByteOffset() + O;
+    else
+      N = Ptr.getByteOffset() - O;
+
+    if (N > 1)
+      S.CCEDiag(S.Current->getSource(OpPC), diag::note_constexpr_array_index)
+          << N << /*non-array*/ true << 0;
+    S.Stk.push<Pointer>(Ptr.asFunctionPointer().getFunction(), N);
+    return true;
   }
 
+  assert(Ptr.isBlockPointer());
+
   uint64_t MaxIndex = static_cast<uint64_t>(Ptr.getNumElems());
   uint64_t Index;
   if (Ptr.isOnePastEnd())
@@ -2024,10 +2039,15 @@ inline bool SubPtr(InterpState &S, CodePtr OpPC) {
     return true;
   }
 
-  T A = LHS.isElementPastEnd() ? T::from(LHS.getNumElems())
-                               : T::from(LHS.getIndex());
-  T B = RHS.isElementPastEnd() ? T::from(RHS.getNumElems())
-                               : T::from(RHS.getIndex());
+  T A = LHS.isBlockPointer()
+            ? (LHS.isElementPastEnd() ? T::from(LHS.getNumElems())
+                                      : T::from(LHS.getIndex()))
+            : T::from(LHS.getIntegerRepresentation());
+  T B = RHS.isBlockPointer()
+            ? (RHS.isElementPastEnd() ? T::from(RHS.getNumElems())
+                                      : T::from(RHS.getIndex()))
+            : T::from(RHS.getIntegerRepresentation());
+
   return AddSubMulHelper<T, T::sub, std::minus>(S, OpPC, A.bitWidth(), A, B);
 }
 
@@ -2905,8 +2925,15 @@ inline bool DecayPtr(InterpState &S, CodePtr OpPC) {
 
   if constexpr (std::is_same_v<FromT, FunctionPointer> &&
                 std::is_same_v<ToT, Pointer>) {
-    S.Stk.push<Pointer>(OldPtr.getFunction());
+    S.Stk.push<Pointer>(OldPtr.getFunction(), OldPtr.getOffset());
     return true;
+  } else if constexpr (std::is_same_v<FromT, Pointer> &&
+                       std::is_same_v<ToT, FunctionPointer>) {
+    if (OldPtr.isFunctionPointer()) {
+      S.Stk.push<FunctionPointer>(OldPtr.asFunctionPointer().getFunction(),
+                                  OldPtr.getByteOffset());
+      return true;
+    }
   }
 
   S.Stk.push<ToT>(ToT(OldPtr.getIntegerRepresentation(), nullptr));
diff --git a/clang/lib/AST/ByteCode/Pointer.h b/clang/lib/AST/ByteCode/Pointer.h
index ba30449977376b..27ac33616f5a8b 100644
--- a/clang/lib/AST/ByteCode/Pointer.h
+++ b/clang/lib/AST/ByteCode/Pointer.h
@@ -137,7 +137,7 @@ class Pointer {
     if (isIntegralPointer())
       return asIntPointer().Value + (Offset * elemSize());
     if (isFunctionPointer())
-      return asFunctionPointer().getIntegerRepresentation();
+      return asFunctionPointer().getIntegerRepresentation() + Offset;
     return reinterpret_cast<uint64_t>(asBlockPointer().Pointee) + Offset;
   }
 
@@ -551,7 +551,7 @@ class Pointer {
   }
 
   /// Returns the byte offset from the start.
-  unsigned getByteOffset() const {
+  uint64_t getByteOffset() const {
     if (isIntegralPointer())
       return asIntPointer().Value + Offset;
     if (isOnePastEnd())
@@ -614,6 +614,8 @@ class Pointer {
 
   /// Checks if the pointer is pointing to a zero-size array.
   bool isZeroSizeArray() const {
+    if (isFunctionPointer())
+      return false;
     if (const auto *Desc = getFieldDesc())
       return Desc->isZeroSizeArray();
     return false;
diff --git a/clang/lib/AST/CMakeLists.txt b/clang/lib/AST/CMakeLists.txt
index 041252b6830e0a..6195a16c2c68db 100644
--- a/clang/lib/AST/CMakeLists.txt
+++ b/clang/lib/AST/CMakeLists.txt
@@ -72,6 +72,7 @@ add_clang_library(clangAST
   ByteCode/EvalEmitter.cpp
   ByteCode/Frame.cpp
   ByteCode/Function.cpp
+  ByteCode/FunctionPointer.cpp
   ByteCode/InterpBuiltin.cpp
   ByteCode/Floating.cpp
   ByteCode/EvaluationResult.cpp
diff --git a/clang/test/AST/ByteCode/c.c b/clang/test/AST/ByteCode/c.c
index b38259d41130eb..60f4d6ad1b2967 100644
--- a/clang/test/AST/ByteCode/c.c
+++ b/clang/test/AST/ByteCode/c.c
@@ -297,3 +297,19 @@ void T1(void) {
 
 enum teste1 test1f(void), (*test1)(void) = test1f; // pedantic-warning {{ISO C forbids forward references to 'enum' types}}
 enum teste1 { TEST1 };
+
+
+void func(void) {
+  _Static_assert(func + 1 - func == 1, ""); // pedantic-warning {{arithmetic on a pointer to the function type}} \
+                                            // pedantic-warning {{arithmetic on pointers to the function type}} \
+                                            // pedantic-warning {{not an integer constant expression}}
+  _Static_assert(func + 0xdead000000000000UL - 0xdead000000000000UL == func, ""); // pedantic-warning 2{{arithmetic on a pointer to the function type}} \
+                                                                                  // pedantic-warning {{not an integer constant expression}} \
+                                                                                  // pedantic-note {{cannot refer to element 16045481047390945280 of non-array object in a constant expression}}
+  _Static_assert(func + 1 != func, ""); // pedantic-warning {{arithmetic on a pointer to the function type}} \
+                                        // pedantic-warning {{expression is not an integer constant expression}}
+  func + 0xdead000000000000UL; // all-warning {{expression result unused}} \
+                               // pedantic-warning {{arithmetic on a pointer to the function type}}
+  func - 0xdead000000000000UL; // all-warning {{expression result unused}} \
+                               // pedantic-warning {{arithmetic on a pointer to the function type}}
+}



More information about the cfe-commits mailing list