[llvm] Generate `kmov` for masking integers (PR #120593)

via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 19 23:01:53 PST 2024


https://github.com/abhishek-kaushik22 updated https://github.com/llvm/llvm-project/pull/120593

>From 822ae48049fdebc769269291868270314f30ca9a Mon Sep 17 00:00:00 2001
From: abhishek-kaushik22 <abhishek.kaushik at intel.com>
Date: Thu, 19 Dec 2024 21:14:10 +0530
Subject: [PATCH 1/2] Generate `kmov` for masking integers

When we have an integer used as a bit mask the llvm ir looks something like this
```
%1 = and <16 x i32> %.splat, <i32 1, i32 2, i32 4, i32 8, i32 16, i32 32, i32 64, i32 128, i32 256, i32 512, i32 1024, i32 2048, i32 4096, i32 8192, i32 16384, i32 32768>
  %cmp1 = icmp ne <16 x i32> %1, zeroinitializer
```
where `.splat` is vector containing the mask in all lanes.
The assembly generated for this looks like
```
vpbroadcastd    %ecx, %zmm0
vptestmd        .LCPI0_0(%rip), %zmm0, %k1
```
where we have a constant table of powers of 2.
Instead of doing this we could just move the relevant bits directly to `k` registers using a `kmov` instruction. This is faster and also reduces code size.
---
 llvm/lib/Target/X86/X86ISelDAGToDAG.cpp |  79 +++++++--
 llvm/test/CodeGen/X86/kmov.ll           | 205 ++++++++++++++++++++++++
 llvm/test/CodeGen/X86/pr78897.ll        |   4 +-
 3 files changed, 273 insertions(+), 15 deletions(-)
 create mode 100644 llvm/test/CodeGen/X86/kmov.ll

diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
index bb20e6ecf281b0..8c199a30dfbce7 100644
--- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
@@ -592,7 +592,7 @@ namespace {
     bool matchVPTERNLOG(SDNode *Root, SDNode *ParentA, SDNode *ParentB,
                         SDNode *ParentC, SDValue A, SDValue B, SDValue C,
                         uint8_t Imm);
-    bool tryVPTESTM(SDNode *Root, SDValue Setcc, SDValue Mask);
+    bool tryVPTESTMOrKMOV(SDNode *Root, SDValue Setcc, SDValue Mask);
     bool tryMatchBitSelect(SDNode *N);
 
     MachineSDNode *emitPCMPISTR(unsigned ROpc, unsigned MOpc, bool MayFoldLoad,
@@ -4897,10 +4897,10 @@ VPTESTM_CASE(v32i16, WZ##SUFFIX)
 #undef VPTESTM_CASE
 }
 
-// Try to create VPTESTM instruction. If InMask is not null, it will be used
-// to form a masked operation.
-bool X86DAGToDAGISel::tryVPTESTM(SDNode *Root, SDValue Setcc,
-                                 SDValue InMask) {
+// Try to create VPTESTM or KMOV instruction. If InMask is not null, it will be
+// used to form a masked operation.
+bool X86DAGToDAGISel::tryVPTESTMOrKMOV(SDNode *Root, SDValue Setcc,
+                                       SDValue InMask) {
   assert(Subtarget->hasAVX512() && "Expected AVX512!");
   assert(Setcc.getSimpleValueType().getVectorElementType() == MVT::i1 &&
          "Unexpected VT!");
@@ -4975,12 +4975,70 @@ bool X86DAGToDAGISel::tryVPTESTM(SDNode *Root, SDValue Setcc,
     return tryFoldBroadcast(Root, P, L, Base, Scale, Index, Disp, Segment);
   };
 
+  auto canUseKMOV = [&]() {
+    if (Src0.getOpcode() != X86ISD::VBROADCAST)
+      return false;
+
+    if (Src1.getOpcode() != ISD::LOAD ||
+        Src1.getOperand(1).getOpcode() != X86ISD::Wrapper ||
+        Src1.getOperand(1).getOperand(0).getOpcode() != ISD::TargetConstantPool)
+      return false;
+
+    const auto *ConstPool =
+        dyn_cast<ConstantPoolSDNode>(Src1.getOperand(1).getOperand(0));
+    if (!ConstPool)
+      return false;
+
+    const auto *ConstVec = ConstPool->getConstVal();
+    const auto *ConstVecType = dyn_cast<FixedVectorType>(ConstVec->getType());
+    if (!ConstVecType)
+      return false;
+
+    for (unsigned i = 0, e = ConstVecType->getNumElements(), k = 1; i != e;
+         ++i, k *= 2) {
+      const auto *Element = ConstVec->getAggregateElement(i);
+      if (llvm::isa<llvm::UndefValue>(Element)) {
+        for (unsigned j = i + 1; j != e; ++j) {
+          if (!llvm::isa<llvm::UndefValue>(ConstVec->getAggregateElement(j)))
+            return false;
+        }
+        return i != 0;
+      }
+
+      if (Element->getUniqueInteger() != k) {
+        return false;
+      }
+    }
+
+    return true;
+  };
+
   // We can only fold loads if the sources are unique.
   bool CanFoldLoads = Src0 != Src1;
 
   bool FoldedLoad = false;
   SDValue Tmp0, Tmp1, Tmp2, Tmp3, Tmp4;
+  SDLoc dl(Root);
+  bool IsTestN = CC == ISD::SETEQ;
+  MachineSDNode *CNode;
+  MVT ResVT = Setcc.getSimpleValueType();
   if (CanFoldLoads) {
+    if (canUseKMOV()) {
+      auto Op = Src0.getOperand(0);
+      if (Op.getSimpleValueType() == MVT::i8) {
+        Op = SDValue(CurDAG->getNode(ISD::ZERO_EXTEND, dl, MVT::i32, Op));
+      }
+      CNode = CurDAG->getMachineNode(
+          ResVT.getVectorNumElements() <= 8 ? X86::KMOVBkr : X86::KMOVWkr, dl,
+          ResVT, Op);
+      if (IsTestN)
+        CNode = CurDAG->getMachineNode(
+            ResVT.getVectorNumElements() <= 8 ? X86::KNOTBkk : X86::KNOTWkk, dl,
+            ResVT, SDValue(CNode, 0));
+      ReplaceUses(SDValue(Root, 0), SDValue(CNode, 0));
+      CurDAG->RemoveDeadNode(Root);
+      return true;
+    }
     FoldedLoad = tryFoldLoadOrBCast(Root, N0.getNode(), Src1, Tmp0, Tmp1, Tmp2,
                                     Tmp3, Tmp4);
     if (!FoldedLoad) {
@@ -4996,9 +5054,6 @@ bool X86DAGToDAGISel::tryVPTESTM(SDNode *Root, SDValue Setcc,
 
   bool IsMasked = InMask.getNode() != nullptr;
 
-  SDLoc dl(Root);
-
-  MVT ResVT = Setcc.getSimpleValueType();
   MVT MaskVT = ResVT;
   if (Widen) {
     // Widen the inputs using insert_subreg or copy_to_regclass.
@@ -5023,11 +5078,9 @@ bool X86DAGToDAGISel::tryVPTESTM(SDNode *Root, SDValue Setcc,
     }
   }
 
-  bool IsTestN = CC == ISD::SETEQ;
   unsigned Opc = getVPTESTMOpc(CmpVT, IsTestN, FoldedLoad, FoldedBCast,
                                IsMasked);
 
-  MachineSDNode *CNode;
   if (FoldedLoad) {
     SDVTList VTs = CurDAG->getVTList(MaskVT, MVT::Other);
 
@@ -5466,10 +5519,10 @@ void X86DAGToDAGISel::Select(SDNode *Node) {
       SDValue N0 = Node->getOperand(0);
       SDValue N1 = Node->getOperand(1);
       if (N0.getOpcode() == ISD::SETCC && N0.hasOneUse() &&
-          tryVPTESTM(Node, N0, N1))
+          tryVPTESTMOrKMOV(Node, N0, N1))
         return;
       if (N1.getOpcode() == ISD::SETCC && N1.hasOneUse() &&
-          tryVPTESTM(Node, N1, N0))
+          tryVPTESTMOrKMOV(Node, N1, N0))
         return;
     }
 
@@ -6393,7 +6446,7 @@ void X86DAGToDAGISel::Select(SDNode *Node) {
   }
 
   case ISD::SETCC: {
-    if (NVT.isVector() && tryVPTESTM(Node, SDValue(Node, 0), SDValue()))
+    if (NVT.isVector() && tryVPTESTMOrKMOV(Node, SDValue(Node, 0), SDValue()))
       return;
 
     break;
diff --git a/llvm/test/CodeGen/X86/kmov.ll b/llvm/test/CodeGen/X86/kmov.ll
new file mode 100644
index 00000000000000..6d72a8923c5ab3
--- /dev/null
+++ b/llvm/test/CodeGen/X86/kmov.ll
@@ -0,0 +1,205 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mcpu=skylake-avx512 | FileCheck %s
+
+define dso_local void @foo_16_ne(ptr nocapture noundef writeonly %c, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %mask) {
+; CHECK-LABEL: foo_16_ne:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    kmovw %ecx, %k1
+; CHECK-NEXT:    vmovups (%rdx), %zmm0 {%k1} {z}
+; CHECK-NEXT:    vmovups (%rsi), %zmm1 {%k1} {z}
+; CHECK-NEXT:    vaddps %zmm1, %zmm0, %zmm0
+; CHECK-NEXT:    vmovups %zmm0, (%rdi) {%k1}
+; CHECK-NEXT:    vzeroupper
+; CHECK-NEXT:    retq
+entry:
+  %0 = and i32 %mask, 65535
+  %.splatinsert = insertelement <16 x i32> poison, i32 %0, i64 0
+  %.splat = shufflevector <16 x i32> %.splatinsert, <16 x i32> poison, <16 x i32> zeroinitializer
+  %1 = and <16 x i32> %.splat, <i32 1, i32 2, i32 4, i32 8, i32 16, i32 32, i32 64, i32 128, i32 256, i32 512, i32 1024, i32 2048, i32 4096, i32 8192, i32 16384, i32 32768>
+  %hir.cmp.45 = icmp ne <16 x i32> %1, zeroinitializer
+  %2 = tail call <16 x float> @llvm.masked.load.v16f32.p0(ptr %b, i32 4, <16 x i1> %hir.cmp.45, <16 x float> poison)
+  %3 = tail call <16 x float> @llvm.masked.load.v16f32.p0(ptr %a, i32 4, <16 x i1> %hir.cmp.45, <16 x float> poison)
+  %4 = fadd reassoc nsz arcp contract afn <16 x float> %2, %3
+  tail call void @llvm.masked.store.v16f32.p0(<16 x float> %4, ptr %c, i32 4, <16 x i1> %hir.cmp.45)
+  ret void
+}
+
+; Function Attrs: mustprogress nounwind uwtable
+define dso_local void @foo_16_eq(ptr nocapture noundef writeonly %c, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %mask) {
+; CHECK-LABEL: foo_16_eq:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    kmovw %ecx, %k0
+; CHECK-NEXT:    knotw %k0, %k1
+; CHECK-NEXT:    vmovups (%rdx), %zmm0 {%k1} {z}
+; CHECK-NEXT:    vmovups (%rsi), %zmm1 {%k1} {z}
+; CHECK-NEXT:    vaddps %zmm1, %zmm0, %zmm0
+; CHECK-NEXT:    vmovups %zmm0, (%rdi) {%k1}
+; CHECK-NEXT:    vzeroupper
+; CHECK-NEXT:    retq
+entry:
+  %0 = and i32 %mask, 65535
+  %.splatinsert = insertelement <16 x i32> poison, i32 %0, i64 0
+  %.splat = shufflevector <16 x i32> %.splatinsert, <16 x i32> poison, <16 x i32> zeroinitializer
+  %1 = and <16 x i32> %.splat, <i32 1, i32 2, i32 4, i32 8, i32 16, i32 32, i32 64, i32 128, i32 256, i32 512, i32 1024, i32 2048, i32 4096, i32 8192, i32 16384, i32 32768>
+  %hir.cmp.45 = icmp eq <16 x i32> %1, zeroinitializer
+  %2 = tail call <16 x float> @llvm.masked.load.v16f32.p0(ptr %b, i32 4, <16 x i1> %hir.cmp.45, <16 x float> poison)
+  %3 = tail call <16 x float> @llvm.masked.load.v16f32.p0(ptr %a, i32 4, <16 x i1> %hir.cmp.45, <16 x float> poison)
+  %4 = fadd reassoc nsz arcp contract afn <16 x float> %2, %3
+  tail call void @llvm.masked.store.v16f32.p0(<16 x float> %4, ptr %c, i32 4, <16 x i1> %hir.cmp.45)
+  ret void
+}
+
+define dso_local void @foo_8_ne(ptr nocapture noundef writeonly %c, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %mask) {
+; CHECK-LABEL: foo_8_ne:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    kmovb %ecx, %k1
+; CHECK-NEXT:    vmovups (%rdx), %ymm0 {%k1} {z}
+; CHECK-NEXT:    vmovups (%rsi), %ymm1 {%k1} {z}
+; CHECK-NEXT:    vaddps %ymm1, %ymm0, %ymm0
+; CHECK-NEXT:    vmovups %ymm0, (%rdi) {%k1}
+; CHECK-NEXT:    vzeroupper
+; CHECK-NEXT:    retq
+entry:
+  %0 = and i32 %mask, 65535
+  %.splatinsert = insertelement <8 x i32> poison, i32 %0, i64 0
+  %.splat = shufflevector <8 x i32> %.splatinsert, <8 x i32> poison, <8 x i32> zeroinitializer
+  %1 = and <8 x i32> %.splat, <i32 1, i32 2, i32 4, i32 8, i32 16, i32 32, i32 64, i32 128>
+  %hir.cmp.45 = icmp ne <8 x i32> %1, zeroinitializer
+  %2 = tail call <8 x float> @llvm.masked.load.v8f32.p0(ptr %b, i32 4, <8 x i1> %hir.cmp.45, <8 x float> poison)
+  %3 = tail call <8 x float> @llvm.masked.load.v8f32.p0(ptr %a, i32 4, <8 x i1> %hir.cmp.45, <8 x float> poison)
+  %4 = fadd reassoc nsz arcp contract afn <8 x float> %2, %3
+  tail call void @llvm.masked.store.v8f32.p0(<8 x float> %4, ptr %c, i32 4, <8 x i1> %hir.cmp.45)
+  ret void
+}
+
+define dso_local void @foo_8_eq(ptr nocapture noundef writeonly %c, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %mask) {
+; CHECK-LABEL: foo_8_eq:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    kmovb %ecx, %k0
+; CHECK-NEXT:    knotb %k0, %k1
+; CHECK-NEXT:    vmovups (%rdx), %ymm0 {%k1} {z}
+; CHECK-NEXT:    vmovups (%rsi), %ymm1 {%k1} {z}
+; CHECK-NEXT:    vaddps %ymm1, %ymm0, %ymm0
+; CHECK-NEXT:    vmovups %ymm0, (%rdi) {%k1}
+; CHECK-NEXT:    vzeroupper
+; CHECK-NEXT:    retq
+entry:
+  %0 = and i32 %mask, 65535
+  %.splatinsert = insertelement <8 x i32> poison, i32 %0, i64 0
+  %.splat = shufflevector <8 x i32> %.splatinsert, <8 x i32> poison, <8 x i32> zeroinitializer
+  %1 = and <8 x i32> %.splat, <i32 1, i32 2, i32 4, i32 8, i32 16, i32 32, i32 64, i32 128>
+  %hir.cmp.45 = icmp eq <8 x i32> %1, zeroinitializer
+  %2 = tail call <8 x float> @llvm.masked.load.v8f32.p0(ptr %b, i32 4, <8 x i1> %hir.cmp.45, <8 x float> poison)
+  %3 = tail call <8 x float> @llvm.masked.load.v8f32.p0(ptr %a, i32 4, <8 x i1> %hir.cmp.45, <8 x float> poison)
+  %4 = fadd reassoc nsz arcp contract afn <8 x float> %2, %3
+  tail call void @llvm.masked.store.v8f32.p0(<8 x float> %4, ptr %c, i32 4, <8 x i1> %hir.cmp.45)
+  ret void
+}
+
+define dso_local void @foo_4_ne(ptr nocapture noundef writeonly %c, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %mask) {
+; CHECK-LABEL: foo_4_ne:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    kmovb %ecx, %k1
+; CHECK-NEXT:    vmovups (%rdx), %xmm0 {%k1} {z}
+; CHECK-NEXT:    vmovups (%rsi), %xmm1 {%k1} {z}
+; CHECK-NEXT:    vaddps %xmm1, %xmm0, %xmm0
+; CHECK-NEXT:    vmovups %xmm0, (%rdi) {%k1}
+; CHECK-NEXT:    retq
+entry:
+  %0 = and i32 %mask, 65535
+  %.splatinsert = insertelement <4 x i32> poison, i32 %0, i64 0
+  %.splat = shufflevector <4 x i32> %.splatinsert, <4 x i32> poison, <4 x i32> zeroinitializer
+  %1 = and <4 x i32> %.splat, <i32 1, i32 2, i32 4, i32 8>
+  %hir.cmp.45 = icmp ne <4 x i32> %1, zeroinitializer
+  %2 = tail call <4 x float> @llvm.masked.load.v4f32.p0(ptr %b, i32 4, <4 x i1> %hir.cmp.45, <4 x float> poison)
+  %3 = tail call <4 x float> @llvm.masked.load.v4f32.p0(ptr %a, i32 4, <4 x i1> %hir.cmp.45, <4 x float> poison)
+  %4 = fadd reassoc nsz arcp contract afn <4 x float> %2, %3
+  tail call void @llvm.masked.store.v4f32.p0(<4 x float> %4, ptr %c, i32 4, <4 x i1> %hir.cmp.45)
+  ret void
+}
+
+define dso_local void @foo_4_eq(ptr nocapture noundef writeonly %c, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %mask) {
+; CHECK-LABEL: foo_4_eq:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    kmovb %ecx, %k0
+; CHECK-NEXT:    knotb %k0, %k1
+; CHECK-NEXT:    vmovups (%rdx), %xmm0 {%k1} {z}
+; CHECK-NEXT:    vmovups (%rsi), %xmm1 {%k1} {z}
+; CHECK-NEXT:    vaddps %xmm1, %xmm0, %xmm0
+; CHECK-NEXT:    vmovups %xmm0, (%rdi) {%k1}
+; CHECK-NEXT:    retq
+entry:
+  %0 = and i32 %mask, 65535
+  %.splatinsert = insertelement <4 x i32> poison, i32 %0, i64 0
+  %.splat = shufflevector <4 x i32> %.splatinsert, <4 x i32> poison, <4 x i32> zeroinitializer
+  %1 = and <4 x i32> %.splat, <i32 1, i32 2, i32 4, i32 8>
+  %hir.cmp.45 = icmp eq <4 x i32> %1, zeroinitializer
+  %2 = tail call <4 x float> @llvm.masked.load.v4f32.p0(ptr %b, i32 4, <4 x i1> %hir.cmp.45, <4 x float> poison)
+  %3 = tail call <4 x float> @llvm.masked.load.v4f32.p0(ptr %a, i32 4, <4 x i1> %hir.cmp.45, <4 x float> poison)
+  %4 = fadd reassoc nsz arcp contract afn <4 x float> %2, %3
+  tail call void @llvm.masked.store.v4f32.p0(<4 x float> %4, ptr %c, i32 4, <4 x i1> %hir.cmp.45)
+  ret void
+}
+
+define dso_local void @foo_2_ne(ptr nocapture noundef writeonly %c, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %mask) {
+; CHECK-LABEL: foo_2_ne:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    kmovb %ecx, %k0
+; CHECK-NEXT:    kshiftlb $6, %k0, %k0
+; CHECK-NEXT:    kshiftrb $6, %k0, %k1
+; CHECK-NEXT:    vmovups (%rdx), %xmm0 {%k1} {z}
+; CHECK-NEXT:    vmovups (%rsi), %xmm1 {%k1} {z}
+; CHECK-NEXT:    vaddps %xmm1, %xmm0, %xmm0
+; CHECK-NEXT:    vmovups %xmm0, (%rdi) {%k1}
+; CHECK-NEXT:    retq
+entry:
+  %.splatinsert = insertelement <2 x i32> poison, i32 %mask, i64 0
+  %.splat = shufflevector <2 x i32> %.splatinsert, <2 x i32> poison, <2 x i32> zeroinitializer
+  %0 = and <2 x i32> %.splat, <i32 1, i32 2>
+  %hir.cmp.44 = icmp ne <2 x i32> %0, zeroinitializer
+  %1 = tail call <2 x float> @llvm.masked.load.v2f32.p0(ptr %b, i32 4, <2 x i1> %hir.cmp.44, <2 x float> poison)
+  %2 = tail call <2 x float> @llvm.masked.load.v2f32.p0(ptr %a, i32 4, <2 x i1> %hir.cmp.44, <2 x float> poison)
+  %3 = fadd reassoc nsz arcp contract afn <2 x float> %1, %2
+  tail call void @llvm.masked.store.v2f32.p0(<2 x float> %3, ptr %c, i32 4, <2 x i1> %hir.cmp.44)
+  ret void
+}
+
+define dso_local void @foo_2_eq(ptr nocapture noundef writeonly %c, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %mask) {
+; CHECK-LABEL: foo_2_eq:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    kmovb %ecx, %k0
+; CHECK-NEXT:    knotb %k0, %k0
+; CHECK-NEXT:    kshiftlb $6, %k0, %k0
+; CHECK-NEXT:    kshiftrb $6, %k0, %k1
+; CHECK-NEXT:    vmovups (%rdx), %xmm0 {%k1} {z}
+; CHECK-NEXT:    vmovups (%rsi), %xmm1 {%k1} {z}
+; CHECK-NEXT:    vaddps %xmm1, %xmm0, %xmm0
+; CHECK-NEXT:    vmovups %xmm0, (%rdi) {%k1}
+; CHECK-NEXT:    retq
+entry:
+  %.splatinsert = insertelement <2 x i32> poison, i32 %mask, i64 0
+  %.splat = shufflevector <2 x i32> %.splatinsert, <2 x i32> poison, <2 x i32> zeroinitializer
+  %0 = and <2 x i32> %.splat, <i32 1, i32 2>
+  %hir.cmp.44 = icmp eq <2 x i32> %0, zeroinitializer
+  %1 = tail call <2 x float> @llvm.masked.load.v2f32.p0(ptr %b, i32 4, <2 x i1> %hir.cmp.44, <2 x float> poison)
+  %2 = tail call <2 x float> @llvm.masked.load.v2f32.p0(ptr %a, i32 4, <2 x i1> %hir.cmp.44, <2 x float> poison)
+  %3 = fadd reassoc nsz arcp contract afn <2 x float> %1, %2
+  tail call void @llvm.masked.store.v2f32.p0(<2 x float> %3, ptr %c, i32 4, <2 x i1> %hir.cmp.44)
+  ret void
+}
+
+declare <2 x float> @llvm.masked.load.v2f32.p0(ptr nocapture, i32 immarg, <2 x i1>, <2 x float>) #1
+
+declare void @llvm.masked.store.v2f32.p0(<2 x float>, ptr nocapture, i32 immarg, <2 x i1>) #2
+
+declare <4 x float> @llvm.masked.load.v4f32.p0(ptr nocapture, i32 immarg, <4 x i1>, <4 x float>) #1
+
+declare void @llvm.masked.store.v4f32.p0(<4 x float>, ptr nocapture, i32 immarg, <4 x i1>) #2
+
+declare <8 x float> @llvm.masked.load.v8f32.p0(ptr nocapture, i32 immarg, <8 x i1>, <8 x float>)
+
+declare void @llvm.masked.store.v8f32.p0(<8 x float>, ptr nocapture, i32 immarg, <8 x i1>)
+
+declare <16 x float> @llvm.masked.load.v16f32.p0(ptr nocapture, i32 immarg, <16 x i1>, <16 x float>)
+
+declare void @llvm.masked.store.v16f32.p0(<16 x float>, ptr nocapture, i32 immarg, <16 x i1>)
diff --git a/llvm/test/CodeGen/X86/pr78897.ll b/llvm/test/CodeGen/X86/pr78897.ll
index 56e4ec2bc8ecbb..38a1800df956b5 100644
--- a/llvm/test/CodeGen/X86/pr78897.ll
+++ b/llvm/test/CodeGen/X86/pr78897.ll
@@ -256,8 +256,8 @@ define <16 x i8> @produceShuffleVectorForByte(i8 zeroext %0) nounwind {
 ;
 ; X64-AVX512-LABEL: produceShuffleVectorForByte:
 ; X64-AVX512:       # %bb.0: # %entry
-; X64-AVX512-NEXT:    vpbroadcastb %edi, %xmm0
-; X64-AVX512-NEXT:    vptestnmb {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %k1
+; X64-AVX512-NEXT:    kmovw %edi, %k0
+; X64-AVX512-NEXT:    knotw %k0, %k1
 ; X64-AVX512-NEXT:    vmovdqu8 {{.*#+}} xmm0 {%k1} {z} = [17,17,17,17,17,17,17,17,u,u,u,u,u,u,u,u]
 ; X64-AVX512-NEXT:    vmovq %xmm0, %rax
 ; X64-AVX512-NEXT:    movabsq $1229782938247303440, %rcx # imm = 0x1111111111111110

>From 3f39f655f46dc7f9f3a5d3218ecffcaeb782fccb Mon Sep 17 00:00:00 2001
From: abhishek-kaushik22 <abhishek.kaushik at intel.com>
Date: Fri, 20 Dec 2024 12:29:36 +0530
Subject: [PATCH 2/2] Review Changes

---
 llvm/lib/Target/X86/X86ISelDAGToDAG.cpp | 23 +++++++++++------------
 llvm/test/CodeGen/X86/pr78897.ll        |  5 +++--
 2 files changed, 14 insertions(+), 14 deletions(-)

diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
index 8c199a30dfbce7..054ff6743b9a5b 100644
--- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
@@ -4976,7 +4976,8 @@ bool X86DAGToDAGISel::tryVPTESTMOrKMOV(SDNode *Root, SDValue Setcc,
   };
 
   auto canUseKMOV = [&]() {
-    if (Src0.getOpcode() != X86ISD::VBROADCAST)
+    if (Src0.getOpcode() != X86ISD::VBROADCAST &&
+        Src0.getOpcode() != X86ISD::VBROADCAST_LOAD)
       return false;
 
     if (Src1.getOpcode() != ISD::LOAD ||
@@ -4994,20 +4995,18 @@ bool X86DAGToDAGISel::tryVPTESTMOrKMOV(SDNode *Root, SDValue Setcc,
     if (!ConstVecType)
       return false;
 
-    for (unsigned i = 0, e = ConstVecType->getNumElements(), k = 1; i != e;
-         ++i, k *= 2) {
-      const auto *Element = ConstVec->getAggregateElement(i);
+    for (unsigned I = 0, E = ConstVecType->getNumElements(); I != E; ++I) {
+      const auto *Element = ConstVec->getAggregateElement(I);
       if (llvm::isa<llvm::UndefValue>(Element)) {
-        for (unsigned j = i + 1; j != e; ++j) {
-          if (!llvm::isa<llvm::UndefValue>(ConstVec->getAggregateElement(j)))
+        for (unsigned J = I + 1; J != E; ++J) {
+          if (!llvm::isa<llvm::UndefValue>(ConstVec->getAggregateElement(J)))
             return false;
         }
-        return i != 0;
+        return I != 0;
       }
 
-      if (Element->getUniqueInteger() != k) {
+      if (Element->getUniqueInteger() != 1 << I)
         return false;
-      }
     }
 
     return true;
@@ -5024,10 +5023,10 @@ bool X86DAGToDAGISel::tryVPTESTMOrKMOV(SDNode *Root, SDValue Setcc,
   MVT ResVT = Setcc.getSimpleValueType();
   if (CanFoldLoads) {
     if (canUseKMOV()) {
-      auto Op = Src0.getOperand(0);
-      if (Op.getSimpleValueType() == MVT::i8) {
+      auto Op = Src0.getOpcode() == X86ISD::VBROADCAST ? Src0.getOperand(0)
+                                                       : Src0.getOperand(1);
+      if (Op.getSimpleValueType() == MVT::i8)
         Op = SDValue(CurDAG->getNode(ISD::ZERO_EXTEND, dl, MVT::i32, Op));
-      }
       CNode = CurDAG->getMachineNode(
           ResVT.getVectorNumElements() <= 8 ? X86::KMOVBkr : X86::KMOVWkr, dl,
           ResVT, Op);
diff --git a/llvm/test/CodeGen/X86/pr78897.ll b/llvm/test/CodeGen/X86/pr78897.ll
index 38a1800df956b5..0c4c03de5901ec 100644
--- a/llvm/test/CodeGen/X86/pr78897.ll
+++ b/llvm/test/CodeGen/X86/pr78897.ll
@@ -223,8 +223,9 @@ define <16 x i8> @produceShuffleVectorForByte(i8 zeroext %0) nounwind {
 ; X86-AVX512-NEXT:    pushl %ebx
 ; X86-AVX512-NEXT:    pushl %edi
 ; X86-AVX512-NEXT:    pushl %esi
-; X86-AVX512-NEXT:    vpbroadcastb {{[0-9]+}}(%esp), %xmm0
-; X86-AVX512-NEXT:    vptestnmb {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0, %k1
+; X86-AVX512-NEXT:    leal {{[0-9]+}}(%esp)
+; X86-AVX512-NEXT:    kmovw %eax, %k0
+; X86-AVX512-NEXT:    knotw %k0, %k1
 ; X86-AVX512-NEXT:    vmovdqu8 {{.*#+}} xmm0 {%k1} {z} = [17,17,17,17,17,17,17,17,u,u,u,u,u,u,u,u]
 ; X86-AVX512-NEXT:    vpextrd $1, %xmm0, %eax
 ; X86-AVX512-NEXT:    vmovd %xmm0, %edx



More information about the llvm-commits mailing list