[llvm] [LVI] Handle icmp of ashr. (PR #68010)

Amara Emerson via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 20 00:00:31 PDT 2023


https://github.com/aemerson updated https://github.com/llvm/llvm-project/pull/68010

>From 59777eaab063592954d5914191f410cce1291c1f Mon Sep 17 00:00:00 2001
From: Amara Emerson <amara at apple.com>
Date: Thu, 1 Jun 2023 10:52:27 -0700
Subject: [PATCH 1/2] [LVI] Handle icmp of ashr.

This handles the case where this combine:
icmp sgt (ashr X, ShAmtC), C --> icmp sgt X, ((C + 1) << ShAmtC) - 1

wasn't performed by instcombine.

Proof of the original combine: https://alive2.llvm.org/ce/z/SfpsvX

Differential Revision: https://reviews.llvm.org/D151911
---
 llvm/lib/Analysis/LazyValueInfo.cpp           |  52 +++++++++
 .../CorrelatedValuePropagation/icmp.ll        | 104 ++++++++++++++++++
 2 files changed, 156 insertions(+)

diff --git a/llvm/lib/Analysis/LazyValueInfo.cpp b/llvm/lib/Analysis/LazyValueInfo.cpp
index 0892aa9d75fb417..3a7b798178db278 100644
--- a/llvm/lib/Analysis/LazyValueInfo.cpp
+++ b/llvm/lib/Analysis/LazyValueInfo.cpp
@@ -26,6 +26,7 @@
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/Dominators.h"
+#include "llvm/IR/InstrTypes.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
@@ -1148,6 +1149,57 @@ static ValueLatticeElement getValueFromICmpCondition(Value *Val, ICmpInst *ICI,
           CR.getUnsignedMin().zext(BitWidth), APInt(BitWidth, 0)));
   }
 
+  // Recognize:
+  // icmp sgt (ashr X, ShAmtC), C --> icmp sgt X, ((C + 1) << ShAmtC) - 1
+  // and friends.
+  // Preconditions: (C != SIGNED_MAX) &&
+  //                ((C+1) << ShAmtC != SIGNED_MIN) &&
+  //                (((C+1) << ShAmtC) >> ShAmtC) == (C+1)
+  const APInt *ShAmtC;
+  if (CmpInst::isSigned(EdgePred) &&
+      match(LHS, m_AShr(m_Specific(Val), m_APInt(ShAmtC))) &&
+      match(RHS, m_APInt(C))) {
+    APInt New = ((*C + 1) << *ShAmtC) - 1;
+    APInt MaxSigned = APInt::getSignedMaxValue(New.getBitWidth());
+    APInt MinSigned = APInt::getSignedMinValue(New.getBitWidth());
+    auto CheckPreConds = [&]() {
+      if (*C == MaxSigned)
+        return false;
+      APInt Shifted = (*C + 1) << *ShAmtC;
+      if (Shifted == MinSigned)
+        return false;
+      if ((Shifted.ashr(*ShAmtC)) != (*C + 1))
+        return false;
+      return true;
+    };
+    if (!CheckPreConds())
+      return ValueLatticeElement::getOverdefined();
+    APInt Lower, Upper;
+    switch (EdgePred) {
+    default:
+      llvm_unreachable("Unknown signed predicate!");
+    case ICmpInst::ICMP_SGT:
+      Lower = New + 1;
+      Upper = MaxSigned;
+      break;
+    case ICmpInst::ICMP_SLE:
+      Lower = MinSigned;
+      Upper = New + 1;
+      break;
+    case ICmpInst::ICMP_SGE:
+      Lower = New;
+      Upper = MaxSigned;
+      break;
+    case ICmpInst::ICMP_SLT:
+      Lower = MinSigned;
+      Upper = New;
+      break;
+    }
+
+    return ValueLatticeElement::getRange(
+        ConstantRange::getNonEmpty(Lower, Upper));
+  }
+
   return ValueLatticeElement::getOverdefined();
 }
 
diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/icmp.ll b/llvm/test/Transforms/CorrelatedValuePropagation/icmp.ll
index c4f0ade39942a76..29225595281fb40 100644
--- a/llvm/test/Transforms/CorrelatedValuePropagation/icmp.ll
+++ b/llvm/test/Transforms/CorrelatedValuePropagation/icmp.ll
@@ -1240,6 +1240,110 @@ define <2 x i1> @non_const_range_minmax_vec(<2 x i8> %a, <2 x i8> %b) {
   ret <2 x i1> %cmp1
 }
 
+define void @ashr_sgt(i8 %x) {
+; CHECK-LABEL: @ashr_sgt(
+; CHECK-NEXT:    [[S:%.*]] = ashr i8 [[X:%.*]], 2
+; CHECK-NEXT:    [[C:%.*]] = icmp sgt i8 [[S]], 0
+; CHECK-NEXT:    br i1 [[C]], label [[IF:%.*]], label [[ELSE:%.*]]
+; CHECK:       if:
+; CHECK-NEXT:    call void @check1(i1 true)
+; CHECK-NEXT:    [[C3:%.*]] = icmp ugt i8 [[X]], 4
+; CHECK-NEXT:    call void @check1(i1 [[C3]])
+; CHECK-NEXT:    ret void
+; CHECK:       else:
+; CHECK-NEXT:    ret void
+;
+  %s = ashr i8 %x, 2
+  %c = icmp sgt i8 %s, 0
+  br i1 %c, label %if, label %else
+if:
+  %c2 = icmp sgt i8 %x, 3
+  call void @check1(i1 %c2)
+  %c3 = icmp sgt i8 %x, 4
+  call void @check1(i1 %c3)
+  ret void
+else:
+  ret void
+}
+
+define void @ashr_sge(i8 %x) {
+; CHECK-LABEL: @ashr_sge(
+; CHECK-NEXT:    [[S:%.*]] = ashr i8 [[X:%.*]], 2
+; CHECK-NEXT:    [[C:%.*]] = icmp sge i8 [[S]], 0
+; CHECK-NEXT:    br i1 [[C]], label [[IF:%.*]], label [[ELSE:%.*]]
+; CHECK:       if:
+; CHECK-NEXT:    call void @check1(i1 true)
+; CHECK-NEXT:    [[C3:%.*]] = icmp uge i8 [[X]], 4
+; CHECK-NEXT:    call void @check1(i1 [[C3]])
+; CHECK-NEXT:    ret void
+; CHECK:       else:
+; CHECK-NEXT:    ret void
+;
+  %s = ashr i8 %x, 2
+  %c = icmp sge i8 %s, 0
+  br i1 %c, label %if, label %else
+if:
+  %c2 = icmp sge i8 %x, 3
+  call void @check1(i1 %c2)
+  %c3 = icmp sge i8 %x, 4
+  call void @check1(i1 %c3)
+  ret void
+else:
+  ret void
+}
+
+define void @ashr_slt(i8 %x) {
+; CHECK-LABEL: @ashr_slt(
+; CHECK-NEXT:    [[S:%.*]] = ashr i8 [[X:%.*]], 2
+; CHECK-NEXT:    [[C:%.*]] = icmp slt i8 [[S]], 0
+; CHECK-NEXT:    br i1 [[C]], label [[IF:%.*]], label [[ELSE:%.*]]
+; CHECK:       if:
+; CHECK-NEXT:    call void @check1(i1 true)
+; CHECK-NEXT:    [[C3:%.*]] = icmp slt i8 [[X]], 2
+; CHECK-NEXT:    call void @check1(i1 [[C3]])
+; CHECK-NEXT:    ret void
+; CHECK:       else:
+; CHECK-NEXT:    ret void
+;
+  %s = ashr i8 %x, 2
+  %c = icmp slt i8 %s, 0
+  br i1 %c, label %if, label %else
+if:
+  %c2 = icmp slt i8 %x, 3
+  call void @check1(i1 %c2)
+  %c3 = icmp slt i8 %x, 2
+  call void @check1(i1 %c3)
+  ret void
+else:
+  ret void
+}
+
+define void @ashr_sle(i8 %x) {
+; CHECK-LABEL: @ashr_sle(
+; CHECK-NEXT:    [[S:%.*]] = ashr i8 [[X:%.*]], 2
+; CHECK-NEXT:    [[C:%.*]] = icmp sle i8 [[S]], 0
+; CHECK-NEXT:    br i1 [[C]], label [[IF:%.*]], label [[ELSE:%.*]]
+; CHECK:       if:
+; CHECK-NEXT:    call void @check1(i1 true)
+; CHECK-NEXT:    [[C3:%.*]] = icmp sle i8 [[X]], 2
+; CHECK-NEXT:    call void @check1(i1 [[C3]])
+; CHECK-NEXT:    ret void
+; CHECK:       else:
+; CHECK-NEXT:    ret void
+;
+  %s = ashr i8 %x, 2
+  %c = icmp sle i8 %s, 0
+  br i1 %c, label %if, label %else
+if:
+  %c2 = icmp sle i8 %x, 3
+  call void @check1(i1 %c2)
+  %c3 = icmp sle i8 %x, 2
+  call void @check1(i1 %c3)
+  ret void
+else:
+  ret void
+}
+
 declare i8 @llvm.umin.i8(i8, i8)
 declare i8 @llvm.umax.i8(i8, i8)
 declare <2 x i8> @llvm.umin.v2i8(<2 x i8>, <2 x i8>)

>From 6f204392a996ffbd9bcef82b78fd11880dc14bd0 Mon Sep 17 00:00:00 2001
From: Amara Emerson <amara at apple.com>
Date: Fri, 20 Oct 2023 00:00:01 -0700
Subject: [PATCH 2/2] Use new implementation

---
 llvm/lib/Analysis/LazyValueInfo.cpp           | 76 ++++++++-----------
 .../CorrelatedValuePropagation/icmp.ll        | 12 +--
 2 files changed, 38 insertions(+), 50 deletions(-)

diff --git a/llvm/lib/Analysis/LazyValueInfo.cpp b/llvm/lib/Analysis/LazyValueInfo.cpp
index 3a7b798178db278..d9feeef49336d95 100644
--- a/llvm/lib/Analysis/LazyValueInfo.cpp
+++ b/llvm/lib/Analysis/LazyValueInfo.cpp
@@ -1084,6 +1084,26 @@ static ValueLatticeElement getValueFromSimpleICmpCondition(
   return ValueLatticeElement::getRange(TrueValues.subtract(Offset));
 }
 
+static std::optional<ConstantRange>
+getRangeViaSLT(CmpInst::Predicate Pred, APInt RHS,
+               function_ref<std::optional<ConstantRange>(const APInt &)> Fn) {
+  bool Invert = false;
+  if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) {
+    Pred = ICmpInst::getInversePredicate(Pred);
+    Invert = true;
+  }
+  if (Pred == ICmpInst::ICMP_SLE) {
+    Pred = ICmpInst::ICMP_SLT;
+    if (RHS.isMaxSignedValue())
+      return std::nullopt; // Could also return full/empty here, if we wanted.
+    ++RHS;
+  }
+  assert(Pred == ICmpInst::ICMP_SLT && "Must be signed predicate");
+  if (auto CR = Fn(RHS))
+    return Invert ? CR->inverse() : CR;
+  return std::nullopt;
+}
+
 static ValueLatticeElement getValueFromICmpCondition(Value *Val, ICmpInst *ICI,
                                                      bool isTrueDest) {
   Value *LHS = ICI->getOperand(0);
@@ -1150,54 +1170,22 @@ static ValueLatticeElement getValueFromICmpCondition(Value *Val, ICmpInst *ICI,
   }
 
   // Recognize:
-  // icmp sgt (ashr X, ShAmtC), C --> icmp sgt X, ((C + 1) << ShAmtC) - 1
-  // and friends.
-  // Preconditions: (C != SIGNED_MAX) &&
-  //                ((C+1) << ShAmtC != SIGNED_MIN) &&
-  //                (((C+1) << ShAmtC) >> ShAmtC) == (C+1)
+  // icmp slt (ashr X, ShAmtC), C --> icmp slt X, C << ShAmtC
+  // Preconditions: (C << ShAmtC) >> ShAmtC == C
   const APInt *ShAmtC;
   if (CmpInst::isSigned(EdgePred) &&
       match(LHS, m_AShr(m_Specific(Val), m_APInt(ShAmtC))) &&
       match(RHS, m_APInt(C))) {
-    APInt New = ((*C + 1) << *ShAmtC) - 1;
-    APInt MaxSigned = APInt::getSignedMaxValue(New.getBitWidth());
-    APInt MinSigned = APInt::getSignedMinValue(New.getBitWidth());
-    auto CheckPreConds = [&]() {
-      if (*C == MaxSigned)
-        return false;
-      APInt Shifted = (*C + 1) << *ShAmtC;
-      if (Shifted == MinSigned)
-        return false;
-      if ((Shifted.ashr(*ShAmtC)) != (*C + 1))
-        return false;
-      return true;
-    };
-    if (!CheckPreConds())
-      return ValueLatticeElement::getOverdefined();
-    APInt Lower, Upper;
-    switch (EdgePred) {
-    default:
-      llvm_unreachable("Unknown signed predicate!");
-    case ICmpInst::ICMP_SGT:
-      Lower = New + 1;
-      Upper = MaxSigned;
-      break;
-    case ICmpInst::ICMP_SLE:
-      Lower = MinSigned;
-      Upper = New + 1;
-      break;
-    case ICmpInst::ICMP_SGE:
-      Lower = New;
-      Upper = MaxSigned;
-      break;
-    case ICmpInst::ICMP_SLT:
-      Lower = MinSigned;
-      Upper = New;
-      break;
-    }
-
-    return ValueLatticeElement::getRange(
-        ConstantRange::getNonEmpty(Lower, Upper));
+    auto CR = getRangeViaSLT(
+        EdgePred, *C, [&](const APInt &RHS) -> std::optional<ConstantRange> {
+          APInt New = RHS << *ShAmtC;
+          if ((New.ashr(*ShAmtC)) != RHS)
+            return std::nullopt;
+          return ConstantRange::getNonEmpty(
+              APInt::getSignedMinValue(New.getBitWidth()), New);
+        });
+    if (CR)
+      return ValueLatticeElement::getRange(*CR);
   }
 
   return ValueLatticeElement::getOverdefined();
diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/icmp.ll b/llvm/test/Transforms/CorrelatedValuePropagation/icmp.ll
index 29225595281fb40..41a0505bbc09b5a 100644
--- a/llvm/test/Transforms/CorrelatedValuePropagation/icmp.ll
+++ b/llvm/test/Transforms/CorrelatedValuePropagation/icmp.ll
@@ -1273,7 +1273,7 @@ define void @ashr_sge(i8 %x) {
 ; CHECK-NEXT:    br i1 [[C]], label [[IF:%.*]], label [[ELSE:%.*]]
 ; CHECK:       if:
 ; CHECK-NEXT:    call void @check1(i1 true)
-; CHECK-NEXT:    [[C3:%.*]] = icmp uge i8 [[X]], 4
+; CHECK-NEXT:    [[C3:%.*]] = icmp uge i8 [[X]], 1
 ; CHECK-NEXT:    call void @check1(i1 [[C3]])
 ; CHECK-NEXT:    ret void
 ; CHECK:       else:
@@ -1283,9 +1283,9 @@ define void @ashr_sge(i8 %x) {
   %c = icmp sge i8 %s, 0
   br i1 %c, label %if, label %else
 if:
-  %c2 = icmp sge i8 %x, 3
+  %c2 = icmp sge i8 %x, 0
   call void @check1(i1 %c2)
-  %c3 = icmp sge i8 %x, 4
+  %c3 = icmp sge i8 %x, 1
   call void @check1(i1 %c3)
   ret void
 else:
@@ -1299,7 +1299,7 @@ define void @ashr_slt(i8 %x) {
 ; CHECK-NEXT:    br i1 [[C]], label [[IF:%.*]], label [[ELSE:%.*]]
 ; CHECK:       if:
 ; CHECK-NEXT:    call void @check1(i1 true)
-; CHECK-NEXT:    [[C3:%.*]] = icmp slt i8 [[X]], 2
+; CHECK-NEXT:    [[C3:%.*]] = icmp ult i8 [[X]], -1
 ; CHECK-NEXT:    call void @check1(i1 [[C3]])
 ; CHECK-NEXT:    ret void
 ; CHECK:       else:
@@ -1309,9 +1309,9 @@ define void @ashr_slt(i8 %x) {
   %c = icmp slt i8 %s, 0
   br i1 %c, label %if, label %else
 if:
-  %c2 = icmp slt i8 %x, 3
+  %c2 = icmp slt i8 %x, 0
   call void @check1(i1 %c2)
-  %c3 = icmp slt i8 %x, 2
+  %c3 = icmp slt i8 %x, -1
   call void @check1(i1 %c3)
   ret void
 else:



More information about the llvm-commits mailing list