[llvm] GlobalISel needs fdiv 1 / sqrt(x) to rsq combine (PR #78673)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Jan 18 22:04:46 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-amdgpu
Author: Nick Anderson (nickleus27)
<details>
<summary>Changes</summary>
Fixes #<!-- -->64743
@<!-- -->arsenm @<!-- -->Pierre-vh Could you guys review and let me know if I am headed in the right direction.
1. is the MIR I am trying to match against
` %sqrt:_(s16) = contract G_FSQRT %x
%one:_(s16) = G_FCONSTANT half 1.0
%rsq:_(s16) = contract G_FDIV %one, %sqrt
` ?
2. Will the matcher in `AMDGPUCombine.td` match the above MIR and call the function I made called `matchFDivSqrt`?
3. Any advice on what needs to be done in `AMDGPUPostLegalizerCombinerImpl::matchFDivSqrt` would be appreciated. For example, what is the state of MI that is passed in? Is it a single instruction or is it a chain/tree of instructions?
---
Full diff: https://github.com/llvm/llvm-project/pull/78673.diff
3 Files Affected:
- (modified) llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h (+6)
- (modified) llvm/lib/Target/AMDGPU/AMDGPUCombine.td (+6-1)
- (modified) llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp (+56)
``````````diff
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
index ea6ed322e9b1927..6ffb0842db3e4e6 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
@@ -495,6 +495,12 @@ m_GFMul(const LHS &L, const RHS &R) {
return BinaryOp_match<LHS, RHS, TargetOpcode::G_FMUL, true>(L, R);
}
+template <typename LHS, typename RHS>
+inline BinaryOp_match<LHS, RHS, TargetOpcode::G_FDIV, true>
+m_GFDiv(const LHS &L, const RHS &R) {
+ return BinaryOp_match<LHS, RHS, TargetOpcode::G_FDIV, true>(L, R);
+}
+
template <typename LHS, typename RHS>
inline BinaryOp_match<LHS, RHS, TargetOpcode::G_FSUB, false>
m_GFSub(const LHS &L, const RHS &R) {
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCombine.td b/llvm/lib/Target/AMDGPU/AMDGPUCombine.td
index b9411e2052120d8..f26fb12dc1149f0 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUCombine.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPUCombine.td
@@ -33,6 +33,11 @@ def rcp_sqrt_to_rsq : GICombineRule<
[{ return matchRcpSqrtToRsq(*${rcp}, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFn(*${rcp}, ${matchinfo}); }])>;
+def fdiv_1_by_sqrt_to_rsq : GICombineRule<
+ (defs root:$root, build_fn_matchinfo:$matchinfo),
+ (match (wip_match_opcode G_FSQRT, G_FDIV):$root,
+ [{ return matchFDivSqrt(*${root}, ${matchinfo}); }]),
+ (apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;
def cvt_f32_ubyteN_matchdata : GIDefMatchData<"CvtF32UByteMatchInfo">;
@@ -156,7 +161,7 @@ def AMDGPUPostLegalizerCombiner: GICombiner<
"AMDGPUPostLegalizerCombinerImpl",
[all_combines, gfx6gfx7_combines, gfx8_combines,
uchar_to_float, cvt_f32_ubyteN, remove_fcanonicalize, foldable_fneg,
- rcp_sqrt_to_rsq, sign_extension_in_reg, smulu64]> {
+ rcp_sqrt_to_rsq, fdiv_1_by_sqrt_to_rsq, sign_extension_in_reg, smulu64]> {
let CombineAllMethodName = "tryCombineAllImpl";
}
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp b/llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp
index a1c34e92a57f356..9cd8436c188dc47 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp
@@ -83,6 +83,9 @@ class AMDGPUPostLegalizerCombinerImpl : public Combiner {
matchRcpSqrtToRsq(MachineInstr &MI,
std::function<void(MachineIRBuilder &)> &MatchInfo) const;
+ bool matchFDivSqrt(MachineInstr &MI,
+ std::function<void(MachineIRBuilder &)> &MatchInfo) const;
+
// FIXME: Should be able to have 2 separate matchdatas rather than custom
// struct boilerplate.
struct CvtF32UByteMatchInfo {
@@ -334,6 +337,59 @@ bool AMDGPUPostLegalizerCombinerImpl::matchRcpSqrtToRsq(
return false;
}
+bool AMDGPUPostLegalizerCombinerImpl::matchFDivSqrt(
+ MachineInstr &MI,
+ std::function<void(MachineIRBuilder &)> &MatchInfo) const {
+
+ // TODO: Can I match fdiv 1.0 / sqrt(x) from here?
+ // My apologies, this code is still a mess. Trying to figure out
+ // what value MI should hold when getting to this point
+
+ auto getSqrtSrc = [=](const MachineInstr &MI) -> MachineInstr * {
+ if (!MI.getFlag(MachineInstr::FmContract))
+ return nullptr;
+ MachineInstr *SqrtSrcMI = nullptr;
+ auto Match =
+ mi_match(MI.getOperand(0).getReg(), MRI, m_GFSqrt(m_MInstr(SqrtSrcMI)));
+ (void)Match;
+ return SqrtSrcMI;
+ };
+
+ // Do I need to match write a matcher for %one:_(s16) = G_FCONSTANT half 1.0
+ // ??
+
+ auto getFdivSrc = [=](const MachineInstr &MI) -> MachineInstr * {
+ if (!MI.getFlag(MachineInstr::FmContract))
+ return nullptr;
+
+ MachineInstr *FDivSrcMI = nullptr;
+ Register One;
+ auto Match = mi_match(MI.getOperand(0).getReg(), MRI,
+ m_GFDiv(m_Reg(One), m_MInstr(FDivSrcMI)));
+ // Not sure how to check for FDiv operancd has a 1.0 value ?
+ if (!MI.getOperand(1).isFPImm()) {
+ return nullptr;
+ }
+ if (!MI.getOperand(1).getFPImm()->isOneValue()) {
+ return nullptr;
+ }
+ (void)Match;
+ return FDivSrcMI;
+ };
+
+ MachineInstr *FDivSrcMI = nullptr, *SqrtSrcMI = nullptr;
+ if ((SqrtSrcMI = getSqrtSrc(MI)) && (FDivSrcMI = getFdivSrc(*SqrtSrcMI))) {
+ MatchInfo = [SqrtSrcMI, &MI](MachineIRBuilder &B) {
+ B.buildIntrinsic(Intrinsic::amdgcn_rsq, {MI.getOperand(0)})
+ .addUse(SqrtSrcMI->getOperand(0).getReg())
+ .setMIFlags(MI.getFlags());
+ };
+ return true;
+ }
+
+ return false;
+}
+
bool AMDGPUPostLegalizerCombinerImpl::matchCvtF32UByteN(
MachineInstr &MI, CvtF32UByteMatchInfo &MatchInfo) const {
Register SrcReg = MI.getOperand(1).getReg();
``````````
</details>
https://github.com/llvm/llvm-project/pull/78673
More information about the llvm-commits
mailing list