[llvm] [CodeGen] [AMDGPU] Attempt DAGCombine for fmul with select to ldexp (PR #111109)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 7 06:21:54 PDT 2024


================
@@ -14476,6 +14477,65 @@ SDValue SITargetLowering::performFDivCombine(SDNode *N,
   return SDValue();
 }
 
+SDValue SITargetLowering::performFMulCombine(SDNode *N,
+                                             DAGCombinerInfo &DCI) const {
+  SelectionDAG &DAG = DCI.DAG;
+  EVT VT = N->getValueType(0);
+
+  SDLoc SL(N);
+  SDValue LHS = N->getOperand(0);
+  SDValue RHS = N->getOperand(1);
+
+  // ldexp(x, zext(i1 y)) -> fmul x, (select y, 2.0, 1.0)
+  // ldexp(x, sext(i1 y)) -> fmul x, (select y, 0.5, 1.0)
+  //
+  // The above mentioned ldexp folding works fine for
+  // f16/f32, but as for f64 it creates f64 select which
+  // is costly to materealize as compared to f64 ldexp
+  // so here we undo the transform for f64 as follows :
+  //
+  // fmul x, (select y, 2.0, 1.0)   -> ldexp(  x, zext(i1 y) )
+  // fmul x, (select y, -2.0, -1.0) -> ldexp( (fneg x), zext(i1 y) )
+  // fmul x, (select y, 0.5, 1.0)   -> ldexp(  x, sext(i1 y) )
+  // fmul x, (select y, -0.5, -1.0) -> ldexp( (fneg x), sext(i1 y) )
+  if (VT == MVT::f64) {
+    if (RHS.hasOneUse() && RHS.getOpcode() == ISD::SELECT) {
+      const ConstantFPSDNode *TrueNode =
+          isConstOrConstSplatFP(RHS.getOperand(1));
+      const ConstantFPSDNode *FalseNode =
+          isConstOrConstSplatFP(RHS.getOperand(2));
+      bool isNeg;
+
+      if (!TrueNode || !FalseNode)
+        return SDValue();
+
+      if (TrueNode->isNegative() && FalseNode->isNegative())
+        isNeg = true;
+      else if (!TrueNode->isNegative() && !FalseNode->isNegative())
+        isNeg = false;
+      else
+        return SDValue();
+
+      unsigned ExtOp;
+      if (FalseNode->isExactlyValue(1.0) || FalseNode->isExactlyValue(-1.0)) {
+        if (TrueNode->isExactlyValue(2.0) || TrueNode->isExactlyValue(-2.0))
+          ExtOp = ISD::ZERO_EXTEND;
+        else if (TrueNode->isExactlyValue(0.5) ||
+                 TrueNode->isExactlyValue(-0.5))
+          ExtOp = ISD::SIGN_EXTEND;
+        else
+          return SDValue();
+
+        SDValue ExtNode = DAG.getNode(ExtOp, SL, MVT::i32, RHS.getOperand(0));
----------------
arsenm wrote:

You can generalize this to select of integer, and handle any pair of powers of 2 

https://github.com/llvm/llvm-project/pull/111109


More information about the llvm-commits mailing list