[llvm] c241eb3 - [X86] Fold X * Y + Z --> C + Z for vpmadd52l/vpmadd52h (#156293)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Sep 1 10:18:54 PDT 2025
Author: XChy
Date: 2025-09-01T17:18:50Z
New Revision: c241eb30c5c315942b8c83e1125b508e9a626a7d
URL: https://github.com/llvm/llvm-project/commit/c241eb30c5c315942b8c83e1125b508e9a626a7d
DIFF: https://github.com/llvm/llvm-project/commit/c241eb30c5c315942b8c83e1125b508e9a626a7d.diff
LOG: [X86] Fold X * Y + Z --> C + Z for vpmadd52l/vpmadd52h (#156293)
Address TODO and implement constant fold for intermediate multiplication
result of vpmadd52l/vpmadd52h.
Added:
Modified:
llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/CodeGen/X86/combine-vpmadd52.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 08ae0d52d795e..dd4c6088392dd 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -44969,11 +44969,19 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
TLO, Depth + 1))
return true;
- // X * 0 + Y --> Y
- // TODO: Handle cases where lower/higher 52 of bits of Op0 * Op1 are known
- // zeroes.
- if (KnownOp0.trunc(52).isZero() || KnownOp1.trunc(52).isZero())
- return TLO.CombineTo(Op, Op2);
+ KnownBits KnownMul;
+ KnownOp0 = KnownOp0.trunc(52);
+ KnownOp1 = KnownOp1.trunc(52);
+ KnownMul = Opc == X86ISD::VPMADD52L ? KnownBits::mul(KnownOp0, KnownOp1)
+ : KnownBits::mulhu(KnownOp0, KnownOp1);
+ KnownMul = KnownMul.zext(64);
+
+ // lo/hi(X * Y) + Z --> C + Z
+ if (KnownMul.isConstant()) {
+ SDLoc DL(Op);
+ SDValue C = TLO.DAG.getConstant(KnownMul.getConstant(), DL, VT);
+ return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::ADD, DL, VT, C, Op2));
+ }
// TODO: Compute the known bits for VPMADD52L/VPMADD52H.
break;
diff --git a/llvm/test/CodeGen/X86/combine-vpmadd52.ll b/llvm/test/CodeGen/X86/combine-vpmadd52.ll
index fd295ea31c55c..9afc1119267ec 100644
--- a/llvm/test/CodeGen/X86/combine-vpmadd52.ll
+++ b/llvm/test/CodeGen/X86/combine-vpmadd52.ll
@@ -183,3 +183,110 @@ define <2 x i64> @test_vpmadd52l_mul_zero_scalar(<2 x i64> %x0, <2 x i64> %x1) {
%1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> <i64 0, i64 123>, <2 x i64> %x1)
ret <2 x i64> %1
}
+
+; (1 << 51) * (1 << 1) -> 1 << 52 -> low 52 bits are zeroes
+define <2 x i64> @test_vpmadd52l_mul_lo52_zero(<2 x i64> %x0) {
+; CHECK-LABEL: test_vpmadd52l_mul_lo52_zero:
+; CHECK: # %bb.0:
+; CHECK-NEXT: retq
+ %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> splat (i64 2251799813685248), <2 x i64> splat (i64 2))
+ ret <2 x i64> %1
+}
+
+; (1 << 25) * (1 << 26) = 1 << 51 -> high 52 bits are zeroes
+define <2 x i64> @test_vpmadd52h_mul_hi52_zero(<2 x i64> %x0) {
+; CHECK-LABEL: test_vpmadd52h_mul_hi52_zero:
+; CHECK: # %bb.0:
+; CHECK-NEXT: retq
+ %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> splat (i64 33554432), <2 x i64> splat (i64 67108864))
+ ret <2 x i64> %1
+}
+
+define <2 x i64> @test_vpmadd52l_mul_lo52_const(<2 x i64> %x0) {
+; AVX512-LABEL: test_vpmadd52l_mul_lo52_const:
+; AVX512: # %bb.0:
+; AVX512-NEXT: vpaddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm0, %xmm0
+; AVX512-NEXT: retq
+;
+; AVX-LABEL: test_vpmadd52l_mul_lo52_const:
+; AVX: # %bb.0:
+; AVX-NEXT: vpaddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
+; AVX-NEXT: retq
+ %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> splat (i64 123), <2 x i64> splat (i64 456))
+ ret <2 x i64> %1
+}
+
+; (1 << 51) * (1 << 51) -> 1 << 102 -> the high 52 bits is 1 << 50
+define <2 x i64> @test_vpmadd52h_mul_hi52_const(<2 x i64> %x0) {
+; AVX512-LABEL: test_vpmadd52h_mul_hi52_const:
+; AVX512: # %bb.0:
+; AVX512-NEXT: vpaddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm0, %xmm0
+; AVX512-NEXT: retq
+;
+; AVX-LABEL: test_vpmadd52h_mul_hi52_const:
+; AVX: # %bb.0:
+; AVX-NEXT: vpaddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
+; AVX-NEXT: retq
+ %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> splat (i64 2251799813685248), <2 x i64> splat (i64 2251799813685248))
+ ret <2 x i64> %1
+}
+
+define <2 x i64> @test_vpmadd52l_mul_lo52_mask(<2 x i64> %x0, <2 x i64> %x1, <2 x i64> %x2) {
+; CHECK-LABEL: test_vpmadd52l_mul_lo52_mask:
+; CHECK: # %bb.0:
+; CHECK-NEXT: retq
+ %and1 = and <2 x i64> %x0, splat (i64 1073741824) ; 1LL << 30
+ %and2 = and <2 x i64> %x1, splat (i64 1073741824) ; 1LL << 30
+ %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> %and1, <2 x i64> %and2)
+ ret <2 x i64> %1
+}
+
+define <2 x i64> @test_vpmadd52h_mul_hi52_mask(<2 x i64> %x0, <2 x i64> %x1, <2 x i64> %x2) {
+; CHECK-LABEL: test_vpmadd52h_mul_hi52_mask:
+; CHECK: # %bb.0:
+; CHECK-NEXT: retq
+ %and1 = lshr <2 x i64> %x0, splat (i64 40)
+ %and2 = lshr <2 x i64> %x1, splat (i64 40)
+ %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> %and1, <2 x i64> %and2)
+ ret <2 x i64> %1
+}
+
+define <2 x i64> @test_vpmadd52l_mul_lo52_mask_negative(<2 x i64> %x0, <2 x i64> %x1, <2 x i64> %x2) {
+; AVX512-LABEL: test_vpmadd52l_mul_lo52_mask_negative:
+; AVX512: # %bb.0:
+; AVX512-NEXT: vpandq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm0, %xmm2
+; AVX512-NEXT: vpandq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm1, %xmm1
+; AVX512-NEXT: vpmadd52luq %xmm1, %xmm2, %xmm0
+; AVX512-NEXT: retq
+;
+; AVX-LABEL: test_vpmadd52l_mul_lo52_mask_negative:
+; AVX: # %bb.0:
+; AVX-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm2
+; AVX-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
+; AVX-NEXT: {vex} vpmadd52luq %xmm1, %xmm2, %xmm0
+; AVX-NEXT: retq
+ %and1 = and <2 x i64> %x0, splat (i64 2097152) ; 1LL << 21
+ %and2 = and <2 x i64> %x1, splat (i64 1073741824) ; 1LL << 30
+ %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> %and1, <2 x i64> %and2)
+ ret <2 x i64> %1
+}
+
+define <2 x i64> @test_vpmadd52h_mul_hi52_negative(<2 x i64> %x0, <2 x i64> %x1, <2 x i64> %x2) {
+; AVX512-LABEL: test_vpmadd52h_mul_hi52_negative:
+; AVX512: # %bb.0:
+; AVX512-NEXT: vpsrlq $30, %xmm0, %xmm2
+; AVX512-NEXT: vpsrlq $43, %xmm1, %xmm1
+; AVX512-NEXT: vpmadd52huq %xmm1, %xmm2, %xmm0
+; AVX512-NEXT: retq
+;
+; AVX-LABEL: test_vpmadd52h_mul_hi52_negative:
+; AVX: # %bb.0:
+; AVX-NEXT: vpsrlq $30, %xmm0, %xmm2
+; AVX-NEXT: vpsrlq $43, %xmm1, %xmm1
+; AVX-NEXT: {vex} vpmadd52huq %xmm1, %xmm2, %xmm0
+; AVX-NEXT: retq
+ %and1 = lshr <2 x i64> %x0, splat (i64 30)
+ %and2 = lshr <2 x i64> %x1, splat (i64 43)
+ %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> %and1, <2 x i64> %and2)
+ ret <2 x i64> %1
+}
More information about the llvm-commits
mailing list