[llvm] 7f04ee1 - [SCEV] Move URem matching to ScalarEvolutionPatternMatch.h (#163170)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Oct 13 12:28:58 PDT 2025
Author: Florian Hahn
Date: 2025-10-13T19:28:53Z
New Revision: 7f04ee19d21d28f7a533fff98c69c16863e6984a
URL: https://github.com/llvm/llvm-project/commit/7f04ee19d21d28f7a533fff98c69c16863e6984a
DIFF: https://github.com/llvm/llvm-project/commit/7f04ee19d21d28f7a533fff98c69c16863e6984a.diff
LOG: [SCEV] Move URem matching to ScalarEvolutionPatternMatch.h (#163170)
Move URem matching to ScalarEvolutionPatternMatch.h so it can
be re-used together with other matchers.
Depends on https://github.com/llvm/llvm-project/pull/163169
PR: https://github.com/llvm/llvm-project/pull/163170
Added:
Modified:
llvm/include/llvm/Analysis/ScalarEvolution.h
llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
llvm/lib/Analysis/ScalarEvolution.cpp
llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
llvm/unittests/Analysis/ScalarEvolutionTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 8876e4ed6ae4f..e5a6c8cc0a6aa 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -2316,10 +2316,6 @@ class ScalarEvolution {
/// an add rec on said loop.
void getUsedLoops(const SCEV *S, SmallPtrSetImpl<const Loop *> &LoopsUsed);
- /// Try to match the pattern generated by getURemExpr(A, B). If successful,
- /// Assign A and B to LHS and RHS, respectively.
- LLVM_ABI bool matchURem(const SCEV *Expr, const SCEV *&LHS, const SCEV *&RHS);
-
/// Look for a SCEV expression with type `SCEVType` and operands `Ops` in
/// `UniqueSCEVs`. Return if found, else nullptr.
SCEV *findExistingSCEVInCache(SCEVTypes SCEVType, ArrayRef<const SCEV *> Ops);
diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
index 07a482d4f166a..871028de3163c 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -252,6 +252,80 @@ m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1) {
return m_scev_Binary<SCEVUDivExpr>(Op0, Op1);
}
+/// Match unsigned remainder pattern.
+/// Matches patterns generated by getURemExpr.
+template <typename Op0_t, typename Op1_t> struct SCEVURem_match {
+ Op0_t Op0;
+ Op1_t Op1;
+ ScalarEvolution &SE;
+
+ SCEVURem_match(Op0_t Op0, Op1_t Op1, ScalarEvolution &SE)
+ : Op0(Op0), Op1(Op1), SE(SE) {}
+
+ bool match(const SCEV *Expr) const {
+ if (Expr->getType()->isPointerTy())
+ return false;
+
+ // Try to match 'zext (trunc A to iB) to iY', which is used
+ // for URem with constant power-of-2 second operands. Make sure the size of
+ // the operand A matches the size of the whole expressions.
+ const SCEV *LHS;
+ if (SCEVPatternMatch::match(Expr, m_scev_ZExt(m_scev_Trunc(m_SCEV(LHS))))) {
+ Type *TruncTy = cast<SCEVZeroExtendExpr>(Expr)->getOperand()->getType();
+ // Bail out if the type of the LHS is larger than the type of the
+ // expression for now.
+ if (SE.getTypeSizeInBits(LHS->getType()) >
+ SE.getTypeSizeInBits(Expr->getType()))
+ return false;
+ if (LHS->getType() != Expr->getType())
+ LHS = SE.getZeroExtendExpr(LHS, Expr->getType());
+ const SCEV *RHS =
+ SE.getConstant(APInt(SE.getTypeSizeInBits(Expr->getType()), 1)
+ << SE.getTypeSizeInBits(TruncTy));
+ return Op0.match(LHS) && Op1.match(RHS);
+ }
+ const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
+ if (Add == nullptr || Add->getNumOperands() != 2)
+ return false;
+
+ const SCEV *A = Add->getOperand(1);
+ const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
+
+ if (Mul == nullptr)
+ return false;
+
+ const auto MatchURemWithDivisor = [&](const SCEV *B) {
+ // (SomeExpr + (-(SomeExpr / B) * B)).
+ if (Expr == SE.getURemExpr(A, B))
+ return Op0.match(A) && Op1.match(B);
+ return false;
+ };
+
+ // (SomeExpr + (-1 * (SomeExpr / B) * B)).
+ if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
+ return MatchURemWithDivisor(Mul->getOperand(1)) ||
+ MatchURemWithDivisor(Mul->getOperand(2));
+
+ // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
+ if (Mul->getNumOperands() == 2)
+ return MatchURemWithDivisor(Mul->getOperand(1)) ||
+ MatchURemWithDivisor(Mul->getOperand(0)) ||
+ MatchURemWithDivisor(SE.getNegativeSCEV(Mul->getOperand(1))) ||
+ MatchURemWithDivisor(SE.getNegativeSCEV(Mul->getOperand(0)));
+ return false;
+ }
+};
+
+/// Match the mathematical pattern A - (A / B) * B, where A and B can be
+/// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
+/// for URem with constant power-of-2 second operands. It's not always easy, as
+/// A and B can be folded (imagine A is X / 2, and B is 4, A / B becomes X / 8).
+template <typename Op0_t, typename Op1_t>
+inline SCEVURem_match<Op0_t, Op1_t> m_scev_URem(Op0_t LHS, Op1_t RHS,
+ ScalarEvolution &SE) {
+ return SCEVURem_match<Op0_t, Op1_t>(LHS, RHS, SE);
+}
+
inline class_match<const Loop> m_Loop() { return class_match<const Loop>(); }
/// Match an affine SCEVAddRecExpr.
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 00c3dbbf3e800..3fab6b0572cb7 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -1774,7 +1774,7 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty,
{
const SCEV *LHS;
const SCEV *RHS;
- if (matchURem(Op, LHS, RHS))
+ if (match(Op, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), *this)))
return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
getZeroExtendExpr(RHS, Ty, Depth + 1));
}
@@ -2699,17 +2699,12 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
}
// Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
- if (Ops.size() == 2) {
- const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[0]);
- if (Mul && Mul->getNumOperands() == 2 &&
- Mul->getOperand(0)->isAllOnesValue()) {
- const SCEV *X;
- const SCEV *Y;
- if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) {
- return getMulExpr(Y, getUDivExpr(X, Y));
- }
- }
- }
+ const SCEV *Y;
+ if (Ops.size() == 2 &&
+ match(Ops[0],
+ m_scev_Mul(m_scev_AllOnes(),
+ m_scev_URem(m_scev_Specific(Ops[1]), m_SCEV(Y), *this))))
+ return getMulExpr(Y, getUDivExpr(Ops[1], Y));
// Skip past any other cast SCEVs.
while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
@@ -15410,65 +15405,6 @@ void PredicatedScalarEvolution::print(raw_ostream &OS, unsigned Depth) const {
}
}
-// Match the mathematical pattern A - (A / B) * B, where A and B can be
-// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
-// for URem with constant power-of-2 second operands.
-// It's not always easy, as A and B can be folded (imagine A is X / 2, and B is
-// 4, A / B becomes X / 8).
-bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
- const SCEV *&RHS) {
- if (Expr->getType()->isPointerTy())
- return false;
-
- // Try to match 'zext (trunc A to iB) to iY', which is used
- // for URem with constant power-of-2 second operands. Make sure the size of
- // the operand A matches the size of the whole expressions.
- if (match(Expr, m_scev_ZExt(m_scev_Trunc(m_SCEV(LHS))))) {
- Type *TruncTy = cast<SCEVZeroExtendExpr>(Expr)->getOperand()->getType();
- // Bail out if the type of the LHS is larger than the type of the
- // expression for now.
- if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(Expr->getType()))
- return false;
- if (LHS->getType() != Expr->getType())
- LHS = getZeroExtendExpr(LHS, Expr->getType());
- RHS = getConstant(APInt(getTypeSizeInBits(Expr->getType()), 1)
- << getTypeSizeInBits(TruncTy));
- return true;
- }
- const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
- if (Add == nullptr || Add->getNumOperands() != 2)
- return false;
-
- const SCEV *A = Add->getOperand(1);
- const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
-
- if (Mul == nullptr)
- return false;
-
- const auto MatchURemWithDivisor = [&](const SCEV *B) {
- // (SomeExpr + (-(SomeExpr / B) * B)).
- if (Expr == getURemExpr(A, B)) {
- LHS = A;
- RHS = B;
- return true;
- }
- return false;
- };
-
- // (SomeExpr + (-1 * (SomeExpr / B) * B)).
- if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
- return MatchURemWithDivisor(Mul->getOperand(1)) ||
- MatchURemWithDivisor(Mul->getOperand(2));
-
- // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
- if (Mul->getNumOperands() == 2)
- return MatchURemWithDivisor(Mul->getOperand(1)) ||
- MatchURemWithDivisor(Mul->getOperand(0)) ||
- MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) ||
- MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0)));
- return false;
-}
-
ScalarEvolution::LoopGuards
ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
BasicBlock *Header = L->getHeader();
@@ -15689,20 +15625,18 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
// If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
// explicitly express that.
- const SCEV *URemLHS = nullptr;
+ const SCEVUnknown *URemLHS = nullptr;
const SCEV *URemRHS = nullptr;
- if (SE.matchURem(LHS, URemLHS, URemRHS)) {
- if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
- auto I = RewriteMap.find(LHSUnknown);
- const SCEV *RewrittenLHS =
- I != RewriteMap.end() ? I->second : LHSUnknown;
- RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
- const auto *Multiple =
- SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
- RewriteMap[LHSUnknown] = Multiple;
- ExprsToRewrite.push_back(LHSUnknown);
- return;
- }
+ if (match(LHS,
+ m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE))) {
+ auto I = RewriteMap.find(URemLHS);
+ const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : URemLHS;
+ RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
+ const auto *Multiple =
+ SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
+ RewriteMap[URemLHS] = Multiple;
+ ExprsToRewrite.push_back(URemLHS);
+ return;
}
}
diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
index 45cee1e7da625..9035e58a707c4 100644
--- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
+++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
@@ -526,7 +526,7 @@ Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) {
// Recognize the canonical representation of an unsimplifed urem.
const SCEV *URemLHS = nullptr;
const SCEV *URemRHS = nullptr;
- if (SE.matchURem(S, URemLHS, URemRHS)) {
+ if (match(S, m_scev_URem(m_SCEV(URemLHS), m_SCEV(URemRHS), SE))) {
Value *LHS = expand(URemLHS);
Value *RHS = expand(URemRHS);
return InsertBinop(Instruction::URem, LHS, RHS, SCEV::FlagAnyWrap,
diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
index 1a68823b4f254..5d7eded06a760 100644
--- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
+++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
@@ -11,6 +11,7 @@
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
#include "llvm/Analysis/ScalarEvolutionNormalization.h"
+#include "llvm/Analysis/ScalarEvolutionPatternMatch.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/Constants.h"
@@ -26,6 +27,8 @@
namespace llvm {
+using namespace SCEVPatternMatch;
+
// We use this fixture to ensure that we clean up ScalarEvolution before
// deleting the PassManager.
class ScalarEvolutionsTest : public testing::Test {
@@ -64,11 +67,6 @@ static std::optional<APInt> computeConstantDifference(ScalarEvolution &SE,
return SE.computeConstantDifference(LHS, RHS);
}
- static bool matchURem(ScalarEvolution &SE, const SCEV *Expr, const SCEV *&LHS,
- const SCEV *&RHS) {
- return SE.matchURem(Expr, LHS, RHS);
- }
-
static bool isImpliedCond(
ScalarEvolution &SE, ICmpInst::Predicate Pred, const SCEV *LHS,
const SCEV *RHS, ICmpInst::Predicate FoundPred, const SCEV *FoundLHS,
@@ -1524,7 +1522,7 @@ TEST_F(ScalarEvolutionsTest, MatchURem) {
auto *URemI = getInstructionByName(F, N);
auto *S = SE.getSCEV(URemI);
const SCEV *LHS, *RHS;
- EXPECT_TRUE(matchURem(SE, S, LHS, RHS));
+ EXPECT_TRUE(match(S, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), SE)));
EXPECT_EQ(LHS, SE.getSCEV(URemI->getOperand(0)));
EXPECT_EQ(RHS, SE.getSCEV(URemI->getOperand(1)));
EXPECT_EQ(LHS->getType(), S->getType());
@@ -1537,7 +1535,7 @@ TEST_F(ScalarEvolutionsTest, MatchURem) {
auto *URem1 = getInstructionByName(F, "rem4");
auto *S = SE.getSCEV(Ext);
const SCEV *LHS, *RHS;
- EXPECT_TRUE(matchURem(SE, S, LHS, RHS));
+ EXPECT_TRUE(match(S, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), SE)));
EXPECT_NE(LHS, SE.getSCEV(URem1->getOperand(0)));
// RHS and URem1->getOperand(1) have
diff erent widths, so compare the
// integer values.
More information about the llvm-commits
mailing list