[llvm] [RISCV] Implement foward inserting save/restore FRM instructions. (PR #77744)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 11 02:16:37 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Yeting Kuo (yetingk)

<details>
<summary>Changes</summary>

Previously, RISCVInsertReadWriteCSR inserted an FRM swap for any value other than 7 and restored the original value right after the vector instruction. This is inefficient if multiple vector instructions use the same rounding mode or if the next vector instruction uses a different explicit rounding mode.

This patch implements a local optimization to solve the above problem. We assume the starting rounding mode of the basic block is "dynamic." When iterating through a basic block and encountering an instruction whose rounding mode is not the same as the current rounding mode, we change the current rounding mode and save the current rounding mode if needed. And we may need to restore FRM when encountering function call, inline asm and some uses of FRM.

The advanced version of this is to perform cross basic block analysis for the starting rounding mode of each basic block.

---
Full diff: https://github.com/llvm/llvm-project/pull/77744.diff


2 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVInsertReadWriteCSR.cpp (+114-2) 
- (added) llvm/test/CodeGen/RISCV/rvv/frm-insert.ll (+217) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVInsertReadWriteCSR.cpp b/llvm/lib/Target/RISCV/RISCVInsertReadWriteCSR.cpp
index b807abcc56819b..4d574a588adc46 100644
--- a/llvm/lib/Target/RISCV/RISCVInsertReadWriteCSR.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInsertReadWriteCSR.cpp
@@ -23,6 +23,10 @@ using namespace llvm;
 #define DEBUG_TYPE "riscv-insert-read-write-csr"
 #define RISCV_INSERT_READ_WRITE_CSR_NAME "RISC-V Insert Read/Write CSR Pass"
 
+static cl::opt<bool> DisableFRMInsertOpt(
+    "riscv-disable-frm-insert-opt", cl::init(false), cl::Hidden,
+    cl::desc("Disable optimized frm insertion."));
+
 namespace {
 
 class RISCVInsertReadWriteCSR : public MachineFunctionPass {
@@ -46,6 +50,7 @@ class RISCVInsertReadWriteCSR : public MachineFunctionPass {
 
 private:
   bool emitWriteRoundingMode(MachineBasicBlock &MBB);
+  bool emitWriteRoundingModeOpt(MachineBasicBlock &MBB);
 };
 
 } // end anonymous namespace
@@ -55,6 +60,109 @@ char RISCVInsertReadWriteCSR::ID = 0;
 INITIALIZE_PASS(RISCVInsertReadWriteCSR, DEBUG_TYPE,
                 RISCV_INSERT_READ_WRITE_CSR_NAME, false, false)
 
+// TODO: Use more accurate rounding mode at the start of MBB.
+bool RISCVInsertReadWriteCSR::emitWriteRoundingModeOpt(MachineBasicBlock &MBB) {
+  bool Changed = false;
+  MachineInstr *LastFRMChanger = nullptr;
+  std::optional<unsigned> CurrentRM = RISCVFPRndMode::DYN;
+  std::optional<Register> SavedFRM;
+
+  for (MachineInstr &MI : MBB) {
+    if (MI.getOpcode() == RISCV::SwapFRMImm ||
+        MI.getOpcode() == RISCV::WriteFRMImm ) {
+      CurrentRM = MI.getOperand(0).getImm();
+      SavedFRM = std::nullopt;
+      continue;
+    }
+
+    if (MI.getOpcode() == RISCV::WriteFRM) {
+      CurrentRM = RISCVFPRndMode::DYN;
+      SavedFRM = std::nullopt;
+      continue;
+    }
+
+    if (MI.isCall() || MI.isInlineAsm() || MI.readsRegister(RISCV::FRM)) {
+      // Restore FRM before unknown operations.
+      if (SavedFRM.has_value())
+        BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteFRM))
+          .addReg(*SavedFRM);
+      CurrentRM = RISCVFPRndMode::DYN;
+      SavedFRM = std::nullopt;
+      continue;
+    }
+
+    assert(!MI.modifiesRegister(RISCV::FRM) &&
+           "Expected that MI could not modify FRM.");
+
+    auto getInstructionRM = [](MachineInstr &MI) -> std::optional<unsigned> {
+      int FRMIdx = RISCVII::getFRMOpNum(MI.getDesc());
+      if (FRMIdx >= 0)
+        return MI.getOperand(FRMIdx).getImm();
+
+      if (!MI.hasRegisterImplicitUseOperand(RISCV::FRM))
+        return std::nullopt;
+
+      // FIXME: Return nullopt if the rounding mode of MI is not DYN, like
+      // FADD_S with RTZ.
+      return RISCVFPRndMode::DYN;
+    };
+
+    std::optional<unsigned> InstrRM = getInstructionRM(MI);
+
+    // Skip if MI does not need FRM.
+    if (!InstrRM.has_value())
+      continue;
+
+    if (InstrRM != RISCVFPRndMode::DYN)
+      LastFRMChanger = &MI;
+
+    if (!MI.readsRegister(RISCV::FRM))
+      MI.addOperand(MachineOperand::CreateReg(RISCV::FRM, /*IsDef*/ false,
+                                              /*IsImp*/ true));
+
+    // Skip if MI uses same rounding mode as FRM.
+    if (InstrRM == CurrentRM)
+      continue;
+
+    if (InstrRM == RISCVFPRndMode::DYN) {
+      if (!SavedFRM.has_value())
+        continue;
+      // SavedFRM not having a value means current FRM has correct rounding mode.
+      BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteFRM))
+        .addReg(*SavedFRM);
+      SavedFRM = std::nullopt;
+      CurrentRM = RISCVFPRndMode::DYN;
+      continue;
+    }
+
+    if (CurrentRM == RISCVFPRndMode::DYN) {
+      // Save current FRM value to SavedFRM.
+      MachineRegisterInfo *MRI = &MBB.getParent()->getRegInfo();
+      SavedFRM = MRI->createVirtualRegister(&RISCV::GPRRegClass);
+      BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::SwapFRMImm),
+              *SavedFRM)
+       .addImm(*InstrRM);
+    } else {
+      // Don't need to save current FRM when CurrentRM != DYN.
+      BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteFRMImm))
+        .addImm(*InstrRM);
+    }
+    CurrentRM = InstrRM;
+    Changed = true;
+  }
+
+  // Restore FRM if needed.
+  if (SavedFRM.has_value()) {
+    assert(LastFRMChanger && "Expected valid pointer.");
+    MachineInstrBuilder MIB =
+        BuildMI(*MBB.getParent(), {}, TII->get(RISCV::WriteFRM))
+            .addReg(*SavedFRM);
+    MBB.insertAfter(LastFRMChanger, MIB);
+  }
+
+  return Changed;
+}
+
 // This function also swaps frm and restores it when encountering an RVV
 // floating point instruction with a static rounding mode.
 bool RISCVInsertReadWriteCSR::emitWriteRoundingMode(MachineBasicBlock &MBB) {
@@ -99,8 +207,12 @@ bool RISCVInsertReadWriteCSR::runOnMachineFunction(MachineFunction &MF) {
 
   bool Changed = false;
 
-  for (MachineBasicBlock &MBB : MF)
-    Changed |= emitWriteRoundingMode(MBB);
+  for (MachineBasicBlock &MBB : MF) {
+    if (DisableFRMInsertOpt)
+      Changed |= emitWriteRoundingMode(MBB);
+    else
+      Changed |= emitWriteRoundingModeOpt(MBB);
+  }
 
   return Changed;
 }
diff --git a/llvm/test/CodeGen/RISCV/rvv/frm-insert.ll b/llvm/test/CodeGen/RISCV/rvv/frm-insert.ll
new file mode 100644
index 00000000000000..c75c35602fbabd
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/frm-insert.ll
@@ -0,0 +1,217 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs -target-abi=lp64d < %s | FileCheck %s
+
+declare <vscale x 1 x float> @llvm.riscv.vfadd.nxv1f32.nxv1f32(
+  <vscale x 1 x float>,
+  <vscale x 1 x float>,
+  <vscale x 1 x float>,
+  i64, i64);
+
+; Test only save/restore frm once.
+define <vscale x 1 x float> @test(<vscale x 1 x float> %0, <vscale x 1 x float> %1, i64 %2) nounwind {
+; CHECK-LABEL: test:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetvli zero, a0, e32, mf2, ta, ma
+; CHECK-NEXT:    fsrmi a0, 0
+; CHECK-NEXT:    vfadd.vv v8, v8, v9
+; CHECK-NEXT:    vfadd.vv v8, v8, v8
+; CHECK-NEXT:    fsrm a0
+; CHECK-NEXT:    ret
+entry:
+  %a = call <vscale x 1 x float> @llvm.riscv.vfadd.nxv1f32.nxv1f32(
+    <vscale x 1 x float> undef,
+    <vscale x 1 x float> %0,
+    <vscale x 1 x float> %1,
+    i64 0, i64 %2)
+  %b = call <vscale x 1 x float> @llvm.riscv.vfadd.nxv1f32.nxv1f32(
+    <vscale x 1 x float> undef,
+    <vscale x 1 x float> %a,
+    <vscale x 1 x float> %a,
+    i64 0, i64 %2)
+  ret <vscale x 1 x float> %b
+}
+
+; Test only restore frm once.
+define <vscale x 1 x float> @test2(<vscale x 1 x float> %0, <vscale x 1 x float> %1, i64 %2) nounwind {
+; CHECK-LABEL: test2:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetvli zero, a0, e32, mf2, ta, ma
+; CHECK-NEXT:    fsrmi a0, 0
+; CHECK-NEXT:    vfadd.vv v8, v8, v9
+; CHECK-NEXT:    fsrmi 1
+; CHECK-NEXT:    vfadd.vv v8, v8, v8
+; CHECK-NEXT:    fsrm a0
+; CHECK-NEXT:    ret
+entry:
+  %a = call <vscale x 1 x float> @llvm.riscv.vfadd.nxv1f32.nxv1f32(
+    <vscale x 1 x float> undef,
+    <vscale x 1 x float> %0,
+    <vscale x 1 x float> %1,
+    i64 0, i64 %2)
+  %b = call <vscale x 1 x float> @llvm.riscv.vfadd.nxv1f32.nxv1f32(
+    <vscale x 1 x float> undef,
+    <vscale x 1 x float> %a,
+    <vscale x 1 x float> %a,
+    i64 1, i64 %2)
+  ret <vscale x 1 x float> %b
+}
+
+; Test restoring frm before function call and doing nothing with folling dynamic
+; rounding mode operations.
+declare void @foo()
+define <vscale x 1 x float> @test3(<vscale x 1 x float> %0, <vscale x 1 x float> %1, i64 %2) nounwind {
+; CHECK-LABEL: test3:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    addi sp, sp, -32
+; CHECK-NEXT:    sd ra, 24(sp) # 8-byte Folded Spill
+; CHECK-NEXT:    sd s0, 16(sp) # 8-byte Folded Spill
+; CHECK-NEXT:    csrr a1, vlenb
+; CHECK-NEXT:    slli a1, a1, 1
+; CHECK-NEXT:    sub sp, sp, a1
+; CHECK-NEXT:    mv s0, a0
+; CHECK-NEXT:    vsetvli zero, a0, e32, mf2, ta, ma
+; CHECK-NEXT:    fsrmi a0, 0
+; CHECK-NEXT:    vfadd.vv v8, v8, v9
+; CHECK-NEXT:    addi a1, sp, 16
+; CHECK-NEXT:    vs1r.v v8, (a1) # Unknown-size Folded Spill
+; CHECK-NEXT:    fsrm a0
+; CHECK-NEXT:    call foo
+; CHECK-NEXT:    vsetvli zero, s0, e32, mf2, ta, ma
+; CHECK-NEXT:    addi a0, sp, 16
+; CHECK-NEXT:    vl1r.v v8, (a0) # Unknown-size Folded Reload
+; CHECK-NEXT:    vfadd.vv v8, v8, v8
+; CHECK-NEXT:    csrr a0, vlenb
+; CHECK-NEXT:    slli a0, a0, 1
+; CHECK-NEXT:    add sp, sp, a0
+; CHECK-NEXT:    ld ra, 24(sp) # 8-byte Folded Reload
+; CHECK-NEXT:    ld s0, 16(sp) # 8-byte Folded Reload
+; CHECK-NEXT:    addi sp, sp, 32
+; CHECK-NEXT:    ret
+entry:
+  %a = call <vscale x 1 x float> @llvm.riscv.vfadd.nxv1f32.nxv1f32(
+    <vscale x 1 x float> undef,
+    <vscale x 1 x float> %0,
+    <vscale x 1 x float> %1,
+    i64 0, i64 %2)
+  call void @foo()
+  %b = call <vscale x 1 x float> @llvm.riscv.vfadd.nxv1f32.nxv1f32(
+    <vscale x 1 x float> undef,
+    <vscale x 1 x float> %a,
+    <vscale x 1 x float> %a,
+    i64 7, i64 %2)
+  ret <vscale x 1 x float> %b
+}
+
+; Test restoring frm before inline asm and doing nothing with folling dynamic
+; rounding mode operations.
+define <vscale x 1 x float> @test4(<vscale x 1 x float> %0, <vscale x 1 x float> %1, i64 %2) nounwind {
+; CHECK-LABEL: test4:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetvli zero, a0, e32, mf2, ta, ma
+; CHECK-NEXT:    fsrmi a1, 0
+; CHECK-NEXT:    vfadd.vv v8, v8, v9
+; CHECK-NEXT:    fsrm a1
+; CHECK-NEXT:    #APP
+; CHECK-NEXT:    #NO_APP
+; CHECK-NEXT:    vsetvli zero, a0, e32, mf2, ta, ma
+; CHECK-NEXT:    vfadd.vv v8, v8, v8
+; CHECK-NEXT:    ret
+entry:
+  %a = call <vscale x 1 x float> @llvm.riscv.vfadd.nxv1f32.nxv1f32(
+    <vscale x 1 x float> undef,
+    <vscale x 1 x float> %0,
+    <vscale x 1 x float> %1,
+    i64 0, i64 %2)
+  call void asm sideeffect "", ""()
+  %b = call <vscale x 1 x float> @llvm.riscv.vfadd.nxv1f32.nxv1f32(
+    <vscale x 1 x float> undef,
+    <vscale x 1 x float> %a,
+    <vscale x 1 x float> %a,
+    i64 7, i64 %2)
+  ret <vscale x 1 x float> %b
+}
+
+; Test restoring frm before reading frmm and doing nothing with folling dynamic
+; rounding mode operations.
+declare i32 @llvm.get.rounding()
+define <vscale x 1 x float> @test5(<vscale x 1 x float> %0, <vscale x 1 x float> %1, i64 %2, ptr %p) nounwind {
+; CHECK-LABEL: test5:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetvli zero, a0, e32, mf2, ta, ma
+; CHECK-NEXT:    fsrmi a0, 0
+; CHECK-NEXT:    vfadd.vv v8, v8, v9
+; CHECK-NEXT:    fsrm a0
+; CHECK-NEXT:    frrm a0
+; CHECK-NEXT:    slli a0, a0, 2
+; CHECK-NEXT:    lui a2, 66
+; CHECK-NEXT:    addiw a2, a2, 769
+; CHECK-NEXT:    srl a0, a2, a0
+; CHECK-NEXT:    andi a0, a0, 7
+; CHECK-NEXT:    vfadd.vv v8, v8, v8
+; CHECK-NEXT:    sw a0, 0(a1)
+; CHECK-NEXT:    ret
+entry:
+  %a = call <vscale x 1 x float> @llvm.riscv.vfadd.nxv1f32.nxv1f32(
+    <vscale x 1 x float> undef,
+    <vscale x 1 x float> %0,
+    <vscale x 1 x float> %1,
+    i64 0, i64 %2)
+  %rm = call i32 @llvm.get.rounding()
+  store i32 %rm, ptr %p, align 4
+  %b = call <vscale x 1 x float> @llvm.riscv.vfadd.nxv1f32.nxv1f32(
+    <vscale x 1 x float> undef,
+    <vscale x 1 x float> %a,
+    <vscale x 1 x float> %a,
+    i64 7, i64 %2)
+  ret <vscale x 1 x float> %b
+}
+
+; Test not set FRM for the two vfadd after WriteFRMImm.
+declare void @llvm.set.rounding(i32)
+define <vscale x 1 x float> @test6(<vscale x 1 x float> %0, <vscale x 1 x float> %1, i64 %2) nounwind {
+; CHECK-LABEL: test6:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    fsrmi 4
+; CHECK-NEXT:    vsetvli zero, a0, e32, mf2, ta, ma
+; CHECK-NEXT:    vfadd.vv v8, v8, v9
+; CHECK-NEXT:    vfadd.vv v8, v8, v8
+; CHECK-NEXT:    ret
+entry:
+  call void @llvm.set.rounding(i32 4)
+  %a = call <vscale x 1 x float> @llvm.riscv.vfadd.nxv1f32.nxv1f32(
+    <vscale x 1 x float> undef,
+    <vscale x 1 x float> %0,
+    <vscale x 1 x float> %1,
+    i64 4, i64 %2)
+  %b = call <vscale x 1 x float> @llvm.riscv.vfadd.nxv1f32.nxv1f32(
+    <vscale x 1 x float> undef,
+    <vscale x 1 x float> %a,
+    <vscale x 1 x float> %a,
+    i64 7, i64 %2)
+  ret <vscale x 1 x float> %b
+}
+
+; Test not set FRM for the vfadd after WriteFRM.
+define <vscale x 1 x float> @test7(<vscale x 1 x float> %0, <vscale x 1 x float> %1, i32 %rm, i64 %2) nounwind {
+; CHECK-LABEL: test7:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    slli a0, a0, 32
+; CHECK-NEXT:    srli a0, a0, 30
+; CHECK-NEXT:    lui a2, 66
+; CHECK-NEXT:    addiw a2, a2, 769
+; CHECK-NEXT:    srl a0, a2, a0
+; CHECK-NEXT:    andi a0, a0, 7
+; CHECK-NEXT:    fsrm a0
+; CHECK-NEXT:    vsetvli zero, a1, e32, mf2, ta, ma
+; CHECK-NEXT:    vfadd.vv v8, v8, v9
+; CHECK-NEXT:    ret
+entry:
+  call void @llvm.set.rounding(i32 %rm)
+  %a = call <vscale x 1 x float> @llvm.riscv.vfadd.nxv1f32.nxv1f32(
+    <vscale x 1 x float> undef,
+    <vscale x 1 x float> %0,
+    <vscale x 1 x float> %1,
+    i64 7, i64 %2)
+  ret <vscale x 1 x float> %a
+}
+

``````````

</details>


https://github.com/llvm/llvm-project/pull/77744


More information about the llvm-commits mailing list