[llvm] [X86] combineSelect - move vselect(cond, pshufb(x), pshufb(y)) -> or(pshufb(x), pshufb(y)) fold (PR #145475)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 24 01:19:12 PDT 2025


https://github.com/RKSimon created https://github.com/llvm/llvm-project/pull/145475

Move the OR(PSHUFB(),PSHUFB()) fold to reuse an existing createShuffleMaskFromVSELECT result and ensure it is performed before the combineX86ShufflesRecursively combine to prevent some hasOneUse failures noticed in #133947 (combineX86ShufflesRecursively still unnecessarily widens vectors in several locations).

>From 3d7732c58134996ddbb0a05b7488b8c2b8cf8a9e Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Tue, 24 Jun 2025 09:16:49 +0100
Subject: [PATCH] [X86] combineSelect - move vselect(cond, pshufb(x),
 pshufb(y)) -> or (pshufb(x), pshufb(y)) fold

Move the OR(PSHUFB(),PSHUFB()) fold to reuse the createShuffleMaskFromVSELECT result and ensure it is performed before the combineX86ShufflesRecursively combine to prevent some hasOneUse failures noticed in #133947 (combineX86ShufflesRecursively still unnecessarily widens vectors in several locations).
---
 llvm/lib/Target/X86/X86ISelLowering.cpp | 78 ++++++++++++-------------
 1 file changed, 38 insertions(+), 40 deletions(-)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 2541182de1208..d5837ab938d4e 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -47690,12 +47690,47 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
       return V;
 
   if (N->getOpcode() == ISD::VSELECT || N->getOpcode() == X86ISD::BLENDV) {
-    SmallVector<int, 64> Mask;
-    if (createShuffleMaskFromVSELECT(Mask, Cond,
+    SmallVector<int, 64> CondMask;
+    if (createShuffleMaskFromVSELECT(CondMask, Cond,
                                      N->getOpcode() == X86ISD::BLENDV)) {
       // Convert vselects with constant condition into shuffles.
       if (DCI.isBeforeLegalizeOps())
-        return DAG.getVectorShuffle(VT, DL, LHS, RHS, Mask);
+        return DAG.getVectorShuffle(VT, DL, LHS, RHS, CondMask);
+
+      // fold vselect(cond, pshufb(x), pshufb(y)) -> or (pshufb(x), pshufb(y))
+      // by forcing the unselected elements to zero.
+      // TODO: Can we handle more shuffles with this?
+      if (LHS.hasOneUse() && RHS.hasOneUse()) {
+        SmallVector<SDValue, 1> LHSOps, RHSOps;
+        SmallVector<int, 64> LHSMask, RHSMask, ByteMask;
+        SDValue LHSShuf = peekThroughOneUseBitcasts(LHS);
+        SDValue RHSShuf = peekThroughOneUseBitcasts(RHS);
+        if (LHSShuf.getOpcode() == X86ISD::PSHUFB &&
+            RHSShuf.getOpcode() == X86ISD::PSHUFB &&
+            scaleShuffleMaskElts(VT.getSizeInBits() / 8, CondMask, ByteMask) &&
+            getTargetShuffleMask(LHSShuf, true, LHSOps, LHSMask) &&
+            getTargetShuffleMask(RHSShuf, true, RHSOps, RHSMask)) {
+          assert(ByteMask.size() == LHSMask.size() &&
+                 ByteMask.size() == RHSMask.size() && "Shuffle mask mismatch");
+          for (auto [I, M] : enumerate(ByteMask)) {
+            // getConstVector sets negative shuffle mask values as undef, so
+            // ensure we hardcode SM_SentinelZero values to zero (0x80).
+            if (M < (int)ByteMask.size()) {
+              LHSMask[I] = isUndefOrZero(LHSMask[I]) ? 0x80 : LHSMask[I];
+              RHSMask[I] = 0x80;
+            } else {
+              LHSMask[I] = 0x80;
+              RHSMask[I] = isUndefOrZero(RHSMask[I]) ? 0x80 : RHSMask[I];
+            }
+          }
+          MVT ByteVT = LHSShuf.getSimpleValueType();
+          LHS = DAG.getNode(X86ISD::PSHUFB, DL, ByteVT, LHSOps[0],
+                            getConstVector(LHSMask, ByteVT, DAG, DL, true));
+          RHS = DAG.getNode(X86ISD::PSHUFB, DL, ByteVT, RHSOps[0],
+                            getConstVector(RHSMask, ByteVT, DAG, DL, true));
+          return DAG.getBitcast(VT, DAG.getNode(ISD::OR, DL, ByteVT, LHS, RHS));
+        }
+      }
 
       // Attempt to combine as shuffle.
       SDValue Op(N, 0);
@@ -47704,43 +47739,6 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
     }
   }
 
-  // fold vselect(cond, pshufb(x), pshufb(y)) -> or (pshufb(x), pshufb(y))
-  // by forcing the unselected elements to zero.
-  // TODO: Can we handle more shuffles with this?
-  if (N->getOpcode() == ISD::VSELECT && CondVT.isVector() && LHS.hasOneUse() &&
-      RHS.hasOneUse()) {
-    SmallVector<SDValue, 1> LHSOps, RHSOps;
-    SmallVector<int, 64> LHSMask, RHSMask, CondMask, ByteMask;
-    SDValue LHSShuf = peekThroughOneUseBitcasts(LHS);
-    SDValue RHSShuf = peekThroughOneUseBitcasts(RHS);
-    if (LHSShuf.getOpcode() == X86ISD::PSHUFB &&
-        RHSShuf.getOpcode() == X86ISD::PSHUFB &&
-        createShuffleMaskFromVSELECT(CondMask, Cond) &&
-        scaleShuffleMaskElts(VT.getSizeInBits() / 8, CondMask, ByteMask) &&
-        getTargetShuffleMask(LHSShuf, true, LHSOps, LHSMask) &&
-        getTargetShuffleMask(RHSShuf, true, RHSOps, RHSMask)) {
-      assert(ByteMask.size() == LHSMask.size() &&
-             ByteMask.size() == RHSMask.size() && "Shuffle mask mismatch");
-      for (auto [I, M] : enumerate(ByteMask)) {
-        // getConstVector sets negative shuffle mask values as undef, so ensure
-        // we hardcode SM_SentinelZero values to zero (0x80).
-        if (M < (int)ByteMask.size()) {
-          LHSMask[I] = isUndefOrZero(LHSMask[I]) ? 0x80 : LHSMask[I];
-          RHSMask[I] = 0x80;
-        } else {
-          LHSMask[I] = 0x80;
-          RHSMask[I] = isUndefOrZero(RHSMask[I]) ? 0x80 : RHSMask[I];
-        }
-      }
-      MVT ByteVT = LHSShuf.getSimpleValueType();
-      LHS = DAG.getNode(X86ISD::PSHUFB, DL, ByteVT, LHSOps[0],
-                        getConstVector(LHSMask, ByteVT, DAG, DL, true));
-      RHS = DAG.getNode(X86ISD::PSHUFB, DL, ByteVT, RHSOps[0],
-                        getConstVector(RHSMask, ByteVT, DAG, DL, true));
-      return DAG.getBitcast(VT, DAG.getNode(ISD::OR, DL, ByteVT, LHS, RHS));
-    }
-  }
-
   // If we have SSE[12] support, try to form min/max nodes. SSE min/max
   // instructions match the semantics of the common C idiom x<y?x:y but not
   // x<=y?x:y, because of how they handle negative zero (which can be



More information about the llvm-commits mailing list