[Mlir-commits] [mlir] [mlir][spirv] Add 8-bit float type emulation (PR #148811)
Md Abdullah Shahneous Bari
llvmlistbot at llvm.org
Mon Jul 28 13:53:47 PDT 2025
https://github.com/mshahneo updated https://github.com/llvm/llvm-project/pull/148811
>From ff27ae8a432ba2afaf70d1974f0c149227c4d562 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Tue, 15 Jul 2025 08:38:45 +0000
Subject: [PATCH 1/8] Add 8-bit float emulation for SPIR-V conversion.
SPIR-V does not support any 8-bit floats.
Threfore, 8-bit floats are emulated as 8-bit integers.
---
mlir/include/mlir/Conversion/Passes.td | 18 +++-
.../SPIRV/Transforms/SPIRVConversion.h | 4 +
.../ControlFlowToSPIRVPass.cpp | 2 +
.../FuncToSPIRV/FuncToSPIRVPass.cpp | 2 +
.../TensorToSPIRV/TensorToSPIRVPass.cpp | 2 +
.../SPIRV/Transforms/SPIRVConversion.cpp | 97 ++++++++++++++++++-
6 files changed, 120 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index eb18160ea2eeb..8616ba0c6df8e 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -196,6 +196,9 @@ def ConvertArithToSPIRVPass : Pass<"convert-arith-to-spirv"> {
"bool", /*default=*/"true",
"Emulate narrower scalar types with 32-bit ones if not supported by "
"the target">,
+ Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
+ "bool", /*default=*/"true",
+ "Emulate unsupported float types by emulating them with integer types of same bit width">
];
}
@@ -416,7 +419,10 @@ def ConvertControlFlowToSPIRVPass : Pass<"convert-cf-to-spirv"> {
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
"bool", /*default=*/"true",
"Emulate narrower scalar types with 32-bit ones if not supported by"
- " the target">
+ " the target">,
+ Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
+ "bool", /*default=*/"true",
+ "Emulate unsupported float types by emulating them with integer types of same bit width">
];
}
@@ -500,7 +506,10 @@ def ConvertFuncToSPIRVPass : Pass<"convert-func-to-spirv"> {
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
"bool", /*default=*/"true",
"Emulate narrower scalar types with 32-bit ones if not supported by"
- " the target">
+ " the target">,
+ Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
+ "bool", /*default=*/"true",
+ "Emulate unsupported float types by emulating them with integer types of same bit width">
];
}
@@ -1163,7 +1172,10 @@ def ConvertTensorToSPIRVPass : Pass<"convert-tensor-to-spirv"> {
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
"bool", /*default=*/"true",
"Emulate narrower scalar types with 32-bit ones if not supported by"
- " the target">
+ " the target">,
+ Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
+ "bool", /*default=*/"true",
+ "Emulate unsupported float types by emulating them with integer types of same bit width">
];
}
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 3d22ec918f4c5..03ae54a8ae30a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -39,6 +39,10 @@ struct SPIRVConversionOptions {
/// The number of bits to store a boolean value.
unsigned boolNumBits{8};
+ /// Whether to emulate unsupported floats with integer types of same bit
+ /// width.
+ bool emulateUnsupportedFloatTypes{true};
+
/// How sub-byte values are storaged in memory.
SPIRVSubByteTypeStorage subByteTypeStorage{SPIRVSubByteTypeStorage::Packed};
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
index 03f4bf4df4912..01657cced2281 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
@@ -43,6 +43,8 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() {
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.emulateUnsupportedFloatTypes =
+ this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
// TODO: We should also take care of block argument type conversion.
diff --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
index 8ed9f659afb10..ca67079ce9bb1 100644
--- a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
@@ -42,6 +42,8 @@ void ConvertFuncToSPIRVPass::runOnOperation() {
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.emulateUnsupportedFloatTypes =
+ this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
RewritePatternSet patterns(context);
diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
index f07386ea80124..309ed8b054628 100644
--- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
@@ -41,6 +41,8 @@ class ConvertTensorToSPIRVPass
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.emulateUnsupportedFloatTypes =
+ this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
RewritePatternSet patterns(context);
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 35ec0190b5a61..37dd75b586002 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -169,6 +169,7 @@ static spirv::ScalarType getIndexType(MLIRContext *ctx,
// SPIR-V dialect. Keeping it local till the use case arises.
static std::optional<int64_t>
getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
+
if (isa<spirv::ScalarType>(type)) {
auto bitWidth = type.getIntOrFloatBitWidth();
// According to the SPIR-V spec:
@@ -182,6 +183,15 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
return bitWidth / 8;
}
+ // Handle 8-bit floats.
+ if (options.emulateUnsupportedFloatTypes && isa<FloatType>(type)) {
+ auto bitWidth = type.getIntOrFloatBitWidth();
+ if (bitWidth == 8)
+ return bitWidth / 8;
+ else
+ return std::nullopt;
+ }
+
if (auto complexType = dyn_cast<ComplexType>(type)) {
auto elementSize = getTypeNumBytes(options, complexType.getElementType());
if (!elementSize)
@@ -318,6 +328,67 @@ static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
type.getSignedness());
}
+/// Converts 8-bit float types to integer types with the same bit width.
+/// Returns a nullptr for unsupported 8-bit float types.
+static Type convert8BitFloatType(const SPIRVConversionOptions &options,
+ FloatType type) {
+ if (!options.emulateUnsupportedFloatTypes)
+ return nullptr;
+ // F8 types are converted to integer types with the same bit width.
+ if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
+ Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
+ Float8E8M0FNUType>(type))
+ return IntegerType::get(type.getContext(), type.getWidth());
+ LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type\n");
+ return nullptr;
+}
+
+/// Converts a sub-byte float ``type` to i32 regardless of target environment.
+/// Returns a nullptr for unsupported float types, including non sub-byte
+/// types.
+///
+/// We are treating 8 bit floats as sub-byte types here due to it's similar
+/// nature of being used as a packed format.
+
+/// Note that we don't recognize
+/// sub-byte types in `spirv::ScalarType` and use the above given that these
+/// sub-byte types are not supported at all in SPIR-V; there are no
+/// compute/storage capability for them like other supported integer types.
+
+// static Type convertPackedFLoatType(const SPIRVConversionOptions &options,
+// FloatType type) {
+
+// // F4, F6, F8 types are converted to integer types with the same bit width.
+
+// if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
+// Float8E5M2FNUZType,
+// Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
+// Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
+// Float8E8M0FNUType>(type))
+// auto emulatedType = IntegerType::get(type.getContext(), type.getWidth());
+
+// if (type.getWidth() > 8) {
+// LLVM_DEBUG(llvm::dbgs() << "not a packed type\n");
+// return nullptr;
+// }
+// if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packed) {
+// LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n");
+// return nullptr;
+// }
+
+// // if (!llvm::isPowerOf2_32(type.getWidth())) {
+// // LLVM_DEBUG(llvm::dbgs()
+// // << "unsupported non-power-of-two bitwidth in sub-byte" <<
+// type
+// // << "\n");
+// // return nullptr;
+// // }
+
+// LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
+// return IntegerType::get(type.getContext(), /*width=*/32,
+// type.getSignedness());
+// }
+
/// Returns a type with the same shape but with any index element type converted
/// to the matching integer type. This is a noop when the element type is not
/// the index type.
@@ -339,8 +410,20 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
type = cast<VectorType>(convertIndexElementType(type, options));
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
if (!scalarType) {
- // If this is not a spec allowed scalar type, try to handle sub-byte integer
- // types.
+ // If this is not a spec allowed scalar type, there are 2 scenarios,
+ // 8 bit floats or sub-byte integer types. try to handle them accrodingly.
+
+ // Hnadle 8 bit float types.
+ auto floatType = dyn_cast<FloatType>(type.getElementType());
+ if (floatType && floatType.getWidth() == 8) {
+ // If this is an 8 bit float type, try to convert it to a supported
+ // integer type.
+ if (auto convertedType = convert8BitFloatType(options, floatType)) {
+ return VectorType::get(type.getShape(), convertedType);
+ }
+ }
+
+ // Handle sub-byte integer types.
auto intType = dyn_cast<IntegerType>(type.getElementType());
if (!intType) {
LLVM_DEBUG(llvm::dbgs()
@@ -596,6 +679,14 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
} else if (auto indexType = dyn_cast<IndexType>(elementType)) {
type = cast<MemRefType>(convertIndexElementType(type, options));
arrayElemType = type.getElementType();
+ } else if (auto floatType = dyn_cast<FloatType>(elementType)) {
+ // Hnadle 8 bit float types.
+ if (options.emulateUnsupportedFloatTypes && floatType &&
+ floatType.getWidth() == 8) {
+ // If this is an 8 bit float type, try to convert it to a supported
+ // integer type.
+ arrayElemType = convert8BitFloatType(options, floatType);
+ }
} else {
LLVM_DEBUG(
llvm::dbgs()
@@ -1444,6 +1535,8 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
addConversion([this](FloatType floatType) -> std::optional<Type> {
if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
return convertScalarType(this->targetEnv, this->options, scalarType);
+ if (floatType.getWidth() == 8)
+ return convert8BitFloatType(this->options, floatType);
return Type();
});
>From 5264873f30bf3958980bb0e41223a6ef7a1abf1a Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Tue, 15 Jul 2025 08:40:31 +0000
Subject: [PATCH 2/8] Add arith.constant support.
Handles scalar and vector.
---
.../Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 30 ++++++++++++++++---
1 file changed, 26 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index d43e6816641cb..f066671efd754 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -99,6 +99,14 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
return builder.getF32FloatAttr(dstVal.convertToFloat());
}
+// Get IntegerAttr from FloatAttr.
+IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
+ ConversionPatternRewriter &rewriter) {
+ APFloat floatVal = floatAttr.getValue();
+ APInt intVal = floatVal.bitcastToAPInt();
+ return rewriter.getIntegerAttr(dstType, intVal);
+}
+
/// Returns true if the given `type` is a boolean scalar or vector type.
static bool isBoolScalarOrVector(Type type) {
assert(type && "Not a valid type");
@@ -296,8 +304,16 @@ struct ConstantCompositeOpPattern final
SmallVector<Attribute, 8> elements;
if (isa<FloatType>(srcElemType)) {
for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
- FloatAttr dstAttr =
- convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter);
+ Attribute dstAttr = nullptr;
+ // Handle 8-bit float conversion to 8-bit integer.
+ if (srcElemType.getIntOrFloatBitWidth() == 8 &&
+ isa<IntegerType>(dstElemType)) {
+ dstAttr =
+ getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter);
+ } else {
+ dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType),
+ rewriter);
+ }
if (!dstAttr)
return failure();
elements.push_back(dstAttr);
@@ -361,11 +377,17 @@ struct ConstantScalarOpPattern final
// Floating-point types.
if (isa<FloatType>(srcType)) {
auto srcAttr = cast<FloatAttr>(cstAttr);
- auto dstAttr = srcAttr;
+ Attribute dstAttr = srcAttr;
// Floating-point types not supported in the target environment are all
// converted to float type.
- if (srcType != dstType) {
+ if (srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) &&
+ dstType.getIntOrFloatBitWidth() == 8) {
+ // If the source is an 8-bit float, convert it to a 8-bit integer.
+ dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter);
+ if (!dstAttr)
+ return failure();
+ } else if (srcType != dstType) {
dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
if (!dstAttr)
return failure();
>From af53af1a18523bc4fc3ad0bd4ac69c4eb08814d5 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Tue, 15 Jul 2025 09:26:06 +0000
Subject: [PATCH 3/8] Handle all Shaped Type 8-bit floats in a similar way.
This approach minimizes the code modification.
---
.../SPIRV/Transforms/SPIRVConversion.cpp | 55 ++++++++++++-------
1 file changed, 35 insertions(+), 20 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 37dd75b586002..1ddefb53aa94d 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -343,6 +343,29 @@ static Type convert8BitFloatType(const SPIRVConversionOptions &options,
return nullptr;
}
+/// Returns a type with the same shape but with any 8-bit float element type
+/// converted to the same bit width integer type. This is a noop when the
+/// element type is not the 8-bit float type.
+static ShapedType
+convertShaped8BitFloatType(ShapedType type,
+ const SPIRVConversionOptions &options) {
+ if (!options.emulateUnsupportedFloatTypes)
+ return nullptr;
+ auto srcElementType = type.getElementType();
+ Type convertedElementType = nullptr;
+ // F8 types are converted to integer types with the same bit width.
+ if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
+ Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
+ Float8E8M0FNUType>(srcElementType))
+ convertedElementType = IntegerType::get(
+ type.getContext(), srcElementType.getIntOrFloatBitWidth());
+
+ if (!convertedElementType)
+ return type;
+
+ return type.clone(convertedElementType);
+}
+
/// Converts a sub-byte float ``type` to i32 regardless of target environment.
/// Returns a nullptr for unsupported float types, including non sub-byte
/// types.
@@ -408,22 +431,11 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
const SPIRVConversionOptions &options, VectorType type,
std::optional<spirv::StorageClass> storageClass = {}) {
type = cast<VectorType>(convertIndexElementType(type, options));
+ type = cast<VectorType>(convertShaped8BitFloatType(type, options));
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
if (!scalarType) {
- // If this is not a spec allowed scalar type, there are 2 scenarios,
- // 8 bit floats or sub-byte integer types. try to handle them accrodingly.
-
- // Hnadle 8 bit float types.
- auto floatType = dyn_cast<FloatType>(type.getElementType());
- if (floatType && floatType.getWidth() == 8) {
- // If this is an 8 bit float type, try to convert it to a supported
- // integer type.
- if (auto convertedType = convert8BitFloatType(options, floatType)) {
- return VectorType::get(type.getShape(), convertedType);
- }
- }
-
- // Handle sub-byte integer types.
+ // If this is not a spec allowed scalar type, try to handle sub-byte integer
+ // types.
auto intType = dyn_cast<IntegerType>(type.getElementType());
if (!intType) {
LLVM_DEBUG(llvm::dbgs()
@@ -516,6 +528,7 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
}
type = cast<TensorType>(convertIndexElementType(type, options));
+ type = cast<TensorType>(convertShaped8BitFloatType(type, options));
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
if (!scalarType) {
LLVM_DEBUG(llvm::dbgs()
@@ -681,12 +694,14 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
arrayElemType = type.getElementType();
} else if (auto floatType = dyn_cast<FloatType>(elementType)) {
// Hnadle 8 bit float types.
- if (options.emulateUnsupportedFloatTypes && floatType &&
- floatType.getWidth() == 8) {
- // If this is an 8 bit float type, try to convert it to a supported
- // integer type.
- arrayElemType = convert8BitFloatType(options, floatType);
- }
+ type = cast<MemRefType>(convertShaped8BitFloatType(type, options));
+ arrayElemType = type.getElementType();
+ // if (options.emulateUnsupportedFloatTypes && floatType &&
+ // floatType.getWidth() == 8) {
+ // // If this is an 8 bit float type, try to convert it to a supported
+ // // integer type.
+ // arrayElemType = convert8BitFloatType(options, floatType);
+ // }
} else {
LLVM_DEBUG(
llvm::dbgs()
>From 04a2da6006f1c3f7efca5ecff8fc64b16fb24986 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Tue, 15 Jul 2025 09:29:29 +0000
Subject: [PATCH 4/8] Remove commented out code.
---
.../SPIRV/Transforms/SPIRVConversion.cpp | 46 -------------------
1 file changed, 46 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 1ddefb53aa94d..e00ebfd272bf7 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -366,52 +366,6 @@ convertShaped8BitFloatType(ShapedType type,
return type.clone(convertedElementType);
}
-/// Converts a sub-byte float ``type` to i32 regardless of target environment.
-/// Returns a nullptr for unsupported float types, including non sub-byte
-/// types.
-///
-/// We are treating 8 bit floats as sub-byte types here due to it's similar
-/// nature of being used as a packed format.
-
-/// Note that we don't recognize
-/// sub-byte types in `spirv::ScalarType` and use the above given that these
-/// sub-byte types are not supported at all in SPIR-V; there are no
-/// compute/storage capability for them like other supported integer types.
-
-// static Type convertPackedFLoatType(const SPIRVConversionOptions &options,
-// FloatType type) {
-
-// // F4, F6, F8 types are converted to integer types with the same bit width.
-
-// if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
-// Float8E5M2FNUZType,
-// Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
-// Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
-// Float8E8M0FNUType>(type))
-// auto emulatedType = IntegerType::get(type.getContext(), type.getWidth());
-
-// if (type.getWidth() > 8) {
-// LLVM_DEBUG(llvm::dbgs() << "not a packed type\n");
-// return nullptr;
-// }
-// if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packed) {
-// LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n");
-// return nullptr;
-// }
-
-// // if (!llvm::isPowerOf2_32(type.getWidth())) {
-// // LLVM_DEBUG(llvm::dbgs()
-// // << "unsupported non-power-of-two bitwidth in sub-byte" <<
-// type
-// // << "\n");
-// // return nullptr;
-// // }
-
-// LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
-// return IntegerType::get(type.getContext(), /*width=*/32,
-// type.getSignedness());
-// }
-
/// Returns a type with the same shape but with any index element type converted
/// to the matching integer type. This is a noop when the element type is not
/// the index type.
>From f9f7517b1e10017828b3156ac1b13453024c3f76 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Tue, 15 Jul 2025 09:37:11 +0000
Subject: [PATCH 5/8] Remove unnecessary commented out code.
---
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp | 6 ------
1 file changed, 6 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index e00ebfd272bf7..3580f7a61ae7e 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -650,12 +650,6 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
// Hnadle 8 bit float types.
type = cast<MemRefType>(convertShaped8BitFloatType(type, options));
arrayElemType = type.getElementType();
- // if (options.emulateUnsupportedFloatTypes && floatType &&
- // floatType.getWidth() == 8) {
- // // If this is an 8 bit float type, try to convert it to a supported
- // // integer type.
- // arrayElemType = convert8BitFloatType(options, floatType);
- // }
} else {
LLVM_DEBUG(
llvm::dbgs()
>From 03dd0c53225b4af4d70ce3071b86ecc549997dfc Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Tue, 15 Jul 2025 10:05:25 +0000
Subject: [PATCH 6/8] Fix clang-format issue.
---
.../Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp | 3 +--
mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp | 3 +--
mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp | 3 +--
3 files changed, 3 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
index 01657cced2281..56b6181018153 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
@@ -43,8 +43,7 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() {
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
- options.emulateUnsupportedFloatTypes =
- this->emulateUnsupportedFloatTypes;
+ options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
// TODO: We should also take care of block argument type conversion.
diff --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
index ca67079ce9bb1..c0439a4033eac 100644
--- a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
@@ -42,8 +42,7 @@ void ConvertFuncToSPIRVPass::runOnOperation() {
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
- options.emulateUnsupportedFloatTypes =
- this->emulateUnsupportedFloatTypes;
+ options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
RewritePatternSet patterns(context);
diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
index 309ed8b054628..8cd650e649008 100644
--- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
@@ -41,8 +41,7 @@ class ConvertTensorToSPIRVPass
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
- options.emulateUnsupportedFloatTypes =
- this->emulateUnsupportedFloatTypes;
+ options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
RewritePatternSet patterns(context);
>From 8e63d485d34cc75134361545fbbc965100a31442 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Tue, 15 Jul 2025 20:00:17 +0000
Subject: [PATCH 7/8] Add test case & make arith-to-spirv use emulation flag.
---
.../Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 9 +++-
.../SPIRV/Transforms/SPIRVConversion.cpp | 4 +-
.../ArithToSPIRV/arith-to-spirv.mlir | 11 ++++
.../FuncToSPIRV/types-to-spirv.mlir | 54 +++++++++++++++++++
4 files changed, 74 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index f066671efd754..a9257ceba8f58 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -306,7 +306,9 @@ struct ConstantCompositeOpPattern final
for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
Attribute dstAttr = nullptr;
// Handle 8-bit float conversion to 8-bit integer.
- if (srcElemType.getIntOrFloatBitWidth() == 8 &&
+ auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
+ if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
+ srcElemType.getIntOrFloatBitWidth() == 8 &&
isa<IntegerType>(dstElemType)) {
dstAttr =
getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter);
@@ -381,7 +383,9 @@ struct ConstantScalarOpPattern final
// Floating-point types not supported in the target environment are all
// converted to float type.
- if (srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) &&
+ auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
+ if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
+ srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) &&
dstType.getIntOrFloatBitWidth() == 8) {
// If the source is an 8-bit float, convert it to a 8-bit integer.
dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter);
@@ -1374,6 +1378,7 @@ struct ConvertArithToSPIRVPass
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
// Use UnrealizedConversionCast as the bridge so that we don't need to pull
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 3580f7a61ae7e..4a0ec19b86690 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -345,12 +345,12 @@ static Type convert8BitFloatType(const SPIRVConversionOptions &options,
/// Returns a type with the same shape but with any 8-bit float element type
/// converted to the same bit width integer type. This is a noop when the
-/// element type is not the 8-bit float type.
+/// element type is not the 8-bit float type or emulation flag is set to false.
static ShapedType
convertShaped8BitFloatType(ShapedType type,
const SPIRVConversionOptions &options) {
if (!options.emulateUnsupportedFloatTypes)
- return nullptr;
+ return type;
auto srcElementType = type.getElementType();
Type convertedElementType = nullptr;
// F8 types are converted to integer types with the same bit width.
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 1abe0fd2ec468..751e727534efe 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -559,6 +559,17 @@ func.func @constant() {
return
}
+// CHECK-LABEL: @constant_8bit_float
+func.func @constant_8bit_float() {
+ // CHECK: spirv.Constant 56 : i8
+ %cst = arith.constant 1.0 : f8E4M3
+ // CHECK: spirv.Constant dense<56> : vector<4xi8>
+ %cst_vector = arith.constant dense<1.0> : vector<4xf8E4M3>
+ // CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8>
+ %cst_tensor = arith.constant dense<1.0> : tensor<4xf8E5M2>
+ return
+}
+
// CHECK-LABEL: @constant_16bit
func.func @constant_16bit() {
// CHECK: spirv.Constant 4 : i16
diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
index 1737f4a906bf8..0c77c88334572 100644
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
@@ -1,6 +1,8 @@
// RUN: mlir-opt -split-input-file -convert-func-to-spirv %s -o - | FileCheck %s
// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-lt-32-bit-scalar-types=false" %s | \
// RUN: FileCheck %s --check-prefix=NOEMU
+// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-unsupported-float-types=false" %s | \
+// RUN: FileCheck %s --check-prefix=UNSUPPORTED_FLOAT
//===----------------------------------------------------------------------===//
// Integer types
@@ -944,3 +946,55 @@ func.func @unranked_tensor(%arg0: tensor<*xi32>) { return }
func.func @dynamic_dim_tensor(%arg0: tensor<8x?xi32>) { return }
} // end module
+
+
+// -----
+
+// Check that 8-bit float types are emulated as i8.
+module attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int8], []>, #spirv.resource_limits<>>
+} {
+
+ // CHECK: spirv.func @float8_to_integer8
+ // CHECK-SAME: (%arg0: i8
+ // CHECK-SAME: %arg1: i8
+ // CHECK-SAME: %arg2: i8
+ // CHECK-SAME: %arg3: i8
+ // CHECK-SAME: %arg4: i8
+ // CHECK-SAME: %arg5: i8
+ // CHECK-SAME: %arg6: i8
+ // CHECK-SAME: %arg7: i8
+ // CHECK-SAME: %arg8: vector<4xi8>
+ // CHECK-SAME: %arg9: !spirv.ptr<!spirv.struct<(!spirv.array<8 x i8, stride=1> [0])>, StorageBuffer>
+ // CHECK-SAME: %arg10: !spirv.array<4 x i8>
+ // UNSUPPORTED_FLOAT-LABEL: func.func @float8_to_integer8
+ // UNSUPPORTED_FLOAT-SAME: (%arg0: f8E5M2
+ // UNSUPPORTED_FLOAT-SAME: %arg1: f8E4M3
+ // UNSUPPORTED_FLOAT-SAME: %arg2: f8E4M3FN
+ // UNSUPPORTED_FLOAT-SAME: %arg3: f8E5M2FNUZ
+ // UNSUPPORTED_FLOAT-SAME: %arg4: f8E4M3FNUZ
+ // UNSUPPORTED_FLOAT-SAME: %arg5: f8E4M3B11FNUZ
+ // UNSUPPORTED_FLOAT-SAME: %arg6: f8E3M4
+ // UNSUPPORTED_FLOAT-SAME: %arg7: f8E8M0FNU
+ // UNSUPPORTED_FLOAT-SAME: %arg8: vector<4xf8E4M3B11FNUZ>
+ // UNSUPPORTED_FLOAT-SAME: %arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>
+ // UNSUPPORTED_FLOAT-SAME: %arg10: tensor<4xf8E5M2>
+ // UNSUPPORTED_FLOAT-SAME: ) {
+
+ func.func @float8_to_integer8(
+ %arg0: f8E5M2, // CHECK-NOT: f8E5M2
+ %arg1: f8E4M3, // CHECK-NOT: f8E4M3
+ %arg2: f8E4M3FN, // CHECK-NOT: f8E4M3FN
+ %arg3: f8E5M2FNUZ, // CHECK-NOT: f8E5M2FNUZ
+ %arg4: f8E4M3FNUZ, // CHECK-NOT: f8E4M3FNUZ
+ %arg5: f8E4M3B11FNUZ, // CHECK-NOT: f8E4M3B11FNUZ
+ %arg6: f8E3M4, // CHECK-NOT: f8E3M4
+ %arg7: f8E8M0FNU, // CHECK-NOT: f8E8M0FNU
+ %arg8: vector<4xf8E4M3B11FNUZ>, // CHECK-NOT: vector<4xf8E4M3B11FNUZ>
+ %arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>, // CHECK-NOT: memref
+ %arg10: tensor<4xf8E5M2> // CHECK-NOT: tensor
+ ) {
+ // CHECK: spirv.Return
+ return
+ }
+}
>From 3cb5796e3b3da890ae7f60e6cd2dc205a167a743 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Mon, 21 Jul 2025 19:35:37 +0000
Subject: [PATCH 8/8] Address review comments.
---
mlir/include/mlir/Conversion/Passes.td | 20 +++++++++++--------
.../Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 9 ++++++---
.../SPIRV/Transforms/SPIRVConversion.cpp | 8 +++-----
.../ArithToSPIRV/arith-to-spirv.mlir | 6 ++++++
4 files changed, 27 insertions(+), 16 deletions(-)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 8616ba0c6df8e..2d1b855b15f10 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -197,8 +197,9 @@ def ConvertArithToSPIRVPass : Pass<"convert-arith-to-spirv"> {
"Emulate narrower scalar types with 32-bit ones if not supported by "
"the target">,
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
- "bool", /*default=*/"true",
- "Emulate unsupported float types by emulating them with integer types of same bit width">
+ "bool", /*default=*/"true",
+ "Emulate unsupported float types by representing them with integer "
+ "types of same bit width">
];
}
@@ -421,8 +422,9 @@ def ConvertControlFlowToSPIRVPass : Pass<"convert-cf-to-spirv"> {
"Emulate narrower scalar types with 32-bit ones if not supported by"
" the target">,
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
- "bool", /*default=*/"true",
- "Emulate unsupported float types by emulating them with integer types of same bit width">
+ "bool", /*default=*/"true",
+ "Emulate unsupported float types by representing them with integer "
+ "types of same bit width">
];
}
@@ -508,8 +510,9 @@ def ConvertFuncToSPIRVPass : Pass<"convert-func-to-spirv"> {
"Emulate narrower scalar types with 32-bit ones if not supported by"
" the target">,
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
- "bool", /*default=*/"true",
- "Emulate unsupported float types by emulating them with integer types of same bit width">
+ "bool", /*default=*/"true",
+ "Emulate unsupported float types by representing them with integer "
+ "types of same bit width">
];
}
@@ -1174,8 +1177,9 @@ def ConvertTensorToSPIRVPass : Pass<"convert-tensor-to-spirv"> {
"Emulate narrower scalar types with 32-bit ones if not supported by"
" the target">,
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
- "bool", /*default=*/"true",
- "Emulate unsupported float types by emulating them with integer types of same bit width">
+ "bool", /*default=*/"true",
+ "Emulate unsupported float types by representing them with integer "
+ "types of same bit width">
];
}
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index a9257ceba8f58..265293b83f84c 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -99,9 +99,12 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
return builder.getF32FloatAttr(dstVal.convertToFloat());
}
-// Get IntegerAttr from FloatAttr.
-IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
- ConversionPatternRewriter &rewriter) {
+// Get in IntegerAttr from FloatAttr while preserving the bits.
+// Useful for converting float constants to integer constants while preserving
+// the bits.
+static IntegerAttr
+getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
+ ConversionPatternRewriter &rewriter) {
APFloat floatVal = floatAttr.getValue();
APInt intVal = floatVal.bitcastToAPInt();
return rewriter.getIntegerAttr(dstType, intVal);
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 4a0ec19b86690..8f4c4cc027798 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -169,7 +169,6 @@ static spirv::ScalarType getIndexType(MLIRContext *ctx,
// SPIR-V dialect. Keeping it local till the use case arises.
static std::optional<int64_t>
getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
-
if (isa<spirv::ScalarType>(type)) {
auto bitWidth = type.getIntOrFloatBitWidth();
// According to the SPIR-V spec:
@@ -188,8 +187,7 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
auto bitWidth = type.getIntOrFloatBitWidth();
if (bitWidth == 8)
return bitWidth / 8;
- else
- return std::nullopt;
+ return std::nullopt;
}
if (auto complexType = dyn_cast<ComplexType>(type)) {
@@ -339,7 +337,7 @@ static Type convert8BitFloatType(const SPIRVConversionOptions &options,
Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
Float8E8M0FNUType>(type))
return IntegerType::get(type.getContext(), type.getWidth());
- LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type\n");
+ LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type: " << type << "\n");
return nullptr;
}
@@ -351,7 +349,7 @@ convertShaped8BitFloatType(ShapedType type,
const SPIRVConversionOptions &options) {
if (!options.emulateUnsupportedFloatTypes)
return type;
- auto srcElementType = type.getElementType();
+ Type srcElementType = type.getElementType();
Type convertedElementType = nullptr;
// F8 types are converted to integer types with the same bit width.
if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 751e727534efe..6e2352e706acc 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -563,10 +563,16 @@ func.func @constant() {
func.func @constant_8bit_float() {
// CHECK: spirv.Constant 56 : i8
%cst = arith.constant 1.0 : f8E4M3
+ // CHECK: spirv.Constant 56 : i8
+ %cst_i8 = arith.bitcast %cst : f8E4M3 to i8
// CHECK: spirv.Constant dense<56> : vector<4xi8>
%cst_vector = arith.constant dense<1.0> : vector<4xf8E4M3>
+ // CHECK: spirv.Constant dense<56> : vector<4xi8>
+ %cst_vector_i8 = arith.bitcast %cst_vector : vector<4xf8E4M3> to vector<4xi8>
// CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8>
%cst_tensor = arith.constant dense<1.0> : tensor<4xf8E5M2>
+ // CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8>
+ %cst_tensor_i8 = arith.bitcast %cst_tensor : tensor<4xf8E5M2> to tensor<4xi8>
return
}
More information about the Mlir-commits
mailing list