[llvm] cb1fed3 - [NVPTX] Correctly guard int -> bf16 on PTX version and SM version
David Majnemer via llvm-commits
llvm-commits at lists.llvm.org
Wed Feb 21 12:27:38 PST 2024
Author: David Majnemer
Date: 2024-02-21T20:26:07Z
New Revision: cb1fed3a89e0cdc2660edaada1f0868cae3b7bcf
URL: https://github.com/llvm/llvm-project/commit/cb1fed3a89e0cdc2660edaada1f0868cae3b7bcf
DIFF: https://github.com/llvm/llvm-project/commit/cb1fed3a89e0cdc2660edaada1f0868cae3b7bcf.diff
LOG: [NVPTX] Correctly guard int -> bf16 on PTX version and SM version
Added:
Modified:
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Removed:
################################################################################
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index fc6c642acbc073..7d2fe78d142292 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -788,13 +788,15 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
// sm_80 only has conversions between f32 and bf16. Custom lower all other
// bf16 conversions.
- if (STI.hasBF16Math() &&
- (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78)) {
+ if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
for (MVT VT : {MVT::i1, MVT::i16, MVT::i32, MVT::i64}) {
setOperationAction(
{ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT, ISD::FP_TO_UINT},
VT, Custom);
}
+ setOperationAction(
+ {ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT, ISD::FP_TO_UINT},
+ MVT::bf16, Custom);
}
setOperationAction(ISD::FROUND, MVT::f16, Promote);
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 40d82ebecbed35..55a1955a7f497e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -3247,23 +3247,23 @@ def : Pat<(f16 (uint_to_fp Int64Regs:$a)),
// sint -> bf16
def : Pat<(bf16 (sint_to_fp Int1Regs:$a)),
- (CVT_bf16_s32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>;
+ (CVT_bf16_s32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>;
def : Pat<(bf16 (sint_to_fp Int16Regs:$a)),
- (CVT_bf16_s16 Int16Regs:$a, CvtRN)>;
+ (CVT_bf16_s16 Int16Regs:$a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>;
def : Pat<(bf16 (sint_to_fp Int32Regs:$a)),
- (CVT_bf16_s32 Int32Regs:$a, CvtRN)>;
+ (CVT_bf16_s32 Int32Regs:$a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>;
def : Pat<(bf16 (sint_to_fp Int64Regs:$a)),
- (CVT_bf16_s64 Int64Regs:$a, CvtRN)>;
+ (CVT_bf16_s64 Int64Regs:$a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>;
// uint -> bf16
def : Pat<(bf16 (uint_to_fp Int1Regs:$a)),
- (CVT_bf16_u32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>;
+ (CVT_bf16_u32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>;
def : Pat<(bf16 (uint_to_fp Int16Regs:$a)),
- (CVT_bf16_u16 Int16Regs:$a, CvtRN)>;
+ (CVT_bf16_u16 Int16Regs:$a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>;
def : Pat<(bf16 (uint_to_fp Int32Regs:$a)),
- (CVT_bf16_u32 Int32Regs:$a, CvtRN)>;
+ (CVT_bf16_u32 Int32Regs:$a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>;
def : Pat<(bf16 (uint_to_fp Int64Regs:$a)),
- (CVT_bf16_u64 Int64Regs:$a, CvtRN)>;
+ (CVT_bf16_u64 Int64Regs:$a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>;
// sint -> f32
def : Pat<(f32 (sint_to_fp Int1Regs:$a)),
More information about the llvm-commits
mailing list