[llvm] [X86] Fold X * Y + Z --> C + Z for vpmadd52l/vpmadd52h (PR #156293)

via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 1 09:38:23 PDT 2025


https://github.com/XChy updated https://github.com/llvm/llvm-project/pull/156293

>From 52e189c7b2a3d176e86e10b3b6b2a0b224dff369 Mon Sep 17 00:00:00 2001
From: XChy <xxs_chy at outlook.com>
Date: Mon, 1 Sep 2025 15:22:20 +0800
Subject: [PATCH 1/6] precommit test for intermediate multiplicaton

---
 llvm/test/CodeGen/X86/combine-vpmadd52.ll | 149 ++++++++++++++++++++++
 1 file changed, 149 insertions(+)

diff --git a/llvm/test/CodeGen/X86/combine-vpmadd52.ll b/llvm/test/CodeGen/X86/combine-vpmadd52.ll
index fd295ea31c55c..7b73a1974359c 100644
--- a/llvm/test/CodeGen/X86/combine-vpmadd52.ll
+++ b/llvm/test/CodeGen/X86/combine-vpmadd52.ll
@@ -183,3 +183,152 @@ define <2 x i64> @test_vpmadd52l_mul_zero_scalar(<2 x i64> %x0, <2 x i64> %x1) {
   %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> <i64 0, i64 123>, <2 x i64> %x1)
   ret <2 x i64> %1
 }
+
+define <2 x i64> @test_vpmadd52l_mul_lo52_zero(<2 x i64> %x0) {
+  ; (1 << 51) * (1 << 1) -> 1 << 52 -> low 52 bits are zeroes
+; AVX512-LABEL: test_vpmadd52l_mul_lo52_zero:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    vpbroadcastq {{.*#+}} xmm1 = [2251799813685248,2251799813685248]
+; AVX512-NEXT:    vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm1, %xmm0
+; AVX512-NEXT:    retq
+;
+; AVX-LABEL: test_vpmadd52l_mul_lo52_zero:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vpbroadcastq {{.*#+}} xmm1 = [2251799813685248,2251799813685248]
+; AVX-NEXT:    {vex} vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0
+; AVX-NEXT:    retq
+  %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> splat (i64 2251799813685248), <2 x i64> splat (i64 2))
+  ret <2 x i64> %1
+}
+
+define <2 x i64> @test_vpmadd52h_mul_hi52_zero(<2 x i64> %x0) {
+  ; (1 << 25) * (1 << 26) = 1 << 51 -> high 52 bits are zeroes
+; AVX512-LABEL: test_vpmadd52h_mul_hi52_zero:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    vpbroadcastq {{.*#+}} xmm1 = [33554432,33554432]
+; AVX512-NEXT:    vpmadd52huq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm1, %xmm0
+; AVX512-NEXT:    retq
+;
+; AVX-LABEL: test_vpmadd52h_mul_hi52_zero:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vpbroadcastq {{.*#+}} xmm1 = [33554432,33554432]
+; AVX-NEXT:    {vex} vpmadd52huq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0
+; AVX-NEXT:    retq
+  %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> splat (i64 33554432), <2 x i64> splat (i64 67108864))
+  ret <2 x i64> %1
+}
+
+define <2 x i64> @test_vpmadd52l_mul_lo52_const(<2 x i64> %x0) {
+; AVX512-LABEL: test_vpmadd52l_mul_lo52_const:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    vpbroadcastq {{.*#+}} xmm1 = [123,123]
+; AVX512-NEXT:    vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm1, %xmm0
+; AVX512-NEXT:    retq
+;
+; AVX-LABEL: test_vpmadd52l_mul_lo52_const:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vpmovsxbq {{.*#+}} xmm1 = [123,123]
+; AVX-NEXT:    {vex} vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0
+; AVX-NEXT:    retq
+  %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> splat (i64 123), <2 x i64> splat (i64 456))
+  ret <2 x i64> %1
+}
+
+define <2 x i64> @test_vpmadd52h_mul_hi52_const(<2 x i64> %x0) {
+  ; (1 << 51) * (1 << 51) -> 1 << 102 -> the high 52 bits is 1 << 50
+; AVX512-LABEL: test_vpmadd52h_mul_hi52_const:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    vpbroadcastq {{.*#+}} xmm1 = [2251799813685248,2251799813685248]
+; AVX512-NEXT:    vpmadd52huq %xmm1, %xmm1, %xmm0
+; AVX512-NEXT:    retq
+;
+; AVX-LABEL: test_vpmadd52h_mul_hi52_const:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vpbroadcastq {{.*#+}} xmm1 = [2251799813685248,2251799813685248]
+; AVX-NEXT:    {vex} vpmadd52huq %xmm1, %xmm1, %xmm0
+; AVX-NEXT:    retq
+  %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> splat (i64 2251799813685248), <2 x i64> splat (i64 2251799813685248))
+  ret <2 x i64> %1
+}
+
+define <2 x i64> @test_vpmadd52l_mul_lo52_mask(<2 x i64> %x0, <2 x i64> %x1, <2 x i64> %x2) {
+; AVX512-LABEL: test_vpmadd52l_mul_lo52_mask:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    vpbroadcastq {{.*#+}} xmm2 = [1073741824,1073741824]
+; AVX512-NEXT:    vpand %xmm2, %xmm0, %xmm3
+; AVX512-NEXT:    vpand %xmm2, %xmm1, %xmm1
+; AVX512-NEXT:    vpmadd52luq %xmm1, %xmm3, %xmm0
+; AVX512-NEXT:    retq
+;
+; AVX-LABEL: test_vpmadd52l_mul_lo52_mask:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vpbroadcastq {{.*#+}} xmm2 = [1073741824,1073741824]
+; AVX-NEXT:    vpand %xmm2, %xmm0, %xmm3
+; AVX-NEXT:    vpand %xmm2, %xmm1, %xmm1
+; AVX-NEXT:    {vex} vpmadd52luq %xmm1, %xmm3, %xmm0
+; AVX-NEXT:    retq
+  %and1 = and <2 x i64> %x0, splat (i64 1073741824) ; 1LL << 30
+  %and2 = and <2 x i64> %x1, splat (i64 1073741824) ; 1LL << 30
+  %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> %and1, <2 x i64> %and2)
+  ret <2 x i64> %1
+}
+
+define <2 x i64> @test_vpmadd52h_mul_hi52_mask(<2 x i64> %x0, <2 x i64> %x1, <2 x i64> %x2) {
+; AVX512-LABEL: test_vpmadd52h_mul_hi52_mask:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    vpsrlq $40, %xmm0, %xmm2
+; AVX512-NEXT:    vpsrlq $40, %xmm1, %xmm1
+; AVX512-NEXT:    vpmadd52huq %xmm1, %xmm2, %xmm0
+; AVX512-NEXT:    retq
+;
+; AVX-LABEL: test_vpmadd52h_mul_hi52_mask:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vpsrlq $40, %xmm0, %xmm2
+; AVX-NEXT:    vpsrlq $40, %xmm1, %xmm1
+; AVX-NEXT:    {vex} vpmadd52huq %xmm1, %xmm2, %xmm0
+; AVX-NEXT:    retq
+  %and1 = lshr <2 x i64> %x0, splat (i64 40)
+  %and2 = lshr <2 x i64> %x1, splat (i64 40)
+  %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> %and1, <2 x i64> %and2)
+  ret <2 x i64> %1
+}
+
+define <2 x i64> @test_vpmadd52l_mul_lo52_mask_negative(<2 x i64> %x0, <2 x i64> %x1, <2 x i64> %x2) {
+; AVX512-LABEL: test_vpmadd52l_mul_lo52_mask_negative:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    vpandq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm0, %xmm2
+; AVX512-NEXT:    vpandq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm1, %xmm1
+; AVX512-NEXT:    vpmadd52luq %xmm1, %xmm2, %xmm0
+; AVX512-NEXT:    retq
+;
+; AVX-LABEL: test_vpmadd52l_mul_lo52_mask_negative:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm2
+; AVX-NEXT:    vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
+; AVX-NEXT:    {vex} vpmadd52luq %xmm1, %xmm2, %xmm0
+; AVX-NEXT:    retq
+  %and1 = and <2 x i64> %x0, splat (i64 2097152) ; 1LL << 21
+  %and2 = and <2 x i64> %x1, splat (i64 1073741824) ; 1LL << 30
+  %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> %and1, <2 x i64> %and2)
+  ret <2 x i64> %1
+}
+
+define <2 x i64> @test_vpmadd52h_mul_hi52_negative(<2 x i64> %x0, <2 x i64> %x1, <2 x i64> %x2) {
+; AVX512-LABEL: test_vpmadd52h_mul_hi52_negative:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    vpsrlq $30, %xmm0, %xmm2
+; AVX512-NEXT:    vpsrlq $43, %xmm1, %xmm1
+; AVX512-NEXT:    vpmadd52huq %xmm1, %xmm2, %xmm0
+; AVX512-NEXT:    retq
+;
+; AVX-LABEL: test_vpmadd52h_mul_hi52_negative:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vpsrlq $30, %xmm0, %xmm2
+; AVX-NEXT:    vpsrlq $43, %xmm1, %xmm1
+; AVX-NEXT:    {vex} vpmadd52huq %xmm1, %xmm2, %xmm0
+; AVX-NEXT:    retq
+  %and1 = lshr <2 x i64> %x0, splat (i64 30)
+  %and2 = lshr <2 x i64> %x1, splat (i64 43)
+  %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> %and1, <2 x i64> %and2)
+  ret <2 x i64> %1
+}

>From 7dba1529b76037b65c8f5e72e6d8b035ec1fa69e Mon Sep 17 00:00:00 2001
From: XChy <xxs_chy at outlook.com>
Date: Mon, 1 Sep 2025 15:50:08 +0800
Subject: [PATCH 2/6] [X86] Fold C1 * C2 + Z --> C3 + Z for vpmadd52l/vpmadd52h

---
 llvm/lib/Target/X86/X86ISelLowering.cpp   | 33 +++++++---
 llvm/test/CodeGen/X86/combine-vpmadd52.ll | 74 +++++------------------
 2 files changed, 39 insertions(+), 68 deletions(-)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index d78cf00a5a2fc..840c2730625c0 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -44954,26 +44954,39 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
   }
   case X86ISD::VPMADD52L:
   case X86ISD::VPMADD52H: {
-    KnownBits KnownOp0, KnownOp1;
+    KnownBits Known52BitsOfOp0, Known52BitsOfOp1;
     SDValue Op0 = Op.getOperand(0);
     SDValue Op1 = Op.getOperand(1);
     SDValue Op2 = Op.getOperand(2);
     //  Only demand the lower 52-bits of operands 0 / 1 (and all 64-bits of
     //  operand 2).
     APInt Low52Bits = APInt::getLowBitsSet(BitWidth, 52);
-    if (SimplifyDemandedBits(Op0, Low52Bits, OriginalDemandedElts, KnownOp0,
-                             TLO, Depth + 1))
+    if (SimplifyDemandedBits(Op0, Low52Bits, OriginalDemandedElts,
+                             Known52BitsOfOp0, TLO, Depth + 1))
       return true;
 
-    if (SimplifyDemandedBits(Op1, Low52Bits, OriginalDemandedElts, KnownOp1,
-                             TLO, Depth + 1))
+    if (SimplifyDemandedBits(Op1, Low52Bits, OriginalDemandedElts,
+                             Known52BitsOfOp1, TLO, Depth + 1))
       return true;
 
-    // X * 0 + Y --> Y
-    // TODO: Handle cases where lower/higher 52 of bits of Op0 * Op1 are known
-    // zeroes.
-    if (KnownOp0.trunc(52).isZero() || KnownOp1.trunc(52).isZero())
-      return TLO.CombineTo(Op, Op2);
+    KnownBits KnownMul;
+    Known52BitsOfOp0 = Known52BitsOfOp0.trunc(52);
+    Known52BitsOfOp1 = Known52BitsOfOp1.trunc(52);
+    if (Opc == X86ISD::VPMADD52L) {
+      KnownMul =
+          KnownBits::mul(Known52BitsOfOp0.zext(104), Known52BitsOfOp1.zext(104))
+              .trunc(52);
+    } else {
+      KnownMul = KnownBits::mulhu(Known52BitsOfOp0, Known52BitsOfOp1);
+    }
+    KnownMul = KnownMul.zext(64);
+
+    // C1 * C2 + Z --> C3 + Z
+    if (KnownMul.isConstant()) {
+      SDValue C = TLO.DAG.getConstant(KnownMul.getConstant(), SDLoc(Op0), VT);
+      return TLO.CombineTo(Op,
+                           TLO.DAG.getNode(ISD::ADD, SDLoc(Op), VT, C, Op2));
+    }
 
     // TODO: Compute the known bits for VPMADD52L/VPMADD52H.
     break;
diff --git a/llvm/test/CodeGen/X86/combine-vpmadd52.ll b/llvm/test/CodeGen/X86/combine-vpmadd52.ll
index 7b73a1974359c..1e075bfe12a31 100644
--- a/llvm/test/CodeGen/X86/combine-vpmadd52.ll
+++ b/llvm/test/CodeGen/X86/combine-vpmadd52.ll
@@ -186,34 +186,18 @@ define <2 x i64> @test_vpmadd52l_mul_zero_scalar(<2 x i64> %x0, <2 x i64> %x1) {
 
 define <2 x i64> @test_vpmadd52l_mul_lo52_zero(<2 x i64> %x0) {
   ; (1 << 51) * (1 << 1) -> 1 << 52 -> low 52 bits are zeroes
-; AVX512-LABEL: test_vpmadd52l_mul_lo52_zero:
-; AVX512:       # %bb.0:
-; AVX512-NEXT:    vpbroadcastq {{.*#+}} xmm1 = [2251799813685248,2251799813685248]
-; AVX512-NEXT:    vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm1, %xmm0
-; AVX512-NEXT:    retq
-;
-; AVX-LABEL: test_vpmadd52l_mul_lo52_zero:
-; AVX:       # %bb.0:
-; AVX-NEXT:    vpbroadcastq {{.*#+}} xmm1 = [2251799813685248,2251799813685248]
-; AVX-NEXT:    {vex} vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0
-; AVX-NEXT:    retq
+; CHECK-LABEL: test_vpmadd52l_mul_lo52_zero:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    retq
   %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> splat (i64 2251799813685248), <2 x i64> splat (i64 2))
   ret <2 x i64> %1
 }
 
 define <2 x i64> @test_vpmadd52h_mul_hi52_zero(<2 x i64> %x0) {
   ; (1 << 25) * (1 << 26) = 1 << 51 -> high 52 bits are zeroes
-; AVX512-LABEL: test_vpmadd52h_mul_hi52_zero:
-; AVX512:       # %bb.0:
-; AVX512-NEXT:    vpbroadcastq {{.*#+}} xmm1 = [33554432,33554432]
-; AVX512-NEXT:    vpmadd52huq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm1, %xmm0
-; AVX512-NEXT:    retq
-;
-; AVX-LABEL: test_vpmadd52h_mul_hi52_zero:
-; AVX:       # %bb.0:
-; AVX-NEXT:    vpbroadcastq {{.*#+}} xmm1 = [33554432,33554432]
-; AVX-NEXT:    {vex} vpmadd52huq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0
-; AVX-NEXT:    retq
+; CHECK-LABEL: test_vpmadd52h_mul_hi52_zero:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    retq
   %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> splat (i64 33554432), <2 x i64> splat (i64 67108864))
   ret <2 x i64> %1
 }
@@ -221,14 +205,12 @@ define <2 x i64> @test_vpmadd52h_mul_hi52_zero(<2 x i64> %x0) {
 define <2 x i64> @test_vpmadd52l_mul_lo52_const(<2 x i64> %x0) {
 ; AVX512-LABEL: test_vpmadd52l_mul_lo52_const:
 ; AVX512:       # %bb.0:
-; AVX512-NEXT:    vpbroadcastq {{.*#+}} xmm1 = [123,123]
-; AVX512-NEXT:    vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm1, %xmm0
+; AVX512-NEXT:    vpaddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm0, %xmm0
 ; AVX512-NEXT:    retq
 ;
 ; AVX-LABEL: test_vpmadd52l_mul_lo52_const:
 ; AVX:       # %bb.0:
-; AVX-NEXT:    vpmovsxbq {{.*#+}} xmm1 = [123,123]
-; AVX-NEXT:    {vex} vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0
+; AVX-NEXT:    vpaddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
 ; AVX-NEXT:    retq
   %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> splat (i64 123), <2 x i64> splat (i64 456))
   ret <2 x i64> %1
@@ -238,35 +220,21 @@ define <2 x i64> @test_vpmadd52h_mul_hi52_const(<2 x i64> %x0) {
   ; (1 << 51) * (1 << 51) -> 1 << 102 -> the high 52 bits is 1 << 50
 ; AVX512-LABEL: test_vpmadd52h_mul_hi52_const:
 ; AVX512:       # %bb.0:
-; AVX512-NEXT:    vpbroadcastq {{.*#+}} xmm1 = [2251799813685248,2251799813685248]
-; AVX512-NEXT:    vpmadd52huq %xmm1, %xmm1, %xmm0
+; AVX512-NEXT:    vpaddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm0, %xmm0
 ; AVX512-NEXT:    retq
 ;
 ; AVX-LABEL: test_vpmadd52h_mul_hi52_const:
 ; AVX:       # %bb.0:
-; AVX-NEXT:    vpbroadcastq {{.*#+}} xmm1 = [2251799813685248,2251799813685248]
-; AVX-NEXT:    {vex} vpmadd52huq %xmm1, %xmm1, %xmm0
+; AVX-NEXT:    vpaddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
 ; AVX-NEXT:    retq
   %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> splat (i64 2251799813685248), <2 x i64> splat (i64 2251799813685248))
   ret <2 x i64> %1
 }
 
 define <2 x i64> @test_vpmadd52l_mul_lo52_mask(<2 x i64> %x0, <2 x i64> %x1, <2 x i64> %x2) {
-; AVX512-LABEL: test_vpmadd52l_mul_lo52_mask:
-; AVX512:       # %bb.0:
-; AVX512-NEXT:    vpbroadcastq {{.*#+}} xmm2 = [1073741824,1073741824]
-; AVX512-NEXT:    vpand %xmm2, %xmm0, %xmm3
-; AVX512-NEXT:    vpand %xmm2, %xmm1, %xmm1
-; AVX512-NEXT:    vpmadd52luq %xmm1, %xmm3, %xmm0
-; AVX512-NEXT:    retq
-;
-; AVX-LABEL: test_vpmadd52l_mul_lo52_mask:
-; AVX:       # %bb.0:
-; AVX-NEXT:    vpbroadcastq {{.*#+}} xmm2 = [1073741824,1073741824]
-; AVX-NEXT:    vpand %xmm2, %xmm0, %xmm3
-; AVX-NEXT:    vpand %xmm2, %xmm1, %xmm1
-; AVX-NEXT:    {vex} vpmadd52luq %xmm1, %xmm3, %xmm0
-; AVX-NEXT:    retq
+; CHECK-LABEL: test_vpmadd52l_mul_lo52_mask:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    retq
   %and1 = and <2 x i64> %x0, splat (i64 1073741824) ; 1LL << 30
   %and2 = and <2 x i64> %x1, splat (i64 1073741824) ; 1LL << 30
   %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> %and1, <2 x i64> %and2)
@@ -274,19 +242,9 @@ define <2 x i64> @test_vpmadd52l_mul_lo52_mask(<2 x i64> %x0, <2 x i64> %x1, <2
 }
 
 define <2 x i64> @test_vpmadd52h_mul_hi52_mask(<2 x i64> %x0, <2 x i64> %x1, <2 x i64> %x2) {
-; AVX512-LABEL: test_vpmadd52h_mul_hi52_mask:
-; AVX512:       # %bb.0:
-; AVX512-NEXT:    vpsrlq $40, %xmm0, %xmm2
-; AVX512-NEXT:    vpsrlq $40, %xmm1, %xmm1
-; AVX512-NEXT:    vpmadd52huq %xmm1, %xmm2, %xmm0
-; AVX512-NEXT:    retq
-;
-; AVX-LABEL: test_vpmadd52h_mul_hi52_mask:
-; AVX:       # %bb.0:
-; AVX-NEXT:    vpsrlq $40, %xmm0, %xmm2
-; AVX-NEXT:    vpsrlq $40, %xmm1, %xmm1
-; AVX-NEXT:    {vex} vpmadd52huq %xmm1, %xmm2, %xmm0
-; AVX-NEXT:    retq
+; CHECK-LABEL: test_vpmadd52h_mul_hi52_mask:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    retq
   %and1 = lshr <2 x i64> %x0, splat (i64 40)
   %and2 = lshr <2 x i64> %x1, splat (i64 40)
   %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> %and1, <2 x i64> %and2)

>From d6293b1939f3e6938b81a2a9fc1cc55cebe7f3ee Mon Sep 17 00:00:00 2001
From: XChy <xxs_chy at outlook.com>
Date: Mon, 1 Sep 2025 17:11:20 +0800
Subject: [PATCH 3/6] fix comments

---
 llvm/lib/Target/X86/X86ISelLowering.cpp | 29 ++++++++++---------------
 1 file changed, 12 insertions(+), 17 deletions(-)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 840c2730625c0..f467898ff869c 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -44954,38 +44954,33 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
   }
   case X86ISD::VPMADD52L:
   case X86ISD::VPMADD52H: {
-    KnownBits Known52BitsOfOp0, Known52BitsOfOp1;
+    KnownBits KnownOp0, KnownOp1;
     SDValue Op0 = Op.getOperand(0);
     SDValue Op1 = Op.getOperand(1);
     SDValue Op2 = Op.getOperand(2);
     //  Only demand the lower 52-bits of operands 0 / 1 (and all 64-bits of
     //  operand 2).
     APInt Low52Bits = APInt::getLowBitsSet(BitWidth, 52);
-    if (SimplifyDemandedBits(Op0, Low52Bits, OriginalDemandedElts,
-                             Known52BitsOfOp0, TLO, Depth + 1))
+    if (SimplifyDemandedBits(Op0, Low52Bits, OriginalDemandedElts, KnownOp0,
+                             TLO, Depth + 1))
       return true;
 
-    if (SimplifyDemandedBits(Op1, Low52Bits, OriginalDemandedElts,
-                             Known52BitsOfOp1, TLO, Depth + 1))
+    if (SimplifyDemandedBits(Op1, Low52Bits, OriginalDemandedElts, KnownOp1,
+                             TLO, Depth + 1))
       return true;
 
     KnownBits KnownMul;
-    Known52BitsOfOp0 = Known52BitsOfOp0.trunc(52);
-    Known52BitsOfOp1 = Known52BitsOfOp1.trunc(52);
-    if (Opc == X86ISD::VPMADD52L) {
-      KnownMul =
-          KnownBits::mul(Known52BitsOfOp0.zext(104), Known52BitsOfOp1.zext(104))
-              .trunc(52);
-    } else {
-      KnownMul = KnownBits::mulhu(Known52BitsOfOp0, Known52BitsOfOp1);
-    }
+    KnownOp0 = KnownOp0.trunc(52);
+    KnownOp1 = KnownOp1.trunc(52);
+    KnownMul = Opc == X86ISD::VPMADD52L ? KnownBits::mul(KnownOp0, KnownOp1)
+                                        : KnownBits::mulhu(KnownOp0, KnownOp1);
     KnownMul = KnownMul.zext(64);
 
     // C1 * C2 + Z --> C3 + Z
     if (KnownMul.isConstant()) {
-      SDValue C = TLO.DAG.getConstant(KnownMul.getConstant(), SDLoc(Op0), VT);
-      return TLO.CombineTo(Op,
-                           TLO.DAG.getNode(ISD::ADD, SDLoc(Op), VT, C, Op2));
+      SDLoc DL(Op);
+      SDValue C = TLO.DAG.getConstant(KnownMul.getConstant(), DL, VT);
+      return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::ADD, DL, VT, C, Op2));
     }
 
     // TODO: Compute the known bits for VPMADD52L/VPMADD52H.

>From 740a00e3d528869723af7018d7c31bb3cddbe2cb Mon Sep 17 00:00:00 2001
From: XChy <xxs_chy at outlook.com>
Date: Mon, 1 Sep 2025 17:21:19 +0800
Subject: [PATCH 4/6] correct the comment

---
 llvm/lib/Target/X86/X86ISelLowering.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index f467898ff869c..ab32222c923e8 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -44976,7 +44976,7 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
                                         : KnownBits::mulhu(KnownOp0, KnownOp1);
     KnownMul = KnownMul.zext(64);
 
-    // C1 * C2 + Z --> C3 + Z
+    // lo/hi(X * Y) + Z --> C3 + Z
     if (KnownMul.isConstant()) {
       SDLoc DL(Op);
       SDValue C = TLO.DAG.getConstant(KnownMul.getConstant(), DL, VT);

>From 2f5ca3b6f493a8c29f5d3b19f946b661368a5d9a Mon Sep 17 00:00:00 2001
From: XChy <xxs_chy at outlook.com>
Date: Mon, 1 Sep 2025 17:21:55 +0800
Subject: [PATCH 5/6] correct the comment again

---
 llvm/lib/Target/X86/X86ISelLowering.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index ab32222c923e8..814a4bd1df714 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -44976,7 +44976,7 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
                                         : KnownBits::mulhu(KnownOp0, KnownOp1);
     KnownMul = KnownMul.zext(64);
 
-    // lo/hi(X * Y) + Z --> C3 + Z
+    // lo/hi(X * Y) + Z --> C + Z
     if (KnownMul.isConstant()) {
       SDLoc DL(Op);
       SDValue C = TLO.DAG.getConstant(KnownMul.getConstant(), DL, VT);

>From 880d0ca87977a4fecdbc996e6a9fbaac789fed3c Mon Sep 17 00:00:00 2001
From: XChy <xxs_chy at outlook.com>
Date: Tue, 2 Sep 2025 00:35:29 +0800
Subject: [PATCH 6/6] update the comment in testcase

---
 llvm/test/CodeGen/X86/combine-vpmadd52.ll | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/llvm/test/CodeGen/X86/combine-vpmadd52.ll b/llvm/test/CodeGen/X86/combine-vpmadd52.ll
index 1e075bfe12a31..9afc1119267ec 100644
--- a/llvm/test/CodeGen/X86/combine-vpmadd52.ll
+++ b/llvm/test/CodeGen/X86/combine-vpmadd52.ll
@@ -184,8 +184,8 @@ define <2 x i64> @test_vpmadd52l_mul_zero_scalar(<2 x i64> %x0, <2 x i64> %x1) {
   ret <2 x i64> %1
 }
 
+; (1 << 51) * (1 << 1) -> 1 << 52 -> low 52 bits are zeroes
 define <2 x i64> @test_vpmadd52l_mul_lo52_zero(<2 x i64> %x0) {
-  ; (1 << 51) * (1 << 1) -> 1 << 52 -> low 52 bits are zeroes
 ; CHECK-LABEL: test_vpmadd52l_mul_lo52_zero:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    retq
@@ -193,8 +193,8 @@ define <2 x i64> @test_vpmadd52l_mul_lo52_zero(<2 x i64> %x0) {
   ret <2 x i64> %1
 }
 
+; (1 << 25) * (1 << 26) = 1 << 51 -> high 52 bits are zeroes
 define <2 x i64> @test_vpmadd52h_mul_hi52_zero(<2 x i64> %x0) {
-  ; (1 << 25) * (1 << 26) = 1 << 51 -> high 52 bits are zeroes
 ; CHECK-LABEL: test_vpmadd52h_mul_hi52_zero:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    retq
@@ -216,8 +216,8 @@ define <2 x i64> @test_vpmadd52l_mul_lo52_const(<2 x i64> %x0) {
   ret <2 x i64> %1
 }
 
+; (1 << 51) * (1 << 51) -> 1 << 102 -> the high 52 bits is 1 << 50
 define <2 x i64> @test_vpmadd52h_mul_hi52_const(<2 x i64> %x0) {
-  ; (1 << 51) * (1 << 51) -> 1 << 102 -> the high 52 bits is 1 << 50
 ; AVX512-LABEL: test_vpmadd52h_mul_hi52_const:
 ; AVX512:       # %bb.0:
 ; AVX512-NEXT:    vpaddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm0, %xmm0



More information about the llvm-commits mailing list