[clang] [llvm] [RISCV] Inline Assembly Support for GPR Pairs ('Pr') (PR #112983)

Sam Elliott via cfe-commits cfe-commits at lists.llvm.org
Wed Nov 13 08:37:59 PST 2024


https://github.com/lenary updated https://github.com/llvm/llvm-project/pull/112983

>From 0b98a56337d3210e82cac0f509eb7d3d547083f9 Mon Sep 17 00:00:00 2001
From: Sam Elliott <quic_aelliott at quicinc.com>
Date: Tue, 22 Oct 2024 12:37:48 -0700
Subject: [PATCH] [RISCV] GPR Pairs for Inline Asm using `Pr`

This patch adds support for getting even-odd general purpose register
pairs into and out of inline assembly using the `Pr` constraint as
proposed in riscv-non-isa/riscv-c-api-doc#92

There are a few different pieces to this patch, each of which need their
own explanation.

- Renames the Register Class used for f64 values on rv32i_zdinx from
  `GPRPair*` to `GPRF64Pair*`. These register classes are kept broadly
  unmodified, as their primary value type is used for type inference
  over selection patterns. This rename affects quite a lot of files.

- Adds new `GPRPair*` register classes which will be used for `Pr`
  constraints and for instructions that need an even-odd GPR pair. This
  new type is used for `amocas.d.*`(rv32) and `amocas.q.*`(rv64) in
  Zacas, instead of the `GPRF64Pair` class being used before.

- Marks the new `GPRPair` class legal as for holding a `MVT::Untyped`.
  Two new RISCVISD node types are added for creating and destructing a
  pair - `BuildGPRPair` and `SplitGPRPair`, and are introduced when
  bitcasting to/from the pair type and `untyped`.

- Adds functionality to `splitValueIntoRegisterParts` and
  `joinRegisterPartsIntoValue` to handle changing `i<2*xlen>` MVTs into
  `untyped` pairs.

- Adds an override for `getNumRegisters` to ensure that `i<2*xlen>`
  values, when going to/from inline assembly, only allocate one (pair)
  register (they would otherwise allocate two). This is due to a bug in
  SelectionDAGBuilder.cpp which other backends also work around.

- Ensures that Clang understands that `Pr` is a valid inline assembly
  constraint.

- Adds Conditions to the GPRF64Pair-related changes to `LowerOperation`
  and `ReplaceNodeResults` which match when BITCAST for the relevant
  types should be handled in a custom manner.

- This also allows `Pr` to be used for `f64` types on `rv32_zdinx`
  architectures, where doubles are stored in a GPR pair.
---
 clang/lib/Basic/Targets/RISCV.cpp             | 11 ++-
 clang/test/CodeGen/RISCV/riscv-inline-asm.c   | 13 ++++
 .../Target/RISCV/AsmParser/RISCVAsmParser.cpp | 22 ++++--
 llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp   | 24 ++++--
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   | 58 +++++++++++++--
 llvm/lib/Target/RISCV/RISCVISelLowering.h     | 17 +++++
 llvm/lib/Target/RISCV/RISCVInstrInfoD.td      | 12 +--
 llvm/lib/Target/RISCV/RISCVRegisterInfo.td    | 23 +++++-
 .../CodeGen/RISCV/rv32-inline-asm-pairs.ll    | 73 +++++++++++++++++++
 .../CodeGen/RISCV/rv64-inline-asm-pairs.ll    | 73 +++++++++++++++++++
 .../CodeGen/RISCV/zdinx-asm-constraint.ll     | 26 +++++++
 11 files changed, 320 insertions(+), 32 deletions(-)
 create mode 100644 llvm/test/CodeGen/RISCV/rv32-inline-asm-pairs.ll
 create mode 100644 llvm/test/CodeGen/RISCV/rv64-inline-asm-pairs.ll

diff --git a/clang/lib/Basic/Targets/RISCV.cpp b/clang/lib/Basic/Targets/RISCV.cpp
index eaaba7642bd7b2..07bf002ed73928 100644
--- a/clang/lib/Basic/Targets/RISCV.cpp
+++ b/clang/lib/Basic/Targets/RISCV.cpp
@@ -108,6 +108,14 @@ bool RISCVTargetInfo::validateAsmConstraint(
       return true;
     }
     return false;
+  case 'P':
+    // An even-odd register pair - GPR
+    if (Name[1] == 'r') {
+      Info.setAllowsRegister();
+      Name += 1;
+      return true;
+    }
+    return false;
   case 'v':
     // A vector register.
     if (Name[1] == 'r' || Name[1] == 'd' || Name[1] == 'm') {
@@ -122,8 +130,9 @@ bool RISCVTargetInfo::validateAsmConstraint(
 std::string RISCVTargetInfo::convertConstraint(const char *&Constraint) const {
   std::string R;
   switch (*Constraint) {
-  // c* and v* are two-letter constraints on RISC-V.
+  // c*, P*, and v* are all two-letter constraints on RISC-V.
   case 'c':
+  case 'P':
   case 'v':
     R = std::string("^") + std::string(Constraint, 2);
     Constraint += 1;
diff --git a/clang/test/CodeGen/RISCV/riscv-inline-asm.c b/clang/test/CodeGen/RISCV/riscv-inline-asm.c
index 75b91d3c497c50..eb6e42f3eb9529 100644
--- a/clang/test/CodeGen/RISCV/riscv-inline-asm.c
+++ b/clang/test/CodeGen/RISCV/riscv-inline-asm.c
@@ -33,6 +33,19 @@ void test_cf(float f, double d) {
   asm volatile("" : "=cf"(cd) : "cf"(d));
 }
 
+#if __riscv_xlen == 32
+typedef long long double_xlen_t;
+#elif __riscv_xlen == 64
+typedef __int128_t double_xlen_t;
+#endif
+double_xlen_t test_Pr_wide_scalar(double_xlen_t p) {
+// CHECK-LABEL: define{{.*}} {{i128|i64}} @test_Pr_wide_scalar(
+// CHECK: call {{i128|i64}} asm sideeffect "", "=^Pr,^Pr"({{i128|i64}} %{{.*}})
+  double_xlen_t ret;
+  asm volatile("" : "=Pr"(ret) : "Pr"(p));
+  return ret;
+}
+
 void test_I(void) {
 // CHECK-LABEL: define{{.*}} void @test_I()
 // CHECK: call void asm sideeffect "", "I"(i32 2047)
diff --git a/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp b/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
index 4d46afb8c4ef97..1b23b36a59e0ec 100644
--- a/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
+++ b/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
@@ -481,6 +481,12 @@ struct RISCVOperand final : public MCParsedAsmOperand {
            RISCVMCRegisterClasses[RISCV::GPRRegClassID].contains(Reg.RegNum);
   }
 
+  bool isGPRPair() const {
+    return Kind == KindTy::Register &&
+           RISCVMCRegisterClasses[RISCV::GPRPairRegClassID].contains(
+               Reg.RegNum);
+  }
+
   bool isGPRF16() const {
     return Kind == KindTy::Register &&
            RISCVMCRegisterClasses[RISCV::GPRF16RegClassID].contains(Reg.RegNum);
@@ -491,17 +497,17 @@ struct RISCVOperand final : public MCParsedAsmOperand {
            RISCVMCRegisterClasses[RISCV::GPRF32RegClassID].contains(Reg.RegNum);
   }
 
-  bool isGPRAsFPR() const { return isGPR() && Reg.IsGPRAsFPR; }
-  bool isGPRAsFPR16() const { return isGPRF16() && Reg.IsGPRAsFPR; }
-  bool isGPRAsFPR32() const { return isGPRF32() && Reg.IsGPRAsFPR; }
-  bool isGPRPairAsFPR() const { return isGPRPair() && Reg.IsGPRAsFPR; }
-
-  bool isGPRPair() const {
+  bool isGPRF64Pair() const {
     return Kind == KindTy::Register &&
-           RISCVMCRegisterClasses[RISCV::GPRPairRegClassID].contains(
+           RISCVMCRegisterClasses[RISCV::GPRF64PairRegClassID].contains(
                Reg.RegNum);
   }
 
+  bool isGPRAsFPR() const { return isGPR() && Reg.IsGPRAsFPR; }
+  bool isGPRAsFPR16() const { return isGPRF16() && Reg.IsGPRAsFPR; }
+  bool isGPRAsFPR32() const { return isGPRF32() && Reg.IsGPRAsFPR; }
+  bool isGPRPairAsFPR64() const { return isGPRF64Pair() && Reg.IsGPRAsFPR; }
+
   static bool evaluateConstantImm(const MCExpr *Expr, int64_t &Imm,
                                   RISCVMCExpr::VariantKind &VK) {
     if (auto *RE = dyn_cast<RISCVMCExpr>(Expr)) {
@@ -2399,7 +2405,7 @@ ParseStatus RISCVAsmParser::parseGPRPairAsFPR64(OperandVector &Operands) {
   const MCRegisterInfo *RI = getContext().getRegisterInfo();
   MCRegister Pair = RI->getMatchingSuperReg(
       Reg, RISCV::sub_gpr_even,
-      &RISCVMCRegisterClasses[RISCV::GPRPairRegClassID]);
+      &RISCVMCRegisterClasses[RISCV::GPRF64PairRegClassID]);
   Operands.push_back(RISCVOperand::createReg(Pair, S, E, /*isGPRAsFPR=*/true));
   return ParseStatus::Success;
 }
diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index a1b74faf17fab3..034314c88f79f0 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -952,27 +952,36 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
     ReplaceNode(Node, Res);
     return;
   }
+  case RISCVISD::BuildGPRPair:
   case RISCVISD::BuildPairF64: {
-    if (!Subtarget->hasStdExtZdinx())
+    if (Opcode == RISCVISD::BuildPairF64 && !Subtarget->hasStdExtZdinx())
       break;
 
-    assert(!Subtarget->is64Bit() && "Unexpected subtarget");
+    assert((!Subtarget->is64Bit() || Opcode == RISCVISD::BuildGPRPair) &&
+           "BuildPairF64 only handled here on rv32i_zdinx");
+
+    int RegClassID = (Opcode == RISCVISD::BuildGPRPair)
+                         ? RISCV::GPRPairRegClassID
+                         : RISCV::GPRF64PairRegClassID;
+    MVT OutType = (Opcode == RISCVISD::BuildGPRPair) ? MVT::Untyped : MVT::f64;
 
     SDValue Ops[] = {
-        CurDAG->getTargetConstant(RISCV::GPRPairRegClassID, DL, MVT::i32),
+        CurDAG->getTargetConstant(RegClassID, DL, MVT::i32),
         Node->getOperand(0),
         CurDAG->getTargetConstant(RISCV::sub_gpr_even, DL, MVT::i32),
         Node->getOperand(1),
         CurDAG->getTargetConstant(RISCV::sub_gpr_odd, DL, MVT::i32)};
 
     SDNode *N =
-        CurDAG->getMachineNode(TargetOpcode::REG_SEQUENCE, DL, MVT::f64, Ops);
+        CurDAG->getMachineNode(TargetOpcode::REG_SEQUENCE, DL, OutType, Ops);
     ReplaceNode(Node, N);
     return;
   }
+  case RISCVISD::SplitGPRPair:
   case RISCVISD::SplitF64: {
-    if (Subtarget->hasStdExtZdinx()) {
-      assert(!Subtarget->is64Bit() && "Unexpected subtarget");
+    if (Subtarget->hasStdExtZdinx() || Opcode != RISCVISD::SplitF64) {
+      assert((!Subtarget->is64Bit() || Opcode == RISCVISD::SplitGPRPair) &&
+             "SplitF64 only handled here on rv32i_zdinx");
 
       if (!SDValue(Node, 0).use_empty()) {
         SDValue Lo = CurDAG->getTargetExtractSubreg(RISCV::sub_gpr_even, DL, VT,
@@ -990,6 +999,9 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
       return;
     }
 
+    assert(Opcode != RISCVISD::SplitGPRPair &&
+           "SplitGPRPair should already be handled");
+
     if (!Subtarget->hasStdExtZfa())
       break;
     assert(Subtarget->hasStdExtD() && !Subtarget->is64Bit() &&
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 3df8eca8cae7fb..b6e66a3f7c5fa1 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -133,7 +133,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
     if (Subtarget.is64Bit())
       addRegisterClass(MVT::f64, &RISCV::GPRRegClass);
     else
-      addRegisterClass(MVT::f64, &RISCV::GPRPairRegClass);
+      addRegisterClass(MVT::f64, &RISCV::GPRF64PairRegClass);
   }
 
   static const MVT::SimpleValueType BoolVecVTs[] = {
@@ -2225,6 +2225,17 @@ MVT RISCVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
   return PartVT;
 }
 
+unsigned
+RISCVTargetLowering::getNumRegisters(LLVMContext &Context, EVT VT,
+                                     std::optional<MVT> RegisterVT) const {
+  // Pair inline assembly operand
+  if (VT == (Subtarget.is64Bit() ? MVT::i128 : MVT::i64) && RegisterVT &&
+      *RegisterVT == MVT::Untyped)
+    return 1;
+
+  return TargetLowering::getNumRegisters(Context, VT, RegisterVT);
+}
+
 unsigned RISCVTargetLowering::getNumRegistersForCallingConv(LLVMContext &Context,
                                                            CallingConv::ID CC,
                                                            EVT VT) const {
@@ -6422,7 +6433,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
       SDValue NewOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, Op0);
       return DAG.getNode(RISCVISD::FMV_W_X_RV64, DL, MVT::f32, NewOp0);
     }
-    if (VT == MVT::f64 && Op0VT == MVT::i64 && XLenVT == MVT::i32) {
+    if (VT == MVT::f64 && Op0VT == MVT::i64 && !Subtarget.is64Bit() &&
+        Subtarget.hasStdExtDOrZdinx()) {
       SDValue Lo, Hi;
       std::tie(Lo, Hi) = DAG.SplitScalar(Op0, DL, MVT::i32, MVT::i32);
       return DAG.getNode(RISCVISD::BuildPairF64, DL, MVT::f64, Lo, Hi);
@@ -12940,7 +12952,8 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
       SDValue FPConv =
           DAG.getNode(RISCVISD::FMV_X_ANYEXTW_RV64, DL, MVT::i64, Op0);
       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, FPConv));
-    } else if (VT == MVT::i64 && Op0VT == MVT::f64 && XLenVT == MVT::i32) {
+    } else if (VT == MVT::i64 && Op0VT == MVT::f64 && !Subtarget.is64Bit() &&
+               Subtarget.hasStdExtDOrZdinx()) {
       SDValue NewReg = DAG.getNode(RISCVISD::SplitF64, DL,
                                    DAG.getVTList(MVT::i32, MVT::i32), Op0);
       SDValue RetReg = DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i64,
@@ -20185,6 +20198,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(TAIL)
   NODE_NAME_CASE(SELECT_CC)
   NODE_NAME_CASE(BR_CC)
+  NODE_NAME_CASE(BuildGPRPair)
+  NODE_NAME_CASE(SplitGPRPair)
   NODE_NAME_CASE(BuildPairF64)
   NODE_NAME_CASE(SplitF64)
   NODE_NAME_CASE(ADD_LO)
@@ -20461,6 +20476,8 @@ RISCVTargetLowering::getConstraintType(StringRef Constraint) const {
       return C_RegisterClass;
     if (Constraint == "cr" || Constraint == "cf")
       return C_RegisterClass;
+    if (Constraint == "Pr")
+      return C_RegisterClass;
   }
   return TargetLowering::getConstraintType(Constraint);
 }
@@ -20482,7 +20499,7 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
       if (VT == MVT::f32 && Subtarget.hasStdExtZfinx())
         return std::make_pair(0U, &RISCV::GPRF32NoX0RegClass);
       if (VT == MVT::f64 && Subtarget.hasStdExtZdinx() && !Subtarget.is64Bit())
-        return std::make_pair(0U, &RISCV::GPRPairNoX0RegClass);
+        return std::make_pair(0U, &RISCV::GPRF64PairNoX0RegClass);
       return std::make_pair(0U, &RISCV::GPRNoX0RegClass);
     case 'f':
       if (VT == MVT::f16) {
@@ -20499,7 +20516,7 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
         if (Subtarget.hasStdExtD())
           return std::make_pair(0U, &RISCV::FPR64RegClass);
         if (Subtarget.hasStdExtZdinx() && !Subtarget.is64Bit())
-          return std::make_pair(0U, &RISCV::GPRPairNoX0RegClass);
+          return std::make_pair(0U, &RISCV::GPRF64PairNoX0RegClass);
         if (Subtarget.hasStdExtZdinx() && Subtarget.is64Bit())
           return std::make_pair(0U, &RISCV::GPRNoX0RegClass);
       }
@@ -20541,7 +20558,7 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
     if (VT == MVT::f32 && Subtarget.hasStdExtZfinx())
       return std::make_pair(0U, &RISCV::GPRF32CRegClass);
     if (VT == MVT::f64 && Subtarget.hasStdExtZdinx() && !Subtarget.is64Bit())
-      return std::make_pair(0U, &RISCV::GPRPairCRegClass);
+      return std::make_pair(0U, &RISCV::GPRF64PairCRegClass);
     if (!VT.isVector())
       return std::make_pair(0U, &RISCV::GPRCRegClass);
   } else if (Constraint == "cf") {
@@ -20559,10 +20576,14 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
       if (Subtarget.hasStdExtD())
         return std::make_pair(0U, &RISCV::FPR64CRegClass);
       if (Subtarget.hasStdExtZdinx() && !Subtarget.is64Bit())
-        return std::make_pair(0U, &RISCV::GPRPairCRegClass);
+        return std::make_pair(0U, &RISCV::GPRF64PairCRegClass);
       if (Subtarget.hasStdExtZdinx() && Subtarget.is64Bit())
         return std::make_pair(0U, &RISCV::GPRCRegClass);
     }
+  } else if (Constraint == "Pr") {
+    if (VT == MVT::f64 && !Subtarget.is64Bit() && Subtarget.hasStdExtZdinx())
+      return std::make_pair(0U, &RISCV::GPRF64PairCRegClass);
+    return std::make_pair(0U, &RISCV::GPRPairNoX0RegClass);
   }
 
   // Clang will correctly decode the usage of register name aliases into their
@@ -20723,7 +20744,7 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
   // Subtarget into account.
   if (Res.second == &RISCV::GPRF16RegClass ||
       Res.second == &RISCV::GPRF32RegClass ||
-      Res.second == &RISCV::GPRPairRegClass)
+      Res.second == &RISCV::GPRF64PairRegClass)
     return std::make_pair(Res.first, &RISCV::GPRRegClass);
 
   return Res;
@@ -21349,6 +21370,16 @@ bool RISCVTargetLowering::splitValueIntoRegisterParts(
     unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const {
   bool IsABIRegCopy = CC.has_value();
   EVT ValueVT = Val.getValueType();
+
+  if (ValueVT == (Subtarget.is64Bit() ? MVT::i128 : MVT::i64) &&
+      NumParts == 1 && PartVT == MVT::Untyped) {
+    // Pairs in Inline Assembly
+    MVT XLenVT = Subtarget.getXLenVT();
+    auto [Lo, Hi] = DAG.SplitScalar(Val, DL, XLenVT, XLenVT);
+    Parts[0] = DAG.getNode(RISCVISD::BuildGPRPair, DL, MVT::Untyped, Lo, Hi);
+    return true;
+  }
+
   if (IsABIRegCopy && (ValueVT == MVT::f16 || ValueVT == MVT::bf16) &&
       PartVT == MVT::f32) {
     // Cast the [b]f16 to i16, extend to i32, pad with ones to make a float
@@ -21420,6 +21451,17 @@ SDValue RISCVTargetLowering::joinRegisterPartsIntoValue(
     SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts, unsigned NumParts,
     MVT PartVT, EVT ValueVT, std::optional<CallingConv::ID> CC) const {
   bool IsABIRegCopy = CC.has_value();
+
+  if (ValueVT == (Subtarget.is64Bit() ? MVT::i128 : MVT::i64) &&
+      NumParts == 1 && PartVT == MVT::Untyped) {
+    // Pairs in Inline Assembly
+    MVT XLenVT = Subtarget.getXLenVT();
+    SDValue Res = DAG.getNode(RISCVISD::SplitGPRPair, DL,
+                              DAG.getVTList(XLenVT, XLenVT), Parts[0]);
+    return DAG.getNode(ISD::BUILD_PAIR, DL, ValueVT, Res.getValue(0),
+                       Res.getValue(1));
+  }
+
   if (IsABIRegCopy && (ValueVT == MVT::f16 || ValueVT == MVT::bf16) &&
       PartVT == MVT::f32) {
     SDValue Val = Parts[0];
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 9ae70d257fa442..773729d69a143f 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -44,6 +44,18 @@ enum NodeType : unsigned {
   SELECT_CC,
   BR_CC,
 
+  /// Turn a pair of `i<xlen>`s into an even-odd register pair (`untyped`).
+  /// - Output: `untyped` even-odd register pair
+  /// - Input 0: `i<xlen>` low-order bits, for even register.
+  /// - Input 1: `i<xlen>` high-order bits, for odd register.
+  BuildGPRPair,
+
+  /// Turn an even-odd register pair (`untyped`) into a pair of `i<xlen>`s.
+  /// - Output 0: `i<xlen>` low-order bits, from even register.
+  /// - Output 1: `i<xlen>` high-order bits, from odd register.
+  /// - Input: `untyped` even-odd register pair
+  SplitGPRPair,
+
   /// Turns a pair of `i32`s into an `f64`. Needed for rv32d/ilp32.
   /// - Output: `f64`.
   /// - Input 0: low-order bits (31-0) (as `i32`), for even register.
@@ -547,6 +559,11 @@ class RISCVTargetLowering : public TargetLowering {
   MVT getRegisterTypeForCallingConv(LLVMContext &Context, CallingConv::ID CC,
                                     EVT VT) const override;
 
+  /// Return the number of registers for a given MVT, for inline assembly
+  unsigned
+  getNumRegisters(LLVMContext &Context, EVT VT,
+                  std::optional<MVT> RegisterVT = std::nullopt) const override;
+
   /// Return the number of registers for a given MVT, ensuring vectors are
   /// treated as a series of gpr sized integers.
   unsigned getNumRegistersForCallingConv(LLVMContext &Context,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoD.td b/llvm/lib/Target/RISCV/RISCVInstrInfoD.td
index b5e1298c7770a8..961c22e502b522 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoD.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoD.td
@@ -36,7 +36,7 @@ def AddrRegImmINX : ComplexPattern<iPTR, 2, "SelectAddrRegImmRV32Zdinx">;
 def GPRPairAsFPR : AsmOperandClass {
   let Name = "GPRPairAsFPR";
   let ParserMethod = "parseGPRPairAsFPR64";
-  let PredicateMethod = "isGPRPairAsFPR";
+  let PredicateMethod = "isGPRPairAsFPR64";
   let RenderMethod = "addRegOperands";
 }
 
@@ -52,7 +52,7 @@ def FPR64INX : RegisterOperand<GPR> {
   let DecoderMethod = "DecodeGPRRegisterClass";
 }
 
-def FPR64IN32X : RegisterOperand<GPRPair> {
+def FPR64IN32X : RegisterOperand<GPRF64Pair> {
   let ParserMatchClass = GPRPairAsFPR;
 }
 
@@ -523,15 +523,15 @@ def PseudoFROUND_D_IN32X : PseudoFROUND<FPR64IN32X, f64>;
 
 /// Loads
 let isCall = 0, mayLoad = 1, mayStore = 0, Size = 8, isCodeGenOnly = 1 in
-def PseudoRV32ZdinxLD : Pseudo<(outs GPRPair:$dst), (ins GPR:$rs1, simm12:$imm12), []>;
+def PseudoRV32ZdinxLD : Pseudo<(outs GPRF64Pair:$dst), (ins GPR:$rs1, simm12:$imm12), []>;
 def : Pat<(f64 (load (AddrRegImmINX (XLenVT GPR:$rs1), simm12:$imm12))),
           (PseudoRV32ZdinxLD GPR:$rs1, simm12:$imm12)>;
 
 /// Stores
 let isCall = 0, mayLoad = 0, mayStore = 1, Size = 8, isCodeGenOnly = 1 in
-def PseudoRV32ZdinxSD : Pseudo<(outs), (ins GPRPair:$rs2, GPRNoX0:$rs1, simm12:$imm12), []>;
-def : Pat<(store (f64 GPRPair:$rs2), (AddrRegImmINX (XLenVT GPR:$rs1), simm12:$imm12)),
-          (PseudoRV32ZdinxSD GPRPair:$rs2, GPR:$rs1, simm12:$imm12)>;
+def PseudoRV32ZdinxSD : Pseudo<(outs), (ins GPRF64Pair:$rs2, GPRNoX0:$rs1, simm12:$imm12), []>;
+def : Pat<(store (f64 GPRF64Pair:$rs2), (AddrRegImmINX (XLenVT GPR:$rs1), simm12:$imm12)),
+          (PseudoRV32ZdinxSD GPRF64Pair:$rs2, GPR:$rs1, simm12:$imm12)>;
 } // Predicates = [HasStdExtZdinx, IsRV32]
 
 let Predicates = [HasStdExtD, IsRV32] in {
diff --git a/llvm/lib/Target/RISCV/RISCVRegisterInfo.td b/llvm/lib/Target/RISCV/RISCVRegisterInfo.td
index 8a722baae89c90..c37f10e1b2c6f8 100644
--- a/llvm/lib/Target/RISCV/RISCVRegisterInfo.td
+++ b/llvm/lib/Target/RISCV/RISCVRegisterInfo.td
@@ -208,6 +208,8 @@ let RegAltNameIndices = [ABIRegAltName] in {
 
 def XLenVT : ValueTypeByHwMode<[RV32, RV64],
                                [i32,  i64]>;
+defvar XLenPairVT = untyped;
+
 // Allow f64 in GPR for ZDINX on RV64.
 def XLenFVT : ValueTypeByHwMode<[RV64],
                                 [f64]>;
@@ -323,7 +325,7 @@ let RegAltNameIndices = [ABIRegAltName] in {
 
 let RegInfos = XLenPairRI,
     DecoderMethod = "DecodeGPRPairRegisterClass" in {
-def GPRPair : RISCVRegisterClass<[XLenPairFVT], 64, (add
+def GPRPair : RISCVRegisterClass<[XLenPairVT], 64, (add
     X10_X11, X12_X13, X14_X15, X16_X17,
     X6_X7,
     X28_X29, X30_X31,
@@ -332,11 +334,11 @@ def GPRPair : RISCVRegisterClass<[XLenPairFVT], 64, (add
     X0_Pair, X2_X3, X4_X5
 )>;
 
-def GPRPairNoX0 : RISCVRegisterClass<[XLenPairFVT], 64, (sub GPRPair, X0_Pair)>;
+def GPRPairNoX0 : RISCVRegisterClass<[XLenPairVT], 64, (sub GPRPair, X0_Pair)>;
 } // let RegInfos = XLenPairRI, DecoderMethod = "DecodeGPRPairRegisterClass"
 
 let RegInfos = XLenPairRI in
-def GPRPairC : RISCVRegisterClass<[XLenPairFVT], 64, (add
+def GPRPairC : RISCVRegisterClass<[XLenPairVT], 64, (add
   X10_X11, X12_X13, X14_X15, X8_X9
 )>;
 
@@ -462,6 +464,21 @@ def GPRF32C : RISCVRegisterClass<[f32], 32, (add (sequence "X%u_W", 10, 15),
                                                  (sequence "X%u_W", 8, 9))>;
 def GPRF32NoX0 : RISCVRegisterClass<[f32], 32, (sub GPRF32, X0_W)>;
 
+let DecoderMethod = "DecodeGPRPairRegisterClass" in
+def GPRF64Pair : RISCVRegisterClass<[XLenPairFVT], 64, (add
+    X10_X11, X12_X13, X14_X15, X16_X17,
+    X6_X7,
+    X28_X29, X30_X31,
+    X8_X9,
+    X18_X19, X20_X21, X22_X23, X24_X25, X26_X27,
+    X0_Pair, X2_X3, X4_X5
+)>;
+
+def GPRF64PairC : RISCVRegisterClass<[XLenPairFVT], 64, (add
+  X10_X11, X12_X13, X14_X15, X8_X9
+)>;
+
+def GPRF64PairNoX0 : RISCVRegisterClass<[XLenPairFVT], 64, (sub GPRF64Pair, X0_Pair)>;
 
 //===----------------------------------------------------------------------===//
 // Vector type mapping to LLVM types.
diff --git a/llvm/test/CodeGen/RISCV/rv32-inline-asm-pairs.ll b/llvm/test/CodeGen/RISCV/rv32-inline-asm-pairs.ll
new file mode 100644
index 00000000000000..a7f121c67e4abd
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rv32-inline-asm-pairs.ll
@@ -0,0 +1,73 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=riscv32 -verify-machineinstrs < %s \
+; RUN:   | FileCheck %s
+
+define i64 @test_Pr_wide_scalar_simple(i64 noundef %0) nounwind {
+; CHECK-LABEL: test_Pr_wide_scalar_simple:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    #APP
+; CHECK-NEXT:    # a2 <- a0
+; CHECK-NEXT:    #NO_APP
+; CHECK-NEXT:    mv a0, a2
+; CHECK-NEXT:    mv a1, a3
+; CHECK-NEXT:    ret
+entry:
+  %1 = call i64 asm sideeffect "/* $0 <- $1 */", "=&^Pr,^Pr"(i64 %0)
+  ret i64 %1
+}
+
+define i32 @test_Pr_wide_scalar_with_ops(i32 noundef %0) nounwind {
+; CHECK-LABEL: test_Pr_wide_scalar_with_ops:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    mv a1, a0
+; CHECK-NEXT:    #APP
+; CHECK-NEXT:    # a2 <- a0
+; CHECK-NEXT:    #NO_APP
+; CHECK-NEXT:    or a0, a2, a3
+; CHECK-NEXT:    ret
+entry:
+  %1 = zext i32 %0 to i64
+  %2 = shl i64 %1, 32
+  %3 = or i64 %1, %2
+  %4 = call i64 asm sideeffect "/* $0 <- $1 */", "=&^Pr,^Pr"(i64 %3)
+  %5 = trunc i64 %4 to i32
+  %6 = lshr i64 %4, 32
+  %7 = trunc i64 %6 to i32
+  %8 = or i32 %5, %7
+  ret i32 %8
+}
+
+define i64 @test_Pr_wide_scalar_inout(ptr %0, i64 noundef %1) nounwind {
+; CHECK-LABEL: test_Pr_wide_scalar_inout:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    addi sp, sp, -16
+; CHECK-NEXT:    mv a3, a2
+; CHECK-NEXT:    sw a0, 12(sp)
+; CHECK-NEXT:    mv a2, a1
+; CHECK-NEXT:    sw a1, 0(sp)
+; CHECK-NEXT:    sw a3, 4(sp)
+; CHECK-NEXT:    #APP
+; CHECK-NEXT:    # a0; a2
+; CHECK-NEXT:    #NO_APP
+; CHECK-NEXT:    sw a0, 12(sp)
+; CHECK-NEXT:    sw a2, 0(sp)
+; CHECK-NEXT:    sw a3, 4(sp)
+; CHECK-NEXT:    mv a0, a2
+; CHECK-NEXT:    mv a1, a3
+; CHECK-NEXT:    addi sp, sp, 16
+; CHECK-NEXT:    ret
+entry:
+  %2 = alloca ptr, align 4
+  %3 = alloca i64, align 8
+  store ptr %0, ptr %2, align 4
+  store i64 %1, ptr %3, align 8
+  %4 = load ptr, ptr %2, align 4
+  %5 = load i64, ptr %3, align 8
+  %6 = call { ptr, i64 } asm sideeffect "/* $0; $1 */", "=r,=^Pr,0,1"(ptr %4, i64 %5)
+  %7 = extractvalue { ptr, i64} %6, 0
+  %8 = extractvalue { ptr, i64 } %6, 1
+  store ptr %7, ptr %2, align 4
+  store i64 %8, ptr %3, align 8
+  %9 = load i64, ptr %3, align 8
+  ret i64 %9
+}
diff --git a/llvm/test/CodeGen/RISCV/rv64-inline-asm-pairs.ll b/llvm/test/CodeGen/RISCV/rv64-inline-asm-pairs.ll
new file mode 100644
index 00000000000000..d8b4b2e21c4f84
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rv64-inline-asm-pairs.ll
@@ -0,0 +1,73 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=riscv64 -verify-machineinstrs < %s \
+; RUN:   | FileCheck %s
+
+define i128 @test_Pr_wide_scalar_simple(i128 noundef %0) nounwind {
+; CHECK-LABEL: test_Pr_wide_scalar_simple:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    #APP
+; CHECK-NEXT:    # a2 <- a0
+; CHECK-NEXT:    #NO_APP
+; CHECK-NEXT:    mv a0, a2
+; CHECK-NEXT:    mv a1, a3
+; CHECK-NEXT:    ret
+entry:
+  %1 = call i128 asm sideeffect "/* $0 <- $1 */", "=&^Pr,^Pr"(i128 %0)
+  ret i128 %1
+}
+
+define i64 @test_Pr_wide_scalar_with_ops(i64 noundef %0) nounwind {
+; CHECK-LABEL: test_Pr_wide_scalar_with_ops:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    mv a1, a0
+; CHECK-NEXT:    #APP
+; CHECK-NEXT:    # a2 <- a0
+; CHECK-NEXT:    #NO_APP
+; CHECK-NEXT:    or a0, a2, a3
+; CHECK-NEXT:    ret
+entry:
+  %1 = zext i64 %0 to i128
+  %2 = shl i128 %1, 64
+  %3 = or i128 %1, %2
+  %4 = call i128 asm sideeffect "/* $0 <- $1 */", "=&^Pr,^Pr"(i128 %3)
+  %5 = trunc i128 %4 to i64
+  %6 = lshr i128 %4, 64
+  %7 = trunc i128 %6 to i64
+  %8 = or i64 %5, %7
+  ret i64 %8
+}
+
+define i128 @test_Pr_wide_scalar_inout(ptr %0, i128 noundef %1) nounwind {
+; CHECK-LABEL: test_Pr_wide_scalar_inout:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    addi sp, sp, -32
+; CHECK-NEXT:    mv a3, a2
+; CHECK-NEXT:    sd a0, 24(sp)
+; CHECK-NEXT:    mv a2, a1
+; CHECK-NEXT:    sd a1, 0(sp)
+; CHECK-NEXT:    sd a3, 8(sp)
+; CHECK-NEXT:    #APP
+; CHECK-NEXT:    # a0; a2
+; CHECK-NEXT:    #NO_APP
+; CHECK-NEXT:    sd a0, 24(sp)
+; CHECK-NEXT:    sd a2, 0(sp)
+; CHECK-NEXT:    sd a3, 8(sp)
+; CHECK-NEXT:    mv a0, a2
+; CHECK-NEXT:    mv a1, a3
+; CHECK-NEXT:    addi sp, sp, 32
+; CHECK-NEXT:    ret
+entry:
+  %2 = alloca ptr, align 8
+  %3 = alloca i128, align 16
+  store ptr %0, ptr %2, align 8
+  store i128 %1, ptr %3, align 16
+  %4 = load ptr, ptr %2, align 8
+  %5 = load i128, ptr %3, align 16
+  %6 = call { ptr, i128 } asm sideeffect "/* $0; $1 */", "=r,=^Pr,0,1"(ptr %4, i128 %5)
+  %7 = extractvalue { ptr, i128} %6, 0
+  %8 = extractvalue { ptr, i128 } %6, 1
+  store ptr %7, ptr %2, align 8
+  store i128 %8, ptr %3, align 16
+  %9 = load i128, ptr %3, align 16
+  ret i128 %9
+}
diff --git a/llvm/test/CodeGen/RISCV/zdinx-asm-constraint.ll b/llvm/test/CodeGen/RISCV/zdinx-asm-constraint.ll
index 18bd41a210f53f..9f8acb6370c6f5 100644
--- a/llvm/test/CodeGen/RISCV/zdinx-asm-constraint.ll
+++ b/llvm/test/CodeGen/RISCV/zdinx-asm-constraint.ll
@@ -26,6 +26,32 @@ entry:
   ret void
 }
 
+define dso_local void @zdinx_asm_Pr(ptr nocapture noundef writeonly %a, double noundef %b, double noundef %c) nounwind {
+; CHECK-LABEL: zdinx_asm_Pr:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    addi sp, sp, -16
+; CHECK-NEXT:    sw s0, 12(sp) # 4-byte Folded Spill
+; CHECK-NEXT:    sw s1, 8(sp) # 4-byte Folded Spill
+; CHECK-NEXT:    mv a5, a4
+; CHECK-NEXT:    mv s1, a2
+; CHECK-NEXT:    mv a4, a3
+; CHECK-NEXT:    mv s0, a1
+; CHECK-NEXT:    #APP
+; CHECK-NEXT:    fsgnjx.d a2, s0, a4
+; CHECK-NEXT:    #NO_APP
+; CHECK-NEXT:    sw a2, 8(a0)
+; CHECK-NEXT:    sw a3, 12(a0)
+; CHECK-NEXT:    lw s0, 12(sp) # 4-byte Folded Reload
+; CHECK-NEXT:    lw s1, 8(sp) # 4-byte Folded Reload
+; CHECK-NEXT:    addi sp, sp, 16
+; CHECK-NEXT:    ret
+entry:
+  %arrayidx = getelementptr inbounds double, ptr %a, i32 1
+  %0 = tail call double asm "fsgnjx.d $0, $1, $2", "=^Pr,^Pr,^Pr"(double %b, double %c)
+  store double %0, ptr %arrayidx, align 8
+  ret void
+}
+
 define dso_local void @zfinx_asm(ptr nocapture noundef writeonly %a, float noundef %b, float noundef %c) nounwind {
 ; CHECK-LABEL: zfinx_asm:
 ; CHECK:       # %bb.0: # %entry



More information about the cfe-commits mailing list