[llvm] [CodeGen] [AMDGPU] Attempt DAGCombine for fmul with select to ldexp (PR #111109)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 15 00:48:00 PDT 2024


================
@@ -14476,6 +14477,60 @@ SDValue SITargetLowering::performFDivCombine(SDNode *N,
   return SDValue();
 }
 
+SDValue SITargetLowering::performFMulCombine(SDNode *N,
+                                             DAGCombinerInfo &DCI) const {
+  SelectionDAG &DAG = DCI.DAG;
+  EVT VT = N->getValueType(0);
+  EVT i32VT = VT.changeElementType(MVT::i32);
+
+  SDLoc SL(N);
+  SDValue LHS = N->getOperand(0);
+  SDValue RHS = N->getOperand(1);
+
+  // ldexp(x, zext(i1 y)) -> fmul x, (select y, 2.0, 1.0)
+  // ldexp(x, sext(i1 y)) -> fmul x, (select y, 0.5, 1.0)
+  //
+  // The above mentioned ldexp folding works fine for
+  // bf16/f32, but as for f64 it creates f64 select which
+  // is costly to materialize as compared to f64 ldexp
+  // so here we undo the transform for f64 datatype.
+  // Also in case of f16, its cheaper to materialize inline
+  // 32 bit-constant (via ldexp use) rather than using fmul.
+  //
+  // Given : A = 2^a  &  B = 2^b ; where a and b are integers.
+  // fmul x, (select y, A, B)     -> ldexp( x, (select i32 y, a, b) )
+  // fmul x, (select y, -A, -B)   -> ldexp( (fneg x), (select i32 y, a, b) )
+  // Note : It takes care of generic scenario which covers undoing
+  // of special case(zext/sext) as mentioned.
+  if (VT.getScalarType() == MVT::f64 || VT.getScalarType() == MVT::f16) {
+    if (RHS.hasOneUse() && RHS.getOpcode() == ISD::SELECT) {
+      const ConstantFPSDNode *TrueNode =
+          isConstOrConstSplatFP(RHS.getOperand(1));
+      const ConstantFPSDNode *FalseNode =
+          isConstOrConstSplatFP(RHS.getOperand(2));
+
+      if (!TrueNode || !FalseNode)
+        return SDValue();
+
+      if (TrueNode->isNegative() != FalseNode->isNegative())
+        return SDValue();
+      LHS = TrueNode->isNegative() ? DAG.getNode(ISD::FNEG, SL, VT, LHS) : LHS;
+
+      int TrueNodeExpVal = TrueNode->getValueAPF().getExactLog2Abs();
+      int FalseNodeExpVal = FalseNode->getValueAPF().getExactLog2Abs();
+      if (TrueNodeExpVal != INT_MIN && FalseNodeExpVal != INT_MIN) {
+        SDValue SelectNode =
+            DAG.getNode(ISD::SELECT, SL, i32VT, RHS.getOperand(0),
+                        DAG.getConstant(TrueNodeExpVal, SL, i32VT),
+                        DAG.getConstant(FalseNodeExpVal, SL, i32VT));
+        return DAG.getNode(ISD::FLDEXP, SL, VT, LHS, SelectNode);
----------------
arsenm wrote:

This lost the flags 

https://github.com/llvm/llvm-project/pull/111109


More information about the llvm-commits mailing list