[llvm] [SCEV] Add initial pattern matching for SCEV constants. (NFC) (PR #119389)
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Tue Dec 10 12:47:29 PST 2024
https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/119389
>From 13d20aba3b5ddedbc7eb4236acdd034f4d5d4796 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Tue, 10 Dec 2024 14:28:31 +0000
Subject: [PATCH 1/2] [SCEV] Add initial pattern matching for SCEV constants.
(NFC)
Add initial pattern matching for SCEV constants. Follow-up patches will
add additional matchers for various SCEV expressions.
---
.../Analysis/ScalarEvolutionPatternMatch.h | 59 +++++++++++++++++++
llvm/lib/Analysis/ScalarEvolution.cpp | 13 ++--
2 files changed, 66 insertions(+), 6 deletions(-)
create mode 100644 llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
new file mode 100644
index 00000000000000..636b9f8e1544f7
--- /dev/null
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -0,0 +1,59 @@
+//===- ScalarEvolutionPatternMatch.h - Match on SCEVs -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides a simple and efficient mechanism for performing general
+// tree-based pattern matches on SCEVs, based on LLVM's IR pattern matchers.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_ANALYSIS_SCALAREVOLUTIONPATTERNMATCH_H
+#define LLVM_ANALYSIS_SCALAREVOLUTIONPATTERNMATCH_H
+
+#include "llvm/Analysis/ScalarEvolutionExpressions.h"
+
+namespace llvm {
+namespace SCEVPatternMatch {
+
+template <typename Val, typename Pattern>
+bool match(const SCEV *S, const Pattern &P) {
+ return P.match(S);
+}
+
+/// Match a specified integer value. \p BitWidth optionally specifies the
+/// bitwidth the matched constant must have. If it is 0, the matched constant
+/// can have any bitwidth.
+template <unsigned BitWidth = 0> struct specific_intval {
+ APInt Val;
+
+ specific_intval(APInt V) : Val(std::move(V)) {}
+
+ bool match(const SCEV *S) const {
+ const auto *C = dyn_cast<SCEVConstant>(S);
+ if (!C)
+ return false;
+
+ if (BitWidth != 0 && C->getAPInt().getBitWidth() != BitWidth)
+ return false;
+ return APInt::isSameValue(C->getAPInt(), Val);
+ }
+};
+
+inline specific_intval<0> m_scev_Zero() {
+ return specific_intval<0>(APInt(64, 0));
+}
+inline specific_intval<0> m_scev_One() {
+ return specific_intval<0>(APInt(64, 1));
+}
+inline specific_intval<0> m_scev_MinusOne() {
+ return specific_intval<0>(APInt(64, -1));
+}
+
+} // namespace SCEVPatternMatch
+} // namespace llvm
+
+#endif
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index cad10486cbf3fa..741431ac8aa158 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -79,6 +79,7 @@
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/MemoryBuiltins.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/Analysis/ScalarEvolutionPatternMatch.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Config/llvm-config.h"
@@ -133,6 +134,7 @@
using namespace llvm;
using namespace PatternMatch;
+using namespace SCEVPatternMatch;
#define DEBUG_TYPE "scalar-evolution"
@@ -3423,9 +3425,8 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
return S;
// 0 udiv Y == 0
- if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
- if (LHSC->getValue()->isZero())
- return LHS;
+ if (match(LHS, m_scev_Zero()))
+ return LHS;
if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
if (RHSC->getValue()->isOne())
@@ -10593,7 +10594,6 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
// Get the initial value for the loop.
const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
- const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
if (!isLoopInvariant(Step, L))
return getCouldNotCompute();
@@ -10615,8 +10615,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
// Handle unitary steps, which cannot wraparound.
// 1*N = -Start; -1*N = Start (mod 2^BW), so:
// N = Distance (as unsigned)
- if (StepC &&
- (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne())) {
+
+ if (match(Step, m_CombineOr(m_scev_One(), m_scev_MinusOne()))) {
APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
@@ -10668,6 +10668,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
}
// Solve the general equation.
+ const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
if (!StepC || StepC->getValue()->isZero())
return getCouldNotCompute();
const SCEV *E = SolveLinEquationWithOverflow(
>From 5478696ae403428d0050d35bf52849fef56bdda5 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Tue, 10 Dec 2024 20:46:33 +0000
Subject: [PATCH 2/2] !fixup build int matcher on top of
PatternMatch::specific_int64
---
.../Analysis/ScalarEvolutionPatternMatch.h | 34 ++++++-------------
1 file changed, 10 insertions(+), 24 deletions(-)
diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
index 636b9f8e1544f7..a6658df32db31c 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -15,6 +15,7 @@
#define LLVM_ANALYSIS_SCALAREVOLUTIONPATTERNMATCH_H
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/IR/PatternMatch.h"
namespace llvm {
namespace SCEVPatternMatch {
@@ -24,34 +25,19 @@ bool match(const SCEV *S, const Pattern &P) {
return P.match(S);
}
-/// Match a specified integer value. \p BitWidth optionally specifies the
-/// bitwidth the matched constant must have. If it is 0, the matched constant
-/// can have any bitwidth.
-template <unsigned BitWidth = 0> struct specific_intval {
- APInt Val;
+struct specific_intval64 : public PatternMatch::specific_intval64<false> {
+ specific_intval64(uint64_t V) : PatternMatch::specific_intval64<false>(V) {}
- specific_intval(APInt V) : Val(std::move(V)) {}
-
- bool match(const SCEV *S) const {
- const auto *C = dyn_cast<SCEVConstant>(S);
- if (!C)
- return false;
-
- if (BitWidth != 0 && C->getAPInt().getBitWidth() != BitWidth)
- return false;
- return APInt::isSameValue(C->getAPInt(), Val);
+ bool match(const SCEV *S) {
+ auto *Cast = dyn_cast<SCEVConstant>(S);
+ return Cast &&
+ PatternMatch::specific_intval64<false>::match(Cast->getValue());
}
};
-inline specific_intval<0> m_scev_Zero() {
- return specific_intval<0>(APInt(64, 0));
-}
-inline specific_intval<0> m_scev_One() {
- return specific_intval<0>(APInt(64, 1));
-}
-inline specific_intval<0> m_scev_MinusOne() {
- return specific_intval<0>(APInt(64, -1));
-}
+inline specific_intval64 m_scev_Zero() { return specific_intval64(0); }
+inline specific_intval64 m_scev_One() { return specific_intval64(1); }
+inline specific_intval64 m_scev_MinusOne() { return specific_intval64(-1); }
} // namespace SCEVPatternMatch
} // namespace llvm
More information about the llvm-commits
mailing list