[llvm] a47c8e4 - [InstCombine] fold lshr(trunc(lshr X, C1)) C2

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 24 12:45:25 PDT 2021


Author: Sanjay Patel
Date: 2021-09-24T15:44:07-04:00
New Revision: a47c8e40c734429903d4000285ca45a1c3299321

URL: https://github.com/llvm/llvm-project/commit/a47c8e40c734429903d4000285ca45a1c3299321
DIFF: https://github.com/llvm/llvm-project/commit/a47c8e40c734429903d4000285ca45a1c3299321.diff

LOG: [InstCombine] fold lshr(trunc(lshr X, C1)) C2

Only the multi-use cases are changing here because there's
another fold that catches the simpler patterns.

But that other fold is the source of infinite loops when we
try to add D110170, so removing that is planned as a follow-up.

Attempt to show the general proof in Alive2:
https://alive2.llvm.org/ce/z/Ns1uS2

Note that the overshift fold-to-zero tests are not
currently handled by instsimplify. If they were, we
could assert that the shift amount sum is less than
the source bitwidth.

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
    llvm/test/Transforms/InstCombine/lshr.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 92bfae23d231..b0d328f71e32 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -1149,14 +1149,26 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
       }
     }
 
+    // (X >>u C1) >>u C --> X >>u (C1 + C)
     if (match(Op0, m_LShr(m_Value(X), m_APInt(C1)))) {
-      unsigned AmtSum = ShAmtC + C1->getZExtValue();
       // Oversized shifts are simplified to zero in InstSimplify.
+      unsigned AmtSum = ShAmtC + C1->getZExtValue();
       if (AmtSum < BitWidth)
-        // (X >>u C1) >>u C --> X >>u (C1 + C)
         return BinaryOperator::CreateLShr(X, ConstantInt::get(Ty, AmtSum));
     }
 
+    // If the first shift covers the number of bits truncated and the combined
+    // shift fits in the source width:
+    // (trunc (X >>u C1)) >>u C --> trunc (X >>u (C1 + C))
+    if (match(Op0, m_OneUse(m_Trunc(m_LShr(m_Value(X), m_APInt(C1)))))) {
+      unsigned SrcWidth = X->getType()->getScalarSizeInBits();
+      unsigned AmtSum = ShAmtC + C1->getZExtValue();
+      if (C1->uge(SrcWidth - BitWidth) && AmtSum < SrcWidth) {
+        Value *SumShift = Builder.CreateLShr(X, AmtSum, "sum.shift");
+        return new TruncInst(SumShift, Ty);
+      }
+    }
+
     // Look for a "splat" mul pattern - it replicates bits across each half of
     // a value, so a right shift is just a mask of the low bits:
     // lshr i32 (mul nuw X, Pow2+1), 16 --> and X, Pow2-1

diff  --git a/llvm/test/Transforms/InstCombine/lshr.ll b/llvm/test/Transforms/InstCombine/lshr.ll
index 217274996e0d..b8b143814015 100644
--- a/llvm/test/Transforms/InstCombine/lshr.ll
+++ b/llvm/test/Transforms/InstCombine/lshr.ll
@@ -487,8 +487,8 @@ define i12 @trunc_sandwich_use1(i32 %x) {
 ; CHECK-LABEL: @trunc_sandwich_use1(
 ; CHECK-NEXT:    [[SH:%.*]] = lshr i32 [[X:%.*]], 28
 ; CHECK-NEXT:    call void @use(i32 [[SH]])
-; CHECK-NEXT:    [[TR:%.*]] = trunc i32 [[SH]] to i12
-; CHECK-NEXT:    [[R:%.*]] = lshr i12 [[TR]], 2
+; CHECK-NEXT:    [[SUM_SHIFT:%.*]] = lshr i32 [[X]], 30
+; CHECK-NEXT:    [[R:%.*]] = trunc i32 [[SUM_SHIFT]] to i12
 ; CHECK-NEXT:    ret i12 [[R]]
 ;
   %sh = lshr i32 %x, 28
@@ -502,8 +502,8 @@ define <3 x i9> @trunc_sandwich_splat_vec_use1(<3 x i14> %x) {
 ; CHECK-LABEL: @trunc_sandwich_splat_vec_use1(
 ; CHECK-NEXT:    [[SH:%.*]] = lshr <3 x i14> [[X:%.*]], <i14 6, i14 6, i14 6>
 ; CHECK-NEXT:    call void @usevec(<3 x i14> [[SH]])
-; CHECK-NEXT:    [[TR:%.*]] = trunc <3 x i14> [[SH]] to <3 x i9>
-; CHECK-NEXT:    [[R:%.*]] = lshr <3 x i9> [[TR]], <i9 5, i9 5, i9 5>
+; CHECK-NEXT:    [[SUM_SHIFT:%.*]] = lshr <3 x i14> [[X]], <i14 11, i14 11, i14 11>
+; CHECK-NEXT:    [[R:%.*]] = trunc <3 x i14> [[SUM_SHIFT]] to <3 x i9>
 ; CHECK-NEXT:    ret <3 x i9> [[R]]
 ;
   %sh = lshr <3 x i14> %x, <i14 6, i14 6, i14 6>
@@ -517,8 +517,8 @@ define i12 @trunc_sandwich_min_shift1_use1(i32 %x) {
 ; CHECK-LABEL: @trunc_sandwich_min_shift1_use1(
 ; CHECK-NEXT:    [[SH:%.*]] = lshr i32 [[X:%.*]], 20
 ; CHECK-NEXT:    call void @use(i32 [[SH]])
-; CHECK-NEXT:    [[TR:%.*]] = trunc i32 [[SH]] to i12
-; CHECK-NEXT:    [[R:%.*]] = lshr i12 [[TR]], 1
+; CHECK-NEXT:    [[SUM_SHIFT:%.*]] = lshr i32 [[X]], 21
+; CHECK-NEXT:    [[R:%.*]] = trunc i32 [[SUM_SHIFT]] to i12
 ; CHECK-NEXT:    ret i12 [[R]]
 ;
   %sh = lshr i32 %x, 20
@@ -528,6 +528,8 @@ define i12 @trunc_sandwich_min_shift1_use1(i32 %x) {
   ret i12 %r
 }
 
+; negative test - trunc is bigger than first shift
+
 define i12 @trunc_sandwich_small_shift1_use1(i32 %x) {
 ; CHECK-LABEL: @trunc_sandwich_small_shift1_use1(
 ; CHECK-NEXT:    [[SH:%.*]] = lshr i32 [[X:%.*]], 19
@@ -547,8 +549,8 @@ define i12 @trunc_sandwich_max_sum_shift_use1(i32 %x) {
 ; CHECK-LABEL: @trunc_sandwich_max_sum_shift_use1(
 ; CHECK-NEXT:    [[SH:%.*]] = lshr i32 [[X:%.*]], 20
 ; CHECK-NEXT:    call void @use(i32 [[SH]])
-; CHECK-NEXT:    [[TR:%.*]] = trunc i32 [[SH]] to i12
-; CHECK-NEXT:    [[R:%.*]] = lshr i12 [[TR]], 11
+; CHECK-NEXT:    [[SUM_SHIFT:%.*]] = lshr i32 [[X]], 31
+; CHECK-NEXT:    [[R:%.*]] = trunc i32 [[SUM_SHIFT]] to i12
 ; CHECK-NEXT:    ret i12 [[R]]
 ;
   %sh = lshr i32 %x, 20
@@ -562,8 +564,8 @@ define i12 @trunc_sandwich_max_sum_shift2_use1(i32 %x) {
 ; CHECK-LABEL: @trunc_sandwich_max_sum_shift2_use1(
 ; CHECK-NEXT:    [[SH:%.*]] = lshr i32 [[X:%.*]], 30
 ; CHECK-NEXT:    call void @use(i32 [[SH]])
-; CHECK-NEXT:    [[TR:%.*]] = trunc i32 [[SH]] to i12
-; CHECK-NEXT:    [[R:%.*]] = lshr i12 [[TR]], 1
+; CHECK-NEXT:    [[SUM_SHIFT:%.*]] = lshr i32 [[X]], 31
+; CHECK-NEXT:    [[R:%.*]] = trunc i32 [[SUM_SHIFT]] to i12
 ; CHECK-NEXT:    ret i12 [[R]]
 ;
   %sh = lshr i32 %x, 30
@@ -573,6 +575,8 @@ define i12 @trunc_sandwich_max_sum_shift2_use1(i32 %x) {
   ret i12 %r
 }
 
+; negative test - but overshift is simplified to zero by another fold
+
 define i12 @trunc_sandwich_big_sum_shift1_use1(i32 %x) {
 ; CHECK-LABEL: @trunc_sandwich_big_sum_shift1_use1(
 ; CHECK-NEXT:    [[SH:%.*]] = lshr i32 [[X:%.*]], 21
@@ -586,6 +590,8 @@ define i12 @trunc_sandwich_big_sum_shift1_use1(i32 %x) {
   ret i12 %r
 }
 
+; negative test - but overshift is simplified to zero by another fold
+
 define i12 @trunc_sandwich_big_sum_shift2_use1(i32 %x) {
 ; CHECK-LABEL: @trunc_sandwich_big_sum_shift2_use1(
 ; CHECK-NEXT:    [[SH:%.*]] = lshr i32 [[X:%.*]], 31


        


More information about the llvm-commits mailing list