[llvm] [SelectionDAG][x86] Ensure vector reduction optimization (PR #144231)
via llvm-commits
llvm-commits at lists.llvm.org
Sat Jun 14 10:15:53 PDT 2025
Suhajda =?utf-8?q?Tamás?= <sutajo at gmail.com>
Message-ID:
In-Reply-To: <llvm.org/llvm/llvm-project/pull/144231 at github.com>
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-x86
Author: Suhajda Tamás (sutajo)
<details>
<summary>Changes</summary>
Fixes #<!-- -->144227
---
Full diff: https://github.com/llvm/llvm-project/pull/144231.diff
2 Files Affected:
- (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+45-7)
- (added) llvm/test/CodeGen/X86/optimize-reduction.ll (+114)
``````````diff
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index b4670e270141f..61e3979b6c0bb 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -47081,7 +47081,8 @@ static SDValue combineArithReduction(SDNode *ExtElt, SelectionDAG &DAG,
/// scalars back, while for x64 we should use 64-bit extracts and shifts.
static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
- const X86Subtarget &Subtarget) {
+ const X86Subtarget &Subtarget,
+ bool& TransformedBinOpReduction) {
if (SDValue NewOp = combineExtractWithShuffle(N, DAG, DCI, Subtarget))
return NewOp;
@@ -47169,23 +47170,33 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,
// Check whether this extract is the root of a sum of absolute differences
// pattern. This has to be done here because we really want it to happen
// pre-legalization,
- if (SDValue SAD = combineBasicSADPattern(N, DAG, Subtarget))
+ if (SDValue SAD = combineBasicSADPattern(N, DAG, Subtarget)) {
+ TransformedBinOpReduction = true;
return SAD;
+ }
- if (SDValue VPDPBUSD = combineVPDPBUSDPattern(N, DAG, Subtarget))
+ if (SDValue VPDPBUSD = combineVPDPBUSDPattern(N, DAG, Subtarget)) {
+ TransformedBinOpReduction = true;
return VPDPBUSD;
+ }
// Attempt to replace an all_of/any_of horizontal reduction with a MOVMSK.
- if (SDValue Cmp = combinePredicateReduction(N, DAG, Subtarget))
+ if (SDValue Cmp = combinePredicateReduction(N, DAG, Subtarget)) {
+ TransformedBinOpReduction = true;
return Cmp;
+ }
// Attempt to replace min/max v8i16/v16i8 reductions with PHMINPOSUW.
- if (SDValue MinMax = combineMinMaxReduction(N, DAG, Subtarget))
+ if (SDValue MinMax = combineMinMaxReduction(N, DAG, Subtarget)) {
+ TransformedBinOpReduction = true;
return MinMax;
+ }
// Attempt to optimize ADD/FADD/MUL reductions with HADD, promotion etc..
- if (SDValue V = combineArithReduction(N, DAG, Subtarget))
+ if (SDValue V = combineArithReduction(N, DAG, Subtarget)) {
+ TransformedBinOpReduction = true;
return V;
+ }
if (SDValue V = scalarizeExtEltFP(N, DAG, Subtarget, DCI))
return V;
@@ -47255,6 +47266,33 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,
return SDValue();
}
+static SDValue combineExtractVectorEltAndOperand(SDNode* N, SelectionDAG& DAG,
+ TargetLowering::DAGCombinerInfo& DCI,
+ const X86Subtarget& Subtarget)
+{
+ bool TransformedBinOpReduction = false;
+ auto Op = combineExtractVectorElt(N, DAG, DCI, Subtarget, TransformedBinOpReduction);
+
+ if (TransformedBinOpReduction)
+ {
+ // In case we simplified N = extract_vector_element(V, 0) with Op and V
+ // resulted from a reduction, then we need to replace all uses of V with
+ // scalar_to_vector(Op) to make sure that we eliminated the binop + shuffle
+ // pyramid. This is safe to do, because the elements of V are undefined except
+ // for the zeroth element.
+
+ auto OldV = N->getOperand(0);
+ auto NewV = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), OldV->getValueType(0), Op);
+
+ auto NV = DCI.CombineTo(N, Op);
+ DCI.CombineTo(OldV.getNode(), NewV);
+
+ Op = NV; // Return N so it doesn't get rechecked!
+ }
+
+ return Op;
+}
+
// Convert (vXiY *ext(vXi1 bitcast(iX))) to extend_in_reg(broadcast(iX)).
// This is more or less the reverse of combineBitcastvxi1.
static SDValue combineToExtendBoolVectorInReg(
@@ -60702,7 +60740,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
case ISD::EXTRACT_VECTOR_ELT:
case X86ISD::PEXTRW:
case X86ISD::PEXTRB:
- return combineExtractVectorElt(N, DAG, DCI, Subtarget);
+ return combineExtractVectorEltAndOperand(N, DAG, DCI, Subtarget);
case ISD::CONCAT_VECTORS:
return combineCONCAT_VECTORS(N, DAG, DCI, Subtarget);
case ISD::INSERT_SUBVECTOR:
diff --git a/llvm/test/CodeGen/X86/optimize-reduction.ll b/llvm/test/CodeGen/X86/optimize-reduction.ll
new file mode 100644
index 0000000000000..e51ac1bd3c13c
--- /dev/null
+++ b/llvm/test/CodeGen/X86/optimize-reduction.ll
@@ -0,0 +1,114 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=x86_64-- -mattr=+sse4.1,+fast-hops | FileCheck %s --check-prefixes=SSE41
+; RUN: llc < %s -mtriple=x86_64-- -mattr=+avx2,+fast-hops | FileCheck %s --check-prefixes=AVX2
+
+define { i16, i16 } @test_reduce_v16i16_with_umin(<16 x i16> %x, <16 x i16> %y) {
+; SSE41-LABEL: test_reduce_v16i16_with_umin:
+; SSE41: # %bb.0:
+; SSE41-NEXT: movdqa %xmm0, %xmm4
+; SSE41-NEXT: pminuw %xmm1, %xmm4
+; SSE41-NEXT: phminposuw %xmm4, %xmm4
+; SSE41-NEXT: movd %xmm4, %eax
+; SSE41-NEXT: pshuflw {{.*#+}} xmm4 = xmm4[0,0,0,0,4,5,6,7]
+; SSE41-NEXT: pshufd {{.*#+}} xmm4 = xmm4[0,1,0,1]
+; SSE41-NEXT: pcmpeqw %xmm4, %xmm1
+; SSE41-NEXT: pcmpeqd %xmm5, %xmm5
+; SSE41-NEXT: pxor %xmm5, %xmm1
+; SSE41-NEXT: por %xmm3, %xmm1
+; SSE41-NEXT: pcmpeqw %xmm4, %xmm0
+; SSE41-NEXT: pxor %xmm5, %xmm0
+; SSE41-NEXT: por %xmm2, %xmm0
+; SSE41-NEXT: pminuw %xmm1, %xmm0
+; SSE41-NEXT: phminposuw %xmm0, %xmm0
+; SSE41-NEXT: movd %xmm0, %edx
+; SSE41-NEXT: # kill: def $ax killed $ax killed $eax
+; SSE41-NEXT: # kill: def $dx killed $dx killed $edx
+; SSE41-NEXT: retq
+;
+; AVX2-LABEL: test_reduce_v16i16_with_umin:
+; AVX2: # %bb.0:
+; AVX2-NEXT: vextracti128 $1, %ymm0, %xmm2
+; AVX2-NEXT: vpminuw %xmm2, %xmm0, %xmm2
+; AVX2-NEXT: vphminposuw %xmm2, %xmm2
+; AVX2-NEXT: vmovd %xmm2, %eax
+; AVX2-NEXT: vpbroadcastw %xmm2, %ymm2
+; AVX2-NEXT: vpcmpeqw %ymm2, %ymm0, %ymm0
+; AVX2-NEXT: vpcmpeqd %ymm2, %ymm2, %ymm2
+; AVX2-NEXT: vpxor %ymm2, %ymm0, %ymm0
+; AVX2-NEXT: vpor %ymm1, %ymm0, %ymm0
+; AVX2-NEXT: vextracti128 $1, %ymm0, %xmm1
+; AVX2-NEXT: vpminuw %xmm1, %xmm0, %xmm0
+; AVX2-NEXT: vphminposuw %xmm0, %xmm0
+; AVX2-NEXT: vmovd %xmm0, %edx
+; AVX2-NEXT: # kill: def $ax killed $ax killed $eax
+; AVX2-NEXT: # kill: def $dx killed $dx killed $edx
+; AVX2-NEXT: vzeroupper
+; AVX2-NEXT: retq
+ %min_x = tail call i16 @llvm.vector.reduce.umin.v16i16(<16 x i16> %x)
+ %min_x_vec = insertelement <1 x i16> poison, i16 %min_x, i64 0
+ %min_x_splat = shufflevector <1 x i16> %min_x_vec, <1 x i16> poison, <16 x i32> zeroinitializer
+ %cmp = icmp eq <16 x i16> %x, %min_x_splat
+ %select = select <16 x i1> %cmp, <16 x i16> %y, <16 x i16> splat (i16 -1)
+ %select_min = tail call i16 @llvm.vector.reduce.umin.v16i16(<16 x i16> %select)
+ %ret_0 = insertvalue { i16, i16 } poison, i16 %min_x, 0
+ %ret = insertvalue { i16, i16 } %ret_0, i16 %select_min, 1
+ ret { i16, i16 } %ret
+}
+
+define { i16, i16 } @test_reduce_v16i16_with_add(<16 x i16> %x, <16 x i16> %y) {
+; SSE41-LABEL: test_reduce_v16i16_with_add:
+; SSE41: # %bb.0: # %start
+; SSE41-NEXT: movdqa %xmm1, %xmm4
+; SSE41-NEXT: phaddw %xmm0, %xmm4
+; SSE41-NEXT: phaddw %xmm4, %xmm4
+; SSE41-NEXT: phaddw %xmm4, %xmm4
+; SSE41-NEXT: phaddw %xmm4, %xmm4
+; SSE41-NEXT: movd %xmm4, %eax
+; SSE41-NEXT: pshuflw {{.*#+}} xmm4 = xmm4[0,0,0,0,4,5,6,7]
+; SSE41-NEXT: pshufd {{.*#+}} xmm4 = xmm4[0,1,0,1]
+; SSE41-NEXT: pcmpeqw %xmm4, %xmm1
+; SSE41-NEXT: pcmpeqd %xmm5, %xmm5
+; SSE41-NEXT: pxor %xmm5, %xmm1
+; SSE41-NEXT: por %xmm3, %xmm1
+; SSE41-NEXT: pcmpeqw %xmm4, %xmm0
+; SSE41-NEXT: pxor %xmm5, %xmm0
+; SSE41-NEXT: por %xmm2, %xmm0
+; SSE41-NEXT: pminuw %xmm1, %xmm0
+; SSE41-NEXT: phminposuw %xmm0, %xmm0
+; SSE41-NEXT: movd %xmm0, %edx
+; SSE41-NEXT: # kill: def $ax killed $ax killed $eax
+; SSE41-NEXT: # kill: def $dx killed $dx killed $edx
+; SSE41-NEXT: retq
+;
+; AVX2-LABEL: test_reduce_v16i16_with_add:
+; AVX2: # %bb.0: # %start
+; AVX2-NEXT: vextracti128 $1, %ymm0, %xmm2
+; AVX2-NEXT: vphaddw %xmm0, %xmm2, %xmm2
+; AVX2-NEXT: vphaddw %xmm2, %xmm2, %xmm2
+; AVX2-NEXT: vphaddw %xmm2, %xmm2, %xmm2
+; AVX2-NEXT: vphaddw %xmm2, %xmm2, %xmm2
+; AVX2-NEXT: vmovd %xmm2, %eax
+; AVX2-NEXT: vpbroadcastw %xmm2, %ymm2
+; AVX2-NEXT: vpcmpeqw %ymm2, %ymm0, %ymm0
+; AVX2-NEXT: vpcmpeqd %ymm2, %ymm2, %ymm2
+; AVX2-NEXT: vpxor %ymm2, %ymm0, %ymm0
+; AVX2-NEXT: vpor %ymm1, %ymm0, %ymm0
+; AVX2-NEXT: vextracti128 $1, %ymm0, %xmm1
+; AVX2-NEXT: vpminuw %xmm1, %xmm0, %xmm0
+; AVX2-NEXT: vphminposuw %xmm0, %xmm0
+; AVX2-NEXT: vmovd %xmm0, %edx
+; AVX2-NEXT: # kill: def $ax killed $ax killed $eax
+; AVX2-NEXT: # kill: def $dx killed $dx killed $edx
+; AVX2-NEXT: vzeroupper
+; AVX2-NEXT: retq
+start:
+ %sum_x = tail call i16 @llvm.vector.reduce.add.v16i16(<16 x i16> %x)
+ %sum_x_vec = insertelement <1 x i16> poison, i16 %sum_x, i64 0
+ %sum_x_splat = shufflevector <1 x i16> %sum_x_vec, <1 x i16> poison, <16 x i32> zeroinitializer
+ %cmp = icmp eq <16 x i16> %x, %sum_x_splat
+ %select = select <16 x i1> %cmp, <16 x i16> %y, <16 x i16> splat (i16 -1)
+ %select_min = tail call i16 @llvm.vector.reduce.umin.v16i16(<16 x i16> %select)
+ %ret_0 = insertvalue { i16, i16 } poison, i16 %sum_x, 0
+ %ret = insertvalue { i16, i16 } %ret_0, i16 %select_min, 1
+ ret { i16, i16 } %ret
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/144231
More information about the llvm-commits
mailing list