[llvm] [InstCombine] Fold vector.reduce.op(vector.reverse(X)) -> vector.reduce.op(X) (PR #91743)
David Sherwood via llvm-commits
llvm-commits at lists.llvm.org
Wed May 15 06:54:59 PDT 2024
https://github.com/david-arm updated https://github.com/llvm/llvm-project/pull/91743
>From f8cdf5cb45223ca08b0e74a9b1516c787f358e1c Mon Sep 17 00:00:00 2001
From: David Sherwood <david.sherwood at arm.com>
Date: Fri, 10 May 2024 13:43:37 +0000
Subject: [PATCH 1/2] [InstCombine] Fold vector.reduce.op(vector.reverse(X)) ->
vector.reduce.op(X)
For all of the following reductions:
vector.reduce.or
vector.reduce.and
vector.reduce.xor
vector.reduce.add
vector.reduce.mul
vector.reduce.umin
vector.reduce.umax
vector.reduce.smin
vector.reduce.smax
vector.reduce.fmin
vector.reduce.fmax
if the input operand is the result of a vector.reverse then we
can perform a reduction on the vector.reverse input instead since
the answer is the same. If the reassociation is permitted we can
also do the same folds for these:
vector.reduce.fadd
vector.reduce.fmul
---
.../InstCombine/InstCombineCalls.cpp | 67 +++++++-
.../InstCombine/vector-logical-reductions.ll | 72 ++++++++
.../InstCombine/vector-reductions.ll | 162 ++++++++++++++++++
3 files changed, 300 insertions(+), 1 deletion(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 77534e0d36131..99927b06cc181 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -3223,6 +3223,14 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
// %res = cmp eq iReduxWidth %val, 11111
Value *Arg = II->getArgOperand(0);
Value *Vect;
+ // When doing a logical reduction of a reversed operand the result is
+ // identical to reducing the unreversed operand.
+ if (match(Arg, m_VecReverse(m_Value(Vect)))) {
+ Value *Res = IID == Intrinsic::vector_reduce_or
+ ? Builder.CreateOrReduce(Vect)
+ : Builder.CreateAndReduce(Vect);
+ return replaceInstUsesWith(CI, Res);
+ }
if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3254,6 +3262,12 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
// Trunc(ctpop(bitcast <n x i1> to in)).
Value *Arg = II->getArgOperand(0);
Value *Vect;
+ // When doing an integer add reduction of a reversed operand the result
+ // is identical to reducing the unreversed operand.
+ if (match(Arg, m_VecReverse(m_Value(Vect)))) {
+ Value *Res = Builder.CreateAddReduce(Vect);
+ return replaceInstUsesWith(CI, Res);
+ }
if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3282,6 +3296,12 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
// ?ext(vector_reduce_add(<n x i1>))
Value *Arg = II->getArgOperand(0);
Value *Vect;
+ // When doing a xor reduction of a reversed operand the result is
+ // identical to reducing the unreversed operand.
+ if (match(Arg, m_VecReverse(m_Value(Vect)))) {
+ Value *Res = Builder.CreateXorReduce(Vect);
+ return replaceInstUsesWith(CI, Res);
+ }
if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3305,6 +3325,12 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
// zext(vector_reduce_and(<n x i1>))
Value *Arg = II->getArgOperand(0);
Value *Vect;
+ // When doing a mul reduction of a reversed operand the result is
+ // identical to reducing the unreversed operand.
+ if (match(Arg, m_VecReverse(m_Value(Vect)))) {
+ Value *Res = Builder.CreateMulReduce(Vect);
+ return replaceInstUsesWith(CI, Res);
+ }
if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3329,6 +3355,14 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
// ?ext(vector_reduce_{and,or}(<n x i1>))
Value *Arg = II->getArgOperand(0);
Value *Vect;
+ // When doing a min/max reduction of a reversed operand the result is
+ // identical to reducing the unreversed operand.
+ if (match(Arg, m_VecReverse(m_Value(Vect)))) {
+ Value *Res = IID == Intrinsic::vector_reduce_umin
+ ? Builder.CreateIntMinReduce(Vect, false)
+ : Builder.CreateIntMaxReduce(Vect, false);
+ return replaceInstUsesWith(CI, Res);
+ }
if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3364,6 +3398,14 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
// zext(vector_reduce_{and,or}(<n x i1>))
Value *Arg = II->getArgOperand(0);
Value *Vect;
+ // When doing a min/max reduction of a reversed operand the result is
+ // identical to reducing the unreversed operand.
+ if (match(Arg, m_VecReverse(m_Value(Vect)))) {
+ Value *Res = IID == Intrinsic::vector_reduce_smin
+ ? Builder.CreateIntMinReduce(Vect, true)
+ : Builder.CreateIntMaxReduce(Vect, true);
+ return replaceInstUsesWith(CI, Res);
+ }
if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3395,8 +3437,31 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
: 0;
Value *Arg = II->getArgOperand(ArgIdx);
Value *V;
+
+ if (!CanBeReassociated)
+ break;
+
+ if (match(Arg, m_VecReverse(m_Value(V)))) {
+ Value *Res;
+ switch (IID) {
+ case Intrinsic::vector_reduce_fadd:
+ Res = Builder.CreateFAddReduce(II->getArgOperand(0), V);
+ break;
+ case Intrinsic::vector_reduce_fmul:
+ Res = Builder.CreateFMulReduce(II->getArgOperand(0), V);
+ break;
+ case Intrinsic::vector_reduce_fmin:
+ Res = Builder.CreateFPMinReduce(V);
+ break;
+ case Intrinsic::vector_reduce_fmax:
+ Res = Builder.CreateFPMaxReduce(V);
+ break;
+ }
+ return replaceInstUsesWith(CI, Res);
+ }
+
ArrayRef<int> Mask;
- if (!isa<FixedVectorType>(Arg->getType()) || !CanBeReassociated ||
+ if (!isa<FixedVectorType>(Arg->getType()) ||
!match(Arg, m_Shuffle(m_Value(V), m_Undef(), m_Mask(Mask))) ||
!cast<ShuffleVectorInst>(Arg)->isSingleSource())
break;
diff --git a/llvm/test/Transforms/InstCombine/vector-logical-reductions.ll b/llvm/test/Transforms/InstCombine/vector-logical-reductions.ll
index 9bb307ebf71e8..da4a0ca754680 100644
--- a/llvm/test/Transforms/InstCombine/vector-logical-reductions.ll
+++ b/llvm/test/Transforms/InstCombine/vector-logical-reductions.ll
@@ -21,5 +21,77 @@ define i1 @reduction_logical_and(<4 x i1> %x) {
ret i1 %r
}
+define i1 @reduction_logical_or_reverse_nxv2i1(<vscale x 2 x i1> %p) {
+; CHECK-LABEL: @reduction_logical_or_reverse_nxv2i1(
+; CHECK-NEXT: [[RED:%.*]] = call i1 @llvm.vector.reduce.or.nxv2i1(<vscale x 2 x i1> [[P:%.*]])
+; CHECK-NEXT: ret i1 [[RED]]
+;
+ %rev = call <vscale x 2 x i1> @llvm.vector.reverse.nxv2i1(<vscale x 2 x i1> %p)
+ %red = call i1 @llvm.vector.reduce.or.nxv2i1(<vscale x 2 x i1> %rev)
+ ret i1 %red
+}
+
+define i1 @reduction_logical_or_reverse_v2i1(<2 x i1> %p) {
+; CHECK-LABEL: @reduction_logical_or_reverse_v2i1(
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i1> [[P:%.*]] to i2
+; CHECK-NEXT: [[RED:%.*]] = icmp ne i2 [[TMP1]], 0
+; CHECK-NEXT: ret i1 [[RED]]
+;
+ %rev = call <2 x i1> @llvm.vector.reverse.v2i1(<2 x i1> %p)
+ %red = call i1 @llvm.vector.reduce.or.v2i1(<2 x i1> %rev)
+ ret i1 %red
+}
+
+define i1 @reduction_logical_and_reverse_nxv2i1(<vscale x 2 x i1> %p) {
+; CHECK-LABEL: @reduction_logical_and_reverse_nxv2i1(
+; CHECK-NEXT: [[RED:%.*]] = call i1 @llvm.vector.reduce.and.nxv2i1(<vscale x 2 x i1> [[P:%.*]])
+; CHECK-NEXT: ret i1 [[RED]]
+;
+ %rev = call <vscale x 2 x i1> @llvm.vector.reverse.nxv2i1(<vscale x 2 x i1> %p)
+ %red = call i1 @llvm.vector.reduce.and.nxv2i1(<vscale x 2 x i1> %rev)
+ ret i1 %red
+}
+
+define i1 @reduction_logical_and_reverse_v2i1(<2 x i1> %p) {
+; CHECK-LABEL: @reduction_logical_and_reverse_v2i1(
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i1> [[P:%.*]] to i2
+; CHECK-NEXT: [[RED:%.*]] = icmp eq i2 [[TMP1]], -1
+; CHECK-NEXT: ret i1 [[RED]]
+;
+ %rev = call <2 x i1> @llvm.vector.reverse.v2i1(<2 x i1> %p)
+ %red = call i1 @llvm.vector.reduce.and.v2i1(<2 x i1> %rev)
+ ret i1 %red
+}
+
+define i1 @reduction_logical_xor_reverse_nxv2i1(<vscale x 2 x i1> %p) {
+; CHECK-LABEL: @reduction_logical_xor_reverse_nxv2i1(
+; CHECK-NEXT: [[RED:%.*]] = call i1 @llvm.vector.reduce.xor.nxv2i1(<vscale x 2 x i1> [[P:%.*]])
+; CHECK-NEXT: ret i1 [[RED]]
+;
+ %rev = call <vscale x 2 x i1> @llvm.vector.reverse.nxv2i1(<vscale x 2 x i1> %p)
+ %red = call i1 @llvm.vector.reduce.xor.nxv2i1(<vscale x 2 x i1> %rev)
+ ret i1 %red
+}
+
+define i1 @reduction_logical_xor_reverse_v2i1(<2 x i1> %p) {
+; CHECK-LABEL: @reduction_logical_xor_reverse_v2i1(
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i1> [[P:%.*]] to i2
+; CHECK-NEXT: [[TMP2:%.*]] = call range(i2 0, -1) i2 @llvm.ctpop.i2(i2 [[TMP1]])
+; CHECK-NEXT: [[RED:%.*]] = trunc i2 [[TMP2]] to i1
+; CHECK-NEXT: ret i1 [[RED]]
+;
+ %rev = call <2 x i1> @llvm.vector.reverse.v2i1(<2 x i1> %p)
+ %red = call i1 @llvm.vector.reduce.xor.v2i1(<2 x i1> %rev)
+ ret i1 %red
+}
+
declare i1 @llvm.vector.reduce.or.v4i1(<4 x i1>)
+declare i1 @llvm.vector.reduce.or.nxv2i1(<vscale x 2 x i1>)
+declare i1 @llvm.vector.reduce.or.v2i1(<2 x i1>)
declare i1 @llvm.vector.reduce.and.v4i1(<4 x i1>)
+declare i1 @llvm.vector.reduce.and.nxv2i1(<vscale x 2 x i1>)
+declare i1 @llvm.vector.reduce.and.v2i1(<2 x i1>)
+declare i1 @llvm.vector.reduce.xor.nxv2i1(<vscale x 2 x i1>)
+declare i1 @llvm.vector.reduce.xor.v2i1(<2 x i1>)
+declare <vscale x 2 x i1> @llvm.vector.reverse.nxv2i1(<vscale x 2 x i1>)
+declare <2 x i1> @llvm.vector.reverse.v2i1(<2 x i1>)
diff --git a/llvm/test/Transforms/InstCombine/vector-reductions.ll b/llvm/test/Transforms/InstCombine/vector-reductions.ll
index 2614ffd386952..3e2a23a5ef64e 100644
--- a/llvm/test/Transforms/InstCombine/vector-reductions.ll
+++ b/llvm/test/Transforms/InstCombine/vector-reductions.ll
@@ -3,12 +3,29 @@
declare float @llvm.vector.reduce.fadd.f32.v4f32(float, <4 x float>)
declare float @llvm.vector.reduce.fadd.f32.v8f32(float, <8 x float>)
+declare float @llvm.vector.reduce.fmul.f32.nxv4f32(float, <vscale x 4 x float>)
+declare float @llvm.vector.reduce.fmin.f32.v4f32(float, <4 x float>)
+declare float @llvm.vector.reduce.fmax.f32.nxv4f32(float, <vscale x 4 x float>)
declare void @use_f32(float)
declare i32 @llvm.vector.reduce.add.v4i32(<4 x i32>)
+declare i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32>)
declare i32 @llvm.vector.reduce.add.v8i32(<8 x i32>)
declare void @use_i32(i32)
+declare i32 @llvm.vector.reduce.mul.v4i32(<4 x i32>)
+declare i32 @llvm.vector.reduce.mul.nxv4i32(<vscale x 4 x i32>)
+
+declare i32 @llvm.vector.reduce.smin.v4i32(<4 x i32>)
+declare i32 @llvm.vector.reduce.smax.nxv4i32(<vscale x 4 x i32>)
+declare i32 @llvm.vector.reduce.umin.v4i32(<4 x i32>)
+declare i32 @llvm.vector.reduce.umax.nxv4i32(<vscale x 4 x i32>)
+
+declare <vscale x 4 x i32> @llvm.vector.reverse.nxv4i32(<vscale x 4 x i32>)
+declare <vscale x 4 x float> @llvm.vector.reverse.nxv4f32(<vscale x 4 x float>)
+declare <4 x i32> @llvm.vector.reverse.v4i32(<4 x i32>)
+declare <4 x float> @llvm.vector.reverse.v4f32(<4 x float>)
+
define float @diff_of_sums_v4f32(float %a0, <4 x float> %v0, float %a1, <4 x float> %v1) {
; CHECK-LABEL: @diff_of_sums_v4f32(
; CHECK-NEXT: [[TMP1:%.*]] = fsub reassoc nsz <4 x float> [[V0:%.*]], [[V1:%.*]]
@@ -22,6 +39,71 @@ define float @diff_of_sums_v4f32(float %a0, <4 x float> %v0, float %a1, <4 x flo
ret float %r
}
+define float @reassoc_sum_of_reverse_v4f32(<4 x float> %v0) {
+; CHECK-LABEL: @reassoc_sum_of_reverse_v4f32(
+; CHECK-NEXT: [[RED:%.*]] = call float @llvm.vector.reduce.fadd.v4f32(float 0.000000e+00, <4 x float> [[V0:%.*]])
+; CHECK-NEXT: ret float [[RED]]
+;
+ %rev = call <4 x float> @llvm.vector.reverse.v4f32(<4 x float> %v0)
+ %red = call reassoc float @llvm.vector.reduce.fadd.v4f32(float zeroinitializer, <4 x float> %rev)
+ ret float %red
+}
+
+define float @reassoc_mul_reduction_of_reverse_nxv4f32(<vscale x 4 x float> %v0) {
+; CHECK-LABEL: @reassoc_mul_reduction_of_reverse_nxv4f32(
+; CHECK-NEXT: [[RED:%.*]] = call float @llvm.vector.reduce.fmul.nxv4f32(float 1.000000e+00, <vscale x 4 x float> [[V0:%.*]])
+; CHECK-NEXT: ret float [[RED]]
+;
+ %rev = call <vscale x 4 x float> @llvm.vector.reverse.nxv4f32(<vscale x 4 x float> %v0)
+ %red = call reassoc float @llvm.vector.reduce.fmul.nxv4f32(float 1.0, <vscale x 4 x float> %rev)
+ ret float %red
+}
+
+define float @fmax_of_reverse_v4f32(<4 x float> %v0) {
+; CHECK-LABEL: @fmax_of_reverse_v4f32(
+; CHECK-NEXT: [[RED:%.*]] = call float @llvm.vector.reduce.fmax.v4f32(<4 x float> [[V0:%.*]])
+; CHECK-NEXT: ret float [[RED]]
+;
+ %rev = call <4 x float> @llvm.vector.reverse.v4f32(<4 x float> %v0)
+ %red = call float @llvm.vector.reduce.fmax.v4f32(<4 x float> %rev)
+ ret float %red
+}
+
+define float @fmin_of_reverse_nxv4f32(<vscale x 4 x float> %v0) {
+; CHECK-LABEL: @fmin_of_reverse_nxv4f32(
+; CHECK-NEXT: [[RED:%.*]] = call float @llvm.vector.reduce.fmin.nxv4f32(<vscale x 4 x float> [[V0:%.*]])
+; CHECK-NEXT: ret float [[RED]]
+;
+ %rev = call <vscale x 4 x float> @llvm.vector.reverse.nxv4f32(<vscale x 4 x float> %v0)
+ %red = call float @llvm.vector.reduce.fmin.nxv4f32(<vscale x 4 x float> %rev)
+ ret float %red
+}
+
+; negative test - fadd cannot be folded with reverse due to lack of reassoc
+define float @sum_of_reverse_v4f32(<4 x float> %v0) {
+; CHECK-LABEL: @sum_of_reverse_v4f32(
+; CHECK-NEXT: [[REV:%.*]] = call <4 x float> @llvm.vector.reverse.v4f32(<4 x float> [[V0:%.*]])
+; CHECK-NEXT: [[RED:%.*]] = call float @llvm.vector.reduce.fadd.v4f32(float 0.000000e+00, <4 x float> [[REV]])
+; CHECK-NEXT: ret float [[RED]]
+;
+ %rev = call <4 x float> @llvm.vector.reverse.v4f32(<4 x float> %v0)
+ %red = call float @llvm.vector.reduce.fadd.v4f32(float zeroinitializer, <4 x float> %rev)
+ ret float %red
+}
+
+; negative test - fmul cannot be folded with reverse due to lack of reassoc
+define float @mul_reduction_of_reverse_nxv4f32(<vscale x 4 x float> %v0) {
+; CHECK-LABEL: @mul_reduction_of_reverse_nxv4f32(
+; CHECK-NEXT: [[REV:%.*]] = call <vscale x 4 x float> @llvm.vector.reverse.nxv4f32(<vscale x 4 x float> [[V0:%.*]])
+; CHECK-NEXT: [[RED:%.*]] = call float @llvm.vector.reduce.fmul.nxv4f32(float 0.000000e+00, <vscale x 4 x float> [[REV]])
+; CHECK-NEXT: ret float [[RED]]
+;
+ %rev = call <vscale x 4 x float> @llvm.vector.reverse.nxv4f32(<vscale x 4 x float> %v0)
+ %red = call float @llvm.vector.reduce.fmul.nxv4f32(float zeroinitializer, <vscale x 4 x float> %rev)
+ ret float %red
+}
+
+
; negative test - fsub must allow reassociation
define float @diff_of_sums_v4f32_fmf(float %a0, <4 x float> %v0, float %a1, <4 x float> %v1) {
@@ -98,6 +180,86 @@ define i32 @diff_of_sums_v4i32(<4 x i32> %v0, <4 x i32> %v1) {
ret i32 %r
}
+define i32 @sum_of_reverse_v4i32(<4 x i32> %v0) {
+; CHECK-LABEL: @sum_of_reverse_v4i32(
+; CHECK-NEXT: [[RED:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[V0:%.*]])
+; CHECK-NEXT: ret i32 [[RED]]
+;
+ %rev = call <4 x i32> @llvm.vector.reverse.v4i32(<4 x i32> %v0)
+ %red = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %rev)
+ ret i32 %red
+}
+
+define i32 @sum_of_reverse_nxv4i32(<vscale x 4 x i32> %v0) {
+; CHECK-LABEL: @sum_of_reverse_nxv4i32(
+; CHECK-NEXT: [[RED:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[V0:%.*]])
+; CHECK-NEXT: ret i32 [[RED]]
+;
+ %rev = call <vscale x 4 x i32> @llvm.vector.reverse.nxv4i32(<vscale x 4 x i32> %v0)
+ %red = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> %rev)
+ ret i32 %red
+}
+
+define i32 @mul_reduce_of_reverse_v4i32(<4 x i32> %v0) {
+; CHECK-LABEL: @mul_reduce_of_reverse_v4i32(
+; CHECK-NEXT: [[RED:%.*]] = call i32 @llvm.vector.reduce.mul.v4i32(<4 x i32> [[V0:%.*]])
+; CHECK-NEXT: ret i32 [[RED]]
+;
+ %rev = call <4 x i32> @llvm.vector.reverse.v4i32(<4 x i32> %v0)
+ %red = call i32 @llvm.vector.reduce.mul.v4i32(<4 x i32> %rev)
+ ret i32 %red
+}
+
+define i32 @mul_reduce_of_reverse_nxv4i32(<vscale x 4 x i32> %v0) {
+; CHECK-LABEL: @mul_reduce_of_reverse_nxv4i32(
+; CHECK-NEXT: [[RED:%.*]] = call i32 @llvm.vector.reduce.mul.nxv4i32(<vscale x 4 x i32> [[V0:%.*]])
+; CHECK-NEXT: ret i32 [[RED]]
+;
+ %rev = call <vscale x 4 x i32> @llvm.vector.reverse.nxv4i32(<vscale x 4 x i32> %v0)
+ %red = call i32 @llvm.vector.reduce.mul.nxv4i32(<vscale x 4 x i32> %rev)
+ ret i32 %red
+}
+
+define i32 @smin_reduce_of_reverse_v4i32(<4 x i32> %v0) {
+; CHECK-LABEL: @smin_reduce_of_reverse_v4i32(
+; CHECK-NEXT: [[RED:%.*]] = call i32 @llvm.vector.reduce.smin.v4i32(<4 x i32> [[V0:%.*]])
+; CHECK-NEXT: ret i32 [[RED]]
+;
+ %rev = call <4 x i32> @llvm.vector.reverse.v4i32(<4 x i32> %v0)
+ %red = call i32 @llvm.vector.reduce.smin.v4i32(<4 x i32> %rev)
+ ret i32 %red
+}
+
+define i32 @smax_reduce_of_reverse_nxv4i32(<vscale x 4 x i32> %v0) {
+; CHECK-LABEL: @smax_reduce_of_reverse_nxv4i32(
+; CHECK-NEXT: [[RED:%.*]] = call i32 @llvm.vector.reduce.smax.nxv4i32(<vscale x 4 x i32> [[V0:%.*]])
+; CHECK-NEXT: ret i32 [[RED]]
+;
+ %rev = call <vscale x 4 x i32> @llvm.vector.reverse.nxv4i32(<vscale x 4 x i32> %v0)
+ %red = call i32 @llvm.vector.reduce.smax.nxv4i32(<vscale x 4 x i32> %rev)
+ ret i32 %red
+}
+
+define i32 @umin_reduce_of_reverse_v4i32(<4 x i32> %v0) {
+; CHECK-LABEL: @umin_reduce_of_reverse_v4i32(
+; CHECK-NEXT: [[RED:%.*]] = call i32 @llvm.vector.reduce.umin.v4i32(<4 x i32> [[V0:%.*]])
+; CHECK-NEXT: ret i32 [[RED]]
+;
+ %rev = call <4 x i32> @llvm.vector.reverse.v4i32(<4 x i32> %v0)
+ %red = call i32 @llvm.vector.reduce.umin.v4i32(<4 x i32> %rev)
+ ret i32 %red
+}
+
+define i32 @umax_reduce_of_reverse_nxv4i32(<vscale x 4 x i32> %v0) {
+; CHECK-LABEL: @umax_reduce_of_reverse_nxv4i32(
+; CHECK-NEXT: [[RED:%.*]] = call i32 @llvm.vector.reduce.umax.nxv4i32(<vscale x 4 x i32> [[V0:%.*]])
+; CHECK-NEXT: ret i32 [[RED]]
+;
+ %rev = call <vscale x 4 x i32> @llvm.vector.reverse.nxv4i32(<vscale x 4 x i32> %v0)
+ %red = call i32 @llvm.vector.reduce.umax.nxv4i32(<vscale x 4 x i32> %rev)
+ ret i32 %red
+}
+
; negative test - extra uses could create extra instructions
define i32 @diff_of_sums_v4i32_extra_use1(<4 x i32> %v0, <4 x i32> %v1) {
>From 4a238d6d94abe68bf008aa9fe165db39eb5fefee Mon Sep 17 00:00:00 2001
From: David Sherwood <david.sherwood at arm.com>
Date: Mon, 13 May 2024 12:38:56 +0000
Subject: [PATCH 2/2] Add simplifyReductionOfShuffle helper
* Let all of the vector.reduce.op variants now call a new helper
called simplifyReductionOfShuffle, which deals with both
vector.reverse and shufflevector operations.
---
.../InstCombine/InstCombineCalls.cpp | 161 ++++++++----------
.../InstCombine/InstCombineInternal.h | 6 +
.../InstCombine/vector-reductions.ll | 4 +-
3 files changed, 78 insertions(+), 93 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 99927b06cc181..47546a1b3fb8f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1435,6 +1435,51 @@ static Instruction *foldBitOrderCrossLogicOp(Value *V,
return nullptr;
}
+Instruction *InstCombinerImpl::simplifyReductionOfShuffle(IntrinsicInst *II) {
+ Intrinsic::ID IID = II->getIntrinsicID();
+ bool CanBeReassociated = (IID != Intrinsic::vector_reduce_fadd &&
+ IID != Intrinsic::vector_reduce_fmul) ||
+ II->hasAllowReassoc();
+
+ if (!CanBeReassociated)
+ return nullptr;
+
+ const unsigned ArgIdx = (IID == Intrinsic::vector_reduce_fadd ||
+ IID == Intrinsic::vector_reduce_fmul)
+ ? 1
+ : 0;
+ Value *Arg = II->getArgOperand(ArgIdx);
+ Value *V;
+
+ if (match(Arg, m_VecReverse(m_Value(V)))) {
+ replaceUse(II->getOperandUse(ArgIdx), V);
+ return II;
+ }
+
+ ArrayRef<int> Mask;
+ if (!isa<FixedVectorType>(Arg->getType()) ||
+ !match(Arg, m_Shuffle(m_Value(V), m_Undef(), m_Mask(Mask))) ||
+ !cast<ShuffleVectorInst>(Arg)->isSingleSource())
+ return nullptr;
+
+ int Sz = Mask.size();
+ SmallBitVector UsedIndices(Sz);
+ for (int Idx : Mask) {
+ if (Idx == PoisonMaskElem || UsedIndices.test(Idx))
+ return nullptr;
+ UsedIndices.set(Idx);
+ }
+
+ // Can remove shuffle iff just shuffled elements, no repeats, undefs, or
+ // other changes.
+ if (UsedIndices.all()) {
+ replaceUse(II->getOperandUse(ArgIdx), V);
+ return II;
+ }
+
+ return nullptr;
+}
+
/// CallInst simplification. This mostly only handles folding of intrinsic
/// instructions. For normal calls, it allows visitCallBase to do the heavy
/// lifting.
@@ -3223,14 +3268,10 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
// %res = cmp eq iReduxWidth %val, 11111
Value *Arg = II->getArgOperand(0);
Value *Vect;
- // When doing a logical reduction of a reversed operand the result is
- // identical to reducing the unreversed operand.
- if (match(Arg, m_VecReverse(m_Value(Vect)))) {
- Value *Res = IID == Intrinsic::vector_reduce_or
- ? Builder.CreateOrReduce(Vect)
- : Builder.CreateAndReduce(Vect);
- return replaceInstUsesWith(CI, Res);
- }
+
+ if (Instruction *I = simplifyReductionOfShuffle(II))
+ return I;
+
if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3262,12 +3303,10 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
// Trunc(ctpop(bitcast <n x i1> to in)).
Value *Arg = II->getArgOperand(0);
Value *Vect;
- // When doing an integer add reduction of a reversed operand the result
- // is identical to reducing the unreversed operand.
- if (match(Arg, m_VecReverse(m_Value(Vect)))) {
- Value *Res = Builder.CreateAddReduce(Vect);
- return replaceInstUsesWith(CI, Res);
- }
+
+ if (Instruction *I = simplifyReductionOfShuffle(II))
+ return I;
+
if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3296,12 +3335,10 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
// ?ext(vector_reduce_add(<n x i1>))
Value *Arg = II->getArgOperand(0);
Value *Vect;
- // When doing a xor reduction of a reversed operand the result is
- // identical to reducing the unreversed operand.
- if (match(Arg, m_VecReverse(m_Value(Vect)))) {
- Value *Res = Builder.CreateXorReduce(Vect);
- return replaceInstUsesWith(CI, Res);
- }
+
+ if (Instruction *I = simplifyReductionOfShuffle(II))
+ return I;
+
if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3325,12 +3362,10 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
// zext(vector_reduce_and(<n x i1>))
Value *Arg = II->getArgOperand(0);
Value *Vect;
- // When doing a mul reduction of a reversed operand the result is
- // identical to reducing the unreversed operand.
- if (match(Arg, m_VecReverse(m_Value(Vect)))) {
- Value *Res = Builder.CreateMulReduce(Vect);
- return replaceInstUsesWith(CI, Res);
- }
+
+ if (Instruction *I = simplifyReductionOfShuffle(II))
+ return I;
+
if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3355,14 +3390,10 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
// ?ext(vector_reduce_{and,or}(<n x i1>))
Value *Arg = II->getArgOperand(0);
Value *Vect;
- // When doing a min/max reduction of a reversed operand the result is
- // identical to reducing the unreversed operand.
- if (match(Arg, m_VecReverse(m_Value(Vect)))) {
- Value *Res = IID == Intrinsic::vector_reduce_umin
- ? Builder.CreateIntMinReduce(Vect, false)
- : Builder.CreateIntMaxReduce(Vect, false);
- return replaceInstUsesWith(CI, Res);
- }
+
+ if (Instruction *I = simplifyReductionOfShuffle(II))
+ return I;
+
if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3398,14 +3429,10 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
// zext(vector_reduce_{and,or}(<n x i1>))
Value *Arg = II->getArgOperand(0);
Value *Vect;
- // When doing a min/max reduction of a reversed operand the result is
- // identical to reducing the unreversed operand.
- if (match(Arg, m_VecReverse(m_Value(Vect)))) {
- Value *Res = IID == Intrinsic::vector_reduce_smin
- ? Builder.CreateIntMinReduce(Vect, true)
- : Builder.CreateIntMaxReduce(Vect, true);
- return replaceInstUsesWith(CI, Res);
- }
+
+ if (Instruction *I = simplifyReductionOfShuffle(II))
+ return I;
+
if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3428,56 +3455,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
case Intrinsic::vector_reduce_fmin:
case Intrinsic::vector_reduce_fadd:
case Intrinsic::vector_reduce_fmul: {
- bool CanBeReassociated = (IID != Intrinsic::vector_reduce_fadd &&
- IID != Intrinsic::vector_reduce_fmul) ||
- II->hasAllowReassoc();
- const unsigned ArgIdx = (IID == Intrinsic::vector_reduce_fadd ||
- IID == Intrinsic::vector_reduce_fmul)
- ? 1
- : 0;
- Value *Arg = II->getArgOperand(ArgIdx);
- Value *V;
-
- if (!CanBeReassociated)
- break;
-
- if (match(Arg, m_VecReverse(m_Value(V)))) {
- Value *Res;
- switch (IID) {
- case Intrinsic::vector_reduce_fadd:
- Res = Builder.CreateFAddReduce(II->getArgOperand(0), V);
- break;
- case Intrinsic::vector_reduce_fmul:
- Res = Builder.CreateFMulReduce(II->getArgOperand(0), V);
- break;
- case Intrinsic::vector_reduce_fmin:
- Res = Builder.CreateFPMinReduce(V);
- break;
- case Intrinsic::vector_reduce_fmax:
- Res = Builder.CreateFPMaxReduce(V);
- break;
- }
- return replaceInstUsesWith(CI, Res);
- }
-
- ArrayRef<int> Mask;
- if (!isa<FixedVectorType>(Arg->getType()) ||
- !match(Arg, m_Shuffle(m_Value(V), m_Undef(), m_Mask(Mask))) ||
- !cast<ShuffleVectorInst>(Arg)->isSingleSource())
- break;
- int Sz = Mask.size();
- SmallBitVector UsedIndices(Sz);
- for (int Idx : Mask) {
- if (Idx == PoisonMaskElem || UsedIndices.test(Idx))
- break;
- UsedIndices.set(Idx);
- }
- // Can remove shuffle iff just shuffled elements, no repeats, undefs, or
- // other changes.
- if (UsedIndices.all()) {
- replaceUse(II->getOperandUse(ArgIdx), V);
+ if (simplifyReductionOfShuffle(II))
return nullptr;
- }
break;
}
case Intrinsic::is_fpclass: {
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index db7838bbe3c25..101a12d547dad 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -296,6 +296,12 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
Instruction *simplifyMaskedGather(IntrinsicInst &II);
Instruction *simplifyMaskedScatter(IntrinsicInst &II);
+ // Simplify vector.reduce.op(shuffle(x)) to vector.reduce.op(x) if we can
+ // prove the answer is identical, where shuffle could be a shufflevector or
+ // vector.reverse operation. Return the simplified instruction if it can be
+ // simplified or nullptr otherwise.
+ Instruction *simplifyReductionOfShuffle(IntrinsicInst *II);
+
/// Transform (zext icmp) to bitwise / integer operations in order to
/// eliminate it.
///
diff --git a/llvm/test/Transforms/InstCombine/vector-reductions.ll b/llvm/test/Transforms/InstCombine/vector-reductions.ll
index 3e2a23a5ef64e..10f4aca72dbc7 100644
--- a/llvm/test/Transforms/InstCombine/vector-reductions.ll
+++ b/llvm/test/Transforms/InstCombine/vector-reductions.ll
@@ -41,7 +41,7 @@ define float @diff_of_sums_v4f32(float %a0, <4 x float> %v0, float %a1, <4 x flo
define float @reassoc_sum_of_reverse_v4f32(<4 x float> %v0) {
; CHECK-LABEL: @reassoc_sum_of_reverse_v4f32(
-; CHECK-NEXT: [[RED:%.*]] = call float @llvm.vector.reduce.fadd.v4f32(float 0.000000e+00, <4 x float> [[V0:%.*]])
+; CHECK-NEXT: [[RED:%.*]] = call reassoc float @llvm.vector.reduce.fadd.v4f32(float 0.000000e+00, <4 x float> [[V0:%.*]])
; CHECK-NEXT: ret float [[RED]]
;
%rev = call <4 x float> @llvm.vector.reverse.v4f32(<4 x float> %v0)
@@ -51,7 +51,7 @@ define float @reassoc_sum_of_reverse_v4f32(<4 x float> %v0) {
define float @reassoc_mul_reduction_of_reverse_nxv4f32(<vscale x 4 x float> %v0) {
; CHECK-LABEL: @reassoc_mul_reduction_of_reverse_nxv4f32(
-; CHECK-NEXT: [[RED:%.*]] = call float @llvm.vector.reduce.fmul.nxv4f32(float 1.000000e+00, <vscale x 4 x float> [[V0:%.*]])
+; CHECK-NEXT: [[RED:%.*]] = call reassoc float @llvm.vector.reduce.fmul.nxv4f32(float 1.000000e+00, <vscale x 4 x float> [[V0:%.*]])
; CHECK-NEXT: ret float [[RED]]
;
%rev = call <vscale x 4 x float> @llvm.vector.reverse.nxv4f32(<vscale x 4 x float> %v0)
More information about the llvm-commits
mailing list