[llvm] cb1fed3 - [NVPTX] Correctly guard int -> bf16 on PTX version and SM version

David Majnemer via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 21 12:27:38 PST 2024


Author: David Majnemer
Date: 2024-02-21T20:26:07Z
New Revision: cb1fed3a89e0cdc2660edaada1f0868cae3b7bcf

URL: https://github.com/llvm/llvm-project/commit/cb1fed3a89e0cdc2660edaada1f0868cae3b7bcf
DIFF: https://github.com/llvm/llvm-project/commit/cb1fed3a89e0cdc2660edaada1f0868cae3b7bcf.diff

LOG: [NVPTX] Correctly guard int -> bf16 on PTX version and SM version

Added: 
    

Modified: 
    llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
    llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index fc6c642acbc073..7d2fe78d142292 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -788,13 +788,15 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
 
   // sm_80 only has conversions between f32 and bf16. Custom lower all other
   // bf16 conversions.
-  if (STI.hasBF16Math() &&
-      (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78)) {
+  if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
     for (MVT VT : {MVT::i1, MVT::i16, MVT::i32, MVT::i64}) {
       setOperationAction(
           {ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT, ISD::FP_TO_UINT},
           VT, Custom);
     }
+    setOperationAction(
+        {ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT, ISD::FP_TO_UINT},
+        MVT::bf16, Custom);
   }
 
   setOperationAction(ISD::FROUND, MVT::f16, Promote);

diff  --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 40d82ebecbed35..55a1955a7f497e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -3247,23 +3247,23 @@ def : Pat<(f16 (uint_to_fp Int64Regs:$a)),
 
 // sint -> bf16
 def : Pat<(bf16 (sint_to_fp Int1Regs:$a)),
-          (CVT_bf16_s32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>;
+          (CVT_bf16_s32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>;
 def : Pat<(bf16 (sint_to_fp Int16Regs:$a)),
-          (CVT_bf16_s16 Int16Regs:$a, CvtRN)>;
+          (CVT_bf16_s16 Int16Regs:$a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>;
 def : Pat<(bf16 (sint_to_fp Int32Regs:$a)),
-          (CVT_bf16_s32 Int32Regs:$a, CvtRN)>;
+          (CVT_bf16_s32 Int32Regs:$a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>;
 def : Pat<(bf16 (sint_to_fp Int64Regs:$a)),
-          (CVT_bf16_s64 Int64Regs:$a, CvtRN)>;
+          (CVT_bf16_s64 Int64Regs:$a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>;
 
 // uint -> bf16
 def : Pat<(bf16 (uint_to_fp Int1Regs:$a)),
-          (CVT_bf16_u32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>;
+          (CVT_bf16_u32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>;
 def : Pat<(bf16 (uint_to_fp Int16Regs:$a)),
-          (CVT_bf16_u16 Int16Regs:$a, CvtRN)>;
+          (CVT_bf16_u16 Int16Regs:$a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>;
 def : Pat<(bf16 (uint_to_fp Int32Regs:$a)),
-          (CVT_bf16_u32 Int32Regs:$a, CvtRN)>;
+          (CVT_bf16_u32 Int32Regs:$a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>;
 def : Pat<(bf16 (uint_to_fp Int64Regs:$a)),
-          (CVT_bf16_u64 Int64Regs:$a, CvtRN)>;
+          (CVT_bf16_u64 Int64Regs:$a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>;
 
 // sint -> f32
 def : Pat<(f32 (sint_to_fp Int1Regs:$a)),


        


More information about the llvm-commits mailing list