[Mlir-commits] [mlir] [mlir][IR] Turn `FloatType` into a type interface (PR #118891)

Matthias Springer llvmlistbot at llvm.org
Mon Dec 16 01:43:41 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/118891

>From 1ebf66db89409d866f1535c1cfa9c62746822168 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Thu, 5 Dec 2024 23:37:45 +0100
Subject: [PATCH] [mlir][IR] Turn `FloatType` into a type interface

This makes it possible to add new floating point types in downstream projects. Also removes one place where we had to hard-code all existing floating point types (`FloatType::classof`).
---
 mlir/include/mlir/IR/BuiltinTypeInterfaces.h  |   9 ++
 mlir/include/mlir/IR/BuiltinTypeInterfaces.td |  59 ++++++++++
 mlir/include/mlir/IR/BuiltinTypes.h           |  56 ---------
 mlir/include/mlir/IR/BuiltinTypes.td          |  17 ++-
 mlir/lib/IR/BuiltinTypeInterfaces.cpp         |  13 +++
 mlir/lib/IR/BuiltinTypes.cpp                  | 106 ++++++++----------
 mlir/unittests/IR/InterfaceAttachmentTest.cpp |   2 +-
 7 files changed, 138 insertions(+), 124 deletions(-)

diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.h b/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
index ed5e5ca22c5958..e8011b5488dc98 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
@@ -11,6 +11,15 @@
 
 #include "mlir/IR/Types.h"
 
+namespace llvm {
+struct fltSemantics;
+} // namespace llvm
+
+namespace mlir {
+class FloatType;
+class MLIRContext;
+} // namespace mlir
+
 #include "mlir/IR/BuiltinTypeInterfaces.h.inc"
 
 #endif // MLIR_IR_BUILTINTYPEINTERFACES_H
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index c9dcd546cf67c2..c36b738e38f42a 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -16,6 +16,65 @@
 
 include "mlir/IR/OpBase.td"
 
+def FloatTypeInterface : TypeInterface<"FloatType"> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    This type interface should be implemented by all floating-point types. It
+    defines the LLVM APFloat semantics and provides a few helper functions.
+  }];
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns the APFloat semantics for this floating-point type.
+      }],
+      /*retTy=*/"const ::llvm::fltSemantics &",
+      /*methodName=*/"getFloatSemantics",
+      /*args=*/(ins)
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns a float type with bitwidth scaled by `scale`. Returns a "null"
+        float type if the scaled element type cannot be represented.
+      }],
+      /*retTy=*/"::mlir::FloatType",
+      /*methodName=*/"scaleElementBitwidth",
+      /*args=*/(ins "unsigned":$scale),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/"return ::mlir::FloatType();"
+    >
+  ];
+
+  let extraClassDeclaration = [{
+    // Convenience factories.
+    static FloatType getBF16(MLIRContext *ctx);
+    static FloatType getF16(MLIRContext *ctx);
+    static FloatType getF32(MLIRContext *ctx);
+    static FloatType getTF32(MLIRContext *ctx);
+    static FloatType getF64(MLIRContext *ctx);
+    static FloatType getF80(MLIRContext *ctx);
+    static FloatType getF128(MLIRContext *ctx);
+    static FloatType getFloat8E5M2(MLIRContext *ctx);
+    static FloatType getFloat8E4M3(MLIRContext *ctx);
+    static FloatType getFloat8E4M3FN(MLIRContext *ctx);
+    static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx);
+    static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
+    static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
+    static FloatType getFloat8E3M4(MLIRContext *ctx);
+    static FloatType getFloat4E2M1FN(MLIRContext *ctx);
+    static FloatType getFloat6E2M3FN(MLIRContext *ctx);
+    static FloatType getFloat6E3M2FN(MLIRContext *ctx);
+    static FloatType getFloat8E8M0FNU(MLIRContext *ctx);
+
+    /// Return the bitwidth of this float type.
+    unsigned getWidth();
+
+    /// Return the width of the mantissa of this type.
+    /// The width includes the integer bit.
+    unsigned getFPMantissaWidth();
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // MemRefElementTypeInterface
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 7f9c470ffec304..2b3c2b6d1753dc 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -25,7 +25,6 @@ struct fltSemantics;
 namespace mlir {
 class AffineExpr;
 class AffineMap;
-class FloatType;
 class IndexType;
 class IntegerType;
 class MemRefType;
@@ -44,52 +43,6 @@ template <typename ConcreteType>
 class ValueSemantics
     : public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
 
-//===----------------------------------------------------------------------===//
-// FloatType
-//===----------------------------------------------------------------------===//
-
-class FloatType : public Type {
-public:
-  using Type::Type;
-
-  // Convenience factories.
-  static FloatType getBF16(MLIRContext *ctx);
-  static FloatType getF16(MLIRContext *ctx);
-  static FloatType getF32(MLIRContext *ctx);
-  static FloatType getTF32(MLIRContext *ctx);
-  static FloatType getF64(MLIRContext *ctx);
-  static FloatType getF80(MLIRContext *ctx);
-  static FloatType getF128(MLIRContext *ctx);
-  static FloatType getFloat8E5M2(MLIRContext *ctx);
-  static FloatType getFloat8E4M3(MLIRContext *ctx);
-  static FloatType getFloat8E4M3FN(MLIRContext *ctx);
-  static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx);
-  static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
-  static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
-  static FloatType getFloat8E3M4(MLIRContext *ctx);
-  static FloatType getFloat4E2M1FN(MLIRContext *ctx);
-  static FloatType getFloat6E2M3FN(MLIRContext *ctx);
-  static FloatType getFloat6E3M2FN(MLIRContext *ctx);
-  static FloatType getFloat8E8M0FNU(MLIRContext *ctx);
-
-  /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(Type type);
-
-  /// Return the bitwidth of this float type.
-  unsigned getWidth();
-
-  /// Return the width of the mantissa of this type.
-  /// The width includes the integer bit.
-  unsigned getFPMantissaWidth();
-
-  /// Get or create a new FloatType with bitwidth scaled by `scale`.
-  /// Return null if the scaled element type cannot be represented.
-  FloatType scaleElementBitwidth(unsigned scale);
-
-  /// Return the floating semantics of this float type.
-  const llvm::fltSemantics &getFloatSemantics();
-};
-
 //===----------------------------------------------------------------------===//
 // TensorType
 //===----------------------------------------------------------------------===//
@@ -448,15 +401,6 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
          llvm::isa<MemRefElementTypeInterface>(type);
 }
 
-inline bool FloatType::classof(Type type) {
-  return llvm::isa<Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
-                   Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
-                   Float8E5M2FNUZType, Float8E4M3FNUZType,
-                   Float8E4M3B11FNUZType, Float8E3M4Type, Float8E8M0FNUType,
-                   BFloat16Type, Float16Type, FloatTF32Type, Float32Type,
-                   Float64Type, Float80Type, Float128Type>(type);
-}
-
 inline FloatType FloatType::getFloat4E2M1FN(MLIRContext *ctx) {
   return Float4E2M1FNType::get(ctx);
 }
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index dca228097d782d..fc50b28c09e41c 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -79,8 +79,12 @@ def Builtin_Complex : Builtin_Type<"Complex", "complex"> {
 //===----------------------------------------------------------------------===//
 
 // Base class for Builtin dialect float types.
-class Builtin_FloatType<string name, string mnemonic>
-    : Builtin_Type<name, mnemonic, /*traits=*/[], "::mlir::FloatType"> {
+class Builtin_FloatType<string name, string mnemonic,
+                        list<string> declaredInterfaceMethods = []>
+    : Builtin_Type<name, mnemonic, /*traits=*/[
+        DeclareTypeInterfaceMethods<
+            FloatTypeInterface,
+            ["getFloatSemantics"] # declaredInterfaceMethods>]> {
   let extraClassDeclaration = [{
     static }] # name # [{Type get(MLIRContext *context);
   }];
@@ -322,14 +326,16 @@ def Builtin_Float8E8M0FNU : Builtin_FloatType<"Float8E8M0FNU", "f8E8M0FNU"> {
 //===----------------------------------------------------------------------===//
 // BFloat16Type
 
-def Builtin_BFloat16 : Builtin_FloatType<"BFloat16", "bf16"> {
+def Builtin_BFloat16 : Builtin_FloatType<"BFloat16", "bf16",
+    /*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
   let summary = "bfloat16 floating-point type";
 }
 
 //===----------------------------------------------------------------------===//
 // Float16Type
 
-def Builtin_Float16 : Builtin_FloatType<"Float16", "f16"> {
+def Builtin_Float16 : Builtin_FloatType<"Float16", "f16",
+    /*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
   let summary = "16-bit floating-point type";
 }
 
@@ -343,7 +349,8 @@ def Builtin_FloatTF32 : Builtin_FloatType<"FloatTF32", "tf32"> {
 //===----------------------------------------------------------------------===//
 // Float32Type
 
-def Builtin_Float32 : Builtin_FloatType<"Float32", "f32"> {
+def Builtin_Float32 : Builtin_FloatType<"Float32", "f32",
+    /*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
   let summary = "32-bit floating-point type";
 }
 
diff --git a/mlir/lib/IR/BuiltinTypeInterfaces.cpp b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
index ab9e65b5edfed3..c663f6c9094604 100644
--- a/mlir/lib/IR/BuiltinTypeInterfaces.cpp
+++ b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
+#include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/Sequence.h"
 
 using namespace mlir;
@@ -19,6 +20,18 @@ using namespace mlir::detail;
 
 #include "mlir/IR/BuiltinTypeInterfaces.cpp.inc"
 
+//===----------------------------------------------------------------------===//
+// FloatType
+//===----------------------------------------------------------------------===//
+
+unsigned FloatType::getWidth() {
+  return APFloat::semanticsSizeInBits(getFloatSemantics());
+}
+
+unsigned FloatType::getFPMantissaWidth() {
+  return APFloat::semanticsPrecision(getFloatSemantics());
+}
+
 //===----------------------------------------------------------------------===//
 // ShapedType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 6546234429c8cb..41b794bc0aec59 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -87,72 +87,54 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
 }
 
 //===----------------------------------------------------------------------===//
-// Float Type
-//===----------------------------------------------------------------------===//
-
-unsigned FloatType::getWidth() {
-  return APFloat::semanticsSizeInBits(getFloatSemantics());
-}
-
-/// Returns the floating semantics for the given type.
-const llvm::fltSemantics &FloatType::getFloatSemantics() {
-  if (llvm::isa<Float4E2M1FNType>(*this))
-    return APFloat::Float4E2M1FN();
-  if (llvm::isa<Float6E2M3FNType>(*this))
-    return APFloat::Float6E2M3FN();
-  if (llvm::isa<Float6E3M2FNType>(*this))
-    return APFloat::Float6E3M2FN();
-  if (llvm::isa<Float8E5M2Type>(*this))
-    return APFloat::Float8E5M2();
-  if (llvm::isa<Float8E4M3Type>(*this))
-    return APFloat::Float8E4M3();
-  if (llvm::isa<Float8E4M3FNType>(*this))
-    return APFloat::Float8E4M3FN();
-  if (llvm::isa<Float8E5M2FNUZType>(*this))
-    return APFloat::Float8E5M2FNUZ();
-  if (llvm::isa<Float8E4M3FNUZType>(*this))
-    return APFloat::Float8E4M3FNUZ();
-  if (llvm::isa<Float8E4M3B11FNUZType>(*this))
-    return APFloat::Float8E4M3B11FNUZ();
-  if (llvm::isa<Float8E3M4Type>(*this))
-    return APFloat::Float8E3M4();
-  if (llvm::isa<Float8E8M0FNUType>(*this))
-    return APFloat::Float8E8M0FNU();
-  if (llvm::isa<BFloat16Type>(*this))
-    return APFloat::BFloat();
-  if (llvm::isa<Float16Type>(*this))
-    return APFloat::IEEEhalf();
-  if (llvm::isa<FloatTF32Type>(*this))
-    return APFloat::FloatTF32();
-  if (llvm::isa<Float32Type>(*this))
-    return APFloat::IEEEsingle();
-  if (llvm::isa<Float64Type>(*this))
-    return APFloat::IEEEdouble();
-  if (llvm::isa<Float80Type>(*this))
-    return APFloat::x87DoubleExtended();
-  if (llvm::isa<Float128Type>(*this))
-    return APFloat::IEEEquad();
-  llvm_unreachable("non-floating point type used");
-}
-
-FloatType FloatType::scaleElementBitwidth(unsigned scale) {
-  if (!scale)
-    return FloatType();
-  MLIRContext *ctx = getContext();
-  if (isF16() || isBF16()) {
-    if (scale == 2)
-      return FloatType::getF32(ctx);
-    if (scale == 4)
-      return FloatType::getF64(ctx);
+// Float Types
+//===----------------------------------------------------------------------===//
+
+// Mapping from MLIR FloatType to APFloat semantics.
+#define FLOAT_TYPE_SEMANTICS(TYPE, SEM)                                        \
+  const llvm::fltSemantics &TYPE::getFloatSemantics() const {                  \
+    return APFloat::SEM();                                                     \
   }
-  if (isF32())
-    if (scale == 2)
-      return FloatType::getF64(ctx);
+FLOAT_TYPE_SEMANTICS(Float4E2M1FNType, Float4E2M1FN)
+FLOAT_TYPE_SEMANTICS(Float6E2M3FNType, Float6E2M3FN)
+FLOAT_TYPE_SEMANTICS(Float6E3M2FNType, Float6E3M2FN)
+FLOAT_TYPE_SEMANTICS(Float8E5M2Type, Float8E5M2)
+FLOAT_TYPE_SEMANTICS(Float8E4M3Type, Float8E4M3)
+FLOAT_TYPE_SEMANTICS(Float8E4M3FNType, Float8E4M3FN)
+FLOAT_TYPE_SEMANTICS(Float8E5M2FNUZType, Float8E5M2FNUZ)
+FLOAT_TYPE_SEMANTICS(Float8E4M3FNUZType, Float8E4M3FNUZ)
+FLOAT_TYPE_SEMANTICS(Float8E4M3B11FNUZType, Float8E4M3B11FNUZ)
+FLOAT_TYPE_SEMANTICS(Float8E3M4Type, Float8E3M4)
+FLOAT_TYPE_SEMANTICS(Float8E8M0FNUType, Float8E8M0FNU)
+FLOAT_TYPE_SEMANTICS(BFloat16Type, BFloat)
+FLOAT_TYPE_SEMANTICS(Float16Type, IEEEhalf)
+FLOAT_TYPE_SEMANTICS(FloatTF32Type, FloatTF32)
+FLOAT_TYPE_SEMANTICS(Float32Type, IEEEsingle)
+FLOAT_TYPE_SEMANTICS(Float64Type, IEEEdouble)
+FLOAT_TYPE_SEMANTICS(Float80Type, x87DoubleExtended)
+FLOAT_TYPE_SEMANTICS(Float128Type, IEEEquad)
+#undef FLOAT_TYPE_SEMANTICS
+
+FloatType Float16Type::scaleElementBitwidth(unsigned scale) const {
+  if (scale == 2)
+    return FloatType::getF32(getContext());
+  if (scale == 4)
+    return FloatType::getF64(getContext());
   return FloatType();
 }
 
-unsigned FloatType::getFPMantissaWidth() {
-  return APFloat::semanticsPrecision(getFloatSemantics());
+FloatType BFloat16Type::scaleElementBitwidth(unsigned scale) const {
+  if (scale == 2)
+    return FloatType::getF32(getContext());
+  if (scale == 4)
+    return FloatType::getF64(getContext());
+  return FloatType();
+}
+
+FloatType Float32Type::scaleElementBitwidth(unsigned scale) const {
+  if (scale == 2)
+    return FloatType::getF64(getContext());
+  return FloatType();
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/unittests/IR/InterfaceAttachmentTest.cpp b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
index b6066dd5685dc6..1b5d3b8c31bd22 100644
--- a/mlir/unittests/IR/InterfaceAttachmentTest.cpp
+++ b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
@@ -43,7 +43,7 @@ struct Model
 /// overrides default methods.
 struct OverridingModel
     : public TestExternalTypeInterface::ExternalModel<OverridingModel,
-                                                      FloatType> {
+                                                      Float32Type> {
   unsigned getBitwidthPlusArg(Type type, unsigned arg) const {
     return type.getIntOrFloatBitWidth() + arg;
   }



More information about the Mlir-commits mailing list