[llvm] [InstCombine] fold (A + B - C == B) -> (A == C). (PR #76129)

via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 20 23:57:46 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Chia (sun-jacobi)

<details>
<summary>Changes</summary>

This patch closes #<!-- -->72512. 

We implemented a work-list algorithm to check whether LHS and RHS are sharing a common value, and then eliminate it. 

This implementation actually also covers case `(A + B == B) -> (A == 0)`, which is a pre-exist optimization. 


---
Full diff: https://github.com/llvm/llvm-project/pull/76129.diff


3 Files Affected:

- (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+42) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineInternal.h (+2) 
- (added) llvm/test/Transforms/InstCombine/icmp-eq-common-factor.ll (+121) 


``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 0ad87eeb4c91a4..399af8d1133f08 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -4547,6 +4547,43 @@ static Instruction *foldICmpXorXX(ICmpInst &I, const SimplifyQuery &Q,
   return nullptr;
 }
 
+// extract common factors like ((A + B) - C == B) -> (A - C == 0)
+Instruction *InstCombinerImpl::foldICmpWithCommonFactors(ICmpInst &Cmp,
+                                                         BinaryOperator *LBO,
+                                                         Value *RHS) {
+  const CmpInst::Predicate Pred = Cmp.getPredicate();
+  if (!ICmpInst::isEquality(Pred))
+    return nullptr;
+
+  SmallVector<BinaryOperator *, 16> worklist(1, LBO);
+  Constant *Zero = Constant::getNullValue(LBO->getType());
+
+  while (!worklist.empty()) {
+    BinaryOperator *BO = worklist.pop_back_val();
+
+    if (Value * A; match(BO, m_OneUse(m_c_Add(m_Value(A), m_Specific(RHS))))) {
+      if (BO != LBO) {
+        replaceInstUsesWith(*BO, A);
+        eraseInstFromFunction(*BO);
+      }
+      return new ICmpInst(Pred, LBO, Zero);
+    }
+
+    unsigned Opc = BO->getOpcode();
+    if (Opc == Instruction::Add || Opc == Instruction::Sub) {
+      auto AddNextBO = [&](Value *Op) {
+        if (BinaryOperator *Next = dyn_cast<BinaryOperator>(Op))
+          worklist.push_back(Next);
+      };
+
+      AddNextBO(BO->getOperand(0));
+      AddNextBO(BO->getOperand(1));
+    }
+  }
+
+  return nullptr;
+}
+
 /// Try to fold icmp (binop), X or icmp X, (binop).
 /// TODO: A large part of this logic is duplicated in InstSimplify's
 /// simplifyICmpWithBinOp(). We should be able to share that and avoid the code
@@ -4565,6 +4602,11 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I,
   if (Instruction *NewICmp = foldICmpXNegX(I, Builder))
     return NewICmp;
 
+  if (BO0) {
+    if (Instruction *NewICmp = foldICmpWithCommonFactors(I, BO0, Op1))
+      return NewICmp;
+  }
+
   const CmpInst::Predicate Pred = I.getPredicate();
   Value *X;
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index f86db698ef8f12..9081373c0157d6 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -632,6 +632,8 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   Instruction *foldICmpInstWithConstantNotInt(ICmpInst &Cmp);
   Instruction *foldICmpInstWithConstantAllowUndef(ICmpInst &Cmp,
                                                   const APInt &C);
+  Instruction *foldICmpWithCommonFactors(ICmpInst &Cmp, BinaryOperator *LBO,
+                                         Value *RHS);
   Instruction *foldICmpBinOp(ICmpInst &Cmp, const SimplifyQuery &SQ);
   Instruction *foldICmpWithMinMaxImpl(Instruction &I, MinMaxIntrinsic *MinMax,
                                       Value *Z, ICmpInst::Predicate Pred);
diff --git a/llvm/test/Transforms/InstCombine/icmp-eq-common-factor.ll b/llvm/test/Transforms/InstCombine/icmp-eq-common-factor.ll
new file mode 100644
index 00000000000000..e28c0910c9cae2
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/icmp-eq-common-factor.ll
@@ -0,0 +1,121 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -passes=instcombine -S < %s | FileCheck %s
+
+; A + B - C == B
+define i1 @icmp_common_add_sub(i32 %a, i32 %b, i32 %c){
+; CHECK-LABEL: define i1 @icmp_common_add_sub(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]], i32 [[C:%.*]]) {
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[A]], [[C]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %add = add i32 %a, %b
+  %sub = sub i32 %add, %c
+  %cmp = icmp eq i32 %sub, %b
+  ret i1 %cmp
+}
+
+; A + B - C == B
+define i1 @icmp_common_add_sub_ne(i32 %a, i32 %b, i32 %c){
+; CHECK-LABEL: define i1 @icmp_common_add_sub_ne(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]], i32 [[C:%.*]]) {
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ne i32 [[A]], [[C]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %add = add i32 %a, %b
+  %sub = sub i32 %add, %c
+  %cmp = icmp ne i32 %sub, %b
+  ret i1 %cmp
+}
+
+; A * B - C == B
+define i1 @icmp_common_mul_sub(i32 %a, i32 %b, i32 %c){
+; CHECK-LABEL: define i1 @icmp_common_mul_sub(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]], i32 [[C:%.*]]) {
+; CHECK-NEXT:    [[MUL:%.*]] = mul i32 [[A]], [[B]]
+; CHECK-NEXT:    [[SUB:%.*]] = sub i32 [[MUL]], [[C]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[SUB]], [[B]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %mul = mul i32 %a, %b
+  %sub = sub i32 %mul, %c
+  %cmp = icmp eq i32 %sub, %b
+  ret i1 %cmp
+}
+
+; A + B + C == B
+define i1 @icmp_common_add_add(i32 %a, i32 %b, i32 %c){
+; CHECK-LABEL: define i1 @icmp_common_add_add(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]], i32 [[C:%.*]]) {
+; CHECK-NEXT:    [[ADD2:%.*]] = sub i32 0, [[C]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[ADD2]], [[A]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %add1 = add i32 %a, %b
+  %add2 = add i32 %add1, %c
+  %cmp = icmp eq i32 %add2, %b
+  ret i1 %cmp
+}
+
+; A + B + C + D == B
+define i1 @icmp_common_add_add_add(i32 %a, i32 %b, i32 %c, i32 %d){
+; CHECK-LABEL: define i1 @icmp_common_add_add_add(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]], i32 [[C:%.*]], i32 [[D:%.*]]) {
+; CHECK-NEXT:    [[ADD2:%.*]] = add i32 [[A]], [[C]]
+; CHECK-NEXT:    [[ADD3:%.*]] = sub i32 0, [[D]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[ADD2]], [[ADD3]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %add1 = add i32 %a, %b
+  %add2 = add i32 %add1, %c
+  %add3 = add i32 %add2, %d
+  %cmp = icmp eq i32 %add3, %b
+  ret i1 %cmp
+}
+
+; A + B + C + D == C
+define i1 @icmp_common_add_add_add_2(i32 %a, i32 %b, i32 %c, i32 %d){
+; CHECK-LABEL: define i1 @icmp_common_add_add_add_2(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]], i32 [[C:%.*]], i32 [[D:%.*]]) {
+; CHECK-NEXT:    [[ADD1:%.*]] = add i32 [[A]], [[B]]
+; CHECK-NEXT:    [[ADD3:%.*]] = sub i32 0, [[D]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[ADD1]], [[ADD3]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %add1 = add i32 %a, %b
+  %add2 = add i32 %add1, %c
+  %add3 = add i32 %add2, %d
+  %cmp = icmp eq i32 %add3, %c
+  ret i1 %cmp
+}
+
+; A + B + C - D == B
+define i1 @icmp_common_add_add_sub(i32 %a, i32 %b, i32 %c, i32 %d){
+; CHECK-LABEL: define i1 @icmp_common_add_add_sub(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]], i32 [[C:%.*]], i32 [[D:%.*]]) {
+; CHECK-NEXT:    [[ADD2:%.*]] = add i32 [[A]], [[C]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[ADD2]], [[D]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %add1 = add i32 %a, %b
+  %add2 = add i32 %add1, %c
+  %sub = sub i32 %add2, %d
+  %cmp = icmp eq i32 %sub, %b
+  ret i1 %cmp
+}
+
+
+; A + B - C + D == B
+define i1 @icmp_common_add_sub_add(i32 %a, i32 %b, i32 %c, i32 %d){
+; CHECK-LABEL: define i1 @icmp_common_add_sub_add(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]], i32 [[C:%.*]], i32 [[D:%.*]]) {
+; CHECK-NEXT:    [[SUB:%.*]] = sub i32 [[A]], [[C]]
+; CHECK-NEXT:    [[ADD2:%.*]] = sub i32 0, [[D]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[SUB]], [[ADD2]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %add1 = add i32 %a, %b
+  %sub = sub i32 %add1, %c
+  %add2 = add i32 %sub, %d
+  %cmp = icmp eq i32 %add2, %b
+  ret i1 %cmp
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/76129


More information about the llvm-commits mailing list