[llvm] [VectorCombine] Fold binary op of reductions. (PR #121567)

Mikhail Gudim via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 29 08:18:40 PST 2025


https://github.com/mgudim updated https://github.com/llvm/llvm-project/pull/121567

>From 4891d75e909bd1539aa0f813240ec0b41023a9e4 Mon Sep 17 00:00:00 2001
From: Mikhail Gudim <mgudim at ventanamicro.com>
Date: Thu, 2 Jan 2025 08:26:22 -0800
Subject: [PATCH 1/2] [InstCombine] Fold binary op of reductions.

Replace binary of of two reductions with one reduction of the binary op
applied to vectors. For example:

```
%v0_red = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %v0)
%v1_red = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %v1)
%res = add i32 %v0_red, %v1_red
```
gets transformed to:

```
%1 = add <16 x i32> %v0, %v1
%res = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %1)
```
---
 .../InstCombine/InstCombineAddSub.cpp         | 18 ++----
 .../InstCombine/InstCombineAndOrXor.cpp       |  9 +++
 .../InstCombine/InstCombineInternal.h         |  1 +
 .../InstCombine/InstCombineMulDivRem.cpp      |  3 +
 .../InstCombine/InstructionCombining.cpp      | 57 +++++++++++++++++++
 .../InstCombine/fold-binop-of-reductions.ll   | 40 ++++++-------
 6 files changed, 91 insertions(+), 37 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 658bbbc5697660..29cf75a6e318da 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1528,6 +1528,9 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
   if (Instruction *X = foldVectorBinop(I))
     return X;
 
+  if (Instruction *X = foldBinopOfReductions(I))
+    return replaceInstUsesWith(I, X);
+
   if (Instruction *Phi = foldBinopWithPhiOperands(I))
     return Phi;
 
@@ -2387,19 +2390,8 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
     }
   }
 
-  auto m_AddRdx = [](Value *&Vec) {
-    return m_OneUse(m_Intrinsic<Intrinsic::vector_reduce_add>(m_Value(Vec)));
-  };
-  Value *V0, *V1;
-  if (match(Op0, m_AddRdx(V0)) && match(Op1, m_AddRdx(V1)) &&
-      V0->getType() == V1->getType()) {
-    // Difference of sums is sum of differences:
-    // add_rdx(V0) - add_rdx(V1) --> add_rdx(V0 - V1)
-    Value *Sub = Builder.CreateSub(V0, V1);
-    Value *Rdx = Builder.CreateIntrinsic(Intrinsic::vector_reduce_add,
-                                         {Sub->getType()}, {Sub});
-    return replaceInstUsesWith(I, Rdx);
-  }
+  if (Instruction *X = foldBinopOfReductions(I))
+    return replaceInstUsesWith(I, X);
 
   if (Constant *C = dyn_cast<Constant>(Op0)) {
     Value *X;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index ca8a20b4b7312d..b36eaf312837ba 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -2380,6 +2380,9 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
   if (Instruction *X = foldVectorBinop(I))
     return X;
 
+  if (Instruction *X = foldBinopOfReductions(I))
+    return replaceInstUsesWith(I, X);
+
   if (Instruction *Phi = foldBinopWithPhiOperands(I))
     return Phi;
 
@@ -3560,6 +3563,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
   if (Instruction *X = foldVectorBinop(I))
     return X;
 
+  if (Instruction *X = foldBinopOfReductions(I))
+    return replaceInstUsesWith(I, X);
+
   if (Instruction *Phi = foldBinopWithPhiOperands(I))
     return Phi;
 
@@ -4671,6 +4677,9 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
   if (Instruction *X = foldVectorBinop(I))
     return X;
 
+  if (Instruction *X = foldBinopOfReductions(I))
+    return replaceInstUsesWith(I, X);
+
   if (Instruction *Phi = foldBinopWithPhiOperands(I))
     return Phi;
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 83e1da98deeda0..40a03cc24817de 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -594,6 +594,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
 
   /// Canonicalize the position of binops relative to shufflevector.
   Instruction *foldVectorBinop(BinaryOperator &Inst);
+  Instruction *foldBinopOfReductions(BinaryOperator &Inst);
   Instruction *foldVectorSelect(SelectInst &Sel);
   Instruction *foldSelectShuffle(ShuffleVectorInst &Shuf);
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index df5f9833a2ff92..95912991d85e9a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -199,6 +199,9 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
   if (Instruction *X = foldVectorBinop(I))
     return X;
 
+  if (Instruction *X = foldBinopOfReductions(I))
+    return replaceInstUsesWith(I, X);
+
   if (Instruction *Phi = foldBinopWithPhiOperands(I))
     return Phi;
 
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index cad17c511b6d03..380ddf904d397f 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -2317,6 +2317,63 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) {
   return nullptr;
 }
 
+static Intrinsic::ID getReductionForBinop(Instruction::BinaryOps Opc) {
+  switch (Opc) {
+  default:
+    break;
+  case Instruction::Add:
+    return Intrinsic::vector_reduce_add;
+  case Instruction::Mul:
+    return Intrinsic::vector_reduce_mul;
+  case Instruction::And:
+    return Intrinsic::vector_reduce_and;
+  case Instruction::Or:
+    return Intrinsic::vector_reduce_or;
+  case Instruction::Xor:
+    return Intrinsic::vector_reduce_xor;
+  }
+  return Intrinsic::not_intrinsic;
+}
+
+Instruction *InstCombinerImpl::foldBinopOfReductions(BinaryOperator &Inst) {
+  Instruction::BinaryOps BinOpOpc = Inst.getOpcode();
+  Intrinsic::ID ReductionIID = getReductionForBinop(BinOpOpc);
+  if (BinOpOpc == Instruction::Sub)
+    ReductionIID = Intrinsic::vector_reduce_add;
+  if (ReductionIID == Intrinsic::not_intrinsic)
+    return nullptr;
+
+  auto checkIntrinsicAndGetItsArgument = [](Value *V,
+                                            Intrinsic::ID IID) -> Value * {
+    IntrinsicInst *II = dyn_cast<IntrinsicInst>(V);
+    if (!II)
+      return nullptr;
+    if (II->getIntrinsicID() == IID && II->hasOneUse())
+      return II->getArgOperand(0);
+    return nullptr;
+  };
+
+  Value *V0 = checkIntrinsicAndGetItsArgument(Inst.getOperand(0), ReductionIID);
+  if (!V0)
+    return nullptr;
+  Value *V1 = checkIntrinsicAndGetItsArgument(Inst.getOperand(1), ReductionIID);
+  if (!V1)
+    return nullptr;
+
+  Type *VTy = V0->getType();
+  if (V1->getType() != VTy)
+    return nullptr;
+
+  Value *VectorBO = Builder.CreateBinOp(BinOpOpc, V0, V1);
+
+  if (PossiblyDisjointInst *PDInst = dyn_cast<PossiblyDisjointInst>(&Inst))
+    if (auto *PDVectorBO = dyn_cast<PossiblyDisjointInst>(VectorBO))
+      PDVectorBO->setIsDisjoint(PDInst->isDisjoint());
+
+  Instruction *Rdx = Builder.CreateIntrinsic(ReductionIID, {VTy}, {VectorBO});
+  return Rdx;
+}
+
 /// Try to narrow the width of a binop if at least 1 operand is an extend of
 /// of a value. This requires a potentially expensive known bits check to make
 /// sure the narrow op does not overflow.
diff --git a/llvm/test/Transforms/InstCombine/fold-binop-of-reductions.ll b/llvm/test/Transforms/InstCombine/fold-binop-of-reductions.ll
index 86f17cdfb79b42..cc88db03fbe2cd 100644
--- a/llvm/test/Transforms/InstCombine/fold-binop-of-reductions.ll
+++ b/llvm/test/Transforms/InstCombine/fold-binop-of-reductions.ll
@@ -4,9 +4,8 @@
 define i32 @add_of_reduce_add(<16 x i32> %v0, <16 x i32> %v1) {
 ; CHECK-LABEL: define i32 @add_of_reduce_add(
 ; CHECK-SAME: <16 x i32> [[V0:%.*]], <16 x i32> [[V1:%.*]]) {
-; CHECK-NEXT:    [[V0_RED:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[V0]])
-; CHECK-NEXT:    [[V1_RED:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[V1]])
-; CHECK-NEXT:    [[RES:%.*]] = add i32 [[V0_RED]], [[V1_RED]]
+; CHECK-NEXT:    [[TMP1:%.*]] = add <16 x i32> [[V0]], [[V1]]
+; CHECK-NEXT:    [[RES:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP1]])
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;
   %v0_red = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %v0)
@@ -31,9 +30,8 @@ define i32 @sub_of_reduce_add(<16 x i32> %v0, <16 x i32> %v1) {
 define i32 @mul_of_reduce_mul(<16 x i32> %v0, <16 x i32> %v1) {
 ; CHECK-LABEL: define i32 @mul_of_reduce_mul(
 ; CHECK-SAME: <16 x i32> [[V0:%.*]], <16 x i32> [[V1:%.*]]) {
-; CHECK-NEXT:    [[V0_RED:%.*]] = tail call i32 @llvm.vector.reduce.mul.v16i32(<16 x i32> [[V0]])
-; CHECK-NEXT:    [[V1_RED:%.*]] = tail call i32 @llvm.vector.reduce.mul.v16i32(<16 x i32> [[V1]])
-; CHECK-NEXT:    [[RES:%.*]] = mul i32 [[V0_RED]], [[V1_RED]]
+; CHECK-NEXT:    [[TMP1:%.*]] = mul <16 x i32> [[V0]], [[V1]]
+; CHECK-NEXT:    [[RES:%.*]] = call i32 @llvm.vector.reduce.mul.v16i32(<16 x i32> [[TMP1]])
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;
   %v0_red = tail call i32 @llvm.vector.reduce.mul.v16i32(<16 x i32> %v0)
@@ -45,9 +43,8 @@ define i32 @mul_of_reduce_mul(<16 x i32> %v0, <16 x i32> %v1) {
 define i32 @and_of_reduce_and(<16 x i32> %v0, <16 x i32> %v1) {
 ; CHECK-LABEL: define i32 @and_of_reduce_and(
 ; CHECK-SAME: <16 x i32> [[V0:%.*]], <16 x i32> [[V1:%.*]]) {
-; CHECK-NEXT:    [[V0_RED:%.*]] = tail call i32 @llvm.vector.reduce.and.v16i32(<16 x i32> [[V0]])
-; CHECK-NEXT:    [[V1_RED:%.*]] = tail call i32 @llvm.vector.reduce.and.v16i32(<16 x i32> [[V1]])
-; CHECK-NEXT:    [[RES:%.*]] = and i32 [[V0_RED]], [[V1_RED]]
+; CHECK-NEXT:    [[TMP1:%.*]] = and <16 x i32> [[V0]], [[V1]]
+; CHECK-NEXT:    [[RES:%.*]] = call i32 @llvm.vector.reduce.and.v16i32(<16 x i32> [[TMP1]])
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;
   %v0_red = tail call i32 @llvm.vector.reduce.and.v16i32(<16 x i32> %v0)
@@ -59,9 +56,8 @@ define i32 @and_of_reduce_and(<16 x i32> %v0, <16 x i32> %v1) {
 define i32 @or_of_reduce_or(<16 x i32> %v0, <16 x i32> %v1) {
 ; CHECK-LABEL: define i32 @or_of_reduce_or(
 ; CHECK-SAME: <16 x i32> [[V0:%.*]], <16 x i32> [[V1:%.*]]) {
-; CHECK-NEXT:    [[V0_RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v16i32(<16 x i32> [[V0]])
-; CHECK-NEXT:    [[V1_RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v16i32(<16 x i32> [[V1]])
-; CHECK-NEXT:    [[RES:%.*]] = or i32 [[V0_RED]], [[V1_RED]]
+; CHECK-NEXT:    [[TMP1:%.*]] = or <16 x i32> [[V0]], [[V1]]
+; CHECK-NEXT:    [[RES:%.*]] = call i32 @llvm.vector.reduce.or.v16i32(<16 x i32> [[TMP1]])
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;
   %v0_red = tail call i32 @llvm.vector.reduce.or.v16i32(<16 x i32> %v0)
@@ -73,9 +69,8 @@ define i32 @or_of_reduce_or(<16 x i32> %v0, <16 x i32> %v1) {
 define i32 @xor_of_reduce_xor(<16 x i32> %v0, <16 x i32> %v1) {
 ; CHECK-LABEL: define i32 @xor_of_reduce_xor(
 ; CHECK-SAME: <16 x i32> [[V0:%.*]], <16 x i32> [[V1:%.*]]) {
-; CHECK-NEXT:    [[V0_RED:%.*]] = tail call i32 @llvm.vector.reduce.xor.v16i32(<16 x i32> [[V0]])
-; CHECK-NEXT:    [[V1_RED:%.*]] = tail call i32 @llvm.vector.reduce.xor.v16i32(<16 x i32> [[V1]])
-; CHECK-NEXT:    [[RES:%.*]] = xor i32 [[V0_RED]], [[V1_RED]]
+; CHECK-NEXT:    [[TMP1:%.*]] = xor <16 x i32> [[V0]], [[V1]]
+; CHECK-NEXT:    [[RES:%.*]] = call i32 @llvm.vector.reduce.xor.v16i32(<16 x i32> [[TMP1]])
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;
   %v0_red = tail call i32 @llvm.vector.reduce.xor.v16i32(<16 x i32> %v0)
@@ -161,9 +156,8 @@ define i32 @multiple_use_of_reduction_1(<16 x i32> %v0, <16 x i32> %v1, ptr %p)
 define i32 @do_not_preserve_overflow_flags(<16 x i32> %v0, <16 x i32> %v1) {
 ; CHECK-LABEL: define i32 @do_not_preserve_overflow_flags(
 ; CHECK-SAME: <16 x i32> [[V0:%.*]], <16 x i32> [[V1:%.*]]) {
-; CHECK-NEXT:    [[V0_RED:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[V0]])
-; CHECK-NEXT:    [[V1_RED:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[V1]])
-; CHECK-NEXT:    [[RES:%.*]] = add nuw nsw i32 [[V0_RED]], [[V1_RED]]
+; CHECK-NEXT:    [[TMP1:%.*]] = add <16 x i32> [[V0]], [[V1]]
+; CHECK-NEXT:    [[RES:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP1]])
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;
   %v0_red = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %v0)
@@ -175,9 +169,8 @@ define i32 @do_not_preserve_overflow_flags(<16 x i32> %v0, <16 x i32> %v1) {
 define i32 @preserve_disjoint_flags(<16 x i32> %v0, <16 x i32> %v1) {
 ; CHECK-LABEL: define i32 @preserve_disjoint_flags(
 ; CHECK-SAME: <16 x i32> [[V0:%.*]], <16 x i32> [[V1:%.*]]) {
-; CHECK-NEXT:    [[V0_RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v16i32(<16 x i32> [[V0]])
-; CHECK-NEXT:    [[V1_RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v16i32(<16 x i32> [[V1]])
-; CHECK-NEXT:    [[RES:%.*]] = or disjoint i32 [[V0_RED]], [[V1_RED]]
+; CHECK-NEXT:    [[TMP1:%.*]] = or disjoint <16 x i32> [[V0]], [[V1]]
+; CHECK-NEXT:    [[RES:%.*]] = call i32 @llvm.vector.reduce.or.v16i32(<16 x i32> [[TMP1]])
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;
   %v0_red = tail call i32 @llvm.vector.reduce.or.v16i32(<16 x i32> %v0)
@@ -189,9 +182,8 @@ define i32 @preserve_disjoint_flags(<16 x i32> %v0, <16 x i32> %v1) {
 define i32 @add_of_reduce_add_vscale(<vscale x 16 x i32> %v0, <vscale x 16 x i32> %v1) {
 ; CHECK-LABEL: define i32 @add_of_reduce_add_vscale(
 ; CHECK-SAME: <vscale x 16 x i32> [[V0:%.*]], <vscale x 16 x i32> [[V1:%.*]]) {
-; CHECK-NEXT:    [[V0_RED:%.*]] = tail call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> [[V0]])
-; CHECK-NEXT:    [[V1_RED:%.*]] = tail call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> [[V1]])
-; CHECK-NEXT:    [[RES:%.*]] = add i32 [[V0_RED]], [[V1_RED]]
+; CHECK-NEXT:    [[TMP1:%.*]] = add <vscale x 16 x i32> [[V0]], [[V1]]
+; CHECK-NEXT:    [[RES:%.*]] = call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> [[TMP1]])
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;
   %v0_red = tail call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> %v0)

>From 776684301f88c5a5c9a68cb9c980a24e2a73f6f5 Mon Sep 17 00:00:00 2001
From: Mikhail Gudim <mgudim at ventanamicro.com>
Date: Wed, 29 Jan 2025 03:58:19 -0800
Subject: [PATCH 2/2] moved to vectorcombine

---
 .../include/llvm/Transforms/Utils/LoopUtils.h |  2 +
 .../InstCombine/InstCombineAddSub.cpp         | 18 ++++--
 .../InstCombine/InstCombineAndOrXor.cpp       |  9 ---
 .../InstCombine/InstCombineInternal.h         |  1 -
 .../InstCombine/InstCombineMulDivRem.cpp      |  3 -
 .../InstCombine/InstructionCombining.cpp      | 57 ------------------
 llvm/lib/Transforms/Utils/LoopUtils.cpp       | 20 +++++++
 .../Transforms/Vectorize/VectorCombine.cpp    | 59 +++++++++++++++++++
 .../InstCombine/fold-binop-of-reductions.ll   |  2 +-
 9 files changed, 95 insertions(+), 76 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Utils/LoopUtils.h b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
index b4cd52fef70fd2..1007b9d48fb726 100644
--- a/llvm/include/llvm/Transforms/Utils/LoopUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
@@ -365,6 +365,8 @@ constexpr Intrinsic::ID getReductionIntrinsicID(RecurKind RK);
 
 /// Returns the arithmetic instruction opcode used when expanding a reduction.
 unsigned getArithmeticReductionInstruction(Intrinsic::ID RdxID);
+/// Returns the reduction intrinsic id corresponding to the binary operation.
+Intrinsic::ID getReductionForBinop(Instruction::BinaryOps Opc);
 
 /// Returns the min/max intrinsic used when expanding a min/max reduction.
 Intrinsic::ID getMinMaxReductionIntrinsicOp(Intrinsic::ID RdxID);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 29cf75a6e318da..658bbbc5697660 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1528,9 +1528,6 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
   if (Instruction *X = foldVectorBinop(I))
     return X;
 
-  if (Instruction *X = foldBinopOfReductions(I))
-    return replaceInstUsesWith(I, X);
-
   if (Instruction *Phi = foldBinopWithPhiOperands(I))
     return Phi;
 
@@ -2390,8 +2387,19 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
     }
   }
 
-  if (Instruction *X = foldBinopOfReductions(I))
-    return replaceInstUsesWith(I, X);
+  auto m_AddRdx = [](Value *&Vec) {
+    return m_OneUse(m_Intrinsic<Intrinsic::vector_reduce_add>(m_Value(Vec)));
+  };
+  Value *V0, *V1;
+  if (match(Op0, m_AddRdx(V0)) && match(Op1, m_AddRdx(V1)) &&
+      V0->getType() == V1->getType()) {
+    // Difference of sums is sum of differences:
+    // add_rdx(V0) - add_rdx(V1) --> add_rdx(V0 - V1)
+    Value *Sub = Builder.CreateSub(V0, V1);
+    Value *Rdx = Builder.CreateIntrinsic(Intrinsic::vector_reduce_add,
+                                         {Sub->getType()}, {Sub});
+    return replaceInstUsesWith(I, Rdx);
+  }
 
   if (Constant *C = dyn_cast<Constant>(Op0)) {
     Value *X;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index b36eaf312837ba..ca8a20b4b7312d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -2380,9 +2380,6 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
   if (Instruction *X = foldVectorBinop(I))
     return X;
 
-  if (Instruction *X = foldBinopOfReductions(I))
-    return replaceInstUsesWith(I, X);
-
   if (Instruction *Phi = foldBinopWithPhiOperands(I))
     return Phi;
 
@@ -3563,9 +3560,6 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
   if (Instruction *X = foldVectorBinop(I))
     return X;
 
-  if (Instruction *X = foldBinopOfReductions(I))
-    return replaceInstUsesWith(I, X);
-
   if (Instruction *Phi = foldBinopWithPhiOperands(I))
     return Phi;
 
@@ -4677,9 +4671,6 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
   if (Instruction *X = foldVectorBinop(I))
     return X;
 
-  if (Instruction *X = foldBinopOfReductions(I))
-    return replaceInstUsesWith(I, X);
-
   if (Instruction *Phi = foldBinopWithPhiOperands(I))
     return Phi;
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 40a03cc24817de..83e1da98deeda0 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -594,7 +594,6 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
 
   /// Canonicalize the position of binops relative to shufflevector.
   Instruction *foldVectorBinop(BinaryOperator &Inst);
-  Instruction *foldBinopOfReductions(BinaryOperator &Inst);
   Instruction *foldVectorSelect(SelectInst &Sel);
   Instruction *foldSelectShuffle(ShuffleVectorInst &Shuf);
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 95912991d85e9a..df5f9833a2ff92 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -199,9 +199,6 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
   if (Instruction *X = foldVectorBinop(I))
     return X;
 
-  if (Instruction *X = foldBinopOfReductions(I))
-    return replaceInstUsesWith(I, X);
-
   if (Instruction *Phi = foldBinopWithPhiOperands(I))
     return Phi;
 
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 380ddf904d397f..cad17c511b6d03 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -2317,63 +2317,6 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) {
   return nullptr;
 }
 
-static Intrinsic::ID getReductionForBinop(Instruction::BinaryOps Opc) {
-  switch (Opc) {
-  default:
-    break;
-  case Instruction::Add:
-    return Intrinsic::vector_reduce_add;
-  case Instruction::Mul:
-    return Intrinsic::vector_reduce_mul;
-  case Instruction::And:
-    return Intrinsic::vector_reduce_and;
-  case Instruction::Or:
-    return Intrinsic::vector_reduce_or;
-  case Instruction::Xor:
-    return Intrinsic::vector_reduce_xor;
-  }
-  return Intrinsic::not_intrinsic;
-}
-
-Instruction *InstCombinerImpl::foldBinopOfReductions(BinaryOperator &Inst) {
-  Instruction::BinaryOps BinOpOpc = Inst.getOpcode();
-  Intrinsic::ID ReductionIID = getReductionForBinop(BinOpOpc);
-  if (BinOpOpc == Instruction::Sub)
-    ReductionIID = Intrinsic::vector_reduce_add;
-  if (ReductionIID == Intrinsic::not_intrinsic)
-    return nullptr;
-
-  auto checkIntrinsicAndGetItsArgument = [](Value *V,
-                                            Intrinsic::ID IID) -> Value * {
-    IntrinsicInst *II = dyn_cast<IntrinsicInst>(V);
-    if (!II)
-      return nullptr;
-    if (II->getIntrinsicID() == IID && II->hasOneUse())
-      return II->getArgOperand(0);
-    return nullptr;
-  };
-
-  Value *V0 = checkIntrinsicAndGetItsArgument(Inst.getOperand(0), ReductionIID);
-  if (!V0)
-    return nullptr;
-  Value *V1 = checkIntrinsicAndGetItsArgument(Inst.getOperand(1), ReductionIID);
-  if (!V1)
-    return nullptr;
-
-  Type *VTy = V0->getType();
-  if (V1->getType() != VTy)
-    return nullptr;
-
-  Value *VectorBO = Builder.CreateBinOp(BinOpOpc, V0, V1);
-
-  if (PossiblyDisjointInst *PDInst = dyn_cast<PossiblyDisjointInst>(&Inst))
-    if (auto *PDVectorBO = dyn_cast<PossiblyDisjointInst>(VectorBO))
-      PDVectorBO->setIsDisjoint(PDInst->isDisjoint());
-
-  Instruction *Rdx = Builder.CreateIntrinsic(ReductionIID, {VTy}, {VectorBO});
-  return Rdx;
-}
-
 /// Try to narrow the width of a binop if at least 1 operand is an extend of
 /// of a value. This requires a potentially expensive known bits check to make
 /// sure the narrow op does not overflow.
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index 45915c10107b2e..0506ea915a23f7 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -957,6 +957,7 @@ constexpr Intrinsic::ID llvm::getReductionIntrinsicID(RecurKind RK) {
   }
 }
 
+// This is the inverse to getReductionForBinop
 unsigned llvm::getArithmeticReductionInstruction(Intrinsic::ID RdxID) {
   switch (RdxID) {
   case Intrinsic::vector_reduce_fadd:
@@ -986,6 +987,25 @@ unsigned llvm::getArithmeticReductionInstruction(Intrinsic::ID RdxID) {
   }
 }
 
+// This is the inverse to getArithmeticReductionInstruction
+Intrinsic::ID llvm::getReductionForBinop(Instruction::BinaryOps Opc) {
+  switch (Opc) {
+  default:
+    break;
+  case Instruction::Add:
+    return Intrinsic::vector_reduce_add;
+  case Instruction::Mul:
+    return Intrinsic::vector_reduce_mul;
+  case Instruction::And:
+    return Intrinsic::vector_reduce_and;
+  case Instruction::Or:
+    return Intrinsic::vector_reduce_or;
+  case Instruction::Xor:
+    return Intrinsic::vector_reduce_xor;
+  }
+  return Intrinsic::not_intrinsic;
+}
+
 Intrinsic::ID llvm::getMinMaxReductionIntrinsicOp(Intrinsic::ID RdxID) {
   switch (RdxID) {
   default:
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 59920b5a4dd20a..7d501ebb43cf32 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -60,6 +60,12 @@ static cl::opt<unsigned> MaxInstrsToScan(
     "vector-combine-max-scan-instrs", cl::init(30), cl::Hidden,
     cl::desc("Max number of instructions to scan for vector combining."));
 
+static cl::opt<bool> ForceFoldBinopOfReductions(
+    "vector-combine-force-fold-binop-of-reductions", cl::init(false),
+    cl::Hidden,
+    cl::desc("Force folding binary of of two reductions even if it is not "
+             "profitable. Should be used for testing only."));
+
 static const unsigned InvalidIndex = std::numeric_limits<unsigned>::max();
 
 namespace {
@@ -113,6 +119,7 @@ class VectorCombine {
   bool scalarizeBinopOrCmp(Instruction &I);
   bool scalarizeVPIntrinsic(Instruction &I);
   bool foldExtractedCmps(Instruction &I);
+  bool foldBinopOfReductions(Instruction &I);
   bool foldSingleElementStore(Instruction &I);
   bool scalarizeLoadExtract(Instruction &I);
   bool foldConcatOfBoolMasks(Instruction &I);
@@ -1182,6 +1189,57 @@ bool VectorCombine::foldExtractedCmps(Instruction &I) {
   return true;
 }
 
+bool VectorCombine::foldBinopOfReductions(Instruction &I) {
+  Instruction::BinaryOps BinOpOpc = cast<BinaryOperator>(&I)->getOpcode();
+  Intrinsic::ID ReductionIID = getReductionForBinop(BinOpOpc);
+  if (BinOpOpc == Instruction::Sub)
+    ReductionIID = Intrinsic::vector_reduce_add;
+  if (ReductionIID == Intrinsic::not_intrinsic)
+    return false;
+
+  auto checkIntrinsicAndGetItsArgument = [](Value *V,
+                                            Intrinsic::ID IID) -> Value * {
+    IntrinsicInst *II = dyn_cast<IntrinsicInst>(V);
+    if (!II)
+      return nullptr;
+    if (II->getIntrinsicID() == IID && II->hasOneUse())
+      return II->getArgOperand(0);
+    return nullptr;
+  };
+
+  Value *V0 = checkIntrinsicAndGetItsArgument(I.getOperand(0), ReductionIID);
+  if (!V0)
+    return false;
+  Value *V1 = checkIntrinsicAndGetItsArgument(I.getOperand(1), ReductionIID);
+  if (!V1)
+    return false;
+
+  VectorType *VTy = cast<VectorType>(V0->getType());
+  if (V1->getType() != VTy)
+    return false;
+
+  InstructionCost ReductionCost =
+      TTI.getArithmeticReductionCost(BinOpOpc, VTy, std::nullopt, CostKind);
+  InstructionCost OldCost =
+      2 * ReductionCost + TTI.getInstructionCost(&I, CostKind);
+  InstructionCost NewCost =
+      ReductionCost + TTI.getArithmeticInstrCost(BinOpOpc, VTy, CostKind);
+  if (NewCost >= OldCost && !ForceFoldBinopOfReductions)
+    return false;
+
+  LLVM_DEBUG(dbgs() << "Found two mergeable reductions: " << I
+                    << "\n  OldCost: " << OldCost << " vs NewCost: " << NewCost
+                    << "\n");
+  Value *VectorBO = Builder.CreateBinOp(BinOpOpc, V0, V1);
+  if (PossiblyDisjointInst *PDInst = dyn_cast<PossiblyDisjointInst>(&I))
+    if (auto *PDVectorBO = dyn_cast<PossiblyDisjointInst>(VectorBO))
+      PDVectorBO->setIsDisjoint(PDInst->isDisjoint());
+
+  Instruction *Rdx = Builder.CreateIntrinsic(ReductionIID, {VTy}, {VectorBO});
+  replaceValue(I, *Rdx);
+  return true;
+}
+
 // Check if memory loc modified between two instrs in the same BB
 static bool isMemModifiedBetween(BasicBlock::iterator Begin,
                                  BasicBlock::iterator End,
@@ -3241,6 +3299,7 @@ bool VectorCombine::run() {
         if (Instruction::isBinaryOp(Opcode)) {
           MadeChange |= foldExtractExtract(I);
           MadeChange |= foldExtractedCmps(I);
+          MadeChange |= foldBinopOfReductions(I);
         }
         break;
       }
diff --git a/llvm/test/Transforms/InstCombine/fold-binop-of-reductions.ll b/llvm/test/Transforms/InstCombine/fold-binop-of-reductions.ll
index cc88db03fbe2cd..0119d2b06edd0a 100644
--- a/llvm/test/Transforms/InstCombine/fold-binop-of-reductions.ll
+++ b/llvm/test/Transforms/InstCombine/fold-binop-of-reductions.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
-; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+; RUN: opt < %s -passes=vector-combine -vector-combine-force-fold-binop-of-reductions=true -S | FileCheck %s
 
 define i32 @add_of_reduce_add(<16 x i32> %v0, <16 x i32> %v1) {
 ; CHECK-LABEL: define i32 @add_of_reduce_add(



More information about the llvm-commits mailing list