[llvm] [SelectionDAG][RISCV] Operations with static rounding (PR #100999)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Jul 29 04:28:51 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-selectiondag
Author: Serge Pavlov (spavloff)
<details>
<summary>Changes</summary>
Some targets, including RISC-V, support rounding mode specified in instruction rather than in a register. To use this feature, new DAG nodes are introduced by this patch. They have the same operands as their default mode counterparts, but have additional integer parameter, which specifies the rounding mode to use. Lowering of these DAG nodes is implemented for extensions F and Zfinx.
---
Patch is 36.12 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/100999.diff
11 Files Affected:
- (modified) llvm/include/llvm/CodeGen/ISDOpcodes.h (+14)
- (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+10)
- (modified) llvm/include/llvm/Target/TargetSelectionDAG.td (+25)
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (+16)
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+2)
- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+62-16)
- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp (+9)
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+22)
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.h (+5)
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfoF.td (+82)
- (added) llvm/test/CodeGen/RISCV/float-mode.ll (+249)
``````````diff
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 5b657fb171296..71a3933250ccb 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -411,6 +411,20 @@ enum NodeType {
STRICT_FREM,
STRICT_FMA,
+ /// Basic floating-point operations with statically specified control modes,
+ /// usually rounding. These have the same operands as the corresponding
+ /// default mode operations with an additional integer operand that represents
+ /// the specified modes in a target-dependent format.
+ FADD_MODE,
+ FSUB_MODE,
+ FMUL_MODE,
+ FDIV_MODE,
+ FSQRT_MODE,
+ FMA_MODE,
+ SINT_TO_FP_MODE,
+ UINT_TO_FP_MODE,
+ FP_ROUND_MODE,
+
/// Constrained versions of libm-equivalent floating point intrinsics.
/// These will be lowered to the equivalent non-constrained pseudo-op
/// (or expanded to the equivalent library call) before final selection.
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 9d9886f4920a2..d9ba3e3ef9d2f 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -354,6 +354,16 @@ class TargetLoweringBase {
return IsStrictFPEnabled;
}
+ /// Returns true if the target supports static rounding mode for the given
+ /// instruction.
+ virtual bool isStaticRoundingSupportedFor(const Instruction &I) const {
+ return false;
+ }
+
+ /// Returns target-specific representation of the given static rounding mode
+ /// or -1, if this rounding mode is not supported.
+ virtual int getMachineRoundingMode(RoundingMode RM) const { return -1; }
+
protected:
/// Initialize all of the actions to default values.
void initActions();
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 46044aab79a83..b206a9d9df983 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -134,12 +134,18 @@ def SDTIntScaledBinOp : SDTypeProfile<1, 3, [ // smulfix, sdivfix, etc
def SDTFPBinOp : SDTypeProfile<1, 2, [ // fadd, fmul, etc.
SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisFP<0>
]>;
+def SDTFPBinModeOp : SDTypeProfile<1, 3, [ // fadd, fmul, etc. with static modes
+ SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisFP<0>, SDTCisInt<3>
+]>;
def SDTFPSignOp : SDTypeProfile<1, 2, [ // fcopysign.
SDTCisSameAs<0, 1>, SDTCisFP<0>, SDTCisFP<2>
]>;
def SDTFPTernaryOp : SDTypeProfile<1, 3, [ // fmadd, fnmsub, etc.
SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisSameAs<0, 3>, SDTCisFP<0>
]>;
+def SDTFPTernaryModeOp : SDTypeProfile<1, 4, [ // fmadd, fnmsub, etc. with static modes
+ SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisSameAs<0, 3>, SDTCisFP<0>, SDTCisInt<4>
+]>;
def SDTIntUnaryOp : SDTypeProfile<1, 1, [ // bitreverse
SDTCisSameAs<0, 1>, SDTCisInt<0>
]>;
@@ -155,9 +161,15 @@ def SDTIntTruncOp : SDTypeProfile<1, 1, [ // trunc
def SDTFPUnaryOp : SDTypeProfile<1, 1, [ // fneg, fsqrt, etc
SDTCisSameAs<0, 1>, SDTCisFP<0>
]>;
+def SDTFPUnaryModeOp : SDTypeProfile<1, 2, [ // fsqrt with static modes
+ SDTCisSameAs<0, 1>, SDTCisFP<0>, SDTCisInt<2>
+]>;
def SDTFPRoundOp : SDTypeProfile<1, 1, [ // fpround
SDTCisFP<0>, SDTCisFP<1>, SDTCisOpSmallerThanOp<0, 1>, SDTCisSameNumEltsAs<0, 1>
]>;
+def SDTFPRoundModeOp : SDTypeProfile<1, 2, [ // fpround with static modes
+ SDTCisFP<0>, SDTCisFP<1>, SDTCisOpSmallerThanOp<0, 1>, SDTCisSameNumEltsAs<0, 1>, SDTCisInt<2>
+]>;
def SDTFPExtendOp : SDTypeProfile<1, 1, [ // fpextend
SDTCisFP<0>, SDTCisFP<1>, SDTCisOpSmallerThanOp<1, 0>, SDTCisSameNumEltsAs<0, 1>
]>;
@@ -167,6 +179,9 @@ def SDIsFPClassOp : SDTypeProfile<1, 2, [ // is_fpclass
def SDTIntToFPOp : SDTypeProfile<1, 1, [ // [su]int_to_fp
SDTCisFP<0>, SDTCisInt<1>, SDTCisSameNumEltsAs<0, 1>
]>;
+def SDTIntToFPModeOp : SDTypeProfile<1, 2, [ // [su]int_to_fp with static modes
+ SDTCisFP<0>, SDTCisInt<1>, SDTCisSameNumEltsAs<0, 1>, SDTCisInt<2>
+]>;
def SDTFPToIntOp : SDTypeProfile<1, 1, [ // fp_to_[su]int
SDTCisInt<0>, SDTCisFP<1>, SDTCisSameNumEltsAs<0, 1>
]>;
@@ -662,6 +677,16 @@ def strict_fp_to_bf16 : SDNode<"ISD::STRICT_FP_TO_BF16",
def strict_fsetcc : SDNode<"ISD::STRICT_FSETCC", SDTSetCC, [SDNPHasChain]>;
def strict_fsetccs : SDNode<"ISD::STRICT_FSETCCS", SDTSetCC, [SDNPHasChain]>;
+def fadd_mode : SDNode<"ISD::FADD_MODE" , SDTFPBinModeOp, [SDNPCommutative]>;
+def fsub_mode : SDNode<"ISD::FSUB_MODE" , SDTFPBinModeOp>;
+def fmul_mode : SDNode<"ISD::FMUL_MODE" , SDTFPBinModeOp, [SDNPCommutative]>;
+def fdiv_mode : SDNode<"ISD::FDIV_MODE" , SDTFPBinModeOp>;
+def fsqrt_mode : SDNode<"ISD::FSQRT_MODE" , SDTFPUnaryModeOp>;
+def fma_mode : SDNode<"ISD::FMA_MODE" , SDTFPTernaryModeOp, [SDNPCommutative]>;
+def sint_to_fp_mode: SDNode<"ISD::SINT_TO_FP_MODE" , SDTIntToFPModeOp>;
+def uint_to_fp_mode: SDNode<"ISD::UINT_TO_FP_MODE" , SDTIntToFPModeOp>;
+def fpround_mode : SDNode<"ISD::FP_ROUND_MODE" , SDTFPRoundModeOp>;
+
def get_fpenv : SDNode<"ISD::GET_FPENV", SDTGetFPStateOp, [SDNPHasChain]>;
def set_fpenv : SDNode<"ISD::SET_FPENV", SDTSetFPStateOp, [SDNPHasChain]>;
def reset_fpenv : SDNode<"ISD::RESET_FPENV", SDTNone, [SDNPHasChain]>;
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 33a53dfc81379..d3be39faca2f6 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -1943,6 +1943,7 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
case ISD::VP_SIGN_EXTEND: Res = PromoteIntOp_VP_SIGN_EXTEND(N); break;
case ISD::VP_SINT_TO_FP:
case ISD::SINT_TO_FP: Res = PromoteIntOp_SINT_TO_FP(N); break;
+ case ISD::SINT_TO_FP_MODE: Res = PromoteIntOp_SINT_TO_FP_MODE(N); break;
case ISD::STRICT_SINT_TO_FP: Res = PromoteIntOp_STRICT_SINT_TO_FP(N); break;
case ISD::STORE: Res = PromoteIntOp_STORE(cast<StoreSDNode>(N),
OpNo); break;
@@ -1963,6 +1964,7 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
case ISD::FP16_TO_FP:
case ISD::VP_UINT_TO_FP:
case ISD::UINT_TO_FP: Res = PromoteIntOp_UINT_TO_FP(N); break;
+ case ISD::UINT_TO_FP_MODE: Res = PromoteIntOp_UINT_TO_FP_MODE(N); break;
case ISD::STRICT_FP16_TO_FP:
case ISD::STRICT_UINT_TO_FP: Res = PromoteIntOp_STRICT_UINT_TO_FP(N); break;
case ISD::ZERO_EXTEND: Res = PromoteIntOp_ZERO_EXTEND(N); break;
@@ -2344,6 +2346,20 @@ SDValue DAGTypeLegalizer::PromoteIntOp_SINT_TO_FP(SDNode *N) {
SExtPromotedInteger(N->getOperand(0))), 0);
}
+SDValue DAGTypeLegalizer::PromoteIntOp_SINT_TO_FP_MODE(SDNode *N) {
+ return SDValue(DAG.UpdateNodeOperands(N,
+ SExtPromotedInteger(N->getOperand(0)),
+ N->getOperand(1)),
+ 0);
+}
+
+SDValue DAGTypeLegalizer::PromoteIntOp_UINT_TO_FP_MODE(SDNode *N) {
+ return SDValue(DAG.UpdateNodeOperands(N,
+ ZExtPromotedInteger(N->getOperand(0)),
+ N->getOperand(1)),
+ 0);
+}
+
SDValue DAGTypeLegalizer::PromoteIntOp_STRICT_SINT_TO_FP(SDNode *N) {
return SDValue(DAG.UpdateNodeOperands(N, N->getOperand(0),
SExtPromotedInteger(N->getOperand(1))), 0);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index d4e61c8588901..eea01aea42a65 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -402,6 +402,8 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue PromoteIntOp_SIGN_EXTEND(SDNode *N);
SDValue PromoteIntOp_VP_SIGN_EXTEND(SDNode *N);
SDValue PromoteIntOp_SINT_TO_FP(SDNode *N);
+ SDValue PromoteIntOp_SINT_TO_FP_MODE(SDNode *N);
+ SDValue PromoteIntOp_UINT_TO_FP_MODE(SDNode *N);
SDValue PromoteIntOp_STRICT_SINT_TO_FP(SDNode *N);
SDValue PromoteIntOp_STORE(StoreSDNode *N, unsigned OpNo);
SDValue PromoteIntOp_TRUNCATE(SDNode *N);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 9f5e6466309e9..26de077171092 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8131,15 +8131,73 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
void SelectionDAGBuilder::visitConstrainedFPIntrinsic(
const ConstrainedFPIntrinsic &FPI) {
SDLoc sdl = getCurSDLoc();
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ EVT VT = TLI.getValueType(DAG.getDataLayout(), FPI.getType());
+ fp::ExceptionBehavior EB = *FPI.getExceptionBehavior();
+ std::optional<RoundingMode> RM = FPI.getRoundingMode();
+
+ SDNodeFlags Flags;
+ if (EB == fp::ExceptionBehavior::ebIgnore)
+ Flags.setNoFPExcept(true);
+
+ if (auto *FPOp = dyn_cast<FPMathOperator>(&FPI))
+ Flags.copyFMF(*FPOp);
+
+ bool UseStaticRounding = EB == fp::ExceptionBehavior::ebIgnore && RM &&
+ *RM != RoundingMode::Dynamic &&
+ TLI.isStaticRoundingSupportedFor(FPI);
+ unsigned Opcode = 0;
+ SmallVector<SDValue, 4> Opers;
+ for (unsigned I = 0, E = FPI.getNonMetadataArgCount(); I != E; ++I)
+ Opers.push_back(getValue(FPI.getArgOperand(I)));
+
+ if (UseStaticRounding) {
+ switch (FPI.getIntrinsicID()) {
+ case Intrinsic::experimental_constrained_fadd:
+ Opcode = ISD::FADD_MODE;
+ break;
+ case Intrinsic::experimental_constrained_fsub:
+ Opcode = ISD::FSUB_MODE;
+ break;
+ case Intrinsic::experimental_constrained_fmul:
+ Opcode = ISD::FMUL_MODE;
+ break;
+ case Intrinsic::experimental_constrained_fdiv:
+ Opcode = ISD::FDIV_MODE;
+ break;
+ case Intrinsic::experimental_constrained_sqrt:
+ Opcode = ISD::FSQRT_MODE;
+ break;
+ case Intrinsic::experimental_constrained_fma:
+ Opcode = ISD::FMA_MODE;
+ break;
+ case Intrinsic::experimental_constrained_sitofp:
+ Opcode = ISD::SINT_TO_FP_MODE;
+ break;
+ case Intrinsic::experimental_constrained_uitofp:
+ Opcode = ISD::UINT_TO_FP_MODE;
+ break;
+ case Intrinsic::experimental_constrained_fptrunc:
+ Opcode = ISD::FP_ROUND_MODE;
+ break;
+ }
+ if (Opcode) {
+ int MachineRM = TLI.getMachineRoundingMode(*RM);
+ assert(MachineRM >= 0 && "Unsupported rounding mode");
+ EVT RMType = TLI.getTypeToTransformTo(*DAG.getContext(), MVT::i32);
+ Opers.push_back(DAG.getConstant(static_cast<uint64_t>(MachineRM), sdl,
+ RMType, true));
+ SDValue Result = DAG.getNode(Opcode, sdl, VT, Opers, Flags);
+ setValue(&FPI, Result);
+ return;
+ }
+ }
// We do not need to serialize constrained FP intrinsics against
// each other or against (nonvolatile) loads, so they can be
// chained like loads.
SDValue Chain = DAG.getRoot();
- SmallVector<SDValue, 4> Opers;
- Opers.push_back(Chain);
- for (unsigned I = 0, E = FPI.getNonMetadataArgCount(); I != E; ++I)
- Opers.push_back(getValue(FPI.getArgOperand(I)));
+ Opers.insert(Opers.begin(), Chain);
auto pushOutChain = [this](SDValue Result, fp::ExceptionBehavior EB) {
assert(Result.getNode()->getNumValues() == 2);
@@ -8167,19 +8225,7 @@ void SelectionDAGBuilder::visitConstrainedFPIntrinsic(
}
};
- const TargetLowering &TLI = DAG.getTargetLoweringInfo();
- EVT VT = TLI.getValueType(DAG.getDataLayout(), FPI.getType());
SDVTList VTs = DAG.getVTList(VT, MVT::Other);
- fp::ExceptionBehavior EB = *FPI.getExceptionBehavior();
-
- SDNodeFlags Flags;
- if (EB == fp::ExceptionBehavior::ebIgnore)
- Flags.setNoFPExcept(true);
-
- if (auto *FPOp = dyn_cast<FPMathOperator>(&FPI))
- Flags.copyFMF(*FPOp);
-
- unsigned Opcode;
switch (FPI.getIntrinsicID()) {
default: llvm_unreachable("Impossible intrinsic"); // Can't reach here.
#define DAG_INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC, DAGN) \
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index 16fc52caebb75..30779a9a38379 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -206,6 +206,7 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
case ISD::FNEG: return "fneg";
case ISD::FSQRT: return "fsqrt";
case ISD::STRICT_FSQRT: return "strict_fsqrt";
+ case ISD::FSQRT_MODE: return "fsqrt_mode";
case ISD::FCBRT: return "fcbrt";
case ISD::FSIN: return "fsin";
case ISD::STRICT_FSIN: return "strict_fsin";
@@ -284,14 +285,19 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
case ISD::FSHR: return "fshr";
case ISD::FADD: return "fadd";
case ISD::STRICT_FADD: return "strict_fadd";
+ case ISD::FADD_MODE: return "fadd_mode";
case ISD::FSUB: return "fsub";
case ISD::STRICT_FSUB: return "strict_fsub";
+ case ISD::FSUB_MODE: return "fsub_mode";
case ISD::FMUL: return "fmul";
case ISD::STRICT_FMUL: return "strict_fmul";
+ case ISD::FMUL_MODE: return "fmul_mode";
case ISD::FDIV: return "fdiv";
case ISD::STRICT_FDIV: return "strict_fdiv";
+ case ISD::FDIV_MODE: return "fdiv_mode";
case ISD::FMA: return "fma";
case ISD::STRICT_FMA: return "strict_fma";
+ case ISD::FMA_MODE: return "fma_mode";
case ISD::FMAD: return "fmad";
case ISD::FREM: return "frem";
case ISD::STRICT_FREM: return "strict_frem";
@@ -382,13 +388,16 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
case ISD::TRUNCATE: return "truncate";
case ISD::FP_ROUND: return "fp_round";
case ISD::STRICT_FP_ROUND: return "strict_fp_round";
+ case ISD::FP_ROUND_MODE: return "fp_round_mode";
case ISD::FP_EXTEND: return "fp_extend";
case ISD::STRICT_FP_EXTEND: return "strict_fp_extend";
case ISD::SINT_TO_FP: return "sint_to_fp";
case ISD::STRICT_SINT_TO_FP: return "strict_sint_to_fp";
+ case ISD::SINT_TO_FP_MODE: return "sint_to_fp_mode";
case ISD::UINT_TO_FP: return "uint_to_fp";
case ISD::STRICT_UINT_TO_FP: return "strict_uint_to_fp";
+ case ISD::UINT_TO_FP_MODE: return "uint_to_fp_mode";
case ISD::FP_TO_SINT: return "fp_to_sint";
case ISD::STRICT_FP_TO_SINT: return "strict_fp_to_sint";
case ISD::FP_TO_UINT: return "fp_to_uint";
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 9ce669a3122f5..c77e284d5fcc7 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -434,6 +434,10 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
ISD::STRICT_FSUB, ISD::STRICT_FMUL, ISD::STRICT_FDIV,
ISD::STRICT_FSQRT, ISD::STRICT_FSETCC, ISD::STRICT_FSETCCS};
+ static const unsigned FPStaticRoundNodes[] = {ISD::FADD_MODE, ISD::FSUB_MODE,
+ ISD::FMUL_MODE, ISD::FDIV_MODE,
+ ISD::FSQRT, ISD::FMA_MODE};
+
static const ISD::CondCode FPCCToExpand[] = {
ISD::SETOGT, ISD::SETOGE, ISD::SETONE, ISD::SETUEQ, ISD::SETUGT,
ISD::SETUGE, ISD::SETULT, ISD::SETULE, ISD::SETUNE, ISD::SETGT,
@@ -526,6 +530,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
if (Subtarget.hasStdExtFOrZfinx()) {
setOperationAction(FPLegalNodeTypes, MVT::f32, Legal);
+ setOperationAction(FPStaticRoundNodes, MVT::f32, Legal);
setOperationAction(FPRndMode, MVT::f32,
Subtarget.hasStdExtZfa() ? Legal : Custom);
setCondCodeAction(FPCCToExpand, MVT::f32, Expand);
@@ -21427,6 +21432,23 @@ bool RISCVTargetLowering::preferScalarizeSplat(SDNode *N) const {
return true;
}
+int RISCVTargetLowering::getMachineRoundingMode(RoundingMode RM) const {
+ switch (RM) {
+ case RoundingMode::TowardZero:
+ return RISCVFPRndMode::RTZ;
+ case RoundingMode::NearestTiesToEven:
+ return RISCVFPRndMode::RNE;
+ case RoundingMode::TowardNegative:
+ return RISCVFPRndMode::RDN;
+ case RoundingMode::TowardPositive:
+ return RISCVFPRndMode::RUP;
+ case RoundingMode::NearestTiesToAway:
+ return RISCVFPRndMode::RMM;
+ default:
+ return -1;
+ }
+}
+
static Value *useTpOffset(IRBuilderBase &IRB, unsigned Offset) {
Module *M = IRB.GetInsertBlock()->getParent()->getParent();
Function *ThreadPointerFunc =
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 498c77f1875ed..b541472304d22 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -535,6 +535,11 @@ class RISCVTargetLowering : public TargetLowering {
bool softPromoteHalfType() const override { return true; }
+ bool isStaticRoundingSupportedFor(const Instruction &I) const override {
+ return true;
+ }
+ int getMachineRoundingMode(RoundingMode RM) const override;
+
/// Return the register type for a given MVT, ensuring vectors are treated
/// as a series of gpr sized integers.
MVT getRegisterTypeForCallingConv(LLVMContext &Context, CallingConv::ID CC,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoF.td b/llvm/lib/Target/RISCV/RISCVInstrInfoF.td
index e6c25e0844fb2..e8c3bb0da192f 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoF.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoF.td
@@ -77,6 +77,12 @@ def any_fma_nsz : PatFrag<(ops node:$rs1, node:$rs2, node:$rs3),
(any_fma node:$rs1, node:$rs2, node:$rs3), [{
return N->getFlags().hasNoSignedZeros();
}]>;
+
+def fma_mode_nsz : PatFrag<(ops node:$rs1, node:$rs2, node:$rs3, node:$rm),
+ (fma_mode node:$rs1, node:$rs2, node:$rs3, node:$rm), [{
+ return N->getFlags().hasNoSignedZeros();
+}]>;
+
//===----------------------------------------------------------------------===//
// Operand and SDNode transformation definitions.
//===----------------------------------------------------------------------===//
@@ -513,6 +519,18 @@ multiclass PatFprFprDynFrm_m<SDPatternOperator OpNode, RVInstRFrm Inst,
Ext.PrimaryTy, Ext.PrimaryVT>;
}
+class PatFprFprModeFrm<SDPatternOperator OpNode, RVInstRFrm Inst,
+ DAGOperand RegTy, ValueType vt>
+ : Pat<(OpNode (vt RegTy:$rs1), (vt RegTy:$rs2), (XLenVT timm:$rm)),
+ (Inst $rs1, $rs2, $rm)>;
+multiclass PatFprFprModeFrm_m<SDPatternOperator OpNode, RVInstRFrm Inst,
+ ExtInfo Ext> {
+ let Predicates = Ext.Predicates in
+ def Ext.Suffix : PatFprFprModeFrm<OpNode,
+ !cast<RVInstRFrm>(Inst#Ext.Suffix),
+ Ext.PrimaryTy, Ext.PrimaryVT>;
+}
+
/// Float conversion operations
// [u]int32<->float conversion patterns must be gated on IsRV32 or IsRV64, so
@@ -526,8 +544,17 @@ foreach Ext = FExts in {
defm : PatFprFprDynFrm_m<any_fdiv, FDIV_S, Ext>;
}
+foreach Ext = FExts in {
+ defm : PatFprFprModeFrm_m<fadd_mode, FADD_S, Ext>;
+ defm : PatFprFprModeFrm_m<fsub_mode, FSUB_S, Ext>;
+ defm : PatFprFprModeFrm_m<fmul_mode, FMUL_S, Ext>;
+ defm : PatFprFprModeFrm_m<fdiv_mode, FDIV_S, Ext>;
+}
+
let Predicates = [HasStdExtF] in {
def : Pat<(any_fsqrt FPR32:$rs1), (FSQRT_S FPR32:$rs1, FRM_DYN)...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/100999
More information about the llvm-commits
mailing list