[llvm] 79af689 - [SCEV] Handle more adds in computeConstantDifference() (#101339)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Aug 2 04:43:05 PDT 2024
Author: Nikita Popov
Date: 2024-08-02T13:43:02+02:00
New Revision: 79af6892f8eee1d0fb96c243716c8b03936751c9
URL: https://github.com/llvm/llvm-project/commit/79af6892f8eee1d0fb96c243716c8b03936751c9
DIFF: https://github.com/llvm/llvm-project/commit/79af6892f8eee1d0fb96c243716c8b03936751c9.diff
LOG: [SCEV] Handle more adds in computeConstantDifference() (#101339)
Currently it only deals with the case where we're subtracting adds with
at most one non-constant operand. This patch extends it to cancel out
common operands for the subtraction of arbitrary add expressions.
The background here is that I want to replace a getMinusSCEV() call in
LAA with computeConstantDifference():
https://github.com/llvm/llvm-project/blob/93fecc2577ece0329f3bbe2719bbc5b4b9b30010/llvm/lib/Analysis/LoopAccessAnalysis.cpp#L1602-L1603
This particular call is very expensive in some cases (e.g. lencod with
LTO) and computeConstantDifference() could achieve this much more
cheaply, because it does not need to construct new SCEV expressions.
However, the current computeConstantDifference() implementation is too
weak for this and misses many basic cases. This is a step towards making
it more powerful while still keeping it pretty fast.
Added:
Modified:
llvm/lib/Analysis/ScalarEvolution.cpp
llvm/unittests/Analysis/ScalarEvolutionTest.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 06dd17dd3648b..b6c70451b9b06 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -11934,8 +11934,9 @@ ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) {
// fairly deep in the call stack (i.e. is called many times).
// X - X = 0.
+ unsigned BW = getTypeSizeInBits(More->getType());
if (More == Less)
- return APInt(getTypeSizeInBits(More->getType()), 0);
+ return APInt(BW, 0);
if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
const auto *LAR = cast<SCEVAddRecExpr>(Less);
@@ -11958,33 +11959,31 @@ ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) {
// fall through
}
- if (isa<SCEVConstant>(Less) && isa<SCEVConstant>(More)) {
- const auto &M = cast<SCEVConstant>(More)->getAPInt();
- const auto &L = cast<SCEVConstant>(Less)->getAPInt();
- return M - L;
- }
-
- SCEV::NoWrapFlags Flags;
- const SCEV *LLess = nullptr, *RLess = nullptr;
- const SCEV *LMore = nullptr, *RMore = nullptr;
- const SCEVConstant *C1 = nullptr, *C2 = nullptr;
- // Compare (X + C1) vs X.
- if (splitBinaryAdd(Less, LLess, RLess, Flags))
- if ((C1 = dyn_cast<SCEVConstant>(LLess)))
- if (RLess == More)
- return -(C1->getAPInt());
-
- // Compare X vs (X + C2).
- if (splitBinaryAdd(More, LMore, RMore, Flags))
- if ((C2 = dyn_cast<SCEVConstant>(LMore)))
- if (RMore == Less)
- return C2->getAPInt();
+ // Try to cancel out common factors in two add expressions.
+ SmallDenseMap<const SCEV *, int, 8> Multiplicity;
+ APInt Diff(BW, 0);
+ auto Add = [&](const SCEV *S, int Mul) {
+ if (auto *C = dyn_cast<SCEVConstant>(S))
+ Diff += C->getAPInt() * Mul;
+ else
+ Multiplicity[S] += Mul;
+ };
+ auto Decompose = [&](const SCEV *S, int Mul) {
+ if (isa<SCEVAddExpr>(S)) {
+ for (const SCEV *Op : S->operands())
+ Add(Op, Mul);
+ } else
+ Add(S, Mul);
+ };
+ Decompose(More, 1);
+ Decompose(Less, -1);
- // Compare (X + C1) vs (X + C2).
- if (C1 && C2 && RLess == RMore)
- return C2->getAPInt() - C1->getAPInt();
+ // Check whether all the non-constants cancel out.
+ for (const auto [_, Mul] : Multiplicity)
+ if (Mul != 0)
+ return std::nullopt;
- return std::nullopt;
+ return Diff;
}
bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
index 76e6095636305..3802ae4051f42 100644
--- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
+++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
@@ -1117,10 +1117,12 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
LLVMContext C;
SMDiagnostic Err;
std::unique_ptr<Module> M = parseAssemblyString(
- "define void @foo(i32 %sz, i32 %pp) { "
+ "define void @foo(i32 %sz, i32 %pp, i32 %x) { "
"entry: "
" %v0 = add i32 %pp, 0 "
" %v3 = add i32 %pp, 3 "
+ " %vx = add i32 %pp, %x "
+ " %vx3 = add i32 %vx, 3 "
" br label %loop.body "
"loop.body: "
" %iv = phi i32 [ %iv.next, %loop.body ], [ 0, %entry ] "
@@ -1141,6 +1143,9 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
auto *ScevV0 = SE.getSCEV(getInstructionByName(F, "v0")); // %pp
auto *ScevV3 = SE.getSCEV(getInstructionByName(F, "v3")); // (3 + %pp)
+ auto *ScevVX = SE.getSCEV(getInstructionByName(F, "vx")); // (%pp + %x)
+ // (%pp + %x + 3)
+ auto *ScevVX3 = SE.getSCEV(getInstructionByName(F, "vx3"));
auto *ScevIV = SE.getSCEV(getInstructionByName(F, "iv")); // {0,+,1}
auto *ScevXA = SE.getSCEV(getInstructionByName(F, "xa")); // {%pp,+,1}
auto *ScevYY = SE.getSCEV(getInstructionByName(F, "yy")); // {(3 + %pp),+,1}
@@ -1162,6 +1167,7 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
EXPECT_EQ(
diff (ScevV0, ScevV3), -3);
EXPECT_EQ(
diff (ScevV0, ScevV0), 0);
EXPECT_EQ(
diff (ScevV3, ScevV3), 0);
+ EXPECT_EQ(
diff (ScevVX3, ScevVX), 3);
EXPECT_EQ(
diff (ScevIV, ScevIV), 0);
EXPECT_EQ(
diff (ScevXA, ScevXB), 0);
EXPECT_EQ(
diff (ScevXA, ScevYY), -3);
More information about the llvm-commits
mailing list