[llvm] [SPIRV] Add support for the SPIR-V extension SPV_KHR_bfloat16 (PR #155645)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Sep 15 06:14:38 PDT 2025
https://github.com/YixingZhang007 updated https://github.com/llvm/llvm-project/pull/155645
>From 8ad95c0666d869662701f8813e9a3d6a7ceecd33 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Wed, 27 Aug 2025 03:42:59 -0700
Subject: [PATCH] add support for the SPIR-V extension SPV_KHR_bfloat16
---
llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp | 2 +-
.../Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h | 5 +++
llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp | 3 +-
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 22 +++++++++-
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 3 ++
llvm/lib/Target/SPIRV/SPIRVInstrInfo.td | 2 +-
llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 40 ++++++++++++++++---
.../lib/Target/SPIRV/SPIRVSymbolicOperands.td | 30 ++++++++++++++
llvm/test/CodeGen/SPIRV/basic_float_types.ll | 24 ++++++++++-
.../extensions/SPV_KHR_bfloat16/bfloat16.ll | 22 ++++++++++
.../bfloat16_cooperative_matrix.ll | 22 ++++++++++
.../SPV_KHR_bfloat16/bfloat16_dot.ll | 21 ++++++++++
12 files changed, 185 insertions(+), 11 deletions(-)
create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 768e3713f78e2..12b735e053bde 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2765,7 +2765,7 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
}
bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
- if (containsBF16Type(U))
+ if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
return false;
const CallInst &CI = cast<CallInst>(U);
diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h
index c2c08f8831307..d76180ce97e9e 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h
@@ -232,6 +232,11 @@ namespace SpecConstantOpOperands {
#include "SPIRVGenTables.inc"
} // namespace SpecConstantOpOperands
+namespace FPEncoding {
+#define GET_FPEncoding_DECL
+#include "SPIRVGenTables.inc"
+} // namespace FPEncoding
+
struct ExtendedBuiltin {
StringRef Name;
InstructionSet::InstructionSet Set;
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index e7da5504b2d58..993de9e9f64ec 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -147,7 +147,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
{"SPV_KHR_float_controls2",
SPIRV::Extension::Extension::SPV_KHR_float_controls2},
{"SPV_INTEL_tensor_float32_conversion",
- SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}};
+ SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion},
+ {"SPV_KHR_bfloat16", SPIRV::Extension::Extension::SPV_KHR_bfloat16}};
bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
StringRef ArgValue,
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index cfe24c84941a9..115766ce886c7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -203,6 +203,18 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
});
}
+SPIRVType *
+SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
+ MachineIRBuilder &MIRBuilder,
+ SPIRV::FPEncoding::FPEncoding FPEncode) {
+ return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+ return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
+ .addDef(createTypeVReg(MIRBuilder))
+ .addImm(Width)
+ .addImm(FPEncode);
+ });
+}
+
SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) {
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return MIRBuilder.buildInstr(SPIRV::OpTypeVoid)
@@ -1041,8 +1053,14 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
return Width == 1 ? getOpTypeBool(MIRBuilder)
: getOpTypeInt(Width, MIRBuilder, false);
}
- if (Ty->isFloatingPointTy())
- return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
+ if (Ty->isFloatingPointTy()) {
+ if (Ty->isBFloatTy()) {
+ return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder,
+ SPIRV::FPEncoding::BFloat16KHR);
+ } else {
+ return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
+ }
+ }
if (Ty->isVoidTy())
return getOpTypeVoid(MIRBuilder);
if (Ty->isVectorTy()) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 7ef812828b7cc..a648defa0a888 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -438,6 +438,9 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder);
+ SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder,
+ SPIRV::FPEncoding::FPEncoding FPEncode);
+
SPIRVType *getOpTypeVoid(MachineIRBuilder &MIRBuilder);
SPIRVType *getOpTypeVector(uint32_t NumElems, SPIRVType *ElemType,
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index 8d10cd0ffb3dd..496dcba17c10d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -167,7 +167,7 @@ def OpTypeVoid: Op<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">;
def OpTypeBool: Op<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">;
def OpTypeInt: Op<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness),
"$type = OpTypeInt $width $signedness">;
-def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width),
+def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width, variable_ops),
"$type = OpTypeFloat $width">;
def OpTypeVector: Op<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount),
"$type = OpTypeVector $compType $compCount">;
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index b7e371d190866..a95f393b75605 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1222,6 +1222,13 @@ static void AddDotProductRequirements(const MachineInstr &MI,
}
}
+static bool isBFloat16Type(const SPIRVType *TypeDef) {
+ return TypeDef && TypeDef->getNumOperands() == 3 &&
+ TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
+ TypeDef->getOperand(1).getImm() == 16 &&
+ TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR;
+}
+
void addInstrRequirements(const MachineInstr &MI,
SPIRV::RequirementHandler &Reqs,
const SPIRVSubtarget &ST) {
@@ -1261,12 +1268,29 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::Int8);
break;
}
+ case SPIRV::OpDot: {
+ const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
+ SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
+ if (isBFloat16Type(TypeDef))
+ Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
+ break;
+ }
case SPIRV::OpTypeFloat: {
unsigned BitWidth = MI.getOperand(1).getImm();
if (BitWidth == 64)
Reqs.addCapability(SPIRV::Capability::Float64);
- else if (BitWidth == 16)
- Reqs.addCapability(SPIRV::Capability::Float16);
+ else if (BitWidth == 16) {
+ if (isBFloat16Type(&MI)) {
+ if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bfloat16))
+ report_fatal_error("OpTypeFloat type with bfloat requires the "
+ "following SPIR-V extension: SPV_KHR_bfloat16",
+ false);
+ Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
+ Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
+ } else {
+ Reqs.addCapability(SPIRV::Capability::Float16);
+ }
+ }
break;
}
case SPIRV::OpTypeVector: {
@@ -1286,8 +1310,9 @@ void addInstrRequirements(const MachineInstr &MI,
assert(MI.getOperand(2).isReg());
const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg());
- if (TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
- TypeDef->getOperand(1).getImm() == 16)
+ if ((TypeDef->getNumOperands() == 2) &&
+ (TypeDef->getOpcode() == SPIRV::OpTypeFloat) &&
+ (TypeDef->getOperand(1).getImm() == 16))
Reqs.addCapability(SPIRV::Capability::Float16Buffer);
break;
}
@@ -1593,7 +1618,7 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::AsmINTEL);
}
break;
- case SPIRV::OpTypeCooperativeMatrixKHR:
+ case SPIRV::OpTypeCooperativeMatrixKHR: {
if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
report_fatal_error(
"OpTypeCooperativeMatrixKHR type requires the "
@@ -1601,7 +1626,12 @@ void addInstrRequirements(const MachineInstr &MI,
false);
Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
+ const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
+ SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
+ if (isBFloat16Type(TypeDef))
+ Reqs.addCapability(SPIRV::Capability::BFloat16CooperativeMatrixKHR);
break;
+ }
case SPIRV::OpArithmeticFenceEXT:
if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence))
report_fatal_error("OpArithmeticFenceEXT requires the "
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index d2824ee2d2caf..501bcb94af2ea 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -210,6 +210,7 @@ def CooperativeMatrixLayoutOperand : OperandCategory;
def CooperativeMatrixOperandsOperand : OperandCategory;
def SpecConstantOpOperandsOperand : OperandCategory;
def MatrixMultiplyAccumulateOperandsOperand : OperandCategory;
+def FPEncodingOperand : OperandCategory;
//===----------------------------------------------------------------------===//
// Definition of the Environments
@@ -382,6 +383,7 @@ defm SPV_INTEL_2d_block_io : ExtensionOperand<122, [EnvOpenCL]>;
defm SPV_INTEL_int4 : ExtensionOperand<123, [EnvOpenCL]>;
defm SPV_KHR_float_controls2 : ExtensionOperand<124, [EnvVulkan, EnvOpenCL]>;
defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125, [EnvOpenCL]>;
+defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvVulkan, EnvOpenCL]>;
//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
@@ -594,6 +596,9 @@ defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d
defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>;
defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>;
defm TensorFloat32RoundingINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>;
+defm BFloat16TypeKHR : CapabilityOperand<5116, 0, 0, [SPV_KHR_bfloat16], []>;
+defm BFloat16DotProductKHR : CapabilityOperand<5117, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR]>;
+defm BFloat16CooperativeMatrixKHR : CapabilityOperand<5118, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR, CooperativeMatrixKHR]>;
//===----------------------------------------------------------------------===//
// Multiclass used to define SourceLanguage enum values and at the same time
@@ -1996,3 +2001,28 @@ defm MatrixAPackedFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x400,
defm MatrixBPackedFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x800, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
defm MatrixAPackedBFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x1000, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
defm MatrixBPackedBFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x2000, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
+
+//===----------------------------------------------------------------------===//
+// Multiclass used to define FPEncoding enum values and at the
+// same time SymbolicOperand entries with extensions.
+//===----------------------------------------------------------------------===//
+def FPEncoding : GenericEnum, Operand<i32> {
+ let FilterClass = "FPEncoding";
+ let NameField = "Name";
+ let ValueField = "Value";
+ let PrintMethod = !strconcat("printSymbolicOperand<OperandCategory::", FilterClass, "Operand>");
+}
+
+class FPEncoding<string name, bits<32> value> {
+ string Name = name;
+ bits<32> Value = value;
+}
+
+multiclass FPEncodingOperand<bits<32> value, list<Extension> reqExtensions>{
+ def NAME : FPEncoding<NAME, value>;
+ defm : SymbolicOperandWithRequirements<
+ FPEncodingOperand, value, NAME, 0, 0,
+ reqExtensions, [], []>;
+}
+
+defm BFloat16KHR : FPEncodingOperand<0, [SPV_KHR_bfloat16]>;
diff --git a/llvm/test/CodeGen/SPIRV/basic_float_types.ll b/llvm/test/CodeGen/SPIRV/basic_float_types.ll
index dfee1ace2205d..112bf1bea9d50 100644
--- a/llvm/test/CodeGen/SPIRV/basic_float_types.ll
+++ b/llvm/test/CodeGen/SPIRV/basic_float_types.ll
@@ -7,8 +7,10 @@ entry:
; CHECK-DAG: OpCapability Float16
; CHECK-DAG: OpCapability Float64
+; CHECK-DAG: OpCapability BFloat16TypeKHR
-; CHECK-DAG: %[[#half:]] = OpTypeFloat 16
+; CHECK-DAG: %[[#half:]] = OpTypeFloat 16{{$}}
+; CHECK-DAG: %[[#bfloat:]] = OpTypeFloat 16 0{{$}}
; CHECK-DAG: %[[#float:]] = OpTypeFloat 32
; CHECK-DAG: %[[#double:]] = OpTypeFloat 64
@@ -16,6 +18,10 @@ entry:
; CHECK-DAG: %[[#v3half:]] = OpTypeVector %[[#half]] 3
; CHECK-DAG: %[[#v4half:]] = OpTypeVector %[[#half]] 4
+; CHECK-DAG: %[[#v2bfloat:]] = OpTypeVector %[[#bfloat]] 2
+; CHECK-DAG: %[[#v3bfloat:]] = OpTypeVector %[[#bfloat]] 3
+; CHECK-DAG: %[[#v4bfloat:]] = OpTypeVector %[[#bfloat]] 4
+
; CHECK-DAG: %[[#v2float:]] = OpTypeVector %[[#float]] 2
; CHECK-DAG: %[[#v3float:]] = OpTypeVector %[[#float]] 3
; CHECK-DAG: %[[#v4float:]] = OpTypeVector %[[#float]] 4
@@ -25,11 +31,15 @@ entry:
; CHECK-DAG: %[[#v4double:]] = OpTypeVector %[[#double]] 4
; CHECK-DAG: %[[#ptr_Function_half:]] = OpTypePointer Function %[[#half]]
+; CHECK-DAG: %[[#ptr_Function_bfloat:]] = OpTypePointer Function %[[#bfloat]]
; CHECK-DAG: %[[#ptr_Function_float:]] = OpTypePointer Function %[[#float]]
; CHECK-DAG: %[[#ptr_Function_double:]] = OpTypePointer Function %[[#double]]
; CHECK-DAG: %[[#ptr_Function_v2half:]] = OpTypePointer Function %[[#v2half]]
; CHECK-DAG: %[[#ptr_Function_v3half:]] = OpTypePointer Function %[[#v3half]]
; CHECK-DAG: %[[#ptr_Function_v4half:]] = OpTypePointer Function %[[#v4half]]
+; CHECK-DAG: %[[#ptr_Function_v2bfloat:]] = OpTypePointer Function %[[#v2bfloat]]
+; CHECK-DAG: %[[#ptr_Function_v3bfloat:]] = OpTypePointer Function %[[#v3bfloat]]
+; CHECK-DAG: %[[#ptr_Function_v4bfloat:]] = OpTypePointer Function %[[#v4bfloat]]
; CHECK-DAG: %[[#ptr_Function_v2float:]] = OpTypePointer Function %[[#v2float]]
; CHECK-DAG: %[[#ptr_Function_v3float:]] = OpTypePointer Function %[[#v3float]]
; CHECK-DAG: %[[#ptr_Function_v4float:]] = OpTypePointer Function %[[#v4float]]
@@ -40,6 +50,9 @@ entry:
; CHECK: %[[#]] = OpVariable %[[#ptr_Function_half]] Function
%half_Val = alloca half, align 2
+; CHECK: %[[#]] = OpVariable %[[#ptr_Function_bfloat]] Function
+ %bfloat_Val = alloca bfloat, align 2
+
; CHECK: %[[#]] = OpVariable %[[#ptr_Function_float]] Function
%float_Val = alloca float, align 4
@@ -55,6 +68,15 @@ entry:
; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v4half]] Function
%half4_Val = alloca <4 x half>, align 8
+; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v2bfloat]] Function
+ %bfloat2_Val = alloca <2 x bfloat>, align 4
+
+; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v3bfloat]] Function
+ %bfloat3_Val = alloca <3 x bfloat>, align 8
+
+; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v4bfloat]] Function
+ %bfloat4_Val = alloca <4 x bfloat>, align 8
+
; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v2float]] Function
%float2_Val = alloca <2 x float>, align 8
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
new file mode 100644
index 0000000000000..22668e71fb257
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
@@ -0,0 +1,22 @@
+; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-ERROR: LLVM ERROR: OpTypeFloat type with bfloat requires the following SPIR-V extension: SPV_KHR_bfloat16
+
+; CHECK-DAG: OpCapability BFloat16TypeKHR
+; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
+; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 0
+; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
+target triple = "spir64-unknown-unknown"
+
+define spir_kernel void @test() {
+entry:
+ %addr1 = alloca bfloat
+ %addr2 = alloca <2 x bfloat>
+ %data1 = load bfloat, ptr %addr1
+ %data2 = load <2 x bfloat>, ptr %addr2
+ ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
new file mode 100644
index 0000000000000..d47b5d7440d18
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll
@@ -0,0 +1,22 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16,+SPV_KHR_cooperative_matrix %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16,+SPV_KHR_cooperative_matrix %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: OpCapability BFloat16TypeKHR
+; CHECK-DAG: OpCapability CooperativeMatrixKHR
+; CHECK-DAG: OpCapability BFloat16CooperativeMatrixKHR
+; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
+; CHECK-DAG: OpExtension "SPV_KHR_cooperative_matrix"
+; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 0
+; CHECK: %[[#MatTy:]] = OpTypeCooperativeMatrixKHR %[[#BFLOAT]] %[[#]] %[[#]] %[[#]] %[[#]]
+; CHECK: OpCompositeConstruct %[[#MatTy]] %[[#]]
+
+define spir_kernel void @matr_mult(ptr addrspace(1) align 1 %_arg_accA, ptr addrspace(1) align 1 %_arg_accB, ptr addrspace(1) align 4 %_arg_accC, i64 %_arg_N, i64 %_arg_K) {
+entry:
+ %addr1 = alloca target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2), align 4
+ %res = alloca target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2), align 4
+ %m1 = tail call spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(bfloat 1.0)
+ store target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) %m1, ptr %addr1, align 4
+ ret void
+}
+
+declare dso_local spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(bfloat)
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
new file mode 100644
index 0000000000000..4c248fea5c7f1
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll
@@ -0,0 +1,21 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: OpCapability BFloat16TypeKHR
+; CHECK-DAG: OpCapability BFloat16DotProductKHR
+; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
+; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 0
+; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
+; CHECK: OpDot
+
+declare spir_func bfloat @_Z3dotDv2_u6__bf16Dv2_S_(<2 x bfloat>, <2 x bfloat>)
+
+define spir_kernel void @test() {
+entry:
+ %addrA = alloca <2 x bfloat>
+ %addrB = alloca <2 x bfloat>
+ %dataA = load <2 x bfloat>, ptr %addrA
+ %dataB = load <2 x bfloat>, ptr %addrB
+ %call = call spir_func bfloat @_Z3dotDv2_u6__bf16Dv2_S_(<2 x bfloat> %dataA, <2 x bfloat> %dataB)
+ ret void
+}
More information about the llvm-commits
mailing list