[Mlir-commits] [mlir] [SPIRV] 8-bit float type emulation (PR #148811)

Md Abdullah Shahneous Bari llvmlistbot at llvm.org
Tue Jul 15 03:02:36 PDT 2025


https://github.com/mshahneo updated https://github.com/llvm/llvm-project/pull/148811

>From 7b01513fa1a3d7ebe0b61d2de3ca9b4644ebd1a0 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Tue, 15 Jul 2025 08:38:45 +0000
Subject: [PATCH 1/5] Add 8-bit float emulation for SPIR-V conversion.

SPIR-V does not support any 8-bit floats.
Threfore, 8-bit floats are emulated as 8-bit integers.
---
 mlir/include/mlir/Conversion/Passes.td        | 18 +++-
 .../SPIRV/Transforms/SPIRVConversion.h        |  4 +
 .../ControlFlowToSPIRVPass.cpp                |  2 +
 .../FuncToSPIRV/FuncToSPIRVPass.cpp           |  2 +
 .../TensorToSPIRV/TensorToSPIRVPass.cpp       |  2 +
 .../SPIRV/Transforms/SPIRVConversion.cpp      | 97 ++++++++++++++++++-
 6 files changed, 120 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 50c67da91a4af..0eb9720351027 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -196,6 +196,9 @@ def ConvertArithToSPIRVPass : Pass<"convert-arith-to-spirv"> {
            "bool", /*default=*/"true",
            "Emulate narrower scalar types with 32-bit ones if not supported by "
            "the target">,
+    Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
+    "bool", /*default=*/"true",
+    "Emulate unsupported float types by emulating them with integer types of same bit width">
   ];
 }
 
@@ -404,7 +407,10 @@ def ConvertControlFlowToSPIRVPass : Pass<"convert-cf-to-spirv"> {
     Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
            "bool", /*default=*/"true",
            "Emulate narrower scalar types with 32-bit ones if not supported by"
-           " the target">
+           " the target">,
+    Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
+    "bool", /*default=*/"true",
+    "Emulate unsupported float types by emulating them with integer types of same bit width">
   ];
 }
 
@@ -488,7 +494,10 @@ def ConvertFuncToSPIRVPass : Pass<"convert-func-to-spirv"> {
     Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
            "bool", /*default=*/"true",
            "Emulate narrower scalar types with 32-bit ones if not supported by"
-           " the target">
+           " the target">,
+    Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
+    "bool", /*default=*/"true",
+    "Emulate unsupported float types by emulating them with integer types of same bit width">
   ];
 }
 
@@ -1151,7 +1160,10 @@ def ConvertTensorToSPIRVPass : Pass<"convert-tensor-to-spirv"> {
     Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
            "bool", /*default=*/"true",
            "Emulate narrower scalar types with 32-bit ones if not supported by"
-           " the target">
+           " the target">,
+    Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
+    "bool", /*default=*/"true",
+    "Emulate unsupported float types by emulating them with integer types of same bit width">
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 3d22ec918f4c5..03ae54a8ae30a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -39,6 +39,10 @@ struct SPIRVConversionOptions {
   /// The number of bits to store a boolean value.
   unsigned boolNumBits{8};
 
+  /// Whether to emulate unsupported floats with integer types of same bit
+  /// width.
+  bool emulateUnsupportedFloatTypes{true};
+
   /// How sub-byte values are storaged in memory.
   SPIRVSubByteTypeStorage subByteTypeStorage{SPIRVSubByteTypeStorage::Packed};
 
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
index 03f4bf4df4912..01657cced2281 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
@@ -43,6 +43,8 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() {
 
   SPIRVConversionOptions options;
   options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+  options.emulateUnsupportedFloatTypes =
+      this->emulateUnsupportedFloatTypes;
   SPIRVTypeConverter typeConverter(targetAttr, options);
 
   // TODO: We should also take care of block argument type conversion.
diff --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
index 8ed9f659afb10..ca67079ce9bb1 100644
--- a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
@@ -42,6 +42,8 @@ void ConvertFuncToSPIRVPass::runOnOperation() {
 
   SPIRVConversionOptions options;
   options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+  options.emulateUnsupportedFloatTypes =
+      this->emulateUnsupportedFloatTypes;
   SPIRVTypeConverter typeConverter(targetAttr, options);
 
   RewritePatternSet patterns(context);
diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
index f07386ea80124..309ed8b054628 100644
--- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
@@ -41,6 +41,8 @@ class ConvertTensorToSPIRVPass
 
     SPIRVConversionOptions options;
     options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+    options.emulateUnsupportedFloatTypes =
+      this->emulateUnsupportedFloatTypes;
     SPIRVTypeConverter typeConverter(targetAttr, options);
 
     RewritePatternSet patterns(context);
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index f70b3325f8725..c8c97e7d79188 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -169,6 +169,7 @@ static spirv::ScalarType getIndexType(MLIRContext *ctx,
 // SPIR-V dialect. Keeping it local till the use case arises.
 static std::optional<int64_t>
 getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
+
   if (isa<spirv::ScalarType>(type)) {
     auto bitWidth = type.getIntOrFloatBitWidth();
     // According to the SPIR-V spec:
@@ -182,6 +183,15 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
     return bitWidth / 8;
   }
 
+  // Handle 8-bit floats.
+  if (options.emulateUnsupportedFloatTypes && isa<FloatType>(type)) {
+    auto bitWidth = type.getIntOrFloatBitWidth();
+    if (bitWidth == 8)
+      return bitWidth / 8;
+    else
+      return std::nullopt;
+  }
+
   if (auto complexType = dyn_cast<ComplexType>(type)) {
     auto elementSize = getTypeNumBytes(options, complexType.getElementType());
     if (!elementSize)
@@ -318,6 +328,67 @@ static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
                           type.getSignedness());
 }
 
+/// Converts 8-bit float types to integer types with the same bit width.
+/// Returns a nullptr for unsupported 8-bit float types.
+static Type convert8BitFloatType(const SPIRVConversionOptions &options,
+                                 FloatType type) {
+  if (!options.emulateUnsupportedFloatTypes)
+    return nullptr;
+  // F8 types are converted to integer types with the same bit width.
+  if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
+          Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
+          Float8E8M0FNUType>(type))
+    return IntegerType::get(type.getContext(), type.getWidth());
+  LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type\n");
+  return nullptr;
+}
+
+/// Converts a sub-byte float ``type` to i32 regardless of target environment.
+/// Returns a nullptr for unsupported float types, including non sub-byte
+/// types.
+///
+/// We are treating 8 bit floats as sub-byte types here due to it's similar
+/// nature of being used as a packed format.
+
+/// Note that we don't recognize
+/// sub-byte types in `spirv::ScalarType` and use the above given that these
+/// sub-byte types are not supported at all in SPIR-V; there are no
+/// compute/storage capability for them like other supported integer types.
+
+// static Type convertPackedFLoatType(const SPIRVConversionOptions &options,
+//                                    FloatType type) {
+
+//   // F4, F6, F8 types are converted to integer types with the same bit width.
+
+//   if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
+//   Float8E5M2FNUZType,
+//           Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
+//           Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
+//           Float8E8M0FNUType>(type))
+//     auto emulatedType = IntegerType::get(type.getContext(), type.getWidth());
+
+//   if (type.getWidth() > 8) {
+//     LLVM_DEBUG(llvm::dbgs() << "not a packed type\n");
+//     return nullptr;
+//   }
+//   if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packed) {
+//     LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n");
+//     return nullptr;
+//   }
+
+//   // if (!llvm::isPowerOf2_32(type.getWidth())) {
+//   //   LLVM_DEBUG(llvm::dbgs()
+//   //              << "unsupported non-power-of-two bitwidth in sub-byte" <<
+//   type
+//   //              << "\n");
+//   //   return nullptr;
+//   // }
+
+//   LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
+//   return IntegerType::get(type.getContext(), /*width=*/32,
+//                           type.getSignedness());
+// }
+
 /// Returns a type with the same shape but with any index element type converted
 /// to the matching integer type. This is a noop when the element type is not
 /// the index type.
@@ -339,8 +410,20 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
   type = cast<VectorType>(convertIndexElementType(type, options));
   auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
   if (!scalarType) {
-    // If this is not a spec allowed scalar type, try to handle sub-byte integer
-    // types.
+    // If this is not a spec allowed scalar type, there are 2 scenarios,
+    // 8 bit floats or sub-byte integer types. try to handle them accrodingly.
+
+    // Hnadle 8 bit float types.
+    auto floatType = dyn_cast<FloatType>(type.getElementType());
+    if (floatType && floatType.getWidth() == 8) {
+      // If this is an 8 bit float type, try to convert it to a supported
+      // integer type.
+      if (auto convertedType = convert8BitFloatType(options, floatType)) {
+        return VectorType::get(type.getShape(), convertedType);
+      }
+    }
+
+    // Handle sub-byte integer types.
     auto intType = dyn_cast<IntegerType>(type.getElementType());
     if (!intType) {
       LLVM_DEBUG(llvm::dbgs()
@@ -596,6 +679,14 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
   } else if (auto indexType = dyn_cast<IndexType>(elementType)) {
     type = cast<MemRefType>(convertIndexElementType(type, options));
     arrayElemType = type.getElementType();
+  } else if (auto floatType = dyn_cast<FloatType>(elementType)) {
+    // Hnadle 8 bit float types.
+    if (options.emulateUnsupportedFloatTypes && floatType &&
+        floatType.getWidth() == 8) {
+      // If this is an 8 bit float type, try to convert it to a supported
+      // integer type.
+      arrayElemType = convert8BitFloatType(options, floatType);
+    }
   } else {
     LLVM_DEBUG(
         llvm::dbgs()
@@ -1439,6 +1530,8 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
   addConversion([this](FloatType floatType) -> std::optional<Type> {
     if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
       return convertScalarType(this->targetEnv, this->options, scalarType);
+    if (floatType.getWidth() == 8)
+      return convert8BitFloatType(this->options, floatType);
     return Type();
   });
 

>From 0c0912312c94188b291f5f347f00ed13703b6321 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Tue, 15 Jul 2025 08:40:31 +0000
Subject: [PATCH 2/5] Add arith.constant support.

Handles scalar and vector.
---
 .../Conversion/ArithToSPIRV/ArithToSPIRV.cpp  | 30 ++++++++++++++++---
 1 file changed, 26 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 434d7df853a5e..ccd3560addf9c 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -99,6 +99,14 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
   return builder.getF32FloatAttr(dstVal.convertToFloat());
 }
 
+// Get IntegerAttr from FloatAttr.
+IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
+                                        ConversionPatternRewriter &rewriter) {
+  APFloat floatVal = floatAttr.getValue();
+  APInt intVal = floatVal.bitcastToAPInt();
+  return rewriter.getIntegerAttr(dstType, intVal);
+}
+
 /// Returns true if the given `type` is a boolean scalar or vector type.
 static bool isBoolScalarOrVector(Type type) {
   assert(type && "Not a valid type");
@@ -296,8 +304,16 @@ struct ConstantCompositeOpPattern final
       SmallVector<Attribute, 8> elements;
       if (isa<FloatType>(srcElemType)) {
         for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
-          FloatAttr dstAttr =
-              convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter);
+          Attribute dstAttr = nullptr;
+          // Handle 8-bit float conversion to 8-bit integer.
+          if (srcElemType.getIntOrFloatBitWidth() == 8 &&
+              isa<IntegerType>(dstElemType)) {
+            dstAttr =
+                getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter);
+          } else {
+            dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType),
+                                       rewriter);
+          }
           if (!dstAttr)
             return failure();
           elements.push_back(dstAttr);
@@ -361,11 +377,17 @@ struct ConstantScalarOpPattern final
     // Floating-point types.
     if (isa<FloatType>(srcType)) {
       auto srcAttr = cast<FloatAttr>(cstAttr);
-      auto dstAttr = srcAttr;
+      Attribute dstAttr = srcAttr;
 
       // Floating-point types not supported in the target environment are all
       // converted to float type.
-      if (srcType != dstType) {
+      if (srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) &&
+          dstType.getIntOrFloatBitWidth() == 8) {
+        // If the source is an 8-bit float, convert it to a 8-bit integer.
+        dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter);
+        if (!dstAttr)
+          return failure();
+      } else if (srcType != dstType) {
         dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
         if (!dstAttr)
           return failure();

>From 2de9adfd121026ec764f9c7a14ae7dc4905c08ea Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Tue, 15 Jul 2025 09:26:06 +0000
Subject: [PATCH 3/5] Handle all Shaped Type 8-bit floats in a similar way.

This approach minimizes the code modification.
---
 .../SPIRV/Transforms/SPIRVConversion.cpp      | 55 ++++++++++++-------
 1 file changed, 35 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index c8c97e7d79188..285ee11749fe7 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -343,6 +343,29 @@ static Type convert8BitFloatType(const SPIRVConversionOptions &options,
   return nullptr;
 }
 
+/// Returns a type with the same shape but with any 8-bit float element type
+/// converted to the same bit width integer type. This is a noop when the
+/// element type is not the 8-bit float type.
+static ShapedType
+convertShaped8BitFloatType(ShapedType type,
+                           const SPIRVConversionOptions &options) {
+  if (!options.emulateUnsupportedFloatTypes)
+    return nullptr;
+  auto srcElementType = type.getElementType();
+  Type convertedElementType = nullptr;
+  // F8 types are converted to integer types with the same bit width.
+  if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
+          Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
+          Float8E8M0FNUType>(srcElementType))
+    convertedElementType = IntegerType::get(
+        type.getContext(), srcElementType.getIntOrFloatBitWidth());
+
+  if (!convertedElementType)
+    return type;
+
+  return type.clone(convertedElementType);
+}
+
 /// Converts a sub-byte float ``type` to i32 regardless of target environment.
 /// Returns a nullptr for unsupported float types, including non sub-byte
 /// types.
@@ -408,22 +431,11 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
                   const SPIRVConversionOptions &options, VectorType type,
                   std::optional<spirv::StorageClass> storageClass = {}) {
   type = cast<VectorType>(convertIndexElementType(type, options));
+  type = cast<VectorType>(convertShaped8BitFloatType(type, options));
   auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
   if (!scalarType) {
-    // If this is not a spec allowed scalar type, there are 2 scenarios,
-    // 8 bit floats or sub-byte integer types. try to handle them accrodingly.
-
-    // Hnadle 8 bit float types.
-    auto floatType = dyn_cast<FloatType>(type.getElementType());
-    if (floatType && floatType.getWidth() == 8) {
-      // If this is an 8 bit float type, try to convert it to a supported
-      // integer type.
-      if (auto convertedType = convert8BitFloatType(options, floatType)) {
-        return VectorType::get(type.getShape(), convertedType);
-      }
-    }
-
-    // Handle sub-byte integer types.
+    // If this is not a spec allowed scalar type, try to handle sub-byte integer
+    // types.
     auto intType = dyn_cast<IntegerType>(type.getElementType());
     if (!intType) {
       LLVM_DEBUG(llvm::dbgs()
@@ -516,6 +528,7 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
   }
 
   type = cast<TensorType>(convertIndexElementType(type, options));
+  type = cast<TensorType>(convertShaped8BitFloatType(type, options));
   auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
   if (!scalarType) {
     LLVM_DEBUG(llvm::dbgs()
@@ -681,12 +694,14 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
     arrayElemType = type.getElementType();
   } else if (auto floatType = dyn_cast<FloatType>(elementType)) {
     // Hnadle 8 bit float types.
-    if (options.emulateUnsupportedFloatTypes && floatType &&
-        floatType.getWidth() == 8) {
-      // If this is an 8 bit float type, try to convert it to a supported
-      // integer type.
-      arrayElemType = convert8BitFloatType(options, floatType);
-    }
+    type = cast<MemRefType>(convertShaped8BitFloatType(type, options));
+    arrayElemType = type.getElementType();
+    // if (options.emulateUnsupportedFloatTypes && floatType &&
+    //     floatType.getWidth() == 8) {
+    //   // If this is an 8 bit float type, try to convert it to a supported
+    //   // integer type.
+    //   arrayElemType = convert8BitFloatType(options, floatType);
+    // }
   } else {
     LLVM_DEBUG(
         llvm::dbgs()

>From b90c58b911fbb95cf79f03d454cd8fff76639124 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Tue, 15 Jul 2025 09:29:29 +0000
Subject: [PATCH 4/5] Remove commented out code.

---
 .../SPIRV/Transforms/SPIRVConversion.cpp      | 46 -------------------
 1 file changed, 46 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 285ee11749fe7..3fc6ac6410e64 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -366,52 +366,6 @@ convertShaped8BitFloatType(ShapedType type,
   return type.clone(convertedElementType);
 }
 
-/// Converts a sub-byte float ``type` to i32 regardless of target environment.
-/// Returns a nullptr for unsupported float types, including non sub-byte
-/// types.
-///
-/// We are treating 8 bit floats as sub-byte types here due to it's similar
-/// nature of being used as a packed format.
-
-/// Note that we don't recognize
-/// sub-byte types in `spirv::ScalarType` and use the above given that these
-/// sub-byte types are not supported at all in SPIR-V; there are no
-/// compute/storage capability for them like other supported integer types.
-
-// static Type convertPackedFLoatType(const SPIRVConversionOptions &options,
-//                                    FloatType type) {
-
-//   // F4, F6, F8 types are converted to integer types with the same bit width.
-
-//   if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
-//   Float8E5M2FNUZType,
-//           Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
-//           Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
-//           Float8E8M0FNUType>(type))
-//     auto emulatedType = IntegerType::get(type.getContext(), type.getWidth());
-
-//   if (type.getWidth() > 8) {
-//     LLVM_DEBUG(llvm::dbgs() << "not a packed type\n");
-//     return nullptr;
-//   }
-//   if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packed) {
-//     LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n");
-//     return nullptr;
-//   }
-
-//   // if (!llvm::isPowerOf2_32(type.getWidth())) {
-//   //   LLVM_DEBUG(llvm::dbgs()
-//   //              << "unsupported non-power-of-two bitwidth in sub-byte" <<
-//   type
-//   //              << "\n");
-//   //   return nullptr;
-//   // }
-
-//   LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
-//   return IntegerType::get(type.getContext(), /*width=*/32,
-//                           type.getSignedness());
-// }
-
 /// Returns a type with the same shape but with any index element type converted
 /// to the matching integer type. This is a noop when the element type is not
 /// the index type.

>From 3755267cf45eeca6bade403e1446854ad2f69979 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Tue, 15 Jul 2025 09:37:11 +0000
Subject: [PATCH 5/5] Remove unnecessary commented out code.

---
 mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp | 6 ------
 1 file changed, 6 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 3fc6ac6410e64..588461b4f5c47 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -650,12 +650,6 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
     // Hnadle 8 bit float types.
     type = cast<MemRefType>(convertShaped8BitFloatType(type, options));
     arrayElemType = type.getElementType();
-    // if (options.emulateUnsupportedFloatTypes && floatType &&
-    //     floatType.getWidth() == 8) {
-    //   // If this is an 8 bit float type, try to convert it to a supported
-    //   // integer type.
-    //   arrayElemType = convert8BitFloatType(options, floatType);
-    // }
   } else {
     LLVM_DEBUG(
         llvm::dbgs()



More information about the Mlir-commits mailing list