[llvm] 9075edc - [InstCombine] move shl-only folds out from under commonShiftTransforms(); NFCI

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 27 09:24:33 PDT 2021


Author: Sanjay Patel
Date: 2021-09-27T12:09:47-04:00
New Revision: 9075edc89bc9b962ef0d16baf57b57b4eb83cf0f

URL: https://github.com/llvm/llvm-project/commit/9075edc89bc9b962ef0d16baf57b57b4eb83cf0f
DIFF: https://github.com/llvm/llvm-project/commit/9075edc89bc9b962ef0d16baf57b57b4eb83cf0f.diff

LOG: [InstCombine] move shl-only folds out from under commonShiftTransforms(); NFCI

This is no-functional-change-intended, but it hopefully makes things
slightly clearer and more efficient to have transforms that require
'shl' be called only from visitShl(). Further cleanup is possible.

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 065a89f8e25b..87750f67d7a3 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -661,14 +661,13 @@ static bool canShiftBinOpWithConstantRHS(BinaryOperator &Shift,
 
 Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *Op1,
                                                    BinaryOperator &I) {
-  bool IsLeftShift = I.getOpcode() == Instruction::Shl;
-
   const APInt *Op1C;
   if (!match(Op1, m_APInt(Op1C)))
     return nullptr;
 
   // See if we can propagate this shift into the input, this covers the trivial
   // cast of lshr(shl(x,c1),c2) as well as other more complex cases.
+  bool IsLeftShift = I.getOpcode() == Instruction::Shl;
   if (I.getOpcode() != Instruction::AShr &&
       canEvaluateShifted(Op0, Op1C->getZExtValue(), IsLeftShift, *this, &I)) {
     LLVM_DEBUG(
@@ -693,118 +692,7 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *Op1,
   if (!Op0->hasOneUse())
     return nullptr;
 
-  // Fold shift2(trunc(shift1(x,c1)), c2) -> trunc(shift2(shift1(x,c1),c2))
-  // If 'shift2' is an ashr, we would have to get the sign bit into a funny
-  // place.  Don't try to do this transformation in this case.  Also, we
-  // require that the input operand is a non-poison shift-by-constant so that we
-  // have confidence that the shifts will get folded together.  We could do this
-  // xform in more cases, but it is unlikely to be profitable.
-  Instruction *TrOp;
-  const APInt *TrShiftAmt;
-  if (IsLeftShift && match(Op0, m_Trunc(m_Instruction(TrOp))) &&
-      match(TrOp, m_OneUse(m_Shift(m_Value(), m_APInt(TrShiftAmt)))) &&
-      TrShiftAmt->ult(TrOp->getType()->getScalarSizeInBits())) {
-    Type *SrcTy = TrOp->getType();
-
-    // Okay, we'll do this xform.  Make the shift of shift.
-    Constant *ShAmt = ConstantExpr::getZExt(Op1, SrcTy);
-    // (shift2 (shift1 & 0x00FF), c2)
-    Value *NSh = Builder.CreateBinOp(I.getOpcode(), TrOp, ShAmt, I.getName());
-
-    // For logical shifts, the truncation has the effect of making the high
-    // part of the register be zeros.  Emulate this by inserting an AND to
-    // clear the top bits as needed.  This 'and' will usually be zapped by
-    // other xforms later if dead.
-    unsigned SrcSize = SrcTy->getScalarSizeInBits();
-    Constant *MaskV =
-        ConstantInt::get(SrcTy, APInt::getLowBitsSet(SrcSize, TypeBits));
-
-    // The mask we constructed says what the trunc would do if occurring
-    // between the shifts.  We want to know the effect *after* the second
-    // shift.  We know that it is a logical shift by a constant, so adjust the
-    // mask as appropriate.
-    MaskV = ConstantExpr::get(I.getOpcode(), MaskV, ShAmt);
-    // shift1 & 0x00FF
-    Value *And = Builder.CreateAnd(NSh, MaskV, Op0->getName());
-    // Return the value truncated to the interesting size.
-    return new TruncInst(And, Ty);
-  }
-
   if (auto *Op0BO = dyn_cast<BinaryOperator>(Op0)) {
-    // Turn ((X >> C) + Y) << C  ->  (X + (Y << C)) & (~0 << C)
-    Value *V1;
-    const APInt *CC;
-    switch (Op0BO->getOpcode()) {
-    default:
-      break;
-    case Instruction::Add:
-    case Instruction::And:
-    case Instruction::Or:
-    case Instruction::Xor: {
-      // These operators commute.
-      // Turn (Y + (X >> C)) << C  ->  (X + (Y << C)) & (~0 << C)
-      if (IsLeftShift && Op0BO->getOperand(1)->hasOneUse() &&
-          match(Op0BO->getOperand(1), m_Shr(m_Value(V1), m_Specific(Op1)))) {
-        Value *YS = // (Y << C)
-            Builder.CreateShl(Op0BO->getOperand(0), Op1, Op0BO->getName());
-        // (X + (Y << C))
-        Value *X = Builder.CreateBinOp(Op0BO->getOpcode(), YS, V1,
-                                       Op0BO->getOperand(1)->getName());
-        unsigned Op1Val = Op1C->getLimitedValue(TypeBits);
-        APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val);
-        Constant *Mask = ConstantInt::get(Ty, Bits);
-        return BinaryOperator::CreateAnd(X, Mask);
-      }
-
-      // Turn (Y + ((X >> C) & CC)) << C  ->  ((X & (CC << C)) + (Y << C))
-      Value *Op0BOOp1 = Op0BO->getOperand(1);
-      if (IsLeftShift && Op0BOOp1->hasOneUse() &&
-          match(Op0BOOp1, m_And(m_OneUse(m_Shr(m_Value(V1), m_Specific(Op1))),
-                                m_APInt(CC)))) {
-        Value *YS = // (Y << C)
-            Builder.CreateShl(Op0BO->getOperand(0), Op1, Op0BO->getName());
-        // X & (CC << C)
-        Value *XM = Builder.CreateAnd(
-            V1, ConstantExpr::getShl(ConstantInt::get(Ty, *CC), Op1),
-            V1->getName() + ".mask");
-        return BinaryOperator::Create(Op0BO->getOpcode(), YS, XM);
-      }
-      LLVM_FALLTHROUGH;
-    }
-
-    case Instruction::Sub: {
-      // Turn ((X >> C) + Y) << C  ->  (X + (Y << C)) & (~0 << C)
-      if (IsLeftShift && Op0BO->getOperand(0)->hasOneUse() &&
-          match(Op0BO->getOperand(0), m_Shr(m_Value(V1), m_Specific(Op1)))) {
-        Value *YS = // (Y << C)
-            Builder.CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName());
-        // (X + (Y << C))
-        Value *X = Builder.CreateBinOp(Op0BO->getOpcode(), V1, YS,
-                                       Op0BO->getOperand(0)->getName());
-        unsigned Op1Val = Op1C->getLimitedValue(TypeBits);
-        APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val);
-        Constant *Mask = ConstantInt::get(Ty, Bits);
-        return BinaryOperator::CreateAnd(X, Mask);
-      }
-
-      // Turn (((X >> C)&CC) + Y) << C  ->  (X + (Y << C)) & (CC << C)
-      if (IsLeftShift && Op0BO->getOperand(0)->hasOneUse() &&
-          match(Op0BO->getOperand(0),
-                m_And(m_OneUse(m_Shr(m_Value(V1), m_Specific(Op1))),
-                      m_APInt(CC)))) {
-        Value *YS = // (Y << C)
-            Builder.CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName());
-        // X & (CC << C)
-        Value *XM = Builder.CreateAnd(
-            V1, ConstantExpr::getShl(ConstantInt::get(Ty, *CC), Op1),
-            V1->getName() + ".mask");
-        return BinaryOperator::Create(Op0BO->getOpcode(), XM, YS);
-      }
-
-      break;
-    }
-    }
-
     // If the operand is a bitwise operator with a constant RHS, and the
     // shift is the only use, we can pull it out of the shift.
     const APInt *Op0C;
@@ -820,20 +708,6 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *Op1,
         return BinaryOperator::Create(Op0BO->getOpcode(), NewShift, NewRHS);
       }
     }
-
-    // If the operand is a subtract with a constant LHS, and the shift
-    // is the only use, we can pull it out of the shift.
-    // This folds (shl (sub C1, X), C2) -> (sub (C1 << C2), (shl X, C2))
-    if (IsLeftShift && Op0BO->getOpcode() == Instruction::Sub &&
-        match(Op0BO->getOperand(0), m_APInt(Op0C))) {
-      Constant *NewRHS = ConstantExpr::get(
-          I.getOpcode(), cast<Constant>(Op0BO->getOperand(0)), Op1);
-
-      Value *NewShift = Builder.CreateShl(Op0BO->getOperand(1), Op1);
-      NewShift->takeName(Op0BO);
-
-      return BinaryOperator::CreateSub(NewRHS, NewShift);
-    }
   }
 
   // If we have a select that conditionally executes some binary operator,
@@ -978,6 +852,129 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
         return BinaryOperator::CreateShl(X, ConstantInt::get(Ty, AmtSum));
     }
 
+    // Fold shl(trunc(shift1(x,c1)), c2) -> trunc(shift2(shift1(x,c1),c2))
+    // If 'shift2' is an ashr, we would have to get the sign bit into a funny
+    // place.  Don't try to do this transformation in this case.  Also, we
+    // require that the input operand is a non-poison shift-by-constant so that
+    // we have confidence that the shifts will get folded together.
+    Instruction *TrOp;
+    const APInt *TrShiftAmt;
+    if (match(Op0, m_Trunc(m_Instruction(TrOp))) &&
+        match(TrOp, m_OneUse(m_Shift(m_Value(), m_APInt(TrShiftAmt)))) &&
+        TrShiftAmt->ult(TrOp->getType()->getScalarSizeInBits())) {
+      Type *SrcTy = TrOp->getType();
+
+      // Okay, we'll do this xform.  Make the shift of shift.
+      unsigned SrcSize = SrcTy->getScalarSizeInBits();
+      Constant *ShAmt = ConstantInt::get(SrcTy, C->zext(SrcSize));
+
+      // (shift2 (shift1 & 0x00FF), c2)
+      Value *NSh = Builder.CreateBinOp(I.getOpcode(), TrOp, ShAmt, I.getName());
+
+      // For logical shifts, the truncation has the effect of making the high
+      // part of the register be zeros.  Emulate this by inserting an AND to
+      // clear the top bits as needed.  This 'and' will usually be zapped by
+      // other xforms later if dead.
+      Constant *MaskV =
+          ConstantInt::get(SrcTy, APInt::getLowBitsSet(SrcSize, BitWidth));
+
+      // The mask we constructed says what the trunc would do if occurring
+      // between the shifts.  We want to know the effect *after* the second
+      // shift.  We know that it is a logical shift by a constant, so adjust the
+      // mask as appropriate.
+      MaskV = ConstantExpr::get(I.getOpcode(), MaskV, ShAmt);
+      // shift1 & 0x00FF
+      Value *And = Builder.CreateAnd(NSh, MaskV, Op0->getName());
+      // Return the value truncated to the interesting size.
+      return new TruncInst(And, Ty);
+    }
+
+    BinaryOperator *Op0BO;
+    if (match(Op0, m_OneUse(m_BinOp(Op0BO)))) {
+      // Turn ((X >> C) + Y) << C  ->  (X + (Y << C)) & (~0 << C)
+      Value *V1;
+      const APInt *CC;
+      switch (Op0BO->getOpcode()) {
+      default:
+        break;
+      case Instruction::Add:
+      case Instruction::And:
+      case Instruction::Or:
+      case Instruction::Xor: {
+        // These operators commute.
+        // Turn (Y + (X >> C)) << C  ->  (X + (Y << C)) & (~0 << C)
+        if (Op0BO->getOperand(1)->hasOneUse() &&
+            match(Op0BO->getOperand(1), m_Shr(m_Value(V1), m_Specific(Op1)))) {
+          Value *YS = // (Y << C)
+              Builder.CreateShl(Op0BO->getOperand(0), Op1, Op0BO->getName());
+          // (X + (Y << C))
+          Value *X = Builder.CreateBinOp(Op0BO->getOpcode(), YS, V1,
+                                         Op0BO->getOperand(1)->getName());
+          unsigned Op1Val = C->getLimitedValue(BitWidth);
+          APInt Bits = APInt::getHighBitsSet(BitWidth, BitWidth - Op1Val);
+          Constant *Mask = ConstantInt::get(Ty, Bits);
+          return BinaryOperator::CreateAnd(X, Mask);
+        }
+
+        // Turn (Y + ((X >> C) & CC)) << C  ->  ((X & (CC << C)) + (Y << C))
+        Value *Op0BOOp1 = Op0BO->getOperand(1);
+        if (Op0BOOp1->hasOneUse() &&
+            match(Op0BOOp1, m_And(m_OneUse(m_Shr(m_Value(V1), m_Specific(Op1))),
+                                  m_APInt(CC)))) {
+          Value *YS = // (Y << C)
+              Builder.CreateShl(Op0BO->getOperand(0), Op1, Op0BO->getName());
+          // X & (CC << C)
+          Value *XM = Builder.CreateAnd(V1, ConstantInt::get(Ty, CC->shl(*C)),
+                                        V1->getName() + ".mask");
+          return BinaryOperator::Create(Op0BO->getOpcode(), YS, XM);
+        }
+        LLVM_FALLTHROUGH;
+      }
+
+      case Instruction::Sub: {
+        // Turn ((X >> C) + Y) << C  ->  (X + (Y << C)) & (~0 << C)
+        if (Op0BO->getOperand(0)->hasOneUse() &&
+            match(Op0BO->getOperand(0), m_Shr(m_Value(V1), m_Specific(Op1)))) {
+          Value *YS = // (Y << C)
+              Builder.CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName());
+          // (X + (Y << C))
+          Value *X = Builder.CreateBinOp(Op0BO->getOpcode(), V1, YS,
+                                         Op0BO->getOperand(0)->getName());
+          unsigned Op1Val = C->getLimitedValue(BitWidth);
+          APInt Bits = APInt::getHighBitsSet(BitWidth, BitWidth - Op1Val);
+          Constant *Mask = ConstantInt::get(Ty, Bits);
+          return BinaryOperator::CreateAnd(X, Mask);
+        }
+
+        // Turn (((X >> C)&CC) + Y) << C  ->  (X + (Y << C)) & (CC << C)
+        if (Op0BO->getOperand(0)->hasOneUse() &&
+            match(Op0BO->getOperand(0),
+                  m_And(m_OneUse(m_Shr(m_Value(V1), m_Specific(Op1))),
+                        m_APInt(CC)))) {
+          Value *YS = // (Y << C)
+              Builder.CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName());
+          // X & (CC << C)
+          Value *XM = Builder.CreateAnd(V1, ConstantInt::get(Ty, CC->shl(*C)),
+                                        V1->getName() + ".mask");
+          return BinaryOperator::Create(Op0BO->getOpcode(), XM, YS);
+        }
+
+        break;
+      }
+      }
+
+      // If the operand is a subtract with a constant LHS, and the shift
+      // is the only use, we can pull it out of the shift.
+      // This folds (shl (sub C1, X), C) -> (sub (C1 << C), (shl X, C))
+      if (Op0BO->getOpcode() == Instruction::Sub &&
+          match(Op0BO->getOperand(0), m_APInt(C1))) {
+        Constant *NewLHS = ConstantInt::get(Ty, C1->shl(*C));
+        Value *NewShift = Builder.CreateShl(Op0BO->getOperand(1), Op1);
+        NewShift->takeName(Op0BO);
+        return BinaryOperator::CreateSub(NewLHS, NewShift);
+      }
+    }
+
     // If the shifted-out value is known-zero, then this is a NUW shift.
     if (!I.hasNoUnsignedWrap() &&
         MaskedValueIsZero(Op0, APInt::getHighBitsSet(BitWidth, ShAmtC), 0,


        


More information about the llvm-commits mailing list