[llvm] r293489 - [InstCombine] enable lshr(shl X, C1), C2 folds for vectors with splat constants

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 30 08:11:41 PST 2017


Author: spatel
Date: Mon Jan 30 10:11:40 2017
New Revision: 293489

URL: http://llvm.org/viewvc/llvm-project?rev=293489&view=rev
Log:
[InstCombine] enable lshr(shl X, C1), C2 folds for vectors with splat constants

Modified:
    llvm/trunk/lib/Transforms/InstCombine/InstCombineShifts.cpp
    llvm/trunk/test/Transforms/InstCombine/shift.ll

Modified: llvm/trunk/lib/Transforms/InstCombine/InstCombineShifts.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineShifts.cpp?rev=293489&r1=293488&r2=293489&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/InstCombine/InstCombineShifts.cpp (original)
+++ llvm/trunk/lib/Transforms/InstCombine/InstCombineShifts.cpp Mon Jan 30 10:11:40 2017
@@ -373,24 +373,6 @@ foldShiftByConstOfShiftByConst(BinaryOpe
   if (ShiftAmt1 < ShiftAmt2) {
     uint32_t ShiftDiff = ShiftAmt2 - ShiftAmt1;
 
-    // (X << C1) >>u C2  --> X >>u (C2-C1) & (-1 >> C2)
-    if (I.getOpcode() == Instruction::LShr &&
-        ShiftOp->getOpcode() == Instruction::Shl) {
-      ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff);
-      // (X <<nuw C1) >>u C2 --> X >>u (C2-C1)
-      if (ShiftOp->hasNoUnsignedWrap()) {
-        BinaryOperator *NewLShr =
-            BinaryOperator::Create(Instruction::LShr, X, ShiftDiffCst);
-        NewLShr->setIsExact(I.isExact());
-        return NewLShr;
-      }
-      Value *Shift = Builder->CreateLShr(X, ShiftDiffCst);
-
-      APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt2));
-      return BinaryOperator::CreateAnd(Shift,
-                                       ConstantInt::get(I.getContext(), Mask));
-    }
-
     // We can't handle (X << C1) >>s C2, it shifts arbitrary bits in. However,
     // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits.
     if (I.getOpcode() == Instruction::AShr &&
@@ -754,10 +736,11 @@ Instruction *InstCombiner::visitLShr(Bin
   if (Instruction *R = commonShiftTransforms(I))
     return R;
 
+  Type *Ty = I.getType();
   const APInt *ShAmtAPInt;
   if (match(Op1, m_APInt(ShAmtAPInt))) {
     unsigned ShAmt = ShAmtAPInt->getZExtValue();
-    unsigned BitWidth = Op0->getType()->getScalarSizeInBits();
+    unsigned BitWidth = Ty->getScalarSizeInBits();
     auto *II = dyn_cast<IntrinsicInst>(Op0);
     if (II && isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmt &&
         (II->getIntrinsicID() == Intrinsic::ctlz ||
@@ -767,16 +750,33 @@ Instruction *InstCombiner::visitLShr(Bin
       // cttz.i32(x)>>5  --> zext(x == 0)
       // ctpop.i32(x)>>5 --> zext(x == -1)
       bool IsPop = II->getIntrinsicID() == Intrinsic::ctpop;
-      Constant *RHS = ConstantInt::getSigned(Op0->getType(), IsPop ? -1 : 0);
+      Constant *RHS = ConstantInt::getSigned(Ty, IsPop ? -1 : 0);
       Value *Cmp = Builder->CreateICmpEQ(II->getArgOperand(0), RHS);
-      return new ZExtInst(Cmp, II->getType());
+      return new ZExtInst(Cmp, Ty);
     }
 
-    // (X << C) >>u C --> X & (-1 >>u C)
     Value *X;
-    if (match(Op0, m_Shl(m_Value(X), m_Specific(Op1)))) {
-      APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt));
-      return BinaryOperator::CreateAnd(X, ConstantInt::get(I.getType(), Mask));
+    const APInt *ShlAmtAPInt;
+    if (match(Op0, m_Shl(m_Value(X), m_APInt(ShlAmtAPInt)))) {
+      unsigned ShlAmt = ShlAmtAPInt->getZExtValue();
+      if (ShlAmt == ShAmt) {
+        // (X << C) >>u C --> X & (-1 >>u C)
+        APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt));
+        return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask));
+      }
+      if (ShlAmt < ShAmt) {
+        Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShlAmt);
+        if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) {
+          // (X <<nuw C1) >>u C2 --> X >>u (C2 - C1)
+          BinaryOperator *NewLShr = BinaryOperator::CreateLShr(X, ShiftDiff);
+          NewLShr->setIsExact(I.isExact());
+          return NewLShr;
+        }
+        // (X << C1) >>u C2  --> (X >>u (C2 - C1)) & (-1 >> C2)
+        Value *NewLShr = Builder->CreateLShr(X, ShiftDiff);
+        APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt));
+        return BinaryOperator::CreateAnd(NewLShr, ConstantInt::get(Ty, Mask));
+      }
     }
 
     // If the shifted-out value is known-zero, then this is an exact shift.

Modified: llvm/trunk/test/Transforms/InstCombine/shift.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/InstCombine/shift.ll?rev=293489&r1=293488&r2=293489&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/InstCombine/shift.ll (original)
+++ llvm/trunk/test/Transforms/InstCombine/shift.ll Mon Jan 30 10:11:40 2017
@@ -924,8 +924,7 @@ define i32 @test51(i32 %x) {
 
 define <2 x i32> @test51_splat_vec(<2 x i32> %x) {
 ; CHECK-LABEL: @test51_splat_vec(
-; CHECK-NEXT:    [[A:%.*]] = shl nuw <2 x i32> %x, <i32 1, i32 1>
-; CHECK-NEXT:    [[B:%.*]] = lshr exact <2 x i32> [[A]], <i32 3, i32 3>
+; CHECK-NEXT:    [[B:%.*]] = lshr exact <2 x i32> %x, <i32 2, i32 2>
 ; CHECK-NEXT:    ret <2 x i32> [[B]]
 ;
   %A = shl nuw <2 x i32> %x, <i32 1, i32 1>
@@ -950,8 +949,8 @@ define i32 @test51_no_nuw(i32 %x) {
 
 define <2 x i32> @test51_no_nuw_splat_vec(<2 x i32> %x) {
 ; CHECK-LABEL: @test51_no_nuw_splat_vec(
-; CHECK-NEXT:    [[A:%.*]] = shl <2 x i32> %x, <i32 1, i32 1>
-; CHECK-NEXT:    [[B:%.*]] = lshr <2 x i32> [[A]], <i32 3, i32 3>
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr <2 x i32> %x, <i32 2, i32 2>
+; CHECK-NEXT:    [[B:%.*]] = and <2 x i32> [[TMP1]], <i32 536870911, i32 536870911>
 ; CHECK-NEXT:    ret <2 x i32> [[B]]
 ;
   %A = shl <2 x i32> %x, <i32 1, i32 1>




More information about the llvm-commits mailing list