[llvm] [InstCombine] Add support for GEPs in `simplifyNonNullOperand` (PR #128365)

via llvm-commits llvm-commits at lists.llvm.org
Sat Feb 22 22:30:38 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Yingwei Zheng (dtcxzyw)

<details>
<summary>Changes</summary>

Alive2: https://alive2.llvm.org/ce/z/2KE8zG

---
Full diff: https://github.com/llvm/llvm-project/pull/128365.diff


8 Files Affected:

- (modified) llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp (+6-2) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineInternal.h (+4-1) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp (+22-4) 
- (modified) llvm/lib/Transforms/InstCombine/InstructionCombining.cpp (+4-2) 
- (modified) llvm/test/Transforms/InstCombine/load.ll (+12) 
- (modified) llvm/test/Transforms/InstCombine/nonnull-select.ll (+98) 
- (modified) llvm/test/Transforms/InstCombine/store.ll (+12) 
- (modified) llvm/test/Transforms/PhaseOrdering/memset-combine.ll (+2-4) 


``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 54f777ab20a7a..63f2fd0a733ce 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -3996,8 +3996,12 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) {
     if (V->getType()->isPointerTy()) {
       // Simplify the nonnull operand if the parameter is known to be nonnull.
       // Otherwise, try to infer nonnull for it.
-      if (Call.paramHasNonNullAttr(ArgNo, /*AllowUndefOrPoison=*/true)) {
-        if (Value *Res = simplifyNonNullOperand(V)) {
+      bool HasDereferenceable = Call.getParamDereferenceableBytes(ArgNo) > 0;
+      if (Call.paramHasAttr(ArgNo, Attribute::NonNull) ||
+          (HasDereferenceable &&
+           !NullPointerIsDefined(Call.getFunction(),
+                                 V->getType()->getPointerAddressSpace()))) {
+        if (Value *Res = simplifyNonNullOperand(V, HasDereferenceable)) {
           replaceOperand(Call, ArgNo, Res);
           Changed = true;
         }
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 71c80d4c401f8..5b2af39e69f2c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -457,7 +457,10 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
 
   /// Simplify \p V given that it is known to be non-null.
   /// Returns the simplified value if possible, otherwise returns nullptr.
-  Value *simplifyNonNullOperand(Value *V);
+  /// If \p HasDereferenceable is true, the simplification will not perform
+  /// same object checks.
+  Value *simplifyNonNullOperand(Value *V, bool HasDereferenceable,
+                                unsigned Depth = 0);
 
 public:
   /// Create and insert the idiom we use to indicate a block is unreachable
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
index 89fc1051b18dc..622884ea1eb46 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
@@ -982,8 +982,9 @@ static bool canSimplifyNullLoadOrGEP(LoadInst &LI, Value *Op) {
   return false;
 }
 
-/// TODO: Recursively simplify nonnull value to handle one-use inbounds GEPs.
-Value *InstCombinerImpl::simplifyNonNullOperand(Value *V) {
+Value *InstCombinerImpl::simplifyNonNullOperand(Value *V,
+                                                bool HasDereferenceable,
+                                                unsigned Depth) {
   if (auto *Sel = dyn_cast<SelectInst>(V)) {
     if (isa<ConstantPointerNull>(Sel->getOperand(1)))
       return Sel->getOperand(2);
@@ -992,6 +993,23 @@ Value *InstCombinerImpl::simplifyNonNullOperand(Value *V) {
       return Sel->getOperand(1);
   }
 
+  if (!V->hasOneUse())
+    return nullptr;
+
+  if (Depth == 1)
+    return nullptr;
+
+  if (auto *GEP = dyn_cast<GetElementPtrInst>(V)) {
+    if (HasDereferenceable || GEP->isInBounds()) {
+      if (auto *Res = simplifyNonNullOperand(GEP->getPointerOperand(),
+                                             HasDereferenceable, Depth + 1)) {
+        replaceOperand(*GEP, 0, Res);
+        addToWorklist(GEP);
+        return nullptr;
+      }
+    }
+  }
+
   return nullptr;
 }
 
@@ -1076,7 +1094,7 @@ Instruction *InstCombinerImpl::visitLoadInst(LoadInst &LI) {
   }
 
   if (!NullPointerIsDefined(LI.getFunction(), LI.getPointerAddressSpace()))
-    if (Value *V = simplifyNonNullOperand(Op))
+    if (Value *V = simplifyNonNullOperand(Op, /*HasDereferenceable=*/true))
       return replaceOperand(LI, 0, V);
 
   return nullptr;
@@ -1444,7 +1462,7 @@ Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) {
     return eraseInstFromFunction(SI);
 
   if (!NullPointerIsDefined(SI.getFunction(), SI.getPointerAddressSpace()))
-    if (Value *V = simplifyNonNullOperand(Ptr))
+    if (Value *V = simplifyNonNullOperand(Ptr, /*HasDereferenceable=*/true))
       return replaceOperand(SI, 1, V);
 
   return nullptr;
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index b7748f59a0cfc..81b057c10b484 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -3593,10 +3593,12 @@ Instruction *InstCombinerImpl::visitReturnInst(ReturnInst &RI) {
   Function *F = RI.getFunction();
   Type *RetTy = RetVal->getType();
   if (RetTy->isPointerTy()) {
+    bool HasDereferenceable =
+        F->getAttributes().getRetDereferenceableBytes() > 0;
     if (F->hasRetAttribute(Attribute::NonNull) ||
-        (F->getAttributes().getRetDereferenceableBytes() > 0 &&
+        (HasDereferenceable &&
          !NullPointerIsDefined(F, RetTy->getPointerAddressSpace()))) {
-      if (Value *V = simplifyNonNullOperand(RetVal))
+      if (Value *V = simplifyNonNullOperand(RetVal, HasDereferenceable))
         return replaceOperand(RI, 0, V);
     }
   }
diff --git a/llvm/test/Transforms/InstCombine/load.ll b/llvm/test/Transforms/InstCombine/load.ll
index 6c087aa87845f..a5ad1e0c21526 100644
--- a/llvm/test/Transforms/InstCombine/load.ll
+++ b/llvm/test/Transforms/InstCombine/load.ll
@@ -439,3 +439,15 @@ define i4 @test_vector_load_i4_non_byte_sized() {
   %res0 = load i4, ptr %ptr0, align 1
   ret i4 %res0
 }
+
+define i32 @load_select_with_null_gep(i1 %cond, ptr %p, i64 %off) {
+; CHECK-LABEL: @load_select_with_null_gep(
+; CHECK-NEXT:    [[GEP:%.*]] = getelementptr i8, ptr [[SEL:%.*]], i64 [[OFF:%.*]]
+; CHECK-NEXT:    [[RES:%.*]] = load i32, ptr [[GEP]], align 4
+; CHECK-NEXT:    ret i32 [[RES]]
+;
+  %sel = select i1 %cond, ptr %p, ptr null
+  %gep = getelementptr i8, ptr %sel, i64 %off
+  %res = load i32, ptr %gep, align 4
+  ret i32 %res
+}
diff --git a/llvm/test/Transforms/InstCombine/nonnull-select.ll b/llvm/test/Transforms/InstCombine/nonnull-select.ll
index cc000b4c88164..929919f9c42c7 100644
--- a/llvm/test/Transforms/InstCombine/nonnull-select.ll
+++ b/llvm/test/Transforms/InstCombine/nonnull-select.ll
@@ -86,4 +86,102 @@ define void @nonnull_noundef_call2(i1 %cond, ptr %p) {
   ret void
 }
 
+define void @nonnull_call_gep(i1 %cond, ptr %p, i64 %off) {
+; CHECK-LABEL: @nonnull_call_gep(
+; CHECK-NEXT:    [[PTR:%.*]] = select i1 [[COND:%.*]], ptr null, ptr [[P:%.*]]
+; CHECK-NEXT:    [[GEP:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[OFF:%.*]]
+; CHECK-NEXT:    call void @f(ptr nonnull [[GEP]])
+; CHECK-NEXT:    ret void
+;
+  %ptr = select i1 %cond, ptr null, ptr %p
+  %gep = getelementptr i8, ptr %ptr, i64 %off
+  call void @f(ptr nonnull %gep)
+  ret void
+}
+
+define void @nonnull_call_gep_multiuse(i1 %cond, ptr %p, i64 %off) {
+; CHECK-LABEL: @nonnull_call_gep_multiuse(
+; CHECK-NEXT:    [[PTR:%.*]] = select i1 [[COND:%.*]], ptr null, ptr [[P:%.*]]
+; CHECK-NEXT:    [[GEP:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 [[OFF:%.*]]
+; CHECK-NEXT:    call void @f(ptr nonnull [[GEP]])
+; CHECK-NEXT:    call void @f(ptr [[GEP]])
+; CHECK-NEXT:    ret void
+;
+  %ptr = select i1 %cond, ptr null, ptr %p
+  %gep = getelementptr inbounds i8, ptr %ptr, i64 %off
+  call void @f(ptr nonnull %gep)
+  call void @f(ptr %gep)
+  ret void
+}
+
+define void @all_nonnull_call_gep_multiuse(i1 %cond, ptr %p, i64 %off) {
+; CHECK-LABEL: @all_nonnull_call_gep_multiuse(
+; CHECK-NEXT:    [[PTR:%.*]] = select i1 [[COND:%.*]], ptr null, ptr [[P:%.*]]
+; CHECK-NEXT:    [[GEP:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 [[OFF:%.*]]
+; CHECK-NEXT:    call void @f(ptr nonnull [[GEP]])
+; CHECK-NEXT:    call void @f(ptr nonnull [[GEP]])
+; CHECK-NEXT:    ret void
+;
+  %ptr = select i1 %cond, ptr null, ptr %p
+  %gep = getelementptr inbounds i8, ptr %ptr, i64 %off
+  call void @f(ptr nonnull %gep)
+  call void @f(ptr nonnull %gep)
+  ret void
+}
+
+define void @nonnull_call_gep_inbounds(i1 %cond, ptr %p, i64 %off) {
+; CHECK-LABEL: @nonnull_call_gep_inbounds(
+; CHECK-NEXT:    [[GEP:%.*]] = getelementptr inbounds i8, ptr [[PTR:%.*]], i64 [[OFF:%.*]]
+; CHECK-NEXT:    call void @f(ptr nonnull [[GEP]])
+; CHECK-NEXT:    ret void
+;
+  %ptr = select i1 %cond, ptr null, ptr %p
+  %gep = getelementptr inbounds i8, ptr %ptr, i64 %off
+  call void @f(ptr nonnull %gep)
+  ret void
+}
+
+define void @nonnull_dereferenceable_call_gep(i1 %cond, ptr %p, i64 %off) {
+; CHECK-LABEL: @nonnull_dereferenceable_call_gep(
+; CHECK-NEXT:    [[GEP:%.*]] = getelementptr i8, ptr [[PTR:%.*]], i64 [[OFF:%.*]]
+; CHECK-NEXT:    call void @f(ptr dereferenceable(1) [[GEP]])
+; CHECK-NEXT:    ret void
+;
+  %ptr = select i1 %cond, ptr null, ptr %p
+  %gep = getelementptr i8, ptr %ptr, i64 %off
+  call void @f(ptr dereferenceable(1) %gep)
+  ret void
+}
+
+define nonnull ptr @nonnull_ret_gep(i1 %cond, ptr %p, i64 %off) {
+; CHECK-LABEL: @nonnull_ret_gep(
+; CHECK-NEXT:    [[PTR:%.*]] = select i1 [[COND:%.*]], ptr null, ptr [[P:%.*]]
+; CHECK-NEXT:    [[GEP:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[OFF:%.*]]
+; CHECK-NEXT:    ret ptr [[GEP]]
+;
+  %ptr = select i1 %cond, ptr null, ptr %p
+  %gep = getelementptr i8, ptr %ptr, i64 %off
+  ret ptr %gep
+}
+
+define nonnull ptr @nonnull_ret_gep_inbounds(i1 %cond, ptr %p, i64 %off) {
+; CHECK-LABEL: @nonnull_ret_gep_inbounds(
+; CHECK-NEXT:    [[GEP:%.*]] = getelementptr inbounds i8, ptr [[PTR:%.*]], i64 [[OFF:%.*]]
+; CHECK-NEXT:    ret ptr [[GEP]]
+;
+  %ptr = select i1 %cond, ptr null, ptr %p
+  %gep = getelementptr inbounds i8, ptr %ptr, i64 %off
+  ret ptr %gep
+}
+
+define dereferenceable(1) ptr @nonnull_dereferenceable_ret_gep(i1 %cond, ptr %p, i64 %off) {
+; CHECK-LABEL: @nonnull_dereferenceable_ret_gep(
+; CHECK-NEXT:    [[GEP:%.*]] = getelementptr i8, ptr [[PTR:%.*]], i64 [[OFF:%.*]]
+; CHECK-NEXT:    ret ptr [[GEP]]
+;
+  %ptr = select i1 %cond, ptr null, ptr %p
+  %gep = getelementptr i8, ptr %ptr, i64 %off
+  ret ptr %gep
+}
+
 declare void @f(ptr)
diff --git a/llvm/test/Transforms/InstCombine/store.ll b/llvm/test/Transforms/InstCombine/store.ll
index 0a2b0a5ee7987..daa40da1828b5 100644
--- a/llvm/test/Transforms/InstCombine/store.ll
+++ b/llvm/test/Transforms/InstCombine/store.ll
@@ -387,6 +387,18 @@ define void @store_select_with_unknown(i1 %cond, ptr %p, ptr %p2) {
   ret void
 }
 
+define void @store_select_with_null_gep(i1 %cond, ptr %p, i64 %off) {
+; CHECK-LABEL: @store_select_with_null_gep(
+; CHECK-NEXT:    [[GEP:%.*]] = getelementptr i8, ptr [[SEL:%.*]], i64 [[OFF:%.*]]
+; CHECK-NEXT:    store i32 0, ptr [[GEP]], align 4
+; CHECK-NEXT:    ret void
+;
+  %sel = select i1 %cond, ptr %p, ptr null
+  %gep = getelementptr i8, ptr %sel, i64 %off
+  store i32 0, ptr %gep, align 4
+  ret void
+}
+
 !0 = !{!4, !4, i64 0}
 !1 = !{!"omnipotent char", !2}
 !2 = !{!"Simple C/C++ TBAA"}
diff --git a/llvm/test/Transforms/PhaseOrdering/memset-combine.ll b/llvm/test/Transforms/PhaseOrdering/memset-combine.ll
index d1de11258ed91..b1a5881bcaa9c 100644
--- a/llvm/test/Transforms/PhaseOrdering/memset-combine.ll
+++ b/llvm/test/Transforms/PhaseOrdering/memset-combine.ll
@@ -6,10 +6,8 @@
 define void @merge_memset(ptr %p, i1 %cond) {
 ; CHECK-LABEL: define void @merge_memset(
 ; CHECK-SAME: ptr [[P:%.*]], i1 [[COND:%.*]]) {
-; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[COND]], ptr null, ptr [[P]]
-; CHECK-NEXT:    tail call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(4096) [[P]], i8 0, i64 4096, i1 false)
-; CHECK-NEXT:    [[OFF:%.*]] = getelementptr inbounds nuw i8, ptr [[SEL]], i64 4096
-; CHECK-NEXT:    tail call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(768) [[OFF]], i8 0, i64 768, i1 false)
+; CHECK-NEXT:    [[OFF:%.*]] = getelementptr inbounds nuw i8, ptr [[P]], i64 4096
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr align 1 [[P]], i8 0, i64 4864, i1 false)
 ; CHECK-NEXT:    ret void
 ;
   %sel = select i1 %cond, ptr null, ptr %p

``````````

</details>


https://github.com/llvm/llvm-project/pull/128365


More information about the llvm-commits mailing list