[llvm] [X86] Generate `kmov` for masking integers (PR #120593)

via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 24 01:56:42 PST 2024


https://github.com/abhishek-kaushik22 updated https://github.com/llvm/llvm-project/pull/120593

>From 822ae48049fdebc769269291868270314f30ca9a Mon Sep 17 00:00:00 2001
From: abhishek-kaushik22 <abhishek.kaushik at intel.com>
Date: Thu, 19 Dec 2024 21:14:10 +0530
Subject: [PATCH 1/4] Generate `kmov` for masking integers

When we have an integer used as a bit mask the llvm ir looks something like this
```
%1 = and <16 x i32> %.splat, <i32 1, i32 2, i32 4, i32 8, i32 16, i32 32, i32 64, i32 128, i32 256, i32 512, i32 1024, i32 2048, i32 4096, i32 8192, i32 16384, i32 32768>
  %cmp1 = icmp ne <16 x i32> %1, zeroinitializer
```
where `.splat` is vector containing the mask in all lanes.
The assembly generated for this looks like
```
vpbroadcastd    %ecx, %zmm0
vptestmd        .LCPI0_0(%rip), %zmm0, %k1
```
where we have a constant table of powers of 2.
Instead of doing this we could just move the relevant bits directly to `k` registers using a `kmov` instruction. This is faster and also reduces code size.
---
 llvm/lib/Target/X86/X86ISelDAGToDAG.cpp |  79 +++++++--
 llvm/test/CodeGen/X86/kmov.ll           | 205 ++++++++++++++++++++++++
 llvm/test/CodeGen/X86/pr78897.ll        |   4 +-
 3 files changed, 273 insertions(+), 15 deletions(-)
 create mode 100644 llvm/test/CodeGen/X86/kmov.ll

diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
index bb20e6ecf281b0..8c199a30dfbce7 100644
--- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
@@ -592,7 +592,7 @@ namespace {
     bool matchVPTERNLOG(SDNode *Root, SDNode *ParentA, SDNode *ParentB,
                         SDNode *ParentC, SDValue A, SDValue B, SDValue C,
                         uint8_t Imm);
-    bool tryVPTESTM(SDNode *Root, SDValue Setcc, SDValue Mask);
+    bool tryVPTESTMOrKMOV(SDNode *Root, SDValue Setcc, SDValue Mask);
     bool tryMatchBitSelect(SDNode *N);
 
     MachineSDNode *emitPCMPISTR(unsigned ROpc, unsigned MOpc, bool MayFoldLoad,
@@ -4897,10 +4897,10 @@ VPTESTM_CASE(v32i16, WZ##SUFFIX)
 #undef VPTESTM_CASE
 }
 
-// Try to create VPTESTM instruction. If InMask is not null, it will be used
-// to form a masked operation.
-bool X86DAGToDAGISel::tryVPTESTM(SDNode *Root, SDValue Setcc,
-                                 SDValue InMask) {
+// Try to create VPTESTM or KMOV instruction. If InMask is not null, it will be
+// used to form a masked operation.
+bool X86DAGToDAGISel::tryVPTESTMOrKMOV(SDNode *Root, SDValue Setcc,
+                                       SDValue InMask) {
   assert(Subtarget->hasAVX512() && "Expected AVX512!");
   assert(Setcc.getSimpleValueType().getVectorElementType() == MVT::i1 &&
          "Unexpected VT!");
@@ -4975,12 +4975,70 @@ bool X86DAGToDAGISel::tryVPTESTM(SDNode *Root, SDValue Setcc,
     return tryFoldBroadcast(Root, P, L, Base, Scale, Index, Disp, Segment);
   };
 
+  auto canUseKMOV = [&]() {
+    if (Src0.getOpcode() != X86ISD::VBROADCAST)
+      return false;
+
+    if (Src1.getOpcode() != ISD::LOAD ||
+        Src1.getOperand(1).getOpcode() != X86ISD::Wrapper ||
+        Src1.getOperand(1).getOperand(0).getOpcode() != ISD::TargetConstantPool)
+      return false;
+
+    const auto *ConstPool =
+        dyn_cast<ConstantPoolSDNode>(Src1.getOperand(1).getOperand(0));
+    if (!ConstPool)
+      return false;
+
+    const auto *ConstVec = ConstPool->getConstVal();
+    const auto *ConstVecType = dyn_cast<FixedVectorType>(ConstVec->getType());
+    if (!ConstVecType)
+      return false;
+
+    for (unsigned i = 0, e = ConstVecType->getNumElements(), k = 1; i != e;
+         ++i, k *= 2) {
+      const auto *Element = ConstVec->getAggregateElement(i);
+      if (llvm::isa<llvm::UndefValue>(Element)) {
+        for (unsigned j = i + 1; j != e; ++j) {
+          if (!llvm::isa<llvm::UndefValue>(ConstVec->getAggregateElement(j)))
+            return false;
+        }
+        return i != 0;
+      }
+
+      if (Element->getUniqueInteger() != k) {
+        return false;
+      }
+    }
+
+    return true;
+  };
+
   // We can only fold loads if the sources are unique.
   bool CanFoldLoads = Src0 != Src1;
 
   bool FoldedLoad = false;
   SDValue Tmp0, Tmp1, Tmp2, Tmp3, Tmp4;
+  SDLoc dl(Root);
+  bool IsTestN = CC == ISD::SETEQ;
+  MachineSDNode *CNode;
+  MVT ResVT = Setcc.getSimpleValueType();
   if (CanFoldLoads) {
+    if (canUseKMOV()) {
+      auto Op = Src0.getOperand(0);
+      if (Op.getSimpleValueType() == MVT::i8) {
+        Op = SDValue(CurDAG->getNode(ISD::ZERO_EXTEND, dl, MVT::i32, Op));
+      }
+      CNode = CurDAG->getMachineNode(
+          ResVT.getVectorNumElements() <= 8 ? X86::KMOVBkr : X86::KMOVWkr, dl,
+          ResVT, Op);
+      if (IsTestN)
+        CNode = CurDAG->getMachineNode(
+            ResVT.getVectorNumElements() <= 8 ? X86::KNOTBkk : X86::KNOTWkk, dl,
+            ResVT, SDValue(CNode, 0));
+      ReplaceUses(SDValue(Root, 0), SDValue(CNode, 0));
+      CurDAG->RemoveDeadNode(Root);
+      return true;
+    }
     FoldedLoad = tryFoldLoadOrBCast(Root, N0.getNode(), Src1, Tmp0, Tmp1, Tmp2,
                                     Tmp3, Tmp4);
     if (!FoldedLoad) {
@@ -4996,9 +5054,6 @@ bool X86DAGToDAGISel::tryVPTESTM(SDNode *Root, SDValue Setcc,
 
   bool IsMasked = InMask.getNode() != nullptr;
 
-  SDLoc dl(Root);
-
-  MVT ResVT = Setcc.getSimpleValueType();
   MVT MaskVT = ResVT;
   if (Widen) {
     // Widen the inputs using insert_subreg or copy_to_regclass.
@@ -5023,11 +5078,9 @@ bool X86DAGToDAGISel::tryVPTESTM(SDNode *Root, SDValue Setcc,
     }
   }
 
-  bool IsTestN = CC == ISD::SETEQ;
   unsigned Opc = getVPTESTMOpc(CmpVT, IsTestN, FoldedLoad, FoldedBCast,
                                IsMasked);
 
-  MachineSDNode *CNode;
   if (FoldedLoad) {
     SDVTList VTs = CurDAG->getVTList(MaskVT, MVT::Other);
 
@@ -5466,10 +5519,10 @@ void X86DAGToDAGISel::Select(SDNode *Node) {
       SDValue N0 = Node->getOperand(0);
       SDValue N1 = Node->getOperand(1);
       if (N0.getOpcode() == ISD::SETCC && N0.hasOneUse() &&
-          tryVPTESTM(Node, N0, N1))
+          tryVPTESTMOrKMOV(Node, N0, N1))
         return;
       if (N1.getOpcode() == ISD::SETCC && N1.hasOneUse() &&
-          tryVPTESTM(Node, N1, N0))
+          tryVPTESTMOrKMOV(Node, N1, N0))
         return;
     }
 
@@ -6393,7 +6446,7 @@ void X86DAGToDAGISel::Select(SDNode *Node) {
   }
 
   case ISD::SETCC: {
-    if (NVT.isVector() && tryVPTESTM(Node, SDValue(Node, 0), SDValue()))
+    if (NVT.isVector() && tryVPTESTMOrKMOV(Node, SDValue(Node, 0), SDValue()))
       return;
 
     break;
diff --git a/llvm/test/CodeGen/X86/kmov.ll b/llvm/test/CodeGen/X86/kmov.ll
new file mode 100644
index 00000000000000..6d72a8923c5ab3
--- /dev/null
+++ b/llvm/test/CodeGen/X86/kmov.ll
@@ -0,0 +1,205 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mcpu=skylake-avx512 | FileCheck %s
+
+define dso_local void @foo_16_ne(ptr nocapture noundef writeonly %c, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %mask) {
+; CHECK-LABEL: foo_16_ne:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    kmovw %ecx, %k1
+; CHECK-NEXT:    vmovups (%rdx), %zmm0 {%k1} {z}
+; CHECK-NEXT:    vmovups (%rsi), %zmm1 {%k1} {z}
+; CHECK-NEXT:    vaddps %zmm1, %zmm0, %zmm0
+; CHECK-NEXT:    vmovups %zmm0, (%rdi) {%k1}
+; CHECK-NEXT:    vzeroupper
+; CHECK-NEXT:    retq
+entry:
+  %0 = and i32 %mask, 65535
+  %.splatinsert = insertelement <16 x i32> poison, i32 %0, i64 0
+  %.splat = shufflevector <16 x i32> %.splatinsert, <16 x i32> poison, <16 x i32> zeroinitializer
+  %1 = and <16 x i32> %.splat, <i32 1, i32 2, i32 4, i32 8, i32 16, i32 32, i32 64, i32 128, i32 256, i32 512, i32 1024, i32 2048, i32 4096, i32 8192, i32 16384, i32 32768>
+  %hir.cmp.45 = icmp ne <16 x i32> %1, zeroinitializer
+  %2 = tail call <16 x float> @llvm.masked.load.v16f32.p0(ptr %b, i32 4, <16 x i1> %hir.cmp.45, <16 x float> poison)
+  %3 = tail call <16 x float> @llvm.masked.load.v16f32.p0(ptr %a, i32 4, <16 x i1> %hir.cmp.45, <16 x float> poison)
+  %4 = fadd reassoc nsz arcp contract afn <16 x float> %2, %3
+  tail call void @llvm.masked.store.v16f32.p0(<16 x float> %4, ptr %c, i32 4, <16 x i1> %hir.cmp.45)
+  ret void
+}
+
+; Function Attrs: mustprogress nounwind uwtable
+define dso_local void @foo_16_eq(ptr nocapture noundef writeonly %c, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %mask) {
+; CHECK-LABEL: foo_16_eq:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    kmovw %ecx, %k0
+; CHECK-NEXT:    knotw %k0, %k1
+; CHECK-NEXT:    vmovups (%rdx), %zmm0 {%k1} {z}
+; CHECK-NEXT:    vmovups (%rsi), %zmm1 {%k1} {z}
+; CHECK-NEXT:    vaddps %zmm1, %zmm0, %zmm0
+; CHECK-NEXT:    vmovups %zmm0, (%rdi) {%k1}
+; CHECK-NEXT:    vzeroupper
+; CHECK-NEXT:    retq
+entry:
+  %0 = and i32 %mask, 65535
+  %.splatinsert = insertelement <16 x i32> poison, i32 %0, i64 0
+  %.splat = shufflevector <16 x i32> %.splatinsert, <16 x i32> poison, <16 x i32> zeroinitializer
+  %1 = and <16 x i32> %.splat, <i32 1, i32 2, i32 4, i32 8, i32 16, i32 32, i32 64, i32 128, i32 256, i32 512, i32 1024, i32 2048, i32 4096, i32 8192, i32 16384, i32 32768>
+  %hir.cmp.45 = icmp eq <16 x i32> %1, zeroinitializer
+  %2 = tail call <16 x float> @llvm.masked.load.v16f32.p0(ptr %b, i32 4, <16 x i1> %hir.cmp.45, <16 x float> poison)
+  %3 = tail call <16 x float> @llvm.masked.load.v16f32.p0(ptr %a, i32 4, <16 x i1> %hir.cmp.45, <16 x float> poison)
+  %4 = fadd reassoc nsz arcp contract afn <16 x float> %2, %3
+  tail call void @llvm.masked.store.v16f32.p0(<16 x float> %4, ptr %c, i32 4, <16 x i1> %hir.cmp.45)
+  ret void
+}
+
+define dso_local void @foo_8_ne(ptr nocapture noundef writeonly %c, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %mask) {
+; CHECK-LABEL: foo_8_ne:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    kmovb %ecx, %k1
+; CHECK-NEXT:    vmovups (%rdx), %ymm0 {%k1} {z}
+; CHECK-NEXT:    vmovups (%rsi), %ymm1 {%k1} {z}
+; CHECK-NEXT:    vaddps %ymm1, %ymm0, %ymm0
+; CHECK-NEXT:    vmovups %ymm0, (%rdi) {%k1}
+; CHECK-NEXT:    vzeroupper
+; CHECK-NEXT:    retq
+entry:
+  %0 = and i32 %mask, 65535
+  %.splatinsert = insertelement <8 x i32> poison, i32 %0, i64 0
+  %.splat = shufflevector <8 x i32> %.splatinsert, <8 x i32> poison, <8 x i32> zeroinitializer
+  %1 = and <8 x i32> %.splat, <i32 1, i32 2, i32 4, i32 8, i32 16, i32 32, i32 64, i32 128>
+  %hir.cmp.45 = icmp ne <8 x i32> %1, zeroinitializer
+  %2 = tail call <8 x float> @llvm.masked.load.v8f32.p0(ptr %b, i32 4, <8 x i1> %hir.cmp.45, <8 x float> poison)
+  %3 = tail call <8 x float> @llvm.masked.load.v8f32.p0(ptr %a, i32 4, <8 x i1> %hir.cmp.45, <8 x float> poison)
+  %4 = fadd reassoc nsz arcp contract afn <8 x float> %2, %3
+  tail call void @llvm.masked.store.v8f32.p0(<8 x float> %4, ptr %c, i32 4, <8 x i1> %hir.cmp.45)
+  ret void
+}
+
+define dso_local void @foo_8_eq(ptr nocapture noundef writeonly %c, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %mask) {
+; CHECK-LABEL: foo_8_eq:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    kmovb %ecx, %k0
+; CHECK-NEXT:    knotb %k0, %k1
+; CHECK-NEXT:    vmovups (%rdx), %ymm0 {%k1} {z}
+; CHECK-NEXT:    vmovups (%rsi), %ymm1 {%k1} {z}
+; CHECK-NEXT:    vaddps %ymm1, %ymm0, %ymm0
+; CHECK-NEXT:    vmovups %ymm0, (%rdi) {%k1}
+; CHECK-NEXT:    vzeroupper
+; CHECK-NEXT:    retq
+entry:
+  %0 = and i32 %mask, 65535
+  %.splatinsert = insertelement <8 x i32> poison, i32 %0, i64 0
+  %.splat = shufflevector <8 x i32> %.splatinsert, <8 x i32> poison, <8 x i32> zeroinitializer
+  %1 = and <8 x i32> %.splat, <i32 1, i32 2, i32 4, i32 8, i32 16, i32 32, i32 64, i32 128>
+  %hir.cmp.45 = icmp eq <8 x i32> %1, zeroinitializer
+  %2 = tail call <8 x float> @llvm.masked.load.v8f32.p0(ptr %b, i32 4, <8 x i1> %hir.cmp.45, <8 x float> poison)
+  %3 = tail call <8 x float> @llvm.masked.load.v8f32.p0(ptr %a, i32 4, <8 x i1> %hir.cmp.45, <8 x float> poison)
+  %4 = fadd reassoc nsz arcp contract afn <8 x float> %2, %3
+  tail call void @llvm.masked.store.v8f32.p0(<8 x float> %4, ptr %c, i32 4, <8 x i1> %hir.cmp.45)
+  ret void
+}
+
+define dso_local void @foo_4_ne(ptr nocapture noundef writeonly %c, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %mask) {
+; CHECK-LABEL: foo_4_ne:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    kmovb %ecx, %k1
+; CHECK-NEXT:    vmovups (%rdx), %xmm0 {%k1} {z}
+; CHECK-NEXT:    vmovups (%rsi), %xmm1 {%k1} {z}
+; CHECK-NEXT:    vaddps %xmm1, %xmm0, %xmm0
+; CHECK-NEXT:    vmovups %xmm0, (%rdi) {%k1}
+; CHECK-NEXT:    retq
+entry:
+  %0 = and i32 %mask, 65535
+  %.splatinsert = insertelement <4 x i32> poison, i32 %0, i64 0
+  %.splat = shufflevector <4 x i32> %.splatinsert, <4 x i32> poison, <4 x i32> zeroinitializer
+  %1 = and <4 x i32> %.splat, <i32 1, i32 2, i32 4, i32 8>
+  %hir.cmp.45 = icmp ne <4 x i32> %1, zeroinitializer
+  %2 = tail call <4 x float> @llvm.masked.load.v4f32.p0(ptr %b, i32 4, <4 x i1> %hir.cmp.45, <4 x float> poison)
+  %3 = tail call <4 x float> @llvm.masked.load.v4f32.p0(ptr %a, i32 4, <4 x i1> %hir.cmp.45, <4 x float> poison)
+  %4 = fadd reassoc nsz arcp contract afn <4 x float> %2, %3
+  tail call void @llvm.masked.store.v4f32.p0(<4 x float> %4, ptr %c, i32 4, <4 x i1> %hir.cmp.45)
+  ret void
+}
+
+define dso_local void @foo_4_eq(ptr nocapture noundef writeonly %c, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %mask) {
+; CHECK-LABEL: foo_4_eq:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    kmovb %ecx, %k0
+; CHECK-NEXT:    knotb %k0, %k1
+; CHECK-NEXT:    vmovups (%rdx), %xmm0 {%k1} {z}
+; CHECK-NEXT:    vmovups (%rsi), %xmm1 {%k1} {z}
+; CHECK-NEXT:    vaddps %xmm1, %xmm0, %xmm0
+; CHECK-NEXT:    vmovups %xmm0, (%rdi) {%k1}
+; CHECK-NEXT:    retq
+entry:
+  %0 = and i32 %mask, 65535
+  %.splatinsert = insertelement <4 x i32> poison, i32 %0, i64 0
+  %.splat = shufflevector <4 x i32> %.splatinsert, <4 x i32> poison, <4 x i32> zeroinitializer
+  %1 = and <4 x i32> %.splat, <i32 1, i32 2, i32 4, i32 8>
+  %hir.cmp.45 = icmp eq <4 x i32> %1, zeroinitializer
+  %2 = tail call <4 x float> @llvm.masked.load.v4f32.p0(ptr %b, i32 4, <4 x i1> %hir.cmp.45, <4 x float> poison)
+  %3 = tail call <4 x float> @llvm.masked.load.v4f32.p0(ptr %a, i32 4, <4 x i1> %hir.cmp.45, <4 x float> poison)
+  %4 = fadd reassoc nsz arcp contract afn <4 x float> %2, %3
+  tail call void @llvm.masked.store.v4f32.p0(<4 x float> %4, ptr %c, i32 4, <4 x i1> %hir.cmp.45)
+  ret void
+}
+
+define dso_local void @foo_2_ne(ptr nocapture noundef writeonly %c, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %mask) {
+; CHECK-LABEL: foo_2_ne:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    kmovb %ecx, %k0
+; CHECK-NEXT:    kshiftlb $6, %k0, %k0
+; CHECK-NEXT:    kshiftrb $6, %k0, %k1
+; CHECK-NEXT:    vmovups (%rdx), %xmm0 {%k1} {z}
+; CHECK-NEXT:    vmovups (%rsi), %xmm1 {%k1} {z}
+; CHECK-NEXT:    vaddps %xmm1, %xmm0, %xmm0
+; CHECK-NEXT:    vmovups %xmm0, (%rdi) {%k1}
+; CHECK-NEXT:    retq
+entry:
+  %.splatinsert = insertelement <2 x i32> poison, i32 %mask, i64 0
+  %.splat = shufflevector <2 x i32> %.splatinsert, <2 x i32> poison, <2 x i32> zeroinitializer
+  %0 = and <2 x i32> %.splat, <i32 1, i32 2>
+  %hir.cmp.44 = icmp ne <2 x i32> %0, zeroinitializer
+  %1 = tail call <2 x float> @llvm.masked.load.v2f32.p0(ptr %b, i32 4, <2 x i1> %hir.cmp.44, <2 x float> poison)
+  %2 = tail call <2 x float> @llvm.masked.load.v2f32.p0(ptr %a, i32 4, <2 x i1> %hir.cmp.44, <2 x float> poison)
+  %3 = fadd reassoc nsz arcp contract afn <2 x float> %1, %2
+  tail call void @llvm.masked.store.v2f32.p0(<2 x float> %3, ptr %c, i32 4, <2 x i1> %hir.cmp.44)
+  ret void
+}
+
+define dso_local void @foo_2_eq(ptr nocapture noundef writeonly %c, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %mask) {
+; CHECK-LABEL: foo_2_eq:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    kmovb %ecx, %k0
+; CHECK-NEXT:    knotb %k0, %k0
+; CHECK-NEXT:    kshiftlb $6, %k0, %k0
+; CHECK-NEXT:    kshiftrb $6, %k0, %k1
+; CHECK-NEXT:    vmovups (%rdx), %xmm0 {%k1} {z}
+; CHECK-NEXT:    vmovups (%rsi), %xmm1 {%k1} {z}
+; CHECK-NEXT:    vaddps %xmm1, %xmm0, %xmm0
+; CHECK-NEXT:    vmovups %xmm0, (%rdi) {%k1}
+; CHECK-NEXT:    retq
+entry:
+  %.splatinsert = insertelement <2 x i32> poison, i32 %mask, i64 0
+  %.splat = shufflevector <2 x i32> %.splatinsert, <2 x i32> poison, <2 x i32> zeroinitializer
+  %0 = and <2 x i32> %.splat, <i32 1, i32 2>
+  %hir.cmp.44 = icmp eq <2 x i32> %0, zeroinitializer
+  %1 = tail call <2 x float> @llvm.masked.load.v2f32.p0(ptr %b, i32 4, <2 x i1> %hir.cmp.44, <2 x float> poison)
+  %2 = tail call <2 x float> @llvm.masked.load.v2f32.p0(ptr %a, i32 4, <2 x i1> %hir.cmp.44, <2 x float> poison)
+  %3 = fadd reassoc nsz arcp contract afn <2 x float> %1, %2
+  tail call void @llvm.masked.store.v2f32.p0(<2 x float> %3, ptr %c, i32 4, <2 x i1> %hir.cmp.44)
+  ret void
+}
+
+declare <2 x float> @llvm.masked.load.v2f32.p0(ptr nocapture, i32 immarg, <2 x i1>, <2 x float>) #1
+
+declare void @llvm.masked.store.v2f32.p0(<2 x float>, ptr nocapture, i32 immarg, <2 x i1>) #2
+
+declare <4 x float> @llvm.masked.load.v4f32.p0(ptr nocapture, i32 immarg, <4 x i1>, <4 x float>) #1
+
+declare void @llvm.masked.store.v4f32.p0(<4 x float>, ptr nocapture, i32 immarg, <4 x i1>) #2
+
+declare <8 x float> @llvm.masked.load.v8f32.p0(ptr nocapture, i32 immarg, <8 x i1>, <8 x float>)
+
+declare void @llvm.masked.store.v8f32.p0(<8 x float>, ptr nocapture, i32 immarg, <8 x i1>)
+
+declare <16 x float> @llvm.masked.load.v16f32.p0(ptr nocapture, i32 immarg, <16 x i1>, <16 x float>)
+
+declare void @llvm.masked.store.v16f32.p0(<16 x float>, ptr nocapture, i32 immarg, <16 x i1>)
diff --git a/llvm/test/CodeGen/X86/pr78897.ll b/llvm/test/CodeGen/X86/pr78897.ll
index 56e4ec2bc8ecbb..38a1800df956b5 100644
--- a/llvm/test/CodeGen/X86/pr78897.ll
+++ b/llvm/test/CodeGen/X86/pr78897.ll
@@ -256,8 +256,8 @@ define <16 x i8> @produceShuffleVectorForByte(i8 zeroext %0) nounwind {
 ;
 ; X64-AVX512-LABEL: produceShuffleVectorForByte:
 ; X64-AVX512:       # %bb.0: # %entry
-; X64-AVX512-NEXT:    vpbroadcastb %edi, %xmm0
-; X64-AVX512-NEXT:    vptestnmb {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %k1
+; X64-AVX512-NEXT:    kmovw %edi, %k0
+; X64-AVX512-NEXT:    knotw %k0, %k1
 ; X64-AVX512-NEXT:    vmovdqu8 {{.*#+}} xmm0 {%k1} {z} = [17,17,17,17,17,17,17,17,u,u,u,u,u,u,u,u]
 ; X64-AVX512-NEXT:    vmovq %xmm0, %rax
 ; X64-AVX512-NEXT:    movabsq $1229782938247303440, %rcx # imm = 0x1111111111111110

>From 3f39f655f46dc7f9f3a5d3218ecffcaeb782fccb Mon Sep 17 00:00:00 2001
From: abhishek-kaushik22 <abhishek.kaushik at intel.com>
Date: Fri, 20 Dec 2024 12:29:36 +0530
Subject: [PATCH 2/4] Review Changes

---
 llvm/lib/Target/X86/X86ISelDAGToDAG.cpp | 23 +++++++++++------------
 llvm/test/CodeGen/X86/pr78897.ll        |  5 +++--
 2 files changed, 14 insertions(+), 14 deletions(-)

diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
index 8c199a30dfbce7..054ff6743b9a5b 100644
--- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
@@ -4976,7 +4976,8 @@ bool X86DAGToDAGISel::tryVPTESTMOrKMOV(SDNode *Root, SDValue Setcc,
   };
 
   auto canUseKMOV = [&]() {
-    if (Src0.getOpcode() != X86ISD::VBROADCAST)
+    if (Src0.getOpcode() != X86ISD::VBROADCAST &&
+        Src0.getOpcode() != X86ISD::VBROADCAST_LOAD)
       return false;
 
     if (Src1.getOpcode() != ISD::LOAD ||
@@ -4994,20 +4995,18 @@ bool X86DAGToDAGISel::tryVPTESTMOrKMOV(SDNode *Root, SDValue Setcc,
     if (!ConstVecType)
       return false;
 
-    for (unsigned i = 0, e = ConstVecType->getNumElements(), k = 1; i != e;
-         ++i, k *= 2) {
-      const auto *Element = ConstVec->getAggregateElement(i);
+    for (unsigned I = 0, E = ConstVecType->getNumElements(); I != E; ++I) {
+      const auto *Element = ConstVec->getAggregateElement(I);
       if (llvm::isa<llvm::UndefValue>(Element)) {
-        for (unsigned j = i + 1; j != e; ++j) {
-          if (!llvm::isa<llvm::UndefValue>(ConstVec->getAggregateElement(j)))
+        for (unsigned J = I + 1; J != E; ++J) {
+          if (!llvm::isa<llvm::UndefValue>(ConstVec->getAggregateElement(J)))
             return false;
         }
-        return i != 0;
+        return I != 0;
       }
 
-      if (Element->getUniqueInteger() != k) {
+      if (Element->getUniqueInteger() != 1 << I)
         return false;
-      }
     }
 
     return true;
@@ -5024,10 +5023,10 @@ bool X86DAGToDAGISel::tryVPTESTMOrKMOV(SDNode *Root, SDValue Setcc,
   MVT ResVT = Setcc.getSimpleValueType();
   if (CanFoldLoads) {
     if (canUseKMOV()) {
-      auto Op = Src0.getOperand(0);
-      if (Op.getSimpleValueType() == MVT::i8) {
+      auto Op = Src0.getOpcode() == X86ISD::VBROADCAST ? Src0.getOperand(0)
+                                                       : Src0.getOperand(1);
+      if (Op.getSimpleValueType() == MVT::i8)
         Op = SDValue(CurDAG->getNode(ISD::ZERO_EXTEND, dl, MVT::i32, Op));
-      }
       CNode = CurDAG->getMachineNode(
           ResVT.getVectorNumElements() <= 8 ? X86::KMOVBkr : X86::KMOVWkr, dl,
           ResVT, Op);
diff --git a/llvm/test/CodeGen/X86/pr78897.ll b/llvm/test/CodeGen/X86/pr78897.ll
index 38a1800df956b5..0c4c03de5901ec 100644
--- a/llvm/test/CodeGen/X86/pr78897.ll
+++ b/llvm/test/CodeGen/X86/pr78897.ll
@@ -223,8 +223,9 @@ define <16 x i8> @produceShuffleVectorForByte(i8 zeroext %0) nounwind {
 ; X86-AVX512-NEXT:    pushl %ebx
 ; X86-AVX512-NEXT:    pushl %edi
 ; X86-AVX512-NEXT:    pushl %esi
-; X86-AVX512-NEXT:    vpbroadcastb {{[0-9]+}}(%esp), %xmm0
-; X86-AVX512-NEXT:    vptestnmb {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0, %k1
+; X86-AVX512-NEXT:    leal {{[0-9]+}}(%esp)
+; X86-AVX512-NEXT:    kmovw %eax, %k0
+; X86-AVX512-NEXT:    knotw %k0, %k1
 ; X86-AVX512-NEXT:    vmovdqu8 {{.*#+}} xmm0 {%k1} {z} = [17,17,17,17,17,17,17,17,u,u,u,u,u,u,u,u]
 ; X86-AVX512-NEXT:    vpextrd $1, %xmm0, %eax
 ; X86-AVX512-NEXT:    vmovd %xmm0, %edx

>From 57c4aa0ad3fa524f15fb03acee7e92cab250f979 Mon Sep 17 00:00:00 2001
From: abhishek-kaushik22 <abhishek.kaushik at intel.com>
Date: Mon, 23 Dec 2024 11:06:02 +0530
Subject: [PATCH 3/4] Update tests

- Remove attributes
- Remove fast math flags
- Simplify tests by removing mask/loads
---
 llvm/test/CodeGen/X86/kmov.ll | 179 ++++++++++------------------------
 1 file changed, 52 insertions(+), 127 deletions(-)

diff --git a/llvm/test/CodeGen/X86/kmov.ll b/llvm/test/CodeGen/X86/kmov.ll
index 6d72a8923c5ab3..f17a559012e676 100644
--- a/llvm/test/CodeGen/X86/kmov.ll
+++ b/llvm/test/CodeGen/X86/kmov.ll
@@ -1,108 +1,73 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
 ; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mcpu=skylake-avx512 | FileCheck %s
 
-define dso_local void @foo_16_ne(ptr nocapture noundef writeonly %c, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %mask) {
-; CHECK-LABEL: foo_16_ne:
+define <16 x i1> @pr120593_16_ne(i32 %mask) {
+; CHECK-LABEL: pr120593_16_ne:
 ; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    kmovw %ecx, %k1
-; CHECK-NEXT:    vmovups (%rdx), %zmm0 {%k1} {z}
-; CHECK-NEXT:    vmovups (%rsi), %zmm1 {%k1} {z}
-; CHECK-NEXT:    vaddps %zmm1, %zmm0, %zmm0
-; CHECK-NEXT:    vmovups %zmm0, (%rdi) {%k1}
-; CHECK-NEXT:    vzeroupper
+; CHECK-NEXT:    kmovw %edi, %k0
+; CHECK-NEXT:    vpmovm2b %k0, %xmm0
 ; CHECK-NEXT:    retq
 entry:
   %0 = and i32 %mask, 65535
   %.splatinsert = insertelement <16 x i32> poison, i32 %0, i64 0
   %.splat = shufflevector <16 x i32> %.splatinsert, <16 x i32> poison, <16 x i32> zeroinitializer
   %1 = and <16 x i32> %.splat, <i32 1, i32 2, i32 4, i32 8, i32 16, i32 32, i32 64, i32 128, i32 256, i32 512, i32 1024, i32 2048, i32 4096, i32 8192, i32 16384, i32 32768>
-  %hir.cmp.45 = icmp ne <16 x i32> %1, zeroinitializer
-  %2 = tail call <16 x float> @llvm.masked.load.v16f32.p0(ptr %b, i32 4, <16 x i1> %hir.cmp.45, <16 x float> poison)
-  %3 = tail call <16 x float> @llvm.masked.load.v16f32.p0(ptr %a, i32 4, <16 x i1> %hir.cmp.45, <16 x float> poison)
-  %4 = fadd reassoc nsz arcp contract afn <16 x float> %2, %3
-  tail call void @llvm.masked.store.v16f32.p0(<16 x float> %4, ptr %c, i32 4, <16 x i1> %hir.cmp.45)
-  ret void
+  %cmp.45 = icmp ne <16 x i32> %1, zeroinitializer
+  ret <16 x i1> %cmp.45
 }
 
-; Function Attrs: mustprogress nounwind uwtable
-define dso_local void @foo_16_eq(ptr nocapture noundef writeonly %c, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %mask) {
-; CHECK-LABEL: foo_16_eq:
+define <16 x i1> @pr120593_16_eq(i32 %mask) {
+; CHECK-LABEL: pr120593_16_eq:
 ; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    kmovw %ecx, %k0
-; CHECK-NEXT:    knotw %k0, %k1
-; CHECK-NEXT:    vmovups (%rdx), %zmm0 {%k1} {z}
-; CHECK-NEXT:    vmovups (%rsi), %zmm1 {%k1} {z}
-; CHECK-NEXT:    vaddps %zmm1, %zmm0, %zmm0
-; CHECK-NEXT:    vmovups %zmm0, (%rdi) {%k1}
-; CHECK-NEXT:    vzeroupper
+; CHECK-NEXT:    kmovw %edi, %k0
+; CHECK-NEXT:    knotw %k0, %k0
+; CHECK-NEXT:    vpmovm2b %k0, %xmm0
 ; CHECK-NEXT:    retq
 entry:
   %0 = and i32 %mask, 65535
   %.splatinsert = insertelement <16 x i32> poison, i32 %0, i64 0
   %.splat = shufflevector <16 x i32> %.splatinsert, <16 x i32> poison, <16 x i32> zeroinitializer
   %1 = and <16 x i32> %.splat, <i32 1, i32 2, i32 4, i32 8, i32 16, i32 32, i32 64, i32 128, i32 256, i32 512, i32 1024, i32 2048, i32 4096, i32 8192, i32 16384, i32 32768>
-  %hir.cmp.45 = icmp eq <16 x i32> %1, zeroinitializer
-  %2 = tail call <16 x float> @llvm.masked.load.v16f32.p0(ptr %b, i32 4, <16 x i1> %hir.cmp.45, <16 x float> poison)
-  %3 = tail call <16 x float> @llvm.masked.load.v16f32.p0(ptr %a, i32 4, <16 x i1> %hir.cmp.45, <16 x float> poison)
-  %4 = fadd reassoc nsz arcp contract afn <16 x float> %2, %3
-  tail call void @llvm.masked.store.v16f32.p0(<16 x float> %4, ptr %c, i32 4, <16 x i1> %hir.cmp.45)
-  ret void
+  %cmp.45 = icmp eq <16 x i32> %1, zeroinitializer
+  ret <16 x i1> %cmp.45
 }
 
-define dso_local void @foo_8_ne(ptr nocapture noundef writeonly %c, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %mask) {
-; CHECK-LABEL: foo_8_ne:
+define <8 x i1> @pr120593_8_ne(i32 %mask) {
+; CHECK-LABEL: pr120593_8_ne:
 ; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    kmovb %ecx, %k1
-; CHECK-NEXT:    vmovups (%rdx), %ymm0 {%k1} {z}
-; CHECK-NEXT:    vmovups (%rsi), %ymm1 {%k1} {z}
-; CHECK-NEXT:    vaddps %ymm1, %ymm0, %ymm0
-; CHECK-NEXT:    vmovups %ymm0, (%rdi) {%k1}
-; CHECK-NEXT:    vzeroupper
+; CHECK-NEXT:    kmovb %edi, %k0
+; CHECK-NEXT:    vpmovm2w %k0, %xmm0
 ; CHECK-NEXT:    retq
 entry:
   %0 = and i32 %mask, 65535
   %.splatinsert = insertelement <8 x i32> poison, i32 %0, i64 0
   %.splat = shufflevector <8 x i32> %.splatinsert, <8 x i32> poison, <8 x i32> zeroinitializer
   %1 = and <8 x i32> %.splat, <i32 1, i32 2, i32 4, i32 8, i32 16, i32 32, i32 64, i32 128>
-  %hir.cmp.45 = icmp ne <8 x i32> %1, zeroinitializer
-  %2 = tail call <8 x float> @llvm.masked.load.v8f32.p0(ptr %b, i32 4, <8 x i1> %hir.cmp.45, <8 x float> poison)
-  %3 = tail call <8 x float> @llvm.masked.load.v8f32.p0(ptr %a, i32 4, <8 x i1> %hir.cmp.45, <8 x float> poison)
-  %4 = fadd reassoc nsz arcp contract afn <8 x float> %2, %3
-  tail call void @llvm.masked.store.v8f32.p0(<8 x float> %4, ptr %c, i32 4, <8 x i1> %hir.cmp.45)
-  ret void
+  %cmp.45 = icmp ne <8 x i32> %1, zeroinitializer
+  ret <8 x i1> %cmp.45
 }
 
-define dso_local void @foo_8_eq(ptr nocapture noundef writeonly %c, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %mask) {
-; CHECK-LABEL: foo_8_eq:
+define <8 x i1> @pr120593_8_eq(i32 %mask) {
+; CHECK-LABEL: pr120593_8_eq:
 ; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    kmovb %ecx, %k0
-; CHECK-NEXT:    knotb %k0, %k1
-; CHECK-NEXT:    vmovups (%rdx), %ymm0 {%k1} {z}
-; CHECK-NEXT:    vmovups (%rsi), %ymm1 {%k1} {z}
-; CHECK-NEXT:    vaddps %ymm1, %ymm0, %ymm0
-; CHECK-NEXT:    vmovups %ymm0, (%rdi) {%k1}
-; CHECK-NEXT:    vzeroupper
+; CHECK-NEXT:    kmovb %edi, %k0
+; CHECK-NEXT:    knotb %k0, %k0
+; CHECK-NEXT:    vpmovm2w %k0, %xmm0
 ; CHECK-NEXT:    retq
 entry:
   %0 = and i32 %mask, 65535
   %.splatinsert = insertelement <8 x i32> poison, i32 %0, i64 0
   %.splat = shufflevector <8 x i32> %.splatinsert, <8 x i32> poison, <8 x i32> zeroinitializer
   %1 = and <8 x i32> %.splat, <i32 1, i32 2, i32 4, i32 8, i32 16, i32 32, i32 64, i32 128>
-  %hir.cmp.45 = icmp eq <8 x i32> %1, zeroinitializer
-  %2 = tail call <8 x float> @llvm.masked.load.v8f32.p0(ptr %b, i32 4, <8 x i1> %hir.cmp.45, <8 x float> poison)
-  %3 = tail call <8 x float> @llvm.masked.load.v8f32.p0(ptr %a, i32 4, <8 x i1> %hir.cmp.45, <8 x float> poison)
-  %4 = fadd reassoc nsz arcp contract afn <8 x float> %2, %3
-  tail call void @llvm.masked.store.v8f32.p0(<8 x float> %4, ptr %c, i32 4, <8 x i1> %hir.cmp.45)
-  ret void
+  %cmp.45 = icmp eq <8 x i32> %1, zeroinitializer
+  ret <8 x i1> %cmp.45
 }
 
-define dso_local void @foo_4_ne(ptr nocapture noundef writeonly %c, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %mask) {
-; CHECK-LABEL: foo_4_ne:
+define void @pr120593_4_ne(ptr %c, ptr %b, i32 %mask) {
+; CHECK-LABEL: pr120593_4_ne:
 ; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    kmovb %ecx, %k1
-; CHECK-NEXT:    vmovups (%rdx), %xmm0 {%k1} {z}
-; CHECK-NEXT:    vmovups (%rsi), %xmm1 {%k1} {z}
-; CHECK-NEXT:    vaddps %xmm1, %xmm0, %xmm0
+; CHECK-NEXT:    kmovb %edx, %k1
+; CHECK-NEXT:    vmovups (%rsi), %xmm0 {%k1} {z}
 ; CHECK-NEXT:    vmovups %xmm0, (%rdi) {%k1}
 ; CHECK-NEXT:    retq
 entry:
@@ -110,22 +75,18 @@ entry:
   %.splatinsert = insertelement <4 x i32> poison, i32 %0, i64 0
   %.splat = shufflevector <4 x i32> %.splatinsert, <4 x i32> poison, <4 x i32> zeroinitializer
   %1 = and <4 x i32> %.splat, <i32 1, i32 2, i32 4, i32 8>
-  %hir.cmp.45 = icmp ne <4 x i32> %1, zeroinitializer
-  %2 = tail call <4 x float> @llvm.masked.load.v4f32.p0(ptr %b, i32 4, <4 x i1> %hir.cmp.45, <4 x float> poison)
-  %3 = tail call <4 x float> @llvm.masked.load.v4f32.p0(ptr %a, i32 4, <4 x i1> %hir.cmp.45, <4 x float> poison)
-  %4 = fadd reassoc nsz arcp contract afn <4 x float> %2, %3
-  tail call void @llvm.masked.store.v4f32.p0(<4 x float> %4, ptr %c, i32 4, <4 x i1> %hir.cmp.45)
+  %cmp.45 = icmp ne <4 x i32> %1, zeroinitializer
+  %2 = tail call <4 x float> @llvm.masked.load.v4f32.p0(ptr %b, i32 4, <4 x i1> %cmp.45, <4 x float> poison)
+  tail call void @llvm.masked.store.v4f32.p0(<4 x float> %2, ptr %c, i32 4, <4 x i1> %cmp.45)
   ret void
 }
 
-define dso_local void @foo_4_eq(ptr nocapture noundef writeonly %c, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %mask) {
-; CHECK-LABEL: foo_4_eq:
+define void @pr120593_4_eq(ptr %c, ptr %b, i32 %mask) {
+; CHECK-LABEL: pr120593_4_eq:
 ; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    kmovb %ecx, %k0
+; CHECK-NEXT:    kmovb %edx, %k0
 ; CHECK-NEXT:    knotb %k0, %k1
-; CHECK-NEXT:    vmovups (%rdx), %xmm0 {%k1} {z}
-; CHECK-NEXT:    vmovups (%rsi), %xmm1 {%k1} {z}
-; CHECK-NEXT:    vaddps %xmm1, %xmm0, %xmm0
+; CHECK-NEXT:    vmovups (%rsi), %xmm0 {%k1} {z}
 ; CHECK-NEXT:    vmovups %xmm0, (%rdi) {%k1}
 ; CHECK-NEXT:    retq
 entry:
@@ -133,73 +94,37 @@ entry:
   %.splatinsert = insertelement <4 x i32> poison, i32 %0, i64 0
   %.splat = shufflevector <4 x i32> %.splatinsert, <4 x i32> poison, <4 x i32> zeroinitializer
   %1 = and <4 x i32> %.splat, <i32 1, i32 2, i32 4, i32 8>
-  %hir.cmp.45 = icmp eq <4 x i32> %1, zeroinitializer
-  %2 = tail call <4 x float> @llvm.masked.load.v4f32.p0(ptr %b, i32 4, <4 x i1> %hir.cmp.45, <4 x float> poison)
-  %3 = tail call <4 x float> @llvm.masked.load.v4f32.p0(ptr %a, i32 4, <4 x i1> %hir.cmp.45, <4 x float> poison)
-  %4 = fadd reassoc nsz arcp contract afn <4 x float> %2, %3
-  tail call void @llvm.masked.store.v4f32.p0(<4 x float> %4, ptr %c, i32 4, <4 x i1> %hir.cmp.45)
+  %cmp.45 = icmp eq <4 x i32> %1, zeroinitializer
+  %2 = tail call <4 x float> @llvm.masked.load.v4f32.p0(ptr %b, i32 4, <4 x i1> %cmp.45, <4 x float> poison)
+  tail call void @llvm.masked.store.v4f32.p0(<4 x float> %2, ptr %c, i32 4, <4 x i1> %cmp.45)
   ret void
 }
 
-define dso_local void @foo_2_ne(ptr nocapture noundef writeonly %c, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %mask) {
-; CHECK-LABEL: foo_2_ne:
+define <2 x i1> @pr120593_2_ne(i32 %mask) {
+; CHECK-LABEL: pr120593_2_ne:
 ; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    kmovb %ecx, %k0
-; CHECK-NEXT:    kshiftlb $6, %k0, %k0
-; CHECK-NEXT:    kshiftrb $6, %k0, %k1
-; CHECK-NEXT:    vmovups (%rdx), %xmm0 {%k1} {z}
-; CHECK-NEXT:    vmovups (%rsi), %xmm1 {%k1} {z}
-; CHECK-NEXT:    vaddps %xmm1, %xmm0, %xmm0
-; CHECK-NEXT:    vmovups %xmm0, (%rdi) {%k1}
+; CHECK-NEXT:    kmovb %edi, %k0
+; CHECK-NEXT:    vpmovm2q %k0, %xmm0
 ; CHECK-NEXT:    retq
 entry:
   %.splatinsert = insertelement <2 x i32> poison, i32 %mask, i64 0
   %.splat = shufflevector <2 x i32> %.splatinsert, <2 x i32> poison, <2 x i32> zeroinitializer
   %0 = and <2 x i32> %.splat, <i32 1, i32 2>
-  %hir.cmp.44 = icmp ne <2 x i32> %0, zeroinitializer
-  %1 = tail call <2 x float> @llvm.masked.load.v2f32.p0(ptr %b, i32 4, <2 x i1> %hir.cmp.44, <2 x float> poison)
-  %2 = tail call <2 x float> @llvm.masked.load.v2f32.p0(ptr %a, i32 4, <2 x i1> %hir.cmp.44, <2 x float> poison)
-  %3 = fadd reassoc nsz arcp contract afn <2 x float> %1, %2
-  tail call void @llvm.masked.store.v2f32.p0(<2 x float> %3, ptr %c, i32 4, <2 x i1> %hir.cmp.44)
-  ret void
+  %cmp.44 = icmp ne <2 x i32> %0, zeroinitializer
+  ret <2 x i1> %cmp.44
 }
 
-define dso_local void @foo_2_eq(ptr nocapture noundef writeonly %c, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %mask) {
-; CHECK-LABEL: foo_2_eq:
+define <2 x i1> @pr120593_2_eq(i32 %mask) {
+; CHECK-LABEL: pr120593_2_eq:
 ; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    kmovb %ecx, %k0
+; CHECK-NEXT:    kmovb %edi, %k0
 ; CHECK-NEXT:    knotb %k0, %k0
-; CHECK-NEXT:    kshiftlb $6, %k0, %k0
-; CHECK-NEXT:    kshiftrb $6, %k0, %k1
-; CHECK-NEXT:    vmovups (%rdx), %xmm0 {%k1} {z}
-; CHECK-NEXT:    vmovups (%rsi), %xmm1 {%k1} {z}
-; CHECK-NEXT:    vaddps %xmm1, %xmm0, %xmm0
-; CHECK-NEXT:    vmovups %xmm0, (%rdi) {%k1}
+; CHECK-NEXT:    vpmovm2q %k0, %xmm0
 ; CHECK-NEXT:    retq
 entry:
   %.splatinsert = insertelement <2 x i32> poison, i32 %mask, i64 0
   %.splat = shufflevector <2 x i32> %.splatinsert, <2 x i32> poison, <2 x i32> zeroinitializer
   %0 = and <2 x i32> %.splat, <i32 1, i32 2>
-  %hir.cmp.44 = icmp eq <2 x i32> %0, zeroinitializer
-  %1 = tail call <2 x float> @llvm.masked.load.v2f32.p0(ptr %b, i32 4, <2 x i1> %hir.cmp.44, <2 x float> poison)
-  %2 = tail call <2 x float> @llvm.masked.load.v2f32.p0(ptr %a, i32 4, <2 x i1> %hir.cmp.44, <2 x float> poison)
-  %3 = fadd reassoc nsz arcp contract afn <2 x float> %1, %2
-  tail call void @llvm.masked.store.v2f32.p0(<2 x float> %3, ptr %c, i32 4, <2 x i1> %hir.cmp.44)
-  ret void
+  %cmp.44 = icmp eq <2 x i32> %0, zeroinitializer
+  ret <2 x i1> %cmp.44
 }
-
-declare <2 x float> @llvm.masked.load.v2f32.p0(ptr nocapture, i32 immarg, <2 x i1>, <2 x float>) #1
-
-declare void @llvm.masked.store.v2f32.p0(<2 x float>, ptr nocapture, i32 immarg, <2 x i1>) #2
-
-declare <4 x float> @llvm.masked.load.v4f32.p0(ptr nocapture, i32 immarg, <4 x i1>, <4 x float>) #1
-
-declare void @llvm.masked.store.v4f32.p0(<4 x float>, ptr nocapture, i32 immarg, <4 x i1>) #2
-
-declare <8 x float> @llvm.masked.load.v8f32.p0(ptr nocapture, i32 immarg, <8 x i1>, <8 x float>)
-
-declare void @llvm.masked.store.v8f32.p0(<8 x float>, ptr nocapture, i32 immarg, <8 x i1>)
-
-declare <16 x float> @llvm.masked.load.v16f32.p0(ptr nocapture, i32 immarg, <16 x i1>, <16 x float>)
-
-declare void @llvm.masked.store.v16f32.p0(<16 x float>, ptr nocapture, i32 immarg, <16 x i1>)

>From 85b9945ef9cde51353f6e96bd68270b62c542d06 Mon Sep 17 00:00:00 2001
From: abhishek-kaushik22 <abhishek.kaushik at intel.com>
Date: Tue, 24 Dec 2024 15:24:33 +0530
Subject: [PATCH 4/4] Combine to KMOV

Combine to KMOV instead of doing it in ISEL
---
 llvm/lib/Target/X86/X86ISelDAGToDAG.cpp | 78 ++++----------------
 llvm/lib/Target/X86/X86ISelLowering.cpp | 95 +++++++++++++++++++++++++
 llvm/test/CodeGen/X86/kmov.ll           | 22 +++---
 llvm/test/CodeGen/X86/pr78897.ll        |  6 +-
 4 files changed, 122 insertions(+), 79 deletions(-)

diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
index daf3e01506374b..9b340a778b36ad 100644
--- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
@@ -592,7 +592,7 @@ namespace {
     bool matchVPTERNLOG(SDNode *Root, SDNode *ParentA, SDNode *ParentB,
                         SDNode *ParentC, SDValue A, SDValue B, SDValue C,
                         uint8_t Imm);
-    bool tryVPTESTMOrKMOV(SDNode *Root, SDValue Setcc, SDValue Mask);
+    bool tryVPTESTM(SDNode *Root, SDValue Setcc, SDValue Mask);
     bool tryMatchBitSelect(SDNode *N);
 
     MachineSDNode *emitPCMPISTR(unsigned ROpc, unsigned MOpc, bool MayFoldLoad,
@@ -4898,10 +4898,10 @@ VPTESTM_CASE(v32i16, WZ##SUFFIX)
 #undef VPTESTM_CASE
 }
 
-// Try to create VPTESTM or KMOV instruction. If InMask is not null, it will be
-// used to form a masked operation.
-bool X86DAGToDAGISel::tryVPTESTMOrKMOV(SDNode *Root, SDValue Setcc,
-                                       SDValue InMask) {
+// Try to create VPTESTM instruction. If InMask is not null, it will be used
+// to form a masked operation.
+bool X86DAGToDAGISel::tryVPTESTM(SDNode *Root, SDValue Setcc,
+                                 SDValue InMask) {
   assert(Subtarget->hasAVX512() && "Expected AVX512!");
   assert(Setcc.getSimpleValueType().getVectorElementType() == MVT::i1 &&
          "Unexpected VT!");
@@ -4976,69 +4976,12 @@ bool X86DAGToDAGISel::tryVPTESTMOrKMOV(SDNode *Root, SDValue Setcc,
     return tryFoldBroadcast(Root, P, L, Base, Scale, Index, Disp, Segment);
   };
 
-  auto canUseKMOV = [&]() {
-    if (Src0.getOpcode() != X86ISD::VBROADCAST &&
-        Src0.getOpcode() != X86ISD::VBROADCAST_LOAD)
-      return false;
-
-    if (Src1.getOpcode() != ISD::LOAD ||
-        Src1.getOperand(1).getOpcode() != X86ISD::Wrapper ||
-        Src1.getOperand(1).getOperand(0).getOpcode() != ISD::TargetConstantPool)
-      return false;
-
-    const auto *ConstPool =
-        dyn_cast<ConstantPoolSDNode>(Src1.getOperand(1).getOperand(0));
-    if (!ConstPool)
-      return false;
-
-    const auto *ConstVec = ConstPool->getConstVal();
-    const auto *ConstVecType = dyn_cast<FixedVectorType>(ConstVec->getType());
-    if (!ConstVecType)
-      return false;
-
-    for (unsigned I = 0, E = ConstVecType->getNumElements(); I != E; ++I) {
-      const auto *Element = ConstVec->getAggregateElement(I);
-      if (llvm::isa<llvm::UndefValue>(Element)) {
-        for (unsigned J = I + 1; J != E; ++J) {
-          if (!llvm::isa<llvm::UndefValue>(ConstVec->getAggregateElement(J)))
-            return false;
-        }
-        return I != 0;
-      }
-
-      if (Element->getUniqueInteger() != 1 << I)
-        return false;
-    }
-
-    return true;
-  };
-
   // We can only fold loads if the sources are unique.
   bool CanFoldLoads = Src0 != Src1;
 
   bool FoldedLoad = false;
   SDValue Tmp0, Tmp1, Tmp2, Tmp3, Tmp4;
-  SDLoc dl(Root);
-  bool IsTestN = CC == ISD::SETEQ;
-  MachineSDNode *CNode;
-  MVT ResVT = Setcc.getSimpleValueType();
   if (CanFoldLoads) {
-    if (canUseKMOV()) {
-      auto Op = Src0.getOpcode() == X86ISD::VBROADCAST ? Src0.getOperand(0)
-                                                       : Src0.getOperand(1);
-      if (Op.getSimpleValueType() == MVT::i8)
-        Op = SDValue(CurDAG->getNode(ISD::ZERO_EXTEND, dl, MVT::i32, Op));
-      CNode = CurDAG->getMachineNode(
-          ResVT.getVectorNumElements() <= 8 ? X86::KMOVBkr : X86::KMOVWkr, dl,
-          ResVT, Op);
-      if (IsTestN)
-        CNode = CurDAG->getMachineNode(
-            ResVT.getVectorNumElements() <= 8 ? X86::KNOTBkk : X86::KNOTWkk, dl,
-            ResVT, SDValue(CNode, 0));
-      ReplaceUses(SDValue(Root, 0), SDValue(CNode, 0));
-      CurDAG->RemoveDeadNode(Root);
-      return true;
-    }
     FoldedLoad = tryFoldLoadOrBCast(Root, N0.getNode(), Src1, Tmp0, Tmp1, Tmp2,
                                     Tmp3, Tmp4);
     if (!FoldedLoad) {
@@ -5054,6 +4997,9 @@ bool X86DAGToDAGISel::tryVPTESTMOrKMOV(SDNode *Root, SDValue Setcc,
 
   bool IsMasked = InMask.getNode() != nullptr;
 
+  SDLoc dl(Root);
+
+  MVT ResVT = Setcc.getSimpleValueType();
   MVT MaskVT = ResVT;
   if (Widen) {
     // Widen the inputs using insert_subreg or copy_to_regclass.
@@ -5078,9 +5024,11 @@ bool X86DAGToDAGISel::tryVPTESTMOrKMOV(SDNode *Root, SDValue Setcc,
     }
   }
 
+  bool IsTestN = CC == ISD::SETEQ;
   unsigned Opc = getVPTESTMOpc(CmpVT, IsTestN, FoldedLoad, FoldedBCast,
                                IsMasked);
 
+  MachineSDNode *CNode;
   if (FoldedLoad) {
     SDVTList VTs = CurDAG->getVTList(MaskVT, MVT::Other);
 
@@ -5519,10 +5467,10 @@ void X86DAGToDAGISel::Select(SDNode *Node) {
       SDValue N0 = Node->getOperand(0);
       SDValue N1 = Node->getOperand(1);
       if (N0.getOpcode() == ISD::SETCC && N0.hasOneUse() &&
-          tryVPTESTMOrKMOV(Node, N0, N1))
+          tryVPTESTM(Node, N0, N1))
         return;
       if (N1.getOpcode() == ISD::SETCC && N1.hasOneUse() &&
-          tryVPTESTMOrKMOV(Node, N1, N0))
+          tryVPTESTM(Node, N1, N0))
         return;
     }
 
@@ -6446,7 +6394,7 @@ void X86DAGToDAGISel::Select(SDNode *Node) {
   }
 
   case ISD::SETCC: {
-    if (NVT.isVector() && tryVPTESTMOrKMOV(Node, SDValue(Node, 0), SDValue()))
+    if (NVT.isVector() && tryVPTESTM(Node, SDValue(Node, 0), SDValue()))
       return;
 
     break;
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 2528ca553d3e97..e61bb46a683ec5 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -55288,6 +55288,95 @@ static SDValue truncateAVX512SetCCNoBWI(EVT VT, EVT OpVT, SDValue LHS,
   return SDValue();
 }
 
+static SDValue combineAVX512SetCCToKMOV(EVT VT, SDValue Op0, ISD::CondCode CC,
+                                        const SDLoc &DL, SelectionDAG &DAG,
+                                        const X86Subtarget &Subtarget) {
+  if (CC != ISD::SETNE && CC != ISD::SETEQ)
+    return SDValue();
+
+  if (!Subtarget.hasAVX512())
+    return SDValue();
+
+  if (Op0.getOpcode() != ISD::AND)
+    return SDValue();
+
+  SDValue Broadcast = Op0.getOperand(0);
+  if (Broadcast.getOpcode() != X86ISD::VBROADCAST &&
+      Broadcast.getOpcode() != X86ISD::VBROADCAST_LOAD)
+    return SDValue();
+
+  SDValue Load = Op0.getOperand(1);
+  if (Load.getOpcode() != ISD::LOAD)
+    return SDValue();
+
+  SDValue Wrapper = Load.getOperand(1);
+  if (Wrapper.getOpcode() != X86ISD::Wrapper)
+    return SDValue();
+
+  const auto *TargetConstPool =
+      dyn_cast<ConstantPoolSDNode>(Wrapper.getOperand(0));
+  if (!TargetConstPool)
+    return SDValue();
+
+  const auto *ConstVec = TargetConstPool->getConstVal();
+  const auto *ConstVecType = dyn_cast<FixedVectorType>(ConstVec->getType());
+  if (!ConstVecType)
+    return SDValue();
+
+  const auto *First = ConstVec->getAggregateElement(0U);
+  if (llvm::isa<UndefValue>(First) || !First->getUniqueInteger().isPowerOf2())
+    return SDValue();
+
+  unsigned N = First->getUniqueInteger().logBase2();
+
+  for (unsigned I = 1, E = ConstVecType->getNumElements(); I < E; ++I) {
+    const auto *Element = ConstVec->getAggregateElement(I);
+    if (llvm::isa<llvm::UndefValue>(Element)) {
+      for (unsigned J = I + 1; J != E; ++J) {
+        if (!llvm::isa<llvm::UndefValue>(ConstVec->getAggregateElement(J)))
+          return SDValue();
+      }
+      break;
+    }
+
+    if (Element->getUniqueInteger() != 1 << (I + N))
+      return SDValue();
+  }
+
+  SDValue BroadcastOp = Broadcast.getOpcode() == X86ISD::VBROADCAST
+                            ? Broadcast.getOperand(0)
+                            : Broadcast.getOperand(1);
+  MVT BroadcastOpVT = BroadcastOp.getSimpleValueType();
+  unsigned Len = VT.getVectorNumElements();
+  SDValue Masked = BroadcastOp;
+  if (N != 0) {
+    unsigned Mask = (1ULL << Len) - 1;
+    SDValue ShiftedValue = DAG.getNode(ISD::SRL, DL, BroadcastOpVT, BroadcastOp,
+                                       DAG.getConstant(N, DL, BroadcastOpVT));
+    Masked = DAG.getNode(ISD::AND, DL, BroadcastOpVT, ShiftedValue,
+                         DAG.getConstant(Mask, DL, BroadcastOpVT));
+  }
+  SDValue Trunc = DAG.getNode(BroadcastOpVT.bitsGT(MVT::i16) ? ISD::TRUNCATE
+                                                             : ISD::ANY_EXTEND,
+                              DL, MVT::i16, Masked);
+  SDValue Bitcast = DAG.getNode(ISD::BITCAST, DL, MVT::v16i1, Trunc);
+  MVT PtrTy = DAG.getTargetLoweringInfo().getPointerTy(DAG.getDataLayout());
+
+  if (CC == ISD::SETEQ)
+    Bitcast = DAG.getNode(
+        ISD::XOR, DL, MVT::v16i1, Bitcast,
+        DAG.getSplatBuildVector(
+            MVT::v16i1, DL,
+            DAG.getConstant(APInt::getAllOnes(PtrTy.getSizeInBits()), DL,
+                            PtrTy)));
+
+  if (VT != MVT::v16i1)
+    return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Bitcast,
+                       DAG.getConstant(0, DL, PtrTy));
+
+  return Bitcast;
+}
+
 static SDValue combineSetCC(SDNode *N, SelectionDAG &DAG,
                             TargetLowering::DAGCombinerInfo &DCI,
                             const X86Subtarget &Subtarget) {
@@ -55420,6 +55509,12 @@ static SDValue combineSetCC(SDNode *N, SelectionDAG &DAG,
              "Unexpected condition code!");
       return Op0.getOperand(0);
     }
+
+    if (IsVZero1) {
+      if (SDValue V =
+              combineAVX512SetCCToKMOV(VT, Op0, TmpCC, DL, DAG, Subtarget))
+        return V;
+    }
   }
 
   // Try and make unsigned vector comparison signed. On pre AVX512 targets there
diff --git a/llvm/test/CodeGen/X86/kmov.ll b/llvm/test/CodeGen/X86/kmov.ll
index f17a559012e676..ba39fc4d1af768 100644
--- a/llvm/test/CodeGen/X86/kmov.ll
+++ b/llvm/test/CodeGen/X86/kmov.ll
@@ -1,10 +1,10 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mcpu=skylake-avx512 | FileCheck %s
+; RUN: llc < %s -mtriple=x86_64-- -mcpu=x86-64-v4 | FileCheck %s
 
 define <16 x i1> @pr120593_16_ne(i32 %mask) {
 ; CHECK-LABEL: pr120593_16_ne:
 ; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    kmovw %edi, %k0
+; CHECK-NEXT:    kmovd %edi, %k0
 ; CHECK-NEXT:    vpmovm2b %k0, %xmm0
 ; CHECK-NEXT:    retq
 entry:
@@ -19,7 +19,7 @@ entry:
 define <16 x i1> @pr120593_16_eq(i32 %mask) {
 ; CHECK-LABEL: pr120593_16_eq:
 ; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    kmovw %edi, %k0
+; CHECK-NEXT:    kmovd %edi, %k0
 ; CHECK-NEXT:    knotw %k0, %k0
 ; CHECK-NEXT:    vpmovm2b %k0, %xmm0
 ; CHECK-NEXT:    retq
@@ -35,7 +35,7 @@ entry:
 define <8 x i1> @pr120593_8_ne(i32 %mask) {
 ; CHECK-LABEL: pr120593_8_ne:
 ; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    kmovb %edi, %k0
+; CHECK-NEXT:    kmovd %edi, %k0
 ; CHECK-NEXT:    vpmovm2w %k0, %xmm0
 ; CHECK-NEXT:    retq
 entry:
@@ -50,7 +50,7 @@ entry:
 define <8 x i1> @pr120593_8_eq(i32 %mask) {
 ; CHECK-LABEL: pr120593_8_eq:
 ; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    kmovb %edi, %k0
+; CHECK-NEXT:    kmovd %edi, %k0
 ; CHECK-NEXT:    knotb %k0, %k0
 ; CHECK-NEXT:    vpmovm2w %k0, %xmm0
 ; CHECK-NEXT:    retq
@@ -66,7 +66,7 @@ entry:
 define void @pr120593_4_ne(ptr %c, ptr %b, i32 %mask) {
 ; CHECK-LABEL: pr120593_4_ne:
 ; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    kmovb %edx, %k1
+; CHECK-NEXT:    kmovd %edx, %k1
 ; CHECK-NEXT:    vmovups (%rsi), %xmm0 {%k1} {z}
 ; CHECK-NEXT:    vmovups %xmm0, (%rdi) {%k1}
 ; CHECK-NEXT:    retq
@@ -84,8 +84,8 @@ entry:
 define void @pr120593_4_eq(ptr %c, ptr %b, i32 %mask) {
 ; CHECK-LABEL: pr120593_4_eq:
 ; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    kmovb %edx, %k0
-; CHECK-NEXT:    knotb %k0, %k1
+; CHECK-NEXT:    kmovd %edx, %k0
+; CHECK-NEXT:    knotw %k0, %k1
 ; CHECK-NEXT:    vmovups (%rsi), %xmm0 {%k1} {z}
 ; CHECK-NEXT:    vmovups %xmm0, (%rdi) {%k1}
 ; CHECK-NEXT:    retq
@@ -103,7 +103,7 @@ entry:
 define <2 x i1> @pr120593_2_ne(i32 %mask) {
 ; CHECK-LABEL: pr120593_2_ne:
 ; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    kmovb %edi, %k0
+; CHECK-NEXT:    kmovd %edi, %k0
 ; CHECK-NEXT:    vpmovm2q %k0, %xmm0
 ; CHECK-NEXT:    retq
 entry:
@@ -117,8 +117,8 @@ entry:
 define <2 x i1> @pr120593_2_eq(i32 %mask) {
 ; CHECK-LABEL: pr120593_2_eq:
 ; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    kmovb %edi, %k0
-; CHECK-NEXT:    knotb %k0, %k0
+; CHECK-NEXT:    kmovd %edi, %k0
+; CHECK-NEXT:    knotw %k0, %k0
 ; CHECK-NEXT:    vpmovm2q %k0, %xmm0
 ; CHECK-NEXT:    retq
 entry:
diff --git a/llvm/test/CodeGen/X86/pr78897.ll b/llvm/test/CodeGen/X86/pr78897.ll
index 0c4c03de5901ec..c3c597f4d79dec 100644
--- a/llvm/test/CodeGen/X86/pr78897.ll
+++ b/llvm/test/CodeGen/X86/pr78897.ll
@@ -223,8 +223,8 @@ define <16 x i8> @produceShuffleVectorForByte(i8 zeroext %0) nounwind {
 ; X86-AVX512-NEXT:    pushl %ebx
 ; X86-AVX512-NEXT:    pushl %edi
 ; X86-AVX512-NEXT:    pushl %esi
-; X86-AVX512-NEXT:    leal {{[0-9]+}}(%esp)
-; X86-AVX512-NEXT:    kmovw %eax, %k0
+; X86-AVX512-NEXT:    leal {{[0-9]+}}(%esp), %eax
+; X86-AVX512-NEXT:    kmovd %eax, %k0
 ; X86-AVX512-NEXT:    knotw %k0, %k1
 ; X86-AVX512-NEXT:    vmovdqu8 {{.*#+}} xmm0 {%k1} {z} = [17,17,17,17,17,17,17,17,u,u,u,u,u,u,u,u]
 ; X86-AVX512-NEXT:    vpextrd $1, %xmm0, %eax
@@ -257,7 +257,7 @@ define <16 x i8> @produceShuffleVectorForByte(i8 zeroext %0) nounwind {
 ;
 ; X64-AVX512-LABEL: produceShuffleVectorForByte:
 ; X64-AVX512:       # %bb.0: # %entry
-; X64-AVX512-NEXT:    kmovw %edi, %k0
+; X64-AVX512-NEXT:    kmovd %edi, %k0
 ; X64-AVX512-NEXT:    knotw %k0, %k1
 ; X64-AVX512-NEXT:    vmovdqu8 {{.*#+}} xmm0 {%k1} {z} = [17,17,17,17,17,17,17,17,u,u,u,u,u,u,u,u]
 ; X64-AVX512-NEXT:    vmovq %xmm0, %rax



More information about the llvm-commits mailing list