[llvm] [LLVM][CodeGen][SVE] Implement nxvbf16 fpextend to nxvf32/nxvf64. (PR #107253)

Paul Walker via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 4 08:12:06 PDT 2024


https://github.com/paulwalker-arm created https://github.com/llvm/llvm-project/pull/107253

NOTE: There are no dedicated SVE instructions but bf16->f32 is just a left shift because they share the same exponent range and from there other convert instructions can be used.

>From 0faa07aeb6e7e0e25a2fa8fa9efdd47f0baeff73 Mon Sep 17 00:00:00 2001
From: Paul Walker <paul.walker at arm.com>
Date: Fri, 30 Aug 2024 15:56:23 +0100
Subject: [PATCH] [LLVM][CodeGen][SVE] Implement nxvbf16 fpextend to
 nxvf32/nxvf64.

NOTE: There are no dedicated SVE instructions but bf16->f32 is just
a left shift because they share the same exponent range and from
there other convert instructions can be used.
---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 23 ++++-
 .../lib/Target/AArch64/AArch64SVEInstrInfo.td |  7 +-
 .../test/CodeGen/AArch64/sve-bf16-converts.ll | 89 +++++++++++++++++++
 3 files changed, 117 insertions(+), 2 deletions(-)
 create mode 100644 llvm/test/CodeGen/AArch64/sve-bf16-converts.ll

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 5e3f9364ac3e12..a57878d18b2b7f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1663,6 +1663,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) {
       setOperationAction(ISD::BITCAST, VT, Custom);
       setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
+      setOperationAction(ISD::FP_EXTEND, VT, Custom);
       setOperationAction(ISD::MLOAD, VT, Custom);
       setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
       setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
@@ -4298,8 +4299,28 @@ static SDValue LowerPREFETCH(SDValue Op, SelectionDAG &DAG) {
 SDValue AArch64TargetLowering::LowerFP_EXTEND(SDValue Op,
                                               SelectionDAG &DAG) const {
   EVT VT = Op.getValueType();
-  if (VT.isScalableVector())
+  if (VT.isScalableVector()) {
+    SDValue SrcVal = Op.getOperand(0);
+
+    if (SrcVal.getValueType().getScalarType() == MVT::bf16) {
+      // bf16 and f32 share the same exponent range so the conversion requires
+      // them to be aligned with the new mantissa bits zero'd, which is just a
+      // left shift that is best to isel drectly.
+      if (VT == MVT::nxv2f32 || VT == MVT::nxv4f32)
+        return Op;
+
+      if (VT != MVT::nxv2f64)
+        return SDValue();
+
+      // Break other conversions in two with the first part converting to f32
+      // and the second using native f32->VT instructions.
+      SDLoc DL(Op);
+      return DAG.getNode(ISD::FP_EXTEND, DL, VT,
+                         DAG.getNode(ISD::FP_EXTEND, DL, MVT::nxv2f32, SrcVal));
+    }
+
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FP_EXTEND_MERGE_PASSTHRU);
+  }
 
   if (useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()))
     return LowerFixedLengthFPExtendToSVE(Op, DAG);
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index af8ddb49b0ac66..ef006be9d02354 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -2320,7 +2320,12 @@ let Predicates = [HasSVEorSME] in {
   def : Pat<(nxv2f16 (AArch64fcvtr_mt (nxv2i1 (SVEAllActive:$Pg)), nxv2f32:$Zs, (i64 timm0_1), nxv2f16:$Zd)),
             (FCVT_ZPmZ_StoH_UNDEF ZPR:$Zd, PPR:$Pg, ZPR:$Zs)>;
 
-  // Signed integer -> Floating-point 
+  def : Pat<(nxv4f32 (fpextend nxv4bf16:$op)),
+            (LSL_ZZI_S $op, (i32 16))>;
+  def : Pat<(nxv2f32 (fpextend nxv2bf16:$op)),
+            (LSL_ZZI_S $op, (i32 16))>;
+
+  // Signed integer -> Floating-point
   def : Pat<(nxv2f16 (AArch64scvtf_mt (nxv2i1 (SVEAllActive):$Pg),
                       (sext_inreg nxv2i64:$Zs, nxv2i16), nxv2f16:$Zd)),
             (SCVTF_ZPmZ_HtoH_UNDEF ZPR:$Zd, PPR:$Pg, ZPR:$Zs)>;
diff --git a/llvm/test/CodeGen/AArch64/sve-bf16-converts.ll b/llvm/test/CodeGen/AArch64/sve-bf16-converts.ll
new file mode 100644
index 00000000000000..d72f92c1dac1ff
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-bf16-converts.ll
@@ -0,0 +1,89 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mattr=+sve                  < %s | FileCheck %s
+; RUN: llc -mattr=+sme -force-streaming < %s | FileCheck %s
+
+target triple = "aarch64-unknown-linux-gnu"
+
+define <vscale x 2 x float> @fpext_nxv2bf16_to_nxv2f32(<vscale x 2 x bfloat> %a) {
+; CHECK-LABEL: fpext_nxv2bf16_to_nxv2f32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    lsl z0.s, z0.s, #16
+; CHECK-NEXT:    ret
+  %res = fpext <vscale x 2 x bfloat> %a to <vscale x 2 x float>
+  ret <vscale x 2 x float> %res
+}
+
+define <vscale x 4 x float> @fpext_nxv4bf16_to_nxv4f32(<vscale x 4 x bfloat> %a) {
+; CHECK-LABEL: fpext_nxv4bf16_to_nxv4f32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    lsl z0.s, z0.s, #16
+; CHECK-NEXT:    ret
+  %res = fpext <vscale x 4 x bfloat> %a to <vscale x 4 x float>
+  ret <vscale x 4 x float> %res
+}
+
+define <vscale x 8 x float> @fpext_nxv8bf16_to_nxv8f32(<vscale x 8 x bfloat> %a) {
+; CHECK-LABEL: fpext_nxv8bf16_to_nxv8f32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    uunpklo z1.s, z0.h
+; CHECK-NEXT:    uunpkhi z2.s, z0.h
+; CHECK-NEXT:    lsl z0.s, z1.s, #16
+; CHECK-NEXT:    lsl z1.s, z2.s, #16
+; CHECK-NEXT:    ret
+  %res = fpext <vscale x 8 x bfloat> %a to <vscale x 8 x float>
+  ret <vscale x 8 x float> %res
+}
+
+define <vscale x 2 x double> @fpext_nxv2bf16_to_nxv2f64(<vscale x 2 x bfloat> %a) {
+; CHECK-LABEL: fpext_nxv2bf16_to_nxv2f64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    lsl z0.s, z0.s, #16
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    fcvt z0.d, p0/m, z0.s
+; CHECK-NEXT:    ret
+  %res = fpext <vscale x 2 x bfloat> %a to <vscale x 2 x double>
+  ret <vscale x 2 x double> %res
+}
+
+define <vscale x 4 x double> @fpext_nxv4bf16_to_nxv4f64(<vscale x 4 x bfloat> %a) {
+; CHECK-LABEL: fpext_nxv4bf16_to_nxv4f64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    uunpklo z1.d, z0.s
+; CHECK-NEXT:    uunpkhi z0.d, z0.s
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    lsl z1.s, z1.s, #16
+; CHECK-NEXT:    lsl z2.s, z0.s, #16
+; CHECK-NEXT:    movprfx z0, z1
+; CHECK-NEXT:    fcvt z0.d, p0/m, z1.s
+; CHECK-NEXT:    movprfx z1, z2
+; CHECK-NEXT:    fcvt z1.d, p0/m, z2.s
+; CHECK-NEXT:    ret
+  %res = fpext <vscale x 4 x bfloat> %a to <vscale x 4 x double>
+  ret <vscale x 4 x double> %res
+}
+
+define <vscale x 8 x double> @fpext_nxv8bf16_to_nxv8f64(<vscale x 8 x bfloat> %a) {
+; CHECK-LABEL: fpext_nxv8bf16_to_nxv8f64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    uunpklo z1.s, z0.h
+; CHECK-NEXT:    uunpkhi z0.s, z0.h
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    uunpklo z2.d, z1.s
+; CHECK-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-NEXT:    uunpklo z3.d, z0.s
+; CHECK-NEXT:    uunpkhi z0.d, z0.s
+; CHECK-NEXT:    lsl z1.s, z1.s, #16
+; CHECK-NEXT:    lsl z2.s, z2.s, #16
+; CHECK-NEXT:    lsl z3.s, z3.s, #16
+; CHECK-NEXT:    lsl z4.s, z0.s, #16
+; CHECK-NEXT:    fcvt z1.d, p0/m, z1.s
+; CHECK-NEXT:    movprfx z0, z2
+; CHECK-NEXT:    fcvt z0.d, p0/m, z2.s
+; CHECK-NEXT:    movprfx z2, z3
+; CHECK-NEXT:    fcvt z2.d, p0/m, z3.s
+; CHECK-NEXT:    movprfx z3, z4
+; CHECK-NEXT:    fcvt z3.d, p0/m, z4.s
+; CHECK-NEXT:    ret
+  %res = fpext <vscale x 8 x bfloat> %a to <vscale x 8 x double>
+  ret <vscale x 8 x double> %res
+}



More information about the llvm-commits mailing list