[llvm] [PAC][IR][AArch64] Add "ptrauth(...)" Constant to represent signed pointers. (PR #85738)

Ahmed Bougacha via llvm-commits llvm-commits at lists.llvm.org
Tue May 28 14:54:46 PDT 2024


https://github.com/ahmedbougacha updated https://github.com/llvm/llvm-project/pull/85738

>From c7779d04d14fb4756679d5d539a627a5e388e04f Mon Sep 17 00:00:00 2001
From: Ahmed Bougacha <ahmed at bougacha.org>
Date: Mon, 18 Mar 2024 23:03:25 -0700
Subject: [PATCH] [IR][AArch64] Add "ptrauth(...)" Constant to represent signed
 pointers.

This defines a new kind of IR Constant that represents a ptrauth signed
pointer, as used in AArch64 PAuth.

It allows representing most kinds of signed pointer constants used thus
far in the llvm ptrauth implementations, notably those used in the
Darwin and ELF ABIs being implemented for c/c++.  These signed pointer
constants are then lowered to ELF/MachO relocations.

These can be simply thought of as a constant `llvm.ptrauth.sign`, with
the interesting addition of discriminator computation: the `ptrauth`
constant can also represent a combined blend, when both address and
integer discriminator operands are used.  Both operands are otherwise
optional, with default values 0/null.

Co-Authored-by: Tim Northover <tnorthover at apple.com>
---
 llvm/docs/LangRef.rst                         |  34 +++++
 llvm/docs/PointerAuth.md                      |  22 ++++
 llvm/include/llvm/AsmParser/LLToken.h         |   1 +
 llvm/include/llvm/Bitcode/LLVMBitCodes.h      |   1 +
 llvm/include/llvm/IR/Constants.h              |  66 ++++++++++
 llvm/include/llvm/IR/Value.def                |   1 +
 llvm/lib/Analysis/ValueTracking.cpp           |   4 +
 llvm/lib/AsmParser/LLLexer.cpp                |   1 +
 llvm/lib/AsmParser/LLParser.cpp               |  54 ++++++++
 llvm/lib/Bitcode/Reader/BitcodeAnalyzer.cpp   |   1 +
 llvm/lib/Bitcode/Reader/BitcodeReader.cpp     |  25 +++-
 llvm/lib/Bitcode/Writer/BitcodeWriter.cpp     |   6 +
 llvm/lib/IR/AsmWriter.cpp                     |  21 +++
 llvm/lib/IR/Constants.cpp                     | 121 ++++++++++++++++++
 llvm/lib/IR/ConstantsContext.h                |  47 +++++++
 llvm/lib/IR/LLVMContextImpl.h                 |   2 +
 llvm/lib/IR/Verifier.cpp                      |  23 ++++
 llvm/test/Assembler/invalid-ptrauth-const1.ll |   6 +
 llvm/test/Assembler/invalid-ptrauth-const2.ll |   6 +
 llvm/test/Assembler/invalid-ptrauth-const3.ll |   6 +
 llvm/test/Assembler/invalid-ptrauth-const4.ll |   6 +
 llvm/test/Assembler/invalid-ptrauth-const5.ll |   6 +
 llvm/test/Assembler/ptrauth-const.ll          |  24 ++++
 llvm/test/Bitcode/compatibility.ll            |   4 +
 llvm/utils/vim/syntax/llvm.vim                |   1 +
 25 files changed, 488 insertions(+), 1 deletion(-)
 create mode 100644 llvm/test/Assembler/invalid-ptrauth-const1.ll
 create mode 100644 llvm/test/Assembler/invalid-ptrauth-const2.ll
 create mode 100644 llvm/test/Assembler/invalid-ptrauth-const3.ll
 create mode 100644 llvm/test/Assembler/invalid-ptrauth-const4.ll
 create mode 100644 llvm/test/Assembler/invalid-ptrauth-const5.ll
 create mode 100644 llvm/test/Assembler/ptrauth-const.ll

diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index 614dd98b013b3..7b64c477d13c7 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -4754,6 +4754,40 @@ reference to the CFI jump table in the ``LowerTypeTests`` pass. These constants
 may be useful in low-level programs, such as operating system kernels, which
 need to refer to the actual function body.
 
+.. _ptrauth_constant:
+
+Pointer Authentication Constants
+--------------------------------
+
+``ptrauth (ptr CST, i32 KEY[, i64 DISC[, ptr ADDRDISC]?]?)``
+
+A '``ptrauth``' constant represents a pointer with a cryptographic
+authentication signature embedded into some bits, as described in the
+`Pointer Authentication <PointerAuth.html>`__ document.
+
+A '``ptrauth``' constant is simply a constant equivalent to the
+``llvm.ptrauth.sign`` intrinsic, potentially fed by a discriminator
+``llvm.ptrauth.blend`` if needed.
+
+Its type is the same as the first argument.  An integer constant discriminator
+and an address discriminator may be optionally specified.  Otherwise, they have
+values ``i64 0`` and ``ptr null``.
+
+If the address discriminator is ``null`` then the expression is equivalent to
+
+.. code-block:: llvm
+
+    %tmp = call i64 @llvm.ptrauth.sign(i64 ptrtoint (ptr CST to i64), i32 KEY, i64 DISC)
+    %val = inttoptr i64 %tmp to ptr
+
+Otherwise, the expression is equivalent to:
+
+.. code-block:: llvm
+
+    %tmp1 = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr ADDRDISC to i64), i64 DISC)
+    %tmp2 = call i64 @llvm.ptrauth.sign(i64 ptrtoint (ptr CST to i64), i32 KEY, i64 %tmp1)
+    %val = inttoptr i64 %tmp2 to ptr
+
 .. _constantexprs:
 
 Constant Expressions
diff --git a/llvm/docs/PointerAuth.md b/llvm/docs/PointerAuth.md
index a8d2b4d8f5f0b..cf2cc6305f130 100644
--- a/llvm/docs/PointerAuth.md
+++ b/llvm/docs/PointerAuth.md
@@ -16,6 +16,7 @@ For more details, see the clang documentation page for
 At the IR level, it is represented using:
 
 * a [set of intrinsics](#intrinsics) (to sign/authenticate pointers)
+* a [signed pointer constant](#constant) (to sign globals)
 * a [call operand bundle](#operand-bundle) (to authenticate called pointers)
 
 The current implementation leverages the
@@ -225,6 +226,27 @@ with a pointer address discriminator, in a way that is specified by the target
 implementation.
 
 
+### Constant
+
+[Intrinsics](#intrinsics) can be used to produce signed pointers dynamically,
+in code, but not for signed pointers referenced by constants, in, e.g., global
+initializers.
+
+The latter are represented using a
+[``ptrauth`` constant](https://llvm.org/docs/LangRef.html#ptrauth-constant),
+which describes an authenticated relocation producing a signed pointer.
+
+```llvm
+ptrauth (ptr CST, i32 KEY, i64 DISC, ptr ADDRDISC)
+```
+
+is equivalent to:
+
+```llvm
+  %disc = call i64 @llvm.ptrauth.blend(i64 ptrtoint(ptr ADDRDISC to i64), i64 DISC)
+  %signedval = call i64 @llvm.ptrauth.sign(ptr CST, i32 KEY, i64 %disc)
+```
+
 ### Operand Bundle
 
 Function pointers used as indirect call targets can be signed when materialized,
diff --git a/llvm/include/llvm/AsmParser/LLToken.h b/llvm/include/llvm/AsmParser/LLToken.h
index df61ec6ed30e0..69821c22dcd61 100644
--- a/llvm/include/llvm/AsmParser/LLToken.h
+++ b/llvm/include/llvm/AsmParser/LLToken.h
@@ -346,6 +346,7 @@ enum Kind {
   kw_blockaddress,
   kw_dso_local_equivalent,
   kw_no_cfi,
+  kw_ptrauth,
 
   kw_freeze,
 
diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index d3b9e96520f88..9999aee61528e 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -413,6 +413,7 @@ enum ConstantsCodes {
                                       //                 asmstr,conststr]
   CST_CODE_CE_GEP_WITH_INRANGE = 31,  // [opty, flags, range, n x operands]
   CST_CODE_CE_GEP = 32,               // [opty, flags, n x operands]
+  CST_CODE_PTRAUTH = 33,              // [ptr, key, disc, addrdisc]
 };
 
 /// CastOpcodes - These are values used in the bitcode files to encode which
diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index a1e5005a9d1da..86f6be7985a23 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -1008,6 +1008,72 @@ struct OperandTraits<NoCFIValue> : public FixedNumOperandTraits<NoCFIValue, 1> {
 
 DEFINE_TRANSPARENT_OPERAND_ACCESSORS(NoCFIValue, Value)
 
+/// A signed pointer, in the ptrauth sense.
+class ConstantPtrAuth final : public Constant {
+  friend struct ConstantPtrAuthKeyType;
+  friend class Constant;
+
+  ConstantPtrAuth(Constant *Ptr, ConstantInt *Key, ConstantInt *Disc,
+                  Constant *AddrDisc);
+
+  void *operator new(size_t s) { return User::operator new(s, 4); }
+
+  void destroyConstantImpl();
+  Value *handleOperandChangeImpl(Value *From, Value *To);
+
+public:
+  /// Return a pointer signed with the specified parameters.
+  static ConstantPtrAuth *get(Constant *Ptr, ConstantInt *Key,
+                              ConstantInt *Disc, Constant *AddrDisc);
+
+  /// Produce a new ptrauth expression signing the given value using
+  /// the same schema as is stored in one.
+  ConstantPtrAuth *getWithSameSchema(Constant *Pointer) const;
+
+  /// Transparently provide more efficient getOperand methods.
+  DECLARE_TRANSPARENT_OPERAND_ACCESSORS(Constant);
+
+  /// The pointer that is signed in this ptrauth signed pointer.
+  Constant *getPointer() const { return cast<Constant>(Op<0>().get()); }
+
+  /// The Key ID, an i32 constant.
+  ConstantInt *getKey() const { return cast<ConstantInt>(Op<1>().get()); }
+
+  /// The integer discriminator, an i64 constant, or 0.
+  ConstantInt *getDiscriminator() const {
+    return cast<ConstantInt>(Op<2>().get());
+  }
+
+  /// The address discriminator if any, or the null constant.
+  /// If present, this must be a value equivalent to the storage location of
+  /// the only global-initializer user of the ptrauth signed pointer.
+  Constant *getAddrDiscriminator() const {
+    return cast<Constant>(Op<3>().get());
+  }
+
+  /// Whether there is any non-null address discriminator.
+  bool hasAddressDiscriminator() const {
+    return !getAddrDiscriminator()->isNullValue();
+  }
+
+  /// Check whether an authentication operation with key \p Key and (possibly
+  /// blended) discriminator \p Discriminator is known to be compatible with
+  /// this ptrauth signed pointer.
+  bool isKnownCompatibleWith(const Value *Key, const Value *Discriminator,
+                             const DataLayout &DL) const;
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast:
+  static bool classof(const Value *V) {
+    return V->getValueID() == ConstantPtrAuthVal;
+  }
+};
+
+template <>
+struct OperandTraits<ConstantPtrAuth>
+    : public FixedNumOperandTraits<ConstantPtrAuth, 4> {};
+
+DEFINE_TRANSPARENT_OPERAND_ACCESSORS(ConstantPtrAuth, Constant)
+
 //===----------------------------------------------------------------------===//
 /// A constant value that is initialized with an expression using
 /// other constant values.
diff --git a/llvm/include/llvm/IR/Value.def b/llvm/include/llvm/IR/Value.def
index 61f7a87666d09..3ece66a529e12 100644
--- a/llvm/include/llvm/IR/Value.def
+++ b/llvm/include/llvm/IR/Value.def
@@ -81,6 +81,7 @@ HANDLE_CONSTANT(BlockAddress)
 HANDLE_CONSTANT(ConstantExpr)
 HANDLE_CONSTANT_EXCLUDE_LLVM_C_API(DSOLocalEquivalent)
 HANDLE_CONSTANT_EXCLUDE_LLVM_C_API(NoCFIValue)
+HANDLE_CONSTANT_EXCLUDE_LLVM_C_API(ConstantPtrAuth)
 
 // ConstantAggregate.
 HANDLE_CONSTANT(ConstantArray)
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 3baa8ede28ffa..08138a5e2f2d9 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -3140,6 +3140,10 @@ bool isKnownNonZero(const Value *V, const APInt &DemandedElts,
       return true;
     }
 
+    // Constant ptrauth can be null, iff the base pointer can be.
+    if (auto *CPA = dyn_cast<ConstantPtrAuth>(V))
+      return isKnownNonZero(CPA->getPointer(), DemandedElts, Q, Depth);
+
     // A global variable in address space 0 is non null unless extern weak
     // or an absolute symbol reference. Other address spaces may have null as a
     // valid address for a global, so we can't assume anything.
diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp
index 20a1bd2957712..d3ab306904da1 100644
--- a/llvm/lib/AsmParser/LLLexer.cpp
+++ b/llvm/lib/AsmParser/LLLexer.cpp
@@ -710,6 +710,7 @@ lltok::Kind LLLexer::LexIdentifier() {
   KEYWORD(blockaddress);
   KEYWORD(dso_local_equivalent);
   KEYWORD(no_cfi);
+  KEYWORD(ptrauth);
 
   // Metadata types.
   KEYWORD(distinct);
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index 5d2056d208567..df0827996396e 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -4046,6 +4046,60 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
     ID.NoCFI = true;
     return false;
   }
+  case lltok::kw_ptrauth: {
+    // ValID ::= 'ptrauth' '(' ptr @foo ',' i32 <key>
+    //                         (',' i64 <disc> (',' ptr addrdisc)? )? ')'
+    Lex.Lex();
+
+    Constant *Ptr, *Key;
+    Constant *Disc = nullptr, *AddrDisc = nullptr;
+
+    if (parseToken(lltok::lparen,
+                   "expected '(' in constant ptrauth expression") ||
+        parseGlobalTypeAndValue(Ptr) ||
+        parseToken(lltok::comma,
+                   "expected comma in constant ptrauth expression") ||
+        parseGlobalTypeAndValue(Key))
+      return true;
+    // If present, parse the optional disc/addrdisc.
+    if (EatIfPresent(lltok::comma))
+      if (parseGlobalTypeAndValue(Disc) ||
+          (EatIfPresent(lltok::comma) && parseGlobalTypeAndValue(AddrDisc)))
+        return true;
+    if (parseToken(lltok::rparen,
+                   "expected ')' in constant ptrauth expression"))
+      return true;
+
+    if (!Ptr->getType()->isPointerTy())
+      return error(ID.Loc, "constant ptrauth base pointer must be a pointer");
+
+    auto *KeyC = dyn_cast<ConstantInt>(Key);
+    if (!KeyC || KeyC->getBitWidth() != 32)
+      return error(ID.Loc, "constant ptrauth key must be i32 constant");
+
+    ConstantInt *DiscC = nullptr;
+    if (Disc) {
+      DiscC = dyn_cast<ConstantInt>(Disc);
+      if (!DiscC || DiscC->getBitWidth() != 64)
+        return error(
+            ID.Loc,
+            "constant ptrauth integer discriminator must be i64 constant");
+    } else {
+      DiscC = ConstantInt::get(Type::getInt64Ty(Context), 0);
+    }
+
+    if (AddrDisc) {
+      if (!AddrDisc->getType()->isPointerTy())
+        return error(
+            ID.Loc, "constant ptrauth address discriminator must be a pointer");
+    } else {
+      AddrDisc = ConstantPointerNull::get(PointerType::get(Context, 0));
+    }
+
+    ID.ConstantVal = ConstantPtrAuth::get(Ptr, KeyC, DiscC, AddrDisc);
+    ID.Kind = ValID::t_Constant;
+    return false;
+  }
 
   case lltok::kw_trunc:
   case lltok::kw_bitcast:
diff --git a/llvm/lib/Bitcode/Reader/BitcodeAnalyzer.cpp b/llvm/lib/Bitcode/Reader/BitcodeAnalyzer.cpp
index c085c715179ba..b7ed9cdf63145 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeAnalyzer.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeAnalyzer.cpp
@@ -222,6 +222,7 @@ GetCodeName(unsigned CodeID, unsigned BlockID,
       STRINGIFY_CODE(CST_CODE, CE_UNOP)
       STRINGIFY_CODE(CST_CODE, DSO_LOCAL_EQUIVALENT)
       STRINGIFY_CODE(CST_CODE, NO_CFI_VALUE)
+      STRINGIFY_CODE(CST_CODE, PTRAUTH)
     case bitc::CST_CODE_BLOCKADDRESS:
       return "CST_CODE_BLOCKADDRESS";
       STRINGIFY_CODE(CST_CODE, DATA)
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index 32b9a033173e9..aee627bbde0bf 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -517,7 +517,8 @@ class BitcodeConstant final : public Value,
   static constexpr uint8_t NoCFIOpcode = 252;
   static constexpr uint8_t DSOLocalEquivalentOpcode = 251;
   static constexpr uint8_t BlockAddressOpcode = 250;
-  static constexpr uint8_t FirstSpecialOpcode = BlockAddressOpcode;
+  static constexpr uint8_t ConstantPtrAuthOpcode = 249;
+  static constexpr uint8_t FirstSpecialOpcode = ConstantPtrAuthOpcode;
 
   // Separate struct to make passing different number of parameters to
   // BitcodeConstant::create() more convenient.
@@ -1562,6 +1563,18 @@ Expected<Value *> BitcodeReader::materializeValue(unsigned StartValID,
         C = ConstantExpr::get(BC->Opcode, ConstOps[0], ConstOps[1], BC->Flags);
       } else {
         switch (BC->Opcode) {
+        case BitcodeConstant::ConstantPtrAuthOpcode: {
+          auto *Key = dyn_cast<ConstantInt>(ConstOps[1]);
+          if (!Key)
+            return error("ptrauth key operand must be ConstantInt");
+
+          auto *Disc = dyn_cast<ConstantInt>(ConstOps[2]);
+          if (!Disc)
+            return error("ptrauth disc operand must be ConstantInt");
+
+          C = ConstantPtrAuth::get(ConstOps[0], Key, Disc, ConstOps[3]);
+          break;
+        }
         case BitcodeConstant::NoCFIOpcode: {
           auto *GV = dyn_cast<GlobalValue>(ConstOps[0]);
           if (!GV)
@@ -3644,6 +3657,16 @@ Error BitcodeReader::parseConstants() {
                                   Record[1]);
       break;
     }
+    case bitc::CST_CODE_PTRAUTH: {
+      if (Record.size() < 4)
+        return error("Invalid ptrauth record");
+      // Ptr, Key, Disc, AddrDisc
+      V = BitcodeConstant::create(Alloc, CurTy,
+                                  BitcodeConstant::ConstantPtrAuthOpcode,
+                                  {(unsigned)Record[0], (unsigned)Record[1],
+                                   (unsigned)Record[2], (unsigned)Record[3]});
+      break;
+    }
     }
 
     assert(V->getType() == getTypeByID(CurTyID) && "Incorrect result type ID");
diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index 3d653fe4458f4..046dad5721c4c 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -2848,6 +2848,12 @@ void ModuleBitcodeWriter::writeConstants(unsigned FirstVal, unsigned LastVal,
       Code = bitc::CST_CODE_NO_CFI_VALUE;
       Record.push_back(VE.getTypeID(NC->getGlobalValue()->getType()));
       Record.push_back(VE.getValueID(NC->getGlobalValue()));
+    } else if (const auto *CPA = dyn_cast<ConstantPtrAuth>(C)) {
+      Code = bitc::CST_CODE_PTRAUTH;
+      Record.push_back(VE.getValueID(CPA->getPointer()));
+      Record.push_back(VE.getValueID(CPA->getKey()));
+      Record.push_back(VE.getValueID(CPA->getDiscriminator()));
+      Record.push_back(VE.getValueID(CPA->getAddrDiscriminator()));
     } else {
 #ifndef NDEBUG
       C->dump();
diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index ced5d78f994ab..8b1a21f962b08 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -1594,6 +1594,27 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
     return;
   }
 
+  if (const ConstantPtrAuth *CPA = dyn_cast<ConstantPtrAuth>(CV)) {
+    Out << "ptrauth (";
+
+    // ptrauth (ptr CST, i32 KEY[, i64 DISC[, ptr ADDRDISC]?]?)
+    unsigned NumOpsToWrite = 2;
+    if (!CPA->getOperand(2)->isNullValue())
+      NumOpsToWrite = 3;
+    if (!CPA->getOperand(3)->isNullValue())
+      NumOpsToWrite = 4;
+
+    ListSeparator LS;
+    for (unsigned i = 0, e = NumOpsToWrite; i != e; ++i) {
+      Out << LS;
+      WriterCtx.TypePrinter->print(CPA->getOperand(i)->getType(), Out);
+      Out << ' ';
+      WriteAsOperandInternal(Out, CPA->getOperand(i), WriterCtx);
+    }
+    Out << ')';
+    return;
+  }
+
   if (const ConstantArray *CA = dyn_cast<ConstantArray>(CV)) {
     Type *ETy = CA->getType()->getElementType();
     Out << '[';
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index cfb89d557db47..119fcb4fa0346 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -550,6 +550,9 @@ void llvm::deleteConstant(Constant *C) {
   case Constant::NoCFIValueVal:
     delete static_cast<NoCFIValue *>(C);
     break;
+  case Constant::ConstantPtrAuthVal:
+    delete static_cast<ConstantPtrAuth *>(C);
+    break;
   case Constant::UndefValueVal:
     delete static_cast<UndefValue *>(C);
     break;
@@ -2015,6 +2018,124 @@ Value *NoCFIValue::handleOperandChangeImpl(Value *From, Value *To) {
   return nullptr;
 }
 
+//---- ConstantPtrAuth::get() implementations.
+//
+
+ConstantPtrAuth *ConstantPtrAuth::get(Constant *Ptr, ConstantInt *Key,
+                                      ConstantInt *Disc, Constant *AddrDisc) {
+  Constant *ArgVec[] = {Ptr, Key, Disc, AddrDisc};
+  ConstantPtrAuthKeyType MapKey(ArgVec);
+  LLVMContextImpl *pImpl = Ptr->getContext().pImpl;
+  return pImpl->ConstantPtrAuths.getOrCreate(Ptr->getType(), MapKey);
+}
+
+ConstantPtrAuth *ConstantPtrAuth::getWithSameSchema(Constant *Pointer) const {
+  return get(Pointer, getKey(), getDiscriminator(), getAddrDiscriminator());
+}
+
+ConstantPtrAuth::ConstantPtrAuth(Constant *Ptr, ConstantInt *Key,
+                                 ConstantInt *Disc, Constant *AddrDisc)
+    : Constant(Ptr->getType(), Value::ConstantPtrAuthVal, &Op<0>(), 4) {
+  assert(Ptr->getType()->isPointerTy());
+  assert(Key->getBitWidth() == 32);
+  assert(Disc->getBitWidth() == 64);
+  assert(AddrDisc->getType()->isPointerTy());
+  setOperand(0, Ptr);
+  setOperand(1, Key);
+  setOperand(2, Disc);
+  setOperand(3, AddrDisc);
+}
+
+/// Remove the constant from the constant table.
+void ConstantPtrAuth::destroyConstantImpl() {
+  getType()->getContext().pImpl->ConstantPtrAuths.remove(this);
+}
+
+Value *ConstantPtrAuth::handleOperandChangeImpl(Value *From, Value *ToV) {
+  assert(isa<Constant>(ToV) && "Cannot make Constant refer to non-constant!");
+  Constant *To = cast<Constant>(ToV);
+
+  SmallVector<Constant *, 4> Values;
+  Values.reserve(getNumOperands());
+
+  unsigned NumUpdated = 0;
+
+  Use *OperandList = getOperandList();
+  unsigned OperandNo = 0;
+  for (Use *O = OperandList, *E = OperandList + getNumOperands(); O != E; ++O) {
+    Constant *Val = cast<Constant>(O->get());
+    if (Val == From) {
+      OperandNo = (O - OperandList);
+      Val = To;
+      ++NumUpdated;
+    }
+    Values.push_back(Val);
+  }
+
+  return getContext().pImpl->ConstantPtrAuths.replaceOperandsInPlace(
+      Values, this, From, To, NumUpdated, OperandNo);
+}
+
+bool ConstantPtrAuth::isKnownCompatibleWith(const Value *Key,
+                                            const Value *Discriminator,
+                                            const DataLayout &DL) const {
+  // If the keys are different, there's no chance for this to be compatible.
+  if (getKey() != Key)
+    return false;
+
+  // We can have 3 kinds of discriminators:
+  // - simple, integer-only:    `i64 x, ptr null` vs. `i64 x`
+  // - address-only:            `i64 0, ptr p` vs. `ptr p`
+  // - blended address/integer: `i64 x, ptr p` vs. `@llvm.ptrauth.blend(p, x)`
+
+  // If this constant has a simple discriminator (integer, no address), easy:
+  // it's compatible iff the provided full discriminator is also a simple
+  // discriminator, identical to our integer discriminator.
+  if (!hasAddressDiscriminator())
+    return getDiscriminator() == Discriminator;
+
+  // Otherwise, we can isolate address and integer discriminator components.
+  const Value *AddrDiscriminator = nullptr;
+
+  // This constant may or may not have an integer discriminator (instead of 0).
+  if (!getDiscriminator()->isNullValue()) {
+    // If it does, there's an implicit blend.  We need to have a matching blend
+    // intrinsic in the provided full discriminator.
+    if (!match(Discriminator,
+               m_Intrinsic<Intrinsic::ptrauth_blend>(
+                   m_Value(AddrDiscriminator), m_Specific(getDiscriminator()))))
+      return false;
+  } else {
+    // Otherwise, interpret the provided full discriminator as address-only.
+    AddrDiscriminator = Discriminator;
+  }
+
+  // Either way, we can now focus on comparing the address discriminators.
+
+  // Discriminators are i64, so the provided addr disc may be a ptrtoint.
+  if (auto *Cast = dyn_cast<PtrToIntOperator>(AddrDiscriminator))
+    AddrDiscriminator = Cast->getPointerOperand();
+
+  // Beyond that, we're only interested in compatible pointers.
+  if (getAddrDiscriminator()->getType() != AddrDiscriminator->getType())
+    return false;
+
+  // These are often the same constant GEP, making them trivially equivalent.
+  if (getAddrDiscriminator() == AddrDiscriminator)
+    return true;
+
+  // Finally, they may be equivalent base+offset expressions.
+  APInt Off1(DL.getIndexTypeSizeInBits(getAddrDiscriminator()->getType()), 0);
+  auto *Base1 = getAddrDiscriminator()->stripAndAccumulateConstantOffsets(
+      DL, Off1, /*AllowNonInbounds=*/true);
+
+  APInt Off2(DL.getIndexTypeSizeInBits(AddrDiscriminator->getType()), 0);
+  auto *Base2 = AddrDiscriminator->stripAndAccumulateConstantOffsets(
+      DL, Off2, /*AllowNonInbounds=*/true);
+
+  return Base1 == Base2 && Off1 == Off2;
+}
+
 //---- ConstantExpr::get() implementations.
 //
 
diff --git a/llvm/lib/IR/ConstantsContext.h b/llvm/lib/IR/ConstantsContext.h
index 7067d0d121117..5153880b5cab6 100644
--- a/llvm/lib/IR/ConstantsContext.h
+++ b/llvm/lib/IR/ConstantsContext.h
@@ -23,6 +23,7 @@
 #include "llvm/IR/Constant.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/GlobalVariable.h"
 #include "llvm/IR/InlineAsm.h"
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Instructions.h"
@@ -286,6 +287,7 @@ DEFINE_TRANSPARENT_OPERAND_ACCESSORS(CompareConstantExpr, Value)
 template <class ConstantClass> struct ConstantAggrKeyType;
 struct InlineAsmKeyType;
 struct ConstantExprKeyType;
+struct ConstantPtrAuthKeyType;
 
 template <class ConstantClass> struct ConstantInfo;
 template <> struct ConstantInfo<ConstantExpr> {
@@ -308,6 +310,10 @@ template <> struct ConstantInfo<ConstantVector> {
   using ValType = ConstantAggrKeyType<ConstantVector>;
   using TypeClass = VectorType;
 };
+template <> struct ConstantInfo<ConstantPtrAuth> {
+  using ValType = ConstantPtrAuthKeyType;
+  using TypeClass = Type;
+};
 
 template <class ConstantClass> struct ConstantAggrKeyType {
   ArrayRef<Constant *> Operands;
@@ -536,6 +542,47 @@ struct ConstantExprKeyType {
   }
 };
 
+struct ConstantPtrAuthKeyType {
+  ArrayRef<Constant *> Operands;
+
+  ConstantPtrAuthKeyType(ArrayRef<Constant *> Operands) : Operands(Operands) {}
+
+  ConstantPtrAuthKeyType(ArrayRef<Constant *> Operands, const ConstantPtrAuth *)
+      : Operands(Operands) {}
+
+  ConstantPtrAuthKeyType(const ConstantPtrAuth *C,
+                         SmallVectorImpl<Constant *> &Storage) {
+    assert(Storage.empty() && "Expected empty storage");
+    for (unsigned I = 0, E = C->getNumOperands(); I != E; ++I)
+      Storage.push_back(cast<Constant>(C->getOperand(I)));
+    Operands = Storage;
+  }
+
+  bool operator==(const ConstantPtrAuthKeyType &X) const {
+    return Operands == X.Operands;
+  }
+
+  bool operator==(const ConstantPtrAuth *C) const {
+    if (Operands.size() != C->getNumOperands())
+      return false;
+    for (unsigned I = 0, E = Operands.size(); I != E; ++I)
+      if (Operands[I] != C->getOperand(I))
+        return false;
+    return true;
+  }
+
+  unsigned getHash() const {
+    return hash_combine_range(Operands.begin(), Operands.end());
+  }
+
+  using TypeClass = typename ConstantInfo<ConstantPtrAuth>::TypeClass;
+
+  ConstantPtrAuth *create(TypeClass *Ty) const {
+    return new ConstantPtrAuth(Operands[0], cast<ConstantInt>(Operands[1]),
+                               cast<ConstantInt>(Operands[2]), Operands[3]);
+  }
+};
+
 // Free memory for a given constant.  Assumes the constant has already been
 // removed from all relevant maps.
 void deleteConstant(Constant *C);
diff --git a/llvm/lib/IR/LLVMContextImpl.h b/llvm/lib/IR/LLVMContextImpl.h
index 399fe0dad26c7..392e0d16f1761 100644
--- a/llvm/lib/IR/LLVMContextImpl.h
+++ b/llvm/lib/IR/LLVMContextImpl.h
@@ -1562,6 +1562,8 @@ class LLVMContextImpl {
 
   DenseMap<const GlobalValue *, NoCFIValue *> NoCFIValues;
 
+  ConstantUniqueMap<ConstantPtrAuth> ConstantPtrAuths;
+
   ConstantUniqueMap<ConstantExpr> ExprConstants;
 
   ConstantUniqueMap<InlineAsm> InlineAsms;
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 50f8d6ec84201..684e54444621b 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -629,6 +629,7 @@ class Verifier : public InstVisitor<Verifier>, VerifierSupport {
 
   void visitConstantExprsRecursively(const Constant *EntryC);
   void visitConstantExpr(const ConstantExpr *CE);
+  void visitConstantPtrAuth(const ConstantPtrAuth *CPA);
   void verifyInlineAsmCall(const CallBase &Call);
   void verifyStatepoint(const CallBase &Call);
   void verifyFrameRecoverIndices();
@@ -2422,6 +2423,9 @@ void Verifier::visitConstantExprsRecursively(const Constant *EntryC) {
     if (const auto *CE = dyn_cast<ConstantExpr>(C))
       visitConstantExpr(CE);
 
+    if (const auto *CPA = dyn_cast<ConstantPtrAuth>(C))
+      visitConstantPtrAuth(CPA);
+
     if (const auto *GV = dyn_cast<GlobalValue>(C)) {
       // Global Values get visited separately, but we do need to make sure
       // that the global value is in the correct module
@@ -2449,6 +2453,23 @@ void Verifier::visitConstantExpr(const ConstantExpr *CE) {
           "Invalid bitcast", CE);
 }
 
+void Verifier::visitConstantPtrAuth(const ConstantPtrAuth *CPA) {
+  Check(CPA->getPointer()->getType()->isPointerTy(),
+        "signed ptrauth constant base pointer must have pointer type");
+
+  Check(CPA->getType() == CPA->getPointer()->getType(),
+        "signed ptrauth constant must have same type as its base pointer");
+
+  Check(CPA->getKey()->getBitWidth() == 32,
+        "signed ptrauth constant key must be i32 constant integer");
+
+  Check(CPA->getAddrDiscriminator()->getType()->isPointerTy(),
+        "signed ptrauth constant address discriminator must be a pointer");
+
+  Check(CPA->getDiscriminator()->getBitWidth() == 64,
+        "signed ptrauth constant discriminator must be i64 constant integer");
+}
+
 bool Verifier::verifyAttributeCount(AttributeList Attrs, unsigned Params) {
   // There shouldn't be more attribute sets than there are parameters plus the
   // function and return value.
@@ -5090,6 +5111,8 @@ void Verifier::visitInstruction(Instruction &I) {
     } else if (isa<InlineAsm>(I.getOperand(i))) {
       Check(CBI && &CBI->getCalledOperandUse() == &I.getOperandUse(i),
             "Cannot take the address of an inline asm!", &I);
+    } else if (auto *CPA = dyn_cast<ConstantPtrAuth>(I.getOperand(i))) {
+      visitConstantExprsRecursively(CPA);
     } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(I.getOperand(i))) {
       if (CE->getType()->isPtrOrPtrVectorTy()) {
         // If we have a ConstantExpr pointer, we need to see if it came from an
diff --git a/llvm/test/Assembler/invalid-ptrauth-const1.ll b/llvm/test/Assembler/invalid-ptrauth-const1.ll
new file mode 100644
index 0000000000000..fba2e23078238
--- /dev/null
+++ b/llvm/test/Assembler/invalid-ptrauth-const1.ll
@@ -0,0 +1,6 @@
+; RUN: not llvm-as < %s 2>&1 | FileCheck %s
+
+ at var = global i32 0
+
+; CHECK: error: constant ptrauth base pointer must be a pointer
+ at auth_var = global ptr ptrauth (i32 42, i32 0)
diff --git a/llvm/test/Assembler/invalid-ptrauth-const2.ll b/llvm/test/Assembler/invalid-ptrauth-const2.ll
new file mode 100644
index 0000000000000..4499c42601c99
--- /dev/null
+++ b/llvm/test/Assembler/invalid-ptrauth-const2.ll
@@ -0,0 +1,6 @@
+; RUN: not llvm-as < %s 2>&1 | FileCheck %s
+
+ at var = global i32 0
+
+; CHECK: error: constant ptrauth key must be i32 constant
+ at auth_var = global ptr ptrauth (ptr @var, i32 ptrtoint (ptr @var to i32))
diff --git a/llvm/test/Assembler/invalid-ptrauth-const3.ll b/llvm/test/Assembler/invalid-ptrauth-const3.ll
new file mode 100644
index 0000000000000..3f2688d92a001
--- /dev/null
+++ b/llvm/test/Assembler/invalid-ptrauth-const3.ll
@@ -0,0 +1,6 @@
+; RUN: not llvm-as < %s 2>&1 | FileCheck %s
+
+ at var = global i32 0
+
+; CHECK: error: constant ptrauth address discriminator must be a pointer
+ at auth_var = global ptr ptrauth (ptr @var, i32 2, i64 65535, i8 0)
diff --git a/llvm/test/Assembler/invalid-ptrauth-const4.ll b/llvm/test/Assembler/invalid-ptrauth-const4.ll
new file mode 100644
index 0000000000000..843a220458a61
--- /dev/null
+++ b/llvm/test/Assembler/invalid-ptrauth-const4.ll
@@ -0,0 +1,6 @@
+; RUN: not llvm-as < %s 2>&1 | FileCheck %s
+
+ at var = global i32 0
+
+; CHECK: error: constant ptrauth integer discriminator must be i64 constant
+ at auth_var = global ptr ptrauth (ptr @var, i32 2, ptr null, i64 ptrtoint (ptr @var to i64))
diff --git a/llvm/test/Assembler/invalid-ptrauth-const5.ll b/llvm/test/Assembler/invalid-ptrauth-const5.ll
new file mode 100644
index 0000000000000..9b47f6f5f423f
--- /dev/null
+++ b/llvm/test/Assembler/invalid-ptrauth-const5.ll
@@ -0,0 +1,6 @@
+; RUN: not llvm-as < %s 2>&1 | FileCheck %s
+
+ at var = global i32 0
+
+; CHECK: error: constant ptrauth integer discriminator must be i64 constant
+ at auth_var = global ptr ptrauth (ptr @var, i32 2, ptr @var))
diff --git a/llvm/test/Assembler/ptrauth-const.ll b/llvm/test/Assembler/ptrauth-const.ll
new file mode 100644
index 0000000000000..94d35146d5927
--- /dev/null
+++ b/llvm/test/Assembler/ptrauth-const.ll
@@ -0,0 +1,24 @@
+; RUN: llvm-as < %s | llvm-dis | FileCheck %s
+
+ at var = global i32 0
+
+; CHECK: @basic = global ptr ptrauth (ptr @var, i32 0)
+ at basic = global ptr ptrauth (ptr @var, i32 0)
+
+; CHECK: @keyed = global ptr ptrauth (ptr @var, i32 3)
+ at keyed = global ptr ptrauth (ptr @var, i32 3)
+
+; CHECK: @intdisc = global ptr ptrauth (ptr @var, i32 0, i64 -1)
+ at intdisc = global ptr ptrauth (ptr @var, i32 0, i64 -1)
+
+; CHECK: @addrdisc = global ptr ptrauth (ptr @var, i32 2, i64 1234, ptr @addrdisc)
+ at addrdisc = global ptr ptrauth (ptr @var, i32 2, i64 1234, ptr @addrdisc)
+
+
+ at var1 = addrspace(1) global i32 0
+
+; CHECK: @addrspace = global ptr addrspace(1) ptrauth (ptr addrspace(1) @var1, i32 0)
+ at addrspace = global ptr addrspace(1) ptrauth (ptr addrspace(1) @var1, i32 0)
+
+; CHECK: @addrspace_addrdisc = addrspace(2) global ptr addrspace(1) ptrauth (ptr addrspace(1) @var1, i32 2, i64 1234, ptr addrspace(2) @addrspace_addrdisc)
+ at addrspace_addrdisc = addrspace(2) global ptr addrspace(1) ptrauth (ptr addrspace(1) @var1, i32 2, i64 1234, ptr addrspace(2) @addrspace_addrdisc)
diff --git a/llvm/test/Bitcode/compatibility.ll b/llvm/test/Bitcode/compatibility.ll
index b374924516d66..2a846e036924c 100644
--- a/llvm/test/Bitcode/compatibility.ll
+++ b/llvm/test/Bitcode/compatibility.ll
@@ -217,6 +217,10 @@ declare void @g.f1()
 ; CHECK: @g.sanitize_address_dyninit = global i32 0, sanitize_address_dyninit
 ; CHECK: @g.sanitize_multiple = global i32 0, sanitize_memtag, sanitize_address_dyninit
 
+; ptrauth constant
+ at auth_var = global ptr ptrauth (ptr @g1, i32 0, i64 65535, ptr null)
+; CHECK: @auth_var = global ptr ptrauth (ptr @g1, i32 0, i64 65535)
+
 ;; Aliases
 ; Format: @<Name> = [Linkage] [Visibility] [DLLStorageClass] [ThreadLocal]
 ;                   [unnamed_addr] alias <AliaseeTy> @<Aliasee>
diff --git a/llvm/utils/vim/syntax/llvm.vim b/llvm/utils/vim/syntax/llvm.vim
index d86e3d1ddbc27..905d696400ca3 100644
--- a/llvm/utils/vim/syntax/llvm.vim
+++ b/llvm/utils/vim/syntax/llvm.vim
@@ -150,6 +150,7 @@ syn keyword llvmKeyword
       \ preallocated
       \ private
       \ protected
+      \ ptrauth
       \ ptx_device
       \ ptx_kernel
       \ readnone



More information about the llvm-commits mailing list