[llvm] [InstCombine] Fix pointer replacement in `foldSelectValueEquivalence` (PR #161701)

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 3 07:39:15 PDT 2025


https://github.com/dtcxzyw updated https://github.com/llvm/llvm-project/pull/161701

>From 33495201d22773f8f7b26181222686a025c5a74f Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Thu, 2 Oct 2025 20:40:54 +0800
Subject: [PATCH 1/4] [InstCombine] Add pre-commit tests. NFC.

---
 .../test/Transforms/InstCombine/select-gep.ll | 26 +++++++++++++++++++
 1 file changed, 26 insertions(+)

diff --git a/llvm/test/Transforms/InstCombine/select-gep.ll b/llvm/test/Transforms/InstCombine/select-gep.ll
index dd8dffba11b05..7aba1421b9ad2 100644
--- a/llvm/test/Transforms/InstCombine/select-gep.ll
+++ b/llvm/test/Transforms/InstCombine/select-gep.ll
@@ -286,3 +286,29 @@ define <2 x ptr> @test7(<2 x ptr> %p1, i64 %idx, <2 x i1> %cc) {
   %select = select <2 x i1> %cc, <2 x ptr> %p1, <2 x ptr> %gep
   ret <2 x ptr> %select
 }
+
+define ptr @ptr_eq_replace_freeze1(ptr %p, ptr %q) {
+; CHECK-LABEL: @ptr_eq_replace_freeze1(
+; CHECK-NEXT:    [[Q_FR:%.*]] = freeze ptr [[Q:%.*]]
+; CHECK-NEXT:    ret ptr [[Q_FR]]
+;
+  %p.fr = freeze ptr %p
+  %q.fr = freeze ptr %q
+  %cmp = icmp eq ptr %p.fr, %q.fr
+  %select = select i1 %cmp, ptr %p.fr, ptr %q.fr
+  ret ptr %select
+}
+
+define ptr @ptr_eq_replace_freeze2(ptr %p, ptr %q) {
+; CHECK-LABEL: @ptr_eq_replace_freeze2(
+; CHECK-NEXT:    [[P_FR:%.*]] = freeze ptr [[P:%.*]]
+; CHECK-NEXT:    [[SELECT:%.*]] = getelementptr i8, ptr [[P_FR]], i64 16
+; CHECK-NEXT:    ret ptr [[SELECT]]
+;
+  %gep1 = getelementptr i32, ptr %p, i64 4
+  %gep2 = getelementptr i32, ptr %q, i64 4
+  %cmp = icmp eq ptr %p, %q
+  %cmp.fr = freeze i1 %cmp
+  %select = select i1 %cmp.fr, ptr %gep1, ptr %gep2
+  ret ptr %select
+}

>From 18f21c6b5451e6bd54561586db9740e398615002 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Thu, 2 Oct 2025 20:50:24 +0800
Subject: [PATCH 2/4] [InstCombine] Fix pointer replacement in
 `foldSelectValueEquivalence`

---
 .../InstCombine/InstCombineSelect.cpp           | 17 +++++++++++------
 llvm/test/Transforms/InstCombine/select-gep.ll  | 10 ++++++++--
 2 files changed, 19 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 87000a1c36eef..5ea79b92633d9 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -17,6 +17,7 @@
 #include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/CmpInstAnalysis.h"
 #include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Analysis/Loads.h"
 #include "llvm/Analysis/OverflowInstAnalysis.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/Analysis/VectorUtils.h"
@@ -1411,6 +1412,8 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
     // in the cmp and in f(Y).
     if (TrueVal == OldOp && (isa<Constant>(OldOp) || !isa<Constant>(NewOp)))
       return nullptr;
+    if (!canReplacePointersIfEqual(OldOp, NewOp, DL))
+      return nullptr;
 
     if (Value *V = simplifyWithOpReplaced(TrueVal, OldOp, NewOp, SQ,
                                           /* AllowRefinement=*/true)) {
@@ -1466,12 +1469,14 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
   // Example:
   // (X == 42) ? 43 : (X + 1) --> (X == 42) ? (X + 1) : (X + 1) --> X + 1
   SmallVector<Instruction *> DropFlags;
-  if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, SQ,
-                             /* AllowRefinement */ false,
-                             &DropFlags) == TrueVal ||
-      simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, SQ,
-                             /* AllowRefinement */ false,
-                             &DropFlags) == TrueVal) {
+  if ((canReplacePointersIfEqual(CmpLHS, CmpRHS, DL) &&
+       simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, SQ,
+                              /* AllowRefinement */ false,
+                              &DropFlags) == TrueVal) ||
+      (canReplacePointersIfEqual(CmpRHS, CmpLHS, DL) &&
+       simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, SQ,
+                              /* AllowRefinement */ false,
+                              &DropFlags) == TrueVal)) {
     for (Instruction *I : DropFlags) {
       I->dropPoisonGeneratingAnnotations();
       Worklist.add(I);
diff --git a/llvm/test/Transforms/InstCombine/select-gep.ll b/llvm/test/Transforms/InstCombine/select-gep.ll
index 7aba1421b9ad2..718133699a8a7 100644
--- a/llvm/test/Transforms/InstCombine/select-gep.ll
+++ b/llvm/test/Transforms/InstCombine/select-gep.ll
@@ -290,7 +290,10 @@ define <2 x ptr> @test7(<2 x ptr> %p1, i64 %idx, <2 x i1> %cc) {
 define ptr @ptr_eq_replace_freeze1(ptr %p, ptr %q) {
 ; CHECK-LABEL: @ptr_eq_replace_freeze1(
 ; CHECK-NEXT:    [[Q_FR:%.*]] = freeze ptr [[Q:%.*]]
-; CHECK-NEXT:    ret ptr [[Q_FR]]
+; CHECK-NEXT:    [[Q_FR1:%.*]] = freeze ptr [[Q1:%.*]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq ptr [[Q_FR]], [[Q_FR1]]
+; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[CMP]], ptr [[Q_FR]], ptr [[Q_FR1]]
+; CHECK-NEXT:    ret ptr [[SELECT]]
 ;
   %p.fr = freeze ptr %p
   %q.fr = freeze ptr %q
@@ -302,7 +305,10 @@ define ptr @ptr_eq_replace_freeze1(ptr %p, ptr %q) {
 define ptr @ptr_eq_replace_freeze2(ptr %p, ptr %q) {
 ; CHECK-LABEL: @ptr_eq_replace_freeze2(
 ; CHECK-NEXT:    [[P_FR:%.*]] = freeze ptr [[P:%.*]]
-; CHECK-NEXT:    [[SELECT:%.*]] = getelementptr i8, ptr [[P_FR]], i64 16
+; CHECK-NEXT:    [[P_FR1:%.*]] = freeze ptr [[P1:%.*]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq ptr [[P_FR1]], [[P_FR]]
+; CHECK-NEXT:    [[SELECT_V:%.*]] = select i1 [[CMP]], ptr [[P_FR1]], ptr [[P_FR]]
+; CHECK-NEXT:    [[SELECT:%.*]] = getelementptr i8, ptr [[SELECT_V]], i64 16
 ; CHECK-NEXT:    ret ptr [[SELECT]]
 ;
   %gep1 = getelementptr i32, ptr %p, i64 4

>From 44f9269415ed9d7e7708e1c122e04c71c24ee8d5 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Fri, 3 Oct 2025 02:39:47 +0800
Subject: [PATCH 3/4] [InstCombine] Cache query results

---
 .../InstCombine/InstCombineSelect.cpp         | 52 ++++++++++++++++---
 1 file changed, 45 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 5ea79b92633d9..aca52c702ce4e 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -43,6 +43,7 @@
 #include "llvm/Support/KnownBits.h"
 #include "llvm/Transforms/InstCombine/InstCombiner.h"
 #include <cassert>
+#include <optional>
 #include <utility>
 
 #define DEBUG_TYPE "instcombine"
@@ -1401,9 +1402,43 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
     return nullptr;
   }
 
+  std::optional<bool> CanReplacePointersIfEqualCache[2] = {std::nullopt,
+                                                           std::nullopt};
+  std::optional<bool> ShareSameUnderlyingObject = std::nullopt;
+  auto CanReplacePointersIfEqual = [&](Value *From, Value *To,
+                                       std::optional<bool> &Cache) -> bool {
+    if (Cache.has_value())
+      return *Cache;
+
+    assert(From->getType() == To->getType() &&
+           "values must have matching types");
+    // Not a pointer, just return true.
+    if (!From->getType()->isPointerTy()) {
+      Cache = true;
+      return true;
+    }
+
+    if (isa<ConstantPointerNull>(To)) {
+      Cache = true;
+      return true;
+    }
+    if (isa<Constant>(To) &&
+        isDereferenceablePointer(To, Type::getInt8Ty(To->getContext()), DL)) {
+      Cache = true;
+      return true;
+    }
+
+    if (!ShareSameUnderlyingObject.has_value())
+      ShareSameUnderlyingObject = getUnderlyingObjectAggressive(From) ==
+                                  getUnderlyingObjectAggressive(To);
+
+    Cache = *ShareSameUnderlyingObject;
+    return *ShareSameUnderlyingObject;
+  };
+
   Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1);
-  auto ReplaceOldOpWithNewOp = [&](Value *OldOp,
-                                   Value *NewOp) -> Instruction * {
+  auto ReplaceOldOpWithNewOp = [&](Value *OldOp, Value *NewOp,
+                                   uint32_t Direction) -> Instruction * {
     // In X == Y ? f(X) : Z, try to evaluate f(Y) and replace the operand.
     // Take care to avoid replacing X == Y ? X : Z with X == Y ? Y : Z, as that
     // would lead to an infinite replacement cycle.
@@ -1412,7 +1447,8 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
     // in the cmp and in f(Y).
     if (TrueVal == OldOp && (isa<Constant>(OldOp) || !isa<Constant>(NewOp)))
       return nullptr;
-    if (!canReplacePointersIfEqual(OldOp, NewOp, DL))
+    if (!CanReplacePointersIfEqual(OldOp, NewOp,
+                                   CanReplacePointersIfEqualCache[Direction]))
       return nullptr;
 
     if (Value *V = simplifyWithOpReplaced(TrueVal, OldOp, NewOp, SQ,
@@ -1451,9 +1487,9 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
     return nullptr;
   };
 
-  if (Instruction *R = ReplaceOldOpWithNewOp(CmpLHS, CmpRHS))
+  if (Instruction *R = ReplaceOldOpWithNewOp(CmpLHS, CmpRHS, 0))
     return R;
-  if (Instruction *R = ReplaceOldOpWithNewOp(CmpRHS, CmpLHS))
+  if (Instruction *R = ReplaceOldOpWithNewOp(CmpRHS, CmpLHS, 1))
     return R;
 
   auto *FalseInst = dyn_cast<Instruction>(FalseVal);
@@ -1469,11 +1505,13 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
   // Example:
   // (X == 42) ? 43 : (X + 1) --> (X == 42) ? (X + 1) : (X + 1) --> X + 1
   SmallVector<Instruction *> DropFlags;
-  if ((canReplacePointersIfEqual(CmpLHS, CmpRHS, DL) &&
+  if ((CanReplacePointersIfEqual(CmpLHS, CmpRHS,
+                                 CanReplacePointersIfEqualCache[0]) &&
        simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, SQ,
                               /* AllowRefinement */ false,
                               &DropFlags) == TrueVal) ||
-      (canReplacePointersIfEqual(CmpRHS, CmpLHS, DL) &&
+      (CanReplacePointersIfEqual(CmpRHS, CmpLHS,
+                                 CanReplacePointersIfEqualCache[1]) &&
        simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, SQ,
                               /* AllowRefinement */ false,
                               &DropFlags) == TrueVal)) {

>From 9a0f0302efc30580136d191e66bac929f08ee25f Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Fri, 3 Oct 2025 22:38:52 +0800
Subject: [PATCH 4/4] [InstCombine] Address review comments. NFC.

---
 .../InstCombine/InstCombineSelect.cpp         | 57 ++++---------------
 1 file changed, 12 insertions(+), 45 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index aca52c702ce4e..b0c06363ddf45 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -1402,40 +1402,6 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
     return nullptr;
   }
 
-  std::optional<bool> CanReplacePointersIfEqualCache[2] = {std::nullopt,
-                                                           std::nullopt};
-  std::optional<bool> ShareSameUnderlyingObject = std::nullopt;
-  auto CanReplacePointersIfEqual = [&](Value *From, Value *To,
-                                       std::optional<bool> &Cache) -> bool {
-    if (Cache.has_value())
-      return *Cache;
-
-    assert(From->getType() == To->getType() &&
-           "values must have matching types");
-    // Not a pointer, just return true.
-    if (!From->getType()->isPointerTy()) {
-      Cache = true;
-      return true;
-    }
-
-    if (isa<ConstantPointerNull>(To)) {
-      Cache = true;
-      return true;
-    }
-    if (isa<Constant>(To) &&
-        isDereferenceablePointer(To, Type::getInt8Ty(To->getContext()), DL)) {
-      Cache = true;
-      return true;
-    }
-
-    if (!ShareSameUnderlyingObject.has_value())
-      ShareSameUnderlyingObject = getUnderlyingObjectAggressive(From) ==
-                                  getUnderlyingObjectAggressive(To);
-
-    Cache = *ShareSameUnderlyingObject;
-    return *ShareSameUnderlyingObject;
-  };
-
   Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1);
   auto ReplaceOldOpWithNewOp = [&](Value *OldOp, Value *NewOp,
                                    uint32_t Direction) -> Instruction * {
@@ -1447,9 +1413,6 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
     // in the cmp and in f(Y).
     if (TrueVal == OldOp && (isa<Constant>(OldOp) || !isa<Constant>(NewOp)))
       return nullptr;
-    if (!CanReplacePointersIfEqual(OldOp, NewOp,
-                                   CanReplacePointersIfEqualCache[Direction]))
-      return nullptr;
 
     if (Value *V = simplifyWithOpReplaced(TrueVal, OldOp, NewOp, SQ,
                                           /* AllowRefinement=*/true)) {
@@ -1487,10 +1450,16 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
     return nullptr;
   };
 
-  if (Instruction *R = ReplaceOldOpWithNewOp(CmpLHS, CmpRHS, 0))
-    return R;
-  if (Instruction *R = ReplaceOldOpWithNewOp(CmpRHS, CmpLHS, 1))
-    return R;
+  bool CanReplaceCmpLHSWithRHS = canReplacePointersIfEqual(CmpLHS, CmpRHS, DL);
+  if (CanReplaceCmpLHSWithRHS) {
+    if (Instruction *R = ReplaceOldOpWithNewOp(CmpLHS, CmpRHS, 0))
+      return R;
+  }
+  bool CanReplaceCmpRHSWithLHS = canReplacePointersIfEqual(CmpRHS, CmpLHS, DL);
+  if (CanReplaceCmpRHSWithLHS) {
+    if (Instruction *R = ReplaceOldOpWithNewOp(CmpRHS, CmpLHS, 1))
+      return R;
+  }
 
   auto *FalseInst = dyn_cast<Instruction>(FalseVal);
   if (!FalseInst)
@@ -1505,13 +1474,11 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
   // Example:
   // (X == 42) ? 43 : (X + 1) --> (X == 42) ? (X + 1) : (X + 1) --> X + 1
   SmallVector<Instruction *> DropFlags;
-  if ((CanReplacePointersIfEqual(CmpLHS, CmpRHS,
-                                 CanReplacePointersIfEqualCache[0]) &&
+  if ((CanReplaceCmpLHSWithRHS &&
        simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, SQ,
                               /* AllowRefinement */ false,
                               &DropFlags) == TrueVal) ||
-      (CanReplacePointersIfEqual(CmpRHS, CmpLHS,
-                                 CanReplacePointersIfEqualCache[1]) &&
+      (CanReplaceCmpRHSWithLHS &&
        simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, SQ,
                               /* AllowRefinement */ false,
                               &DropFlags) == TrueVal)) {



More information about the llvm-commits mailing list