[llvm] c241eb3 - [X86] Fold X * Y + Z --> C + Z for vpmadd52l/vpmadd52h (#156293)

via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 1 10:18:54 PDT 2025


Author: XChy
Date: 2025-09-01T17:18:50Z
New Revision: c241eb30c5c315942b8c83e1125b508e9a626a7d

URL: https://github.com/llvm/llvm-project/commit/c241eb30c5c315942b8c83e1125b508e9a626a7d
DIFF: https://github.com/llvm/llvm-project/commit/c241eb30c5c315942b8c83e1125b508e9a626a7d.diff

LOG: [X86] Fold X * Y + Z --> C + Z for vpmadd52l/vpmadd52h (#156293)

Address TODO and implement constant fold for intermediate multiplication
result of vpmadd52l/vpmadd52h.

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/X86/combine-vpmadd52.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 08ae0d52d795e..dd4c6088392dd 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -44969,11 +44969,19 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
                              TLO, Depth + 1))
       return true;
 
-    // X * 0 + Y --> Y
-    // TODO: Handle cases where lower/higher 52 of bits of Op0 * Op1 are known
-    // zeroes.
-    if (KnownOp0.trunc(52).isZero() || KnownOp1.trunc(52).isZero())
-      return TLO.CombineTo(Op, Op2);
+    KnownBits KnownMul;
+    KnownOp0 = KnownOp0.trunc(52);
+    KnownOp1 = KnownOp1.trunc(52);
+    KnownMul = Opc == X86ISD::VPMADD52L ? KnownBits::mul(KnownOp0, KnownOp1)
+                                        : KnownBits::mulhu(KnownOp0, KnownOp1);
+    KnownMul = KnownMul.zext(64);
+
+    // lo/hi(X * Y) + Z --> C + Z
+    if (KnownMul.isConstant()) {
+      SDLoc DL(Op);
+      SDValue C = TLO.DAG.getConstant(KnownMul.getConstant(), DL, VT);
+      return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::ADD, DL, VT, C, Op2));
+    }
 
     // TODO: Compute the known bits for VPMADD52L/VPMADD52H.
     break;

diff  --git a/llvm/test/CodeGen/X86/combine-vpmadd52.ll b/llvm/test/CodeGen/X86/combine-vpmadd52.ll
index fd295ea31c55c..9afc1119267ec 100644
--- a/llvm/test/CodeGen/X86/combine-vpmadd52.ll
+++ b/llvm/test/CodeGen/X86/combine-vpmadd52.ll
@@ -183,3 +183,110 @@ define <2 x i64> @test_vpmadd52l_mul_zero_scalar(<2 x i64> %x0, <2 x i64> %x1) {
   %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> <i64 0, i64 123>, <2 x i64> %x1)
   ret <2 x i64> %1
 }
+
+; (1 << 51) * (1 << 1) -> 1 << 52 -> low 52 bits are zeroes
+define <2 x i64> @test_vpmadd52l_mul_lo52_zero(<2 x i64> %x0) {
+; CHECK-LABEL: test_vpmadd52l_mul_lo52_zero:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    retq
+  %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> splat (i64 2251799813685248), <2 x i64> splat (i64 2))
+  ret <2 x i64> %1
+}
+
+; (1 << 25) * (1 << 26) = 1 << 51 -> high 52 bits are zeroes
+define <2 x i64> @test_vpmadd52h_mul_hi52_zero(<2 x i64> %x0) {
+; CHECK-LABEL: test_vpmadd52h_mul_hi52_zero:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    retq
+  %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> splat (i64 33554432), <2 x i64> splat (i64 67108864))
+  ret <2 x i64> %1
+}
+
+define <2 x i64> @test_vpmadd52l_mul_lo52_const(<2 x i64> %x0) {
+; AVX512-LABEL: test_vpmadd52l_mul_lo52_const:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    vpaddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm0, %xmm0
+; AVX512-NEXT:    retq
+;
+; AVX-LABEL: test_vpmadd52l_mul_lo52_const:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vpaddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
+; AVX-NEXT:    retq
+  %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> splat (i64 123), <2 x i64> splat (i64 456))
+  ret <2 x i64> %1
+}
+
+; (1 << 51) * (1 << 51) -> 1 << 102 -> the high 52 bits is 1 << 50
+define <2 x i64> @test_vpmadd52h_mul_hi52_const(<2 x i64> %x0) {
+; AVX512-LABEL: test_vpmadd52h_mul_hi52_const:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    vpaddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm0, %xmm0
+; AVX512-NEXT:    retq
+;
+; AVX-LABEL: test_vpmadd52h_mul_hi52_const:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vpaddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
+; AVX-NEXT:    retq
+  %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> splat (i64 2251799813685248), <2 x i64> splat (i64 2251799813685248))
+  ret <2 x i64> %1
+}
+
+define <2 x i64> @test_vpmadd52l_mul_lo52_mask(<2 x i64> %x0, <2 x i64> %x1, <2 x i64> %x2) {
+; CHECK-LABEL: test_vpmadd52l_mul_lo52_mask:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    retq
+  %and1 = and <2 x i64> %x0, splat (i64 1073741824) ; 1LL << 30
+  %and2 = and <2 x i64> %x1, splat (i64 1073741824) ; 1LL << 30
+  %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> %and1, <2 x i64> %and2)
+  ret <2 x i64> %1
+}
+
+define <2 x i64> @test_vpmadd52h_mul_hi52_mask(<2 x i64> %x0, <2 x i64> %x1, <2 x i64> %x2) {
+; CHECK-LABEL: test_vpmadd52h_mul_hi52_mask:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    retq
+  %and1 = lshr <2 x i64> %x0, splat (i64 40)
+  %and2 = lshr <2 x i64> %x1, splat (i64 40)
+  %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> %and1, <2 x i64> %and2)
+  ret <2 x i64> %1
+}
+
+define <2 x i64> @test_vpmadd52l_mul_lo52_mask_negative(<2 x i64> %x0, <2 x i64> %x1, <2 x i64> %x2) {
+; AVX512-LABEL: test_vpmadd52l_mul_lo52_mask_negative:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    vpandq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm0, %xmm2
+; AVX512-NEXT:    vpandq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm1, %xmm1
+; AVX512-NEXT:    vpmadd52luq %xmm1, %xmm2, %xmm0
+; AVX512-NEXT:    retq
+;
+; AVX-LABEL: test_vpmadd52l_mul_lo52_mask_negative:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm2
+; AVX-NEXT:    vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
+; AVX-NEXT:    {vex} vpmadd52luq %xmm1, %xmm2, %xmm0
+; AVX-NEXT:    retq
+  %and1 = and <2 x i64> %x0, splat (i64 2097152) ; 1LL << 21
+  %and2 = and <2 x i64> %x1, splat (i64 1073741824) ; 1LL << 30
+  %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> %and1, <2 x i64> %and2)
+  ret <2 x i64> %1
+}
+
+define <2 x i64> @test_vpmadd52h_mul_hi52_negative(<2 x i64> %x0, <2 x i64> %x1, <2 x i64> %x2) {
+; AVX512-LABEL: test_vpmadd52h_mul_hi52_negative:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    vpsrlq $30, %xmm0, %xmm2
+; AVX512-NEXT:    vpsrlq $43, %xmm1, %xmm1
+; AVX512-NEXT:    vpmadd52huq %xmm1, %xmm2, %xmm0
+; AVX512-NEXT:    retq
+;
+; AVX-LABEL: test_vpmadd52h_mul_hi52_negative:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vpsrlq $30, %xmm0, %xmm2
+; AVX-NEXT:    vpsrlq $43, %xmm1, %xmm1
+; AVX-NEXT:    {vex} vpmadd52huq %xmm1, %xmm2, %xmm0
+; AVX-NEXT:    retq
+  %and1 = lshr <2 x i64> %x0, splat (i64 30)
+  %and2 = lshr <2 x i64> %x1, splat (i64 43)
+  %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> %and1, <2 x i64> %and2)
+  ret <2 x i64> %1
+}


        


More information about the llvm-commits mailing list