[llvm] [SelectionDAG][x86] Ensure vector reduction optimization (PR #144231)

Suhajda Tamás via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 17 10:58:59 PDT 2025


https://github.com/sutajo updated https://github.com/llvm/llvm-project/pull/144231

>From 2b2130a54aa74635ca194d6533e2e9ecc313f39e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Suhajda=20Tam=C3=A1s?= <sutajo at gmail.com>
Date: Sat, 14 Jun 2025 19:04:49 +0200
Subject: [PATCH 1/4] [x86] Add test for reduction

---
 llvm/test/CodeGen/X86/optimize-reduction.ll | 140 ++++++++++++++++++++
 1 file changed, 140 insertions(+)
 create mode 100644 llvm/test/CodeGen/X86/optimize-reduction.ll

diff --git a/llvm/test/CodeGen/X86/optimize-reduction.ll b/llvm/test/CodeGen/X86/optimize-reduction.ll
new file mode 100644
index 0000000000000..003c41612b8bf
--- /dev/null
+++ b/llvm/test/CodeGen/X86/optimize-reduction.ll
@@ -0,0 +1,140 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=x86_64-- -mattr=+sse4.1,+fast-hops | FileCheck %s --check-prefixes=SSE41
+; RUN: llc < %s -mtriple=x86_64-- -mattr=+avx2,+fast-hops | FileCheck %s --check-prefixes=AVX2
+
+define { i16, i16 } @test_reduce_v16i16_with_umin(<16 x i16> %x, <16 x i16> %y) {
+; SSE41-LABEL: test_reduce_v16i16_with_umin:
+; SSE41:       # %bb.0:
+; SSE41-NEXT:    movdqa %xmm0, %xmm4
+; SSE41-NEXT:    pminuw %xmm1, %xmm4
+; SSE41-NEXT:    pshufd {{.*#+}} xmm5 = xmm4[2,3,2,3]
+; SSE41-NEXT:    pminuw %xmm4, %xmm5
+; SSE41-NEXT:    pshufd {{.*#+}} xmm6 = xmm5[1,1,1,1]
+; SSE41-NEXT:    pminuw %xmm5, %xmm6
+; SSE41-NEXT:    movdqa %xmm6, %xmm5
+; SSE41-NEXT:    psrld $16, %xmm5
+; SSE41-NEXT:    pminuw %xmm6, %xmm5
+; SSE41-NEXT:    phminposuw %xmm4, %xmm4
+; SSE41-NEXT:    movd %xmm4, %eax
+; SSE41-NEXT:    pshuflw {{.*#+}} xmm4 = xmm5[0,0,0,0,4,5,6,7]
+; SSE41-NEXT:    pshufd {{.*#+}} xmm4 = xmm4[0,1,0,1]
+; SSE41-NEXT:    pcmpeqw %xmm4, %xmm1
+; SSE41-NEXT:    pcmpeqd %xmm5, %xmm5
+; SSE41-NEXT:    pxor %xmm5, %xmm1
+; SSE41-NEXT:    por %xmm3, %xmm1
+; SSE41-NEXT:    pcmpeqw %xmm4, %xmm0
+; SSE41-NEXT:    pxor %xmm5, %xmm0
+; SSE41-NEXT:    por %xmm2, %xmm0
+; SSE41-NEXT:    pminuw %xmm1, %xmm0
+; SSE41-NEXT:    phminposuw %xmm0, %xmm0
+; SSE41-NEXT:    movd %xmm0, %edx
+; SSE41-NEXT:    # kill: def $ax killed $ax killed $eax
+; SSE41-NEXT:    # kill: def $dx killed $dx killed $edx
+; SSE41-NEXT:    retq
+;
+; AVX2-LABEL: test_reduce_v16i16_with_umin:
+; AVX2:       # %bb.0:
+; AVX2-NEXT:    vextracti128 $1, %ymm0, %xmm2
+; AVX2-NEXT:    vpminuw %xmm2, %xmm0, %xmm2
+; AVX2-NEXT:    vpshufd {{.*#+}} xmm3 = xmm2[2,3,2,3]
+; AVX2-NEXT:    vpminuw %xmm3, %xmm2, %xmm3
+; AVX2-NEXT:    vpshufd {{.*#+}} xmm4 = xmm3[1,1,1,1]
+; AVX2-NEXT:    vpminuw %xmm4, %xmm3, %xmm3
+; AVX2-NEXT:    vpsrld $16, %xmm3, %xmm4
+; AVX2-NEXT:    vphminposuw %xmm2, %xmm2
+; AVX2-NEXT:    vmovd %xmm2, %eax
+; AVX2-NEXT:    vpminuw %xmm4, %xmm3, %xmm2
+; AVX2-NEXT:    vpbroadcastw %xmm2, %ymm2
+; AVX2-NEXT:    vpcmpeqw %ymm2, %ymm0, %ymm0
+; AVX2-NEXT:    vpcmpeqd %ymm2, %ymm2, %ymm2
+; AVX2-NEXT:    vpxor %ymm2, %ymm0, %ymm0
+; AVX2-NEXT:    vpor %ymm1, %ymm0, %ymm0
+; AVX2-NEXT:    vextracti128 $1, %ymm0, %xmm1
+; AVX2-NEXT:    vpminuw %xmm1, %xmm0, %xmm0
+; AVX2-NEXT:    vphminposuw %xmm0, %xmm0
+; AVX2-NEXT:    vmovd %xmm0, %edx
+; AVX2-NEXT:    # kill: def $ax killed $ax killed $eax
+; AVX2-NEXT:    # kill: def $dx killed $dx killed $edx
+; AVX2-NEXT:    vzeroupper
+; AVX2-NEXT:    retq
+  %min_x = tail call i16 @llvm.vector.reduce.umin.v16i16(<16 x i16> %x)
+  %min_x_vec = insertelement <1 x i16> poison, i16 %min_x, i64 0
+  %min_x_splat = shufflevector <1 x i16> %min_x_vec, <1 x i16> poison, <16 x i32> zeroinitializer
+  %cmp = icmp eq <16 x i16> %x, %min_x_splat
+  %select = select <16 x i1> %cmp, <16 x i16> %y, <16 x i16> splat (i16 -1)
+  %select_min = tail call i16 @llvm.vector.reduce.umin.v16i16(<16 x i16> %select)
+  %ret_0 = insertvalue { i16, i16 } poison, i16 %min_x, 0
+  %ret = insertvalue { i16, i16 } %ret_0, i16 %select_min, 1
+  ret { i16, i16 } %ret
+}
+
+define { i16, i16 } @test_reduce_v16i16_with_add(<16 x i16> %x, <16 x i16> %y) {
+; SSE41-LABEL: test_reduce_v16i16_with_add:
+; SSE41:       # %bb.0: # %start
+; SSE41-NEXT:    movdqa %xmm0, %xmm4
+; SSE41-NEXT:    paddw %xmm1, %xmm4
+; SSE41-NEXT:    pshufd {{.*#+}} xmm5 = xmm4[2,3,2,3]
+; SSE41-NEXT:    paddw %xmm4, %xmm5
+; SSE41-NEXT:    pshufd {{.*#+}} xmm4 = xmm5[1,1,1,1]
+; SSE41-NEXT:    paddw %xmm5, %xmm4
+; SSE41-NEXT:    phaddw %xmm4, %xmm4
+; SSE41-NEXT:    movdqa %xmm1, %xmm5
+; SSE41-NEXT:    phaddw %xmm0, %xmm5
+; SSE41-NEXT:    phaddw %xmm5, %xmm5
+; SSE41-NEXT:    phaddw %xmm5, %xmm5
+; SSE41-NEXT:    phaddw %xmm5, %xmm5
+; SSE41-NEXT:    movd %xmm5, %eax
+; SSE41-NEXT:    pshuflw {{.*#+}} xmm4 = xmm4[0,0,0,0,4,5,6,7]
+; SSE41-NEXT:    pshufd {{.*#+}} xmm4 = xmm4[0,1,0,1]
+; SSE41-NEXT:    pcmpeqw %xmm4, %xmm1
+; SSE41-NEXT:    pcmpeqd %xmm5, %xmm5
+; SSE41-NEXT:    pxor %xmm5, %xmm1
+; SSE41-NEXT:    por %xmm3, %xmm1
+; SSE41-NEXT:    pcmpeqw %xmm4, %xmm0
+; SSE41-NEXT:    pxor %xmm5, %xmm0
+; SSE41-NEXT:    por %xmm2, %xmm0
+; SSE41-NEXT:    pminuw %xmm1, %xmm0
+; SSE41-NEXT:    phminposuw %xmm0, %xmm0
+; SSE41-NEXT:    movd %xmm0, %edx
+; SSE41-NEXT:    # kill: def $ax killed $ax killed $eax
+; SSE41-NEXT:    # kill: def $dx killed $dx killed $edx
+; SSE41-NEXT:    retq
+;
+; AVX2-LABEL: test_reduce_v16i16_with_add:
+; AVX2:       # %bb.0: # %start
+; AVX2-NEXT:    vextracti128 $1, %ymm0, %xmm2
+; AVX2-NEXT:    vpaddw %xmm2, %xmm0, %xmm3
+; AVX2-NEXT:    vpshufd {{.*#+}} xmm4 = xmm3[2,3,2,3]
+; AVX2-NEXT:    vpaddw %xmm4, %xmm3, %xmm3
+; AVX2-NEXT:    vpshufd {{.*#+}} xmm4 = xmm3[1,1,1,1]
+; AVX2-NEXT:    vpaddw %xmm4, %xmm3, %xmm3
+; AVX2-NEXT:    vphaddw %xmm3, %xmm3, %xmm3
+; AVX2-NEXT:    vphaddw %xmm0, %xmm2, %xmm2
+; AVX2-NEXT:    vphaddw %xmm2, %xmm2, %xmm2
+; AVX2-NEXT:    vphaddw %xmm2, %xmm2, %xmm2
+; AVX2-NEXT:    vphaddw %xmm2, %xmm2, %xmm2
+; AVX2-NEXT:    vmovd %xmm2, %eax
+; AVX2-NEXT:    vpbroadcastw %xmm3, %ymm2
+; AVX2-NEXT:    vpcmpeqw %ymm2, %ymm0, %ymm0
+; AVX2-NEXT:    vpcmpeqd %ymm2, %ymm2, %ymm2
+; AVX2-NEXT:    vpxor %ymm2, %ymm0, %ymm0
+; AVX2-NEXT:    vpor %ymm1, %ymm0, %ymm0
+; AVX2-NEXT:    vextracti128 $1, %ymm0, %xmm1
+; AVX2-NEXT:    vpminuw %xmm1, %xmm0, %xmm0
+; AVX2-NEXT:    vphminposuw %xmm0, %xmm0
+; AVX2-NEXT:    vmovd %xmm0, %edx
+; AVX2-NEXT:    # kill: def $ax killed $ax killed $eax
+; AVX2-NEXT:    # kill: def $dx killed $dx killed $edx
+; AVX2-NEXT:    vzeroupper
+; AVX2-NEXT:    retq
+start:
+  %sum_x = tail call i16 @llvm.vector.reduce.add.v16i16(<16 x i16> %x)
+  %sum_x_vec = insertelement <1 x i16> poison, i16 %sum_x, i64 0
+  %sum_x_splat = shufflevector <1 x i16> %sum_x_vec, <1 x i16> poison, <16 x i32> zeroinitializer
+  %cmp = icmp eq <16 x i16> %x, %sum_x_splat
+  %select = select <16 x i1> %cmp, <16 x i16> %y, <16 x i16> splat (i16 -1)
+  %select_min = tail call i16 @llvm.vector.reduce.umin.v16i16(<16 x i16> %select)
+  %ret_0 = insertvalue { i16, i16 } poison, i16 %sum_x, 0
+  %ret = insertvalue { i16, i16 } %ret_0, i16 %select_min, 1
+  ret { i16, i16 } %ret
+}

>From 93df21fcff0b9f836947cfd9bb8e88b565b1c435 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Suhajda=20Tam=C3=A1s?= <sutajo at gmail.com>
Date: Sat, 14 Jun 2025 19:09:44 +0200
Subject: [PATCH 2/4] [x86] Implement optimization and update tests

---
 llvm/lib/Target/X86/X86ISelLowering.cpp     | 52 ++++++++++++++++++---
 llvm/test/CodeGen/X86/optimize-reduction.ll | 40 +++-------------
 2 files changed, 52 insertions(+), 40 deletions(-)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index b4670e270141f..61e3979b6c0bb 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -47081,7 +47081,8 @@ static SDValue combineArithReduction(SDNode *ExtElt, SelectionDAG &DAG,
 /// scalars back, while for x64 we should use 64-bit extracts and shifts.
 static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,
                                        TargetLowering::DAGCombinerInfo &DCI,
-                                       const X86Subtarget &Subtarget) {
+                                       const X86Subtarget &Subtarget, 
+                                       bool& TransformedBinOpReduction) {
   if (SDValue NewOp = combineExtractWithShuffle(N, DAG, DCI, Subtarget))
     return NewOp;
 
@@ -47169,23 +47170,33 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,
   // Check whether this extract is the root of a sum of absolute differences
   // pattern. This has to be done here because we really want it to happen
   // pre-legalization,
-  if (SDValue SAD = combineBasicSADPattern(N, DAG, Subtarget))
+  if (SDValue SAD = combineBasicSADPattern(N, DAG, Subtarget)) {
+    TransformedBinOpReduction = true;
     return SAD;
+  }
 
-  if (SDValue VPDPBUSD = combineVPDPBUSDPattern(N, DAG, Subtarget))
+  if (SDValue VPDPBUSD = combineVPDPBUSDPattern(N, DAG, Subtarget)) {
+    TransformedBinOpReduction = true;
     return VPDPBUSD;
+  }
 
   // Attempt to replace an all_of/any_of horizontal reduction with a MOVMSK.
-  if (SDValue Cmp = combinePredicateReduction(N, DAG, Subtarget))
+  if (SDValue Cmp = combinePredicateReduction(N, DAG, Subtarget)) {
+    TransformedBinOpReduction = true;
     return Cmp;
+  }
 
   // Attempt to replace min/max v8i16/v16i8 reductions with PHMINPOSUW.
-  if (SDValue MinMax = combineMinMaxReduction(N, DAG, Subtarget))
+  if (SDValue MinMax = combineMinMaxReduction(N, DAG, Subtarget)) {
+    TransformedBinOpReduction = true;
     return MinMax;
+  }
 
   // Attempt to optimize ADD/FADD/MUL reductions with HADD, promotion etc..
-  if (SDValue V = combineArithReduction(N, DAG, Subtarget))
+  if (SDValue V = combineArithReduction(N, DAG, Subtarget)) {
+    TransformedBinOpReduction = true;
     return V;
+  }
 
   if (SDValue V = scalarizeExtEltFP(N, DAG, Subtarget, DCI))
     return V;
@@ -47255,6 +47266,33 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,
   return SDValue();
 }
 
+static SDValue combineExtractVectorEltAndOperand(SDNode* N, SelectionDAG& DAG,
+    TargetLowering::DAGCombinerInfo& DCI,
+    const X86Subtarget& Subtarget)
+{
+  bool TransformedBinOpReduction = false;
+  auto Op = combineExtractVectorElt(N, DAG, DCI, Subtarget, TransformedBinOpReduction);
+
+  if (TransformedBinOpReduction)
+  {
+    // In case we simplified N = extract_vector_element(V, 0) with Op and V
+    // resulted from a reduction, then we need to replace all uses of V with
+    // scalar_to_vector(Op) to make sure that we eliminated the binop + shuffle
+    // pyramid. This is safe to do, because the elements of V are undefined except 
+    // for the zeroth element.
+
+    auto OldV = N->getOperand(0);
+    auto NewV = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), OldV->getValueType(0), Op);
+
+    auto NV = DCI.CombineTo(N, Op);
+    DCI.CombineTo(OldV.getNode(), NewV);
+
+    Op = NV; // Return N so it doesn't get rechecked!
+  }
+
+  return Op;
+}
+
 // Convert (vXiY *ext(vXi1 bitcast(iX))) to extend_in_reg(broadcast(iX)).
 // This is more or less the reverse of combineBitcastvxi1.
 static SDValue combineToExtendBoolVectorInReg(
@@ -60702,7 +60740,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
   case ISD::EXTRACT_VECTOR_ELT:
   case X86ISD::PEXTRW:
   case X86ISD::PEXTRB:
-    return combineExtractVectorElt(N, DAG, DCI, Subtarget);
+    return combineExtractVectorEltAndOperand(N, DAG, DCI, Subtarget);
   case ISD::CONCAT_VECTORS:
     return combineCONCAT_VECTORS(N, DAG, DCI, Subtarget);
   case ISD::INSERT_SUBVECTOR:
diff --git a/llvm/test/CodeGen/X86/optimize-reduction.ll b/llvm/test/CodeGen/X86/optimize-reduction.ll
index 003c41612b8bf..e51ac1bd3c13c 100644
--- a/llvm/test/CodeGen/X86/optimize-reduction.ll
+++ b/llvm/test/CodeGen/X86/optimize-reduction.ll
@@ -7,16 +7,9 @@ define { i16, i16 } @test_reduce_v16i16_with_umin(<16 x i16> %x, <16 x i16> %y)
 ; SSE41:       # %bb.0:
 ; SSE41-NEXT:    movdqa %xmm0, %xmm4
 ; SSE41-NEXT:    pminuw %xmm1, %xmm4
-; SSE41-NEXT:    pshufd {{.*#+}} xmm5 = xmm4[2,3,2,3]
-; SSE41-NEXT:    pminuw %xmm4, %xmm5
-; SSE41-NEXT:    pshufd {{.*#+}} xmm6 = xmm5[1,1,1,1]
-; SSE41-NEXT:    pminuw %xmm5, %xmm6
-; SSE41-NEXT:    movdqa %xmm6, %xmm5
-; SSE41-NEXT:    psrld $16, %xmm5
-; SSE41-NEXT:    pminuw %xmm6, %xmm5
 ; SSE41-NEXT:    phminposuw %xmm4, %xmm4
 ; SSE41-NEXT:    movd %xmm4, %eax
-; SSE41-NEXT:    pshuflw {{.*#+}} xmm4 = xmm5[0,0,0,0,4,5,6,7]
+; SSE41-NEXT:    pshuflw {{.*#+}} xmm4 = xmm4[0,0,0,0,4,5,6,7]
 ; SSE41-NEXT:    pshufd {{.*#+}} xmm4 = xmm4[0,1,0,1]
 ; SSE41-NEXT:    pcmpeqw %xmm4, %xmm1
 ; SSE41-NEXT:    pcmpeqd %xmm5, %xmm5
@@ -36,14 +29,8 @@ define { i16, i16 } @test_reduce_v16i16_with_umin(<16 x i16> %x, <16 x i16> %y)
 ; AVX2:       # %bb.0:
 ; AVX2-NEXT:    vextracti128 $1, %ymm0, %xmm2
 ; AVX2-NEXT:    vpminuw %xmm2, %xmm0, %xmm2
-; AVX2-NEXT:    vpshufd {{.*#+}} xmm3 = xmm2[2,3,2,3]
-; AVX2-NEXT:    vpminuw %xmm3, %xmm2, %xmm3
-; AVX2-NEXT:    vpshufd {{.*#+}} xmm4 = xmm3[1,1,1,1]
-; AVX2-NEXT:    vpminuw %xmm4, %xmm3, %xmm3
-; AVX2-NEXT:    vpsrld $16, %xmm3, %xmm4
 ; AVX2-NEXT:    vphminposuw %xmm2, %xmm2
 ; AVX2-NEXT:    vmovd %xmm2, %eax
-; AVX2-NEXT:    vpminuw %xmm4, %xmm3, %xmm2
 ; AVX2-NEXT:    vpbroadcastw %xmm2, %ymm2
 ; AVX2-NEXT:    vpcmpeqw %ymm2, %ymm0, %ymm0
 ; AVX2-NEXT:    vpcmpeqd %ymm2, %ymm2, %ymm2
@@ -71,19 +58,12 @@ define { i16, i16 } @test_reduce_v16i16_with_umin(<16 x i16> %x, <16 x i16> %y)
 define { i16, i16 } @test_reduce_v16i16_with_add(<16 x i16> %x, <16 x i16> %y) {
 ; SSE41-LABEL: test_reduce_v16i16_with_add:
 ; SSE41:       # %bb.0: # %start
-; SSE41-NEXT:    movdqa %xmm0, %xmm4
-; SSE41-NEXT:    paddw %xmm1, %xmm4
-; SSE41-NEXT:    pshufd {{.*#+}} xmm5 = xmm4[2,3,2,3]
-; SSE41-NEXT:    paddw %xmm4, %xmm5
-; SSE41-NEXT:    pshufd {{.*#+}} xmm4 = xmm5[1,1,1,1]
-; SSE41-NEXT:    paddw %xmm5, %xmm4
+; SSE41-NEXT:    movdqa %xmm1, %xmm4
+; SSE41-NEXT:    phaddw %xmm0, %xmm4
+; SSE41-NEXT:    phaddw %xmm4, %xmm4
 ; SSE41-NEXT:    phaddw %xmm4, %xmm4
-; SSE41-NEXT:    movdqa %xmm1, %xmm5
-; SSE41-NEXT:    phaddw %xmm0, %xmm5
-; SSE41-NEXT:    phaddw %xmm5, %xmm5
-; SSE41-NEXT:    phaddw %xmm5, %xmm5
-; SSE41-NEXT:    phaddw %xmm5, %xmm5
-; SSE41-NEXT:    movd %xmm5, %eax
+; SSE41-NEXT:    phaddw %xmm4, %xmm4
+; SSE41-NEXT:    movd %xmm4, %eax
 ; SSE41-NEXT:    pshuflw {{.*#+}} xmm4 = xmm4[0,0,0,0,4,5,6,7]
 ; SSE41-NEXT:    pshufd {{.*#+}} xmm4 = xmm4[0,1,0,1]
 ; SSE41-NEXT:    pcmpeqw %xmm4, %xmm1
@@ -103,18 +83,12 @@ define { i16, i16 } @test_reduce_v16i16_with_add(<16 x i16> %x, <16 x i16> %y) {
 ; AVX2-LABEL: test_reduce_v16i16_with_add:
 ; AVX2:       # %bb.0: # %start
 ; AVX2-NEXT:    vextracti128 $1, %ymm0, %xmm2
-; AVX2-NEXT:    vpaddw %xmm2, %xmm0, %xmm3
-; AVX2-NEXT:    vpshufd {{.*#+}} xmm4 = xmm3[2,3,2,3]
-; AVX2-NEXT:    vpaddw %xmm4, %xmm3, %xmm3
-; AVX2-NEXT:    vpshufd {{.*#+}} xmm4 = xmm3[1,1,1,1]
-; AVX2-NEXT:    vpaddw %xmm4, %xmm3, %xmm3
-; AVX2-NEXT:    vphaddw %xmm3, %xmm3, %xmm3
 ; AVX2-NEXT:    vphaddw %xmm0, %xmm2, %xmm2
 ; AVX2-NEXT:    vphaddw %xmm2, %xmm2, %xmm2
 ; AVX2-NEXT:    vphaddw %xmm2, %xmm2, %xmm2
 ; AVX2-NEXT:    vphaddw %xmm2, %xmm2, %xmm2
 ; AVX2-NEXT:    vmovd %xmm2, %eax
-; AVX2-NEXT:    vpbroadcastw %xmm3, %ymm2
+; AVX2-NEXT:    vpbroadcastw %xmm2, %ymm2
 ; AVX2-NEXT:    vpcmpeqw %ymm2, %ymm0, %ymm0
 ; AVX2-NEXT:    vpcmpeqd %ymm2, %ymm2, %ymm2
 ; AVX2-NEXT:    vpxor %ymm2, %ymm0, %ymm0

>From 57a37886a8d426a049c1a1c63b5064b9c55cf086 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Suhajda=20Tam=C3=A1s?= <sutajo at gmail.com>
Date: Sat, 14 Jun 2025 19:59:01 +0200
Subject: [PATCH 3/4] [x86] Assert that the new reduction does not depend on
 the converted one

---
 llvm/lib/Target/X86/X86ISelLowering.cpp | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 61e3979b6c0bb..b606de022daf0 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -47279,9 +47279,12 @@ static SDValue combineExtractVectorEltAndOperand(SDNode* N, SelectionDAG& DAG,
     // resulted from a reduction, then we need to replace all uses of V with
     // scalar_to_vector(Op) to make sure that we eliminated the binop + shuffle
     // pyramid. This is safe to do, because the elements of V are undefined except 
-    // for the zeroth element.
+    // for the zeroth element and Op does not depend on V.
 
     auto OldV = N->getOperand(0);
+    assert(!Op.getNode()->hasPredecessor(OldV.getNode()) && 
+        "Op must not depend on the converted reduction");
+
     auto NewV = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), OldV->getValueType(0), Op);
 
     auto NV = DCI.CombineTo(N, Op);

>From ebb3ba0cdf1a8159b279367e3e94f65dab38e654 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Suhajda=20Tam=C3=A1s?= <sutajo at gmail.com>
Date: Tue, 17 Jun 2025 19:58:42 +0200
Subject: [PATCH 4/4] [x86] Add custom lowering for min/max vector reductions

---
 llvm/lib/Target/X86/X86ISelLowering.cpp       | 163 ++++++++++++------
 .../lib/Target/X86/X86TargetTransformInfo.cpp |  19 ++
 llvm/lib/Target/X86/X86TargetTransformInfo.h  |   1 +
 llvm/test/CodeGen/X86/optimize-reduction.ll   |   1 +
 4 files changed, 130 insertions(+), 54 deletions(-)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index b606de022daf0..684092e416ca4 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -1435,6 +1435,20 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
       setOperationAction(ISD::BITREVERSE, VT, Custom);
   }
 
+  // Vector min/max reductions
+  if (Subtarget.hasSSE41())
+  {
+    for (MVT VT : MVT::vector_valuetypes()) {
+      if (VT.getScalarType() == MVT::i8 || VT.getScalarType() == MVT::i16)
+      {
+        setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom);
+        setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom);
+        setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom);
+        setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom);
+      }
+    }
+  }
+
   if (!Subtarget.useSoftFloat() && Subtarget.hasAVX()) {
     bool HasInt256 = Subtarget.hasInt256();
 
@@ -25409,6 +25423,94 @@ static SDValue LowerEXTEND_VECTOR_INREG(SDValue Op,
   return SignExt;
 }
 
+// Create a min/max v8i16/v16i8 horizontal reduction with PHMINPOSUW.
+static SDValue createMinMaxReduction(SDValue Src, EVT TargetVT, SDLoc DL,
+                                     ISD::NodeType BinOp, SelectionDAG &DAG,
+                                     const X86Subtarget &Subtarget)
+{
+  assert(Subtarget.hasSSE41() && "The caller must check if SSE4.1 is available");
+
+  EVT SrcVT = Src.getValueType();
+  EVT SrcSVT = SrcVT.getScalarType();
+
+  if (SrcSVT != TargetVT || (SrcVT.getSizeInBits() % 128) != 0)
+    return SDValue();
+
+  // First, reduce the source down to 128-bit, applying BinOp to lo/hi.
+  while (SrcVT.getSizeInBits() > 128) {
+    SDValue Lo, Hi;
+    std::tie(Lo, Hi) = splitVector(Src, DAG, DL);
+    SrcVT = Lo.getValueType();
+    Src = DAG.getNode(BinOp, DL, SrcVT, Lo, Hi);
+  }
+  assert(((SrcVT == MVT::v8i16 && TargetVT == MVT::i16) ||
+          (SrcVT == MVT::v16i8 && TargetVT == MVT::i8)) &&
+         "Unexpected value type");
+
+  // PHMINPOSUW applies to UMIN(v8i16), for SMIN/SMAX/UMAX we must apply a mask
+  // to flip the value accordingly.
+  SDValue Mask;
+  unsigned MaskEltsBits = TargetVT.getSizeInBits();
+  if (BinOp == ISD::SMAX)
+    Mask = DAG.getConstant(APInt::getSignedMaxValue(MaskEltsBits), DL, SrcVT);
+  else if (BinOp == ISD::SMIN)
+    Mask = DAG.getConstant(APInt::getSignedMinValue(MaskEltsBits), DL, SrcVT);
+  else if (BinOp == ISD::UMAX)
+    Mask = DAG.getAllOnesConstant(DL, SrcVT);
+
+  if (Mask)
+    Src = DAG.getNode(ISD::XOR, DL, SrcVT, Mask, Src);
+
+  // For v16i8 cases we need to perform UMIN on pairs of byte elements,
+  // shuffling each upper element down and insert zeros. This means that the
+  // v16i8 UMIN will leave the upper element as zero, performing zero-extension
+  // ready for the PHMINPOS.
+  if (TargetVT == MVT::i8) {
+    SDValue Upper = DAG.getVectorShuffle(
+        SrcVT, DL, Src, DAG.getConstant(0, DL, MVT::v16i8),
+        {1, 16, 3, 16, 5, 16, 7, 16, 9, 16, 11, 16, 13, 16, 15, 16});
+    Src = DAG.getNode(ISD::UMIN, DL, SrcVT, Src, Upper);
+  }
+
+  // Perform the PHMINPOS on a v8i16 vector,
+  Src = DAG.getBitcast(MVT::v8i16, Src);
+  Src = DAG.getNode(X86ISD::PHMINPOS, DL, MVT::v8i16, Src);
+  Src = DAG.getBitcast(SrcVT, Src);
+
+  if (Mask)
+    Src = DAG.getNode(ISD::XOR, DL, SrcVT, Mask, Src);
+
+  return DAG.getExtractVectorElt(DL, TargetVT, Src, 0);
+}
+
+static SDValue LowerVECTOR_REDUCE_MINMAX(SDValue Op,
+    const X86Subtarget& Subtarget,
+    SelectionDAG& DAG)
+{
+  ISD::NodeType BinOp;
+  switch (Op.getOpcode())
+  {
+    default: 
+      assert(false && "Expected min/max reduction");
+      break;
+    case ISD::VECREDUCE_UMIN:
+      BinOp = ISD::UMIN;
+      break;
+    case ISD::VECREDUCE_UMAX:
+      BinOp = ISD::UMAX;
+      break;
+    case ISD::VECREDUCE_SMIN:
+      BinOp = ISD::SMIN;
+      break;
+    case ISD::VECREDUCE_SMAX:
+      BinOp = ISD::SMAX;
+      break;
+  }
+
+  return createMinMaxReduction(Op->getOperand(0), Op.getValueType(), SDLoc(Op),
+      BinOp, DAG, Subtarget);
+}
+
 static SDValue LowerSIGN_EXTEND(SDValue Op, const X86Subtarget &Subtarget,
                                 SelectionDAG &DAG) {
   MVT VT = Op->getSimpleValueType(0);
@@ -33620,6 +33722,11 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   case ISD::ZERO_EXTEND_VECTOR_INREG:
   case ISD::SIGN_EXTEND_VECTOR_INREG:
     return LowerEXTEND_VECTOR_INREG(Op, Subtarget, DAG);
+  case ISD::VECREDUCE_UMIN:
+  case ISD::VECREDUCE_UMAX:
+  case ISD::VECREDUCE_SMIN:
+  case ISD::VECREDUCE_SMAX:
+    return LowerVECTOR_REDUCE_MINMAX(Op, Subtarget, DAG);
   case ISD::FP_TO_SINT:
   case ISD::STRICT_FP_TO_SINT:
   case ISD::FP_TO_UINT:
@@ -46192,60 +46299,8 @@ static SDValue combineMinMaxReduction(SDNode *Extract, SelectionDAG &DAG,
   if (!Src)
     return SDValue();
 
-  EVT SrcVT = Src.getValueType();
-  EVT SrcSVT = SrcVT.getScalarType();
-  if (SrcSVT != ExtractVT || (SrcVT.getSizeInBits() % 128) != 0)
-    return SDValue();
-
-  SDLoc DL(Extract);
-  SDValue MinPos = Src;
-
-  // First, reduce the source down to 128-bit, applying BinOp to lo/hi.
-  while (SrcVT.getSizeInBits() > 128) {
-    SDValue Lo, Hi;
-    std::tie(Lo, Hi) = splitVector(MinPos, DAG, DL);
-    SrcVT = Lo.getValueType();
-    MinPos = DAG.getNode(BinOp, DL, SrcVT, Lo, Hi);
-  }
-  assert(((SrcVT == MVT::v8i16 && ExtractVT == MVT::i16) ||
-          (SrcVT == MVT::v16i8 && ExtractVT == MVT::i8)) &&
-         "Unexpected value type");
-
-  // PHMINPOSUW applies to UMIN(v8i16), for SMIN/SMAX/UMAX we must apply a mask
-  // to flip the value accordingly.
-  SDValue Mask;
-  unsigned MaskEltsBits = ExtractVT.getSizeInBits();
-  if (BinOp == ISD::SMAX)
-    Mask = DAG.getConstant(APInt::getSignedMaxValue(MaskEltsBits), DL, SrcVT);
-  else if (BinOp == ISD::SMIN)
-    Mask = DAG.getConstant(APInt::getSignedMinValue(MaskEltsBits), DL, SrcVT);
-  else if (BinOp == ISD::UMAX)
-    Mask = DAG.getAllOnesConstant(DL, SrcVT);
-
-  if (Mask)
-    MinPos = DAG.getNode(ISD::XOR, DL, SrcVT, Mask, MinPos);
-
-  // For v16i8 cases we need to perform UMIN on pairs of byte elements,
-  // shuffling each upper element down and insert zeros. This means that the
-  // v16i8 UMIN will leave the upper element as zero, performing zero-extension
-  // ready for the PHMINPOS.
-  if (ExtractVT == MVT::i8) {
-    SDValue Upper = DAG.getVectorShuffle(
-        SrcVT, DL, MinPos, DAG.getConstant(0, DL, MVT::v16i8),
-        {1, 16, 3, 16, 5, 16, 7, 16, 9, 16, 11, 16, 13, 16, 15, 16});
-    MinPos = DAG.getNode(ISD::UMIN, DL, SrcVT, MinPos, Upper);
-  }
-
-  // Perform the PHMINPOS on a v8i16 vector,
-  MinPos = DAG.getBitcast(MVT::v8i16, MinPos);
-  MinPos = DAG.getNode(X86ISD::PHMINPOS, DL, MVT::v8i16, MinPos);
-  MinPos = DAG.getBitcast(SrcVT, MinPos);
-
-  if (Mask)
-    MinPos = DAG.getNode(ISD::XOR, DL, SrcVT, Mask, MinPos);
-
-  return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ExtractVT, MinPos,
-                     DAG.getVectorIdxConstant(0, DL));
+  return createMinMaxReduction(Src, ExtractVT, SDLoc(Extract),
+      BinOp, DAG, Subtarget);
 }
 
 // Attempt to replace an all_of/any_of/parity style horizontal reduction with a MOVMSK.
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index a1a177528eb23..3c479fc72ce30 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -6575,6 +6575,25 @@ X86TTIImpl::enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const {
   return Options;
 }
 
+bool llvm::X86TTIImpl::shouldExpandReduction(const IntrinsicInst *II) const {
+  switch (II->getIntrinsicID()) {
+  default:
+    return true;
+
+  case Intrinsic::vector_reduce_umin:
+  case Intrinsic::vector_reduce_umax:
+  case Intrinsic::vector_reduce_smin:
+  case Intrinsic::vector_reduce_smax:
+    auto *VType = cast<FixedVectorType>(II->getOperand(0)->getType());
+    auto SType = VType->getScalarType();
+    bool CanUsePHMINPOSUW = 
+        ST->hasSSE41() && II->getType() == SType &&
+        (VType->getPrimitiveSizeInBits() % 128) == 0 &&
+        (SType->isIntegerTy(8) || SType->isIntegerTy(16));
+    return !CanUsePHMINPOSUW;
+  }
+}
+
 bool X86TTIImpl::prefersVectorizedAddressing() const {
   return supportsGather();
 }
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index 72673d6fbd80f..5e2fe40f9f902 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -303,6 +303,7 @@ class X86TTIImpl : public BasicTTIImplBase<X86TTIImpl> {
   TTI::MemCmpExpansionOptions
   enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const override;
   bool preferAlternateOpcodeVectorization() const override { return false; }
+  bool shouldExpandReduction(const IntrinsicInst *II) const override;
   bool prefersVectorizedAddressing() const override;
   bool supportsEfficientVectorElementLoadStore() const override;
   bool enableInterleavedAccessVectorization() const override;
diff --git a/llvm/test/CodeGen/X86/optimize-reduction.ll b/llvm/test/CodeGen/X86/optimize-reduction.ll
index e51ac1bd3c13c..4e9732882d2bb 100644
--- a/llvm/test/CodeGen/X86/optimize-reduction.ll
+++ b/llvm/test/CodeGen/X86/optimize-reduction.ll
@@ -31,6 +31,7 @@ define { i16, i16 } @test_reduce_v16i16_with_umin(<16 x i16> %x, <16 x i16> %y)
 ; AVX2-NEXT:    vpminuw %xmm2, %xmm0, %xmm2
 ; AVX2-NEXT:    vphminposuw %xmm2, %xmm2
 ; AVX2-NEXT:    vmovd %xmm2, %eax
+; AVX2-NEXT:    vmovd %eax, %xmm2
 ; AVX2-NEXT:    vpbroadcastw %xmm2, %ymm2
 ; AVX2-NEXT:    vpcmpeqw %ymm2, %ymm0, %ymm0
 ; AVX2-NEXT:    vpcmpeqd %ymm2, %ymm2, %ymm2



More information about the llvm-commits mailing list