[llvm] [InstCombine] Combine ptrauth intrin. callee into same-key bundle. (PR #94707)

Ahmed Bougacha via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 6 17:29:41 PDT 2024


https://github.com/ahmedbougacha created https://github.com/llvm/llvm-project/pull/94707

Try to optimize a call to the result of a ptrauth intrinsic, potentially into the ptrauth call bundle:
```
  call(ptrauth.resign(p)), ["ptrauth"()] ->  call p, ["ptrauth"()]
  call(ptrauth.sign(p)),   ["ptrauth"()] ->  call p
```

as long as the key/discriminator are the same in sign and auth-bundle, and we don't change the key in the bundle (to a potentially-invalid key.)

Generating a plain call to a raw unauthenticated pointer is generally undesirable, but if we ended up seeing a naked ptrauth.sign in the first place, we already have suspicious code.  Unauthenticated calls are also easier to spot than naked signs, so let the indirect call shine.

>From 903f1497bad057039e4580e7fe33787e31f217d7 Mon Sep 17 00:00:00 2001
From: Ahmed Bougacha <ahmed at bougacha.org>
Date: Mon, 31 Jul 2023 07:22:40 -0700
Subject: [PATCH] [InstCombine] Combine ptrauth intrin. callee into same-key
 bundle.

Try to optimize a call to the result of a ptrauth intrinsic, potentially
into the ptrauth call bundle:
- call(ptrauth.resign(p)), ["ptrauth"()] ->  call p, ["ptrauth"()]
- call(ptrauth.sign(p)),   ["ptrauth"()] ->  call p
as long as the key/discriminator are the same in sign and auth-bundle,
and we don't change the key in the bundle (to a potentially-invalid
key.)

Generating a plain call to a raw unauthenticated pointer is generally
undesirable, but if we ended up seeing a naked ptrauth.sign in the first
place, we already have suspicious code.  Unauthenticated calls are also
easier to spot than naked signs, so let the indirect call shine.
---
 .../InstCombine/InstCombineCalls.cpp          | 76 +++++++++++++++++++
 .../InstCombine/InstCombineInternal.h         |  8 ++
 .../InstCombine/ptrauth-intrinsics-call.ll    | 76 +++++++++++++++++++
 3 files changed, 160 insertions(+)
 create mode 100644 llvm/test/Transforms/InstCombine/ptrauth-intrinsics-call.ll

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 0632f3cfc6dd2..8dd2ecadf5135 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -3647,6 +3647,78 @@ static IntrinsicInst *findInitTrampoline(Value *Callee) {
   return nullptr;
 }
 
+Instruction *InstCombinerImpl::foldPtrAuthIntrinsicCallee(CallBase &Call) {
+  Value *Callee = Call.getCalledOperand();
+  auto *IPC = dyn_cast<IntToPtrInst>(Callee);
+  if (!IPC || !IPC->isNoopCast(DL))
+    return nullptr;
+
+  IntrinsicInst *II = dyn_cast<IntrinsicInst>(IPC->getOperand(0));
+  if (!II)
+    return nullptr;
+
+  // Isolate the ptrauth bundle from the others.
+  std::optional<OperandBundleUse> PtrAuthBundleOrNone;
+  SmallVector<OperandBundleDef, 2> NewBundles;
+  for (unsigned BI = 0, BE = Call.getNumOperandBundles(); BI != BE; ++BI) {
+    OperandBundleUse Bundle = Call.getOperandBundleAt(BI);
+    if (Bundle.getTagID() == LLVMContext::OB_ptrauth)
+      PtrAuthBundleOrNone = Bundle;
+    else
+      NewBundles.emplace_back(Bundle);
+  }
+
+  Value *NewCallee = nullptr;
+  switch (II->getIntrinsicID()) {
+  default:
+    return nullptr;
+
+  // call(ptrauth.resign(p)), ["ptrauth"()] ->  call p, ["ptrauth"()]
+  // assuming the call bundle and the sign operands match.
+  case Intrinsic::ptrauth_resign: {
+    if (!PtrAuthBundleOrNone ||
+        II->getOperand(3) != PtrAuthBundleOrNone->Inputs[0] ||
+        II->getOperand(4) != PtrAuthBundleOrNone->Inputs[1])
+      return nullptr;
+
+    // Don't change the key used in the call; we don't know what's valid.
+    if (II->getOperand(1) != PtrAuthBundleOrNone->Inputs[0])
+      return nullptr;
+
+    Value *NewBundleOps[] = {II->getOperand(1), II->getOperand(2)};
+    NewBundles.emplace_back("ptrauth", NewBundleOps);
+    NewCallee = II->getOperand(0);
+    break;
+  }
+
+  // call(ptrauth.sign(p)), ["ptrauth"()] ->  call p
+  // assuming the call bundle and the sign operands match.
+  // Non-ptrauth indirect calls are undesirable, but so is ptrauth.sign.
+  case Intrinsic::ptrauth_sign: {
+    if (!PtrAuthBundleOrNone ||
+        II->getOperand(1) != PtrAuthBundleOrNone->Inputs[0] ||
+        II->getOperand(2) != PtrAuthBundleOrNone->Inputs[1])
+      return nullptr;
+    NewCallee = II->getOperand(0);
+    break;
+  }
+  }
+
+  if (!NewCallee)
+    return nullptr;
+
+  NewCallee = Builder.CreateBitOrPointerCast(NewCallee, Callee->getType());
+  CallBase *NewCall = nullptr;
+  if (auto *CI = dyn_cast<CallInst>(&Call)) {
+    NewCall = CallInst::Create(CI, NewBundles);
+  } else {
+    auto *IKI = cast<InvokeInst>(&Call);
+    NewCall = InvokeInst::Create(IKI, NewBundles);
+  }
+  NewCall->setCalledOperand(NewCallee);
+  return NewCall;
+}
+
 bool InstCombinerImpl::annotateAnyAllocSite(CallBase &Call,
                                             const TargetLibraryInfo *TLI) {
   // Note: We only handle cases which can't be driven from generic attributes
@@ -3794,6 +3866,10 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) {
   if (IntrinsicInst *II = findInitTrampoline(Callee))
     return transformCallThroughTrampoline(Call, *II);
 
+  // Combine calls involving pointer authentication intrinsics.
+  if (Instruction *NewCall = foldPtrAuthIntrinsicCallee(Call))
+    return NewCall;
+
   if (isa<InlineAsm>(Callee) && !Call.doesNotThrow()) {
     InlineAsm *IA = cast<InlineAsm>(Callee);
     if (!IA->canThrow()) {
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 984f02bcccad7..f290f446b424e 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -282,6 +282,14 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   Instruction *transformCallThroughTrampoline(CallBase &Call,
                                               IntrinsicInst &Tramp);
 
+  /// Try to optimize a call to the result of a ptrauth intrinsic, potentially
+  /// into the ptrauth call bundle:
+  /// - call(ptrauth.resign(p)), ["ptrauth"()] ->  call p, ["ptrauth"()]
+  /// - call(ptrauth.sign(p)),   ["ptrauth"()] ->  call p
+  /// as long as the key/discriminator are the same in sign and auth-bundle,
+  /// and we don't change the key in the bundle (to a potentially-invalid key.)
+  Instruction *foldPtrAuthIntrinsicCallee(CallBase &Call);
+
   // Return (a, b) if (LHS, RHS) is known to be (a, b) or (b, a).
   // Otherwise, return std::nullopt
   // Currently it matches:
diff --git a/llvm/test/Transforms/InstCombine/ptrauth-intrinsics-call.ll b/llvm/test/Transforms/InstCombine/ptrauth-intrinsics-call.ll
new file mode 100644
index 0000000000000..900fd43f01602
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/ptrauth-intrinsics-call.ll
@@ -0,0 +1,76 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define i32 @test_ptrauth_call_resign(ptr %p) {
+; CHECK-LABEL: @test_ptrauth_call_resign(
+; CHECK-NEXT:    [[V3:%.*]] = call i32 [[P:%.*]]() [ "ptrauth"(i32 1, i64 1234) ]
+; CHECK-NEXT:    ret i32 [[V3]]
+;
+  %v0 = ptrtoint ptr %p to i64
+  %v1 = call i64 @llvm.ptrauth.resign(i64 %v0, i32 1, i64 1234, i32 1, i64 5678)
+  %v2 = inttoptr i64 %v1 to i32()*
+  %v3 = call i32 %v2() [ "ptrauth"(i32 1, i64 5678) ]
+  ret i32 %v3
+}
+
+define i32 @test_ptrauth_call_resign_blend(ptr %pp) {
+; CHECK-LABEL: @test_ptrauth_call_resign_blend(
+; CHECK-NEXT:    [[V01:%.*]] = load ptr, ptr [[PP:%.*]], align 8
+; CHECK-NEXT:    [[V6:%.*]] = call i32 [[V01]]() [ "ptrauth"(i32 1, i64 1234) ]
+; CHECK-NEXT:    ret i32 [[V6]]
+;
+  %v0 = load ptr, ptr %pp, align 8
+  %v1 = ptrtoint ptr %pp to i64
+  %v2 = ptrtoint ptr %v0 to i64
+  %v3 = call i64 @llvm.ptrauth.blend(i64 %v1, i64 5678)
+  %v4 = call i64 @llvm.ptrauth.resign(i64 %v2, i32 1, i64 1234, i32 1, i64 %v3)
+  %v5 = inttoptr i64 %v4 to i32()*
+  %v6 = call i32 %v5() [ "ptrauth"(i32 1, i64 %v3) ]
+  ret i32 %v6
+}
+
+define i32 @test_ptrauth_call_resign_blend_2(ptr %pp) {
+; CHECK-LABEL: @test_ptrauth_call_resign_blend_2(
+; CHECK-NEXT:    [[V01:%.*]] = load ptr, ptr [[PP:%.*]], align 8
+; CHECK-NEXT:    [[V1:%.*]] = ptrtoint ptr [[PP]] to i64
+; CHECK-NEXT:    [[V3:%.*]] = call i64 @llvm.ptrauth.blend(i64 [[V1]], i64 5678)
+; CHECK-NEXT:    [[V6:%.*]] = call i32 [[V01]]() [ "ptrauth"(i32 0, i64 [[V3]]) ]
+; CHECK-NEXT:    ret i32 [[V6]]
+;
+  %v0 = load ptr, ptr %pp, align 8
+  %v1 = ptrtoint ptr %pp to i64
+  %v2 = ptrtoint ptr %v0 to i64
+  %v3 = call i64 @llvm.ptrauth.blend(i64 %v1, i64 5678)
+  %v4 = call i64 @llvm.ptrauth.resign(i64 %v2, i32 0, i64 %v3, i32 0, i64 1234)
+  %v5 = inttoptr i64 %v4 to i32()*
+  %v6 = call i32 %v5() [ "ptrauth"(i32 0, i64 1234) ]
+  ret i32 %v6
+}
+
+define i32 @test_ptrauth_call_sign(ptr %p) {
+; CHECK-LABEL: @test_ptrauth_call_sign(
+; CHECK-NEXT:    [[V3:%.*]] = call i32 [[P:%.*]]()
+; CHECK-NEXT:    ret i32 [[V3]]
+;
+  %v0 = ptrtoint ptr %p to i64
+  %v1 = call i64 @llvm.ptrauth.sign(i64 %v0, i32 2, i64 5678)
+  %v2 = inttoptr i64 %v1 to i32()*
+  %v3 = call i32 %v2() [ "ptrauth"(i32 2, i64 5678) ]
+  ret i32 %v3
+}
+
+define i32 @test_ptrauth_call_sign_otherbundle(ptr %p) {
+; CHECK-LABEL: @test_ptrauth_call_sign_otherbundle(
+; CHECK-NEXT:    [[V3:%.*]] = call i32 [[P:%.*]]() [ "somebundle"(ptr null), "otherbundle"(i64 0) ]
+; CHECK-NEXT:    ret i32 [[V3]]
+;
+  %v0 = ptrtoint ptr %p to i64
+  %v1 = call i64 @llvm.ptrauth.sign(i64 %v0, i32 2, i64 5678)
+  %v2 = inttoptr i64 %v1 to i32()*
+  %v3 = call i32 %v2() [ "somebundle"(ptr null), "ptrauth"(i32 2, i64 5678), "otherbundle"(i64 0) ]
+  ret i32 %v3
+}
+
+declare i64 @llvm.ptrauth.sign(i64, i32, i64)
+declare i64 @llvm.ptrauth.resign(i64, i32, i64, i32, i64)
+declare i64 @llvm.ptrauth.blend(i64, i64)



More information about the llvm-commits mailing list