[llvm] [InstCombine] fold (A + B - C == B) -> (A == C). (PR #76129)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Dec 20 23:57:46 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Chia (sun-jacobi)
<details>
<summary>Changes</summary>
This patch closes #<!-- -->72512.
We implemented a work-list algorithm to check whether LHS and RHS are sharing a common value, and then eliminate it.
This implementation actually also covers case `(A + B == B) -> (A == 0)`, which is a pre-exist optimization.
---
Full diff: https://github.com/llvm/llvm-project/pull/76129.diff
3 Files Affected:
- (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+42)
- (modified) llvm/lib/Transforms/InstCombine/InstCombineInternal.h (+2)
- (added) llvm/test/Transforms/InstCombine/icmp-eq-common-factor.ll (+121)
``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 0ad87eeb4c91a4..399af8d1133f08 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -4547,6 +4547,43 @@ static Instruction *foldICmpXorXX(ICmpInst &I, const SimplifyQuery &Q,
return nullptr;
}
+// extract common factors like ((A + B) - C == B) -> (A - C == 0)
+Instruction *InstCombinerImpl::foldICmpWithCommonFactors(ICmpInst &Cmp,
+ BinaryOperator *LBO,
+ Value *RHS) {
+ const CmpInst::Predicate Pred = Cmp.getPredicate();
+ if (!ICmpInst::isEquality(Pred))
+ return nullptr;
+
+ SmallVector<BinaryOperator *, 16> worklist(1, LBO);
+ Constant *Zero = Constant::getNullValue(LBO->getType());
+
+ while (!worklist.empty()) {
+ BinaryOperator *BO = worklist.pop_back_val();
+
+ if (Value * A; match(BO, m_OneUse(m_c_Add(m_Value(A), m_Specific(RHS))))) {
+ if (BO != LBO) {
+ replaceInstUsesWith(*BO, A);
+ eraseInstFromFunction(*BO);
+ }
+ return new ICmpInst(Pred, LBO, Zero);
+ }
+
+ unsigned Opc = BO->getOpcode();
+ if (Opc == Instruction::Add || Opc == Instruction::Sub) {
+ auto AddNextBO = [&](Value *Op) {
+ if (BinaryOperator *Next = dyn_cast<BinaryOperator>(Op))
+ worklist.push_back(Next);
+ };
+
+ AddNextBO(BO->getOperand(0));
+ AddNextBO(BO->getOperand(1));
+ }
+ }
+
+ return nullptr;
+}
+
/// Try to fold icmp (binop), X or icmp X, (binop).
/// TODO: A large part of this logic is duplicated in InstSimplify's
/// simplifyICmpWithBinOp(). We should be able to share that and avoid the code
@@ -4565,6 +4602,11 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I,
if (Instruction *NewICmp = foldICmpXNegX(I, Builder))
return NewICmp;
+ if (BO0) {
+ if (Instruction *NewICmp = foldICmpWithCommonFactors(I, BO0, Op1))
+ return NewICmp;
+ }
+
const CmpInst::Predicate Pred = I.getPredicate();
Value *X;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index f86db698ef8f12..9081373c0157d6 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -632,6 +632,8 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
Instruction *foldICmpInstWithConstantNotInt(ICmpInst &Cmp);
Instruction *foldICmpInstWithConstantAllowUndef(ICmpInst &Cmp,
const APInt &C);
+ Instruction *foldICmpWithCommonFactors(ICmpInst &Cmp, BinaryOperator *LBO,
+ Value *RHS);
Instruction *foldICmpBinOp(ICmpInst &Cmp, const SimplifyQuery &SQ);
Instruction *foldICmpWithMinMaxImpl(Instruction &I, MinMaxIntrinsic *MinMax,
Value *Z, ICmpInst::Predicate Pred);
diff --git a/llvm/test/Transforms/InstCombine/icmp-eq-common-factor.ll b/llvm/test/Transforms/InstCombine/icmp-eq-common-factor.ll
new file mode 100644
index 00000000000000..e28c0910c9cae2
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/icmp-eq-common-factor.ll
@@ -0,0 +1,121 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -passes=instcombine -S < %s | FileCheck %s
+
+; A + B - C == B
+define i1 @icmp_common_add_sub(i32 %a, i32 %b, i32 %c){
+; CHECK-LABEL: define i1 @icmp_common_add_sub(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]], i32 [[C:%.*]]) {
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[A]], [[C]]
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %add = add i32 %a, %b
+ %sub = sub i32 %add, %c
+ %cmp = icmp eq i32 %sub, %b
+ ret i1 %cmp
+}
+
+; A + B - C == B
+define i1 @icmp_common_add_sub_ne(i32 %a, i32 %b, i32 %c){
+; CHECK-LABEL: define i1 @icmp_common_add_sub_ne(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]], i32 [[C:%.*]]) {
+; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32 [[A]], [[C]]
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %add = add i32 %a, %b
+ %sub = sub i32 %add, %c
+ %cmp = icmp ne i32 %sub, %b
+ ret i1 %cmp
+}
+
+; A * B - C == B
+define i1 @icmp_common_mul_sub(i32 %a, i32 %b, i32 %c){
+; CHECK-LABEL: define i1 @icmp_common_mul_sub(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]], i32 [[C:%.*]]) {
+; CHECK-NEXT: [[MUL:%.*]] = mul i32 [[A]], [[B]]
+; CHECK-NEXT: [[SUB:%.*]] = sub i32 [[MUL]], [[C]]
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[SUB]], [[B]]
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %mul = mul i32 %a, %b
+ %sub = sub i32 %mul, %c
+ %cmp = icmp eq i32 %sub, %b
+ ret i1 %cmp
+}
+
+; A + B + C == B
+define i1 @icmp_common_add_add(i32 %a, i32 %b, i32 %c){
+; CHECK-LABEL: define i1 @icmp_common_add_add(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]], i32 [[C:%.*]]) {
+; CHECK-NEXT: [[ADD2:%.*]] = sub i32 0, [[C]]
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD2]], [[A]]
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %add1 = add i32 %a, %b
+ %add2 = add i32 %add1, %c
+ %cmp = icmp eq i32 %add2, %b
+ ret i1 %cmp
+}
+
+; A + B + C + D == B
+define i1 @icmp_common_add_add_add(i32 %a, i32 %b, i32 %c, i32 %d){
+; CHECK-LABEL: define i1 @icmp_common_add_add_add(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]], i32 [[C:%.*]], i32 [[D:%.*]]) {
+; CHECK-NEXT: [[ADD2:%.*]] = add i32 [[A]], [[C]]
+; CHECK-NEXT: [[ADD3:%.*]] = sub i32 0, [[D]]
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD2]], [[ADD3]]
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %add1 = add i32 %a, %b
+ %add2 = add i32 %add1, %c
+ %add3 = add i32 %add2, %d
+ %cmp = icmp eq i32 %add3, %b
+ ret i1 %cmp
+}
+
+; A + B + C + D == C
+define i1 @icmp_common_add_add_add_2(i32 %a, i32 %b, i32 %c, i32 %d){
+; CHECK-LABEL: define i1 @icmp_common_add_add_add_2(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]], i32 [[C:%.*]], i32 [[D:%.*]]) {
+; CHECK-NEXT: [[ADD1:%.*]] = add i32 [[A]], [[B]]
+; CHECK-NEXT: [[ADD3:%.*]] = sub i32 0, [[D]]
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD1]], [[ADD3]]
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %add1 = add i32 %a, %b
+ %add2 = add i32 %add1, %c
+ %add3 = add i32 %add2, %d
+ %cmp = icmp eq i32 %add3, %c
+ ret i1 %cmp
+}
+
+; A + B + C - D == B
+define i1 @icmp_common_add_add_sub(i32 %a, i32 %b, i32 %c, i32 %d){
+; CHECK-LABEL: define i1 @icmp_common_add_add_sub(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]], i32 [[C:%.*]], i32 [[D:%.*]]) {
+; CHECK-NEXT: [[ADD2:%.*]] = add i32 [[A]], [[C]]
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD2]], [[D]]
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %add1 = add i32 %a, %b
+ %add2 = add i32 %add1, %c
+ %sub = sub i32 %add2, %d
+ %cmp = icmp eq i32 %sub, %b
+ ret i1 %cmp
+}
+
+
+; A + B - C + D == B
+define i1 @icmp_common_add_sub_add(i32 %a, i32 %b, i32 %c, i32 %d){
+; CHECK-LABEL: define i1 @icmp_common_add_sub_add(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]], i32 [[C:%.*]], i32 [[D:%.*]]) {
+; CHECK-NEXT: [[SUB:%.*]] = sub i32 [[A]], [[C]]
+; CHECK-NEXT: [[ADD2:%.*]] = sub i32 0, [[D]]
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[SUB]], [[ADD2]]
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %add1 = add i32 %a, %b
+ %sub = sub i32 %add1, %c
+ %add2 = add i32 %sub, %d
+ %cmp = icmp eq i32 %add2, %b
+ ret i1 %cmp
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/76129
More information about the llvm-commits
mailing list