[llvm] [DAG] Constant fold FMAD (PR #69324)

via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 17 05:12:47 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-selectiondag

Author: Pierre van Houtryve (Pierre-vh)

<details>
<summary>Changes</summary>

This has very little effect on codegen in practice, but is a nice to have I think.

See #<!-- -->68315

---
Full diff: https://github.com/llvm/llvm-project/pull/69324.diff


3 Files Affected:

- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+19) 
- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+7-2) 
- (modified) llvm/test/CodeGen/AMDGPU/udiv.ll (+8-27) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 20ad4c766a1a3fc..eac0a14d8303fa4 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -495,6 +495,7 @@ namespace {
     SDValue visitFSUB(SDNode *N);
     SDValue visitFMUL(SDNode *N);
     template <class MatchContextClass> SDValue visitFMA(SDNode *N);
+    SDValue visitFMAD(SDNode *N);
     SDValue visitFDIV(SDNode *N);
     SDValue visitFREM(SDNode *N);
     SDValue visitFSQRT(SDNode *N);
@@ -2000,6 +2001,8 @@ SDValue DAGCombiner::visit(SDNode *N) {
   case ISD::FSUB:               return visitFSUB(N);
   case ISD::FMUL:               return visitFMUL(N);
   case ISD::FMA:                return visitFMA<EmptyMatchContext>(N);
+  case ISD::FMAD:
+    return visitFMAD(N);
   case ISD::FDIV:               return visitFDIV(N);
   case ISD::FREM:               return visitFREM(N);
   case ISD::FSQRT:              return visitFSQRT(N);
@@ -16752,6 +16755,22 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
   return SDValue();
 }
 
+SDValue DAGCombiner::visitFMAD(SDNode *N) {
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+  SDValue N2 = N->getOperand(2);
+  EVT VT = N->getValueType(0);
+  SDLoc DL(N);
+
+  // Constant fold FMAD.
+  if (isa<ConstantFPSDNode>(N0) && isa<ConstantFPSDNode>(N1) &&
+      isa<ConstantFPSDNode>(N2)) {
+    return DAG.getNode(ISD::FMAD, DL, VT, N0, N1, N2);
+  }
+
+  return SDValue();
+}
+
 // Combine multiple FDIVs with the same divisor into multiple FMULs by the
 // reciprocal.
 // E.g., (a / D; b / D;) -> (recip = 1.0 / D; a * recip; b * recip)
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 3f06d0bd4eaa1d5..b028c483718107e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -7069,7 +7069,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
          "Operand is DELETED_NODE!");
   // Perform various simplifications.
   switch (Opcode) {
-  case ISD::FMA: {
+  case ISD::FMA:
+  case ISD::FMAD: {
     assert(VT.isFloatingPoint() && "This operator only applies to FP types!");
     assert(N1.getValueType() == VT && N2.getValueType() == VT &&
            N3.getValueType() == VT && "FMA types must match!");
@@ -7080,7 +7081,11 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
       APFloat  V1 = N1CFP->getValueAPF();
       const APFloat &V2 = N2CFP->getValueAPF();
       const APFloat &V3 = N3CFP->getValueAPF();
-      V1.fusedMultiplyAdd(V2, V3, APFloat::rmNearestTiesToEven);
+      if (Opcode == ISD::FMAD) {
+        V1.multiply(V2, APFloat::rmNearestTiesToEven);
+        V1.add(V3, APFloat::rmNearestTiesToEven);
+      } else
+        V1.fusedMultiplyAdd(V2, V3, APFloat::rmNearestTiesToEven);
       return getConstantFP(V1, DL, VT);
     }
     break;
diff --git a/llvm/test/CodeGen/AMDGPU/udiv.ll b/llvm/test/CodeGen/AMDGPU/udiv.ll
index 012b3f976734dec..e554f912ff64886 100644
--- a/llvm/test/CodeGen/AMDGPU/udiv.ll
+++ b/llvm/test/CodeGen/AMDGPU/udiv.ll
@@ -2619,39 +2619,20 @@ define i64 @v_test_udiv64_mulhi_fold(i64 %arg) {
 ; VI-LABEL: v_test_udiv64_mulhi_fold:
 ; VI:       ; %bb.0:
 ; VI-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; VI-NEXT:    v_mov_b32_e32 v2, 0x4f800000
-; VI-NEXT:    v_madak_f32 v2, 0, v2, 0x47c35000
-; VI-NEXT:    v_rcp_f32_e32 v2, v2
+; VI-NEXT:    v_mov_b32_e32 v4, 0xa7c5
+; VI-NEXT:    v_mul_u32_u24_e32 v3, 0x500, v4
+; VI-NEXT:    v_mul_hi_u32_u24_e32 v2, 0x500, v4
+; VI-NEXT:    v_add_u32_e32 v3, vcc, 0x4237, v3
+; VI-NEXT:    v_addc_u32_e32 v5, vcc, 0, v2, vcc
+; VI-NEXT:    v_add_u32_e32 v6, vcc, 0xa9000000, v3
 ; VI-NEXT:    s_mov_b32 s6, 0xfffe7960
-; VI-NEXT:    v_mul_f32_e32 v2, 0x5f7ffffc, v2
-; VI-NEXT:    v_mul_f32_e32 v3, 0x2f800000, v2
-; VI-NEXT:    v_trunc_f32_e32 v3, v3
-; VI-NEXT:    v_madmk_f32 v2, v3, 0xcf800000, v2
-; VI-NEXT:    v_cvt_u32_f32_e32 v6, v2
-; VI-NEXT:    v_cvt_u32_f32_e32 v7, v3
-; VI-NEXT:    v_mad_u64_u32 v[2:3], s[4:5], v6, s6, 0
-; VI-NEXT:    v_mul_lo_u32 v4, v7, s6
-; VI-NEXT:    v_sub_u32_e32 v3, vcc, v3, v6
-; VI-NEXT:    v_add_u32_e32 v8, vcc, v3, v4
-; VI-NEXT:    v_mul_hi_u32 v5, v6, v2
-; VI-NEXT:    v_mad_u64_u32 v[3:4], s[4:5], v6, v8, 0
-; VI-NEXT:    v_add_u32_e32 v9, vcc, v5, v3
-; VI-NEXT:    v_mad_u64_u32 v[2:3], s[4:5], v7, v2, 0
-; VI-NEXT:    v_addc_u32_e32 v10, vcc, 0, v4, vcc
-; VI-NEXT:    v_mad_u64_u32 v[4:5], s[4:5], v7, v8, 0
-; VI-NEXT:    v_add_u32_e32 v2, vcc, v9, v2
-; VI-NEXT:    v_addc_u32_e32 v2, vcc, v10, v3, vcc
-; VI-NEXT:    v_addc_u32_e32 v3, vcc, 0, v5, vcc
-; VI-NEXT:    v_add_u32_e32 v2, vcc, v2, v4
-; VI-NEXT:    v_addc_u32_e32 v3, vcc, 0, v3, vcc
-; VI-NEXT:    v_add_u32_e32 v6, vcc, v6, v2
-; VI-NEXT:    v_addc_u32_e32 v7, vcc, v7, v3, vcc
 ; VI-NEXT:    v_mad_u64_u32 v[2:3], s[4:5], v6, s6, 0
+; VI-NEXT:    v_addc_u32_e32 v7, vcc, v5, v4, vcc
 ; VI-NEXT:    v_mul_lo_u32 v4, v7, s6
 ; VI-NEXT:    v_sub_u32_e32 v3, vcc, v3, v6
+; VI-NEXT:    v_mul_hi_u32 v8, v6, v2
 ; VI-NEXT:    v_add_u32_e32 v5, vcc, v4, v3
 ; VI-NEXT:    v_mad_u64_u32 v[3:4], s[4:5], v6, v5, 0
-; VI-NEXT:    v_mul_hi_u32 v8, v6, v2
 ; VI-NEXT:    v_add_u32_e32 v8, vcc, v8, v3
 ; VI-NEXT:    v_mad_u64_u32 v[2:3], s[4:5], v7, v2, 0
 ; VI-NEXT:    v_addc_u32_e32 v9, vcc, 0, v4, vcc

``````````

</details>


https://github.com/llvm/llvm-project/pull/69324


More information about the llvm-commits mailing list