[llvm] e4888a3 - [X86][BF16] Enable __bf16 for x86 targets.
Phoebe Wang via llvm-commits
llvm-commits at lists.llvm.org
Tue Aug 9 18:41:20 PDT 2022
Author: Freddy Ye
Date: 2022-08-10T09:00:47+08:00
New Revision: e4888a37d36780872d685c68ef8b26b2e14d6d39
URL: https://github.com/llvm/llvm-project/commit/e4888a37d36780872d685c68ef8b26b2e14d6d39
DIFF: https://github.com/llvm/llvm-project/commit/e4888a37d36780872d685c68ef8b26b2e14d6d39.diff
LOG: [X86][BF16] Enable __bf16 for x86 targets.
X86 psABI has updated to support __bf16 type, the ABI of which is the
same as FP16. See https://discourse.llvm.org/t/patch-add-optional-bfloat16-support/63149
Reviewed By: pengfei
Differential Revision: https://reviews.llvm.org/D130964
Added:
clang/test/CodeGen/X86/bfloat-abi.c
clang/test/CodeGen/X86/bfloat-half-abi.c
clang/test/CodeGen/X86/bfloat-mangle.cpp
Modified:
clang/docs/LanguageExtensions.rst
clang/lib/Basic/Targets/X86.cpp
clang/lib/Basic/Targets/X86.h
clang/lib/CodeGen/TargetInfo.cpp
clang/test/Sema/vector-decl-crash.c
llvm/include/llvm/IR/Type.h
Removed:
################################################################################
diff --git a/clang/docs/LanguageExtensions.rst b/clang/docs/LanguageExtensions.rst
index 52931cc9232ce..a6d02b2e02c15 100644
--- a/clang/docs/LanguageExtensions.rst
+++ b/clang/docs/LanguageExtensions.rst
@@ -756,6 +756,10 @@ performing the operation, and then truncating to ``_Float16``.
``__bf16`` is purely a storage format; it is currently only supported on the following targets:
* 32-bit ARM
* 64-bit ARM (AArch64)
+* X86 (see below)
+
+On X86 targets, ``__bf16`` is supported as long as SSE2 is available, which
+includes all 64-bit and all recent 32-bit processors.
``__fp16`` is a storage and interchange format only. This means that values of
``__fp16`` are immediately promoted to (at least) ``float`` when used in arithmetic
diff --git a/clang/lib/Basic/Targets/X86.cpp b/clang/lib/Basic/Targets/X86.cpp
index 210c8451e7886..7a3cb662a91ff 100644
--- a/clang/lib/Basic/Targets/X86.cpp
+++ b/clang/lib/Basic/Targets/X86.cpp
@@ -358,6 +358,8 @@ bool X86TargetInfo::handleTargetFeatures(std::vector<std::string> &Features,
HasFloat16 = SSELevel >= SSE2;
+ HasBFloat16 = SSELevel >= SSE2;
+
MMX3DNowEnum ThreeDNowLevel = llvm::StringSwitch<MMX3DNowEnum>(Feature)
.Case("+3dnowa", AMD3DNowAthlon)
.Case("+3dnow", AMD3DNow)
diff --git a/clang/lib/Basic/Targets/X86.h b/clang/lib/Basic/Targets/X86.h
index 0affa58b2f4c0..ed0864aec6d2d 100644
--- a/clang/lib/Basic/Targets/X86.h
+++ b/clang/lib/Basic/Targets/X86.h
@@ -156,6 +156,8 @@ class LLVM_LIBRARY_VISIBILITY X86TargetInfo : public TargetInfo {
public:
X86TargetInfo(const llvm::Triple &Triple, const TargetOptions &)
: TargetInfo(Triple) {
+ BFloat16Width = BFloat16Align = 16;
+ BFloat16Format = &llvm::APFloat::BFloat();
LongDoubleFormat = &llvm::APFloat::x87DoubleExtended();
AddrSpaceMap = &X86AddrSpaceMap;
HasStrictFP = true;
@@ -396,6 +398,8 @@ class LLVM_LIBRARY_VISIBILITY X86TargetInfo : public TargetInfo {
uint64_t getPointerAlignV(unsigned AddrSpace) const override {
return getPointerWidthV(AddrSpace);
}
+
+ const char *getBFloat16Mangling() const override { return "u6__bf16"; };
};
// X86-32 generic target
diff --git a/clang/lib/CodeGen/TargetInfo.cpp b/clang/lib/CodeGen/TargetInfo.cpp
index c283712c74d35..bbc40ad3a50a7 100644
--- a/clang/lib/CodeGen/TargetInfo.cpp
+++ b/clang/lib/CodeGen/TargetInfo.cpp
@@ -2861,7 +2861,7 @@ void X86_64ABIInfo::classify(QualType Ty, uint64_t OffsetBase, Class &Lo,
} else if (k >= BuiltinType::Bool && k <= BuiltinType::LongLong) {
Current = Integer;
} else if (k == BuiltinType::Float || k == BuiltinType::Double ||
- k == BuiltinType::Float16) {
+ k == BuiltinType::Float16 || k == BuiltinType::BFloat16) {
Current = SSE;
} else if (k == BuiltinType::LongDouble) {
const llvm::fltSemantics *LDF = &getTarget().getLongDoubleFormat();
@@ -2992,7 +2992,8 @@ void X86_64ABIInfo::classify(QualType Ty, uint64_t OffsetBase, Class &Lo,
Current = Integer;
else if (Size <= 128)
Lo = Hi = Integer;
- } else if (ET->isFloat16Type() || ET == getContext().FloatTy) {
+ } else if (ET->isFloat16Type() || ET == getContext().FloatTy ||
+ ET->isBFloat16Type()) {
Current = SSE;
} else if (ET == getContext().DoubleTy) {
Lo = Hi = SSE;
@@ -3464,9 +3465,9 @@ GetSSETypeAtOffset(llvm::Type *IRType, unsigned IROffset,
if (SourceSize > T0Size)
T1 = getFPTypeAtOffset(IRType, IROffset + T0Size, TD);
if (T1 == nullptr) {
- // Check if IRType is a half + float. float type will be in IROffset+4 due
+ // Check if IRType is a half/bfloat + float. float type will be in IROffset+4 due
// to its alignment.
- if (T0->isHalfTy() && SourceSize > 4)
+ if (T0->is16bitFPTy() && SourceSize > 4)
T1 = getFPTypeAtOffset(IRType, IROffset + 4, TD);
// If we can't get a second FP type, return a simple half or float.
// avx512fp16-abi.c:pr51813_2 shows it works to return float for
@@ -3478,7 +3479,7 @@ GetSSETypeAtOffset(llvm::Type *IRType, unsigned IROffset,
if (T0->isFloatTy() && T1->isFloatTy())
return llvm::FixedVectorType::get(T0, 2);
- if (T0->isHalfTy() && T1->isHalfTy()) {
+ if (T0->is16bitFPTy() && T1->is16bitFPTy()) {
llvm::Type *T2 = nullptr;
if (SourceSize > 4)
T2 = getFPTypeAtOffset(IRType, IROffset + 4, TD);
@@ -3487,7 +3488,7 @@ GetSSETypeAtOffset(llvm::Type *IRType, unsigned IROffset,
return llvm::FixedVectorType::get(T0, 4);
}
- if (T0->isHalfTy() || T1->isHalfTy())
+ if (T0->is16bitFPTy() || T1->is16bitFPTy())
return llvm::FixedVectorType::get(llvm::Type::getHalfTy(getVMContext()), 4);
return llvm::Type::getDoubleTy(getVMContext());
diff --git a/clang/test/CodeGen/X86/bfloat-abi.c b/clang/test/CodeGen/X86/bfloat-abi.c
new file mode 100644
index 0000000000000..42250791848ac
--- /dev/null
+++ b/clang/test/CodeGen/X86/bfloat-abi.c
@@ -0,0 +1,149 @@
+// RUN: %clang_cc1 -triple x86_64-linux -emit-llvm -target-feature +sse2 < %s | FileCheck %s --check-prefixes=CHECK
+
+struct bfloat1 {
+ __bf16 a;
+};
+
+struct bfloat1 h1(__bf16 a) {
+ // CHECK: define{{.*}}bfloat @
+ struct bfloat1 x;
+ x.a = a;
+ return x;
+}
+
+struct bfloat2 {
+ __bf16 a;
+ __bf16 b;
+};
+
+struct bfloat2 h2(__bf16 a, __bf16 b) {
+ // CHECK: define{{.*}}<2 x bfloat> @
+ struct bfloat2 x;
+ x.a = a;
+ x.b = b;
+ return x;
+}
+
+struct bfloat3 {
+ __bf16 a;
+ __bf16 b;
+ __bf16 c;
+};
+
+struct bfloat3 h3(__bf16 a, __bf16 b, __bf16 c) {
+ // CHECK: define{{.*}}<4 x bfloat> @
+ struct bfloat3 x;
+ x.a = a;
+ x.b = b;
+ x.c = c;
+ return x;
+}
+
+struct bfloat4 {
+ __bf16 a;
+ __bf16 b;
+ __bf16 c;
+ __bf16 d;
+};
+
+struct bfloat4 h4(__bf16 a, __bf16 b, __bf16 c, __bf16 d) {
+ // CHECK: define{{.*}}<4 x bfloat> @
+ struct bfloat4 x;
+ x.a = a;
+ x.b = b;
+ x.c = c;
+ x.d = d;
+ return x;
+}
+
+struct floatbfloat {
+ float a;
+ __bf16 b;
+};
+
+struct floatbfloat fh(float a, __bf16 b) {
+ // CHECK: define{{.*}}<4 x half> @
+ struct floatbfloat x;
+ x.a = a;
+ x.b = b;
+ return x;
+}
+
+struct floatbfloat2 {
+ float a;
+ __bf16 b;
+ __bf16 c;
+};
+
+struct floatbfloat2 fh2(float a, __bf16 b, __bf16 c) {
+ // CHECK: define{{.*}}<4 x half> @
+ struct floatbfloat2 x;
+ x.a = a;
+ x.b = b;
+ x.c = c;
+ return x;
+}
+
+struct bfloatfloat {
+ __bf16 a;
+ float b;
+};
+
+struct bfloatfloat hf(__bf16 a, float b) {
+ // CHECK: define{{.*}}<4 x half> @
+ struct bfloatfloat x;
+ x.a = a;
+ x.b = b;
+ return x;
+}
+
+struct bfloat2float {
+ __bf16 a;
+ __bf16 b;
+ float c;
+};
+
+struct bfloat2float h2f(__bf16 a, __bf16 b, float c) {
+ // CHECK: define{{.*}}<4 x bfloat> @
+ struct bfloat2float x;
+ x.a = a;
+ x.b = b;
+ x.c = c;
+ return x;
+}
+
+struct floatbfloat3 {
+ float a;
+ __bf16 b;
+ __bf16 c;
+ __bf16 d;
+};
+
+struct floatbfloat3 fh3(float a, __bf16 b, __bf16 c, __bf16 d) {
+ // CHECK: define{{.*}}{ <4 x half>, bfloat } @
+ struct floatbfloat3 x;
+ x.a = a;
+ x.b = b;
+ x.c = c;
+ x.d = d;
+ return x;
+}
+
+struct bfloat5 {
+ __bf16 a;
+ __bf16 b;
+ __bf16 c;
+ __bf16 d;
+ __bf16 e;
+};
+
+struct bfloat5 h5(__bf16 a, __bf16 b, __bf16 c, __bf16 d, __bf16 e) {
+ // CHECK: define{{.*}}{ <4 x bfloat>, bfloat } @
+ struct bfloat5 x;
+ x.a = a;
+ x.b = b;
+ x.c = c;
+ x.d = d;
+ x.e = e;
+ return x;
+}
diff --git a/clang/test/CodeGen/X86/bfloat-half-abi.c b/clang/test/CodeGen/X86/bfloat-half-abi.c
new file mode 100644
index 0000000000000..42250791848ac
--- /dev/null
+++ b/clang/test/CodeGen/X86/bfloat-half-abi.c
@@ -0,0 +1,149 @@
+// RUN: %clang_cc1 -triple x86_64-linux -emit-llvm -target-feature +sse2 < %s | FileCheck %s --check-prefixes=CHECK
+
+struct bfloat1 {
+ __bf16 a;
+};
+
+struct bfloat1 h1(__bf16 a) {
+ // CHECK: define{{.*}}bfloat @
+ struct bfloat1 x;
+ x.a = a;
+ return x;
+}
+
+struct bfloat2 {
+ __bf16 a;
+ __bf16 b;
+};
+
+struct bfloat2 h2(__bf16 a, __bf16 b) {
+ // CHECK: define{{.*}}<2 x bfloat> @
+ struct bfloat2 x;
+ x.a = a;
+ x.b = b;
+ return x;
+}
+
+struct bfloat3 {
+ __bf16 a;
+ __bf16 b;
+ __bf16 c;
+};
+
+struct bfloat3 h3(__bf16 a, __bf16 b, __bf16 c) {
+ // CHECK: define{{.*}}<4 x bfloat> @
+ struct bfloat3 x;
+ x.a = a;
+ x.b = b;
+ x.c = c;
+ return x;
+}
+
+struct bfloat4 {
+ __bf16 a;
+ __bf16 b;
+ __bf16 c;
+ __bf16 d;
+};
+
+struct bfloat4 h4(__bf16 a, __bf16 b, __bf16 c, __bf16 d) {
+ // CHECK: define{{.*}}<4 x bfloat> @
+ struct bfloat4 x;
+ x.a = a;
+ x.b = b;
+ x.c = c;
+ x.d = d;
+ return x;
+}
+
+struct floatbfloat {
+ float a;
+ __bf16 b;
+};
+
+struct floatbfloat fh(float a, __bf16 b) {
+ // CHECK: define{{.*}}<4 x half> @
+ struct floatbfloat x;
+ x.a = a;
+ x.b = b;
+ return x;
+}
+
+struct floatbfloat2 {
+ float a;
+ __bf16 b;
+ __bf16 c;
+};
+
+struct floatbfloat2 fh2(float a, __bf16 b, __bf16 c) {
+ // CHECK: define{{.*}}<4 x half> @
+ struct floatbfloat2 x;
+ x.a = a;
+ x.b = b;
+ x.c = c;
+ return x;
+}
+
+struct bfloatfloat {
+ __bf16 a;
+ float b;
+};
+
+struct bfloatfloat hf(__bf16 a, float b) {
+ // CHECK: define{{.*}}<4 x half> @
+ struct bfloatfloat x;
+ x.a = a;
+ x.b = b;
+ return x;
+}
+
+struct bfloat2float {
+ __bf16 a;
+ __bf16 b;
+ float c;
+};
+
+struct bfloat2float h2f(__bf16 a, __bf16 b, float c) {
+ // CHECK: define{{.*}}<4 x bfloat> @
+ struct bfloat2float x;
+ x.a = a;
+ x.b = b;
+ x.c = c;
+ return x;
+}
+
+struct floatbfloat3 {
+ float a;
+ __bf16 b;
+ __bf16 c;
+ __bf16 d;
+};
+
+struct floatbfloat3 fh3(float a, __bf16 b, __bf16 c, __bf16 d) {
+ // CHECK: define{{.*}}{ <4 x half>, bfloat } @
+ struct floatbfloat3 x;
+ x.a = a;
+ x.b = b;
+ x.c = c;
+ x.d = d;
+ return x;
+}
+
+struct bfloat5 {
+ __bf16 a;
+ __bf16 b;
+ __bf16 c;
+ __bf16 d;
+ __bf16 e;
+};
+
+struct bfloat5 h5(__bf16 a, __bf16 b, __bf16 c, __bf16 d, __bf16 e) {
+ // CHECK: define{{.*}}{ <4 x bfloat>, bfloat } @
+ struct bfloat5 x;
+ x.a = a;
+ x.b = b;
+ x.c = c;
+ x.d = d;
+ x.e = e;
+ return x;
+}
diff --git a/clang/test/CodeGen/X86/bfloat-mangle.cpp b/clang/test/CodeGen/X86/bfloat-mangle.cpp
new file mode 100644
index 0000000000000..2892a76d8d910
--- /dev/null
+++ b/clang/test/CodeGen/X86/bfloat-mangle.cpp
@@ -0,0 +1,5 @@
+// RUN: %clang_cc1 -triple i386-unknown-unknown -target-feature +sse2 -emit-llvm -o - %s | FileCheck %s
+// RUN: %clang_cc1 -triple x86_64-unknown-unknown -target-feature +sse2 -emit-llvm -o - %s | FileCheck %s
+
+// CHECK: define {{.*}}void @_Z3foou6__bf16(bfloat noundef %b)
+void foo(__bf16 b) {}
diff --git a/clang/test/Sema/vector-decl-crash.c b/clang/test/Sema/vector-decl-crash.c
index 5e4b098fee2d3..fafe34133de43 100644
--- a/clang/test/Sema/vector-decl-crash.c
+++ b/clang/test/Sema/vector-decl-crash.c
@@ -1,4 +1,4 @@
-// RUN: %clang_cc1 %s -fsyntax-only -verify -triple x86_64-unknown-unknown
+// RUN: %clang_cc1 %s -fsyntax-only -verify -triple riscv64-unknown-unknown
// GH50171
// This would previously crash when __bf16 was not a supported type.
diff --git a/llvm/include/llvm/IR/Type.h b/llvm/include/llvm/IR/Type.h
index 52fce3409b7ae..e6d351babb563 100644
--- a/llvm/include/llvm/IR/Type.h
+++ b/llvm/include/llvm/IR/Type.h
@@ -144,6 +144,11 @@ class Type {
/// Return true if this is 'bfloat', a 16-bit bfloat type.
bool isBFloatTy() const { return getTypeID() == BFloatTyID; }
+ /// Return true if this is a 16-bit float type.
+ bool is16bitFPTy() const {
+ return getTypeID() == BFloatTyID || getTypeID() == HalfTyID;
+ }
+
/// Return true if this is 'float', a 32-bit IEEE fp type.
bool isFloatTy() const { return getTypeID() == FloatTyID; }
More information about the llvm-commits
mailing list