[llvm] GlobalISel needs fdiv 1 / sqrt(x) to rsq combine (PR #78673)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 18 22:04:46 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-amdgpu

Author: Nick Anderson (nickleus27)

<details>
<summary>Changes</summary>

Fixes #<!-- -->64743

@<!-- -->arsenm @<!-- -->Pierre-vh Could you guys review and let me know if I am headed in the right direction.
1. is the MIR I am trying to match against 
`   %sqrt:_(s16) = contract G_FSQRT %x
    %one:_(s16) = G_FCONSTANT half 1.0
    %rsq:_(s16) = contract G_FDIV %one, %sqrt
` ?
2. Will the matcher in `AMDGPUCombine.td` match the above MIR and call the function I made called `matchFDivSqrt`?
3. Any advice on what needs to be done in `AMDGPUPostLegalizerCombinerImpl::matchFDivSqrt` would be appreciated. For example, what is the state of MI that is passed in? Is it a single instruction or is it a chain/tree of instructions?

---
Full diff: https://github.com/llvm/llvm-project/pull/78673.diff


3 Files Affected:

- (modified) llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h (+6) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUCombine.td (+6-1) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp (+56) 


``````````diff
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
index ea6ed322e9b1927..6ffb0842db3e4e6 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
@@ -495,6 +495,12 @@ m_GFMul(const LHS &L, const RHS &R) {
   return BinaryOp_match<LHS, RHS, TargetOpcode::G_FMUL, true>(L, R);
 }
 
+template <typename LHS, typename RHS>
+inline BinaryOp_match<LHS, RHS, TargetOpcode::G_FDIV, true>
+m_GFDiv(const LHS &L, const RHS &R) {
+  return BinaryOp_match<LHS, RHS, TargetOpcode::G_FDIV, true>(L, R);
+}
+
 template <typename LHS, typename RHS>
 inline BinaryOp_match<LHS, RHS, TargetOpcode::G_FSUB, false>
 m_GFSub(const LHS &L, const RHS &R) {
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCombine.td b/llvm/lib/Target/AMDGPU/AMDGPUCombine.td
index b9411e2052120d8..f26fb12dc1149f0 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUCombine.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPUCombine.td
@@ -33,6 +33,11 @@ def rcp_sqrt_to_rsq : GICombineRule<
          [{ return matchRcpSqrtToRsq(*${rcp}, ${matchinfo}); }]),
   (apply [{ Helper.applyBuildFn(*${rcp}, ${matchinfo}); }])>;
 
+def fdiv_1_by_sqrt_to_rsq : GICombineRule<
+  (defs root:$root, build_fn_matchinfo:$matchinfo),
+  (match (wip_match_opcode G_FSQRT, G_FDIV):$root,
+         [{ return matchFDivSqrt(*${root}, ${matchinfo}); }]),
+  (apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;
 
 def cvt_f32_ubyteN_matchdata : GIDefMatchData<"CvtF32UByteMatchInfo">;
 
@@ -156,7 +161,7 @@ def AMDGPUPostLegalizerCombiner: GICombiner<
   "AMDGPUPostLegalizerCombinerImpl",
   [all_combines, gfx6gfx7_combines, gfx8_combines,
    uchar_to_float, cvt_f32_ubyteN, remove_fcanonicalize, foldable_fneg,
-   rcp_sqrt_to_rsq, sign_extension_in_reg, smulu64]> {
+   rcp_sqrt_to_rsq, fdiv_1_by_sqrt_to_rsq, sign_extension_in_reg, smulu64]> {
   let CombineAllMethodName = "tryCombineAllImpl";
 }
 
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp b/llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp
index a1c34e92a57f356..9cd8436c188dc47 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp
@@ -83,6 +83,9 @@ class AMDGPUPostLegalizerCombinerImpl : public Combiner {
   matchRcpSqrtToRsq(MachineInstr &MI,
                     std::function<void(MachineIRBuilder &)> &MatchInfo) const;
 
+  bool matchFDivSqrt(MachineInstr &MI,
+                     std::function<void(MachineIRBuilder &)> &MatchInfo) const;
+
   // FIXME: Should be able to have 2 separate matchdatas rather than custom
   // struct boilerplate.
   struct CvtF32UByteMatchInfo {
@@ -334,6 +337,59 @@ bool AMDGPUPostLegalizerCombinerImpl::matchRcpSqrtToRsq(
   return false;
 }
 
+bool AMDGPUPostLegalizerCombinerImpl::matchFDivSqrt(
+    MachineInstr &MI,
+    std::function<void(MachineIRBuilder &)> &MatchInfo) const {
+
+  // TODO: Can I match fdiv 1.0 / sqrt(x) from here?
+  // My apologies, this code is still a mess. Trying to figure out
+  // what value MI should hold when getting to this point
+
+  auto getSqrtSrc = [=](const MachineInstr &MI) -> MachineInstr * {
+    if (!MI.getFlag(MachineInstr::FmContract))
+      return nullptr;
+    MachineInstr *SqrtSrcMI = nullptr;
+    auto Match =
+        mi_match(MI.getOperand(0).getReg(), MRI, m_GFSqrt(m_MInstr(SqrtSrcMI)));
+    (void)Match;
+    return SqrtSrcMI;
+  };
+
+  // Do I need to match write a matcher for  %one:_(s16) = G_FCONSTANT half 1.0
+  // ??
+
+  auto getFdivSrc = [=](const MachineInstr &MI) -> MachineInstr * {
+    if (!MI.getFlag(MachineInstr::FmContract))
+      return nullptr;
+
+    MachineInstr *FDivSrcMI = nullptr;
+    Register One;
+    auto Match = mi_match(MI.getOperand(0).getReg(), MRI,
+                          m_GFDiv(m_Reg(One), m_MInstr(FDivSrcMI)));
+    // Not sure how to check for FDiv operancd has a 1.0 value ?
+    if (!MI.getOperand(1).isFPImm()) {
+      return nullptr;
+    }
+    if (!MI.getOperand(1).getFPImm()->isOneValue()) {
+      return nullptr;
+    }
+    (void)Match;
+    return FDivSrcMI;
+  };
+
+  MachineInstr *FDivSrcMI = nullptr, *SqrtSrcMI = nullptr;
+  if ((SqrtSrcMI = getSqrtSrc(MI)) && (FDivSrcMI = getFdivSrc(*SqrtSrcMI))) {
+    MatchInfo = [SqrtSrcMI, &MI](MachineIRBuilder &B) {
+      B.buildIntrinsic(Intrinsic::amdgcn_rsq, {MI.getOperand(0)})
+          .addUse(SqrtSrcMI->getOperand(0).getReg())
+          .setMIFlags(MI.getFlags());
+    };
+    return true;
+  }
+
+  return false;
+}
+
 bool AMDGPUPostLegalizerCombinerImpl::matchCvtF32UByteN(
     MachineInstr &MI, CvtF32UByteMatchInfo &MatchInfo) const {
   Register SrcReg = MI.getOperand(1).getReg();

``````````

</details>


https://github.com/llvm/llvm-project/pull/78673


More information about the llvm-commits mailing list