[llvm] [SelectionDAG][RISCV] Operations with static rounding (PR #100999)
Serge Pavlov via llvm-commits
llvm-commits at lists.llvm.org
Thu Aug 1 04:09:20 PDT 2024
https://github.com/spavloff updated https://github.com/llvm/llvm-project/pull/100999
>From 38dadb6e81f3b9426cfb925ea3e44115c4676869 Mon Sep 17 00:00:00 2001
From: Serge Pavlov <sepavloff at gmail.com>
Date: Thu, 30 May 2024 22:11:38 +0700
Subject: [PATCH 1/3] [SelectionDAG][RISCV] Operations with static rounding
Some targets, including RISC-V, support rounding mode specified in
instruction rather than in a register. To use this feature, new DAG
nodes are introduced by this patch. They have the same operands as their
default mode counterparts, but have additional integer parameter, which
specifies the rounding mode to use. Lowering of these DAG nodes is
implemented for extensions F and Zfinx.
---
llvm/include/llvm/CodeGen/ISDOpcodes.h | 14 +
llvm/include/llvm/CodeGen/TargetLowering.h | 10 +
.../include/llvm/Target/TargetSelectionDAG.td | 25 ++
.../SelectionDAG/LegalizeIntegerTypes.cpp | 16 ++
llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h | 2 +
.../SelectionDAG/SelectionDAGBuilder.cpp | 78 ++++--
.../SelectionDAG/SelectionDAGDumper.cpp | 9 +
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 22 ++
llvm/lib/Target/RISCV/RISCVISelLowering.h | 5 +
llvm/lib/Target/RISCV/RISCVInstrInfoF.td | 82 ++++++
llvm/test/CodeGen/RISCV/float-mode.ll | 249 ++++++++++++++++++
11 files changed, 496 insertions(+), 16 deletions(-)
create mode 100644 llvm/test/CodeGen/RISCV/float-mode.ll
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 5b657fb171296..71a3933250ccb 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -411,6 +411,20 @@ enum NodeType {
STRICT_FREM,
STRICT_FMA,
+ /// Basic floating-point operations with statically specified control modes,
+ /// usually rounding. These have the same operands as the corresponding
+ /// default mode operations with an additional integer operand that represents
+ /// the specified modes in a target-dependent format.
+ FADD_MODE,
+ FSUB_MODE,
+ FMUL_MODE,
+ FDIV_MODE,
+ FSQRT_MODE,
+ FMA_MODE,
+ SINT_TO_FP_MODE,
+ UINT_TO_FP_MODE,
+ FP_ROUND_MODE,
+
/// Constrained versions of libm-equivalent floating point intrinsics.
/// These will be lowered to the equivalent non-constrained pseudo-op
/// (or expanded to the equivalent library call) before final selection.
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 9d9886f4920a2..d9ba3e3ef9d2f 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -354,6 +354,16 @@ class TargetLoweringBase {
return IsStrictFPEnabled;
}
+ /// Returns true if the target supports static rounding mode for the given
+ /// instruction.
+ virtual bool isStaticRoundingSupportedFor(const Instruction &I) const {
+ return false;
+ }
+
+ /// Returns target-specific representation of the given static rounding mode
+ /// or -1, if this rounding mode is not supported.
+ virtual int getMachineRoundingMode(RoundingMode RM) const { return -1; }
+
protected:
/// Initialize all of the actions to default values.
void initActions();
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 46044aab79a83..b206a9d9df983 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -134,12 +134,18 @@ def SDTIntScaledBinOp : SDTypeProfile<1, 3, [ // smulfix, sdivfix, etc
def SDTFPBinOp : SDTypeProfile<1, 2, [ // fadd, fmul, etc.
SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisFP<0>
]>;
+def SDTFPBinModeOp : SDTypeProfile<1, 3, [ // fadd, fmul, etc. with static modes
+ SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisFP<0>, SDTCisInt<3>
+]>;
def SDTFPSignOp : SDTypeProfile<1, 2, [ // fcopysign.
SDTCisSameAs<0, 1>, SDTCisFP<0>, SDTCisFP<2>
]>;
def SDTFPTernaryOp : SDTypeProfile<1, 3, [ // fmadd, fnmsub, etc.
SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisSameAs<0, 3>, SDTCisFP<0>
]>;
+def SDTFPTernaryModeOp : SDTypeProfile<1, 4, [ // fmadd, fnmsub, etc. with static modes
+ SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisSameAs<0, 3>, SDTCisFP<0>, SDTCisInt<4>
+]>;
def SDTIntUnaryOp : SDTypeProfile<1, 1, [ // bitreverse
SDTCisSameAs<0, 1>, SDTCisInt<0>
]>;
@@ -155,9 +161,15 @@ def SDTIntTruncOp : SDTypeProfile<1, 1, [ // trunc
def SDTFPUnaryOp : SDTypeProfile<1, 1, [ // fneg, fsqrt, etc
SDTCisSameAs<0, 1>, SDTCisFP<0>
]>;
+def SDTFPUnaryModeOp : SDTypeProfile<1, 2, [ // fsqrt with static modes
+ SDTCisSameAs<0, 1>, SDTCisFP<0>, SDTCisInt<2>
+]>;
def SDTFPRoundOp : SDTypeProfile<1, 1, [ // fpround
SDTCisFP<0>, SDTCisFP<1>, SDTCisOpSmallerThanOp<0, 1>, SDTCisSameNumEltsAs<0, 1>
]>;
+def SDTFPRoundModeOp : SDTypeProfile<1, 2, [ // fpround with static modes
+ SDTCisFP<0>, SDTCisFP<1>, SDTCisOpSmallerThanOp<0, 1>, SDTCisSameNumEltsAs<0, 1>, SDTCisInt<2>
+]>;
def SDTFPExtendOp : SDTypeProfile<1, 1, [ // fpextend
SDTCisFP<0>, SDTCisFP<1>, SDTCisOpSmallerThanOp<1, 0>, SDTCisSameNumEltsAs<0, 1>
]>;
@@ -167,6 +179,9 @@ def SDIsFPClassOp : SDTypeProfile<1, 2, [ // is_fpclass
def SDTIntToFPOp : SDTypeProfile<1, 1, [ // [su]int_to_fp
SDTCisFP<0>, SDTCisInt<1>, SDTCisSameNumEltsAs<0, 1>
]>;
+def SDTIntToFPModeOp : SDTypeProfile<1, 2, [ // [su]int_to_fp with static modes
+ SDTCisFP<0>, SDTCisInt<1>, SDTCisSameNumEltsAs<0, 1>, SDTCisInt<2>
+]>;
def SDTFPToIntOp : SDTypeProfile<1, 1, [ // fp_to_[su]int
SDTCisInt<0>, SDTCisFP<1>, SDTCisSameNumEltsAs<0, 1>
]>;
@@ -662,6 +677,16 @@ def strict_fp_to_bf16 : SDNode<"ISD::STRICT_FP_TO_BF16",
def strict_fsetcc : SDNode<"ISD::STRICT_FSETCC", SDTSetCC, [SDNPHasChain]>;
def strict_fsetccs : SDNode<"ISD::STRICT_FSETCCS", SDTSetCC, [SDNPHasChain]>;
+def fadd_mode : SDNode<"ISD::FADD_MODE" , SDTFPBinModeOp, [SDNPCommutative]>;
+def fsub_mode : SDNode<"ISD::FSUB_MODE" , SDTFPBinModeOp>;
+def fmul_mode : SDNode<"ISD::FMUL_MODE" , SDTFPBinModeOp, [SDNPCommutative]>;
+def fdiv_mode : SDNode<"ISD::FDIV_MODE" , SDTFPBinModeOp>;
+def fsqrt_mode : SDNode<"ISD::FSQRT_MODE" , SDTFPUnaryModeOp>;
+def fma_mode : SDNode<"ISD::FMA_MODE" , SDTFPTernaryModeOp, [SDNPCommutative]>;
+def sint_to_fp_mode: SDNode<"ISD::SINT_TO_FP_MODE" , SDTIntToFPModeOp>;
+def uint_to_fp_mode: SDNode<"ISD::UINT_TO_FP_MODE" , SDTIntToFPModeOp>;
+def fpround_mode : SDNode<"ISD::FP_ROUND_MODE" , SDTFPRoundModeOp>;
+
def get_fpenv : SDNode<"ISD::GET_FPENV", SDTGetFPStateOp, [SDNPHasChain]>;
def set_fpenv : SDNode<"ISD::SET_FPENV", SDTSetFPStateOp, [SDNPHasChain]>;
def reset_fpenv : SDNode<"ISD::RESET_FPENV", SDTNone, [SDNPHasChain]>;
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 33a53dfc81379..d3be39faca2f6 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -1943,6 +1943,7 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
case ISD::VP_SIGN_EXTEND: Res = PromoteIntOp_VP_SIGN_EXTEND(N); break;
case ISD::VP_SINT_TO_FP:
case ISD::SINT_TO_FP: Res = PromoteIntOp_SINT_TO_FP(N); break;
+ case ISD::SINT_TO_FP_MODE: Res = PromoteIntOp_SINT_TO_FP_MODE(N); break;
case ISD::STRICT_SINT_TO_FP: Res = PromoteIntOp_STRICT_SINT_TO_FP(N); break;
case ISD::STORE: Res = PromoteIntOp_STORE(cast<StoreSDNode>(N),
OpNo); break;
@@ -1963,6 +1964,7 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
case ISD::FP16_TO_FP:
case ISD::VP_UINT_TO_FP:
case ISD::UINT_TO_FP: Res = PromoteIntOp_UINT_TO_FP(N); break;
+ case ISD::UINT_TO_FP_MODE: Res = PromoteIntOp_UINT_TO_FP_MODE(N); break;
case ISD::STRICT_FP16_TO_FP:
case ISD::STRICT_UINT_TO_FP: Res = PromoteIntOp_STRICT_UINT_TO_FP(N); break;
case ISD::ZERO_EXTEND: Res = PromoteIntOp_ZERO_EXTEND(N); break;
@@ -2344,6 +2346,20 @@ SDValue DAGTypeLegalizer::PromoteIntOp_SINT_TO_FP(SDNode *N) {
SExtPromotedInteger(N->getOperand(0))), 0);
}
+SDValue DAGTypeLegalizer::PromoteIntOp_SINT_TO_FP_MODE(SDNode *N) {
+ return SDValue(DAG.UpdateNodeOperands(N,
+ SExtPromotedInteger(N->getOperand(0)),
+ N->getOperand(1)),
+ 0);
+}
+
+SDValue DAGTypeLegalizer::PromoteIntOp_UINT_TO_FP_MODE(SDNode *N) {
+ return SDValue(DAG.UpdateNodeOperands(N,
+ ZExtPromotedInteger(N->getOperand(0)),
+ N->getOperand(1)),
+ 0);
+}
+
SDValue DAGTypeLegalizer::PromoteIntOp_STRICT_SINT_TO_FP(SDNode *N) {
return SDValue(DAG.UpdateNodeOperands(N, N->getOperand(0),
SExtPromotedInteger(N->getOperand(1))), 0);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index d4e61c8588901..eea01aea42a65 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -402,6 +402,8 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue PromoteIntOp_SIGN_EXTEND(SDNode *N);
SDValue PromoteIntOp_VP_SIGN_EXTEND(SDNode *N);
SDValue PromoteIntOp_SINT_TO_FP(SDNode *N);
+ SDValue PromoteIntOp_SINT_TO_FP_MODE(SDNode *N);
+ SDValue PromoteIntOp_UINT_TO_FP_MODE(SDNode *N);
SDValue PromoteIntOp_STRICT_SINT_TO_FP(SDNode *N);
SDValue PromoteIntOp_STORE(StoreSDNode *N, unsigned OpNo);
SDValue PromoteIntOp_TRUNCATE(SDNode *N);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 9f5e6466309e9..26de077171092 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8131,15 +8131,73 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
void SelectionDAGBuilder::visitConstrainedFPIntrinsic(
const ConstrainedFPIntrinsic &FPI) {
SDLoc sdl = getCurSDLoc();
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ EVT VT = TLI.getValueType(DAG.getDataLayout(), FPI.getType());
+ fp::ExceptionBehavior EB = *FPI.getExceptionBehavior();
+ std::optional<RoundingMode> RM = FPI.getRoundingMode();
+
+ SDNodeFlags Flags;
+ if (EB == fp::ExceptionBehavior::ebIgnore)
+ Flags.setNoFPExcept(true);
+
+ if (auto *FPOp = dyn_cast<FPMathOperator>(&FPI))
+ Flags.copyFMF(*FPOp);
+
+ bool UseStaticRounding = EB == fp::ExceptionBehavior::ebIgnore && RM &&
+ *RM != RoundingMode::Dynamic &&
+ TLI.isStaticRoundingSupportedFor(FPI);
+ unsigned Opcode = 0;
+ SmallVector<SDValue, 4> Opers;
+ for (unsigned I = 0, E = FPI.getNonMetadataArgCount(); I != E; ++I)
+ Opers.push_back(getValue(FPI.getArgOperand(I)));
+
+ if (UseStaticRounding) {
+ switch (FPI.getIntrinsicID()) {
+ case Intrinsic::experimental_constrained_fadd:
+ Opcode = ISD::FADD_MODE;
+ break;
+ case Intrinsic::experimental_constrained_fsub:
+ Opcode = ISD::FSUB_MODE;
+ break;
+ case Intrinsic::experimental_constrained_fmul:
+ Opcode = ISD::FMUL_MODE;
+ break;
+ case Intrinsic::experimental_constrained_fdiv:
+ Opcode = ISD::FDIV_MODE;
+ break;
+ case Intrinsic::experimental_constrained_sqrt:
+ Opcode = ISD::FSQRT_MODE;
+ break;
+ case Intrinsic::experimental_constrained_fma:
+ Opcode = ISD::FMA_MODE;
+ break;
+ case Intrinsic::experimental_constrained_sitofp:
+ Opcode = ISD::SINT_TO_FP_MODE;
+ break;
+ case Intrinsic::experimental_constrained_uitofp:
+ Opcode = ISD::UINT_TO_FP_MODE;
+ break;
+ case Intrinsic::experimental_constrained_fptrunc:
+ Opcode = ISD::FP_ROUND_MODE;
+ break;
+ }
+ if (Opcode) {
+ int MachineRM = TLI.getMachineRoundingMode(*RM);
+ assert(MachineRM >= 0 && "Unsupported rounding mode");
+ EVT RMType = TLI.getTypeToTransformTo(*DAG.getContext(), MVT::i32);
+ Opers.push_back(DAG.getConstant(static_cast<uint64_t>(MachineRM), sdl,
+ RMType, true));
+ SDValue Result = DAG.getNode(Opcode, sdl, VT, Opers, Flags);
+ setValue(&FPI, Result);
+ return;
+ }
+ }
// We do not need to serialize constrained FP intrinsics against
// each other or against (nonvolatile) loads, so they can be
// chained like loads.
SDValue Chain = DAG.getRoot();
- SmallVector<SDValue, 4> Opers;
- Opers.push_back(Chain);
- for (unsigned I = 0, E = FPI.getNonMetadataArgCount(); I != E; ++I)
- Opers.push_back(getValue(FPI.getArgOperand(I)));
+ Opers.insert(Opers.begin(), Chain);
auto pushOutChain = [this](SDValue Result, fp::ExceptionBehavior EB) {
assert(Result.getNode()->getNumValues() == 2);
@@ -8167,19 +8225,7 @@ void SelectionDAGBuilder::visitConstrainedFPIntrinsic(
}
};
- const TargetLowering &TLI = DAG.getTargetLoweringInfo();
- EVT VT = TLI.getValueType(DAG.getDataLayout(), FPI.getType());
SDVTList VTs = DAG.getVTList(VT, MVT::Other);
- fp::ExceptionBehavior EB = *FPI.getExceptionBehavior();
-
- SDNodeFlags Flags;
- if (EB == fp::ExceptionBehavior::ebIgnore)
- Flags.setNoFPExcept(true);
-
- if (auto *FPOp = dyn_cast<FPMathOperator>(&FPI))
- Flags.copyFMF(*FPOp);
-
- unsigned Opcode;
switch (FPI.getIntrinsicID()) {
default: llvm_unreachable("Impossible intrinsic"); // Can't reach here.
#define DAG_INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC, DAGN) \
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index 16fc52caebb75..30779a9a38379 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -206,6 +206,7 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
case ISD::FNEG: return "fneg";
case ISD::FSQRT: return "fsqrt";
case ISD::STRICT_FSQRT: return "strict_fsqrt";
+ case ISD::FSQRT_MODE: return "fsqrt_mode";
case ISD::FCBRT: return "fcbrt";
case ISD::FSIN: return "fsin";
case ISD::STRICT_FSIN: return "strict_fsin";
@@ -284,14 +285,19 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
case ISD::FSHR: return "fshr";
case ISD::FADD: return "fadd";
case ISD::STRICT_FADD: return "strict_fadd";
+ case ISD::FADD_MODE: return "fadd_mode";
case ISD::FSUB: return "fsub";
case ISD::STRICT_FSUB: return "strict_fsub";
+ case ISD::FSUB_MODE: return "fsub_mode";
case ISD::FMUL: return "fmul";
case ISD::STRICT_FMUL: return "strict_fmul";
+ case ISD::FMUL_MODE: return "fmul_mode";
case ISD::FDIV: return "fdiv";
case ISD::STRICT_FDIV: return "strict_fdiv";
+ case ISD::FDIV_MODE: return "fdiv_mode";
case ISD::FMA: return "fma";
case ISD::STRICT_FMA: return "strict_fma";
+ case ISD::FMA_MODE: return "fma_mode";
case ISD::FMAD: return "fmad";
case ISD::FREM: return "frem";
case ISD::STRICT_FREM: return "strict_frem";
@@ -382,13 +388,16 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
case ISD::TRUNCATE: return "truncate";
case ISD::FP_ROUND: return "fp_round";
case ISD::STRICT_FP_ROUND: return "strict_fp_round";
+ case ISD::FP_ROUND_MODE: return "fp_round_mode";
case ISD::FP_EXTEND: return "fp_extend";
case ISD::STRICT_FP_EXTEND: return "strict_fp_extend";
case ISD::SINT_TO_FP: return "sint_to_fp";
case ISD::STRICT_SINT_TO_FP: return "strict_sint_to_fp";
+ case ISD::SINT_TO_FP_MODE: return "sint_to_fp_mode";
case ISD::UINT_TO_FP: return "uint_to_fp";
case ISD::STRICT_UINT_TO_FP: return "strict_uint_to_fp";
+ case ISD::UINT_TO_FP_MODE: return "uint_to_fp_mode";
case ISD::FP_TO_SINT: return "fp_to_sint";
case ISD::STRICT_FP_TO_SINT: return "strict_fp_to_sint";
case ISD::FP_TO_UINT: return "fp_to_uint";
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 9ce669a3122f5..c77e284d5fcc7 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -434,6 +434,10 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
ISD::STRICT_FSUB, ISD::STRICT_FMUL, ISD::STRICT_FDIV,
ISD::STRICT_FSQRT, ISD::STRICT_FSETCC, ISD::STRICT_FSETCCS};
+ static const unsigned FPStaticRoundNodes[] = {ISD::FADD_MODE, ISD::FSUB_MODE,
+ ISD::FMUL_MODE, ISD::FDIV_MODE,
+ ISD::FSQRT, ISD::FMA_MODE};
+
static const ISD::CondCode FPCCToExpand[] = {
ISD::SETOGT, ISD::SETOGE, ISD::SETONE, ISD::SETUEQ, ISD::SETUGT,
ISD::SETUGE, ISD::SETULT, ISD::SETULE, ISD::SETUNE, ISD::SETGT,
@@ -526,6 +530,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
if (Subtarget.hasStdExtFOrZfinx()) {
setOperationAction(FPLegalNodeTypes, MVT::f32, Legal);
+ setOperationAction(FPStaticRoundNodes, MVT::f32, Legal);
setOperationAction(FPRndMode, MVT::f32,
Subtarget.hasStdExtZfa() ? Legal : Custom);
setCondCodeAction(FPCCToExpand, MVT::f32, Expand);
@@ -21427,6 +21432,23 @@ bool RISCVTargetLowering::preferScalarizeSplat(SDNode *N) const {
return true;
}
+int RISCVTargetLowering::getMachineRoundingMode(RoundingMode RM) const {
+ switch (RM) {
+ case RoundingMode::TowardZero:
+ return RISCVFPRndMode::RTZ;
+ case RoundingMode::NearestTiesToEven:
+ return RISCVFPRndMode::RNE;
+ case RoundingMode::TowardNegative:
+ return RISCVFPRndMode::RDN;
+ case RoundingMode::TowardPositive:
+ return RISCVFPRndMode::RUP;
+ case RoundingMode::NearestTiesToAway:
+ return RISCVFPRndMode::RMM;
+ default:
+ return -1;
+ }
+}
+
static Value *useTpOffset(IRBuilderBase &IRB, unsigned Offset) {
Module *M = IRB.GetInsertBlock()->getParent()->getParent();
Function *ThreadPointerFunc =
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 498c77f1875ed..b541472304d22 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -535,6 +535,11 @@ class RISCVTargetLowering : public TargetLowering {
bool softPromoteHalfType() const override { return true; }
+ bool isStaticRoundingSupportedFor(const Instruction &I) const override {
+ return true;
+ }
+ int getMachineRoundingMode(RoundingMode RM) const override;
+
/// Return the register type for a given MVT, ensuring vectors are treated
/// as a series of gpr sized integers.
MVT getRegisterTypeForCallingConv(LLVMContext &Context, CallingConv::ID CC,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoF.td b/llvm/lib/Target/RISCV/RISCVInstrInfoF.td
index e6c25e0844fb2..e8c3bb0da192f 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoF.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoF.td
@@ -77,6 +77,12 @@ def any_fma_nsz : PatFrag<(ops node:$rs1, node:$rs2, node:$rs3),
(any_fma node:$rs1, node:$rs2, node:$rs3), [{
return N->getFlags().hasNoSignedZeros();
}]>;
+
+def fma_mode_nsz : PatFrag<(ops node:$rs1, node:$rs2, node:$rs3, node:$rm),
+ (fma_mode node:$rs1, node:$rs2, node:$rs3, node:$rm), [{
+ return N->getFlags().hasNoSignedZeros();
+}]>;
+
//===----------------------------------------------------------------------===//
// Operand and SDNode transformation definitions.
//===----------------------------------------------------------------------===//
@@ -513,6 +519,18 @@ multiclass PatFprFprDynFrm_m<SDPatternOperator OpNode, RVInstRFrm Inst,
Ext.PrimaryTy, Ext.PrimaryVT>;
}
+class PatFprFprModeFrm<SDPatternOperator OpNode, RVInstRFrm Inst,
+ DAGOperand RegTy, ValueType vt>
+ : Pat<(OpNode (vt RegTy:$rs1), (vt RegTy:$rs2), (XLenVT timm:$rm)),
+ (Inst $rs1, $rs2, $rm)>;
+multiclass PatFprFprModeFrm_m<SDPatternOperator OpNode, RVInstRFrm Inst,
+ ExtInfo Ext> {
+ let Predicates = Ext.Predicates in
+ def Ext.Suffix : PatFprFprModeFrm<OpNode,
+ !cast<RVInstRFrm>(Inst#Ext.Suffix),
+ Ext.PrimaryTy, Ext.PrimaryVT>;
+}
+
/// Float conversion operations
// [u]int32<->float conversion patterns must be gated on IsRV32 or IsRV64, so
@@ -526,8 +544,17 @@ foreach Ext = FExts in {
defm : PatFprFprDynFrm_m<any_fdiv, FDIV_S, Ext>;
}
+foreach Ext = FExts in {
+ defm : PatFprFprModeFrm_m<fadd_mode, FADD_S, Ext>;
+ defm : PatFprFprModeFrm_m<fsub_mode, FSUB_S, Ext>;
+ defm : PatFprFprModeFrm_m<fmul_mode, FMUL_S, Ext>;
+ defm : PatFprFprModeFrm_m<fdiv_mode, FDIV_S, Ext>;
+}
+
let Predicates = [HasStdExtF] in {
def : Pat<(any_fsqrt FPR32:$rs1), (FSQRT_S FPR32:$rs1, FRM_DYN)>;
+def : Pat<(fsqrt_mode FPR32:$rs1, (XLenVT timm:$rm)),
+ (FSQRT_S FPR32:$rs1, frmarg:$rm)>;
def : Pat<(fneg FPR32:$rs1), (FSGNJN_S $rs1, $rs1)>;
def : Pat<(fabs FPR32:$rs1), (FSGNJX_S $rs1, $rs1)>;
@@ -537,6 +564,8 @@ def : Pat<(riscv_fclass FPR32:$rs1), (FCLASS_S $rs1)>;
let Predicates = [HasStdExtZfinx] in {
def : Pat<(any_fsqrt FPR32INX:$rs1), (FSQRT_S_INX FPR32INX:$rs1, FRM_DYN)>;
+def : Pat<(fsqrt_mode FPR32INX:$rs1, (XLenVT timm:$rm)),
+ (FSQRT_S_INX FPR32INX:$rs1, frmarg:$rm)>;
def : Pat<(fneg FPR32INX:$rs1), (FSGNJN_S_INX $rs1, $rs1)>;
def : Pat<(fabs FPR32INX:$rs1), (FSGNJX_S_INX $rs1, $rs1)>;
@@ -555,22 +584,33 @@ def : Pat<(fcopysign FPR32:$rs1, (fneg FPR32:$rs2)), (FSGNJN_S $rs1, $rs2)>;
// fmadd: rs1 * rs2 + rs3
def : Pat<(any_fma FPR32:$rs1, FPR32:$rs2, FPR32:$rs3),
(FMADD_S $rs1, $rs2, $rs3, FRM_DYN)>;
+def : Pat<(fma_mode FPR32:$rs1, FPR32:$rs2, FPR32:$rs3, (XLenVT timm:$rm)),
+ (FMADD_S $rs1, $rs2, $rs3, frmarg:$rm)>;
// fmsub: rs1 * rs2 - rs3
def : Pat<(any_fma FPR32:$rs1, FPR32:$rs2, (fneg FPR32:$rs3)),
(FMSUB_S FPR32:$rs1, FPR32:$rs2, FPR32:$rs3, FRM_DYN)>;
+def : Pat<(fma_mode FPR32:$rs1, FPR32:$rs2, (fneg FPR32:$rs3),
+ (XLenVT timm:$rm)),
+ (FMSUB_S FPR32:$rs1, FPR32:$rs2, FPR32:$rs3, frmarg:$rm)>;
// fnmsub: -rs1 * rs2 + rs3
def : Pat<(any_fma (fneg FPR32:$rs1), FPR32:$rs2, FPR32:$rs3),
(FNMSUB_S FPR32:$rs1, FPR32:$rs2, FPR32:$rs3, FRM_DYN)>;
+def : Pat<(fma_mode (fneg FPR32:$rs1), FPR32:$rs2, FPR32:$rs3, (XLenVT timm:$rm)),
+ (FNMSUB_S FPR32:$rs1, FPR32:$rs2, FPR32:$rs3, frmarg:$rm)>;
// fnmadd: -rs1 * rs2 - rs3
def : Pat<(any_fma (fneg FPR32:$rs1), FPR32:$rs2, (fneg FPR32:$rs3)),
(FNMADD_S FPR32:$rs1, FPR32:$rs2, FPR32:$rs3, FRM_DYN)>;
+def : Pat<(fma_mode (fneg FPR32:$rs1), FPR32:$rs2, (fneg FPR32:$rs3), (XLenVT timm:$rm)),
+ (FNMADD_S FPR32:$rs1, FPR32:$rs2, FPR32:$rs3, frmarg:$rm)>;
// fnmadd: -(rs1 * rs2 + rs3) (the nsz flag on the FMA)
def : Pat<(fneg (any_fma_nsz FPR32:$rs1, FPR32:$rs2, FPR32:$rs3)),
(FNMADD_S FPR32:$rs1, FPR32:$rs2, FPR32:$rs3, FRM_DYN)>;
+def : Pat<(fneg (fma_mode_nsz FPR32:$rs1, FPR32:$rs2, FPR32:$rs3, (XLenVT timm:$rm))),
+ (FNMADD_S FPR32:$rs1, FPR32:$rs2, FPR32:$rs3, frmarg:$rm)>;
} // Predicates = [HasStdExtF]
let Predicates = [HasStdExtZfinx] in {
@@ -579,22 +619,32 @@ def : Pat<(fcopysign FPR32INX:$rs1, (fneg FPR32INX:$rs2)), (FSGNJN_S_INX $rs1, $
// fmadd: rs1 * rs2 + rs3
def : Pat<(any_fma FPR32INX:$rs1, FPR32INX:$rs2, FPR32INX:$rs3),
(FMADD_S_INX $rs1, $rs2, $rs3, FRM_DYN)>;
+def : Pat<(fma_mode FPR32INX:$rs1, FPR32INX:$rs2, FPR32INX:$rs3, (XLenVT timm:$rm)),
+ (FMADD_S_INX $rs1, $rs2, $rs3, frmarg:$rm)>;
// fmsub: rs1 * rs2 - rs3
def : Pat<(any_fma FPR32INX:$rs1, FPR32INX:$rs2, (fneg FPR32INX:$rs3)),
(FMSUB_S_INX FPR32INX:$rs1, FPR32INX:$rs2, FPR32INX:$rs3, FRM_DYN)>;
+def : Pat<(fma_mode FPR32INX:$rs1, FPR32INX:$rs2, (fneg FPR32INX:$rs3), (XLenVT timm:$rm)),
+ (FMSUB_S_INX FPR32INX:$rs1, FPR32INX:$rs2, FPR32INX:$rs3, frmarg:$rm)>;
// fnmsub: -rs1 * rs2 + rs3
def : Pat<(any_fma (fneg FPR32INX:$rs1), FPR32INX:$rs2, FPR32INX:$rs3),
(FNMSUB_S_INX FPR32INX:$rs1, FPR32INX:$rs2, FPR32INX:$rs3, FRM_DYN)>;
+def : Pat<(fma_mode (fneg FPR32INX:$rs1), FPR32INX:$rs2, FPR32INX:$rs3, (XLenVT timm:$rm)),
+ (FNMSUB_S_INX FPR32INX:$rs1, FPR32INX:$rs2, FPR32INX:$rs3, frmarg:$rm)>;
// fnmadd: -rs1 * rs2 - rs3
def : Pat<(any_fma (fneg FPR32INX:$rs1), FPR32INX:$rs2, (fneg FPR32INX:$rs3)),
(FNMADD_S_INX FPR32INX:$rs1, FPR32INX:$rs2, FPR32INX:$rs3, FRM_DYN)>;
+def : Pat<(fma_mode (fneg FPR32INX:$rs1), FPR32INX:$rs2, (fneg FPR32INX:$rs3), (XLenVT timm:$rm)),
+ (FNMADD_S_INX FPR32INX:$rs1, FPR32INX:$rs2, FPR32INX:$rs3, frmarg:$rm)>;
// fnmadd: -(rs1 * rs2 + rs3) (the nsz flag on the FMA)
def : Pat<(fneg (any_fma_nsz FPR32INX:$rs1, FPR32INX:$rs2, FPR32INX:$rs3)),
(FNMADD_S_INX FPR32INX:$rs1, FPR32INX:$rs2, FPR32INX:$rs3, FRM_DYN)>;
+def : Pat<(fneg (fma_mode_nsz FPR32INX:$rs1, FPR32INX:$rs2, FPR32INX:$rs3, (XLenVT timm:$rm))),
+ (FNMADD_S_INX FPR32INX:$rs1, FPR32INX:$rs2, FPR32INX:$rs3, frmarg:$rm)>;
} // Predicates = [HasStdExtZfinx]
// The ratified 20191213 ISA spec defines fmin and fmax in a way that matches
@@ -717,6 +767,12 @@ def : Pat<(i32 (any_lround FPR32:$rs1)), (FCVT_W_S $rs1, FRM_RMM)>;
// [u]int->float. Match GCC and default to using dynamic rounding mode.
def : Pat<(any_sint_to_fp (i32 GPR:$rs1)), (FCVT_S_W $rs1, FRM_DYN)>;
def : Pat<(any_uint_to_fp (i32 GPR:$rs1)), (FCVT_S_WU $rs1, FRM_DYN)>;
+
+// [u]int->float using static rounding mode.
+def : Pat<(sint_to_fp_mode (i32 GPR:$rs1), (XLenVT timm:$rm)),
+ (FCVT_S_W $rs1, frmarg:$rm)>;
+def : Pat<(uint_to_fp_mode (i32 GPR:$rs1), (XLenVT timm:$rm)),
+ (FCVT_S_WU $rs1, frmarg:$rm)>;
} // Predicates = [HasStdExtF]
let Predicates = [HasStdExtZfinx] in {
@@ -737,6 +793,12 @@ def : Pat<(i32 (any_lround FPR32INX:$rs1)), (FCVT_W_S_INX $rs1, FRM_RMM)>;
// [u]int->float. Match GCC and default to using dynamic rounding mode.
def : Pat<(any_sint_to_fp (i32 GPR:$rs1)), (FCVT_S_W_INX $rs1, FRM_DYN)>;
def : Pat<(any_uint_to_fp (i32 GPR:$rs1)), (FCVT_S_WU_INX $rs1, FRM_DYN)>;
+
+// [u]int->float using static rounding mode.
+def : Pat<(sint_to_fp_mode (i32 GPR:$rs1), (XLenVT timm:$rm)),
+ (FCVT_S_W_INX $rs1, frmarg:$rm)>;
+def : Pat<(uint_to_fp_mode (i32 GPR:$rs1), (XLenVT timm:$rm)),
+ (FCVT_S_WU_INX $rs1, frmarg:$rm)>;
} // Predicates = [HasStdExtZfinx]
let Predicates = [HasStdExtF, IsRV64] in {
@@ -771,6 +833,16 @@ def : Pat<(any_sint_to_fp (i64 (sexti32 (i64 GPR:$rs1)))), (FCVT_S_W $rs1, FRM_D
def : Pat<(any_uint_to_fp (i64 (zexti32 (i64 GPR:$rs1)))), (FCVT_S_WU $rs1, FRM_DYN)>;
def : Pat<(any_sint_to_fp (i64 GPR:$rs1)), (FCVT_S_L $rs1, FRM_DYN)>;
def : Pat<(any_uint_to_fp (i64 GPR:$rs1)), (FCVT_S_LU $rs1, FRM_DYN)>;
+
+// [u]int->fp using static rounding mode.
+def : Pat<(sint_to_fp_mode (i64 (sexti32 (i64 GPR:$rs1))), (i64 timm:$rm)),
+ (FCVT_S_W $rs1, frmarg:$rm)>;
+def : Pat<(uint_to_fp_mode (i64 (zexti32 (i64 GPR:$rs1))), (i64 timm:$rm)),
+ (FCVT_S_WU $rs1, frmarg:$rm)>;
+def : Pat<(sint_to_fp_mode (i64 GPR:$rs1), (i64 timm:$rm)),
+ (FCVT_S_L $rs1, frmarg:$rm)>;
+def : Pat<(uint_to_fp_mode (i64 GPR:$rs1), (i64 timm:$rm)),
+ (FCVT_S_LU $rs1, frmarg:$rm)>;
} // Predicates = [HasStdExtF, IsRV64]
let Predicates = [HasStdExtZfinx, IsRV64] in {
@@ -805,4 +877,14 @@ def : Pat<(any_sint_to_fp (i64 (sexti32 (i64 GPR:$rs1)))), (FCVT_S_W_INX $rs1, F
def : Pat<(any_uint_to_fp (i64 (zexti32 (i64 GPR:$rs1)))), (FCVT_S_WU_INX $rs1, FRM_DYN)>;
def : Pat<(any_sint_to_fp (i64 GPR:$rs1)), (FCVT_S_L_INX $rs1, FRM_DYN)>;
def : Pat<(any_uint_to_fp (i64 GPR:$rs1)), (FCVT_S_LU_INX $rs1, FRM_DYN)>;
+
+// [u]int->fp using static rounding mode.
+def : Pat<(sint_to_fp_mode (i64 (sexti32 (i64 GPR:$rs1))), (i64 timm:$rm)),
+ (FCVT_S_W_INX $rs1, frmarg:$rm)>;
+def : Pat<(uint_to_fp_mode (i64 (zexti32 (i64 GPR:$rs1))), (i64 timm:$rm)),
+ (FCVT_S_WU_INX $rs1, frmarg:$rm)>;
+def : Pat<(sint_to_fp_mode (i64 GPR:$rs1), (i64 timm:$rm)),
+ (FCVT_S_L_INX $rs1, frmarg:$rm)>;
+def : Pat<(uint_to_fp_mode (i64 GPR:$rs1), (i64 timm:$rm)),
+ (FCVT_S_LU_INX $rs1, frmarg:$rm)>;
} // Predicates = [HasStdExtZfinx, IsRV64]
diff --git a/llvm/test/CodeGen/RISCV/float-mode.ll b/llvm/test/CodeGen/RISCV/float-mode.ll
new file mode 100644
index 0000000000000..3d751b1d3aa56
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/float-mode.ll
@@ -0,0 +1,249 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=riscv32 -mattr=+f -verify-machineinstrs -target-abi=ilp32f < %s | FileCheck --check-prefix=FLOAT %s
+; RUN: llc -mtriple=riscv32 -mattr=+zfinx -verify-machineinstrs -target-abi=ilp32 < %s | FileCheck --check-prefix=FINT %s
+; RUN: llc -mtriple=riscv64 -mattr=+f -verify-machineinstrs -target-abi=lp64f < %s | FileCheck --check-prefix=FLOAT %s
+; RUN: llc -mtriple=riscv64 -mattr=+zfinx -verify-machineinstrs -target-abi=lp64 < %s | FileCheck --check-prefix=FINT %s
+
+
+define float @add_dyn(float %x, float %y) strictfp nounwind {
+; FLOAT-LABEL: add_dyn:
+; FLOAT: # %bb.0:
+; FLOAT-NEXT: fadd.s fa0, fa0, fa1
+; FLOAT-NEXT: ret
+;
+; FINT-LABEL: add_dyn:
+; FINT: # %bb.0:
+; FINT-NEXT: fadd.s a0, a0, a1
+; FINT-NEXT: ret
+ %add = tail call float @llvm.experimental.constrained.fadd.f32(float %x, float %y, metadata !"round.dynamic", metadata !"fpexcept.ignore") strictfp
+ ret float %add
+}
+
+define float @add_rte(float %x, float %y) strictfp nounwind {
+; FLOAT-LABEL: add_rte:
+; FLOAT: # %bb.0:
+; FLOAT-NEXT: fadd.s fa0, fa0, fa1, rne
+; FLOAT-NEXT: ret
+;
+; FINT-LABEL: add_rte:
+; FINT: # %bb.0:
+; FINT-NEXT: fadd.s a0, a0, a1, rne
+; FINT-NEXT: ret
+ %add = tail call float @llvm.experimental.constrained.fadd.f32(float %x, float %y, metadata !"round.tonearest", metadata !"fpexcept.ignore") strictfp
+ ret float %add
+}
+
+define float @add_rtz(float %x, float %y) strictfp nounwind {
+; FLOAT-LABEL: add_rtz:
+; FLOAT: # %bb.0:
+; FLOAT-NEXT: fadd.s fa0, fa0, fa1, rtz
+; FLOAT-NEXT: ret
+;
+; FINT-LABEL: add_rtz:
+; FINT: # %bb.0:
+; FINT-NEXT: fadd.s a0, a0, a1, rtz
+; FINT-NEXT: ret
+ %add = tail call float @llvm.experimental.constrained.fadd.f32(float %x, float %y, metadata !"round.towardzero", metadata !"fpexcept.ignore") strictfp
+ ret float %add
+}
+
+define float @add_rup(float %x, float %y) strictfp nounwind {
+; FLOAT-LABEL: add_rup:
+; FLOAT: # %bb.0:
+; FLOAT-NEXT: fadd.s fa0, fa0, fa1, rup
+; FLOAT-NEXT: ret
+;
+; FINT-LABEL: add_rup:
+; FINT: # %bb.0:
+; FINT-NEXT: fadd.s a0, a0, a1, rup
+; FINT-NEXT: ret
+ %add = tail call float @llvm.experimental.constrained.fadd.f32(float %x, float %y, metadata !"round.upward", metadata !"fpexcept.ignore") strictfp
+ ret float %add
+}
+
+define float @add_rdn(float %x, float %y) strictfp nounwind {
+; FLOAT-LABEL: add_rdn:
+; FLOAT: # %bb.0:
+; FLOAT-NEXT: fadd.s fa0, fa0, fa1, rdn
+; FLOAT-NEXT: ret
+;
+; FINT-LABEL: add_rdn:
+; FINT: # %bb.0:
+; FINT-NEXT: fadd.s a0, a0, a1, rdn
+; FINT-NEXT: ret
+ %add = tail call float @llvm.experimental.constrained.fadd.f32(float %x, float %y, metadata !"round.downward", metadata !"fpexcept.ignore") strictfp
+ ret float %add
+}
+
+define float @add_rmm(float %x, float %y) strictfp nounwind {
+; FLOAT-LABEL: add_rmm:
+; FLOAT: # %bb.0:
+; FLOAT-NEXT: fadd.s fa0, fa0, fa1, rmm
+; FLOAT-NEXT: ret
+;
+; FINT-LABEL: add_rmm:
+; FINT: # %bb.0:
+; FINT-NEXT: fadd.s a0, a0, a1, rmm
+; FINT-NEXT: ret
+ %add = tail call float @llvm.experimental.constrained.fadd.f32(float %x, float %y, metadata !"round.tonearestaway", metadata !"fpexcept.ignore") strictfp
+ ret float %add
+}
+
+define float @sub_rup(float %x, float %y) strictfp nounwind {
+; FLOAT-LABEL: sub_rup:
+; FLOAT: # %bb.0:
+; FLOAT-NEXT: fsub.s fa0, fa0, fa1, rup
+; FLOAT-NEXT: ret
+;
+; FINT-LABEL: sub_rup:
+; FINT: # %bb.0:
+; FINT-NEXT: fsub.s a0, a0, a1, rup
+; FINT-NEXT: ret
+ %sub = tail call float @llvm.experimental.constrained.fsub.f32(float %x, float %y, metadata !"round.upward", metadata !"fpexcept.ignore") strictfp
+ ret float %sub
+}
+
+define float @mul_rup(float %x, float %y) strictfp nounwind {
+; FLOAT-LABEL: mul_rup:
+; FLOAT: # %bb.0:
+; FLOAT-NEXT: fmul.s fa0, fa0, fa1, rup
+; FLOAT-NEXT: ret
+;
+; FINT-LABEL: mul_rup:
+; FINT: # %bb.0:
+; FINT-NEXT: fmul.s a0, a0, a1, rup
+; FINT-NEXT: ret
+ %mul = tail call float @llvm.experimental.constrained.fmul.f32(float %x, float %y, metadata !"round.upward", metadata !"fpexcept.ignore") strictfp
+ ret float %mul
+}
+
+define float @div_rup(float %x, float %y) strictfp nounwind {
+; FLOAT-LABEL: div_rup:
+; FLOAT: # %bb.0:
+; FLOAT-NEXT: fdiv.s fa0, fa0, fa1, rup
+; FLOAT-NEXT: ret
+;
+; FINT-LABEL: div_rup:
+; FINT: # %bb.0:
+; FINT-NEXT: fdiv.s a0, a0, a1, rup
+; FINT-NEXT: ret
+ %div = tail call float @llvm.experimental.constrained.fdiv.f32(float %x, float %y, metadata !"round.upward", metadata !"fpexcept.ignore") strictfp
+ ret float %div
+}
+
+define float @sqrt_rdn(float %x) strictfp nounwind {
+; FLOAT-LABEL: sqrt_rdn:
+; FLOAT: # %bb.0:
+; FLOAT-NEXT: fsqrt.s fa0, fa0, rdn
+; FLOAT-NEXT: ret
+;
+; FINT-LABEL: sqrt_rdn:
+; FINT: # %bb.0:
+; FINT-NEXT: fsqrt.s a0, a0, rdn
+; FINT-NEXT: ret
+ %sqrt = tail call float @llvm.experimental.constrained.sqrt.f32(float %x, metadata !"round.downward", metadata !"fpexcept.ignore") strictfp
+ ret float %sqrt
+}
+
+define float @fmadd_rup(float %a, float %b, float %c) nounwind strictfp {
+; FLOAT-LABEL: fmadd_rup:
+; FLOAT: # %bb.0:
+; FLOAT-NEXT: fmadd.s fa0, fa0, fa1, fa2, rup
+; FLOAT-NEXT: ret
+;
+; FINT-LABEL: fmadd_rup:
+; FINT: # %bb.0:
+; FINT-NEXT: fmadd.s a0, a0, a1, a2, rup
+; FINT-NEXT: ret
+ %1 = call float @llvm.experimental.constrained.fma.f32(float %a, float %b, float %c, metadata !"round.upward", metadata !"fpexcept.ignore") strictfp
+ ret float %1
+}
+
+define float @fmsub_rdn(float %a, float %b, float %c) nounwind strictfp {
+; FLOAT-LABEL: fmsub_rdn:
+; FLOAT: # %bb.0:
+; FLOAT-NEXT: fmv.w.x fa5, zero
+; FLOAT-NEXT: fadd.s fa5, fa2, fa5
+; FLOAT-NEXT: fmsub.s fa0, fa0, fa1, fa5, rdn
+; FLOAT-NEXT: ret
+;
+; FINT-LABEL: fmsub_rdn:
+; FINT: # %bb.0:
+; FINT-NEXT: fadd.s a2, a2, zero
+; FINT-NEXT: fmsub.s a0, a0, a1, a2, rdn
+; FINT-NEXT: ret
+ %c_ = fadd float 0.0, %c ; avoid negation using xor
+ %negc = fneg float %c_
+ %1 = call float @llvm.experimental.constrained.fma.f32(float %a, float %b, float %negc, metadata !"round.downward", metadata !"fpexcept.ignore") strictfp
+ ret float %1
+}
+
+define float @fnmadd_rtz(float %a, float %b, float %c) nounwind strictfp {
+; FLOAT-LABEL: fnmadd_rtz:
+; FLOAT: # %bb.0:
+; FLOAT-NEXT: fmv.w.x fa5, zero
+; FLOAT-NEXT: fadd.s fa4, fa0, fa5
+; FLOAT-NEXT: fadd.s fa5, fa2, fa5
+; FLOAT-NEXT: fnmadd.s fa0, fa4, fa1, fa5, rtz
+; FLOAT-NEXT: ret
+;
+; FINT-LABEL: fnmadd_rtz:
+; FINT: # %bb.0:
+; FINT-NEXT: fadd.s a0, a0, zero
+; FINT-NEXT: fadd.s a2, a2, zero
+; FINT-NEXT: fnmadd.s a0, a0, a1, a2, rtz
+; FINT-NEXT: ret
+ %a_ = fadd float 0.0, %a
+ %c_ = fadd float 0.0, %c
+ %nega = fneg float %a_
+ %negc = fneg float %c_
+ %1 = call float @llvm.experimental.constrained.fma.f32(float %nega, float %b, float %negc, metadata !"round.towardzero", metadata !"fpexcept.ignore") strictfp
+ ret float %1
+}
+
+define float @fnmsub_rte(float %a, float %b, float %c) nounwind strictfp {
+; FLOAT-LABEL: fnmsub_rte:
+; FLOAT: # %bb.0:
+; FLOAT-NEXT: fmv.w.x fa5, zero
+; FLOAT-NEXT: fadd.s fa5, fa0, fa5
+; FLOAT-NEXT: fnmsub.s fa0, fa5, fa1, fa2, rne
+; FLOAT-NEXT: ret
+;
+; FINT-LABEL: fnmsub_rte:
+; FINT: # %bb.0:
+; FINT-NEXT: fadd.s a0, a0, zero
+; FINT-NEXT: fnmsub.s a0, a0, a1, a2, rne
+; FINT-NEXT: ret
+ %a_ = fadd float 0.0, %a
+ %nega = fneg float %a_
+ %1 = call float @llvm.experimental.constrained.fma.f32(float %nega, float %b, float %c, metadata !"round.tonearest", metadata !"fpexcept.ignore") strictfp
+ ret float %1
+}
+
+define float @sitofp_rmm(i32 %a) nounwind strictfp {
+; FLOAT-LABEL: sitofp_rmm:
+; FLOAT: # %bb.0:
+; FLOAT-NEXT: fcvt.s.w fa0, a0, rmm
+; FLOAT-NEXT: ret
+;
+; FINT-LABEL: sitofp_rmm:
+; FINT: # %bb.0:
+; FINT-NEXT: fcvt.s.w a0, a0, rmm
+; FINT-NEXT: ret
+ %1 = call float @llvm.experimental.constrained.sitofp.f32.i32(i32 %a, metadata !"round.tonearestaway", metadata !"fpexcept.ignore") strictfp
+ ret float %1
+}
+
+define float @uitofp_rne(i32 %a) nounwind strictfp {
+; FLOAT-LABEL: uitofp_rne:
+; FLOAT: # %bb.0:
+; FLOAT-NEXT: fcvt.s.wu fa0, a0, rne
+; FLOAT-NEXT: ret
+;
+; FINT-LABEL: uitofp_rne:
+; FINT: # %bb.0:
+; FINT-NEXT: fcvt.s.wu a0, a0, rne
+; FINT-NEXT: ret
+ %1 = call float @llvm.experimental.constrained.uitofp.f32.i32(i32 %a, metadata !"round.tonearest", metadata !"fpexcept.ignore") strictfp
+ ret float %1
+}
>From 193feaa5649dbbbdc569c74f83e248545e7c3462 Mon Sep 17 00:00:00 2001
From: Serge Pavlov <sepavloff at gmail.com>
Date: Wed, 31 Jul 2024 22:57:37 +0700
Subject: [PATCH 2/3] Address review comments
- change node names from *_MODE to *_ROUND. Also rename related objects,
- remove handling fptrunc_round. Tests for it require another FP type,
which suport static rounding.
- get rid of the target callback getMachineRoundingMode. It requires new
RISCV-specific nodes to represent the nodes with machine-specific
rounding mode,
- provide implementaton of isStaticRoundingSupportedFor.
---
llvm/include/llvm/CodeGen/ISDOpcodes.h | 25 ++--
llvm/include/llvm/CodeGen/TargetLowering.h | 4 -
.../include/llvm/Target/TargetSelectionDAG.td | 28 ++---
llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp | 5 +
.../SelectionDAG/LegalizeIntegerTypes.cpp | 8 +-
llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h | 4 +-
.../SelectionDAG/SelectionDAGBuilder.cpp | 28 ++---
.../SelectionDAG/SelectionDAGDumper.cpp | 17 ++-
.../Target/RISCV/MCTargetDesc/RISCVBaseInfo.h | 20 ++++
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 111 +++++++++++++++---
llvm/lib/Target/RISCV/RISCVISelLowering.h | 19 ++-
llvm/lib/Target/RISCV/RISCVInstrInfoF.td | 80 ++++++++-----
12 files changed, 231 insertions(+), 118 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 71a3933250ccb..20812a0fbbd35 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -411,19 +411,18 @@ enum NodeType {
STRICT_FREM,
STRICT_FMA,
- /// Basic floating-point operations with statically specified control modes,
- /// usually rounding. These have the same operands as the corresponding
- /// default mode operations with an additional integer operand that represents
- /// the specified modes in a target-dependent format.
- FADD_MODE,
- FSUB_MODE,
- FMUL_MODE,
- FDIV_MODE,
- FSQRT_MODE,
- FMA_MODE,
- SINT_TO_FP_MODE,
- UINT_TO_FP_MODE,
- FP_ROUND_MODE,
+ /// Basic floating-point operations with statically specified rounding mode.
+ /// They have the same operands as the corresponding default mode operations
+ /// with an additional integer operand that represents the specified rounding
+ /// mode in target-independent format, the same as used in llvm.get_rounding.
+ FADD_ROUND,
+ FSUB_ROUND,
+ FMUL_ROUND,
+ FDIV_ROUND,
+ FSQRT_ROUND,
+ FMA_ROUND,
+ SINT_TO_FP_ROUND,
+ UINT_TO_FP_ROUND,
/// Constrained versions of libm-equivalent floating point intrinsics.
/// These will be lowered to the equivalent non-constrained pseudo-op
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index d9ba3e3ef9d2f..8c4902742c9eb 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -360,10 +360,6 @@ class TargetLoweringBase {
return false;
}
- /// Returns target-specific representation of the given static rounding mode
- /// or -1, if this rounding mode is not supported.
- virtual int getMachineRoundingMode(RoundingMode RM) const { return -1; }
-
protected:
/// Initialize all of the actions to default values.
void initActions();
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index b206a9d9df983..748b061fa26a2 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -134,7 +134,7 @@ def SDTIntScaledBinOp : SDTypeProfile<1, 3, [ // smulfix, sdivfix, etc
def SDTFPBinOp : SDTypeProfile<1, 2, [ // fadd, fmul, etc.
SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisFP<0>
]>;
-def SDTFPBinModeOp : SDTypeProfile<1, 3, [ // fadd, fmul, etc. with static modes
+def SDTFPBinRoundOp : SDTypeProfile<1, 3, [ // fadd, fmul, etc. with static rounding
SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisFP<0>, SDTCisInt<3>
]>;
def SDTFPSignOp : SDTypeProfile<1, 2, [ // fcopysign.
@@ -143,7 +143,7 @@ def SDTFPSignOp : SDTypeProfile<1, 2, [ // fcopysign.
def SDTFPTernaryOp : SDTypeProfile<1, 3, [ // fmadd, fnmsub, etc.
SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisSameAs<0, 3>, SDTCisFP<0>
]>;
-def SDTFPTernaryModeOp : SDTypeProfile<1, 4, [ // fmadd, fnmsub, etc. with static modes
+def SDTFPTernaryRoundOp : SDTypeProfile<1, 4, [ // fmadd, fnmsub, etc. with static rounding
SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisSameAs<0, 3>, SDTCisFP<0>, SDTCisInt<4>
]>;
def SDTIntUnaryOp : SDTypeProfile<1, 1, [ // bitreverse
@@ -161,15 +161,12 @@ def SDTIntTruncOp : SDTypeProfile<1, 1, [ // trunc
def SDTFPUnaryOp : SDTypeProfile<1, 1, [ // fneg, fsqrt, etc
SDTCisSameAs<0, 1>, SDTCisFP<0>
]>;
-def SDTFPUnaryModeOp : SDTypeProfile<1, 2, [ // fsqrt with static modes
+def SDTFPUnaryRoundOp : SDTypeProfile<1, 2, [ // fsqrt with static rounding
SDTCisSameAs<0, 1>, SDTCisFP<0>, SDTCisInt<2>
]>;
def SDTFPRoundOp : SDTypeProfile<1, 1, [ // fpround
SDTCisFP<0>, SDTCisFP<1>, SDTCisOpSmallerThanOp<0, 1>, SDTCisSameNumEltsAs<0, 1>
]>;
-def SDTFPRoundModeOp : SDTypeProfile<1, 2, [ // fpround with static modes
- SDTCisFP<0>, SDTCisFP<1>, SDTCisOpSmallerThanOp<0, 1>, SDTCisSameNumEltsAs<0, 1>, SDTCisInt<2>
-]>;
def SDTFPExtendOp : SDTypeProfile<1, 1, [ // fpextend
SDTCisFP<0>, SDTCisFP<1>, SDTCisOpSmallerThanOp<1, 0>, SDTCisSameNumEltsAs<0, 1>
]>;
@@ -179,7 +176,7 @@ def SDIsFPClassOp : SDTypeProfile<1, 2, [ // is_fpclass
def SDTIntToFPOp : SDTypeProfile<1, 1, [ // [su]int_to_fp
SDTCisFP<0>, SDTCisInt<1>, SDTCisSameNumEltsAs<0, 1>
]>;
-def SDTIntToFPModeOp : SDTypeProfile<1, 2, [ // [su]int_to_fp with static modes
+def SDTIntToFPRoundOp : SDTypeProfile<1, 2, [ // [su]int_to_fp with static rounding
SDTCisFP<0>, SDTCisInt<1>, SDTCisSameNumEltsAs<0, 1>, SDTCisInt<2>
]>;
def SDTFPToIntOp : SDTypeProfile<1, 1, [ // fp_to_[su]int
@@ -677,15 +674,14 @@ def strict_fp_to_bf16 : SDNode<"ISD::STRICT_FP_TO_BF16",
def strict_fsetcc : SDNode<"ISD::STRICT_FSETCC", SDTSetCC, [SDNPHasChain]>;
def strict_fsetccs : SDNode<"ISD::STRICT_FSETCCS", SDTSetCC, [SDNPHasChain]>;
-def fadd_mode : SDNode<"ISD::FADD_MODE" , SDTFPBinModeOp, [SDNPCommutative]>;
-def fsub_mode : SDNode<"ISD::FSUB_MODE" , SDTFPBinModeOp>;
-def fmul_mode : SDNode<"ISD::FMUL_MODE" , SDTFPBinModeOp, [SDNPCommutative]>;
-def fdiv_mode : SDNode<"ISD::FDIV_MODE" , SDTFPBinModeOp>;
-def fsqrt_mode : SDNode<"ISD::FSQRT_MODE" , SDTFPUnaryModeOp>;
-def fma_mode : SDNode<"ISD::FMA_MODE" , SDTFPTernaryModeOp, [SDNPCommutative]>;
-def sint_to_fp_mode: SDNode<"ISD::SINT_TO_FP_MODE" , SDTIntToFPModeOp>;
-def uint_to_fp_mode: SDNode<"ISD::UINT_TO_FP_MODE" , SDTIntToFPModeOp>;
-def fpround_mode : SDNode<"ISD::FP_ROUND_MODE" , SDTFPRoundModeOp>;
+def fadd_round : SDNode<"ISD::FADD_ROUND" , SDTFPBinRoundOp, [SDNPCommutative]>;
+def fsub_round : SDNode<"ISD::FSUB_ROUND" , SDTFPBinRoundOp>;
+def fmul_round : SDNode<"ISD::FMUL_ROUND" , SDTFPBinRoundOp, [SDNPCommutative]>;
+def fdiv_round : SDNode<"ISD::FDIV_ROUND" , SDTFPBinRoundOp>;
+def fsqrt_round : SDNode<"ISD::FSQRT_ROUND" , SDTFPUnaryRoundOp>;
+def fma_round : SDNode<"ISD::FMA_ROUND" , SDTFPTernaryRoundOp, [SDNPCommutative]>;
+def sint_to_fp_round: SDNode<"ISD::SINT_TO_FP_ROUND" , SDTIntToFPRoundOp>;
+def uint_to_fp_round: SDNode<"ISD::UINT_TO_FP_ROUND" , SDTIntToFPRoundOp>;
def get_fpenv : SDNode<"ISD::GET_FPENV", SDTGetFPStateOp, [SDNPHasChain]>;
def set_fpenv : SDNode<"ISD::SET_FPENV", SDTSetFPStateOp, [SDNPHasChain]>;
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index bdb7917073020..58058b6ca5378 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -1025,6 +1025,11 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {
Action = TLI.getOperationAction(Node->getOpcode(),
Node->getOperand(1).getValueType());
break;
+ case ISD::SINT_TO_FP_ROUND:
+ case ISD::UINT_TO_FP_ROUND:
+ Action = TLI.getOperationAction(Node->getOpcode(),
+ Node->getOperand(0).getValueType());
+ break;
case ISD::SIGN_EXTEND_INREG: {
EVT InnerType = cast<VTSDNode>(Node->getOperand(1))->getVT();
Action = TLI.getOperationAction(Node->getOpcode(), InnerType);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index d3be39faca2f6..7a48123378100 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -1943,7 +1943,7 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
case ISD::VP_SIGN_EXTEND: Res = PromoteIntOp_VP_SIGN_EXTEND(N); break;
case ISD::VP_SINT_TO_FP:
case ISD::SINT_TO_FP: Res = PromoteIntOp_SINT_TO_FP(N); break;
- case ISD::SINT_TO_FP_MODE: Res = PromoteIntOp_SINT_TO_FP_MODE(N); break;
+ case ISD::SINT_TO_FP_ROUND: Res = PromoteIntOp_SINT_TO_FP_ROUND(N); break;
case ISD::STRICT_SINT_TO_FP: Res = PromoteIntOp_STRICT_SINT_TO_FP(N); break;
case ISD::STORE: Res = PromoteIntOp_STORE(cast<StoreSDNode>(N),
OpNo); break;
@@ -1964,7 +1964,7 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
case ISD::FP16_TO_FP:
case ISD::VP_UINT_TO_FP:
case ISD::UINT_TO_FP: Res = PromoteIntOp_UINT_TO_FP(N); break;
- case ISD::UINT_TO_FP_MODE: Res = PromoteIntOp_UINT_TO_FP_MODE(N); break;
+ case ISD::UINT_TO_FP_ROUND: Res = PromoteIntOp_UINT_TO_FP_ROUND(N); break;
case ISD::STRICT_FP16_TO_FP:
case ISD::STRICT_UINT_TO_FP: Res = PromoteIntOp_STRICT_UINT_TO_FP(N); break;
case ISD::ZERO_EXTEND: Res = PromoteIntOp_ZERO_EXTEND(N); break;
@@ -2346,14 +2346,14 @@ SDValue DAGTypeLegalizer::PromoteIntOp_SINT_TO_FP(SDNode *N) {
SExtPromotedInteger(N->getOperand(0))), 0);
}
-SDValue DAGTypeLegalizer::PromoteIntOp_SINT_TO_FP_MODE(SDNode *N) {
+SDValue DAGTypeLegalizer::PromoteIntOp_SINT_TO_FP_ROUND(SDNode *N) {
return SDValue(DAG.UpdateNodeOperands(N,
SExtPromotedInteger(N->getOperand(0)),
N->getOperand(1)),
0);
}
-SDValue DAGTypeLegalizer::PromoteIntOp_UINT_TO_FP_MODE(SDNode *N) {
+SDValue DAGTypeLegalizer::PromoteIntOp_UINT_TO_FP_ROUND(SDNode *N) {
return SDValue(DAG.UpdateNodeOperands(N,
ZExtPromotedInteger(N->getOperand(0)),
N->getOperand(1)),
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index eea01aea42a65..9b406a706e3ba 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -402,8 +402,8 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue PromoteIntOp_SIGN_EXTEND(SDNode *N);
SDValue PromoteIntOp_VP_SIGN_EXTEND(SDNode *N);
SDValue PromoteIntOp_SINT_TO_FP(SDNode *N);
- SDValue PromoteIntOp_SINT_TO_FP_MODE(SDNode *N);
- SDValue PromoteIntOp_UINT_TO_FP_MODE(SDNode *N);
+ SDValue PromoteIntOp_SINT_TO_FP_ROUND(SDNode *N);
+ SDValue PromoteIntOp_UINT_TO_FP_ROUND(SDNode *N);
SDValue PromoteIntOp_STRICT_SINT_TO_FP(SDNode *N);
SDValue PromoteIntOp_STORE(StoreSDNode *N, unsigned OpNo);
SDValue PromoteIntOp_TRUNCATE(SDNode *N);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 26de077171092..c2dbc3ad72765 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8153,40 +8153,36 @@ void SelectionDAGBuilder::visitConstrainedFPIntrinsic(
if (UseStaticRounding) {
switch (FPI.getIntrinsicID()) {
+ default:
+ break;
case Intrinsic::experimental_constrained_fadd:
- Opcode = ISD::FADD_MODE;
+ Opcode = ISD::FADD_ROUND;
break;
case Intrinsic::experimental_constrained_fsub:
- Opcode = ISD::FSUB_MODE;
+ Opcode = ISD::FSUB_ROUND;
break;
case Intrinsic::experimental_constrained_fmul:
- Opcode = ISD::FMUL_MODE;
+ Opcode = ISD::FMUL_ROUND;
break;
case Intrinsic::experimental_constrained_fdiv:
- Opcode = ISD::FDIV_MODE;
+ Opcode = ISD::FDIV_ROUND;
break;
case Intrinsic::experimental_constrained_sqrt:
- Opcode = ISD::FSQRT_MODE;
+ Opcode = ISD::FSQRT_ROUND;
break;
case Intrinsic::experimental_constrained_fma:
- Opcode = ISD::FMA_MODE;
+ Opcode = ISD::FMA_ROUND;
break;
case Intrinsic::experimental_constrained_sitofp:
- Opcode = ISD::SINT_TO_FP_MODE;
+ Opcode = ISD::SINT_TO_FP_ROUND;
break;
case Intrinsic::experimental_constrained_uitofp:
- Opcode = ISD::UINT_TO_FP_MODE;
- break;
- case Intrinsic::experimental_constrained_fptrunc:
- Opcode = ISD::FP_ROUND_MODE;
+ Opcode = ISD::UINT_TO_FP_ROUND;
break;
}
if (Opcode) {
- int MachineRM = TLI.getMachineRoundingMode(*RM);
- assert(MachineRM >= 0 && "Unsupported rounding mode");
- EVT RMType = TLI.getTypeToTransformTo(*DAG.getContext(), MVT::i32);
- Opers.push_back(DAG.getConstant(static_cast<uint64_t>(MachineRM), sdl,
- RMType, true));
+ Opers.push_back(DAG.getTargetConstant(static_cast<uint64_t>(*RM), sdl,
+ MVT::i8, true));
SDValue Result = DAG.getNode(Opcode, sdl, VT, Opers, Flags);
setValue(&FPI, Result);
return;
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index 30779a9a38379..be5d0b6037a0d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -206,7 +206,7 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
case ISD::FNEG: return "fneg";
case ISD::FSQRT: return "fsqrt";
case ISD::STRICT_FSQRT: return "strict_fsqrt";
- case ISD::FSQRT_MODE: return "fsqrt_mode";
+ case ISD::FSQRT_ROUND: return "fsqrt_round";
case ISD::FCBRT: return "fcbrt";
case ISD::FSIN: return "fsin";
case ISD::STRICT_FSIN: return "strict_fsin";
@@ -285,19 +285,19 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
case ISD::FSHR: return "fshr";
case ISD::FADD: return "fadd";
case ISD::STRICT_FADD: return "strict_fadd";
- case ISD::FADD_MODE: return "fadd_mode";
+ case ISD::FADD_ROUND: return "fadd_round";
case ISD::FSUB: return "fsub";
case ISD::STRICT_FSUB: return "strict_fsub";
- case ISD::FSUB_MODE: return "fsub_mode";
+ case ISD::FSUB_ROUND: return "fsub_round";
case ISD::FMUL: return "fmul";
case ISD::STRICT_FMUL: return "strict_fmul";
- case ISD::FMUL_MODE: return "fmul_mode";
+ case ISD::FMUL_ROUND: return "fmul_round";
case ISD::FDIV: return "fdiv";
case ISD::STRICT_FDIV: return "strict_fdiv";
- case ISD::FDIV_MODE: return "fdiv_mode";
+ case ISD::FDIV_ROUND: return "fdiv_round";
case ISD::FMA: return "fma";
case ISD::STRICT_FMA: return "strict_fma";
- case ISD::FMA_MODE: return "fma_mode";
+ case ISD::FMA_ROUND: return "fma_round";
case ISD::FMAD: return "fmad";
case ISD::FREM: return "frem";
case ISD::STRICT_FREM: return "strict_frem";
@@ -388,16 +388,15 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
case ISD::TRUNCATE: return "truncate";
case ISD::FP_ROUND: return "fp_round";
case ISD::STRICT_FP_ROUND: return "strict_fp_round";
- case ISD::FP_ROUND_MODE: return "fp_round_mode";
case ISD::FP_EXTEND: return "fp_extend";
case ISD::STRICT_FP_EXTEND: return "strict_fp_extend";
case ISD::SINT_TO_FP: return "sint_to_fp";
case ISD::STRICT_SINT_TO_FP: return "strict_sint_to_fp";
- case ISD::SINT_TO_FP_MODE: return "sint_to_fp_mode";
+ case ISD::SINT_TO_FP_ROUND: return "sint_to_fp_round";
case ISD::UINT_TO_FP: return "uint_to_fp";
case ISD::STRICT_UINT_TO_FP: return "strict_uint_to_fp";
- case ISD::UINT_TO_FP_MODE: return "uint_to_fp_mode";
+ case ISD::UINT_TO_FP_ROUND: return "uint_to_fp_round";
case ISD::FP_TO_SINT: return "fp_to_sint";
case ISD::STRICT_FP_TO_SINT: return "strict_fp_to_sint";
case ISD::FP_TO_UINT: return "fp_to_uint";
diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
index 626206962e752..1ebc59692cb95 100644
--- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
+++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
@@ -375,6 +375,26 @@ inline static bool isValidRoundingMode(unsigned Mode) {
return true;
}
}
+
+inline static RoundingMode getMachineMode(llvm::RoundingMode RM) {
+ switch (RM) {
+ case llvm::RoundingMode::TowardZero :
+ return RISCVFPRndMode::RTZ;
+ case llvm::RoundingMode::NearestTiesToEven:
+ return RISCVFPRndMode::RNE;
+ case llvm::RoundingMode::TowardNegative :
+ return RISCVFPRndMode::RDN;
+ case llvm::RoundingMode::TowardPositive:
+ return RISCVFPRndMode::RUP;
+ case llvm::RoundingMode::NearestTiesToAway:
+ return RISCVFPRndMode::RMM;
+ case llvm::RoundingMode::Dynamic:
+ return RISCVFPRndMode::DYN;
+ default:
+ return RoundingMode::Invalid;
+ }
+}
+
} // namespace RISCVFPRndMode
namespace RISCVVXRndMode {
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index c77e284d5fcc7..4dae262551335 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -434,9 +434,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
ISD::STRICT_FSUB, ISD::STRICT_FMUL, ISD::STRICT_FDIV,
ISD::STRICT_FSQRT, ISD::STRICT_FSETCC, ISD::STRICT_FSETCCS};
- static const unsigned FPStaticRoundNodes[] = {ISD::FADD_MODE, ISD::FSUB_MODE,
- ISD::FMUL_MODE, ISD::FDIV_MODE,
- ISD::FSQRT, ISD::FMA_MODE};
+ static const unsigned FPStaticRoundNodes[] = {
+ ISD::FADD_ROUND, ISD::FSUB_ROUND, ISD::FMUL_ROUND,
+ ISD::FDIV_ROUND, ISD::FSQRT_ROUND, ISD::FMA_ROUND};
static const ISD::CondCode FPCCToExpand[] = {
ISD::SETOGT, ISD::SETOGE, ISD::SETONE, ISD::SETUEQ, ISD::SETUGT,
@@ -530,7 +530,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
if (Subtarget.hasStdExtFOrZfinx()) {
setOperationAction(FPLegalNodeTypes, MVT::f32, Legal);
- setOperationAction(FPStaticRoundNodes, MVT::f32, Legal);
+ setOperationAction(FPStaticRoundNodes, MVT::f32, Custom);
setOperationAction(FPRndMode, MVT::f32,
Subtarget.hasStdExtZfa() ? Legal : Custom);
setCondCodeAction(FPCCToExpand, MVT::f32, Expand);
@@ -612,11 +612,16 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction({ISD::STRICT_FP_TO_UINT, ISD::STRICT_FP_TO_SINT,
ISD::STRICT_UINT_TO_FP, ISD::STRICT_SINT_TO_FP},
XLenVT, Legal);
+ setOperationAction({ISD::UINT_TO_FP_ROUND, ISD::SINT_TO_FP_ROUND},
+ XLenVT, Custom);
- if (RV64LegalI32 && Subtarget.is64Bit())
+ if (RV64LegalI32 && Subtarget.is64Bit()) {
setOperationAction({ISD::STRICT_FP_TO_UINT, ISD::STRICT_FP_TO_SINT,
ISD::STRICT_UINT_TO_FP, ISD::STRICT_SINT_TO_FP},
MVT::i32, Legal);
+ setOperationAction({ISD::UINT_TO_FP_ROUND, ISD::SINT_TO_FP_ROUND},
+ MVT::i32, Custom);
+ }
setOperationAction(ISD::GET_ROUNDING, XLenVT, Custom);
setOperationAction(ISD::SET_ROUNDING, MVT::Other, Custom);
@@ -7085,6 +7090,15 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
!Subtarget.hasVInstructionsF16()))
return SplitStrictFPVectorOp(Op, DAG);
return lowerToScalableOp(Op, DAG);
+ case ISD::FADD_ROUND:
+ case ISD::FSUB_ROUND:
+ case ISD::FMUL_ROUND:
+ case ISD::FDIV_ROUND:
+ case ISD::FSQRT_ROUND:
+ case ISD::FMA_ROUND:
+ case ISD::SINT_TO_FP_ROUND:
+ case ISD::UINT_TO_FP_ROUND:
+ return convertToMachineRounding(Op, DAG);
case ISD::STRICT_FSETCC:
case ISD::STRICT_FSETCCS:
return lowerVectorStrictFSetcc(Op, DAG);
@@ -11992,6 +12006,53 @@ SDValue RISCVTargetLowering::lowerSET_ROUNDING(SDValue Op,
RMValue);
}
+SDValue RISCVTargetLowering::convertToMachineRounding(SDValue Op,
+ SelectionDAG &DAG) const {
+ unsigned Opcode = 0;
+ switch (Op->getOpcode()) {
+ default:
+ break;
+ case ISD::FADD_ROUND:
+ Opcode = RISCVISD::FADD_ROUND;
+ break;
+ case ISD::FSUB_ROUND:
+ Opcode = RISCVISD::FSUB_ROUND;
+ break;
+ case ISD::FMUL_ROUND:
+ Opcode = RISCVISD::FMUL_ROUND;
+ break;
+ case ISD::FDIV_ROUND:
+ Opcode = RISCVISD::FDIV_ROUND;
+ break;
+ case ISD::FSQRT_ROUND:
+ Opcode = RISCVISD::FSQRT_ROUND;
+ break;
+ case ISD::FMA_ROUND:
+ Opcode = RISCVISD::FMA_ROUND;
+ break;
+ case ISD::SINT_TO_FP_ROUND:
+ Opcode = RISCVISD::SINT_TO_FP_ROUND;
+ break;
+ case ISD::UINT_TO_FP_ROUND:
+ Opcode = RISCVISD::UINT_TO_FP_ROUND;
+ break;
+ }
+ if (Opcode == 0)
+ return SDValue();
+
+ SDLoc DL(Op);
+ SmallVector<SDValue, 4> Opers;
+ for (unsigned I = 0, E = Op.getNumOperands() - 1; I != E; ++I)
+ Opers.push_back(Op.getOperand(I));
+ SDValue RMVal = Op->ops().back();
+ unsigned RM = cast<ConstantSDNode>(RMVal)->getZExtValue();
+ RM = RISCVFPRndMode::getMachineMode(static_cast<llvm::RoundingMode>(RM));
+ RMVal = DAG.getTargetConstant(RM, DL, Subtarget.getXLenVT());
+ Opers.push_back(RMVal);
+
+ return DAG.getNode(Opcode, DL, Op->getVTList(), Opers);
+}
+
SDValue RISCVTargetLowering::lowerEH_DWARF_CFA(SDValue Op,
SelectionDAG &DAG) const {
MachineFunction &MF = DAG.getMachineFunction();
@@ -20466,6 +20527,14 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(SF_VC_V_IVW_SE)
NODE_NAME_CASE(SF_VC_V_VVW_SE)
NODE_NAME_CASE(SF_VC_V_FVW_SE)
+ NODE_NAME_CASE(FADD_ROUND)
+ NODE_NAME_CASE(FSUB_ROUND)
+ NODE_NAME_CASE(FMUL_ROUND)
+ NODE_NAME_CASE(FDIV_ROUND)
+ NODE_NAME_CASE(FSQRT_ROUND)
+ NODE_NAME_CASE(FMA_ROUND)
+ NODE_NAME_CASE(SINT_TO_FP_ROUND)
+ NODE_NAME_CASE(UINT_TO_FP_ROUND)
}
// clang-format on
return nullptr;
@@ -21432,21 +21501,25 @@ bool RISCVTargetLowering::preferScalarizeSplat(SDNode *N) const {
return true;
}
-int RISCVTargetLowering::getMachineRoundingMode(RoundingMode RM) const {
- switch (RM) {
- case RoundingMode::TowardZero:
- return RISCVFPRndMode::RTZ;
- case RoundingMode::NearestTiesToEven:
- return RISCVFPRndMode::RNE;
- case RoundingMode::TowardNegative:
- return RISCVFPRndMode::RDN;
- case RoundingMode::TowardPositive:
- return RISCVFPRndMode::RUP;
- case RoundingMode::NearestTiesToAway:
- return RISCVFPRndMode::RMM;
- default:
- return -1;
+bool RISCVTargetLowering::isStaticRoundingSupportedFor(
+ const Instruction &I) const {
+ if (auto *CI = dyn_cast<ConstrainedFPIntrinsic>(&I)) {
+ switch (CI->getIntrinsicID()) {
+ default:
+ break;
+ case Intrinsic::experimental_constrained_fadd:
+ case Intrinsic::experimental_constrained_fsub:
+ case Intrinsic::experimental_constrained_fmul:
+ case Intrinsic::experimental_constrained_fdiv:
+ case Intrinsic::experimental_constrained_sqrt:
+ case Intrinsic::experimental_constrained_fma:
+ case Intrinsic::experimental_constrained_sitofp:
+ case Intrinsic::experimental_constrained_uitofp:
+ if (CI->getType()->isFloatTy())
+ return true;
+ }
}
+ return false;
}
static Value *useTpOffset(IRBuilderBase &IRB, unsigned Offset) {
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index b541472304d22..2eced7942de4c 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -468,6 +468,18 @@ enum NodeType : unsigned {
SF_VC_V_VVW_SE,
SF_VC_V_FVW_SE,
+ // These nodes represent the same operations as the target-independent nodes
+ // with the same names. The only difference is using target-specific encoding
+ // of rounding mode.
+ FADD_ROUND,
+ FSUB_ROUND,
+ FMUL_ROUND,
+ FDIV_ROUND,
+ FSQRT_ROUND,
+ FMA_ROUND,
+ SINT_TO_FP_ROUND,
+ UINT_TO_FP_ROUND,
+
// WARNING: Do not add anything in the end unless you want the node to
// have memop! In fact, starting from FIRST_TARGET_MEMORY_OPCODE all
// opcodes will be thought as target memory ops!
@@ -535,10 +547,7 @@ class RISCVTargetLowering : public TargetLowering {
bool softPromoteHalfType() const override { return true; }
- bool isStaticRoundingSupportedFor(const Instruction &I) const override {
- return true;
- }
- int getMachineRoundingMode(RoundingMode RM) const override;
+ bool isStaticRoundingSupportedFor(const Instruction &I) const override;
/// Return the register type for a given MVT, ensuring vectors are treated
/// as a series of gpr sized integers.
@@ -1003,6 +1012,8 @@ class RISCVTargetLowering : public TargetLowering {
SDValue expandUnalignedRVVLoad(SDValue Op, SelectionDAG &DAG) const;
SDValue expandUnalignedRVVStore(SDValue Op, SelectionDAG &DAG) const;
+ SDValue convertToMachineRounding(SDValue Op, SelectionDAG &DAG) const;
+
bool isEligibleForTailCallOptimization(
CCState &CCInfo, CallLoweringInfo &CLI, MachineFunction &MF,
const SmallVector<CCValAssign, 16> &ArgLocs) const;
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoF.td b/llvm/lib/Target/RISCV/RISCVInstrInfoF.td
index e8c3bb0da192f..a5747d6c463a7 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoF.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoF.td
@@ -66,6 +66,15 @@ def riscv_strict_fcvt_wu_rv64
: SDNode<"RISCVISD::STRICT_FCVT_WU_RV64", SDT_RISCVFCVT_W_RV64,
[SDNPHasChain]>;
+def riscv_fadd_round : SDNode<"RISCVISD::FADD_ROUND", SDTFPBinRoundOp>;
+def riscv_fsub_round : SDNode<"RISCVISD::FSUB_ROUND", SDTFPBinRoundOp>;
+def riscv_fmul_round : SDNode<"RISCVISD::FMUL_ROUND", SDTFPBinRoundOp>;
+def riscv_fdiv_round : SDNode<"RISCVISD::FDIV_ROUND", SDTFPBinRoundOp>;
+def riscv_fsqrt_round : SDNode<"RISCVISD::FSQRT_ROUND", SDTFPUnaryRoundOp>;
+def riscv_fma_round : SDNode<"RISCVISD::FMA_ROUND", SDTFPTernaryRoundOp>;
+def riscv_sint_to_fp_round : SDNode<"RISCVISD::SINT_TO_FP_ROUND", SDTIntToFPRoundOp>;
+def riscv_uint_to_fp_round : SDNode<"RISCVISD::UINT_TO_FP_ROUND", SDTIntToFPRoundOp>;
+
def riscv_any_fcvt_w_rv64 : PatFrags<(ops node:$src, node:$frm),
[(riscv_strict_fcvt_w_rv64 node:$src, node:$frm),
(riscv_fcvt_w_rv64 node:$src, node:$frm)]>;
@@ -78,8 +87,8 @@ def any_fma_nsz : PatFrag<(ops node:$rs1, node:$rs2, node:$rs3),
return N->getFlags().hasNoSignedZeros();
}]>;
-def fma_mode_nsz : PatFrag<(ops node:$rs1, node:$rs2, node:$rs3, node:$rm),
- (fma_mode node:$rs1, node:$rs2, node:$rs3, node:$rm), [{
+def riscv_fma_round_nsz : PatFrag<(ops node:$rs1, node:$rs2, node:$rs3, node:$rm),
+ (riscv_fma_round node:$rs1, node:$rs2, node:$rs3, node:$rm), [{
return N->getFlags().hasNoSignedZeros();
}]>;
@@ -545,15 +554,15 @@ foreach Ext = FExts in {
}
foreach Ext = FExts in {
- defm : PatFprFprModeFrm_m<fadd_mode, FADD_S, Ext>;
- defm : PatFprFprModeFrm_m<fsub_mode, FSUB_S, Ext>;
- defm : PatFprFprModeFrm_m<fmul_mode, FMUL_S, Ext>;
- defm : PatFprFprModeFrm_m<fdiv_mode, FDIV_S, Ext>;
+ defm : PatFprFprModeFrm_m<riscv_fadd_round, FADD_S, Ext>;
+ defm : PatFprFprModeFrm_m<riscv_fsub_round, FSUB_S, Ext>;
+ defm : PatFprFprModeFrm_m<riscv_fmul_round, FMUL_S, Ext>;
+ defm : PatFprFprModeFrm_m<riscv_fdiv_round, FDIV_S, Ext>;
}
let Predicates = [HasStdExtF] in {
def : Pat<(any_fsqrt FPR32:$rs1), (FSQRT_S FPR32:$rs1, FRM_DYN)>;
-def : Pat<(fsqrt_mode FPR32:$rs1, (XLenVT timm:$rm)),
+def : Pat<(riscv_fsqrt_round FPR32:$rs1, (XLenVT timm:$rm)),
(FSQRT_S FPR32:$rs1, frmarg:$rm)>;
def : Pat<(fneg FPR32:$rs1), (FSGNJN_S $rs1, $rs1)>;
@@ -564,7 +573,7 @@ def : Pat<(riscv_fclass FPR32:$rs1), (FCLASS_S $rs1)>;
let Predicates = [HasStdExtZfinx] in {
def : Pat<(any_fsqrt FPR32INX:$rs1), (FSQRT_S_INX FPR32INX:$rs1, FRM_DYN)>;
-def : Pat<(fsqrt_mode FPR32INX:$rs1, (XLenVT timm:$rm)),
+def : Pat<(riscv_fsqrt_round FPR32INX:$rs1, (XLenVT timm:$rm)),
(FSQRT_S_INX FPR32INX:$rs1, frmarg:$rm)>;
def : Pat<(fneg FPR32INX:$rs1), (FSGNJN_S_INX $rs1, $rs1)>;
@@ -584,32 +593,36 @@ def : Pat<(fcopysign FPR32:$rs1, (fneg FPR32:$rs2)), (FSGNJN_S $rs1, $rs2)>;
// fmadd: rs1 * rs2 + rs3
def : Pat<(any_fma FPR32:$rs1, FPR32:$rs2, FPR32:$rs3),
(FMADD_S $rs1, $rs2, $rs3, FRM_DYN)>;
-def : Pat<(fma_mode FPR32:$rs1, FPR32:$rs2, FPR32:$rs3, (XLenVT timm:$rm)),
+def : Pat<(riscv_fma_round FPR32:$rs1, FPR32:$rs2, FPR32:$rs3,
+ (XLenVT timm:$rm)),
(FMADD_S $rs1, $rs2, $rs3, frmarg:$rm)>;
// fmsub: rs1 * rs2 - rs3
def : Pat<(any_fma FPR32:$rs1, FPR32:$rs2, (fneg FPR32:$rs3)),
(FMSUB_S FPR32:$rs1, FPR32:$rs2, FPR32:$rs3, FRM_DYN)>;
-def : Pat<(fma_mode FPR32:$rs1, FPR32:$rs2, (fneg FPR32:$rs3),
- (XLenVT timm:$rm)),
+def : Pat<(riscv_fma_round FPR32:$rs1, FPR32:$rs2, (fneg FPR32:$rs3),
+ (XLenVT timm:$rm)),
(FMSUB_S FPR32:$rs1, FPR32:$rs2, FPR32:$rs3, frmarg:$rm)>;
// fnmsub: -rs1 * rs2 + rs3
def : Pat<(any_fma (fneg FPR32:$rs1), FPR32:$rs2, FPR32:$rs3),
(FNMSUB_S FPR32:$rs1, FPR32:$rs2, FPR32:$rs3, FRM_DYN)>;
-def : Pat<(fma_mode (fneg FPR32:$rs1), FPR32:$rs2, FPR32:$rs3, (XLenVT timm:$rm)),
+def : Pat<(riscv_fma_round (fneg FPR32:$rs1), FPR32:$rs2, FPR32:$rs3,
+ (XLenVT timm:$rm)),
(FNMSUB_S FPR32:$rs1, FPR32:$rs2, FPR32:$rs3, frmarg:$rm)>;
// fnmadd: -rs1 * rs2 - rs3
def : Pat<(any_fma (fneg FPR32:$rs1), FPR32:$rs2, (fneg FPR32:$rs3)),
(FNMADD_S FPR32:$rs1, FPR32:$rs2, FPR32:$rs3, FRM_DYN)>;
-def : Pat<(fma_mode (fneg FPR32:$rs1), FPR32:$rs2, (fneg FPR32:$rs3), (XLenVT timm:$rm)),
+def : Pat<(riscv_fma_round (fneg FPR32:$rs1), FPR32:$rs2, (fneg FPR32:$rs3),
+ (XLenVT timm:$rm)),
(FNMADD_S FPR32:$rs1, FPR32:$rs2, FPR32:$rs3, frmarg:$rm)>;
// fnmadd: -(rs1 * rs2 + rs3) (the nsz flag on the FMA)
def : Pat<(fneg (any_fma_nsz FPR32:$rs1, FPR32:$rs2, FPR32:$rs3)),
(FNMADD_S FPR32:$rs1, FPR32:$rs2, FPR32:$rs3, FRM_DYN)>;
-def : Pat<(fneg (fma_mode_nsz FPR32:$rs1, FPR32:$rs2, FPR32:$rs3, (XLenVT timm:$rm))),
+def : Pat<(fneg (riscv_fma_round_nsz FPR32:$rs1, FPR32:$rs2, FPR32:$rs3,
+ (XLenVT timm:$rm))),
(FNMADD_S FPR32:$rs1, FPR32:$rs2, FPR32:$rs3, frmarg:$rm)>;
} // Predicates = [HasStdExtF]
@@ -619,31 +632,36 @@ def : Pat<(fcopysign FPR32INX:$rs1, (fneg FPR32INX:$rs2)), (FSGNJN_S_INX $rs1, $
// fmadd: rs1 * rs2 + rs3
def : Pat<(any_fma FPR32INX:$rs1, FPR32INX:$rs2, FPR32INX:$rs3),
(FMADD_S_INX $rs1, $rs2, $rs3, FRM_DYN)>;
-def : Pat<(fma_mode FPR32INX:$rs1, FPR32INX:$rs2, FPR32INX:$rs3, (XLenVT timm:$rm)),
+def : Pat<(riscv_fma_round FPR32INX:$rs1, FPR32INX:$rs2, FPR32INX:$rs3,
+ (XLenVT timm:$rm)),
(FMADD_S_INX $rs1, $rs2, $rs3, frmarg:$rm)>;
// fmsub: rs1 * rs2 - rs3
def : Pat<(any_fma FPR32INX:$rs1, FPR32INX:$rs2, (fneg FPR32INX:$rs3)),
(FMSUB_S_INX FPR32INX:$rs1, FPR32INX:$rs2, FPR32INX:$rs3, FRM_DYN)>;
-def : Pat<(fma_mode FPR32INX:$rs1, FPR32INX:$rs2, (fneg FPR32INX:$rs3), (XLenVT timm:$rm)),
+def : Pat<(riscv_fma_round FPR32INX:$rs1, FPR32INX:$rs2, (fneg FPR32INX:$rs3),
+ (XLenVT timm:$rm)),
(FMSUB_S_INX FPR32INX:$rs1, FPR32INX:$rs2, FPR32INX:$rs3, frmarg:$rm)>;
// fnmsub: -rs1 * rs2 + rs3
def : Pat<(any_fma (fneg FPR32INX:$rs1), FPR32INX:$rs2, FPR32INX:$rs3),
(FNMSUB_S_INX FPR32INX:$rs1, FPR32INX:$rs2, FPR32INX:$rs3, FRM_DYN)>;
-def : Pat<(fma_mode (fneg FPR32INX:$rs1), FPR32INX:$rs2, FPR32INX:$rs3, (XLenVT timm:$rm)),
+def : Pat<(riscv_fma_round (fneg FPR32INX:$rs1), FPR32INX:$rs2, FPR32INX:$rs3,
+ (XLenVT timm:$rm)),
(FNMSUB_S_INX FPR32INX:$rs1, FPR32INX:$rs2, FPR32INX:$rs3, frmarg:$rm)>;
// fnmadd: -rs1 * rs2 - rs3
def : Pat<(any_fma (fneg FPR32INX:$rs1), FPR32INX:$rs2, (fneg FPR32INX:$rs3)),
(FNMADD_S_INX FPR32INX:$rs1, FPR32INX:$rs2, FPR32INX:$rs3, FRM_DYN)>;
-def : Pat<(fma_mode (fneg FPR32INX:$rs1), FPR32INX:$rs2, (fneg FPR32INX:$rs3), (XLenVT timm:$rm)),
+def : Pat<(riscv_fma_round (fneg FPR32INX:$rs1), FPR32INX:$rs2, (fneg FPR32INX:$rs3),
+ (XLenVT timm:$rm)),
(FNMADD_S_INX FPR32INX:$rs1, FPR32INX:$rs2, FPR32INX:$rs3, frmarg:$rm)>;
// fnmadd: -(rs1 * rs2 + rs3) (the nsz flag on the FMA)
def : Pat<(fneg (any_fma_nsz FPR32INX:$rs1, FPR32INX:$rs2, FPR32INX:$rs3)),
(FNMADD_S_INX FPR32INX:$rs1, FPR32INX:$rs2, FPR32INX:$rs3, FRM_DYN)>;
-def : Pat<(fneg (fma_mode_nsz FPR32INX:$rs1, FPR32INX:$rs2, FPR32INX:$rs3, (XLenVT timm:$rm))),
+def : Pat<(fneg (riscv_fma_round_nsz FPR32INX:$rs1, FPR32INX:$rs2, FPR32INX:$rs3,
+ (XLenVT timm:$rm))),
(FNMADD_S_INX FPR32INX:$rs1, FPR32INX:$rs2, FPR32INX:$rs3, frmarg:$rm)>;
} // Predicates = [HasStdExtZfinx]
@@ -769,9 +787,9 @@ def : Pat<(any_sint_to_fp (i32 GPR:$rs1)), (FCVT_S_W $rs1, FRM_DYN)>;
def : Pat<(any_uint_to_fp (i32 GPR:$rs1)), (FCVT_S_WU $rs1, FRM_DYN)>;
// [u]int->float using static rounding mode.
-def : Pat<(sint_to_fp_mode (i32 GPR:$rs1), (XLenVT timm:$rm)),
+def : Pat<(riscv_sint_to_fp_round (i32 GPR:$rs1), (XLenVT timm:$rm)),
(FCVT_S_W $rs1, frmarg:$rm)>;
-def : Pat<(uint_to_fp_mode (i32 GPR:$rs1), (XLenVT timm:$rm)),
+def : Pat<(riscv_uint_to_fp_round (i32 GPR:$rs1), (XLenVT timm:$rm)),
(FCVT_S_WU $rs1, frmarg:$rm)>;
} // Predicates = [HasStdExtF]
@@ -795,9 +813,9 @@ def : Pat<(any_sint_to_fp (i32 GPR:$rs1)), (FCVT_S_W_INX $rs1, FRM_DYN)>;
def : Pat<(any_uint_to_fp (i32 GPR:$rs1)), (FCVT_S_WU_INX $rs1, FRM_DYN)>;
// [u]int->float using static rounding mode.
-def : Pat<(sint_to_fp_mode (i32 GPR:$rs1), (XLenVT timm:$rm)),
+def : Pat<(riscv_sint_to_fp_round (i32 GPR:$rs1), (XLenVT timm:$rm)),
(FCVT_S_W_INX $rs1, frmarg:$rm)>;
-def : Pat<(uint_to_fp_mode (i32 GPR:$rs1), (XLenVT timm:$rm)),
+def : Pat<(riscv_uint_to_fp_round (i32 GPR:$rs1), (XLenVT timm:$rm)),
(FCVT_S_WU_INX $rs1, frmarg:$rm)>;
} // Predicates = [HasStdExtZfinx]
@@ -835,13 +853,13 @@ def : Pat<(any_sint_to_fp (i64 GPR:$rs1)), (FCVT_S_L $rs1, FRM_DYN)>;
def : Pat<(any_uint_to_fp (i64 GPR:$rs1)), (FCVT_S_LU $rs1, FRM_DYN)>;
// [u]int->fp using static rounding mode.
-def : Pat<(sint_to_fp_mode (i64 (sexti32 (i64 GPR:$rs1))), (i64 timm:$rm)),
+def : Pat<(riscv_sint_to_fp_round (i64 (sexti32 (i64 GPR:$rs1))), (i64 timm:$rm)),
(FCVT_S_W $rs1, frmarg:$rm)>;
-def : Pat<(uint_to_fp_mode (i64 (zexti32 (i64 GPR:$rs1))), (i64 timm:$rm)),
+def : Pat<(riscv_uint_to_fp_round (i64 (zexti32 (i64 GPR:$rs1))), (i64 timm:$rm)),
(FCVT_S_WU $rs1, frmarg:$rm)>;
-def : Pat<(sint_to_fp_mode (i64 GPR:$rs1), (i64 timm:$rm)),
+def : Pat<(riscv_sint_to_fp_round (i64 GPR:$rs1), (i64 timm:$rm)),
(FCVT_S_L $rs1, frmarg:$rm)>;
-def : Pat<(uint_to_fp_mode (i64 GPR:$rs1), (i64 timm:$rm)),
+def : Pat<(riscv_uint_to_fp_round (i64 GPR:$rs1), (i64 timm:$rm)),
(FCVT_S_LU $rs1, frmarg:$rm)>;
} // Predicates = [HasStdExtF, IsRV64]
@@ -879,12 +897,12 @@ def : Pat<(any_sint_to_fp (i64 GPR:$rs1)), (FCVT_S_L_INX $rs1, FRM_DYN)>;
def : Pat<(any_uint_to_fp (i64 GPR:$rs1)), (FCVT_S_LU_INX $rs1, FRM_DYN)>;
// [u]int->fp using static rounding mode.
-def : Pat<(sint_to_fp_mode (i64 (sexti32 (i64 GPR:$rs1))), (i64 timm:$rm)),
+def : Pat<(riscv_sint_to_fp_round (i64 (sexti32 (i64 GPR:$rs1))), (i64 timm:$rm)),
(FCVT_S_W_INX $rs1, frmarg:$rm)>;
-def : Pat<(uint_to_fp_mode (i64 (zexti32 (i64 GPR:$rs1))), (i64 timm:$rm)),
+def : Pat<(riscv_uint_to_fp_round (i64 (zexti32 (i64 GPR:$rs1))), (i64 timm:$rm)),
(FCVT_S_WU_INX $rs1, frmarg:$rm)>;
-def : Pat<(sint_to_fp_mode (i64 GPR:$rs1), (i64 timm:$rm)),
+def : Pat<(riscv_sint_to_fp_round (i64 GPR:$rs1), (i64 timm:$rm)),
(FCVT_S_L_INX $rs1, frmarg:$rm)>;
-def : Pat<(uint_to_fp_mode (i64 GPR:$rs1), (i64 timm:$rm)),
+def : Pat<(riscv_uint_to_fp_round (i64 GPR:$rs1), (i64 timm:$rm)),
(FCVT_S_LU_INX $rs1, frmarg:$rm)>;
} // Predicates = [HasStdExtZfinx, IsRV64]
>From 49b38263593aec94fd000f6c8dc0ebb686b505c6 Mon Sep 17 00:00:00 2001
From: Serge Pavlov <sepavloff at gmail.com>
Date: Thu, 1 Aug 2024 18:08:55 +0700
Subject: [PATCH 3/3] Add chech for F/Zfinx
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 2 ++
1 file changed, 2 insertions(+)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 4dae262551335..934d9104016fc 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -21503,6 +21503,8 @@ bool RISCVTargetLowering::preferScalarizeSplat(SDNode *N) const {
bool RISCVTargetLowering::isStaticRoundingSupportedFor(
const Instruction &I) const {
+ if (!Subtarget.hasStdExtFOrZfinx())
+ return false;
if (auto *CI = dyn_cast<ConstrainedFPIntrinsic>(&I)) {
switch (CI->getIntrinsicID()) {
default:
More information about the llvm-commits
mailing list