[llvm] 1084b32 - [ARM] Always replace FP16 bitcasts with VMOVhr or VMOVrh

David Green via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 28 08:13:37 PDT 2020


Author: David Green
Date: 2020-04-28T16:12:53+01:00
New Revision: 1084b323396048bbabbf4b11a173e5926eaeb8c6

URL: https://github.com/llvm/llvm-project/commit/1084b323396048bbabbf4b11a173e5926eaeb8c6
DIFF: https://github.com/llvm/llvm-project/commit/1084b323396048bbabbf4b11a173e5926eaeb8c6.diff

LOG: [ARM] Always replace FP16 bitcasts with VMOVhr or VMOVrh

This changes the logic with lowering fp16 bitcasts to always produce
either a VMOVhr or a VMOVrh, instead of only trying to do it with
certain surrounding nodes. To perform the same optimisations demand bits
and known bits information has been added for them.

Differential Revision: https://reviews.llvm.org/D78587

Added: 
    

Modified: 
    llvm/lib/Target/ARM/ARMISelLowering.cpp
    llvm/test/CodeGen/Thumb2/mve-intrinsics/vminvq.ll
    llvm/test/CodeGen/Thumb2/mve-vdup.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index ee0050ea9430..e6428f55d01f 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -5752,57 +5752,25 @@ static SDValue ExpandBITCAST(SDNode *N, SelectionDAG &DAG,
   SDLoc dl(N);
   SDValue Op = N->getOperand(0);
 
-  // This function is only supposed to be called for i64 types, either as the
-  // source or destination of the bit convert.
+  // This function is only supposed to be called for i16 and i64 types, either
+  // as the source or destination of the bit convert.
   EVT SrcVT = Op.getValueType();
   EVT DstVT = N->getValueType(0);
-  const bool HasFullFP16 = Subtarget->hasFullFP16();
 
   if (SrcVT == MVT::i16 && DstVT == MVT::f16) {
-    if (!HasFullFP16)
+    if (!Subtarget->hasFullFP16())
       return SDValue();
-    // SoftFP: read half-precision arguments:
-    //
-    // t2: i32,ch = ...
-    //        t7: i16 = truncate t2 <~~~~ Op
-    //      t8: f16 = bitcast t7    <~~~~ N
-    //
-    if (Op.getOperand(0).getValueType() == MVT::i32)
-      return DAG.getNode(ARMISD::VMOVhr, SDLoc(Op),
-                         MVT::f16, Op.getOperand(0));
-
-    return SDValue();
+    // f16 bitcast i16 -> VMOVhr
+    return DAG.getNode(ARMISD::VMOVhr, SDLoc(N), MVT::f16,
+                       DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), MVT::i32, Op));
   }
 
-  // Half-precision return values
   if (SrcVT == MVT::f16 && DstVT == MVT::i16) {
-    if (!HasFullFP16)
+    if (!Subtarget->hasFullFP16())
       return SDValue();
-    //
-    //          t11: f16 = fadd t8, t10
-    //        t12: i16 = bitcast t11       <~~~ SDNode N
-    //      t13: i32 = zero_extend t12
-    //    t16: ch,glue = CopyToReg t0, Register:i32 %r0, t13
-    //  t17: ch = ARMISD::RET_FLAG t16, Register:i32 %r0, t16:1
-    //
-    // transform this into:
-    //
-    //    t20: i32 = ARMISD::VMOVrh t11
-    //  t16: ch,glue = CopyToReg t0, Register:i32 %r0, t20
-    //
-    auto ZeroExtend = N->use_begin();
-    if (N->use_size() != 1 || ZeroExtend->getOpcode() != ISD::ZERO_EXTEND ||
-        ZeroExtend->getValueType(0) != MVT::i32)
-      return SDValue();
-
-    auto Copy = ZeroExtend->use_begin();
-    if (Copy->getOpcode() == ISD::CopyToReg &&
-        Copy->use_begin()->getOpcode() == ARMISD::RET_FLAG) {
-      SDValue Cvt = DAG.getNode(ARMISD::VMOVrh, SDLoc(Op), MVT::i32, Op);
-      DAG.ReplaceAllUsesWith(*ZeroExtend, &Cvt);
-      return Cvt;
-    }
-    return SDValue();
+    // i16 bitcast f16 -> VMOVrh
+    return DAG.getNode(ISD::TRUNCATE, SDLoc(N), MVT::i16,
+                       DAG.getNode(ARMISD::VMOVrh, SDLoc(N), MVT::i32, Op));
   }
 
   if (!(SrcVT == MVT::i64 || DstVT == MVT::i64))
@@ -13019,16 +12987,25 @@ static SDValue PerformVMOVhrCombine(SDNode *N, TargetLowering::DAGCombinerInfo &
   //     t2: f32,ch = CopyFromReg t0, Register:f32 %0
   //   t5: i32 = bitcast t2
   // t18: f16 = ARMISD::VMOVhr t5
-  SDValue BC = N->getOperand(0);
-  if (BC->getOpcode() != ISD::BITCAST)
-    return SDValue();
-  SDValue Copy = BC->getOperand(0);
-  if (Copy.getValueType() != MVT::f32 || Copy->getOpcode() != ISD::CopyFromReg)
-    return SDValue();
+  SDValue Op0 = N->getOperand(0);
+  if (Op0->getOpcode() == ISD::BITCAST) {
+    SDValue Copy = Op0->getOperand(0);
+    if (Copy.getValueType() == MVT::f32 &&
+        Copy->getOpcode() == ISD::CopyFromReg) {
+      SDValue Ops[] = {Copy->getOperand(0), Copy->getOperand(1)};
+      SDValue NewCopy =
+          DCI.DAG.getNode(ISD::CopyFromReg, SDLoc(N), MVT::f16, Ops);
+      return NewCopy;
+    }
+  }
 
-  SDValue Ops[] = {Copy->getOperand(0), Copy->getOperand(1)};
-  SDValue NewCopy = DCI.DAG.getNode(ISD::CopyFromReg, SDLoc(N), MVT::f16, Ops);
-  return NewCopy;
+  // Only the bottom 16 bits of the source register are used.
+  APInt DemandedMask = APInt::getLowBitsSet(32, 16);
+  const TargetLowering &TLI = DCI.DAG.getTargetLoweringInfo();
+  if (TLI.SimplifyDemandedBits(Op0, DemandedMask, DCI))
+    return SDValue(N, 0);
+
+  return SDValue();
 }
 
 static SDValue PerformVMOVrhCombine(SDNode *N,
@@ -16393,6 +16370,12 @@ void ARMTargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
     assert(DstSz == Known.getBitWidth());
     break;
   }
+  case ARMISD::VMOVrh: {
+    KnownBits KnownOp = DAG.computeKnownBits(Op->getOperand(0), Depth + 1);
+    assert(KnownOp.getBitWidth() == 16);
+    Known = KnownOp.zext(32);
+    break;
+  }
   }
 }
 

diff  --git a/llvm/test/CodeGen/Thumb2/mve-intrinsics/vminvq.ll b/llvm/test/CodeGen/Thumb2/mve-intrinsics/vminvq.ll
index fd1daef4b9ec..578f9f003f55 100644
--- a/llvm/test/CodeGen/Thumb2/mve-intrinsics/vminvq.ll
+++ b/llvm/test/CodeGen/Thumb2/mve-intrinsics/vminvq.ll
@@ -220,14 +220,11 @@ entry:
 define arm_aapcs_vfpcc float @test_vminnmvq_f16(float %a.coerce, <8 x half> %b) {
 ; CHECK-LABEL: test_vminnmvq_f16:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    sub sp, #4
 ; CHECK-NEXT:    vmov r0, s0
 ; CHECK-NEXT:    vminnmv.f16 r0, q1
 ; CHECK-NEXT:    vmov s0, r0
-; CHECK-NEXT:    vstr.16 s0, [sp, #2]
-; CHECK-NEXT:    ldrh.w r0, [sp, #2]
+; CHECK-NEXT:    vmov.f16 r0, s0
 ; CHECK-NEXT:    vmov s0, r0
-; CHECK-NEXT:    add sp, #4
 ; CHECK-NEXT:    bx lr
 entry:
   %0 = bitcast float %a.coerce to i32
@@ -255,14 +252,11 @@ entry:
 define arm_aapcs_vfpcc float @test_vminnmavq_f16(float %a.coerce, <8 x half> %b) {
 ; CHECK-LABEL: test_vminnmavq_f16:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    sub sp, #4
 ; CHECK-NEXT:    vmov r0, s0
 ; CHECK-NEXT:    vminnmav.f16 r0, q1
 ; CHECK-NEXT:    vmov s0, r0
-; CHECK-NEXT:    vstr.16 s0, [sp, #2]
-; CHECK-NEXT:    ldrh.w r0, [sp, #2]
+; CHECK-NEXT:    vmov.f16 r0, s0
 ; CHECK-NEXT:    vmov s0, r0
-; CHECK-NEXT:    add sp, #4
 ; CHECK-NEXT:    bx lr
 entry:
   %0 = bitcast float %a.coerce to i32
@@ -290,14 +284,11 @@ entry:
 define arm_aapcs_vfpcc float @test_vmaxnmvq_f16(float %a.coerce, <8 x half> %b) {
 ; CHECK-LABEL: test_vmaxnmvq_f16:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    sub sp, #4
 ; CHECK-NEXT:    vmov r0, s0
 ; CHECK-NEXT:    vmaxnmv.f16 r0, q1
 ; CHECK-NEXT:    vmov s0, r0
-; CHECK-NEXT:    vstr.16 s0, [sp, #2]
-; CHECK-NEXT:    ldrh.w r0, [sp, #2]
+; CHECK-NEXT:    vmov.f16 r0, s0
 ; CHECK-NEXT:    vmov s0, r0
-; CHECK-NEXT:    add sp, #4
 ; CHECK-NEXT:    bx lr
 entry:
   %0 = bitcast float %a.coerce to i32
@@ -325,14 +316,11 @@ entry:
 define arm_aapcs_vfpcc float @test_vmaxnmavq_f16(float %a.coerce, <8 x half> %b) {
 ; CHECK-LABEL: test_vmaxnmavq_f16:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    sub sp, #4
 ; CHECK-NEXT:    vmov r0, s0
 ; CHECK-NEXT:    vmaxnmav.f16 r0, q1
 ; CHECK-NEXT:    vmov s0, r0
-; CHECK-NEXT:    vstr.16 s0, [sp, #2]
-; CHECK-NEXT:    ldrh.w r0, [sp, #2]
+; CHECK-NEXT:    vmov.f16 r0, s0
 ; CHECK-NEXT:    vmov s0, r0
-; CHECK-NEXT:    add sp, #4
 ; CHECK-NEXT:    bx lr
 entry:
   %0 = bitcast float %a.coerce to i32
@@ -648,16 +636,13 @@ entry:
 define arm_aapcs_vfpcc float @test_vminnmvq_p_f16(float %a.coerce, <8 x half> %b, i16 zeroext %p) {
 ; CHECK-LABEL: test_vminnmvq_p_f16:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    sub sp, #4
 ; CHECK-NEXT:    vmov r1, s0
 ; CHECK-NEXT:    vmsr p0, r0
 ; CHECK-NEXT:    vpst
 ; CHECK-NEXT:    vminnmvt.f16 r1, q1
 ; CHECK-NEXT:    vmov s0, r1
-; CHECK-NEXT:    vstr.16 s0, [sp, #2]
-; CHECK-NEXT:    ldrh.w r0, [sp, #2]
+; CHECK-NEXT:    vmov.f16 r0, s0
 ; CHECK-NEXT:    vmov s0, r0
-; CHECK-NEXT:    add sp, #4
 ; CHECK-NEXT:    bx lr
 entry:
   %0 = bitcast float %a.coerce to i32
@@ -691,16 +676,13 @@ entry:
 define arm_aapcs_vfpcc float @test_vminnmavq_p_f16(float %a.coerce, <8 x half> %b, i16 zeroext %p) {
 ; CHECK-LABEL: test_vminnmavq_p_f16:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    sub sp, #4
 ; CHECK-NEXT:    vmov r1, s0
 ; CHECK-NEXT:    vmsr p0, r0
 ; CHECK-NEXT:    vpst
 ; CHECK-NEXT:    vminnmavt.f16 r1, q1
 ; CHECK-NEXT:    vmov s0, r1
-; CHECK-NEXT:    vstr.16 s0, [sp, #2]
-; CHECK-NEXT:    ldrh.w r0, [sp, #2]
+; CHECK-NEXT:    vmov.f16 r0, s0
 ; CHECK-NEXT:    vmov s0, r0
-; CHECK-NEXT:    add sp, #4
 ; CHECK-NEXT:    bx lr
 entry:
   %0 = bitcast float %a.coerce to i32
@@ -734,16 +716,13 @@ entry:
 define arm_aapcs_vfpcc float @test_vmaxnmvq_p_f16(float %a.coerce, <8 x half> %b, i16 zeroext %p) {
 ; CHECK-LABEL: test_vmaxnmvq_p_f16:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    sub sp, #4
 ; CHECK-NEXT:    vmov r1, s0
 ; CHECK-NEXT:    vmsr p0, r0
 ; CHECK-NEXT:    vpst
 ; CHECK-NEXT:    vmaxnmvt.f16 r1, q1
 ; CHECK-NEXT:    vmov s0, r1
-; CHECK-NEXT:    vstr.16 s0, [sp, #2]
-; CHECK-NEXT:    ldrh.w r0, [sp, #2]
+; CHECK-NEXT:    vmov.f16 r0, s0
 ; CHECK-NEXT:    vmov s0, r0
-; CHECK-NEXT:    add sp, #4
 ; CHECK-NEXT:    bx lr
 entry:
   %0 = bitcast float %a.coerce to i32
@@ -777,16 +756,13 @@ entry:
 define arm_aapcs_vfpcc float @test_vmaxnmavq_p_f16(float %a.coerce, <8 x half> %b, i16 zeroext %p) {
 ; CHECK-LABEL: test_vmaxnmavq_p_f16:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    sub sp, #4
 ; CHECK-NEXT:    vmov r1, s0
 ; CHECK-NEXT:    vmsr p0, r0
 ; CHECK-NEXT:    vpst
 ; CHECK-NEXT:    vmaxnmavt.f16 r1, q1
 ; CHECK-NEXT:    vmov s0, r1
-; CHECK-NEXT:    vstr.16 s0, [sp, #2]
-; CHECK-NEXT:    ldrh.w r0, [sp, #2]
+; CHECK-NEXT:    vmov.f16 r0, s0
 ; CHECK-NEXT:    vmov s0, r0
-; CHECK-NEXT:    add sp, #4
 ; CHECK-NEXT:    bx lr
 entry:
   %0 = bitcast float %a.coerce to i32

diff  --git a/llvm/test/CodeGen/Thumb2/mve-vdup.ll b/llvm/test/CodeGen/Thumb2/mve-vdup.ll
index ae91b52e1d54..78bd610958d7 100644
--- a/llvm/test/CodeGen/Thumb2/mve-vdup.ll
+++ b/llvm/test/CodeGen/Thumb2/mve-vdup.ll
@@ -127,15 +127,11 @@ entry:
 define arm_aapcs_vfpcc <8 x half> @vdup_f16_bc(half* %src1, half* %src2) {
 ; CHECK-LABEL: vdup_f16_bc:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    .pad #4
-; CHECK-NEXT:    sub sp, #4
 ; CHECK-NEXT:    vldr.16 s0, [r1]
 ; CHECK-NEXT:    vldr.16 s2, [r0]
 ; CHECK-NEXT:    vadd.f16 s0, s2, s0
-; CHECK-NEXT:    vstr.16 s0, [sp, #2]
-; CHECK-NEXT:    ldrh.w r0, [sp, #2]
+; CHECK-NEXT:    vmov.f16 r0, s0
 ; CHECK-NEXT:    vdup.16 q0, r0
-; CHECK-NEXT:    add sp, #4
 ; CHECK-NEXT:    bx lr
 entry:
   %0 = load half, half *%src1, align 2
@@ -260,16 +256,12 @@ entry:
 define arm_aapcs_vfpcc half @vdup_f16_extract(half* %src1, half* %src2) {
 ; CHECK-LABEL: vdup_f16_extract:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    .pad #4
-; CHECK-NEXT:    sub sp, #4
 ; CHECK-NEXT:    vldr.16 s0, [r2]
 ; CHECK-NEXT:    vldr.16 s2, [r1]
 ; CHECK-NEXT:    vadd.f16 s0, s2, s0
-; CHECK-NEXT:    vstr.16 s0, [sp, #2]
-; CHECK-NEXT:    ldrh.w r1, [sp, #2]
+; CHECK-NEXT:    vmov.f16 r1, s0
 ; CHECK-NEXT:    vdup.16 q0, r1
 ; CHECK-NEXT:    vstr.16 s1, [r0]
-; CHECK-NEXT:    add sp, #4
 ; CHECK-NEXT:    bx lr
 entry:
   %0 = load half, half *%src1, align 2


        


More information about the llvm-commits mailing list