[llvm] 1076b31 - [GlobalISel] Combine select + fcmp to fminnum/fmaxnum/fminimum/fmaximum

Jessica Paquette via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 16 13:36:07 PDT 2022


Author: Jessica Paquette
Date: 2022-09-16T13:35:46-07:00
New Revision: 1076b31da8da8854e1dfca8f1a9a03de9fd4f8f5

URL: https://github.com/llvm/llvm-project/commit/1076b31da8da8854e1dfca8f1a9a03de9fd4f8f5
DIFF: https://github.com/llvm/llvm-project/commit/1076b31da8da8854e1dfca8f1a9a03de9fd4f8f5.diff

LOG: [GlobalISel] Combine select + fcmp to fminnum/fmaxnum/fminimum/fmaximum

This is a partial port of the code used by the SelectionDAGBuilder to
translate selects.

In particular, see matchSelectPattern in ValueTracking.cpp. This is a
GISel-equivalent of the portion which handles fminnum/fmaxnum/fminimum/fmaximum.

I tried to set it up so it'd be easy to add the non-FP cases. Those are simpler.
On the AArch64-end, it seems like the FP cases are more important for perf
right now, so I bit the bullet and went at the more complicated problem. :)

I elected to do this as a post-legalize combine rather than in the
IRTranslator because

Deciding which fmax/fmin to use can depend on legalization rules
Philosophically-speaking (TM), putting it in a combine just feels cleaner

Being able to enable/disable the combine is handy
Another option would be to use the ValueTracking code in the IRTranslator and
match what SelectionDAGBuilder::visitSelect does. I think that may be somewhat
annoying since we'd need to write lowerings back into the selects in the
legalizer. I'm not strongly opposed to the approach.

We'd also want to be careful with vector selects once that's implemented,
which explicitly check if a vector select is legal on the target. That'd
probably need a hook.

>From what I can tell, doing this as a combine is probably a cleaner option
long-term.

Differential Revision: https://reviews.llvm.org/D116702

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
    llvm/include/llvm/Target/GlobalISel/Combine.td
    llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
    llvm/lib/Target/AArch64/AArch64Combine.td

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
index cc9b85755619f..d5e323523024e 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
@@ -21,6 +21,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/CodeGen/Register.h"
 #include "llvm/Support/LowLevelTypeImpl.h"
+#include "llvm/IR/InstrTypes.h"
 #include <functional>
 
 namespace llvm {
@@ -755,6 +756,10 @@ class CombinerHelper {
   /// Transform G_ADD(G_SUB(y, x), x) to y.
   bool matchAddSubSameReg(MachineInstr &MI, Register &Src);
 
+  /// \returns true if it is possible to simplify a select instruction \p MI
+  /// to a min/max instruction of some sort.
+  bool matchSimplifySelectToMinMax(MachineInstr &MI, BuildFnTy &MatchInfo);
+
 private:
   /// Given a non-indexed load or store instruction \p MI, find an offset that
   /// can be usefully and legally folded into it as a post-indexing operation.
@@ -800,6 +805,49 @@ class CombinerHelper {
   /// a re-association of its operands would break an existing legal addressing
   /// mode that the address computation currently represents.
   bool reassociationCanBreakAddressingModePattern(MachineInstr &PtrAdd);
+
+  /// Behavior when a floating point min/max is given one NaN and one
+  /// non-NaN as input.
+  enum class SelectPatternNaNBehaviour {
+    NOT_APPLICABLE = 0, /// NaN behavior not applicable.
+    RETURNS_NAN,        /// Given one NaN input, returns the NaN.
+    RETURNS_OTHER,      /// Given one NaN input, returns the non-NaN.
+    RETURNS_ANY /// Given one NaN input, can return either (or both operands are
+                /// known non-NaN.)
+  };
+
+  /// \returns which of \p LHS and \p RHS would be the result of a non-equality
+  /// floating point comparison where one of \p LHS and \p RHS may be NaN.
+  ///
+  /// If both \p LHS and \p RHS may be NaN, returns
+  /// SelectPatternNaNBehaviour::NOT_APPLICABLE.
+  SelectPatternNaNBehaviour
+  computeRetValAgainstNaN(Register LHS, Register RHS,
+                          bool IsOrderedComparison) const;
+
+  /// Determines the floating point min/max opcode which should be used for
+  /// a G_SELECT fed by a G_FCMP with predicate \p Pred.
+  ///
+  /// \returns 0 if this G_SELECT should not be combined to a floating point
+  /// min or max. If it should be combined, returns one of
+  ///
+  /// * G_FMAXNUM
+  /// * G_FMAXIMUM
+  /// * G_FMINNUM
+  /// * G_FMINIMUM
+  ///
+  /// Helper function for matchFPSelectToMinMax.
+  unsigned getFPMinMaxOpcForSelect(CmpInst::Predicate Pred, LLT DstTy,
+                                   SelectPatternNaNBehaviour VsNaNRetVal) const;
+
+  /// Handle floating point cases for matchSimplifySelectToMinMax.
+  ///
+  /// E.g.
+  ///
+  /// select (fcmp uge x, 1.0) x, 1.0 -> fmax x, 1.0
+  /// select (fcmp uge x, 1.0) 1.0, x -> fminnm x, 1.0
+  bool matchFPSelectToMinMax(Register Dst, Register Cond, Register TrueVal,
+                             Register FalseVal, BuildFnTy &MatchInfo);
 };
 } // namespace llvm
 

diff  --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td
index cc726bdb70ee9..5ec829f29869e 100644
--- a/llvm/include/llvm/Target/GlobalISel/Combine.td
+++ b/llvm/include/llvm/Target/GlobalISel/Combine.td
@@ -929,6 +929,12 @@ def add_sub_reg: GICombineRule <
   (apply [{ return Helper.replaceSingleDefInstWithReg(*${root},
                                                       ${matchinfo}); }])>;
 
+def select_to_minmax: GICombineRule<
+  (defs root:$root, build_fn_matchinfo:$info),
+  (match (wip_match_opcode G_SELECT):$root,
+         [{ return Helper.matchSimplifySelectToMinMax(*${root}, ${info}); }]),
+  (apply [{ Helper.applyBuildFn(*${root}, ${info}); }])>;
+
 // FIXME: These should use the custom predicate feature once it lands.
 def undef_combines : GICombineGroup<[undef_to_fp_zero, undef_to_int_zero,
                                      undef_to_negative_one,

diff  --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index 88ecb41ed9b27..070a70aeb9285 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -5818,6 +5818,138 @@ bool CombinerHelper::matchAddSubSameReg(MachineInstr &MI, Register &Src) {
   return CheckFold(LHS, RHS) || CheckFold(RHS, LHS);
 }
 
+unsigned CombinerHelper::getFPMinMaxOpcForSelect(
+    CmpInst::Predicate Pred, LLT DstTy,
+    SelectPatternNaNBehaviour VsNaNRetVal) const {
+  assert(VsNaNRetVal != SelectPatternNaNBehaviour::NOT_APPLICABLE &&
+         "Expected a NaN behaviour?");
+  // Choose an opcode based off of legality or the behaviour when one of the
+  // LHS/RHS may be NaN.
+  switch (Pred) {
+  default:
+    return 0;
+  case CmpInst::FCMP_UGT:
+  case CmpInst::FCMP_UGE:
+  case CmpInst::FCMP_OGT:
+  case CmpInst::FCMP_OGE:
+    if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_OTHER)
+      return TargetOpcode::G_FMAXNUM;
+    if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_NAN)
+      return TargetOpcode::G_FMAXIMUM;
+    if (isLegal({TargetOpcode::G_FMAXNUM, {DstTy}}))
+      return TargetOpcode::G_FMAXNUM;
+    if (isLegal({TargetOpcode::G_FMAXIMUM, {DstTy}}))
+      return TargetOpcode::G_FMAXIMUM;
+    return 0;
+  case CmpInst::FCMP_ULT:
+  case CmpInst::FCMP_ULE:
+  case CmpInst::FCMP_OLT:
+  case CmpInst::FCMP_OLE:
+    if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_OTHER)
+      return TargetOpcode::G_FMINNUM;
+    if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_NAN)
+      return TargetOpcode::G_FMINIMUM;
+    if (isLegal({TargetOpcode::G_FMINNUM, {DstTy}}))
+      return TargetOpcode::G_FMINNUM;
+    if (!isLegal({TargetOpcode::G_FMINIMUM, {DstTy}}))
+      return 0;
+    return TargetOpcode::G_FMINIMUM;
+  }
+}
+
+CombinerHelper::SelectPatternNaNBehaviour
+CombinerHelper::computeRetValAgainstNaN(Register LHS, Register RHS,
+                                        bool IsOrderedComparison) const {
+  bool LHSSafe = isKnownNeverNaN(LHS, MRI);
+  bool RHSSafe = isKnownNeverNaN(RHS, MRI);
+  // Completely unsafe.
+  if (!LHSSafe && !RHSSafe)
+    return SelectPatternNaNBehaviour::NOT_APPLICABLE;
+  if (LHSSafe && RHSSafe)
+    return SelectPatternNaNBehaviour::RETURNS_ANY;
+  // An ordered comparison will return false when given a NaN, so it
+  // returns the RHS.
+  if (IsOrderedComparison)
+    return LHSSafe ? SelectPatternNaNBehaviour::RETURNS_NAN
+                   : SelectPatternNaNBehaviour::RETURNS_OTHER;
+  // An unordered comparison will return true when given a NaN, so it
+  // returns the LHS.
+  return LHSSafe ? SelectPatternNaNBehaviour::RETURNS_OTHER
+                 : SelectPatternNaNBehaviour::RETURNS_NAN;
+}
+
+bool CombinerHelper::matchFPSelectToMinMax(Register Dst, Register Cond,
+                                           Register TrueVal, Register FalseVal,
+                                           BuildFnTy &MatchInfo) {
+  // Match: select (fcmp cond x, y) x, y
+  //        select (fcmp cond x, y) y, x
+  // And turn it into fminnum/fmaxnum or fmin/fmax based off of the condition.
+  LLT DstTy = MRI.getType(Dst);
+  // Bail out early on pointers, since we'll never want to fold to a min/max.
+  // TODO: Handle vectors.
+  if (DstTy.isPointer() || DstTy.isVector())
+    return false;
+  // Match a floating point compare with a less-than/greater-than predicate.
+  // TODO: Allow multiple users of the compare if they are all selects.
+  CmpInst::Predicate Pred;
+  Register CmpLHS, CmpRHS;
+  if (!mi_match(Cond, MRI,
+                m_OneNonDBGUse(
+                    m_GFCmp(m_Pred(Pred), m_Reg(CmpLHS), m_Reg(CmpRHS)))) ||
+      CmpInst::isEquality(Pred))
+    return false;
+  SelectPatternNaNBehaviour ResWithKnownNaNInfo =
+      computeRetValAgainstNaN(CmpLHS, CmpRHS, CmpInst::isOrdered(Pred));
+  if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::NOT_APPLICABLE)
+    return false;
+  if (TrueVal == CmpRHS && FalseVal == CmpLHS) {
+    std::swap(CmpLHS, CmpRHS);
+    Pred = CmpInst::getSwappedPredicate(Pred);
+    if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::RETURNS_NAN)
+      ResWithKnownNaNInfo = SelectPatternNaNBehaviour::RETURNS_OTHER;
+    else if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::RETURNS_OTHER)
+      ResWithKnownNaNInfo = SelectPatternNaNBehaviour::RETURNS_NAN;
+  }
+  if (TrueVal != CmpLHS || FalseVal != CmpRHS)
+    return false;
+  // Decide what type of max/min this should be based off of the predicate.
+  unsigned Opc = getFPMinMaxOpcForSelect(Pred, DstTy, ResWithKnownNaNInfo);
+  if (!Opc || !isLegal({Opc, {DstTy}}))
+    return false;
+  // Comparisons between signed zero and zero may have 
diff erent results...
+  // unless we have fmaximum/fminimum. In that case, we know -0 < 0.
+  if (Opc != TargetOpcode::G_FMAXIMUM && Opc != TargetOpcode::G_FMINIMUM) {
+    // We don't know if a comparison between two 0s will give us a consistent
+    // result. Be conservative and only proceed if at least one side is
+    // non-zero.
+    auto KnownNonZeroSide = getFConstantVRegValWithLookThrough(CmpLHS, MRI);
+    if (!KnownNonZeroSide || !KnownNonZeroSide->Value.isNonZero()) {
+      KnownNonZeroSide = getFConstantVRegValWithLookThrough(CmpRHS, MRI);
+      if (!KnownNonZeroSide || !KnownNonZeroSide->Value.isNonZero())
+        return false;
+    }
+  }
+  MatchInfo = [=](MachineIRBuilder &B) {
+    B.buildInstr(Opc, {Dst}, {CmpLHS, CmpRHS});
+  };
+  return true;
+}
+
+bool CombinerHelper::matchSimplifySelectToMinMax(MachineInstr &MI,
+                                                 BuildFnTy &MatchInfo) {
+  // TODO: Handle integer cases.
+  assert(MI.getOpcode() == TargetOpcode::G_SELECT);
+  // Condition may be fed by a truncated compare.
+  Register Cond = MI.getOperand(1).getReg();
+  Register MaybeTrunc;
+  if (mi_match(Cond, MRI, m_OneNonDBGUse(m_GTrunc(m_Reg(MaybeTrunc)))))
+    Cond = MaybeTrunc;
+  Register Dst = MI.getOperand(0).getReg();
+  Register TrueVal = MI.getOperand(2).getReg();
+  Register FalseVal = MI.getOperand(3).getReg();
+  return matchFPSelectToMinMax(Dst, Cond, TrueVal, FalseVal, MatchInfo);
+}
+
 bool CombinerHelper::tryCombine(MachineInstr &MI) {
   if (tryCombineCopy(MI))
     return true;

diff  --git a/llvm/lib/Target/AArch64/AArch64Combine.td b/llvm/lib/Target/AArch64/AArch64Combine.td
index 18c111255e538..fff69211d25a1 100644
--- a/llvm/lib/Target/AArch64/AArch64Combine.td
+++ b/llvm/lib/Target/AArch64/AArch64Combine.td
@@ -228,6 +228,7 @@ def AArch64PostLegalizerCombinerHelper
                         select_combines, fold_merge_to_zext,
                         constant_fold, identity_combines,
                         ptr_add_immed_chain, overlapping_and,
-                        split_store_zero_128, undef_combines]> {
+                        split_store_zero_128, undef_combines,
+                        select_to_minmax]> {
   let DisableRuleOption = "aarch64postlegalizercombiner-disable-rule";
 }


        


More information about the llvm-commits mailing list