[llvm] [ValueTracking] Support horizontal vector add in computeKnownBits (PR #174410)

Valeriy Savchenko via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 7 01:52:46 PST 2026


https://github.com/SavchenkoValeriy updated https://github.com/llvm/llvm-project/pull/174410

>From 68bec97755ca2982b5dc72ceeeaf303e470795c6 Mon Sep 17 00:00:00 2001
From: Valeriy Savchenko <vsavchenko at apple.com>
Date: Mon, 5 Jan 2026 14:23:13 +0000
Subject: [PATCH] [ValueTracking] Support horizontal vector add in
 computeKnownBits

---
 llvm/include/llvm/Support/KnownBits.h         |  5 +++
 llvm/lib/Analysis/ValueTracking.cpp           |  8 ++++
 llvm/lib/Support/KnownBits.cpp                | 40 +++++++++++++++++
 .../vector-reduce-add-known-bits.ll           | 45 +++++++++++++++++++
 .../PhaseOrdering/AArch64/udotabd.ll          | 20 ++++-----
 llvm/unittests/Support/KnownBitsTest.cpp      | 34 ++++++++++++++
 6 files changed, 142 insertions(+), 10 deletions(-)
 create mode 100644 llvm/test/Transforms/InstCombine/vector-reduce-add-known-bits.ll

diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h
index bff944325880b..ad901be01f0ae 100644
--- a/llvm/include/llvm/Support/KnownBits.h
+++ b/llvm/include/llvm/Support/KnownBits.h
@@ -511,6 +511,11 @@ struct KnownBits {
   /// Compute known bits for the absolute value.
   LLVM_ABI KnownBits abs(bool IntMinIsPoison = false) const;
 
+  /// Compute known bits for horizontal add for a vector with NumElts
+  /// elements, where each element has the known bits represented by this
+  /// object.
+  LLVM_ABI KnownBits reduceAdd(unsigned NumElts) const;
+
   KnownBits byteSwap() const {
     return KnownBits(Zero.byteSwap(), One.byteSwap());
   }
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 9cb6f19b9340c..3d2315811b7af 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -2132,6 +2132,14 @@ static void computeKnownBitsFromOperator(const Operator *I,
           Known.One.clearAllBits();
         break;
       }
+      case Intrinsic::vector_reduce_add: {
+        auto *VecTy = dyn_cast<FixedVectorType>(I->getOperand(0)->getType());
+        if (!VecTy)
+          break;
+        computeKnownBits(I->getOperand(0), Known, Q, Depth + 1);
+        Known = Known.reduceAdd(VecTy->getNumElements());
+        break;
+      }
       case Intrinsic::umin:
         computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
         computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 7db8e1641462e..2c4841bce5b53 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -601,6 +601,46 @@ KnownBits KnownBits::abs(bool IntMinIsPoison) const {
   return KnownAbs;
 }
 
+KnownBits KnownBits::reduceAdd(unsigned NumElts) const {
+  if (NumElts == 0)
+    return KnownBits(getBitWidth());
+
+  unsigned BitWidth = getBitWidth();
+  KnownBits Result(BitWidth);
+
+  if (isConstant())
+    // If all elements are the same constant, we can simply compute it
+    return KnownBits::makeConstant(NumElts * getConstant());
+
+  // The main idea is as follows.
+  //
+  // If KnownBits for each element has L leading zeros then
+  // X_i < 2^(W - L) for every i from [1, N].
+  //
+  //   ADD X_i <= ADD max(X_i) = N * max(X_i)
+  //           <  N * 2^(W - L)
+  //           <  2^(W - L + ceil(log2(N)))
+  //
+  // As the result, we can conclude that
+  //
+  //   L' = L - ceil(log2(N)) = L - bit_width(N - 1)
+  //
+  // Similar logic can be applied to leading ones.
+  unsigned LostBits = NumElts > 1 ? llvm::bit_width(NumElts - 1) : 0;
+
+  if (isNonNegative()) {
+    unsigned LeadingZeros = countMinLeadingZeros();
+    LeadingZeros = LeadingZeros > LostBits ? LeadingZeros - LostBits : 0;
+    Result.Zero.setHighBits(LeadingZeros);
+  } else if (isNegative()) {
+    unsigned LeadingOnes = countMinLeadingOnes();
+    LeadingOnes = LeadingOnes > LostBits ? LeadingOnes - LostBits : 0;
+    Result.One.setHighBits(LeadingOnes);
+  }
+
+  return Result;
+}
+
 static KnownBits computeForSatAddSub(bool Add, bool Signed,
                                      const KnownBits &LHS,
                                      const KnownBits &RHS) {
diff --git a/llvm/test/Transforms/InstCombine/vector-reduce-add-known-bits.ll b/llvm/test/Transforms/InstCombine/vector-reduce-add-known-bits.ll
new file mode 100644
index 0000000000000..60b898b492063
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/vector-reduce-add-known-bits.ll
@@ -0,0 +1,45 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define i32 @reduce_add_eliminate_mask(ptr %p) {
+; CHECK-LABEL: define i32 @reduce_add_eliminate_mask(
+; CHECK-SAME: ptr [[P:%.*]]) {
+; CHECK-NEXT:    [[VEC:%.*]] = load <4 x i32>, ptr [[P]], align 16
+; CHECK-NEXT:    [[AND:%.*]] = and <4 x i32> [[VEC]], splat (i32 268435455)
+; CHECK-NEXT:    [[SUM:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[AND]])
+; CHECK-NEXT:    ret i32 [[SUM]]
+;
+  %vec = load <4 x i32>, ptr %p
+  %and = and <4 x i32> %vec, splat (i32 268435455)
+  %sum = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %and)
+  %masked = and i32 %sum, 1073741823
+  ret i32 %masked
+}
+
+define i1 @reduce_add_simplify_comparison(ptr %p) {
+; CHECK-LABEL: define i1 @reduce_add_simplify_comparison(
+; CHECK-SAME: ptr [[P:%.*]]) {
+; CHECK-NEXT:    ret i1 true
+;
+  %vec = load <8 x i32>, ptr %p
+  %and = and <8 x i32> %vec, splat (i32 16777215)
+  %sum = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %and)
+  %cmp = icmp ult i32 %sum, 134217728
+  ret i1 %cmp
+}
+
+define i64 @reduce_add_sext(ptr %p) {
+; CHECK-LABEL: define i64 @reduce_add_sext(
+; CHECK-SAME: ptr [[P:%.*]]) {
+; CHECK-NEXT:    [[VEC:%.*]] = load <2 x i32>, ptr [[P]], align 8
+; CHECK-NEXT:    [[AND:%.*]] = and <2 x i32> [[VEC]], splat (i32 4194303)
+; CHECK-NEXT:    [[SUM:%.*]] = call i32 @llvm.vector.reduce.add.v2i32(<2 x i32> [[AND]])
+; CHECK-NEXT:    [[EXT:%.*]] = zext nneg i32 [[SUM]] to i64
+; CHECK-NEXT:    ret i64 [[EXT]]
+;
+  %vec = load <2 x i32>, ptr %p
+  %and = and <2 x i32> %vec, splat (i32 4194303)
+  %sum = call i32 @llvm.vector.reduce.add.v2i32(<2 x i32> %and)
+  %ext = sext i32 %sum to i64
+  ret i64 %ext
+}
diff --git a/llvm/test/Transforms/PhaseOrdering/AArch64/udotabd.ll b/llvm/test/Transforms/PhaseOrdering/AArch64/udotabd.ll
index 4c7e39d31b5c6..e2f7f8f7e5cac 100644
--- a/llvm/test/Transforms/PhaseOrdering/AArch64/udotabd.ll
+++ b/llvm/test/Transforms/PhaseOrdering/AArch64/udotabd.ll
@@ -29,7 +29,7 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
 ; CHECK-O3-NEXT:    [[TMP13:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP12]], i1 false)
 ; CHECK-O3-NEXT:    [[TMP14:%.*]] = zext <16 x i16> [[TMP13]] to <16 x i32>
 ; CHECK-O3-NEXT:    [[TMP15:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP14]])
-; CHECK-O3-NEXT:    [[OP_RDX_1:%.*]] = add i32 [[TMP15]], [[TMP7]]
+; CHECK-O3-NEXT:    [[OP_RDX_1:%.*]] = add nuw nsw i32 [[TMP15]], [[TMP7]]
 ; CHECK-O3-NEXT:    [[ADD_PTR_1:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR]], i64 [[IDX_EXT]]
 ; CHECK-O3-NEXT:    [[ADD_PTR9_1:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9]], i64 [[IDX_EXT8]]
 ; CHECK-O3-NEXT:    [[TMP16:%.*]] = load <16 x i8>, ptr [[ADD_PTR_1]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -40,7 +40,7 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
 ; CHECK-O3-NEXT:    [[TMP21:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP20]], i1 false)
 ; CHECK-O3-NEXT:    [[TMP22:%.*]] = zext <16 x i16> [[TMP21]] to <16 x i32>
 ; CHECK-O3-NEXT:    [[TMP23:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP22]])
-; CHECK-O3-NEXT:    [[OP_RDX_2:%.*]] = add i32 [[TMP23]], [[OP_RDX_1]]
+; CHECK-O3-NEXT:    [[OP_RDX_2:%.*]] = add nuw nsw i32 [[TMP23]], [[OP_RDX_1]]
 ; CHECK-O3-NEXT:    [[ADD_PTR_2:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_1]], i64 [[IDX_EXT]]
 ; CHECK-O3-NEXT:    [[ADD_PTR9_2:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_1]], i64 [[IDX_EXT8]]
 ; CHECK-O3-NEXT:    [[TMP24:%.*]] = load <16 x i8>, ptr [[ADD_PTR_2]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -51,7 +51,7 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
 ; CHECK-O3-NEXT:    [[TMP29:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP28]], i1 false)
 ; CHECK-O3-NEXT:    [[TMP30:%.*]] = zext <16 x i16> [[TMP29]] to <16 x i32>
 ; CHECK-O3-NEXT:    [[TMP31:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP30]])
-; CHECK-O3-NEXT:    [[OP_RDX_3:%.*]] = add i32 [[TMP31]], [[OP_RDX_2]]
+; CHECK-O3-NEXT:    [[OP_RDX_3:%.*]] = add nuw nsw i32 [[TMP31]], [[OP_RDX_2]]
 ; CHECK-O3-NEXT:    [[ADD_PTR_3:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_2]], i64 [[IDX_EXT]]
 ; CHECK-O3-NEXT:    [[ADD_PTR9_3:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_2]], i64 [[IDX_EXT8]]
 ; CHECK-O3-NEXT:    [[TMP32:%.*]] = load <16 x i8>, ptr [[ADD_PTR_3]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -62,7 +62,7 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
 ; CHECK-O3-NEXT:    [[TMP37:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP36]], i1 false)
 ; CHECK-O3-NEXT:    [[TMP38:%.*]] = zext <16 x i16> [[TMP37]] to <16 x i32>
 ; CHECK-O3-NEXT:    [[TMP39:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP38]])
-; CHECK-O3-NEXT:    [[OP_RDX_4:%.*]] = add i32 [[TMP39]], [[OP_RDX_3]]
+; CHECK-O3-NEXT:    [[OP_RDX_4:%.*]] = add nuw nsw i32 [[TMP39]], [[OP_RDX_3]]
 ; CHECK-O3-NEXT:    [[ADD_PTR_4:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_3]], i64 [[IDX_EXT]]
 ; CHECK-O3-NEXT:    [[ADD_PTR9_4:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_3]], i64 [[IDX_EXT8]]
 ; CHECK-O3-NEXT:    [[TMP40:%.*]] = load <16 x i8>, ptr [[ADD_PTR_4]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -73,7 +73,7 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
 ; CHECK-O3-NEXT:    [[TMP45:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP44]], i1 false)
 ; CHECK-O3-NEXT:    [[TMP46:%.*]] = zext <16 x i16> [[TMP45]] to <16 x i32>
 ; CHECK-O3-NEXT:    [[TMP47:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP46]])
-; CHECK-O3-NEXT:    [[OP_RDX_5:%.*]] = add i32 [[TMP47]], [[OP_RDX_4]]
+; CHECK-O3-NEXT:    [[OP_RDX_5:%.*]] = add nuw nsw i32 [[TMP47]], [[OP_RDX_4]]
 ; CHECK-O3-NEXT:    [[ADD_PTR_5:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_4]], i64 [[IDX_EXT]]
 ; CHECK-O3-NEXT:    [[ADD_PTR9_5:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_4]], i64 [[IDX_EXT8]]
 ; CHECK-O3-NEXT:    [[TMP48:%.*]] = load <16 x i8>, ptr [[ADD_PTR_5]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -209,7 +209,7 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
 ; CHECK-LTO-NEXT:    [[TMP11:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP10]], i1 true)
 ; CHECK-LTO-NEXT:    [[TMP52:%.*]] = zext nneg <16 x i16> [[TMP11]] to <16 x i32>
 ; CHECK-LTO-NEXT:    [[TMP60:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP52]])
-; CHECK-LTO-NEXT:    [[OP_RDX_1:%.*]] = add i32 [[TMP60]], [[TMP44]]
+; CHECK-LTO-NEXT:    [[OP_RDX_1:%.*]] = add nuw nsw i32 [[TMP60]], [[TMP44]]
 ; CHECK-LTO-NEXT:    [[ADD_PTR_1:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR]], i64 [[IDX_EXT]]
 ; CHECK-LTO-NEXT:    [[ADD_PTR9_1:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9]], i64 [[IDX_EXT8]]
 ; CHECK-LTO-NEXT:    [[TMP12:%.*]] = load <16 x i8>, ptr [[ADD_PTR_1]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -220,7 +220,7 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
 ; CHECK-LTO-NEXT:    [[TMP17:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP16]], i1 true)
 ; CHECK-LTO-NEXT:    [[TMP68:%.*]] = zext nneg <16 x i16> [[TMP17]] to <16 x i32>
 ; CHECK-LTO-NEXT:    [[TMP76:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP68]])
-; CHECK-LTO-NEXT:    [[OP_RDX_2:%.*]] = add i32 [[OP_RDX_1]], [[TMP76]]
+; CHECK-LTO-NEXT:    [[OP_RDX_2:%.*]] = add nuw nsw i32 [[OP_RDX_1]], [[TMP76]]
 ; CHECK-LTO-NEXT:    [[ADD_PTR_2:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_1]], i64 [[IDX_EXT]]
 ; CHECK-LTO-NEXT:    [[ADD_PTR9_2:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_1]], i64 [[IDX_EXT8]]
 ; CHECK-LTO-NEXT:    [[TMP18:%.*]] = load <16 x i8>, ptr [[ADD_PTR_2]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -231,7 +231,7 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
 ; CHECK-LTO-NEXT:    [[TMP23:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP22]], i1 true)
 ; CHECK-LTO-NEXT:    [[TMP84:%.*]] = zext nneg <16 x i16> [[TMP23]] to <16 x i32>
 ; CHECK-LTO-NEXT:    [[TMP92:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP84]])
-; CHECK-LTO-NEXT:    [[OP_RDX_3:%.*]] = add i32 [[OP_RDX_2]], [[TMP92]]
+; CHECK-LTO-NEXT:    [[OP_RDX_3:%.*]] = add nuw nsw i32 [[OP_RDX_2]], [[TMP92]]
 ; CHECK-LTO-NEXT:    [[ADD_PTR_3:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_2]], i64 [[IDX_EXT]]
 ; CHECK-LTO-NEXT:    [[ADD_PTR9_3:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_2]], i64 [[IDX_EXT8]]
 ; CHECK-LTO-NEXT:    [[TMP24:%.*]] = load <16 x i8>, ptr [[ADD_PTR_3]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -242,7 +242,7 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
 ; CHECK-LTO-NEXT:    [[TMP29:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP28]], i1 true)
 ; CHECK-LTO-NEXT:    [[TMP100:%.*]] = zext nneg <16 x i16> [[TMP29]] to <16 x i32>
 ; CHECK-LTO-NEXT:    [[TMP108:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP100]])
-; CHECK-LTO-NEXT:    [[OP_RDX_4:%.*]] = add i32 [[OP_RDX_3]], [[TMP108]]
+; CHECK-LTO-NEXT:    [[OP_RDX_4:%.*]] = add nuw nsw i32 [[OP_RDX_3]], [[TMP108]]
 ; CHECK-LTO-NEXT:    [[ADD_PTR_4:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_3]], i64 [[IDX_EXT]]
 ; CHECK-LTO-NEXT:    [[ADD_PTR9_4:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_3]], i64 [[IDX_EXT8]]
 ; CHECK-LTO-NEXT:    [[TMP30:%.*]] = load <16 x i8>, ptr [[ADD_PTR_4]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -253,7 +253,7 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
 ; CHECK-LTO-NEXT:    [[TMP35:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP34]], i1 true)
 ; CHECK-LTO-NEXT:    [[TMP116:%.*]] = zext nneg <16 x i16> [[TMP35]] to <16 x i32>
 ; CHECK-LTO-NEXT:    [[TMP117:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP116]])
-; CHECK-LTO-NEXT:    [[OP_RDX_5:%.*]] = add i32 [[OP_RDX_4]], [[TMP117]]
+; CHECK-LTO-NEXT:    [[OP_RDX_5:%.*]] = add nuw nsw i32 [[OP_RDX_4]], [[TMP117]]
 ; CHECK-LTO-NEXT:    [[ADD_PTR_5:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_4]], i64 [[IDX_EXT]]
 ; CHECK-LTO-NEXT:    [[ADD_PTR9_5:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_4]], i64 [[IDX_EXT8]]
 ; CHECK-LTO-NEXT:    [[TMP37:%.*]] = load <16 x i8>, ptr [[ADD_PTR_5]], align 1, !tbaa [[CHAR_TBAA0]]
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index ce0bf86e39dd7..9e9ffca966f6d 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -845,4 +845,38 @@ TEST(KnownBitsTest, MulExhaustive) {
   }
 }
 
+TEST(KnownBitsTest, ReduceAddExhaustive) {
+  unsigned Bits = 4;
+  for (unsigned NumElts : {2, 4}) {
+    ForeachKnownBits(Bits, [&](const KnownBits &EltKnown) {
+      KnownBits Computed = EltKnown.reduceAdd(NumElts);
+      KnownBits Exact(Bits);
+      Exact.Zero.setAllBits();
+      Exact.One.setAllBits();
+
+      llvm::function_ref<void(unsigned, APInt)> EnumerateCombinations;
+      auto EnumerateCombinationsImpl = [&](unsigned Depth, APInt CurrentSum) {
+        if (Depth == NumElts) {
+          Exact.One &= CurrentSum;
+          Exact.Zero &= ~CurrentSum;
+          return;
+        }
+        ForeachNumInKnownBits(EltKnown, [&](const APInt &Elt) {
+          EnumerateCombinations(Depth + 1, CurrentSum + Elt);
+        });
+      };
+      EnumerateCombinations = EnumerateCombinationsImpl;
+
+      // Here we recursively generate NumElts unique elements matching known
+      // bits and collect exact known bits for all possible combinations.
+      EnumerateCombinations(0, APInt(Bits, 0));
+
+      if (!Exact.hasConflict()) {
+        EXPECT_TRUE(checkResult("reduceAdd", Exact, Computed, {EltKnown},
+                                /*CheckOptimality=*/false));
+      }
+    });
+  }
+}
+
 } // end anonymous namespace



More information about the llvm-commits mailing list