[llvm-branch-commits] [llvm] [ConstantTime] Native ct.select support for ARM32 and Thumb (PR #166707)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Nov 6 12:37:18 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-arm
Author: Julius Alexandre (wizardengineer)
<details>
<summary>Changes</summary>
This patch implements architecture-specific lowering for ct.select on ARM
(both ARM32 and Thumb modes) using conditional move instructions and
bitwise operations for constant-time selection.
Implementation details:
- Uses pseudo-instructions that are expanded Post-RA to bitwise operations
- Post-RA expansion in ARMBaseInstrInfo for BUNDLE pseudo-instructions
- Handles scalar integer types, floating-point, and half-precision types
- Handles vector types with NEON when available
- Support for both ARM and Thumb instruction sets (Thumb1 and Thumb2)
- Special handling for Thumb1 which lacks conditional execution
- Comprehensive test coverage including half-precision and vectors
The implementation includes:
- ISelLowering: Custom lowering to CTSELECT pseudo-instructions
- ISelDAGToDAG: Selection of appropriate pseudo-instructions
- BaseInstrInfo: Post-RA expansion of BUNDLE to bitwise instruction sequences
- InstrInfo.td: Pseudo-instruction definitions for different types
- TargetMachine: Registration of Post-RA expansion pass
- Proper handling of condition codes and register allocation constraints
---
Patch is 166.38 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/166707.diff
10 Files Affected:
- (modified) llvm/lib/Target/ARM/ARMBaseInstrInfo.cpp (+335-2)
- (modified) llvm/lib/Target/ARM/ARMBaseInstrInfo.h (+6)
- (modified) llvm/lib/Target/ARM/ARMISelDAGToDAG.cpp (+86)
- (modified) llvm/lib/Target/ARM/ARMISelLowering.cpp (+164-20)
- (modified) llvm/lib/Target/ARM/ARMISelLowering.h (+11-2)
- (modified) llvm/lib/Target/ARM/ARMInstrInfo.td (+185)
- (modified) llvm/lib/Target/ARM/ARMTargetMachine.cpp (+3-5)
- (added) llvm/test/CodeGen/ARM/ctselect-half.ll (+975)
- (added) llvm/test/CodeGen/ARM/ctselect-vector.ll (+2179)
- (added) llvm/test/CodeGen/ARM/ctselect.ll (+555)
``````````diff
diff --git a/llvm/lib/Target/ARM/ARMBaseInstrInfo.cpp b/llvm/lib/Target/ARM/ARMBaseInstrInfo.cpp
index 22769dbf38719..6d8a3b72244fe 100644
--- a/llvm/lib/Target/ARM/ARMBaseInstrInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMBaseInstrInfo.cpp
@@ -1526,18 +1526,351 @@ void ARMBaseInstrInfo::expandMEMCPY(MachineBasicBlock::iterator MI) const {
BB->erase(MI);
}
+// Expands the ctselect pseudo for vector operands, post-RA.
+bool ARMBaseInstrInfo::expandCtSelectVector(MachineInstr &MI) const {
+ MachineBasicBlock *MBB = MI.getParent();
+ DebugLoc DL = MI.getDebugLoc();
+
+ Register DestReg = MI.getOperand(0).getReg();
+ Register MaskReg = MI.getOperand(1).getReg();
+
+ // These operations will differ by operand register size.
+ unsigned AndOp = ARM::VANDd;
+ unsigned BicOp = ARM::VBICd;
+ unsigned OrrOp = ARM::VORRd;
+ unsigned BroadcastOp = ARM::VDUP32d;
+
+ const TargetRegisterInfo *TRI = &getRegisterInfo();
+ const TargetRegisterClass *RC = TRI->getMinimalPhysRegClass(DestReg);
+
+ if (ARM::QPRRegClass.hasSubClassEq(RC)) {
+ AndOp = ARM::VANDq;
+ BicOp = ARM::VBICq;
+ OrrOp = ARM::VORRq;
+ BroadcastOp = ARM::VDUP32q;
+ }
+
+ unsigned RsbOp = Subtarget.isThumb2() ? ARM::t2RSBri : ARM::RSBri;
+
+ // Any vector pseudo has: ((outs $dst, $tmp_mask, $bcast_mask), (ins $src1,
+ // $src2, $cond))
+ Register VectorMaskReg = MI.getOperand(2).getReg();
+ Register Src1Reg = MI.getOperand(3).getReg();
+ Register Src2Reg = MI.getOperand(4).getReg();
+ Register CondReg = MI.getOperand(5).getReg();
+
+ // The following sequence of steps yields: (src1 & mask) | (src2 & ~mask)
+
+ // 1. mask = 0 - cond
+ // When cond = 0: mask = 0x00000000.
+ // When cond = 1: mask = 0xFFFFFFFF.
+
+ MachineInstr *FirstNewMI = BuildMI(*MBB, MI, DL, get(RsbOp), MaskReg)
+ .addReg(CondReg)
+ .addImm(0)
+ .add(predOps(ARMCC::AL))
+ .add(condCodeOp())
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ // 2. A = src1 & mask
+ // For vectors, broadcast the scalar mask so it matches operand size.
+ BuildMI(*MBB, MI, DL, get(BroadcastOp), VectorMaskReg)
+ .addReg(MaskReg)
+ .add(predOps(ARMCC::AL))
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ BuildMI(*MBB, MI, DL, get(AndOp), DestReg)
+ .addReg(Src1Reg)
+ .addReg(VectorMaskReg)
+ .add(predOps(ARMCC::AL))
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ // 3. B = src2 & ~mask
+ BuildMI(*MBB, MI, DL, get(BicOp), VectorMaskReg)
+ .addReg(Src2Reg)
+ .addReg(VectorMaskReg)
+ .add(predOps(ARMCC::AL))
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ // 4. result = A | B
+ auto LastNewMI = BuildMI(*MBB, MI, DL, get(OrrOp), DestReg)
+ .addReg(DestReg)
+ .addReg(VectorMaskReg)
+ .add(predOps(ARMCC::AL))
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ auto BundleStart = FirstNewMI->getIterator();
+ auto BundleEnd = LastNewMI->getIterator();
+
+ // Add instruction bundling
+ finalizeBundle(*MBB, BundleStart, std::next(BundleEnd));
+
+ MI.eraseFromParent();
+ return true;
+}
+
+// Expands the ctselect pseudo for thumb1, post-RA.
+bool ARMBaseInstrInfo::expandCtSelectThumb(MachineInstr &MI) const {
+ MachineBasicBlock *MBB = MI.getParent();
+ DebugLoc DL = MI.getDebugLoc();
+
+ // pseudos in thumb1 mode have: (outs $dst, $tmp_mask), (ins $src1, $src2,
+ // $cond)) register class here is always tGPR.
+ Register DestReg = MI.getOperand(0).getReg();
+ Register MaskReg = MI.getOperand(1).getReg();
+ Register Src1Reg = MI.getOperand(2).getReg();
+ Register Src2Reg = MI.getOperand(3).getReg();
+ Register CondReg = MI.getOperand(4).getReg();
+
+ // Access register info
+ MachineFunction *MF = MBB->getParent();
+ const TargetRegisterInfo *TRI = MF->getSubtarget().getRegisterInfo();
+ MachineRegisterInfo &MRI = MF->getRegInfo();
+
+ unsigned RegSize = TRI->getRegSizeInBits(MaskReg, MRI);
+ unsigned ShiftAmount = RegSize - 1;
+
+ // Option 1: Shift-based mask (preferred - no flag modification)
+ MachineInstr *FirstNewMI = BuildMI(*MBB, MI, DL, get(ARM::tMOVr), MaskReg)
+ .addReg(CondReg)
+ .add(predOps(ARMCC::AL))
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ // Instead of using RSB, we can use LSL and ASR to get the mask. This is to
+ // avoid the flag modification caused by RSB. tLSLri: (outs tGPR:$Rd,
+ // s_cc_out:$s), (ins tGPR:$Rm, imm0_31:$imm5, pred:$p)
+ BuildMI(*MBB, MI, DL, get(ARM::tLSLri), MaskReg)
+ .addReg(ARM::CPSR, RegState::Define | RegState::Dead) // s_cc_out:$s
+ .addReg(MaskReg) // $Rm
+ .addImm(ShiftAmount) // imm0_31:$imm5
+ .add(predOps(ARMCC::AL)) // pred:$p
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ // tASRri: (outs tGPR:$Rd, s_cc_out:$s), (ins tGPR:$Rm, imm_sr:$imm5, pred:$p)
+ BuildMI(*MBB, MI, DL, get(ARM::tASRri), MaskReg)
+ .addReg(ARM::CPSR, RegState::Define | RegState::Dead) // s_cc_out:$s
+ .addReg(MaskReg) // $Rm
+ .addImm(ShiftAmount) // imm_sr:$imm5
+ .add(predOps(ARMCC::AL)) // pred:$p
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ // 2. xor_diff = src1 ^ src2
+ BuildMI(*MBB, MI, DL, get(ARM::tMOVr), DestReg)
+ .addReg(Src1Reg)
+ .add(predOps(ARMCC::AL))
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ // tEOR has tied operands: (outs tGPR:$Rdn, s_cc_out:$s), (ins tGPR:$Rn,
+ // pred:$p) with constraint "$Rn = $Rdn"
+ BuildMI(*MBB, MI, DL, get(ARM::tEOR), DestReg)
+ .addReg(ARM::CPSR, RegState::Define | RegState::Dead) // s_cc_out:$s
+ .addReg(DestReg) // tied input $Rn
+ .addReg(Src2Reg) // $Rm
+ .add(predOps(ARMCC::AL)) // pred:$p
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ // 3. masked_xor = xor_diff & mask
+ // tAND has tied operands: (outs tGPR:$Rdn, s_cc_out:$s), (ins tGPR:$Rn,
+ // pred:$p) with constraint "$Rn = $Rdn"
+ BuildMI(*MBB, MI, DL, get(ARM::tAND), DestReg)
+ .addReg(ARM::CPSR, RegState::Define | RegState::Dead) // s_cc_out:$s
+ .addReg(DestReg) // tied input $Rn
+ .addReg(MaskReg, RegState::Kill) // $Rm
+ .add(predOps(ARMCC::AL)) // pred:$p
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ // 4. result = src2 ^ masked_xor
+ // tEOR has tied operands: (outs tGPR:$Rdn, s_cc_out:$s), (ins tGPR:$Rn,
+ // pred:$p) with constraint "$Rn = $Rdn"
+ auto LastMI =
+ BuildMI(*MBB, MI, DL, get(ARM::tEOR), DestReg)
+ .addReg(ARM::CPSR, RegState::Define | RegState::Dead) // s_cc_out:$s
+ .addReg(DestReg) // tied input $Rn
+ .addReg(Src2Reg) // $Rm
+ .add(predOps(ARMCC::AL)) // pred:$p
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ // Add instruction bundling
+ auto BundleStart = FirstNewMI->getIterator();
+ finalizeBundle(*MBB, BundleStart, std::next(LastMI->getIterator()));
+
+ MI.eraseFromParent();
+ return true;
+}
+
+// Expands the ctselect pseudo, post-RA.
+bool ARMBaseInstrInfo::expandCtSelect(MachineInstr &MI) const {
+ MachineBasicBlock *MBB = MI.getParent();
+ DebugLoc DL = MI.getDebugLoc();
+
+ Register DestReg = MI.getOperand(0).getReg();
+ Register MaskReg = MI.getOperand(1).getReg();
+ Register DestRegSavedRef = DestReg;
+ Register Src1Reg, Src2Reg, CondReg;
+
+ // These operations will differ by operand register size.
+ unsigned RsbOp = ARM::RSBri;
+ unsigned AndOp = ARM::ANDrr;
+ unsigned BicOp = ARM::BICrr;
+ unsigned OrrOp = ARM::ORRrr;
+
+ if (Subtarget.isThumb2()) {
+ RsbOp = ARM::t2RSBri;
+ AndOp = ARM::t2ANDrr;
+ BicOp = ARM::t2BICrr;
+ OrrOp = ARM::t2ORRrr;
+ }
+
+ unsigned Opcode = MI.getOpcode();
+ bool IsFloat = Opcode == ARM::CTSELECTf32 || Opcode == ARM::CTSELECTf16 ||
+ Opcode == ARM::CTSELECTbf16;
+ MachineInstr *FirstNewMI = nullptr;
+ if (IsFloat) {
+ // Each float pseudo has: (outs $dst, $tmp_mask, $scratch1, $scratch2), (ins
+ // $src1, $src2, $cond)) We use two scratch registers in tablegen for
+ // bitwise ops on float types,.
+ Register GPRScratch1 = MI.getOperand(2).getReg();
+ Register GPRScratch2 = MI.getOperand(3).getReg();
+
+ // choice a from __builtin_ct_select(cond, a, b)
+ Src1Reg = MI.getOperand(4).getReg();
+ // choice b from __builtin_ct_select(cond, a, b)
+ Src2Reg = MI.getOperand(5).getReg();
+ // cond from __builtin_ct_select(cond, a, b)
+ CondReg = MI.getOperand(6).getReg();
+
+ // Move fp src1 to GPR scratch1 so we can do our bitwise ops
+ FirstNewMI = BuildMI(*MBB, MI, DL, get(ARM::VMOVRS), GPRScratch1)
+ .addReg(Src1Reg)
+ .add(predOps(ARMCC::AL))
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ // Move src2 to scratch2
+ BuildMI(*MBB, MI, DL, get(ARM::VMOVRS), GPRScratch2)
+ .addReg(Src2Reg)
+ .add(predOps(ARMCC::AL))
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ Src1Reg = GPRScratch1;
+ Src2Reg = GPRScratch2;
+ // Reuse GPRScratch1 for dest after we are done working with src1.
+ DestReg = GPRScratch1;
+ } else {
+ // Any non-float, non-vector pseudo has: (outs $dst, $tmp_mask), (ins $src1,
+ // $src2, $cond))
+ Src1Reg = MI.getOperand(2).getReg();
+ Src2Reg = MI.getOperand(3).getReg();
+ CondReg = MI.getOperand(4).getReg();
+ }
+
+ // The following sequence of steps yields: (src1 & mask) | (src2 & ~mask)
+
+ // 1. mask = 0 - cond
+ // When cond = 0: mask = 0x00000000.
+ // When cond = 1: mask = 0xFFFFFFFF.
+ auto TmpNewMI = BuildMI(*MBB, MI, DL, get(RsbOp), MaskReg)
+ .addReg(CondReg)
+ .addImm(0)
+ .add(predOps(ARMCC::AL))
+ .add(condCodeOp())
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ // We use the first instruction in the bundle as the first instruction.
+ if (!FirstNewMI)
+ FirstNewMI = TmpNewMI;
+
+ // 2. A = src1 & mask
+ BuildMI(*MBB, MI, DL, get(AndOp), DestReg)
+ .addReg(Src1Reg)
+ .addReg(MaskReg)
+ .add(predOps(ARMCC::AL))
+ .add(condCodeOp())
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ // 3. B = src2 & ~mask
+ BuildMI(*MBB, MI, DL, get(BicOp), MaskReg)
+ .addReg(Src2Reg)
+ .addReg(MaskReg)
+ .add(predOps(ARMCC::AL))
+ .add(condCodeOp())
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ // 4. result = A | B
+ auto LastNewMI = BuildMI(*MBB, MI, DL, get(OrrOp), DestReg)
+ .addReg(DestReg)
+ .addReg(MaskReg)
+ .add(predOps(ARMCC::AL))
+ .add(condCodeOp())
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ if (IsFloat) {
+ // Return our result from GPR to the correct register type.
+ LastNewMI = BuildMI(*MBB, MI, DL, get(ARM::VMOVSR), DestRegSavedRef)
+ .addReg(DestReg)
+ .add(predOps(ARMCC::AL))
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+ }
+
+ auto BundleStart = FirstNewMI->getIterator();
+ auto BundleEnd = LastNewMI->getIterator();
+
+ // Add instruction bundling
+ finalizeBundle(*MBB, BundleStart, std::next(BundleEnd));
+
+ MI.eraseFromParent();
+ return true;
+}
+
bool ARMBaseInstrInfo::expandPostRAPseudo(MachineInstr &MI) const {
- if (MI.getOpcode() == TargetOpcode::LOAD_STACK_GUARD) {
+ auto opcode = MI.getOpcode();
+
+ if (opcode == TargetOpcode::LOAD_STACK_GUARD) {
expandLoadStackGuard(MI);
MI.getParent()->erase(MI);
return true;
}
- if (MI.getOpcode() == ARM::MEMCPY) {
+ if (opcode == ARM::MEMCPY) {
expandMEMCPY(MI);
return true;
}
+ if (opcode == ARM::CTSELECTf64) {
+ if (Subtarget.isThumb1Only()) {
+ LLVM_DEBUG(dbgs() << "Opcode (thumb1 subtarget) " << opcode
+ << "replaced by: " << MI);
+ return expandCtSelectThumb(MI);
+ } else {
+ LLVM_DEBUG(dbgs() << "Opcode (vector) " << opcode
+ << "replaced by: " << MI);
+ return expandCtSelectVector(MI);
+ }
+ }
+
+ if (opcode == ARM::CTSELECTv8i8 || opcode == ARM::CTSELECTv4i16 ||
+ opcode == ARM::CTSELECTv2i32 || opcode == ARM::CTSELECTv1i64 ||
+ opcode == ARM::CTSELECTv2f32 || opcode == ARM::CTSELECTv4f16 ||
+ opcode == ARM::CTSELECTv4bf16 || opcode == ARM::CTSELECTv16i8 ||
+ opcode == ARM::CTSELECTv8i16 || opcode == ARM::CTSELECTv4i32 ||
+ opcode == ARM::CTSELECTv2i64 || opcode == ARM::CTSELECTv4f32 ||
+ opcode == ARM::CTSELECTv2f64 || opcode == ARM::CTSELECTv8f16 ||
+ opcode == ARM::CTSELECTv8bf16) {
+ LLVM_DEBUG(dbgs() << "Opcode (vector) " << opcode << "replaced by: " << MI);
+ return expandCtSelectVector(MI);
+ }
+
+ if (opcode == ARM::CTSELECTint || opcode == ARM::CTSELECTf16 ||
+ opcode == ARM::CTSELECTbf16 || opcode == ARM::CTSELECTf32) {
+ if (Subtarget.isThumb1Only()) {
+ LLVM_DEBUG(dbgs() << "Opcode (thumb1 subtarget) " << opcode
+ << "replaced by: " << MI);
+ return expandCtSelectThumb(MI);
+ } else {
+ LLVM_DEBUG(dbgs() << "Opcode " << opcode << "replaced by: " << MI);
+ return expandCtSelect(MI);
+ }
+ }
+
// This hook gets to expand COPY instructions before they become
// copyPhysReg() calls. Look for VMOVS instructions that can legally be
// widened to VMOVD. We prefer the VMOVD when possible because it may be
diff --git a/llvm/lib/Target/ARM/ARMBaseInstrInfo.h b/llvm/lib/Target/ARM/ARMBaseInstrInfo.h
index 2869e7f708046..f0e090f09f5dc 100644
--- a/llvm/lib/Target/ARM/ARMBaseInstrInfo.h
+++ b/llvm/lib/Target/ARM/ARMBaseInstrInfo.h
@@ -221,6 +221,12 @@ class ARMBaseInstrInfo : public ARMGenInstrInfo {
const TargetRegisterInfo *TRI, Register VReg,
MachineInstr::MIFlag Flags = MachineInstr::NoFlags) const override;
+ bool expandCtSelectVector(MachineInstr &MI) const;
+
+ bool expandCtSelectThumb(MachineInstr &MI) const;
+
+ bool expandCtSelect(MachineInstr &MI) const;
+
bool expandPostRAPseudo(MachineInstr &MI) const override;
bool shouldSink(const MachineInstr &MI) const override;
diff --git a/llvm/lib/Target/ARM/ARMISelDAGToDAG.cpp b/llvm/lib/Target/ARM/ARMISelDAGToDAG.cpp
index 847b7af5a9b11..3fdc5734baaa5 100644
--- a/llvm/lib/Target/ARM/ARMISelDAGToDAG.cpp
+++ b/llvm/lib/Target/ARM/ARMISelDAGToDAG.cpp
@@ -4200,6 +4200,92 @@ void ARMDAGToDAGISel::Select(SDNode *N) {
// Other cases are autogenerated.
break;
}
+ case ARMISD::CTSELECT: {
+ EVT VT = N->getValueType(0);
+ unsigned PseudoOpcode;
+ bool IsFloat = false;
+ bool IsVector = false;
+
+ if (VT == MVT::f16) {
+ PseudoOpcode = ARM::CTSELECTf16;
+ IsFloat = true;
+ } else if (VT == MVT::bf16) {
+ PseudoOpcode = ARM::CTSELECTbf16;
+ IsFloat = true;
+ } else if (VT == MVT::f32) {
+ PseudoOpcode = ARM::CTSELECTf32;
+ IsFloat = true;
+ } else if (VT == MVT::f64) {
+ PseudoOpcode = ARM::CTSELECTf64;
+ IsVector = true;
+ } else if (VT == MVT::v8i8) {
+ PseudoOpcode = ARM::CTSELECTv8i8;
+ IsVector = true;
+ } else if (VT == MVT::v4i16) {
+ PseudoOpcode = ARM::CTSELECTv4i16;
+ IsVector = true;
+ } else if (VT == MVT::v2i32) {
+ PseudoOpcode = ARM::CTSELECTv2i32;
+ IsVector = true;
+ } else if (VT == MVT::v1i64) {
+ PseudoOpcode = ARM::CTSELECTv1i64;
+ IsVector = true;
+ } else if (VT == MVT::v2f32) {
+ PseudoOpcode = ARM::CTSELECTv2f32;
+ IsVector = true;
+ } else if (VT == MVT::v4f16) {
+ PseudoOpcode = ARM::CTSELECTv4f16;
+ IsVector = true;
+ } else if (VT == MVT::v4bf16) {
+ PseudoOpcode = ARM::CTSELECTv4bf16;
+ IsVector = true;
+ } else if (VT == MVT::v16i8) {
+ PseudoOpcode = ARM::CTSELECTv16i8;
+ IsVector = true;
+ } else if (VT == MVT::v8i16) {
+ PseudoOpcode = ARM::CTSELECTv8i16;
+ IsVector = true;
+ } else if (VT == MVT::v4i32) {
+ PseudoOpcode = ARM::CTSELECTv4i32;
+ IsVector = true;
+ } else if (VT == MVT::v2i64) {
+ PseudoOpcode = ARM::CTSELECTv2i64;
+ IsVector = true;
+ } else if (VT == MVT::v4f32) {
+ PseudoOpcode = ARM::CTSELECTv4f32;
+ IsVector = true;
+ } else if (VT == MVT::v2f64) {
+ PseudoOpcode = ARM::CTSELECTv2f64;
+ IsVector = true;
+ } else if (VT == MVT::v8f16) {
+ PseudoOpcode = ARM::CTSELECTv8f16;
+ IsVector = true;
+ } else if (VT == MVT::v8bf16) {
+ PseudoOpcode = ARM::CTSELECTv8bf16;
+ IsVector = true;
+ } else {
+ // i1, i8, i16, i32, i64
+ PseudoOpcode = ARM::CTSELECTint;
+ }
+
+ SmallVector<EVT, 4> VTs;
+ VTs.push_back(VT); // $dst
+ VTs.push_back(MVT::i32); // $tmp_mask (always GPR)
+
+ if (IsVector) {
+ VTs.push_back(VT); // $bcast_mask (same type as dst for vectors)
+ } else if (IsFloat) {
+ VTs.push_back(MVT::i32); // $scratch1 (GPR)
+ VTs.push_back(MVT::i32); // $scratch2 (GPR)
+ }
+
+ // src1, src2, cond
+ SDValue Ops[] = {N->getOperand(0), N->getOperand(1), N->getOperand(2)};
+
+ SDNode *ResNode = CurDAG->getMachineNode(PseudoOpcode, SDLoc(N), VTs, Ops);
+ ReplaceNode(N, ResNode);
+ return;
+ }
case ARMISD::VZIP: {
EVT VT = N->getValueType(0);
// vzip.32 Dd, Dm is a pseudo-instruction expanded to vtrn.32 Dd, Dm.
diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index 6b0653457cbaf..63005f1c9f989 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -203,6 +203,7 @@ void ARMTargetLowering::addTypeForNEON(MVT VT, MVT PromotedLdStVT) {
setOperationAction(ISD::SELECT, VT, Expand);
setOperationAction(ISD::SELECT_CC, VT, Expand);
setOperationAction(ISD::VSELECT, VT, Expand);
+ setOperationAction(ISD::CTSELECT, VT, Custom);
setOperationAction(ISD::SIGN_EXTEND_INREG, VT, Expand);
if (VT.isInteger()) {
setOperationAction(ISD::SHL, VT, Custom);
@@ -304,6 +305,7 @@ void ARMTargetLowering::addMVEVectorTypes(bool HasMVEFP) {
setOperationAction(ISD::CTPOP, VT, Expand);
setOperationAction(ISD::SELECT, VT, Expand);
setOperationAction(ISD::SELECT_CC, VT, Expand);
+ setOperationAction(ISD::CTSELECT, VT, Custom);
// Vector reductions
setOperationAction(ISD::VECREDUCE_ADD, VT, Legal);
@@ -355,6 +357,7 @@ void ARMTargetLowering::addMVEVectorTypes(bool HasMVEFP) {
setOperationAction(ISD::MSTORE, VT, Legal);
setOperationAction(ISD::SELECT, VT, Expand);
setOperationAction(ISD::SELECT_CC, VT, Expand);
+ setOperationAction(ISD::CTSELECT, VT, Custom);
// Pre and Post inc are supported on loads and stores
for (unsigned im = (unsigned)ISD::PRE_INC;
@@ -408,6 +411,28 @@ void ARMTargetLowering::addMVEVectorTypes(bool HasMVEFP) {
setOperationAction(ISD::VECREDUCE_FMIN, MVT::v2f16, Custom);
setOperationAction(ISD::VECREDUCE_FMAX, MVT::v2f16, Custom);
+ if (Subtarget->hasFullFP16()) {
+ setOperationAction(ISD::CTSELECT, MVT::v4f16, Custom);
+ setOperationAction(ISD::CTSELECT, MVT::v8f16, Custom);
+ }
+
+ if (Subtarget->hasBF16()) {
+ setOperationAction(ISD::CTSELECT, MVT::v4bf16, Custom);
+ setOperationAction(ISD::CTSELECT, MVT::v8bf16, Custom);
+ }
+
+ // small exotic vectors get scalarised for ctselect
+ setOperationAction(ISD::CTSELECT, MVT::v1i8, Expand);
+ setOperationAction(ISD::CTSELECT, MVT::v1i16, Expand);
+ setOperationAction(ISD::CTSELECT, MVT::v1i32, Expand);
+ setOperationAction(ISD::CTSELECT, MVT::v1f32, Expand);
+ setOperationAction(ISD::CTSELECT, MVT::v2i8, Expand);
+
+ setOperationAction(ISD::CTSELECT, MVT::v2i16, Promote);
+ setOperationPromotedToType(ISD::CTSELECT, MVT::v2i16, MVT::v4i16);
+ setOperationAction(IS...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/166707
More information about the llvm-branch-commits
mailing list