[llvm] a9e9dd9 - [AArch64] Add bf16 select handling

David Green via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 11 06:20:40 PDT 2022


Author: David Green
Date: 2022-08-11T14:20:36+01:00
New Revision: a9e9dd9a3a44d88cda85d0b26778c80faf5355d2

URL: https://github.com/llvm/llvm-project/commit/a9e9dd9a3a44d88cda85d0b26778c80faf5355d2
DIFF: https://github.com/llvm/llvm-project/commit/a9e9dd9a3a44d88cda85d0b26778c80faf5355d2.diff

LOG: [AArch64] Add bf16 select handling

A bfloat select operation will currently crash, but is allowed from C.
This adds handling for the operation, turning it into a FCSELHrrr if
fullfp16 is present, or converting it to a FCSELSrrr if not. The
FCSELSrrr is created via using INSERT_SUBREG/EXTRACT_SUBREG to convert
the bf16 to a f32 and using the f32 pattern for FCSELSrrr. (I originally
attempted to do this via a tablegen pattern, but it appears that the
nzcv glue is places onto the wrong node, causing it to be forgotten and
incorrect scheduling to be emitted).

The FCSELSrrr can also be used for fp16 selects when +fullfp16 is not
present, which helps avoid an unnecessary promotion to f32.

Differential Revision: https://reviews.llvm.org/D131253

Added: 
    llvm/test/CodeGen/AArch64/bf16-select.ll

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64InstrInfo.td
    llvm/test/CodeGen/AArch64/arm64-fmax.ll
    llvm/test/CodeGen/AArch64/f16-instructions.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 305156391e1fe..771b1b40a8e6e 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -402,11 +402,13 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
   setOperationAction(ISD::SELECT, MVT::i32, Custom);
   setOperationAction(ISD::SELECT, MVT::i64, Custom);
   setOperationAction(ISD::SELECT, MVT::f16, Custom);
+  setOperationAction(ISD::SELECT, MVT::bf16, Custom);
   setOperationAction(ISD::SELECT, MVT::f32, Custom);
   setOperationAction(ISD::SELECT, MVT::f64, Custom);
   setOperationAction(ISD::SELECT_CC, MVT::i32, Custom);
   setOperationAction(ISD::SELECT_CC, MVT::i64, Custom);
   setOperationAction(ISD::SELECT_CC, MVT::f16, Custom);
+  setOperationAction(ISD::SELECT_CC, MVT::bf16, Expand);
   setOperationAction(ISD::SELECT_CC, MVT::f32, Custom);
   setOperationAction(ISD::SELECT_CC, MVT::f64, Custom);
   setOperationAction(ISD::BR_JT, MVT::Other, Custom);
@@ -603,7 +605,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
 
   if (!Subtarget->hasFullFP16()) {
     for (auto Op :
-         {ISD::SELECT,         ISD::SELECT_CC,      ISD::SETCC,
+         {ISD::SETCC,          ISD::SELECT_CC,
           ISD::BR_CC,          ISD::FADD,           ISD::FSUB,
           ISD::FMUL,           ISD::FDIV,           ISD::FMA,
           ISD::FNEG,           ISD::FABS,           ISD::FCEIL,
@@ -8439,7 +8441,32 @@ SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
     RHS = DAG.getConstant(0, DL, CCVal.getValueType());
     CC = ISD::SETNE;
   }
-  return LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, DL, DAG);
+
+  // If we are lowering a f16 and we do not have fullf16, convert to a f32 in
+  // order to use FCSELSrrr
+  if ((Ty == MVT::f16 || Ty == MVT::bf16) && !Subtarget->hasFullFP16()) {
+    TVal = SDValue(
+        DAG.getMachineNode(TargetOpcode::INSERT_SUBREG, DL, MVT::f32,
+                           DAG.getUNDEF(MVT::f32), TVal,
+                           DAG.getTargetConstant(AArch64::hsub, DL, MVT::i32)),
+        0);
+    FVal = SDValue(
+        DAG.getMachineNode(TargetOpcode::INSERT_SUBREG, DL, MVT::f32,
+                           DAG.getUNDEF(MVT::f32), FVal,
+                           DAG.getTargetConstant(AArch64::hsub, DL, MVT::i32)),
+        0);
+  }
+
+  SDValue Res = LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, DL, DAG);
+
+  if ((Ty == MVT::f16 || Ty == MVT::bf16) && !Subtarget->hasFullFP16()) {
+    Res = SDValue(
+        DAG.getMachineNode(TargetOpcode::EXTRACT_SUBREG, DL, Ty, Res,
+                           DAG.getTargetConstant(AArch64::hsub, DL, MVT::i32)),
+        0);
+  }
+
+  return Res;
 }
 
 SDValue AArch64TargetLowering::LowerJumpTable(SDValue Op,

diff  --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 926e7305bab92..b2a458a2683e7 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -4205,6 +4205,10 @@ defm FCCMP  : FPCondComparison<0, "fccmp", AArch64fccmp>;
 
 defm FCSEL : FPCondSelect<"fcsel">;
 
+let Predicates = [HasFullFP16] in
+def : Pat<(bf16 (AArch64csel (bf16 FPR16:$Rn), (bf16 FPR16:$Rm), (i32 imm:$cond), NZCV)),
+          (FCSELHrrr FPR16:$Rn, FPR16:$Rm, imm:$cond)>;
+
 // CSEL instructions providing f128 types need to be handled by a
 // pseudo-instruction since the eventual code will need to introduce basic
 // blocks and control flow.

diff  --git a/llvm/test/CodeGen/AArch64/arm64-fmax.ll b/llvm/test/CodeGen/AArch64/arm64-fmax.ll
index 025f660d77046..24429a8d275a0 100644
--- a/llvm/test/CodeGen/AArch64/arm64-fmax.ll
+++ b/llvm/test/CodeGen/AArch64/arm64-fmax.ll
@@ -71,13 +71,12 @@ define i64 @test_integer(i64  %in) {
 define float @test_f16(half %in) {
 ; CHECK-LABEL: test_f16:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    fcvt s0, h0
-; CHECK-NEXT:    movi d1, #0000000000000000
-; CHECK-NEXT:    fcmp s0, #0.0
-; CHECK-NEXT:    cset w8, lt
-; CHECK-NEXT:    cmp w8, #0
-; CHECK-NEXT:    fcsel s0, s0, s1, ne
-; CHECK-NEXT:    fcvt h0, s0
+; CHECK-NEXT:    // kill: def $h0 killed $h0 def $s0
+; CHECK-NEXT:    fcvt s1, h0
+; CHECK-NEXT:    adrp x8, .LCPI5_0
+; CHECK-NEXT:    ldr h2, [x8, :lo12:.LCPI5_0]
+; CHECK-NEXT:    fcmp s1, #0.0
+; CHECK-NEXT:    fcsel s0, s0, s2, lt
 ; CHECK-NEXT:    fcvt s0, h0
 ; CHECK-NEXT:    ret
   %cmp = fcmp nnan ult half %in, 0.000000e+00

diff  --git a/llvm/test/CodeGen/AArch64/bf16-select.ll b/llvm/test/CodeGen/AArch64/bf16-select.ll
new file mode 100644
index 0000000000000..e3479f49e86b6
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/bf16-select.ll
@@ -0,0 +1,64 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple aarch64-unknown-unknown -verify-machineinstrs | FileCheck %s --check-prefixes=CHECK-BASE
+; RUN: llc < %s -mtriple aarch64-unknown-unknown -mattr=+fullfp16 -verify-machineinstrs | FileCheck %s --check-prefixes=CHECK-FP16
+; RUN: llc < %s -mtriple aarch64-unknown-unknown -mattr=+bf16 -verify-machineinstrs | FileCheck %s --check-prefixes=CHECK-BASE
+; RUN: llc < %s -mtriple aarch64-unknown-unknown -mattr=+bf16,+fullfp16 -verify-machineinstrs | FileCheck %s --check-prefixes=CHECK-FP16
+
+define bfloat @test_select(bfloat %a, bfloat %b, i1 zeroext %c) {
+; CHECK-BASE-LABEL: test_select:
+; CHECK-BASE:       // %bb.0:
+; CHECK-BASE-NEXT:    // kill: def $h0 killed $h0 def $s0
+; CHECK-BASE-NEXT:    cmp w0, #0
+; CHECK-BASE-NEXT:    // kill: def $h1 killed $h1 def $s1
+; CHECK-BASE-NEXT:    fcsel s0, s0, s1, ne
+; CHECK-BASE-NEXT:    // kill: def $h0 killed $h0 killed $s0
+; CHECK-BASE-NEXT:    ret
+;
+; CHECK-FP16-LABEL: test_select:
+; CHECK-FP16:       // %bb.0:
+; CHECK-FP16-NEXT:    cmp w0, #0
+; CHECK-FP16-NEXT:    fcsel h0, h0, h1, ne
+; CHECK-FP16-NEXT:    ret
+  %r = select i1 %c, bfloat %a, bfloat %b
+  ret bfloat %r
+}
+
+define bfloat @test_select_fcc(bfloat %a, bfloat %b, float %c, float %d) {
+; CHECK-BASE-LABEL: test_select_fcc:
+; CHECK-BASE:       // %bb.0:
+; CHECK-BASE-NEXT:    fcmp s2, s3
+; CHECK-BASE-NEXT:    // kill: def $h0 killed $h0 def $s0
+; CHECK-BASE-NEXT:    // kill: def $h1 killed $h1 def $s1
+; CHECK-BASE-NEXT:    fcsel s0, s0, s1, ne
+; CHECK-BASE-NEXT:    // kill: def $h0 killed $h0 killed $s0
+; CHECK-BASE-NEXT:    ret
+;
+; CHECK-FP16-LABEL: test_select_fcc:
+; CHECK-FP16:       // %bb.0:
+; CHECK-FP16-NEXT:    fcmp s2, s3
+; CHECK-FP16-NEXT:    fcsel h0, h0, h1, ne
+; CHECK-FP16-NEXT:    ret
+  %cc = fcmp une float %c, %d
+  %r = select i1 %cc, bfloat %a, bfloat %b
+  ret bfloat %r
+}
+
+define bfloat @test_select_icc(bfloat %a, bfloat %b, i32 %c, i32 %d) {
+; CHECK-BASE-LABEL: test_select_icc:
+; CHECK-BASE:       // %bb.0:
+; CHECK-BASE-NEXT:    // kill: def $h0 killed $h0 def $s0
+; CHECK-BASE-NEXT:    cmp w0, w1
+; CHECK-BASE-NEXT:    // kill: def $h1 killed $h1 def $s1
+; CHECK-BASE-NEXT:    fcsel s0, s0, s1, ne
+; CHECK-BASE-NEXT:    // kill: def $h0 killed $h0 killed $s0
+; CHECK-BASE-NEXT:    ret
+;
+; CHECK-FP16-LABEL: test_select_icc:
+; CHECK-FP16:       // %bb.0:
+; CHECK-FP16-NEXT:    cmp w0, w1
+; CHECK-FP16-NEXT:    fcsel h0, h0, h1, ne
+; CHECK-FP16-NEXT:    ret
+  %cc = icmp ne i32 %c, %d
+  %r = select i1 %cc, bfloat %a, bfloat %b
+  ret bfloat %r
+}

diff  --git a/llvm/test/CodeGen/AArch64/f16-instructions.ll b/llvm/test/CodeGen/AArch64/f16-instructions.ll
index dc63b5139ca3a..63e98c4c056b7 100644
--- a/llvm/test/CodeGen/AArch64/f16-instructions.ll
+++ b/llvm/test/CodeGen/AArch64/f16-instructions.ll
@@ -167,11 +167,8 @@ define half @test_tailcall_flipped(half %a, half %b) #0 {
 }
 
 ; CHECK-CVT-LABEL: test_select:
-; CHECK-CVT-NEXT: fcvt s1, h1
-; CHECK-CVT-NEXT: fcvt s0, h0
 ; CHECK-CVT-NEXT: cmp  w0, #0
 ; CHECK-CVT-NEXT: fcsel s0, s0, s1, ne
-; CHECK-CVT-NEXT: fcvt h0, s0
 ; CHECK-CVT-NEXT: ret
 
 ; CHECK-FP16-LABEL: test_select:
@@ -187,11 +184,8 @@ define half @test_select(half %a, half %b, i1 zeroext %c) #0 {
 ; CHECK-CVT-LABEL: test_select_cc:
 ; CHECK-CVT-DAG: fcvt s3, h3
 ; CHECK-CVT-DAG: fcvt s2, h2
-; CHECK-CVT-DAG: fcvt s1, h1
-; CHECK-CVT-DAG: fcvt s0, h0
 ; CHECK-CVT-DAG: fcmp s2, s3
 ; CHECK-CVT-NEXT: fcsel s0, s0, s1, ne
-; CHECK-CVT-NEXT: fcvt h0, s0
 ; CHECK-CVT-NEXT: ret
 
 ; CHECK-FP16-LABEL: test_select_cc:
@@ -224,11 +218,8 @@ define float @test_select_cc_f32_f16(float %a, float %b, half %c, half %d) #0 {
 }
 
 ; CHECK-CVT-LABEL: test_select_cc_f16_f32:
-; CHECK-CVT-DAG:  fcvt s0, h0
-; CHECK-CVT-DAG:  fcvt s1, h1
 ; CHECK-CVT-DAG:  fcmp s2, s3
 ; CHECK-CVT-NEXT: fcsel s0, s0, s1, ne
-; CHECK-CVT-NEXT: fcvt h0, s0
 ; CHECK-CVT-NEXT: ret
 
 ; CHECK-FP16-LABEL: test_select_cc_f16_f32:
@@ -485,16 +476,14 @@ define i1 @test_fcmp_ord(half %a, half %b) #0 {
 }
 
 ; CHECK-COMMON-LABEL: test_fccmp:
-; CHECK-CVT:      fcvt  s0, h0
-; CHECK-CVT-NEXT: fmov  s1, #8.00000000
-; CHECK-CVT-NEXT: fcmp  s0, s1
-; CHECK-CVT-NEXT: fmov  s1, #5.00000000
-; CHECK-CVT-NEXT: cset  w8, gt
-; CHECK-CVT-NEXT: fcmp  s0, s1
-; CHECK-CVT-NEXT: cset  w9, mi
-; CHECK-CVT-NEXT: tst   w8, w9
-; CHECK-CVT-NEXT: fcsel s0, s0, s1, ne
-; CHECK-CVT-NEXT: fcvt  h0, s0
+; CHECK-CVT:      fcvt  s1, h0
+; CHECK-CVT-NEXT: fmov  s2, #5.00000000
+; CHECK-CVT-NEXT: fcmp  s1, s2
+; CHECK-CVT-NEXT: fmov  s2, #8.00000000
+; CHECK-CVT-NEXT: adrp x8
+; CHECK-CVT-NEXT: fccmp s1, s2, #4, mi
+; CHECK-CVT-NEXT: ldr h1, [x8,
+; CHECK-CVT-NEXT: fcsel s0, s0, s1, gt
 ; CHECK-CVT-NEXT: str   h0, [x0]
 ; CHECK-CVT-NEXT: ret
 ; CHECK-FP16:      fmov  h1, #5.00000000


        


More information about the llvm-commits mailing list