[llvm] [RISCV] Merge GPRPair and GPRF64Pair (PR #116094)
Sam Elliott via llvm-commits
llvm-commits at lists.llvm.org
Tue Nov 19 05:35:17 PST 2024
https://github.com/lenary updated https://github.com/llvm/llvm-project/pull/116094
>From 95fc8c4072ea251ae8fee997b1b370bd5f0ef8e8 Mon Sep 17 00:00:00 2001
From: Sam Elliott <quic_aelliott at quicinc.com>
Date: Wed, 13 Nov 2024 11:07:52 -0800
Subject: [PATCH] [WIP][RISCV] Merge GPRPair and GPRF64Pair
---
.../Target/RISCV/AsmParser/RISCVAsmParser.cpp | 10 +---
llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp | 16 +++---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 51 ++++++++++++-------
llvm/lib/Target/RISCV/RISCVInstrInfoD.td | 18 +++----
llvm/lib/Target/RISCV/RISCVRegisterInfo.td | 22 ++------
.../CodeGen/RISCV/zdinx-asm-constraint.ll | 36 +++++++++++++
6 files changed, 90 insertions(+), 63 deletions(-)
diff --git a/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp b/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
index 1b23b36a59e0ec..b843bb5ae43100 100644
--- a/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
+++ b/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
@@ -497,16 +497,10 @@ struct RISCVOperand final : public MCParsedAsmOperand {
RISCVMCRegisterClasses[RISCV::GPRF32RegClassID].contains(Reg.RegNum);
}
- bool isGPRF64Pair() const {
- return Kind == KindTy::Register &&
- RISCVMCRegisterClasses[RISCV::GPRF64PairRegClassID].contains(
- Reg.RegNum);
- }
-
bool isGPRAsFPR() const { return isGPR() && Reg.IsGPRAsFPR; }
bool isGPRAsFPR16() const { return isGPRF16() && Reg.IsGPRAsFPR; }
bool isGPRAsFPR32() const { return isGPRF32() && Reg.IsGPRAsFPR; }
- bool isGPRPairAsFPR64() const { return isGPRF64Pair() && Reg.IsGPRAsFPR; }
+ bool isGPRPairAsFPR64() const { return isGPRPair() && Reg.IsGPRAsFPR; }
static bool evaluateConstantImm(const MCExpr *Expr, int64_t &Imm,
RISCVMCExpr::VariantKind &VK) {
@@ -2405,7 +2399,7 @@ ParseStatus RISCVAsmParser::parseGPRPairAsFPR64(OperandVector &Operands) {
const MCRegisterInfo *RI = getContext().getRegisterInfo();
MCRegister Pair = RI->getMatchingSuperReg(
Reg, RISCV::sub_gpr_even,
- &RISCVMCRegisterClasses[RISCV::GPRF64PairRegClassID]);
+ &RISCVMCRegisterClasses[RISCV::GPRPairRegClassID]);
Operands.push_back(RISCVOperand::createReg(Pair, S, E, /*isGPRAsFPR=*/true));
return ParseStatus::Success;
}
diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index ca368a18c80d64..bd450aba249c65 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -958,20 +958,14 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
assert((!Subtarget->is64Bit() || Opcode == RISCVISD::BuildGPRPair) &&
"BuildPairF64 only handled here on rv32i_zdinx");
- int RegClassID = (Opcode == RISCVISD::BuildGPRPair)
- ? RISCV::GPRPairRegClassID
- : RISCV::GPRF64PairRegClassID;
- MVT OutType = (Opcode == RISCVISD::BuildGPRPair) ? MVT::Untyped : MVT::f64;
-
SDValue Ops[] = {
- CurDAG->getTargetConstant(RegClassID, DL, MVT::i32),
+ CurDAG->getTargetConstant(RISCV::GPRPairRegClassID, DL, MVT::i32),
Node->getOperand(0),
CurDAG->getTargetConstant(RISCV::sub_gpr_even, DL, MVT::i32),
Node->getOperand(1),
CurDAG->getTargetConstant(RISCV::sub_gpr_odd, DL, MVT::i32)};
- SDNode *N =
- CurDAG->getMachineNode(TargetOpcode::REG_SEQUENCE, DL, OutType, Ops);
+ SDNode *N = CurDAG->getMachineNode(TargetOpcode::REG_SEQUENCE, DL, VT, Ops);
ReplaceNode(Node, N);
return;
}
@@ -982,13 +976,15 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
"SplitF64 only handled here on rv32i_zdinx");
if (!SDValue(Node, 0).use_empty()) {
- SDValue Lo = CurDAG->getTargetExtractSubreg(RISCV::sub_gpr_even, DL, VT,
+ SDValue Lo = CurDAG->getTargetExtractSubreg(RISCV::sub_gpr_even, DL,
+ Node->getSimpleValueType(0),
Node->getOperand(0));
ReplaceUses(SDValue(Node, 0), Lo);
}
if (!SDValue(Node, 1).use_empty()) {
- SDValue Hi = CurDAG->getTargetExtractSubreg(RISCV::sub_gpr_odd, DL, VT,
+ SDValue Hi = CurDAG->getTargetExtractSubreg(RISCV::sub_gpr_odd, DL,
+ Node->getSimpleValueType(1),
Node->getOperand(0));
ReplaceUses(SDValue(Node, 1), Hi);
}
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 649bff2869e24d..976b2478b433e5 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -133,7 +133,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
if (Subtarget.is64Bit())
addRegisterClass(MVT::f64, &RISCV::GPRRegClass);
else
- addRegisterClass(MVT::f64, &RISCV::GPRF64PairRegClass);
+ addRegisterClass(MVT::f64, &RISCV::GPRPairRegClass);
}
static const MVT::SimpleValueType BoolVecVTs[] = {
@@ -20507,7 +20507,7 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
if (VT == MVT::f32 && Subtarget.hasStdExtZfinx())
return std::make_pair(0U, &RISCV::GPRF32NoX0RegClass);
if (VT == MVT::f64 && Subtarget.hasStdExtZdinx() && !Subtarget.is64Bit())
- return std::make_pair(0U, &RISCV::GPRF64PairNoX0RegClass);
+ return std::make_pair(0U, &RISCV::GPRPairNoX0RegClass);
return std::make_pair(0U, &RISCV::GPRNoX0RegClass);
case 'f':
if (VT == MVT::f16) {
@@ -20524,14 +20524,14 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
if (Subtarget.hasStdExtD())
return std::make_pair(0U, &RISCV::FPR64RegClass);
if (Subtarget.hasStdExtZdinx() && !Subtarget.is64Bit())
- return std::make_pair(0U, &RISCV::GPRF64PairNoX0RegClass);
+ return std::make_pair(0U, &RISCV::GPRPairNoX0RegClass);
if (Subtarget.hasStdExtZdinx() && Subtarget.is64Bit())
return std::make_pair(0U, &RISCV::GPRNoX0RegClass);
}
break;
case 'R':
if (VT == MVT::f64 && !Subtarget.is64Bit() && Subtarget.hasStdExtZdinx())
- return std::make_pair(0U, &RISCV::GPRF64PairNoX0RegClass);
+ return std::make_pair(0U, &RISCV::GPRPairNoX0RegClass);
return std::make_pair(0U, &RISCV::GPRPairNoX0RegClass);
default:
break;
@@ -20570,7 +20570,7 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
if (VT == MVT::f32 && Subtarget.hasStdExtZfinx())
return std::make_pair(0U, &RISCV::GPRF32CRegClass);
if (VT == MVT::f64 && Subtarget.hasStdExtZdinx() && !Subtarget.is64Bit())
- return std::make_pair(0U, &RISCV::GPRF64PairCRegClass);
+ return std::make_pair(0U, &RISCV::GPRPairCRegClass);
if (!VT.isVector())
return std::make_pair(0U, &RISCV::GPRCRegClass);
} else if (Constraint == "cf") {
@@ -20588,7 +20588,7 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
if (Subtarget.hasStdExtD())
return std::make_pair(0U, &RISCV::FPR64CRegClass);
if (Subtarget.hasStdExtZdinx() && !Subtarget.is64Bit())
- return std::make_pair(0U, &RISCV::GPRF64PairCRegClass);
+ return std::make_pair(0U, &RISCV::GPRPairCRegClass);
if (Subtarget.hasStdExtZdinx() && Subtarget.is64Bit())
return std::make_pair(0U, &RISCV::GPRCRegClass);
}
@@ -20752,7 +20752,7 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
// Subtarget into account.
if (Res.second == &RISCV::GPRF16RegClass ||
Res.second == &RISCV::GPRF32RegClass ||
- Res.second == &RISCV::GPRF64PairRegClass)
+ Res.second == &RISCV::GPRPairRegClass)
return std::make_pair(Res.first, &RISCV::GPRRegClass);
return Res;
@@ -21379,12 +21379,19 @@ bool RISCVTargetLowering::splitValueIntoRegisterParts(
bool IsABIRegCopy = CC.has_value();
EVT ValueVT = Val.getValueType();
- if (ValueVT == (Subtarget.is64Bit() ? MVT::i128 : MVT::i64) &&
+ MVT PairVT = Subtarget.is64Bit() ? MVT::i128 : MVT::i64;
+ if ((ValueVT == PairVT ||
+ (!Subtarget.is64Bit() && Subtarget.hasStdExtZdinx() &&
+ ValueVT == MVT::f64)) &&
NumParts == 1 && PartVT == MVT::Untyped) {
- // Pairs in Inline Assembly
+ // Pairs in Inline Assembly, f64 in Inline assembly on rv32_zdinx
MVT XLenVT = Subtarget.getXLenVT();
+ if (ValueVT == MVT::f64)
+ Val = DAG.getBitcast(MVT::i64, Val);
auto [Lo, Hi] = DAG.SplitScalar(Val, DL, XLenVT, XLenVT);
- Parts[0] = DAG.getNode(RISCVISD::BuildGPRPair, DL, MVT::Untyped, Lo, Hi);
+ // Always creating an MVT::Untyped part, so always use
+ // RISCVISD::BuildGPRPair.
+ Parts[0] = DAG.getNode(RISCVISD::BuildGPRPair, DL, PartVT, Lo, Hi);
return true;
}
@@ -21396,7 +21403,7 @@ bool RISCVTargetLowering::splitValueIntoRegisterParts(
Val = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Val);
Val = DAG.getNode(ISD::OR, DL, MVT::i32, Val,
DAG.getConstant(0xFFFF0000, DL, MVT::i32));
- Val = DAG.getNode(ISD::BITCAST, DL, MVT::f32, Val);
+ Val = DAG.getNode(ISD::BITCAST, DL, PartVT, Val);
Parts[0] = Val;
return true;
}
@@ -21465,14 +21472,24 @@ SDValue RISCVTargetLowering::joinRegisterPartsIntoValue(
MVT PartVT, EVT ValueVT, std::optional<CallingConv::ID> CC) const {
bool IsABIRegCopy = CC.has_value();
- if (ValueVT == (Subtarget.is64Bit() ? MVT::i128 : MVT::i64) &&
+ MVT PairVT = Subtarget.is64Bit() ? MVT::i128 : MVT::i64;
+ if ((ValueVT == PairVT ||
+ (!Subtarget.is64Bit() && Subtarget.hasStdExtZdinx() &&
+ ValueVT == MVT::f64)) &&
NumParts == 1 && PartVT == MVT::Untyped) {
- // Pairs in Inline Assembly
+ // Pairs in Inline Assembly, f64 in Inline assembly on rv32_zdinx
MVT XLenVT = Subtarget.getXLenVT();
- SDValue Res = DAG.getNode(RISCVISD::SplitGPRPair, DL,
- DAG.getVTList(XLenVT, XLenVT), Parts[0]);
- return DAG.getNode(ISD::BUILD_PAIR, DL, ValueVT, Res.getValue(0),
- Res.getValue(1));
+
+ SDValue Val = Parts[0];
+ // Always starting with an MVT::Untyped part, so always use
+ // RISCVISD::SplitGPRPair
+ Val = DAG.getNode(RISCVISD::SplitGPRPair, DL, DAG.getVTList(XLenVT, XLenVT),
+ Val);
+ Val = DAG.getNode(ISD::BUILD_PAIR, DL, PairVT, Val.getValue(0),
+ Val.getValue(1));
+ if (ValueVT == MVT::f64)
+ Val = DAG.getBitcast(ValueVT, Val);
+ return Val;
}
if (IsABIRegCopy && (ValueVT == MVT::f16 || ValueVT == MVT::bf16) &&
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoD.td b/llvm/lib/Target/RISCV/RISCVInstrInfoD.td
index 3c043c3d3864b5..b01af468d9ea2b 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoD.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoD.td
@@ -52,7 +52,7 @@ def FPR64INX : RegisterOperand<GPR> {
let DecoderMethod = "DecodeGPRRegisterClass";
}
-def FPR64IN32X : RegisterOperand<GPRF64Pair> {
+def FPR64IN32X : RegisterOperand<GPRPair> {
let ParserMatchClass = GPRPairAsFPR;
}
@@ -457,16 +457,16 @@ def : PatSetCC<FPR64INX, any_fsetccs, SETOLE, FLE_D_INX, f64>;
let Predicates = [HasStdExtZdinx, IsRV32] in {
// Match signaling FEQ_D
-def : Pat<(XLenVT (strict_fsetccs FPR64IN32X:$rs1, FPR64IN32X:$rs2, SETEQ)),
+def : Pat<(XLenVT (strict_fsetccs (f64 FPR64IN32X:$rs1), FPR64IN32X:$rs2, SETEQ)),
(AND (XLenVT (FLE_D_IN32X $rs1, $rs2)),
(XLenVT (FLE_D_IN32X $rs2, $rs1)))>;
-def : Pat<(XLenVT (strict_fsetccs FPR64IN32X:$rs1, FPR64IN32X:$rs2, SETOEQ)),
+def : Pat<(XLenVT (strict_fsetccs (f64 FPR64IN32X:$rs1), FPR64IN32X:$rs2, SETOEQ)),
(AND (XLenVT (FLE_D_IN32X $rs1, $rs2)),
(XLenVT (FLE_D_IN32X $rs2, $rs1)))>;
// If both operands are the same, use a single FLE.
-def : Pat<(XLenVT (strict_fsetccs FPR64IN32X:$rs1, FPR64IN32X:$rs1, SETEQ)),
+def : Pat<(XLenVT (strict_fsetccs (f64 FPR64IN32X:$rs1), FPR64IN32X:$rs1, SETEQ)),
(FLE_D_IN32X $rs1, $rs1)>;
-def : Pat<(XLenVT (strict_fsetccs FPR64IN32X:$rs1, FPR64IN32X:$rs1, SETOEQ)),
+def : Pat<(XLenVT (strict_fsetccs (f64 FPR64IN32X:$rs1), FPR64IN32X:$rs1, SETOEQ)),
(FLE_D_IN32X $rs1, $rs1)>;
def : PatSetCC<FPR64IN32X, any_fsetccs, SETLT, FLT_D_IN32X, f64>;
@@ -523,15 +523,15 @@ def PseudoFROUND_D_IN32X : PseudoFROUND<FPR64IN32X, f64>;
/// Loads
let isCall = 0, mayLoad = 1, mayStore = 0, Size = 8, isCodeGenOnly = 1 in
-def PseudoRV32ZdinxLD : Pseudo<(outs GPRF64Pair:$dst), (ins GPR:$rs1, simm12:$imm12), []>;
+def PseudoRV32ZdinxLD : Pseudo<(outs GPRPair:$dst), (ins GPR:$rs1, simm12:$imm12), []>;
def : Pat<(f64 (load (AddrRegImmINX (XLenVT GPR:$rs1), simm12:$imm12))),
(PseudoRV32ZdinxLD GPR:$rs1, simm12:$imm12)>;
/// Stores
let isCall = 0, mayLoad = 0, mayStore = 1, Size = 8, isCodeGenOnly = 1 in
-def PseudoRV32ZdinxSD : Pseudo<(outs), (ins GPRF64Pair:$rs2, GPRNoX0:$rs1, simm12:$imm12), []>;
-def : Pat<(store (f64 GPRF64Pair:$rs2), (AddrRegImmINX (XLenVT GPR:$rs1), simm12:$imm12)),
- (PseudoRV32ZdinxSD GPRF64Pair:$rs2, GPR:$rs1, simm12:$imm12)>;
+def PseudoRV32ZdinxSD : Pseudo<(outs), (ins GPRPair:$rs2, GPRNoX0:$rs1, simm12:$imm12), []>;
+def : Pat<(store (f64 GPRPair:$rs2), (AddrRegImmINX (XLenVT GPR:$rs1), simm12:$imm12)),
+ (PseudoRV32ZdinxSD GPRPair:$rs2, GPR:$rs1, simm12:$imm12)>;
} // Predicates = [HasStdExtZdinx, IsRV32]
let Predicates = [HasStdExtD, IsRV32] in {
diff --git a/llvm/lib/Target/RISCV/RISCVRegisterInfo.td b/llvm/lib/Target/RISCV/RISCVRegisterInfo.td
index e0687b90ad17fe..7eb93973459c0d 100644
--- a/llvm/lib/Target/RISCV/RISCVRegisterInfo.td
+++ b/llvm/lib/Target/RISCV/RISCVRegisterInfo.td
@@ -325,7 +325,7 @@ let RegAltNameIndices = [ABIRegAltName] in {
let RegInfos = XLenPairRI,
DecoderMethod = "DecodeGPRPairRegisterClass" in {
-def GPRPair : RISCVRegisterClass<[XLenPairVT], 64, (add
+def GPRPair : RISCVRegisterClass<[XLenPairVT, XLenPairFVT], 64, (add
X10_X11, X12_X13, X14_X15, X16_X17,
X6_X7,
X28_X29, X30_X31,
@@ -334,11 +334,11 @@ def GPRPair : RISCVRegisterClass<[XLenPairVT], 64, (add
X0_Pair, X2_X3, X4_X5
)>;
-def GPRPairNoX0 : RISCVRegisterClass<[XLenPairVT], 64, (sub GPRPair, X0_Pair)>;
+def GPRPairNoX0 : RISCVRegisterClass<[XLenPairVT, XLenPairFVT], 64, (sub GPRPair, X0_Pair)>;
} // let RegInfos = XLenPairRI, DecoderMethod = "DecodeGPRPairRegisterClass"
let RegInfos = XLenPairRI in
-def GPRPairC : RISCVRegisterClass<[XLenPairVT], 64, (add
+def GPRPairC : RISCVRegisterClass<[XLenPairVT, XLenPairFVT], 64, (add
X10_X11, X12_X13, X14_X15, X8_X9
)>;
@@ -464,22 +464,6 @@ def GPRF32C : RISCVRegisterClass<[f32], 32, (add (sequence "X%u_W", 10, 15),
(sequence "X%u_W", 8, 9))>;
def GPRF32NoX0 : RISCVRegisterClass<[f32], 32, (sub GPRF32, X0_W)>;
-let DecoderMethod = "DecodeGPRPairRegisterClass" in
-def GPRF64Pair : RISCVRegisterClass<[XLenPairFVT], 64, (add
- X10_X11, X12_X13, X14_X15, X16_X17,
- X6_X7,
- X28_X29, X30_X31,
- X8_X9,
- X18_X19, X20_X21, X22_X23, X24_X25, X26_X27,
- X0_Pair, X2_X3, X4_X5
-)>;
-
-def GPRF64PairC : RISCVRegisterClass<[XLenPairFVT], 64, (add
- X10_X11, X12_X13, X14_X15, X8_X9
-)>;
-
-def GPRF64PairNoX0 : RISCVRegisterClass<[XLenPairFVT], 64, (sub GPRF64Pair, X0_Pair)>;
-
//===----------------------------------------------------------------------===//
// Vector type mapping to LLVM types.
//===----------------------------------------------------------------------===//
diff --git a/llvm/test/CodeGen/RISCV/zdinx-asm-constraint.ll b/llvm/test/CodeGen/RISCV/zdinx-asm-constraint.ll
index fed8c59ad14e07..81a8a8065e6b61 100644
--- a/llvm/test/CodeGen/RISCV/zdinx-asm-constraint.ll
+++ b/llvm/test/CodeGen/RISCV/zdinx-asm-constraint.ll
@@ -46,6 +46,42 @@ entry:
ret void
}
+define dso_local void @zdinx_asm_inout(ptr nocapture noundef writeonly %a, double noundef %b) nounwind {
+; CHECK-LABEL: zdinx_asm_inout:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: mv a3, a2
+; CHECK-NEXT: mv a2, a1
+; CHECK-NEXT: #APP
+; CHECK-NEXT: fmv.d a2, a2
+; CHECK-NEXT: #NO_APP
+; CHECK-NEXT: sw a2, 8(a0)
+; CHECK-NEXT: sw a3, 12(a0)
+; CHECK-NEXT: ret
+entry:
+ %arrayidx = getelementptr inbounds double, ptr %a, i32 1
+ %0 = tail call double asm "fsgnj.d $0, $1, $1", "=r,0"(double %b)
+ store double %0, ptr %arrayidx, align 8
+ ret void
+}
+
+define dso_local void @zdinx_asm_Pr_inout(ptr nocapture noundef writeonly %a, double noundef %b) nounwind {
+; CHECK-LABEL: zdinx_asm_Pr_inout:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: mv a3, a2
+; CHECK-NEXT: mv a2, a1
+; CHECK-NEXT: #APP
+; CHECK-NEXT: fabs.d a2, a2
+; CHECK-NEXT: #NO_APP
+; CHECK-NEXT: sw a2, 8(a0)
+; CHECK-NEXT: sw a3, 12(a0)
+; CHECK-NEXT: ret
+entry:
+ %arrayidx = getelementptr inbounds double, ptr %a, i32 1
+ %0 = tail call double asm "fsgnjx.d $0, $1, $1", "=R,0"(double %b)
+ store double %0, ptr %arrayidx, align 8
+ ret void
+}
+
define dso_local void @zfinx_asm(ptr nocapture noundef writeonly %a, float noundef %b, float noundef %c) nounwind {
; CHECK-LABEL: zfinx_asm:
; CHECK: # %bb.0: # %entry
More information about the llvm-commits
mailing list