[llvm] [GVN] Restrict equality propagation for pointers (PR #82458)

Usman Nadeem via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 23 19:13:25 PDT 2024


https://github.com/UsmanNadeem updated https://github.com/llvm/llvm-project/pull/82458

>From f4ebdf067e22fd50e54ac1a7158a900bfe789f8c Mon Sep 17 00:00:00 2001
From: "Nadeem, Usman" <mnadeem at quicinc.com>
Date: Thu, 10 Aug 2023 17:33:08 -0700
Subject: [PATCH 1/7] [GVN] Restrict equality propagation for pointers

Reviving https://reviews.llvm.org/D143129

This patch does the following:

Adds the following functions:
- replaceDominatedUsesWithIf() that takes a callback.

- canReplacePointersIfEqual(...) returns true if the underlying object
is the same, and for null and const dereferencable pointer replacements.

- canReplacePointersIfEqualInUse(...) returns true for the above as well
as if the use is in icmp/ptrtoint or phi/selects feeding into them.

Updates GVN using the functions above so that the pointer replacements
are only made using the above API.

Change-Id: I4927ea452734458be028854ef0e5cbcd81955910
---
 llvm/include/llvm/Analysis/Loads.h         |  15 +-
 llvm/include/llvm/Transforms/Utils/Local.h |  12 ++
 llvm/lib/Analysis/Loads.cpp                |  78 ++++++--
 llvm/lib/Transforms/Scalar/GVN.cpp         |  24 ++-
 llvm/lib/Transforms/Utils/Local.cpp        |  30 ++-
 llvm/test/Transforms/GVN/condprop.ll       | 207 ++++++++++++++++++++-
 llvm/unittests/Analysis/LoadsTest.cpp      |  49 +++--
 7 files changed, 354 insertions(+), 61 deletions(-)

diff --git a/llvm/include/llvm/Analysis/Loads.h b/llvm/include/llvm/Analysis/Loads.h
index 0926093bba99de..e39a1d6cafad23 100644
--- a/llvm/include/llvm/Analysis/Loads.h
+++ b/llvm/include/llvm/Analysis/Loads.h
@@ -173,14 +173,15 @@ Value *findAvailablePtrLoadStore(const MemoryLocation &Loc, Type *AccessTy,
                                  unsigned MaxInstsToScan, BatchAAResults *AA,
                                  bool *IsLoadCSE, unsigned *NumScanedInst);
 
-/// Returns true if a pointer value \p A can be replace with another pointer
-/// value \B if they are deemed equal through some means (e.g. information from
+/// Returns true if a pointer value \p From can be replaced with another pointer
+/// value \To if they are deemed equal through some means (e.g. information from
 /// conditions).
-/// NOTE: the current implementations is incomplete and unsound. It does not
-/// reject all invalid cases yet, but will be made stricter in the future. In
-/// particular this means returning true means unknown if replacement is safe.
-bool canReplacePointersIfEqual(Value *A, Value *B, const DataLayout &DL,
-                               Instruction *CtxI);
+/// NOTE: The current implementation allows replacement in Icmp and PtrToInt
+/// instructions, as well as when we are replacing with a null pointer.
+/// Additionally it also allows replacement of pointers when both pointers have
+/// the same underlying object.
+bool canReplacePointersIfEqual(const Value *From, const Value *To);
+bool canReplacePointersInUseIfEqual(const Use &U, const Value *To);
 }
 
 #endif
diff --git a/llvm/include/llvm/Transforms/Utils/Local.h b/llvm/include/llvm/Transforms/Utils/Local.h
index 2df3c9049c7d62..d0dbf1951fe8b6 100644
--- a/llvm/include/llvm/Transforms/Utils/Local.h
+++ b/llvm/include/llvm/Transforms/Utils/Local.h
@@ -432,6 +432,18 @@ unsigned replaceDominatedUsesWith(Value *From, Value *To, DominatorTree &DT,
 /// the end of the given BasicBlock. Returns the number of replacements made.
 unsigned replaceDominatedUsesWith(Value *From, Value *To, DominatorTree &DT,
                                   const BasicBlock *BB);
+/// Replace each use of 'From' with 'To' if that use is dominated by
+/// the given edge and the callback ShouldReplace returns true. Returns the
+/// number of replacements made.
+unsigned replaceDominatedUsesWithIf(
+    Value *From, Value *To, DominatorTree &DT, const BasicBlockEdge &Edge,
+    function_ref<bool(const Use &U, const Value *To)> ShouldReplace);
+/// Replace each use of 'From' with 'To' if that use is dominated by
+/// the end of the given BasicBlock and the callback ShouldReplace returns true.
+/// Returns the number of replacements made.
+unsigned replaceDominatedUsesWithIf(
+    Value *From, Value *To, DominatorTree &DT, const BasicBlock *BB,
+    function_ref<bool(const Use &U, const Value *To)> ShouldReplace);
 
 /// Return true if this call calls a gc leaf function.
 ///
diff --git a/llvm/lib/Analysis/Loads.cpp b/llvm/lib/Analysis/Loads.cpp
index 5916d2ab48ecec..c9671577441a63 100644
--- a/llvm/lib/Analysis/Loads.cpp
+++ b/llvm/lib/Analysis/Loads.cpp
@@ -708,22 +708,66 @@ Value *llvm::FindAvailableLoadedValue(LoadInst *Load, BatchAAResults &AA,
   return Available;
 }
 
-bool llvm::canReplacePointersIfEqual(Value *A, Value *B, const DataLayout &DL,
-                                     Instruction *CtxI) {
-  Type *Ty = A->getType();
-  assert(Ty == B->getType() && Ty->isPointerTy() &&
-         "values must have matching pointer types");
-
-  // NOTE: The checks in the function are incomplete and currently miss illegal
-  // cases! The current implementation is a starting point and the
-  // implementation should be made stricter over time.
-  if (auto *C = dyn_cast<Constant>(B)) {
-    // Do not allow replacing a pointer with a constant pointer, unless it is
-    // either null or at least one byte is dereferenceable.
-    APInt OneByte(DL.getPointerTypeSizeInBits(Ty), 1);
-    return C->isNullValue() ||
-           isDereferenceableAndAlignedPointer(B, Align(1), OneByte, DL, CtxI);
-  }
+// Returns true if a use is either in an ICmp/PtrToInt or a Phi/Select that only
+// feeds into them.
+static bool isPointerUseReplacable(const Use &U, int MaxLookup = 6) {
+  if (MaxLookup == 0)
+    return false;
+
+  const User *User = U.getUser();
+  if (isa<ICmpInst>(User))
+    return true;
+  if (isa<PtrToIntInst>(User))
+    return true;
+  if (isa<PHINode, SelectInst>(User) &&
+      all_of(User->uses(), [&](const Use &Use) {
+        return isPointerUseReplacable(Use, MaxLookup - 1);
+      }))
+    return true;
+
+  return false;
+}
+
+static const DataLayout &getDLFromVal(const Value *V) {
+  if (const Argument *A = dyn_cast<Argument>(V))
+    return A->getParent()->getParent()->getDataLayout();
+  if (const Instruction *I = dyn_cast<Instruction>(V))
+    return I->getModule()->getDataLayout();
+  if (const GlobalValue *GV = dyn_cast<GlobalValue>(V))
+    return GV->getParent()->getDataLayout();
+  llvm_unreachable("Unknown Value type");
+}
+
+// Returns true if `To` is a null pointer, constant dereferenceable pointer or
+// both pointers have the same underlying objects.
+static bool isPointerAlwaysReplacable(const Value *From, const Value *To) {
+  if (isa<ConstantPointerNull>(To))
+    return true;
+  if (isa<Constant>(To) &&
+      isDereferenceablePointer(To, Type::getInt8Ty(To->getContext()),
+                               getDLFromVal(From)))
+    return true;
+  if (getUnderlyingObject(From) == getUnderlyingObject(To))
+    return true;
+  return false;
+}
+
+bool llvm::canReplacePointersInUseIfEqual(const Use &U, const Value *To) {
+  assert(U->getType() == To->getType() && "values must have matching types");
+  // Not a pointer, just return true.
+  if (!To->getType()->isPointerTy())
+    return true;
+
+  if (isPointerAlwaysReplacable(&*U, To))
+    return true;
+  return isPointerUseReplacable(U);
+}
+
+bool llvm::canReplacePointersIfEqual(const Value *From, const Value *To) {
+  assert(From->getType() == To->getType() && "values must have matching types");
+  // Not a pointer, just return true.
+  if (!From->getType()->isPointerTy())
+    return true;
 
-  return true;
+  return isPointerAlwaysReplacable(From, To);
 }
diff --git a/llvm/lib/Transforms/Scalar/GVN.cpp b/llvm/lib/Transforms/Scalar/GVN.cpp
index dcb1ed334b6103..d4b16836d728fa 100644
--- a/llvm/lib/Transforms/Scalar/GVN.cpp
+++ b/llvm/lib/Transforms/Scalar/GVN.cpp
@@ -33,6 +33,7 @@
 #include "llvm/Analysis/GlobalsModRef.h"
 #include "llvm/Analysis/InstructionPrecedenceTracking.h"
 #include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Analysis/Loads.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/MemoryBuiltins.h"
 #include "llvm/Analysis/MemoryDependenceAnalysis.h"
@@ -2443,7 +2444,8 @@ bool GVNPass::propagateEquality(Value *LHS, Value *RHS,
     // using the leader table is about compiling faster, not optimizing better).
     // The leader table only tracks basic blocks, not edges. Only add to if we
     // have the simple case where the edge dominates the end.
-    if (RootDominatesEnd && !isa<Instruction>(RHS))
+    if (RootDominatesEnd && !isa<Instruction>(RHS) &&
+        canReplacePointersIfEqual(LHS, RHS))
       addToLeaderTable(LVN, RHS, Root.getEnd());
 
     // Replace all occurrences of 'LHS' with 'RHS' everywhere in the scope.  As
@@ -2452,14 +2454,18 @@ bool GVNPass::propagateEquality(Value *LHS, Value *RHS,
     if (!LHS->hasOneUse()) {
       unsigned NumReplacements =
           DominatesByEdge
-              ? replaceDominatedUsesWith(LHS, RHS, *DT, Root)
-              : replaceDominatedUsesWith(LHS, RHS, *DT, Root.getStart());
-
-      Changed |= NumReplacements > 0;
-      NumGVNEqProp += NumReplacements;
-      // Cached information for anything that uses LHS will be invalid.
-      if (MD)
-        MD->invalidateCachedPointerInfo(LHS);
+              ? replaceDominatedUsesWithIf(LHS, RHS, *DT, Root,
+                                           canReplacePointersInUseIfEqual)
+              : replaceDominatedUsesWithIf(LHS, RHS, *DT, Root.getStart(),
+                                           canReplacePointersInUseIfEqual);
+
+      if (NumReplacements > 0) {
+        Changed = true;
+        NumGVNEqProp += NumReplacements;
+        // Cached information for anything that uses LHS will be invalid.
+        if (MD)
+          MD->invalidateCachedPointerInfo(LHS);
+      }
     }
 
     // Now try to deduce additional equalities from this one. For example, if
diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp
index 075eeb5b19fd2b..2d2e8f9d28805b 100644
--- a/llvm/lib/Transforms/Utils/Local.cpp
+++ b/llvm/lib/Transforms/Utils/Local.cpp
@@ -3395,15 +3395,18 @@ void llvm::patchReplacementInstruction(Instruction *I, Value *Repl) {
 }
 
 template <typename RootType, typename DominatesFn>
-static unsigned replaceDominatedUsesWith(Value *From, Value *To,
-                                         const RootType &Root,
-                                         const DominatesFn &Dominates) {
+static unsigned replaceDominatedUsesWith(
+    Value *From, Value *To, const RootType &Root, const DominatesFn &Dominates,
+    std::optional<function_ref<bool(const Use &U, const Value *To)>>
+        ShouldReplace) {
   assert(From->getType() == To->getType());
 
   unsigned Count = 0;
   for (Use &U : llvm::make_early_inc_range(From->uses())) {
     if (!Dominates(Root, U))
       continue;
+    if (ShouldReplace.has_value() && !ShouldReplace.value()(U, To))
+      continue;
     LLVM_DEBUG(dbgs() << "Replace dominated use of '";
                From->printAsOperand(dbgs());
                dbgs() << "' with " << *To << " in " << *U.getUser() << "\n");
@@ -3434,7 +3437,7 @@ unsigned llvm::replaceDominatedUsesWith(Value *From, Value *To,
   auto Dominates = [&DT](const BasicBlockEdge &Root, const Use &U) {
     return DT.dominates(Root, U);
   };
-  return ::replaceDominatedUsesWith(From, To, Root, Dominates);
+  return ::replaceDominatedUsesWith(From, To, Root, Dominates, std::nullopt);
 }
 
 unsigned llvm::replaceDominatedUsesWith(Value *From, Value *To,
@@ -3443,9 +3446,26 @@ unsigned llvm::replaceDominatedUsesWith(Value *From, Value *To,
   auto Dominates = [&DT](const BasicBlock *BB, const Use &U) {
     return DT.dominates(BB, U);
   };
-  return ::replaceDominatedUsesWith(From, To, BB, Dominates);
+  return ::replaceDominatedUsesWith(From, To, BB, Dominates, std::nullopt);
+}
+
+unsigned llvm::replaceDominatedUsesWithIf(
+    Value *From, Value *To, DominatorTree &DT, const BasicBlockEdge &Root,
+    function_ref<bool(const Use &U, const Value *To)> ShouldReplace) {
+  auto Dominates = [&DT](const BasicBlockEdge &Root, const Use &U) {
+    return DT.dominates(Root, U);
+  };
+  return ::replaceDominatedUsesWith(From, To, Root, Dominates, ShouldReplace);
 }
 
+unsigned llvm::replaceDominatedUsesWithIf(
+    Value *From, Value *To, DominatorTree &DT, const BasicBlock *BB,
+    function_ref<bool(const Use &U, const Value *To)> ShouldReplace) {
+  auto Dominates = [&DT](const BasicBlock *BB, const Use &U) {
+    return DT.dominates(BB, U);
+  };
+  return ::replaceDominatedUsesWith(From, To, BB, Dominates, ShouldReplace);
+}
 bool llvm::callsGCLeafFunction(const CallBase *Call,
                                const TargetLibraryInfo &TLI) {
   // Check if the function is specifically marked as a gc leaf function.
diff --git a/llvm/test/Transforms/GVN/condprop.ll b/llvm/test/Transforms/GVN/condprop.ll
index 6b1e4d10601099..2b202b1701150e 100644
--- a/llvm/test/Transforms/GVN/condprop.ll
+++ b/llvm/test/Transforms/GVN/condprop.ll
@@ -521,15 +521,16 @@ define i32 @test13(ptr %ptr1, ptr %ptr2) {
 ; CHECK-NEXT:    [[GEP1:%.*]] = getelementptr i32, ptr [[PTR2:%.*]], i32 1
 ; CHECK-NEXT:    [[GEP2:%.*]] = getelementptr i32, ptr [[PTR2]], i32 2
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp eq ptr [[PTR1:%.*]], [[PTR2]]
-; CHECK-NEXT:    [[VAL2_PRE:%.*]] = load i32, ptr [[GEP2]], align 4
 ; CHECK-NEXT:    br i1 [[CMP]], label [[IF:%.*]], label [[END:%.*]]
 ; CHECK:       if:
+; CHECK-NEXT:    [[VAL1:%.*]] = load i32, ptr [[GEP2]], align 4
 ; CHECK-NEXT:    br label [[END]]
 ; CHECK:       end:
-; CHECK-NEXT:    [[PHI1:%.*]] = phi ptr [ [[PTR2]], [[IF]] ], [ [[GEP1]], [[ENTRY:%.*]] ]
-; CHECK-NEXT:    [[PHI2:%.*]] = phi i32 [ [[VAL2_PRE]], [[IF]] ], [ 0, [[ENTRY]] ]
+; CHECK-NEXT:    [[PHI1:%.*]] = phi ptr [ [[PTR1]], [[IF]] ], [ [[GEP1]], [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[PHI2:%.*]] = phi i32 [ [[VAL1]], [[IF]] ], [ 0, [[ENTRY]] ]
 ; CHECK-NEXT:    store i32 0, ptr [[PHI1]], align 4
-; CHECK-NEXT:    [[RET:%.*]] = add i32 [[PHI2]], [[VAL2_PRE]]
+; CHECK-NEXT:    [[VAL2:%.*]] = load i32, ptr [[GEP2]], align 4
+; CHECK-NEXT:    [[RET:%.*]] = add i32 [[PHI2]], [[VAL2]]
 ; CHECK-NEXT:    ret i32 [[RET]]
 ;
 entry:
@@ -574,7 +575,7 @@ define void @test14(ptr %ptr1, ptr noalias %ptr2) {
 ; CHECK:       if2:
 ; CHECK-NEXT:    br label [[LOOP_END]]
 ; CHECK:       loop.end:
-; CHECK-NEXT:    [[PHI3:%.*]] = phi ptr [ [[PTR2]], [[THEN]] ], [ [[PTR1]], [[IF2]] ]
+; CHECK-NEXT:    [[PHI3:%.*]] = phi ptr [ [[GEP2]], [[THEN]] ], [ [[PTR1]], [[IF2]] ]
 ; CHECK-NEXT:    [[VAL3]] = load i32, ptr [[GEP2]], align 4
 ; CHECK-NEXT:    store i32 [[VAL3]], ptr [[PHI3]], align 4
 ; CHECK-NEXT:    br i1 undef, label [[LOOP]], label [[IF1]]
@@ -609,3 +610,199 @@ loop.end:
   %gep3 = getelementptr inbounds i32, ptr %ptr1, i32 1
   br i1 undef, label %loop, label %if1
 }
+
+; Make sure that the call to use_ptr does not have %p1
+define void @single_phi1(ptr %p1) {
+; CHECK-LABEL: @single_phi1(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[P2:%.*]] = load ptr, ptr poison, align 8
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq ptr [[P2]], [[P1:%.*]]
+; CHECK-NEXT:    br i1 [[CMP1]], label [[BB4:%.*]], label [[BB1:%.*]]
+; CHECK:       bb1:
+; CHECK-NEXT:    switch i8 poison, label [[BB2:%.*]] [
+; CHECK-NEXT:    i8 0, label [[BB1]]
+; CHECK-NEXT:    i8 1, label [[BB3:%.*]]
+; CHECK-NEXT:    ]
+; CHECK:       bb2:
+; CHECK-NEXT:    unreachable
+; CHECK:       bb3:
+; CHECK-NEXT:    br label [[BB4]]
+; CHECK:       bb4:
+; CHECK-NEXT:    call void @use_bool(i1 [[CMP1]])
+; CHECK-NEXT:    call void @use_ptr(ptr [[P2]])
+; CHECK-NEXT:    ret void
+;
+entry:
+  %p2 = load ptr, ptr poison, align 8
+  %cmp1 = icmp eq ptr %p2, %p1
+  br i1 %cmp1, label %bb4, label %bb1
+
+bb1:
+  switch i8 poison, label %bb2 [
+  i8 0, label %bb1
+  i8 1, label %bb3
+  ]
+
+bb2:
+  unreachable
+
+bb3:
+  br label %bb4
+
+bb4:
+  %phi1 = phi ptr [ %p2, %entry ], [ poison, %bb3 ]
+  %cmp2 = icmp eq ptr %phi1, %p1
+  call void @use_bool(i1 %cmp2)
+  call void @use_ptr(ptr %phi1)
+  ret void
+}
+
+define void @single_phi2(ptr %p1) {
+; CHECK-LABEL: @single_phi2(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[P2:%.*]] = load ptr, ptr poison, align 8
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq ptr [[P2]], [[P1:%.*]]
+; CHECK-NEXT:    br i1 [[CMP1]], label [[BB4:%.*]], label [[BB1:%.*]]
+; CHECK:       bb1:
+; CHECK-NEXT:    switch i8 poison, label [[BB2:%.*]] [
+; CHECK-NEXT:    i8 0, label [[BB1]]
+; CHECK-NEXT:    i8 1, label [[BB3:%.*]]
+; CHECK-NEXT:    ]
+; CHECK:       bb2:
+; CHECK-NEXT:    br label [[BB4]]
+; CHECK:       bb3:
+; CHECK-NEXT:    br label [[BB4]]
+; CHECK:       bb4:
+; CHECK-NEXT:    call void @use_bool(i1 [[CMP1]])
+; CHECK-NEXT:    call void @use_ptr(ptr [[P2]])
+; CHECK-NEXT:    ret void
+;
+entry:
+  %p2 = load ptr, ptr poison, align 8
+  %cmp1 = icmp eq ptr %p2, %p1
+  br i1 %cmp1, label %bb4, label %bb1
+
+bb1:
+  switch i8 poison, label %bb2 [
+  i8 0, label %bb1
+  i8 1, label %bb3
+  ]
+
+bb2:
+  br label %bb4
+
+bb3:
+  br label %bb4
+
+bb4:
+  %phi1 = phi ptr [ %p2, %entry ], [ %p2, %bb2 ], [ poison, %bb3 ]
+  %cmp2 = icmp eq ptr %phi1, %p1
+  call void @use_bool(i1 %cmp2)
+  call void @use_ptr(ptr %phi1)
+  ret void
+}
+
+define void @multiple_phi1(ptr %p1) {
+; CHECK-LABEL: @multiple_phi1(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[P2:%.*]] = load ptr, ptr poison, align 8
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq ptr [[P2]], [[P1:%.*]]
+; CHECK-NEXT:    br i1 [[CMP1]], label [[BB4:%.*]], label [[BB1:%.*]]
+; CHECK:       bb1:
+; CHECK-NEXT:    switch i8 poison, label [[BB2:%.*]] [
+; CHECK-NEXT:    i8 0, label [[BB1]]
+; CHECK-NEXT:    i8 1, label [[BB3:%.*]]
+; CHECK-NEXT:    ]
+; CHECK:       bb2:
+; CHECK-NEXT:    unreachable
+; CHECK:       bb3:
+; CHECK-NEXT:    br label [[BB4]]
+; CHECK:       bb4:
+; CHECK-NEXT:    call void @use_bool(i1 [[CMP1]])
+; CHECK-NEXT:    br label [[BB5:%.*]]
+; CHECK:       bb5:
+; CHECK-NEXT:    call void @use_ptr(ptr [[P2]])
+; CHECK-NEXT:    br label [[BB5]]
+;
+entry:
+  %p2 = load ptr, ptr poison, align 8
+  %cmp1 = icmp eq ptr %p2, %p1
+  br i1 %cmp1, label %bb4, label %bb1
+
+bb1:
+  switch i8 poison, label %bb2 [
+  i8 0, label %bb1
+  i8 1, label %bb3
+  ]
+
+bb2:
+  unreachable
+
+bb3:
+  br label %bb4
+
+bb4:
+  %phi1 = phi ptr [ %p2, %entry ], [ poison, %bb3 ]
+  %cmp2 = icmp eq ptr %phi1, %p1
+  call void @use_bool(i1 %cmp2)
+  br label %bb5
+
+bb5:
+  %phi2 = phi ptr [ poison, %bb5 ], [ %phi1, %bb4 ]
+  call void @use_ptr(ptr %phi2)
+  br label %bb5
+}
+
+define void @multiple_phi2(ptr %p1) {
+; CHECK-LABEL: @multiple_phi2(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[P2:%.*]] = load ptr, ptr poison, align 8
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq ptr [[P2]], [[P1:%.*]]
+; CHECK-NEXT:    br i1 [[CMP1]], label [[BB4:%.*]], label [[BB1:%.*]]
+; CHECK:       bb1:
+; CHECK-NEXT:    switch i8 poison, label [[BB2:%.*]] [
+; CHECK-NEXT:    i8 0, label [[BB1]]
+; CHECK-NEXT:    i8 1, label [[BB3:%.*]]
+; CHECK-NEXT:    ]
+; CHECK:       bb2:
+; CHECK-NEXT:    br label [[BB4]]
+; CHECK:       bb3:
+; CHECK-NEXT:    br label [[BB4]]
+; CHECK:       bb4:
+; CHECK-NEXT:    call void @use_bool(i1 [[CMP1]])
+; CHECK-NEXT:    br label [[BB5:%.*]]
+; CHECK:       bb5:
+; CHECK-NEXT:    call void @use_ptr(ptr [[P2]])
+; CHECK-NEXT:    br label [[BB5]]
+;
+entry:
+  %p2 = load ptr, ptr poison, align 8
+  %cmp1 = icmp eq ptr %p2, %p1
+  br i1 %cmp1, label %bb4, label %bb1
+
+bb1:
+  switch i8 poison, label %bb2 [
+  i8 0, label %bb1
+  i8 1, label %bb3
+  ]
+
+bb2:
+  br label %bb4
+
+bb3:
+  br label %bb4
+
+bb4:
+  %phi1 = phi ptr [ %p2, %entry ], [ %p2, %bb2 ], [ poison, %bb3 ]
+  %cmp2 = icmp eq ptr %phi1, %p1
+  call void @use_bool(i1 %cmp2)
+  br label %bb5
+
+bb5:
+  %phi2 = phi ptr [ poison, %bb5 ], [ %phi1, %bb4 ]
+  call void @use_ptr(ptr %phi2)
+  br label %bb5
+}
+
+declare void @use_bool(i1)
+declare void @use_ptr(ptr)
diff --git a/llvm/unittests/Analysis/LoadsTest.cpp b/llvm/unittests/Analysis/LoadsTest.cpp
index 0111cfeefa41ae..0694889596fa23 100644
--- a/llvm/unittests/Analysis/LoadsTest.cpp
+++ b/llvm/unittests/Analysis/LoadsTest.cpp
@@ -68,35 +68,48 @@ TEST(LoadsTest, CanReplacePointersIfEqual) {
                                       R"IR(
 @y = common global [1 x i32] zeroinitializer, align 4
 @x = common global [1 x i32] zeroinitializer, align 4
-
 declare void @use(i32*)
 
-define void @f(i32* %p) {
+define void @f(i32* %p1, i32* %p2, i64 %i) {
   call void @use(i32* getelementptr inbounds ([1 x i32], [1 x i32]* @y, i64 0, i64 0))
-  call void @use(i32* getelementptr inbounds (i32, i32* getelementptr inbounds ([1 x i32], [1 x i32]* @x, i64 0, i64 0), i64 1))
+
+  %p1_idx = getelementptr inbounds i32, i32* %p1, i64 %i
+  call void @use(i32* %p1_idx)
+
+  %icmp = icmp eq i32* %p1, getelementptr inbounds ([1 x i32], [1 x i32]* @y, i64 0, i64 0)
+  %ptrInt = ptrtoint i32* %p1 to i64
   ret void
 }
 )IR");
-  const auto &DL = M->getDataLayout();
   auto *GV = M->getNamedValue("f");
   ASSERT_TRUE(GV);
   auto *F = dyn_cast<Function>(GV);
   ASSERT_TRUE(F);
 
-  // NOTE: the implementation of canReplacePointersIfEqual is incomplete.
-  // Currently the only the cases it returns false for are really sound and
-  // returning true means unknown.
-  Value *P = &*F->arg_begin();
+  Value *P1 = &*F->arg_begin();
+  Value *P2 = F->getArg(1);
+  Value *NullPtr = Constant::getNullValue(P1->getType());
   auto InstIter = F->front().begin();
-  Value *ConstDerefPtr = *cast<CallInst>(&*InstIter)->arg_begin();
-  // ConstDerefPtr is a constant pointer that is provably de-referenceable. We
-  // can replace an arbitrary pointer with it.
-  EXPECT_TRUE(canReplacePointersIfEqual(P, ConstDerefPtr, DL, nullptr));
+  CallInst *UserOfY = cast<CallInst>(&*InstIter);
+  Value *ConstDerefPtr = UserOfY->getArgOperand(0);
+  // We cannot replace two pointers in arbitrary instructions unless we are
+  // replacing with null, a constant dereferencable pointer or they have the
+  // same underlying object.
+  EXPECT_FALSE(canReplacePointersIfEqual(ConstDerefPtr, P1));
+  EXPECT_FALSE(canReplacePointersIfEqual(P1, P2));
+  EXPECT_TRUE(canReplacePointersIfEqual(P1, ConstDerefPtr));
+  EXPECT_TRUE(canReplacePointersIfEqual(P1, NullPtr));
+
+  GetElementPtrInst *BasedOnP1 = cast<GetElementPtrInst>(&*++InstIter);
+  EXPECT_TRUE(canReplacePointersIfEqual(BasedOnP1, P1));
+  EXPECT_FALSE(canReplacePointersIfEqual(BasedOnP1, P2));
 
-  ++InstIter;
-  Value *ConstUnDerefPtr = *cast<CallInst>(&*InstIter)->arg_begin();
-  // ConstUndDerefPtr is a constant pointer that is provably not
-  // de-referenceable. We cannot replace an arbitrary pointer with it.
-  EXPECT_FALSE(
-      canReplacePointersIfEqual(ConstDerefPtr, ConstUnDerefPtr, DL, nullptr));
+  // // We can replace two arbitrary pointers in icmp and ptrtoint instructions.
+  auto P1UseIter = P1->use_begin();
+  const Use &PtrToIntUse = *P1UseIter;
+  const Use &IcmpUse = *++P1UseIter;
+  const Use &GEPUse = *++P1UseIter;
+  EXPECT_FALSE(canReplacePointersInUseIfEqual(GEPUse, P2));
+  EXPECT_TRUE(canReplacePointersInUseIfEqual(PtrToIntUse, P2));
+  EXPECT_TRUE(canReplacePointersInUseIfEqual(IcmpUse, P2));
 }

>From 1f141c4b2c2ec718b6be1ea187805320ac0a9094 Mon Sep 17 00:00:00 2001
From: "Nadeem, Usman" <mnadeem at quicinc.com>
Date: Mon, 26 Feb 2024 14:49:32 -0800
Subject: [PATCH 2/7] Fix typos, remove some uses of undef/poison

Change-Id: I7d0ddd8e6d2bab90de72d2c3f691cfbea9c3e257
---
 llvm/lib/Analysis/Loads.cpp          |  6 +--
 llvm/test/Transforms/GVN/condprop.ll | 80 ++++++++++++++--------------
 2 files changed, 43 insertions(+), 43 deletions(-)

diff --git a/llvm/lib/Analysis/Loads.cpp b/llvm/lib/Analysis/Loads.cpp
index c9671577441a63..f370dfcec709ca 100644
--- a/llvm/lib/Analysis/Loads.cpp
+++ b/llvm/lib/Analysis/Loads.cpp
@@ -740,7 +740,7 @@ static const DataLayout &getDLFromVal(const Value *V) {
 
 // Returns true if `To` is a null pointer, constant dereferenceable pointer or
 // both pointers have the same underlying objects.
-static bool isPointerAlwaysReplacable(const Value *From, const Value *To) {
+static bool isPointerAlwaysReplaceable(const Value *From, const Value *To) {
   if (isa<ConstantPointerNull>(To))
     return true;
   if (isa<Constant>(To) &&
@@ -758,7 +758,7 @@ bool llvm::canReplacePointersInUseIfEqual(const Use &U, const Value *To) {
   if (!To->getType()->isPointerTy())
     return true;
 
-  if (isPointerAlwaysReplacable(&*U, To))
+  if (isPointerAlwaysReplaceable(&*U, To))
     return true;
   return isPointerUseReplacable(U);
 }
@@ -769,5 +769,5 @@ bool llvm::canReplacePointersIfEqual(const Value *From, const Value *To) {
   if (!From->getType()->isPointerTy())
     return true;
 
-  return isPointerAlwaysReplacable(From, To);
+  return isPointerAlwaysReplaceable(From, To);
 }
diff --git a/llvm/test/Transforms/GVN/condprop.ll b/llvm/test/Transforms/GVN/condprop.ll
index 2b202b1701150e..6402a23157729c 100644
--- a/llvm/test/Transforms/GVN/condprop.ll
+++ b/llvm/test/Transforms/GVN/condprop.ll
@@ -214,11 +214,11 @@ define void @test4(i1 %b, i32 %x) {
 ; CHECK-NEXT:    br i1 [[B:%.*]], label [[SW:%.*]], label [[CASE3:%.*]]
 ; CHECK:       sw:
 ; CHECK-NEXT:    switch i32 [[X:%.*]], label [[DEFAULT:%.*]] [
-; CHECK-NEXT:    i32 0, label [[CASE0:%.*]]
-; CHECK-NEXT:    i32 1, label [[CASE1:%.*]]
-; CHECK-NEXT:    i32 2, label [[CASE0]]
-; CHECK-NEXT:    i32 3, label [[CASE3]]
-; CHECK-NEXT:    i32 4, label [[DEFAULT]]
+; CHECK-NEXT:      i32 0, label [[CASE0:%.*]]
+; CHECK-NEXT:      i32 1, label [[CASE1:%.*]]
+; CHECK-NEXT:      i32 2, label [[CASE0]]
+; CHECK-NEXT:      i32 3, label [[CASE3]]
+; CHECK-NEXT:      i32 4, label [[DEFAULT]]
 ; CHECK-NEXT:    ]
 ; CHECK:       default:
 ; CHECK-NEXT:    call void @bar(i32 [[X]])
@@ -553,14 +553,14 @@ end:
   ret i32 %ret
 }
 
-define void @test14(ptr %ptr1, ptr noalias %ptr2) {
+define void @test14(ptr %ptr1, ptr noalias %ptr2, i1 %b1, i1 %b2) {
 ; CHECK-LABEL: @test14(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[GEP1:%.*]] = getelementptr inbounds i32, ptr [[PTR1:%.*]], i32 1
 ; CHECK-NEXT:    [[GEP2:%.*]] = getelementptr inbounds i32, ptr [[PTR1]], i32 2
 ; CHECK-NEXT:    br label [[LOOP:%.*]]
 ; CHECK:       loop:
-; CHECK-NEXT:    br i1 undef, label [[LOOP_IF1_CRIT_EDGE:%.*]], label [[THEN:%.*]]
+; CHECK-NEXT:    br i1 [[B1:%.*]], label [[LOOP_IF1_CRIT_EDGE:%.*]], label [[THEN:%.*]]
 ; CHECK:       loop.if1_crit_edge:
 ; CHECK-NEXT:    [[VAL2_PRE:%.*]] = load i32, ptr [[GEP2]], align 4
 ; CHECK-NEXT:    br label [[IF1:%.*]]
@@ -578,7 +578,7 @@ define void @test14(ptr %ptr1, ptr noalias %ptr2) {
 ; CHECK-NEXT:    [[PHI3:%.*]] = phi ptr [ [[GEP2]], [[THEN]] ], [ [[PTR1]], [[IF2]] ]
 ; CHECK-NEXT:    [[VAL3]] = load i32, ptr [[GEP2]], align 4
 ; CHECK-NEXT:    store i32 [[VAL3]], ptr [[PHI3]], align 4
-; CHECK-NEXT:    br i1 undef, label [[LOOP]], label [[IF1]]
+; CHECK-NEXT:    br i1 [[B2:%.*]], label [[LOOP]], label [[IF1]]
 ;
 entry:
   %gep1 = getelementptr inbounds i32, ptr %ptr1, i32 1
@@ -587,7 +587,7 @@ entry:
 
 loop:
   %phi1 = phi ptr [ %gep3, %loop.end ], [ %gep1, %entry ]
-  br i1 undef, label %if1, label %then
+  br i1 %b1, label %if1, label %then
 
 
 if1:
@@ -608,20 +608,20 @@ loop.end:
   %val3 = load i32, ptr %gep2, align 4
   store i32 %val3, ptr %phi3, align 4
   %gep3 = getelementptr inbounds i32, ptr %ptr1, i32 1
-  br i1 undef, label %loop, label %if1
+  br i1 %b2, label %loop, label %if1
 }
 
 ; Make sure that the call to use_ptr does not have %p1
-define void @single_phi1(ptr %p1) {
+define void @single_phi1(ptr %p0, ptr %p1, i8 %s) {
 ; CHECK-LABEL: @single_phi1(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[P2:%.*]] = load ptr, ptr poison, align 8
+; CHECK-NEXT:    [[P2:%.*]] = load ptr, ptr [[P0:%.*]], align 8
 ; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq ptr [[P2]], [[P1:%.*]]
 ; CHECK-NEXT:    br i1 [[CMP1]], label [[BB4:%.*]], label [[BB1:%.*]]
 ; CHECK:       bb1:
-; CHECK-NEXT:    switch i8 poison, label [[BB2:%.*]] [
-; CHECK-NEXT:    i8 0, label [[BB1]]
-; CHECK-NEXT:    i8 1, label [[BB3:%.*]]
+; CHECK-NEXT:    switch i8 [[S:%.*]], label [[BB2:%.*]] [
+; CHECK-NEXT:      i8 0, label [[BB1]]
+; CHECK-NEXT:      i8 1, label [[BB3:%.*]]
 ; CHECK-NEXT:    ]
 ; CHECK:       bb2:
 ; CHECK-NEXT:    unreachable
@@ -633,12 +633,12 @@ define void @single_phi1(ptr %p1) {
 ; CHECK-NEXT:    ret void
 ;
 entry:
-  %p2 = load ptr, ptr poison, align 8
+  %p2 = load ptr, ptr %p0, align 8
   %cmp1 = icmp eq ptr %p2, %p1
   br i1 %cmp1, label %bb4, label %bb1
 
 bb1:
-  switch i8 poison, label %bb2 [
+  switch i8 %s, label %bb2 [
   i8 0, label %bb1
   i8 1, label %bb3
   ]
@@ -650,23 +650,23 @@ bb3:
   br label %bb4
 
 bb4:
-  %phi1 = phi ptr [ %p2, %entry ], [ poison, %bb3 ]
+  %phi1 = phi ptr [ %p2, %entry ], [ %p2, %bb3 ]
   %cmp2 = icmp eq ptr %phi1, %p1
   call void @use_bool(i1 %cmp2)
   call void @use_ptr(ptr %phi1)
   ret void
 }
 
-define void @single_phi2(ptr %p1) {
+define void @single_phi2(ptr %p0, ptr %p1, i8 %s) {
 ; CHECK-LABEL: @single_phi2(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[P2:%.*]] = load ptr, ptr poison, align 8
+; CHECK-NEXT:    [[P2:%.*]] = load ptr, ptr [[P0:%.*]], align 8
 ; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq ptr [[P2]], [[P1:%.*]]
 ; CHECK-NEXT:    br i1 [[CMP1]], label [[BB4:%.*]], label [[BB1:%.*]]
 ; CHECK:       bb1:
-; CHECK-NEXT:    switch i8 poison, label [[BB2:%.*]] [
-; CHECK-NEXT:    i8 0, label [[BB1]]
-; CHECK-NEXT:    i8 1, label [[BB3:%.*]]
+; CHECK-NEXT:    switch i8 [[S:%.*]], label [[BB2:%.*]] [
+; CHECK-NEXT:      i8 0, label [[BB1]]
+; CHECK-NEXT:      i8 1, label [[BB3:%.*]]
 ; CHECK-NEXT:    ]
 ; CHECK:       bb2:
 ; CHECK-NEXT:    br label [[BB4]]
@@ -678,12 +678,12 @@ define void @single_phi2(ptr %p1) {
 ; CHECK-NEXT:    ret void
 ;
 entry:
-  %p2 = load ptr, ptr poison, align 8
+  %p2 = load ptr, ptr %p0, align 8
   %cmp1 = icmp eq ptr %p2, %p1
   br i1 %cmp1, label %bb4, label %bb1
 
 bb1:
-  switch i8 poison, label %bb2 [
+  switch i8 %s, label %bb2 [
   i8 0, label %bb1
   i8 1, label %bb3
   ]
@@ -695,23 +695,23 @@ bb3:
   br label %bb4
 
 bb4:
-  %phi1 = phi ptr [ %p2, %entry ], [ %p2, %bb2 ], [ poison, %bb3 ]
+  %phi1 = phi ptr [ %p2, %entry ], [ %p2, %bb2 ], [ %p2, %bb3 ]
   %cmp2 = icmp eq ptr %phi1, %p1
   call void @use_bool(i1 %cmp2)
   call void @use_ptr(ptr %phi1)
   ret void
 }
 
-define void @multiple_phi1(ptr %p1) {
+define void @multiple_phi1(ptr %p0, ptr %p1, i8 %s) {
 ; CHECK-LABEL: @multiple_phi1(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[P2:%.*]] = load ptr, ptr poison, align 8
+; CHECK-NEXT:    [[P2:%.*]] = load ptr, ptr [[P0:%.*]], align 8
 ; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq ptr [[P2]], [[P1:%.*]]
 ; CHECK-NEXT:    br i1 [[CMP1]], label [[BB4:%.*]], label [[BB1:%.*]]
 ; CHECK:       bb1:
-; CHECK-NEXT:    switch i8 poison, label [[BB2:%.*]] [
-; CHECK-NEXT:    i8 0, label [[BB1]]
-; CHECK-NEXT:    i8 1, label [[BB3:%.*]]
+; CHECK-NEXT:    switch i8 [[S:%.*]], label [[BB2:%.*]] [
+; CHECK-NEXT:      i8 0, label [[BB1]]
+; CHECK-NEXT:      i8 1, label [[BB3:%.*]]
 ; CHECK-NEXT:    ]
 ; CHECK:       bb2:
 ; CHECK-NEXT:    unreachable
@@ -725,12 +725,12 @@ define void @multiple_phi1(ptr %p1) {
 ; CHECK-NEXT:    br label [[BB5]]
 ;
 entry:
-  %p2 = load ptr, ptr poison, align 8
+  %p2 = load ptr, ptr %p0, align 8
   %cmp1 = icmp eq ptr %p2, %p1
   br i1 %cmp1, label %bb4, label %bb1
 
 bb1:
-  switch i8 poison, label %bb2 [
+  switch i8 %s, label %bb2 [
   i8 0, label %bb1
   i8 1, label %bb3
   ]
@@ -753,16 +753,16 @@ bb5:
   br label %bb5
 }
 
-define void @multiple_phi2(ptr %p1) {
+define void @multiple_phi2(ptr %p0, ptr %p1, i8 %s) {
 ; CHECK-LABEL: @multiple_phi2(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[P2:%.*]] = load ptr, ptr poison, align 8
+; CHECK-NEXT:    [[P2:%.*]] = load ptr, ptr [[P0:%.*]], align 8
 ; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq ptr [[P2]], [[P1:%.*]]
 ; CHECK-NEXT:    br i1 [[CMP1]], label [[BB4:%.*]], label [[BB1:%.*]]
 ; CHECK:       bb1:
-; CHECK-NEXT:    switch i8 poison, label [[BB2:%.*]] [
-; CHECK-NEXT:    i8 0, label [[BB1]]
-; CHECK-NEXT:    i8 1, label [[BB3:%.*]]
+; CHECK-NEXT:    switch i8 [[S:%.*]], label [[BB2:%.*]] [
+; CHECK-NEXT:      i8 0, label [[BB1]]
+; CHECK-NEXT:      i8 1, label [[BB3:%.*]]
 ; CHECK-NEXT:    ]
 ; CHECK:       bb2:
 ; CHECK-NEXT:    br label [[BB4]]
@@ -776,12 +776,12 @@ define void @multiple_phi2(ptr %p1) {
 ; CHECK-NEXT:    br label [[BB5]]
 ;
 entry:
-  %p2 = load ptr, ptr poison, align 8
+  %p2 = load ptr, ptr %p0, align 8
   %cmp1 = icmp eq ptr %p2, %p1
   br i1 %cmp1, label %bb4, label %bb1
 
 bb1:
-  switch i8 poison, label %bb2 [
+  switch i8 %s, label %bb2 [
   i8 0, label %bb1
   i8 1, label %bb3
   ]

>From 022ddefbd30c63b86bed755478c5d2dfaeaa2b04 Mon Sep 17 00:00:00 2001
From: "Nadeem, Usman" <mnadeem at quicinc.com>
Date: Mon, 22 Apr 2024 13:05:32 -0700
Subject: [PATCH 3/7] Use DataLayout as function argument

Change-Id: Ia70757984f307a0be2d0ef9141325be71653c321
---
 llvm/include/llvm/Analysis/Loads.h         |  6 ++-
 llvm/include/llvm/Transforms/Utils/Local.h | 19 +++++---
 llvm/lib/Analysis/Loads.cpp                | 26 ++++-------
 llvm/lib/Transforms/Scalar/EarlyCSE.cpp    |  5 ++-
 llvm/lib/Transforms/Scalar/GVN.cpp         | 14 +++---
 llvm/lib/Transforms/Scalar/LoopSink.cpp    |  2 +-
 llvm/lib/Transforms/Utils/Local.cpp        | 52 ++++++++++++----------
 llvm/unittests/Analysis/LoadsTest.cpp      | 19 ++++----
 8 files changed, 77 insertions(+), 66 deletions(-)

diff --git a/llvm/include/llvm/Analysis/Loads.h b/llvm/include/llvm/Analysis/Loads.h
index e39a1d6cafad23..a8d954b9872d97 100644
--- a/llvm/include/llvm/Analysis/Loads.h
+++ b/llvm/include/llvm/Analysis/Loads.h
@@ -180,8 +180,10 @@ Value *findAvailablePtrLoadStore(const MemoryLocation &Loc, Type *AccessTy,
 /// instructions, as well as when we are replacing with a null pointer.
 /// Additionally it also allows replacement of pointers when both pointers have
 /// the same underlying object.
-bool canReplacePointersIfEqual(const Value *From, const Value *To);
-bool canReplacePointersInUseIfEqual(const Use &U, const Value *To);
+bool canReplacePointersIfEqual(const Value *From, const Value *To,
+                               const DataLayout &DL);
+bool canReplacePointersInUseIfEqual(const Use &U, const Value *To,
+                                    const DataLayout &DL);
 }
 
 #endif
diff --git a/llvm/include/llvm/Transforms/Utils/Local.h b/llvm/include/llvm/Transforms/Utils/Local.h
index e2143b5bfbe2fd..c73a2c0304ea29 100644
--- a/llvm/include/llvm/Transforms/Utils/Local.h
+++ b/llvm/include/llvm/Transforms/Utils/Local.h
@@ -433,24 +433,29 @@ unsigned replaceNonLocalUsesWith(Instruction *From, Value *To);
 
 /// Replace each use of 'From' with 'To' if that use is dominated by
 /// the given edge.  Returns the number of replacements made.
-unsigned replaceDominatedUsesWith(Value *From, Value *To, DominatorTree &DT,
+unsigned replaceDominatedUsesWith(Value *From, Value *To, const DataLayout &DL,
+                                  DominatorTree &DT,
                                   const BasicBlockEdge &Edge);
 /// Replace each use of 'From' with 'To' if that use is dominated by
 /// the end of the given BasicBlock. Returns the number of replacements made.
-unsigned replaceDominatedUsesWith(Value *From, Value *To, DominatorTree &DT,
-                                  const BasicBlock *BB);
+unsigned replaceDominatedUsesWith(Value *From, Value *To, const DataLayout &DL,
+                                  DominatorTree &DT, const BasicBlock *BB);
 /// Replace each use of 'From' with 'To' if that use is dominated by
 /// the given edge and the callback ShouldReplace returns true. Returns the
 /// number of replacements made.
 unsigned replaceDominatedUsesWithIf(
-    Value *From, Value *To, DominatorTree &DT, const BasicBlockEdge &Edge,
-    function_ref<bool(const Use &U, const Value *To)> ShouldReplace);
+    Value *From, Value *To, const DataLayout &DL, DominatorTree &DT,
+    const BasicBlockEdge &Edge,
+    function_ref<bool(const Use &U, const Value *To, const DataLayout &DL)>
+        ShouldReplace);
 /// Replace each use of 'From' with 'To' if that use is dominated by
 /// the end of the given BasicBlock and the callback ShouldReplace returns true.
 /// Returns the number of replacements made.
 unsigned replaceDominatedUsesWithIf(
-    Value *From, Value *To, DominatorTree &DT, const BasicBlock *BB,
-    function_ref<bool(const Use &U, const Value *To)> ShouldReplace);
+    Value *From, Value *To, const DataLayout &DL, DominatorTree &DT,
+    const BasicBlock *BB,
+    function_ref<bool(const Use &U, const Value *To, const DataLayout &DL)>
+        ShouldReplace);
 
 /// Return true if this call calls a gc leaf function.
 ///
diff --git a/llvm/lib/Analysis/Loads.cpp b/llvm/lib/Analysis/Loads.cpp
index ed47e38ad690ee..0b46e42e509b35 100644
--- a/llvm/lib/Analysis/Loads.cpp
+++ b/llvm/lib/Analysis/Loads.cpp
@@ -730,46 +730,38 @@ static bool isPointerUseReplacable(const Use &U, int MaxLookup = 6) {
   return false;
 }
 
-static const DataLayout &getDLFromVal(const Value *V) {
-  if (const Argument *A = dyn_cast<Argument>(V))
-    return A->getParent()->getParent()->getDataLayout();
-  if (const Instruction *I = dyn_cast<Instruction>(V))
-    return I->getModule()->getDataLayout();
-  if (const GlobalValue *GV = dyn_cast<GlobalValue>(V))
-    return GV->getParent()->getDataLayout();
-  llvm_unreachable("Unknown Value type");
-}
-
 // Returns true if `To` is a null pointer, constant dereferenceable pointer or
 // both pointers have the same underlying objects.
-static bool isPointerAlwaysReplaceable(const Value *From, const Value *To) {
+static bool isPointerAlwaysReplaceable(const Value *From, const Value *To,
+                                       const DataLayout &DL) {
   if (isa<ConstantPointerNull>(To))
     return true;
   if (isa<Constant>(To) &&
-      isDereferenceablePointer(To, Type::getInt8Ty(To->getContext()),
-                               getDLFromVal(From)))
+      isDereferenceablePointer(To, Type::getInt8Ty(To->getContext()), DL))
     return true;
   if (getUnderlyingObject(From) == getUnderlyingObject(To))
     return true;
   return false;
 }
 
-bool llvm::canReplacePointersInUseIfEqual(const Use &U, const Value *To) {
+bool llvm::canReplacePointersInUseIfEqual(const Use &U, const Value *To,
+                                          const DataLayout &DL) {
   assert(U->getType() == To->getType() && "values must have matching types");
   // Not a pointer, just return true.
   if (!To->getType()->isPointerTy())
     return true;
 
-  if (isPointerAlwaysReplaceable(&*U, To))
+  if (isPointerAlwaysReplaceable(&*U, To, DL))
     return true;
   return isPointerUseReplacable(U);
 }
 
-bool llvm::canReplacePointersIfEqual(const Value *From, const Value *To) {
+bool llvm::canReplacePointersIfEqual(const Value *From, const Value *To,
+                                     const DataLayout &DL) {
   assert(From->getType() == To->getType() && "values must have matching types");
   // Not a pointer, just return true.
   if (!From->getType()->isPointerTy())
     return true;
 
-  return isPointerAlwaysReplaceable(From, To);
+  return isPointerAlwaysReplaceable(From, To, DL);
 }
diff --git a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
index f736d429cb6381..b7148f64f5f032 100644
--- a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
+++ b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
@@ -1228,8 +1228,9 @@ bool EarlyCSE::handleBranchCondition(Instruction *CondInst,
       LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n");
     } else {
       // Replace all dominated uses with the known value.
-      if (unsigned Count = replaceDominatedUsesWith(Curr, TorF, DT,
-                                                    BasicBlockEdge(Pred, BB))) {
+      if (unsigned Count = replaceDominatedUsesWith(
+              Curr, TorF, Curr->getModule()->getDataLayout(), DT,
+              BasicBlockEdge(Pred, BB))) {
         NumCSECVP += Count;
         MadeChanges = true;
       }
diff --git a/llvm/lib/Transforms/Scalar/GVN.cpp b/llvm/lib/Transforms/Scalar/GVN.cpp
index a8480742600b99..2acdefca9fbeb1 100644
--- a/llvm/lib/Transforms/Scalar/GVN.cpp
+++ b/llvm/lib/Transforms/Scalar/GVN.cpp
@@ -2420,6 +2420,10 @@ bool GVNPass::propagateEquality(Value *LHS, Value *RHS,
     if (isa<Constant>(LHS) || (isa<Argument>(LHS) && !isa<Constant>(RHS)))
       std::swap(LHS, RHS);
     assert((isa<Argument>(LHS) || isa<Instruction>(LHS)) && "Unexpected value!");
+    const DataLayout &DL =
+        isa<Argument>(LHS)
+            ? cast<Argument>(LHS)->getParent()->getParent()->getDataLayout()
+            : cast<Instruction>(LHS)->getModule()->getDataLayout();
 
     // If there is no obvious reason to prefer the left-hand side over the
     // right-hand side, ensure the longest lived term is on the right-hand side,
@@ -2447,7 +2451,7 @@ bool GVNPass::propagateEquality(Value *LHS, Value *RHS,
     // The leader table only tracks basic blocks, not edges. Only add to if we
     // have the simple case where the edge dominates the end.
     if (RootDominatesEnd && !isa<Instruction>(RHS) &&
-        canReplacePointersIfEqual(LHS, RHS))
+        canReplacePointersIfEqual(LHS, RHS, DL))
       addToLeaderTable(LVN, RHS, Root.getEnd());
 
     // Replace all occurrences of 'LHS' with 'RHS' everywhere in the scope.  As
@@ -2456,9 +2460,9 @@ bool GVNPass::propagateEquality(Value *LHS, Value *RHS,
     if (!LHS->hasOneUse()) {
       unsigned NumReplacements =
           DominatesByEdge
-              ? replaceDominatedUsesWithIf(LHS, RHS, *DT, Root,
+              ? replaceDominatedUsesWithIf(LHS, RHS, DL, *DT, Root,
                                            canReplacePointersInUseIfEqual)
-              : replaceDominatedUsesWithIf(LHS, RHS, *DT, Root.getStart(),
+              : replaceDominatedUsesWithIf(LHS, RHS, DL, *DT, Root.getStart(),
                                            canReplacePointersInUseIfEqual);
 
       if (NumReplacements > 0) {
@@ -2523,8 +2527,8 @@ bool GVNPass::propagateEquality(Value *LHS, Value *RHS,
         if (NotCmp && isa<Instruction>(NotCmp)) {
           unsigned NumReplacements =
               DominatesByEdge
-                  ? replaceDominatedUsesWith(NotCmp, NotVal, *DT, Root)
-                  : replaceDominatedUsesWith(NotCmp, NotVal, *DT,
+                  ? replaceDominatedUsesWith(NotCmp, NotVal, DL, *DT, Root)
+                  : replaceDominatedUsesWith(NotCmp, NotVal, DL, *DT,
                                              Root.getStart());
           Changed |= NumReplacements > 0;
           NumGVNEqProp += NumReplacements;
diff --git a/llvm/lib/Transforms/Scalar/LoopSink.cpp b/llvm/lib/Transforms/Scalar/LoopSink.cpp
index 6eedf95e7575ec..140e7c52062021 100644
--- a/llvm/lib/Transforms/Scalar/LoopSink.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopSink.cpp
@@ -259,7 +259,7 @@ static bool sinkInstruction(
       return UIToReplace->getParent() == N && !isa<PHINode>(UIToReplace);
     });
     // Replaces uses of I with IC in blocks dominated by N
-    replaceDominatedUsesWith(&I, IC, DT, N);
+    replaceDominatedUsesWith(&I, IC, I.getModule()->getDataLayout(), DT, N);
     LLVM_DEBUG(dbgs() << "Sinking a clone of " << I << " To: " << N->getName()
                       << '\n');
     NumLoopSunkCloned++;
diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp
index 99e2c6a1293c23..90d6742af8bfc0 100644
--- a/llvm/lib/Transforms/Utils/Local.cpp
+++ b/llvm/lib/Transforms/Utils/Local.cpp
@@ -3429,18 +3429,16 @@ void llvm::patchReplacementInstruction(Instruction *I, Value *Repl) {
   combineMetadataForCSE(ReplInst, I, false);
 }
 
-template <typename RootType, typename DominatesFn>
-static unsigned replaceDominatedUsesWith(
-    Value *From, Value *To, const RootType &Root, const DominatesFn &Dominates,
-    std::optional<function_ref<bool(const Use &U, const Value *To)>>
-        ShouldReplace) {
+template <typename RootType, typename ShouldReplaceFn>
+static unsigned replaceDominatedUsesWith(Value *From, Value *To,
+                                         const DataLayout &DL,
+                                         const RootType &Root,
+                                         const ShouldReplaceFn &ShouldReplace) {
   assert(From->getType() == To->getType());
 
   unsigned Count = 0;
   for (Use &U : llvm::make_early_inc_range(From->uses())) {
-    if (!Dominates(Root, U))
-      continue;
-    if (ShouldReplace.has_value() && !ShouldReplace.value()(U, To))
+    if (!ShouldReplace(Root, U))
       continue;
     LLVM_DEBUG(dbgs() << "Replace dominated use of '";
                From->printAsOperand(dbgs());
@@ -3467,39 +3465,47 @@ unsigned llvm::replaceNonLocalUsesWith(Instruction *From, Value *To) {
 }
 
 unsigned llvm::replaceDominatedUsesWith(Value *From, Value *To,
-                                        DominatorTree &DT,
+                                        const DataLayout &DL, DominatorTree &DT,
                                         const BasicBlockEdge &Root) {
   auto Dominates = [&DT](const BasicBlockEdge &Root, const Use &U) {
     return DT.dominates(Root, U);
   };
-  return ::replaceDominatedUsesWith(From, To, Root, Dominates, std::nullopt);
+  return ::replaceDominatedUsesWith(From, To, DL, Root, Dominates);
 }
 
 unsigned llvm::replaceDominatedUsesWith(Value *From, Value *To,
-                                        DominatorTree &DT,
+                                        const DataLayout &DL, DominatorTree &DT,
                                         const BasicBlock *BB) {
   auto Dominates = [&DT](const BasicBlock *BB, const Use &U) {
     return DT.dominates(BB, U);
   };
-  return ::replaceDominatedUsesWith(From, To, BB, Dominates, std::nullopt);
+  return ::replaceDominatedUsesWith(From, To, DL, BB, Dominates);
 }
 
 unsigned llvm::replaceDominatedUsesWithIf(
-    Value *From, Value *To, DominatorTree &DT, const BasicBlockEdge &Root,
-    function_ref<bool(const Use &U, const Value *To)> ShouldReplace) {
-  auto Dominates = [&DT](const BasicBlockEdge &Root, const Use &U) {
-    return DT.dominates(Root, U);
-  };
-  return ::replaceDominatedUsesWith(From, To, Root, Dominates, ShouldReplace);
+    Value *From, Value *To, const DataLayout &DL, DominatorTree &DT,
+    const BasicBlockEdge &Root,
+    function_ref<bool(const Use &U, const Value *To, const DataLayout &DL)>
+        ShouldReplace) {
+  auto DominatesAndShouldReplace =
+      [ShouldReplace, To, &DT, &DL](const BasicBlockEdge &Root, const Use &U) {
+        return DT.dominates(Root, U) && ShouldReplace(U, To, DL);
+      };
+  return ::replaceDominatedUsesWith(From, To, DL, Root,
+                                    DominatesAndShouldReplace);
 }
 
 unsigned llvm::replaceDominatedUsesWithIf(
-    Value *From, Value *To, DominatorTree &DT, const BasicBlock *BB,
-    function_ref<bool(const Use &U, const Value *To)> ShouldReplace) {
-  auto Dominates = [&DT](const BasicBlock *BB, const Use &U) {
-    return DT.dominates(BB, U);
+    Value *From, Value *To, const DataLayout &DL, DominatorTree &DT,
+    const BasicBlock *BB,
+    function_ref<bool(const Use &U, const Value *To, const DataLayout &DL)>
+        ShouldReplace) {
+  auto DominatesAndShouldReplace = [ShouldReplace, To, &DT,
+                                    &DL](const BasicBlock *BB, const Use &U) {
+    return DT.dominates(BB, U) && ShouldReplace(U, To, DL);
   };
-  return ::replaceDominatedUsesWith(From, To, BB, Dominates, ShouldReplace);
+  return ::replaceDominatedUsesWith(From, To, DL, BB,
+                                    DominatesAndShouldReplace);
 }
 bool llvm::callsGCLeafFunction(const CallBase *Call,
                                const TargetLibraryInfo &TLI) {
diff --git a/llvm/unittests/Analysis/LoadsTest.cpp b/llvm/unittests/Analysis/LoadsTest.cpp
index 0694889596fa23..dc9d61a0d56acf 100644
--- a/llvm/unittests/Analysis/LoadsTest.cpp
+++ b/llvm/unittests/Analysis/LoadsTest.cpp
@@ -81,6 +81,7 @@ define void @f(i32* %p1, i32* %p2, i64 %i) {
   ret void
 }
 )IR");
+  const DataLayout &DL = M->getDataLayout();
   auto *GV = M->getNamedValue("f");
   ASSERT_TRUE(GV);
   auto *F = dyn_cast<Function>(GV);
@@ -95,21 +96,21 @@ define void @f(i32* %p1, i32* %p2, i64 %i) {
   // We cannot replace two pointers in arbitrary instructions unless we are
   // replacing with null, a constant dereferencable pointer or they have the
   // same underlying object.
-  EXPECT_FALSE(canReplacePointersIfEqual(ConstDerefPtr, P1));
-  EXPECT_FALSE(canReplacePointersIfEqual(P1, P2));
-  EXPECT_TRUE(canReplacePointersIfEqual(P1, ConstDerefPtr));
-  EXPECT_TRUE(canReplacePointersIfEqual(P1, NullPtr));
+  EXPECT_FALSE(canReplacePointersIfEqual(ConstDerefPtr, P1, DL));
+  EXPECT_FALSE(canReplacePointersIfEqual(P1, P2, DL));
+  EXPECT_TRUE(canReplacePointersIfEqual(P1, ConstDerefPtr, DL));
+  EXPECT_TRUE(canReplacePointersIfEqual(P1, NullPtr, DL));
 
   GetElementPtrInst *BasedOnP1 = cast<GetElementPtrInst>(&*++InstIter);
-  EXPECT_TRUE(canReplacePointersIfEqual(BasedOnP1, P1));
-  EXPECT_FALSE(canReplacePointersIfEqual(BasedOnP1, P2));
+  EXPECT_TRUE(canReplacePointersIfEqual(BasedOnP1, P1, DL));
+  EXPECT_FALSE(canReplacePointersIfEqual(BasedOnP1, P2, DL));
 
   // // We can replace two arbitrary pointers in icmp and ptrtoint instructions.
   auto P1UseIter = P1->use_begin();
   const Use &PtrToIntUse = *P1UseIter;
   const Use &IcmpUse = *++P1UseIter;
   const Use &GEPUse = *++P1UseIter;
-  EXPECT_FALSE(canReplacePointersInUseIfEqual(GEPUse, P2));
-  EXPECT_TRUE(canReplacePointersInUseIfEqual(PtrToIntUse, P2));
-  EXPECT_TRUE(canReplacePointersInUseIfEqual(IcmpUse, P2));
+  EXPECT_FALSE(canReplacePointersInUseIfEqual(GEPUse, P2, DL));
+  EXPECT_TRUE(canReplacePointersInUseIfEqual(PtrToIntUse, P2, DL));
+  EXPECT_TRUE(canReplacePointersInUseIfEqual(IcmpUse, P2, DL));
 }

>From cb54b767400defc806e38b4957a04221bb47c64f Mon Sep 17 00:00:00 2001
From: "Nadeem, Usman" <mnadeem at quicinc.com>
Date: Tue, 23 Apr 2024 10:11:56 -0700
Subject: [PATCH 4/7] DL captured in callback, add comment

Change-Id: Ifba7e54313c5d02bcba356e0a2e5e93d521b9c4b
---
 llvm/include/llvm/Transforms/Utils/Local.h | 19 ++++-------
 llvm/lib/Analysis/Loads.cpp                |  2 ++
 llvm/lib/Transforms/Scalar/EarlyCSE.cpp    |  5 ++-
 llvm/lib/Transforms/Scalar/GVN.cpp         | 16 +++++----
 llvm/lib/Transforms/Scalar/LoopSink.cpp    |  2 +-
 llvm/lib/Transforms/Utils/Local.cpp        | 38 +++++++++-------------
 6 files changed, 38 insertions(+), 44 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Utils/Local.h b/llvm/include/llvm/Transforms/Utils/Local.h
index c73a2c0304ea29..e2143b5bfbe2fd 100644
--- a/llvm/include/llvm/Transforms/Utils/Local.h
+++ b/llvm/include/llvm/Transforms/Utils/Local.h
@@ -433,29 +433,24 @@ unsigned replaceNonLocalUsesWith(Instruction *From, Value *To);
 
 /// Replace each use of 'From' with 'To' if that use is dominated by
 /// the given edge.  Returns the number of replacements made.
-unsigned replaceDominatedUsesWith(Value *From, Value *To, const DataLayout &DL,
-                                  DominatorTree &DT,
+unsigned replaceDominatedUsesWith(Value *From, Value *To, DominatorTree &DT,
                                   const BasicBlockEdge &Edge);
 /// Replace each use of 'From' with 'To' if that use is dominated by
 /// the end of the given BasicBlock. Returns the number of replacements made.
-unsigned replaceDominatedUsesWith(Value *From, Value *To, const DataLayout &DL,
-                                  DominatorTree &DT, const BasicBlock *BB);
+unsigned replaceDominatedUsesWith(Value *From, Value *To, DominatorTree &DT,
+                                  const BasicBlock *BB);
 /// Replace each use of 'From' with 'To' if that use is dominated by
 /// the given edge and the callback ShouldReplace returns true. Returns the
 /// number of replacements made.
 unsigned replaceDominatedUsesWithIf(
-    Value *From, Value *To, const DataLayout &DL, DominatorTree &DT,
-    const BasicBlockEdge &Edge,
-    function_ref<bool(const Use &U, const Value *To, const DataLayout &DL)>
-        ShouldReplace);
+    Value *From, Value *To, DominatorTree &DT, const BasicBlockEdge &Edge,
+    function_ref<bool(const Use &U, const Value *To)> ShouldReplace);
 /// Replace each use of 'From' with 'To' if that use is dominated by
 /// the end of the given BasicBlock and the callback ShouldReplace returns true.
 /// Returns the number of replacements made.
 unsigned replaceDominatedUsesWithIf(
-    Value *From, Value *To, const DataLayout &DL, DominatorTree &DT,
-    const BasicBlock *BB,
-    function_ref<bool(const Use &U, const Value *To, const DataLayout &DL)>
-        ShouldReplace);
+    Value *From, Value *To, DominatorTree &DT, const BasicBlock *BB,
+    function_ref<bool(const Use &U, const Value *To)> ShouldReplace);
 
 /// Return true if this call calls a gc leaf function.
 ///
diff --git a/llvm/lib/Analysis/Loads.cpp b/llvm/lib/Analysis/Loads.cpp
index 0b46e42e509b35..5c28aa65a05770 100644
--- a/llvm/lib/Analysis/Loads.cpp
+++ b/llvm/lib/Analysis/Loads.cpp
@@ -734,6 +734,8 @@ static bool isPointerUseReplacable(const Use &U, int MaxLookup = 6) {
 // both pointers have the same underlying objects.
 static bool isPointerAlwaysReplaceable(const Value *From, const Value *To,
                                        const DataLayout &DL) {
+  // This is not strictly correct, but we do it for now to retain important
+  // optimizations.
   if (isa<ConstantPointerNull>(To))
     return true;
   if (isa<Constant>(To) &&
diff --git a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
index b7148f64f5f032..f736d429cb6381 100644
--- a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
+++ b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
@@ -1228,9 +1228,8 @@ bool EarlyCSE::handleBranchCondition(Instruction *CondInst,
       LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n");
     } else {
       // Replace all dominated uses with the known value.
-      if (unsigned Count = replaceDominatedUsesWith(
-              Curr, TorF, Curr->getModule()->getDataLayout(), DT,
-              BasicBlockEdge(Pred, BB))) {
+      if (unsigned Count = replaceDominatedUsesWith(Curr, TorF, DT,
+                                                    BasicBlockEdge(Pred, BB))) {
         NumCSECVP += Count;
         MadeChanges = true;
       }
diff --git a/llvm/lib/Transforms/Scalar/GVN.cpp b/llvm/lib/Transforms/Scalar/GVN.cpp
index 2acdefca9fbeb1..d829e92b24440a 100644
--- a/llvm/lib/Transforms/Scalar/GVN.cpp
+++ b/llvm/lib/Transforms/Scalar/GVN.cpp
@@ -2458,12 +2458,16 @@ bool GVNPass::propagateEquality(Value *LHS, Value *RHS,
     // LHS always has at least one use that is not dominated by Root, this will
     // never do anything if LHS has only one use.
     if (!LHS->hasOneUse()) {
+      // Create a callback that captures the DL.
+      auto canReplacePointersCallBack = [&DL](const Use &U, const Value *To) {
+        return canReplacePointersInUseIfEqual(U, To, DL);
+      };
       unsigned NumReplacements =
           DominatesByEdge
-              ? replaceDominatedUsesWithIf(LHS, RHS, DL, *DT, Root,
-                                           canReplacePointersInUseIfEqual)
-              : replaceDominatedUsesWithIf(LHS, RHS, DL, *DT, Root.getStart(),
-                                           canReplacePointersInUseIfEqual);
+              ? replaceDominatedUsesWithIf(LHS, RHS, *DT, Root,
+                                           canReplacePointersCallBack)
+              : replaceDominatedUsesWithIf(LHS, RHS, *DT, Root.getStart(),
+                                           canReplacePointersCallBack);
 
       if (NumReplacements > 0) {
         Changed = true;
@@ -2527,8 +2531,8 @@ bool GVNPass::propagateEquality(Value *LHS, Value *RHS,
         if (NotCmp && isa<Instruction>(NotCmp)) {
           unsigned NumReplacements =
               DominatesByEdge
-                  ? replaceDominatedUsesWith(NotCmp, NotVal, DL, *DT, Root)
-                  : replaceDominatedUsesWith(NotCmp, NotVal, DL, *DT,
+                  ? replaceDominatedUsesWith(NotCmp, NotVal, *DT, Root)
+                  : replaceDominatedUsesWith(NotCmp, NotVal, *DT,
                                              Root.getStart());
           Changed |= NumReplacements > 0;
           NumGVNEqProp += NumReplacements;
diff --git a/llvm/lib/Transforms/Scalar/LoopSink.cpp b/llvm/lib/Transforms/Scalar/LoopSink.cpp
index 140e7c52062021..6eedf95e7575ec 100644
--- a/llvm/lib/Transforms/Scalar/LoopSink.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopSink.cpp
@@ -259,7 +259,7 @@ static bool sinkInstruction(
       return UIToReplace->getParent() == N && !isa<PHINode>(UIToReplace);
     });
     // Replaces uses of I with IC in blocks dominated by N
-    replaceDominatedUsesWith(&I, IC, I.getModule()->getDataLayout(), DT, N);
+    replaceDominatedUsesWith(&I, IC, DT, N);
     LLVM_DEBUG(dbgs() << "Sinking a clone of " << I << " To: " << N->getName()
                       << '\n');
     NumLoopSunkCloned++;
diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp
index 90d6742af8bfc0..f1f9652902be93 100644
--- a/llvm/lib/Transforms/Utils/Local.cpp
+++ b/llvm/lib/Transforms/Utils/Local.cpp
@@ -3431,7 +3431,6 @@ void llvm::patchReplacementInstruction(Instruction *I, Value *Repl) {
 
 template <typename RootType, typename ShouldReplaceFn>
 static unsigned replaceDominatedUsesWith(Value *From, Value *To,
-                                         const DataLayout &DL,
                                          const RootType &Root,
                                          const ShouldReplaceFn &ShouldReplace) {
   assert(From->getType() == To->getType());
@@ -3465,48 +3464,43 @@ unsigned llvm::replaceNonLocalUsesWith(Instruction *From, Value *To) {
 }
 
 unsigned llvm::replaceDominatedUsesWith(Value *From, Value *To,
-                                        const DataLayout &DL, DominatorTree &DT,
+                                        DominatorTree &DT,
                                         const BasicBlockEdge &Root) {
   auto Dominates = [&DT](const BasicBlockEdge &Root, const Use &U) {
     return DT.dominates(Root, U);
   };
-  return ::replaceDominatedUsesWith(From, To, DL, Root, Dominates);
+  return ::replaceDominatedUsesWith(From, To, Root, Dominates);
 }
 
 unsigned llvm::replaceDominatedUsesWith(Value *From, Value *To,
-                                        const DataLayout &DL, DominatorTree &DT,
+                                        DominatorTree &DT,
                                         const BasicBlock *BB) {
   auto Dominates = [&DT](const BasicBlock *BB, const Use &U) {
     return DT.dominates(BB, U);
   };
-  return ::replaceDominatedUsesWith(From, To, DL, BB, Dominates);
+  return ::replaceDominatedUsesWith(From, To, BB, Dominates);
 }
 
 unsigned llvm::replaceDominatedUsesWithIf(
-    Value *From, Value *To, const DataLayout &DL, DominatorTree &DT,
-    const BasicBlockEdge &Root,
-    function_ref<bool(const Use &U, const Value *To, const DataLayout &DL)>
-        ShouldReplace) {
+    Value *From, Value *To, DominatorTree &DT, const BasicBlockEdge &Root,
+    function_ref<bool(const Use &U, const Value *To)> ShouldReplace) {
   auto DominatesAndShouldReplace =
-      [ShouldReplace, To, &DT, &DL](const BasicBlockEdge &Root, const Use &U) {
-        return DT.dominates(Root, U) && ShouldReplace(U, To, DL);
+      [&DT, &ShouldReplace, To](const BasicBlockEdge &Root, const Use &U) {
+        return DT.dominates(Root, U) && ShouldReplace(U, To);
       };
-  return ::replaceDominatedUsesWith(From, To, DL, Root,
-                                    DominatesAndShouldReplace);
+  return ::replaceDominatedUsesWith(From, To, Root, DominatesAndShouldReplace);
 }
 
 unsigned llvm::replaceDominatedUsesWithIf(
-    Value *From, Value *To, const DataLayout &DL, DominatorTree &DT,
-    const BasicBlock *BB,
-    function_ref<bool(const Use &U, const Value *To, const DataLayout &DL)>
-        ShouldReplace) {
-  auto DominatesAndShouldReplace = [ShouldReplace, To, &DT,
-                                    &DL](const BasicBlock *BB, const Use &U) {
-    return DT.dominates(BB, U) && ShouldReplace(U, To, DL);
+    Value *From, Value *To, DominatorTree &DT, const BasicBlock *BB,
+    function_ref<bool(const Use &U, const Value *To)> ShouldReplace) {
+  auto DominatesAndShouldReplace = [&DT, &ShouldReplace,
+                                    To](const BasicBlock *BB, const Use &U) {
+    return DT.dominates(BB, U) && ShouldReplace(U, To);
   };
-  return ::replaceDominatedUsesWith(From, To, DL, BB,
-                                    DominatesAndShouldReplace);
+  return ::replaceDominatedUsesWith(From, To, BB, DominatesAndShouldReplace);
 }
+
 bool llvm::callsGCLeafFunction(const CallBase *Call,
                                const TargetLibraryInfo &TLI) {
   // Check if the function is specifically marked as a gc leaf function.

>From 02332a9084a83cbdcc5545c7f8c7c61ba0a7e030 Mon Sep 17 00:00:00 2001
From: "Nadeem, Usman" <mnadeem at quicinc.com>
Date: Tue, 23 Apr 2024 13:08:37 -0700
Subject: [PATCH 5/7] Make isPointerUseReplacable iterative

Change-Id: Ic203125c66a72af9439547fc592c80b459290e86
---
 llvm/lib/Analysis/Loads.cpp | 33 ++++++++++++++++++---------------
 1 file changed, 18 insertions(+), 15 deletions(-)

diff --git a/llvm/lib/Analysis/Loads.cpp b/llvm/lib/Analysis/Loads.cpp
index 5c28aa65a05770..295de1343ab293 100644
--- a/llvm/lib/Analysis/Loads.cpp
+++ b/llvm/lib/Analysis/Loads.cpp
@@ -712,22 +712,25 @@ Value *llvm::FindAvailableLoadedValue(LoadInst *Load, BatchAAResults &AA,
 
 // Returns true if a use is either in an ICmp/PtrToInt or a Phi/Select that only
 // feeds into them.
-static bool isPointerUseReplacable(const Use &U, int MaxLookup = 6) {
-  if (MaxLookup == 0)
-    return false;
-
-  const User *User = U.getUser();
-  if (isa<ICmpInst>(User))
-    return true;
-  if (isa<PtrToIntInst>(User))
-    return true;
-  if (isa<PHINode, SelectInst>(User) &&
-      all_of(User->uses(), [&](const Use &Use) {
-        return isPointerUseReplacable(Use, MaxLookup - 1);
-      }))
-    return true;
+static bool isPointerUseReplacable(const Use &U) {
+  unsigned Limit = 40;
+  SmallVector<const User *> Worklist({U.getUser()});
+  SmallPtrSet<const User *, 8> Visited;
+
+  while (!Worklist.empty() && --Limit) {
+    auto *User = Worklist.pop_back_val();
+    Visited.insert(User);
+    if (isa<ICmpInst, PtrToIntInst>(User))
+      continue;
+    if (isa<PHINode, SelectInst>(User)) {
+      for (const auto &Use : User->uses())
+        if (!Visited.contains(Use.getUser()))
+          Worklist.push_back(Use.getUser());
+    } else
+      return false;
+  }
 
-  return false;
+  return Limit;
 }
 
 // Returns true if `To` is a null pointer, constant dereferenceable pointer or

>From 40d81813ee813c00301cf4237009170eaaff4d4a Mon Sep 17 00:00:00 2001
From: Usman Nadeem <usmanone at gmail.com>
Date: Tue, 23 Apr 2024 13:17:27 -0700
Subject: [PATCH 6/7] Update llvm/lib/Analysis/Loads.cpp

Co-authored-by: Eli Friedman <efriedma at quicinc.com>
---
 llvm/lib/Analysis/Loads.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Analysis/Loads.cpp b/llvm/lib/Analysis/Loads.cpp
index 295de1343ab293..b7b7b16874af60 100644
--- a/llvm/lib/Analysis/Loads.cpp
+++ b/llvm/lib/Analysis/Loads.cpp
@@ -730,7 +730,7 @@ static bool isPointerUseReplacable(const Use &U) {
       return false;
   }
 
-  return Limit;
+  return Limit != 0;
 }
 
 // Returns true if `To` is a null pointer, constant dereferenceable pointer or

>From 6f2533404eddab6d27d0cf937b56218eba3a6b7c Mon Sep 17 00:00:00 2001
From: "Nadeem, Usman" <mnadeem at quicinc.com>
Date: Tue, 23 Apr 2024 19:11:52 -0700
Subject: [PATCH 7/7] address comments

Change-Id: Iba33ade6b0dca3a8c1341f8bf99f1013d17ca3f3
---
 llvm/lib/Analysis/Loads.cpp           | 11 +++++------
 llvm/unittests/Analysis/LoadsTest.cpp |  2 +-
 2 files changed, 6 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Analysis/Loads.cpp b/llvm/lib/Analysis/Loads.cpp
index b7b7b16874af60..478302d687b534 100644
--- a/llvm/lib/Analysis/Loads.cpp
+++ b/llvm/lib/Analysis/Loads.cpp
@@ -719,14 +719,13 @@ static bool isPointerUseReplacable(const Use &U) {
 
   while (!Worklist.empty() && --Limit) {
     auto *User = Worklist.pop_back_val();
-    Visited.insert(User);
+    if (!Visited.insert(User).second)
+      continue;
     if (isa<ICmpInst, PtrToIntInst>(User))
       continue;
-    if (isa<PHINode, SelectInst>(User)) {
-      for (const auto &Use : User->uses())
-        if (!Visited.contains(Use.getUser()))
-          Worklist.push_back(Use.getUser());
-    } else
+    if (isa<PHINode, SelectInst>(User))
+      Worklist.append(User->user_begin(), User->user_end());
+    else
       return false;
   }
 
diff --git a/llvm/unittests/Analysis/LoadsTest.cpp b/llvm/unittests/Analysis/LoadsTest.cpp
index dc9d61a0d56acf..5da3feaf762f37 100644
--- a/llvm/unittests/Analysis/LoadsTest.cpp
+++ b/llvm/unittests/Analysis/LoadsTest.cpp
@@ -105,7 +105,7 @@ define void @f(i32* %p1, i32* %p2, i64 %i) {
   EXPECT_TRUE(canReplacePointersIfEqual(BasedOnP1, P1, DL));
   EXPECT_FALSE(canReplacePointersIfEqual(BasedOnP1, P2, DL));
 
-  // // We can replace two arbitrary pointers in icmp and ptrtoint instructions.
+  // We can replace two arbitrary pointers in icmp and ptrtoint instructions.
   auto P1UseIter = P1->use_begin();
   const Use &PtrToIntUse = *P1UseIter;
   const Use &IcmpUse = *++P1UseIter;



More information about the llvm-commits mailing list