[llvm-branch-commits] [RISCV] Support llvm.readsteadycounter intrinsic (PR #82322)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Feb 20 00:56:49 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Wang Pengcheng (wangpc-pp)

<details>
<summary>Changes</summary>

This intrinsic was introduced by #<!-- -->81331, which is a lot like
`llvm.readcyclecounter`.

For the RISCV implementation, we rename `ReadCycleWide` pseudo to
`ReadCounterWide` and make it accept two operands (the low and high
parts of the counter). As for legalization and lowering parts, we
reuse the code of `ISD::READCYCLECOUNTER` (make it able to handle
both intrinsics).

Tests using Clang builtins are runned on real hardware and it works
as excepted.


---
Full diff: https://github.com/llvm/llvm-project/pull/82322.diff


4 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+44-25) 
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.h (+2-2) 
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfo.td (+20-13) 
- (added) llvm/test/CodeGen/RISCV/readsteadycounter.ll (+28) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 9ab6895aed521e..32d47a669020f1 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -629,6 +629,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
   // Unfortunately this can't be determined just from the ISA naming string.
   setOperationAction(ISD::READCYCLECOUNTER, MVT::i64,
                      Subtarget.is64Bit() ? Legal : Custom);
+  setOperationAction(ISD::READSTEADYCOUNTER, MVT::i64,
+                     Subtarget.is64Bit() ? Legal : Custom);
 
   setOperationAction({ISD::TRAP, ISD::DEBUGTRAP}, MVT::Other, Legal);
   setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom);
@@ -11725,13 +11727,27 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
     Results.push_back(Result);
     break;
   }
-  case ISD::READCYCLECOUNTER: {
-    assert(!Subtarget.is64Bit() &&
-           "READCYCLECOUNTER only has custom type legalization on riscv32");
+  case ISD::READCYCLECOUNTER:
+  case ISD::READSTEADYCOUNTER: {
+    assert(!Subtarget.is64Bit() && "READCYCLECOUNTER/READSTEADYCOUNTER only "
+                                   "has custom type legalization on riscv32");
 
+    SDValue LoCounter, HiCounter;
+    MVT XLenVT = Subtarget.getXLenVT();
+    if (N->getOpcode() == ISD::READCYCLECOUNTER) {
+      LoCounter = DAG.getConstant(
+          RISCVSysReg::lookupSysRegByName("CYCLE")->Encoding, DL, XLenVT);
+      HiCounter = DAG.getConstant(
+          RISCVSysReg::lookupSysRegByName("CYCLEH")->Encoding, DL, XLenVT);
+    } else if (N->getOpcode() == ISD::READSTEADYCOUNTER) {
+      LoCounter = DAG.getConstant(
+          RISCVSysReg::lookupSysRegByName("TIME")->Encoding, DL, XLenVT);
+      HiCounter = DAG.getConstant(
+          RISCVSysReg::lookupSysRegByName("TIMEH")->Encoding, DL, XLenVT);
+    }
     SDVTList VTs = DAG.getVTList(MVT::i32, MVT::i32, MVT::Other);
-    SDValue RCW =
-        DAG.getNode(RISCVISD::READ_CYCLE_WIDE, DL, VTs, N->getOperand(0));
+    SDValue RCW = DAG.getNode(RISCVISD::READ_COUNTER_WIDE, DL, VTs,
+                              N->getOperand(0), LoCounter, HiCounter);
 
     Results.push_back(
         DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i64, RCW, RCW.getValue(1)));
@@ -16903,29 +16919,30 @@ RISCVTargetLowering::getTargetConstantFromLoad(LoadSDNode *Ld) const {
   return CNodeLo->getConstVal();
 }
 
-static MachineBasicBlock *emitReadCycleWidePseudo(MachineInstr &MI,
-                                                  MachineBasicBlock *BB) {
-  assert(MI.getOpcode() == RISCV::ReadCycleWide && "Unexpected instruction");
+static MachineBasicBlock *emitReadCounterWidePseudo(MachineInstr &MI,
+                                                    MachineBasicBlock *BB) {
+  assert(MI.getOpcode() == RISCV::ReadCounterWide && "Unexpected instruction");
 
-  // To read the 64-bit cycle CSR on a 32-bit target, we read the two halves.
+  // To read a 64-bit counter CSR on a 32-bit target, we read the two halves.
   // Should the count have wrapped while it was being read, we need to try
   // again.
-  // ...
+  // For example:
+  // ```
   // read:
-  // rdcycleh x3 # load high word of cycle
-  // rdcycle  x2 # load low word of cycle
-  // rdcycleh x4 # load high word of cycle
-  // bne x3, x4, read # check if high word reads match, otherwise try again
-  // ...
+  //   csrrs x3, counter # load high word of counter
+  //   csrrs x2, counterh # load low word of counter
+  //   csrrs x4, counter # load high word of counter
+  //   bne x3, x4, read # check if high word reads match, otherwise try again
+  // ```
 
   MachineFunction &MF = *BB->getParent();
-  const BasicBlock *LLVM_BB = BB->getBasicBlock();
+  const BasicBlock *LLVMBB = BB->getBasicBlock();
   MachineFunction::iterator It = ++BB->getIterator();
 
-  MachineBasicBlock *LoopMBB = MF.CreateMachineBasicBlock(LLVM_BB);
+  MachineBasicBlock *LoopMBB = MF.CreateMachineBasicBlock(LLVMBB);
   MF.insert(It, LoopMBB);
 
-  MachineBasicBlock *DoneMBB = MF.CreateMachineBasicBlock(LLVM_BB);
+  MachineBasicBlock *DoneMBB = MF.CreateMachineBasicBlock(LLVMBB);
   MF.insert(It, DoneMBB);
 
   // Transfer the remainder of BB and its successor edges to DoneMBB.
@@ -16939,17 +16956,19 @@ static MachineBasicBlock *emitReadCycleWidePseudo(MachineInstr &MI,
   Register ReadAgainReg = RegInfo.createVirtualRegister(&RISCV::GPRRegClass);
   Register LoReg = MI.getOperand(0).getReg();
   Register HiReg = MI.getOperand(1).getReg();
+  int64_t LoCounter = MI.getOperand(2).getImm();
+  int64_t HiCounter = MI.getOperand(3).getImm();
   DebugLoc DL = MI.getDebugLoc();
 
   const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
   BuildMI(LoopMBB, DL, TII->get(RISCV::CSRRS), HiReg)
-      .addImm(RISCVSysReg::lookupSysRegByName("CYCLEH")->Encoding)
+      .addImm(HiCounter)
       .addReg(RISCV::X0);
   BuildMI(LoopMBB, DL, TII->get(RISCV::CSRRS), LoReg)
-      .addImm(RISCVSysReg::lookupSysRegByName("CYCLE")->Encoding)
+      .addImm(LoCounter)
       .addReg(RISCV::X0);
   BuildMI(LoopMBB, DL, TII->get(RISCV::CSRRS), ReadAgainReg)
-      .addImm(RISCVSysReg::lookupSysRegByName("CYCLEH")->Encoding)
+      .addImm(HiCounter)
       .addReg(RISCV::X0);
 
   BuildMI(LoopMBB, DL, TII->get(RISCV::BNE))
@@ -17528,10 +17547,10 @@ RISCVTargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
   switch (MI.getOpcode()) {
   default:
     llvm_unreachable("Unexpected instr type to insert");
-  case RISCV::ReadCycleWide:
+  case RISCV::ReadCounterWide:
     assert(!Subtarget.is64Bit() &&
-           "ReadCycleWrite is only to be used on riscv32");
-    return emitReadCycleWidePseudo(MI, BB);
+           "ReadCounterWide is only to be used on riscv32");
+    return emitReadCounterWidePseudo(MI, BB);
   case RISCV::Select_GPR_Using_CC_GPR:
   case RISCV::Select_FPR16_Using_CC_GPR:
   case RISCV::Select_FPR16INX_Using_CC_GPR:
@@ -19203,7 +19222,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(FCLASS)
   NODE_NAME_CASE(FMAX)
   NODE_NAME_CASE(FMIN)
-  NODE_NAME_CASE(READ_CYCLE_WIDE)
+  NODE_NAME_CASE(READ_COUNTER_WIDE)
   NODE_NAME_CASE(BREV8)
   NODE_NAME_CASE(ORC_B)
   NODE_NAME_CASE(ZIP)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 255b1d0e15eedd..879af0ecdf8bc0 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -126,9 +126,9 @@ enum NodeType : unsigned {
   // Floating point fmax and fmin matching the RISC-V instruction semantics.
   FMAX, FMIN,
 
-  // READ_CYCLE_WIDE - A read of the 64-bit cycle CSR on a 32-bit target
+  // READ_COUNTER_WIDE - A read of the 64-bit counter CSR on a 32-bit target
   // (returns (Lo, Hi)). It takes a chain operand.
-  READ_CYCLE_WIDE,
+  READ_COUNTER_WIDE,
   // brev8, orc.b, zip, and unzip from Zbb and Zbkb. All operands are i32 or
   // XLenVT.
   BREV8,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.td b/llvm/lib/Target/RISCV/RISCVInstrInfo.td
index 7fe9b626b66d68..0d2ffac4883a34 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.td
@@ -33,8 +33,10 @@ def SDT_RISCVReadCSR  : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisInt<1>]>;
 def SDT_RISCVWriteCSR : SDTypeProfile<0, 2, [SDTCisInt<0>, SDTCisInt<1>]>;
 def SDT_RISCVSwapCSR  : SDTypeProfile<1, 2, [SDTCisInt<0>, SDTCisInt<1>,
                                              SDTCisInt<2>]>;
-def SDT_RISCVReadCycleWide : SDTypeProfile<2, 0, [SDTCisVT<0, i32>,
-                                                  SDTCisVT<1, i32>]>;
+def SDT_RISCVReadCounterWide : SDTypeProfile<2, 2, [SDTCisVT<0, i32>,
+                                                    SDTCisVT<1, i32>,
+                                                    SDTCisInt<2>,
+                                                    SDTCisInt<3>]>;
 def SDT_RISCVIntUnaryOpW : SDTypeProfile<1, 1, [
   SDTCisSameAs<0, 1>, SDTCisVT<0, i64>
 ]>;
@@ -77,9 +79,9 @@ def riscv_write_csr : SDNode<"RISCVISD::WRITE_CSR", SDT_RISCVWriteCSR,
 def riscv_swap_csr  : SDNode<"RISCVISD::SWAP_CSR", SDT_RISCVSwapCSR,
                              [SDNPHasChain]>;
 
-def riscv_read_cycle_wide : SDNode<"RISCVISD::READ_CYCLE_WIDE",
-                                   SDT_RISCVReadCycleWide,
-                                   [SDNPHasChain, SDNPSideEffect]>;
+def riscv_read_counter_wide : SDNode<"RISCVISD::READ_COUNTER_WIDE",
+                                     SDT_RISCVReadCounterWide,
+                                     [SDNPHasChain, SDNPSideEffect]>;
 
 def riscv_add_lo : SDNode<"RISCVISD::ADD_LO", SDTIntBinOp>;
 def riscv_hi : SDNode<"RISCVISD::HI", SDTIntUnaryOp>;
@@ -363,7 +365,7 @@ def CSRSystemRegister : AsmOperandClass {
   let DiagnosticType = "InvalidCSRSystemRegister";
 }
 
-def csr_sysreg : RISCVOp {
+def csr_sysreg : RISCVOp, ImmLeaf<XLenVT, "return isUInt<12>(Imm);"> {
   let ParserMatchClass = CSRSystemRegister;
   let PrintMethod = "printCSRSystemRegister";
   let DecoderMethod = "decodeUImmOperand<12>";
@@ -1827,16 +1829,21 @@ def : StPat<truncstorei32, SW, GPR, i64>;
 def : StPat<store, SD, GPR, i64>;
 } // Predicates = [IsRV64]
 
+// On RV64, we can directly read these 64-bit counter CSRs.
+let Predicates = [IsRV64] in {
 /// readcyclecounter
-// On RV64, we can directly read the 64-bit "cycle" CSR.
-let Predicates = [IsRV64] in
 def : Pat<(i64 (readcyclecounter)), (CSRRS CYCLE.Encoding, (XLenVT X0))>;
-// On RV32, ReadCycleWide will be expanded to the suggested loop reading both
-// halves of the 64-bit "cycle" CSR.
+/// readsteadycounter
+def : Pat<(i64 (readsteadycounter)), (CSRRS TIME.Encoding, (XLenVT X0))>;
+}
+
+// On RV32, ReadCounterWide will be expanded to the suggested loop reading both
+// halves of 64-bit counter CSRs.
 let Predicates = [IsRV32], usesCustomInserter = 1, hasNoSchedulingInfo = 1 in
-def ReadCycleWide : Pseudo<(outs GPR:$lo, GPR:$hi), (ins),
-                           [(set GPR:$lo, GPR:$hi, (riscv_read_cycle_wide))],
-                           "", "">;
+def ReadCounterWide : Pseudo<(outs GPR:$lo, GPR:$hi), (ins i32imm:$csr_lo, i32imm:$csr_hi),
+                             [(set GPR:$lo, GPR:$hi,
+                              (riscv_read_counter_wide csr_sysreg:$csr_lo, csr_sysreg:$csr_hi))],
+                             "", "">;
 
 /// traps
 
diff --git a/llvm/test/CodeGen/RISCV/readsteadycounter.ll b/llvm/test/CodeGen/RISCV/readsteadycounter.ll
new file mode 100644
index 00000000000000..19eab64530c66a
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/readsteadycounter.ll
@@ -0,0 +1,28 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv32 -verify-machineinstrs < %s \
+; RUN:   | FileCheck -check-prefix=RV32I %s
+; RUN: llc -mtriple=riscv64 -verify-machineinstrs < %s \
+; RUN:   | FileCheck -check-prefix=RV64I %s
+
+; Verify that we lower @llvm.readsteadycounter() correctly.
+
+declare i64 @llvm.readsteadycounter()
+
+define i64 @test_builtin_readsteadycounter() nounwind {
+; RV32I-LABEL: test_builtin_readsteadycounter:
+; RV32I:       # %bb.0:
+; RV32I-NEXT:  .LBB0_1: # =>This Inner Loop Header: Depth=1
+; RV32I-NEXT:    rdtimeh a1
+; RV32I-NEXT:    rdtime a0
+; RV32I-NEXT:    rdtimeh a2
+; RV32I-NEXT:    bne a1, a2, .LBB0_1
+; RV32I-NEXT:  # %bb.2:
+; RV32I-NEXT:    ret
+;
+; RV64I-LABEL: test_builtin_readsteadycounter:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    rdtime a0
+; RV64I-NEXT:    ret
+  %1 = tail call i64 @llvm.readsteadycounter()
+  ret i64 %1
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/82322


More information about the llvm-branch-commits mailing list