[llvm] [LLVM][CodeGen][SVE] Implement nxvf32 fpround to nxvbf16. (PR #107420)

Paul Walker via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 9 09:32:16 PDT 2024


https://github.com/paulwalker-arm updated https://github.com/llvm/llvm-project/pull/107420

>From a74b99c808b33c693cfd978d8c17ba74a619af13 Mon Sep 17 00:00:00 2001
From: Paul Walker <paul.walker at arm.com>
Date: Fri, 30 Aug 2024 15:59:10 +0100
Subject: [PATCH 1/2] [LLVM][CodeGen][SVE] Implement nxvf32 fpround to nxvbf16.

---
 .../Target/AArch64/AArch64ISelLowering.cpp    |  50 ++++++-
 .../lib/Target/AArch64/AArch64SVEInstrInfo.td |   2 +-
 llvm/lib/Target/AArch64/SVEInstrFormats.td    |   6 +-
 .../test/CodeGen/AArch64/sve-bf16-converts.ll | 129 +++++++++++++++++-
 4 files changed, 180 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index c0671dd1f0087c..79c451a17c93b7 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1664,6 +1664,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::BITCAST, VT, Custom);
       setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
       setOperationAction(ISD::FP_EXTEND, VT, Custom);
+      setOperationAction(ISD::FP_ROUND, VT, Custom);
       setOperationAction(ISD::MLOAD, VT, Custom);
       setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
       setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
@@ -4332,14 +4333,57 @@ SDValue AArch64TargetLowering::LowerFP_EXTEND(SDValue Op,
 SDValue AArch64TargetLowering::LowerFP_ROUND(SDValue Op,
                                              SelectionDAG &DAG) const {
   EVT VT = Op.getValueType();
-  if (VT.isScalableVector())
-    return LowerToPredicatedOp(Op, DAG, AArch64ISD::FP_ROUND_MERGE_PASSTHRU);
-
   bool IsStrict = Op->isStrictFPOpcode();
   SDValue SrcVal = Op.getOperand(IsStrict ? 1 : 0);
   EVT SrcVT = SrcVal.getValueType();
   bool Trunc = Op.getConstantOperandVal(IsStrict ? 2 : 1) == 1;
 
+  if (VT.isScalableVector()) {
+    if (VT.getScalarType() != MVT::bf16)
+      return LowerToPredicatedOp(Op, DAG, AArch64ISD::FP_ROUND_MERGE_PASSTHRU);
+
+    SDLoc DL(Op);
+    constexpr EVT I32 = MVT::nxv4i32;
+    auto ImmV = [&](int I) -> SDValue { return DAG.getConstant(I, DL, I32); };
+
+    SDValue NaN;
+    SDValue Narrow;
+
+    if (SrcVT == MVT::nxv2f32 || SrcVT == MVT::nxv4f32) {
+      if (Subtarget->hasBF16())
+        return LowerToPredicatedOp(Op, DAG,
+                                   AArch64ISD::FP_ROUND_MERGE_PASSTHRU);
+
+      Narrow = getSVESafeBitCast(I32, SrcVal, DAG);
+
+      // Set the quiet bit.
+      if (!DAG.isKnownNeverSNaN(SrcVal))
+        NaN = DAG.getNode(ISD::OR, DL, I32, Narrow, ImmV(0x400000));
+    } else
+      return SDValue();
+
+    if (!Trunc) {
+      SDValue Lsb = DAG.getNode(ISD::SRL, DL, I32, Narrow, ImmV(16));
+      Lsb = DAG.getNode(ISD::AND, DL, I32, Lsb, ImmV(1));
+      SDValue RoundingBias = DAG.getNode(ISD::ADD, DL, I32, Lsb, ImmV(0x7fff));
+      Narrow = DAG.getNode(ISD::ADD, DL, I32, Narrow, RoundingBias);
+    }
+
+    // Don't round if we had a NaN, we don't want to turn 0x7fffffff into
+    // 0x80000000.
+    if (NaN) {
+      EVT I1 = I32.changeElementType(MVT::i1);
+      EVT CondVT = VT.changeElementType(MVT::i1);
+      SDValue IsNaN = DAG.getSetCC(DL, CondVT, SrcVal, SrcVal, ISD::SETUO);
+      IsNaN = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, I1, IsNaN);
+      Narrow = DAG.getSelect(DL, I32, IsNaN, NaN, Narrow);
+    }
+
+    // Now that we have rounded, shift the bits into position.
+    Narrow = DAG.getNode(ISD::SRL, DL, I32, Narrow, ImmV(16));
+    return getSVESafeBitCast(VT, Narrow, DAG);
+  }
+
   if (useSVEForFixedLengthVectorVT(SrcVT, !Subtarget->isNeonAvailable()))
     return LowerFixedLengthFPRoundToSVE(Op, DAG);
 
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 692cd66d38437d..b45c5111df8e5b 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -2397,7 +2397,7 @@ let Predicates = [HasBF16, HasSVEorSME] in {
   defm BFMLALT_ZZZ : sve2_fp_mla_long<0b101, "bfmlalt", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlalt>;
   defm BFMLALB_ZZZI : sve2_fp_mla_long_by_indexed_elem<0b100, "bfmlalb", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlalb_lane_v2>;
   defm BFMLALT_ZZZI : sve2_fp_mla_long_by_indexed_elem<0b101, "bfmlalt", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlalt_lane_v2>;
-  defm BFCVT_ZPmZ   : sve_bfloat_convert<0b1, "bfcvt",   int_aarch64_sve_fcvt_bf16f32>;
+  defm BFCVT_ZPmZ   : sve_bfloat_convert<0b1, "bfcvt",   int_aarch64_sve_fcvt_bf16f32, AArch64fcvtr_mt>;
   defm BFCVTNT_ZPmZ : sve_bfloat_convert<0b0, "bfcvtnt", int_aarch64_sve_fcvtnt_bf16f32>;
 } // End HasBF16, HasSVEorSME
 
diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td
index d6d503171a41e6..f06543676e0242 100644
--- a/llvm/lib/Target/AArch64/SVEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td
@@ -8811,9 +8811,13 @@ class sve_bfloat_convert<bit N, string asm>
   let mayRaiseFPException = 1;
 }
 
-multiclass sve_bfloat_convert<bit N, string asm, SDPatternOperator op> {
+multiclass sve_bfloat_convert<bit N, string asm, SDPatternOperator op,
+                              SDPatternOperator ir_op = null_frag> {
   def NAME : sve_bfloat_convert<N, asm>;
+
   def : SVE_3_Op_Pat<nxv8bf16, op, nxv8bf16, nxv8i1, nxv4f32, !cast<Instruction>(NAME)>;
+  def : SVE_1_Op_Passthru_Round_Pat<nxv4bf16, ir_op, nxv4i1, nxv4f32, !cast<Instruction>(NAME)>;
+  def : SVE_1_Op_Passthru_Round_Pat<nxv2bf16, ir_op, nxv2i1, nxv2f32, !cast<Instruction>(NAME)>;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/llvm/test/CodeGen/AArch64/sve-bf16-converts.ll b/llvm/test/CodeGen/AArch64/sve-bf16-converts.ll
index d72f92c1dac1ff..4104a43cb8e917 100644
--- a/llvm/test/CodeGen/AArch64/sve-bf16-converts.ll
+++ b/llvm/test/CodeGen/AArch64/sve-bf16-converts.ll
@@ -1,9 +1,15 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc -mattr=+sve                  < %s | FileCheck %s
-; RUN: llc -mattr=+sme -force-streaming < %s | FileCheck %s
+; RUN: llc -mattr=+sve                          < %s | FileCheck %s
+; RUN: llc -mattr=+sve --enable-no-nans-fp-math < %s | FileCheck %s --check-prefixes=CHECK,NOBF16NNAN
+; RUN: llc -mattr=+sve,+bf16                    < %s | FileCheck %s --check-prefixes=CHECK,BF16
+; RUN: llc -mattr=+sme -force-streaming         < %s | FileCheck %s --check-prefixes=CHECK,BF16
 
 target triple = "aarch64-unknown-linux-gnu"
 
+; NOTE: "fptrunc <# x double> to <# x bfloat>" is not supported because SVE
+; lacks a down convert that rounds to odd. Such IR will trigger the usual
+; failure (crash) when attempting to unroll a scalable vector.
+
 define <vscale x 2 x float> @fpext_nxv2bf16_to_nxv2f32(<vscale x 2 x bfloat> %a) {
 ; CHECK-LABEL: fpext_nxv2bf16_to_nxv2f32:
 ; CHECK:       // %bb.0:
@@ -87,3 +93,122 @@ define <vscale x 8 x double> @fpext_nxv8bf16_to_nxv8f64(<vscale x 8 x bfloat> %a
   %res = fpext <vscale x 8 x bfloat> %a to <vscale x 8 x double>
   ret <vscale x 8 x double> %res
 }
+
+define <vscale x 2 x bfloat> @fptrunc_nxv2f32_to_nxv2bf16(<vscale x 2 x float> %a) {
+; NOBF16-LABEL: fptrunc_nxv2f32_to_nxv2bf16:
+; NOBF16:       // %bb.0:
+; NOBF16-NEXT:    mov z1.s, #32767 // =0x7fff
+; NOBF16-NEXT:    lsr z2.s, z0.s, #16
+; NOBF16-NEXT:    ptrue p0.d
+; NOBF16-NEXT:    fcmuo p0.s, p0/z, z0.s, z0.s
+; NOBF16-NEXT:    and z2.s, z2.s, #0x1
+; NOBF16-NEXT:    add z1.s, z0.s, z1.s
+; NOBF16-NEXT:    orr z0.s, z0.s, #0x400000
+; NOBF16-NEXT:    add z1.s, z2.s, z1.s
+; NOBF16-NEXT:    sel z0.s, p0, z0.s, z1.s
+; NOBF16-NEXT:    lsr z0.s, z0.s, #16
+; NOBF16-NEXT:    ret
+;
+; NOBF16NNAN-LABEL: fptrunc_nxv2f32_to_nxv2bf16:
+; NOBF16NNAN:       // %bb.0:
+; NOBF16NNAN-NEXT:    mov z1.s, #32767 // =0x7fff
+; NOBF16NNAN-NEXT:    lsr z2.s, z0.s, #16
+; NOBF16NNAN-NEXT:    and z2.s, z2.s, #0x1
+; NOBF16NNAN-NEXT:    add z0.s, z0.s, z1.s
+; NOBF16NNAN-NEXT:    add z0.s, z2.s, z0.s
+; NOBF16NNAN-NEXT:    lsr z0.s, z0.s, #16
+; NOBF16NNAN-NEXT:    ret
+;
+; BF16-LABEL: fptrunc_nxv2f32_to_nxv2bf16:
+; BF16:       // %bb.0:
+; BF16-NEXT:    ptrue p0.d
+; BF16-NEXT:    bfcvt z0.h, p0/m, z0.s
+; BF16-NEXT:    ret
+  %res = fptrunc <vscale x 2 x float> %a to <vscale x 2 x bfloat>
+  ret <vscale x 2 x bfloat> %res
+}
+
+define <vscale x 4 x bfloat> @fptrunc_nxv4f32_to_nxv4bf16(<vscale x 4 x float> %a) {
+; NOBF16-LABEL: fptrunc_nxv4f32_to_nxv4bf16:
+; NOBF16:       // %bb.0:
+; NOBF16-NEXT:    mov z1.s, #32767 // =0x7fff
+; NOBF16-NEXT:    lsr z2.s, z0.s, #16
+; NOBF16-NEXT:    ptrue p0.s
+; NOBF16-NEXT:    fcmuo p0.s, p0/z, z0.s, z0.s
+; NOBF16-NEXT:    and z2.s, z2.s, #0x1
+; NOBF16-NEXT:    add z1.s, z0.s, z1.s
+; NOBF16-NEXT:    orr z0.s, z0.s, #0x400000
+; NOBF16-NEXT:    add z1.s, z2.s, z1.s
+; NOBF16-NEXT:    sel z0.s, p0, z0.s, z1.s
+; NOBF16-NEXT:    lsr z0.s, z0.s, #16
+; NOBF16-NEXT:    ret
+;
+; NOBF16NNAN-LABEL: fptrunc_nxv4f32_to_nxv4bf16:
+; NOBF16NNAN:       // %bb.0:
+; NOBF16NNAN-NEXT:    mov z1.s, #32767 // =0x7fff
+; NOBF16NNAN-NEXT:    lsr z2.s, z0.s, #16
+; NOBF16NNAN-NEXT:    and z2.s, z2.s, #0x1
+; NOBF16NNAN-NEXT:    add z0.s, z0.s, z1.s
+; NOBF16NNAN-NEXT:    add z0.s, z2.s, z0.s
+; NOBF16NNAN-NEXT:    lsr z0.s, z0.s, #16
+; NOBF16NNAN-NEXT:    ret
+;
+; BF16-LABEL: fptrunc_nxv4f32_to_nxv4bf16:
+; BF16:       // %bb.0:
+; BF16-NEXT:    ptrue p0.s
+; BF16-NEXT:    bfcvt z0.h, p0/m, z0.s
+; BF16-NEXT:    ret
+  %res = fptrunc <vscale x 4 x float> %a to <vscale x 4 x bfloat>
+  ret <vscale x 4 x bfloat> %res
+}
+
+define <vscale x 8 x bfloat> @fptrunc_nxv8f32_to_nxv8bf16(<vscale x 8 x float> %a) {
+; NOBF16-LABEL: fptrunc_nxv8f32_to_nxv8bf16:
+; NOBF16:       // %bb.0:
+; NOBF16-NEXT:    mov z2.s, #32767 // =0x7fff
+; NOBF16-NEXT:    lsr z3.s, z1.s, #16
+; NOBF16-NEXT:    lsr z4.s, z0.s, #16
+; NOBF16-NEXT:    ptrue p0.s
+; NOBF16-NEXT:    and z3.s, z3.s, #0x1
+; NOBF16-NEXT:    and z4.s, z4.s, #0x1
+; NOBF16-NEXT:    fcmuo p1.s, p0/z, z1.s, z1.s
+; NOBF16-NEXT:    add z5.s, z1.s, z2.s
+; NOBF16-NEXT:    add z2.s, z0.s, z2.s
+; NOBF16-NEXT:    fcmuo p0.s, p0/z, z0.s, z0.s
+; NOBF16-NEXT:    orr z1.s, z1.s, #0x400000
+; NOBF16-NEXT:    orr z0.s, z0.s, #0x400000
+; NOBF16-NEXT:    add z3.s, z3.s, z5.s
+; NOBF16-NEXT:    add z2.s, z4.s, z2.s
+; NOBF16-NEXT:    sel z1.s, p1, z1.s, z3.s
+; NOBF16-NEXT:    sel z0.s, p0, z0.s, z2.s
+; NOBF16-NEXT:    lsr z1.s, z1.s, #16
+; NOBF16-NEXT:    lsr z0.s, z0.s, #16
+; NOBF16-NEXT:    uzp1 z0.h, z0.h, z1.h
+; NOBF16-NEXT:    ret
+;
+; NOBF16NNAN-LABEL: fptrunc_nxv8f32_to_nxv8bf16:
+; NOBF16NNAN:       // %bb.0:
+; NOBF16NNAN-NEXT:    mov z2.s, #32767 // =0x7fff
+; NOBF16NNAN-NEXT:    lsr z3.s, z1.s, #16
+; NOBF16NNAN-NEXT:    lsr z4.s, z0.s, #16
+; NOBF16NNAN-NEXT:    and z3.s, z3.s, #0x1
+; NOBF16NNAN-NEXT:    and z4.s, z4.s, #0x1
+; NOBF16NNAN-NEXT:    add z1.s, z1.s, z2.s
+; NOBF16NNAN-NEXT:    add z0.s, z0.s, z2.s
+; NOBF16NNAN-NEXT:    add z1.s, z3.s, z1.s
+; NOBF16NNAN-NEXT:    add z0.s, z4.s, z0.s
+; NOBF16NNAN-NEXT:    lsr z1.s, z1.s, #16
+; NOBF16NNAN-NEXT:    lsr z0.s, z0.s, #16
+; NOBF16NNAN-NEXT:    uzp1 z0.h, z0.h, z1.h
+; NOBF16NNAN-NEXT:    ret
+;
+; BF16-LABEL: fptrunc_nxv8f32_to_nxv8bf16:
+; BF16:       // %bb.0:
+; BF16-NEXT:    ptrue p0.s
+; BF16-NEXT:    bfcvt z1.h, p0/m, z1.s
+; BF16-NEXT:    bfcvt z0.h, p0/m, z0.s
+; BF16-NEXT:    uzp1 z0.h, z0.h, z1.h
+; BF16-NEXT:    ret
+  %res = fptrunc <vscale x 8 x float> %a to <vscale x 8 x bfloat>
+  ret <vscale x 8 x bfloat> %res
+}

>From a5afaf5a6bfa2b6d0eae7b525beae11f9cbaae28 Mon Sep 17 00:00:00 2001
From: Paul Walker <paul.walker at arm.com>
Date: Mon, 9 Sep 2024 16:31:01 +0000
Subject: [PATCH 2/2] Restore missing check-prefixes option.

---
 llvm/test/CodeGen/AArch64/sve-bf16-converts.ll | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/test/CodeGen/AArch64/sve-bf16-converts.ll b/llvm/test/CodeGen/AArch64/sve-bf16-converts.ll
index 4104a43cb8e917..d63f7e6f3242e0 100644
--- a/llvm/test/CodeGen/AArch64/sve-bf16-converts.ll
+++ b/llvm/test/CodeGen/AArch64/sve-bf16-converts.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc -mattr=+sve                          < %s | FileCheck %s
+; RUN: llc -mattr=+sve                          < %s | FileCheck %s --check-prefixes=CHECK,NOBF16
 ; RUN: llc -mattr=+sve --enable-no-nans-fp-math < %s | FileCheck %s --check-prefixes=CHECK,NOBF16NNAN
 ; RUN: llc -mattr=+sve,+bf16                    < %s | FileCheck %s --check-prefixes=CHECK,BF16
 ; RUN: llc -mattr=+sme -force-streaming         < %s | FileCheck %s --check-prefixes=CHECK,BF16



More information about the llvm-commits mailing list