[llvm] r228572 - Allow ScalarEvolution to catch more min/max cases

Johannes Doerfert doerfert at cs.uni-saarland.de
Mon Feb 9 04:34:24 PST 2015


Author: jdoerfert
Date: Mon Feb  9 06:34:23 2015
New Revision: 228572

URL: http://llvm.org/viewvc/llvm-project?rev=228572&view=rev
Log:
Allow ScalarEvolution to catch more min/max cases

  For the attached test case different types are used in the ICmpInst
  and SelectInst that represent the min/max expressions. However, if the
  ICmpInst type is smaller a comparison with the sign/zero extended
  operands would have yielded the same result. This situation might
  arise after the instruction combination pass was applied.

  Differential Revision: http://reviews.llvm.org/D7338

Added:
    llvm/trunk/test/Analysis/ScalarEvolution/min-max-exprs.ll
Modified:
    llvm/trunk/lib/Analysis/ScalarEvolution.cpp

Modified: llvm/trunk/lib/Analysis/ScalarEvolution.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Analysis/ScalarEvolution.cpp?rev=228572&r1=228571&r2=228572&view=diff
==============================================================================
--- llvm/trunk/lib/Analysis/ScalarEvolution.cpp (original)
+++ llvm/trunk/lib/Analysis/ScalarEvolution.cpp Mon Feb  9 06:34:23 2015
@@ -4297,9 +4297,10 @@ const SCEV *ScalarEvolution::createSCEV(
       case ICmpInst::ICMP_SGE:
         // a >s b ? a+x : b+x  ->  smax(a, b)+x
         // a >s b ? b+x : a+x  ->  smin(a, b)+x
-        if (LHS->getType() == U->getType()) {
-          const SCEV *LS = getSCEV(LHS);
-          const SCEV *RS = getSCEV(RHS);
+        if (getTypeSizeInBits(LHS->getType()) <=
+            getTypeSizeInBits(U->getType())) {
+          const SCEV *LS = getNoopOrSignExtend(getSCEV(LHS), U->getType());
+          const SCEV *RS = getNoopOrSignExtend(getSCEV(RHS), U->getType());
           const SCEV *LA = getSCEV(U->getOperand(1));
           const SCEV *RA = getSCEV(U->getOperand(2));
           const SCEV *LDiff = getMinusSCEV(LA, LS);
@@ -4320,9 +4321,10 @@ const SCEV *ScalarEvolution::createSCEV(
       case ICmpInst::ICMP_UGE:
         // a >u b ? a+x : b+x  ->  umax(a, b)+x
         // a >u b ? b+x : a+x  ->  umin(a, b)+x
-        if (LHS->getType() == U->getType()) {
-          const SCEV *LS = getSCEV(LHS);
-          const SCEV *RS = getSCEV(RHS);
+        if (getTypeSizeInBits(LHS->getType()) <=
+            getTypeSizeInBits(U->getType())) {
+          const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType());
+          const SCEV *RS = getNoopOrZeroExtend(getSCEV(RHS), U->getType());
           const SCEV *LA = getSCEV(U->getOperand(1));
           const SCEV *RA = getSCEV(U->getOperand(2));
           const SCEV *LDiff = getMinusSCEV(LA, LS);
@@ -4337,11 +4339,11 @@ const SCEV *ScalarEvolution::createSCEV(
         break;
       case ICmpInst::ICMP_NE:
         // n != 0 ? n+x : 1+x  ->  umax(n, 1)+x
-        if (LHS->getType() == U->getType() &&
-            isa<ConstantInt>(RHS) &&
-            cast<ConstantInt>(RHS)->isZero()) {
-          const SCEV *One = getConstant(LHS->getType(), 1);
-          const SCEV *LS = getSCEV(LHS);
+        if (getTypeSizeInBits(LHS->getType()) <=
+                getTypeSizeInBits(U->getType()) &&
+            isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
+          const SCEV *One = getConstant(U->getType(), 1);
+          const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType());
           const SCEV *LA = getSCEV(U->getOperand(1));
           const SCEV *RA = getSCEV(U->getOperand(2));
           const SCEV *LDiff = getMinusSCEV(LA, LS);
@@ -4352,11 +4354,11 @@ const SCEV *ScalarEvolution::createSCEV(
         break;
       case ICmpInst::ICMP_EQ:
         // n == 0 ? 1+x : n+x  ->  umax(n, 1)+x
-        if (LHS->getType() == U->getType() &&
-            isa<ConstantInt>(RHS) &&
-            cast<ConstantInt>(RHS)->isZero()) {
-          const SCEV *One = getConstant(LHS->getType(), 1);
-          const SCEV *LS = getSCEV(LHS);
+        if (getTypeSizeInBits(LHS->getType()) <=
+                getTypeSizeInBits(U->getType()) &&
+            isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
+          const SCEV *One = getConstant(U->getType(), 1);
+          const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType());
           const SCEV *LA = getSCEV(U->getOperand(1));
           const SCEV *RA = getSCEV(U->getOperand(2));
           const SCEV *LDiff = getMinusSCEV(LA, One);
@@ -7028,8 +7030,8 @@ ScalarEvolution::isImpliedCondOperandsHe
   return false;
 }
 
-// Verify if an linear IV with positive stride can overflow when in a 
-// less-than comparison, knowing the invariant term of the comparison, the 
+// Verify if an linear IV with positive stride can overflow when in a
+// less-than comparison, knowing the invariant term of the comparison, the
 // stride and the knowledge of NSW/NUW flags on the recurrence.
 bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
                                          bool IsSigned, bool NoWrap) {
@@ -7057,7 +7059,7 @@ bool ScalarEvolution::doesIVOverflowOnLT
   return (MaxValue - MaxStrideMinusOne).ult(MaxRHS);
 }
 
-// Verify if an linear IV with negative stride can overflow when in a 
+// Verify if an linear IV with negative stride can overflow when in a
 // greater-than comparison, knowing the invariant term of the comparison,
 // the stride and the knowledge of NSW/NUW flags on the recurrence.
 bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
@@ -7088,7 +7090,7 @@ bool ScalarEvolution::doesIVOverflowOnGT
 
 // Compute the backedge taken count knowing the interval difference, the
 // stride and presence of the equality in the comparison.
-const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step, 
+const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step,
                                             bool Equality) {
   const SCEV *One = getConstant(Step->getType(), 1);
   Delta = Equality ? getAddExpr(Delta, Step)
@@ -7128,7 +7130,7 @@ ScalarEvolution::HowManyLessThans(const
 
   // Avoid proven overflow cases: this will ensure that the backedge taken count
   // will not generate any unsigned overflow. Relaxed no-overflow conditions
-  // exploit NoWrapFlags, allowing to optimize in presence of undefined 
+  // exploit NoWrapFlags, allowing to optimize in presence of undefined
   // behaviors like the case of C language.
   if (!Stride->isOne() && doesIVOverflowOnLT(RHS, Stride, IsSigned, NoWrap))
     return getCouldNotCompute();
@@ -7208,7 +7210,7 @@ ScalarEvolution::HowManyGreaterThans(con
 
   // Avoid proven overflow cases: this will ensure that the backedge taken count
   // will not generate any unsigned overflow. Relaxed no-overflow conditions
-  // exploit NoWrapFlags, allowing to optimize in presence of undefined 
+  // exploit NoWrapFlags, allowing to optimize in presence of undefined
   // behaviors like the case of C language.
   if (!Stride->isOne() && doesIVOverflowOnGT(RHS, Stride, IsSigned, NoWrap))
     return getCouldNotCompute();
@@ -7256,7 +7258,7 @@ ScalarEvolution::HowManyGreaterThans(con
   if (isa<SCEVConstant>(BECount))
     MaxBECount = BECount;
   else
-    MaxBECount = computeBECount(getConstant(MaxStart - MinEnd), 
+    MaxBECount = computeBECount(getConstant(MaxStart - MinEnd),
                                 getConstant(MinStride), false);
 
   if (isa<SCEVCouldNotCompute>(MaxBECount))

Added: llvm/trunk/test/Analysis/ScalarEvolution/min-max-exprs.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Analysis/ScalarEvolution/min-max-exprs.ll?rev=228572&view=auto
==============================================================================
--- llvm/trunk/test/Analysis/ScalarEvolution/min-max-exprs.ll (added)
+++ llvm/trunk/test/Analysis/ScalarEvolution/min-max-exprs.ll Mon Feb  9 06:34:23 2015
@@ -0,0 +1,53 @@
+; RUN: opt -scalar-evolution -analyze < %s | FileCheck %s
+;
+; This checks if the min and max expressions are properly recognized by
+; ScalarEvolution even though they the ICmpInst and SelectInst have different
+; types.
+;
+;    #define max(a, b) (a > b ? a : b)
+;    #define min(a, b) (a < b ? a : b)
+;
+;    void f(int *A, int N) {
+;      for (int i = 0; i < N; i++) {
+;        A[max(0, i - 3)] = A[min(N, i + 3)] * 2;
+;      }
+;    }
+;
+target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
+
+define void @f(i32* %A, i32 %N) {
+bb:
+  br label %bb1
+
+bb1:                                              ; preds = %bb2, %bb
+  %i.0 = phi i32 [ 0, %bb ], [ %tmp23, %bb2 ]
+  %i.0.1 = sext i32 %i.0 to i64
+  %tmp = icmp slt i32 %i.0, %N
+  br i1 %tmp, label %bb2, label %bb24
+
+bb2:                                              ; preds = %bb1
+  %tmp3 = add nuw nsw i32 %i.0, 3
+  %tmp4 = icmp slt i32 %tmp3, %N
+  %tmp5 = sext i32 %tmp3 to i64
+  %tmp6 = sext i32 %N to i64
+  %tmp9 = select i1 %tmp4, i64 %tmp5, i64 %tmp6
+;                  min(N, i+3)
+; CHECK:           select i1 %tmp4, i64 %tmp5, i64 %tmp6
+; CHECK-NEXT:  --> (-1 + (-1 * ((-1 + (-1 * (sext i32 {3,+,1}<nw><%bb1> to i64))) smax (-1 + (-1 * (sext i32 %N to i64))))))
+  %tmp11 = getelementptr inbounds i32* %A, i64 %tmp9
+  %tmp12 = load i32* %tmp11, align 4
+  %tmp13 = shl nsw i32 %tmp12, 1
+  %tmp14 = icmp sge i32 3, %i.0
+  %tmp17 = add nsw i64 %i.0.1, -3
+  %tmp19 = select i1 %tmp14, i64 0, i64 %tmp17
+;                  max(0, i - 3)
+; CHECK:           select i1 %tmp14, i64 0, i64 %tmp17
+; CHECK-NEXT: -->  (-3 + (3 smax {0,+,1}<nuw><nsw><%bb1>))
+  %tmp21 = getelementptr inbounds i32* %A, i64 %tmp19
+  store i32 %tmp13, i32* %tmp21, align 4
+  %tmp23 = add nuw nsw i32 %i.0, 1
+  br label %bb1
+
+bb24:                                             ; preds = %bb1
+  ret void
+}





More information about the llvm-commits mailing list