[llvm] [ConstantTime][LLVM] Add llvm.ct.select intrinsic with generic SelectionDAG lowering (PR #166702)

via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 6 12:37:14 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Julius Alexandre (wizardengineer)

<details>
<summary>Changes</summary>

Here's the updated PR description with the actual PR links:

This is in reference to our [[RFC] Constant-Time Coding Support](https://discourse.llvm.org/t/rfc-constant-time-coding-support/87781) proposal.

  ## Summary

 This PR introduces core infrastructure for constant-time selection operations in LLVM, providing a foundation for
 cryptographic code that prevents timing side-channels. This is the first PR in a stacked series implementing
 comprehensive constant-time selection support across multiple architectures.

  ## Changes

  ### 1. Core Intrinsic Definition

  Adds the `llvm.ct.select` intrinsic family to LLVM IR:
  - Provides constant-time selection semantics: `result = condition ? true_value : false_value`
  - Guarantees execution time independent of the condition value
  - Prevents timing side-channel attacks in security-sensitive code
  - Defined in `llvm/include/llvm/IR/Intrinsics.td`

  ### 2. Generic Fallback Implementation

  Implements architecture-agnostic SelectionDAG lowering using bitwise operations:
  - **Pattern**: `result = (true_val & mask) | (false_val & ~mask)` where `mask = -(condition)`
  - Works on any architecture without specialized instruction support
  - Ensures constant-time execution through branch-free operations
  - Located in `llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp`

  The fallback implementation converts boolean conditions into bitmasks and uses bitwise arithmetic to achieve
  constant-time selection:
  condition (0 or 1) → NEG → mask (0x0000... or 0xFFFF...)
  result = (true_val & mask) | (false_val & ~mask)

  This approach guarantees:
  - No conditional branches that could leak timing information
  - Constant execution time regardless of condition value
  - Portable implementation across all LLVM targets

  ### 3. Test Coverage

  Includes basic test cases demonstrating fallback functionality:
  - **RISC-V**: `llvm/test/CodeGen/RISCV/ctselect-fallback.ll` - Generic fallback pattern verification
  - **X86**: `llvm/test/CodeGen/X86/ctselect.ll` - Demonstrates generic lowering before optimization

  These tests verify that the intrinsic correctly lowers to constant-time bitwise operations on architectures without
  native support.

  ## Architecture Support

  This PR provides the **fallback implementation** that works on all architectures. Subsequent PRs in the stack will add:
  - Clang frontend support (`__builtin_ct_select`)
  - Architecture-specific optimizations (CMOV, conditional moves, etc.)
  - Comprehensive test suites for each target

  ## Security Properties

  The fallback implementation ensures:
  1. **No conditional branches** - Execution path is independent of condition value
  2. **Constant execution time** - All code paths take the same number of cycles
  3. **No data-dependent memory access** - All memory operations are unconditional
  4. **Compiler barrier semantics** - Prevents optimization across ct.select boundaries

  ## Related PRs

  This is part of a stacked PR series implementing constant-time selection:
  1. **#<!-- -->166702** (this PR): Core infrastructure and fallback implementation
  2. **#<!-- -->166703**: Clang frontend support (`__builtin_ct_select`)
  3. **#<!-- -->166708**: RISC-V comprehensive tests
  4. **#<!-- -->166705**: MIPS comprehensive tests
  5. **#<!-- -->166709**: WebAssembly comprehensive tests
  6. **#<!-- -->166704**: X86/i386 optimizations and tests
  7. **#<!-- -->166706**: AArch64 (ARM64) optimizations and tests
  8. **#<!-- -->166707**: ARM32/Thumb optimizations and tests

  ## Testing

  All changes pass existing regression tests. New tests verify:
  - Correct lowering to bitwise operations
  - Absence of conditional branches in generated code
  - Proper handling of various data types (integers, pointers, floats)

  ## References

  - RFC Discussion: https://discourse.llvm.org/t/rfc-constant-time-coding-support/87781

---

Patch is 70.38 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/166702.diff


19 Files Affected:

- (modified) llvm/include/llvm/CodeGen/ISDOpcodes.h (+4) 
- (modified) llvm/include/llvm/CodeGen/SelectionDAG.h (+7) 
- (modified) llvm/include/llvm/CodeGen/SelectionDAGNodes.h (+3-1) 
- (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+11-7) 
- (modified) llvm/include/llvm/IR/Intrinsics.td (+9) 
- (modified) llvm/include/llvm/Target/TargetSelectionDAG.td (+6) 
- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+110-2) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp (+44-2) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp (+16-1) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (+20) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+5-1) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypesGeneric.cpp (+14) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp (+13) 
- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+1) 
- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+131) 
- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h (+3) 
- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp (+1) 
- (added) llvm/test/CodeGen/RISCV/ctselect-fallback.ll (+330) 
- (added) llvm/test/CodeGen/X86/ctselect.ll (+779) 


``````````diff
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index ff3dd0d4c3c51..656f6e718f029 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -783,6 +783,10 @@ enum NodeType {
   /// i1 then the high bits must conform to getBooleanContents.
   SELECT,
 
+  /// Constant-time Select, implemented with CMOV instruction. This is used to
+  /// implement constant-time select.
+  CTSELECT,
+
   /// Select with a vector condition (op #0) and two vector operands (ops #1
   /// and #2), returning a vector result.  All vectors have the same length.
   /// Much like the scalar select and setcc, each bit in the condition selects
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 1a5ffb38f2568..b5debd490d9cb 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1352,6 +1352,13 @@ class SelectionDAG {
     return getNode(Opcode, DL, VT, Cond, LHS, RHS, Flags);
   }
 
+  SDValue getCTSelect(const SDLoc &DL, EVT VT, SDValue Cond, SDValue LHS,
+                      SDValue RHS, SDNodeFlags Flags = SDNodeFlags()) {
+    assert(LHS.getValueType() == VT && RHS.getValueType() == VT &&
+           "Cannot use select on differing types");
+    return getNode(ISD::CTSELECT, DL, VT, Cond, LHS, RHS, Flags);
+  }
+
   /// Helper function to make it easier to build SelectCC's if you just have an
   /// ISD::CondCode instead of an SDValue.
   SDValue getSelectCC(const SDLoc &DL, SDValue LHS, SDValue RHS, SDValue True,
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 1759463ea7965..8e18eb2f7db0e 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -435,6 +435,9 @@ struct SDNodeFlags {
                             NonNeg | NoNaNs | NoInfs | SameSign | InBounds,
     FastMathFlags = NoNaNs | NoInfs | NoSignedZeros | AllowReciprocal |
                     AllowContract | ApproximateFuncs | AllowReassociation,
+
+    // Flag for disabling optimization
+    NoMerge = 1 << 15,
   };
 
   /// Default constructor turns off all optimization flags.
@@ -486,7 +489,6 @@ struct SDNodeFlags {
   bool hasNoFPExcept() const { return Flags & NoFPExcept; }
   bool hasUnpredictable() const { return Flags & Unpredictable; }
   bool hasInBounds() const { return Flags & InBounds; }
-
   bool operator==(const SDNodeFlags &Other) const {
     return Flags == Other.Flags;
   }
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 78f63b4406eb0..8198485803d8b 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -242,11 +242,15 @@ class LLVM_ABI TargetLoweringBase {
 
   /// Enum that describes what type of support for selects the target has.
   enum SelectSupportKind {
-    ScalarValSelect,      // The target supports scalar selects (ex: cmov).
-    ScalarCondVectorVal,  // The target supports selects with a scalar condition
-                          // and vector values (ex: cmov).
-    VectorMaskSelect      // The target supports vector selects with a vector
-                          // mask (ex: x86 blends).
+    ScalarValSelect,     // The target supports scalar selects (ex: cmov).
+    ScalarCondVectorVal, // The target supports selects with a scalar condition
+                         // and vector values (ex: cmov).
+    VectorMaskSelect,    // The target supports vector selects with a vector
+                         // mask (ex: x86 blends).
+    CtSelect,            // The target implements a custom constant-time select.
+    ScalarCondVectorValCtSelect, // The target supports selects with a scalar
+                                 // condition and vector values.
+    VectorMaskValCtSelect, // The target supports vector selects with a vector
   };
 
   /// Enum that specifies what an atomic load/AtomicRMWInst is expanded
@@ -476,8 +480,8 @@ class LLVM_ABI TargetLoweringBase {
   MachineMemOperand::Flags
   getVPIntrinsicMemOperandFlags(const VPIntrinsic &VPIntrin) const;
 
-  virtual bool isSelectSupported(SelectSupportKind /*kind*/) const {
-    return true;
+  virtual bool isSelectSupported(SelectSupportKind kind) const {
+    return kind != CtSelect;
   }
 
   /// Return true if the @llvm.get.active.lane.mask intrinsic should be expanded
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 6a079f62dd9cf..d41c61777089d 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -1825,6 +1825,15 @@ def int_coro_subfn_addr : DefaultAttrsIntrinsic<
     [IntrReadMem, IntrArgMemOnly, ReadOnly<ArgIndex<0>>,
      NoCapture<ArgIndex<0>>]>;
 
+///===-------------------------- Constant Time Intrinsics
+///--------------------------===//
+//
+// Intrinsic to support constant time select
+def int_ct_select
+    : DefaultAttrsIntrinsic<[llvm_any_ty],
+                            [llvm_i1_ty, LLVMMatchType<0>, LLVMMatchType<0>],
+                            [IntrWriteMem, IntrWillReturn, NoUndef<RetIndex>]>;
+
 ///===-------------------------- Other Intrinsics --------------------------===//
 //
 // TODO: We should introduce a new memory kind fo traps (and other side effects
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 07a858fd682fc..de4abd713d3cf 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -214,6 +214,11 @@ def SDTSelect : SDTypeProfile<1, 3, [       // select
   SDTCisInt<1>, SDTCisSameAs<0, 2>, SDTCisSameAs<2, 3>
 ]>;
 
+def SDTCtSelect
+    : SDTypeProfile<1, 3,
+                    [ // ctselect
+                        SDTCisInt<1>, SDTCisSameAs<0, 2>, SDTCisSameAs<2, 3>]>;
+
 def SDTVSelect : SDTypeProfile<1, 3, [       // vselect
   SDTCisVec<0>, SDTCisInt<1>, SDTCisSameAs<0, 2>, SDTCisSameAs<2, 3>, SDTCisSameNumEltsAs<0, 1>
 ]>;
@@ -717,6 +722,7 @@ def reset_fpmode   : SDNode<"ISD::RESET_FPMODE", SDTNone, [SDNPHasChain]>;
 
 def setcc      : SDNode<"ISD::SETCC"      , SDTSetCC>;
 def select     : SDNode<"ISD::SELECT"     , SDTSelect>;
+def ctselect : SDNode<"ISD::CTSELECT", SDTCtSelect>;
 def vselect    : SDNode<"ISD::VSELECT"    , SDTVSelect>;
 def selectcc   : SDNode<"ISD::SELECT_CC"  , SDTSelectCC>;
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 46c4bb85a7420..28fcebbb4a92a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -484,6 +484,7 @@ namespace {
     SDValue visitCTTZ_ZERO_UNDEF(SDNode *N);
     SDValue visitCTPOP(SDNode *N);
     SDValue visitSELECT(SDNode *N);
+    SDValue visitCTSELECT(SDNode *N);
     SDValue visitVSELECT(SDNode *N);
     SDValue visitVP_SELECT(SDNode *N);
     SDValue visitSELECT_CC(SDNode *N);
@@ -1898,6 +1899,7 @@ void DAGCombiner::Run(CombineLevel AtLevel) {
 }
 
 SDValue DAGCombiner::visit(SDNode *N) {
+
   // clang-format off
   switch (N->getOpcode()) {
   default: break;
@@ -1968,6 +1970,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
   case ISD::CTTZ_ZERO_UNDEF:    return visitCTTZ_ZERO_UNDEF(N);
   case ISD::CTPOP:              return visitCTPOP(N);
   case ISD::SELECT:             return visitSELECT(N);
+  case ISD::CTSELECT:           return visitCTSELECT(N);
   case ISD::VSELECT:            return visitVSELECT(N);
   case ISD::SELECT_CC:          return visitSELECT_CC(N);
   case ISD::SETCC:              return visitSETCC(N);
@@ -6032,6 +6035,7 @@ static SDValue isSaturatingMinMax(SDValue N0, SDValue N1, SDValue N2,
     N0CC = cast<CondCodeSDNode>(N0.getOperand(4))->get();
     break;
   case ISD::SELECT:
+  case ISD::CTSELECT:
   case ISD::VSELECT:
     if (N0.getOperand(0).getOpcode() != ISD::SETCC)
       return SDValue();
@@ -12184,8 +12188,9 @@ template <class MatchContextClass>
 static SDValue foldBoolSelectToLogic(SDNode *N, const SDLoc &DL,
                                      SelectionDAG &DAG) {
   assert((N->getOpcode() == ISD::SELECT || N->getOpcode() == ISD::VSELECT ||
-          N->getOpcode() == ISD::VP_SELECT) &&
-         "Expected a (v)(vp.)select");
+          N->getOpcode() == ISD::VP_SELECT ||
+          N->getOpcode() == ISD::CTSELECT) &&
+         "Expected a (v)(vp.)(ct) select");
   SDValue Cond = N->getOperand(0);
   SDValue T = N->getOperand(1), F = N->getOperand(2);
   EVT VT = N->getValueType(0);
@@ -12547,6 +12552,109 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) {
   return SDValue();
 }
 
+SDValue DAGCombiner::visitCTSELECT(SDNode *N) {
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+  SDValue N2 = N->getOperand(2);
+  EVT VT = N->getValueType(0);
+  EVT VT0 = N0.getValueType();
+  SDLoc DL(N);
+  SDNodeFlags Flags = N->getFlags();
+
+  if (SDValue V = foldBoolSelectToLogic<EmptyMatchContext>(N, DL, DAG))
+    return V;
+
+  // ctselect (not Cond), N1, N2 -> ctselect Cond, N2, N1
+  if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false)) {
+    SDValue SelectOp = DAG.getNode(ISD::CTSELECT, DL, VT, F, N2, N1);
+    SelectOp->setFlags(Flags);
+    return SelectOp;
+  }
+
+  if (VT0 == MVT::i1) {
+    // The code in this block deals with the following 2 equivalences:
+    //    select(C0|C1, x, y) <=> select(C0, x, select(C1, x, y))
+    //    select(C0&C1, x, y) <=> select(C0, select(C1, x, y), y)
+    // The target can specify its preferred form with the
+    // shouldNormalizeToSelectSequence() callback. However we always transform
+    // to the right anyway if we find the inner select exists in the DAG anyway
+    // and we always transform to the left side if we know that we can further
+    // optimize the combination of the conditions.
+    bool normalizeToSequence =
+        TLI.shouldNormalizeToSelectSequence(*DAG.getContext(), VT);
+    // ctselect (and Cond0, Cond1), X, Y
+    //   -> ctselect Cond0, (ctselect Cond1, X, Y), Y
+    if (N0->getOpcode() == ISD::AND && N0->hasOneUse()) {
+      SDValue Cond0 = N0->getOperand(0);
+      SDValue Cond1 = N0->getOperand(1);
+      SDValue InnerSelect = DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(),
+                                        Cond1, N1, N2, Flags);
+      if (normalizeToSequence || !InnerSelect.use_empty())
+        return DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(), Cond0,
+                           InnerSelect, N2, Flags);
+      // Cleanup on failure.
+      if (InnerSelect.use_empty())
+        recursivelyDeleteUnusedNodes(InnerSelect.getNode());
+    }
+    // ctselect (or Cond0, Cond1), X, Y -> ctselect Cond0, X, (ctselect Cond1,
+    // X, Y)
+    if (N0->getOpcode() == ISD::OR && N0->hasOneUse()) {
+      SDValue Cond0 = N0->getOperand(0);
+      SDValue Cond1 = N0->getOperand(1);
+      SDValue InnerSelect = DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(),
+                                        Cond1, N1, N2, Flags);
+      if (normalizeToSequence || !InnerSelect.use_empty())
+        return DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(), Cond0, N1,
+                           InnerSelect, Flags);
+      // Cleanup on failure.
+      if (InnerSelect.use_empty())
+        recursivelyDeleteUnusedNodes(InnerSelect.getNode());
+    }
+
+    // ctselect Cond0, (ctselect Cond1, X, Y), Y -> ctselect (and Cond0, Cond1),
+    // X, Y
+    if (N1->getOpcode() == ISD::CTSELECT && N1->hasOneUse()) {
+      SDValue N1_0 = N1->getOperand(0);
+      SDValue N1_1 = N1->getOperand(1);
+      SDValue N1_2 = N1->getOperand(2);
+      if (N1_2 == N2 && N0.getValueType() == N1_0.getValueType()) {
+        // Create the actual and node if we can generate good code for it.
+        if (!normalizeToSequence) {
+          SDValue And = DAG.getNode(ISD::AND, DL, N0.getValueType(), N0, N1_0);
+          return DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(), And, N1_1,
+                             N2, Flags);
+        }
+        // Otherwise see if we can optimize the "and" to a better pattern.
+        if (SDValue Combined = visitANDLike(N0, N1_0, N)) {
+          return DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(), Combined,
+                             N1_1, N2, Flags);
+        }
+      }
+    }
+    // ctselect Cond0, X, (ctselect Cond1, X, Y) -> ctselect (or Cond0, Cond1),
+    // X, Y
+    if (N2->getOpcode() == ISD::CTSELECT && N2->hasOneUse()) {
+      SDValue N2_0 = N2->getOperand(0);
+      SDValue N2_1 = N2->getOperand(1);
+      SDValue N2_2 = N2->getOperand(2);
+      if (N2_1 == N1 && N0.getValueType() == N2_0.getValueType()) {
+        // Create the actual or node if we can generate good code for it.
+        if (!normalizeToSequence) {
+          SDValue Or = DAG.getNode(ISD::OR, DL, N0.getValueType(), N0, N2_0);
+          return DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(), Or, N1, N2_2,
+                             Flags);
+        }
+        // Otherwise see if we can optimize to a better pattern.
+        if (SDValue Combined = visitORLike(N0, N2_0, DL))
+          return DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(), Combined, N1,
+                             N2_2, Flags);
+      }
+    }
+  }
+
+  return SDValue();
+}
+
 // This function assumes all the vselect's arguments are CONCAT_VECTOR
 // nodes and that the condition is a BV of ConstantSDNodes (or undefs).
 static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 431a81002074f..8178fd8981519 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -4136,6 +4136,46 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
     }
     Results.push_back(Tmp1);
     break;
+  case ISD::CTSELECT: {
+    Tmp1 = Node->getOperand(0);
+    Tmp2 = Node->getOperand(1);
+    Tmp3 = Node->getOperand(2);
+    EVT VT = Tmp2.getValueType();
+    if (VT.isVector()) {
+      SmallVector<SDValue> Elements;
+      unsigned NumElements = VT.getVectorNumElements();
+      EVT ScalarVT = VT.getScalarType();
+      for (unsigned Idx = 0; Idx < NumElements; ++Idx) {
+        SDValue IdxVal = DAG.getConstant(Idx, dl, MVT::i64);
+        SDValue TVal =
+            DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, ScalarVT, Tmp2, IdxVal);
+        SDValue FVal =
+            DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, ScalarVT, Tmp3, IdxVal);
+        Elements.push_back(
+            DAG.getCTSelect(dl, ScalarVT, Tmp1, TVal, FVal, Node->getFlags()));
+      }
+      Tmp1 = DAG.getBuildVector(VT, dl, Elements);
+    } else if (VT.isFloatingPoint()) {
+      EVT IntegerVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits());
+      Tmp2 = DAG.getBitcast(IntegerVT, Tmp2);
+      Tmp3 = DAG.getBitcast(IntegerVT, Tmp3);
+      Tmp1 = DAG.getBitcast(VT, DAG.getCTSelect(dl, IntegerVT, Tmp1, Tmp2, Tmp3,
+                                                Node->getFlags()));
+    } else {
+      assert(VT.isInteger());
+      EVT HalfVT = VT.getHalfSizedIntegerVT(*DAG.getContext());
+      auto [Tmp2Lo, Tmp2Hi] = DAG.SplitScalar(Tmp2, dl, HalfVT, HalfVT);
+      auto [Tmp3Lo, Tmp3Hi] = DAG.SplitScalar(Tmp3, dl, HalfVT, HalfVT);
+      SDValue ResLo =
+          DAG.getCTSelect(dl, HalfVT, Tmp1, Tmp2Lo, Tmp3Lo, Node->getFlags());
+      SDValue ResHi =
+          DAG.getCTSelect(dl, HalfVT, Tmp1, Tmp2Hi, Tmp3Hi, Node->getFlags());
+      Tmp1 = DAG.getNode(ISD::BUILD_PAIR, dl, VT, ResLo, ResHi);
+      Tmp1->setFlags(Node->getFlags());
+    }
+    Results.push_back(Tmp1);
+    break;
+  }
   case ISD::BR_JT: {
     SDValue Chain = Node->getOperand(0);
     SDValue Table = Node->getOperand(1);
@@ -5474,7 +5514,8 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
     Results.push_back(DAG.getNode(ISD::TRUNCATE, dl, OVT, Tmp2));
     break;
   }
-  case ISD::SELECT: {
+  case ISD::SELECT:
+  case ISD::CTSELECT: {
     unsigned ExtOp, TruncOp;
     if (Node->getValueType(0).isVector() ||
         Node->getValueType(0).getSizeInBits() == NVT.getSizeInBits()) {
@@ -5492,7 +5533,8 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
     Tmp2 = DAG.getNode(ExtOp, dl, NVT, Node->getOperand(1));
     Tmp3 = DAG.getNode(ExtOp, dl, NVT, Node->getOperand(2));
     // Perform the larger operation, then round down.
-    Tmp1 = DAG.getSelect(dl, NVT, Tmp1, Tmp2, Tmp3);
+    Tmp1 = DAG.getNode(Node->getOpcode(), dl, NVT, Tmp1, Tmp2, Tmp3);
+    Tmp1->setFlags(Node->getFlags());
     if (TruncOp != ISD::FP_ROUND)
       Tmp1 = DAG.getNode(TruncOp, dl, Node->getValueType(0), Tmp1);
     else
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index 58983cb57d7f6..855a15a744cfe 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -159,6 +159,7 @@ void DAGTypeLegalizer::SoftenFloatResult(SDNode *N, unsigned ResNo) {
     case ISD::ATOMIC_LOAD: R = SoftenFloatRes_ATOMIC_LOAD(N); break;
     case ISD::ATOMIC_SWAP: R = BitcastToInt_ATOMIC_SWAP(N); break;
     case ISD::SELECT:      R = SoftenFloatRes_SELECT(N); break;
+    case ISD::CTSELECT:    R = SoftenFloatRes_CTSELECT(N); break;
     case ISD::SELECT_CC:   R = SoftenFloatRes_SELECT_CC(N); break;
     case ISD::FREEZE:      R = SoftenFloatRes_FREEZE(N); break;
     case ISD::STRICT_SINT_TO_FP:
@@ -1041,6 +1042,13 @@ SDValue DAGTypeLegalizer::SoftenFloatRes_SELECT(SDNode *N) {
                        LHS.getValueType(), N->getOperand(0), LHS, RHS);
 }
 
+SDValue DAGTypeLegalizer::SoftenFloatRes_CTSELECT(SDNode *N) {
+  SDValue LHS = GetSoftenedFloat(N->getOperand(1));
+  SDValue RHS = GetSoftenedFloat(N->getOperand(2));
+  return DAG.getCTSelect(SDLoc(N), LHS.getValueType(), N->getOperand(0), LHS,
+                         RHS);
+}
+
 SDValue DAGTypeLegalizer::SoftenFloatRes_SELECT_CC(SDNode *N) {
   SDValue LHS = GetSoftenedFloat(N->getOperand(2));
   SDValue RHS = GetSoftenedFloat(N->getOperand(3));
@@ -1561,6 +1569,7 @@ void DAGTypeLegalizer::ExpandFloatResult(SDNode *N, unsigned ResNo) {
   case ISD::POISON:
   case ISD::UNDEF:        SplitRes_UNDEF(N, Lo, Hi); break;
   case ISD::SELECT:       SplitRes_Select(N, Lo, Hi); break;
+  case ISD::CTSELECT:     SplitRes_Select(N, Lo, Hi); break;
   case ISD::SELECT_CC:    SplitRes_SELECT_CC(N, Lo, Hi); break;
 
   case ISD::MERGE_VALUES:       ExpandRes_MERGE_VALUES(N, ResNo, Lo, Hi); break;
@@ -2917,6 +2926,9 @@ void DAGTypeLegalizer::PromoteFloatResult(SDNode *N, unsigned ResNo) {
       R = PromoteFloatRes_ATOMIC_LOAD(N);
       break;
     case ISD::SELECT:     R = PromoteFloatRes_SELECT(N); break;
+    case ISD::CTSELECT:
+      R = PromoteFloatRes_SELECT(N);
+      break;
     case ISD::SELECT_CC:  R = PromoteFloatRes_SELECT_CC(N); break;
 
     case ISD::SINT_TO_FP:
@@ -3219,7 +3231,7 @@ SDValue DAGTypeLegalizer::PromoteFloatRes_SELECT(SDNode *N) {
   SDValue TrueVal = GetPromotedFloat(N->getOperand(1));
   SDValue FalseVal = GetPromotedFloat(N->getOperand(2));
 
-  return DAG.getNode(ISD::SELECT, SDLoc(N), TrueVal->getValueType(0),
+  return DAG.getNode(N->getOpcode(), SDLoc(N), TrueVal->getValueType(0),
                      N->getOperand(0), TrueVal, FalseVal);
 }
 
@@ -3403,6 +3415,9 @@ void DAGTypeLegalizer::SoftPromoteHalfResult(SDNode *N, unsigned ResNo) {
     R = SoftPromoteHalfRes_ATOMIC_LOAD(N);
     break;
   case ISD::SELECT:      R = SoftPromoteHalfRes_SELECT(N); break;
+  case ISD::CTSELECT:
+    R = SoftPromoteHalfRes_SELECT(N);
+    break;
   case ISD::SELECT_CC:   R = SoftPromoteHalfRes_SELECT_CC(N); break;
   case ISD::STRICT_SINT_TO_FP:
   case ISD::STRICT_UINT_TO_FP:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 44e5a187c4281..0135b3195438b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -95,6 +95,7 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
     Res = PromoteIntRes_VECTOR_COMPRESS(N);
     break;
   case ISD::SELECT:
+  case ISD::CTSELECT:
   case ISD::VSELECT:
   case ISD::VP_SELECT:
   case ISD::VP_MERGE:
@@ -2013,6 +2014,9 @@...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/166702


More information about the llvm-commits mailing list