[Mlir-commits] [mlir] [mlir][math] Add FP software implementation lowering pass: math-to-apfloat (PR #171221)

Maksim Levental llvmlistbot at llvm.org
Wed Dec 17 10:28:45 PST 2025


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/171221

>From 148280d765d689990e7d90317be2fb369736e7ca Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Mon, 8 Dec 2025 14:48:11 -0800
Subject: [PATCH 1/6] [mlir][math] Add FP software implementation lowering
 pass: math-to-apfloat

---
 .../Conversion/MathToAPFloat/MathToAPFloat.h  |  21 +++
 mlir/include/mlir/Conversion/Passes.h         |   1 +
 mlir/include/mlir/Conversion/Passes.td        |  15 ++
 mlir/include/mlir/Dialect/Func/Utils/Utils.h  |  16 ++
 .../ArithToAPFloat.cpp                        |  92 +++--------
 .../ArithAndMathToAPFloat/CMakeLists.txt      |  49 ++++++
 .../ArithAndMathToAPFloat/MathToAPFloat.cpp   | 148 ++++++++++++++++++
 .../ArithAndMathToAPFloat/Utils.cpp           |  22 +++
 .../Conversion/ArithAndMathToAPFloat/Utils.h  |  21 +++
 .../Conversion/ArithToAPFloat/CMakeLists.txt  |  19 ---
 mlir/lib/Conversion/CMakeLists.txt            |   2 +-
 mlir/lib/Dialect/Func/Utils/Utils.cpp         |  39 +++++
 mlir/lib/ExecutionEngine/APFloatWrappers.cpp  |  48 +++++-
 .../Math/CPU/test-apfloat-emulation.mlir      |  32 ++++
 14 files changed, 437 insertions(+), 88 deletions(-)
 create mode 100644 mlir/include/mlir/Conversion/MathToAPFloat/MathToAPFloat.h
 rename mlir/lib/Conversion/{ArithToAPFloat => ArithAndMathToAPFloat}/ArithToAPFloat.cpp (88%)
 create mode 100644 mlir/lib/Conversion/ArithAndMathToAPFloat/CMakeLists.txt
 create mode 100644 mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
 create mode 100644 mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp
 create mode 100644 mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h
 delete mode 100644 mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt
 create mode 100644 mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir

diff --git a/mlir/include/mlir/Conversion/MathToAPFloat/MathToAPFloat.h b/mlir/include/mlir/Conversion/MathToAPFloat/MathToAPFloat.h
new file mode 100644
index 0000000000000..86179a1611d5e
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MathToAPFloat/MathToAPFloat.h
@@ -0,0 +1,21 @@
+//===- MathToAPFloat.h - Math to APFloat impl conversion ---*- C++ ------*-===//
+//
+// Part of the APFloat Project, under the Apache License v2.0 with APFloat
+// Exceptions. See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH APFloat-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_MATHTOAPFLOAT_MATHTOAPFLOAT_H
+#define MLIR_CONVERSION_MATHTOAPFLOAT_MATHTOAPFLOAT_H
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_MATHTOAPFLOATCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MATHTOAPFLOAT_MATHTOAPFLOAT_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 82bdfd02661a6..05ec2f8ce2538 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -44,6 +44,7 @@
 #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
 #include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
 #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
+#include "mlir/Conversion/MathToAPFloat/MathToAPFloat.h"
 #include "mlir/Conversion/MathToEmitC/MathToEmitCPass.h"
 #include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index fcbaf3ccc1486..7f24e58671aab 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -775,6 +775,21 @@ def ConvertMathToLibmPass : Pass<"convert-math-to-libm", "ModuleOp"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// MathToAPFloat
+//===----------------------------------------------------------------------===//
+
+def MathToAPFloatConversionPass
+    : Pass<"convert-math-to-apfloat", "ModuleOp"> {
+  let summary = "Convert Math ops to APFloat runtime library calls";
+  let description = [{
+    This pass converts supported Math ops to APFloat-based runtime library
+    calls (APFloatWrappers.cpp). APFloat is a software implementation of
+    floating-point mathmetic operations.
+  }];
+  let dependentDialects = ["math::MathDialect", "func::FuncDialect"];
+}
+
 //===----------------------------------------------------------------------===//
 // MathToLLVM
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
index 00d50874a2e8d..079c1f461b6ed 100644
--- a/mlir/include/mlir/Dialect/Func/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
@@ -67,6 +67,22 @@ FailureOr<FuncOp> lookupFnDecl(SymbolOpInterface symTable, StringRef name,
                                FunctionType funcT,
                                SymbolTableCollection *symbolTables = nullptr);
 
+/// Create a FuncOp decl and insert it into `symTable` operation. If
+/// `symbolTables` is provided, then the decl will be inserted into the
+/// SymbolTableCollection.
+FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
+                    FunctionType funcT, bool setPrivate,
+                    SymbolTableCollection *symbolTables = nullptr);
+
+/// Helper function to look up or create the symbol for a runtime library
+/// function with the given parameter types. Returns an int64_t, unless a
+/// different result type is specified.
+FailureOr<FuncOp>
+lookupOrCreateFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
+                     TypeRange paramTypes,
+                     SymbolTableCollection *symbolTables = nullptr,
+                     Type resultType = {});
+
 } // namespace func
 } // namespace mlir
 
diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
similarity index 88%
rename from mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
rename to mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
index 79816fc6e3bf1..b9ba94ef08098 100644
--- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
+#include "Utils.h"
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
@@ -25,47 +26,6 @@ namespace mlir {
 using namespace mlir;
 using namespace mlir::func;
 
-static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
-                           StringRef name, FunctionType funcT, bool setPrivate,
-                           SymbolTableCollection *symbolTables = nullptr) {
-  OpBuilder::InsertionGuard g(b);
-  assert(!symTable->getRegion(0).empty() && "expected non-empty region");
-  b.setInsertionPointToStart(&symTable->getRegion(0).front());
-  FuncOp funcOp = FuncOp::create(b, symTable->getLoc(), name, funcT);
-  if (setPrivate)
-    funcOp.setPrivate();
-  if (symbolTables) {
-    SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable);
-    symbolTable.insert(funcOp, symTable->getRegion(0).front().begin());
-  }
-  return funcOp;
-}
-
-/// Helper function to look up or create the symbol for a runtime library
-/// function with the given parameter types. Returns an int64_t, unless a
-/// different result type is specified.
-static FailureOr<FuncOp>
-lookupOrCreateApFloatFn(OpBuilder &b, SymbolOpInterface symTable,
-                        StringRef name, TypeRange paramTypes,
-                        SymbolTableCollection *symbolTables = nullptr,
-                        Type resultType = {}) {
-  if (!resultType)
-    resultType = IntegerType::get(symTable->getContext(), 64);
-  std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
-  auto funcT = FunctionType::get(b.getContext(), paramTypes, {resultType});
-  FailureOr<FuncOp> func =
-      lookupFnDecl(symTable, funcName, funcT, symbolTables);
-  // Failed due to type mismatch.
-  if (failed(func))
-    return func;
-  // Successfully matched existing decl.
-  if (*func)
-    return *func;
-
-  return createFnDecl(b, symTable, funcName, funcT,
-                      /*setPrivate=*/true, symbolTables);
-}
-
 /// Helper function to look up or create the symbol for a runtime library
 /// function for a binary arithmetic operation.
 ///
@@ -81,14 +41,9 @@ lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
                        SymbolTableCollection *symbolTables = nullptr) {
   auto i32Type = IntegerType::get(symTable->getContext(), 32);
   auto i64Type = IntegerType::get(symTable->getContext(), 64);
-  return lookupOrCreateApFloatFn(b, symTable, name, {i32Type, i64Type, i64Type},
-                                 symbolTables);
-}
-
-static Value getSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy) {
-  int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
-  return arith::ConstantOp::create(b, loc, b.getI32Type(),
-                                   b.getIntegerAttr(b.getI32Type(), sem));
+  std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
+  return lookupOrCreateFnDecl(b, symTable, funcName,
+                              {i32Type, i64Type, i64Type}, symbolTables);
 }
 
 /// Given two operands of vector type and vector result type (with the same
@@ -197,7 +152,7 @@ struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
               arith::BitcastOp::create(rewriter, loc, intWType, rhs));
 
           // Call APFloat function.
-          Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+          Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
           SmallVector<Value> params = {semValue, lhsBits, rhsBits};
           auto resultOp = func::CallOp::create(rewriter, loc,
                                                TypeRange(rewriter.getI64Type()),
@@ -231,8 +186,9 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
     // Get APFloat function from runtime library.
     auto i32Type = IntegerType::get(symTable->getContext(), 32);
     auto i64Type = IntegerType::get(symTable->getContext(), 64);
-    FailureOr<FuncOp> fn = lookupOrCreateApFloatFn(
-        rewriter, symTable, "convert", {i32Type, i32Type, i64Type});
+    FailureOr<FuncOp> fn =
+        lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_convert",
+                             {i32Type, i32Type, i64Type});
     if (failed(fn))
       return fn;
 
@@ -250,9 +206,10 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
               arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
 
           // Call APFloat function.
-          Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
+          Value inSemValue = getAPFloatSemanticsValue(rewriter, loc, inFloatTy);
           auto outFloatTy = cast<FloatType>(resultType);
-          Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
+          Value outSemValue =
+              getAPFloatSemanticsValue(rewriter, loc, outFloatTy);
           std::array<Value, 3> params = {inSemValue, outSemValue, operandBits};
           auto resultOp = func::CallOp::create(rewriter, loc,
                                                TypeRange(rewriter.getI64Type()),
@@ -289,8 +246,8 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
     auto i32Type = IntegerType::get(symTable->getContext(), 32);
     auto i64Type = IntegerType::get(symTable->getContext(), 64);
     FailureOr<FuncOp> fn =
-        lookupOrCreateApFloatFn(rewriter, symTable, "convert_to_int",
-                                {i32Type, i32Type, i1Type, i64Type});
+        lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_convert_to_int",
+                             {i32Type, i32Type, i1Type, i64Type});
     if (failed(fn))
       return fn;
 
@@ -308,7 +265,7 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
               arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
 
           // Call APFloat function.
-          Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
+          Value inSemValue = getAPFloatSemanticsValue(rewriter, loc, inFloatTy);
           auto outIntTy = cast<IntegerType>(resultType);
           Value outWidthValue = arith::ConstantOp::create(
               rewriter, loc, i32Type,
@@ -350,9 +307,9 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
     auto i1Type = IntegerType::get(symTable->getContext(), 1);
     auto i32Type = IntegerType::get(symTable->getContext(), 32);
     auto i64Type = IntegerType::get(symTable->getContext(), 64);
-    FailureOr<FuncOp> fn =
-        lookupOrCreateApFloatFn(rewriter, symTable, "convert_from_int",
-                                {i32Type, i32Type, i1Type, i64Type});
+    FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
+        rewriter, symTable, "_mlir_apfloat_convert_from_int",
+        {i32Type, i32Type, i1Type, i64Type});
     if (failed(fn))
       return fn;
 
@@ -377,7 +334,8 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
 
           // Call APFloat function.
           auto outFloatTy = cast<FloatType>(resultType);
-          Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
+          Value outSemValue =
+              getAPFloatSemanticsValue(rewriter, loc, outFloatTy);
           Value inWidthValue = arith::ConstantOp::create(
               rewriter, loc, i32Type,
               rewriter.getIntegerAttr(i32Type, inIntTy.getWidth()));
@@ -421,8 +379,8 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
     auto i32Type = IntegerType::get(symTable->getContext(), 32);
     auto i64Type = IntegerType::get(symTable->getContext(), 64);
     FailureOr<FuncOp> fn =
-        lookupOrCreateApFloatFn(rewriter, symTable, "compare",
-                                {i32Type, i64Type, i64Type}, nullptr, i8Type);
+        lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_compare",
+                             {i32Type, i64Type, i64Type}, nullptr, i8Type);
     if (failed(fn))
       return fn;
 
@@ -443,7 +401,7 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
               arith::BitcastOp::create(rewriter, loc, intWType, rhs));
 
           // Call APFloat function.
-          Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+          Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
           SmallVector<Value> params = {semValue, lhsBits, rhsBits};
           Value comparisonResult =
               func::CallOp::create(rewriter, loc, TypeRange(i8Type),
@@ -569,8 +527,8 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
     // Get APFloat function from runtime library.
     auto i32Type = IntegerType::get(symTable->getContext(), 32);
     auto i64Type = IntegerType::get(symTable->getContext(), 64);
-    FailureOr<FuncOp> fn =
-        lookupOrCreateApFloatFn(rewriter, symTable, "neg", {i32Type, i64Type});
+    FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
+        rewriter, symTable, "_mlir_apfloat_neg", {i32Type, i64Type});
     if (failed(fn))
       return fn;
 
@@ -588,7 +546,7 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
               arith::BitcastOp::create(rewriter, loc, intWType, operand1));
 
           // Call APFloat function.
-          Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+          Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
           SmallVector<Value> params = {semValue, operandBits};
           Value negatedBits =
               func::CallOp::create(rewriter, loc, TypeRange(i64Type),
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/CMakeLists.txt b/mlir/lib/Conversion/ArithAndMathToAPFloat/CMakeLists.txt
new file mode 100644
index 0000000000000..cc8e61a87addc
--- /dev/null
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/CMakeLists.txt
@@ -0,0 +1,49 @@
+add_mlir_library(ArithAndMathToAPFloatUtils
+  Utils.cpp
+  PARTIAL_SOURCES_INTENDED
+
+  LINK_LIBS PUBLIC
+  MLIRArithDialect
+)
+
+add_mlir_conversion_library(MLIRArithToAPFloat
+  ArithToAPFloat.cpp
+  PARTIAL_SOURCES_INTENDED
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  ArithAndMathToAPFloatUtils
+  MLIRArithDialect
+  MLIRArithTransforms
+  MLIRFuncDialect
+  MLIRFuncUtils
+  MLIRVectorDialect
+  )
+
+add_mlir_conversion_library(MLIRMathToAPFloat
+  MathToAPFloat.cpp
+  PARTIAL_SOURCES_INTENDED
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToLLVM
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  ArithAndMathToAPFloatUtils
+  MLIRMathDialect
+  MLIRFuncDialect
+  MLIRFuncUtils
+  )
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
new file mode 100644
index 0000000000000..e540747ac0abd
--- /dev/null
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
@@ -0,0 +1,148 @@
+//===- MathToAPFloat.cpp - Mathmetic to APFloat Conversion ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToAPFloat/MathToAPFloat.h"
+#include "Utils.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Utils/Utils.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_MATHTOAPFLOATCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::func;
+
+struct AbsFOpToAPFloatConversion final : OpRewritePattern<math::AbsFOp> {
+  AbsFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
+                            PatternBenefit benefit = 1)
+      : OpRewritePattern<math::AbsFOp>(context, benefit), symTable(symTable) {}
+
+  LogicalResult matchAndRewrite(math::AbsFOp op,
+                                PatternRewriter &rewriter) const override {
+    // Get APFloat function from runtime library.
+    auto i32Type = IntegerType::get(symTable->getContext(), 32);
+    auto i64Type = IntegerType::get(symTable->getContext(), 64);
+    FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
+        rewriter, symTable, "_mlir_apfloat_abs", {i32Type, i64Type});
+    if (failed(fn))
+      return fn;
+    Location loc = op.getLoc();
+    rewriter.setInsertionPoint(op);
+    // Cast operands to 64-bit integers.
+    auto operand = op.getOperand();
+    auto floatTy = cast<FloatType>(operand.getType());
+    auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+    Value operandBits = arith::ExtUIOp::create(
+        rewriter, loc, i64Type,
+        arith::BitcastOp::create(rewriter, loc, intWType, operand));
+
+    // Call APFloat function.
+    Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
+    SmallVector<Value> params = {semValue, operandBits};
+    Value negatedBits = func::CallOp::create(rewriter, loc, TypeRange(i64Type),
+                                             SymbolRefAttr::get(*fn), params)
+                            ->getResult(0);
+
+    // Truncate result to the original width.
+    Value truncatedBits =
+        arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
+    rewriter.replaceOp(
+        op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
+    return success();
+  }
+
+  SymbolOpInterface symTable;
+};
+
+template <typename OpTy>
+struct IsOpToAPFloatConversion final : OpRewritePattern<OpTy> {
+  IsOpToAPFloatConversion(MLIRContext *context, const char *APFloatName,
+                          SymbolOpInterface symTable,
+                          PatternBenefit benefit = 1)
+      : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
+        APFloatName(APFloatName) {};
+
+  LogicalResult matchAndRewrite(OpTy op,
+                                PatternRewriter &rewriter) const override {
+    // Get APFloat function from runtime library.
+    auto i1 = IntegerType::get(symTable->getContext(), 1);
+    auto i32Type = IntegerType::get(symTable->getContext(), 32);
+    auto i64Type = IntegerType::get(symTable->getContext(), 64);
+    std::string funcName =
+        (llvm::Twine("_mlir_apfloat_is") + APFloatName).str();
+    FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
+        rewriter, symTable, funcName, {i32Type, i64Type}, nullptr, i1);
+    if (failed(fn))
+      return fn;
+    Location loc = op.getLoc();
+    rewriter.setInsertionPoint(op);
+    // Cast operands to 64-bit integers.
+    auto operand = op.getOperand();
+    auto floatTy = cast<FloatType>(operand.getType());
+    auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+    Value operandBits = arith::ExtUIOp::create(
+        rewriter, loc, i64Type,
+        arith::BitcastOp::create(rewriter, loc, intWType, operand));
+
+    // Call APFloat function.
+    Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
+    SmallVector<Value> params = {semValue, operandBits};
+    rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(i64Type),
+                                              SymbolRefAttr::get(*fn), params);
+    return success();
+  }
+
+  SymbolOpInterface symTable;
+  const char *APFloatName;
+};
+
+namespace {
+struct MathToAPFloatConversionPass final
+    : impl::MathToAPFloatConversionPassBase<MathToAPFloatConversionPass> {
+  using Base::Base;
+
+  void runOnOperation() override;
+};
+
+void MathToAPFloatConversionPass::runOnOperation() {
+  MLIRContext *context = &getContext();
+  RewritePatternSet patterns(context);
+
+  patterns.add<AbsFOpToAPFloatConversion>(context, getOperation());
+  patterns.add<IsOpToAPFloatConversion<math::IsFiniteOp>>(context, "finite",
+                                                          getOperation());
+  patterns.add<IsOpToAPFloatConversion<math::IsInfOp>>(context, "infinite",
+                                                       getOperation());
+  patterns.add<IsOpToAPFloatConversion<math::IsNaNOp>>(context, "nan",
+                                                       getOperation());
+  patterns.add<IsOpToAPFloatConversion<math::IsNormalOp>>(context, "normal",
+                                                          getOperation());
+
+  LogicalResult result = success();
+  ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
+    if (diag.getSeverity() == DiagnosticSeverity::Error) {
+      result = failure();
+    }
+    // NB: if you don't return failure, no other diag handlers will fire (see
+    // mlir/lib/IR/Diagnostics.cpp:DiagnosticEngineImpl::emit).
+    return failure();
+  });
+  walkAndApplyPatterns(getOperation(), std::move(patterns));
+  if (failed(result))
+    return signalPassFailure();
+}
+} // namespace
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp
new file mode 100644
index 0000000000000..2b5857367dc40
--- /dev/null
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp
@@ -0,0 +1,22 @@
+//===- Utils.cpp - Utils for APFloat Conversion ---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "Utils.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Value.h"
+
+mlir::Value mlir::getAPFloatSemanticsValue(OpBuilder &b, Location loc,
+                                           FloatType floatTy) {
+  int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
+  return arith::ConstantOp::create(b, loc, b.getI32Type(),
+                                   b.getIntegerAttr(b.getI32Type(), sem));
+}
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h
new file mode 100644
index 0000000000000..5f11d24261b43
--- /dev/null
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h
@@ -0,0 +1,21 @@
+//===- Utils.h - Utils for APFloat Conversion - C++ -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_UTILS_H_
+#define MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_UTILS_H_
+
+namespace mlir {
+class Value;
+class OpBuilder;
+class Location;
+class FloatType;
+
+Value getAPFloatSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_UTILS_H_
diff --git a/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt b/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt
deleted file mode 100644
index 31fce7a4de8a2..0000000000000
--- a/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt
+++ /dev/null
@@ -1,19 +0,0 @@
-add_mlir_conversion_library(MLIRArithToAPFloat
-  ArithToAPFloat.cpp
-
-  ADDITIONAL_HEADER_DIRS
-  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM
-
-  DEPENDS
-  MLIRConversionPassIncGen
-
-  LINK_COMPONENTS
-  Core
-
-  LINK_LIBS PUBLIC
-  MLIRArithDialect
-  MLIRArithTransforms
-  MLIRFuncDialect
-  MLIRFuncUtils
-  MLIRVectorDialect
-  )
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 613dc6d242ceb..2ed10effb53da 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -2,7 +2,7 @@ add_subdirectory(AffineToStandard)
 add_subdirectory(AMDGPUToROCDL)
 add_subdirectory(ArithCommon)
 add_subdirectory(ArithToAMDGPU)
-add_subdirectory(ArithToAPFloat)
+add_subdirectory(ArithAndMathToAPFloat)
 add_subdirectory(ArithToArmSME)
 add_subdirectory(ArithToEmitC)
 add_subdirectory(ArithToLLVM)
diff --git a/mlir/lib/Dialect/Func/Utils/Utils.cpp b/mlir/lib/Dialect/Func/Utils/Utils.cpp
index d6dfd0229963c..0a56817b704ff 100644
--- a/mlir/lib/Dialect/Func/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Func/Utils/Utils.cpp
@@ -279,3 +279,42 @@ func::lookupFnDecl(SymbolOpInterface symTable, StringRef name,
   }
   return func;
 }
+
+func::FuncOp func::createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
+                                StringRef name, FunctionType funcT,
+                                bool setPrivate,
+                                SymbolTableCollection *symbolTables) {
+  OpBuilder::InsertionGuard g(b);
+  assert(!symTable->getRegion(0).empty() && "expected non-empty region");
+  b.setInsertionPointToStart(&symTable->getRegion(0).front());
+  func::FuncOp funcOp =
+      func::FuncOp::create(b, symTable->getLoc(), name, funcT);
+  if (setPrivate)
+    funcOp.setPrivate();
+  if (symbolTables) {
+    SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable);
+    symbolTable.insert(funcOp, symTable->getRegion(0).front().begin());
+  }
+  return funcOp;
+}
+
+FailureOr<func::FuncOp>
+func::lookupOrCreateFnDecl(OpBuilder &b, SymbolOpInterface symTable,
+                           StringRef funcName, TypeRange paramTypes,
+                           SymbolTableCollection *symbolTables,
+                           Type resultType) {
+  if (!resultType)
+    resultType = IntegerType::get(symTable->getContext(), 64);
+  auto funcT = FunctionType::get(b.getContext(), paramTypes, {resultType});
+  FailureOr<func::FuncOp> func =
+      lookupFnDecl(symTable, funcName, funcT, symbolTables);
+  // Failed due to type mismatch.
+  if (failed(func))
+    return func;
+  // Successfully matched existing decl.
+  if (*func)
+    return *func;
+
+  return createFnDecl(b, symTable, funcName, funcT,
+                      /*setPrivate=*/true, symbolTables);
+}
\ No newline at end of file
diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
index f3e38eb8ffa2d..0c076af20dea7 100644
--- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
@@ -143,7 +143,8 @@ MLIR_APFLOAT_WRAPPERS_EXPORT int8_t _mlir_apfloat_compare(int32_t semantics,
   return static_cast<int8_t>(x.compare(y));
 }
 
-MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_neg(int32_t semantics, uint64_t a) {
+MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_neg(int32_t semantics,
+                                                        uint64_t a) {
   const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
       static_cast<llvm::APFloatBase::Semantics>(semantics));
   unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
@@ -152,6 +153,51 @@ MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_neg(int32_t semantics, uint6
   return x.bitcastToAPInt().getZExtValue();
 }
 
+MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_abs(int32_t semantics,
+                                                        uint64_t a) {
+  const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+      static_cast<llvm::APFloatBase::Semantics>(semantics));
+  unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
+  llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
+  return abs(x).bitcastToAPInt().getZExtValue();
+}
+
+MLIR_APFLOAT_WRAPPERS_EXPORT bool _mlir_apfloat_isfinite(int32_t semantics,
+                                                         uint64_t a) {
+  const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+      static_cast<llvm::APFloatBase::Semantics>(semantics));
+  unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
+  llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
+  return x.isFinite();
+}
+
+MLIR_APFLOAT_WRAPPERS_EXPORT bool _mlir_apfloat_isinfinite(int32_t semantics,
+                                                           uint64_t a) {
+  const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+      static_cast<llvm::APFloatBase::Semantics>(semantics));
+  unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
+  llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
+  return x.isInfinity();
+}
+
+MLIR_APFLOAT_WRAPPERS_EXPORT bool _mlir_apfloat_isnormal(int32_t semantics,
+                                                         uint64_t a) {
+  const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+      static_cast<llvm::APFloatBase::Semantics>(semantics));
+  unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
+  llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
+  return x.isNormal();
+}
+
+MLIR_APFLOAT_WRAPPERS_EXPORT bool _mlir_apfloat_isnan(int32_t semantics,
+                                                      uint64_t a) {
+  const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+      static_cast<llvm::APFloatBase::Semantics>(semantics));
+  unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
+  llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
+  return x.isNaN();
+}
+
 /// Min/max operations.
 #define APFLOAT_MIN_MAX_OP(OP)                                                 \
   MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_##OP(                    \
diff --git a/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir
new file mode 100644
index 0000000000000..aca8a432a53b9
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir
@@ -0,0 +1,32 @@
+// REQUIRES: system-linux || system-darwin
+// TODO: Run only on Linux until we figure out how to build
+// mlir_apfloat_wrappers in a platform-independent way.
+
+// Case 1: All floating-point arithmetics is lowered through APFloat.
+// RUN: mlir-opt %s --convert-math-to-apfloat --convert-to-llvm | \
+// RUN: mlir-runner -e entry --entry-point-result=void \
+// RUN:             --shared-libs=%mlir_c_runner_utils \
+// RUN:             --shared-libs=%mlir_apfloat_wrappers | FileCheck %s
+
+// Case 2: Only unsupported arithmetics (f8E4M3FN) is lowered through APFloat.
+//         Arithmetics on f32 is lowered directly to LLVM.
+// RUN: mlir-opt %s --convert-to-llvm --convert-math-to-apfloat \
+// RUN:          --convert-to-llvm --reconcile-unrealized-casts | \
+// RUN: mlir-runner -e entry --entry-point-result=void \
+// RUN:             --shared-libs=%mlir_c_runner_utils \
+// RUN:             --shared-libs=%mlir_apfloat_wrappers | FileCheck %s
+
+func.func @entry() {
+  %neg14fp8 = arith.constant -1.4 : f8E4M3FN
+  %neg14fp32 = arith.constant 1.4 : f32
+
+  // CHECK: 1.375
+  %c2 = math.absf %neg14fp8 : f8E4M3FN
+  vector.print %c2 : f8E4M3FN
+
+  // CHECK: 1.4
+  %c3 = math.absf %neg14fp32 : f32
+  vector.print %c3 : f32
+
+  return
+}

>From 9d11dde48dab8f332e61ef151db59914e2720445 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Wed, 10 Dec 2025 11:37:24 -0800
Subject: [PATCH 2/6] not working (print is wrong?)

---
 .../ArithAndMathToAPFloat/MathToAPFloat.cpp   | 51 +++++++++++++++++++
 mlir/lib/ExecutionEngine/APFloatWrappers.cpp  | 23 +++++++++
 .../Math/CPU/test-apfloat-emulation.mlir      |  9 ++++
 3 files changed, 83 insertions(+)

diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
index e540747ac0abd..4c8764ba1d6b0 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
@@ -110,6 +110,56 @@ struct IsOpToAPFloatConversion final : OpRewritePattern<OpTy> {
   const char *APFloatName;
 };
 
+struct FmaOpToAPFloatConversion final : OpRewritePattern<math::FmaOp> {
+  FmaOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
+                           PatternBenefit benefit = 1)
+      : OpRewritePattern<math::FmaOp>(context, benefit), symTable(symTable) {};
+
+  LogicalResult matchAndRewrite(math::FmaOp op,
+                                PatternRewriter &rewriter) const override {
+
+    auto i32Type = IntegerType::get(symTable->getContext(), 32);
+    auto i64Type = IntegerType::get(symTable->getContext(), 64);
+    FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
+        rewriter, symTable, "_mlir_apfloat_fused_multiply_add",
+        {i32Type, i64Type, i64Type, i64Type});
+    if (failed(fn))
+      return fn;
+    Location loc = op.getLoc();
+    rewriter.setInsertionPoint(op);
+
+    // Cast operands to 64-bit integers.
+    auto floatTy = cast<FloatType>(op.getResult().getType());
+    auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+    auto int64Type = rewriter.getI64Type();
+    Value operand = arith::ExtUIOp::create(
+        rewriter, loc, int64Type,
+        arith::BitcastOp::create(rewriter, loc, intWType, op.getA()));
+    Value multiplicand = arith::ExtUIOp::create(
+        rewriter, loc, int64Type,
+        arith::BitcastOp::create(rewriter, loc, intWType, op.getB()));
+    Value addend = arith::ExtUIOp::create(
+        rewriter, loc, int64Type,
+        arith::BitcastOp::create(rewriter, loc, intWType, op.getC()));
+
+    // Call APFloat function.
+    Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
+    SmallVector<Value> params = {semValue, operand, multiplicand, addend};
+    auto resultOp =
+        func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
+                             SymbolRefAttr::get(*fn), params);
+
+    // Truncate result to the original width.
+    Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
+                                                  resultOp->getResult(0));
+    rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, floatTy, truncatedBits);
+    return success();
+  }
+
+  SymbolOpInterface symTable;
+  const char *APFloatName;
+};
+
 namespace {
 struct MathToAPFloatConversionPass final
     : impl::MathToAPFloatConversionPassBase<MathToAPFloatConversionPass> {
@@ -131,6 +181,7 @@ void MathToAPFloatConversionPass::runOnOperation() {
                                                        getOperation());
   patterns.add<IsOpToAPFloatConversion<math::IsNormalOp>>(context, "normal",
                                                           getOperation());
+  patterns.add<FmaOpToAPFloatConversion>(context, getOperation());
 
   LogicalResult result = success();
   ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
index 0c076af20dea7..254590a0d8566 100644
--- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
@@ -21,6 +21,7 @@
 //
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/APSInt.h"
+#include "llvm/Support/Debug.h"
 
 #ifdef _WIN32
 #ifndef MLIR_APFLOAT_WRAPPERS_EXPORT
@@ -198,6 +199,28 @@ MLIR_APFLOAT_WRAPPERS_EXPORT bool _mlir_apfloat_isnan(int32_t semantics,
   return x.isNaN();
 }
 
+MLIR_APFLOAT_WRAPPERS_EXPORT bool
+_mlir_apfloat_fused_multiply_add(int32_t semantics, uint64_t operand,
+                                 uint64_t multiplicand, uint64_t addend) {
+  const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+      static_cast<llvm::APFloatBase::Semantics>(semantics));
+  unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
+  llvm::APFloat operand_(sem, llvm::APInt(bitWidth, operand));
+  llvm::APFloat multiplicand_(sem, llvm::APInt(bitWidth, multiplicand));
+  llvm::APFloat addend_(sem, llvm::APInt(bitWidth, addend));
+  llvm::detail::opStatus stat = operand_.fusedMultiplyAdd(
+      multiplicand_, addend_, llvm::RoundingMode::NearestTiesToEven);
+
+  ////////////
+  operand_.print(llvm::dbgs());
+  llvm::dbgs() << "\n";
+  ////////////
+
+  assert(stat == llvm::APFloatBase::opOK &&
+         "expected fusedMultiplyAdd status to be OK");
+  return operand_.bitcastToAPInt().getZExtValue();
+}
+
 /// Min/max operations.
 #define APFLOAT_MIN_MAX_OP(OP)                                                 \
   MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_##OP(                    \
diff --git a/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir
index aca8a432a53b9..892b970a2796d 100644
--- a/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir
+++ b/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir
@@ -28,5 +28,14 @@ func.func @entry() {
   %c3 = math.absf %neg14fp32 : f32
   vector.print %c3 : f32
 
+  // see llvm/unittests/ADT/APFloatTest::TEST(APFloatTest, Float8E8M0FNUFMA)
+  %twof8E8M0FNU = arith.constant 2.0 : f8E8M0FNU
+  %fourf8E8M0FNU = arith.constant 4.0 : f8E8M0FNU
+  %eightf8E8M0FNU = arith.constant 8.0 : f8E8M0FNU
+
+  // CHECK: 16
+  %c4 = math.fma %fourf8E8M0FNU, %twof8E8M0FNU, %eightf8E8M0FNU : f8E8M0FNU
+  // vector.print %c4 : f8E8M0FNU
+
   return
 }

>From 9d203b2f2e929b934b1a567218095d9bde514df2 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Wed, 10 Dec 2025 16:25:44 -0800
Subject: [PATCH 3/6] working

---
 .../Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp    | 1 -
 mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp             | 7 +++++--
 mlir/lib/ExecutionEngine/APFloatWrappers.cpp              | 8 +-------
 .../Dialect/Math/CPU/test-apfloat-emulation.mlir          | 4 ++--
 4 files changed, 8 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
index 4c8764ba1d6b0..8d9abde951182 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
@@ -157,7 +157,6 @@ struct FmaOpToAPFloatConversion final : OpRewritePattern<math::FmaOp> {
   }
 
   SymbolOpInterface symTable;
-  const char *APFloatName;
 };
 
 namespace {
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index a08cc98e4d5bf..e5f8763127a1b 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -37,7 +37,9 @@ using ConvertFMFMathToLLVMPattern =
     VectorConvertToLLVMPattern<SourceOp, TargetOp, ConvertFastMath,
                                FailOnUnsupportedFP>;
 
-using AbsFOpLowering = ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
+using AbsFOpLowering =
+    ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp,
+                                /*FailOnUnsupportedFP=*/true>;
 using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
 using CopySignOpLowering =
     ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
@@ -52,7 +54,8 @@ using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
 using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
 using FloorOpLowering =
     ConvertFMFMathToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
-using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
+using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp,
+                                                  /*FailOnUnsupportedFP=*/true>;
 using Log10OpLowering =
     ConvertFMFMathToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
 using Log2OpLowering = ConvertFMFMathToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
index 254590a0d8566..9deb900fbe35d 100644
--- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
@@ -199,7 +199,7 @@ MLIR_APFLOAT_WRAPPERS_EXPORT bool _mlir_apfloat_isnan(int32_t semantics,
   return x.isNaN();
 }
 
-MLIR_APFLOAT_WRAPPERS_EXPORT bool
+MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t
 _mlir_apfloat_fused_multiply_add(int32_t semantics, uint64_t operand,
                                  uint64_t multiplicand, uint64_t addend) {
   const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
@@ -210,12 +210,6 @@ _mlir_apfloat_fused_multiply_add(int32_t semantics, uint64_t operand,
   llvm::APFloat addend_(sem, llvm::APInt(bitWidth, addend));
   llvm::detail::opStatus stat = operand_.fusedMultiplyAdd(
       multiplicand_, addend_, llvm::RoundingMode::NearestTiesToEven);
-
-  ////////////
-  operand_.print(llvm::dbgs());
-  llvm::dbgs() << "\n";
-  ////////////
-
   assert(stat == llvm::APFloatBase::opOK &&
          "expected fusedMultiplyAdd status to be OK");
   return operand_.bitcastToAPInt().getZExtValue();
diff --git a/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir
index 892b970a2796d..261308740208b 100644
--- a/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir
+++ b/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir
@@ -8,7 +8,7 @@
 // RUN:             --shared-libs=%mlir_c_runner_utils \
 // RUN:             --shared-libs=%mlir_apfloat_wrappers | FileCheck %s
 
-// Case 2: Only unsupported arithmetics (f8E4M3FN) is lowered through APFloat.
+// Case 2: Only unsupported arithmetics is lowered through APFloat.
 //         Arithmetics on f32 is lowered directly to LLVM.
 // RUN: mlir-opt %s --convert-to-llvm --convert-math-to-apfloat \
 // RUN:          --convert-to-llvm --reconcile-unrealized-casts | \
@@ -35,7 +35,7 @@ func.func @entry() {
 
   // CHECK: 16
   %c4 = math.fma %fourf8E8M0FNU, %twof8E8M0FNU, %eightf8E8M0FNU : f8E8M0FNU
-  // vector.print %c4 : f8E8M0FNU
+  vector.print %c4 : f8E8M0FNU
 
   return
 }

>From 6e118f3f8a2c9033db6aebbd4609008676901722 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 12 Dec 2025 15:05:24 -0800
Subject: [PATCH 4/6] address comment

---
 .../ArithAndMathToAPFloat/CMakeLists.txt           |  2 +-
 .../ArithAndMathToAPFloat/MathToAPFloat.cpp        | 14 +++++++++++++-
 2 files changed, 14 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/CMakeLists.txt b/mlir/lib/Conversion/ArithAndMathToAPFloat/CMakeLists.txt
index cc8e61a87addc..bad8226ac88ec 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/CMakeLists.txt
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/CMakeLists.txt
@@ -4,7 +4,7 @@ add_mlir_library(ArithAndMathToAPFloatUtils
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
-)
+  )
 
 add_mlir_conversion_library(MLIRArithToAPFloat
   ArithToAPFloat.cpp
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
index 8d9abde951182..9cd5a41daf7d8 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
@@ -45,6 +45,10 @@ struct AbsFOpToAPFloatConversion final : OpRewritePattern<math::AbsFOp> {
     // Cast operands to 64-bit integers.
     auto operand = op.getOperand();
     auto floatTy = cast<FloatType>(operand.getType());
+    if (floatTy.getIntOrFloatBitWidth() > 64) {
+      return rewriter.notifyMatchFailure(op,
+                                         "bitwidth > 64 bits is not supported");
+    }
     auto intWType = rewriter.getIntegerType(floatTy.getWidth());
     Value operandBits = arith::ExtUIOp::create(
         rewriter, loc, i64Type,
@@ -93,6 +97,10 @@ struct IsOpToAPFloatConversion final : OpRewritePattern<OpTy> {
     // Cast operands to 64-bit integers.
     auto operand = op.getOperand();
     auto floatTy = cast<FloatType>(operand.getType());
+    if (floatTy.getIntOrFloatBitWidth() > 64) {
+      return rewriter.notifyMatchFailure(op,
+                                         "bitwidth > 64 bits is not supported");
+    }
     auto intWType = rewriter.getIntegerType(floatTy.getWidth());
     Value operandBits = arith::ExtUIOp::create(
         rewriter, loc, i64Type,
@@ -101,7 +109,7 @@ struct IsOpToAPFloatConversion final : OpRewritePattern<OpTy> {
     // Call APFloat function.
     Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
     SmallVector<Value> params = {semValue, operandBits};
-    rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(i64Type),
+    rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(i1),
                                               SymbolRefAttr::get(*fn), params);
     return success();
   }
@@ -130,6 +138,10 @@ struct FmaOpToAPFloatConversion final : OpRewritePattern<math::FmaOp> {
 
     // Cast operands to 64-bit integers.
     auto floatTy = cast<FloatType>(op.getResult().getType());
+    if (floatTy.getIntOrFloatBitWidth() > 64) {
+      return rewriter.notifyMatchFailure(op,
+                                         "bitwidth > 64 bits is not supported");
+    }
     auto intWType = rewriter.getIntegerType(floatTy.getWidth());
     auto int64Type = rewriter.getI64Type();
     Value operand = arith::ExtUIOp::create(

>From 8cba26c0fd26bb52bad75a2c3c167939779a3848 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 12 Dec 2025 16:38:22 -0800
Subject: [PATCH 5/6] add lit test

---
 .../arith-to-apfloat.mlir                     |  0
 .../math-to-apfloat.mlir                      | 66 ++++++++++++++++++
 .../Math/CPU/test-apfloat-emulation.mlir      | 69 ++++++++++++++-----
 3 files changed, 118 insertions(+), 17 deletions(-)
 rename mlir/test/Conversion/{ArithToApfloat => ArithAndMathToAPFloat}/arith-to-apfloat.mlir (100%)
 create mode 100644 mlir/test/Conversion/ArithAndMathToAPFloat/math-to-apfloat.mlir

diff --git a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir b/mlir/test/Conversion/ArithAndMathToAPFloat/arith-to-apfloat.mlir
similarity index 100%
rename from mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
rename to mlir/test/Conversion/ArithAndMathToAPFloat/arith-to-apfloat.mlir
diff --git a/mlir/test/Conversion/ArithAndMathToAPFloat/math-to-apfloat.mlir b/mlir/test/Conversion/ArithAndMathToAPFloat/math-to-apfloat.mlir
new file mode 100644
index 0000000000000..ee5940f86ec01
--- /dev/null
+++ b/mlir/test/Conversion/ArithAndMathToAPFloat/math-to-apfloat.mlir
@@ -0,0 +1,66 @@
+// RUN: mlir-opt %s --convert-math-to-apfloat | FileCheck %s
+
+func.func @full_example() {
+  %neg14fp8 = arith.constant -1.4 : f8E4M3FN
+  %abs = math.absf %neg14fp8 : f8E4M3FN
+
+  // see llvm/unittests/ADT/APFloatTest::TEST(APFloatTest, Float8E8M0FNUFMA)
+  %twof8E8M0FNU = arith.constant 2.0 : f8E8M0FNU
+  %fourf8E8M0FNU = arith.constant 4.0 : f8E8M0FNU
+  %eightf8E8M0FNU = arith.constant 8.0 : f8E8M0FNU
+  %fma = math.fma %fourf8E8M0FNU, %twof8E8M0FNU, %eightf8E8M0FNU : f8E8M0FNU
+
+  %isinf = math.isinf %neg14fp8 : f8E4M3FN
+  %isnan = math.isnan %neg14fp8 : f8E4M3FN
+  %isnormal = math.isnormal %neg14fp8 : f8E4M3FN
+  %isfinite = math.isfinite %neg14fp8 : f8E4M3FN
+
+  return
+}
+
+// CHECK-LABEL:   func.func private @_mlir_apfloat_isfinite(i32, i64) -> i1
+// CHECK:         func.func private @_mlir_apfloat_isnormal(i32, i64) -> i1
+// CHECK:         func.func private @_mlir_apfloat_isnan(i32, i64) -> i1
+// CHECK:         func.func private @_mlir_apfloat_isinfinite(i32, i64) -> i1
+// CHECK:         func.func private @_mlir_apfloat_fused_multiply_add(i32, i64, i64, i64) -> i64
+// CHECK:         func.func private @_mlir_apfloat_abs(i32, i64) -> i64
+
+// CHECK-LABEL:   func.func @full_example() {
+// CHECK:           %[[CONSTANT_0:.*]] = arith.constant -1.375000e+00 : f8E4M3FN
+// CHECK:           %[[BITCAST_0:.*]] = arith.bitcast %[[CONSTANT_0]] : f8E4M3FN to i8
+// CHECK:           %[[EXTUI_0:.*]] = arith.extui %[[BITCAST_0]] : i8 to i64
+// CHECK:           %[[CONSTANT_1:.*]] = arith.constant 10 : i32
+// CHECK:           %[[VAL_0:.*]] = call @_mlir_apfloat_abs(%[[CONSTANT_1]], %[[EXTUI_0]]) : (i32, i64) -> i64
+// CHECK:           %[[TRUNCI_0:.*]] = arith.trunci %[[VAL_0]] : i64 to i8
+// CHECK:           %[[BITCAST_1:.*]] = arith.bitcast %[[TRUNCI_0]] : i8 to f8E4M3FN
+// CHECK:           %[[CONSTANT_2:.*]] = arith.constant 2.000000e+00 : f8E8M0FNU
+// CHECK:           %[[CONSTANT_3:.*]] = arith.constant 4.000000e+00 : f8E8M0FNU
+// CHECK:           %[[CONSTANT_4:.*]] = arith.constant 8.000000e+00 : f8E8M0FNU
+// CHECK:           %[[BITCAST_2:.*]] = arith.bitcast %[[CONSTANT_3]] : f8E8M0FNU to i8
+// CHECK:           %[[EXTUI_1:.*]] = arith.extui %[[BITCAST_2]] : i8 to i64
+// CHECK:           %[[BITCAST_3:.*]] = arith.bitcast %[[CONSTANT_2]] : f8E8M0FNU to i8
+// CHECK:           %[[EXTUI_2:.*]] = arith.extui %[[BITCAST_3]] : i8 to i64
+// CHECK:           %[[BITCAST_4:.*]] = arith.bitcast %[[CONSTANT_4]] : f8E8M0FNU to i8
+// CHECK:           %[[EXTUI_3:.*]] = arith.extui %[[BITCAST_4]] : i8 to i64
+// CHECK:           %[[CONSTANT_5:.*]] = arith.constant 15 : i32
+// CHECK:           %[[VAL_1:.*]] = call @_mlir_apfloat_fused_multiply_add(%[[CONSTANT_5]], %[[EXTUI_1]], %[[EXTUI_2]], %[[EXTUI_3]]) : (i32, i64, i64, i64) -> i64
+// CHECK:           %[[TRUNCI_1:.*]] = arith.trunci %[[VAL_1]] : i64 to i8
+// CHECK:           %[[BITCAST_5:.*]] = arith.bitcast %[[TRUNCI_1]] : i8 to f8E8M0FNU
+// CHECK:           %[[BITCAST_6:.*]] = arith.bitcast %[[CONSTANT_0]] : f8E4M3FN to i8
+// CHECK:           %[[EXTUI_4:.*]] = arith.extui %[[BITCAST_6]] : i8 to i64
+// CHECK:           %[[CONSTANT_6:.*]] = arith.constant 10 : i32
+// CHECK:           %[[VAL_2:.*]] = call @_mlir_apfloat_isinfinite(%[[CONSTANT_6]], %[[EXTUI_4]]) : (i32, i64) -> i1
+// CHECK:           %[[BITCAST_7:.*]] = arith.bitcast %[[CONSTANT_0]] : f8E4M3FN to i8
+// CHECK:           %[[EXTUI_5:.*]] = arith.extui %[[BITCAST_7]] : i8 to i64
+// CHECK:           %[[CONSTANT_7:.*]] = arith.constant 10 : i32
+// CHECK:           %[[VAL_3:.*]] = call @_mlir_apfloat_isnan(%[[CONSTANT_7]], %[[EXTUI_5]]) : (i32, i64) -> i1
+// CHECK:           %[[BITCAST_8:.*]] = arith.bitcast %[[CONSTANT_0]] : f8E4M3FN to i8
+// CHECK:           %[[EXTUI_6:.*]] = arith.extui %[[BITCAST_8]] : i8 to i64
+// CHECK:           %[[CONSTANT_8:.*]] = arith.constant 10 : i32
+// CHECK:           %[[VAL_4:.*]] = call @_mlir_apfloat_isnormal(%[[CONSTANT_8]], %[[EXTUI_6]]) : (i32, i64) -> i1
+// CHECK:           %[[BITCAST_9:.*]] = arith.bitcast %[[CONSTANT_0]] : f8E4M3FN to i8
+// CHECK:           %[[EXTUI_7:.*]] = arith.extui %[[BITCAST_9]] : i8 to i64
+// CHECK:           %[[CONSTANT_9:.*]] = arith.constant 10 : i32
+// CHECK:           %[[VAL_5:.*]] = call @_mlir_apfloat_isfinite(%[[CONSTANT_9]], %[[EXTUI_7]]) : (i32, i64) -> i1
+// CHECK:           return
+// CHECK:         }
\ No newline at end of file
diff --git a/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir
index 261308740208b..f713743c36b1e 100644
--- a/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir
+++ b/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir
@@ -4,38 +4,73 @@
 
 // Case 1: All floating-point arithmetics is lowered through APFloat.
 // RUN: mlir-opt %s --convert-math-to-apfloat --convert-to-llvm | \
-// RUN: mlir-runner -e entry --entry-point-result=void \
+// RUN: mlir-runner -e entryfp8 --entry-point-result=void \
 // RUN:             --shared-libs=%mlir_c_runner_utils \
-// RUN:             --shared-libs=%mlir_apfloat_wrappers | FileCheck %s
+// RUN:             --shared-libs=%mlir_apfloat_wrappers | FileCheck %s --check-prefix=CHECK-FP8
 
 // Case 2: Only unsupported arithmetics is lowered through APFloat.
 //         Arithmetics on f32 is lowered directly to LLVM.
 // RUN: mlir-opt %s --convert-to-llvm --convert-math-to-apfloat \
 // RUN:          --convert-to-llvm --reconcile-unrealized-casts | \
-// RUN: mlir-runner -e entry --entry-point-result=void \
+// RUN: mlir-runner -e entryfp32 --entry-point-result=void \
 // RUN:             --shared-libs=%mlir_c_runner_utils \
-// RUN:             --shared-libs=%mlir_apfloat_wrappers | FileCheck %s
+// RUN:             --shared-libs=%mlir_apfloat_wrappers | FileCheck %s --check-prefix=CHECK-FP32
 
-func.func @entry() {
+func.func @entryfp8() {
   %neg14fp8 = arith.constant -1.4 : f8E4M3FN
-  %neg14fp32 = arith.constant 1.4 : f32
-
-  // CHECK: 1.375
-  %c2 = math.absf %neg14fp8 : f8E4M3FN
-  vector.print %c2 : f8E4M3FN
-
-  // CHECK: 1.4
-  %c3 = math.absf %neg14fp32 : f32
-  vector.print %c3 : f32
+  %abs = math.absf %neg14fp8 : f8E4M3FN
+  // CHECK-FP8: 1.375
+  vector.print %abs : f8E4M3FN
 
   // see llvm/unittests/ADT/APFloatTest::TEST(APFloatTest, Float8E8M0FNUFMA)
   %twof8E8M0FNU = arith.constant 2.0 : f8E8M0FNU
   %fourf8E8M0FNU = arith.constant 4.0 : f8E8M0FNU
   %eightf8E8M0FNU = arith.constant 8.0 : f8E8M0FNU
+  %fma = math.fma %fourf8E8M0FNU, %twof8E8M0FNU, %eightf8E8M0FNU : f8E8M0FNU
+  // CHECK-FP8: 16
+  vector.print %fma : f8E8M0FNU
+
+  // CHECK-FP8: 0
+  %isinf = math.isinf %neg14fp8 : f8E4M3FN
+  vector.print %isinf : i1
+  // CHECK-FP8: 0
+  %isnan = math.isnan %neg14fp8 : f8E4M3FN
+  vector.print %isnan : i1
+  %isnormal = math.isnormal %neg14fp8 : f8E4M3FN
+  // CHECK-FP8: 1
+  vector.print %isnormal : i1
+  %isfinite = math.isfinite %neg14fp8 : f8E4M3FN
+  // CHECK-FP8: 1
+  vector.print %isfinite : i1
+
+  return
+}
+
+func.func @entryfp32() {
+  %neg14 = arith.constant -1.4 : f32
+  %abs = math.absf %neg14 : f32
+  // CHECK-FP32: 1.4
+  vector.print %abs : f32
+
+  %two = arith.constant 2.0 : f32
+  %four = arith.constant 4.0 : f32
+  %eight = arith.constant 8.0 : f32
+  %fma = math.fma %four, %two, %eight : f32
+  // CHECK-FP32: 16
+  vector.print %fma : f32
 
-  // CHECK: 16
-  %c4 = math.fma %fourf8E8M0FNU, %twof8E8M0FNU, %eightf8E8M0FNU : f8E8M0FNU
-  vector.print %c4 : f8E8M0FNU
+  // CHECK-FP32: 0
+  %isinf = math.isinf %neg14 : f32
+  vector.print %isinf : i1
+  // CHECK-FP32: 0
+  %isnan = math.isnan %neg14 : f32
+  vector.print %isnan : i1
+  %isnormal = math.isnormal %neg14 : f32
+  // CHECK-FP32: 1
+  vector.print %isnormal : i1
+  %isfinite = math.isfinite %neg14 : f32
+  // CHECK-FP32: 1
+  vector.print %isfinite : i1
 
   return
 }

>From 8e33e3101fe87f655acce93df15062c3cbcfeff1 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Wed, 17 Dec 2025 10:22:22 -0800
Subject: [PATCH 6/6] address comments

---
 .../ArithAndMathToAPFloat/MathToAPFloat.cpp   |  49 ++++----
 .../math-to-apfloat.mlir                      |   2 +-
 .../Math/CPU/test-apfloat-emulation.mlir      | 108 ++++++++----------
 3 files changed, 80 insertions(+), 79 deletions(-)

diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
index 9cd5a41daf7d8..20d82863c518e 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
@@ -33,6 +33,16 @@ struct AbsFOpToAPFloatConversion final : OpRewritePattern<math::AbsFOp> {
 
   LogicalResult matchAndRewrite(math::AbsFOp op,
                                 PatternRewriter &rewriter) const override {
+    // Cast operands to 64-bit integers.
+    auto operand = op.getOperand();
+    auto floatTy = dyn_cast<FloatType>(operand.getType());
+    if (!floatTy)
+      return rewriter.notifyMatchFailure(op,
+                                         "only scalar FloatTypes supported");
+    if (floatTy.getIntOrFloatBitWidth() > 64) {
+      return rewriter.notifyMatchFailure(op,
+                                         "bitwidth > 64 bits is not supported");
+    }
     // Get APFloat function from runtime library.
     auto i32Type = IntegerType::get(symTable->getContext(), 32);
     auto i64Type = IntegerType::get(symTable->getContext(), 64);
@@ -42,13 +52,6 @@ struct AbsFOpToAPFloatConversion final : OpRewritePattern<math::AbsFOp> {
       return fn;
     Location loc = op.getLoc();
     rewriter.setInsertionPoint(op);
-    // Cast operands to 64-bit integers.
-    auto operand = op.getOperand();
-    auto floatTy = cast<FloatType>(operand.getType());
-    if (floatTy.getIntOrFloatBitWidth() > 64) {
-      return rewriter.notifyMatchFailure(op,
-                                         "bitwidth > 64 bits is not supported");
-    }
     auto intWType = rewriter.getIntegerType(floatTy.getWidth());
     Value operandBits = arith::ExtUIOp::create(
         rewriter, loc, i64Type,
@@ -82,6 +85,16 @@ struct IsOpToAPFloatConversion final : OpRewritePattern<OpTy> {
 
   LogicalResult matchAndRewrite(OpTy op,
                                 PatternRewriter &rewriter) const override {
+    // Cast operands to 64-bit integers.
+    auto operand = op.getOperand();
+    auto floatTy = dyn_cast<FloatType>(operand.getType());
+    if (!floatTy)
+      return rewriter.notifyMatchFailure(op,
+                                         "only scalar FloatTypes supported");
+    if (floatTy.getIntOrFloatBitWidth() > 64) {
+      return rewriter.notifyMatchFailure(op,
+                                         "bitwidth > 64 bits is not supported");
+    }
     // Get APFloat function from runtime library.
     auto i1 = IntegerType::get(symTable->getContext(), 1);
     auto i32Type = IntegerType::get(symTable->getContext(), 32);
@@ -94,13 +107,6 @@ struct IsOpToAPFloatConversion final : OpRewritePattern<OpTy> {
       return fn;
     Location loc = op.getLoc();
     rewriter.setInsertionPoint(op);
-    // Cast operands to 64-bit integers.
-    auto operand = op.getOperand();
-    auto floatTy = cast<FloatType>(operand.getType());
-    if (floatTy.getIntOrFloatBitWidth() > 64) {
-      return rewriter.notifyMatchFailure(op,
-                                         "bitwidth > 64 bits is not supported");
-    }
     auto intWType = rewriter.getIntegerType(floatTy.getWidth());
     Value operandBits = arith::ExtUIOp::create(
         rewriter, loc, i64Type,
@@ -125,6 +131,15 @@ struct FmaOpToAPFloatConversion final : OpRewritePattern<math::FmaOp> {
 
   LogicalResult matchAndRewrite(math::FmaOp op,
                                 PatternRewriter &rewriter) const override {
+    // Cast operands to 64-bit integers.
+    auto floatTy = cast<FloatType>(op.getResult().getType());
+    if (!floatTy)
+      return rewriter.notifyMatchFailure(op,
+                                         "only scalar FloatTypes supported");
+    if (floatTy.getIntOrFloatBitWidth() > 64) {
+      return rewriter.notifyMatchFailure(op,
+                                         "bitwidth > 64 bits is not supported");
+    }
 
     auto i32Type = IntegerType::get(symTable->getContext(), 32);
     auto i64Type = IntegerType::get(symTable->getContext(), 64);
@@ -136,12 +151,6 @@ struct FmaOpToAPFloatConversion final : OpRewritePattern<math::FmaOp> {
     Location loc = op.getLoc();
     rewriter.setInsertionPoint(op);
 
-    // Cast operands to 64-bit integers.
-    auto floatTy = cast<FloatType>(op.getResult().getType());
-    if (floatTy.getIntOrFloatBitWidth() > 64) {
-      return rewriter.notifyMatchFailure(op,
-                                         "bitwidth > 64 bits is not supported");
-    }
     auto intWType = rewriter.getIntegerType(floatTy.getWidth());
     auto int64Type = rewriter.getI64Type();
     Value operand = arith::ExtUIOp::create(
diff --git a/mlir/test/Conversion/ArithAndMathToAPFloat/math-to-apfloat.mlir b/mlir/test/Conversion/ArithAndMathToAPFloat/math-to-apfloat.mlir
index ee5940f86ec01..e52eb5866093c 100644
--- a/mlir/test/Conversion/ArithAndMathToAPFloat/math-to-apfloat.mlir
+++ b/mlir/test/Conversion/ArithAndMathToAPFloat/math-to-apfloat.mlir
@@ -63,4 +63,4 @@ func.func @full_example() {
 // CHECK:           %[[CONSTANT_9:.*]] = arith.constant 10 : i32
 // CHECK:           %[[VAL_5:.*]] = call @_mlir_apfloat_isfinite(%[[CONSTANT_9]], %[[EXTUI_7]]) : (i32, i64) -> i1
 // CHECK:           return
-// CHECK:         }
\ No newline at end of file
+// CHECK:         }
diff --git a/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir
index f713743c36b1e..c890b470b563a 100644
--- a/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir
+++ b/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir
@@ -1,76 +1,68 @@
 // REQUIRES: system-linux || system-darwin
-// TODO: Run only on Linux until we figure out how to build
+// TODO: Run only on Linux and MacOS until we figure out how to build
 // mlir_apfloat_wrappers in a platform-independent way.
 
-// Case 1: All floating-point arithmetics is lowered through APFloat.
-// RUN: mlir-opt %s --convert-math-to-apfloat --convert-to-llvm | \
-// RUN: mlir-runner -e entryfp8 --entry-point-result=void \
+// RUN: mlir-opt %s --convert-math-to-apfloat --convert-to-llvm  | \
+// RUN: mlir-runner -e entry --entry-point-result=void \
 // RUN:             --shared-libs=%mlir_c_runner_utils \
-// RUN:             --shared-libs=%mlir_apfloat_wrappers | FileCheck %s --check-prefix=CHECK-FP8
+// RUN:             --shared-libs=%mlir_apfloat_wrappers | FileCheck %s
 
-// Case 2: Only unsupported arithmetics is lowered through APFloat.
-//         Arithmetics on f32 is lowered directly to LLVM.
-// RUN: mlir-opt %s --convert-to-llvm --convert-math-to-apfloat \
-// RUN:          --convert-to-llvm --reconcile-unrealized-casts | \
-// RUN: mlir-runner -e entryfp32 --entry-point-result=void \
-// RUN:             --shared-libs=%mlir_c_runner_utils \
-// RUN:             --shared-libs=%mlir_apfloat_wrappers | FileCheck %s --check-prefix=CHECK-FP32
+func.func @entry() {
+
+  // FP8
 
-func.func @entryfp8() {
   %neg14fp8 = arith.constant -1.4 : f8E4M3FN
-  %abs = math.absf %neg14fp8 : f8E4M3FN
-  // CHECK-FP8: 1.375
-  vector.print %abs : f8E4M3FN
+  %absfp8 = math.absf %neg14fp8 : f8E4M3FN
+  // CHECK: 1.375
+  vector.print %absfp8 : f8E4M3FN
 
   // see llvm/unittests/ADT/APFloatTest::TEST(APFloatTest, Float8E8M0FNUFMA)
   %twof8E8M0FNU = arith.constant 2.0 : f8E8M0FNU
   %fourf8E8M0FNU = arith.constant 4.0 : f8E8M0FNU
   %eightf8E8M0FNU = arith.constant 8.0 : f8E8M0FNU
-  %fma = math.fma %fourf8E8M0FNU, %twof8E8M0FNU, %eightf8E8M0FNU : f8E8M0FNU
-  // CHECK-FP8: 16
-  vector.print %fma : f8E8M0FNU
-
-  // CHECK-FP8: 0
-  %isinf = math.isinf %neg14fp8 : f8E4M3FN
-  vector.print %isinf : i1
-  // CHECK-FP8: 0
-  %isnan = math.isnan %neg14fp8 : f8E4M3FN
-  vector.print %isnan : i1
-  %isnormal = math.isnormal %neg14fp8 : f8E4M3FN
-  // CHECK-FP8: 1
-  vector.print %isnormal : i1
-  %isfinite = math.isfinite %neg14fp8 : f8E4M3FN
-  // CHECK-FP8: 1
-  vector.print %isfinite : i1
-
-  return
-}
+  %fmafp8 = math.fma %fourf8E8M0FNU, %twof8E8M0FNU, %eightf8E8M0FNU : f8E8M0FNU
+  // CHECK: 16
+  vector.print %fmafp8 : f8E8M0FNU
 
-func.func @entryfp32() {
-  %neg14 = arith.constant -1.4 : f32
-  %abs = math.absf %neg14 : f32
-  // CHECK-FP32: 1.4
-  vector.print %abs : f32
+  // CHECK: 0
+  %isinffp8 = math.isinf %neg14fp8 : f8E4M3FN
+  vector.print %isinffp8 : i1
+  // CHECK: 0
+  %isnanfp8 = math.isnan %neg14fp8 : f8E4M3FN
+  vector.print %isnanfp8 : i1
+  %isnormalfp8 = math.isnormal %neg14fp8 : f8E4M3FN
+  // CHECK: 1
+  vector.print %isnormalfp8 : i1
+  %isfinitefp8 = math.isfinite %neg14fp8 : f8E4M3FN
+  // CHECK: 1
+  vector.print %isfinitefp8 : i1
+  
+  // FP32
+  
+  %neg14fp32 = arith.constant -1.4 : f32
+  %absfp32 = math.absf %neg14fp32 : f32
+  // CHECK: 1.4
+  vector.print %absfp32 : f32
 
-  %two = arith.constant 2.0 : f32
-  %four = arith.constant 4.0 : f32
-  %eight = arith.constant 8.0 : f32
-  %fma = math.fma %four, %two, %eight : f32
-  // CHECK-FP32: 16
-  vector.print %fma : f32
+  %twofp32 = arith.constant 2.0 : f32
+  %fourfp32 = arith.constant 4.0 : f32
+  %eightfp32 = arith.constant 8.0 : f32
+  %fmafp32 = math.fma %fourfp32, %twofp32, %eightfp32 : f32
+  // CHECK: 16
+  vector.print %fmafp32 : f32
 
-  // CHECK-FP32: 0
-  %isinf = math.isinf %neg14 : f32
-  vector.print %isinf : i1
-  // CHECK-FP32: 0
-  %isnan = math.isnan %neg14 : f32
-  vector.print %isnan : i1
-  %isnormal = math.isnormal %neg14 : f32
-  // CHECK-FP32: 1
-  vector.print %isnormal : i1
-  %isfinite = math.isfinite %neg14 : f32
-  // CHECK-FP32: 1
-  vector.print %isfinite : i1
+  // CHECK: 0
+  %isinffp32 = math.isinf %neg14fp32 : f32
+  vector.print %isinffp32 : i1
+  // CHECK: 0
+  %isnanfp32 = math.isnan %neg14fp32 : f32
+  vector.print %isnanfp32 : i1
+  %isnormalfp32 = math.isnormal %neg14fp32 : f32
+  // CHECK: 1
+  vector.print %isnormalfp32 : i1
+  %isfinitefp32 = math.isfinite %neg14fp32 : f32
+  // CHECK: 1
+  vector.print %isfinitefp32 : i1
 
   return
 }



More information about the Mlir-commits mailing list