[llvm] [STLExtras] Introduce BinaryPredicateFunctor, [not_]equal_to (PR #175056)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 8 11:34:21 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-adt

@llvm/pr-subscribers-llvm-ir

Author: Ramkumar Ramachandra (artagnon)

<details>
<summary>Changes</summary>

We introduce a BinaryPredicateFunctor that generalizes the match functors in PatternMatch and VPlanPatternMatch, rewriting the functors in terms of this more general functor. To demonstrate that BinaryPredicateFunctor is widely applicable, we pose the problem of shortening a common idiom where we use an STL algorithm like all_of or any_of, and check the members of the range against a value: we introduce llvm::{equal_to, not_equal_to} in terms of the general functor.

---
Full diff: https://github.com/llvm/llvm-project/pull/175056.diff


4 Files Affected:

- (modified) llvm/include/llvm/ADT/STLExtras.h (+24) 
- (modified) llvm/include/llvm/IR/PatternMatch.h (+5-10) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h (+11-10) 
- (modified) llvm/unittests/ADT/STLExtrasTest.cpp (+22) 


``````````diff
diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index 7b0304fd99463..339704dfeba66 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -2157,6 +2157,30 @@ template <typename T> bool all_equal(std::initializer_list<T> Values) {
   return all_equal<std::initializer_list<T>>(std::move(Values));
 }
 
+template <typename ValT, typename RefT, typename BinaryPredicate>
+struct BinaryPredicateFunctor {
+  RefT Ref;
+  const BinaryPredicate &P;
+  BinaryPredicateFunctor(RefT Ref, const BinaryPredicate &P) : Ref(Ref), P(P) {}
+  bool operator()(ValT Val) const { return P(Val, Ref); }
+};
+
+/// Functor variant of std::equal_to that can be used as a UnaryPredicate in
+/// functional algorithms like all_of.
+template <typename T>
+BinaryPredicateFunctor<T, T, decltype(std::equal_to<T>())>
+equal_to(const T &Ref) {
+  return {Ref, std::equal_to<T>()};
+}
+
+/// Functor variant of std::not_equal_to that can be used as a UnaryPredicate in
+/// functional algorithms like all_of.
+template <typename T>
+BinaryPredicateFunctor<T, T, decltype(std::not_equal_to<T>())>
+not_equal_to(const T &Ref) {
+  return {Ref, std::not_equal_to<T>()};
+}
+
 /// Provide a container algorithm similar to C++ Library Fundamentals v2's
 /// `erase_if` which is equivalent to:
 ///
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 88aef4a368f29..a3fd5e7cb1d80 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -46,21 +46,16 @@
 namespace llvm {
 namespace PatternMatch {
 
-template <typename Val, typename Pattern> bool match(Val *V, const Pattern &P) {
+template <typename Val, typename Pattern> bool match(Val V, const Pattern &P) {
   return P.match(V);
 }
 
-template <typename Val, typename Pattern> struct MatchFunctor {
-  const Pattern &P;
-  MatchFunctor(const Pattern &P) : P(P) {}
-  bool operator()(Val *V) const { return P.match(V); }
-};
-
 /// A match functor that can be used as a UnaryPredicate in functional
 /// algorithms like all_of.
-template <typename Val = const Value, typename Pattern>
-MatchFunctor<Val, Pattern> match_fn(const Pattern &P) {
-  return P;
+template <typename Val = const Value *, typename Pattern>
+BinaryPredicateFunctor<Val, Pattern, decltype(match<Val, Pattern>)>
+match_fn(const Pattern &P) {
+  return {P, match<Val, Pattern>};
 }
 
 template <typename Pattern> bool match(ArrayRef<int> Mask, const Pattern &P) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
index 3732d009b9537..0127cd7bad2bd 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
@@ -19,7 +19,7 @@
 
 namespace llvm::VPlanPatternMatch {
 
-template <typename Val, typename Pattern> bool match(Val *V, const Pattern &P) {
+template <typename Val, typename Pattern> bool match(Val V, const Pattern &P) {
   return P.match(V);
 }
 
@@ -32,17 +32,18 @@ template <typename Pattern> bool match(VPSingleDefRecipe *R, const Pattern &P) {
   return P.match(static_cast<const VPRecipeBase *>(R));
 }
 
-template <typename Val, typename Pattern> struct VPMatchFunctor {
-  const Pattern &P;
-  VPMatchFunctor(const Pattern &P) : P(P) {}
-  bool operator()(Val *V) const { return match(V, P); }
-};
-
 /// A match functor that can be used as a UnaryPredicate in functional
 /// algorithms like all_of.
-template <typename Val = VPUser, typename Pattern>
-VPMatchFunctor<Val, Pattern> match_fn(const Pattern &P) {
-  return P;
+template <typename Val, typename Pattern>
+BinaryPredicateFunctor<Val, Pattern, decltype(match<Val, Pattern>)>
+match_fn(const Pattern &P) {
+  return {P, match<Val, Pattern>};
+}
+
+template <typename Pattern>
+BinaryPredicateFunctor<VPUser *, Pattern, decltype(match<VPUser *, Pattern>)>
+match_fn(const Pattern &P) {
+  return {P, match<Pattern>};
 }
 
 template <typename Class> struct class_match {
diff --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp
index c712df11d382e..62abbbd88f9f7 100644
--- a/llvm/unittests/ADT/STLExtrasTest.cpp
+++ b/llvm/unittests/ADT/STLExtrasTest.cpp
@@ -1055,6 +1055,28 @@ TEST(STLExtrasTest, to_address) {
   EXPECT_EQ(V1, llvm::to_address(V3));
 }
 
+TEST(STLExtrasTest, EqualToNotEqualTo) {
+  std::vector<int> V;
+  EXPECT_TRUE(all_of(V, equal_to(1)));
+  EXPECT_TRUE(all_of(V, not_equal_to(1)));
+
+  V.push_back(1);
+  EXPECT_TRUE(all_of(V, equal_to(1)));
+  EXPECT_TRUE(all_of(V, not_equal_to(2)));
+
+  V.push_back(1);
+  V.push_back(1);
+  EXPECT_TRUE(all_of(V, equal_to(1)));
+  EXPECT_TRUE(all_of(V, not_equal_to(2)));
+  EXPECT_TRUE(none_of(V, equal_to(2)));
+
+  V.push_back(2);
+  EXPECT_FALSE(all_of(V, equal_to(1)));
+  EXPECT_FALSE(all_of(V, not_equal_to(1)));
+  EXPECT_TRUE(any_of(V, equal_to(2)));
+  EXPECT_TRUE(any_of(V, not_equal_to(2)));
+}
+
 TEST(STLExtrasTest, partition_point) {
   std::vector<int> V = {1, 3, 5, 7, 9};
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/175056


More information about the llvm-commits mailing list