[Mlir-commits] [mlir] [mlir][math] Add FP software implementation lowering pass: math-to-apfloat (PR #171221)

Maksim Levental llvmlistbot at llvm.org
Mon Dec 8 15:53:27 PST 2025


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/171221

>From 1e43eb136b5ffcc0ef86878ed03212d4f4ca151b Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Mon, 8 Dec 2025 14:48:11 -0800
Subject: [PATCH] [mlir][math] Add FP software implementation lowering pass:
 math-to-apfloat

---
 .../Conversion/MathToAPFloat/MathToAPFloat.h  | 21 ++++++
 mlir/include/mlir/Conversion/Passes.h         |  1 +
 mlir/include/mlir/Conversion/Passes.td        | 15 +++++
 mlir/include/mlir/Dialect/Func/Utils/Utils.h  | 16 +++++
 .../ArithToAPFloat/ArithToAPFloat.cpp         | 67 ++++---------------
 mlir/lib/Conversion/CMakeLists.txt            |  1 +
 .../Conversion/MathToAPFloat/CMakeLists.txt   | 17 +++++
 .../MathToAPFloat/MathToAPFloat.cpp           | 52 ++++++++++++++
 mlir/lib/Dialect/Func/Utils/Utils.cpp         | 38 +++++++++++
 9 files changed, 175 insertions(+), 53 deletions(-)
 create mode 100644 mlir/include/mlir/Conversion/MathToAPFloat/MathToAPFloat.h
 create mode 100644 mlir/lib/Conversion/MathToAPFloat/CMakeLists.txt
 create mode 100644 mlir/lib/Conversion/MathToAPFloat/MathToAPFloat.cpp

diff --git a/mlir/include/mlir/Conversion/MathToAPFloat/MathToAPFloat.h b/mlir/include/mlir/Conversion/MathToAPFloat/MathToAPFloat.h
new file mode 100644
index 0000000000000..86179a1611d5e
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MathToAPFloat/MathToAPFloat.h
@@ -0,0 +1,21 @@
+//===- MathToAPFloat.h - Math to APFloat impl conversion ---*- C++ ------*-===//
+//
+// Part of the APFloat Project, under the Apache License v2.0 with APFloat
+// Exceptions. See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH APFloat-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_MATHTOAPFLOAT_MATHTOAPFLOAT_H
+#define MLIR_CONVERSION_MATHTOAPFLOAT_MATHTOAPFLOAT_H
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_MATHTOAPFLOATCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MATHTOAPFLOAT_MATHTOAPFLOAT_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 82bdfd02661a6..05ec2f8ce2538 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -44,6 +44,7 @@
 #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
 #include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
 #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
+#include "mlir/Conversion/MathToAPFloat/MathToAPFloat.h"
 #include "mlir/Conversion/MathToEmitC/MathToEmitCPass.h"
 #include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index fcbaf3ccc1486..7f24e58671aab 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -775,6 +775,21 @@ def ConvertMathToLibmPass : Pass<"convert-math-to-libm", "ModuleOp"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// MathToAPFloat
+//===----------------------------------------------------------------------===//
+
+def MathToAPFloatConversionPass
+    : Pass<"convert-math-to-apfloat", "ModuleOp"> {
+  let summary = "Convert Math ops to APFloat runtime library calls";
+  let description = [{
+    This pass converts supported Math ops to APFloat-based runtime library
+    calls (APFloatWrappers.cpp). APFloat is a software implementation of
+    floating-point mathmetic operations.
+  }];
+  let dependentDialects = ["math::MathDialect", "func::FuncDialect"];
+}
+
 //===----------------------------------------------------------------------===//
 // MathToLLVM
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
index 00d50874a2e8d..c04f61c869c49 100644
--- a/mlir/include/mlir/Dialect/Func/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
@@ -67,6 +67,22 @@ FailureOr<FuncOp> lookupFnDecl(SymbolOpInterface symTable, StringRef name,
                                FunctionType funcT,
                                SymbolTableCollection *symbolTables = nullptr);
 
+/// Create a FuncOp decl and insert it into `symTable` operation. If
+/// `symbolTables` is provided, then the decl will be inserted into the
+/// SymbolTableCollection.
+FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
+                    FunctionType funcT, bool setPrivate,
+                    SymbolTableCollection *symbolTables = nullptr);
+
+/// Helper function to look up or create the symbol for a runtime library
+/// function with the given parameter types. Returns an int64_t, unless a
+/// different result type is specified.
+FailureOr<FuncOp>
+lookupOrCreateFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
+                 TypeRange paramTypes,
+                 SymbolTableCollection *symbolTables = nullptr,
+                 Type resultType = {});
+
 } // namespace func
 } // namespace mlir
 
diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
index 79816fc6e3bf1..103a9529eab44 100644
--- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
@@ -25,47 +25,6 @@ namespace mlir {
 using namespace mlir;
 using namespace mlir::func;
 
-static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
-                           StringRef name, FunctionType funcT, bool setPrivate,
-                           SymbolTableCollection *symbolTables = nullptr) {
-  OpBuilder::InsertionGuard g(b);
-  assert(!symTable->getRegion(0).empty() && "expected non-empty region");
-  b.setInsertionPointToStart(&symTable->getRegion(0).front());
-  FuncOp funcOp = FuncOp::create(b, symTable->getLoc(), name, funcT);
-  if (setPrivate)
-    funcOp.setPrivate();
-  if (symbolTables) {
-    SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable);
-    symbolTable.insert(funcOp, symTable->getRegion(0).front().begin());
-  }
-  return funcOp;
-}
-
-/// Helper function to look up or create the symbol for a runtime library
-/// function with the given parameter types. Returns an int64_t, unless a
-/// different result type is specified.
-static FailureOr<FuncOp>
-lookupOrCreateApFloatFn(OpBuilder &b, SymbolOpInterface symTable,
-                        StringRef name, TypeRange paramTypes,
-                        SymbolTableCollection *symbolTables = nullptr,
-                        Type resultType = {}) {
-  if (!resultType)
-    resultType = IntegerType::get(symTable->getContext(), 64);
-  std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
-  auto funcT = FunctionType::get(b.getContext(), paramTypes, {resultType});
-  FailureOr<FuncOp> func =
-      lookupFnDecl(symTable, funcName, funcT, symbolTables);
-  // Failed due to type mismatch.
-  if (failed(func))
-    return func;
-  // Successfully matched existing decl.
-  if (*func)
-    return *func;
-
-  return createFnDecl(b, symTable, funcName, funcT,
-                      /*setPrivate=*/true, symbolTables);
-}
-
 /// Helper function to look up or create the symbol for a runtime library
 /// function for a binary arithmetic operation.
 ///
@@ -81,8 +40,9 @@ lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
                        SymbolTableCollection *symbolTables = nullptr) {
   auto i32Type = IntegerType::get(symTable->getContext(), 32);
   auto i64Type = IntegerType::get(symTable->getContext(), 64);
-  return lookupOrCreateApFloatFn(b, symTable, name, {i32Type, i64Type, i64Type},
-                                 symbolTables);
+  std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
+  return lookupOrCreateFn(b, symTable, funcName, {i32Type, i64Type, i64Type},
+                          symbolTables);
 }
 
 static Value getSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy) {
@@ -231,8 +191,9 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
     // Get APFloat function from runtime library.
     auto i32Type = IntegerType::get(symTable->getContext(), 32);
     auto i64Type = IntegerType::get(symTable->getContext(), 64);
-    FailureOr<FuncOp> fn = lookupOrCreateApFloatFn(
-        rewriter, symTable, "convert", {i32Type, i32Type, i64Type});
+    FailureOr<FuncOp> fn =
+        lookupOrCreateFn(rewriter, symTable, "_mlir_apfloat_convert",
+                         {i32Type, i32Type, i64Type});
     if (failed(fn))
       return fn;
 
@@ -289,8 +250,8 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
     auto i32Type = IntegerType::get(symTable->getContext(), 32);
     auto i64Type = IntegerType::get(symTable->getContext(), 64);
     FailureOr<FuncOp> fn =
-        lookupOrCreateApFloatFn(rewriter, symTable, "convert_to_int",
-                                {i32Type, i32Type, i1Type, i64Type});
+        lookupOrCreateFn(rewriter, symTable, "_mlir_apfloat_convert_to_int",
+                         {i32Type, i32Type, i1Type, i64Type});
     if (failed(fn))
       return fn;
 
@@ -351,8 +312,8 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
     auto i32Type = IntegerType::get(symTable->getContext(), 32);
     auto i64Type = IntegerType::get(symTable->getContext(), 64);
     FailureOr<FuncOp> fn =
-        lookupOrCreateApFloatFn(rewriter, symTable, "convert_from_int",
-                                {i32Type, i32Type, i1Type, i64Type});
+        lookupOrCreateFn(rewriter, symTable, "_mlir_apfloat_convert_from_int",
+                         {i32Type, i32Type, i1Type, i64Type});
     if (failed(fn))
       return fn;
 
@@ -421,8 +382,8 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
     auto i32Type = IntegerType::get(symTable->getContext(), 32);
     auto i64Type = IntegerType::get(symTable->getContext(), 64);
     FailureOr<FuncOp> fn =
-        lookupOrCreateApFloatFn(rewriter, symTable, "compare",
-                                {i32Type, i64Type, i64Type}, nullptr, i8Type);
+        lookupOrCreateFn(rewriter, symTable, "_mlir_apfloat_compare",
+                         {i32Type, i64Type, i64Type}, nullptr, i8Type);
     if (failed(fn))
       return fn;
 
@@ -569,8 +530,8 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
     // Get APFloat function from runtime library.
     auto i32Type = IntegerType::get(symTable->getContext(), 32);
     auto i64Type = IntegerType::get(symTable->getContext(), 64);
-    FailureOr<FuncOp> fn =
-        lookupOrCreateApFloatFn(rewriter, symTable, "neg", {i32Type, i64Type});
+    FailureOr<FuncOp> fn = lookupOrCreateFn(
+        rewriter, symTable, "_mlir_apfloat_neg", {i32Type, i64Type});
     if (failed(fn))
       return fn;
 
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 613dc6d242ceb..3c59fbda6810a 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -35,6 +35,7 @@ add_subdirectory(IndexToLLVM)
 add_subdirectory(IndexToSPIRV)
 add_subdirectory(LinalgToStandard)
 add_subdirectory(LLVMCommon)
+add_subdirectory(MathToAPFloat)
 add_subdirectory(MathToEmitC)
 add_subdirectory(MathToFuncs)
 add_subdirectory(MathToLibm)
diff --git a/mlir/lib/Conversion/MathToAPFloat/CMakeLists.txt b/mlir/lib/Conversion/MathToAPFloat/CMakeLists.txt
new file mode 100644
index 0000000000000..454b71b1ef160
--- /dev/null
+++ b/mlir/lib/Conversion/MathToAPFloat/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_conversion_library(MLIRMathToAPFloat
+  MathToAPFloat.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToLLVM
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRMathDialect
+  MLIRFuncDialect
+  MLIRFuncUtils
+  )
diff --git a/mlir/lib/Conversion/MathToAPFloat/MathToAPFloat.cpp b/mlir/lib/Conversion/MathToAPFloat/MathToAPFloat.cpp
new file mode 100644
index 0000000000000..35a4f94e4be24
--- /dev/null
+++ b/mlir/lib/Conversion/MathToAPFloat/MathToAPFloat.cpp
@@ -0,0 +1,52 @@
+//===- MathToAPFloat.cpp - Mathmetic to APFloat Conversion ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToAPFloat/MathToAPFloat.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Utils/Utils.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_MATHTOAPFLOATCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::func;
+
+namespace {
+struct MathToAPFloatConversionPass final
+    : impl::MathToAPFloatConversionPassBase<MathToAPFloatConversionPass> {
+  using Base::Base;
+
+  void runOnOperation() override;
+};
+
+void MathToAPFloatConversionPass::runOnOperation() {
+  MLIRContext *context = &getContext();
+  RewritePatternSet patterns(context);
+  LogicalResult result = success();
+  ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
+    if (diag.getSeverity() == DiagnosticSeverity::Error) {
+      result = failure();
+    }
+    // NB: if you don't return failure, no other diag handlers will fire (see
+    // mlir/lib/IR/Diagnostics.cpp:DiagnosticEngineImpl::emit).
+    return failure();
+  });
+  walkAndApplyPatterns(getOperation(), std::move(patterns));
+  if (failed(result))
+    return signalPassFailure();
+}
+} // namespace
diff --git a/mlir/lib/Dialect/Func/Utils/Utils.cpp b/mlir/lib/Dialect/Func/Utils/Utils.cpp
index d6dfd0229963c..d7dc89bb8c050 100644
--- a/mlir/lib/Dialect/Func/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Func/Utils/Utils.cpp
@@ -279,3 +279,41 @@ func::lookupFnDecl(SymbolOpInterface symTable, StringRef name,
   }
   return func;
 }
+
+func::FuncOp func::createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
+                                StringRef name, FunctionType funcT,
+                                bool setPrivate,
+                                SymbolTableCollection *symbolTables) {
+  OpBuilder::InsertionGuard g(b);
+  assert(!symTable->getRegion(0).empty() && "expected non-empty region");
+  b.setInsertionPointToStart(&symTable->getRegion(0).front());
+  func::FuncOp funcOp =
+      func::FuncOp::create(b, symTable->getLoc(), name, funcT);
+  if (setPrivate)
+    funcOp.setPrivate();
+  if (symbolTables) {
+    SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable);
+    symbolTable.insert(funcOp, symTable->getRegion(0).front().begin());
+  }
+  return funcOp;
+}
+
+FailureOr<func::FuncOp>
+func::lookupOrCreateFn(OpBuilder &b, SymbolOpInterface symTable,
+                       StringRef funcName, TypeRange paramTypes,
+                       SymbolTableCollection *symbolTables, Type resultType) {
+  if (!resultType)
+    resultType = IntegerType::get(symTable->getContext(), 64);
+  auto funcT = FunctionType::get(b.getContext(), paramTypes, {resultType});
+  FailureOr<func::FuncOp> func =
+      lookupFnDecl(symTable, funcName, funcT, symbolTables);
+  // Failed due to type mismatch.
+  if (failed(func))
+    return func;
+  // Successfully matched existing decl.
+  if (*func)
+    return *func;
+
+  return createFnDecl(b, symTable, funcName, funcT,
+                      /*setPrivate=*/true, symbolTables);
+}
\ No newline at end of file



More information about the Mlir-commits mailing list