[llvm] 7e5d7d2 - [NVPTX] Correctly lower extending loads for fp16 vectors.

Artem Belevich via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 23 10:46:23 PDT 2023


Author: Artem Belevich
Date: 2023-06-23T10:45:49-07:00
New Revision: 7e5d7d208f5b3efb1b645312c70dbca56405b58a

URL: https://github.com/llvm/llvm-project/commit/7e5d7d208f5b3efb1b645312c70dbca56405b58a
DIFF: https://github.com/llvm/llvm-project/commit/7e5d7d208f5b3efb1b645312c70dbca56405b58a.diff

LOG: [NVPTX] Correctly lower extending loads for fp16 vectors.

Fixes https://github.com/llvm/llvm-project/issues/63436

Improves lowering of extending FP vector loads. We were previously splitting
them unnecessarily.

Differential Revision: https://reviews.llvm.org/D153477

Added: 
    

Modified: 
    llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
    llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
    llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
    llvm/test/CodeGen/NVPTX/vector-loads.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index e38fa9928802e..cb8a1867c44f0 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -1599,13 +1599,13 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
   EVT OrigType = N->getValueType(0);
   LoadSDNode *LdNode = dyn_cast<LoadSDNode>(N);
 
-  if (OrigType != EltVT && LdNode) {
+  if (OrigType != EltVT &&
+      (LdNode || (OrigType.isFloatingPoint() && EltVT.isFloatingPoint()))) {
     // We have an extending-load. The instruction we selected operates on the
     // smaller type, but the SDNode we are replacing has the larger type. We
     // need to emit a CVT to make the types match.
-    bool IsSigned = LdNode->getExtensionType() == ISD::SEXTLOAD;
-    unsigned CvtOpc = GetConvertOpcode(OrigType.getSimpleVT(),
-                                       EltVT.getSimpleVT(), IsSigned);
+    unsigned CvtOpc =
+        GetConvertOpcode(OrigType.getSimpleVT(), EltVT.getSimpleVT(), LdNode);
 
     // For each output value, apply the manual sign/zero-extension and make sure
     // all users of the load go through that CVT.
@@ -3601,7 +3601,8 @@ bool NVPTXDAGToDAGISel::SelectInlineAsmMemoryOperand(
 /// GetConvertOpcode - Returns the CVT_ instruction opcode that implements a
 /// conversion from \p SrcTy to \p DestTy.
 unsigned NVPTXDAGToDAGISel::GetConvertOpcode(MVT DestTy, MVT SrcTy,
-                                             bool IsSigned) {
+                                             LoadSDNode *LdNode) {
+  bool IsSigned = LdNode && LdNode->getExtensionType() == ISD::SEXTLOAD;
   switch (SrcTy.SimpleTy) {
   default:
     llvm_unreachable("Unhandled source type");
@@ -3649,5 +3650,14 @@ unsigned NVPTXDAGToDAGISel::GetConvertOpcode(MVT DestTy, MVT SrcTy,
     case MVT::i32:
       return IsSigned ? NVPTX::CVT_s32_s64 : NVPTX::CVT_u32_u64;
     }
+  case MVT::f16:
+    switch (DestTy.SimpleTy) {
+    default:
+      llvm_unreachable("Unhandled dest type");
+    case MVT::f32:
+      return NVPTX::CVT_f32_f16;
+    case MVT::f64:
+      return NVPTX::CVT_f64_f16;
+    }
   }
 }

diff  --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index 746a9de5a2019..2a8ee5089ca02 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -97,7 +97,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
 
   bool ChkMemSDNodeAddressSpace(SDNode *N, unsigned int spN) const;
 
-  static unsigned GetConvertOpcode(MVT DestTy, MVT SrcTy, bool IsSigned);
+  static unsigned GetConvertOpcode(MVT DestTy, MVT SrcTy, LoadSDNode *N);
 };
 } // end namespace llvm
 

diff  --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index c46ed2111258c..5c16e34660c71 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -5146,7 +5146,8 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
 
   Align Alignment = LD->getAlign();
   auto &TD = DAG.getDataLayout();
-  Align PrefAlign = TD.getPrefTypeAlign(ResVT.getTypeForEVT(*DAG.getContext()));
+  Align PrefAlign =
+      TD.getPrefTypeAlign(LD->getMemoryVT().getTypeForEVT(*DAG.getContext()));
   if (Alignment < PrefAlign) {
     // This load is not sufficiently aligned, so bail out and let this vector
     // load be scalarized.  Note that we may still be able to emit smaller

diff  --git a/llvm/test/CodeGen/NVPTX/vector-loads.ll b/llvm/test/CodeGen/NVPTX/vector-loads.ll
index b700fb5e8f34c..850cef1b83de7 100644
--- a/llvm/test/CodeGen/NVPTX/vector-loads.ll
+++ b/llvm/test/CodeGen/NVPTX/vector-loads.ll
@@ -97,5 +97,58 @@ define void @foo_complex(ptr nocapture readonly align 16 dereferenceable(1342177
   ret void
 }
 
+; CHECK-LABEL: extv8f16_global_a16(
+define void @extv8f16_global_a16(ptr addrspace(1) noalias readonly align 16 %dst, ptr addrspace(1) noalias readonly align 16 %src) #0 {
+; CHECK: ld.global.v4.b16 {%f
+; CHECK: ld.global.v4.b16 {%f
+  %v = load <8 x half>, ptr addrspace(1) %src, align 16
+  %ext = fpext <8 x half> %v to <8 x float>
+; CHECK: st.global.v4.f32
+; CHECK: st.global.v4.f32
+  store <8 x float> %ext, ptr addrspace(1) %dst, align 16
+  ret void
+}
+
+; CHECK-LABEL: extv8f16_global_a4(
+define void @extv8f16_global_a4(ptr addrspace(1) noalias readonly align 16 %dst, ptr addrspace(1) noalias readonly align 16 %src) #0 {
+; CHECK: ld.global.v2.b16 {%f
+; CHECK: ld.global.v2.b16 {%f
+; CHECK: ld.global.v2.b16 {%f
+; CHECK: ld.global.v2.b16 {%f
+  %v = load <8 x half>, ptr addrspace(1) %src, align 4
+  %ext = fpext <8 x half> %v to <8 x float>
+; CHECK: st.global.v4.f32
+; CHECK: st.global.v4.f32
+  store <8 x float> %ext, ptr addrspace(1) %dst, align 16
+  ret void
+}
+
+
+; CHECK-LABEL: extv8f16_generic_a16(
+define void @extv8f16_generic_a16(ptr noalias readonly align 16 %dst, ptr noalias readonly align 16 %src) #0 {
+; CHECK: ld.v4.b16 {%f
+; CHECK: ld.v4.b16 {%f
+  %v = load <8 x half>, ptr %src, align 16
+  %ext = fpext <8 x half> %v to <8 x float>
+; CHECK: st.v4.f32
+; CHECK: st.v4.f32
+  store <8 x float> %ext, ptr %dst, align 16
+  ret void
+}
+
+; CHECK-LABEL: extv8f16_generic_a4(
+define void @extv8f16_generic_a4(ptr noalias readonly align 16 %dst, ptr noalias readonly align 16 %src) #0 {
+; CHECK: ld.v2.b16 {%f
+; CHECK: ld.v2.b16 {%f
+; CHECK: ld.v2.b16 {%f
+; CHECK: ld.v2.b16 {%f
+  %v = load <8 x half>, ptr %src, align 4
+  %ext = fpext <8 x half> %v to <8 x float>
+; CHECK: st.v4.f32
+; CHECK: st.v4.f32
+  store <8 x float> %ext, ptr %dst, align 16
+  ret void
+}
+
 
 !1 = !{i32 0, i32 64}


        


More information about the llvm-commits mailing list