[llvm] [X86] getScalarMaskingNode - if the mask is zero just return the blended passthrough and preserved source value (PR #153575)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 2 01:05:12 PDT 2025


https://github.com/RKSimon updated https://github.com/llvm/llvm-project/pull/153575

>From 14a802c3d66828a49f49a8c8ce38df479867a710 Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Thu, 14 Aug 2025 14:25:21 +0100
Subject: [PATCH 1/2] [X86] getScalarMaskingNode - if the mask is zero just
 return the preserved source value

We already do the reverse if the mask is one, so I added the other case.

This just handles the fold upon creation of the X86ISD::SELECTS node - an alternative would be to add a full combineSELECTS combine to handle cases where this appears in later folds?

I had to adjust the test case for #98306 as AFAICT it'd been over reduced

Fixes #153570
---
 llvm/lib/Target/X86/X86ISelLowering.cpp           |  4 +---
 .../test/CodeGen/X86/avx512cfmulsh-instrinsics.ll |  6 +++---
 llvm/test/CodeGen/X86/avx512fp16-intrinsics.ll    | 15 +++++++++++++++
 3 files changed, 19 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 97cdf5b784bc0..00623fa8c8972 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -26261,10 +26261,8 @@ static SDValue getScalarMaskingNode(SDValue Op, SDValue Mask,
                                     SDValue PreservedSrc,
                                     const X86Subtarget &Subtarget,
                                     SelectionDAG &DAG) {
-
   if (auto *MaskConst = dyn_cast<ConstantSDNode>(Mask))
-    if (MaskConst->getZExtValue() & 0x1)
-      return Op;
+    return (MaskConst->getZExtValue() & 0x1) ? Op : PreservedSrc;
 
   MVT VT = Op.getSimpleValueType();
   SDLoc dl(Op);
diff --git a/llvm/test/CodeGen/X86/avx512cfmulsh-instrinsics.ll b/llvm/test/CodeGen/X86/avx512cfmulsh-instrinsics.ll
index e449c7192e4bf..b60d7a5463d6b 100644
--- a/llvm/test/CodeGen/X86/avx512cfmulsh-instrinsics.ll
+++ b/llvm/test/CodeGen/X86/avx512cfmulsh-instrinsics.ll
@@ -278,14 +278,14 @@ define <4 x float> @test_int_x86_avx512fp16_maskz_cfcmadd_sh(<4 x float> %x0, <4
   ret <4 x float> %res
 }
 
-define <4 x float> @PR98306() {
+define <4 x float> @PR98306(i8 %m) {
 ; CHECK-LABEL: PR98306:
 ; CHECK:       ## %bb.0:
-; CHECK-NEXT:    kxorw %k0, %k0, %k1
+; CHECK-NEXT:    kmovd %edi, %k1
 ; CHECK-NEXT:    vmovaps {{.*#+}} xmm1 = [7.8125E-3,1.050912E+6,4.203776E+6,1.6815616E+7]
 ; CHECK-NEXT:    vmovaps {{.*#+}} xmm0 = [3.2E+1,4.03288064E+8,8.0658432E+8,1.61318502E+9]
 ; CHECK-NEXT:    vfmaddcsh {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0 {%k1} {z}
 ; CHECK-NEXT:    retq
-  %res = call <4 x float> @llvm.x86.avx512fp16.maskz.vfmadd.csh(<4 x float> <float 7.812500e-03, float 0x4130092000000000, float 0x4150094000000000, float 0x4170096000000000>, <4 x float> <float 2.000000e+00, float 0x4188098000000000, float 0x4198099000000000, float 0x41A809A000000000>, <4 x float> <float 3.200000e+01, float 0x41B809B000000000, float 0x41C809C000000000, float 0x41D809D000000000>, i8 0, i32 4)
+  %res = call <4 x float> @llvm.x86.avx512fp16.maskz.vfmadd.csh(<4 x float> <float 7.812500e-03, float 0x4130092000000000, float 0x4150094000000000, float 0x4170096000000000>, <4 x float> <float 2.000000e+00, float 0x4188098000000000, float 0x4198099000000000, float 0x41A809A000000000>, <4 x float> <float 3.200000e+01, float 0x41B809B000000000, float 0x41C809C000000000, float 0x41D809D000000000>, i8 %m, i32 4)
   ret <4 x float> %res
 }
diff --git a/llvm/test/CodeGen/X86/avx512fp16-intrinsics.ll b/llvm/test/CodeGen/X86/avx512fp16-intrinsics.ll
index 627a94799424c..89410a9c1b476 100644
--- a/llvm/test/CodeGen/X86/avx512fp16-intrinsics.ll
+++ b/llvm/test/CodeGen/X86/avx512fp16-intrinsics.ll
@@ -1361,3 +1361,18 @@ define <32 x half> @test_mm512_castph256_ph512_freeze(<16 x half> %a0) nounwind
   %res = shufflevector <16 x half> %a0, <16 x half> %a1, <32 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15, i32 16, i32 17, i32 18, i32 19, i32 20, i32 21, i32 22, i32 23, i32 24, i32 25, i32 26, i32 27, i32 28, i32 29, i32 30, i32 31>
   ret <32 x half> %res
 }
+
+define <8 x half> @PR153570(ptr %p) {
+; CHECK-LABEL: PR153570:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vmovsh {{.*#+}} xmm0 = [2.0E+0,0.0E+0,0.0E+0,0.0E+0,0.0E+0,0.0E+0,0.0E+0,0.0E+0]
+; CHECK-NEXT:    vmovsh {{.*#+}} xmm1 = [1.0E+0,0.0E+0,0.0E+0,0.0E+0,0.0E+0,0.0E+0,0.0E+0,0.0E+0]
+; CHECK-NEXT:    vmulsh {rn-sae}, %xmm0, %xmm1, %xmm0
+; CHECK-NEXT:    vxorps %xmm1, %xmm1, %xmm1
+; CHECK-NEXT:    vmovaps %xmm1, (%rdi)
+; CHECK-NEXT:    retq
+  %r = tail call <8 x half> @llvm.x86.avx512fp16.mask.mul.sh.round(<8 x half> <half 0xH3C00, half 0xH0000, half 0xH0000, half 0xH0000, half 0xH0000, half 0xH0000, half 0xH0000, half 0xH0000>, <8 x half> <half 0xH4000, half 0xH0000, half 0xH0000, half 0xH0000, half 0xH0000, half 0xH0000, half 0xH0000, half 0xH0000>, <8 x half> zeroinitializer, i8 0, i32 8)
+  store <8 x half> %r, ptr %p, align 16
+  %r1 = tail call <8 x half> @llvm.x86.avx512fp16.mask.mul.sh.round(<8 x half> <half 0xH3C00, half 0xH0000, half 0xH0000, half 0xH0000, half 0xH0000, half 0xH0000, half 0xH0000, half 0xH0000>, <8 x half> <half 0xH4000, half 0xH0000, half 0xH0000, half 0xH0000, half 0xH0000, half 0xH0000, half 0xH0000, half 0xH0000>, <8 x half> zeroinitializer, i8 1, i32 8)
+  ret <8 x half> %r1
+}

>From bc69d8dc84b80e6749857cf3cc9a4c24a6a4e54b Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Thu, 28 Aug 2025 15:09:17 +0100
Subject: [PATCH 2/2] [X86] getScalarMaskingNode - emit a scalar blend pattern
 if mask is false

Assumes scalar passthrough is op0
---
 llvm/lib/Target/X86/X86ISelLowering.cpp        | 16 ++++++++++++++--
 llvm/test/CodeGen/X86/avx512fp16-intrinsics.ll | 11 ++++++-----
 2 files changed, 20 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 8ad6cdd1d683c..ac3e164fdf1de 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -26269,8 +26269,9 @@ static SDValue getScalarMaskingNode(SDValue Op, SDValue Mask,
                                     SDValue PreservedSrc,
                                     const X86Subtarget &Subtarget,
                                     SelectionDAG &DAG) {
-  if (auto *MaskConst = dyn_cast<ConstantSDNode>(Mask))
-    return (MaskConst->getZExtValue() & 0x1) ? Op : PreservedSrc;
+  auto *MaskConst = dyn_cast<ConstantSDNode>(Mask);
+  if (MaskConst && (MaskConst->getZExtValue() & 0x1))
+    return Op;
 
   MVT VT = Op.getSimpleValueType();
   SDLoc dl(Op);
@@ -26286,6 +26287,17 @@ static SDValue getScalarMaskingNode(SDValue Op, SDValue Mask,
 
   if (PreservedSrc.isUndef())
     PreservedSrc = getZeroVector(VT, Subtarget, DAG, dl);
+
+  if (MaskConst) {
+    assert((MaskConst->getZExtValue() & 0x1) == 0 && "Expected false mask");
+    // Discard op and blend passthrough with scalar op src/dst.
+    SmallVector<int, 16> ShuffleMask(VT.getVectorNumElements());
+    std::iota(ShuffleMask.begin(), ShuffleMask.end(), 0);
+    ShuffleMask[0] = VT.getVectorNumElements();
+    return DAG.getVectorShuffle(VT, dl, Op.getOperand(0), PreservedSrc,
+                                ShuffleMask);
+  }
+
   return DAG.getNode(X86ISD::SELECTS, dl, VT, IMask, Op, PreservedSrc);
 }
 
diff --git a/llvm/test/CodeGen/X86/avx512fp16-intrinsics.ll b/llvm/test/CodeGen/X86/avx512fp16-intrinsics.ll
index 89410a9c1b476..b1bacd92f073b 100644
--- a/llvm/test/CodeGen/X86/avx512fp16-intrinsics.ll
+++ b/llvm/test/CodeGen/X86/avx512fp16-intrinsics.ll
@@ -1365,14 +1365,15 @@ define <32 x half> @test_mm512_castph256_ph512_freeze(<16 x half> %a0) nounwind
 define <8 x half> @PR153570(ptr %p) {
 ; CHECK-LABEL: PR153570:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vmovsh {{.*#+}} xmm0 = [2.0E+0,0.0E+0,0.0E+0,0.0E+0,0.0E+0,0.0E+0,0.0E+0,0.0E+0]
-; CHECK-NEXT:    vmovsh {{.*#+}} xmm1 = [1.0E+0,0.0E+0,0.0E+0,0.0E+0,0.0E+0,0.0E+0,0.0E+0,0.0E+0]
+; CHECK-NEXT:    vpbroadcastw {{.*#+}} xmm0 = [2.0E+0,2.0E+0,2.0E+0,2.0E+0,2.0E+0,2.0E+0,2.0E+0,2.0E+0]
+; CHECK-NEXT:    vpbroadcastw {{.*#+}} xmm1 = [1.0E+0,1.0E+0,1.0E+0,1.0E+0,1.0E+0,1.0E+0,1.0E+0,1.0E+0]
 ; CHECK-NEXT:    vmulsh {rn-sae}, %xmm0, %xmm1, %xmm0
-; CHECK-NEXT:    vxorps %xmm1, %xmm1, %xmm1
+; CHECK-NEXT:    vpbroadcastw {{.*#+}} xmm2 = [-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0]
+; CHECK-NEXT:    vmovsh %xmm2, %xmm1, %xmm1
 ; CHECK-NEXT:    vmovaps %xmm1, (%rdi)
 ; CHECK-NEXT:    retq
-  %r = tail call <8 x half> @llvm.x86.avx512fp16.mask.mul.sh.round(<8 x half> <half 0xH3C00, half 0xH0000, half 0xH0000, half 0xH0000, half 0xH0000, half 0xH0000, half 0xH0000, half 0xH0000>, <8 x half> <half 0xH4000, half 0xH0000, half 0xH0000, half 0xH0000, half 0xH0000, half 0xH0000, half 0xH0000, half 0xH0000>, <8 x half> zeroinitializer, i8 0, i32 8)
+  %r = tail call <8 x half> @llvm.x86.avx512fp16.mask.mul.sh.round(<8 x half> <half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00>, <8 x half> <half 0xH4000, half 0xH4000, half 0xH4000, half 0xH4000, half 0xH4000, half 0xH4000, half 0xH4000, half 0xH4000>, <8 x half> <half 0xH8000, half 0xH8000, half 0xH8000, half 0xH8000, half 0xH8000, half 0xH8000, half 0xH8000, half 0xH8000>, i8 0, i32 8)
   store <8 x half> %r, ptr %p, align 16
-  %r1 = tail call <8 x half> @llvm.x86.avx512fp16.mask.mul.sh.round(<8 x half> <half 0xH3C00, half 0xH0000, half 0xH0000, half 0xH0000, half 0xH0000, half 0xH0000, half 0xH0000, half 0xH0000>, <8 x half> <half 0xH4000, half 0xH0000, half 0xH0000, half 0xH0000, half 0xH0000, half 0xH0000, half 0xH0000, half 0xH0000>, <8 x half> zeroinitializer, i8 1, i32 8)
+  %r1 = tail call <8 x half> @llvm.x86.avx512fp16.mask.mul.sh.round(<8 x half> <half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00>, <8 x half> <half 0xH4000, half 0xH4000, half 0xH4000, half 0xH4000, half 0xH4000, half 0xH4000, half 0xH4000, half 0xH4000>, <8 x half> <half 0xH8000, half 0xH8000, half 0xH8000, half 0xH8000, half 0xH8000, half 0xH8000, half 0xH8000, half 0xH8000>, i8 1, i32 8)
   ret <8 x half> %r1
 }



More information about the llvm-commits mailing list