[llvm] ISel: introduce vector ISD::LRINT, ISD::LLRINT; custom RISCV lowering (PR #66924)

via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 20 08:52:25 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

<details>
<summary>Changes</summary>

The issue #<!-- -->55208 noticed that std::rint is vectorized by the SLPVectorizer, but a very similar function, std::lrint, is not. std::lrint corresponds to ISD::LRINT in the SelectionDAG, and std::llrint is a familiar cousin corresponding to ISD::LLRINT. Now, neither ISD::LRINT nor ISD::LLRINT have a corresponding vector variant, and the LangRef makes this clear in the documentation of llvm.lrint.* and llvm.llrint.*.

This patch extends the LangRef to include vector variants of llvm.lrint.* and llvm.llrint.*, and lays the necessary ground-work of scalarizing it for all targets. However, this patch would be devoid of motivation unless we show the utility of these new vector variants. Hence, the RISCV target has been chosen to implement a custom lowering to the vfcvt.x.f.v instruction. The patch also includes a CostModel for RISCV, and a trivial follow-up can potentially enable the SLPVectorizer to vectorize std::lrint and std::llrint, fixing #<!-- -->55208.

The patch includes tests, obviously for the RISCV target, but also for the X86, AArch64, and PowerPC targets to justify the addition of the vector variants to the LangRef.

---

Patch is 439.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/66924.diff


22 Files Affected:

- (modified) llvm/docs/LangRef.rst (+4-2) 
- (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+6) 
- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+19) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp (+9-5) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (+17) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+3-2) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp (+2) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp (+15-1) 
- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+2) 
- (modified) llvm/lib/CodeGen/TargetLoweringBase.cpp (+6-6) 
- (modified) llvm/lib/IR/Verifier.cpp (+1-3) 
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+57-40) 
- (modified) llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp (+27) 
- (modified) llvm/test/Analysis/CostModel/RISCV/fround.ll (+130) 
- (added) llvm/test/CodeGen/AArch64/vector-llrint.ll (+621) 
- (added) llvm/test/CodeGen/AArch64/vector-lrint.ll (+622) 
- (added) llvm/test/CodeGen/PowerPC/vector-llrint.ll (+4847) 
- (added) llvm/test/CodeGen/PowerPC/vector-lrint.ll (+4858) 
- (added) llvm/test/CodeGen/RISCV/rvv/llrint-sdnode.ll (+108) 
- (added) llvm/test/CodeGen/RISCV/rvv/lrint-sdnode.ll (+155) 
- (added) llvm/test/CodeGen/X86/vector-llrint.ll (+290) 
- (added) llvm/test/CodeGen/X86/vector-lrint.ll (+429) 


``````````diff
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index d3b0cb0cc50cec2..fdff5da1451cf3d 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -15802,7 +15802,8 @@ Syntax:
 """""""
 
 This is an overloaded intrinsic. You can use ``llvm.lrint`` on any
-floating-point type. Not all targets support all types however.
+floating-point type or vector of floating-point type. Not all targets
+support all types however.
 
 ::
 
@@ -15846,7 +15847,8 @@ Syntax:
 """""""
 
 This is an overloaded intrinsic. You can use ``llvm.llrint`` on any
-floating-point type. Not all targets support all types however.
+floating-point type or vector of floating-point type. Not all targets
+support all types however.
 
 ::
 
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index c11d558a73e9d09..eeca2f4060ed200 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -1843,6 +1843,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     case Intrinsic::rint:
       ISD = ISD::FRINT;
       break;
+    case Intrinsic::lrint:
+      ISD = ISD::LRINT;
+      break;
+    case Intrinsic::llrint:
+      ISD = ISD::LLRINT;
+      break;
     case Intrinsic::round:
       ISD = ISD::FROUND;
       break;
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 5088d59147492f1..4f3a9ac79965308 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -504,6 +504,7 @@ namespace {
     SDValue visitUINT_TO_FP(SDNode *N);
     SDValue visitFP_TO_SINT(SDNode *N);
     SDValue visitFP_TO_UINT(SDNode *N);
+    SDValue visitXRINT(SDNode *N);
     SDValue visitFP_ROUND(SDNode *N);
     SDValue visitFP_EXTEND(SDNode *N);
     SDValue visitFNEG(SDNode *N);
@@ -2005,6 +2006,9 @@ SDValue DAGCombiner::visit(SDNode *N) {
   case ISD::UINT_TO_FP:         return visitUINT_TO_FP(N);
   case ISD::FP_TO_SINT:         return visitFP_TO_SINT(N);
   case ISD::FP_TO_UINT:         return visitFP_TO_UINT(N);
+  case ISD::LRINT:
+  case ISD::LLRINT:
+    return visitXRINT(N);
   case ISD::FP_ROUND:           return visitFP_ROUND(N);
   case ISD::FP_EXTEND:          return visitFP_EXTEND(N);
   case ISD::FNEG:               return visitFNEG(N);
@@ -17206,6 +17210,21 @@ SDValue DAGCombiner::visitFP_TO_UINT(SDNode *N) {
   return FoldIntToFPToInt(N, DAG);
 }
 
+SDValue DAGCombiner::visitXRINT(SDNode *N) {
+  SDValue N0 = N->getOperand(0);
+  EVT VT = N->getValueType(0);
+
+  // fold (lrint|llrint undef) -> undef
+  if (N0.isUndef())
+    return DAG.getUNDEF(VT);
+
+  // fold (lrint|llrint c1fp) -> c1
+  if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
+    return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N0);
+
+  return SDValue();
+}
+
 SDValue DAGCombiner::visitFP_ROUND(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index 95f181217803515..2c8d23fbb674907 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -2209,10 +2209,15 @@ bool DAGTypeLegalizer::PromoteFloatOperand(SDNode *N, unsigned OpNo) {
     case ISD::BITCAST:    R = PromoteFloatOp_BITCAST(N, OpNo); break;
     case ISD::FCOPYSIGN:  R = PromoteFloatOp_FCOPYSIGN(N, OpNo); break;
     case ISD::FP_TO_SINT:
-    case ISD::FP_TO_UINT: R = PromoteFloatOp_FP_TO_XINT(N, OpNo); break;
+    case ISD::FP_TO_UINT:
+    case ISD::LRINT:
+    case ISD::LLRINT:
+      R = PromoteFloatOp_UnaryOp(N, OpNo);
+      break;
     case ISD::FP_TO_SINT_SAT:
     case ISD::FP_TO_UINT_SAT:
-                          R = PromoteFloatOp_FP_TO_XINT_SAT(N, OpNo); break;
+      R = PromoteFloatOp_BinOp(N, OpNo);
+      break;
     case ISD::FP_EXTEND:  R = PromoteFloatOp_FP_EXTEND(N, OpNo); break;
     case ISD::SELECT_CC:  R = PromoteFloatOp_SELECT_CC(N, OpNo); break;
     case ISD::SETCC:      R = PromoteFloatOp_SETCC(N, OpNo); break;
@@ -2251,13 +2256,12 @@ SDValue DAGTypeLegalizer::PromoteFloatOp_FCOPYSIGN(SDNode *N, unsigned OpNo) {
 }
 
 // Convert the promoted float value to the desired integer type
-SDValue DAGTypeLegalizer::PromoteFloatOp_FP_TO_XINT(SDNode *N, unsigned OpNo) {
+SDValue DAGTypeLegalizer::PromoteFloatOp_UnaryOp(SDNode *N, unsigned OpNo) {
   SDValue Op = GetPromotedFloat(N->getOperand(0));
   return DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0), Op);
 }
 
-SDValue DAGTypeLegalizer::PromoteFloatOp_FP_TO_XINT_SAT(SDNode *N,
-                                                        unsigned OpNo) {
+SDValue DAGTypeLegalizer::PromoteFloatOp_BinOp(SDNode *N, unsigned OpNo) {
   SDValue Op = GetPromotedFloat(N->getOperand(0));
   return DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0), Op,
                      N->getOperand(1));
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index fc9e3ff3734989d..70e85523b9540b5 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -162,6 +162,11 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::FP_TO_UINT_SAT:
                          Res = PromoteIntRes_FP_TO_XINT_SAT(N); break;
 
+  case ISD::LRINT:
+  case ISD::LLRINT:
+    Res = PromoteIntRes_XRINT(N);
+    break;
+
   case ISD::FP_TO_BF16:
   case ISD::FP_TO_FP16:
     Res = PromoteIntRes_FP_TO_FP16_BF16(N);
@@ -539,6 +544,18 @@ SDValue DAGTypeLegalizer::PromoteIntRes_BSWAP(SDNode *N) {
                      Mask, EVL);
 }
 
+SDValue DAGTypeLegalizer::PromoteIntRes_XRINT(SDNode *N) {
+  SDValue Op = GetPromotedInteger(N->getOperand(0));
+  EVT OVT = N->getValueType(0);
+  EVT NVT = Op.getValueType();
+  SDLoc dl(N);
+
+  unsigned DiffBits = NVT.getScalarSizeInBits() - OVT.getScalarSizeInBits();
+  SDValue ShAmt = DAG.getShiftAmountConstant(DiffBits, NVT, dl);
+  return DAG.getNode(ISD::SRL, dl, NVT,
+                     DAG.getNode(N->getOpcode(), dl, NVT, Op), ShAmt);
+}
+
 SDValue DAGTypeLegalizer::PromoteIntRes_BITREVERSE(SDNode *N) {
   SDValue Op = GetPromotedInteger(N->getOperand(0));
   EVT OVT = N->getValueType(0);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index c802604a3470e13..4f7fb9f12289052 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -324,6 +324,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteIntRes_CTTZ(SDNode *N);
   SDValue PromoteIntRes_EXTRACT_VECTOR_ELT(SDNode *N);
   SDValue PromoteIntRes_FP_TO_XINT(SDNode *N);
+  SDValue PromoteIntRes_XRINT(SDNode *N);
   SDValue PromoteIntRes_FP_TO_XINT_SAT(SDNode *N);
   SDValue PromoteIntRes_FP_TO_FP16_BF16(SDNode *N);
   SDValue PromoteIntRes_FREEZE(SDNode *N);
@@ -711,8 +712,8 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteFloatOp_BITCAST(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_FCOPYSIGN(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_FP_EXTEND(SDNode *N, unsigned OpNo);
-  SDValue PromoteFloatOp_FP_TO_XINT(SDNode *N, unsigned OpNo);
-  SDValue PromoteFloatOp_FP_TO_XINT_SAT(SDNode *N, unsigned OpNo);
+  SDValue PromoteFloatOp_UnaryOp(SDNode *N, unsigned OpNo);
+  SDValue PromoteFloatOp_BinOp(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_STORE(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_SELECT_CC(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_SETCC(SDNode *N, unsigned OpNo);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index dec81475f3a88fc..2bc9c5f0844541d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -396,6 +396,8 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
   case ISD::FCEIL:
   case ISD::FTRUNC:
   case ISD::FRINT:
+  case ISD::LRINT:
+  case ISD::LLRINT:
   case ISD::FNEARBYINT:
   case ISD::FROUND:
   case ISD::FROUNDEVEN:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 1bb6fbbf064b931..2c5343c3c4b160e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -101,6 +101,8 @@ void DAGTypeLegalizer::ScalarizeVectorResult(SDNode *N, unsigned ResNo) {
   case ISD::FP_TO_SINT:
   case ISD::FP_TO_UINT:
   case ISD::FRINT:
+  case ISD::LRINT:
+  case ISD::LLRINT:
   case ISD::FROUND:
   case ISD::FROUNDEVEN:
   case ISD::FSIN:
@@ -681,6 +683,8 @@ bool DAGTypeLegalizer::ScalarizeVectorOperand(SDNode *N, unsigned OpNo) {
   case ISD::FP_TO_UINT:
   case ISD::SINT_TO_FP:
   case ISD::UINT_TO_FP:
+  case ISD::LRINT:
+  case ISD::LLRINT:
     Res = ScalarizeVecOp_UnaryOp(N);
     break;
   case ISD::STRICT_SINT_TO_FP:
@@ -1097,6 +1101,8 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
   case ISD::VP_FP_TO_UINT:
   case ISD::FRINT:
   case ISD::VP_FRINT:
+  case ISD::LRINT:
+  case ISD::LLRINT:
   case ISD::FROUND:
   case ISD::VP_FROUND:
   case ISD::FROUNDEVEN:
@@ -2974,6 +2980,8 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
   case ISD::ZERO_EXTEND:
   case ISD::ANY_EXTEND:
   case ISD::FTRUNC:
+  case ISD::LRINT:
+  case ISD::LLRINT:
     Res = SplitVecOp_UnaryOp(N);
     break;
   case ISD::FLDEXP:
@@ -4209,6 +4217,8 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
   case ISD::FLOG2:
   case ISD::FNEARBYINT:
   case ISD::FRINT:
+  case ISD::LRINT:
+  case ISD::LLRINT:
   case ISD::FROUND:
   case ISD::FROUNDEVEN:
   case ISD::FSIN:
@@ -5958,7 +5968,11 @@ bool DAGTypeLegalizer::WidenVectorOperand(SDNode *N, unsigned OpNo) {
   case ISD::STRICT_FSETCCS:     Res = WidenVecOp_STRICT_FSETCC(N); break;
   case ISD::VSELECT:            Res = WidenVecOp_VSELECT(N); break;
   case ISD::FLDEXP:
-  case ISD::FCOPYSIGN:          Res = WidenVecOp_UnrollVectorOp(N); break;
+  case ISD::FCOPYSIGN:
+  case ISD::LRINT:
+  case ISD::LLRINT:
+    Res = WidenVecOp_UnrollVectorOp(N);
+    break;
   case ISD::IS_FPCLASS:         Res = WidenVecOp_IS_FPCLASS(N); break;
 
   case ISD::ANY_EXTEND:
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 7fcd1f4f898911a..a695c1d1f874b76 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -5111,6 +5111,8 @@ bool SelectionDAG::isKnownNeverNaN(SDValue Op, bool SNaN, unsigned Depth) const
   case ISD::FROUND:
   case ISD::FROUNDEVEN:
   case ISD::FRINT:
+  case ISD::LRINT:
+  case ISD::LLRINT:
   case ISD::FNEARBYINT:
   case ISD::FLDEXP: {
     if (SNaN)
diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index 3e4bff5ddce1264..99eadf4bb9d578b 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -873,13 +873,13 @@ void TargetLoweringBase::initActions() {
 
     // These operations default to expand for vector types.
     if (VT.isVector())
-      setOperationAction({ISD::FCOPYSIGN, ISD::SIGN_EXTEND_INREG,
-                          ISD::ANY_EXTEND_VECTOR_INREG,
-                          ISD::SIGN_EXTEND_VECTOR_INREG,
-                          ISD::ZERO_EXTEND_VECTOR_INREG, ISD::SPLAT_VECTOR},
-                         VT, Expand);
+      setOperationAction(
+          {ISD::FCOPYSIGN, ISD::SIGN_EXTEND_INREG, ISD::ANY_EXTEND_VECTOR_INREG,
+           ISD::SIGN_EXTEND_VECTOR_INREG, ISD::ZERO_EXTEND_VECTOR_INREG,
+           ISD::SPLAT_VECTOR, ISD::LRINT, ISD::LLRINT},
+          VT, Expand);
 
-    // Constrained floating-point operations default to expand.
+      // Constrained floating-point operations default to expand.
 #define DAG_INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC, DAGN)               \
     setOperationAction(ISD::STRICT_##DAGN, VT, Expand);
 #include "llvm/IR/ConstrainedOps.def"
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 5dac691e17cd6ef..7459c7d84874588 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -5659,9 +5659,7 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
     break;
   }
   case Intrinsic::lround:
-  case Intrinsic::llround:
-  case Intrinsic::lrint:
-  case Intrinsic::llrint: {
+  case Intrinsic::llround: {
     Type *ValTy = Call.getArgOperand(0)->getType();
     Type *ResultTy = Call.getType();
     Check(!ValTy->isVectorTy() && !ResultTy->isVectorTy(),
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f1cea6c6756f4fc..e10002b82e06edf 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -732,7 +732,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
                          VT, Custom);
       setOperationAction({ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT}, VT,
                          Custom);
-
+      setOperationAction({ISD::LRINT, ISD::LLRINT}, VT, Custom);
       setOperationAction(
           {ISD::SADDSAT, ISD::UADDSAT, ISD::SSUBSAT, ISD::USUBSAT}, VT, Legal);
 
@@ -2687,13 +2687,14 @@ static RISCVFPRndMode::RoundingMode matchRoundingOp(unsigned Opc) {
   return RISCVFPRndMode::Invalid;
 }
 
-// Expand vector FTRUNC, FCEIL, FFLOOR, FROUND, VP_FCEIL, VP_FFLOOR, VP_FROUND
-// VP_FROUNDEVEN, VP_FROUNDTOZERO, VP_FRINT and VP_FNEARBYINT by converting to
-// the integer domain and back. Taking care to avoid converting values that are
+// Expand vector FTRUNC, FCEIL, FFLOOR, FROUND, FRINT, VP_FCEIL, VP_FFLOOR,
+// VP_FROUND, VP_FROUNDEVEN, VP_FROUNDTOZERO, VP_FRINT and VP_FNEARBYINT by
+// converting to the integer domain and back. Expand LRINT and LLRINT by
+// converting to integer domain. Take care to avoid converting values that are
 // nan or already correct.
 static SDValue
-lowerVectorFTRUNC_FCEIL_FFLOOR_FROUND(SDValue Op, SelectionDAG &DAG,
-                                      const RISCVSubtarget &Subtarget) {
+lowerVectorFTRUNC_FCEIL_FFLOOR_FROUND_XRINT(SDValue Op, SelectionDAG &DAG,
+                                            const RISCVSubtarget &Subtarget) {
   MVT VT = Op.getSimpleValueType();
   assert(VT.isVector() && "Unexpected type");
 
@@ -2721,28 +2722,34 @@ lowerVectorFTRUNC_FCEIL_FFLOOR_FROUND(SDValue Op, SelectionDAG &DAG,
   // Freeze the source since we are increasing the number of uses.
   Src = DAG.getFreeze(Src);
 
-  // We do the conversion on the absolute value and fix the sign at the end.
-  SDValue Abs = DAG.getNode(RISCVISD::FABS_VL, DL, ContainerVT, Src, Mask, VL);
-
-  // Determine the largest integer that can be represented exactly. This and
-  // values larger than it don't have any fractional bits so don't need to
-  // be converted.
-  const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(ContainerVT);
-  unsigned Precision = APFloat::semanticsPrecision(FltSem);
-  APFloat MaxVal = APFloat(FltSem);
-  MaxVal.convertFromAPInt(APInt::getOneBitSet(Precision, Precision - 1),
-                          /*IsSigned*/ false, APFloat::rmNearestTiesToEven);
-  SDValue MaxValNode =
-      DAG.getConstantFP(MaxVal, DL, ContainerVT.getVectorElementType());
-  SDValue MaxValSplat = DAG.getNode(RISCVISD::VFMV_V_F_VL, DL, ContainerVT,
-                                    DAG.getUNDEF(ContainerVT), MaxValNode, VL);
-
-  // If abs(Src) was larger than MaxVal or nan, keep it.
-  MVT SetccVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
-  Mask =
-      DAG.getNode(RISCVISD::SETCC_VL, DL, SetccVT,
-                  {Abs, MaxValSplat, DAG.getCondCode(ISD::SETOLT),
-                   Mask, Mask, VL});
+  // Skip all this logic for LRINT and LLRINT, since ContainerVT isn't a FP
+  // type.
+  if (Op.getOpcode() != ISD::LRINT && Op.getOpcode() != ISD::LLRINT) {
+    // We do the conversion on the absolute value and fix the sign at the end.
+    SDValue Abs =
+        DAG.getNode(RISCVISD::FABS_VL, DL, ContainerVT, Src, Mask, VL);
+
+    // Determine the largest integer that can be represented exactly. This and
+    // values larger than it don't have any fractional bits so don't need to
+    // be converted.
+    const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(ContainerVT);
+    unsigned Precision = APFloat::semanticsPrecision(FltSem);
+    APFloat MaxVal = APFloat(FltSem);
+    MaxVal.convertFromAPInt(APInt::getOneBitSet(Precision, Precision - 1),
+                            /*IsSigned*/ false, APFloat::rmNearestTiesToEven);
+    SDValue MaxValNode =
+        DAG.getConstantFP(MaxVal, DL, ContainerVT.getVectorElementType());
+    SDValue MaxValSplat =
+        DAG.getNode(RISCVISD::VFMV_V_F_VL, DL, ContainerVT,
+                    DAG.getUNDEF(ContainerVT), MaxValNode, VL);
+
+    // If abs(Src) was larger than MaxVal or nan, keep it.
+    MVT SetccVT =
+        MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
+    Mask = DAG.getNode(
+        RISCVISD::SETCC_VL, DL, SetccVT,
+        {Abs, MaxValSplat, DAG.getCondCode(ISD::SETOLT), Mask, Mask, VL});
+  }
 
   // Truncate to integer and convert back to FP.
   MVT IntVT = ContainerVT.changeVectorElementTypeToInteger();
@@ -2773,6 +2780,8 @@ lowerVectorFTRUNC_FCEIL_FFLOOR_FROUND(SDValue Op, SelectionDAG &DAG,
     break;
   case ISD::FRINT:
   case ISD::VP_FRINT:
+  case ISD::LRINT:
+  case ISD::LLRINT:
     Truncated = DAG.getNode(RISCVISD::VFCVT_X_F_VL, DL, IntVT, Src, Mask, VL);
     break;
   case ISD::FNEARBYINT:
@@ -2782,14 +2791,17 @@ lowerVectorFTRUNC_FCEIL_FFLOOR_FROUND(SDValue Op, SelectionDAG &DAG,
     break;
   }
 
-  // VFROUND_NOEXCEPT_VL includes SINT_TO_FP_VL.
-  if (Truncated.getOpcode() != RISCVISD::VFROUND_NOEXCEPT_VL)
-    Truncated = DAG.getNode(RISCVISD::SINT_TO_FP_VL, DL, ContainerVT, Truncated,
-                            Mask, VL);
+  // For LRINT and LLRINT, we're done: don't convert back to float.
+  if (Op.getOpcode() != ISD::LRINT && Op.getOpcode() != ISD::LLRINT) {
+    // VFROUND_NOEXCEPT_VL includes SINT_TO_FP_VL.
+    if (Truncated.getOpcode() != RISCVISD::VFROUND_NOEXCEPT_VL)
+      Truncated = DAG.getNode(RISCVISD::SINT_TO_FP_VL, DL, ContainerVT,
+                              Truncated, Mask, VL);
 
-  // Restore the original sign so that -0.0 is preserved.
-  Truncated = DAG.getNode(RISCVISD::FCOPYSIGN_VL, DL, ContainerVT, Truncated,
-                          Src, Src, Mask, VL);
+    // Restore the original sign so that -0.0 is preserved.
+    Truncated = DAG.getNode(RISCVISD::FCOPYSIGN_VL, DL, ContainerVT, Truncated,
+                            Src, Src, Mask, VL);
+  }
 
   if (!VT.isFixedLengthVector())
     return Truncated;
@@ -2902,11 +2914,14 @@ lowerVectorStrictFTRUNC_FCEIL_FFLOOR_FROUND(SDValue Op, SelectionDAG &DAG,
 }
 
 static SDValue
-lowerFTRUNC_FCEIL_FFLOOR_FROUND(SDValue Op, SelectionDAG &DAG,
-                                const RISCVSubtarget &Subtarget) {
+lowerFTRUNC_FCEIL_FFLOOR_FROUND_XRINT(SDValue Op, SelectionDAG &DAG,
+                                      const RISCVSubtarget &Subtarget) {
   MVT VT = Op.getSimpleValueType();
+  if (Op.getOpcode() == ISD::LRINT || Op.getOpcode() == ISD::LLRINT)
+    assert(VT.isVector() &&
+           "Expected vector in custom lowering of LRINT and LLRINT");
   if (VT.isVector())
-    return lowerVectorFTRUNC_FCEIL_FFLOOR_FROUND(Op, DAG, Subtarget);
+    return lowerVectorFTRUNC_FCEIL_FFLOOR_FROUND_XRINT(Op, DAG, Subtarget);
 
   if (DAG.shouldOptForSize())
     return SDValue();
@@ -5921,9 +5936,11 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
   case ISD::FFLOOR:
   case ISD::FNEARBYINT:
   case ISD::FRINT:
+  case ISD::LRINT:
+  case ISD::LLRINT:
   case ISD::FROUND:
   case ISD::FROUNDEVEN:
-    return lowerFTRUNC_FCEIL_FFLOOR_FROUND(Op, DAG, Subtarget);
+    return lowerFTRUNC_FCEIL_FFLOOR_FROUND_XRINT(Op, DAG, Subtarget);
   case ISD::VECREDUCE_ADD:
   case ISD::VECREDUCE_UMAX:
   case ISD::VECREDUCE_SMAX:
@@ -6287,7 +6304,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
         (Subtarget.hasVInstructionsF16Minimal() &&
          !Subtarget.hasVIns...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/66924


More information about the llvm-commits mailing list