[llvm] b3fdb7b - [InstCombine] Combine lshr of add -> (a + b < a)

Pierre van Houtryve via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 10 00:37:27 PST 2023


Author: Pierre van Houtryve
Date: 2023-01-10T03:37:23-05:00
New Revision: b3fdb7b0cba49e7f24fd8207c677b0541045755c

URL: https://github.com/llvm/llvm-project/commit/b3fdb7b0cba49e7f24fd8207c677b0541045755c
DIFF: https://github.com/llvm/llvm-project/commit/b3fdb7b0cba49e7f24fd8207c677b0541045755c.diff

LOG: [InstCombine] Combine lshr of add -> (a + b < a)

Tries to perform
  (lshr (add (zext X), (zext Y)), K)
  ->  (icmp ult (add X, Y), X)
  where
    - The add's operands are zexts from a K-bits integer to a bigger type.
    - The add is only used by the shr, or by iK (or narrower) truncates.
    - The lshr type has more than 2 bits (other types are boolean math).
    - K > 1

This seems to be a pattern that just comes from OpenCL front-ends, so adding DAG/GISel combines doesn't seem to be worth the complexity.

Original patch D107552 by @abinavpp - adapted to use (a + b < a) instead of uaddo following discussion on the review.
See this issue https://github.com/RadeonOpenCompute/ROCm/issues/488

Reviewed By: spatel

Differential Revision: https://reviews.llvm.org/D138814

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineInternal.h
    llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
    llvm/test/Transforms/InstCombine/lshr.ll
    llvm/test/Transforms/InstCombine/shift-add.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index e99c8a0bfa3b7..a9def58f487df 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -375,6 +375,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
                               bool InvertFalseVal = false);
   Value *getSelectCondition(Value *A, Value *B, bool ABIsTheSame);
 
+  Instruction *foldLShrOverflowBit(BinaryOperator &I);
   Instruction *foldExtractOfOverflowIntrinsic(ExtractValueInst &EV);
   Instruction *foldIntrinsicWithOverflowCommon(IntrinsicInst *II);
   Instruction *foldFPSignBitOps(BinaryOperator &I);

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 9ac1fd7731bf1..d8c66f9a87fad 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -839,6 +839,74 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *C1,
   return nullptr;
 }
 
+// Tries to perform
+//    (lshr (add (zext X), (zext Y)), K)
+//      -> (icmp ult (add X, Y), X)
+//    where
+//      - The add's operands are zexts from a K-bits integer to a bigger type.
+//      - The add is only used by the shr, or by iK (or narrower) truncates.
+//      - The lshr type has more than 2 bits (other types are boolean math).
+//      - K > 1
+//    note that
+//      - The resulting add cannot have nuw/nsw, else on overflow we get a
+//        poison value and the transform isn't legal anymore.
+Instruction *InstCombinerImpl::foldLShrOverflowBit(BinaryOperator &I) {
+  assert(I.getOpcode() == Instruction::LShr);
+
+  Value *Add = I.getOperand(0);
+  Value *ShiftAmt = I.getOperand(1);
+  Type *Ty = I.getType();
+
+  if (Ty->getScalarSizeInBits() < 3)
+    return nullptr;
+
+  const APInt *ShAmtAPInt = nullptr;
+  Value *X = nullptr, *Y = nullptr;
+  if (!match(ShiftAmt, m_APInt(ShAmtAPInt)) ||
+      !match(Add,
+             m_Add(m_OneUse(m_ZExt(m_Value(X))), m_OneUse(m_ZExt(m_Value(Y))))))
+    return nullptr;
+
+  const unsigned ShAmt = ShAmtAPInt->getZExtValue();
+  if (ShAmt == 1)
+    return nullptr;
+
+  // X/Y are zexts from `ShAmt`-sized ints.
+  if (X->getType()->getScalarSizeInBits() != ShAmt ||
+      Y->getType()->getScalarSizeInBits() != ShAmt)
+    return nullptr;
+
+  // Make sure that `Add` is only used by `I` and `ShAmt`-truncates.
+  if (!Add->hasOneUse()) {
+    for (User *U : Add->users()) {
+      if (U == &I)
+        continue;
+
+      TruncInst *Trunc = dyn_cast<TruncInst>(U);
+      if (!Trunc || Trunc->getType()->getScalarSizeInBits() > ShAmt)
+        return nullptr;
+    }
+  }
+
+  // Insert at Add so that the newly created `NarrowAdd` will dominate it's
+  // users (i.e. `Add`'s users).
+  Instruction *AddInst = cast<Instruction>(Add);
+  Builder.SetInsertPoint(AddInst);
+
+  Value *NarrowAdd = Builder.CreateAdd(X, Y, "add.narrowed");
+  Value *Overflow =
+      Builder.CreateICmpULT(NarrowAdd, X, "add.narrowed.overflow");
+
+  // Replace the uses of the original add with a zext of the
+  // NarrowAdd's result. Note that all users at this stage are known to
+  // be ShAmt-sized truncs, or the lshr itself.
+  if (!Add->hasOneUse())
+    replaceInstUsesWith(*AddInst, Builder.CreateZExt(NarrowAdd, Ty));
+
+  // Replace the LShr with a zext of the overflow check.
+  return new ZExtInst(Overflow, Ty);
+}
+
 Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
   const SimplifyQuery Q = SQ.getWithInstruction(&I);
 
@@ -1333,6 +1401,9 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
     return BinaryOperator::CreateAnd(Mask, X);
   }
 
+  if (Instruction *Overflow = foldLShrOverflowBit(I))
+    return Overflow;
+
   return nullptr;
 }
 

diff  --git a/llvm/test/Transforms/InstCombine/lshr.ll b/llvm/test/Transforms/InstCombine/lshr.ll
index 2d58c53e108a0..f3209dbbe4456 100644
--- a/llvm/test/Transforms/InstCombine/lshr.ll
+++ b/llvm/test/Transforms/InstCombine/lshr.ll
@@ -1043,10 +1043,9 @@ define i2 @bool_add_lshr(i1 %a, i1 %b) {
 
 define i4 @not_bool_add_lshr(i2 %a, i2 %b) {
 ; CHECK-LABEL: @not_bool_add_lshr(
-; CHECK-NEXT:    [[ZEXT_A:%.*]] = zext i2 [[A:%.*]] to i4
-; CHECK-NEXT:    [[ZEXT_B:%.*]] = zext i2 [[B:%.*]] to i4
-; CHECK-NEXT:    [[ADD:%.*]] = add nuw nsw i4 [[ZEXT_A]], [[ZEXT_B]]
-; CHECK-NEXT:    [[LSHR:%.*]] = lshr i4 [[ADD]], 2
+; CHECK-NEXT:    [[TMP1:%.*]] = xor i2 [[A:%.*]], -1
+; CHECK-NEXT:    [[ADD_NARROWED_OVERFLOW:%.*]] = icmp ult i2 [[TMP1]], [[B:%.*]]
+; CHECK-NEXT:    [[LSHR:%.*]] = zext i1 [[ADD_NARROWED_OVERFLOW]] to i4
 ; CHECK-NEXT:    ret i4 [[LSHR]]
 ;
   %zext.a = zext i2 %a to i4

diff  --git a/llvm/test/Transforms/InstCombine/shift-add.ll b/llvm/test/Transforms/InstCombine/shift-add.ll
index 6001eaebf3a99..0301b93063385 100644
--- a/llvm/test/Transforms/InstCombine/shift-add.ll
+++ b/llvm/test/Transforms/InstCombine/shift-add.ll
@@ -462,10 +462,9 @@ define i2 @ashr_2_add_zext_basic(i1 %a, i1 %b) {
 
 define i32 @lshr_16_add_zext_basic(i16 %a, i16 %b) {
 ; CHECK-LABEL: @lshr_16_add_zext_basic(
-; CHECK-NEXT:    [[ZEXT_A:%.*]] = zext i16 [[A:%.*]] to i32
-; CHECK-NEXT:    [[ZEXT_B:%.*]] = zext i16 [[B:%.*]] to i32
-; CHECK-NEXT:    [[ADD:%.*]] = add nuw nsw i32 [[ZEXT_A]], [[ZEXT_B]]
-; CHECK-NEXT:    [[LSHR:%.*]] = lshr i32 [[ADD]], 16
+; CHECK-NEXT:    [[TMP1:%.*]] = xor i16 [[A:%.*]], -1
+; CHECK-NEXT:    [[ADD_NARROWED_OVERFLOW:%.*]] = icmp ult i16 [[TMP1]], [[B:%.*]]
+; CHECK-NEXT:    [[LSHR:%.*]] = zext i1 [[ADD_NARROWED_OVERFLOW]] to i32
 ; CHECK-NEXT:    ret i32 [[LSHR]]
 ;
   %zext.a = zext i16 %a to i32
@@ -524,10 +523,9 @@ define i32 @lshr_16_add_not_known_16_leading_zeroes(i32 %a, i32 %b) {
 
 define i64 @lshr_32_add_zext_basic(i32 %a, i32 %b) {
 ; CHECK-LABEL: @lshr_32_add_zext_basic(
-; CHECK-NEXT:    [[ZEXT_A:%.*]] = zext i32 [[A:%.*]] to i64
-; CHECK-NEXT:    [[ZEXT_B:%.*]] = zext i32 [[B:%.*]] to i64
-; CHECK-NEXT:    [[ADD:%.*]] = add nuw nsw i64 [[ZEXT_A]], [[ZEXT_B]]
-; CHECK-NEXT:    [[LSHR:%.*]] = lshr i64 [[ADD]], 32
+; CHECK-NEXT:    [[TMP1:%.*]] = xor i32 [[A:%.*]], -1
+; CHECK-NEXT:    [[ADD_NARROWED_OVERFLOW:%.*]] = icmp ult i32 [[TMP1]], [[B:%.*]]
+; CHECK-NEXT:    [[LSHR:%.*]] = zext i1 [[ADD_NARROWED_OVERFLOW]] to i64
 ; CHECK-NEXT:    ret i64 [[LSHR]]
 ;
   %zext.a = zext i32 %a to i64
@@ -582,10 +580,9 @@ define i64 @lshr_33_i32_add_zext_basic(i32 %a, i32 %b) {
 
 define i64 @lshr_16_to_64_add_zext_basic(i16 %a, i16 %b) {
 ; CHECK-LABEL: @lshr_16_to_64_add_zext_basic(
-; CHECK-NEXT:    [[ZEXT_A:%.*]] = zext i16 [[A:%.*]] to i64
-; CHECK-NEXT:    [[ZEXT_B:%.*]] = zext i16 [[B:%.*]] to i64
-; CHECK-NEXT:    [[ADD:%.*]] = add nuw nsw i64 [[ZEXT_A]], [[ZEXT_B]]
-; CHECK-NEXT:    [[LSHR:%.*]] = lshr i64 [[ADD]], 16
+; CHECK-NEXT:    [[TMP1:%.*]] = xor i16 [[A:%.*]], -1
+; CHECK-NEXT:    [[ADD_NARROWED_OVERFLOW:%.*]] = icmp ult i16 [[TMP1]], [[B:%.*]]
+; CHECK-NEXT:    [[LSHR:%.*]] = zext i1 [[ADD_NARROWED_OVERFLOW]] to i64
 ; CHECK-NEXT:    ret i64 [[LSHR]]
 ;
   %zext.a = zext i16 %a to i64
@@ -628,10 +625,9 @@ define i64 @lshr_32_add_not_known_32_leading_zeroes(i64 %a, i64 %b) {
 
 define i32 @ashr_16_add_zext_basic(i16 %a, i16 %b) {
 ; CHECK-LABEL: @ashr_16_add_zext_basic(
-; CHECK-NEXT:    [[ZEXT_A:%.*]] = zext i16 [[A:%.*]] to i32
-; CHECK-NEXT:    [[ZEXT_B:%.*]] = zext i16 [[B:%.*]] to i32
-; CHECK-NEXT:    [[ADD:%.*]] = add nuw nsw i32 [[ZEXT_A]], [[ZEXT_B]]
-; CHECK-NEXT:    [[LSHR:%.*]] = lshr i32 [[ADD]], 16
+; CHECK-NEXT:    [[TMP1:%.*]] = xor i16 [[A:%.*]], -1
+; CHECK-NEXT:    [[ADD_NARROWED_OVERFLOW:%.*]] = icmp ult i16 [[TMP1]], [[B:%.*]]
+; CHECK-NEXT:    [[LSHR:%.*]] = zext i1 [[ADD_NARROWED_OVERFLOW]] to i32
 ; CHECK-NEXT:    ret i32 [[LSHR]]
 ;
   %zext.a = zext i16 %a to i32
@@ -643,10 +639,9 @@ define i32 @ashr_16_add_zext_basic(i16 %a, i16 %b) {
 
 define i64 @ashr_32_add_zext_basic(i32 %a, i32 %b) {
 ; CHECK-LABEL: @ashr_32_add_zext_basic(
-; CHECK-NEXT:    [[ZEXT_A:%.*]] = zext i32 [[A:%.*]] to i64
-; CHECK-NEXT:    [[ZEXT_B:%.*]] = zext i32 [[B:%.*]] to i64
-; CHECK-NEXT:    [[ADD:%.*]] = add nuw nsw i64 [[ZEXT_A]], [[ZEXT_B]]
-; CHECK-NEXT:    [[LSHR:%.*]] = lshr i64 [[ADD]], 32
+; CHECK-NEXT:    [[TMP1:%.*]] = xor i32 [[A:%.*]], -1
+; CHECK-NEXT:    [[ADD_NARROWED_OVERFLOW:%.*]] = icmp ult i32 [[TMP1]], [[B:%.*]]
+; CHECK-NEXT:    [[LSHR:%.*]] = zext i1 [[ADD_NARROWED_OVERFLOW]] to i64
 ; CHECK-NEXT:    ret i64 [[LSHR]]
 ;
   %zext.a = zext i32 %a to i64
@@ -658,10 +653,9 @@ define i64 @ashr_32_add_zext_basic(i32 %a, i32 %b) {
 
 define i64 @ashr_16_to_64_add_zext_basic(i16 %a, i16 %b) {
 ; CHECK-LABEL: @ashr_16_to_64_add_zext_basic(
-; CHECK-NEXT:    [[ZEXT_A:%.*]] = zext i16 [[A:%.*]] to i64
-; CHECK-NEXT:    [[ZEXT_B:%.*]] = zext i16 [[B:%.*]] to i64
-; CHECK-NEXT:    [[ADD:%.*]] = add nuw nsw i64 [[ZEXT_A]], [[ZEXT_B]]
-; CHECK-NEXT:    [[LSHR:%.*]] = lshr i64 [[ADD]], 16
+; CHECK-NEXT:    [[TMP1:%.*]] = xor i16 [[A:%.*]], -1
+; CHECK-NEXT:    [[ADD_NARROWED_OVERFLOW:%.*]] = icmp ult i16 [[TMP1]], [[B:%.*]]
+; CHECK-NEXT:    [[LSHR:%.*]] = zext i1 [[ADD_NARROWED_OVERFLOW]] to i64
 ; CHECK-NEXT:    ret i64 [[LSHR]]
 ;
   %zext.a = zext i16 %a to i64
@@ -673,13 +667,10 @@ define i64 @ashr_16_to_64_add_zext_basic(i16 %a, i16 %b) {
 
 define i32 @lshr_32_add_zext_trunc(i32 %a, i32 %b) {
 ; CHECK-LABEL: @lshr_32_add_zext_trunc(
-; CHECK-NEXT:    [[ZEXT_A:%.*]] = zext i32 [[A:%.*]] to i64
-; CHECK-NEXT:    [[ZEXT_B:%.*]] = zext i32 [[B:%.*]] to i64
-; CHECK-NEXT:    [[ADD:%.*]] = add nuw nsw i64 [[ZEXT_A]], [[ZEXT_B]]
-; CHECK-NEXT:    [[TRUNC_ADD:%.*]] = trunc i64 [[ADD]] to i32
-; CHECK-NEXT:    [[SHR:%.*]] = lshr i64 [[ADD]], 32
-; CHECK-NEXT:    [[TRUNC_SHR:%.*]] = trunc i64 [[SHR]] to i32
-; CHECK-NEXT:    [[RET:%.*]] = add i32 [[TRUNC_ADD]], [[TRUNC_SHR]]
+; CHECK-NEXT:    [[ADD_NARROWED:%.*]] = add i32 [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    [[ADD_NARROWED_OVERFLOW:%.*]] = icmp ult i32 [[ADD_NARROWED]], [[A]]
+; CHECK-NEXT:    [[TRUNC_SHR:%.*]] = zext i1 [[ADD_NARROWED_OVERFLOW]] to i32
+; CHECK-NEXT:    [[RET:%.*]] = add i32 [[ADD_NARROWED]], [[TRUNC_SHR]]
 ; CHECK-NEXT:    ret i32 [[RET]]
 ;
   %zext.a = zext i32 %a to i64
@@ -695,29 +686,27 @@ define i32 @lshr_32_add_zext_trunc(i32 %a, i32 %b) {
 define <3 x i32> @add3_i96(<3 x i32> %0, <3 x i32> %1) {
 ; CHECK-LABEL: @add3_i96(
 ; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <3 x i32> [[TMP0:%.*]], i64 0
-; CHECK-NEXT:    [[TMP4:%.*]] = zext i32 [[TMP3]] to i64
-; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <3 x i32> [[TMP1:%.*]], i64 0
+; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <3 x i32> [[TMP1:%.*]], i64 0
+; CHECK-NEXT:    [[ADD_NARROWED:%.*]] = add i32 [[TMP4]], [[TMP3]]
+; CHECK-NEXT:    [[ADD_NARROWED_OVERFLOW:%.*]] = icmp ult i32 [[ADD_NARROWED]], [[TMP4]]
+; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <3 x i32> [[TMP0]], i64 1
 ; CHECK-NEXT:    [[TMP6:%.*]] = zext i32 [[TMP5]] to i64
-; CHECK-NEXT:    [[TMP7:%.*]] = add nuw nsw i64 [[TMP6]], [[TMP4]]
-; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <3 x i32> [[TMP0]], i64 1
-; CHECK-NEXT:    [[TMP9:%.*]] = zext i32 [[TMP8]] to i64
-; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <3 x i32> [[TMP1]], i64 1
-; CHECK-NEXT:    [[TMP11:%.*]] = zext i32 [[TMP10]] to i64
-; CHECK-NEXT:    [[TMP12:%.*]] = add nuw nsw i64 [[TMP11]], [[TMP9]]
-; CHECK-NEXT:    [[TMP13:%.*]] = lshr i64 [[TMP7]], 32
-; CHECK-NEXT:    [[TMP14:%.*]] = add nuw nsw i64 [[TMP12]], [[TMP13]]
-; CHECK-NEXT:    [[TMP15:%.*]] = extractelement <3 x i32> [[TMP0]], i64 2
-; CHECK-NEXT:    [[TMP16:%.*]] = extractelement <3 x i32> [[TMP1]], i64 2
-; CHECK-NEXT:    [[TMP17:%.*]] = add i32 [[TMP16]], [[TMP15]]
-; CHECK-NEXT:    [[TMP18:%.*]] = lshr i64 [[TMP14]], 32
-; CHECK-NEXT:    [[TMP19:%.*]] = trunc i64 [[TMP18]] to i32
-; CHECK-NEXT:    [[TMP20:%.*]] = add i32 [[TMP17]], [[TMP19]]
-; CHECK-NEXT:    [[TMP21:%.*]] = trunc i64 [[TMP7]] to i32
-; CHECK-NEXT:    [[TMP22:%.*]] = insertelement <3 x i32> undef, i32 [[TMP21]], i64 0
-; CHECK-NEXT:    [[TMP23:%.*]] = trunc i64 [[TMP14]] to i32
-; CHECK-NEXT:    [[TMP24:%.*]] = insertelement <3 x i32> [[TMP22]], i32 [[TMP23]], i64 1
-; CHECK-NEXT:    [[TMP25:%.*]] = insertelement <3 x i32> [[TMP24]], i32 [[TMP20]], i64 2
-; CHECK-NEXT:    ret <3 x i32> [[TMP25]]
+; CHECK-NEXT:    [[TMP7:%.*]] = extractelement <3 x i32> [[TMP1]], i64 1
+; CHECK-NEXT:    [[TMP8:%.*]] = zext i32 [[TMP7]] to i64
+; CHECK-NEXT:    [[TMP9:%.*]] = add nuw nsw i64 [[TMP8]], [[TMP6]]
+; CHECK-NEXT:    [[TMP10:%.*]] = zext i1 [[ADD_NARROWED_OVERFLOW]] to i64
+; CHECK-NEXT:    [[TMP11:%.*]] = add nuw nsw i64 [[TMP9]], [[TMP10]]
+; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <3 x i32> [[TMP0]], i64 2
+; CHECK-NEXT:    [[TMP13:%.*]] = extractelement <3 x i32> [[TMP1]], i64 2
+; CHECK-NEXT:    [[TMP14:%.*]] = add i32 [[TMP13]], [[TMP12]]
+; CHECK-NEXT:    [[TMP15:%.*]] = lshr i64 [[TMP11]], 32
+; CHECK-NEXT:    [[TMP16:%.*]] = trunc i64 [[TMP15]] to i32
+; CHECK-NEXT:    [[TMP17:%.*]] = add i32 [[TMP14]], [[TMP16]]
+; CHECK-NEXT:    [[TMP18:%.*]] = insertelement <3 x i32> undef, i32 [[ADD_NARROWED]], i64 0
+; CHECK-NEXT:    [[TMP19:%.*]] = trunc i64 [[TMP11]] to i32
+; CHECK-NEXT:    [[TMP20:%.*]] = insertelement <3 x i32> [[TMP18]], i32 [[TMP19]], i64 1
+; CHECK-NEXT:    [[TMP21:%.*]] = insertelement <3 x i32> [[TMP20]], i32 [[TMP17]], i64 2
+; CHECK-NEXT:    ret <3 x i32> [[TMP21]]
 ;
   %3 = extractelement <3 x i32> %0, i64 0
   %4 = zext i32 %3 to i64


        


More information about the llvm-commits mailing list