[clang] 10b6567 - [AArch64]: BFloat MatMul Intrinsics&CodeGen

Luke Geeson via cfe-commits cfe-commits at lists.llvm.org
Tue Jun 16 07:23:48 PDT 2020


Author: Luke Geeson
Date: 2020-06-16T15:23:30+01:00
New Revision: 10b6567f49778f49ea81ff36269fc0fbc033d7ad

URL: https://github.com/llvm/llvm-project/commit/10b6567f49778f49ea81ff36269fc0fbc033d7ad
DIFF: https://github.com/llvm/llvm-project/commit/10b6567f49778f49ea81ff36269fc0fbc033d7ad.diff

LOG: [AArch64]: BFloat MatMul Intrinsics&CodeGen

This patch upstreams support for BFloat Matrix Multiplication Intrinsics
and Code Generation from __bf16 to AArch64. This includes IR intrinsics. Unittests are
provided as needed. AArch32 Intrinsics + CodeGen will come after this
patch.

This patch is part of a series implementing the Bfloat16 extension of
the
Armv8.6-a architecture, as detailed here:

https://community.arm.com/developer/ip-products/processors/b/processors-ip-blog/posts/arm-architecture-developments-armv8-6-a

The bfloat type, and its properties are specified in the Arm
Architecture
Reference Manual:

https://developer.arm.com/docs/ddi0487/latest/arm-architecture-reference-manual-armv8-for-armv8-a-architecture-profile

The following people contributed to this patch:

Luke Geeson
 - Momchil Velikov
 - Mikhail Maltsev
 - Luke Cheeseman

Reviewers: SjoerdMeijer, t.p.northover, sdesmalen, labrinea, miyuki,
stuij

Reviewed By: miyuki, stuij

Subscribers: kristof.beyls, hiraditya, danielkiss, cfe-commits,
llvm-commits, miyuki, chill, pbarrio, stuij

Tags: #clang, #llvm

Differential Revision: https://reviews.llvm.org/D80752

Change-Id: I174f0fd0f600d04e3799b06a7da88973c6c0703f

Added: 
    clang/test/CodeGen/aarch64-bf16-dotprod-intrinsics.c
    llvm/test/CodeGen/AArch64/aarch64-bf16-dotprod-intrinsics.ll

Modified: 
    clang/include/clang/Basic/arm_neon.td
    clang/lib/CodeGen/CGBuiltin.cpp
    llvm/include/llvm/IR/IntrinsicsAArch64.td
    llvm/lib/Target/AArch64/AArch64InstrFormats.td
    llvm/lib/Target/AArch64/AArch64InstrInfo.td

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/Basic/arm_neon.td b/clang/include/clang/Basic/arm_neon.td
index ffdf08ea494a..289f5ea47b92 100644
--- a/clang/include/clang/Basic/arm_neon.td
+++ b/clang/include/clang/Basic/arm_neon.td
@@ -244,6 +244,22 @@ def OP_SUDOT_LNQ
     : Op<(call "vusdot", $p0,
           (cast "8", "U", (call_mangled "splat_lane", (bitcast "int32x4_t", $p2), $p3)), $p1)>;
 
+def OP_BFDOT_LN
+    : Op<(call "vbfdot", $p0, $p1,
+          (bitcast $p1, (call_mangled "splat_lane", (bitcast "float32x2_t", $p2), $p3)))>;
+
+def OP_BFDOT_LNQ
+    : Op<(call "vbfdot", $p0, $p1,
+          (bitcast $p1, (call_mangled "splat_lane", (bitcast "float32x4_t", $p2), $p3)))>;
+
+def OP_BFMLALB_LN
+    : Op<(call "vbfmlalb", $p0, $p1,
+          (dup_typed $p1, (call "vget_lane", $p2, $p3)))>;
+
+def OP_BFMLALT_LN
+    : Op<(call "vbfmlalt", $p0, $p1,
+          (dup_typed $p1, (call "vget_lane", $p2, $p3)))>;
+
 //===----------------------------------------------------------------------===//
 // Auxiliary Instructions
 //===----------------------------------------------------------------------===//
@@ -1847,6 +1863,25 @@ let ArchGuard = "defined(__ARM_FEATURE_MATMUL_INT8)" in {
   }
 }
 
+let ArchGuard = "defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC)" in {
+  def VDOT_BF : SInst<"vbfdot", "..BB", "fQf">;
+  def VDOT_LANE_BF : SOpInst<"vbfdot_lane", "..B(Bq)I", "fQf", OP_BFDOT_LN>;
+  def VDOT_LANEQ_BF : SOpInst<"vbfdot_laneq", "..B(BQ)I", "fQf", OP_BFDOT_LNQ> {
+    let isLaneQ = 1;
+  }
+
+  def VFMMLA_BF : SInst<"vbfmmla", "..BB", "Qf">;
+
+  def VFMLALB_BF : SInst<"vbfmlalb", "..BB", "Qf">;
+  def VFMLALT_BF : SInst<"vbfmlalt", "..BB", "Qf">;
+
+  def VFMLALB_LANE_BF : SOpInst<"vbfmlalb_lane", "..B(Bq)I", "Qf", OP_BFMLALB_LN>;
+  def VFMLALB_LANEQ_BF : SOpInst<"vbfmlalb_laneq", "..B(BQ)I", "Qf", OP_BFMLALB_LN>;
+
+  def VFMLALT_LANE_BF : SOpInst<"vbfmlalt_lane", "..B(Bq)I", "Qf", OP_BFMLALT_LN>;
+  def VFMLALT_LANEQ_BF : SOpInst<"vbfmlalt_laneq", "..B(BQ)I", "Qf", OP_BFMLALT_LN>;
+}
+
 // v8.3-A Vector complex addition intrinsics
 let ArchGuard = "defined(__ARM_FEATURE_COMPLEX) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)" in {
   def VCADD_ROT90_FP16   : SInst<"vcadd_rot90", "...", "h">;

diff  --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 209b5a2b00e3..c3cfed34eeba 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -4970,6 +4970,11 @@ static const ARMVectorIntrinsicInfo AArch64SIMDIntrinsicMap[] = {
   NEONMAP1(vaeseq_v, aarch64_crypto_aese, 0),
   NEONMAP1(vaesimcq_v, aarch64_crypto_aesimc, 0),
   NEONMAP1(vaesmcq_v, aarch64_crypto_aesmc, 0),
+  NEONMAP1(vbfdot_v, aarch64_neon_bfdot, 0),
+  NEONMAP1(vbfdotq_v, aarch64_neon_bfdot, 0),
+  NEONMAP1(vbfmlalbq_v, aarch64_neon_bfmlalb, 0),
+  NEONMAP1(vbfmlaltq_v, aarch64_neon_bfmlalt, 0),
+  NEONMAP1(vbfmmlaq_v, aarch64_neon_bfmmla, 0),
   NEONMAP1(vcadd_rot270_v, aarch64_neon_vcadd_rot270, Add1ArgType),
   NEONMAP1(vcadd_rot90_v, aarch64_neon_vcadd_rot90, Add1ArgType),
   NEONMAP1(vcaddq_rot270_v, aarch64_neon_vcadd_rot270, Add1ArgType),
@@ -6141,6 +6146,32 @@ Value *CodeGenFunction::EmitCommonNeonBuiltinExpr(
     llvm::Type *Tys[2] = { Ty, InputTy };
     return EmitNeonCall(CGM.getIntrinsic(Int, Tys), Ops, "vusdot");
   }
+  case NEON::BI__builtin_neon_vbfdot_v:
+  case NEON::BI__builtin_neon_vbfdotq_v: {
+    llvm::Type *InputTy =
+           llvm::VectorType::get(Int8Ty, Ty->getPrimitiveSizeInBits() / 8);
+    llvm::Type *Tys[2] = { Ty, InputTy };
+    return EmitNeonCall(CGM.getIntrinsic(Int, Tys), Ops, "vbfdot");
+  }
+  case NEON::BI__builtin_neon_vbfmmlaq_v: {
+    llvm::Type *InputTy =
+           llvm::VectorType::get(Int8Ty, Ty->getPrimitiveSizeInBits() / 8);
+    llvm::Type *Tys[2] = { Ty, InputTy };
+    return EmitNeonCall(CGM.getIntrinsic(Int, Tys), Ops, "vbfmmla");
+  }
+  case NEON::BI__builtin_neon_vbfmlalbq_v: {
+    llvm::Type *InputTy =
+           llvm::VectorType::get(Int8Ty, Ty->getPrimitiveSizeInBits() / 8);
+    llvm::Type *Tys[2] = { Ty, InputTy };
+    return EmitNeonCall(CGM.getIntrinsic(Int, Tys), Ops, "vbfmlalb");
+  }
+  case NEON::BI__builtin_neon_vbfmlaltq_v: {
+    llvm::Type *InputTy =
+           llvm::VectorType::get(Int8Ty, Ty->getPrimitiveSizeInBits() / 8);
+    llvm::Type *Tys[2] = { Ty, InputTy };
+    return EmitNeonCall(CGM.getIntrinsic(Int, Tys), Ops, "vbfmlalt");
+  }
+
   }
 
   assert(Int && "Expected valid intrinsic number");

diff  --git a/clang/test/CodeGen/aarch64-bf16-dotprod-intrinsics.c b/clang/test/CodeGen/aarch64-bf16-dotprod-intrinsics.c
new file mode 100644
index 000000000000..22e1396787ce
--- /dev/null
+++ b/clang/test/CodeGen/aarch64-bf16-dotprod-intrinsics.c
@@ -0,0 +1,146 @@
+// RUN: %clang_cc1 -triple aarch64-arm-none-eabi -target-feature +neon -target-feature +bf16 \
+// RUN: -disable-O0-optnone -emit-llvm %s -o - | opt -S -mem2reg -instcombine | FileCheck %s
+
+#include <arm_neon.h>
+
+// CHECK-LABEL: test_vbfdot_f32
+// CHECK-NEXT: entry:
+// CHECK-NEXT  %0 = bitcast <4 x bfloat> %a to <8 x i8>
+// CHECK-NEXT  %1 = bitcast <4 x bfloat> %b to <8 x i8>
+// CHECK-NEXT  %vbfdot1.i = tail call <2 x float> @llvm.aarch64.neon.bfdot.v2f32.v8i8(<2 x float> %r, <8 x i8> %0, <8 x i8> %1)
+// CHECK-NEXT  ret <2 x float> %vbfdot1.i
+float32x2_t test_vbfdot_f32(float32x2_t r, bfloat16x4_t a, bfloat16x4_t b) {
+  return vbfdot_f32(r, a, b);
+}
+
+// CHECK-LABEL: test_vbfdotq_f32
+// CHECK-NEXT: entry:
+// CHECK-NEXT  %0 = bitcast <8 x bfloat> %a to <16 x i8>
+// CHECK-NEXT  %1 = bitcast <8 x bfloat> %b to <16 x i8>
+// CHECK-NEXT  %vbfdot1.i = tail call <4 x float> @llvm.aarch64.neon.bfdot.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1)
+// CHECK-NEXT  ret <4 x float> %vbfdot1.i
+float32x4_t test_vbfdotq_f32(float32x4_t r, bfloat16x8_t a, bfloat16x8_t b){
+  return vbfdotq_f32(r, a, b);
+}
+
+// CHECK-LABEL: test_vbfdot_lane_f32
+// CHECK-NEXT: entry:
+// CHECK-NEXT  %0 = bitcast <4 x bfloat> %b to <2 x float>
+// CHECK-NEXT  %lane = shufflevector <2 x float> %0, <2 x float> undef, <2 x i32> zeroinitializer
+// CHECK-NEXT  %1 = bitcast <4 x bfloat> %a to <8 x i8>
+// CHECK-NEXT  %2 = bitcast <2 x float> %lane to <8 x i8>
+// CHECK-NEXT  %vbfdot1.i = tail call <2 x float> @llvm.aarch64.neon.bfdot.v2f32.v8i8(<2 x float> %r, <8 x i8> %1, <8 x i8> %2)
+// CHECK-NEXT  ret <2 x float> %vbfdot1.i
+float32x2_t test_vbfdot_lane_f32(float32x2_t r, bfloat16x4_t a, bfloat16x4_t b){
+  return vbfdot_lane_f32(r, a, b, 0);
+}
+
+// CHECK-LABEL: test_vbfdotq_laneq_f32
+// CHECK-NEXT: entry:
+// CHECK-NEXT  %0 = bitcast <8 x bfloat> %b to <4 x float>
+// CHECK-NEXT  %lane = shufflevector <4 x float> %0, <4 x float> undef, <4 x i32> <i32 3, i32 3, i32 3, i32 3>
+// CHECK-NEXT  %1 = bitcast <8 x bfloat> %a to <16 x i8>
+// CHECK-NEXT  %2 = bitcast <4 x float> %lane to <16 x i8>
+// CHECK-NEXT  %vbfdot1.i = tail call <4 x float> @llvm.aarch64.neon.bfdot.v4f32.v16i8(<4 x float> %r, <16 x i8> %1, <16 x i8> %2)
+// CHECK-NEXT  ret <4 x float> %vbfdot1.i
+float32x4_t test_vbfdotq_laneq_f32(float32x4_t r, bfloat16x8_t a, bfloat16x8_t b) {
+  return vbfdotq_laneq_f32(r, a, b, 3);
+}
+
+// CHECK-LABEL: test_vbfdot_laneq_f32
+// CHECK-NEXT: entry:
+// CHECK-NEXT  %0 = bitcast <8 x bfloat> %b to <4 x float>
+// CHECK-NEXT  %lane = shufflevector <4 x float> %0, <4 x float> undef, <2 x i32> <i32 3, i32 3>
+// CHECK-NEXT  %1 = bitcast <4 x bfloat> %a to <8 x i8>
+// CHECK-NEXT  %2 = bitcast <2 x float> %lane to <8 x i8>
+// CHECK-NEXT  %vbfdot1.i = tail call <2 x float> @llvm.aarch64.neon.bfdot.v2f32.v8i8(<2 x float> %r, <8 x i8> %1, <8 x i8> %2)
+// CHECK-NEXT  ret <2 x float> %vbfdot1.i
+float32x2_t test_vbfdot_laneq_f32(float32x2_t r, bfloat16x4_t a, bfloat16x8_t b) {
+  return vbfdot_laneq_f32(r, a, b, 3);
+}
+
+// CHECK-LABEL: test_vbfdotq_lane_f32
+// CHECK-NEXT: entry:
+// CHECK-NEXT  %0 = bitcast <4 x bfloat> %b to <2 x float>
+// CHECK-NEXT  %lane = shufflevector <2 x float> %0, <2 x float> undef, <4 x i32> zeroinitializer
+// CHECK-NEXT  %1 = bitcast <8 x bfloat> %a to <16 x i8>
+// CHECK-NEXT  %2 = bitcast <4 x float> %lane to <16 x i8>
+// CHECK-NEXT  %vbfdot1.i = tail call <4 x float> @llvm.aarch64.neon.bfdot.v4f32.v16i8(<4 x float> %r, <16 x i8> %1, <16 x i8> %2)
+// CHECK-NEXT  ret <4 x float> %vbfdot1.i
+float32x4_t test_vbfdotq_lane_f32(float32x4_t r, bfloat16x8_t a, bfloat16x4_t b) {
+  return vbfdotq_lane_f32(r, a, b, 0);
+}
+
+// CHECK-LABEL: test_vbfmmlaq_f32
+// CHECK-NEXT: entry:
+// CHECK-NEXT  %0 = bitcast <8 x bfloat> %a to <16 x i8>
+// CHECK-NEXT  %1 = bitcast <8 x bfloat> %b to <16 x i8>
+// CHECK-NEXT  %vbfmmla1.i = tail call <4 x float> @llvm.aarch64.neon.bfmmla.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1)
+// CHECK-NEXT  ret <4 x float> %vbfmmla1.i
+float32x4_t test_vbfmmlaq_f32(float32x4_t r, bfloat16x8_t a, bfloat16x8_t b) {
+  return vbfmmlaq_f32(r, a, b);
+}
+
+// CHECK-LABEL: test_vbfmlalbq_f32
+// CHECK-NEXT: entry:
+// CHECK-NEXT  %0 = bitcast <8 x bfloat> %a to <16 x i8>
+// CHECK-NEXT  %1 = bitcast <8 x bfloat> %b to <16 x i8>
+// CHECK-NEXT  %vbfmlalb1.i = tail call <4 x float> @llvm.aarch64.neon.bfmlalb.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1)
+// CHECK-NEXT  ret <4 x float> %vbfmlalb1.i
+float32x4_t test_vbfmlalbq_f32(float32x4_t r, bfloat16x8_t a, bfloat16x8_t b) {
+  return vbfmlalbq_f32(r, a, b);
+}
+
+// CHECK-LABEL: test_vbfmlaltq_f32
+// CHECK-NEXT: entry:
+// CHECK-NEXT  %0 = bitcast <8 x bfloat> %a to <16 x i8>
+// CHECK-NEXT  %1 = bitcast <8 x bfloat> %b to <16 x i8>
+// CHECK-NEXT  %vbfmlalt1.i = tail call <4 x float> @llvm.aarch64.neon.bfmlalt.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1)
+// CHECK-NEXT  ret <4 x float> %vbfmlalt1.i
+float32x4_t test_vbfmlaltq_f32(float32x4_t r, bfloat16x8_t a, bfloat16x8_t b) {
+  return vbfmlaltq_f32(r, a, b);
+}
+
+// CHECK-LABEL: test_vbfmlalbq_lane_f32
+// CHECK-NEXT: entry:
+// CHECK-NEXT  %vecinit35 = shufflevector <4 x bfloat> %b, <4 x bfloat> undef, <8 x i32> zeroinitializer
+// CHECK-NEXT  %0 = bitcast <8 x bfloat> %a to <16 x i8>
+// CHECK-NEXT  %1 = bitcast <8 x bfloat> %vecinit35 to <16 x i8>
+// CHECK-NEXT  %vbfmlalb1.i = tail call <4 x float> @llvm.aarch64.neon.bfmlalb.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1)
+// CHECK-NEXT  ret <4 x float> %vbfmlalb1.i
+float32x4_t test_vbfmlalbq_lane_f32(float32x4_t r, bfloat16x8_t a, bfloat16x4_t b) {
+  return vbfmlalbq_lane_f32(r, a, b, 0);
+}
+
+// CHECK-LABEL: test_vbfmlalbq_laneq_f32
+// CHECK-NEXT: entry:
+// CHECK-NEXT  %vecinit35 = shufflevector <8 x bfloat> %b, <8 x bfloat> undef, <8 x i32> <i32 3, i32 3, i32 3, i32 3, i32 3, i32 3, i32 3, i32 3>
+// CHECK-NEXT  %0 = bitcast <8 x bfloat> %a to <16 x i8>
+// CHECK-NEXT  %1 = bitcast <8 x bfloat> %vecinit35 to <16 x i8>
+// CHECK-NEXT  %vbfmlalb1.i = tail call <4 x float> @llvm.aarch64.neon.bfmlalb.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1)
+// CHECK-NEXT  ret <4 x float> %vbfmlalb1.i
+float32x4_t test_vbfmlalbq_laneq_f32(float32x4_t r, bfloat16x8_t a, bfloat16x8_t b) {
+  return vbfmlalbq_laneq_f32(r, a, b, 3);
+}
+
+// CHECK-LABEL: test_vbfmlaltq_lane_f32
+// CHECK-NEXT: entry:
+// CHECK-NEXT  %vecinit35 = shufflevector <4 x bfloat> %b, <4 x bfloat> undef, <8 x i32> zeroinitializer
+// CHECK-NEXT  %0 = bitcast <8 x bfloat> %a to <16 x i8>
+// CHECK-NEXT  %1 = bitcast <8 x bfloat> %vecinit35 to <16 x i8>
+// CHECK-NEXT  %vbfmlalt1.i = tail call <4 x float> @llvm.aarch64.neon.bfmlalt.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1)
+// CHECK-NEXT  ret <4 x float> %vbfmlalt1.i
+float32x4_t test_vbfmlaltq_lane_f32(float32x4_t r, bfloat16x8_t a, bfloat16x4_t b) {
+  return vbfmlaltq_lane_f32(r, a, b, 0);
+}
+
+// CHECK-LABEL: test_vbfmlaltq_laneq_f32
+// CHECK-NEXT: entry:
+// CHECK-NEXT  %vecinit35 = shufflevector <8 x bfloat> %b, <8 x bfloat> undef, <8 x i32> <i32 3, i32 3, i32 3, i32 3, i32 3, i32 3, i32 3, i32 3>
+// CHECK-NEXT  %0 = bitcast <8 x bfloat> %a to <16 x i8>
+// CHECK-NEXT  %1 = bitcast <8 x bfloat> %vecinit35 to <16 x i8>
+// CHECK-NEXT  %vbfmlalt1.i = tail call <4 x float> @llvm.aarch64.neon.bfmlalt.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1)
+// CHECK-NEXT  ret <4 x float> %vbfmlalt1.i
+float32x4_t test_vbfmlaltq_laneq_f32(float32x4_t r, bfloat16x8_t a, bfloat16x8_t b) {
+  return vbfmlaltq_laneq_f32(r, a, b, 3);
+}

diff  --git a/llvm/include/llvm/IR/IntrinsicsAArch64.td b/llvm/include/llvm/IR/IntrinsicsAArch64.td
index 9019b9d3be55..483afe97cc63 100644
--- a/llvm/include/llvm/IR/IntrinsicsAArch64.td
+++ b/llvm/include/llvm/IR/IntrinsicsAArch64.td
@@ -178,6 +178,12 @@ let TargetPrefix = "aarch64" in {  // All intrinsics start with "llvm.aarch64.".
     : Intrinsic<[llvm_anyvector_ty],
                 [LLVMMatchType<0>, llvm_anyvector_ty, LLVMMatchType<1>],
                 [IntrNoMem]>;
+
+  class AdvSIMD_FML_Intrinsic
+    : Intrinsic<[llvm_anyvector_ty],
+                [LLVMMatchType<0>, llvm_anyvector_ty, LLVMMatchType<1>],
+                [IntrNoMem]>;
+
 }
 
 // Arithmetic ops
@@ -459,6 +465,11 @@ let TargetPrefix = "aarch64", IntrProperties = [IntrNoMem] in {
   def int_aarch64_neon_smmla : AdvSIMD_MatMul_Intrinsic;
   def int_aarch64_neon_usmmla : AdvSIMD_MatMul_Intrinsic;
   def int_aarch64_neon_usdot : AdvSIMD_Dot_Intrinsic;
+  def int_aarch64_neon_bfdot : AdvSIMD_Dot_Intrinsic;
+  def int_aarch64_neon_bfmmla : AdvSIMD_MatMul_Intrinsic;
+  def int_aarch64_neon_bfmlalb : AdvSIMD_FML_Intrinsic;
+  def int_aarch64_neon_bfmlalt : AdvSIMD_FML_Intrinsic;
+
 
   // v8.2-A FP16 Fused Multiply-Add Long
   def int_aarch64_neon_fmlal : AdvSIMD_FP16FML_Intrinsic;

diff  --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
index 713bf0bf3cad..8f5202af96e4 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
@@ -7815,16 +7815,19 @@ let mayStore = 0, mayLoad = 0, hasSideEffects = 0 in {
 class BaseSIMDThreeSameVectorBFDot<bit Q, bit U, string asm, string kind1,
                                    string kind2, RegisterOperand RegType,
                                    ValueType AccumType, ValueType InputType>
-  : BaseSIMDThreeSameVectorTied<Q, U, 0b010, 0b11111, RegType, asm, kind1, []> {
+  : BaseSIMDThreeSameVectorTied<Q, U, 0b010, 0b11111, RegType, asm, kind1, [(set (AccumType RegType:$dst),
+                    (int_aarch64_neon_bfdot (AccumType RegType:$Rd),
+                                            (InputType RegType:$Rn),
+                                            (InputType RegType:$Rm)))]> {
   let AsmString = !strconcat(asm,
                              "{\t$Rd" # kind1 # ", $Rn" # kind2 #
                                ", $Rm" # kind2 # "}");
 }
 
 multiclass SIMDThreeSameVectorBFDot<bit U, string asm> {
-  def v4f16 : BaseSIMDThreeSameVectorBFDot<0, U, asm, ".2s", ".4h", V64,
+  def v4bf16 : BaseSIMDThreeSameVectorBFDot<0, U, asm, ".2s", ".4h", V64,
                                            v2f32, v8i8>;
-  def v8f16 : BaseSIMDThreeSameVectorBFDot<1, U, asm, ".4s", ".8h", V128,
+  def v8bf16 : BaseSIMDThreeSameVectorBFDot<1, U, asm, ".4s", ".8h", V128,
                                            v4f32, v16i8>;
 }
 
@@ -7837,7 +7840,13 @@ class BaseSIMDThreeSameVectorBF16DotI<bit Q, bit U, string asm,
   : BaseSIMDIndexedTied<Q, U, 0b0, 0b01, 0b1111,
                         RegType, RegType, V128, VectorIndexS,
                         asm, "", dst_kind, lhs_kind, rhs_kind,
-        []> {
+        [(set (AccumType RegType:$dst),
+              (AccumType (int_aarch64_neon_bfdot
+                                 (AccumType RegType:$Rd),
+                                 (InputType RegType:$Rn),
+                                 (InputType (bitconvert (AccumType
+                                    (AArch64duplane32 (v4f32 V128:$Rm),
+                                        VectorIndexH:$idx)))))))]> {
 
   bits<2> idx;
   let Inst{21}    = idx{0};  // L
@@ -7846,23 +7855,30 @@ class BaseSIMDThreeSameVectorBF16DotI<bit Q, bit U, string asm,
 
 multiclass SIMDThreeSameVectorBF16DotI<bit U, string asm> {
 
-  def v4f16  : BaseSIMDThreeSameVectorBF16DotI<0, U, asm, ".2s", ".4h",
+  def v4bf16  : BaseSIMDThreeSameVectorBF16DotI<0, U, asm, ".2s", ".4h",
                                                ".2h", V64, v2f32, v8i8>;
-  def v8f16 : BaseSIMDThreeSameVectorBF16DotI<1, U, asm, ".4s", ".8h",
+  def v8bf16 : BaseSIMDThreeSameVectorBF16DotI<1, U, asm, ".4s", ".8h",
                                               ".2h", V128, v4f32, v16i8>;
 }
 
-class SIMDBF16MLAL<bit Q, string asm>
+class SIMDBF16MLAL<bit Q, string asm, SDPatternOperator OpNode>
   : BaseSIMDThreeSameVectorTied<Q, 0b1, 0b110, 0b11111, V128, asm, ".4s",
-              []> { // TODO: Add intrinsics
+              [(set (v4f32 V128:$dst), (OpNode (v4f32 V128:$Rd),
+                                               (v16i8 V128:$Rn),
+                                               (v16i8 V128:$Rm)))]> {
   let AsmString = !strconcat(asm, "{\t$Rd.4s, $Rn.8h, $Rm.8h}");
 }
 
-class SIMDBF16MLALIndex<bit Q, string asm>
+class SIMDBF16MLALIndex<bit Q, string asm, SDPatternOperator OpNode>
   : I<(outs V128:$dst),
       (ins V128:$Rd, V128:$Rn, V128_lo:$Rm, VectorIndexH:$idx), asm,
       "{\t$Rd.4s, $Rn.8h, $Rm.h$idx}", "$Rd = $dst",
-          []>, // TODO: Add intrinsics
+          [(set (v4f32 V128:$dst),
+                (v4f32 (OpNode (v4f32 V128:$Rd),
+                               (v16i8 V128:$Rn),
+                               (v16i8 (bitconvert (v8bf16
+                                  (AArch64duplane16 (v8bf16 V128_lo:$Rm),
+                                      VectorIndexH:$idx)))))))]>,
     Sched<[WriteV]> {
   bits<5> Rd;
   bits<5> Rn;
@@ -7884,7 +7900,10 @@ class SIMDBF16MLALIndex<bit Q, string asm>
 class SIMDThreeSameVectorBF16MatrixMul<string asm>
   : BaseSIMDThreeSameVectorTied<1, 1, 0b010, 0b11101,
                                 V128, asm, ".4s",
-                          []> {
+                          [(set (v4f32 V128:$dst),
+                                (int_aarch64_neon_bfmmla (v4f32 V128:$Rd),
+                                                         (v16i8 V128:$Rn),
+                                                         (v16i8 V128:$Rm)))]> {
   let AsmString = !strconcat(asm, "{\t$Rd", ".4s", ", $Rn", ".8h",
                                     ", $Rm", ".8h", "}");
 }

diff  --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 8716ffb412d1..b56c5d9ff851 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -784,10 +784,10 @@ let Predicates = [HasBF16] in {
 defm BFDOT       : SIMDThreeSameVectorBFDot<1, "bfdot">;
 defm BF16DOTlane : SIMDThreeSameVectorBF16DotI<0, "bfdot">;
 def BFMMLA       : SIMDThreeSameVectorBF16MatrixMul<"bfmmla">;
-def BFMLALB      : SIMDBF16MLAL<0, "bfmlalb">;
-def BFMLALT      : SIMDBF16MLAL<1, "bfmlalt">;
-def BFMLALBIdx   : SIMDBF16MLALIndex<0, "bfmlalb">;
-def BFMLALTIdx   : SIMDBF16MLALIndex<1, "bfmlalt">;
+def BFMLALB      : SIMDBF16MLAL<0, "bfmlalb", int_aarch64_neon_bfmlalb>;
+def BFMLALT      : SIMDBF16MLAL<1, "bfmlalt", int_aarch64_neon_bfmlalt>;
+def BFMLALBIdx   : SIMDBF16MLALIndex<0, "bfmlalb", int_aarch64_neon_bfmlalb>;
+def BFMLALTIdx   : SIMDBF16MLALIndex<1, "bfmlalt", int_aarch64_neon_bfmlalt>;
 def BFCVTN       : SIMD_BFCVTN;
 def BFCVTN2      : SIMD_BFCVTN2;
 def BFCVT        : BF16ToSinglePrecision<"bfcvt">;

diff  --git a/llvm/test/CodeGen/AArch64/aarch64-bf16-dotprod-intrinsics.ll b/llvm/test/CodeGen/AArch64/aarch64-bf16-dotprod-intrinsics.ll
new file mode 100644
index 000000000000..96513115f2d9
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/aarch64-bf16-dotprod-intrinsics.ll
@@ -0,0 +1,176 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple aarch64-arm-none-eabi  -mattr=+bf16 %s -o - | FileCheck %s
+
+define <2 x float> @test_vbfdot_f32(<2 x float> %r, <4 x bfloat> %a, <4 x bfloat> %b) {
+; CHECK-LABEL: test_vbfdot_f32:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    bfdot v0.2s, v1.4h, v2.4h
+; CHECK-NEXT:    ret
+entry:
+  %0 = bitcast <4 x bfloat> %a to <8 x i8>
+  %1 = bitcast <4 x bfloat> %b to <8 x i8>
+  %vbfdot1.i = tail call <2 x float> @llvm.aarch64.neon.bfdot.v2f32.v8i8(<2 x float> %r, <8 x i8> %0, <8 x i8> %1)
+  ret <2 x float> %vbfdot1.i
+}
+
+define <4 x float> @test_vbfdotq_f32(<4 x float> %r, <8 x bfloat> %a, <8 x bfloat> %b) {
+; CHECK-LABEL: test_vbfdotq_f32:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    bfdot v0.4s, v1.8h, v2.8h
+; CHECK-NEXT:    ret
+entry:
+  %0 = bitcast <8 x bfloat> %a to <16 x i8>
+  %1 = bitcast <8 x bfloat> %b to <16 x i8>
+  %vbfdot1.i = tail call <4 x float> @llvm.aarch64.neon.bfdot.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1)
+  ret <4 x float> %vbfdot1.i
+}
+
+define <2 x float> @test_vbfdot_lane_f32(<2 x float> %r, <4 x bfloat> %a, <4 x bfloat> %b) {
+; CHECK-LABEL: test_vbfdot_lane_f32:
+; CHECK:       // %bb.0: // %entry
+; CHECK:    bfdot v0.2s, v1.4h, v2.2h[0]
+; CHECK-NEXT:    ret
+entry:
+  %0 = bitcast <4 x bfloat> %b to <2 x float>
+  %shuffle = shufflevector <2 x float> %0, <2 x float> undef, <2 x i32> zeroinitializer
+  %1 = bitcast <4 x bfloat> %a to <8 x i8>
+  %2 = bitcast <2 x float> %shuffle to <8 x i8>
+  %vbfdot1.i = tail call <2 x float> @llvm.aarch64.neon.bfdot.v2f32.v8i8(<2 x float> %r, <8 x i8> %1, <8 x i8> %2)
+  ret <2 x float> %vbfdot1.i
+}
+
+define <4 x float> @test_vbfdotq_laneq_f32(<4 x float> %r, <8 x bfloat> %a, <8 x bfloat> %b) {
+; CHECK-LABEL: test_vbfdotq_laneq_f32:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    bfdot v0.4s, v1.8h, v2.2h[3]
+; CHECK-NEXT:    ret
+entry:
+  %0 = bitcast <8 x bfloat> %b to <4 x float>
+  %shuffle = shufflevector <4 x float> %0, <4 x float> undef, <4 x i32> <i32 3, i32 3, i32 3, i32 3>
+  %1 = bitcast <8 x bfloat> %a to <16 x i8>
+  %2 = bitcast <4 x float> %shuffle to <16 x i8>
+  %vbfdot1.i = tail call <4 x float> @llvm.aarch64.neon.bfdot.v4f32.v16i8(<4 x float> %r, <16 x i8> %1, <16 x i8> %2)
+  ret <4 x float> %vbfdot1.i
+}
+
+define <2 x float> @test_vbfdot_laneq_f32(<2 x float> %r, <4 x bfloat> %a, <8 x bfloat> %b) {
+; CHECK-LABEL: test_vbfdot_laneq_f32:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    bfdot v0.2s, v1.4h, v2.2h[3]
+; CHECK-NEXT:    ret
+entry:
+  %0 = bitcast <8 x bfloat> %b to <4 x float>
+  %shuffle = shufflevector <4 x float> %0, <4 x float> undef, <2 x i32> <i32 3, i32 3>
+  %1 = bitcast <4 x bfloat> %a to <8 x i8>
+  %2 = bitcast <2 x float> %shuffle to <8 x i8>
+  %vbfdot1.i = tail call <2 x float> @llvm.aarch64.neon.bfdot.v2f32.v8i8(<2 x float> %r, <8 x i8> %1, <8 x i8> %2)
+  ret <2 x float> %vbfdot1.i
+}
+
+define <4 x float> @test_vbfdotq_lane_f32(<4 x float> %r, <8 x bfloat> %a, <4 x bfloat> %b) {
+; CHECK-LABEL: test_vbfdotq_lane_f32:
+; CHECK:       // %bb.0: // %entry
+; CHECK:    bfdot v0.4s, v1.8h, v2.2h[0]
+; CHECK-NEXT:    ret
+entry:
+  %0 = bitcast <4 x bfloat> %b to <2 x float>
+  %shuffle = shufflevector <2 x float> %0, <2 x float> undef, <4 x i32> zeroinitializer
+  %1 = bitcast <8 x bfloat> %a to <16 x i8>
+  %2 = bitcast <4 x float> %shuffle to <16 x i8>
+  %vbfdot1.i = tail call <4 x float> @llvm.aarch64.neon.bfdot.v4f32.v16i8(<4 x float> %r, <16 x i8> %1, <16 x i8> %2)
+  ret <4 x float> %vbfdot1.i
+}
+
+define <4 x float> @test_vbfmmlaq_f32(<4 x float> %r, <8 x bfloat> %a, <8 x bfloat> %b) {
+; CHECK-LABEL: test_vbfmmlaq_f32:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    bfmmla v0.4s, v1.8h, v2.8h
+; CHECK-NEXT:    ret
+entry:
+  %0 = bitcast <8 x bfloat> %a to <16 x i8>
+  %1 = bitcast <8 x bfloat> %b to <16 x i8>
+  %vbfmmla1.i = tail call <4 x float> @llvm.aarch64.neon.bfmmla.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1)
+  ret <4 x float> %vbfmmla1.i
+}
+
+define <4 x float> @test_vbfmlalbq_f32(<4 x float> %r, <8 x bfloat> %a, <8 x bfloat> %b) {
+; CHECK-LABEL: test_vbfmlalbq_f32:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    bfmlalb v0.4s, v1.8h, v2.8h
+; CHECK-NEXT:    ret
+entry:
+  %0 = bitcast <8 x bfloat> %a to <16 x i8>
+  %1 = bitcast <8 x bfloat> %b to <16 x i8>
+  %vbfmlalb1.i = tail call <4 x float> @llvm.aarch64.neon.bfmlalb.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1)
+  ret <4 x float> %vbfmlalb1.i
+}
+
+define <4 x float> @test_vbfmlaltq_f32(<4 x float> %r, <8 x bfloat> %a, <8 x bfloat> %b) {
+; CHECK-LABEL: test_vbfmlaltq_f32:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    bfmlalt v0.4s, v1.8h, v2.8h
+; CHECK-NEXT:    ret
+entry:
+  %0 = bitcast <8 x bfloat> %a to <16 x i8>
+  %1 = bitcast <8 x bfloat> %b to <16 x i8>
+  %vbfmlalt1.i = tail call <4 x float> @llvm.aarch64.neon.bfmlalt.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1)
+  ret <4 x float> %vbfmlalt1.i
+}
+
+define <4 x float> @test_vbfmlalbq_lane_f32(<4 x float> %r, <8 x bfloat> %a, <4 x bfloat> %b) {
+; CHECK-LABEL: test_vbfmlalbq_lane_f32:
+; CHECK:       // %bb.0: // %entry
+; CHECK:    bfmlalb v0.4s, v1.8h, v2.h[0]
+; CHECK-NEXT:    ret
+entry:
+  %vecinit35 = shufflevector <4 x bfloat> %b, <4 x bfloat> undef, <8 x i32> zeroinitializer
+  %0 = bitcast <8 x bfloat> %a to <16 x i8>
+  %1 = bitcast <8 x bfloat> %vecinit35 to <16 x i8>
+  %vbfmlalb1.i = tail call <4 x float> @llvm.aarch64.neon.bfmlalb.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1)
+  ret <4 x float> %vbfmlalb1.i
+}
+
+define <4 x float> @test_vbfmlalbq_laneq_f32(<4 x float> %r, <8 x bfloat> %a, <8 x bfloat> %b) {
+; CHECK-LABEL: test_vbfmlalbq_laneq_f32:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    bfmlalb v0.4s, v1.8h, v2.h[3]
+; CHECK-NEXT:    ret
+entry:
+  %vecinit35 = shufflevector <8 x bfloat> %b, <8 x bfloat> undef, <8 x i32> <i32 3, i32 3, i32 3, i32 3, i32 3, i32 3, i32 3, i32 3>
+  %0 = bitcast <8 x bfloat> %a to <16 x i8>
+  %1 = bitcast <8 x bfloat> %vecinit35 to <16 x i8>
+  %vbfmlalb1.i = tail call <4 x float> @llvm.aarch64.neon.bfmlalb.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1)
+  ret <4 x float> %vbfmlalb1.i
+}
+
+define <4 x float> @test_vbfmlaltq_lane_f32(<4 x float> %r, <8 x bfloat> %a, <4 x bfloat> %b) {
+; CHECK-LABEL: test_vbfmlaltq_lane_f32:
+; CHECK:       // %bb.0: // %entry
+; CHECK:    bfmlalt v0.4s, v1.8h, v2.h[0]
+; CHECK-NEXT:    ret
+entry:
+  %vecinit35 = shufflevector <4 x bfloat> %b, <4 x bfloat> undef, <8 x i32> zeroinitializer
+  %0 = bitcast <8 x bfloat> %a to <16 x i8>
+  %1 = bitcast <8 x bfloat> %vecinit35 to <16 x i8>
+  %vbfmlalt1.i = tail call <4 x float> @llvm.aarch64.neon.bfmlalt.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1)
+  ret <4 x float> %vbfmlalt1.i
+}
+
+define <4 x float> @test_vbfmlaltq_laneq_f32(<4 x float> %r, <8 x bfloat> %a, <8 x bfloat> %b) {
+; CHECK-LABEL: test_vbfmlaltq_laneq_f32:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    bfmlalt v0.4s, v1.8h, v2.h[3]
+; CHECK-NEXT:    ret
+entry:
+  %vecinit35 = shufflevector <8 x bfloat> %b, <8 x bfloat> undef, <8 x i32> <i32 3, i32 3, i32 3, i32 3, i32 3, i32 3, i32 3, i32 3>
+  %0 = bitcast <8 x bfloat> %a to <16 x i8>
+  %1 = bitcast <8 x bfloat> %vecinit35 to <16 x i8>
+  %vbfmlalt1.i = tail call <4 x float> @llvm.aarch64.neon.bfmlalt.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1)
+  ret <4 x float> %vbfmlalt1.i
+}
+
+declare <2 x float> @llvm.aarch64.neon.bfdot.v2f32.v8i8(<2 x float>, <8 x i8>, <8 x i8>) #2
+declare <4 x float> @llvm.aarch64.neon.bfdot.v4f32.v16i8(<4 x float>, <16 x i8>, <16 x i8>) #2
+declare <4 x float> @llvm.aarch64.neon.bfmmla.v4f32.v16i8(<4 x float>, <16 x i8>, <16 x i8>) #2
+declare <4 x float> @llvm.aarch64.neon.bfmlalb.v4f32.v16i8(<4 x float>, <16 x i8>, <16 x i8>) #2
+declare <4 x float> @llvm.aarch64.neon.bfmlalt.v4f32.v16i8(<4 x float>, <16 x i8>, <16 x i8>) #2


        


More information about the cfe-commits mailing list