[llvm] [LSR] Clean up code using SCEVPatternMatch (NFC) (PR #145556)
Ramkumar Ramachandra via llvm-commits
llvm-commits at lists.llvm.org
Fri Jun 27 07:01:11 PDT 2025
https://github.com/artagnon updated https://github.com/llvm/llvm-project/pull/145556
>From eebabb8defb3b504a882bf27cb466997b123c7d0 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Tue, 24 Jun 2025 18:12:09 +0100
Subject: [PATCH 1/2] [LSR] Clean up code using SCEVPatternMatch (NFC)
---
.../Transforms/Scalar/LoopStrengthReduce.cpp | 111 ++++++++----------
1 file changed, 48 insertions(+), 63 deletions(-)
diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index 4ba69034d6448..9ffcfb47b0a89 100644
--- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -923,10 +923,12 @@ static const SCEV *getExactSDiv(const SCEV *LHS, const SCEV *RHS,
/// If S involves the addition of a constant integer value, return that integer
/// value, and mutate S to point to a new SCEV with that value excluded.
static Immediate ExtractImmediate(const SCEV *&S, ScalarEvolution &SE) {
- if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) {
- if (C->getAPInt().getSignificantBits() <= 64) {
- S = SE.getConstant(C->getType(), 0);
- return Immediate::getFixed(C->getValue()->getSExtValue());
+ const APInt *C;
+ const SCEV *Op1;
+ if (match(S, m_scev_APInt(C))) {
+ if (C->getSignificantBits() <= 64) {
+ S = SE.getConstant(S->getType(), 0);
+ return Immediate::getFixed(C->getSExtValue());
}
} else if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
SmallVector<const SCEV *, 8> NewOps(Add->operands());
@@ -942,14 +944,11 @@ static Immediate ExtractImmediate(const SCEV *&S, ScalarEvolution &SE) {
// FIXME: AR->getNoWrapFlags(SCEV::FlagNW)
SCEV::FlagAnyWrap);
return Result;
- } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
- if (EnableVScaleImmediates && M->getNumOperands() == 2) {
- if (const SCEVConstant *C = dyn_cast<SCEVConstant>(M->getOperand(0)))
- if (isa<SCEVVScale>(M->getOperand(1))) {
- S = SE.getConstant(M->getType(), 0);
- return Immediate::getScalable(C->getValue()->getSExtValue());
- }
- }
+ } else if (EnableVScaleImmediates &&
+ match(S, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op1))) &&
+ isa<SCEVVScale>(Op1)) {
+ S = SE.getConstant(S->getType(), 0);
+ return Immediate::getScalable(C->getSExtValue());
}
return Immediate::getZero();
}
@@ -1133,23 +1132,22 @@ static bool isHighCostExpansion(const SCEV *S,
return false;
}
- if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
- if (Mul->getNumOperands() == 2) {
- // Multiplication by a constant is ok
- if (isa<SCEVConstant>(Mul->getOperand(0)))
- return isHighCostExpansion(Mul->getOperand(1), Processed, SE);
-
- // If we have the value of one operand, check if an existing
- // multiplication already generates this expression.
- if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(Mul->getOperand(1))) {
- Value *UVal = U->getValue();
- for (User *UR : UVal->users()) {
- // If U is a constant, it may be used by a ConstantExpr.
- Instruction *UI = dyn_cast<Instruction>(UR);
- if (UI && UI->getOpcode() == Instruction::Mul &&
- SE.isSCEVable(UI->getType())) {
- return SE.getSCEV(UI) == Mul;
- }
+ const SCEV *Op0, *Op1;
+ if (match(S, m_scev_Mul(m_SCEV(Op0), m_SCEV(Op1)))) {
+ // Multiplication by a constant is ok
+ if (isa<SCEVConstant>(Op0))
+ return isHighCostExpansion(Op1, Processed, SE);
+
+ // If we have the value of one operand, check if an existing
+ // multiplication already generates this expression.
+ if (const auto *U = dyn_cast<SCEVUnknown>(Op1)) {
+ Value *UVal = U->getValue();
+ for (User *UR : UVal->users()) {
+ // If U is a constant, it may be used by a ConstantExpr.
+ Instruction *UI = dyn_cast<Instruction>(UR);
+ if (UI && UI->getOpcode() == Instruction::Mul &&
+ SE.isSCEVable(UI->getType())) {
+ return SE.getSCEV(UI) == S;
}
}
}
@@ -3333,14 +3331,12 @@ static bool canFoldIVIncExpr(const SCEV *IncExpr, Instruction *UserInst,
IncOffset = Immediate::getFixed(IncConst->getValue()->getSExtValue());
} else {
// Look for mul(vscale, constant), to detect a scalable offset.
- auto *IncVScale = dyn_cast<SCEVMulExpr>(IncExpr);
- if (!IncVScale || IncVScale->getNumOperands() != 2 ||
- !isa<SCEVVScale>(IncVScale->getOperand(1)))
- return false;
- auto *Scale = dyn_cast<SCEVConstant>(IncVScale->getOperand(0));
- if (!Scale || Scale->getType()->getScalarSizeInBits() > 64)
+ const APInt *C;
+ const SCEV *Op1;
+ if (!match(IncExpr, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op1))) ||
+ !isa<SCEVVScale>(Op1) || C->getSignificantBits() > 64)
return false;
- IncOffset = Immediate::getScalable(Scale->getValue()->getSExtValue());
+ IncOffset = Immediate::getScalable(C->getSExtValue());
}
if (!isAddressUse(TTI, UserInst, Operand))
@@ -3818,6 +3814,8 @@ static const SCEV *CollectSubexprs(const SCEV *S, const SCEVConstant *C,
return nullptr;
}
const SCEV *Start, *Step;
+ const SCEVConstant *Op0;
+ const SCEV *Op1;
if (match(S, m_scev_AffineAddRec(m_SCEV(Start), m_SCEV(Step)))) {
// Split a non-zero base out of an addrec.
if (Start->isZero())
@@ -3839,19 +3837,13 @@ static const SCEV *CollectSubexprs(const SCEV *S, const SCEVConstant *C,
// FIXME: AR->getNoWrapFlags(SCEV::FlagNW)
SCEV::FlagAnyWrap);
}
- } else if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
+ } else if (match(S, m_scev_Mul(m_SCEVConstant(Op0), m_SCEV(Op1)))) {
// Break (C * (a + b + c)) into C*a + C*b + C*c.
- if (Mul->getNumOperands() != 2)
- return S;
- if (const SCEVConstant *Op0 =
- dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
- C = C ? cast<SCEVConstant>(SE.getMulExpr(C, Op0)) : Op0;
- const SCEV *Remainder =
- CollectSubexprs(Mul->getOperand(1), C, Ops, L, SE, Depth+1);
- if (Remainder)
- Ops.push_back(SE.getMulExpr(C, Remainder));
- return nullptr;
- }
+ C = C ? cast<SCEVConstant>(SE.getMulExpr(C, Op0)) : Op0;
+ const SCEV *Remainder = CollectSubexprs(Op1, C, Ops, L, SE, Depth + 1);
+ if (Remainder)
+ Ops.push_back(SE.getMulExpr(C, Remainder));
+ return nullptr;
}
return S;
}
@@ -6478,13 +6470,10 @@ struct SCEVDbgValueBuilder {
/// Components of the expression are omitted if they are an identity function.
/// Chain (non-affine) SCEVs are not supported.
bool SCEVToValueExpr(const llvm::SCEVAddRecExpr &SAR, ScalarEvolution &SE) {
- assert(SAR.isAffine() && "Expected affine SCEV");
- // TODO: Is this check needed?
- if (isa<SCEVAddRecExpr>(SAR.getStart()))
- return false;
-
- const SCEV *Start = SAR.getStart();
- const SCEV *Stride = SAR.getStepRecurrence(SE);
+ const SCEV *Start, *Stride;
+ [[maybe_unused]] bool Match =
+ match(&SAR, m_scev_AffineAddRec(m_SCEV(Start), m_SCEV(Stride)));
+ assert(Match && "Expected affine SCEV");
// Skip pushing arithmetic noops.
if (!isIdentityFunction(llvm::dwarf::DW_OP_mul, Stride)) {
@@ -6549,14 +6538,10 @@ struct SCEVDbgValueBuilder {
/// Components of the expression are omitted if they are an identity function.
bool SCEVToIterCountExpr(const llvm::SCEVAddRecExpr &SAR,
ScalarEvolution &SE) {
- assert(SAR.isAffine() && "Expected affine SCEV");
- if (isa<SCEVAddRecExpr>(SAR.getStart())) {
- LLVM_DEBUG(dbgs() << "scev-salvage: IV SCEV. Unsupported nested AddRec: "
- << SAR << '\n');
- return false;
- }
- const SCEV *Start = SAR.getStart();
- const SCEV *Stride = SAR.getStepRecurrence(SE);
+ const SCEV *Start, *Stride;
+ [[maybe_unused]] bool Match =
+ match(&SAR, m_scev_AffineAddRec(m_SCEV(Start), m_SCEV(Stride)));
+ assert(Match && "Expected affine SCEV");
// Skip pushing arithmetic noops.
if (!isIdentityFunction(llvm::dwarf::DW_OP_minus, Start)) {
>From 1d0a32ec3fa30e36349a5ddac998fab625237094 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Fri, 27 Jun 2025 14:49:41 +0100
Subject: [PATCH 2/2] [LSR/SCEVPM] Introduce and use m_SCEVVScale
---
llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h | 3 +++
llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp | 9 +++------
2 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
index 8e9d7e0b72142..09e3945f5a8ff 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -64,6 +64,9 @@ inline class_match<const SCEV> m_SCEV() { return class_match<const SCEV>(); }
inline class_match<const SCEVConstant> m_SCEVConstant() {
return class_match<const SCEVConstant>();
}
+inline class_match<const SCEVVScale> m_SCEVVScale() {
+ return class_match<const SCEVVScale>();
+}
template <typename Class> struct bind_ty {
Class *&VR;
diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index 9ffcfb47b0a89..dcaa3a22638e0 100644
--- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -924,7 +924,6 @@ static const SCEV *getExactSDiv(const SCEV *LHS, const SCEV *RHS,
/// value, and mutate S to point to a new SCEV with that value excluded.
static Immediate ExtractImmediate(const SCEV *&S, ScalarEvolution &SE) {
const APInt *C;
- const SCEV *Op1;
if (match(S, m_scev_APInt(C))) {
if (C->getSignificantBits() <= 64) {
S = SE.getConstant(S->getType(), 0);
@@ -945,8 +944,7 @@ static Immediate ExtractImmediate(const SCEV *&S, ScalarEvolution &SE) {
SCEV::FlagAnyWrap);
return Result;
} else if (EnableVScaleImmediates &&
- match(S, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op1))) &&
- isa<SCEVVScale>(Op1)) {
+ match(S, m_scev_Mul(m_scev_APInt(C), m_SCEVVScale()))) {
S = SE.getConstant(S->getType(), 0);
return Immediate::getScalable(C->getSExtValue());
}
@@ -3332,9 +3330,8 @@ static bool canFoldIVIncExpr(const SCEV *IncExpr, Instruction *UserInst,
} else {
// Look for mul(vscale, constant), to detect a scalable offset.
const APInt *C;
- const SCEV *Op1;
- if (!match(IncExpr, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op1))) ||
- !isa<SCEVVScale>(Op1) || C->getSignificantBits() > 64)
+ if (!match(IncExpr, m_scev_Mul(m_scev_APInt(C), m_SCEVVScale())) ||
+ C->getSignificantBits() > 64)
return false;
IncOffset = Immediate::getScalable(C->getSExtValue());
}
More information about the llvm-commits
mailing list