[llvm] [InstCombine] Extend `foldICmpBinOp` to `add`-like `or`. (PR #71396)

Mikhail Gudim via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 6 06:02:19 PST 2023


https://github.com/mgudim created https://github.com/llvm/llvm-project/pull/71396

InstCombine canonicalizes `add` to `or` when possible, but this makes some optimizations applicable to `add` to be missed because they don't realize that the `or` is equivalent to `add`.

In this patch we generalize `foldICmpBinOp` to handle such cases.

>From f4346fcd9be66472ca850b0efc08f3adaa797385 Mon Sep 17 00:00:00 2001
From: Mikhail Gudim <mgudim at gmail.com>
Date: Fri, 3 Nov 2023 16:15:36 -0400
Subject: [PATCH] [InstCombine] Extend `foldICmpBinOp` to `add`-like `or`.

InstCombine canonicalizes `add` to `or` when possible, but this makes
some optimizations applicable to `add` to be missed because they don't
realize that the `or` is equivalent to `add`.

In this patch we generalize `foldICmpBinOp` to handle such cases.
---
 .../InstCombine/InstCombineCompares.cpp       | 59 ++++++++++++-------
 1 file changed, 39 insertions(+), 20 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 55e26d09cd6e829..9b7edbe31900e95 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -4564,6 +4564,17 @@ static Instruction *foldICmpXorXX(ICmpInst &I, const SimplifyQuery &Q,
   return nullptr;
 }
 
+static bool isAddLike(const Instruction &I, const SimplifyQuery &SQ) {
+  unsigned Opc = I.getOpcode();
+  if (Opc == Instruction::Add)
+    return true;
+  if (Opc == Instruction::Or) {
+    return haveNoCommonBitsSet(I.getOperand(0), I.getOperand(1), SQ.getWithInstruction(&I));
+  }
+  return false;
+}
+
+
 /// Try to fold icmp (binop), X or icmp X, (binop).
 /// TODO: A large part of this logic is duplicated in InstSimplify's
 /// simplifyICmpWithBinOp(). We should be able to share that and avoid the code
@@ -4641,25 +4652,37 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I,
   }
 
   bool NoOp0WrapProblem = false, NoOp1WrapProblem = false;
-  if (BO0 && isa<OverflowingBinaryOperator>(BO0))
+  bool Op0HasNUW = false, Op1HasNUW = false;
+  bool Op0HasNSW = false, Op1HasNSW = false;
+  if (BO0 && isa<OverflowingBinaryOperator>(BO0)) {
+    Op0HasNUW = BO0->hasNoUnsignedWrap();
+    Op0HasNSW = BO0->hasNoSignedWrap();
     NoOp0WrapProblem =
         ICmpInst::isEquality(Pred) ||
-        (CmpInst::isUnsigned(Pred) && BO0->hasNoUnsignedWrap()) ||
-        (CmpInst::isSigned(Pred) && BO0->hasNoSignedWrap());
-  if (BO1 && isa<OverflowingBinaryOperator>(BO1))
+        (CmpInst::isUnsigned(Pred) && Op0HasNUW) ||
+        (CmpInst::isSigned(Pred) && Op0HasNSW);
+  }
+  if (BO1 && isa<OverflowingBinaryOperator>(BO1)) {
+    Op1HasNUW = BO1->hasNoUnsignedWrap();
+    Op1HasNSW = BO1->hasNoSignedWrap();
     NoOp1WrapProblem =
         ICmpInst::isEquality(Pred) ||
-        (CmpInst::isUnsigned(Pred) && BO1->hasNoUnsignedWrap()) ||
-        (CmpInst::isSigned(Pred) && BO1->hasNoSignedWrap());
+        (CmpInst::isUnsigned(Pred) && Op1HasNUW) ||
+        (CmpInst::isSigned(Pred) && Op1HasNSW);
+  }
 
   // Analyze the case when either Op0 or Op1 is an add instruction.
   // Op0 = A + B (or A and B are null); Op1 = C + D (or C and D are null).
   Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr;
-  if (BO0 && BO0->getOpcode() == Instruction::Add) {
+  if (BO0 && isAddLike(*BO0, SQ)) {
+    if (BO0->getOpcode() == Instruction::Or)
+      NoOp0WrapProblem = true;
     A = BO0->getOperand(0);
     B = BO0->getOperand(1);
   }
-  if (BO1 && BO1->getOpcode() == Instruction::Add) {
+  if (BO1 && isAddLike(*BO1, SQ)) {
+    if (BO1->getOpcode() == Instruction::Or)
+      NoOp1WrapProblem = true;
     C = BO1->getOperand(0);
     D = BO1->getOperand(1);
   }
@@ -4781,17 +4804,13 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I,
       APInt AP2Abs = AP2->abs();
       if (AP1Abs.uge(AP2Abs)) {
         APInt Diff = *AP1 - *AP2;
-        bool HasNUW = BO0->hasNoUnsignedWrap() && Diff.ule(*AP1);
-        bool HasNSW = BO0->hasNoSignedWrap();
         Constant *C3 = Constant::getIntegerValue(BO0->getType(), Diff);
-        Value *NewAdd = Builder.CreateAdd(A, C3, "", HasNUW, HasNSW);
+        Value *NewAdd = Builder.CreateAdd(A, C3, "", Op0HasNUW && Diff.ule(*AP1), Op0HasNSW);
         return new ICmpInst(Pred, NewAdd, C);
       } else {
         APInt Diff = *AP2 - *AP1;
-        bool HasNUW = BO1->hasNoUnsignedWrap() && Diff.ule(*AP2);
-        bool HasNSW = BO1->hasNoSignedWrap();
         Constant *C3 = Constant::getIntegerValue(BO0->getType(), Diff);
-        Value *NewAdd = Builder.CreateAdd(C, C3, "", HasNUW, HasNSW);
+        Value *NewAdd = Builder.CreateAdd(C, C3, "", Op1HasNUW && Diff.ule(*AP1), Op1HasNSW);
         return new ICmpInst(Pred, A, NewAdd);
       }
     }
@@ -4885,16 +4904,16 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I,
                   isKnownNonZero(Z, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT);
         // if Z != 0 and nsw(X * Z) and nsw(Y * Z)
         //    X * Z eq/ne Y * Z -> X eq/ne Y
-        if (NonZero && BO0 && BO1 && BO0->hasNoSignedWrap() &&
-            BO1->hasNoSignedWrap())
+        if (NonZero && BO0 && BO1 && Op0HasNUW &&
+            Op1HasNSW)
           return new ICmpInst(Pred, X, Y);
       } else
         NonZero = isKnownNonZero(Z, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT);
 
       // If Z != 0 and nuw(X * Z) and nuw(Y * Z)
       //    X * Z u{lt/le/gt/ge}/eq/ne Y * Z -> X u{lt/le/gt/ge}/eq/ne Y
-      if (NonZero && BO0 && BO1 && BO0->hasNoUnsignedWrap() &&
-          BO1->hasNoUnsignedWrap())
+      if (NonZero && BO0 && BO1 && Op0HasNUW&&
+          Op1HasNUW)
         return new ICmpInst(Pred, X, Y);
     }
   }
@@ -4993,8 +5012,8 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I,
       return new ICmpInst(Pred, BO0->getOperand(0), BO1->getOperand(0));
 
     case Instruction::Shl: {
-      bool NUW = BO0->hasNoUnsignedWrap() && BO1->hasNoUnsignedWrap();
-      bool NSW = BO0->hasNoSignedWrap() && BO1->hasNoSignedWrap();
+      bool NUW = Op0HasNUW && Op1HasNUW;
+      bool NSW = Op0HasNSW && Op1HasNSW;
       if (!NUW && !NSW)
         break;
       if (!NSW && I.isSigned())



More information about the llvm-commits mailing list