[llvm-branch-commits] [llvm] 658d9e5 - [AArch64] Add some basic handling for bf16 constants.

Tobias Hieta via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Aug 2 00:14:43 PDT 2023


Author: David Green
Date: 2023-08-02T09:09:59+02:00
New Revision: 658d9e565c9b46dafa78a8423158c01ed64f758d

URL: https://github.com/llvm/llvm-project/commit/658d9e565c9b46dafa78a8423158c01ed64f758d
DIFF: https://github.com/llvm/llvm-project/commit/658d9e565c9b46dafa78a8423158c01ed64f758d.diff

LOG: [AArch64] Add some basic handling for bf16 constants.

This adds some basic handling for bf16 constants, attempting to treat them a
lot like fp16 constants where it can. Zero immediates get lowered to FMOVH0,
others either get lowered to FMOVWHr(MOVi32imm) or use FMOVHi if they can.
Without fp16 they get expanded. This may not always be optimal, but fixes a gap
in our lowering. See llvm/test/CodeGen/AArch64/f16-imm.ll for the equivalent
fp16 test.

Differential Revision: https://reviews.llvm.org/D156649

(cherry picked from commit 778fa4edaf207bd2fef3635ceb8782e325ded76a)

Added: 
    llvm/test/CodeGen/AArch64/bf16-imm.ll

Modified: 
    llvm/lib/CodeGen/TargetLoweringBase.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64InstrFormats.td
    llvm/lib/Target/AArch64/AArch64InstrInfo.td

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index badb7fe533330f..68a4616fe4b833 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -921,7 +921,7 @@ void TargetLoweringBase::initActions() {
   // Legal, in which case all fp constants are legal, or use isFPImmLegal()
   // to optimize expansions for certain constants.
   setOperationAction(ISD::ConstantFP,
-                     {MVT::f16, MVT::f32, MVT::f64, MVT::f80, MVT::f128},
+                     {MVT::bf16, MVT::f16, MVT::f32, MVT::f64, MVT::f80, MVT::f128},
                      Expand);
 
   // These library functions default to expand.

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 13df87af6c7b58..c8a461a924b485 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1091,6 +1091,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
 
     if (Subtarget->hasFullFP16()) {
       setOperationAction(ISD::ConstantFP, MVT::f16, Legal);
+      setOperationAction(ISD::ConstantFP, MVT::bf16, Legal);
 
       setOperationAction(ISD::SINT_TO_FP, MVT::v8i8, Custom);
       setOperationAction(ISD::UINT_TO_FP, MVT::v8i8, Custom);
@@ -9757,7 +9758,7 @@ bool AArch64TargetLowering::isFPImmLegal(const APFloat &Imm, EVT VT,
     IsLegal = AArch64_AM::getFP64Imm(ImmInt) != -1 || Imm.isPosZero();
   else if (VT == MVT::f32)
     IsLegal = AArch64_AM::getFP32Imm(ImmInt) != -1 || Imm.isPosZero();
-  else if (VT == MVT::f16)
+  else if (VT == MVT::f16 || VT == MVT::bf16)
     IsLegal =
         (Subtarget->hasFullFP16() && AArch64_AM::getFP16Imm(ImmInt) != -1) ||
         Imm.isPosZero();

diff  --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
index d39fd69f9e0ee4..39135df285c238 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
@@ -1306,6 +1306,11 @@ def fpimm16 : Operand<f16>,
   let PrintMethod = "printFPImmOperand";
 }
 
+def fpimmbf16 : Operand<bf16>,
+                FPImmLeaf<bf16, [{
+      return AArch64_AM::getFP16Imm(Imm) != -1;
+    }], fpimm16XForm>;
+
 def fpimm32 : Operand<f32>,
               FPImmLeaf<f32, [{
       return AArch64_AM::getFP32Imm(Imm) != -1;

diff  --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 3450ed29d1426e..565d629841b940 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -4355,16 +4355,23 @@ def FMOVS0 : Pseudo<(outs FPR32:$Rd), (ins), [(set f32:$Rd, (fpimm0))]>,
 def FMOVD0 : Pseudo<(outs FPR64:$Rd), (ins), [(set f64:$Rd, (fpimm0))]>,
     Sched<[WriteF]>;
 }
+
 // Similarly add aliases
 def : InstAlias<"fmov $Rd, #0.0", (FMOVWHr FPR16:$Rd, WZR), 0>,
     Requires<[HasFullFP16]>;
 def : InstAlias<"fmov $Rd, #0.0", (FMOVWSr FPR32:$Rd, WZR), 0>;
 def : InstAlias<"fmov $Rd, #0.0", (FMOVXDr FPR64:$Rd, XZR), 0>;
 
-// Pattern for FP16 immediates
+def : Pat<(bf16 fpimm0),
+          (FMOVH0)>;
+
+// Pattern for FP16 and BF16 immediates
 let Predicates = [HasFullFP16] in {
   def : Pat<(f16 fpimm:$in),
-    (FMOVWHr (MOVi32imm (bitcast_fpimm_to_i32 f16:$in)))>;
+            (FMOVWHr (MOVi32imm (bitcast_fpimm_to_i32 f16:$in)))>;
+
+  def : Pat<(bf16 fpimm:$in),
+            (FMOVWHr (MOVi32imm (bitcast_fpimm_to_i32 bf16:$in)))>;
 }
 
 //===----------------------------------------------------------------------===//
@@ -4617,6 +4624,11 @@ let isReMaterializable = 1, isAsCheapAsAMove = 1 in {
 defm FMOV : FPMoveImmediate<"fmov">;
 }
 
+let Predicates = [HasFullFP16] in {
+  def : Pat<(bf16 fpimmbf16:$in),
+            (FMOVHi (fpimm16XForm bf16:$in))>;
+}
+
 //===----------------------------------------------------------------------===//
 // Advanced SIMD two vector instructions.
 //===----------------------------------------------------------------------===//

diff  --git a/llvm/test/CodeGen/AArch64/bf16-imm.ll b/llvm/test/CodeGen/AArch64/bf16-imm.ll
new file mode 100644
index 00000000000000..450bf286d8d783
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/bf16-imm.ll
@@ -0,0 +1,121 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=aarch64 -mattr=+fullfp16 | FileCheck %s --check-prefixes=CHECK,CHECK-FP16
+; RUN: llc < %s -mtriple=aarch64 -mattr=-fullfp16 | FileCheck %s --check-prefixes=CHECK,CHECK-NOFP16
+
+define bfloat @Const0() {
+; CHECK-LABEL: Const0:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    movi d0, #0000000000000000
+; CHECK-NEXT:    ret
+entry:
+  ret bfloat 0xR0000
+}
+
+define bfloat @Const1() {
+; CHECK-FP16-LABEL: Const1:
+; CHECK-FP16:       // %bb.0: // %entry
+; CHECK-FP16-NEXT:    fmov h0, #1.00000000
+; CHECK-FP16-NEXT:    ret
+;
+; CHECK-NOFP16-LABEL: Const1:
+; CHECK-NOFP16:       // %bb.0: // %entry
+; CHECK-NOFP16-NEXT:    adrp x8, .LCPI1_0
+; CHECK-NOFP16-NEXT:    ldr h0, [x8, :lo12:.LCPI1_0]
+; CHECK-NOFP16-NEXT:    ret
+entry:
+  ret bfloat 0xR3C00
+}
+
+define bfloat @Const2() {
+; CHECK-FP16-LABEL: Const2:
+; CHECK-FP16:       // %bb.0: // %entry
+; CHECK-FP16-NEXT:    fmov h0, #0.12500000
+; CHECK-FP16-NEXT:    ret
+;
+; CHECK-NOFP16-LABEL: Const2:
+; CHECK-NOFP16:       // %bb.0: // %entry
+; CHECK-NOFP16-NEXT:    adrp x8, .LCPI2_0
+; CHECK-NOFP16-NEXT:    ldr h0, [x8, :lo12:.LCPI2_0]
+; CHECK-NOFP16-NEXT:    ret
+entry:
+  ret bfloat 0xR3000
+}
+
+define bfloat @Const3() {
+; CHECK-FP16-LABEL: Const3:
+; CHECK-FP16:       // %bb.0: // %entry
+; CHECK-FP16-NEXT:    fmov h0, #30.00000000
+; CHECK-FP16-NEXT:    ret
+;
+; CHECK-NOFP16-LABEL: Const3:
+; CHECK-NOFP16:       // %bb.0: // %entry
+; CHECK-NOFP16-NEXT:    adrp x8, .LCPI3_0
+; CHECK-NOFP16-NEXT:    ldr h0, [x8, :lo12:.LCPI3_0]
+; CHECK-NOFP16-NEXT:    ret
+entry:
+  ret bfloat 0xR4F80
+}
+
+define bfloat @Const4() {
+; CHECK-FP16-LABEL: Const4:
+; CHECK-FP16:       // %bb.0: // %entry
+; CHECK-FP16-NEXT:    fmov h0, #31.00000000
+; CHECK-FP16-NEXT:    ret
+;
+; CHECK-NOFP16-LABEL: Const4:
+; CHECK-NOFP16:       // %bb.0: // %entry
+; CHECK-NOFP16-NEXT:    adrp x8, .LCPI4_0
+; CHECK-NOFP16-NEXT:    ldr h0, [x8, :lo12:.LCPI4_0]
+; CHECK-NOFP16-NEXT:    ret
+entry:
+  ret bfloat 0xR4FC0
+}
+
+define bfloat @Const5() {
+; CHECK-FP16-LABEL: Const5:
+; CHECK-FP16:       // %bb.0: // %entry
+; CHECK-FP16-NEXT:    mov w8, #12272 // =0x2ff0
+; CHECK-FP16-NEXT:    fmov h0, w8
+; CHECK-FP16-NEXT:    ret
+;
+; CHECK-NOFP16-LABEL: Const5:
+; CHECK-NOFP16:       // %bb.0: // %entry
+; CHECK-NOFP16-NEXT:    adrp x8, .LCPI5_0
+; CHECK-NOFP16-NEXT:    ldr h0, [x8, :lo12:.LCPI5_0]
+; CHECK-NOFP16-NEXT:    ret
+entry:
+  ret bfloat 0xR2FF0
+}
+
+define bfloat @Const6() {
+; CHECK-FP16-LABEL: Const6:
+; CHECK-FP16:       // %bb.0: // %entry
+; CHECK-FP16-NEXT:    mov w8, #20417 // =0x4fc1
+; CHECK-FP16-NEXT:    fmov h0, w8
+; CHECK-FP16-NEXT:    ret
+;
+; CHECK-NOFP16-LABEL: Const6:
+; CHECK-NOFP16:       // %bb.0: // %entry
+; CHECK-NOFP16-NEXT:    adrp x8, .LCPI6_0
+; CHECK-NOFP16-NEXT:    ldr h0, [x8, :lo12:.LCPI6_0]
+; CHECK-NOFP16-NEXT:    ret
+entry:
+  ret bfloat 0xR4FC1
+}
+
+define bfloat @Const7() {
+; CHECK-FP16-LABEL: Const7:
+; CHECK-FP16:       // %bb.0: // %entry
+; CHECK-FP16-NEXT:    mov w8, #20480 // =0x5000
+; CHECK-FP16-NEXT:    fmov h0, w8
+; CHECK-FP16-NEXT:    ret
+;
+; CHECK-NOFP16-LABEL: Const7:
+; CHECK-NOFP16:       // %bb.0: // %entry
+; CHECK-NOFP16-NEXT:    adrp x8, .LCPI7_0
+; CHECK-NOFP16-NEXT:    ldr h0, [x8, :lo12:.LCPI7_0]
+; CHECK-NOFP16-NEXT:    ret
+entry:
+  ret bfloat 0xR5000
+}
+


        


More information about the llvm-branch-commits mailing list