[llvm] [VectorCombine] Optimize vector combine in fold binop of reduction (PR #179416)

via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 6 01:19:25 PST 2026


https://github.com/Anjian-Wen updated https://github.com/llvm/llvm-project/pull/179416

>From 8b345aabfb01136b19ca093dd2ca8f5c463660f2 Mon Sep 17 00:00:00 2001
From: Anjian-Wen <wenanjian at bytedance.com>
Date: Tue, 3 Feb 2026 16:19:04 +0800
Subject: [PATCH 1/5] [RISCV] optimize vector combine in fold binop of
 reduction

move the reduction op close by simple commutative property of Addition
and Associative Property of Subtraction, which can pave the way for combining
the next two instructions in some case and saving a reduction instruction at last.
---
 .../Transforms/Vectorize/VectorCombine.cpp    | 34 +++++++++++++++++++
 .../VectorCombine/fold-binop-of-reductions.ll | 30 ++++++++++++++++
 2 files changed, 64 insertions(+)

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 1746d3e4b06f4..1dca47e97988d 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -1754,6 +1754,40 @@ bool VectorCombine::foldBinopOfReductions(Instruction &I) {
     return nullptr;
   };
 
+  // sub (add (a, vector_reduce_add b), vector_reduce_add c) ->
+  // add (a, sub (vector_reduce_add b, vector_reduce_add c))
+  // sub (add (vector_reduce_add b, a), vector_reduce_add c) ->
+  // add (a, sub (vector_reduce_add b, vector_reduce_add c))
+  if (BinOpOpc == Instruction::Sub) {
+    auto *II = dyn_cast<BinaryOperator>(I.getOperand(0));
+    if (II && II->getOpcode() == Instruction::Add) {
+      Value *V1 =
+          checkIntrinsicAndGetItsArgument(I.getOperand(1), ReductionIID);
+      if (V1) {
+        Instruction *I0 = dyn_cast<Instruction>(I.getOperand(0));
+        Value *V00 =
+            checkIntrinsicAndGetItsArgument(I0->getOperand(0), ReductionIID);
+        Value *V01 =
+            checkIntrinsicAndGetItsArgument(I0->getOperand(1), ReductionIID);
+        if (V00) {
+          Value *NewSub =
+              Builder.CreateBinOp(BinOpOpc, I0->getOperand(0), I.getOperand(1));
+          Value *NewAdd =
+              Builder.CreateBinOp(Instruction::Add, I0->getOperand(1), NewSub);
+          replaceValue(I, *NewAdd);
+          return true;
+        } else if (V01) {
+          Value *NewSub =
+              Builder.CreateBinOp(BinOpOpc, I0->getOperand(1), I.getOperand(1));
+          Value *NewAdd =
+              Builder.CreateBinOp(Instruction::Add, I0->getOperand(0), NewSub);
+          replaceValue(I, *NewAdd);
+          return true;
+        }
+      }
+    }
+  }
+
   Value *V0 = checkIntrinsicAndGetItsArgument(I.getOperand(0), ReductionIID);
   if (!V0)
     return false;
diff --git a/llvm/test/Transforms/VectorCombine/fold-binop-of-reductions.ll b/llvm/test/Transforms/VectorCombine/fold-binop-of-reductions.ll
index 5f29af9de5a39..002f98990e197 100644
--- a/llvm/test/Transforms/VectorCombine/fold-binop-of-reductions.ll
+++ b/llvm/test/Transforms/VectorCombine/fold-binop-of-reductions.ll
@@ -205,3 +205,33 @@ define i32 @element_counts_do_not_match_vscale(<vscale x 16 x i32> %v0, <vscale
   %res = add i32 %v0_red, %v1_red
   ret i32 %res
 }
+
+define i32 @sub_add_reduction_s_reduction(<vscale x 8 x i32> %v0, <vscale x 8 x i32> %v1, i32 %s1) {
+; CHECK-LABEL: define i32 @sub_add_reduction_s_reduction(
+; CHECK-SAME: <vscale x 8 x i32> [[V0:%.*]], <vscale x 8 x i32> [[V1:%.*]], i32 [[S1:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = sub <vscale x 8 x i32> [[V0]], [[V1]]
+; CHECK-NEXT:    [[TMP2:%.*]] = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> [[TMP1]])
+; CHECK-NEXT:    [[RES:%.*]] = add i32 [[TMP2]], [[S1]]
+; CHECK-NEXT:    ret i32 [[RES]]
+;
+  %v0_red = tail call i32 @llvm.vector.reduce.add.v8i32(<vscale x 8 x i32> %v0)
+  %v1_red = tail call i32 @llvm.vector.reduce.add.v8i32(<vscale x 8 x i32> %v1)
+  %add1 = add i32 %v0_red, %s1
+  %res = sub i32 %add1, %v1_red
+  ret i32 %res
+}
+
+define i32 @sub_add_s_reduction_reduction(<vscale x 8 x i32> %v0, <vscale x 8 x i32> %v1, i32 %s1) {
+; CHECK-LABEL: define i32 @sub_add_s_reduction_reduction(
+; CHECK-SAME: <vscale x 8 x i32> [[V0:%.*]], <vscale x 8 x i32> [[V1:%.*]], i32 [[S1:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = sub <vscale x 8 x i32> [[V0]], [[V1]]
+; CHECK-NEXT:    [[TMP2:%.*]] = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> [[TMP1]])
+; CHECK-NEXT:    [[RES:%.*]] = add i32 [[S1]], [[TMP2]]
+; CHECK-NEXT:    ret i32 [[RES]]
+;
+  %v0_red = tail call i32 @llvm.vector.reduce.add.v8i32(<vscale x 8 x i32> %v0)
+  %v1_red = tail call i32 @llvm.vector.reduce.add.v8i32(<vscale x 8 x i32> %v1)
+  %add1 = add i32 %s1, %v0_red
+  %res = sub i32 %add1, %v1_red
+  ret i32 %res
+}

>From 0913941580820fe63b9f8870c7ce8e4b6b3dd369 Mon Sep 17 00:00:00 2001
From: Anjian-Wen <wenanjian at bytedance.com>
Date: Mon, 9 Feb 2026 15:36:57 +0800
Subject: [PATCH 2/5] add more pattern and more test

---
 .../Transforms/Vectorize/VectorCombine.cpp    | 63 ++++++++++++-------
 .../VectorCombine/fold-binop-of-reductions.ll | 30 +++++++++
 2 files changed, 70 insertions(+), 23 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 1dca47e97988d..74bf337c1767a 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -1754,34 +1754,51 @@ bool VectorCombine::foldBinopOfReductions(Instruction &I) {
     return nullptr;
   };
 
-  // sub (add (a, vector_reduce_add b), vector_reduce_add c) ->
-  // add (a, sub (vector_reduce_add b, vector_reduce_add c))
   // sub (add (vector_reduce_add b, a), vector_reduce_add c) ->
+  // add (sub (vector_reduce_add b, vector_reduce_add c), a)
+  //
+  // sub (sub (vector_reduce_add b, a), vector_reduce_add c) ->
+  // sub (sub (vector_reduce_add b, vector_reduce_add c), a)
+  //
+  // sub (sub (a, vector_reduce_add b), vector_reduce_add c) ->
+  // sub (a, add (vector_reduce_add b, vector_reduce_add c))
+  //
+  // sub (add (a, vector_reduce_add b), vector_reduce_add c) ->
   // add (a, sub (vector_reduce_add b, vector_reduce_add c))
   if (BinOpOpc == Instruction::Sub) {
     auto *II = dyn_cast<BinaryOperator>(I.getOperand(0));
-    if (II && II->getOpcode() == Instruction::Add) {
-      Value *V1 =
-          checkIntrinsicAndGetItsArgument(I.getOperand(1), ReductionIID);
-      if (V1) {
-        Instruction *I0 = dyn_cast<Instruction>(I.getOperand(0));
-        Value *V00 =
-            checkIntrinsicAndGetItsArgument(I0->getOperand(0), ReductionIID);
-        Value *V01 =
-            checkIntrinsicAndGetItsArgument(I0->getOperand(1), ReductionIID);
-        if (V00) {
-          Value *NewSub =
-              Builder.CreateBinOp(BinOpOpc, I0->getOperand(0), I.getOperand(1));
-          Value *NewAdd =
-              Builder.CreateBinOp(Instruction::Add, I0->getOperand(1), NewSub);
-          replaceValue(I, *NewAdd);
+    Value *V1 = checkIntrinsicAndGetItsArgument(I.getOperand(1), ReductionIID);
+
+    if (II && V1 &&
+        (II->getOpcode() == Instruction::Add ||
+         II->getOpcode() == Instruction::Sub)) {
+      Instruction *I0 = dyn_cast<Instruction>(I.getOperand(0));
+      Value *V00 =
+          checkIntrinsicAndGetItsArgument(I0->getOperand(0), ReductionIID);
+      Value *V01 =
+          checkIntrinsicAndGetItsArgument(I0->getOperand(1), ReductionIID);
+
+      if (V00 && !V01) {
+        Value *CombineNode = Builder.CreateBinOp(
+            Instruction::Sub, I0->getOperand(0), I.getOperand(1));
+        Value *NewBinNode = Builder.CreateBinOp(II->getOpcode(), CombineNode,
+                                                I0->getOperand(1));
+        replaceValue(I, *NewBinNode);
+        return true;
+      } else if (V01 && !V00) {
+        if (II->getOpcode() == Instruction::Sub) {
+          Value *CombineNode = Builder.CreateBinOp(
+              Instruction::Add, I0->getOperand(1), I.getOperand(1));
+          Value *NewBinNode = Builder.CreateBinOp(
+              Instruction::Sub, I0->getOperand(0), CombineNode);
+          replaceValue(I, *NewBinNode);
           return true;
-        } else if (V01) {
-          Value *NewSub =
-              Builder.CreateBinOp(BinOpOpc, I0->getOperand(1), I.getOperand(1));
-          Value *NewAdd =
-              Builder.CreateBinOp(Instruction::Add, I0->getOperand(0), NewSub);
-          replaceValue(I, *NewAdd);
+        } else if (II->getOpcode() == Instruction::Add) {
+          Value *CombineNode = Builder.CreateBinOp(
+              Instruction::Sub, I0->getOperand(1), I.getOperand(1));
+          Value *NewBinNode = Builder.CreateBinOp(
+              Instruction::Add, I0->getOperand(0), CombineNode);
+          replaceValue(I, *NewBinNode);
           return true;
         }
       }
diff --git a/llvm/test/Transforms/VectorCombine/fold-binop-of-reductions.ll b/llvm/test/Transforms/VectorCombine/fold-binop-of-reductions.ll
index 002f98990e197..22960119ce056 100644
--- a/llvm/test/Transforms/VectorCombine/fold-binop-of-reductions.ll
+++ b/llvm/test/Transforms/VectorCombine/fold-binop-of-reductions.ll
@@ -235,3 +235,33 @@ define i32 @sub_add_s_reduction_reduction(<vscale x 8 x i32> %v0, <vscale x 8 x
   %res = sub i32 %add1, %v1_red
   ret i32 %res
 }
+
+define i32 @sub_sub_reduction_s_reduction(<vscale x 8 x i32> %v0, <vscale x 8 x i32> %v1, i32 %s1) {
+; CHECK-LABEL: define i32 @sub_sub_reduction_s_reduction(
+; CHECK-SAME: <vscale x 8 x i32> [[V0:%.*]], <vscale x 8 x i32> [[V1:%.*]], i32 [[S1:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = sub <vscale x 8 x i32> [[V0]], [[V1]]
+; CHECK-NEXT:    [[TMP2:%.*]] = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> [[TMP1]])
+; CHECK-NEXT:    [[RES:%.*]] = sub i32 [[TMP2]], [[S1]]
+; CHECK-NEXT:    ret i32 [[RES]]
+;
+  %v0_red = tail call i32 @llvm.vector.reduce.add.v8i32(<vscale x 8 x i32> %v0)
+  %v1_red = tail call i32 @llvm.vector.reduce.add.v8i32(<vscale x 8 x i32> %v1)
+  %sub1 = sub i32 %v0_red, %s1
+  %res = sub i32 %sub1, %v1_red
+  ret i32 %res
+}
+
+define i32 @sub_sub_s_reduction_reduction(<vscale x 8 x i32> %v0, <vscale x 8 x i32> %v1, i32 %s1) {
+; CHECK-LABEL: define i32 @sub_sub_s_reduction_reduction(
+; CHECK-SAME: <vscale x 8 x i32> [[V0:%.*]], <vscale x 8 x i32> [[V1:%.*]], i32 [[S1:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = add <vscale x 8 x i32> [[V0]], [[V1]]
+; CHECK-NEXT:    [[TMP2:%.*]] = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> [[TMP1]])
+; CHECK-NEXT:    [[RES:%.*]] = sub i32 [[S1]], [[TMP2]]
+; CHECK-NEXT:    ret i32 [[RES]]
+;
+  %v0_red = tail call i32 @llvm.vector.reduce.add.v8i32(<vscale x 8 x i32> %v0)
+  %v1_red = tail call i32 @llvm.vector.reduce.add.v8i32(<vscale x 8 x i32> %v1)
+  %sub1 = sub i32 %s1, %v0_red
+  %res = sub i32 %sub1, %v1_red
+  ret i32 %res
+}

>From 443404379b325e028139ff9dd9fde81bc58c135a Mon Sep 17 00:00:00 2001
From: Anjian-Wen <wenanjian at bytedance.com>
Date: Fri, 13 Feb 2026 15:39:27 +0800
Subject: [PATCH 3/5] modify match function to match all outer sub/add and
 inner sub/add with reduction

---
 .../Transforms/Vectorize/VectorCombine.cpp    | 233 +++++++++++++-----
 .../PhaseOrdering/AArch64/udotabd.ll          |  28 +--
 .../fold-binop-of-reductions-add.ll           |  54 ++++
 .../fold-binop-of-reductions-sub.ll           | 132 ++++++++++
 .../VectorCombine/fold-binop-of-reductions.ll |   2 +-
 5 files changed, 377 insertions(+), 72 deletions(-)
 create mode 100644 llvm/test/Transforms/VectorCombine/fold-binop-of-reductions-add.ll
 create mode 100644 llvm/test/Transforms/VectorCombine/fold-binop-of-reductions-sub.ll

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 74bf337c1767a..6e0de0393db31 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -47,6 +47,10 @@
 using namespace llvm;
 using namespace llvm::PatternMatch;
 
+static cl::opt<bool> EnableFoldBinopOfReductions(
+    "vector-combine-fold-binop-of-reductions", cl::init(true), cl::Hidden,
+    cl::desc("Enable folding of binary operations of reductions"));
+
 STATISTIC(NumVecLoad, "Number of vector loads formed");
 STATISTIC(NumVecCmp, "Number of vector compares formed");
 STATISTIC(NumVecBO, "Number of vector binops formed");
@@ -1736,6 +1740,152 @@ static void analyzeCostOfVecReduction(const IntrinsicInst &II,
                                                       std::nullopt, CostKind);
 }
 
+static Value *checkIntrinsicAndGetItsArgument(Value *V, Intrinsic::ID IID) {
+  auto *II = dyn_cast<IntrinsicInst>(V);
+  if (!II)
+    return nullptr;
+  if (II->getIntrinsicID() == IID && II->hasOneUse())
+    return II->getArgOperand(0);
+  return nullptr;
+}
+
+template <typename IRBuilderTy>
+static bool matchAssociativeReduction(
+    Instruction &I, Instruction::BinaryOps BinOpOpc, Intrinsic::ID ReductionIID,
+    IRBuilderTy &Builder, BinaryOperator *Op0, Value *ScalarReduceC,
+    bool ScalarReduceCIsLeft,
+    function_ref<void(Instruction &, Value &)> ReplaceValue) {
+  if (Op0->getOpcode() != Instruction::Add &&
+      Op0->getOpcode() != Instruction::Sub)
+    return false;
+
+  Value *ReduceCVector =
+      checkIntrinsicAndGetItsArgument(ScalarReduceC, ReductionIID);
+  if (!ReduceCVector)
+    return false;
+
+  Value *ReduceLVector =
+      checkIntrinsicAndGetItsArgument(Op0->getOperand(0), ReductionIID);
+  Value *ReduceRVector =
+      checkIntrinsicAndGetItsArgument(Op0->getOperand(1), ReductionIID);
+
+  // Only one of the operands of Op0 should be a reduction
+  if ((ReduceLVector && ReduceRVector) || (!ReduceLVector && !ReduceRVector))
+    return false;
+
+  Value *ScalarReduceX =
+      ReduceLVector ? Op0->getOperand(0) : Op0->getOperand(1);
+  Value *OtherVal = ReduceLVector ? Op0->getOperand(1) : Op0->getOperand(0);
+  bool IsReduceXOnLeft = (ReduceLVector != nullptr);
+
+  Instruction::BinaryOps OuterOp = BinOpOpc;
+  Instruction::BinaryOps InnerOp = Op0->getOpcode();
+  Instruction::BinaryOps NewReduceOp;
+  Instruction::BinaryOps NewBinOp;
+  Value *LHS_Reduce, *RHS_Reduce;
+  Value *LHS_Bin, *RHS_Bin;
+
+  if (OuterOp == Instruction::Add) {
+    if (InnerOp == Instruction::Add) {
+      // (X + Y) + C -> (X + C) + Y
+      NewReduceOp = Instruction::Add;
+      NewBinOp = Instruction::Add;
+      LHS_Reduce = ScalarReduceX;
+      RHS_Reduce = ScalarReduceC;
+      LHS_Bin = nullptr;
+      RHS_Bin = OtherVal;
+    } else { // Inner == Sub
+      if (IsReduceXOnLeft) {
+        // (X - Y) + C -> (X + C) - Y
+        NewReduceOp = Instruction::Add;
+        NewBinOp = Instruction::Sub;
+        LHS_Reduce = ScalarReduceX;
+        RHS_Reduce = ScalarReduceC;
+        LHS_Bin = nullptr;
+        RHS_Bin = OtherVal;
+      } else {
+        // (Y - X) + C -> Y - (X - C)
+        NewReduceOp = Instruction::Sub;
+        NewBinOp = Instruction::Sub;
+        LHS_Reduce = ScalarReduceX;
+        RHS_Reduce = ScalarReduceC;
+        LHS_Bin = OtherVal;
+        RHS_Bin = nullptr;
+      }
+    }
+  } else { // Outer == Sub
+    if (!ScalarReduceCIsLeft) { // (Op0 - C)
+      if (InnerOp == Instruction::Add) {
+        // (X + Y) - C -> (X - C) + Y
+        NewReduceOp = Instruction::Sub;
+        NewBinOp = Instruction::Add;
+        LHS_Reduce = ScalarReduceX;
+        RHS_Reduce = ScalarReduceC;
+        LHS_Bin = nullptr;
+        RHS_Bin = OtherVal;
+      } else { // Inner == Sub
+        if (IsReduceXOnLeft) {
+          // (X - Y) - C -> (X - C) - Y
+          NewReduceOp = Instruction::Sub;
+          NewBinOp = Instruction::Sub;
+          LHS_Reduce = ScalarReduceX;
+          RHS_Reduce = ScalarReduceC;
+          LHS_Bin = nullptr;
+          RHS_Bin = OtherVal;
+        } else {
+          // (Y - X) - C -> Y - (X + C)
+          NewReduceOp = Instruction::Add;
+          NewBinOp = Instruction::Sub;
+          LHS_Reduce = ScalarReduceX;
+          RHS_Reduce = ScalarReduceC;
+          LHS_Bin = OtherVal;
+          RHS_Bin = nullptr;
+        }
+      }
+    } else { // (C - Op0)
+      if (InnerOp == Instruction::Add) {
+        // C - (X + Y) -> (C - X) - Y
+        NewReduceOp = Instruction::Sub;
+        NewBinOp = Instruction::Sub;
+        LHS_Reduce = ScalarReduceC;
+        RHS_Reduce = ScalarReduceX;
+        LHS_Bin = nullptr;
+        RHS_Bin = OtherVal;
+      } else { // Inner == Sub
+        if (IsReduceXOnLeft) {
+          // C - (X - Y) -> (C - X) + Y
+          NewReduceOp = Instruction::Sub;
+          NewBinOp = Instruction::Add;
+          LHS_Reduce = ScalarReduceC;
+          RHS_Reduce = ScalarReduceX;
+          LHS_Bin = nullptr;
+          RHS_Bin = OtherVal;
+        } else {
+          // C - (Y - X) -> (C + X) - Y
+          NewReduceOp = Instruction::Add;
+          NewBinOp = Instruction::Sub;
+          LHS_Reduce = ScalarReduceC;
+          RHS_Reduce = ScalarReduceX;
+          LHS_Bin = nullptr;
+          RHS_Bin = OtherVal;
+        }
+      }
+    }
+  }
+
+  Value *CombineNode =
+      Builder.CreateBinOp(NewReduceOp, LHS_Reduce, RHS_Reduce);
+
+  Value *NewBinNode;
+  if (LHS_Bin == nullptr)
+    NewBinNode = Builder.CreateBinOp(NewBinOp, CombineNode, RHS_Bin);
+  else
+    NewBinNode = Builder.CreateBinOp(NewBinOp, LHS_Bin, CombineNode);
+
+  ReplaceValue(I, *NewBinNode);
+  return true;
+}
+
 bool VectorCombine::foldBinopOfReductions(Instruction &I) {
   Instruction::BinaryOps BinOpOpc = cast<BinaryOperator>(&I)->getOpcode();
   Intrinsic::ID ReductionIID = getReductionForBinop(BinOpOpc);
@@ -1744,67 +1894,36 @@ bool VectorCombine::foldBinopOfReductions(Instruction &I) {
   if (ReductionIID == Intrinsic::not_intrinsic)
     return false;
 
-  auto checkIntrinsicAndGetItsArgument = [](Value *V,
-                                            Intrinsic::ID IID) -> Value * {
-    auto *II = dyn_cast<IntrinsicInst>(V);
-    if (!II)
-      return nullptr;
-    if (II->getIntrinsicID() == IID && II->hasOneUse())
-      return II->getArgOperand(0);
-    return nullptr;
-  };
+  if (!EnableFoldBinopOfReductions)
+    return false;
 
-  // sub (add (vector_reduce_add b, a), vector_reduce_add c) ->
-  // add (sub (vector_reduce_add b, vector_reduce_add c), a)
-  //
-  // sub (sub (vector_reduce_add b, a), vector_reduce_add c) ->
-  // sub (sub (vector_reduce_add b, vector_reduce_add c), a)
-  //
-  // sub (sub (a, vector_reduce_add b), vector_reduce_add c) ->
-  // sub (a, add (vector_reduce_add b, vector_reduce_add c))
+  auto ReplaceValue = [&](Instruction &I, Value &V) { replaceValue(I, V); };
+
+  // Reduce the number of reductions by folding a binop of a reduction and a
+  // scalar (which might be another reduction) into a single reduction of a
+  // vector binop. This leverages associativity and commutativity of the
+  // binary operation (Add/Sub) to group reductions together.
   //
-  // sub (add (a, vector_reduce_add b), vector_reduce_add c) ->
-  // add (a, sub (vector_reduce_add b, vector_reduce_add c))
-  if (BinOpOpc == Instruction::Sub) {
-    auto *II = dyn_cast<BinaryOperator>(I.getOperand(0));
-    Value *V1 = checkIntrinsicAndGetItsArgument(I.getOperand(1), ReductionIID);
-
-    if (II && V1 &&
-        (II->getOpcode() == Instruction::Add ||
-         II->getOpcode() == Instruction::Sub)) {
-      Instruction *I0 = dyn_cast<Instruction>(I.getOperand(0));
-      Value *V00 =
-          checkIntrinsicAndGetItsArgument(I0->getOperand(0), ReductionIID);
-      Value *V01 =
-          checkIntrinsicAndGetItsArgument(I0->getOperand(1), ReductionIID);
-
-      if (V00 && !V01) {
-        Value *CombineNode = Builder.CreateBinOp(
-            Instruction::Sub, I0->getOperand(0), I.getOperand(1));
-        Value *NewBinNode = Builder.CreateBinOp(II->getOpcode(), CombineNode,
-                                                I0->getOperand(1));
-        replaceValue(I, *NewBinNode);
-        return true;
-      } else if (V01 && !V00) {
-        if (II->getOpcode() == Instruction::Sub) {
-          Value *CombineNode = Builder.CreateBinOp(
-              Instruction::Add, I0->getOperand(1), I.getOperand(1));
-          Value *NewBinNode = Builder.CreateBinOp(
-              Instruction::Sub, I0->getOperand(0), CombineNode);
-          replaceValue(I, *NewBinNode);
-          return true;
-        } else if (II->getOpcode() == Instruction::Add) {
-          Value *CombineNode = Builder.CreateBinOp(
-              Instruction::Sub, I0->getOperand(1), I.getOperand(1));
-          Value *NewBinNode = Builder.CreateBinOp(
-              Instruction::Add, I0->getOperand(0), CombineNode);
-          replaceValue(I, *NewBinNode);
-          return true;
-        }
-      }
-    }
+  // Examples:
+  //   (Reduce(X) + Y) + Reduce(Z)  -> Reduce(X + Z) + Y
+  //   (Reduce(X) - Y) - Reduce(Z)  -> Reduce(X - Z) - Y
+  //   Reduce(Z) - (Reduce(X) + Y)  -> Reduce(Z - X) - Y
+  if (auto *Op0 = dyn_cast<BinaryOperator>(I.getOperand(0))) {
+    if (matchAssociativeReduction(I, BinOpOpc, ReductionIID, Builder, Op0,
+                                  I.getOperand(1), /*ScalarReduceCIsLeft=*/false,
+                                  ReplaceValue))
+      return true;
   }
 
+  if (auto *Op1 = dyn_cast<BinaryOperator>(I.getOperand(1))) {
+    if (matchAssociativeReduction(I, BinOpOpc, ReductionIID, Builder, Op1,
+                                  I.getOperand(0), /*ScalarReduceCIsLeft=*/true,
+                                  ReplaceValue))
+      return true;
+  }
+
+
+
   Value *V0 = checkIntrinsicAndGetItsArgument(I.getOperand(0), ReductionIID);
   if (!V0)
     return false;
diff --git a/llvm/test/Transforms/PhaseOrdering/AArch64/udotabd.ll b/llvm/test/Transforms/PhaseOrdering/AArch64/udotabd.ll
index e2f7f8f7e5cac..27309e6359837 100644
--- a/llvm/test/Transforms/PhaseOrdering/AArch64/udotabd.ll
+++ b/llvm/test/Transforms/PhaseOrdering/AArch64/udotabd.ll
@@ -219,8 +219,7 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
 ; CHECK-LTO-NEXT:    [[TMP16:%.*]] = sub nsw <16 x i16> [[TMP13]], [[TMP15]]
 ; CHECK-LTO-NEXT:    [[TMP17:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP16]], i1 true)
 ; CHECK-LTO-NEXT:    [[TMP68:%.*]] = zext nneg <16 x i16> [[TMP17]] to <16 x i32>
-; CHECK-LTO-NEXT:    [[TMP76:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP68]])
-; CHECK-LTO-NEXT:    [[OP_RDX_2:%.*]] = add nuw nsw i32 [[OP_RDX_1]], [[TMP76]]
+; CHECK-LTO-NEXT:    [[OP_RDX_2:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP68]])
 ; CHECK-LTO-NEXT:    [[ADD_PTR_2:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_1]], i64 [[IDX_EXT]]
 ; CHECK-LTO-NEXT:    [[ADD_PTR9_2:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_1]], i64 [[IDX_EXT8]]
 ; CHECK-LTO-NEXT:    [[TMP18:%.*]] = load <16 x i8>, ptr [[ADD_PTR_2]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -232,6 +231,7 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
 ; CHECK-LTO-NEXT:    [[TMP84:%.*]] = zext nneg <16 x i16> [[TMP23]] to <16 x i32>
 ; CHECK-LTO-NEXT:    [[TMP92:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP84]])
 ; CHECK-LTO-NEXT:    [[OP_RDX_3:%.*]] = add nuw nsw i32 [[OP_RDX_2]], [[TMP92]]
+; CHECK-LTO-NEXT:    [[OP_RDX_6:%.*]] = add nuw nsw i32 [[OP_RDX_3]], [[OP_RDX_1]]
 ; CHECK-LTO-NEXT:    [[ADD_PTR_3:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_2]], i64 [[IDX_EXT]]
 ; CHECK-LTO-NEXT:    [[ADD_PTR9_3:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_2]], i64 [[IDX_EXT8]]
 ; CHECK-LTO-NEXT:    [[TMP24:%.*]] = load <16 x i8>, ptr [[ADD_PTR_3]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -241,8 +241,7 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
 ; CHECK-LTO-NEXT:    [[TMP28:%.*]] = sub nsw <16 x i16> [[TMP25]], [[TMP27]]
 ; CHECK-LTO-NEXT:    [[TMP29:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP28]], i1 true)
 ; CHECK-LTO-NEXT:    [[TMP100:%.*]] = zext nneg <16 x i16> [[TMP29]] to <16 x i32>
-; CHECK-LTO-NEXT:    [[TMP108:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP100]])
-; CHECK-LTO-NEXT:    [[OP_RDX_4:%.*]] = add nuw nsw i32 [[OP_RDX_3]], [[TMP108]]
+; CHECK-LTO-NEXT:    [[OP_RDX_4:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP100]])
 ; CHECK-LTO-NEXT:    [[ADD_PTR_4:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_3]], i64 [[IDX_EXT]]
 ; CHECK-LTO-NEXT:    [[ADD_PTR9_4:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_3]], i64 [[IDX_EXT8]]
 ; CHECK-LTO-NEXT:    [[TMP30:%.*]] = load <16 x i8>, ptr [[ADD_PTR_4]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -254,6 +253,7 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
 ; CHECK-LTO-NEXT:    [[TMP116:%.*]] = zext nneg <16 x i16> [[TMP35]] to <16 x i32>
 ; CHECK-LTO-NEXT:    [[TMP117:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP116]])
 ; CHECK-LTO-NEXT:    [[OP_RDX_5:%.*]] = add nuw nsw i32 [[OP_RDX_4]], [[TMP117]]
+; CHECK-LTO-NEXT:    [[OP_RDX_8:%.*]] = add nuw nsw i32 [[OP_RDX_5]], [[OP_RDX_6]]
 ; CHECK-LTO-NEXT:    [[ADD_PTR_5:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_4]], i64 [[IDX_EXT]]
 ; CHECK-LTO-NEXT:    [[ADD_PTR9_5:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_4]], i64 [[IDX_EXT8]]
 ; CHECK-LTO-NEXT:    [[TMP37:%.*]] = load <16 x i8>, ptr [[ADD_PTR_5]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -264,7 +264,6 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
 ; CHECK-LTO-NEXT:    [[TMP42:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP41]], i1 true)
 ; CHECK-LTO-NEXT:    [[TMP43:%.*]] = zext nneg <16 x i16> [[TMP42]] to <16 x i32>
 ; CHECK-LTO-NEXT:    [[TMP118:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP43]])
-; CHECK-LTO-NEXT:    [[OP_RDX_6:%.*]] = add i32 [[OP_RDX_5]], [[TMP118]]
 ; CHECK-LTO-NEXT:    [[ADD_PTR_6:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_5]], i64 [[IDX_EXT]]
 ; CHECK-LTO-NEXT:    [[ADD_PTR9_6:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_5]], i64 [[IDX_EXT8]]
 ; CHECK-LTO-NEXT:    [[TMP45:%.*]] = load <16 x i8>, ptr [[ADD_PTR_6]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -275,7 +274,8 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
 ; CHECK-LTO-NEXT:    [[TMP50:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP49]], i1 true)
 ; CHECK-LTO-NEXT:    [[TMP51:%.*]] = zext nneg <16 x i16> [[TMP50]] to <16 x i32>
 ; CHECK-LTO-NEXT:    [[TMP120:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP51]])
-; CHECK-LTO-NEXT:    [[OP_RDX_7:%.*]] = add i32 [[OP_RDX_6]], [[TMP120]]
+; CHECK-LTO-NEXT:    [[TMP76:%.*]] = add nuw nsw i32 [[TMP118]], [[TMP120]]
+; CHECK-LTO-NEXT:    [[OP_RDX_7:%.*]] = add nuw nsw i32 [[TMP76]], [[OP_RDX_8]]
 ; CHECK-LTO-NEXT:    [[ADD_PTR_7:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_6]], i64 [[IDX_EXT]]
 ; CHECK-LTO-NEXT:    [[ADD_PTR9_7:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_6]], i64 [[IDX_EXT8]]
 ; CHECK-LTO-NEXT:    [[TMP53:%.*]] = load <16 x i8>, ptr [[ADD_PTR_7]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -286,7 +286,6 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
 ; CHECK-LTO-NEXT:    [[TMP58:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP57]], i1 true)
 ; CHECK-LTO-NEXT:    [[TMP59:%.*]] = zext nneg <16 x i16> [[TMP58]] to <16 x i32>
 ; CHECK-LTO-NEXT:    [[TMP121:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP59]])
-; CHECK-LTO-NEXT:    [[OP_RDX_8:%.*]] = add i32 [[OP_RDX_7]], [[TMP121]]
 ; CHECK-LTO-NEXT:    [[ADD_PTR_8:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_7]], i64 [[IDX_EXT]]
 ; CHECK-LTO-NEXT:    [[ADD_PTR9_8:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_7]], i64 [[IDX_EXT8]]
 ; CHECK-LTO-NEXT:    [[TMP61:%.*]] = load <16 x i8>, ptr [[ADD_PTR_8]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -297,7 +296,8 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
 ; CHECK-LTO-NEXT:    [[TMP66:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP65]], i1 true)
 ; CHECK-LTO-NEXT:    [[TMP67:%.*]] = zext nneg <16 x i16> [[TMP66]] to <16 x i32>
 ; CHECK-LTO-NEXT:    [[TMP122:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP67]])
-; CHECK-LTO-NEXT:    [[OP_RDX_9:%.*]] = add i32 [[OP_RDX_8]], [[TMP122]]
+; CHECK-LTO-NEXT:    [[TMP108:%.*]] = add nuw nsw i32 [[TMP121]], [[TMP122]]
+; CHECK-LTO-NEXT:    [[TMP124:%.*]] = add nuw nsw i32 [[TMP108]], [[OP_RDX_7]]
 ; CHECK-LTO-NEXT:    [[ADD_PTR_9:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_8]], i64 [[IDX_EXT]]
 ; CHECK-LTO-NEXT:    [[ADD_PTR9_9:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_8]], i64 [[IDX_EXT8]]
 ; CHECK-LTO-NEXT:    [[TMP69:%.*]] = load <16 x i8>, ptr [[ADD_PTR_9]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -308,7 +308,6 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
 ; CHECK-LTO-NEXT:    [[TMP74:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP73]], i1 true)
 ; CHECK-LTO-NEXT:    [[TMP75:%.*]] = zext nneg <16 x i16> [[TMP74]] to <16 x i32>
 ; CHECK-LTO-NEXT:    [[TMP123:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP75]])
-; CHECK-LTO-NEXT:    [[OP_RDX_10:%.*]] = add i32 [[OP_RDX_9]], [[TMP123]]
 ; CHECK-LTO-NEXT:    [[ADD_PTR_10:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_9]], i64 [[IDX_EXT]]
 ; CHECK-LTO-NEXT:    [[ADD_PTR9_10:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_9]], i64 [[IDX_EXT8]]
 ; CHECK-LTO-NEXT:    [[TMP77:%.*]] = load <16 x i8>, ptr [[ADD_PTR_10]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -318,7 +317,8 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
 ; CHECK-LTO-NEXT:    [[TMP81:%.*]] = sub nsw <16 x i16> [[TMP78]], [[TMP80]]
 ; CHECK-LTO-NEXT:    [[TMP82:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP81]], i1 true)
 ; CHECK-LTO-NEXT:    [[TMP83:%.*]] = zext nneg <16 x i16> [[TMP82]] to <16 x i32>
-; CHECK-LTO-NEXT:    [[TMP124:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP83]])
+; CHECK-LTO-NEXT:    [[TMP128:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP83]])
+; CHECK-LTO-NEXT:    [[OP_RDX_10:%.*]] = add nuw nsw i32 [[TMP123]], [[TMP128]]
 ; CHECK-LTO-NEXT:    [[OP_RDX_11:%.*]] = add i32 [[OP_RDX_10]], [[TMP124]]
 ; CHECK-LTO-NEXT:    [[ADD_PTR_11:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_10]], i64 [[IDX_EXT]]
 ; CHECK-LTO-NEXT:    [[ADD_PTR9_11:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_10]], i64 [[IDX_EXT8]]
@@ -330,7 +330,6 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
 ; CHECK-LTO-NEXT:    [[TMP90:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP89]], i1 true)
 ; CHECK-LTO-NEXT:    [[TMP91:%.*]] = zext nneg <16 x i16> [[TMP90]] to <16 x i32>
 ; CHECK-LTO-NEXT:    [[TMP125:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP91]])
-; CHECK-LTO-NEXT:    [[OP_RDX_12:%.*]] = add i32 [[OP_RDX_11]], [[TMP125]]
 ; CHECK-LTO-NEXT:    [[ADD_PTR_12:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_11]], i64 [[IDX_EXT]]
 ; CHECK-LTO-NEXT:    [[ADD_PTR9_12:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_11]], i64 [[IDX_EXT8]]
 ; CHECK-LTO-NEXT:    [[TMP93:%.*]] = load <16 x i8>, ptr [[ADD_PTR_12]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -341,7 +340,8 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
 ; CHECK-LTO-NEXT:    [[TMP98:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP97]], i1 true)
 ; CHECK-LTO-NEXT:    [[TMP99:%.*]] = zext nneg <16 x i16> [[TMP98]] to <16 x i32>
 ; CHECK-LTO-NEXT:    [[TMP126:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP99]])
-; CHECK-LTO-NEXT:    [[OP_RDX_13:%.*]] = add i32 [[OP_RDX_12]], [[TMP126]]
+; CHECK-LTO-NEXT:    [[TMP129:%.*]] = add nuw nsw i32 [[TMP125]], [[TMP126]]
+; CHECK-LTO-NEXT:    [[TMP127:%.*]] = add i32 [[TMP129]], [[OP_RDX_11]]
 ; CHECK-LTO-NEXT:    [[ADD_PTR_13:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_12]], i64 [[IDX_EXT]]
 ; CHECK-LTO-NEXT:    [[ADD_PTR9_13:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_12]], i64 [[IDX_EXT8]]
 ; CHECK-LTO-NEXT:    [[TMP101:%.*]] = load <16 x i8>, ptr [[ADD_PTR_13]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -352,7 +352,6 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
 ; CHECK-LTO-NEXT:    [[TMP106:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP105]], i1 true)
 ; CHECK-LTO-NEXT:    [[TMP107:%.*]] = zext nneg <16 x i16> [[TMP106]] to <16 x i32>
 ; CHECK-LTO-NEXT:    [[TMP119:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP107]])
-; CHECK-LTO-NEXT:    [[OP_RDX_14:%.*]] = add i32 [[OP_RDX_13]], [[TMP119]]
 ; CHECK-LTO-NEXT:    [[ADD_PTR_14:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_13]], i64 [[IDX_EXT]]
 ; CHECK-LTO-NEXT:    [[ADD_PTR9_14:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_13]], i64 [[IDX_EXT8]]
 ; CHECK-LTO-NEXT:    [[TMP109:%.*]] = load <16 x i8>, ptr [[ADD_PTR_14]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -362,7 +361,8 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
 ; CHECK-LTO-NEXT:    [[TMP113:%.*]] = sub nsw <16 x i16> [[TMP110]], [[TMP112]]
 ; CHECK-LTO-NEXT:    [[TMP114:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP113]], i1 true)
 ; CHECK-LTO-NEXT:    [[TMP115:%.*]] = zext nneg <16 x i16> [[TMP114]] to <16 x i32>
-; CHECK-LTO-NEXT:    [[TMP127:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP115]])
+; CHECK-LTO-NEXT:    [[TMP133:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP115]])
+; CHECK-LTO-NEXT:    [[OP_RDX_14:%.*]] = add nuw nsw i32 [[TMP119]], [[TMP133]]
 ; CHECK-LTO-NEXT:    [[OP_RDX_15:%.*]] = add i32 [[OP_RDX_14]], [[TMP127]]
 ; CHECK-LTO-NEXT:    ret i32 [[OP_RDX_15]]
 ;
diff --git a/llvm/test/Transforms/VectorCombine/fold-binop-of-reductions-add.ll b/llvm/test/Transforms/VectorCombine/fold-binop-of-reductions-add.ll
new file mode 100644
index 0000000000000..db34481932828
--- /dev/null
+++ b/llvm/test/Transforms/VectorCombine/fold-binop-of-reductions-add.ll
@@ -0,0 +1,54 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; Test case for foldBinopOfReductions with Add operator
+; RUN: opt -S -passes=vector-combine %s | FileCheck %s
+
+define i32 @test_add_add_reduction(i32 %a, <4 x i32> %b, <4 x i32> %c) {
+; CHECK-LABEL: define i32 @test_add_add_reduction(
+; CHECK-SAME: i32 [[A:%.*]], <4 x i32> [[B:%.*]], <4 x i32> [[C:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = add <4 x i32> [[B]], [[C]]
+; CHECK-NEXT:    [[TMP2:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT:    [[ADD2:%.*]] = add i32 [[TMP2]], [[A]]
+; CHECK-NEXT:    ret i32 [[ADD2]]
+;
+  ; Test case 1: add (add (vector_reduce_add b, a), vector_reduce_add c)
+  %reduce_b = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %b)
+  %add1 = add i32 %reduce_b, %a
+  %reduce_c = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %c)
+  %add2 = add i32 %add1, %reduce_c
+  ret i32 %add2
+}
+
+define i32 @test_add_add_reduction_reverse(i32 %a, <4 x i32> %b, <4 x i32> %c) {
+; CHECK-LABEL: define i32 @test_add_add_reduction_reverse(
+; CHECK-SAME: i32 [[A:%.*]], <4 x i32> [[B:%.*]], <4 x i32> [[C:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = add <4 x i32> [[B]], [[C]]
+; CHECK-NEXT:    [[TMP2:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT:    [[ADD2:%.*]] = add i32 [[TMP2]], [[A]]
+; CHECK-NEXT:    ret i32 [[ADD2]]
+;
+  ; Test case 2: add (add (a, vector_reduce_add b), vector_reduce_add c)
+  %reduce_b = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %b)
+  %add1 = add i32 %a, %reduce_b
+  %reduce_c = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %c)
+  %add2 = add i32 %add1, %reduce_c
+  ret i32 %add2
+}
+
+define i32 @test_add_reduction_add(i32 %c, <4 x i32> %a, <4 x i32> %b) {
+; CHECK-LABEL: define i32 @test_add_reduction_add(
+; CHECK-SAME: i32 [[C:%.*]], <4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = add <4 x i32> [[B]], [[A]]
+; CHECK-NEXT:    [[TMP2:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT:    [[ADD2:%.*]] = add i32 [[TMP2]], [[C]]
+; CHECK-NEXT:    ret i32 [[ADD2]]
+;
+  ; Test case 3: add (vector_reduce_add a, add (vector_reduce_add b, c))
+  %reduce_a = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %a)
+  %reduce_b = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %b)
+  %add1 = add i32 %reduce_b, %c
+  %add2 = add i32 %reduce_a, %add1
+  ret i32 %add2
+}
+
+; Declare the vector reduce add intrinsic
+declare i32 @llvm.vector.reduce.add.v4i32(<4 x i32>)
diff --git a/llvm/test/Transforms/VectorCombine/fold-binop-of-reductions-sub.ll b/llvm/test/Transforms/VectorCombine/fold-binop-of-reductions-sub.ll
new file mode 100644
index 0000000000000..c6d9771d32713
--- /dev/null
+++ b/llvm/test/Transforms/VectorCombine/fold-binop-of-reductions-sub.ll
@@ -0,0 +1,132 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes=vector-combine %s | FileCheck %s
+
+define i32 @test_sub_add_reduction_1(i32 %b, <4 x i32> %a, <4 x i32> %c) {
+; CHECK-LABEL: define i32 @test_sub_add_reduction_1(
+; CHECK-SAME: i32 [[B:%.*]], <4 x i32> [[A:%.*]], <4 x i32> [[C:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = sub <4 x i32> [[A]], [[C]]
+; CHECK-NEXT:    [[TMP2:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT:    [[SUB:%.*]] = add i32 [[TMP2]], [[B]]
+; CHECK-NEXT:    ret i32 [[SUB]]
+;
+  %reduce_a = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %a)
+  %reduce_c = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %c)
+  ; (ReduceA + B) - ReduceC
+  %add = add i32 %reduce_a, %b
+  %sub = sub i32 %add, %reduce_c
+  ret i32 %sub
+}
+
+define i32 @test_sub_add_reduction_2(i32 %b, <4 x i32> %a, <4 x i32> %c) {
+; CHECK-LABEL: define i32 @test_sub_add_reduction_2(
+; CHECK-SAME: i32 [[B:%.*]], <4 x i32> [[A:%.*]], <4 x i32> [[C:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = sub <4 x i32> [[A]], [[C]]
+; CHECK-NEXT:    [[TMP2:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT:    [[SUB:%.*]] = add i32 [[TMP2]], [[B]]
+; CHECK-NEXT:    ret i32 [[SUB]]
+;
+  %reduce_a = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %a)
+  %reduce_c = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %c)
+  ; (B + ReduceA) - ReduceC
+  %add = add i32 %b, %reduce_a
+  %sub = sub i32 %add, %reduce_c
+  ret i32 %sub
+}
+
+define i32 @test_sub_sub_reduction_1(i32 %b, <4 x i32> %a, <4 x i32> %c) {
+; CHECK-LABEL: define i32 @test_sub_sub_reduction_1(
+; CHECK-SAME: i32 [[B:%.*]], <4 x i32> [[A:%.*]], <4 x i32> [[C:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = sub <4 x i32> [[A]], [[C]]
+; CHECK-NEXT:    [[TMP2:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT:    [[SUB:%.*]] = sub i32 [[TMP2]], [[B]]
+; CHECK-NEXT:    ret i32 [[SUB]]
+;
+  %reduce_a = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %a)
+  %reduce_c = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %c)
+  ; (ReduceA - B) - ReduceC
+  %op0 = sub i32 %reduce_a, %b
+  %sub = sub i32 %op0, %reduce_c
+  ret i32 %sub
+}
+
+define i32 @test_sub_sub_reduction_2(i32 %b, <4 x i32> %a, <4 x i32> %c) {
+; CHECK-LABEL: define i32 @test_sub_sub_reduction_2(
+; CHECK-SAME: i32 [[B:%.*]], <4 x i32> [[A:%.*]], <4 x i32> [[C:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = add <4 x i32> [[A]], [[C]]
+; CHECK-NEXT:    [[TMP2:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT:    [[SUB:%.*]] = sub i32 [[B]], [[TMP2]]
+; CHECK-NEXT:    ret i32 [[SUB]]
+;
+  %reduce_a = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %a)
+  %reduce_c = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %c)
+  ; (B - ReduceA) - ReduceC
+  %op0 = sub i32 %b, %reduce_a
+  %sub = sub i32 %op0, %reduce_c
+  ret i32 %sub
+}
+
+define i32 @test_sub_add_reduction_3(i32 %b, <4 x i32> %a, <4 x i32> %c) {
+; CHECK-LABEL: define i32 @test_sub_add_reduction_3(
+; CHECK-SAME: i32 [[B:%.*]], <4 x i32> [[A:%.*]], <4 x i32> [[C:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = sub <4 x i32> [[C]], [[A]]
+; CHECK-NEXT:    [[TMP2:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT:    [[SUB:%.*]] = sub i32 [[TMP2]], [[B]]
+; CHECK-NEXT:    ret i32 [[SUB]]
+;
+  %reduce_a = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %a)
+  %reduce_c = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %c)
+  ; ReduceC - (ReduceA + B)
+  %op0 = add i32 %reduce_a, %b
+  %sub = sub i32 %reduce_c, %op0
+  ret i32 %sub
+}
+
+define i32 @test_sub_add_reduction_4(i32 %b, <4 x i32> %a, <4 x i32> %c) {
+; CHECK-LABEL: define i32 @test_sub_add_reduction_4(
+; CHECK-SAME: i32 [[B:%.*]], <4 x i32> [[A:%.*]], <4 x i32> [[C:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = sub <4 x i32> [[C]], [[A]]
+; CHECK-NEXT:    [[TMP2:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT:    [[SUB:%.*]] = sub i32 [[TMP2]], [[B]]
+; CHECK-NEXT:    ret i32 [[SUB]]
+;
+  %reduce_a = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %a)
+  %reduce_c = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %c)
+  ; ReduceC - (B + ReduceA)
+  %op0 = add i32 %b, %reduce_a
+  %sub = sub i32 %reduce_c, %op0
+  ret i32 %sub
+}
+
+define i32 @test_sub_sub_reduction_3(i32 %b, <4 x i32> %a, <4 x i32> %c) {
+; CHECK-LABEL: define i32 @test_sub_sub_reduction_3(
+; CHECK-SAME: i32 [[B:%.*]], <4 x i32> [[A:%.*]], <4 x i32> [[C:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = sub <4 x i32> [[C]], [[A]]
+; CHECK-NEXT:    [[TMP2:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT:    [[SUB:%.*]] = add i32 [[TMP2]], [[B]]
+; CHECK-NEXT:    ret i32 [[SUB]]
+;
+  %reduce_a = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %a)
+  %reduce_c = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %c)
+  ; ReduceC - (ReduceA - B)
+  %op0 = sub i32 %reduce_a, %b
+  %sub = sub i32 %reduce_c, %op0
+  ret i32 %sub
+}
+
+define i32 @test_sub_sub_reduction_4(i32 %b, <4 x i32> %a, <4 x i32> %c) {
+; CHECK-LABEL: define i32 @test_sub_sub_reduction_4(
+; CHECK-SAME: i32 [[B:%.*]], <4 x i32> [[A:%.*]], <4 x i32> [[C:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = add <4 x i32> [[C]], [[A]]
+; CHECK-NEXT:    [[TMP2:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT:    [[SUB:%.*]] = sub i32 [[TMP2]], [[B]]
+; CHECK-NEXT:    ret i32 [[SUB]]
+;
+  %reduce_a = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %a)
+  %reduce_c = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %c)
+  ; ReduceC - (B - ReduceA)
+  %op0 = sub i32 %b, %reduce_a
+  %sub = sub i32 %reduce_c, %op0
+  ret i32 %sub
+}
+
+declare i32 @llvm.vector.reduce.add.v4i32(<4 x i32>)
diff --git a/llvm/test/Transforms/VectorCombine/fold-binop-of-reductions.ll b/llvm/test/Transforms/VectorCombine/fold-binop-of-reductions.ll
index 22960119ce056..e2ddc8725bdd5 100644
--- a/llvm/test/Transforms/VectorCombine/fold-binop-of-reductions.ll
+++ b/llvm/test/Transforms/VectorCombine/fold-binop-of-reductions.ll
@@ -226,7 +226,7 @@ define i32 @sub_add_s_reduction_reduction(<vscale x 8 x i32> %v0, <vscale x 8 x
 ; CHECK-SAME: <vscale x 8 x i32> [[V0:%.*]], <vscale x 8 x i32> [[V1:%.*]], i32 [[S1:%.*]]) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = sub <vscale x 8 x i32> [[V0]], [[V1]]
 ; CHECK-NEXT:    [[TMP2:%.*]] = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> [[TMP1]])
-; CHECK-NEXT:    [[RES:%.*]] = add i32 [[S1]], [[TMP2]]
+; CHECK-NEXT:    [[RES:%.*]] = add i32 [[TMP2]], [[S1]]
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;
   %v0_red = tail call i32 @llvm.vector.reduce.add.v8i32(<vscale x 8 x i32> %v0)

>From 3315b61beb8df393297c1a7fba653b8ff42ec87b Mon Sep 17 00:00:00 2001
From: Anjian-Wen <wenanjian at bytedance.com>
Date: Mon, 2 Mar 2026 14:28:08 +0800
Subject: [PATCH 4/5] Convert matchAssociativeReduction a member method of
 VectorCombine

---
 .../Transforms/Vectorize/VectorCombine.cpp    | 20 +++++++++++--------
 1 file changed, 12 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 70c7c5a6c67cb..55dfd5ea1f8e9 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -206,6 +206,13 @@ class VectorCombine {
       }
     }
   }
+
+  template <typename IRBuilderTy>
+  bool matchAssociativeReduction(
+      Instruction &I, Instruction::BinaryOps BinOpOpc,
+      Intrinsic::ID ReductionIID, IRBuilderTy &Builder, BinaryOperator *Op0,
+      Value *ScalarReduceC, bool ScalarReduceCIsLeft,
+      function_ref<void(Instruction &, Value &)> ReplaceValue);
 };
 } // namespace
 
@@ -1750,7 +1757,7 @@ static Value *checkIntrinsicAndGetItsArgument(Value *V, Intrinsic::ID IID) {
 }
 
 template <typename IRBuilderTy>
-static bool matchAssociativeReduction(
+bool VectorCombine::matchAssociativeReduction(
     Instruction &I, Instruction::BinaryOps BinOpOpc, Intrinsic::ID ReductionIID,
     IRBuilderTy &Builder, BinaryOperator *Op0, Value *ScalarReduceC,
     bool ScalarReduceCIsLeft,
@@ -1813,7 +1820,7 @@ static bool matchAssociativeReduction(
         RHS_Bin = nullptr;
       }
     }
-  } else { // Outer == Sub
+  } else {                      // Outer == Sub
     if (!ScalarReduceCIsLeft) { // (Op0 - C)
       if (InnerOp == Instruction::Add) {
         // (X + Y) - C -> (X - C) + Y
@@ -1873,8 +1880,7 @@ static bool matchAssociativeReduction(
     }
   }
 
-  Value *CombineNode =
-      Builder.CreateBinOp(NewReduceOp, LHS_Reduce, RHS_Reduce);
+  Value *CombineNode = Builder.CreateBinOp(NewReduceOp, LHS_Reduce, RHS_Reduce);
 
   Value *NewBinNode;
   if (LHS_Bin == nullptr)
@@ -1910,8 +1916,8 @@ bool VectorCombine::foldBinopOfReductions(Instruction &I) {
   //   Reduce(Z) - (Reduce(X) + Y)  -> Reduce(Z - X) - Y
   if (auto *Op0 = dyn_cast<BinaryOperator>(I.getOperand(0))) {
     if (matchAssociativeReduction(I, BinOpOpc, ReductionIID, Builder, Op0,
-                                  I.getOperand(1), /*ScalarReduceCIsLeft=*/false,
-                                  ReplaceValue))
+                                  I.getOperand(1),
+                                  /*ScalarReduceCIsLeft=*/false, ReplaceValue))
       return true;
   }
 
@@ -1922,8 +1928,6 @@ bool VectorCombine::foldBinopOfReductions(Instruction &I) {
       return true;
   }
 
-
-
   Value *V0 = checkIntrinsicAndGetItsArgument(I.getOperand(0), ReductionIID);
   if (!V0)
     return false;

>From 9f09cc9a096f3361864789b438246b30ae2d50c1 Mon Sep 17 00:00:00 2001
From: Anjian-Wen <wenanjian at bytedance.com>
Date: Fri, 6 Mar 2026 17:17:56 +0800
Subject: [PATCH 5/5] update comments

---
 llvm/lib/Transforms/Vectorize/VectorCombine.cpp | 15 ++++++++-------
 1 file changed, 8 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 55dfd5ea1f8e9..2d61dfc5a29ed 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -1905,15 +1905,16 @@ bool VectorCombine::foldBinopOfReductions(Instruction &I) {
 
   auto ReplaceValue = [&](Instruction &I, Value &V) { replaceValue(I, V); };
 
-  // Reduce the number of reductions by folding a binop of a reduction and a
-  // scalar (which might be another reduction) into a single reduction of a
-  // vector binop. This leverages associativity and commutativity of the
-  // binary operation (Add/Sub) to group reductions together.
+  // Bring the two reduce instructions closer together to provide an opportunity
+  // to see if they can be folded. Only focus on the direct Add/Sub between
+  // the two reduces instruction. whether it can ultimately be folded depends on
+  // subsequent type and cost model judgments.
   //
   // Examples:
-  //   (Reduce(X) + Y) + Reduce(Z)  -> Reduce(X + Z) + Y
-  //   (Reduce(X) - Y) - Reduce(Z)  -> Reduce(X - Z) - Y
-  //   Reduce(Z) - (Reduce(X) + Y)  -> Reduce(Z - X) - Y
+  //   Reduce(X) + Y + Reduce(Z)   -> Reduce(X) + Reduce(Z) + Y
+  //   Reduce(X) - Y - Reduce(Z)   -> Reduce(X) - Reduce(Z) - Y
+  //   Reduce(Z) - (Reduce(X) + Y) -> Reduce(Z) - Reduce(X) - Y
+  //
   if (auto *Op0 = dyn_cast<BinaryOperator>(I.getOperand(0))) {
     if (matchAssociativeReduction(I, BinOpOpc, ReductionIID, Builder, Op0,
                                   I.getOperand(1),



More information about the llvm-commits mailing list