[llvm] [STLExtras] Introduce BinaryPredicateFunctor, [not_]equal_to (PR #175056)

Ramkumar Ramachandra via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 8 11:33:48 PST 2026


https://github.com/artagnon created https://github.com/llvm/llvm-project/pull/175056

We introduce a BinaryPredicateFunctor that generalizes the match functors in PatternMatch and VPlanPatternMatch, rewriting the functors in terms of this more general functor. To demonstrate that BinaryPredicateFunctor is widely applicable, we pose the problem of shortening a common idiom where we use an STL algorithm like all_of or any_of, and check the members of the range against a value: we introduce llvm::{equal_to, not_equal_to} in terms of the general functor.

>From 572443388cb4deb545e055c5980a9b1b4f89853d Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Thu, 8 Jan 2026 15:40:53 +0000
Subject: [PATCH] [STLExtras] Introduce BinaryPredicateFunctor, [not_]equal_to

We introduce a BinaryPredicateFunctor that generalizes the match
functors in PatternMatch and VPlanPatternMatch, rewriting the functors
in terms of this more general functor. To demonstrate that
BinaryPredicateFunctor is widely applicable, we pose the problem of
shortening a common idiom where we use an STL algorithm like all_of or
any_of, and check the members of the range against a value: we introduce
llvm::{equal_to, not_equal_to} in terms of the general functor.
---
 llvm/include/llvm/ADT/STLExtras.h             | 24 +++++++++++++++++++
 llvm/include/llvm/IR/PatternMatch.h           | 15 ++++--------
 .../Transforms/Vectorize/VPlanPatternMatch.h  | 21 ++++++++--------
 llvm/unittests/ADT/STLExtrasTest.cpp          | 22 +++++++++++++++++
 4 files changed, 62 insertions(+), 20 deletions(-)

diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index 7b0304fd99463..339704dfeba66 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -2157,6 +2157,30 @@ template <typename T> bool all_equal(std::initializer_list<T> Values) {
   return all_equal<std::initializer_list<T>>(std::move(Values));
 }
 
+template <typename ValT, typename RefT, typename BinaryPredicate>
+struct BinaryPredicateFunctor {
+  RefT Ref;
+  const BinaryPredicate &P;
+  BinaryPredicateFunctor(RefT Ref, const BinaryPredicate &P) : Ref(Ref), P(P) {}
+  bool operator()(ValT Val) const { return P(Val, Ref); }
+};
+
+/// Functor variant of std::equal_to that can be used as a UnaryPredicate in
+/// functional algorithms like all_of.
+template <typename T>
+BinaryPredicateFunctor<T, T, decltype(std::equal_to<T>())>
+equal_to(const T &Ref) {
+  return {Ref, std::equal_to<T>()};
+}
+
+/// Functor variant of std::not_equal_to that can be used as a UnaryPredicate in
+/// functional algorithms like all_of.
+template <typename T>
+BinaryPredicateFunctor<T, T, decltype(std::not_equal_to<T>())>
+not_equal_to(const T &Ref) {
+  return {Ref, std::not_equal_to<T>()};
+}
+
 /// Provide a container algorithm similar to C++ Library Fundamentals v2's
 /// `erase_if` which is equivalent to:
 ///
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 88aef4a368f29..a3fd5e7cb1d80 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -46,21 +46,16 @@
 namespace llvm {
 namespace PatternMatch {
 
-template <typename Val, typename Pattern> bool match(Val *V, const Pattern &P) {
+template <typename Val, typename Pattern> bool match(Val V, const Pattern &P) {
   return P.match(V);
 }
 
-template <typename Val, typename Pattern> struct MatchFunctor {
-  const Pattern &P;
-  MatchFunctor(const Pattern &P) : P(P) {}
-  bool operator()(Val *V) const { return P.match(V); }
-};
-
 /// A match functor that can be used as a UnaryPredicate in functional
 /// algorithms like all_of.
-template <typename Val = const Value, typename Pattern>
-MatchFunctor<Val, Pattern> match_fn(const Pattern &P) {
-  return P;
+template <typename Val = const Value *, typename Pattern>
+BinaryPredicateFunctor<Val, Pattern, decltype(match<Val, Pattern>)>
+match_fn(const Pattern &P) {
+  return {P, match<Val, Pattern>};
 }
 
 template <typename Pattern> bool match(ArrayRef<int> Mask, const Pattern &P) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
index 3732d009b9537..0127cd7bad2bd 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
@@ -19,7 +19,7 @@
 
 namespace llvm::VPlanPatternMatch {
 
-template <typename Val, typename Pattern> bool match(Val *V, const Pattern &P) {
+template <typename Val, typename Pattern> bool match(Val V, const Pattern &P) {
   return P.match(V);
 }
 
@@ -32,17 +32,18 @@ template <typename Pattern> bool match(VPSingleDefRecipe *R, const Pattern &P) {
   return P.match(static_cast<const VPRecipeBase *>(R));
 }
 
-template <typename Val, typename Pattern> struct VPMatchFunctor {
-  const Pattern &P;
-  VPMatchFunctor(const Pattern &P) : P(P) {}
-  bool operator()(Val *V) const { return match(V, P); }
-};
-
 /// A match functor that can be used as a UnaryPredicate in functional
 /// algorithms like all_of.
-template <typename Val = VPUser, typename Pattern>
-VPMatchFunctor<Val, Pattern> match_fn(const Pattern &P) {
-  return P;
+template <typename Val, typename Pattern>
+BinaryPredicateFunctor<Val, Pattern, decltype(match<Val, Pattern>)>
+match_fn(const Pattern &P) {
+  return {P, match<Val, Pattern>};
+}
+
+template <typename Pattern>
+BinaryPredicateFunctor<VPUser *, Pattern, decltype(match<VPUser *, Pattern>)>
+match_fn(const Pattern &P) {
+  return {P, match<Pattern>};
 }
 
 template <typename Class> struct class_match {
diff --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp
index c712df11d382e..62abbbd88f9f7 100644
--- a/llvm/unittests/ADT/STLExtrasTest.cpp
+++ b/llvm/unittests/ADT/STLExtrasTest.cpp
@@ -1055,6 +1055,28 @@ TEST(STLExtrasTest, to_address) {
   EXPECT_EQ(V1, llvm::to_address(V3));
 }
 
+TEST(STLExtrasTest, EqualToNotEqualTo) {
+  std::vector<int> V;
+  EXPECT_TRUE(all_of(V, equal_to(1)));
+  EXPECT_TRUE(all_of(V, not_equal_to(1)));
+
+  V.push_back(1);
+  EXPECT_TRUE(all_of(V, equal_to(1)));
+  EXPECT_TRUE(all_of(V, not_equal_to(2)));
+
+  V.push_back(1);
+  V.push_back(1);
+  EXPECT_TRUE(all_of(V, equal_to(1)));
+  EXPECT_TRUE(all_of(V, not_equal_to(2)));
+  EXPECT_TRUE(none_of(V, equal_to(2)));
+
+  V.push_back(2);
+  EXPECT_FALSE(all_of(V, equal_to(1)));
+  EXPECT_FALSE(all_of(V, not_equal_to(1)));
+  EXPECT_TRUE(any_of(V, equal_to(2)));
+  EXPECT_TRUE(any_of(V, not_equal_to(2)));
+}
+
 TEST(STLExtrasTest, partition_point) {
   std::vector<int> V = {1, 3, 5, 7, 9};
 



More information about the llvm-commits mailing list