[llvm] d50366d - [InstCombine] improve matching for sext-lshr-trunc patterns, part 2

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 8 11:43:41 PDT 2020


Author: Sanjay Patel
Date: 2020-06-08T14:41:50-04:00
New Revision: d50366d29f25fb2297b7561092132ecf74a391e9

URL: https://github.com/llvm/llvm-project/commit/d50366d29f25fb2297b7561092132ecf74a391e9
DIFF: https://github.com/llvm/llvm-project/commit/d50366d29f25fb2297b7561092132ecf74a391e9.diff

LOG: [InstCombine] improve matching for sext-lshr-trunc patterns, part 2

Similar to rG42f488b63a04

This is intended to preserve the logic of the existing transform,
but remove unnecessary restrictions on uses and types.

https://rise4fun.com/Alive/oS0

  Name: narrow input
  Pre: C1 <= width(C1) - 24
  %B = sext i8 %A
  %C = lshr %B, C1
  %r = trunc %C to i24
  =>
  %s = ashr i8 %A, trunc(umin(C1, 7))
  %r = sext i8 %s to i24

  Name: wide input
  Pre: C1 <= width(C1) - 24
  %B = sext i24 %A
  %C = lshr %B, C1
  %r = trunc %C to i8
  =>
  %s = ashr i24 %A, trunc(umin(C1, 23))
  %r = trunc i24 %s to i8

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
    llvm/test/Transforms/InstCombine/cast.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 7256c88f5dc3..3750f31e3cff 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -785,44 +785,27 @@ Instruction *InstCombiner::visitTrunc(TruncInst &Trunc) {
   }
 
   const APInt *C;
-  if (match(Src, m_LShr(m_SExt(m_Value(A)), m_APInt(C))) &&
-      A->getType() == DestTy) {
-    // If the shift is small enough, all zero bits created by the shift are
-    // removed by the trunc:
-    // trunc (lshr (sext A), C) --> ashr A, C
-    if (C->getZExtValue() <= SrcWidth - DestWidth) {
-      unsigned ShAmt = std::min((unsigned)C->getZExtValue(), DestWidth - 1);
-      return BinaryOperator::CreateAShr(A, ConstantInt::get(DestTy, ShAmt));
-    }
-    // TODO: Mask high bits with 'and'.
-  }
+  if (match(Src, m_LShr(m_SExt(m_Value(A)), m_APInt(C)))) {
+    unsigned AWidth = A->getType()->getScalarSizeInBits();
+    unsigned MaxShiftAmt = SrcWidth - std::max(DestWidth, AWidth);
 
-  // More complicated: deal with mismatched sizes.
-  // FIXME: This is too restrictive for uses and doesn't work with vectors.
-  // Transform trunc(lshr (sext A), Cst) to ashr A, Cst to eliminate type
-  // conversion.
-  // It works because bits coming from sign extension have the same value as
-  // the sign bit of the original value; performing ashr instead of lshr
-  // generates bits of the same value as the sign bit.
-  if (Src->hasOneUse() &&
-      match(Src, m_LShr(m_SExt(m_Value(A)), m_ConstantInt(Cst)))) {
-    Value *SExt = cast<Instruction>(Src)->getOperand(0);
-    unsigned ASize = A->getType()->getPrimitiveSizeInBits();
-    unsigned MaxAmt = SrcWidth - std::max(DestWidth, ASize);
-    unsigned ShiftAmt = Cst->getZExtValue();
-
-    // This optimization can be only performed when zero bits generated by
-    // the original lshr aren't pulled into the value after truncation, so we
-    // can only shift by values no larger than the number of extension bits.
-    // FIXME: Instead of bailing when the shift is too large, use and to clear
-    // the extra bits.
-    if (ShiftAmt <= MaxAmt) {
-      if (SExt->hasOneUse()) {
-        Value *Shift = Builder.CreateAShr(A, std::min(ShiftAmt, ASize - 1));
-        Shift->takeName(Src);
+    // If the shift is small enough, all zero bits created by the shift are
+    // removed by the trunc.
+    if (C->getZExtValue() <= MaxShiftAmt) {
+      // trunc (lshr (sext A), C) --> ashr A, C
+      if (A->getType() == DestTy) {
+        unsigned ShAmt = std::min((unsigned)C->getZExtValue(), DestWidth - 1);
+        return BinaryOperator::CreateAShr(A, ConstantInt::get(DestTy, ShAmt));
+      }
+      // The types are mismatched, so create a cast after shifting:
+      // trunc (lshr (sext A), C) --> sext/trunc (ashr A, C)
+      if (Src->hasOneUse()) {
+        unsigned ShAmt = std::min((unsigned)C->getZExtValue(), AWidth - 1);
+        Value *Shift = Builder.CreateAShr(A, ShAmt);
         return CastInst::CreateIntegerCast(Shift, DestTy, true);
       }
     }
+    // TODO: Mask high bits with 'and'.
   }
 
   if (Instruction *I = narrowBinOp(Trunc))

diff  --git a/llvm/test/Transforms/InstCombine/cast.ll b/llvm/test/Transforms/InstCombine/cast.ll
index 89a294c142df..a68d81acdde9 100644
--- a/llvm/test/Transforms/InstCombine/cast.ll
+++ b/llvm/test/Transforms/InstCombine/cast.ll
@@ -1423,8 +1423,8 @@ define i1 @PR23309v2(i32 %A, i32 %B) {
 
 define i16 @PR24763(i8 %V) {
 ; ALL-LABEL: @PR24763(
-; ALL-NEXT:    [[L:%.*]] = ashr i8 [[V:%.*]], 1
-; ALL-NEXT:    [[T:%.*]] = sext i8 [[L]] to i16
+; ALL-NEXT:    [[TMP1:%.*]] = ashr i8 [[V:%.*]], 1
+; ALL-NEXT:    [[T:%.*]] = sext i8 [[TMP1]] to i16
 ; ALL-NEXT:    ret i16 [[T]]
 ;
   %conv = sext i8 %V to i32
@@ -1619,8 +1619,8 @@ define i8 @trunc_lshr_overshift_sext_uses3(i8 %A) {
 
 define i8 @trunc_lshr_sext_wide_input(i16 %A) {
 ; ALL-LABEL: @trunc_lshr_sext_wide_input(
-; ALL-NEXT:    [[C:%.*]] = ashr i16 [[A:%.*]], 9
-; ALL-NEXT:    [[D:%.*]] = trunc i16 [[C]] to i8
+; ALL-NEXT:    [[TMP1:%.*]] = ashr i16 [[A:%.*]], 9
+; ALL-NEXT:    [[D:%.*]] = trunc i16 [[TMP1]] to i8
 ; ALL-NEXT:    ret i8 [[D]]
 ;
   %B = sext i16 %A to i32
@@ -1633,8 +1633,8 @@ define <2 x i8> @trunc_lshr_sext_wide_input_uses1(<2 x i16> %A) {
 ; ALL-LABEL: @trunc_lshr_sext_wide_input_uses1(
 ; ALL-NEXT:    [[B:%.*]] = sext <2 x i16> [[A:%.*]] to <2 x i32>
 ; ALL-NEXT:    call void @use_v2i32(<2 x i32> [[B]])
-; ALL-NEXT:    [[C:%.*]] = lshr <2 x i32> [[B]], <i32 9, i32 9>
-; ALL-NEXT:    [[D:%.*]] = trunc <2 x i32> [[C]] to <2 x i8>
+; ALL-NEXT:    [[TMP1:%.*]] = ashr <2 x i16> [[A]], <i16 9, i16 9>
+; ALL-NEXT:    [[D:%.*]] = trunc <2 x i16> [[TMP1]] to <2 x i8>
 ; ALL-NEXT:    ret <2 x i8> [[D]]
 ;
   %B = sext <2 x i16> %A to <2 x i32>
@@ -1692,8 +1692,8 @@ define i8 @trunc_lshr_overshift_sext_wide_input_uses1(i16 %A) {
 ; ALL-LABEL: @trunc_lshr_overshift_sext_wide_input_uses1(
 ; ALL-NEXT:    [[B:%.*]] = sext i16 [[A:%.*]] to i32
 ; ALL-NEXT:    call void @use_i32(i32 [[B]])
-; ALL-NEXT:    [[C:%.*]] = lshr i32 [[B]], 16
-; ALL-NEXT:    [[D:%.*]] = trunc i32 [[C]] to i8
+; ALL-NEXT:    [[TMP1:%.*]] = ashr i16 [[A]], 15
+; ALL-NEXT:    [[D:%.*]] = trunc i16 [[TMP1]] to i8
 ; ALL-NEXT:    ret i8 [[D]]
 ;
   %B = sext i16 %A to i32
@@ -1737,8 +1737,8 @@ define i8 @trunc_lshr_overshift_sext_wide_input_uses3(i16 %A) {
 
 define i16 @trunc_lshr_sext_narrow_input(i8 %A) {
 ; ALL-LABEL: @trunc_lshr_sext_narrow_input(
-; ALL-NEXT:    [[C:%.*]] = ashr i8 [[A:%.*]], 6
-; ALL-NEXT:    [[D:%.*]] = sext i8 [[C]] to i16
+; ALL-NEXT:    [[TMP1:%.*]] = ashr i8 [[A:%.*]], 6
+; ALL-NEXT:    [[D:%.*]] = sext i8 [[TMP1]] to i16
 ; ALL-NEXT:    ret i16 [[D]]
 ;
   %B = sext i8 %A to i32
@@ -1751,8 +1751,8 @@ define <2 x i16> @trunc_lshr_sext_narrow_input_uses1(<2 x i8> %A) {
 ; ALL-LABEL: @trunc_lshr_sext_narrow_input_uses1(
 ; ALL-NEXT:    [[B:%.*]] = sext <2 x i8> [[A:%.*]] to <2 x i32>
 ; ALL-NEXT:    call void @use_v2i32(<2 x i32> [[B]])
-; ALL-NEXT:    [[C:%.*]] = lshr <2 x i32> [[B]], <i32 6, i32 6>
-; ALL-NEXT:    [[D:%.*]] = trunc <2 x i32> [[C]] to <2 x i16>
+; ALL-NEXT:    [[TMP1:%.*]] = ashr <2 x i8> [[A]], <i8 6, i8 6>
+; ALL-NEXT:    [[D:%.*]] = sext <2 x i8> [[TMP1]] to <2 x i16>
 ; ALL-NEXT:    ret <2 x i16> [[D]]
 ;
   %B = sext <2 x i8> %A to <2 x i32>
@@ -1796,9 +1796,8 @@ define <2 x i16> @trunc_lshr_sext_narrow_input_uses3(<2 x i8> %A) {
 
 define <2 x i16> @trunc_lshr_overshift_narrow_input_sext(<2 x i8> %A) {
 ; ALL-LABEL: @trunc_lshr_overshift_narrow_input_sext(
-; ALL-NEXT:    [[B:%.*]] = sext <2 x i8> [[A:%.*]] to <2 x i32>
-; ALL-NEXT:    [[C:%.*]] = lshr <2 x i32> [[B]], <i32 8, i32 8>
-; ALL-NEXT:    [[D:%.*]] = trunc <2 x i32> [[C]] to <2 x i16>
+; ALL-NEXT:    [[TMP1:%.*]] = ashr <2 x i8> [[A:%.*]], <i8 7, i8 7>
+; ALL-NEXT:    [[D:%.*]] = sext <2 x i8> [[TMP1]] to <2 x i16>
 ; ALL-NEXT:    ret <2 x i16> [[D]]
 ;
   %B = sext <2 x i8> %A to <2 x i32>
@@ -1811,8 +1810,8 @@ define i16 @trunc_lshr_overshift_sext_narrow_input_uses1(i8 %A) {
 ; ALL-LABEL: @trunc_lshr_overshift_sext_narrow_input_uses1(
 ; ALL-NEXT:    [[B:%.*]] = sext i8 [[A:%.*]] to i32
 ; ALL-NEXT:    call void @use_i32(i32 [[B]])
-; ALL-NEXT:    [[C:%.*]] = lshr i32 [[B]], 8
-; ALL-NEXT:    [[D:%.*]] = trunc i32 [[C]] to i16
+; ALL-NEXT:    [[TMP1:%.*]] = ashr i8 [[A]], 7
+; ALL-NEXT:    [[D:%.*]] = sext i8 [[TMP1]] to i16
 ; ALL-NEXT:    ret i16 [[D]]
 ;
   %B = sext i8 %A to i32
@@ -1930,8 +1929,8 @@ define i8 @pr33078_1(i8 %A) {
 
 define i12 @pr33078_2(i8 %A) {
 ; ALL-LABEL: @pr33078_2(
-; ALL-NEXT:    [[C:%.*]] = ashr i8 [[A:%.*]], 4
-; ALL-NEXT:    [[D:%.*]] = sext i8 [[C]] to i12
+; ALL-NEXT:    [[TMP1:%.*]] = ashr i8 [[A:%.*]], 4
+; ALL-NEXT:    [[D:%.*]] = sext i8 [[TMP1]] to i12
 ; ALL-NEXT:    ret i12 [[D]]
 ;
   %B = sext i8 %A to i16


        


More information about the llvm-commits mailing list