[llvm] [X86] Fold C1 * C2 + Z --> C3 + Z for vpmadd52l/vpmadd52h (PR #156293)
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Mon Sep 1 02:01:51 PDT 2025
================
@@ -44954,26 +44954,39 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
}
case X86ISD::VPMADD52L:
case X86ISD::VPMADD52H: {
- KnownBits KnownOp0, KnownOp1;
+ KnownBits Known52BitsOfOp0, Known52BitsOfOp1;
SDValue Op0 = Op.getOperand(0);
SDValue Op1 = Op.getOperand(1);
SDValue Op2 = Op.getOperand(2);
// Only demand the lower 52-bits of operands 0 / 1 (and all 64-bits of
// operand 2).
APInt Low52Bits = APInt::getLowBitsSet(BitWidth, 52);
- if (SimplifyDemandedBits(Op0, Low52Bits, OriginalDemandedElts, KnownOp0,
- TLO, Depth + 1))
+ if (SimplifyDemandedBits(Op0, Low52Bits, OriginalDemandedElts,
+ Known52BitsOfOp0, TLO, Depth + 1))
return true;
- if (SimplifyDemandedBits(Op1, Low52Bits, OriginalDemandedElts, KnownOp1,
- TLO, Depth + 1))
+ if (SimplifyDemandedBits(Op1, Low52Bits, OriginalDemandedElts,
+ Known52BitsOfOp1, TLO, Depth + 1))
return true;
- // X * 0 + Y --> Y
- // TODO: Handle cases where lower/higher 52 of bits of Op0 * Op1 are known
- // zeroes.
- if (KnownOp0.trunc(52).isZero() || KnownOp1.trunc(52).isZero())
- return TLO.CombineTo(Op, Op2);
+ KnownBits KnownMul;
+ Known52BitsOfOp0 = Known52BitsOfOp0.trunc(52);
+ Known52BitsOfOp1 = Known52BitsOfOp1.trunc(52);
+ if (Opc == X86ISD::VPMADD52L) {
+ KnownMul =
+ KnownBits::mul(Known52BitsOfOp0.zext(104), Known52BitsOfOp1.zext(104))
+ .trunc(52);
+ } else {
+ KnownMul = KnownBits::mulhu(Known52BitsOfOp0, Known52BitsOfOp1);
+ }
+ KnownMul = KnownMul.zext(64);
+
+ // C1 * C2 + Z --> C3 + Z
+ if (KnownMul.isConstant()) {
+ SDValue C = TLO.DAG.getConstant(KnownMul.getConstant(), SDLoc(Op0), VT);
----------------
RKSimon wrote:
Pull out repeated SDLoc()
https://github.com/llvm/llvm-project/pull/156293
More information about the llvm-commits
mailing list