[llvm] [X86] Fold X * 1 + Z --> X + Z for VPMADD52L (PR #158516)

Hongyu Chen via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 16 01:32:44 PDT 2025


https://github.com/XChy updated https://github.com/llvm/llvm-project/pull/158516

>From dffe8d4e642c63910afe0f59225f370e1b9d4e9f Mon Sep 17 00:00:00 2001
From: XChy <xxs_chy at outlook.com>
Date: Mon, 15 Sep 2025 02:23:10 +0800
Subject: [PATCH 1/4] Precommit tests

---
 llvm/test/CodeGen/X86/combine-vpmadd52.ll | 75 +++++++++++++++++++++++
 1 file changed, 75 insertions(+)

diff --git a/llvm/test/CodeGen/X86/combine-vpmadd52.ll b/llvm/test/CodeGen/X86/combine-vpmadd52.ll
index 2cb060ea92b14..3e9f6ddab9a4a 100644
--- a/llvm/test/CodeGen/X86/combine-vpmadd52.ll
+++ b/llvm/test/CodeGen/X86/combine-vpmadd52.ll
@@ -398,3 +398,78 @@ define <2 x i64> @test3_knownbits_vpmadd52h_negative(<2 x i64> %x0, <2 x i64> %x
   %ret = and <2 x i64> %madd, splat (i64 1)
   ret <2 x i64> %ret
 }
+
+define <2 x i64> @test_vpmadd52l_mul_one(<2 x i64> %x0, <2 x i32> %x1) {
+; AVX512-LABEL: test_vpmadd52l_mul_one:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero
+; AVX512-NEXT:    vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm1, %xmm0
+; AVX512-NEXT:    retq
+;
+; AVX-LABEL: test_vpmadd52l_mul_one:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero
+; AVX-NEXT:    {vex} vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0
+; AVX-NEXT:    retq
+  %ext = zext <2 x i32> %x1 to <2 x i64>
+  %ifma = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> splat(i64 1), <2 x i64> %ext)
+  ret <2 x i64> %ifma
+}
+
+define <2 x i64> @test_vpmadd52l_mul_one_commuted(<2 x i64> %x0, <2 x i32> %x1) {
+; AVX512-LABEL: test_vpmadd52l_mul_one_commuted:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero
+; AVX512-NEXT:    vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm1, %xmm0
+; AVX512-NEXT:    retq
+;
+; AVX-LABEL: test_vpmadd52l_mul_one_commuted:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero
+; AVX-NEXT:    {vex} vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0
+; AVX-NEXT:    retq
+  %ext = zext <2 x i32> %x1 to <2 x i64>
+  %ifma = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> %ext, <2 x i64> splat(i64 1))
+  ret <2 x i64> %ifma
+}
+
+define <2 x i64> @test_vpmadd52l_mul_one_no_mask(<2 x i64> %x0, <2 x i64> %x1) {
+; AVX512-LABEL: test_vpmadd52l_mul_one_no_mask:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm1, %xmm0
+; AVX512-NEXT:    retq
+;
+; AVX-LABEL: test_vpmadd52l_mul_one_no_mask:
+; AVX:       # %bb.0:
+; AVX-NEXT:    {vex} vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0
+; AVX-NEXT:    retq
+  %ifma = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> splat(i64 1), <2 x i64> %x1)
+  ret <2 x i64> %ifma
+}
+
+; Mul by (1 << 52) + 1
+define <2 x i64> @test_vpmadd52l_mul_one_in_52bits(<2 x i64> %x0, <2 x i32> %x1) {
+; AVX512-LABEL: test_vpmadd52l_mul_one_in_52bits:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero
+; AVX512-NEXT:    vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm1, %xmm0
+; AVX512-NEXT:    retq
+;
+; AVX-LABEL: test_vpmadd52l_mul_one_in_52bits:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero
+; AVX-NEXT:    {vex} vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0
+; AVX-NEXT:    retq
+  %ext = zext <2 x i32> %x1 to <2 x i64>
+  %ifma = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> splat(i64 4503599627370497), <2 x i64> %ext)
+  ret <2 x i64> %ifma
+}
+
+; lo(x1) * 1 = lo(x1), the high 52 bits are zeroes still.
+define <2 x i64> @test_vpmadd52h_mul_one(<2 x i64> %x0, <2 x i64> %x1) {
+; CHECK-LABEL: test_vpmadd52h_mul_one:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    retq
+  %ifma = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> splat(i64 1), <2 x i64> %x1)
+  ret <2 x i64> %ifma
+}

>From a1bd2886a9ad4690f2ad13ec7214f4f76929e64d Mon Sep 17 00:00:00 2001
From: XChy <xxs_chy at outlook.com>
Date: Mon, 15 Sep 2025 02:58:14 +0800
Subject: [PATCH 2/4] [X86] Fold X * 1 + Z --> X + Z for VPADD52L

---
 llvm/lib/Target/X86/X86ISelLowering.cpp   | 25 +++++++++---
 llvm/test/CodeGen/X86/combine-vpmadd52.ll | 48 +++++++----------------
 2 files changed, 35 insertions(+), 38 deletions(-)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 3631016b0f5c7..00ccddc3956b4 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -44958,6 +44958,7 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
   }
   case X86ISD::VPMADD52L:
   case X86ISD::VPMADD52H: {
+    KnownBits OrigKnownOp0, OrigKnownOp1;
     KnownBits KnownOp0, KnownOp1, KnownOp2;
     SDValue Op0 = Op.getOperand(0);
     SDValue Op1 = Op.getOperand(1);
@@ -44965,11 +44966,11 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
     //  Only demand the lower 52-bits of operands 0 / 1 (and all 64-bits of
     //  operand 2).
     APInt Low52Bits = APInt::getLowBitsSet(BitWidth, 52);
-    if (SimplifyDemandedBits(Op0, Low52Bits, OriginalDemandedElts, KnownOp0,
+    if (SimplifyDemandedBits(Op0, Low52Bits, OriginalDemandedElts, OrigKnownOp0,
                              TLO, Depth + 1))
       return true;
 
-    if (SimplifyDemandedBits(Op1, Low52Bits, OriginalDemandedElts, KnownOp1,
+    if (SimplifyDemandedBits(Op1, Low52Bits, OriginalDemandedElts, OrigKnownOp1,
                              TLO, Depth + 1))
       return true;
 
@@ -44978,19 +44979,33 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
       return true;
 
     KnownBits KnownMul;
-    KnownOp0 = KnownOp0.trunc(52);
-    KnownOp1 = KnownOp1.trunc(52);
+    KnownOp0 = OrigKnownOp0.trunc(52);
+    KnownOp1 = OrigKnownOp1.trunc(52);
     KnownMul = Opc == X86ISD::VPMADD52L ? KnownBits::mul(KnownOp0, KnownOp1)
                                         : KnownBits::mulhu(KnownOp0, KnownOp1);
     KnownMul = KnownMul.zext(64);
 
+    SDLoc DL(Op);
     // lo/hi(X * Y) + Z --> C + Z
     if (KnownMul.isConstant()) {
-      SDLoc DL(Op);
       SDValue C = TLO.DAG.getConstant(KnownMul.getConstant(), DL, VT);
       return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::ADD, DL, VT, C, Op2));
     }
 
+    // C * X --> X * C
+    if (KnownOp0.isConstant()) {
+      std::swap(OrigKnownOp0, OrigKnownOp1);
+      std::swap(KnownOp0, KnownOp1);
+      std::swap(Op0, Op1);
+    }
+
+    // lo(X * 1) + Z --> lo(X) + Z --> X iff X == lo(X)
+    if (Opc == X86ISD::VPMADD52L && KnownOp1.isConstant() &&
+        KnownOp1.getConstant().isOne() &&
+        OrigKnownOp0.countMinLeadingZeros() >= 12) {
+      return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::ADD, DL, VT, Op0, Op2));
+    }
+
     Known = KnownBits::add(KnownMul, KnownOp2);
     return false;
   }
diff --git a/llvm/test/CodeGen/X86/combine-vpmadd52.ll b/llvm/test/CodeGen/X86/combine-vpmadd52.ll
index 3e9f6ddab9a4a..8b741e9ef9482 100644
--- a/llvm/test/CodeGen/X86/combine-vpmadd52.ll
+++ b/llvm/test/CodeGen/X86/combine-vpmadd52.ll
@@ -400,34 +400,22 @@ define <2 x i64> @test3_knownbits_vpmadd52h_negative(<2 x i64> %x0, <2 x i64> %x
 }
 
 define <2 x i64> @test_vpmadd52l_mul_one(<2 x i64> %x0, <2 x i32> %x1) {
-; AVX512-LABEL: test_vpmadd52l_mul_one:
-; AVX512:       # %bb.0:
-; AVX512-NEXT:    vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero
-; AVX512-NEXT:    vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm1, %xmm0
-; AVX512-NEXT:    retq
-;
-; AVX-LABEL: test_vpmadd52l_mul_one:
-; AVX:       # %bb.0:
-; AVX-NEXT:    vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero
-; AVX-NEXT:    {vex} vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0
-; AVX-NEXT:    retq
+; CHECK-LABEL: test_vpmadd52l_mul_one:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero
+; CHECK-NEXT:    vpaddq %xmm0, %xmm1, %xmm0
+; CHECK-NEXT:    retq
   %ext = zext <2 x i32> %x1 to <2 x i64>
   %ifma = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> splat(i64 1), <2 x i64> %ext)
   ret <2 x i64> %ifma
 }
 
 define <2 x i64> @test_vpmadd52l_mul_one_commuted(<2 x i64> %x0, <2 x i32> %x1) {
-; AVX512-LABEL: test_vpmadd52l_mul_one_commuted:
-; AVX512:       # %bb.0:
-; AVX512-NEXT:    vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero
-; AVX512-NEXT:    vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm1, %xmm0
-; AVX512-NEXT:    retq
-;
-; AVX-LABEL: test_vpmadd52l_mul_one_commuted:
-; AVX:       # %bb.0:
-; AVX-NEXT:    vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero
-; AVX-NEXT:    {vex} vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0
-; AVX-NEXT:    retq
+; CHECK-LABEL: test_vpmadd52l_mul_one_commuted:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero
+; CHECK-NEXT:    vpaddq %xmm0, %xmm1, %xmm0
+; CHECK-NEXT:    retq
   %ext = zext <2 x i32> %x1 to <2 x i64>
   %ifma = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> %ext, <2 x i64> splat(i64 1))
   ret <2 x i64> %ifma
@@ -449,17 +437,11 @@ define <2 x i64> @test_vpmadd52l_mul_one_no_mask(<2 x i64> %x0, <2 x i64> %x1) {
 
 ; Mul by (1 << 52) + 1
 define <2 x i64> @test_vpmadd52l_mul_one_in_52bits(<2 x i64> %x0, <2 x i32> %x1) {
-; AVX512-LABEL: test_vpmadd52l_mul_one_in_52bits:
-; AVX512:       # %bb.0:
-; AVX512-NEXT:    vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero
-; AVX512-NEXT:    vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm1, %xmm0
-; AVX512-NEXT:    retq
-;
-; AVX-LABEL: test_vpmadd52l_mul_one_in_52bits:
-; AVX:       # %bb.0:
-; AVX-NEXT:    vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero
-; AVX-NEXT:    {vex} vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0
-; AVX-NEXT:    retq
+; CHECK-LABEL: test_vpmadd52l_mul_one_in_52bits:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero
+; CHECK-NEXT:    vpaddq %xmm0, %xmm1, %xmm0
+; CHECK-NEXT:    retq
   %ext = zext <2 x i32> %x1 to <2 x i64>
   %ifma = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> splat(i64 4503599627370497), <2 x i64> %ext)
   ret <2 x i64> %ifma

>From 574fe4ccd8980c8cd3168b5664ba2f9362d8f26e Mon Sep 17 00:00:00 2001
From: XChy <xxs_chy at outlook.com>
Date: Tue, 16 Sep 2025 13:15:38 +0800
Subject: [PATCH 3/4] combine

---
 llvm/lib/Target/X86/X86ISelLowering.cpp | 54 ++++++++++++++++---------
 1 file changed, 34 insertions(+), 20 deletions(-)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 00ccddc3956b4..3b551a29d28fc 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -44958,7 +44958,6 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
   }
   case X86ISD::VPMADD52L:
   case X86ISD::VPMADD52H: {
-    KnownBits OrigKnownOp0, OrigKnownOp1;
     KnownBits KnownOp0, KnownOp1, KnownOp2;
     SDValue Op0 = Op.getOperand(0);
     SDValue Op1 = Op.getOperand(1);
@@ -44966,11 +44965,11 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
     //  Only demand the lower 52-bits of operands 0 / 1 (and all 64-bits of
     //  operand 2).
     APInt Low52Bits = APInt::getLowBitsSet(BitWidth, 52);
-    if (SimplifyDemandedBits(Op0, Low52Bits, OriginalDemandedElts, OrigKnownOp0,
+    if (SimplifyDemandedBits(Op0, Low52Bits, OriginalDemandedElts, KnownOp0,
                              TLO, Depth + 1))
       return true;
 
-    if (SimplifyDemandedBits(Op1, Low52Bits, OriginalDemandedElts, OrigKnownOp1,
+    if (SimplifyDemandedBits(Op1, Low52Bits, OriginalDemandedElts, KnownOp1,
                              TLO, Depth + 1))
       return true;
 
@@ -44979,8 +44978,8 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
       return true;
 
     KnownBits KnownMul;
-    KnownOp0 = OrigKnownOp0.trunc(52);
-    KnownOp1 = OrigKnownOp1.trunc(52);
+    KnownOp0 = KnownOp0.trunc(52);
+    KnownOp1 = KnownOp1.trunc(52);
     KnownMul = Opc == X86ISD::VPMADD52L ? KnownBits::mul(KnownOp0, KnownOp1)
                                         : KnownBits::mulhu(KnownOp0, KnownOp1);
     KnownMul = KnownMul.zext(64);
@@ -44992,20 +44991,6 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
       return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::ADD, DL, VT, C, Op2));
     }
 
-    // C * X --> X * C
-    if (KnownOp0.isConstant()) {
-      std::swap(OrigKnownOp0, OrigKnownOp1);
-      std::swap(KnownOp0, KnownOp1);
-      std::swap(Op0, Op1);
-    }
-
-    // lo(X * 1) + Z --> lo(X) + Z --> X iff X == lo(X)
-    if (Opc == X86ISD::VPMADD52L && KnownOp1.isConstant() &&
-        KnownOp1.getConstant().isOne() &&
-        OrigKnownOp0.countMinLeadingZeros() >= 12) {
-      return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::ADD, DL, VT, Op0, Op2));
-    }
-
     Known = KnownBits::add(KnownMul, KnownOp2);
     return false;
   }
@@ -60201,8 +60186,37 @@ static SDValue combineVPMADD(SDNode *N, SelectionDAG &DAG,
 static SDValue combineVPMADD52LH(SDNode *N, SelectionDAG &DAG,
                                  TargetLowering::DAGCombinerInfo &DCI) {
   MVT VT = N->getSimpleValueType(0);
-  unsigned NumEltBits = VT.getScalarSizeInBits();
+
+  bool AddLow = N->getOpcode() == X86ISD::VPMADD52L;
+  SDValue Op0 = N->getOperand(0);
+  SDValue Op1 = N->getOperand(1);
+  SDValue Op2 = N->getOperand(2);
+  SDLoc DL(N);
+
+  APInt C0, C1;
+  bool HasC0 = X86::isConstantSplat(Op0, C0),
+       HasC1 = X86::isConstantSplat(Op1, C1);
+
+  // lo/hi(C * X) + Z --> lo/hi(X * C) + Z
+  if (HasC0 && !HasC1)
+    return DAG.getNode(N->getOpcode(), DL, VT, Op1, Op0, Op2);
+
+  // Only keep the low 52 bits of C1
+  if (HasC1 && C1.countLeadingZeros() < 12) {
+    C1.clearBits(52, 64);
+    SDValue LowC1 = DAG.getConstant(C1, DL, VT);
+    return DAG.getNode(N->getOpcode(), DL, VT, Op0, LowC1, Op2);
+  }
+
+  // lo(X * 1) + Z --> lo(X) + Z iff X == lo(X)
+  if (AddLow && HasC1 && C1.isOne()) {
+    KnownBits KnownOp0 = DAG.computeKnownBits(Op0);
+    if (KnownOp0.countMinLeadingZeros() >= 12)
+      return DAG.getNode(ISD::ADD, DL, VT, Op0, Op2);
+  }
+
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+  unsigned NumEltBits = VT.getScalarSizeInBits();
   if (TLI.SimplifyDemandedBits(SDValue(N, 0), APInt::getAllOnes(NumEltBits),
                                DCI))
     return SDValue(N, 0);

>From 1f155fc4515c5274b9872cc45ff1dc3bbc53a9d4 Mon Sep 17 00:00:00 2001
From: XChy <xxs_chy at outlook.com>
Date: Tue, 16 Sep 2025 16:32:28 +0800
Subject: [PATCH 4/4] fix diff

---
 llvm/lib/Target/X86/X86ISelLowering.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 3b551a29d28fc..790f65ac547fb 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -44984,9 +44984,9 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
                                         : KnownBits::mulhu(KnownOp0, KnownOp1);
     KnownMul = KnownMul.zext(64);
 
-    SDLoc DL(Op);
     // lo/hi(X * Y) + Z --> C + Z
     if (KnownMul.isConstant()) {
+      SDLoc DL(Op);
       SDValue C = TLO.DAG.getConstant(KnownMul.getConstant(), DL, VT);
       return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::ADD, DL, VT, C, Op2));
     }



More information about the llvm-commits mailing list