[llvm] [InstCombine] fold (A + B - C == B) -> (A == C). (PR #76129)

via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 21 04:18:28 PST 2023


https://github.com/sun-jacobi updated https://github.com/llvm/llvm-project/pull/76129

>From 98c373005e84368758128d051d89f260ecfab65c Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Thu, 21 Dec 2023 14:12:04 +0900
Subject: [PATCH 1/5] [InstCombine] fold (A + B - C == B) -> (A - C == 0)

---
 .../InstCombine/InstCombineCompares.cpp       | 44 +++++++++++++++++++
 .../InstCombine/InstCombineInternal.h         |  2 +
 2 files changed, 46 insertions(+)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 0ad87eeb4c91a4..ca30e5b6ad779b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -4547,6 +4547,45 @@ static Instruction *foldICmpXorXX(ICmpInst &I, const SimplifyQuery &Q,
   return nullptr;
 }
 
+// extract common factors like ((A + B) - C == B) -> (A - C == 0)
+Value *InstCombinerImpl::foldICmpWithCommonFactors(ICmpInst &Cmp,
+                                                   BinaryOperator *LBO,
+                                                   Value *RHS) {
+  const CmpInst::Predicate Pred = Cmp.getPredicate();
+  if (!ICmpInst::isEquality(Pred))
+    return nullptr;
+
+  SmallVector<BinaryOperator *, 16> WorkList(1, LBO);
+
+  while (!WorkList.empty()) {
+    BinaryOperator *BO = WorkList.pop_back_val();
+
+    Value *A;
+    if (match(BO, m_OneUse(m_c_Add(m_Value(A), m_Specific(RHS))))) {
+      if (BO == LBO)
+        return Builder.CreateICmp(Pred, A,
+                                  Constant::getNullValue(LBO->getType()));
+      replaceInstUsesWith(*BO, A);
+      eraseInstFromFunction(*BO);
+      return Builder.CreateICmp(Pred, LBO,
+                                Constant::getNullValue(LBO->getType()));
+    }
+
+    unsigned Opc = BO->getOpcode();
+    if (Opc == Instruction::Add || Opc == Instruction::Sub) {
+      auto AddNextBO = [&](Value *Op) {
+        if (BinaryOperator *Next = dyn_cast<BinaryOperator>(Op))
+          WorkList.push_back(Next);
+      };
+
+      AddNextBO(BO->getOperand(0));
+      AddNextBO(BO->getOperand(1));
+    }
+  }
+
+  return nullptr;
+}
+
 /// Try to fold icmp (binop), X or icmp X, (binop).
 /// TODO: A large part of this logic is duplicated in InstSimplify's
 /// simplifyICmpWithBinOp(). We should be able to share that and avoid the code
@@ -4565,6 +4604,11 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I,
   if (Instruction *NewICmp = foldICmpXNegX(I, Builder))
     return NewICmp;
 
+  if (BO0) {
+    if (Value *V = foldICmpWithCommonFactors(I, BO0, Op1))
+      return replaceInstUsesWith(I, V);
+  }
+
   const CmpInst::Predicate Pred = I.getPredicate();
   Value *X;
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index f86db698ef8f12..54a9dd9942d09c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -632,6 +632,8 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   Instruction *foldICmpInstWithConstantNotInt(ICmpInst &Cmp);
   Instruction *foldICmpInstWithConstantAllowUndef(ICmpInst &Cmp,
                                                   const APInt &C);
+  Value *foldICmpWithCommonFactors(ICmpInst &Cmp, BinaryOperator *LBO,
+                                   Value *RHS);
   Instruction *foldICmpBinOp(ICmpInst &Cmp, const SimplifyQuery &SQ);
   Instruction *foldICmpWithMinMaxImpl(Instruction &I, MinMaxIntrinsic *MinMax,
                                       Value *Z, ICmpInst::Predicate Pred);

>From 6b272c60951fa0132dbe9a8866d42ba7b414fee4 Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Thu, 21 Dec 2023 16:29:33 +0900
Subject: [PATCH 2/5] [InstCombine] refactor foldICmpWithCommonFactors

---
 .../InstCombine/InstCombineCompares.cpp       | 34 +++++++++----------
 .../InstCombine/InstCombineInternal.h         |  4 +--
 2 files changed, 18 insertions(+), 20 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index ca30e5b6ad779b..5482b044deadc9 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -4548,34 +4548,32 @@ static Instruction *foldICmpXorXX(ICmpInst &I, const SimplifyQuery &Q,
 }
 
 // extract common factors like ((A + B) - C == B) -> (A - C == 0)
-Value *InstCombinerImpl::foldICmpWithCommonFactors(ICmpInst &Cmp,
-                                                   BinaryOperator *LBO,
-                                                   Value *RHS) {
+Instruction *InstCombinerImpl::foldICmpWithCommonFactors(ICmpInst &Cmp,
+                                                         BinaryOperator *LBO,
+                                                         Value *RHS) {
   const CmpInst::Predicate Pred = Cmp.getPredicate();
   if (!ICmpInst::isEquality(Pred))
     return nullptr;
 
-  SmallVector<BinaryOperator *, 16> WorkList(1, LBO);
+  SmallVector<BinaryOperator *, 16> worklist(1, LBO);
+  Constant *Zero = Constant::getNullValue(LBO->getType());
 
-  while (!WorkList.empty()) {
-    BinaryOperator *BO = WorkList.pop_back_val();
+  while (!worklist.empty()) {
+    BinaryOperator *BO = worklist.pop_back_val();
 
-    Value *A;
-    if (match(BO, m_OneUse(m_c_Add(m_Value(A), m_Specific(RHS))))) {
-      if (BO == LBO)
-        return Builder.CreateICmp(Pred, A,
-                                  Constant::getNullValue(LBO->getType()));
-      replaceInstUsesWith(*BO, A);
-      eraseInstFromFunction(*BO);
-      return Builder.CreateICmp(Pred, LBO,
-                                Constant::getNullValue(LBO->getType()));
+    if (Value * A; match(BO, m_OneUse(m_c_Add(m_Value(A), m_Specific(RHS))))) {
+      if (BO != LBO) {
+        replaceInstUsesWith(*BO, A);
+        eraseInstFromFunction(*BO);
+      }
+      return new ICmpInst(Pred, A, Zero);
     }
 
     unsigned Opc = BO->getOpcode();
     if (Opc == Instruction::Add || Opc == Instruction::Sub) {
       auto AddNextBO = [&](Value *Op) {
         if (BinaryOperator *Next = dyn_cast<BinaryOperator>(Op))
-          WorkList.push_back(Next);
+          worklist.push_back(Next);
       };
 
       AddNextBO(BO->getOperand(0));
@@ -4605,8 +4603,8 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I,
     return NewICmp;
 
   if (BO0) {
-    if (Value *V = foldICmpWithCommonFactors(I, BO0, Op1))
-      return replaceInstUsesWith(I, V);
+    if (Instruction *NewICmp = foldICmpWithCommonFactors(I, BO0, Op1))
+      return NewICmp;
   }
 
   const CmpInst::Predicate Pred = I.getPredicate();
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 54a9dd9942d09c..9081373c0157d6 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -632,8 +632,8 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   Instruction *foldICmpInstWithConstantNotInt(ICmpInst &Cmp);
   Instruction *foldICmpInstWithConstantAllowUndef(ICmpInst &Cmp,
                                                   const APInt &C);
-  Value *foldICmpWithCommonFactors(ICmpInst &Cmp, BinaryOperator *LBO,
-                                   Value *RHS);
+  Instruction *foldICmpWithCommonFactors(ICmpInst &Cmp, BinaryOperator *LBO,
+                                         Value *RHS);
   Instruction *foldICmpBinOp(ICmpInst &Cmp, const SimplifyQuery &SQ);
   Instruction *foldICmpWithMinMaxImpl(Instruction &I, MinMaxIntrinsic *MinMax,
                                       Value *Z, ICmpInst::Predicate Pred);

>From e3999c39c007cf5c2f8125512aab4e479a8256a1 Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Thu, 21 Dec 2023 16:50:15 +0900
Subject: [PATCH 3/5] [InstCombine] fix bugs in foldICmpWithCommonFactors

---
 llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 5482b044deadc9..399af8d1133f08 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -4566,7 +4566,7 @@ Instruction *InstCombinerImpl::foldICmpWithCommonFactors(ICmpInst &Cmp,
         replaceInstUsesWith(*BO, A);
         eraseInstFromFunction(*BO);
       }
-      return new ICmpInst(Pred, A, Zero);
+      return new ICmpInst(Pred, LBO, Zero);
     }
 
     unsigned Opc = BO->getOpcode();

>From 88b44ceef7813f295a1ac071a4dab899bb5d452f Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Thu, 21 Dec 2023 16:51:04 +0900
Subject: [PATCH 4/5] [InstCombine] add icmp-eq-common-factor.ll test

---
 .../InstCombine/icmp-eq-common-factor.ll      | 121 ++++++++++++++++++
 1 file changed, 121 insertions(+)
 create mode 100644 llvm/test/Transforms/InstCombine/icmp-eq-common-factor.ll

diff --git a/llvm/test/Transforms/InstCombine/icmp-eq-common-factor.ll b/llvm/test/Transforms/InstCombine/icmp-eq-common-factor.ll
new file mode 100644
index 00000000000000..e28c0910c9cae2
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/icmp-eq-common-factor.ll
@@ -0,0 +1,121 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -passes=instcombine -S < %s | FileCheck %s
+
+; A + B - C == B
+define i1 @icmp_common_add_sub(i32 %a, i32 %b, i32 %c){
+; CHECK-LABEL: define i1 @icmp_common_add_sub(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]], i32 [[C:%.*]]) {
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[A]], [[C]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %add = add i32 %a, %b
+  %sub = sub i32 %add, %c
+  %cmp = icmp eq i32 %sub, %b
+  ret i1 %cmp
+}
+
+; A + B - C == B
+define i1 @icmp_common_add_sub_ne(i32 %a, i32 %b, i32 %c){
+; CHECK-LABEL: define i1 @icmp_common_add_sub_ne(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]], i32 [[C:%.*]]) {
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ne i32 [[A]], [[C]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %add = add i32 %a, %b
+  %sub = sub i32 %add, %c
+  %cmp = icmp ne i32 %sub, %b
+  ret i1 %cmp
+}
+
+; A * B - C == B
+define i1 @icmp_common_mul_sub(i32 %a, i32 %b, i32 %c){
+; CHECK-LABEL: define i1 @icmp_common_mul_sub(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]], i32 [[C:%.*]]) {
+; CHECK-NEXT:    [[MUL:%.*]] = mul i32 [[A]], [[B]]
+; CHECK-NEXT:    [[SUB:%.*]] = sub i32 [[MUL]], [[C]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[SUB]], [[B]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %mul = mul i32 %a, %b
+  %sub = sub i32 %mul, %c
+  %cmp = icmp eq i32 %sub, %b
+  ret i1 %cmp
+}
+
+; A + B + C == B
+define i1 @icmp_common_add_add(i32 %a, i32 %b, i32 %c){
+; CHECK-LABEL: define i1 @icmp_common_add_add(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]], i32 [[C:%.*]]) {
+; CHECK-NEXT:    [[ADD2:%.*]] = sub i32 0, [[C]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[ADD2]], [[A]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %add1 = add i32 %a, %b
+  %add2 = add i32 %add1, %c
+  %cmp = icmp eq i32 %add2, %b
+  ret i1 %cmp
+}
+
+; A + B + C + D == B
+define i1 @icmp_common_add_add_add(i32 %a, i32 %b, i32 %c, i32 %d){
+; CHECK-LABEL: define i1 @icmp_common_add_add_add(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]], i32 [[C:%.*]], i32 [[D:%.*]]) {
+; CHECK-NEXT:    [[ADD2:%.*]] = add i32 [[A]], [[C]]
+; CHECK-NEXT:    [[ADD3:%.*]] = sub i32 0, [[D]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[ADD2]], [[ADD3]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %add1 = add i32 %a, %b
+  %add2 = add i32 %add1, %c
+  %add3 = add i32 %add2, %d
+  %cmp = icmp eq i32 %add3, %b
+  ret i1 %cmp
+}
+
+; A + B + C + D == C
+define i1 @icmp_common_add_add_add_2(i32 %a, i32 %b, i32 %c, i32 %d){
+; CHECK-LABEL: define i1 @icmp_common_add_add_add_2(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]], i32 [[C:%.*]], i32 [[D:%.*]]) {
+; CHECK-NEXT:    [[ADD1:%.*]] = add i32 [[A]], [[B]]
+; CHECK-NEXT:    [[ADD3:%.*]] = sub i32 0, [[D]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[ADD1]], [[ADD3]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %add1 = add i32 %a, %b
+  %add2 = add i32 %add1, %c
+  %add3 = add i32 %add2, %d
+  %cmp = icmp eq i32 %add3, %c
+  ret i1 %cmp
+}
+
+; A + B + C - D == B
+define i1 @icmp_common_add_add_sub(i32 %a, i32 %b, i32 %c, i32 %d){
+; CHECK-LABEL: define i1 @icmp_common_add_add_sub(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]], i32 [[C:%.*]], i32 [[D:%.*]]) {
+; CHECK-NEXT:    [[ADD2:%.*]] = add i32 [[A]], [[C]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[ADD2]], [[D]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %add1 = add i32 %a, %b
+  %add2 = add i32 %add1, %c
+  %sub = sub i32 %add2, %d
+  %cmp = icmp eq i32 %sub, %b
+  ret i1 %cmp
+}
+
+
+; A + B - C + D == B
+define i1 @icmp_common_add_sub_add(i32 %a, i32 %b, i32 %c, i32 %d){
+; CHECK-LABEL: define i1 @icmp_common_add_sub_add(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]], i32 [[C:%.*]], i32 [[D:%.*]]) {
+; CHECK-NEXT:    [[SUB:%.*]] = sub i32 [[A]], [[C]]
+; CHECK-NEXT:    [[ADD2:%.*]] = sub i32 0, [[D]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[SUB]], [[ADD2]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %add1 = add i32 %a, %b
+  %sub = sub i32 %add1, %c
+  %add2 = add i32 %sub, %d
+  %cmp = icmp eq i32 %add2, %b
+  ret i1 %cmp
+}

>From c8c76567e58fee1432df1199a20d87b658d4fe50 Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Thu, 21 Dec 2023 21:18:04 +0900
Subject: [PATCH 5/5] [InstCombine] fixs bug in foldICmpWithCommonFactors

---
 .../InstCombine/InstCombineCompares.cpp       | 27 +++++++++++--------
 .../InstCombine/icmp-eq-common-factor.ll      |  2 +-
 2 files changed, 17 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 399af8d1133f08..03c04f6d7b6b13 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -4555,27 +4555,32 @@ Instruction *InstCombinerImpl::foldICmpWithCommonFactors(ICmpInst &Cmp,
   if (!ICmpInst::isEquality(Pred))
     return nullptr;
 
-  SmallVector<BinaryOperator *, 16> worklist(1, LBO);
-  Constant *Zero = Constant::getNullValue(LBO->getType());
+  if (LBO->getOpcode() != Instruction::Add &&
+      LBO->getOpcode() != Instruction::Sub)
+    return nullptr;
+
+  SmallVector<BinaryOperator *, 16> worklist;
+
+  auto AddNextBO = [&](Value *Op) {
+    if (BinaryOperator *Next = dyn_cast<BinaryOperator>(Op))
+      worklist.push_back(Next);
+  };
+
+  AddNextBO(LBO->getOperand(0));
+  AddNextBO(LBO->getOperand(1));
 
   while (!worklist.empty()) {
     BinaryOperator *BO = worklist.pop_back_val();
 
     if (Value * A; match(BO, m_OneUse(m_c_Add(m_Value(A), m_Specific(RHS))))) {
-      if (BO != LBO) {
-        replaceInstUsesWith(*BO, A);
-        eraseInstFromFunction(*BO);
-      }
+      replaceInstUsesWith(*BO, A);
+      eraseInstFromFunction(*BO);
+      Constant *Zero = Constant::getNullValue(LBO->getType());
       return new ICmpInst(Pred, LBO, Zero);
     }
 
     unsigned Opc = BO->getOpcode();
     if (Opc == Instruction::Add || Opc == Instruction::Sub) {
-      auto AddNextBO = [&](Value *Op) {
-        if (BinaryOperator *Next = dyn_cast<BinaryOperator>(Op))
-          worklist.push_back(Next);
-      };
-
       AddNextBO(BO->getOperand(0));
       AddNextBO(BO->getOperand(1));
     }
diff --git a/llvm/test/Transforms/InstCombine/icmp-eq-common-factor.ll b/llvm/test/Transforms/InstCombine/icmp-eq-common-factor.ll
index e28c0910c9cae2..d6c3c8c8dfad35 100644
--- a/llvm/test/Transforms/InstCombine/icmp-eq-common-factor.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-eq-common-factor.ll
@@ -14,7 +14,7 @@ define i1 @icmp_common_add_sub(i32 %a, i32 %b, i32 %c){
   ret i1 %cmp
 }
 
-; A + B - C == B
+; A + B - C !=  B
 define i1 @icmp_common_add_sub_ne(i32 %a, i32 %b, i32 %c){
 ; CHECK-LABEL: define i1 @icmp_common_add_sub_ne(
 ; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]], i32 [[C:%.*]]) {



More information about the llvm-commits mailing list