[llvm] [SCEV] Optimize away variable with constant recurrence in loop (PR #87343)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Apr 2 05:57:00 PDT 2024
https://github.com/mrdaybird created https://github.com/llvm/llvm-project/pull/87343
Fixes #75331.
Based on the idea suggested by @nikic in the issue discussion.
**This is my first PR, any thoughts and suggestion are greatly appreciated!**
**Explanation**:
The basic idea is to optimize away recurrence which are constant throughout the loop.
Consider Recurrence of the form `y(k) = y(k-1)*p + q`, where:
1.`y(k)` is value of y in the k-th iteration,
2. `y(0) = a` is the start value(value before the loop),
3. *p* and *q* are loop invariants
then if `y(k) = a` for all k, then replace the phi node associated with *y* with *a*.
**Implementation:**
This is achieved by replacing the phi node with start value in the back edge and checking if it equals to
start value itself.
Examples(https://godbolt.org/z/Px66abnva):
```{c}
int simple(int a, unsigned int k){
int j = a;
for(unsigned i = 0; i < k; i++){
j = 2*j - a;
}
return j;
}
int intermediate(int a, int b, unsigned int k) {
int j = a;
for (unsigned i = 0; i < k; i++) {
j = b*j - (b-1)*a;
}
return j;
}
```
Both function should return 'a'. Currently. GCC seems to optimize the the first function, and not the second. LLVM, does not optimize both.
This PR will optimize both functions.
**Alive2:** https://alive2.llvm.org/ce/z/nnaTu6
Tasks:
- [x] Add optimization to *createAddRecFromPHI* in *ScalarEvolution.cpp*
- [ ] Add tests
>From 8008c62a73039b767d948da5ef758bbd2340afc0 Mon Sep 17 00:00:00 2001
From: Vaibhav Pathak <pathakvaibhav at protonmail.com>
Date: Tue, 2 Apr 2024 17:31:45 +0530
Subject: [PATCH] Add missed optimization in SCEV
---
llvm/lib/Analysis/ScalarEvolution.cpp | 67 +++++++++++++++++++++++++++
1 file changed, 67 insertions(+)
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 515b9d0744f6e3..4033dd7d5b4439 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -5925,6 +5925,73 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
return PHISCEV;
}
+ } else {
+ // try to match y(k) = p*y(k-1) + q, where p and q are loop invariants
+ // if match,for start_value = a does the equation simplify to y(k) = a
+ unsigned FoundSNIdx = Add->getNumOperands();
+ for (unsigned i = 0, ei = Add->getNumOperands(); i != ei; ++i) {
+ if (const SCEVMulExpr* Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(i))) {
+ for (unsigned j = 0, ej = Mul->getNumOperands(); j != ej; ++j) {
+ if (Mul->getOperand(j) == SymbolicName){
+ FoundSNIdx = i;
+ break;
+ }
+ }
+ if (FoundSNIdx != ei) break;
+ }
+ }
+ if (FoundSNIdx != Add->getNumOperands()) {
+ // check for invariance of q, q can be sum of values
+ // collect q along the way.
+ bool IsQInvariant = true;
+ SmallVector<const SCEV*, 4> QValues;
+ for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) {
+ if (i != FoundSNIdx) {
+ auto *Op = Add->getOperand(i);
+ if (isLoopInvariant(Op, L)) {
+ QValues.push_back(Op);
+ }else{
+ IsQInvariant = false;
+ break;
+ }
+ }
+ }
+ if (IsQInvariant) {
+ // check for invariance of p, p can be product of values
+ // p = 1 is covered in the previous section.
+ // collect p along the way.
+ bool IsPInvariant = true;
+ SmallVector<const SCEV*, 4> PValues;
+ const SCEVMulExpr *SNOperand = dyn_cast<SCEVMulExpr>(Add->getOperand(FoundSNIdx));
+ for (unsigned i = 0, e = SNOperand->getNumOperands(); i != e; ++i) {
+ auto *Op = SNOperand->getOperand(i);
+ if (Op != SymbolicName) {
+ if (isLoopInvariant(Op, L)) {
+ PValues.push_back(Op);
+ }else{
+ IsPInvariant = false;
+ break;
+ }
+ }
+ }
+ if (IsPInvariant) {
+ // p and q are both loop invariant
+ // then simplify p*Start_value + q
+ auto *StartVal = getSCEV(StartValueV);
+ auto *P = getMulExpr(PValues);
+ auto *PMulSV = getMulExpr(P, StartVal);
+ QValues.push_back(PMulSV);
+ auto *Resultant = getAddExpr(QValues);
+
+ if (Resultant == StartVal) {
+ // the recurrence is always equal to StartVal
+ forgetMemoizedResults(SymbolicName);
+ insertValueToMap(PN, StartVal);
+ return StartVal;
+ }
+ }
+ }
+ }
}
} else {
// Otherwise, this could be a loop like this:
More information about the llvm-commits
mailing list