[llvm] [RISCV] Add experimental support for making i32 a legal type on RV64 in SelectionDAG. (PR #70357)

via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 26 10:31:23 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Craig Topper (topperc)

<details>
<summary>Changes</summary>

This will select i32 operations directly to W instructions without custom nodes. Hopefully this can allow us to be less dependent on hasAllNBitUsers to recover i32 operations in RISCVISelDAGToDAG.cpp.

This support is enabled with a command line option that is off by default.

Generated code is still very not optimal.

I've duplicated many test cases for this, but its not complete. I think that enabling this runs all existing lit tests without crashing.

---

Patch is 378.42 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/70357.diff


32 Files Affected:

- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp (+18-1) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (+1-1) 
- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+1-1) 
- (modified) llvm/lib/Target/RISCV/RISCVGISel.td (+1-47) 
- (modified) llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp (+9-3) 
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+248-84) 
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.h (+6) 
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfo.td (+53-6) 
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfoA.td (+58) 
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfoD.td (+2-2) 
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfoF.td (+8-8) 
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfoM.td (+15) 
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td (+12-2) 
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfoZb.td (+26) 
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td (+1-1) 
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td (+8-8) 
- (added) llvm/test/CodeGen/RISCV/rv64-legal-i32/alu32.ll (+240) 
- (added) llvm/test/CodeGen/RISCV/rv64-legal-i32/div.ll (+699) 
- (added) llvm/test/CodeGen/RISCV/rv64-legal-i32/imm.ll (+2564) 
- (added) llvm/test/CodeGen/RISCV/rv64-legal-i32/mem.ll (+92) 
- (added) llvm/test/CodeGen/RISCV/rv64-legal-i32/mem64.ll (+341) 
- (added) llvm/test/CodeGen/RISCV/rv64-legal-i32/rem.ll (+390) 
- (added) llvm/test/CodeGen/RISCV/rv64-legal-i32/rv64xtheadbb.ll (+902) 
- (added) llvm/test/CodeGen/RISCV/rv64-legal-i32/rv64zba.ll (+1798) 
- (added) llvm/test/CodeGen/RISCV/rv64-legal-i32/rv64zbb-intrinsic.ll (+77) 
- (added) llvm/test/CodeGen/RISCV/rv64-legal-i32/rv64zbb-zbkb.ll (+600) 
- (added) llvm/test/CodeGen/RISCV/rv64-legal-i32/rv64zbb.ll (+1068) 
- (added) llvm/test/CodeGen/RISCV/rv64-legal-i32/rv64zbc-intrinsic.ll (+42) 
- (added) llvm/test/CodeGen/RISCV/rv64-legal-i32/rv64zbc-zbkc-intrinsic.ll (+67) 
- (added) llvm/test/CodeGen/RISCV/rv64-legal-i32/rv64zbkb-intrinsic.ll (+73) 
- (added) llvm/test/CodeGen/RISCV/rv64-legal-i32/rv64zbs.ll (+1000) 
- (added) llvm/test/CodeGen/RISCV/rv64-legal-i32/xaluo.ll (+1308) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index f19beea3a3ed8b7..82751a442dbc3bc 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -5023,6 +5023,10 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
   case ISD::SREM:
   case ISD::UDIV:
   case ISD::UREM:
+  case ISD::SMIN:
+  case ISD::SMAX:
+  case ISD::UMIN:
+  case ISD::UMAX:
   case ISD::AND:
   case ISD::OR:
   case ISD::XOR: {
@@ -5039,12 +5043,21 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
         break;
       case ISD::SDIV:
       case ISD::SREM:
+      case ISD::SMIN:
+      case ISD::SMAX:
         ExtOp = ISD::SIGN_EXTEND;
         break;
       case ISD::UDIV:
       case ISD::UREM:
         ExtOp = ISD::ZERO_EXTEND;
         break;
+      case ISD::UMIN:
+      case ISD::UMAX:
+        if (TLI.isSExtCheaperThanZExt(OVT, NVT))
+          ExtOp = ISD::SIGN_EXTEND;
+        else
+          ExtOp = ISD::ZERO_EXTEND;
+        break;
       }
       TruncOp = ISD::TRUNCATE;
     }
@@ -5166,7 +5179,11 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
     unsigned ExtOp = ISD::FP_EXTEND;
     if (NVT.isInteger()) {
       ISD::CondCode CCCode = cast<CondCodeSDNode>(Node->getOperand(2))->get();
-      ExtOp = isSignedIntSetCC(CCCode) ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
+      if (isSignedIntSetCC(CCCode) ||
+          TLI.isSExtCheaperThanZExt(Node->getOperand(0).getValueType(), NVT))
+        ExtOp = ISD::SIGN_EXTEND;
+      else
+        ExtOp = ISD::ZERO_EXTEND;
     }
     if (Node->isStrictFPOpcode()) {
       SDValue InChain = Node->getOperand(0);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index fc9e3ff3734989d..4364606a403001d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -371,7 +371,7 @@ SDValue DAGTypeLegalizer::PromoteIntRes_AtomicCmpSwap(AtomicSDNode *N,
         N->getMemOperand());
     ReplaceValueWith(SDValue(N, 0), Res.getValue(0));
     ReplaceValueWith(SDValue(N, 2), Res.getValue(2));
-    return Res.getValue(1);
+    return DAG.getSExtOrTrunc(Res.getValue(1), SDLoc(N), NVT);
   }
 
   // Op2 is used for the comparison and thus must be extended according to the
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 71f6a3791c2cee0..e4f0e27e577befd 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -3468,7 +3468,7 @@ void SelectionDAGBuilder::visitSelect(const User &I) {
     }
 
     if (!IsUnaryAbs && Opc != ISD::DELETED_NODE &&
-        (TLI.isOperationLegalOrCustom(Opc, VT) ||
+        (TLI.isOperationLegalOrCustomOrPromote(Opc, VT) ||
          (UseScalarMinMax &&
           TLI.isOperationLegalOrCustom(Opc, VT.getScalarType()))) &&
         // If the underlying comparison instruction is used by any other
diff --git a/llvm/lib/Target/RISCV/RISCVGISel.td b/llvm/lib/Target/RISCV/RISCVGISel.td
index 60896106bc0b5bb..887671ecb435d2e 100644
--- a/llvm/lib/Target/RISCV/RISCVGISel.td
+++ b/llvm/lib/Target/RISCV/RISCVGISel.td
@@ -21,8 +21,6 @@ def simm12Plus1 : ImmLeaf<XLenVT, [{
 def simm12Plus1i32 : ImmLeaf<i32, [{
     return (isInt<12>(Imm) && Imm != -2048) || Imm == 2048;}]>;
 
-def simm12i32 : ImmLeaf<i32, [{return isInt<12>(Imm);}]>;
-
 def uimm5i32 : ImmLeaf<i32, [{return isUInt<5>(Imm);}]>;
 
 // FIXME: This doesn't check that the G_CONSTANT we're deriving the immediate
@@ -49,11 +47,6 @@ def GIAddrRegImm :
   GIComplexOperandMatcher<s32, "selectAddrRegImm">,
   GIComplexPatternEquiv<AddrRegImm>;
 
-// Convert from i32 immediate to i64 target immediate to make SelectionDAG type
-// checking happy so we can use ADDIW which expects an XLen immediate.
-def as_i64imm : SDNodeXForm<imm, [{
-  return CurDAG->getTargetConstant(N->getSExtValue(), SDLoc(N), MVT::i64);
-}]>;
 def gi_as_i64imm : GICustomOperandRenderer<"renderImm">,
   GISDNodeXFormEquiv<as_i64imm>;
 
@@ -88,11 +81,6 @@ def : Pat<(XLenVT (sub GPR:$rs1, simm12Plus1:$imm)),
           (ADDI GPR:$rs1, (NegImm simm12Plus1:$imm))>;
 
 let Predicates = [IsRV64] in {
-def : Pat<(i32 (add GPR:$rs1, GPR:$rs2)), (ADDW GPR:$rs1, GPR:$rs2)>;
-def : Pat<(i32 (sub GPR:$rs1, GPR:$rs2)), (SUBW GPR:$rs1, GPR:$rs2)>;
-
-def : Pat<(i32 (add GPR:$rs1, simm12i32:$imm)),
-          (ADDIW GPR:$rs1, (i64 (as_i64imm $imm)))>;
 def : Pat<(i32 (sub GPR:$rs1, simm12Plus1i32:$imm)),
           (ADDIW GPR:$rs1, (i64 (NegImm $imm)))>;
 
@@ -106,19 +94,6 @@ def : Pat<(i32 (sra GPR:$rs1, uimm5i32:$imm)),
           (SRAIW GPR:$rs1, (i64 (as_i64imm $imm)))>;
 def : Pat<(i32 (srl GPR:$rs1, uimm5i32:$imm)),
           (SRLIW GPR:$rs1, (i64 (as_i64imm $imm)))>;
-
-def : Pat<(i64 (sext i32:$rs)), (ADDIW GPR:$rs, 0)>;
-}
-
-let Predicates = [HasStdExtMOrZmmul, IsRV64] in {
-def : Pat<(i32 (mul GPR:$rs1, GPR:$rs2)), (MULW GPR:$rs1, GPR:$rs2)>;
-}
-
-let Predicates = [HasStdExtM, IsRV64] in {
-def : Pat<(i32 (sdiv GPR:$rs1, GPR:$rs2)), (DIVW GPR:$rs1, GPR:$rs2)>;
-def : Pat<(i32 (srem GPR:$rs1, GPR:$rs2)), (REMW GPR:$rs1, GPR:$rs2)>;
-def : Pat<(i32 (udiv GPR:$rs1, GPR:$rs2)), (DIVUW GPR:$rs1, GPR:$rs2)>;
-def : Pat<(i32 (urem GPR:$rs1, GPR:$rs2)), (REMUW GPR:$rs1, GPR:$rs2)>;
 }
 
 let Predicates = [HasStdExtZba, IsRV64] in {
@@ -126,13 +101,8 @@ let Predicates = [HasStdExtZba, IsRV64] in {
 // in SDISel for RV64, which is not the case in GISel.
 def : Pat<(shl (i64 (zext i32:$rs1)), uimm5:$shamt),
           (SLLI_UW GPR:$rs1, uimm5:$shamt)>;
-
-def : Pat<(i64 (zext i32:$rs)), (ADD_UW GPR:$rs, (XLenVT X0))>;
 } // Predicates = [HasStdExtZba, IsRV64]
 
-let Predicates = [IsRV64, NotHasStdExtZba] in
-def: Pat<(i64 (zext i32:$rs)), (SRLI (SLLI GPR:$rs, 32), 32)>;
-
 // Ptr type used in patterns with GlobalISelEmitter
 def PtrVT : PtrValueTypeByHwMode<XLenVT, 0>;
 
@@ -186,8 +156,6 @@ def : Pat<(XLenVT (setle (Ty GPR:$rs1), (Ty GPR:$rs2))),
           (XORI (SLT GPR:$rs2, GPR:$rs1), 1)>;
 }
 
-// Define pattern expansions for load/extload and store/truncstore operations
-// for ptr return type
 let Predicates = [IsRV32] in {
 def : LdPat<load, LW, PtrVT>;
 def : StPat<store, SW, GPR, PtrVT>;
@@ -196,18 +164,4 @@ def : StPat<store, SW, GPR, PtrVT>;
 let Predicates = [IsRV64] in {
 def : LdPat<load, LD, PtrVT>;
 def : StPat<store, SD, GPR, PtrVT>;
-
-// Define pattern expansions for rv64 load/extloads and store/truncstore
-// operations for i32 return type
-def : LdPat<sextloadi8, LB, i32>;
-def : LdPat<extloadi8, LBU, i32>;
-def : LdPat<zextloadi8, LBU, i32>;
-def : LdPat<sextloadi16, LH, i32>;
-def : LdPat<extloadi16, LH, i32>;
-def : LdPat<zextloadi16, LHU, i32>;
-def : LdPat<load, LW, i32>;
-
-def : StPat<truncstorei8, SB, GPR, i32>;
-def : StPat<truncstorei16, SH, GPR, i32>;
-def : StPat<store, SW, GPR, i32>;
-} // Predicates = [IsRV64]
+}
diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index 6c156057ccd7d0e..588224a221395a5 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -67,8 +67,11 @@ void RISCVDAGToDAGISel::PreprocessISelDAG() {
           VT.isInteger() ? RISCVISD::VMV_V_X_VL : RISCVISD::VFMV_V_F_VL;
       SDLoc DL(N);
       SDValue VL = CurDAG->getRegister(RISCV::X0, Subtarget->getXLenVT());
-      Result = CurDAG->getNode(Opc, DL, VT, CurDAG->getUNDEF(VT),
-                               N->getOperand(0), VL);
+      SDValue Src = N->getOperand(0);
+      if (VT.isInteger())
+        Src = CurDAG->getNode(ISD::ANY_EXTEND, DL, Subtarget->getXLenVT(),
+                              N->getOperand(0));
+      Result = CurDAG->getNode(Opc, DL, VT, CurDAG->getUNDEF(VT), Src, VL);
       break;
     }
     case RISCVISD::SPLAT_VECTOR_SPLIT_I64_VL: {
@@ -833,7 +836,7 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
 
   switch (Opcode) {
   case ISD::Constant: {
-    assert(VT == Subtarget->getXLenVT() && "Unexpected VT");
+    assert((VT == Subtarget->getXLenVT() || VT == MVT::i32) && "Unexpected VT");
     auto *ConstNode = cast<ConstantSDNode>(Node);
     if (ConstNode->isZero()) {
       SDValue New =
@@ -3288,6 +3291,9 @@ bool RISCVDAGToDAGISel::doPeepholeSExtW(SDNode *N) {
   case RISCV::TH_MULAH:
   case RISCV::TH_MULSW:
   case RISCV::TH_MULSH:
+    if (N0.getValueType() == MVT::i32)
+      break;
+
     // Result is already sign extended just remove the sext.w.
     // NOTE: We only handle the nodes that are selected with hasAllWUsers.
     ReplaceUses(N, N0.getNode());
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index beb371063f89b2d..364bfb6fc77947a 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -75,6 +75,10 @@ static cl::opt<int>
                        "use for creating a floating-point immediate value"),
               cl::init(2));
 
+static cl::opt<bool>
+    RV64LegalI32("riscv-experimental-rv64-legal-i32", cl::ReallyHidden,
+                 cl::desc("Make i32 a legal type for SelectionDAG on RV64."));
+
 RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
                                          const RISCVSubtarget &STI)
     : TargetLowering(TM), Subtarget(STI) {
@@ -115,6 +119,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
 
   // Set up the register classes.
   addRegisterClass(XLenVT, &RISCV::GPRRegClass);
+  if (Subtarget.is64Bit() && RV64LegalI32)
+    addRegisterClass(MVT::i32, &RISCV::GPRRegClass);
 
   if (Subtarget.hasStdExtZfhOrZfhmin())
     addRegisterClass(MVT::f16, &RISCV::FPR16RegClass);
@@ -237,8 +243,12 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
 
   setOperationAction(ISD::BR_JT, MVT::Other, Expand);
   setOperationAction(ISD::BR_CC, XLenVT, Expand);
+  if (RV64LegalI32 && Subtarget.is64Bit())
+    setOperationAction(ISD::BR_CC, MVT::i32, Expand);
   setOperationAction(ISD::BRCOND, MVT::Other, Custom);
   setOperationAction(ISD::SELECT_CC, XLenVT, Expand);
+  if (RV64LegalI32 && Subtarget.is64Bit())
+    setOperationAction(ISD::SELECT_CC, MVT::i32, Expand);
 
   setCondCodeAction(ISD::SETLE, XLenVT, Expand);
   setCondCodeAction(ISD::SETGT, XLenVT, Custom);
@@ -247,6 +257,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
   setCondCodeAction(ISD::SETUGT, XLenVT, Custom);
   setCondCodeAction(ISD::SETUGE, XLenVT, Expand);
 
+  if (RV64LegalI32 && Subtarget.is64Bit())
+    setOperationAction(ISD::SETCC, MVT::i32, Promote);
+
   setOperationAction({ISD::STACKSAVE, ISD::STACKRESTORE}, MVT::Other, Expand);
 
   setOperationAction(ISD::VASTART, MVT::Other, Custom);
@@ -262,14 +275,20 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
   if (Subtarget.is64Bit()) {
     setOperationAction(ISD::EH_DWARF_CFA, MVT::i64, Custom);
 
-    setOperationAction(ISD::LOAD, MVT::i32, Custom);
+    if (!RV64LegalI32)
+      setOperationAction(ISD::LOAD, MVT::i32, Custom);
 
-    setOperationAction({ISD::ADD, ISD::SUB, ISD::SHL, ISD::SRA, ISD::SRL},
-                       MVT::i32, Custom);
+    if (RV64LegalI32)
+      setOperationAction({ISD::AND, ISD::OR, ISD::XOR}, MVT::i32, Promote);
+    else
+      setOperationAction({ISD::ADD, ISD::SUB, ISD::SHL, ISD::SRA, ISD::SRL},
+                         MVT::i32, Custom);
 
-    setOperationAction(ISD::SADDO, MVT::i32, Custom);
-    setOperationAction({ISD::UADDO, ISD::USUBO, ISD::UADDSAT, ISD::USUBSAT},
-                       MVT::i32, Custom);
+    if (!RV64LegalI32) {
+      setOperationAction(ISD::SADDO, MVT::i32, Custom);
+      setOperationAction({ISD::UADDO, ISD::USUBO, ISD::UADDSAT, ISD::USUBSAT},
+                         MVT::i32, Custom);
+    }
   } else {
     setLibcallName(
         {RTLIB::SHL_I128, RTLIB::SRL_I128, RTLIB::SRA_I128, RTLIB::MUL_I128},
@@ -277,19 +296,36 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
     setLibcallName(RTLIB::MULO_I64, nullptr);
   }
 
-  if (!Subtarget.hasStdExtM() && !Subtarget.hasStdExtZmmul())
+  if (!Subtarget.hasStdExtM() && !Subtarget.hasStdExtZmmul()) {
     setOperationAction({ISD::MUL, ISD::MULHS, ISD::MULHU}, XLenVT, Expand);
-  else if (Subtarget.is64Bit())
-    setOperationAction(ISD::MUL, {MVT::i32, MVT::i128}, Custom);
-  else
+    if (RV64LegalI32 && Subtarget.is64Bit())
+      setOperationAction(ISD::MUL, MVT::i32, Promote);
+  } else if (Subtarget.is64Bit()) {
+    setOperationAction(ISD::MUL, MVT::i128, Custom);
+    if (!RV64LegalI32)
+      setOperationAction(ISD::MUL, MVT::i32, Custom);
+  } else {
     setOperationAction(ISD::MUL, MVT::i64, Custom);
+  }
 
-  if (!Subtarget.hasStdExtM())
+  if (!Subtarget.hasStdExtM()) {
     setOperationAction({ISD::SDIV, ISD::UDIV, ISD::SREM, ISD::UREM},
                        XLenVT, Expand);
-  else if (Subtarget.is64Bit())
-    setOperationAction({ISD::SDIV, ISD::UDIV, ISD::UREM},
-                       {MVT::i8, MVT::i16, MVT::i32}, Custom);
+    if (RV64LegalI32 && Subtarget.is64Bit())
+      setOperationAction({ISD::SDIV, ISD::UDIV, ISD::SREM, ISD::UREM}, MVT::i32,
+                         Promote);
+  } else if (Subtarget.is64Bit()) {
+    if (!RV64LegalI32)
+      setOperationAction({ISD::SDIV, ISD::UDIV, ISD::UREM},
+                         {MVT::i8, MVT::i16, MVT::i32}, Custom);
+  }
+
+  if (RV64LegalI32 && Subtarget.is64Bit()) {
+    setOperationAction({ISD::MULHS, ISD::MULHU}, MVT::i32, Expand);
+    setOperationAction(
+        {ISD::SDIVREM, ISD::UDIVREM, ISD::SMUL_LOHI, ISD::UMUL_LOHI}, MVT::i32,
+        Expand);
+  }
 
   setOperationAction(
       {ISD::SDIVREM, ISD::UDIVREM, ISD::SMUL_LOHI, ISD::UMUL_LOHI}, XLenVT,
@@ -299,7 +335,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
                      Custom);
 
   if (Subtarget.hasStdExtZbb() || Subtarget.hasStdExtZbkb()) {
-    if (Subtarget.is64Bit())
+    if (!RV64LegalI32 && Subtarget.is64Bit())
       setOperationAction({ISD::ROTL, ISD::ROTR}, MVT::i32, Custom);
   } else if (Subtarget.hasVendorXTHeadBb()) {
     if (Subtarget.is64Bit())
@@ -307,6 +343,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
     setOperationAction({ISD::ROTL, ISD::ROTR}, XLenVT, Custom);
   } else {
     setOperationAction({ISD::ROTL, ISD::ROTR}, XLenVT, Expand);
+    if (RV64LegalI32 && Subtarget.is64Bit())
+      setOperationAction({ISD::ROTL, ISD::ROTR}, MVT::i32, Expand);
   }
 
   // With Zbb we have an XLen rev8 instruction, but not GREVI. So we'll
@@ -316,6 +354,13 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
                       Subtarget.hasVendorXTHeadBb())
                          ? Legal
                          : Expand);
+  if (RV64LegalI32 && Subtarget.is64Bit())
+    setOperationAction(ISD::BSWAP, MVT::i32,
+                       (Subtarget.hasStdExtZbb() || Subtarget.hasStdExtZbkb() ||
+                        Subtarget.hasVendorXTHeadBb())
+                           ? Promote
+                           : Expand);
+
   // Zbkb can use rev8+brev8 to implement bitreverse.
   setOperationAction(ISD::BITREVERSE, XLenVT,
                      Subtarget.hasStdExtZbkb() ? Custom : Expand);
@@ -323,30 +368,49 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
   if (Subtarget.hasStdExtZbb()) {
     setOperationAction({ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX}, XLenVT,
                        Legal);
+    if (RV64LegalI32 && Subtarget.is64Bit())
+      setOperationAction({ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX}, MVT::i32,
+                         Promote);
 
-    if (Subtarget.is64Bit())
-      setOperationAction(
-          {ISD::CTTZ, ISD::CTTZ_ZERO_UNDEF},
-          MVT::i32, Custom);
+    if (Subtarget.is64Bit()) {
+      if (RV64LegalI32)
+        setOperationAction(ISD::CTTZ, MVT::i32, Legal);
+      else
+        setOperationAction({ISD::CTTZ, ISD::CTTZ_ZERO_UNDEF}, MVT::i32, Custom);
+    }
   } else {
     setOperationAction({ISD::CTTZ, ISD::CTPOP}, XLenVT, Expand);
+    if (RV64LegalI32 && Subtarget.is64Bit())
+      setOperationAction({ISD::CTTZ, ISD::CTPOP}, MVT::i32, Expand);
   }
 
   if (Subtarget.hasStdExtZbb() || Subtarget.hasVendorXTHeadBb()) {
     // We need the custom lowering to make sure that the resulting sequence
     // for the 32bit case is efficient on 64bit targets.
-    if (Subtarget.is64Bit())
-      setOperationAction({ISD::CTLZ, ISD::CTLZ_ZERO_UNDEF}, MVT::i32, Custom);
+    if (Subtarget.is64Bit()) {
+      if (RV64LegalI32) {
+        setOperationAction(ISD::CTLZ, MVT::i32,
+                           Subtarget.hasStdExtZbb() ? Legal : Promote);
+        if (!Subtarget.hasStdExtZbb())
+          setOperationAction(ISD::CTLZ_ZERO_UNDEF, MVT::i32, Promote);
+      } else
+        setOperationAction({ISD::CTLZ, ISD::CTLZ_ZERO_UNDEF}, MVT::i32, Custom);
+    }
   } else {
     setOperationAction(ISD::CTLZ, XLenVT, Expand);
+    if (RV64LegalI32 && Subtarget.is64Bit())
+      setOperationAction(ISD::CTLZ, MVT::i32, Expand);
   }
 
-  if (Subtarget.is64Bit())
+  if (!RV64LegalI32 && Subtarget.is64Bit())
     setOperationAction(ISD::ABS, MVT::i32, Custom);
 
   if (!Subtarget.hasVendorXTHeadCondMov())
     setOperationAction(ISD::SELECT, XLenVT, Custom);
 
+  if (RV64LegalI32 && Subtarget.is64Bit())
+    setOperationAction(ISD::SELECT, MVT::i32, Promote);
+
   static const unsigned FPLegalNodeTypes[] = {
       ISD::FMINNUM,        ISD::FMAXNUM,       ISD::LRINT,
       ISD::LLRINT,         ISD::LROUND,        ISD::LLROUND,
@@ -525,6 +589,11 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
                         ISD::STRICT_UINT_TO_FP, ISD::STRICT_SINT_TO_FP},
                        XLenVT, Legal);
 
+    if (RV64LegalI32 && Subtarget.is64Bit())
+      setOperationAction({ISD::STRICT_FP_TO_UINT, ISD::STRICT_FP_TO_SINT,
+                          ISD::STRICT_UINT_TO_FP, ISD::STRICT_SINT_TO_FP},
+                         MVT::i32, Legal);
+
     setOperationAction(ISD::GET_ROUNDING, XLenVT, Custom);
     setOperationAction(ISD::SET_ROUNDING, MVT::Other, Custom);
   }
@@ -569,6 +638,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
     setBooleanVectorContents(ZeroOrOneBooleanContent);
 
     setOperationAction(ISD::VSCALE, XLenVT, Custom);
+    if (RV64LegalI32 && Subtarget.is64Bit())
+      setOperationAction(ISD::VSCALE, MVT::i32, Custom);
 
     // RVV intrinsics may have illegal operands.
     // We also need to custom legalize vmv.x.s.
@@ -1247,8 +1318,11 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
     }
   }
 
-  if (Subtarget.hasStdExtA())
+  if (Subtarget.hasStdExtA()) {
     setOperationAction(ISD::ATOMIC_LOAD_SUB, XLenVT, Expand);
+    if (RV64LegalI32 && Subtarget.is64Bit())
+      setOperationAction(ISD::ATOMIC_LOAD_SUB, MVT::i32, Expand);
+  }
 
   if (Subtarget.hasForcedAtomics()) {
     // Force __sync libcalls to be emitted for atomic rmw/cas operations.
@@ -2079,7 +2153,12 @@ MVT RISCVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
       !Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin())
     return MVT::f32;
 
-  return TargetLowering::getRegisterTypeForCallingConv(Context, CC, VT);
+  MVT PartVT = TargetLowering::getRegisterTypeForCallingConv(Context, CC, VT);
+
+  if (RV64LegalI32 && Subtarget.is64Bit() && PartVT == MVT::i32)
+    return MVT::i64;
+
+  return PartVT;
 }
 
 unsigned RISCVTargetLowering::getNumRegistersForCallingConv(LLVMContext &Context,
@@ -2094,6 +2173,21 @@ unsigned RISCVTargetLowering::getNumRegistersForCallingConv(LLVMContext &Context
   return TargetLowering::getNumRegistersForCallingConv(Context, CC, VT);
 }
 
+unsigned RISCVTargetLowering::getVectorTypeBreakdownForCallingConv(
+    LLVMContext &Context, CallingConv::ID CC, EVT VT, EVT &IntermediateVT,
+    unsigned &NumIntermediates, MVT &RegisterVT) const {
+  unsigned NumRegs = TargetLowering::getVectorTypeBreakdownForCallingConv(
+      Context, CC, VT, IntermediateVT, NumIntermediates, RegisterVT);
+
+  if (RV64LegalI32 && Subtarget.is64Bit() && IntermediateVT == MVT::i32)
+    IntermediateVT = MVT::i64;
+
+  if (RV64LegalI32 && Subtarget.is64Bit() && RegisterVT == MVT::i32)
+    RegisterVT = MVT::i64;
+
+...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/70357


More information about the llvm-commits mailing list