[clang] 3ad1673 - [clang][Interp] Implement function pointers
Timm Bäder via cfe-commits
cfe-commits at lists.llvm.org
Thu Mar 30 06:38:01 PDT 2023
Author: Timm Bäder
Date: 2023-03-30T15:37:49+02:00
New Revision: 3ad167329aafde02e70b0327c0488602111a81ee
URL: https://github.com/llvm/llvm-project/commit/3ad167329aafde02e70b0327c0488602111a81ee
DIFF: https://github.com/llvm/llvm-project/commit/3ad167329aafde02e70b0327c0488602111a81ee.diff
LOG: [clang][Interp] Implement function pointers
Differential Revision: https://reviews.llvm.org/D141472
Added:
clang/lib/AST/Interp/FunctionPointer.h
Modified:
clang/lib/AST/Interp/ByteCodeExprGen.cpp
clang/lib/AST/Interp/Context.cpp
clang/lib/AST/Interp/Descriptor.cpp
clang/lib/AST/Interp/Interp.h
clang/lib/AST/Interp/InterpStack.h
clang/lib/AST/Interp/Opcodes.td
clang/lib/AST/Interp/PrimType.cpp
clang/lib/AST/Interp/PrimType.h
clang/test/AST/Interp/functions.cpp
Removed:
################################################################################
diff --git a/clang/lib/AST/Interp/ByteCodeExprGen.cpp b/clang/lib/AST/Interp/ByteCodeExprGen.cpp
index fff2425bedf42..c6cf7f7c99a59 100644
--- a/clang/lib/AST/Interp/ByteCodeExprGen.cpp
+++ b/clang/lib/AST/Interp/ByteCodeExprGen.cpp
@@ -131,6 +131,11 @@ bool ByteCodeExprGen<Emitter>::VisitCastExpr(const CastExpr *CE) {
return this->emitCastFloatingIntegral(*ToT, CE);
}
+ case CK_NullToPointer:
+ if (DiscardResult)
+ return true;
+ return this->emitNull(classifyPrim(CE->getType()), CE);
+
case CK_ArrayToPointerDecay:
case CK_AtomicToNonAtomic:
case CK_ConstructorConversion:
@@ -138,7 +143,6 @@ bool ByteCodeExprGen<Emitter>::VisitCastExpr(const CastExpr *CE) {
case CK_NonAtomicToAtomic:
case CK_NoOp:
case CK_UserDefinedConversion:
- case CK_NullToPointer:
return this->visit(SubExpr);
case CK_IntegralToBoolean:
@@ -400,10 +404,7 @@ bool ByteCodeExprGen<Emitter>::VisitImplicitValueInitExpr(const ImplicitValueIni
if (!T)
return false;
- if (E->getType()->isPointerType())
- return this->emitNullPtr(E);
-
- return this->emitZero(*T, E);
+ return this->visitZeroInitializer(*T, E);
}
template <class Emitter>
@@ -950,6 +951,8 @@ bool ByteCodeExprGen<Emitter>::visitZeroInitializer(PrimType T, const Expr *E) {
return this->emitZeroUint64(E);
case PT_Ptr:
return this->emitNullPtr(E);
+ case PT_FnPtr:
+ return this->emitNullFnPtr(E);
case PT_Float:
assert(false);
}
@@ -1116,6 +1119,7 @@ bool ByteCodeExprGen<Emitter>::emitConst(T Value, const Expr *E) {
case PT_Bool:
return this->emitConstBool(Value, E);
case PT_Ptr:
+ case PT_FnPtr:
case PT_Float:
llvm_unreachable("Invalid integral type");
break;
@@ -1606,8 +1610,27 @@ bool ByteCodeExprGen<Emitter>::VisitCallExpr(const CallExpr *E) {
if (E->getBuiltinCallee())
return VisitBuiltinCallExpr(E);
- const Decl *Callee = E->getCalleeDecl();
- if (const auto *FuncDecl = dyn_cast_if_present<FunctionDecl>(Callee)) {
+ QualType ReturnType = E->getCallReturnType(Ctx.getASTContext());
+ std::optional<PrimType> T = classify(ReturnType);
+ bool HasRVO = !ReturnType->isVoidType() && !T;
+
+ if (HasRVO && DiscardResult) {
+ // If we need to discard the return value but the function returns its
+ // value via an RVO pointer, we need to create one such pointer just
+ // for this call.
+ if (std::optional<unsigned> LocalIndex = allocateLocal(E)) {
+ if (!this->emitGetPtrLocal(*LocalIndex, E))
+ return false;
+ }
+ }
+
+ // Put arguments on the stack.
+ for (const auto *Arg : E->arguments()) {
+ if (!this->visit(Arg))
+ return false;
+ }
+
+ if (const FunctionDecl *FuncDecl = E->getDirectCallee()) {
const Function *Func = getFunction(FuncDecl);
if (!Func)
return false;
@@ -1619,24 +1642,7 @@ bool ByteCodeExprGen<Emitter>::VisitCallExpr(const CallExpr *E) {
if (Func->isFullyCompiled() && !Func->isConstexpr())
return false;
- QualType ReturnType = E->getCallReturnType(Ctx.getASTContext());
- std::optional<PrimType> T = classify(ReturnType);
-
- if (Func->hasRVO() && DiscardResult) {
- // If we need to discard the return value but the function returns its
- // value via an RVO pointer, we need to create one such pointer just
- // for this call.
- if (std::optional<unsigned> LocalIndex = allocateLocal(E)) {
- if (!this->emitGetPtrLocal(*LocalIndex, E))
- return false;
- }
- }
-
- // Put arguments on the stack.
- for (const auto *Arg : E->arguments()) {
- if (!this->visit(Arg))
- return false;
- }
+ assert(HasRVO == Func->hasRVO());
// In any case call the function. The return value will end up on the stack
// and if the function has RVO, we already have the pointer on the stack to
@@ -1644,15 +1650,22 @@ bool ByteCodeExprGen<Emitter>::VisitCallExpr(const CallExpr *E) {
if (!this->emitCall(Func, E))
return false;
- if (DiscardResult && !ReturnType->isVoidType() && T)
- return this->emitPop(*T, E);
-
- return true;
} else {
- assert(false && "We don't support non-FunctionDecl callees right now.");
+ // Indirect call. Visit the callee, which will leave a FunctionPointer on
+ // the stack. Cleanup of the returned value if necessary will be done after
+ // the function call completed.
+ if (!this->visit(E->getCallee()))
+ return false;
+
+ if (!this->emitCallPtr(E))
+ return false;
}
- return false;
+ // Cleanup for discarded return values.
+ if (DiscardResult && !ReturnType->isVoidType() && T)
+ return this->emitPop(*T, E);
+
+ return true;
}
template <class Emitter>
@@ -1846,6 +1859,9 @@ bool ByteCodeExprGen<Emitter>::VisitDeclRefExpr(const DeclRefExpr *E) {
return this->emitConst(ECD->getInitVal(), E);
} else if (const auto *BD = dyn_cast<BindingDecl>(Decl)) {
return this->visit(BD->getBinding());
+ } else if (const auto *FuncDecl = dyn_cast<FunctionDecl>(Decl)) {
+ const Function *F = getFunction(FuncDecl);
+ return F && this->emitGetFnPtr(F, E);
}
return false;
diff --git a/clang/lib/AST/Interp/Context.cpp b/clang/lib/AST/Interp/Context.cpp
index dcf41a5c40202..6ede05e0f4c42 100644
--- a/clang/lib/AST/Interp/Context.cpp
+++ b/clang/lib/AST/Interp/Context.cpp
@@ -78,9 +78,11 @@ bool Context::evaluateAsInitializer(State &Parent, const VarDecl *VD,
const LangOptions &Context::getLangOpts() const { return Ctx.getLangOpts(); }
std::optional<PrimType> Context::classify(QualType T) const {
- if (T->isReferenceType() || T->isPointerType()) {
+ if (T->isFunctionPointerType() || T->isFunctionReferenceType())
+ return PT_FnPtr;
+
+ if (T->isReferenceType() || T->isPointerType())
return PT_Ptr;
- }
if (T->isBooleanType())
return PT_Bool;
diff --git a/clang/lib/AST/Interp/Descriptor.cpp b/clang/lib/AST/Interp/Descriptor.cpp
index 212311cfa2ae6..31554dddb30a0 100644
--- a/clang/lib/AST/Interp/Descriptor.cpp
+++ b/clang/lib/AST/Interp/Descriptor.cpp
@@ -9,6 +9,7 @@
#include "Descriptor.h"
#include "Boolean.h"
#include "Floating.h"
+#include "FunctionPointer.h"
#include "Pointer.h"
#include "PrimType.h"
#include "Record.h"
diff --git a/clang/lib/AST/Interp/FunctionPointer.h b/clang/lib/AST/Interp/FunctionPointer.h
new file mode 100644
index 0000000000000..2d449bdb031d8
--- /dev/null
+++ b/clang/lib/AST/Interp/FunctionPointer.h
@@ -0,0 +1,57 @@
+
+
+#ifndef LLVM_CLANG_AST_INTERP_FUNCTION_POINTER_H
+#define LLVM_CLANG_AST_INTERP_FUNCTION_POINTER_H
+
+#include "Function.h"
+#include "Primitives.h"
+#include "clang/AST/APValue.h"
+
+namespace clang {
+namespace interp {
+
+class FunctionPointer final {
+private:
+ const Function *Func;
+
+public:
+ FunctionPointer() : Func(nullptr) {}
+ FunctionPointer(const Function *Func) : Func(Func) { assert(Func); }
+
+ const Function *getFunction() const { return Func; }
+
+ APValue toAPValue() const {
+ if (!Func)
+ return APValue(static_cast<Expr *>(nullptr), CharUnits::Zero(), {},
+ /*OnePastTheEnd=*/false, /*IsNull=*/true);
+
+ return APValue(Func->getDecl(), CharUnits::Zero(), {},
+ /*OnePastTheEnd=*/false, /*IsNull=*/false);
+ }
+
+ void print(llvm::raw_ostream &OS) const {
+ OS << "FnPtr(";
+ if (Func)
+ OS << Func->getName();
+ else
+ OS << "nullptr";
+ OS << ")";
+ }
+
+ ComparisonCategoryResult compare(const FunctionPointer &RHS) const {
+ if (Func == RHS.Func)
+ return ComparisonCategoryResult::Equal;
+ return ComparisonCategoryResult::Unordered;
+ }
+};
+
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
+ FunctionPointer FP) {
+ FP.print(OS);
+ return OS;
+}
+
+} // namespace interp
+} // namespace clang
+
+#endif
diff --git a/clang/lib/AST/Interp/Interp.h b/clang/lib/AST/Interp/Interp.h
index 98561f0a9ce07..bb34737e018b2 100644
--- a/clang/lib/AST/Interp/Interp.h
+++ b/clang/lib/AST/Interp/Interp.h
@@ -16,6 +16,7 @@
#include "Boolean.h"
#include "Floating.h"
#include "Function.h"
+#include "FunctionPointer.h"
#include "InterpFrame.h"
#include "InterpStack.h"
#include "InterpState.h"
@@ -1538,6 +1539,22 @@ inline bool CallBI(InterpState &S, CodePtr &PC, const Function *Func) {
return false;
}
+inline bool CallPtr(InterpState &S, CodePtr &PC) {
+ const FunctionPointer &FuncPtr = S.Stk.pop<FunctionPointer>();
+
+ const Function *F = FuncPtr.getFunction();
+ if (!F || !F->isConstexpr())
+ return false;
+
+ return Call(S, PC, F);
+}
+
+inline bool GetFnPtr(InterpState &S, CodePtr &PC, const Function *Func) {
+ assert(Func);
+ S.Stk.push<FunctionPointer>(Func);
+ return true;
+}
+
//===----------------------------------------------------------------------===//
// Read opcode arguments
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/AST/Interp/InterpStack.h b/clang/lib/AST/Interp/InterpStack.h
index 4987b2c33d6de..e625ffd8e421f 100644
--- a/clang/lib/AST/Interp/InterpStack.h
+++ b/clang/lib/AST/Interp/InterpStack.h
@@ -13,6 +13,7 @@
#ifndef LLVM_CLANG_AST_INTERP_INTERPSTACK_H
#define LLVM_CLANG_AST_INTERP_INTERPSTACK_H
+#include "FunctionPointer.h"
#include "PrimType.h"
#include <memory>
#include <vector>
@@ -162,6 +163,8 @@ class InterpStack final {
return PT_Uint64;
else if constexpr (std::is_same_v<T, Floating>)
return PT_Float;
+ else if constexpr (std::is_same_v<T, FunctionPointer>)
+ return PT_FnPtr;
llvm_unreachable("unknown type push()'ed into InterpStack");
}
diff --git a/clang/lib/AST/Interp/Opcodes.td b/clang/lib/AST/Interp/Opcodes.td
index 80d5c652d8ddc..f3662dcd6f430 100644
--- a/clang/lib/AST/Interp/Opcodes.td
+++ b/clang/lib/AST/Interp/Opcodes.td
@@ -27,6 +27,7 @@ def Sint64 : Type;
def Uint64 : Type;
def Float : Type;
def Ptr : Type;
+def FnPtr : Type;
//===----------------------------------------------------------------------===//
// Types transferred to the interpreter.
@@ -77,7 +78,7 @@ def AluTypeClass : TypeClass {
}
def PtrTypeClass : TypeClass {
- let Types = [Ptr];
+ let Types = [Ptr, FnPtr];
}
def BoolTypeClass : TypeClass {
@@ -187,6 +188,12 @@ def CallBI : Opcode {
let ChangesPC = 1;
}
+def CallPtr : Opcode {
+ let Args = [];
+ let Types = [];
+ let ChangesPC = 1;
+}
+
//===----------------------------------------------------------------------===//
// Frame management
//===----------------------------------------------------------------------===//
@@ -228,6 +235,7 @@ def Zero : Opcode {
// [] -> [Pointer]
def Null : Opcode {
let Types = [PtrTypeClass];
+ let HasGroup = 1;
}
//===----------------------------------------------------------------------===//
@@ -447,6 +455,14 @@ def DecPtr : Opcode {
let HasGroup = 0;
}
+//===----------------------------------------------------------------------===//
+// Function pointers.
+//===----------------------------------------------------------------------===//
+def GetFnPtr : Opcode {
+ let Args = [ArgFunction];
+}
+
+
//===----------------------------------------------------------------------===//
// Binary operators.
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/AST/Interp/PrimType.cpp b/clang/lib/AST/Interp/PrimType.cpp
index da07b6f0fb0c6..a9b5d8ea8cc8c 100644
--- a/clang/lib/AST/Interp/PrimType.cpp
+++ b/clang/lib/AST/Interp/PrimType.cpp
@@ -9,6 +9,7 @@
#include "PrimType.h"
#include "Boolean.h"
#include "Floating.h"
+#include "FunctionPointer.h"
#include "Pointer.h"
using namespace clang;
diff --git a/clang/lib/AST/Interp/PrimType.h b/clang/lib/AST/Interp/PrimType.h
index db9d8c3a85799..91311cf7030a7 100644
--- a/clang/lib/AST/Interp/PrimType.h
+++ b/clang/lib/AST/Interp/PrimType.h
@@ -24,6 +24,7 @@ namespace interp {
class Pointer;
class Boolean;
class Floating;
+class FunctionPointer;
/// Enumeration of the primitive types of the VM.
enum PrimType : unsigned {
@@ -38,6 +39,7 @@ enum PrimType : unsigned {
PT_Bool,
PT_Float,
PT_Ptr,
+ PT_FnPtr,
};
/// Mapping from primitive types to their representation.
@@ -53,6 +55,9 @@ template <> struct PrimConv<PT_Uint64> { using T = Integral<64, false>; };
template <> struct PrimConv<PT_Float> { using T = Floating; };
template <> struct PrimConv<PT_Bool> { using T = Boolean; };
template <> struct PrimConv<PT_Ptr> { using T = Pointer; };
+template <> struct PrimConv<PT_FnPtr> {
+ using T = FunctionPointer;
+};
/// Returns the size of a primitive type in bytes.
size_t primSize(PrimType Type);
@@ -90,6 +95,7 @@ static inline bool aligned(const void *P) {
TYPE_SWITCH_CASE(PT_Float, B) \
TYPE_SWITCH_CASE(PT_Bool, B) \
TYPE_SWITCH_CASE(PT_Ptr, B) \
+ TYPE_SWITCH_CASE(PT_FnPtr, B) \
} \
} while (0)
#define COMPOSITE_TYPE_SWITCH(Expr, B, D) \
diff --git a/clang/test/AST/Interp/functions.cpp b/clang/test/AST/Interp/functions.cpp
index 4fb5d3d98d749..48862d3997227 100644
--- a/clang/test/AST/Interp/functions.cpp
+++ b/clang/test/AST/Interp/functions.cpp
@@ -99,3 +99,66 @@ constexpr void invalid2() {
huh(); // expected-error {{use of undeclared identifier}} \
// ref-error {{use of undeclared identifier}}
}
+
+namespace FunctionPointers {
+ constexpr int add(int a, int b) {
+ return a + b;
+ }
+
+ struct S { int a; };
+ constexpr S getS() {
+ return S{12};
+ }
+
+ constexpr int applyBinOp(int a, int b, int (*op)(int, int)) {
+ return op(a, b);
+ }
+ static_assert(applyBinOp(1, 2, add) == 3, "");
+
+ constexpr int ignoreReturnValue() {
+ int (*foo)(int, int) = add;
+
+ foo(1, 2);
+ return 1;
+ }
+ static_assert(ignoreReturnValue() == 1, "");
+
+ constexpr int createS(S (*gimme)()) {
+ gimme(); // Ignored return value
+ return gimme().a;
+ }
+ static_assert(createS(getS) == 12, "");
+
+namespace FunctionReturnType {
+ typedef int (*ptr)(int*);
+ typedef ptr (*pm)();
+
+ constexpr int fun1(int* y) {
+ return *y + 10;
+ }
+ constexpr ptr fun() {
+ return &fun1;
+ }
+ static_assert(fun() == nullptr, ""); // expected-error {{static assertion failed}} \
+ // ref-error {{static assertion failed}}
+
+ constexpr int foo() {
+ int (*f)(int *) = fun();
+ int m = 0;
+
+ m = f(&m);
+
+ return m;
+ }
+ static_assert(foo() == 10);
+
+ struct S {
+ int i;
+ void (*fp)();
+ };
+
+ constexpr S s{ 12 };
+ static_assert(s.fp == nullptr); // zero-initialized function pointer.
+}
+
+}
More information about the cfe-commits
mailing list