[llvm] 7e5d7d2 - [NVPTX] Correctly lower extending loads for fp16 vectors.
Artem Belevich via llvm-commits
llvm-commits at lists.llvm.org
Fri Jun 23 10:46:23 PDT 2023
Author: Artem Belevich
Date: 2023-06-23T10:45:49-07:00
New Revision: 7e5d7d208f5b3efb1b645312c70dbca56405b58a
URL: https://github.com/llvm/llvm-project/commit/7e5d7d208f5b3efb1b645312c70dbca56405b58a
DIFF: https://github.com/llvm/llvm-project/commit/7e5d7d208f5b3efb1b645312c70dbca56405b58a.diff
LOG: [NVPTX] Correctly lower extending loads for fp16 vectors.
Fixes https://github.com/llvm/llvm-project/issues/63436
Improves lowering of extending FP vector loads. We were previously splitting
them unnecessarily.
Differential Revision: https://reviews.llvm.org/D153477
Added:
Modified:
llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
llvm/test/CodeGen/NVPTX/vector-loads.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index e38fa9928802e..cb8a1867c44f0 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -1599,13 +1599,13 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
EVT OrigType = N->getValueType(0);
LoadSDNode *LdNode = dyn_cast<LoadSDNode>(N);
- if (OrigType != EltVT && LdNode) {
+ if (OrigType != EltVT &&
+ (LdNode || (OrigType.isFloatingPoint() && EltVT.isFloatingPoint()))) {
// We have an extending-load. The instruction we selected operates on the
// smaller type, but the SDNode we are replacing has the larger type. We
// need to emit a CVT to make the types match.
- bool IsSigned = LdNode->getExtensionType() == ISD::SEXTLOAD;
- unsigned CvtOpc = GetConvertOpcode(OrigType.getSimpleVT(),
- EltVT.getSimpleVT(), IsSigned);
+ unsigned CvtOpc =
+ GetConvertOpcode(OrigType.getSimpleVT(), EltVT.getSimpleVT(), LdNode);
// For each output value, apply the manual sign/zero-extension and make sure
// all users of the load go through that CVT.
@@ -3601,7 +3601,8 @@ bool NVPTXDAGToDAGISel::SelectInlineAsmMemoryOperand(
/// GetConvertOpcode - Returns the CVT_ instruction opcode that implements a
/// conversion from \p SrcTy to \p DestTy.
unsigned NVPTXDAGToDAGISel::GetConvertOpcode(MVT DestTy, MVT SrcTy,
- bool IsSigned) {
+ LoadSDNode *LdNode) {
+ bool IsSigned = LdNode && LdNode->getExtensionType() == ISD::SEXTLOAD;
switch (SrcTy.SimpleTy) {
default:
llvm_unreachable("Unhandled source type");
@@ -3649,5 +3650,14 @@ unsigned NVPTXDAGToDAGISel::GetConvertOpcode(MVT DestTy, MVT SrcTy,
case MVT::i32:
return IsSigned ? NVPTX::CVT_s32_s64 : NVPTX::CVT_u32_u64;
}
+ case MVT::f16:
+ switch (DestTy.SimpleTy) {
+ default:
+ llvm_unreachable("Unhandled dest type");
+ case MVT::f32:
+ return NVPTX::CVT_f32_f16;
+ case MVT::f64:
+ return NVPTX::CVT_f64_f16;
+ }
}
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index 746a9de5a2019..2a8ee5089ca02 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -97,7 +97,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
bool ChkMemSDNodeAddressSpace(SDNode *N, unsigned int spN) const;
- static unsigned GetConvertOpcode(MVT DestTy, MVT SrcTy, bool IsSigned);
+ static unsigned GetConvertOpcode(MVT DestTy, MVT SrcTy, LoadSDNode *N);
};
} // end namespace llvm
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index c46ed2111258c..5c16e34660c71 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -5146,7 +5146,8 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
Align Alignment = LD->getAlign();
auto &TD = DAG.getDataLayout();
- Align PrefAlign = TD.getPrefTypeAlign(ResVT.getTypeForEVT(*DAG.getContext()));
+ Align PrefAlign =
+ TD.getPrefTypeAlign(LD->getMemoryVT().getTypeForEVT(*DAG.getContext()));
if (Alignment < PrefAlign) {
// This load is not sufficiently aligned, so bail out and let this vector
// load be scalarized. Note that we may still be able to emit smaller
diff --git a/llvm/test/CodeGen/NVPTX/vector-loads.ll b/llvm/test/CodeGen/NVPTX/vector-loads.ll
index b700fb5e8f34c..850cef1b83de7 100644
--- a/llvm/test/CodeGen/NVPTX/vector-loads.ll
+++ b/llvm/test/CodeGen/NVPTX/vector-loads.ll
@@ -97,5 +97,58 @@ define void @foo_complex(ptr nocapture readonly align 16 dereferenceable(1342177
ret void
}
+; CHECK-LABEL: extv8f16_global_a16(
+define void @extv8f16_global_a16(ptr addrspace(1) noalias readonly align 16 %dst, ptr addrspace(1) noalias readonly align 16 %src) #0 {
+; CHECK: ld.global.v4.b16 {%f
+; CHECK: ld.global.v4.b16 {%f
+ %v = load <8 x half>, ptr addrspace(1) %src, align 16
+ %ext = fpext <8 x half> %v to <8 x float>
+; CHECK: st.global.v4.f32
+; CHECK: st.global.v4.f32
+ store <8 x float> %ext, ptr addrspace(1) %dst, align 16
+ ret void
+}
+
+; CHECK-LABEL: extv8f16_global_a4(
+define void @extv8f16_global_a4(ptr addrspace(1) noalias readonly align 16 %dst, ptr addrspace(1) noalias readonly align 16 %src) #0 {
+; CHECK: ld.global.v2.b16 {%f
+; CHECK: ld.global.v2.b16 {%f
+; CHECK: ld.global.v2.b16 {%f
+; CHECK: ld.global.v2.b16 {%f
+ %v = load <8 x half>, ptr addrspace(1) %src, align 4
+ %ext = fpext <8 x half> %v to <8 x float>
+; CHECK: st.global.v4.f32
+; CHECK: st.global.v4.f32
+ store <8 x float> %ext, ptr addrspace(1) %dst, align 16
+ ret void
+}
+
+
+; CHECK-LABEL: extv8f16_generic_a16(
+define void @extv8f16_generic_a16(ptr noalias readonly align 16 %dst, ptr noalias readonly align 16 %src) #0 {
+; CHECK: ld.v4.b16 {%f
+; CHECK: ld.v4.b16 {%f
+ %v = load <8 x half>, ptr %src, align 16
+ %ext = fpext <8 x half> %v to <8 x float>
+; CHECK: st.v4.f32
+; CHECK: st.v4.f32
+ store <8 x float> %ext, ptr %dst, align 16
+ ret void
+}
+
+; CHECK-LABEL: extv8f16_generic_a4(
+define void @extv8f16_generic_a4(ptr noalias readonly align 16 %dst, ptr noalias readonly align 16 %src) #0 {
+; CHECK: ld.v2.b16 {%f
+; CHECK: ld.v2.b16 {%f
+; CHECK: ld.v2.b16 {%f
+; CHECK: ld.v2.b16 {%f
+ %v = load <8 x half>, ptr %src, align 4
+ %ext = fpext <8 x half> %v to <8 x float>
+; CHECK: st.v4.f32
+; CHECK: st.v4.f32
+ store <8 x float> %ext, ptr %dst, align 16
+ ret void
+}
+
!1 = !{i32 0, i32 64}
More information about the llvm-commits
mailing list