[llvm] 21429cf - [InstCombine] generalize fold for (trunc (X u>> C1)) u>> C

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 27 07:58:46 PDT 2021


Author: Sanjay Patel
Date: 2021-09-27T10:57:31-04:00
New Revision: 21429cf43a41d4ff1263dd601d5e5f81a6387cd0

URL: https://github.com/llvm/llvm-project/commit/21429cf43a41d4ff1263dd601d5e5f81a6387cd0
DIFF: https://github.com/llvm/llvm-project/commit/21429cf43a41d4ff1263dd601d5e5f81a6387cd0.diff

LOG: [InstCombine] generalize fold for (trunc (X u>> C1)) u>> C

This is another step towards trying to re-apply D110170
by eliminating conflicting transforms that cause infinite loops.
a47c8e40c734 was a previous patch in this direction.

The diffs here are mostly cosmetic, but intentional:
1. The existing code that would handle this pattern in FoldShiftByConstant()
   is limited to 'shl' only now. The formatting change to IsLeftShift shows
   that we could move several transforms into visitShl() directly for
   efficiency because they are not common shift transforms.

2. The tests are regenerated to show new instruction names to prove that
   we are getting (almost) identical logic results.

3. The one case where we differ ("trunc_sandwich_small_shift1") shows that
   we now use a narrow 'and' instruction. Previously, we relied on another
   transform to do that, but it is limited to legal types. That seems to
   be a legacy constraint from when IR analysis and codegen were less robust.

https://alive2.llvm.org/ce/z/JxyGA4

  declare void @llvm.assume(i1)

  define i8 @src(i32 %x, i32 %c0, i8 %c1) {
    ; The sum of the shifts must not overflow the source width.
    %z1 = zext i8 %c1 to i32
    %sum = add i32 %c0, %z1
    %ov = icmp ult i32 %sum, 32
    call void @llvm.assume(i1 %ov)

    %sh1 = lshr i32 %x, %c0
    %tr = trunc i32 %sh1 to i8
    %sh2 = lshr i8 %tr, %c1
    ret i8 %sh2
  }

  define i8 @tgt(i32 %x, i32 %c0, i8 %c1) {
    %z1 = zext i8 %c1 to i32
    %sum = add i32 %c0, %z1
    %maskc = lshr i8 -1, %c1

    %s = lshr i32 %x, %sum
    %t = trunc i32 %s to i8
    %a = and i8 %t, %maskc
    ret i8 %a
  }

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
    llvm/test/Transforms/InstCombine/lshr.ll
    llvm/test/Transforms/InstCombine/shift.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index bfa37d8e98b7..065a89f8e25b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -661,7 +661,7 @@ static bool canShiftBinOpWithConstantRHS(BinaryOperator &Shift,
 
 Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *Op1,
                                                    BinaryOperator &I) {
-  bool isLeftShift = I.getOpcode() == Instruction::Shl;
+  bool IsLeftShift = I.getOpcode() == Instruction::Shl;
 
   const APInt *Op1C;
   if (!match(Op1, m_APInt(Op1C)))
@@ -670,14 +670,14 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *Op1,
   // See if we can propagate this shift into the input, this covers the trivial
   // cast of lshr(shl(x,c1),c2) as well as other more complex cases.
   if (I.getOpcode() != Instruction::AShr &&
-      canEvaluateShifted(Op0, Op1C->getZExtValue(), isLeftShift, *this, &I)) {
+      canEvaluateShifted(Op0, Op1C->getZExtValue(), IsLeftShift, *this, &I)) {
     LLVM_DEBUG(
         dbgs() << "ICE: GetShiftedValue propagating shift through expression"
                   " to eliminate shift:\n  IN: "
                << *Op0 << "\n  SH: " << I << "\n");
 
     return replaceInstUsesWith(
-        I, getShiftedValue(Op0, Op1C->getZExtValue(), isLeftShift, *this, DL));
+        I, getShiftedValue(Op0, Op1C->getZExtValue(), IsLeftShift, *this, DL));
   }
 
   // See if we can simplify any instructions used by the instruction whose sole
@@ -701,7 +701,7 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *Op1,
   // xform in more cases, but it is unlikely to be profitable.
   Instruction *TrOp;
   const APInt *TrShiftAmt;
-  if (I.isLogicalShift() && match(Op0, m_Trunc(m_Instruction(TrOp))) &&
+  if (IsLeftShift && match(Op0, m_Trunc(m_Instruction(TrOp))) &&
       match(TrOp, m_OneUse(m_Shift(m_Value(), m_APInt(TrShiftAmt)))) &&
       TrShiftAmt->ult(TrOp->getType()->getScalarSizeInBits())) {
     Type *SrcTy = TrOp->getType();
@@ -743,7 +743,7 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *Op1,
     case Instruction::Xor: {
       // These operators commute.
       // Turn (Y + (X >> C)) << C  ->  (X + (Y << C)) & (~0 << C)
-      if (isLeftShift && Op0BO->getOperand(1)->hasOneUse() &&
+      if (IsLeftShift && Op0BO->getOperand(1)->hasOneUse() &&
           match(Op0BO->getOperand(1), m_Shr(m_Value(V1), m_Specific(Op1)))) {
         Value *YS = // (Y << C)
             Builder.CreateShl(Op0BO->getOperand(0), Op1, Op0BO->getName());
@@ -758,7 +758,7 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *Op1,
 
       // Turn (Y + ((X >> C) & CC)) << C  ->  ((X & (CC << C)) + (Y << C))
       Value *Op0BOOp1 = Op0BO->getOperand(1);
-      if (isLeftShift && Op0BOOp1->hasOneUse() &&
+      if (IsLeftShift && Op0BOOp1->hasOneUse() &&
           match(Op0BOOp1, m_And(m_OneUse(m_Shr(m_Value(V1), m_Specific(Op1))),
                                 m_APInt(CC)))) {
         Value *YS = // (Y << C)
@@ -774,7 +774,7 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *Op1,
 
     case Instruction::Sub: {
       // Turn ((X >> C) + Y) << C  ->  (X + (Y << C)) & (~0 << C)
-      if (isLeftShift && Op0BO->getOperand(0)->hasOneUse() &&
+      if (IsLeftShift && Op0BO->getOperand(0)->hasOneUse() &&
           match(Op0BO->getOperand(0), m_Shr(m_Value(V1), m_Specific(Op1)))) {
         Value *YS = // (Y << C)
             Builder.CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName());
@@ -788,7 +788,7 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *Op1,
       }
 
       // Turn (((X >> C)&CC) + Y) << C  ->  (X + (Y << C)) & (CC << C)
-      if (isLeftShift && Op0BO->getOperand(0)->hasOneUse() &&
+      if (IsLeftShift && Op0BO->getOperand(0)->hasOneUse() &&
           match(Op0BO->getOperand(0),
                 m_And(m_OneUse(m_Shr(m_Value(V1), m_Specific(Op1))),
                       m_APInt(CC)))) {
@@ -824,7 +824,7 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *Op1,
     // If the operand is a subtract with a constant LHS, and the shift
     // is the only use, we can pull it out of the shift.
     // This folds (shl (sub C1, X), C2) -> (sub (C1 << C2), (shl X, C2))
-    if (isLeftShift && Op0BO->getOpcode() == Instruction::Sub &&
+    if (IsLeftShift && Op0BO->getOpcode() == Instruction::Sub &&
         match(Op0BO->getOperand(0), m_APInt(Op0C))) {
       Constant *NewRHS = ConstantExpr::get(
           I.getOpcode(), cast<Constant>(Op0BO->getOperand(0)), Op1);
@@ -1158,15 +1158,26 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
         return BinaryOperator::CreateLShr(X, ConstantInt::get(Ty, AmtSum));
     }
 
-    // If the first shift covers the number of bits truncated and the combined
-    // shift fits in the source width:
-    // (trunc (X >>u C1)) >>u C --> trunc (X >>u (C1 + C))
-    if (match(Op0, m_OneUse(m_Trunc(m_LShr(m_Value(X), m_APInt(C1)))))) {
+    Instruction *TruncSrc;
+    if (match(Op0, m_OneUse(m_Trunc(m_Instruction(TruncSrc)))) &&
+        match(TruncSrc, m_LShr(m_Value(X), m_APInt(C1)))) {
       unsigned SrcWidth = X->getType()->getScalarSizeInBits();
       unsigned AmtSum = ShAmtC + C1->getZExtValue();
-      if (C1->uge(SrcWidth - BitWidth) && AmtSum < SrcWidth) {
+
+      // If the combined shift fits in the source width:
+      // (trunc (X >>u C1)) >>u C --> and (trunc (X >>u (C1 + C)), MaskC
+      //
+      // If the first shift covers the number of bits truncated, then the
+      // mask instruction is eliminated (and so the use check is relaxed).
+      if (AmtSum < SrcWidth &&
+          (TruncSrc->hasOneUse() || C1->uge(SrcWidth - BitWidth))) {
         Value *SumShift = Builder.CreateLShr(X, AmtSum, "sum.shift");
-        return new TruncInst(SumShift, Ty);
+        Value *Trunc = Builder.CreateTrunc(SumShift, Ty, I.getName());
+
+        // If the first shift does not cover the number of bits truncated, then
+        // we require a mask to get rid of high bits in the result.
+        APInt MaskC = APInt::getAllOnes(BitWidth).lshr(ShAmtC);
+        return BinaryOperator::CreateAnd(Trunc, ConstantInt::get(Ty, MaskC));
       }
     }
 

diff  --git a/llvm/test/Transforms/InstCombine/lshr.ll b/llvm/test/Transforms/InstCombine/lshr.ll
index b8b143814015..2a7d8e18b6cb 100644
--- a/llvm/test/Transforms/InstCombine/lshr.ll
+++ b/llvm/test/Transforms/InstCombine/lshr.ll
@@ -392,9 +392,9 @@ define i32 @srem2_lshr30(i32 %x) {
 
 define i12 @trunc_sandwich(i32 %x) {
 ; CHECK-LABEL: @trunc_sandwich(
-; CHECK-NEXT:    [[SH:%.*]] = lshr i32 [[X:%.*]], 30
-; CHECK-NEXT:    [[R:%.*]] = trunc i32 [[SH]] to i12
-; CHECK-NEXT:    ret i12 [[R]]
+; CHECK-NEXT:    [[SUM_SHIFT:%.*]] = lshr i32 [[X:%.*]], 30
+; CHECK-NEXT:    [[R1:%.*]] = trunc i32 [[SUM_SHIFT]] to i12
+; CHECK-NEXT:    ret i12 [[R1]]
 ;
   %sh = lshr i32 %x, 28
   %tr = trunc i32 %sh to i12
@@ -404,9 +404,9 @@ define i12 @trunc_sandwich(i32 %x) {
 
 define <2 x i12> @trunc_sandwich_splat_vec(<2 x i32> %x) {
 ; CHECK-LABEL: @trunc_sandwich_splat_vec(
-; CHECK-NEXT:    [[SH:%.*]] = lshr <2 x i32> [[X:%.*]], <i32 30, i32 30>
-; CHECK-NEXT:    [[R:%.*]] = trunc <2 x i32> [[SH]] to <2 x i12>
-; CHECK-NEXT:    ret <2 x i12> [[R]]
+; CHECK-NEXT:    [[SUM_SHIFT:%.*]] = lshr <2 x i32> [[X:%.*]], <i32 30, i32 30>
+; CHECK-NEXT:    [[R1:%.*]] = trunc <2 x i32> [[SUM_SHIFT]] to <2 x i12>
+; CHECK-NEXT:    ret <2 x i12> [[R1]]
 ;
   %sh = lshr <2 x i32> %x, <i32 22, i32 22>
   %tr = trunc <2 x i32> %sh to <2 x i12>
@@ -416,9 +416,9 @@ define <2 x i12> @trunc_sandwich_splat_vec(<2 x i32> %x) {
 
 define i12 @trunc_sandwich_min_shift1(i32 %x) {
 ; CHECK-LABEL: @trunc_sandwich_min_shift1(
-; CHECK-NEXT:    [[SH:%.*]] = lshr i32 [[X:%.*]], 21
-; CHECK-NEXT:    [[R:%.*]] = trunc i32 [[SH]] to i12
-; CHECK-NEXT:    ret i12 [[R]]
+; CHECK-NEXT:    [[SUM_SHIFT:%.*]] = lshr i32 [[X:%.*]], 21
+; CHECK-NEXT:    [[R1:%.*]] = trunc i32 [[SUM_SHIFT]] to i12
+; CHECK-NEXT:    ret i12 [[R1]]
 ;
   %sh = lshr i32 %x, 20
   %tr = trunc i32 %sh to i12
@@ -428,9 +428,9 @@ define i12 @trunc_sandwich_min_shift1(i32 %x) {
 
 define i12 @trunc_sandwich_small_shift1(i32 %x) {
 ; CHECK-LABEL: @trunc_sandwich_small_shift1(
-; CHECK-NEXT:    [[SH:%.*]] = lshr i32 [[X:%.*]], 20
-; CHECK-NEXT:    [[TR2:%.*]] = and i32 [[SH]], 2047
-; CHECK-NEXT:    [[R:%.*]] = trunc i32 [[TR2]] to i12
+; CHECK-NEXT:    [[SUM_SHIFT:%.*]] = lshr i32 [[X:%.*]], 20
+; CHECK-NEXT:    [[R1:%.*]] = trunc i32 [[SUM_SHIFT]] to i12
+; CHECK-NEXT:    [[R:%.*]] = and i12 [[R1]], 2047
 ; CHECK-NEXT:    ret i12 [[R]]
 ;
   %sh = lshr i32 %x, 19
@@ -441,9 +441,9 @@ define i12 @trunc_sandwich_small_shift1(i32 %x) {
 
 define i12 @trunc_sandwich_max_sum_shift(i32 %x) {
 ; CHECK-LABEL: @trunc_sandwich_max_sum_shift(
-; CHECK-NEXT:    [[SH:%.*]] = lshr i32 [[X:%.*]], 31
-; CHECK-NEXT:    [[R:%.*]] = trunc i32 [[SH]] to i12
-; CHECK-NEXT:    ret i12 [[R]]
+; CHECK-NEXT:    [[SUM_SHIFT:%.*]] = lshr i32 [[X:%.*]], 31
+; CHECK-NEXT:    [[R1:%.*]] = trunc i32 [[SUM_SHIFT]] to i12
+; CHECK-NEXT:    ret i12 [[R1]]
 ;
   %sh = lshr i32 %x, 20
   %tr = trunc i32 %sh to i12
@@ -453,9 +453,9 @@ define i12 @trunc_sandwich_max_sum_shift(i32 %x) {
 
 define i12 @trunc_sandwich_max_sum_shift2(i32 %x) {
 ; CHECK-LABEL: @trunc_sandwich_max_sum_shift2(
-; CHECK-NEXT:    [[SH:%.*]] = lshr i32 [[X:%.*]], 31
-; CHECK-NEXT:    [[R:%.*]] = trunc i32 [[SH]] to i12
-; CHECK-NEXT:    ret i12 [[R]]
+; CHECK-NEXT:    [[SUM_SHIFT:%.*]] = lshr i32 [[X:%.*]], 31
+; CHECK-NEXT:    [[R1:%.*]] = trunc i32 [[SUM_SHIFT]] to i12
+; CHECK-NEXT:    ret i12 [[R1]]
 ;
   %sh = lshr i32 %x, 30
   %tr = trunc i32 %sh to i12
@@ -488,8 +488,8 @@ define i12 @trunc_sandwich_use1(i32 %x) {
 ; CHECK-NEXT:    [[SH:%.*]] = lshr i32 [[X:%.*]], 28
 ; CHECK-NEXT:    call void @use(i32 [[SH]])
 ; CHECK-NEXT:    [[SUM_SHIFT:%.*]] = lshr i32 [[X]], 30
-; CHECK-NEXT:    [[R:%.*]] = trunc i32 [[SUM_SHIFT]] to i12
-; CHECK-NEXT:    ret i12 [[R]]
+; CHECK-NEXT:    [[R1:%.*]] = trunc i32 [[SUM_SHIFT]] to i12
+; CHECK-NEXT:    ret i12 [[R1]]
 ;
   %sh = lshr i32 %x, 28
   call void @use(i32 %sh)
@@ -503,8 +503,8 @@ define <3 x i9> @trunc_sandwich_splat_vec_use1(<3 x i14> %x) {
 ; CHECK-NEXT:    [[SH:%.*]] = lshr <3 x i14> [[X:%.*]], <i14 6, i14 6, i14 6>
 ; CHECK-NEXT:    call void @usevec(<3 x i14> [[SH]])
 ; CHECK-NEXT:    [[SUM_SHIFT:%.*]] = lshr <3 x i14> [[X]], <i14 11, i14 11, i14 11>
-; CHECK-NEXT:    [[R:%.*]] = trunc <3 x i14> [[SUM_SHIFT]] to <3 x i9>
-; CHECK-NEXT:    ret <3 x i9> [[R]]
+; CHECK-NEXT:    [[R1:%.*]] = trunc <3 x i14> [[SUM_SHIFT]] to <3 x i9>
+; CHECK-NEXT:    ret <3 x i9> [[R1]]
 ;
   %sh = lshr <3 x i14> %x, <i14 6, i14 6, i14 6>
   call void @usevec(<3 x i14> %sh)
@@ -518,8 +518,8 @@ define i12 @trunc_sandwich_min_shift1_use1(i32 %x) {
 ; CHECK-NEXT:    [[SH:%.*]] = lshr i32 [[X:%.*]], 20
 ; CHECK-NEXT:    call void @use(i32 [[SH]])
 ; CHECK-NEXT:    [[SUM_SHIFT:%.*]] = lshr i32 [[X]], 21
-; CHECK-NEXT:    [[R:%.*]] = trunc i32 [[SUM_SHIFT]] to i12
-; CHECK-NEXT:    ret i12 [[R]]
+; CHECK-NEXT:    [[R1:%.*]] = trunc i32 [[SUM_SHIFT]] to i12
+; CHECK-NEXT:    ret i12 [[R1]]
 ;
   %sh = lshr i32 %x, 20
   call void @use(i32 %sh)
@@ -550,8 +550,8 @@ define i12 @trunc_sandwich_max_sum_shift_use1(i32 %x) {
 ; CHECK-NEXT:    [[SH:%.*]] = lshr i32 [[X:%.*]], 20
 ; CHECK-NEXT:    call void @use(i32 [[SH]])
 ; CHECK-NEXT:    [[SUM_SHIFT:%.*]] = lshr i32 [[X]], 31
-; CHECK-NEXT:    [[R:%.*]] = trunc i32 [[SUM_SHIFT]] to i12
-; CHECK-NEXT:    ret i12 [[R]]
+; CHECK-NEXT:    [[R1:%.*]] = trunc i32 [[SUM_SHIFT]] to i12
+; CHECK-NEXT:    ret i12 [[R1]]
 ;
   %sh = lshr i32 %x, 20
   call void @use(i32 %sh)
@@ -565,8 +565,8 @@ define i12 @trunc_sandwich_max_sum_shift2_use1(i32 %x) {
 ; CHECK-NEXT:    [[SH:%.*]] = lshr i32 [[X:%.*]], 30
 ; CHECK-NEXT:    call void @use(i32 [[SH]])
 ; CHECK-NEXT:    [[SUM_SHIFT:%.*]] = lshr i32 [[X]], 31
-; CHECK-NEXT:    [[R:%.*]] = trunc i32 [[SUM_SHIFT]] to i12
-; CHECK-NEXT:    ret i12 [[R]]
+; CHECK-NEXT:    [[R1:%.*]] = trunc i32 [[SUM_SHIFT]] to i12
+; CHECK-NEXT:    ret i12 [[R1]]
 ;
   %sh = lshr i32 %x, 30
   call void @use(i32 %sh)

diff  --git a/llvm/test/Transforms/InstCombine/shift.ll b/llvm/test/Transforms/InstCombine/shift.ll
index 5ecc443d9010..f644ed9bb86e 100644
--- a/llvm/test/Transforms/InstCombine/shift.ll
+++ b/llvm/test/Transforms/InstCombine/shift.ll
@@ -444,9 +444,9 @@ bb2:
 define i32 @test29(i64 %d18) {
 ; CHECK-LABEL: @test29(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[I916:%.*]] = lshr i64 [[D18:%.*]], 63
-; CHECK-NEXT:    [[I10:%.*]] = trunc i64 [[I916]] to i32
-; CHECK-NEXT:    ret i32 [[I10]]
+; CHECK-NEXT:    [[SUM_SHIFT:%.*]] = lshr i64 [[D18:%.*]], 63
+; CHECK-NEXT:    [[I101:%.*]] = trunc i64 [[SUM_SHIFT]] to i32
+; CHECK-NEXT:    ret i32 [[I101]]
 ;
 entry:
   %i916 = lshr i64 %d18, 32
@@ -458,9 +458,9 @@ entry:
 define <2 x i32> @test29_uniform(<2 x i64> %d18) {
 ; CHECK-LABEL: @test29_uniform(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[I916:%.*]] = lshr <2 x i64> [[D18:%.*]], <i64 63, i64 63>
-; CHECK-NEXT:    [[I10:%.*]] = trunc <2 x i64> [[I916]] to <2 x i32>
-; CHECK-NEXT:    ret <2 x i32> [[I10]]
+; CHECK-NEXT:    [[SUM_SHIFT:%.*]] = lshr <2 x i64> [[D18:%.*]], <i64 63, i64 63>
+; CHECK-NEXT:    [[I101:%.*]] = trunc <2 x i64> [[SUM_SHIFT]] to <2 x i32>
+; CHECK-NEXT:    ret <2 x i32> [[I101]]
 ;
 entry:
   %i916 = lshr <2 x i64> %d18, <i64 32, i64 32>


        


More information about the llvm-commits mailing list