[llvm] [SimplifyCFG] Switch to use `paramHasNonNullAttr` (PR #125383)

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Sun Jun 1 09:06:10 PDT 2025


https://github.com/dtcxzyw updated https://github.com/llvm/llvm-project/pull/125383

>From 5b59246712a5c9446a526503818b7f86b824f03c Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Sun, 2 Feb 2025 13:28:49 +0800
Subject: [PATCH 1/2] [SimplifyCFG] Switch to use `paramHasNonNullAttr`

---
 llvm/include/llvm/IR/InstrTypes.h                    | 12 ++++++++++--
 llvm/lib/Transforms/Utils/SimplifyCFG.cpp            |  4 ++--
 .../Transforms/SimplifyCFG/UnreachableEliminate.ll   | 11 +++++------
 3 files changed, 17 insertions(+), 10 deletions(-)

diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h
index 8e47e3c7b3a7c..61070aa79b15d 100644
--- a/llvm/include/llvm/IR/InstrTypes.h
+++ b/llvm/include/llvm/IR/InstrTypes.h
@@ -1839,7 +1839,11 @@ class CallBase : public Instruction {
   /// Extract the number of dereferenceable bytes for a call or
   /// parameter (0=unknown).
   uint64_t getParamDereferenceableBytes(unsigned i) const {
-    return Attrs.getParamDereferenceableBytes(i);
+    uint64_t Bytes = Attrs.getParamDereferenceableBytes(i);
+    if (const Function *F = getCalledFunction())
+      Bytes =
+          std::max(Bytes, F->getAttributes().getParamDereferenceableBytes(i));
+    return Bytes;
   }
 
   /// Extract the number of dereferenceable_or_null bytes for a call
@@ -1857,7 +1861,11 @@ class CallBase : public Instruction {
   /// Extract the number of dereferenceable_or_null bytes for a
   /// parameter (0=unknown).
   uint64_t getParamDereferenceableOrNullBytes(unsigned i) const {
-    return Attrs.getParamDereferenceableOrNullBytes(i);
+    uint64_t Bytes = Attrs.getParamDereferenceableOrNullBytes(i);
+    if (const Function *F = getCalledFunction())
+      Bytes = std::max(
+          Bytes, F->getAttributes().getParamDereferenceableOrNullBytes(i));
+    return Bytes;
   }
 
   /// Extract a test mask for disallowed floating-point value classes for the
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 7840601d341b8..b7299e01b0c5f 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -8224,8 +8224,8 @@ static bool passingValueIsAlwaysUndefined(Value *V, Instruction *I, bool PtrValu
       if (CB->isArgOperand(&Use)) {
         unsigned ArgIdx = CB->getArgOperandNo(&Use);
         // Passing null to a nonnnull+noundef argument is undefined.
-        if (C->isNullValue() && CB->isPassingUndefUB(ArgIdx) &&
-            CB->paramHasAttr(ArgIdx, Attribute::NonNull))
+        if (isa<ConstantPointerNull>(C) &&
+            CB->paramHasNonNullAttr(ArgIdx, /*AllowUndefOrPoison=*/false))
           return !PtrValueMayBeModified;
         // Passing undef to a noundef argument is undefined.
         if (isa<UndefValue>(C) && CB->isPassingUndefUB(ArgIdx))
diff --git a/llvm/test/Transforms/SimplifyCFG/UnreachableEliminate.ll b/llvm/test/Transforms/SimplifyCFG/UnreachableEliminate.ll
index aae1ab032f36e..2da5d18b63f49 100644
--- a/llvm/test/Transforms/SimplifyCFG/UnreachableEliminate.ll
+++ b/llvm/test/Transforms/SimplifyCFG/UnreachableEliminate.ll
@@ -238,7 +238,7 @@ else:
 }
 
 declare ptr @fn_nonnull_noundef_arg(ptr nonnull noundef %p)
-declare ptr @fn_nonnull_deref_arg(ptr nonnull dereferenceable(4) %p)
+declare ptr @fn_deref_arg(ptr dereferenceable(4) %p)
 declare ptr @fn_nonnull_deref_or_null_arg(ptr nonnull dereferenceable_or_null(4) %p)
 declare ptr @fn_nonnull_arg(ptr nonnull %p)
 declare ptr @fn_noundef_arg(ptr noundef %p)
@@ -271,7 +271,7 @@ define void @test9_deref(i1 %X, ptr %Y) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[TMP0:%.*]] = xor i1 [[X:%.*]], true
 ; CHECK-NEXT:    call void @llvm.assume(i1 [[TMP0]])
-; CHECK-NEXT:    [[TMP1:%.*]] = call ptr @fn_nonnull_deref_arg(ptr [[Y:%.*]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call ptr @fn_deref_arg(ptr [[Y:%.*]])
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -282,7 +282,7 @@ if:
 
 else:
   %phi = phi ptr [ %Y, %entry ], [ null, %if ]
-  call ptr @fn_nonnull_deref_arg(ptr %phi)
+  call ptr @fn_deref_arg(ptr %phi)
   ret void
 }
 
@@ -290,9 +290,8 @@ else:
 define void @test9_deref_or_null(i1 %X, ptr %Y) {
 ; CHECK-LABEL: @test9_deref_or_null(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[TMP0:%.*]] = xor i1 [[X:%.*]], true
-; CHECK-NEXT:    call void @llvm.assume(i1 [[TMP0]])
-; CHECK-NEXT:    [[TMP1:%.*]] = call ptr @fn_nonnull_deref_or_null_arg(ptr [[Y:%.*]])
+; CHECK-NEXT:    [[Y:%.*]] = select i1 [[X:%.*]], ptr null, ptr [[Y1:%.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = call ptr @fn_nonnull_deref_or_null_arg(ptr [[Y]])
 ; CHECK-NEXT:    ret void
 ;
 entry:

>From 49bc007c3efb6c0dc40b252741e8fe7c043a793e Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Mon, 2 Jun 2025 00:05:49 +0800
Subject: [PATCH 2/2] [IR] Only check attributes

---
 llvm/include/llvm/IR/InstrTypes.h | 12 ++----------
 llvm/lib/IR/Instructions.cpp      |  2 +-
 2 files changed, 3 insertions(+), 11 deletions(-)

diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h
index 61070aa79b15d..8e47e3c7b3a7c 100644
--- a/llvm/include/llvm/IR/InstrTypes.h
+++ b/llvm/include/llvm/IR/InstrTypes.h
@@ -1839,11 +1839,7 @@ class CallBase : public Instruction {
   /// Extract the number of dereferenceable bytes for a call or
   /// parameter (0=unknown).
   uint64_t getParamDereferenceableBytes(unsigned i) const {
-    uint64_t Bytes = Attrs.getParamDereferenceableBytes(i);
-    if (const Function *F = getCalledFunction())
-      Bytes =
-          std::max(Bytes, F->getAttributes().getParamDereferenceableBytes(i));
-    return Bytes;
+    return Attrs.getParamDereferenceableBytes(i);
   }
 
   /// Extract the number of dereferenceable_or_null bytes for a call
@@ -1861,11 +1857,7 @@ class CallBase : public Instruction {
   /// Extract the number of dereferenceable_or_null bytes for a
   /// parameter (0=unknown).
   uint64_t getParamDereferenceableOrNullBytes(unsigned i) const {
-    uint64_t Bytes = Attrs.getParamDereferenceableOrNullBytes(i);
-    if (const Function *F = getCalledFunction())
-      Bytes = std::max(
-          Bytes, F->getAttributes().getParamDereferenceableOrNullBytes(i));
-    return Bytes;
+    return Attrs.getParamDereferenceableOrNullBytes(i);
   }
 
   /// Extract a test mask for disallowed floating-point value classes for the
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index f404e11b9c0f0..e22b04e1bb0b3 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -448,7 +448,7 @@ bool CallBase::paramHasNonNullAttr(unsigned ArgNo,
       (AllowUndefOrPoison || paramHasAttr(ArgNo, Attribute::NoUndef)))
     return true;
 
-  if (getParamDereferenceableBytes(ArgNo) > 0 &&
+  if (paramHasAttr(ArgNo, Attribute::Dereferenceable) &&
       !NullPointerIsDefined(
           getCaller(),
           getArgOperand(ArgNo)->getType()->getPointerAddressSpace()))



More information about the llvm-commits mailing list