[clang] [clang] Implement function pointer type discrimination (PR #96992)

Akira Hatanaka via cfe-commits cfe-commits at lists.llvm.org
Mon Jul 8 15:20:58 PDT 2024


https://github.com/ahatanak updated https://github.com/llvm/llvm-project/pull/96992

>From cf22a4be007f7e6fdc6e4c17c1f32fa70440b123 Mon Sep 17 00:00:00 2001
From: Akira Hatanaka <ahatanak at gmail.com>
Date: Wed, 26 Jun 2024 13:02:31 -0700
Subject: [PATCH 1/4] [clang] Implement function pointer type discrimination

Give users an option to sign a function pointer using a non-zero
discrimiantor based on the type of the destination.

Co-authored-by: John McCall <rjmccall at apple.com>
---
 clang/include/clang/AST/ASTContext.h          |   3 +
 clang/include/clang/AST/Type.h                |   3 +
 clang/include/clang/Basic/Features.def        |   1 +
 clang/include/clang/Basic/LangOptions.def     |   3 +
 clang/include/clang/CodeGen/CodeGenABITypes.h |   4 +
 clang/include/clang/Driver/Options.td         |   2 +
 clang/lib/AST/ASTContext.cpp                  | 263 ++++++++++++++++++
 clang/lib/CodeGen/CGExprConstant.cpp          |   5 +
 clang/lib/CodeGen/CGPointerAuth.cpp           |  68 ++++-
 clang/lib/CodeGen/CodeGenModule.h             |   4 +
 clang/lib/Driver/ToolChains/Clang.cpp         |   3 +
 clang/lib/Frontend/CompilerInvocation.cpp     |  11 +-
 .../ptrauth-function-type-discriminator.c     | 137 +++++++++
 clang/test/Preprocessor/ptrauth_feature.c     |  34 ++-
 14 files changed, 528 insertions(+), 13 deletions(-)
 create mode 100644 clang/test/CodeGen/ptrauth-function-type-discriminator.c

diff --git a/clang/include/clang/AST/ASTContext.h b/clang/include/clang/AST/ASTContext.h
index a99f2dc6eb3f2..8ba0c943a9c86 100644
--- a/clang/include/clang/AST/ASTContext.h
+++ b/clang/include/clang/AST/ASTContext.h
@@ -1283,6 +1283,9 @@ class ASTContext : public RefCountedBase<ASTContext> {
   uint16_t
   getPointerAuthVTablePointerDiscriminator(const CXXRecordDecl *RD);
 
+  /// Return the "other" type-specific discriminator for the given type.
+  uint16_t getPointerAuthTypeDiscriminator(QualType T);
+
   /// Apply Objective-C protocol qualifiers to the given type.
   /// \param allowOnPointerType specifies if we can apply protocol
   /// qualifiers on ObjCObjectPointerType. It can be set to true when
diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h
index a98899f7f4222..3aa0f05b0ab60 100644
--- a/clang/include/clang/AST/Type.h
+++ b/clang/include/clang/AST/Type.h
@@ -2506,6 +2506,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
   bool isFunctionNoProtoType() const { return getAs<FunctionNoProtoType>(); }
   bool isFunctionProtoType() const { return getAs<FunctionProtoType>(); }
   bool isPointerType() const;
+  bool isSignableType() const;
   bool isAnyPointerType() const;   // Any C pointer or ObjC object pointer
   bool isCountAttributedType() const;
   bool isBlockPointerType() const;
@@ -8001,6 +8002,8 @@ inline bool Type::isAnyPointerType() const {
   return isPointerType() || isObjCObjectPointerType();
 }
 
+inline bool Type::isSignableType() const { return isPointerType(); }
+
 inline bool Type::isBlockPointerType() const {
   return isa<BlockPointerType>(CanonicalType);
 }
diff --git a/clang/include/clang/Basic/Features.def b/clang/include/clang/Basic/Features.def
index 53f410d3cb4bd..3c9d0364880c5 100644
--- a/clang/include/clang/Basic/Features.def
+++ b/clang/include/clang/Basic/Features.def
@@ -110,6 +110,7 @@ FEATURE(ptrauth_vtable_pointer_address_discrimination, LangOpts.PointerAuthVTPtr
 FEATURE(ptrauth_vtable_pointer_type_discrimination, LangOpts.PointerAuthVTPtrTypeDiscrimination)
 FEATURE(ptrauth_member_function_pointer_type_discrimination, LangOpts.PointerAuthCalls)
 FEATURE(ptrauth_init_fini, LangOpts.PointerAuthInitFini)
+FEATURE(ptrauth_function_pointer_type_discrimination, LangOpts.FunctionPointerTypeDiscrimination)
 EXTENSION(swiftcc,
   PP.getTargetInfo().checkCallingConvention(CC_Swift) ==
   clang::TargetInfo::CCCR_OK)
diff --git a/clang/include/clang/Basic/LangOptions.def b/clang/include/clang/Basic/LangOptions.def
index 6dd6b5614f44c..554f8643812b5 100644
--- a/clang/include/clang/Basic/LangOptions.def
+++ b/clang/include/clang/Basic/LangOptions.def
@@ -470,6 +470,9 @@ ENUM_LANGOPT(StrictFlexArraysLevel, StrictFlexArraysLevelKind, 2,
 
 COMPATIBLE_VALUE_LANGOPT(MaxTokens, 32, 0, "Max number of tokens per TU or 0")
 
+BENIGN_LANGOPT(FunctionPointerTypeDiscrimination, 1, 0,
+               "Use type discrimination when signing function pointers")
+
 ENUM_LANGOPT(SignReturnAddressScope, SignReturnAddressScopeKind, 2, SignReturnAddressScopeKind::None,
              "Scope of return address signing")
 ENUM_LANGOPT(SignReturnAddressKey, SignReturnAddressKeyKind, 1, SignReturnAddressKeyKind::AKey,
diff --git a/clang/include/clang/CodeGen/CodeGenABITypes.h b/clang/include/clang/CodeGen/CodeGenABITypes.h
index d4822dc160082..9cbc5a8a2a3f4 100644
--- a/clang/include/clang/CodeGen/CodeGenABITypes.h
+++ b/clang/include/clang/CodeGen/CodeGenABITypes.h
@@ -108,6 +108,10 @@ unsigned getLLVMFieldNumber(CodeGenModule &CGM,
 /// Return a declaration discriminator for the given global decl.
 uint16_t getPointerAuthDeclDiscriminator(CodeGenModule &CGM, GlobalDecl GD);
 
+/// Return a type discriminator for the given function type.
+uint16_t getPointerAuthTypeDiscriminator(CodeGenModule &CGM,
+                                         QualType FunctionType);
+
 /// Given the language and code-generation options that Clang was configured
 /// with, set the default LLVM IR attributes for a function definition.
 /// The attributes set here are mostly global target-configuration and
diff --git a/clang/include/clang/Driver/Options.td b/clang/include/clang/Driver/Options.td
index dd55838dcf384..a6f960e42d23f 100644
--- a/clang/include/clang/Driver/Options.td
+++ b/clang/include/clang/Driver/Options.td
@@ -4228,6 +4228,8 @@ defm ptrauth_vtable_pointer_address_discrimination :
 defm ptrauth_vtable_pointer_type_discrimination :
   OptInCC1FFlag<"ptrauth-vtable-pointer-type-discrimination", "Enable type discrimination of vtable pointers">;
 defm ptrauth_init_fini : OptInCC1FFlag<"ptrauth-init-fini", "Enable signing of function pointers in init/fini arrays">;
+defm ptrauth_function_pointer_type_discrimination : OptInCC1FFlag<"ptrauth-function-pointer-type-discrimination",
+  "Enabling type discrimination on C function pointers">;
 }
 
 def fenable_matrix : Flag<["-"], "fenable-matrix">, Group<f_Group>,
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index 84deaf5429df7..74e0ae0a58e9f 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -3140,6 +3140,269 @@ ASTContext::getPointerAuthVTablePointerDiscriminator(const CXXRecordDecl *RD) {
   return llvm::getPointerAuthStableSipHash(Str);
 }
 
+/// Encode a function type for use in the discriminator of a function pointer
+/// type. We can't use the itanium scheme for this since C has quite permissive
+/// rules for type compatibility that we need to be compatible with.
+///
+/// Formally, this function associates every function pointer type T with an
+/// encoded string E(T). Let the equivalence relation T1 ~ T2 be defined as
+/// E(T1) == E(T2). E(T) is part of the ABI of values of type T. C type
+/// compatibility requires equivalent treatment under the ABI, so
+/// CCompatible(T1, T2) must imply E(T1) == E(T2), that is, CCompatible must be
+/// a subset of ~. Crucially, however, it must be a proper subset because
+/// CCompatible is not an equivalence relation: for example, int[] is compatible
+/// with both int[1] and int[2], but the latter are not compatible with each
+/// other. Therefore this encoding function must be careful to only distinguish
+/// types if there is no third type with which they are both required to be
+/// compatible.
+static void encodeTypeForFunctionPointerAuth(ASTContext &Ctx, raw_ostream &OS,
+                                             QualType QT) {
+  // FIXME: Consider address space qualifiers.
+  const Type *T = QT.getCanonicalType().getTypePtr();
+
+  // FIXME: Consider using the C++ type mangling when we encounter a construct
+  // that is incompatible with C.
+
+  switch (T->getTypeClass()) {
+  case Type::Atomic:
+    return encodeTypeForFunctionPointerAuth(
+        Ctx, OS, cast<AtomicType>(T)->getValueType());
+
+  case Type::LValueReference:
+    OS << "R";
+    encodeTypeForFunctionPointerAuth(Ctx, OS,
+                                     cast<ReferenceType>(T)->getPointeeType());
+    return;
+  case Type::RValueReference:
+    OS << "O";
+    encodeTypeForFunctionPointerAuth(Ctx, OS,
+                                     cast<ReferenceType>(T)->getPointeeType());
+    return;
+
+  case Type::Pointer:
+    // C11 6.7.6.1p2:
+    //   For two pointer types to be compatible, both shall be identically
+    //   qualified and both shall be pointers to compatible types.
+    // FIXME: we should also consider pointee types.
+    OS << "P";
+    return;
+
+  case Type::ObjCObjectPointer:
+  case Type::BlockPointer:
+    OS << "P";
+    return;
+
+  case Type::Complex:
+    OS << "C";
+    return encodeTypeForFunctionPointerAuth(
+        Ctx, OS, cast<ComplexType>(T)->getElementType());
+
+  case Type::VariableArray:
+  case Type::ConstantArray:
+  case Type::IncompleteArray:
+  case Type::ArrayParameter:
+    // C11 6.7.6.2p6:
+    //   For two array types to be compatible, both shall have compatible
+    //   element types, and if both size specifiers are present, and are integer
+    //   constant expressions, then both size specifiers shall have the same
+    //   constant value [...]
+    //
+    // So since ElemType[N] has to be compatible ElemType[], we can't encode the
+    // width of the array.
+    OS << "A";
+    return encodeTypeForFunctionPointerAuth(
+        Ctx, OS, cast<ArrayType>(T)->getElementType());
+
+  case Type::ObjCInterface:
+  case Type::ObjCObject:
+    OS << "<objc_object>";
+    return;
+
+  case Type::Enum:
+    // C11 6.7.2.2p4:
+    //   Each enumerated type shall be compatible with char, a signed integer
+    //   type, or an unsigned integer type.
+    //
+    // So we have to treat enum types as integers.
+    OS << "i";
+    return;
+
+  case Type::FunctionNoProto:
+  case Type::FunctionProto: {
+    // C11 6.7.6.3p15:
+    //   For two function types to be compatible, both shall specify compatible
+    //   return types. Moreover, the parameter type lists, if both are present,
+    //   shall agree in the number of parameters and in the use of the ellipsis
+    //   terminator; corresponding parameters shall have compatible types.
+    //
+    // That paragraph goes on to describe how unprototyped functions are to be
+    // handled, which we ignore here. Unprototyped function pointers are hashed
+    // as though they were prototyped nullary functions since thats probably
+    // what the user meant. This behavior is non-conforming.
+    // FIXME: If we add a "custom discriminator" function type attribute we
+    // should encode functions as their discriminators.
+    OS << "F";
+    auto *FuncType = cast<FunctionType>(T);
+    encodeTypeForFunctionPointerAuth(Ctx, OS, FuncType->getReturnType());
+    if (auto *FPT = dyn_cast<FunctionProtoType>(FuncType)) {
+      for (QualType Param : FPT->param_types()) {
+        Param = Ctx.getSignatureParameterType(Param);
+        encodeTypeForFunctionPointerAuth(Ctx, OS, Param);
+      }
+      if (FPT->isVariadic())
+        OS << "z";
+    }
+    OS << "E";
+    return;
+  }
+
+  case Type::MemberPointer: {
+    OS << "M";
+    auto *MPT = T->getAs<MemberPointerType>();
+    encodeTypeForFunctionPointerAuth(Ctx, OS, QualType(MPT->getClass(), 0));
+    encodeTypeForFunctionPointerAuth(Ctx, OS, MPT->getPointeeType());
+    return;
+  }
+  case Type::ExtVector:
+  case Type::Vector:
+    OS << "Dv" << Ctx.getTypeSizeInChars(T).getQuantity();
+    break;
+
+  // Don't bother discriminating based on these types.
+  case Type::Pipe:
+  case Type::BitInt:
+  case Type::ConstantMatrix:
+    OS << "?";
+    return;
+
+  case Type::Builtin: {
+    const BuiltinType *BTy = T->getAs<BuiltinType>();
+    switch (BTy->getKind()) {
+#define SIGNED_TYPE(Id, SingletonId)                                           \
+  case BuiltinType::Id:                                                        \
+    OS << "i";                                                                 \
+    return;
+#define UNSIGNED_TYPE(Id, SingletonId)                                         \
+  case BuiltinType::Id:                                                        \
+    OS << "i";                                                                 \
+    return;
+#define PLACEHOLDER_TYPE(Id, SingletonId) case BuiltinType::Id:
+#define BUILTIN_TYPE(Id, SingletonId)
+#include "clang/AST/BuiltinTypes.def"
+      llvm_unreachable("placeholder types should not appear here.");
+
+    case BuiltinType::Half:
+      OS << "Dh";
+      return;
+    case BuiltinType::Float:
+      OS << "f";
+      return;
+    case BuiltinType::Double:
+      OS << "d";
+      return;
+    case BuiltinType::LongDouble:
+      OS << "e";
+      return;
+    case BuiltinType::Float16:
+      OS << "DF16_";
+      return;
+    case BuiltinType::Float128:
+      OS << "g";
+      return;
+
+    case BuiltinType::Void:
+      OS << "v";
+      return;
+
+    case BuiltinType::ObjCId:
+    case BuiltinType::ObjCClass:
+    case BuiltinType::ObjCSel:
+    case BuiltinType::NullPtr:
+      OS << "P";
+      return;
+
+    // Don't bother discriminating based on OpenCL types.
+    case BuiltinType::OCLSampler:
+    case BuiltinType::OCLEvent:
+    case BuiltinType::OCLClkEvent:
+    case BuiltinType::OCLQueue:
+    case BuiltinType::OCLReserveID:
+    case BuiltinType::BFloat16:
+    case BuiltinType::VectorQuad:
+    case BuiltinType::VectorPair:
+      OS << "?";
+      return;
+
+    // Don't bother discriminating based on these seldom-used types.
+    case BuiltinType::Ibm128:
+      return;
+#define IMAGE_TYPE(ImgType, Id, SingletonId, Access, Suffix)                   \
+  case BuiltinType::Id:                                                        \
+    return;
+#include "clang/Basic/OpenCLImageTypes.def"
+#define EXT_OPAQUE_TYPE(ExtType, Id, Ext)                                      \
+  case BuiltinType::Id:                                                        \
+    return;
+#include "clang/Basic/OpenCLExtensionTypes.def"
+#define SVE_TYPE(Name, Id, SingletonId)                                        \
+  case BuiltinType::Id:                                                        \
+    return;
+#include "clang/Basic/AArch64SVEACLETypes.def"
+    case BuiltinType::Dependent:
+      llvm_unreachable("should never get here");
+    case BuiltinType::AMDGPUBufferRsrc:
+    case BuiltinType::WasmExternRef:
+#define RVV_TYPE(Name, Id, SingletonId) case BuiltinType::Id:
+#include "clang/Basic/RISCVVTypes.def"
+      llvm_unreachable("not yet implemented");
+    }
+  }
+  case Type::Record: {
+    RecordDecl *RD = T->getAs<RecordType>()->getDecl();
+    IdentifierInfo *II = RD->getIdentifier();
+    if (!II)
+      if (const TypedefNameDecl *Typedef = RD->getTypedefNameForAnonDecl())
+        II = Typedef->getDeclName().getAsIdentifierInfo();
+
+    if (!II) {
+      OS << "<anonymous_record>";
+      return;
+    }
+    OS << II->getLength() << II->getName();
+    return;
+  }
+  case Type::DeducedTemplateSpecialization:
+  case Type::Auto:
+#define NON_CANONICAL_TYPE(Class, Base) case Type::Class:
+#define DEPENDENT_TYPE(Class, Base) case Type::Class:
+#define NON_CANONICAL_UNLESS_DEPENDENT_TYPE(Class, Base) case Type::Class:
+#define ABSTRACT_TYPE(Class, Base)
+#define TYPE(Class, Base)
+#include "clang/AST/TypeNodes.inc"
+    llvm_unreachable("unexpected non-canonical or dependent type!");
+    return;
+  }
+}
+
+uint16_t ASTContext::getPointerAuthTypeDiscriminator(QualType T) {
+  assert(!T->isDependentType() &&
+         "cannot compute type discriminator of a dependent type");
+
+  SmallString<256> Str;
+  llvm::raw_svector_ostream Out(Str);
+
+  if (T->isFunctionPointerType() || T->isFunctionReferenceType())
+    T = T->getPointeeType();
+
+  if (T->isFunctionType())
+    encodeTypeForFunctionPointerAuth(*this, Out, T);
+  else
+    llvm_unreachable(
+        "type discrimination of non-function type not implemented yet");
+
+  return llvm::getPointerAuthStableSipHash(Str);
+}
+
 QualType ASTContext::getObjCGCQualType(QualType T,
                                        Qualifiers::GC GCAttr) const {
   QualType CanT = getCanonicalType(T);
diff --git a/clang/lib/CodeGen/CGExprConstant.cpp b/clang/lib/CodeGen/CGExprConstant.cpp
index 1fec587b5c4c7..6d412af5cc994 100644
--- a/clang/lib/CodeGen/CGExprConstant.cpp
+++ b/clang/lib/CodeGen/CGExprConstant.cpp
@@ -2220,6 +2220,11 @@ llvm::Constant *ConstantLValueEmitter::emitPointerAuthPointer(const Expr *E) {
 
   // The assertions here are all checked by Sema.
   assert(Result.Val.isLValue());
+  auto *Base = Result.Val.getLValueBase().get<const ValueDecl *>();
+  if (auto *Decl = dyn_cast_or_null<FunctionDecl>(Base)) {
+    assert(Result.Val.getLValueOffset().isZero());
+    return CGM.getRawFunctionPointer(Decl);
+  }
   return ConstantEmitter(CGM, Emitter.CGF)
       .emitAbstract(E->getExprLoc(), Result.Val, E->getType());
 }
diff --git a/clang/lib/CodeGen/CGPointerAuth.cpp b/clang/lib/CodeGen/CGPointerAuth.cpp
index 673f6e60bfc31..33b421437c74c 100644
--- a/clang/lib/CodeGen/CGPointerAuth.cpp
+++ b/clang/lib/CodeGen/CGPointerAuth.cpp
@@ -15,6 +15,7 @@
 #include "CodeGenModule.h"
 #include "clang/CodeGen/CodeGenABITypes.h"
 #include "clang/CodeGen/ConstantInitBuilder.h"
+#include "llvm/Analysis/ValueTracking.h"
 #include "llvm/Support/SipHash.h"
 
 using namespace clang;
@@ -29,7 +30,9 @@ llvm::ConstantInt *CodeGenModule::getPointerAuthOtherDiscriminator(
     return nullptr;
 
   case PointerAuthSchema::Discrimination::Type:
-    llvm_unreachable("type discrimination not implemented yet");
+    assert(!Type.isNull() && "type not provided for type-discriminated schema");
+    return llvm::ConstantInt::get(
+        IntPtrTy, getContext().getPointerAuthTypeDiscriminator(Type));
 
   case PointerAuthSchema::Discrimination::Decl:
     assert(Decl.getDecl() &&
@@ -43,6 +46,11 @@ llvm::ConstantInt *CodeGenModule::getPointerAuthOtherDiscriminator(
   llvm_unreachable("bad discrimination kind");
 }
 
+uint16_t CodeGen::getPointerAuthTypeDiscriminator(CodeGenModule &CGM,
+                                                  QualType FunctionType) {
+  return CGM.getContext().getPointerAuthTypeDiscriminator(FunctionType);
+}
+
 uint16_t CodeGen::getPointerAuthDeclDiscriminator(CodeGenModule &CGM,
                                                   GlobalDecl Declaration) {
   return CGM.getPointerAuthDeclDiscriminator(Declaration);
@@ -71,12 +79,15 @@ CGPointerAuthInfo CodeGenModule::getFunctionPointerAuthInfo(QualType T) {
   assert(!Schema.isAddressDiscriminated() &&
          "function pointers cannot use address-specific discrimination");
 
-  assert(!Schema.hasOtherDiscrimination() &&
-         "function pointers don't support any discrimination yet");
+  llvm::Constant *Discriminator = nullptr;
+  if (T->isFunctionPointerType() || T->isFunctionReferenceType())
+    T = T->getPointeeType();
+  if (T->isFunctionType())
+    Discriminator = getPointerAuthOtherDiscriminator(Schema, GlobalDecl(), T);
 
   return CGPointerAuthInfo(Schema.getKey(), Schema.getAuthenticationMode(),
                            /*IsaPointer=*/false, /*AuthenticatesNull=*/false,
-                           /*Discriminator=*/nullptr);
+                           Discriminator);
 }
 
 llvm::Value *
@@ -114,6 +125,47 @@ CGPointerAuthInfo CodeGenFunction::EmitPointerAuthInfo(
                            Schema.authenticatesNullValues(), Discriminator);
 }
 
+/// Return the natural pointer authentication for values of the given
+/// pointee type.
+static CGPointerAuthInfo
+getPointerAuthInfoForPointeeType(CodeGenModule &CGM, QualType PointeeType) {
+  if (PointeeType.isNull())
+    return CGPointerAuthInfo();
+
+  // Function pointers use the function-pointer schema by default.
+  if (PointeeType->isFunctionType())
+    return CGM.getFunctionPointerAuthInfo(PointeeType);
+
+  // Normal data pointers never use direct pointer authentication by default.
+  return CGPointerAuthInfo();
+}
+
+CGPointerAuthInfo CodeGenModule::getPointerAuthInfoForPointeeType(QualType T) {
+  return ::getPointerAuthInfoForPointeeType(*this, T);
+}
+
+/// Return the natural pointer authentication for values of the given
+/// pointer type.
+static CGPointerAuthInfo getPointerAuthInfoForType(CodeGenModule &CGM,
+                                                   QualType PointerType) {
+  assert(PointerType->isSignableType());
+
+  // Block pointers are currently not signed.
+  if (PointerType->isBlockPointerType())
+    return CGPointerAuthInfo();
+
+  auto PointeeType = PointerType->getPointeeType();
+
+  if (PointeeType.isNull())
+    return CGPointerAuthInfo();
+
+  return ::getPointerAuthInfoForPointeeType(CGM, PointeeType);
+}
+
+CGPointerAuthInfo CodeGenModule::getPointerAuthInfoForType(QualType T) {
+  return ::getPointerAuthInfoForType(*this, T);
+}
+
 llvm::Constant *
 CodeGenModule::getConstantSignedPointer(llvm::Constant *Pointer, unsigned Key,
                                         llvm::Constant *StorageAddress,
@@ -180,6 +232,14 @@ llvm::Constant *CodeGenModule::getFunctionPointer(GlobalDecl GD,
                                                   llvm::Type *Ty) {
   const auto *FD = cast<FunctionDecl>(GD.getDecl());
   QualType FuncType = FD->getType();
+
+  // Annoyingly, K&R functions have prototypes in the clang AST, but
+  // expressions referring to them are unprototyped.
+  if (!FD->hasPrototype())
+    if (const auto *Proto = FuncType->getAs<FunctionProtoType>())
+      FuncType = Context.getFunctionNoProtoType(Proto->getReturnType(),
+                                                Proto->getExtInfo());
+
   return getFunctionPointer(getRawFunctionPointer(GD, Ty), FuncType);
 }
 
diff --git a/clang/lib/CodeGen/CodeGenModule.h b/clang/lib/CodeGen/CodeGenModule.h
index 22b2b314c316c..8a0a2c8731e72 100644
--- a/clang/lib/CodeGen/CodeGenModule.h
+++ b/clang/lib/CodeGen/CodeGenModule.h
@@ -964,6 +964,10 @@ class CodeGenModule : public CodeGenTypeCache {
 
   CGPointerAuthInfo getFunctionPointerAuthInfo(QualType T);
 
+  CGPointerAuthInfo getPointerAuthInfoForPointeeType(QualType type);
+
+  CGPointerAuthInfo getPointerAuthInfoForType(QualType type);
+
   bool shouldSignPointer(const PointerAuthSchema &Schema);
   llvm::Constant *getConstantSignedPointer(llvm::Constant *Pointer,
                                            const PointerAuthSchema &Schema,
diff --git a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp
index c0f6bc0c2e45a..35dc95c5524d3 100644
--- a/clang/lib/Driver/ToolChains/Clang.cpp
+++ b/clang/lib/Driver/ToolChains/Clang.cpp
@@ -1789,6 +1789,9 @@ void Clang::AddAArch64TargetArgs(const ArgList &Args,
       options::OPT_fno_ptrauth_vtable_pointer_type_discrimination);
   Args.addOptInFlag(CmdArgs, options::OPT_fptrauth_init_fini,
                     options::OPT_fno_ptrauth_init_fini);
+  Args.addOptInFlag(
+      CmdArgs, options::OPT_fptrauth_function_pointer_type_discrimination,
+      options::OPT_fno_ptrauth_function_pointer_type_discrimination);
 }
 
 void Clang::AddLoongArchTargetArgs(const ArgList &Args,
diff --git a/clang/lib/Frontend/CompilerInvocation.cpp b/clang/lib/Frontend/CompilerInvocation.cpp
index f42e28ba7e629..5301802875ab0 100644
--- a/clang/lib/Frontend/CompilerInvocation.cpp
+++ b/clang/lib/Frontend/CompilerInvocation.cpp
@@ -1466,8 +1466,10 @@ void CompilerInvocation::setDefaultPointerAuthOptions(
     using Key = PointerAuthSchema::ARM8_3Key;
     using Discrimination = PointerAuthSchema::Discrimination;
     // If you change anything here, be sure to update <ptrauth.h>.
-    Opts.FunctionPointers =
-        PointerAuthSchema(Key::ASIA, false, Discrimination::None);
+    Opts.FunctionPointers = PointerAuthSchema(
+        Key::ASIA, false,
+        LangOpts.FunctionPointerTypeDiscrimination ? Discrimination::Type
+                                                   : Discrimination::None);
 
     Opts.CXXVTablePointers = PointerAuthSchema(
         Key::ASDA, LangOpts.PointerAuthVTPtrAddressDiscrimination,
@@ -3398,6 +3400,8 @@ static void GeneratePointerAuthArgs(const LangOptions &Opts,
     GenerateArg(Consumer, OPT_fptrauth_vtable_pointer_type_discrimination);
   if (Opts.PointerAuthInitFini)
     GenerateArg(Consumer, OPT_fptrauth_init_fini);
+  if (Opts.FunctionPointerTypeDiscrimination)
+    GenerateArg(Consumer, OPT_fptrauth_function_pointer_type_discrimination);
 }
 
 static void ParsePointerAuthArgs(LangOptions &Opts, ArgList &Args,
@@ -4779,6 +4783,9 @@ bool CompilerInvocation::CreateFromArgsImpl(
   if (Res.getFrontendOpts().ProgramAction == frontend::RewriteObjC)
     LangOpts.ObjCExceptions = 1;
 
+  LangOpts.FunctionPointerTypeDiscrimination =
+      Args.hasArg(OPT_fptrauth_function_pointer_type_discrimination);
+
   for (auto Warning : Res.getDiagnosticOpts().Warnings) {
     if (Warning == "misexpect" &&
         !Diags.isIgnored(diag::warn_profile_data_misexpect, SourceLocation())) {
diff --git a/clang/test/CodeGen/ptrauth-function-type-discriminator.c b/clang/test/CodeGen/ptrauth-function-type-discriminator.c
new file mode 100644
index 0000000000000..5dea48fe5915b
--- /dev/null
+++ b/clang/test/CodeGen/ptrauth-function-type-discriminator.c
@@ -0,0 +1,137 @@
+// RUN: %clang_cc1 %s       -fptrauth-function-pointer-type-discrimination -triple arm64e-apple-ios13 -fptrauth-calls -fptrauth-intrinsics -disable-llvm-passes -emit-llvm -o- | FileCheck %s --check-prefix=CHECK --check-prefix=CHECKC
+// RUN: %clang_cc1 -xc++ %s -fptrauth-function-pointer-type-discrimination -triple arm64e-apple-ios13 -fptrauth-calls -fptrauth-intrinsics -disable-llvm-passes -emit-llvm -o- | FileCheck %s --check-prefix=CHECK
+// RUN: %clang_cc1 -fptrauth-function-pointer-type-discrimination -triple arm64-apple-ios -fptrauth-calls -fptrauth-intrinsics -emit-pch %s -o %t.ast
+// RUN: %clang_cc1 -fptrauth-function-pointer-type-discrimination -triple arm64-apple-ios -fptrauth-calls -fptrauth-intrinsics -emit-llvm -x ast -o - %t.ast | FileCheck -check-prefix=CHECK --check-prefix=CHECKC %s
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+void f(void);
+void f2(int);
+void (*fnptr)(void);
+void *opaque;
+unsigned long uintptr;
+
+// CHECK: @test_constant_null = global ptr null
+void (*test_constant_null)(int) = 0;
+
+// CHECK: @test_constant_cast = global ptr ptrauth (ptr @f, i32 0, i64 2712)
+void (*test_constant_cast)(int) = (void (*)(int))f;
+
+// CHECK: @test_opaque = global ptr ptrauth (ptr @f, i32 0)
+void *test_opaque =
+#ifdef __cplusplus
+    (void *)
+#endif
+    (void (*)(int))(double (*)(double))f;
+
+// CHECK: @test_intptr_t = global i64 ptrtoint (ptr ptrauth (ptr @f, i32 0) to i64)
+unsigned long test_intptr_t = (unsigned long)f;
+
+// CHECK: @test_through_long = global ptr ptrauth (ptr @f, i32 0, i64 2712)
+void (*test_through_long)(int) = (void (*)(int))(long)f;
+
+// CHECK: @test_to_long = global i64 ptrtoint (ptr ptrauth (ptr @f, i32 0) to i64)
+long test_to_long = (long)(double (*)())f;
+
+extern void external_function(void);
+// CHECK: @fptr1 = global ptr ptrauth (ptr @external_function, i32 0, i64 18983)
+void (*fptr1)(void) = external_function;
+// CHECK: @fptr2 = global ptr ptrauth (ptr @external_function, i32 0, i64 18983)
+void (*fptr2)(void) = &external_function;
+
+// CHECK: @fptr3 = global ptr ptrauth (ptr @external_function, i32 2, i64 26)
+void (*fptr3)(void) = __builtin_ptrauth_sign_constant(&external_function, 2, 26);
+
+// CHECK: @fptr4 = global ptr ptrauth (ptr @external_function, i32 2, i64 26, ptr @fptr4)
+void (*fptr4)(void) = __builtin_ptrauth_sign_constant(&external_function, 2, __builtin_ptrauth_blend_discriminator(&fptr4, 26));
+
+// CHECK-LABEL: define void @test_call()
+void test_call() {
+  // CHECK:      [[T0:%.*]] = load ptr, ptr @fnptr,
+  // CHECK-NEXT: call void [[T0]]() [ "ptrauth"(i32 0, i64 18983) ]
+  fnptr();
+}
+
+// CHECK-LABEL: define ptr @test_function_pointer()
+// CHECK:  ret ptr ptrauth (ptr @external_function, i32 0, i64 18983)
+void (*test_function_pointer())(void) {
+  return external_function;
+}
+
+struct InitiallyIncomplete;
+extern struct InitiallyIncomplete returns_initially_incomplete(void);
+// CHECK-LABEL: define void @use_while_incomplete()
+void use_while_incomplete() {
+  // CHECK:      [[VAR:%.*]] = alloca ptr,
+  // CHECK-NEXT: store ptr ptrauth (ptr @returns_initially_incomplete, i32 0, i64 25106), ptr [[VAR]]
+  struct InitiallyIncomplete (*fnptr)(void) = &returns_initially_incomplete;
+}
+struct InitiallyIncomplete { int x; };
+// CHECK-LABEL: define void @use_while_complete()
+void use_while_complete() {
+  // CHECK:      [[VAR:%.*]] = alloca ptr,
+  // CHECK-NEXT: store ptr ptrauth (ptr @returns_initially_incomplete, i32 0, i64 25106), ptr [[VAR]]
+  // CHECK-NEXT: ret void
+  struct InitiallyIncomplete (*fnptr)(void) = &returns_initially_incomplete;
+}
+
+#ifndef __cplusplus
+
+void knr(param)
+  int param;
+{}
+
+// CHECKC-LABEL: define void @test_knr
+void test_knr() {
+  void (*p)() = knr;
+  p(0);
+
+  // CHECKC: [[P:%.*]] = alloca ptr
+  // CHECKC: store ptr ptrauth (ptr @knr, i32 0, i64 18983), ptr [[P]]
+  // CHECKC: [[LOAD:%.*]] = load ptr, ptr [[P]]
+  // CHECKC: call void [[LOAD]](i32 noundef 0) [ "ptrauth"(i32 0, i64 18983) ]
+}
+
+// CHECKC-LABEL: define void @test_redeclaration
+void test_redeclaration() {
+  void redecl();
+  void (*ptr)() = redecl;
+  void redecl(int);
+  void (*ptr2)(int) = redecl;
+  ptr();
+  ptr2(0);
+
+  // CHECKC: store ptr ptrauth (ptr @redecl, i32 0, i64 18983), ptr %ptr
+  // CHECKC: store ptr ptrauth (ptr @redecl, i32 0, i64 2712), ptr %ptr2
+  // CHECKC: call void {{.*}}() [ "ptrauth"(i32 0, i64 18983) ]
+  // CHECKC: call void {{.*}}(i32 noundef 0) [ "ptrauth"(i32 0, i64 2712) ]
+}
+
+void knr2(param)
+     int param;
+{}
+
+// CHECKC-LABEL: define void @test_redecl_knr
+void test_redecl_knr() {
+  void (*p)() = knr2;
+  p();
+
+  // CHECKC: store ptr ptrauth (ptr @knr2, i32 0, i64 18983)
+  // CHECKC: call void {{.*}}() [ "ptrauth"(i32 0, i64 18983) ]
+
+  void knr2(int);
+
+  void (*p2)(int) = knr2;
+  p2(0);
+
+  // CHECKC: store ptr ptrauth (ptr @knr2, i32 0, i64 2712)
+  // CHECKC: call void {{.*}}(i32 noundef 0) [ "ptrauth"(i32 0, i64 2712) ]
+}
+
+#endif
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/clang/test/Preprocessor/ptrauth_feature.c b/clang/test/Preprocessor/ptrauth_feature.c
index 80e239110ffc7..88b6982c01657 100644
--- a/clang/test/Preprocessor/ptrauth_feature.c
+++ b/clang/test/Preprocessor/ptrauth_feature.c
@@ -5,7 +5,7 @@
 // RUN:   -fptrauth-vtable-pointer-address-discrimination \
 // RUN:   -fptrauth-vtable-pointer-type-discrimination \
 // RUN:   -fptrauth-init-fini | \
-// RUN:   FileCheck %s --check-prefixes=INTRIN,CALLS,RETS,VPTR_ADDR_DISCR,VPTR_TYPE_DISCR,INITFINI
+// RUN:   FileCheck %s --check-prefixes=INTRIN,CALLS,RETS,VPTR_ADDR_DISCR,VPTR_TYPE_DISCR,INITFINI,NOFUNC
 
 // RUN: %clang_cc1 -E %s -triple=aarch64 \
 // RUN:   -fptrauth-calls \
@@ -13,7 +13,7 @@
 // RUN:   -fptrauth-vtable-pointer-address-discrimination \
 // RUN:   -fptrauth-vtable-pointer-type-discrimination \
 // RUN:   -fptrauth-init-fini | \
-// RUN:   FileCheck %s --check-prefixes=NOINTRIN,CALLS,RETS,VPTR_ADDR_DISCR,VPTR_TYPE_DISCR,INITFINI
+// RUN:   FileCheck %s --check-prefixes=NOINTRIN,CALLS,RETS,VPTR_ADDR_DISCR,VPTR_TYPE_DISCR,INITFINI,NOFUNC
 
 // RUN: %clang_cc1 -E %s -triple=aarch64 \
 // RUN:   -fptrauth-intrinsics \
@@ -21,7 +21,7 @@
 // RUN:   -fptrauth-vtable-pointer-address-discrimination \
 // RUN:   -fptrauth-vtable-pointer-type-discrimination \
 // RUN:   -fptrauth-init-fini | \
-// RUN:   FileCheck %s --check-prefixes=INTRIN,NOCALLS,RETS,VPTR_ADDR_DISCR,VPTR_TYPE_DISCR,INITFINI
+// RUN:   FileCheck %s --check-prefixes=INTRIN,NOCALLS,RETS,VPTR_ADDR_DISCR,VPTR_TYPE_DISCR,INITFINI,NOFUNC
 
 // RUN: %clang_cc1 -E %s -triple=aarch64 \
 // RUN:   -fptrauth-intrinsics \
@@ -29,7 +29,7 @@
 // RUN:   -fptrauth-vtable-pointer-address-discrimination \
 // RUN:   -fptrauth-vtable-pointer-type-discrimination \
 // RUN:   -fptrauth-init-fini | \
-// RUN:   FileCheck %s --check-prefixes=INTRIN,CALLS,NORETS,VPTR_ADDR_DISCR,VPTR_TYPE_DISCR,INITFINI
+// RUN:   FileCheck %s --check-prefixes=INTRIN,CALLS,NORETS,VPTR_ADDR_DISCR,VPTR_TYPE_DISCR,INITFINI,NOFUNC
 
 // RUN: %clang_cc1 -E %s -triple=aarch64 \
 // RUN:   -fptrauth-intrinsics \
@@ -37,7 +37,7 @@
 // RUN:   -fptrauth-returns \
 // RUN:   -fptrauth-vtable-pointer-type-discrimination \
 // RUN:   -fptrauth-init-fini | \
-// RUN:   FileCheck %s --check-prefixes=INTRIN,CALLS,RETS,NOVPTR_ADDR_DISCR,VPTR_TYPE_DISCR,INITFINI
+// RUN:   FileCheck %s --check-prefixes=INTRIN,CALLS,RETS,NOVPTR_ADDR_DISCR,VPTR_TYPE_DISCR,INITFINI,NOFUNC
 
 // RUN: %clang_cc1 -E %s -triple=aarch64 \
 // RUN:   -fptrauth-intrinsics \
@@ -45,7 +45,7 @@
 // RUN:   -fptrauth-returns \
 // RUN:   -fptrauth-vtable-pointer-address-discrimination \
 // RUN:   -fptrauth-init-fini | \
-// RUN:   FileCheck %s --check-prefixes=INTRIN,CALLS,RETS,VPTR_ADDR_DISCR,NOVPTR_TYPE_DISCR,INITFINI
+// RUN:   FileCheck %s --check-prefixes=INTRIN,CALLS,RETS,VPTR_ADDR_DISCR,NOVPTR_TYPE_DISCR,INITFINI,NOFUNC
 
 // RUN: %clang_cc1 -E %s -triple=aarch64 \
 // RUN:   -fptrauth-intrinsics \
@@ -53,7 +53,17 @@
 // RUN:   -fptrauth-returns \
 // RUN:   -fptrauth-vtable-pointer-address-discrimination \
 // RUN:   -fptrauth-vtable-pointer-type-discrimination | \
-// RUN:   FileCheck %s --check-prefixes=INTRIN,CALLS,RETS,VPTR_ADDR_DISCR,VPTR_TYPE_DISCR,NOINITFINI
+// RUN:   FileCheck %s --check-prefixes=INTRIN,CALLS,RETS,VPTR_ADDR_DISCR,VPTR_TYPE_DISCR,NOINITFINI,NOFUNC
+
+// RUN: %clang_cc1 -E %s -triple=aarch64 \
+// RUN:   -fptrauth-intrinsics \
+// RUN:   -fptrauth-calls \
+// RUN:   -fptrauth-returns \
+// RUN:   -fptrauth-vtable-pointer-address-discrimination \
+// RUN:   -fptrauth-vtable-pointer-type-discrimination \
+// RUN:   -fptrauth-function-pointer-type-discrimination | \
+// RUN:   FileCheck %s --check-prefixes=INTRIN,CALLS,RETS,VPTR_ADDR_DISCR,VPTR_TYPE_DISCR,NOINITFINI,FUNC
+
 
 #if __has_feature(ptrauth_intrinsics)
 // INTRIN: has_ptrauth_intrinsics
@@ -111,3 +121,13 @@ void has_ptrauth_init_fini() {}
 // NOINITFINI: no_ptrauth_init_fini
 void no_ptrauth_init_fini() {}
 #endif
+
+#include <ptrauth.h>
+
+#if __has_feature(ptrauth_function_pointer_type_discrimination)
+// FUNC: has_ptrauth_function_pointer_type_discrimination
+void has_ptrauth_function_pointer_type_discrimination() {}
+#else
+// NOFUNC: no_ptrauth_function_pointer_type_discrimination
+void no_ptrauth_function_pointer_type_discrimination() {}
+#endif

>From 1d5757a4ba1aad6cedaae239b62128c9f9ab0c43 Mon Sep 17 00:00:00 2001
From: Akira Hatanaka <ahatanak at gmail.com>
Date: Fri, 28 Jun 2024 08:02:36 -0700
Subject: [PATCH 2/4] Rename langopt

---
 clang/include/clang/Basic/Features.def    | 2 +-
 clang/include/clang/Basic/LangOptions.def | 2 +-
 clang/lib/Frontend/CompilerInvocation.cpp | 8 ++++----
 3 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/clang/include/clang/Basic/Features.def b/clang/include/clang/Basic/Features.def
index 3c9d0364880c5..2f864ff1c0edf 100644
--- a/clang/include/clang/Basic/Features.def
+++ b/clang/include/clang/Basic/Features.def
@@ -110,7 +110,7 @@ FEATURE(ptrauth_vtable_pointer_address_discrimination, LangOpts.PointerAuthVTPtr
 FEATURE(ptrauth_vtable_pointer_type_discrimination, LangOpts.PointerAuthVTPtrTypeDiscrimination)
 FEATURE(ptrauth_member_function_pointer_type_discrimination, LangOpts.PointerAuthCalls)
 FEATURE(ptrauth_init_fini, LangOpts.PointerAuthInitFini)
-FEATURE(ptrauth_function_pointer_type_discrimination, LangOpts.FunctionPointerTypeDiscrimination)
+FEATURE(ptrauth_function_pointer_type_discrimination, LangOpts.PointerAuthFunctionTypeDiscrimination)
 EXTENSION(swiftcc,
   PP.getTargetInfo().checkCallingConvention(CC_Swift) ==
   clang::TargetInfo::CCCR_OK)
diff --git a/clang/include/clang/Basic/LangOptions.def b/clang/include/clang/Basic/LangOptions.def
index 554f8643812b5..7521ab85a9b70 100644
--- a/clang/include/clang/Basic/LangOptions.def
+++ b/clang/include/clang/Basic/LangOptions.def
@@ -470,7 +470,7 @@ ENUM_LANGOPT(StrictFlexArraysLevel, StrictFlexArraysLevelKind, 2,
 
 COMPATIBLE_VALUE_LANGOPT(MaxTokens, 32, 0, "Max number of tokens per TU or 0")
 
-BENIGN_LANGOPT(FunctionPointerTypeDiscrimination, 1, 0,
+BENIGN_LANGOPT(PointerAuthFunctionTypeDiscrimination, 1, 0,
                "Use type discrimination when signing function pointers")
 
 ENUM_LANGOPT(SignReturnAddressScope, SignReturnAddressScopeKind, 2, SignReturnAddressScopeKind::None,
diff --git a/clang/lib/Frontend/CompilerInvocation.cpp b/clang/lib/Frontend/CompilerInvocation.cpp
index 5301802875ab0..b441baa414900 100644
--- a/clang/lib/Frontend/CompilerInvocation.cpp
+++ b/clang/lib/Frontend/CompilerInvocation.cpp
@@ -1468,8 +1468,8 @@ void CompilerInvocation::setDefaultPointerAuthOptions(
     // If you change anything here, be sure to update <ptrauth.h>.
     Opts.FunctionPointers = PointerAuthSchema(
         Key::ASIA, false,
-        LangOpts.FunctionPointerTypeDiscrimination ? Discrimination::Type
-                                                   : Discrimination::None);
+        LangOpts.PointerAuthFunctionTypeDiscrimination ? Discrimination::Type
+                                                       : Discrimination::None);
 
     Opts.CXXVTablePointers = PointerAuthSchema(
         Key::ASDA, LangOpts.PointerAuthVTPtrAddressDiscrimination,
@@ -3400,7 +3400,7 @@ static void GeneratePointerAuthArgs(const LangOptions &Opts,
     GenerateArg(Consumer, OPT_fptrauth_vtable_pointer_type_discrimination);
   if (Opts.PointerAuthInitFini)
     GenerateArg(Consumer, OPT_fptrauth_init_fini);
-  if (Opts.FunctionPointerTypeDiscrimination)
+  if (Opts.PointerAuthFunctionTypeDiscrimination)
     GenerateArg(Consumer, OPT_fptrauth_function_pointer_type_discrimination);
 }
 
@@ -4783,7 +4783,7 @@ bool CompilerInvocation::CreateFromArgsImpl(
   if (Res.getFrontendOpts().ProgramAction == frontend::RewriteObjC)
     LangOpts.ObjCExceptions = 1;
 
-  LangOpts.FunctionPointerTypeDiscrimination =
+  LangOpts.PointerAuthFunctionTypeDiscrimination =
       Args.hasArg(OPT_fptrauth_function_pointer_type_discrimination);
 
   for (auto Warning : Res.getDiagnosticOpts().Warnings) {

>From 4b1ad0b9fb2837a3b8327fa8b606efc3e6e6c70f Mon Sep 17 00:00:00 2001
From: Akira Hatanaka <ahatanak at gmail.com>
Date: Fri, 28 Jun 2024 11:05:17 -0700
Subject: [PATCH 3/4] Address review comments

---
 clang/include/clang/AST/ASTContext.h      |  2 +-
 clang/include/clang/Basic/LangOptions.def |  5 ++---
 clang/include/clang/Driver/Options.td     |  2 +-
 clang/lib/AST/ASTContext.cpp              | 18 +++++++++---------
 clang/lib/CodeGen/CGExprConstant.cpp      |  4 ++--
 clang/lib/CodeGen/CGPointerAuth.cpp       |  1 -
 clang/lib/Frontend/CompilerInvocation.cpp |  5 ++---
 7 files changed, 17 insertions(+), 20 deletions(-)

diff --git a/clang/include/clang/AST/ASTContext.h b/clang/include/clang/AST/ASTContext.h
index 8ba0c943a9c86..57022e75073fe 100644
--- a/clang/include/clang/AST/ASTContext.h
+++ b/clang/include/clang/AST/ASTContext.h
@@ -1284,7 +1284,7 @@ class ASTContext : public RefCountedBase<ASTContext> {
   getPointerAuthVTablePointerDiscriminator(const CXXRecordDecl *RD);
 
   /// Return the "other" type-specific discriminator for the given type.
-  uint16_t getPointerAuthTypeDiscriminator(QualType T);
+  uint16_t getPointerAuthTypeDiscriminator(QualType T) const;
 
   /// Apply Objective-C protocol qualifiers to the given type.
   /// \param allowOnPointerType specifies if we can apply protocol
diff --git a/clang/include/clang/Basic/LangOptions.def b/clang/include/clang/Basic/LangOptions.def
index 7521ab85a9b70..3c9ab542fb8e2 100644
--- a/clang/include/clang/Basic/LangOptions.def
+++ b/clang/include/clang/Basic/LangOptions.def
@@ -168,6 +168,8 @@ LANGOPT(PointerAuthAuthTraps, 1, 0, "pointer authentication failure traps")
 LANGOPT(PointerAuthVTPtrAddressDiscrimination, 1, 0, "incorporate address discrimination in authenticated vtable pointers")
 LANGOPT(PointerAuthVTPtrTypeDiscrimination, 1, 0, "incorporate type discrimination in authenticated vtable pointers")
 LANGOPT(PointerAuthInitFini, 1, 0, "sign function pointers in init/fini arrays")
+BENIGN_LANGOPT(PointerAuthFunctionTypeDiscrimination, 1, 0,
+               "Use type discrimination when signing function pointers")
 
 LANGOPT(DoubleSquareBracketAttributes, 1, 0, "'[[]]' attributes extension for all language standard modes")
 LANGOPT(ExperimentalLateParseAttributes, 1, 0, "experimental late parsing of attributes")
@@ -470,9 +472,6 @@ ENUM_LANGOPT(StrictFlexArraysLevel, StrictFlexArraysLevelKind, 2,
 
 COMPATIBLE_VALUE_LANGOPT(MaxTokens, 32, 0, "Max number of tokens per TU or 0")
 
-BENIGN_LANGOPT(PointerAuthFunctionTypeDiscrimination, 1, 0,
-               "Use type discrimination when signing function pointers")
-
 ENUM_LANGOPT(SignReturnAddressScope, SignReturnAddressScopeKind, 2, SignReturnAddressScopeKind::None,
              "Scope of return address signing")
 ENUM_LANGOPT(SignReturnAddressKey, SignReturnAddressKeyKind, 1, SignReturnAddressKeyKind::AKey,
diff --git a/clang/include/clang/Driver/Options.td b/clang/include/clang/Driver/Options.td
index a6f960e42d23f..c586ee09d7372 100644
--- a/clang/include/clang/Driver/Options.td
+++ b/clang/include/clang/Driver/Options.td
@@ -4229,7 +4229,7 @@ defm ptrauth_vtable_pointer_type_discrimination :
   OptInCC1FFlag<"ptrauth-vtable-pointer-type-discrimination", "Enable type discrimination of vtable pointers">;
 defm ptrauth_init_fini : OptInCC1FFlag<"ptrauth-init-fini", "Enable signing of function pointers in init/fini arrays">;
 defm ptrauth_function_pointer_type_discrimination : OptInCC1FFlag<"ptrauth-function-pointer-type-discrimination",
-  "Enabling type discrimination on C function pointers">;
+  "Enable type discrimination on C function pointers">;
 }
 
 def fenable_matrix : Flag<["-"], "fenable-matrix">, Group<f_Group>,
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index 74e0ae0a58e9f..42604658095ad 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -3155,8 +3155,8 @@ ASTContext::getPointerAuthVTablePointerDiscriminator(const CXXRecordDecl *RD) {
 /// other. Therefore this encoding function must be careful to only distinguish
 /// types if there is no third type with which they are both required to be
 /// compatible.
-static void encodeTypeForFunctionPointerAuth(ASTContext &Ctx, raw_ostream &OS,
-                                             QualType QT) {
+static void encodeTypeForFunctionPointerAuth(const ASTContext &Ctx,
+                                             raw_ostream &OS, QualType QT) {
   // FIXME: Consider address space qualifiers.
   const Type *T = QT.getCanonicalType().getTypePtr();
 
@@ -3242,9 +3242,9 @@ static void encodeTypeForFunctionPointerAuth(ASTContext &Ctx, raw_ostream &OS,
     // FIXME: If we add a "custom discriminator" function type attribute we
     // should encode functions as their discriminators.
     OS << "F";
-    auto *FuncType = cast<FunctionType>(T);
+    const auto *FuncType = cast<FunctionType>(T);
     encodeTypeForFunctionPointerAuth(Ctx, OS, FuncType->getReturnType());
-    if (auto *FPT = dyn_cast<FunctionProtoType>(FuncType)) {
+    if (const auto *FPT = dyn_cast<FunctionProtoType>(FuncType)) {
       for (QualType Param : FPT->param_types()) {
         Param = Ctx.getSignatureParameterType(Param);
         encodeTypeForFunctionPointerAuth(Ctx, OS, Param);
@@ -3258,7 +3258,7 @@ static void encodeTypeForFunctionPointerAuth(ASTContext &Ctx, raw_ostream &OS,
 
   case Type::MemberPointer: {
     OS << "M";
-    auto *MPT = T->getAs<MemberPointerType>();
+    const auto *MPT = T->getAs<MemberPointerType>();
     encodeTypeForFunctionPointerAuth(Ctx, OS, QualType(MPT->getClass(), 0));
     encodeTypeForFunctionPointerAuth(Ctx, OS, MPT->getPointeeType());
     return;
@@ -3276,7 +3276,7 @@ static void encodeTypeForFunctionPointerAuth(ASTContext &Ctx, raw_ostream &OS,
     return;
 
   case Type::Builtin: {
-    const BuiltinType *BTy = T->getAs<BuiltinType>();
+    const auto *BTy = T->getAs<BuiltinType>();
     switch (BTy->getKind()) {
 #define SIGNED_TYPE(Id, SingletonId)                                           \
   case BuiltinType::Id:                                                        \
@@ -3358,8 +3358,8 @@ static void encodeTypeForFunctionPointerAuth(ASTContext &Ctx, raw_ostream &OS,
     }
   }
   case Type::Record: {
-    RecordDecl *RD = T->getAs<RecordType>()->getDecl();
-    IdentifierInfo *II = RD->getIdentifier();
+    const RecordDecl *RD = T->getAs<RecordType>()->getDecl();
+    const IdentifierInfo *II = RD->getIdentifier();
     if (!II)
       if (const TypedefNameDecl *Typedef = RD->getTypedefNameForAnonDecl())
         II = Typedef->getDeclName().getAsIdentifierInfo();
@@ -3384,7 +3384,7 @@ static void encodeTypeForFunctionPointerAuth(ASTContext &Ctx, raw_ostream &OS,
   }
 }
 
-uint16_t ASTContext::getPointerAuthTypeDiscriminator(QualType T) {
+uint16_t ASTContext::getPointerAuthTypeDiscriminator(QualType T) const {
   assert(!T->isDependentType() &&
          "cannot compute type discriminator of a dependent type");
 
diff --git a/clang/lib/CodeGen/CGExprConstant.cpp b/clang/lib/CodeGen/CGExprConstant.cpp
index 6d412af5cc994..f09e7304481a9 100644
--- a/clang/lib/CodeGen/CGExprConstant.cpp
+++ b/clang/lib/CodeGen/CGExprConstant.cpp
@@ -2220,8 +2220,8 @@ llvm::Constant *ConstantLValueEmitter::emitPointerAuthPointer(const Expr *E) {
 
   // The assertions here are all checked by Sema.
   assert(Result.Val.isLValue());
-  auto *Base = Result.Val.getLValueBase().get<const ValueDecl *>();
-  if (auto *Decl = dyn_cast_or_null<FunctionDecl>(Base)) {
+  const auto *Base = Result.Val.getLValueBase().get<const ValueDecl *>();
+  if (const auto *Decl = dyn_cast_or_null<FunctionDecl>(Base)) {
     assert(Result.Val.getLValueOffset().isZero());
     return CGM.getRawFunctionPointer(Decl);
   }
diff --git a/clang/lib/CodeGen/CGPointerAuth.cpp b/clang/lib/CodeGen/CGPointerAuth.cpp
index 33b421437c74c..621d567dde721 100644
--- a/clang/lib/CodeGen/CGPointerAuth.cpp
+++ b/clang/lib/CodeGen/CGPointerAuth.cpp
@@ -15,7 +15,6 @@
 #include "CodeGenModule.h"
 #include "clang/CodeGen/CodeGenABITypes.h"
 #include "clang/CodeGen/ConstantInitBuilder.h"
-#include "llvm/Analysis/ValueTracking.h"
 #include "llvm/Support/SipHash.h"
 
 using namespace clang;
diff --git a/clang/lib/Frontend/CompilerInvocation.cpp b/clang/lib/Frontend/CompilerInvocation.cpp
index b441baa414900..ef2b2aae91a18 100644
--- a/clang/lib/Frontend/CompilerInvocation.cpp
+++ b/clang/lib/Frontend/CompilerInvocation.cpp
@@ -3415,6 +3415,8 @@ static void ParsePointerAuthArgs(LangOptions &Opts, ArgList &Args,
   Opts.PointerAuthVTPtrTypeDiscrimination =
       Args.hasArg(OPT_fptrauth_vtable_pointer_type_discrimination);
   Opts.PointerAuthInitFini = Args.hasArg(OPT_fptrauth_init_fini);
+  Opts.PointerAuthFunctionTypeDiscrimination =
+      Args.hasArg(OPT_fptrauth_function_pointer_type_discrimination);
 }
 
 /// Check if input file kind and language standard are compatible.
@@ -4783,9 +4785,6 @@ bool CompilerInvocation::CreateFromArgsImpl(
   if (Res.getFrontendOpts().ProgramAction == frontend::RewriteObjC)
     LangOpts.ObjCExceptions = 1;
 
-  LangOpts.PointerAuthFunctionTypeDiscrimination =
-      Args.hasArg(OPT_fptrauth_function_pointer_type_discrimination);
-
   for (auto Warning : Res.getDiagnosticOpts().Warnings) {
     if (Warning == "misexpect" &&
         !Diags.isIgnored(diag::warn_profile_data_misexpect, SourceLocation())) {

>From 06e4a149f7c5c815ef629e33e67802739de1d008 Mon Sep 17 00:00:00 2001
From: Akira Hatanaka <ahatanak at gmail.com>
Date: Mon, 8 Jul 2024 14:43:57 -0700
Subject: [PATCH 4/4] Address review comments

---
 clang/lib/AST/ASTContext.cpp         |  4 ++--
 clang/lib/CodeGen/CGExprConstant.cpp | 35 +++++++++++++++++-----------
 clang/lib/CodeGen/ConstantEmitter.h  |  9 ++++---
 3 files changed, 30 insertions(+), 18 deletions(-)

diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index 42604658095ad..a5cc00cccc2f4 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -3224,8 +3224,8 @@ static void encodeTypeForFunctionPointerAuth(const ASTContext &Ctx,
     //   type, or an unsigned integer type.
     //
     // So we have to treat enum types as integers.
-    OS << "i";
-    return;
+    return encodeTypeForFunctionPointerAuth(
+        Ctx, OS, cast<EnumType>(T)->getDecl()->getIntegerType());
 
   case Type::FunctionNoProto:
   case Type::FunctionProto: {
diff --git a/clang/lib/CodeGen/CGExprConstant.cpp b/clang/lib/CodeGen/CGExprConstant.cpp
index f09e7304481a9..00a5a7e6898a8 100644
--- a/clang/lib/CodeGen/CGExprConstant.cpp
+++ b/clang/lib/CodeGen/CGExprConstant.cpp
@@ -1542,9 +1542,11 @@ ConstantEmitter::emitAbstract(const Expr *E, QualType destType) {
 
 llvm::Constant *
 ConstantEmitter::emitAbstract(SourceLocation loc, const APValue &value,
-                              QualType destType) {
+                              QualType destType,
+                              bool EnablePtrAuthFunctionTypeDiscrimination) {
   auto state = pushAbstract();
-  auto C = tryEmitPrivate(value, destType);
+  auto C =
+      tryEmitPrivate(value, destType, EnablePtrAuthFunctionTypeDiscrimination);
   C = validateAndPopAbstract(C, state);
   if (!C) {
     CGM.Error(loc,
@@ -1938,14 +1940,18 @@ class ConstantLValueEmitter : public ConstStmtVisitor<ConstantLValueEmitter,
   ConstantEmitter &Emitter;
   const APValue &Value;
   QualType DestType;
+  bool EnablePtrAuthFunctionTypeDiscrimination;
 
   // Befriend StmtVisitorBase so that we don't have to expose Visit*.
   friend StmtVisitorBase;
 
 public:
   ConstantLValueEmitter(ConstantEmitter &emitter, const APValue &value,
-                        QualType destType)
-    : CGM(emitter.CGM), Emitter(emitter), Value(value), DestType(destType) {}
+                        QualType destType,
+                        bool EnablePtrAuthFunctionTypeDiscrimination = true)
+      : CGM(emitter.CGM), Emitter(emitter), Value(value), DestType(destType),
+        EnablePtrAuthFunctionTypeDiscrimination(
+            EnablePtrAuthFunctionTypeDiscrimination) {}
 
   llvm::Constant *tryEmit();
 
@@ -2069,7 +2075,10 @@ ConstantLValueEmitter::tryEmitBase(const APValue::LValueBase &base) {
       return CGM.GetWeakRefReference(D).getPointer();
 
     auto PtrAuthSign = [&](llvm::Constant *C) {
-      CGPointerAuthInfo AuthInfo = CGM.getFunctionPointerAuthInfo(DestType);
+      CGPointerAuthInfo AuthInfo;
+
+      if (EnablePtrAuthFunctionTypeDiscrimination)
+        AuthInfo = CGM.getFunctionPointerAuthInfo(DestType);
 
       if (AuthInfo) {
         if (hasNonZeroOffset())
@@ -2220,13 +2229,10 @@ llvm::Constant *ConstantLValueEmitter::emitPointerAuthPointer(const Expr *E) {
 
   // The assertions here are all checked by Sema.
   assert(Result.Val.isLValue());
-  const auto *Base = Result.Val.getLValueBase().get<const ValueDecl *>();
-  if (const auto *Decl = dyn_cast_or_null<FunctionDecl>(Base)) {
+  if (isa<FunctionDecl>(Result.Val.getLValueBase().get<const ValueDecl *>()))
     assert(Result.Val.getLValueOffset().isZero());
-    return CGM.getRawFunctionPointer(Decl);
-  }
   return ConstantEmitter(CGM, Emitter.CGF)
-      .emitAbstract(E->getExprLoc(), Result.Val, E->getType());
+      .emitAbstract(E->getExprLoc(), Result.Val, E->getType(), false);
 }
 
 unsigned ConstantLValueEmitter::emitPointerAuthKey(const Expr *E) {
@@ -2283,15 +2289,18 @@ ConstantLValueEmitter::VisitMaterializeTemporaryExpr(
   return CGM.GetAddrOfGlobalTemporary(E, Inner);
 }
 
-llvm::Constant *ConstantEmitter::tryEmitPrivate(const APValue &Value,
-                                                QualType DestType) {
+llvm::Constant *
+ConstantEmitter::tryEmitPrivate(const APValue &Value, QualType DestType,
+                                bool EnablePtrAuthFunctionTypeDiscrimination) {
   switch (Value.getKind()) {
   case APValue::None:
   case APValue::Indeterminate:
     // Out-of-lifetime and indeterminate values can be modeled as 'undef'.
     return llvm::UndefValue::get(CGM.getTypes().ConvertType(DestType));
   case APValue::LValue:
-    return ConstantLValueEmitter(*this, Value, DestType).tryEmit();
+    return ConstantLValueEmitter(*this, Value, DestType,
+                                 EnablePtrAuthFunctionTypeDiscrimination)
+        .tryEmit();
   case APValue::Int:
     return llvm::ConstantInt::get(CGM.getLLVMContext(), Value.getInt());
   case APValue::FixedPoint:
diff --git a/clang/lib/CodeGen/ConstantEmitter.h b/clang/lib/CodeGen/ConstantEmitter.h
index eff0a8d52891f..581b05ae87add 100644
--- a/clang/lib/CodeGen/ConstantEmitter.h
+++ b/clang/lib/CodeGen/ConstantEmitter.h
@@ -103,8 +103,9 @@ class ConstantEmitter {
   /// expression is known to be a constant expression with either a fairly
   /// simple type or a known simple form.
   llvm::Constant *emitAbstract(const Expr *E, QualType T);
-  llvm::Constant *emitAbstract(SourceLocation loc, const APValue &value,
-                               QualType T);
+  llvm::Constant *
+  emitAbstract(SourceLocation loc, const APValue &value, QualType T,
+               bool EnablePtrAuthFunctionTypeDiscrimination = true);
 
   /// Try to emit the result of the given expression as an abstract constant.
   llvm::Constant *tryEmitAbstract(const Expr *E, QualType T);
@@ -138,7 +139,9 @@ class ConstantEmitter {
   llvm::Constant *tryEmitPrivate(const Expr *E, QualType T);
   llvm::Constant *tryEmitPrivateForMemory(const Expr *E, QualType T);
 
-  llvm::Constant *tryEmitPrivate(const APValue &value, QualType T);
+  llvm::Constant *
+  tryEmitPrivate(const APValue &value, QualType T,
+                 bool EnablePtrAuthFunctionTypeDiscrimination = true);
   llvm::Constant *tryEmitPrivateForMemory(const APValue &value, QualType T);
 
   /// Get the address of the current location.  This is a constant



More information about the cfe-commits mailing list