[llvm] [AMDGPU] Add support for llvm.lround and llvm.lrint intrinsics lowering. (PR #96817)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 26 13:37:16 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-selectiondag

Author: Sumanth Gundapaneni (sgundapa)

<details>
<summary>Changes</summary>

This patch adds support to lower llvm.lround and llvm.lrint intrinsics for AMDGPU in SelectionDAGISel and GlobalISel. In order to support vector floating point input for llvm.lround, this patch extends the target independent APIs and provide support for scalarizing, in line with llvm.lrint.

---

Patch is 132.51 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/96817.diff


19 Files Affected:

- (modified) llvm/docs/LangRef.rst (+2-1) 
- (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+6) 
- (modified) llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp (+4) 
- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+7-3) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp (+3-1) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+1-1) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp (+2) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp (+14-2) 
- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+2) 
- (modified) llvm/lib/CodeGen/TargetLoweringBase.cpp (+2-1) 
- (modified) llvm/lib/IR/Verifier.cpp (+21-8) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp (+28-4) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h (+2) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp (+45) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.h (+4) 
- (added) llvm/test/CodeGen/AMDGPU/GlobalISel/lrint.ll (+493) 
- (added) llvm/test/CodeGen/AMDGPU/GlobalISel/lround.ll (+814) 
- (added) llvm/test/CodeGen/AMDGPU/lrint.ll (+467) 
- (added) llvm/test/CodeGen/AMDGPU/lround.ll (+807) 


``````````diff
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index edb362c617565..b8ed6906fc3d2 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -16606,7 +16606,8 @@ Syntax:
 """""""
 
 This is an overloaded intrinsic. You can use ``llvm.lround`` on any
-floating-point type. Not all targets support all types however.
+floating-point type or vector of floating-point type. Not all targets
+support all types however.
 
 ::
 
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 9f8d3ded9b3c1..9e18e7851e780 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -2040,6 +2040,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     case Intrinsic::llrint:
       ISD = ISD::LLRINT;
       break;
+    case Intrinsic::lround:
+      ISD = ISD::LROUND;
+      break;
+    case Intrinsic::llround:
+      ISD = ISD::LLROUND;
+      break;
     case Intrinsic::round:
       ISD = ISD::FROUND;
       break;
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index 430fcae731689..b181f98fea3d5 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -4656,6 +4656,10 @@ LegalizerHelper::fewerElementsVector(MachineInstr &MI, unsigned TypeIdx,
   case G_FCEIL:
   case G_FFLOOR:
   case G_FRINT:
+  case G_INTRINSIC_LRINT:
+  case G_INTRINSIC_LLRINT:
+  case G_LROUND:
+  case G_LLROUND:
   case G_INTRINSIC_ROUND:
   case G_INTRINSIC_ROUNDEVEN:
   case G_INTRINSIC_TRUNC:
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 254d63abdf805..c14a4bf54823e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -506,7 +506,7 @@ namespace {
     SDValue visitUINT_TO_FP(SDNode *N);
     SDValue visitFP_TO_SINT(SDNode *N);
     SDValue visitFP_TO_UINT(SDNode *N);
-    SDValue visitXRINT(SDNode *N);
+    SDValue visitXROUND(SDNode *N);
     SDValue visitFP_ROUND(SDNode *N);
     SDValue visitFP_EXTEND(SDNode *N);
     SDValue visitFNEG(SDNode *N);
@@ -1927,7 +1927,9 @@ SDValue DAGCombiner::visit(SDNode *N) {
   case ISD::FP_TO_SINT:         return visitFP_TO_SINT(N);
   case ISD::FP_TO_UINT:         return visitFP_TO_UINT(N);
   case ISD::LRINT:
-  case ISD::LLRINT:             return visitXRINT(N);
+  case ISD::LLRINT:
+  case ISD::LROUND:
+  case ISD::LLROUND:            return visitXROUND(N);
   case ISD::FP_ROUND:           return visitFP_ROUND(N);
   case ISD::FP_EXTEND:          return visitFP_EXTEND(N);
   case ISD::FNEG:               return visitFNEG(N);
@@ -17835,15 +17837,17 @@ SDValue DAGCombiner::visitFP_TO_UINT(SDNode *N) {
   return FoldIntToFPToInt(N, DAG);
 }
 
-SDValue DAGCombiner::visitXRINT(SDNode *N) {
+SDValue DAGCombiner::visitXROUND(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   EVT VT = N->getValueType(0);
 
   // fold (lrint|llrint undef) -> undef
+  // fold (lround|llround undef) -> undef
   if (N0.isUndef())
     return DAG.getUNDEF(VT);
 
   // fold (lrint|llrint c1fp) -> c1
+  // fold (lround|llround c1fp) -> c1
   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
     return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N0);
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index aa116c9de5d8c..76c17395ef245 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -2296,7 +2296,9 @@ bool DAGTypeLegalizer::PromoteFloatOperand(SDNode *N, unsigned OpNo) {
     case ISD::FP_TO_SINT:
     case ISD::FP_TO_UINT:
     case ISD::LRINT:
-    case ISD::LLRINT:     R = PromoteFloatOp_UnaryOp(N, OpNo); break;
+    case ISD::LLRINT:
+    case ISD::LROUND:
+    case ISD::LLROUND:     R = PromoteFloatOp_UnaryOp(N, OpNo); break;
     case ISD::FP_TO_SINT_SAT:
     case ISD::FP_TO_UINT_SAT:
                           R = PromoteFloatOp_FP_TO_XINT_SAT(N, OpNo); break;
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 85f947efe2c75..858e139df3199 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -1028,7 +1028,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue WidenVecRes_Convert(SDNode *N);
   SDValue WidenVecRes_Convert_StrictFP(SDNode *N);
   SDValue WidenVecRes_FP_TO_XINT_SAT(SDNode *N);
-  SDValue WidenVecRes_XRINT(SDNode *N);
+  SDValue WidenVecRes_XROUND(SDNode *N);
   SDValue WidenVecRes_FCOPYSIGN(SDNode *N);
   SDValue WidenVecRes_UnarySameEltsWithScalarArg(SDNode *N);
   SDValue WidenVecRes_ExpOp(SDNode *N);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 14b147cc5b01b..e5559fb957847 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -466,6 +466,8 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
                                               Node->getValueType(0), Scale);
     break;
   }
+  case ISD::LROUND:
+  case ISD::LLROUND:
   case ISD::LRINT:
   case ISD::LLRINT:
   case ISD::SINT_TO_FP:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 532c6306fb3d1..9d03bb818375e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -104,6 +104,8 @@ void DAGTypeLegalizer::ScalarizeVectorResult(SDNode *N, unsigned ResNo) {
   case ISD::FRINT:
   case ISD::LRINT:
   case ISD::LLRINT:
+  case ISD::LROUND:
+  case ISD::LLROUND:
   case ISD::FROUND:
   case ISD::FROUNDEVEN:
   case ISD::FSIN:
@@ -746,6 +748,8 @@ bool DAGTypeLegalizer::ScalarizeVectorOperand(SDNode *N, unsigned OpNo) {
   case ISD::UINT_TO_FP:
   case ISD::LRINT:
   case ISD::LLRINT:
+  case ISD::LROUND:
+  case ISD::LLROUND:
     Res = ScalarizeVecOp_UnaryOp(N);
     break;
   case ISD::STRICT_SINT_TO_FP:
@@ -1176,6 +1180,8 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
   case ISD::VP_LRINT:
   case ISD::LLRINT:
   case ISD::VP_LLRINT:
+  case ISD::LROUND:
+  case ISD::LLROUND:
   case ISD::FROUND:
   case ISD::VP_FROUND:
   case ISD::FROUNDEVEN:
@@ -3157,6 +3163,8 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
   case ISD::ZERO_EXTEND:
   case ISD::ANY_EXTEND:
   case ISD::FTRUNC:
+  case ISD::LROUND:
+  case ISD::LLROUND:
   case ISD::LRINT:
   case ISD::LLRINT:
     Res = SplitVecOp_UnaryOp(N);
@@ -4471,11 +4479,13 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
     Res = WidenVecRes_FP_TO_XINT_SAT(N);
     break;
 
+  case ISD::LROUND:
+  case ISD::LLROUND:
   case ISD::LRINT:
   case ISD::LLRINT:
   case ISD::VP_LRINT:
   case ISD::VP_LLRINT:
-    Res = WidenVecRes_XRINT(N);
+    Res = WidenVecRes_XROUND(N);
     break;
 
   case ISD::FABS:
@@ -5086,7 +5096,7 @@ SDValue DAGTypeLegalizer::WidenVecRes_FP_TO_XINT_SAT(SDNode *N) {
   return DAG.getNode(N->getOpcode(), dl, WidenVT, Src, N->getOperand(1));
 }
 
-SDValue DAGTypeLegalizer::WidenVecRes_XRINT(SDNode *N) {
+SDValue DAGTypeLegalizer::WidenVecRes_XROUND(SDNode *N) {
   SDLoc dl(N);
   EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
   ElementCount WidenNumElts = WidenVT.getVectorElementCount();
@@ -6309,6 +6319,8 @@ bool DAGTypeLegalizer::WidenVectorOperand(SDNode *N, unsigned OpNo) {
   case ISD::VSELECT:            Res = WidenVecOp_VSELECT(N); break;
   case ISD::FLDEXP:
   case ISD::FCOPYSIGN:
+  case ISD::LROUND:
+  case ISD::LLROUND:
   case ISD::LRINT:
   case ISD::LLRINT:
     Res = WidenVecOp_UnrollVectorOp(N);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 8463e94d7f933..99928b5a255d6 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -5392,6 +5392,8 @@ bool SelectionDAG::isKnownNeverNaN(SDValue Op, bool SNaN, unsigned Depth) const
   case ISD::FROUND:
   case ISD::FROUNDEVEN:
   case ISD::FRINT:
+  case ISD::LROUND:
+  case ISD::LLROUND:
   case ISD::LRINT:
   case ISD::LLRINT:
   case ISD::FNEARBYINT:
diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index ff684c7cb6bba..7e9ba14cb2ed2 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -964,7 +964,8 @@ void TargetLoweringBase::initActions() {
       setOperationAction(
           {ISD::FCOPYSIGN, ISD::SIGN_EXTEND_INREG, ISD::ANY_EXTEND_VECTOR_INREG,
            ISD::SIGN_EXTEND_VECTOR_INREG, ISD::ZERO_EXTEND_VECTOR_INREG,
-           ISD::SPLAT_VECTOR, ISD::LRINT, ISD::LLRINT, ISD::FTAN},
+           ISD::SPLAT_VECTOR, ISD::LRINT, ISD::LLRINT, ISD::FTAN, ISD::LROUND,
+           ISD::LLROUND},
           VT, Expand);
 
       // Constrained floating-point operations default to expand.
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index c98f61d555140..c84bb98a9305e 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -5928,6 +5928,27 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
     }
     break;
   }
+  case Intrinsic::lround:
+  case Intrinsic::llround: {
+    Type *ValTy = Call.getArgOperand(0)->getType();
+    Type *ResultTy = Call.getType();
+    Check(
+        ValTy->isFPOrFPVectorTy() && ResultTy->isIntOrIntVectorTy(),
+        "llvm.lround, llvm.llround: argument must be floating-point or vector "
+        "of floating-points, and result must be integer or vector of integers",
+        &Call);
+    Check(
+        ValTy->isVectorTy() == ResultTy->isVectorTy(),
+        "llvm.lround, llvm.llround: argument and result disagree on vector use",
+        &Call);
+    if (ValTy->isVectorTy()) {
+      Check(cast<VectorType>(ValTy)->getElementCount() ==
+                cast<VectorType>(ResultTy)->getElementCount(),
+            "llvm.lround, llvm.llround: argument must be same length as result",
+            &Call);
+    }
+    break;
+  }
   case Intrinsic::lrint:
   case Intrinsic::llrint: {
     Type *ValTy = Call.getArgOperand(0)->getType();
@@ -5948,14 +5969,6 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
     }
     break;
   }
-  case Intrinsic::lround:
-  case Intrinsic::llround: {
-    Type *ValTy = Call.getArgOperand(0)->getType();
-    Type *ResultTy = Call.getType();
-    Check(!ValTy->isVectorTy() && !ResultTy->isVectorTy(),
-          "Intrinsic does not support vectors", &Call);
-    break;
-  }
   case Intrinsic::bswap: {
     Type *Ty = Call.getType();
     unsigned Size = Ty->getScalarSizeInBits();
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index 522b3a34161cd..b1e28b14329d6 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -393,7 +393,8 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
                      MVT::f32, Legal);
 
   setOperationAction(ISD::FLOG2, MVT::f32, Custom);
-  setOperationAction(ISD::FROUND, {MVT::f32, MVT::f64}, Custom);
+  setOperationAction({ISD::FROUND, ISD::LROUND, ISD::LLROUND},
+                     {MVT::f16, MVT::f32, MVT::f64}, Custom);
 
   setOperationAction(
       {ISD::FLOG, ISD::FLOG10, ISD::FEXP, ISD::FEXP2, ISD::FEXP10}, MVT::f32,
@@ -401,7 +402,8 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
 
   setOperationAction(ISD::FNEARBYINT, {MVT::f16, MVT::f32, MVT::f64}, Custom);
 
-  setOperationAction(ISD::FRINT, {MVT::f16, MVT::f32, MVT::f64}, Custom);
+  setOperationAction({ISD::FRINT, ISD::LRINT, ISD::LLRINT},
+                     {MVT::f16, MVT::f32, MVT::f64}, Custom);
 
   setOperationAction(ISD::FREM, {MVT::f16, MVT::f32, MVT::f64}, Custom);
 
@@ -1385,10 +1387,17 @@ SDValue AMDGPUTargetLowering::LowerOperation(SDValue Op,
   case ISD::FCEIL: return LowerFCEIL(Op, DAG);
   case ISD::FTRUNC: return LowerFTRUNC(Op, DAG);
   case ISD::FRINT: return LowerFRINT(Op, DAG);
-  case ISD::FNEARBYINT: return LowerFNEARBYINT(Op, DAG);
+  case ISD::LRINT:
+  case ISD::LLRINT:
+    return LowerLRINT(Op, DAG);
+  case ISD::FNEARBYINT:
+    return LowerFNEARBYINT(Op, DAG);
   case ISD::FROUNDEVEN:
     return LowerFROUNDEVEN(Op, DAG);
   case ISD::FROUND: return LowerFROUND(Op, DAG);
+  case ISD::LROUND:
+  case ISD::LLROUND:
+    return LowerLROUND(Op, DAG);
   case ISD::FFLOOR: return LowerFFLOOR(Op, DAG);
   case ISD::FLOG2:
     return LowerFLOG2(Op, DAG);
@@ -2493,6 +2502,14 @@ SDValue AMDGPUTargetLowering::LowerFRINT(SDValue Op, SelectionDAG &DAG) const {
   return DAG.getNode(ISD::FROUNDEVEN, SDLoc(Op), VT, Arg);
 }
 
+SDValue AMDGPUTargetLowering::LowerLRINT(SDValue Op, SelectionDAG &DAG) const {
+  auto ResVT = Op.getValueType();
+  auto Arg = Op.getOperand(0u);
+  auto ArgVT = Arg.getValueType();
+  SDValue RoundNode = DAG.getNode(ISD::FROUNDEVEN, SDLoc(Op), ArgVT, Arg);
+  return DAG.getNode(ISD::FP_TO_SINT, SDLoc(Op), ResVT, RoundNode);
+}
+
 // XXX - May require not supporting f32 denormals?
 
 // Don't handle v2f16. The extra instructions to scalarize and repack around the
@@ -2501,7 +2518,7 @@ SDValue AMDGPUTargetLowering::LowerFRINT(SDValue Op, SelectionDAG &DAG) const {
 SDValue AMDGPUTargetLowering::LowerFROUND(SDValue Op, SelectionDAG &DAG) const {
   SDLoc SL(Op);
   SDValue X = Op.getOperand(0);
-  EVT VT = Op.getValueType();
+  EVT VT = X.getValueType();
 
   SDValue T = DAG.getNode(ISD::FTRUNC, SL, VT, X);
 
@@ -2525,6 +2542,13 @@ SDValue AMDGPUTargetLowering::LowerFROUND(SDValue Op, SelectionDAG &DAG) const {
   return DAG.getNode(ISD::FADD, SL, VT, T, SignedOffset);
 }
 
+SDValue AMDGPUTargetLowering::LowerLROUND(SDValue Op, SelectionDAG &DAG) const {
+  SDLoc SL(Op);
+  EVT ResVT = Op.getValueType();
+  SDValue FRoundNode = LowerFROUND(Op, DAG);
+  return DAG.getNode(ISD::FP_TO_SINT, SL, ResVT, FRoundNode);
+}
+
 SDValue AMDGPUTargetLowering::LowerFFLOOR(SDValue Op, SelectionDAG &DAG) const {
   SDLoc SL(Op);
   SDValue Src = Op.getOperand(0);
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
index 37572af3897f2..28cce7b6da517 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
@@ -55,10 +55,12 @@ class AMDGPUTargetLowering : public TargetLowering {
   SDValue LowerFCEIL(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerFTRUNC(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerFRINT(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerLRINT(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerFNEARBYINT(SDValue Op, SelectionDAG &DAG) const;
 
   SDValue LowerFROUNDEVEN(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerFROUND(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerLROUND(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerFFLOOR(SDValue Op, SelectionDAG &DAG) const;
 
   static bool allowApproxFunc(const SelectionDAG &DAG, SDNodeFlags Flags);
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
index f1254b2e9e1d2..46ca530008f70 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
@@ -1141,6 +1141,18 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
       .scalarize(0)
       .lower();
 
+  getActionDefinitionsBuilder({G_LROUND, G_LLROUND})
+      .customFor({{S32, S16}, {S32, S32}, {S64, S64}, {S32, S64}, {S64, S32}})
+      .clampScalar(0, S16, S64)
+      .scalarize(0)
+      .lower();
+
+  getActionDefinitionsBuilder({G_INTRINSIC_LRINT, G_INTRINSIC_LLRINT})
+      .customFor({{S32, S16}, {S32, S32}, {S64, S64}, {S32, S64}, {S64, S32}})
+      .clampScalar(0, S16, S64)
+      .scalarize(0)
+      .lower();
+
   if (ST.has16BitInsts()) {
     getActionDefinitionsBuilder(
         {G_INTRINSIC_TRUNC, G_FCEIL, G_INTRINSIC_ROUNDEVEN})
@@ -2156,6 +2168,12 @@ bool AMDGPULegalizerInfo::legalizeCustom(
     return legalizeCTLZ_ZERO_UNDEF(MI, MRI, B);
   case TargetOpcode::G_INTRINSIC_FPTRUNC_ROUND:
     return legalizeFPTruncRound(MI, B);
+  case TargetOpcode::G_LROUND:
+  case TargetOpcode::G_LLROUND:
+    return legalizeLROUND(MI, MRI, B);
+  case TargetOpcode::G_INTRINSIC_LRINT:
+  case TargetOpcode::G_INTRINSIC_LLRINT:
+    return legalizeLRINT(MI, MRI, B);
   case TargetOpcode::G_STACKSAVE:
     return legalizeStackSave(MI, B);
   case TargetOpcode::G_GET_FPENV:
@@ -7104,6 +7122,33 @@ bool AMDGPULegalizerInfo::legalizeFPTruncRound(MachineInstr &MI,
   return true;
 }
 
+bool AMDGPULegalizerInfo::legalizeLROUND(MachineInstr &MI,
+                                         MachineRegisterInfo &MRI,
+                                         MachineIRBuilder &B) const {
+  Register DstReg = MI.getOperand(0).getReg();
+  Register SrcReg = MI.getOperand(1).getReg();
+  LLT SrcTy = MRI.getType(SrcReg);
+
+  auto Round = B.buildInstr(TargetOpcode::G_INTRINSIC_ROUND, {SrcTy}, {SrcReg});
+  B.buildFPTOSI(DstReg, Round);
+  MI.eraseFromParent();
+  return true;
+}
+
+bool AMDGPULegalizerInfo::legalizeLRINT(MachineInstr &MI,
+                                        MachineRegisterInfo &MRI,
+                                        MachineIRBuilder &B) const {
+
+  Register DstReg = MI.getOperand(0).getReg();
+  Register SrcReg = MI.getOperand(1).getReg();
+  LLT SrcTy = MRI.getType(SrcReg);
+
+  auto Round = B.buildIntrinsicRoundeven(SrcTy, SrcReg);
+  B.buildFPTOSI(DstReg, Round);
+  MI.eraseFromParent();
+  return true;
+}
+
 bool AMDGPULegalizerInfo::legalizeStackSave(MachineInstr &MI,
                                             MachineIRBuilder &B) const {
   const SITargetLowering *TLI = ST.getTargetLowering();
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.h b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.h
index ae01bb29c1108..f0dae35224907 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.h
@@ -55,6 +55,10 @@ class AMDGPULegalizerInfo final : public LegalizerInfo {
                      MachineIRBuilder &B, bool Signed) const;
   bool legalizeFPTOI(MachineInstr &MI, MachineRegisterInfo &MRI,
                      MachineIRBuilder &B, bool Signed) const;
+  bool legalizeLROUND(MachineInstr &MI, MachineRegisterInfo &MRI,
+                      MachineIRBuilder &B) const;
+  bool legalizeLRINT(MachineInstr &MI, MachineRegisterInfo &MRI,
+                     MachineIRBuilder &B) const;
   bool legalizeMinNumMaxNum(LegalizerHelper &Helper, MachineInstr &MI) const;
   bool legalizeExtractVectorElt(MachineInstr &MI, MachineRegisterInfo &MRI,
                                 MachineIRBuilder &B) const;
diff --git a/llvm/test/CodeGen/AMDGPU/GlobalISel/lrint.ll b/llvm/test/CodeGen/AMDGPU/GlobalISel/lrint.ll
new file mode 100644
index 0000000000000..c6ac0b2dd3334
--- /dev/null
+++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/lrint.ll
@@ -0,0 +1,493 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -global-isel -mtriple=amdgcn -mcpu=gfx900 < %s | FileCheck -check-prefixes=GCN,GFX9 %s
+; RUN: llc -global-isel -mtriple=amdgcn -mcpu=gfx1010 < %s | FileCheck -check-prefixes=GCN,GFX10 %s
+; RUN: llc -global-isel -mtriple=amdgcn -mcpu=gfx1100 < %s | FileCheck -check-prefixes=GCN,GFX11 %s
+
+declare float @llvm.rint.f32(float)
+declare i32 @llvm.lrint.i32.f32(float)
+declare i32 @llvm.lrint.i32.f64(double)
+declare i64 @llvm.lrint.i64.f32(float)
+declare i64 @llvm.lrint.i64.f64(double)
+declare i64 @llvm.llrint.i64.f32(float)
+declare half @llvm.rint.f16(half)
+declare i32 @llvm.lrint.i32.f16(half %arg)
+declare <2 x float> @llvm.rint.v2f32.v2f32(<2 x float> %arg)
+declare <2 x i32> @llvm.lrint.v2i32.v2f32(<2 x float> %arg)
+declare <2 x i64> @llvm.lrint.v2i64.v2f32(<2 x float> %arg)
+
+define float @intrinsic_frint(float %arg) {
+; GCN-LABEL: intrinsic_frint:
+; GCN:       ; %bb.0: ; %entry
+; GCN-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GCN-NEXT:    v_rndne_f32_e32 v0, v0
+; GCN-NEXT:    s_setpc_b64 s[30:31]
+entry:
+  %0 = ta...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/96817


More information about the llvm-commits mailing list