[llvm] [ADT] Add `sum_of` and `product_of` accumulate wrappers. (PR #162129)

via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 6 10:52:06 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-adt

Author: Jakub Kuderski (kuhar)

<details>
<summary>Changes</summary>

Also extend the `accumulate` wrapper to accept a binary operator.

The goal is to the most common usage of `std::accumulate` across the codebase -- calculating either the sum of or the product of all values.

---
Full diff: https://github.com/llvm/llvm-project/pull/162129.diff


2 Files Affected:

- (modified) llvm/include/llvm/ADT/STLExtras.h (+22) 
- (modified) llvm/unittests/ADT/STLExtrasTest.cpp (+39) 


``````````diff
diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index 4a91b061dd3b7..5b20d6bd38262 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -1692,6 +1692,28 @@ template <typename R, typename E> auto accumulate(R &&Range, E &&Init) {
                          std::forward<E>(Init));
 }
 
+/// Wrapper for std::accumulate with a binary operator.
+template <typename R, typename E, typename BinaryOp>
+auto accumulate(R &&Range, E &&Init, BinaryOp &&Op) {
+  return std::accumulate(adl_begin(Range), adl_end(Range),
+                         std::forward<E>(Init), std::forward<BinaryOp>(Op));
+}
+
+/// Returns the sum of all values in `Range` with `Init` initial value.
+/// The default initial value is 0.
+template <typename R, typename E = detail::ValueOfRange<R>>
+auto sum_of(R &&Range, E Init = E{0}) {
+  return accumulate(std::forward<R>(Range), std::move(Init));
+}
+
+/// Returns the product of all values in `Range` with `Init` initial value.
+/// The default initial value is 1.
+template <typename R, typename E = detail::ValueOfRange<R>>
+auto product_of(R &&Range, E Init = E{1}) {
+  return accumulate(std::forward<R>(Range), std::move(Init),
+                    std::multiplies<>{});
+}
+
 /// Provide wrappers to std::for_each which take ranges instead of having to
 /// pass begin/end explicitly.
 template <typename R, typename UnaryFunction>
diff --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp
index 5020acda95b0b..52fb423713a05 100644
--- a/llvm/unittests/ADT/STLExtrasTest.cpp
+++ b/llvm/unittests/ADT/STLExtrasTest.cpp
@@ -14,6 +14,7 @@
 #include <array>
 #include <climits>
 #include <cstddef>
+#include <functional>
 #include <initializer_list>
 #include <iterator>
 #include <list>
@@ -1658,6 +1659,44 @@ TEST(STLExtrasTest, Accumulate) {
   EXPECT_EQ(accumulate(V1, 10), std::accumulate(V1.begin(), V1.end(), 10));
   EXPECT_EQ(accumulate(drop_begin(V1), 7),
             std::accumulate(V1.begin() + 1, V1.end(), 7));
+
+  EXPECT_EQ(accumulate(V1, 2, std::multiplies<>{}), 240);
+}
+
+TEST(STLExtrasTest, SumOf) {
+  EXPECT_EQ(sum_of(std::vector<int>()), 0);
+  EXPECT_EQ(sum_of(std::vector<int>(), 1), 1);
+  std::vector<int> V1 = {1, 2, 3, 4, 5};
+  static_assert(std::is_same_v<decltype(sum_of(V1)), int>);
+  static_assert(std::is_same_v<decltype(sum_of(V1, 1)), int>);
+  EXPECT_EQ(sum_of(V1), 15);
+  EXPECT_EQ(sum_of(V1, 1), 16);
+
+  std::vector<float> V2 = {1.0f, 2.0f, 4.0f};
+  static_assert(std::is_same_v<decltype(sum_of(V2)), float>);
+  static_assert(std::is_same_v<decltype(sum_of(V2), 1.0f), float>);
+  static_assert(std::is_same_v<decltype(sum_of(V2), 1.0), double>);
+  EXPECT_EQ(sum_of(V2), 7.0f);
+  EXPECT_EQ(sum_of(V2, 1.0f), 8.0f);
+}
+
+TEST(STLExtrasTest, ProductOf) {
+  EXPECT_EQ(product_of(std::vector<int>()), 1);
+  EXPECT_EQ(product_of(std::vector<int>(), 0), 0);
+  EXPECT_EQ(product_of(std::vector<int>(), 1), 1);
+  std::vector<int> V1 = {1, 2, 3, 4, 5};
+  static_assert(std::is_same_v<decltype(product_of(V1)), int>);
+  static_assert(std::is_same_v<decltype(product_of(V1, 1)), int>);
+  EXPECT_EQ(product_of(V1), 120);
+  EXPECT_EQ(product_of(V1, 1), 120);
+  EXPECT_EQ(product_of(V1, 2), 240);
+
+  std::vector<float> V2 = {1.0f, 2.0f, 4.0f};
+  static_assert(std::is_same_v<decltype(product_of(V2)), float>);
+  static_assert(std::is_same_v<decltype(product_of(V2), 1.0f), float>);
+  static_assert(std::is_same_v<decltype(product_of(V2), 1.0), double>);
+  EXPECT_EQ(product_of(V2), 8.0f);
+  EXPECT_EQ(product_of(V2, 4.0f), 32.0f);
 }
 
 struct Foo;

``````````

</details>


https://github.com/llvm/llvm-project/pull/162129


More information about the llvm-commits mailing list