[llvm] [SelectionDAG][X86] Split via Concat <n x T> vector types for atomic load (PR #120640)

via llvm-commits llvm-commits at lists.llvm.org
Sun Feb 2 12:25:43 PST 2025


https://github.com/jofrn updated https://github.com/llvm/llvm-project/pull/120640

>From 76e068350de93eee06b3cc4616adadf33e9d84c0 Mon Sep 17 00:00:00 2001
From: jofrn <jofernau at amd.com>
Date: Thu, 19 Dec 2024 16:25:55 -0500
Subject: [PATCH] [SelectionDAG][X86] Split via Concat <n x T> vector types for
 atomic load

Vector types that aren't widened are 'split' via CONCAT_VECTORS
so that a single ATOMIC_LOAD is issued for the entire vector at once.
This change utilizes the load vectorization infrastructure in
SelectionDAG in order to group the vectors. This enables SelectionDAG
to translate vectors with type bfloat,half.

commit-id:3a045357
---
 llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h |   1 +
 .../SelectionDAG/LegalizeVectorTypes.cpp      |  32 ++++
 llvm/test/CodeGen/X86/atomic-load-store.ll    | 171 ++++++++++++++++++
 3 files changed, 204 insertions(+)

diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 593d511f50f5c3..da7622db4bde23 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -948,6 +948,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   void SplitVecRes_FPOp_MultiType(SDNode *N, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_IS_FPCLASS(SDNode *N, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_INSERT_VECTOR_ELT(SDNode *N, SDValue &Lo, SDValue &Hi);
+  void SplitVecRes_ATOMIC_LOAD(AtomicSDNode *LD);
   void SplitVecRes_LOAD(LoadSDNode *LD, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_VP_LOAD(VPLoadSDNode *LD, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_VP_STRIDED_LOAD(VPStridedLoadSDNode *SLD, SDValue &Lo,
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index f794e4f35c8343..789866d31bf356 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -1152,6 +1152,9 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
     SplitVecRes_STEP_VECTOR(N, Lo, Hi);
     break;
   case ISD::SIGN_EXTEND_INREG: SplitVecRes_InregOp(N, Lo, Hi); break;
+  case ISD::ATOMIC_LOAD:
+    SplitVecRes_ATOMIC_LOAD(cast<AtomicSDNode>(N));
+    break;
   case ISD::LOAD:
     SplitVecRes_LOAD(cast<LoadSDNode>(N), Lo, Hi);
     break;
@@ -1395,6 +1398,35 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
     SetSplitVector(SDValue(N, ResNo), Lo, Hi);
 }
 
+void DAGTypeLegalizer::SplitVecRes_ATOMIC_LOAD(AtomicSDNode *LD) {
+  SDLoc dl(LD);
+
+  EVT MemoryVT = LD->getMemoryVT();
+  unsigned NumElts = MemoryVT.getVectorMinNumElements();
+
+  EVT IntMemoryVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16, NumElts);
+  EVT ElemVT = EVT::getVectorVT(*DAG.getContext(),
+                                MemoryVT.getVectorElementType(), 1);
+
+  // Create a single atomic to load all the elements at once.
+  SDValue Atomic = DAG.getAtomic(ISD::ATOMIC_LOAD, dl, IntMemoryVT, IntMemoryVT,
+                                 LD->getChain(), LD->getBasePtr(),
+                                 LD->getMemOperand());
+
+  // Instead of splitting, put all the elements back into a vector.
+  SmallVector<SDValue, 4> Ops;
+  for (unsigned i = 0; i < NumElts; ++i) {
+    SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i16, Atomic,
+                              DAG.getVectorIdxConstant(i, dl));
+    Elt = DAG.getBitcast(ElemVT, Elt);
+    Ops.push_back(Elt);
+  }
+  SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, dl, MemoryVT, Ops);
+
+  ReplaceValueWith(SDValue(LD, 0), Concat);
+  ReplaceValueWith(SDValue(LD, 1), LD->getChain());
+}
+
 void DAGTypeLegalizer::IncrementPointer(MemSDNode *N, EVT MemVT,
                                         MachinePointerInfo &MPI, SDValue &Ptr,
                                         uint64_t *ScaledOffset) {
diff --git a/llvm/test/CodeGen/X86/atomic-load-store.ll b/llvm/test/CodeGen/X86/atomic-load-store.ll
index 935d058a52f8fe..42b09558242934 100644
--- a/llvm/test/CodeGen/X86/atomic-load-store.ll
+++ b/llvm/test/CodeGen/X86/atomic-load-store.ll
@@ -204,6 +204,68 @@ define <2 x float> @atomic_vec2_float_align(ptr %x) {
   ret <2 x float> %ret
 }
 
+define <2 x half> @atomic_vec2_half(ptr %x) {
+; CHECK3-LABEL: atomic_vec2_half:
+; CHECK3:       ## %bb.0:
+; CHECK3-NEXT:    movl (%rdi), %eax
+; CHECK3-NEXT:    pinsrw $0, %eax, %xmm0
+; CHECK3-NEXT:    shrl $16, %eax
+; CHECK3-NEXT:    pinsrw $0, %eax, %xmm1
+; CHECK3-NEXT:    punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3]
+; CHECK3-NEXT:    retq
+;
+; CHECK0-LABEL: atomic_vec2_half:
+; CHECK0:       ## %bb.0:
+; CHECK0-NEXT:    movl (%rdi), %eax
+; CHECK0-NEXT:    movl %eax, %ecx
+; CHECK0-NEXT:    shrl $16, %ecx
+; CHECK0-NEXT:    movw %cx, %dx
+; CHECK0-NEXT:    ## implicit-def: $ecx
+; CHECK0-NEXT:    movw %dx, %cx
+; CHECK0-NEXT:    ## implicit-def: $xmm1
+; CHECK0-NEXT:    pinsrw $0, %ecx, %xmm1
+; CHECK0-NEXT:    movw %ax, %cx
+; CHECK0-NEXT:    ## implicit-def: $eax
+; CHECK0-NEXT:    movw %cx, %ax
+; CHECK0-NEXT:    ## implicit-def: $xmm0
+; CHECK0-NEXT:    pinsrw $0, %eax, %xmm0
+; CHECK0-NEXT:    punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3]
+; CHECK0-NEXT:    retq
+  %ret = load atomic <2 x half>, ptr %x acquire, align 4
+  ret <2 x half> %ret
+}
+
+define <2 x bfloat> @atomic_vec2_bfloat(ptr %x) {
+; CHECK3-LABEL: atomic_vec2_bfloat:
+; CHECK3:       ## %bb.0:
+; CHECK3-NEXT:    movl (%rdi), %eax
+; CHECK3-NEXT:    pinsrw $0, %eax, %xmm0
+; CHECK3-NEXT:    shrl $16, %eax
+; CHECK3-NEXT:    pinsrw $0, %eax, %xmm1
+; CHECK3-NEXT:    punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3]
+; CHECK3-NEXT:    retq
+;
+; CHECK0-LABEL: atomic_vec2_bfloat:
+; CHECK0:       ## %bb.0:
+; CHECK0-NEXT:    movl (%rdi), %eax
+; CHECK0-NEXT:    movl %eax, %ecx
+; CHECK0-NEXT:    shrl $16, %ecx
+; CHECK0-NEXT:    ## kill: def $cx killed $cx killed $ecx
+; CHECK0-NEXT:    movw %ax, %dx
+; CHECK0-NEXT:    ## implicit-def: $eax
+; CHECK0-NEXT:    movw %dx, %ax
+; CHECK0-NEXT:    ## implicit-def: $xmm0
+; CHECK0-NEXT:    pinsrw $0, %eax, %xmm0
+; CHECK0-NEXT:    ## implicit-def: $eax
+; CHECK0-NEXT:    movw %cx, %ax
+; CHECK0-NEXT:    ## implicit-def: $xmm1
+; CHECK0-NEXT:    pinsrw $0, %eax, %xmm1
+; CHECK0-NEXT:    punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3]
+; CHECK0-NEXT:    retq
+  %ret = load atomic <2 x bfloat>, ptr %x acquire, align 4
+  ret <2 x bfloat> %ret
+}
+
 define <1 x ptr> @atomic_vec1_ptr(ptr %x) nounwind {
 ; CHECK3-LABEL: atomic_vec1_ptr:
 ; CHECK3:       ## %bb.0:
@@ -376,6 +438,115 @@ define <4 x i16> @atomic_vec4_i16(ptr %x) nounwind {
   ret <4 x i16> %ret
 }
 
+define <4 x half> @atomic_vec4_half(ptr %x) nounwind {
+; CHECK3-LABEL: atomic_vec4_half:
+; CHECK3:       ## %bb.0:
+; CHECK3-NEXT:    movq (%rdi), %rax
+; CHECK3-NEXT:    movl %eax, %ecx
+; CHECK3-NEXT:    shrl $16, %ecx
+; CHECK3-NEXT:    pinsrw $0, %ecx, %xmm1
+; CHECK3-NEXT:    pinsrw $0, %eax, %xmm0
+; CHECK3-NEXT:    movq %rax, %rcx
+; CHECK3-NEXT:    shrq $32, %rcx
+; CHECK3-NEXT:    pinsrw $0, %ecx, %xmm2
+; CHECK3-NEXT:    shrq $48, %rax
+; CHECK3-NEXT:    pinsrw $0, %eax, %xmm3
+; CHECK3-NEXT:    punpcklwd {{.*#+}} xmm2 = xmm2[0],xmm3[0],xmm2[1],xmm3[1],xmm2[2],xmm3[2],xmm2[3],xmm3[3]
+; CHECK3-NEXT:    punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3]
+; CHECK3-NEXT:    punpckldq {{.*#+}} xmm0 = xmm0[0],xmm2[0],xmm0[1],xmm2[1]
+; CHECK3-NEXT:    retq
+;
+; CHECK0-LABEL: atomic_vec4_half:
+; CHECK0:       ## %bb.0:
+; CHECK0-NEXT:    movq (%rdi), %rax
+; CHECK0-NEXT:    movl %eax, %ecx
+; CHECK0-NEXT:    shrl $16, %ecx
+; CHECK0-NEXT:    movw %cx, %dx
+; CHECK0-NEXT:    ## implicit-def: $ecx
+; CHECK0-NEXT:    movw %dx, %cx
+; CHECK0-NEXT:    ## implicit-def: $xmm2
+; CHECK0-NEXT:    pinsrw $0, %ecx, %xmm2
+; CHECK0-NEXT:    movw %ax, %dx
+; CHECK0-NEXT:    ## implicit-def: $ecx
+; CHECK0-NEXT:    movw %dx, %cx
+; CHECK0-NEXT:    ## implicit-def: $xmm0
+; CHECK0-NEXT:    pinsrw $0, %ecx, %xmm0
+; CHECK0-NEXT:    movq %rax, %rcx
+; CHECK0-NEXT:    shrq $32, %rcx
+; CHECK0-NEXT:    movw %cx, %dx
+; CHECK0-NEXT:    ## implicit-def: $ecx
+; CHECK0-NEXT:    movw %dx, %cx
+; CHECK0-NEXT:    ## implicit-def: $xmm1
+; CHECK0-NEXT:    pinsrw $0, %ecx, %xmm1
+; CHECK0-NEXT:    shrq $48, %rax
+; CHECK0-NEXT:    movw %ax, %cx
+; CHECK0-NEXT:    ## implicit-def: $eax
+; CHECK0-NEXT:    movw %cx, %ax
+; CHECK0-NEXT:    ## implicit-def: $xmm3
+; CHECK0-NEXT:    pinsrw $0, %eax, %xmm3
+; CHECK0-NEXT:    punpcklwd {{.*#+}} xmm1 = xmm1[0],xmm3[0],xmm1[1],xmm3[1],xmm1[2],xmm3[2],xmm1[3],xmm3[3]
+; CHECK0-NEXT:    punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm2[0],xmm0[1],xmm2[1],xmm0[2],xmm2[2],xmm0[3],xmm2[3]
+; CHECK0-NEXT:    unpcklps {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
+; CHECK0-NEXT:    retq
+  %ret = load atomic <4 x half>, ptr %x acquire, align 8
+  ret <4 x half> %ret
+}
+
+define <4 x bfloat> @atomic_vec4_bfloat(ptr %x) nounwind {
+; CHECK3-LABEL: atomic_vec4_bfloat:
+; CHECK3:       ## %bb.0:
+; CHECK3-NEXT:    movq (%rdi), %rax
+; CHECK3-NEXT:    movq %rax, %rcx
+; CHECK3-NEXT:    movq %rax, %rdx
+; CHECK3-NEXT:    pinsrw $0, %eax, %xmm0
+; CHECK3-NEXT:    ## kill: def $eax killed $eax killed $rax
+; CHECK3-NEXT:    shrl $16, %eax
+; CHECK3-NEXT:    shrq $32, %rcx
+; CHECK3-NEXT:    shrq $48, %rdx
+; CHECK3-NEXT:    pinsrw $0, %edx, %xmm1
+; CHECK3-NEXT:    pinsrw $0, %ecx, %xmm2
+; CHECK3-NEXT:    punpcklwd {{.*#+}} xmm2 = xmm2[0],xmm1[0],xmm2[1],xmm1[1],xmm2[2],xmm1[2],xmm2[3],xmm1[3]
+; CHECK3-NEXT:    pinsrw $0, %eax, %xmm1
+; CHECK3-NEXT:    punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3]
+; CHECK3-NEXT:    punpckldq {{.*#+}} xmm0 = xmm0[0],xmm2[0],xmm0[1],xmm2[1]
+; CHECK3-NEXT:    retq
+;
+; CHECK0-LABEL: atomic_vec4_bfloat:
+; CHECK0:       ## %bb.0:
+; CHECK0-NEXT:    movq (%rdi), %rax
+; CHECK0-NEXT:    movl %eax, %ecx
+; CHECK0-NEXT:    shrl $16, %ecx
+; CHECK0-NEXT:    ## kill: def $cx killed $cx killed $ecx
+; CHECK0-NEXT:    movw %ax, %dx
+; CHECK0-NEXT:    movq %rax, %rsi
+; CHECK0-NEXT:    shrq $32, %rsi
+; CHECK0-NEXT:    ## kill: def $si killed $si killed $rsi
+; CHECK0-NEXT:    shrq $48, %rax
+; CHECK0-NEXT:    movw %ax, %di
+; CHECK0-NEXT:    ## implicit-def: $eax
+; CHECK0-NEXT:    movw %di, %ax
+; CHECK0-NEXT:    ## implicit-def: $xmm0
+; CHECK0-NEXT:    pinsrw $0, %eax, %xmm0
+; CHECK0-NEXT:    ## implicit-def: $eax
+; CHECK0-NEXT:    movw %si, %ax
+; CHECK0-NEXT:    ## implicit-def: $xmm1
+; CHECK0-NEXT:    pinsrw $0, %eax, %xmm1
+; CHECK0-NEXT:    punpcklwd {{.*#+}} xmm1 = xmm1[0],xmm0[0],xmm1[1],xmm0[1],xmm1[2],xmm0[2],xmm1[3],xmm0[3]
+; CHECK0-NEXT:    ## implicit-def: $eax
+; CHECK0-NEXT:    movw %dx, %ax
+; CHECK0-NEXT:    ## implicit-def: $xmm0
+; CHECK0-NEXT:    pinsrw $0, %eax, %xmm0
+; CHECK0-NEXT:    ## implicit-def: $eax
+; CHECK0-NEXT:    movw %cx, %ax
+; CHECK0-NEXT:    ## implicit-def: $xmm2
+; CHECK0-NEXT:    pinsrw $0, %eax, %xmm2
+; CHECK0-NEXT:    punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm2[0],xmm0[1],xmm2[1],xmm0[2],xmm2[2],xmm0[3],xmm2[3]
+; CHECK0-NEXT:    unpcklps {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
+; CHECK0-NEXT:    retq
+  %ret = load atomic <4 x bfloat>, ptr %x acquire, align 8
+  ret <4 x bfloat> %ret
+}
+
 define <4 x float> @atomic_vec4_float_align(ptr %x) nounwind {
 ; CHECK-LABEL: atomic_vec4_float_align:
 ; CHECK:       ## %bb.0:



More information about the llvm-commits mailing list