[llvm] [AArch64] optimize vselect of bitcast (PR #180375)

via llvm-commits llvm-commits at lists.llvm.org
Sun Feb 8 03:09:17 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Folkert de Vries (folkertdev)

<details>
<summary>Changes</summary>

Using code/ideas from the x86 backend to optimize a select on a bitcast integer. The previous aarch64 approach was to individually extract the bits from the mask, which is kind of terrible.

https://rust.godbolt.org/z/576sndT66

```llvm
define void @<!-- -->if_then_else8(ptr %out, i8 %mask, ptr %if_true, ptr %if_false) {
start:
  %t = load <8 x i32>, ptr %if_true, align 4
  %f = load <8 x i32>, ptr %if_false, align 4
  %m = bitcast i8 %mask to <8 x i1>
  %s = select <8 x i1> %m, <8 x i32> %t, <8 x i32> %f
  store <8 x i32> %s, ptr %out, align 4
  ret void
}
```

turned into 

```asm
if_then_else8:                          // @<!-- -->if_then_else8
        sub     sp, sp, #<!-- -->16
        ubfx    w8, w1, #<!-- -->4, #<!-- -->1
        and     w11, w1, #<!-- -->0x1
        ubfx    w9, w1, #<!-- -->5, #<!-- -->1
        fmov    s1, w11
        ubfx    w10, w1, #<!-- -->1, #<!-- -->1
        fmov    s0, w8
        ubfx    w8, w1, #<!-- -->6, #<!-- -->1
        ldp     q5, q2, [x3]
        mov     v1.h[1], w10
        ldp     q4, q3, [x2]
        mov     v0.h[1], w9
        ubfx    w9, w1, #<!-- -->2, #<!-- -->1
        mov     v1.h[2], w9
        ubfx    w9, w1, #<!-- -->3, #<!-- -->1
        mov     v0.h[2], w8
        ubfx    w8, w1, #<!-- -->7, #<!-- -->1
        mov     v1.h[3], w9
        mov     v0.h[3], w8
        ushll   v1.4s, v1.4h, #<!-- -->0
        ushll   v0.4s, v0.4h, #<!-- -->0
        shl     v1.4s, v1.4s, #<!-- -->31
        shl     v0.4s, v0.4s, #<!-- -->31
        cmlt    v1.4s, v1.4s, #<!-- -->0
        cmlt    v0.4s, v0.4s, #<!-- -->0
        bsl     v1.16b, v4.16b, v5.16b
        bsl     v0.16b, v3.16b, v2.16b
        stp     q1, q0, [x0]
        add     sp, sp, #<!-- -->16
        ret
```

With this PR that instead emits

```asm
if_then_else8:
   adrp x8, .LCPI0_1
   dup v0.4s, w1
   ldr q1, [x8, :lo12:.LCPI0_1]
   adrp x8, .LCPI0_0
   ldr q2, [x8, :lo12:.LCPI0_0]
   ldp q4, q3, [x2]
   and v1.16b, v0.16b, v1.16b
   and v0.16b, v0.16b, v2.16b
   ldp q5, q2, [x3]
   cmeq v1.4s, v1.4s, #<!-- -->0
   cmeq v0.4s, v0.4s, #<!-- -->0
   bsl v1.16b, v2.16b, v3.16b
   bsl v0.16b, v5.16b, v4.16b
   stp q0, q1, [x0]
   ret
```

So substantially shorter. Instead of building the mask element-by-element, this approach (by virtue of not splitting) instead splats the mask value into all vector lanes, performs a bitwise and with powers of 2, and compares with zero to construct the mask vector. 

cc https://github.com/rust-lang/rust/issues/122376
cc https://github.com/llvm/llvm-project/pull/175769

---
Full diff: https://github.com/llvm/llvm-project/pull/180375.diff


1 Files Affected:

- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+106-2) 


``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 2b2e057c80373..178fd40d94be0 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -27264,12 +27264,98 @@ static SDValue trySwapVSelectOperands(SDNode *N, SelectionDAG &DAG) {
                      {InverseSetCC, SelectB, SelectA});
 }
 
+// Convert (vXiY *ext(vXi1 bitcast(iX))) to extend_in_reg(broadcast(iX)).
+static SDValue combineToExtendBoolVectorInReg(
+    unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N0, SelectionDAG &DAG,
+    TargetLowering::DAGCombinerInfo &DCI, const AArch64Subtarget &Subtarget) {
+  if (Opcode != ISD::SIGN_EXTEND && Opcode != ISD::ZERO_EXTEND &&
+      Opcode != ISD::ANY_EXTEND)
+    return SDValue();
+  if (!DCI.isBeforeLegalizeOps())
+    return SDValue();
+  if (!Subtarget.hasNEON())
+    return SDValue();
+
+  EVT SVT = VT.getScalarType();
+  EVT InSVT = N0.getValueType().getScalarType();
+  unsigned EltSizeInBits = SVT.getSizeInBits();
+
+  // Input type must be extending a bool vector (bit-casted from a scalar
+  // integer) to legal integer types.
+  if (!VT.isVector())
+    return SDValue();
+  if (SVT != MVT::i64 && SVT != MVT::i32 && SVT != MVT::i16 && SVT != MVT::i8)
+    return SDValue();
+  if (InSVT != MVT::i1 || N0.getOpcode() != ISD::BITCAST)
+    return SDValue();
+
+  SDValue N00 = N0.getOperand(0);
+  EVT SclVT = N00.getValueType();
+  if (!SclVT.isScalarInteger())
+    return SDValue();
+
+  SDValue Vec;
+  SmallVector<int> ShuffleMask;
+  unsigned NumElts = VT.getVectorNumElements();
+  assert(NumElts == SclVT.getSizeInBits() && "Unexpected bool vector size");
+
+  // Broadcast the scalar integer to the vector elements.
+  if (NumElts > EltSizeInBits) {
+    // If the scalar integer is greater than the vector element size, then we
+    // must split it down into sub-sections for broadcasting. For example:
+    //   i16 -> v16i8 (i16 -> v8i16 -> v16i8) with 2 sub-sections.
+    //   i32 -> v32i8 (i32 -> v8i32 -> v32i8) with 4 sub-sections.
+    assert((NumElts % EltSizeInBits) == 0 && "Unexpected integer scale");
+    unsigned Scale = NumElts / EltSizeInBits;
+    EVT BroadcastVT = EVT::getVectorVT(*DAG.getContext(), SclVT, EltSizeInBits);
+    Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, BroadcastVT, N00);
+    Vec = DAG.getBitcast(VT, Vec);
+
+    for (unsigned i = 0; i != Scale; ++i) {
+      int Offset = 0;
+      ShuffleMask.append(EltSizeInBits, i + Offset);
+    }
+    Vec = DAG.getVectorShuffle(VT, DL, Vec, Vec, ShuffleMask);
+  } else {
+    // For smaller scalar integers, we can simply any-extend it to the vector
+    // element size (we don't care about the upper bits) and broadcast it to all
+    // elements.
+    Vec = DAG.getSplat(VT, DL, DAG.getAnyExtOrTrunc(N00, DL, SVT));
+  }
+
+  // Now, mask the relevant bit in each element.
+  SmallVector<SDValue, 32> Bits;
+  for (unsigned i = 0; i != NumElts; ++i) {
+    int BitIdx = (i % EltSizeInBits);
+    APInt Bit = APInt::getBitsSet(EltSizeInBits, BitIdx, BitIdx + 1);
+    Bits.push_back(DAG.getConstant(Bit, DL, SVT));
+  }
+  SDValue BitMask = DAG.getBuildVector(VT, DL, Bits);
+  Vec = DAG.getNode(ISD::AND, DL, VT, Vec, BitMask);
+
+  // Compare against the bitmask and extend the result.
+  EVT CCVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1, NumElts);
+  Vec = DAG.getSetCC(DL, CCVT, Vec, BitMask, ISD::SETEQ);
+  Vec = DAG.getSExtOrTrunc(Vec, DL, VT);
+
+  // For SEXT, this is now done, otherwise shift the result down for
+  // zero-extension.
+  if (Opcode == ISD::SIGN_EXTEND)
+    return Vec;
+  return DAG.getNode(ISD::SRL, DL, VT, Vec,
+                     DAG.getConstant(EltSizeInBits - 1, DL, VT));
+}
+
 // vselect (v1i1 setcc) ->
 //     vselect (v1iXX setcc)  (XX is the size of the compared operand type)
 // FIXME: Currently the type legalizer can't handle VSELECT having v1i1 as
 // condition. If it can legalize "VSELECT v1i1" correctly, no need to combine
 // such VSELECT.
-static SDValue performVSelectCombine(SDNode *N, SelectionDAG &DAG) {
+static SDValue performVSelectCombine(SDNode *N,
+                                     TargetLowering::DAGCombinerInfo &DCI,
+                                     const AArch64Subtarget *Subtarget) {
+  SelectionDAG &DAG = DCI.DAG;
+
   if (auto SwapResult = trySwapVSelectOperands(N, DAG))
     return SwapResult;
 
@@ -27333,6 +27419,24 @@ static SDValue performVSelectCombine(SDNode *N, SelectionDAG &DAG) {
     }
   }
 
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+
+  // Attempt to convert a (vXi1 bitcast(iX N0)) selection mask before it might
+  // get split by legalization.
+  if (N0.getOpcode() == ISD::BITCAST && CCVT.isVector() &&
+      CCVT.getVectorElementType() == MVT::i1 &&
+      TLI.isTypeLegal(ResVT.getScalarType())) {
+
+    SDLoc DL(N);
+    EVT ExtCondVT = ResVT.changeVectorElementTypeToInteger();
+
+    if (SDValue ExtCond = combineToExtendBoolVectorInReg(
+            ISD::SIGN_EXTEND, DL, ExtCondVT, N0, DAG, DCI, *Subtarget)) {
+      ExtCond = DAG.getNode(ISD::TRUNCATE, DL, CCVT, ExtCond);
+      return DAG.getSelect(DL, ResVT, ExtCond, IfTrue, IfFalse);
+    }
+  }
+
   EVT CmpVT = N0.getOperand(0).getValueType();
   if (N0.getOpcode() != ISD::SETCC ||
       CCVT.getVectorElementCount() != ElementCount::getFixed(1) ||
@@ -28712,7 +28816,7 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
   case ISD::SELECT:
     return performSelectCombine(N, DCI);
   case ISD::VSELECT:
-    return performVSelectCombine(N, DCI.DAG);
+    return performVSelectCombine(N, DCI, Subtarget);
   case ISD::SETCC:
     return performSETCCCombine(N, DCI, DAG);
   case ISD::LOAD:

``````````

</details>


https://github.com/llvm/llvm-project/pull/180375


More information about the llvm-commits mailing list