[llvm] [LLVM][CodeGen][SVE2] Implement nxvf64 fpround to nxvbf16. (PR #111012)

Paul Walker via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 3 09:09:57 PDT 2024


https://github.com/paulwalker-arm created https://github.com/llvm/llvm-project/pull/111012

NOTE: SVE2 only because that is when FCVTX is available, which is required to perform the necessary two-step rounding.

>From f33497ab131f6c4be74c0b2276680f196f05c3e4 Mon Sep 17 00:00:00 2001
From: Paul Walker <paul.walker at arm.com>
Date: Fri, 30 Aug 2024 18:18:03 +0100
Subject: [PATCH] [LLVM][CodeGen][SVE2] Implement nxvf64 fpround to nxvbf16.

NOTE: SVE2 only because that is when FCVTX is available, which is
required to perform the necessary two-step rounding.
---
 .../Target/AArch64/AArch64ISelLowering.cpp    |  14 +++
 llvm/lib/Target/AArch64/AArch64ISelLowering.h |   1 +
 .../lib/Target/AArch64/AArch64SVEInstrInfo.td |   3 +-
 llvm/lib/Target/AArch64/AArch64Subtarget.h    |  10 +-
 llvm/lib/Target/AArch64/SVEInstrFormats.td    |   4 +-
 .../CodeGen/AArch64/sve2-bf16-converts.ll     | 115 ++++++++++++++++++
 6 files changed, 142 insertions(+), 5 deletions(-)
 create mode 100644 llvm/test/CodeGen/AArch64/sve2-bf16-converts.ll

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index e55e9989e6565c..aa8e16cc98122f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -268,6 +268,7 @@ static bool isMergePassthruOpcode(unsigned Opc) {
   case AArch64ISD::FP_EXTEND_MERGE_PASSTHRU:
   case AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU:
   case AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU:
+  case AArch64ISD::FCVTX_MERGE_PASSTHRU:
   case AArch64ISD::FCVTZU_MERGE_PASSTHRU:
   case AArch64ISD::FCVTZS_MERGE_PASSTHRU:
   case AArch64ISD::FSQRT_MERGE_PASSTHRU:
@@ -2622,6 +2623,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(AArch64ISD::FP_EXTEND_MERGE_PASSTHRU)
     MAKE_CASE(AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU)
     MAKE_CASE(AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU)
+    MAKE_CASE(AArch64ISD::FCVTX_MERGE_PASSTHRU)
     MAKE_CASE(AArch64ISD::FCVTZU_MERGE_PASSTHRU)
     MAKE_CASE(AArch64ISD::FCVTZS_MERGE_PASSTHRU)
     MAKE_CASE(AArch64ISD::FSQRT_MERGE_PASSTHRU)
@@ -4363,6 +4365,18 @@ SDValue AArch64TargetLowering::LowerFP_ROUND(SDValue Op,
       // Set the quiet bit.
       if (!DAG.isKnownNeverSNaN(SrcVal))
         NaN = DAG.getNode(ISD::OR, DL, I32, Narrow, ImmV(0x400000));
+    } else if (SrcVT == MVT::nxv2f64 &&
+               (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable())) {
+      SDValue Pg = getPredicateForVector(DAG, DL, MVT::nxv2f32);
+      Narrow = DAG.getNode(AArch64ISD::FCVTX_MERGE_PASSTHRU, DL, MVT::nxv2f32,
+                           Pg, SrcVal, DAG.getUNDEF(MVT::nxv2f32));
+
+      if (Subtarget->hasBF16())
+        return DAG.getNode(AArch64ISD::FP_ROUND_MERGE_PASSTHRU, DL, VT, Pg,
+                           Narrow, DAG.getTargetConstant(0, DL, MVT::i64),
+                           DAG.getUNDEF(VT));
+
+      Narrow = getSVESafeBitCast(I32, Narrow, DAG);
     } else
       return SDValue();
 
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 480bf60360bf55..1bae7562f459a5 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -158,6 +158,7 @@ enum NodeType : unsigned {
   FP_EXTEND_MERGE_PASSTHRU,
   UINT_TO_FP_MERGE_PASSTHRU,
   SINT_TO_FP_MERGE_PASSTHRU,
+  FCVTX_MERGE_PASSTHRU,
   FCVTZU_MERGE_PASSTHRU,
   FCVTZS_MERGE_PASSTHRU,
   SIGN_EXTEND_INREG_MERGE_PASSTHRU,
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 76362768e0aa6b..53d9473975a235 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -357,6 +357,7 @@ def AArch64fcvtr_mt  : SDNode<"AArch64ISD::FP_ROUND_MERGE_PASSTHRU", SDT_AArch64
 def AArch64fcvte_mt  : SDNode<"AArch64ISD::FP_EXTEND_MERGE_PASSTHRU", SDT_AArch64FCVT>;
 def AArch64ucvtf_mt  : SDNode<"AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU", SDT_AArch64FCVT>;
 def AArch64scvtf_mt  : SDNode<"AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU", SDT_AArch64FCVT>;
+def AArch64fcvtx_mt  : SDNode<"AArch64ISD::FCVTX_MERGE_PASSTHRU", SDT_AArch64FCVT>;
 def AArch64fcvtzu_mt : SDNode<"AArch64ISD::FCVTZU_MERGE_PASSTHRU", SDT_AArch64FCVT>;
 def AArch64fcvtzs_mt : SDNode<"AArch64ISD::FCVTZS_MERGE_PASSTHRU", SDT_AArch64FCVT>;
 
@@ -3779,7 +3780,7 @@ let Predicates = [HasSVE2orSME, UseExperimentalZeroingPseudos] in {
 let Predicates = [HasSVE2orSME] in {
   // SVE2 floating-point convert precision
   defm FCVTXNT_ZPmZ : sve2_fp_convert_down_odd_rounding_top<"fcvtxnt", "int_aarch64_sve_fcvtxnt">;
-  defm FCVTX_ZPmZ   : sve2_fp_convert_down_odd_rounding<"fcvtx",       "int_aarch64_sve_fcvtx">;
+  defm FCVTX_ZPmZ   : sve2_fp_convert_down_odd_rounding<"fcvtx",       "int_aarch64_sve_fcvtx", AArch64fcvtx_mt>;
   defm FCVTNT_ZPmZ  : sve2_fp_convert_down_narrow<"fcvtnt",            "int_aarch64_sve_fcvtnt">;
   defm FCVTLT_ZPmZ  : sve2_fp_convert_up_long<"fcvtlt",                "int_aarch64_sve_fcvtlt">;
 
diff --git a/llvm/lib/Target/AArch64/AArch64Subtarget.h b/llvm/lib/Target/AArch64/AArch64Subtarget.h
index accfb49c6fbe3a..9856415361e50d 100644
--- a/llvm/lib/Target/AArch64/AArch64Subtarget.h
+++ b/llvm/lib/Target/AArch64/AArch64Subtarget.h
@@ -188,10 +188,14 @@ class AArch64Subtarget final : public AArch64GenSubtargetInfo {
            (hasSMEFA64() || (!isStreaming() && !isStreamingCompatible()));
   }
 
-  /// Returns true if the target has access to either the full range of SVE instructions,
-  /// or the streaming-compatible subset of SVE instructions.
+  /// Returns true if the target has access to the streaming-compatible subset
+  /// of SVE instructions.
+  bool isStreamingSVEAvailable() const { return hasSME() && isStreaming(); }
+
+  /// Returns true if the target has access to either the full range of SVE
+  /// instructions, or the streaming-compatible subset of SVE instructions.
   bool isSVEorStreamingSVEAvailable() const {
-    return hasSVE() || (hasSME() && isStreaming());
+    return hasSVE() || isStreamingSVEAvailable();
   }
 
   unsigned getMinVectorRegisterBitWidth() const {
diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td
index 13c2a90a963f8c..121e19ac0397fe 100644
--- a/llvm/lib/Target/AArch64/SVEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td
@@ -3059,9 +3059,11 @@ multiclass sve2_fp_un_pred_zeroing_hsd<SDPatternOperator op> {
   def : SVE_1_Op_PassthruZero_Pat<nxv2i64, op, nxv2i1, nxv2f64, !cast<Pseudo>(NAME # _D_ZERO)>;
 }
 
-multiclass sve2_fp_convert_down_odd_rounding<string asm, string op> {
+multiclass sve2_fp_convert_down_odd_rounding<string asm, string op, SDPatternOperator ir_op = null_frag> {
   def _DtoS : sve_fp_2op_p_zd<0b0001010, asm, ZPR64, ZPR32, ElementSizeD>;
+
   def : SVE_3_Op_Pat<nxv4f32, !cast<SDPatternOperator>(op # _f32f64), nxv4f32, nxv2i1, nxv2f64, !cast<Instruction>(NAME # _DtoS)>;
+  def : SVE_1_Op_Passthru_Pat<nxv2f32, ir_op, nxv2i1, nxv2f64, !cast<Instruction>(NAME # _DtoS)>;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/llvm/test/CodeGen/AArch64/sve2-bf16-converts.ll b/llvm/test/CodeGen/AArch64/sve2-bf16-converts.ll
new file mode 100644
index 00000000000000..c67b1700843aec
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve2-bf16-converts.ll
@@ -0,0 +1,115 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mattr=+sve2                 < %s | FileCheck %s --check-prefixes=NOBF16
+; RUN: llc -mattr=+sve2,+bf16           < %s | FileCheck %s --check-prefixes=BF16
+; RUN: llc -mattr=+sme -force-streaming < %s | FileCheck %s --check-prefixes=BF16
+
+target triple = "aarch64-unknown-linux-gnu"
+
+define <vscale x 2 x bfloat> @fptrunc_nxv2f64_to_nxv2bf16(<vscale x 2 x double> %a) {
+; NOBF16-LABEL: fptrunc_nxv2f64_to_nxv2bf16:
+; NOBF16:       // %bb.0:
+; NOBF16-NEXT:    ptrue p0.d
+; NOBF16-NEXT:    mov z1.s, #32767 // =0x7fff
+; NOBF16-NEXT:    fcvtx z0.s, p0/m, z0.d
+; NOBF16-NEXT:    lsr z2.s, z0.s, #16
+; NOBF16-NEXT:    add z0.s, z0.s, z1.s
+; NOBF16-NEXT:    and z2.s, z2.s, #0x1
+; NOBF16-NEXT:    add z0.s, z2.s, z0.s
+; NOBF16-NEXT:    lsr z0.s, z0.s, #16
+; NOBF16-NEXT:    ret
+;
+; BF16-LABEL: fptrunc_nxv2f64_to_nxv2bf16:
+; BF16:       // %bb.0:
+; BF16-NEXT:    ptrue p0.d
+; BF16-NEXT:    fcvtx z0.s, p0/m, z0.d
+; BF16-NEXT:    bfcvt z0.h, p0/m, z0.s
+; BF16-NEXT:    ret
+  %res = fptrunc <vscale x 2 x double> %a to <vscale x 2 x bfloat>
+  ret <vscale x 2 x bfloat> %res
+}
+
+define <vscale x 4 x bfloat> @fptrunc_nxv4f64_to_nxv4bf16(<vscale x 4 x double> %a) {
+; NOBF16-LABEL: fptrunc_nxv4f64_to_nxv4bf16:
+; NOBF16:       // %bb.0:
+; NOBF16-NEXT:    ptrue p0.d
+; NOBF16-NEXT:    mov z2.s, #32767 // =0x7fff
+; NOBF16-NEXT:    fcvtx z1.s, p0/m, z1.d
+; NOBF16-NEXT:    fcvtx z0.s, p0/m, z0.d
+; NOBF16-NEXT:    lsr z3.s, z1.s, #16
+; NOBF16-NEXT:    lsr z4.s, z0.s, #16
+; NOBF16-NEXT:    add z1.s, z1.s, z2.s
+; NOBF16-NEXT:    add z0.s, z0.s, z2.s
+; NOBF16-NEXT:    and z3.s, z3.s, #0x1
+; NOBF16-NEXT:    and z4.s, z4.s, #0x1
+; NOBF16-NEXT:    add z1.s, z3.s, z1.s
+; NOBF16-NEXT:    add z0.s, z4.s, z0.s
+; NOBF16-NEXT:    lsr z1.s, z1.s, #16
+; NOBF16-NEXT:    lsr z0.s, z0.s, #16
+; NOBF16-NEXT:    uzp1 z0.s, z0.s, z1.s
+; NOBF16-NEXT:    ret
+;
+; BF16-LABEL: fptrunc_nxv4f64_to_nxv4bf16:
+; BF16:       // %bb.0:
+; BF16-NEXT:    ptrue p0.d
+; BF16-NEXT:    fcvtx z1.s, p0/m, z1.d
+; BF16-NEXT:    fcvtx z0.s, p0/m, z0.d
+; BF16-NEXT:    bfcvt z1.h, p0/m, z1.s
+; BF16-NEXT:    bfcvt z0.h, p0/m, z0.s
+; BF16-NEXT:    uzp1 z0.s, z0.s, z1.s
+; BF16-NEXT:    ret
+  %res = fptrunc <vscale x 4 x double> %a to <vscale x 4 x bfloat>
+  ret <vscale x 4 x bfloat> %res
+}
+
+define <vscale x 8 x bfloat> @fptrunc_nxv8f64_to_nxv8bf16(<vscale x 8 x double> %a) {
+; NOBF16-LABEL: fptrunc_nxv8f64_to_nxv8bf16:
+; NOBF16:       // %bb.0:
+; NOBF16-NEXT:    ptrue p0.d
+; NOBF16-NEXT:    mov z4.s, #32767 // =0x7fff
+; NOBF16-NEXT:    fcvtx z3.s, p0/m, z3.d
+; NOBF16-NEXT:    fcvtx z2.s, p0/m, z2.d
+; NOBF16-NEXT:    fcvtx z1.s, p0/m, z1.d
+; NOBF16-NEXT:    fcvtx z0.s, p0/m, z0.d
+; NOBF16-NEXT:    lsr z5.s, z3.s, #16
+; NOBF16-NEXT:    lsr z6.s, z2.s, #16
+; NOBF16-NEXT:    lsr z7.s, z1.s, #16
+; NOBF16-NEXT:    lsr z24.s, z0.s, #16
+; NOBF16-NEXT:    add z3.s, z3.s, z4.s
+; NOBF16-NEXT:    add z2.s, z2.s, z4.s
+; NOBF16-NEXT:    add z1.s, z1.s, z4.s
+; NOBF16-NEXT:    add z0.s, z0.s, z4.s
+; NOBF16-NEXT:    and z5.s, z5.s, #0x1
+; NOBF16-NEXT:    and z6.s, z6.s, #0x1
+; NOBF16-NEXT:    and z7.s, z7.s, #0x1
+; NOBF16-NEXT:    and z24.s, z24.s, #0x1
+; NOBF16-NEXT:    add z3.s, z5.s, z3.s
+; NOBF16-NEXT:    add z2.s, z6.s, z2.s
+; NOBF16-NEXT:    add z1.s, z7.s, z1.s
+; NOBF16-NEXT:    add z0.s, z24.s, z0.s
+; NOBF16-NEXT:    lsr z3.s, z3.s, #16
+; NOBF16-NEXT:    lsr z2.s, z2.s, #16
+; NOBF16-NEXT:    lsr z1.s, z1.s, #16
+; NOBF16-NEXT:    lsr z0.s, z0.s, #16
+; NOBF16-NEXT:    uzp1 z2.s, z2.s, z3.s
+; NOBF16-NEXT:    uzp1 z0.s, z0.s, z1.s
+; NOBF16-NEXT:    uzp1 z0.h, z0.h, z2.h
+; NOBF16-NEXT:    ret
+;
+; BF16-LABEL: fptrunc_nxv8f64_to_nxv8bf16:
+; BF16:       // %bb.0:
+; BF16-NEXT:    ptrue p0.d
+; BF16-NEXT:    fcvtx z3.s, p0/m, z3.d
+; BF16-NEXT:    fcvtx z2.s, p0/m, z2.d
+; BF16-NEXT:    fcvtx z1.s, p0/m, z1.d
+; BF16-NEXT:    fcvtx z0.s, p0/m, z0.d
+; BF16-NEXT:    bfcvt z3.h, p0/m, z3.s
+; BF16-NEXT:    bfcvt z2.h, p0/m, z2.s
+; BF16-NEXT:    bfcvt z1.h, p0/m, z1.s
+; BF16-NEXT:    bfcvt z0.h, p0/m, z0.s
+; BF16-NEXT:    uzp1 z2.s, z2.s, z3.s
+; BF16-NEXT:    uzp1 z0.s, z0.s, z1.s
+; BF16-NEXT:    uzp1 z0.h, z0.h, z2.h
+; BF16-NEXT:    ret
+  %res = fptrunc <vscale x 8 x double> %a to <vscale x 8 x bfloat>
+  ret <vscale x 8 x bfloat> %res
+}



More information about the llvm-commits mailing list