[llvm] 5e1a9d3 - [ARM] Add lowering for bf16 neon vtrn, vzup and vuzp.

David Green via llvm-commits llvm-commits at lists.llvm.org
Sun Oct 2 07:34:46 PDT 2022


Author: David Green
Date: 2022-10-02T15:34:37+01:00
New Revision: 5e1a9d319d2ee5d59a151e1d82f8f23e6cf27466

URL: https://github.com/llvm/llvm-project/commit/5e1a9d319d2ee5d59a151e1d82f8f23e6cf27466
DIFF: https://github.com/llvm/llvm-project/commit/5e1a9d319d2ee5d59a151e1d82f8f23e6cf27466.diff

LOG: [ARM] Add lowering for bf16 neon vtrn, vzup and vuzp.

These go via Dag2Dag, which are better based on element sizes not the
exact element types.

Added: 
    

Modified: 
    llvm/lib/Target/ARM/ARMISelDAGToDAG.cpp
    llvm/test/CodeGen/ARM/bf16-shuffle.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/ARM/ARMISelDAGToDAG.cpp b/llvm/lib/Target/ARM/ARMISelDAGToDAG.cpp
index afe16a3cd55c..9ba79b0cd1dd 100644
--- a/llvm/lib/Target/ARM/ARMISelDAGToDAG.cpp
+++ b/llvm/lib/Target/ARM/ARMISelDAGToDAG.cpp
@@ -3608,7 +3608,22 @@ void ARMDAGToDAGISel::SelectCMPZ(SDNode *N, bool &SwitchEQNEToPLMI) {
                      Range->second + (31 - Range->first));
     ReplaceNode(And.getNode(), NewN);
   }
+}
 
+static unsigned getVectorShuffleOpcode(EVT VT, unsigned Opc64[3],
+                                       unsigned Opc128[3]) {
+  assert((VT.is64BitVector() || VT.is128BitVector()) &&
+         "Unexpected vector shuffle length");
+  switch (VT.getScalarSizeInBits()) {
+  default:
+    llvm_unreachable("Unexpected vector shuffle element size");
+  case 8:
+    return VT.is64BitVector() ? Opc64[0] : Opc128[0];
+  case 16:
+    return VT.is64BitVector() ? Opc64[1] : Opc128[1];
+  case 32:
+    return VT.is64BitVector() ? Opc64[2] : Opc128[2];
+  }
 }
 
 void ARMDAGToDAGISel::Select(SDNode *N) {
@@ -4248,72 +4263,38 @@ void ARMDAGToDAGISel::Select(SDNode *N) {
     // Other cases are autogenerated.
     break;
   }
-
   case ARMISD::VZIP: {
-    unsigned Opc = 0;
     EVT VT = N->getValueType(0);
-    switch (VT.getSimpleVT().SimpleTy) {
-    default: return;
-    case MVT::v8i8:  Opc = ARM::VZIPd8; break;
-    case MVT::v4f16:
-    case MVT::v4i16: Opc = ARM::VZIPd16; break;
-    case MVT::v2f32:
     // vzip.32 Dd, Dm is a pseudo-instruction expanded to vtrn.32 Dd, Dm.
-    case MVT::v2i32: Opc = ARM::VTRNd32; break;
-    case MVT::v16i8: Opc = ARM::VZIPq8; break;
-    case MVT::v8f16:
-    case MVT::v8i16: Opc = ARM::VZIPq16; break;
-    case MVT::v4f32:
-    case MVT::v4i32: Opc = ARM::VZIPq32; break;
-    }
+    unsigned Opc64[] = {ARM::VZIPd8, ARM::VZIPd16, ARM::VTRNd32};
+    unsigned Opc128[] = {ARM::VZIPq8, ARM::VZIPq16, ARM::VZIPq32};
+    unsigned Opc = getVectorShuffleOpcode(VT, Opc64, Opc128);
     SDValue Pred = getAL(CurDAG, dl);
     SDValue PredReg = CurDAG->getRegister(0, MVT::i32);
-    SDValue Ops[] = { N->getOperand(0), N->getOperand(1), Pred, PredReg };
+    SDValue Ops[] = {N->getOperand(0), N->getOperand(1), Pred, PredReg};
     ReplaceNode(N, CurDAG->getMachineNode(Opc, dl, VT, VT, Ops));
     return;
   }
   case ARMISD::VUZP: {
-    unsigned Opc = 0;
     EVT VT = N->getValueType(0);
-    switch (VT.getSimpleVT().SimpleTy) {
-    default: return;
-    case MVT::v8i8:  Opc = ARM::VUZPd8; break;
-    case MVT::v4f16:
-    case MVT::v4i16: Opc = ARM::VUZPd16; break;
-    case MVT::v2f32:
     // vuzp.32 Dd, Dm is a pseudo-instruction expanded to vtrn.32 Dd, Dm.
-    case MVT::v2i32: Opc = ARM::VTRNd32; break;
-    case MVT::v16i8: Opc = ARM::VUZPq8; break;
-    case MVT::v8f16:
-    case MVT::v8i16: Opc = ARM::VUZPq16; break;
-    case MVT::v4f32:
-    case MVT::v4i32: Opc = ARM::VUZPq32; break;
-    }
+    unsigned Opc64[] = {ARM::VUZPd8, ARM::VUZPd16, ARM::VTRNd32};
+    unsigned Opc128[] = {ARM::VUZPq8, ARM::VUZPq16, ARM::VUZPq32};
+    unsigned Opc = getVectorShuffleOpcode(VT, Opc64, Opc128);
     SDValue Pred = getAL(CurDAG, dl);
     SDValue PredReg = CurDAG->getRegister(0, MVT::i32);
-    SDValue Ops[] = { N->getOperand(0), N->getOperand(1), Pred, PredReg };
+    SDValue Ops[] = {N->getOperand(0), N->getOperand(1), Pred, PredReg};
     ReplaceNode(N, CurDAG->getMachineNode(Opc, dl, VT, VT, Ops));
     return;
   }
   case ARMISD::VTRN: {
-    unsigned Opc = 0;
     EVT VT = N->getValueType(0);
-    switch (VT.getSimpleVT().SimpleTy) {
-    default: return;
-    case MVT::v8i8:  Opc = ARM::VTRNd8; break;
-    case MVT::v4f16:
-    case MVT::v4i16: Opc = ARM::VTRNd16; break;
-    case MVT::v2f32:
-    case MVT::v2i32: Opc = ARM::VTRNd32; break;
-    case MVT::v16i8: Opc = ARM::VTRNq8; break;
-    case MVT::v8f16:
-    case MVT::v8i16: Opc = ARM::VTRNq16; break;
-    case MVT::v4f32:
-    case MVT::v4i32: Opc = ARM::VTRNq32; break;
-    }
+    unsigned Opc64[] = {ARM::VTRNd8, ARM::VTRNd16, ARM::VTRNd32};
+    unsigned Opc128[] = {ARM::VTRNq8, ARM::VTRNq16, ARM::VTRNq32};
+    unsigned Opc = getVectorShuffleOpcode(VT, Opc64, Opc128);
     SDValue Pred = getAL(CurDAG, dl);
     SDValue PredReg = CurDAG->getRegister(0, MVT::i32);
-    SDValue Ops[] = { N->getOperand(0), N->getOperand(1), Pred, PredReg };
+    SDValue Ops[] = {N->getOperand(0), N->getOperand(1), Pred, PredReg};
     ReplaceNode(N, CurDAG->getMachineNode(Opc, dl, VT, VT, Ops));
     return;
   }

diff  --git a/llvm/test/CodeGen/ARM/bf16-shuffle.ll b/llvm/test/CodeGen/ARM/bf16-shuffle.ll
index 14660fa5faba..47ac45caea36 100644
--- a/llvm/test/CodeGen/ARM/bf16-shuffle.ll
+++ b/llvm/test/CodeGen/ARM/bf16-shuffle.ll
@@ -34,59 +34,83 @@ entry:
   ret <8 x bfloat> %3
 }
 
-;define dso_local %struct.float16x4x2_t @test_vzip_bf16(<4 x bfloat> %a, <4 x bfloat> %b) {
-;entry:
-;  %vzip.i = shufflevector <4 x bfloat> %a, <4 x bfloat> %b, <4 x i32> <i32 0, i32 4, i32 1, i32 5>
-;  %vzip1.i = shufflevector <4 x bfloat> %a, <4 x bfloat> %b, <4 x i32> <i32 2, i32 6, i32 3, i32 7>
-;  %.fca.0.0.insert = insertvalue %struct.float16x4x2_t undef, <4 x bfloat> %vzip.i, 0, 0
-;  %.fca.0.1.insert = insertvalue %struct.float16x4x2_t %.fca.0.0.insert, <4 x bfloat> %vzip1.i, 0, 1
-;  ret %struct.float16x4x2_t %.fca.0.1.insert
-;}
+define dso_local %struct.float16x4x2_t @test_vzip_bf16(<4 x bfloat> %a, <4 x bfloat> %b) {
+; CHECK-LABEL: test_vzip_bf16:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vzip.16 d0, d1
+; CHECK-NEXT:    bx lr
+entry:
+  %vzip.i = shufflevector <4 x bfloat> %a, <4 x bfloat> %b, <4 x i32> <i32 0, i32 4, i32 1, i32 5>
+  %vzip1.i = shufflevector <4 x bfloat> %a, <4 x bfloat> %b, <4 x i32> <i32 2, i32 6, i32 3, i32 7>
+  %.fca.0.0.insert = insertvalue %struct.float16x4x2_t undef, <4 x bfloat> %vzip.i, 0, 0
+  %.fca.0.1.insert = insertvalue %struct.float16x4x2_t %.fca.0.0.insert, <4 x bfloat> %vzip1.i, 0, 1
+  ret %struct.float16x4x2_t %.fca.0.1.insert
+}
 
-;define dso_local %struct.float16x8x2_t @test_vzipq_bf16(<8 x bfloat> %a, <8 x bfloat> %b) {
-;entry:
-;  %vzip.i = shufflevector <8 x bfloat> %a, <8 x bfloat> %b, <8 x i32> <i32 0, i32 8, i32 1, i32 9, i32 2, i32 10, i32 3, i32 11>
-;  %vzip1.i = shufflevector <8 x bfloat> %a, <8 x bfloat> %b, <8 x i32> <i32 4, i32 12, i32 5, i32 13, i32 6, i32 14, i32 7, i32 15>
-;  %.fca.0.0.insert = insertvalue %struct.float16x8x2_t undef, <8 x bfloat> %vzip.i, 0, 0
-;  %.fca.0.1.insert = insertvalue %struct.float16x8x2_t %.fca.0.0.insert, <8 x bfloat> %vzip1.i, 0, 1
-;  ret %struct.float16x8x2_t %.fca.0.1.insert
-;}
+define dso_local %struct.float16x8x2_t @test_vzipq_bf16(<8 x bfloat> %a, <8 x bfloat> %b) {
+; CHECK-LABEL: test_vzipq_bf16:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vzip.16 q0, q1
+; CHECK-NEXT:    bx lr
+entry:
+  %vzip.i = shufflevector <8 x bfloat> %a, <8 x bfloat> %b, <8 x i32> <i32 0, i32 8, i32 1, i32 9, i32 2, i32 10, i32 3, i32 11>
+  %vzip1.i = shufflevector <8 x bfloat> %a, <8 x bfloat> %b, <8 x i32> <i32 4, i32 12, i32 5, i32 13, i32 6, i32 14, i32 7, i32 15>
+  %.fca.0.0.insert = insertvalue %struct.float16x8x2_t undef, <8 x bfloat> %vzip.i, 0, 0
+  %.fca.0.1.insert = insertvalue %struct.float16x8x2_t %.fca.0.0.insert, <8 x bfloat> %vzip1.i, 0, 1
+  ret %struct.float16x8x2_t %.fca.0.1.insert
+}
 
-;define dso_local %struct.float16x4x2_t @test_vuzp_bf16(<4 x bfloat> %a, <4 x bfloat> %b) {
-;entry:
-;  %vuzp.i = shufflevector <4 x bfloat> %a, <4 x bfloat> %b, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
-;  %vuzp1.i = shufflevector <4 x bfloat> %a, <4 x bfloat> %b, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
-;  %.fca.0.0.insert = insertvalue %struct.float16x4x2_t undef, <4 x bfloat> %vuzp.i, 0, 0
-;  %.fca.0.1.insert = insertvalue %struct.float16x4x2_t %.fca.0.0.insert, <4 x bfloat> %vuzp1.i, 0, 1
-;  ret %struct.float16x4x2_t %.fca.0.1.insert
-;}
+define dso_local %struct.float16x4x2_t @test_vuzp_bf16(<4 x bfloat> %a, <4 x bfloat> %b) {
+; CHECK-LABEL: test_vuzp_bf16:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vuzp.16 d0, d1
+; CHECK-NEXT:    bx lr
+entry:
+  %vuzp.i = shufflevector <4 x bfloat> %a, <4 x bfloat> %b, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+  %vuzp1.i = shufflevector <4 x bfloat> %a, <4 x bfloat> %b, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+  %.fca.0.0.insert = insertvalue %struct.float16x4x2_t undef, <4 x bfloat> %vuzp.i, 0, 0
+  %.fca.0.1.insert = insertvalue %struct.float16x4x2_t %.fca.0.0.insert, <4 x bfloat> %vuzp1.i, 0, 1
+  ret %struct.float16x4x2_t %.fca.0.1.insert
+}
 
-;define dso_local %struct.float16x8x2_t @test_vuzpq_bf16(<8 x bfloat> %a, <8 x bfloat> %b) {
-;entry:
-;  %vuzp.i = shufflevector <8 x bfloat> %a, <8 x bfloat> %b, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
-;  %vuzp1.i = shufflevector <8 x bfloat> %a, <8 x bfloat> %b, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
-;  %.fca.0.0.insert = insertvalue %struct.float16x8x2_t undef, <8 x bfloat> %vuzp.i, 0, 0
-;  %.fca.0.1.insert = insertvalue %struct.float16x8x2_t %.fca.0.0.insert, <8 x bfloat> %vuzp1.i, 0, 1
-;  ret %struct.float16x8x2_t %.fca.0.1.insert
-;}
+define dso_local %struct.float16x8x2_t @test_vuzpq_bf16(<8 x bfloat> %a, <8 x bfloat> %b) {
+; CHECK-LABEL: test_vuzpq_bf16:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vuzp.16 q0, q1
+; CHECK-NEXT:    bx lr
+entry:
+  %vuzp.i = shufflevector <8 x bfloat> %a, <8 x bfloat> %b, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+  %vuzp1.i = shufflevector <8 x bfloat> %a, <8 x bfloat> %b, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
+  %.fca.0.0.insert = insertvalue %struct.float16x8x2_t undef, <8 x bfloat> %vuzp.i, 0, 0
+  %.fca.0.1.insert = insertvalue %struct.float16x8x2_t %.fca.0.0.insert, <8 x bfloat> %vuzp1.i, 0, 1
+  ret %struct.float16x8x2_t %.fca.0.1.insert
+}
 
-;define dso_local %struct.float16x4x2_t @test_vtrn_bf16(<4 x bfloat> %a, <4 x bfloat> %b) {
-;entry:
-;  %vtrn.i = shufflevector <4 x bfloat> %a, <4 x bfloat> %b, <4 x i32> <i32 0, i32 4, i32 2, i32 6>
-;  %vtrn1.i = shufflevector <4 x bfloat> %a, <4 x bfloat> %b, <4 x i32> <i32 1, i32 5, i32 3, i32 7>
-;  %.fca.0.0.insert = insertvalue %struct.float16x4x2_t undef, <4 x bfloat> %vtrn.i, 0, 0
-;  %.fca.0.1.insert = insertvalue %struct.float16x4x2_t %.fca.0.0.insert, <4 x bfloat> %vtrn1.i, 0, 1
-;  ret %struct.float16x4x2_t %.fca.0.1.insert
-;}
+define dso_local %struct.float16x4x2_t @test_vtrn_bf16(<4 x bfloat> %a, <4 x bfloat> %b) {
+; CHECK-LABEL: test_vtrn_bf16:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vtrn.16 d0, d1
+; CHECK-NEXT:    bx lr
+entry:
+  %vtrn.i = shufflevector <4 x bfloat> %a, <4 x bfloat> %b, <4 x i32> <i32 0, i32 4, i32 2, i32 6>
+  %vtrn1.i = shufflevector <4 x bfloat> %a, <4 x bfloat> %b, <4 x i32> <i32 1, i32 5, i32 3, i32 7>
+  %.fca.0.0.insert = insertvalue %struct.float16x4x2_t undef, <4 x bfloat> %vtrn.i, 0, 0
+  %.fca.0.1.insert = insertvalue %struct.float16x4x2_t %.fca.0.0.insert, <4 x bfloat> %vtrn1.i, 0, 1
+  ret %struct.float16x4x2_t %.fca.0.1.insert
+}
 
-;define dso_local %struct.float16x8x2_t @test_vtrnq_bf16(<8 x bfloat> %a, <8 x bfloat> %b) {
-;entry:
-;  %vtrn.i = shufflevector <8 x bfloat> %a, <8 x bfloat> %b, <8 x i32> <i32 0, i32 8, i32 2, i32 10, i32 4, i32 12, i32 6, i32 14>
-;  %vtrn1.i = shufflevector <8 x bfloat> %a, <8 x bfloat> %b, <8 x i32> <i32 1, i32 9, i32 3, i32 11, i32 5, i32 13, i32 7, i32 15>
-;  %.fca.0.0.insert = insertvalue %struct.float16x8x2_t undef, <8 x bfloat> %vtrn.i, 0, 0
-;  %.fca.0.1.insert = insertvalue %struct.float16x8x2_t %.fca.0.0.insert, <8 x bfloat> %vtrn1.i, 0, 1
-;  ret %struct.float16x8x2_t %.fca.0.1.insert
-;}
+define dso_local %struct.float16x8x2_t @test_vtrnq_bf16(<8 x bfloat> %a, <8 x bfloat> %b) {
+; CHECK-LABEL: test_vtrnq_bf16:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vtrn.16 q0, q1
+; CHECK-NEXT:    bx lr
+entry:
+  %vtrn.i = shufflevector <8 x bfloat> %a, <8 x bfloat> %b, <8 x i32> <i32 0, i32 8, i32 2, i32 10, i32 4, i32 12, i32 6, i32 14>
+  %vtrn1.i = shufflevector <8 x bfloat> %a, <8 x bfloat> %b, <8 x i32> <i32 1, i32 9, i32 3, i32 11, i32 5, i32 13, i32 7, i32 15>
+  %.fca.0.0.insert = insertvalue %struct.float16x8x2_t undef, <8 x bfloat> %vtrn.i, 0, 0
+  %.fca.0.1.insert = insertvalue %struct.float16x8x2_t %.fca.0.0.insert, <8 x bfloat> %vtrn1.i, 0, 1
+  ret %struct.float16x8x2_t %.fca.0.1.insert
+}
 
 define dso_local <4 x bfloat> @test_vmov_n_bf16(float %a.coerce) {
 ; CHECK-NOFP16-LABEL: test_vmov_n_bf16:


        


More information about the llvm-commits mailing list