[llvm] [SelectionDAG] Recurse through mask expression trees in WidenVSELECTMask (PR #188085)
Valeriy Savchenko via llvm-commits
llvm-commits at lists.llvm.org
Wed Apr 1 04:43:44 PDT 2026
https://github.com/SavchenkoValeriy updated https://github.com/llvm/llvm-project/pull/188085
>From b92b34f4e595e197d2b6bcfd8f235cde95a3df1c Mon Sep 17 00:00:00 2001
From: Valeriy Savchenko <vsavchenko at apple.com>
Date: Mon, 23 Mar 2026 17:51:12 +0000
Subject: [PATCH 1/3] [AArch64][NFC] Add tests for mask widening
---
.../AArch64/vselect-widen-mask-tree.ll | 218 ++++++++++++++++++
1 file changed, 218 insertions(+)
create mode 100644 llvm/test/CodeGen/AArch64/vselect-widen-mask-tree.ll
diff --git a/llvm/test/CodeGen/AArch64/vselect-widen-mask-tree.ll b/llvm/test/CodeGen/AArch64/vselect-widen-mask-tree.ll
new file mode 100644
index 0000000000000..c459c78c3e5df
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/vselect-widen-mask-tree.ll
@@ -0,0 +1,218 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=arm64-apple-ios -o - %s | FileCheck %s
+
+define <4 x i32> @freeze_or_setcc(<4 x i32> %a, <4 x i32> %b, <4 x i32> %x, <4 x i32> %y) {
+; CHECK-LABEL: freeze_or_setcc:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: add.4s v4, v0, v1
+; CHECK-NEXT: cmgt.4s v0, v1, v0
+; CHECK-NEXT: cmgt.4s v4, v4, #0
+; CHECK-NEXT: orr.16b v0, v4, v0
+; CHECK-NEXT: bsl.16b v0, v2, v3
+; CHECK-NEXT: ret
+ %add = add nsw <4 x i32> %a, %b
+ %cmp1 = icmp sgt <4 x i32> %add, zeroinitializer
+ %sub = sub nsw <4 x i32> %a, %b
+ %cmp2 = icmp slt <4 x i32> %sub, zeroinitializer
+ %or = or <4 x i1> %cmp1, %cmp2
+ %fr = freeze <4 x i1> %or
+ %sel = select <4 x i1> %fr, <4 x i32> %x, <4 x i32> %y
+ ret <4 x i32> %sel
+}
+
+define <4 x i32> @select_allones_or_setcc(<4 x i32> %a, <4 x i32> %x, <4 x i32> %y, i1 %cond) {
+; CHECK-LABEL: select_allones_or_setcc:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: cmgt.4s v0, v0, #0
+; CHECK-NEXT: tst w0, #0x1
+; CHECK-NEXT: csetm w8, ne
+; CHECK-NEXT: dup.4h v3, w8
+; CHECK-NEXT: xtn.4h v0, v0
+; CHECK-NEXT: orr.8b v0, v0, v3
+; CHECK-NEXT: sshll.4s v0, v0, #0
+; CHECK-NEXT: bsl.16b v0, v1, v2
+; CHECK-NEXT: ret
+ %cmp = icmp sgt <4 x i32> %a, zeroinitializer
+ %mask = select i1 %cond, <4 x i1> <i1 true, i1 true, i1 true, i1 true>, <4 x i1> %cmp
+ %sel = select <4 x i1> %mask, <4 x i32> %x, <4 x i32> %y
+ ret <4 x i32> %sel
+}
+
+define <4 x i32> @select_setcc_or_allzeros(<4 x i32> %a, <4 x i32> %x, <4 x i32> %y, i1 %cond) {
+; CHECK-LABEL: select_setcc_or_allzeros:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: cmgt.4s v0, v0, #0
+; CHECK-NEXT: tst w0, #0x1
+; CHECK-NEXT: csetm w8, ne
+; CHECK-NEXT: dup.4h v3, w8
+; CHECK-NEXT: xtn.4h v0, v0
+; CHECK-NEXT: and.8b v0, v0, v3
+; CHECK-NEXT: sshll.4s v0, v0, #0
+; CHECK-NEXT: bsl.16b v0, v1, v2
+; CHECK-NEXT: ret
+ %cmp = icmp sgt <4 x i32> %a, zeroinitializer
+ %mask = select i1 %cond, <4 x i1> %cmp, <4 x i1> zeroinitializer
+ %sel = select <4 x i1> %mask, <4 x i32> %x, <4 x i32> %y
+ ret <4 x i32> %sel
+}
+
+define <4 x i32> @select_allzeros_or_allones(<4 x i32> %x, <4 x i32> %y, i1 %cond) {
+; CHECK-LABEL: select_allzeros_or_allones:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: tst w0, #0x1
+; CHECK-NEXT: csetm w8, ne
+; CHECK-NEXT: dup.4h v2, w8
+; CHECK-NEXT: mvn.8b v2, v2
+; CHECK-NEXT: sshll.4s v2, v2, #0
+; CHECK-NEXT: bif.16b v0, v1, v2
+; CHECK-NEXT: ret
+ %mask = select i1 %cond, <4 x i1> zeroinitializer, <4 x i1> <i1 true, i1 true, i1 true, i1 true>
+ %sel = select <4 x i1> %mask, <4 x i32> %x, <4 x i32> %y
+ ret <4 x i32> %sel
+}
+
+define <4 x i32> @vselect_of_setccs(<4 x i32> %a, <4 x i32> %b, <4 x i32> %x, <4 x i32> %y) {
+; CHECK-LABEL: vselect_of_setccs:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: movi.4s v4, #100
+; CHECK-NEXT: cmgt.4s v5, v0, #0
+; CHECK-NEXT: cmeq.4s v1, v1, #0
+; CHECK-NEXT: cmgt.4s v0, v4, v0
+; CHECK-NEXT: xtn.4h v4, v1
+; CHECK-NEXT: and.16b v1, v5, v1
+; CHECK-NEXT: xtn.4h v0, v0
+; CHECK-NEXT: xtn.4h v1, v1
+; CHECK-NEXT: bic.8b v0, v0, v4
+; CHECK-NEXT: orr.8b v0, v1, v0
+; CHECK-NEXT: sshll.4s v0, v0, #0
+; CHECK-NEXT: bsl.16b v0, v2, v3
+; CHECK-NEXT: ret
+ %cmp1 = icmp sgt <4 x i32> %a, zeroinitializer
+ %cmp2 = icmp slt <4 x i32> %a, <i32 100, i32 100, i32 100, i32 100>
+ %cond = icmp eq <4 x i32> %b, zeroinitializer
+ %mask = select <4 x i1> %cond, <4 x i1> %cmp1, <4 x i1> %cmp2
+ %sel = select <4 x i1> %mask, <4 x i32> %x, <4 x i32> %y
+ ret <4 x i32> %sel
+}
+
+define <4 x i32> @select_scalar_cond_setccs(<4 x i32> %a, <4 x i32> %x, <4 x i32> %y, i1 %cond) {
+; CHECK-LABEL: select_scalar_cond_setccs:
+; CHECK: ; %bb.0: ; %entry
+; CHECK-NEXT: movi.4s v3, #100
+; CHECK-NEXT: cmgt.4s v4, v0, #0
+; CHECK-NEXT: tst w0, #0x1
+; CHECK-NEXT: csetm w8, ne
+; CHECK-NEXT: cmgt.4s v0, v3, v0
+; CHECK-NEXT: xtn.4h v3, v4
+; CHECK-NEXT: dup.4h v4, w8
+; CHECK-NEXT: xtn.4h v0, v0
+; CHECK-NEXT: bif.8b v0, v3, v4
+; CHECK-NEXT: sshll.4s v0, v0, #0
+; CHECK-NEXT: bsl.16b v0, v1, v2
+; CHECK-NEXT: ret
+entry:
+ %cmp1 = icmp sgt <4 x i32> %a, zeroinitializer
+ br i1 %cond, label %then, label %else
+
+then:
+ %cmp2 = icmp slt <4 x i32> %a, <i32 100, i32 100, i32 100, i32 100>
+ br label %merge
+
+else:
+ br label %merge
+
+merge:
+ %mask = phi <4 x i1> [ %cmp2, %then ], [ %cmp1, %else ]
+ %fr = freeze <4 x i1> %mask
+ %sel = select <4 x i1> %fr, <4 x i32> %x, <4 x i32> %y
+ ret <4 x i32> %sel
+}
+
+define <3 x i64> @or_setcc_i16_i32_sel_i64(<3 x i16> %a, <3 x i16> %b, <3 x i32> %c, <3 x i32> %d, <3 x i64> %x, <3 x i64> %y) {
+; CHECK-LABEL: or_setcc_i16_i32_sel_i64:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: cmgt.4h v0, v0, v1
+; CHECK-NEXT: cmgt.4s v1, v2, v3
+; CHECK-NEXT: ; kill: def $d7 killed $d7 def $q7
+; CHECK-NEXT: ; kill: def $d4 killed $d4 def $q4
+; CHECK-NEXT: ; kill: def $d5 killed $d5 def $q5
+; CHECK-NEXT: ; kill: def $d6 killed $d6 def $q6
+; CHECK-NEXT: mov.d v4[1], v5[0]
+; CHECK-NEXT: sshll.4s v0, v0, #0
+; CHECK-NEXT: orr.16b v1, v0, v1
+; CHECK-NEXT: ldp d0, d2, [sp]
+; CHECK-NEXT: mov.d v7[1], v0[0]
+; CHECK-NEXT: sshll.2d v0, v1, #0
+; CHECK-NEXT: sshll2.2d v1, v1, #0
+; CHECK-NEXT: bit.16b v2, v6, v1
+; CHECK-NEXT: bsl.16b v0, v4, v7
+; CHECK-NEXT: ; kill: def $d2 killed $d2 killed $q2
+; CHECK-NEXT: ext.16b v1, v0, v0, #8
+; CHECK-NEXT: ; kill: def $d0 killed $d0 killed $q0
+; CHECK-NEXT: ; kill: def $d1 killed $d1 killed $q1
+; CHECK-NEXT: ret
+ %cmp0 = icmp sgt <3 x i16> %a, %b
+ %cmp1 = icmp sgt <3 x i32> %c, %d
+ %or = or <3 x i1> %cmp0, %cmp1
+ %sel = select <3 x i1> %or, <3 x i64> %x, <3 x i64> %y
+ ret <3 x i64> %sel
+}
+
+define <3 x i64> @and_setcc_i32_i32_sel_i64(<3 x i32> %a, <3 x i32> %b, <3 x i32> %c, <3 x i32> %d, <3 x i64> %x, <3 x i64> %y) {
+; CHECK-LABEL: and_setcc_i32_i32_sel_i64:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: cmgt.4s v2, v2, v3
+; CHECK-NEXT: cmgt.4s v0, v0, v1
+; CHECK-NEXT: ; kill: def $d7 killed $d7 def $q7
+; CHECK-NEXT: ; kill: def $d4 killed $d4 def $q4
+; CHECK-NEXT: ; kill: def $d5 killed $d5 def $q5
+; CHECK-NEXT: ; kill: def $d6 killed $d6 def $q6
+; CHECK-NEXT: mov.d v4[1], v5[0]
+; CHECK-NEXT: and.16b v1, v0, v2
+; CHECK-NEXT: ldp d0, d2, [sp]
+; CHECK-NEXT: mov.d v7[1], v0[0]
+; CHECK-NEXT: sshll.2d v0, v1, #0
+; CHECK-NEXT: sshll2.2d v1, v1, #0
+; CHECK-NEXT: bit.16b v2, v6, v1
+; CHECK-NEXT: bsl.16b v0, v4, v7
+; CHECK-NEXT: ; kill: def $d2 killed $d2 killed $q2
+; CHECK-NEXT: ext.16b v1, v0, v0, #8
+; CHECK-NEXT: ; kill: def $d0 killed $d0 killed $q0
+; CHECK-NEXT: ; kill: def $d1 killed $d1 killed $q1
+; CHECK-NEXT: ret
+ %cmp0 = icmp sgt <3 x i32> %a, %b
+ %cmp1 = icmp sgt <3 x i32> %c, %d
+ %and = and <3 x i1> %cmp0, %cmp1
+ %sel = select <3 x i1> %and, <3 x i64> %x, <3 x i64> %y
+ ret <3 x i64> %sel
+}
+
+define <3 x i64> @or_setcc_i16_i16_sel_i64(<3 x i16> %a, <3 x i16> %b, <3 x i16> %c, <3 x i16> %d, <3 x i64> %x, <3 x i64> %y) {
+; CHECK-LABEL: or_setcc_i16_i16_sel_i64:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: cmgt.4h v2, v2, v3
+; CHECK-NEXT: cmgt.4h v0, v0, v1
+; CHECK-NEXT: ; kill: def $d7 killed $d7 def $q7
+; CHECK-NEXT: ; kill: def $d4 killed $d4 def $q4
+; CHECK-NEXT: ; kill: def $d5 killed $d5 def $q5
+; CHECK-NEXT: ; kill: def $d6 killed $d6 def $q6
+; CHECK-NEXT: mov.d v4[1], v5[0]
+; CHECK-NEXT: orr.8b v0, v0, v2
+; CHECK-NEXT: sshll.4s v1, v0, #0
+; CHECK-NEXT: ldp d0, d2, [sp]
+; CHECK-NEXT: mov.d v7[1], v0[0]
+; CHECK-NEXT: sshll.2d v0, v1, #0
+; CHECK-NEXT: sshll2.2d v1, v1, #0
+; CHECK-NEXT: bit.16b v2, v6, v1
+; CHECK-NEXT: bsl.16b v0, v4, v7
+; CHECK-NEXT: ; kill: def $d2 killed $d2 killed $q2
+; CHECK-NEXT: ext.16b v1, v0, v0, #8
+; CHECK-NEXT: ; kill: def $d0 killed $d0 killed $q0
+; CHECK-NEXT: ; kill: def $d1 killed $d1 killed $q1
+; CHECK-NEXT: ret
+ %cmp0 = icmp sgt <3 x i16> %a, %b
+ %cmp1 = icmp sgt <3 x i16> %c, %d
+ %or = or <3 x i1> %cmp0, %cmp1
+ %sel = select <3 x i1> %or, <3 x i64> %x, <3 x i64> %y
+ ret <3 x i64> %sel
+}
>From 0539a4c00361118cad4897ff22ec679e3b9b42ae Mon Sep 17 00:00:00 2001
From: Valeriy Savchenko <vsavchenko at apple.com>
Date: Fri, 6 Mar 2026 16:37:00 +0000
Subject: [PATCH 2/3] [SelectionDAG] Recurse through mask expression trees in
WidenVSELECTMask
---
llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h | 5 +
.../SelectionDAG/LegalizeVectorTypes.cpp | 129 +++++++++++-------
llvm/test/CodeGen/AArch64/arm64-zip.ll | 27 ++--
.../AArch64/vselect-widen-mask-tree.ll | 96 +++++++------
.../X86/bitcast-int-to-vector-bool-sext.ll | 65 +++++----
5 files changed, 172 insertions(+), 150 deletions(-)
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index da592e3cad0f5..4d69c0e8bcde8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -1149,6 +1149,11 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
/// MaskVT to ToMaskVT if needed with vector extension or truncation.
SDValue convertMask(SDValue InMask, EVT MaskVT, EVT ToMaskVT);
+ /// Recursively widen a mask expression tree to ToVT. Handles SETCC,
+ /// logical ops (AND/OR/XOR), VECTOR_SHUFFLE, FREEZE, SELECT/VSELECT,
+ /// and constant BUILD_VECTORs.
+ SDValue widenMaskTree(SDValue V, EVT ToVT, unsigned Depth = 0);
+
//===--------------------------------------------------------------------===//
// Generic Splitting: LegalizeTypesGeneric.cpp
//===--------------------------------------------------------------------===//
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index aeb9d4d7bdc1d..0b7ab61cf3df8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -6806,9 +6806,8 @@ static inline bool isSETCCorConvertedSETCC(SDValue N) {
// to ToMaskVT if needed with vector extension or truncation.
SDValue DAGTypeLegalizer::convertMask(SDValue InMask, EVT MaskVT,
EVT ToMaskVT) {
- // Currently a SETCC or a AND/OR/XOR with two SETCCs are handled.
- // FIXME: This code seems to be too restrictive, we might consider
- // generalizing it or dropping it.
+ // Called from widenMaskTree for SETCC leaf nodes. Re-creates the SETCC with
+ // result type MaskVT, then sign-extends/truncates and pads to ToMaskVT.
assert(isSETCCorConvertedSETCC(InMask) && "Unexpected mask argument.");
// Make a new Mask node, with a legal result VT.
@@ -6862,6 +6861,82 @@ SDValue DAGTypeLegalizer::convertMask(SDValue InMask, EVT MaskVT,
return Mask;
}
+// Recursively widen a mask expression tree to ToVT. Handles any chain of
+// mask-preserving operations rooted at SETCC nodes. Returns SDValue() if
+// the tree cannot be widened.
+SDValue DAGTypeLegalizer::widenMaskTree(SDValue V, EVT ToVT, unsigned Depth) {
+ if (Depth >= DAG.MaxRecursionDepth)
+ return SDValue();
+
+ unsigned Opcode = V.getOpcode();
+
+ // Base case: SETCC produces the mask directly.
+ if (isSETCCOp(Opcode)) {
+ EVT MaskVT = getSetCCResultType(getSETCCOperandType(V));
+ return convertMask(V, MaskVT, ToVT);
+ }
+
+ // Base case: all-zeros or all-ones BUILD_VECTOR.
+ if (ISD::isBuildVectorAllZeros(V.getNode()))
+ return DAG.getConstant(0, SDLoc(V), ToVT);
+ if (ISD::isBuildVectorAllOnes(V.getNode()))
+ return DAG.getAllOnesConstant(SDLoc(V), ToVT);
+
+ SDLoc DL(V);
+
+ // Logical operations (AND/OR/XOR): widen both operands and rebuild.
+ if (isLogicalMaskOp(Opcode)) {
+ SDValue Op0 = widenMaskTree(V.getOperand(0), ToVT, Depth + 1);
+ if (!Op0)
+ return SDValue();
+ SDValue Op1 = widenMaskTree(V.getOperand(1), ToVT, Depth + 1);
+ if (!Op1)
+ return SDValue();
+ return DAG.getNode(Opcode, DL, ToVT, Op0, Op1);
+ }
+
+ // FREEZE: widen the operand and re-wrap.
+ if (Opcode == ISD::FREEZE) {
+ SDValue Inner = widenMaskTree(V.getOperand(0), ToVT, Depth + 1);
+ if (!Inner)
+ return SDValue();
+ return DAG.getNode(ISD::FREEZE, DL, ToVT, Inner);
+ }
+
+ // Vector shuffle: widen inputs and apply the same mask.
+ if (Opcode == ISD::VECTOR_SHUFFLE) {
+ auto *Shuf = cast<ShuffleVectorSDNode>(V);
+ SDValue Op0 = widenMaskTree(V.getOperand(0), ToVT, Depth + 1);
+ if (!Op0)
+ return SDValue();
+ SDValue Op1 = V.getOperand(1).isUndef()
+ ? DAG.getUNDEF(ToVT)
+ : widenMaskTree(V.getOperand(1), ToVT, Depth + 1);
+ if (!Op1)
+ return SDValue();
+ return DAG.getVectorShuffle(ToVT, DL, Op0, Op1, Shuf->getMask());
+ }
+
+ // SELECT/VSELECT: widen both true and false mask values.
+ if (Opcode == ISD::SELECT || Opcode == ISD::VSELECT) {
+ SDValue Cond = V.getOperand(0);
+ if (Cond.getValueType().isVector()) {
+ Cond = widenMaskTree(Cond, ToVT, Depth + 1);
+ if (!Cond)
+ return SDValue();
+ }
+ SDValue Op1 = widenMaskTree(V.getOperand(1), ToVT, Depth + 1);
+ if (!Op1)
+ return SDValue();
+ SDValue Op2 = widenMaskTree(V.getOperand(2), ToVT, Depth + 1);
+ if (!Op2)
+ return SDValue();
+ return DAG.getNode(Opcode, DL, ToVT, Cond, Op1, Op2);
+ }
+
+ return SDValue();
+}
+
// This method tries to handle some special cases for the vselect mask
// and if needed adjusting the mask vector type to match that of the VSELECT.
// Without it, many cases end up with scalarization of the SETCC, with many
@@ -6873,9 +6948,6 @@ SDValue DAGTypeLegalizer::WidenVSELECTMask(SDNode *N) {
if (N->getOpcode() != ISD::VSELECT)
return SDValue();
- if (!isSETCCOp(Cond->getOpcode()) && !isLogicalMaskOp(Cond->getOpcode()))
- return SDValue();
-
// If this is a splitted VSELECT that was previously already handled, do
// nothing.
EVT CondVT = Cond->getValueType(0);
@@ -6928,49 +7000,8 @@ SDValue DAGTypeLegalizer::WidenVSELECTMask(SDNode *N) {
if (!ToMaskVT.getScalarType().isInteger())
ToMaskVT = ToMaskVT.changeVectorElementTypeToInteger();
- SDValue Mask;
- if (isSETCCOp(Cond->getOpcode())) {
- EVT MaskVT = getSetCCResultType(getSETCCOperandType(Cond));
- Mask = convertMask(Cond, MaskVT, ToMaskVT);
- } else if (isLogicalMaskOp(Cond->getOpcode()) &&
- isSETCCOp(Cond->getOperand(0).getOpcode()) &&
- isSETCCOp(Cond->getOperand(1).getOpcode())) {
- // Cond is (AND/OR/XOR (SETCC, SETCC))
- SDValue SETCC0 = Cond->getOperand(0);
- SDValue SETCC1 = Cond->getOperand(1);
- EVT VT0 = getSetCCResultType(getSETCCOperandType(SETCC0));
- EVT VT1 = getSetCCResultType(getSETCCOperandType(SETCC1));
- unsigned ScalarBits0 = VT0.getScalarSizeInBits();
- unsigned ScalarBits1 = VT1.getScalarSizeInBits();
- unsigned ScalarBits_ToMask = ToMaskVT.getScalarSizeInBits();
- EVT MaskVT;
- // If the two SETCCs have different VTs, either extend/truncate one of
- // them to the other "towards" ToMaskVT, or truncate one and extend the
- // other to ToMaskVT.
- if (ScalarBits0 != ScalarBits1) {
- EVT NarrowVT = ((ScalarBits0 < ScalarBits1) ? VT0 : VT1);
- EVT WideVT = ((NarrowVT == VT0) ? VT1 : VT0);
- if (ScalarBits_ToMask >= WideVT.getScalarSizeInBits())
- MaskVT = WideVT;
- else if (ScalarBits_ToMask <= NarrowVT.getScalarSizeInBits())
- MaskVT = NarrowVT;
- else
- MaskVT = ToMaskVT;
- } else
- // If the two SETCCs have the same VT, don't change it.
- MaskVT = VT0;
-
- // Make new SETCCs and logical nodes.
- SETCC0 = convertMask(SETCC0, VT0, MaskVT);
- SETCC1 = convertMask(SETCC1, VT1, MaskVT);
- Cond = DAG.getNode(Cond->getOpcode(), SDLoc(Cond), MaskVT, SETCC0, SETCC1);
-
- // Convert the logical op for VSELECT if needed.
- Mask = convertMask(Cond, MaskVT, ToMaskVT);
- } else
- return SDValue();
-
- return Mask;
+ // Try to recursively widen the mask expression tree to the target type.
+ return widenMaskTree(Cond, ToMaskVT);
}
SDValue DAGTypeLegalizer::WidenVecRes_Select(SDNode *N) {
diff --git a/llvm/test/CodeGen/AArch64/arm64-zip.ll b/llvm/test/CodeGen/AArch64/arm64-zip.ll
index 44411a1032dca..fac75f24e9666 100644
--- a/llvm/test/CodeGen/AArch64/arm64-zip.ll
+++ b/llvm/test/CodeGen/AArch64/arm64-zip.ll
@@ -382,13 +382,10 @@ define <4 x float> @shuffle_zip1(<4 x float> %arg) {
; CHECK-LABEL: shuffle_zip1:
; CHECK: // %bb.0: // %bb
; CHECK-NEXT: fcmgt.4s v0, v0, #0.0
-; CHECK-NEXT: uzp1.8h v1, v0, v0
-; CHECK-NEXT: xtn.4h v0, v0
-; CHECK-NEXT: xtn.4h v1, v1
-; CHECK-NEXT: zip2.4h v0, v0, v1
; CHECK-NEXT: fmov.4s v1, #1.00000000
-; CHECK-NEXT: zip1.4h v0, v0, v0
-; CHECK-NEXT: sshll.4s v0, v0, #0
+; CHECK-NEXT: uzp1.4s v2, v0, v0
+; CHECK-NEXT: zip2.4s v0, v0, v2
+; CHECK-NEXT: zip1.4s v0, v0, v0
; CHECK-NEXT: and.16b v0, v0, v1
; CHECK-NEXT: ret
bb:
@@ -403,13 +400,10 @@ define <4 x i32> @shuffle_zip2(<4 x i32> %arg) {
; CHECK-LABEL: shuffle_zip2:
; CHECK: // %bb.0: // %bb
; CHECK-NEXT: cmtst.4s v0, v0, v0
-; CHECK-NEXT: uzp1.8h v1, v0, v0
-; CHECK-NEXT: xtn.4h v0, v0
-; CHECK-NEXT: xtn.4h v1, v1
-; CHECK-NEXT: zip2.4h v0, v0, v1
; CHECK-NEXT: movi.4s v1, #1
-; CHECK-NEXT: zip1.4h v0, v0, v0
-; CHECK-NEXT: ushll.4s v0, v0, #0
+; CHECK-NEXT: uzp1.4s v2, v0, v0
+; CHECK-NEXT: zip2.4s v0, v0, v2
+; CHECK-NEXT: zip1.4s v0, v0, v0
; CHECK-NEXT: and.16b v0, v0, v1
; CHECK-NEXT: ret
bb:
@@ -424,13 +418,10 @@ define <4 x i32> @shuffle_zip3(<4 x i32> %arg) {
; CHECK-LABEL: shuffle_zip3:
; CHECK: // %bb.0: // %bb
; CHECK-NEXT: cmgt.4s v0, v0, #0
-; CHECK-NEXT: uzp1.8h v1, v0, v0
-; CHECK-NEXT: xtn.4h v0, v0
-; CHECK-NEXT: xtn.4h v1, v1
-; CHECK-NEXT: zip2.4h v0, v0, v1
; CHECK-NEXT: movi.4s v1, #1
-; CHECK-NEXT: zip1.4h v0, v0, v0
-; CHECK-NEXT: ushll.4s v0, v0, #0
+; CHECK-NEXT: uzp1.4s v2, v0, v0
+; CHECK-NEXT: zip2.4s v0, v0, v2
+; CHECK-NEXT: zip1.4s v0, v0, v0
; CHECK-NEXT: and.16b v0, v0, v1
; CHECK-NEXT: ret
bb:
diff --git a/llvm/test/CodeGen/AArch64/vselect-widen-mask-tree.ll b/llvm/test/CodeGen/AArch64/vselect-widen-mask-tree.ll
index c459c78c3e5df..2d9712c276594 100644
--- a/llvm/test/CodeGen/AArch64/vselect-widen-mask-tree.ll
+++ b/llvm/test/CodeGen/AArch64/vselect-widen-mask-tree.ll
@@ -23,13 +23,11 @@ define <4 x i32> @freeze_or_setcc(<4 x i32> %a, <4 x i32> %b, <4 x i32> %x, <4 x
define <4 x i32> @select_allones_or_setcc(<4 x i32> %a, <4 x i32> %x, <4 x i32> %y, i1 %cond) {
; CHECK-LABEL: select_allones_or_setcc:
; CHECK: ; %bb.0:
-; CHECK-NEXT: cmgt.4s v0, v0, #0
; CHECK-NEXT: tst w0, #0x1
+; CHECK-NEXT: cmgt.4s v0, v0, #0
; CHECK-NEXT: csetm w8, ne
-; CHECK-NEXT: dup.4h v3, w8
-; CHECK-NEXT: xtn.4h v0, v0
-; CHECK-NEXT: orr.8b v0, v0, v3
-; CHECK-NEXT: sshll.4s v0, v0, #0
+; CHECK-NEXT: dup.4s v3, w8
+; CHECK-NEXT: orr.16b v0, v0, v3
; CHECK-NEXT: bsl.16b v0, v1, v2
; CHECK-NEXT: ret
%cmp = icmp sgt <4 x i32> %a, zeroinitializer
@@ -41,13 +39,11 @@ define <4 x i32> @select_allones_or_setcc(<4 x i32> %a, <4 x i32> %x, <4 x i32>
define <4 x i32> @select_setcc_or_allzeros(<4 x i32> %a, <4 x i32> %x, <4 x i32> %y, i1 %cond) {
; CHECK-LABEL: select_setcc_or_allzeros:
; CHECK: ; %bb.0:
-; CHECK-NEXT: cmgt.4s v0, v0, #0
; CHECK-NEXT: tst w0, #0x1
+; CHECK-NEXT: cmgt.4s v0, v0, #0
; CHECK-NEXT: csetm w8, ne
-; CHECK-NEXT: dup.4h v3, w8
-; CHECK-NEXT: xtn.4h v0, v0
-; CHECK-NEXT: and.8b v0, v0, v3
-; CHECK-NEXT: sshll.4s v0, v0, #0
+; CHECK-NEXT: dup.4s v3, w8
+; CHECK-NEXT: and.16b v0, v0, v3
; CHECK-NEXT: bsl.16b v0, v1, v2
; CHECK-NEXT: ret
%cmp = icmp sgt <4 x i32> %a, zeroinitializer
@@ -61,10 +57,8 @@ define <4 x i32> @select_allzeros_or_allones(<4 x i32> %x, <4 x i32> %y, i1 %con
; CHECK: ; %bb.0:
; CHECK-NEXT: tst w0, #0x1
; CHECK-NEXT: csetm w8, ne
-; CHECK-NEXT: dup.4h v2, w8
-; CHECK-NEXT: mvn.8b v2, v2
-; CHECK-NEXT: sshll.4s v2, v2, #0
-; CHECK-NEXT: bif.16b v0, v1, v2
+; CHECK-NEXT: dup.4s v2, w8
+; CHECK-NEXT: bit.16b v0, v1, v2
; CHECK-NEXT: ret
%mask = select i1 %cond, <4 x i1> zeroinitializer, <4 x i1> <i1 true, i1 true, i1 true, i1 true>
%sel = select <4 x i1> %mask, <4 x i32> %x, <4 x i32> %y
@@ -78,13 +72,7 @@ define <4 x i32> @vselect_of_setccs(<4 x i32> %a, <4 x i32> %b, <4 x i32> %x, <4
; CHECK-NEXT: cmgt.4s v5, v0, #0
; CHECK-NEXT: cmeq.4s v1, v1, #0
; CHECK-NEXT: cmgt.4s v0, v4, v0
-; CHECK-NEXT: xtn.4h v4, v1
-; CHECK-NEXT: and.16b v1, v5, v1
-; CHECK-NEXT: xtn.4h v0, v0
-; CHECK-NEXT: xtn.4h v1, v1
-; CHECK-NEXT: bic.8b v0, v0, v4
-; CHECK-NEXT: orr.8b v0, v1, v0
-; CHECK-NEXT: sshll.4s v0, v0, #0
+; CHECK-NEXT: bit.16b v0, v5, v1
; CHECK-NEXT: bsl.16b v0, v2, v3
; CHECK-NEXT: ret
%cmp1 = icmp sgt <4 x i32> %a, zeroinitializer
@@ -99,15 +87,12 @@ define <4 x i32> @select_scalar_cond_setccs(<4 x i32> %a, <4 x i32> %x, <4 x i32
; CHECK-LABEL: select_scalar_cond_setccs:
; CHECK: ; %bb.0: ; %entry
; CHECK-NEXT: movi.4s v3, #100
-; CHECK-NEXT: cmgt.4s v4, v0, #0
; CHECK-NEXT: tst w0, #0x1
+; CHECK-NEXT: cmgt.4s v4, v0, #0
; CHECK-NEXT: csetm w8, ne
; CHECK-NEXT: cmgt.4s v0, v3, v0
-; CHECK-NEXT: xtn.4h v3, v4
-; CHECK-NEXT: dup.4h v4, w8
-; CHECK-NEXT: xtn.4h v0, v0
-; CHECK-NEXT: bif.8b v0, v3, v4
-; CHECK-NEXT: sshll.4s v0, v0, #0
+; CHECK-NEXT: dup.4s v3, w8
+; CHECK-NEXT: bif.16b v0, v4, v3
; CHECK-NEXT: bsl.16b v0, v1, v2
; CHECK-NEXT: ret
entry:
@@ -139,17 +124,21 @@ define <3 x i64> @or_setcc_i16_i32_sel_i64(<3 x i16> %a, <3 x i16> %b, <3 x i32>
; CHECK-NEXT: ; kill: def $d6 killed $d6 def $q6
; CHECK-NEXT: mov.d v4[1], v5[0]
; CHECK-NEXT: sshll.4s v0, v0, #0
-; CHECK-NEXT: orr.16b v1, v0, v1
-; CHECK-NEXT: ldp d0, d2, [sp]
-; CHECK-NEXT: mov.d v7[1], v0[0]
-; CHECK-NEXT: sshll.2d v0, v1, #0
-; CHECK-NEXT: sshll2.2d v1, v1, #0
-; CHECK-NEXT: bit.16b v2, v6, v1
+; CHECK-NEXT: ext.16b v2, v1, v1, #8
+; CHECK-NEXT: ext.16b v3, v0, v0, #8
+; CHECK-NEXT: orr.8b v0, v0, v1
+; CHECK-NEXT: ldp d1, d16, [sp]
+; CHECK-NEXT: sshll.2d v0, v0, #0
+; CHECK-NEXT: mov.d v7[1], v1[0]
+; CHECK-NEXT: orr.8b v1, v3, v2
+; CHECK-NEXT: sshll.2d v1, v1, #0
; CHECK-NEXT: bsl.16b v0, v4, v7
-; CHECK-NEXT: ; kill: def $d2 killed $d2 killed $q2
+; CHECK-NEXT: mov.16b v2, v1
; CHECK-NEXT: ext.16b v1, v0, v0, #8
; CHECK-NEXT: ; kill: def $d0 killed $d0 killed $q0
+; CHECK-NEXT: bsl.16b v2, v6, v16
; CHECK-NEXT: ; kill: def $d1 killed $d1 killed $q1
+; CHECK-NEXT: ; kill: def $d2 killed $d2 killed $q2
; CHECK-NEXT: ret
%cmp0 = icmp sgt <3 x i16> %a, %b
%cmp1 = icmp sgt <3 x i32> %c, %d
@@ -168,17 +157,21 @@ define <3 x i64> @and_setcc_i32_i32_sel_i64(<3 x i32> %a, <3 x i32> %b, <3 x i32
; CHECK-NEXT: ; kill: def $d5 killed $d5 def $q5
; CHECK-NEXT: ; kill: def $d6 killed $d6 def $q6
; CHECK-NEXT: mov.d v4[1], v5[0]
-; CHECK-NEXT: and.16b v1, v0, v2
-; CHECK-NEXT: ldp d0, d2, [sp]
-; CHECK-NEXT: mov.d v7[1], v0[0]
-; CHECK-NEXT: sshll.2d v0, v1, #0
-; CHECK-NEXT: sshll2.2d v1, v1, #0
-; CHECK-NEXT: bit.16b v2, v6, v1
+; CHECK-NEXT: ext.16b v1, v2, v2, #8
+; CHECK-NEXT: ext.16b v3, v0, v0, #8
+; CHECK-NEXT: and.8b v0, v0, v2
+; CHECK-NEXT: ldp d2, d16, [sp]
+; CHECK-NEXT: sshll.2d v0, v0, #0
+; CHECK-NEXT: and.8b v1, v3, v1
+; CHECK-NEXT: mov.d v7[1], v2[0]
+; CHECK-NEXT: sshll.2d v1, v1, #0
; CHECK-NEXT: bsl.16b v0, v4, v7
-; CHECK-NEXT: ; kill: def $d2 killed $d2 killed $q2
+; CHECK-NEXT: mov.16b v2, v1
; CHECK-NEXT: ext.16b v1, v0, v0, #8
; CHECK-NEXT: ; kill: def $d0 killed $d0 killed $q0
+; CHECK-NEXT: bsl.16b v2, v6, v16
; CHECK-NEXT: ; kill: def $d1 killed $d1 killed $q1
+; CHECK-NEXT: ; kill: def $d2 killed $d2 killed $q2
; CHECK-NEXT: ret
%cmp0 = icmp sgt <3 x i32> %a, %b
%cmp1 = icmp sgt <3 x i32> %c, %d
@@ -197,18 +190,23 @@ define <3 x i64> @or_setcc_i16_i16_sel_i64(<3 x i16> %a, <3 x i16> %b, <3 x i16>
; CHECK-NEXT: ; kill: def $d5 killed $d5 def $q5
; CHECK-NEXT: ; kill: def $d6 killed $d6 def $q6
; CHECK-NEXT: mov.d v4[1], v5[0]
-; CHECK-NEXT: orr.8b v0, v0, v2
-; CHECK-NEXT: sshll.4s v1, v0, #0
-; CHECK-NEXT: ldp d0, d2, [sp]
-; CHECK-NEXT: mov.d v7[1], v0[0]
-; CHECK-NEXT: sshll.2d v0, v1, #0
-; CHECK-NEXT: sshll2.2d v1, v1, #0
-; CHECK-NEXT: bit.16b v2, v6, v1
+; CHECK-NEXT: sshll.4s v1, v2, #0
+; CHECK-NEXT: sshll.4s v0, v0, #0
+; CHECK-NEXT: ext.16b v2, v1, v1, #8
+; CHECK-NEXT: ext.16b v3, v0, v0, #8
+; CHECK-NEXT: orr.8b v0, v0, v1
+; CHECK-NEXT: ldp d1, d16, [sp]
+; CHECK-NEXT: sshll.2d v0, v0, #0
+; CHECK-NEXT: mov.d v7[1], v1[0]
+; CHECK-NEXT: orr.8b v1, v3, v2
+; CHECK-NEXT: sshll.2d v1, v1, #0
; CHECK-NEXT: bsl.16b v0, v4, v7
-; CHECK-NEXT: ; kill: def $d2 killed $d2 killed $q2
+; CHECK-NEXT: mov.16b v2, v1
; CHECK-NEXT: ext.16b v1, v0, v0, #8
; CHECK-NEXT: ; kill: def $d0 killed $d0 killed $q0
+; CHECK-NEXT: bsl.16b v2, v6, v16
; CHECK-NEXT: ; kill: def $d1 killed $d1 killed $q1
+; CHECK-NEXT: ; kill: def $d2 killed $d2 killed $q2
; CHECK-NEXT: ret
%cmp0 = icmp sgt <3 x i16> %a, %b
%cmp1 = icmp sgt <3 x i16> %c, %d
diff --git a/llvm/test/CodeGen/X86/bitcast-int-to-vector-bool-sext.ll b/llvm/test/CodeGen/X86/bitcast-int-to-vector-bool-sext.ll
index 474be4465d9b7..19819e3ffbe43 100644
--- a/llvm/test/CodeGen/X86/bitcast-int-to-vector-bool-sext.ll
+++ b/llvm/test/CodeGen/X86/bitcast-int-to-vector-bool-sext.ll
@@ -660,32 +660,30 @@ define <8 x i32> @PR157382(ptr %p0, ptr %p1, ptr %p2) {
; SSE2-SSSE3: # %bb.0:
; SSE2-SSSE3-NEXT: movdqu (%rdi), %xmm3
; SSE2-SSSE3-NEXT: movdqu 16(%rdi), %xmm2
-; SSE2-SSSE3-NEXT: movdqu (%rsi), %xmm0
+; SSE2-SSSE3-NEXT: movdqu (%rsi), %xmm5
; SSE2-SSSE3-NEXT: movdqu 16(%rsi), %xmm4
; SSE2-SSSE3-NEXT: movq {{.*#+}} xmm1 = mem[0],zero
-; SSE2-SSSE3-NEXT: pxor %xmm5, %xmm5
+; SSE2-SSSE3-NEXT: pxor %xmm0, %xmm0
; SSE2-SSSE3-NEXT: pxor %xmm6, %xmm6
-; SSE2-SSSE3-NEXT: pcmpgtd %xmm3, %xmm6
+; SSE2-SSSE3-NEXT: pcmpgtd %xmm2, %xmm6
; SSE2-SSSE3-NEXT: pcmpeqd %xmm7, %xmm7
; SSE2-SSSE3-NEXT: pxor %xmm7, %xmm6
; SSE2-SSSE3-NEXT: pxor %xmm8, %xmm8
-; SSE2-SSSE3-NEXT: pcmpgtd %xmm2, %xmm8
+; SSE2-SSSE3-NEXT: pcmpgtd %xmm3, %xmm8
; SSE2-SSSE3-NEXT: pxor %xmm7, %xmm8
-; SSE2-SSSE3-NEXT: pcmpgtd %xmm5, %xmm0
-; SSE2-SSSE3-NEXT: por %xmm6, %xmm0
-; SSE2-SSSE3-NEXT: pcmpgtd %xmm5, %xmm4
-; SSE2-SSSE3-NEXT: por %xmm8, %xmm4
-; SSE2-SSSE3-NEXT: packssdw %xmm4, %xmm0
+; SSE2-SSSE3-NEXT: pcmpgtd %xmm0, %xmm4
+; SSE2-SSSE3-NEXT: por %xmm6, %xmm4
+; SSE2-SSSE3-NEXT: pcmpgtd %xmm0, %xmm5
+; SSE2-SSSE3-NEXT: por %xmm8, %xmm5
; SSE2-SSSE3-NEXT: punpcklbw {{.*#+}} xmm1 = xmm1[0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7]
-; SSE2-SSSE3-NEXT: pcmpeqb %xmm5, %xmm1
+; SSE2-SSSE3-NEXT: pcmpeqb %xmm0, %xmm1
; SSE2-SSSE3-NEXT: pxor %xmm7, %xmm1
-; SSE2-SSSE3-NEXT: por %xmm0, %xmm1
+; SSE2-SSSE3-NEXT: movdqa %xmm1, %xmm0
; SSE2-SSSE3-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3]
-; SSE2-SSSE3-NEXT: psrad $16, %xmm0
-; SSE2-SSSE3-NEXT: pand %xmm3, %xmm0
+; SSE2-SSSE3-NEXT: por %xmm5, %xmm0
; SSE2-SSSE3-NEXT: punpckhwd {{.*#+}} xmm1 = xmm1[4,4,5,5,6,6,7,7]
-; SSE2-SSSE3-NEXT: pslld $31, %xmm1
-; SSE2-SSSE3-NEXT: psrad $31, %xmm1
+; SSE2-SSSE3-NEXT: por %xmm4, %xmm1
+; SSE2-SSSE3-NEXT: pand %xmm3, %xmm0
; SSE2-SSSE3-NEXT: pand %xmm2, %xmm1
; SSE2-SSSE3-NEXT: retq
;
@@ -693,28 +691,27 @@ define <8 x i32> @PR157382(ptr %p0, ptr %p1, ptr %p2) {
; AVX1: # %bb.0:
; AVX1-NEXT: vmovdqu (%rdi), %ymm0
; AVX1-NEXT: vmovq {{.*#+}} xmm1 = mem[0],zero
-; AVX1-NEXT: vpxor %xmm2, %xmm2, %xmm2
-; AVX1-NEXT: vpcmpgtd %xmm0, %xmm2, %xmm3
+; AVX1-NEXT: vextractf128 $1, %ymm0, %xmm2
+; AVX1-NEXT: vpxor %xmm3, %xmm3, %xmm3
+; AVX1-NEXT: vpcmpgtd %xmm2, %xmm3, %xmm2
; AVX1-NEXT: vpcmpeqd %xmm4, %xmm4, %xmm4
-; AVX1-NEXT: vpxor %xmm4, %xmm3, %xmm3
-; AVX1-NEXT: vextractf128 $1, %ymm0, %xmm5
-; AVX1-NEXT: vpcmpgtd %xmm5, %xmm2, %xmm5
+; AVX1-NEXT: vpxor %xmm4, %xmm2, %xmm2
+; AVX1-NEXT: vpcmpgtd %xmm0, %xmm3, %xmm5
; AVX1-NEXT: vpxor %xmm4, %xmm5, %xmm5
-; AVX1-NEXT: vmovdqu (%rsi), %xmm6
-; AVX1-NEXT: vmovdqu 16(%rsi), %xmm7
-; AVX1-NEXT: vpcmpgtd %xmm2, %xmm6, %xmm6
-; AVX1-NEXT: vpor %xmm6, %xmm3, %xmm3
-; AVX1-NEXT: vpcmpgtd %xmm2, %xmm7, %xmm6
-; AVX1-NEXT: vpor %xmm6, %xmm5, %xmm5
-; AVX1-NEXT: vpackssdw %xmm5, %xmm3, %xmm3
-; AVX1-NEXT: vpcmpeqb %xmm2, %xmm1, %xmm1
+; AVX1-NEXT: vinsertf128 $1, %xmm2, %ymm5, %ymm2
+; AVX1-NEXT: vmovdqu (%rsi), %xmm5
+; AVX1-NEXT: vmovdqu 16(%rsi), %xmm6
+; AVX1-NEXT: vpcmpgtd %xmm3, %xmm6, %xmm6
+; AVX1-NEXT: vpcmpgtd %xmm3, %xmm5, %xmm5
+; AVX1-NEXT: vinsertf128 $1, %xmm6, %ymm5, %ymm5
+; AVX1-NEXT: vorps %ymm5, %ymm2, %ymm2
+; AVX1-NEXT: vpcmpeqb %xmm3, %xmm1, %xmm1
; AVX1-NEXT: vpxor %xmm4, %xmm1, %xmm1
-; AVX1-NEXT: vpmovsxbw %xmm1, %xmm1
-; AVX1-NEXT: vpor %xmm1, %xmm3, %xmm1
-; AVX1-NEXT: vpmovsxwd %xmm1, %xmm2
-; AVX1-NEXT: vpshufd {{.*#+}} xmm1 = xmm1[2,3,2,3]
-; AVX1-NEXT: vpmovsxwd %xmm1, %xmm1
-; AVX1-NEXT: vinsertf128 $1, %xmm1, %ymm2, %ymm1
+; AVX1-NEXT: vpmovsxbd %xmm1, %xmm3
+; AVX1-NEXT: vpshufd {{.*#+}} xmm1 = xmm1[1,1,1,1]
+; AVX1-NEXT: vpmovsxbd %xmm1, %xmm1
+; AVX1-NEXT: vinsertf128 $1, %xmm1, %ymm3, %ymm1
+; AVX1-NEXT: vorps %ymm1, %ymm2, %ymm1
; AVX1-NEXT: vandps %ymm0, %ymm1, %ymm0
; AVX1-NEXT: retq
;
>From 1c9ea2b89f2c0aaa2f3e4d04965f881aa9142b5a Mon Sep 17 00:00:00 2001
From: Valeriy Savchenko <vsavchenko at apple.com>
Date: Wed, 25 Mar 2026 15:39:23 +0000
Subject: [PATCH 3/3] [SelectionDAG] Restore the pre-existing logic of type
selection
---
llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h | 15 +-
.../SelectionDAG/LegalizeVectorTypes.cpp | 166 +++++++++++-------
.../AArch64/vselect-widen-mask-tree.ll | 57 +++---
3 files changed, 138 insertions(+), 100 deletions(-)
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 4d69c0e8bcde8..2537672fbdbf2 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -1145,13 +1145,22 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
/// By default, the vector will be widened with undefined values.
SDValue ModifyToType(SDValue InOp, EVT NVT, bool FillWithZeroes = false);
+ /// Adjust element width (sign-extend/truncate) and element count
+ /// (extract/concat) of Mask to match ToMaskVT.
+ SDValue adjustMaskToType(SDValue Mask, EVT ToMaskVT);
+
+ /// Pick an intermediate VT and adjust both operands to it, minimizing
+ /// extend/truncate overhead given the final target ToVT.
+ EVT unifyMaskTypes(SDValue &Op0, SDValue &Op1, EVT ToVT);
+
/// Return a mask of vector type MaskVT to replace InMask. Also adjust
/// MaskVT to ToMaskVT if needed with vector extension or truncation.
SDValue convertMask(SDValue InMask, EVT MaskVT, EVT ToMaskVT);
- /// Recursively widen a mask expression tree to ToVT. Handles SETCC,
- /// logical ops (AND/OR/XOR), VECTOR_SHUFFLE, FREEZE, SELECT/VSELECT,
- /// and constant BUILD_VECTORs.
+ /// Recursively widen a mask expression tree to ToVT, walking through
+ /// mask-preserving operations down to SETCC leaves. Avoids redundant
+ /// extend/truncate chains that arise when each node is widened independently.
+ /// Returns SDValue() if the tree cannot be widened.
SDValue widenMaskTree(SDValue V, EVT ToVT, unsigned Depth = 0);
//===--------------------------------------------------------------------===//
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 0b7ab61cf3df8..95405fdc3126c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -6824,9 +6824,14 @@ SDValue DAGTypeLegalizer::convertMask(SDValue InMask, EVT MaskVT,
Mask = DAG.getNode(InMask->getOpcode(), SDLoc(InMask), MaskVT, Ops,
InMask->getFlags());
- // If MaskVT has smaller or bigger elements than ToMaskVT, a vector sign
- // extend or truncate is needed.
+ return adjustMaskToType(Mask, ToMaskVT);
+}
+
+// Adjust element width (sign-extend/truncate) and element count
+// (extract/concat) of Mask to match ToMaskVT.
+SDValue DAGTypeLegalizer::adjustMaskToType(SDValue Mask, EVT ToMaskVT) {
LLVMContext &Ctx = *DAG.getContext();
+ EVT MaskVT = Mask.getValueType();
unsigned MaskScalarBits = MaskVT.getScalarSizeInBits();
unsigned ToMaskScalBits = ToMaskVT.getScalarSizeInBits();
if (MaskScalarBits < ToMaskScalBits) {
@@ -6861,80 +6866,117 @@ SDValue DAGTypeLegalizer::convertMask(SDValue InMask, EVT MaskVT,
return Mask;
}
-// Recursively widen a mask expression tree to ToVT. Handles any chain of
-// mask-preserving operations rooted at SETCC nodes. Returns SDValue() if
-// the tree cannot be widened.
+// Adjust both operands to a common intermediate mask type, picking a scalar
+// width that minimizes extend/truncate overhead given the final target ToVT.
+EVT DAGTypeLegalizer::unifyMaskTypes(SDValue &Op0, SDValue &Op1, EVT ToVT) {
+ assert(Op0.getValueType().getVectorNumElements() ==
+ Op1.getValueType().getVectorNumElements() &&
+ "unifyMaskTypes only handles scalar width differences");
+ unsigned Bits0 = Op0.getValueType().getScalarSizeInBits();
+ unsigned Bits1 = Op1.getValueType().getScalarSizeInBits();
+ unsigned NarrowBits = std::min(Bits0, Bits1);
+ unsigned WideBits = std::max(Bits0, Bits1);
+ unsigned ToBits = ToVT.getScalarSizeInBits();
+ unsigned IntBits = NarrowBits == WideBits ? NarrowBits
+ : ToBits >= WideBits ? WideBits
+ : ToBits <= NarrowBits ? NarrowBits
+ : ToBits;
+ EVT OpVT = EVT::getVectorVT(*DAG.getContext(), MVT::getIntegerVT(IntBits),
+ Op0.getValueType().getVectorNumElements());
+ Op0 = adjustMaskToType(Op0, OpVT);
+ Op1 = adjustMaskToType(Op1, OpVT);
+ return OpVT;
+}
+
SDValue DAGTypeLegalizer::widenMaskTree(SDValue V, EVT ToVT, unsigned Depth) {
if (Depth >= DAG.MaxRecursionDepth)
return SDValue();
- unsigned Opcode = V.getOpcode();
+ SDValue Result = [&]() -> SDValue {
+ unsigned Opcode = V.getOpcode();
- // Base case: SETCC produces the mask directly.
- if (isSETCCOp(Opcode)) {
- EVT MaskVT = getSetCCResultType(getSETCCOperandType(V));
- return convertMask(V, MaskVT, ToVT);
- }
+ // Base case: SETCC produces the mask at its natural type.
+ if (isSETCCOp(Opcode)) {
+ EVT MaskVT = getSetCCResultType(getSETCCOperandType(V));
+ return convertMask(V, MaskVT, MaskVT);
+ }
- // Base case: all-zeros or all-ones BUILD_VECTOR.
- if (ISD::isBuildVectorAllZeros(V.getNode()))
- return DAG.getConstant(0, SDLoc(V), ToVT);
- if (ISD::isBuildVectorAllOnes(V.getNode()))
- return DAG.getAllOnesConstant(SDLoc(V), ToVT);
+ // Base case: all-zeros or all-ones BUILD_VECTOR. Use ToVT directly since
+ // these are invariant under sign-extend/truncate.
+ if (ISD::isBuildVectorAllZeros(V.getNode()))
+ return DAG.getConstant(0, SDLoc(V), ToVT);
+ if (ISD::isBuildVectorAllOnes(V.getNode()))
+ return DAG.getAllOnesConstant(SDLoc(V), ToVT);
- SDLoc DL(V);
+ SDLoc DL(V);
- // Logical operations (AND/OR/XOR): widen both operands and rebuild.
- if (isLogicalMaskOp(Opcode)) {
- SDValue Op0 = widenMaskTree(V.getOperand(0), ToVT, Depth + 1);
- if (!Op0)
- return SDValue();
- SDValue Op1 = widenMaskTree(V.getOperand(1), ToVT, Depth + 1);
- if (!Op1)
- return SDValue();
- return DAG.getNode(Opcode, DL, ToVT, Op0, Op1);
- }
+ // Logical operations (AND/OR/XOR): try picking the best fitting width out
+ // of children's element widths.
+ if (isLogicalMaskOp(Opcode)) {
+ SDValue Op0 = widenMaskTree(V.getOperand(0), ToVT, Depth + 1);
+ if (!Op0)
+ return SDValue();
+ SDValue Op1 = widenMaskTree(V.getOperand(1), ToVT, Depth + 1);
+ if (!Op1)
+ return SDValue();
+ EVT OpVT = unifyMaskTypes(Op0, Op1, ToVT);
+ return DAG.getNode(Opcode, DL, OpVT, Op0, Op1);
+ }
- // FREEZE: widen the operand and re-wrap.
- if (Opcode == ISD::FREEZE) {
- SDValue Inner = widenMaskTree(V.getOperand(0), ToVT, Depth + 1);
- if (!Inner)
- return SDValue();
- return DAG.getNode(ISD::FREEZE, DL, ToVT, Inner);
- }
+ // FREEZE: widen the operand and re-wrap.
+ if (Opcode == ISD::FREEZE) {
+ SDValue Inner = widenMaskTree(V.getOperand(0), ToVT, Depth + 1);
+ if (!Inner)
+ return SDValue();
+ return DAG.getNode(ISD::FREEZE, DL, Inner.getValueType(), Inner);
+ }
- // Vector shuffle: widen inputs and apply the same mask.
- if (Opcode == ISD::VECTOR_SHUFFLE) {
- auto *Shuf = cast<ShuffleVectorSDNode>(V);
- SDValue Op0 = widenMaskTree(V.getOperand(0), ToVT, Depth + 1);
- if (!Op0)
- return SDValue();
- SDValue Op1 = V.getOperand(1).isUndef()
- ? DAG.getUNDEF(ToVT)
- : widenMaskTree(V.getOperand(1), ToVT, Depth + 1);
- if (!Op1)
- return SDValue();
- return DAG.getVectorShuffle(ToVT, DL, Op0, Op1, Shuf->getMask());
- }
+ // Vector shuffle: try inferring the best fitting width from operands.
+ if (Opcode == ISD::VECTOR_SHUFFLE) {
+ auto *Shuf = cast<ShuffleVectorSDNode>(V);
+ SDValue Op0 = widenMaskTree(V.getOperand(0), ToVT, Depth + 1);
+ if (!Op0)
+ return SDValue();
+ if (V.getOperand(1).isUndef()) {
+ EVT OpVT = Op0.getValueType();
+ return DAG.getVectorShuffle(OpVT, DL, Op0, DAG.getUNDEF(OpVT),
+ Shuf->getMask());
+ }
+ SDValue Op1 = widenMaskTree(V.getOperand(1), ToVT, Depth + 1);
+ if (!Op1)
+ return SDValue();
+ EVT OpVT = unifyMaskTypes(Op0, Op1, ToVT);
+ return DAG.getVectorShuffle(OpVT, DL, Op0, Op1, Shuf->getMask());
+ }
- // SELECT/VSELECT: widen both true and false mask values.
- if (Opcode == ISD::SELECT || Opcode == ISD::VSELECT) {
- SDValue Cond = V.getOperand(0);
- if (Cond.getValueType().isVector()) {
- Cond = widenMaskTree(Cond, ToVT, Depth + 1);
- if (!Cond)
+ // SELECT/VSELECT: try inferring the best fitting width from operands.
+ if (Opcode == ISD::SELECT || Opcode == ISD::VSELECT) {
+ SDValue Op1 = widenMaskTree(V.getOperand(1), ToVT, Depth + 1);
+ if (!Op1)
return SDValue();
+ SDValue Op2 = widenMaskTree(V.getOperand(2), ToVT, Depth + 1);
+ if (!Op2)
+ return SDValue();
+ EVT OpVT = unifyMaskTypes(Op1, Op2, ToVT);
+
+ SDValue Cond = V.getOperand(0);
+ if (Opcode == ISD::VSELECT) {
+ Cond = widenMaskTree(Cond, ToVT, Depth + 1);
+ if (!Cond)
+ return SDValue();
+ Cond = adjustMaskToType(Cond, OpVT);
+ }
+ return DAG.getNode(Opcode, DL, OpVT, Cond, Op1, Op2);
}
- SDValue Op1 = widenMaskTree(V.getOperand(1), ToVT, Depth + 1);
- if (!Op1)
- return SDValue();
- SDValue Op2 = widenMaskTree(V.getOperand(2), ToVT, Depth + 1);
- if (!Op2)
- return SDValue();
- return DAG.getNode(Opcode, DL, ToVT, Cond, Op1, Op2);
- }
- return SDValue();
+ return SDValue();
+ }();
+
+ if (!Result)
+ return SDValue();
+ if (Depth == 0)
+ Result = adjustMaskToType(Result, ToVT);
+ return Result;
}
// This method tries to handle some special cases for the vselect mask
diff --git a/llvm/test/CodeGen/AArch64/vselect-widen-mask-tree.ll b/llvm/test/CodeGen/AArch64/vselect-widen-mask-tree.ll
index 2d9712c276594..b1454cfb1c5b1 100644
--- a/llvm/test/CodeGen/AArch64/vselect-widen-mask-tree.ll
+++ b/llvm/test/CodeGen/AArch64/vselect-widen-mask-tree.ll
@@ -124,21 +124,17 @@ define <3 x i64> @or_setcc_i16_i32_sel_i64(<3 x i16> %a, <3 x i16> %b, <3 x i32>
; CHECK-NEXT: ; kill: def $d6 killed $d6 def $q6
; CHECK-NEXT: mov.d v4[1], v5[0]
; CHECK-NEXT: sshll.4s v0, v0, #0
-; CHECK-NEXT: ext.16b v2, v1, v1, #8
-; CHECK-NEXT: ext.16b v3, v0, v0, #8
-; CHECK-NEXT: orr.8b v0, v0, v1
-; CHECK-NEXT: ldp d1, d16, [sp]
-; CHECK-NEXT: sshll.2d v0, v0, #0
-; CHECK-NEXT: mov.d v7[1], v1[0]
-; CHECK-NEXT: orr.8b v1, v3, v2
-; CHECK-NEXT: sshll.2d v1, v1, #0
+; CHECK-NEXT: orr.16b v1, v0, v1
+; CHECK-NEXT: ldp d0, d2, [sp]
+; CHECK-NEXT: mov.d v7[1], v0[0]
+; CHECK-NEXT: sshll.2d v0, v1, #0
+; CHECK-NEXT: sshll2.2d v1, v1, #0
+; CHECK-NEXT: bit.16b v2, v6, v1
; CHECK-NEXT: bsl.16b v0, v4, v7
-; CHECK-NEXT: mov.16b v2, v1
+; CHECK-NEXT: ; kill: def $d2 killed $d2 killed $q2
; CHECK-NEXT: ext.16b v1, v0, v0, #8
; CHECK-NEXT: ; kill: def $d0 killed $d0 killed $q0
-; CHECK-NEXT: bsl.16b v2, v6, v16
; CHECK-NEXT: ; kill: def $d1 killed $d1 killed $q1
-; CHECK-NEXT: ; kill: def $d2 killed $d2 killed $q2
; CHECK-NEXT: ret
%cmp0 = icmp sgt <3 x i16> %a, %b
%cmp1 = icmp sgt <3 x i32> %c, %d
@@ -157,21 +153,17 @@ define <3 x i64> @and_setcc_i32_i32_sel_i64(<3 x i32> %a, <3 x i32> %b, <3 x i32
; CHECK-NEXT: ; kill: def $d5 killed $d5 def $q5
; CHECK-NEXT: ; kill: def $d6 killed $d6 def $q6
; CHECK-NEXT: mov.d v4[1], v5[0]
-; CHECK-NEXT: ext.16b v1, v2, v2, #8
-; CHECK-NEXT: ext.16b v3, v0, v0, #8
-; CHECK-NEXT: and.8b v0, v0, v2
-; CHECK-NEXT: ldp d2, d16, [sp]
-; CHECK-NEXT: sshll.2d v0, v0, #0
-; CHECK-NEXT: and.8b v1, v3, v1
-; CHECK-NEXT: mov.d v7[1], v2[0]
-; CHECK-NEXT: sshll.2d v1, v1, #0
+; CHECK-NEXT: and.16b v1, v0, v2
+; CHECK-NEXT: ldp d0, d2, [sp]
+; CHECK-NEXT: mov.d v7[1], v0[0]
+; CHECK-NEXT: sshll.2d v0, v1, #0
+; CHECK-NEXT: sshll2.2d v1, v1, #0
+; CHECK-NEXT: bit.16b v2, v6, v1
; CHECK-NEXT: bsl.16b v0, v4, v7
-; CHECK-NEXT: mov.16b v2, v1
+; CHECK-NEXT: ; kill: def $d2 killed $d2 killed $q2
; CHECK-NEXT: ext.16b v1, v0, v0, #8
; CHECK-NEXT: ; kill: def $d0 killed $d0 killed $q0
-; CHECK-NEXT: bsl.16b v2, v6, v16
; CHECK-NEXT: ; kill: def $d1 killed $d1 killed $q1
-; CHECK-NEXT: ; kill: def $d2 killed $d2 killed $q2
; CHECK-NEXT: ret
%cmp0 = icmp sgt <3 x i32> %a, %b
%cmp1 = icmp sgt <3 x i32> %c, %d
@@ -190,23 +182,18 @@ define <3 x i64> @or_setcc_i16_i16_sel_i64(<3 x i16> %a, <3 x i16> %b, <3 x i16>
; CHECK-NEXT: ; kill: def $d5 killed $d5 def $q5
; CHECK-NEXT: ; kill: def $d6 killed $d6 def $q6
; CHECK-NEXT: mov.d v4[1], v5[0]
-; CHECK-NEXT: sshll.4s v1, v2, #0
-; CHECK-NEXT: sshll.4s v0, v0, #0
-; CHECK-NEXT: ext.16b v2, v1, v1, #8
-; CHECK-NEXT: ext.16b v3, v0, v0, #8
-; CHECK-NEXT: orr.8b v0, v0, v1
-; CHECK-NEXT: ldp d1, d16, [sp]
-; CHECK-NEXT: sshll.2d v0, v0, #0
-; CHECK-NEXT: mov.d v7[1], v1[0]
-; CHECK-NEXT: orr.8b v1, v3, v2
-; CHECK-NEXT: sshll.2d v1, v1, #0
+; CHECK-NEXT: orr.8b v0, v0, v2
+; CHECK-NEXT: sshll.4s v1, v0, #0
+; CHECK-NEXT: ldp d0, d2, [sp]
+; CHECK-NEXT: mov.d v7[1], v0[0]
+; CHECK-NEXT: sshll.2d v0, v1, #0
+; CHECK-NEXT: sshll2.2d v1, v1, #0
+; CHECK-NEXT: bit.16b v2, v6, v1
; CHECK-NEXT: bsl.16b v0, v4, v7
-; CHECK-NEXT: mov.16b v2, v1
+; CHECK-NEXT: ; kill: def $d2 killed $d2 killed $q2
; CHECK-NEXT: ext.16b v1, v0, v0, #8
; CHECK-NEXT: ; kill: def $d0 killed $d0 killed $q0
-; CHECK-NEXT: bsl.16b v2, v6, v16
; CHECK-NEXT: ; kill: def $d1 killed $d1 killed $q1
-; CHECK-NEXT: ; kill: def $d2 killed $d2 killed $q2
; CHECK-NEXT: ret
%cmp0 = icmp sgt <3 x i16> %a, %b
%cmp1 = icmp sgt <3 x i16> %c, %d
More information about the llvm-commits
mailing list