[llvm] [InstCombine] Combine ptrauth constant callee into bundle. (PR #94706)

Ahmed Bougacha via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 13 13:42:03 PDT 2024


https://github.com/ahmedbougacha updated https://github.com/llvm/llvm-project/pull/94706

>From de79c48651bbdafe6559ccc759d4d2b5df3ce212 Mon Sep 17 00:00:00 2001
From: Ahmed Bougacha <ahmed at bougacha.org>
Date: Mon, 27 Sep 2021 08:00:00 -0700
Subject: [PATCH 1/2] [InstCombine] Combine ptrauth constant callee into
 bundle.

Try to optimize a call to a ptrauth constant, into its ptrauth bundle:
  call(ptrauth(f)), ["ptrauth"()] ->  call f
as long as the key/discriminator are the same in constant and bundle.
---
 .../InstCombine/InstCombineCalls.cpp          | 32 +++++++
 .../InstCombine/InstCombineInternal.h         |  5 ++
 .../Transforms/InstCombine/ptrauth-call.ll    | 89 +++++++++++++++++++
 3 files changed, 126 insertions(+)
 create mode 100644 llvm/test/Transforms/InstCombine/ptrauth-call.ll

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 436cdbff75669..64f3038d94f94 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -3665,6 +3665,34 @@ static IntrinsicInst *findInitTrampoline(Value *Callee) {
   return nullptr;
 }
 
+Instruction *InstCombinerImpl::foldPtrAuthConstantCallee(CallBase &Call) {
+  auto *CPA = dyn_cast<ConstantPtrAuth>(Call.getCalledOperand());
+  if (!CPA)
+    return nullptr;
+
+  auto *CalleeF = dyn_cast<Function>(CPA->getPointer()->stripPointerCasts());
+  // If the ptrauth constant isn't based on a function pointer, bail out.
+  if (!CalleeF)
+    return nullptr;
+
+  // Inspect the call ptrauth bundle to check it matches the ptrauth constant.
+  auto PAB = Call.getOperandBundle(LLVMContext::OB_ptrauth);
+  if (!PAB)
+    return nullptr;
+
+  auto *Key = cast<ConstantInt>(PAB->Inputs[0]);
+  Value *Discriminator = PAB->Inputs[1];
+
+  // If the bundle doesn't match, this is probably going to fail to auth.
+  if (!CPA->isKnownCompatibleWith(Key, Discriminator, DL))
+    return nullptr;
+
+  // If the bundle matches the constant, proceed in making this a direct call.
+  auto *NewCall = CallBase::removeOperandBundle(&Call, LLVMContext::OB_ptrauth);
+  NewCall->setCalledOperand(CalleeF);
+  return NewCall;
+}
+
 bool InstCombinerImpl::annotateAnyAllocSite(CallBase &Call,
                                             const TargetLibraryInfo *TLI) {
   // Note: We only handle cases which can't be driven from generic attributes
@@ -3812,6 +3840,10 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) {
   if (IntrinsicInst *II = findInitTrampoline(Callee))
     return transformCallThroughTrampoline(Call, *II);
 
+  // Combine calls to ptrauth constants.
+  if (Instruction *NewCall = foldPtrAuthConstantCallee(Call))
+    return NewCall;
+
   if (isa<InlineAsm>(Callee) && !Call.doesNotThrow()) {
     InlineAsm *IA = cast<InlineAsm>(Callee);
     if (!IA->canThrow()) {
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 984f02bcccad7..9268cbe594d90 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -282,6 +282,11 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   Instruction *transformCallThroughTrampoline(CallBase &Call,
                                               IntrinsicInst &Tramp);
 
+  /// Try to optimize a call to a ptrauth constant, into its ptrauth bundle:
+  ///   call(ptrauth(f)), ["ptrauth"()] ->  call f
+  /// as long as the key/discriminator are the same in constant and bundle.
+  Instruction *foldPtrAuthConstantCallee(CallBase &Call);
+
   // Return (a, b) if (LHS, RHS) is known to be (a, b) or (b, a).
   // Otherwise, return std::nullopt
   // Currently it matches:
diff --git a/llvm/test/Transforms/InstCombine/ptrauth-call.ll b/llvm/test/Transforms/InstCombine/ptrauth-call.ll
new file mode 100644
index 0000000000000..b4363b528d4e2
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/ptrauth-call.ll
@@ -0,0 +1,89 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128"
+
+declare i64 @f(i32)
+declare ptr @f2(i32)
+
+define i32 @test_ptrauth_call(i32 %a0) {
+; CHECK-LABEL: @test_ptrauth_call(
+; CHECK-NEXT:    [[V0:%.*]] = call i32 @f(i32 [[A0:%.*]])
+; CHECK-NEXT:    ret i32 [[V0]]
+;
+  %v0 = call i32 ptrauth(ptr @f, i32 0)(i32 %a0) [ "ptrauth"(i32 0, i64 0) ]
+  ret i32 %v0
+}
+
+define i32 @test_ptrauth_call_disc(i32 %a0) {
+; CHECK-LABEL: @test_ptrauth_call_disc(
+; CHECK-NEXT:    [[V0:%.*]] = call i32 @f(i32 [[A0:%.*]])
+; CHECK-NEXT:    ret i32 [[V0]]
+;
+  %v0 = call i32 ptrauth(ptr @f, i32 1, i64 5678)(i32 %a0) [ "ptrauth"(i32 1, i64 5678) ]
+  ret i32 %v0
+}
+
+ at f_addr_disc.ref = constant ptr ptrauth(ptr @f, i32 1, i64 0, ptr @f_addr_disc.ref)
+
+define i32 @test_ptrauth_call_addr_disc(i32 %a0) {
+; CHECK-LABEL: @test_ptrauth_call_addr_disc(
+; CHECK-NEXT:    [[V0:%.*]] = call i32 @f(i32 [[A0:%.*]])
+; CHECK-NEXT:    ret i32 [[V0]]
+;
+  %v0 = call i32 ptrauth(ptr @f, i32 1, i64 0, ptr @f_addr_disc.ref)(i32 %a0) [ "ptrauth"(i32 1, i64 ptrtoint (ptr @f_addr_disc.ref to i64)) ]
+  ret i32 %v0
+}
+
+ at f_both_disc.ref = constant ptr ptrauth(ptr @f, i32 1, i64 1234, ptr @f_both_disc.ref)
+
+define i32 @test_ptrauth_call_blend(i32 %a0) {
+; CHECK-LABEL: @test_ptrauth_call_blend(
+; CHECK-NEXT:    [[V0:%.*]] = call i32 @f(i32 [[A0:%.*]])
+; CHECK-NEXT:    ret i32 [[V0]]
+;
+  %v = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr @f_both_disc.ref to i64), i64 1234)
+  %v0 = call i32 ptrauth(ptr @f, i32 1, i64 1234, ptr @f_both_disc.ref)(i32 %a0) [ "ptrauth"(i32 1, i64 %v) ]
+  ret i32 %v0
+}
+
+define i64 @test_ptrauth_call_cast(i32 %a0) {
+; CHECK-LABEL: @test_ptrauth_call_cast(
+; CHECK-NEXT:    [[V0:%.*]] = call ptr @f2(i32 [[A0:%.*]])
+; CHECK-NEXT:    [[TMP1:%.*]] = ptrtoint ptr [[V0]] to i64
+; CHECK-NEXT:    ret i64 [[TMP1]]
+;
+  %v0 = call i64 ptrauth(ptr @f2, i32 0)(i32 %a0) [ "ptrauth"(i32 0, i64 0) ]
+  ret i64 %v0
+}
+
+define i32 @test_ptrauth_call_mismatch_key(i32 %a0) {
+; CHECK-LABEL: @test_ptrauth_call_mismatch_key(
+; CHECK-NEXT:    [[V0:%.*]] = call i32 ptrauth (ptr @f, i32 1, i64 5678)(i32 [[A0:%.*]]) [ "ptrauth"(i32 0, i64 5678) ]
+; CHECK-NEXT:    ret i32 [[V0]]
+;
+  %v0 = call i32 ptrauth(ptr @f, i32 1, i64 5678)(i32 %a0) [ "ptrauth"(i32 0, i64 5678) ]
+  ret i32 %v0
+}
+
+define i32 @test_ptrauth_call_mismatch_disc(i32 %a0) {
+; CHECK-LABEL: @test_ptrauth_call_mismatch_disc(
+; CHECK-NEXT:    [[V0:%.*]] = call i32 ptrauth (ptr @f, i32 1, i64 5678)(i32 [[A0:%.*]]) [ "ptrauth"(i32 1, i64 0) ]
+; CHECK-NEXT:    ret i32 [[V0]]
+;
+  %v0 = call i32 ptrauth(ptr @f, i32 1, i64 5678)(i32 %a0) [ "ptrauth"(i32 1, i64 0) ]
+  ret i32 %v0
+}
+
+define i32 @test_ptrauth_call_mismatch_blend(i32 %a0) {
+; CHECK-LABEL: @test_ptrauth_call_mismatch_blend(
+; CHECK-NEXT:    [[V:%.*]] = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr @f_both_disc.ref to i64), i64 0)
+; CHECK-NEXT:    [[V0:%.*]] = call i32 ptrauth (ptr @f, i32 1, i64 1234, ptr @f_both_disc.ref)(i32 [[A0:%.*]]) [ "ptrauth"(i32 1, i64 [[V]]) ]
+; CHECK-NEXT:    ret i32 [[V0]]
+;
+  %v = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr @f_both_disc.ref to i64), i64 0)
+  %v0 = call i32 ptrauth(ptr @f, i32 1, i64 1234, ptr @f_both_disc.ref)(i32 %a0) [ "ptrauth"(i32 1, i64 %v) ]
+  ret i32 %v0
+}
+
+declare i64 @llvm.ptrauth.blend(i64, i64)

>From 1d00ba7d78a0390f3a1eca4fe49de03ebe9d5333 Mon Sep 17 00:00:00 2001
From: Ahmed Bougacha <ahmed at bougacha.org>
Date: Thu, 13 Jun 2024 13:26:55 -0700
Subject: [PATCH 2/2] Remove now-unnecessary stripPointerCasts.

---
 llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 64f3038d94f94..069f638cd0e45 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -3670,7 +3670,7 @@ Instruction *InstCombinerImpl::foldPtrAuthConstantCallee(CallBase &Call) {
   if (!CPA)
     return nullptr;
 
-  auto *CalleeF = dyn_cast<Function>(CPA->getPointer()->stripPointerCasts());
+  auto *CalleeF = dyn_cast<Function>(CPA->getPointer());
   // If the ptrauth constant isn't based on a function pointer, bail out.
   if (!CalleeF)
     return nullptr;



More information about the llvm-commits mailing list