[Mlir-commits] [mlir] 6685fd8 - [mlir] Add support for TF32 as a Builtin FloatType
Mehdi Amini
llvmlistbot at llvm.org
Thu Jul 6 08:56:28 PDT 2023
Author: Jeremy Furtek
Date: 2023-07-06T08:56:07-07:00
New Revision: 6685fd82391d3e654d3b05f2d54cdcdec6e6d887
URL: https://github.com/llvm/llvm-project/commit/6685fd82391d3e654d3b05f2d54cdcdec6e6d887
DIFF: https://github.com/llvm/llvm-project/commit/6685fd82391d3e654d3b05f2d54cdcdec6e6d887.diff
LOG: [mlir] Add support for TF32 as a Builtin FloatType
This diff adds support for TF32 as a Builtin floating point type. This
supplements the recent addition of the TF32 semantic to the LLVM APFloat class
by extending usage to MLIR.
https://reviews.llvm.org/D151923
More information on the TF32 type can be found here:
https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D153705
Added:
Modified:
mlir/include/mlir-c/BuiltinTypes.h
mlir/include/mlir/IR/Builders.h
mlir/include/mlir/IR/BuiltinTypes.h
mlir/include/mlir/IR/BuiltinTypes.td
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/Types.h
mlir/lib/AsmParser/TokenKinds.def
mlir/lib/AsmParser/TypeParser.cpp
mlir/lib/Bindings/Python/IRTypes.cpp
mlir/lib/CAPI/IR/BuiltinTypes.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/Builders.cpp
mlir/lib/IR/BuiltinTypes.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/lib/IR/Types.cpp
mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
mlir/test/IR/attribute.mlir
mlir/test/python/ir/builtin_types.py
mlir/utils/gdb-scripts/prettyprinters.py
mlir/utils/lldb-scripts/mlirDataFormatters.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index c8ea44cd94faed..a6d8e10efbde92 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -163,6 +163,16 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAF64(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirF64TypeGet(MlirContext ctx);
+/// Returns the typeID of a TF32 type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFloatTF32TypeGetTypeID(void);
+
+/// Checks whether the given type is an TF32 type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsATF32(MlirType type);
+
+/// Creates a TF32 type in the given context. The type is owned by the
+/// context.
+MLIR_CAPI_EXPORTED MlirType mlirTF32TypeGet(MlirContext ctx);
+
//===----------------------------------------------------------------------===//
// None type.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 4dbeb418099d1e..9be71b065c09bb 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -67,6 +67,7 @@ class Builder {
FloatType getFloat8E4M3B11FNUZType();
FloatType getBF16Type();
FloatType getF16Type();
+ FloatType getTF32Type();
FloatType getF32Type();
FloatType getF64Type();
FloatType getF80Type();
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index f22421aa7a428d..2fb8852cde94b4 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -44,6 +44,7 @@ class FloatType : public Type {
static FloatType getBF16(MLIRContext *ctx);
static FloatType getF16(MLIRContext *ctx);
static FloatType getF32(MLIRContext *ctx);
+ static FloatType getTF32(MLIRContext *ctx);
static FloatType getF64(MLIRContext *ctx);
static FloatType getF80(MLIRContext *ctx);
static FloatType getF128(MLIRContext *ctx);
@@ -417,8 +418,8 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
inline bool FloatType::classof(Type type) {
return llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
Float8E4M3FNUZType, Float8E4M3B11FNUZType, BFloat16Type,
- Float16Type, Float32Type, Float64Type, Float80Type,
- Float128Type>(type);
+ Float16Type, FloatTF32Type, Float32Type, Float64Type,
+ Float80Type, Float128Type>(type);
}
inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) {
@@ -449,6 +450,10 @@ inline FloatType FloatType::getF16(MLIRContext *ctx) {
return Float16Type::get(ctx);
}
+inline FloatType FloatType::getTF32(MLIRContext *ctx) {
+ return FloatTF32Type::get(ctx);
+}
+
inline FloatType FloatType::getF32(MLIRContext *ctx) {
return Float32Type::get(ctx);
}
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 900531b1953c4b..75e85f9887f2f1 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -198,6 +198,13 @@ def Builtin_Float16 : Builtin_FloatType<"Float16"> {
let summary = "16-bit floating-point type";
}
+//===----------------------------------------------------------------------===//
+// FloatTF32Type
+
+def Builtin_FloatTF32 : Builtin_FloatType<"FloatTF32"> {
+ let summary = "TF32 floating-point type";
+}
+
//===----------------------------------------------------------------------===//
// Float32Type
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index fd5480ab57a3b0..40674988114f2e 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -570,6 +570,8 @@ def F128 : F<128>;
def BF16 : Type<CPred<"$_self.isBF16()">, "bfloat16 type">,
BuildableType<"$_builder.getBF16Type()">;
+def TF32 : Type<CPred<"$_self.isTF32()">, "tf32 type">,
+ BuildableType<"$_builder.getTF32Type()">;
def F8E4M3FN : Type<CPred<"$_self.isFloat8E4M3FN()">, "f8E4M3FN type">,
BuildableType<"$_builder.getFloat8E4M3FNType()">;
def F8E5M2 : Type<CPred<"$_self.isFloat8E5M2()">, "f8E5M2 type">,
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index ed08041964c8db..5c4e06da829d9c 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -127,6 +127,7 @@ class Type {
bool isFloat8E4M3B11FNUZ() const;
bool isBF16() const;
bool isF16() const;
+ bool isTF32() const;
bool isF32() const;
bool isF64() const;
bool isF80() const;
diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index 9a632e3570fb5e..1b5aa10e4ac1dc 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -117,6 +117,7 @@ TOK_KEYWORD(step)
TOK_KEYWORD(strided)
TOK_KEYWORD(symbol)
TOK_KEYWORD(tensor)
+TOK_KEYWORD(tf32)
TOK_KEYWORD(to)
TOK_KEYWORD(true)
TOK_KEYWORD(tuple)
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 6a65dda505a1c1..306e850af27bc5 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -38,6 +38,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
case Token::kw_f8E4M3B11FNUZ:
case Token::kw_bf16:
case Token::kw_f16:
+ case Token::kw_tf32:
case Token::kw_f32:
case Token::kw_f64:
case Token::kw_f80:
@@ -313,6 +314,9 @@ Type Parser::parseNonFunctionType() {
case Token::kw_f16:
consumeToken(Token::kw_f16);
return builder.getF16Type();
+ case Token::kw_tf32:
+ consumeToken(Token::kw_tf32);
+ return builder.getTF32Type();
case Token::kw_f32:
consumeToken(Token::kw_f32);
return builder.getF32Type();
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 25307262bddbd2..caf215be85baa8 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -247,6 +247,26 @@ class PyF16Type : public PyConcreteType<PyF16Type> {
}
};
+/// Floating Point Type subclass - TF32Type.
+class PyTF32Type : public PyConcreteType<PyTF32Type> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloatTF32TypeGetTypeID;
+ static constexpr const char *pyClassName = "FloatTF32Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirTF32TypeGet(context->get());
+ return PyTF32Type(context->getRef(), t);
+ },
+ py::arg("context") = py::none(), "Create a tf32 type.");
+ }
+};
+
/// Floating Point Type subclass - F32Type.
class PyF32Type : public PyConcreteType<PyF32Type> {
public:
@@ -754,6 +774,7 @@ void mlir::python::populateIRTypes(py::module &m) {
PyFloat8E5M2FNUZType::bind(m);
PyBF16Type::bind(m);
PyF16Type::bind(m);
+ PyTF32Type::bind(m);
PyF32Type::bind(m);
PyF64Type::bind(m);
PyNoneType::bind(m);
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 82c5b5a6147a4b..50266b4b523323 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -152,6 +152,16 @@ MlirType mlirF16TypeGet(MlirContext ctx) {
return wrap(FloatType::getF16(unwrap(ctx)));
}
+MlirTypeID mlirFloatTF32TypeGetTypeID() {
+ return wrap(FloatTF32Type::getTypeID());
+}
+
+bool mlirTypeIsATF32(MlirType type) { return unwrap(type).isTF32(); }
+
+MlirType mlirTF32TypeGet(MlirContext ctx) {
+ return wrap(FloatType::getTF32(unwrap(ctx)));
+}
+
MlirTypeID mlirFloat32TypeGetTypeID() { return wrap(Float32Type::getTypeID()); }
bool mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); }
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index c0975a6be4be92..16ca7f8c6fc13e 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2433,6 +2433,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
.Case<Float8E4M3B11FNUZType>([&](Type) { os << "f8E4M3B11FNUZ"; })
.Case<BFloat16Type>([&](Type) { os << "bf16"; })
.Case<Float16Type>([&](Type) { os << "f16"; })
+ .Case<FloatTF32Type>([&](Type) { os << "tf32"; })
.Case<Float32Type>([&](Type) { os << "f32"; })
.Case<Float64Type>([&](Type) { os << "f64"; })
.Case<Float80Type>([&](Type) { os << "f80"; })
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 22abd521206804..35940b187cd2d1 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -58,6 +58,8 @@ FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
FloatType Builder::getF16Type() { return FloatType::getF16(context); }
+FloatType Builder::getTF32Type() { return FloatType::getTF32(context); }
+
FloatType Builder::getF32Type() { return FloatType::getF32(context); }
FloatType Builder::getF64Type() { return FloatType::getF64(context); }
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index e29555f93e9407..b5ebaa01e61bb4 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -93,7 +93,7 @@ unsigned FloatType::getWidth() {
return 8;
if (llvm::isa<Float16Type, BFloat16Type>(*this))
return 16;
- if (llvm::isa<Float32Type>(*this))
+ if (llvm::isa<Float32Type, FloatTF32Type>(*this))
return 32;
if (llvm::isa<Float64Type>(*this))
return 64;
@@ -120,6 +120,8 @@ const llvm::fltSemantics &FloatType::getFloatSemantics() {
return APFloat::BFloat();
if (llvm::isa<Float16Type>(*this))
return APFloat::IEEEhalf();
+ if (llvm::isa<FloatTF32Type>(*this))
+ return APFloat::FloatTF32();
if (llvm::isa<Float32Type>(*this))
return APFloat::IEEEsingle();
if (llvm::isa<Float64Type>(*this))
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index cc4b33f9ca6694..a79c7fb5052e08 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -219,6 +219,7 @@ class MLIRContextImpl {
Float8E4M3B11FNUZType f8E4M3B11FNUZTy;
BFloat16Type bf16Ty;
Float16Type f16Ty;
+ FloatTF32Type tf32Ty;
Float32Type f32Ty;
Float64Type f64Ty;
Float80Type f80Ty;
@@ -294,6 +295,7 @@ MLIRContext::MLIRContext(const DialectRegistry ®istry, Threading setting)
impl->f8E4M3B11FNUZTy = TypeUniquer::get<Float8E4M3B11FNUZType>(this);
impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
impl->f16Ty = TypeUniquer::get<Float16Type>(this);
+ impl->tf32Ty = TypeUniquer::get<FloatTF32Type>(this);
impl->f32Ty = TypeUniquer::get<Float32Type>(this);
impl->f64Ty = TypeUniquer::get<Float64Type>(this);
impl->f80Ty = TypeUniquer::get<Float80Type>(this);
@@ -960,6 +962,9 @@ BFloat16Type BFloat16Type::get(MLIRContext *context) {
Float16Type Float16Type::get(MLIRContext *context) {
return context->getImpl().f16Ty;
}
+FloatTF32Type FloatTF32Type::get(MLIRContext *context) {
+ return context->getImpl().tf32Ty;
+}
Float32Type Float32Type::get(MLIRContext *context) {
return context->getImpl().f32Ty;
}
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index e376a5fd33922f..32dfef9e810495 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -47,6 +47,7 @@ bool Type::isFloat8E4M3B11FNUZ() const {
}
bool Type::isBF16() const { return llvm::isa<BFloat16Type>(*this); }
bool Type::isF16() const { return llvm::isa<Float16Type>(*this); }
+bool Type::isTF32() const { return llvm::isa<FloatTF32Type>(*this); }
bool Type::isF32() const { return llvm::isa<Float32Type>(*this); }
bool Type::isF64() const { return llvm::isa<Float64Type>(*this); }
bool Type::isF80() const { return llvm::isa<Float80Type>(*this); }
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 714935fe12e288..23f4687d0394d9 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -56,6 +56,7 @@ __all__ = [
"Float8E4M3B11FNUZType",
"Float8E5M2FNUZType",
"F16Type",
+ "FloatTF32Type",
"F32Type",
"F64Type",
"FlatSymbolRefAttr",
@@ -627,6 +628,14 @@ class F16Type(Type):
@staticmethod
def isinstance(arg: Any) -> bool: ...
+# TODO: Auto-generated. Audit and fix.
+class FloatTF32Type(Type):
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @staticmethod
+ def get(*args, **kwargs) -> FloatTF32Type: ...
+ @staticmethod
+ def isinstance(arg: Any) -> bool: ...
+
# TODO: Auto-generated. Audit and fix.
class F32Type(Type):
def __init__(self, cast_from_type: Type) -> None: ...
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index 25d237a74f3ada..291d5832fce79a 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -64,6 +64,10 @@ func.func @float_attrs_pass() {
// CHECK: float_attr = 2.000000e+00 : bf16
float_attr = 2. : bf16
} : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 2.000000e+00 : tf32
+ float_attr = 2. : tf32
+ } : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 2.000000e+00 : f32
float_attr = 2. : f32
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 99273bab0b4957..51a311dec94419 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -212,6 +212,8 @@ def testFloatType():
print("float:", BF16Type.get())
# CHECK: float: f16
print("float:", F16Type.get())
+ # CHECK: float: tf32
+ print("float:", FloatTF32Type.get())
# CHECK: float: f32
print("float:", F32Type.get())
# CHECK: float: f64
diff --git a/mlir/utils/gdb-scripts/prettyprinters.py b/mlir/utils/gdb-scripts/prettyprinters.py
index 9ea8bdbe86d773..45fd0837c9391e 100644
--- a/mlir/utils/gdb-scripts/prettyprinters.py
+++ b/mlir/utils/gdb-scripts/prettyprinters.py
@@ -166,6 +166,7 @@ def to_string(self):
"IndexType",
"IntegerType",
"Float16Type",
+ "FloatTF32Type",
"Float32Type",
"Float64Type",
"Float80Type",
diff --git a/mlir/utils/lldb-scripts/mlirDataFormatters.py b/mlir/utils/lldb-scripts/mlirDataFormatters.py
index 5d06b400334c8a..41f6227fe9de77 100644
--- a/mlir/utils/lldb-scripts/mlirDataFormatters.py
+++ b/mlir/utils/lldb-scripts/mlirDataFormatters.py
@@ -57,6 +57,7 @@ def build_ptr_str_from_addr(addrValue: lldb.SBValue, type: lldb.SBType):
"mlir::Float8E4M3B11FNUZType": '"f8E4M3B11FNUZ"',
"mlir::BFloat16Type": '"bf16"',
"mlir::Float16Type": '"f16"',
+ "mlir::FloatTF32Type": '"tf32"',
"mlir::Float32Type": '"f32"',
"mlir::Float64Type": '"f64"',
"mlir::Float80Type": '"f80"',
More information about the Mlir-commits
mailing list