[llvm-branch-commits] [llvm] [ConstantTime] Native ct.select support for ARM64 (PR #166706)
Julius Alexandre via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Nov 6 09:28:53 PST 2025
https://github.com/wizardengineer updated https://github.com/llvm/llvm-project/pull/166706
>From 7de2b8134fa87b1f113834b5ab0b218cbde5a821 Mon Sep 17 00:00:00 2001
From: wizardengineer <juliuswoosebert at gmail.com>
Date: Wed, 5 Nov 2025 17:09:45 -0500
Subject: [PATCH] [LLVM][AArch64] Add native ct.select support for ARM64
This patch implements architecture-specific lowering for ct.select on AArch64
using CSEL (conditional select) instructions for constant-time selection.
Implementation details:
- Uses CSEL family of instructions for scalar integer types
- Uses FCSEL for floating-point types (F16, BF16, F32, F64)
- Post-RA MC lowering to convert pseudo-instructions to real CSEL/FCSEL
- Handles vector types appropriately
- Comprehensive test coverage for AArch64
The implementation includes:
- ISelLowering: Custom lowering to CTSELECT pseudo-instructions
- InstrInfo: Pseudo-instruction definitions and patterns
- MCInstLower: Post-RA lowering of pseudo-instructions to actual CSEL/FCSEL
- Proper handling of condition codes for constant-time guarantees
---
.../Target/AArch64/AArch64ISelLowering.cpp | 56 +++++
llvm/lib/Target/AArch64/AArch64ISelLowering.h | 11 +
llvm/lib/Target/AArch64/AArch64InstrInfo.cpp | 200 ++++++++----------
llvm/lib/Target/AArch64/AArch64InstrInfo.td | 40 ++++
.../lib/Target/AArch64/AArch64MCInstLower.cpp | 18 ++
llvm/test/CodeGen/AArch64/ctselect.ll | 153 ++++++++++++++
6 files changed, 368 insertions(+), 110 deletions(-)
create mode 100644 llvm/test/CodeGen/AArch64/ctselect.ll
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 60aa61e993b26..54d0ea168d0b6 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -511,12 +511,36 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::BR_CC, MVT::f64, Custom);
setOperationAction(ISD::SELECT, MVT::i32, Custom);
setOperationAction(ISD::SELECT, MVT::i64, Custom);
+ setOperationAction(ISD::CTSELECT, MVT::i8, Promote);
+ setOperationAction(ISD::CTSELECT, MVT::i16, Promote);
+ setOperationAction(ISD::CTSELECT, MVT::i32, Custom);
+ setOperationAction(ISD::CTSELECT, MVT::i64, Custom);
if (Subtarget->hasFPARMv8()) {
setOperationAction(ISD::SELECT, MVT::f16, Custom);
setOperationAction(ISD::SELECT, MVT::bf16, Custom);
}
+ if (Subtarget->hasFullFP16()) {
+ setOperationAction(ISD::CTSELECT, MVT::f16, Custom);
+ setOperationAction(ISD::CTSELECT, MVT::bf16, Custom);
+ } else {
+ setOperationAction(ISD::CTSELECT, MVT::f16, Promote);
+ setOperationAction(ISD::CTSELECT, MVT::bf16, Promote);
+ }
setOperationAction(ISD::SELECT, MVT::f32, Custom);
setOperationAction(ISD::SELECT, MVT::f64, Custom);
+ setOperationAction(ISD::CTSELECT, MVT::f32, Custom);
+ setOperationAction(ISD::CTSELECT, MVT::f64, Custom);
+ for (MVT VT : MVT::vector_valuetypes()) {
+ MVT elemType = VT.getVectorElementType();
+ if (elemType == MVT::i8 || elemType == MVT::i16) {
+ setOperationAction(ISD::CTSELECT, VT, Promote);
+ } else if ((elemType == MVT::f16 || elemType == MVT::bf16) &&
+ !Subtarget->hasFullFP16()) {
+ setOperationAction(ISD::CTSELECT, VT, Promote);
+ } else {
+ setOperationAction(ISD::CTSELECT, VT, Expand);
+ }
+ }
setOperationAction(ISD::SELECT_CC, MVT::i32, Custom);
setOperationAction(ISD::SELECT_CC, MVT::i64, Custom);
setOperationAction(ISD::SELECT_CC, MVT::f16, Custom);
@@ -3328,6 +3352,20 @@ void AArch64TargetLowering::fixupPtrauthDiscriminator(
IntDiscOp.setImm(IntDisc);
}
+MachineBasicBlock *AArch64TargetLowering::EmitCTSELECT(MachineInstr &MI,
+ MachineBasicBlock *MBB,
+ unsigned Opcode) const {
+ const TargetInstrInfo *TII = Subtarget->getInstrInfo();
+ DebugLoc DL = MI.getDebugLoc();
+ MachineInstrBuilder Builder = BuildMI(*MBB, MI, DL, TII->get(Opcode));
+ for (unsigned Idx = 0; Idx < MI.getNumOperands(); ++Idx) {
+ Builder.add(MI.getOperand(Idx));
+ }
+ Builder->setFlag(MachineInstr::NoMerge);
+ MBB->remove_instr(&MI);
+ return MBB;
+}
+
MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
MachineInstr &MI, MachineBasicBlock *BB) const {
@@ -7590,6 +7628,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
return LowerSELECT(Op, DAG);
case ISD::SELECT_CC:
return LowerSELECT_CC(Op, DAG);
+ case ISD::CTSELECT:
+ return LowerCTSELECT(Op, DAG);
case ISD::JumpTable:
return LowerJumpTable(Op, DAG);
case ISD::BR_JT:
@@ -12149,6 +12189,22 @@ SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
return Res;
}
+SDValue AArch64TargetLowering::LowerCTSELECT(SDValue Op,
+ SelectionDAG &DAG) const {
+ SDValue CCVal = Op->getOperand(0);
+ SDValue TVal = Op->getOperand(1);
+ SDValue FVal = Op->getOperand(2);
+ SDLoc DL(Op);
+
+ EVT VT = Op.getValueType();
+
+ SDValue Zero = DAG.getConstant(0, DL, CCVal.getValueType());
+ SDValue CC;
+ SDValue Cmp = getAArch64Cmp(CCVal, Zero, ISD::SETNE, CC, DAG, DL);
+
+ return DAG.getNode(AArch64ISD::CTSELECT, DL, VT, TVal, FVal, CC, Cmp);
+}
+
SDValue AArch64TargetLowering::LowerJumpTable(SDValue Op,
SelectionDAG &DAG) const {
// Jump table entries as PC relative offsets. No additional tweaking
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 2cb8ed29f252a..987377bc49023 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -23,6 +23,11 @@
namespace llvm {
+namespace AArch64ISD {
+// Forward declare the enum from the generated file
+enum GenNodeType : unsigned;
+} // namespace AArch64ISD
+
class AArch64TargetMachine;
namespace AArch64 {
@@ -202,6 +207,9 @@ class AArch64TargetLowering : public TargetLowering {
MachineOperand &AddrDiscOp,
const TargetRegisterClass *AddrDiscRC) const;
+ MachineBasicBlock *EmitCTSELECT(MachineInstr &MI, MachineBasicBlock *BB,
+ unsigned Opcode) const;
+
MachineBasicBlock *
EmitInstrWithCustomInserter(MachineInstr &MI,
MachineBasicBlock *MBB) const override;
@@ -684,6 +692,7 @@ class AArch64TargetLowering : public TargetLowering {
iterator_range<SDNode::user_iterator> Users,
SDNodeFlags Flags, const SDLoc &dl,
SelectionDAG &DAG) const;
+ SDValue LowerCTSELECT(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerINIT_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerADJUST_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerJumpTable(SDValue Op, SelectionDAG &DAG) const;
@@ -919,6 +928,8 @@ class AArch64TargetLowering : public TargetLowering {
bool hasMultipleConditionRegisters(EVT VT) const override {
return VT.isScalableVector();
}
+
+ bool isSelectSupported(SelectSupportKind Kind) const override { return true; }
};
namespace AArch64 {
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index ccc8eb8a9706d..bab67f57ea6b6 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -700,7 +700,7 @@ static unsigned removeCopies(const MachineRegisterInfo &MRI, unsigned VReg) {
// csel instruction. If so, return the folded opcode, and the replacement
// register.
static unsigned canFoldIntoCSel(const MachineRegisterInfo &MRI, unsigned VReg,
- unsigned *NewReg = nullptr) {
+ unsigned *NewVReg = nullptr) {
VReg = removeCopies(MRI, VReg);
if (!Register::isVirtualRegister(VReg))
return 0;
@@ -708,37 +708,8 @@ static unsigned canFoldIntoCSel(const MachineRegisterInfo &MRI, unsigned VReg,
bool Is64Bit = AArch64::GPR64allRegClass.hasSubClassEq(MRI.getRegClass(VReg));
const MachineInstr *DefMI = MRI.getVRegDef(VReg);
unsigned Opc = 0;
- unsigned SrcReg = 0;
+ unsigned SrcOpNum = 0;
switch (DefMI->getOpcode()) {
- case AArch64::SUBREG_TO_REG:
- // Check for the following way to define an 64-bit immediate:
- // %0:gpr32 = MOVi32imm 1
- // %1:gpr64 = SUBREG_TO_REG 0, %0:gpr32, %subreg.sub_32
- if (!DefMI->getOperand(1).isImm() || DefMI->getOperand(1).getImm() != 0)
- return 0;
- if (!DefMI->getOperand(2).isReg())
- return 0;
- if (!DefMI->getOperand(3).isImm() ||
- DefMI->getOperand(3).getImm() != AArch64::sub_32)
- return 0;
- DefMI = MRI.getVRegDef(DefMI->getOperand(2).getReg());
- if (DefMI->getOpcode() != AArch64::MOVi32imm)
- return 0;
- if (!DefMI->getOperand(1).isImm() || DefMI->getOperand(1).getImm() != 1)
- return 0;
- assert(Is64Bit);
- SrcReg = AArch64::XZR;
- Opc = AArch64::CSINCXr;
- break;
-
- case AArch64::MOVi32imm:
- case AArch64::MOVi64imm:
- if (!DefMI->getOperand(1).isImm() || DefMI->getOperand(1).getImm() != 1)
- return 0;
- SrcReg = Is64Bit ? AArch64::XZR : AArch64::WZR;
- Opc = Is64Bit ? AArch64::CSINCXr : AArch64::CSINCWr;
- break;
-
case AArch64::ADDSXri:
case AArch64::ADDSWri:
// if NZCV is used, do not fold.
@@ -753,7 +724,7 @@ static unsigned canFoldIntoCSel(const MachineRegisterInfo &MRI, unsigned VReg,
if (!DefMI->getOperand(2).isImm() || DefMI->getOperand(2).getImm() != 1 ||
DefMI->getOperand(3).getImm() != 0)
return 0;
- SrcReg = DefMI->getOperand(1).getReg();
+ SrcOpNum = 1;
Opc = Is64Bit ? AArch64::CSINCXr : AArch64::CSINCWr;
break;
@@ -763,7 +734,7 @@ static unsigned canFoldIntoCSel(const MachineRegisterInfo &MRI, unsigned VReg,
unsigned ZReg = removeCopies(MRI, DefMI->getOperand(1).getReg());
if (ZReg != AArch64::XZR && ZReg != AArch64::WZR)
return 0;
- SrcReg = DefMI->getOperand(2).getReg();
+ SrcOpNum = 2;
Opc = Is64Bit ? AArch64::CSINVXr : AArch64::CSINVWr;
break;
}
@@ -782,17 +753,17 @@ static unsigned canFoldIntoCSel(const MachineRegisterInfo &MRI, unsigned VReg,
unsigned ZReg = removeCopies(MRI, DefMI->getOperand(1).getReg());
if (ZReg != AArch64::XZR && ZReg != AArch64::WZR)
return 0;
- SrcReg = DefMI->getOperand(2).getReg();
+ SrcOpNum = 2;
Opc = Is64Bit ? AArch64::CSNEGXr : AArch64::CSNEGWr;
break;
}
default:
return 0;
}
- assert(Opc && SrcReg && "Missing parameters");
+ assert(Opc && SrcOpNum && "Missing parameters");
- if (NewReg)
- *NewReg = SrcReg;
+ if (NewVReg)
+ *NewVReg = DefMI->getOperand(SrcOpNum).getReg();
return Opc;
}
@@ -993,34 +964,28 @@ void AArch64InstrInfo::insertSelect(MachineBasicBlock &MBB,
// Try folding simple instructions into the csel.
if (TryFold) {
- unsigned NewReg = 0;
- unsigned FoldedOpc = canFoldIntoCSel(MRI, TrueReg, &NewReg);
+ unsigned NewVReg = 0;
+ unsigned FoldedOpc = canFoldIntoCSel(MRI, TrueReg, &NewVReg);
if (FoldedOpc) {
// The folded opcodes csinc, csinc and csneg apply the operation to
// FalseReg, so we need to invert the condition.
CC = AArch64CC::getInvertedCondCode(CC);
TrueReg = FalseReg;
} else
- FoldedOpc = canFoldIntoCSel(MRI, FalseReg, &NewReg);
+ FoldedOpc = canFoldIntoCSel(MRI, FalseReg, &NewVReg);
// Fold the operation. Leave any dead instructions for DCE to clean up.
if (FoldedOpc) {
- FalseReg = NewReg;
+ FalseReg = NewVReg;
Opc = FoldedOpc;
- // Extend the live range of NewReg.
- MRI.clearKillFlags(NewReg);
+ // The extends the live range of NewVReg.
+ MRI.clearKillFlags(NewVReg);
}
}
// Pull all virtual register into the appropriate class.
MRI.constrainRegClass(TrueReg, RC);
- // FalseReg might be WZR or XZR if the folded operand is a literal 1.
- assert(
- (FalseReg.isVirtual() || FalseReg == AArch64::WZR ||
- FalseReg == AArch64::XZR) &&
- "FalseReg was folded into a non-virtual register other than WZR or XZR");
- if (FalseReg.isVirtual())
- MRI.constrainRegClass(FalseReg, RC);
+ MRI.constrainRegClass(FalseReg, RC);
// Insert the csel.
BuildMI(MBB, I, DL, get(Opc), DstReg)
@@ -2148,16 +2113,47 @@ bool AArch64InstrInfo::removeCmpToZeroOrOne(
return true;
}
-bool AArch64InstrInfo::expandPostRAPseudo(MachineInstr &MI) const {
- if (MI.getOpcode() != TargetOpcode::LOAD_STACK_GUARD &&
- MI.getOpcode() != AArch64::CATCHRET)
- return false;
+static inline void expandCtSelect(MachineBasicBlock &MBB, MachineInstr &MI,
+ DebugLoc &DL, const MCInstrDesc &MCID) {
+ MachineInstrBuilder Builder = BuildMI(MBB, MI, DL, MCID);
+ for (unsigned Idx = 0; Idx < MI.getNumOperands(); ++Idx) {
+ Builder.add(MI.getOperand(Idx));
+ }
+ Builder->setFlag(MachineInstr::NoMerge);
+ MBB.remove_instr(&MI);
+}
+bool AArch64InstrInfo::expandPostRAPseudo(MachineInstr &MI) const {
MachineBasicBlock &MBB = *MI.getParent();
auto &Subtarget = MBB.getParent()->getSubtarget<AArch64Subtarget>();
auto TRI = Subtarget.getRegisterInfo();
DebugLoc DL = MI.getDebugLoc();
+ switch (MI.getOpcode()) {
+ case AArch64::I32CTSELECT:
+ expandCtSelect(MBB, MI, DL, get(AArch64::CSELWr));
+ return true;
+ case AArch64::I64CTSELECT:
+ expandCtSelect(MBB, MI, DL, get(AArch64::CSELXr));
+ return true;
+ case AArch64::BF16CTSELECT:
+ expandCtSelect(MBB, MI, DL, get(AArch64::FCSELHrrr));
+ return true;
+ case AArch64::F16CTSELECT:
+ expandCtSelect(MBB, MI, DL, get(AArch64::FCSELHrrr));
+ return true;
+ case AArch64::F32CTSELECT:
+ expandCtSelect(MBB, MI, DL, get(AArch64::FCSELSrrr));
+ return true;
+ case AArch64::F64CTSELECT:
+ expandCtSelect(MBB, MI, DL, get(AArch64::FCSELDrrr));
+ return true;
+ }
+
+ if (MI.getOpcode() != TargetOpcode::LOAD_STACK_GUARD &&
+ MI.getOpcode() != AArch64::CATCHRET)
+ return false;
+
if (MI.getOpcode() == AArch64::CATCHRET) {
// Skip to the first instruction before the epilog.
const TargetInstrInfo *TII =
@@ -5098,7 +5094,7 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
bool RenamableDest,
bool RenamableSrc) const {
if (AArch64::GPR32spRegClass.contains(DestReg) &&
- AArch64::GPR32spRegClass.contains(SrcReg)) {
+ (AArch64::GPR32spRegClass.contains(SrcReg) || SrcReg == AArch64::WZR)) {
if (DestReg == AArch64::WSP || SrcReg == AArch64::WSP) {
// If either operand is WSP, expand to ADD #0.
if (Subtarget.hasZeroCycleRegMoveGPR64() &&
@@ -5123,14 +5119,30 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
.addImm(0)
.addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
}
+ } else if (SrcReg == AArch64::WZR && Subtarget.hasZeroCycleZeroingGPR64() &&
+ !Subtarget.hasZeroCycleZeroingGPR32()) {
+ // Use 64-bit zeroing when available but 32-bit zeroing is not
+ MCRegister DestRegX = RI.getMatchingSuperReg(DestReg, AArch64::sub_32,
+ &AArch64::GPR64spRegClass);
+ assert(DestRegX.isValid() && "Destination super-reg not valid");
+ BuildMI(MBB, I, DL, get(AArch64::MOVZXi), DestRegX)
+ .addImm(0)
+ .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
+ } else if (SrcReg == AArch64::WZR && Subtarget.hasZeroCycleZeroingGPR32()) {
+ BuildMI(MBB, I, DL, get(AArch64::MOVZWi), DestReg)
+ .addImm(0)
+ .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
} else if (Subtarget.hasZeroCycleRegMoveGPR64() &&
!Subtarget.hasZeroCycleRegMoveGPR32()) {
// Cyclone recognizes "ORR Xd, XZR, Xm" as a zero-cycle register move.
MCRegister DestRegX = RI.getMatchingSuperReg(DestReg, AArch64::sub_32,
&AArch64::GPR64spRegClass);
assert(DestRegX.isValid() && "Destination super-reg not valid");
- MCRegister SrcRegX = RI.getMatchingSuperReg(SrcReg, AArch64::sub_32,
- &AArch64::GPR64spRegClass);
+ MCRegister SrcRegX =
+ SrcReg == AArch64::WZR
+ ? AArch64::XZR
+ : RI.getMatchingSuperReg(SrcReg, AArch64::sub_32,
+ &AArch64::GPR64spRegClass);
assert(SrcRegX.isValid() && "Source super-reg not valid");
// This instruction is reading and writing X registers. This may upset
// the register scavenger and machine verifier, so we need to indicate
@@ -5149,59 +5161,6 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
return;
}
- // GPR32 zeroing
- if (AArch64::GPR32spRegClass.contains(DestReg) && SrcReg == AArch64::WZR) {
- if (Subtarget.hasZeroCycleZeroingGPR64() &&
- !Subtarget.hasZeroCycleZeroingGPR32()) {
- MCRegister DestRegX = RI.getMatchingSuperReg(DestReg, AArch64::sub_32,
- &AArch64::GPR64spRegClass);
- assert(DestRegX.isValid() && "Destination super-reg not valid");
- BuildMI(MBB, I, DL, get(AArch64::MOVZXi), DestRegX)
- .addImm(0)
- .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
- } else if (Subtarget.hasZeroCycleZeroingGPR32()) {
- BuildMI(MBB, I, DL, get(AArch64::MOVZWi), DestReg)
- .addImm(0)
- .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
- } else {
- BuildMI(MBB, I, DL, get(AArch64::ORRWrr), DestReg)
- .addReg(AArch64::WZR)
- .addReg(AArch64::WZR);
- }
- return;
- }
-
- if (AArch64::GPR64spRegClass.contains(DestReg) &&
- AArch64::GPR64spRegClass.contains(SrcReg)) {
- if (DestReg == AArch64::SP || SrcReg == AArch64::SP) {
- // If either operand is SP, expand to ADD #0.
- BuildMI(MBB, I, DL, get(AArch64::ADDXri), DestReg)
- .addReg(SrcReg, getKillRegState(KillSrc))
- .addImm(0)
- .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
- } else {
- // Otherwise, expand to ORR XZR.
- BuildMI(MBB, I, DL, get(AArch64::ORRXrr), DestReg)
- .addReg(AArch64::XZR)
- .addReg(SrcReg, getKillRegState(KillSrc));
- }
- return;
- }
-
- // GPR64 zeroing
- if (AArch64::GPR64spRegClass.contains(DestReg) && SrcReg == AArch64::XZR) {
- if (Subtarget.hasZeroCycleZeroingGPR64()) {
- BuildMI(MBB, I, DL, get(AArch64::MOVZXi), DestReg)
- .addImm(0)
- .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
- } else {
- BuildMI(MBB, I, DL, get(AArch64::ORRXrr), DestReg)
- .addReg(AArch64::XZR)
- .addReg(AArch64::XZR);
- }
- return;
- }
-
// Copy a Predicate register by ORRing with itself.
if (AArch64::PPRRegClass.contains(DestReg) &&
AArch64::PPRRegClass.contains(SrcReg)) {
@@ -5286,6 +5245,27 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
return;
}
+ if (AArch64::GPR64spRegClass.contains(DestReg) &&
+ (AArch64::GPR64spRegClass.contains(SrcReg) || SrcReg == AArch64::XZR)) {
+ if (DestReg == AArch64::SP || SrcReg == AArch64::SP) {
+ // If either operand is SP, expand to ADD #0.
+ BuildMI(MBB, I, DL, get(AArch64::ADDXri), DestReg)
+ .addReg(SrcReg, getKillRegState(KillSrc))
+ .addImm(0)
+ .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
+ } else if (SrcReg == AArch64::XZR && Subtarget.hasZeroCycleZeroingGPR64()) {
+ BuildMI(MBB, I, DL, get(AArch64::MOVZXi), DestReg)
+ .addImm(0)
+ .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
+ } else {
+ // Otherwise, expand to ORR XZR.
+ BuildMI(MBB, I, DL, get(AArch64::ORRXrr), DestReg)
+ .addReg(AArch64::XZR)
+ .addReg(SrcReg, getKillRegState(KillSrc));
+ }
+ return;
+ }
+
// Copy a DDDD register quad by copying the individual sub-registers.
if (AArch64::DDDDRegClass.contains(DestReg) &&
AArch64::DDDDRegClass.contains(SrcReg)) {
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 2871a20e28b65..5017a39789d08 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -476,6 +476,9 @@ def SDT_AArch64cbz : SDTypeProfile<0, 2, [SDTCisInt<0>, SDTCisVT<1, OtherVT>]>;
def SDT_AArch64tbz : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisInt<1>,
SDTCisVT<2, OtherVT>]>;
+def SDT_AArch64CtSelect : SDTypeProfile<1, 4,
+ [SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>,
+ SDTCisInt<3>, SDTCisVT<4, i32>]>;
def SDT_AArch64CSel : SDTypeProfile<1, 4,
[SDTCisSameAs<0, 1>,
SDTCisSameAs<0, 2>,
@@ -843,6 +846,7 @@ def AArch64tbz : SDNode<"AArch64ISD::TBZ", SDT_AArch64tbz,
def AArch64tbnz : SDNode<"AArch64ISD::TBNZ", SDT_AArch64tbz,
[SDNPHasChain]>;
+def AArch64ctselect : SDNode<"AArch64ISD::CTSELECT", SDT_AArch64CtSelect>;
def AArch64csel : SDNode<"AArch64ISD::CSEL", SDT_AArch64CSel>;
// Conditional select invert.
@@ -5644,6 +5648,42 @@ def F128CSEL : Pseudo<(outs FPR128:$Rd),
let hasNoSchedulingInfo = 1;
}
+//===----------------------------------------------------------------------===//
+// Constant-time conditional selection instructions
+//===----------------------------------------------------------------------===//
+
+let hasSideEffects = 1, isPseudo = 1, hasNoSchedulingInfo = 1,
+ Uses = [NZCV] in {
+ def I32CTSELECT
+ : Pseudo<(outs GPR32:$dst), (ins GPR32:$tval, GPR32:$fval, i32imm:$cc),
+ [(set (i32 GPR32:$dst), (AArch64ctselect GPR32:$tval,
+ GPR32:$fval, (i32 imm:$cc), NZCV))]>;
+ def I64CTSELECT
+ : Pseudo<(outs GPR64:$dst), (ins GPR64:$tval, GPR64:$fval, i32imm:$cc),
+ [(set (i64 GPR64:$dst), (AArch64ctselect GPR64:$tval,
+ GPR64:$fval, (i32 imm:$cc), NZCV))]>;
+ let Predicates = [HasFullFP16] in {
+ def F16CTSELECT
+ : Pseudo<(outs FPR16:$dst), (ins FPR16:$tval, FPR16:$fval, i32imm:$cc),
+ [(set (f16 FPR16:$dst), (AArch64ctselect (f16 FPR16:$tval),
+ (f16 FPR16:$fval), (i32 imm:$cc),
+ NZCV))]>;
+ def BF16CTSELECT
+ : Pseudo<(outs FPR16:$dst), (ins FPR16:$tval, FPR16:$fval, i32imm:$cc),
+ [(set (bf16 FPR16:$dst), (AArch64ctselect (bf16 FPR16:$tval),
+ (bf16 FPR16:$fval), (i32 imm:$cc),
+ NZCV))]>;
+ }
+ def F32CTSELECT
+ : Pseudo<(outs FPR32:$dst), (ins FPR32:$tval, FPR32:$fval, i32imm:$cc),
+ [(set (f32 FPR32:$dst), (AArch64ctselect FPR32:$tval,
+ FPR32:$fval, (i32 imm:$cc), NZCV))]>;
+ def F64CTSELECT
+ : Pseudo<(outs FPR64:$dst), (ins FPR64:$tval, FPR64:$fval, i32imm:$cc),
+ [(set (f64 FPR64:$dst), (AArch64ctselect FPR64:$tval,
+ FPR64:$fval, (i32 imm:$cc), NZCV))]>;
+}
+
//===----------------------------------------------------------------------===//
// Instructions used for emitting unwind opcodes on ARM64 Windows.
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/AArch64/AArch64MCInstLower.cpp b/llvm/lib/Target/AArch64/AArch64MCInstLower.cpp
index 39946633603f6..e2ec9118eb5ee 100644
--- a/llvm/lib/Target/AArch64/AArch64MCInstLower.cpp
+++ b/llvm/lib/Target/AArch64/AArch64MCInstLower.cpp
@@ -393,5 +393,23 @@ void AArch64MCInstLower::Lower(const MachineInstr *MI, MCInst &OutMI) const {
OutMI.setOpcode(AArch64::RET);
OutMI.addOperand(MCOperand::createReg(AArch64::LR));
break;
+ case AArch64::I32CTSELECT:
+ OutMI.setOpcode(AArch64::CSELWr);
+ break;
+ case AArch64::I64CTSELECT:
+ OutMI.setOpcode(AArch64::CSELXr);
+ break;
+ case AArch64::BF16CTSELECT:
+ OutMI.setOpcode(AArch64::FCSELHrrr);
+ break;
+ case AArch64::F16CTSELECT:
+ OutMI.setOpcode(AArch64::FCSELHrrr);
+ break;
+ case AArch64::F32CTSELECT:
+ OutMI.setOpcode(AArch64::FCSELSrrr);
+ break;
+ case AArch64::F64CTSELECT:
+ OutMI.setOpcode(AArch64::FCSELDrrr);
+ break;
}
}
diff --git a/llvm/test/CodeGen/AArch64/ctselect.ll b/llvm/test/CodeGen/AArch64/ctselect.ll
new file mode 100644
index 0000000000000..77e9cf24e56cf
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/ctselect.ll
@@ -0,0 +1,153 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc < %s -verify-machineinstrs -mtriple=aarch64-none-eabi | FileCheck %s --check-prefixes=DEFAULT,NOFP16
+; RUN: llc < %s -verify-machineinstrs -mtriple=aarch64-none-eabi -mattr=+fullfp16 | FileCheck %s --check-prefixes=DEFAULT,FP16
+
+define i1 @ct_i1(i1 %cond, i1 %a, i1 %b) {
+; DEFAULT-LABEL: ct_i1:
+; DEFAULT: // %bb.0:
+; DEFAULT-NEXT: tst w0, #0x1
+; DEFAULT-NEXT: csel w8, w1, w2, ne
+; DEFAULT-NEXT: and w0, w8, #0x1
+; DEFAULT-NEXT: ret
+ %1 = call i1 @llvm.ct.select.i1(i1 %cond, i1 %a, i1 %b)
+ ret i1 %1
+}
+
+define i8 @ct_i8(i1 %cond, i8 %a, i8 %b) {
+; DEFAULT-LABEL: ct_i8:
+; DEFAULT: // %bb.0:
+; DEFAULT-NEXT: tst w0, #0x1
+; DEFAULT-NEXT: csel w0, w1, w2, ne
+; DEFAULT-NEXT: ret
+ %1 = call i8 @llvm.ct.select.i8(i1 %cond, i8 %a, i8 %b)
+ ret i8 %1
+}
+
+define i16 @ct_i16(i1 %cond, i16 %a, i16 %b) {
+; DEFAULT-LABEL: ct_i16:
+; DEFAULT: // %bb.0:
+; DEFAULT-NEXT: tst w0, #0x1
+; DEFAULT-NEXT: csel w0, w1, w2, ne
+; DEFAULT-NEXT: ret
+ %1 = call i16 @llvm.ct.select.i16(i1 %cond, i16 %a, i16 %b)
+ ret i16 %1
+}
+
+define i32 @ct_i32(i1 %cond, i32 %a, i32 %b) {
+; DEFAULT-LABEL: ct_i32:
+; DEFAULT: // %bb.0:
+; DEFAULT-NEXT: tst w0, #0x1
+; DEFAULT-NEXT: csel w0, w1, w2, ne
+; DEFAULT-NEXT: ret
+ %1 = call i32 @llvm.ct.select.i32(i1 %cond, i32 %a, i32 %b)
+ ret i32 %1
+}
+
+define i64 @ct_i64(i1 %cond, i64 %a, i64 %b) {
+; DEFAULT-LABEL: ct_i64:
+; DEFAULT: // %bb.0:
+; DEFAULT-NEXT: tst w0, #0x1
+; DEFAULT-NEXT: csel x0, x1, x2, ne
+; DEFAULT-NEXT: ret
+ %1 = call i64 @llvm.ct.select.i64(i1 %cond, i64 %a, i64 %b)
+ ret i64 %1
+}
+
+define i128 @ct_i128(i1 %cond, i128 %a, i128 %b) {
+; DEFAULT-LABEL: ct_i128:
+; DEFAULT: // %bb.0:
+; DEFAULT-NEXT: tst w0, #0x1
+; DEFAULT-NEXT: csel x0, x2, x4, ne
+; DEFAULT-NEXT: csel x1, x3, x5, ne
+; DEFAULT-NEXT: ret
+ %1 = call i128 @llvm.ct.select.i128(i1 %cond, i128 %a, i128 %b)
+ ret i128 %1
+}
+
+define half @ct_f16(i1 %cond, half %a, half %b) {
+; NOFP16-LABEL: ct_f16:
+; NOFP16: // %bb.0:
+; NOFP16-NEXT: fcvt s1, h1
+; NOFP16-NEXT: fcvt s0, h0
+; NOFP16-NEXT: tst w0, #0x1
+; NOFP16-NEXT: fcsel s0, s0, s1, ne
+; NOFP16-NEXT: fcvt h0, s0
+; NOFP16-NEXT: ret
+;
+; FP16-LABEL: ct_f16:
+; FP16: // %bb.0:
+; FP16-NEXT: tst w0, #0x1
+; FP16-NEXT: fcsel h0, h0, h1, ne
+; FP16-NEXT: ret
+ %1 = call half @llvm.ct.select.f16(i1 %cond, half %a, half %b)
+ ret half %1
+}
+
+define float @ct_f32(i1 %cond, float %a, float %b) {
+; DEFAULT-LABEL: ct_f32:
+; DEFAULT: // %bb.0:
+; DEFAULT-NEXT: tst w0, #0x1
+; DEFAULT-NEXT: fcsel s0, s0, s1, ne
+; DEFAULT-NEXT: ret
+ %1 = call float @llvm.ct.select.f32(i1 %cond, float %a, float %b)
+ ret float %1
+}
+
+define double @ct_f64(i1 %cond, double %a, double %b) {
+; DEFAULT-LABEL: ct_f64:
+; DEFAULT: // %bb.0:
+; DEFAULT-NEXT: tst w0, #0x1
+; DEFAULT-NEXT: fcsel d0, d0, d1, ne
+; DEFAULT-NEXT: ret
+ %1 = call double @llvm.ct.select.f64(i1 %cond, double %a, double %b)
+ ret double %1
+}
+
+define <4 x i32> @ct_v4i32(i1 %cond, <4 x i32> %a, <4 x i32> %b) {
+; DEFAULT-LABEL: ct_v4i32:
+; DEFAULT: // %bb.0:
+; DEFAULT-NEXT: mov w8, v1.s[1]
+; DEFAULT-NEXT: mov w9, v0.s[1]
+; DEFAULT-NEXT: tst w0, #0x1
+; DEFAULT-NEXT: fmov w10, s1
+; DEFAULT-NEXT: fmov w11, s0
+; DEFAULT-NEXT: csel w8, w9, w8, ne
+; DEFAULT-NEXT: csel w9, w11, w10, ne
+; DEFAULT-NEXT: mov w10, v1.s[2]
+; DEFAULT-NEXT: fmov s2, w9
+; DEFAULT-NEXT: mov w11, v0.s[2]
+; DEFAULT-NEXT: mov w9, v0.s[3]
+; DEFAULT-NEXT: mov v2.s[1], w8
+; DEFAULT-NEXT: mov w8, v1.s[3]
+; DEFAULT-NEXT: csel w10, w11, w10, ne
+; DEFAULT-NEXT: mov v2.s[2], w10
+; DEFAULT-NEXT: csel w8, w9, w8, ne
+; DEFAULT-NEXT: mov v2.s[3], w8
+; DEFAULT-NEXT: mov v0.16b, v2.16b
+; DEFAULT-NEXT: ret
+ %1 = call <4 x i32> @llvm.ct.select.v4i32(i1 %cond, <4 x i32> %a, <4 x i32> %b)
+ ret <4 x i32> %1
+}
+
+define <4 x float> @ct_v4f32(i1 %cond, <4 x float> %a, <4 x float> %b) {
+; DEFAULT-LABEL: ct_v4f32:
+; DEFAULT: // %bb.0:
+; DEFAULT-NEXT: mov s2, v1.s[1]
+; DEFAULT-NEXT: mov s3, v0.s[1]
+; DEFAULT-NEXT: tst w0, #0x1
+; DEFAULT-NEXT: mov s4, v1.s[2]
+; DEFAULT-NEXT: mov s5, v0.s[2]
+; DEFAULT-NEXT: fcsel s3, s3, s2, ne
+; DEFAULT-NEXT: fcsel s2, s0, s1, ne
+; DEFAULT-NEXT: mov s1, v1.s[3]
+; DEFAULT-NEXT: mov s0, v0.s[3]
+; DEFAULT-NEXT: mov v2.s[1], v3.s[0]
+; DEFAULT-NEXT: fcsel s3, s5, s4, ne
+; DEFAULT-NEXT: fcsel s0, s0, s1, ne
+; DEFAULT-NEXT: mov v2.s[2], v3.s[0]
+; DEFAULT-NEXT: mov v2.s[3], v0.s[0]
+; DEFAULT-NEXT: mov v0.16b, v2.16b
+; DEFAULT-NEXT: ret
+ %1 = call <4 x float> @llvm.ct.select.v4f32(i1 %cond, <4 x float> %a, <4 x float> %b)
+ ret <4 x float> %1
+}
More information about the llvm-branch-commits
mailing list