[llvm] 23f0f06 - [InstCombine] Fold icmps comparing uadd_sat with a constant

Dhruv Chawla via llvm-commits llvm-commits at lists.llvm.org
Sat Jul 8 01:04:42 PDT 2023


Author: Dhruv Chawla
Date: 2023-07-08T12:50:39+05:30
New Revision: 23f0f061c399a51b9c846a7aaab1c15ce039e1a3

URL: https://github.com/llvm/llvm-project/commit/23f0f061c399a51b9c846a7aaab1c15ce039e1a3
DIFF: https://github.com/llvm/llvm-project/commit/23f0f061c399a51b9c846a7aaab1c15ce039e1a3.diff

LOG: [InstCombine] Fold icmps comparing uadd_sat with a constant

This patch is a continuation of D154206. It introduces a fold for the
operation "uadd_sat(X, C) pred C2" where "C" and "C2" are constants. The
fold is:

uadd_sat(X, C) pred C2
=> (X >= ~C) || ((X + C) pred C2) -> when (UINT_MAX pred C2) is true
=> (X < ~C)  && ((X + C) pred C2) -> when (UINT_MAX pred C2) is false

This patch also generalizes the fold to work with any saturating
intrinsic as long as the saturating value is known.

Proofs: https://alive2.llvm.org/ce/z/wWeirP

Differential Revision: https://reviews.llvm.org/D154565

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
    llvm/test/Transforms/InstCombine/icmp-uadd-sat.ll
    llvm/test/Transforms/InstCombine/icmp-usub-sat.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 0a9aaf0dea7376..4c1272b0a8b3b3 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -3580,48 +3580,73 @@ Instruction *InstCombinerImpl::foldICmpBinOpWithConstant(ICmpInst &Cmp,
 }
 
 static Instruction *
-foldICmpUSubSatWithConstant(ICmpInst::Predicate Pred, IntrinsicInst *II,
-                            const APInt &C, InstCombiner::BuilderTy &Builder) {
+foldICmpUSubSatOrUAddSatWithConstant(ICmpInst::Predicate Pred,
+                                     SaturatingInst *II, const APInt &C,
+                                     InstCombiner::BuilderTy &Builder) {
   // This transform may end up producing more than one instruction for the
   // intrinsic, so limit it to one user of the intrinsic.
   if (!II->hasOneUse())
     return nullptr;
 
-  // Let Y = usub_sat(X, C) pred C2
-  //  => Y = (X < C ? 0 : (X - C)) pred C2
-  //  => Y = (X < C) ? (0 pred C2) : ((X - C) pred C2)
+  // Let Y        = [add/sub]_sat(X, C) pred C2
+  //     SatVal   = The saturating value for the operation
+  //     WillWrap = Whether or not the operation will underflow / overflow
+  // => Y = (WillWrap ? SatVal : (X binop C)) pred C2
+  // => Y = WillWrap ? (SatVal pred C2) : ((X binop C) pred C2)
   //
-  // When (0 pred C2) is true, then
-  //    Y = (X < C) ? true : ((X - C) pred C2)
-  // => Y = (X < C) || ((X - C) pred C2)
+  // When (SatVal pred C2) is true, then
+  //    Y = WillWrap ? true : ((X binop C) pred C2)
+  // => Y = WillWrap || ((X binop C) pred C2)
   // else
-  //    Y = (X < C) ? false : ((X - C) pred C2)
-  // => Y = !(X < C) && ((X - C) pred C2)
-  // => Y = (X >= C) && ((X - C) pred C2)
+  //    Y =  WillWrap ? false : ((X binop C) pred C2)
+  // => Y = !WillWrap ?  ((X binop C) pred C2) : false
+  // => Y = !WillWrap && ((X binop C) pred C2)
   Value *Op0 = II->getOperand(0);
   Value *Op1 = II->getOperand(1);
 
-  // Check (0 pred C2)
-  auto [NewPred, LogicalOp] =
-      ICmpInst::compare(APInt::getZero(C.getBitWidth()), C, Pred)
-          ? std::make_pair(ICmpInst::ICMP_ULT, Instruction::BinaryOps::Or)
-          : std::make_pair(ICmpInst::ICMP_UGE, Instruction::BinaryOps::And);
-
   const APInt *COp1;
-  // This transform only works when the usub_sat has an integral constant or
+  // This transform only works when the intrinsic has an integral constant or
   // splat vector as the second operand.
   if (!match(Op1, m_APInt(COp1)))
     return nullptr;
 
-  ConstantRange C1 = ConstantRange::makeExactICmpRegion(NewPred, *COp1);
-  // Convert '(X - C) pred C2' into 'X pred C2' shifted by C.
+  APInt SatVal;
+  switch (II->getIntrinsicID()) {
+  default:
+    llvm_unreachable(
+        "This function only works with usub_sat and uadd_sat for now!");
+  case Intrinsic::uadd_sat:
+    SatVal = APInt::getAllOnes(C.getBitWidth());
+    break;
+  case Intrinsic::usub_sat:
+    SatVal = APInt::getZero(C.getBitWidth());
+    break;
+  }
+
+  // Check (SatVal pred C2)
+  bool SatValCheck = ICmpInst::compare(SatVal, C, Pred);
+
+  // !WillWrap.
+  ConstantRange C1 = ConstantRange::makeExactNoWrapRegion(
+      II->getBinaryOp(), *COp1, II->getNoWrapKind());
+
+  // WillWrap.
+  if (SatValCheck)
+    C1 = C1.inverse();
+
   ConstantRange C2 = ConstantRange::makeExactICmpRegion(Pred, C);
-  C2 = C2.add(*COp1);
+  if (II->getBinaryOp() == Instruction::Add)
+    C2 = C2.sub(*COp1);
+  else
+    C2 = C2.add(*COp1);
+
+  Instruction::BinaryOps CombiningOp =
+      SatValCheck ? Instruction::BinaryOps::Or : Instruction::BinaryOps::And;
 
   std::optional<ConstantRange> Combination;
-  if (LogicalOp == Instruction::BinaryOps::Or)
+  if (CombiningOp == Instruction::BinaryOps::Or)
     Combination = C1.exactUnionWith(C2);
-  else /* LogicalOp == Instruction::BinaryOps::And */
+  else /* CombiningOp == Instruction::BinaryOps::And */
     Combination = C1.exactIntersectWith(C2);
 
   if (!Combination)
@@ -3649,8 +3674,10 @@ Instruction *InstCombinerImpl::foldICmpIntrinsicWithConstant(ICmpInst &Cmp,
   switch (II->getIntrinsicID()) {
   default:
     break;
+  case Intrinsic::uadd_sat:
   case Intrinsic::usub_sat:
-    if (auto *Folded = foldICmpUSubSatWithConstant(Pred, II, C, Builder))
+    if (auto *Folded = foldICmpUSubSatOrUAddSatWithConstant(
+            Pred, cast<SaturatingInst>(II), C, Builder))
       return Folded;
     break;
   }

diff  --git a/llvm/test/Transforms/InstCombine/icmp-uadd-sat.ll b/llvm/test/Transforms/InstCombine/icmp-uadd-sat.ll
index e5ce97a6e34c67..a61feac024a9c3 100644
--- a/llvm/test/Transforms/InstCombine/icmp-uadd-sat.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-uadd-sat.ll
@@ -10,8 +10,7 @@
 define i1 @icmp_eq_basic(i8 %arg) {
 ; CHECK-LABEL: define i1 @icmp_eq_basic
 ; CHECK-SAME: (i8 [[ARG:%.*]]) {
-; CHECK-NEXT:    [[ADD:%.*]] = call i8 @llvm.uadd.sat.i8(i8 [[ARG]], i8 2)
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[ADD]], 5
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[ARG]], 3
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %add = call i8 @llvm.uadd.sat.i8(i8 %arg, i8 2)
@@ -22,8 +21,7 @@ define i1 @icmp_eq_basic(i8 %arg) {
 define i1 @icmp_ne_basic(i16 %arg) {
 ; CHECK-LABEL: define i1 @icmp_ne_basic
 ; CHECK-SAME: (i16 [[ARG:%.*]]) {
-; CHECK-NEXT:    [[ADD:%.*]] = call i16 @llvm.uadd.sat.i16(i16 [[ARG]], i16 8)
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ne i16 [[ADD]], 9
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ne i16 [[ARG]], 1
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %add = call i16 @llvm.uadd.sat.i16(i16 %arg, i16 8)
@@ -34,8 +32,7 @@ define i1 @icmp_ne_basic(i16 %arg) {
 define i1 @icmp_ule_basic(i32 %arg) {
 ; CHECK-LABEL: define i1 @icmp_ule_basic
 ; CHECK-SAME: (i32 [[ARG:%.*]]) {
-; CHECK-NEXT:    [[ADD:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[ARG]], i32 2)
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i32 [[ADD]], 4
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i32 [[ARG]], 2
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %add = call i32 @llvm.uadd.sat.i32(i32 %arg, i32 2)
@@ -46,8 +43,7 @@ define i1 @icmp_ule_basic(i32 %arg) {
 define i1 @icmp_ult_basic(i64 %arg) {
 ; CHECK-LABEL: define i1 @icmp_ult_basic
 ; CHECK-SAME: (i64 [[ARG:%.*]]) {
-; CHECK-NEXT:    [[ADD:%.*]] = call i64 @llvm.uadd.sat.i64(i64 [[ARG]], i64 5)
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i64 [[ADD]], 20
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i64 [[ARG]], 15
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %add = call i64 @llvm.uadd.sat.i64(i64 %arg, i64 5)
@@ -58,8 +54,7 @@ define i1 @icmp_ult_basic(i64 %arg) {
 define i1 @icmp_uge_basic(i8 %arg) {
 ; CHECK-LABEL: define i1 @icmp_uge_basic
 ; CHECK-SAME: (i8 [[ARG:%.*]]) {
-; CHECK-NEXT:    [[ADD:%.*]] = call i8 @llvm.uadd.sat.i8(i8 [[ARG]], i8 4)
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ugt i8 [[ADD]], 7
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ugt i8 [[ARG]], 3
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %add = call i8 @llvm.uadd.sat.i8(i8 %arg, i8 4)
@@ -70,8 +65,7 @@ define i1 @icmp_uge_basic(i8 %arg) {
 define i1 @icmp_ugt_basic(i16 %arg) {
 ; CHECK-LABEL: define i1 @icmp_ugt_basic
 ; CHECK-SAME: (i16 [[ARG:%.*]]) {
-; CHECK-NEXT:    [[ADD:%.*]] = call i16 @llvm.uadd.sat.i16(i16 [[ARG]], i16 1)
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ugt i16 [[ADD]], 3
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ugt i16 [[ARG]], 2
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %add = call i16 @llvm.uadd.sat.i16(i16 %arg, i16 1)
@@ -82,8 +76,7 @@ define i1 @icmp_ugt_basic(i16 %arg) {
 define i1 @icmp_sle_basic(i32 %arg) {
 ; CHECK-LABEL: define i1 @icmp_sle_basic
 ; CHECK-SAME: (i32 [[ARG:%.*]]) {
-; CHECK-NEXT:    [[ADD:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[ARG]], i32 10)
-; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i32 [[ADD]], 9
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ugt i32 [[ARG]], 2147483637
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %add = call i32 @llvm.uadd.sat.i32(i32 %arg, i32 10)
@@ -94,8 +87,7 @@ define i1 @icmp_sle_basic(i32 %arg) {
 define i1 @icmp_slt_basic(i64 %arg) {
 ; CHECK-LABEL: define i1 @icmp_slt_basic
 ; CHECK-SAME: (i64 [[ARG:%.*]]) {
-; CHECK-NEXT:    [[ADD:%.*]] = call i64 @llvm.uadd.sat.i64(i64 [[ARG]], i64 24)
-; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i64 [[ADD]], 5
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ugt i64 [[ARG]], 9223372036854775783
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %add = call i64 @llvm.uadd.sat.i64(i64 %arg, i64 24)
@@ -106,8 +98,8 @@ define i1 @icmp_slt_basic(i64 %arg) {
 define i1 @icmp_sge_basic(i8 %arg) {
 ; CHECK-LABEL: define i1 @icmp_sge_basic
 ; CHECK-SAME: (i8 [[ARG:%.*]]) {
-; CHECK-NEXT:    [[ADD:%.*]] = call i8 @llvm.uadd.sat.i8(i8 [[ARG]], i8 1)
-; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i8 [[ADD]], 3
+; CHECK-NEXT:    [[TMP1:%.*]] = add i8 [[ARG]], -3
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i8 [[TMP1]], 124
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %add = call i8 @llvm.uadd.sat.i8(i8 %arg, i8 1)
@@ -118,8 +110,8 @@ define i1 @icmp_sge_basic(i8 %arg) {
 define i1 @icmp_sgt_basic(i16 %arg) {
 ; CHECK-LABEL: define i1 @icmp_sgt_basic
 ; CHECK-SAME: (i16 [[ARG:%.*]]) {
-; CHECK-NEXT:    [[ADD:%.*]] = call i16 @llvm.uadd.sat.i16(i16 [[ARG]], i16 2)
-; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i16 [[ADD]], 5
+; CHECK-NEXT:    [[TMP1:%.*]] = add i16 [[ARG]], -4
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i16 [[TMP1]], 32762
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %add = call i16 @llvm.uadd.sat.i16(i16 %arg, i16 2)
@@ -150,8 +142,7 @@ define i1 @icmp_eq_multiuse(i8 %arg) {
 define <2 x i1> @icmp_eq_vector_equal(<2 x i8> %arg) {
 ; CHECK-LABEL: define <2 x i1> @icmp_eq_vector_equal
 ; CHECK-SAME: (<2 x i8> [[ARG:%.*]]) {
-; CHECK-NEXT:    [[ADD:%.*]] = call <2 x i8> @llvm.uadd.sat.v2i8(<2 x i8> [[ARG]], <2 x i8> <i8 2, i8 2>)
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq <2 x i8> [[ADD]], <i8 5, i8 5>
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq <2 x i8> [[ARG]], <i8 3, i8 3>
 ; CHECK-NEXT:    ret <2 x i1> [[CMP]]
 ;
   %add = call <2 x i8> @llvm.uadd.sat.v2i8(<2 x i8> %arg, <2 x i8> <i8 2, i8 2>)
@@ -174,8 +165,7 @@ define <2 x i1> @icmp_eq_vector_unequal(<2 x i8> %arg) {
 define <2 x i1> @icmp_ne_vector_equal(<2 x i16> %arg) {
 ; CHECK-LABEL: define <2 x i1> @icmp_ne_vector_equal
 ; CHECK-SAME: (<2 x i16> [[ARG:%.*]]) {
-; CHECK-NEXT:    [[ADD:%.*]] = call <2 x i16> @llvm.uadd.sat.v2i16(<2 x i16> [[ARG]], <2 x i16> <i16 3, i16 3>)
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ne <2 x i16> [[ADD]], <i16 5, i16 5>
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ne <2 x i16> [[ARG]], <i16 2, i16 2>
 ; CHECK-NEXT:    ret <2 x i1> [[CMP]]
 ;
   %add = call <2 x i16> @llvm.uadd.sat.v2i16(<2 x i16> %arg, <2 x i16> <i16 3, i16 3>)
@@ -198,8 +188,7 @@ define <2 x i1> @icmp_ne_vector_unequal(<2 x i16> %arg) {
 define <2 x i1> @icmp_ule_vector_equal(<2 x i32> %arg) {
 ; CHECK-LABEL: define <2 x i1> @icmp_ule_vector_equal
 ; CHECK-SAME: (<2 x i32> [[ARG:%.*]]) {
-; CHECK-NEXT:    [[ADD:%.*]] = call <2 x i32> @llvm.uadd.sat.v2i32(<2 x i32> [[ARG]], <2 x i32> <i32 3, i32 3>)
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ult <2 x i32> [[ADD]], <i32 5, i32 5>
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult <2 x i32> [[ARG]], <i32 2, i32 2>
 ; CHECK-NEXT:    ret <2 x i1> [[CMP]]
 ;
   %add = call <2 x i32> @llvm.uadd.sat.v2i32(<2 x i32> %arg, <2 x i32> <i32 3, i32 3>)
@@ -222,8 +211,7 @@ define <2 x i1> @icmp_ule_vector_unequal(<2 x i32> %arg) {
 define <2 x i1> @icmp_sgt_vector_equal(<2 x i64> %arg) {
 ; CHECK-LABEL: define <2 x i1> @icmp_sgt_vector_equal
 ; CHECK-SAME: (<2 x i64> [[ARG:%.*]]) {
-; CHECK-NEXT:    [[ADD:%.*]] = call <2 x i64> @llvm.uadd.sat.v2i64(<2 x i64> [[ARG]], <2 x i64> <i64 409623, i64 409623>)
-; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt <2 x i64> [[ADD]], <i64 1234, i64 1234>
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult <2 x i64> [[ARG]], <i64 9223372036854366185, i64 9223372036854366185>
 ; CHECK-NEXT:    ret <2 x i1> [[CMP]]
 ;
   %add = call <2 x i64> @llvm.uadd.sat.v2i64(<2 x i64> %arg, <2 x i64> <i64 409623, i64 409623>)

diff  --git a/llvm/test/Transforms/InstCombine/icmp-usub-sat.ll b/llvm/test/Transforms/InstCombine/icmp-usub-sat.ll
index 8c517801e72279..87257e40ac6a2c 100644
--- a/llvm/test/Transforms/InstCombine/icmp-usub-sat.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-usub-sat.ll
@@ -1,7 +1,8 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2
 ; RUN: opt < %s -passes=instcombine -S | FileCheck %s
 
-; Tests for InstCombineCompares.cpp::foldICmpUSubSatWithConstant
+; Tests for InstCombineCompares.cpp::foldICmpUSubSatOrUAddSatWithConstant
+; - usub_sat case
 ; https://github.com/llvm/llvm-project/issues/58342
 
 ; ==============================================================================


        


More information about the llvm-commits mailing list