[llvm] [SCEV] Handle more adds in computeConstantDifference() (PR #101339)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Jul 31 07:17:58 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-analysis
Author: Nikita Popov (nikic)
<details>
<summary>Changes</summary>
Currently it only deals with the case where we're subtracting adds with at most one non-constant operand. This patch extends it to cancel out common operands for the subtraction of arbitrary add expressions.
The background here is that I want to replace a getMinusSCEV() call in LAA with computeConstantDifference():
https://github.com/llvm/llvm-project/blob/93fecc2577ece0329f3bbe2719bbc5b4b9b30010/llvm/lib/Analysis/LoopAccessAnalysis.cpp#L1602-L1603
This particular call is very expensive in some cases (e.g. lencod with LTO) and computeConstantDifference() could achieve this much more cheaply, because it does not need to construct new SCEV expressions.
However, the current computeConstantDifference() implementation is too weak for this and misses many basic cases. This is a step towards making it more powerful while still keeping it pretty fast.
Compile-time impact:
http://llvm-compile-time-tracker.com/compare.php?from=5d833ee6acc85bf108a8787ba233e955728868ab&to=ef3b81a63874e8f05c5b627014b516d4c59388f4&stat=instructions:u
---
Full diff: https://github.com/llvm/llvm-project/pull/101339.diff
2 Files Affected:
- (modified) llvm/lib/Analysis/ScalarEvolution.cpp (+25-26)
- (modified) llvm/unittests/Analysis/ScalarEvolutionTest.cpp (+7-1)
``````````diff
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index fb56d5d436653..bdd36e7d3154f 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -11923,8 +11923,9 @@ ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) {
// fairly deep in the call stack (i.e. is called many times).
// X - X = 0.
+ unsigned BW = getTypeSizeInBits(More->getType());
if (More == Less)
- return APInt(getTypeSizeInBits(More->getType()), 0);
+ return APInt(BW, 0);
if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
const auto *LAR = cast<SCEVAddRecExpr>(Less);
@@ -11947,33 +11948,31 @@ ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) {
// fall through
}
- if (isa<SCEVConstant>(Less) && isa<SCEVConstant>(More)) {
- const auto &M = cast<SCEVConstant>(More)->getAPInt();
- const auto &L = cast<SCEVConstant>(Less)->getAPInt();
- return M - L;
- }
-
- SCEV::NoWrapFlags Flags;
- const SCEV *LLess = nullptr, *RLess = nullptr;
- const SCEV *LMore = nullptr, *RMore = nullptr;
- const SCEVConstant *C1 = nullptr, *C2 = nullptr;
- // Compare (X + C1) vs X.
- if (splitBinaryAdd(Less, LLess, RLess, Flags))
- if ((C1 = dyn_cast<SCEVConstant>(LLess)))
- if (RLess == More)
- return -(C1->getAPInt());
-
- // Compare X vs (X + C2).
- if (splitBinaryAdd(More, LMore, RMore, Flags))
- if ((C2 = dyn_cast<SCEVConstant>(LMore)))
- if (RMore == Less)
- return C2->getAPInt();
+ // Try to cancel out common factors in two add expressions.
+ SmallDenseMap<const SCEV *, int, 8> Multiplicity;
+ APInt Diff(BW, 0);
+ auto Add = [&](const SCEV *S, int Mul) {
+ if (auto *C = dyn_cast<SCEVConstant>(S))
+ Diff += C->getAPInt() * Mul;
+ else
+ Multiplicity[S] += Mul;
+ };
+ auto Decompose = [&](const SCEV *S, int Mul) {
+ if (isa<SCEVAddExpr>(S)) {
+ for (const SCEV *Op : S->operands())
+ Add(Op, Mul);
+ } else
+ Add(S, Mul);
+ };
+ Decompose(More, 1);
+ Decompose(Less, -1);
- // Compare (X + C1) vs (X + C2).
- if (C1 && C2 && RLess == RMore)
- return C2->getAPInt() - C1->getAPInt();
+ // Check whether all the non-constants cancel out.
+ for (const auto [_, Mul] : Multiplicity)
+ if (Mul != 0)
+ return std::nullopt;
- return std::nullopt;
+ return Diff;
}
bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
index 76e6095636305..3802ae4051f42 100644
--- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
+++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
@@ -1117,10 +1117,12 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
LLVMContext C;
SMDiagnostic Err;
std::unique_ptr<Module> M = parseAssemblyString(
- "define void @foo(i32 %sz, i32 %pp) { "
+ "define void @foo(i32 %sz, i32 %pp, i32 %x) { "
"entry: "
" %v0 = add i32 %pp, 0 "
" %v3 = add i32 %pp, 3 "
+ " %vx = add i32 %pp, %x "
+ " %vx3 = add i32 %vx, 3 "
" br label %loop.body "
"loop.body: "
" %iv = phi i32 [ %iv.next, %loop.body ], [ 0, %entry ] "
@@ -1141,6 +1143,9 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
auto *ScevV0 = SE.getSCEV(getInstructionByName(F, "v0")); // %pp
auto *ScevV3 = SE.getSCEV(getInstructionByName(F, "v3")); // (3 + %pp)
+ auto *ScevVX = SE.getSCEV(getInstructionByName(F, "vx")); // (%pp + %x)
+ // (%pp + %x + 3)
+ auto *ScevVX3 = SE.getSCEV(getInstructionByName(F, "vx3"));
auto *ScevIV = SE.getSCEV(getInstructionByName(F, "iv")); // {0,+,1}
auto *ScevXA = SE.getSCEV(getInstructionByName(F, "xa")); // {%pp,+,1}
auto *ScevYY = SE.getSCEV(getInstructionByName(F, "yy")); // {(3 + %pp),+,1}
@@ -1162,6 +1167,7 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
EXPECT_EQ(diff(ScevV0, ScevV3), -3);
EXPECT_EQ(diff(ScevV0, ScevV0), 0);
EXPECT_EQ(diff(ScevV3, ScevV3), 0);
+ EXPECT_EQ(diff(ScevVX3, ScevVX), 3);
EXPECT_EQ(diff(ScevIV, ScevIV), 0);
EXPECT_EQ(diff(ScevXA, ScevXB), 0);
EXPECT_EQ(diff(ScevXA, ScevYY), -3);
``````````
</details>
https://github.com/llvm/llvm-project/pull/101339
More information about the llvm-commits
mailing list