[llvm] b76f1f1 - [SCEV] Keep common NUW flags when inlining Add operands.

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 9 09:21:45 PDT 2021


Author: Florian Hahn
Date: 2021-06-09T17:13:21+01:00
New Revision: b76f1f120285fe60b347220e705f0e6008d8cf65

URL: https://github.com/llvm/llvm-project/commit/b76f1f120285fe60b347220e705f0e6008d8cf65
DIFF: https://github.com/llvm/llvm-project/commit/b76f1f120285fe60b347220e705f0e6008d8cf65.diff

LOG: [SCEV] Keep common NUW flags when inlining Add operands.

Currently, NoWrapFlags are dropped if we inline operands of SCEVAddExpr
operands. As a consequence, we always drop flags when building
expressions like `getAddExpr(A, getAddExpr(B, C, NUW), NUW)`.

We should be able to retain NUW flags common among all inlined
SCEVAddExpr and the original flags.

Reviewed By: nikic, mkazantsev

Differential Revision: https://reviews.llvm.org/D103877

Added: 
    

Modified: 
    llvm/lib/Analysis/ScalarEvolution.cpp
    llvm/test/Analysis/ScalarEvolution/no-wrap-add-exprs.ll
    llvm/test/Analysis/ScalarEvolution/no-wrap-symbolic-becount.ll
    llvm/unittests/Analysis/ScalarEvolutionTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 57a6c83ddd46b..699be27f6f06b 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -2544,6 +2544,10 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
   // If there are add operands they would be next.
   if (Idx < Ops.size()) {
     bool DeletedAdd = false;
+    // If the original flags and all inlined SCEVAddExprs are NUW, use the
+    // common NUW flag for expression after inlining. Other flags cannot be
+    // preserved, because they may depend on the original order of operations.
+    SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
     while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
       if (Ops.size() > AddOpsInlineThreshold ||
           Add->getNumOperands() > AddOpsInlineThreshold)
@@ -2553,13 +2557,14 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
       Ops.erase(Ops.begin()+Idx);
       Ops.append(Add->op_begin(), Add->op_end());
       DeletedAdd = true;
+      CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
     }
 
     // If we deleted at least one add, we added operands to the end of the list,
     // and they are not necessarily sorted.  Recurse to resort and resimplify
     // any operands we just acquired.
     if (DeletedAdd)
-      return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
+      return getAddExpr(Ops, CommonFlags, Depth + 1);
   }
 
   // Skip over the add expression until we get to a multiply.

diff  --git a/llvm/test/Analysis/ScalarEvolution/no-wrap-add-exprs.ll b/llvm/test/Analysis/ScalarEvolution/no-wrap-add-exprs.ll
index fd34306861ea0..1bbbe8e373510 100644
--- a/llvm/test/Analysis/ScalarEvolution/no-wrap-add-exprs.ll
+++ b/llvm/test/Analysis/ScalarEvolution/no-wrap-add-exprs.ll
@@ -212,9 +212,9 @@ define void @f3(i8* %x_addr, i8* %y_addr, i32* %tmp_addr) {
 ; CHECK-NEXT:    %sunkaddr4 = getelementptr inbounds i8, i8* bitcast ({ %union, [2000 x i8] }* @tmp_addr to i8*), i64 %sunkaddr3
 ; CHECK-NEXT:    --> ((4 * (zext i32 (4 + (4 * (%tmp /u 4))<nuw>) to i64))<nuw><nsw> + @tmp_addr)<nuw> U: [0,-3) S: [-9223372036854775808,9223372036854775805)
 ; CHECK-NEXT:    %sunkaddr5 = getelementptr inbounds i8, i8* %sunkaddr4, i64 4096
-; CHECK-NEXT:    --> (4096 + (4 * (zext i32 (4 + (4 * (%tmp /u 4))<nuw>) to i64))<nuw><nsw> + @tmp_addr) U: [0,-3) S: [-9223372036854775808,9223372036854775805)
+; CHECK-NEXT:    --> (4096 + (4 * (zext i32 (4 + (4 * (%tmp /u 4))<nuw>) to i64))<nuw><nsw> + @tmp_addr)<nuw> U: [4096,-3) S: [-9223372036854775808,9223372036854775805)
 ; CHECK-NEXT:    %addr4.cast = bitcast i8* %sunkaddr5 to i32*
-; CHECK-NEXT:    --> (4096 + (4 * (zext i32 (4 + (4 * (%tmp /u 4))<nuw>) to i64))<nuw><nsw> + @tmp_addr) U: [0,-3) S: [-9223372036854775808,9223372036854775805)
+; CHECK-NEXT:    --> (4096 + (4 * (zext i32 (4 + (4 * (%tmp /u 4))<nuw>) to i64))<nuw><nsw> + @tmp_addr)<nuw> U: [4096,-3) S: [-9223372036854775808,9223372036854775805)
 ; CHECK-NEXT:    %addr4.incr = getelementptr i32, i32* %addr4.cast, i64 1
 ; CHECK-NEXT:    --> (4100 + (4 * (zext i32 (4 + (4 * (%tmp /u 4))<nuw>) to i64))<nuw><nsw> + @tmp_addr) U: [0,-3) S: [-9223372036854775808,9223372036854775805)
 ; CHECK-NEXT:    %add5 = add i32 %mul, 5
@@ -224,11 +224,11 @@ define void @f3(i8* %x_addr, i8* %y_addr, i32* %tmp_addr) {
 ; CHECK-NEXT:    %sunkaddr0 = mul i64 %add5.zext, 4
 ; CHECK-NEXT:    --> (4 + (4 * (zext i32 (4 + (4 * (%tmp /u 4))<nuw>) to i64))<nuw><nsw>)<nuw><nsw> U: [4,17179869173) S: [4,17179869185)
 ; CHECK-NEXT:    %sunkaddr1 = getelementptr inbounds i8, i8* bitcast ({ %union, [2000 x i8] }* @tmp_addr to i8*), i64 %sunkaddr0
-; CHECK-NEXT:    --> (4 + (4 * (zext i32 (4 + (4 * (%tmp /u 4))<nuw>) to i64))<nuw><nsw> + @tmp_addr) U: [0,-3) S: [-9223372036854775808,9223372036854775805)
+; CHECK-NEXT:    --> (4 + (4 * (zext i32 (4 + (4 * (%tmp /u 4))<nuw>) to i64))<nuw><nsw> + @tmp_addr)<nuw> U: [4,-3) S: [-9223372036854775808,9223372036854775805)
 ; CHECK-NEXT:    %sunkaddr2 = getelementptr inbounds i8, i8* %sunkaddr1, i64 4096
-; CHECK-NEXT:    --> (4100 + (4 * (zext i32 (4 + (4 * (%tmp /u 4))<nuw>) to i64))<nuw><nsw> + @tmp_addr) U: [0,-3) S: [-9223372036854775808,9223372036854775805)
+; CHECK-NEXT:    --> (4100 + (4 * (zext i32 (4 + (4 * (%tmp /u 4))<nuw>) to i64))<nuw><nsw> + @tmp_addr)<nuw> U: [0,-3) S: [-9223372036854775808,9223372036854775805)
 ; CHECK-NEXT:    %addr5.cast = bitcast i8* %sunkaddr2 to i32*
-; CHECK-NEXT:    --> (4100 + (4 * (zext i32 (4 + (4 * (%tmp /u 4))<nuw>) to i64))<nuw><nsw> + @tmp_addr) U: [0,-3) S: [-9223372036854775808,9223372036854775805)
+; CHECK-NEXT:    --> (4100 + (4 * (zext i32 (4 + (4 * (%tmp /u 4))<nuw>) to i64))<nuw><nsw> + @tmp_addr)<nuw> U: [0,-3) S: [-9223372036854775808,9223372036854775805)
 ; CHECK-NEXT:  Determining loop execution counts for: @f3
 ;
   entry:

diff  --git a/llvm/test/Analysis/ScalarEvolution/no-wrap-symbolic-becount.ll b/llvm/test/Analysis/ScalarEvolution/no-wrap-symbolic-becount.ll
index 7b09482a775f8..92f12a47a37dd 100644
--- a/llvm/test/Analysis/ScalarEvolution/no-wrap-symbolic-becount.ll
+++ b/llvm/test/Analysis/ScalarEvolution/no-wrap-symbolic-becount.ll
@@ -99,12 +99,12 @@ define void @pointer_iv_nowrap(i8* %startptr, i8* %endptr) local_unnamed_addr {
 ; CHECK-NEXT:  %iv = phi i8* [ %init, %entry ], [ %iv.next, %loop ]
 ; CHECK-NEXT:  -->  {(2000 + %startptr)<nuw>,+,1}<nuw><%loop> U: [2000,0) S: [2000,0)
 ; CHECK-NEXT:  %iv.next = getelementptr inbounds i8, i8* %iv, i64 1
-; CHECK-NEXT:  -->  {(2001 + %startptr),+,1}<nuw><%loop> U: full-set S: full-set
+; CHECK-NEXT:  -->  {(2001 + %startptr)<nuw>,+,1}<nuw><%loop> U: [2001,0) S: [2001,0)
 
 ; CHECK-NEXT:Determining loop execution counts for: @pointer_iv_nowrap
-; CHECK-NEXT:Loop %loop: Unpredictable backedge-taken count.
-; CHECK-NEXT:Loop %loop: Unpredictable max backedge-taken count.
-; CHECK-NEXT:Loop %loop: Unpredictable predicated backedge-taken count.
+; CHECK-NEXT:Loop %loop: backedge-taken count is (-2000 + (-1 * %startptr) + ((2000 + %startptr)<nuw> umax %endptr))
+; CHECK-NEXT:Loop %loop: max backedge-taken count is -2001
+; CHECK-NEXT:Loop %loop: Predicated backedge-taken count is (-2000 + (-1 * %startptr) + ((2000 + %startptr)<nuw> umax %endptr))
 ;
 entry:
   %init = getelementptr inbounds i8, i8* %startptr, i64 2000

diff  --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
index 3014fa4cb3792..060b0456bae47 100644
--- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
+++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
@@ -391,7 +391,7 @@ TEST_F(ScalarEvolutionsTest, CompareValueComplexity) {
 
 TEST_F(ScalarEvolutionsTest, SCEVAddExpr) {
   Type *Ty32 = Type::getInt32Ty(Context);
-  Type *ArgTys[] = {Type::getInt64Ty(Context), Ty32};
+  Type *ArgTys[] = {Type::getInt64Ty(Context), Ty32, Ty32, Ty32, Ty32, Ty32};
 
   FunctionType *FTy =
       FunctionType::get(Type::getVoidTy(Context), ArgTys, false);
@@ -419,6 +419,45 @@ TEST_F(ScalarEvolutionsTest, SCEVAddExpr) {
   ReturnInst::Create(Context, nullptr, EntryBB);
   ScalarEvolution SE = buildSE(*F);
   EXPECT_NE(nullptr, SE.getSCEV(Mul1));
+
+  Argument *A3 = &*(std::next(F->arg_begin(), 2));
+  Argument *A4 = &*(std::next(F->arg_begin(), 3));
+  Argument *A5 = &*(std::next(F->arg_begin(), 4));
+  Argument *A6 = &*(std::next(F->arg_begin(), 5));
+
+  auto *AddWithNUW = cast<SCEVAddExpr>(SE.getAddExpr(
+      SE.getAddExpr(SE.getSCEV(A2), SE.getSCEV(A3), SCEV::FlagNUW),
+      SE.getConstant(APInt(/*numBits=*/32, 5)), SCEV::FlagNUW));
+  EXPECT_EQ(AddWithNUW->getNumOperands(), 3u);
+  EXPECT_EQ(AddWithNUW->getNoWrapFlags(), SCEV::FlagNUW);
+
+  auto *AddWithAnyWrap =
+      SE.getAddExpr(SE.getSCEV(A3), SE.getSCEV(A4), SCEV::FlagAnyWrap);
+  auto *AddWithAnyWrapNUW = cast<SCEVAddExpr>(
+      SE.getAddExpr(AddWithAnyWrap, SE.getSCEV(A5), SCEV::FlagNUW));
+  EXPECT_EQ(AddWithAnyWrapNUW->getNumOperands(), 3u);
+  EXPECT_EQ(AddWithAnyWrapNUW->getNoWrapFlags(), SCEV::FlagAnyWrap);
+
+  auto *AddWithNSW = SE.getAddExpr(
+      SE.getSCEV(A2), SE.getConstant(APInt(32, 99)), SCEV::FlagNSW);
+  auto *AddWithNSW_NUW = cast<SCEVAddExpr>(
+      SE.getAddExpr(AddWithNSW, SE.getSCEV(A5), SCEV::FlagNUW));
+  EXPECT_EQ(AddWithNSW_NUW->getNumOperands(), 3u);
+  EXPECT_EQ(AddWithNSW_NUW->getNoWrapFlags(), SCEV::FlagAnyWrap);
+
+  auto *AddWithNSWNUW =
+      SE.getAddExpr(SE.getSCEV(A2), SE.getSCEV(A4),
+                    ScalarEvolution::setFlags(SCEV::FlagNUW, SCEV::FlagNSW));
+  auto *AddWithNSWNUW_NUW = cast<SCEVAddExpr>(
+      SE.getAddExpr(AddWithNSWNUW, SE.getSCEV(A5), SCEV::FlagNUW));
+  EXPECT_EQ(AddWithNSWNUW_NUW->getNumOperands(), 3u);
+  EXPECT_EQ(AddWithNSWNUW_NUW->getNoWrapFlags(), SCEV::FlagNUW);
+
+  auto *AddWithNSW_NSWNUW = cast<SCEVAddExpr>(
+      SE.getAddExpr(AddWithNSW, SE.getSCEV(A6),
+                    ScalarEvolution::setFlags(SCEV::FlagNUW, SCEV::FlagNSW)));
+  EXPECT_EQ(AddWithNSW_NSWNUW->getNumOperands(), 3u);
+  EXPECT_EQ(AddWithNSW_NSWNUW->getNoWrapFlags(), SCEV::FlagAnyWrap);
 }
 
 static Instruction &GetInstByName(Function &F, StringRef Name) {


        


More information about the llvm-commits mailing list