[Mlir-commits] [mlir] 2f086f2 - [APFloat] Add E4M3B11FNUZ

David Majnemer llvmlistbot at llvm.org
Fri Mar 24 13:07:15 PDT 2023


Author: David Majnemer
Date: 2023-03-24T20:06:40Z
New Revision: 2f086f265bf97fe6543fb199f4ef874ca3522479

URL: https://github.com/llvm/llvm-project/commit/2f086f265bf97fe6543fb199f4ef874ca3522479
DIFF: https://github.com/llvm/llvm-project/commit/2f086f265bf97fe6543fb199f4ef874ca3522479.diff

LOG: [APFloat] Add E4M3B11FNUZ

X. Sun et al. (https://dl.acm.org/doi/10.5555/3454287.3454728) published
a paper showing that an FP format with 4 bits of exponent, 3 bits of
significand and an exponent bias of 11 would work quite well for ML
applications.

Google hardware supports a variant of this format where 0x80 is used to
represent NaN, as in the Float8E4M3FNUZ format. Just like the
Float8E4M3FNUZ format, this format does not support -0 and values which
would map to it will become +0.

This format is proposed for inclusion in OpenXLA's StableHLO dialect: https://github.com/openxla/stablehlo/pull/1308

As part of inclusion in that dialect, APFloat needs to know how to
handle this format.

Differential Revision: https://reviews.llvm.org/D146441

Added: 
    

Modified: 
    clang/lib/AST/MicrosoftMangle.cpp
    llvm/include/llvm/ADT/APFloat.h
    llvm/lib/Support/APFloat.cpp
    llvm/unittests/ADT/APFloatTest.cpp
    mlir/include/mlir-c/BuiltinTypes.h
    mlir/include/mlir/IR/Builders.h
    mlir/include/mlir/IR/BuiltinTypes.h
    mlir/include/mlir/IR/BuiltinTypes.td
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/IR/Types.h
    mlir/lib/AsmParser/TokenKinds.def
    mlir/lib/AsmParser/TypeParser.cpp
    mlir/lib/Bindings/Python/IRTypes.cpp
    mlir/lib/CAPI/IR/BuiltinTypes.cpp
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/Builders.cpp
    mlir/lib/IR/BuiltinTypes.cpp
    mlir/lib/IR/MLIRContext.cpp
    mlir/lib/IR/Types.cpp
    mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
    mlir/test/IR/attribute.mlir
    mlir/test/python/ir/builtin_types.py
    mlir/utils/lldb-scripts/mlirDataFormatters.py

Removed: 
    


################################################################################
diff  --git a/clang/lib/AST/MicrosoftMangle.cpp b/clang/lib/AST/MicrosoftMangle.cpp
index b8a916a72f750..d8c837bcade02 100644
--- a/clang/lib/AST/MicrosoftMangle.cpp
+++ b/clang/lib/AST/MicrosoftMangle.cpp
@@ -845,6 +845,7 @@ void MicrosoftCXXNameMangler::mangleFloat(llvm::APFloat Number) {
   case APFloat::S_Float8E4M3FN:
   case APFloat::S_Float8E5M2FNUZ:
   case APFloat::S_Float8E4M3FNUZ:
+  case APFloat::S_Float8E4M3B11FNUZ:
     llvm_unreachable("Tried to mangle unexpected APFloat semantics");
   }
 

diff  --git a/llvm/include/llvm/ADT/APFloat.h b/llvm/include/llvm/ADT/APFloat.h
index c3064651fbc56..d7fcf02decea7 100644
--- a/llvm/include/llvm/ADT/APFloat.h
+++ b/llvm/include/llvm/ADT/APFloat.h
@@ -177,6 +177,13 @@ struct APFloatBase {
     // This format's exponent bias is 8, instead of the 7 (2 ** (4 - 1) - 1)
     // that IEEE precedent would imply.
     S_Float8E4M3FNUZ,
+    // 8-bit floating point number mostly following IEEE-754 conventions
+    // and bit layout S1E4M3 with expanded range and with no infinity or signed
+    // zero.
+    // NaN is represnted as negative zero. (FN -> Finite, UZ -> unsigned zero).
+    // This format's exponent bias is 11, instead of the 7 (2 ** (4 - 1) - 1)
+    // that IEEE precedent would imply.
+    S_Float8E4M3B11FNUZ,
 
     S_x87DoubleExtended,
     S_MaxSemantics = S_x87DoubleExtended,
@@ -195,6 +202,7 @@ struct APFloatBase {
   static const fltSemantics &Float8E5M2FNUZ() LLVM_READNONE;
   static const fltSemantics &Float8E4M3FN() LLVM_READNONE;
   static const fltSemantics &Float8E4M3FNUZ() LLVM_READNONE;
+  static const fltSemantics &Float8E4M3B11FNUZ() LLVM_READNONE;
   static const fltSemantics &x87DoubleExtended() LLVM_READNONE;
 
   /// A Pseudo fltsemantic used to construct APFloats that cannot conflict with
@@ -590,6 +598,7 @@ class IEEEFloat final : public APFloatBase {
   APInt convertFloat8E5M2FNUZAPFloatToAPInt() const;
   APInt convertFloat8E4M3FNAPFloatToAPInt() const;
   APInt convertFloat8E4M3FNUZAPFloatToAPInt() const;
+  APInt convertFloat8E4M3B11FNUZAPFloatToAPInt() const;
   void initFromAPInt(const fltSemantics *Sem, const APInt &api);
   void initFromHalfAPInt(const APInt &api);
   void initFromBFloatAPInt(const APInt &api);
@@ -602,6 +611,7 @@ class IEEEFloat final : public APFloatBase {
   void initFromFloat8E5M2FNUZAPInt(const APInt &api);
   void initFromFloat8E4M3FNAPInt(const APInt &api);
   void initFromFloat8E4M3FNUZAPInt(const APInt &api);
+  void initFromFloat8E4M3B11FNUZAPInt(const APInt &api);
 
   void assign(const IEEEFloat &);
   void copySignificand(const IEEEFloat &);

diff  --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp
index 05053828bfc18..97c811a18e8af 100644
--- a/llvm/lib/Support/APFloat.cpp
+++ b/llvm/lib/Support/APFloat.cpp
@@ -60,8 +60,9 @@ enum class fltNonfiniteBehavior {
   IEEE754,
 
   // This behavior is present in the Float8ExMyFN* types (Float8E4M3FN,
-  // Float8E5M2FNUZ, and Float8E4M3FNUZ). There is no representation for Inf,
-  // and operations that would ordinarily produce Inf produce NaN instead.
+  // Float8E5M2FNUZ, Float8E4M3FNUZ, and Float8E4M3B11FNUZ). There is no
+  // representation for Inf, and operations that would ordinarily produce Inf
+  // produce NaN instead.
   // The details of the NaN representation(s) in this form are determined by the
   // `fltNanEncoding` enum. We treat all NaNs as quiet, as the available
   // encodings do not distinguish between signalling and quiet NaN.
@@ -138,6 +139,13 @@ struct fltSemantics {
       8, -6, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::AllOnes};
   static const fltSemantics semFloat8E4M3FNUZ = {
       7, -7, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero};
+  static const fltSemantics semFloat8E4M3B11FNUZ = {
+      4,
+      -10,
+      4,
+      8,
+      fltNonfiniteBehavior::NanOnly,
+      fltNanEncoding::NegativeZero};
   static const fltSemantics semX87DoubleExtended = {16383, -16382, 64, 80};
   static const fltSemantics semBogus = {0, 0, 0, 0};
 
@@ -201,6 +209,8 @@ struct fltSemantics {
       return Float8E4M3FN();
     case S_Float8E4M3FNUZ:
       return Float8E4M3FNUZ();
+    case S_Float8E4M3B11FNUZ:
+      return Float8E4M3B11FNUZ();
     case S_x87DoubleExtended:
       return x87DoubleExtended();
     }
@@ -229,6 +239,8 @@ struct fltSemantics {
       return S_Float8E4M3FN;
     else if (&Sem == &llvm::APFloat::Float8E4M3FNUZ())
       return S_Float8E4M3FNUZ;
+    else if (&Sem == &llvm::APFloat::Float8E4M3B11FNUZ())
+      return S_Float8E4M3B11FNUZ;
     else if (&Sem == &llvm::APFloat::x87DoubleExtended())
       return S_x87DoubleExtended;
     else
@@ -259,6 +271,9 @@ struct fltSemantics {
   const fltSemantics &APFloatBase::Float8E4M3FNUZ() {
     return semFloat8E4M3FNUZ;
   }
+  const fltSemantics &APFloatBase::Float8E4M3B11FNUZ() {
+    return semFloat8E4M3B11FNUZ;
+  }
   const fltSemantics &APFloatBase::x87DoubleExtended() {
     return semX87DoubleExtended;
   }
@@ -3709,6 +3724,33 @@ APInt IEEEFloat::convertFloat8E4M3FNUZAPFloatToAPInt() const {
                    (mysignificand & 0x7)));
 }
 
+APInt IEEEFloat::convertFloat8E4M3B11FNUZAPFloatToAPInt() const {
+  assert(semantics == (const llvm::fltSemantics *)&semFloat8E4M3B11FNUZ);
+  assert(partCount() == 1);
+
+  uint32_t myexponent, mysignificand;
+
+  if (isFiniteNonZero()) {
+    myexponent = exponent + 11; // bias
+    mysignificand = (uint32_t)*significandParts();
+    if (myexponent == 1 && !(mysignificand & 0x8))
+      myexponent = 0; // denormal
+  } else if (category == fcZero) {
+    myexponent = 0;
+    mysignificand = 0;
+  } else if (category == fcInfinity) {
+    myexponent = 0;
+    mysignificand = 0;
+  } else {
+    assert(category == fcNaN && "Unknown category!");
+    myexponent = 0;
+    mysignificand = (uint32_t)*significandParts();
+  }
+
+  return APInt(8, (((sign & 1) << 7) | ((myexponent & 0xf) << 3) |
+                   (mysignificand & 0x7)));
+}
+
 // This function creates an APInt that is just a bit map of the floating
 // point constant as it would appear in memory.  It is not a conversion,
 // and treating the result as a normal integer is unlikely to be useful.
@@ -3744,6 +3786,9 @@ APInt IEEEFloat::bitcastToAPInt() const {
   if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3FNUZ)
     return convertFloat8E4M3FNUZAPFloatToAPInt();
 
+  if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3B11FNUZ)
+    return convertFloat8E4M3B11FNUZAPFloatToAPInt();
+
   assert(semantics == (const llvm::fltSemantics*)&semX87DoubleExtended &&
          "unknown format!");
   return convertF80LongDoubleAPFloatToAPInt();
@@ -4077,6 +4122,32 @@ void IEEEFloat::initFromFloat8E4M3FNUZAPInt(const APInt &api) {
   }
 }
 
+void IEEEFloat::initFromFloat8E4M3B11FNUZAPInt(const APInt &api) {
+  uint32_t i = (uint32_t)*api.getRawData();
+  uint32_t myexponent = (i >> 3) & 0xf;
+  uint32_t mysignificand = i & 0x7;
+
+  initialize(&semFloat8E4M3B11FNUZ);
+  assert(partCount() == 1);
+
+  sign = i >> 7;
+  if (myexponent == 0 && mysignificand == 0 && sign == 0) {
+    makeZero(sign);
+  } else if (myexponent == 0 && mysignificand == 0 && sign == 1) {
+    category = fcNaN;
+    exponent = exponentNaN();
+    *significandParts() = mysignificand;
+  } else {
+    category = fcNormal;
+    exponent = myexponent - 11; // bias
+    *significandParts() = mysignificand;
+    if (myexponent == 0) // denormal
+      exponent = -10;
+    else
+      *significandParts() |= 0x8; // integer bit
+  }
+}
+
 /// Treat api as containing the bits of a floating point number.
 void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
   assert(api.getBitWidth() == Sem->sizeInBits);
@@ -4102,6 +4173,8 @@ void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
     return initFromFloat8E4M3FNAPInt(api);
   if (Sem == &semFloat8E4M3FNUZ)
     return initFromFloat8E4M3FNUZAPInt(api);
+  if (Sem == &semFloat8E4M3B11FNUZ)
+    return initFromFloat8E4M3B11FNUZAPInt(api);
 
   llvm_unreachable(nullptr);
 }

diff  --git a/llvm/unittests/ADT/APFloatTest.cpp b/llvm/unittests/ADT/APFloatTest.cpp
index 79c6a3f2c53ee..cbf59acd77ee0 100644
--- a/llvm/unittests/ADT/APFloatTest.cpp
+++ b/llvm/unittests/ADT/APFloatTest.cpp
@@ -1346,6 +1346,10 @@ TEST(APFloatTest, makeNaN) {
     {               0x80ULL, APFloat::Float8E4M3FNUZ(), false, true,            0xaaULL },
     {               0x80ULL, APFloat::Float8E4M3FNUZ(), true, false,            0xaaULL },
     {               0x80ULL, APFloat::Float8E4M3FNUZ(), true, true,             0xaaULL },
+    {               0x80ULL, APFloat::Float8E4M3B11FNUZ(), false, false,        0xaaULL },
+    {               0x80ULL, APFloat::Float8E4M3B11FNUZ(), false, true,         0xaaULL },
+    {               0x80ULL, APFloat::Float8E4M3B11FNUZ(), true, false,         0xaaULL },
+    {               0x80ULL, APFloat::Float8E4M3B11FNUZ(), true, true,          0xaaULL },
       // clang-format on
   };
 
@@ -1774,6 +1778,8 @@ TEST(APFloatTest, getLargest) {
             APFloat::getLargest(APFloat::Float8E4M3FNUZ()).convertToDouble());
   EXPECT_EQ(57344,
             APFloat::getLargest(APFloat::Float8E5M2FNUZ()).convertToDouble());
+  EXPECT_EQ(
+      30, APFloat::getLargest(APFloat::Float8E4M3B11FNUZ()).convertToDouble());
 }
 
 TEST(APFloatTest, getSmallest) {
@@ -1818,6 +1824,13 @@ TEST(APFloatTest, getSmallest) {
   EXPECT_TRUE(test.isFiniteNonZero());
   EXPECT_TRUE(test.isDenormal());
   EXPECT_TRUE(test.bitwiseIsEqual(expected));
+
+  test = APFloat::getSmallest(APFloat::Float8E4M3B11FNUZ(), false);
+  expected = APFloat(APFloat::Float8E4M3B11FNUZ(), "0x0.2p-10");
+  EXPECT_FALSE(test.isNegative());
+  EXPECT_TRUE(test.isFiniteNonZero());
+  EXPECT_TRUE(test.isDenormal());
+  EXPECT_TRUE(test.bitwiseIsEqual(expected));
 }
 
 TEST(APFloatTest, getSmallestNormalized) {
@@ -1884,6 +1897,14 @@ TEST(APFloatTest, getSmallestNormalized) {
   EXPECT_FALSE(test.isDenormal());
   EXPECT_TRUE(test.bitwiseIsEqual(expected));
   EXPECT_TRUE(test.isSmallestNormalized());
+
+  test = APFloat::getSmallestNormalized(APFloat::Float8E4M3B11FNUZ(), false);
+  expected = APFloat(APFloat::Float8E4M3B11FNUZ(), "0x1.0p-10");
+  EXPECT_FALSE(test.isNegative());
+  EXPECT_TRUE(test.isFiniteNonZero());
+  EXPECT_FALSE(test.isDenormal());
+  EXPECT_TRUE(test.bitwiseIsEqual(expected));
+  EXPECT_TRUE(test.isSmallestNormalized());
 }
 
 TEST(APFloatTest, getZero) {
@@ -1913,7 +1934,9 @@ TEST(APFloatTest, getZero) {
       {&APFloat::Float8E4M3FN(), false, true, {0, 0}, 1},
       {&APFloat::Float8E4M3FN(), true, true, {0x80ULL, 0}, 1},
       {&APFloat::Float8E4M3FNUZ(), false, false, {0, 0}, 1},
-      {&APFloat::Float8E4M3FNUZ(), true, false, {0, 0}, 1}};
+      {&APFloat::Float8E4M3FNUZ(), true, false, {0, 0}, 1},
+      {&APFloat::Float8E4M3B11FNUZ(), false, false, {0, 0}, 1},
+      {&APFloat::Float8E4M3B11FNUZ(), true, false, {0, 0}, 1}};
   const unsigned NumGetZeroTests = std::size(GetZeroTest);
   for (unsigned i = 0; i < NumGetZeroTests; ++i) {
     APFloat test = APFloat::getZero(*GetZeroTest[i].semantics,
@@ -1944,14 +1967,14 @@ TEST(APFloatTest, copySign) {
   EXPECT_TRUE(APFloat(42.0).bitwiseIsEqual(
       APFloat::copySign(APFloat(42.0), APFloat(1.0))));
   // For floating-point formats with unsigned 0, copySign() to a zero is a noop
-  EXPECT_TRUE(
-      APFloat::getZero(APFloat::Float8E4M3FNUZ())
-          .bitwiseIsEqual(APFloat::copySign(
-              APFloat::getZero(APFloat::Float8E4M3FNUZ()), APFloat(-1.0))));
-  EXPECT_TRUE(
-      APFloat::getNaN(APFloat::Float8E4M3FNUZ(), true)
-          .bitwiseIsEqual(APFloat::copySign(
-              APFloat::getNaN(APFloat::Float8E4M3FNUZ(), true), APFloat(1.0))));
+  for (APFloat::Semantics S :
+       {APFloat::S_Float8E4M3FNUZ, APFloat::S_Float8E4M3B11FNUZ}) {
+    const llvm::fltSemantics &Sem = APFloat::EnumToSemantics(S);
+    EXPECT_TRUE(APFloat::getZero(Sem).bitwiseIsEqual(
+        APFloat::copySign(APFloat::getZero(Sem), APFloat(-1.0))));
+    EXPECT_TRUE(APFloat::getNaN(Sem, true).bitwiseIsEqual(
+        APFloat::copySign(APFloat::getNaN(Sem, true), APFloat(1.0))));
+  }
 }
 
 TEST(APFloatTest, convert) {
@@ -2073,17 +2096,18 @@ TEST(APFloatTest, Float8UZConvert) {
       {APFloat::getSNaN(APFloat::IEEEsingle(), true), APFloat::opInvalidOp},
       {APFloat::getInf(APFloat::IEEEsingle(), false), APFloat::opInexact},
       {APFloat::getInf(APFloat::IEEEsingle(), true), APFloat::opInexact}};
-  for (auto [toTest, expectedRes] : toNaNTests) {
-    llvm::SmallString<16> value;
-    toTest.toString(value);
-    SCOPED_TRACE("toTest = " + value);
-    for (const fltSemantics *sem :
-         {&APFloat::Float8E4M3FNUZ(), &APFloat::Float8E5M2FNUZ()}) {
-      SCOPED_TRACE("Semantics = " +
-                   std::to_string(APFloat::SemanticsToEnum(*sem)));
+  for (APFloat::Semantics S :
+       {APFloat::S_Float8E5M2FNUZ, APFloat::S_Float8E4M3FNUZ,
+        APFloat::S_Float8E4M3B11FNUZ}) {
+    const llvm::fltSemantics &Sem = APFloat::EnumToSemantics(S);
+    SCOPED_TRACE("Semantics = " + std::to_string(S));
+    for (auto [toTest, expectedRes] : toNaNTests) {
+      llvm::SmallString<16> value;
+      toTest.toString(value);
+      SCOPED_TRACE("toTest = " + value);
       losesInfo = false;
       APFloat test = toTest;
-      EXPECT_EQ(test.convert(*sem, APFloat::rmNearestTiesToAway, &losesInfo),
+      EXPECT_EQ(test.convert(Sem, APFloat::rmNearestTiesToAway, &losesInfo),
                 expectedRes);
       EXPECT_TRUE(test.isNaN());
       EXPECT_TRUE(test.isNegative());
@@ -2092,37 +2116,34 @@ TEST(APFloatTest, Float8UZConvert) {
       EXPECT_EQ(0x80, test.bitcastToAPInt());
       EXPECT_TRUE(losesInfo);
     }
-  }
 
-  // Negative zero conversions are information losing.
-  losesInfo = false;
-  APFloat test = APFloat::getZero(APFloat::IEEEsingle(), true);
-  EXPECT_EQ(test.convert(APFloat::Float8E5M2FNUZ(),
-                         APFloat::rmNearestTiesToAway, &losesInfo),
-            APFloat::opInexact);
-  EXPECT_TRUE(test.isZero());
-  EXPECT_FALSE(test.isNegative());
-  EXPECT_TRUE(losesInfo);
-  EXPECT_EQ(0x0, test.bitcastToAPInt());
-
-  losesInfo = true;
-  test = APFloat::getZero(APFloat::IEEEsingle(), false);
-  EXPECT_EQ(test.convert(APFloat::Float8E5M2FNUZ(),
-                         APFloat::rmNearestTiesToAway, &losesInfo),
-            APFloat::opOK);
-  EXPECT_TRUE(test.isZero());
-  EXPECT_FALSE(test.isNegative());
-  EXPECT_FALSE(losesInfo);
-  EXPECT_EQ(0x0, test.bitcastToAPInt());
+    // Negative zero conversions are information losing.
+    losesInfo = false;
+    APFloat test = APFloat::getZero(APFloat::IEEEsingle(), true);
+    EXPECT_EQ(test.convert(Sem, APFloat::rmNearestTiesToAway, &losesInfo),
+              APFloat::opInexact);
+    EXPECT_TRUE(test.isZero());
+    EXPECT_FALSE(test.isNegative());
+    EXPECT_TRUE(losesInfo);
+    EXPECT_EQ(0x0, test.bitcastToAPInt());
+
+    losesInfo = true;
+    test = APFloat::getZero(APFloat::IEEEsingle(), false);
+    EXPECT_EQ(test.convert(Sem, APFloat::rmNearestTiesToAway, &losesInfo),
+              APFloat::opOK);
+    EXPECT_TRUE(test.isZero());
+    EXPECT_FALSE(test.isNegative());
+    EXPECT_FALSE(losesInfo);
+    EXPECT_EQ(0x0, test.bitcastToAPInt());
 
-  // Except in casts between ourselves.
-  losesInfo = true;
-  test = APFloat::getZero(APFloat::Float8E5M2FNUZ());
-  EXPECT_EQ(test.convert(APFloat::Float8E4M3FNUZ(),
-                         APFloat::rmNearestTiesToAway, &losesInfo),
-            APFloat::opOK);
-  EXPECT_FALSE(losesInfo);
-  EXPECT_EQ(0x0, test.bitcastToAPInt());
+    // Except in casts between ourselves.
+    losesInfo = true;
+    test = APFloat::getZero(Sem);
+    EXPECT_EQ(test.convert(Sem, APFloat::rmNearestTiesToAway, &losesInfo),
+              APFloat::opOK);
+    EXPECT_FALSE(losesInfo);
+    EXPECT_EQ(0x0, test.bitcastToAPInt());
+  }
 }
 
 TEST(APFloatTest, PPCDoubleDouble) {
@@ -5003,7 +5024,7 @@ TEST(APFloatTest, Float8ExhaustivePair) {
   // Test each pair of 8-bit floats with non-standard semantics
   for (APFloat::Semantics Sem :
        {APFloat::S_Float8E4M3FN, APFloat::S_Float8E5M2FNUZ,
-        APFloat::S_Float8E4M3FNUZ}) {
+        APFloat::S_Float8E4M3FNUZ, APFloat::S_Float8E4M3B11FNUZ}) {
     const llvm::fltSemantics &S = APFloat::EnumToSemantics(Sem);
     for (int i = 0; i < 256; i++) {
       for (int j = 0; j < 256; j++) {
@@ -5483,50 +5504,54 @@ TEST(APFloatTest, UnsignedZeroArithmeticSpecial) {
   // cases and so are not repeated here.
 
   // The IEEE round towards negative rule doesn't apply
-  APFloat test = APFloat::getSmallest(APFloat::Float8E4M3FNUZ());
-  APFloat rhs = test;
-  EXPECT_EQ(test.subtract(rhs, APFloat::rmTowardNegative), APFloat::opOK);
-  EXPECT_TRUE(test.isZero());
-  EXPECT_FALSE(test.isNegative());
-
-  // Multiplication of (small) * (-small) is +0
-  test = APFloat::getSmallestNormalized(APFloat::Float8E4M3FNUZ());
-  rhs = -test;
-  EXPECT_EQ(test.multiply(rhs, APFloat::rmNearestTiesToAway),
-            APFloat::opInexact | APFloat::opUnderflow);
-  EXPECT_TRUE(test.isZero());
-  EXPECT_FALSE(test.isNegative());
+  for (APFloat::Semantics S :
+       {APFloat::S_Float8E4M3FNUZ, APFloat::S_Float8E4M3B11FNUZ}) {
+    const llvm::fltSemantics &Sem = APFloat::EnumToSemantics(S);
+    APFloat test = APFloat::getSmallest(Sem);
+    APFloat rhs = test;
+    EXPECT_EQ(test.subtract(rhs, APFloat::rmTowardNegative), APFloat::opOK);
+    EXPECT_TRUE(test.isZero());
+    EXPECT_FALSE(test.isNegative());
 
-  // Dividing the negatize float_min by anything gives +0
-  test = APFloat::getSmallest(APFloat::Float8E4M3FNUZ(), true);
-  rhs = APFloat(APFloat::Float8E4M3FNUZ(), "2.0");
-  EXPECT_EQ(test.divide(rhs, APFloat::rmNearestTiesToEven),
-            APFloat::opInexact | APFloat::opUnderflow);
-  EXPECT_TRUE(test.isZero());
-  EXPECT_FALSE(test.isNegative());
+    // Multiplication of (small) * (-small) is +0
+    test = APFloat::getSmallestNormalized(Sem);
+    rhs = -test;
+    EXPECT_EQ(test.multiply(rhs, APFloat::rmNearestTiesToAway),
+              APFloat::opInexact | APFloat::opUnderflow);
+    EXPECT_TRUE(test.isZero());
+    EXPECT_FALSE(test.isNegative());
 
-  // Remainder can't copy sign because there's only one zero
-  test = APFloat(APFloat::Float8E4M3FNUZ(), "-4.0");
-  rhs = APFloat(APFloat::Float8E4M3FNUZ(), "2.0");
-  EXPECT_EQ(test.remainder(rhs), APFloat::opOK);
-  EXPECT_TRUE(test.isZero());
-  EXPECT_FALSE(test.isNegative());
+    // Dividing the negatize float_min by anything gives +0
+    test = APFloat::getSmallest(Sem, true);
+    rhs = APFloat(Sem, "2.0");
+    EXPECT_EQ(test.divide(rhs, APFloat::rmNearestTiesToEven),
+              APFloat::opInexact | APFloat::opUnderflow);
+    EXPECT_TRUE(test.isZero());
+    EXPECT_FALSE(test.isNegative());
 
-  // And same for mod
-  test = APFloat(APFloat::Float8E4M3FNUZ(), "-4.0");
-  rhs = APFloat(APFloat::Float8E4M3FNUZ(), "2.0");
-  EXPECT_EQ(test.mod(rhs), APFloat::opOK);
-  EXPECT_TRUE(test.isZero());
-  EXPECT_FALSE(test.isNegative());
+    // Remainder can't copy sign because there's only one zero
+    test = APFloat(Sem, "-4.0");
+    rhs = APFloat(Sem, "2.0");
+    EXPECT_EQ(test.remainder(rhs), APFloat::opOK);
+    EXPECT_TRUE(test.isZero());
+    EXPECT_FALSE(test.isNegative());
 
-  // FMA correctly handles both the multiply and add parts of all this
-  test = APFloat(APFloat::Float8E4M3FNUZ(), "2.0");
-  rhs = test;
-  APFloat addend = APFloat(APFloat::Float8E4M3FNUZ(), "-4.0");
-  EXPECT_EQ(test.fusedMultiplyAdd(rhs, addend, APFloat::rmTowardNegative),
-            APFloat::opOK);
-  EXPECT_TRUE(test.isZero());
-  EXPECT_FALSE(test.isNegative());
+    // And same for mod
+    test = APFloat(Sem, "-4.0");
+    rhs = APFloat(Sem, "2.0");
+    EXPECT_EQ(test.mod(rhs), APFloat::opOK);
+    EXPECT_TRUE(test.isZero());
+    EXPECT_FALSE(test.isNegative());
+
+    // FMA correctly handles both the multiply and add parts of all this
+    test = APFloat(Sem, "2.0");
+    rhs = test;
+    APFloat addend = APFloat(Sem, "-4.0");
+    EXPECT_EQ(test.fusedMultiplyAdd(rhs, addend, APFloat::rmTowardNegative),
+              APFloat::opOK);
+    EXPECT_TRUE(test.isZero());
+    EXPECT_FALSE(test.isNegative());
+  }
 }
 
 TEST(APFloatTest, Float8E5M2FNUZAdd) {
@@ -5590,7 +5615,8 @@ TEST(APFloatTest, Float8UnsignedZeroExhaustive) {
     const double largest;
     const double smallest;
   } const exhaustiveTests[] = {{&APFloat::Float8E5M2FNUZ(), 57344., 0x1.0p-17},
-                               {&APFloat::Float8E4M3FNUZ(), 240., 0x1.0p-10}};
+                               {&APFloat::Float8E4M3FNUZ(), 240., 0x1.0p-10},
+                               {&APFloat::Float8E4M3B11FNUZ(), 30., 0x1.0p-13}};
   for (const auto &testInfo : exhaustiveTests) {
     const fltSemantics &sem = *testInfo.semantics;
     SCOPED_TRACE("Semantics=" + std::to_string(APFloat::SemanticsToEnum(sem)));
@@ -5634,71 +5660,79 @@ TEST(APFloatTest, Float8UnsignedZeroExhaustive) {
 }
 
 TEST(APFloatTest, Float8E4M3FNUZNext) {
-  APFloat test(APFloat::Float8E4M3FNUZ(), APFloat::uninitialized);
-  APFloat expected(APFloat::Float8E4M3FNUZ(), APFloat::uninitialized);
-
-  // 1. NextUp of largest bit pattern is nan
-  test = APFloat::getLargest(APFloat::Float8E4M3FNUZ());
-  expected = APFloat::getNaN(APFloat::Float8E4M3FNUZ());
-  EXPECT_EQ(test.next(false), APFloat::opOK);
-  EXPECT_FALSE(test.isInfinity());
-  EXPECT_FALSE(test.isZero());
-  EXPECT_TRUE(test.isNaN());
-  EXPECT_TRUE(test.bitwiseIsEqual(expected));
+  for (APFloat::Semantics S :
+       {APFloat::S_Float8E4M3FNUZ, APFloat::S_Float8E4M3B11FNUZ}) {
+    const llvm::fltSemantics &Sem = APFloat::EnumToSemantics(S);
+    APFloat test(Sem, APFloat::uninitialized);
+    APFloat expected(Sem, APFloat::uninitialized);
+
+    // 1. NextUp of largest bit pattern is nan
+    test = APFloat::getLargest(Sem);
+    expected = APFloat::getNaN(Sem);
+    EXPECT_EQ(test.next(false), APFloat::opOK);
+    EXPECT_FALSE(test.isInfinity());
+    EXPECT_FALSE(test.isZero());
+    EXPECT_TRUE(test.isNaN());
+    EXPECT_TRUE(test.bitwiseIsEqual(expected));
 
-  // 2. NextUp of smallest negative denormal is +0
-  test = APFloat::getSmallest(APFloat::Float8E4M3FNUZ(), true);
-  expected = APFloat::getZero(APFloat::Float8E4M3FNUZ(), false);
-  EXPECT_EQ(test.next(false), APFloat::opOK);
-  EXPECT_FALSE(test.isNegZero());
-  EXPECT_TRUE(test.isPosZero());
-  EXPECT_TRUE(test.bitwiseIsEqual(expected));
+    // 2. NextUp of smallest negative denormal is +0
+    test = APFloat::getSmallest(Sem, true);
+    expected = APFloat::getZero(Sem, false);
+    EXPECT_EQ(test.next(false), APFloat::opOK);
+    EXPECT_FALSE(test.isNegZero());
+    EXPECT_TRUE(test.isPosZero());
+    EXPECT_TRUE(test.bitwiseIsEqual(expected));
 
-  // 3. nextDown of negative of largest value is NaN
-  test = APFloat::getLargest(APFloat::Float8E4M3FNUZ(), true);
-  expected = APFloat::getNaN(APFloat::Float8E4M3FNUZ());
-  EXPECT_EQ(test.next(true), APFloat::opOK);
-  EXPECT_FALSE(test.isInfinity());
-  EXPECT_FALSE(test.isZero());
-  EXPECT_TRUE(test.isNaN());
-  EXPECT_TRUE(test.bitwiseIsEqual(expected));
+    // 3. nextDown of negative of largest value is NaN
+    test = APFloat::getLargest(Sem, true);
+    expected = APFloat::getNaN(Sem);
+    EXPECT_EQ(test.next(true), APFloat::opOK);
+    EXPECT_FALSE(test.isInfinity());
+    EXPECT_FALSE(test.isZero());
+    EXPECT_TRUE(test.isNaN());
+    EXPECT_TRUE(test.bitwiseIsEqual(expected));
 
-  // 4. nextDown of +0 is smallest negative denormal
-  test = APFloat::getZero(APFloat::Float8E4M3FNUZ(), false);
-  expected = APFloat::getSmallest(APFloat::Float8E4M3FNUZ(), true);
-  EXPECT_EQ(test.next(true), APFloat::opOK);
-  EXPECT_FALSE(test.isZero());
-  EXPECT_TRUE(test.isDenormal());
-  EXPECT_TRUE(test.bitwiseIsEqual(expected));
+    // 4. nextDown of +0 is smallest negative denormal
+    test = APFloat::getZero(Sem, false);
+    expected = APFloat::getSmallest(Sem, true);
+    EXPECT_EQ(test.next(true), APFloat::opOK);
+    EXPECT_FALSE(test.isZero());
+    EXPECT_TRUE(test.isDenormal());
+    EXPECT_TRUE(test.bitwiseIsEqual(expected));
 
-  // 5. nextUp of NaN is NaN
-  test = APFloat::getNaN(APFloat::Float8E4M3FNUZ(), false);
-  expected = APFloat::getNaN(APFloat::Float8E4M3FNUZ(), true);
-  EXPECT_EQ(test.next(false), APFloat::opOK);
-  EXPECT_TRUE(test.isNaN());
+    // 5. nextUp of NaN is NaN
+    test = APFloat::getNaN(Sem, false);
+    expected = APFloat::getNaN(Sem, true);
+    EXPECT_EQ(test.next(false), APFloat::opOK);
+    EXPECT_TRUE(test.isNaN());
 
-  // 6. nextDown of NaN is NaN
-  test = APFloat::getNaN(APFloat::Float8E4M3FNUZ(), false);
-  expected = APFloat::getNaN(APFloat::Float8E4M3FNUZ(), true);
-  EXPECT_EQ(test.next(true), APFloat::opOK);
-  EXPECT_TRUE(test.isNaN());
+    // 6. nextDown of NaN is NaN
+    test = APFloat::getNaN(Sem, false);
+    expected = APFloat::getNaN(Sem, true);
+    EXPECT_EQ(test.next(true), APFloat::opOK);
+    EXPECT_TRUE(test.isNaN());
+  }
 }
 
 TEST(APFloatTest, Float8E4M3FNUZChangeSign) {
-  APFloat test = APFloat(APFloat::Float8E4M3FNUZ(), "1.0");
-  APFloat expected = APFloat(APFloat::Float8E4M3FNUZ(), "-1.0");
-  test.changeSign();
-  EXPECT_TRUE(test.bitwiseIsEqual(expected));
+  for (APFloat::Semantics S :
+       {APFloat::S_Float8E4M3FNUZ, APFloat::S_Float8E4M3B11FNUZ}) {
+    const llvm::fltSemantics &Sem = APFloat::EnumToSemantics(S);
+    APFloat test = APFloat(Sem, "1.0");
+    APFloat expected = APFloat(Sem, "-1.0");
+    test.changeSign();
+    EXPECT_TRUE(test.bitwiseIsEqual(expected));
 
-  test = APFloat::getZero(APFloat::Float8E4M3FNUZ());
-  expected = test;
-  test.changeSign();
-  EXPECT_TRUE(test.bitwiseIsEqual(expected));
+    test = APFloat::getZero(Sem);
+    expected = test;
+    test.changeSign();
+    EXPECT_TRUE(test.bitwiseIsEqual(expected));
 
-  test = APFloat::getNaN(APFloat::Float8E4M3FNUZ());
-  expected = test;
-  test.changeSign();
-  EXPECT_TRUE(test.bitwiseIsEqual(expected));
+    test = APFloat::getNaN(Sem);
+    expected = test;
+    test.changeSign();
+    EXPECT_TRUE(test.bitwiseIsEqual(expected));
+  }
 }
 
 TEST(APFloatTest, Float8E4M3FNUZFromString) {

diff  --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index 8b855d8c39a4d..2b7606f3d9caf 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -95,6 +95,13 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type);
 /// context.
 MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx);
 
+/// Checks whether the given type is an f8E4M3B11FNUZ type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type);
+
+/// Creates an f8E4M3B11FNUZ type in the given context. The type is owned by the
+/// context.
+MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx);
+
 /// Checks whether the given type is a bf16 type.
 MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type);
 

diff  --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index f970d89dd410f..7197b1364bbce 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -64,6 +64,7 @@ class Builder {
   FloatType getFloat8E4M3FNType();
   FloatType getFloat8E5M2FNUZType();
   FloatType getFloat8E4M3FNUZType();
+  FloatType getFloat8E4M3B11FNUZType();
   FloatType getBF16Type();
   FloatType getF16Type();
   FloatType getF32Type();

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 33995f34ee39b..baee29c554c3d 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -49,6 +49,7 @@ class FloatType : public Type {
   static FloatType getFloat8E4M3FN(MLIRContext *ctx);
   static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx);
   static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
+  static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool classof(Type type);
@@ -376,9 +377,10 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
 }
 
 inline bool FloatType::classof(Type type) {
-  return type.isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
-                  Float8E4M3FNUZType, BFloat16Type, Float16Type, Float32Type,
-                  Float64Type, Float80Type, Float128Type>();
+  return type
+      .isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
+           Float8E4M3FNUZType, Float8E4M3B11FNUZType, BFloat16Type, Float16Type,
+           Float32Type, Float64Type, Float80Type, Float128Type>();
 }
 
 inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) {
@@ -397,6 +399,10 @@ inline FloatType FloatType::getFloat8E4M3FNUZ(MLIRContext *ctx) {
   return Float8E4M3FNUZType::get(ctx);
 }
 
+inline FloatType FloatType::getFloat8E4M3B11FNUZ(MLIRContext *ctx) {
+  return Float8E4M3B11FNUZType::get(ctx);
+}
+
 inline FloatType FloatType::getBF16(MLIRContext *ctx) {
   return BFloat16Type::get(ctx);
 }

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 4d6120ff9df88..3e41625c0b688 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -162,6 +162,28 @@ def Builtin_Float8E4M3FNUZ : Builtin_FloatType<"Float8E4M3FNUZ"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Float8E4M3B11FNUZType
+
+def Builtin_Float8E4M3B11FNUZ : Builtin_FloatType<"Float8E4M3B11FNUZ"> {
+  let summary = "8-bit floating point with 3 bit mantissa";
+  let description = [{
+    An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits
+    mantissa. This is not a standard type as defined by IEEE-754, but it follows
+    similar conventions, with the exception that there are no infinity values,
+    no negative zero, and only one NaN representation. This type has the
+    following characteristics:
+
+      * bit encoding: S1E4M3
+      * exponent bias: 11
+      * infinities: Not supported
+      * NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s
+      * denormals when exponent is 0
+
+    Related to: https://dl.acm.org/doi/10.5555/3454287.3454728
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // BFloat16Type
 

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 527ccc05e090d..554f02675363a 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -490,6 +490,8 @@ def F8E5M2 : Type<CPred<"$_self.isFloat8E5M2()">, "f8E5M2 type">,
              BuildableType<"$_builder.getFloat8E5M2Type()">;
 def F8E4M3FNUZ : Type<CPred<"$_self.isFloat8E4M3FNUZ()">, "f8E4M3FNUZ type">,
                  BuildableType<"$_builder.getFloat8E4M3FNUZType()">;
+def F8E4M3B11FNUZ : Type<CPred<"$_self.isFloat8E4M3B11FNUZ()">, "f8E4M3B11FNUZ type">,
+                 BuildableType<"$_builder.getFloat8E4M3B11FNUZType()">;
 def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
                  BuildableType<"$_builder.getFloat8E5M2FNUZType()">;
 

diff  --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index fc377bee3cefc..adefc2908b552 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -124,6 +124,7 @@ class Type {
   bool isFloat8E4M3FN() const;
   bool isFloat8E5M2FNUZ() const;
   bool isFloat8E4M3FNUZ() const;
+  bool isFloat8E4M3B11FNUZ() const;
   bool isBF16() const;
   bool isF16() const;
   bool isF32() const;

diff  --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index 0e666c792b9de..9a632e3570fb5 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -97,6 +97,7 @@ TOK_KEYWORD(f8E5M2)
 TOK_KEYWORD(f8E4M3FN)
 TOK_KEYWORD(f8E5M2FNUZ)
 TOK_KEYWORD(f8E4M3FNUZ)
+TOK_KEYWORD(f8E4M3B11FNUZ)
 TOK_KEYWORD(f128)
 TOK_KEYWORD(false)
 TOK_KEYWORD(floordiv)

diff  --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 47078c1ba0472..737767ce9101b 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -35,6 +35,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
   case Token::kw_f8E4M3FN:
   case Token::kw_f8E5M2FNUZ:
   case Token::kw_f8E4M3FNUZ:
+  case Token::kw_f8E4M3B11FNUZ:
   case Token::kw_bf16:
   case Token::kw_f16:
   case Token::kw_f32:
@@ -303,6 +304,9 @@ Type Parser::parseNonFunctionType() {
   case Token::kw_f8E4M3FNUZ:
     consumeToken(Token::kw_f8E4M3FNUZ);
     return builder.getFloat8E4M3FNUZType();
+  case Token::kw_f8E4M3B11FNUZ:
+    consumeToken(Token::kw_f8E4M3B11FNUZ);
+    return builder.getFloat8E4M3B11FNUZType();
   case Token::kw_bf16:
     consumeToken(Token::kw_bf16);
     return builder.getBF16Type();

diff  --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 2166bab902a13..6d381b1c01701 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -157,6 +157,24 @@ class PyFloat8E4M3FNUZType : public PyConcreteType<PyFloat8E4M3FNUZType> {
   }
 };
 
+/// Floating Point Type subclass - Float8E4M3B11FNUZ.
+class PyFloat8E4M3B11FNUZType : public PyConcreteType<PyFloat8E4M3B11FNUZType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
+  static constexpr const char *pyClassName = "Float8E4M3B11FNUZType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get());
+          return PyFloat8E4M3B11FNUZType(context->getRef(), t);
+        },
+        py::arg("context") = py::none(), "Create a float8_e4m3b11fnuz type.");
+  }
+};
+
 /// Floating Point Type subclass - Float8E5M2FNUZ.
 class PyFloat8E5M2FNUZType : public PyConcreteType<PyFloat8E5M2FNUZType> {
 public:

diff  --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index aea1221200af6..2468c05463f4a 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -100,6 +100,14 @@ MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx) {
   return wrap(FloatType::getFloat8E4M3FNUZ(unwrap(ctx)));
 }
 
+bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type) {
+  return unwrap(type).isFloat8E4M3B11FNUZ();
+}
+
+MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) {
+  return wrap(FloatType::getFloat8E4M3B11FNUZ(unwrap(ctx)));
+}
+
 bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }
 
 MlirType mlirBF16TypeGet(MlirContext ctx) {

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 75448955f3123..1820803f24258 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2428,6 +2428,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
       .Case<Float8E4M3FNType>([&](Type) { os << "f8E4M3FN"; })
       .Case<Float8E5M2FNUZType>([&](Type) { os << "f8E5M2FNUZ"; })
       .Case<Float8E4M3FNUZType>([&](Type) { os << "f8E4M3FNUZ"; })
+      .Case<Float8E4M3B11FNUZType>([&](Type) { os << "f8E4M3B11FNUZ"; })
       .Case<BFloat16Type>([&](Type) { os << "bf16"; })
       .Case<Float16Type>([&](Type) { os << "f16"; })
       .Case<Float32Type>([&](Type) { os << "f32"; })

diff  --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 8eab32b201a04..9203943123470 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -49,6 +49,10 @@ FloatType Builder::getFloat8E4M3FNUZType() {
   return FloatType::getFloat8E4M3FNUZ(context);
 }
 
+FloatType Builder::getFloat8E4M3B11FNUZType() {
+  return FloatType::getFloat8E4M3B11FNUZ(context);
+}
+
 FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
 
 FloatType Builder::getF16Type() { return FloatType::getF16(context); }

diff  --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 6e6c6b9683c78..7e95ca137fdd1 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -89,7 +89,7 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
 
 unsigned FloatType::getWidth() {
   if (isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
-          Float8E4M3FNUZType>())
+          Float8E4M3FNUZType, Float8E4M3B11FNUZType>())
     return 8;
   if (isa<Float16Type, BFloat16Type>())
     return 16;
@@ -114,6 +114,8 @@ const llvm::fltSemantics &FloatType::getFloatSemantics() {
     return APFloat::Float8E5M2FNUZ();
   if (isa<Float8E4M3FNUZType>())
     return APFloat::Float8E4M3FNUZ();
+  if (isa<Float8E4M3B11FNUZType>())
+    return APFloat::Float8E4M3B11FNUZ();
   if (isa<BFloat16Type>())
     return APFloat::BFloat();
   if (isa<Float16Type>())

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 29d49edf9efb9..daa4a6af63020 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -214,6 +214,7 @@ class MLIRContextImpl {
   Float8E4M3FNType f8E4M3FNTy;
   Float8E5M2FNUZType f8E5M2FNUZTy;
   Float8E4M3FNUZType f8E4M3FNUZTy;
+  Float8E4M3B11FNUZType f8E4M3B11FNUZTy;
   BFloat16Type bf16Ty;
   Float16Type f16Ty;
   Float32Type f32Ty;
@@ -288,6 +289,7 @@ MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
   impl->f8E4M3FNTy = TypeUniquer::get<Float8E4M3FNType>(this);
   impl->f8E5M2FNUZTy = TypeUniquer::get<Float8E5M2FNUZType>(this);
   impl->f8E4M3FNUZTy = TypeUniquer::get<Float8E4M3FNUZType>(this);
+  impl->f8E4M3B11FNUZTy = TypeUniquer::get<Float8E4M3B11FNUZType>(this);
   impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
   impl->f16Ty = TypeUniquer::get<Float16Type>(this);
   impl->f32Ty = TypeUniquer::get<Float32Type>(this);
@@ -892,6 +894,9 @@ Float8E5M2FNUZType Float8E5M2FNUZType::get(MLIRContext *context) {
 Float8E4M3FNUZType Float8E4M3FNUZType::get(MLIRContext *context) {
   return context->getImpl().f8E4M3FNUZTy;
 }
+Float8E4M3B11FNUZType Float8E4M3B11FNUZType::get(MLIRContext *context) {
+  return context->getImpl().f8E4M3B11FNUZTy;
+}
 BFloat16Type BFloat16Type::get(MLIRContext *context) {
   return context->getImpl().bf16Ty;
 }

diff  --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index e739786bd3990..d3d1d860d5d32 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -38,6 +38,7 @@ bool Type::isFloat8E5M2() const { return isa<Float8E5M2Type>(); }
 bool Type::isFloat8E4M3FN() const { return isa<Float8E4M3FNType>(); }
 bool Type::isFloat8E5M2FNUZ() const { return isa<Float8E5M2FNUZType>(); }
 bool Type::isFloat8E4M3FNUZ() const { return isa<Float8E4M3FNUZType>(); }
+bool Type::isFloat8E4M3B11FNUZ() const { return isa<Float8E4M3B11FNUZType>(); }
 bool Type::isBF16() const { return isa<BFloat16Type>(); }
 bool Type::isF16() const { return isa<Float16Type>(); }
 bool Type::isF32() const { return isa<Float32Type>(); }

diff  --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 7d5ff23f60ab7..75b25bd8c1c9d 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -53,6 +53,7 @@ __all__ = [
     "Float8E4M3FNType",
     "Float8E5M2Type",
     "Float8E4M3FNUZType",
+    "Float8E4M3B11FNUZType",
     "Float8E5M2FNUZType",
     "F16Type",
     "F32Type",
@@ -602,6 +603,13 @@ class Float8E4M3FNUZType(Type):
     @staticmethod
     def isinstance(arg: Any) -> bool: ...
 
+class Float8E4M3B11FNUZType(Type):
+    def __init__(self, cast_from_type: Type) -> None: ...
+    @staticmethod
+    def get(*args, **kwargs) -> Float8E4M3B11FNUZType: ...
+    @staticmethod
+    def isinstance(arg: Any) -> bool: ...
+
 class Float8E5M2FNUZType(Type):
     def __init__(self, cast_from_type: Type) -> None: ...
     @staticmethod

diff  --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index de840f950f458..c296507868cbc 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -52,6 +52,10 @@ func.func @float_attrs_pass() {
     // CHECK: float_attr = 2.000000e+00 : f8E4M3FNUZ
     float_attr = 2. : f8E4M3FNUZ
   } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 2.000000e+00 : f8E4M3B11FNUZ
+    float_attr = 2. : f8E4M3B11FNUZ
+  } : () -> ()
   "test.float_attrs"() {
     // CHECK: float_attr = 2.000000e+00 : f16
     float_attr = 2. : f16

diff  --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 594cc6620e396..e383a78f40b8a 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -202,6 +202,8 @@ def testFloatType():
     print("float:", Float8E5M2FNUZType.get())
     # CHECK: float: f8E4M3FNUZ
     print("float:", Float8E4M3FNUZType.get())
+    # CHECK: float: f8E4M3B11FNUZ
+    print("float:", Float8E4M3B11FNUZType.get())
     # CHECK: float: bf16
     print("float:", BF16Type.get())
     # CHECK: float: f16

diff  --git a/mlir/utils/lldb-scripts/mlirDataFormatters.py b/mlir/utils/lldb-scripts/mlirDataFormatters.py
index 908a734f6e30d..bfd76a7d0ca28 100644
--- a/mlir/utils/lldb-scripts/mlirDataFormatters.py
+++ b/mlir/utils/lldb-scripts/mlirDataFormatters.py
@@ -54,6 +54,7 @@ def build_ptr_str_from_addr(addrValue: lldb.SBValue, type: lldb.SBType):
     "mlir::Float8E4M3FNType": '"f8E4M3FN"',
     "mlir::Float8E5M2FNUZType": '"f8E5M2FNUZ"',
     "mlir::Float8E4M3FNUZType": '"f8E4M3FNUZ"',
+    "mlir::Float8E4M3B11FNUZType": '"f8E4M3B11FNUZ"',
     "mlir::BFloat16Type": '"bf16"',
     "mlir::Float16Type": '"f16"',
     "mlir::Float32Type": '"f32"',


        


More information about the Mlir-commits mailing list