[llvm] [InstCombine] Fold vector.reduce.op(vector.reverse(X)) -> vector.reduce.op(X) (PR #91743)

David Sherwood via llvm-commits llvm-commits at lists.llvm.org
Wed May 15 06:54:59 PDT 2024


https://github.com/david-arm updated https://github.com/llvm/llvm-project/pull/91743

>From f8cdf5cb45223ca08b0e74a9b1516c787f358e1c Mon Sep 17 00:00:00 2001
From: David Sherwood <david.sherwood at arm.com>
Date: Fri, 10 May 2024 13:43:37 +0000
Subject: [PATCH 1/2] [InstCombine] Fold vector.reduce.op(vector.reverse(X)) ->
 vector.reduce.op(X)

For all of the following reductions:

vector.reduce.or
vector.reduce.and
vector.reduce.xor
vector.reduce.add
vector.reduce.mul
vector.reduce.umin
vector.reduce.umax
vector.reduce.smin
vector.reduce.smax
vector.reduce.fmin
vector.reduce.fmax

if the input operand is the result of a vector.reverse then we
can perform a reduction on the vector.reverse input instead since
the answer is the same. If the reassociation is permitted we can
also do the same folds for these:

vector.reduce.fadd
vector.reduce.fmul
---
 .../InstCombine/InstCombineCalls.cpp          |  67 +++++++-
 .../InstCombine/vector-logical-reductions.ll  |  72 ++++++++
 .../InstCombine/vector-reductions.ll          | 162 ++++++++++++++++++
 3 files changed, 300 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 77534e0d36131..99927b06cc181 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -3223,6 +3223,14 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     // %res = cmp eq iReduxWidth %val, 11111
     Value *Arg = II->getArgOperand(0);
     Value *Vect;
+    // When doing a logical reduction of a reversed operand the result is
+    // identical to reducing the unreversed operand.
+    if (match(Arg, m_VecReverse(m_Value(Vect)))) {
+      Value *Res = IID == Intrinsic::vector_reduce_or
+                       ? Builder.CreateOrReduce(Vect)
+                       : Builder.CreateAndReduce(Vect);
+      return replaceInstUsesWith(CI, Res);
+    }
     if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
       if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
         if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3254,6 +3262,12 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       // Trunc(ctpop(bitcast <n x i1> to in)).
       Value *Arg = II->getArgOperand(0);
       Value *Vect;
+      // When doing an integer add reduction of a reversed operand the result
+      // is identical to reducing the unreversed operand.
+      if (match(Arg, m_VecReverse(m_Value(Vect)))) {
+        Value *Res = Builder.CreateAddReduce(Vect);
+        return replaceInstUsesWith(CI, Res);
+      }
       if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
         if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
           if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3282,6 +3296,12 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       //   ?ext(vector_reduce_add(<n x i1>))
       Value *Arg = II->getArgOperand(0);
       Value *Vect;
+      // When doing a xor reduction of a reversed operand the result is
+      // identical to reducing the unreversed operand.
+      if (match(Arg, m_VecReverse(m_Value(Vect)))) {
+        Value *Res = Builder.CreateXorReduce(Vect);
+        return replaceInstUsesWith(CI, Res);
+      }
       if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
         if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
           if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3305,6 +3325,12 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       //   zext(vector_reduce_and(<n x i1>))
       Value *Arg = II->getArgOperand(0);
       Value *Vect;
+      // When doing a mul reduction of a reversed operand the result is
+      // identical to reducing the unreversed operand.
+      if (match(Arg, m_VecReverse(m_Value(Vect)))) {
+        Value *Res = Builder.CreateMulReduce(Vect);
+        return replaceInstUsesWith(CI, Res);
+      }
       if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
         if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
           if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3329,6 +3355,14 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       //   ?ext(vector_reduce_{and,or}(<n x i1>))
       Value *Arg = II->getArgOperand(0);
       Value *Vect;
+      // When doing a min/max reduction of a reversed operand the result is
+      // identical to reducing the unreversed operand.
+      if (match(Arg, m_VecReverse(m_Value(Vect)))) {
+        Value *Res = IID == Intrinsic::vector_reduce_umin
+                         ? Builder.CreateIntMinReduce(Vect, false)
+                         : Builder.CreateIntMaxReduce(Vect, false);
+        return replaceInstUsesWith(CI, Res);
+      }
       if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
         if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
           if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3364,6 +3398,14 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       //   zext(vector_reduce_{and,or}(<n x i1>))
       Value *Arg = II->getArgOperand(0);
       Value *Vect;
+      // When doing a min/max reduction of a reversed operand the result is
+      // identical to reducing the unreversed operand.
+      if (match(Arg, m_VecReverse(m_Value(Vect)))) {
+        Value *Res = IID == Intrinsic::vector_reduce_smin
+                         ? Builder.CreateIntMinReduce(Vect, true)
+                         : Builder.CreateIntMaxReduce(Vect, true);
+        return replaceInstUsesWith(CI, Res);
+      }
       if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
         if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
           if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3395,8 +3437,31 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
                                 : 0;
     Value *Arg = II->getArgOperand(ArgIdx);
     Value *V;
+
+    if (!CanBeReassociated)
+      break;
+
+    if (match(Arg, m_VecReverse(m_Value(V)))) {
+      Value *Res;
+      switch (IID) {
+      case Intrinsic::vector_reduce_fadd:
+        Res = Builder.CreateFAddReduce(II->getArgOperand(0), V);
+        break;
+      case Intrinsic::vector_reduce_fmul:
+        Res = Builder.CreateFMulReduce(II->getArgOperand(0), V);
+        break;
+      case Intrinsic::vector_reduce_fmin:
+        Res = Builder.CreateFPMinReduce(V);
+        break;
+      case Intrinsic::vector_reduce_fmax:
+        Res = Builder.CreateFPMaxReduce(V);
+        break;
+      }
+      return replaceInstUsesWith(CI, Res);
+    }
+
     ArrayRef<int> Mask;
-    if (!isa<FixedVectorType>(Arg->getType()) || !CanBeReassociated ||
+    if (!isa<FixedVectorType>(Arg->getType()) ||
         !match(Arg, m_Shuffle(m_Value(V), m_Undef(), m_Mask(Mask))) ||
         !cast<ShuffleVectorInst>(Arg)->isSingleSource())
       break;
diff --git a/llvm/test/Transforms/InstCombine/vector-logical-reductions.ll b/llvm/test/Transforms/InstCombine/vector-logical-reductions.ll
index 9bb307ebf71e8..da4a0ca754680 100644
--- a/llvm/test/Transforms/InstCombine/vector-logical-reductions.ll
+++ b/llvm/test/Transforms/InstCombine/vector-logical-reductions.ll
@@ -21,5 +21,77 @@ define i1 @reduction_logical_and(<4 x i1> %x) {
   ret i1 %r
 }
 
+define i1 @reduction_logical_or_reverse_nxv2i1(<vscale x 2 x i1> %p) {
+; CHECK-LABEL: @reduction_logical_or_reverse_nxv2i1(
+; CHECK-NEXT:    [[RED:%.*]] = call i1 @llvm.vector.reduce.or.nxv2i1(<vscale x 2 x i1> [[P:%.*]])
+; CHECK-NEXT:    ret i1 [[RED]]
+;
+  %rev = call <vscale x 2 x i1> @llvm.vector.reverse.nxv2i1(<vscale x 2 x i1> %p)
+  %red = call i1 @llvm.vector.reduce.or.nxv2i1(<vscale x 2 x i1> %rev)
+  ret i1 %red
+}
+
+define i1 @reduction_logical_or_reverse_v2i1(<2 x i1> %p) {
+; CHECK-LABEL: @reduction_logical_or_reverse_v2i1(
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <2 x i1> [[P:%.*]] to i2
+; CHECK-NEXT:    [[RED:%.*]] = icmp ne i2 [[TMP1]], 0
+; CHECK-NEXT:    ret i1 [[RED]]
+;
+  %rev = call <2 x i1> @llvm.vector.reverse.v2i1(<2 x i1> %p)
+  %red = call i1 @llvm.vector.reduce.or.v2i1(<2 x i1> %rev)
+  ret i1 %red
+}
+
+define i1 @reduction_logical_and_reverse_nxv2i1(<vscale x 2 x i1> %p) {
+; CHECK-LABEL: @reduction_logical_and_reverse_nxv2i1(
+; CHECK-NEXT:    [[RED:%.*]] = call i1 @llvm.vector.reduce.and.nxv2i1(<vscale x 2 x i1> [[P:%.*]])
+; CHECK-NEXT:    ret i1 [[RED]]
+;
+  %rev = call <vscale x 2 x i1> @llvm.vector.reverse.nxv2i1(<vscale x 2 x i1> %p)
+  %red = call i1 @llvm.vector.reduce.and.nxv2i1(<vscale x 2 x i1> %rev)
+  ret i1 %red
+}
+
+define i1 @reduction_logical_and_reverse_v2i1(<2 x i1> %p) {
+; CHECK-LABEL: @reduction_logical_and_reverse_v2i1(
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <2 x i1> [[P:%.*]] to i2
+; CHECK-NEXT:    [[RED:%.*]] = icmp eq i2 [[TMP1]], -1
+; CHECK-NEXT:    ret i1 [[RED]]
+;
+  %rev = call <2 x i1> @llvm.vector.reverse.v2i1(<2 x i1> %p)
+  %red = call i1 @llvm.vector.reduce.and.v2i1(<2 x i1> %rev)
+  ret i1 %red
+}
+
+define i1 @reduction_logical_xor_reverse_nxv2i1(<vscale x 2 x i1> %p) {
+; CHECK-LABEL: @reduction_logical_xor_reverse_nxv2i1(
+; CHECK-NEXT:    [[RED:%.*]] = call i1 @llvm.vector.reduce.xor.nxv2i1(<vscale x 2 x i1> [[P:%.*]])
+; CHECK-NEXT:    ret i1 [[RED]]
+;
+  %rev = call <vscale x 2 x i1> @llvm.vector.reverse.nxv2i1(<vscale x 2 x i1> %p)
+  %red = call i1 @llvm.vector.reduce.xor.nxv2i1(<vscale x 2 x i1> %rev)
+  ret i1 %red
+}
+
+define i1 @reduction_logical_xor_reverse_v2i1(<2 x i1> %p) {
+; CHECK-LABEL: @reduction_logical_xor_reverse_v2i1(
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <2 x i1> [[P:%.*]] to i2
+; CHECK-NEXT:    [[TMP2:%.*]] = call range(i2 0, -1) i2 @llvm.ctpop.i2(i2 [[TMP1]])
+; CHECK-NEXT:    [[RED:%.*]] = trunc i2 [[TMP2]] to i1
+; CHECK-NEXT:    ret i1 [[RED]]
+;
+  %rev = call <2 x i1> @llvm.vector.reverse.v2i1(<2 x i1> %p)
+  %red = call i1 @llvm.vector.reduce.xor.v2i1(<2 x i1> %rev)
+  ret i1 %red
+}
+
 declare i1 @llvm.vector.reduce.or.v4i1(<4 x i1>)
+declare i1 @llvm.vector.reduce.or.nxv2i1(<vscale x 2 x i1>)
+declare i1 @llvm.vector.reduce.or.v2i1(<2 x i1>)
 declare i1 @llvm.vector.reduce.and.v4i1(<4 x i1>)
+declare i1 @llvm.vector.reduce.and.nxv2i1(<vscale x 2 x i1>)
+declare i1 @llvm.vector.reduce.and.v2i1(<2 x i1>)
+declare i1 @llvm.vector.reduce.xor.nxv2i1(<vscale x 2 x i1>)
+declare i1 @llvm.vector.reduce.xor.v2i1(<2 x i1>)
+declare <vscale x 2 x i1> @llvm.vector.reverse.nxv2i1(<vscale x 2 x i1>)
+declare <2 x i1> @llvm.vector.reverse.v2i1(<2 x i1>)
diff --git a/llvm/test/Transforms/InstCombine/vector-reductions.ll b/llvm/test/Transforms/InstCombine/vector-reductions.ll
index 2614ffd386952..3e2a23a5ef64e 100644
--- a/llvm/test/Transforms/InstCombine/vector-reductions.ll
+++ b/llvm/test/Transforms/InstCombine/vector-reductions.ll
@@ -3,12 +3,29 @@
 
 declare float @llvm.vector.reduce.fadd.f32.v4f32(float, <4 x float>)
 declare float @llvm.vector.reduce.fadd.f32.v8f32(float, <8 x float>)
+declare float @llvm.vector.reduce.fmul.f32.nxv4f32(float, <vscale x 4 x float>)
+declare float @llvm.vector.reduce.fmin.f32.v4f32(float, <4 x float>)
+declare float @llvm.vector.reduce.fmax.f32.nxv4f32(float, <vscale x 4 x float>)
 declare void @use_f32(float)
 
 declare i32 @llvm.vector.reduce.add.v4i32(<4 x i32>)
+declare i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32>)
 declare i32 @llvm.vector.reduce.add.v8i32(<8 x i32>)
 declare void @use_i32(i32)
 
+declare i32 @llvm.vector.reduce.mul.v4i32(<4 x i32>)
+declare i32 @llvm.vector.reduce.mul.nxv4i32(<vscale x 4 x i32>)
+
+declare i32 @llvm.vector.reduce.smin.v4i32(<4 x i32>)
+declare i32 @llvm.vector.reduce.smax.nxv4i32(<vscale x 4 x i32>)
+declare i32 @llvm.vector.reduce.umin.v4i32(<4 x i32>)
+declare i32 @llvm.vector.reduce.umax.nxv4i32(<vscale x 4 x i32>)
+
+declare <vscale x 4 x i32> @llvm.vector.reverse.nxv4i32(<vscale x 4 x i32>)
+declare <vscale x 4 x float> @llvm.vector.reverse.nxv4f32(<vscale x 4 x float>)
+declare <4 x i32> @llvm.vector.reverse.v4i32(<4 x i32>)
+declare <4 x float> @llvm.vector.reverse.v4f32(<4 x float>)
+
 define float @diff_of_sums_v4f32(float %a0, <4 x float> %v0, float %a1, <4 x float> %v1) {
 ; CHECK-LABEL: @diff_of_sums_v4f32(
 ; CHECK-NEXT:    [[TMP1:%.*]] = fsub reassoc nsz <4 x float> [[V0:%.*]], [[V1:%.*]]
@@ -22,6 +39,71 @@ define float @diff_of_sums_v4f32(float %a0, <4 x float> %v0, float %a1, <4 x flo
   ret float %r
 }
 
+define float @reassoc_sum_of_reverse_v4f32(<4 x float> %v0) {
+; CHECK-LABEL: @reassoc_sum_of_reverse_v4f32(
+; CHECK-NEXT:    [[RED:%.*]] = call float @llvm.vector.reduce.fadd.v4f32(float 0.000000e+00, <4 x float> [[V0:%.*]])
+; CHECK-NEXT:    ret float [[RED]]
+;
+  %rev = call <4 x float> @llvm.vector.reverse.v4f32(<4 x float> %v0)
+  %red = call reassoc float @llvm.vector.reduce.fadd.v4f32(float zeroinitializer, <4 x float> %rev)
+  ret float %red
+}
+
+define float @reassoc_mul_reduction_of_reverse_nxv4f32(<vscale x 4 x float> %v0) {
+; CHECK-LABEL: @reassoc_mul_reduction_of_reverse_nxv4f32(
+; CHECK-NEXT:    [[RED:%.*]] = call float @llvm.vector.reduce.fmul.nxv4f32(float 1.000000e+00, <vscale x 4 x float> [[V0:%.*]])
+; CHECK-NEXT:    ret float [[RED]]
+;
+  %rev = call <vscale x 4 x float> @llvm.vector.reverse.nxv4f32(<vscale x 4 x float> %v0)
+  %red = call reassoc float @llvm.vector.reduce.fmul.nxv4f32(float 1.0, <vscale x 4 x float> %rev)
+  ret float %red
+}
+
+define float @fmax_of_reverse_v4f32(<4 x float> %v0) {
+; CHECK-LABEL: @fmax_of_reverse_v4f32(
+; CHECK-NEXT:    [[RED:%.*]] = call float @llvm.vector.reduce.fmax.v4f32(<4 x float> [[V0:%.*]])
+; CHECK-NEXT:    ret float [[RED]]
+;
+  %rev = call <4 x float> @llvm.vector.reverse.v4f32(<4 x float> %v0)
+  %red = call float @llvm.vector.reduce.fmax.v4f32(<4 x float> %rev)
+  ret float %red
+}
+
+define float @fmin_of_reverse_nxv4f32(<vscale x 4 x float> %v0) {
+; CHECK-LABEL: @fmin_of_reverse_nxv4f32(
+; CHECK-NEXT:    [[RED:%.*]] = call float @llvm.vector.reduce.fmin.nxv4f32(<vscale x 4 x float> [[V0:%.*]])
+; CHECK-NEXT:    ret float [[RED]]
+;
+  %rev = call <vscale x 4 x float> @llvm.vector.reverse.nxv4f32(<vscale x 4 x float> %v0)
+  %red = call float @llvm.vector.reduce.fmin.nxv4f32(<vscale x 4 x float> %rev)
+  ret float %red
+}
+
+; negative test - fadd cannot be folded with reverse due to lack of reassoc
+define float @sum_of_reverse_v4f32(<4 x float> %v0) {
+; CHECK-LABEL: @sum_of_reverse_v4f32(
+; CHECK-NEXT:    [[REV:%.*]] = call <4 x float> @llvm.vector.reverse.v4f32(<4 x float> [[V0:%.*]])
+; CHECK-NEXT:    [[RED:%.*]] = call float @llvm.vector.reduce.fadd.v4f32(float 0.000000e+00, <4 x float> [[REV]])
+; CHECK-NEXT:    ret float [[RED]]
+;
+  %rev = call <4 x float> @llvm.vector.reverse.v4f32(<4 x float> %v0)
+  %red = call float @llvm.vector.reduce.fadd.v4f32(float zeroinitializer, <4 x float> %rev)
+  ret float %red
+}
+
+; negative test - fmul cannot be folded with reverse due to lack of reassoc
+define float @mul_reduction_of_reverse_nxv4f32(<vscale x 4 x float> %v0) {
+; CHECK-LABEL: @mul_reduction_of_reverse_nxv4f32(
+; CHECK-NEXT:    [[REV:%.*]] = call <vscale x 4 x float> @llvm.vector.reverse.nxv4f32(<vscale x 4 x float> [[V0:%.*]])
+; CHECK-NEXT:    [[RED:%.*]] = call float @llvm.vector.reduce.fmul.nxv4f32(float 0.000000e+00, <vscale x 4 x float> [[REV]])
+; CHECK-NEXT:    ret float [[RED]]
+;
+  %rev = call <vscale x 4 x float> @llvm.vector.reverse.nxv4f32(<vscale x 4 x float> %v0)
+  %red = call float @llvm.vector.reduce.fmul.nxv4f32(float zeroinitializer, <vscale x 4 x float> %rev)
+  ret float %red
+}
+
+
 ; negative test - fsub must allow reassociation
 
 define float @diff_of_sums_v4f32_fmf(float %a0, <4 x float> %v0, float %a1, <4 x float> %v1) {
@@ -98,6 +180,86 @@ define i32 @diff_of_sums_v4i32(<4 x i32> %v0, <4 x i32> %v1) {
   ret i32 %r
 }
 
+define i32 @sum_of_reverse_v4i32(<4 x i32> %v0) {
+; CHECK-LABEL: @sum_of_reverse_v4i32(
+; CHECK-NEXT:    [[RED:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[V0:%.*]])
+; CHECK-NEXT:    ret i32 [[RED]]
+;
+  %rev = call <4 x i32> @llvm.vector.reverse.v4i32(<4 x i32> %v0)
+  %red = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %rev)
+  ret i32 %red
+}
+
+define i32 @sum_of_reverse_nxv4i32(<vscale x 4 x i32> %v0) {
+; CHECK-LABEL: @sum_of_reverse_nxv4i32(
+; CHECK-NEXT:    [[RED:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[V0:%.*]])
+; CHECK-NEXT:    ret i32 [[RED]]
+;
+  %rev = call <vscale x 4 x i32> @llvm.vector.reverse.nxv4i32(<vscale x 4 x i32> %v0)
+  %red = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> %rev)
+  ret i32 %red
+}
+
+define i32 @mul_reduce_of_reverse_v4i32(<4 x i32> %v0) {
+; CHECK-LABEL: @mul_reduce_of_reverse_v4i32(
+; CHECK-NEXT:    [[RED:%.*]] = call i32 @llvm.vector.reduce.mul.v4i32(<4 x i32> [[V0:%.*]])
+; CHECK-NEXT:    ret i32 [[RED]]
+;
+  %rev = call <4 x i32> @llvm.vector.reverse.v4i32(<4 x i32> %v0)
+  %red = call i32 @llvm.vector.reduce.mul.v4i32(<4 x i32> %rev)
+  ret i32 %red
+}
+
+define i32 @mul_reduce_of_reverse_nxv4i32(<vscale x 4 x i32> %v0) {
+; CHECK-LABEL: @mul_reduce_of_reverse_nxv4i32(
+; CHECK-NEXT:    [[RED:%.*]] = call i32 @llvm.vector.reduce.mul.nxv4i32(<vscale x 4 x i32> [[V0:%.*]])
+; CHECK-NEXT:    ret i32 [[RED]]
+;
+  %rev = call <vscale x 4 x i32> @llvm.vector.reverse.nxv4i32(<vscale x 4 x i32> %v0)
+  %red = call i32 @llvm.vector.reduce.mul.nxv4i32(<vscale x 4 x i32> %rev)
+  ret i32 %red
+}
+
+define i32 @smin_reduce_of_reverse_v4i32(<4 x i32> %v0) {
+; CHECK-LABEL: @smin_reduce_of_reverse_v4i32(
+; CHECK-NEXT:    [[RED:%.*]] = call i32 @llvm.vector.reduce.smin.v4i32(<4 x i32> [[V0:%.*]])
+; CHECK-NEXT:    ret i32 [[RED]]
+;
+  %rev = call <4 x i32> @llvm.vector.reverse.v4i32(<4 x i32> %v0)
+  %red = call i32 @llvm.vector.reduce.smin.v4i32(<4 x i32> %rev)
+  ret i32 %red
+}
+
+define i32 @smax_reduce_of_reverse_nxv4i32(<vscale x 4 x i32> %v0) {
+; CHECK-LABEL: @smax_reduce_of_reverse_nxv4i32(
+; CHECK-NEXT:    [[RED:%.*]] = call i32 @llvm.vector.reduce.smax.nxv4i32(<vscale x 4 x i32> [[V0:%.*]])
+; CHECK-NEXT:    ret i32 [[RED]]
+;
+  %rev = call <vscale x 4 x i32> @llvm.vector.reverse.nxv4i32(<vscale x 4 x i32> %v0)
+  %red = call i32 @llvm.vector.reduce.smax.nxv4i32(<vscale x 4 x i32> %rev)
+  ret i32 %red
+}
+
+define i32 @umin_reduce_of_reverse_v4i32(<4 x i32> %v0) {
+; CHECK-LABEL: @umin_reduce_of_reverse_v4i32(
+; CHECK-NEXT:    [[RED:%.*]] = call i32 @llvm.vector.reduce.umin.v4i32(<4 x i32> [[V0:%.*]])
+; CHECK-NEXT:    ret i32 [[RED]]
+;
+  %rev = call <4 x i32> @llvm.vector.reverse.v4i32(<4 x i32> %v0)
+  %red = call i32 @llvm.vector.reduce.umin.v4i32(<4 x i32> %rev)
+  ret i32 %red
+}
+
+define i32 @umax_reduce_of_reverse_nxv4i32(<vscale x 4 x i32> %v0) {
+; CHECK-LABEL: @umax_reduce_of_reverse_nxv4i32(
+; CHECK-NEXT:    [[RED:%.*]] = call i32 @llvm.vector.reduce.umax.nxv4i32(<vscale x 4 x i32> [[V0:%.*]])
+; CHECK-NEXT:    ret i32 [[RED]]
+;
+  %rev = call <vscale x 4 x i32> @llvm.vector.reverse.nxv4i32(<vscale x 4 x i32> %v0)
+  %red = call i32 @llvm.vector.reduce.umax.nxv4i32(<vscale x 4 x i32> %rev)
+  ret i32 %red
+}
+
 ; negative test - extra uses could create extra instructions
 
 define i32 @diff_of_sums_v4i32_extra_use1(<4 x i32> %v0, <4 x i32> %v1) {

>From 4a238d6d94abe68bf008aa9fe165db39eb5fefee Mon Sep 17 00:00:00 2001
From: David Sherwood <david.sherwood at arm.com>
Date: Mon, 13 May 2024 12:38:56 +0000
Subject: [PATCH 2/2] Add simplifyReductionOfShuffle helper

* Let all of the vector.reduce.op variants now call a new helper
called simplifyReductionOfShuffle, which deals with both
vector.reverse and shufflevector operations.
---
 .../InstCombine/InstCombineCalls.cpp          | 161 ++++++++----------
 .../InstCombine/InstCombineInternal.h         |   6 +
 .../InstCombine/vector-reductions.ll          |   4 +-
 3 files changed, 78 insertions(+), 93 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 99927b06cc181..47546a1b3fb8f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1435,6 +1435,51 @@ static Instruction *foldBitOrderCrossLogicOp(Value *V,
   return nullptr;
 }
 
+Instruction *InstCombinerImpl::simplifyReductionOfShuffle(IntrinsicInst *II) {
+  Intrinsic::ID IID = II->getIntrinsicID();
+  bool CanBeReassociated = (IID != Intrinsic::vector_reduce_fadd &&
+                            IID != Intrinsic::vector_reduce_fmul) ||
+                           II->hasAllowReassoc();
+
+  if (!CanBeReassociated)
+    return nullptr;
+
+  const unsigned ArgIdx = (IID == Intrinsic::vector_reduce_fadd ||
+                           IID == Intrinsic::vector_reduce_fmul)
+                              ? 1
+                              : 0;
+  Value *Arg = II->getArgOperand(ArgIdx);
+  Value *V;
+
+  if (match(Arg, m_VecReverse(m_Value(V)))) {
+    replaceUse(II->getOperandUse(ArgIdx), V);
+    return II;
+  }
+
+  ArrayRef<int> Mask;
+  if (!isa<FixedVectorType>(Arg->getType()) ||
+      !match(Arg, m_Shuffle(m_Value(V), m_Undef(), m_Mask(Mask))) ||
+      !cast<ShuffleVectorInst>(Arg)->isSingleSource())
+    return nullptr;
+
+  int Sz = Mask.size();
+  SmallBitVector UsedIndices(Sz);
+  for (int Idx : Mask) {
+    if (Idx == PoisonMaskElem || UsedIndices.test(Idx))
+      return nullptr;
+    UsedIndices.set(Idx);
+  }
+
+  // Can remove shuffle iff just shuffled elements, no repeats, undefs, or
+  // other changes.
+  if (UsedIndices.all()) {
+    replaceUse(II->getOperandUse(ArgIdx), V);
+    return II;
+  }
+
+  return nullptr;
+}
+
 /// CallInst simplification. This mostly only handles folding of intrinsic
 /// instructions. For normal calls, it allows visitCallBase to do the heavy
 /// lifting.
@@ -3223,14 +3268,10 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     // %res = cmp eq iReduxWidth %val, 11111
     Value *Arg = II->getArgOperand(0);
     Value *Vect;
-    // When doing a logical reduction of a reversed operand the result is
-    // identical to reducing the unreversed operand.
-    if (match(Arg, m_VecReverse(m_Value(Vect)))) {
-      Value *Res = IID == Intrinsic::vector_reduce_or
-                       ? Builder.CreateOrReduce(Vect)
-                       : Builder.CreateAndReduce(Vect);
-      return replaceInstUsesWith(CI, Res);
-    }
+
+    if (Instruction *I = simplifyReductionOfShuffle(II))
+      return I;
+
     if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
       if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
         if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3262,12 +3303,10 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       // Trunc(ctpop(bitcast <n x i1> to in)).
       Value *Arg = II->getArgOperand(0);
       Value *Vect;
-      // When doing an integer add reduction of a reversed operand the result
-      // is identical to reducing the unreversed operand.
-      if (match(Arg, m_VecReverse(m_Value(Vect)))) {
-        Value *Res = Builder.CreateAddReduce(Vect);
-        return replaceInstUsesWith(CI, Res);
-      }
+
+      if (Instruction *I = simplifyReductionOfShuffle(II))
+        return I;
+
       if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
         if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
           if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3296,12 +3335,10 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       //   ?ext(vector_reduce_add(<n x i1>))
       Value *Arg = II->getArgOperand(0);
       Value *Vect;
-      // When doing a xor reduction of a reversed operand the result is
-      // identical to reducing the unreversed operand.
-      if (match(Arg, m_VecReverse(m_Value(Vect)))) {
-        Value *Res = Builder.CreateXorReduce(Vect);
-        return replaceInstUsesWith(CI, Res);
-      }
+
+      if (Instruction *I = simplifyReductionOfShuffle(II))
+        return I;
+
       if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
         if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
           if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3325,12 +3362,10 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       //   zext(vector_reduce_and(<n x i1>))
       Value *Arg = II->getArgOperand(0);
       Value *Vect;
-      // When doing a mul reduction of a reversed operand the result is
-      // identical to reducing the unreversed operand.
-      if (match(Arg, m_VecReverse(m_Value(Vect)))) {
-        Value *Res = Builder.CreateMulReduce(Vect);
-        return replaceInstUsesWith(CI, Res);
-      }
+
+      if (Instruction *I = simplifyReductionOfShuffle(II))
+        return I;
+
       if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
         if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
           if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3355,14 +3390,10 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       //   ?ext(vector_reduce_{and,or}(<n x i1>))
       Value *Arg = II->getArgOperand(0);
       Value *Vect;
-      // When doing a min/max reduction of a reversed operand the result is
-      // identical to reducing the unreversed operand.
-      if (match(Arg, m_VecReverse(m_Value(Vect)))) {
-        Value *Res = IID == Intrinsic::vector_reduce_umin
-                         ? Builder.CreateIntMinReduce(Vect, false)
-                         : Builder.CreateIntMaxReduce(Vect, false);
-        return replaceInstUsesWith(CI, Res);
-      }
+
+      if (Instruction *I = simplifyReductionOfShuffle(II))
+        return I;
+
       if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
         if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
           if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3398,14 +3429,10 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       //   zext(vector_reduce_{and,or}(<n x i1>))
       Value *Arg = II->getArgOperand(0);
       Value *Vect;
-      // When doing a min/max reduction of a reversed operand the result is
-      // identical to reducing the unreversed operand.
-      if (match(Arg, m_VecReverse(m_Value(Vect)))) {
-        Value *Res = IID == Intrinsic::vector_reduce_smin
-                         ? Builder.CreateIntMinReduce(Vect, true)
-                         : Builder.CreateIntMaxReduce(Vect, true);
-        return replaceInstUsesWith(CI, Res);
-      }
+
+      if (Instruction *I = simplifyReductionOfShuffle(II))
+        return I;
+
       if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
         if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
           if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3428,56 +3455,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
   case Intrinsic::vector_reduce_fmin:
   case Intrinsic::vector_reduce_fadd:
   case Intrinsic::vector_reduce_fmul: {
-    bool CanBeReassociated = (IID != Intrinsic::vector_reduce_fadd &&
-                              IID != Intrinsic::vector_reduce_fmul) ||
-                             II->hasAllowReassoc();
-    const unsigned ArgIdx = (IID == Intrinsic::vector_reduce_fadd ||
-                             IID == Intrinsic::vector_reduce_fmul)
-                                ? 1
-                                : 0;
-    Value *Arg = II->getArgOperand(ArgIdx);
-    Value *V;
-
-    if (!CanBeReassociated)
-      break;
-
-    if (match(Arg, m_VecReverse(m_Value(V)))) {
-      Value *Res;
-      switch (IID) {
-      case Intrinsic::vector_reduce_fadd:
-        Res = Builder.CreateFAddReduce(II->getArgOperand(0), V);
-        break;
-      case Intrinsic::vector_reduce_fmul:
-        Res = Builder.CreateFMulReduce(II->getArgOperand(0), V);
-        break;
-      case Intrinsic::vector_reduce_fmin:
-        Res = Builder.CreateFPMinReduce(V);
-        break;
-      case Intrinsic::vector_reduce_fmax:
-        Res = Builder.CreateFPMaxReduce(V);
-        break;
-      }
-      return replaceInstUsesWith(CI, Res);
-    }
-
-    ArrayRef<int> Mask;
-    if (!isa<FixedVectorType>(Arg->getType()) ||
-        !match(Arg, m_Shuffle(m_Value(V), m_Undef(), m_Mask(Mask))) ||
-        !cast<ShuffleVectorInst>(Arg)->isSingleSource())
-      break;
-    int Sz = Mask.size();
-    SmallBitVector UsedIndices(Sz);
-    for (int Idx : Mask) {
-      if (Idx == PoisonMaskElem || UsedIndices.test(Idx))
-        break;
-      UsedIndices.set(Idx);
-    }
-    // Can remove shuffle iff just shuffled elements, no repeats, undefs, or
-    // other changes.
-    if (UsedIndices.all()) {
-      replaceUse(II->getOperandUse(ArgIdx), V);
+    if (simplifyReductionOfShuffle(II))
       return nullptr;
-    }
     break;
   }
   case Intrinsic::is_fpclass: {
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index db7838bbe3c25..101a12d547dad 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -296,6 +296,12 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   Instruction *simplifyMaskedGather(IntrinsicInst &II);
   Instruction *simplifyMaskedScatter(IntrinsicInst &II);
 
+  // Simplify vector.reduce.op(shuffle(x)) to vector.reduce.op(x) if we can
+  // prove the answer is identical, where shuffle could be a shufflevector or
+  // vector.reverse operation. Return the simplified instruction if it can be
+  // simplified or nullptr otherwise.
+  Instruction *simplifyReductionOfShuffle(IntrinsicInst *II);
+
   /// Transform (zext icmp) to bitwise / integer operations in order to
   /// eliminate it.
   ///
diff --git a/llvm/test/Transforms/InstCombine/vector-reductions.ll b/llvm/test/Transforms/InstCombine/vector-reductions.ll
index 3e2a23a5ef64e..10f4aca72dbc7 100644
--- a/llvm/test/Transforms/InstCombine/vector-reductions.ll
+++ b/llvm/test/Transforms/InstCombine/vector-reductions.ll
@@ -41,7 +41,7 @@ define float @diff_of_sums_v4f32(float %a0, <4 x float> %v0, float %a1, <4 x flo
 
 define float @reassoc_sum_of_reverse_v4f32(<4 x float> %v0) {
 ; CHECK-LABEL: @reassoc_sum_of_reverse_v4f32(
-; CHECK-NEXT:    [[RED:%.*]] = call float @llvm.vector.reduce.fadd.v4f32(float 0.000000e+00, <4 x float> [[V0:%.*]])
+; CHECK-NEXT:    [[RED:%.*]] = call reassoc float @llvm.vector.reduce.fadd.v4f32(float 0.000000e+00, <4 x float> [[V0:%.*]])
 ; CHECK-NEXT:    ret float [[RED]]
 ;
   %rev = call <4 x float> @llvm.vector.reverse.v4f32(<4 x float> %v0)
@@ -51,7 +51,7 @@ define float @reassoc_sum_of_reverse_v4f32(<4 x float> %v0) {
 
 define float @reassoc_mul_reduction_of_reverse_nxv4f32(<vscale x 4 x float> %v0) {
 ; CHECK-LABEL: @reassoc_mul_reduction_of_reverse_nxv4f32(
-; CHECK-NEXT:    [[RED:%.*]] = call float @llvm.vector.reduce.fmul.nxv4f32(float 1.000000e+00, <vscale x 4 x float> [[V0:%.*]])
+; CHECK-NEXT:    [[RED:%.*]] = call reassoc float @llvm.vector.reduce.fmul.nxv4f32(float 1.000000e+00, <vscale x 4 x float> [[V0:%.*]])
 ; CHECK-NEXT:    ret float [[RED]]
 ;
   %rev = call <vscale x 4 x float> @llvm.vector.reverse.nxv4f32(<vscale x 4 x float> %v0)



More information about the llvm-commits mailing list