[llvm] [NVPTX] Fold (add (select 0, (mul a, b)), c) -> (select c, (mad a, b, c)) (PR #96352)
Alex MacLean via llvm-commits
llvm-commits at lists.llvm.org
Tue Jun 25 12:49:24 PDT 2024
- Previous message: [llvm] [NVPTX] Fold (add (select 0, (mul a, b)), c) -> (select c, (mad a, b, c)) (PR #96352)
- Next message: [llvm] [NVPTX] Fold (add (select 0, (mul a, b)), c) -> (select c, (mad a, b, c)) (PR #96352)
- Messages sorted by:
[ date ]
[ thread ]
[ subject ]
[ author ]
https://github.com/AlexMaclean updated https://github.com/llvm/llvm-project/pull/96352
>From dbd09da875aee8d77b7e1c3fda518f0c88569668 Mon Sep 17 00:00:00 2001
From: Alex MacLean <amaclean at nvidia.com>
Date: Fri, 21 Jun 2024 19:58:33 +0000
Subject: [PATCH 1/2] [NVPTX] Fold (add (select 0, (mul a, b)), c) -> (select
(mad a, b, c), c)
---
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 219 ++++++++++++--------
llvm/test/CodeGen/NVPTX/combine-mad.ll | 49 +++++
2 files changed, 185 insertions(+), 83 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index f4ef7c9914f13..0c609554370a3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -5215,103 +5215,129 @@ bool NVPTXTargetLowering::allowUnsafeFPMath(MachineFunction &MF) const {
return F.getFnAttribute("unsafe-fp-math").getValueAsBool();
}
+static bool isConstZero(const SDValue &Operand) {
+ const auto *Const = dyn_cast<ConstantSDNode>(Operand);
+ return Const && Const->getZExtValue() == 0;
+}
+
/// PerformADDCombineWithOperands - Try DAG combinations for an ADD with
/// operands N0 and N1. This is a helper for PerformADDCombine that is
/// called with the default operands, and if that fails, with commuted
/// operands.
-static SDValue PerformADDCombineWithOperands(
- SDNode *N, SDValue N0, SDValue N1, TargetLowering::DAGCombinerInfo &DCI,
- const NVPTXSubtarget &Subtarget, CodeGenOptLevel OptLevel) {
- SelectionDAG &DAG = DCI.DAG;
- // Skip non-integer, non-scalar case
- EVT VT=N0.getValueType();
- if (VT.isVector())
+static SDValue
+PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
+ TargetLowering::DAGCombinerInfo &DCI) {
+ EVT VT = N0.getValueType();
+
+ // Since integer multiply-add costs the same as integer multiply
+ // but is more costly than integer add, do the fusion only when
+ // the mul is only used in the add.
+ if (!N0.getNode()->hasOneUse())
return SDValue();
// fold (add (mul a, b), c) -> (mad a, b, c)
//
- if (N0.getOpcode() == ISD::MUL) {
- assert (VT.isInteger());
- // For integer:
- // Since integer multiply-add costs the same as integer multiply
- // but is more costly than integer add, do the fusion only when
- // the mul is only used in the add.
- if (OptLevel == CodeGenOptLevel::None || VT != MVT::i32 ||
- !N0.getNode()->hasOneUse())
+ if (N0.getOpcode() == ISD::MUL)
+ return DCI.DAG.getNode(NVPTXISD::IMAD, SDLoc(N), VT, N0.getOperand(0),
+ N0.getOperand(1), N1);
+
+ // fold (add (select cond, 0, (mul a, b)), c)
+ // -> (select cond, (mad a, b, c), c)
+ //
+ if (N0.getOpcode() == ISD::SELECT) {
+ bool ZeroCond;
+ if (isConstZero(N0->getOperand(1)))
+ ZeroCond = true;
+ else if (isConstZero(N0->getOperand(2)))
+ ZeroCond = false;
+ else
+ return SDValue();
+
+ SDValue M = N0->getOperand(ZeroCond ? 2 : 1);
+ if (M->getOpcode() != ISD::MUL || !M.getNode()->hasOneUse())
return SDValue();
- // Do the folding
- return DAG.getNode(NVPTXISD::IMAD, SDLoc(N), VT,
- N0.getOperand(0), N0.getOperand(1), N1);
+ SDValue MAD = DCI.DAG.getNode(NVPTXISD::IMAD, SDLoc(N), VT,
+ M->getOperand(0), M->getOperand(1), N1);
+ return DCI.DAG.getSelect(SDLoc(N), VT, N0->getOperand(0),
+ (ZeroCond ? N1 : MAD), (ZeroCond ? MAD : N1));
}
- else if (N0.getOpcode() == ISD::FMUL) {
- if (VT == MVT::f32 || VT == MVT::f64) {
- const auto *TLI = static_cast<const NVPTXTargetLowering *>(
- &DAG.getTargetLoweringInfo());
- if (!TLI->allowFMA(DAG.getMachineFunction(), OptLevel))
- return SDValue();
- // For floating point:
- // Do the fusion only when the mul has less than 5 uses and all
- // are add.
- // The heuristic is that if a use is not an add, then that use
- // cannot be fused into fma, therefore mul is still needed anyway.
- // If there are more than 4 uses, even if they are all add, fusing
- // them will increase register pressue.
- //
- int numUses = 0;
- int nonAddCount = 0;
- for (const SDNode *User : N0.getNode()->uses()) {
- numUses++;
- if (User->getOpcode() != ISD::FADD)
- ++nonAddCount;
- }
+ return SDValue();
+}
+
+static SDValue
+PerformFADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
+ TargetLowering::DAGCombinerInfo &DCI,
+ CodeGenOptLevel OptLevel) {
+ EVT VT = N0.getValueType();
+ if (N0.getOpcode() == ISD::FMUL) {
+ const auto *TLI = static_cast<const NVPTXTargetLowering *>(
+ &DCI.DAG.getTargetLoweringInfo());
+ if (!TLI->allowFMA(DCI.DAG.getMachineFunction(), OptLevel))
+ return SDValue();
+
+ // For floating point:
+ // Do the fusion only when the mul has less than 5 uses and all
+ // are add.
+ // The heuristic is that if a use is not an add, then that use
+ // cannot be fused into fma, therefore mul is still needed anyway.
+ // If there are more than 4 uses, even if they are all add, fusing
+ // them will increase register pressue.
+ //
+ int numUses = 0;
+ int nonAddCount = 0;
+ for (const SDNode *User : N0.getNode()->uses()) {
+ numUses++;
+ if (User->getOpcode() != ISD::FADD)
+ ++nonAddCount;
if (numUses >= 5)
return SDValue();
- if (nonAddCount) {
- int orderNo = N->getIROrder();
- int orderNo2 = N0.getNode()->getIROrder();
- // simple heuristics here for considering potential register
- // pressure, the logics here is that the differnce are used
- // to measure the distance between def and use, the longer distance
- // more likely cause register pressure.
- if (orderNo - orderNo2 < 500)
- return SDValue();
-
- // Now, check if at least one of the FMUL's operands is live beyond the node N,
- // which guarantees that the FMA will not increase register pressure at node N.
- bool opIsLive = false;
- const SDNode *left = N0.getOperand(0).getNode();
- const SDNode *right = N0.getOperand(1).getNode();
-
- if (isa<ConstantSDNode>(left) || isa<ConstantSDNode>(right))
- opIsLive = true;
-
- if (!opIsLive)
- for (const SDNode *User : left->uses()) {
- int orderNo3 = User->getIROrder();
- if (orderNo3 > orderNo) {
- opIsLive = true;
- break;
- }
- }
+ }
+ if (nonAddCount) {
+ int orderNo = N->getIROrder();
+ int orderNo2 = N0.getNode()->getIROrder();
+ // simple heuristics here for considering potential register
+ // pressure, the logics here is that the differnce are used
+ // to measure the distance between def and use, the longer distance
+ // more likely cause register pressure.
+ if (orderNo - orderNo2 < 500)
+ return SDValue();
- if (!opIsLive)
- for (const SDNode *User : right->uses()) {
- int orderNo3 = User->getIROrder();
- if (orderNo3 > orderNo) {
- opIsLive = true;
- break;
- }
+ // Now, check if at least one of the FMUL's operands is live beyond the
+ // node N, which guarantees that the FMA will not increase register
+ // pressure at node N.
+ bool opIsLive = false;
+ const SDNode *left = N0.getOperand(0).getNode();
+ const SDNode *right = N0.getOperand(1).getNode();
+
+ if (isa<ConstantSDNode>(left) || isa<ConstantSDNode>(right))
+ opIsLive = true;
+
+ if (!opIsLive)
+ for (const SDNode *User : left->uses()) {
+ int orderNo3 = User->getIROrder();
+ if (orderNo3 > orderNo) {
+ opIsLive = true;
+ break;
}
+ }
- if (!opIsLive)
- return SDValue();
- }
+ if (!opIsLive)
+ for (const SDNode *User : right->uses()) {
+ int orderNo3 = User->getIROrder();
+ if (orderNo3 > orderNo) {
+ opIsLive = true;
+ break;
+ }
+ }
- return DAG.getNode(ISD::FMA, SDLoc(N), VT,
- N0.getOperand(0), N0.getOperand(1), N1);
+ if (!opIsLive)
+ return SDValue();
}
+
+ return DCI.DAG.getNode(ISD::FMA, SDLoc(N), VT, N0.getOperand(0),
+ N0.getOperand(1), N1);
}
return SDValue();
@@ -5332,18 +5358,44 @@ static SDValue PerformStoreRetvalCombine(SDNode *N) {
///
static SDValue PerformADDCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
- const NVPTXSubtarget &Subtarget,
+ CodeGenOptLevel OptLevel) {
+ if (OptLevel == CodeGenOptLevel::None)
+ return SDValue();
+
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+
+ // Skip non-integer, non-scalar case
+ EVT VT = N0.getValueType();
+ if (VT.isVector() || VT != MVT::i32)
+ return SDValue();
+
+ // First try with the default operand order.
+ if (SDValue Result = PerformADDCombineWithOperands(N, N0, N1, DCI))
+ return Result;
+
+ // If that didn't work, try again with the operands commuted.
+ return PerformADDCombineWithOperands(N, N1, N0, DCI);
+}
+
+/// PerformFADDCombine - Target-specific dag combine xforms for ISD::FADD.
+///
+static SDValue PerformFADDCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
CodeGenOptLevel OptLevel) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
+ EVT VT = N0.getValueType();
+ if (VT.isVector() || !(VT == MVT::f32 || VT == MVT::f64))
+ return SDValue();
+
// First try with the default operand order.
- if (SDValue Result =
- PerformADDCombineWithOperands(N, N0, N1, DCI, Subtarget, OptLevel))
+ if (SDValue Result = PerformFADDCombineWithOperands(N, N0, N1, DCI, OptLevel))
return Result;
// If that didn't work, try again with the operands commuted.
- return PerformADDCombineWithOperands(N, N1, N0, DCI, Subtarget, OptLevel);
+ return PerformFADDCombineWithOperands(N, N1, N0, DCI, OptLevel);
}
static SDValue PerformANDCombine(SDNode *N,
@@ -5876,8 +5928,9 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
switch (N->getOpcode()) {
default: break;
case ISD::ADD:
+ return PerformADDCombine(N, DCI, OptLevel);
case ISD::FADD:
- return PerformADDCombine(N, DCI, STI, OptLevel);
+ return PerformFADDCombine(N, DCI, OptLevel);
case ISD::MUL:
return PerformMULCombine(N, DCI, OptLevel);
case ISD::SHL:
diff --git a/llvm/test/CodeGen/NVPTX/combine-mad.ll b/llvm/test/CodeGen/NVPTX/combine-mad.ll
index 0637bc916ea49..56bfaa14c5877 100644
--- a/llvm/test/CodeGen/NVPTX/combine-mad.ll
+++ b/llvm/test/CodeGen/NVPTX/combine-mad.ll
@@ -134,3 +134,52 @@ define i32 @test3(i32 %n, i32 %m, i32 %s) {
%mul = mul i32 %sel, %m
ret i32 %mul
}
+
+;; (add (select 0, (mul a, b)), c) -> (select (mad a, b, c), c)
+define i32 @test4(i32 %a, i32 %b, i32 %c, i1 %p) {
+; CHECK-LABEL: test4(
+; CHECK: {
+; CHECK-NEXT: .reg .pred %p<2>;
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .b32 %r<6>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u8 %rs1, [test4_param_3];
+; CHECK-NEXT: and.b16 %rs2, %rs1, 1;
+; CHECK-NEXT: setp.eq.b16 %p1, %rs2, 1;
+; CHECK-NEXT: ld.param.u32 %r1, [test4_param_0];
+; CHECK-NEXT: ld.param.u32 %r2, [test4_param_1];
+; CHECK-NEXT: ld.param.u32 %r3, [test4_param_2];
+; CHECK-NEXT: mad.lo.s32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT: selp.b32 %r5, %r4, %r3, %p1;
+; CHECK-NEXT: st.param.b32 [func_retval0+0], %r5;
+; CHECK-NEXT: ret;
+ %mul = mul i32 %a, %b
+ %sel = select i1 %p, i32 %mul, i32 0
+ %add = add i32 %c, %sel
+ ret i32 %add
+}
+
+define i32 @test4_rev(i32 %a, i32 %b, i32 %c, i1 %p) {
+; CHECK-LABEL: test4_rev(
+; CHECK: {
+; CHECK-NEXT: .reg .pred %p<2>;
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .b32 %r<6>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u8 %rs1, [test4_rev_param_3];
+; CHECK-NEXT: and.b16 %rs2, %rs1, 1;
+; CHECK-NEXT: setp.eq.b16 %p1, %rs2, 1;
+; CHECK-NEXT: ld.param.u32 %r1, [test4_rev_param_0];
+; CHECK-NEXT: ld.param.u32 %r2, [test4_rev_param_1];
+; CHECK-NEXT: ld.param.u32 %r3, [test4_rev_param_2];
+; CHECK-NEXT: mad.lo.s32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT: selp.b32 %r5, %r3, %r4, %p1;
+; CHECK-NEXT: st.param.b32 [func_retval0+0], %r5;
+; CHECK-NEXT: ret;
+ %mul = mul i32 %a, %b
+ %sel = select i1 %p, i32 0, i32 %mul
+ %add = add i32 %c, %sel
+ ret i32 %add
+}
>From 94733c9cc88ec1a249984a81dfd5e19cff1528ee Mon Sep 17 00:00:00 2001
From: Alex MacLean <amaclean at nvidia.com>
Date: Tue, 25 Jun 2024 19:49:07 +0000
Subject: [PATCH 2/2] address comments
---
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 13 +++++++------
1 file changed, 7 insertions(+), 6 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 0c609554370a3..d59f330cf0b77 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -5242,25 +5242,26 @@ PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
N0.getOperand(1), N1);
// fold (add (select cond, 0, (mul a, b)), c)
- // -> (select cond, (mad a, b, c), c)
+ // -> (select cond, c, (mad a, b, c))
//
if (N0.getOpcode() == ISD::SELECT) {
- bool ZeroCond;
+ unsigned ZeroOpNum;
if (isConstZero(N0->getOperand(1)))
- ZeroCond = true;
+ ZeroOpNum = 1;
else if (isConstZero(N0->getOperand(2)))
- ZeroCond = false;
+ ZeroOpNum = 2;
else
return SDValue();
- SDValue M = N0->getOperand(ZeroCond ? 2 : 1);
+ SDValue M = N0->getOperand((ZeroOpNum == 1) ? 2 : 1);
if (M->getOpcode() != ISD::MUL || !M.getNode()->hasOneUse())
return SDValue();
SDValue MAD = DCI.DAG.getNode(NVPTXISD::IMAD, SDLoc(N), VT,
M->getOperand(0), M->getOperand(1), N1);
return DCI.DAG.getSelect(SDLoc(N), VT, N0->getOperand(0),
- (ZeroCond ? N1 : MAD), (ZeroCond ? MAD : N1));
+ ((ZeroOpNum == 1) ? N1 : MAD),
+ ((ZeroOpNum == 1) ? MAD : N1));
}
return SDValue();
- Previous message: [llvm] [NVPTX] Fold (add (select 0, (mul a, b)), c) -> (select c, (mad a, b, c)) (PR #96352)
- Next message: [llvm] [NVPTX] Fold (add (select 0, (mul a, b)), c) -> (select c, (mad a, b, c)) (PR #96352)
- Messages sorted by:
[ date ]
[ thread ]
[ subject ]
[ author ]
More information about the llvm-commits
mailing list