[Mlir-commits] [mlir] [mlir][spirv] Add SPV_EXT_float8 support (PR #179246)
Davide Grohmann
llvmlistbot at llvm.org
Tue Feb 3 04:27:10 PST 2026
https://github.com/davidegrohmann updated https://github.com/llvm/llvm-project/pull/179246
>From a8e8c0644c5a8a672dc7992a581b682d1a8420c4 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Fri, 30 Jan 2026 14:02:09 +0100
Subject: [PATCH] [mlir][spirv] Add SPV_EXT_float8 support
Reference: https://github.khronos.org/SPIRV-Registry/extensions/EXT/SPV_EXT_float8.html
Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: I39c400f4a828c76c3b7c937d1e2795def8e465ab
---
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 32 +++++++++--
mlir/include/mlir/IR/Builders.h | 2 +
mlir/include/mlir/IR/Types.h | 3 ++
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 14 +++++
mlir/lib/IR/Builders.cpp | 4 ++
mlir/lib/IR/Types.cpp | 2 +
.../SPIRV/Deserialization/Deserializer.cpp | 54 ++++++++++++-------
.../Target/SPIRV/Serialization/Serializer.cpp | 15 +++++-
.../test/Dialect/SPIRV/IR/arithmetic-ops.mlir | 2 +-
mlir/test/Dialect/SPIRV/IR/availability.mlir | 1 +
mlir/test/Dialect/SPIRV/IR/composite-ops.mlir | 2 +-
mlir/test/Dialect/SPIRV/IR/structure-ops.mlir | 6 +--
mlir/test/Dialect/SPIRV/IR/types.mlir | 12 +++++
mlir/test/Target/SPIRV/constant.mlir | 40 +++++++++++++-
14 files changed, 156 insertions(+), 33 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 8b9f51d38374c..2f189c64300ae 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -360,6 +360,7 @@ def SPV_EXT_shader_image_int64 : I32EnumAttrCase<"SPV_EXT_shader_image
def SPV_EXT_shader_atomic_float16_add : I32EnumAttrCase<"SPV_EXT_shader_atomic_float16_add", 1011>;
def SPV_EXT_mesh_shader : I32EnumAttrCase<"SPV_EXT_mesh_shader", 1012>;
def SPV_EXT_replicated_composites : I32EnumAttrCase<"SPV_EXT_replicated_composites", 1013>;
+def SPV_EXT_float8 : I32EnumAttrCase<"SPV_EXT_float8", 1014>;
def SPV_AMD_gpu_shader_half_float_fetch : I32EnumAttrCase<"SPV_AMD_gpu_shader_half_float_fetch", 2000>;
def SPV_AMD_shader_ballot : I32EnumAttrCase<"SPV_AMD_shader_ballot", 2001>;
@@ -449,7 +450,7 @@ def SPIRV_ExtensionAttr :
SPV_EXT_shader_stencil_export, SPV_EXT_shader_viewport_index_layer,
SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
- SPV_EXT_mesh_shader, SPV_EXT_replicated_composites,
+ SPV_EXT_mesh_shader, SPV_EXT_replicated_composites, SPV_EXT_float8,
SPV_ARM_tensors, SPV_ARM_graph,
SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
@@ -1486,6 +1487,12 @@ def SPIRV_C_CacheControlsINTEL : I32EnumAttrCase<"CacheControlsINTEL", 6441> {
];
}
+def SPIRV_C_Float8EXT : I32EnumAttrCase<"Float8EXT", 4212> {
+ list<Availability> availability = [
+ Extension<[SPV_EXT_float8]>
+ ];
+}
+
def SPIRV_CapabilityAttr :
SPIRV_I32EnumAttr<"Capability", "valid SPIR-V Capability", "capability", [
SPIRV_C_Matrix, SPIRV_C_Addresses, SPIRV_C_Linkage, SPIRV_C_Kernel, SPIRV_C_Float16,
@@ -1583,7 +1590,7 @@ def SPIRV_CapabilityAttr :
SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL,
SPIRV_C_CacheControlsINTEL, SPIRV_C_BFloat16TypeKHR,
SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR,
- SPIRV_C_TensorFloat32RoundingINTEL
+ SPIRV_C_TensorFloat32RoundingINTEL, SPIRV_C_Float8EXT
]>;
def SPIRV_AM_Logical : I32EnumAttrCase<"Logical", 0>;
@@ -3287,9 +3294,24 @@ def SPIRV_FPE_BFloat16KHR : I32EnumAttrCase<"BFloat16KHR", 0> {
Capability<[SPIRV_C_BFloat16TypeKHR]>
];
}
+
+def SPIRV_FPE_Float8E4M3EXT : I32EnumAttrCase<"Float8E4M3EXT", 4214> {
+ list<Availability> availability = [
+ Capability<[SPIRV_C_Float8EXT]>
+ ];
+}
+
+def SPIRV_FPE_Float8E5M2EXT : I32EnumAttrCase<"Float8E5M2EXT", 4215> {
+ list<Availability> availability = [
+ Capability<[SPIRV_C_Float8EXT]>
+ ];
+}
+
def SPIRV_FPEncodingAttr :
SPIRV_I32EnumAttr<"FPEncoding", "valid SPIR-V FPEncoding", "f_p_encoding", [
- SPIRV_FPE_BFloat16KHR
+ SPIRV_FPE_BFloat16KHR,
+ SPIRV_FPE_Float8E4M3EXT,
+ SPIRV_FPE_Float8E5M2EXT,
]>;
def SPIRV_FC_None : I32BitEnumAttrCaseNone<"None">;
@@ -4248,9 +4270,11 @@ def SPIRV_Int32 : TypeAlias<I32, "Int32">;
def SPIRV_Float16 : TypeAlias<F16, "Float16">;
def SPIRV_Float32 : TypeAlias<F32, "Float32">;
def SPIRV_BFloat16KHR : TypeAlias<BF16, "BFloat16">;
+def SPIRV_Float8E4M3EXT : TypeAlias<F8E4M3FN, "Float8E4M3">;
+def SPIRV_Float8E5M2EXT : TypeAlias<F8E5M2, "Float8E5M2">;
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>;
-def SPIRV_AnyFloat : AnyTypeOf<[SPIRV_Float, SPIRV_BFloat16KHR]>;
+def SPIRV_AnyFloat : AnyTypeOf<[SPIRV_Float, SPIRV_BFloat16KHR, SPIRV_Float8E4M3EXT, SPIRV_Float8E5M2EXT]>;
def SPIRV_Vector : VectorOfRankAndLengthAndType<[1], [2, 3, 4, 8, 16],
[SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat]>;
// Component type check is done in the type parser for the following SPIR-V
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 3ba6818204ba0..a6cb1456544b9 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -62,6 +62,8 @@ class Builder {
// Types.
FloatType getF8E8M0Type();
+ FloatType getF8E4M3FNType();
+ FloatType getF8E5M2Type();
FloatType getBF16Type();
FloatType getF16Type();
FloatType getTF32Type();
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 4ffdbfa5b1224..97583a93f6157 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -116,6 +116,9 @@ class Type {
bool isF64() const;
bool isF80() const;
bool isF128() const;
+ bool isF8E4M3FN() const;
+ bool isF8E5M2() const;
+
/// Return true if this is an float type (with the specified width).
bool isFloat() const;
bool isFloat(unsigned width) const;
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 342a47cdefbf0..fa54824dceac4 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -551,6 +551,11 @@ void TypeExtensionVisitor::addConcrete(ScalarType type) {
extensions.push_back(ext);
}
+ if (isa<Float8E4M3FNType>(type) || isa<Float8E5M2Type>(type)) {
+ static constexpr auto ext = Extension::SPV_EXT_float8;
+ extensions.push_back(ext);
+ }
+
// 8- or 16-bit integer/floating-point numbers will require extra extensions
// to appear in interface storage classes. See SPV_KHR_16bit_storage and
// SPV_KHR_8bit_storage for more details.
@@ -648,6 +653,15 @@ void TypeCapabilityVisitor::addConcrete(ScalarType type) {
} else {
assert(isa<FloatType>(type));
switch (bitwidth) {
+ case 8: {
+ if (isa<Float8E4M3FNType>(type) || isa<Float8E5M2Type>(type)) {
+ static constexpr auto cap = Capability::Float8EXT;
+ capabilities.push_back(cap);
+ } else {
+ llvm_unreachable("invalid 8-bit float type to getCapabilities");
+ }
+ break;
+ }
case 16: {
if (isa<BFloat16Type>(type)) {
static constexpr auto cap = Capability::BFloat16TypeKHR;
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 8f199b60fccdc..cf64954751d5e 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -34,6 +34,10 @@ Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
FloatType Builder::getF8E8M0Type() { return Float8E8M0FNUType::get(context); }
+FloatType Builder::getF8E4M3FNType() { return Float8E4M3FNType::get(context); }
+
+FloatType Builder::getF8E5M2Type() { return Float8E5M2Type::get(context); }
+
FloatType Builder::getBF16Type() { return BFloat16Type::get(context); }
FloatType Builder::getF16Type() { return Float16Type::get(context); }
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index 765b787d3d17a..ec10a5ce9e2e7 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -41,6 +41,8 @@ bool Type::isF32() const { return llvm::isa<Float32Type>(*this); }
bool Type::isF64() const { return llvm::isa<Float64Type>(*this); }
bool Type::isF80() const { return llvm::isa<Float80Type>(*this); }
bool Type::isF128() const { return llvm::isa<Float128Type>(*this); }
+bool Type::isF8E4M3FN() const { return llvm::isa<Float8E4M3FNType>(*this); }
+bool Type::isF8E5M2() const { return llvm::isa<Float8E5M2Type>(*this); }
bool Type::isFloat() const { return llvm::isa<FloatType>(*this); }
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 8d1f9c26fe596..3ceaa9189898d 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1094,30 +1094,38 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
uint32_t bitWidth = operands[1];
Type floatTy;
- switch (bitWidth) {
- case 16:
- floatTy = opBuilder.getF16Type();
- break;
- case 32:
- floatTy = opBuilder.getF32Type();
- break;
- case 64:
- floatTy = opBuilder.getF64Type();
- break;
- default:
- return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
- << bitWidth;
+ if (operands.size() == 2) {
+ switch (bitWidth) {
+ case 16:
+ floatTy = opBuilder.getF16Type();
+ break;
+ case 32:
+ floatTy = opBuilder.getF32Type();
+ break;
+ case 64:
+ floatTy = opBuilder.getF64Type();
+ break;
+ default:
+ return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
+ << bitWidth;
+ }
}
if (operands.size() == 3) {
- if (spirv::FPEncoding(operands[2]) != spirv::FPEncoding::BFloat16KHR)
+ if (spirv::FPEncoding(operands[2]) == spirv::FPEncoding::BFloat16KHR &&
+ bitWidth == 16)
+ floatTy = opBuilder.getBF16Type();
+ else if (spirv::FPEncoding(operands[2]) ==
+ spirv::FPEncoding::Float8E4M3EXT &&
+ bitWidth == 8)
+ floatTy = opBuilder.getF8E4M3FNType();
+ else if (spirv::FPEncoding(operands[2]) ==
+ spirv::FPEncoding::Float8E5M2EXT &&
+ bitWidth == 8)
+ floatTy = opBuilder.getF8E5M2Type();
+ else
return emitError(unknownLoc, "unsupported OpTypeFloat FP encoding: ")
- << operands[2];
- if (bitWidth != 16)
- return emitError(unknownLoc,
- "invalid OpTypeFloat bitwidth for bfloat16 encoding: ")
- << bitWidth << " (expected 16)";
- floatTy = opBuilder.getBF16Type();
+ << operands[2] << " and bitWidth " << bitWidth;
}
typeMap[operands[0]] = floatTy;
@@ -1734,6 +1742,12 @@ LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
} else if (floatType.isBF16()) {
APInt data(16, operands[2]);
value = APFloat(APFloat::BFloat(), data);
+ } else if (floatType.isF8E4M3FN()) {
+ APInt data(8, operands[2]);
+ value = APFloat(APFloat::Float8E4M3FN(), data);
+ } else if (floatType.isF8E5M2()) {
+ APInt data(8, operands[2]);
+ value = APFloat(APFloat::Float8E5M2(), data);
}
auto attr = opBuilder.getFloatAttr(floatType, value);
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 840c9c990f9c6..d9ffc4131d3c3 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -599,6 +599,15 @@ LogicalResult Serializer::prepareBasicType(
if (floatType.isBF16()) {
operands.push_back(static_cast<uint32_t>(spirv::FPEncoding::BFloat16KHR));
}
+ if (floatType.isF8E4M3FN()) {
+ operands.push_back(
+ static_cast<uint32_t>(spirv::FPEncoding::Float8E4M3EXT));
+ }
+ if (floatType.isF8E5M2()) {
+ operands.push_back(
+ static_cast<uint32_t>(spirv::FPEncoding::Float8E5M2EXT));
+ }
+
return success();
}
@@ -1253,8 +1262,10 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
} words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
encodeInstructionInto(typesGlobalValues, opcode,
{typeID, resultID, words.word1, words.word2});
- } else if (semantics == &APFloat::IEEEhalf() ||
- semantics == &APFloat::BFloat()) {
+ } else if (llvm::is_contained({&APFloat::IEEEhalf(), &APFloat::BFloat(),
+ &APFloat::Float8E4M3FN(),
+ &APFloat::Float8E5M2()},
+ semantics)) {
uint32_t word =
static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
diff --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
index c703274bda579..7f1b84123151a 100644
--- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
@@ -348,7 +348,7 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 {
// -----
func.func @dot(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
- // expected-error @+1 {{'spirv.Dot' op operand #0 must be fixed-length vector of 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16}}
+ // expected-error @+1 {{'spirv.Dot' op operand #0 must be fixed-length vector of 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 values of length 2/3/4/8/16}}
%0 = spirv.Dot %arg0, %arg1 : vector<4xi32> -> i32
return %0 : i32
}
diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir
index 4ef242bdc5b16..d25c858228844 100644
--- a/mlir/test/Dialect/SPIRV/IR/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -307,6 +307,7 @@ func.func @constant_composite_replicate() -> () {
spirv.Return
}
+
//===----------------------------------------------------------------------===//
// GraphARM ops
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
index 9323518f50373..6e4126172f670 100644
--- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
@@ -100,7 +100,7 @@ func.func @composite_construct_vector_wrong_count(%arg0: f32, %arg1: f32, %arg2
// -----
func.func @composite_construct_vector_rank_two(%arg0: vector<2x2xi1>, %arg1: vector<2x2xi1>) -> vector<4x2xi1> {
- // expected-error @+1 {{op operand #0 must be variadic of void or bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V pointer type or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V sampled image type or any SPIR-V image type or any SPIR-V tensorArm type, but got 'vector<2x2xi1>'}}
+ // expected-error @+1 {{ op operand #0 must be variadic of void or bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 or vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V pointer type or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V sampled image type or any SPIR-V image type or any SPIR-V tensorArm type, but got 'vector<2x2xi1>'}}
%0 = spirv.CompositeConstruct %arg0, %arg1 : (vector<2x2xi1>, vector<2x2xi1>) -> vector<4x2xi1>
return %0: vector<4x2xi1>
}
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 2c5dc8b9f3b0f..7e37826795d83 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -170,7 +170,7 @@ func.func @coop_matrix_const_wrong_type() -> () {
//===----------------------------------------------------------------------===//
func.func @ccr_result_not_composite() -> () {
- // expected-error @+1 {{op result #0 must be vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V tensorArm type, but got 'i32'}}
+ // expected-error @+1 {{op result #0 must be vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V tensorArm type, but got 'i32'}}
%0 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : i32
return
}
@@ -360,7 +360,7 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
}
// -----
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
// expected-error @+1 {{'spirv.module' cannot contain external functions without 'Import' linkage_attributes (LinkageAttributes)}}
spirv.func @outside.func.without.linkage(%arg0 : i8) -> () "Pure"
spirv.func @inside.func() -> () "Pure" attributes {} {spirv.Return}
@@ -477,7 +477,7 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
// CHECK: linkage_attributes = #spirv.linkage_attributes<linkage_name = "outSideGlobalVar1", linkage_type = <Import>>
spirv.GlobalVariable @var1 {
linkage_attributes=#spirv.linkage_attributes<
- linkage_name="outSideGlobalVar1",
+ linkage_name="outSideGlobalVar1",
linkage_type=<Import>
>
} : !spirv.ptr<f32, Private>
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index f350e56255983..145a291343504 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -631,3 +631,15 @@ func.func private @arm_tensor_type_unranked(!spirv.arm.tensor<*xi32>) -> ()
// expected-error @+1 {{arm.tensors do not support zero dimensions}}
func.func private @arm_tensor_type_zero_dim(!spirv.arm.tensor<0xi32>) -> ()
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Float8_EXT
+//===----------------------------------------------------------------------===//
+
+// CHECK: func private @type_f8E4M3FN(f8E4M3FN)
+func.func private @type_f8E4M3FN(f8E4M3FN) -> ()
+
+// CHECK: func private @type_f8E5M2(f8E5M2)
+func.func private @type_f8E5M2(f8E5M2) -> ()
diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index 4838d3c510757..cc7c93824c6e3 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -5,8 +5,8 @@
// we cannot use splits.
spirv.module Logical Vulkan requires #spirv.vce<v1.3,
- [VulkanMemoryModel, Shader, Int64, Int16, Int8, Float64, Float16, CooperativeMatrixKHR, TensorsARM, Linkage],
- [SPV_KHR_vulkan_memory_model, SPV_KHR_cooperative_matrix, SPV_ARM_tensors]> {
+ [VulkanMemoryModel, Shader, Int64, Int16, Int8, Float64, Float16, BFloat16TypeKHR, Float8EXT, CooperativeMatrixKHR, Linkage],
+ [SPV_KHR_vulkan_memory_model, SPV_KHR_cooperative_matrix, SPV_KHR_bfloat16, SPV_EXT_float8]> {
// CHECK-LABEL: @bool_const
spirv.func @bool_const() -> () "None" {
// CHECK: spirv.Constant true
@@ -161,6 +161,42 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3,
spirv.Return
}
+ // CHECK-LABEL: @bf16
+ spirv.func @bf16() -> () "None" {
+ // CHECK: spirv.Constant 5.120000e+02 : bf16
+ %0 = spirv.Constant 512. : bf16
+ // CHECK: spirv.Constant -5.120000e+02 : bf16
+ %1 = spirv.Constant -512. : bf16
+
+ %2 = spirv.FConvert %0 : bf16 to f32
+ %3 = spirv.FConvert %1 : bf16 to f32
+ spirv.Return
+ }
+
+ // CHECK-LABEL: @f8E4M3FN
+ spirv.func @f8E4M3FN() -> () "None" {
+ // CHECK: spirv.Constant 1.280000e+02 : f8E4M3FN
+ %0 = spirv.Constant 127. : f8E4M3FN
+ // CHECK: spirv.Constant -1.280000e+02 : f8E4M3FN
+ %1 = spirv.Constant -127. : f8E4M3FN
+
+ %2 = spirv.FConvert %0 : f8E4M3FN to f32
+ %3 = spirv.FConvert %1 : f8E4M3FN to f32
+ spirv.Return
+ }
+
+ // CHECK-LABEL: @f8E5M2
+ spirv.func @f8E5M2() -> () "None" {
+ // CHECK: spirv.Constant 1.280000e+02 : f8E5M2
+ %0 = spirv.Constant 127. : f8E5M2
+ // CHECK: spirv.Constant -1.280000e+02 : f8E5M2
+ %1 = spirv.Constant -127. : f8E5M2
+
+ %2 = spirv.FConvert %0 : f8E5M2 to f32
+ %3 = spirv.FConvert %1 : f8E5M2 to f32
+ spirv.Return
+ }
+
// CHECK-LABEL: @bool_vector_const
spirv.func @bool_vector_const() -> () "None" {
// CHECK: spirv.Constant dense<false> : vector<2xi1>
More information about the Mlir-commits
mailing list