[llvm] 778fa4e - [AArch64] Add some basic handling for bf16 constants.

David Green via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 31 13:32:02 PDT 2023


Author: David Green
Date: 2023-07-31T21:31:56+01:00
New Revision: 778fa4edaf207bd2fef3635ceb8782e325ded76a

URL: https://github.com/llvm/llvm-project/commit/778fa4edaf207bd2fef3635ceb8782e325ded76a
DIFF: https://github.com/llvm/llvm-project/commit/778fa4edaf207bd2fef3635ceb8782e325ded76a.diff

LOG: [AArch64] Add some basic handling for bf16 constants.

This adds some basic handling for bf16 constants, attempting to treat them a
lot like fp16 constants where it can. Zero immediates get lowered to FMOVH0,
others either get lowered to FMOVWHr(MOVi32imm) or use FMOVHi if they can.
Without fp16 they get expanded. This may not always be optimal, but fixes a gap
in our lowering. See llvm/test/CodeGen/AArch64/f16-imm.ll for the equivalent
fp16 test.

Differential Revision: https://reviews.llvm.org/D156649

Added: 
    llvm/test/CodeGen/AArch64/bf16-imm.ll

Modified: 
    llvm/lib/CodeGen/TargetLoweringBase.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64InstrFormats.td
    llvm/lib/Target/AArch64/AArch64InstrInfo.td

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index badb7fe533330f..68a4616fe4b833 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -921,7 +921,7 @@ void TargetLoweringBase::initActions() {
   // Legal, in which case all fp constants are legal, or use isFPImmLegal()
   // to optimize expansions for certain constants.
   setOperationAction(ISD::ConstantFP,
-                     {MVT::f16, MVT::f32, MVT::f64, MVT::f80, MVT::f128},
+                     {MVT::bf16, MVT::f16, MVT::f32, MVT::f64, MVT::f80, MVT::f128},
                      Expand);
 
   // These library functions default to expand.

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 7547eae576a24f..0407252d50d12e 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1097,6 +1097,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
 
     if (Subtarget->hasFullFP16()) {
       setOperationAction(ISD::ConstantFP, MVT::f16, Legal);
+      setOperationAction(ISD::ConstantFP, MVT::bf16, Legal);
 
       setOperationAction(ISD::SINT_TO_FP, MVT::v8i8, Custom);
       setOperationAction(ISD::UINT_TO_FP, MVT::v8i8, Custom);
@@ -9789,7 +9790,7 @@ bool AArch64TargetLowering::isFPImmLegal(const APFloat &Imm, EVT VT,
     IsLegal = AArch64_AM::getFP64Imm(ImmInt) != -1 || Imm.isPosZero();
   else if (VT == MVT::f32)
     IsLegal = AArch64_AM::getFP32Imm(ImmInt) != -1 || Imm.isPosZero();
-  else if (VT == MVT::f16)
+  else if (VT == MVT::f16 || VT == MVT::bf16)
     IsLegal =
         (Subtarget->hasFullFP16() && AArch64_AM::getFP16Imm(ImmInt) != -1) ||
         Imm.isPosZero();

diff  --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
index 578a7645da21c9..ac7b208e3f075c 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
@@ -1306,6 +1306,11 @@ def fpimm16 : Operand<f16>,
   let PrintMethod = "printFPImmOperand";
 }
 
+def fpimmbf16 : Operand<bf16>,
+                FPImmLeaf<bf16, [{
+      return AArch64_AM::getFP16Imm(Imm) != -1;
+    }], fpimm16XForm>;
+
 def fpimm32 : Operand<f32>,
               FPImmLeaf<f32, [{
       return AArch64_AM::getFP32Imm(Imm) != -1;

diff  --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 3450ed29d1426e..565d629841b940 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -4355,16 +4355,23 @@ def FMOVS0 : Pseudo<(outs FPR32:$Rd), (ins), [(set f32:$Rd, (fpimm0))]>,
 def FMOVD0 : Pseudo<(outs FPR64:$Rd), (ins), [(set f64:$Rd, (fpimm0))]>,
     Sched<[WriteF]>;
 }
+
 // Similarly add aliases
 def : InstAlias<"fmov $Rd, #0.0", (FMOVWHr FPR16:$Rd, WZR), 0>,
     Requires<[HasFullFP16]>;
 def : InstAlias<"fmov $Rd, #0.0", (FMOVWSr FPR32:$Rd, WZR), 0>;
 def : InstAlias<"fmov $Rd, #0.0", (FMOVXDr FPR64:$Rd, XZR), 0>;
 
-// Pattern for FP16 immediates
+def : Pat<(bf16 fpimm0),
+          (FMOVH0)>;
+
+// Pattern for FP16 and BF16 immediates
 let Predicates = [HasFullFP16] in {
   def : Pat<(f16 fpimm:$in),
-    (FMOVWHr (MOVi32imm (bitcast_fpimm_to_i32 f16:$in)))>;
+            (FMOVWHr (MOVi32imm (bitcast_fpimm_to_i32 f16:$in)))>;
+
+  def : Pat<(bf16 fpimm:$in),
+            (FMOVWHr (MOVi32imm (bitcast_fpimm_to_i32 bf16:$in)))>;
 }
 
 //===----------------------------------------------------------------------===//
@@ -4617,6 +4624,11 @@ let isReMaterializable = 1, isAsCheapAsAMove = 1 in {
 defm FMOV : FPMoveImmediate<"fmov">;
 }
 
+let Predicates = [HasFullFP16] in {
+  def : Pat<(bf16 fpimmbf16:$in),
+            (FMOVHi (fpimm16XForm bf16:$in))>;
+}
+
 //===----------------------------------------------------------------------===//
 // Advanced SIMD two vector instructions.
 //===----------------------------------------------------------------------===//

diff  --git a/llvm/test/CodeGen/AArch64/bf16-imm.ll b/llvm/test/CodeGen/AArch64/bf16-imm.ll
new file mode 100644
index 00000000000000..450bf286d8d783
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/bf16-imm.ll
@@ -0,0 +1,121 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=aarch64 -mattr=+fullfp16 | FileCheck %s --check-prefixes=CHECK,CHECK-FP16
+; RUN: llc < %s -mtriple=aarch64 -mattr=-fullfp16 | FileCheck %s --check-prefixes=CHECK,CHECK-NOFP16
+
+define bfloat @Const0() {
+; CHECK-LABEL: Const0:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    movi d0, #0000000000000000
+; CHECK-NEXT:    ret
+entry:
+  ret bfloat 0xR0000
+}
+
+define bfloat @Const1() {
+; CHECK-FP16-LABEL: Const1:
+; CHECK-FP16:       // %bb.0: // %entry
+; CHECK-FP16-NEXT:    fmov h0, #1.00000000
+; CHECK-FP16-NEXT:    ret
+;
+; CHECK-NOFP16-LABEL: Const1:
+; CHECK-NOFP16:       // %bb.0: // %entry
+; CHECK-NOFP16-NEXT:    adrp x8, .LCPI1_0
+; CHECK-NOFP16-NEXT:    ldr h0, [x8, :lo12:.LCPI1_0]
+; CHECK-NOFP16-NEXT:    ret
+entry:
+  ret bfloat 0xR3C00
+}
+
+define bfloat @Const2() {
+; CHECK-FP16-LABEL: Const2:
+; CHECK-FP16:       // %bb.0: // %entry
+; CHECK-FP16-NEXT:    fmov h0, #0.12500000
+; CHECK-FP16-NEXT:    ret
+;
+; CHECK-NOFP16-LABEL: Const2:
+; CHECK-NOFP16:       // %bb.0: // %entry
+; CHECK-NOFP16-NEXT:    adrp x8, .LCPI2_0
+; CHECK-NOFP16-NEXT:    ldr h0, [x8, :lo12:.LCPI2_0]
+; CHECK-NOFP16-NEXT:    ret
+entry:
+  ret bfloat 0xR3000
+}
+
+define bfloat @Const3() {
+; CHECK-FP16-LABEL: Const3:
+; CHECK-FP16:       // %bb.0: // %entry
+; CHECK-FP16-NEXT:    fmov h0, #30.00000000
+; CHECK-FP16-NEXT:    ret
+;
+; CHECK-NOFP16-LABEL: Const3:
+; CHECK-NOFP16:       // %bb.0: // %entry
+; CHECK-NOFP16-NEXT:    adrp x8, .LCPI3_0
+; CHECK-NOFP16-NEXT:    ldr h0, [x8, :lo12:.LCPI3_0]
+; CHECK-NOFP16-NEXT:    ret
+entry:
+  ret bfloat 0xR4F80
+}
+
+define bfloat @Const4() {
+; CHECK-FP16-LABEL: Const4:
+; CHECK-FP16:       // %bb.0: // %entry
+; CHECK-FP16-NEXT:    fmov h0, #31.00000000
+; CHECK-FP16-NEXT:    ret
+;
+; CHECK-NOFP16-LABEL: Const4:
+; CHECK-NOFP16:       // %bb.0: // %entry
+; CHECK-NOFP16-NEXT:    adrp x8, .LCPI4_0
+; CHECK-NOFP16-NEXT:    ldr h0, [x8, :lo12:.LCPI4_0]
+; CHECK-NOFP16-NEXT:    ret
+entry:
+  ret bfloat 0xR4FC0
+}
+
+define bfloat @Const5() {
+; CHECK-FP16-LABEL: Const5:
+; CHECK-FP16:       // %bb.0: // %entry
+; CHECK-FP16-NEXT:    mov w8, #12272 // =0x2ff0
+; CHECK-FP16-NEXT:    fmov h0, w8
+; CHECK-FP16-NEXT:    ret
+;
+; CHECK-NOFP16-LABEL: Const5:
+; CHECK-NOFP16:       // %bb.0: // %entry
+; CHECK-NOFP16-NEXT:    adrp x8, .LCPI5_0
+; CHECK-NOFP16-NEXT:    ldr h0, [x8, :lo12:.LCPI5_0]
+; CHECK-NOFP16-NEXT:    ret
+entry:
+  ret bfloat 0xR2FF0
+}
+
+define bfloat @Const6() {
+; CHECK-FP16-LABEL: Const6:
+; CHECK-FP16:       // %bb.0: // %entry
+; CHECK-FP16-NEXT:    mov w8, #20417 // =0x4fc1
+; CHECK-FP16-NEXT:    fmov h0, w8
+; CHECK-FP16-NEXT:    ret
+;
+; CHECK-NOFP16-LABEL: Const6:
+; CHECK-NOFP16:       // %bb.0: // %entry
+; CHECK-NOFP16-NEXT:    adrp x8, .LCPI6_0
+; CHECK-NOFP16-NEXT:    ldr h0, [x8, :lo12:.LCPI6_0]
+; CHECK-NOFP16-NEXT:    ret
+entry:
+  ret bfloat 0xR4FC1
+}
+
+define bfloat @Const7() {
+; CHECK-FP16-LABEL: Const7:
+; CHECK-FP16:       // %bb.0: // %entry
+; CHECK-FP16-NEXT:    mov w8, #20480 // =0x5000
+; CHECK-FP16-NEXT:    fmov h0, w8
+; CHECK-FP16-NEXT:    ret
+;
+; CHECK-NOFP16-LABEL: Const7:
+; CHECK-NOFP16:       // %bb.0: // %entry
+; CHECK-NOFP16-NEXT:    adrp x8, .LCPI7_0
+; CHECK-NOFP16-NEXT:    ldr h0, [x8, :lo12:.LCPI7_0]
+; CHECK-NOFP16-NEXT:    ret
+entry:
+  ret bfloat 0xR5000
+}
+


        


More information about the llvm-commits mailing list