[llvm] [InstCombine] Offset both sides of an equality icmp (PR #134086)

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 2 06:49:55 PDT 2025


https://github.com/dtcxzyw created https://github.com/llvm/llvm-project/pull/134086

Closes https://github.com/llvm/llvm-project/issues/134024

>From 4d9339636f878a7e8b150d3148cb5186ab06d407 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Wed, 2 Apr 2025 21:15:52 +0800
Subject: [PATCH 1/2] tmp

---
 .../InstCombine/InstCombineCompares.cpp       | 61 +++++++++++++++++++
 1 file changed, 61 insertions(+)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index e75b4026d5424..a30aa7191afbd 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -5808,6 +5808,63 @@ static Instruction *foldICmpPow2Test(ICmpInst &I,
   return nullptr;
 }
 
+/// Find all possible pairs (BinOp, RHS) that BinOp V, RHS can be simplified.
+using OffsetOp = std::pair<Instruction::BinaryOps, Value *>;
+static void collectOffsetOp(Value *V, SmallVectorImpl<OffsetOp> &Offsets) {
+  Instruction *Inst = dyn_cast<Instruction>(V);
+  if (!Inst)
+    return;
+  Constant *C;
+
+  switch (Inst->getOpcode()) {
+  case Instruction::Add:
+    if (match(Inst->getOperand(1), m_ImmConstant(C)))
+      if (Constant *NegC = ConstantExpr::getNeg(C))
+        Offsets.emplace_back(Instruction::Add, NegC);
+    break;
+  case Instruction::Xor:
+    Offsets.emplace_back(Instruction::Xor, Inst->getOperand(1));
+    Offsets.emplace_back(Instruction::Xor, Inst->getOperand(0));
+    break;
+  default:
+    break;
+  }
+}
+
+/// Offset both sides of an equality icmp to see if we can save some
+/// instructions. icmp eq/ne X, Y -> icmp eq/ne X op C, Y op C Note: This
+/// operation should not introduce poison.
+static Instruction *foldICmpEqualityWithOffset(ICmpInst &I,
+                                               InstCombiner::BuilderTy &Builder,
+                                               const SimplifyQuery &SQ) {
+  assert(I.isEquality() && "Expected an equality icmp");
+  Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+  if (!Op0->getType()->isIntOrIntVectorTy())
+    return nullptr;
+
+  SmallVector<OffsetOp, 4> OffsetOps;
+  if (Op0->hasOneUse())
+    collectOffsetOp(Op0, OffsetOps);
+  if (Op1->hasOneUse())
+    collectOffsetOp(Op1, OffsetOps);
+  for (auto [BinOp, RHS] : OffsetOps) {
+    auto BinOpc = static_cast<unsigned>(BinOp);
+
+    Value *Simplified0 = simplifyBinOp(BinOpc, Op0, RHS, SQ);
+    if (!Simplified0 || Simplified0 == Op0)
+      continue;
+
+    Value *Simplified1 = simplifyBinOp(BinOpc, Op1, RHS, SQ);
+    if (!Simplified1 || Simplified1 == Op1)
+      continue;
+
+    return new ICmpInst(static_cast<ICmpInst::Predicate>(I.getPredicate()),
+                        Simplified0, Simplified1);
+  }
+
+  return nullptr;
+}
+
 Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
   if (!I.isEquality())
     return nullptr;
@@ -6054,6 +6111,10 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
                                   : ConstantInt::getNullValue(A->getType()));
   }
 
+  if (auto *Res = foldICmpEqualityWithOffset(
+          I, Builder, getSimplifyQuery().getWithInstruction(&I)))
+    return Res;
+
   return nullptr;
 }
 

>From e5471c3dde74e06f45c7ba0505773a63ae7a1edb Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Wed, 2 Apr 2025 21:48:20 +0800
Subject: [PATCH 2/2] Handle select

---
 .../InstCombine/InstCombineCompares.cpp       | 84 +++++++++++++++++--
 1 file changed, 75 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index a30aa7191afbd..0369327798072 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -5810,7 +5810,8 @@ static Instruction *foldICmpPow2Test(ICmpInst &I,
 
 /// Find all possible pairs (BinOp, RHS) that BinOp V, RHS can be simplified.
 using OffsetOp = std::pair<Instruction::BinaryOps, Value *>;
-static void collectOffsetOp(Value *V, SmallVectorImpl<OffsetOp> &Offsets) {
+static void collectOffsetOp(Value *V, SmallVectorImpl<OffsetOp> &Offsets,
+                            bool AllowRecursion) {
   Instruction *Inst = dyn_cast<Instruction>(V);
   if (!Inst)
     return;
@@ -5826,11 +5827,51 @@ static void collectOffsetOp(Value *V, SmallVectorImpl<OffsetOp> &Offsets) {
     Offsets.emplace_back(Instruction::Xor, Inst->getOperand(1));
     Offsets.emplace_back(Instruction::Xor, Inst->getOperand(0));
     break;
+  case Instruction::Select:
+    if (AllowRecursion) {
+      Value *TrueV = Inst->getOperand(1);
+      if (TrueV->hasOneUse())
+        collectOffsetOp(TrueV, Offsets, /*AllowRecursion=*/false);
+      Value *FalseV = Inst->getOperand(2);
+      if (FalseV->hasOneUse())
+        collectOffsetOp(FalseV, Offsets, /*AllowRecursion=*/false);
+    }
+    break;
   default:
     break;
   }
 }
 
+enum class OffsetKind { Invalid, Value, Select };
+
+struct OffsetResult {
+  OffsetKind Kind;
+  Value *V0, *V1, *V2;
+
+  static OffsetResult invalid() {
+    return {OffsetKind::Invalid, nullptr, nullptr, nullptr};
+  }
+  static OffsetResult value(Value *V) {
+    return {OffsetKind::Value, V, nullptr, nullptr};
+  }
+  static OffsetResult select(Value *Cond, Value *TrueV, Value *FalseV) {
+    return {OffsetKind::Select, Cond, TrueV, FalseV};
+  }
+  bool isValid() const { return Kind != OffsetKind::Invalid; }
+  Value *materialize(InstCombiner::BuilderTy &Builder) const {
+    switch (Kind) {
+    case OffsetKind::Invalid:
+      llvm_unreachable("Invalid offset result");
+    case OffsetKind::Value:
+      return V0;
+    case OffsetKind::Select:
+      return Builder.CreateSelect(V0, V1, V2);
+    default:
+      llvm_unreachable("Unknown offset result kind");
+    }
+  }
+};
+
 /// Offset both sides of an equality icmp to see if we can save some
 /// instructions. icmp eq/ne X, Y -> icmp eq/ne X op C, Y op C Note: This
 /// operation should not introduce poison.
@@ -5844,22 +5885,47 @@ static Instruction *foldICmpEqualityWithOffset(ICmpInst &I,
 
   SmallVector<OffsetOp, 4> OffsetOps;
   if (Op0->hasOneUse())
-    collectOffsetOp(Op0, OffsetOps);
+    collectOffsetOp(Op0, OffsetOps, /*AllowRecursion=*/true);
   if (Op1->hasOneUse())
-    collectOffsetOp(Op1, OffsetOps);
+    collectOffsetOp(Op1, OffsetOps, /*AllowRecursion=*/true);
+
+  auto ApplyOffsetImpl = [&](Value *V, unsigned BinOpc, Value *RHS) -> Value * {
+    Value *Simplified = simplifyBinOp(BinOpc, V, RHS, SQ);
+    // Avoid infinite loops by checking if RHS is an identity for the BinOp.
+    if (!Simplified || Simplified == V)
+      return nullptr;
+    return Simplified;
+  };
+
+  auto ApplyOffset = [&](Value *V, unsigned BinOpc,
+                         Value *RHS) -> OffsetResult {
+    if (auto *Sel = dyn_cast<SelectInst>(V)) {
+      Value *TrueVal = ApplyOffsetImpl(Sel->getTrueValue(), BinOpc, RHS);
+      if (!TrueVal)
+        return OffsetResult::invalid();
+      Value *FalseVal = ApplyOffsetImpl(Sel->getFalseValue(), BinOpc, RHS);
+      if (!FalseVal)
+        return OffsetResult::invalid();
+      return OffsetResult::select(Sel->getCondition(), TrueVal, FalseVal);
+    } else if (Value *Simplified = ApplyOffsetImpl(V, BinOpc, RHS)) {
+      return OffsetResult::value(Simplified);
+    }
+    return OffsetResult::invalid();
+  };
+
   for (auto [BinOp, RHS] : OffsetOps) {
     auto BinOpc = static_cast<unsigned>(BinOp);
 
-    Value *Simplified0 = simplifyBinOp(BinOpc, Op0, RHS, SQ);
-    if (!Simplified0 || Simplified0 == Op0)
+    auto Op0Result = ApplyOffset(Op0, BinOpc, RHS);
+    if (!Op0Result.isValid())
       continue;
-
-    Value *Simplified1 = simplifyBinOp(BinOpc, Op1, RHS, SQ);
-    if (!Simplified1 || Simplified1 == Op1)
+    auto Op1Result = ApplyOffset(Op1, BinOpc, RHS);
+    if (!Op1Result.isValid())
       continue;
 
     return new ICmpInst(static_cast<ICmpInst::Predicate>(I.getPredicate()),
-                        Simplified0, Simplified1);
+                        Op0Result.materialize(Builder),
+                        Op1Result.materialize(Builder));
   }
 
   return nullptr;



More information about the llvm-commits mailing list