[llvm] [NVPTX] Remove `NVPTX::IMAD` opcode, and rely on intruction selection only (PR #121724)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 6 22:43:43 PST 2025


https://github.com/peterbell10 updated https://github.com/llvm/llvm-project/pull/121724

>From 3e34b3ea613763e4e8a69a004abc662d4971c507 Mon Sep 17 00:00:00 2001
From: Peter Bell <peterbell10 at openai.com>
Date: Sun, 5 Jan 2025 07:25:19 +0000
Subject: [PATCH 1/4] [NVPTX] Add DAG combine patterns to simplify IMAD

I noticed that NVPTX will sometimes emit `mad.lo` to multiply by 1,
e.g. in https://gcc.godbolt.org/z/45W3Wcnxz

This happens when DAGCombiner operates on the add before the mul, so
the imad contraction happens regardless of whether the mul could have
been simplified.

This PR adds:
```
mad x 1 y => add x y
mad x -1 y => sub y x
mad x 0 y => y
mad x y 0 => mul x y
mad c0 c1 z => add z (C0 * C1)
```

Another option might be to remove `NVPTXISD::IMAD` and only combine to
mad during selection. This would allow the normal DAGCombiner patterns
to simplify the graph without any NVPTX-specific intervention. However,
it also risks DAGCombiner breaking up the mul-add patterns, which
is why I haven't done it that way.
---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp   | 49 +++++++++++++++++++
 llvm/test/CodeGen/NVPTX/combine-mad.ll        | 20 ++++++++
 llvm/test/CodeGen/NVPTX/dynamic_stackalloc.ll |  2 +-
 3 files changed, 70 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 5c1f717694a4c7..c4529c9151bc2b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -5164,6 +5164,53 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
   return DAG.getNode(ISD::BITCAST, DL, VT, PRMT);
 }
 
+static SDValue
+PerformIMADCombineWithOperands(SDNode *N, SDValue N0, SDValue N1, SDValue N2,
+                               TargetLowering::DAGCombinerInfo &DCI) {
+  ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
+  ConstantSDNode *N2C = dyn_cast<ConstantSDNode>(N2);
+  EVT VT = N0->getValueType(0);
+  SDLoc DL(N);
+  SDNodeFlags Flags = N->getFlags();
+
+  // mad x 1 y => add x y
+  if (N1C && N1C->isOne())
+    return DCI.DAG.getNode(ISD::ADD, DL, VT, N0, N2, Flags);
+
+  // mad x -1 y => sub y x
+  if (N1C && N1C->isAllOnes()) {
+    Flags.setNoUnsignedWrap(false);
+    return DCI.DAG.getNode(ISD::SUB, DL, VT, N2, N0, Flags);
+  }
+
+  // mad x 0 y => y
+  if (N1C && N1C->isZero())
+    return N2;
+
+  // mad x y 0 => mul x y
+  if (N2C && N2C->isZero())
+    return DCI.DAG.getNode(ISD::MUL, DL, VT, N0, N1, Flags);
+
+  // mad c0 c1 x => add x (c0*c1)
+  if (SDValue C =
+          DCI.DAG.FoldConstantArithmetic(ISD::MUL, DL, VT, {N0, N1}, Flags))
+    return DCI.DAG.getNode(ISD::ADD, DL, VT, N2, C, Flags);
+
+  return {};
+}
+
+static SDValue PerformIMADCombine(SDNode *N,
+                                  TargetLowering::DAGCombinerInfo &DCI) {
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+  SDValue N2 = N->getOperand(2);
+  SDValue res = PerformIMADCombineWithOperands(N, N0, N1, N2, DCI);
+  if (res)
+    return res;
+
+  return PerformIMADCombineWithOperands(N, N1, N0, N2, DCI);
+}
+
 SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
                                                DAGCombinerInfo &DCI) const {
   CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5198,6 +5245,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
       return PerformVSELECTCombine(N, DCI);
     case ISD::BUILD_VECTOR:
       return PerformBUILD_VECTORCombine(N, DCI);
+    case NVPTXISD::IMAD:
+      return PerformIMADCombine(N, DCI);
   }
   return SDValue();
 }
diff --git a/llvm/test/CodeGen/NVPTX/combine-mad.ll b/llvm/test/CodeGen/NVPTX/combine-mad.ll
index 1b22cfde39725f..7d523a835a1f3f 100644
--- a/llvm/test/CodeGen/NVPTX/combine-mad.ll
+++ b/llvm/test/CodeGen/NVPTX/combine-mad.ll
@@ -183,3 +183,23 @@ define i32 @test4_rev(i32 %a, i32 %b, i32 %c, i1 %p) {
   %add = add i32 %c, %sel
   ret i32 %add
 }
+
+;; This case relies on mad x 1 y => add x y, previously we emit:
+;;     mad.lo.s32      %r3, %r1, 1, %r2;
+define i32 @test_mad_fold(i32 %x) {
+; CHECK-LABEL: test_mad_fold(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test_mad_fold_param_0];
+; CHECK-NEXT:    mul.hi.s32 %r2, %r1, -2147221471;
+; CHECK-NEXT:    add.s32 %r3, %r1, %r2;
+; CHECK-NEXT:    shr.u32 %r4, %r3, 31;
+; CHECK-NEXT:    shr.s32 %r5, %r3, 12;
+; CHECK-NEXT:    add.s32 %r6, %r5, %r4;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r6;
+; CHECK-NEXT:    ret;
+  %div = sdiv i32 %x, 8191
+  ret i32 %div
+}
diff --git a/llvm/test/CodeGen/NVPTX/dynamic_stackalloc.ll b/llvm/test/CodeGen/NVPTX/dynamic_stackalloc.ll
index 27a523b9dd91d2..de19d2983f3435 100644
--- a/llvm/test/CodeGen/NVPTX/dynamic_stackalloc.ll
+++ b/llvm/test/CodeGen/NVPTX/dynamic_stackalloc.ll
@@ -12,7 +12,7 @@
 ; CHECK-NOT: __local_depot
 
 ; CHECK-32:       ld.param.u32  %r[[SIZE:[0-9]]], [test_dynamic_stackalloc_param_0];
-; CHECK-32-NEXT:  mad.lo.s32 %r[[SIZE2:[0-9]]], %r[[SIZE]], 1, 7;
+; CHECK-32-NEXT:  add.s32 %r[[SIZE2:[0-9]]], %r[[SIZE]], 7;
 ; CHECK-32-NEXT:  and.b32         %r[[SIZE3:[0-9]]], %r[[SIZE2]], -8;
 ; CHECK-32-NEXT:  alloca.u32  %r[[ALLOCA:[0-9]]], %r[[SIZE3]], 16;
 ; CHECK-32-NEXT:  cvta.local.u32  %r[[ALLOCA]], %r[[ALLOCA]];

>From 66adc32206482e3444057de278bfcaac730e3a7b Mon Sep 17 00:00:00 2001
From: Peter Bell <peterbell10 at openai.com>
Date: Mon, 6 Jan 2025 01:28:55 +0000
Subject: [PATCH 2/4] Add direct tests on combiner + infra to disable earlier
 opts

---
 llvm/include/llvm/CodeGen/SelectionDAG.h      |   5 +
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp |  10 +-
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 189 +++++-------------
 llvm/test/CodeGen/NVPTX/combine-mad-only.ll   |  87 ++++++++
 4 files changed, 152 insertions(+), 139 deletions(-)
 create mode 100644 llvm/test/CodeGen/NVPTX/combine-mad-only.ll

diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index ff7caec41855fd..3a015c8df2066a 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -2460,6 +2460,11 @@ class SelectionDAG {
   SDNode *FindNodeOrInsertPos(const FoldingSetNodeID &ID, const SDLoc &DL,
                               void *&InsertPos);
 
+  SDValue getNodeImpl(unsigned Opcode, const SDLoc &DL, EVT VT,
+                      ArrayRef<SDValue> Ops, SDNodeFlags Flags);
+  SDValue getNodeImpl(unsigned Opcode, const SDLoc &DL, SDVTList VTs,
+                      ArrayRef<SDValue> Ops, SDNodeFlags Flags);
+
   /// Maps to auto-CSE operations.
   std::vector<CondCodeSDNode*> CondCodeNodes;
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 6b2501591c81a3..6d75809cdaf69f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -153,6 +153,13 @@ static cl::opt<bool> EnableVectorFCopySignExtendRound(
     "combiner-vector-fcopysign-extend-round", cl::Hidden, cl::init(false),
     cl::desc(
         "Enable merging extends and rounds into FCOPYSIGN on vector types"));
+
+static cl::opt<bool>
+    EnableGenericCombines("combiner-generic-combines", cl::Hidden,
+                          cl::init(true),
+                          cl::desc("Enable generic DAGCombine patterns. Useful "
+                                   "for testing target-specific combines."));
+
 namespace {
 
   class DAGCombiner {
@@ -251,7 +258,8 @@ namespace {
         : DAG(D), TLI(D.getTargetLoweringInfo()),
           STI(D.getSubtarget().getSelectionDAGInfo()), OptLevel(OL), AA(AA) {
       ForCodeSize = DAG.shouldOptForSize();
-      DisableGenericCombines = STI && STI->disableGenericCombines(OptLevel);
+      DisableGenericCombines = !EnableGenericCombines ||
+                               (STI && STI->disableGenericCombines(OptLevel));
 
       MaximumLegalStoreInBits = 0;
       // We use the minimum store size here, since that's all we can guarantee
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 10e8ba93359fbd..6a3799e02edd94 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -115,6 +115,10 @@ static cl::opt<unsigned>
     MaxSteps("has-predecessor-max-steps", cl::Hidden, cl::init(8192),
              cl::desc("DAG combiner limit number of steps when searching DAG "
                       "for predecessor nodes"));
+static cl::opt<bool> EnableSimplifyNodes(
+    "selectiondag-simplify-nodes", cl::Hidden, cl::init(true),
+    cl::desc("Enable SelectionDAG::getNode simplifications. Useful for testing "
+             "DAG combines."));
 
 static void NewSDValueDbgMsg(SDValue V, StringRef Msg, SelectionDAG *G) {
   LLVM_DEBUG(dbgs() << Msg; V.getNode()->dump(G););
@@ -6157,23 +6161,46 @@ static SDValue foldCONCAT_VECTORS(const SDLoc &DL, EVT VT,
 }
 
 /// Gets or creates the specified node.
-SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT) {
+SDValue SelectionDAG::getNodeImpl(unsigned Opcode, const SDLoc &DL, EVT VT,
+                                  ArrayRef<SDValue> Ops,
+                                  const SDNodeFlags Flags) {
   SDVTList VTs = getVTList(VT);
-  FoldingSetNodeID ID;
-  AddNodeIDNode(ID, Opcode, VTs, {});
-  void *IP = nullptr;
-  if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP))
-    return SDValue(E, 0);
+  return getNodeImpl(Opcode, DL, VTs, Ops, Flags);
+}
 
-  auto *N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
-  CSEMap.InsertNode(N, IP);
+SDValue SelectionDAG::getNodeImpl(unsigned Opcode, const SDLoc &DL,
+                                  SDVTList VTs, ArrayRef<SDValue> Ops,
+                                  const SDNodeFlags Flags) {
+  SDNode *N;
+  // Don't CSE glue-producing nodes
+  if (VTs.VTs[VTs.NumVTs - 1] != MVT::Glue) {
+    FoldingSetNodeID ID;
+    AddNodeIDNode(ID, Opcode, VTs, Ops);
+    void *IP = nullptr;
+    if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
+      E->intersectFlagsWith(Flags);
+      return SDValue(E, 0);
+    }
+
+    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
+    createOperands(N, Ops);
+    CSEMap.InsertNode(N, IP);
+  } else {
+    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
+    createOperands(N, Ops);
+  }
 
+  N->setFlags(Flags);
   InsertNode(N);
   SDValue V = SDValue(N, 0);
   NewSDValueDbgMsg(V, "Creating new node: ", this);
   return V;
 }
 
+SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT) {
+  return getNodeImpl(Opcode, DL, VT, {}, SDNodeFlags{});
+}
+
 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
                               SDValue N1) {
   SDNodeFlags Flags;
@@ -6185,6 +6212,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
                               SDValue N1, const SDNodeFlags Flags) {
   assert(N1.getOpcode() != ISD::DELETED_NODE && "Operand is DELETED_NODE!");
+  if (!EnableSimplifyNodes)
+    return getNodeImpl(Opcode, DL, VT, {N1}, Flags);
 
   // Constant fold unary operations with a vector integer or float operand.
   switch (Opcode) {
@@ -6501,31 +6530,7 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
     break;
   }
 
-  SDNode *N;
-  SDVTList VTs = getVTList(VT);
-  SDValue Ops[] = {N1};
-  if (VT != MVT::Glue) { // Don't CSE glue producing nodes
-    FoldingSetNodeID ID;
-    AddNodeIDNode(ID, Opcode, VTs, Ops);
-    void *IP = nullptr;
-    if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
-      E->intersectFlagsWith(Flags);
-      return SDValue(E, 0);
-    }
-
-    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
-    N->setFlags(Flags);
-    createOperands(N, Ops);
-    CSEMap.InsertNode(N, IP);
-  } else {
-    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
-    createOperands(N, Ops);
-  }
-
-  InsertNode(N);
-  SDValue V = SDValue(N, 0);
-  NewSDValueDbgMsg(V, "Creating new node: ", this);
-  return V;
+  return getNodeImpl(Opcode, DL, VT, {N1}, Flags);
 }
 
 static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
@@ -7219,6 +7224,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
   assert(N1.getOpcode() != ISD::DELETED_NODE &&
          N2.getOpcode() != ISD::DELETED_NODE &&
          "Operand is DELETED_NODE!");
+  if (!EnableSimplifyNodes)
+    return getNodeImpl(Opcode, DL, VT, {N1, N2}, Flags);
 
   canonicalizeCommutativeBinop(Opcode, N1, N2);
 
@@ -7665,32 +7672,7 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
     }
   }
 
-  // Memoize this node if possible.
-  SDNode *N;
-  SDVTList VTs = getVTList(VT);
-  SDValue Ops[] = {N1, N2};
-  if (VT != MVT::Glue) {
-    FoldingSetNodeID ID;
-    AddNodeIDNode(ID, Opcode, VTs, Ops);
-    void *IP = nullptr;
-    if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
-      E->intersectFlagsWith(Flags);
-      return SDValue(E, 0);
-    }
-
-    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
-    N->setFlags(Flags);
-    createOperands(N, Ops);
-    CSEMap.InsertNode(N, IP);
-  } else {
-    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
-    createOperands(N, Ops);
-  }
-
-  InsertNode(N);
-  SDValue V = SDValue(N, 0);
-  NewSDValueDbgMsg(V, "Creating new node: ", this);
-  return V;
+  return getNodeImpl(Opcode, DL, VT, {N1, N2}, Flags);
 }
 
 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
@@ -7708,6 +7690,9 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
          N2.getOpcode() != ISD::DELETED_NODE &&
          N3.getOpcode() != ISD::DELETED_NODE &&
          "Operand is DELETED_NODE!");
+  if (!EnableSimplifyNodes)
+    return getNodeImpl(Opcode, DL, VT, {N1, N2, N3}, Flags);
+
   // Perform various simplifications.
   switch (Opcode) {
   case ISD::FMA:
@@ -7862,33 +7847,7 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
     break;
   }
   }
-
-  // Memoize node if it doesn't produce a glue result.
-  SDNode *N;
-  SDVTList VTs = getVTList(VT);
-  SDValue Ops[] = {N1, N2, N3};
-  if (VT != MVT::Glue) {
-    FoldingSetNodeID ID;
-    AddNodeIDNode(ID, Opcode, VTs, Ops);
-    void *IP = nullptr;
-    if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
-      E->intersectFlagsWith(Flags);
-      return SDValue(E, 0);
-    }
-
-    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
-    N->setFlags(Flags);
-    createOperands(N, Ops);
-    CSEMap.InsertNode(N, IP);
-  } else {
-    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
-    createOperands(N, Ops);
-  }
-
-  InsertNode(N);
-  SDValue V = SDValue(N, 0);
-  NewSDValueDbgMsg(V, "Creating new node: ", this);
-  return V;
+  return getNodeImpl(Opcode, DL, VT, {N1, N2, N3}, Flags);
 }
 
 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
@@ -10343,6 +10302,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
     assert(Op.getOpcode() != ISD::DELETED_NODE &&
            "Operand is DELETED_NODE!");
 #endif
+  if (!EnableSimplifyNodes)
+    return getNodeImpl(Opcode, DL, VT, Ops, Flags);
 
   switch (Opcode) {
   default: break;
@@ -10411,34 +10372,7 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
     break;
   }
 
-  // Memoize nodes.
-  SDNode *N;
-  SDVTList VTs = getVTList(VT);
-
-  if (VT != MVT::Glue) {
-    FoldingSetNodeID ID;
-    AddNodeIDNode(ID, Opcode, VTs, Ops);
-    void *IP = nullptr;
-
-    if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
-      E->intersectFlagsWith(Flags);
-      return SDValue(E, 0);
-    }
-
-    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
-    createOperands(N, Ops);
-
-    CSEMap.InsertNode(N, IP);
-  } else {
-    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
-    createOperands(N, Ops);
-  }
-
-  N->setFlags(Flags);
-  InsertNode(N);
-  SDValue V(N, 0);
-  NewSDValueDbgMsg(V, "Creating new node: ", this);
-  return V;
+  return getNodeImpl(Opcode, DL, VT, Ops, Flags);
 }
 
 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL,
@@ -10458,6 +10392,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
                               ArrayRef<SDValue> Ops, const SDNodeFlags Flags) {
   if (VTList.NumVTs == 1)
     return getNode(Opcode, DL, VTList.VTs[0], Ops, Flags);
+  if (!EnableSimplifyNodes)
+    return getNodeImpl(Opcode, DL, VTList, Ops, Flags);
 
 #ifndef NDEBUG
   for (const auto &Op : Ops)
@@ -10622,30 +10558,7 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
 #endif
   }
 
-  // Memoize the node unless it returns a glue result.
-  SDNode *N;
-  if (VTList.VTs[VTList.NumVTs-1] != MVT::Glue) {
-    FoldingSetNodeID ID;
-    AddNodeIDNode(ID, Opcode, VTList, Ops);
-    void *IP = nullptr;
-    if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
-      E->intersectFlagsWith(Flags);
-      return SDValue(E, 0);
-    }
-
-    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTList);
-    createOperands(N, Ops);
-    CSEMap.InsertNode(N, IP);
-  } else {
-    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTList);
-    createOperands(N, Ops);
-  }
-
-  N->setFlags(Flags);
-  InsertNode(N);
-  SDValue V(N, 0);
-  NewSDValueDbgMsg(V, "Creating new node: ", this);
-  return V;
+  return getNodeImpl(Opcode, DL, VTList, Ops, Flags);
 }
 
 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL,
diff --git a/llvm/test/CodeGen/NVPTX/combine-mad-only.ll b/llvm/test/CodeGen/NVPTX/combine-mad-only.ll
new file mode 100644
index 00000000000000..fb4bcc39b5a64d
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/combine-mad-only.ll
@@ -0,0 +1,87 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mtriple=nvptx -mcpu=sm_20 -O1 --selectiondag-simplify-nodes=false --combiner-generic-combines=false --debug-counter=early-cse=100, | FileCheck %s
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_20 -O1 --selectiondag-simplify-nodes=false --combiner-generic-combines=false --debug-counter=early-cse=100, | FileCheck %s
+; RUN: %if ptxas && !ptxas-12.0 %{ llc < %s -mtriple=nvptx -mcpu=sm_20 -O1 --selectiondag-simplify-nodes=false --combiner-generic-combines=false --debug-counter=early-cse=100, | %ptxas-verify %}
+; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_20 -O1 --selectiondag-simplify-nodes=false --combiner-generic-combines=false --debug-counter=early-cse=100, | %ptxas-verify %}
+
+;; mad x 1 y => add y x
+define i32 @test_mad_mul_1(i32 %x, i32 %y) {
+; CHECK-LABEL: test_mad_mul_1(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test_mad_mul_1_param_0];
+; CHECK-NEXT:    ld.param.u32 %r2, [test_mad_mul_1_param_1];
+; CHECK-NEXT:    add.s32 %r3, %r1, %r2;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT:    ret;
+  %mul = mul i32 %x, 1
+  %add = add i32 %mul, %y
+  ret i32 %add
+}
+
+;; mad x -1 y => sub y x
+define i32 @test_mad_mul_neg_1(i32 %x, i32 %y) {
+; CHECK-LABEL: test_mad_mul_neg_1(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test_mad_mul_neg_1_param_0];
+; CHECK-NEXT:    ld.param.u32 %r2, [test_mad_mul_neg_1_param_1];
+; CHECK-NEXT:    sub.s32 %r3, %r2, %r1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT:    ret;
+  %mul = mul i32 %x, -1
+  %add = add i32 %mul, %y
+  ret i32 %add
+}
+
+;; mad x 0 y => y
+define i32 @test_mad_mul_0(i32 %x, i32 %y) {
+; CHECK-LABEL: test_mad_mul_0(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test_mad_mul_0_param_1];
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT:    ret;
+  %mul = mul i32 %x, 0
+  %add = add i32 %mul, %y
+  ret i32 %add
+}
+
+;; mad x y 0 => mul x y
+define i32 @test_mad_add_0(i32 %x, i32 %y) {
+; CHECK-LABEL: test_mad_add_0(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test_mad_add_0_param_0];
+; CHECK-NEXT:    ld.param.u32 %r2, [test_mad_add_0_param_1];
+; CHECK-NEXT:    mul.lo.s32 %r3, %r1, %r2;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT:    ret;
+  %mul = mul i32 %x, %y
+  %add = add i32 %mul, 0
+  ret i32 %add
+}
+
+;; mad c0 c1 x => add x (c0*c1)
+define i32 @test_mad_fold_mul(i32 %x) {
+; CHECK-LABEL: test_mad_fold_mul(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test_mad_fold_mul_param_0];
+; CHECK-NEXT:    add.s32 %r2, %r1, 12;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r2;
+; CHECK-NEXT:    ret;
+  %mul = mul i32 4, 3
+  %add = add i32 %mul, %x
+  ret i32 %add
+}

>From 1765171071de2db8fd372d8eeea14f018ad6387b Mon Sep 17 00:00:00 2001
From: Peter Bell <peterbell10 at openai.com>
Date: Tue, 7 Jan 2025 05:59:57 +0000
Subject: [PATCH 3/4] Remove NVPTX::IMAD and rely on ISel only

---
 llvm/include/llvm/CodeGen/SelectionDAG.h      |   5 -
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp |  10 +-
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 189 ++++++++++++-----
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp   |  71 +------
 llvm/lib/Target/NVPTX/NVPTXISelLowering.h     |   1 -
 llvm/lib/Target/NVPTX/NVPTXInstrInfo.td       |  99 +++------
 llvm/test/CodeGen/NVPTX/combine-mad-only.ll   |  87 --------
 llvm/test/CodeGen/NVPTX/combine-mad.ll        |   2 +-
 llvm/test/CodeGen/NVPTX/i128.ll               | 192 +++++++++---------
 9 files changed, 275 insertions(+), 381 deletions(-)
 delete mode 100644 llvm/test/CodeGen/NVPTX/combine-mad-only.ll

diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 3a015c8df2066a..ff7caec41855fd 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -2460,11 +2460,6 @@ class SelectionDAG {
   SDNode *FindNodeOrInsertPos(const FoldingSetNodeID &ID, const SDLoc &DL,
                               void *&InsertPos);
 
-  SDValue getNodeImpl(unsigned Opcode, const SDLoc &DL, EVT VT,
-                      ArrayRef<SDValue> Ops, SDNodeFlags Flags);
-  SDValue getNodeImpl(unsigned Opcode, const SDLoc &DL, SDVTList VTs,
-                      ArrayRef<SDValue> Ops, SDNodeFlags Flags);
-
   /// Maps to auto-CSE operations.
   std::vector<CondCodeSDNode*> CondCodeNodes;
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 6d75809cdaf69f..6b2501591c81a3 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -153,13 +153,6 @@ static cl::opt<bool> EnableVectorFCopySignExtendRound(
     "combiner-vector-fcopysign-extend-round", cl::Hidden, cl::init(false),
     cl::desc(
         "Enable merging extends and rounds into FCOPYSIGN on vector types"));
-
-static cl::opt<bool>
-    EnableGenericCombines("combiner-generic-combines", cl::Hidden,
-                          cl::init(true),
-                          cl::desc("Enable generic DAGCombine patterns. Useful "
-                                   "for testing target-specific combines."));
-
 namespace {
 
   class DAGCombiner {
@@ -258,8 +251,7 @@ namespace {
         : DAG(D), TLI(D.getTargetLoweringInfo()),
           STI(D.getSubtarget().getSelectionDAGInfo()), OptLevel(OL), AA(AA) {
       ForCodeSize = DAG.shouldOptForSize();
-      DisableGenericCombines = !EnableGenericCombines ||
-                               (STI && STI->disableGenericCombines(OptLevel));
+      DisableGenericCombines = STI && STI->disableGenericCombines(OptLevel);
 
       MaximumLegalStoreInBits = 0;
       // We use the minimum store size here, since that's all we can guarantee
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 6a3799e02edd94..10e8ba93359fbd 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -115,10 +115,6 @@ static cl::opt<unsigned>
     MaxSteps("has-predecessor-max-steps", cl::Hidden, cl::init(8192),
              cl::desc("DAG combiner limit number of steps when searching DAG "
                       "for predecessor nodes"));
-static cl::opt<bool> EnableSimplifyNodes(
-    "selectiondag-simplify-nodes", cl::Hidden, cl::init(true),
-    cl::desc("Enable SelectionDAG::getNode simplifications. Useful for testing "
-             "DAG combines."));
 
 static void NewSDValueDbgMsg(SDValue V, StringRef Msg, SelectionDAG *G) {
   LLVM_DEBUG(dbgs() << Msg; V.getNode()->dump(G););
@@ -6161,46 +6157,23 @@ static SDValue foldCONCAT_VECTORS(const SDLoc &DL, EVT VT,
 }
 
 /// Gets or creates the specified node.
-SDValue SelectionDAG::getNodeImpl(unsigned Opcode, const SDLoc &DL, EVT VT,
-                                  ArrayRef<SDValue> Ops,
-                                  const SDNodeFlags Flags) {
+SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT) {
   SDVTList VTs = getVTList(VT);
-  return getNodeImpl(Opcode, DL, VTs, Ops, Flags);
-}
-
-SDValue SelectionDAG::getNodeImpl(unsigned Opcode, const SDLoc &DL,
-                                  SDVTList VTs, ArrayRef<SDValue> Ops,
-                                  const SDNodeFlags Flags) {
-  SDNode *N;
-  // Don't CSE glue-producing nodes
-  if (VTs.VTs[VTs.NumVTs - 1] != MVT::Glue) {
-    FoldingSetNodeID ID;
-    AddNodeIDNode(ID, Opcode, VTs, Ops);
-    void *IP = nullptr;
-    if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
-      E->intersectFlagsWith(Flags);
-      return SDValue(E, 0);
-    }
+  FoldingSetNodeID ID;
+  AddNodeIDNode(ID, Opcode, VTs, {});
+  void *IP = nullptr;
+  if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP))
+    return SDValue(E, 0);
 
-    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
-    createOperands(N, Ops);
-    CSEMap.InsertNode(N, IP);
-  } else {
-    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
-    createOperands(N, Ops);
-  }
+  auto *N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
+  CSEMap.InsertNode(N, IP);
 
-  N->setFlags(Flags);
   InsertNode(N);
   SDValue V = SDValue(N, 0);
   NewSDValueDbgMsg(V, "Creating new node: ", this);
   return V;
 }
 
-SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT) {
-  return getNodeImpl(Opcode, DL, VT, {}, SDNodeFlags{});
-}
-
 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
                               SDValue N1) {
   SDNodeFlags Flags;
@@ -6212,8 +6185,6 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
                               SDValue N1, const SDNodeFlags Flags) {
   assert(N1.getOpcode() != ISD::DELETED_NODE && "Operand is DELETED_NODE!");
-  if (!EnableSimplifyNodes)
-    return getNodeImpl(Opcode, DL, VT, {N1}, Flags);
 
   // Constant fold unary operations with a vector integer or float operand.
   switch (Opcode) {
@@ -6530,7 +6501,31 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
     break;
   }
 
-  return getNodeImpl(Opcode, DL, VT, {N1}, Flags);
+  SDNode *N;
+  SDVTList VTs = getVTList(VT);
+  SDValue Ops[] = {N1};
+  if (VT != MVT::Glue) { // Don't CSE glue producing nodes
+    FoldingSetNodeID ID;
+    AddNodeIDNode(ID, Opcode, VTs, Ops);
+    void *IP = nullptr;
+    if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
+      E->intersectFlagsWith(Flags);
+      return SDValue(E, 0);
+    }
+
+    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
+    N->setFlags(Flags);
+    createOperands(N, Ops);
+    CSEMap.InsertNode(N, IP);
+  } else {
+    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
+    createOperands(N, Ops);
+  }
+
+  InsertNode(N);
+  SDValue V = SDValue(N, 0);
+  NewSDValueDbgMsg(V, "Creating new node: ", this);
+  return V;
 }
 
 static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
@@ -7224,8 +7219,6 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
   assert(N1.getOpcode() != ISD::DELETED_NODE &&
          N2.getOpcode() != ISD::DELETED_NODE &&
          "Operand is DELETED_NODE!");
-  if (!EnableSimplifyNodes)
-    return getNodeImpl(Opcode, DL, VT, {N1, N2}, Flags);
 
   canonicalizeCommutativeBinop(Opcode, N1, N2);
 
@@ -7672,7 +7665,32 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
     }
   }
 
-  return getNodeImpl(Opcode, DL, VT, {N1, N2}, Flags);
+  // Memoize this node if possible.
+  SDNode *N;
+  SDVTList VTs = getVTList(VT);
+  SDValue Ops[] = {N1, N2};
+  if (VT != MVT::Glue) {
+    FoldingSetNodeID ID;
+    AddNodeIDNode(ID, Opcode, VTs, Ops);
+    void *IP = nullptr;
+    if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
+      E->intersectFlagsWith(Flags);
+      return SDValue(E, 0);
+    }
+
+    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
+    N->setFlags(Flags);
+    createOperands(N, Ops);
+    CSEMap.InsertNode(N, IP);
+  } else {
+    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
+    createOperands(N, Ops);
+  }
+
+  InsertNode(N);
+  SDValue V = SDValue(N, 0);
+  NewSDValueDbgMsg(V, "Creating new node: ", this);
+  return V;
 }
 
 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
@@ -7690,9 +7708,6 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
          N2.getOpcode() != ISD::DELETED_NODE &&
          N3.getOpcode() != ISD::DELETED_NODE &&
          "Operand is DELETED_NODE!");
-  if (!EnableSimplifyNodes)
-    return getNodeImpl(Opcode, DL, VT, {N1, N2, N3}, Flags);
-
   // Perform various simplifications.
   switch (Opcode) {
   case ISD::FMA:
@@ -7847,7 +7862,33 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
     break;
   }
   }
-  return getNodeImpl(Opcode, DL, VT, {N1, N2, N3}, Flags);
+
+  // Memoize node if it doesn't produce a glue result.
+  SDNode *N;
+  SDVTList VTs = getVTList(VT);
+  SDValue Ops[] = {N1, N2, N3};
+  if (VT != MVT::Glue) {
+    FoldingSetNodeID ID;
+    AddNodeIDNode(ID, Opcode, VTs, Ops);
+    void *IP = nullptr;
+    if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
+      E->intersectFlagsWith(Flags);
+      return SDValue(E, 0);
+    }
+
+    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
+    N->setFlags(Flags);
+    createOperands(N, Ops);
+    CSEMap.InsertNode(N, IP);
+  } else {
+    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
+    createOperands(N, Ops);
+  }
+
+  InsertNode(N);
+  SDValue V = SDValue(N, 0);
+  NewSDValueDbgMsg(V, "Creating new node: ", this);
+  return V;
 }
 
 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
@@ -10302,8 +10343,6 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
     assert(Op.getOpcode() != ISD::DELETED_NODE &&
            "Operand is DELETED_NODE!");
 #endif
-  if (!EnableSimplifyNodes)
-    return getNodeImpl(Opcode, DL, VT, Ops, Flags);
 
   switch (Opcode) {
   default: break;
@@ -10372,7 +10411,34 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
     break;
   }
 
-  return getNodeImpl(Opcode, DL, VT, Ops, Flags);
+  // Memoize nodes.
+  SDNode *N;
+  SDVTList VTs = getVTList(VT);
+
+  if (VT != MVT::Glue) {
+    FoldingSetNodeID ID;
+    AddNodeIDNode(ID, Opcode, VTs, Ops);
+    void *IP = nullptr;
+
+    if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
+      E->intersectFlagsWith(Flags);
+      return SDValue(E, 0);
+    }
+
+    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
+    createOperands(N, Ops);
+
+    CSEMap.InsertNode(N, IP);
+  } else {
+    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
+    createOperands(N, Ops);
+  }
+
+  N->setFlags(Flags);
+  InsertNode(N);
+  SDValue V(N, 0);
+  NewSDValueDbgMsg(V, "Creating new node: ", this);
+  return V;
 }
 
 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL,
@@ -10392,8 +10458,6 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
                               ArrayRef<SDValue> Ops, const SDNodeFlags Flags) {
   if (VTList.NumVTs == 1)
     return getNode(Opcode, DL, VTList.VTs[0], Ops, Flags);
-  if (!EnableSimplifyNodes)
-    return getNodeImpl(Opcode, DL, VTList, Ops, Flags);
 
 #ifndef NDEBUG
   for (const auto &Op : Ops)
@@ -10558,7 +10622,30 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
 #endif
   }
 
-  return getNodeImpl(Opcode, DL, VTList, Ops, Flags);
+  // Memoize the node unless it returns a glue result.
+  SDNode *N;
+  if (VTList.VTs[VTList.NumVTs-1] != MVT::Glue) {
+    FoldingSetNodeID ID;
+    AddNodeIDNode(ID, Opcode, VTList, Ops);
+    void *IP = nullptr;
+    if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
+      E->intersectFlagsWith(Flags);
+      return SDValue(E, 0);
+    }
+
+    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTList);
+    createOperands(N, Ops);
+    CSEMap.InsertNode(N, IP);
+  } else {
+    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTList);
+    createOperands(N, Ops);
+  }
+
+  N->setFlags(Flags);
+  InsertNode(N);
+  SDValue V(N, 0);
+  NewSDValueDbgMsg(V, "Creating new node: ", this);
+  return V;
 }
 
 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL,
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index c4529c9151bc2b..362dc4338c72e8 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1037,7 +1037,6 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(NVPTXISD::StoreV4)
     MAKE_CASE(NVPTXISD::FSHL_CLAMP)
     MAKE_CASE(NVPTXISD::FSHR_CLAMP)
-    MAKE_CASE(NVPTXISD::IMAD)
     MAKE_CASE(NVPTXISD::BFE)
     MAKE_CASE(NVPTXISD::BFI)
     MAKE_CASE(NVPTXISD::PRMT)
@@ -4442,14 +4441,8 @@ PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
   if (!N0.getNode()->hasOneUse())
     return SDValue();
 
-  // fold (add (mul a, b), c) -> (mad a, b, c)
-  //
-  if (N0.getOpcode() == ISD::MUL)
-    return DCI.DAG.getNode(NVPTXISD::IMAD, SDLoc(N), VT, N0.getOperand(0),
-                           N0.getOperand(1), N1);
-
   // fold (add (select cond, 0, (mul a, b)), c)
-  //   -> (select cond, c, (mad a, b, c))
+  //   -> (select cond, c, (add (mul a, b), c))
   //
   if (N0.getOpcode() == ISD::SELECT) {
     unsigned ZeroOpNum;
@@ -4464,8 +4457,9 @@ PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
     if (M->getOpcode() != ISD::MUL || !M.getNode()->hasOneUse())
       return SDValue();
 
-    SDValue MAD = DCI.DAG.getNode(NVPTXISD::IMAD, SDLoc(N), VT,
-                                  M->getOperand(0), M->getOperand(1), N1);
+    SDLoc DL(N);
+    SDValue Mul = DCI.DAG.getNode(ISD::MUL, DL, VT, M->getOperand(0), M->getOperand(1));
+    SDValue MAD = DCI.DAG.getNode(ISD::ADD, DL, VT, Mul, N1);
     return DCI.DAG.getSelect(SDLoc(N), VT, N0->getOperand(0),
                              ((ZeroOpNum == 1) ? N1 : MAD),
                              ((ZeroOpNum == 1) ? MAD : N1));
@@ -4902,8 +4896,10 @@ static SDValue matchMADConstOnePattern(SDValue Add) {
 static SDValue combineMADConstOne(SDValue X, SDValue Add, EVT VT, SDLoc DL,
                                   TargetLowering::DAGCombinerInfo &DCI) {
 
-  if (SDValue Y = matchMADConstOnePattern(Add))
-    return DCI.DAG.getNode(NVPTXISD::IMAD, DL, VT, X, Y, X);
+  if (SDValue Y = matchMADConstOnePattern(Add)) {
+    SDValue Mul = DCI.DAG.getNode(ISD::MUL, DL, VT, X, Y);
+    return DCI.DAG.getNode(ISD::ADD, DL, VT, Mul, X);
+  }
 
   return SDValue();
 }
@@ -4950,7 +4946,7 @@ PerformMULCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
 
   SDLoc DL(N);
 
-  // (mul x, (add y, 1)) -> (mad x, y, x)
+  // (mul x, (add y, 1)) -> (add (mul x, y), x)
   if (SDValue Res = combineMADConstOne(N0, N1, VT, DL, DCI))
     return Res;
   if (SDValue Res = combineMADConstOne(N1, N0, VT, DL, DCI))
@@ -5164,53 +5160,6 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
   return DAG.getNode(ISD::BITCAST, DL, VT, PRMT);
 }
 
-static SDValue
-PerformIMADCombineWithOperands(SDNode *N, SDValue N0, SDValue N1, SDValue N2,
-                               TargetLowering::DAGCombinerInfo &DCI) {
-  ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
-  ConstantSDNode *N2C = dyn_cast<ConstantSDNode>(N2);
-  EVT VT = N0->getValueType(0);
-  SDLoc DL(N);
-  SDNodeFlags Flags = N->getFlags();
-
-  // mad x 1 y => add x y
-  if (N1C && N1C->isOne())
-    return DCI.DAG.getNode(ISD::ADD, DL, VT, N0, N2, Flags);
-
-  // mad x -1 y => sub y x
-  if (N1C && N1C->isAllOnes()) {
-    Flags.setNoUnsignedWrap(false);
-    return DCI.DAG.getNode(ISD::SUB, DL, VT, N2, N0, Flags);
-  }
-
-  // mad x 0 y => y
-  if (N1C && N1C->isZero())
-    return N2;
-
-  // mad x y 0 => mul x y
-  if (N2C && N2C->isZero())
-    return DCI.DAG.getNode(ISD::MUL, DL, VT, N0, N1, Flags);
-
-  // mad c0 c1 x => add x (c0*c1)
-  if (SDValue C =
-          DCI.DAG.FoldConstantArithmetic(ISD::MUL, DL, VT, {N0, N1}, Flags))
-    return DCI.DAG.getNode(ISD::ADD, DL, VT, N2, C, Flags);
-
-  return {};
-}
-
-static SDValue PerformIMADCombine(SDNode *N,
-                                  TargetLowering::DAGCombinerInfo &DCI) {
-  SDValue N0 = N->getOperand(0);
-  SDValue N1 = N->getOperand(1);
-  SDValue N2 = N->getOperand(2);
-  SDValue res = PerformIMADCombineWithOperands(N, N0, N1, N2, DCI);
-  if (res)
-    return res;
-
-  return PerformIMADCombineWithOperands(N, N1, N0, N2, DCI);
-}
-
 SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
                                                DAGCombinerInfo &DCI) const {
   CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5245,8 +5194,6 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
       return PerformVSELECTCombine(N, DCI);
     case ISD::BUILD_VECTOR:
       return PerformBUILD_VECTORCombine(N, DCI);
-    case NVPTXISD::IMAD:
-      return PerformIMADCombine(N, DCI);
   }
   return SDValue();
 }
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 4a98fe21b81dc6..51265ed2179d88 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -55,7 +55,6 @@ enum NodeType : unsigned {
   FSHR_CLAMP,
   MUL_WIDE_SIGNED,
   MUL_WIDE_UNSIGNED,
-  IMAD,
   SETP_F16X2,
   SETP_BF16X2,
   BFE,
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index c3e72d6ce3a3f8..f24c35d9f1cca7 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -141,6 +141,7 @@ def hasLDG : Predicate<"Subtarget->hasLDG()">;
 def hasLDU : Predicate<"Subtarget->hasLDU()">;
 def hasPTXASUnreachableBug : Predicate<"Subtarget->hasPTXASUnreachableBug()">;
 def noPTXASUnreachableBug : Predicate<"!Subtarget->hasPTXASUnreachableBug()">;
+def hasO1 : Predicate<"TM.getOptLevel() != CodeGenOptLevel::None">;
 
 def doF32FTZ : Predicate<"useF32FTZ()">;
 def doNoF32FTZ : Predicate<"!useF32FTZ()">;
@@ -1069,73 +1070,37 @@ def : Pat<(mul (zext i16:$a), (i32 UInt16Const:$b)),
 //
 // Integer multiply-add
 //
-def SDTIMAD :
-  SDTypeProfile<1, 3, [SDTCisSameAs<0, 1>, SDTCisInt<0>, SDTCisInt<2>,
-                       SDTCisSameAs<0, 2>, SDTCisSameAs<0, 3>]>;
-def imad : SDNode<"NVPTXISD::IMAD", SDTIMAD>;
-
-def MAD16rrr :
-  NVPTXInst<(outs Int16Regs:$dst),
-            (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c),
-            "mad.lo.s16 \t$dst, $a, $b, $c;",
-            [(set i16:$dst, (imad i16:$a, i16:$b, i16:$c))]>;
-def MAD16rri :
-  NVPTXInst<(outs Int16Regs:$dst),
-            (ins Int16Regs:$a, Int16Regs:$b, i16imm:$c),
-            "mad.lo.s16 \t$dst, $a, $b, $c;",
-            [(set i16:$dst, (imad i16:$a, i16:$b, imm:$c))]>;
-def MAD16rir :
-  NVPTXInst<(outs Int16Regs:$dst),
-            (ins Int16Regs:$a, i16imm:$b, Int16Regs:$c),
-            "mad.lo.s16 \t$dst, $a, $b, $c;",
-            [(set i16:$dst, (imad i16:$a, imm:$b, i16:$c))]>;
-def MAD16rii :
-  NVPTXInst<(outs Int16Regs:$dst),
-            (ins Int16Regs:$a, i16imm:$b, i16imm:$c),
-            "mad.lo.s16 \t$dst, $a, $b, $c;",
-            [(set i16:$dst, (imad i16:$a, imm:$b, imm:$c))]>;
-
-def MAD32rrr :
-  NVPTXInst<(outs Int32Regs:$dst),
-            (ins Int32Regs:$a, Int32Regs:$b, Int32Regs:$c),
-            "mad.lo.s32 \t$dst, $a, $b, $c;",
-            [(set i32:$dst, (imad i32:$a, i32:$b, i32:$c))]>;
-def MAD32rri :
-  NVPTXInst<(outs Int32Regs:$dst),
-            (ins Int32Regs:$a, Int32Regs:$b, i32imm:$c),
-            "mad.lo.s32 \t$dst, $a, $b, $c;",
-            [(set i32:$dst, (imad i32:$a, i32:$b, imm:$c))]>;
-def MAD32rir :
-  NVPTXInst<(outs Int32Regs:$dst),
-            (ins Int32Regs:$a, i32imm:$b, Int32Regs:$c),
-            "mad.lo.s32 \t$dst, $a, $b, $c;",
-            [(set i32:$dst, (imad i32:$a, imm:$b, i32:$c))]>;
-def MAD32rii :
-  NVPTXInst<(outs Int32Regs:$dst),
-            (ins Int32Regs:$a, i32imm:$b, i32imm:$c),
-            "mad.lo.s32 \t$dst, $a, $b, $c;",
-            [(set i32:$dst, (imad i32:$a, imm:$b, imm:$c))]>;
-
-def MAD64rrr :
-  NVPTXInst<(outs Int64Regs:$dst),
-            (ins Int64Regs:$a, Int64Regs:$b, Int64Regs:$c),
-            "mad.lo.s64 \t$dst, $a, $b, $c;",
-            [(set i64:$dst, (imad i64:$a, i64:$b, i64:$c))]>;
-def MAD64rri :
-  NVPTXInst<(outs Int64Regs:$dst),
-            (ins Int64Regs:$a, Int64Regs:$b, i64imm:$c),
-            "mad.lo.s64 \t$dst, $a, $b, $c;",
-            [(set i64:$dst, (imad i64:$a, i64:$b, imm:$c))]>;
-def MAD64rir :
-  NVPTXInst<(outs Int64Regs:$dst),
-            (ins Int64Regs:$a, i64imm:$b, Int64Regs:$c),
-            "mad.lo.s64 \t$dst, $a, $b, $c;",
-            [(set i64:$dst, (imad i64:$a, imm:$b, i64:$c))]>;
-def MAD64rii :
-  NVPTXInst<(outs Int64Regs:$dst),
-            (ins Int64Regs:$a, i64imm:$b, i64imm:$c),
-            "mad.lo.s64 \t$dst, $a, $b, $c;",
-            [(set i64:$dst, (imad i64:$a, imm:$b, imm:$c))]>;
+
+multiclass MAD<string Ptx, ValueType VT, NVPTXRegClass Reg, Operand Imm> {
+
+  def rrr:
+    NVPTXInst<(outs Reg:$dst),
+              (ins Reg:$a, Reg:$b, Reg:$c),
+              Ptx # " \t$dst, $a, $b, $c;",
+              [(set VT:$dst, (add (mul VT:$a, VT:$b), VT:$c))]>;
+
+  def rir:
+    NVPTXInst<(outs Reg:$dst),
+              (ins Reg:$a, Imm:$b, Reg:$c),
+              Ptx # " \t$dst, $a, $b, $c;",
+              [(set VT:$dst, (add (mul VT:$a, imm:$b), VT:$c))]>;
+  def rri:
+    NVPTXInst<(outs Reg:$dst),
+              (ins Reg:$a, Reg:$b, Imm:$c),
+              Ptx # " \t$dst, $a, $b, $c;",
+              [(set VT:$dst, (add (mul VT:$a, VT:$b), imm:$c))]>;
+  def rii:
+    NVPTXInst<(outs Reg:$dst),
+              (ins Reg:$a, Imm:$b, Imm:$c),
+              Ptx # " \t$dst, $a, $b, $c;",
+              [(set VT:$dst, (add (mul VT:$a, imm:$b), imm:$c))]>;
+}
+
+let Predicates = [hasO1] in {
+defm MAD16 : MAD<"mad.lo.s16", i16, Int16Regs, i16imm>;
+defm MAD32 : MAD<"mad.lo.s32", i32, Int32Regs, i32imm>;
+defm MAD64 : MAD<"mad.lo.s64", i64, Int64Regs, i64imm>;
+}
 
 def INEG16 :
   NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$src),
diff --git a/llvm/test/CodeGen/NVPTX/combine-mad-only.ll b/llvm/test/CodeGen/NVPTX/combine-mad-only.ll
deleted file mode 100644
index fb4bcc39b5a64d..00000000000000
--- a/llvm/test/CodeGen/NVPTX/combine-mad-only.ll
+++ /dev/null
@@ -1,87 +0,0 @@
-; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc < %s -mtriple=nvptx -mcpu=sm_20 -O1 --selectiondag-simplify-nodes=false --combiner-generic-combines=false --debug-counter=early-cse=100, | FileCheck %s
-; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_20 -O1 --selectiondag-simplify-nodes=false --combiner-generic-combines=false --debug-counter=early-cse=100, | FileCheck %s
-; RUN: %if ptxas && !ptxas-12.0 %{ llc < %s -mtriple=nvptx -mcpu=sm_20 -O1 --selectiondag-simplify-nodes=false --combiner-generic-combines=false --debug-counter=early-cse=100, | %ptxas-verify %}
-; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_20 -O1 --selectiondag-simplify-nodes=false --combiner-generic-combines=false --debug-counter=early-cse=100, | %ptxas-verify %}
-
-;; mad x 1 y => add y x
-define i32 @test_mad_mul_1(i32 %x, i32 %y) {
-; CHECK-LABEL: test_mad_mul_1(
-; CHECK:       {
-; CHECK-NEXT:    .reg .b32 %r<4>;
-; CHECK-EMPTY:
-; CHECK-NEXT:  // %bb.0:
-; CHECK-NEXT:    ld.param.u32 %r1, [test_mad_mul_1_param_0];
-; CHECK-NEXT:    ld.param.u32 %r2, [test_mad_mul_1_param_1];
-; CHECK-NEXT:    add.s32 %r3, %r1, %r2;
-; CHECK-NEXT:    st.param.b32 [func_retval0], %r3;
-; CHECK-NEXT:    ret;
-  %mul = mul i32 %x, 1
-  %add = add i32 %mul, %y
-  ret i32 %add
-}
-
-;; mad x -1 y => sub y x
-define i32 @test_mad_mul_neg_1(i32 %x, i32 %y) {
-; CHECK-LABEL: test_mad_mul_neg_1(
-; CHECK:       {
-; CHECK-NEXT:    .reg .b32 %r<4>;
-; CHECK-EMPTY:
-; CHECK-NEXT:  // %bb.0:
-; CHECK-NEXT:    ld.param.u32 %r1, [test_mad_mul_neg_1_param_0];
-; CHECK-NEXT:    ld.param.u32 %r2, [test_mad_mul_neg_1_param_1];
-; CHECK-NEXT:    sub.s32 %r3, %r2, %r1;
-; CHECK-NEXT:    st.param.b32 [func_retval0], %r3;
-; CHECK-NEXT:    ret;
-  %mul = mul i32 %x, -1
-  %add = add i32 %mul, %y
-  ret i32 %add
-}
-
-;; mad x 0 y => y
-define i32 @test_mad_mul_0(i32 %x, i32 %y) {
-; CHECK-LABEL: test_mad_mul_0(
-; CHECK:       {
-; CHECK-NEXT:    .reg .b32 %r<2>;
-; CHECK-EMPTY:
-; CHECK-NEXT:  // %bb.0:
-; CHECK-NEXT:    ld.param.u32 %r1, [test_mad_mul_0_param_1];
-; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
-; CHECK-NEXT:    ret;
-  %mul = mul i32 %x, 0
-  %add = add i32 %mul, %y
-  ret i32 %add
-}
-
-;; mad x y 0 => mul x y
-define i32 @test_mad_add_0(i32 %x, i32 %y) {
-; CHECK-LABEL: test_mad_add_0(
-; CHECK:       {
-; CHECK-NEXT:    .reg .b32 %r<4>;
-; CHECK-EMPTY:
-; CHECK-NEXT:  // %bb.0:
-; CHECK-NEXT:    ld.param.u32 %r1, [test_mad_add_0_param_0];
-; CHECK-NEXT:    ld.param.u32 %r2, [test_mad_add_0_param_1];
-; CHECK-NEXT:    mul.lo.s32 %r3, %r1, %r2;
-; CHECK-NEXT:    st.param.b32 [func_retval0], %r3;
-; CHECK-NEXT:    ret;
-  %mul = mul i32 %x, %y
-  %add = add i32 %mul, 0
-  ret i32 %add
-}
-
-;; mad c0 c1 x => add x (c0*c1)
-define i32 @test_mad_fold_mul(i32 %x) {
-; CHECK-LABEL: test_mad_fold_mul(
-; CHECK:       {
-; CHECK-NEXT:    .reg .b32 %r<3>;
-; CHECK-EMPTY:
-; CHECK-NEXT:  // %bb.0:
-; CHECK-NEXT:    ld.param.u32 %r1, [test_mad_fold_mul_param_0];
-; CHECK-NEXT:    add.s32 %r2, %r1, 12;
-; CHECK-NEXT:    st.param.b32 [func_retval0], %r2;
-; CHECK-NEXT:    ret;
-  %mul = mul i32 4, 3
-  %add = add i32 %mul, %x
-  ret i32 %add
-}
diff --git a/llvm/test/CodeGen/NVPTX/combine-mad.ll b/llvm/test/CodeGen/NVPTX/combine-mad.ll
index 7d523a835a1f3f..bd88376c374988 100644
--- a/llvm/test/CodeGen/NVPTX/combine-mad.ll
+++ b/llvm/test/CodeGen/NVPTX/combine-mad.ll
@@ -194,7 +194,7 @@ define i32 @test_mad_fold(i32 %x) {
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    ld.param.u32 %r1, [test_mad_fold_param_0];
 ; CHECK-NEXT:    mul.hi.s32 %r2, %r1, -2147221471;
-; CHECK-NEXT:    add.s32 %r3, %r1, %r2;
+; CHECK-NEXT:    add.s32 %r3, %r2, %r1;
 ; CHECK-NEXT:    shr.u32 %r4, %r3, 31;
 ; CHECK-NEXT:    shr.s32 %r5, %r3, 12;
 ; CHECK-NEXT:    add.s32 %r6, %r5, %r4;
diff --git a/llvm/test/CodeGen/NVPTX/i128.ll b/llvm/test/CodeGen/NVPTX/i128.ll
index accfbe4af0313c..d7ba0aeb02eed2 100644
--- a/llvm/test/CodeGen/NVPTX/i128.ll
+++ b/llvm/test/CodeGen/NVPTX/i128.ll
@@ -7,20 +7,20 @@ define i128 @srem_i128(i128 %lhs, i128 %rhs) {
 ; CHECK:       {
 ; CHECK-NEXT:    .reg .pred %p<19>;
 ; CHECK-NEXT:    .reg .b32 %r<20>;
-; CHECK-NEXT:    .reg .b64 %rd<129>;
+; CHECK-NEXT:    .reg .b64 %rd<127>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0: // %_udiv-special-cases
 ; CHECK-NEXT:    ld.param.v2.u64 {%rd45, %rd46}, [srem_i128_param_0];
 ; CHECK-NEXT:    ld.param.v2.u64 {%rd49, %rd50}, [srem_i128_param_1];
 ; CHECK-NEXT:    shr.s64 %rd2, %rd46, 63;
-; CHECK-NEXT:    mov.b64 %rd119, 0;
-; CHECK-NEXT:    sub.cc.s64 %rd52, %rd119, %rd45;
-; CHECK-NEXT:    subc.cc.s64 %rd53, %rd119, %rd46;
+; CHECK-NEXT:    mov.b64 %rd117, 0;
+; CHECK-NEXT:    sub.cc.s64 %rd52, %rd117, %rd45;
+; CHECK-NEXT:    subc.cc.s64 %rd53, %rd117, %rd46;
 ; CHECK-NEXT:    setp.lt.s64 %p1, %rd46, 0;
 ; CHECK-NEXT:    selp.b64 %rd4, %rd53, %rd46, %p1;
 ; CHECK-NEXT:    selp.b64 %rd3, %rd52, %rd45, %p1;
-; CHECK-NEXT:    sub.cc.s64 %rd54, %rd119, %rd49;
-; CHECK-NEXT:    subc.cc.s64 %rd55, %rd119, %rd50;
+; CHECK-NEXT:    sub.cc.s64 %rd54, %rd117, %rd49;
+; CHECK-NEXT:    subc.cc.s64 %rd55, %rd117, %rd50;
 ; CHECK-NEXT:    setp.lt.s64 %p2, %rd50, 0;
 ; CHECK-NEXT:    selp.b64 %rd6, %rd55, %rd50, %p2;
 ; CHECK-NEXT:    selp.b64 %rd5, %rd54, %rd49, %p2;
@@ -44,7 +44,7 @@ define i128 @srem_i128(i128 %lhs, i128 %rhs) {
 ; CHECK-NEXT:    add.s64 %rd64, %rd63, 64;
 ; CHECK-NEXT:    selp.b64 %rd65, %rd62, %rd64, %p7;
 ; CHECK-NEXT:    sub.cc.s64 %rd66, %rd61, %rd65;
-; CHECK-NEXT:    subc.cc.s64 %rd67, %rd119, 0;
+; CHECK-NEXT:    subc.cc.s64 %rd67, %rd117, 0;
 ; CHECK-NEXT:    setp.eq.s64 %p8, %rd67, 0;
 ; CHECK-NEXT:    setp.ne.s64 %p9, %rd67, 0;
 ; CHECK-NEXT:    selp.u32 %r5, -1, 0, %p9;
@@ -57,14 +57,14 @@ define i128 @srem_i128(i128 %lhs, i128 %rhs) {
 ; CHECK-NEXT:    xor.b64 %rd68, %rd66, 127;
 ; CHECK-NEXT:    or.b64 %rd69, %rd68, %rd67;
 ; CHECK-NEXT:    setp.eq.s64 %p13, %rd69, 0;
-; CHECK-NEXT:    selp.b64 %rd128, 0, %rd4, %p12;
-; CHECK-NEXT:    selp.b64 %rd127, 0, %rd3, %p12;
+; CHECK-NEXT:    selp.b64 %rd126, 0, %rd4, %p12;
+; CHECK-NEXT:    selp.b64 %rd125, 0, %rd3, %p12;
 ; CHECK-NEXT:    or.pred %p14, %p12, %p13;
 ; CHECK-NEXT:    @%p14 bra $L__BB0_5;
 ; CHECK-NEXT:  // %bb.3: // %udiv-bb1
-; CHECK-NEXT:    add.cc.s64 %rd121, %rd66, 1;
-; CHECK-NEXT:    addc.cc.s64 %rd122, %rd67, 0;
-; CHECK-NEXT:    or.b64 %rd72, %rd121, %rd122;
+; CHECK-NEXT:    add.cc.s64 %rd119, %rd66, 1;
+; CHECK-NEXT:    addc.cc.s64 %rd120, %rd67, 0;
+; CHECK-NEXT:    or.b64 %rd72, %rd119, %rd120;
 ; CHECK-NEXT:    setp.eq.s64 %p15, %rd72, 0;
 ; CHECK-NEXT:    cvt.u32.u64 %r9, %rd66;
 ; CHECK-NEXT:    mov.b32 %r10, 127;
@@ -78,12 +78,12 @@ define i128 @srem_i128(i128 %lhs, i128 %rhs) {
 ; CHECK-NEXT:    sub.s32 %r15, %r14, %r9;
 ; CHECK-NEXT:    shl.b64 %rd76, %rd3, %r15;
 ; CHECK-NEXT:    setp.gt.s32 %p16, %r11, 63;
-; CHECK-NEXT:    selp.b64 %rd126, %rd76, %rd75, %p16;
-; CHECK-NEXT:    shl.b64 %rd125, %rd3, %r11;
-; CHECK-NEXT:    mov.u64 %rd116, %rd119;
+; CHECK-NEXT:    selp.b64 %rd124, %rd76, %rd75, %p16;
+; CHECK-NEXT:    shl.b64 %rd123, %rd3, %r11;
+; CHECK-NEXT:    mov.u64 %rd114, %rd117;
 ; CHECK-NEXT:    @%p15 bra $L__BB0_4;
 ; CHECK-NEXT:  // %bb.1: // %udiv-preheader
-; CHECK-NEXT:    cvt.u32.u64 %r16, %rd121;
+; CHECK-NEXT:    cvt.u32.u64 %r16, %rd119;
 ; CHECK-NEXT:    shr.u64 %rd79, %rd3, %r16;
 ; CHECK-NEXT:    sub.s32 %r18, %r12, %r16;
 ; CHECK-NEXT:    shl.b64 %rd80, %rd4, %r18;
@@ -91,61 +91,59 @@ define i128 @srem_i128(i128 %lhs, i128 %rhs) {
 ; CHECK-NEXT:    add.s32 %r19, %r16, -64;
 ; CHECK-NEXT:    shr.u64 %rd82, %rd4, %r19;
 ; CHECK-NEXT:    setp.gt.s32 %p17, %r16, 63;
-; CHECK-NEXT:    selp.b64 %rd123, %rd82, %rd81, %p17;
-; CHECK-NEXT:    shr.u64 %rd124, %rd4, %r16;
+; CHECK-NEXT:    selp.b64 %rd121, %rd82, %rd81, %p17;
+; CHECK-NEXT:    shr.u64 %rd122, %rd4, %r16;
 ; CHECK-NEXT:    add.cc.s64 %rd35, %rd5, -1;
 ; CHECK-NEXT:    addc.cc.s64 %rd36, %rd6, -1;
-; CHECK-NEXT:    mov.b64 %rd116, 0;
-; CHECK-NEXT:    mov.u64 %rd119, %rd116;
+; CHECK-NEXT:    mov.b64 %rd114, 0;
+; CHECK-NEXT:    mov.u64 %rd117, %rd114;
 ; CHECK-NEXT:  $L__BB0_2: // %udiv-do-while
 ; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
-; CHECK-NEXT:    shr.u64 %rd83, %rd123, 63;
-; CHECK-NEXT:    shl.b64 %rd84, %rd124, 1;
+; CHECK-NEXT:    shr.u64 %rd83, %rd121, 63;
+; CHECK-NEXT:    shl.b64 %rd84, %rd122, 1;
 ; CHECK-NEXT:    or.b64 %rd85, %rd84, %rd83;
-; CHECK-NEXT:    shl.b64 %rd86, %rd123, 1;
-; CHECK-NEXT:    shr.u64 %rd87, %rd126, 63;
+; CHECK-NEXT:    shl.b64 %rd86, %rd121, 1;
+; CHECK-NEXT:    shr.u64 %rd87, %rd124, 63;
 ; CHECK-NEXT:    or.b64 %rd88, %rd86, %rd87;
-; CHECK-NEXT:    shr.u64 %rd89, %rd125, 63;
-; CHECK-NEXT:    shl.b64 %rd90, %rd126, 1;
+; CHECK-NEXT:    shr.u64 %rd89, %rd123, 63;
+; CHECK-NEXT:    shl.b64 %rd90, %rd124, 1;
 ; CHECK-NEXT:    or.b64 %rd91, %rd90, %rd89;
-; CHECK-NEXT:    shl.b64 %rd92, %rd125, 1;
-; CHECK-NEXT:    or.b64 %rd125, %rd119, %rd92;
-; CHECK-NEXT:    or.b64 %rd126, %rd116, %rd91;
+; CHECK-NEXT:    shl.b64 %rd92, %rd123, 1;
+; CHECK-NEXT:    or.b64 %rd123, %rd117, %rd92;
+; CHECK-NEXT:    or.b64 %rd124, %rd114, %rd91;
 ; CHECK-NEXT:    sub.cc.s64 %rd93, %rd35, %rd88;
 ; CHECK-NEXT:    subc.cc.s64 %rd94, %rd36, %rd85;
 ; CHECK-NEXT:    shr.s64 %rd95, %rd94, 63;
-; CHECK-NEXT:    and.b64 %rd119, %rd95, 1;
+; CHECK-NEXT:    and.b64 %rd117, %rd95, 1;
 ; CHECK-NEXT:    and.b64 %rd96, %rd95, %rd5;
 ; CHECK-NEXT:    and.b64 %rd97, %rd95, %rd6;
-; CHECK-NEXT:    sub.cc.s64 %rd123, %rd88, %rd96;
-; CHECK-NEXT:    subc.cc.s64 %rd124, %rd85, %rd97;
-; CHECK-NEXT:    add.cc.s64 %rd121, %rd121, -1;
-; CHECK-NEXT:    addc.cc.s64 %rd122, %rd122, -1;
-; CHECK-NEXT:    or.b64 %rd98, %rd121, %rd122;
+; CHECK-NEXT:    sub.cc.s64 %rd121, %rd88, %rd96;
+; CHECK-NEXT:    subc.cc.s64 %rd122, %rd85, %rd97;
+; CHECK-NEXT:    add.cc.s64 %rd119, %rd119, -1;
+; CHECK-NEXT:    addc.cc.s64 %rd120, %rd120, -1;
+; CHECK-NEXT:    or.b64 %rd98, %rd119, %rd120;
 ; CHECK-NEXT:    setp.eq.s64 %p18, %rd98, 0;
 ; CHECK-NEXT:    @%p18 bra $L__BB0_4;
 ; CHECK-NEXT:    bra.uni $L__BB0_2;
 ; CHECK-NEXT:  $L__BB0_4: // %udiv-loop-exit
-; CHECK-NEXT:    shr.u64 %rd99, %rd125, 63;
-; CHECK-NEXT:    shl.b64 %rd100, %rd126, 1;
+; CHECK-NEXT:    shr.u64 %rd99, %rd123, 63;
+; CHECK-NEXT:    shl.b64 %rd100, %rd124, 1;
 ; CHECK-NEXT:    or.b64 %rd101, %rd100, %rd99;
-; CHECK-NEXT:    shl.b64 %rd102, %rd125, 1;
-; CHECK-NEXT:    or.b64 %rd127, %rd119, %rd102;
-; CHECK-NEXT:    or.b64 %rd128, %rd116, %rd101;
+; CHECK-NEXT:    shl.b64 %rd102, %rd123, 1;
+; CHECK-NEXT:    or.b64 %rd125, %rd117, %rd102;
+; CHECK-NEXT:    or.b64 %rd126, %rd114, %rd101;
 ; CHECK-NEXT:  $L__BB0_5: // %udiv-end
-; CHECK-NEXT:    mul.hi.u64 %rd103, %rd5, %rd127;
-; CHECK-NEXT:    mul.lo.s64 %rd104, %rd5, %rd128;
-; CHECK-NEXT:    add.s64 %rd105, %rd103, %rd104;
-; CHECK-NEXT:    mul.lo.s64 %rd106, %rd6, %rd127;
-; CHECK-NEXT:    add.s64 %rd107, %rd105, %rd106;
-; CHECK-NEXT:    mul.lo.s64 %rd108, %rd5, %rd127;
-; CHECK-NEXT:    sub.cc.s64 %rd109, %rd3, %rd108;
-; CHECK-NEXT:    subc.cc.s64 %rd110, %rd4, %rd107;
-; CHECK-NEXT:    xor.b64 %rd111, %rd109, %rd2;
-; CHECK-NEXT:    xor.b64 %rd112, %rd110, %rd2;
-; CHECK-NEXT:    sub.cc.s64 %rd113, %rd111, %rd2;
-; CHECK-NEXT:    subc.cc.s64 %rd114, %rd112, %rd2;
-; CHECK-NEXT:    st.param.v2.b64 [func_retval0], {%rd113, %rd114};
+; CHECK-NEXT:    mul.hi.u64 %rd103, %rd5, %rd125;
+; CHECK-NEXT:    mad.lo.s64 %rd104, %rd5, %rd126, %rd103;
+; CHECK-NEXT:    mad.lo.s64 %rd105, %rd6, %rd125, %rd104;
+; CHECK-NEXT:    mul.lo.s64 %rd106, %rd5, %rd125;
+; CHECK-NEXT:    sub.cc.s64 %rd107, %rd3, %rd106;
+; CHECK-NEXT:    subc.cc.s64 %rd108, %rd4, %rd105;
+; CHECK-NEXT:    xor.b64 %rd109, %rd107, %rd2;
+; CHECK-NEXT:    xor.b64 %rd110, %rd108, %rd2;
+; CHECK-NEXT:    sub.cc.s64 %rd111, %rd109, %rd2;
+; CHECK-NEXT:    subc.cc.s64 %rd112, %rd110, %rd2;
+; CHECK-NEXT:    st.param.v2.b64 [func_retval0], {%rd111, %rd112};
 ; CHECK-NEXT:    ret;
   %div = srem i128 %lhs, %rhs
   ret i128 %div
@@ -156,7 +154,7 @@ define i128 @urem_i128(i128 %lhs, i128 %rhs) {
 ; CHECK:       {
 ; CHECK-NEXT:    .reg .pred %p<17>;
 ; CHECK-NEXT:    .reg .b32 %r<20>;
-; CHECK-NEXT:    .reg .b64 %rd<115>;
+; CHECK-NEXT:    .reg .b64 %rd<113>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0: // %_udiv-special-cases
 ; CHECK-NEXT:    ld.param.v2.u64 {%rd41, %rd42}, [urem_i128_param_0];
@@ -180,9 +178,9 @@ define i128 @urem_i128(i128 %lhs, i128 %rhs) {
 ; CHECK-NEXT:    cvt.u64.u32 %rd52, %r4;
 ; CHECK-NEXT:    add.s64 %rd53, %rd52, 64;
 ; CHECK-NEXT:    selp.b64 %rd54, %rd51, %rd53, %p5;
-; CHECK-NEXT:    mov.b64 %rd105, 0;
+; CHECK-NEXT:    mov.b64 %rd103, 0;
 ; CHECK-NEXT:    sub.cc.s64 %rd56, %rd50, %rd54;
-; CHECK-NEXT:    subc.cc.s64 %rd57, %rd105, 0;
+; CHECK-NEXT:    subc.cc.s64 %rd57, %rd103, 0;
 ; CHECK-NEXT:    setp.eq.s64 %p6, %rd57, 0;
 ; CHECK-NEXT:    setp.ne.s64 %p7, %rd57, 0;
 ; CHECK-NEXT:    selp.u32 %r5, -1, 0, %p7;
@@ -195,14 +193,14 @@ define i128 @urem_i128(i128 %lhs, i128 %rhs) {
 ; CHECK-NEXT:    xor.b64 %rd58, %rd56, 127;
 ; CHECK-NEXT:    or.b64 %rd59, %rd58, %rd57;
 ; CHECK-NEXT:    setp.eq.s64 %p11, %rd59, 0;
-; CHECK-NEXT:    selp.b64 %rd114, 0, %rd42, %p10;
-; CHECK-NEXT:    selp.b64 %rd113, 0, %rd41, %p10;
+; CHECK-NEXT:    selp.b64 %rd112, 0, %rd42, %p10;
+; CHECK-NEXT:    selp.b64 %rd111, 0, %rd41, %p10;
 ; CHECK-NEXT:    or.pred %p12, %p10, %p11;
 ; CHECK-NEXT:    @%p12 bra $L__BB1_5;
 ; CHECK-NEXT:  // %bb.3: // %udiv-bb1
-; CHECK-NEXT:    add.cc.s64 %rd107, %rd56, 1;
-; CHECK-NEXT:    addc.cc.s64 %rd108, %rd57, 0;
-; CHECK-NEXT:    or.b64 %rd62, %rd107, %rd108;
+; CHECK-NEXT:    add.cc.s64 %rd105, %rd56, 1;
+; CHECK-NEXT:    addc.cc.s64 %rd106, %rd57, 0;
+; CHECK-NEXT:    or.b64 %rd62, %rd105, %rd106;
 ; CHECK-NEXT:    setp.eq.s64 %p13, %rd62, 0;
 ; CHECK-NEXT:    cvt.u32.u64 %r9, %rd56;
 ; CHECK-NEXT:    mov.b32 %r10, 127;
@@ -216,12 +214,12 @@ define i128 @urem_i128(i128 %lhs, i128 %rhs) {
 ; CHECK-NEXT:    sub.s32 %r15, %r14, %r9;
 ; CHECK-NEXT:    shl.b64 %rd66, %rd41, %r15;
 ; CHECK-NEXT:    setp.gt.s32 %p14, %r11, 63;
-; CHECK-NEXT:    selp.b64 %rd112, %rd66, %rd65, %p14;
-; CHECK-NEXT:    shl.b64 %rd111, %rd41, %r11;
-; CHECK-NEXT:    mov.u64 %rd102, %rd105;
+; CHECK-NEXT:    selp.b64 %rd110, %rd66, %rd65, %p14;
+; CHECK-NEXT:    shl.b64 %rd109, %rd41, %r11;
+; CHECK-NEXT:    mov.u64 %rd100, %rd103;
 ; CHECK-NEXT:    @%p13 bra $L__BB1_4;
 ; CHECK-NEXT:  // %bb.1: // %udiv-preheader
-; CHECK-NEXT:    cvt.u32.u64 %r16, %rd107;
+; CHECK-NEXT:    cvt.u32.u64 %r16, %rd105;
 ; CHECK-NEXT:    shr.u64 %rd69, %rd41, %r16;
 ; CHECK-NEXT:    sub.s32 %r18, %r12, %r16;
 ; CHECK-NEXT:    shl.b64 %rd70, %rd42, %r18;
@@ -229,57 +227,55 @@ define i128 @urem_i128(i128 %lhs, i128 %rhs) {
 ; CHECK-NEXT:    add.s32 %r19, %r16, -64;
 ; CHECK-NEXT:    shr.u64 %rd72, %rd42, %r19;
 ; CHECK-NEXT:    setp.gt.s32 %p15, %r16, 63;
-; CHECK-NEXT:    selp.b64 %rd109, %rd72, %rd71, %p15;
-; CHECK-NEXT:    shr.u64 %rd110, %rd42, %r16;
+; CHECK-NEXT:    selp.b64 %rd107, %rd72, %rd71, %p15;
+; CHECK-NEXT:    shr.u64 %rd108, %rd42, %r16;
 ; CHECK-NEXT:    add.cc.s64 %rd33, %rd3, -1;
 ; CHECK-NEXT:    addc.cc.s64 %rd34, %rd4, -1;
-; CHECK-NEXT:    mov.b64 %rd102, 0;
-; CHECK-NEXT:    mov.u64 %rd105, %rd102;
+; CHECK-NEXT:    mov.b64 %rd100, 0;
+; CHECK-NEXT:    mov.u64 %rd103, %rd100;
 ; CHECK-NEXT:  $L__BB1_2: // %udiv-do-while
 ; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
-; CHECK-NEXT:    shr.u64 %rd73, %rd109, 63;
-; CHECK-NEXT:    shl.b64 %rd74, %rd110, 1;
+; CHECK-NEXT:    shr.u64 %rd73, %rd107, 63;
+; CHECK-NEXT:    shl.b64 %rd74, %rd108, 1;
 ; CHECK-NEXT:    or.b64 %rd75, %rd74, %rd73;
-; CHECK-NEXT:    shl.b64 %rd76, %rd109, 1;
-; CHECK-NEXT:    shr.u64 %rd77, %rd112, 63;
+; CHECK-NEXT:    shl.b64 %rd76, %rd107, 1;
+; CHECK-NEXT:    shr.u64 %rd77, %rd110, 63;
 ; CHECK-NEXT:    or.b64 %rd78, %rd76, %rd77;
-; CHECK-NEXT:    shr.u64 %rd79, %rd111, 63;
-; CHECK-NEXT:    shl.b64 %rd80, %rd112, 1;
+; CHECK-NEXT:    shr.u64 %rd79, %rd109, 63;
+; CHECK-NEXT:    shl.b64 %rd80, %rd110, 1;
 ; CHECK-NEXT:    or.b64 %rd81, %rd80, %rd79;
-; CHECK-NEXT:    shl.b64 %rd82, %rd111, 1;
-; CHECK-NEXT:    or.b64 %rd111, %rd105, %rd82;
-; CHECK-NEXT:    or.b64 %rd112, %rd102, %rd81;
+; CHECK-NEXT:    shl.b64 %rd82, %rd109, 1;
+; CHECK-NEXT:    or.b64 %rd109, %rd103, %rd82;
+; CHECK-NEXT:    or.b64 %rd110, %rd100, %rd81;
 ; CHECK-NEXT:    sub.cc.s64 %rd83, %rd33, %rd78;
 ; CHECK-NEXT:    subc.cc.s64 %rd84, %rd34, %rd75;
 ; CHECK-NEXT:    shr.s64 %rd85, %rd84, 63;
-; CHECK-NEXT:    and.b64 %rd105, %rd85, 1;
+; CHECK-NEXT:    and.b64 %rd103, %rd85, 1;
 ; CHECK-NEXT:    and.b64 %rd86, %rd85, %rd3;
 ; CHECK-NEXT:    and.b64 %rd87, %rd85, %rd4;
-; CHECK-NEXT:    sub.cc.s64 %rd109, %rd78, %rd86;
-; CHECK-NEXT:    subc.cc.s64 %rd110, %rd75, %rd87;
-; CHECK-NEXT:    add.cc.s64 %rd107, %rd107, -1;
-; CHECK-NEXT:    addc.cc.s64 %rd108, %rd108, -1;
-; CHECK-NEXT:    or.b64 %rd88, %rd107, %rd108;
+; CHECK-NEXT:    sub.cc.s64 %rd107, %rd78, %rd86;
+; CHECK-NEXT:    subc.cc.s64 %rd108, %rd75, %rd87;
+; CHECK-NEXT:    add.cc.s64 %rd105, %rd105, -1;
+; CHECK-NEXT:    addc.cc.s64 %rd106, %rd106, -1;
+; CHECK-NEXT:    or.b64 %rd88, %rd105, %rd106;
 ; CHECK-NEXT:    setp.eq.s64 %p16, %rd88, 0;
 ; CHECK-NEXT:    @%p16 bra $L__BB1_4;
 ; CHECK-NEXT:    bra.uni $L__BB1_2;
 ; CHECK-NEXT:  $L__BB1_4: // %udiv-loop-exit
-; CHECK-NEXT:    shr.u64 %rd89, %rd111, 63;
-; CHECK-NEXT:    shl.b64 %rd90, %rd112, 1;
+; CHECK-NEXT:    shr.u64 %rd89, %rd109, 63;
+; CHECK-NEXT:    shl.b64 %rd90, %rd110, 1;
 ; CHECK-NEXT:    or.b64 %rd91, %rd90, %rd89;
-; CHECK-NEXT:    shl.b64 %rd92, %rd111, 1;
-; CHECK-NEXT:    or.b64 %rd113, %rd105, %rd92;
-; CHECK-NEXT:    or.b64 %rd114, %rd102, %rd91;
+; CHECK-NEXT:    shl.b64 %rd92, %rd109, 1;
+; CHECK-NEXT:    or.b64 %rd111, %rd103, %rd92;
+; CHECK-NEXT:    or.b64 %rd112, %rd100, %rd91;
 ; CHECK-NEXT:  $L__BB1_5: // %udiv-end
-; CHECK-NEXT:    mul.hi.u64 %rd93, %rd3, %rd113;
-; CHECK-NEXT:    mul.lo.s64 %rd94, %rd3, %rd114;
-; CHECK-NEXT:    add.s64 %rd95, %rd93, %rd94;
-; CHECK-NEXT:    mul.lo.s64 %rd96, %rd4, %rd113;
-; CHECK-NEXT:    add.s64 %rd97, %rd95, %rd96;
-; CHECK-NEXT:    mul.lo.s64 %rd98, %rd3, %rd113;
-; CHECK-NEXT:    sub.cc.s64 %rd99, %rd41, %rd98;
-; CHECK-NEXT:    subc.cc.s64 %rd100, %rd42, %rd97;
-; CHECK-NEXT:    st.param.v2.b64 [func_retval0], {%rd99, %rd100};
+; CHECK-NEXT:    mul.hi.u64 %rd93, %rd3, %rd111;
+; CHECK-NEXT:    mad.lo.s64 %rd94, %rd3, %rd112, %rd93;
+; CHECK-NEXT:    mad.lo.s64 %rd95, %rd4, %rd111, %rd94;
+; CHECK-NEXT:    mul.lo.s64 %rd96, %rd3, %rd111;
+; CHECK-NEXT:    sub.cc.s64 %rd97, %rd41, %rd96;
+; CHECK-NEXT:    subc.cc.s64 %rd98, %rd42, %rd95;
+; CHECK-NEXT:    st.param.v2.b64 [func_retval0], {%rd97, %rd98};
 ; CHECK-NEXT:    ret;
   %div = urem i128 %lhs, %rhs
   ret i128 %div

>From b841af0dbde56e26eb1105b48a89597c52c140e6 Mon Sep 17 00:00:00 2001
From: Peter Bell <peterbell10 at openai.com>
Date: Tue, 7 Jan 2025 06:12:32 +0000
Subject: [PATCH 4/4] Format

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 362dc4338c72e8..4670f9206b6b3b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -4458,7 +4458,8 @@ PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
       return SDValue();
 
     SDLoc DL(N);
-    SDValue Mul = DCI.DAG.getNode(ISD::MUL, DL, VT, M->getOperand(0), M->getOperand(1));
+    SDValue Mul =
+        DCI.DAG.getNode(ISD::MUL, DL, VT, M->getOperand(0), M->getOperand(1));
     SDValue MAD = DCI.DAG.getNode(ISD::ADD, DL, VT, Mul, N1);
     return DCI.DAG.getSelect(SDLoc(N), VT, N0->getOperand(0),
                              ((ZeroOpNum == 1) ? N1 : MAD),



More information about the llvm-commits mailing list