[llvm] [STLExtras] Introduce bind_{front, back}, [not_]equal_to (PR #175056)

Ramkumar Ramachandra via llvm-commits llvm-commits at lists.llvm.org
Sun Jan 11 01:48:36 PST 2026


https://github.com/artagnon updated https://github.com/llvm/llvm-project/pull/175056

>From dea0bba6a034c61c0833311657499239d899a2f5 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Thu, 8 Jan 2026 15:40:53 +0000
Subject: [PATCH 1/4] [STLExtras] Introduce BinaryPredicateFunctor,
 [not_]equal_to

We introduce a BinaryPredicateFunctor that generalizes the match
functors in PatternMatch and VPlanPatternMatch, rewriting the functors
in terms of this more general functor. To demonstrate that
BinaryPredicateFunctor is widely applicable, we pose the problem of
shortening a common idiom where we use an STL algorithm like all_of or
any_of, and check the members of the range against a value: we introduce
llvm::{equal_to, not_equal_to} in terms of the general functor.
---
 llvm/include/llvm/ADT/STLExtras.h             | 24 +++++++++++++++++++
 llvm/include/llvm/IR/PatternMatch.h           | 15 ++++--------
 .../Transforms/Vectorize/VPlanPatternMatch.h  | 21 ++++++++--------
 llvm/unittests/ADT/STLExtrasTest.cpp          | 22 +++++++++++++++++
 4 files changed, 62 insertions(+), 20 deletions(-)

diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index 7b0304fd99463..339704dfeba66 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -2157,6 +2157,30 @@ template <typename T> bool all_equal(std::initializer_list<T> Values) {
   return all_equal<std::initializer_list<T>>(std::move(Values));
 }
 
+template <typename ValT, typename RefT, typename BinaryPredicate>
+struct BinaryPredicateFunctor {
+  RefT Ref;
+  const BinaryPredicate &P;
+  BinaryPredicateFunctor(RefT Ref, const BinaryPredicate &P) : Ref(Ref), P(P) {}
+  bool operator()(ValT Val) const { return P(Val, Ref); }
+};
+
+/// Functor variant of std::equal_to that can be used as a UnaryPredicate in
+/// functional algorithms like all_of.
+template <typename T>
+BinaryPredicateFunctor<T, T, decltype(std::equal_to<T>())>
+equal_to(const T &Ref) {
+  return {Ref, std::equal_to<T>()};
+}
+
+/// Functor variant of std::not_equal_to that can be used as a UnaryPredicate in
+/// functional algorithms like all_of.
+template <typename T>
+BinaryPredicateFunctor<T, T, decltype(std::not_equal_to<T>())>
+not_equal_to(const T &Ref) {
+  return {Ref, std::not_equal_to<T>()};
+}
+
 /// Provide a container algorithm similar to C++ Library Fundamentals v2's
 /// `erase_if` which is equivalent to:
 ///
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 88aef4a368f29..a3fd5e7cb1d80 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -46,21 +46,16 @@
 namespace llvm {
 namespace PatternMatch {
 
-template <typename Val, typename Pattern> bool match(Val *V, const Pattern &P) {
+template <typename Val, typename Pattern> bool match(Val V, const Pattern &P) {
   return P.match(V);
 }
 
-template <typename Val, typename Pattern> struct MatchFunctor {
-  const Pattern &P;
-  MatchFunctor(const Pattern &P) : P(P) {}
-  bool operator()(Val *V) const { return P.match(V); }
-};
-
 /// A match functor that can be used as a UnaryPredicate in functional
 /// algorithms like all_of.
-template <typename Val = const Value, typename Pattern>
-MatchFunctor<Val, Pattern> match_fn(const Pattern &P) {
-  return P;
+template <typename Val = const Value *, typename Pattern>
+BinaryPredicateFunctor<Val, Pattern, decltype(match<Val, Pattern>)>
+match_fn(const Pattern &P) {
+  return {P, match<Val, Pattern>};
 }
 
 template <typename Pattern> bool match(ArrayRef<int> Mask, const Pattern &P) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
index 3732d009b9537..0127cd7bad2bd 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
@@ -19,7 +19,7 @@
 
 namespace llvm::VPlanPatternMatch {
 
-template <typename Val, typename Pattern> bool match(Val *V, const Pattern &P) {
+template <typename Val, typename Pattern> bool match(Val V, const Pattern &P) {
   return P.match(V);
 }
 
@@ -32,17 +32,18 @@ template <typename Pattern> bool match(VPSingleDefRecipe *R, const Pattern &P) {
   return P.match(static_cast<const VPRecipeBase *>(R));
 }
 
-template <typename Val, typename Pattern> struct VPMatchFunctor {
-  const Pattern &P;
-  VPMatchFunctor(const Pattern &P) : P(P) {}
-  bool operator()(Val *V) const { return match(V, P); }
-};
-
 /// A match functor that can be used as a UnaryPredicate in functional
 /// algorithms like all_of.
-template <typename Val = VPUser, typename Pattern>
-VPMatchFunctor<Val, Pattern> match_fn(const Pattern &P) {
-  return P;
+template <typename Val, typename Pattern>
+BinaryPredicateFunctor<Val, Pattern, decltype(match<Val, Pattern>)>
+match_fn(const Pattern &P) {
+  return {P, match<Val, Pattern>};
+}
+
+template <typename Pattern>
+BinaryPredicateFunctor<VPUser *, Pattern, decltype(match<VPUser *, Pattern>)>
+match_fn(const Pattern &P) {
+  return {P, match<Pattern>};
 }
 
 template <typename Class> struct class_match {
diff --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp
index c712df11d382e..62abbbd88f9f7 100644
--- a/llvm/unittests/ADT/STLExtrasTest.cpp
+++ b/llvm/unittests/ADT/STLExtrasTest.cpp
@@ -1055,6 +1055,28 @@ TEST(STLExtrasTest, to_address) {
   EXPECT_EQ(V1, llvm::to_address(V3));
 }
 
+TEST(STLExtrasTest, EqualToNotEqualTo) {
+  std::vector<int> V;
+  EXPECT_TRUE(all_of(V, equal_to(1)));
+  EXPECT_TRUE(all_of(V, not_equal_to(1)));
+
+  V.push_back(1);
+  EXPECT_TRUE(all_of(V, equal_to(1)));
+  EXPECT_TRUE(all_of(V, not_equal_to(2)));
+
+  V.push_back(1);
+  V.push_back(1);
+  EXPECT_TRUE(all_of(V, equal_to(1)));
+  EXPECT_TRUE(all_of(V, not_equal_to(2)));
+  EXPECT_TRUE(none_of(V, equal_to(2)));
+
+  V.push_back(2);
+  EXPECT_FALSE(all_of(V, equal_to(1)));
+  EXPECT_FALSE(all_of(V, not_equal_to(1)));
+  EXPECT_TRUE(any_of(V, equal_to(2)));
+  EXPECT_TRUE(any_of(V, not_equal_to(2)));
+}
+
 TEST(STLExtrasTest, partition_point) {
   std::vector<int> V = {1, 3, 5, 7, 9};
 

>From c9eeb03d899b0b66982cece37b5e91cfae2942f4 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Thu, 8 Jan 2026 20:44:12 +0000
Subject: [PATCH 2/4] [STLExtras] Simplify implementation

Co-authored-by: Jakub Kuderski <jakub at nod-labs.com>
---
 llvm/include/llvm/ADT/STLExtras.h             | 32 ++++++++++---------
 llvm/include/llvm/IR/PatternMatch.h           |  5 ++-
 .../Transforms/Vectorize/VPlanPatternMatch.h  | 13 ++------
 3 files changed, 22 insertions(+), 28 deletions(-)

diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index 339704dfeba66..65c72e0a76c19 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -2157,28 +2157,30 @@ template <typename T> bool all_equal(std::initializer_list<T> Values) {
   return all_equal<std::initializer_list<T>>(std::move(Values));
 }
 
-template <typename ValT, typename RefT, typename BinaryPredicate>
-struct BinaryPredicateFunctor {
-  RefT Ref;
-  const BinaryPredicate &P;
-  BinaryPredicateFunctor(RefT Ref, const BinaryPredicate &P) : Ref(Ref), P(P) {}
-  bool operator()(ValT Val) const { return P(Val, Ref); }
-};
+template <typename RefT, typename PredT>
+auto bind_first(const RefT &Ref, const PredT &Pred) {
+  return [&](auto &&...Val) {
+    return Pred(Ref, std::forward<decltype(Val)>(Val)...);
+  };
+}
+
+template <typename RefT, typename PredT>
+auto bind_last(const RefT &Ref, const PredT &Pred) {
+  return [&](auto &&...Val) {
+    return Pred(std::forward<decltype(Val)>(Val)..., Ref);
+  };
+}
 
 /// Functor variant of std::equal_to that can be used as a UnaryPredicate in
 /// functional algorithms like all_of.
-template <typename T>
-BinaryPredicateFunctor<T, T, decltype(std::equal_to<T>())>
-equal_to(const T &Ref) {
-  return {Ref, std::equal_to<T>()};
+template <typename T> auto equal_to(const T &Ref) {
+  return bind_first(Ref, std::equal_to<>());
 }
 
 /// Functor variant of std::not_equal_to that can be used as a UnaryPredicate in
 /// functional algorithms like all_of.
-template <typename T>
-BinaryPredicateFunctor<T, T, decltype(std::not_equal_to<T>())>
-not_equal_to(const T &Ref) {
-  return {Ref, std::not_equal_to<T>()};
+template <typename T> auto not_equal_to(const T &Ref) {
+  return bind_first(Ref, std::not_equal_to<>());
 }
 
 /// Provide a container algorithm similar to C++ Library Fundamentals v2's
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index a3fd5e7cb1d80..de216052d8427 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -53,9 +53,8 @@ template <typename Val, typename Pattern> bool match(Val V, const Pattern &P) {
 /// A match functor that can be used as a UnaryPredicate in functional
 /// algorithms like all_of.
 template <typename Val = const Value *, typename Pattern>
-BinaryPredicateFunctor<Val, Pattern, decltype(match<Val, Pattern>)>
-match_fn(const Pattern &P) {
-  return {P, match<Val, Pattern>};
+auto match_fn(const Pattern &P) {
+  return bind_last(P, match<Val, Pattern>);
 }
 
 template <typename Pattern> bool match(ArrayRef<int> Mask, const Pattern &P) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
index 0127cd7bad2bd..44d20c97cdc51 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
@@ -34,16 +34,9 @@ template <typename Pattern> bool match(VPSingleDefRecipe *R, const Pattern &P) {
 
 /// A match functor that can be used as a UnaryPredicate in functional
 /// algorithms like all_of.
-template <typename Val, typename Pattern>
-BinaryPredicateFunctor<Val, Pattern, decltype(match<Val, Pattern>)>
-match_fn(const Pattern &P) {
-  return {P, match<Val, Pattern>};
-}
-
-template <typename Pattern>
-BinaryPredicateFunctor<VPUser *, Pattern, decltype(match<VPUser *, Pattern>)>
-match_fn(const Pattern &P) {
-  return {P, match<Pattern>};
+template <typename Val = VPUser *, typename Pattern>
+auto match_fn(const Pattern &P) {
+  return bind_last<Pattern, decltype(match<Val, Pattern>)>(P, match);
 }
 
 template <typename Class> struct class_match {

>From e7ab4baf4af4534056825db9d13c1536005ca948 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Fri, 9 Jan 2026 10:20:02 +0000
Subject: [PATCH 3/4] [STL] Back-port bind_{front,back}

---
 llvm/include/llvm/ADT/STLExtras.h             | 25 ++++----------
 llvm/include/llvm/ADT/STLForwardCompat.h      | 33 +++++++++++++++++++
 llvm/include/llvm/IR/PatternMatch.h           |  6 ++--
 .../Transforms/Vectorize/VPlanPatternMatch.h  | 21 +++++++-----
 4 files changed, 56 insertions(+), 29 deletions(-)

diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index 65c72e0a76c19..9b2ce68c02020 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -2157,30 +2157,19 @@ template <typename T> bool all_equal(std::initializer_list<T> Values) {
   return all_equal<std::initializer_list<T>>(std::move(Values));
 }
 
-template <typename RefT, typename PredT>
-auto bind_first(const RefT &Ref, const PredT &Pred) {
-  return [&](auto &&...Val) {
-    return Pred(Ref, std::forward<decltype(Val)>(Val)...);
-  };
-}
-
-template <typename RefT, typename PredT>
-auto bind_last(const RefT &Ref, const PredT &Pred) {
-  return [&](auto &&...Val) {
-    return Pred(std::forward<decltype(Val)>(Val)..., Ref);
-  };
-}
-
 /// Functor variant of std::equal_to that can be used as a UnaryPredicate in
 /// functional algorithms like all_of.
-template <typename T> auto equal_to(const T &Ref) {
-  return bind_first(Ref, std::equal_to<>());
+template <typename T>
+constexpr auto equal_to(const T &Arg) { // NOLINT(readability-identifier-naming)
+  return bind_front<std::declval<std::equal_to<T>()>>(Arg);
 }
 
 /// Functor variant of std::not_equal_to that can be used as a UnaryPredicate in
 /// functional algorithms like all_of.
-template <typename T> auto not_equal_to(const T &Ref) {
-  return bind_first(Ref, std::not_equal_to<>());
+template <typename T>
+constexpr auto
+not_equal_to(const T &Arg) { // NOLINT(readability-identifier-naming)
+  return bind_front<std::declval<std::not_equal_to<T>()>>(Arg);
 }
 
 /// Provide a container algorithm similar to C++ Library Fundamentals v2's
diff --git a/llvm/include/llvm/ADT/STLForwardCompat.h b/llvm/include/llvm/ADT/STLForwardCompat.h
index b975a403cd042..133f287a0c2ec 100644
--- a/llvm/include/llvm/ADT/STLForwardCompat.h
+++ b/llvm/include/llvm/ADT/STLForwardCompat.h
@@ -17,6 +17,7 @@
 #ifndef LLVM_ADT_STLFORWARDCOMPAT_H
 #define LLVM_ADT_STLFORWARDCOMPAT_H
 
+#include <functional>
 #include <optional>
 #include <type_traits>
 #include <utility>
@@ -177,4 +178,36 @@ struct from_range_t {
 inline constexpr from_range_t from_range{};
 } // namespace llvm
 
+//===----------------------------------------------------------------------===//
+//     Features from C++26
+//===----------------------------------------------------------------------===//
+
+template <auto Fn, typename... BindArgsT>
+constexpr auto
+bind_front(BindArgsT &&...BindArgs) { // NOLINT(readability-identifier-naming)
+  using FnT = decltype(Fn);
+
+  if constexpr (std::is_pointer_v<FnT> or std::is_member_pointer_v<FnT>)
+    static_assert(Fn != nullptr);
+
+  return [&BindArgs...](auto &&...CallArgs) {
+    return std::invoke(Fn, std::forward<BindArgsT>(BindArgs)...,
+                       std::forward<decltype(CallArgs)>(CallArgs)...);
+  };
+}
+
+template <auto Fn, typename... BindArgsT>
+constexpr auto
+bind_back(BindArgsT &&...BindArgs) { // NOLINT(readability-identifier-naming)
+  using FnT = decltype(Fn);
+
+  if constexpr (std::is_pointer_v<FnT> or std::is_member_pointer_v<FnT>)
+    static_assert(Fn != nullptr);
+
+  return [&BindArgs...](auto &&...CallArgs) {
+    return std::invoke(Fn, std::forward<decltype(CallArgs)>(CallArgs)...,
+                       std::forward<BindArgsT>(BindArgs)...);
+  };
+}
+
 #endif // LLVM_ADT_STLFORWARDCOMPAT_H
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index de216052d8427..9bed9f2207d79 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -46,15 +46,15 @@
 namespace llvm {
 namespace PatternMatch {
 
-template <typename Val, typename Pattern> bool match(Val V, const Pattern &P) {
+template <typename Val, typename Pattern> bool match(Val *V, const Pattern &P) {
   return P.match(V);
 }
 
 /// A match functor that can be used as a UnaryPredicate in functional
 /// algorithms like all_of.
-template <typename Val = const Value *, typename Pattern>
+template <typename Val = const Value, typename Pattern>
 auto match_fn(const Pattern &P) {
-  return bind_last(P, match<Val, Pattern>);
+  return bind_back<match<Val, Pattern>>(P);
 }
 
 template <typename Pattern> bool match(ArrayRef<int> Mask, const Pattern &P) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
index 44d20c97cdc51..03506d1fcdba1 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
@@ -19,24 +19,29 @@
 
 namespace llvm::VPlanPatternMatch {
 
-template <typename Val, typename Pattern> bool match(Val V, const Pattern &P) {
+template <typename Val, typename Pattern> bool match(Val *V, const Pattern &P) {
   return P.match(V);
 }
 
+/// A match functor that can be used as a UnaryPredicate in functional
+/// algorithms like all_of.
+template <typename Val, typename Pattern>
+constexpr auto match_fn(const Pattern &P) {
+  return bind_back<match<Val, Pattern>>(P);
+}
+
 template <typename Pattern> bool match(VPUser *U, const Pattern &P) {
   auto *R = dyn_cast<VPRecipeBase>(U);
   return R && match(R, P);
 }
 
-template <typename Pattern> bool match(VPSingleDefRecipe *R, const Pattern &P) {
-  return P.match(static_cast<const VPRecipeBase *>(R));
+/// Match functor for VPUser.
+template <typename Pattern> constexpr auto match_fn(const Pattern &P) {
+  return bind_back<match<Pattern>>(P);
 }
 
-/// A match functor that can be used as a UnaryPredicate in functional
-/// algorithms like all_of.
-template <typename Val = VPUser *, typename Pattern>
-auto match_fn(const Pattern &P) {
-  return bind_last<Pattern, decltype(match<Val, Pattern>)>(P, match);
+template <typename Pattern> bool match(VPSingleDefRecipe *R, const Pattern &P) {
+  return P.match(static_cast<const VPRecipeBase *>(R));
 }
 
 template <typename Class> struct class_match {

>From 41d9dc0db3b0a0efb4b61e36f3968196e6393b71 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <r at artagnon.com>
Date: Sun, 11 Jan 2026 07:13:30 +0000
Subject: [PATCH 4/4] [ADT] Fix build by using Fn param

I've implemented a modified version of Yanzuo's suggestion, with a note
about perfect-forwarding and Self.

Ref: https://eel.is/c++draft/expr.prim.lambda#17
Ref: https://www.open-std.org/jtc1/sc22/wg21/docs/cwg_defects.html#2011
Ref: https://stackoverflow.com/q/55559308/1088790
Ref: https://vittorioromeo.info/index/blog/capturing_perfectly_forwarded_objects_in_lambdas.html

Co-authored-by: Yanzuo Liu <zwuis at outlook.com>
---
 llvm/include/llvm/ADT/STLExtras.h             |  4 +-
 llvm/include/llvm/ADT/STLForwardCompat.h      | 52 ++++++++++---------
 llvm/include/llvm/IR/PatternMatch.h           |  2 +-
 .../Transforms/Vectorize/VPlanPatternMatch.h  |  4 +-
 llvm/unittests/ADT/STLExtrasTest.cpp          |  2 +-
 llvm/unittests/ADT/STLForwardCompatTest.cpp   | 35 +++++++++++++
 6 files changed, 68 insertions(+), 31 deletions(-)

diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index 9b2ce68c02020..142bd43aa407c 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -2161,7 +2161,7 @@ template <typename T> bool all_equal(std::initializer_list<T> Values) {
 /// functional algorithms like all_of.
 template <typename T>
 constexpr auto equal_to(const T &Arg) { // NOLINT(readability-identifier-naming)
-  return bind_front<std::declval<std::equal_to<T>()>>(Arg);
+  return bind_front(std::equal_to<>(), Arg);
 }
 
 /// Functor variant of std::not_equal_to that can be used as a UnaryPredicate in
@@ -2169,7 +2169,7 @@ constexpr auto equal_to(const T &Arg) { // NOLINT(readability-identifier-naming)
 template <typename T>
 constexpr auto
 not_equal_to(const T &Arg) { // NOLINT(readability-identifier-naming)
-  return bind_front<std::declval<std::not_equal_to<T>()>>(Arg);
+  return bind_front(std::not_equal_to<>(), Arg);
 }
 
 /// Provide a container algorithm similar to C++ Library Fundamentals v2's
diff --git a/llvm/include/llvm/ADT/STLForwardCompat.h b/llvm/include/llvm/ADT/STLForwardCompat.h
index 133f287a0c2ec..861fed72d2dc5 100644
--- a/llvm/include/llvm/ADT/STLForwardCompat.h
+++ b/llvm/include/llvm/ADT/STLForwardCompat.h
@@ -148,6 +148,23 @@ template <class T> constexpr T *to_address(T *P) {
   return P;
 }
 
+// The behvaior is the same as std::bind_front, with the following differences:
+// - BindArgs are not perfect-forwarded
+// - The value catagory and const of the returned lambda are not considered when
+// calling it. An approach to handle them is using deducing this and
+// std::forward_like (C++23).
+template <typename FnT, typename... BindArgsT>
+constexpr auto bind_front(FnT &&Fn, // NOLINT(readability-identifier-naming)
+                          BindArgsT &&...BindArgs) {
+  if constexpr (std::is_pointer_v<FnT> or std::is_member_pointer_v<FnT>)
+    static_assert(Fn != nullptr);
+
+  return [&BindArgs..., Fn = std::forward<FnT>(Fn)](auto &&...CallArgs) {
+    return std::invoke(Fn, std::forward<BindArgsT>(BindArgs)...,
+                       std::forward<decltype(CallArgs)>(CallArgs)...);
+  };
+}
+
 //===----------------------------------------------------------------------===//
 //     Features from C++23
 //===----------------------------------------------------------------------===//
@@ -176,38 +193,23 @@ struct from_range_t {
   explicit from_range_t() = default;
 };
 inline constexpr from_range_t from_range{};
-} // namespace llvm
-
-//===----------------------------------------------------------------------===//
-//     Features from C++26
-//===----------------------------------------------------------------------===//
-
-template <auto Fn, typename... BindArgsT>
-constexpr auto
-bind_front(BindArgsT &&...BindArgs) { // NOLINT(readability-identifier-naming)
-  using FnT = decltype(Fn);
 
+// The behvaior is the same as std::bind_back, with the following differences:
+// - BindArgs are not perfect-forwarded
+// - The value catagory and const of the returned lambda are not considered when
+// calling it. An approach to handle them is using deducing this and
+// std::forward_like (C++23).
+template <typename FnT, typename... BindArgsT>
+constexpr auto bind_back(FnT &&Fn, // NOLINT(readability-identifier-naming)
+                         BindArgsT &&...BindArgs) {
   if constexpr (std::is_pointer_v<FnT> or std::is_member_pointer_v<FnT>)
     static_assert(Fn != nullptr);
 
-  return [&BindArgs...](auto &&...CallArgs) {
-    return std::invoke(Fn, std::forward<BindArgsT>(BindArgs)...,
-                       std::forward<decltype(CallArgs)>(CallArgs)...);
-  };
-}
-
-template <auto Fn, typename... BindArgsT>
-constexpr auto
-bind_back(BindArgsT &&...BindArgs) { // NOLINT(readability-identifier-naming)
-  using FnT = decltype(Fn);
-
-  if constexpr (std::is_pointer_v<FnT> or std::is_member_pointer_v<FnT>)
-    static_assert(Fn != nullptr);
-
-  return [&BindArgs...](auto &&...CallArgs) {
+  return [&BindArgs..., Fn = std::forward<FnT>(Fn)](auto &&...CallArgs) {
     return std::invoke(Fn, std::forward<decltype(CallArgs)>(CallArgs)...,
                        std::forward<BindArgsT>(BindArgs)...);
   };
 }
+} // namespace llvm
 
 #endif // LLVM_ADT_STLFORWARDCOMPAT_H
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 9bed9f2207d79..772873e428471 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -54,7 +54,7 @@ template <typename Val, typename Pattern> bool match(Val *V, const Pattern &P) {
 /// algorithms like all_of.
 template <typename Val = const Value, typename Pattern>
 auto match_fn(const Pattern &P) {
-  return bind_back<match<Val, Pattern>>(P);
+  return bind_back(match<Val, Pattern>, P);
 }
 
 template <typename Pattern> bool match(ArrayRef<int> Mask, const Pattern &P) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
index 03506d1fcdba1..81fb24f8fda8a 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
@@ -27,7 +27,7 @@ template <typename Val, typename Pattern> bool match(Val *V, const Pattern &P) {
 /// algorithms like all_of.
 template <typename Val, typename Pattern>
 constexpr auto match_fn(const Pattern &P) {
-  return bind_back<match<Val, Pattern>>(P);
+  return bind_back(match<Val, Pattern>, P);
 }
 
 template <typename Pattern> bool match(VPUser *U, const Pattern &P) {
@@ -37,7 +37,7 @@ template <typename Pattern> bool match(VPUser *U, const Pattern &P) {
 
 /// Match functor for VPUser.
 template <typename Pattern> constexpr auto match_fn(const Pattern &P) {
-  return bind_back<match<Pattern>>(P);
+  return bind_back(match<Pattern>, P);
 }
 
 template <typename Pattern> bool match(VPSingleDefRecipe *R, const Pattern &P) {
diff --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp
index 62abbbd88f9f7..fe71945e4a794 100644
--- a/llvm/unittests/ADT/STLExtrasTest.cpp
+++ b/llvm/unittests/ADT/STLExtrasTest.cpp
@@ -1055,7 +1055,7 @@ TEST(STLExtrasTest, to_address) {
   EXPECT_EQ(V1, llvm::to_address(V3));
 }
 
-TEST(STLExtrasTest, EqualToNotEqualTo) {
+TEST(STLExtras, EqualToNotEqualTo) {
   std::vector<int> V;
   EXPECT_TRUE(all_of(V, equal_to(1)));
   EXPECT_TRUE(all_of(V, not_equal_to(1)));
diff --git a/llvm/unittests/ADT/STLForwardCompatTest.cpp b/llvm/unittests/ADT/STLForwardCompatTest.cpp
index d0092fdb52b01..7fc60e4c2e6a8 100644
--- a/llvm/unittests/ADT/STLForwardCompatTest.cpp
+++ b/llvm/unittests/ADT/STLForwardCompatTest.cpp
@@ -14,6 +14,8 @@
 #include <type_traits>
 #include <utility>
 
+using namespace llvm;
+
 namespace {
 
 template <typename T>
@@ -205,4 +207,37 @@ TEST(STLForwardCompatTest, IdentityCxx20) {
   static_assert(std::is_same_v<int &&, decltype(identity(int(5)))>);
 }
 
+TEST(STLForwardCompat, BindFrontBindBack) {
+  std::vector<int> V;
+  auto MulAdd = [](int A, int B, int C) { return A * (B + C) == 12; };
+  auto Mul0 = bind_back(MulAdd, 4, 2);
+  auto Mul1 = bind_front(MulAdd, 2, 4);
+  auto Mul20 = bind_back(MulAdd, 4);
+  auto Mul21 = bind_front(MulAdd, 2);
+  EXPECT_TRUE(all_of(V, Mul0));
+  EXPECT_TRUE(all_of(V, Mul1));
+
+  V.push_back(2);
+  EXPECT_TRUE(all_of(V, Mul0));
+  EXPECT_TRUE(all_of(V, Mul1));
+
+  V.push_back(2);
+  V.push_back(2);
+  EXPECT_TRUE(all_of(V, Mul0));
+  EXPECT_TRUE(all_of(V, Mul1));
+
+  auto Spec0 = bind_front(Mul20, 2);
+  auto Spec1 = bind_back(Mul21, 4);
+  EXPECT_TRUE(all_of(V, Spec0));
+  EXPECT_TRUE(all_of(V, Spec1));
+
+  V.push_back(3);
+  EXPECT_FALSE(all_of(V, Mul0));
+  EXPECT_FALSE(all_of(V, Mul1));
+  EXPECT_FALSE(all_of(V, Spec0));
+  EXPECT_FALSE(all_of(V, Spec1));
+  EXPECT_TRUE(any_of(V, Spec0));
+  EXPECT_TRUE(any_of(V, Spec1));
+}
+
 } // namespace



More information about the llvm-commits mailing list