[clang] 10b6567 - [AArch64]: BFloat MatMul Intrinsics&CodeGen
Luke Geeson via cfe-commits
cfe-commits at lists.llvm.org
Tue Jun 16 07:23:48 PDT 2020
Author: Luke Geeson
Date: 2020-06-16T15:23:30+01:00
New Revision: 10b6567f49778f49ea81ff36269fc0fbc033d7ad
URL: https://github.com/llvm/llvm-project/commit/10b6567f49778f49ea81ff36269fc0fbc033d7ad
DIFF: https://github.com/llvm/llvm-project/commit/10b6567f49778f49ea81ff36269fc0fbc033d7ad.diff
LOG: [AArch64]: BFloat MatMul Intrinsics&CodeGen
This patch upstreams support for BFloat Matrix Multiplication Intrinsics
and Code Generation from __bf16 to AArch64. This includes IR intrinsics. Unittests are
provided as needed. AArch32 Intrinsics + CodeGen will come after this
patch.
This patch is part of a series implementing the Bfloat16 extension of
the
Armv8.6-a architecture, as detailed here:
https://community.arm.com/developer/ip-products/processors/b/processors-ip-blog/posts/arm-architecture-developments-armv8-6-a
The bfloat type, and its properties are specified in the Arm
Architecture
Reference Manual:
https://developer.arm.com/docs/ddi0487/latest/arm-architecture-reference-manual-armv8-for-armv8-a-architecture-profile
The following people contributed to this patch:
Luke Geeson
- Momchil Velikov
- Mikhail Maltsev
- Luke Cheeseman
Reviewers: SjoerdMeijer, t.p.northover, sdesmalen, labrinea, miyuki,
stuij
Reviewed By: miyuki, stuij
Subscribers: kristof.beyls, hiraditya, danielkiss, cfe-commits,
llvm-commits, miyuki, chill, pbarrio, stuij
Tags: #clang, #llvm
Differential Revision: https://reviews.llvm.org/D80752
Change-Id: I174f0fd0f600d04e3799b06a7da88973c6c0703f
Added:
clang/test/CodeGen/aarch64-bf16-dotprod-intrinsics.c
llvm/test/CodeGen/AArch64/aarch64-bf16-dotprod-intrinsics.ll
Modified:
clang/include/clang/Basic/arm_neon.td
clang/lib/CodeGen/CGBuiltin.cpp
llvm/include/llvm/IR/IntrinsicsAArch64.td
llvm/lib/Target/AArch64/AArch64InstrFormats.td
llvm/lib/Target/AArch64/AArch64InstrInfo.td
Removed:
################################################################################
diff --git a/clang/include/clang/Basic/arm_neon.td b/clang/include/clang/Basic/arm_neon.td
index ffdf08ea494a..289f5ea47b92 100644
--- a/clang/include/clang/Basic/arm_neon.td
+++ b/clang/include/clang/Basic/arm_neon.td
@@ -244,6 +244,22 @@ def OP_SUDOT_LNQ
: Op<(call "vusdot", $p0,
(cast "8", "U", (call_mangled "splat_lane", (bitcast "int32x4_t", $p2), $p3)), $p1)>;
+def OP_BFDOT_LN
+ : Op<(call "vbfdot", $p0, $p1,
+ (bitcast $p1, (call_mangled "splat_lane", (bitcast "float32x2_t", $p2), $p3)))>;
+
+def OP_BFDOT_LNQ
+ : Op<(call "vbfdot", $p0, $p1,
+ (bitcast $p1, (call_mangled "splat_lane", (bitcast "float32x4_t", $p2), $p3)))>;
+
+def OP_BFMLALB_LN
+ : Op<(call "vbfmlalb", $p0, $p1,
+ (dup_typed $p1, (call "vget_lane", $p2, $p3)))>;
+
+def OP_BFMLALT_LN
+ : Op<(call "vbfmlalt", $p0, $p1,
+ (dup_typed $p1, (call "vget_lane", $p2, $p3)))>;
+
//===----------------------------------------------------------------------===//
// Auxiliary Instructions
//===----------------------------------------------------------------------===//
@@ -1847,6 +1863,25 @@ let ArchGuard = "defined(__ARM_FEATURE_MATMUL_INT8)" in {
}
}
+let ArchGuard = "defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC)" in {
+ def VDOT_BF : SInst<"vbfdot", "..BB", "fQf">;
+ def VDOT_LANE_BF : SOpInst<"vbfdot_lane", "..B(Bq)I", "fQf", OP_BFDOT_LN>;
+ def VDOT_LANEQ_BF : SOpInst<"vbfdot_laneq", "..B(BQ)I", "fQf", OP_BFDOT_LNQ> {
+ let isLaneQ = 1;
+ }
+
+ def VFMMLA_BF : SInst<"vbfmmla", "..BB", "Qf">;
+
+ def VFMLALB_BF : SInst<"vbfmlalb", "..BB", "Qf">;
+ def VFMLALT_BF : SInst<"vbfmlalt", "..BB", "Qf">;
+
+ def VFMLALB_LANE_BF : SOpInst<"vbfmlalb_lane", "..B(Bq)I", "Qf", OP_BFMLALB_LN>;
+ def VFMLALB_LANEQ_BF : SOpInst<"vbfmlalb_laneq", "..B(BQ)I", "Qf", OP_BFMLALB_LN>;
+
+ def VFMLALT_LANE_BF : SOpInst<"vbfmlalt_lane", "..B(Bq)I", "Qf", OP_BFMLALT_LN>;
+ def VFMLALT_LANEQ_BF : SOpInst<"vbfmlalt_laneq", "..B(BQ)I", "Qf", OP_BFMLALT_LN>;
+}
+
// v8.3-A Vector complex addition intrinsics
let ArchGuard = "defined(__ARM_FEATURE_COMPLEX) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)" in {
def VCADD_ROT90_FP16 : SInst<"vcadd_rot90", "...", "h">;
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 209b5a2b00e3..c3cfed34eeba 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -4970,6 +4970,11 @@ static const ARMVectorIntrinsicInfo AArch64SIMDIntrinsicMap[] = {
NEONMAP1(vaeseq_v, aarch64_crypto_aese, 0),
NEONMAP1(vaesimcq_v, aarch64_crypto_aesimc, 0),
NEONMAP1(vaesmcq_v, aarch64_crypto_aesmc, 0),
+ NEONMAP1(vbfdot_v, aarch64_neon_bfdot, 0),
+ NEONMAP1(vbfdotq_v, aarch64_neon_bfdot, 0),
+ NEONMAP1(vbfmlalbq_v, aarch64_neon_bfmlalb, 0),
+ NEONMAP1(vbfmlaltq_v, aarch64_neon_bfmlalt, 0),
+ NEONMAP1(vbfmmlaq_v, aarch64_neon_bfmmla, 0),
NEONMAP1(vcadd_rot270_v, aarch64_neon_vcadd_rot270, Add1ArgType),
NEONMAP1(vcadd_rot90_v, aarch64_neon_vcadd_rot90, Add1ArgType),
NEONMAP1(vcaddq_rot270_v, aarch64_neon_vcadd_rot270, Add1ArgType),
@@ -6141,6 +6146,32 @@ Value *CodeGenFunction::EmitCommonNeonBuiltinExpr(
llvm::Type *Tys[2] = { Ty, InputTy };
return EmitNeonCall(CGM.getIntrinsic(Int, Tys), Ops, "vusdot");
}
+ case NEON::BI__builtin_neon_vbfdot_v:
+ case NEON::BI__builtin_neon_vbfdotq_v: {
+ llvm::Type *InputTy =
+ llvm::VectorType::get(Int8Ty, Ty->getPrimitiveSizeInBits() / 8);
+ llvm::Type *Tys[2] = { Ty, InputTy };
+ return EmitNeonCall(CGM.getIntrinsic(Int, Tys), Ops, "vbfdot");
+ }
+ case NEON::BI__builtin_neon_vbfmmlaq_v: {
+ llvm::Type *InputTy =
+ llvm::VectorType::get(Int8Ty, Ty->getPrimitiveSizeInBits() / 8);
+ llvm::Type *Tys[2] = { Ty, InputTy };
+ return EmitNeonCall(CGM.getIntrinsic(Int, Tys), Ops, "vbfmmla");
+ }
+ case NEON::BI__builtin_neon_vbfmlalbq_v: {
+ llvm::Type *InputTy =
+ llvm::VectorType::get(Int8Ty, Ty->getPrimitiveSizeInBits() / 8);
+ llvm::Type *Tys[2] = { Ty, InputTy };
+ return EmitNeonCall(CGM.getIntrinsic(Int, Tys), Ops, "vbfmlalb");
+ }
+ case NEON::BI__builtin_neon_vbfmlaltq_v: {
+ llvm::Type *InputTy =
+ llvm::VectorType::get(Int8Ty, Ty->getPrimitiveSizeInBits() / 8);
+ llvm::Type *Tys[2] = { Ty, InputTy };
+ return EmitNeonCall(CGM.getIntrinsic(Int, Tys), Ops, "vbfmlalt");
+ }
+
}
assert(Int && "Expected valid intrinsic number");
diff --git a/clang/test/CodeGen/aarch64-bf16-dotprod-intrinsics.c b/clang/test/CodeGen/aarch64-bf16-dotprod-intrinsics.c
new file mode 100644
index 000000000000..22e1396787ce
--- /dev/null
+++ b/clang/test/CodeGen/aarch64-bf16-dotprod-intrinsics.c
@@ -0,0 +1,146 @@
+// RUN: %clang_cc1 -triple aarch64-arm-none-eabi -target-feature +neon -target-feature +bf16 \
+// RUN: -disable-O0-optnone -emit-llvm %s -o - | opt -S -mem2reg -instcombine | FileCheck %s
+
+#include <arm_neon.h>
+
+// CHECK-LABEL: test_vbfdot_f32
+// CHECK-NEXT: entry:
+// CHECK-NEXT %0 = bitcast <4 x bfloat> %a to <8 x i8>
+// CHECK-NEXT %1 = bitcast <4 x bfloat> %b to <8 x i8>
+// CHECK-NEXT %vbfdot1.i = tail call <2 x float> @llvm.aarch64.neon.bfdot.v2f32.v8i8(<2 x float> %r, <8 x i8> %0, <8 x i8> %1)
+// CHECK-NEXT ret <2 x float> %vbfdot1.i
+float32x2_t test_vbfdot_f32(float32x2_t r, bfloat16x4_t a, bfloat16x4_t b) {
+ return vbfdot_f32(r, a, b);
+}
+
+// CHECK-LABEL: test_vbfdotq_f32
+// CHECK-NEXT: entry:
+// CHECK-NEXT %0 = bitcast <8 x bfloat> %a to <16 x i8>
+// CHECK-NEXT %1 = bitcast <8 x bfloat> %b to <16 x i8>
+// CHECK-NEXT %vbfdot1.i = tail call <4 x float> @llvm.aarch64.neon.bfdot.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1)
+// CHECK-NEXT ret <4 x float> %vbfdot1.i
+float32x4_t test_vbfdotq_f32(float32x4_t r, bfloat16x8_t a, bfloat16x8_t b){
+ return vbfdotq_f32(r, a, b);
+}
+
+// CHECK-LABEL: test_vbfdot_lane_f32
+// CHECK-NEXT: entry:
+// CHECK-NEXT %0 = bitcast <4 x bfloat> %b to <2 x float>
+// CHECK-NEXT %lane = shufflevector <2 x float> %0, <2 x float> undef, <2 x i32> zeroinitializer
+// CHECK-NEXT %1 = bitcast <4 x bfloat> %a to <8 x i8>
+// CHECK-NEXT %2 = bitcast <2 x float> %lane to <8 x i8>
+// CHECK-NEXT %vbfdot1.i = tail call <2 x float> @llvm.aarch64.neon.bfdot.v2f32.v8i8(<2 x float> %r, <8 x i8> %1, <8 x i8> %2)
+// CHECK-NEXT ret <2 x float> %vbfdot1.i
+float32x2_t test_vbfdot_lane_f32(float32x2_t r, bfloat16x4_t a, bfloat16x4_t b){
+ return vbfdot_lane_f32(r, a, b, 0);
+}
+
+// CHECK-LABEL: test_vbfdotq_laneq_f32
+// CHECK-NEXT: entry:
+// CHECK-NEXT %0 = bitcast <8 x bfloat> %b to <4 x float>
+// CHECK-NEXT %lane = shufflevector <4 x float> %0, <4 x float> undef, <4 x i32> <i32 3, i32 3, i32 3, i32 3>
+// CHECK-NEXT %1 = bitcast <8 x bfloat> %a to <16 x i8>
+// CHECK-NEXT %2 = bitcast <4 x float> %lane to <16 x i8>
+// CHECK-NEXT %vbfdot1.i = tail call <4 x float> @llvm.aarch64.neon.bfdot.v4f32.v16i8(<4 x float> %r, <16 x i8> %1, <16 x i8> %2)
+// CHECK-NEXT ret <4 x float> %vbfdot1.i
+float32x4_t test_vbfdotq_laneq_f32(float32x4_t r, bfloat16x8_t a, bfloat16x8_t b) {
+ return vbfdotq_laneq_f32(r, a, b, 3);
+}
+
+// CHECK-LABEL: test_vbfdot_laneq_f32
+// CHECK-NEXT: entry:
+// CHECK-NEXT %0 = bitcast <8 x bfloat> %b to <4 x float>
+// CHECK-NEXT %lane = shufflevector <4 x float> %0, <4 x float> undef, <2 x i32> <i32 3, i32 3>
+// CHECK-NEXT %1 = bitcast <4 x bfloat> %a to <8 x i8>
+// CHECK-NEXT %2 = bitcast <2 x float> %lane to <8 x i8>
+// CHECK-NEXT %vbfdot1.i = tail call <2 x float> @llvm.aarch64.neon.bfdot.v2f32.v8i8(<2 x float> %r, <8 x i8> %1, <8 x i8> %2)
+// CHECK-NEXT ret <2 x float> %vbfdot1.i
+float32x2_t test_vbfdot_laneq_f32(float32x2_t r, bfloat16x4_t a, bfloat16x8_t b) {
+ return vbfdot_laneq_f32(r, a, b, 3);
+}
+
+// CHECK-LABEL: test_vbfdotq_lane_f32
+// CHECK-NEXT: entry:
+// CHECK-NEXT %0 = bitcast <4 x bfloat> %b to <2 x float>
+// CHECK-NEXT %lane = shufflevector <2 x float> %0, <2 x float> undef, <4 x i32> zeroinitializer
+// CHECK-NEXT %1 = bitcast <8 x bfloat> %a to <16 x i8>
+// CHECK-NEXT %2 = bitcast <4 x float> %lane to <16 x i8>
+// CHECK-NEXT %vbfdot1.i = tail call <4 x float> @llvm.aarch64.neon.bfdot.v4f32.v16i8(<4 x float> %r, <16 x i8> %1, <16 x i8> %2)
+// CHECK-NEXT ret <4 x float> %vbfdot1.i
+float32x4_t test_vbfdotq_lane_f32(float32x4_t r, bfloat16x8_t a, bfloat16x4_t b) {
+ return vbfdotq_lane_f32(r, a, b, 0);
+}
+
+// CHECK-LABEL: test_vbfmmlaq_f32
+// CHECK-NEXT: entry:
+// CHECK-NEXT %0 = bitcast <8 x bfloat> %a to <16 x i8>
+// CHECK-NEXT %1 = bitcast <8 x bfloat> %b to <16 x i8>
+// CHECK-NEXT %vbfmmla1.i = tail call <4 x float> @llvm.aarch64.neon.bfmmla.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1)
+// CHECK-NEXT ret <4 x float> %vbfmmla1.i
+float32x4_t test_vbfmmlaq_f32(float32x4_t r, bfloat16x8_t a, bfloat16x8_t b) {
+ return vbfmmlaq_f32(r, a, b);
+}
+
+// CHECK-LABEL: test_vbfmlalbq_f32
+// CHECK-NEXT: entry:
+// CHECK-NEXT %0 = bitcast <8 x bfloat> %a to <16 x i8>
+// CHECK-NEXT %1 = bitcast <8 x bfloat> %b to <16 x i8>
+// CHECK-NEXT %vbfmlalb1.i = tail call <4 x float> @llvm.aarch64.neon.bfmlalb.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1)
+// CHECK-NEXT ret <4 x float> %vbfmlalb1.i
+float32x4_t test_vbfmlalbq_f32(float32x4_t r, bfloat16x8_t a, bfloat16x8_t b) {
+ return vbfmlalbq_f32(r, a, b);
+}
+
+// CHECK-LABEL: test_vbfmlaltq_f32
+// CHECK-NEXT: entry:
+// CHECK-NEXT %0 = bitcast <8 x bfloat> %a to <16 x i8>
+// CHECK-NEXT %1 = bitcast <8 x bfloat> %b to <16 x i8>
+// CHECK-NEXT %vbfmlalt1.i = tail call <4 x float> @llvm.aarch64.neon.bfmlalt.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1)
+// CHECK-NEXT ret <4 x float> %vbfmlalt1.i
+float32x4_t test_vbfmlaltq_f32(float32x4_t r, bfloat16x8_t a, bfloat16x8_t b) {
+ return vbfmlaltq_f32(r, a, b);
+}
+
+// CHECK-LABEL: test_vbfmlalbq_lane_f32
+// CHECK-NEXT: entry:
+// CHECK-NEXT %vecinit35 = shufflevector <4 x bfloat> %b, <4 x bfloat> undef, <8 x i32> zeroinitializer
+// CHECK-NEXT %0 = bitcast <8 x bfloat> %a to <16 x i8>
+// CHECK-NEXT %1 = bitcast <8 x bfloat> %vecinit35 to <16 x i8>
+// CHECK-NEXT %vbfmlalb1.i = tail call <4 x float> @llvm.aarch64.neon.bfmlalb.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1)
+// CHECK-NEXT ret <4 x float> %vbfmlalb1.i
+float32x4_t test_vbfmlalbq_lane_f32(float32x4_t r, bfloat16x8_t a, bfloat16x4_t b) {
+ return vbfmlalbq_lane_f32(r, a, b, 0);
+}
+
+// CHECK-LABEL: test_vbfmlalbq_laneq_f32
+// CHECK-NEXT: entry:
+// CHECK-NEXT %vecinit35 = shufflevector <8 x bfloat> %b, <8 x bfloat> undef, <8 x i32> <i32 3, i32 3, i32 3, i32 3, i32 3, i32 3, i32 3, i32 3>
+// CHECK-NEXT %0 = bitcast <8 x bfloat> %a to <16 x i8>
+// CHECK-NEXT %1 = bitcast <8 x bfloat> %vecinit35 to <16 x i8>
+// CHECK-NEXT %vbfmlalb1.i = tail call <4 x float> @llvm.aarch64.neon.bfmlalb.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1)
+// CHECK-NEXT ret <4 x float> %vbfmlalb1.i
+float32x4_t test_vbfmlalbq_laneq_f32(float32x4_t r, bfloat16x8_t a, bfloat16x8_t b) {
+ return vbfmlalbq_laneq_f32(r, a, b, 3);
+}
+
+// CHECK-LABEL: test_vbfmlaltq_lane_f32
+// CHECK-NEXT: entry:
+// CHECK-NEXT %vecinit35 = shufflevector <4 x bfloat> %b, <4 x bfloat> undef, <8 x i32> zeroinitializer
+// CHECK-NEXT %0 = bitcast <8 x bfloat> %a to <16 x i8>
+// CHECK-NEXT %1 = bitcast <8 x bfloat> %vecinit35 to <16 x i8>
+// CHECK-NEXT %vbfmlalt1.i = tail call <4 x float> @llvm.aarch64.neon.bfmlalt.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1)
+// CHECK-NEXT ret <4 x float> %vbfmlalt1.i
+float32x4_t test_vbfmlaltq_lane_f32(float32x4_t r, bfloat16x8_t a, bfloat16x4_t b) {
+ return vbfmlaltq_lane_f32(r, a, b, 0);
+}
+
+// CHECK-LABEL: test_vbfmlaltq_laneq_f32
+// CHECK-NEXT: entry:
+// CHECK-NEXT %vecinit35 = shufflevector <8 x bfloat> %b, <8 x bfloat> undef, <8 x i32> <i32 3, i32 3, i32 3, i32 3, i32 3, i32 3, i32 3, i32 3>
+// CHECK-NEXT %0 = bitcast <8 x bfloat> %a to <16 x i8>
+// CHECK-NEXT %1 = bitcast <8 x bfloat> %vecinit35 to <16 x i8>
+// CHECK-NEXT %vbfmlalt1.i = tail call <4 x float> @llvm.aarch64.neon.bfmlalt.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1)
+// CHECK-NEXT ret <4 x float> %vbfmlalt1.i
+float32x4_t test_vbfmlaltq_laneq_f32(float32x4_t r, bfloat16x8_t a, bfloat16x8_t b) {
+ return vbfmlaltq_laneq_f32(r, a, b, 3);
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsAArch64.td b/llvm/include/llvm/IR/IntrinsicsAArch64.td
index 9019b9d3be55..483afe97cc63 100644
--- a/llvm/include/llvm/IR/IntrinsicsAArch64.td
+++ b/llvm/include/llvm/IR/IntrinsicsAArch64.td
@@ -178,6 +178,12 @@ let TargetPrefix = "aarch64" in { // All intrinsics start with "llvm.aarch64.".
: Intrinsic<[llvm_anyvector_ty],
[LLVMMatchType<0>, llvm_anyvector_ty, LLVMMatchType<1>],
[IntrNoMem]>;
+
+ class AdvSIMD_FML_Intrinsic
+ : Intrinsic<[llvm_anyvector_ty],
+ [LLVMMatchType<0>, llvm_anyvector_ty, LLVMMatchType<1>],
+ [IntrNoMem]>;
+
}
// Arithmetic ops
@@ -459,6 +465,11 @@ let TargetPrefix = "aarch64", IntrProperties = [IntrNoMem] in {
def int_aarch64_neon_smmla : AdvSIMD_MatMul_Intrinsic;
def int_aarch64_neon_usmmla : AdvSIMD_MatMul_Intrinsic;
def int_aarch64_neon_usdot : AdvSIMD_Dot_Intrinsic;
+ def int_aarch64_neon_bfdot : AdvSIMD_Dot_Intrinsic;
+ def int_aarch64_neon_bfmmla : AdvSIMD_MatMul_Intrinsic;
+ def int_aarch64_neon_bfmlalb : AdvSIMD_FML_Intrinsic;
+ def int_aarch64_neon_bfmlalt : AdvSIMD_FML_Intrinsic;
+
// v8.2-A FP16 Fused Multiply-Add Long
def int_aarch64_neon_fmlal : AdvSIMD_FP16FML_Intrinsic;
diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
index 713bf0bf3cad..8f5202af96e4 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
@@ -7815,16 +7815,19 @@ let mayStore = 0, mayLoad = 0, hasSideEffects = 0 in {
class BaseSIMDThreeSameVectorBFDot<bit Q, bit U, string asm, string kind1,
string kind2, RegisterOperand RegType,
ValueType AccumType, ValueType InputType>
- : BaseSIMDThreeSameVectorTied<Q, U, 0b010, 0b11111, RegType, asm, kind1, []> {
+ : BaseSIMDThreeSameVectorTied<Q, U, 0b010, 0b11111, RegType, asm, kind1, [(set (AccumType RegType:$dst),
+ (int_aarch64_neon_bfdot (AccumType RegType:$Rd),
+ (InputType RegType:$Rn),
+ (InputType RegType:$Rm)))]> {
let AsmString = !strconcat(asm,
"{\t$Rd" # kind1 # ", $Rn" # kind2 #
", $Rm" # kind2 # "}");
}
multiclass SIMDThreeSameVectorBFDot<bit U, string asm> {
- def v4f16 : BaseSIMDThreeSameVectorBFDot<0, U, asm, ".2s", ".4h", V64,
+ def v4bf16 : BaseSIMDThreeSameVectorBFDot<0, U, asm, ".2s", ".4h", V64,
v2f32, v8i8>;
- def v8f16 : BaseSIMDThreeSameVectorBFDot<1, U, asm, ".4s", ".8h", V128,
+ def v8bf16 : BaseSIMDThreeSameVectorBFDot<1, U, asm, ".4s", ".8h", V128,
v4f32, v16i8>;
}
@@ -7837,7 +7840,13 @@ class BaseSIMDThreeSameVectorBF16DotI<bit Q, bit U, string asm,
: BaseSIMDIndexedTied<Q, U, 0b0, 0b01, 0b1111,
RegType, RegType, V128, VectorIndexS,
asm, "", dst_kind, lhs_kind, rhs_kind,
- []> {
+ [(set (AccumType RegType:$dst),
+ (AccumType (int_aarch64_neon_bfdot
+ (AccumType RegType:$Rd),
+ (InputType RegType:$Rn),
+ (InputType (bitconvert (AccumType
+ (AArch64duplane32 (v4f32 V128:$Rm),
+ VectorIndexH:$idx)))))))]> {
bits<2> idx;
let Inst{21} = idx{0}; // L
@@ -7846,23 +7855,30 @@ class BaseSIMDThreeSameVectorBF16DotI<bit Q, bit U, string asm,
multiclass SIMDThreeSameVectorBF16DotI<bit U, string asm> {
- def v4f16 : BaseSIMDThreeSameVectorBF16DotI<0, U, asm, ".2s", ".4h",
+ def v4bf16 : BaseSIMDThreeSameVectorBF16DotI<0, U, asm, ".2s", ".4h",
".2h", V64, v2f32, v8i8>;
- def v8f16 : BaseSIMDThreeSameVectorBF16DotI<1, U, asm, ".4s", ".8h",
+ def v8bf16 : BaseSIMDThreeSameVectorBF16DotI<1, U, asm, ".4s", ".8h",
".2h", V128, v4f32, v16i8>;
}
-class SIMDBF16MLAL<bit Q, string asm>
+class SIMDBF16MLAL<bit Q, string asm, SDPatternOperator OpNode>
: BaseSIMDThreeSameVectorTied<Q, 0b1, 0b110, 0b11111, V128, asm, ".4s",
- []> { // TODO: Add intrinsics
+ [(set (v4f32 V128:$dst), (OpNode (v4f32 V128:$Rd),
+ (v16i8 V128:$Rn),
+ (v16i8 V128:$Rm)))]> {
let AsmString = !strconcat(asm, "{\t$Rd.4s, $Rn.8h, $Rm.8h}");
}
-class SIMDBF16MLALIndex<bit Q, string asm>
+class SIMDBF16MLALIndex<bit Q, string asm, SDPatternOperator OpNode>
: I<(outs V128:$dst),
(ins V128:$Rd, V128:$Rn, V128_lo:$Rm, VectorIndexH:$idx), asm,
"{\t$Rd.4s, $Rn.8h, $Rm.h$idx}", "$Rd = $dst",
- []>, // TODO: Add intrinsics
+ [(set (v4f32 V128:$dst),
+ (v4f32 (OpNode (v4f32 V128:$Rd),
+ (v16i8 V128:$Rn),
+ (v16i8 (bitconvert (v8bf16
+ (AArch64duplane16 (v8bf16 V128_lo:$Rm),
+ VectorIndexH:$idx)))))))]>,
Sched<[WriteV]> {
bits<5> Rd;
bits<5> Rn;
@@ -7884,7 +7900,10 @@ class SIMDBF16MLALIndex<bit Q, string asm>
class SIMDThreeSameVectorBF16MatrixMul<string asm>
: BaseSIMDThreeSameVectorTied<1, 1, 0b010, 0b11101,
V128, asm, ".4s",
- []> {
+ [(set (v4f32 V128:$dst),
+ (int_aarch64_neon_bfmmla (v4f32 V128:$Rd),
+ (v16i8 V128:$Rn),
+ (v16i8 V128:$Rm)))]> {
let AsmString = !strconcat(asm, "{\t$Rd", ".4s", ", $Rn", ".8h",
", $Rm", ".8h", "}");
}
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 8716ffb412d1..b56c5d9ff851 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -784,10 +784,10 @@ let Predicates = [HasBF16] in {
defm BFDOT : SIMDThreeSameVectorBFDot<1, "bfdot">;
defm BF16DOTlane : SIMDThreeSameVectorBF16DotI<0, "bfdot">;
def BFMMLA : SIMDThreeSameVectorBF16MatrixMul<"bfmmla">;
-def BFMLALB : SIMDBF16MLAL<0, "bfmlalb">;
-def BFMLALT : SIMDBF16MLAL<1, "bfmlalt">;
-def BFMLALBIdx : SIMDBF16MLALIndex<0, "bfmlalb">;
-def BFMLALTIdx : SIMDBF16MLALIndex<1, "bfmlalt">;
+def BFMLALB : SIMDBF16MLAL<0, "bfmlalb", int_aarch64_neon_bfmlalb>;
+def BFMLALT : SIMDBF16MLAL<1, "bfmlalt", int_aarch64_neon_bfmlalt>;
+def BFMLALBIdx : SIMDBF16MLALIndex<0, "bfmlalb", int_aarch64_neon_bfmlalb>;
+def BFMLALTIdx : SIMDBF16MLALIndex<1, "bfmlalt", int_aarch64_neon_bfmlalt>;
def BFCVTN : SIMD_BFCVTN;
def BFCVTN2 : SIMD_BFCVTN2;
def BFCVT : BF16ToSinglePrecision<"bfcvt">;
diff --git a/llvm/test/CodeGen/AArch64/aarch64-bf16-dotprod-intrinsics.ll b/llvm/test/CodeGen/AArch64/aarch64-bf16-dotprod-intrinsics.ll
new file mode 100644
index 000000000000..96513115f2d9
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/aarch64-bf16-dotprod-intrinsics.ll
@@ -0,0 +1,176 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple aarch64-arm-none-eabi -mattr=+bf16 %s -o - | FileCheck %s
+
+define <2 x float> @test_vbfdot_f32(<2 x float> %r, <4 x bfloat> %a, <4 x bfloat> %b) {
+; CHECK-LABEL: test_vbfdot_f32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: bfdot v0.2s, v1.4h, v2.4h
+; CHECK-NEXT: ret
+entry:
+ %0 = bitcast <4 x bfloat> %a to <8 x i8>
+ %1 = bitcast <4 x bfloat> %b to <8 x i8>
+ %vbfdot1.i = tail call <2 x float> @llvm.aarch64.neon.bfdot.v2f32.v8i8(<2 x float> %r, <8 x i8> %0, <8 x i8> %1)
+ ret <2 x float> %vbfdot1.i
+}
+
+define <4 x float> @test_vbfdotq_f32(<4 x float> %r, <8 x bfloat> %a, <8 x bfloat> %b) {
+; CHECK-LABEL: test_vbfdotq_f32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: bfdot v0.4s, v1.8h, v2.8h
+; CHECK-NEXT: ret
+entry:
+ %0 = bitcast <8 x bfloat> %a to <16 x i8>
+ %1 = bitcast <8 x bfloat> %b to <16 x i8>
+ %vbfdot1.i = tail call <4 x float> @llvm.aarch64.neon.bfdot.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1)
+ ret <4 x float> %vbfdot1.i
+}
+
+define <2 x float> @test_vbfdot_lane_f32(<2 x float> %r, <4 x bfloat> %a, <4 x bfloat> %b) {
+; CHECK-LABEL: test_vbfdot_lane_f32:
+; CHECK: // %bb.0: // %entry
+; CHECK: bfdot v0.2s, v1.4h, v2.2h[0]
+; CHECK-NEXT: ret
+entry:
+ %0 = bitcast <4 x bfloat> %b to <2 x float>
+ %shuffle = shufflevector <2 x float> %0, <2 x float> undef, <2 x i32> zeroinitializer
+ %1 = bitcast <4 x bfloat> %a to <8 x i8>
+ %2 = bitcast <2 x float> %shuffle to <8 x i8>
+ %vbfdot1.i = tail call <2 x float> @llvm.aarch64.neon.bfdot.v2f32.v8i8(<2 x float> %r, <8 x i8> %1, <8 x i8> %2)
+ ret <2 x float> %vbfdot1.i
+}
+
+define <4 x float> @test_vbfdotq_laneq_f32(<4 x float> %r, <8 x bfloat> %a, <8 x bfloat> %b) {
+; CHECK-LABEL: test_vbfdotq_laneq_f32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: bfdot v0.4s, v1.8h, v2.2h[3]
+; CHECK-NEXT: ret
+entry:
+ %0 = bitcast <8 x bfloat> %b to <4 x float>
+ %shuffle = shufflevector <4 x float> %0, <4 x float> undef, <4 x i32> <i32 3, i32 3, i32 3, i32 3>
+ %1 = bitcast <8 x bfloat> %a to <16 x i8>
+ %2 = bitcast <4 x float> %shuffle to <16 x i8>
+ %vbfdot1.i = tail call <4 x float> @llvm.aarch64.neon.bfdot.v4f32.v16i8(<4 x float> %r, <16 x i8> %1, <16 x i8> %2)
+ ret <4 x float> %vbfdot1.i
+}
+
+define <2 x float> @test_vbfdot_laneq_f32(<2 x float> %r, <4 x bfloat> %a, <8 x bfloat> %b) {
+; CHECK-LABEL: test_vbfdot_laneq_f32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: bfdot v0.2s, v1.4h, v2.2h[3]
+; CHECK-NEXT: ret
+entry:
+ %0 = bitcast <8 x bfloat> %b to <4 x float>
+ %shuffle = shufflevector <4 x float> %0, <4 x float> undef, <2 x i32> <i32 3, i32 3>
+ %1 = bitcast <4 x bfloat> %a to <8 x i8>
+ %2 = bitcast <2 x float> %shuffle to <8 x i8>
+ %vbfdot1.i = tail call <2 x float> @llvm.aarch64.neon.bfdot.v2f32.v8i8(<2 x float> %r, <8 x i8> %1, <8 x i8> %2)
+ ret <2 x float> %vbfdot1.i
+}
+
+define <4 x float> @test_vbfdotq_lane_f32(<4 x float> %r, <8 x bfloat> %a, <4 x bfloat> %b) {
+; CHECK-LABEL: test_vbfdotq_lane_f32:
+; CHECK: // %bb.0: // %entry
+; CHECK: bfdot v0.4s, v1.8h, v2.2h[0]
+; CHECK-NEXT: ret
+entry:
+ %0 = bitcast <4 x bfloat> %b to <2 x float>
+ %shuffle = shufflevector <2 x float> %0, <2 x float> undef, <4 x i32> zeroinitializer
+ %1 = bitcast <8 x bfloat> %a to <16 x i8>
+ %2 = bitcast <4 x float> %shuffle to <16 x i8>
+ %vbfdot1.i = tail call <4 x float> @llvm.aarch64.neon.bfdot.v4f32.v16i8(<4 x float> %r, <16 x i8> %1, <16 x i8> %2)
+ ret <4 x float> %vbfdot1.i
+}
+
+define <4 x float> @test_vbfmmlaq_f32(<4 x float> %r, <8 x bfloat> %a, <8 x bfloat> %b) {
+; CHECK-LABEL: test_vbfmmlaq_f32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: bfmmla v0.4s, v1.8h, v2.8h
+; CHECK-NEXT: ret
+entry:
+ %0 = bitcast <8 x bfloat> %a to <16 x i8>
+ %1 = bitcast <8 x bfloat> %b to <16 x i8>
+ %vbfmmla1.i = tail call <4 x float> @llvm.aarch64.neon.bfmmla.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1)
+ ret <4 x float> %vbfmmla1.i
+}
+
+define <4 x float> @test_vbfmlalbq_f32(<4 x float> %r, <8 x bfloat> %a, <8 x bfloat> %b) {
+; CHECK-LABEL: test_vbfmlalbq_f32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: bfmlalb v0.4s, v1.8h, v2.8h
+; CHECK-NEXT: ret
+entry:
+ %0 = bitcast <8 x bfloat> %a to <16 x i8>
+ %1 = bitcast <8 x bfloat> %b to <16 x i8>
+ %vbfmlalb1.i = tail call <4 x float> @llvm.aarch64.neon.bfmlalb.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1)
+ ret <4 x float> %vbfmlalb1.i
+}
+
+define <4 x float> @test_vbfmlaltq_f32(<4 x float> %r, <8 x bfloat> %a, <8 x bfloat> %b) {
+; CHECK-LABEL: test_vbfmlaltq_f32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: bfmlalt v0.4s, v1.8h, v2.8h
+; CHECK-NEXT: ret
+entry:
+ %0 = bitcast <8 x bfloat> %a to <16 x i8>
+ %1 = bitcast <8 x bfloat> %b to <16 x i8>
+ %vbfmlalt1.i = tail call <4 x float> @llvm.aarch64.neon.bfmlalt.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1)
+ ret <4 x float> %vbfmlalt1.i
+}
+
+define <4 x float> @test_vbfmlalbq_lane_f32(<4 x float> %r, <8 x bfloat> %a, <4 x bfloat> %b) {
+; CHECK-LABEL: test_vbfmlalbq_lane_f32:
+; CHECK: // %bb.0: // %entry
+; CHECK: bfmlalb v0.4s, v1.8h, v2.h[0]
+; CHECK-NEXT: ret
+entry:
+ %vecinit35 = shufflevector <4 x bfloat> %b, <4 x bfloat> undef, <8 x i32> zeroinitializer
+ %0 = bitcast <8 x bfloat> %a to <16 x i8>
+ %1 = bitcast <8 x bfloat> %vecinit35 to <16 x i8>
+ %vbfmlalb1.i = tail call <4 x float> @llvm.aarch64.neon.bfmlalb.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1)
+ ret <4 x float> %vbfmlalb1.i
+}
+
+define <4 x float> @test_vbfmlalbq_laneq_f32(<4 x float> %r, <8 x bfloat> %a, <8 x bfloat> %b) {
+; CHECK-LABEL: test_vbfmlalbq_laneq_f32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: bfmlalb v0.4s, v1.8h, v2.h[3]
+; CHECK-NEXT: ret
+entry:
+ %vecinit35 = shufflevector <8 x bfloat> %b, <8 x bfloat> undef, <8 x i32> <i32 3, i32 3, i32 3, i32 3, i32 3, i32 3, i32 3, i32 3>
+ %0 = bitcast <8 x bfloat> %a to <16 x i8>
+ %1 = bitcast <8 x bfloat> %vecinit35 to <16 x i8>
+ %vbfmlalb1.i = tail call <4 x float> @llvm.aarch64.neon.bfmlalb.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1)
+ ret <4 x float> %vbfmlalb1.i
+}
+
+define <4 x float> @test_vbfmlaltq_lane_f32(<4 x float> %r, <8 x bfloat> %a, <4 x bfloat> %b) {
+; CHECK-LABEL: test_vbfmlaltq_lane_f32:
+; CHECK: // %bb.0: // %entry
+; CHECK: bfmlalt v0.4s, v1.8h, v2.h[0]
+; CHECK-NEXT: ret
+entry:
+ %vecinit35 = shufflevector <4 x bfloat> %b, <4 x bfloat> undef, <8 x i32> zeroinitializer
+ %0 = bitcast <8 x bfloat> %a to <16 x i8>
+ %1 = bitcast <8 x bfloat> %vecinit35 to <16 x i8>
+ %vbfmlalt1.i = tail call <4 x float> @llvm.aarch64.neon.bfmlalt.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1)
+ ret <4 x float> %vbfmlalt1.i
+}
+
+define <4 x float> @test_vbfmlaltq_laneq_f32(<4 x float> %r, <8 x bfloat> %a, <8 x bfloat> %b) {
+; CHECK-LABEL: test_vbfmlaltq_laneq_f32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: bfmlalt v0.4s, v1.8h, v2.h[3]
+; CHECK-NEXT: ret
+entry:
+ %vecinit35 = shufflevector <8 x bfloat> %b, <8 x bfloat> undef, <8 x i32> <i32 3, i32 3, i32 3, i32 3, i32 3, i32 3, i32 3, i32 3>
+ %0 = bitcast <8 x bfloat> %a to <16 x i8>
+ %1 = bitcast <8 x bfloat> %vecinit35 to <16 x i8>
+ %vbfmlalt1.i = tail call <4 x float> @llvm.aarch64.neon.bfmlalt.v4f32.v16i8(<4 x float> %r, <16 x i8> %0, <16 x i8> %1)
+ ret <4 x float> %vbfmlalt1.i
+}
+
+declare <2 x float> @llvm.aarch64.neon.bfdot.v2f32.v8i8(<2 x float>, <8 x i8>, <8 x i8>) #2
+declare <4 x float> @llvm.aarch64.neon.bfdot.v4f32.v16i8(<4 x float>, <16 x i8>, <16 x i8>) #2
+declare <4 x float> @llvm.aarch64.neon.bfmmla.v4f32.v16i8(<4 x float>, <16 x i8>, <16 x i8>) #2
+declare <4 x float> @llvm.aarch64.neon.bfmlalb.v4f32.v16i8(<4 x float>, <16 x i8>, <16 x i8>) #2
+declare <4 x float> @llvm.aarch64.neon.bfmlalt.v4f32.v16i8(<4 x float>, <16 x i8>, <16 x i8>) #2
More information about the cfe-commits
mailing list