[llvm] [SelectionDAG][x86] Ensure vector reduction optimization (PR #144231)

via llvm-commits llvm-commits at lists.llvm.org
Sat Jun 14 10:15:53 PDT 2025


Suhajda =?utf-8?q?Tamás?= <sutajo at gmail.com>
Message-ID:
In-Reply-To: <llvm.org/llvm/llvm-project/pull/144231 at github.com>


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-x86

Author: Suhajda Tamás (sutajo)

<details>
<summary>Changes</summary>

Fixes #<!-- -->144227

---
Full diff: https://github.com/llvm/llvm-project/pull/144231.diff


2 Files Affected:

- (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+45-7) 
- (added) llvm/test/CodeGen/X86/optimize-reduction.ll (+114) 


``````````diff
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index b4670e270141f..61e3979b6c0bb 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -47081,7 +47081,8 @@ static SDValue combineArithReduction(SDNode *ExtElt, SelectionDAG &DAG,
 /// scalars back, while for x64 we should use 64-bit extracts and shifts.
 static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,
                                        TargetLowering::DAGCombinerInfo &DCI,
-                                       const X86Subtarget &Subtarget) {
+                                       const X86Subtarget &Subtarget, 
+                                       bool& TransformedBinOpReduction) {
   if (SDValue NewOp = combineExtractWithShuffle(N, DAG, DCI, Subtarget))
     return NewOp;
 
@@ -47169,23 +47170,33 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,
   // Check whether this extract is the root of a sum of absolute differences
   // pattern. This has to be done here because we really want it to happen
   // pre-legalization,
-  if (SDValue SAD = combineBasicSADPattern(N, DAG, Subtarget))
+  if (SDValue SAD = combineBasicSADPattern(N, DAG, Subtarget)) {
+    TransformedBinOpReduction = true;
     return SAD;
+  }
 
-  if (SDValue VPDPBUSD = combineVPDPBUSDPattern(N, DAG, Subtarget))
+  if (SDValue VPDPBUSD = combineVPDPBUSDPattern(N, DAG, Subtarget)) {
+    TransformedBinOpReduction = true;
     return VPDPBUSD;
+  }
 
   // Attempt to replace an all_of/any_of horizontal reduction with a MOVMSK.
-  if (SDValue Cmp = combinePredicateReduction(N, DAG, Subtarget))
+  if (SDValue Cmp = combinePredicateReduction(N, DAG, Subtarget)) {
+    TransformedBinOpReduction = true;
     return Cmp;
+  }
 
   // Attempt to replace min/max v8i16/v16i8 reductions with PHMINPOSUW.
-  if (SDValue MinMax = combineMinMaxReduction(N, DAG, Subtarget))
+  if (SDValue MinMax = combineMinMaxReduction(N, DAG, Subtarget)) {
+    TransformedBinOpReduction = true;
     return MinMax;
+  }
 
   // Attempt to optimize ADD/FADD/MUL reductions with HADD, promotion etc..
-  if (SDValue V = combineArithReduction(N, DAG, Subtarget))
+  if (SDValue V = combineArithReduction(N, DAG, Subtarget)) {
+    TransformedBinOpReduction = true;
     return V;
+  }
 
   if (SDValue V = scalarizeExtEltFP(N, DAG, Subtarget, DCI))
     return V;
@@ -47255,6 +47266,33 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,
   return SDValue();
 }
 
+static SDValue combineExtractVectorEltAndOperand(SDNode* N, SelectionDAG& DAG,
+    TargetLowering::DAGCombinerInfo& DCI,
+    const X86Subtarget& Subtarget)
+{
+  bool TransformedBinOpReduction = false;
+  auto Op = combineExtractVectorElt(N, DAG, DCI, Subtarget, TransformedBinOpReduction);
+
+  if (TransformedBinOpReduction)
+  {
+    // In case we simplified N = extract_vector_element(V, 0) with Op and V
+    // resulted from a reduction, then we need to replace all uses of V with
+    // scalar_to_vector(Op) to make sure that we eliminated the binop + shuffle
+    // pyramid. This is safe to do, because the elements of V are undefined except 
+    // for the zeroth element.
+
+    auto OldV = N->getOperand(0);
+    auto NewV = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), OldV->getValueType(0), Op);
+
+    auto NV = DCI.CombineTo(N, Op);
+    DCI.CombineTo(OldV.getNode(), NewV);
+
+    Op = NV; // Return N so it doesn't get rechecked!
+  }
+
+  return Op;
+}
+
 // Convert (vXiY *ext(vXi1 bitcast(iX))) to extend_in_reg(broadcast(iX)).
 // This is more or less the reverse of combineBitcastvxi1.
 static SDValue combineToExtendBoolVectorInReg(
@@ -60702,7 +60740,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
   case ISD::EXTRACT_VECTOR_ELT:
   case X86ISD::PEXTRW:
   case X86ISD::PEXTRB:
-    return combineExtractVectorElt(N, DAG, DCI, Subtarget);
+    return combineExtractVectorEltAndOperand(N, DAG, DCI, Subtarget);
   case ISD::CONCAT_VECTORS:
     return combineCONCAT_VECTORS(N, DAG, DCI, Subtarget);
   case ISD::INSERT_SUBVECTOR:
diff --git a/llvm/test/CodeGen/X86/optimize-reduction.ll b/llvm/test/CodeGen/X86/optimize-reduction.ll
new file mode 100644
index 0000000000000..e51ac1bd3c13c
--- /dev/null
+++ b/llvm/test/CodeGen/X86/optimize-reduction.ll
@@ -0,0 +1,114 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=x86_64-- -mattr=+sse4.1,+fast-hops | FileCheck %s --check-prefixes=SSE41
+; RUN: llc < %s -mtriple=x86_64-- -mattr=+avx2,+fast-hops | FileCheck %s --check-prefixes=AVX2
+
+define { i16, i16 } @test_reduce_v16i16_with_umin(<16 x i16> %x, <16 x i16> %y) {
+; SSE41-LABEL: test_reduce_v16i16_with_umin:
+; SSE41:       # %bb.0:
+; SSE41-NEXT:    movdqa %xmm0, %xmm4
+; SSE41-NEXT:    pminuw %xmm1, %xmm4
+; SSE41-NEXT:    phminposuw %xmm4, %xmm4
+; SSE41-NEXT:    movd %xmm4, %eax
+; SSE41-NEXT:    pshuflw {{.*#+}} xmm4 = xmm4[0,0,0,0,4,5,6,7]
+; SSE41-NEXT:    pshufd {{.*#+}} xmm4 = xmm4[0,1,0,1]
+; SSE41-NEXT:    pcmpeqw %xmm4, %xmm1
+; SSE41-NEXT:    pcmpeqd %xmm5, %xmm5
+; SSE41-NEXT:    pxor %xmm5, %xmm1
+; SSE41-NEXT:    por %xmm3, %xmm1
+; SSE41-NEXT:    pcmpeqw %xmm4, %xmm0
+; SSE41-NEXT:    pxor %xmm5, %xmm0
+; SSE41-NEXT:    por %xmm2, %xmm0
+; SSE41-NEXT:    pminuw %xmm1, %xmm0
+; SSE41-NEXT:    phminposuw %xmm0, %xmm0
+; SSE41-NEXT:    movd %xmm0, %edx
+; SSE41-NEXT:    # kill: def $ax killed $ax killed $eax
+; SSE41-NEXT:    # kill: def $dx killed $dx killed $edx
+; SSE41-NEXT:    retq
+;
+; AVX2-LABEL: test_reduce_v16i16_with_umin:
+; AVX2:       # %bb.0:
+; AVX2-NEXT:    vextracti128 $1, %ymm0, %xmm2
+; AVX2-NEXT:    vpminuw %xmm2, %xmm0, %xmm2
+; AVX2-NEXT:    vphminposuw %xmm2, %xmm2
+; AVX2-NEXT:    vmovd %xmm2, %eax
+; AVX2-NEXT:    vpbroadcastw %xmm2, %ymm2
+; AVX2-NEXT:    vpcmpeqw %ymm2, %ymm0, %ymm0
+; AVX2-NEXT:    vpcmpeqd %ymm2, %ymm2, %ymm2
+; AVX2-NEXT:    vpxor %ymm2, %ymm0, %ymm0
+; AVX2-NEXT:    vpor %ymm1, %ymm0, %ymm0
+; AVX2-NEXT:    vextracti128 $1, %ymm0, %xmm1
+; AVX2-NEXT:    vpminuw %xmm1, %xmm0, %xmm0
+; AVX2-NEXT:    vphminposuw %xmm0, %xmm0
+; AVX2-NEXT:    vmovd %xmm0, %edx
+; AVX2-NEXT:    # kill: def $ax killed $ax killed $eax
+; AVX2-NEXT:    # kill: def $dx killed $dx killed $edx
+; AVX2-NEXT:    vzeroupper
+; AVX2-NEXT:    retq
+  %min_x = tail call i16 @llvm.vector.reduce.umin.v16i16(<16 x i16> %x)
+  %min_x_vec = insertelement <1 x i16> poison, i16 %min_x, i64 0
+  %min_x_splat = shufflevector <1 x i16> %min_x_vec, <1 x i16> poison, <16 x i32> zeroinitializer
+  %cmp = icmp eq <16 x i16> %x, %min_x_splat
+  %select = select <16 x i1> %cmp, <16 x i16> %y, <16 x i16> splat (i16 -1)
+  %select_min = tail call i16 @llvm.vector.reduce.umin.v16i16(<16 x i16> %select)
+  %ret_0 = insertvalue { i16, i16 } poison, i16 %min_x, 0
+  %ret = insertvalue { i16, i16 } %ret_0, i16 %select_min, 1
+  ret { i16, i16 } %ret
+}
+
+define { i16, i16 } @test_reduce_v16i16_with_add(<16 x i16> %x, <16 x i16> %y) {
+; SSE41-LABEL: test_reduce_v16i16_with_add:
+; SSE41:       # %bb.0: # %start
+; SSE41-NEXT:    movdqa %xmm1, %xmm4
+; SSE41-NEXT:    phaddw %xmm0, %xmm4
+; SSE41-NEXT:    phaddw %xmm4, %xmm4
+; SSE41-NEXT:    phaddw %xmm4, %xmm4
+; SSE41-NEXT:    phaddw %xmm4, %xmm4
+; SSE41-NEXT:    movd %xmm4, %eax
+; SSE41-NEXT:    pshuflw {{.*#+}} xmm4 = xmm4[0,0,0,0,4,5,6,7]
+; SSE41-NEXT:    pshufd {{.*#+}} xmm4 = xmm4[0,1,0,1]
+; SSE41-NEXT:    pcmpeqw %xmm4, %xmm1
+; SSE41-NEXT:    pcmpeqd %xmm5, %xmm5
+; SSE41-NEXT:    pxor %xmm5, %xmm1
+; SSE41-NEXT:    por %xmm3, %xmm1
+; SSE41-NEXT:    pcmpeqw %xmm4, %xmm0
+; SSE41-NEXT:    pxor %xmm5, %xmm0
+; SSE41-NEXT:    por %xmm2, %xmm0
+; SSE41-NEXT:    pminuw %xmm1, %xmm0
+; SSE41-NEXT:    phminposuw %xmm0, %xmm0
+; SSE41-NEXT:    movd %xmm0, %edx
+; SSE41-NEXT:    # kill: def $ax killed $ax killed $eax
+; SSE41-NEXT:    # kill: def $dx killed $dx killed $edx
+; SSE41-NEXT:    retq
+;
+; AVX2-LABEL: test_reduce_v16i16_with_add:
+; AVX2:       # %bb.0: # %start
+; AVX2-NEXT:    vextracti128 $1, %ymm0, %xmm2
+; AVX2-NEXT:    vphaddw %xmm0, %xmm2, %xmm2
+; AVX2-NEXT:    vphaddw %xmm2, %xmm2, %xmm2
+; AVX2-NEXT:    vphaddw %xmm2, %xmm2, %xmm2
+; AVX2-NEXT:    vphaddw %xmm2, %xmm2, %xmm2
+; AVX2-NEXT:    vmovd %xmm2, %eax
+; AVX2-NEXT:    vpbroadcastw %xmm2, %ymm2
+; AVX2-NEXT:    vpcmpeqw %ymm2, %ymm0, %ymm0
+; AVX2-NEXT:    vpcmpeqd %ymm2, %ymm2, %ymm2
+; AVX2-NEXT:    vpxor %ymm2, %ymm0, %ymm0
+; AVX2-NEXT:    vpor %ymm1, %ymm0, %ymm0
+; AVX2-NEXT:    vextracti128 $1, %ymm0, %xmm1
+; AVX2-NEXT:    vpminuw %xmm1, %xmm0, %xmm0
+; AVX2-NEXT:    vphminposuw %xmm0, %xmm0
+; AVX2-NEXT:    vmovd %xmm0, %edx
+; AVX2-NEXT:    # kill: def $ax killed $ax killed $eax
+; AVX2-NEXT:    # kill: def $dx killed $dx killed $edx
+; AVX2-NEXT:    vzeroupper
+; AVX2-NEXT:    retq
+start:
+  %sum_x = tail call i16 @llvm.vector.reduce.add.v16i16(<16 x i16> %x)
+  %sum_x_vec = insertelement <1 x i16> poison, i16 %sum_x, i64 0
+  %sum_x_splat = shufflevector <1 x i16> %sum_x_vec, <1 x i16> poison, <16 x i32> zeroinitializer
+  %cmp = icmp eq <16 x i16> %x, %sum_x_splat
+  %select = select <16 x i1> %cmp, <16 x i16> %y, <16 x i16> splat (i16 -1)
+  %select_min = tail call i16 @llvm.vector.reduce.umin.v16i16(<16 x i16> %select)
+  %ret_0 = insertvalue { i16, i16 } poison, i16 %sum_x, 0
+  %ret = insertvalue { i16, i16 } %ret_0, i16 %select_min, 1
+  ret { i16, i16 } %ret
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/144231


More information about the llvm-commits mailing list