[llvm] [NVPTX] Fold (add (select 0, (mul a, b)), c) -> (select c, (mad a, b, c)) (PR #96352)
Alex MacLean via llvm-commits
llvm-commits at lists.llvm.org
Tue Jun 25 12:51:00 PDT 2024
- Previous message: [llvm] [NVPTX] Fold (add (select 0, (mul a, b)), c) -> (select c, (mad a, b, c)) (PR #96352)
- Next message: [llvm] [NVPTX] Fold (add (select 0, (mul a, b)), c) -> (select c, (mad a, b, c)) (PR #96352)
- Messages sorted by:
[ date ]
[ thread ]
[ subject ]
[ author ]
================
@@ -5215,103 +5215,129 @@ bool NVPTXTargetLowering::allowUnsafeFPMath(MachineFunction &MF) const {
return F.getFnAttribute("unsafe-fp-math").getValueAsBool();
}
+static bool isConstZero(const SDValue &Operand) {
+ const auto *Const = dyn_cast<ConstantSDNode>(Operand);
+ return Const && Const->getZExtValue() == 0;
+}
+
/// PerformADDCombineWithOperands - Try DAG combinations for an ADD with
/// operands N0 and N1. This is a helper for PerformADDCombine that is
/// called with the default operands, and if that fails, with commuted
/// operands.
-static SDValue PerformADDCombineWithOperands(
- SDNode *N, SDValue N0, SDValue N1, TargetLowering::DAGCombinerInfo &DCI,
- const NVPTXSubtarget &Subtarget, CodeGenOptLevel OptLevel) {
- SelectionDAG &DAG = DCI.DAG;
- // Skip non-integer, non-scalar case
- EVT VT=N0.getValueType();
- if (VT.isVector())
+static SDValue
+PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
+ TargetLowering::DAGCombinerInfo &DCI) {
+ EVT VT = N0.getValueType();
+
+ // Since integer multiply-add costs the same as integer multiply
+ // but is more costly than integer add, do the fusion only when
+ // the mul is only used in the add.
+ if (!N0.getNode()->hasOneUse())
return SDValue();
// fold (add (mul a, b), c) -> (mad a, b, c)
//
- if (N0.getOpcode() == ISD::MUL) {
- assert (VT.isInteger());
- // For integer:
- // Since integer multiply-add costs the same as integer multiply
- // but is more costly than integer add, do the fusion only when
- // the mul is only used in the add.
- if (OptLevel == CodeGenOptLevel::None || VT != MVT::i32 ||
- !N0.getNode()->hasOneUse())
+ if (N0.getOpcode() == ISD::MUL)
+ return DCI.DAG.getNode(NVPTXISD::IMAD, SDLoc(N), VT, N0.getOperand(0),
+ N0.getOperand(1), N1);
+
+ // fold (add (select cond, 0, (mul a, b)), c)
+ // -> (select cond, (mad a, b, c), c)
+ //
+ if (N0.getOpcode() == ISD::SELECT) {
+ bool ZeroCond;
+ if (isConstZero(N0->getOperand(1)))
+ ZeroCond = true;
+ else if (isConstZero(N0->getOperand(2)))
+ ZeroCond = false;
+ else
+ return SDValue();
+
+ SDValue M = N0->getOperand(ZeroCond ? 2 : 1);
----------------
AlexMaclean wrote:
Sounds good, I've updated this to use an index.
https://github.com/llvm/llvm-project/pull/96352
- Previous message: [llvm] [NVPTX] Fold (add (select 0, (mul a, b)), c) -> (select c, (mad a, b, c)) (PR #96352)
- Next message: [llvm] [NVPTX] Fold (add (select 0, (mul a, b)), c) -> (select c, (mad a, b, c)) (PR #96352)
- Messages sorted by:
[ date ]
[ thread ]
[ subject ]
[ author ]
More information about the llvm-commits
mailing list