[llvm] [InstCombine] Fold zext(icmp (A, xxx)) == shr(A, BW - 1) => not(trunc(xor(zext(icmp), shl))) (PR #68244)

via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 4 11:35:56 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

<details>
<summary>Changes</summary>

Resolves #<!-- -->67916 .
This patch extends `foldICmpEquality` to fold `zext(icmp (A, xxx)) == shr(A, BW - 1)` into `not(trunc(xor(zext(icmp), shl)))`.
Here I think `xor` would be better for `i1` type than `eq`.
[Alive2](https://alive2.llvm.org/ce/z/t7UXuG).

---
Full diff: https://github.com/llvm/llvm-project/pull/68244.diff


3 Files Affected:

- (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+29-14) 
- (modified) llvm/test/Transforms/InstCombine/icmp-shr.ll (+3-4) 
- (modified) llvm/test/Transforms/InstCombine/icmp-xor-signbit.ll (+80) 


``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 9f034aba874a8c4..a0a45c73695f5c9 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -5311,11 +5311,7 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
       return new ICmpInst(Pred, A, Builder.CreateTrunc(B, A->getType()));
   }
 
-  // Test if 2 values have different or same signbits:
-  // (X u>> BitWidth - 1) == zext (Y s> -1) --> (X ^ Y) < 0
-  // (X u>> BitWidth - 1) != zext (Y s> -1) --> (X ^ Y) > -1
-  // (X s>> BitWidth - 1) == sext (Y s> -1) --> (X ^ Y) < 0
-  // (X s>> BitWidth - 1) != sext (Y s> -1) --> (X ^ Y) > -1
+  // Signbit test
   Instruction *ExtI;
   if (match(Op1, m_CombineAnd(m_Instruction(ExtI), m_ZExtOrSExt(m_Value(A)))) &&
       (Op0->hasOneUse() || Op1->hasOneUse())) {
@@ -5325,17 +5321,36 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
     ICmpInst::Predicate Pred2;
     if (match(Op0, m_CombineAnd(m_Instruction(ShiftI),
                                 m_Shr(m_Value(X),
-                                      m_SpecificIntAllowUndef(OpWidth - 1)))) &&
-        match(A, m_ICmp(Pred2, m_Value(Y), m_AllOnes())) &&
-        Pred2 == ICmpInst::ICMP_SGT && X->getType() == Y->getType()) {
+                                      m_SpecificIntAllowUndef(OpWidth - 1))))) {
+      // Test if 2 values have different or same signbits:
+      // (X u>> BitWidth - 1) == zext (Y s> -1) --> (X ^ Y) < 0
+      // (X u>> BitWidth - 1) != zext (Y s> -1) --> (X ^ Y) > -1
+      // (X s>> BitWidth - 1) == sext (Y s> -1) --> (X ^ Y) < 0
+      // (X s>> BitWidth - 1) != sext (Y s> -1) --> (X ^ Y) > -1
       unsigned ExtOpc = ExtI->getOpcode();
       unsigned ShiftOpc = ShiftI->getOpcode();
-      if ((ExtOpc == Instruction::ZExt && ShiftOpc == Instruction::LShr) ||
-          (ExtOpc == Instruction::SExt && ShiftOpc == Instruction::AShr)) {
-        Value *Xor = Builder.CreateXor(X, Y, "xor.signbits");
-        Value *R = (Pred == ICmpInst::ICMP_EQ) ? Builder.CreateIsNeg(Xor)
-                                               : Builder.CreateIsNotNeg(Xor);
-        return replaceInstUsesWith(I, R);
+
+      if (match(A, m_ICmp(Pred2, m_Value(Y), m_AllOnes())) &&
+          Pred2 == ICmpInst::ICMP_SGT && X->getType() == Y->getType()) {
+        if ((ExtOpc == Instruction::ZExt && ShiftOpc == Instruction::LShr) ||
+            (ExtOpc == Instruction::SExt && ShiftOpc == Instruction::AShr)) {
+          Value *Xor = Builder.CreateXor(X, Y, "xor.signbits");
+          Value *R = (Pred == ICmpInst::ICMP_EQ) ? Builder.CreateIsNeg(Xor)
+                                                 : Builder.CreateIsNotNeg(Xor);
+          return replaceInstUsesWith(I, R);
+        }
+      }
+
+      // Transform (X < 0 ==/!= icmp(X)) into (not) xor(X < 0, icmp(X))
+      if (match(A, m_c_ICmp(Pred2, m_Value(X), m_Value())) &&
+          ExtOpc == Instruction::ZExt && ShiftOpc == Instruction::LShr) {
+
+        Value *Xor = Builder.CreateXor(Op0, Op1, "xor.ne");
+        Value *Trunc = Builder.CreateSExtOrTrunc(Xor, A->getType(), "eq.trunc");
+        Value *Not = (Pred == ICmpInst::ICMP_EQ)
+                         ? Builder.CreateNot(Trunc, "eq.not")
+                         : Trunc;
+        return replaceInstUsesWith(I, Not);
       }
     }
   }
diff --git a/llvm/test/Transforms/InstCombine/icmp-shr.ll b/llvm/test/Transforms/InstCombine/icmp-shr.ll
index f4dfa2edfa17710..b0ecd5ad6a01b2f 100644
--- a/llvm/test/Transforms/InstCombine/icmp-shr.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-shr.ll
@@ -1397,11 +1397,10 @@ define <2 x i1> @same_signbit_poison_elts(<2 x i8> %x, <2 x i8> %y) {
 
 define i1 @same_signbit_wrong_type(i8 %x, i32 %y) {
 ; CHECK-LABEL: @same_signbit_wrong_type(
-; CHECK-NEXT:    [[XSIGN:%.*]] = lshr i8 [[X:%.*]], 7
 ; CHECK-NEXT:    [[YPOS:%.*]] = icmp sgt i32 [[Y:%.*]], -1
-; CHECK-NEXT:    [[YPOSZ:%.*]] = zext i1 [[YPOS]] to i8
-; CHECK-NEXT:    [[R:%.*]] = icmp ne i8 [[XSIGN]], [[YPOSZ]]
-; CHECK-NEXT:    ret i1 [[R]]
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp slt i8 [[X:%.*]], 0
+; CHECK-NEXT:    [[TMP2:%.*]] = xor i1 [[TMP1]], [[YPOS]]
+; CHECK-NEXT:    ret i1 [[TMP2]]
 ;
   %xsign = lshr i8 %x, 7
   %ypos = icmp sgt i32 %y, -1
diff --git a/llvm/test/Transforms/InstCombine/icmp-xor-signbit.ll b/llvm/test/Transforms/InstCombine/icmp-xor-signbit.ll
index 29a18ebbdd94e16..f4286023779a5a7 100644
--- a/llvm/test/Transforms/InstCombine/icmp-xor-signbit.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-xor-signbit.ll
@@ -217,3 +217,83 @@ define <2 x i1> @negative_simplify_splat(<4 x i8> %x) {
   ret <2 x i1> %c
 }
 
+
+define i1 @slt_zero_eq_ne_0(i32 %a) {
+; CHECK-LABEL: @slt_zero_eq_ne_0(
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp slt i32 [[A:%.*]], 1
+; CHECK-NEXT:    ret i1 [[TMP1]]
+;
+  %cmp = icmp ne i32 %a, 0
+  %conv = zext i1 %cmp to i32
+  %cmp1 = lshr i32 %a, 31
+  %cmp2 = icmp eq i32 %conv, %cmp1
+  ret i1 %cmp2
+}
+
+define i1 @slt_zero_ne_ne_0(i32 %a) {
+; CHECK-LABEL: @slt_zero_ne_ne_0(
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp sgt i32 [[A:%.*]], 0
+; CHECK-NEXT:    ret i1 [[TMP1]]
+;
+  %cmp = icmp ne i32 %a, 0
+  %conv = zext i1 %cmp to i32
+  %cmp1 = lshr i32 %a, 31
+  %cmp2 = icmp ne i32 %conv, %cmp1
+  ret i1 %cmp2
+}
+
+define <4 x i1> @slt_zero_eq_ne_0_vec(<4 x i32> %a) {
+; CHECK-LABEL: @slt_zero_eq_ne_0_vec(
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp slt <4 x i32> [[A:%.*]], <i32 1, i32 1, i32 1, i32 1>
+; CHECK-NEXT:    ret <4 x i1> [[TMP1]]
+;
+  %cmp = icmp ne <4 x i32> %a, zeroinitializer
+  %conv = zext <4 x i1> %cmp to <4 x i32>
+  %cmp1 = lshr <4 x i32> %a, <i32 31, i32 31, i32 31, i32 31>
+  %cmp2 = icmp eq <4 x i32> %conv, %cmp1
+  ret <4 x i1> %cmp2
+}
+
+define i1 @slt_zero_ne_ne_b(i32 %a, i32 %b) {
+; CHECK-LABEL: @slt_zero_ne_ne_b(
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ne i32 [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp slt i32 [[A]], 0
+; CHECK-NEXT:    [[TMP2:%.*]] = xor i1 [[TMP1]], [[CMP]]
+; CHECK-NEXT:    ret i1 [[TMP2]]
+;
+  %cmp = icmp ne i32 %a, %b
+  %conv = zext i1 %cmp to i32
+  %cmp1 = lshr i32 %a, 31
+  %cmp2 = icmp ne i32 %conv, %cmp1
+  ret i1 %cmp2
+}
+
+define i1 @slt_zero_eq_ne_0_fail1(i32 %a) {
+; CHECK-LABEL: @slt_zero_eq_ne_0_fail1(
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ne i32 [[A:%.*]], 0
+; CHECK-NEXT:    [[CONV:%.*]] = zext i1 [[CMP]] to i32
+; CHECK-NEXT:    [[CMP1:%.*]] = ashr i32 [[A]], 31
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq i32 [[CMP1]], [[CONV]]
+; CHECK-NEXT:    ret i1 [[CMP2]]
+;
+  %cmp = icmp ne i32 %a, 0
+  %conv = zext i1 %cmp to i32
+  %cmp1 = ashr i32 %a, 31
+  %cmp2 = icmp eq i32 %conv, %cmp1
+  ret i1 %cmp2
+}
+
+define i1 @slt_zero_eq_ne_0_fail2(i32 %a) {
+; CHECK-LABEL: @slt_zero_eq_ne_0_fail2(
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ne i32 [[A:%.*]], 0
+; CHECK-NEXT:    [[CONV:%.*]] = zext i1 [[CMP]] to i32
+; CHECK-NEXT:    [[CMP1:%.*]] = lshr i32 [[A]], 30
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq i32 [[CMP1]], [[CONV]]
+; CHECK-NEXT:    ret i1 [[CMP2]]
+;
+  %cmp = icmp ne i32 %a, 0
+  %conv = zext i1 %cmp to i32
+  %cmp1 = lshr i32 %a, 30
+  %cmp2 = icmp eq i32 %conv, %cmp1
+  ret i1 %cmp2
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/68244


More information about the llvm-commits mailing list