[llvm] [VectorCombine] Fold vector.reduce.OP(F(X)) == 0 -> OP(X) == 0 (PR #173069)

Valeriy Savchenko via llvm-commits llvm-commits at lists.llvm.org
Sun Jan 11 06:45:38 PST 2026


================
@@ -3806,6 +3807,192 @@ bool VectorCombine::foldCastFromReductions(Instruction &I) {
   return true;
 }
 
+/// Match a non-zero and non-poison integer or a vector with all non-zero and
+/// non-poison elements.
+static cst_pred_ty<is_non_zero_int, /*AllowPoison=*/false>
+m_NonZeroNonPoisonInt() {
+  return cst_pred_ty<is_non_zero_int, /*AllowPoison=*/false>();
+}
+
+bool VectorCombine::foldICmpEqZeroVectorReduce(Instruction &I) {
+  // vector.reduce.OP f(X_i) == 0 -> vector.reduce.OP X_i == 0
+  //
+  // We can prove it for cases when:
+  //
+  //   1.  OP X_i == 0 <=> \forall i \in [1, N] X_i == 0
+  //   1'. OP X_i == 0 <=> \exists j \in [1, N] X_j == 0
+  //   2.  f(x) == 0 <=> x == 0
+  //
+  // From 1 and 2 (or 1' and 2), we can infer that
+  //
+  //   OP f(X_i) == 0 <=> OP X_i == 0.
+  //
+  //                  (1)
+  //   OP f(X_i) == 0 <=> \forall i \in [1, N] f(X_i) == 0
+  //                  (2)
+  //                  <=> \forall i \in [1, N] X_i == 0
+  //                  (1)
+  //                  <=> OP(X_i) == 0
+  //
+  // For some of the OP's and f's, we need to have domain constraints on X
+  // to ensure properties 1 (or 1') and 2.
+  CmpPredicate Pred;
+  Value *Op;
+  if (!match(&I, m_ICmp(Pred, m_Value(Op), m_Zero())) ||
+      !ICmpInst::isEquality(Pred))
+    return false;
+
+  auto *II = dyn_cast<IntrinsicInst>(Op);
+  if (!II)
+    return false;
+
+  switch (II->getIntrinsicID()) {
+  case Intrinsic::vector_reduce_add:
+  case Intrinsic::vector_reduce_or:
+  case Intrinsic::vector_reduce_umin:
+  case Intrinsic::vector_reduce_umax:
+  case Intrinsic::vector_reduce_smin:
+  case Intrinsic::vector_reduce_smax:
+    break;
+  default:
+    return false;
+  }
+
+  Value *InnerOp = II->getArgOperand(0);
+
+  // TODO: fixed vector type might be too restrictive
+  if (!II->hasOneUse() || !InnerOp->hasOneUse() ||
+      !isa<FixedVectorType>(InnerOp->getType()))
+    return false;
+
+  Value *X = nullptr;
+
+  // Check for zero-preserving operations where f(x) = 0 <=> x = 0
+  //
+  //   1. f(x) = shl nuw x, y for arbitrary y
+  //   2. f(x) = mul nuw x, c for defined c != 0
+  //   3. f(x) = zext x
+  //   4. f(x) = sext x
+  //   5. f(x) = neg x
+  //
+  if (!(match(InnerOp, m_NUWShl(m_Value(X),
+                                m_Value())) || // Case 1
+        match(InnerOp,
+              m_NUWMul(m_Value(X), m_NonZeroNonPoisonInt())) || // Case 2
+        match(InnerOp, m_ZExt(m_Value(X))) ||                   // Case 3
+        match(InnerOp, m_SExt(m_Value(X))) ||                   // Case 4
+        match(InnerOp, m_Neg(m_Value(X)))                       // Case 5
+        ))
+    return false;
+
+  SimplifyQuery S = SQ.getWithInstruction(&I);
+  assert(isa<FixedVectorType>(X->getType()) && "Unexpected type");
+  auto *XTy = cast<FixedVectorType>(X->getType());
+
+  // Check for domain constraints for all supported reductions.
+  //
+  //  a. OR X_i   - has property 1  for every X
+  //  b. UMAX X_i - has property 1  for every X
+  //  c. UMIN X_i - has property 1' for every X
+  //  d. SMAX X_i - has property 1  for X >= 0
+  //  e. SMIN X_i - has property 1' for X >= 0
+  //  f. ADD X_i  - has property 1  for X >= 0 && ADD X_i doesn't sign wrap
+  //
+  // In order for the proof to work, we need 1 (or 1') to be true for both
+  // OP f(X_i) and OP X_i and that's why below we check constraints twice.
+  //
+  // NOTE: ADD X_i holds property 1 for a mirror case as well, i.e. when
+  //       X <= 0 && ADD X_i doesn't sign wrap. However, due to the nature
+  //       of known bits, we can't reasonably hold knowledge of "either 0
+  //       or negative".
+  switch (II->getIntrinsicID()) {
+  case Intrinsic::vector_reduce_add: {
+    // We need to check that both X_i and f(X_i) have enough leading
+    // zeros to not overflow.
+    KnownBits KnownX = computeKnownBits(X, S);
+    KnownBits KnownFX = computeKnownBits(InnerOp, S);
+    unsigned NumElems = XTy->getNumElements();
+    // Adding N elements loses at most bit_width(N-1) leading bits.
+    unsigned LostBits = NumElems > 1 ? llvm::bit_width(NumElems - 1) : 0;
----------------
SavchenkoValeriy wrote:

Applied suggested change, but ignored the comment about `computeOverflowForUnsignedMul`. This method is robust for all bitwidths and vector sizes (it uses the same logic as in #174410)

https://github.com/llvm/llvm-project/pull/173069


More information about the llvm-commits mailing list