[llvm] 34e5a71 - [InstCombine] Combine ptrauth constants into ptrauth intrinsics. (#94705)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Jun 26 18:54:44 PDT 2024
Author: Ahmed Bougacha
Date: 2024-06-26T18:54:40-07:00
New Revision: 34e5a71b3219391309eb498a55e4d49831e1f9ab
URL: https://github.com/llvm/llvm-project/commit/34e5a71b3219391309eb498a55e4d49831e1f9ab
DIFF: https://github.com/llvm/llvm-project/commit/34e5a71b3219391309eb498a55e4d49831e1f9ab.diff
LOG: [InstCombine] Combine ptrauth constants into ptrauth intrinsics. (#94705)
When we encounter two consecutive ptrauth intrinsics, we can already
combine the inner matching sign + auth pair, e.g.:
resign(sign(p,ks,ds),ks,ds,kr,dr) -> sign(p,kr,dr)
We can generalize that to ptrauth constants, which are effectively
constant equivalents to ptrauth.sign, i.e.:
resign(ptrauth(p,ks,ds),ks,ds,kr,dr) -> ptrauth(p,kr,dr)
auth(ptrauth(p,k,d),k,d) -> p
While there, cleanup a redundant return after eraseInstFromFunction in
the shared (intrinsic|constant)->intrinsic folding code.
Added:
Modified:
llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
llvm/test/Transforms/InstCombine/ptrauth-intrinsics.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 436cdbff75669..b42f0ca296fc5 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -2643,13 +2643,14 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
// (sign|resign) + (auth|resign) can be folded by omitting the middle
// sign+auth component if the key and discriminator match.
bool NeedSign = II->getIntrinsicID() == Intrinsic::ptrauth_resign;
+ Value *Ptr = II->getArgOperand(0);
Value *Key = II->getArgOperand(1);
Value *Disc = II->getArgOperand(2);
// AuthKey will be the key we need to end up authenticating against in
// whatever we replace this sequence with.
Value *AuthKey = nullptr, *AuthDisc = nullptr, *BasePtr;
- if (auto CI = dyn_cast<CallBase>(II->getArgOperand(0))) {
+ if (const auto *CI = dyn_cast<CallBase>(Ptr)) {
BasePtr = CI->getArgOperand(0);
if (CI->getIntrinsicID() == Intrinsic::ptrauth_sign) {
if (CI->getArgOperand(1) != Key || CI->getArgOperand(2) != Disc)
@@ -2661,6 +2662,27 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
AuthDisc = CI->getArgOperand(2);
} else
break;
+ } else if (const auto *PtrToInt = dyn_cast<PtrToIntOperator>(Ptr)) {
+ // ptrauth constants are equivalent to a call to @llvm.ptrauth.sign for
+ // our purposes, so check for that too.
+ const auto *CPA = dyn_cast<ConstantPtrAuth>(PtrToInt->getOperand(0));
+ if (!CPA || !CPA->isKnownCompatibleWith(Key, Disc, DL))
+ break;
+
+ // resign(ptrauth(p,ks,ds),ks,ds,kr,dr) -> ptrauth(p,kr,dr)
+ if (NeedSign && isa<ConstantInt>(II->getArgOperand(4))) {
+ auto *SignKey = cast<ConstantInt>(II->getArgOperand(3));
+ auto *SignDisc = cast<ConstantInt>(II->getArgOperand(4));
+ auto *SignAddrDisc = ConstantPointerNull::get(Builder.getPtrTy());
+ auto *NewCPA = ConstantPtrAuth::get(CPA->getPointer(), SignKey,
+ SignDisc, SignAddrDisc);
+ replaceInstUsesWith(
+ *II, ConstantExpr::getPointerCast(NewCPA, II->getType()));
+ return eraseInstFromFunction(*II);
+ }
+
+ // auth(ptrauth(p,k,d),k,d) -> p
+ BasePtr = Builder.CreatePtrToInt(CPA->getPointer(), II->getType());
} else
break;
@@ -2677,8 +2699,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
} else {
// sign(0) + auth(0) = nop
replaceInstUsesWith(*II, BasePtr);
- eraseInstFromFunction(*II);
- return nullptr;
+ return eraseInstFromFunction(*II);
}
SmallVector<Value *, 4> CallArgs;
diff --git a/llvm/test/Transforms/InstCombine/ptrauth-intrinsics.ll b/llvm/test/Transforms/InstCombine/ptrauth-intrinsics.ll
index da0f724abfde4..208e162ac9416 100644
--- a/llvm/test/Transforms/InstCombine/ptrauth-intrinsics.ll
+++ b/llvm/test/Transforms/InstCombine/ptrauth-intrinsics.ll
@@ -12,6 +12,27 @@ define i64 @test_ptrauth_nop(ptr %p) {
ret i64 %authed
}
+declare void @foo()
+declare void @bar()
+
+define i64 @test_ptrauth_nop_constant() {
+; CHECK-LABEL: @test_ptrauth_nop_constant(
+; CHECK-NEXT: ret i64 ptrtoint (ptr @foo to i64)
+;
+ %authed = call i64 @llvm.ptrauth.auth(i64 ptrtoint(ptr ptrauth(ptr @foo, i32 1, i64 1234) to i64), i32 1, i64 1234)
+ ret i64 %authed
+}
+
+define i64 @test_ptrauth_nop_constant_addrdisc() {
+; CHECK-LABEL: @test_ptrauth_nop_constant_addrdisc(
+; CHECK-NEXT: ret i64 ptrtoint (ptr @foo to i64)
+;
+ %addr = ptrtoint ptr @foo to i64
+ %blended = call i64 @llvm.ptrauth.blend(i64 %addr, i64 1234)
+ %authed = call i64 @llvm.ptrauth.auth(i64 ptrtoint(ptr ptrauth(ptr @foo, i32 1, i64 1234, ptr @foo) to i64), i32 1, i64 %blended)
+ ret i64 %authed
+}
+
define i64 @test_ptrauth_nop_mismatch(ptr %p) {
; CHECK-LABEL: @test_ptrauth_nop_mismatch(
; CHECK-NEXT: [[TMP0:%.*]] = ptrtoint ptr [[P:%.*]] to i64
@@ -87,6 +108,59 @@ define i64 @test_ptrauth_resign_auth_mismatch(ptr %p) {
ret i64 %authed
}
+define i64 @test_ptrauth_nop_constant_mismatch() {
+; CHECK-LABEL: @test_ptrauth_nop_constant_mismatch(
+; CHECK-NEXT: [[AUTHED:%.*]] = call i64 @llvm.ptrauth.auth(i64 ptrtoint (ptr ptrauth (ptr @foo, i32 1, i64 1234) to i64), i32 1, i64 12)
+; CHECK-NEXT: ret i64 [[AUTHED]]
+;
+ %authed = call i64 @llvm.ptrauth.auth(i64 ptrtoint(ptr ptrauth(ptr @foo, i32 1, i64 1234) to i64), i32 1, i64 12)
+ ret i64 %authed
+}
+
+define i64 @test_ptrauth_nop_constant_mismatch_key() {
+; CHECK-LABEL: @test_ptrauth_nop_constant_mismatch_key(
+; CHECK-NEXT: [[AUTHED:%.*]] = call i64 @llvm.ptrauth.auth(i64 ptrtoint (ptr ptrauth (ptr @foo, i32 1, i64 1234) to i64), i32 0, i64 1234)
+; CHECK-NEXT: ret i64 [[AUTHED]]
+;
+ %authed = call i64 @llvm.ptrauth.auth(i64 ptrtoint(ptr ptrauth(ptr @foo, i32 1, i64 1234) to i64), i32 0, i64 1234)
+ ret i64 %authed
+}
+
+define i64 @test_ptrauth_nop_constant_addrdisc_mismatch() {
+; CHECK-LABEL: @test_ptrauth_nop_constant_addrdisc_mismatch(
+; CHECK-NEXT: [[BLENDED:%.*]] = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr @foo to i64), i64 12)
+; CHECK-NEXT: [[AUTHED:%.*]] = call i64 @llvm.ptrauth.auth(i64 ptrtoint (ptr ptrauth (ptr @foo, i32 1, i64 1234, ptr @foo) to i64), i32 1, i64 [[BLENDED]])
+; CHECK-NEXT: ret i64 [[AUTHED]]
+;
+ %addr = ptrtoint ptr @foo to i64
+ %blended = call i64 @llvm.ptrauth.blend(i64 %addr, i64 12)
+ %authed = call i64 @llvm.ptrauth.auth(i64 ptrtoint(ptr ptrauth(ptr @foo, i32 1, i64 1234, ptr @foo) to i64), i32 1, i64 %blended)
+ ret i64 %authed
+}
+
+define i64 @test_ptrauth_nop_constant_addrdisc_mismatch2() {
+; CHECK-LABEL: @test_ptrauth_nop_constant_addrdisc_mismatch2(
+; CHECK-NEXT: [[BLENDED:%.*]] = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr @bar to i64), i64 1234)
+; CHECK-NEXT: [[AUTHED:%.*]] = call i64 @llvm.ptrauth.auth(i64 ptrtoint (ptr ptrauth (ptr @foo, i32 1, i64 1234, ptr @foo) to i64), i32 1, i64 [[BLENDED]])
+; CHECK-NEXT: ret i64 [[AUTHED]]
+;
+ %addr = ptrtoint ptr @bar to i64
+ %blended = call i64 @llvm.ptrauth.blend(i64 %addr, i64 1234)
+ %authed = call i64 @llvm.ptrauth.auth(i64 ptrtoint(ptr ptrauth(ptr @foo, i32 1, i64 1234, ptr @foo) to i64), i32 1, i64 %blended)
+ ret i64 %authed
+}
+
+define i64 @test_ptrauth_resign_ptrauth_constant(ptr %p) {
+; CHECK-LABEL: @test_ptrauth_resign_ptrauth_constant(
+; CHECK-NEXT: ret i64 ptrtoint (ptr ptrauth (ptr @foo, i32 0, i64 42) to i64)
+;
+
+ %tmp0 = ptrtoint ptr %p to i64
+ %authed = call i64 @llvm.ptrauth.resign(i64 ptrtoint(ptr ptrauth(ptr @foo, i32 1, i64 1234) to i64), i32 1, i64 1234, i32 0, i64 42)
+ ret i64 %authed
+}
+
declare i64 @llvm.ptrauth.auth(i64, i32, i64)
declare i64 @llvm.ptrauth.sign(i64, i32, i64)
declare i64 @llvm.ptrauth.resign(i64, i32, i64, i32, i64)
+declare i64 @llvm.ptrauth.blend(i64, i64)
More information about the llvm-commits
mailing list