[llvm] [NVPTX][SelectionDAG] Add IMAD combine rules + infra to disable default SelectionDAG rules for testing (PR #121724)

Justin Fargnoli via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 6 13:56:54 PST 2025


================
@@ -5164,6 +5164,53 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
   return DAG.getNode(ISD::BITCAST, DL, VT, PRMT);
 }
 
+static SDValue
+PerformIMADCombineWithOperands(SDNode *N, SDValue N0, SDValue N1, SDValue N2,
+                               TargetLowering::DAGCombinerInfo &DCI) {
+  ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
+  ConstantSDNode *N2C = dyn_cast<ConstantSDNode>(N2);
+  EVT VT = N0->getValueType(0);
+  SDLoc DL(N);
+  SDNodeFlags Flags = N->getFlags();
+
+  // mad x 1 y => add x y
+  if (N1C && N1C->isOne())
+    return DCI.DAG.getNode(ISD::ADD, DL, VT, N0, N2, Flags);
+
+  // mad x -1 y => sub y x
+  if (N1C && N1C->isAllOnes()) {
+    Flags.setNoUnsignedWrap(false);
+    return DCI.DAG.getNode(ISD::SUB, DL, VT, N2, N0, Flags);
+  }
+
+  // mad x 0 y => y
+  if (N1C && N1C->isZero())
+    return N2;
+
+  // mad x y 0 => mul x y
+  if (N2C && N2C->isZero())
+    return DCI.DAG.getNode(ISD::MUL, DL, VT, N0, N1, Flags);
+
+  // mad c0 c1 x => add x (c0*c1)
+  if (SDValue C =
+          DCI.DAG.FoldConstantArithmetic(ISD::MUL, DL, VT, {N0, N1}, Flags))
+    return DCI.DAG.getNode(ISD::ADD, DL, VT, N2, C, Flags);
+
+  return {};
+}
+
+static SDValue PerformIMADCombine(SDNode *N,
+                                  TargetLowering::DAGCombinerInfo &DCI) {
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+  SDValue N2 = N->getOperand(2);
+  SDValue res = PerformIMADCombineWithOperands(N, N0, N1, N2, DCI);
----------------
justinfargnoli wrote:

```suggestion
  SDValue Res = PerformIMADCombineWithOperands(N, N0, N1, N2, DCI);
```

https://github.com/llvm/llvm-project/pull/121724


More information about the llvm-commits mailing list