[Mlir-commits] [llvm] [mlir] Rename f8E4M3 to f8E4M3FN in mlir.extras.types py package (PR #97102)

Alexander Pivovarov llvmlistbot at llvm.org
Fri Jun 28 12:42:04 PDT 2024


https://github.com/apivovarov created https://github.com/llvm/llvm-project/pull/97102

Currently `f8E4M3` is mapped to `Float8E4M3FNType`.

This PR renames `f8E4M3` to `f8E4M3FN` to accurately reflect the actual type.

This PR is needed to avoid names conflict in upcoming PR which will add IEEE 754 `Float8E4M3Type`.

Maksim, can you review this PR? @makslevental ?

>From 3f3a0307dd32d870bda2268c9a7428427119f6ef Mon Sep 17 00:00:00 2001
From: Alexander Pivovarov <pivovaa at amazon.com>
Date: Fri, 28 Jun 2024 19:33:21 +0000
Subject: [PATCH] Rename f8E4M3 to f8E4M3FN in mlir.extras.types py package

---
 llvm/lib/Support/APFloat.cpp       |  4 ++--
 llvm/unittests/ADT/APFloatTest.cpp | 12 ++++++------
 mlir/python/mlir/extras/types.py   |  2 +-
 3 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp
index 47618bc325951..3017a9b976658 100644
--- a/llvm/lib/Support/APFloat.cpp
+++ b/llvm/lib/Support/APFloat.cpp
@@ -83,8 +83,8 @@ enum class fltNanEncoding {
   // exponent is all 1s and the significand is non-zero.
   IEEE,
 
-  // Represents the behavior in the Float8E4M3 floating point type where NaN is
-  // represented by having the exponent and mantissa set to all 1s.
+  // Represents the behavior in the Float8E4M3FN floating point type where NaN
+  // is represented by having the exponent and mantissa set to all 1s.
   // This behavior matches the FP8 E4M3 type described in
   // https://arxiv.org/abs/2209.05433. We treat both signed and unsigned NaNs
   // as non-signalling, although the paper does not state whether the NaN
diff --git a/llvm/unittests/ADT/APFloatTest.cpp b/llvm/unittests/ADT/APFloatTest.cpp
index cf6bbd313c6c6..86a25f4394e19 100644
--- a/llvm/unittests/ADT/APFloatTest.cpp
+++ b/llvm/unittests/ADT/APFloatTest.cpp
@@ -5508,8 +5508,8 @@ TEST(APFloatTest, ConvertE4M3FNToE5M2) {
   EXPECT_TRUE(losesInfo);
   EXPECT_EQ(status, APFloat::opInexact);
 
-  // Convert E4M3 denormal to E5M2 normal. Should not be truncated, despite the
-  // destination format having one fewer significand bit
+  // Convert E4M3FN denormal to E5M2 normal. Should not be truncated, despite
+  // the destination format having one fewer significand bit
   test = APFloat(APFloat::Float8E4M3FN(), "0x1.Cp-7");
   status = test.convert(APFloat::Float8E5M2(), APFloat::rmNearestTiesToEven,
                         &losesInfo);
@@ -5647,8 +5647,8 @@ TEST(APFloatTest, Float8E4M3FNAdd) {
     int category;
     APFloat::roundingMode roundingMode = APFloat::rmNearestTiesToEven;
   } AdditionTests[] = {
-      // Test addition operations involving NaN, overflow, and the max E4M3
-      // value (448) because E4M3 differs from IEEE-754 types in these regards
+      // Test addition operations involving NaN, overflow, and the max E4M3FN
+      // value (448) because E4M3FN differs from IEEE-754 types in these regards
       {FromStr("448"), FromStr("16"), "448", APFloat::opInexact,
        APFloat::fcNormal},
       {FromStr("448"), FromStr("18"), "NaN",
@@ -6278,8 +6278,8 @@ TEST(APFloatTest, ConvertE4M3FNUZToE5M2FNUZ) {
   EXPECT_TRUE(losesInfo);
   EXPECT_EQ(status, APFloat::opInexact);
 
-  // Convert E4M3 denormal to E5M2 normal. Should not be truncated, despite the
-  // destination format having one fewer significand bit
+  // Convert E4M3FNUZ denormal to E5M2 normal. Should not be truncated, despite
+  // the destination format having one fewer significand bit
   losesInfo = true;
   test = APFloat(APFloat::Float8E4M3FNUZ(), "0x1.Cp-8");
   status = test.convert(APFloat::Float8E5M2FNUZ(), APFloat::rmNearestTiesToEven,
diff --git a/mlir/python/mlir/extras/types.py b/mlir/python/mlir/extras/types.py
index db9e8229fb288..b93c46b172a9a 100644
--- a/mlir/python/mlir/extras/types.py
+++ b/mlir/python/mlir/extras/types.py
@@ -68,7 +68,7 @@ def ui(width):
 bf16 = lambda: BF16Type.get()
 
 f8E5M2 = lambda: Float8E5M2Type.get()
-f8E4M3 = lambda: Float8E4M3FNType.get()
+f8E4M3FN = lambda: Float8E4M3FNType.get()
 f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get()
 
 none = lambda: NoneType.get()



More information about the Mlir-commits mailing list