[llvm] r264747 - [SCEV] Extract out a MatchBinaryOp; NFCI

Sanjoy Das via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 29 09:40:44 PDT 2016


Author: sanjoy
Date: Tue Mar 29 11:40:44 2016
New Revision: 264747

URL: http://llvm.org/viewvc/llvm-project?rev=264747&view=rev
Log:
[SCEV] Extract out a MatchBinaryOp; NFCI

MatchBinaryOp abstracts out the IR instructions from the operations they
represent.  While this change is NFC, we will use this factoring later
to map things like `(extractvalue 0 (sadd.with.overflow X Y))` to `(add
X Y)`.

Modified:
    llvm/trunk/lib/Analysis/ScalarEvolution.cpp

Modified: llvm/trunk/lib/Analysis/ScalarEvolution.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Analysis/ScalarEvolution.cpp?rev=264747&r1=264746&r2=264747&view=diff
==============================================================================
--- llvm/trunk/lib/Analysis/ScalarEvolution.cpp (original)
+++ llvm/trunk/lib/Analysis/ScalarEvolution.cpp Tue Mar 29 11:40:44 2016
@@ -4723,6 +4723,83 @@ SCEV::NoWrapFlags ScalarEvolution::getNo
   return SCEV::FlagAnyWrap;
 }
 
+namespace {
+/// Represents an abstract binary operation.  This may exist as a
+/// normal instruction or constant expression, or may have been
+/// derived from an expression tree.
+struct BinaryOp {
+  unsigned Opcode;
+  Value *LHS;
+  Value *RHS;
+
+  /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
+  /// constant expression.
+  Operator *Op;
+
+  explicit BinaryOp(Operator *Op)
+      : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
+        Op(Op) {}
+
+  explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS)
+      : Opcode(Opcode), LHS(LHS), RHS(RHS), Op(nullptr) {}
+};
+}
+
+
+/// Try to map \p V into a BinaryOp, and return \c None on failure.
+static Optional<BinaryOp> MatchBinaryOp(Value *V) {
+  auto *Op = dyn_cast<Operator>(V);
+  if (!Op)
+    return None;
+
+  // Implementation detail: all the cleverness here should happen without
+  // creating new SCEV expressions -- our caller knowns tricks to avoid creating
+  // SCEV expressions when possible, and we should not break that.
+
+  switch (Op->getOpcode()) {
+  case Instruction::Add:
+  case Instruction::Sub:
+  case Instruction::Mul:
+  case Instruction::UDiv:
+  case Instruction::And:
+  case Instruction::Or:
+  case Instruction::AShr:
+  case Instruction::Shl:
+    return BinaryOp(Op);
+
+  case Instruction::Xor:
+    if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
+      // If the RHS of the xor is a signbit, then this is just an add.
+      // Instcombine turns add of signbit into xor as a strength reduction step.
+      if (RHSC->getValue().isSignBit())
+        return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
+    return BinaryOp(Op);
+
+  case Instruction::LShr:
+    // Turn logical shift right of a constant into a unsigned divide.
+    if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
+      uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
+
+      // If the shift count is not less than the bitwidth, the result of
+      // the shift is undefined. Don't try to analyze it, because the
+      // resolution chosen here may differ from the resolution chosen in
+      // other parts of the compiler.
+      if (SA->getValue().ult(BitWidth)) {
+        Constant *X =
+            ConstantInt::get(SA->getContext(),
+                             APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
+        return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
+      }
+    }
+    return BinaryOp(Op);
+
+  default:
+    break;
+  }
+
+  return None;
+}
+
 /// createSCEV - We know that there is no SCEV for the specified value.  Analyze
 /// the expression.
 ///
@@ -4747,198 +4824,201 @@ const SCEV *ScalarEvolution::createSCEV(
     return getUnknown(V);
 
   Operator *U = cast<Operator>(V);
-  switch (U->getOpcode()) {
-  case Instruction::Add: {
-    // The simple thing to do would be to just call getSCEV on both operands
-    // and call getAddExpr with the result. However if we're looking at a
-    // bunch of things all added together, this can be quite inefficient,
-    // because it leads to N-1 getAddExpr calls for N ultimate operands.
-    // Instead, gather up all the operands and make a single getAddExpr call.
-    // LLVM IR canonical form means we need only traverse the left operands.
-    SmallVector<const SCEV *, 4> AddOps;
-    for (Value *Op = U;; Op = U->getOperand(0)) {
-      U = dyn_cast<Operator>(Op);
-      unsigned Opcode = U ? U->getOpcode() : 0;
-      if (!U || (Opcode != Instruction::Add && Opcode != Instruction::Sub)) {
-        assert(Op != V && "V should be an add");
-        AddOps.push_back(getSCEV(Op));
-        break;
-      }
+  if (auto BO = MatchBinaryOp(U)) {
+    switch (BO->Opcode) {
+    case Instruction::Add: {
+      // The simple thing to do would be to just call getSCEV on both operands
+      // and call getAddExpr with the result. However if we're looking at a
+      // bunch of things all added together, this can be quite inefficient,
+      // because it leads to N-1 getAddExpr calls for N ultimate operands.
+      // Instead, gather up all the operands and make a single getAddExpr call.
+      // LLVM IR canonical form means we need only traverse the left operands.
+      SmallVector<const SCEV *, 4> AddOps;
+      do {
+        if (BO->Op) {
+          if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
+            AddOps.push_back(OpSCEV);
+            break;
+          }
 
-      if (auto *OpSCEV = getExistingSCEV(U)) {
-        AddOps.push_back(OpSCEV);
-        break;
-      }
+          // If a NUW or NSW flag can be applied to the SCEV for this
+          // addition, then compute the SCEV for this addition by itself
+          // with a separate call to getAddExpr. We need to do that
+          // instead of pushing the operands of the addition onto AddOps,
+          // since the flags are only known to apply to this particular
+          // addition - they may not apply to other additions that can be
+          // formed with operands from AddOps.
+          const SCEV *RHS = getSCEV(BO->RHS);
+          SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
+          if (Flags != SCEV::FlagAnyWrap) {
+            const SCEV *LHS = getSCEV(BO->LHS);
+            if (BO->Opcode == Instruction::Sub)
+              AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
+            else
+              AddOps.push_back(getAddExpr(LHS, RHS, Flags));
+            break;
+          }
+        }
 
-      // If a NUW or NSW flag can be applied to the SCEV for this
-      // addition, then compute the SCEV for this addition by itself
-      // with a separate call to getAddExpr. We need to do that
-      // instead of pushing the operands of the addition onto AddOps,
-      // since the flags are only known to apply to this particular
-      // addition - they may not apply to other additions that can be
-      // formed with operands from AddOps.
-      const SCEV *RHS = getSCEV(U->getOperand(1));
-      SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(U);
-      if (Flags != SCEV::FlagAnyWrap) {
-        const SCEV *LHS = getSCEV(U->getOperand(0));
-        if (Opcode == Instruction::Sub)
-          AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
+        if (BO->Opcode == Instruction::Sub)
+          AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
         else
-          AddOps.push_back(getAddExpr(LHS, RHS, Flags));
-        break;
-      }
+          AddOps.push_back(getSCEV(BO->RHS));
+
+        auto NewBO = MatchBinaryOp(BO->LHS);
+        if (!NewBO || (NewBO->Opcode != Instruction::Add &&
+                       NewBO->Opcode != Instruction::Sub)) {
+          AddOps.push_back(getSCEV(BO->LHS));
+          break;
+        }
+        BO = NewBO;
+      } while (true);
 
-      if (Opcode == Instruction::Sub)
-        AddOps.push_back(getNegativeSCEV(RHS));
-      else
-        AddOps.push_back(RHS);
+      return getAddExpr(AddOps);
     }
-    return getAddExpr(AddOps);
-  }
 
-  case Instruction::Mul: {
-    SmallVector<const SCEV *, 4> MulOps;
-    for (Value *Op = U;; Op = U->getOperand(0)) {
-      U = dyn_cast<Operator>(Op);
-      if (!U || U->getOpcode() != Instruction::Mul) {
-        assert(Op != V && "V should be a mul");
-        MulOps.push_back(getSCEV(Op));
-        break;
-      }
+    case Instruction::Mul: {
+      SmallVector<const SCEV *, 4> MulOps;
+      do {
+        if (BO->Op) {
+          if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
+            MulOps.push_back(OpSCEV);
+            break;
+          }
 
-      if (auto *OpSCEV = getExistingSCEV(U)) {
-        MulOps.push_back(OpSCEV);
-        break;
-      }
+          SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
+          if (Flags != SCEV::FlagAnyWrap) {
+            MulOps.push_back(
+                getMulExpr(getSCEV(BO->LHS), getSCEV(BO->RHS), Flags));
+            break;
+          }
+        }
 
-      SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(U);
-      if (Flags != SCEV::FlagAnyWrap) {
-        MulOps.push_back(getMulExpr(getSCEV(U->getOperand(0)),
-                                    getSCEV(U->getOperand(1)), Flags));
-        break;
-      }
+        MulOps.push_back(getSCEV(BO->RHS));
+        auto NewBO = MatchBinaryOp(BO->LHS);
+        if (!NewBO || NewBO->Opcode != Instruction::Mul) {
+          MulOps.push_back(getSCEV(BO->LHS));
+          break;
+        }
+	BO = NewBO;
+      } while (true);
 
-      MulOps.push_back(getSCEV(U->getOperand(1)));
+      return getMulExpr(MulOps);
     }
-    return getMulExpr(MulOps);
-  }
-  case Instruction::UDiv:
-    return getUDivExpr(getSCEV(U->getOperand(0)),
-                       getSCEV(U->getOperand(1)));
-  case Instruction::Sub:
-    return getMinusSCEV(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)),
-                        getNoWrapFlagsFromUB(U));
-  case Instruction::And:
-    // For an expression like x&255 that merely masks off the high bits,
-    // use zext(trunc(x)) as the SCEV expression.
-    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
-      if (CI->isNullValue())
-        return getSCEV(U->getOperand(1));
-      if (CI->isAllOnesValue())
-        return getSCEV(U->getOperand(0));
-      const APInt &A = CI->getValue();
-
-      // Instcombine's ShrinkDemandedConstant may strip bits out of
-      // constants, obscuring what would otherwise be a low-bits mask.
-      // Use computeKnownBits to compute what ShrinkDemandedConstant
-      // knew about to reconstruct a low-bits mask value.
-      unsigned LZ = A.countLeadingZeros();
-      unsigned TZ = A.countTrailingZeros();
-      unsigned BitWidth = A.getBitWidth();
-      APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
-      computeKnownBits(U->getOperand(0), KnownZero, KnownOne, getDataLayout(),
-                       0, &AC, nullptr, &DT);
-
-      APInt EffectiveMask =
-          APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
-      if ((LZ != 0 || TZ != 0) && !((~A & ~KnownZero) & EffectiveMask)) {
-        const SCEV *MulCount = getConstant(
-            ConstantInt::get(getContext(), APInt::getOneBitSet(BitWidth, TZ)));
-        return getMulExpr(
-            getZeroExtendExpr(
-                getTruncateExpr(
-                    getUDivExactExpr(getSCEV(U->getOperand(0)), MulCount),
-                    IntegerType::get(getContext(), BitWidth - LZ - TZ)),
-                U->getType()),
-            MulCount);
-      }
+    case Instruction::UDiv:
+      return getUDivExpr(getSCEV(BO->LHS), getSCEV(BO->RHS));
+    case Instruction::Sub: {
+      SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
+      if (BO->Op)
+        Flags = getNoWrapFlagsFromUB(BO->Op);
+      return getMinusSCEV(getSCEV(BO->LHS), getSCEV(BO->RHS), Flags);
     }
-    break;
+    case Instruction::And:
+      // For an expression like x&255 that merely masks off the high bits,
+      // use zext(trunc(x)) as the SCEV expression.
+      if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
+        if (CI->isNullValue())
+          return getSCEV(BO->RHS);
+        if (CI->isAllOnesValue())
+          return getSCEV(BO->LHS);
+        const APInt &A = CI->getValue();
+
+        // Instcombine's ShrinkDemandedConstant may strip bits out of
+        // constants, obscuring what would otherwise be a low-bits mask.
+        // Use computeKnownBits to compute what ShrinkDemandedConstant
+        // knew about to reconstruct a low-bits mask value.
+        unsigned LZ = A.countLeadingZeros();
+        unsigned TZ = A.countTrailingZeros();
+        unsigned BitWidth = A.getBitWidth();
+        APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
+        computeKnownBits(BO->LHS, KnownZero, KnownOne, getDataLayout(),
+                         0, &AC, nullptr, &DT);
+
+        APInt EffectiveMask =
+            APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
+        if ((LZ != 0 || TZ != 0) && !((~A & ~KnownZero) & EffectiveMask)) {
+          const SCEV *MulCount = getConstant(ConstantInt::get(
+              getContext(), APInt::getOneBitSet(BitWidth, TZ)));
+          return getMulExpr(
+              getZeroExtendExpr(
+                  getTruncateExpr(
+                      getUDivExactExpr(getSCEV(BO->LHS), MulCount),
+                      IntegerType::get(getContext(), BitWidth - LZ - TZ)),
+                  BO->LHS->getType()),
+              MulCount);
+        }
+      }
+      break;
 
-  case Instruction::Or:
-    // If the RHS of the Or is a constant, we may have something like:
-    // X*4+1 which got turned into X*4|1.  Handle this as an Add so loop
-    // optimizations will transparently handle this case.
-    //
-    // In order for this transformation to be safe, the LHS must be of the
-    // form X*(2^n) and the Or constant must be less than 2^n.
-    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
-      const SCEV *LHS = getSCEV(U->getOperand(0));
-      const APInt &CIVal = CI->getValue();
-      if (GetMinTrailingZeros(LHS) >=
-          (CIVal.getBitWidth() - CIVal.countLeadingZeros())) {
-        // Build a plain add SCEV.
-        const SCEV *S = getAddExpr(LHS, getSCEV(CI));
-        // If the LHS of the add was an addrec and it has no-wrap flags,
-        // transfer the no-wrap flags, since an or won't introduce a wrap.
-        if (const SCEVAddRecExpr *NewAR = dyn_cast<SCEVAddRecExpr>(S)) {
-          const SCEVAddRecExpr *OldAR = cast<SCEVAddRecExpr>(LHS);
-          const_cast<SCEVAddRecExpr *>(NewAR)->setNoWrapFlags(
-            OldAR->getNoWrapFlags());
+    case Instruction::Or:
+      // If the RHS of the Or is a constant, we may have something like:
+      // X*4+1 which got turned into X*4|1.  Handle this as an Add so loop
+      // optimizations will transparently handle this case.
+      //
+      // In order for this transformation to be safe, the LHS must be of the
+      // form X*(2^n) and the Or constant must be less than 2^n.
+      if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
+        const SCEV *LHS = getSCEV(BO->LHS);
+        const APInt &CIVal = CI->getValue();
+        if (GetMinTrailingZeros(LHS) >=
+            (CIVal.getBitWidth() - CIVal.countLeadingZeros())) {
+          // Build a plain add SCEV.
+          const SCEV *S = getAddExpr(LHS, getSCEV(CI));
+          // If the LHS of the add was an addrec and it has no-wrap flags,
+          // transfer the no-wrap flags, since an or won't introduce a wrap.
+          if (const SCEVAddRecExpr *NewAR = dyn_cast<SCEVAddRecExpr>(S)) {
+            const SCEVAddRecExpr *OldAR = cast<SCEVAddRecExpr>(LHS);
+            const_cast<SCEVAddRecExpr *>(NewAR)->setNoWrapFlags(
+                OldAR->getNoWrapFlags());
+          }
+          return S;
         }
-        return S;
       }
-    }
-    break;
-  case Instruction::Xor:
-    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
-      // If the RHS of the xor is a signbit, then this is just an add.
-      // Instcombine turns add of signbit into xor as a strength reduction step.
-      if (CI->getValue().isSignBit())
-        return getAddExpr(getSCEV(U->getOperand(0)),
-                          getSCEV(U->getOperand(1)));
-
-      // If the RHS of xor is -1, then this is a not operation.
-      if (CI->isAllOnesValue())
-        return getNotSCEV(getSCEV(U->getOperand(0)));
-
-      // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
-      // This is a variant of the check for xor with -1, and it handles
-      // the case where instcombine has trimmed non-demanded bits out
-      // of an xor with -1.
-      if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U->getOperand(0)))
-        if (ConstantInt *LCI = dyn_cast<ConstantInt>(BO->getOperand(1)))
-          if (BO->getOpcode() == Instruction::And &&
-              LCI->getValue() == CI->getValue())
-            if (const SCEVZeroExtendExpr *Z =
-                  dyn_cast<SCEVZeroExtendExpr>(getSCEV(U->getOperand(0)))) {
-              Type *UTy = U->getType();
-              const SCEV *Z0 = Z->getOperand();
-              Type *Z0Ty = Z0->getType();
-              unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
-
-              // If C is a low-bits mask, the zero extend is serving to
-              // mask off the high bits. Complement the operand and
-              // re-apply the zext.
-              if (APIntOps::isMask(Z0TySize, CI->getValue()))
-                return getZeroExtendExpr(getNotSCEV(Z0), UTy);
-
-              // If C is a single bit, it may be in the sign-bit position
-              // before the zero-extend. In this case, represent the xor
-              // using an add, which is equivalent, and re-apply the zext.
-              APInt Trunc = CI->getValue().trunc(Z0TySize);
-              if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
-                  Trunc.isSignBit())
-                return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
-                                         UTy);
-            }
-    }
-    break;
+      break;
+
+    case Instruction::Xor:
+      if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
+        // If the RHS of xor is -1, then this is a not operation.
+        if (CI->isAllOnesValue())
+          return getNotSCEV(getSCEV(BO->LHS));
+
+        // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
+        // This is a variant of the check for xor with -1, and it handles
+        // the case where instcombine has trimmed non-demanded bits out
+        // of an xor with -1.
+        if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
+          if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
+            if (LBO->getOpcode() == Instruction::And &&
+                LCI->getValue() == CI->getValue())
+              if (const SCEVZeroExtendExpr *Z =
+                      dyn_cast<SCEVZeroExtendExpr>(getSCEV(BO->LHS))) {
+                Type *UTy = BO->LHS->getType();
+                const SCEV *Z0 = Z->getOperand();
+                Type *Z0Ty = Z0->getType();
+                unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
+
+                // If C is a low-bits mask, the zero extend is serving to
+                // mask off the high bits. Complement the operand and
+                // re-apply the zext.
+                if (APIntOps::isMask(Z0TySize, CI->getValue()))
+                  return getZeroExtendExpr(getNotSCEV(Z0), UTy);
+
+                // If C is a single bit, it may be in the sign-bit position
+                // before the zero-extend. In this case, represent the xor
+                // using an add, which is equivalent, and re-apply the zext.
+                APInt Trunc = CI->getValue().trunc(Z0TySize);
+                if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
+                    Trunc.isSignBit())
+                  return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
+                                           UTy);
+              }
+      }
+      break;
 
   case Instruction::Shl:
     // Turn shift left of a constant amount into a multiply.
-    if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
-      uint32_t BitWidth = cast<IntegerType>(U->getType())->getBitWidth();
+    if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
+      uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
 
       // If the shift count is not less than the bitwidth, the result of
       // the shift is undefined. Don't try to analyze it, because the
@@ -4954,58 +5034,43 @@ const SCEV *ScalarEvolution::createSCEV(
       // http://lists.llvm.org/pipermail/llvm-dev/2015-April/084195.html
       // and http://reviews.llvm.org/D8890 .
       auto Flags = SCEV::FlagAnyWrap;
-      if (SA->getValue().ult(BitWidth - 1)) Flags = getNoWrapFlagsFromUB(U);
+      if (BO->Op && SA->getValue().ult(BitWidth - 1))
+        Flags = getNoWrapFlagsFromUB(BO->Op);
 
       Constant *X = ConstantInt::get(getContext(),
         APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
-      return getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X), Flags);
+      return getMulExpr(getSCEV(BO->LHS), getSCEV(X), Flags);
     }
     break;
 
-  case Instruction::LShr:
-    // Turn logical shift right of a constant into a unsigned divide.
-    if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
-      uint32_t BitWidth = cast<IntegerType>(U->getType())->getBitWidth();
-
-      // If the shift count is not less than the bitwidth, the result of
-      // the shift is undefined. Don't try to analyze it, because the
-      // resolution chosen here may differ from the resolution chosen in
-      // other parts of the compiler.
-      if (SA->getValue().uge(BitWidth))
-        break;
-
-      Constant *X = ConstantInt::get(getContext(),
-        APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
-      return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(X));
+    case Instruction::AShr:
+      // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression.
+      if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS))
+        if (Operator *L = dyn_cast<Operator>(BO->LHS))
+          if (L->getOpcode() == Instruction::Shl &&
+              L->getOperand(1) == BO->RHS) {
+            uint64_t BitWidth = getTypeSizeInBits(BO->LHS->getType());
+
+            // If the shift count is not less than the bitwidth, the result of
+            // the shift is undefined. Don't try to analyze it, because the
+            // resolution chosen here may differ from the resolution chosen in
+            // other parts of the compiler.
+            if (CI->getValue().uge(BitWidth))
+              break;
+
+            uint64_t Amt = BitWidth - CI->getZExtValue();
+            if (Amt == BitWidth)
+              return getSCEV(L->getOperand(0)); // shift by zero --> noop
+            return getSignExtendExpr(
+                getTruncateExpr(getSCEV(L->getOperand(0)),
+                                IntegerType::get(getContext(), Amt)),
+                BO->LHS->getType());
+          }
+      break;
     }
-    break;
-
-  case Instruction::AShr:
-    // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression.
-    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1)))
-      if (Operator *L = dyn_cast<Operator>(U->getOperand(0)))
-        if (L->getOpcode() == Instruction::Shl &&
-            L->getOperand(1) == U->getOperand(1)) {
-          uint64_t BitWidth = getTypeSizeInBits(U->getType());
-
-          // If the shift count is not less than the bitwidth, the result of
-          // the shift is undefined. Don't try to analyze it, because the
-          // resolution chosen here may differ from the resolution chosen in
-          // other parts of the compiler.
-          if (CI->getValue().uge(BitWidth))
-            break;
-
-          uint64_t Amt = BitWidth - CI->getZExtValue();
-          if (Amt == BitWidth)
-            return getSCEV(L->getOperand(0));       // shift by zero --> noop
-          return
-            getSignExtendExpr(getTruncateExpr(getSCEV(L->getOperand(0)),
-                                              IntegerType::get(getContext(),
-                                                               Amt)),
-                              U->getType());
-        }
-    break;
+  }
 
+  switch (U->getOpcode()) {
   case Instruction::Trunc:
     return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
 
@@ -5040,9 +5105,6 @@ const SCEV *ScalarEvolution::createSCEV(
     if (isa<Instruction>(U))
       return createNodeForSelectOrPHI(cast<Instruction>(U), U->getOperand(0),
                                       U->getOperand(1), U->getOperand(2));
-
-  default: // We cannot analyze this expression.
-    break;
   }
 
   return getUnknown(V);




More information about the llvm-commits mailing list