[Mlir-commits] [mlir] Prototype: APFloat runtime library for unsupported CPU floating-point types (PR #166484)
Matthias Springer
llvmlistbot at llvm.org
Tue Nov 4 22:34:40 PST 2025
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/166484
>From 350270bb006c6764b2b7ee5d53c408a67aa54c2d Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Tue, 4 Nov 2025 23:45:35 +0000
Subject: [PATCH] Prototype: APFloat CPU runner
---
.../mlir/Dialect/LLVMIR/FunctionCallUtils.h | 7 +++
.../Conversion/ArithToLLVM/ArithToLLVM.cpp | 45 ++++++++++++++++++-
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 14 ++++++
.../Dialect/LLVMIR/IR/FunctionCallUtils.cpp | 23 ++++++++++
mlir/lib/ExecutionEngine/APFloatWrappers.cpp | 40 +++++++++++++++++
mlir/lib/ExecutionEngine/CMakeLists.txt | 12 +++++
.../Arith/CPU/test-apfloat-emulation.mlir | 19 ++++++++
7 files changed, 159 insertions(+), 1 deletion(-)
create mode 100644 mlir/lib/ExecutionEngine/APFloatWrappers.cpp
create mode 100644 mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 8ad9ed18acebd..8564d0f4205cf 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -52,6 +52,13 @@ lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp,
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreateApFloatAddFFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables = nullptr);
+
/// Declares a function to print a C-string.
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
/// have the signature void(char const*). The default function is `printString`.
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 03ed4d51cc744..632e1a7f02602 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -14,6 +14,7 @@
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/TypeUtilities.h"
@@ -572,6 +573,47 @@ void mlir::arith::registerConvertArithToLLVMInterface(
});
}
+struct FancyAddFLowering : public ConvertOpToLLVMPattern<arith::AddFOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::AddFOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Get APFloat adder function from runtime library.
+ auto parent = op->getParentOfType<ModuleOp>();
+ if (!parent)
+ return failure();
+ FailureOr<Operation *> adder =
+ LLVM::lookupOrCreateApFloatAddFFn(rewriter, parent);
+ auto floatTy = cast<FloatType>(op.getType());
+
+ // Cast operands to 64-bit integers.
+ Location loc = op.getLoc();
+ Value lhsBits = rewriter.create<LLVM::ZExtOp>(loc, rewriter.getI64Type(),
+ adaptor.getLhs());
+ Value rhsBits = rewriter.create<LLVM::ZExtOp>(loc, rewriter.getI64Type(),
+ adaptor.getRhs());
+
+ // Call software implementation of floating point addition.
+ int32_t sem =
+ llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
+ Value semValue = rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getI32Type(),
+ rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
+ SmallVector<Value> params = {semValue, lhsBits, rhsBits};
+ auto resultOp =
+ LLVM::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*adder), params);
+
+ // Truncate result to the original width.
+ Value truncatedBits = rewriter.create<LLVM::TruncOp>(
+ loc, rewriter.getIntegerType(floatTy.getWidth()),
+ resultOp->getResult(0));
+ rewriter.replaceOp(op, truncatedBits);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// Pattern Population
//===----------------------------------------------------------------------===//
@@ -586,7 +628,8 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
// clang-format off
patterns.add<
- AddFOpLowering,
+ //AddFOpLowering,
+ FancyAddFLowering,
AddIOpLowering,
AndIOpLowering,
AddUIExtendedOpLowering,
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 69a317ecd101f..260c028ffd9c5 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1654,6 +1654,20 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
return failure();
}
}
+ } else if (auto floatTy = dyn_cast<FloatType>(printType)) {
+ // Print other floating-point types using the APFloat runtime library.
+ int32_t sem =
+ llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
+ Value semValue = rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getI32Type(),
+ rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
+ Value floatBits =
+ rewriter.create<LLVM::ZExtOp>(loc, rewriter.getI64Type(), value);
+ printer =
+ LLVM::lookupOrCreateApFloatPrintFn(rewriter, parent, symbolTables);
+ emitCall(rewriter, loc, printer.value(),
+ ValueRange({semValue, floatBits}));
+ return success();
} else {
return failure();
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index feaffa34897b6..8ee039be60568 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -30,6 +30,8 @@ static constexpr llvm::StringRef kPrintF16 = "printF16";
static constexpr llvm::StringRef kPrintBF16 = "printBF16";
static constexpr llvm::StringRef kPrintF32 = "printF32";
static constexpr llvm::StringRef kPrintF64 = "printF64";
+static constexpr llvm::StringRef kPrintApFloat = "printApFloat";
+static constexpr llvm::StringRef kApFloatAddF = "APFloat_add";
static constexpr llvm::StringRef kPrintString = "printString";
static constexpr llvm::StringRef kPrintOpen = "printOpen";
static constexpr llvm::StringRef kPrintClose = "printClose";
@@ -160,6 +162,27 @@ mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables) {
+ return lookupOrCreateReservedFn(
+ b, moduleOp, kPrintApFloat,
+ {IntegerType::get(moduleOp->getContext(), 32),
+ IntegerType::get(moduleOp->getContext(), 64)},
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
+}
+
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreateApFloatAddFFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables) {
+ return lookupOrCreateReservedFn(
+ b, moduleOp, kApFloatAddF,
+ {IntegerType::get(moduleOp->getContext(), 32),
+ IntegerType::get(moduleOp->getContext(), 64),
+ IntegerType::get(moduleOp->getContext(), 64)},
+ IntegerType::get(moduleOp->getContext(), 64), symbolTables);
+}
+
static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
return LLVM::LLVMPointerType::get(context);
}
diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
new file mode 100644
index 0000000000000..7879c75803355
--- /dev/null
+++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
@@ -0,0 +1,40 @@
+//===- ArmRunnerUtils.cpp - Utilities for configuring architecture properties //
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/APFloat.h"
+#include <iostream>
+
+#if (defined(_WIN32) || defined(__CYGWIN__))
+#define MLIR_APFLOAT_WRAPPERS_EXPORTED __declspec(dllexport)
+#else
+#define MLIR_APFLOAT_WRAPPERS_EXPORTED __attribute__((visibility("default")))
+#endif
+
+extern "C" {
+
+int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED APFloat_add(int32_t semantics,
+ uint64_t a, uint64_t b) {
+ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+ static_cast<llvm::APFloatBase::Semantics>(semantics));
+ unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
+ llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a));
+ llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b));
+ auto status = lhs.add(rhs, llvm::RoundingMode::NearestTiesToEven);
+ return lhs.bitcastToAPInt().getZExtValue();
+}
+
+void MLIR_APFLOAT_WRAPPERS_EXPORTED printApFloat(int32_t semantics,
+ uint64_t a) {
+ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+ static_cast<llvm::APFloatBase::Semantics>(semantics));
+ unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
+ llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
+ double d = x.convertToDouble();
+ std::cout << d << std::endl;
+}
+}
diff --git a/mlir/lib/ExecutionEngine/CMakeLists.txt b/mlir/lib/ExecutionEngine/CMakeLists.txt
index fdeb4dacf9278..8c09e50e4de7b 100644
--- a/mlir/lib/ExecutionEngine/CMakeLists.txt
+++ b/mlir/lib/ExecutionEngine/CMakeLists.txt
@@ -2,6 +2,7 @@
# is a big dependency which most don't need.
set(LLVM_OPTIONAL_SOURCES
+ APFloatWrappers.cpp
ArmRunnerUtils.cpp
ArmSMEStubs.cpp
AsyncRuntime.cpp
@@ -167,6 +168,15 @@ if(LLVM_ENABLE_PIC)
set_property(TARGET mlir_float16_utils PROPERTY CXX_STANDARD 17)
target_compile_definitions(mlir_float16_utils PRIVATE mlir_float16_utils_EXPORTS)
+ add_mlir_library(mlir_apfloat_wrappers
+ SHARED
+ APFloatWrappers.cpp
+
+ EXCLUDE_FROM_LIBMLIR
+ )
+ set_property(TARGET mlir_apfloat_wrappers PROPERTY CXX_STANDARD 17)
+ target_compile_definitions(mlir_apfloat_wrappers PRIVATE mlir_apfloat_wrappers_EXPORTS)
+
add_subdirectory(SparseTensor)
add_mlir_library(mlir_c_runner_utils
@@ -177,6 +187,7 @@ if(LLVM_ENABLE_PIC)
EXCLUDE_FROM_LIBMLIR
LINK_LIBS PUBLIC
+ mlir_apfloat_wrappers
mlir_float16_utils
MLIRSparseTensorEnums
MLIRSparseTensorRuntime
@@ -191,6 +202,7 @@ if(LLVM_ENABLE_PIC)
EXCLUDE_FROM_LIBMLIR
LINK_LIBS PUBLIC
+ mlir_apfloat_wrappers
mlir_float16_utils
)
target_compile_definitions(mlir_runner_utils PRIVATE mlir_runner_utils_EXPORTS)
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
new file mode 100644
index 0000000000000..5cd83688d1710
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
@@ -0,0 +1,19 @@
+// Check that the ceildivsi lowering is correct.
+// We do not check any poison or UB values, as it is not possible to catch them.
+
+// RUN: mlir-opt %s --convert-to-llvm
+
+// Put rhs into separate function so that it won't be constant-folded.
+func.func @foo() -> f4E2M1FN {
+ %cst = arith.constant 5.0 : f4E2M1FN
+ return %cst : f4E2M1FN
+}
+
+func.func @entry() {
+ %a = arith.constant 5.0 : f4E2M1FN
+ %b = func.call @foo() : () -> (f4E2M1FN)
+ %c = arith.addf %a, %b : f4E2M1FN
+ vector.print %c : f4E2M1FN
+ return
+}
+
More information about the Mlir-commits
mailing list