[llvm] [SimplifyCFG] Switch to use `paramHasNonNullAttr` (PR #125383)
Yingwei Zheng via llvm-commits
llvm-commits at lists.llvm.org
Sat Feb 1 21:36:33 PST 2025
https://github.com/dtcxzyw created https://github.com/llvm/llvm-project/pull/125383
None
>From 7e90465568b3993a67b75dce60b22b0684e1d803 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Sun, 2 Feb 2025 13:28:49 +0800
Subject: [PATCH] [SimplifyCFG] Switch to use `paramHasNonNullAttr`
---
llvm/include/llvm/IR/InstrTypes.h | 12 ++++++++++--
llvm/lib/Transforms/Utils/SimplifyCFG.cpp | 18 ++++++++++--------
.../SimplifyCFG/UnreachableEliminate.ll | 11 +++++------
3 files changed, 25 insertions(+), 16 deletions(-)
diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h
index 26be02d4b193de..9196db8a7632ff 100644
--- a/llvm/include/llvm/IR/InstrTypes.h
+++ b/llvm/include/llvm/IR/InstrTypes.h
@@ -1823,7 +1823,11 @@ class CallBase : public Instruction {
/// Extract the number of dereferenceable bytes for a call or
/// parameter (0=unknown).
uint64_t getParamDereferenceableBytes(unsigned i) const {
- return Attrs.getParamDereferenceableBytes(i);
+ uint64_t Bytes = Attrs.getParamDereferenceableBytes(i);
+ if (const Function *F = getCalledFunction())
+ Bytes =
+ std::max(Bytes, F->getAttributes().getParamDereferenceableBytes(i));
+ return Bytes;
}
/// Extract the number of dereferenceable_or_null bytes for a call
@@ -1841,7 +1845,11 @@ class CallBase : public Instruction {
/// Extract the number of dereferenceable_or_null bytes for a
/// parameter (0=unknown).
uint64_t getParamDereferenceableOrNullBytes(unsigned i) const {
- return Attrs.getParamDereferenceableOrNullBytes(i);
+ uint64_t Bytes = Attrs.getParamDereferenceableOrNullBytes(i);
+ if (const Function *F = getCalledFunction())
+ Bytes = std::max(
+ Bytes, F->getAttributes().getParamDereferenceableOrNullBytes(i));
+ return Bytes;
}
/// Extract a test mask for disallowed floating-point value classes for the
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 12dd49da279b9c..52e0bff13d08ae 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -8276,15 +8276,17 @@ static bool passingValueIsAlwaysUndefined(Value *V, Instruction *I, bool PtrValu
return true;
if (C->isNullValue()) {
- for (const llvm::Use &Arg : CB->args())
- if (Arg == I) {
- unsigned ArgIdx = CB->getArgOperandNo(&Arg);
- if (CB->isPassingUndefUB(ArgIdx) &&
- CB->paramHasAttr(ArgIdx, Attribute::NonNull)) {
- // Passing null to a nonnnull+noundef argument is undefined.
- return !PtrValueMayBeModified;
+ if (C->getType()->isPointerTy()) {
+ for (const llvm::Use &Arg : CB->args())
+ if (Arg == I) {
+ unsigned ArgIdx = CB->getArgOperandNo(&Arg);
+ if (CB->paramHasNonNullAttr(ArgIdx,
+ /*AllowUndefOrPoison=*/false)) {
+ // Passing null to a nonnnull+noundef argument is undefined.
+ return !PtrValueMayBeModified;
+ }
}
- }
+ }
} else if (isa<UndefValue>(C)) {
// Passing undef to a noundef argument is undefined.
for (const llvm::Use &Arg : CB->args())
diff --git a/llvm/test/Transforms/SimplifyCFG/UnreachableEliminate.ll b/llvm/test/Transforms/SimplifyCFG/UnreachableEliminate.ll
index aae1ab032f36e4..2da5d18b63f497 100644
--- a/llvm/test/Transforms/SimplifyCFG/UnreachableEliminate.ll
+++ b/llvm/test/Transforms/SimplifyCFG/UnreachableEliminate.ll
@@ -238,7 +238,7 @@ else:
}
declare ptr @fn_nonnull_noundef_arg(ptr nonnull noundef %p)
-declare ptr @fn_nonnull_deref_arg(ptr nonnull dereferenceable(4) %p)
+declare ptr @fn_deref_arg(ptr dereferenceable(4) %p)
declare ptr @fn_nonnull_deref_or_null_arg(ptr nonnull dereferenceable_or_null(4) %p)
declare ptr @fn_nonnull_arg(ptr nonnull %p)
declare ptr @fn_noundef_arg(ptr noundef %p)
@@ -271,7 +271,7 @@ define void @test9_deref(i1 %X, ptr %Y) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = xor i1 [[X:%.*]], true
; CHECK-NEXT: call void @llvm.assume(i1 [[TMP0]])
-; CHECK-NEXT: [[TMP1:%.*]] = call ptr @fn_nonnull_deref_arg(ptr [[Y:%.*]])
+; CHECK-NEXT: [[TMP1:%.*]] = call ptr @fn_deref_arg(ptr [[Y:%.*]])
; CHECK-NEXT: ret void
;
entry:
@@ -282,7 +282,7 @@ if:
else:
%phi = phi ptr [ %Y, %entry ], [ null, %if ]
- call ptr @fn_nonnull_deref_arg(ptr %phi)
+ call ptr @fn_deref_arg(ptr %phi)
ret void
}
@@ -290,9 +290,8 @@ else:
define void @test9_deref_or_null(i1 %X, ptr %Y) {
; CHECK-LABEL: @test9_deref_or_null(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[TMP0:%.*]] = xor i1 [[X:%.*]], true
-; CHECK-NEXT: call void @llvm.assume(i1 [[TMP0]])
-; CHECK-NEXT: [[TMP1:%.*]] = call ptr @fn_nonnull_deref_or_null_arg(ptr [[Y:%.*]])
+; CHECK-NEXT: [[Y:%.*]] = select i1 [[X:%.*]], ptr null, ptr [[Y1:%.*]]
+; CHECK-NEXT: [[TMP1:%.*]] = call ptr @fn_nonnull_deref_or_null_arg(ptr [[Y]])
; CHECK-NEXT: ret void
;
entry:
More information about the llvm-commits
mailing list