[llvm-branch-commits] [llvm] [SelectionDAG] Split vector types for atomic load (PR #120640)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue May 6 08:04:29 PDT 2025


https://github.com/jofrn updated https://github.com/llvm/llvm-project/pull/120640

>From cd4402a7d285dbe90539e4cf13680cf773bdfdba Mon Sep 17 00:00:00 2001
From: jofrn <jofernau at amd.com>
Date: Thu, 19 Dec 2024 16:25:55 -0500
Subject: [PATCH] [SelectionDAG] Split vector types for atomic load

Vector types that aren't widened are split
so that a single ATOMIC_LOAD is issued for the entire vector at once.
This change utilizes the load vectorization infrastructure in
SelectionDAG in order to group the vectors. This enables SelectionDAG
to translate vectors with type bfloat,half.

commit-id:3a045357
---
 llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h |   1 +
 .../SelectionDAG/LegalizeVectorTypes.cpp      |  39 ++++
 llvm/test/CodeGen/X86/atomic-load-store.ll    | 171 ++++++++++++++++++
 3 files changed, 211 insertions(+)

diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index bdfa5f7741ad3..d8f402f529632 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -960,6 +960,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   void SplitVecRes_FPOp_MultiType(SDNode *N, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_IS_FPCLASS(SDNode *N, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_INSERT_VECTOR_ELT(SDNode *N, SDValue &Lo, SDValue &Hi);
+  void SplitVecRes_ATOMIC_LOAD(AtomicSDNode *LD, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_LOAD(LoadSDNode *LD, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_VP_LOAD(VPLoadSDNode *LD, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_VP_STRIDED_LOAD(VPStridedLoadSDNode *SLD, SDValue &Lo,
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 6b3467573a0a2..d321df77fbce0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -1172,6 +1172,9 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
     SplitVecRes_STEP_VECTOR(N, Lo, Hi);
     break;
   case ISD::SIGN_EXTEND_INREG: SplitVecRes_InregOp(N, Lo, Hi); break;
+  case ISD::ATOMIC_LOAD:
+    SplitVecRes_ATOMIC_LOAD(cast<AtomicSDNode>(N), Lo, Hi);
+    break;
   case ISD::LOAD:
     SplitVecRes_LOAD(cast<LoadSDNode>(N), Lo, Hi);
     break;
@@ -1421,6 +1424,42 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
     SetSplitVector(SDValue(N, ResNo), Lo, Hi);
 }
 
+void DAGTypeLegalizer::SplitVecRes_ATOMIC_LOAD(AtomicSDNode *LD, SDValue &Lo,
+                                               SDValue &Hi) {
+  EVT LoVT, HiVT;
+  SDLoc dl(LD);
+  std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(LD->getValueType(0));
+
+  ISD::LoadExtType ExtType = LD->getExtensionType();
+  SDValue Ch = LD->getChain();
+  SDValue Ptr = LD->getBasePtr();
+  EVT MemoryVT = LD->getMemoryVT();
+
+  EVT LoMemVT, HiMemVT;
+  std::tie(LoMemVT, HiMemVT) = DAG.GetSplitDestVTs(MemoryVT);
+
+  EVT IntVT =
+      EVT::getIntegerVT(*DAG.getContext(), LD->getValueType(0).getSizeInBits());
+  EVT MemIntVT =
+      EVT::getIntegerVT(*DAG.getContext(), 2 * LoMemVT.getSizeInBits());
+  SDValue ALD = DAG.getAtomicLoad(ExtType, dl, MemIntVT, IntVT, Ch, Ptr,
+                                  LD->getMemOperand());
+
+  EVT LoIntVT = EVT::getIntegerVT(*DAG.getContext(), LoVT.getSizeInBits());
+  EVT HiIntVT = EVT::getIntegerVT(*DAG.getContext(), HiVT.getSizeInBits());
+  SDValue ExtractLo = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, LoIntVT, ALD,
+                                  DAG.getIntPtrConstant(0, dl));
+  SDValue ExtractHi = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, HiIntVT, ALD,
+                                  DAG.getIntPtrConstant(1, dl));
+
+  Lo = DAG.getBitcast(LoVT, ExtractLo);
+  Hi = DAG.getBitcast(HiVT, ExtractHi);
+
+  // Legalize the chain result - switch anything that used the old chain to
+  // use the new one.
+  ReplaceValueWith(SDValue(LD, 1), ALD.getValue(1));
+}
+
 void DAGTypeLegalizer::IncrementPointer(MemSDNode *N, EVT MemVT,
                                         MachinePointerInfo &MPI, SDValue &Ptr,
                                         uint64_t *ScaledOffset) {
diff --git a/llvm/test/CodeGen/X86/atomic-load-store.ll b/llvm/test/CodeGen/X86/atomic-load-store.ll
index 935d058a52f8f..42b0955824293 100644
--- a/llvm/test/CodeGen/X86/atomic-load-store.ll
+++ b/llvm/test/CodeGen/X86/atomic-load-store.ll
@@ -204,6 +204,68 @@ define <2 x float> @atomic_vec2_float_align(ptr %x) {
   ret <2 x float> %ret
 }
 
+define <2 x half> @atomic_vec2_half(ptr %x) {
+; CHECK3-LABEL: atomic_vec2_half:
+; CHECK3:       ## %bb.0:
+; CHECK3-NEXT:    movl (%rdi), %eax
+; CHECK3-NEXT:    pinsrw $0, %eax, %xmm0
+; CHECK3-NEXT:    shrl $16, %eax
+; CHECK3-NEXT:    pinsrw $0, %eax, %xmm1
+; CHECK3-NEXT:    punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3]
+; CHECK3-NEXT:    retq
+;
+; CHECK0-LABEL: atomic_vec2_half:
+; CHECK0:       ## %bb.0:
+; CHECK0-NEXT:    movl (%rdi), %eax
+; CHECK0-NEXT:    movl %eax, %ecx
+; CHECK0-NEXT:    shrl $16, %ecx
+; CHECK0-NEXT:    movw %cx, %dx
+; CHECK0-NEXT:    ## implicit-def: $ecx
+; CHECK0-NEXT:    movw %dx, %cx
+; CHECK0-NEXT:    ## implicit-def: $xmm1
+; CHECK0-NEXT:    pinsrw $0, %ecx, %xmm1
+; CHECK0-NEXT:    movw %ax, %cx
+; CHECK0-NEXT:    ## implicit-def: $eax
+; CHECK0-NEXT:    movw %cx, %ax
+; CHECK0-NEXT:    ## implicit-def: $xmm0
+; CHECK0-NEXT:    pinsrw $0, %eax, %xmm0
+; CHECK0-NEXT:    punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3]
+; CHECK0-NEXT:    retq
+  %ret = load atomic <2 x half>, ptr %x acquire, align 4
+  ret <2 x half> %ret
+}
+
+define <2 x bfloat> @atomic_vec2_bfloat(ptr %x) {
+; CHECK3-LABEL: atomic_vec2_bfloat:
+; CHECK3:       ## %bb.0:
+; CHECK3-NEXT:    movl (%rdi), %eax
+; CHECK3-NEXT:    pinsrw $0, %eax, %xmm0
+; CHECK3-NEXT:    shrl $16, %eax
+; CHECK3-NEXT:    pinsrw $0, %eax, %xmm1
+; CHECK3-NEXT:    punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3]
+; CHECK3-NEXT:    retq
+;
+; CHECK0-LABEL: atomic_vec2_bfloat:
+; CHECK0:       ## %bb.0:
+; CHECK0-NEXT:    movl (%rdi), %eax
+; CHECK0-NEXT:    movl %eax, %ecx
+; CHECK0-NEXT:    shrl $16, %ecx
+; CHECK0-NEXT:    ## kill: def $cx killed $cx killed $ecx
+; CHECK0-NEXT:    movw %ax, %dx
+; CHECK0-NEXT:    ## implicit-def: $eax
+; CHECK0-NEXT:    movw %dx, %ax
+; CHECK0-NEXT:    ## implicit-def: $xmm0
+; CHECK0-NEXT:    pinsrw $0, %eax, %xmm0
+; CHECK0-NEXT:    ## implicit-def: $eax
+; CHECK0-NEXT:    movw %cx, %ax
+; CHECK0-NEXT:    ## implicit-def: $xmm1
+; CHECK0-NEXT:    pinsrw $0, %eax, %xmm1
+; CHECK0-NEXT:    punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3]
+; CHECK0-NEXT:    retq
+  %ret = load atomic <2 x bfloat>, ptr %x acquire, align 4
+  ret <2 x bfloat> %ret
+}
+
 define <1 x ptr> @atomic_vec1_ptr(ptr %x) nounwind {
 ; CHECK3-LABEL: atomic_vec1_ptr:
 ; CHECK3:       ## %bb.0:
@@ -376,6 +438,115 @@ define <4 x i16> @atomic_vec4_i16(ptr %x) nounwind {
   ret <4 x i16> %ret
 }
 
+define <4 x half> @atomic_vec4_half(ptr %x) nounwind {
+; CHECK3-LABEL: atomic_vec4_half:
+; CHECK3:       ## %bb.0:
+; CHECK3-NEXT:    movq (%rdi), %rax
+; CHECK3-NEXT:    movl %eax, %ecx
+; CHECK3-NEXT:    shrl $16, %ecx
+; CHECK3-NEXT:    pinsrw $0, %ecx, %xmm1
+; CHECK3-NEXT:    pinsrw $0, %eax, %xmm0
+; CHECK3-NEXT:    movq %rax, %rcx
+; CHECK3-NEXT:    shrq $32, %rcx
+; CHECK3-NEXT:    pinsrw $0, %ecx, %xmm2
+; CHECK3-NEXT:    shrq $48, %rax
+; CHECK3-NEXT:    pinsrw $0, %eax, %xmm3
+; CHECK3-NEXT:    punpcklwd {{.*#+}} xmm2 = xmm2[0],xmm3[0],xmm2[1],xmm3[1],xmm2[2],xmm3[2],xmm2[3],xmm3[3]
+; CHECK3-NEXT:    punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3]
+; CHECK3-NEXT:    punpckldq {{.*#+}} xmm0 = xmm0[0],xmm2[0],xmm0[1],xmm2[1]
+; CHECK3-NEXT:    retq
+;
+; CHECK0-LABEL: atomic_vec4_half:
+; CHECK0:       ## %bb.0:
+; CHECK0-NEXT:    movq (%rdi), %rax
+; CHECK0-NEXT:    movl %eax, %ecx
+; CHECK0-NEXT:    shrl $16, %ecx
+; CHECK0-NEXT:    movw %cx, %dx
+; CHECK0-NEXT:    ## implicit-def: $ecx
+; CHECK0-NEXT:    movw %dx, %cx
+; CHECK0-NEXT:    ## implicit-def: $xmm2
+; CHECK0-NEXT:    pinsrw $0, %ecx, %xmm2
+; CHECK0-NEXT:    movw %ax, %dx
+; CHECK0-NEXT:    ## implicit-def: $ecx
+; CHECK0-NEXT:    movw %dx, %cx
+; CHECK0-NEXT:    ## implicit-def: $xmm0
+; CHECK0-NEXT:    pinsrw $0, %ecx, %xmm0
+; CHECK0-NEXT:    movq %rax, %rcx
+; CHECK0-NEXT:    shrq $32, %rcx
+; CHECK0-NEXT:    movw %cx, %dx
+; CHECK0-NEXT:    ## implicit-def: $ecx
+; CHECK0-NEXT:    movw %dx, %cx
+; CHECK0-NEXT:    ## implicit-def: $xmm1
+; CHECK0-NEXT:    pinsrw $0, %ecx, %xmm1
+; CHECK0-NEXT:    shrq $48, %rax
+; CHECK0-NEXT:    movw %ax, %cx
+; CHECK0-NEXT:    ## implicit-def: $eax
+; CHECK0-NEXT:    movw %cx, %ax
+; CHECK0-NEXT:    ## implicit-def: $xmm3
+; CHECK0-NEXT:    pinsrw $0, %eax, %xmm3
+; CHECK0-NEXT:    punpcklwd {{.*#+}} xmm1 = xmm1[0],xmm3[0],xmm1[1],xmm3[1],xmm1[2],xmm3[2],xmm1[3],xmm3[3]
+; CHECK0-NEXT:    punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm2[0],xmm0[1],xmm2[1],xmm0[2],xmm2[2],xmm0[3],xmm2[3]
+; CHECK0-NEXT:    unpcklps {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
+; CHECK0-NEXT:    retq
+  %ret = load atomic <4 x half>, ptr %x acquire, align 8
+  ret <4 x half> %ret
+}
+
+define <4 x bfloat> @atomic_vec4_bfloat(ptr %x) nounwind {
+; CHECK3-LABEL: atomic_vec4_bfloat:
+; CHECK3:       ## %bb.0:
+; CHECK3-NEXT:    movq (%rdi), %rax
+; CHECK3-NEXT:    movq %rax, %rcx
+; CHECK3-NEXT:    movq %rax, %rdx
+; CHECK3-NEXT:    pinsrw $0, %eax, %xmm0
+; CHECK3-NEXT:    ## kill: def $eax killed $eax killed $rax
+; CHECK3-NEXT:    shrl $16, %eax
+; CHECK3-NEXT:    shrq $32, %rcx
+; CHECK3-NEXT:    shrq $48, %rdx
+; CHECK3-NEXT:    pinsrw $0, %edx, %xmm1
+; CHECK3-NEXT:    pinsrw $0, %ecx, %xmm2
+; CHECK3-NEXT:    punpcklwd {{.*#+}} xmm2 = xmm2[0],xmm1[0],xmm2[1],xmm1[1],xmm2[2],xmm1[2],xmm2[3],xmm1[3]
+; CHECK3-NEXT:    pinsrw $0, %eax, %xmm1
+; CHECK3-NEXT:    punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3]
+; CHECK3-NEXT:    punpckldq {{.*#+}} xmm0 = xmm0[0],xmm2[0],xmm0[1],xmm2[1]
+; CHECK3-NEXT:    retq
+;
+; CHECK0-LABEL: atomic_vec4_bfloat:
+; CHECK0:       ## %bb.0:
+; CHECK0-NEXT:    movq (%rdi), %rax
+; CHECK0-NEXT:    movl %eax, %ecx
+; CHECK0-NEXT:    shrl $16, %ecx
+; CHECK0-NEXT:    ## kill: def $cx killed $cx killed $ecx
+; CHECK0-NEXT:    movw %ax, %dx
+; CHECK0-NEXT:    movq %rax, %rsi
+; CHECK0-NEXT:    shrq $32, %rsi
+; CHECK0-NEXT:    ## kill: def $si killed $si killed $rsi
+; CHECK0-NEXT:    shrq $48, %rax
+; CHECK0-NEXT:    movw %ax, %di
+; CHECK0-NEXT:    ## implicit-def: $eax
+; CHECK0-NEXT:    movw %di, %ax
+; CHECK0-NEXT:    ## implicit-def: $xmm0
+; CHECK0-NEXT:    pinsrw $0, %eax, %xmm0
+; CHECK0-NEXT:    ## implicit-def: $eax
+; CHECK0-NEXT:    movw %si, %ax
+; CHECK0-NEXT:    ## implicit-def: $xmm1
+; CHECK0-NEXT:    pinsrw $0, %eax, %xmm1
+; CHECK0-NEXT:    punpcklwd {{.*#+}} xmm1 = xmm1[0],xmm0[0],xmm1[1],xmm0[1],xmm1[2],xmm0[2],xmm1[3],xmm0[3]
+; CHECK0-NEXT:    ## implicit-def: $eax
+; CHECK0-NEXT:    movw %dx, %ax
+; CHECK0-NEXT:    ## implicit-def: $xmm0
+; CHECK0-NEXT:    pinsrw $0, %eax, %xmm0
+; CHECK0-NEXT:    ## implicit-def: $eax
+; CHECK0-NEXT:    movw %cx, %ax
+; CHECK0-NEXT:    ## implicit-def: $xmm2
+; CHECK0-NEXT:    pinsrw $0, %eax, %xmm2
+; CHECK0-NEXT:    punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm2[0],xmm0[1],xmm2[1],xmm0[2],xmm2[2],xmm0[3],xmm2[3]
+; CHECK0-NEXT:    unpcklps {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
+; CHECK0-NEXT:    retq
+  %ret = load atomic <4 x bfloat>, ptr %x acquire, align 8
+  ret <4 x bfloat> %ret
+}
+
 define <4 x float> @atomic_vec4_float_align(ptr %x) nounwind {
 ; CHECK-LABEL: atomic_vec4_float_align:
 ; CHECK:       ## %bb.0:



More information about the llvm-branch-commits mailing list