[llvm] [RISCV] Use software guarded branch for indirect jump table branch. (PR #66762)

via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 19 03:52:19 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

<details>
<summary>Changes</summary>

When Zicfilp enabled, indirect jump table branch should be a software guarded branch.

---
Full diff: https://github.com/llvm/llvm-project/pull/66762.diff


5 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+14) 
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.h (+8) 
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfo.td (+16) 
- (modified) llvm/lib/Target/RISCV/RISCVRegisterInfo.td (+2) 
- (added) llvm/test/CodeGen/RISCV/jumptable-swguarded.ll (+105) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 19e407414627dd4..cc94839f773d00e 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18741,6 +18741,20 @@ unsigned RISCVTargetLowering::getCustomCtpopCost(EVT VT,
   return isCtpopFast(VT) ? 0 : 1;
 }
 
+SDValue RISCVTargetLowering::expandIndirectJTBranch(const SDLoc &dl,
+                                                    SDValue Value, SDValue Addr,
+                                                    int JTI,
+                                                    SelectionDAG &DAG) const {
+  if (Subtarget.hasStdExtZicfilp()) {
+    // When Zicfilp enabled, we need to use software guarded branch for jump
+    // table branch.
+    SDValue JTInfo = DAG.getJumpTableDebugInfo(JTI, Value, dl);
+    return DAG.getNode(RISCVISD::SW_GUARDED_BRIND, dl, MVT::Other, JTInfo,
+                       Addr);
+  }
+  return TargetLowering::expandIndirectJTBranch(dl, Value, Addr, JTI, DAG);
+}
+
 namespace llvm::RISCVVIntrinsicsTable {
 
 #define GET_RISCVVIntrinsicsTable_IMPL
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 815b9be47f56026..3af4e43da5b3cf5 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -387,6 +387,10 @@ enum NodeType : unsigned {
   CZERO_EQZ, // vt.maskc for XVentanaCondOps.
   CZERO_NEZ, // vt.maskcn for XVentanaCondOps.
 
+  /// Software guarded BRIND node. Operand 0 is the chain operand and
+  /// operand 1 is the target address.
+  SW_GUARDED_BRIND,
+
   // FP to 32 bit int conversions for RV64. These are used to keep track of the
   // result being sign extended to 64 bit. These saturate out of range inputs.
   STRICT_FCVT_W_RV64 = ISD::FIRST_TARGET_STRICTFP_OPCODE,
@@ -815,6 +819,10 @@ class RISCVTargetLowering : public TargetLowering {
 
   bool supportKCFIBundles() const override { return true; }
 
+  SDValue expandIndirectJTBranch(const SDLoc &dl, SDValue Value, SDValue Addr,
+                                 int JTI, SelectionDAG &DAG) const override;
+
+
   MachineInstr *EmitKCFICheck(MachineBasicBlock &MBB,
                               MachineBasicBlock::instr_iterator &MBBI,
                               const TargetInstrInfo *TII) const override;
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.td b/llvm/lib/Target/RISCV/RISCVInstrInfo.td
index 582fe60fd0368e9..5dbf5794b82ef15 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.td
@@ -44,6 +44,7 @@ def SDT_RISCVIntBinOpW : SDTypeProfile<1, 2, [
 def SDT_RISCVIntShiftDOpW : SDTypeProfile<1, 3, [
   SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisVT<0, i64>, SDTCisVT<3, i64>
 ]>;
+def SDT_RISCVSWGuardedBrind : SDTypeProfile<0, -1, [SDTCisVT<0, iPTR>]>;
 
 // Target-independent nodes, but with target-specific formats.
 def callseq_start : SDNode<"ISD::CALLSEQ_START", SDT_CallSeqStart,
@@ -67,6 +68,8 @@ def riscv_brcc      : SDNode<"RISCVISD::BR_CC", SDT_RISCVBrCC,
 def riscv_tail      : SDNode<"RISCVISD::TAIL", SDT_RISCVCall,
                              [SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
                               SDNPVariadic]>;
+def riscv_sw_guarded_brind : SDNode<"RISCVISD::SW_GUARDED_BRIND",
+                                    SDT_RISCVSWGuardedBrind, [SDNPHasChain]>;
 def riscv_sllw      : SDNode<"RISCVISD::SLLW", SDT_RISCVIntBinOpW>;
 def riscv_sraw      : SDNode<"RISCVISD::SRAW", SDT_RISCVIntBinOpW>;
 def riscv_srlw      : SDNode<"RISCVISD::SRLW", SDT_RISCVIntBinOpW>;
@@ -1554,6 +1557,13 @@ let isBarrier = 1, isBranch = 1, isIndirectBranch = 1, isTerminator = 1 in
 def PseudoBRIND : Pseudo<(outs), (ins GPRJALR:$rs1, simm12:$imm12), []>,
                   PseudoInstExpansion<(JALR X0, GPR:$rs1, simm12:$imm12)>;
 
+let Predicates = [HasStdExtZicfilp] in {
+let isBarrier = 1, isBranch = 1, isIndirectBranch = 1, isTerminator = 1 in
+def PseudoBRINDX7 : Pseudo<(outs), (ins GPRX7:$rs1, simm12:$imm12), []>,
+                    PseudoInstExpansion<(JALR X0, GPR:$rs1, simm12:$imm12)>;
+
+}
+
 def : Pat<(brind GPRJALR:$rs1), (PseudoBRIND GPRJALR:$rs1, 0)>;
 def : Pat<(brind (add GPRJALR:$rs1, simm12:$imm12)),
           (PseudoBRIND GPRJALR:$rs1, simm12:$imm12)>;
@@ -1945,6 +1955,12 @@ def : Pat<(binop_allwusers<add> GPR:$rs1, (AddiPair:$rs2)),
                  (AddiPairImmSmall AddiPair:$rs2))>;
 }
 
+let Predicates = [HasStdExtZicfilp] in {
+def : Pat<(riscv_sw_guarded_brind GPRX7:$rs1), (PseudoBRINDX7 GPRX7:$rs1, 0)>;
+def : Pat<(riscv_sw_guarded_brind (add GPRX7:$rs1, simm12:$imm12)),
+          (PseudoBRINDX7 GPRX7:$rs1, simm12:$imm12)>;
+}
+
 //===----------------------------------------------------------------------===//
 // Standard extensions
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/RISCV/RISCVRegisterInfo.td b/llvm/lib/Target/RISCV/RISCVRegisterInfo.td
index 1a6145f92908134..c5315004e40d1f9 100644
--- a/llvm/lib/Target/RISCV/RISCVRegisterInfo.td
+++ b/llvm/lib/Target/RISCV/RISCVRegisterInfo.td
@@ -142,6 +142,8 @@ def GPRNoX0 : GPRRegisterClass<(sub GPR, X0)>;
 
 def GPRNoX0X2 : GPRRegisterClass<(sub GPR, X0, X2)>;
 
+def GPRX7 : GPRRegisterClass<(add X7)>;
+
 // Don't use X1 or X5 for JALR since that is a hint to pop the return address
 // stack on some microarchitectures. Also remove the reserved registers X0, X2,
 // X3, and X4 as it reduces the number of register classes that get synthesized
diff --git a/llvm/test/CodeGen/RISCV/jumptable-swguarded.ll b/llvm/test/CodeGen/RISCV/jumptable-swguarded.ll
new file mode 100644
index 000000000000000..68ef376f24dbc67
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/jumptable-swguarded.ll
@@ -0,0 +1,105 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple riscv32 -mattr=+experimental-zicfilp < %s | FileCheck %s
+; RUN: llc -mtriple riscv64 -mattr=+experimental-zicfilp < %s | FileCheck %s
+; RUN: llc -mtriple riscv32 < %s | FileCheck %s --check-prefix=NO-ZICFILP
+; RUN: llc -mtriple riscv64 < %s | FileCheck %s --check-prefix=NO-ZICFILP
+
+; Test using t2 to jump table branch.
+define void @above_threshold(i32 signext %in, ptr %out) nounwind {
+; CHECK-LABEL: above_threshold:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    addi a0, a0, -1
+; CHECK-NEXT:    li a2, 5
+; CHECK-NEXT:    bltu a2, a0, .LBB0_9
+; CHECK-NEXT:  # %bb.1: # %entry
+; CHECK-NEXT:    slli a0, a0, 2
+; CHECK-NEXT:    lui a2, %hi(.LJTI0_0)
+; CHECK-NEXT:    addi a2, a2, %lo(.LJTI0_0)
+; CHECK-NEXT:    add a0, a0, a2
+; CHECK-NEXT:    lw t2, 0(a0)
+; CHECK-NEXT:    jr t2
+; CHECK-NEXT:  .LBB0_2: # %bb1
+; CHECK-NEXT:    li a0, 4
+; CHECK-NEXT:    j .LBB0_8
+; CHECK-NEXT:  .LBB0_3: # %bb2
+; CHECK-NEXT:    li a0, 3
+; CHECK-NEXT:    j .LBB0_8
+; CHECK-NEXT:  .LBB0_4: # %bb3
+; CHECK-NEXT:    li a0, 2
+; CHECK-NEXT:    j .LBB0_8
+; CHECK-NEXT:  .LBB0_5: # %bb4
+; CHECK-NEXT:    li a0, 1
+; CHECK-NEXT:    j .LBB0_8
+; CHECK-NEXT:  .LBB0_6: # %bb5
+; CHECK-NEXT:    li a0, 100
+; CHECK-NEXT:    j .LBB0_8
+; CHECK-NEXT:  .LBB0_7: # %bb6
+; CHECK-NEXT:    li a0, 200
+; CHECK-NEXT:  .LBB0_8: # %exit
+; CHECK-NEXT:    sw a0, 0(a1)
+; CHECK-NEXT:  .LBB0_9: # %exit
+; CHECK-NEXT:    ret
+;
+; NO-ZICFILP-LABEL: above_threshold:
+; NO-ZICFILP:       # %bb.0: # %entry
+; NO-ZICFILP-NEXT:    addi a0, a0, -1
+; NO-ZICFILP-NEXT:    li a2, 5
+; NO-ZICFILP-NEXT:    bltu a2, a0, .LBB0_9
+; NO-ZICFILP-NEXT:  # %bb.1: # %entry
+; NO-ZICFILP-NEXT:    slli a0, a0, 2
+; NO-ZICFILP-NEXT:    lui a2, %hi(.LJTI0_0)
+; NO-ZICFILP-NEXT:    addi a2, a2, %lo(.LJTI0_0)
+; NO-ZICFILP-NEXT:    add a0, a0, a2
+; NO-ZICFILP-NEXT:    lw a0, 0(a0)
+; NO-ZICFILP-NEXT:    jr a0
+; NO-ZICFILP-NEXT:  .LBB0_2: # %bb1
+; NO-ZICFILP-NEXT:    li a0, 4
+; NO-ZICFILP-NEXT:    j .LBB0_8
+; NO-ZICFILP-NEXT:  .LBB0_3: # %bb2
+; NO-ZICFILP-NEXT:    li a0, 3
+; NO-ZICFILP-NEXT:    j .LBB0_8
+; NO-ZICFILP-NEXT:  .LBB0_4: # %bb3
+; NO-ZICFILP-NEXT:    li a0, 2
+; NO-ZICFILP-NEXT:    j .LBB0_8
+; NO-ZICFILP-NEXT:  .LBB0_5: # %bb4
+; NO-ZICFILP-NEXT:    li a0, 1
+; NO-ZICFILP-NEXT:    j .LBB0_8
+; NO-ZICFILP-NEXT:  .LBB0_6: # %bb5
+; NO-ZICFILP-NEXT:    li a0, 100
+; NO-ZICFILP-NEXT:    j .LBB0_8
+; NO-ZICFILP-NEXT:  .LBB0_7: # %bb6
+; NO-ZICFILP-NEXT:    li a0, 200
+; NO-ZICFILP-NEXT:  .LBB0_8: # %exit
+; NO-ZICFILP-NEXT:    sw a0, 0(a1)
+; NO-ZICFILP-NEXT:  .LBB0_9: # %exit
+; NO-ZICFILP-NEXT:    ret
+entry:
+  switch i32 %in, label %exit [
+    i32 1, label %bb1
+    i32 2, label %bb2
+    i32 3, label %bb3
+    i32 4, label %bb4
+    i32 5, label %bb5
+    i32 6, label %bb6
+  ]
+bb1:
+  store i32 4, ptr %out
+  br label %exit
+bb2:
+  store i32 3, ptr %out
+  br label %exit
+bb3:
+  store i32 2, ptr %out
+  br label %exit
+bb4:
+  store i32 1, ptr %out
+  br label %exit
+bb5:
+  store i32 100, ptr %out
+  br label %exit
+bb6:
+  store i32 200, ptr %out
+  br label %exit
+exit:
+  ret void
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/66762


More information about the llvm-commits mailing list