[llvm] [InstCombine] Combine ptrauth constants into ptrauth intrinsics. (PR #94705)

Ahmed Bougacha via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 26 18:49:09 PDT 2024


https://github.com/ahmedbougacha updated https://github.com/llvm/llvm-project/pull/94705

>From 9dfe0109cc6c3c74edb2a3ad4e7de7aea3e615e1 Mon Sep 17 00:00:00 2001
From: Ahmed Bougacha <ahmed at bougacha.org>
Date: Mon, 27 Sep 2021 08:00:00 -0700
Subject: [PATCH 1/3] [InstCombine] Combine ptrauth constants into ptrauth
 intrinsics.

When we encounter two consecutive ptrauth intrinsics, we can
already combine the inner matching sign + auth pair, e.g.:
  resign(sign(p,ks,ds),ks,ds,kr,dr) -> sign(p,kr,dr)

We can generalize that to ptrauth constants, which are effectively
constant equivalents to ptrauth.sign, i.e.:
  resign(ptrauth(p,ks,ds),ks,ds,kr,dr) -> ptrauth(p,kr,dr)
  auth(ptrauth(p,k,d),k,d) -> p

While there cleanup a redundant return after eraseInstFromFunction
in the shared (intrinsic|constant)->intrinsic folding code.
---
 .../InstCombine/InstCombineCalls.cpp          | 27 ++++++-
 .../InstCombine/ptrauth-intrinsics.ll         | 73 +++++++++++++++++++
 2 files changed, 97 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 436cdbff75669..310514bea3ec1 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -2643,13 +2643,14 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     // (sign|resign) + (auth|resign) can be folded by omitting the middle
     // sign+auth component if the key and discriminator match.
     bool NeedSign = II->getIntrinsicID() == Intrinsic::ptrauth_resign;
+    Value *Ptr = II->getArgOperand(0);
     Value *Key = II->getArgOperand(1);
     Value *Disc = II->getArgOperand(2);
 
     // AuthKey will be the key we need to end up authenticating against in
     // whatever we replace this sequence with.
     Value *AuthKey = nullptr, *AuthDisc = nullptr, *BasePtr;
-    if (auto CI = dyn_cast<CallBase>(II->getArgOperand(0))) {
+    if (auto *CI = dyn_cast<CallBase>(Ptr)) {
       BasePtr = CI->getArgOperand(0);
       if (CI->getIntrinsicID() == Intrinsic::ptrauth_sign) {
         if (CI->getArgOperand(1) != Key || CI->getArgOperand(2) != Disc)
@@ -2661,6 +2662,27 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
         AuthDisc = CI->getArgOperand(2);
       } else
         break;
+    } else if (auto *PtrToInt = dyn_cast<PtrToIntOperator>(Ptr)) {
+      // ptrauth constants are equivalent to a call to @llvm.ptrauth.sign for
+      // our purposes, so check for that too.
+      auto *CPA = dyn_cast<ConstantPtrAuth>(PtrToInt->getOperand(0));
+      if (!CPA || !CPA->isKnownCompatibleWith(Key, Disc, DL))
+        break;
+
+      // resign(ptrauth(p,ks,ds),ks,ds,kr,dr) -> ptrauth(p,kr,dr)
+      if (NeedSign && isa<ConstantInt>(II->getArgOperand(4))) {
+        auto *SignKey = cast<ConstantInt>(II->getArgOperand(3));
+        auto *SignDisc = cast<ConstantInt>(II->getArgOperand(4));
+        auto *SignAddrDisc = ConstantPointerNull::get(Builder.getPtrTy());
+        auto *NewCPA = ConstantPtrAuth::get(CPA->getPointer(), SignKey,
+                                            SignDisc, SignAddrDisc);
+        replaceInstUsesWith(
+            *II, ConstantExpr::getPointerCast(NewCPA, II->getType()));
+        return eraseInstFromFunction(*II);
+      }
+
+      // auth(ptrauth(p,k,d),k,d) -> p
+      BasePtr = Builder.CreatePtrToInt(CPA->getPointer(), II->getType());
     } else
       break;
 
@@ -2677,8 +2699,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     } else {
       // sign(0) + auth(0) = nop
       replaceInstUsesWith(*II, BasePtr);
-      eraseInstFromFunction(*II);
-      return nullptr;
+      return eraseInstFromFunction(*II);
     }
 
     SmallVector<Value *, 4> CallArgs;
diff --git a/llvm/test/Transforms/InstCombine/ptrauth-intrinsics.ll b/llvm/test/Transforms/InstCombine/ptrauth-intrinsics.ll
index da0f724abfde4..3e894739f4e34 100644
--- a/llvm/test/Transforms/InstCombine/ptrauth-intrinsics.ll
+++ b/llvm/test/Transforms/InstCombine/ptrauth-intrinsics.ll
@@ -12,6 +12,26 @@ define i64 @test_ptrauth_nop(ptr %p) {
   ret i64 %authed
 }
 
+declare void @foo()
+
+define i64 @test_ptrauth_nop_constant() {
+; CHECK-LABEL: @test_ptrauth_nop_constant(
+; CHECK-NEXT:    ret i64 ptrtoint (ptr @foo to i64)
+;
+  %authed = call i64 @llvm.ptrauth.auth(i64 ptrtoint(ptr ptrauth(ptr @foo, i32 1, i64 1234) to i64), i32 1, i64 1234)
+  ret i64 %authed
+}
+
+define i64 @test_ptrauth_nop_constant_addrdisc() {
+; CHECK-LABEL: @test_ptrauth_nop_constant_addrdisc(
+; CHECK-NEXT:    ret i64 ptrtoint (ptr @foo to i64)
+;
+  %addr = ptrtoint void()* @foo to i64
+  %blended = call i64 @llvm.ptrauth.blend(i64 %addr, i64 1234)
+  %authed = call i64 @llvm.ptrauth.auth(i64 ptrtoint(ptr ptrauth(ptr @foo, i32 1, i64 1234, ptr @foo) to i64), i32 1, i64 %blended)
+  ret i64 %authed
+}
+
 define i64 @test_ptrauth_nop_mismatch(ptr %p) {
 ; CHECK-LABEL: @test_ptrauth_nop_mismatch(
 ; CHECK-NEXT:    [[TMP0:%.*]] = ptrtoint ptr [[P:%.*]] to i64
@@ -87,6 +107,59 @@ define i64 @test_ptrauth_resign_auth_mismatch(ptr %p) {
   ret i64 %authed
 }
 
+define i64 @test_ptrauth_nop_constant_mismatch() {
+; CHECK-LABEL: @test_ptrauth_nop_constant_mismatch(
+; CHECK-NEXT:    [[AUTHED:%.*]] = call i64 @llvm.ptrauth.auth(i64 ptrtoint (ptr ptrauth (ptr @foo, i32 1, i64 1234) to i64), i32 1, i64 12)
+; CHECK-NEXT:    ret i64 [[AUTHED]]
+;
+  %authed = call i64 @llvm.ptrauth.auth(i64 ptrtoint(ptr ptrauth(ptr @foo, i32 1, i64 1234) to i64), i32 1, i64 12)
+  ret i64 %authed
+}
+
+define i64 @test_ptrauth_nop_constant_mismatch_key() {
+; CHECK-LABEL: @test_ptrauth_nop_constant_mismatch_key(
+; CHECK-NEXT:    [[AUTHED:%.*]] = call i64 @llvm.ptrauth.auth(i64 ptrtoint (ptr ptrauth (ptr @foo, i32 1, i64 1234) to i64), i32 0, i64 1234)
+; CHECK-NEXT:    ret i64 [[AUTHED]]
+;
+  %authed = call i64 @llvm.ptrauth.auth(i64 ptrtoint(ptr ptrauth(ptr @foo, i32 1, i64 1234) to i64), i32 0, i64 1234)
+  ret i64 %authed
+}
+
+define i64 @test_ptrauth_nop_constant_addrdisc_mismatch() {
+; CHECK-LABEL: @test_ptrauth_nop_constant_addrdisc_mismatch(
+; CHECK-NEXT:    [[BLENDED:%.*]] = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr @foo to i64), i64 12)
+; CHECK-NEXT:    [[AUTHED:%.*]] = call i64 @llvm.ptrauth.auth(i64 ptrtoint (ptr ptrauth (ptr @foo, i32 1, i64 1234, ptr @foo) to i64), i32 1, i64 [[BLENDED]])
+; CHECK-NEXT:    ret i64 [[AUTHED]]
+;
+  %addr = ptrtoint ptr @foo to i64
+  %blended = call i64 @llvm.ptrauth.blend(i64 %addr, i64 12)
+  %authed = call i64 @llvm.ptrauth.auth(i64 ptrtoint(ptr ptrauth(ptr @foo, i32 1, i64 1234, ptr @foo) to i64), i32 1, i64 %blended)
+  ret i64 %authed
+}
+
+define i64 @test_ptrauth_nop_constant_addrdisc_mismatch2() {
+; CHECK-LABEL: @test_ptrauth_nop_constant_addrdisc_mismatch2(
+; CHECK-NEXT:    [[BLENDED:%.*]] = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr @test_ptrauth_nop to i64), i64 1234)
+; CHECK-NEXT:    [[AUTHED:%.*]] = call i64 @llvm.ptrauth.auth(i64 ptrtoint (ptr ptrauth (ptr @foo, i32 1, i64 1234, ptr @foo) to i64), i32 1, i64 [[BLENDED]])
+; CHECK-NEXT:    ret i64 [[AUTHED]]
+;
+  %addr = ptrtoint ptr @test_ptrauth_nop to i64
+  %blended = call i64 @llvm.ptrauth.blend(i64 %addr, i64 1234)
+  %authed = call i64 @llvm.ptrauth.auth(i64 ptrtoint(ptr ptrauth(ptr @foo, i32 1, i64 1234, ptr @foo) to i64), i32 1, i64 %blended)
+  ret i64 %authed
+}
+
+define i64 @test_ptrauth_resign_ptrauth_constant(ptr %p) {
+; CHECK-LABEL: @test_ptrauth_resign_ptrauth_constant(
+; CHECK-NEXT:    ret i64 ptrtoint (ptr ptrauth (ptr @foo, i32 0, i64 42) to i64)
+;
+
+  %tmp0 = ptrtoint ptr %p to i64
+  %authed = call i64 @llvm.ptrauth.resign(i64 ptrtoint(ptr ptrauth(ptr @foo, i32 1, i64 1234) to i64), i32 1, i64 1234, i32 0, i64 42)
+  ret i64 %authed
+}
+
 declare i64 @llvm.ptrauth.auth(i64, i32, i64)
 declare i64 @llvm.ptrauth.sign(i64, i32, i64)
 declare i64 @llvm.ptrauth.resign(i64, i32, i64, i32, i64)
+declare i64 @llvm.ptrauth.blend(i64, i64)

>From 4943d02d4ec5ea311ce9db55858ae895cc313601 Mon Sep 17 00:00:00 2001
From: Ahmed Bougacha <ahmed at bougacha.org>
Date: Thu, 13 Jun 2024 13:34:45 -0700
Subject: [PATCH 2/3] Replace typed pointers in casts with ptr.

---
 llvm/test/Transforms/InstCombine/ptrauth-intrinsics.ll | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/test/Transforms/InstCombine/ptrauth-intrinsics.ll b/llvm/test/Transforms/InstCombine/ptrauth-intrinsics.ll
index 3e894739f4e34..609abbe3a0d72 100644
--- a/llvm/test/Transforms/InstCombine/ptrauth-intrinsics.ll
+++ b/llvm/test/Transforms/InstCombine/ptrauth-intrinsics.ll
@@ -26,7 +26,7 @@ define i64 @test_ptrauth_nop_constant_addrdisc() {
 ; CHECK-LABEL: @test_ptrauth_nop_constant_addrdisc(
 ; CHECK-NEXT:    ret i64 ptrtoint (ptr @foo to i64)
 ;
-  %addr = ptrtoint void()* @foo to i64
+  %addr = ptrtoint ptr @foo to i64
   %blended = call i64 @llvm.ptrauth.blend(i64 %addr, i64 1234)
   %authed = call i64 @llvm.ptrauth.auth(i64 ptrtoint(ptr ptrauth(ptr @foo, i32 1, i64 1234, ptr @foo) to i64), i32 1, i64 %blended)
   ret i64 %authed

>From b08be11498e0e786606f9da20cfa7f813c9c3953 Mon Sep 17 00:00:00 2001
From: Ahmed Bougacha <ahmed at bougacha.org>
Date: Wed, 26 Jun 2024 18:47:45 -0700
Subject: [PATCH 3/3] Address review feedback.

- don't reuse fn in test
- const
---
 llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp   | 6 +++---
 llvm/test/Transforms/InstCombine/ptrauth-intrinsics.ll | 5 +++--
 2 files changed, 6 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 310514bea3ec1..b42f0ca296fc5 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -2650,7 +2650,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     // AuthKey will be the key we need to end up authenticating against in
     // whatever we replace this sequence with.
     Value *AuthKey = nullptr, *AuthDisc = nullptr, *BasePtr;
-    if (auto *CI = dyn_cast<CallBase>(Ptr)) {
+    if (const auto *CI = dyn_cast<CallBase>(Ptr)) {
       BasePtr = CI->getArgOperand(0);
       if (CI->getIntrinsicID() == Intrinsic::ptrauth_sign) {
         if (CI->getArgOperand(1) != Key || CI->getArgOperand(2) != Disc)
@@ -2662,10 +2662,10 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
         AuthDisc = CI->getArgOperand(2);
       } else
         break;
-    } else if (auto *PtrToInt = dyn_cast<PtrToIntOperator>(Ptr)) {
+    } else if (const auto *PtrToInt = dyn_cast<PtrToIntOperator>(Ptr)) {
       // ptrauth constants are equivalent to a call to @llvm.ptrauth.sign for
       // our purposes, so check for that too.
-      auto *CPA = dyn_cast<ConstantPtrAuth>(PtrToInt->getOperand(0));
+      const auto *CPA = dyn_cast<ConstantPtrAuth>(PtrToInt->getOperand(0));
       if (!CPA || !CPA->isKnownCompatibleWith(Key, Disc, DL))
         break;
 
diff --git a/llvm/test/Transforms/InstCombine/ptrauth-intrinsics.ll b/llvm/test/Transforms/InstCombine/ptrauth-intrinsics.ll
index 609abbe3a0d72..208e162ac9416 100644
--- a/llvm/test/Transforms/InstCombine/ptrauth-intrinsics.ll
+++ b/llvm/test/Transforms/InstCombine/ptrauth-intrinsics.ll
@@ -13,6 +13,7 @@ define i64 @test_ptrauth_nop(ptr %p) {
 }
 
 declare void @foo()
+declare void @bar()
 
 define i64 @test_ptrauth_nop_constant() {
 ; CHECK-LABEL: @test_ptrauth_nop_constant(
@@ -139,11 +140,11 @@ define i64 @test_ptrauth_nop_constant_addrdisc_mismatch() {
 
 define i64 @test_ptrauth_nop_constant_addrdisc_mismatch2() {
 ; CHECK-LABEL: @test_ptrauth_nop_constant_addrdisc_mismatch2(
-; CHECK-NEXT:    [[BLENDED:%.*]] = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr @test_ptrauth_nop to i64), i64 1234)
+; CHECK-NEXT:    [[BLENDED:%.*]] = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr @bar to i64), i64 1234)
 ; CHECK-NEXT:    [[AUTHED:%.*]] = call i64 @llvm.ptrauth.auth(i64 ptrtoint (ptr ptrauth (ptr @foo, i32 1, i64 1234, ptr @foo) to i64), i32 1, i64 [[BLENDED]])
 ; CHECK-NEXT:    ret i64 [[AUTHED]]
 ;
-  %addr = ptrtoint ptr @test_ptrauth_nop to i64
+  %addr = ptrtoint ptr @bar to i64
   %blended = call i64 @llvm.ptrauth.blend(i64 %addr, i64 1234)
   %authed = call i64 @llvm.ptrauth.auth(i64 ptrtoint(ptr ptrauth(ptr @foo, i32 1, i64 1234, ptr @foo) to i64), i32 1, i64 %blended)
   ret i64 %authed



More information about the llvm-commits mailing list