[llvm] 0d454d6 - [InstCombine] Fold xor of icmps using range information (#76334)

via llvm-commits llvm-commits at lists.llvm.org
Sun Dec 24 15:14:34 PST 2023


Author: Yingwei Zheng
Date: 2023-12-25T07:14:31+08:00
New Revision: 0d454d6e591a579f450093c4ba8c49675e1643ad

URL: https://github.com/llvm/llvm-project/commit/0d454d6e591a579f450093c4ba8c49675e1643ad
DIFF: https://github.com/llvm/llvm-project/commit/0d454d6e591a579f450093c4ba8c49675e1643ad.diff

LOG: [InstCombine] Fold xor of icmps using range information (#76334)

This patch folds xor of icmps into a single comparison using range-based reasoning as `foldAndOrOfICmpsUsingRanges` does.
Fixes #70928.

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
    llvm/test/Transforms/InstCombine/and-or-icmps.ll
    llvm/test/Transforms/InstCombine/xor-icmps.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 5e362f4117d051..63b1e0f64a8824 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3956,35 +3956,50 @@ Value *InstCombinerImpl::foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS,
   const APInt *LC, *RC;
   if (match(LHS1, m_APInt(LC)) && match(RHS1, m_APInt(RC)) &&
       LHS0->getType() == RHS0->getType() &&
-      LHS0->getType()->isIntOrIntVectorTy() &&
-      (LHS->hasOneUse() || RHS->hasOneUse())) {
+      LHS0->getType()->isIntOrIntVectorTy()) {
     // Convert xor of signbit tests to signbit test of xor'd values:
     // (X > -1) ^ (Y > -1) --> (X ^ Y) < 0
     // (X <  0) ^ (Y <  0) --> (X ^ Y) < 0
     // (X > -1) ^ (Y <  0) --> (X ^ Y) > -1
     // (X <  0) ^ (Y > -1) --> (X ^ Y) > -1
     bool TrueIfSignedL, TrueIfSignedR;
-    if (isSignBitCheck(PredL, *LC, TrueIfSignedL) &&
+    if ((LHS->hasOneUse() || RHS->hasOneUse()) &&
+        isSignBitCheck(PredL, *LC, TrueIfSignedL) &&
         isSignBitCheck(PredR, *RC, TrueIfSignedR)) {
       Value *XorLR = Builder.CreateXor(LHS0, RHS0);
       return TrueIfSignedL == TrueIfSignedR ? Builder.CreateIsNeg(XorLR) :
                                               Builder.CreateIsNotNeg(XorLR);
     }
 
-    // (X > C) ^ (X < C + 2) --> X != C + 1
-    // (X < C + 2) ^ (X > C) --> X != C + 1
-    // Considering the correctness of this pattern, we should avoid that C is
-    // non-negative and C + 2 is negative, although it will be matched by other
-    // patterns.
-    const APInt *C1, *C2;
-    if ((PredL == CmpInst::ICMP_SGT && match(LHS1, m_APInt(C1)) &&
-         PredR == CmpInst::ICMP_SLT && match(RHS1, m_APInt(C2))) ||
-        (PredL == CmpInst::ICMP_SLT && match(LHS1, m_APInt(C2)) &&
-         PredR == CmpInst::ICMP_SGT && match(RHS1, m_APInt(C1))))
-      if (LHS0 == RHS0 && *C1 + 2 == *C2 &&
-          (C1->isNegative() || C2->isNonNegative()))
-        return Builder.CreateICmpNE(LHS0,
-                                    ConstantInt::get(LHS0->getType(), *C1 + 1));
+    // Fold (icmp pred1 X, C1) ^ (icmp pred2 X, C2)
+    // into a single comparison using range-based reasoning.
+    if (LHS0 == RHS0) {
+      ConstantRange CR1 = ConstantRange::makeExactICmpRegion(PredL, *LC);
+      ConstantRange CR2 = ConstantRange::makeExactICmpRegion(PredR, *RC);
+      auto CRUnion = CR1.exactUnionWith(CR2);
+      auto CRIntersect = CR1.exactIntersectWith(CR2);
+      if (CRUnion && CRIntersect)
+        if (auto CR = CRUnion->exactIntersectWith(CRIntersect->inverse())) {
+          if (CR->isFullSet())
+            return ConstantInt::getTrue(I.getType());
+          if (CR->isEmptySet())
+            return ConstantInt::getFalse(I.getType());
+
+          CmpInst::Predicate NewPred;
+          APInt NewC, Offset;
+          CR->getEquivalentICmp(NewPred, NewC, Offset);
+
+          if ((Offset.isZero() && (LHS->hasOneUse() || RHS->hasOneUse())) ||
+              (LHS->hasOneUse() && RHS->hasOneUse())) {
+            Value *NewV = LHS0;
+            Type *Ty = LHS0->getType();
+            if (!Offset.isZero())
+              NewV = Builder.CreateAdd(NewV, ConstantInt::get(Ty, Offset));
+            return Builder.CreateICmp(NewPred, NewV,
+                                      ConstantInt::get(Ty, NewC));
+          }
+        }
+    }
   }
 
   // Instead of trying to imitate the folds for and/or, decompose this 'xor'

diff  --git a/llvm/test/Transforms/InstCombine/and-or-icmps.ll b/llvm/test/Transforms/InstCombine/and-or-icmps.ll
index 881a9b7ff129db..91ecf24760259b 100644
--- a/llvm/test/Transforms/InstCombine/and-or-icmps.ll
+++ b/llvm/test/Transforms/InstCombine/and-or-icmps.ll
@@ -3015,10 +3015,8 @@ define i32 @icmp_x_slt_0_and_icmp_y_sgt_neg1_i32_fail(i32 %x, i32 %y) {
 
 define i32 @icmp_slt_0_xor_icmp_sge_neg2_i32_fail(i32 %x) {
 ; CHECK-LABEL: @icmp_slt_0_xor_icmp_sge_neg2_i32_fail(
-; CHECK-NEXT:    [[A:%.*]] = icmp sgt i32 [[X:%.*]], -3
-; CHECK-NEXT:    [[TMP1:%.*]] = icmp slt i32 [[X]], 0
-; CHECK-NEXT:    [[TMP2:%.*]] = xor i1 [[TMP1]], [[A]]
-; CHECK-NEXT:    [[D:%.*]] = zext i1 [[TMP2]] to i32
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp ult i32 [[X:%.*]], -2
+; CHECK-NEXT:    [[D:%.*]] = zext i1 [[TMP1]] to i32
 ; CHECK-NEXT:    ret i32 [[D]]
 ;
   %A = icmp sge i32 %x, -2

diff  --git a/llvm/test/Transforms/InstCombine/xor-icmps.ll b/llvm/test/Transforms/InstCombine/xor-icmps.ll
index c85993ea9a7e0d..f104cd7fdcada5 100644
--- a/llvm/test/Transforms/InstCombine/xor-icmps.ll
+++ b/llvm/test/Transforms/InstCombine/xor-icmps.ll
@@ -171,3 +171,151 @@ define i1 @xor_icmp_ptr(ptr %c, ptr %d) {
   ret i1 %xor
 }
 
+; Tests from PR70928
+define i1 @xor_icmp_true_signed(i32 %a) {
+; CHECK-LABEL: @xor_icmp_true_signed(
+; CHECK-NEXT:    ret i1 true
+;
+  %cmp = icmp sgt i32 %a, 5
+  %cmp1 = icmp slt i32 %a, 6
+  %cmp3 = xor i1 %cmp, %cmp1
+  ret i1 %cmp3
+}
+define i1 @xor_icmp_true_signed_multiuse1(i32 %a) {
+; CHECK-LABEL: @xor_icmp_true_signed_multiuse1(
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i32 [[A:%.*]], 5
+; CHECK-NEXT:    call void @use(i1 [[CMP]])
+; CHECK-NEXT:    ret i1 true
+;
+  %cmp = icmp sgt i32 %a, 5
+  call void @use(i1 %cmp)
+  %cmp1 = icmp slt i32 %a, 6
+  %cmp3 = xor i1 %cmp, %cmp1
+  ret i1 %cmp3
+}
+define i1 @xor_icmp_true_signed_multiuse2(i32 %a) {
+; CHECK-LABEL: @xor_icmp_true_signed_multiuse2(
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i32 [[A:%.*]], 5
+; CHECK-NEXT:    call void @use(i1 [[CMP]])
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp slt i32 [[A]], 6
+; CHECK-NEXT:    call void @use(i1 [[CMP1]])
+; CHECK-NEXT:    ret i1 true
+;
+  %cmp = icmp sgt i32 %a, 5
+  call void @use(i1 %cmp)
+  %cmp1 = icmp slt i32 %a, 6
+  call void @use(i1 %cmp1)
+  %cmp3 = xor i1 %cmp, %cmp1
+  ret i1 %cmp3
+}
+define i1 @xor_icmp_true_signed_commuted(i32 %a) {
+; CHECK-LABEL: @xor_icmp_true_signed_commuted(
+; CHECK-NEXT:    ret i1 true
+;
+  %cmp = icmp sgt i32 %a, 5
+  %cmp1 = icmp slt i32 %a, 6
+  %cmp3 = xor i1 %cmp1, %cmp
+  ret i1 %cmp3
+}
+define i1 @xor_icmp_true_unsigned(i32 %a) {
+; CHECK-LABEL: @xor_icmp_true_unsigned(
+; CHECK-NEXT:    ret i1 true
+;
+  %cmp = icmp ugt i32 %a, 5
+  %cmp1 = icmp ult i32 %a, 6
+  %cmp3 = xor i1 %cmp, %cmp1
+  ret i1 %cmp3
+}
+define i1 @xor_icmp_to_ne(i32 %a) {
+; CHECK-LABEL: @xor_icmp_to_ne(
+; CHECK-NEXT:    [[CMP3:%.*]] = icmp ne i32 [[A:%.*]], 5
+; CHECK-NEXT:    ret i1 [[CMP3]]
+;
+  %cmp = icmp sgt i32 %a, 4
+  %cmp1 = icmp slt i32 %a, 6
+  %cmp3 = xor i1 %cmp, %cmp1
+  ret i1 %cmp3
+}
+define i1 @xor_icmp_to_ne_multiuse1(i32 %a) {
+; CHECK-LABEL: @xor_icmp_to_ne_multiuse1(
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i32 [[A:%.*]], 4
+; CHECK-NEXT:    call void @use(i1 [[CMP]])
+; CHECK-NEXT:    [[CMP3:%.*]] = icmp ne i32 [[A]], 5
+; CHECK-NEXT:    ret i1 [[CMP3]]
+;
+  %cmp = icmp sgt i32 %a, 4
+  call void @use(i1 %cmp)
+  %cmp1 = icmp slt i32 %a, 6
+  %cmp3 = xor i1 %cmp, %cmp1
+  ret i1 %cmp3
+}
+define i1 @xor_icmp_to_icmp_add(i32 %a) {
+; CHECK-LABEL: @xor_icmp_to_icmp_add(
+; CHECK-NEXT:    [[TMP1:%.*]] = add i32 [[A:%.*]], -6
+; CHECK-NEXT:    [[CMP3:%.*]] = icmp ult i32 [[TMP1]], -2
+; CHECK-NEXT:    ret i1 [[CMP3]]
+;
+  %cmp = icmp sgt i32 %a, 3
+  %cmp1 = icmp slt i32 %a, 6
+  %cmp3 = xor i1 %cmp, %cmp1
+  ret i1 %cmp3
+}
+; Negative tests
+; The result of ConstantRange::
diff erence is not exact.
+define i1 @xor_icmp_invalid_range(i8 %x0) {
+; CHECK-LABEL: @xor_icmp_invalid_range(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[X0:%.*]], -5
+; CHECK-NEXT:    [[OR_COND:%.*]] = icmp ne i8 [[TMP1]], 0
+; CHECK-NEXT:    ret i1 [[OR_COND]]
+;
+  %cmp = icmp eq i8 %x0, 0
+  %cmp4 = icmp ne i8 %x0, 4
+  %or.cond = xor i1 %cmp, %cmp4
+  ret i1 %or.cond
+}
+define i1 @xor_icmp_to_ne_multiuse2(i32 %a) {
+; CHECK-LABEL: @xor_icmp_to_ne_multiuse2(
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i32 [[A:%.*]], 4
+; CHECK-NEXT:    call void @use(i1 [[CMP]])
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp slt i32 [[A]], 6
+; CHECK-NEXT:    call void @use(i1 [[CMP1]])
+; CHECK-NEXT:    [[CMP3:%.*]] = xor i1 [[CMP]], [[CMP1]]
+; CHECK-NEXT:    ret i1 [[CMP3]]
+;
+  %cmp = icmp sgt i32 %a, 4
+  call void @use(i1 %cmp)
+  %cmp1 = icmp slt i32 %a, 6
+  call void @use(i1 %cmp1)
+  %cmp3 = xor i1 %cmp, %cmp1
+  ret i1 %cmp3
+}
+define i1 @xor_icmp_to_icmp_add_multiuse1(i32 %a) {
+; CHECK-LABEL: @xor_icmp_to_icmp_add_multiuse1(
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i32 [[A:%.*]], 3
+; CHECK-NEXT:    call void @use(i1 [[CMP]])
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp slt i32 [[A]], 6
+; CHECK-NEXT:    [[CMP3:%.*]] = xor i1 [[CMP]], [[CMP1]]
+; CHECK-NEXT:    ret i1 [[CMP3]]
+;
+  %cmp = icmp sgt i32 %a, 3
+  call void @use(i1 %cmp)
+  %cmp1 = icmp slt i32 %a, 6
+  %cmp3 = xor i1 %cmp, %cmp1
+  ret i1 %cmp3
+}
+define i1 @xor_icmp_to_icmp_add_multiuse2(i32 %a) {
+; CHECK-LABEL: @xor_icmp_to_icmp_add_multiuse2(
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i32 [[A:%.*]], 3
+; CHECK-NEXT:    call void @use(i1 [[CMP]])
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp slt i32 [[A]], 6
+; CHECK-NEXT:    call void @use(i1 [[CMP1]])
+; CHECK-NEXT:    [[CMP3:%.*]] = xor i1 [[CMP]], [[CMP1]]
+; CHECK-NEXT:    ret i1 [[CMP3]]
+;
+  %cmp = icmp sgt i32 %a, 3
+  call void @use(i1 %cmp)
+  %cmp1 = icmp slt i32 %a, 6
+  call void @use(i1 %cmp1)
+  %cmp3 = xor i1 %cmp, %cmp1
+  ret i1 %cmp3
+}


        


More information about the llvm-commits mailing list