[llvm] 8a567e5 - [ScalarEvolution] Fix pointer/int type handling converting select/phi to min/max.
Eli Friedman via llvm-commits
llvm-commits at lists.llvm.org
Thu Jun 17 14:05:35 PDT 2021
Author: Eli Friedman
Date: 2021-06-17T14:05:12-07:00
New Revision: 8a567e5f22a6d76ce076cf4305fe5c7cbff50fe0
URL: https://github.com/llvm/llvm-project/commit/8a567e5f22a6d76ce076cf4305fe5c7cbff50fe0
DIFF: https://github.com/llvm/llvm-project/commit/8a567e5f22a6d76ce076cf4305fe5c7cbff50fe0.diff
LOG: [ScalarEvolution] Fix pointer/int type handling converting select/phi to min/max.
The old version of this code would blindly perform arithmetic without
paying attention to whether the types involved were pointers or
integers. This could lead to weird expressions like negating a pointer.
Explicitly handle simple cases involving pointers, like "x < y ? x : y".
In all other cases, coerce the operands of the comparison to integer
types. This avoids the weird cases, while handling most of the
interesting cases.
Differential Revision: https://reviews.llvm.org/D103660
Added:
Modified:
llvm/lib/Analysis/ScalarEvolution.cpp
llvm/test/Analysis/ScalarEvolution/pr46786.ll
llvm/test/Transforms/IndVarSimplify/pr45835.ll
llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index c5cb7c4cfe743..67989c8107af5 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -5546,12 +5546,35 @@ const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Instruction *I,
// a > b ? b+x : a+x -> min(a, b)+x
if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) {
bool Signed = ICI->isSigned();
- const SCEV *LS = Signed ? getNoopOrSignExtend(getSCEV(LHS), I->getType())
- : getNoopOrZeroExtend(getSCEV(LHS), I->getType());
- const SCEV *RS = Signed ? getNoopOrSignExtend(getSCEV(RHS), I->getType())
- : getNoopOrZeroExtend(getSCEV(RHS), I->getType());
const SCEV *LA = getSCEV(TrueVal);
const SCEV *RA = getSCEV(FalseVal);
+ const SCEV *LS = getSCEV(LHS);
+ const SCEV *RS = getSCEV(RHS);
+ if (LA->getType()->isPointerTy()) {
+ // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
+ // Need to make sure we can't produce weird expressions involving
+ // negated pointers.
+ if (LA == LS && RA == RS)
+ return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
+ if (LA == RS && RA == LS)
+ return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
+ }
+ auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
+ if (Op->getType()->isPointerTy()) {
+ Op = getLosslessPtrToIntExpr(Op);
+ if (isa<SCEVCouldNotCompute>(Op))
+ return Op;
+ }
+ if (Signed)
+ Op = getNoopOrSignExtend(Op, I->getType());
+ else
+ Op = getNoopOrZeroExtend(Op, I->getType());
+ return Op;
+ };
+ LS = CoerceOperand(LS);
+ RS = CoerceOperand(RS);
+ if (isa<SCEVCouldNotCompute>(LS) || isa<SCEVCouldNotCompute>(RS))
+ break;
const SCEV *LDiff = getMinusSCEV(LA, LS);
const SCEV *RDiff = getMinusSCEV(RA, RS);
if (LDiff == RDiff)
diff --git a/llvm/test/Analysis/ScalarEvolution/pr46786.ll b/llvm/test/Analysis/ScalarEvolution/pr46786.ll
index f19a33cf355dc..73bbb3d418a43 100644
--- a/llvm/test/Analysis/ScalarEvolution/pr46786.ll
+++ b/llvm/test/Analysis/ScalarEvolution/pr46786.ll
@@ -16,11 +16,11 @@ define i8* @FSE_decompress_usingDTable(i8* %arg, i32 %arg1, i32 %arg2, i32 %arg3
; CHECK-NEXT: %i5 = getelementptr inbounds i8, i8* %i, i32 %i4
; CHECK-NEXT: --> ((-1 * %arg1) + %arg2 + %arg) U: full-set S: full-set
; CHECK-NEXT: %i7 = select i1 %i6, i32 %arg2, i32 %arg1
-; CHECK-NEXT: --> ((-1 * %arg) + (((-1 * %arg1) + %arg2 + %arg) umin %arg) + %arg1) U: full-set S: full-set
+; CHECK-NEXT: --> ((-1 * (ptrtoint i8* %arg to i32)) + (((-1 * %arg1) + (ptrtoint i8* %arg to i32) + %arg2) umin (ptrtoint i8* %arg to i32)) + %arg1) U: full-set S: full-set
; CHECK-NEXT: %i8 = sub i32 %arg3, %i7
-; CHECK-NEXT: --> ((-1 * (((-1 * %arg1) + %arg2 + %arg) umin %arg)) + (-1 * %arg1) + %arg3 + %arg) U: full-set S: full-set
+; CHECK-NEXT: --> ((-1 * (((-1 * %arg1) + (ptrtoint i8* %arg to i32) + %arg2) umin (ptrtoint i8* %arg to i32))) + (-1 * %arg1) + (ptrtoint i8* %arg to i32) + %arg3) U: full-set S: full-set
; CHECK-NEXT: %i9 = getelementptr inbounds i8, i8* %arg, i32 %i8
-; CHECK-NEXT: --> ((2 * %arg) + (-1 * (((-1 * %arg1) + %arg2 + %arg) umin %arg)) + (-1 * %arg1) + %arg3) U: full-set S: full-set
+; CHECK-NEXT: --> ((-1 * (((-1 * %arg1) + (ptrtoint i8* %arg to i32) + %arg2) umin (ptrtoint i8* %arg to i32))) + (-1 * %arg1) + (ptrtoint i8* %arg to i32) + %arg3 + %arg) U: full-set S: full-set
; CHECK-NEXT: Determining loop execution counts for: @FSE_decompress_usingDTable
;
bb:
@@ -42,11 +42,11 @@ define i8* @test_01(i8* %p) {
; CHECK-NEXT: %p2 = getelementptr i8, i8* %p, i32 1
; CHECK-NEXT: --> (1 + %p) U: full-set S: full-set
; CHECK-NEXT: %index = select i1 %cmp, i32 2, i32 1
-; CHECK-NEXT: --> ((-1 * %p) + ((1 + %p) umax (2 + %p))) U: full-set S: full-set
+; CHECK-NEXT: --> ((-1 * (ptrtoint i8* %p to i32)) + ((1 + (ptrtoint i8* %p to i32)) umax (2 + (ptrtoint i8* %p to i32)))) U: full-set S: full-set
; CHECK-NEXT: %neg_index = sub i32 0, %index
-; CHECK-NEXT: --> ((-1 * ((1 + %p) umax (2 + %p))) + %p) U: full-set S: full-set
+; CHECK-NEXT: --> ((-1 * ((1 + (ptrtoint i8* %p to i32)) umax (2 + (ptrtoint i8* %p to i32)))) + (ptrtoint i8* %p to i32)) U: full-set S: full-set
; CHECK-NEXT: %gep = getelementptr i8, i8* %p, i32 %neg_index
-; CHECK-NEXT: --> ((2 * %p) + (-1 * ((1 + %p) umax (2 + %p)))) U: full-set S: full-set
+; CHECK-NEXT: --> ((-1 * ((1 + (ptrtoint i8* %p to i32)) umax (2 + (ptrtoint i8* %p to i32)))) + (ptrtoint i8* %p to i32) + %p) U: full-set S: full-set
; CHECK-NEXT: Determining loop execution counts for: @test_01
;
%p1 = getelementptr i8, i8* %p, i32 2
@@ -66,11 +66,11 @@ define i8* @test_02(i8* %p) {
; CHECK-NEXT: %p2 = getelementptr i8, i8* %p, i32 1
; CHECK-NEXT: --> (1 + %p) U: full-set S: full-set
; CHECK-NEXT: %index = select i1 %cmp, i32 2, i32 1
-; CHECK-NEXT: --> ((-1 * %p) + ((1 + %p) smax (2 + %p))) U: full-set S: full-set
+; CHECK-NEXT: --> ((-1 * (ptrtoint i8* %p to i32)) + ((1 + (ptrtoint i8* %p to i32)) smax (2 + (ptrtoint i8* %p to i32)))) U: full-set S: full-set
; CHECK-NEXT: %neg_index = sub i32 0, %index
-; CHECK-NEXT: --> ((-1 * ((1 + %p) smax (2 + %p))) + %p) U: full-set S: full-set
+; CHECK-NEXT: --> ((-1 * ((1 + (ptrtoint i8* %p to i32)) smax (2 + (ptrtoint i8* %p to i32)))) + (ptrtoint i8* %p to i32)) U: full-set S: full-set
; CHECK-NEXT: %gep = getelementptr i8, i8* %p, i32 %neg_index
-; CHECK-NEXT: --> ((2 * %p) + (-1 * ((1 + %p) smax (2 + %p)))) U: full-set S: full-set
+; CHECK-NEXT: --> ((-1 * ((1 + (ptrtoint i8* %p to i32)) smax (2 + (ptrtoint i8* %p to i32)))) + (ptrtoint i8* %p to i32) + %p) U: full-set S: full-set
; CHECK-NEXT: Determining loop execution counts for: @test_02
;
%p1 = getelementptr i8, i8* %p, i32 2
@@ -90,11 +90,11 @@ define i8* @test_03(i8* %p) {
; CHECK-NEXT: %p2 = getelementptr i8, i8* %p, i32 1
; CHECK-NEXT: --> (1 + %p) U: full-set S: full-set
; CHECK-NEXT: %index = select i1 %cmp, i32 2, i32 1
-; CHECK-NEXT: --> ((-1 * %p) + ((1 + %p) umin (2 + %p))) U: full-set S: full-set
+; CHECK-NEXT: --> ((-1 * (ptrtoint i8* %p to i32)) + ((1 + (ptrtoint i8* %p to i32)) umin (2 + (ptrtoint i8* %p to i32)))) U: full-set S: full-set
; CHECK-NEXT: %neg_index = sub i32 0, %index
-; CHECK-NEXT: --> ((-1 * ((1 + %p) umin (2 + %p))) + %p) U: full-set S: full-set
+; CHECK-NEXT: --> ((-1 * ((1 + (ptrtoint i8* %p to i32)) umin (2 + (ptrtoint i8* %p to i32)))) + (ptrtoint i8* %p to i32)) U: full-set S: full-set
; CHECK-NEXT: %gep = getelementptr i8, i8* %p, i32 %neg_index
-; CHECK-NEXT: --> ((2 * %p) + (-1 * ((1 + %p) umin (2 + %p)))) U: full-set S: full-set
+; CHECK-NEXT: --> ((-1 * ((1 + (ptrtoint i8* %p to i32)) umin (2 + (ptrtoint i8* %p to i32)))) + (ptrtoint i8* %p to i32) + %p) U: full-set S: full-set
; CHECK-NEXT: Determining loop execution counts for: @test_03
;
%p1 = getelementptr i8, i8* %p, i32 2
@@ -114,11 +114,11 @@ define i8* @test_04(i8* %p) {
; CHECK-NEXT: %p2 = getelementptr i8, i8* %p, i32 1
; CHECK-NEXT: --> (1 + %p) U: full-set S: full-set
; CHECK-NEXT: %index = select i1 %cmp, i32 2, i32 1
-; CHECK-NEXT: --> ((-1 * %p) + ((1 + %p) smin (2 + %p))) U: full-set S: full-set
+; CHECK-NEXT: --> ((-1 * (ptrtoint i8* %p to i32)) + ((1 + (ptrtoint i8* %p to i32)) smin (2 + (ptrtoint i8* %p to i32)))) U: full-set S: full-set
; CHECK-NEXT: %neg_index = sub i32 0, %index
-; CHECK-NEXT: --> ((-1 * ((1 + %p) smin (2 + %p))) + %p) U: full-set S: full-set
+; CHECK-NEXT: --> ((-1 * ((1 + (ptrtoint i8* %p to i32)) smin (2 + (ptrtoint i8* %p to i32)))) + (ptrtoint i8* %p to i32)) U: full-set S: full-set
; CHECK-NEXT: %gep = getelementptr i8, i8* %p, i32 %neg_index
-; CHECK-NEXT: --> ((2 * %p) + (-1 * ((1 + %p) smin (2 + %p)))) U: full-set S: full-set
+; CHECK-NEXT: --> ((-1 * ((1 + (ptrtoint i8* %p to i32)) smin (2 + (ptrtoint i8* %p to i32)))) + (ptrtoint i8* %p to i32) + %p) U: full-set S: full-set
; CHECK-NEXT: Determining loop execution counts for: @test_04
;
%p1 = getelementptr i8, i8* %p, i32 2
diff --git a/llvm/test/Transforms/IndVarSimplify/pr45835.ll b/llvm/test/Transforms/IndVarSimplify/pr45835.ll
index d5bab5ada6269..01fba87e92b3f 100644
--- a/llvm/test/Transforms/IndVarSimplify/pr45835.ll
+++ b/llvm/test/Transforms/IndVarSimplify/pr45835.ll
@@ -10,7 +10,7 @@ target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16
define internal fastcc void @d(i8* %c) unnamed_addr #0 {
entry:
- %cmp = icmp ule i8* %c, getelementptr inbounds (i8, i8* @a, i64 65535)
+ %cmp = icmp ule i8* %c, @a
%add.ptr = getelementptr inbounds i8, i8* %c, i64 -65535
br label %while.cond
@@ -18,7 +18,7 @@ while.cond:
br i1 icmp ne (i8 0, i8 0), label %cont, label %while.end
cont:
- %a.mux = select i1 %cmp, i8* @a, i8* %add.ptr
+ %a.mux = select i1 %cmp, i8* @a, i8* %c
switch i64 0, label %while.cond [
i64 -1, label %handler.pointer_overflow.i
i64 0, label %handler.pointer_overflow.i
@@ -26,7 +26,7 @@ cont:
handler.pointer_overflow.i:
%a.mux.lcssa4 = phi i8* [ %a.mux, %cont ], [ %a.mux, %cont ]
-; ALWAYS: [ %scevgep, %cont ], [ %scevgep, %cont ]
+; ALWAYS: [ %umax, %cont ], [ %umax, %cont ]
; NEVER: [ %a.mux, %cont ], [ %a.mux, %cont ]
; In cheap mode, use either one as long as it's consistent.
; CHEAP: [ %[[VAL:.*]], %cont ], [ %[[VAL]], %cont ]
diff --git a/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp b/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp
index 55f3e438cbe38..d825a756d508e 100644
--- a/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp
+++ b/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp
@@ -118,20 +118,7 @@ TEST_F(ScalarEvolutionExpanderTest, ExpandPtrTypeSCEV) {
ScalarEvolution SE = buildSE(*F);
auto *S = SE.getSCEV(CastB);
- SCEVExpander Exp(SE, M.getDataLayout(), "expander");
- Value *V =
- Exp.expandCodeFor(cast<SCEVAddExpr>(S)->getOperand(1), nullptr, Br);
-
- // Expect the expansion code contains:
- // %0 = bitcast i32* %bitcast2 to i8*
- // %uglygep = getelementptr i8, i8* %0, i64 -1
- // %1 = bitcast i8* %uglygep to i32*
- EXPECT_TRUE(isa<BitCastInst>(V));
- Instruction *Gep = cast<Instruction>(V)->getPrevNode();
- EXPECT_TRUE(isa<GetElementPtrInst>(Gep));
- EXPECT_TRUE(isa<ConstantInt>(Gep->getOperand(1)));
- EXPECT_EQ(cast<ConstantInt>(Gep->getOperand(1))->getSExtValue(), -1);
- EXPECT_TRUE(isa<BitCastInst>(Gep->getPrevNode()));
+ EXPECT_TRUE(isa<SCEVUnknown>(S));
}
// Make sure that SCEV doesn't introduce illegal ptrtoint/inttoptr instructions
More information about the llvm-commits
mailing list