[Mlir-commits] [mlir] [mlir][IR] Turn `FloatType` into a type interface (PR #118891)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Dec 5 16:48:43 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

This makes it possible to add new MLIR floating point types in downstream projects. (Adding new APFloat semantics in downstream projects is not possible yet, so parsing/printing/converting float literals of newly added types is not supported.)

Also removes one place where we had to hard-code all existing floating point types (`FloatType::classof`). See discussion here: https://discourse.llvm.org/t/rethink-on-approach-to-low-precision-fp-types/82361

No measurable compilation time changes for these lit tests:
```
Benchmark 1: mlir-opt ./mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir -split-input-file -convert-vector-to-llvm -o /dev/null
  BEFORE
  Time (mean ± σ):     248.4 ms ±   3.2 ms    [User: 237.0 ms, System: 20.1 ms]
  Range (min … max):   243.3 ms … 255.9 ms    30 runs

  AFTER
  Time (mean ± σ):     246.8 ms ±   3.2 ms    [User: 233.2 ms, System: 21.8 ms]
  Range (min … max):   240.2 ms … 252.1 ms    30 runs


Benchmark 2: mlir-opt- ./mlir/test/Dialect/Arith/canonicalize.mlir -split-input-file -canonicalize -o /dev/null
  BEFORE
  Time (mean ± σ):      37.3 ms ±   1.8 ms    [User: 31.6 ms, System: 30.4 ms]
  Range (min … max):    34.6 ms …  42.0 ms    200 runs

  AFTER
  Time (mean ± σ):      37.5 ms ±   2.0 ms    [User: 31.5 ms, System: 29.2 ms]
  Range (min … max):    34.5 ms …  43.0 ms    200 runs


Benchmark 3: mlir-opt ./mlir/test/Dialect/Tensor/canonicalize.mlir -split-input-file -canonicalize -allow-unregistered-dialect -o /dev/null
  BEFORE
  Time (mean ± σ):     152.2 ms ±   2.5 ms    [User: 140.1 ms, System: 12.2 ms]
  Range (min … max):   147.6 ms … 161.8 ms    200 runs

  AFTER
  Time (mean ± σ):     151.9 ms ±   2.7 ms    [User: 140.5 ms, System: 11.5 ms]
  Range (min … max):   147.2 ms … 159.1 ms    200 runs
```


---
Full diff: https://github.com/llvm/llvm-project/pull/118891.diff


7 Files Affected:

- (modified) mlir/include/mlir/IR/BuiltinTypeInterfaces.h (+9) 
- (modified) mlir/include/mlir/IR/BuiltinTypeInterfaces.td (+46) 
- (modified) mlir/include/mlir/IR/BuiltinTypes.h (-56) 
- (modified) mlir/include/mlir/IR/BuiltinTypes.td (+3-1) 
- (modified) mlir/lib/IR/BuiltinTypeInterfaces.cpp (+29) 
- (modified) mlir/lib/IR/BuiltinTypes.cpp (+26-66) 
- (modified) mlir/unittests/IR/InterfaceAttachmentTest.cpp (+1-1) 


``````````diff
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.h b/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
index ed5e5ca22c5958..e8011b5488dc98 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
@@ -11,6 +11,15 @@
 
 #include "mlir/IR/Types.h"
 
+namespace llvm {
+struct fltSemantics;
+} // namespace llvm
+
+namespace mlir {
+class FloatType;
+class MLIRContext;
+} // namespace mlir
+
 #include "mlir/IR/BuiltinTypeInterfaces.h.inc"
 
 #endif // MLIR_IR_BUILTINTYPEINTERFACES_H
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index c9dcd546cf67c2..8b0242672dfdb4 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -16,6 +16,52 @@
 
 include "mlir/IR/OpBase.td"
 
+def FloatTypeInterface : TypeInterface<"FloatType"> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    This type interface should be implemented by all floating-point types. It
+    defines the LLVM APFloat semantics and provides a few helper functions.
+  }];
+  let methods = [
+    InterfaceMethod<[{
+      Returns the APFloat semantics for this floating-point type.
+    }], "const llvm::fltSemantics &", "getFloatSemantics", (ins)>,
+  ];
+
+  let extraClassDeclaration = [{
+    // Convenience factories.
+    static FloatType getBF16(MLIRContext *ctx);
+    static FloatType getF16(MLIRContext *ctx);
+    static FloatType getF32(MLIRContext *ctx);
+    static FloatType getTF32(MLIRContext *ctx);
+    static FloatType getF64(MLIRContext *ctx);
+    static FloatType getF80(MLIRContext *ctx);
+    static FloatType getF128(MLIRContext *ctx);
+    static FloatType getFloat8E5M2(MLIRContext *ctx);
+    static FloatType getFloat8E4M3(MLIRContext *ctx);
+    static FloatType getFloat8E4M3FN(MLIRContext *ctx);
+    static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx);
+    static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
+    static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
+    static FloatType getFloat8E3M4(MLIRContext *ctx);
+    static FloatType getFloat4E2M1FN(MLIRContext *ctx);
+    static FloatType getFloat6E2M3FN(MLIRContext *ctx);
+    static FloatType getFloat6E3M2FN(MLIRContext *ctx);
+    static FloatType getFloat8E8M0FNU(MLIRContext *ctx);
+
+    /// Return the bitwidth of this float type.
+    unsigned getWidth();
+
+    /// Return the width of the mantissa of this type.
+    /// The width includes the integer bit.
+    unsigned getFPMantissaWidth();
+
+    /// Get or create a new FloatType with bitwidth scaled by `scale`.
+    /// Return null if the scaled element type cannot be represented.
+    FloatType scaleElementBitwidth(unsigned scale);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // MemRefElementTypeInterface
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 7f9c470ffec304..2b3c2b6d1753dc 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -25,7 +25,6 @@ struct fltSemantics;
 namespace mlir {
 class AffineExpr;
 class AffineMap;
-class FloatType;
 class IndexType;
 class IntegerType;
 class MemRefType;
@@ -44,52 +43,6 @@ template <typename ConcreteType>
 class ValueSemantics
     : public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
 
-//===----------------------------------------------------------------------===//
-// FloatType
-//===----------------------------------------------------------------------===//
-
-class FloatType : public Type {
-public:
-  using Type::Type;
-
-  // Convenience factories.
-  static FloatType getBF16(MLIRContext *ctx);
-  static FloatType getF16(MLIRContext *ctx);
-  static FloatType getF32(MLIRContext *ctx);
-  static FloatType getTF32(MLIRContext *ctx);
-  static FloatType getF64(MLIRContext *ctx);
-  static FloatType getF80(MLIRContext *ctx);
-  static FloatType getF128(MLIRContext *ctx);
-  static FloatType getFloat8E5M2(MLIRContext *ctx);
-  static FloatType getFloat8E4M3(MLIRContext *ctx);
-  static FloatType getFloat8E4M3FN(MLIRContext *ctx);
-  static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx);
-  static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
-  static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
-  static FloatType getFloat8E3M4(MLIRContext *ctx);
-  static FloatType getFloat4E2M1FN(MLIRContext *ctx);
-  static FloatType getFloat6E2M3FN(MLIRContext *ctx);
-  static FloatType getFloat6E3M2FN(MLIRContext *ctx);
-  static FloatType getFloat8E8M0FNU(MLIRContext *ctx);
-
-  /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(Type type);
-
-  /// Return the bitwidth of this float type.
-  unsigned getWidth();
-
-  /// Return the width of the mantissa of this type.
-  /// The width includes the integer bit.
-  unsigned getFPMantissaWidth();
-
-  /// Get or create a new FloatType with bitwidth scaled by `scale`.
-  /// Return null if the scaled element type cannot be represented.
-  FloatType scaleElementBitwidth(unsigned scale);
-
-  /// Return the floating semantics of this float type.
-  const llvm::fltSemantics &getFloatSemantics();
-};
-
 //===----------------------------------------------------------------------===//
 // TensorType
 //===----------------------------------------------------------------------===//
@@ -448,15 +401,6 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
          llvm::isa<MemRefElementTypeInterface>(type);
 }
 
-inline bool FloatType::classof(Type type) {
-  return llvm::isa<Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
-                   Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
-                   Float8E5M2FNUZType, Float8E4M3FNUZType,
-                   Float8E4M3B11FNUZType, Float8E3M4Type, Float8E8M0FNUType,
-                   BFloat16Type, Float16Type, FloatTF32Type, Float32Type,
-                   Float64Type, Float80Type, Float128Type>(type);
-}
-
 inline FloatType FloatType::getFloat4E2M1FN(MLIRContext *ctx) {
   return Float4E2M1FNType::get(ctx);
 }
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index dca228097d782d..a0afda4e3b465e 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -80,7 +80,9 @@ def Builtin_Complex : Builtin_Type<"Complex", "complex"> {
 
 // Base class for Builtin dialect float types.
 class Builtin_FloatType<string name, string mnemonic>
-    : Builtin_Type<name, mnemonic, /*traits=*/[], "::mlir::FloatType"> {
+    : Builtin_Type<name, mnemonic, /*traits=*/[
+        DeclareTypeInterfaceMethods<FloatTypeInterface,
+                                    ["getFloatSemantics"]>]> {
   let extraClassDeclaration = [{
     static }] # name # [{Type get(MLIRContext *context);
   }];
diff --git a/mlir/lib/IR/BuiltinTypeInterfaces.cpp b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
index ab9e65b5edfed3..1374e889833283 100644
--- a/mlir/lib/IR/BuiltinTypeInterfaces.cpp
+++ b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
+#include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/Sequence.h"
 
 using namespace mlir;
@@ -19,6 +20,34 @@ using namespace mlir::detail;
 
 #include "mlir/IR/BuiltinTypeInterfaces.cpp.inc"
 
+//===----------------------------------------------------------------------===//
+// FloatType
+//===----------------------------------------------------------------------===//
+
+unsigned FloatType::getWidth() {
+  return APFloat::semanticsSizeInBits(getFloatSemantics());
+}
+
+FloatType FloatType::scaleElementBitwidth(unsigned scale) {
+  if (!scale)
+    return FloatType();
+  MLIRContext *ctx = getContext();
+  if (isF16() || isBF16()) {
+    if (scale == 2)
+      return FloatType::getF32(ctx);
+    if (scale == 4)
+      return FloatType::getF64(ctx);
+  }
+  if (isF32())
+    if (scale == 2)
+      return FloatType::getF64(ctx);
+  return FloatType();
+}
+
+unsigned FloatType::getFPMantissaWidth() {
+  return APFloat::semanticsPrecision(getFloatSemantics());
+}
+
 //===----------------------------------------------------------------------===//
 // ShapedType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 6546234429c8cb..81e154328a4a2e 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -87,73 +87,33 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
 }
 
 //===----------------------------------------------------------------------===//
-// Float Type
-//===----------------------------------------------------------------------===//
-
-unsigned FloatType::getWidth() {
-  return APFloat::semanticsSizeInBits(getFloatSemantics());
-}
-
-/// Returns the floating semantics for the given type.
-const llvm::fltSemantics &FloatType::getFloatSemantics() {
-  if (llvm::isa<Float4E2M1FNType>(*this))
-    return APFloat::Float4E2M1FN();
-  if (llvm::isa<Float6E2M3FNType>(*this))
-    return APFloat::Float6E2M3FN();
-  if (llvm::isa<Float6E3M2FNType>(*this))
-    return APFloat::Float6E3M2FN();
-  if (llvm::isa<Float8E5M2Type>(*this))
-    return APFloat::Float8E5M2();
-  if (llvm::isa<Float8E4M3Type>(*this))
-    return APFloat::Float8E4M3();
-  if (llvm::isa<Float8E4M3FNType>(*this))
-    return APFloat::Float8E4M3FN();
-  if (llvm::isa<Float8E5M2FNUZType>(*this))
-    return APFloat::Float8E5M2FNUZ();
-  if (llvm::isa<Float8E4M3FNUZType>(*this))
-    return APFloat::Float8E4M3FNUZ();
-  if (llvm::isa<Float8E4M3B11FNUZType>(*this))
-    return APFloat::Float8E4M3B11FNUZ();
-  if (llvm::isa<Float8E3M4Type>(*this))
-    return APFloat::Float8E3M4();
-  if (llvm::isa<Float8E8M0FNUType>(*this))
-    return APFloat::Float8E8M0FNU();
-  if (llvm::isa<BFloat16Type>(*this))
-    return APFloat::BFloat();
-  if (llvm::isa<Float16Type>(*this))
-    return APFloat::IEEEhalf();
-  if (llvm::isa<FloatTF32Type>(*this))
-    return APFloat::FloatTF32();
-  if (llvm::isa<Float32Type>(*this))
-    return APFloat::IEEEsingle();
-  if (llvm::isa<Float64Type>(*this))
-    return APFloat::IEEEdouble();
-  if (llvm::isa<Float80Type>(*this))
-    return APFloat::x87DoubleExtended();
-  if (llvm::isa<Float128Type>(*this))
-    return APFloat::IEEEquad();
-  llvm_unreachable("non-floating point type used");
-}
-
-FloatType FloatType::scaleElementBitwidth(unsigned scale) {
-  if (!scale)
-    return FloatType();
-  MLIRContext *ctx = getContext();
-  if (isF16() || isBF16()) {
-    if (scale == 2)
-      return FloatType::getF32(ctx);
-    if (scale == 4)
-      return FloatType::getF64(ctx);
-  }
-  if (isF32())
-    if (scale == 2)
-      return FloatType::getF64(ctx);
-  return FloatType();
-}
+// Float Types
+//===----------------------------------------------------------------------===//
 
-unsigned FloatType::getFPMantissaWidth() {
-  return APFloat::semanticsPrecision(getFloatSemantics());
-}
+// Mapping from MLIR FloatType to APFloat semantics.
+#define FLOAT_TYPE_SEMANTICS(TYPE, SEM)                                        \
+  const llvm::fltSemantics &TYPE::getFloatSemantics() const {                  \
+    return APFloat::SEM();                                                     \
+  }
+FLOAT_TYPE_SEMANTICS(Float4E2M1FNType, Float4E2M1FN)
+FLOAT_TYPE_SEMANTICS(Float6E2M3FNType, Float6E2M3FN)
+FLOAT_TYPE_SEMANTICS(Float6E3M2FNType, Float6E3M2FN)
+FLOAT_TYPE_SEMANTICS(Float8E5M2Type, Float8E5M2)
+FLOAT_TYPE_SEMANTICS(Float8E4M3Type, Float8E4M3)
+FLOAT_TYPE_SEMANTICS(Float8E4M3FNType, Float8E4M3FN)
+FLOAT_TYPE_SEMANTICS(Float8E5M2FNUZType, Float8E5M2FNUZ)
+FLOAT_TYPE_SEMANTICS(Float8E4M3FNUZType, Float8E4M3FNUZ)
+FLOAT_TYPE_SEMANTICS(Float8E4M3B11FNUZType, Float8E4M3B11FNUZ)
+FLOAT_TYPE_SEMANTICS(Float8E3M4Type, Float8E3M4)
+FLOAT_TYPE_SEMANTICS(Float8E8M0FNUType, Float8E8M0FNU)
+FLOAT_TYPE_SEMANTICS(BFloat16Type, BFloat)
+FLOAT_TYPE_SEMANTICS(Float16Type, IEEEhalf)
+FLOAT_TYPE_SEMANTICS(FloatTF32Type, FloatTF32)
+FLOAT_TYPE_SEMANTICS(Float32Type, IEEEsingle)
+FLOAT_TYPE_SEMANTICS(Float64Type, IEEEdouble)
+FLOAT_TYPE_SEMANTICS(Float80Type, x87DoubleExtended)
+FLOAT_TYPE_SEMANTICS(Float128Type, IEEEquad)
+#undef FLOAT_TYPE_SEMANTICS
 
 //===----------------------------------------------------------------------===//
 // FunctionType
diff --git a/mlir/unittests/IR/InterfaceAttachmentTest.cpp b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
index b6066dd5685dc6..1b5d3b8c31bd22 100644
--- a/mlir/unittests/IR/InterfaceAttachmentTest.cpp
+++ b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
@@ -43,7 +43,7 @@ struct Model
 /// overrides default methods.
 struct OverridingModel
     : public TestExternalTypeInterface::ExternalModel<OverridingModel,
-                                                      FloatType> {
+                                                      Float32Type> {
   unsigned getBitwidthPlusArg(Type type, unsigned arg) const {
     return type.getIntOrFloatBitWidth() + arg;
   }

``````````

</details>


https://github.com/llvm/llvm-project/pull/118891


More information about the Mlir-commits mailing list