[llvm] GlobalISel needs fdiv 1 / sqrt(x) to rsq combine (PR #78673)

Nick Anderson via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 18 22:04:17 PST 2024


https://github.com/nickleus27 created https://github.com/llvm/llvm-project/pull/78673

Fixes #64743

@arsenm @Pierre-vh Could you guys review and let me know if I am headed in the right direction.
1. is the MIR I am trying to match against 
`   %sqrt:_(s16) = contract G_FSQRT %x
    %one:_(s16) = G_FCONSTANT half 1.0
    %rsq:_(s16) = contract G_FDIV %one, %sqrt
` ?
2. Will the matcher in `AMDGPUCombine.td` match the above MIR and call the function I made called `matchFDivSqrt`?
3. Any advice on what needs to be done in `AMDGPUPostLegalizerCombinerImpl::matchFDivSqrt` would be appreciated. For example, what is the state of MI that is passed in? Is it a single instruction or is it a chain/tree of instructions?

>From 533b4241326618d2c931eb97f48373705f1d1481 Mon Sep 17 00:00:00 2001
From: Nick Anderson <nickleus27 at gmail.com>
Date: Mon, 15 Jan 2024 02:38:21 -0800
Subject: [PATCH] GlobalISel needs fdiv 1 / sqrt(x) to rsq combine

---
 .../llvm/CodeGen/GlobalISel/MIPatternMatch.h  |  6 ++
 llvm/lib/Target/AMDGPU/AMDGPUCombine.td       |  7 ++-
 .../AMDGPU/AMDGPUPostLegalizerCombiner.cpp    | 56 +++++++++++++++++++
 3 files changed, 68 insertions(+), 1 deletion(-)

diff --git a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
index ea6ed322e9b192..6ffb0842db3e4e 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
@@ -495,6 +495,12 @@ m_GFMul(const LHS &L, const RHS &R) {
   return BinaryOp_match<LHS, RHS, TargetOpcode::G_FMUL, true>(L, R);
 }
 
+template <typename LHS, typename RHS>
+inline BinaryOp_match<LHS, RHS, TargetOpcode::G_FDIV, true>
+m_GFDiv(const LHS &L, const RHS &R) {
+  return BinaryOp_match<LHS, RHS, TargetOpcode::G_FDIV, true>(L, R);
+}
+
 template <typename LHS, typename RHS>
 inline BinaryOp_match<LHS, RHS, TargetOpcode::G_FSUB, false>
 m_GFSub(const LHS &L, const RHS &R) {
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCombine.td b/llvm/lib/Target/AMDGPU/AMDGPUCombine.td
index b9411e2052120d..f26fb12dc1149f 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUCombine.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPUCombine.td
@@ -33,6 +33,11 @@ def rcp_sqrt_to_rsq : GICombineRule<
          [{ return matchRcpSqrtToRsq(*${rcp}, ${matchinfo}); }]),
   (apply [{ Helper.applyBuildFn(*${rcp}, ${matchinfo}); }])>;
 
+def fdiv_1_by_sqrt_to_rsq : GICombineRule<
+  (defs root:$root, build_fn_matchinfo:$matchinfo),
+  (match (wip_match_opcode G_FSQRT, G_FDIV):$root,
+         [{ return matchFDivSqrt(*${root}, ${matchinfo}); }]),
+  (apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;
 
 def cvt_f32_ubyteN_matchdata : GIDefMatchData<"CvtF32UByteMatchInfo">;
 
@@ -156,7 +161,7 @@ def AMDGPUPostLegalizerCombiner: GICombiner<
   "AMDGPUPostLegalizerCombinerImpl",
   [all_combines, gfx6gfx7_combines, gfx8_combines,
    uchar_to_float, cvt_f32_ubyteN, remove_fcanonicalize, foldable_fneg,
-   rcp_sqrt_to_rsq, sign_extension_in_reg, smulu64]> {
+   rcp_sqrt_to_rsq, fdiv_1_by_sqrt_to_rsq, sign_extension_in_reg, smulu64]> {
   let CombineAllMethodName = "tryCombineAllImpl";
 }
 
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp b/llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp
index a1c34e92a57f35..9cd8436c188dc4 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp
@@ -83,6 +83,9 @@ class AMDGPUPostLegalizerCombinerImpl : public Combiner {
   matchRcpSqrtToRsq(MachineInstr &MI,
                     std::function<void(MachineIRBuilder &)> &MatchInfo) const;
 
+  bool matchFDivSqrt(MachineInstr &MI,
+                     std::function<void(MachineIRBuilder &)> &MatchInfo) const;
+
   // FIXME: Should be able to have 2 separate matchdatas rather than custom
   // struct boilerplate.
   struct CvtF32UByteMatchInfo {
@@ -334,6 +337,59 @@ bool AMDGPUPostLegalizerCombinerImpl::matchRcpSqrtToRsq(
   return false;
 }
 
+bool AMDGPUPostLegalizerCombinerImpl::matchFDivSqrt(
+    MachineInstr &MI,
+    std::function<void(MachineIRBuilder &)> &MatchInfo) const {
+
+  // TODO: Can I match fdiv 1.0 / sqrt(x) from here?
+  // My apologies, this code is still a mess. Trying to figure out
+  // what value MI should hold when getting to this point
+
+  auto getSqrtSrc = [=](const MachineInstr &MI) -> MachineInstr * {
+    if (!MI.getFlag(MachineInstr::FmContract))
+      return nullptr;
+    MachineInstr *SqrtSrcMI = nullptr;
+    auto Match =
+        mi_match(MI.getOperand(0).getReg(), MRI, m_GFSqrt(m_MInstr(SqrtSrcMI)));
+    (void)Match;
+    return SqrtSrcMI;
+  };
+
+  // Do I need to match write a matcher for  %one:_(s16) = G_FCONSTANT half 1.0
+  // ??
+
+  auto getFdivSrc = [=](const MachineInstr &MI) -> MachineInstr * {
+    if (!MI.getFlag(MachineInstr::FmContract))
+      return nullptr;
+
+    MachineInstr *FDivSrcMI = nullptr;
+    Register One;
+    auto Match = mi_match(MI.getOperand(0).getReg(), MRI,
+                          m_GFDiv(m_Reg(One), m_MInstr(FDivSrcMI)));
+    // Not sure how to check for FDiv operancd has a 1.0 value ?
+    if (!MI.getOperand(1).isFPImm()) {
+      return nullptr;
+    }
+    if (!MI.getOperand(1).getFPImm()->isOneValue()) {
+      return nullptr;
+    }
+    (void)Match;
+    return FDivSrcMI;
+  };
+
+  MachineInstr *FDivSrcMI = nullptr, *SqrtSrcMI = nullptr;
+  if ((SqrtSrcMI = getSqrtSrc(MI)) && (FDivSrcMI = getFdivSrc(*SqrtSrcMI))) {
+    MatchInfo = [SqrtSrcMI, &MI](MachineIRBuilder &B) {
+      B.buildIntrinsic(Intrinsic::amdgcn_rsq, {MI.getOperand(0)})
+          .addUse(SqrtSrcMI->getOperand(0).getReg())
+          .setMIFlags(MI.getFlags());
+    };
+    return true;
+  }
+
+  return false;
+}
+
 bool AMDGPUPostLegalizerCombinerImpl::matchCvtF32UByteN(
     MachineInstr &MI, CvtF32UByteMatchInfo &MatchInfo) const {
   Register SrcReg = MI.getOperand(1).getReg();



More information about the llvm-commits mailing list