[Mlir-commits] [mlir] Prototype: APFloat runtime library for unsupported CPU floating-point types (PR #166484)

Matthias Springer llvmlistbot at llvm.org
Tue Nov 4 22:34:40 PST 2025


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/166484

>From 350270bb006c6764b2b7ee5d53c408a67aa54c2d Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Tue, 4 Nov 2025 23:45:35 +0000
Subject: [PATCH] Prototype: APFloat CPU runner

---
 .../mlir/Dialect/LLVMIR/FunctionCallUtils.h   |  7 +++
 .../Conversion/ArithToLLVM/ArithToLLVM.cpp    | 45 ++++++++++++++++++-
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 14 ++++++
 .../Dialect/LLVMIR/IR/FunctionCallUtils.cpp   | 23 ++++++++++
 mlir/lib/ExecutionEngine/APFloatWrappers.cpp  | 40 +++++++++++++++++
 mlir/lib/ExecutionEngine/CMakeLists.txt       | 12 +++++
 .../Arith/CPU/test-apfloat-emulation.mlir     | 19 ++++++++
 7 files changed, 159 insertions(+), 1 deletion(-)
 create mode 100644 mlir/lib/ExecutionEngine/APFloatWrappers.cpp
 create mode 100644 mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir

diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 8ad9ed18acebd..8564d0f4205cf 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -52,6 +52,13 @@ lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp,
 FailureOr<LLVM::LLVMFuncOp>
 lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
                          SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp,
+                             SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreateApFloatAddFFn(OpBuilder &b, Operation *moduleOp,
+                            SymbolTableCollection *symbolTables = nullptr);
+
 /// Declares a function to print a C-string.
 /// If a custom runtime function is defined via `runtimeFunctionName`, it must
 /// have the signature void(char const*). The default function is `printString`.
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 03ed4d51cc744..632e1a7f02602 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
 #include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/TypeUtilities.h"
@@ -572,6 +573,47 @@ void mlir::arith::registerConvertArithToLLVMInterface(
   });
 }
 
+struct FancyAddFLowering : public ConvertOpToLLVMPattern<arith::AddFOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::AddFOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Get APFloat adder function from runtime library.
+    auto parent = op->getParentOfType<ModuleOp>();
+    if (!parent)
+      return failure();
+    FailureOr<Operation *> adder =
+        LLVM::lookupOrCreateApFloatAddFFn(rewriter, parent);
+    auto floatTy = cast<FloatType>(op.getType());
+
+    // Cast operands to 64-bit integers.
+    Location loc = op.getLoc();
+    Value lhsBits = rewriter.create<LLVM::ZExtOp>(loc, rewriter.getI64Type(),
+                                                  adaptor.getLhs());
+    Value rhsBits = rewriter.create<LLVM::ZExtOp>(loc, rewriter.getI64Type(),
+                                                  adaptor.getRhs());
+
+    // Call software implementation of floating point addition.
+    int32_t sem =
+        llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
+    Value semValue = rewriter.create<LLVM::ConstantOp>(
+        loc, rewriter.getI32Type(),
+        rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
+    SmallVector<Value> params = {semValue, lhsBits, rhsBits};
+    auto resultOp =
+        LLVM::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
+                             SymbolRefAttr::get(*adder), params);
+
+    // Truncate result to the original width.
+    Value truncatedBits = rewriter.create<LLVM::TruncOp>(
+        loc, rewriter.getIntegerType(floatTy.getWidth()),
+        resultOp->getResult(0));
+    rewriter.replaceOp(op, truncatedBits);
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // Pattern Population
 //===----------------------------------------------------------------------===//
@@ -586,7 +628,8 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
 
   // clang-format off
   patterns.add<
-    AddFOpLowering,
+    //AddFOpLowering,
+    FancyAddFLowering,
     AddIOpLowering,
     AndIOpLowering,
     AddUIExtendedOpLowering,
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 69a317ecd101f..260c028ffd9c5 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1654,6 +1654,20 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
           return failure();
         }
       }
+    } else if (auto floatTy = dyn_cast<FloatType>(printType)) {
+      // Print other floating-point types using the APFloat runtime library.
+      int32_t sem =
+          llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
+      Value semValue = rewriter.create<LLVM::ConstantOp>(
+          loc, rewriter.getI32Type(),
+          rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
+      Value floatBits =
+          rewriter.create<LLVM::ZExtOp>(loc, rewriter.getI64Type(), value);
+      printer =
+          LLVM::lookupOrCreateApFloatPrintFn(rewriter, parent, symbolTables);
+      emitCall(rewriter, loc, printer.value(),
+               ValueRange({semValue, floatBits}));
+      return success();
     } else {
       return failure();
     }
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index feaffa34897b6..8ee039be60568 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -30,6 +30,8 @@ static constexpr llvm::StringRef kPrintF16 = "printF16";
 static constexpr llvm::StringRef kPrintBF16 = "printBF16";
 static constexpr llvm::StringRef kPrintF32 = "printF32";
 static constexpr llvm::StringRef kPrintF64 = "printF64";
+static constexpr llvm::StringRef kPrintApFloat = "printApFloat";
+static constexpr llvm::StringRef kApFloatAddF = "APFloat_add";
 static constexpr llvm::StringRef kPrintString = "printString";
 static constexpr llvm::StringRef kPrintOpen = "printOpen";
 static constexpr llvm::StringRef kPrintClose = "printClose";
@@ -160,6 +162,27 @@ mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
       LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
 }
 
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp,
+                                         SymbolTableCollection *symbolTables) {
+  return lookupOrCreateReservedFn(
+      b, moduleOp, kPrintApFloat,
+      {IntegerType::get(moduleOp->getContext(), 32),
+       IntegerType::get(moduleOp->getContext(), 64)},
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
+}
+
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreateApFloatAddFFn(OpBuilder &b, Operation *moduleOp,
+                                        SymbolTableCollection *symbolTables) {
+  return lookupOrCreateReservedFn(
+      b, moduleOp, kApFloatAddF,
+      {IntegerType::get(moduleOp->getContext(), 32),
+       IntegerType::get(moduleOp->getContext(), 64),
+       IntegerType::get(moduleOp->getContext(), 64)},
+      IntegerType::get(moduleOp->getContext(), 64), symbolTables);
+}
+
 static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
   return LLVM::LLVMPointerType::get(context);
 }
diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
new file mode 100644
index 0000000000000..7879c75803355
--- /dev/null
+++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
@@ -0,0 +1,40 @@
+//===- ArmRunnerUtils.cpp - Utilities for configuring architecture properties //
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/APFloat.h"
+#include <iostream>
+
+#if (defined(_WIN32) || defined(__CYGWIN__))
+#define MLIR_APFLOAT_WRAPPERS_EXPORTED __declspec(dllexport)
+#else
+#define MLIR_APFLOAT_WRAPPERS_EXPORTED __attribute__((visibility("default")))
+#endif
+
+extern "C" {
+
+int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED APFloat_add(int32_t semantics,
+                                                   uint64_t a, uint64_t b) {
+  const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+      static_cast<llvm::APFloatBase::Semantics>(semantics));
+  unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
+  llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a));
+  llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b));
+  auto status = lhs.add(rhs, llvm::RoundingMode::NearestTiesToEven);
+  return lhs.bitcastToAPInt().getZExtValue();
+}
+
+void MLIR_APFLOAT_WRAPPERS_EXPORTED printApFloat(int32_t semantics,
+                                                 uint64_t a) {
+  const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+      static_cast<llvm::APFloatBase::Semantics>(semantics));
+  unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
+  llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
+  double d = x.convertToDouble();
+  std::cout << d << std::endl;
+}
+}
diff --git a/mlir/lib/ExecutionEngine/CMakeLists.txt b/mlir/lib/ExecutionEngine/CMakeLists.txt
index fdeb4dacf9278..8c09e50e4de7b 100644
--- a/mlir/lib/ExecutionEngine/CMakeLists.txt
+++ b/mlir/lib/ExecutionEngine/CMakeLists.txt
@@ -2,6 +2,7 @@
 # is a big dependency which most don't need.
 
 set(LLVM_OPTIONAL_SOURCES
+  APFloatWrappers.cpp
   ArmRunnerUtils.cpp
   ArmSMEStubs.cpp
   AsyncRuntime.cpp
@@ -167,6 +168,15 @@ if(LLVM_ENABLE_PIC)
   set_property(TARGET mlir_float16_utils PROPERTY CXX_STANDARD 17)
   target_compile_definitions(mlir_float16_utils PRIVATE mlir_float16_utils_EXPORTS)
 
+  add_mlir_library(mlir_apfloat_wrappers
+    SHARED
+    APFloatWrappers.cpp
+
+    EXCLUDE_FROM_LIBMLIR
+    )
+  set_property(TARGET mlir_apfloat_wrappers PROPERTY CXX_STANDARD 17)
+  target_compile_definitions(mlir_apfloat_wrappers PRIVATE mlir_apfloat_wrappers_EXPORTS)
+
   add_subdirectory(SparseTensor)
 
   add_mlir_library(mlir_c_runner_utils
@@ -177,6 +187,7 @@ if(LLVM_ENABLE_PIC)
     EXCLUDE_FROM_LIBMLIR
 
     LINK_LIBS PUBLIC
+    mlir_apfloat_wrappers
     mlir_float16_utils
     MLIRSparseTensorEnums
     MLIRSparseTensorRuntime
@@ -191,6 +202,7 @@ if(LLVM_ENABLE_PIC)
     EXCLUDE_FROM_LIBMLIR
 
     LINK_LIBS PUBLIC
+    mlir_apfloat_wrappers
     mlir_float16_utils
   )
   target_compile_definitions(mlir_runner_utils PRIVATE mlir_runner_utils_EXPORTS)
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
new file mode 100644
index 0000000000000..5cd83688d1710
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
@@ -0,0 +1,19 @@
+// Check that the ceildivsi lowering is correct.
+// We do not check any poison or UB values, as it is not possible to catch them.
+
+// RUN: mlir-opt %s --convert-to-llvm
+
+// Put rhs into separate function so that it won't be constant-folded.
+func.func @foo() -> f4E2M1FN {
+  %cst = arith.constant 5.0 : f4E2M1FN
+  return %cst : f4E2M1FN
+}
+
+func.func @entry() {
+  %a = arith.constant 5.0 : f4E2M1FN
+  %b = func.call @foo() : () -> (f4E2M1FN)
+  %c = arith.addf %a, %b : f4E2M1FN
+  vector.print %c : f4E2M1FN
+  return
+}
+



More information about the Mlir-commits mailing list