[llvm] [SelectionDAG] Recurse through mask expression trees in WidenVSELECTMask (PR #188085)

Valeriy Savchenko via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 1 04:43:44 PDT 2026


https://github.com/SavchenkoValeriy updated https://github.com/llvm/llvm-project/pull/188085

>From b92b34f4e595e197d2b6bcfd8f235cde95a3df1c Mon Sep 17 00:00:00 2001
From: Valeriy Savchenko <vsavchenko at apple.com>
Date: Mon, 23 Mar 2026 17:51:12 +0000
Subject: [PATCH 1/3] [AArch64][NFC] Add tests for mask widening

---
 .../AArch64/vselect-widen-mask-tree.ll        | 218 ++++++++++++++++++
 1 file changed, 218 insertions(+)
 create mode 100644 llvm/test/CodeGen/AArch64/vselect-widen-mask-tree.ll

diff --git a/llvm/test/CodeGen/AArch64/vselect-widen-mask-tree.ll b/llvm/test/CodeGen/AArch64/vselect-widen-mask-tree.ll
new file mode 100644
index 0000000000000..c459c78c3e5df
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/vselect-widen-mask-tree.ll
@@ -0,0 +1,218 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=arm64-apple-ios -o - %s | FileCheck %s
+
+define <4 x i32> @freeze_or_setcc(<4 x i32> %a, <4 x i32> %b, <4 x i32> %x, <4 x i32> %y) {
+; CHECK-LABEL: freeze_or_setcc:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    add.4s v4, v0, v1
+; CHECK-NEXT:    cmgt.4s v0, v1, v0
+; CHECK-NEXT:    cmgt.4s v4, v4, #0
+; CHECK-NEXT:    orr.16b v0, v4, v0
+; CHECK-NEXT:    bsl.16b v0, v2, v3
+; CHECK-NEXT:    ret
+  %add = add nsw <4 x i32> %a, %b
+  %cmp1 = icmp sgt <4 x i32> %add, zeroinitializer
+  %sub = sub nsw <4 x i32> %a, %b
+  %cmp2 = icmp slt <4 x i32> %sub, zeroinitializer
+  %or = or <4 x i1> %cmp1, %cmp2
+  %fr = freeze <4 x i1> %or
+  %sel = select <4 x i1> %fr, <4 x i32> %x, <4 x i32> %y
+  ret <4 x i32> %sel
+}
+
+define <4 x i32> @select_allones_or_setcc(<4 x i32> %a, <4 x i32> %x, <4 x i32> %y, i1 %cond) {
+; CHECK-LABEL: select_allones_or_setcc:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    cmgt.4s v0, v0, #0
+; CHECK-NEXT:    tst w0, #0x1
+; CHECK-NEXT:    csetm w8, ne
+; CHECK-NEXT:    dup.4h v3, w8
+; CHECK-NEXT:    xtn.4h v0, v0
+; CHECK-NEXT:    orr.8b v0, v0, v3
+; CHECK-NEXT:    sshll.4s v0, v0, #0
+; CHECK-NEXT:    bsl.16b v0, v1, v2
+; CHECK-NEXT:    ret
+  %cmp = icmp sgt <4 x i32> %a, zeroinitializer
+  %mask = select i1 %cond, <4 x i1> <i1 true, i1 true, i1 true, i1 true>, <4 x i1> %cmp
+  %sel = select <4 x i1> %mask, <4 x i32> %x, <4 x i32> %y
+  ret <4 x i32> %sel
+}
+
+define <4 x i32> @select_setcc_or_allzeros(<4 x i32> %a, <4 x i32> %x, <4 x i32> %y, i1 %cond) {
+; CHECK-LABEL: select_setcc_or_allzeros:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    cmgt.4s v0, v0, #0
+; CHECK-NEXT:    tst w0, #0x1
+; CHECK-NEXT:    csetm w8, ne
+; CHECK-NEXT:    dup.4h v3, w8
+; CHECK-NEXT:    xtn.4h v0, v0
+; CHECK-NEXT:    and.8b v0, v0, v3
+; CHECK-NEXT:    sshll.4s v0, v0, #0
+; CHECK-NEXT:    bsl.16b v0, v1, v2
+; CHECK-NEXT:    ret
+  %cmp = icmp sgt <4 x i32> %a, zeroinitializer
+  %mask = select i1 %cond, <4 x i1> %cmp, <4 x i1> zeroinitializer
+  %sel = select <4 x i1> %mask, <4 x i32> %x, <4 x i32> %y
+  ret <4 x i32> %sel
+}
+
+define <4 x i32> @select_allzeros_or_allones(<4 x i32> %x, <4 x i32> %y, i1 %cond) {
+; CHECK-LABEL: select_allzeros_or_allones:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    tst w0, #0x1
+; CHECK-NEXT:    csetm w8, ne
+; CHECK-NEXT:    dup.4h v2, w8
+; CHECK-NEXT:    mvn.8b v2, v2
+; CHECK-NEXT:    sshll.4s v2, v2, #0
+; CHECK-NEXT:    bif.16b v0, v1, v2
+; CHECK-NEXT:    ret
+  %mask = select i1 %cond, <4 x i1> zeroinitializer, <4 x i1> <i1 true, i1 true, i1 true, i1 true>
+  %sel = select <4 x i1> %mask, <4 x i32> %x, <4 x i32> %y
+  ret <4 x i32> %sel
+}
+
+define <4 x i32> @vselect_of_setccs(<4 x i32> %a, <4 x i32> %b, <4 x i32> %x, <4 x i32> %y) {
+; CHECK-LABEL: vselect_of_setccs:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    movi.4s v4, #100
+; CHECK-NEXT:    cmgt.4s v5, v0, #0
+; CHECK-NEXT:    cmeq.4s v1, v1, #0
+; CHECK-NEXT:    cmgt.4s v0, v4, v0
+; CHECK-NEXT:    xtn.4h v4, v1
+; CHECK-NEXT:    and.16b v1, v5, v1
+; CHECK-NEXT:    xtn.4h v0, v0
+; CHECK-NEXT:    xtn.4h v1, v1
+; CHECK-NEXT:    bic.8b v0, v0, v4
+; CHECK-NEXT:    orr.8b v0, v1, v0
+; CHECK-NEXT:    sshll.4s v0, v0, #0
+; CHECK-NEXT:    bsl.16b v0, v2, v3
+; CHECK-NEXT:    ret
+  %cmp1 = icmp sgt <4 x i32> %a, zeroinitializer
+  %cmp2 = icmp slt <4 x i32> %a, <i32 100, i32 100, i32 100, i32 100>
+  %cond = icmp eq <4 x i32> %b, zeroinitializer
+  %mask = select <4 x i1> %cond, <4 x i1> %cmp1, <4 x i1> %cmp2
+  %sel = select <4 x i1> %mask, <4 x i32> %x, <4 x i32> %y
+  ret <4 x i32> %sel
+}
+
+define <4 x i32> @select_scalar_cond_setccs(<4 x i32> %a, <4 x i32> %x, <4 x i32> %y, i1 %cond) {
+; CHECK-LABEL: select_scalar_cond_setccs:
+; CHECK:       ; %bb.0: ; %entry
+; CHECK-NEXT:    movi.4s v3, #100
+; CHECK-NEXT:    cmgt.4s v4, v0, #0
+; CHECK-NEXT:    tst w0, #0x1
+; CHECK-NEXT:    csetm w8, ne
+; CHECK-NEXT:    cmgt.4s v0, v3, v0
+; CHECK-NEXT:    xtn.4h v3, v4
+; CHECK-NEXT:    dup.4h v4, w8
+; CHECK-NEXT:    xtn.4h v0, v0
+; CHECK-NEXT:    bif.8b v0, v3, v4
+; CHECK-NEXT:    sshll.4s v0, v0, #0
+; CHECK-NEXT:    bsl.16b v0, v1, v2
+; CHECK-NEXT:    ret
+entry:
+  %cmp1 = icmp sgt <4 x i32> %a, zeroinitializer
+  br i1 %cond, label %then, label %else
+
+then:
+  %cmp2 = icmp slt <4 x i32> %a, <i32 100, i32 100, i32 100, i32 100>
+  br label %merge
+
+else:
+  br label %merge
+
+merge:
+  %mask = phi <4 x i1> [ %cmp2, %then ], [ %cmp1, %else ]
+  %fr = freeze <4 x i1> %mask
+  %sel = select <4 x i1> %fr, <4 x i32> %x, <4 x i32> %y
+  ret <4 x i32> %sel
+}
+
+define <3 x i64> @or_setcc_i16_i32_sel_i64(<3 x i16> %a, <3 x i16> %b, <3 x i32> %c, <3 x i32> %d, <3 x i64> %x, <3 x i64> %y) {
+; CHECK-LABEL: or_setcc_i16_i32_sel_i64:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    cmgt.4h v0, v0, v1
+; CHECK-NEXT:    cmgt.4s v1, v2, v3
+; CHECK-NEXT:    ; kill: def $d7 killed $d7 def $q7
+; CHECK-NEXT:    ; kill: def $d4 killed $d4 def $q4
+; CHECK-NEXT:    ; kill: def $d5 killed $d5 def $q5
+; CHECK-NEXT:    ; kill: def $d6 killed $d6 def $q6
+; CHECK-NEXT:    mov.d v4[1], v5[0]
+; CHECK-NEXT:    sshll.4s v0, v0, #0
+; CHECK-NEXT:    orr.16b v1, v0, v1
+; CHECK-NEXT:    ldp d0, d2, [sp]
+; CHECK-NEXT:    mov.d v7[1], v0[0]
+; CHECK-NEXT:    sshll.2d v0, v1, #0
+; CHECK-NEXT:    sshll2.2d v1, v1, #0
+; CHECK-NEXT:    bit.16b v2, v6, v1
+; CHECK-NEXT:    bsl.16b v0, v4, v7
+; CHECK-NEXT:    ; kill: def $d2 killed $d2 killed $q2
+; CHECK-NEXT:    ext.16b v1, v0, v0, #8
+; CHECK-NEXT:    ; kill: def $d0 killed $d0 killed $q0
+; CHECK-NEXT:    ; kill: def $d1 killed $d1 killed $q1
+; CHECK-NEXT:    ret
+  %cmp0 = icmp sgt <3 x i16> %a, %b
+  %cmp1 = icmp sgt <3 x i32> %c, %d
+  %or = or <3 x i1> %cmp0, %cmp1
+  %sel = select <3 x i1> %or, <3 x i64> %x, <3 x i64> %y
+  ret <3 x i64> %sel
+}
+
+define <3 x i64> @and_setcc_i32_i32_sel_i64(<3 x i32> %a, <3 x i32> %b, <3 x i32> %c, <3 x i32> %d, <3 x i64> %x, <3 x i64> %y) {
+; CHECK-LABEL: and_setcc_i32_i32_sel_i64:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    cmgt.4s v2, v2, v3
+; CHECK-NEXT:    cmgt.4s v0, v0, v1
+; CHECK-NEXT:    ; kill: def $d7 killed $d7 def $q7
+; CHECK-NEXT:    ; kill: def $d4 killed $d4 def $q4
+; CHECK-NEXT:    ; kill: def $d5 killed $d5 def $q5
+; CHECK-NEXT:    ; kill: def $d6 killed $d6 def $q6
+; CHECK-NEXT:    mov.d v4[1], v5[0]
+; CHECK-NEXT:    and.16b v1, v0, v2
+; CHECK-NEXT:    ldp d0, d2, [sp]
+; CHECK-NEXT:    mov.d v7[1], v0[0]
+; CHECK-NEXT:    sshll.2d v0, v1, #0
+; CHECK-NEXT:    sshll2.2d v1, v1, #0
+; CHECK-NEXT:    bit.16b v2, v6, v1
+; CHECK-NEXT:    bsl.16b v0, v4, v7
+; CHECK-NEXT:    ; kill: def $d2 killed $d2 killed $q2
+; CHECK-NEXT:    ext.16b v1, v0, v0, #8
+; CHECK-NEXT:    ; kill: def $d0 killed $d0 killed $q0
+; CHECK-NEXT:    ; kill: def $d1 killed $d1 killed $q1
+; CHECK-NEXT:    ret
+  %cmp0 = icmp sgt <3 x i32> %a, %b
+  %cmp1 = icmp sgt <3 x i32> %c, %d
+  %and = and <3 x i1> %cmp0, %cmp1
+  %sel = select <3 x i1> %and, <3 x i64> %x, <3 x i64> %y
+  ret <3 x i64> %sel
+}
+
+define <3 x i64> @or_setcc_i16_i16_sel_i64(<3 x i16> %a, <3 x i16> %b, <3 x i16> %c, <3 x i16> %d, <3 x i64> %x, <3 x i64> %y) {
+; CHECK-LABEL: or_setcc_i16_i16_sel_i64:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    cmgt.4h v2, v2, v3
+; CHECK-NEXT:    cmgt.4h v0, v0, v1
+; CHECK-NEXT:    ; kill: def $d7 killed $d7 def $q7
+; CHECK-NEXT:    ; kill: def $d4 killed $d4 def $q4
+; CHECK-NEXT:    ; kill: def $d5 killed $d5 def $q5
+; CHECK-NEXT:    ; kill: def $d6 killed $d6 def $q6
+; CHECK-NEXT:    mov.d v4[1], v5[0]
+; CHECK-NEXT:    orr.8b v0, v0, v2
+; CHECK-NEXT:    sshll.4s v1, v0, #0
+; CHECK-NEXT:    ldp d0, d2, [sp]
+; CHECK-NEXT:    mov.d v7[1], v0[0]
+; CHECK-NEXT:    sshll.2d v0, v1, #0
+; CHECK-NEXT:    sshll2.2d v1, v1, #0
+; CHECK-NEXT:    bit.16b v2, v6, v1
+; CHECK-NEXT:    bsl.16b v0, v4, v7
+; CHECK-NEXT:    ; kill: def $d2 killed $d2 killed $q2
+; CHECK-NEXT:    ext.16b v1, v0, v0, #8
+; CHECK-NEXT:    ; kill: def $d0 killed $d0 killed $q0
+; CHECK-NEXT:    ; kill: def $d1 killed $d1 killed $q1
+; CHECK-NEXT:    ret
+  %cmp0 = icmp sgt <3 x i16> %a, %b
+  %cmp1 = icmp sgt <3 x i16> %c, %d
+  %or = or <3 x i1> %cmp0, %cmp1
+  %sel = select <3 x i1> %or, <3 x i64> %x, <3 x i64> %y
+  ret <3 x i64> %sel
+}

>From 0539a4c00361118cad4897ff22ec679e3b9b42ae Mon Sep 17 00:00:00 2001
From: Valeriy Savchenko <vsavchenko at apple.com>
Date: Fri, 6 Mar 2026 16:37:00 +0000
Subject: [PATCH 2/3] [SelectionDAG] Recurse through mask expression trees in
 WidenVSELECTMask

---
 llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h |   5 +
 .../SelectionDAG/LegalizeVectorTypes.cpp      | 129 +++++++++++-------
 llvm/test/CodeGen/AArch64/arm64-zip.ll        |  27 ++--
 .../AArch64/vselect-widen-mask-tree.ll        |  96 +++++++------
 .../X86/bitcast-int-to-vector-bool-sext.ll    |  65 +++++----
 5 files changed, 172 insertions(+), 150 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index da592e3cad0f5..4d69c0e8bcde8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -1149,6 +1149,11 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   /// MaskVT to ToMaskVT if needed with vector extension or truncation.
   SDValue convertMask(SDValue InMask, EVT MaskVT, EVT ToMaskVT);
 
+  /// Recursively widen a mask expression tree to ToVT. Handles SETCC,
+  /// logical ops (AND/OR/XOR), VECTOR_SHUFFLE, FREEZE, SELECT/VSELECT,
+  /// and constant BUILD_VECTORs.
+  SDValue widenMaskTree(SDValue V, EVT ToVT, unsigned Depth = 0);
+
   //===--------------------------------------------------------------------===//
   // Generic Splitting: LegalizeTypesGeneric.cpp
   //===--------------------------------------------------------------------===//
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index aeb9d4d7bdc1d..0b7ab61cf3df8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -6806,9 +6806,8 @@ static inline bool isSETCCorConvertedSETCC(SDValue N) {
 // to ToMaskVT if needed with vector extension or truncation.
 SDValue DAGTypeLegalizer::convertMask(SDValue InMask, EVT MaskVT,
                                       EVT ToMaskVT) {
-  // Currently a SETCC or a AND/OR/XOR with two SETCCs are handled.
-  // FIXME: This code seems to be too restrictive, we might consider
-  // generalizing it or dropping it.
+  // Called from widenMaskTree for SETCC leaf nodes. Re-creates the SETCC with
+  // result type MaskVT, then sign-extends/truncates and pads to ToMaskVT.
   assert(isSETCCorConvertedSETCC(InMask) && "Unexpected mask argument.");
 
   // Make a new Mask node, with a legal result VT.
@@ -6862,6 +6861,82 @@ SDValue DAGTypeLegalizer::convertMask(SDValue InMask, EVT MaskVT,
   return Mask;
 }
 
+// Recursively widen a mask expression tree to ToVT. Handles any chain of
+// mask-preserving operations rooted at SETCC nodes. Returns SDValue() if
+// the tree cannot be widened.
+SDValue DAGTypeLegalizer::widenMaskTree(SDValue V, EVT ToVT, unsigned Depth) {
+  if (Depth >= DAG.MaxRecursionDepth)
+    return SDValue();
+
+  unsigned Opcode = V.getOpcode();
+
+  // Base case: SETCC produces the mask directly.
+  if (isSETCCOp(Opcode)) {
+    EVT MaskVT = getSetCCResultType(getSETCCOperandType(V));
+    return convertMask(V, MaskVT, ToVT);
+  }
+
+  // Base case: all-zeros or all-ones BUILD_VECTOR.
+  if (ISD::isBuildVectorAllZeros(V.getNode()))
+    return DAG.getConstant(0, SDLoc(V), ToVT);
+  if (ISD::isBuildVectorAllOnes(V.getNode()))
+    return DAG.getAllOnesConstant(SDLoc(V), ToVT);
+
+  SDLoc DL(V);
+
+  // Logical operations (AND/OR/XOR): widen both operands and rebuild.
+  if (isLogicalMaskOp(Opcode)) {
+    SDValue Op0 = widenMaskTree(V.getOperand(0), ToVT, Depth + 1);
+    if (!Op0)
+      return SDValue();
+    SDValue Op1 = widenMaskTree(V.getOperand(1), ToVT, Depth + 1);
+    if (!Op1)
+      return SDValue();
+    return DAG.getNode(Opcode, DL, ToVT, Op0, Op1);
+  }
+
+  // FREEZE: widen the operand and re-wrap.
+  if (Opcode == ISD::FREEZE) {
+    SDValue Inner = widenMaskTree(V.getOperand(0), ToVT, Depth + 1);
+    if (!Inner)
+      return SDValue();
+    return DAG.getNode(ISD::FREEZE, DL, ToVT, Inner);
+  }
+
+  // Vector shuffle: widen inputs and apply the same mask.
+  if (Opcode == ISD::VECTOR_SHUFFLE) {
+    auto *Shuf = cast<ShuffleVectorSDNode>(V);
+    SDValue Op0 = widenMaskTree(V.getOperand(0), ToVT, Depth + 1);
+    if (!Op0)
+      return SDValue();
+    SDValue Op1 = V.getOperand(1).isUndef()
+                      ? DAG.getUNDEF(ToVT)
+                      : widenMaskTree(V.getOperand(1), ToVT, Depth + 1);
+    if (!Op1)
+      return SDValue();
+    return DAG.getVectorShuffle(ToVT, DL, Op0, Op1, Shuf->getMask());
+  }
+
+  // SELECT/VSELECT: widen both true and false mask values.
+  if (Opcode == ISD::SELECT || Opcode == ISD::VSELECT) {
+    SDValue Cond = V.getOperand(0);
+    if (Cond.getValueType().isVector()) {
+      Cond = widenMaskTree(Cond, ToVT, Depth + 1);
+      if (!Cond)
+        return SDValue();
+    }
+    SDValue Op1 = widenMaskTree(V.getOperand(1), ToVT, Depth + 1);
+    if (!Op1)
+      return SDValue();
+    SDValue Op2 = widenMaskTree(V.getOperand(2), ToVT, Depth + 1);
+    if (!Op2)
+      return SDValue();
+    return DAG.getNode(Opcode, DL, ToVT, Cond, Op1, Op2);
+  }
+
+  return SDValue();
+}
+
 // This method tries to handle some special cases for the vselect mask
 // and if needed adjusting the mask vector type to match that of the VSELECT.
 // Without it, many cases end up with scalarization of the SETCC, with many
@@ -6873,9 +6948,6 @@ SDValue DAGTypeLegalizer::WidenVSELECTMask(SDNode *N) {
   if (N->getOpcode() != ISD::VSELECT)
     return SDValue();
 
-  if (!isSETCCOp(Cond->getOpcode()) && !isLogicalMaskOp(Cond->getOpcode()))
-    return SDValue();
-
   // If this is a splitted VSELECT that was previously already handled, do
   // nothing.
   EVT CondVT = Cond->getValueType(0);
@@ -6928,49 +7000,8 @@ SDValue DAGTypeLegalizer::WidenVSELECTMask(SDNode *N) {
   if (!ToMaskVT.getScalarType().isInteger())
     ToMaskVT = ToMaskVT.changeVectorElementTypeToInteger();
 
-  SDValue Mask;
-  if (isSETCCOp(Cond->getOpcode())) {
-    EVT MaskVT = getSetCCResultType(getSETCCOperandType(Cond));
-    Mask = convertMask(Cond, MaskVT, ToMaskVT);
-  } else if (isLogicalMaskOp(Cond->getOpcode()) &&
-             isSETCCOp(Cond->getOperand(0).getOpcode()) &&
-             isSETCCOp(Cond->getOperand(1).getOpcode())) {
-    // Cond is (AND/OR/XOR (SETCC, SETCC))
-    SDValue SETCC0 = Cond->getOperand(0);
-    SDValue SETCC1 = Cond->getOperand(1);
-    EVT VT0 = getSetCCResultType(getSETCCOperandType(SETCC0));
-    EVT VT1 = getSetCCResultType(getSETCCOperandType(SETCC1));
-    unsigned ScalarBits0 = VT0.getScalarSizeInBits();
-    unsigned ScalarBits1 = VT1.getScalarSizeInBits();
-    unsigned ScalarBits_ToMask = ToMaskVT.getScalarSizeInBits();
-    EVT MaskVT;
-    // If the two SETCCs have different VTs, either extend/truncate one of
-    // them to the other "towards" ToMaskVT, or truncate one and extend the
-    // other to ToMaskVT.
-    if (ScalarBits0 != ScalarBits1) {
-      EVT NarrowVT = ((ScalarBits0 < ScalarBits1) ? VT0 : VT1);
-      EVT WideVT = ((NarrowVT == VT0) ? VT1 : VT0);
-      if (ScalarBits_ToMask >= WideVT.getScalarSizeInBits())
-        MaskVT = WideVT;
-      else if (ScalarBits_ToMask <= NarrowVT.getScalarSizeInBits())
-        MaskVT = NarrowVT;
-      else
-        MaskVT = ToMaskVT;
-    } else
-      // If the two SETCCs have the same VT, don't change it.
-      MaskVT = VT0;
-
-    // Make new SETCCs and logical nodes.
-    SETCC0 = convertMask(SETCC0, VT0, MaskVT);
-    SETCC1 = convertMask(SETCC1, VT1, MaskVT);
-    Cond = DAG.getNode(Cond->getOpcode(), SDLoc(Cond), MaskVT, SETCC0, SETCC1);
-
-    // Convert the logical op for VSELECT if needed.
-    Mask = convertMask(Cond, MaskVT, ToMaskVT);
-  } else
-    return SDValue();
-
-  return Mask;
+  // Try to recursively widen the mask expression tree to the target type.
+  return widenMaskTree(Cond, ToMaskVT);
 }
 
 SDValue DAGTypeLegalizer::WidenVecRes_Select(SDNode *N) {
diff --git a/llvm/test/CodeGen/AArch64/arm64-zip.ll b/llvm/test/CodeGen/AArch64/arm64-zip.ll
index 44411a1032dca..fac75f24e9666 100644
--- a/llvm/test/CodeGen/AArch64/arm64-zip.ll
+++ b/llvm/test/CodeGen/AArch64/arm64-zip.ll
@@ -382,13 +382,10 @@ define <4 x float> @shuffle_zip1(<4 x float> %arg) {
 ; CHECK-LABEL: shuffle_zip1:
 ; CHECK:       // %bb.0: // %bb
 ; CHECK-NEXT:    fcmgt.4s v0, v0, #0.0
-; CHECK-NEXT:    uzp1.8h v1, v0, v0
-; CHECK-NEXT:    xtn.4h v0, v0
-; CHECK-NEXT:    xtn.4h v1, v1
-; CHECK-NEXT:    zip2.4h v0, v0, v1
 ; CHECK-NEXT:    fmov.4s v1, #1.00000000
-; CHECK-NEXT:    zip1.4h v0, v0, v0
-; CHECK-NEXT:    sshll.4s v0, v0, #0
+; CHECK-NEXT:    uzp1.4s v2, v0, v0
+; CHECK-NEXT:    zip2.4s v0, v0, v2
+; CHECK-NEXT:    zip1.4s v0, v0, v0
 ; CHECK-NEXT:    and.16b v0, v0, v1
 ; CHECK-NEXT:    ret
 bb:
@@ -403,13 +400,10 @@ define <4 x i32> @shuffle_zip2(<4 x i32> %arg) {
 ; CHECK-LABEL: shuffle_zip2:
 ; CHECK:       // %bb.0: // %bb
 ; CHECK-NEXT:    cmtst.4s v0, v0, v0
-; CHECK-NEXT:    uzp1.8h v1, v0, v0
-; CHECK-NEXT:    xtn.4h v0, v0
-; CHECK-NEXT:    xtn.4h v1, v1
-; CHECK-NEXT:    zip2.4h v0, v0, v1
 ; CHECK-NEXT:    movi.4s v1, #1
-; CHECK-NEXT:    zip1.4h v0, v0, v0
-; CHECK-NEXT:    ushll.4s v0, v0, #0
+; CHECK-NEXT:    uzp1.4s v2, v0, v0
+; CHECK-NEXT:    zip2.4s v0, v0, v2
+; CHECK-NEXT:    zip1.4s v0, v0, v0
 ; CHECK-NEXT:    and.16b v0, v0, v1
 ; CHECK-NEXT:    ret
 bb:
@@ -424,13 +418,10 @@ define <4 x i32> @shuffle_zip3(<4 x i32> %arg) {
 ; CHECK-LABEL: shuffle_zip3:
 ; CHECK:       // %bb.0: // %bb
 ; CHECK-NEXT:    cmgt.4s v0, v0, #0
-; CHECK-NEXT:    uzp1.8h v1, v0, v0
-; CHECK-NEXT:    xtn.4h v0, v0
-; CHECK-NEXT:    xtn.4h v1, v1
-; CHECK-NEXT:    zip2.4h v0, v0, v1
 ; CHECK-NEXT:    movi.4s v1, #1
-; CHECK-NEXT:    zip1.4h v0, v0, v0
-; CHECK-NEXT:    ushll.4s v0, v0, #0
+; CHECK-NEXT:    uzp1.4s v2, v0, v0
+; CHECK-NEXT:    zip2.4s v0, v0, v2
+; CHECK-NEXT:    zip1.4s v0, v0, v0
 ; CHECK-NEXT:    and.16b v0, v0, v1
 ; CHECK-NEXT:    ret
 bb:
diff --git a/llvm/test/CodeGen/AArch64/vselect-widen-mask-tree.ll b/llvm/test/CodeGen/AArch64/vselect-widen-mask-tree.ll
index c459c78c3e5df..2d9712c276594 100644
--- a/llvm/test/CodeGen/AArch64/vselect-widen-mask-tree.ll
+++ b/llvm/test/CodeGen/AArch64/vselect-widen-mask-tree.ll
@@ -23,13 +23,11 @@ define <4 x i32> @freeze_or_setcc(<4 x i32> %a, <4 x i32> %b, <4 x i32> %x, <4 x
 define <4 x i32> @select_allones_or_setcc(<4 x i32> %a, <4 x i32> %x, <4 x i32> %y, i1 %cond) {
 ; CHECK-LABEL: select_allones_or_setcc:
 ; CHECK:       ; %bb.0:
-; CHECK-NEXT:    cmgt.4s v0, v0, #0
 ; CHECK-NEXT:    tst w0, #0x1
+; CHECK-NEXT:    cmgt.4s v0, v0, #0
 ; CHECK-NEXT:    csetm w8, ne
-; CHECK-NEXT:    dup.4h v3, w8
-; CHECK-NEXT:    xtn.4h v0, v0
-; CHECK-NEXT:    orr.8b v0, v0, v3
-; CHECK-NEXT:    sshll.4s v0, v0, #0
+; CHECK-NEXT:    dup.4s v3, w8
+; CHECK-NEXT:    orr.16b v0, v0, v3
 ; CHECK-NEXT:    bsl.16b v0, v1, v2
 ; CHECK-NEXT:    ret
   %cmp = icmp sgt <4 x i32> %a, zeroinitializer
@@ -41,13 +39,11 @@ define <4 x i32> @select_allones_or_setcc(<4 x i32> %a, <4 x i32> %x, <4 x i32>
 define <4 x i32> @select_setcc_or_allzeros(<4 x i32> %a, <4 x i32> %x, <4 x i32> %y, i1 %cond) {
 ; CHECK-LABEL: select_setcc_or_allzeros:
 ; CHECK:       ; %bb.0:
-; CHECK-NEXT:    cmgt.4s v0, v0, #0
 ; CHECK-NEXT:    tst w0, #0x1
+; CHECK-NEXT:    cmgt.4s v0, v0, #0
 ; CHECK-NEXT:    csetm w8, ne
-; CHECK-NEXT:    dup.4h v3, w8
-; CHECK-NEXT:    xtn.4h v0, v0
-; CHECK-NEXT:    and.8b v0, v0, v3
-; CHECK-NEXT:    sshll.4s v0, v0, #0
+; CHECK-NEXT:    dup.4s v3, w8
+; CHECK-NEXT:    and.16b v0, v0, v3
 ; CHECK-NEXT:    bsl.16b v0, v1, v2
 ; CHECK-NEXT:    ret
   %cmp = icmp sgt <4 x i32> %a, zeroinitializer
@@ -61,10 +57,8 @@ define <4 x i32> @select_allzeros_or_allones(<4 x i32> %x, <4 x i32> %y, i1 %con
 ; CHECK:       ; %bb.0:
 ; CHECK-NEXT:    tst w0, #0x1
 ; CHECK-NEXT:    csetm w8, ne
-; CHECK-NEXT:    dup.4h v2, w8
-; CHECK-NEXT:    mvn.8b v2, v2
-; CHECK-NEXT:    sshll.4s v2, v2, #0
-; CHECK-NEXT:    bif.16b v0, v1, v2
+; CHECK-NEXT:    dup.4s v2, w8
+; CHECK-NEXT:    bit.16b v0, v1, v2
 ; CHECK-NEXT:    ret
   %mask = select i1 %cond, <4 x i1> zeroinitializer, <4 x i1> <i1 true, i1 true, i1 true, i1 true>
   %sel = select <4 x i1> %mask, <4 x i32> %x, <4 x i32> %y
@@ -78,13 +72,7 @@ define <4 x i32> @vselect_of_setccs(<4 x i32> %a, <4 x i32> %b, <4 x i32> %x, <4
 ; CHECK-NEXT:    cmgt.4s v5, v0, #0
 ; CHECK-NEXT:    cmeq.4s v1, v1, #0
 ; CHECK-NEXT:    cmgt.4s v0, v4, v0
-; CHECK-NEXT:    xtn.4h v4, v1
-; CHECK-NEXT:    and.16b v1, v5, v1
-; CHECK-NEXT:    xtn.4h v0, v0
-; CHECK-NEXT:    xtn.4h v1, v1
-; CHECK-NEXT:    bic.8b v0, v0, v4
-; CHECK-NEXT:    orr.8b v0, v1, v0
-; CHECK-NEXT:    sshll.4s v0, v0, #0
+; CHECK-NEXT:    bit.16b v0, v5, v1
 ; CHECK-NEXT:    bsl.16b v0, v2, v3
 ; CHECK-NEXT:    ret
   %cmp1 = icmp sgt <4 x i32> %a, zeroinitializer
@@ -99,15 +87,12 @@ define <4 x i32> @select_scalar_cond_setccs(<4 x i32> %a, <4 x i32> %x, <4 x i32
 ; CHECK-LABEL: select_scalar_cond_setccs:
 ; CHECK:       ; %bb.0: ; %entry
 ; CHECK-NEXT:    movi.4s v3, #100
-; CHECK-NEXT:    cmgt.4s v4, v0, #0
 ; CHECK-NEXT:    tst w0, #0x1
+; CHECK-NEXT:    cmgt.4s v4, v0, #0
 ; CHECK-NEXT:    csetm w8, ne
 ; CHECK-NEXT:    cmgt.4s v0, v3, v0
-; CHECK-NEXT:    xtn.4h v3, v4
-; CHECK-NEXT:    dup.4h v4, w8
-; CHECK-NEXT:    xtn.4h v0, v0
-; CHECK-NEXT:    bif.8b v0, v3, v4
-; CHECK-NEXT:    sshll.4s v0, v0, #0
+; CHECK-NEXT:    dup.4s v3, w8
+; CHECK-NEXT:    bif.16b v0, v4, v3
 ; CHECK-NEXT:    bsl.16b v0, v1, v2
 ; CHECK-NEXT:    ret
 entry:
@@ -139,17 +124,21 @@ define <3 x i64> @or_setcc_i16_i32_sel_i64(<3 x i16> %a, <3 x i16> %b, <3 x i32>
 ; CHECK-NEXT:    ; kill: def $d6 killed $d6 def $q6
 ; CHECK-NEXT:    mov.d v4[1], v5[0]
 ; CHECK-NEXT:    sshll.4s v0, v0, #0
-; CHECK-NEXT:    orr.16b v1, v0, v1
-; CHECK-NEXT:    ldp d0, d2, [sp]
-; CHECK-NEXT:    mov.d v7[1], v0[0]
-; CHECK-NEXT:    sshll.2d v0, v1, #0
-; CHECK-NEXT:    sshll2.2d v1, v1, #0
-; CHECK-NEXT:    bit.16b v2, v6, v1
+; CHECK-NEXT:    ext.16b v2, v1, v1, #8
+; CHECK-NEXT:    ext.16b v3, v0, v0, #8
+; CHECK-NEXT:    orr.8b v0, v0, v1
+; CHECK-NEXT:    ldp d1, d16, [sp]
+; CHECK-NEXT:    sshll.2d v0, v0, #0
+; CHECK-NEXT:    mov.d v7[1], v1[0]
+; CHECK-NEXT:    orr.8b v1, v3, v2
+; CHECK-NEXT:    sshll.2d v1, v1, #0
 ; CHECK-NEXT:    bsl.16b v0, v4, v7
-; CHECK-NEXT:    ; kill: def $d2 killed $d2 killed $q2
+; CHECK-NEXT:    mov.16b v2, v1
 ; CHECK-NEXT:    ext.16b v1, v0, v0, #8
 ; CHECK-NEXT:    ; kill: def $d0 killed $d0 killed $q0
+; CHECK-NEXT:    bsl.16b v2, v6, v16
 ; CHECK-NEXT:    ; kill: def $d1 killed $d1 killed $q1
+; CHECK-NEXT:    ; kill: def $d2 killed $d2 killed $q2
 ; CHECK-NEXT:    ret
   %cmp0 = icmp sgt <3 x i16> %a, %b
   %cmp1 = icmp sgt <3 x i32> %c, %d
@@ -168,17 +157,21 @@ define <3 x i64> @and_setcc_i32_i32_sel_i64(<3 x i32> %a, <3 x i32> %b, <3 x i32
 ; CHECK-NEXT:    ; kill: def $d5 killed $d5 def $q5
 ; CHECK-NEXT:    ; kill: def $d6 killed $d6 def $q6
 ; CHECK-NEXT:    mov.d v4[1], v5[0]
-; CHECK-NEXT:    and.16b v1, v0, v2
-; CHECK-NEXT:    ldp d0, d2, [sp]
-; CHECK-NEXT:    mov.d v7[1], v0[0]
-; CHECK-NEXT:    sshll.2d v0, v1, #0
-; CHECK-NEXT:    sshll2.2d v1, v1, #0
-; CHECK-NEXT:    bit.16b v2, v6, v1
+; CHECK-NEXT:    ext.16b v1, v2, v2, #8
+; CHECK-NEXT:    ext.16b v3, v0, v0, #8
+; CHECK-NEXT:    and.8b v0, v0, v2
+; CHECK-NEXT:    ldp d2, d16, [sp]
+; CHECK-NEXT:    sshll.2d v0, v0, #0
+; CHECK-NEXT:    and.8b v1, v3, v1
+; CHECK-NEXT:    mov.d v7[1], v2[0]
+; CHECK-NEXT:    sshll.2d v1, v1, #0
 ; CHECK-NEXT:    bsl.16b v0, v4, v7
-; CHECK-NEXT:    ; kill: def $d2 killed $d2 killed $q2
+; CHECK-NEXT:    mov.16b v2, v1
 ; CHECK-NEXT:    ext.16b v1, v0, v0, #8
 ; CHECK-NEXT:    ; kill: def $d0 killed $d0 killed $q0
+; CHECK-NEXT:    bsl.16b v2, v6, v16
 ; CHECK-NEXT:    ; kill: def $d1 killed $d1 killed $q1
+; CHECK-NEXT:    ; kill: def $d2 killed $d2 killed $q2
 ; CHECK-NEXT:    ret
   %cmp0 = icmp sgt <3 x i32> %a, %b
   %cmp1 = icmp sgt <3 x i32> %c, %d
@@ -197,18 +190,23 @@ define <3 x i64> @or_setcc_i16_i16_sel_i64(<3 x i16> %a, <3 x i16> %b, <3 x i16>
 ; CHECK-NEXT:    ; kill: def $d5 killed $d5 def $q5
 ; CHECK-NEXT:    ; kill: def $d6 killed $d6 def $q6
 ; CHECK-NEXT:    mov.d v4[1], v5[0]
-; CHECK-NEXT:    orr.8b v0, v0, v2
-; CHECK-NEXT:    sshll.4s v1, v0, #0
-; CHECK-NEXT:    ldp d0, d2, [sp]
-; CHECK-NEXT:    mov.d v7[1], v0[0]
-; CHECK-NEXT:    sshll.2d v0, v1, #0
-; CHECK-NEXT:    sshll2.2d v1, v1, #0
-; CHECK-NEXT:    bit.16b v2, v6, v1
+; CHECK-NEXT:    sshll.4s v1, v2, #0
+; CHECK-NEXT:    sshll.4s v0, v0, #0
+; CHECK-NEXT:    ext.16b v2, v1, v1, #8
+; CHECK-NEXT:    ext.16b v3, v0, v0, #8
+; CHECK-NEXT:    orr.8b v0, v0, v1
+; CHECK-NEXT:    ldp d1, d16, [sp]
+; CHECK-NEXT:    sshll.2d v0, v0, #0
+; CHECK-NEXT:    mov.d v7[1], v1[0]
+; CHECK-NEXT:    orr.8b v1, v3, v2
+; CHECK-NEXT:    sshll.2d v1, v1, #0
 ; CHECK-NEXT:    bsl.16b v0, v4, v7
-; CHECK-NEXT:    ; kill: def $d2 killed $d2 killed $q2
+; CHECK-NEXT:    mov.16b v2, v1
 ; CHECK-NEXT:    ext.16b v1, v0, v0, #8
 ; CHECK-NEXT:    ; kill: def $d0 killed $d0 killed $q0
+; CHECK-NEXT:    bsl.16b v2, v6, v16
 ; CHECK-NEXT:    ; kill: def $d1 killed $d1 killed $q1
+; CHECK-NEXT:    ; kill: def $d2 killed $d2 killed $q2
 ; CHECK-NEXT:    ret
   %cmp0 = icmp sgt <3 x i16> %a, %b
   %cmp1 = icmp sgt <3 x i16> %c, %d
diff --git a/llvm/test/CodeGen/X86/bitcast-int-to-vector-bool-sext.ll b/llvm/test/CodeGen/X86/bitcast-int-to-vector-bool-sext.ll
index 474be4465d9b7..19819e3ffbe43 100644
--- a/llvm/test/CodeGen/X86/bitcast-int-to-vector-bool-sext.ll
+++ b/llvm/test/CodeGen/X86/bitcast-int-to-vector-bool-sext.ll
@@ -660,32 +660,30 @@ define <8 x i32> @PR157382(ptr %p0, ptr %p1, ptr %p2) {
 ; SSE2-SSSE3:       # %bb.0:
 ; SSE2-SSSE3-NEXT:    movdqu (%rdi), %xmm3
 ; SSE2-SSSE3-NEXT:    movdqu 16(%rdi), %xmm2
-; SSE2-SSSE3-NEXT:    movdqu (%rsi), %xmm0
+; SSE2-SSSE3-NEXT:    movdqu (%rsi), %xmm5
 ; SSE2-SSSE3-NEXT:    movdqu 16(%rsi), %xmm4
 ; SSE2-SSSE3-NEXT:    movq {{.*#+}} xmm1 = mem[0],zero
-; SSE2-SSSE3-NEXT:    pxor %xmm5, %xmm5
+; SSE2-SSSE3-NEXT:    pxor %xmm0, %xmm0
 ; SSE2-SSSE3-NEXT:    pxor %xmm6, %xmm6
-; SSE2-SSSE3-NEXT:    pcmpgtd %xmm3, %xmm6
+; SSE2-SSSE3-NEXT:    pcmpgtd %xmm2, %xmm6
 ; SSE2-SSSE3-NEXT:    pcmpeqd %xmm7, %xmm7
 ; SSE2-SSSE3-NEXT:    pxor %xmm7, %xmm6
 ; SSE2-SSSE3-NEXT:    pxor %xmm8, %xmm8
-; SSE2-SSSE3-NEXT:    pcmpgtd %xmm2, %xmm8
+; SSE2-SSSE3-NEXT:    pcmpgtd %xmm3, %xmm8
 ; SSE2-SSSE3-NEXT:    pxor %xmm7, %xmm8
-; SSE2-SSSE3-NEXT:    pcmpgtd %xmm5, %xmm0
-; SSE2-SSSE3-NEXT:    por %xmm6, %xmm0
-; SSE2-SSSE3-NEXT:    pcmpgtd %xmm5, %xmm4
-; SSE2-SSSE3-NEXT:    por %xmm8, %xmm4
-; SSE2-SSSE3-NEXT:    packssdw %xmm4, %xmm0
+; SSE2-SSSE3-NEXT:    pcmpgtd %xmm0, %xmm4
+; SSE2-SSSE3-NEXT:    por %xmm6, %xmm4
+; SSE2-SSSE3-NEXT:    pcmpgtd %xmm0, %xmm5
+; SSE2-SSSE3-NEXT:    por %xmm8, %xmm5
 ; SSE2-SSSE3-NEXT:    punpcklbw {{.*#+}} xmm1 = xmm1[0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7]
-; SSE2-SSSE3-NEXT:    pcmpeqb %xmm5, %xmm1
+; SSE2-SSSE3-NEXT:    pcmpeqb %xmm0, %xmm1
 ; SSE2-SSSE3-NEXT:    pxor %xmm7, %xmm1
-; SSE2-SSSE3-NEXT:    por %xmm0, %xmm1
+; SSE2-SSSE3-NEXT:    movdqa %xmm1, %xmm0
 ; SSE2-SSSE3-NEXT:    punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3]
-; SSE2-SSSE3-NEXT:    psrad $16, %xmm0
-; SSE2-SSSE3-NEXT:    pand %xmm3, %xmm0
+; SSE2-SSSE3-NEXT:    por %xmm5, %xmm0
 ; SSE2-SSSE3-NEXT:    punpckhwd {{.*#+}} xmm1 = xmm1[4,4,5,5,6,6,7,7]
-; SSE2-SSSE3-NEXT:    pslld $31, %xmm1
-; SSE2-SSSE3-NEXT:    psrad $31, %xmm1
+; SSE2-SSSE3-NEXT:    por %xmm4, %xmm1
+; SSE2-SSSE3-NEXT:    pand %xmm3, %xmm0
 ; SSE2-SSSE3-NEXT:    pand %xmm2, %xmm1
 ; SSE2-SSSE3-NEXT:    retq
 ;
@@ -693,28 +691,27 @@ define <8 x i32> @PR157382(ptr %p0, ptr %p1, ptr %p2) {
 ; AVX1:       # %bb.0:
 ; AVX1-NEXT:    vmovdqu (%rdi), %ymm0
 ; AVX1-NEXT:    vmovq {{.*#+}} xmm1 = mem[0],zero
-; AVX1-NEXT:    vpxor %xmm2, %xmm2, %xmm2
-; AVX1-NEXT:    vpcmpgtd %xmm0, %xmm2, %xmm3
+; AVX1-NEXT:    vextractf128 $1, %ymm0, %xmm2
+; AVX1-NEXT:    vpxor %xmm3, %xmm3, %xmm3
+; AVX1-NEXT:    vpcmpgtd %xmm2, %xmm3, %xmm2
 ; AVX1-NEXT:    vpcmpeqd %xmm4, %xmm4, %xmm4
-; AVX1-NEXT:    vpxor %xmm4, %xmm3, %xmm3
-; AVX1-NEXT:    vextractf128 $1, %ymm0, %xmm5
-; AVX1-NEXT:    vpcmpgtd %xmm5, %xmm2, %xmm5
+; AVX1-NEXT:    vpxor %xmm4, %xmm2, %xmm2
+; AVX1-NEXT:    vpcmpgtd %xmm0, %xmm3, %xmm5
 ; AVX1-NEXT:    vpxor %xmm4, %xmm5, %xmm5
-; AVX1-NEXT:    vmovdqu (%rsi), %xmm6
-; AVX1-NEXT:    vmovdqu 16(%rsi), %xmm7
-; AVX1-NEXT:    vpcmpgtd %xmm2, %xmm6, %xmm6
-; AVX1-NEXT:    vpor %xmm6, %xmm3, %xmm3
-; AVX1-NEXT:    vpcmpgtd %xmm2, %xmm7, %xmm6
-; AVX1-NEXT:    vpor %xmm6, %xmm5, %xmm5
-; AVX1-NEXT:    vpackssdw %xmm5, %xmm3, %xmm3
-; AVX1-NEXT:    vpcmpeqb %xmm2, %xmm1, %xmm1
+; AVX1-NEXT:    vinsertf128 $1, %xmm2, %ymm5, %ymm2
+; AVX1-NEXT:    vmovdqu (%rsi), %xmm5
+; AVX1-NEXT:    vmovdqu 16(%rsi), %xmm6
+; AVX1-NEXT:    vpcmpgtd %xmm3, %xmm6, %xmm6
+; AVX1-NEXT:    vpcmpgtd %xmm3, %xmm5, %xmm5
+; AVX1-NEXT:    vinsertf128 $1, %xmm6, %ymm5, %ymm5
+; AVX1-NEXT:    vorps %ymm5, %ymm2, %ymm2
+; AVX1-NEXT:    vpcmpeqb %xmm3, %xmm1, %xmm1
 ; AVX1-NEXT:    vpxor %xmm4, %xmm1, %xmm1
-; AVX1-NEXT:    vpmovsxbw %xmm1, %xmm1
-; AVX1-NEXT:    vpor %xmm1, %xmm3, %xmm1
-; AVX1-NEXT:    vpmovsxwd %xmm1, %xmm2
-; AVX1-NEXT:    vpshufd {{.*#+}} xmm1 = xmm1[2,3,2,3]
-; AVX1-NEXT:    vpmovsxwd %xmm1, %xmm1
-; AVX1-NEXT:    vinsertf128 $1, %xmm1, %ymm2, %ymm1
+; AVX1-NEXT:    vpmovsxbd %xmm1, %xmm3
+; AVX1-NEXT:    vpshufd {{.*#+}} xmm1 = xmm1[1,1,1,1]
+; AVX1-NEXT:    vpmovsxbd %xmm1, %xmm1
+; AVX1-NEXT:    vinsertf128 $1, %xmm1, %ymm3, %ymm1
+; AVX1-NEXT:    vorps %ymm1, %ymm2, %ymm1
 ; AVX1-NEXT:    vandps %ymm0, %ymm1, %ymm0
 ; AVX1-NEXT:    retq
 ;

>From 1c9ea2b89f2c0aaa2f3e4d04965f881aa9142b5a Mon Sep 17 00:00:00 2001
From: Valeriy Savchenko <vsavchenko at apple.com>
Date: Wed, 25 Mar 2026 15:39:23 +0000
Subject: [PATCH 3/3] [SelectionDAG] Restore the pre-existing logic of type
 selection

---
 llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h |  15 +-
 .../SelectionDAG/LegalizeVectorTypes.cpp      | 166 +++++++++++-------
 .../AArch64/vselect-widen-mask-tree.ll        |  57 +++---
 3 files changed, 138 insertions(+), 100 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 4d69c0e8bcde8..2537672fbdbf2 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -1145,13 +1145,22 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   /// By default, the vector will be widened with undefined values.
   SDValue ModifyToType(SDValue InOp, EVT NVT, bool FillWithZeroes = false);
 
+  /// Adjust element width (sign-extend/truncate) and element count
+  /// (extract/concat) of Mask to match ToMaskVT.
+  SDValue adjustMaskToType(SDValue Mask, EVT ToMaskVT);
+
+  /// Pick an intermediate VT and adjust both operands to it, minimizing
+  /// extend/truncate overhead given the final target ToVT.
+  EVT unifyMaskTypes(SDValue &Op0, SDValue &Op1, EVT ToVT);
+
   /// Return a mask of vector type MaskVT to replace InMask. Also adjust
   /// MaskVT to ToMaskVT if needed with vector extension or truncation.
   SDValue convertMask(SDValue InMask, EVT MaskVT, EVT ToMaskVT);
 
-  /// Recursively widen a mask expression tree to ToVT. Handles SETCC,
-  /// logical ops (AND/OR/XOR), VECTOR_SHUFFLE, FREEZE, SELECT/VSELECT,
-  /// and constant BUILD_VECTORs.
+  /// Recursively widen a mask expression tree to ToVT, walking through
+  /// mask-preserving operations down to SETCC leaves. Avoids redundant
+  /// extend/truncate chains that arise when each node is widened independently.
+  /// Returns SDValue() if the tree cannot be widened.
   SDValue widenMaskTree(SDValue V, EVT ToVT, unsigned Depth = 0);
 
   //===--------------------------------------------------------------------===//
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 0b7ab61cf3df8..95405fdc3126c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -6824,9 +6824,14 @@ SDValue DAGTypeLegalizer::convertMask(SDValue InMask, EVT MaskVT,
     Mask = DAG.getNode(InMask->getOpcode(), SDLoc(InMask), MaskVT, Ops,
                        InMask->getFlags());
 
-  // If MaskVT has smaller or bigger elements than ToMaskVT, a vector sign
-  // extend or truncate is needed.
+  return adjustMaskToType(Mask, ToMaskVT);
+}
+
+// Adjust element width (sign-extend/truncate) and element count
+// (extract/concat) of Mask to match ToMaskVT.
+SDValue DAGTypeLegalizer::adjustMaskToType(SDValue Mask, EVT ToMaskVT) {
   LLVMContext &Ctx = *DAG.getContext();
+  EVT MaskVT = Mask.getValueType();
   unsigned MaskScalarBits = MaskVT.getScalarSizeInBits();
   unsigned ToMaskScalBits = ToMaskVT.getScalarSizeInBits();
   if (MaskScalarBits < ToMaskScalBits) {
@@ -6861,80 +6866,117 @@ SDValue DAGTypeLegalizer::convertMask(SDValue InMask, EVT MaskVT,
   return Mask;
 }
 
-// Recursively widen a mask expression tree to ToVT. Handles any chain of
-// mask-preserving operations rooted at SETCC nodes. Returns SDValue() if
-// the tree cannot be widened.
+// Adjust both operands to a common intermediate mask type, picking a scalar
+// width that minimizes extend/truncate overhead given the final target ToVT.
+EVT DAGTypeLegalizer::unifyMaskTypes(SDValue &Op0, SDValue &Op1, EVT ToVT) {
+  assert(Op0.getValueType().getVectorNumElements() ==
+             Op1.getValueType().getVectorNumElements() &&
+         "unifyMaskTypes only handles scalar width differences");
+  unsigned Bits0 = Op0.getValueType().getScalarSizeInBits();
+  unsigned Bits1 = Op1.getValueType().getScalarSizeInBits();
+  unsigned NarrowBits = std::min(Bits0, Bits1);
+  unsigned WideBits = std::max(Bits0, Bits1);
+  unsigned ToBits = ToVT.getScalarSizeInBits();
+  unsigned IntBits = NarrowBits == WideBits ? NarrowBits
+                     : ToBits >= WideBits   ? WideBits
+                     : ToBits <= NarrowBits ? NarrowBits
+                                            : ToBits;
+  EVT OpVT = EVT::getVectorVT(*DAG.getContext(), MVT::getIntegerVT(IntBits),
+                              Op0.getValueType().getVectorNumElements());
+  Op0 = adjustMaskToType(Op0, OpVT);
+  Op1 = adjustMaskToType(Op1, OpVT);
+  return OpVT;
+}
+
 SDValue DAGTypeLegalizer::widenMaskTree(SDValue V, EVT ToVT, unsigned Depth) {
   if (Depth >= DAG.MaxRecursionDepth)
     return SDValue();
 
-  unsigned Opcode = V.getOpcode();
+  SDValue Result = [&]() -> SDValue {
+    unsigned Opcode = V.getOpcode();
 
-  // Base case: SETCC produces the mask directly.
-  if (isSETCCOp(Opcode)) {
-    EVT MaskVT = getSetCCResultType(getSETCCOperandType(V));
-    return convertMask(V, MaskVT, ToVT);
-  }
+    // Base case: SETCC produces the mask at its natural type.
+    if (isSETCCOp(Opcode)) {
+      EVT MaskVT = getSetCCResultType(getSETCCOperandType(V));
+      return convertMask(V, MaskVT, MaskVT);
+    }
 
-  // Base case: all-zeros or all-ones BUILD_VECTOR.
-  if (ISD::isBuildVectorAllZeros(V.getNode()))
-    return DAG.getConstant(0, SDLoc(V), ToVT);
-  if (ISD::isBuildVectorAllOnes(V.getNode()))
-    return DAG.getAllOnesConstant(SDLoc(V), ToVT);
+    // Base case: all-zeros or all-ones BUILD_VECTOR. Use ToVT directly since
+    // these are invariant under sign-extend/truncate.
+    if (ISD::isBuildVectorAllZeros(V.getNode()))
+      return DAG.getConstant(0, SDLoc(V), ToVT);
+    if (ISD::isBuildVectorAllOnes(V.getNode()))
+      return DAG.getAllOnesConstant(SDLoc(V), ToVT);
 
-  SDLoc DL(V);
+    SDLoc DL(V);
 
-  // Logical operations (AND/OR/XOR): widen both operands and rebuild.
-  if (isLogicalMaskOp(Opcode)) {
-    SDValue Op0 = widenMaskTree(V.getOperand(0), ToVT, Depth + 1);
-    if (!Op0)
-      return SDValue();
-    SDValue Op1 = widenMaskTree(V.getOperand(1), ToVT, Depth + 1);
-    if (!Op1)
-      return SDValue();
-    return DAG.getNode(Opcode, DL, ToVT, Op0, Op1);
-  }
+    // Logical operations (AND/OR/XOR): try picking the best fitting width out
+    // of children's element widths.
+    if (isLogicalMaskOp(Opcode)) {
+      SDValue Op0 = widenMaskTree(V.getOperand(0), ToVT, Depth + 1);
+      if (!Op0)
+        return SDValue();
+      SDValue Op1 = widenMaskTree(V.getOperand(1), ToVT, Depth + 1);
+      if (!Op1)
+        return SDValue();
+      EVT OpVT = unifyMaskTypes(Op0, Op1, ToVT);
+      return DAG.getNode(Opcode, DL, OpVT, Op0, Op1);
+    }
 
-  // FREEZE: widen the operand and re-wrap.
-  if (Opcode == ISD::FREEZE) {
-    SDValue Inner = widenMaskTree(V.getOperand(0), ToVT, Depth + 1);
-    if (!Inner)
-      return SDValue();
-    return DAG.getNode(ISD::FREEZE, DL, ToVT, Inner);
-  }
+    // FREEZE: widen the operand and re-wrap.
+    if (Opcode == ISD::FREEZE) {
+      SDValue Inner = widenMaskTree(V.getOperand(0), ToVT, Depth + 1);
+      if (!Inner)
+        return SDValue();
+      return DAG.getNode(ISD::FREEZE, DL, Inner.getValueType(), Inner);
+    }
 
-  // Vector shuffle: widen inputs and apply the same mask.
-  if (Opcode == ISD::VECTOR_SHUFFLE) {
-    auto *Shuf = cast<ShuffleVectorSDNode>(V);
-    SDValue Op0 = widenMaskTree(V.getOperand(0), ToVT, Depth + 1);
-    if (!Op0)
-      return SDValue();
-    SDValue Op1 = V.getOperand(1).isUndef()
-                      ? DAG.getUNDEF(ToVT)
-                      : widenMaskTree(V.getOperand(1), ToVT, Depth + 1);
-    if (!Op1)
-      return SDValue();
-    return DAG.getVectorShuffle(ToVT, DL, Op0, Op1, Shuf->getMask());
-  }
+    // Vector shuffle: try inferring the best fitting width from operands.
+    if (Opcode == ISD::VECTOR_SHUFFLE) {
+      auto *Shuf = cast<ShuffleVectorSDNode>(V);
+      SDValue Op0 = widenMaskTree(V.getOperand(0), ToVT, Depth + 1);
+      if (!Op0)
+        return SDValue();
+      if (V.getOperand(1).isUndef()) {
+        EVT OpVT = Op0.getValueType();
+        return DAG.getVectorShuffle(OpVT, DL, Op0, DAG.getUNDEF(OpVT),
+                                    Shuf->getMask());
+      }
+      SDValue Op1 = widenMaskTree(V.getOperand(1), ToVT, Depth + 1);
+      if (!Op1)
+        return SDValue();
+      EVT OpVT = unifyMaskTypes(Op0, Op1, ToVT);
+      return DAG.getVectorShuffle(OpVT, DL, Op0, Op1, Shuf->getMask());
+    }
 
-  // SELECT/VSELECT: widen both true and false mask values.
-  if (Opcode == ISD::SELECT || Opcode == ISD::VSELECT) {
-    SDValue Cond = V.getOperand(0);
-    if (Cond.getValueType().isVector()) {
-      Cond = widenMaskTree(Cond, ToVT, Depth + 1);
-      if (!Cond)
+    // SELECT/VSELECT: try inferring the best fitting width from operands.
+    if (Opcode == ISD::SELECT || Opcode == ISD::VSELECT) {
+      SDValue Op1 = widenMaskTree(V.getOperand(1), ToVT, Depth + 1);
+      if (!Op1)
         return SDValue();
+      SDValue Op2 = widenMaskTree(V.getOperand(2), ToVT, Depth + 1);
+      if (!Op2)
+        return SDValue();
+      EVT OpVT = unifyMaskTypes(Op1, Op2, ToVT);
+
+      SDValue Cond = V.getOperand(0);
+      if (Opcode == ISD::VSELECT) {
+        Cond = widenMaskTree(Cond, ToVT, Depth + 1);
+        if (!Cond)
+          return SDValue();
+        Cond = adjustMaskToType(Cond, OpVT);
+      }
+      return DAG.getNode(Opcode, DL, OpVT, Cond, Op1, Op2);
     }
-    SDValue Op1 = widenMaskTree(V.getOperand(1), ToVT, Depth + 1);
-    if (!Op1)
-      return SDValue();
-    SDValue Op2 = widenMaskTree(V.getOperand(2), ToVT, Depth + 1);
-    if (!Op2)
-      return SDValue();
-    return DAG.getNode(Opcode, DL, ToVT, Cond, Op1, Op2);
-  }
 
-  return SDValue();
+    return SDValue();
+  }();
+
+  if (!Result)
+    return SDValue();
+  if (Depth == 0)
+    Result = adjustMaskToType(Result, ToVT);
+  return Result;
 }
 
 // This method tries to handle some special cases for the vselect mask
diff --git a/llvm/test/CodeGen/AArch64/vselect-widen-mask-tree.ll b/llvm/test/CodeGen/AArch64/vselect-widen-mask-tree.ll
index 2d9712c276594..b1454cfb1c5b1 100644
--- a/llvm/test/CodeGen/AArch64/vselect-widen-mask-tree.ll
+++ b/llvm/test/CodeGen/AArch64/vselect-widen-mask-tree.ll
@@ -124,21 +124,17 @@ define <3 x i64> @or_setcc_i16_i32_sel_i64(<3 x i16> %a, <3 x i16> %b, <3 x i32>
 ; CHECK-NEXT:    ; kill: def $d6 killed $d6 def $q6
 ; CHECK-NEXT:    mov.d v4[1], v5[0]
 ; CHECK-NEXT:    sshll.4s v0, v0, #0
-; CHECK-NEXT:    ext.16b v2, v1, v1, #8
-; CHECK-NEXT:    ext.16b v3, v0, v0, #8
-; CHECK-NEXT:    orr.8b v0, v0, v1
-; CHECK-NEXT:    ldp d1, d16, [sp]
-; CHECK-NEXT:    sshll.2d v0, v0, #0
-; CHECK-NEXT:    mov.d v7[1], v1[0]
-; CHECK-NEXT:    orr.8b v1, v3, v2
-; CHECK-NEXT:    sshll.2d v1, v1, #0
+; CHECK-NEXT:    orr.16b v1, v0, v1
+; CHECK-NEXT:    ldp d0, d2, [sp]
+; CHECK-NEXT:    mov.d v7[1], v0[0]
+; CHECK-NEXT:    sshll.2d v0, v1, #0
+; CHECK-NEXT:    sshll2.2d v1, v1, #0
+; CHECK-NEXT:    bit.16b v2, v6, v1
 ; CHECK-NEXT:    bsl.16b v0, v4, v7
-; CHECK-NEXT:    mov.16b v2, v1
+; CHECK-NEXT:    ; kill: def $d2 killed $d2 killed $q2
 ; CHECK-NEXT:    ext.16b v1, v0, v0, #8
 ; CHECK-NEXT:    ; kill: def $d0 killed $d0 killed $q0
-; CHECK-NEXT:    bsl.16b v2, v6, v16
 ; CHECK-NEXT:    ; kill: def $d1 killed $d1 killed $q1
-; CHECK-NEXT:    ; kill: def $d2 killed $d2 killed $q2
 ; CHECK-NEXT:    ret
   %cmp0 = icmp sgt <3 x i16> %a, %b
   %cmp1 = icmp sgt <3 x i32> %c, %d
@@ -157,21 +153,17 @@ define <3 x i64> @and_setcc_i32_i32_sel_i64(<3 x i32> %a, <3 x i32> %b, <3 x i32
 ; CHECK-NEXT:    ; kill: def $d5 killed $d5 def $q5
 ; CHECK-NEXT:    ; kill: def $d6 killed $d6 def $q6
 ; CHECK-NEXT:    mov.d v4[1], v5[0]
-; CHECK-NEXT:    ext.16b v1, v2, v2, #8
-; CHECK-NEXT:    ext.16b v3, v0, v0, #8
-; CHECK-NEXT:    and.8b v0, v0, v2
-; CHECK-NEXT:    ldp d2, d16, [sp]
-; CHECK-NEXT:    sshll.2d v0, v0, #0
-; CHECK-NEXT:    and.8b v1, v3, v1
-; CHECK-NEXT:    mov.d v7[1], v2[0]
-; CHECK-NEXT:    sshll.2d v1, v1, #0
+; CHECK-NEXT:    and.16b v1, v0, v2
+; CHECK-NEXT:    ldp d0, d2, [sp]
+; CHECK-NEXT:    mov.d v7[1], v0[0]
+; CHECK-NEXT:    sshll.2d v0, v1, #0
+; CHECK-NEXT:    sshll2.2d v1, v1, #0
+; CHECK-NEXT:    bit.16b v2, v6, v1
 ; CHECK-NEXT:    bsl.16b v0, v4, v7
-; CHECK-NEXT:    mov.16b v2, v1
+; CHECK-NEXT:    ; kill: def $d2 killed $d2 killed $q2
 ; CHECK-NEXT:    ext.16b v1, v0, v0, #8
 ; CHECK-NEXT:    ; kill: def $d0 killed $d0 killed $q0
-; CHECK-NEXT:    bsl.16b v2, v6, v16
 ; CHECK-NEXT:    ; kill: def $d1 killed $d1 killed $q1
-; CHECK-NEXT:    ; kill: def $d2 killed $d2 killed $q2
 ; CHECK-NEXT:    ret
   %cmp0 = icmp sgt <3 x i32> %a, %b
   %cmp1 = icmp sgt <3 x i32> %c, %d
@@ -190,23 +182,18 @@ define <3 x i64> @or_setcc_i16_i16_sel_i64(<3 x i16> %a, <3 x i16> %b, <3 x i16>
 ; CHECK-NEXT:    ; kill: def $d5 killed $d5 def $q5
 ; CHECK-NEXT:    ; kill: def $d6 killed $d6 def $q6
 ; CHECK-NEXT:    mov.d v4[1], v5[0]
-; CHECK-NEXT:    sshll.4s v1, v2, #0
-; CHECK-NEXT:    sshll.4s v0, v0, #0
-; CHECK-NEXT:    ext.16b v2, v1, v1, #8
-; CHECK-NEXT:    ext.16b v3, v0, v0, #8
-; CHECK-NEXT:    orr.8b v0, v0, v1
-; CHECK-NEXT:    ldp d1, d16, [sp]
-; CHECK-NEXT:    sshll.2d v0, v0, #0
-; CHECK-NEXT:    mov.d v7[1], v1[0]
-; CHECK-NEXT:    orr.8b v1, v3, v2
-; CHECK-NEXT:    sshll.2d v1, v1, #0
+; CHECK-NEXT:    orr.8b v0, v0, v2
+; CHECK-NEXT:    sshll.4s v1, v0, #0
+; CHECK-NEXT:    ldp d0, d2, [sp]
+; CHECK-NEXT:    mov.d v7[1], v0[0]
+; CHECK-NEXT:    sshll.2d v0, v1, #0
+; CHECK-NEXT:    sshll2.2d v1, v1, #0
+; CHECK-NEXT:    bit.16b v2, v6, v1
 ; CHECK-NEXT:    bsl.16b v0, v4, v7
-; CHECK-NEXT:    mov.16b v2, v1
+; CHECK-NEXT:    ; kill: def $d2 killed $d2 killed $q2
 ; CHECK-NEXT:    ext.16b v1, v0, v0, #8
 ; CHECK-NEXT:    ; kill: def $d0 killed $d0 killed $q0
-; CHECK-NEXT:    bsl.16b v2, v6, v16
 ; CHECK-NEXT:    ; kill: def $d1 killed $d1 killed $q1
-; CHECK-NEXT:    ; kill: def $d2 killed $d2 killed $q2
 ; CHECK-NEXT:    ret
   %cmp0 = icmp sgt <3 x i16> %a, %b
   %cmp1 = icmp sgt <3 x i16> %c, %d



More information about the llvm-commits mailing list