[llvm] [SelectionDAG][x86] Ensure vector reduction optimization (PR #144231)
Suhajda Tamás via llvm-commits
llvm-commits at lists.llvm.org
Sat Jun 14 10:15:07 PDT 2025
https://github.com/sutajo created https://github.com/llvm/llvm-project/pull/144231
Fixes #144227
>From 2b2130a54aa74635ca194d6533e2e9ecc313f39e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Suhajda=20Tam=C3=A1s?= <sutajo at gmail.com>
Date: Sat, 14 Jun 2025 19:04:49 +0200
Subject: [PATCH 1/2] [x86] Add test for reduction
---
llvm/test/CodeGen/X86/optimize-reduction.ll | 140 ++++++++++++++++++++
1 file changed, 140 insertions(+)
create mode 100644 llvm/test/CodeGen/X86/optimize-reduction.ll
diff --git a/llvm/test/CodeGen/X86/optimize-reduction.ll b/llvm/test/CodeGen/X86/optimize-reduction.ll
new file mode 100644
index 0000000000000..003c41612b8bf
--- /dev/null
+++ b/llvm/test/CodeGen/X86/optimize-reduction.ll
@@ -0,0 +1,140 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=x86_64-- -mattr=+sse4.1,+fast-hops | FileCheck %s --check-prefixes=SSE41
+; RUN: llc < %s -mtriple=x86_64-- -mattr=+avx2,+fast-hops | FileCheck %s --check-prefixes=AVX2
+
+define { i16, i16 } @test_reduce_v16i16_with_umin(<16 x i16> %x, <16 x i16> %y) {
+; SSE41-LABEL: test_reduce_v16i16_with_umin:
+; SSE41: # %bb.0:
+; SSE41-NEXT: movdqa %xmm0, %xmm4
+; SSE41-NEXT: pminuw %xmm1, %xmm4
+; SSE41-NEXT: pshufd {{.*#+}} xmm5 = xmm4[2,3,2,3]
+; SSE41-NEXT: pminuw %xmm4, %xmm5
+; SSE41-NEXT: pshufd {{.*#+}} xmm6 = xmm5[1,1,1,1]
+; SSE41-NEXT: pminuw %xmm5, %xmm6
+; SSE41-NEXT: movdqa %xmm6, %xmm5
+; SSE41-NEXT: psrld $16, %xmm5
+; SSE41-NEXT: pminuw %xmm6, %xmm5
+; SSE41-NEXT: phminposuw %xmm4, %xmm4
+; SSE41-NEXT: movd %xmm4, %eax
+; SSE41-NEXT: pshuflw {{.*#+}} xmm4 = xmm5[0,0,0,0,4,5,6,7]
+; SSE41-NEXT: pshufd {{.*#+}} xmm4 = xmm4[0,1,0,1]
+; SSE41-NEXT: pcmpeqw %xmm4, %xmm1
+; SSE41-NEXT: pcmpeqd %xmm5, %xmm5
+; SSE41-NEXT: pxor %xmm5, %xmm1
+; SSE41-NEXT: por %xmm3, %xmm1
+; SSE41-NEXT: pcmpeqw %xmm4, %xmm0
+; SSE41-NEXT: pxor %xmm5, %xmm0
+; SSE41-NEXT: por %xmm2, %xmm0
+; SSE41-NEXT: pminuw %xmm1, %xmm0
+; SSE41-NEXT: phminposuw %xmm0, %xmm0
+; SSE41-NEXT: movd %xmm0, %edx
+; SSE41-NEXT: # kill: def $ax killed $ax killed $eax
+; SSE41-NEXT: # kill: def $dx killed $dx killed $edx
+; SSE41-NEXT: retq
+;
+; AVX2-LABEL: test_reduce_v16i16_with_umin:
+; AVX2: # %bb.0:
+; AVX2-NEXT: vextracti128 $1, %ymm0, %xmm2
+; AVX2-NEXT: vpminuw %xmm2, %xmm0, %xmm2
+; AVX2-NEXT: vpshufd {{.*#+}} xmm3 = xmm2[2,3,2,3]
+; AVX2-NEXT: vpminuw %xmm3, %xmm2, %xmm3
+; AVX2-NEXT: vpshufd {{.*#+}} xmm4 = xmm3[1,1,1,1]
+; AVX2-NEXT: vpminuw %xmm4, %xmm3, %xmm3
+; AVX2-NEXT: vpsrld $16, %xmm3, %xmm4
+; AVX2-NEXT: vphminposuw %xmm2, %xmm2
+; AVX2-NEXT: vmovd %xmm2, %eax
+; AVX2-NEXT: vpminuw %xmm4, %xmm3, %xmm2
+; AVX2-NEXT: vpbroadcastw %xmm2, %ymm2
+; AVX2-NEXT: vpcmpeqw %ymm2, %ymm0, %ymm0
+; AVX2-NEXT: vpcmpeqd %ymm2, %ymm2, %ymm2
+; AVX2-NEXT: vpxor %ymm2, %ymm0, %ymm0
+; AVX2-NEXT: vpor %ymm1, %ymm0, %ymm0
+; AVX2-NEXT: vextracti128 $1, %ymm0, %xmm1
+; AVX2-NEXT: vpminuw %xmm1, %xmm0, %xmm0
+; AVX2-NEXT: vphminposuw %xmm0, %xmm0
+; AVX2-NEXT: vmovd %xmm0, %edx
+; AVX2-NEXT: # kill: def $ax killed $ax killed $eax
+; AVX2-NEXT: # kill: def $dx killed $dx killed $edx
+; AVX2-NEXT: vzeroupper
+; AVX2-NEXT: retq
+ %min_x = tail call i16 @llvm.vector.reduce.umin.v16i16(<16 x i16> %x)
+ %min_x_vec = insertelement <1 x i16> poison, i16 %min_x, i64 0
+ %min_x_splat = shufflevector <1 x i16> %min_x_vec, <1 x i16> poison, <16 x i32> zeroinitializer
+ %cmp = icmp eq <16 x i16> %x, %min_x_splat
+ %select = select <16 x i1> %cmp, <16 x i16> %y, <16 x i16> splat (i16 -1)
+ %select_min = tail call i16 @llvm.vector.reduce.umin.v16i16(<16 x i16> %select)
+ %ret_0 = insertvalue { i16, i16 } poison, i16 %min_x, 0
+ %ret = insertvalue { i16, i16 } %ret_0, i16 %select_min, 1
+ ret { i16, i16 } %ret
+}
+
+define { i16, i16 } @test_reduce_v16i16_with_add(<16 x i16> %x, <16 x i16> %y) {
+; SSE41-LABEL: test_reduce_v16i16_with_add:
+; SSE41: # %bb.0: # %start
+; SSE41-NEXT: movdqa %xmm0, %xmm4
+; SSE41-NEXT: paddw %xmm1, %xmm4
+; SSE41-NEXT: pshufd {{.*#+}} xmm5 = xmm4[2,3,2,3]
+; SSE41-NEXT: paddw %xmm4, %xmm5
+; SSE41-NEXT: pshufd {{.*#+}} xmm4 = xmm5[1,1,1,1]
+; SSE41-NEXT: paddw %xmm5, %xmm4
+; SSE41-NEXT: phaddw %xmm4, %xmm4
+; SSE41-NEXT: movdqa %xmm1, %xmm5
+; SSE41-NEXT: phaddw %xmm0, %xmm5
+; SSE41-NEXT: phaddw %xmm5, %xmm5
+; SSE41-NEXT: phaddw %xmm5, %xmm5
+; SSE41-NEXT: phaddw %xmm5, %xmm5
+; SSE41-NEXT: movd %xmm5, %eax
+; SSE41-NEXT: pshuflw {{.*#+}} xmm4 = xmm4[0,0,0,0,4,5,6,7]
+; SSE41-NEXT: pshufd {{.*#+}} xmm4 = xmm4[0,1,0,1]
+; SSE41-NEXT: pcmpeqw %xmm4, %xmm1
+; SSE41-NEXT: pcmpeqd %xmm5, %xmm5
+; SSE41-NEXT: pxor %xmm5, %xmm1
+; SSE41-NEXT: por %xmm3, %xmm1
+; SSE41-NEXT: pcmpeqw %xmm4, %xmm0
+; SSE41-NEXT: pxor %xmm5, %xmm0
+; SSE41-NEXT: por %xmm2, %xmm0
+; SSE41-NEXT: pminuw %xmm1, %xmm0
+; SSE41-NEXT: phminposuw %xmm0, %xmm0
+; SSE41-NEXT: movd %xmm0, %edx
+; SSE41-NEXT: # kill: def $ax killed $ax killed $eax
+; SSE41-NEXT: # kill: def $dx killed $dx killed $edx
+; SSE41-NEXT: retq
+;
+; AVX2-LABEL: test_reduce_v16i16_with_add:
+; AVX2: # %bb.0: # %start
+; AVX2-NEXT: vextracti128 $1, %ymm0, %xmm2
+; AVX2-NEXT: vpaddw %xmm2, %xmm0, %xmm3
+; AVX2-NEXT: vpshufd {{.*#+}} xmm4 = xmm3[2,3,2,3]
+; AVX2-NEXT: vpaddw %xmm4, %xmm3, %xmm3
+; AVX2-NEXT: vpshufd {{.*#+}} xmm4 = xmm3[1,1,1,1]
+; AVX2-NEXT: vpaddw %xmm4, %xmm3, %xmm3
+; AVX2-NEXT: vphaddw %xmm3, %xmm3, %xmm3
+; AVX2-NEXT: vphaddw %xmm0, %xmm2, %xmm2
+; AVX2-NEXT: vphaddw %xmm2, %xmm2, %xmm2
+; AVX2-NEXT: vphaddw %xmm2, %xmm2, %xmm2
+; AVX2-NEXT: vphaddw %xmm2, %xmm2, %xmm2
+; AVX2-NEXT: vmovd %xmm2, %eax
+; AVX2-NEXT: vpbroadcastw %xmm3, %ymm2
+; AVX2-NEXT: vpcmpeqw %ymm2, %ymm0, %ymm0
+; AVX2-NEXT: vpcmpeqd %ymm2, %ymm2, %ymm2
+; AVX2-NEXT: vpxor %ymm2, %ymm0, %ymm0
+; AVX2-NEXT: vpor %ymm1, %ymm0, %ymm0
+; AVX2-NEXT: vextracti128 $1, %ymm0, %xmm1
+; AVX2-NEXT: vpminuw %xmm1, %xmm0, %xmm0
+; AVX2-NEXT: vphminposuw %xmm0, %xmm0
+; AVX2-NEXT: vmovd %xmm0, %edx
+; AVX2-NEXT: # kill: def $ax killed $ax killed $eax
+; AVX2-NEXT: # kill: def $dx killed $dx killed $edx
+; AVX2-NEXT: vzeroupper
+; AVX2-NEXT: retq
+start:
+ %sum_x = tail call i16 @llvm.vector.reduce.add.v16i16(<16 x i16> %x)
+ %sum_x_vec = insertelement <1 x i16> poison, i16 %sum_x, i64 0
+ %sum_x_splat = shufflevector <1 x i16> %sum_x_vec, <1 x i16> poison, <16 x i32> zeroinitializer
+ %cmp = icmp eq <16 x i16> %x, %sum_x_splat
+ %select = select <16 x i1> %cmp, <16 x i16> %y, <16 x i16> splat (i16 -1)
+ %select_min = tail call i16 @llvm.vector.reduce.umin.v16i16(<16 x i16> %select)
+ %ret_0 = insertvalue { i16, i16 } poison, i16 %sum_x, 0
+ %ret = insertvalue { i16, i16 } %ret_0, i16 %select_min, 1
+ ret { i16, i16 } %ret
+}
>From 93df21fcff0b9f836947cfd9bb8e88b565b1c435 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Suhajda=20Tam=C3=A1s?= <sutajo at gmail.com>
Date: Sat, 14 Jun 2025 19:09:44 +0200
Subject: [PATCH 2/2] [x86] Implement optimization and update tests
---
llvm/lib/Target/X86/X86ISelLowering.cpp | 52 ++++++++++++++++++---
llvm/test/CodeGen/X86/optimize-reduction.ll | 40 +++-------------
2 files changed, 52 insertions(+), 40 deletions(-)
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index b4670e270141f..61e3979b6c0bb 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -47081,7 +47081,8 @@ static SDValue combineArithReduction(SDNode *ExtElt, SelectionDAG &DAG,
/// scalars back, while for x64 we should use 64-bit extracts and shifts.
static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
- const X86Subtarget &Subtarget) {
+ const X86Subtarget &Subtarget,
+ bool& TransformedBinOpReduction) {
if (SDValue NewOp = combineExtractWithShuffle(N, DAG, DCI, Subtarget))
return NewOp;
@@ -47169,23 +47170,33 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,
// Check whether this extract is the root of a sum of absolute differences
// pattern. This has to be done here because we really want it to happen
// pre-legalization,
- if (SDValue SAD = combineBasicSADPattern(N, DAG, Subtarget))
+ if (SDValue SAD = combineBasicSADPattern(N, DAG, Subtarget)) {
+ TransformedBinOpReduction = true;
return SAD;
+ }
- if (SDValue VPDPBUSD = combineVPDPBUSDPattern(N, DAG, Subtarget))
+ if (SDValue VPDPBUSD = combineVPDPBUSDPattern(N, DAG, Subtarget)) {
+ TransformedBinOpReduction = true;
return VPDPBUSD;
+ }
// Attempt to replace an all_of/any_of horizontal reduction with a MOVMSK.
- if (SDValue Cmp = combinePredicateReduction(N, DAG, Subtarget))
+ if (SDValue Cmp = combinePredicateReduction(N, DAG, Subtarget)) {
+ TransformedBinOpReduction = true;
return Cmp;
+ }
// Attempt to replace min/max v8i16/v16i8 reductions with PHMINPOSUW.
- if (SDValue MinMax = combineMinMaxReduction(N, DAG, Subtarget))
+ if (SDValue MinMax = combineMinMaxReduction(N, DAG, Subtarget)) {
+ TransformedBinOpReduction = true;
return MinMax;
+ }
// Attempt to optimize ADD/FADD/MUL reductions with HADD, promotion etc..
- if (SDValue V = combineArithReduction(N, DAG, Subtarget))
+ if (SDValue V = combineArithReduction(N, DAG, Subtarget)) {
+ TransformedBinOpReduction = true;
return V;
+ }
if (SDValue V = scalarizeExtEltFP(N, DAG, Subtarget, DCI))
return V;
@@ -47255,6 +47266,33 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,
return SDValue();
}
+static SDValue combineExtractVectorEltAndOperand(SDNode* N, SelectionDAG& DAG,
+ TargetLowering::DAGCombinerInfo& DCI,
+ const X86Subtarget& Subtarget)
+{
+ bool TransformedBinOpReduction = false;
+ auto Op = combineExtractVectorElt(N, DAG, DCI, Subtarget, TransformedBinOpReduction);
+
+ if (TransformedBinOpReduction)
+ {
+ // In case we simplified N = extract_vector_element(V, 0) with Op and V
+ // resulted from a reduction, then we need to replace all uses of V with
+ // scalar_to_vector(Op) to make sure that we eliminated the binop + shuffle
+ // pyramid. This is safe to do, because the elements of V are undefined except
+ // for the zeroth element.
+
+ auto OldV = N->getOperand(0);
+ auto NewV = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), OldV->getValueType(0), Op);
+
+ auto NV = DCI.CombineTo(N, Op);
+ DCI.CombineTo(OldV.getNode(), NewV);
+
+ Op = NV; // Return N so it doesn't get rechecked!
+ }
+
+ return Op;
+}
+
// Convert (vXiY *ext(vXi1 bitcast(iX))) to extend_in_reg(broadcast(iX)).
// This is more or less the reverse of combineBitcastvxi1.
static SDValue combineToExtendBoolVectorInReg(
@@ -60702,7 +60740,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
case ISD::EXTRACT_VECTOR_ELT:
case X86ISD::PEXTRW:
case X86ISD::PEXTRB:
- return combineExtractVectorElt(N, DAG, DCI, Subtarget);
+ return combineExtractVectorEltAndOperand(N, DAG, DCI, Subtarget);
case ISD::CONCAT_VECTORS:
return combineCONCAT_VECTORS(N, DAG, DCI, Subtarget);
case ISD::INSERT_SUBVECTOR:
diff --git a/llvm/test/CodeGen/X86/optimize-reduction.ll b/llvm/test/CodeGen/X86/optimize-reduction.ll
index 003c41612b8bf..e51ac1bd3c13c 100644
--- a/llvm/test/CodeGen/X86/optimize-reduction.ll
+++ b/llvm/test/CodeGen/X86/optimize-reduction.ll
@@ -7,16 +7,9 @@ define { i16, i16 } @test_reduce_v16i16_with_umin(<16 x i16> %x, <16 x i16> %y)
; SSE41: # %bb.0:
; SSE41-NEXT: movdqa %xmm0, %xmm4
; SSE41-NEXT: pminuw %xmm1, %xmm4
-; SSE41-NEXT: pshufd {{.*#+}} xmm5 = xmm4[2,3,2,3]
-; SSE41-NEXT: pminuw %xmm4, %xmm5
-; SSE41-NEXT: pshufd {{.*#+}} xmm6 = xmm5[1,1,1,1]
-; SSE41-NEXT: pminuw %xmm5, %xmm6
-; SSE41-NEXT: movdqa %xmm6, %xmm5
-; SSE41-NEXT: psrld $16, %xmm5
-; SSE41-NEXT: pminuw %xmm6, %xmm5
; SSE41-NEXT: phminposuw %xmm4, %xmm4
; SSE41-NEXT: movd %xmm4, %eax
-; SSE41-NEXT: pshuflw {{.*#+}} xmm4 = xmm5[0,0,0,0,4,5,6,7]
+; SSE41-NEXT: pshuflw {{.*#+}} xmm4 = xmm4[0,0,0,0,4,5,6,7]
; SSE41-NEXT: pshufd {{.*#+}} xmm4 = xmm4[0,1,0,1]
; SSE41-NEXT: pcmpeqw %xmm4, %xmm1
; SSE41-NEXT: pcmpeqd %xmm5, %xmm5
@@ -36,14 +29,8 @@ define { i16, i16 } @test_reduce_v16i16_with_umin(<16 x i16> %x, <16 x i16> %y)
; AVX2: # %bb.0:
; AVX2-NEXT: vextracti128 $1, %ymm0, %xmm2
; AVX2-NEXT: vpminuw %xmm2, %xmm0, %xmm2
-; AVX2-NEXT: vpshufd {{.*#+}} xmm3 = xmm2[2,3,2,3]
-; AVX2-NEXT: vpminuw %xmm3, %xmm2, %xmm3
-; AVX2-NEXT: vpshufd {{.*#+}} xmm4 = xmm3[1,1,1,1]
-; AVX2-NEXT: vpminuw %xmm4, %xmm3, %xmm3
-; AVX2-NEXT: vpsrld $16, %xmm3, %xmm4
; AVX2-NEXT: vphminposuw %xmm2, %xmm2
; AVX2-NEXT: vmovd %xmm2, %eax
-; AVX2-NEXT: vpminuw %xmm4, %xmm3, %xmm2
; AVX2-NEXT: vpbroadcastw %xmm2, %ymm2
; AVX2-NEXT: vpcmpeqw %ymm2, %ymm0, %ymm0
; AVX2-NEXT: vpcmpeqd %ymm2, %ymm2, %ymm2
@@ -71,19 +58,12 @@ define { i16, i16 } @test_reduce_v16i16_with_umin(<16 x i16> %x, <16 x i16> %y)
define { i16, i16 } @test_reduce_v16i16_with_add(<16 x i16> %x, <16 x i16> %y) {
; SSE41-LABEL: test_reduce_v16i16_with_add:
; SSE41: # %bb.0: # %start
-; SSE41-NEXT: movdqa %xmm0, %xmm4
-; SSE41-NEXT: paddw %xmm1, %xmm4
-; SSE41-NEXT: pshufd {{.*#+}} xmm5 = xmm4[2,3,2,3]
-; SSE41-NEXT: paddw %xmm4, %xmm5
-; SSE41-NEXT: pshufd {{.*#+}} xmm4 = xmm5[1,1,1,1]
-; SSE41-NEXT: paddw %xmm5, %xmm4
+; SSE41-NEXT: movdqa %xmm1, %xmm4
+; SSE41-NEXT: phaddw %xmm0, %xmm4
+; SSE41-NEXT: phaddw %xmm4, %xmm4
; SSE41-NEXT: phaddw %xmm4, %xmm4
-; SSE41-NEXT: movdqa %xmm1, %xmm5
-; SSE41-NEXT: phaddw %xmm0, %xmm5
-; SSE41-NEXT: phaddw %xmm5, %xmm5
-; SSE41-NEXT: phaddw %xmm5, %xmm5
-; SSE41-NEXT: phaddw %xmm5, %xmm5
-; SSE41-NEXT: movd %xmm5, %eax
+; SSE41-NEXT: phaddw %xmm4, %xmm4
+; SSE41-NEXT: movd %xmm4, %eax
; SSE41-NEXT: pshuflw {{.*#+}} xmm4 = xmm4[0,0,0,0,4,5,6,7]
; SSE41-NEXT: pshufd {{.*#+}} xmm4 = xmm4[0,1,0,1]
; SSE41-NEXT: pcmpeqw %xmm4, %xmm1
@@ -103,18 +83,12 @@ define { i16, i16 } @test_reduce_v16i16_with_add(<16 x i16> %x, <16 x i16> %y) {
; AVX2-LABEL: test_reduce_v16i16_with_add:
; AVX2: # %bb.0: # %start
; AVX2-NEXT: vextracti128 $1, %ymm0, %xmm2
-; AVX2-NEXT: vpaddw %xmm2, %xmm0, %xmm3
-; AVX2-NEXT: vpshufd {{.*#+}} xmm4 = xmm3[2,3,2,3]
-; AVX2-NEXT: vpaddw %xmm4, %xmm3, %xmm3
-; AVX2-NEXT: vpshufd {{.*#+}} xmm4 = xmm3[1,1,1,1]
-; AVX2-NEXT: vpaddw %xmm4, %xmm3, %xmm3
-; AVX2-NEXT: vphaddw %xmm3, %xmm3, %xmm3
; AVX2-NEXT: vphaddw %xmm0, %xmm2, %xmm2
; AVX2-NEXT: vphaddw %xmm2, %xmm2, %xmm2
; AVX2-NEXT: vphaddw %xmm2, %xmm2, %xmm2
; AVX2-NEXT: vphaddw %xmm2, %xmm2, %xmm2
; AVX2-NEXT: vmovd %xmm2, %eax
-; AVX2-NEXT: vpbroadcastw %xmm3, %ymm2
+; AVX2-NEXT: vpbroadcastw %xmm2, %ymm2
; AVX2-NEXT: vpcmpeqw %ymm2, %ymm0, %ymm0
; AVX2-NEXT: vpcmpeqd %ymm2, %ymm2, %ymm2
; AVX2-NEXT: vpxor %ymm2, %ymm0, %ymm0
More information about the llvm-commits
mailing list