[llvm] [NVPTX] Fold (add (select 0, (mul a, b)), c) -> (select c, (mad a, b, c)) (PR #96352)

Alex MacLean via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 25 12:51:00 PDT 2024


================
@@ -5215,103 +5215,129 @@ bool NVPTXTargetLowering::allowUnsafeFPMath(MachineFunction &MF) const {
   return F.getFnAttribute("unsafe-fp-math").getValueAsBool();
 }
 
+static bool isConstZero(const SDValue &Operand) {
+  const auto *Const = dyn_cast<ConstantSDNode>(Operand);
+  return Const && Const->getZExtValue() == 0;
+}
+
 /// PerformADDCombineWithOperands - Try DAG combinations for an ADD with
 /// operands N0 and N1.  This is a helper for PerformADDCombine that is
 /// called with the default operands, and if that fails, with commuted
 /// operands.
-static SDValue PerformADDCombineWithOperands(
-    SDNode *N, SDValue N0, SDValue N1, TargetLowering::DAGCombinerInfo &DCI,
-    const NVPTXSubtarget &Subtarget, CodeGenOptLevel OptLevel) {
-  SelectionDAG  &DAG = DCI.DAG;
-  // Skip non-integer, non-scalar case
-  EVT VT=N0.getValueType();
-  if (VT.isVector())
+static SDValue
+PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
+                              TargetLowering::DAGCombinerInfo &DCI) {
+  EVT VT = N0.getValueType();
+
+  // Since integer multiply-add costs the same as integer multiply
+  // but is more costly than integer add, do the fusion only when
+  // the mul is only used in the add.
+  if (!N0.getNode()->hasOneUse())
     return SDValue();
 
   // fold (add (mul a, b), c) -> (mad a, b, c)
   //
-  if (N0.getOpcode() == ISD::MUL) {
-    assert (VT.isInteger());
-    // For integer:
-    // Since integer multiply-add costs the same as integer multiply
-    // but is more costly than integer add, do the fusion only when
-    // the mul is only used in the add.
-    if (OptLevel == CodeGenOptLevel::None || VT != MVT::i32 ||
-        !N0.getNode()->hasOneUse())
+  if (N0.getOpcode() == ISD::MUL)
+    return DCI.DAG.getNode(NVPTXISD::IMAD, SDLoc(N), VT, N0.getOperand(0),
+                           N0.getOperand(1), N1);
+
+  // fold (add (select cond, 0, (mul a, b)), c)
+  //   -> (select cond, (mad a, b, c), c)
+  //
+  if (N0.getOpcode() == ISD::SELECT) {
+    bool ZeroCond;
+    if (isConstZero(N0->getOperand(1)))
+      ZeroCond = true;
+    else if (isConstZero(N0->getOperand(2)))
+      ZeroCond = false;
+    else
+      return SDValue();
+
+    SDValue M = N0->getOperand(ZeroCond ? 2 : 1);
----------------
AlexMaclean wrote:

Sounds good, I've updated this to use an index.

https://github.com/llvm/llvm-project/pull/96352


More information about the llvm-commits mailing list