[Mlir-commits] [mlir] 0a0960d - [mlir][spirv] Add bfloat16 support (#141458)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jun 13 07:14:49 PDT 2025
Author: Darren Wihandi
Date: 2025-06-13T10:14:45-04:00
New Revision: 0a0960dac69fc88a3c8bd5e2099f8d45b0292c78
URL: https://github.com/llvm/llvm-project/commit/0a0960dac69fc88a3c8bd5e2099f8d45b0292c78
DIFF: https://github.com/llvm/llvm-project/commit/0a0960dac69fc88a3c8bd5e2099f8d45b0292c78.diff
LOG: [mlir][spirv] Add bfloat16 support (#141458)
Adds bf16 support to SPIRV by using the `SPV_KHR_bfloat16` extension.
Only a few operations are supported, including loading from and storing
to memory, conversion to/from other types, cooperative matrix operations
(including coop matrix arithmetic ops) and dot product support.
This PR adds the type definition and implements the basic cast
operations. Arithmetic/coop matrix ops will be added in a separate PR.
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir
mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
mlir/test/Dialect/SPIRV/IR/types.mlir
mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
mlir/test/Target/SPIRV/cast-ops.mlir
mlir/test/Target/SPIRV/logical-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index b143cf9a5f509..e413503bbd672 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -344,6 +344,7 @@ def SPV_KHR_subgroup_rotate : I32EnumAttrCase<"SPV_KHR_subgroup
def SPV_KHR_non_semantic_info : I32EnumAttrCase<"SPV_KHR_non_semantic_info", 29>;
def SPV_KHR_terminate_invocation : I32EnumAttrCase<"SPV_KHR_terminate_invocation", 30>;
def SPV_KHR_cooperative_matrix : I32EnumAttrCase<"SPV_KHR_cooperative_matrix", 31>;
+def SPV_KHR_bfloat16 : I32EnumAttrCase<"SPV_KHR_bfloat16", 32>;
def SPV_EXT_demote_to_helper_invocation : I32EnumAttrCase<"SPV_EXT_demote_to_helper_invocation", 1000>;
def SPV_EXT_descriptor_indexing : I32EnumAttrCase<"SPV_EXT_descriptor_indexing", 1001>;
@@ -436,7 +437,7 @@ def SPIRV_ExtensionAttr :
SPV_KHR_fragment_shader_barycentric, SPV_KHR_ray_cull_mask,
SPV_KHR_uniform_group_instructions, SPV_KHR_subgroup_rotate,
SPV_KHR_non_semantic_info, SPV_KHR_terminate_invocation,
- SPV_KHR_cooperative_matrix,
+ SPV_KHR_cooperative_matrix, SPV_KHR_bfloat16,
SPV_EXT_demote_to_helper_invocation, SPV_EXT_descriptor_indexing,
SPV_EXT_fragment_fully_covered, SPV_EXT_fragment_invocation_density,
SPV_EXT_fragment_shader_interlock, SPV_EXT_physical_storage_buffer,
@@ -1412,6 +1413,23 @@ def SPIRV_C_ShaderStereoViewNV : I32EnumAttrCase<"Shade
Extension<[SPV_NV_stereo_view_rendering]>
];
}
+def SPIRV_C_BFloat16TypeKHR : I32EnumAttrCase<"BFloat16TypeKHR", 5116> {
+ list<Availability> availability = [
+ Extension<[SPV_KHR_bfloat16]>
+ ];
+}
+def SPIRV_C_BFloat16DotProductKHR : I32EnumAttrCase<"BFloat16DotProductKHR", 5117> {
+ list<I32EnumAttrCase> implies = [SPIRV_C_BFloat16TypeKHR];
+ list<Availability> availability = [
+ Extension<[SPV_KHR_bfloat16]>
+ ];
+}
+def SPIRV_C_BFloat16CooperativeMatrixKHR : I32EnumAttrCase<"BFloat16CooperativeMatrixKHR", 5118> {
+ list<I32EnumAttrCase> implies = [SPIRV_C_BFloat16TypeKHR, SPIRV_C_CooperativeMatrixKHR];
+ list<Availability> availability = [
+ Extension<[SPV_KHR_bfloat16]>
+ ];
+}
def SPIRV_C_Bfloat16ConversionINTEL : I32EnumAttrCase<"Bfloat16ConversionINTEL", 6115> {
list<Availability> availability = [
@@ -1518,7 +1536,8 @@ def SPIRV_CapabilityAttr :
SPIRV_C_StorageTexelBufferArrayNonUniformIndexing,
SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV,
SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL,
- SPIRV_C_CacheControlsINTEL
+ SPIRV_C_CacheControlsINTEL, SPIRV_C_BFloat16TypeKHR,
+ SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR
]>;
def SPIRV_AM_Logical : I32EnumAttrCase<"Logical", 0>;
@@ -3217,6 +3236,16 @@ def SPIRV_ExecutionModelAttr :
SPIRV_EM_TaskEXT, SPIRV_EM_MeshEXT
]>;
+def SPIRV_FPE_BFloat16KHR : I32EnumAttrCase<"BFloat16KHR", 0> {
+ list<Availability> availability = [
+ Capability<[SPIRV_C_BFloat16TypeKHR]>
+ ];
+}
+def SPIRV_FPEncodingAttr :
+ SPIRV_I32EnumAttr<"FPEncoding", "valid SPIR-V FPEncoding", "f_p_encoding", [
+ SPIRV_FPE_BFloat16KHR
+ ]>;
+
def SPIRV_FC_None : I32BitEnumAttrCaseNone<"None">;
def SPIRV_FC_Inline : I32BitEnumAttrCaseBit<"Inline", 0>;
def SPIRV_FC_DontInline : I32BitEnumAttrCaseBit<"DontInline", 1>;
@@ -4161,10 +4190,12 @@ def SPIRV_Integer : AnyIntOfWidths<[8, 16, 32, 64]>;
def SPIRV_Int16 : TypeAlias<I16, "Int16">;
def SPIRV_Int32 : TypeAlias<I32, "Int32">;
def SPIRV_Float32 : TypeAlias<F32, "Float32">;
+def SPIRV_BFloat16KHR : TypeAlias<BF16, "BFloat16">;
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>;
+def SPIRV_AnyFloat : AnyTypeOf<[SPIRV_Float, SPIRV_BFloat16KHR]>;
def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16],
- [SPIRV_Bool, SPIRV_Integer, SPIRV_Float]>;
+ [SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat]>;
// Component type check is done in the type parser for the following SPIR-V
// dialect-specific types so we use "Any" here.
def SPIRV_AnyPtr : DialectType<SPIRV_Dialect, SPIRV_IsPtrType,
@@ -4187,14 +4218,14 @@ def SPIRV_AnyStruct : DialectType<SPIRV_Dialect, SPIRV_IsStructType,
def SPIRV_AnySampledImage : DialectType<SPIRV_Dialect, SPIRV_IsSampledImageType,
"any SPIR-V sampled image type">;
-def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_Float]>;
+def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_AnyFloat]>;
def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
def SPIRV_Aggregate : AnyTypeOf<[SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct]>;
def SPIRV_Composite :
AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix]>;
def SPIRV_Type : AnyTypeOf<[
- SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_Vector,
+ SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat, SPIRV_Vector,
SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage,
SPIRV_AnyImage
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
index b05ee0251df5b..a5c8aa8fb450c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
@@ -86,7 +86,7 @@ def SPIRV_BitcastOp : SPIRV_Op<"Bitcast", [Pure]> {
// -----
-def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_Float, []> {
+def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_AnyFloat, []> {
let summary = [{
Convert value numerically from floating point to signed integer, with
round toward 0.0.
@@ -111,7 +111,7 @@ def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_Float
// -----
-def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_Float, []> {
+def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_AnyFloat, []> {
let summary = [{
Convert value numerically from floating point to unsigned integer, with
round toward 0.0.
@@ -138,7 +138,7 @@ def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_Float
// -----
def SPIRV_ConvertSToFOp : SPIRV_CastOp<"ConvertSToF",
- SPIRV_Float,
+ SPIRV_AnyFloat,
SPIRV_Integer,
[SignedOp]> {
let summary = [{
@@ -165,7 +165,7 @@ def SPIRV_ConvertSToFOp : SPIRV_CastOp<"ConvertSToF",
// -----
def SPIRV_ConvertUToFOp : SPIRV_CastOp<"ConvertUToF",
- SPIRV_Float,
+ SPIRV_AnyFloat,
SPIRV_Integer,
[UnsignedOp]> {
let summary = [{
@@ -192,8 +192,8 @@ def SPIRV_ConvertUToFOp : SPIRV_CastOp<"ConvertUToF",
// -----
def SPIRV_FConvertOp : SPIRV_CastOp<"FConvert",
- SPIRV_Float,
- SPIRV_Float,
+ SPIRV_AnyFloat,
+ SPIRV_AnyFloat,
[UsableInSpecConstantOp]> {
let summary = [{
Convert value numerically from one floating-point width to another
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 0cf5f0823be63..a21acef1c4b43 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -175,10 +175,7 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
// Check other allowed types
if (auto t = llvm::dyn_cast<FloatType>(type)) {
- if (type.isBF16()) {
- parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types");
- return Type();
- }
+ // TODO: All float types are allowed for now, but this should be fixed.
} else if (auto t = llvm::dyn_cast<IntegerType>(type)) {
if (!ScalarType::isValid(t)) {
parser.emitError(typeLoc,
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 1aff43c301334..93e0c9b33c546 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -526,7 +526,7 @@ bool ScalarType::classof(Type type) {
}
bool ScalarType::isValid(FloatType type) {
- return llvm::is_contained({16u, 32u, 64u}, type.getWidth()) && !type.isBF16();
+ return llvm::is_contained({16u, 32u, 64u}, type.getWidth());
}
bool ScalarType::isValid(IntegerType type) {
@@ -535,6 +535,11 @@ bool ScalarType::isValid(IntegerType type) {
void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
+ if (isa<BFloat16Type>(*this)) {
+ static const Extension ext = Extension::SPV_KHR_bfloat16;
+ extensions.push_back(ext);
+ }
+
// 8- or 16-bit integer/floating-point numbers will require extra extensions
// to appear in interface storage classes. See SPV_KHR_16bit_storage and
// SPV_KHR_8bit_storage for more details.
@@ -640,7 +645,16 @@ void ScalarType::getCapabilities(
} else {
assert(llvm::isa<FloatType>(*this));
switch (bitwidth) {
- WIDTH_CASE(Float, 16);
+ case 16: {
+ if (isa<BFloat16Type>(*this)) {
+ static const Capability cap = Capability::BFloat16TypeKHR;
+ capabilities.push_back(cap);
+ } else {
+ static const Capability cap = Capability::Float16;
+ capabilities.push_back(cap);
+ }
+ break;
+ }
WIDTH_CASE(Float, 64);
case 32:
break;
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index c43d584d7b913..b9d9a9015eb61 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -867,11 +867,15 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
} break;
case spirv::Opcode::OpTypeFloat: {
- if (operands.size() != 2)
- return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter");
+ if (operands.size() != 2 && operands.size() != 3)
+ return emitError(unknownLoc,
+ "OpTypeFloat expects either 2 operands (type, bitwidth) "
+ "or 3 operands (type, bitwidth, encoding), but got ")
+ << operands.size();
+ uint32_t bitWidth = operands[1];
Type floatTy;
- switch (operands[1]) {
+ switch (bitWidth) {
case 16:
floatTy = opBuilder.getF16Type();
break;
@@ -883,8 +887,20 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
break;
default:
return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
- << operands[1];
+ << bitWidth;
+ }
+
+ if (operands.size() == 3) {
+ if (spirv::FPEncoding(operands[2]) != spirv::FPEncoding::BFloat16KHR)
+ return emitError(unknownLoc, "unsupported OpTypeFloat FP encoding: ")
+ << operands[2];
+ if (bitWidth != 16)
+ return emitError(unknownLoc,
+ "invalid OpTypeFloat bitwidth for bfloat16 encoding: ")
+ << bitWidth << " (expected 16)";
+ floatTy = opBuilder.getBF16Type();
}
+
typeMap[operands[0]] = floatTy;
} break;
case spirv::Opcode::OpTypeVector: {
@@ -1399,6 +1415,9 @@ LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
} else if (floatType.isF16()) {
APInt data(16, operands[2]);
value = APFloat(APFloat::IEEEhalf(), data);
+ } else if (floatType.isBF16()) {
+ APInt data(16, operands[2]);
+ value = APFloat(APFloat::BFloat(), data);
}
auto attr = opBuilder.getFloatAttr(floatType, value);
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 647535809554c..d258bfd852961 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -523,6 +523,9 @@ LogicalResult Serializer::prepareBasicType(
if (auto floatType = dyn_cast<FloatType>(type)) {
typeEnum = spirv::Opcode::OpTypeFloat;
operands.push_back(floatType.getWidth());
+ if (floatType.isBF16()) {
+ operands.push_back(static_cast<uint32_t>(spirv::FPEncoding::BFloat16KHR));
+ }
return success();
}
@@ -1022,21 +1025,23 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
auto resultID = getNextID();
APFloat value = floatAttr.getValue();
+ const llvm::fltSemantics *semantics = &value.getSemantics();
auto opcode =
isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
- if (&value.getSemantics() == &APFloat::IEEEsingle()) {
+ if (semantics == &APFloat::IEEEsingle()) {
uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
- } else if (&value.getSemantics() == &APFloat::IEEEdouble()) {
+ } else if (semantics == &APFloat::IEEEdouble()) {
struct DoubleWord {
uint32_t word1;
uint32_t word2;
} words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
encodeInstructionInto(typesGlobalValues, opcode,
{typeID, resultID, words.word1, words.word2});
- } else if (&value.getSemantics() == &APFloat::IEEEhalf()) {
+ } else if (semantics == &APFloat::IEEEhalf() ||
+ semantics == &APFloat::BFloat()) {
uint32_t word =
static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
index 82d750755ffe2..1737f4a906bf8 100644
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
@@ -173,6 +173,12 @@ func.func @float16(%arg0: f16) { return }
// NOEMU-SAME: f64
func.func @float64(%arg0: f64) { return }
+// CHECK-LABEL: spirv.func @bfloat16
+// CHECK-SAME: f32
+// NOEMU-LABEL: func.func @bfloat16
+// NOEMU-SAME: bf16
+func.func @bfloat16(%arg0: bf16) { return }
+
// f80 is not supported by SPIR-V.
// CHECK-LABEL: func.func @float80
// CHECK-SAME: f80
@@ -206,18 +212,6 @@ func.func @float64(%arg0: f64) { return }
// -----
-// Check that bf16 is not supported.
-module attributes {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
-} {
-
-// CHECK-NOT: spirv.func @bf16_type
-func.func @bf16_type(%arg0: bf16) { return }
-
-} // end module
-
-// -----
-
//===----------------------------------------------------------------------===//
// Complex types
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
index 2d0c86e08de5a..d58c27598f2b8 100644
--- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
@@ -12,6 +12,14 @@ func.func @fadd_scalar(%arg: f32) -> f32 {
// -----
+func.func @fadd_bf16_scalar(%arg: bf16) -> bf16 {
+ // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+ %0 = spirv.FAdd %arg, %arg : bf16
+ return %0 : bf16
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.FDiv
//===----------------------------------------------------------------------===//
@@ -24,6 +32,14 @@ func.func @fdiv_scalar(%arg: f32) -> f32 {
// -----
+func.func @fdiv_bf16_scalar(%arg: bf16) -> bf16 {
+ // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+ %0 = spirv.FDiv %arg, %arg : bf16
+ return %0 : bf16
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.FMod
//===----------------------------------------------------------------------===//
@@ -36,6 +52,14 @@ func.func @fmod_scalar(%arg: f32) -> f32 {
// -----
+func.func @fmod_bf16_scalar(%arg: bf16) -> bf16 {
+ // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+ %0 = spirv.FMod %arg, %arg : bf16
+ return %0 : bf16
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.FMul
//===----------------------------------------------------------------------===//
@@ -70,6 +94,14 @@ func.func @fmul_bf16(%arg: bf16) -> bf16 {
// -----
+func.func @fmul_bf16_vector(%arg: vector<4xbf16>) -> vector<4xbf16> {
+ // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+ %0 = spirv.FMul %arg, %arg : vector<4xbf16>
+ return %0 : vector<4xbf16>
+}
+
+// -----
+
func.func @fmul_tensor(%arg: tensor<4xf32>) -> tensor<4xf32> {
// expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
%0 = spirv.FMul %arg, %arg : tensor<4xf32>
@@ -90,6 +122,14 @@ func.func @fnegate_scalar(%arg: f32) -> f32 {
// -----
+func.func @fnegate_bf16_scalar(%arg: bf16) -> bf16 {
+ // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+ %0 = spirv.FNegate %arg : bf16
+ return %0 : bf16
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.FRem
//===----------------------------------------------------------------------===//
@@ -102,6 +142,14 @@ func.func @frem_scalar(%arg: f32) -> f32 {
// -----
+func.func @frem_bf16_scalar(%arg: bf16) -> bf16 {
+ // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+ %0 = spirv.FRem %arg, %arg : bf16
+ return %0 : bf16
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.FSub
//===----------------------------------------------------------------------===//
@@ -114,6 +162,14 @@ func.func @fsub_scalar(%arg: f32) -> f32 {
// -----
+func.func @fsub_bf16_scalar(%arg: bf16) -> bf16 {
+ // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+ %0 = spirv.FSub %arg, %arg : bf16
+ return %0 : bf16
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.IAdd
//===----------------------------------------------------------------------===//
@@ -489,3 +545,11 @@ func.func @vector_times_scalar(%vector: vector<4xf32>, %scalar: f32) -> vector<3
%0 = spirv.VectorTimesScalar %vector, %scalar : (vector<4xf32>, f32) -> vector<3xf32>
return %0 : vector<3xf32>
}
+
+// -----
+
+func.func @vector_bf16_times_scalar_bf16(%vector: vector<4xbf16>, %scalar: bf16) -> vector<4xbf16> {
+ // expected-error @+1 {{op operand #0 must be vector of 16/32/64-bit float values of length 2/3/4}}
+ %0 = spirv.VectorTimesScalar %vector, %scalar : (vector<4xbf16>, bf16) -> vector<4xbf16>
+ return %0 : vector<4xbf16>
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir
index cc0abd3a42dcb..661497d5fff38 100644
--- a/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir
@@ -272,3 +272,11 @@ func.func @atomic_fadd(%ptr : !spirv.ptr<f32, StorageBuffer>, %value : f32) -> f
%0 = spirv.EXT.AtomicFAdd <Device> <Acquire|Release> %ptr, %value : !spirv.ptr<f32, StorageBuffer>
return %0 : f32
}
+
+// -----
+
+func.func @atomic_bf16_fadd(%ptr : !spirv.ptr<bf16, StorageBuffer>, %value : bf16) -> bf16 {
+ // expected-error @+1 {{op operand #1 must be 16/32/64-bit float, but got 'bf16'}}
+ %0 = spirv.EXT.AtomicFAdd <Device> <None> %ptr, %value : !spirv.ptr<bf16, StorageBuffer>
+ return %0 : bf16
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
index 34d0109e6bb44..4480a1f3720f2 100644
--- a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
@@ -110,6 +110,14 @@ func.func @convert_f_to_s_vector(%arg0 : vector<3xf32>) -> vector<3xi32> {
// -----
+func.func @convert_bf16_to_s32_scalar(%arg0 : bf16) -> i32 {
+ // CHECK: {{%.*}} = spirv.ConvertFToS {{%.*}} : bf16 to i32
+ %0 = spirv.ConvertFToS %arg0 : bf16 to i32
+ spirv.ReturnValue %0 : i32
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.ConvertFToU
//===----------------------------------------------------------------------===//
@@ -146,6 +154,14 @@ func.func @convert_f_to_u.coopmatrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgrou
// -----
+func.func @convert_bf16_to_u32_scalar(%arg0 : bf16) -> i32 {
+ // CHECK: {{%.*}} = spirv.ConvertFToU {{%.*}} : bf16 to i32
+ %0 = spirv.ConvertFToU %arg0 : bf16 to i32
+ spirv.ReturnValue %0 : i32
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.ConvertSToF
//===----------------------------------------------------------------------===//
@@ -174,6 +190,14 @@ func.func @convert_s_to_f_vector(%arg0 : vector<3xi32>) -> vector<3xf32> {
// -----
+func.func @convert_s32_to_bf16_scalar(%arg0 : i32) -> bf16 {
+ // CHECK: {{%.*}} = spirv.ConvertSToF {{%.*}} : i32 to bf16
+ %0 = spirv.ConvertSToF %arg0 : i32 to bf16
+ spirv.ReturnValue %0 : bf16
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.ConvertUToF
//===----------------------------------------------------------------------===//
@@ -202,6 +226,14 @@ func.func @convert_u_to_f_vector(%arg0 : vector<3xi32>) -> vector<3xf32> {
// -----
+func.func @convert_u32_to_bf16_scalar(%arg0 : i32) -> bf16 {
+ // CHECK: {{%.*}} = spirv.ConvertUToF {{%.*}} : i32 to bf16
+ %0 = spirv.ConvertUToF %arg0 : i32 to bf16
+ spirv.ReturnValue %0 : bf16
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.FConvert
//===----------------------------------------------------------------------===//
@@ -238,6 +270,30 @@ func.func @f_convert_vector(%arg0 : f32) -> f32 {
// -----
+func.func @f_convert_bf16_to_f32_scalar(%arg0 : bf16) -> f32 {
+ // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : bf16 to f32
+ %0 = spirv.FConvert %arg0 : bf16 to f32
+ spirv.ReturnValue %0 : f32
+}
+
+// -----
+
+func.func @f_convert_f32_to_bf16_vector(%arg0 : vector<3xf32>) -> vector<3xbf16> {
+ // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : vector<3xf32> to vector<3xbf16>
+ %0 = spirv.FConvert %arg0 : vector<3xf32> to vector<3xbf16>
+ spirv.ReturnValue %0 : vector<3xbf16>
+}
+
+// -----
+
+func.func @f_convert_f32_to_bf16_coop_matrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>) -> !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA> {
+ // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> to !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>
+ %0 = spirv.FConvert %arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> to !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>
+ spirv.ReturnValue %0 : !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.SConvert
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
index 3fc8dfb2767d1..e71b545de11df 100644
--- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
@@ -11,6 +11,13 @@ func.func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> ve
return %0: vector<3xf32>
}
+// CHECK-LABEL: func @composite_construct_bf16_vector
+func.func @composite_construct_bf16_vector(%arg0: bf16, %arg1: bf16, %arg2 : bf16) -> vector<3xbf16> {
+ // CHECK: spirv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (bf16, bf16, bf16) -> vector<3xbf16>
+ %0 = spirv.CompositeConstruct %arg0, %arg1, %arg2 : (bf16, bf16, bf16) -> vector<3xbf16>
+ return %0: vector<3xbf16>
+}
+
// CHECK-LABEL: func @composite_construct_struct
func.func @composite_construct_struct(%arg0: vector<3xf32>, %arg1: !spirv.array<4xf32>, %arg2 : !spirv.struct<(f32)>) -> !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)> {
// CHECK: spirv.CompositeConstruct
diff --git a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
index 2b75767feaf92..642346cc40b0d 100644
--- a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
@@ -50,6 +50,14 @@ func.func @exp(%arg0 : i32) -> () {
// -----
+func.func @exp_bf16(%arg0 : bf16) -> () {
+ // expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values of length 2/3/4}}
+ %2 = spirv.GL.Exp %arg0 : bf16
+ return
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.GL.{F|S|U}{Max|Min}
//===----------------------------------------------------------------------===//
@@ -92,6 +100,15 @@ func.func @iminmax(%arg0: i32, %arg1: i32) {
// -----
+func.func @fmaxminbf16vec(%arg0 : vector<3xbf16>, %arg1 : vector<3xbf16>) {
+ // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+ %1 = spirv.GL.FMax %arg0, %arg1 : vector<3xbf16>
+ %2 = spirv.GL.FMin %arg0, %arg1 : vector<3xbf16>
+ return
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.GL.InverseSqrt
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
index 5c24f0e6a7d33..d6c34645f5746 100644
--- a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
@@ -201,6 +201,14 @@ func.func @select_op_float(%arg0: i1) -> () {
return
}
+func.func @select_op_bfloat16(%arg0: i1) -> () {
+ %0 = spirv.Constant 2.0 : bf16
+ %1 = spirv.Constant 3.0 : bf16
+ // CHECK: spirv.Select {{%.*}}, {{%.*}}, {{%.*}} : i1, bf16
+ %2 = spirv.Select %arg0, %0, %1 : i1, bf16
+ return
+}
+
func.func @select_op_ptr(%arg0: i1) -> () {
%0 = spirv.Variable : !spirv.ptr<f32, Function>
%1 = spirv.Variable : !spirv.ptr<f32, Function>
diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
index 5f56de6ad1fa9..7ab94f17360d5 100644
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
@@ -184,6 +184,14 @@ func.func @group_non_uniform_fmul_clustered_reduce(%val: vector<2xf32>) -> vecto
// -----
+func.func @group_non_uniform_bf16_fmul_reduce(%val: bf16) -> bf16 {
+ // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'bf16'}}
+ %0 = spirv.GroupNonUniformFMul <Workgroup> <Reduce> %val : bf16 -> bf16
+ return %0: bf16
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.GroupNonUniformFMax
//===----------------------------------------------------------------------===//
@@ -197,6 +205,14 @@ func.func @group_non_uniform_fmax_reduce(%val: f32) -> f32 {
// -----
+func.func @group_non_uniform_bf16_fmax_reduce(%val: bf16) -> bf16 {
+ // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'bf16'}}
+ %0 = spirv.GroupNonUniformFMax <Workgroup> <Reduce> %val : bf16 -> bf16
+ return %0: bf16
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.GroupNonUniformFMin
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index b63a08d96e6af..c23894c62826b 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -15,6 +15,9 @@ func.func private @vector_array_type(!spirv.array< 32 x vector<4xf32> >) -> ()
// CHECK: func private @array_type_stride(!spirv.array<4 x !spirv.array<4 x f32, stride=4>, stride=128>)
func.func private @array_type_stride(!spirv.array< 4 x !spirv.array<4 x f32, stride=4>, stride = 128>) -> ()
+// CHECK: func private @vector_array_type_bf16(!spirv.array<32 x vector<4xbf16>>)
+func.func private @vector_array_type_bf16(!spirv.array<32 x vector<4xbf16> >) -> ()
+
// -----
// expected-error @+1 {{expected '<'}}
@@ -57,11 +60,6 @@ func.func private @tensor_type(!spirv.array<4xtensor<4xf32>>) -> ()
// -----
-// expected-error @+1 {{cannot use 'bf16' to compose SPIR-V types}}
-func.func private @bf16_type(!spirv.array<4xbf16>) -> ()
-
-// -----
-
// expected-error @+1 {{only 1/8/16/32/64-bit integer type allowed but found 'i256'}}
func.func private @i256_type(!spirv.array<4xi256>) -> ()
diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index ff5ac7cea8fc6..2b237665ffc4a 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -217,3 +217,17 @@ spirv.module Logical GLSL450 attributes {
spirv.GlobalVariable @data : !spirv.ptr<!spirv.struct<(i8 [0], f16 [2], i64 [4])>, Uniform>
spirv.GlobalVariable @img : !spirv.ptr<!spirv.image<f32, Buffer, NoDepth, NonArrayed, SingleSampled, SamplerUnknown, Rg32f>, UniformConstant>
}
+
+// Using bfloat16 requires BFloat16TypeKHR capability and SPV_KHR_bfloat16 extension.
+// CHECK: requires #spirv.vce<v1.0, [StorageBuffer16BitAccess, Shader, BFloat16TypeKHR], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
+spirv.module Logical GLSL450 attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [Shader, StorageBuffer16BitAccess, BFloat16TypeKHR], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>,
+ #spirv.resource_limits<>
+ >
+} {
+ spirv.func @load_bf16(%ptr : !spirv.ptr<bf16, StorageBuffer>) -> bf16 "None" {
+ %val = spirv.Load "StorageBuffer" %ptr : bf16
+ spirv.ReturnValue %val : bf16
+ }
+}
diff --git a/mlir/test/Target/SPIRV/cast-ops.mlir b/mlir/test/Target/SPIRV/cast-ops.mlir
index ede0bf30511ef..04a468b39b645 100644
--- a/mlir/test/Target/SPIRV/cast-ops.mlir
+++ b/mlir/test/Target/SPIRV/cast-ops.mlir
@@ -25,6 +25,11 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
%0 = spirv.ConvertFToS %arg0 : f64 to i32
spirv.ReturnValue %0 : i32
}
+ spirv.func @convert_bf16_to_s32(%arg0 : bf16) -> i32 "None" {
+ // CHECK: {{%.*}} = spirv.ConvertFToS {{%.*}} : bf16 to i32
+ %0 = spirv.ConvertFToS %arg0 : bf16 to i32
+ spirv.ReturnValue %0 : i32
+ }
spirv.func @convert_f_to_u(%arg0 : f32) -> i32 "None" {
// CHECK: {{%.*}} = spirv.ConvertFToU {{%.*}} : f32 to i32
%0 = spirv.ConvertFToU %arg0 : f32 to i32
@@ -35,6 +40,11 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
%0 = spirv.ConvertFToU %arg0 : f64 to i32
spirv.ReturnValue %0 : i32
}
+ spirv.func @convert_bf16_to_u32(%arg0 : bf16) -> i32 "None" {
+ // CHECK: {{%.*}} = spirv.ConvertFToU {{%.*}} : bf16 to i32
+ %0 = spirv.ConvertFToU %arg0 : bf16 to i32
+ spirv.ReturnValue %0 : i32
+ }
spirv.func @convert_s_to_f(%arg0 : i32) -> f32 "None" {
// CHECK: {{%.*}} = spirv.ConvertSToF {{%.*}} : i32 to f32
%0 = spirv.ConvertSToF %arg0 : i32 to f32
@@ -45,6 +55,11 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
%0 = spirv.ConvertSToF %arg0 : i64 to f32
spirv.ReturnValue %0 : f32
}
+ spirv.func @convert_s64_to_bf16(%arg0 : i64) -> bf16 "None" {
+ // CHECK: {{%.*}} = spirv.ConvertSToF {{%.*}} : i64 to bf16
+ %0 = spirv.ConvertSToF %arg0 : i64 to bf16
+ spirv.ReturnValue %0 : bf16
+ }
spirv.func @convert_u_to_f(%arg0 : i32) -> f32 "None" {
// CHECK: {{%.*}} = spirv.ConvertUToF {{%.*}} : i32 to f32
%0 = spirv.ConvertUToF %arg0 : i32 to f32
@@ -55,11 +70,26 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
%0 = spirv.ConvertUToF %arg0 : i64 to f32
spirv.ReturnValue %0 : f32
}
- spirv.func @f_convert(%arg0 : f32) -> f64 "None" {
+ spirv.func @convert_u64_to_bf16(%arg0 : i64) -> bf16 "None" {
+ // CHECK: {{%.*}} = spirv.ConvertUToF {{%.*}} : i64 to bf16
+ %0 = spirv.ConvertUToF %arg0 : i64 to bf16
+ spirv.ReturnValue %0 : bf16
+ }
+ spirv.func @convert_f32_to_f64(%arg0 : f32) -> f64 "None" {
// CHECK: {{%.*}} = spirv.FConvert {{%.*}} : f32 to f64
%0 = spirv.FConvert %arg0 : f32 to f64
spirv.ReturnValue %0 : f64
}
+ spirv.func @convert_f32_to_bf16(%arg0 : f32) -> bf16 "None" {
+ // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : f32 to bf16
+ %0 = spirv.FConvert %arg0 : f32 to bf16
+ spirv.ReturnValue %0 : bf16
+ }
+ spirv.func @convert_bf16_to_f32(%arg0 : bf16) -> f32 "None" {
+ // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : bf16 to f32
+ %0 = spirv.FConvert %arg0 : bf16 to f32
+ spirv.ReturnValue %0 : f32
+ }
spirv.func @s_convert(%arg0 : i32) -> i64 "None" {
// CHECK: {{%.*}} = spirv.SConvert {{%.*}} : i32 to i64
%0 = spirv.SConvert %arg0 : i32 to i64
diff --git a/mlir/test/Target/SPIRV/logical-ops.mlir b/mlir/test/Target/SPIRV/logical-ops.mlir
index 16846ac84e38c..b2008719b021c 100644
--- a/mlir/test/Target/SPIRV/logical-ops.mlir
+++ b/mlir/test/Target/SPIRV/logical-ops.mlir
@@ -108,3 +108,26 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
spirv.Return
}
}
+
+// -----
+
+// Test select works with bf16 scalar and vectors.
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+ spirv.SpecConstant @condition_scalar = true
+ spirv.func @select_bf16() -> () "None" {
+ %0 = spirv.Constant 4.0 : bf16
+ %1 = spirv.Constant 5.0 : bf16
+ %2 = spirv.mlir.referenceof @condition_scalar : i1
+ // CHECK: spirv.Select {{.*}}, {{.*}}, {{.*}} : i1, bf16
+ %3 = spirv.Select %2, %0, %1 : i1, bf16
+ %4 = spirv.Constant dense<[2.0, 3.0, 4.0, 5.0]> : vector<4xbf16>
+ %5 = spirv.Constant dense<[6.0, 7.0, 8.0, 9.0]> : vector<4xbf16>
+ // CHECK: spirv.Select {{.*}}, {{.*}}, {{.*}} : i1, vector<4xbf16>
+ %6 = spirv.Select %2, %4, %5 : i1, vector<4xbf16>
+ %7 = spirv.Constant dense<[true, true, true, true]> : vector<4xi1>
+ // CHECK: spirv.Select {{.*}}, {{.*}}, {{.*}} : vector<4xi1>, vector<4xbf16>
+ %8 = spirv.Select %7, %4, %5 : vector<4xi1>, vector<4xbf16>
+ spirv.Return
+ }
+}
More information about the Mlir-commits
mailing list