[llvm] [IR][AArch64] Add "ptrauth(...)" Constant to represent signed pointers. (PR #85738)

Ahmed Bougacha via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 18 23:06:59 PDT 2024


https://github.com/ahmedbougacha created https://github.com/llvm/llvm-project/pull/85738

This defines a new kind of IR Constant that represents a ptrauth signed pointer, as used in AArch64 PAuth.

It allows representing most kinds of signed pointer constants used thus far in the llvm ptrauth implementations, notably those used in the Darwin and ELF ABIs being implemented for c/c++. These signed pointer constants are then lowered to ELF/MachO relocations.

These can be simply thought of as a constant `llvm.ptrauth.sign`, with the interesting addition of discriminator computation: the `ptrauth` constant can also represent a combined blend, when both address and integer discriminator operands are used.

This also teaches some of the most common constant folding and analysis paths about these, usually in a straightforward way.  I have a couple fixmes to expand some of those here, as well as to add unittest for the ConstantPtrAuth methods.

>From aea98844234ea0dadca0bb089afa4d07f8f399b9 Mon Sep 17 00:00:00 2001
From: Ahmed Bougacha <ahmed at bougacha.org>
Date: Mon, 18 Mar 2024 23:03:25 -0700
Subject: [PATCH] [IR][AArch64] Add "ptrauth(...)" Constant to represent signed
 pointers.

This defines a new kind of IR Constant that represents a ptrauth
signed pointer, as used in AArch64 PAuth.

It allows representing most kinds of signed pointer constants
used thus far in the llvm ptrauth implementations, notably those
used in the Darwin and ELF ABIs being implemented for c/c++.
These signed pointer constants are then lowered to ELF/MachO
relocations.

These can be simply thought of as a constant `llvm.ptrauth.sign`,
with the interesting addition of discriminator computation:
the `ptrauth` constant can also represent a combined blend,
when both address and integer discriminator operands are used.

Co-Authored-by: Tim Northover <tnorthover at apple.com>
---
 llvm/docs/LangRef.rst                         |  27 ++++
 llvm/include/llvm-c/Core.h                    |   1 +
 llvm/include/llvm/AsmParser/LLToken.h         |   1 +
 llvm/include/llvm/Bitcode/LLVMBitCodes.h      |   2 +
 llvm/include/llvm/IR/Constants.h              |  63 ++++++++++
 llvm/include/llvm/IR/Value.def                |   1 +
 llvm/lib/Analysis/ValueTracking.cpp           |   7 ++
 llvm/lib/AsmParser/LLLexer.cpp                |   1 +
 llvm/lib/AsmParser/LLParser.cpp               |  42 +++++++
 llvm/lib/Bitcode/Reader/BitcodeReader.cpp     |  30 ++++-
 llvm/lib/Bitcode/Writer/BitcodeWriter.cpp     |   8 ++
 llvm/lib/IR/AsmWriter.cpp                     |  15 +++
 llvm/lib/IR/ConstantFold.cpp                  |   3 +
 llvm/lib/IR/Constants.cpp                     | 115 ++++++++++++++++++
 llvm/lib/IR/ConstantsContext.h                |  47 +++++++
 llvm/lib/IR/LLVMContextImpl.h                 |   2 +
 llvm/lib/IR/Verifier.cpp                      |  20 +++
 llvm/test/Assembler/invalid-ptrauth-const1.ll |   6 +
 llvm/test/Assembler/invalid-ptrauth-const2.ll |   6 +
 llvm/test/Assembler/invalid-ptrauth-const3.ll |   6 +
 llvm/test/Assembler/invalid-ptrauth-const4.ll |   6 +
 llvm/test/Assembler/invalid-ptrauth-const5.ll |   6 +
 llvm/test/Assembler/invalid-ptrauth-const6.ll |   6 +
 llvm/test/Assembler/ptrauth-const.ll          |  13 ++
 llvm/test/Bitcode/compatibility.ll            |   4 +
 llvm/utils/vim/syntax/llvm.vim                |   1 +
 26 files changed, 438 insertions(+), 1 deletion(-)
 create mode 100644 llvm/test/Assembler/invalid-ptrauth-const1.ll
 create mode 100644 llvm/test/Assembler/invalid-ptrauth-const2.ll
 create mode 100644 llvm/test/Assembler/invalid-ptrauth-const3.ll
 create mode 100644 llvm/test/Assembler/invalid-ptrauth-const4.ll
 create mode 100644 llvm/test/Assembler/invalid-ptrauth-const5.ll
 create mode 100644 llvm/test/Assembler/invalid-ptrauth-const6.ll
 create mode 100644 llvm/test/Assembler/ptrauth-const.ll

diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index e07b642285b3e6..0d91d4fc3ba1ef 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -4748,6 +4748,33 @@ reference to the CFI jump table in the ``LowerTypeTests`` pass. These constants
 may be useful in low-level programs, such as operating system kernels, which
 need to refer to the actual function body.
 
+.. _ptrauth
+
+Authenticated Pointers
+----------------------
+
+``ptrauth (ptr CST, i32 KEY, ptr ADDRDISC, i16 DISC)
+
+A '``ptrauth``' constant represents a pointer with a cryptographic
+authentication signature embedded into some bits. Its type is the same as the
+first argument.
+
+
+If the address disciminator is ``null`` then the expression is equivalent to
+
+.. code-block:llvm
+    %tmp = call i64 @llvm.ptrauth.sign.i64(i64 ptrtoint (ptr CST to i64), i32 KEY, i64 DISC)
+    %val = inttoptr i64 %tmp to ptr
+
+If the address discriminator is present, then it is
+
+.. code-block:llvm
+    %tmp1 = call i64 @llvm.ptrauth.blend.i64(i64 ptrtoint (ptr ADDRDISC to i64), i64 DISC)
+    %tmp2 = call i64 @llvm.ptrauth.sign.i64(i64 ptrtoint (ptr CST to i64), i64  %tmp1)
+    %val = inttoptr i64 %tmp2 to ptr
+
+    %tmp = call i64 @llvm.ptrauth.blend.i64
+
 .. _constantexprs:
 
 Constant Expressions
diff --git a/llvm/include/llvm-c/Core.h b/llvm/include/llvm-c/Core.h
index f56a6c961aad74..5f69a07fbed644 100644
--- a/llvm/include/llvm-c/Core.h
+++ b/llvm/include/llvm-c/Core.h
@@ -286,6 +286,7 @@ typedef enum {
   LLVMInstructionValueKind,
   LLVMPoisonValueValueKind,
   LLVMConstantTargetNoneValueKind,
+  LLVMConstantPtrAuthValueKind,
 } LLVMValueKind;
 
 typedef enum {
diff --git a/llvm/include/llvm/AsmParser/LLToken.h b/llvm/include/llvm/AsmParser/LLToken.h
index 5863a8d6e8ee84..e949023463f54d 100644
--- a/llvm/include/llvm/AsmParser/LLToken.h
+++ b/llvm/include/llvm/AsmParser/LLToken.h
@@ -343,6 +343,7 @@ enum Kind {
   kw_insertvalue,
   kw_blockaddress,
   kw_dso_local_equivalent,
+  kw_ptrauth,
   kw_no_cfi,
 
   kw_freeze,
diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index 39303e64852141..747bd55c2a8c82 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -411,6 +411,8 @@ enum ConstantsCodes {
                               //                 sideeffect|alignstack|
                               //                 asmdialect|unwind,
                               //                 asmstr,conststr]
+  CST_CODE_SIGNED_PTR = 31,   // CE_SIGNED_PTR: [ptrty, ptr, key,
+                              //                 addrdiscty, addrdisc, disc]
 };
 
 /// CastOpcodes - These are values used in the bitcode files to encode which
diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index c0ac9a4aa6750c..9cf53616cc921b 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -1006,6 +1006,69 @@ struct OperandTraits<NoCFIValue> : public FixedNumOperandTraits<NoCFIValue, 1> {
 
 DEFINE_TRANSPARENT_OPERAND_ACCESSORS(NoCFIValue, Value)
 
+/// A signed pointer
+///
+class ConstantPtrAuth final : public Constant {
+  friend struct ConstantPtrAuthKeyType;
+  friend class Constant;
+
+  ConstantPtrAuth(Constant *Ptr, ConstantInt *Key, Constant *AddrDisc,
+                  ConstantInt *Disc);
+
+  void *operator new(size_t s) { return User::operator new(s, 4); }
+
+  void destroyConstantImpl();
+  Value *handleOperandChangeImpl(Value *From, Value *To);
+
+public:
+  /// Return a pointer authenticated with the specified parameters.
+  static ConstantPtrAuth *get(Constant *Ptr, ConstantInt *Key,
+                              Constant *AddrDisc, ConstantInt *Disc);
+
+  /// Produce a new ptrauth expression signing the given value using
+  /// the same schema as is stored in one.
+  ConstantPtrAuth *getWithSameSchema(Constant *Pointer) const;
+
+  /// Transparently provide more efficient getOperand methods.
+  DECLARE_TRANSPARENT_OPERAND_ACCESSORS(Constant);
+
+  /// The pointer that is authenticated in this authenticated global reference.
+  Constant *getPointer() const { return (Constant *)Op<0>().get(); }
+
+  /// The Key ID, an i32 constant.
+  ConstantInt *getKey() const { return (ConstantInt *)Op<1>().get(); }
+
+  /// The address discriminator if any, or the null constant.
+  /// If present, this must be a value equivalent to the storage location of
+  /// the only user of the authenticated ptrauth global.
+  Constant *getAddrDiscriminator() const { return (Constant *)Op<2>().get(); }
+
+  /// The discriminator.
+  ConstantInt *getDiscriminator() const { return (ConstantInt *)Op<3>().get(); }
+
+  /// Whether there is any non-null address discriminator.
+  bool hasAddressDiversity() const {
+    return !getAddrDiscriminator()->isNullValue();
+  }
+
+  /// Check whether an authentication operation with key \p KeyV and (possibly
+  /// blended) discriminator \p DiscriminatorV is compatible with this
+  /// authenticated global reference.
+  bool isCompatibleWith(const Value *Key, const Value *Discriminator,
+                        const DataLayout &DL) const;
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast:
+  static bool classof(const Value *V) {
+    return V->getValueID() == ConstantPtrAuthVal;
+  }
+};
+
+template <>
+struct OperandTraits<ConstantPtrAuth>
+    : public FixedNumOperandTraits<ConstantPtrAuth, 4> {};
+
+DEFINE_TRANSPARENT_OPERAND_ACCESSORS(ConstantPtrAuth, Constant)
+
 //===----------------------------------------------------------------------===//
 /// A constant value that is initialized with an expression using
 /// other constant values.
diff --git a/llvm/include/llvm/IR/Value.def b/llvm/include/llvm/IR/Value.def
index 61f7a87666d094..31110ff05ae368 100644
--- a/llvm/include/llvm/IR/Value.def
+++ b/llvm/include/llvm/IR/Value.def
@@ -78,6 +78,7 @@ HANDLE_GLOBAL_VALUE(GlobalAlias)
 HANDLE_GLOBAL_VALUE(GlobalIFunc)
 HANDLE_GLOBAL_VALUE(GlobalVariable)
 HANDLE_CONSTANT(BlockAddress)
+HANDLE_CONSTANT(ConstantPtrAuth)
 HANDLE_CONSTANT(ConstantExpr)
 HANDLE_CONSTANT_EXCLUDE_LLVM_C_API(DSOLocalEquivalent)
 HANDLE_CONSTANT_EXCLUDE_LLVM_C_API(NoCFIValue)
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index fe5d084b55bbe3..d0bdaca57e47f5 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -2900,6 +2900,10 @@ bool isKnownNonZero(const Value *V, const APInt &DemandedElts, unsigned Depth,
       return true;
     }
 
+    // Constant ptrauth can be null, iff the base pointer can be.
+    if (auto *CPA = dyn_cast<ConstantPtrAuth>(V))
+      return isKnownNonZero(CPA->getPointer(), DemandedElts, Depth, Q);
+
     // A global variable in address space 0 is non null unless extern weak
     // or an absolute symbol reference. Other address spaces may have null as a
     // valid address for a global, so we can't assume anything.
@@ -6993,6 +6997,9 @@ static bool isGuaranteedNotToBeUndefOrPoison(
         isa<ConstantPointerNull>(C) || isa<Function>(C))
       return true;
 
+    if (isa<ConstantPtrAuth>(C))
+      return true;
+
     if (C->getType()->isVectorTy() && !isa<ConstantExpr>(C))
       return (!includesUndef(Kind) ? !C->containsPoisonElement()
                                    : !C->containsUndefOrPoisonElement()) &&
diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp
index 02f64fcfac4f0c..e37ee0bb90a82d 100644
--- a/llvm/lib/AsmParser/LLLexer.cpp
+++ b/llvm/lib/AsmParser/LLLexer.cpp
@@ -708,6 +708,7 @@ lltok::Kind LLLexer::LexIdentifier() {
   KEYWORD(blockaddress);
   KEYWORD(dso_local_equivalent);
   KEYWORD(no_cfi);
+  KEYWORD(ptrauth);
 
   // Metadata types.
   KEYWORD(distinct);
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index 2e0f5ba82220c9..21039f7efb9b27 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -3998,6 +3998,48 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
     ID.NoCFI = true;
     return false;
   }
+  case lltok::kw_ptrauth: {
+    // ValID ::= 'ptrauth' '(' ptr @foo ',' i32 <key> ','
+    //                         ptr addrdisc ',' i64 <disc> ')'
+    Lex.Lex();
+
+    Constant *Ptr, *Key, *AddrDisc, *Disc;
+
+    if (parseToken(lltok::lparen,
+                   "expected '(' in signed pointer expression") ||
+        parseGlobalTypeAndValue(Ptr) ||
+        parseToken(lltok::comma,
+                   "expected comma in signed pointer expression") ||
+        parseGlobalTypeAndValue(Key) ||
+        parseToken(lltok::comma,
+                   "expected comma in signed pointer expression") ||
+        parseGlobalTypeAndValue(AddrDisc) ||
+        parseToken(lltok::comma,
+                   "expected comma in signed pointer expression") ||
+        parseGlobalTypeAndValue(Disc) ||
+        parseToken(lltok::rparen, "expected ')' in signed pointer expression"))
+      return true;
+
+    if (!Ptr->getType()->isPointerTy())
+      return error(ID.Loc, "signed pointer must be a pointer");
+
+    auto KeyC = dyn_cast<ConstantInt>(Key);
+    if (!KeyC || KeyC->getBitWidth() != 32)
+      return error(ID.Loc, "signed pointer key must be i32 constant integer");
+
+    if (!AddrDisc->getType()->isPointerTy())
+      return error(ID.Loc,
+                   "signed pointer address discriminator must be a pointer");
+
+    auto DiscC = dyn_cast<ConstantInt>(Disc);
+    if (!DiscC || DiscC->getBitWidth() != 64)
+      return error(ID.Loc,
+                   "signed pointer discriminator must be i64 constant integer");
+
+    ID.ConstantVal = ConstantPtrAuth::get(Ptr, KeyC, AddrDisc, DiscC);
+    ID.Kind = ValID::t_Constant;
+    return false;
+  }
 
   case lltok::kw_trunc:
   case lltok::kw_bitcast:
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index d284c9823c9ede..538200abcf6f97 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -504,7 +504,8 @@ class BitcodeConstant final : public Value,
   static constexpr uint8_t NoCFIOpcode = 252;
   static constexpr uint8_t DSOLocalEquivalentOpcode = 251;
   static constexpr uint8_t BlockAddressOpcode = 250;
-  static constexpr uint8_t FirstSpecialOpcode = BlockAddressOpcode;
+  static constexpr uint8_t ConstantPtrAuthOpcode = 249;
+  static constexpr uint8_t FirstSpecialOpcode = ConstantPtrAuthOpcode;
 
   // Separate struct to make passing different number of parameters to
   // BitcodeConstant::create() more convenient.
@@ -1528,6 +1529,18 @@ Expected<Value *> BitcodeReader::materializeValue(unsigned StartValID,
         C = ConstantExpr::get(BC->Opcode, ConstOps[0], ConstOps[1], BC->Flags);
       } else {
         switch (BC->Opcode) {
+        case BitcodeConstant::ConstantPtrAuthOpcode: {
+          auto *Key = dyn_cast<ConstantInt>(ConstOps[1]);
+          if (!Key)
+            return error("ptrauth key operand must be ConstantInt");
+
+          auto *Disc = dyn_cast<ConstantInt>(ConstOps[3]);
+          if (!Disc)
+            return error("ptrauth disc operand must be ConstantInt");
+
+          C = ConstantPtrAuth::get(ConstOps[0], Key, ConstOps[2], Disc);
+          break;
+        }
         case BitcodeConstant::NoCFIOpcode: {
           auto *GV = dyn_cast<GlobalValue>(ConstOps[0]);
           if (!GV)
@@ -3596,6 +3609,21 @@ Error BitcodeReader::parseConstants() {
                                   Record[1]);
       break;
     }
+    case bitc::CST_CODE_SIGNED_PTR: {
+      if (Record.size() < 6)
+        return error("Invalid record");
+      Type *PtrTy = getTypeByID(Record[0]);
+      if (!PtrTy)
+        return error("Invalid record");
+
+      // PtrTy, Ptr, Key, AddrDiscTy, AddrDisc, Disc
+      V = BitcodeConstant::create(
+        Alloc, CurTy, BitcodeConstant::ConstantPtrAuthOpcode,
+        {(unsigned)Record[1], (unsigned)Record[2], (unsigned)Record[4],
+         (unsigned)Record[5]});
+
+      break;
+    }
     }
 
     assert(V->getType() == getTypeByID(CurTyID) && "Incorrect result type ID");
diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index 6f0879a4e0ee74..74f1bd8ba49b57 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -2800,6 +2800,14 @@ void ModuleBitcodeWriter::writeConstants(unsigned FirstVal, unsigned LastVal,
       Record.push_back(VE.getTypeID(BA->getFunction()->getType()));
       Record.push_back(VE.getValueID(BA->getFunction()));
       Record.push_back(VE.getGlobalBasicBlockID(BA->getBasicBlock()));
+    } else if (const ConstantPtrAuth *SP = dyn_cast<ConstantPtrAuth>(C)) {
+      Code = bitc::CST_CODE_SIGNED_PTR;
+      Record.push_back(VE.getTypeID(SP->getPointer()->getType()));
+      Record.push_back(VE.getValueID(SP->getPointer()));
+      Record.push_back(VE.getValueID(SP->getKey()));
+      Record.push_back(VE.getTypeID(SP->getAddrDiscriminator()->getType()));
+      Record.push_back(VE.getValueID(SP->getAddrDiscriminator()));
+      Record.push_back(VE.getValueID(SP->getDiscriminator()));
     } else if (const auto *Equiv = dyn_cast<DSOLocalEquivalent>(C)) {
       Code = bitc::CST_CODE_DSO_LOCAL_EQUIVALENT;
       Record.push_back(VE.getTypeID(Equiv->getGlobalValue()->getType()));
diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index 19acc89f73fb7e..0e9227f0945a4d 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -1578,6 +1578,21 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
     return;
   }
 
+  if (const ConstantPtrAuth *SP = dyn_cast<ConstantPtrAuth>(CV)) {
+    Out << "ptrauth (";
+
+    for (unsigned i = 0; i < SP->getNumOperands(); ++i) {
+      WriterCtx.TypePrinter->print(SP->getOperand(i)->getType(), Out);
+      Out << ' ';
+      WriteAsOperandInternal(Out, SP->getOperand(i), WriterCtx);
+      if (i != SP->getNumOperands() - 1)
+        Out << ", ";
+    }
+
+    Out << ')';
+    return;
+  }
+
   if (const ConstantArray *CA = dyn_cast<ConstantArray>(CV)) {
     Type *ETy = CA->getType()->getElementType();
     Out << '[';
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index 034e397bc69fce..0bfec86783378a 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -1154,6 +1154,9 @@ static ICmpInst::Predicate evaluateICmpRelation(Constant *V1, Constant *V2) {
                                 GV->getType()->getAddressSpace()))
         return ICmpInst::ICMP_UGT;
     }
+  } else if (const ConstantPtrAuth *SP = dyn_cast<ConstantPtrAuth>(V1)) {
+    // FIXME: ahmedbougacha: implement ptrauth cst comparison
+    return ICmpInst::BAD_ICMP_PREDICATE;
   } else {
     // Ok, the LHS is known to be a constantexpr.  The RHS can be any of a
     // constantexpr, a global, block address, or a simple constant.
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index e6b92aad392f66..1af52f9e612c29 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -550,6 +550,9 @@ void llvm::deleteConstant(Constant *C) {
   case Constant::NoCFIValueVal:
     delete static_cast<NoCFIValue *>(C);
     break;
+  case Constant::ConstantPtrAuthVal:
+    delete static_cast<ConstantPtrAuth *>(C);
+    break;
   case Constant::UndefValueVal:
     delete static_cast<UndefValue *>(C);
     break;
@@ -2015,6 +2018,118 @@ Value *NoCFIValue::handleOperandChangeImpl(Value *From, Value *To) {
   return nullptr;
 }
 
+//---- ConstantPtrAuth::get() implementations.
+//
+
+static bool areEquivalentAddrDiscriminators(const Value *V1, const Value *V2,
+                                            const DataLayout &DL) {
+  APInt V1Off(DL.getPointerSizeInBits(), 0);
+  APInt V2Off(DL.getPointerSizeInBits(), 0);
+
+  if (auto *V1Cast = dyn_cast<PtrToIntOperator>(V1))
+    V1 = V1Cast->getPointerOperand();
+  if (auto *V2Cast = dyn_cast<PtrToIntOperator>(V2))
+    V2 = V2Cast->getPointerOperand();
+  auto *V1Base = V1->stripAndAccumulateConstantOffsets(
+      DL, V1Off, /*AllowNonInbounds=*/true);
+  auto *V2Base = V2->stripAndAccumulateConstantOffsets(
+      DL, V2Off, /*AllowNonInbounds=*/true);
+  return V1Base == V2Base && V1Off == V2Off;
+}
+
+bool ConstantPtrAuth::isCompatibleWith(const Value *Key,
+                                       const Value *Discriminator,
+                                       const DataLayout &DL) const {
+  // If the keys are different, there's no chance for this to be compatible.
+  if (Key != getKey())
+    return false;
+
+  // If the discriminators are the same, this is compatible iff there is no
+  // address discriminator.
+  if (Discriminator == getDiscriminator())
+    return getAddrDiscriminator()->isNullValue();
+
+  // If we dynamically blend the discriminator with the address discriminator,
+  // this is compatible.
+  if (auto *DiscBlend = dyn_cast<IntrinsicInst>(Discriminator)) {
+    if (DiscBlend->getIntrinsicID() == Intrinsic::ptrauth_blend &&
+        DiscBlend->getOperand(1) == getDiscriminator() &&
+        areEquivalentAddrDiscriminators(DiscBlend->getOperand(0),
+                                        getAddrDiscriminator(), DL))
+      return true;
+  }
+
+  // If we don't have a non-address discriminator, we don't need a blend in
+  // the first place:  accept the address discriminator as the discriminator.
+  if (getDiscriminator()->isNullValue() &&
+      areEquivalentAddrDiscriminators(getAddrDiscriminator(), Discriminator,
+                                      DL))
+    return true;
+
+  // Otherwise, we don't know.
+  return false;
+}
+
+ConstantPtrAuth *ConstantPtrAuth::getWithSameSchema(Constant *Pointer) const {
+  return get(Pointer, getKey(), getAddrDiscriminator(), getDiscriminator());
+}
+
+ConstantPtrAuth *ConstantPtrAuth::get(Constant *Ptr, ConstantInt *Key,
+                                      Constant *AddrDisc, ConstantInt *Disc) {
+  Constant *ArgVec[] = {Ptr, Key, AddrDisc, Disc};
+  ConstantPtrAuthKeyType MapKey(ArgVec);
+  LLVMContextImpl *pImpl = Ptr->getContext().pImpl;
+  return pImpl->ConstantPtrAuths.getOrCreate(Ptr->getType(), MapKey);
+}
+
+ConstantPtrAuth::ConstantPtrAuth(Constant *Ptr, ConstantInt *Key,
+                                 Constant *AddrDisc, ConstantInt *Disc)
+    : Constant(Ptr->getType(), Value::ConstantPtrAuthVal, &Op<0>(), 4) {
+#ifndef NDEBUG
+  assert(Ptr->getType()->isPointerTy());
+  assert(Key->getBitWidth() == 32);
+  assert(AddrDisc->getType()->isPointerTy());
+  assert(Disc->getBitWidth() == 64);
+#endif
+  setOperand(0, Ptr);
+  setOperand(1, Key);
+  setOperand(2, AddrDisc);
+  setOperand(3, Disc);
+}
+
+/// Remove the constant from the constant table.
+void ConstantPtrAuth::destroyConstantImpl() {
+  getType()->getContext().pImpl->ConstantPtrAuths.remove(this);
+}
+
+Value *ConstantPtrAuth::handleOperandChangeImpl(Value *From, Value *ToV) {
+  assert(isa<Constant>(ToV) && "Cannot make Constant refer to non-constant!");
+  Constant *To = cast<Constant>(ToV);
+
+  SmallVector<Constant *, 8> Values;
+  Values.reserve(getNumOperands()); // Build replacement array.
+
+  // Fill values with the modified operands of the constant array.  Also,
+  // compute whether this turns into an all-zeros array.
+  unsigned NumUpdated = 0;
+
+  Use *OperandList = getOperandList();
+  unsigned OperandNo = 0;
+  for (Use *O = OperandList, *E = OperandList + getNumOperands(); O != E; ++O) {
+    Constant *Val = cast<Constant>(O->get());
+    if (Val == From) {
+      OperandNo = (O - OperandList);
+      Val = To;
+      ++NumUpdated;
+    }
+    Values.push_back(Val);
+  }
+
+  // FIXME: shouldn't we check it's not already there?
+  return getContext().pImpl->ConstantPtrAuths.replaceOperandsInPlace(
+      Values, this, From, To, NumUpdated, OperandNo);
+}
+
 //---- ConstantExpr::get() implementations.
 //
 
diff --git a/llvm/lib/IR/ConstantsContext.h b/llvm/lib/IR/ConstantsContext.h
index 44a926b5dc58e0..bd111f406687c1 100644
--- a/llvm/lib/IR/ConstantsContext.h
+++ b/llvm/lib/IR/ConstantsContext.h
@@ -23,6 +23,7 @@
 #include "llvm/IR/Constant.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/GlobalVariable.h"
 #include "llvm/IR/InlineAsm.h"
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Instructions.h"
@@ -282,6 +283,7 @@ DEFINE_TRANSPARENT_OPERAND_ACCESSORS(CompareConstantExpr, Value)
 template <class ConstantClass> struct ConstantAggrKeyType;
 struct InlineAsmKeyType;
 struct ConstantExprKeyType;
+struct ConstantPtrAuthKeyType;
 
 template <class ConstantClass> struct ConstantInfo;
 template <> struct ConstantInfo<ConstantExpr> {
@@ -304,6 +306,10 @@ template <> struct ConstantInfo<ConstantVector> {
   using ValType = ConstantAggrKeyType<ConstantVector>;
   using TypeClass = VectorType;
 };
+template <> struct ConstantInfo<ConstantPtrAuth> {
+  using ValType = ConstantPtrAuthKeyType;
+  using TypeClass = Type;
+};
 
 template <class ConstantClass> struct ConstantAggrKeyType {
   ArrayRef<Constant *> Operands;
@@ -511,6 +517,47 @@ struct ConstantExprKeyType {
   }
 };
 
+struct ConstantPtrAuthKeyType {
+  ArrayRef<Constant *> Operands;
+
+  ConstantPtrAuthKeyType(ArrayRef<Constant *> Operands) : Operands(Operands) {}
+
+  ConstantPtrAuthKeyType(ArrayRef<Constant *> Operands, const ConstantPtrAuth *)
+      : Operands(Operands) {}
+
+  ConstantPtrAuthKeyType(const ConstantPtrAuth *C,
+                         SmallVectorImpl<Constant *> &Storage) {
+    assert(Storage.empty() && "Expected empty storage");
+    for (unsigned I = 0, E = C->getNumOperands(); I != E; ++I)
+      Storage.push_back(cast<Constant>(C->getOperand(I)));
+    Operands = Storage;
+  }
+
+  bool operator==(const ConstantPtrAuthKeyType &X) const {
+    return Operands == X.Operands;
+  }
+
+  bool operator==(const ConstantPtrAuth *C) const {
+    if (Operands.size() != C->getNumOperands())
+      return false;
+    for (unsigned I = 0, E = Operands.size(); I != E; ++I)
+      if (Operands[I] != C->getOperand(I))
+        return false;
+    return true;
+  }
+
+  unsigned getHash() const {
+    return hash_combine_range(Operands.begin(), Operands.end());
+  }
+
+  using TypeClass = typename ConstantInfo<ConstantPtrAuth>::TypeClass;
+
+  ConstantPtrAuth *create(TypeClass *Ty) const {
+    return new ConstantPtrAuth(Operands[0], cast<ConstantInt>(Operands[1]),
+                               Operands[2], cast<ConstantInt>(Operands[3]));
+  }
+};
+
 // Free memory for a given constant.  Assumes the constant has already been
 // removed from all relevant maps.
 void deleteConstant(Constant *C);
diff --git a/llvm/lib/IR/LLVMContextImpl.h b/llvm/lib/IR/LLVMContextImpl.h
index b1dcb262fb657f..bc0f1f2cf8bc31 100644
--- a/llvm/lib/IR/LLVMContextImpl.h
+++ b/llvm/lib/IR/LLVMContextImpl.h
@@ -1546,6 +1546,8 @@ class LLVMContextImpl {
 
   DenseMap<const GlobalValue *, NoCFIValue *> NoCFIValues;
 
+  ConstantUniqueMap<ConstantPtrAuth> ConstantPtrAuths;
+
   ConstantUniqueMap<ConstantExpr> ExprConstants;
 
   ConstantUniqueMap<InlineAsm> InlineAsms;
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 62dde2e6ad4243..22977de66207ee 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -626,6 +626,7 @@ class Verifier : public InstVisitor<Verifier>, VerifierSupport {
 
   void visitConstantExprsRecursively(const Constant *EntryC);
   void visitConstantExpr(const ConstantExpr *CE);
+  void visitConstantPtrAuth(const ConstantPtrAuth *SP);
   void verifyInlineAsmCall(const CallBase &Call);
   void verifyStatepoint(const CallBase &Call);
   void verifyFrameRecoverIndices();
@@ -2396,6 +2397,9 @@ void Verifier::visitConstantExprsRecursively(const Constant *EntryC) {
     if (const auto *CE = dyn_cast<ConstantExpr>(C))
       visitConstantExpr(CE);
 
+    if (const auto *SP = dyn_cast<ConstantPtrAuth>(C))
+      visitConstantPtrAuth(SP);
+
     if (const auto *GV = dyn_cast<GlobalValue>(C)) {
       // Global Values get visited separately, but we do need to make sure
       // that the global value is in the correct module
@@ -2423,6 +2427,20 @@ void Verifier::visitConstantExpr(const ConstantExpr *CE) {
           "Invalid bitcast", CE);
 }
 
+void Verifier::visitConstantPtrAuth(const ConstantPtrAuth *SP) {
+  Check(SP->getPointer()->getType()->isPointerTy(),
+        "signed pointer must be a pointer");
+
+  Check(SP->getKey()->getBitWidth() == 32,
+        "signed pointer key must be i32 constant integer");
+
+  Check(SP->getAddrDiscriminator()->getType()->isPointerTy(),
+        "signed pointer address discriminator must be a pointer");
+
+  Check(SP->getDiscriminator()->getBitWidth() == 64,
+        "signed pointer discriminator must be i64 constant integer");
+}
+
 bool Verifier::verifyAttributeCount(AttributeList Attrs, unsigned Params) {
   // There shouldn't be more attribute sets than there are parameters plus the
   // function and return value.
@@ -5039,6 +5057,8 @@ void Verifier::visitInstruction(Instruction &I) {
     } else if (isa<InlineAsm>(I.getOperand(i))) {
       Check(CBI && &CBI->getCalledOperandUse() == &I.getOperandUse(i),
             "Cannot take the address of an inline asm!", &I);
+    } else if (auto SP = dyn_cast<ConstantPtrAuth>(I.getOperand(i))) {
+      visitConstantExprsRecursively(SP);
     } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(I.getOperand(i))) {
       if (CE->getType()->isPtrOrPtrVectorTy()) {
         // If we have a ConstantExpr pointer, we need to see if it came from an
diff --git a/llvm/test/Assembler/invalid-ptrauth-const1.ll b/llvm/test/Assembler/invalid-ptrauth-const1.ll
new file mode 100644
index 00000000000000..dfe0e1c5ed3f05
--- /dev/null
+++ b/llvm/test/Assembler/invalid-ptrauth-const1.ll
@@ -0,0 +1,6 @@
+; RUN: not llvm-as < %s 2>&1 | FileCheck %s
+
+ at var = global i32 0
+
+; CHECK: error: signed pointer must be a pointer
+ at auth_var = global ptr ptrauth (i32 42, i32 0, ptr null, i64 65535)
diff --git a/llvm/test/Assembler/invalid-ptrauth-const2.ll b/llvm/test/Assembler/invalid-ptrauth-const2.ll
new file mode 100644
index 00000000000000..9042d8dfe77263
--- /dev/null
+++ b/llvm/test/Assembler/invalid-ptrauth-const2.ll
@@ -0,0 +1,6 @@
+; RUN: not llvm-as < %s 2>&1 | FileCheck %s
+
+ at var = global i32 0
+
+; CHECK: error: signed pointer key must be i32 constant integer
+ at auth_var = global ptr ptrauth (ptr @var, i32 ptrtoint (ptr @var to i32), ptr null, i64 65535)
diff --git a/llvm/test/Assembler/invalid-ptrauth-const3.ll b/llvm/test/Assembler/invalid-ptrauth-const3.ll
new file mode 100644
index 00000000000000..00bcef9db6d1b1
--- /dev/null
+++ b/llvm/test/Assembler/invalid-ptrauth-const3.ll
@@ -0,0 +1,6 @@
+; RUN: not llvm-as < %s 2>&1 | FileCheck %s
+
+ at var = global i32 0
+
+; CHECK: error: signed pointer address discriminator must be a pointer
+ at auth_var = global ptr ptrauth (ptr @var, i32 2, i8 0, i64 65535)
diff --git a/llvm/test/Assembler/invalid-ptrauth-const4.ll b/llvm/test/Assembler/invalid-ptrauth-const4.ll
new file mode 100644
index 00000000000000..00bcef9db6d1b1
--- /dev/null
+++ b/llvm/test/Assembler/invalid-ptrauth-const4.ll
@@ -0,0 +1,6 @@
+; RUN: not llvm-as < %s 2>&1 | FileCheck %s
+
+ at var = global i32 0
+
+; CHECK: error: signed pointer address discriminator must be a pointer
+ at auth_var = global ptr ptrauth (ptr @var, i32 2, i8 0, i64 65535)
diff --git a/llvm/test/Assembler/invalid-ptrauth-const5.ll b/llvm/test/Assembler/invalid-ptrauth-const5.ll
new file mode 100644
index 00000000000000..72910769de262d
--- /dev/null
+++ b/llvm/test/Assembler/invalid-ptrauth-const5.ll
@@ -0,0 +1,6 @@
+; RUN: not llvm-as < %s 2>&1 | FileCheck %s
+
+ at var = global i32 0
+
+; CHECK: error: signed pointer discriminator must be i64 constant integer
+ at auth_var = global ptr ptrauth (ptr @var, i32 2, ptr null, i64 ptrtoint (ptr @var to i64))
diff --git a/llvm/test/Assembler/invalid-ptrauth-const6.ll b/llvm/test/Assembler/invalid-ptrauth-const6.ll
new file mode 100644
index 00000000000000..d79d88ee845ce6
--- /dev/null
+++ b/llvm/test/Assembler/invalid-ptrauth-const6.ll
@@ -0,0 +1,6 @@
+; RUN: not llvm-as < %s 2>&1 | FileCheck %s
+
+ at var = global i32 0
+
+; CHECK: error: signed pointer address discriminator must be a pointer
+ at auth_var = global ptr ptrauth (ptr @var, i32 2, i8 0, i32 1000000)
diff --git a/llvm/test/Assembler/ptrauth-const.ll b/llvm/test/Assembler/ptrauth-const.ll
new file mode 100644
index 00000000000000..15b71e5ef8c136
--- /dev/null
+++ b/llvm/test/Assembler/ptrauth-const.ll
@@ -0,0 +1,13 @@
+; RUN: llvm-as < %s | llvm-dis | FileCheck %s
+
+ at var = global i32 0
+
+; CHECK: @auth_var = global ptr ptrauth (ptr @var, i32 0, ptr null, i64 -1)
+ at auth_var = global ptr ptrauth (ptr @var, i32 0, ptr null, i64 -1)
+
+
+; CHECK: @addrdisc_var = global ptr ptrauth (ptr @var, i32 0, ptr @addrdisc_var, i64 1234)
+ at addrdisc_var = global ptr ptrauth (ptr @var, i32 0, ptr @addrdisc_var, i64 1234)
+
+; CHECK: @keyed_var = global ptr ptrauth (ptr @var, i32 3, ptr null, i64 0)
+ at keyed_var = global ptr ptrauth (ptr @var, i32 3, ptr null, i64 0)
diff --git a/llvm/test/Bitcode/compatibility.ll b/llvm/test/Bitcode/compatibility.ll
index ce6a6571ec144c..ff00259192339a 100644
--- a/llvm/test/Bitcode/compatibility.ll
+++ b/llvm/test/Bitcode/compatibility.ll
@@ -217,6 +217,10 @@ declare void @g.f1()
 ; CHECK: @g.sanitize_address_dyninit = global i32 0, sanitize_address_dyninit
 ; CHECK: @g.sanitize_multiple = global i32 0, sanitize_memtag, sanitize_address_dyninit
 
+; ptrauth constant
+ at auth_var = global ptr ptrauth (ptr @g1, i32 0, ptr null, i64 65535)
+; CHECK: @auth_var = global ptr ptrauth (ptr @g1, i32 0, ptr null, i64 65535)
+
 ;; Aliases
 ; Format: @<Name> = [Linkage] [Visibility] [DLLStorageClass] [ThreadLocal]
 ;                   [unnamed_addr] alias <AliaseeTy> @<Aliasee>
diff --git a/llvm/utils/vim/syntax/llvm.vim b/llvm/utils/vim/syntax/llvm.vim
index d86e3d1ddbc27f..905d696400ca37 100644
--- a/llvm/utils/vim/syntax/llvm.vim
+++ b/llvm/utils/vim/syntax/llvm.vim
@@ -150,6 +150,7 @@ syn keyword llvmKeyword
       \ preallocated
       \ private
       \ protected
+      \ ptrauth
       \ ptx_device
       \ ptx_kernel
       \ readnone



More information about the llvm-commits mailing list