[Mlir-commits] [mlir] 73eecc9 - [mlir] Convert 8-bit float types to i8

Krzysztof Drewniak llvmlistbot at llvm.org
Mon Jun 26 10:42:04 PDT 2023


Author: Krzysztof Drewniak
Date: 2023-06-26T17:42:00Z
New Revision: 73eecc9ca4db5730a77e7f4144d93696c9a1c5a3

URL: https://github.com/llvm/llvm-project/commit/73eecc9ca4db5730a77e7f4144d93696c9a1c5a3
DIFF: https://github.com/llvm/llvm-project/commit/73eecc9ca4db5730a77e7f4144d93696c9a1c5a3.diff

LOG: [mlir] Convert 8-bit float types to i8

Whereas LLVM currently doesn't have any types for 8-bit floats, and
whereas existing 8-bit float APIs (for instance, the AMDGCN
intrinsics) take such floats as (packed) bytes, translate the MLIR
8-bit float types to i8 during LLVM lowering.

In order to not special-case arith.constant for bitcasting constants
to their integer form, amend the MLIR to LLVM translator to turn 8-bit
float constants into i8 constants with the same value (by use of
APFloat's bitcast method).

This change can be reverted once LLVM has 8-bit float types.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D153160

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
    mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
    mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
    mlir/test/Target/LLVMIR/llvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index 1eb5661bb387e..79a68e875f045 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -184,7 +184,8 @@ class LLVMTypeConverter : public TypeConverter {
 
   /// Convert a floating point type: `f16` to `f16`, `f32` to
   /// `f32` and `f64` to `f64`.  `bf16` is not supported
-  /// by LLVM.
+  /// by LLVM. 8-bit float types are converted to 8-bit integers as this is how
+  /// all LLVM backends that support them currently represent them.
   Type convertFloatType(FloatType type);
 
   /// Convert complex number type: `complex<f16>` to `!llvm<"{ half, half }">`,

diff  --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index b518665581bb7..e764c8d30b600 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -539,14 +539,6 @@ struct ConvertAMDGPUToROCDLPass
 void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                                    RewritePatternSet &patterns,
                                                    Chipset chipset) {
-  // ROCDL supports fp8 types in some contexts, but there is no LLVM-level f8
-  // type. Therefore, for this target, declare f8 to be equal to i8.
-  converter.addConversion([](FloatType type) -> std::optional<Type> {
-    if (type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ())
-      return IntegerType::get(type.getContext(), 8);
-    return std::nullopt;
-  });
-
   patterns.add<LDSBarrierOpLowering>(converter);
   patterns.add<
       RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawBufferLoadOp>,

diff  --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index aac6e60b4f50d..21ef3076a5a47 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -193,7 +193,12 @@ Type LLVMTypeConverter::convertIntegerType(IntegerType type) {
   return IntegerType::get(&getContext(), type.getWidth());
 }
 
-Type LLVMTypeConverter::convertFloatType(FloatType type) { return type; }
+Type LLVMTypeConverter::convertFloatType(FloatType type) {
+  if (type.isFloat8E5M2() || type.isFloat8E4M3FN() || type.isFloat8E5M2FNUZ() ||
+      type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ())
+    return IntegerType::get(&getContext(), type.getWidth());
+  return type;
+}
 
 // Convert a `ComplexType` to an LLVM type. The result is a complex number
 // struct with entries for the

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 9f072191804e0..685b031dd509b 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -360,6 +360,12 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
         llvmType,
         intAttr.getValue().sextOrTrunc(llvmType->getIntegerBitWidth()));
   if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
+    const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics();
+    // Special case for 8-bit floats, which are represented by integers due to
+    // the lack of native fp8 types in LLVM at the moment.
+    if (APFloat::getSizeInBits(sem) == 8 && llvmType->isIntegerTy(8))
+      return llvm::ConstantInt::get(llvmType,
+                                    floatAttr.getValue().bitcastToAPInt());
     if (llvmType !=
         llvm::Type::getFloatingPointTy(llvmType->getContext(),
                                        floatAttr.getValue().getSemantics())) {

diff  --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index d23b6a7c289cc..6365eabb34165 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -55,6 +55,21 @@ llvm.mlir.global internal constant @string_const("foobar") : !llvm.array<6 x i8>
 // CHECK: @int_global_undef = internal global i64 undef
 llvm.mlir.global internal @int_global_undef() : i64
 
+// CHECK: @f8E4M3FN_global_as_i8 = internal global i8 60
+llvm.mlir.global internal @f8E4M3FN_global_as_i8(1.5 : f8E4M3FN) : i8
+
+// CHECK: @f8E5M2_global_as_i8 = internal global i8 62
+llvm.mlir.global internal @f8E5M2_global_as_i8(1.5 : f8E5M2) : i8
+
+// CHECK: @f8E4M3FNUZ_global_as_i8 = internal global i8 68
+llvm.mlir.global internal @f8E4M3FNUZ_global_as_i8(1.5 : f8E4M3FNUZ) : i8
+
+// CHECK: @f8E5M2FNUZ_global_as_i8 = internal global i8 66
+llvm.mlir.global internal @f8E5M2FNUZ_global_as_i8(1.5 : f8E5M2FNUZ) : i8
+
+// CHECK: @f8E4M3B11FNUZ_global_as_i8 = internal global i8 92
+llvm.mlir.global internal @f8E4M3B11FNUZ_global_as_i8(1.5 : f8E4M3B11FNUZ) : i8
+
 // CHECK: @explicit_undef = global i32 undef
 llvm.mlir.global external @explicit_undef() : i32 {
   %0 = llvm.mlir.undef : i32


        


More information about the Mlir-commits mailing list