[llvm] [VectorCombine] Fold binary op of reductions. (PR #121567)
Mikhail Gudim via llvm-commits
llvm-commits at lists.llvm.org
Tue Feb 11 01:18:47 PST 2025
https://github.com/mgudim updated https://github.com/llvm/llvm-project/pull/121567
>From a569a5ff32b85ff6a41a8ab9c96e0432e4c6efe0 Mon Sep 17 00:00:00 2001
From: Mikhail Gudim <mgudim at ventanamicro.com>
Date: Thu, 2 Jan 2025 08:26:22 -0800
Subject: [PATCH 1/2] [InstCombine] Fold binary op of reductions.
Replace binary of of two reductions with one reduction of the binary op
applied to vectors. For example:
```
%v0_red = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %v0)
%v1_red = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %v1)
%res = add i32 %v0_red, %v1_red
```
gets transformed to:
```
%1 = add <16 x i32> %v0, %v1
%res = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %1)
```
---
.../InstCombine/InstCombineAddSub.cpp | 18 ++----
.../InstCombine/InstCombineAndOrXor.cpp | 9 +++
.../InstCombine/InstCombineInternal.h | 1 +
.../InstCombine/InstCombineMulDivRem.cpp | 3 +
.../InstCombine/InstructionCombining.cpp | 57 +++++++++++++++++++
.../VectorCombine/fold-binop-of-reductions.ll | 40 ++++++-------
6 files changed, 91 insertions(+), 37 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 658bbbc56976601..29cf75a6e318da0 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1528,6 +1528,9 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
if (Instruction *X = foldVectorBinop(I))
return X;
+ if (Instruction *X = foldBinopOfReductions(I))
+ return replaceInstUsesWith(I, X);
+
if (Instruction *Phi = foldBinopWithPhiOperands(I))
return Phi;
@@ -2387,19 +2390,8 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
}
}
- auto m_AddRdx = [](Value *&Vec) {
- return m_OneUse(m_Intrinsic<Intrinsic::vector_reduce_add>(m_Value(Vec)));
- };
- Value *V0, *V1;
- if (match(Op0, m_AddRdx(V0)) && match(Op1, m_AddRdx(V1)) &&
- V0->getType() == V1->getType()) {
- // Difference of sums is sum of differences:
- // add_rdx(V0) - add_rdx(V1) --> add_rdx(V0 - V1)
- Value *Sub = Builder.CreateSub(V0, V1);
- Value *Rdx = Builder.CreateIntrinsic(Intrinsic::vector_reduce_add,
- {Sub->getType()}, {Sub});
- return replaceInstUsesWith(I, Rdx);
- }
+ if (Instruction *X = foldBinopOfReductions(I))
+ return replaceInstUsesWith(I, X);
if (Constant *C = dyn_cast<Constant>(Op0)) {
Value *X;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index ca8a20b4b7312d1..b36eaf312837ba5 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -2380,6 +2380,9 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
if (Instruction *X = foldVectorBinop(I))
return X;
+ if (Instruction *X = foldBinopOfReductions(I))
+ return replaceInstUsesWith(I, X);
+
if (Instruction *Phi = foldBinopWithPhiOperands(I))
return Phi;
@@ -3560,6 +3563,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
if (Instruction *X = foldVectorBinop(I))
return X;
+ if (Instruction *X = foldBinopOfReductions(I))
+ return replaceInstUsesWith(I, X);
+
if (Instruction *Phi = foldBinopWithPhiOperands(I))
return Phi;
@@ -4671,6 +4677,9 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
if (Instruction *X = foldVectorBinop(I))
return X;
+ if (Instruction *X = foldBinopOfReductions(I))
+ return replaceInstUsesWith(I, X);
+
if (Instruction *Phi = foldBinopWithPhiOperands(I))
return Phi;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 83e1da98deeda09..40a03cc24817dec 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -594,6 +594,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
/// Canonicalize the position of binops relative to shufflevector.
Instruction *foldVectorBinop(BinaryOperator &Inst);
+ Instruction *foldBinopOfReductions(BinaryOperator &Inst);
Instruction *foldVectorSelect(SelectInst &Sel);
Instruction *foldSelectShuffle(ShuffleVectorInst &Shuf);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index c8bdf029dd71c37..19e49f41f5b09b6 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -199,6 +199,9 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
if (Instruction *X = foldVectorBinop(I))
return X;
+ if (Instruction *X = foldBinopOfReductions(I))
+ return replaceInstUsesWith(I, X);
+
if (Instruction *Phi = foldBinopWithPhiOperands(I))
return Phi;
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index a64c188575e6c37..a5e066d8d478874 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -2317,6 +2317,63 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) {
return nullptr;
}
+static Intrinsic::ID getReductionForBinop(Instruction::BinaryOps Opc) {
+ switch (Opc) {
+ default:
+ break;
+ case Instruction::Add:
+ return Intrinsic::vector_reduce_add;
+ case Instruction::Mul:
+ return Intrinsic::vector_reduce_mul;
+ case Instruction::And:
+ return Intrinsic::vector_reduce_and;
+ case Instruction::Or:
+ return Intrinsic::vector_reduce_or;
+ case Instruction::Xor:
+ return Intrinsic::vector_reduce_xor;
+ }
+ return Intrinsic::not_intrinsic;
+}
+
+Instruction *InstCombinerImpl::foldBinopOfReductions(BinaryOperator &Inst) {
+ Instruction::BinaryOps BinOpOpc = Inst.getOpcode();
+ Intrinsic::ID ReductionIID = getReductionForBinop(BinOpOpc);
+ if (BinOpOpc == Instruction::Sub)
+ ReductionIID = Intrinsic::vector_reduce_add;
+ if (ReductionIID == Intrinsic::not_intrinsic)
+ return nullptr;
+
+ auto checkIntrinsicAndGetItsArgument = [](Value *V,
+ Intrinsic::ID IID) -> Value * {
+ IntrinsicInst *II = dyn_cast<IntrinsicInst>(V);
+ if (!II)
+ return nullptr;
+ if (II->getIntrinsicID() == IID && II->hasOneUse())
+ return II->getArgOperand(0);
+ return nullptr;
+ };
+
+ Value *V0 = checkIntrinsicAndGetItsArgument(Inst.getOperand(0), ReductionIID);
+ if (!V0)
+ return nullptr;
+ Value *V1 = checkIntrinsicAndGetItsArgument(Inst.getOperand(1), ReductionIID);
+ if (!V1)
+ return nullptr;
+
+ Type *VTy = V0->getType();
+ if (V1->getType() != VTy)
+ return nullptr;
+
+ Value *VectorBO = Builder.CreateBinOp(BinOpOpc, V0, V1);
+
+ if (PossiblyDisjointInst *PDInst = dyn_cast<PossiblyDisjointInst>(&Inst))
+ if (auto *PDVectorBO = dyn_cast<PossiblyDisjointInst>(VectorBO))
+ PDVectorBO->setIsDisjoint(PDInst->isDisjoint());
+
+ Instruction *Rdx = Builder.CreateIntrinsic(ReductionIID, {VTy}, {VectorBO});
+ return Rdx;
+}
+
/// Try to narrow the width of a binop if at least 1 operand is an extend of
/// of a value. This requires a potentially expensive known bits check to make
/// sure the narrow op does not overflow.
diff --git a/llvm/test/Transforms/VectorCombine/fold-binop-of-reductions.ll b/llvm/test/Transforms/VectorCombine/fold-binop-of-reductions.ll
index 86f17cdfb79b426..cc88db03fbe2cdd 100644
--- a/llvm/test/Transforms/VectorCombine/fold-binop-of-reductions.ll
+++ b/llvm/test/Transforms/VectorCombine/fold-binop-of-reductions.ll
@@ -4,9 +4,8 @@
define i32 @add_of_reduce_add(<16 x i32> %v0, <16 x i32> %v1) {
; CHECK-LABEL: define i32 @add_of_reduce_add(
; CHECK-SAME: <16 x i32> [[V0:%.*]], <16 x i32> [[V1:%.*]]) {
-; CHECK-NEXT: [[V0_RED:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[V0]])
-; CHECK-NEXT: [[V1_RED:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[V1]])
-; CHECK-NEXT: [[RES:%.*]] = add i32 [[V0_RED]], [[V1_RED]]
+; CHECK-NEXT: [[TMP1:%.*]] = add <16 x i32> [[V0]], [[V1]]
+; CHECK-NEXT: [[RES:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP1]])
; CHECK-NEXT: ret i32 [[RES]]
;
%v0_red = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %v0)
@@ -31,9 +30,8 @@ define i32 @sub_of_reduce_add(<16 x i32> %v0, <16 x i32> %v1) {
define i32 @mul_of_reduce_mul(<16 x i32> %v0, <16 x i32> %v1) {
; CHECK-LABEL: define i32 @mul_of_reduce_mul(
; CHECK-SAME: <16 x i32> [[V0:%.*]], <16 x i32> [[V1:%.*]]) {
-; CHECK-NEXT: [[V0_RED:%.*]] = tail call i32 @llvm.vector.reduce.mul.v16i32(<16 x i32> [[V0]])
-; CHECK-NEXT: [[V1_RED:%.*]] = tail call i32 @llvm.vector.reduce.mul.v16i32(<16 x i32> [[V1]])
-; CHECK-NEXT: [[RES:%.*]] = mul i32 [[V0_RED]], [[V1_RED]]
+; CHECK-NEXT: [[TMP1:%.*]] = mul <16 x i32> [[V0]], [[V1]]
+; CHECK-NEXT: [[RES:%.*]] = call i32 @llvm.vector.reduce.mul.v16i32(<16 x i32> [[TMP1]])
; CHECK-NEXT: ret i32 [[RES]]
;
%v0_red = tail call i32 @llvm.vector.reduce.mul.v16i32(<16 x i32> %v0)
@@ -45,9 +43,8 @@ define i32 @mul_of_reduce_mul(<16 x i32> %v0, <16 x i32> %v1) {
define i32 @and_of_reduce_and(<16 x i32> %v0, <16 x i32> %v1) {
; CHECK-LABEL: define i32 @and_of_reduce_and(
; CHECK-SAME: <16 x i32> [[V0:%.*]], <16 x i32> [[V1:%.*]]) {
-; CHECK-NEXT: [[V0_RED:%.*]] = tail call i32 @llvm.vector.reduce.and.v16i32(<16 x i32> [[V0]])
-; CHECK-NEXT: [[V1_RED:%.*]] = tail call i32 @llvm.vector.reduce.and.v16i32(<16 x i32> [[V1]])
-; CHECK-NEXT: [[RES:%.*]] = and i32 [[V0_RED]], [[V1_RED]]
+; CHECK-NEXT: [[TMP1:%.*]] = and <16 x i32> [[V0]], [[V1]]
+; CHECK-NEXT: [[RES:%.*]] = call i32 @llvm.vector.reduce.and.v16i32(<16 x i32> [[TMP1]])
; CHECK-NEXT: ret i32 [[RES]]
;
%v0_red = tail call i32 @llvm.vector.reduce.and.v16i32(<16 x i32> %v0)
@@ -59,9 +56,8 @@ define i32 @and_of_reduce_and(<16 x i32> %v0, <16 x i32> %v1) {
define i32 @or_of_reduce_or(<16 x i32> %v0, <16 x i32> %v1) {
; CHECK-LABEL: define i32 @or_of_reduce_or(
; CHECK-SAME: <16 x i32> [[V0:%.*]], <16 x i32> [[V1:%.*]]) {
-; CHECK-NEXT: [[V0_RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v16i32(<16 x i32> [[V0]])
-; CHECK-NEXT: [[V1_RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v16i32(<16 x i32> [[V1]])
-; CHECK-NEXT: [[RES:%.*]] = or i32 [[V0_RED]], [[V1_RED]]
+; CHECK-NEXT: [[TMP1:%.*]] = or <16 x i32> [[V0]], [[V1]]
+; CHECK-NEXT: [[RES:%.*]] = call i32 @llvm.vector.reduce.or.v16i32(<16 x i32> [[TMP1]])
; CHECK-NEXT: ret i32 [[RES]]
;
%v0_red = tail call i32 @llvm.vector.reduce.or.v16i32(<16 x i32> %v0)
@@ -73,9 +69,8 @@ define i32 @or_of_reduce_or(<16 x i32> %v0, <16 x i32> %v1) {
define i32 @xor_of_reduce_xor(<16 x i32> %v0, <16 x i32> %v1) {
; CHECK-LABEL: define i32 @xor_of_reduce_xor(
; CHECK-SAME: <16 x i32> [[V0:%.*]], <16 x i32> [[V1:%.*]]) {
-; CHECK-NEXT: [[V0_RED:%.*]] = tail call i32 @llvm.vector.reduce.xor.v16i32(<16 x i32> [[V0]])
-; CHECK-NEXT: [[V1_RED:%.*]] = tail call i32 @llvm.vector.reduce.xor.v16i32(<16 x i32> [[V1]])
-; CHECK-NEXT: [[RES:%.*]] = xor i32 [[V0_RED]], [[V1_RED]]
+; CHECK-NEXT: [[TMP1:%.*]] = xor <16 x i32> [[V0]], [[V1]]
+; CHECK-NEXT: [[RES:%.*]] = call i32 @llvm.vector.reduce.xor.v16i32(<16 x i32> [[TMP1]])
; CHECK-NEXT: ret i32 [[RES]]
;
%v0_red = tail call i32 @llvm.vector.reduce.xor.v16i32(<16 x i32> %v0)
@@ -161,9 +156,8 @@ define i32 @multiple_use_of_reduction_1(<16 x i32> %v0, <16 x i32> %v1, ptr %p)
define i32 @do_not_preserve_overflow_flags(<16 x i32> %v0, <16 x i32> %v1) {
; CHECK-LABEL: define i32 @do_not_preserve_overflow_flags(
; CHECK-SAME: <16 x i32> [[V0:%.*]], <16 x i32> [[V1:%.*]]) {
-; CHECK-NEXT: [[V0_RED:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[V0]])
-; CHECK-NEXT: [[V1_RED:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[V1]])
-; CHECK-NEXT: [[RES:%.*]] = add nuw nsw i32 [[V0_RED]], [[V1_RED]]
+; CHECK-NEXT: [[TMP1:%.*]] = add <16 x i32> [[V0]], [[V1]]
+; CHECK-NEXT: [[RES:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP1]])
; CHECK-NEXT: ret i32 [[RES]]
;
%v0_red = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %v0)
@@ -175,9 +169,8 @@ define i32 @do_not_preserve_overflow_flags(<16 x i32> %v0, <16 x i32> %v1) {
define i32 @preserve_disjoint_flags(<16 x i32> %v0, <16 x i32> %v1) {
; CHECK-LABEL: define i32 @preserve_disjoint_flags(
; CHECK-SAME: <16 x i32> [[V0:%.*]], <16 x i32> [[V1:%.*]]) {
-; CHECK-NEXT: [[V0_RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v16i32(<16 x i32> [[V0]])
-; CHECK-NEXT: [[V1_RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v16i32(<16 x i32> [[V1]])
-; CHECK-NEXT: [[RES:%.*]] = or disjoint i32 [[V0_RED]], [[V1_RED]]
+; CHECK-NEXT: [[TMP1:%.*]] = or disjoint <16 x i32> [[V0]], [[V1]]
+; CHECK-NEXT: [[RES:%.*]] = call i32 @llvm.vector.reduce.or.v16i32(<16 x i32> [[TMP1]])
; CHECK-NEXT: ret i32 [[RES]]
;
%v0_red = tail call i32 @llvm.vector.reduce.or.v16i32(<16 x i32> %v0)
@@ -189,9 +182,8 @@ define i32 @preserve_disjoint_flags(<16 x i32> %v0, <16 x i32> %v1) {
define i32 @add_of_reduce_add_vscale(<vscale x 16 x i32> %v0, <vscale x 16 x i32> %v1) {
; CHECK-LABEL: define i32 @add_of_reduce_add_vscale(
; CHECK-SAME: <vscale x 16 x i32> [[V0:%.*]], <vscale x 16 x i32> [[V1:%.*]]) {
-; CHECK-NEXT: [[V0_RED:%.*]] = tail call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> [[V0]])
-; CHECK-NEXT: [[V1_RED:%.*]] = tail call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> [[V1]])
-; CHECK-NEXT: [[RES:%.*]] = add i32 [[V0_RED]], [[V1_RED]]
+; CHECK-NEXT: [[TMP1:%.*]] = add <vscale x 16 x i32> [[V0]], [[V1]]
+; CHECK-NEXT: [[RES:%.*]] = call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> [[TMP1]])
; CHECK-NEXT: ret i32 [[RES]]
;
%v0_red = tail call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> %v0)
>From 6b25c35303d87f2fba39d8afd2bdd4523e6a11ac Mon Sep 17 00:00:00 2001
From: Mikhail Gudim <mgudim at ventanamicro.com>
Date: Wed, 29 Jan 2025 03:58:19 -0800
Subject: [PATCH 2/2] Improve cost model
---
.../include/llvm/Transforms/Utils/LoopUtils.h | 2 +
.../InstCombine/InstCombineAddSub.cpp | 18 ++-
.../InstCombine/InstCombineAndOrXor.cpp | 9 --
.../InstCombine/InstCombineInternal.h | 1 -
.../InstCombine/InstCombineMulDivRem.cpp | 3 -
.../InstCombine/InstructionCombining.cpp | 57 --------
llvm/lib/Transforms/Utils/LoopUtils.cpp | 20 +++
.../Transforms/Vectorize/VectorCombine.cpp | 130 ++++++++++++++++++
.../VectorCombine/fold-binop-of-reductions.ll | 2 +-
9 files changed, 166 insertions(+), 76 deletions(-)
diff --git a/llvm/include/llvm/Transforms/Utils/LoopUtils.h b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
index b4cd52fef70fd24..1007b9d48fb726d 100644
--- a/llvm/include/llvm/Transforms/Utils/LoopUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
@@ -365,6 +365,8 @@ constexpr Intrinsic::ID getReductionIntrinsicID(RecurKind RK);
/// Returns the arithmetic instruction opcode used when expanding a reduction.
unsigned getArithmeticReductionInstruction(Intrinsic::ID RdxID);
+/// Returns the reduction intrinsic id corresponding to the binary operation.
+Intrinsic::ID getReductionForBinop(Instruction::BinaryOps Opc);
/// Returns the min/max intrinsic used when expanding a min/max reduction.
Intrinsic::ID getMinMaxReductionIntrinsicOp(Intrinsic::ID RdxID);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 29cf75a6e318da0..658bbbc56976601 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1528,9 +1528,6 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
if (Instruction *X = foldVectorBinop(I))
return X;
- if (Instruction *X = foldBinopOfReductions(I))
- return replaceInstUsesWith(I, X);
-
if (Instruction *Phi = foldBinopWithPhiOperands(I))
return Phi;
@@ -2390,8 +2387,19 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
}
}
- if (Instruction *X = foldBinopOfReductions(I))
- return replaceInstUsesWith(I, X);
+ auto m_AddRdx = [](Value *&Vec) {
+ return m_OneUse(m_Intrinsic<Intrinsic::vector_reduce_add>(m_Value(Vec)));
+ };
+ Value *V0, *V1;
+ if (match(Op0, m_AddRdx(V0)) && match(Op1, m_AddRdx(V1)) &&
+ V0->getType() == V1->getType()) {
+ // Difference of sums is sum of differences:
+ // add_rdx(V0) - add_rdx(V1) --> add_rdx(V0 - V1)
+ Value *Sub = Builder.CreateSub(V0, V1);
+ Value *Rdx = Builder.CreateIntrinsic(Intrinsic::vector_reduce_add,
+ {Sub->getType()}, {Sub});
+ return replaceInstUsesWith(I, Rdx);
+ }
if (Constant *C = dyn_cast<Constant>(Op0)) {
Value *X;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index b36eaf312837ba5..ca8a20b4b7312d1 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -2380,9 +2380,6 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
if (Instruction *X = foldVectorBinop(I))
return X;
- if (Instruction *X = foldBinopOfReductions(I))
- return replaceInstUsesWith(I, X);
-
if (Instruction *Phi = foldBinopWithPhiOperands(I))
return Phi;
@@ -3563,9 +3560,6 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
if (Instruction *X = foldVectorBinop(I))
return X;
- if (Instruction *X = foldBinopOfReductions(I))
- return replaceInstUsesWith(I, X);
-
if (Instruction *Phi = foldBinopWithPhiOperands(I))
return Phi;
@@ -4677,9 +4671,6 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
if (Instruction *X = foldVectorBinop(I))
return X;
- if (Instruction *X = foldBinopOfReductions(I))
- return replaceInstUsesWith(I, X);
-
if (Instruction *Phi = foldBinopWithPhiOperands(I))
return Phi;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 40a03cc24817dec..83e1da98deeda09 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -594,7 +594,6 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
/// Canonicalize the position of binops relative to shufflevector.
Instruction *foldVectorBinop(BinaryOperator &Inst);
- Instruction *foldBinopOfReductions(BinaryOperator &Inst);
Instruction *foldVectorSelect(SelectInst &Sel);
Instruction *foldSelectShuffle(ShuffleVectorInst &Shuf);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 19e49f41f5b09b6..c8bdf029dd71c37 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -199,9 +199,6 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
if (Instruction *X = foldVectorBinop(I))
return X;
- if (Instruction *X = foldBinopOfReductions(I))
- return replaceInstUsesWith(I, X);
-
if (Instruction *Phi = foldBinopWithPhiOperands(I))
return Phi;
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index a5e066d8d478874..a64c188575e6c37 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -2317,63 +2317,6 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) {
return nullptr;
}
-static Intrinsic::ID getReductionForBinop(Instruction::BinaryOps Opc) {
- switch (Opc) {
- default:
- break;
- case Instruction::Add:
- return Intrinsic::vector_reduce_add;
- case Instruction::Mul:
- return Intrinsic::vector_reduce_mul;
- case Instruction::And:
- return Intrinsic::vector_reduce_and;
- case Instruction::Or:
- return Intrinsic::vector_reduce_or;
- case Instruction::Xor:
- return Intrinsic::vector_reduce_xor;
- }
- return Intrinsic::not_intrinsic;
-}
-
-Instruction *InstCombinerImpl::foldBinopOfReductions(BinaryOperator &Inst) {
- Instruction::BinaryOps BinOpOpc = Inst.getOpcode();
- Intrinsic::ID ReductionIID = getReductionForBinop(BinOpOpc);
- if (BinOpOpc == Instruction::Sub)
- ReductionIID = Intrinsic::vector_reduce_add;
- if (ReductionIID == Intrinsic::not_intrinsic)
- return nullptr;
-
- auto checkIntrinsicAndGetItsArgument = [](Value *V,
- Intrinsic::ID IID) -> Value * {
- IntrinsicInst *II = dyn_cast<IntrinsicInst>(V);
- if (!II)
- return nullptr;
- if (II->getIntrinsicID() == IID && II->hasOneUse())
- return II->getArgOperand(0);
- return nullptr;
- };
-
- Value *V0 = checkIntrinsicAndGetItsArgument(Inst.getOperand(0), ReductionIID);
- if (!V0)
- return nullptr;
- Value *V1 = checkIntrinsicAndGetItsArgument(Inst.getOperand(1), ReductionIID);
- if (!V1)
- return nullptr;
-
- Type *VTy = V0->getType();
- if (V1->getType() != VTy)
- return nullptr;
-
- Value *VectorBO = Builder.CreateBinOp(BinOpOpc, V0, V1);
-
- if (PossiblyDisjointInst *PDInst = dyn_cast<PossiblyDisjointInst>(&Inst))
- if (auto *PDVectorBO = dyn_cast<PossiblyDisjointInst>(VectorBO))
- PDVectorBO->setIsDisjoint(PDInst->isDisjoint());
-
- Instruction *Rdx = Builder.CreateIntrinsic(ReductionIID, {VTy}, {VectorBO});
- return Rdx;
-}
-
/// Try to narrow the width of a binop if at least 1 operand is an extend of
/// of a value. This requires a potentially expensive known bits check to make
/// sure the narrow op does not overflow.
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index 45915c10107b2eb..0506ea915a23f7f 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -957,6 +957,7 @@ constexpr Intrinsic::ID llvm::getReductionIntrinsicID(RecurKind RK) {
}
}
+// This is the inverse to getReductionForBinop
unsigned llvm::getArithmeticReductionInstruction(Intrinsic::ID RdxID) {
switch (RdxID) {
case Intrinsic::vector_reduce_fadd:
@@ -986,6 +987,25 @@ unsigned llvm::getArithmeticReductionInstruction(Intrinsic::ID RdxID) {
}
}
+// This is the inverse to getArithmeticReductionInstruction
+Intrinsic::ID llvm::getReductionForBinop(Instruction::BinaryOps Opc) {
+ switch (Opc) {
+ default:
+ break;
+ case Instruction::Add:
+ return Intrinsic::vector_reduce_add;
+ case Instruction::Mul:
+ return Intrinsic::vector_reduce_mul;
+ case Instruction::And:
+ return Intrinsic::vector_reduce_and;
+ case Instruction::Or:
+ return Intrinsic::vector_reduce_or;
+ case Instruction::Xor:
+ return Intrinsic::vector_reduce_xor;
+ }
+ return Intrinsic::not_intrinsic;
+}
+
Intrinsic::ID llvm::getMinMaxReductionIntrinsicOp(Intrinsic::ID RdxID) {
switch (RdxID) {
default:
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 59920b5a4dd20ab..65ab05e66cbde2b 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -113,6 +113,7 @@ class VectorCombine {
bool scalarizeBinopOrCmp(Instruction &I);
bool scalarizeVPIntrinsic(Instruction &I);
bool foldExtractedCmps(Instruction &I);
+ bool foldBinopOfReductions(Instruction &I);
bool foldSingleElementStore(Instruction &I);
bool scalarizeLoadExtract(Instruction &I);
bool foldConcatOfBoolMasks(Instruction &I);
@@ -1182,6 +1183,134 @@ bool VectorCombine::foldExtractedCmps(Instruction &I) {
return true;
}
+static void analyzeCostOfVecReduction(const IntrinsicInst &II,
+ TTI::TargetCostKind CostKind,
+ const TargetTransformInfo &TTI,
+ InstructionCost &CostBeforeReduction,
+ InstructionCost &CostAfterReduction) {
+ using namespace llvm::PatternMatch;
+ Instruction *Op0, *Op1;
+ Instruction *RedOp = dyn_cast<Instruction>(II.getOperand(0));
+ VectorType *VecRedTy = cast<VectorType>(II.getOperand(0)->getType());
+ unsigned ReductionOpc =
+ getArithmeticReductionInstruction(II.getIntrinsicID());
+ if (RedOp && match(RedOp, m_ZExtOrSExt(m_Value()))) {
+ bool IsUnsigned = isa<ZExtInst>(RedOp);
+ VectorType *ExtType =
+ cast<VectorType>(RedOp->getOperand(0)->getType());
+
+ CostBeforeReduction =
+ TTI.getCastInstrCost(RedOp->getOpcode(), VecRedTy, ExtType,
+ TTI::CastContextHint::None, CostKind, RedOp);
+ CostAfterReduction =
+ TTI.getExtendedReductionCost(ReductionOpc, IsUnsigned, II.getType(),
+ ExtType, FastMathFlags(), CostKind);
+ return;
+ }
+ if (RedOp && II.getIntrinsicID() == Intrinsic::vector_reduce_add &&
+ match(RedOp,
+ m_ZExtOrSExt(m_Mul(m_Instruction(Op0), m_Instruction(Op1)))) &&
+ match(Op0, m_ZExtOrSExt(m_Value())) &&
+ Op0->getOpcode() == Op1->getOpcode() &&
+ Op0->getOperand(0)->getType() == Op1->getOperand(0)->getType() &&
+ (Op0->getOpcode() == RedOp->getOpcode() || Op0 == Op1)) {
+ // Matched reduce.add(ext(mul(ext(A), ext(B)))
+ bool IsUnsigned = isa<ZExtInst>(Op0);
+ VectorType *ExtType =
+ cast<VectorType>(Op0->getOperand(0)->getType());
+ VectorType *MulType = VectorType::get(Op0->getType(), VecRedTy);
+
+ InstructionCost ExtCost =
+ TTI.getCastInstrCost(Op0->getOpcode(), MulType, ExtType,
+ TTI::CastContextHint::None, CostKind, Op0);
+ InstructionCost MulCost =
+ TTI.getArithmeticInstrCost(Instruction::Mul, MulType, CostKind);
+ InstructionCost Ext2Cost =
+ TTI.getCastInstrCost(RedOp->getOpcode(), VecRedTy, MulType,
+ TTI::CastContextHint::None, CostKind, RedOp);
+
+ CostBeforeReduction = ExtCost * 2 + MulCost + Ext2Cost;
+ CostAfterReduction =
+ TTI.getMulAccReductionCost(IsUnsigned, II.getType(), ExtType, CostKind);
+ return;
+ }
+ CostAfterReduction = TTI.getArithmeticReductionCost(ReductionOpc, VecRedTy,
+ std::nullopt, CostKind);
+ return;
+}
+
+bool VectorCombine::foldBinopOfReductions(Instruction &I) {
+ Instruction::BinaryOps BinOpOpc = cast<BinaryOperator>(&I)->getOpcode();
+ Intrinsic::ID ReductionIID = getReductionForBinop(BinOpOpc);
+ if (BinOpOpc == Instruction::Sub)
+ ReductionIID = Intrinsic::vector_reduce_add;
+ if (ReductionIID == Intrinsic::not_intrinsic)
+ return false;
+
+ auto checkIntrinsicAndGetItsArgument = [](Value *V,
+ Intrinsic::ID IID) -> Value * {
+ IntrinsicInst *II = dyn_cast<IntrinsicInst>(V);
+ if (!II)
+ return nullptr;
+ if (II->getIntrinsicID() == IID && II->hasOneUse())
+ return II->getArgOperand(0);
+ return nullptr;
+ };
+
+ Value *V0 = checkIntrinsicAndGetItsArgument(I.getOperand(0), ReductionIID);
+ if (!V0)
+ return false;
+ Value *V1 = checkIntrinsicAndGetItsArgument(I.getOperand(1), ReductionIID);
+ if (!V1)
+ return false;
+
+ VectorType *VTy = cast<VectorType>(V0->getType());
+ if (V1->getType() != VTy)
+ return false;
+ const IntrinsicInst &II0 = *cast<IntrinsicInst>(I.getOperand(0));
+ const IntrinsicInst &II1 = *cast<IntrinsicInst>(I.getOperand(1));
+ unsigned ReductionOpc =
+ getArithmeticReductionInstruction(II0.getIntrinsicID());
+
+ InstructionCost OldCost = 0;
+ InstructionCost NewCost = 0;
+ InstructionCost CostOfRedOperand0 = 0;
+ InstructionCost CostOfRed0 = 0;
+ InstructionCost CostOfRedOperand1 = 0;
+ InstructionCost CostOfRed1 = 0;
+ analyzeCostOfVecReduction(II0, CostKind, TTI, CostOfRedOperand0, CostOfRed0);
+ analyzeCostOfVecReduction(II1, CostKind, TTI, CostOfRedOperand1, CostOfRed1);
+ OldCost = CostOfRed0 + CostOfRed1 + TTI.getInstructionCost(&I, CostKind);
+ NewCost =
+ CostOfRedOperand0 + CostOfRedOperand1 +
+ TTI.getArithmeticInstrCost(BinOpOpc, VTy, CostKind) +
+ TTI.getArithmeticReductionCost(ReductionOpc, VTy, std::nullopt, CostKind);
+ LLVM_DEBUG(
+ dbgs() << "CostOfRed0: " << CostOfRed0 << "\n";
+ dbgs() << "CostOfRed0: " << CostOfRed1 << "\n";
+ dbgs() << "Cost of scalar op: " << TTI.getInstructionCost(&I, CostKind) << "\n";
+ dbgs() << "CostOfRedOperand0: " << CostOfRedOperand0 << "\n";
+ dbgs() << "CostOfRedOperand1: " << CostOfRedOperand1 << "\n";
+ dbgs() << "Cost of vector op: " << TTI.getArithmeticInstrCost(BinOpOpc, VTy, CostKind) << "\n";
+ dbgs() << "Cost of final reduction: "<< TTI.getArithmeticReductionCost(ReductionOpc, VTy, std::nullopt, CostKind) << "\n";
+ dbgs() << "OldCost: " << OldCost << ", NewCost: " << NewCost << "\n";
+ );
+ if (NewCost >= OldCost || !NewCost.isValid())
+ return false;
+
+ LLVM_DEBUG(dbgs() << "Found two mergeable reductions: " << I
+ << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
+ << "\n");
+ Value *VectorBO = Builder.CreateBinOp(BinOpOpc, V0, V1);
+ if (PossiblyDisjointInst *PDInst = dyn_cast<PossiblyDisjointInst>(&I))
+ if (auto *PDVectorBO = dyn_cast<PossiblyDisjointInst>(VectorBO))
+ PDVectorBO->setIsDisjoint(PDInst->isDisjoint());
+
+ Instruction *Rdx = Builder.CreateIntrinsic(ReductionIID, {VTy}, {VectorBO});
+ replaceValue(I, *Rdx);
+ return true;
+}
+
// Check if memory loc modified between two instrs in the same BB
static bool isMemModifiedBetween(BasicBlock::iterator Begin,
BasicBlock::iterator End,
@@ -3241,6 +3370,7 @@ bool VectorCombine::run() {
if (Instruction::isBinaryOp(Opcode)) {
MadeChange |= foldExtractExtract(I);
MadeChange |= foldExtractedCmps(I);
+ MadeChange |= foldBinopOfReductions(I);
}
break;
}
diff --git a/llvm/test/Transforms/VectorCombine/fold-binop-of-reductions.ll b/llvm/test/Transforms/VectorCombine/fold-binop-of-reductions.ll
index cc88db03fbe2cdd..5f29af9de5a39a7 100644
--- a/llvm/test/Transforms/VectorCombine/fold-binop-of-reductions.ll
+++ b/llvm/test/Transforms/VectorCombine/fold-binop-of-reductions.ll
@@ -1,5 +1,5 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
-; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+; RUN: opt < %s -passes=vector-combine -S | FileCheck %s
define i32 @add_of_reduce_add(<16 x i32> %v0, <16 x i32> %v1) {
; CHECK-LABEL: define i32 @add_of_reduce_add(
More information about the llvm-commits
mailing list