[Mlir-commits] [mlir] f494346 - [mlir][IR] Remove builder API + caching for low-precision FP types (#123321)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Jan 18 01:38:55 PST 2025


Author: Matthias Springer
Date: 2025-01-18T10:38:51+01:00
New Revision: f4943464d769e2eacd5c54dfaaf0468788abeb84

URL: https://github.com/llvm/llvm-project/commit/f4943464d769e2eacd5c54dfaaf0468788abeb84
DIFF: https://github.com/llvm/llvm-project/commit/f4943464d769e2eacd5c54dfaaf0468788abeb84.diff

LOG: [mlir][IR] Remove builder API + caching for low-precision FP types (#123321)

Remove builder API (e.g., `b.getFloat4E2M1FNType()`) and caching in
`MLIRContext` for low-precision FP types. Types are still cached in the
type uniquer.

For details, see:
https://discourse.llvm.org/t/rethink-on-approach-to-low-precision-fp-types/82361/28

Note for LLVM integration: Use `b.getType<Float4E2M1FNType>()` or
`Float4E2M1FNType::get(b.getContext())` instead of
`b.getFloat4E2M1FNType()`.

Added: 
    

Modified: 
    mlir/include/mlir/IR/Builders.h
    mlir/include/mlir/IR/BuiltinTypes.td
    mlir/include/mlir/IR/CommonTypeConstraints.td
    mlir/lib/AsmParser/TypeParser.cpp
    mlir/lib/Dialect/Arith/Utils/Utils.cpp
    mlir/lib/IR/Builders.cpp
    mlir/lib/IR/MLIRContext.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index daea2a23d6fbed..cd8d3ee0af72b0 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -61,17 +61,6 @@ class Builder {
                        Attribute metadata = Attribute());
 
   // Types.
-  FloatType getFloat4E2M1FNType();
-  FloatType getFloat6E2M3FNType();
-  FloatType getFloat6E3M2FNType();
-  FloatType getFloat8E5M2Type();
-  FloatType getFloat8E4M3Type();
-  FloatType getFloat8E4M3FNType();
-  FloatType getFloat8E5M2FNUZType();
-  FloatType getFloat8E4M3FNUZType();
-  FloatType getFloat8E4M3B11FNUZType();
-  FloatType getFloat8E3M4Type();
-  FloatType getFloat8E8M0FNUType();
   FloatType getBF16Type();
   FloatType getF16Type();
   FloatType getTF32Type();

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index fc50b28c09e41c..4f09d2e41e7ceb 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -85,6 +85,12 @@ class Builtin_FloatType<string name, string mnemonic,
         DeclareTypeInterfaceMethods<
             FloatTypeInterface,
             ["getFloatSemantics"] # declaredInterfaceMethods>]> {
+}
+
+// Float types that are cached in MLIRContext.
+class Builtin_CachedFloatType<string name, string mnemonic,
+                              list<string> declaredInterfaceMethods = []>
+    : Builtin_FloatType<name, mnemonic, declaredInterfaceMethods> {
   let extraClassDeclaration = [{
     static }] # name # [{Type get(MLIRContext *context);
   }];
@@ -326,7 +332,7 @@ def Builtin_Float8E8M0FNU : Builtin_FloatType<"Float8E8M0FNU", "f8E8M0FNU"> {
 //===----------------------------------------------------------------------===//
 // BFloat16Type
 
-def Builtin_BFloat16 : Builtin_FloatType<"BFloat16", "bf16",
+def Builtin_BFloat16 : Builtin_CachedFloatType<"BFloat16", "bf16",
     /*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
   let summary = "bfloat16 floating-point type";
 }
@@ -334,7 +340,7 @@ def Builtin_BFloat16 : Builtin_FloatType<"BFloat16", "bf16",
 //===----------------------------------------------------------------------===//
 // Float16Type
 
-def Builtin_Float16 : Builtin_FloatType<"Float16", "f16",
+def Builtin_Float16 : Builtin_CachedFloatType<"Float16", "f16",
     /*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
   let summary = "16-bit floating-point type";
 }
@@ -342,14 +348,14 @@ def Builtin_Float16 : Builtin_FloatType<"Float16", "f16",
 //===----------------------------------------------------------------------===//
 // FloatTF32Type
 
-def Builtin_FloatTF32 : Builtin_FloatType<"FloatTF32", "tf32"> {
+def Builtin_FloatTF32 : Builtin_CachedFloatType<"FloatTF32", "tf32"> {
   let summary = "TF32 floating-point type";
 }
 
 //===----------------------------------------------------------------------===//
 // Float32Type
 
-def Builtin_Float32 : Builtin_FloatType<"Float32", "f32",
+def Builtin_Float32 : Builtin_CachedFloatType<"Float32", "f32",
     /*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
   let summary = "32-bit floating-point type";
 }
@@ -357,21 +363,21 @@ def Builtin_Float32 : Builtin_FloatType<"Float32", "f32",
 //===----------------------------------------------------------------------===//
 // Float64Type
 
-def Builtin_Float64 : Builtin_FloatType<"Float64", "f64"> {
+def Builtin_Float64 : Builtin_CachedFloatType<"Float64", "f64"> {
   let summary = "64-bit floating-point type";
 }
 
 //===----------------------------------------------------------------------===//
 // Float80Type
 
-def Builtin_Float80 : Builtin_FloatType<"Float80", "f80"> {
+def Builtin_Float80 : Builtin_CachedFloatType<"Float80", "f80"> {
   let summary = "80-bit floating-point type";
 }
 
 //===----------------------------------------------------------------------===//
 // Float128Type
 
-def Builtin_Float128 : Builtin_FloatType<"Float128", "f128"> {
+def Builtin_Float128 : Builtin_CachedFloatType<"Float128", "f128"> {
   let summary = "128-bit floating-point type";
 }
 

diff  --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index b9f8c1ed19470d..6f52195c1d7c92 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -330,31 +330,31 @@ def F80 : F<80>;
 def F128 : F<128>;
 
 def BF16 : Type<CPred<"$_self.isBF16()">, "bfloat16 type">,
-           BuildableType<"$_builder.getBF16Type()">;
+           BuildableType<"$_builder.getType<BFloat16Type>()">;
 def TF32 : Type<CPred<"$_self.isTF32()">, "tf32 type">,
-           BuildableType<"$_builder.getTF32Type()">;
+           BuildableType<"$_builder.getType<FloatTF32Type>()">;
 def F8E4M3FN : Type<CPred<"$_self.isFloat8E4M3FN()">, "f8E4M3FN type">,
-               BuildableType<"$_builder.getFloat8E4M3FNType()">;
+               BuildableType<"$_builder.getType<Float8E4M3FNType>()">;
 def F8E5M2 : Type<CPred<"$_self.isFloat8E5M2()">, "f8E5M2 type">,
-             BuildableType<"$_builder.getFloat8E5M2Type()">;
+             BuildableType<"$_builder.getType<Float8E5M2Type>()">;
 def F8E4M3 : Type<CPred<"$_self.isFloat8E4M3()">, "f8E4M3 type">,
-             BuildableType<"$_builder.getFloat8E4M3Type()">;
+             BuildableType<"$_builder.getType<Float8E4M3Type>()">;
 def F8E4M3FNUZ : Type<CPred<"$_self.isFloat8E4M3FNUZ()">, "f8E4M3FNUZ type">,
-                 BuildableType<"$_builder.getFloat8E4M3FNUZType()">;
+                 BuildableType<"$_builder.getType<Float8E4M3FNUZType>()">;
 def F8E4M3B11FNUZ : Type<CPred<"$_self.isFloat8E4M3B11FNUZ()">, "f8E4M3B11FNUZ type">,
-                 BuildableType<"$_builder.getFloat8E4M3B11FNUZType()">;
+                 BuildableType<"$_builder.getType<Float8E4M3B11FNUZType>()">;
 def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
-                 BuildableType<"$_builder.getFloat8E5M2FNUZType()">;
+                 BuildableType<"$_builder.getType<Float8E5M2FNUZType>()">;
 def F8E3M4 : Type<CPred<"$_self.isFloat8E3M4()">, "f8E3M4 type">,
-             BuildableType<"$_builder.getFloat8E3M4Type()">;
+             BuildableType<"$_builder.getType<Float8E3M4Type>()">;
 def F4E2M1FN : Type<CPred<"$_self.isFloat4E2M1FN()">, "f4E2M1FN type">,
-               BuildableType<"$_builder.getFloat4E2M1FNType()">;
+               BuildableType<"$_builder.getType<Float4E2M1FNType>()">;
 def F6E2M3FN : Type<CPred<"$_self.isFloat6E2M3FN()">, "f6E2M3FN type">,
-               BuildableType<"$_builder.getFloat6E2M3FNType()">;
+               BuildableType<"$_builder.getType<Float6E2M3FNType>()">;
 def F6E3M2FN : Type<CPred<"$_self.isFloat6E3M2FN()">, "f6E3M2FN type">,
-               BuildableType<"$_builder.getFloat6E3M2FNType()">;
+               BuildableType<"$_builder.getType<Float6E3M2FNType>()">;
 def F8E8M0FNU : Type<CPred<"$_self.isFloat8E8M0FNU()">, "f8E8M0FNU type">,
-                BuildableType<"$_builder.getFloat8E8M0FNUType()">;
+                BuildableType<"$_builder.getType<Float8E8M0FNUType>()">;
 
 def AnyComplex : Type<CPred<"::llvm::isa<::mlir::ComplexType>($_self)">,
                       "complex-type", "::mlir::ComplexType">;

diff  --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index c614eb39b364be..21bb0ec3d0d515 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -309,58 +309,58 @@ Type Parser::parseNonFunctionType() {
   // float-type
   case Token::kw_f4E2M1FN:
     consumeToken(Token::kw_f4E2M1FN);
-    return builder.getFloat4E2M1FNType();
+    return builder.getType<Float4E2M1FNType>();
   case Token::kw_f6E2M3FN:
     consumeToken(Token::kw_f6E2M3FN);
-    return builder.getFloat6E2M3FNType();
+    return builder.getType<Float6E2M3FNType>();
   case Token::kw_f6E3M2FN:
     consumeToken(Token::kw_f6E3M2FN);
-    return builder.getFloat6E3M2FNType();
+    return builder.getType<Float6E3M2FNType>();
   case Token::kw_f8E5M2:
     consumeToken(Token::kw_f8E5M2);
-    return builder.getFloat8E5M2Type();
+    return builder.getType<Float8E5M2Type>();
   case Token::kw_f8E4M3:
     consumeToken(Token::kw_f8E4M3);
-    return builder.getFloat8E4M3Type();
+    return builder.getType<Float8E4M3Type>();
   case Token::kw_f8E4M3FN:
     consumeToken(Token::kw_f8E4M3FN);
-    return builder.getFloat8E4M3FNType();
+    return builder.getType<Float8E4M3FNType>();
   case Token::kw_f8E5M2FNUZ:
     consumeToken(Token::kw_f8E5M2FNUZ);
-    return builder.getFloat8E5M2FNUZType();
+    return builder.getType<Float8E5M2FNUZType>();
   case Token::kw_f8E4M3FNUZ:
     consumeToken(Token::kw_f8E4M3FNUZ);
-    return builder.getFloat8E4M3FNUZType();
+    return builder.getType<Float8E4M3FNUZType>();
   case Token::kw_f8E4M3B11FNUZ:
     consumeToken(Token::kw_f8E4M3B11FNUZ);
-    return builder.getFloat8E4M3B11FNUZType();
+    return builder.getType<Float8E4M3B11FNUZType>();
   case Token::kw_f8E3M4:
     consumeToken(Token::kw_f8E3M4);
-    return builder.getFloat8E3M4Type();
+    return builder.getType<Float8E3M4Type>();
   case Token::kw_f8E8M0FNU:
     consumeToken(Token::kw_f8E8M0FNU);
-    return builder.getFloat8E8M0FNUType();
+    return builder.getType<Float8E8M0FNUType>();
   case Token::kw_bf16:
     consumeToken(Token::kw_bf16);
-    return builder.getBF16Type();
+    return builder.getType<BFloat16Type>();
   case Token::kw_f16:
     consumeToken(Token::kw_f16);
-    return builder.getF16Type();
+    return builder.getType<Float16Type>();
   case Token::kw_tf32:
     consumeToken(Token::kw_tf32);
-    return builder.getTF32Type();
+    return builder.getType<FloatTF32Type>();
   case Token::kw_f32:
     consumeToken(Token::kw_f32);
-    return builder.getF32Type();
+    return builder.getType<Float32Type>();
   case Token::kw_f64:
     consumeToken(Token::kw_f64);
-    return builder.getF64Type();
+    return builder.getType<Float64Type>();
   case Token::kw_f80:
     consumeToken(Token::kw_f80);
-    return builder.getF80Type();
+    return builder.getType<Float80Type>();
   case Token::kw_f128:
     consumeToken(Token::kw_f128);
-    return builder.getF128Type();
+    return builder.getType<Float128Type>();
 
   // index-type
   case Token::kw_index:

diff  --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index 0fa7d321844113..39c9005e449e38 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -361,22 +361,22 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
 std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name) {
   Builder b(ctx);
   return llvm::StringSwitch<std::optional<FloatType>>(name)
-      .Case("f4E2M1FN", b.getFloat4E2M1FNType())
-      .Case("f6E2M3FN", b.getFloat6E2M3FNType())
-      .Case("f6E3M2FN", b.getFloat6E3M2FNType())
-      .Case("f8E5M2", b.getFloat8E5M2Type())
-      .Case("f8E4M3", b.getFloat8E4M3Type())
-      .Case("f8E4M3FN", b.getFloat8E4M3FNType())
-      .Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
-      .Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
-      .Case("f8E3M4", b.getFloat8E3M4Type())
-      .Case("f8E8M0FNU", b.getFloat8E8M0FNUType())
-      .Case("bf16", b.getBF16Type())
-      .Case("f16", b.getF16Type())
-      .Case("f32", b.getF32Type())
-      .Case("f64", b.getF64Type())
-      .Case("f80", b.getF80Type())
-      .Case("f128", b.getF128Type())
+      .Case("f4E2M1FN", b.getType<Float4E2M1FNType>())
+      .Case("f6E2M3FN", b.getType<Float6E2M3FNType>())
+      .Case("f6E3M2FN", b.getType<Float6E3M2FNType>())
+      .Case("f8E5M2", b.getType<Float8E5M2Type>())
+      .Case("f8E4M3", b.getType<Float8E4M3Type>())
+      .Case("f8E4M3FN", b.getType<Float8E4M3FNType>())
+      .Case("f8E5M2FNUZ", b.getType<Float8E5M2FNUZType>())
+      .Case("f8E4M3FNUZ", b.getType<Float8E4M3FNUZType>())
+      .Case("f8E3M4", b.getType<Float8E3M4Type>())
+      .Case("f8E8M0FNU", b.getType<Float8E8M0FNUType>())
+      .Case("bf16", b.getType<BFloat16Type>())
+      .Case("f16", b.getType<Float16Type>())
+      .Case("f32", b.getType<Float32Type>())
+      .Case("f64", b.getType<Float64Type>())
+      .Case("f80", b.getType<Float80Type>())
+      .Case("f128", b.getType<Float128Type>())
       .Default(std::nullopt);
 }
 

diff  --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 8439b063f2634b..d57a7ca07ede58 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -34,44 +34,6 @@ Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
 // Types.
 //===----------------------------------------------------------------------===//
 
-FloatType Builder::getFloat4E2M1FNType() {
-  return Float4E2M1FNType::get(context);
-}
-
-FloatType Builder::getFloat6E2M3FNType() {
-  return Float6E2M3FNType::get(context);
-}
-
-FloatType Builder::getFloat6E3M2FNType() {
-  return Float6E3M2FNType::get(context);
-}
-
-FloatType Builder::getFloat8E5M2Type() { return Float8E5M2Type::get(context); }
-
-FloatType Builder::getFloat8E4M3Type() { return Float8E4M3Type::get(context); }
-
-FloatType Builder::getFloat8E4M3FNType() {
-  return Float8E4M3FNType::get(context);
-}
-
-FloatType Builder::getFloat8E5M2FNUZType() {
-  return Float8E5M2FNUZType::get(context);
-}
-
-FloatType Builder::getFloat8E4M3FNUZType() {
-  return Float8E4M3FNUZType::get(context);
-}
-
-FloatType Builder::getFloat8E4M3B11FNUZType() {
-  return Float8E4M3B11FNUZType::get(context);
-}
-
-FloatType Builder::getFloat8E3M4Type() { return Float8E3M4Type::get(context); }
-
-FloatType Builder::getFloat8E8M0FNUType() {
-  return Float8E8M0FNUType::get(context);
-}
-
 FloatType Builder::getBF16Type() { return BFloat16Type::get(context); }
 
 FloatType Builder::getF16Type() { return Float16Type::get(context); }

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index b9e745fdf4a13e..87782e84dd6e4a 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -221,17 +221,6 @@ class MLIRContextImpl {
   llvm::DenseMap<StringRef, AbstractType *> nameToType;
 
   /// Cached Type Instances.
-  Float4E2M1FNType f4E2M1FNTy;
-  Float6E2M3FNType f6E2M3FNTy;
-  Float6E3M2FNType f6E3M2FNTy;
-  Float8E5M2Type f8E5M2Ty;
-  Float8E4M3Type f8E4M3Ty;
-  Float8E4M3FNType f8E4M3FNTy;
-  Float8E5M2FNUZType f8E5M2FNUZTy;
-  Float8E4M3FNUZType f8E4M3FNUZTy;
-  Float8E4M3B11FNUZType f8E4M3B11FNUZTy;
-  Float8E3M4Type f8E3M4Ty;
-  Float8E8M0FNUType f8E8M0FNUTy;
   BFloat16Type bf16Ty;
   Float16Type f16Ty;
   FloatTF32Type tf32Ty;
@@ -317,17 +306,6 @@ MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
 
   //// Types.
   /// Floating-point Types.
-  impl->f4E2M1FNTy = TypeUniquer::get<Float4E2M1FNType>(this);
-  impl->f6E2M3FNTy = TypeUniquer::get<Float6E2M3FNType>(this);
-  impl->f6E3M2FNTy = TypeUniquer::get<Float6E3M2FNType>(this);
-  impl->f8E5M2Ty = TypeUniquer::get<Float8E5M2Type>(this);
-  impl->f8E4M3Ty = TypeUniquer::get<Float8E4M3Type>(this);
-  impl->f8E4M3FNTy = TypeUniquer::get<Float8E4M3FNType>(this);
-  impl->f8E5M2FNUZTy = TypeUniquer::get<Float8E5M2FNUZType>(this);
-  impl->f8E4M3FNUZTy = TypeUniquer::get<Float8E4M3FNUZType>(this);
-  impl->f8E4M3B11FNUZTy = TypeUniquer::get<Float8E4M3B11FNUZType>(this);
-  impl->f8E3M4Ty = TypeUniquer::get<Float8E3M4Type>(this);
-  impl->f8E8M0FNUTy = TypeUniquer::get<Float8E8M0FNUType>(this);
   impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
   impl->f16Ty = TypeUniquer::get<Float16Type>(this);
   impl->tf32Ty = TypeUniquer::get<FloatTF32Type>(this);
@@ -1044,39 +1022,6 @@ AbstractType::lookup(StringRef name, MLIRContext *context) {
 /// This should not be used directly.
 StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
 
-Float4E2M1FNType Float4E2M1FNType::get(MLIRContext *context) {
-  return context->getImpl().f4E2M1FNTy;
-}
-Float6E2M3FNType Float6E2M3FNType::get(MLIRContext *context) {
-  return context->getImpl().f6E2M3FNTy;
-}
-Float6E3M2FNType Float6E3M2FNType::get(MLIRContext *context) {
-  return context->getImpl().f6E3M2FNTy;
-}
-Float8E5M2Type Float8E5M2Type::get(MLIRContext *context) {
-  return context->getImpl().f8E5M2Ty;
-}
-Float8E4M3Type Float8E4M3Type::get(MLIRContext *context) {
-  return context->getImpl().f8E4M3Ty;
-}
-Float8E4M3FNType Float8E4M3FNType::get(MLIRContext *context) {
-  return context->getImpl().f8E4M3FNTy;
-}
-Float8E5M2FNUZType Float8E5M2FNUZType::get(MLIRContext *context) {
-  return context->getImpl().f8E5M2FNUZTy;
-}
-Float8E4M3FNUZType Float8E4M3FNUZType::get(MLIRContext *context) {
-  return context->getImpl().f8E4M3FNUZTy;
-}
-Float8E4M3B11FNUZType Float8E4M3B11FNUZType::get(MLIRContext *context) {
-  return context->getImpl().f8E4M3B11FNUZTy;
-}
-Float8E3M4Type Float8E3M4Type::get(MLIRContext *context) {
-  return context->getImpl().f8E3M4Ty;
-}
-Float8E8M0FNUType Float8E8M0FNUType::get(MLIRContext *context) {
-  return context->getImpl().f8E8M0FNUTy;
-}
 BFloat16Type BFloat16Type::get(MLIRContext *context) {
   return context->getImpl().bf16Ty;
 }


        


More information about the Mlir-commits mailing list