[llvm] [InstCombine] Combine ptrauth constant callee into bundle. (PR #94706)
Ahmed Bougacha via llvm-commits
llvm-commits at lists.llvm.org
Thu Jun 6 17:28:59 PDT 2024
https://github.com/ahmedbougacha created https://github.com/llvm/llvm-project/pull/94706
Try to optimize a call to a ptrauth constant, into its ptrauth bundle:
call(ptrauth(f)), ["ptrauth"()] -> call f
as long as the key/discriminator are the same in constant and bundle.
>From 2b1879025d45a58cb089b9d6d5857cb20842cb24 Mon Sep 17 00:00:00 2001
From: Ahmed Bougacha <ahmed at bougacha.org>
Date: Mon, 27 Sep 2021 08:00:00 -0700
Subject: [PATCH] [InstCombine] Combine ptrauth constant callee into bundle.
Try to optimize a call to a ptrauth constant, into its ptrauth bundle:
call(ptrauth(f)), ["ptrauth"()] -> call f
as long as the key/discriminator are the same in constant and bundle.
---
.../InstCombine/InstCombineCalls.cpp | 32 +++++++
.../InstCombine/InstCombineInternal.h | 5 ++
.../Transforms/InstCombine/ptrauth-call.ll | 89 +++++++++++++++++++
3 files changed, 126 insertions(+)
create mode 100644 llvm/test/Transforms/InstCombine/ptrauth-call.ll
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 0632f3cfc6dd2..3f6244951b64c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -3647,6 +3647,34 @@ static IntrinsicInst *findInitTrampoline(Value *Callee) {
return nullptr;
}
+Instruction *InstCombinerImpl::foldPtrAuthConstantCallee(CallBase &Call) {
+ auto *CPA = dyn_cast<ConstantPtrAuth>(Call.getCalledOperand());
+ if (!CPA)
+ return nullptr;
+
+ auto *CalleeF = dyn_cast<Function>(CPA->getPointer()->stripPointerCasts());
+ // If the ptrauth constant isn't based on a function pointer, bail out.
+ if (!CalleeF)
+ return nullptr;
+
+ // Inspect the call ptrauth bundle to check it matches the ptrauth constant.
+ auto PAB = Call.getOperandBundle(LLVMContext::OB_ptrauth);
+ if (!PAB)
+ return nullptr;
+
+ auto *Key = cast<ConstantInt>(PAB->Inputs[0]);
+ Value *Discriminator = PAB->Inputs[1];
+
+ // If the bundle doesn't match, this is probably going to fail to auth.
+ if (!CPA->isKnownCompatibleWith(Key, Discriminator, DL))
+ return nullptr;
+
+ // If the bundle matches the constant, proceed in making this a direct call.
+ auto *NewCall = CallBase::removeOperandBundle(&Call, LLVMContext::OB_ptrauth);
+ NewCall->setCalledOperand(CalleeF);
+ return NewCall;
+}
+
bool InstCombinerImpl::annotateAnyAllocSite(CallBase &Call,
const TargetLibraryInfo *TLI) {
// Note: We only handle cases which can't be driven from generic attributes
@@ -3794,6 +3822,10 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) {
if (IntrinsicInst *II = findInitTrampoline(Callee))
return transformCallThroughTrampoline(Call, *II);
+ // Combine calls to ptrauth constants.
+ if (Instruction *NewCall = foldPtrAuthConstantCallee(Call))
+ return NewCall;
+
if (isa<InlineAsm>(Callee) && !Call.doesNotThrow()) {
InlineAsm *IA = cast<InlineAsm>(Callee);
if (!IA->canThrow()) {
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 984f02bcccad7..9268cbe594d90 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -282,6 +282,11 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
Instruction *transformCallThroughTrampoline(CallBase &Call,
IntrinsicInst &Tramp);
+ /// Try to optimize a call to a ptrauth constant, into its ptrauth bundle:
+ /// call(ptrauth(f)), ["ptrauth"()] -> call f
+ /// as long as the key/discriminator are the same in constant and bundle.
+ Instruction *foldPtrAuthConstantCallee(CallBase &Call);
+
// Return (a, b) if (LHS, RHS) is known to be (a, b) or (b, a).
// Otherwise, return std::nullopt
// Currently it matches:
diff --git a/llvm/test/Transforms/InstCombine/ptrauth-call.ll b/llvm/test/Transforms/InstCombine/ptrauth-call.ll
new file mode 100644
index 0000000000000..b4363b528d4e2
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/ptrauth-call.ll
@@ -0,0 +1,89 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128"
+
+declare i64 @f(i32)
+declare ptr @f2(i32)
+
+define i32 @test_ptrauth_call(i32 %a0) {
+; CHECK-LABEL: @test_ptrauth_call(
+; CHECK-NEXT: [[V0:%.*]] = call i32 @f(i32 [[A0:%.*]])
+; CHECK-NEXT: ret i32 [[V0]]
+;
+ %v0 = call i32 ptrauth(ptr @f, i32 0)(i32 %a0) [ "ptrauth"(i32 0, i64 0) ]
+ ret i32 %v0
+}
+
+define i32 @test_ptrauth_call_disc(i32 %a0) {
+; CHECK-LABEL: @test_ptrauth_call_disc(
+; CHECK-NEXT: [[V0:%.*]] = call i32 @f(i32 [[A0:%.*]])
+; CHECK-NEXT: ret i32 [[V0]]
+;
+ %v0 = call i32 ptrauth(ptr @f, i32 1, i64 5678)(i32 %a0) [ "ptrauth"(i32 1, i64 5678) ]
+ ret i32 %v0
+}
+
+ at f_addr_disc.ref = constant ptr ptrauth(ptr @f, i32 1, i64 0, ptr @f_addr_disc.ref)
+
+define i32 @test_ptrauth_call_addr_disc(i32 %a0) {
+; CHECK-LABEL: @test_ptrauth_call_addr_disc(
+; CHECK-NEXT: [[V0:%.*]] = call i32 @f(i32 [[A0:%.*]])
+; CHECK-NEXT: ret i32 [[V0]]
+;
+ %v0 = call i32 ptrauth(ptr @f, i32 1, i64 0, ptr @f_addr_disc.ref)(i32 %a0) [ "ptrauth"(i32 1, i64 ptrtoint (ptr @f_addr_disc.ref to i64)) ]
+ ret i32 %v0
+}
+
+ at f_both_disc.ref = constant ptr ptrauth(ptr @f, i32 1, i64 1234, ptr @f_both_disc.ref)
+
+define i32 @test_ptrauth_call_blend(i32 %a0) {
+; CHECK-LABEL: @test_ptrauth_call_blend(
+; CHECK-NEXT: [[V0:%.*]] = call i32 @f(i32 [[A0:%.*]])
+; CHECK-NEXT: ret i32 [[V0]]
+;
+ %v = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr @f_both_disc.ref to i64), i64 1234)
+ %v0 = call i32 ptrauth(ptr @f, i32 1, i64 1234, ptr @f_both_disc.ref)(i32 %a0) [ "ptrauth"(i32 1, i64 %v) ]
+ ret i32 %v0
+}
+
+define i64 @test_ptrauth_call_cast(i32 %a0) {
+; CHECK-LABEL: @test_ptrauth_call_cast(
+; CHECK-NEXT: [[V0:%.*]] = call ptr @f2(i32 [[A0:%.*]])
+; CHECK-NEXT: [[TMP1:%.*]] = ptrtoint ptr [[V0]] to i64
+; CHECK-NEXT: ret i64 [[TMP1]]
+;
+ %v0 = call i64 ptrauth(ptr @f2, i32 0)(i32 %a0) [ "ptrauth"(i32 0, i64 0) ]
+ ret i64 %v0
+}
+
+define i32 @test_ptrauth_call_mismatch_key(i32 %a0) {
+; CHECK-LABEL: @test_ptrauth_call_mismatch_key(
+; CHECK-NEXT: [[V0:%.*]] = call i32 ptrauth (ptr @f, i32 1, i64 5678)(i32 [[A0:%.*]]) [ "ptrauth"(i32 0, i64 5678) ]
+; CHECK-NEXT: ret i32 [[V0]]
+;
+ %v0 = call i32 ptrauth(ptr @f, i32 1, i64 5678)(i32 %a0) [ "ptrauth"(i32 0, i64 5678) ]
+ ret i32 %v0
+}
+
+define i32 @test_ptrauth_call_mismatch_disc(i32 %a0) {
+; CHECK-LABEL: @test_ptrauth_call_mismatch_disc(
+; CHECK-NEXT: [[V0:%.*]] = call i32 ptrauth (ptr @f, i32 1, i64 5678)(i32 [[A0:%.*]]) [ "ptrauth"(i32 1, i64 0) ]
+; CHECK-NEXT: ret i32 [[V0]]
+;
+ %v0 = call i32 ptrauth(ptr @f, i32 1, i64 5678)(i32 %a0) [ "ptrauth"(i32 1, i64 0) ]
+ ret i32 %v0
+}
+
+define i32 @test_ptrauth_call_mismatch_blend(i32 %a0) {
+; CHECK-LABEL: @test_ptrauth_call_mismatch_blend(
+; CHECK-NEXT: [[V:%.*]] = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr @f_both_disc.ref to i64), i64 0)
+; CHECK-NEXT: [[V0:%.*]] = call i32 ptrauth (ptr @f, i32 1, i64 1234, ptr @f_both_disc.ref)(i32 [[A0:%.*]]) [ "ptrauth"(i32 1, i64 [[V]]) ]
+; CHECK-NEXT: ret i32 [[V0]]
+;
+ %v = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr @f_both_disc.ref to i64), i64 0)
+ %v0 = call i32 ptrauth(ptr @f, i32 1, i64 1234, ptr @f_both_disc.ref)(i32 %a0) [ "ptrauth"(i32 1, i64 %v) ]
+ ret i32 %v0
+}
+
+declare i64 @llvm.ptrauth.blend(i64, i64)
More information about the llvm-commits
mailing list