[llvm] [NVPTX] Improve folding to mad with immediate 1 (PR #93628)

Artem Belevich via llvm-commits llvm-commits at lists.llvm.org
Wed May 29 15:33:06 PDT 2024


================
@@ -5614,17 +5614,101 @@ static SDValue TryMULWIDECombine(SDNode *N,
   return DCI.DAG.getNode(Opc, DL, MulType, TruncLHS, TruncRHS);
 }
 
+static SDValue matchMADConstOnePattern(SDValue Add) {
+  if (Add->getOpcode() != ISD::ADD)
+    return SDValue();
+
+  if (const auto *Const0 = dyn_cast<ConstantSDNode>(Add->getOperand(0)))
+    if (Const0->getZExtValue() == 1)
+      return Add->getOperand(1);
+
+  if (const auto *Const1 = dyn_cast<ConstantSDNode>(Add->getOperand(1)))
+    if (Const1->getZExtValue() == 1)
+      return Add->getOperand(0);
+
+  return SDValue();
+}
+
+static SDValue combineMADConstOne(SDValue X, SDValue Add, EVT VT, SDLoc DL,
+                                  TargetLowering::DAGCombinerInfo &DCI) {
+
+  if (SDValue Y = matchMADConstOnePattern(Add))
+    return DCI.DAG.getNode(NVPTXISD::IMAD, DL, VT, X, Y, X);
+
+  return SDValue();
+}
+
+static SDValue combineMulSelectConstOne(SDValue X, SDValue Select, EVT VT,
+                                        SDLoc DL,
+                                        TargetLowering::DAGCombinerInfo &DCI) {
+  if (Select->getOpcode() != ISD::SELECT)
+    return SDValue();
+
+  SDValue Cond = Select->getOperand(0);
+
+  unsigned ConstOpNo = 1;
+  auto *Const = dyn_cast<ConstantSDNode>(Select->getOperand(ConstOpNo));
+  if (!Const || Const->getZExtValue() != 1) {
----------------
Artem-B wrote:

It looks like we could extract the common pattern into a helper function:

```
bool isConstOne(Operand) {
  const auto *Const = dyn_cast<ConstantSDNode>(Operand);
  return Const && Const->getZExtValue() == 1;
}
```

and then use it in handful of instances of this pattern throughout the code.

https://github.com/llvm/llvm-project/pull/93628


More information about the llvm-commits mailing list