[Mlir-commits] [mlir] 73eecc9 - [mlir] Convert 8-bit float types to i8
Krzysztof Drewniak
llvmlistbot at llvm.org
Mon Jun 26 10:42:04 PDT 2023
Author: Krzysztof Drewniak
Date: 2023-06-26T17:42:00Z
New Revision: 73eecc9ca4db5730a77e7f4144d93696c9a1c5a3
URL: https://github.com/llvm/llvm-project/commit/73eecc9ca4db5730a77e7f4144d93696c9a1c5a3
DIFF: https://github.com/llvm/llvm-project/commit/73eecc9ca4db5730a77e7f4144d93696c9a1c5a3.diff
LOG: [mlir] Convert 8-bit float types to i8
Whereas LLVM currently doesn't have any types for 8-bit floats, and
whereas existing 8-bit float APIs (for instance, the AMDGCN
intrinsics) take such floats as (packed) bytes, translate the MLIR
8-bit float types to i8 during LLVM lowering.
In order to not special-case arith.constant for bitcasting constants
to their integer form, amend the MLIR to LLVM translator to turn 8-bit
float constants into i8 constants with the same value (by use of
APFloat's bitcast method).
This change can be reverted once LLVM has 8-bit float types.
Reviewed By: gysit
Differential Revision: https://reviews.llvm.org/D153160
Added:
Modified:
mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/test/Target/LLVMIR/llvmir.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index 1eb5661bb387e..79a68e875f045 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -184,7 +184,8 @@ class LLVMTypeConverter : public TypeConverter {
/// Convert a floating point type: `f16` to `f16`, `f32` to
/// `f32` and `f64` to `f64`. `bf16` is not supported
- /// by LLVM.
+ /// by LLVM. 8-bit float types are converted to 8-bit integers as this is how
+ /// all LLVM backends that support them currently represent them.
Type convertFloatType(FloatType type);
/// Convert complex number type: `complex<f16>` to `!llvm<"{ half, half }">`,
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index b518665581bb7..e764c8d30b600 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -539,14 +539,6 @@ struct ConvertAMDGPUToROCDLPass
void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns,
Chipset chipset) {
- // ROCDL supports fp8 types in some contexts, but there is no LLVM-level f8
- // type. Therefore, for this target, declare f8 to be equal to i8.
- converter.addConversion([](FloatType type) -> std::optional<Type> {
- if (type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ())
- return IntegerType::get(type.getContext(), 8);
- return std::nullopt;
- });
-
patterns.add<LDSBarrierOpLowering>(converter);
patterns.add<
RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawBufferLoadOp>,
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index aac6e60b4f50d..21ef3076a5a47 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -193,7 +193,12 @@ Type LLVMTypeConverter::convertIntegerType(IntegerType type) {
return IntegerType::get(&getContext(), type.getWidth());
}
-Type LLVMTypeConverter::convertFloatType(FloatType type) { return type; }
+Type LLVMTypeConverter::convertFloatType(FloatType type) {
+ if (type.isFloat8E5M2() || type.isFloat8E4M3FN() || type.isFloat8E5M2FNUZ() ||
+ type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ())
+ return IntegerType::get(&getContext(), type.getWidth());
+ return type;
+}
// Convert a `ComplexType` to an LLVM type. The result is a complex number
// struct with entries for the
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 9f072191804e0..685b031dd509b 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -360,6 +360,12 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
llvmType,
intAttr.getValue().sextOrTrunc(llvmType->getIntegerBitWidth()));
if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
+ const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics();
+ // Special case for 8-bit floats, which are represented by integers due to
+ // the lack of native fp8 types in LLVM at the moment.
+ if (APFloat::getSizeInBits(sem) == 8 && llvmType->isIntegerTy(8))
+ return llvm::ConstantInt::get(llvmType,
+ floatAttr.getValue().bitcastToAPInt());
if (llvmType !=
llvm::Type::getFloatingPointTy(llvmType->getContext(),
floatAttr.getValue().getSemantics())) {
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index d23b6a7c289cc..6365eabb34165 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -55,6 +55,21 @@ llvm.mlir.global internal constant @string_const("foobar") : !llvm.array<6 x i8>
// CHECK: @int_global_undef = internal global i64 undef
llvm.mlir.global internal @int_global_undef() : i64
+// CHECK: @f8E4M3FN_global_as_i8 = internal global i8 60
+llvm.mlir.global internal @f8E4M3FN_global_as_i8(1.5 : f8E4M3FN) : i8
+
+// CHECK: @f8E5M2_global_as_i8 = internal global i8 62
+llvm.mlir.global internal @f8E5M2_global_as_i8(1.5 : f8E5M2) : i8
+
+// CHECK: @f8E4M3FNUZ_global_as_i8 = internal global i8 68
+llvm.mlir.global internal @f8E4M3FNUZ_global_as_i8(1.5 : f8E4M3FNUZ) : i8
+
+// CHECK: @f8E5M2FNUZ_global_as_i8 = internal global i8 66
+llvm.mlir.global internal @f8E5M2FNUZ_global_as_i8(1.5 : f8E5M2FNUZ) : i8
+
+// CHECK: @f8E4M3B11FNUZ_global_as_i8 = internal global i8 92
+llvm.mlir.global internal @f8E4M3B11FNUZ_global_as_i8(1.5 : f8E4M3B11FNUZ) : i8
+
// CHECK: @explicit_undef = global i32 undef
llvm.mlir.global external @explicit_undef() : i32 {
%0 = llvm.mlir.undef : i32
More information about the Mlir-commits
mailing list