[Mlir-commits] [mlir] 6685fd8 - [mlir] Add support for TF32 as a Builtin FloatType

Mehdi Amini llvmlistbot at llvm.org
Thu Jul 6 08:56:28 PDT 2023


Author: Jeremy Furtek
Date: 2023-07-06T08:56:07-07:00
New Revision: 6685fd82391d3e654d3b05f2d54cdcdec6e6d887

URL: https://github.com/llvm/llvm-project/commit/6685fd82391d3e654d3b05f2d54cdcdec6e6d887
DIFF: https://github.com/llvm/llvm-project/commit/6685fd82391d3e654d3b05f2d54cdcdec6e6d887.diff

LOG: [mlir] Add support for TF32 as a Builtin FloatType

This diff adds support for TF32 as a Builtin floating point type. This
supplements the recent addition of the TF32 semantic to the LLVM APFloat class
by extending usage to MLIR.

https://reviews.llvm.org/D151923

More information on the TF32 type can be found here:

https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D153705

Added: 
    

Modified: 
    mlir/include/mlir-c/BuiltinTypes.h
    mlir/include/mlir/IR/Builders.h
    mlir/include/mlir/IR/BuiltinTypes.h
    mlir/include/mlir/IR/BuiltinTypes.td
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/IR/Types.h
    mlir/lib/AsmParser/TokenKinds.def
    mlir/lib/AsmParser/TypeParser.cpp
    mlir/lib/Bindings/Python/IRTypes.cpp
    mlir/lib/CAPI/IR/BuiltinTypes.cpp
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/Builders.cpp
    mlir/lib/IR/BuiltinTypes.cpp
    mlir/lib/IR/MLIRContext.cpp
    mlir/lib/IR/Types.cpp
    mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
    mlir/test/IR/attribute.mlir
    mlir/test/python/ir/builtin_types.py
    mlir/utils/gdb-scripts/prettyprinters.py
    mlir/utils/lldb-scripts/mlirDataFormatters.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index c8ea44cd94faed..a6d8e10efbde92 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -163,6 +163,16 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAF64(MlirType type);
 /// context.
 MLIR_CAPI_EXPORTED MlirType mlirF64TypeGet(MlirContext ctx);
 
+/// Returns the typeID of a TF32 type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFloatTF32TypeGetTypeID(void);
+
+/// Checks whether the given type is an TF32 type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsATF32(MlirType type);
+
+/// Creates a TF32 type in the given context. The type is owned by the
+/// context.
+MLIR_CAPI_EXPORTED MlirType mlirTF32TypeGet(MlirContext ctx);
+
 //===----------------------------------------------------------------------===//
 // None type.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 4dbeb418099d1e..9be71b065c09bb 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -67,6 +67,7 @@ class Builder {
   FloatType getFloat8E4M3B11FNUZType();
   FloatType getBF16Type();
   FloatType getF16Type();
+  FloatType getTF32Type();
   FloatType getF32Type();
   FloatType getF64Type();
   FloatType getF80Type();

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index f22421aa7a428d..2fb8852cde94b4 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -44,6 +44,7 @@ class FloatType : public Type {
   static FloatType getBF16(MLIRContext *ctx);
   static FloatType getF16(MLIRContext *ctx);
   static FloatType getF32(MLIRContext *ctx);
+  static FloatType getTF32(MLIRContext *ctx);
   static FloatType getF64(MLIRContext *ctx);
   static FloatType getF80(MLIRContext *ctx);
   static FloatType getF128(MLIRContext *ctx);
@@ -417,8 +418,8 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
 inline bool FloatType::classof(Type type) {
   return llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
                    Float8E4M3FNUZType, Float8E4M3B11FNUZType, BFloat16Type,
-                   Float16Type, Float32Type, Float64Type, Float80Type,
-                   Float128Type>(type);
+                   Float16Type, FloatTF32Type, Float32Type, Float64Type,
+                   Float80Type, Float128Type>(type);
 }
 
 inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) {
@@ -449,6 +450,10 @@ inline FloatType FloatType::getF16(MLIRContext *ctx) {
   return Float16Type::get(ctx);
 }
 
+inline FloatType FloatType::getTF32(MLIRContext *ctx) {
+  return FloatTF32Type::get(ctx);
+}
+
 inline FloatType FloatType::getF32(MLIRContext *ctx) {
   return Float32Type::get(ctx);
 }

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 900531b1953c4b..75e85f9887f2f1 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -198,6 +198,13 @@ def Builtin_Float16 : Builtin_FloatType<"Float16"> {
   let summary = "16-bit floating-point type";
 }
 
+//===----------------------------------------------------------------------===//
+// FloatTF32Type
+
+def Builtin_FloatTF32 : Builtin_FloatType<"FloatTF32"> {
+  let summary = "TF32 floating-point type";
+}
+
 //===----------------------------------------------------------------------===//
 // Float32Type
 

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index fd5480ab57a3b0..40674988114f2e 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -570,6 +570,8 @@ def F128 : F<128>;
 
 def BF16 : Type<CPred<"$_self.isBF16()">, "bfloat16 type">,
            BuildableType<"$_builder.getBF16Type()">;
+def TF32 : Type<CPred<"$_self.isTF32()">, "tf32 type">,
+           BuildableType<"$_builder.getTF32Type()">;
 def F8E4M3FN : Type<CPred<"$_self.isFloat8E4M3FN()">, "f8E4M3FN type">,
                BuildableType<"$_builder.getFloat8E4M3FNType()">;
 def F8E5M2 : Type<CPred<"$_self.isFloat8E5M2()">, "f8E5M2 type">,

diff  --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index ed08041964c8db..5c4e06da829d9c 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -127,6 +127,7 @@ class Type {
   bool isFloat8E4M3B11FNUZ() const;
   bool isBF16() const;
   bool isF16() const;
+  bool isTF32() const;
   bool isF32() const;
   bool isF64() const;
   bool isF80() const;

diff  --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index 9a632e3570fb5e..1b5aa10e4ac1dc 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -117,6 +117,7 @@ TOK_KEYWORD(step)
 TOK_KEYWORD(strided)
 TOK_KEYWORD(symbol)
 TOK_KEYWORD(tensor)
+TOK_KEYWORD(tf32)
 TOK_KEYWORD(to)
 TOK_KEYWORD(true)
 TOK_KEYWORD(tuple)

diff  --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 6a65dda505a1c1..306e850af27bc5 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -38,6 +38,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
   case Token::kw_f8E4M3B11FNUZ:
   case Token::kw_bf16:
   case Token::kw_f16:
+  case Token::kw_tf32:
   case Token::kw_f32:
   case Token::kw_f64:
   case Token::kw_f80:
@@ -313,6 +314,9 @@ Type Parser::parseNonFunctionType() {
   case Token::kw_f16:
     consumeToken(Token::kw_f16);
     return builder.getF16Type();
+  case Token::kw_tf32:
+    consumeToken(Token::kw_tf32);
+    return builder.getTF32Type();
   case Token::kw_f32:
     consumeToken(Token::kw_f32);
     return builder.getF32Type();

diff  --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 25307262bddbd2..caf215be85baa8 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -247,6 +247,26 @@ class PyF16Type : public PyConcreteType<PyF16Type> {
   }
 };
 
+/// Floating Point Type subclass - TF32Type.
+class PyTF32Type : public PyConcreteType<PyTF32Type> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloatTF32TypeGetTypeID;
+  static constexpr const char *pyClassName = "FloatTF32Type";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          MlirType t = mlirTF32TypeGet(context->get());
+          return PyTF32Type(context->getRef(), t);
+        },
+        py::arg("context") = py::none(), "Create a tf32 type.");
+  }
+};
+
 /// Floating Point Type subclass - F32Type.
 class PyF32Type : public PyConcreteType<PyF32Type> {
 public:
@@ -754,6 +774,7 @@ void mlir::python::populateIRTypes(py::module &m) {
   PyFloat8E5M2FNUZType::bind(m);
   PyBF16Type::bind(m);
   PyF16Type::bind(m);
+  PyTF32Type::bind(m);
   PyF32Type::bind(m);
   PyF64Type::bind(m);
   PyNoneType::bind(m);

diff  --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 82c5b5a6147a4b..50266b4b523323 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -152,6 +152,16 @@ MlirType mlirF16TypeGet(MlirContext ctx) {
   return wrap(FloatType::getF16(unwrap(ctx)));
 }
 
+MlirTypeID mlirFloatTF32TypeGetTypeID() {
+  return wrap(FloatTF32Type::getTypeID());
+}
+
+bool mlirTypeIsATF32(MlirType type) { return unwrap(type).isTF32(); }
+
+MlirType mlirTF32TypeGet(MlirContext ctx) {
+  return wrap(FloatType::getTF32(unwrap(ctx)));
+}
+
 MlirTypeID mlirFloat32TypeGetTypeID() { return wrap(Float32Type::getTypeID()); }
 
 bool mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); }

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index c0975a6be4be92..16ca7f8c6fc13e 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2433,6 +2433,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
       .Case<Float8E4M3B11FNUZType>([&](Type) { os << "f8E4M3B11FNUZ"; })
       .Case<BFloat16Type>([&](Type) { os << "bf16"; })
       .Case<Float16Type>([&](Type) { os << "f16"; })
+      .Case<FloatTF32Type>([&](Type) { os << "tf32"; })
       .Case<Float32Type>([&](Type) { os << "f32"; })
       .Case<Float64Type>([&](Type) { os << "f64"; })
       .Case<Float80Type>([&](Type) { os << "f80"; })

diff  --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 22abd521206804..35940b187cd2d1 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -58,6 +58,8 @@ FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
 
 FloatType Builder::getF16Type() { return FloatType::getF16(context); }
 
+FloatType Builder::getTF32Type() { return FloatType::getTF32(context); }
+
 FloatType Builder::getF32Type() { return FloatType::getF32(context); }
 
 FloatType Builder::getF64Type() { return FloatType::getF64(context); }

diff  --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index e29555f93e9407..b5ebaa01e61bb4 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -93,7 +93,7 @@ unsigned FloatType::getWidth() {
     return 8;
   if (llvm::isa<Float16Type, BFloat16Type>(*this))
     return 16;
-  if (llvm::isa<Float32Type>(*this))
+  if (llvm::isa<Float32Type, FloatTF32Type>(*this))
     return 32;
   if (llvm::isa<Float64Type>(*this))
     return 64;
@@ -120,6 +120,8 @@ const llvm::fltSemantics &FloatType::getFloatSemantics() {
     return APFloat::BFloat();
   if (llvm::isa<Float16Type>(*this))
     return APFloat::IEEEhalf();
+  if (llvm::isa<FloatTF32Type>(*this))
+    return APFloat::FloatTF32();
   if (llvm::isa<Float32Type>(*this))
     return APFloat::IEEEsingle();
   if (llvm::isa<Float64Type>(*this))

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index cc4b33f9ca6694..a79c7fb5052e08 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -219,6 +219,7 @@ class MLIRContextImpl {
   Float8E4M3B11FNUZType f8E4M3B11FNUZTy;
   BFloat16Type bf16Ty;
   Float16Type f16Ty;
+  FloatTF32Type tf32Ty;
   Float32Type f32Ty;
   Float64Type f64Ty;
   Float80Type f80Ty;
@@ -294,6 +295,7 @@ MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
   impl->f8E4M3B11FNUZTy = TypeUniquer::get<Float8E4M3B11FNUZType>(this);
   impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
   impl->f16Ty = TypeUniquer::get<Float16Type>(this);
+  impl->tf32Ty = TypeUniquer::get<FloatTF32Type>(this);
   impl->f32Ty = TypeUniquer::get<Float32Type>(this);
   impl->f64Ty = TypeUniquer::get<Float64Type>(this);
   impl->f80Ty = TypeUniquer::get<Float80Type>(this);
@@ -960,6 +962,9 @@ BFloat16Type BFloat16Type::get(MLIRContext *context) {
 Float16Type Float16Type::get(MLIRContext *context) {
   return context->getImpl().f16Ty;
 }
+FloatTF32Type FloatTF32Type::get(MLIRContext *context) {
+  return context->getImpl().tf32Ty;
+}
 Float32Type Float32Type::get(MLIRContext *context) {
   return context->getImpl().f32Ty;
 }

diff  --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index e376a5fd33922f..32dfef9e810495 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -47,6 +47,7 @@ bool Type::isFloat8E4M3B11FNUZ() const {
 }
 bool Type::isBF16() const { return llvm::isa<BFloat16Type>(*this); }
 bool Type::isF16() const { return llvm::isa<Float16Type>(*this); }
+bool Type::isTF32() const { return llvm::isa<FloatTF32Type>(*this); }
 bool Type::isF32() const { return llvm::isa<Float32Type>(*this); }
 bool Type::isF64() const { return llvm::isa<Float64Type>(*this); }
 bool Type::isF80() const { return llvm::isa<Float80Type>(*this); }

diff  --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 714935fe12e288..23f4687d0394d9 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -56,6 +56,7 @@ __all__ = [
     "Float8E4M3B11FNUZType",
     "Float8E5M2FNUZType",
     "F16Type",
+    "FloatTF32Type",
     "F32Type",
     "F64Type",
     "FlatSymbolRefAttr",
@@ -627,6 +628,14 @@ class F16Type(Type):
     @staticmethod
     def isinstance(arg: Any) -> bool: ...
 
+# TODO: Auto-generated. Audit and fix.
+class FloatTF32Type(Type):
+    def __init__(self, cast_from_type: Type) -> None: ...
+    @staticmethod
+    def get(*args, **kwargs) -> FloatTF32Type: ...
+    @staticmethod
+    def isinstance(arg: Any) -> bool: ...
+
 # TODO: Auto-generated. Audit and fix.
 class F32Type(Type):
     def __init__(self, cast_from_type: Type) -> None: ...

diff  --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index 25d237a74f3ada..291d5832fce79a 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -64,6 +64,10 @@ func.func @float_attrs_pass() {
     // CHECK: float_attr = 2.000000e+00 : bf16
     float_attr = 2. : bf16
   } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 2.000000e+00 : tf32
+    float_attr = 2. : tf32
+  } : () -> ()
   "test.float_attrs"() {
     // CHECK: float_attr = 2.000000e+00 : f32
     float_attr = 2. : f32

diff  --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 99273bab0b4957..51a311dec94419 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -212,6 +212,8 @@ def testFloatType():
         print("float:", BF16Type.get())
         # CHECK: float: f16
         print("float:", F16Type.get())
+        # CHECK: float: tf32
+        print("float:", FloatTF32Type.get())
         # CHECK: float: f32
         print("float:", F32Type.get())
         # CHECK: float: f64

diff  --git a/mlir/utils/gdb-scripts/prettyprinters.py b/mlir/utils/gdb-scripts/prettyprinters.py
index 9ea8bdbe86d773..45fd0837c9391e 100644
--- a/mlir/utils/gdb-scripts/prettyprinters.py
+++ b/mlir/utils/gdb-scripts/prettyprinters.py
@@ -166,6 +166,7 @@ def to_string(self):
     "IndexType",
     "IntegerType",
     "Float16Type",
+    "FloatTF32Type",
     "Float32Type",
     "Float64Type",
     "Float80Type",

diff  --git a/mlir/utils/lldb-scripts/mlirDataFormatters.py b/mlir/utils/lldb-scripts/mlirDataFormatters.py
index 5d06b400334c8a..41f6227fe9de77 100644
--- a/mlir/utils/lldb-scripts/mlirDataFormatters.py
+++ b/mlir/utils/lldb-scripts/mlirDataFormatters.py
@@ -57,6 +57,7 @@ def build_ptr_str_from_addr(addrValue: lldb.SBValue, type: lldb.SBType):
     "mlir::Float8E4M3B11FNUZType": '"f8E4M3B11FNUZ"',
     "mlir::BFloat16Type": '"bf16"',
     "mlir::Float16Type": '"f16"',
+    "mlir::FloatTF32Type": '"tf32"',
     "mlir::Float32Type": '"f32"',
     "mlir::Float64Type": '"f64"',
     "mlir::Float80Type": '"f80"',


        


More information about the Mlir-commits mailing list