[llvm] 306b9c7 - [SCEV] Handle more add/addrec mixes in computeConstantDifference() (#101999)

via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 13 02:01:43 PDT 2024


Author: Nikita Popov
Date: 2024-08-13T11:01:39+02:00
New Revision: 306b9c7b48ade28ed10e5926b6d9f5e3acab3968

URL: https://github.com/llvm/llvm-project/commit/306b9c7b48ade28ed10e5926b6d9f5e3acab3968
DIFF: https://github.com/llvm/llvm-project/commit/306b9c7b48ade28ed10e5926b6d9f5e3acab3968.diff

LOG: [SCEV] Handle more add/addrec mixes in computeConstantDifference() (#101999)

computeConstantDifference() can currently look through addrecs with
identical steps, and then through adds with identical operands (apart
from constants).

However, it fails to handle minor variations, such as two nested add
recs, or an outer add with an inner addrec (rather than the other way
around).

This patch supports these cases by adding a loop over the
simplifications, limited to a small number of iterations. The motivation
is the same as in #101339, to make
computeConstantDifference() powerful enough to replace existing uses of
`dyn_cast<SCEVConstant>(getMinusSCEV())` with it. Though as the IR test
diff shows, other callers may also benefit.

Added: 
    

Modified: 
    llvm/lib/Analysis/ScalarEvolution.cpp
    llvm/test/Transforms/LoopVectorize/skeleton-lcssa-crash.ll
    llvm/unittests/Analysis/ScalarEvolutionTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index de220d68a59163..78975bee4d72c4 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -11951,62 +11951,94 @@ ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) {
   // We avoid subtracting expressions here because this function is usually
   // fairly deep in the call stack (i.e. is called many times).
 
-  // X - X = 0.
   unsigned BW = getTypeSizeInBits(More->getType());
-  if (More == Less)
-    return APInt(BW, 0);
-
-  if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
-    const auto *LAR = cast<SCEVAddRecExpr>(Less);
-    const auto *MAR = cast<SCEVAddRecExpr>(More);
-
-    if (LAR->getLoop() != MAR->getLoop())
-      return std::nullopt;
-
-    // We look at affine expressions only; not for correctness but to keep
-    // getStepRecurrence cheap.
-    if (!LAR->isAffine() || !MAR->isAffine())
-      return std::nullopt;
+  APInt Diff(BW, 0);
+  // Try various simplifications to reduce the 
diff erence to a constant. Limit
+  // the number of allowed simplifications to keep compile-time low.
+  for (unsigned I = 0; I < 4; ++I) {
+    if (More == Less)
+      return Diff;
+
+    // Reduce addrecs with identical steps to their start value.
+    if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
+      const auto *LAR = cast<SCEVAddRecExpr>(Less);
+      const auto *MAR = cast<SCEVAddRecExpr>(More);
+
+      if (LAR->getLoop() != MAR->getLoop())
+        return std::nullopt;
+
+      // We look at affine expressions only; not for correctness but to keep
+      // getStepRecurrence cheap.
+      if (!LAR->isAffine() || !MAR->isAffine())
+        return std::nullopt;
+
+      if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
+        return std::nullopt;
+
+      Less = LAR->getStart();
+      More = MAR->getStart();
+      continue;
+    }
 
-    if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
+    // Try to cancel out common factors in two add expressions.
+    SmallDenseMap<const SCEV *, int, 8> Multiplicity;
+    auto Add = [&](const SCEV *S, int Mul) {
+      if (auto *C = dyn_cast<SCEVConstant>(S)) {
+        if (Mul == 1) {
+          Diff += C->getAPInt();
+        } else {
+          assert(Mul == -1);
+          Diff -= C->getAPInt();
+        }
+      } else
+        Multiplicity[S] += Mul;
+    };
+    auto Decompose = [&](const SCEV *S, int Mul) {
+      if (isa<SCEVAddExpr>(S)) {
+        for (const SCEV *Op : S->operands())
+          Add(Op, Mul);
+      } else
+        Add(S, Mul);
+    };
+    Decompose(More, 1);
+    Decompose(Less, -1);
+
+    // Check whether all the non-constants cancel out, or reduce to new
+    // More/Less values.
+    const SCEV *NewMore = nullptr, *NewLess = nullptr;
+    for (const auto [S, Mul] : Multiplicity) {
+      if (Mul == 0)
+        continue;
+      if (Mul == 1) {
+        if (NewMore)
+          return std::nullopt;
+        NewMore = S;
+      } else if (Mul == -1) {
+        if (NewLess)
+          return std::nullopt;
+        NewLess = S;
+      } else
+        return std::nullopt;
+    }
+
+    // Values stayed the same, no point in trying further.
+    if (NewMore == More || NewLess == Less)
       return std::nullopt;
 
-    Less = LAR->getStart();
-    More = MAR->getStart();
+    More = NewMore;
+    Less = NewLess;
 
-    // fall through
-  }
+    // Reduced to constant.
+    if (!More && !Less)
+      return Diff;
 
-  // Try to cancel out common factors in two add expressions.
-  SmallDenseMap<const SCEV *, int, 8> Multiplicity;
-  APInt Diff(BW, 0);
-  auto Add = [&](const SCEV *S, int Mul) {
-    if (auto *C = dyn_cast<SCEVConstant>(S)) {
-      if (Mul == 1) {
-        Diff += C->getAPInt();
-      } else {
-        assert(Mul == -1);
-        Diff -= C->getAPInt();
-      }
-    } else
-      Multiplicity[S] += Mul;
-  };
-  auto Decompose = [&](const SCEV *S, int Mul) {
-    if (isa<SCEVAddExpr>(S)) {
-      for (const SCEV *Op : S->operands())
-        Add(Op, Mul);
-    } else
-      Add(S, Mul);
-  };
-  Decompose(More, 1);
-  Decompose(Less, -1);
-
-  // Check whether all the non-constants cancel out.
-  for (const auto &[_, Mul] : Multiplicity)
-    if (Mul != 0)
+    // Left with variable on only one side, bail out.
+    if (!More || !Less)
       return std::nullopt;
+  }
 
-  return Diff;
+  // Did not reduce to constant.
+  return std::nullopt;
 }
 
 bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(

diff  --git a/llvm/test/Transforms/LoopVectorize/skeleton-lcssa-crash.ll b/llvm/test/Transforms/LoopVectorize/skeleton-lcssa-crash.ll
index db5a7105fd8c4d..f55e37c7772609 100644
--- a/llvm/test/Transforms/LoopVectorize/skeleton-lcssa-crash.ll
+++ b/llvm/test/Transforms/LoopVectorize/skeleton-lcssa-crash.ll
@@ -29,13 +29,13 @@ define i16 @test(ptr %arg, i64 %N) {
 ; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[TMP0]], 2
 ; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_MEMCHECK:%.*]]
 ; CHECK:       vector.memcheck:
-; CHECK-NEXT:    [[UGLYGEP:%.*]] = getelementptr i8, ptr [[L_2_LCSSA]], i64 2
-; CHECK-NEXT:    [[UGLYGEP5:%.*]] = getelementptr i8, ptr [[L_1_LCSSA]], i64 2
+; CHECK-NEXT:    [[SCEVGEP:%.*]] = getelementptr i8, ptr [[L_2_LCSSA]], i64 2
+; CHECK-NEXT:    [[SCEVGEP5:%.*]] = getelementptr i8, ptr [[L_1_LCSSA]], i64 2
 ; CHECK-NEXT:    [[TMP1:%.*]] = shl i64 [[N]], 1
 ; CHECK-NEXT:    [[TMP2:%.*]] = add i64 [[TMP1]], 4
-; CHECK-NEXT:    [[UGLYGEP6:%.*]] = getelementptr i8, ptr [[L_1_LCSSA]], i64 [[TMP2]]
-; CHECK-NEXT:    [[BOUND0:%.*]] = icmp ult ptr [[L_2_LCSSA]], [[UGLYGEP6]]
-; CHECK-NEXT:    [[BOUND1:%.*]] = icmp ult ptr [[UGLYGEP5]], [[UGLYGEP]]
+; CHECK-NEXT:    [[SCEVGEP6:%.*]] = getelementptr i8, ptr [[L_1_LCSSA]], i64 [[TMP2]]
+; CHECK-NEXT:    [[BOUND0:%.*]] = icmp ult ptr [[L_2_LCSSA]], [[SCEVGEP6]]
+; CHECK-NEXT:    [[BOUND1:%.*]] = icmp ult ptr [[SCEVGEP5]], [[SCEVGEP]]
 ; CHECK-NEXT:    [[FOUND_CONFLICT:%.*]] = and i1 [[BOUND0]], [[BOUND1]]
 ; CHECK-NEXT:    br i1 [[FOUND_CONFLICT]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
 ; CHECK:       vector.ph:
@@ -48,10 +48,10 @@ define i16 @test(ptr %arg, i64 %N) {
 ; CHECK-NEXT:    [[TMP4:%.*]] = add nuw nsw i64 [[TMP3]], 1
 ; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds i16, ptr [[L_1]], i64 [[TMP4]]
 ; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr inbounds i16, ptr [[TMP5]], i32 0
-; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <2 x i16>, ptr [[TMP6]], align 2, !alias.scope !0
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <2 x i16>, ptr [[TMP6]], align 2, !alias.scope [[META0:![0-9]+]]
 ; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr inbounds i16, ptr [[L_2]], i64 0
 ; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <2 x i16> [[WIDE_LOAD]], i32 1
-; CHECK-NEXT:    store i16 [[TMP8]], ptr [[TMP7]], align 2, !alias.scope !3, !noalias !0
+; CHECK-NEXT:    store i16 [[TMP8]], ptr [[TMP7]], align 2, !alias.scope [[META3:![0-9]+]], !noalias [[META0]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 2
 ; CHECK-NEXT:    [[TMP9:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[TMP9]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP5:![0-9]+]]
@@ -74,7 +74,7 @@ define i16 @test(ptr %arg, i64 %N) {
 ; CHECK-NEXT:    [[LOOP_L_1:%.*]] = load i16, ptr [[GEP_1]], align 2
 ; CHECK-NEXT:    [[GEP_2:%.*]] = getelementptr inbounds i16, ptr [[L_2_LCSSA]], i64 0
 ; CHECK-NEXT:    store i16 [[LOOP_L_1]], ptr [[GEP_2]], align 2
-; CHECK-NEXT:    br i1 [[C_5]], label [[LOOP_3]], label [[EXIT_LOOPEXIT]], !llvm.loop [[LOOP7:![0-9]+]]
+; CHECK-NEXT:    br i1 [[C_5]], label [[LOOP_3]], label [[EXIT_LOOPEXIT]], !llvm.loop [[LOOP8:![0-9]+]]
 ; CHECK:       exit.loopexit:
 ; CHECK-NEXT:    br label [[EXIT:%.*]]
 ; CHECK:       exit.loopexit1:
@@ -138,31 +138,17 @@ define void @test2(ptr %dst) {
 ; CHECK-NEXT:    [[INDVAR_NEXT]] = add i32 [[INDVAR]], 1
 ; CHECK-NEXT:    br i1 [[C_1]], label [[LOOP_2]], label [[LOOP_3_PH:%.*]]
 ; CHECK:       loop.3.ph:
-; CHECK-NEXT:    [[INDVAR_LCSSA1:%.*]] = phi i32 [ [[INDVAR]], [[LOOP_2]] ]
 ; CHECK-NEXT:    [[INDVAR_LCSSA:%.*]] = phi i32 [ [[INDVAR]], [[LOOP_2]] ]
 ; CHECK-NEXT:    [[IV_1_LCSSA:%.*]] = phi i64 [ [[IV_1]], [[LOOP_2]] ]
 ; CHECK-NEXT:    [[TMP0:%.*]] = and i64 [[IV_1_LCSSA]], 4294967295
-; CHECK-NEXT:    [[TMP1:%.*]] = mul i32 [[INDVAR_LCSSA1]], -1
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i32 [[INDVAR_LCSSA]], -1
 ; CHECK-NEXT:    [[TMP2:%.*]] = add i32 [[TMP1]], 1000
-; CHECK-NEXT:    [[SMIN2:%.*]] = call i32 @llvm.smin.i32(i32 [[TMP2]], i32 1)
-; CHECK-NEXT:    [[TMP3:%.*]] = sub i32 [[TMP2]], [[SMIN2]]
+; CHECK-NEXT:    [[SMIN:%.*]] = call i32 @llvm.smin.i32(i32 [[TMP2]], i32 1)
+; CHECK-NEXT:    [[TMP3:%.*]] = sub i32 [[TMP2]], [[SMIN]]
 ; CHECK-NEXT:    [[TMP4:%.*]] = zext i32 [[TMP3]] to i64
 ; CHECK-NEXT:    [[TMP5:%.*]] = add nuw nsw i64 [[TMP4]], 1
 ; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[TMP5]], 2
-; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_SCEVCHECK:%.*]]
-; CHECK:       vector.scevcheck:
-; CHECK-NEXT:    [[TMP6:%.*]] = mul i32 [[INDVAR_LCSSA]], -1
-; CHECK-NEXT:    [[TMP7:%.*]] = add i32 [[TMP6]], 1000
-; CHECK-NEXT:    [[SMIN:%.*]] = call i32 @llvm.smin.i32(i32 [[TMP7]], i32 1)
-; CHECK-NEXT:    [[TMP8:%.*]] = sub i32 [[TMP7]], [[SMIN]]
-; CHECK-NEXT:    [[TMP9:%.*]] = add i32 [[TMP6]], 999
-; CHECK-NEXT:    [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 1, i32 [[TMP8]])
-; CHECK-NEXT:    [[MUL_RESULT:%.*]] = extractvalue { i32, i1 } [[MUL]], 0
-; CHECK-NEXT:    [[MUL_OVERFLOW:%.*]] = extractvalue { i32, i1 } [[MUL]], 1
-; CHECK-NEXT:    [[TMP10:%.*]] = sub i32 [[TMP9]], [[MUL_RESULT]]
-; CHECK-NEXT:    [[TMP11:%.*]] = icmp ugt i32 [[TMP10]], [[TMP9]]
-; CHECK-NEXT:    [[TMP12:%.*]] = or i1 [[TMP11]], [[MUL_OVERFLOW]]
-; CHECK-NEXT:    br i1 [[TMP12]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
+; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK:       vector.ph:
 ; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[TMP5]], 2
 ; CHECK-NEXT:    [[N_VEC:%.*]] = sub i64 [[TMP5]], [[N_MOD_VF]]
@@ -171,21 +157,21 @@ define void @test2(ptr %dst) {
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[OFFSET_IDX:%.*]] = sub i64 [[TMP0]], [[INDEX]]
-; CHECK-NEXT:    [[TMP13:%.*]] = add i64 [[OFFSET_IDX]], 0
-; CHECK-NEXT:    [[TMP14:%.*]] = add nsw i64 [[TMP13]], -1
-; CHECK-NEXT:    [[TMP15:%.*]] = and i64 [[TMP14]], 4294967295
-; CHECK-NEXT:    [[TMP16:%.*]] = getelementptr inbounds i32, ptr [[DST:%.*]], i64 [[TMP15]]
-; CHECK-NEXT:    [[TMP17:%.*]] = getelementptr inbounds i32, ptr [[TMP16]], i32 0
-; CHECK-NEXT:    [[TMP18:%.*]] = getelementptr inbounds i32, ptr [[TMP17]], i32 -1
-; CHECK-NEXT:    store <2 x i32> zeroinitializer, ptr [[TMP18]], align 4
+; CHECK-NEXT:    [[TMP6:%.*]] = add i64 [[OFFSET_IDX]], 0
+; CHECK-NEXT:    [[TMP7:%.*]] = add nsw i64 [[TMP6]], -1
+; CHECK-NEXT:    [[TMP8:%.*]] = and i64 [[TMP7]], 4294967295
+; CHECK-NEXT:    [[TMP9:%.*]] = getelementptr inbounds i32, ptr [[DST:%.*]], i64 [[TMP8]]
+; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr inbounds i32, ptr [[TMP9]], i32 0
+; CHECK-NEXT:    [[TMP11:%.*]] = getelementptr inbounds i32, ptr [[TMP10]], i32 -1
+; CHECK-NEXT:    store <2 x i32> zeroinitializer, ptr [[TMP11]], align 4
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 2
-; CHECK-NEXT:    [[TMP19:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[TMP19]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
+; CHECK-NEXT:    [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP9:![0-9]+]]
 ; CHECK:       middle.block:
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[TMP5]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[CMP_N]], label [[LOOP_1_LATCH:%.*]], label [[SCALAR_PH]]
 ; CHECK:       scalar.ph:
-; CHECK-NEXT:    [[BC_RESUME_VAL:%.*]] = phi i64 [ [[IND_END]], [[MIDDLE_BLOCK]] ], [ [[TMP0]], [[LOOP_3_PH]] ], [ [[TMP0]], [[VECTOR_SCEVCHECK]] ]
+; CHECK-NEXT:    [[BC_RESUME_VAL:%.*]] = phi i64 [ [[IND_END]], [[MIDDLE_BLOCK]] ], [ [[TMP0]], [[LOOP_3_PH]] ]
 ; CHECK-NEXT:    br label [[LOOP_3:%.*]]
 ; CHECK:       loop.3:
 ; CHECK-NEXT:    [[IV_2:%.*]] = phi i64 [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ], [ [[IV_2_NEXT:%.*]], [[LOOP_3]] ]
@@ -195,7 +181,7 @@ define void @test2(ptr %dst) {
 ; CHECK-NEXT:    store i32 0, ptr [[GEP_DST]], align 4
 ; CHECK-NEXT:    [[IV_2_TRUNC:%.*]] = trunc i64 [[IV_2]] to i32
 ; CHECK-NEXT:    [[EC:%.*]] = icmp sgt i32 [[IV_2_TRUNC]], 1
-; CHECK-NEXT:    br i1 [[EC]], label [[LOOP_3]], label [[LOOP_1_LATCH]], !llvm.loop [[LOOP9:![0-9]+]]
+; CHECK-NEXT:    br i1 [[EC]], label [[LOOP_3]], label [[LOOP_1_LATCH]], !llvm.loop [[LOOP10:![0-9]+]]
 ; CHECK:       loop.1.latch:
 ; CHECK-NEXT:    [[C_2:%.*]] = call i1 @cond()
 ; CHECK-NEXT:    br i1 [[C_2]], label [[EXIT:%.*]], label [[LOOP_1_HEADER]]

diff  --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
index 6fc24f6796310d..42aad6ae507bf6 100644
--- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
+++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
@@ -1202,8 +1202,8 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
     EXPECT_EQ(
diff (ScevIV, ScevIVNext), -1);
     EXPECT_EQ(
diff (ScevIVNext, ScevIV), 1);
     EXPECT_EQ(
diff (ScevIVNext, ScevIVNext), 0);
-    EXPECT_EQ(
diff (ScevIV2P3, ScevIV2), std::nullopt); // TODO
-    EXPECT_EQ(
diff (ScevIV2PVar, ScevIV2PVarP3), std::nullopt); // TODO
+    EXPECT_EQ(
diff (ScevIV2P3, ScevIV2), 3);
+    EXPECT_EQ(
diff (ScevIV2PVar, ScevIV2PVarP3), -3);
     EXPECT_EQ(
diff (ScevV0, ScevIV), std::nullopt);
     EXPECT_EQ(
diff (ScevIVNext, ScevV3), std::nullopt);
     EXPECT_EQ(
diff (ScevYY, ScevV3), std::nullopt);


        


More information about the llvm-commits mailing list