[llvm] [PAC][IR][AArch64] Add "ptrauth(...)" Constant to represent signed pointers. (PR #85738)

Ahmed Bougacha via llvm-commits llvm-commits at lists.llvm.org
Wed May 8 10:38:14 PDT 2024


https://github.com/ahmedbougacha updated https://github.com/llvm/llvm-project/pull/85738

>From aea98844234ea0dadca0bb089afa4d07f8f399b9 Mon Sep 17 00:00:00 2001
From: Ahmed Bougacha <ahmed at bougacha.org>
Date: Mon, 18 Mar 2024 23:03:25 -0700
Subject: [PATCH 1/4] [IR][AArch64] Add "ptrauth(...)" Constant to represent
 signed pointers.

This defines a new kind of IR Constant that represents a ptrauth
signed pointer, as used in AArch64 PAuth.

It allows representing most kinds of signed pointer constants
used thus far in the llvm ptrauth implementations, notably those
used in the Darwin and ELF ABIs being implemented for c/c++.
These signed pointer constants are then lowered to ELF/MachO
relocations.

These can be simply thought of as a constant `llvm.ptrauth.sign`,
with the interesting addition of discriminator computation:
the `ptrauth` constant can also represent a combined blend,
when both address and integer discriminator operands are used.

Co-Authored-by: Tim Northover <tnorthover at apple.com>
---
 llvm/docs/LangRef.rst                         |  27 ++++
 llvm/include/llvm-c/Core.h                    |   1 +
 llvm/include/llvm/AsmParser/LLToken.h         |   1 +
 llvm/include/llvm/Bitcode/LLVMBitCodes.h      |   2 +
 llvm/include/llvm/IR/Constants.h              |  63 ++++++++++
 llvm/include/llvm/IR/Value.def                |   1 +
 llvm/lib/Analysis/ValueTracking.cpp           |   7 ++
 llvm/lib/AsmParser/LLLexer.cpp                |   1 +
 llvm/lib/AsmParser/LLParser.cpp               |  42 +++++++
 llvm/lib/Bitcode/Reader/BitcodeReader.cpp     |  30 ++++-
 llvm/lib/Bitcode/Writer/BitcodeWriter.cpp     |   8 ++
 llvm/lib/IR/AsmWriter.cpp                     |  15 +++
 llvm/lib/IR/ConstantFold.cpp                  |   3 +
 llvm/lib/IR/Constants.cpp                     | 115 ++++++++++++++++++
 llvm/lib/IR/ConstantsContext.h                |  47 +++++++
 llvm/lib/IR/LLVMContextImpl.h                 |   2 +
 llvm/lib/IR/Verifier.cpp                      |  20 +++
 llvm/test/Assembler/invalid-ptrauth-const1.ll |   6 +
 llvm/test/Assembler/invalid-ptrauth-const2.ll |   6 +
 llvm/test/Assembler/invalid-ptrauth-const3.ll |   6 +
 llvm/test/Assembler/invalid-ptrauth-const4.ll |   6 +
 llvm/test/Assembler/invalid-ptrauth-const5.ll |   6 +
 llvm/test/Assembler/invalid-ptrauth-const6.ll |   6 +
 llvm/test/Assembler/ptrauth-const.ll          |  13 ++
 llvm/test/Bitcode/compatibility.ll            |   4 +
 llvm/utils/vim/syntax/llvm.vim                |   1 +
 26 files changed, 438 insertions(+), 1 deletion(-)
 create mode 100644 llvm/test/Assembler/invalid-ptrauth-const1.ll
 create mode 100644 llvm/test/Assembler/invalid-ptrauth-const2.ll
 create mode 100644 llvm/test/Assembler/invalid-ptrauth-const3.ll
 create mode 100644 llvm/test/Assembler/invalid-ptrauth-const4.ll
 create mode 100644 llvm/test/Assembler/invalid-ptrauth-const5.ll
 create mode 100644 llvm/test/Assembler/invalid-ptrauth-const6.ll
 create mode 100644 llvm/test/Assembler/ptrauth-const.ll

diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index e07b642285b3e..0d91d4fc3ba1e 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -4748,6 +4748,33 @@ reference to the CFI jump table in the ``LowerTypeTests`` pass. These constants
 may be useful in low-level programs, such as operating system kernels, which
 need to refer to the actual function body.
 
+.. _ptrauth
+
+Authenticated Pointers
+----------------------
+
+``ptrauth (ptr CST, i32 KEY, ptr ADDRDISC, i16 DISC)
+
+A '``ptrauth``' constant represents a pointer with a cryptographic
+authentication signature embedded into some bits. Its type is the same as the
+first argument.
+
+
+If the address disciminator is ``null`` then the expression is equivalent to
+
+.. code-block:llvm
+    %tmp = call i64 @llvm.ptrauth.sign.i64(i64 ptrtoint (ptr CST to i64), i32 KEY, i64 DISC)
+    %val = inttoptr i64 %tmp to ptr
+
+If the address discriminator is present, then it is
+
+.. code-block:llvm
+    %tmp1 = call i64 @llvm.ptrauth.blend.i64(i64 ptrtoint (ptr ADDRDISC to i64), i64 DISC)
+    %tmp2 = call i64 @llvm.ptrauth.sign.i64(i64 ptrtoint (ptr CST to i64), i64  %tmp1)
+    %val = inttoptr i64 %tmp2 to ptr
+
+    %tmp = call i64 @llvm.ptrauth.blend.i64
+
 .. _constantexprs:
 
 Constant Expressions
diff --git a/llvm/include/llvm-c/Core.h b/llvm/include/llvm-c/Core.h
index f56a6c961aad7..5f69a07fbed64 100644
--- a/llvm/include/llvm-c/Core.h
+++ b/llvm/include/llvm-c/Core.h
@@ -286,6 +286,7 @@ typedef enum {
   LLVMInstructionValueKind,
   LLVMPoisonValueValueKind,
   LLVMConstantTargetNoneValueKind,
+  LLVMConstantPtrAuthValueKind,
 } LLVMValueKind;
 
 typedef enum {
diff --git a/llvm/include/llvm/AsmParser/LLToken.h b/llvm/include/llvm/AsmParser/LLToken.h
index 5863a8d6e8ee8..e949023463f54 100644
--- a/llvm/include/llvm/AsmParser/LLToken.h
+++ b/llvm/include/llvm/AsmParser/LLToken.h
@@ -343,6 +343,7 @@ enum Kind {
   kw_insertvalue,
   kw_blockaddress,
   kw_dso_local_equivalent,
+  kw_ptrauth,
   kw_no_cfi,
 
   kw_freeze,
diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index 39303e6485214..747bd55c2a8c8 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -411,6 +411,8 @@ enum ConstantsCodes {
                               //                 sideeffect|alignstack|
                               //                 asmdialect|unwind,
                               //                 asmstr,conststr]
+  CST_CODE_SIGNED_PTR = 31,   // CE_SIGNED_PTR: [ptrty, ptr, key,
+                              //                 addrdiscty, addrdisc, disc]
 };
 
 /// CastOpcodes - These are values used in the bitcode files to encode which
diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index c0ac9a4aa6750..9cf53616cc921 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -1006,6 +1006,69 @@ struct OperandTraits<NoCFIValue> : public FixedNumOperandTraits<NoCFIValue, 1> {
 
 DEFINE_TRANSPARENT_OPERAND_ACCESSORS(NoCFIValue, Value)
 
+/// A signed pointer
+///
+class ConstantPtrAuth final : public Constant {
+  friend struct ConstantPtrAuthKeyType;
+  friend class Constant;
+
+  ConstantPtrAuth(Constant *Ptr, ConstantInt *Key, Constant *AddrDisc,
+                  ConstantInt *Disc);
+
+  void *operator new(size_t s) { return User::operator new(s, 4); }
+
+  void destroyConstantImpl();
+  Value *handleOperandChangeImpl(Value *From, Value *To);
+
+public:
+  /// Return a pointer authenticated with the specified parameters.
+  static ConstantPtrAuth *get(Constant *Ptr, ConstantInt *Key,
+                              Constant *AddrDisc, ConstantInt *Disc);
+
+  /// Produce a new ptrauth expression signing the given value using
+  /// the same schema as is stored in one.
+  ConstantPtrAuth *getWithSameSchema(Constant *Pointer) const;
+
+  /// Transparently provide more efficient getOperand methods.
+  DECLARE_TRANSPARENT_OPERAND_ACCESSORS(Constant);
+
+  /// The pointer that is authenticated in this authenticated global reference.
+  Constant *getPointer() const { return (Constant *)Op<0>().get(); }
+
+  /// The Key ID, an i32 constant.
+  ConstantInt *getKey() const { return (ConstantInt *)Op<1>().get(); }
+
+  /// The address discriminator if any, or the null constant.
+  /// If present, this must be a value equivalent to the storage location of
+  /// the only user of the authenticated ptrauth global.
+  Constant *getAddrDiscriminator() const { return (Constant *)Op<2>().get(); }
+
+  /// The discriminator.
+  ConstantInt *getDiscriminator() const { return (ConstantInt *)Op<3>().get(); }
+
+  /// Whether there is any non-null address discriminator.
+  bool hasAddressDiversity() const {
+    return !getAddrDiscriminator()->isNullValue();
+  }
+
+  /// Check whether an authentication operation with key \p KeyV and (possibly
+  /// blended) discriminator \p DiscriminatorV is compatible with this
+  /// authenticated global reference.
+  bool isCompatibleWith(const Value *Key, const Value *Discriminator,
+                        const DataLayout &DL) const;
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast:
+  static bool classof(const Value *V) {
+    return V->getValueID() == ConstantPtrAuthVal;
+  }
+};
+
+template <>
+struct OperandTraits<ConstantPtrAuth>
+    : public FixedNumOperandTraits<ConstantPtrAuth, 4> {};
+
+DEFINE_TRANSPARENT_OPERAND_ACCESSORS(ConstantPtrAuth, Constant)
+
 //===----------------------------------------------------------------------===//
 /// A constant value that is initialized with an expression using
 /// other constant values.
diff --git a/llvm/include/llvm/IR/Value.def b/llvm/include/llvm/IR/Value.def
index 61f7a87666d09..31110ff05ae36 100644
--- a/llvm/include/llvm/IR/Value.def
+++ b/llvm/include/llvm/IR/Value.def
@@ -78,6 +78,7 @@ HANDLE_GLOBAL_VALUE(GlobalAlias)
 HANDLE_GLOBAL_VALUE(GlobalIFunc)
 HANDLE_GLOBAL_VALUE(GlobalVariable)
 HANDLE_CONSTANT(BlockAddress)
+HANDLE_CONSTANT(ConstantPtrAuth)
 HANDLE_CONSTANT(ConstantExpr)
 HANDLE_CONSTANT_EXCLUDE_LLVM_C_API(DSOLocalEquivalent)
 HANDLE_CONSTANT_EXCLUDE_LLVM_C_API(NoCFIValue)
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index fe5d084b55bbe..d0bdaca57e47f 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -2900,6 +2900,10 @@ bool isKnownNonZero(const Value *V, const APInt &DemandedElts, unsigned Depth,
       return true;
     }
 
+    // Constant ptrauth can be null, iff the base pointer can be.
+    if (auto *CPA = dyn_cast<ConstantPtrAuth>(V))
+      return isKnownNonZero(CPA->getPointer(), DemandedElts, Depth, Q);
+
     // A global variable in address space 0 is non null unless extern weak
     // or an absolute symbol reference. Other address spaces may have null as a
     // valid address for a global, so we can't assume anything.
@@ -6993,6 +6997,9 @@ static bool isGuaranteedNotToBeUndefOrPoison(
         isa<ConstantPointerNull>(C) || isa<Function>(C))
       return true;
 
+    if (isa<ConstantPtrAuth>(C))
+      return true;
+
     if (C->getType()->isVectorTy() && !isa<ConstantExpr>(C))
       return (!includesUndef(Kind) ? !C->containsPoisonElement()
                                    : !C->containsUndefOrPoisonElement()) &&
diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp
index 02f64fcfac4f0..e37ee0bb90a82 100644
--- a/llvm/lib/AsmParser/LLLexer.cpp
+++ b/llvm/lib/AsmParser/LLLexer.cpp
@@ -708,6 +708,7 @@ lltok::Kind LLLexer::LexIdentifier() {
   KEYWORD(blockaddress);
   KEYWORD(dso_local_equivalent);
   KEYWORD(no_cfi);
+  KEYWORD(ptrauth);
 
   // Metadata types.
   KEYWORD(distinct);
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index 2e0f5ba82220c..21039f7efb9b2 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -3998,6 +3998,48 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
     ID.NoCFI = true;
     return false;
   }
+  case lltok::kw_ptrauth: {
+    // ValID ::= 'ptrauth' '(' ptr @foo ',' i32 <key> ','
+    //                         ptr addrdisc ',' i64 <disc> ')'
+    Lex.Lex();
+
+    Constant *Ptr, *Key, *AddrDisc, *Disc;
+
+    if (parseToken(lltok::lparen,
+                   "expected '(' in signed pointer expression") ||
+        parseGlobalTypeAndValue(Ptr) ||
+        parseToken(lltok::comma,
+                   "expected comma in signed pointer expression") ||
+        parseGlobalTypeAndValue(Key) ||
+        parseToken(lltok::comma,
+                   "expected comma in signed pointer expression") ||
+        parseGlobalTypeAndValue(AddrDisc) ||
+        parseToken(lltok::comma,
+                   "expected comma in signed pointer expression") ||
+        parseGlobalTypeAndValue(Disc) ||
+        parseToken(lltok::rparen, "expected ')' in signed pointer expression"))
+      return true;
+
+    if (!Ptr->getType()->isPointerTy())
+      return error(ID.Loc, "signed pointer must be a pointer");
+
+    auto KeyC = dyn_cast<ConstantInt>(Key);
+    if (!KeyC || KeyC->getBitWidth() != 32)
+      return error(ID.Loc, "signed pointer key must be i32 constant integer");
+
+    if (!AddrDisc->getType()->isPointerTy())
+      return error(ID.Loc,
+                   "signed pointer address discriminator must be a pointer");
+
+    auto DiscC = dyn_cast<ConstantInt>(Disc);
+    if (!DiscC || DiscC->getBitWidth() != 64)
+      return error(ID.Loc,
+                   "signed pointer discriminator must be i64 constant integer");
+
+    ID.ConstantVal = ConstantPtrAuth::get(Ptr, KeyC, AddrDisc, DiscC);
+    ID.Kind = ValID::t_Constant;
+    return false;
+  }
 
   case lltok::kw_trunc:
   case lltok::kw_bitcast:
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index d284c9823c9ed..538200abcf6f9 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -504,7 +504,8 @@ class BitcodeConstant final : public Value,
   static constexpr uint8_t NoCFIOpcode = 252;
   static constexpr uint8_t DSOLocalEquivalentOpcode = 251;
   static constexpr uint8_t BlockAddressOpcode = 250;
-  static constexpr uint8_t FirstSpecialOpcode = BlockAddressOpcode;
+  static constexpr uint8_t ConstantPtrAuthOpcode = 249;
+  static constexpr uint8_t FirstSpecialOpcode = ConstantPtrAuthOpcode;
 
   // Separate struct to make passing different number of parameters to
   // BitcodeConstant::create() more convenient.
@@ -1528,6 +1529,18 @@ Expected<Value *> BitcodeReader::materializeValue(unsigned StartValID,
         C = ConstantExpr::get(BC->Opcode, ConstOps[0], ConstOps[1], BC->Flags);
       } else {
         switch (BC->Opcode) {
+        case BitcodeConstant::ConstantPtrAuthOpcode: {
+          auto *Key = dyn_cast<ConstantInt>(ConstOps[1]);
+          if (!Key)
+            return error("ptrauth key operand must be ConstantInt");
+
+          auto *Disc = dyn_cast<ConstantInt>(ConstOps[3]);
+          if (!Disc)
+            return error("ptrauth disc operand must be ConstantInt");
+
+          C = ConstantPtrAuth::get(ConstOps[0], Key, ConstOps[2], Disc);
+          break;
+        }
         case BitcodeConstant::NoCFIOpcode: {
           auto *GV = dyn_cast<GlobalValue>(ConstOps[0]);
           if (!GV)
@@ -3596,6 +3609,21 @@ Error BitcodeReader::parseConstants() {
                                   Record[1]);
       break;
     }
+    case bitc::CST_CODE_SIGNED_PTR: {
+      if (Record.size() < 6)
+        return error("Invalid record");
+      Type *PtrTy = getTypeByID(Record[0]);
+      if (!PtrTy)
+        return error("Invalid record");
+
+      // PtrTy, Ptr, Key, AddrDiscTy, AddrDisc, Disc
+      V = BitcodeConstant::create(
+        Alloc, CurTy, BitcodeConstant::ConstantPtrAuthOpcode,
+        {(unsigned)Record[1], (unsigned)Record[2], (unsigned)Record[4],
+         (unsigned)Record[5]});
+
+      break;
+    }
     }
 
     assert(V->getType() == getTypeByID(CurTyID) && "Incorrect result type ID");
diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index 6f0879a4e0ee7..74f1bd8ba49b5 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -2800,6 +2800,14 @@ void ModuleBitcodeWriter::writeConstants(unsigned FirstVal, unsigned LastVal,
       Record.push_back(VE.getTypeID(BA->getFunction()->getType()));
       Record.push_back(VE.getValueID(BA->getFunction()));
       Record.push_back(VE.getGlobalBasicBlockID(BA->getBasicBlock()));
+    } else if (const ConstantPtrAuth *SP = dyn_cast<ConstantPtrAuth>(C)) {
+      Code = bitc::CST_CODE_SIGNED_PTR;
+      Record.push_back(VE.getTypeID(SP->getPointer()->getType()));
+      Record.push_back(VE.getValueID(SP->getPointer()));
+      Record.push_back(VE.getValueID(SP->getKey()));
+      Record.push_back(VE.getTypeID(SP->getAddrDiscriminator()->getType()));
+      Record.push_back(VE.getValueID(SP->getAddrDiscriminator()));
+      Record.push_back(VE.getValueID(SP->getDiscriminator()));
     } else if (const auto *Equiv = dyn_cast<DSOLocalEquivalent>(C)) {
       Code = bitc::CST_CODE_DSO_LOCAL_EQUIVALENT;
       Record.push_back(VE.getTypeID(Equiv->getGlobalValue()->getType()));
diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index 19acc89f73fb7..0e9227f0945a4 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -1578,6 +1578,21 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
     return;
   }
 
+  if (const ConstantPtrAuth *SP = dyn_cast<ConstantPtrAuth>(CV)) {
+    Out << "ptrauth (";
+
+    for (unsigned i = 0; i < SP->getNumOperands(); ++i) {
+      WriterCtx.TypePrinter->print(SP->getOperand(i)->getType(), Out);
+      Out << ' ';
+      WriteAsOperandInternal(Out, SP->getOperand(i), WriterCtx);
+      if (i != SP->getNumOperands() - 1)
+        Out << ", ";
+    }
+
+    Out << ')';
+    return;
+  }
+
   if (const ConstantArray *CA = dyn_cast<ConstantArray>(CV)) {
     Type *ETy = CA->getType()->getElementType();
     Out << '[';
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index 034e397bc69fc..0bfec86783378 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -1154,6 +1154,9 @@ static ICmpInst::Predicate evaluateICmpRelation(Constant *V1, Constant *V2) {
                                 GV->getType()->getAddressSpace()))
         return ICmpInst::ICMP_UGT;
     }
+  } else if (const ConstantPtrAuth *SP = dyn_cast<ConstantPtrAuth>(V1)) {
+    // FIXME: ahmedbougacha: implement ptrauth cst comparison
+    return ICmpInst::BAD_ICMP_PREDICATE;
   } else {
     // Ok, the LHS is known to be a constantexpr.  The RHS can be any of a
     // constantexpr, a global, block address, or a simple constant.
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index e6b92aad392f6..1af52f9e612c2 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -550,6 +550,9 @@ void llvm::deleteConstant(Constant *C) {
   case Constant::NoCFIValueVal:
     delete static_cast<NoCFIValue *>(C);
     break;
+  case Constant::ConstantPtrAuthVal:
+    delete static_cast<ConstantPtrAuth *>(C);
+    break;
   case Constant::UndefValueVal:
     delete static_cast<UndefValue *>(C);
     break;
@@ -2015,6 +2018,118 @@ Value *NoCFIValue::handleOperandChangeImpl(Value *From, Value *To) {
   return nullptr;
 }
 
+//---- ConstantPtrAuth::get() implementations.
+//
+
+static bool areEquivalentAddrDiscriminators(const Value *V1, const Value *V2,
+                                            const DataLayout &DL) {
+  APInt V1Off(DL.getPointerSizeInBits(), 0);
+  APInt V2Off(DL.getPointerSizeInBits(), 0);
+
+  if (auto *V1Cast = dyn_cast<PtrToIntOperator>(V1))
+    V1 = V1Cast->getPointerOperand();
+  if (auto *V2Cast = dyn_cast<PtrToIntOperator>(V2))
+    V2 = V2Cast->getPointerOperand();
+  auto *V1Base = V1->stripAndAccumulateConstantOffsets(
+      DL, V1Off, /*AllowNonInbounds=*/true);
+  auto *V2Base = V2->stripAndAccumulateConstantOffsets(
+      DL, V2Off, /*AllowNonInbounds=*/true);
+  return V1Base == V2Base && V1Off == V2Off;
+}
+
+bool ConstantPtrAuth::isCompatibleWith(const Value *Key,
+                                       const Value *Discriminator,
+                                       const DataLayout &DL) const {
+  // If the keys are different, there's no chance for this to be compatible.
+  if (Key != getKey())
+    return false;
+
+  // If the discriminators are the same, this is compatible iff there is no
+  // address discriminator.
+  if (Discriminator == getDiscriminator())
+    return getAddrDiscriminator()->isNullValue();
+
+  // If we dynamically blend the discriminator with the address discriminator,
+  // this is compatible.
+  if (auto *DiscBlend = dyn_cast<IntrinsicInst>(Discriminator)) {
+    if (DiscBlend->getIntrinsicID() == Intrinsic::ptrauth_blend &&
+        DiscBlend->getOperand(1) == getDiscriminator() &&
+        areEquivalentAddrDiscriminators(DiscBlend->getOperand(0),
+                                        getAddrDiscriminator(), DL))
+      return true;
+  }
+
+  // If we don't have a non-address discriminator, we don't need a blend in
+  // the first place:  accept the address discriminator as the discriminator.
+  if (getDiscriminator()->isNullValue() &&
+      areEquivalentAddrDiscriminators(getAddrDiscriminator(), Discriminator,
+                                      DL))
+    return true;
+
+  // Otherwise, we don't know.
+  return false;
+}
+
+ConstantPtrAuth *ConstantPtrAuth::getWithSameSchema(Constant *Pointer) const {
+  return get(Pointer, getKey(), getAddrDiscriminator(), getDiscriminator());
+}
+
+ConstantPtrAuth *ConstantPtrAuth::get(Constant *Ptr, ConstantInt *Key,
+                                      Constant *AddrDisc, ConstantInt *Disc) {
+  Constant *ArgVec[] = {Ptr, Key, AddrDisc, Disc};
+  ConstantPtrAuthKeyType MapKey(ArgVec);
+  LLVMContextImpl *pImpl = Ptr->getContext().pImpl;
+  return pImpl->ConstantPtrAuths.getOrCreate(Ptr->getType(), MapKey);
+}
+
+ConstantPtrAuth::ConstantPtrAuth(Constant *Ptr, ConstantInt *Key,
+                                 Constant *AddrDisc, ConstantInt *Disc)
+    : Constant(Ptr->getType(), Value::ConstantPtrAuthVal, &Op<0>(), 4) {
+#ifndef NDEBUG
+  assert(Ptr->getType()->isPointerTy());
+  assert(Key->getBitWidth() == 32);
+  assert(AddrDisc->getType()->isPointerTy());
+  assert(Disc->getBitWidth() == 64);
+#endif
+  setOperand(0, Ptr);
+  setOperand(1, Key);
+  setOperand(2, AddrDisc);
+  setOperand(3, Disc);
+}
+
+/// Remove the constant from the constant table.
+void ConstantPtrAuth::destroyConstantImpl() {
+  getType()->getContext().pImpl->ConstantPtrAuths.remove(this);
+}
+
+Value *ConstantPtrAuth::handleOperandChangeImpl(Value *From, Value *ToV) {
+  assert(isa<Constant>(ToV) && "Cannot make Constant refer to non-constant!");
+  Constant *To = cast<Constant>(ToV);
+
+  SmallVector<Constant *, 8> Values;
+  Values.reserve(getNumOperands()); // Build replacement array.
+
+  // Fill values with the modified operands of the constant array.  Also,
+  // compute whether this turns into an all-zeros array.
+  unsigned NumUpdated = 0;
+
+  Use *OperandList = getOperandList();
+  unsigned OperandNo = 0;
+  for (Use *O = OperandList, *E = OperandList + getNumOperands(); O != E; ++O) {
+    Constant *Val = cast<Constant>(O->get());
+    if (Val == From) {
+      OperandNo = (O - OperandList);
+      Val = To;
+      ++NumUpdated;
+    }
+    Values.push_back(Val);
+  }
+
+  // FIXME: shouldn't we check it's not already there?
+  return getContext().pImpl->ConstantPtrAuths.replaceOperandsInPlace(
+      Values, this, From, To, NumUpdated, OperandNo);
+}
+
 //---- ConstantExpr::get() implementations.
 //
 
diff --git a/llvm/lib/IR/ConstantsContext.h b/llvm/lib/IR/ConstantsContext.h
index 44a926b5dc58e..bd111f406687c 100644
--- a/llvm/lib/IR/ConstantsContext.h
+++ b/llvm/lib/IR/ConstantsContext.h
@@ -23,6 +23,7 @@
 #include "llvm/IR/Constant.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/GlobalVariable.h"
 #include "llvm/IR/InlineAsm.h"
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Instructions.h"
@@ -282,6 +283,7 @@ DEFINE_TRANSPARENT_OPERAND_ACCESSORS(CompareConstantExpr, Value)
 template <class ConstantClass> struct ConstantAggrKeyType;
 struct InlineAsmKeyType;
 struct ConstantExprKeyType;
+struct ConstantPtrAuthKeyType;
 
 template <class ConstantClass> struct ConstantInfo;
 template <> struct ConstantInfo<ConstantExpr> {
@@ -304,6 +306,10 @@ template <> struct ConstantInfo<ConstantVector> {
   using ValType = ConstantAggrKeyType<ConstantVector>;
   using TypeClass = VectorType;
 };
+template <> struct ConstantInfo<ConstantPtrAuth> {
+  using ValType = ConstantPtrAuthKeyType;
+  using TypeClass = Type;
+};
 
 template <class ConstantClass> struct ConstantAggrKeyType {
   ArrayRef<Constant *> Operands;
@@ -511,6 +517,47 @@ struct ConstantExprKeyType {
   }
 };
 
+struct ConstantPtrAuthKeyType {
+  ArrayRef<Constant *> Operands;
+
+  ConstantPtrAuthKeyType(ArrayRef<Constant *> Operands) : Operands(Operands) {}
+
+  ConstantPtrAuthKeyType(ArrayRef<Constant *> Operands, const ConstantPtrAuth *)
+      : Operands(Operands) {}
+
+  ConstantPtrAuthKeyType(const ConstantPtrAuth *C,
+                         SmallVectorImpl<Constant *> &Storage) {
+    assert(Storage.empty() && "Expected empty storage");
+    for (unsigned I = 0, E = C->getNumOperands(); I != E; ++I)
+      Storage.push_back(cast<Constant>(C->getOperand(I)));
+    Operands = Storage;
+  }
+
+  bool operator==(const ConstantPtrAuthKeyType &X) const {
+    return Operands == X.Operands;
+  }
+
+  bool operator==(const ConstantPtrAuth *C) const {
+    if (Operands.size() != C->getNumOperands())
+      return false;
+    for (unsigned I = 0, E = Operands.size(); I != E; ++I)
+      if (Operands[I] != C->getOperand(I))
+        return false;
+    return true;
+  }
+
+  unsigned getHash() const {
+    return hash_combine_range(Operands.begin(), Operands.end());
+  }
+
+  using TypeClass = typename ConstantInfo<ConstantPtrAuth>::TypeClass;
+
+  ConstantPtrAuth *create(TypeClass *Ty) const {
+    return new ConstantPtrAuth(Operands[0], cast<ConstantInt>(Operands[1]),
+                               Operands[2], cast<ConstantInt>(Operands[3]));
+  }
+};
+
 // Free memory for a given constant.  Assumes the constant has already been
 // removed from all relevant maps.
 void deleteConstant(Constant *C);
diff --git a/llvm/lib/IR/LLVMContextImpl.h b/llvm/lib/IR/LLVMContextImpl.h
index b1dcb262fb657..bc0f1f2cf8bc3 100644
--- a/llvm/lib/IR/LLVMContextImpl.h
+++ b/llvm/lib/IR/LLVMContextImpl.h
@@ -1546,6 +1546,8 @@ class LLVMContextImpl {
 
   DenseMap<const GlobalValue *, NoCFIValue *> NoCFIValues;
 
+  ConstantUniqueMap<ConstantPtrAuth> ConstantPtrAuths;
+
   ConstantUniqueMap<ConstantExpr> ExprConstants;
 
   ConstantUniqueMap<InlineAsm> InlineAsms;
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 62dde2e6ad424..22977de66207e 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -626,6 +626,7 @@ class Verifier : public InstVisitor<Verifier>, VerifierSupport {
 
   void visitConstantExprsRecursively(const Constant *EntryC);
   void visitConstantExpr(const ConstantExpr *CE);
+  void visitConstantPtrAuth(const ConstantPtrAuth *SP);
   void verifyInlineAsmCall(const CallBase &Call);
   void verifyStatepoint(const CallBase &Call);
   void verifyFrameRecoverIndices();
@@ -2396,6 +2397,9 @@ void Verifier::visitConstantExprsRecursively(const Constant *EntryC) {
     if (const auto *CE = dyn_cast<ConstantExpr>(C))
       visitConstantExpr(CE);
 
+    if (const auto *SP = dyn_cast<ConstantPtrAuth>(C))
+      visitConstantPtrAuth(SP);
+
     if (const auto *GV = dyn_cast<GlobalValue>(C)) {
       // Global Values get visited separately, but we do need to make sure
       // that the global value is in the correct module
@@ -2423,6 +2427,20 @@ void Verifier::visitConstantExpr(const ConstantExpr *CE) {
           "Invalid bitcast", CE);
 }
 
+void Verifier::visitConstantPtrAuth(const ConstantPtrAuth *SP) {
+  Check(SP->getPointer()->getType()->isPointerTy(),
+        "signed pointer must be a pointer");
+
+  Check(SP->getKey()->getBitWidth() == 32,
+        "signed pointer key must be i32 constant integer");
+
+  Check(SP->getAddrDiscriminator()->getType()->isPointerTy(),
+        "signed pointer address discriminator must be a pointer");
+
+  Check(SP->getDiscriminator()->getBitWidth() == 64,
+        "signed pointer discriminator must be i64 constant integer");
+}
+
 bool Verifier::verifyAttributeCount(AttributeList Attrs, unsigned Params) {
   // There shouldn't be more attribute sets than there are parameters plus the
   // function and return value.
@@ -5039,6 +5057,8 @@ void Verifier::visitInstruction(Instruction &I) {
     } else if (isa<InlineAsm>(I.getOperand(i))) {
       Check(CBI && &CBI->getCalledOperandUse() == &I.getOperandUse(i),
             "Cannot take the address of an inline asm!", &I);
+    } else if (auto SP = dyn_cast<ConstantPtrAuth>(I.getOperand(i))) {
+      visitConstantExprsRecursively(SP);
     } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(I.getOperand(i))) {
       if (CE->getType()->isPtrOrPtrVectorTy()) {
         // If we have a ConstantExpr pointer, we need to see if it came from an
diff --git a/llvm/test/Assembler/invalid-ptrauth-const1.ll b/llvm/test/Assembler/invalid-ptrauth-const1.ll
new file mode 100644
index 0000000000000..dfe0e1c5ed3f0
--- /dev/null
+++ b/llvm/test/Assembler/invalid-ptrauth-const1.ll
@@ -0,0 +1,6 @@
+; RUN: not llvm-as < %s 2>&1 | FileCheck %s
+
+ at var = global i32 0
+
+; CHECK: error: signed pointer must be a pointer
+ at auth_var = global ptr ptrauth (i32 42, i32 0, ptr null, i64 65535)
diff --git a/llvm/test/Assembler/invalid-ptrauth-const2.ll b/llvm/test/Assembler/invalid-ptrauth-const2.ll
new file mode 100644
index 0000000000000..9042d8dfe7726
--- /dev/null
+++ b/llvm/test/Assembler/invalid-ptrauth-const2.ll
@@ -0,0 +1,6 @@
+; RUN: not llvm-as < %s 2>&1 | FileCheck %s
+
+ at var = global i32 0
+
+; CHECK: error: signed pointer key must be i32 constant integer
+ at auth_var = global ptr ptrauth (ptr @var, i32 ptrtoint (ptr @var to i32), ptr null, i64 65535)
diff --git a/llvm/test/Assembler/invalid-ptrauth-const3.ll b/llvm/test/Assembler/invalid-ptrauth-const3.ll
new file mode 100644
index 0000000000000..00bcef9db6d1b
--- /dev/null
+++ b/llvm/test/Assembler/invalid-ptrauth-const3.ll
@@ -0,0 +1,6 @@
+; RUN: not llvm-as < %s 2>&1 | FileCheck %s
+
+ at var = global i32 0
+
+; CHECK: error: signed pointer address discriminator must be a pointer
+ at auth_var = global ptr ptrauth (ptr @var, i32 2, i8 0, i64 65535)
diff --git a/llvm/test/Assembler/invalid-ptrauth-const4.ll b/llvm/test/Assembler/invalid-ptrauth-const4.ll
new file mode 100644
index 0000000000000..00bcef9db6d1b
--- /dev/null
+++ b/llvm/test/Assembler/invalid-ptrauth-const4.ll
@@ -0,0 +1,6 @@
+; RUN: not llvm-as < %s 2>&1 | FileCheck %s
+
+ at var = global i32 0
+
+; CHECK: error: signed pointer address discriminator must be a pointer
+ at auth_var = global ptr ptrauth (ptr @var, i32 2, i8 0, i64 65535)
diff --git a/llvm/test/Assembler/invalid-ptrauth-const5.ll b/llvm/test/Assembler/invalid-ptrauth-const5.ll
new file mode 100644
index 0000000000000..72910769de262
--- /dev/null
+++ b/llvm/test/Assembler/invalid-ptrauth-const5.ll
@@ -0,0 +1,6 @@
+; RUN: not llvm-as < %s 2>&1 | FileCheck %s
+
+ at var = global i32 0
+
+; CHECK: error: signed pointer discriminator must be i64 constant integer
+ at auth_var = global ptr ptrauth (ptr @var, i32 2, ptr null, i64 ptrtoint (ptr @var to i64))
diff --git a/llvm/test/Assembler/invalid-ptrauth-const6.ll b/llvm/test/Assembler/invalid-ptrauth-const6.ll
new file mode 100644
index 0000000000000..d79d88ee845ce
--- /dev/null
+++ b/llvm/test/Assembler/invalid-ptrauth-const6.ll
@@ -0,0 +1,6 @@
+; RUN: not llvm-as < %s 2>&1 | FileCheck %s
+
+ at var = global i32 0
+
+; CHECK: error: signed pointer address discriminator must be a pointer
+ at auth_var = global ptr ptrauth (ptr @var, i32 2, i8 0, i32 1000000)
diff --git a/llvm/test/Assembler/ptrauth-const.ll b/llvm/test/Assembler/ptrauth-const.ll
new file mode 100644
index 0000000000000..15b71e5ef8c13
--- /dev/null
+++ b/llvm/test/Assembler/ptrauth-const.ll
@@ -0,0 +1,13 @@
+; RUN: llvm-as < %s | llvm-dis | FileCheck %s
+
+ at var = global i32 0
+
+; CHECK: @auth_var = global ptr ptrauth (ptr @var, i32 0, ptr null, i64 -1)
+ at auth_var = global ptr ptrauth (ptr @var, i32 0, ptr null, i64 -1)
+
+
+; CHECK: @addrdisc_var = global ptr ptrauth (ptr @var, i32 0, ptr @addrdisc_var, i64 1234)
+ at addrdisc_var = global ptr ptrauth (ptr @var, i32 0, ptr @addrdisc_var, i64 1234)
+
+; CHECK: @keyed_var = global ptr ptrauth (ptr @var, i32 3, ptr null, i64 0)
+ at keyed_var = global ptr ptrauth (ptr @var, i32 3, ptr null, i64 0)
diff --git a/llvm/test/Bitcode/compatibility.ll b/llvm/test/Bitcode/compatibility.ll
index ce6a6571ec144..ff00259192339 100644
--- a/llvm/test/Bitcode/compatibility.ll
+++ b/llvm/test/Bitcode/compatibility.ll
@@ -217,6 +217,10 @@ declare void @g.f1()
 ; CHECK: @g.sanitize_address_dyninit = global i32 0, sanitize_address_dyninit
 ; CHECK: @g.sanitize_multiple = global i32 0, sanitize_memtag, sanitize_address_dyninit
 
+; ptrauth constant
+ at auth_var = global ptr ptrauth (ptr @g1, i32 0, ptr null, i64 65535)
+; CHECK: @auth_var = global ptr ptrauth (ptr @g1, i32 0, ptr null, i64 65535)
+
 ;; Aliases
 ; Format: @<Name> = [Linkage] [Visibility] [DLLStorageClass] [ThreadLocal]
 ;                   [unnamed_addr] alias <AliaseeTy> @<Aliasee>
diff --git a/llvm/utils/vim/syntax/llvm.vim b/llvm/utils/vim/syntax/llvm.vim
index d86e3d1ddbc27..905d696400ca3 100644
--- a/llvm/utils/vim/syntax/llvm.vim
+++ b/llvm/utils/vim/syntax/llvm.vim
@@ -150,6 +150,7 @@ syn keyword llvmKeyword
       \ preallocated
       \ private
       \ protected
+      \ ptrauth
       \ ptx_device
       \ ptx_kernel
       \ readnone

>From e47a75a37faf9c42e7ef796686f1a13db4ac070d Mon Sep 17 00:00:00 2001
From: Ahmed Bougacha <ahmed at bougacha.org>
Date: Wed, 8 May 2024 10:25:46 -0700
Subject: [PATCH 2/4] Address review feedback.

- LangRef.rst: missing ':', stale example
- LLParser.cpp: auto *
- Bitcode: remove ptrty/addrdiscty
- AsmWriter.cpp: ListSeparator
- Constants.h: cast<>
- Constants.cpp: stale NDEBUG
- Constants.cpp: stale comments
- Constants.cpp: use getTypeSizeInBits for offset accumulation,
  to allow pointer/integer types without fiddling.
- ConstantFold.cpp: remove FIXME, should be addressed separately
- ValueTracking.cpp: remove from isGuaranteedNotToBeUndefOrPoison
---
 llvm/docs/LangRef.rst                     |  4 +---
 llvm/include/llvm/Bitcode/LLVMBitCodes.h  |  4 +---
 llvm/include/llvm/IR/Constants.h          | 12 ++++++++----
 llvm/lib/Analysis/ValueTracking.cpp       |  3 ---
 llvm/lib/AsmParser/LLParser.cpp           |  4 ++--
 llvm/lib/Bitcode/Reader/BitcodeReader.cpp | 15 +++++----------
 llvm/lib/Bitcode/Writer/BitcodeWriter.cpp | 14 ++++++--------
 llvm/lib/IR/AsmWriter.cpp                 |  6 ++----
 llvm/lib/IR/ConstantFold.cpp              |  3 ---
 llvm/lib/IR/Constants.cpp                 | 19 +++++++------------
 10 files changed, 32 insertions(+), 52 deletions(-)

diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index c29bbda2eb2e5..e1074982ea147 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -4750,7 +4750,7 @@ reference to the CFI jump table in the ``LowerTypeTests`` pass. These constants
 may be useful in low-level programs, such as operating system kernels, which
 need to refer to the actual function body.
 
-.. _ptrauth
+.. _ptrauth:
 
 Authenticated Pointers
 ----------------------
@@ -4775,8 +4775,6 @@ If the address discriminator is present, then it is
     %tmp2 = call i64 @llvm.ptrauth.sign.i64(i64 ptrtoint (ptr CST to i64), i64  %tmp1)
     %val = inttoptr i64 %tmp2 to ptr
 
-    %tmp = call i64 @llvm.ptrauth.blend.i64
-
 .. _constantexprs:
 
 Constant Expressions
diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index 361068df80932..544c76bf4e3c9 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -412,9 +412,7 @@ enum ConstantsCodes {
                                       //                 asmdialect|unwind,
                                       //                 asmstr,conststr]
   CST_CODE_CE_GEP_WITH_INRANGE = 31,  // [opty, flags, range, n x operands]
-  CST_CODE_SIGNED_PTR = 32,           // SIGNED_PTR: [ptrty, ptr, key,
-                                      //              addrdiscty, addrdisc,
-                                      //              disc]
+  CST_CODE_SIGNED_PTR = 32,           // [ptr, key, addrdisc, disc]
 };
 
 /// CastOpcodes - These are values used in the bitcode files to encode which
diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index c64a7a46bdd0c..15bea35a5b14a 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -1034,18 +1034,22 @@ class ConstantPtrAuth final : public Constant {
   DECLARE_TRANSPARENT_OPERAND_ACCESSORS(Constant);
 
   /// The pointer that is authenticated in this authenticated global reference.
-  Constant *getPointer() const { return (Constant *)Op<0>().get(); }
+  Constant *getPointer() const { return cast<Constant>(Op<0>().get()); }
 
   /// The Key ID, an i32 constant.
-  ConstantInt *getKey() const { return (ConstantInt *)Op<1>().get(); }
+  ConstantInt *getKey() const { return cast<ConstantInt>(Op<1>().get()); }
 
   /// The address discriminator if any, or the null constant.
   /// If present, this must be a value equivalent to the storage location of
   /// the only user of the authenticated ptrauth global.
-  Constant *getAddrDiscriminator() const { return (Constant *)Op<2>().get(); }
+  Constant *getAddrDiscriminator() const {
+    return cast<Constant>(Op<2>().get());
+  }
 
   /// The discriminator.
-  ConstantInt *getDiscriminator() const { return (ConstantInt *)Op<3>().get(); }
+  ConstantInt *getDiscriminator() const {
+    return cast<ConstantInt>(Op<3>().get());
+  }
 
   /// Whether there is any non-null address discriminator.
   bool hasAddressDiversity() const {
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 7713ea8c4321b..1c850f36878cf 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -7226,9 +7226,6 @@ static bool isGuaranteedNotToBeUndefOrPoison(
         isa<ConstantPointerNull>(C) || isa<Function>(C))
       return true;
 
-    if (isa<ConstantPtrAuth>(C))
-      return true;
-
     if (C->getType()->isVectorTy() && !isa<ConstantExpr>(C))
       return (!includesUndef(Kind) ? !C->containsPoisonElement()
                                    : !C->containsUndefOrPoisonElement()) &&
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index 8978fed6c4e8d..7c9ccec8cb429 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -4069,7 +4069,7 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
     if (!Ptr->getType()->isPointerTy())
       return error(ID.Loc, "signed pointer must be a pointer");
 
-    auto KeyC = dyn_cast<ConstantInt>(Key);
+    auto *KeyC = dyn_cast<ConstantInt>(Key);
     if (!KeyC || KeyC->getBitWidth() != 32)
       return error(ID.Loc, "signed pointer key must be i32 constant integer");
 
@@ -4077,7 +4077,7 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
       return error(ID.Loc,
                    "signed pointer address discriminator must be a pointer");
 
-    auto DiscC = dyn_cast<ConstantInt>(Disc);
+    auto *DiscC = dyn_cast<ConstantInt>(Disc);
     if (!DiscC || DiscC->getBitWidth() != 64)
       return error(ID.Loc,
                    "signed pointer discriminator must be i64 constant integer");
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index cb67fba26cae6..eef080041219d 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -3643,18 +3643,13 @@ Error BitcodeReader::parseConstants() {
       break;
     }
     case bitc::CST_CODE_SIGNED_PTR: {
-      if (Record.size() < 6)
-        return error("Invalid record");
-      Type *PtrTy = getTypeByID(Record[0]);
-      if (!PtrTy)
-        return error("Invalid record");
-
-      // PtrTy, Ptr, Key, AddrDiscTy, AddrDisc, Disc
+      if (Record.size() < 4)
+        return error("Invalid ptrauth record");
+      // Ptr, Key, AddrDisc, Disc
       V = BitcodeConstant::create(
         Alloc, CurTy, BitcodeConstant::ConstantPtrAuthOpcode,
-        {(unsigned)Record[1], (unsigned)Record[2], (unsigned)Record[4],
-         (unsigned)Record[5]});
-
+        {(unsigned)Record[0], (unsigned)Record[1], (unsigned)Record[2],
+         (unsigned)Record[3]});
       break;
     }
     }
diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index 73c1bba7399e2..71e001f733e56 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -2822,14 +2822,6 @@ void ModuleBitcodeWriter::writeConstants(unsigned FirstVal, unsigned LastVal,
       Record.push_back(VE.getTypeID(BA->getFunction()->getType()));
       Record.push_back(VE.getValueID(BA->getFunction()));
       Record.push_back(VE.getGlobalBasicBlockID(BA->getBasicBlock()));
-    } else if (const ConstantPtrAuth *SP = dyn_cast<ConstantPtrAuth>(C)) {
-      Code = bitc::CST_CODE_SIGNED_PTR;
-      Record.push_back(VE.getTypeID(SP->getPointer()->getType()));
-      Record.push_back(VE.getValueID(SP->getPointer()));
-      Record.push_back(VE.getValueID(SP->getKey()));
-      Record.push_back(VE.getTypeID(SP->getAddrDiscriminator()->getType()));
-      Record.push_back(VE.getValueID(SP->getAddrDiscriminator()));
-      Record.push_back(VE.getValueID(SP->getDiscriminator()));
     } else if (const auto *Equiv = dyn_cast<DSOLocalEquivalent>(C)) {
       Code = bitc::CST_CODE_DSO_LOCAL_EQUIVALENT;
       Record.push_back(VE.getTypeID(Equiv->getGlobalValue()->getType()));
@@ -2838,6 +2830,12 @@ void ModuleBitcodeWriter::writeConstants(unsigned FirstVal, unsigned LastVal,
       Code = bitc::CST_CODE_NO_CFI_VALUE;
       Record.push_back(VE.getTypeID(NC->getGlobalValue()->getType()));
       Record.push_back(VE.getValueID(NC->getGlobalValue()));
+    } else if (const auto *CPA = dyn_cast<ConstantPtrAuth>(C)) {
+      Code = bitc::CST_CODE_SIGNED_PTR;
+      Record.push_back(VE.getValueID(CPA->getPointer()));
+      Record.push_back(VE.getValueID(CPA->getKey()));
+      Record.push_back(VE.getValueID(CPA->getAddrDiscriminator()));
+      Record.push_back(VE.getValueID(CPA->getDiscriminator()));
     } else {
 #ifndef NDEBUG
       C->dump();
diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index c4a2130d27992..09186dbfd36ef 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -1592,15 +1592,13 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
 
   if (const ConstantPtrAuth *SP = dyn_cast<ConstantPtrAuth>(CV)) {
     Out << "ptrauth (";
-
+    ListSeparator LS;
     for (unsigned i = 0; i < SP->getNumOperands(); ++i) {
+      Out << LS;
       WriterCtx.TypePrinter->print(SP->getOperand(i)->getType(), Out);
       Out << ' ';
       WriteAsOperandInternal(Out, SP->getOperand(i), WriterCtx);
-      if (i != SP->getNumOperands() - 1)
-        Out << ", ";
     }
-
     Out << ')';
     return;
   }
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index 4d1d840067eb4..a766b1fe60182 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -1154,9 +1154,6 @@ static ICmpInst::Predicate evaluateICmpRelation(Constant *V1, Constant *V2) {
                                 GV->getType()->getAddressSpace()))
         return ICmpInst::ICMP_UGT;
     }
-  } else if (const ConstantPtrAuth *SP = dyn_cast<ConstantPtrAuth>(V1)) {
-    // FIXME: ahmedbougacha: implement ptrauth cst comparison
-    return ICmpInst::BAD_ICMP_PREDICATE;
   } else if (auto *CE1 = dyn_cast<ConstantExpr>(V1)) {
     // Ok, the LHS is known to be a constantexpr.  The RHS can be any of a
     // constantexpr, a global, block address, or a simple constant.
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index 97a34fa3791a0..11214fc8246ab 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -2023,18 +2023,18 @@ Value *NoCFIValue::handleOperandChangeImpl(Value *From, Value *To) {
 
 static bool areEquivalentAddrDiscriminators(const Value *V1, const Value *V2,
                                             const DataLayout &DL) {
-  APInt V1Off(DL.getPointerSizeInBits(), 0);
-  APInt V2Off(DL.getPointerSizeInBits(), 0);
+  APInt Off1(DL.getTypeSizeInBits(V1->getType()), 0);
+  APInt Off2(DL.getTypeSizeInBits(V2->getType()), 0);
 
   if (auto *V1Cast = dyn_cast<PtrToIntOperator>(V1))
     V1 = V1Cast->getPointerOperand();
   if (auto *V2Cast = dyn_cast<PtrToIntOperator>(V2))
     V2 = V2Cast->getPointerOperand();
   auto *V1Base = V1->stripAndAccumulateConstantOffsets(
-      DL, V1Off, /*AllowNonInbounds=*/true);
+      DL, Off1, /*AllowNonInbounds=*/true);
   auto *V2Base = V2->stripAndAccumulateConstantOffsets(
-      DL, V2Off, /*AllowNonInbounds=*/true);
-  return V1Base == V2Base && V1Off == V2Off;
+      DL, Off2, /*AllowNonInbounds=*/true);
+  return V1Base == V2Base && Off1 == Off2;
 }
 
 bool ConstantPtrAuth::isCompatibleWith(const Value *Key,
@@ -2085,12 +2085,10 @@ ConstantPtrAuth *ConstantPtrAuth::get(Constant *Ptr, ConstantInt *Key,
 ConstantPtrAuth::ConstantPtrAuth(Constant *Ptr, ConstantInt *Key,
                                  Constant *AddrDisc, ConstantInt *Disc)
     : Constant(Ptr->getType(), Value::ConstantPtrAuthVal, &Op<0>(), 4) {
-#ifndef NDEBUG
   assert(Ptr->getType()->isPointerTy());
   assert(Key->getBitWidth() == 32);
   assert(AddrDisc->getType()->isPointerTy());
   assert(Disc->getBitWidth() == 64);
-#endif
   setOperand(0, Ptr);
   setOperand(1, Key);
   setOperand(2, AddrDisc);
@@ -2106,11 +2104,9 @@ Value *ConstantPtrAuth::handleOperandChangeImpl(Value *From, Value *ToV) {
   assert(isa<Constant>(ToV) && "Cannot make Constant refer to non-constant!");
   Constant *To = cast<Constant>(ToV);
 
-  SmallVector<Constant *, 8> Values;
-  Values.reserve(getNumOperands()); // Build replacement array.
+  SmallVector<Constant *, 4> Values;
+  Values.reserve(getNumOperands());
 
-  // Fill values with the modified operands of the constant array.  Also,
-  // compute whether this turns into an all-zeros array.
   unsigned NumUpdated = 0;
 
   Use *OperandList = getOperandList();
@@ -2125,7 +2121,6 @@ Value *ConstantPtrAuth::handleOperandChangeImpl(Value *From, Value *ToV) {
     Values.push_back(Val);
   }
 
-  // FIXME: shouldn't we check it's not already there?
   return getContext().pImpl->ConstantPtrAuths.replaceOperandsInPlace(
       Values, this, From, To, NumUpdated, OperandNo);
 }

>From 48a946ccd8eaffd660fbd18abb8e419ffd3b1da7 Mon Sep 17 00:00:00 2001
From: Ahmed Bougacha <ahmed at bougacha.org>
Date: Wed, 8 May 2024 10:25:54 -0700
Subject: [PATCH 3/4] Refine patch a bit.

- signed_ptr/SP -> ptrauth/CPA
- Verifier.cpp: check CPA type == base ptr type
- C API: exclude ConstantPtrAuth until needed
- BitcodeAnalyzer.cpp: print CST_CODE_PTRAUTH
- IR: refine test a bit, add addrspace for base/disc
- docs: refine a bit, remove stale overload
- docs: add to PointerAuth.md
---
 llvm/docs/LangRef.rst                       | 30 ++++++++++++--------
 llvm/docs/PointerAuth.md                    | 22 +++++++++++++++
 llvm/include/llvm-c/Core.h                  |  1 -
 llvm/include/llvm/Bitcode/LLVMBitCodes.h    |  2 +-
 llvm/include/llvm/IR/Value.def              |  2 +-
 llvm/lib/Bitcode/Reader/BitcodeAnalyzer.cpp |  1 +
 llvm/lib/Bitcode/Reader/BitcodeReader.cpp   |  2 +-
 llvm/lib/Bitcode/Writer/BitcodeWriter.cpp   |  2 +-
 llvm/lib/IR/AsmWriter.cpp                   |  8 +++---
 llvm/lib/IR/Verifier.cpp                    | 31 +++++++++++----------
 llvm/test/Assembler/ptrauth-const.ll        | 25 +++++++++++++----
 11 files changed, 85 insertions(+), 41 deletions(-)

diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index e1074982ea147..cd534ac140dc8 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -4750,29 +4750,35 @@ reference to the CFI jump table in the ``LowerTypeTests`` pass. These constants
 may be useful in low-level programs, such as operating system kernels, which
 need to refer to the actual function body.
 
-.. _ptrauth:
+.. _ptrauth_constant:
 
-Authenticated Pointers
-----------------------
+Pointer Authentication Constants
+--------------------------------
 
-``ptrauth (ptr CST, i32 KEY, ptr ADDRDISC, i16 DISC)
+``ptrauth (ptr CST, i32 KEY, ptr ADDRDISC, i64 DISC)``
 
 A '``ptrauth``' constant represents a pointer with a cryptographic
-authentication signature embedded into some bits. Its type is the same as the
-first argument.
+authentication signature embedded into some bits, as described in the
+`Pointer Authentication <PointerAuth.html>`__ document.
+
+A '``ptrauth``' constant is simply a constant equivalent to the
+``llvm.ptrauth.sign`` intrinsic, potentially fed by a discriminator
+``llvm.ptrauth.blend`` if needed.
 
 
 If the address disciminator is ``null`` then the expression is equivalent to
 
-.. code-block:llvm
-    %tmp = call i64 @llvm.ptrauth.sign.i64(i64 ptrtoint (ptr CST to i64), i32 KEY, i64 DISC)
+.. code-block:: llvm
+
+    %tmp = call i64 @llvm.ptrauth.sign(i64 ptrtoint (ptr CST to i64), i32 KEY, i64 DISC)
     %val = inttoptr i64 %tmp to ptr
 
-If the address discriminator is present, then it is
+Otherwise, the expression is equivalent to:
+
+.. code-block:: llvm
 
-.. code-block:llvm
-    %tmp1 = call i64 @llvm.ptrauth.blend.i64(i64 ptrtoint (ptr ADDRDISC to i64), i64 DISC)
-    %tmp2 = call i64 @llvm.ptrauth.sign.i64(i64 ptrtoint (ptr CST to i64), i64  %tmp1)
+    %tmp1 = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr ADDRDISC to i64), i64 DISC)
+    %tmp2 = call i64 @llvm.ptrauth.sign(i64 ptrtoint (ptr CST to i64), i32 KEY, i64 %tmp1)
     %val = inttoptr i64 %tmp2 to ptr
 
 .. _constantexprs:
diff --git a/llvm/docs/PointerAuth.md b/llvm/docs/PointerAuth.md
index a8d2b4d8f5f0b..a52bff2663dce 100644
--- a/llvm/docs/PointerAuth.md
+++ b/llvm/docs/PointerAuth.md
@@ -16,6 +16,7 @@ For more details, see the clang documentation page for
 At the IR level, it is represented using:
 
 * a [set of intrinsics](#intrinsics) (to sign/authenticate pointers)
+* a [signed pointer constant](#constant) (to sign globals)
 * a [call operand bundle](#operand-bundle) (to authenticate called pointers)
 
 The current implementation leverages the
@@ -225,6 +226,27 @@ with a pointer address discriminator, in a way that is specified by the target
 implementation.
 
 
+### Constant
+
+[Intrinsics](#intrinsics) can be used to produce signed pointers dynamically,
+in code, but not for signed pointers referenced by constants, in, e.g., global
+initializers.
+
+The latter are represented using a
+[``ptrauth`` constant](https://llvm.org/docs/LangRef.html#ptrauth-constant),
+which describes an authenticated relocation producing a signed pointer.
+
+```llvm
+  ptrauth (ptr CST, i32 KEY, ptr ADDRDISC, i64 DISC)
+```
+
+is equivalent to:
+
+```llvm
+  %disc = call i64 @llvm.ptrauth.blend(i64 ptrtoint(ptr ADDRDISC to i64), i64 DISC)
+  %signedval = call i64 @llvm.ptrauth.sign(ptr CST, i32 KEY, i64 %disc)
+```
+
 ### Operand Bundle
 
 Function pointers used as indirect call targets can be signed when materialized,
diff --git a/llvm/include/llvm-c/Core.h b/llvm/include/llvm-c/Core.h
index 97ada1e6cc4f0..ba02ca4825753 100644
--- a/llvm/include/llvm-c/Core.h
+++ b/llvm/include/llvm-c/Core.h
@@ -286,7 +286,6 @@ typedef enum {
   LLVMInstructionValueKind,
   LLVMPoisonValueValueKind,
   LLVMConstantTargetNoneValueKind,
-  LLVMConstantPtrAuthValueKind,
 } LLVMValueKind;
 
 typedef enum {
diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index 544c76bf4e3c9..66f058240ac16 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -412,7 +412,7 @@ enum ConstantsCodes {
                                       //                 asmdialect|unwind,
                                       //                 asmstr,conststr]
   CST_CODE_CE_GEP_WITH_INRANGE = 31,  // [opty, flags, range, n x operands]
-  CST_CODE_SIGNED_PTR = 32,           // [ptr, key, addrdisc, disc]
+  CST_CODE_PTRAUTH = 32,              // [ptr, key, addrdisc, disc]
 };
 
 /// CastOpcodes - These are values used in the bitcode files to encode which
diff --git a/llvm/include/llvm/IR/Value.def b/llvm/include/llvm/IR/Value.def
index 31110ff05ae36..3ece66a529e12 100644
--- a/llvm/include/llvm/IR/Value.def
+++ b/llvm/include/llvm/IR/Value.def
@@ -78,10 +78,10 @@ HANDLE_GLOBAL_VALUE(GlobalAlias)
 HANDLE_GLOBAL_VALUE(GlobalIFunc)
 HANDLE_GLOBAL_VALUE(GlobalVariable)
 HANDLE_CONSTANT(BlockAddress)
-HANDLE_CONSTANT(ConstantPtrAuth)
 HANDLE_CONSTANT(ConstantExpr)
 HANDLE_CONSTANT_EXCLUDE_LLVM_C_API(DSOLocalEquivalent)
 HANDLE_CONSTANT_EXCLUDE_LLVM_C_API(NoCFIValue)
+HANDLE_CONSTANT_EXCLUDE_LLVM_C_API(ConstantPtrAuth)
 
 // ConstantAggregate.
 HANDLE_CONSTANT(ConstantArray)
diff --git a/llvm/lib/Bitcode/Reader/BitcodeAnalyzer.cpp b/llvm/lib/Bitcode/Reader/BitcodeAnalyzer.cpp
index c085c715179ba..b7ed9cdf63145 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeAnalyzer.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeAnalyzer.cpp
@@ -222,6 +222,7 @@ GetCodeName(unsigned CodeID, unsigned BlockID,
       STRINGIFY_CODE(CST_CODE, CE_UNOP)
       STRINGIFY_CODE(CST_CODE, DSO_LOCAL_EQUIVALENT)
       STRINGIFY_CODE(CST_CODE, NO_CFI_VALUE)
+      STRINGIFY_CODE(CST_CODE, PTRAUTH)
     case bitc::CST_CODE_BLOCKADDRESS:
       return "CST_CODE_BLOCKADDRESS";
       STRINGIFY_CODE(CST_CODE, DATA)
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index eef080041219d..1d3d7343783e9 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -3642,7 +3642,7 @@ Error BitcodeReader::parseConstants() {
                                   Record[1]);
       break;
     }
-    case bitc::CST_CODE_SIGNED_PTR: {
+    case bitc::CST_CODE_PTRAUTH: {
       if (Record.size() < 4)
         return error("Invalid ptrauth record");
       // Ptr, Key, AddrDisc, Disc
diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index 71e001f733e56..eaa77c3198d3d 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -2831,7 +2831,7 @@ void ModuleBitcodeWriter::writeConstants(unsigned FirstVal, unsigned LastVal,
       Record.push_back(VE.getTypeID(NC->getGlobalValue()->getType()));
       Record.push_back(VE.getValueID(NC->getGlobalValue()));
     } else if (const auto *CPA = dyn_cast<ConstantPtrAuth>(C)) {
-      Code = bitc::CST_CODE_SIGNED_PTR;
+      Code = bitc::CST_CODE_PTRAUTH;
       Record.push_back(VE.getValueID(CPA->getPointer()));
       Record.push_back(VE.getValueID(CPA->getKey()));
       Record.push_back(VE.getValueID(CPA->getAddrDiscriminator()));
diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index 09186dbfd36ef..e345bfe1abb04 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -1590,14 +1590,14 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
     return;
   }
 
-  if (const ConstantPtrAuth *SP = dyn_cast<ConstantPtrAuth>(CV)) {
+  if (const ConstantPtrAuth *CPA = dyn_cast<ConstantPtrAuth>(CV)) {
     Out << "ptrauth (";
     ListSeparator LS;
-    for (unsigned i = 0; i < SP->getNumOperands(); ++i) {
+    for (auto *Op : CPA->operand_values()) {
       Out << LS;
-      WriterCtx.TypePrinter->print(SP->getOperand(i)->getType(), Out);
+      WriterCtx.TypePrinter->print(Op->getType(), Out);
       Out << ' ';
-      WriteAsOperandInternal(Out, SP->getOperand(i), WriterCtx);
+      WriteAsOperandInternal(Out, Op, WriterCtx);
     }
     Out << ')';
     return;
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 7b727ae2a704f..6bb6d2f440437 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -629,7 +629,7 @@ class Verifier : public InstVisitor<Verifier>, VerifierSupport {
 
   void visitConstantExprsRecursively(const Constant *EntryC);
   void visitConstantExpr(const ConstantExpr *CE);
-  void visitConstantPtrAuth(const ConstantPtrAuth *SP);
+  void visitConstantPtrAuth(const ConstantPtrAuth *CPA);
   void verifyInlineAsmCall(const CallBase &Call);
   void verifyStatepoint(const CallBase &Call);
   void verifyFrameRecoverIndices();
@@ -2423,8 +2423,8 @@ void Verifier::visitConstantExprsRecursively(const Constant *EntryC) {
     if (const auto *CE = dyn_cast<ConstantExpr>(C))
       visitConstantExpr(CE);
 
-    if (const auto *SP = dyn_cast<ConstantPtrAuth>(C))
-      visitConstantPtrAuth(SP);
+    if (const auto *CPA = dyn_cast<ConstantPtrAuth>(C))
+      visitConstantPtrAuth(CPA);
 
     if (const auto *GV = dyn_cast<GlobalValue>(C)) {
       // Global Values get visited separately, but we do need to make sure
@@ -2453,18 +2453,21 @@ void Verifier::visitConstantExpr(const ConstantExpr *CE) {
           "Invalid bitcast", CE);
 }
 
-void Verifier::visitConstantPtrAuth(const ConstantPtrAuth *SP) {
-  Check(SP->getPointer()->getType()->isPointerTy(),
-        "signed pointer must be a pointer");
+void Verifier::visitConstantPtrAuth(const ConstantPtrAuth *CPA) {
+  Check(CPA->getPointer()->getType()->isPointerTy(),
+        "signed ptrauth constant base pointer must have pointer type");
 
-  Check(SP->getKey()->getBitWidth() == 32,
-        "signed pointer key must be i32 constant integer");
+  Check(CPA->getType() == CPA->getPointer()->getType(),
+        "signed ptrauth constant must have same type as its base pointer");
 
-  Check(SP->getAddrDiscriminator()->getType()->isPointerTy(),
-        "signed pointer address discriminator must be a pointer");
+  Check(CPA->getKey()->getBitWidth() == 32,
+        "signed ptrauth constant key must be i32 constant integer");
 
-  Check(SP->getDiscriminator()->getBitWidth() == 64,
-        "signed pointer discriminator must be i64 constant integer");
+  Check(CPA->getAddrDiscriminator()->getType()->isPointerTy(),
+        "signed ptrauth constant address discriminator must be a pointer");
+
+  Check(CPA->getDiscriminator()->getBitWidth() == 64,
+        "signed ptrauth constant discriminator must be i64 constant integer");
 }
 
 bool Verifier::verifyAttributeCount(AttributeList Attrs, unsigned Params) {
@@ -5108,8 +5111,8 @@ void Verifier::visitInstruction(Instruction &I) {
     } else if (isa<InlineAsm>(I.getOperand(i))) {
       Check(CBI && &CBI->getCalledOperandUse() == &I.getOperandUse(i),
             "Cannot take the address of an inline asm!", &I);
-    } else if (auto SP = dyn_cast<ConstantPtrAuth>(I.getOperand(i))) {
-      visitConstantExprsRecursively(SP);
+    } else if (auto *CPA = dyn_cast<ConstantPtrAuth>(I.getOperand(i))) {
+      visitConstantExprsRecursively(CPA);
     } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(I.getOperand(i))) {
       if (CE->getType()->isPtrOrPtrVectorTy()) {
         // If we have a ConstantExpr pointer, we need to see if it came from an
diff --git a/llvm/test/Assembler/ptrauth-const.ll b/llvm/test/Assembler/ptrauth-const.ll
index 15b71e5ef8c13..d74feb352ec09 100644
--- a/llvm/test/Assembler/ptrauth-const.ll
+++ b/llvm/test/Assembler/ptrauth-const.ll
@@ -2,12 +2,25 @@
 
 @var = global i32 0
 
-; CHECK: @auth_var = global ptr ptrauth (ptr @var, i32 0, ptr null, i64 -1)
- at auth_var = global ptr ptrauth (ptr @var, i32 0, ptr null, i64 -1)
+; CHECK: @basic = global ptr ptrauth (ptr @var, i32 0, ptr null, i64 0)
+ at basic = global ptr ptrauth (ptr @var, i32 0, ptr null, i64 0)
 
+; CHECK: @keyed = global ptr ptrauth (ptr @var, i32 3, ptr null, i64 0)
+ at keyed = global ptr ptrauth (ptr @var, i32 3, ptr null, i64 0)
 
-; CHECK: @addrdisc_var = global ptr ptrauth (ptr @var, i32 0, ptr @addrdisc_var, i64 1234)
- at addrdisc_var = global ptr ptrauth (ptr @var, i32 0, ptr @addrdisc_var, i64 1234)
+; CHECK: @intdisc = global ptr ptrauth (ptr @var, i32 0, ptr null, i64 -1)
+ at intdisc = global ptr ptrauth (ptr @var, i32 0, ptr null, i64 -1)
 
-; CHECK: @keyed_var = global ptr ptrauth (ptr @var, i32 3, ptr null, i64 0)
- at keyed_var = global ptr ptrauth (ptr @var, i32 3, ptr null, i64 0)
+ at addrdisc_storage = global ptr null
+; CHECK: @addrdisc = global ptr ptrauth (ptr @var, i32 2, ptr @addrdisc_storage, i64 1234)
+ at addrdisc = global ptr ptrauth (ptr @var, i32 2, ptr @addrdisc_storage, i64 1234)
+
+
+ at var1 = addrspace(1) global i32 0
+
+; CHECK: @addrspace = global ptr addrspace(1) ptrauth (ptr addrspace(1) @var1, i32 0, ptr null, i64 0)
+ at addrspace = global ptr addrspace(1) ptrauth (ptr addrspace(1) @var1, i32 0, ptr null, i64 0)
+
+ at addrspace_addrdisc_storage = addrspace(2) global ptr addrspace(1) null
+; CHECK: @addrspace_addrdisc = global ptr addrspace(1) ptrauth (ptr addrspace(1) @var1, i32 2, ptr addrspace(2) @addrspace_addrdisc_storage, i64 1234)
+ at addrspace_addrdisc = global ptr addrspace(1) ptrauth (ptr addrspace(1) @var1, i32 2, ptr addrspace(2) @addrspace_addrdisc_storage, i64 1234)

>From c6638c7d80cd10a6a73a1ecdbfa6c21fa1825893 Mon Sep 17 00:00:00 2001
From: Ahmed Bougacha <ahmed at bougacha.org>
Date: Wed, 8 May 2024 10:26:00 -0700
Subject: [PATCH 4/4] Print disc/addrdisc optionally, and swap them in ops
 list.

We originally matched the llvm.ptrauth struct layout, but this
is nicer because:
- both are commonly 0/null, so we can omit them in the common case
- other than qualifier usage, integer discriminator alone is more
  common than an address discriminator, so omitting the latter alone
  helps also.
- key + integer discriminator looks like the bundle/intrinsics, which
  helps pattern matching when staring at IR.
- the variable-length addrdisc GEP string hides the integer disc.

This doesn't match the llvm.ptrauth struct, but that's relatively
straightforward to macro-update.

This doesn't match the C qualifier either, but there the address
discriminator is only a boolean `1`, so it's hard to mix these up.
---
 llvm/docs/LangRef.rst                         |  5 +-
 llvm/docs/PointerAuth.md                      |  2 +-
 llvm/include/llvm/Bitcode/LLVMBitCodes.h      |  2 +-
 llvm/include/llvm/IR/Constants.h              | 18 +++---
 llvm/lib/AsmParser/LLParser.cpp               | 58 +++++++++++--------
 llvm/lib/Bitcode/Reader/BitcodeReader.cpp     |  6 +-
 llvm/lib/Bitcode/Writer/BitcodeWriter.cpp     |  2 +-
 llvm/lib/IR/AsmWriter.cpp                     | 14 ++++-
 llvm/lib/IR/Constants.cpp                     | 14 ++---
 llvm/lib/IR/ConstantsContext.h                |  2 +-
 llvm/test/Assembler/invalid-ptrauth-const1.ll |  4 +-
 llvm/test/Assembler/invalid-ptrauth-const2.ll |  4 +-
 llvm/test/Assembler/invalid-ptrauth-const3.ll |  4 +-
 llvm/test/Assembler/invalid-ptrauth-const4.ll |  4 +-
 llvm/test/Assembler/invalid-ptrauth-const5.ll |  4 +-
 llvm/test/Assembler/invalid-ptrauth-const6.ll |  6 --
 llvm/test/Assembler/ptrauth-const.ll          | 24 ++++----
 llvm/test/Bitcode/compatibility.ll            |  4 +-
 18 files changed, 97 insertions(+), 80 deletions(-)
 delete mode 100644 llvm/test/Assembler/invalid-ptrauth-const6.ll

diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index cd534ac140dc8..326fad39ca85f 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -4755,7 +4755,7 @@ need to refer to the actual function body.
 Pointer Authentication Constants
 --------------------------------
 
-``ptrauth (ptr CST, i32 KEY, ptr ADDRDISC, i64 DISC)``
+``ptrauth (ptr CST, i32 KEY[, i64 DISC[, ptr ADDRDISC]?]?)``
 
 A '``ptrauth``' constant represents a pointer with a cryptographic
 authentication signature embedded into some bits, as described in the
@@ -4765,6 +4765,9 @@ A '``ptrauth``' constant is simply a constant equivalent to the
 ``llvm.ptrauth.sign`` intrinsic, potentially fed by a discriminator
 ``llvm.ptrauth.blend`` if needed.
 
+Its type is the same as the first argument.  An integer constant discriminator
+and an address discriminator may be optionally specified.  Otherwise, they have
+values ``i64 0`` and ``ptr null``.
 
 If the address disciminator is ``null`` then the expression is equivalent to
 
diff --git a/llvm/docs/PointerAuth.md b/llvm/docs/PointerAuth.md
index a52bff2663dce..91e7d9e519f1b 100644
--- a/llvm/docs/PointerAuth.md
+++ b/llvm/docs/PointerAuth.md
@@ -237,7 +237,7 @@ The latter are represented using a
 which describes an authenticated relocation producing a signed pointer.
 
 ```llvm
-  ptrauth (ptr CST, i32 KEY, ptr ADDRDISC, i64 DISC)
+  ptrauth (ptr CST, i32 KEY, i64 DISC, ptr ADDRDISC)
 ```
 
 is equivalent to:
diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index 66f058240ac16..ac68602803ec8 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -412,7 +412,7 @@ enum ConstantsCodes {
                                       //                 asmdialect|unwind,
                                       //                 asmstr,conststr]
   CST_CODE_CE_GEP_WITH_INRANGE = 31,  // [opty, flags, range, n x operands]
-  CST_CODE_PTRAUTH = 32,              // [ptr, key, addrdisc, disc]
+  CST_CODE_PTRAUTH = 32,              // [ptr, key, disc, addrdisc]
 };
 
 /// CastOpcodes - These are values used in the bitcode files to encode which
diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index 15bea35a5b14a..bf7827f3bd080 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -1013,8 +1013,8 @@ class ConstantPtrAuth final : public Constant {
   friend struct ConstantPtrAuthKeyType;
   friend class Constant;
 
-  ConstantPtrAuth(Constant *Ptr, ConstantInt *Key, Constant *AddrDisc,
-                  ConstantInt *Disc);
+  ConstantPtrAuth(Constant *Ptr, ConstantInt *Key, ConstantInt *Disc,
+                  Constant *AddrDisc);
 
   void *operator new(size_t s) { return User::operator new(s, 4); }
 
@@ -1024,7 +1024,7 @@ class ConstantPtrAuth final : public Constant {
 public:
   /// Return a pointer authenticated with the specified parameters.
   static ConstantPtrAuth *get(Constant *Ptr, ConstantInt *Key,
-                              Constant *AddrDisc, ConstantInt *Disc);
+                              ConstantInt *Disc, Constant *AddrDisc);
 
   /// Produce a new ptrauth expression signing the given value using
   /// the same schema as is stored in one.
@@ -1039,16 +1039,16 @@ class ConstantPtrAuth final : public Constant {
   /// The Key ID, an i32 constant.
   ConstantInt *getKey() const { return cast<ConstantInt>(Op<1>().get()); }
 
+  /// The discriminator.
+  ConstantInt *getDiscriminator() const {
+    return cast<ConstantInt>(Op<2>().get());
+  }
+
   /// The address discriminator if any, or the null constant.
   /// If present, this must be a value equivalent to the storage location of
   /// the only user of the authenticated ptrauth global.
   Constant *getAddrDiscriminator() const {
-    return cast<Constant>(Op<2>().get());
-  }
-
-  /// The discriminator.
-  ConstantInt *getDiscriminator() const {
-    return cast<ConstantInt>(Op<3>().get());
+    return cast<Constant>(Op<3>().get());
   }
 
   /// Whether there is any non-null address discriminator.
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index 7c9ccec8cb429..2b7a42e57d1bc 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -4045,44 +4045,56 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
     return false;
   }
   case lltok::kw_ptrauth: {
-    // ValID ::= 'ptrauth' '(' ptr @foo ',' i32 <key> ','
-    //                         ptr addrdisc ',' i64 <disc> ')'
+    // ValID ::= 'ptrauth' '(' ptr @foo ',' i32 <key>
+    //                         (',' i64 <disc> (',' ptr addrdisc)? )? ')'
     Lex.Lex();
 
-    Constant *Ptr, *Key, *AddrDisc, *Disc;
+    Constant *Ptr, *Key;
+    Constant *Disc = nullptr, *AddrDisc = nullptr;
 
     if (parseToken(lltok::lparen,
-                   "expected '(' in signed pointer expression") ||
+                   "expected '(' in constant ptrauth expression") ||
         parseGlobalTypeAndValue(Ptr) ||
         parseToken(lltok::comma,
-                   "expected comma in signed pointer expression") ||
-        parseGlobalTypeAndValue(Key) ||
-        parseToken(lltok::comma,
-                   "expected comma in signed pointer expression") ||
-        parseGlobalTypeAndValue(AddrDisc) ||
-        parseToken(lltok::comma,
-                   "expected comma in signed pointer expression") ||
-        parseGlobalTypeAndValue(Disc) ||
-        parseToken(lltok::rparen, "expected ')' in signed pointer expression"))
+                   "expected comma in constant ptrauth expression") ||
+        parseGlobalTypeAndValue(Key))
+      return true;
+    // If present, parse the optional disc/addrdisc.
+    if (EatIfPresent(lltok::comma))
+      if (parseGlobalTypeAndValue(Disc) ||
+          (EatIfPresent(lltok::comma) && parseGlobalTypeAndValue(AddrDisc)))
+        return true;
+    if (parseToken(lltok::rparen,
+                   "expected ')' in constant ptrauth expression"))
       return true;
 
     if (!Ptr->getType()->isPointerTy())
-      return error(ID.Loc, "signed pointer must be a pointer");
+      return error(ID.Loc, "constant ptrauth base pointer must be a pointer");
 
     auto *KeyC = dyn_cast<ConstantInt>(Key);
     if (!KeyC || KeyC->getBitWidth() != 32)
-      return error(ID.Loc, "signed pointer key must be i32 constant integer");
+      return error(ID.Loc, "constant ptrauth key must be i32 constant");
 
-    if (!AddrDisc->getType()->isPointerTy())
-      return error(ID.Loc,
-                   "signed pointer address discriminator must be a pointer");
+    ConstantInt *DiscC = nullptr;
+    if (Disc) {
+      DiscC = dyn_cast<ConstantInt>(Disc);
+      if (!DiscC || DiscC->getBitWidth() != 64)
+        return error(
+            ID.Loc,
+            "constant ptrauth integer discriminator must be i64 constant");
+    } else {
+      DiscC = ConstantInt::get(Type::getInt64Ty(Context), 0);
+    }
 
-    auto *DiscC = dyn_cast<ConstantInt>(Disc);
-    if (!DiscC || DiscC->getBitWidth() != 64)
-      return error(ID.Loc,
-                   "signed pointer discriminator must be i64 constant integer");
+    if (AddrDisc) {
+      if (!AddrDisc->getType()->isPointerTy())
+        return error(
+            ID.Loc, "constant ptrauth address discriminator must be a pointer");
+    } else {
+      AddrDisc = ConstantPointerNull::get(PointerType::get(Context, 0));
+    }
 
-    ID.ConstantVal = ConstantPtrAuth::get(Ptr, KeyC, AddrDisc, DiscC);
+    ID.ConstantVal = ConstantPtrAuth::get(Ptr, KeyC, DiscC, AddrDisc);
     ID.Kind = ValID::t_Constant;
     return false;
   }
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index 1d3d7343783e9..0cd3ce4bd0b1b 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -1556,11 +1556,11 @@ Expected<Value *> BitcodeReader::materializeValue(unsigned StartValID,
           if (!Key)
             return error("ptrauth key operand must be ConstantInt");
 
-          auto *Disc = dyn_cast<ConstantInt>(ConstOps[3]);
+          auto *Disc = dyn_cast<ConstantInt>(ConstOps[2]);
           if (!Disc)
             return error("ptrauth disc operand must be ConstantInt");
 
-          C = ConstantPtrAuth::get(ConstOps[0], Key, ConstOps[2], Disc);
+          C = ConstantPtrAuth::get(ConstOps[0], Key, Disc, ConstOps[3]);
           break;
         }
         case BitcodeConstant::NoCFIOpcode: {
@@ -3645,7 +3645,7 @@ Error BitcodeReader::parseConstants() {
     case bitc::CST_CODE_PTRAUTH: {
       if (Record.size() < 4)
         return error("Invalid ptrauth record");
-      // Ptr, Key, AddrDisc, Disc
+      // Ptr, Key, Disc, AddrDisc
       V = BitcodeConstant::create(
         Alloc, CurTy, BitcodeConstant::ConstantPtrAuthOpcode,
         {(unsigned)Record[0], (unsigned)Record[1], (unsigned)Record[2],
diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index eaa77c3198d3d..0ae5b422b9fe7 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -2834,8 +2834,8 @@ void ModuleBitcodeWriter::writeConstants(unsigned FirstVal, unsigned LastVal,
       Code = bitc::CST_CODE_PTRAUTH;
       Record.push_back(VE.getValueID(CPA->getPointer()));
       Record.push_back(VE.getValueID(CPA->getKey()));
-      Record.push_back(VE.getValueID(CPA->getAddrDiscriminator()));
       Record.push_back(VE.getValueID(CPA->getDiscriminator()));
+      Record.push_back(VE.getValueID(CPA->getAddrDiscriminator()));
     } else {
 #ifndef NDEBUG
       C->dump();
diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index e345bfe1abb04..c80c7d2688276 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -1592,12 +1592,20 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
 
   if (const ConstantPtrAuth *CPA = dyn_cast<ConstantPtrAuth>(CV)) {
     Out << "ptrauth (";
+
+    // ptrauth (ptr CST, i32 KEY[, i64 DISC[, ptr ADDRDISC]?]?)
+    unsigned NumOpsToWrite = 2;
+    if (!CPA->getOperand(2)->isNullValue())
+      NumOpsToWrite = 3;
+    if (!CPA->getOperand(3)->isNullValue())
+      NumOpsToWrite = 4;
+
     ListSeparator LS;
-    for (auto *Op : CPA->operand_values()) {
+    for (unsigned i = 0, e = NumOpsToWrite; i != e; ++i) {
       Out << LS;
-      WriterCtx.TypePrinter->print(Op->getType(), Out);
+      WriterCtx.TypePrinter->print(CPA->getOperand(i)->getType(), Out);
       Out << ' ';
-      WriteAsOperandInternal(Out, Op, WriterCtx);
+      WriteAsOperandInternal(Out, CPA->getOperand(i), WriterCtx);
     }
     Out << ')';
     return;
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index 11214fc8246ab..e58675ba81a8f 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -2071,28 +2071,28 @@ bool ConstantPtrAuth::isCompatibleWith(const Value *Key,
 }
 
 ConstantPtrAuth *ConstantPtrAuth::getWithSameSchema(Constant *Pointer) const {
-  return get(Pointer, getKey(), getAddrDiscriminator(), getDiscriminator());
+  return get(Pointer, getKey(), getDiscriminator(), getAddrDiscriminator());
 }
 
 ConstantPtrAuth *ConstantPtrAuth::get(Constant *Ptr, ConstantInt *Key,
-                                      Constant *AddrDisc, ConstantInt *Disc) {
-  Constant *ArgVec[] = {Ptr, Key, AddrDisc, Disc};
+                                      ConstantInt *Disc, Constant *AddrDisc) {
+  Constant *ArgVec[] = {Ptr, Key, Disc, AddrDisc};
   ConstantPtrAuthKeyType MapKey(ArgVec);
   LLVMContextImpl *pImpl = Ptr->getContext().pImpl;
   return pImpl->ConstantPtrAuths.getOrCreate(Ptr->getType(), MapKey);
 }
 
 ConstantPtrAuth::ConstantPtrAuth(Constant *Ptr, ConstantInt *Key,
-                                 Constant *AddrDisc, ConstantInt *Disc)
+                                 ConstantInt *Disc, Constant *AddrDisc)
     : Constant(Ptr->getType(), Value::ConstantPtrAuthVal, &Op<0>(), 4) {
   assert(Ptr->getType()->isPointerTy());
   assert(Key->getBitWidth() == 32);
-  assert(AddrDisc->getType()->isPointerTy());
   assert(Disc->getBitWidth() == 64);
+  assert(AddrDisc->getType()->isPointerTy());
   setOperand(0, Ptr);
   setOperand(1, Key);
-  setOperand(2, AddrDisc);
-  setOperand(3, Disc);
+  setOperand(2, Disc);
+  setOperand(3, AddrDisc);
 }
 
 /// Remove the constant from the constant table.
diff --git a/llvm/lib/IR/ConstantsContext.h b/llvm/lib/IR/ConstantsContext.h
index 09b80b0a0953d..5153880b5cab6 100644
--- a/llvm/lib/IR/ConstantsContext.h
+++ b/llvm/lib/IR/ConstantsContext.h
@@ -579,7 +579,7 @@ struct ConstantPtrAuthKeyType {
 
   ConstantPtrAuth *create(TypeClass *Ty) const {
     return new ConstantPtrAuth(Operands[0], cast<ConstantInt>(Operands[1]),
-                               Operands[2], cast<ConstantInt>(Operands[3]));
+                               cast<ConstantInt>(Operands[2]), Operands[3]);
   }
 };
 
diff --git a/llvm/test/Assembler/invalid-ptrauth-const1.ll b/llvm/test/Assembler/invalid-ptrauth-const1.ll
index dfe0e1c5ed3f0..fba2e23078238 100644
--- a/llvm/test/Assembler/invalid-ptrauth-const1.ll
+++ b/llvm/test/Assembler/invalid-ptrauth-const1.ll
@@ -2,5 +2,5 @@
 
 @var = global i32 0
 
-; CHECK: error: signed pointer must be a pointer
- at auth_var = global ptr ptrauth (i32 42, i32 0, ptr null, i64 65535)
+; CHECK: error: constant ptrauth base pointer must be a pointer
+ at auth_var = global ptr ptrauth (i32 42, i32 0)
diff --git a/llvm/test/Assembler/invalid-ptrauth-const2.ll b/llvm/test/Assembler/invalid-ptrauth-const2.ll
index 9042d8dfe7726..4499c42601c99 100644
--- a/llvm/test/Assembler/invalid-ptrauth-const2.ll
+++ b/llvm/test/Assembler/invalid-ptrauth-const2.ll
@@ -2,5 +2,5 @@
 
 @var = global i32 0
 
-; CHECK: error: signed pointer key must be i32 constant integer
- at auth_var = global ptr ptrauth (ptr @var, i32 ptrtoint (ptr @var to i32), ptr null, i64 65535)
+; CHECK: error: constant ptrauth key must be i32 constant
+ at auth_var = global ptr ptrauth (ptr @var, i32 ptrtoint (ptr @var to i32))
diff --git a/llvm/test/Assembler/invalid-ptrauth-const3.ll b/llvm/test/Assembler/invalid-ptrauth-const3.ll
index 00bcef9db6d1b..3f2688d92a001 100644
--- a/llvm/test/Assembler/invalid-ptrauth-const3.ll
+++ b/llvm/test/Assembler/invalid-ptrauth-const3.ll
@@ -2,5 +2,5 @@
 
 @var = global i32 0
 
-; CHECK: error: signed pointer address discriminator must be a pointer
- at auth_var = global ptr ptrauth (ptr @var, i32 2, i8 0, i64 65535)
+; CHECK: error: constant ptrauth address discriminator must be a pointer
+ at auth_var = global ptr ptrauth (ptr @var, i32 2, i64 65535, i8 0)
diff --git a/llvm/test/Assembler/invalid-ptrauth-const4.ll b/llvm/test/Assembler/invalid-ptrauth-const4.ll
index 00bcef9db6d1b..843a220458a61 100644
--- a/llvm/test/Assembler/invalid-ptrauth-const4.ll
+++ b/llvm/test/Assembler/invalid-ptrauth-const4.ll
@@ -2,5 +2,5 @@
 
 @var = global i32 0
 
-; CHECK: error: signed pointer address discriminator must be a pointer
- at auth_var = global ptr ptrauth (ptr @var, i32 2, i8 0, i64 65535)
+; CHECK: error: constant ptrauth integer discriminator must be i64 constant
+ at auth_var = global ptr ptrauth (ptr @var, i32 2, ptr null, i64 ptrtoint (ptr @var to i64))
diff --git a/llvm/test/Assembler/invalid-ptrauth-const5.ll b/llvm/test/Assembler/invalid-ptrauth-const5.ll
index 72910769de262..9b47f6f5f423f 100644
--- a/llvm/test/Assembler/invalid-ptrauth-const5.ll
+++ b/llvm/test/Assembler/invalid-ptrauth-const5.ll
@@ -2,5 +2,5 @@
 
 @var = global i32 0
 
-; CHECK: error: signed pointer discriminator must be i64 constant integer
- at auth_var = global ptr ptrauth (ptr @var, i32 2, ptr null, i64 ptrtoint (ptr @var to i64))
+; CHECK: error: constant ptrauth integer discriminator must be i64 constant
+ at auth_var = global ptr ptrauth (ptr @var, i32 2, ptr @var))
diff --git a/llvm/test/Assembler/invalid-ptrauth-const6.ll b/llvm/test/Assembler/invalid-ptrauth-const6.ll
deleted file mode 100644
index d79d88ee845ce..0000000000000
--- a/llvm/test/Assembler/invalid-ptrauth-const6.ll
+++ /dev/null
@@ -1,6 +0,0 @@
-; RUN: not llvm-as < %s 2>&1 | FileCheck %s
-
- at var = global i32 0
-
-; CHECK: error: signed pointer address discriminator must be a pointer
- at auth_var = global ptr ptrauth (ptr @var, i32 2, i8 0, i32 1000000)
diff --git a/llvm/test/Assembler/ptrauth-const.ll b/llvm/test/Assembler/ptrauth-const.ll
index d74feb352ec09..fdfece58dd863 100644
--- a/llvm/test/Assembler/ptrauth-const.ll
+++ b/llvm/test/Assembler/ptrauth-const.ll
@@ -2,25 +2,25 @@
 
 @var = global i32 0
 
-; CHECK: @basic = global ptr ptrauth (ptr @var, i32 0, ptr null, i64 0)
- at basic = global ptr ptrauth (ptr @var, i32 0, ptr null, i64 0)
+; CHECK: @basic = global ptr ptrauth (ptr @var, i32 0)
+ at basic = global ptr ptrauth (ptr @var, i32 0)
 
-; CHECK: @keyed = global ptr ptrauth (ptr @var, i32 3, ptr null, i64 0)
- at keyed = global ptr ptrauth (ptr @var, i32 3, ptr null, i64 0)
+; CHECK: @keyed = global ptr ptrauth (ptr @var, i32 3)
+ at keyed = global ptr ptrauth (ptr @var, i32 3)
 
-; CHECK: @intdisc = global ptr ptrauth (ptr @var, i32 0, ptr null, i64 -1)
- at intdisc = global ptr ptrauth (ptr @var, i32 0, ptr null, i64 -1)
+; CHECK: @intdisc = global ptr ptrauth (ptr @var, i32 0, i64 -1)
+ at intdisc = global ptr ptrauth (ptr @var, i32 0, i64 -1)
 
 @addrdisc_storage = global ptr null
-; CHECK: @addrdisc = global ptr ptrauth (ptr @var, i32 2, ptr @addrdisc_storage, i64 1234)
- at addrdisc = global ptr ptrauth (ptr @var, i32 2, ptr @addrdisc_storage, i64 1234)
+; CHECK: @addrdisc = global ptr ptrauth (ptr @var, i32 2, i64 1234, ptr @addrdisc_storage)
+ at addrdisc = global ptr ptrauth (ptr @var, i32 2, i64 1234, ptr @addrdisc_storage)
 
 
 @var1 = addrspace(1) global i32 0
 
-; CHECK: @addrspace = global ptr addrspace(1) ptrauth (ptr addrspace(1) @var1, i32 0, ptr null, i64 0)
- at addrspace = global ptr addrspace(1) ptrauth (ptr addrspace(1) @var1, i32 0, ptr null, i64 0)
+; CHECK: @addrspace = global ptr addrspace(1) ptrauth (ptr addrspace(1) @var1, i32 0)
+ at addrspace = global ptr addrspace(1) ptrauth (ptr addrspace(1) @var1, i32 0)
 
 @addrspace_addrdisc_storage = addrspace(2) global ptr addrspace(1) null
-; CHECK: @addrspace_addrdisc = global ptr addrspace(1) ptrauth (ptr addrspace(1) @var1, i32 2, ptr addrspace(2) @addrspace_addrdisc_storage, i64 1234)
- at addrspace_addrdisc = global ptr addrspace(1) ptrauth (ptr addrspace(1) @var1, i32 2, ptr addrspace(2) @addrspace_addrdisc_storage, i64 1234)
+; CHECK: @addrspace_addrdisc = global ptr addrspace(1) ptrauth (ptr addrspace(1) @var1, i32 2, i64 1234, ptr addrspace(2) @addrspace_addrdisc_storage)
+ at addrspace_addrdisc = global ptr addrspace(1) ptrauth (ptr addrspace(1) @var1, i32 2, i64 1234, ptr addrspace(2) @addrspace_addrdisc_storage)
diff --git a/llvm/test/Bitcode/compatibility.ll b/llvm/test/Bitcode/compatibility.ll
index 1400956737e1e..2a846e036924c 100644
--- a/llvm/test/Bitcode/compatibility.ll
+++ b/llvm/test/Bitcode/compatibility.ll
@@ -218,8 +218,8 @@ declare void @g.f1()
 ; CHECK: @g.sanitize_multiple = global i32 0, sanitize_memtag, sanitize_address_dyninit
 
 ; ptrauth constant
- at auth_var = global ptr ptrauth (ptr @g1, i32 0, ptr null, i64 65535)
-; CHECK: @auth_var = global ptr ptrauth (ptr @g1, i32 0, ptr null, i64 65535)
+ at auth_var = global ptr ptrauth (ptr @g1, i32 0, i64 65535, ptr null)
+; CHECK: @auth_var = global ptr ptrauth (ptr @g1, i32 0, i64 65535)
 
 ;; Aliases
 ; Format: @<Name> = [Linkage] [Visibility] [DLLStorageClass] [ThreadLocal]



More information about the llvm-commits mailing list