[Mlir-commits] [mlir] [mlir][arith][NFC] Use type parser instead of hard-coding type keywords (PR #186753)

Matthias Springer llvmlistbot at llvm.org
Mon Mar 16 05:47:19 PDT 2026


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/186753

>From bd4afdaf170c0daeec441a2c70db3ad8970212ab Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 16 Mar 2026 08:47:44 +0000
Subject: [PATCH] [mlir][arith] Use type parser instead of hard-coding type
 keywords

---
 mlir/include/mlir/Dialect/Arith/Utils/Utils.h |  5 ++--
 .../Transforms/EmulateUnsupportedFloats.cpp   | 14 +++++-----
 mlir/lib/Dialect/Arith/Utils/CMakeLists.txt   |  1 +
 mlir/lib/Dialect/Arith/Utils/Utils.cpp        | 27 +++++--------------
 .../Transforms/ExtendToSupportedTypes.cpp     | 13 ++++-----
 .../Arith/emulate-unsupported-floats.mlir     | 26 ++++++++++--------
 6 files changed, 36 insertions(+), 50 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
index c0b286494996b..4ebb7e16239f7 100644
--- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
@@ -143,8 +143,9 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values);
 Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
                     Type resultType);
 
-// Map strings to float types.
-std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name);
+// Map strings to float types. Returns nullptr if the name is not a known
+// floating-point type.
+FloatType parseFloatType(MLIRContext *ctx, StringRef name);
 
 } // namespace arith
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index d018cddeb8dc1..b6e101952676a 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -132,25 +132,23 @@ void EmulateUnsupportedFloatsPass::runOnOperation() {
   SmallVector<Type> sourceTypes;
   Type targetType;
 
-  std::optional<FloatType> maybeTargetType =
-      arith::parseFloatType(ctx, targetTypeStr);
-  if (!maybeTargetType) {
+  FloatType parsedTargetType = arith::parseFloatType(ctx, targetTypeStr);
+  if (!parsedTargetType) {
     emitError(UnknownLoc::get(ctx), "could not map target type '" +
                                         targetTypeStr +
                                         "' to a known floating-point type");
     return signalPassFailure();
   }
-  targetType = *maybeTargetType;
+  targetType = parsedTargetType;
   for (StringRef sourceTypeStr : sourceTypeStrs) {
-    std::optional<FloatType> maybeSourceType =
-        arith::parseFloatType(ctx, sourceTypeStr);
-    if (!maybeSourceType) {
+    FloatType sourceType = arith::parseFloatType(ctx, sourceTypeStr);
+    if (!sourceType) {
       emitError(UnknownLoc::get(ctx), "could not map source type '" +
                                           sourceTypeStr +
                                           "' to a known floating-point type");
       return signalPassFailure();
     }
-    sourceTypes.push_back(*maybeSourceType);
+    sourceTypes.push_back(sourceType);
   }
   if (sourceTypes.empty())
     (void)emitOptionalWarning(
diff --git a/mlir/lib/Dialect/Arith/Utils/CMakeLists.txt b/mlir/lib/Dialect/Arith/Utils/CMakeLists.txt
index 07fa58b209b5e..b4760510fc96e 100644
--- a/mlir/lib/Dialect/Arith/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Utils/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRArithUtils
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
+  MLIRAsmParser
   MLIRComplexDialect
   MLIRDialect
   MLIRDialectUtils
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index 122154566a74e..200b40f74a5f5 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -11,9 +11,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/AsmParser/AsmParser.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Diagnostics.h"
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/SmallVectorExtras.h"
 #include <numeric>
@@ -357,27 +359,10 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
   });
 }
 
-/// Map strings to float types.
-std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name) {
-  Builder b(ctx);
-  return llvm::StringSwitch<std::optional<FloatType>>(name)
-      .Case("f4E2M1FN", b.getType<Float4E2M1FNType>())
-      .Case("f6E2M3FN", b.getType<Float6E2M3FNType>())
-      .Case("f6E3M2FN", b.getType<Float6E3M2FNType>())
-      .Case("f8E5M2", b.getType<Float8E5M2Type>())
-      .Case("f8E4M3", b.getType<Float8E4M3Type>())
-      .Case("f8E4M3FN", b.getType<Float8E4M3FNType>())
-      .Case("f8E5M2FNUZ", b.getType<Float8E5M2FNUZType>())
-      .Case("f8E4M3FNUZ", b.getType<Float8E4M3FNUZType>())
-      .Case("f8E3M4", b.getType<Float8E3M4Type>())
-      .Case("f8E8M0FNU", b.getType<Float8E8M0FNUType>())
-      .Case("bf16", b.getType<BFloat16Type>())
-      .Case("f16", b.getType<Float16Type>())
-      .Case("f32", b.getType<Float32Type>())
-      .Case("f64", b.getType<Float64Type>())
-      .Case("f80", b.getType<Float80Type>())
-      .Case("f128", b.getType<Float128Type>())
-      .Default(std::nullopt);
+FloatType parseFloatType(MLIRContext *ctx, StringRef name) {
+  // Suppress diagnostics: callers handle invalid type strings themselves.
+  ScopedDiagnosticHandler handler(ctx, [](Diagnostic &) {});
+  return dyn_cast_or_null<FloatType>(mlir::parseType(name, ctx));
 }
 
 } // namespace mlir::arith
diff --git a/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp b/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp
index 9d6ad613fc945..bc262f84b26ac 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp
@@ -124,28 +124,25 @@ void ExtendToSupportedTypesPass::runOnOperation() {
   MLIRContext *ctx = &getContext();
 
   // Parse target type
-  std::optional<Type> maybeTargetType =
-      arith::parseFloatType(ctx, targetTypeStr);
-  if (!maybeTargetType.has_value()) {
+  FloatType targetType = arith::parseFloatType(ctx, targetTypeStr);
+  if (!targetType) {
     emitError(UnknownLoc::get(ctx), "could not map target type '" +
                                         targetTypeStr +
                                         "' to a known floating-point type");
     return signalPassFailure();
   }
-  Type targetType = maybeTargetType.value();
 
   // Parse source types
   llvm::SetVector<Type> sourceTypes;
   for (const auto &extraTypeStr : extraTypeStrs) {
-    std::optional<FloatType> maybeExtraType =
-        arith::parseFloatType(ctx, extraTypeStr);
-    if (!maybeExtraType.has_value()) {
+    FloatType extraType = arith::parseFloatType(ctx, extraTypeStr);
+    if (!extraType) {
       emitError(UnknownLoc::get(ctx), "could not map source type '" +
                                           extraTypeStr +
                                           "' to a known floating-point type");
       return signalPassFailure();
     }
-    sourceTypes.insert(maybeExtraType.value());
+    sourceTypes.insert(extraType);
   }
   // f64 and f32 are implicitly supported
   Builder b(ctx);
diff --git a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
index fcd004ac554aa..41c120edc9ac8 100644
--- a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
+++ b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
@@ -1,4 +1,9 @@
-// RUN: mlir-opt --split-input-file --arith-emulate-unsupported-floats="source-types=bf16,f8E4M3FNUZ target-type=f32" %s | FileCheck %s
+// RUN: mlir-opt --arith-emulate-unsupported-floats="source-types=bf16,f8E4M3FNUZ target-type=f32" %s | FileCheck %s
+// RUN: mlir-opt --arith-emulate-unsupported-floats="source-types=bf16,f8E4M3FNUZ target-type=!llvm.ppc_fp128" %s | FileCheck %s --check-prefix=CHECK-PPC
+
+// Arbitrary op from the LLVM dialect to ensure that the LLVM dialect is loaded.
+// The LLVM dialect is needed for the !llvm.ppc_fp128 test run.
+llvm.func @foo() {}
 
 func.func @basic_expansion(%x: bf16) -> bf16 {
 // CHECK-LABEL: @basic_expansion
@@ -9,13 +14,20 @@ func.func @basic_expansion(%x: bf16) -> bf16 {
 // CHECK: [[Y_EXP:%.+]] = arith.addf [[X_EXP]], [[C_EXP]] : f32
 // CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] fastmath<contract> : f32 to bf16
 // CHECK: return [[Y]]
+
+// CHECK-PPC-LABEL: @basic_expansion
+// CHECK-PPC-SAME: [[X:%.+]]: bf16
+// CHECK-PPC-DAG: [[C:%.+]] = arith.constant {{.*}} : bf16
+// CHECK-PPC-DAG: [[X_EXP:%.+]] = arith.extf [[X]] fastmath<contract> : bf16 to !llvm.ppc_fp128
+// CHECK-PPC-DAG: [[C_EXP:%.+]] = arith.extf [[C]] fastmath<contract> : bf16 to !llvm.ppc_fp128
+// CHECK-PPC: [[Y_EXP:%.+]] = arith.addf [[X_EXP]], [[C_EXP]] : !llvm.ppc_fp128
+// CHECK-PPC: [[Y:%.+]] = arith.truncf [[Y_EXP]] fastmath<contract> : !llvm.ppc_fp128 to bf16
+// CHECK-PPC: return [[Y]]
   %c = arith.constant 1.0 : bf16
   %y = arith.addf %x, %c : bf16
   func.return %y : bf16
 }
 
-// -----
-
 func.func @chained(%x: bf16, %y: bf16, %z: bf16) -> i1 {
 // CHECK-LABEL: @chained
 // CHECK-SAME: [[X:%.+]]: bf16, [[Y:%.+]]: bf16, [[Z:%.+]]: bf16
@@ -36,8 +48,6 @@ func.func @chained(%x: bf16, %y: bf16, %z: bf16) -> i1 {
   func.return %res : i1
 }
 
-// -----
-
 func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) {
 // CHECK-LABEL: @memops
 // CHECK: [[V:%.+]] = memref.load {{.*}} : memref<4xf8E4M3FNUZ>
@@ -58,8 +68,6 @@ func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) {
   func.return
 }
 
-// -----
-
 func.func @vectors(%a: vector<4xf8E4M3FNUZ>) -> vector<4xf32> {
 // CHECK-LABEL: @vectors
 // CHECK-SAME: [[A:%.+]]: vector<4xf8E4M3FNUZ>
@@ -73,8 +81,6 @@ func.func @vectors(%a: vector<4xf8E4M3FNUZ>) -> vector<4xf32> {
   func.return %ret : vector<4xf32>
 }
 
-// -----
-
 func.func @no_expansion(%x: f32) -> f32 {
 // CHECK-LABEL: @no_expansion
 // CHECK-SAME: [[X:%.+]]: f32
@@ -86,8 +92,6 @@ func.func @no_expansion(%x: f32) -> f32 {
   func.return %y : f32
 }
 
-// -----
-
 func.func @no_promote_select(%c: i1, %x: bf16, %y: bf16) -> bf16 {
 // CHECK-LABEL: @no_promote_select
 // CHECK-SAME: (%[[C:.+]]: i1, %[[X:.+]]: bf16, %[[Y:.+]]: bf16)



More information about the Mlir-commits mailing list