[llvm] [GlobalISel][AArch6] Add G_FPTOSI_SAT/G_FPTOUI_SAT (PR #96297)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 21 04:05:10 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-support

@llvm/pr-subscribers-llvm-globalisel

Author: David Green (davemgreen)

<details>
<summary>Changes</summary>

This is an implementation of the saturating fp to int conversions for GlobalISel. On AArch64 the converstion instrctions work this way, producing saturating results.  LegalizerHelper::lowerFPTOINT_SAT is ported from SDAG.

AArch64 has a lot of existing tests for fptosi_sat, covering a wide range of types. I have tried to make most of them work all at once, but a few fall back due to other missing features such as f128 handling for min/max.

---

Patch is 544.98 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/96297.diff


19 Files Affected:

- (modified) llvm/docs/GlobalISel/GenericOpcode.rst (+5) 
- (modified) llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h (+2) 
- (modified) llvm/include/llvm/CodeGen/GlobalISel/LegalizerHelper.h (+1) 
- (modified) llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h (+10) 
- (modified) llvm/include/llvm/Support/TargetOpcodes.def (+6) 
- (modified) llvm/include/llvm/Target/GenericOpcodes.td (+12) 
- (modified) llvm/include/llvm/Target/GlobalISel/SelectionDAGCompat.td (+2) 
- (modified) llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp (+8) 
- (modified) llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp (+156-9) 
- (modified) llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp (+6) 
- (modified) llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp (+50) 
- (modified) llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp (+4) 
- (modified) llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir (+7) 
- (modified) llvm/test/CodeGen/AArch64/fptosi-sat-scalar.ll (+650-324) 
- (modified) llvm/test/CodeGen/AArch64/fptosi-sat-vector.ll (+3796-2177) 
- (modified) llvm/test/CodeGen/AArch64/fptoui-sat-scalar.ll (+477-235) 
- (modified) llvm/test/CodeGen/AArch64/fptoui-sat-vector.ll (+2967-1702) 
- (modified) llvm/test/TableGen/GlobalISelCombinerEmitter/match-table-cxx.td (+19-19) 
- (modified) llvm/test/TableGen/GlobalISelEmitter.td (+1-1) 


``````````diff
diff --git a/llvm/docs/GlobalISel/GenericOpcode.rst b/llvm/docs/GlobalISel/GenericOpcode.rst
index 42f56348885b4..54a78745fa8f9 100644
--- a/llvm/docs/GlobalISel/GenericOpcode.rst
+++ b/llvm/docs/GlobalISel/GenericOpcode.rst
@@ -473,6 +473,11 @@ G_FPTOSI, G_FPTOUI, G_SITOFP, G_UITOFP
 
 Convert between integer and floating point.
 
+G_FPTOSI_SAT, G_FPTOUI_SAT
+^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Saturating convert between integer and floating point.
+
 G_FABS
 ^^^^^^
 
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h b/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
index 995031f7c00be..2d0773627aa4c 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
@@ -818,6 +818,8 @@ class GCastOp : public GenericMachineInstr {
     case TargetOpcode::G_FPEXT:
     case TargetOpcode::G_FPTOSI:
     case TargetOpcode::G_FPTOUI:
+    case TargetOpcode::G_FPTOSI_SAT:
+    case TargetOpcode::G_FPTOUI_SAT:
     case TargetOpcode::G_FPTRUNC:
     case TargetOpcode::G_INTTOPTR:
     case TargetOpcode::G_PTRTOINT:
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerHelper.h
index 284f434fbb9b0..0d2572b982e34 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerHelper.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerHelper.h
@@ -393,6 +393,7 @@ class LegalizerHelper {
   LegalizeResult lowerSITOFP(MachineInstr &MI);
   LegalizeResult lowerFPTOUI(MachineInstr &MI);
   LegalizeResult lowerFPTOSI(MachineInstr &MI);
+  LegalizeResult lowerFPTOINT_SAT(MachineInstr &MI);
 
   LegalizeResult lowerFPTRUNC_F64_TO_F16(MachineInstr &MI);
   LegalizeResult lowerFPTRUNC(MachineInstr &MI);
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h b/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h
index 92e05ee858a75..7429ca42640bb 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h
@@ -1965,6 +1965,16 @@ class MachineIRBuilder {
     return buildInstr(TargetOpcode::G_FPTOSI, {Dst}, {Src0});
   }
 
+  /// Build and insert \p Res = G_FPTOUI_SAT \p Src0
+  MachineInstrBuilder buildFPTOUI_SAT(const DstOp &Dst, const SrcOp &Src0) {
+    return buildInstr(TargetOpcode::G_FPTOUI_SAT, {Dst}, {Src0});
+  }
+
+  /// Build and insert \p Res = G_FPTOSI_SAT \p Src0
+  MachineInstrBuilder buildFPTOSI_SAT(const DstOp &Dst, const SrcOp &Src0) {
+    return buildInstr(TargetOpcode::G_FPTOSI_SAT, {Dst}, {Src0});
+  }
+
   /// Build and insert \p Dst = G_INTRINSIC_ROUNDEVEN \p Src0, \p Src1
   MachineInstrBuilder
   buildIntrinsicRoundeven(const DstOp &Dst, const SrcOp &Src0,
diff --git a/llvm/include/llvm/Support/TargetOpcodes.def b/llvm/include/llvm/Support/TargetOpcodes.def
index df4b264af72a8..a82b3555ae924 100644
--- a/llvm/include/llvm/Support/TargetOpcodes.def
+++ b/llvm/include/llvm/Support/TargetOpcodes.def
@@ -665,6 +665,12 @@ HANDLE_TARGET_OPCODE(G_SITOFP)
 /// Generic unsigned-int to float conversion
 HANDLE_TARGET_OPCODE(G_UITOFP)
 
+/// Generic saturating float to signed-int conversion
+HANDLE_TARGET_OPCODE(G_FPTOSI_SAT)
+
+/// Generic saturating float to unsigned-int conversion
+HANDLE_TARGET_OPCODE(G_FPTOUI_SAT)
+
 /// Generic FP absolute value.
 HANDLE_TARGET_OPCODE(G_FABS)
 
diff --git a/llvm/include/llvm/Target/GenericOpcodes.td b/llvm/include/llvm/Target/GenericOpcodes.td
index 4abffe6476c85..eae1ff825a0d4 100644
--- a/llvm/include/llvm/Target/GenericOpcodes.td
+++ b/llvm/include/llvm/Target/GenericOpcodes.td
@@ -749,6 +749,18 @@ def G_UITOFP : GenericInstruction {
   let hasSideEffects = false;
 }
 
+def G_FPTOSI_SAT : GenericInstruction {
+  let OutOperandList = (outs type0:$dst);
+  let InOperandList = (ins type1:$src);
+  let hasSideEffects = false;
+}
+
+def G_FPTOUI_SAT : GenericInstruction {
+  let OutOperandList = (outs type0:$dst);
+  let InOperandList = (ins type1:$src);
+  let hasSideEffects = false;
+}
+
 def G_FABS : GenericInstruction {
   let OutOperandList = (outs type0:$dst);
   let InOperandList = (ins type0:$src);
diff --git a/llvm/include/llvm/Target/GlobalISel/SelectionDAGCompat.td b/llvm/include/llvm/Target/GlobalISel/SelectionDAGCompat.td
index 560d3b434d07d..be47f50a803e0 100644
--- a/llvm/include/llvm/Target/GlobalISel/SelectionDAGCompat.td
+++ b/llvm/include/llvm/Target/GlobalISel/SelectionDAGCompat.td
@@ -98,6 +98,8 @@ def : GINodeEquiv<G_FPTOSI, fp_to_sint>;
 def : GINodeEquiv<G_FPTOUI, fp_to_uint>;
 def : GINodeEquiv<G_SITOFP, sint_to_fp>;
 def : GINodeEquiv<G_UITOFP, uint_to_fp>;
+def : GINodeEquiv<G_FPTOSI_SAT, fp_to_sint_sat>;
+def : GINodeEquiv<G_FPTOUI_SAT, fp_to_uint_sat>;
 def : GINodeEquiv<G_FADD, fadd>;
 def : GINodeEquiv<G_FSUB, fsub>;
 def : GINodeEquiv<G_FMA, fma>;
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index c06b35a98e434..c80a90557a569 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2330,6 +2330,14 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
                            MachineInstr::copyFlagsFromInstruction(CI));
     return true;
   }
+  case Intrinsic::fptosi_sat:
+    MIRBuilder.buildFPTOSI_SAT(getOrCreateVReg(CI),
+                               getOrCreateVReg(*CI.getArgOperand(0)));
+    return true;
+  case Intrinsic::fptoui_sat:
+    MIRBuilder.buildFPTOUI_SAT(getOrCreateVReg(CI),
+                               getOrCreateVReg(*CI.getArgOperand(0)));
+    return true;
   case Intrinsic::memcpy_inline:
     return translateMemFunc(CI, MIRBuilder, TargetOpcode::G_MEMCPY_INLINE);
   case Intrinsic::memcpy:
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index 430fcae731689..d9c243327e92a 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -1715,6 +1715,8 @@ LegalizerHelper::LegalizeResult LegalizerHelper::narrowScalar(MachineInstr &MI,
   }
   case TargetOpcode::G_FPTOUI:
   case TargetOpcode::G_FPTOSI:
+  case TargetOpcode::G_FPTOUI_SAT:
+  case TargetOpcode::G_FPTOSI_SAT:
     return narrowScalarFPTOI(MI, TypeIdx, NarrowTy);
   case TargetOpcode::G_FPEXT:
     if (TypeIdx != 0)
@@ -2358,7 +2360,8 @@ LegalizerHelper::widenScalarMulo(MachineInstr &MI, unsigned TypeIdx,
 
 LegalizerHelper::LegalizeResult
 LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) {
-  switch (MI.getOpcode()) {
+  unsigned Opcode = MI.getOpcode();
+  switch (Opcode) {
   default:
     return UnableToLegalize;
   case TargetOpcode::G_ATOMICRMW_XCHG:
@@ -2442,13 +2445,13 @@ LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) {
     Register SrcReg = MI.getOperand(1).getReg();
 
     // First extend the input.
-    unsigned ExtOpc = MI.getOpcode() == TargetOpcode::G_CTTZ ||
-                              MI.getOpcode() == TargetOpcode::G_CTTZ_ZERO_UNDEF
+    unsigned ExtOpc = Opcode == TargetOpcode::G_CTTZ ||
+                              Opcode == TargetOpcode::G_CTTZ_ZERO_UNDEF
                           ? TargetOpcode::G_ANYEXT
                           : TargetOpcode::G_ZEXT;
     auto MIBSrc = MIRBuilder.buildInstr(ExtOpc, {WideTy}, {SrcReg});
     LLT CurTy = MRI.getType(SrcReg);
-    unsigned NewOpc = MI.getOpcode();
+    unsigned NewOpc = Opcode;
     if (NewOpc == TargetOpcode::G_CTTZ) {
       // The count is the same in the larger type except if the original
       // value was zero.  This can be handled by setting the bit just off
@@ -2464,8 +2467,8 @@ LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) {
     // Perform the operation at the larger size.
     auto MIBNewOp = MIRBuilder.buildInstr(NewOpc, {WideTy}, {MIBSrc});
     // This is already the correct result for CTPOP and CTTZs
-    if (MI.getOpcode() == TargetOpcode::G_CTLZ ||
-        MI.getOpcode() == TargetOpcode::G_CTLZ_ZERO_UNDEF) {
+    if (Opcode == TargetOpcode::G_CTLZ ||
+        Opcode == TargetOpcode::G_CTLZ_ZERO_UNDEF) {
       // The correct result is NewOp - (Difference in widety and current ty).
       unsigned SizeDiff = WideTy.getSizeInBits() - CurTy.getSizeInBits();
       MIBNewOp = MIRBuilder.buildSub(
@@ -2614,8 +2617,8 @@ LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) {
     Observer.changingInstr(MI);
 
     if (TypeIdx == 0) {
-      unsigned CvtOp = MI.getOpcode() == TargetOpcode::G_ASHR ?
-        TargetOpcode::G_SEXT : TargetOpcode::G_ZEXT;
+      unsigned CvtOp = Opcode == TargetOpcode::G_ASHR ? TargetOpcode::G_SEXT
+                                                      : TargetOpcode::G_ZEXT;
 
       widenScalarSrc(MI, WideTy, 1, CvtOp);
       widenScalarDst(MI, WideTy);
@@ -2697,6 +2700,47 @@ LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) {
     else
       widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ZEXT);
 
+    Observer.changedInstr(MI);
+    return Legalized;
+  case TargetOpcode::G_FPTOSI_SAT:
+  case TargetOpcode::G_FPTOUI_SAT:
+    Observer.changingInstr(MI);
+
+    if (TypeIdx == 0) {
+      Register OldDst = MI.getOperand(0).getReg();
+      LLT Ty = MRI.getType(OldDst);
+      Register ExtReg = MRI.createGenericVirtualRegister(WideTy);
+      Register NewDst;
+      MI.getOperand(0).setReg(ExtReg);
+      uint64_t ShortBits = Ty.getScalarSizeInBits();
+      uint64_t WideBits = WideTy.getScalarSizeInBits();
+      MIRBuilder.setInsertPt(MIRBuilder.getMBB(), ++MIRBuilder.getInsertPt());
+      if (Opcode == TargetOpcode::G_FPTOSI_SAT) {
+        // z = i16 fptosi_sat(a)
+        // ->
+        // x = i32 fptosi_sat(a)
+        // y = smin(x, 32767)
+        // z = smax(y, -32768)
+        auto MaxVal = MIRBuilder.buildConstant(
+            WideTy, APInt::getSignedMaxValue(ShortBits).sext(WideBits));
+        auto MinVal = MIRBuilder.buildConstant(
+            WideTy, APInt::getSignedMinValue(ShortBits).sext(WideBits));
+        Register MidReg =
+            MIRBuilder.buildSMin(WideTy, ExtReg, MaxVal).getReg(0);
+        NewDst = MIRBuilder.buildSMax(WideTy, MidReg, MinVal).getReg(0);
+      } else {
+        // z = i16 fptoui_sat(a)
+        // ->
+        // x = i32 fptoui_sat(a)
+        // y = smin(x, 65535)
+        auto MaxVal = MIRBuilder.buildConstant(
+            WideTy, APInt::getAllOnes(ShortBits).zext(WideBits));
+        NewDst = MIRBuilder.buildUMin(WideTy, ExtReg, MaxVal).getReg(0);
+      }
+      MIRBuilder.buildTrunc(OldDst, NewDst);
+    } else
+      widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_FPEXT);
+
     Observer.changedInstr(MI);
     return Legalized;
   case TargetOpcode::G_LOAD:
@@ -2921,7 +2965,7 @@ LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) {
   case TargetOpcode::G_FLDEXP:
   case TargetOpcode::G_STRICT_FLDEXP: {
     if (TypeIdx == 0) {
-      if (MI.getOpcode() == TargetOpcode::G_STRICT_FLDEXP)
+      if (Opcode == TargetOpcode::G_STRICT_FLDEXP)
         return UnableToLegalize;
 
       Observer.changingInstr(MI);
@@ -3914,6 +3958,9 @@ LegalizerHelper::lower(MachineInstr &MI, unsigned TypeIdx, LLT LowerHintTy) {
     return lowerFPTOUI(MI);
   case G_FPTOSI:
     return lowerFPTOSI(MI);
+  case G_FPTOUI_SAT:
+  case G_FPTOSI_SAT:
+    return lowerFPTOINT_SAT(MI);
   case G_FPTRUNC:
     return lowerFPTRUNC(MI);
   case G_FPOWI:
@@ -4713,6 +4760,8 @@ LegalizerHelper::fewerElementsVector(MachineInstr &MI, unsigned TypeIdx,
   case G_UITOFP:
   case G_FPTOSI:
   case G_FPTOUI:
+  case G_FPTOSI_SAT:
+  case G_FPTOUI_SAT:
   case G_INTTOPTR:
   case G_PTRTOINT:
   case G_ADDRSPACE_CAST:
@@ -5504,6 +5553,8 @@ LegalizerHelper::moreElementsVector(MachineInstr &MI, unsigned TypeIdx,
   case TargetOpcode::G_FPEXT:
   case TargetOpcode::G_FPTOSI:
   case TargetOpcode::G_FPTOUI:
+  case TargetOpcode::G_FPTOSI_SAT:
+  case TargetOpcode::G_FPTOUI_SAT:
   case TargetOpcode::G_SITOFP:
   case TargetOpcode::G_UITOFP: {
     Observer.changingInstr(MI);
@@ -7012,6 +7063,102 @@ LegalizerHelper::LegalizeResult LegalizerHelper::lowerFPTOSI(MachineInstr &MI) {
   return Legalized;
 }
 
+LegalizerHelper::LegalizeResult
+LegalizerHelper::lowerFPTOINT_SAT(MachineInstr &MI) {
+  auto [Dst, DstTy, Src, SrcTy] = MI.getFirst2RegLLTs();
+
+  bool IsSigned = MI.getOpcode() == TargetOpcode::G_FPTOSI_SAT;
+  unsigned SatWidth = DstTy.getScalarSizeInBits();
+
+  // Determine minimum and maximum integer values and their corresponding
+  // floating-point values.
+  APInt MinInt, MaxInt;
+  if (IsSigned) {
+    MinInt = APInt::getSignedMinValue(SatWidth);
+    MaxInt = APInt::getSignedMaxValue(SatWidth);
+  } else {
+    MinInt = APInt::getMinValue(SatWidth);
+    MaxInt = APInt::getMaxValue(SatWidth);
+  }
+
+  const fltSemantics &Semantics = getFltSemanticForLLT(SrcTy.getScalarType());
+  APFloat MinFloat(Semantics);
+  APFloat MaxFloat(Semantics);
+
+  APFloat::opStatus MinStatus =
+      MinFloat.convertFromAPInt(MinInt, IsSigned, APFloat::rmTowardZero);
+  APFloat::opStatus MaxStatus =
+      MaxFloat.convertFromAPInt(MaxInt, IsSigned, APFloat::rmTowardZero);
+  bool AreExactFloatBounds = !(MinStatus & APFloat::opStatus::opInexact) &&
+                             !(MaxStatus & APFloat::opStatus::opInexact);
+
+  // If the integer bounds are exactly representable as floats and min/max are
+  // legal, emit a min+max+fptoi sequence. Otherwise we have to use a sequence
+  // of comparisons and selects.
+  bool MinMaxLegal = LI.isLegal({TargetOpcode::G_FMINNUM, SrcTy}) &&
+                     LI.isLegal({TargetOpcode::G_FMAXNUM, SrcTy});
+  if (AreExactFloatBounds && MinMaxLegal) {
+    // Clamp Src by MinFloat from below. If Src is NaN the result is MinFloat.
+    auto Max = MIRBuilder.buildFMaxNum(
+        SrcTy, Src, MIRBuilder.buildFConstant(SrcTy, MinFloat));
+    // Clamp by MaxFloat from above. NaN cannot occur.
+    auto Min = MIRBuilder.buildFMinNum(
+        SrcTy, Max, MIRBuilder.buildFConstant(SrcTy, MaxFloat));
+    // Convert clamped value to integer. In the unsigned case we're done,
+    // because we mapped NaN to MinFloat, which will cast to zero.
+    if (!IsSigned) {
+      MIRBuilder.buildFPTOUI(Dst, Min);
+      MI.eraseFromParent();
+      return Legalized;
+    }
+
+    // Otherwise, select 0 if Src is NaN.
+    auto FpToInt = MIRBuilder.buildFPTOSI(DstTy, Min);
+    auto IsZero = MIRBuilder.buildFCmp(CmpInst::FCMP_UNO,
+                                       DstTy.changeElementSize(1), Src, Src);
+    MIRBuilder.buildSelect(Dst, IsZero, MIRBuilder.buildConstant(DstTy, 0),
+                           FpToInt);
+    MI.eraseFromParent();
+    return Legalized;
+  }
+
+  // Result of direct conversion. The assumption here is that the operation is
+  // non-trapping and it's fine to apply it to an out-of-range value if we
+  // select it away later.
+  auto FpToInt = IsSigned ? MIRBuilder.buildFPTOSI(DstTy, Src)
+                          : MIRBuilder.buildFPTOUI(DstTy, Src);
+
+  // If Src ULT MinFloat, select MinInt. In particular, this also selects
+  // MinInt if Src is NaN.
+  auto ULT =
+      MIRBuilder.buildFCmp(CmpInst::FCMP_ULT, SrcTy.changeElementSize(1), Src,
+                           MIRBuilder.buildFConstant(SrcTy, MinFloat));
+  auto Max = MIRBuilder.buildSelect(
+      DstTy, ULT, MIRBuilder.buildConstant(DstTy, MinInt), FpToInt);
+  // If Src OGT MaxFloat, select MaxInt.
+  auto OGT =
+      MIRBuilder.buildFCmp(CmpInst::FCMP_OGT, SrcTy.changeElementSize(1), Src,
+                           MIRBuilder.buildFConstant(SrcTy, MaxFloat));
+
+  // In the unsigned case we are done, because we mapped NaN to MinInt, which
+  // is already zero.
+  if (!IsSigned) {
+    MIRBuilder.buildSelect(Dst, OGT, MIRBuilder.buildConstant(DstTy, MaxInt),
+                           Max);
+    MI.eraseFromParent();
+    return Legalized;
+  }
+
+  // Otherwise, select 0 if Src is NaN.
+  auto Min = MIRBuilder.buildSelect(
+      DstTy, OGT, MIRBuilder.buildConstant(DstTy, MaxInt), Max);
+  auto IsZero = MIRBuilder.buildFCmp(CmpInst::FCMP_UNO,
+                                     DstTy.changeElementSize(1), Src, Src);
+  MIRBuilder.buildSelect(Dst, IsZero, MIRBuilder.buildConstant(DstTy, 0), Min);
+  MI.eraseFromParent();
+  return Legalized;
+}
+
 // f64 -> f16 conversion using round-to-nearest-even rounding mode.
 LegalizerHelper::LegalizeResult
 LegalizerHelper::lowerFPTRUNC_F64_TO_F16(MachineInstr &MI) {
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
index d32007ec45fb6..9966c0f7c0450 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
@@ -2130,6 +2130,12 @@ bool AArch64InstructionSelector::preISelLower(MachineInstr &I) {
     }
     return false;
   }
+  case TargetOpcode::G_FPTOSI_SAT:
+    I.setDesc(TII.get(TargetOpcode::G_FPTOSI));
+    return true;
+  case TargetOpcode::G_FPTOUI_SAT:
+    I.setDesc(TII.get(TargetOpcode::G_FPTOUI));
+    return true;
   default:
     return false;
   }
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
index fef0b722efe45..f0444416f84f7 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
@@ -708,6 +708,56 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
       .libcallFor(
           {{s32, s128}, {s64, s128}, {s128, s128}, {s128, s32}, {s128, s64}});
 
+  getActionDefinitionsBuilder({G_FPTOSI_SAT, G_FPTOUI_SAT})
+      .legalFor({{s32, s32},
+                 {s64, s32},
+                 {s32, s64},
+                 {s64, s64},
+                 {v2s64, v2s64},
+                 {v4s32, v4s32},
+                 {v2s32, v2s32}})
+      .legalIf([=](const LegalityQuery &Query) {
+        return HasFP16 &&
+               (Query.Types[1] == s16 || Query.Types[1] == v4s16 ||
+                Query.Types[1] == v8s16) &&
+               (Query.Types[0] == s32 || Query.Types[0] == s64 ||
+                Query.Types[0] == v4s16 || Query.Types[0] == v8s16);
+      })
+      // Handle types larger than i64 by scalarizing/lowering.
+      .scalarizeIf(scalarOrEltWiderThan(0, 64), 0)
+      .scalarizeIf(scalarOrEltWiderThan(1, 64), 1)
+      // The range of a fp16 value fits into an i17, so we can lower the width
+      // to i64.
+      .narrowScalarIf(
+          [=](const LegalityQuery &Query) {
+            return Query.Types[1] == s16 &&
+                   Query.Types[0].getSizeInBits() > 64;
+          },
+          changeTo(0, s64))
+      .lowerIf(::any(scalarWiderThan(0, 64), scalarWiderThan(1, 64)), 0)
+      .moreElementsToNextPow2(0)
+      .widenScalarToNextPow2(0, /*MinSize=*/32)
+      .minScalar(0, s32)
+      .widenScalarOrEltToNextPow2OrMinSize(1, /*MinSize=*/HasFP16 ? 16 : 32)
+      .widenScalarIf(
+          [=](const LegalityQuery &Query) {
+            unsigned ITySize = Query.Types[0].getScalarSizeInBits();
+            return (ITySize == 16 || ITySize == 32 || ITySize == 64) &&
+                   ITySize > Query.Types[1].getScalarSizeInBits();
+          },
+          LegalizeMutations::changeElementSizeTo(1, 0))
+      .widenScalarIf(
+          [=](const LegalityQuery &Query) {
+            unsigned FTySize = Query.Types[1].getScalarSizeInBits();
+            return (FTySize == 16 || FTySize == 32 || FTySize == 64) &&
+                   Query.Types[0].getScalarSizeInBits() < FTySize;
+          },
+          LegalizeMutations::changeElementSizeTo(0, 1))
+      .widenScalarOrEltToNextPow2(0)
+      .clampNumElements(0, v4s16, v8s16)
+      .clampNumElements(0, v2s32, v4s32)
+      .clampMaxNumElements(0, s64, 2);
+
   getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})
       .legalForCartesianProduct({s32, s64, v2s64, v4s32, v2s32})
       .legalIf([=](const LegalityQuery &Query) {
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
index 5616d063f70bc..be5465a299b56 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
@@ -559,6 +559,8 @@ bool AArc...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/96297


More information about the llvm-commits mailing list