[llvm] [ConstantTime][LLVM] Add llvm.ct.select intrinsic with generic SelectionDAG lowering (PR #166702)
Simon Tatham via llvm-commits
llvm-commits at lists.llvm.org
Fri Nov 7 06:24:55 PST 2025
================
@@ -6496,6 +6496,105 @@ void SelectionDAGBuilder::visitVectorExtractLastActive(const CallInst &I,
setValue(&I, Result);
}
+/// Fallback implementation for constant-time select using DAG chaining.
+/// This implementation uses data dependencies through virtual registers to
+/// prevent optimizations from breaking the constant-time property.
+/// It handles scalars, vectors (fixed and scalable), and floating-point types.
+SDValue SelectionDAGBuilder::createProtectedCtSelectFallback(
+ SelectionDAG &DAG, const SDLoc &DL, SDValue Cond, SDValue T, SDValue F,
+ EVT VT) {
+
+ SDValue WorkingT = T;
+ SDValue WorkingF = F;
+ EVT WorkingVT = VT;
+
+ SDValue Chain = DAG.getEntryNode();
+ MachineRegisterInfo &MRI = DAG.getMachineFunction().getRegInfo();
+
+ // Handle vector condition: splat scalar condition to vector
+ if (VT.isVector() && !Cond.getValueType().isVector()) {
+ ElementCount NumElems = VT.getVectorElementCount();
+ EVT CondVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1, NumElems);
+
+ if (VT.isScalableVector()) {
+ Cond = DAG.getSplatVector(CondVT, DL, Cond);
+ } else {
+ Cond = DAG.getSplatBuildVector(CondVT, DL, Cond);
+ }
+ }
+
+ // Handle floating-point types: bitcast to integer for bitwise operations
+ if (VT.isFloatingPoint()) {
+ if (VT.isVector()) {
+ // float vector -> int vector
+ EVT ElemVT = VT.getVectorElementType();
+ unsigned int ElemBitWidth = ElemVT.getScalarSizeInBits();
+ EVT IntElemVT = EVT::getIntegerVT(*DAG.getContext(), ElemBitWidth);
+
+ WorkingVT = EVT::getVectorVT(*DAG.getContext(), IntElemVT,
+ VT.getVectorElementCount());
+ } else {
+ WorkingVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits());
+ }
+
+ WorkingT = DAG.getBitcast(WorkingVT, T);
+ WorkingF = DAG.getBitcast(WorkingVT, F);
+ }
+
+ // Create mask: sign-extend condition to all bits
+ SDValue Mask = DAG.getSExtOrTrunc(Cond, DL, WorkingVT);
+
+ // Create all-ones constant for inversion
+ SDValue AllOnes;
+ if (WorkingVT.isScalableVector()) {
+ unsigned BitWidth = WorkingVT.getScalarSizeInBits();
+ APInt AllOnesVal = APInt::getAllOnes(BitWidth);
+ SDValue ScalarAllOnes =
+ DAG.getConstant(AllOnesVal, DL, WorkingVT.getScalarType());
+ AllOnes = DAG.getSplatVector(WorkingVT, DL, ScalarAllOnes);
+ } else {
+ AllOnes = DAG.getAllOnesConstant(DL, WorkingVT);
+ }
+
+ // Invert mask for false value
+ SDValue Invert = DAG.getNode(ISD::XOR, DL, WorkingVT, Mask, AllOnes);
+
+ // Compute: (T & Mask) | (F & ~Mask)
----------------
statham-arm wrote:
Is it better to compute `F ^ ((T ^ F) & Mask)`? That's what I normally do in my handwritten bit-twiddling selects. It's the same number of binary bitwise operations (two XORs and an AND, instead of two ANDs and an OR), but it avoids the extra unary operation of having to invert `Mask`.
https://github.com/llvm/llvm-project/pull/166702
More information about the llvm-commits
mailing list