[llvm] r249177 - [SCEV] Refactor out a createNodeForSelect
Sanjoy Das via llvm-commits
llvm-commits at lists.llvm.org
Fri Oct 2 12:39:59 PDT 2015
Author: sanjoy
Date: Fri Oct 2 14:39:59 2015
New Revision: 249177
URL: http://llvm.org/viewvc/llvm-project?rev=249177&view=rev
Log:
[SCEV] Refactor out a createNodeForSelect
Summary:
We will shortly re-use this for select-like br-phi pairs.
Reviewers: atrick, joker-eph, joker.eph
Subscribers: sanjoy, llvm-commits
Differential Revision: http://reviews.llvm.org/D13377
Modified:
llvm/trunk/include/llvm/Analysis/ScalarEvolution.h
llvm/trunk/lib/Analysis/ScalarEvolution.cpp
Modified: llvm/trunk/include/llvm/Analysis/ScalarEvolution.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Analysis/ScalarEvolution.h?rev=249177&r1=249176&r2=249177&view=diff
==============================================================================
--- llvm/trunk/include/llvm/Analysis/ScalarEvolution.h (original)
+++ llvm/trunk/include/llvm/Analysis/ScalarEvolution.h Fri Oct 2 14:39:59 2015
@@ -415,6 +415,13 @@ namespace llvm {
/// Provide the special handling we need to analyze PHI SCEVs.
const SCEV *createNodeForPHI(PHINode *PN);
+ /// Provide special handling for a select-like instruction (currently this
+ /// is either a select instruction or a phi node). \p I is the instruction
+ /// being processed, and it is assumed equivalent to "Cond ? TrueVal :
+ /// FalseVal".
+ const SCEV *createNodeForSelect(Instruction *I, Value *Cond, Value *TrueVal,
+ Value *FalseVal);
+
/// Provide the special handling we need to analyze GEP SCEVs.
const SCEV *createNodeForGEP(GEPOperator *GEP);
Modified: llvm/trunk/lib/Analysis/ScalarEvolution.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Analysis/ScalarEvolution.cpp?rev=249177&r1=249176&r2=249177&view=diff
==============================================================================
--- llvm/trunk/lib/Analysis/ScalarEvolution.cpp (original)
+++ llvm/trunk/lib/Analysis/ScalarEvolution.cpp Fri Oct 2 14:39:59 2015
@@ -3756,6 +3756,99 @@ const SCEV *ScalarEvolution::createNodeF
return getUnknown(PN);
}
+const SCEV *ScalarEvolution::createNodeForSelect(Instruction *I, Value *Cond,
+ Value *TrueVal,
+ Value *FalseVal) {
+ // Try to match some simple smax or umax patterns.
+ auto *ICI = dyn_cast<ICmpInst>(Cond);
+ if (!ICI)
+ return getUnknown(I);
+
+ Value *LHS = ICI->getOperand(0);
+ Value *RHS = ICI->getOperand(1);
+
+ switch (ICI->getPredicate()) {
+ case ICmpInst::ICMP_SLT:
+ case ICmpInst::ICMP_SLE:
+ std::swap(LHS, RHS);
+ // fall through
+ case ICmpInst::ICMP_SGT:
+ case ICmpInst::ICMP_SGE:
+ // a >s b ? a+x : b+x -> smax(a, b)+x
+ // a >s b ? b+x : a+x -> smin(a, b)+x
+ if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) {
+ const SCEV *LS = getNoopOrSignExtend(getSCEV(LHS), I->getType());
+ const SCEV *RS = getNoopOrSignExtend(getSCEV(RHS), I->getType());
+ const SCEV *LA = getSCEV(TrueVal);
+ const SCEV *RA = getSCEV(FalseVal);
+ const SCEV *LDiff = getMinusSCEV(LA, LS);
+ const SCEV *RDiff = getMinusSCEV(RA, RS);
+ if (LDiff == RDiff)
+ return getAddExpr(getSMaxExpr(LS, RS), LDiff);
+ LDiff = getMinusSCEV(LA, RS);
+ RDiff = getMinusSCEV(RA, LS);
+ if (LDiff == RDiff)
+ return getAddExpr(getSMinExpr(LS, RS), LDiff);
+ }
+ break;
+ case ICmpInst::ICMP_ULT:
+ case ICmpInst::ICMP_ULE:
+ std::swap(LHS, RHS);
+ // fall through
+ case ICmpInst::ICMP_UGT:
+ case ICmpInst::ICMP_UGE:
+ // a >u b ? a+x : b+x -> umax(a, b)+x
+ // a >u b ? b+x : a+x -> umin(a, b)+x
+ if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) {
+ const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType());
+ const SCEV *RS = getNoopOrZeroExtend(getSCEV(RHS), I->getType());
+ const SCEV *LA = getSCEV(TrueVal);
+ const SCEV *RA = getSCEV(FalseVal);
+ const SCEV *LDiff = getMinusSCEV(LA, LS);
+ const SCEV *RDiff = getMinusSCEV(RA, RS);
+ if (LDiff == RDiff)
+ return getAddExpr(getUMaxExpr(LS, RS), LDiff);
+ LDiff = getMinusSCEV(LA, RS);
+ RDiff = getMinusSCEV(RA, LS);
+ if (LDiff == RDiff)
+ return getAddExpr(getUMinExpr(LS, RS), LDiff);
+ }
+ break;
+ case ICmpInst::ICMP_NE:
+ // n != 0 ? n+x : 1+x -> umax(n, 1)+x
+ if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType()) &&
+ isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
+ const SCEV *One = getOne(I->getType());
+ const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType());
+ const SCEV *LA = getSCEV(TrueVal);
+ const SCEV *RA = getSCEV(FalseVal);
+ const SCEV *LDiff = getMinusSCEV(LA, LS);
+ const SCEV *RDiff = getMinusSCEV(RA, One);
+ if (LDiff == RDiff)
+ return getAddExpr(getUMaxExpr(One, LS), LDiff);
+ }
+ break;
+ case ICmpInst::ICMP_EQ:
+ // n == 0 ? 1+x : n+x -> umax(n, 1)+x
+ if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType()) &&
+ isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
+ const SCEV *One = getOne(I->getType());
+ const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType());
+ const SCEV *LA = getSCEV(TrueVal);
+ const SCEV *RA = getSCEV(FalseVal);
+ const SCEV *LDiff = getMinusSCEV(LA, One);
+ const SCEV *RDiff = getMinusSCEV(RA, LS);
+ if (LDiff == RDiff)
+ return getAddExpr(getUMaxExpr(One, LS), LDiff);
+ }
+ break;
+ default:
+ break;
+ }
+
+ return getUnknown(I);
+}
+
/// createNodeForGEP - Expand GEP instructions into add and multiply
/// operations. This allows them to be analyzed by regular SCEV code.
///
@@ -4470,94 +4563,13 @@ const SCEV *ScalarEvolution::createSCEV(
return createNodeForPHI(cast<PHINode>(U));
case Instruction::Select:
- // This could be a smax or umax that was lowered earlier.
- // Try to recover it.
- if (ICmpInst *ICI = dyn_cast<ICmpInst>(U->getOperand(0))) {
- Value *LHS = ICI->getOperand(0);
- Value *RHS = ICI->getOperand(1);
- switch (ICI->getPredicate()) {
- case ICmpInst::ICMP_SLT:
- case ICmpInst::ICMP_SLE:
- std::swap(LHS, RHS);
- // fall through
- case ICmpInst::ICMP_SGT:
- case ICmpInst::ICMP_SGE:
- // a >s b ? a+x : b+x -> smax(a, b)+x
- // a >s b ? b+x : a+x -> smin(a, b)+x
- if (getTypeSizeInBits(LHS->getType()) <=
- getTypeSizeInBits(U->getType())) {
- const SCEV *LS = getNoopOrSignExtend(getSCEV(LHS), U->getType());
- const SCEV *RS = getNoopOrSignExtend(getSCEV(RHS), U->getType());
- const SCEV *LA = getSCEV(U->getOperand(1));
- const SCEV *RA = getSCEV(U->getOperand(2));
- const SCEV *LDiff = getMinusSCEV(LA, LS);
- const SCEV *RDiff = getMinusSCEV(RA, RS);
- if (LDiff == RDiff)
- return getAddExpr(getSMaxExpr(LS, RS), LDiff);
- LDiff = getMinusSCEV(LA, RS);
- RDiff = getMinusSCEV(RA, LS);
- if (LDiff == RDiff)
- return getAddExpr(getSMinExpr(LS, RS), LDiff);
- }
- break;
- case ICmpInst::ICMP_ULT:
- case ICmpInst::ICMP_ULE:
- std::swap(LHS, RHS);
- // fall through
- case ICmpInst::ICMP_UGT:
- case ICmpInst::ICMP_UGE:
- // a >u b ? a+x : b+x -> umax(a, b)+x
- // a >u b ? b+x : a+x -> umin(a, b)+x
- if (getTypeSizeInBits(LHS->getType()) <=
- getTypeSizeInBits(U->getType())) {
- const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType());
- const SCEV *RS = getNoopOrZeroExtend(getSCEV(RHS), U->getType());
- const SCEV *LA = getSCEV(U->getOperand(1));
- const SCEV *RA = getSCEV(U->getOperand(2));
- const SCEV *LDiff = getMinusSCEV(LA, LS);
- const SCEV *RDiff = getMinusSCEV(RA, RS);
- if (LDiff == RDiff)
- return getAddExpr(getUMaxExpr(LS, RS), LDiff);
- LDiff = getMinusSCEV(LA, RS);
- RDiff = getMinusSCEV(RA, LS);
- if (LDiff == RDiff)
- return getAddExpr(getUMinExpr(LS, RS), LDiff);
- }
- break;
- case ICmpInst::ICMP_NE:
- // n != 0 ? n+x : 1+x -> umax(n, 1)+x
- if (getTypeSizeInBits(LHS->getType()) <=
- getTypeSizeInBits(U->getType()) &&
- isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
- const SCEV *One = getOne(U->getType());
- const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType());
- const SCEV *LA = getSCEV(U->getOperand(1));
- const SCEV *RA = getSCEV(U->getOperand(2));
- const SCEV *LDiff = getMinusSCEV(LA, LS);
- const SCEV *RDiff = getMinusSCEV(RA, One);
- if (LDiff == RDiff)
- return getAddExpr(getUMaxExpr(One, LS), LDiff);
- }
- break;
- case ICmpInst::ICMP_EQ:
- // n == 0 ? 1+x : n+x -> umax(n, 1)+x
- if (getTypeSizeInBits(LHS->getType()) <=
- getTypeSizeInBits(U->getType()) &&
- isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
- const SCEV *One = getOne(U->getType());
- const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType());
- const SCEV *LA = getSCEV(U->getOperand(1));
- const SCEV *RA = getSCEV(U->getOperand(2));
- const SCEV *LDiff = getMinusSCEV(LA, One);
- const SCEV *RDiff = getMinusSCEV(RA, LS);
- if (LDiff == RDiff)
- return getAddExpr(getUMaxExpr(One, LS), LDiff);
- }
- break;
- default:
- break;
- }
- }
+ // U can also be a select constant expr, which let fall through. Since
+ // createNodeForSelect only works for a condition that is an `ICmpInst`, and
+ // constant expressions cannot have instructions as operands, we'd have
+ // returned getUnknown for a select constant expressions anyway.
+ if (isa<Instruction>(U))
+ return createNodeForSelect(cast<Instruction>(U), U->getOperand(0),
+ U->getOperand(1), U->getOperand(2));
default: // We cannot analyze this expression.
break;
More information about the llvm-commits
mailing list