[Mlir-commits] [mlir] [mlir][IR] Turn `FloatType` into a type interface (PR #118891)
Matthias Springer
llvmlistbot at llvm.org
Mon Dec 16 01:43:41 PST 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/118891
>From 1ebf66db89409d866f1535c1cfa9c62746822168 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Thu, 5 Dec 2024 23:37:45 +0100
Subject: [PATCH] [mlir][IR] Turn `FloatType` into a type interface
This makes it possible to add new floating point types in downstream projects. Also removes one place where we had to hard-code all existing floating point types (`FloatType::classof`).
---
mlir/include/mlir/IR/BuiltinTypeInterfaces.h | 9 ++
mlir/include/mlir/IR/BuiltinTypeInterfaces.td | 59 ++++++++++
mlir/include/mlir/IR/BuiltinTypes.h | 56 ---------
mlir/include/mlir/IR/BuiltinTypes.td | 17 ++-
mlir/lib/IR/BuiltinTypeInterfaces.cpp | 13 +++
mlir/lib/IR/BuiltinTypes.cpp | 106 ++++++++----------
mlir/unittests/IR/InterfaceAttachmentTest.cpp | 2 +-
7 files changed, 138 insertions(+), 124 deletions(-)
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.h b/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
index ed5e5ca22c5958..e8011b5488dc98 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
@@ -11,6 +11,15 @@
#include "mlir/IR/Types.h"
+namespace llvm {
+struct fltSemantics;
+} // namespace llvm
+
+namespace mlir {
+class FloatType;
+class MLIRContext;
+} // namespace mlir
+
#include "mlir/IR/BuiltinTypeInterfaces.h.inc"
#endif // MLIR_IR_BUILTINTYPEINTERFACES_H
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index c9dcd546cf67c2..c36b738e38f42a 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -16,6 +16,65 @@
include "mlir/IR/OpBase.td"
+def FloatTypeInterface : TypeInterface<"FloatType"> {
+ let cppNamespace = "::mlir";
+ let description = [{
+ This type interface should be implemented by all floating-point types. It
+ defines the LLVM APFloat semantics and provides a few helper functions.
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns the APFloat semantics for this floating-point type.
+ }],
+ /*retTy=*/"const ::llvm::fltSemantics &",
+ /*methodName=*/"getFloatSemantics",
+ /*args=*/(ins)
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns a float type with bitwidth scaled by `scale`. Returns a "null"
+ float type if the scaled element type cannot be represented.
+ }],
+ /*retTy=*/"::mlir::FloatType",
+ /*methodName=*/"scaleElementBitwidth",
+ /*args=*/(ins "unsigned":$scale),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/"return ::mlir::FloatType();"
+ >
+ ];
+
+ let extraClassDeclaration = [{
+ // Convenience factories.
+ static FloatType getBF16(MLIRContext *ctx);
+ static FloatType getF16(MLIRContext *ctx);
+ static FloatType getF32(MLIRContext *ctx);
+ static FloatType getTF32(MLIRContext *ctx);
+ static FloatType getF64(MLIRContext *ctx);
+ static FloatType getF80(MLIRContext *ctx);
+ static FloatType getF128(MLIRContext *ctx);
+ static FloatType getFloat8E5M2(MLIRContext *ctx);
+ static FloatType getFloat8E4M3(MLIRContext *ctx);
+ static FloatType getFloat8E4M3FN(MLIRContext *ctx);
+ static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx);
+ static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
+ static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
+ static FloatType getFloat8E3M4(MLIRContext *ctx);
+ static FloatType getFloat4E2M1FN(MLIRContext *ctx);
+ static FloatType getFloat6E2M3FN(MLIRContext *ctx);
+ static FloatType getFloat6E3M2FN(MLIRContext *ctx);
+ static FloatType getFloat8E8M0FNU(MLIRContext *ctx);
+
+ /// Return the bitwidth of this float type.
+ unsigned getWidth();
+
+ /// Return the width of the mantissa of this type.
+ /// The width includes the integer bit.
+ unsigned getFPMantissaWidth();
+ }];
+}
+
//===----------------------------------------------------------------------===//
// MemRefElementTypeInterface
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 7f9c470ffec304..2b3c2b6d1753dc 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -25,7 +25,6 @@ struct fltSemantics;
namespace mlir {
class AffineExpr;
class AffineMap;
-class FloatType;
class IndexType;
class IntegerType;
class MemRefType;
@@ -44,52 +43,6 @@ template <typename ConcreteType>
class ValueSemantics
: public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
-//===----------------------------------------------------------------------===//
-// FloatType
-//===----------------------------------------------------------------------===//
-
-class FloatType : public Type {
-public:
- using Type::Type;
-
- // Convenience factories.
- static FloatType getBF16(MLIRContext *ctx);
- static FloatType getF16(MLIRContext *ctx);
- static FloatType getF32(MLIRContext *ctx);
- static FloatType getTF32(MLIRContext *ctx);
- static FloatType getF64(MLIRContext *ctx);
- static FloatType getF80(MLIRContext *ctx);
- static FloatType getF128(MLIRContext *ctx);
- static FloatType getFloat8E5M2(MLIRContext *ctx);
- static FloatType getFloat8E4M3(MLIRContext *ctx);
- static FloatType getFloat8E4M3FN(MLIRContext *ctx);
- static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx);
- static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
- static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
- static FloatType getFloat8E3M4(MLIRContext *ctx);
- static FloatType getFloat4E2M1FN(MLIRContext *ctx);
- static FloatType getFloat6E2M3FN(MLIRContext *ctx);
- static FloatType getFloat6E3M2FN(MLIRContext *ctx);
- static FloatType getFloat8E8M0FNU(MLIRContext *ctx);
-
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool classof(Type type);
-
- /// Return the bitwidth of this float type.
- unsigned getWidth();
-
- /// Return the width of the mantissa of this type.
- /// The width includes the integer bit.
- unsigned getFPMantissaWidth();
-
- /// Get or create a new FloatType with bitwidth scaled by `scale`.
- /// Return null if the scaled element type cannot be represented.
- FloatType scaleElementBitwidth(unsigned scale);
-
- /// Return the floating semantics of this float type.
- const llvm::fltSemantics &getFloatSemantics();
-};
-
//===----------------------------------------------------------------------===//
// TensorType
//===----------------------------------------------------------------------===//
@@ -448,15 +401,6 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
llvm::isa<MemRefElementTypeInterface>(type);
}
-inline bool FloatType::classof(Type type) {
- return llvm::isa<Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
- Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
- Float8E5M2FNUZType, Float8E4M3FNUZType,
- Float8E4M3B11FNUZType, Float8E3M4Type, Float8E8M0FNUType,
- BFloat16Type, Float16Type, FloatTF32Type, Float32Type,
- Float64Type, Float80Type, Float128Type>(type);
-}
-
inline FloatType FloatType::getFloat4E2M1FN(MLIRContext *ctx) {
return Float4E2M1FNType::get(ctx);
}
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index dca228097d782d..fc50b28c09e41c 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -79,8 +79,12 @@ def Builtin_Complex : Builtin_Type<"Complex", "complex"> {
//===----------------------------------------------------------------------===//
// Base class for Builtin dialect float types.
-class Builtin_FloatType<string name, string mnemonic>
- : Builtin_Type<name, mnemonic, /*traits=*/[], "::mlir::FloatType"> {
+class Builtin_FloatType<string name, string mnemonic,
+ list<string> declaredInterfaceMethods = []>
+ : Builtin_Type<name, mnemonic, /*traits=*/[
+ DeclareTypeInterfaceMethods<
+ FloatTypeInterface,
+ ["getFloatSemantics"] # declaredInterfaceMethods>]> {
let extraClassDeclaration = [{
static }] # name # [{Type get(MLIRContext *context);
}];
@@ -322,14 +326,16 @@ def Builtin_Float8E8M0FNU : Builtin_FloatType<"Float8E8M0FNU", "f8E8M0FNU"> {
//===----------------------------------------------------------------------===//
// BFloat16Type
-def Builtin_BFloat16 : Builtin_FloatType<"BFloat16", "bf16"> {
+def Builtin_BFloat16 : Builtin_FloatType<"BFloat16", "bf16",
+ /*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
let summary = "bfloat16 floating-point type";
}
//===----------------------------------------------------------------------===//
// Float16Type
-def Builtin_Float16 : Builtin_FloatType<"Float16", "f16"> {
+def Builtin_Float16 : Builtin_FloatType<"Float16", "f16",
+ /*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
let summary = "16-bit floating-point type";
}
@@ -343,7 +349,8 @@ def Builtin_FloatTF32 : Builtin_FloatType<"FloatTF32", "tf32"> {
//===----------------------------------------------------------------------===//
// Float32Type
-def Builtin_Float32 : Builtin_FloatType<"Float32", "f32"> {
+def Builtin_Float32 : Builtin_FloatType<"Float32", "f32",
+ /*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
let summary = "32-bit floating-point type";
}
diff --git a/mlir/lib/IR/BuiltinTypeInterfaces.cpp b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
index ab9e65b5edfed3..c663f6c9094604 100644
--- a/mlir/lib/IR/BuiltinTypeInterfaces.cpp
+++ b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
@@ -8,6 +8,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
+#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/Sequence.h"
using namespace mlir;
@@ -19,6 +20,18 @@ using namespace mlir::detail;
#include "mlir/IR/BuiltinTypeInterfaces.cpp.inc"
+//===----------------------------------------------------------------------===//
+// FloatType
+//===----------------------------------------------------------------------===//
+
+unsigned FloatType::getWidth() {
+ return APFloat::semanticsSizeInBits(getFloatSemantics());
+}
+
+unsigned FloatType::getFPMantissaWidth() {
+ return APFloat::semanticsPrecision(getFloatSemantics());
+}
+
//===----------------------------------------------------------------------===//
// ShapedType
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 6546234429c8cb..41b794bc0aec59 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -87,72 +87,54 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
}
//===----------------------------------------------------------------------===//
-// Float Type
-//===----------------------------------------------------------------------===//
-
-unsigned FloatType::getWidth() {
- return APFloat::semanticsSizeInBits(getFloatSemantics());
-}
-
-/// Returns the floating semantics for the given type.
-const llvm::fltSemantics &FloatType::getFloatSemantics() {
- if (llvm::isa<Float4E2M1FNType>(*this))
- return APFloat::Float4E2M1FN();
- if (llvm::isa<Float6E2M3FNType>(*this))
- return APFloat::Float6E2M3FN();
- if (llvm::isa<Float6E3M2FNType>(*this))
- return APFloat::Float6E3M2FN();
- if (llvm::isa<Float8E5M2Type>(*this))
- return APFloat::Float8E5M2();
- if (llvm::isa<Float8E4M3Type>(*this))
- return APFloat::Float8E4M3();
- if (llvm::isa<Float8E4M3FNType>(*this))
- return APFloat::Float8E4M3FN();
- if (llvm::isa<Float8E5M2FNUZType>(*this))
- return APFloat::Float8E5M2FNUZ();
- if (llvm::isa<Float8E4M3FNUZType>(*this))
- return APFloat::Float8E4M3FNUZ();
- if (llvm::isa<Float8E4M3B11FNUZType>(*this))
- return APFloat::Float8E4M3B11FNUZ();
- if (llvm::isa<Float8E3M4Type>(*this))
- return APFloat::Float8E3M4();
- if (llvm::isa<Float8E8M0FNUType>(*this))
- return APFloat::Float8E8M0FNU();
- if (llvm::isa<BFloat16Type>(*this))
- return APFloat::BFloat();
- if (llvm::isa<Float16Type>(*this))
- return APFloat::IEEEhalf();
- if (llvm::isa<FloatTF32Type>(*this))
- return APFloat::FloatTF32();
- if (llvm::isa<Float32Type>(*this))
- return APFloat::IEEEsingle();
- if (llvm::isa<Float64Type>(*this))
- return APFloat::IEEEdouble();
- if (llvm::isa<Float80Type>(*this))
- return APFloat::x87DoubleExtended();
- if (llvm::isa<Float128Type>(*this))
- return APFloat::IEEEquad();
- llvm_unreachable("non-floating point type used");
-}
-
-FloatType FloatType::scaleElementBitwidth(unsigned scale) {
- if (!scale)
- return FloatType();
- MLIRContext *ctx = getContext();
- if (isF16() || isBF16()) {
- if (scale == 2)
- return FloatType::getF32(ctx);
- if (scale == 4)
- return FloatType::getF64(ctx);
+// Float Types
+//===----------------------------------------------------------------------===//
+
+// Mapping from MLIR FloatType to APFloat semantics.
+#define FLOAT_TYPE_SEMANTICS(TYPE, SEM) \
+ const llvm::fltSemantics &TYPE::getFloatSemantics() const { \
+ return APFloat::SEM(); \
}
- if (isF32())
- if (scale == 2)
- return FloatType::getF64(ctx);
+FLOAT_TYPE_SEMANTICS(Float4E2M1FNType, Float4E2M1FN)
+FLOAT_TYPE_SEMANTICS(Float6E2M3FNType, Float6E2M3FN)
+FLOAT_TYPE_SEMANTICS(Float6E3M2FNType, Float6E3M2FN)
+FLOAT_TYPE_SEMANTICS(Float8E5M2Type, Float8E5M2)
+FLOAT_TYPE_SEMANTICS(Float8E4M3Type, Float8E4M3)
+FLOAT_TYPE_SEMANTICS(Float8E4M3FNType, Float8E4M3FN)
+FLOAT_TYPE_SEMANTICS(Float8E5M2FNUZType, Float8E5M2FNUZ)
+FLOAT_TYPE_SEMANTICS(Float8E4M3FNUZType, Float8E4M3FNUZ)
+FLOAT_TYPE_SEMANTICS(Float8E4M3B11FNUZType, Float8E4M3B11FNUZ)
+FLOAT_TYPE_SEMANTICS(Float8E3M4Type, Float8E3M4)
+FLOAT_TYPE_SEMANTICS(Float8E8M0FNUType, Float8E8M0FNU)
+FLOAT_TYPE_SEMANTICS(BFloat16Type, BFloat)
+FLOAT_TYPE_SEMANTICS(Float16Type, IEEEhalf)
+FLOAT_TYPE_SEMANTICS(FloatTF32Type, FloatTF32)
+FLOAT_TYPE_SEMANTICS(Float32Type, IEEEsingle)
+FLOAT_TYPE_SEMANTICS(Float64Type, IEEEdouble)
+FLOAT_TYPE_SEMANTICS(Float80Type, x87DoubleExtended)
+FLOAT_TYPE_SEMANTICS(Float128Type, IEEEquad)
+#undef FLOAT_TYPE_SEMANTICS
+
+FloatType Float16Type::scaleElementBitwidth(unsigned scale) const {
+ if (scale == 2)
+ return FloatType::getF32(getContext());
+ if (scale == 4)
+ return FloatType::getF64(getContext());
return FloatType();
}
-unsigned FloatType::getFPMantissaWidth() {
- return APFloat::semanticsPrecision(getFloatSemantics());
+FloatType BFloat16Type::scaleElementBitwidth(unsigned scale) const {
+ if (scale == 2)
+ return FloatType::getF32(getContext());
+ if (scale == 4)
+ return FloatType::getF64(getContext());
+ return FloatType();
+}
+
+FloatType Float32Type::scaleElementBitwidth(unsigned scale) const {
+ if (scale == 2)
+ return FloatType::getF64(getContext());
+ return FloatType();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/unittests/IR/InterfaceAttachmentTest.cpp b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
index b6066dd5685dc6..1b5d3b8c31bd22 100644
--- a/mlir/unittests/IR/InterfaceAttachmentTest.cpp
+++ b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
@@ -43,7 +43,7 @@ struct Model
/// overrides default methods.
struct OverridingModel
: public TestExternalTypeInterface::ExternalModel<OverridingModel,
- FloatType> {
+ Float32Type> {
unsigned getBitwidthPlusArg(Type type, unsigned arg) const {
return type.getIntOrFloatBitWidth() + arg;
}
More information about the Mlir-commits
mailing list