[llvm] [ValueTracking] Handle non-canonical operand order in `isImpliedCondICmps` (PR #85575)

via llvm-commits llvm-commits at lists.llvm.org
Sun Mar 17 14:36:20 PDT 2024


https://github.com/goldsteinn updated https://github.com/llvm/llvm-project/pull/85575

>From d144e08b16b47efba096c38b2822afc897fee768 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Sun, 17 Mar 2024 15:32:53 -0500
Subject: [PATCH 1/2] [ValueTracking] Add tests for implied cond with swapped
 operands; NFC

---
 .../icmp-select-implies-common-op.ll          | 93 +++++++++++++++++++
 1 file changed, 93 insertions(+)
 create mode 100644 llvm/test/Transforms/InstCombine/icmp-select-implies-common-op.ll

diff --git a/llvm/test/Transforms/InstCombine/icmp-select-implies-common-op.ll b/llvm/test/Transforms/InstCombine/icmp-select-implies-common-op.ll
new file mode 100644
index 00000000000000..83850164813cb7
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/icmp-select-implies-common-op.ll
@@ -0,0 +1,93 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -S -passes=instcombine < %s | FileCheck %s
+
+define i1 @sgt_3_impliesF_eq_2(i8 %x, i8 %y) {
+; CHECK-LABEL: @sgt_3_impliesF_eq_2(
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i8 [[X:%.*]], 3
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[CMP]], i8 2, i8 [[Y:%.*]]
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq i8 [[SEL]], [[X]]
+; CHECK-NEXT:    ret i1 [[CMP2]]
+;
+  %cmp = icmp sgt i8 %x, 3
+  %sel = select i1 %cmp, i8 2, i8 %y
+  %cmp2 = icmp eq i8 %sel, %x
+  ret i1 %cmp2
+}
+
+define i1 @sgt_3_impliesT_sgt_2(i8 %x, i8 %y) {
+; CHECK-LABEL: @sgt_3_impliesT_sgt_2(
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i8 [[X:%.*]], 3
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[CMP]], i8 2, i8 [[Y:%.*]]
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp sgt i8 [[SEL]], [[X]]
+; CHECK-NEXT:    ret i1 [[CMP2]]
+;
+  %cmp = icmp sgt i8 %x, 3
+  %sel = select i1 %cmp, i8 2, i8 %y
+  %cmp2 = icmp sgt i8 %sel, %x
+  ret i1 %cmp2
+}
+
+define i1 @sgt_x_impliesF_eq_smin_todo(i8 %x, i8 %y, i8 %z) {
+; CHECK-LABEL: @sgt_x_impliesF_eq_smin_todo(
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i8 [[X:%.*]], [[Z:%.*]]
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[CMP]], i8 -128, i8 [[Y:%.*]]
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq i8 [[SEL]], [[X]]
+; CHECK-NEXT:    ret i1 [[CMP2]]
+;
+  %cmp = icmp sgt i8 %x, %z
+  %sel = select i1 %cmp, i8 -128, i8 %y
+  %cmp2 = icmp eq i8 %sel, %x
+  ret i1 %cmp2
+}
+
+define i1 @slt_x_impliesT_ne_smin_todo(i8 %x, i8 %y, i8 %z) {
+; CHECK-LABEL: @slt_x_impliesT_ne_smin_todo(
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i8 [[X:%.*]], [[Z:%.*]]
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[CMP]], i8 127, i8 [[Y:%.*]]
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp ne i8 [[SEL]], [[X]]
+; CHECK-NEXT:    ret i1 [[CMP2]]
+;
+  %cmp = icmp slt i8 %x, %z
+  %sel = select i1 %cmp, i8 127, i8 %y
+  %cmp2 = icmp ne i8 %x, %sel
+  ret i1 %cmp2
+}
+
+define i1 @ult_x_impliesT_eq_umax_todo(i8 %x, i8 %y, i8 %z) {
+; CHECK-LABEL: @ult_x_impliesT_eq_umax_todo(
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ugt i8 [[Z:%.*]], [[X:%.*]]
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[CMP]], i8 -1, i8 [[Y:%.*]]
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp ne i8 [[SEL]], [[X]]
+; CHECK-NEXT:    ret i1 [[CMP2]]
+;
+  %cmp = icmp ugt i8 %z, %x
+  %sel = select i1 %cmp, i8 255, i8 %y
+  %cmp2 = icmp ne i8 %sel, %x
+  ret i1 %cmp2
+}
+
+define i1 @ult_1_impliesF_eq_1(i8 %x, i8 %y) {
+; CHECK-LABEL: @ult_1_impliesF_eq_1(
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[SEL:%.*]], 0
+; CHECK-NEXT:    [[X:%.*]] = select i1 [[CMP]], i8 1, i8 [[Y:%.*]]
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq i8 [[X]], [[SEL]]
+; CHECK-NEXT:    ret i1 [[CMP2]]
+;
+  %cmp = icmp ult i8 %x, 1
+  %sel = select i1 %cmp, i8 1, i8 %y
+  %cmp2 = icmp eq i8 %x, %sel
+  ret i1 %cmp2
+}
+
+define i1 @ugt_x_impliesF_eq_umin_todo(i8 %x, i8 %y, i8 %z) {
+; CHECK-LABEL: @ugt_x_impliesF_eq_umin_todo(
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ugt i8 [[Z:%.*]], [[X:%.*]]
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[CMP]], i8 0, i8 [[Y:%.*]]
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq i8 [[SEL]], [[X]]
+; CHECK-NEXT:    ret i1 [[CMP2]]
+;
+  %cmp = icmp ugt i8 %z, %x
+  %sel = select i1 %cmp, i8 0, i8 %y
+  %cmp2 = icmp eq i8 %x, %sel
+  ret i1 %cmp2
+}

>From beb6e1a7c83ee623b653df8213bf484550ed9f0a Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Sun, 17 Mar 2024 12:07:09 -0500
Subject: [PATCH 2/2] [ValueTracking] Handle non-canonical operand order in
 `isImpliedCondICmps`

We don't always have canonical order here, so do it manually.
---
 llvm/lib/Analysis/ValueTracking.cpp           | 66 ++++++++-----------
 llvm/test/Transforms/InstCombine/assume.ll    |  2 +-
 .../icmp-select-implies-common-op.ll          | 24 +++----
 llvm/test/Transforms/InstCombine/select.ll    |  6 +-
 llvm/test/Transforms/InstCombine/shift.ll     |  7 +-
 5 files changed, 45 insertions(+), 60 deletions(-)

diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index edbeede910d7f7..651628e404b7a2 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -8471,26 +8471,12 @@ isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS,
   }
 }
 
-/// Return true if the operands of two compares (expanded as "L0 pred L1" and
-/// "R0 pred R1") match. IsSwappedOps is true when the operands match, but are
-/// swapped.
-static bool areMatchingOperands(const Value *L0, const Value *L1, const Value *R0,
-                           const Value *R1, bool &AreSwappedOps) {
-  bool AreMatchingOps = (L0 == R0 && L1 == R1);
-  AreSwappedOps = (L0 == R1 && L1 == R0);
-  return AreMatchingOps || AreSwappedOps;
-}
-
 /// Return true if "icmp1 LPred X, Y" implies "icmp2 RPred X, Y" is true.
 /// Return false if "icmp1 LPred X, Y" implies "icmp2 RPred X, Y" is false.
 /// Otherwise, return std::nullopt if we can't infer anything.
 static std::optional<bool>
 isImpliedCondMatchingOperands(CmpInst::Predicate LPred,
-                              CmpInst::Predicate RPred, bool AreSwappedOps) {
-  // Canonicalize the predicate as if the operands were not commuted.
-  if (AreSwappedOps)
-    RPred = ICmpInst::getSwappedPredicate(RPred);
-
+                              CmpInst::Predicate RPred) {
   if (CmpInst::isImpliedTrueByMatchingCmp(LPred, RPred))
     return true;
   if (CmpInst::isImpliedFalseByMatchingCmp(LPred, RPred))
@@ -8532,6 +8518,25 @@ static std::optional<bool> isImpliedCondICmps(const ICmpInst *LHS,
   CmpInst::Predicate LPred =
       LHSIsTrue ? LHS->getPredicate() : LHS->getInversePredicate();
 
+  // We can have non-canonical operands, so try to normalize any common operand
+  // to L0/R0.
+  if (L0 == R1) {
+    std::swap(R0, R1);
+    RPred = ICmpInst::getSwappedPredicate(RPred);
+  }
+  if (R0 == L1) {
+    std::swap(L0, L1);
+    LPred = ICmpInst::getSwappedPredicate(LPred);
+  }
+  if (L1 == R1) {
+    // If we have L0 == R0 and L1 == R1, then make L1/R1 the constants.
+    if (L0 != R0 || match(L0, m_ImmConstant())) {
+      std::swap(L0, L1);
+      LPred = ICmpInst::getSwappedPredicate(LPred);
+      std::swap(R0, R1);
+      RPred = ICmpInst::getSwappedPredicate(RPred);
+    }
+  }
   // Can we infer anything when the 0-operands match and the 1-operands are
   // constants (not necessarily matching)?
   const APInt *LC, *RC;
@@ -8539,32 +8544,15 @@ static std::optional<bool> isImpliedCondICmps(const ICmpInst *LHS,
     return isImpliedCondCommonOperandWithConstants(LPred, *LC, RPred, *RC);
 
   // Can we infer anything when the two compares have matching operands?
-  bool AreSwappedOps;
-  if (areMatchingOperands(L0, L1, R0, R1, AreSwappedOps))
-    return isImpliedCondMatchingOperands(LPred, RPred, AreSwappedOps);
+  if (L0 == R0 && L1 == R1)
+    return isImpliedCondMatchingOperands(LPred, RPred);
 
   // L0 = R0 = L1 + R1, L0 >=u L1 implies R0 >=u R1, L0 <u L1 implies R0 <u R1
-  if (ICmpInst::isUnsigned(LPred) && ICmpInst::isUnsigned(RPred)) {
-    if (L0 == R1) {
-      std::swap(R0, R1);
-      RPred = ICmpInst::getSwappedPredicate(RPred);
-    }
-    if (L1 == R0) {
-      std::swap(L0, L1);
-      LPred = ICmpInst::getSwappedPredicate(LPred);
-    }
-    if (L1 == R1) {
-      std::swap(L0, L1);
-      LPred = ICmpInst::getSwappedPredicate(LPred);
-      std::swap(R0, R1);
-      RPred = ICmpInst::getSwappedPredicate(RPred);
-    }
-    if (L0 == R0 &&
-        (LPred == ICmpInst::ICMP_ULT || LPred == ICmpInst::ICMP_UGE) &&
-        (RPred == ICmpInst::ICMP_ULT || RPred == ICmpInst::ICMP_UGE) &&
-        match(L0, m_c_Add(m_Specific(L1), m_Specific(R1))))
-      return LPred == RPred;
-  }
+  if (L0 == R0 &&
+      (LPred == ICmpInst::ICMP_ULT || LPred == ICmpInst::ICMP_UGE) &&
+      (RPred == ICmpInst::ICMP_ULT || RPred == ICmpInst::ICMP_UGE) &&
+      match(L0, m_c_Add(m_Specific(L1), m_Specific(R1))))
+    return LPred == RPred;
 
   if (LPred == RPred)
     return isImpliedCondOperands(LPred, L0, L1, R0, R1, DL, Depth);
diff --git a/llvm/test/Transforms/InstCombine/assume.ll b/llvm/test/Transforms/InstCombine/assume.ll
index 927f0a86b0a252..87c75fb2b55592 100644
--- a/llvm/test/Transforms/InstCombine/assume.ll
+++ b/llvm/test/Transforms/InstCombine/assume.ll
@@ -386,7 +386,7 @@ define i1 @nonnull5(ptr %a) {
 define i32 @assumption_conflicts_with_known_bits(i32 %a, i32 %b) {
 ; CHECK-LABEL: @assumption_conflicts_with_known_bits(
 ; CHECK-NEXT:    store i1 true, ptr poison, align 1
-; CHECK-NEXT:    ret i32 1
+; CHECK-NEXT:    ret i32 poison
 ;
   %and1 = and i32 %b, 3
   %B1 = lshr i32 %and1, %and1
diff --git a/llvm/test/Transforms/InstCombine/icmp-select-implies-common-op.ll b/llvm/test/Transforms/InstCombine/icmp-select-implies-common-op.ll
index 83850164813cb7..bacdb54f787d6a 100644
--- a/llvm/test/Transforms/InstCombine/icmp-select-implies-common-op.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-select-implies-common-op.ll
@@ -3,10 +3,10 @@
 
 define i1 @sgt_3_impliesF_eq_2(i8 %x, i8 %y) {
 ; CHECK-LABEL: @sgt_3_impliesF_eq_2(
-; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i8 [[X:%.*]], 3
-; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[CMP]], i8 2, i8 [[Y:%.*]]
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq i8 [[SEL]], [[X]]
-; CHECK-NEXT:    ret i1 [[CMP2]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i8 [[X:%.*]], 4
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq i8 [[SEL:%.*]], [[X]]
+; CHECK-NEXT:    [[CMP3:%.*]] = select i1 [[CMP]], i1 [[CMP2]], i1 false
+; CHECK-NEXT:    ret i1 [[CMP3]]
 ;
   %cmp = icmp sgt i8 %x, 3
   %sel = select i1 %cmp, i8 2, i8 %y
@@ -16,10 +16,10 @@ define i1 @sgt_3_impliesF_eq_2(i8 %x, i8 %y) {
 
 define i1 @sgt_3_impliesT_sgt_2(i8 %x, i8 %y) {
 ; CHECK-LABEL: @sgt_3_impliesT_sgt_2(
-; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i8 [[X:%.*]], 3
-; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[CMP]], i8 2, i8 [[Y:%.*]]
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp sgt i8 [[SEL]], [[X]]
-; CHECK-NEXT:    ret i1 [[CMP2]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i8 [[X:%.*]], 4
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp sgt i8 [[SEL:%.*]], [[X]]
+; CHECK-NEXT:    [[CMP3:%.*]] = select i1 [[CMP]], i1 [[CMP2]], i1 false
+; CHECK-NEXT:    ret i1 [[CMP3]]
 ;
   %cmp = icmp sgt i8 %x, 3
   %sel = select i1 %cmp, i8 2, i8 %y
@@ -68,10 +68,10 @@ define i1 @ult_x_impliesT_eq_umax_todo(i8 %x, i8 %y, i8 %z) {
 
 define i1 @ult_1_impliesF_eq_1(i8 %x, i8 %y) {
 ; CHECK-LABEL: @ult_1_impliesF_eq_1(
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[SEL:%.*]], 0
-; CHECK-NEXT:    [[X:%.*]] = select i1 [[CMP]], i8 1, i8 [[Y:%.*]]
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq i8 [[X]], [[SEL]]
-; CHECK-NEXT:    ret i1 [[CMP2]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ne i8 [[SEL:%.*]], 0
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq i8 [[X:%.*]], [[SEL]]
+; CHECK-NEXT:    [[CMP3:%.*]] = select i1 [[CMP]], i1 [[CMP2]], i1 false
+; CHECK-NEXT:    ret i1 [[CMP3]]
 ;
   %cmp = icmp ult i8 %x, 1
   %sel = select i1 %cmp, i8 1, i8 %y
diff --git a/llvm/test/Transforms/InstCombine/select.ll b/llvm/test/Transforms/InstCombine/select.ll
index a84904106eced4..d9734242a86891 100644
--- a/llvm/test/Transforms/InstCombine/select.ll
+++ b/llvm/test/Transforms/InstCombine/select.ll
@@ -2925,10 +2925,8 @@ define i8 @select_replacement_loop3(i32 noundef %x) {
 
 define i16 @select_replacement_loop4(i16 noundef %p_12) {
 ; CHECK-LABEL: @select_replacement_loop4(
-; CHECK-NEXT:    [[CMP1:%.*]] = icmp ult i16 [[P_12:%.*]], 2
-; CHECK-NEXT:    [[AND1:%.*]] = and i16 [[P_12]], 1
-; CHECK-NEXT:    [[AND2:%.*]] = select i1 [[CMP1]], i16 [[AND1]], i16 0
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq i16 [[AND2]], [[P_12]]
+; CHECK-NEXT:    [[AND1:%.*]] = and i16 [[P_12:%.*]], 1
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp ult i16 [[P_12]], 2
 ; CHECK-NEXT:    [[AND3:%.*]] = select i1 [[CMP2]], i16 [[AND1]], i16 0
 ; CHECK-NEXT:    ret i16 [[AND3]]
 ;
diff --git a/llvm/test/Transforms/InstCombine/shift.ll b/llvm/test/Transforms/InstCombine/shift.ll
index 62f32c28683711..bef7fc81a7d1f9 100644
--- a/llvm/test/Transforms/InstCombine/shift.ll
+++ b/llvm/test/Transforms/InstCombine/shift.ll
@@ -1751,12 +1751,11 @@ define void @ashr_out_of_range_1(ptr %A) {
 ; CHECK-NEXT:    [[L:%.*]] = load i177, ptr [[A:%.*]], align 4
 ; CHECK-NEXT:    [[L_FROZEN:%.*]] = freeze i177 [[L]]
 ; CHECK-NEXT:    [[TMP1:%.*]] = icmp eq i177 [[L_FROZEN]], -1
-; CHECK-NEXT:    [[B:%.*]] = select i1 [[TMP1]], i177 0, i177 [[L_FROZEN]]
-; CHECK-NEXT:    [[TMP2:%.*]] = trunc i177 [[B]] to i64
+; CHECK-NEXT:    [[TMP6:%.*]] = trunc i177 [[L_FROZEN]] to i64
+; CHECK-NEXT:    [[TMP2:%.*]] = select i1 [[TMP1]], i64 0, i64 [[TMP6]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr i177, ptr [[A]], i64 [[TMP2]]
 ; CHECK-NEXT:    [[G11:%.*]] = getelementptr i8, ptr [[TMP3]], i64 -24
-; CHECK-NEXT:    [[C17:%.*]] = icmp sgt i177 [[B]], [[L_FROZEN]]
-; CHECK-NEXT:    [[TMP4:%.*]] = sext i1 [[C17]] to i64
+; CHECK-NEXT:    [[TMP4:%.*]] = sext i1 [[TMP1]] to i64
 ; CHECK-NEXT:    [[G62:%.*]] = getelementptr i177, ptr [[G11]], i64 [[TMP4]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = icmp eq i177 [[L_FROZEN]], -1
 ; CHECK-NEXT:    [[B28:%.*]] = select i1 [[TMP5]], i177 0, i177 [[L_FROZEN]]



More information about the llvm-commits mailing list