[llvm] 8a567e5 - [ScalarEvolution] Fix pointer/int type handling converting select/phi to min/max.

Eli Friedman via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 17 14:05:35 PDT 2021


Author: Eli Friedman
Date: 2021-06-17T14:05:12-07:00
New Revision: 8a567e5f22a6d76ce076cf4305fe5c7cbff50fe0

URL: https://github.com/llvm/llvm-project/commit/8a567e5f22a6d76ce076cf4305fe5c7cbff50fe0
DIFF: https://github.com/llvm/llvm-project/commit/8a567e5f22a6d76ce076cf4305fe5c7cbff50fe0.diff

LOG: [ScalarEvolution] Fix pointer/int type handling converting select/phi to min/max.

The old version of this code would blindly perform arithmetic without
paying attention to whether the types involved were pointers or
integers.  This could lead to weird expressions like negating a pointer.

Explicitly handle simple cases involving pointers, like "x < y ? x : y".
In all other cases, coerce the operands of the comparison to integer
types.  This avoids the weird cases, while handling most of the
interesting cases.

Differential Revision: https://reviews.llvm.org/D103660

Added: 
    

Modified: 
    llvm/lib/Analysis/ScalarEvolution.cpp
    llvm/test/Analysis/ScalarEvolution/pr46786.ll
    llvm/test/Transforms/IndVarSimplify/pr45835.ll
    llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index c5cb7c4cfe743..67989c8107af5 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -5546,12 +5546,35 @@ const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Instruction *I,
     // a > b ? b+x : a+x  ->  min(a, b)+x
     if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) {
       bool Signed = ICI->isSigned();
-      const SCEV *LS = Signed ? getNoopOrSignExtend(getSCEV(LHS), I->getType())
-                              : getNoopOrZeroExtend(getSCEV(LHS), I->getType());
-      const SCEV *RS = Signed ? getNoopOrSignExtend(getSCEV(RHS), I->getType())
-                              : getNoopOrZeroExtend(getSCEV(RHS), I->getType());
       const SCEV *LA = getSCEV(TrueVal);
       const SCEV *RA = getSCEV(FalseVal);
+      const SCEV *LS = getSCEV(LHS);
+      const SCEV *RS = getSCEV(RHS);
+      if (LA->getType()->isPointerTy()) {
+        // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
+        // Need to make sure we can't produce weird expressions involving
+        // negated pointers.
+        if (LA == LS && RA == RS)
+          return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
+        if (LA == RS && RA == LS)
+          return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
+      }
+      auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
+        if (Op->getType()->isPointerTy()) {
+          Op = getLosslessPtrToIntExpr(Op);
+          if (isa<SCEVCouldNotCompute>(Op))
+            return Op;
+        }
+        if (Signed)
+          Op = getNoopOrSignExtend(Op, I->getType());
+        else
+          Op = getNoopOrZeroExtend(Op, I->getType());
+        return Op;
+      };
+      LS = CoerceOperand(LS);
+      RS = CoerceOperand(RS);
+      if (isa<SCEVCouldNotCompute>(LS) || isa<SCEVCouldNotCompute>(RS))
+        break;
       const SCEV *LDiff = getMinusSCEV(LA, LS);
       const SCEV *RDiff = getMinusSCEV(RA, RS);
       if (LDiff == RDiff)

diff  --git a/llvm/test/Analysis/ScalarEvolution/pr46786.ll b/llvm/test/Analysis/ScalarEvolution/pr46786.ll
index f19a33cf355dc..73bbb3d418a43 100644
--- a/llvm/test/Analysis/ScalarEvolution/pr46786.ll
+++ b/llvm/test/Analysis/ScalarEvolution/pr46786.ll
@@ -16,11 +16,11 @@ define i8* @FSE_decompress_usingDTable(i8* %arg, i32 %arg1, i32 %arg2, i32 %arg3
 ; CHECK-NEXT:    %i5 = getelementptr inbounds i8, i8* %i, i32 %i4
 ; CHECK-NEXT:    --> ((-1 * %arg1) + %arg2 + %arg) U: full-set S: full-set
 ; CHECK-NEXT:    %i7 = select i1 %i6, i32 %arg2, i32 %arg1
-; CHECK-NEXT:    --> ((-1 * %arg) + (((-1 * %arg1) + %arg2 + %arg) umin %arg) + %arg1) U: full-set S: full-set
+; CHECK-NEXT:    --> ((-1 * (ptrtoint i8* %arg to i32)) + (((-1 * %arg1) + (ptrtoint i8* %arg to i32) + %arg2) umin (ptrtoint i8* %arg to i32)) + %arg1) U: full-set S: full-set
 ; CHECK-NEXT:    %i8 = sub i32 %arg3, %i7
-; CHECK-NEXT:    --> ((-1 * (((-1 * %arg1) + %arg2 + %arg) umin %arg)) + (-1 * %arg1) + %arg3 + %arg) U: full-set S: full-set
+; CHECK-NEXT:    --> ((-1 * (((-1 * %arg1) + (ptrtoint i8* %arg to i32) + %arg2) umin (ptrtoint i8* %arg to i32))) + (-1 * %arg1) + (ptrtoint i8* %arg to i32) + %arg3) U: full-set S: full-set
 ; CHECK-NEXT:    %i9 = getelementptr inbounds i8, i8* %arg, i32 %i8
-; CHECK-NEXT:    --> ((2 * %arg) + (-1 * (((-1 * %arg1) + %arg2 + %arg) umin %arg)) + (-1 * %arg1) + %arg3) U: full-set S: full-set
+; CHECK-NEXT:    --> ((-1 * (((-1 * %arg1) + (ptrtoint i8* %arg to i32) + %arg2) umin (ptrtoint i8* %arg to i32))) + (-1 * %arg1) + (ptrtoint i8* %arg to i32) + %arg3 + %arg) U: full-set S: full-set
 ; CHECK-NEXT:  Determining loop execution counts for: @FSE_decompress_usingDTable
 ;
 bb:
@@ -42,11 +42,11 @@ define i8* @test_01(i8* %p) {
 ; CHECK-NEXT:    %p2 = getelementptr i8, i8* %p, i32 1
 ; CHECK-NEXT:    --> (1 + %p) U: full-set S: full-set
 ; CHECK-NEXT:    %index = select i1 %cmp, i32 2, i32 1
-; CHECK-NEXT:    --> ((-1 * %p) + ((1 + %p) umax (2 + %p))) U: full-set S: full-set
+; CHECK-NEXT:    --> ((-1 * (ptrtoint i8* %p to i32)) + ((1 + (ptrtoint i8* %p to i32)) umax (2 + (ptrtoint i8* %p to i32)))) U: full-set S: full-set
 ; CHECK-NEXT:    %neg_index = sub i32 0, %index
-; CHECK-NEXT:    --> ((-1 * ((1 + %p) umax (2 + %p))) + %p) U: full-set S: full-set
+; CHECK-NEXT:    --> ((-1 * ((1 + (ptrtoint i8* %p to i32)) umax (2 + (ptrtoint i8* %p to i32)))) + (ptrtoint i8* %p to i32)) U: full-set S: full-set
 ; CHECK-NEXT:    %gep = getelementptr i8, i8* %p, i32 %neg_index
-; CHECK-NEXT:    --> ((2 * %p) + (-1 * ((1 + %p) umax (2 + %p)))) U: full-set S: full-set
+; CHECK-NEXT:    --> ((-1 * ((1 + (ptrtoint i8* %p to i32)) umax (2 + (ptrtoint i8* %p to i32)))) + (ptrtoint i8* %p to i32) + %p) U: full-set S: full-set
 ; CHECK-NEXT:  Determining loop execution counts for: @test_01
 ;
   %p1 = getelementptr i8, i8* %p, i32 2
@@ -66,11 +66,11 @@ define i8* @test_02(i8* %p) {
 ; CHECK-NEXT:    %p2 = getelementptr i8, i8* %p, i32 1
 ; CHECK-NEXT:    --> (1 + %p) U: full-set S: full-set
 ; CHECK-NEXT:    %index = select i1 %cmp, i32 2, i32 1
-; CHECK-NEXT:    --> ((-1 * %p) + ((1 + %p) smax (2 + %p))) U: full-set S: full-set
+; CHECK-NEXT:    --> ((-1 * (ptrtoint i8* %p to i32)) + ((1 + (ptrtoint i8* %p to i32)) smax (2 + (ptrtoint i8* %p to i32)))) U: full-set S: full-set
 ; CHECK-NEXT:    %neg_index = sub i32 0, %index
-; CHECK-NEXT:    --> ((-1 * ((1 + %p) smax (2 + %p))) + %p) U: full-set S: full-set
+; CHECK-NEXT:    --> ((-1 * ((1 + (ptrtoint i8* %p to i32)) smax (2 + (ptrtoint i8* %p to i32)))) + (ptrtoint i8* %p to i32)) U: full-set S: full-set
 ; CHECK-NEXT:    %gep = getelementptr i8, i8* %p, i32 %neg_index
-; CHECK-NEXT:    --> ((2 * %p) + (-1 * ((1 + %p) smax (2 + %p)))) U: full-set S: full-set
+; CHECK-NEXT:    --> ((-1 * ((1 + (ptrtoint i8* %p to i32)) smax (2 + (ptrtoint i8* %p to i32)))) + (ptrtoint i8* %p to i32) + %p) U: full-set S: full-set
 ; CHECK-NEXT:  Determining loop execution counts for: @test_02
 ;
   %p1 = getelementptr i8, i8* %p, i32 2
@@ -90,11 +90,11 @@ define i8* @test_03(i8* %p) {
 ; CHECK-NEXT:    %p2 = getelementptr i8, i8* %p, i32 1
 ; CHECK-NEXT:    --> (1 + %p) U: full-set S: full-set
 ; CHECK-NEXT:    %index = select i1 %cmp, i32 2, i32 1
-; CHECK-NEXT:    --> ((-1 * %p) + ((1 + %p) umin (2 + %p))) U: full-set S: full-set
+; CHECK-NEXT:    --> ((-1 * (ptrtoint i8* %p to i32)) + ((1 + (ptrtoint i8* %p to i32)) umin (2 + (ptrtoint i8* %p to i32)))) U: full-set S: full-set
 ; CHECK-NEXT:    %neg_index = sub i32 0, %index
-; CHECK-NEXT:    --> ((-1 * ((1 + %p) umin (2 + %p))) + %p) U: full-set S: full-set
+; CHECK-NEXT:    --> ((-1 * ((1 + (ptrtoint i8* %p to i32)) umin (2 + (ptrtoint i8* %p to i32)))) + (ptrtoint i8* %p to i32)) U: full-set S: full-set
 ; CHECK-NEXT:    %gep = getelementptr i8, i8* %p, i32 %neg_index
-; CHECK-NEXT:    --> ((2 * %p) + (-1 * ((1 + %p) umin (2 + %p)))) U: full-set S: full-set
+; CHECK-NEXT:    --> ((-1 * ((1 + (ptrtoint i8* %p to i32)) umin (2 + (ptrtoint i8* %p to i32)))) + (ptrtoint i8* %p to i32) + %p) U: full-set S: full-set
 ; CHECK-NEXT:  Determining loop execution counts for: @test_03
 ;
   %p1 = getelementptr i8, i8* %p, i32 2
@@ -114,11 +114,11 @@ define i8* @test_04(i8* %p) {
 ; CHECK-NEXT:    %p2 = getelementptr i8, i8* %p, i32 1
 ; CHECK-NEXT:    --> (1 + %p) U: full-set S: full-set
 ; CHECK-NEXT:    %index = select i1 %cmp, i32 2, i32 1
-; CHECK-NEXT:    --> ((-1 * %p) + ((1 + %p) smin (2 + %p))) U: full-set S: full-set
+; CHECK-NEXT:    --> ((-1 * (ptrtoint i8* %p to i32)) + ((1 + (ptrtoint i8* %p to i32)) smin (2 + (ptrtoint i8* %p to i32)))) U: full-set S: full-set
 ; CHECK-NEXT:    %neg_index = sub i32 0, %index
-; CHECK-NEXT:    --> ((-1 * ((1 + %p) smin (2 + %p))) + %p) U: full-set S: full-set
+; CHECK-NEXT:    --> ((-1 * ((1 + (ptrtoint i8* %p to i32)) smin (2 + (ptrtoint i8* %p to i32)))) + (ptrtoint i8* %p to i32)) U: full-set S: full-set
 ; CHECK-NEXT:    %gep = getelementptr i8, i8* %p, i32 %neg_index
-; CHECK-NEXT:    --> ((2 * %p) + (-1 * ((1 + %p) smin (2 + %p)))) U: full-set S: full-set
+; CHECK-NEXT:    --> ((-1 * ((1 + (ptrtoint i8* %p to i32)) smin (2 + (ptrtoint i8* %p to i32)))) + (ptrtoint i8* %p to i32) + %p) U: full-set S: full-set
 ; CHECK-NEXT:  Determining loop execution counts for: @test_04
 ;
   %p1 = getelementptr i8, i8* %p, i32 2

diff  --git a/llvm/test/Transforms/IndVarSimplify/pr45835.ll b/llvm/test/Transforms/IndVarSimplify/pr45835.ll
index d5bab5ada6269..01fba87e92b3f 100644
--- a/llvm/test/Transforms/IndVarSimplify/pr45835.ll
+++ b/llvm/test/Transforms/IndVarSimplify/pr45835.ll
@@ -10,7 +10,7 @@ target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16
 
 define internal fastcc void @d(i8* %c) unnamed_addr #0 {
 entry:
-  %cmp = icmp ule i8* %c, getelementptr inbounds (i8, i8* @a, i64 65535)
+  %cmp = icmp ule i8* %c, @a
   %add.ptr = getelementptr inbounds i8, i8* %c, i64 -65535
   br label %while.cond
 
@@ -18,7 +18,7 @@ while.cond:
   br i1 icmp ne (i8 0, i8 0), label %cont, label %while.end
 
 cont:
-  %a.mux = select i1 %cmp, i8* @a, i8* %add.ptr
+  %a.mux = select i1 %cmp, i8* @a, i8* %c
   switch i64 0, label %while.cond [
     i64 -1, label %handler.pointer_overflow.i
     i64 0, label %handler.pointer_overflow.i
@@ -26,7 +26,7 @@ cont:
 
 handler.pointer_overflow.i:
   %a.mux.lcssa4 = phi i8* [ %a.mux, %cont ], [ %a.mux, %cont ]
-; ALWAYS: [ %scevgep, %cont ], [ %scevgep, %cont ]
+; ALWAYS: [ %umax, %cont ], [ %umax, %cont ]
 ; NEVER: [ %a.mux, %cont ], [ %a.mux, %cont ]
 ; In cheap mode, use either one as long as it's consistent.
 ; CHEAP: [ %[[VAL:.*]], %cont ], [ %[[VAL]], %cont ]

diff  --git a/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp b/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp
index 55f3e438cbe38..d825a756d508e 100644
--- a/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp
+++ b/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp
@@ -118,20 +118,7 @@ TEST_F(ScalarEvolutionExpanderTest, ExpandPtrTypeSCEV) {
 
   ScalarEvolution SE = buildSE(*F);
   auto *S = SE.getSCEV(CastB);
-  SCEVExpander Exp(SE, M.getDataLayout(), "expander");
-  Value *V =
-      Exp.expandCodeFor(cast<SCEVAddExpr>(S)->getOperand(1), nullptr, Br);
-
-  // Expect the expansion code contains:
-  //   %0 = bitcast i32* %bitcast2 to i8*
-  //   %uglygep = getelementptr i8, i8* %0, i64 -1
-  //   %1 = bitcast i8* %uglygep to i32*
-  EXPECT_TRUE(isa<BitCastInst>(V));
-  Instruction *Gep = cast<Instruction>(V)->getPrevNode();
-  EXPECT_TRUE(isa<GetElementPtrInst>(Gep));
-  EXPECT_TRUE(isa<ConstantInt>(Gep->getOperand(1)));
-  EXPECT_EQ(cast<ConstantInt>(Gep->getOperand(1))->getSExtValue(), -1);
-  EXPECT_TRUE(isa<BitCastInst>(Gep->getPrevNode()));
+  EXPECT_TRUE(isa<SCEVUnknown>(S));
 }
 
 // Make sure that SCEV doesn't introduce illegal ptrtoint/inttoptr instructions


        


More information about the llvm-commits mailing list