[llvm] [X86] Fold X * 1 + Z --> X + Z for VPMADD52L (PR #158516)
Hongyu Chen via llvm-commits
llvm-commits at lists.llvm.org
Tue Sep 16 01:32:44 PDT 2025
https://github.com/XChy updated https://github.com/llvm/llvm-project/pull/158516
>From dffe8d4e642c63910afe0f59225f370e1b9d4e9f Mon Sep 17 00:00:00 2001
From: XChy <xxs_chy at outlook.com>
Date: Mon, 15 Sep 2025 02:23:10 +0800
Subject: [PATCH 1/4] Precommit tests
---
llvm/test/CodeGen/X86/combine-vpmadd52.ll | 75 +++++++++++++++++++++++
1 file changed, 75 insertions(+)
diff --git a/llvm/test/CodeGen/X86/combine-vpmadd52.ll b/llvm/test/CodeGen/X86/combine-vpmadd52.ll
index 2cb060ea92b14..3e9f6ddab9a4a 100644
--- a/llvm/test/CodeGen/X86/combine-vpmadd52.ll
+++ b/llvm/test/CodeGen/X86/combine-vpmadd52.ll
@@ -398,3 +398,78 @@ define <2 x i64> @test3_knownbits_vpmadd52h_negative(<2 x i64> %x0, <2 x i64> %x
%ret = and <2 x i64> %madd, splat (i64 1)
ret <2 x i64> %ret
}
+
+define <2 x i64> @test_vpmadd52l_mul_one(<2 x i64> %x0, <2 x i32> %x1) {
+; AVX512-LABEL: test_vpmadd52l_mul_one:
+; AVX512: # %bb.0:
+; AVX512-NEXT: vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero
+; AVX512-NEXT: vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm1, %xmm0
+; AVX512-NEXT: retq
+;
+; AVX-LABEL: test_vpmadd52l_mul_one:
+; AVX: # %bb.0:
+; AVX-NEXT: vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero
+; AVX-NEXT: {vex} vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0
+; AVX-NEXT: retq
+ %ext = zext <2 x i32> %x1 to <2 x i64>
+ %ifma = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> splat(i64 1), <2 x i64> %ext)
+ ret <2 x i64> %ifma
+}
+
+define <2 x i64> @test_vpmadd52l_mul_one_commuted(<2 x i64> %x0, <2 x i32> %x1) {
+; AVX512-LABEL: test_vpmadd52l_mul_one_commuted:
+; AVX512: # %bb.0:
+; AVX512-NEXT: vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero
+; AVX512-NEXT: vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm1, %xmm0
+; AVX512-NEXT: retq
+;
+; AVX-LABEL: test_vpmadd52l_mul_one_commuted:
+; AVX: # %bb.0:
+; AVX-NEXT: vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero
+; AVX-NEXT: {vex} vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0
+; AVX-NEXT: retq
+ %ext = zext <2 x i32> %x1 to <2 x i64>
+ %ifma = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> %ext, <2 x i64> splat(i64 1))
+ ret <2 x i64> %ifma
+}
+
+define <2 x i64> @test_vpmadd52l_mul_one_no_mask(<2 x i64> %x0, <2 x i64> %x1) {
+; AVX512-LABEL: test_vpmadd52l_mul_one_no_mask:
+; AVX512: # %bb.0:
+; AVX512-NEXT: vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm1, %xmm0
+; AVX512-NEXT: retq
+;
+; AVX-LABEL: test_vpmadd52l_mul_one_no_mask:
+; AVX: # %bb.0:
+; AVX-NEXT: {vex} vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0
+; AVX-NEXT: retq
+ %ifma = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> splat(i64 1), <2 x i64> %x1)
+ ret <2 x i64> %ifma
+}
+
+; Mul by (1 << 52) + 1
+define <2 x i64> @test_vpmadd52l_mul_one_in_52bits(<2 x i64> %x0, <2 x i32> %x1) {
+; AVX512-LABEL: test_vpmadd52l_mul_one_in_52bits:
+; AVX512: # %bb.0:
+; AVX512-NEXT: vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero
+; AVX512-NEXT: vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm1, %xmm0
+; AVX512-NEXT: retq
+;
+; AVX-LABEL: test_vpmadd52l_mul_one_in_52bits:
+; AVX: # %bb.0:
+; AVX-NEXT: vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero
+; AVX-NEXT: {vex} vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0
+; AVX-NEXT: retq
+ %ext = zext <2 x i32> %x1 to <2 x i64>
+ %ifma = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> splat(i64 4503599627370497), <2 x i64> %ext)
+ ret <2 x i64> %ifma
+}
+
+; lo(x1) * 1 = lo(x1), the high 52 bits are zeroes still.
+define <2 x i64> @test_vpmadd52h_mul_one(<2 x i64> %x0, <2 x i64> %x1) {
+; CHECK-LABEL: test_vpmadd52h_mul_one:
+; CHECK: # %bb.0:
+; CHECK-NEXT: retq
+ %ifma = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> splat(i64 1), <2 x i64> %x1)
+ ret <2 x i64> %ifma
+}
>From a1bd2886a9ad4690f2ad13ec7214f4f76929e64d Mon Sep 17 00:00:00 2001
From: XChy <xxs_chy at outlook.com>
Date: Mon, 15 Sep 2025 02:58:14 +0800
Subject: [PATCH 2/4] [X86] Fold X * 1 + Z --> X + Z for VPADD52L
---
llvm/lib/Target/X86/X86ISelLowering.cpp | 25 +++++++++---
llvm/test/CodeGen/X86/combine-vpmadd52.ll | 48 +++++++----------------
2 files changed, 35 insertions(+), 38 deletions(-)
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 3631016b0f5c7..00ccddc3956b4 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -44958,6 +44958,7 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
}
case X86ISD::VPMADD52L:
case X86ISD::VPMADD52H: {
+ KnownBits OrigKnownOp0, OrigKnownOp1;
KnownBits KnownOp0, KnownOp1, KnownOp2;
SDValue Op0 = Op.getOperand(0);
SDValue Op1 = Op.getOperand(1);
@@ -44965,11 +44966,11 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
// Only demand the lower 52-bits of operands 0 / 1 (and all 64-bits of
// operand 2).
APInt Low52Bits = APInt::getLowBitsSet(BitWidth, 52);
- if (SimplifyDemandedBits(Op0, Low52Bits, OriginalDemandedElts, KnownOp0,
+ if (SimplifyDemandedBits(Op0, Low52Bits, OriginalDemandedElts, OrigKnownOp0,
TLO, Depth + 1))
return true;
- if (SimplifyDemandedBits(Op1, Low52Bits, OriginalDemandedElts, KnownOp1,
+ if (SimplifyDemandedBits(Op1, Low52Bits, OriginalDemandedElts, OrigKnownOp1,
TLO, Depth + 1))
return true;
@@ -44978,19 +44979,33 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
return true;
KnownBits KnownMul;
- KnownOp0 = KnownOp0.trunc(52);
- KnownOp1 = KnownOp1.trunc(52);
+ KnownOp0 = OrigKnownOp0.trunc(52);
+ KnownOp1 = OrigKnownOp1.trunc(52);
KnownMul = Opc == X86ISD::VPMADD52L ? KnownBits::mul(KnownOp0, KnownOp1)
: KnownBits::mulhu(KnownOp0, KnownOp1);
KnownMul = KnownMul.zext(64);
+ SDLoc DL(Op);
// lo/hi(X * Y) + Z --> C + Z
if (KnownMul.isConstant()) {
- SDLoc DL(Op);
SDValue C = TLO.DAG.getConstant(KnownMul.getConstant(), DL, VT);
return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::ADD, DL, VT, C, Op2));
}
+ // C * X --> X * C
+ if (KnownOp0.isConstant()) {
+ std::swap(OrigKnownOp0, OrigKnownOp1);
+ std::swap(KnownOp0, KnownOp1);
+ std::swap(Op0, Op1);
+ }
+
+ // lo(X * 1) + Z --> lo(X) + Z --> X iff X == lo(X)
+ if (Opc == X86ISD::VPMADD52L && KnownOp1.isConstant() &&
+ KnownOp1.getConstant().isOne() &&
+ OrigKnownOp0.countMinLeadingZeros() >= 12) {
+ return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::ADD, DL, VT, Op0, Op2));
+ }
+
Known = KnownBits::add(KnownMul, KnownOp2);
return false;
}
diff --git a/llvm/test/CodeGen/X86/combine-vpmadd52.ll b/llvm/test/CodeGen/X86/combine-vpmadd52.ll
index 3e9f6ddab9a4a..8b741e9ef9482 100644
--- a/llvm/test/CodeGen/X86/combine-vpmadd52.ll
+++ b/llvm/test/CodeGen/X86/combine-vpmadd52.ll
@@ -400,34 +400,22 @@ define <2 x i64> @test3_knownbits_vpmadd52h_negative(<2 x i64> %x0, <2 x i64> %x
}
define <2 x i64> @test_vpmadd52l_mul_one(<2 x i64> %x0, <2 x i32> %x1) {
-; AVX512-LABEL: test_vpmadd52l_mul_one:
-; AVX512: # %bb.0:
-; AVX512-NEXT: vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero
-; AVX512-NEXT: vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm1, %xmm0
-; AVX512-NEXT: retq
-;
-; AVX-LABEL: test_vpmadd52l_mul_one:
-; AVX: # %bb.0:
-; AVX-NEXT: vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero
-; AVX-NEXT: {vex} vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0
-; AVX-NEXT: retq
+; CHECK-LABEL: test_vpmadd52l_mul_one:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero
+; CHECK-NEXT: vpaddq %xmm0, %xmm1, %xmm0
+; CHECK-NEXT: retq
%ext = zext <2 x i32> %x1 to <2 x i64>
%ifma = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> splat(i64 1), <2 x i64> %ext)
ret <2 x i64> %ifma
}
define <2 x i64> @test_vpmadd52l_mul_one_commuted(<2 x i64> %x0, <2 x i32> %x1) {
-; AVX512-LABEL: test_vpmadd52l_mul_one_commuted:
-; AVX512: # %bb.0:
-; AVX512-NEXT: vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero
-; AVX512-NEXT: vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm1, %xmm0
-; AVX512-NEXT: retq
-;
-; AVX-LABEL: test_vpmadd52l_mul_one_commuted:
-; AVX: # %bb.0:
-; AVX-NEXT: vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero
-; AVX-NEXT: {vex} vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0
-; AVX-NEXT: retq
+; CHECK-LABEL: test_vpmadd52l_mul_one_commuted:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero
+; CHECK-NEXT: vpaddq %xmm0, %xmm1, %xmm0
+; CHECK-NEXT: retq
%ext = zext <2 x i32> %x1 to <2 x i64>
%ifma = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> %ext, <2 x i64> splat(i64 1))
ret <2 x i64> %ifma
@@ -449,17 +437,11 @@ define <2 x i64> @test_vpmadd52l_mul_one_no_mask(<2 x i64> %x0, <2 x i64> %x1) {
; Mul by (1 << 52) + 1
define <2 x i64> @test_vpmadd52l_mul_one_in_52bits(<2 x i64> %x0, <2 x i32> %x1) {
-; AVX512-LABEL: test_vpmadd52l_mul_one_in_52bits:
-; AVX512: # %bb.0:
-; AVX512-NEXT: vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero
-; AVX512-NEXT: vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm1, %xmm0
-; AVX512-NEXT: retq
-;
-; AVX-LABEL: test_vpmadd52l_mul_one_in_52bits:
-; AVX: # %bb.0:
-; AVX-NEXT: vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero
-; AVX-NEXT: {vex} vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0
-; AVX-NEXT: retq
+; CHECK-LABEL: test_vpmadd52l_mul_one_in_52bits:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero
+; CHECK-NEXT: vpaddq %xmm0, %xmm1, %xmm0
+; CHECK-NEXT: retq
%ext = zext <2 x i32> %x1 to <2 x i64>
%ifma = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> splat(i64 4503599627370497), <2 x i64> %ext)
ret <2 x i64> %ifma
>From 574fe4ccd8980c8cd3168b5664ba2f9362d8f26e Mon Sep 17 00:00:00 2001
From: XChy <xxs_chy at outlook.com>
Date: Tue, 16 Sep 2025 13:15:38 +0800
Subject: [PATCH 3/4] combine
---
llvm/lib/Target/X86/X86ISelLowering.cpp | 54 ++++++++++++++++---------
1 file changed, 34 insertions(+), 20 deletions(-)
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 00ccddc3956b4..3b551a29d28fc 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -44958,7 +44958,6 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
}
case X86ISD::VPMADD52L:
case X86ISD::VPMADD52H: {
- KnownBits OrigKnownOp0, OrigKnownOp1;
KnownBits KnownOp0, KnownOp1, KnownOp2;
SDValue Op0 = Op.getOperand(0);
SDValue Op1 = Op.getOperand(1);
@@ -44966,11 +44965,11 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
// Only demand the lower 52-bits of operands 0 / 1 (and all 64-bits of
// operand 2).
APInt Low52Bits = APInt::getLowBitsSet(BitWidth, 52);
- if (SimplifyDemandedBits(Op0, Low52Bits, OriginalDemandedElts, OrigKnownOp0,
+ if (SimplifyDemandedBits(Op0, Low52Bits, OriginalDemandedElts, KnownOp0,
TLO, Depth + 1))
return true;
- if (SimplifyDemandedBits(Op1, Low52Bits, OriginalDemandedElts, OrigKnownOp1,
+ if (SimplifyDemandedBits(Op1, Low52Bits, OriginalDemandedElts, KnownOp1,
TLO, Depth + 1))
return true;
@@ -44979,8 +44978,8 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
return true;
KnownBits KnownMul;
- KnownOp0 = OrigKnownOp0.trunc(52);
- KnownOp1 = OrigKnownOp1.trunc(52);
+ KnownOp0 = KnownOp0.trunc(52);
+ KnownOp1 = KnownOp1.trunc(52);
KnownMul = Opc == X86ISD::VPMADD52L ? KnownBits::mul(KnownOp0, KnownOp1)
: KnownBits::mulhu(KnownOp0, KnownOp1);
KnownMul = KnownMul.zext(64);
@@ -44992,20 +44991,6 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::ADD, DL, VT, C, Op2));
}
- // C * X --> X * C
- if (KnownOp0.isConstant()) {
- std::swap(OrigKnownOp0, OrigKnownOp1);
- std::swap(KnownOp0, KnownOp1);
- std::swap(Op0, Op1);
- }
-
- // lo(X * 1) + Z --> lo(X) + Z --> X iff X == lo(X)
- if (Opc == X86ISD::VPMADD52L && KnownOp1.isConstant() &&
- KnownOp1.getConstant().isOne() &&
- OrigKnownOp0.countMinLeadingZeros() >= 12) {
- return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::ADD, DL, VT, Op0, Op2));
- }
-
Known = KnownBits::add(KnownMul, KnownOp2);
return false;
}
@@ -60201,8 +60186,37 @@ static SDValue combineVPMADD(SDNode *N, SelectionDAG &DAG,
static SDValue combineVPMADD52LH(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI) {
MVT VT = N->getSimpleValueType(0);
- unsigned NumEltBits = VT.getScalarSizeInBits();
+
+ bool AddLow = N->getOpcode() == X86ISD::VPMADD52L;
+ SDValue Op0 = N->getOperand(0);
+ SDValue Op1 = N->getOperand(1);
+ SDValue Op2 = N->getOperand(2);
+ SDLoc DL(N);
+
+ APInt C0, C1;
+ bool HasC0 = X86::isConstantSplat(Op0, C0),
+ HasC1 = X86::isConstantSplat(Op1, C1);
+
+ // lo/hi(C * X) + Z --> lo/hi(X * C) + Z
+ if (HasC0 && !HasC1)
+ return DAG.getNode(N->getOpcode(), DL, VT, Op1, Op0, Op2);
+
+ // Only keep the low 52 bits of C1
+ if (HasC1 && C1.countLeadingZeros() < 12) {
+ C1.clearBits(52, 64);
+ SDValue LowC1 = DAG.getConstant(C1, DL, VT);
+ return DAG.getNode(N->getOpcode(), DL, VT, Op0, LowC1, Op2);
+ }
+
+ // lo(X * 1) + Z --> lo(X) + Z iff X == lo(X)
+ if (AddLow && HasC1 && C1.isOne()) {
+ KnownBits KnownOp0 = DAG.computeKnownBits(Op0);
+ if (KnownOp0.countMinLeadingZeros() >= 12)
+ return DAG.getNode(ISD::ADD, DL, VT, Op0, Op2);
+ }
+
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ unsigned NumEltBits = VT.getScalarSizeInBits();
if (TLI.SimplifyDemandedBits(SDValue(N, 0), APInt::getAllOnes(NumEltBits),
DCI))
return SDValue(N, 0);
>From 1f155fc4515c5274b9872cc45ff1dc3bbc53a9d4 Mon Sep 17 00:00:00 2001
From: XChy <xxs_chy at outlook.com>
Date: Tue, 16 Sep 2025 16:32:28 +0800
Subject: [PATCH 4/4] fix diff
---
llvm/lib/Target/X86/X86ISelLowering.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 3b551a29d28fc..790f65ac547fb 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -44984,9 +44984,9 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
: KnownBits::mulhu(KnownOp0, KnownOp1);
KnownMul = KnownMul.zext(64);
- SDLoc DL(Op);
// lo/hi(X * Y) + Z --> C + Z
if (KnownMul.isConstant()) {
+ SDLoc DL(Op);
SDValue C = TLO.DAG.getConstant(KnownMul.getConstant(), DL, VT);
return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::ADD, DL, VT, C, Op2));
}
More information about the llvm-commits
mailing list