[llvm] 607485f - [LLVM][SVE] Lower bfloat extends the same as other types. (#129544)

via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 4 03:53:47 PST 2025


Author: Paul Walker
Date: 2025-03-04T11:53:43Z
New Revision: 607485f81c8bbfcf91ecb5a71a6323fb2bc367d9

URL: https://github.com/llvm/llvm-project/commit/607485f81c8bbfcf91ecb5a71a6323fb2bc367d9
DIFF: https://github.com/llvm/llvm-project/commit/607485f81c8bbfcf91ecb5a71a6323fb2bc367d9.diff

LOG: [LLVM][SVE] Lower bfloat extends the same as other types. (#129544)

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 32f2f5de060d2..2dca8c0da4756 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -4507,18 +4507,9 @@ SDValue AArch64TargetLowering::LowerFP_EXTEND(SDValue Op,
   if (VT.isScalableVector()) {
     SDValue SrcVal = Op.getOperand(0);
 
-    if (SrcVal.getValueType().getScalarType() == MVT::bf16) {
-      // bf16 and f32 share the same exponent range so the conversion requires
-      // them to be aligned with the new mantissa bits zero'd. This is just a
-      // left shift that is best to isel directly.
-      if (VT == MVT::nxv2f32 || VT == MVT::nxv4f32)
-        return Op;
-
-      if (VT != MVT::nxv2f64)
-        return SDValue();
-
-      // Break other conversions in two with the first part converting to f32
-      // and the second using native f32->VT instructions.
+    if (VT == MVT::nxv2f64 && SrcVal.getValueType() == MVT::nxv2bf16) {
+      // Break conversion in two with the first part converting to f32 and the
+      // second using native f32->VT instructions.
       SDLoc DL(Op);
       return DAG.getNode(ISD::FP_EXTEND, DL, VT,
                          DAG.getNode(ISD::FP_EXTEND, DL, MVT::nxv2f32, SrcVal));

diff  --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index fd38bc22a4987..3ee71c14c6bd4 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -345,7 +345,7 @@ def AArch64fclamp : PatFrags<(ops node:$Zd, node:$Zn, node:$Zm),
 
 def SDT_AArch64FCVT : SDTypeProfile<1, 3, [
   SDTCisVec<0>, SDTCisVec<1>, SDTCisVec<2>, SDTCisVec<3>,
-  SDTCVecEltisVT<1,i1>
+  SDTCVecEltisVT<1,i1>, SDTCisSameNumEltsAs<0,1>, SDTCisSameAs<0,3>
 ]>;
 
 def SDT_AArch64FCVTR : SDTypeProfile<1, 4, [
@@ -2377,9 +2377,9 @@ let Predicates = [HasSVE_or_SME] in {
   def : Pat<(nxv2f16 (AArch64fcvtr_mt (nxv2i1 (SVEAllActive:$Pg)), nxv2f32:$Zs, (i64 timm0_1), nxv2f16:$Zd)),
             (FCVT_ZPmZ_StoH_UNDEF ZPR:$Zd, PPR:$Pg, ZPR:$Zs)>;
 
-  def : Pat<(nxv4f32 (fpextend nxv4bf16:$op)),
+  def : Pat<(nxv4f32 (AArch64fcvte_mt (SVEAnyPredicate), nxv4bf16:$op, undef)),
             (LSL_ZZI_S $op, (i32 16))>;
-  def : Pat<(nxv2f32 (fpextend nxv2bf16:$op)),
+  def : Pat<(nxv2f32 (AArch64fcvte_mt (SVEAnyPredicate), nxv2bf16:$op, undef)),
             (LSL_ZZI_S $op, (i32 16))>;
 
   // Signed integer -> Floating-point


        


More information about the llvm-commits mailing list