[llvm] [SCEV] Add initial pattern matching for SCEV constants. (NFC) (PR #119389)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 11 05:10:51 PST 2024


https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/119389

>From 5a8ba1cddbfabb07a3aa591966991131f0051fcf Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Tue, 10 Dec 2024 14:28:31 +0000
Subject: [PATCH 1/3] [SCEV] Add initial pattern matching for SCEV constants.
 (NFC)

Add initial pattern matching for SCEV constants. Follow-up patches will
add additional matchers for various SCEV expressions.
---
 .../Analysis/ScalarEvolutionPatternMatch.h    | 59 +++++++++++++++++++
 llvm/lib/Analysis/ScalarEvolution.cpp         | 13 ++--
 2 files changed, 66 insertions(+), 6 deletions(-)
 create mode 100644 llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h

diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
new file mode 100644
index 00000000000000..636b9f8e1544f7
--- /dev/null
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -0,0 +1,59 @@
+//===- ScalarEvolutionPatternMatch.h - Match on SCEVs -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides a simple and efficient mechanism for performing general
+// tree-based pattern matches on SCEVs, based on LLVM's IR pattern matchers.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_ANALYSIS_SCALAREVOLUTIONPATTERNMATCH_H
+#define LLVM_ANALYSIS_SCALAREVOLUTIONPATTERNMATCH_H
+
+#include "llvm/Analysis/ScalarEvolutionExpressions.h"
+
+namespace llvm {
+namespace SCEVPatternMatch {
+
+template <typename Val, typename Pattern>
+bool match(const SCEV *S, const Pattern &P) {
+  return P.match(S);
+}
+
+/// Match a specified integer value. \p BitWidth optionally specifies the
+/// bitwidth the matched constant must have. If it is 0, the matched constant
+/// can have any bitwidth.
+template <unsigned BitWidth = 0> struct specific_intval {
+  APInt Val;
+
+  specific_intval(APInt V) : Val(std::move(V)) {}
+
+  bool match(const SCEV *S) const {
+    const auto *C = dyn_cast<SCEVConstant>(S);
+    if (!C)
+      return false;
+
+    if (BitWidth != 0 && C->getAPInt().getBitWidth() != BitWidth)
+      return false;
+    return APInt::isSameValue(C->getAPInt(), Val);
+  }
+};
+
+inline specific_intval<0> m_scev_Zero() {
+  return specific_intval<0>(APInt(64, 0));
+}
+inline specific_intval<0> m_scev_One() {
+  return specific_intval<0>(APInt(64, 1));
+}
+inline specific_intval<0> m_scev_MinusOne() {
+  return specific_intval<0>(APInt(64, -1));
+}
+
+} // namespace SCEVPatternMatch
+} // namespace llvm
+
+#endif
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index cad10486cbf3fa..741431ac8aa158 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -79,6 +79,7 @@
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/MemoryBuiltins.h"
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/Analysis/ScalarEvolutionPatternMatch.h"
 #include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/Config/llvm-config.h"
@@ -133,6 +134,7 @@
 
 using namespace llvm;
 using namespace PatternMatch;
+using namespace SCEVPatternMatch;
 
 #define DEBUG_TYPE "scalar-evolution"
 
@@ -3423,9 +3425,8 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
     return S;
 
   // 0 udiv Y == 0
-  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
-    if (LHSC->getValue()->isZero())
-      return LHS;
+  if (match(LHS, m_scev_Zero()))
+    return LHS;
 
   if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
     if (RHSC->getValue()->isOne())
@@ -10593,7 +10594,6 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
   // Get the initial value for the loop.
   const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
   const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
-  const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
 
   if (!isLoopInvariant(Step, L))
     return getCouldNotCompute();
@@ -10615,8 +10615,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
   // Handle unitary steps, which cannot wraparound.
   // 1*N = -Start; -1*N = Start (mod 2^BW), so:
   //   N = Distance (as unsigned)
-  if (StepC &&
-      (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne())) {
+
+  if (match(Step, m_CombineOr(m_scev_One(), m_scev_MinusOne()))) {
     APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
     MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
 
@@ -10668,6 +10668,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
   }
 
   // Solve the general equation.
+  const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
   if (!StepC || StepC->getValue()->isZero())
     return getCouldNotCompute();
   const SCEV *E = SolveLinEquationWithOverflow(

>From 7eed2650e304851416cff575ff9e35f1521a01e6 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Tue, 10 Dec 2024 20:46:33 +0000
Subject: [PATCH 2/3] !fixup build int matcher on top of
 PatternMatch::specific_int64

---
 .../Analysis/ScalarEvolutionPatternMatch.h    | 34 ++++++-------------
 1 file changed, 10 insertions(+), 24 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
index 636b9f8e1544f7..a6658df32db31c 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -15,6 +15,7 @@
 #define LLVM_ANALYSIS_SCALAREVOLUTIONPATTERNMATCH_H
 
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/IR/PatternMatch.h"
 
 namespace llvm {
 namespace SCEVPatternMatch {
@@ -24,34 +25,19 @@ bool match(const SCEV *S, const Pattern &P) {
   return P.match(S);
 }
 
-/// Match a specified integer value. \p BitWidth optionally specifies the
-/// bitwidth the matched constant must have. If it is 0, the matched constant
-/// can have any bitwidth.
-template <unsigned BitWidth = 0> struct specific_intval {
-  APInt Val;
+struct specific_intval64 : public PatternMatch::specific_intval64<false> {
+  specific_intval64(uint64_t V) : PatternMatch::specific_intval64<false>(V) {}
 
-  specific_intval(APInt V) : Val(std::move(V)) {}
-
-  bool match(const SCEV *S) const {
-    const auto *C = dyn_cast<SCEVConstant>(S);
-    if (!C)
-      return false;
-
-    if (BitWidth != 0 && C->getAPInt().getBitWidth() != BitWidth)
-      return false;
-    return APInt::isSameValue(C->getAPInt(), Val);
+  bool match(const SCEV *S) {
+    auto *Cast = dyn_cast<SCEVConstant>(S);
+    return Cast &&
+           PatternMatch::specific_intval64<false>::match(Cast->getValue());
   }
 };
 
-inline specific_intval<0> m_scev_Zero() {
-  return specific_intval<0>(APInt(64, 0));
-}
-inline specific_intval<0> m_scev_One() {
-  return specific_intval<0>(APInt(64, 1));
-}
-inline specific_intval<0> m_scev_MinusOne() {
-  return specific_intval<0>(APInt(64, -1));
-}
+inline specific_intval64 m_scev_Zero() { return specific_intval64(0); }
+inline specific_intval64 m_scev_One() { return specific_intval64(1); }
+inline specific_intval64 m_scev_MinusOne() { return specific_intval64(-1); }
 
 } // namespace SCEVPatternMatch
 } // namespace llvm

>From 3f91da915c9a7fb57241c935387b8192aa490629 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Wed, 11 Dec 2024 11:59:25 +0000
Subject: [PATCH 3/3] !fixup use predicates

---
 .../Analysis/ScalarEvolutionPatternMatch.h    | 42 +++++++++++++------
 llvm/lib/Analysis/ScalarEvolution.cpp         | 27 +++---------
 2 files changed, 36 insertions(+), 33 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
index a6658df32db31c..b4adfe2c2d8a49 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -1,5 +1,4 @@
-//===- ScalarEvolutionPatternMatch.h - Match on SCEVs -----------*- C++ -*-===//
-//
+//===----------------------------------------------------------------------===//
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
@@ -15,7 +14,6 @@
 #define LLVM_ANALYSIS_SCALAREVOLUTIONPATTERNMATCH_H
 
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
-#include "llvm/IR/PatternMatch.h"
 
 namespace llvm {
 namespace SCEVPatternMatch {
@@ -25,19 +23,39 @@ bool match(const SCEV *S, const Pattern &P) {
   return P.match(S);
 }
 
-struct specific_intval64 : public PatternMatch::specific_intval64<false> {
-  specific_intval64(uint64_t V) : PatternMatch::specific_intval64<false>(V) {}
-
+template <typename Predicate> struct cst_pred_ty : public Predicate {
   bool match(const SCEV *S) {
-    auto *Cast = dyn_cast<SCEVConstant>(S);
-    return Cast &&
-           PatternMatch::specific_intval64<false>::match(Cast->getValue());
+    assert((isa<SCEVCouldNotCompute>(S) || !S->getType()->isVectorTy()) &&
+           "no vector types expected from SCEVs");
+    auto *C = dyn_cast<SCEVConstant>(S);
+    return C && this->isValue(C->getAPInt());
+  }
+};
+
+struct is_zero {
+  template <typename ITy> bool match(ITy *S) {
+    assert((isa<SCEVCouldNotCompute>(S) || !S->getType()->isVectorTy()) &&
+           "no vector types expected from SCEVs");
+    auto *C = dyn_cast<SCEVConstant>(S);
+    return C && C->getValue()->isNullValue();
   }
 };
+/// Match any null constant.
+inline is_zero m_scev_Zero() { return is_zero(); }
+
+struct is_one {
+  bool isValue(const APInt &C) { return C.isOne(); }
+};
+/// Match an integer 1.
+inline cst_pred_ty<is_one> m_scev_One() { return cst_pred_ty<is_one>(); }
 
-inline specific_intval64 m_scev_Zero() { return specific_intval64(0); }
-inline specific_intval64 m_scev_One() { return specific_intval64(1); }
-inline specific_intval64 m_scev_MinusOne() { return specific_intval64(-1); }
+struct is_all_ones {
+  bool isValue(const APInt &C) { return C.isAllOnes(); }
+};
+/// Match an integer with all bits set.
+inline cst_pred_ty<is_all_ones> m_scev_AllOnes() {
+  return cst_pred_ty<is_all_ones>();
+}
 
 } // namespace SCEVPatternMatch
 } // namespace llvm
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 741431ac8aa158..e18133971f5bf0 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -445,23 +445,11 @@ ArrayRef<const SCEV *> SCEV::operands() const {
   llvm_unreachable("Unknown SCEV kind!");
 }
 
-bool SCEV::isZero() const {
-  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
-    return SC->getValue()->isZero();
-  return false;
-}
+bool SCEV::isZero() const { return match(this, m_scev_Zero()); }
 
-bool SCEV::isOne() const {
-  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
-    return SC->getValue()->isOne();
-  return false;
-}
+bool SCEV::isOne() const { return match(this, m_scev_One()); }
 
-bool SCEV::isAllOnesValue() const {
-  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
-    return SC->getValue()->isMinusOne();
-  return false;
-}
+bool SCEV::isAllOnesValue() const { return match(this, m_scev_AllOnes()); }
 
 bool SCEV::isNonConstantNegative() const {
   const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
@@ -10616,7 +10604,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
   // 1*N = -Start; -1*N = Start (mod 2^BW), so:
   //   N = Distance (as unsigned)
 
-  if (match(Step, m_CombineOr(m_scev_One(), m_scev_MinusOne()))) {
+  if (match(Step, m_CombineOr(m_scev_One(), m_scev_AllOnes()))) {
     APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
     MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
 
@@ -15511,9 +15499,7 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
 
     // If we have LHS == 0, check if LHS is computing a property of some unknown
     // SCEV %v which we can rewrite %v to express explicitly.
-    const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS);
-    if (Predicate == CmpInst::ICMP_EQ && RHSC &&
-        RHSC->getValue()->isNullValue()) {
+    if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
       // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
       // explicitly express that.
       const SCEV *URemLHS = nullptr;
@@ -15694,8 +15680,7 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
           To = RHS;
         break;
       case CmpInst::ICMP_NE:
-        if (isa<SCEVConstant>(RHS) &&
-            cast<SCEVConstant>(RHS)->getValue()->isNullValue()) {
+        if (match(RHS, m_scev_Zero())) {
           const SCEV *OneAlignedUp =
               DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One;
           To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);



More information about the llvm-commits mailing list