[llvm] 126016b - [InstCombine] Simplify nonnull pointers (#128111)

via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 21 23:30:08 PST 2025


Author: Yingwei Zheng
Date: 2025-02-22T15:30:04+08:00
New Revision: 126016b6620e1e64e1685e06f4c6bb1e5a2508ed

URL: https://github.com/llvm/llvm-project/commit/126016b6620e1e64e1685e06f4c6bb1e5a2508ed
DIFF: https://github.com/llvm/llvm-project/commit/126016b6620e1e64e1685e06f4c6bb1e5a2508ed.diff

LOG: [InstCombine] Simplify nonnull pointers (#128111)

This patch is the follow-up of
https://github.com/llvm/llvm-project/pull/127979. It introduces a helper
`simplifyNonNullOperand` to avoid duplicate logic. It also addresses the
one-use issue in `visitLoadInst`, as discussed in
https://github.com/llvm/llvm-project/pull/127979#issuecomment-2671013972.
The `nonnull` attribute is also supported. Proof:
https://alive2.llvm.org/ce/z/MCKgT9

Added: 
    llvm/test/Transforms/PhaseOrdering/memset-combine.ll

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
    llvm/lib/Transforms/InstCombine/InstCombineInternal.h
    llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
    llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
    llvm/test/Transforms/InstCombine/nonnull-select.ll
    llvm/test/Transforms/PhaseOrdering/load-store-sameval.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 400ebcf493713..54f777ab20a7a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -3993,10 +3993,19 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) {
   unsigned ArgNo = 0;
 
   for (Value *V : Call.args()) {
-    if (V->getType()->isPointerTy() &&
-        !Call.paramHasAttr(ArgNo, Attribute::NonNull) &&
-        isKnownNonZero(V, getSimplifyQuery().getWithInstruction(&Call)))
-      ArgNos.push_back(ArgNo);
+    if (V->getType()->isPointerTy()) {
+      // Simplify the nonnull operand if the parameter is known to be nonnull.
+      // Otherwise, try to infer nonnull for it.
+      if (Call.paramHasNonNullAttr(ArgNo, /*AllowUndefOrPoison=*/true)) {
+        if (Value *Res = simplifyNonNullOperand(V)) {
+          replaceOperand(Call, ArgNo, Res);
+          Changed = true;
+        }
+      } else if (isKnownNonZero(V,
+                                getSimplifyQuery().getWithInstruction(&Call))) {
+        ArgNos.push_back(ArgNo);
+      }
+    }
     ArgNo++;
   }
 

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 83e1da98deeda..71c80d4c401f8 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -455,6 +455,10 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
 
   Instruction *hoistFNegAboveFMulFDiv(Value *FNegOp, Instruction &FMFSource);
 
+  /// Simplify \p V given that it is known to be non-null.
+  /// Returns the simplified value if possible, otherwise returns nullptr.
+  Value *simplifyNonNullOperand(Value *V);
+
 public:
   /// Create and insert the idiom we use to indicate a block is unreachable
   /// without having to rewrite the CFG from within InstCombine.

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
index d5534c15cca76..89fc1051b18dc 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
@@ -982,6 +982,19 @@ static bool canSimplifyNullLoadOrGEP(LoadInst &LI, Value *Op) {
   return false;
 }
 
+/// TODO: Recursively simplify nonnull value to handle one-use inbounds GEPs.
+Value *InstCombinerImpl::simplifyNonNullOperand(Value *V) {
+  if (auto *Sel = dyn_cast<SelectInst>(V)) {
+    if (isa<ConstantPointerNull>(Sel->getOperand(1)))
+      return Sel->getOperand(2);
+
+    if (isa<ConstantPointerNull>(Sel->getOperand(2)))
+      return Sel->getOperand(1);
+  }
+
+  return nullptr;
+}
+
 Instruction *InstCombinerImpl::visitLoadInst(LoadInst &LI) {
   Value *Op = LI.getOperand(0);
   if (Value *Res = simplifyLoadInst(&LI, Op, SQ.getWithInstruction(&LI)))
@@ -1059,20 +1072,13 @@ Instruction *InstCombinerImpl::visitLoadInst(LoadInst &LI) {
         V2->copyMetadata(LI, Metadata::PoisonGeneratingIDs);
         return SelectInst::Create(SI->getCondition(), V1, V2);
       }
-
-      // load (select (cond, null, P)) -> load P
-      if (isa<ConstantPointerNull>(SI->getOperand(1)) &&
-          !NullPointerIsDefined(SI->getFunction(),
-                                LI.getPointerAddressSpace()))
-        return replaceOperand(LI, 0, SI->getOperand(2));
-
-      // load (select (cond, P, null)) -> load P
-      if (isa<ConstantPointerNull>(SI->getOperand(2)) &&
-          !NullPointerIsDefined(SI->getFunction(),
-                                LI.getPointerAddressSpace()))
-        return replaceOperand(LI, 0, SI->getOperand(1));
     }
   }
+
+  if (!NullPointerIsDefined(LI.getFunction(), LI.getPointerAddressSpace()))
+    if (Value *V = simplifyNonNullOperand(Op))
+      return replaceOperand(LI, 0, V);
+
   return nullptr;
 }
 
@@ -1437,19 +1443,9 @@ Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) {
   if (isa<UndefValue>(Val))
     return eraseInstFromFunction(SI);
 
-  // TODO: Add a helper to simplify the pointer operand for all memory
-  // instructions.
-  // store val, (select (cond, null, P)) -> store val, P
-  // store val, (select (cond, P, null)) -> store val, P
-  if (!NullPointerIsDefined(SI.getFunction(), SI.getPointerAddressSpace())) {
-    if (SelectInst *Sel = dyn_cast<SelectInst>(Ptr)) {
-      if (isa<ConstantPointerNull>(Sel->getOperand(1)))
-        return replaceOperand(SI, 1, Sel->getOperand(2));
-
-      if (isa<ConstantPointerNull>(Sel->getOperand(2)))
-        return replaceOperand(SI, 1, Sel->getOperand(1));
-    }
-  }
+  if (!NullPointerIsDefined(SI.getFunction(), SI.getPointerAddressSpace()))
+    if (Value *V = simplifyNonNullOperand(Ptr))
+      return replaceOperand(SI, 1, V);
 
   return nullptr;
 }

diff  --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 5621511570b58..b7748f59a0cfc 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -3587,10 +3587,23 @@ Instruction *InstCombinerImpl::visitFree(CallInst &FI, Value *Op) {
 
 Instruction *InstCombinerImpl::visitReturnInst(ReturnInst &RI) {
   Value *RetVal = RI.getReturnValue();
-  if (!RetVal || !AttributeFuncs::isNoFPClassCompatibleType(RetVal->getType()))
+  if (!RetVal)
     return nullptr;
 
   Function *F = RI.getFunction();
+  Type *RetTy = RetVal->getType();
+  if (RetTy->isPointerTy()) {
+    if (F->hasRetAttribute(Attribute::NonNull) ||
+        (F->getAttributes().getRetDereferenceableBytes() > 0 &&
+         !NullPointerIsDefined(F, RetTy->getPointerAddressSpace()))) {
+      if (Value *V = simplifyNonNullOperand(RetVal))
+        return replaceOperand(RI, 0, V);
+    }
+  }
+
+  if (!AttributeFuncs::isNoFPClassCompatibleType(RetTy))
+    return nullptr;
+
   FPClassTest ReturnClass = F->getAttributes().getRetNoFPClass();
   if (ReturnClass == fcNone)
     return nullptr;

diff  --git a/llvm/test/Transforms/InstCombine/nonnull-select.ll b/llvm/test/Transforms/InstCombine/nonnull-select.ll
index 3fab2dfb41a42..cc000b4c88164 100644
--- a/llvm/test/Transforms/InstCombine/nonnull-select.ll
+++ b/llvm/test/Transforms/InstCombine/nonnull-select.ll
@@ -5,10 +5,7 @@
 
 define nonnull ptr @pr48975(ptr %.0) {
 ; CHECK-LABEL: @pr48975(
-; CHECK-NEXT:    [[DOT1:%.*]] = load ptr, ptr [[DOT0:%.*]], align 8
-; CHECK-NEXT:    [[DOT2:%.*]] = icmp eq ptr [[DOT1]], null
-; CHECK-NEXT:    [[DOT4:%.*]] = select i1 [[DOT2]], ptr null, ptr [[DOT0]]
-; CHECK-NEXT:    ret ptr [[DOT4]]
+; CHECK-NEXT:    ret ptr [[DOT4:%.*]]
 ;
   %.1 = load ptr, ptr %.0, align 8
   %.2 = icmp eq ptr %.1, null
@@ -18,8 +15,7 @@ define nonnull ptr @pr48975(ptr %.0) {
 
 define nonnull ptr @nonnull_ret(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_ret(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr [[P:%.*]], ptr null
-; CHECK-NEXT:    ret ptr [[RES]]
+; CHECK-NEXT:    ret ptr [[RES:%.*]]
 ;
   %res = select i1 %cond, ptr %p, ptr null
   ret ptr %res
@@ -27,8 +23,7 @@ define nonnull ptr @nonnull_ret(i1 %cond, ptr %p) {
 
 define nonnull ptr @nonnull_ret2(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_ret2(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr null, ptr [[P:%.*]]
-; CHECK-NEXT:    ret ptr [[RES]]
+; CHECK-NEXT:    ret ptr [[RES:%.*]]
 ;
   %res = select i1 %cond, ptr null, ptr %p
   ret ptr %res
@@ -36,8 +31,7 @@ define nonnull ptr @nonnull_ret2(i1 %cond, ptr %p) {
 
 define nonnull noundef ptr @nonnull_noundef_ret(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_noundef_ret(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr [[P:%.*]], ptr null
-; CHECK-NEXT:    ret ptr [[RES]]
+; CHECK-NEXT:    ret ptr [[RES:%.*]]
 ;
   %res = select i1 %cond, ptr %p, ptr null
   ret ptr %res
@@ -45,8 +39,7 @@ define nonnull noundef ptr @nonnull_noundef_ret(i1 %cond, ptr %p) {
 
 define nonnull noundef ptr @nonnull_noundef_ret2(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_noundef_ret2(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr null, ptr [[P:%.*]]
-; CHECK-NEXT:    ret ptr [[RES]]
+; CHECK-NEXT:    ret ptr [[RES:%.*]]
 ;
   %res = select i1 %cond, ptr null, ptr %p
   ret ptr %res
@@ -55,8 +48,7 @@ define nonnull noundef ptr @nonnull_noundef_ret2(i1 %cond, ptr %p) {
 
 define void @nonnull_call(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_call(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr [[P:%.*]], ptr null
-; CHECK-NEXT:    call void @f(ptr nonnull [[RES]])
+; CHECK-NEXT:    call void @f(ptr nonnull [[RES:%.*]])
 ; CHECK-NEXT:    ret void
 ;
   %res = select i1 %cond, ptr %p, ptr null
@@ -66,8 +58,7 @@ define void @nonnull_call(i1 %cond, ptr %p) {
 
 define void @nonnull_call2(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_call2(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr null, ptr [[P:%.*]]
-; CHECK-NEXT:    call void @f(ptr nonnull [[RES]])
+; CHECK-NEXT:    call void @f(ptr nonnull [[RES:%.*]])
 ; CHECK-NEXT:    ret void
 ;
   %res = select i1 %cond, ptr null, ptr %p
@@ -77,8 +68,7 @@ define void @nonnull_call2(i1 %cond, ptr %p) {
 
 define void @nonnull_noundef_call(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_noundef_call(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr [[P:%.*]], ptr null
-; CHECK-NEXT:    call void @f(ptr noundef nonnull [[RES]])
+; CHECK-NEXT:    call void @f(ptr noundef nonnull [[RES:%.*]])
 ; CHECK-NEXT:    ret void
 ;
   %res = select i1 %cond, ptr %p, ptr null
@@ -88,8 +78,7 @@ define void @nonnull_noundef_call(i1 %cond, ptr %p) {
 
 define void @nonnull_noundef_call2(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_noundef_call2(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr null, ptr [[P:%.*]]
-; CHECK-NEXT:    call void @f(ptr noundef nonnull [[RES]])
+; CHECK-NEXT:    call void @f(ptr noundef nonnull [[RES:%.*]])
 ; CHECK-NEXT:    ret void
 ;
   %res = select i1 %cond, ptr null, ptr %p

diff  --git a/llvm/test/Transforms/PhaseOrdering/load-store-sameval.ll b/llvm/test/Transforms/PhaseOrdering/load-store-sameval.ll
index d8ef0723cf09e..f6bf57a678786 100644
--- a/llvm/test/Transforms/PhaseOrdering/load-store-sameval.ll
+++ b/llvm/test/Transforms/PhaseOrdering/load-store-sameval.ll
@@ -1,24 +1,21 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
 ; RUN: opt -passes='instcombine,early-cse<memssa>' -S %s | FileCheck %s
 
-; FIXME: We can remove the store instruction in the exit block
 define i32 @load_store_sameval(ptr %p, i1 %cond1, i1 %cond2) {
 ; CHECK-LABEL: define i32 @load_store_sameval(
 ; CHECK-SAME: ptr [[P:%.*]], i1 [[COND1:%.*]], i1 [[COND2:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
-; CHECK-NEXT:    [[SPEC_SELECT:%.*]] = select i1 [[COND1]], ptr null, ptr [[P]]
-; CHECK-NEXT:    [[PRE:%.*]] = load i32, ptr [[SPEC_SELECT]], align 4
+; CHECK-NEXT:    [[PRE:%.*]] = load i32, ptr [[P]], align 4
 ; CHECK-NEXT:    br label %[[BLOCK:.*]]
 ; CHECK:       [[BLOCK]]:
 ; CHECK-NEXT:    br label %[[BLOCK2:.*]]
 ; CHECK:       [[BLOCK2]]:
 ; CHECK-NEXT:    br i1 [[COND2]], label %[[BLOCK3:.*]], label %[[EXIT:.*]]
 ; CHECK:       [[BLOCK3]]:
-; CHECK-NEXT:    [[LOAD:%.*]] = load double, ptr [[SPEC_SELECT]], align 8
+; CHECK-NEXT:    [[LOAD:%.*]] = load double, ptr [[P]], align 8
 ; CHECK-NEXT:    [[CMP:%.*]] = fcmp une double [[LOAD]], 0.000000e+00
 ; CHECK-NEXT:    br i1 [[CMP]], label %[[BLOCK]], label %[[BLOCK2]]
 ; CHECK:       [[EXIT]]:
-; CHECK-NEXT:    store i32 [[PRE]], ptr [[P]], align 4
 ; CHECK-NEXT:    ret i32 0
 ;
 entry:

diff  --git a/llvm/test/Transforms/PhaseOrdering/memset-combine.ll b/llvm/test/Transforms/PhaseOrdering/memset-combine.ll
new file mode 100644
index 0000000000000..d1de11258ed91
--- /dev/null
+++ b/llvm/test/Transforms/PhaseOrdering/memset-combine.ll
@@ -0,0 +1,20 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+
+; RUN: opt < %s -passes=instcombine,memcpyopt -S | FileCheck %s
+
+; FIXME: These two memset calls should be merged into a single one.
+define void @merge_memset(ptr %p, i1 %cond) {
+; CHECK-LABEL: define void @merge_memset(
+; CHECK-SAME: ptr [[P:%.*]], i1 [[COND:%.*]]) {
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[COND]], ptr null, ptr [[P]]
+; CHECK-NEXT:    tail call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(4096) [[P]], i8 0, i64 4096, i1 false)
+; CHECK-NEXT:    [[OFF:%.*]] = getelementptr inbounds nuw i8, ptr [[SEL]], i64 4096
+; CHECK-NEXT:    tail call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(768) [[OFF]], i8 0, i64 768, i1 false)
+; CHECK-NEXT:    ret void
+;
+  %sel = select i1 %cond, ptr null, ptr %p
+  tail call void @llvm.memset.p0.i64(ptr noundef nonnull %sel, i8 0, i64 4096, i1 false)
+  %off = getelementptr inbounds nuw i8, ptr %sel, i64 4096
+  tail call void @llvm.memset.p0.i64(ptr noundef nonnull %off, i8 0, i64 768, i1 false)
+  ret void
+}


        


More information about the llvm-commits mailing list