[Mlir-commits] [mlir] [mlir][IR] Remove builder API + caching for low-precision FP types (PR #123321)
Matthias Springer
llvmlistbot at llvm.org
Fri Jan 17 09:10:52 PST 2025
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/123321
>From 2e1833f4db4df710abd2e402a6453659803f6fe7 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Fri, 17 Jan 2025 12:03:57 +0100
Subject: [PATCH] [mlir][IR] Remove builder API + caching for low-precision FP
types
---
mlir/include/mlir/IR/Builders.h | 11 ----
mlir/include/mlir/IR/BuiltinTypes.td | 20 ++++---
mlir/include/mlir/IR/CommonTypeConstraints.td | 26 ++++-----
mlir/lib/AsmParser/TypeParser.cpp | 36 ++++++------
mlir/lib/Dialect/Arith/Utils/Utils.cpp | 32 +++++------
mlir/lib/IR/Builders.cpp | 38 -------------
mlir/lib/IR/MLIRContext.cpp | 55 -------------------
7 files changed, 60 insertions(+), 158 deletions(-)
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index daea2a23d6fbed..cd8d3ee0af72b0 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -61,17 +61,6 @@ class Builder {
Attribute metadata = Attribute());
// Types.
- FloatType getFloat4E2M1FNType();
- FloatType getFloat6E2M3FNType();
- FloatType getFloat6E3M2FNType();
- FloatType getFloat8E5M2Type();
- FloatType getFloat8E4M3Type();
- FloatType getFloat8E4M3FNType();
- FloatType getFloat8E5M2FNUZType();
- FloatType getFloat8E4M3FNUZType();
- FloatType getFloat8E4M3B11FNUZType();
- FloatType getFloat8E3M4Type();
- FloatType getFloat8E8M0FNUType();
FloatType getBF16Type();
FloatType getF16Type();
FloatType getTF32Type();
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index fc50b28c09e41c..4f09d2e41e7ceb 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -85,6 +85,12 @@ class Builtin_FloatType<string name, string mnemonic,
DeclareTypeInterfaceMethods<
FloatTypeInterface,
["getFloatSemantics"] # declaredInterfaceMethods>]> {
+}
+
+// Float types that are cached in MLIRContext.
+class Builtin_CachedFloatType<string name, string mnemonic,
+ list<string> declaredInterfaceMethods = []>
+ : Builtin_FloatType<name, mnemonic, declaredInterfaceMethods> {
let extraClassDeclaration = [{
static }] # name # [{Type get(MLIRContext *context);
}];
@@ -326,7 +332,7 @@ def Builtin_Float8E8M0FNU : Builtin_FloatType<"Float8E8M0FNU", "f8E8M0FNU"> {
//===----------------------------------------------------------------------===//
// BFloat16Type
-def Builtin_BFloat16 : Builtin_FloatType<"BFloat16", "bf16",
+def Builtin_BFloat16 : Builtin_CachedFloatType<"BFloat16", "bf16",
/*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
let summary = "bfloat16 floating-point type";
}
@@ -334,7 +340,7 @@ def Builtin_BFloat16 : Builtin_FloatType<"BFloat16", "bf16",
//===----------------------------------------------------------------------===//
// Float16Type
-def Builtin_Float16 : Builtin_FloatType<"Float16", "f16",
+def Builtin_Float16 : Builtin_CachedFloatType<"Float16", "f16",
/*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
let summary = "16-bit floating-point type";
}
@@ -342,14 +348,14 @@ def Builtin_Float16 : Builtin_FloatType<"Float16", "f16",
//===----------------------------------------------------------------------===//
// FloatTF32Type
-def Builtin_FloatTF32 : Builtin_FloatType<"FloatTF32", "tf32"> {
+def Builtin_FloatTF32 : Builtin_CachedFloatType<"FloatTF32", "tf32"> {
let summary = "TF32 floating-point type";
}
//===----------------------------------------------------------------------===//
// Float32Type
-def Builtin_Float32 : Builtin_FloatType<"Float32", "f32",
+def Builtin_Float32 : Builtin_CachedFloatType<"Float32", "f32",
/*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
let summary = "32-bit floating-point type";
}
@@ -357,21 +363,21 @@ def Builtin_Float32 : Builtin_FloatType<"Float32", "f32",
//===----------------------------------------------------------------------===//
// Float64Type
-def Builtin_Float64 : Builtin_FloatType<"Float64", "f64"> {
+def Builtin_Float64 : Builtin_CachedFloatType<"Float64", "f64"> {
let summary = "64-bit floating-point type";
}
//===----------------------------------------------------------------------===//
// Float80Type
-def Builtin_Float80 : Builtin_FloatType<"Float80", "f80"> {
+def Builtin_Float80 : Builtin_CachedFloatType<"Float80", "f80"> {
let summary = "80-bit floating-point type";
}
//===----------------------------------------------------------------------===//
// Float128Type
-def Builtin_Float128 : Builtin_FloatType<"Float128", "f128"> {
+def Builtin_Float128 : Builtin_CachedFloatType<"Float128", "f128"> {
let summary = "128-bit floating-point type";
}
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index b9f8c1ed19470d..6f52195c1d7c92 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -330,31 +330,31 @@ def F80 : F<80>;
def F128 : F<128>;
def BF16 : Type<CPred<"$_self.isBF16()">, "bfloat16 type">,
- BuildableType<"$_builder.getBF16Type()">;
+ BuildableType<"$_builder.getType<BFloat16Type>()">;
def TF32 : Type<CPred<"$_self.isTF32()">, "tf32 type">,
- BuildableType<"$_builder.getTF32Type()">;
+ BuildableType<"$_builder.getType<FloatTF32Type>()">;
def F8E4M3FN : Type<CPred<"$_self.isFloat8E4M3FN()">, "f8E4M3FN type">,
- BuildableType<"$_builder.getFloat8E4M3FNType()">;
+ BuildableType<"$_builder.getType<Float8E4M3FNType>()">;
def F8E5M2 : Type<CPred<"$_self.isFloat8E5M2()">, "f8E5M2 type">,
- BuildableType<"$_builder.getFloat8E5M2Type()">;
+ BuildableType<"$_builder.getType<Float8E5M2Type>()">;
def F8E4M3 : Type<CPred<"$_self.isFloat8E4M3()">, "f8E4M3 type">,
- BuildableType<"$_builder.getFloat8E4M3Type()">;
+ BuildableType<"$_builder.getType<Float8E4M3Type>()">;
def F8E4M3FNUZ : Type<CPred<"$_self.isFloat8E4M3FNUZ()">, "f8E4M3FNUZ type">,
- BuildableType<"$_builder.getFloat8E4M3FNUZType()">;
+ BuildableType<"$_builder.getType<Float8E4M3FNUZType>()">;
def F8E4M3B11FNUZ : Type<CPred<"$_self.isFloat8E4M3B11FNUZ()">, "f8E4M3B11FNUZ type">,
- BuildableType<"$_builder.getFloat8E4M3B11FNUZType()">;
+ BuildableType<"$_builder.getType<Float8E4M3B11FNUZType>()">;
def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
- BuildableType<"$_builder.getFloat8E5M2FNUZType()">;
+ BuildableType<"$_builder.getType<Float8E5M2FNUZType>()">;
def F8E3M4 : Type<CPred<"$_self.isFloat8E3M4()">, "f8E3M4 type">,
- BuildableType<"$_builder.getFloat8E3M4Type()">;
+ BuildableType<"$_builder.getType<Float8E3M4Type>()">;
def F4E2M1FN : Type<CPred<"$_self.isFloat4E2M1FN()">, "f4E2M1FN type">,
- BuildableType<"$_builder.getFloat4E2M1FNType()">;
+ BuildableType<"$_builder.getType<Float4E2M1FNType>()">;
def F6E2M3FN : Type<CPred<"$_self.isFloat6E2M3FN()">, "f6E2M3FN type">,
- BuildableType<"$_builder.getFloat6E2M3FNType()">;
+ BuildableType<"$_builder.getType<Float6E2M3FNType>()">;
def F6E3M2FN : Type<CPred<"$_self.isFloat6E3M2FN()">, "f6E3M2FN type">,
- BuildableType<"$_builder.getFloat6E3M2FNType()">;
+ BuildableType<"$_builder.getType<Float6E3M2FNType>()">;
def F8E8M0FNU : Type<CPred<"$_self.isFloat8E8M0FNU()">, "f8E8M0FNU type">,
- BuildableType<"$_builder.getFloat8E8M0FNUType()">;
+ BuildableType<"$_builder.getType<Float8E8M0FNUType>()">;
def AnyComplex : Type<CPred<"::llvm::isa<::mlir::ComplexType>($_self)">,
"complex-type", "::mlir::ComplexType">;
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index c614eb39b364be..21bb0ec3d0d515 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -309,58 +309,58 @@ Type Parser::parseNonFunctionType() {
// float-type
case Token::kw_f4E2M1FN:
consumeToken(Token::kw_f4E2M1FN);
- return builder.getFloat4E2M1FNType();
+ return builder.getType<Float4E2M1FNType>();
case Token::kw_f6E2M3FN:
consumeToken(Token::kw_f6E2M3FN);
- return builder.getFloat6E2M3FNType();
+ return builder.getType<Float6E2M3FNType>();
case Token::kw_f6E3M2FN:
consumeToken(Token::kw_f6E3M2FN);
- return builder.getFloat6E3M2FNType();
+ return builder.getType<Float6E3M2FNType>();
case Token::kw_f8E5M2:
consumeToken(Token::kw_f8E5M2);
- return builder.getFloat8E5M2Type();
+ return builder.getType<Float8E5M2Type>();
case Token::kw_f8E4M3:
consumeToken(Token::kw_f8E4M3);
- return builder.getFloat8E4M3Type();
+ return builder.getType<Float8E4M3Type>();
case Token::kw_f8E4M3FN:
consumeToken(Token::kw_f8E4M3FN);
- return builder.getFloat8E4M3FNType();
+ return builder.getType<Float8E4M3FNType>();
case Token::kw_f8E5M2FNUZ:
consumeToken(Token::kw_f8E5M2FNUZ);
- return builder.getFloat8E5M2FNUZType();
+ return builder.getType<Float8E5M2FNUZType>();
case Token::kw_f8E4M3FNUZ:
consumeToken(Token::kw_f8E4M3FNUZ);
- return builder.getFloat8E4M3FNUZType();
+ return builder.getType<Float8E4M3FNUZType>();
case Token::kw_f8E4M3B11FNUZ:
consumeToken(Token::kw_f8E4M3B11FNUZ);
- return builder.getFloat8E4M3B11FNUZType();
+ return builder.getType<Float8E4M3B11FNUZType>();
case Token::kw_f8E3M4:
consumeToken(Token::kw_f8E3M4);
- return builder.getFloat8E3M4Type();
+ return builder.getType<Float8E3M4Type>();
case Token::kw_f8E8M0FNU:
consumeToken(Token::kw_f8E8M0FNU);
- return builder.getFloat8E8M0FNUType();
+ return builder.getType<Float8E8M0FNUType>();
case Token::kw_bf16:
consumeToken(Token::kw_bf16);
- return builder.getBF16Type();
+ return builder.getType<BFloat16Type>();
case Token::kw_f16:
consumeToken(Token::kw_f16);
- return builder.getF16Type();
+ return builder.getType<Float16Type>();
case Token::kw_tf32:
consumeToken(Token::kw_tf32);
- return builder.getTF32Type();
+ return builder.getType<FloatTF32Type>();
case Token::kw_f32:
consumeToken(Token::kw_f32);
- return builder.getF32Type();
+ return builder.getType<Float32Type>();
case Token::kw_f64:
consumeToken(Token::kw_f64);
- return builder.getF64Type();
+ return builder.getType<Float64Type>();
case Token::kw_f80:
consumeToken(Token::kw_f80);
- return builder.getF80Type();
+ return builder.getType<Float80Type>();
case Token::kw_f128:
consumeToken(Token::kw_f128);
- return builder.getF128Type();
+ return builder.getType<Float128Type>();
// index-type
case Token::kw_index:
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index 0fa7d321844113..39c9005e449e38 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -361,22 +361,22 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name) {
Builder b(ctx);
return llvm::StringSwitch<std::optional<FloatType>>(name)
- .Case("f4E2M1FN", b.getFloat4E2M1FNType())
- .Case("f6E2M3FN", b.getFloat6E2M3FNType())
- .Case("f6E3M2FN", b.getFloat6E3M2FNType())
- .Case("f8E5M2", b.getFloat8E5M2Type())
- .Case("f8E4M3", b.getFloat8E4M3Type())
- .Case("f8E4M3FN", b.getFloat8E4M3FNType())
- .Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
- .Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
- .Case("f8E3M4", b.getFloat8E3M4Type())
- .Case("f8E8M0FNU", b.getFloat8E8M0FNUType())
- .Case("bf16", b.getBF16Type())
- .Case("f16", b.getF16Type())
- .Case("f32", b.getF32Type())
- .Case("f64", b.getF64Type())
- .Case("f80", b.getF80Type())
- .Case("f128", b.getF128Type())
+ .Case("f4E2M1FN", b.getType<Float4E2M1FNType>())
+ .Case("f6E2M3FN", b.getType<Float6E2M3FNType>())
+ .Case("f6E3M2FN", b.getType<Float6E3M2FNType>())
+ .Case("f8E5M2", b.getType<Float8E5M2Type>())
+ .Case("f8E4M3", b.getType<Float8E4M3Type>())
+ .Case("f8E4M3FN", b.getType<Float8E4M3FNType>())
+ .Case("f8E5M2FNUZ", b.getType<Float8E5M2FNUZType>())
+ .Case("f8E4M3FNUZ", b.getType<Float8E4M3FNUZType>())
+ .Case("f8E3M4", b.getType<Float8E3M4Type>())
+ .Case("f8E8M0FNU", b.getType<Float8E8M0FNUType>())
+ .Case("bf16", b.getType<BFloat16Type>())
+ .Case("f16", b.getType<Float16Type>())
+ .Case("f32", b.getType<Float32Type>())
+ .Case("f64", b.getType<Float64Type>())
+ .Case("f80", b.getType<Float80Type>())
+ .Case("f128", b.getType<Float128Type>())
.Default(std::nullopt);
}
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 8439b063f2634b..d57a7ca07ede58 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -34,44 +34,6 @@ Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
// Types.
//===----------------------------------------------------------------------===//
-FloatType Builder::getFloat4E2M1FNType() {
- return Float4E2M1FNType::get(context);
-}
-
-FloatType Builder::getFloat6E2M3FNType() {
- return Float6E2M3FNType::get(context);
-}
-
-FloatType Builder::getFloat6E3M2FNType() {
- return Float6E3M2FNType::get(context);
-}
-
-FloatType Builder::getFloat8E5M2Type() { return Float8E5M2Type::get(context); }
-
-FloatType Builder::getFloat8E4M3Type() { return Float8E4M3Type::get(context); }
-
-FloatType Builder::getFloat8E4M3FNType() {
- return Float8E4M3FNType::get(context);
-}
-
-FloatType Builder::getFloat8E5M2FNUZType() {
- return Float8E5M2FNUZType::get(context);
-}
-
-FloatType Builder::getFloat8E4M3FNUZType() {
- return Float8E4M3FNUZType::get(context);
-}
-
-FloatType Builder::getFloat8E4M3B11FNUZType() {
- return Float8E4M3B11FNUZType::get(context);
-}
-
-FloatType Builder::getFloat8E3M4Type() { return Float8E3M4Type::get(context); }
-
-FloatType Builder::getFloat8E8M0FNUType() {
- return Float8E8M0FNUType::get(context);
-}
-
FloatType Builder::getBF16Type() { return BFloat16Type::get(context); }
FloatType Builder::getF16Type() { return Float16Type::get(context); }
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index b9e745fdf4a13e..87782e84dd6e4a 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -221,17 +221,6 @@ class MLIRContextImpl {
llvm::DenseMap<StringRef, AbstractType *> nameToType;
/// Cached Type Instances.
- Float4E2M1FNType f4E2M1FNTy;
- Float6E2M3FNType f6E2M3FNTy;
- Float6E3M2FNType f6E3M2FNTy;
- Float8E5M2Type f8E5M2Ty;
- Float8E4M3Type f8E4M3Ty;
- Float8E4M3FNType f8E4M3FNTy;
- Float8E5M2FNUZType f8E5M2FNUZTy;
- Float8E4M3FNUZType f8E4M3FNUZTy;
- Float8E4M3B11FNUZType f8E4M3B11FNUZTy;
- Float8E3M4Type f8E3M4Ty;
- Float8E8M0FNUType f8E8M0FNUTy;
BFloat16Type bf16Ty;
Float16Type f16Ty;
FloatTF32Type tf32Ty;
@@ -317,17 +306,6 @@ MLIRContext::MLIRContext(const DialectRegistry ®istry, Threading setting)
//// Types.
/// Floating-point Types.
- impl->f4E2M1FNTy = TypeUniquer::get<Float4E2M1FNType>(this);
- impl->f6E2M3FNTy = TypeUniquer::get<Float6E2M3FNType>(this);
- impl->f6E3M2FNTy = TypeUniquer::get<Float6E3M2FNType>(this);
- impl->f8E5M2Ty = TypeUniquer::get<Float8E5M2Type>(this);
- impl->f8E4M3Ty = TypeUniquer::get<Float8E4M3Type>(this);
- impl->f8E4M3FNTy = TypeUniquer::get<Float8E4M3FNType>(this);
- impl->f8E5M2FNUZTy = TypeUniquer::get<Float8E5M2FNUZType>(this);
- impl->f8E4M3FNUZTy = TypeUniquer::get<Float8E4M3FNUZType>(this);
- impl->f8E4M3B11FNUZTy = TypeUniquer::get<Float8E4M3B11FNUZType>(this);
- impl->f8E3M4Ty = TypeUniquer::get<Float8E3M4Type>(this);
- impl->f8E8M0FNUTy = TypeUniquer::get<Float8E8M0FNUType>(this);
impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
impl->f16Ty = TypeUniquer::get<Float16Type>(this);
impl->tf32Ty = TypeUniquer::get<FloatTF32Type>(this);
@@ -1044,39 +1022,6 @@ AbstractType::lookup(StringRef name, MLIRContext *context) {
/// This should not be used directly.
StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
-Float4E2M1FNType Float4E2M1FNType::get(MLIRContext *context) {
- return context->getImpl().f4E2M1FNTy;
-}
-Float6E2M3FNType Float6E2M3FNType::get(MLIRContext *context) {
- return context->getImpl().f6E2M3FNTy;
-}
-Float6E3M2FNType Float6E3M2FNType::get(MLIRContext *context) {
- return context->getImpl().f6E3M2FNTy;
-}
-Float8E5M2Type Float8E5M2Type::get(MLIRContext *context) {
- return context->getImpl().f8E5M2Ty;
-}
-Float8E4M3Type Float8E4M3Type::get(MLIRContext *context) {
- return context->getImpl().f8E4M3Ty;
-}
-Float8E4M3FNType Float8E4M3FNType::get(MLIRContext *context) {
- return context->getImpl().f8E4M3FNTy;
-}
-Float8E5M2FNUZType Float8E5M2FNUZType::get(MLIRContext *context) {
- return context->getImpl().f8E5M2FNUZTy;
-}
-Float8E4M3FNUZType Float8E4M3FNUZType::get(MLIRContext *context) {
- return context->getImpl().f8E4M3FNUZTy;
-}
-Float8E4M3B11FNUZType Float8E4M3B11FNUZType::get(MLIRContext *context) {
- return context->getImpl().f8E4M3B11FNUZTy;
-}
-Float8E3M4Type Float8E3M4Type::get(MLIRContext *context) {
- return context->getImpl().f8E3M4Ty;
-}
-Float8E8M0FNUType Float8E8M0FNUType::get(MLIRContext *context) {
- return context->getImpl().f8E8M0FNUTy;
-}
BFloat16Type BFloat16Type::get(MLIRContext *context) {
return context->getImpl().bf16Ty;
}
More information about the Mlir-commits
mailing list