[Mlir-commits] [mlir] [mlir][IR] Turn `FloatType` into a type interface (PR #118891)
Matthias Springer
llvmlistbot at llvm.org
Thu Dec 5 16:42:45 PST 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/118891
>From c996d3fac3c3a4a0a7c3615c39101a4bcf31c8fa Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Thu, 5 Dec 2024 23:37:45 +0100
Subject: [PATCH] [mlir][IR] Turn `FloatType` into a type interface
This makes it possible to add new floating point types in downstream projects. Also removes one place where we had to hard-code all existing floating point types (`FloatType::classof`).
---
mlir/include/mlir/IR/BuiltinTypeInterfaces.h | 9 ++
mlir/include/mlir/IR/BuiltinTypeInterfaces.td | 46 ++++++++++
mlir/include/mlir/IR/BuiltinTypes.h | 56 -----------
mlir/include/mlir/IR/BuiltinTypes.td | 4 +-
mlir/lib/IR/BuiltinTypeInterfaces.cpp | 29 ++++++
mlir/lib/IR/BuiltinTypes.cpp | 92 ++++++-------------
mlir/unittests/IR/InterfaceAttachmentTest.cpp | 2 +-
7 files changed, 114 insertions(+), 124 deletions(-)
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.h b/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
index ed5e5ca22c5958..e8011b5488dc98 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
@@ -11,6 +11,15 @@
#include "mlir/IR/Types.h"
+namespace llvm {
+struct fltSemantics;
+} // namespace llvm
+
+namespace mlir {
+class FloatType;
+class MLIRContext;
+} // namespace mlir
+
#include "mlir/IR/BuiltinTypeInterfaces.h.inc"
#endif // MLIR_IR_BUILTINTYPEINTERFACES_H
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index c9dcd546cf67c2..8b0242672dfdb4 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -16,6 +16,52 @@
include "mlir/IR/OpBase.td"
+def FloatTypeInterface : TypeInterface<"FloatType"> {
+ let cppNamespace = "::mlir";
+ let description = [{
+ This type interface should be implemented by all floating-point types. It
+ defines the LLVM APFloat semantics and provides a few helper functions.
+ }];
+ let methods = [
+ InterfaceMethod<[{
+ Returns the APFloat semantics for this floating-point type.
+ }], "const llvm::fltSemantics &", "getFloatSemantics", (ins)>,
+ ];
+
+ let extraClassDeclaration = [{
+ // Convenience factories.
+ static FloatType getBF16(MLIRContext *ctx);
+ static FloatType getF16(MLIRContext *ctx);
+ static FloatType getF32(MLIRContext *ctx);
+ static FloatType getTF32(MLIRContext *ctx);
+ static FloatType getF64(MLIRContext *ctx);
+ static FloatType getF80(MLIRContext *ctx);
+ static FloatType getF128(MLIRContext *ctx);
+ static FloatType getFloat8E5M2(MLIRContext *ctx);
+ static FloatType getFloat8E4M3(MLIRContext *ctx);
+ static FloatType getFloat8E4M3FN(MLIRContext *ctx);
+ static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx);
+ static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
+ static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
+ static FloatType getFloat8E3M4(MLIRContext *ctx);
+ static FloatType getFloat4E2M1FN(MLIRContext *ctx);
+ static FloatType getFloat6E2M3FN(MLIRContext *ctx);
+ static FloatType getFloat6E3M2FN(MLIRContext *ctx);
+ static FloatType getFloat8E8M0FNU(MLIRContext *ctx);
+
+ /// Return the bitwidth of this float type.
+ unsigned getWidth();
+
+ /// Return the width of the mantissa of this type.
+ /// The width includes the integer bit.
+ unsigned getFPMantissaWidth();
+
+ /// Get or create a new FloatType with bitwidth scaled by `scale`.
+ /// Return null if the scaled element type cannot be represented.
+ FloatType scaleElementBitwidth(unsigned scale);
+ }];
+}
+
//===----------------------------------------------------------------------===//
// MemRefElementTypeInterface
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 7f9c470ffec304..2b3c2b6d1753dc 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -25,7 +25,6 @@ struct fltSemantics;
namespace mlir {
class AffineExpr;
class AffineMap;
-class FloatType;
class IndexType;
class IntegerType;
class MemRefType;
@@ -44,52 +43,6 @@ template <typename ConcreteType>
class ValueSemantics
: public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
-//===----------------------------------------------------------------------===//
-// FloatType
-//===----------------------------------------------------------------------===//
-
-class FloatType : public Type {
-public:
- using Type::Type;
-
- // Convenience factories.
- static FloatType getBF16(MLIRContext *ctx);
- static FloatType getF16(MLIRContext *ctx);
- static FloatType getF32(MLIRContext *ctx);
- static FloatType getTF32(MLIRContext *ctx);
- static FloatType getF64(MLIRContext *ctx);
- static FloatType getF80(MLIRContext *ctx);
- static FloatType getF128(MLIRContext *ctx);
- static FloatType getFloat8E5M2(MLIRContext *ctx);
- static FloatType getFloat8E4M3(MLIRContext *ctx);
- static FloatType getFloat8E4M3FN(MLIRContext *ctx);
- static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx);
- static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
- static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
- static FloatType getFloat8E3M4(MLIRContext *ctx);
- static FloatType getFloat4E2M1FN(MLIRContext *ctx);
- static FloatType getFloat6E2M3FN(MLIRContext *ctx);
- static FloatType getFloat6E3M2FN(MLIRContext *ctx);
- static FloatType getFloat8E8M0FNU(MLIRContext *ctx);
-
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool classof(Type type);
-
- /// Return the bitwidth of this float type.
- unsigned getWidth();
-
- /// Return the width of the mantissa of this type.
- /// The width includes the integer bit.
- unsigned getFPMantissaWidth();
-
- /// Get or create a new FloatType with bitwidth scaled by `scale`.
- /// Return null if the scaled element type cannot be represented.
- FloatType scaleElementBitwidth(unsigned scale);
-
- /// Return the floating semantics of this float type.
- const llvm::fltSemantics &getFloatSemantics();
-};
-
//===----------------------------------------------------------------------===//
// TensorType
//===----------------------------------------------------------------------===//
@@ -448,15 +401,6 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
llvm::isa<MemRefElementTypeInterface>(type);
}
-inline bool FloatType::classof(Type type) {
- return llvm::isa<Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
- Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
- Float8E5M2FNUZType, Float8E4M3FNUZType,
- Float8E4M3B11FNUZType, Float8E3M4Type, Float8E8M0FNUType,
- BFloat16Type, Float16Type, FloatTF32Type, Float32Type,
- Float64Type, Float80Type, Float128Type>(type);
-}
-
inline FloatType FloatType::getFloat4E2M1FN(MLIRContext *ctx) {
return Float4E2M1FNType::get(ctx);
}
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index dca228097d782d..a0afda4e3b465e 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -80,7 +80,9 @@ def Builtin_Complex : Builtin_Type<"Complex", "complex"> {
// Base class for Builtin dialect float types.
class Builtin_FloatType<string name, string mnemonic>
- : Builtin_Type<name, mnemonic, /*traits=*/[], "::mlir::FloatType"> {
+ : Builtin_Type<name, mnemonic, /*traits=*/[
+ DeclareTypeInterfaceMethods<FloatTypeInterface,
+ ["getFloatSemantics"]>]> {
let extraClassDeclaration = [{
static }] # name # [{Type get(MLIRContext *context);
}];
diff --git a/mlir/lib/IR/BuiltinTypeInterfaces.cpp b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
index ab9e65b5edfed3..1374e889833283 100644
--- a/mlir/lib/IR/BuiltinTypeInterfaces.cpp
+++ b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
@@ -8,6 +8,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
+#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/Sequence.h"
using namespace mlir;
@@ -19,6 +20,34 @@ using namespace mlir::detail;
#include "mlir/IR/BuiltinTypeInterfaces.cpp.inc"
+//===----------------------------------------------------------------------===//
+// FloatType
+//===----------------------------------------------------------------------===//
+
+unsigned FloatType::getWidth() {
+ return APFloat::semanticsSizeInBits(getFloatSemantics());
+}
+
+FloatType FloatType::scaleElementBitwidth(unsigned scale) {
+ if (!scale)
+ return FloatType();
+ MLIRContext *ctx = getContext();
+ if (isF16() || isBF16()) {
+ if (scale == 2)
+ return FloatType::getF32(ctx);
+ if (scale == 4)
+ return FloatType::getF64(ctx);
+ }
+ if (isF32())
+ if (scale == 2)
+ return FloatType::getF64(ctx);
+ return FloatType();
+}
+
+unsigned FloatType::getFPMantissaWidth() {
+ return APFloat::semanticsPrecision(getFloatSemantics());
+}
+
//===----------------------------------------------------------------------===//
// ShapedType
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 6546234429c8cb..81e154328a4a2e 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -87,73 +87,33 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
}
//===----------------------------------------------------------------------===//
-// Float Type
-//===----------------------------------------------------------------------===//
-
-unsigned FloatType::getWidth() {
- return APFloat::semanticsSizeInBits(getFloatSemantics());
-}
-
-/// Returns the floating semantics for the given type.
-const llvm::fltSemantics &FloatType::getFloatSemantics() {
- if (llvm::isa<Float4E2M1FNType>(*this))
- return APFloat::Float4E2M1FN();
- if (llvm::isa<Float6E2M3FNType>(*this))
- return APFloat::Float6E2M3FN();
- if (llvm::isa<Float6E3M2FNType>(*this))
- return APFloat::Float6E3M2FN();
- if (llvm::isa<Float8E5M2Type>(*this))
- return APFloat::Float8E5M2();
- if (llvm::isa<Float8E4M3Type>(*this))
- return APFloat::Float8E4M3();
- if (llvm::isa<Float8E4M3FNType>(*this))
- return APFloat::Float8E4M3FN();
- if (llvm::isa<Float8E5M2FNUZType>(*this))
- return APFloat::Float8E5M2FNUZ();
- if (llvm::isa<Float8E4M3FNUZType>(*this))
- return APFloat::Float8E4M3FNUZ();
- if (llvm::isa<Float8E4M3B11FNUZType>(*this))
- return APFloat::Float8E4M3B11FNUZ();
- if (llvm::isa<Float8E3M4Type>(*this))
- return APFloat::Float8E3M4();
- if (llvm::isa<Float8E8M0FNUType>(*this))
- return APFloat::Float8E8M0FNU();
- if (llvm::isa<BFloat16Type>(*this))
- return APFloat::BFloat();
- if (llvm::isa<Float16Type>(*this))
- return APFloat::IEEEhalf();
- if (llvm::isa<FloatTF32Type>(*this))
- return APFloat::FloatTF32();
- if (llvm::isa<Float32Type>(*this))
- return APFloat::IEEEsingle();
- if (llvm::isa<Float64Type>(*this))
- return APFloat::IEEEdouble();
- if (llvm::isa<Float80Type>(*this))
- return APFloat::x87DoubleExtended();
- if (llvm::isa<Float128Type>(*this))
- return APFloat::IEEEquad();
- llvm_unreachable("non-floating point type used");
-}
-
-FloatType FloatType::scaleElementBitwidth(unsigned scale) {
- if (!scale)
- return FloatType();
- MLIRContext *ctx = getContext();
- if (isF16() || isBF16()) {
- if (scale == 2)
- return FloatType::getF32(ctx);
- if (scale == 4)
- return FloatType::getF64(ctx);
- }
- if (isF32())
- if (scale == 2)
- return FloatType::getF64(ctx);
- return FloatType();
-}
+// Float Types
+//===----------------------------------------------------------------------===//
-unsigned FloatType::getFPMantissaWidth() {
- return APFloat::semanticsPrecision(getFloatSemantics());
-}
+// Mapping from MLIR FloatType to APFloat semantics.
+#define FLOAT_TYPE_SEMANTICS(TYPE, SEM) \
+ const llvm::fltSemantics &TYPE::getFloatSemantics() const { \
+ return APFloat::SEM(); \
+ }
+FLOAT_TYPE_SEMANTICS(Float4E2M1FNType, Float4E2M1FN)
+FLOAT_TYPE_SEMANTICS(Float6E2M3FNType, Float6E2M3FN)
+FLOAT_TYPE_SEMANTICS(Float6E3M2FNType, Float6E3M2FN)
+FLOAT_TYPE_SEMANTICS(Float8E5M2Type, Float8E5M2)
+FLOAT_TYPE_SEMANTICS(Float8E4M3Type, Float8E4M3)
+FLOAT_TYPE_SEMANTICS(Float8E4M3FNType, Float8E4M3FN)
+FLOAT_TYPE_SEMANTICS(Float8E5M2FNUZType, Float8E5M2FNUZ)
+FLOAT_TYPE_SEMANTICS(Float8E4M3FNUZType, Float8E4M3FNUZ)
+FLOAT_TYPE_SEMANTICS(Float8E4M3B11FNUZType, Float8E4M3B11FNUZ)
+FLOAT_TYPE_SEMANTICS(Float8E3M4Type, Float8E3M4)
+FLOAT_TYPE_SEMANTICS(Float8E8M0FNUType, Float8E8M0FNU)
+FLOAT_TYPE_SEMANTICS(BFloat16Type, BFloat)
+FLOAT_TYPE_SEMANTICS(Float16Type, IEEEhalf)
+FLOAT_TYPE_SEMANTICS(FloatTF32Type, FloatTF32)
+FLOAT_TYPE_SEMANTICS(Float32Type, IEEEsingle)
+FLOAT_TYPE_SEMANTICS(Float64Type, IEEEdouble)
+FLOAT_TYPE_SEMANTICS(Float80Type, x87DoubleExtended)
+FLOAT_TYPE_SEMANTICS(Float128Type, IEEEquad)
+#undef FLOAT_TYPE_SEMANTICS
//===----------------------------------------------------------------------===//
// FunctionType
diff --git a/mlir/unittests/IR/InterfaceAttachmentTest.cpp b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
index b6066dd5685dc6..1b5d3b8c31bd22 100644
--- a/mlir/unittests/IR/InterfaceAttachmentTest.cpp
+++ b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
@@ -43,7 +43,7 @@ struct Model
/// overrides default methods.
struct OverridingModel
: public TestExternalTypeInterface::ExternalModel<OverridingModel,
- FloatType> {
+ Float32Type> {
unsigned getBitwidthPlusArg(Type type, unsigned arg) const {
return type.getIntOrFloatBitWidth() + arg;
}
More information about the Mlir-commits
mailing list