[Mlir-commits] [mlir] [mlir][arith][NFC] Use type parser instead of hard-coding type keywords (PR #186753)
Matthias Springer
llvmlistbot at llvm.org
Mon Mar 16 01:53:27 PDT 2026
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/186753
Parse type literals instead of hard-coding them in a switch-case statement. This new implementation also works for non-builtin floating-point types.
Assisted by: claude-opus-4.6
>From 5272633e28612ac629fe1b0e77d6af53ec5defe7 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 16 Mar 2026 08:47:44 +0000
Subject: [PATCH] [mlir][arith] Use type parser instead of hard-coding type
keywords
---
mlir/include/mlir/Dialect/Arith/Utils/Utils.h | 5 ++--
.../Transforms/EmulateUnsupportedFloats.cpp | 14 +++++-----
mlir/lib/Dialect/Arith/Utils/CMakeLists.txt | 1 +
mlir/lib/Dialect/Arith/Utils/Utils.cpp | 27 +++++--------------
.../Transforms/ExtendToSupportedTypes.cpp | 13 ++++-----
5 files changed, 21 insertions(+), 39 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
index c0b286494996b..4ebb7e16239f7 100644
--- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
@@ -143,8 +143,9 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values);
Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
Type resultType);
-// Map strings to float types.
-std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name);
+// Map strings to float types. Returns nullptr if the name is not a known
+// floating-point type.
+FloatType parseFloatType(MLIRContext *ctx, StringRef name);
} // namespace arith
} // namespace mlir
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index d018cddeb8dc1..b6e101952676a 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -132,25 +132,23 @@ void EmulateUnsupportedFloatsPass::runOnOperation() {
SmallVector<Type> sourceTypes;
Type targetType;
- std::optional<FloatType> maybeTargetType =
- arith::parseFloatType(ctx, targetTypeStr);
- if (!maybeTargetType) {
+ FloatType parsedTargetType = arith::parseFloatType(ctx, targetTypeStr);
+ if (!parsedTargetType) {
emitError(UnknownLoc::get(ctx), "could not map target type '" +
targetTypeStr +
"' to a known floating-point type");
return signalPassFailure();
}
- targetType = *maybeTargetType;
+ targetType = parsedTargetType;
for (StringRef sourceTypeStr : sourceTypeStrs) {
- std::optional<FloatType> maybeSourceType =
- arith::parseFloatType(ctx, sourceTypeStr);
- if (!maybeSourceType) {
+ FloatType sourceType = arith::parseFloatType(ctx, sourceTypeStr);
+ if (!sourceType) {
emitError(UnknownLoc::get(ctx), "could not map source type '" +
sourceTypeStr +
"' to a known floating-point type");
return signalPassFailure();
}
- sourceTypes.push_back(*maybeSourceType);
+ sourceTypes.push_back(sourceType);
}
if (sourceTypes.empty())
(void)emitOptionalWarning(
diff --git a/mlir/lib/Dialect/Arith/Utils/CMakeLists.txt b/mlir/lib/Dialect/Arith/Utils/CMakeLists.txt
index 07fa58b209b5e..b4760510fc96e 100644
--- a/mlir/lib/Dialect/Arith/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Utils/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRArithUtils
LINK_LIBS PUBLIC
MLIRArithDialect
+ MLIRAsmParser
MLIRComplexDialect
MLIRDialect
MLIRDialectUtils
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index 122154566a74e..200b40f74a5f5 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -11,9 +11,11 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/AsmParser/AsmParser.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Diagnostics.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include <numeric>
@@ -357,27 +359,10 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
});
}
-/// Map strings to float types.
-std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name) {
- Builder b(ctx);
- return llvm::StringSwitch<std::optional<FloatType>>(name)
- .Case("f4E2M1FN", b.getType<Float4E2M1FNType>())
- .Case("f6E2M3FN", b.getType<Float6E2M3FNType>())
- .Case("f6E3M2FN", b.getType<Float6E3M2FNType>())
- .Case("f8E5M2", b.getType<Float8E5M2Type>())
- .Case("f8E4M3", b.getType<Float8E4M3Type>())
- .Case("f8E4M3FN", b.getType<Float8E4M3FNType>())
- .Case("f8E5M2FNUZ", b.getType<Float8E5M2FNUZType>())
- .Case("f8E4M3FNUZ", b.getType<Float8E4M3FNUZType>())
- .Case("f8E3M4", b.getType<Float8E3M4Type>())
- .Case("f8E8M0FNU", b.getType<Float8E8M0FNUType>())
- .Case("bf16", b.getType<BFloat16Type>())
- .Case("f16", b.getType<Float16Type>())
- .Case("f32", b.getType<Float32Type>())
- .Case("f64", b.getType<Float64Type>())
- .Case("f80", b.getType<Float80Type>())
- .Case("f128", b.getType<Float128Type>())
- .Default(std::nullopt);
+FloatType parseFloatType(MLIRContext *ctx, StringRef name) {
+ // Suppress diagnostics: callers handle invalid type strings themselves.
+ ScopedDiagnosticHandler handler(ctx, [](Diagnostic &) {});
+ return dyn_cast_or_null<FloatType>(mlir::parseType(name, ctx));
}
} // namespace mlir::arith
diff --git a/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp b/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp
index 9d6ad613fc945..bc262f84b26ac 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp
@@ -124,28 +124,25 @@ void ExtendToSupportedTypesPass::runOnOperation() {
MLIRContext *ctx = &getContext();
// Parse target type
- std::optional<Type> maybeTargetType =
- arith::parseFloatType(ctx, targetTypeStr);
- if (!maybeTargetType.has_value()) {
+ FloatType targetType = arith::parseFloatType(ctx, targetTypeStr);
+ if (!targetType) {
emitError(UnknownLoc::get(ctx), "could not map target type '" +
targetTypeStr +
"' to a known floating-point type");
return signalPassFailure();
}
- Type targetType = maybeTargetType.value();
// Parse source types
llvm::SetVector<Type> sourceTypes;
for (const auto &extraTypeStr : extraTypeStrs) {
- std::optional<FloatType> maybeExtraType =
- arith::parseFloatType(ctx, extraTypeStr);
- if (!maybeExtraType.has_value()) {
+ FloatType extraType = arith::parseFloatType(ctx, extraTypeStr);
+ if (!extraType) {
emitError(UnknownLoc::get(ctx), "could not map source type '" +
extraTypeStr +
"' to a known floating-point type");
return signalPassFailure();
}
- sourceTypes.insert(maybeExtraType.value());
+ sourceTypes.insert(extraType);
}
// f64 and f32 are implicitly supported
Builder b(ctx);
More information about the Mlir-commits
mailing list