[llvm] [RISCV] Use software guarded branch for indirect jump table branch. (PR #66762)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Sep 19 03:52:19 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-risc-v
<details>
<summary>Changes</summary>
When Zicfilp enabled, indirect jump table branch should be a software guarded branch.
---
Full diff: https://github.com/llvm/llvm-project/pull/66762.diff
5 Files Affected:
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+14)
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.h (+8)
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfo.td (+16)
- (modified) llvm/lib/Target/RISCV/RISCVRegisterInfo.td (+2)
- (added) llvm/test/CodeGen/RISCV/jumptable-swguarded.ll (+105)
``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 19e407414627dd4..cc94839f773d00e 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18741,6 +18741,20 @@ unsigned RISCVTargetLowering::getCustomCtpopCost(EVT VT,
return isCtpopFast(VT) ? 0 : 1;
}
+SDValue RISCVTargetLowering::expandIndirectJTBranch(const SDLoc &dl,
+ SDValue Value, SDValue Addr,
+ int JTI,
+ SelectionDAG &DAG) const {
+ if (Subtarget.hasStdExtZicfilp()) {
+ // When Zicfilp enabled, we need to use software guarded branch for jump
+ // table branch.
+ SDValue JTInfo = DAG.getJumpTableDebugInfo(JTI, Value, dl);
+ return DAG.getNode(RISCVISD::SW_GUARDED_BRIND, dl, MVT::Other, JTInfo,
+ Addr);
+ }
+ return TargetLowering::expandIndirectJTBranch(dl, Value, Addr, JTI, DAG);
+}
+
namespace llvm::RISCVVIntrinsicsTable {
#define GET_RISCVVIntrinsicsTable_IMPL
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 815b9be47f56026..3af4e43da5b3cf5 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -387,6 +387,10 @@ enum NodeType : unsigned {
CZERO_EQZ, // vt.maskc for XVentanaCondOps.
CZERO_NEZ, // vt.maskcn for XVentanaCondOps.
+ /// Software guarded BRIND node. Operand 0 is the chain operand and
+ /// operand 1 is the target address.
+ SW_GUARDED_BRIND,
+
// FP to 32 bit int conversions for RV64. These are used to keep track of the
// result being sign extended to 64 bit. These saturate out of range inputs.
STRICT_FCVT_W_RV64 = ISD::FIRST_TARGET_STRICTFP_OPCODE,
@@ -815,6 +819,10 @@ class RISCVTargetLowering : public TargetLowering {
bool supportKCFIBundles() const override { return true; }
+ SDValue expandIndirectJTBranch(const SDLoc &dl, SDValue Value, SDValue Addr,
+ int JTI, SelectionDAG &DAG) const override;
+
+
MachineInstr *EmitKCFICheck(MachineBasicBlock &MBB,
MachineBasicBlock::instr_iterator &MBBI,
const TargetInstrInfo *TII) const override;
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.td b/llvm/lib/Target/RISCV/RISCVInstrInfo.td
index 582fe60fd0368e9..5dbf5794b82ef15 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.td
@@ -44,6 +44,7 @@ def SDT_RISCVIntBinOpW : SDTypeProfile<1, 2, [
def SDT_RISCVIntShiftDOpW : SDTypeProfile<1, 3, [
SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisVT<0, i64>, SDTCisVT<3, i64>
]>;
+def SDT_RISCVSWGuardedBrind : SDTypeProfile<0, -1, [SDTCisVT<0, iPTR>]>;
// Target-independent nodes, but with target-specific formats.
def callseq_start : SDNode<"ISD::CALLSEQ_START", SDT_CallSeqStart,
@@ -67,6 +68,8 @@ def riscv_brcc : SDNode<"RISCVISD::BR_CC", SDT_RISCVBrCC,
def riscv_tail : SDNode<"RISCVISD::TAIL", SDT_RISCVCall,
[SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
SDNPVariadic]>;
+def riscv_sw_guarded_brind : SDNode<"RISCVISD::SW_GUARDED_BRIND",
+ SDT_RISCVSWGuardedBrind, [SDNPHasChain]>;
def riscv_sllw : SDNode<"RISCVISD::SLLW", SDT_RISCVIntBinOpW>;
def riscv_sraw : SDNode<"RISCVISD::SRAW", SDT_RISCVIntBinOpW>;
def riscv_srlw : SDNode<"RISCVISD::SRLW", SDT_RISCVIntBinOpW>;
@@ -1554,6 +1557,13 @@ let isBarrier = 1, isBranch = 1, isIndirectBranch = 1, isTerminator = 1 in
def PseudoBRIND : Pseudo<(outs), (ins GPRJALR:$rs1, simm12:$imm12), []>,
PseudoInstExpansion<(JALR X0, GPR:$rs1, simm12:$imm12)>;
+let Predicates = [HasStdExtZicfilp] in {
+let isBarrier = 1, isBranch = 1, isIndirectBranch = 1, isTerminator = 1 in
+def PseudoBRINDX7 : Pseudo<(outs), (ins GPRX7:$rs1, simm12:$imm12), []>,
+ PseudoInstExpansion<(JALR X0, GPR:$rs1, simm12:$imm12)>;
+
+}
+
def : Pat<(brind GPRJALR:$rs1), (PseudoBRIND GPRJALR:$rs1, 0)>;
def : Pat<(brind (add GPRJALR:$rs1, simm12:$imm12)),
(PseudoBRIND GPRJALR:$rs1, simm12:$imm12)>;
@@ -1945,6 +1955,12 @@ def : Pat<(binop_allwusers<add> GPR:$rs1, (AddiPair:$rs2)),
(AddiPairImmSmall AddiPair:$rs2))>;
}
+let Predicates = [HasStdExtZicfilp] in {
+def : Pat<(riscv_sw_guarded_brind GPRX7:$rs1), (PseudoBRINDX7 GPRX7:$rs1, 0)>;
+def : Pat<(riscv_sw_guarded_brind (add GPRX7:$rs1, simm12:$imm12)),
+ (PseudoBRINDX7 GPRX7:$rs1, simm12:$imm12)>;
+}
+
//===----------------------------------------------------------------------===//
// Standard extensions
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/RISCV/RISCVRegisterInfo.td b/llvm/lib/Target/RISCV/RISCVRegisterInfo.td
index 1a6145f92908134..c5315004e40d1f9 100644
--- a/llvm/lib/Target/RISCV/RISCVRegisterInfo.td
+++ b/llvm/lib/Target/RISCV/RISCVRegisterInfo.td
@@ -142,6 +142,8 @@ def GPRNoX0 : GPRRegisterClass<(sub GPR, X0)>;
def GPRNoX0X2 : GPRRegisterClass<(sub GPR, X0, X2)>;
+def GPRX7 : GPRRegisterClass<(add X7)>;
+
// Don't use X1 or X5 for JALR since that is a hint to pop the return address
// stack on some microarchitectures. Also remove the reserved registers X0, X2,
// X3, and X4 as it reduces the number of register classes that get synthesized
diff --git a/llvm/test/CodeGen/RISCV/jumptable-swguarded.ll b/llvm/test/CodeGen/RISCV/jumptable-swguarded.ll
new file mode 100644
index 000000000000000..68ef376f24dbc67
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/jumptable-swguarded.ll
@@ -0,0 +1,105 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple riscv32 -mattr=+experimental-zicfilp < %s | FileCheck %s
+; RUN: llc -mtriple riscv64 -mattr=+experimental-zicfilp < %s | FileCheck %s
+; RUN: llc -mtriple riscv32 < %s | FileCheck %s --check-prefix=NO-ZICFILP
+; RUN: llc -mtriple riscv64 < %s | FileCheck %s --check-prefix=NO-ZICFILP
+
+; Test using t2 to jump table branch.
+define void @above_threshold(i32 signext %in, ptr %out) nounwind {
+; CHECK-LABEL: above_threshold:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: addi a0, a0, -1
+; CHECK-NEXT: li a2, 5
+; CHECK-NEXT: bltu a2, a0, .LBB0_9
+; CHECK-NEXT: # %bb.1: # %entry
+; CHECK-NEXT: slli a0, a0, 2
+; CHECK-NEXT: lui a2, %hi(.LJTI0_0)
+; CHECK-NEXT: addi a2, a2, %lo(.LJTI0_0)
+; CHECK-NEXT: add a0, a0, a2
+; CHECK-NEXT: lw t2, 0(a0)
+; CHECK-NEXT: jr t2
+; CHECK-NEXT: .LBB0_2: # %bb1
+; CHECK-NEXT: li a0, 4
+; CHECK-NEXT: j .LBB0_8
+; CHECK-NEXT: .LBB0_3: # %bb2
+; CHECK-NEXT: li a0, 3
+; CHECK-NEXT: j .LBB0_8
+; CHECK-NEXT: .LBB0_4: # %bb3
+; CHECK-NEXT: li a0, 2
+; CHECK-NEXT: j .LBB0_8
+; CHECK-NEXT: .LBB0_5: # %bb4
+; CHECK-NEXT: li a0, 1
+; CHECK-NEXT: j .LBB0_8
+; CHECK-NEXT: .LBB0_6: # %bb5
+; CHECK-NEXT: li a0, 100
+; CHECK-NEXT: j .LBB0_8
+; CHECK-NEXT: .LBB0_7: # %bb6
+; CHECK-NEXT: li a0, 200
+; CHECK-NEXT: .LBB0_8: # %exit
+; CHECK-NEXT: sw a0, 0(a1)
+; CHECK-NEXT: .LBB0_9: # %exit
+; CHECK-NEXT: ret
+;
+; NO-ZICFILP-LABEL: above_threshold:
+; NO-ZICFILP: # %bb.0: # %entry
+; NO-ZICFILP-NEXT: addi a0, a0, -1
+; NO-ZICFILP-NEXT: li a2, 5
+; NO-ZICFILP-NEXT: bltu a2, a0, .LBB0_9
+; NO-ZICFILP-NEXT: # %bb.1: # %entry
+; NO-ZICFILP-NEXT: slli a0, a0, 2
+; NO-ZICFILP-NEXT: lui a2, %hi(.LJTI0_0)
+; NO-ZICFILP-NEXT: addi a2, a2, %lo(.LJTI0_0)
+; NO-ZICFILP-NEXT: add a0, a0, a2
+; NO-ZICFILP-NEXT: lw a0, 0(a0)
+; NO-ZICFILP-NEXT: jr a0
+; NO-ZICFILP-NEXT: .LBB0_2: # %bb1
+; NO-ZICFILP-NEXT: li a0, 4
+; NO-ZICFILP-NEXT: j .LBB0_8
+; NO-ZICFILP-NEXT: .LBB0_3: # %bb2
+; NO-ZICFILP-NEXT: li a0, 3
+; NO-ZICFILP-NEXT: j .LBB0_8
+; NO-ZICFILP-NEXT: .LBB0_4: # %bb3
+; NO-ZICFILP-NEXT: li a0, 2
+; NO-ZICFILP-NEXT: j .LBB0_8
+; NO-ZICFILP-NEXT: .LBB0_5: # %bb4
+; NO-ZICFILP-NEXT: li a0, 1
+; NO-ZICFILP-NEXT: j .LBB0_8
+; NO-ZICFILP-NEXT: .LBB0_6: # %bb5
+; NO-ZICFILP-NEXT: li a0, 100
+; NO-ZICFILP-NEXT: j .LBB0_8
+; NO-ZICFILP-NEXT: .LBB0_7: # %bb6
+; NO-ZICFILP-NEXT: li a0, 200
+; NO-ZICFILP-NEXT: .LBB0_8: # %exit
+; NO-ZICFILP-NEXT: sw a0, 0(a1)
+; NO-ZICFILP-NEXT: .LBB0_9: # %exit
+; NO-ZICFILP-NEXT: ret
+entry:
+ switch i32 %in, label %exit [
+ i32 1, label %bb1
+ i32 2, label %bb2
+ i32 3, label %bb3
+ i32 4, label %bb4
+ i32 5, label %bb5
+ i32 6, label %bb6
+ ]
+bb1:
+ store i32 4, ptr %out
+ br label %exit
+bb2:
+ store i32 3, ptr %out
+ br label %exit
+bb3:
+ store i32 2, ptr %out
+ br label %exit
+bb4:
+ store i32 1, ptr %out
+ br label %exit
+bb5:
+ store i32 100, ptr %out
+ br label %exit
+bb6:
+ store i32 200, ptr %out
+ br label %exit
+exit:
+ ret void
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/66762
More information about the llvm-commits
mailing list