[Mlir-commits] [mlir] [mlir] Add FP software implementation lowering pass: `arith-to-apfloat` (PR #166618)
Maksim Levental
llvmlistbot at llvm.org
Sun Nov 9 23:33:34 PST 2025
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/166618
>From 949a4e961aea4ad1289006372e8adaa675ec8562 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Tue, 4 Nov 2025 23:45:35 +0000
Subject: [PATCH 01/12] Prototype: APFloat CPU runner
---
.../mlir/Dialect/LLVMIR/FunctionCallUtils.h | 7 +++
.../Conversion/ArithToLLVM/ArithToLLVM.cpp | 45 ++++++++++++++++++-
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 14 ++++++
.../Dialect/LLVMIR/IR/FunctionCallUtils.cpp | 23 ++++++++++
mlir/lib/ExecutionEngine/APFloatWrappers.cpp | 40 +++++++++++++++++
mlir/lib/ExecutionEngine/CMakeLists.txt | 12 +++++
.../Arith/CPU/test-apfloat-emulation.mlir | 19 ++++++++
7 files changed, 159 insertions(+), 1 deletion(-)
create mode 100644 mlir/lib/ExecutionEngine/APFloatWrappers.cpp
create mode 100644 mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 8ad9ed18acebd..8564d0f4205cf 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -52,6 +52,13 @@ lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp,
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreateApFloatAddFFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables = nullptr);
+
/// Declares a function to print a C-string.
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
/// have the signature void(char const*). The default function is `printString`.
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index b6099902cc337..336c1b4b2824e 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -14,6 +14,7 @@
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/TypeUtilities.h"
@@ -594,6 +595,47 @@ void mlir::arith::registerConvertArithToLLVMInterface(
});
}
+struct FancyAddFLowering : public ConvertOpToLLVMPattern<arith::AddFOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::AddFOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Get APFloat adder function from runtime library.
+ auto parent = op->getParentOfType<ModuleOp>();
+ if (!parent)
+ return failure();
+ FailureOr<Operation *> adder =
+ LLVM::lookupOrCreateApFloatAddFFn(rewriter, parent);
+ auto floatTy = cast<FloatType>(op.getType());
+
+ // Cast operands to 64-bit integers.
+ Location loc = op.getLoc();
+ Value lhsBits = rewriter.create<LLVM::ZExtOp>(loc, rewriter.getI64Type(),
+ adaptor.getLhs());
+ Value rhsBits = rewriter.create<LLVM::ZExtOp>(loc, rewriter.getI64Type(),
+ adaptor.getRhs());
+
+ // Call software implementation of floating point addition.
+ int32_t sem =
+ llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
+ Value semValue = rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getI32Type(),
+ rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
+ SmallVector<Value> params = {semValue, lhsBits, rhsBits};
+ auto resultOp =
+ LLVM::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*adder), params);
+
+ // Truncate result to the original width.
+ Value truncatedBits = rewriter.create<LLVM::TruncOp>(
+ loc, rewriter.getIntegerType(floatTy.getWidth()),
+ resultOp->getResult(0));
+ rewriter.replaceOp(op, truncatedBits);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// Pattern Population
//===----------------------------------------------------------------------===//
@@ -608,7 +650,8 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
// clang-format off
patterns.add<
- AddFOpLowering,
+ //AddFOpLowering,
+ FancyAddFLowering,
AddIOpLowering,
AndIOpLowering,
AddUIExtendedOpLowering,
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 69a317ecd101f..260c028ffd9c5 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1654,6 +1654,20 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
return failure();
}
}
+ } else if (auto floatTy = dyn_cast<FloatType>(printType)) {
+ // Print other floating-point types using the APFloat runtime library.
+ int32_t sem =
+ llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
+ Value semValue = rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getI32Type(),
+ rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
+ Value floatBits =
+ rewriter.create<LLVM::ZExtOp>(loc, rewriter.getI64Type(), value);
+ printer =
+ LLVM::lookupOrCreateApFloatPrintFn(rewriter, parent, symbolTables);
+ emitCall(rewriter, loc, printer.value(),
+ ValueRange({semValue, floatBits}));
+ return success();
} else {
return failure();
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index feaffa34897b6..8ee039be60568 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -30,6 +30,8 @@ static constexpr llvm::StringRef kPrintF16 = "printF16";
static constexpr llvm::StringRef kPrintBF16 = "printBF16";
static constexpr llvm::StringRef kPrintF32 = "printF32";
static constexpr llvm::StringRef kPrintF64 = "printF64";
+static constexpr llvm::StringRef kPrintApFloat = "printApFloat";
+static constexpr llvm::StringRef kApFloatAddF = "APFloat_add";
static constexpr llvm::StringRef kPrintString = "printString";
static constexpr llvm::StringRef kPrintOpen = "printOpen";
static constexpr llvm::StringRef kPrintClose = "printClose";
@@ -160,6 +162,27 @@ mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables) {
+ return lookupOrCreateReservedFn(
+ b, moduleOp, kPrintApFloat,
+ {IntegerType::get(moduleOp->getContext(), 32),
+ IntegerType::get(moduleOp->getContext(), 64)},
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
+}
+
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreateApFloatAddFFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables) {
+ return lookupOrCreateReservedFn(
+ b, moduleOp, kApFloatAddF,
+ {IntegerType::get(moduleOp->getContext(), 32),
+ IntegerType::get(moduleOp->getContext(), 64),
+ IntegerType::get(moduleOp->getContext(), 64)},
+ IntegerType::get(moduleOp->getContext(), 64), symbolTables);
+}
+
static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
return LLVM::LLVMPointerType::get(context);
}
diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
new file mode 100644
index 0000000000000..7879c75803355
--- /dev/null
+++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
@@ -0,0 +1,40 @@
+//===- ArmRunnerUtils.cpp - Utilities for configuring architecture properties //
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/APFloat.h"
+#include <iostream>
+
+#if (defined(_WIN32) || defined(__CYGWIN__))
+#define MLIR_APFLOAT_WRAPPERS_EXPORTED __declspec(dllexport)
+#else
+#define MLIR_APFLOAT_WRAPPERS_EXPORTED __attribute__((visibility("default")))
+#endif
+
+extern "C" {
+
+int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED APFloat_add(int32_t semantics,
+ uint64_t a, uint64_t b) {
+ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+ static_cast<llvm::APFloatBase::Semantics>(semantics));
+ unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
+ llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a));
+ llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b));
+ auto status = lhs.add(rhs, llvm::RoundingMode::NearestTiesToEven);
+ return lhs.bitcastToAPInt().getZExtValue();
+}
+
+void MLIR_APFLOAT_WRAPPERS_EXPORTED printApFloat(int32_t semantics,
+ uint64_t a) {
+ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+ static_cast<llvm::APFloatBase::Semantics>(semantics));
+ unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
+ llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
+ double d = x.convertToDouble();
+ std::cout << d << std::endl;
+}
+}
diff --git a/mlir/lib/ExecutionEngine/CMakeLists.txt b/mlir/lib/ExecutionEngine/CMakeLists.txt
index fdeb4dacf9278..8c09e50e4de7b 100644
--- a/mlir/lib/ExecutionEngine/CMakeLists.txt
+++ b/mlir/lib/ExecutionEngine/CMakeLists.txt
@@ -2,6 +2,7 @@
# is a big dependency which most don't need.
set(LLVM_OPTIONAL_SOURCES
+ APFloatWrappers.cpp
ArmRunnerUtils.cpp
ArmSMEStubs.cpp
AsyncRuntime.cpp
@@ -167,6 +168,15 @@ if(LLVM_ENABLE_PIC)
set_property(TARGET mlir_float16_utils PROPERTY CXX_STANDARD 17)
target_compile_definitions(mlir_float16_utils PRIVATE mlir_float16_utils_EXPORTS)
+ add_mlir_library(mlir_apfloat_wrappers
+ SHARED
+ APFloatWrappers.cpp
+
+ EXCLUDE_FROM_LIBMLIR
+ )
+ set_property(TARGET mlir_apfloat_wrappers PROPERTY CXX_STANDARD 17)
+ target_compile_definitions(mlir_apfloat_wrappers PRIVATE mlir_apfloat_wrappers_EXPORTS)
+
add_subdirectory(SparseTensor)
add_mlir_library(mlir_c_runner_utils
@@ -177,6 +187,7 @@ if(LLVM_ENABLE_PIC)
EXCLUDE_FROM_LIBMLIR
LINK_LIBS PUBLIC
+ mlir_apfloat_wrappers
mlir_float16_utils
MLIRSparseTensorEnums
MLIRSparseTensorRuntime
@@ -191,6 +202,7 @@ if(LLVM_ENABLE_PIC)
EXCLUDE_FROM_LIBMLIR
LINK_LIBS PUBLIC
+ mlir_apfloat_wrappers
mlir_float16_utils
)
target_compile_definitions(mlir_runner_utils PRIVATE mlir_runner_utils_EXPORTS)
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
new file mode 100644
index 0000000000000..5cd83688d1710
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
@@ -0,0 +1,19 @@
+// Check that the ceildivsi lowering is correct.
+// We do not check any poison or UB values, as it is not possible to catch them.
+
+// RUN: mlir-opt %s --convert-to-llvm
+
+// Put rhs into separate function so that it won't be constant-folded.
+func.func @foo() -> f4E2M1FN {
+ %cst = arith.constant 5.0 : f4E2M1FN
+ return %cst : f4E2M1FN
+}
+
+func.func @entry() {
+ %a = arith.constant 5.0 : f4E2M1FN
+ %b = func.call @foo() : () -> (f4E2M1FN)
+ %c = arith.addf %a, %b : f4E2M1FN
+ vector.print %c : f4E2M1FN
+ return
+}
+
>From a93ba6c8662664817e1d02f02e598ca4f7957854 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Wed, 5 Nov 2025 11:13:09 -0800
Subject: [PATCH 02/12] check float cast
---
mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 336c1b4b2824e..28e29cb7d0751 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -605,9 +605,11 @@ struct FancyAddFLowering : public ConvertOpToLLVMPattern<arith::AddFOp> {
auto parent = op->getParentOfType<ModuleOp>();
if (!parent)
return failure();
+ auto floatTy = dyn_cast<FloatType>(op.getType());
+ if (!floatTy)
+ return failure();
FailureOr<Operation *> adder =
LLVM::lookupOrCreateApFloatAddFFn(rewriter, parent);
- auto floatTy = cast<FloatType>(op.getType());
// Cast operands to 64-bit integers.
Location loc = op.getLoc();
>From 0055cb09435119be00918bb132a4f7282d3efa81 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Wed, 5 Nov 2025 13:26:59 -0800
Subject: [PATCH 03/12] fix creates
---
mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp | 16 ++++++++--------
1 file changed, 8 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 28e29cb7d0751..3026a70684c46 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -613,16 +613,16 @@ struct FancyAddFLowering : public ConvertOpToLLVMPattern<arith::AddFOp> {
// Cast operands to 64-bit integers.
Location loc = op.getLoc();
- Value lhsBits = rewriter.create<LLVM::ZExtOp>(loc, rewriter.getI64Type(),
- adaptor.getLhs());
- Value rhsBits = rewriter.create<LLVM::ZExtOp>(loc, rewriter.getI64Type(),
- adaptor.getRhs());
+ Value lhsBits = LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(),
+ adaptor.getLhs());
+ Value rhsBits = LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(),
+ adaptor.getRhs());
// Call software implementation of floating point addition.
int32_t sem =
llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
- Value semValue = rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getI32Type(),
+ Value semValue = LLVM::ConstantOp::create(
+ rewriter, loc, rewriter.getI32Type(),
rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
SmallVector<Value> params = {semValue, lhsBits, rhsBits};
auto resultOp =
@@ -630,8 +630,8 @@ struct FancyAddFLowering : public ConvertOpToLLVMPattern<arith::AddFOp> {
SymbolRefAttr::get(*adder), params);
// Truncate result to the original width.
- Value truncatedBits = rewriter.create<LLVM::TruncOp>(
- loc, rewriter.getIntegerType(floatTy.getWidth()),
+ Value truncatedBits = LLVM::TruncOp::create(
+ rewriter, loc, rewriter.getIntegerType(floatTy.getWidth()),
resultOp->getResult(0));
rewriter.replaceOp(op, truncatedBits);
return success();
>From 38ae089a6e5cc1a665998398564a9a64b01e9511 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Wed, 5 Nov 2025 13:40:18 -0800
Subject: [PATCH 04/12] check fp8 types
---
mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp | 10 +++++++---
1 file changed, 7 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 3026a70684c46..472c12eedf827 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -605,9 +605,13 @@ struct FancyAddFLowering : public ConvertOpToLLVMPattern<arith::AddFOp> {
auto parent = op->getParentOfType<ModuleOp>();
if (!parent)
return failure();
- auto floatTy = dyn_cast<FloatType>(op.getType());
- if (!floatTy)
+ if (!llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
+ Float8E5M2FNUZType, Float8E4M3FNUZType,
+ Float8E4M3B11FNUZType, Float8E3M4Type, Float4E2M1FNType,
+ Float6E2M3FNType, Float6E3M2FNType, Float8E8M0FNUType>(
+ op.getType()))
return failure();
+ auto floatTy = cast<FloatType>(op.getType());
FailureOr<Operation *> adder =
LLVM::lookupOrCreateApFloatAddFFn(rewriter, parent);
@@ -652,7 +656,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
// clang-format off
patterns.add<
- //AddFOpLowering,
+ AddFOpLowering,
FancyAddFLowering,
AddIOpLowering,
AndIOpLowering,
>From 0f3b82046b72e6540189215b2f0f757947e9d495 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Wed, 5 Nov 2025 19:25:10 -0800
Subject: [PATCH 05/12] add X-macros
---
.../mlir/Dialect/LLVMIR/FunctionCallUtils.h | 20 ++++-
.../Dialect/LLVMIR/IR/FunctionCallUtils.cpp | 33 +++++---
mlir/lib/ExecutionEngine/APFloatWrappers.cpp | 75 ++++++++++++++++---
3 files changed, 104 insertions(+), 24 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 8564d0f4205cf..01f7f75c210ef 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -55,9 +55,23 @@ lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables = nullptr);
-FailureOr<LLVM::LLVMFuncOp>
-lookupOrCreateApFloatAddFFn(OpBuilder &b, Operation *moduleOp,
- SymbolTableCollection *symbolTables = nullptr);
+
+#define APFLOAT_BIN_OPS(X) \
+ X(add) \
+ X(subtract) \
+ X(multiply) \
+ X(divide) \
+ X(remainder) \
+ X(mod)
+
+#define LOOKUP_OR_CREATE_APFLOAT_FN_DECL(OP) \
+ FailureOr<LLVM::LLVMFuncOp> lookupOrCreateApFloat##OP##Fn( \
+ OpBuilder &b, Operation *moduleOp, \
+ SymbolTableCollection *symbolTables = nullptr);
+
+APFLOAT_BIN_OPS(LOOKUP_OR_CREATE_APFLOAT_FN_DECL)
+
+#undef LOOKUP_OR_CREATE_APFLOAT_FN_DECL
/// Declares a function to print a C-string.
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index 8ee039be60568..cb6ee76f8cbfb 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -31,7 +31,14 @@ static constexpr llvm::StringRef kPrintBF16 = "printBF16";
static constexpr llvm::StringRef kPrintF32 = "printF32";
static constexpr llvm::StringRef kPrintF64 = "printF64";
static constexpr llvm::StringRef kPrintApFloat = "printApFloat";
-static constexpr llvm::StringRef kApFloatAddF = "APFloat_add";
+
+#define APFLOAT_EXTERN_K(OP) kApFloat_##OP
+
+#define APFLOAT_EXTERN_NAME(OP) \
+ static constexpr llvm::StringRef APFLOAT_EXTERN_K(OP) = "APFloat_" #OP;
+
+APFLOAT_BIN_OPS(APFLOAT_EXTERN_NAME)
+
static constexpr llvm::StringRef kPrintString = "printString";
static constexpr llvm::StringRef kPrintOpen = "printOpen";
static constexpr llvm::StringRef kPrintClose = "printClose";
@@ -172,16 +179,20 @@ mlir::LLVM::lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp,
LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
-FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreateApFloatAddFFn(OpBuilder &b, Operation *moduleOp,
- SymbolTableCollection *symbolTables) {
- return lookupOrCreateReservedFn(
- b, moduleOp, kApFloatAddF,
- {IntegerType::get(moduleOp->getContext(), 32),
- IntegerType::get(moduleOp->getContext(), 64),
- IntegerType::get(moduleOp->getContext(), 64)},
- IntegerType::get(moduleOp->getContext(), 64), symbolTables);
-}
+#define LOOKUP_OR_CREATE_APFLOAT_FN_DEFN(OP) \
+ FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreateApFloat##OP##Fn( \
+ OpBuilder &b, Operation *moduleOp, \
+ SymbolTableCollection *symbolTables) { \
+ return lookupOrCreateReservedFn( \
+ b, moduleOp, APFLOAT_EXTERN_K(OP), \
+ {IntegerType::get(moduleOp->getContext(), 32), \
+ IntegerType::get(moduleOp->getContext(), 64), \
+ IntegerType::get(moduleOp->getContext(), 64)}, \
+ IntegerType::get(moduleOp->getContext(), 64), symbolTables); \
+ }
+
+APFLOAT_BIN_OPS(LOOKUP_OR_CREATE_APFLOAT_FN_DEFN)
+#undef LOOKUP_OR_CREATE_APFLOAT_FN_DEFN
static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
return LLVM::LLVMPointerType::get(context);
diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
index 7879c75803355..8d2848bd7cf77 100644
--- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
@@ -7,26 +7,81 @@
//===----------------------------------------------------------------------===//
#include "llvm/ADT/APFloat.h"
+#include "llvm/Support/Debug.h"
+
#include <iostream>
+#define DEBUG_TYPE "mlir-apfloat-wrapper"
+
#if (defined(_WIN32) || defined(__CYGWIN__))
#define MLIR_APFLOAT_WRAPPERS_EXPORTED __declspec(dllexport)
#else
#define MLIR_APFLOAT_WRAPPERS_EXPORTED __attribute__((visibility("default")))
#endif
+static std::string_view
+apFloatOpStatusToStr(llvm::APFloatBase::opStatus status) {
+ switch (status) {
+ case llvm::APFloatBase::opOK:
+ return "OK";
+ case llvm::APFloatBase::opInvalidOp:
+ return "InvalidOp";
+ case llvm::APFloatBase::opDivByZero:
+ return "DivByZero";
+ case llvm::APFloatBase::opOverflow:
+ return "Overflow";
+ case llvm::APFloatBase::opUnderflow:
+ return "Underflow";
+ case llvm::APFloatBase::opInexact:
+ return "Inexact";
+ }
+ llvm::report_fatal_error("unhandled llvm::APFloatBase::opStatus variant");
+}
+
+#define APFLOAT_BINARY_OP(OP) \
+ int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED APFloat_##OP( \
+ int32_t semantics, uint64_t a, uint64_t b) { \
+ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \
+ static_cast<llvm::APFloatBase::Semantics>(semantics)); \
+ unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); \
+ llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a)); \
+ llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b)); \
+ llvm::APFloatBase::opStatus status = lhs.OP(rhs); \
+ assert(status == llvm::APFloatBase::opOK && "expected " #OP \
+ " opstatus to be OK"); \
+ return lhs.bitcastToAPInt().getZExtValue(); \
+ }
+
+#define APFLOAT_BINARY_OP_ROUNDING_MODE(OP, ROUNDING_MODE) \
+ int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED APFloat_##OP( \
+ int32_t semantics, uint64_t a, uint64_t b) { \
+ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \
+ static_cast<llvm::APFloatBase::Semantics>(semantics)); \
+ unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); \
+ llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a)); \
+ llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b)); \
+ llvm::APFloatBase::opStatus status = lhs.OP(rhs, ROUNDING_MODE); \
+ assert(status == llvm::APFloatBase::opOK && "expected " #OP \
+ " opstatus to be OK"); \
+ return lhs.bitcastToAPInt().getZExtValue(); \
+ }
+
extern "C" {
-int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED APFloat_add(int32_t semantics,
- uint64_t a, uint64_t b) {
- const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
- static_cast<llvm::APFloatBase::Semantics>(semantics));
- unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
- llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a));
- llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b));
- auto status = lhs.add(rhs, llvm::RoundingMode::NearestTiesToEven);
- return lhs.bitcastToAPInt().getZExtValue();
-}
+#define BIN_OPS_WITH_ROUNDING(X) \
+ X(add, llvm::RoundingMode::NearestTiesToEven) \
+ X(subtract, llvm::RoundingMode::NearestTiesToEven) \
+ X(multiply, llvm::RoundingMode::NearestTiesToEven) \
+ X(divide, llvm::RoundingMode::NearestTiesToEven)
+
+BIN_OPS_WITH_ROUNDING(APFLOAT_BINARY_OP_ROUNDING_MODE)
+#undef BIN_OPS_WITH_ROUNDING
+#undef APFLOAT_BINARY_OP_ROUNDING_MODE
+
+APFLOAT_BINARY_OP(remainder)
+APFLOAT_BINARY_OP(mod)
+
+#undef APFLOAT_BINARY_OP
void MLIR_APFLOAT_WRAPPERS_EXPORTED printApFloat(int32_t semantics,
uint64_t a) {
>From b713f607b27984b2eff561169f51b941b2ee2b06 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Thu, 6 Nov 2025 15:14:09 -0800
Subject: [PATCH 06/12] add arith-to-apfloat
---
.../ArithToAPFloat/ArithToAPFloat.h | 28 ++++
mlir/include/mlir/Conversion/Passes.h | 1 +
mlir/include/mlir/Conversion/Passes.td | 13 ++
mlir/include/mlir/Dialect/Func/Utils/Utils.h | 8 ++
.../mlir/Dialect/LLVMIR/FunctionCallUtils.h | 17 ---
.../ArithToAPFloat/ArithToAPFloat.cpp | 136 ++++++++++++++++++
.../Conversion/ArithToAPFloat/CMakeLists.txt | 17 +++
.../Conversion/ArithToLLVM/ArithToLLVM.cpp | 48 -------
mlir/lib/Conversion/CMakeLists.txt | 1 +
mlir/lib/Dialect/Func/Utils/Utils.cpp | 42 ++++++
.../Dialect/LLVMIR/IR/FunctionCallUtils.cpp | 23 ---
mlir/lib/ExecutionEngine/APFloatWrappers.cpp | 22 ---
.../Arith/CPU/test-apfloat-emulation.mlir | 21 ++-
13 files changed, 266 insertions(+), 111 deletions(-)
create mode 100644 mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h
create mode 100644 mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
create mode 100644 mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt
diff --git a/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h b/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h
new file mode 100644
index 0000000000000..a5df4647f1acc
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h
@@ -0,0 +1,28 @@
+//===- ArithToAPFloat.h - Arith to APFloat impl conversion ---*- C++ ----*-===//
+//
+// Part of the APFloat Project, under the Apache License v2.0 with APFloat
+// Exceptions. See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH APFloat-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H
+#define MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H
+
+#include <memory>
+
+namespace mlir {
+
+class DialectRegistry;
+class RewritePatternSet;
+class Pass;
+
+#define GEN_PASS_DECL_ARITHTOAPFLOATCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace arith {
+void populateArithToAPFloatConversionPatterns(RewritePatternSet &patterns);
+} // namespace arith
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARITHTOAPFloat_ARITHTOAPFloat_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 40d866ec7bf10..82bdfd02661a6 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -12,6 +12,7 @@
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
+#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 70e3e45c225db..2bcd2870949f3 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -186,6 +186,19 @@ def ArithToLLVMConversionPass : Pass<"convert-arith-to-llvm"> {
];
}
+//===----------------------------------------------------------------------===//
+// ArithToAPFloat
+//===----------------------------------------------------------------------===//
+
+def ArithToAPFloatConversionPass : Pass<"convert-arith-to-apfloat"> {
+ let summary = "Convert Arith dialect ops on FP8 types to APFloat lib calls";
+ let description = [{
+ This pass converts supported Arith ops which manipulate FP8 typed values to APFloat lib calls.
+ }];
+ let dependentDialects = ["func::FuncDialect"];
+ let options = [];
+}
+
//===----------------------------------------------------------------------===//
// ArithToSPIRV
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
index 3576126a487ac..9c9973cf84368 100644
--- a/mlir/include/mlir/Dialect/Func/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
@@ -60,6 +60,14 @@ mlir::FailureOr<std::pair<mlir::func::FuncOp, mlir::func::CallOp>>
deduplicateArgsOfFuncOp(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp,
mlir::ModuleOp moduleOp);
+/// Create a FuncOp with signature `resultTypes`(`paramTypes`)` and name `name`.
+/// Return a failure if the FuncOp found has unexpected signature.
+FailureOr<FuncOp>
+lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
+ ArrayRef<Type> paramTypes = {},
+ ArrayRef<Type> resultTypes = {}, bool setPrivate = false,
+ SymbolTableCollection *symbolTables = nullptr);
+
} // namespace func
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 01f7f75c210ef..b09d32022e348 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -56,23 +56,6 @@ FailureOr<LLVM::LLVMFuncOp>
lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables = nullptr);
-#define APFLOAT_BIN_OPS(X) \
- X(add) \
- X(subtract) \
- X(multiply) \
- X(divide) \
- X(remainder) \
- X(mod)
-
-#define LOOKUP_OR_CREATE_APFLOAT_FN_DECL(OP) \
- FailureOr<LLVM::LLVMFuncOp> lookupOrCreateApFloat##OP##Fn( \
- OpBuilder &b, Operation *moduleOp, \
- SymbolTableCollection *symbolTables = nullptr);
-
-APFLOAT_BIN_OPS(LOOKUP_OR_CREATE_APFLOAT_FN_DECL)
-
-#undef LOOKUP_OR_CREATE_APFLOAT_FN_DECL
-
/// Declares a function to print a C-string.
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
/// have the signature void(char const*). The default function is `printString`.
diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
new file mode 100644
index 0000000000000..bc451e88eb3bd
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
@@ -0,0 +1,136 @@
+//===- ArithToAPFloat.cpp - Arithmetic to APFloat impl conversion ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Utils/Utils.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::func;
+
+#define APFLOAT_BIN_OPS(X) \
+ X(add) \
+ X(subtract) \
+ X(multiply) \
+ X(divide) \
+ X(remainder) \
+ X(mod)
+
+#define APFLOAT_EXTERN_K(OP) kApFloat_##OP
+
+#define APFLOAT_EXTERN_NAME(OP) \
+ static constexpr llvm::StringRef APFLOAT_EXTERN_K(OP) = "_mlir_" \
+ "apfloat_" #OP;
+
+namespace mlir::func {
+#define LOOKUP_OR_CREATE_APFLOAT_FN_DECL(OP) \
+ FailureOr<FuncOp> lookupOrCreateApFloat##OP##Fn( \
+ OpBuilder &b, Operation *moduleOp, \
+ SymbolTableCollection *symbolTables = nullptr);
+
+APFLOAT_BIN_OPS(LOOKUP_OR_CREATE_APFLOAT_FN_DECL)
+
+#undef LOOKUP_OR_CREATE_APFLOAT_FN_DECL
+
+APFLOAT_BIN_OPS(APFLOAT_EXTERN_NAME)
+
+#define LOOKUP_OR_CREATE_APFLOAT_FN_DEFN(OP) \
+ FailureOr<FuncOp> lookupOrCreateApFloat##OP##Fn( \
+ OpBuilder &b, Operation *moduleOp, \
+ SymbolTableCollection *symbolTables) { \
+ return lookupOrCreateFn(b, moduleOp, APFLOAT_EXTERN_K(OP), \
+ {IntegerType::get(moduleOp->getContext(), 32), \
+ IntegerType::get(moduleOp->getContext(), 64), \
+ IntegerType::get(moduleOp->getContext(), 64)}, \
+ {IntegerType::get(moduleOp->getContext(), 64)}, \
+ /*setPrivate*/ true, symbolTables); \
+ }
+
+APFLOAT_BIN_OPS(LOOKUP_OR_CREATE_APFLOAT_FN_DEFN)
+#undef LOOKUP_OR_CREATE_APFLOAT_FN_DEFN
+} // namespace mlir::func
+
+struct FancyAddFLowering : OpRewritePattern<arith::AddFOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(arith::AddFOp op,
+ PatternRewriter &rewriter) const override {
+ // Get APFloat adder function from runtime library.
+ auto parent = op->getParentOfType<ModuleOp>();
+ if (!parent)
+ return failure();
+ if (!llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
+ Float8E5M2FNUZType, Float8E4M3FNUZType,
+ Float8E4M3B11FNUZType, Float8E3M4Type, Float4E2M1FNType,
+ Float6E2M3FNType, Float6E3M2FNType, Float8E8M0FNUType>(
+ op.getType()))
+ return failure();
+ FailureOr<Operation *> adder = lookupOrCreateApFloataddFn(rewriter, parent);
+
+ // Cast operands to 64-bit integers.
+ Location loc = op.getLoc();
+ auto floatTy = cast<FloatType>(op.getType());
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ auto int64Type = rewriter.getI64Type();
+ Value lhsBits = arith::ExtUIOp::create(
+ rewriter, loc, int64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
+ Value rhsBits = arith::ExtUIOp::create(
+ rewriter, loc, int64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));
+
+ // Call software implementation of floating point addition.
+ int32_t sem =
+ llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
+ Value semValue = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getI32Type(),
+ rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
+ SmallVector<Value> params = {semValue, lhsBits, rhsBits};
+ auto resultOp =
+ func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*adder), params);
+
+ // Truncate result to the original width.
+ Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
+ resultOp->getResult(0));
+ rewriter.replaceAllUsesWith(
+ op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
+ return success();
+ }
+};
+
+void arith::populateArithToAPFloatConversionPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<FancyAddFLowering>(patterns.getContext());
+}
+
+namespace {
+struct ArithToAPFloatConversionPass final
+ : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
+ using impl::ArithToAPFloatConversionPassBase<
+ ArithToAPFloatConversionPass>::ArithToAPFloatConversionPassBase;
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ RewritePatternSet patterns(op->getContext());
+ arith::populateArithToAPFloatConversionPatterns(patterns);
+ if (failed(applyPatternsGreedily(op, std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+} // namespace
diff --git a/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt b/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt
new file mode 100644
index 0000000000000..b0d1e46b3655f
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_conversion_library(MLIRArithToAPFloat
+ ArithToAPFloat.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRArithDialect
+ MLIRArithTransforms
+ MLIRFuncDialect
+ )
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 472c12eedf827..f2bacc3399144 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -595,53 +595,6 @@ void mlir::arith::registerConvertArithToLLVMInterface(
});
}
-struct FancyAddFLowering : public ConvertOpToLLVMPattern<arith::AddFOp> {
- using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(arith::AddFOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- // Get APFloat adder function from runtime library.
- auto parent = op->getParentOfType<ModuleOp>();
- if (!parent)
- return failure();
- if (!llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
- Float8E5M2FNUZType, Float8E4M3FNUZType,
- Float8E4M3B11FNUZType, Float8E3M4Type, Float4E2M1FNType,
- Float6E2M3FNType, Float6E3M2FNType, Float8E8M0FNUType>(
- op.getType()))
- return failure();
- auto floatTy = cast<FloatType>(op.getType());
- FailureOr<Operation *> adder =
- LLVM::lookupOrCreateApFloatAddFFn(rewriter, parent);
-
- // Cast operands to 64-bit integers.
- Location loc = op.getLoc();
- Value lhsBits = LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(),
- adaptor.getLhs());
- Value rhsBits = LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(),
- adaptor.getRhs());
-
- // Call software implementation of floating point addition.
- int32_t sem =
- llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
- Value semValue = LLVM::ConstantOp::create(
- rewriter, loc, rewriter.getI32Type(),
- rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
- SmallVector<Value> params = {semValue, lhsBits, rhsBits};
- auto resultOp =
- LLVM::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
- SymbolRefAttr::get(*adder), params);
-
- // Truncate result to the original width.
- Value truncatedBits = LLVM::TruncOp::create(
- rewriter, loc, rewriter.getIntegerType(floatTy.getWidth()),
- resultOp->getResult(0));
- rewriter.replaceOp(op, truncatedBits);
- return success();
- }
-};
-
//===----------------------------------------------------------------------===//
// Pattern Population
//===----------------------------------------------------------------------===//
@@ -657,7 +610,6 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
// clang-format off
patterns.add<
AddFOpLowering,
- FancyAddFLowering,
AddIOpLowering,
AndIOpLowering,
AddUIExtendedOpLowering,
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index bebf1b8fff3f9..613dc6d242ceb 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -2,6 +2,7 @@ add_subdirectory(AffineToStandard)
add_subdirectory(AMDGPUToROCDL)
add_subdirectory(ArithCommon)
add_subdirectory(ArithToAMDGPU)
+add_subdirectory(ArithToAPFloat)
add_subdirectory(ArithToArmSME)
add_subdirectory(ArithToEmitC)
add_subdirectory(ArithToLLVM)
diff --git a/mlir/lib/Dialect/Func/Utils/Utils.cpp b/mlir/lib/Dialect/Func/Utils/Utils.cpp
index b4cb0932ef631..e187e62cf6555 100644
--- a/mlir/lib/Dialect/Func/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Func/Utils/Utils.cpp
@@ -254,3 +254,45 @@ func::deduplicateArgsOfFuncOp(RewriterBase &rewriter, func::FuncOp funcOp,
return std::make_pair(*newFuncOpOrFailure, newCallOp);
}
+
+FailureOr<func::FuncOp>
+func::lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
+ ArrayRef<Type> paramTypes, ArrayRef<Type> resultTypes,
+ bool setPrivate, SymbolTableCollection *symbolTables) {
+ assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
+ "expected SymbolTable operation");
+
+ FuncOp func;
+ if (symbolTables) {
+ func = symbolTables->lookupSymbolIn<FuncOp>(
+ moduleOp, StringAttr::get(moduleOp->getContext(), name));
+ } else {
+ func = llvm::dyn_cast_or_null<FuncOp>(
+ SymbolTable::lookupSymbolIn(moduleOp, name));
+ }
+
+ FunctionType funcT =
+ FunctionType::get(b.getContext(), paramTypes, resultTypes);
+ // Assert the signature of the found function is same as expected
+ if (func) {
+ if (funcT != func.getFunctionType()) {
+ func.emitError("redefinition of function '")
+ << name << "' of different type " << funcT << " is prohibited";
+ return failure();
+ }
+ return func;
+ }
+
+ OpBuilder::InsertionGuard g(b);
+ assert(!moduleOp->getRegion(0).empty() && "expected non-empty region");
+ b.setInsertionPointToStart(&moduleOp->getRegion(0).front());
+ FuncOp funcOp = FuncOp::create(b, moduleOp->getLoc(), name, funcT);
+ if (setPrivate)
+ funcOp.setPrivate();
+ if (symbolTables) {
+ SymbolTable &symbolTable = symbolTables->getSymbolTable(moduleOp);
+ symbolTable.insert(funcOp, moduleOp->getRegion(0).front().begin());
+ }
+
+ return funcOp;
+}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index cb6ee76f8cbfb..160b6ae89215c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -31,14 +31,6 @@ static constexpr llvm::StringRef kPrintBF16 = "printBF16";
static constexpr llvm::StringRef kPrintF32 = "printF32";
static constexpr llvm::StringRef kPrintF64 = "printF64";
static constexpr llvm::StringRef kPrintApFloat = "printApFloat";
-
-#define APFLOAT_EXTERN_K(OP) kApFloat_##OP
-
-#define APFLOAT_EXTERN_NAME(OP) \
- static constexpr llvm::StringRef APFLOAT_EXTERN_K(OP) = "APFloat_" #OP;
-
-APFLOAT_BIN_OPS(APFLOAT_EXTERN_NAME)
-
static constexpr llvm::StringRef kPrintString = "printString";
static constexpr llvm::StringRef kPrintOpen = "printOpen";
static constexpr llvm::StringRef kPrintClose = "printClose";
@@ -179,21 +171,6 @@ mlir::LLVM::lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp,
LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
-#define LOOKUP_OR_CREATE_APFLOAT_FN_DEFN(OP) \
- FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreateApFloat##OP##Fn( \
- OpBuilder &b, Operation *moduleOp, \
- SymbolTableCollection *symbolTables) { \
- return lookupOrCreateReservedFn( \
- b, moduleOp, APFLOAT_EXTERN_K(OP), \
- {IntegerType::get(moduleOp->getContext(), 32), \
- IntegerType::get(moduleOp->getContext(), 64), \
- IntegerType::get(moduleOp->getContext(), 64)}, \
- IntegerType::get(moduleOp->getContext(), 64), symbolTables); \
- }
-
-APFLOAT_BIN_OPS(LOOKUP_OR_CREATE_APFLOAT_FN_DEFN)
-#undef LOOKUP_OR_CREATE_APFLOAT_FN_DEFN
-
static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
return LLVM::LLVMPointerType::get(context);
}
diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
index 8d2848bd7cf77..a5049436d03c1 100644
--- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
@@ -7,37 +7,15 @@
//===----------------------------------------------------------------------===//
#include "llvm/ADT/APFloat.h"
-#include "llvm/Support/Debug.h"
#include <iostream>
-#define DEBUG_TYPE "mlir-apfloat-wrapper"
-
#if (defined(_WIN32) || defined(__CYGWIN__))
#define MLIR_APFLOAT_WRAPPERS_EXPORTED __declspec(dllexport)
#else
#define MLIR_APFLOAT_WRAPPERS_EXPORTED __attribute__((visibility("default")))
#endif
-static std::string_view
-apFloatOpStatusToStr(llvm::APFloatBase::opStatus status) {
- switch (status) {
- case llvm::APFloatBase::opOK:
- return "OK";
- case llvm::APFloatBase::opInvalidOp:
- return "InvalidOp";
- case llvm::APFloatBase::opDivByZero:
- return "DivByZero";
- case llvm::APFloatBase::opOverflow:
- return "Overflow";
- case llvm::APFloatBase::opUnderflow:
- return "Underflow";
- case llvm::APFloatBase::opInexact:
- return "Inexact";
- }
- llvm::report_fatal_error("unhandled llvm::APFloatBase::opStatus variant");
-}
-
#define APFLOAT_BINARY_OP(OP) \
int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED APFloat_##OP( \
int32_t semantics, uint64_t a, uint64_t b) { \
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
index 5cd83688d1710..d4c2394474b15 100644
--- a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
@@ -1,7 +1,7 @@
// Check that the ceildivsi lowering is correct.
// We do not check any poison or UB values, as it is not possible to catch them.
-// RUN: mlir-opt %s --convert-to-llvm
+// RUN: mlir-opt %s --convert-arith-to-apfloat
// Put rhs into separate function so that it won't be constant-folded.
func.func @foo() -> f4E2M1FN {
@@ -17,3 +17,22 @@ func.func @entry() {
return
}
+// CHECK-LABEL: func.func private @_mlir_apfloat_add(i32, i64, i64) -> i64
+
+// CHECK-LABEL: func.func @foo() -> f4E2M1FN {
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant 4.000000e+00 : f4E2M1FN
+// CHECK: return %[[CONSTANT_0]] : f4E2M1FN
+// CHECK: }
+
+// CHECK-LABEL: func.func @entry() {
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant 18 : i32
+// CHECK: %[[CONSTANT_1:.*]] = arith.constant 6 : i64
+// CHECK: %[[VAL_0:.*]] = call @foo() : () -> f4E2M1FN
+// CHECK: %[[BITCAST_0:.*]] = arith.bitcast %[[VAL_0]] : f4E2M1FN to i4
+// CHECK: %[[EXTUI_0:.*]] = arith.extui %[[BITCAST_0]] : i4 to i64
+// CHECK: %[[VAL_1:.*]] = call @_mlir_apfloat_add(%[[CONSTANT_0]], %[[EXTUI_0]], %[[CONSTANT_1]]) : (i32, i64, i64) -> i64
+// CHECK: %[[TRUNCI_0:.*]] = arith.trunci %[[VAL_1]] : i64 to i4
+// CHECK: %[[BITCAST_1:.*]] = arith.bitcast %[[TRUNCI_0]] : i4 to f4E2M1FN
+// CHECK: vector.print %[[BITCAST_1]] : f4E2M1FN
+// CHECK: return
+// CHECK: }
\ No newline at end of file
>From 78df4a889512cdf017135ec2b75594428b96af52 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Thu, 6 Nov 2025 19:43:26 -0800
Subject: [PATCH 07/12] fix ConvertVectorToLLVM.cpp
---
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 260c028ffd9c5..c747e1b59558a 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1658,11 +1658,11 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
// Print other floating-point types using the APFloat runtime library.
int32_t sem =
llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
- Value semValue = rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getI32Type(),
+ Value semValue = LLVM::ConstantOp::create(
+ rewriter, loc, rewriter.getI32Type(),
rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
Value floatBits =
- rewriter.create<LLVM::ZExtOp>(loc, rewriter.getI64Type(), value);
+ LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(), value);
printer =
LLVM::lookupOrCreateApFloatPrintFn(rewriter, parent, symbolTables);
emitCall(rewriter, loc, printer.value(),
>From 118006443745ff068993233e7b4111f29c6bef12 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sun, 9 Nov 2025 06:04:21 +0000
Subject: [PATCH 08/12] walk instead of dialect conversion
---
.../ArithToAPFloat/ArithToAPFloat.h | 9 +-
mlir/include/mlir/Conversion/Passes.td | 9 +-
.../ArithToAPFloat/ArithToAPFloat.cpp | 189 +++++++++---------
mlir/lib/ExecutionEngine/APFloatWrappers.cpp | 36 ++--
.../ArithToApfloat/arith-to-apfloat.mlir | 88 ++++++++
.../Arith/CPU/test-apfloat-emulation.mlir | 43 ++--
6 files changed, 219 insertions(+), 155 deletions(-)
create mode 100644 mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
diff --git a/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h b/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h
index a5df4647f1acc..64a42a228199e 100644
--- a/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h
+++ b/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h
@@ -12,17 +12,10 @@
#include <memory>
namespace mlir {
-
-class DialectRegistry;
-class RewritePatternSet;
class Pass;
#define GEN_PASS_DECL_ARITHTOAPFLOATCONVERSIONPASS
#include "mlir/Conversion/Passes.h.inc"
-
-namespace arith {
-void populateArithToAPFloatConversionPatterns(RewritePatternSet &patterns);
-} // namespace arith
} // namespace mlir
-#endif // MLIR_CONVERSION_ARITHTOAPFloat_ARITHTOAPFloat_H
+#endif // MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 2bcd2870949f3..2cd7d14f5517b 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -190,10 +190,13 @@ def ArithToLLVMConversionPass : Pass<"convert-arith-to-llvm"> {
// ArithToAPFloat
//===----------------------------------------------------------------------===//
-def ArithToAPFloatConversionPass : Pass<"convert-arith-to-apfloat"> {
- let summary = "Convert Arith dialect ops on FP8 types to APFloat lib calls";
+def ArithToAPFloatConversionPass
+ : Pass<"convert-arith-to-apfloat", "ModuleOp"> {
+ let summary = "Convert Arith ops to APFloat runtime library calls";
let description = [{
- This pass converts supported Arith ops which manipulate FP8 typed values to APFloat lib calls.
+ This pass converts supported Arith ops to APFloat-based runtime library
+ calls (APFloatWrappers.cpp). APFloat is a software implementation of
+ floating-point arithmetic operations.
}];
let dependentDialects = ["func::FuncDialect"];
let options = [];
diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
index bc451e88eb3bd..ee752b48eff9b 100644
--- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
@@ -1,4 +1,4 @@
-//===- ArithToAPFloat.cpp - Arithmetic to APFloat impl conversion ---------===//
+//===- ArithToAPFloat.cpp - Arithmetic to APFloat Conversion --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -13,7 +13,8 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Utils/Utils.h"
#include "mlir/IR/Verifier.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "llvm/ADT/TypeSwitch.h"
namespace mlir {
#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS
@@ -23,100 +24,66 @@ namespace mlir {
using namespace mlir;
using namespace mlir::func;
-#define APFLOAT_BIN_OPS(X) \
- X(add) \
- X(subtract) \
- X(multiply) \
- X(divide) \
- X(remainder) \
- X(mod)
-
-#define APFLOAT_EXTERN_K(OP) kApFloat_##OP
-
-#define APFLOAT_EXTERN_NAME(OP) \
- static constexpr llvm::StringRef APFLOAT_EXTERN_K(OP) = "_mlir_" \
- "apfloat_" #OP;
-
-namespace mlir::func {
-#define LOOKUP_OR_CREATE_APFLOAT_FN_DECL(OP) \
- FailureOr<FuncOp> lookupOrCreateApFloat##OP##Fn( \
- OpBuilder &b, Operation *moduleOp, \
- SymbolTableCollection *symbolTables = nullptr);
-
-APFLOAT_BIN_OPS(LOOKUP_OR_CREATE_APFLOAT_FN_DECL)
-
-#undef LOOKUP_OR_CREATE_APFLOAT_FN_DECL
-
-APFLOAT_BIN_OPS(APFLOAT_EXTERN_NAME)
-
-#define LOOKUP_OR_CREATE_APFLOAT_FN_DEFN(OP) \
- FailureOr<FuncOp> lookupOrCreateApFloat##OP##Fn( \
- OpBuilder &b, Operation *moduleOp, \
- SymbolTableCollection *symbolTables) { \
- return lookupOrCreateFn(b, moduleOp, APFLOAT_EXTERN_K(OP), \
- {IntegerType::get(moduleOp->getContext(), 32), \
- IntegerType::get(moduleOp->getContext(), 64), \
- IntegerType::get(moduleOp->getContext(), 64)}, \
- {IntegerType::get(moduleOp->getContext(), 64)}, \
- /*setPrivate*/ true, symbolTables); \
- }
-
-APFLOAT_BIN_OPS(LOOKUP_OR_CREATE_APFLOAT_FN_DEFN)
-#undef LOOKUP_OR_CREATE_APFLOAT_FN_DEFN
-} // namespace mlir::func
-
-struct FancyAddFLowering : OpRewritePattern<arith::AddFOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(arith::AddFOp op,
- PatternRewriter &rewriter) const override {
- // Get APFloat adder function from runtime library.
- auto parent = op->getParentOfType<ModuleOp>();
- if (!parent)
- return failure();
- if (!llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
- Float8E5M2FNUZType, Float8E4M3FNUZType,
- Float8E4M3B11FNUZType, Float8E3M4Type, Float4E2M1FNType,
- Float6E2M3FNType, Float6E3M2FNType, Float8E8M0FNUType>(
- op.getType()))
- return failure();
- FailureOr<Operation *> adder = lookupOrCreateApFloataddFn(rewriter, parent);
-
- // Cast operands to 64-bit integers.
- Location loc = op.getLoc();
- auto floatTy = cast<FloatType>(op.getType());
- auto intWType = rewriter.getIntegerType(floatTy.getWidth());
- auto int64Type = rewriter.getI64Type();
- Value lhsBits = arith::ExtUIOp::create(
- rewriter, loc, int64Type,
- arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
- Value rhsBits = arith::ExtUIOp::create(
- rewriter, loc, int64Type,
- arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));
-
- // Call software implementation of floating point addition.
- int32_t sem =
- llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
- Value semValue = arith::ConstantOp::create(
- rewriter, loc, rewriter.getI32Type(),
- rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
- SmallVector<Value> params = {semValue, lhsBits, rhsBits};
- auto resultOp =
- func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
- SymbolRefAttr::get(*adder), params);
-
- // Truncate result to the original width.
- Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
- resultOp->getResult(0));
- rewriter.replaceAllUsesWith(
- op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
- return success();
- }
-};
+/// Helper function to lookup or create the symbol for a runtime library
+/// function for a binary arithmetic operation.
+///
+/// Parameter 1: APFloat semantics
+/// Parameter 2: Left-hand side operand
+/// Parameter 3: Right-hand side operand
+///
+/// This function will return a failure if the function is found but has an
+/// unexpected signature.
+///
+static FailureOr<Operation *>
+lookupOrCreateBinaryFn(OpBuilder &b, Operation *moduleOp, StringRef name,
+ SymbolTableCollection *symbolTables = nullptr) {
+ return lookupOrCreateFn(b, moduleOp,
+ (llvm::Twine("_mlir_apfloat_") + name).str(),
+ {IntegerType::get(moduleOp->getContext(), 32),
+ IntegerType::get(moduleOp->getContext(), 64),
+ IntegerType::get(moduleOp->getContext(), 64)},
+ {IntegerType::get(moduleOp->getContext(), 64)},
+ /*setPrivate=*/true, symbolTables);
+}
-void arith::populateArithToAPFloatConversionPatterns(
- RewritePatternSet &patterns) {
- patterns.add<FancyAddFLowering>(patterns.getContext());
+/// Rewrite a binary arithmetic operation to an APFloat function call.
+template <typename OpTy>
+static LogicalResult rewriteBinaryOp(RewriterBase &rewriter, ModuleOp module,
+ OpTy op, StringRef apfloatName) {
+ // Get APFloat function from runtime library.
+ FailureOr<Operation *> fn =
+ lookupOrCreateBinaryFn(rewriter, module, apfloatName);
+ if (failed(fn))
+ return op->emitError("failed to lookup or create APFloat function");
+
+ // Cast operands to 64-bit integers.
+ Location loc = op.getLoc();
+ auto floatTy = cast<FloatType>(op.getType());
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ auto int64Type = rewriter.getI64Type();
+ Value lhsBits = arith::ExtUIOp::create(
+ rewriter, loc, int64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
+ Value rhsBits = arith::ExtUIOp::create(
+ rewriter, loc, int64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));
+
+ // Call APFloat function.
+ int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
+ Value semValue = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getI32Type(),
+ rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
+ SmallVector<Value> params = {semValue, lhsBits, rhsBits};
+ auto resultOp =
+ func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*fn), params);
+
+ // Truncate result to the original width.
+ Value truncatedBits =
+ arith::TruncIOp::create(rewriter, loc, intWType, resultOp->getResult(0));
+ rewriter.replaceOp(
+ op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
+ return success();
}
namespace {
@@ -126,10 +93,34 @@ struct ArithToAPFloatConversionPass final
ArithToAPFloatConversionPass>::ArithToAPFloatConversionPassBase;
void runOnOperation() override {
- Operation *op = getOperation();
- RewritePatternSet patterns(op->getContext());
- arith::populateArithToAPFloatConversionPatterns(patterns);
- if (failed(applyPatternsGreedily(op, std::move(patterns))))
+ ModuleOp module = getOperation();
+ IRRewriter rewriter(getOperation()->getContext());
+ SmallVector<arith::AddFOp> addOps;
+ WalkResult status = module->walk([&](Operation *op) {
+ rewriter.setInsertionPoint(op);
+ LogicalResult result =
+ llvm::TypeSwitch<Operation *, LogicalResult>(op)
+ .Case<arith::AddFOp>([&](arith::AddFOp op) {
+ return rewriteBinaryOp(rewriter, module, op, "add");
+ })
+ .Case<arith::SubFOp>([&](arith::SubFOp op) {
+ return rewriteBinaryOp(rewriter, module, op, "subtract");
+ })
+ .Case<arith::MulFOp>([&](arith::MulFOp op) {
+ return rewriteBinaryOp(rewriter, module, op, "multiply");
+ })
+ .Case<arith::DivFOp>([&](arith::DivFOp op) {
+ return rewriteBinaryOp(rewriter, module, op, "divide");
+ })
+ .Case<arith::RemFOp>([&](arith::RemFOp op) {
+ return rewriteBinaryOp(rewriter, module, op, "remainder");
+ })
+ .Default([](Operation *op) { return success(); });
+ if (failed(result))
+ return WalkResult::interrupt();
+ return WalkResult::advance();
+ });
+ if (status.wasInterrupted())
return signalPassFailure();
}
};
diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
index a5049436d03c1..685edfc0fd082 100644
--- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
@@ -1,46 +1,55 @@
-//===- ArmRunnerUtils.cpp - Utilities for configuring architecture properties //
+//===- APFloatWrappers.cpp - Software Implementation of FP Arithmetics --- ===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
-
+//
+// This file exposes the APFloat infrastructure to MLIR programs as a runtime
+// library. APFloat is a software implementation of floating point arithmetics.
+//
+// On the MLIR side, floating-point values must be bitcasted to 64-bit integers
+// before calling a runtime function. If a floating-point type has less than
+// 64 bits, it must be zero-extended to 64 bits after bitcasting it to an
+// integer.
+//
+// Runtime functions receive the floating-point operands of the arithmeic
+// operation in the form of 64-bit integers, along with the APFloat semantics
+// in the form of a 32-bit integer, which will be interpreted as an
+// APFloatBase::Semantics enum value.
+//
#include "llvm/ADT/APFloat.h"
-#include <iostream>
-
#if (defined(_WIN32) || defined(__CYGWIN__))
#define MLIR_APFLOAT_WRAPPERS_EXPORTED __declspec(dllexport)
#else
#define MLIR_APFLOAT_WRAPPERS_EXPORTED __attribute__((visibility("default")))
#endif
+/// Binary operations without rounding mode.
#define APFLOAT_BINARY_OP(OP) \
- int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED APFloat_##OP( \
+ int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED _mlir_apfloat_##OP( \
int32_t semantics, uint64_t a, uint64_t b) { \
const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \
static_cast<llvm::APFloatBase::Semantics>(semantics)); \
unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); \
llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a)); \
llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b)); \
- llvm::APFloatBase::opStatus status = lhs.OP(rhs); \
- assert(status == llvm::APFloatBase::opOK && "expected " #OP \
- " opstatus to be OK"); \
+ lhs.OP(rhs); \
return lhs.bitcastToAPInt().getZExtValue(); \
}
+/// Binary operations with rounding mode.
#define APFLOAT_BINARY_OP_ROUNDING_MODE(OP, ROUNDING_MODE) \
- int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED APFloat_##OP( \
+ int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED _mlir_apfloat_##OP( \
int32_t semantics, uint64_t a, uint64_t b) { \
const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \
static_cast<llvm::APFloatBase::Semantics>(semantics)); \
unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); \
llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a)); \
llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b)); \
- llvm::APFloatBase::opStatus status = lhs.OP(rhs, ROUNDING_MODE); \
- assert(status == llvm::APFloatBase::opOK && "expected " #OP \
- " opstatus to be OK"); \
+ lhs.OP(rhs, ROUNDING_MODE); \
return lhs.bitcastToAPInt().getZExtValue(); \
}
@@ -57,7 +66,6 @@ BIN_OPS_WITH_ROUNDING(APFLOAT_BINARY_OP_ROUNDING_MODE)
#undef APFLOAT_BINARY_OP_ROUNDING_MODE
APFLOAT_BINARY_OP(remainder)
-APFLOAT_BINARY_OP(mod)
#undef APFLOAT_BINARY_OP
@@ -68,6 +76,6 @@ void MLIR_APFLOAT_WRAPPERS_EXPORTED printApFloat(int32_t semantics,
unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
double d = x.convertToDouble();
- std::cout << d << std::endl;
+ fprintf(stdout, "%lg", d);
}
}
diff --git a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
new file mode 100644
index 0000000000000..a1e0a382c85ab
--- /dev/null
+++ b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
@@ -0,0 +1,88 @@
+// RUN: mlir-opt %s --convert-arith-to-apfloat -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func private @_mlir_apfloat_add(i32, i64, i64) -> i64
+
+// CHECK-LABEL: func.func @foo() -> f8E4M3FN {
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant 2.250000e+00 : f8E4M3FN
+// CHECK: return %[[CONSTANT_0]] : f8E4M3FN
+// CHECK: }
+
+// CHECK-LABEL: func.func @full_example() {
+// CHECK: %[[cst:.*]] = arith.constant 1.375000e+00 : f8E4M3FN
+// CHECK: %[[rhs:.*]] = call @foo() : () -> f8E4M3FN
+// CHECK: %[[lhs_casted:.*]] = arith.bitcast %[[cst]] : f8E4M3FN to i8
+// CHECK: %[[lhs_ext:.*]] = arith.extui %[[lhs_casted]] : i8 to i64
+// CHECK: %[[rhs_casted:.*]] = arith.bitcast %[[rhs]] : f8E4M3FN to i8
+// CHECK: %[[rhs_ext:.*]] = arith.extui %[[rhs_casted]] : i8 to i64
+// CHECK: %[[c10_i32:.*]] = arith.constant 10 : i32
+// CHECK: %[[res:.*]] = call @_mlir_apfloat_add(%[[c10_i32]], %[[lhs_ext]], %[[rhs_ext]]) : (i32, i64, i64) -> i64
+// CHECK: %[[res_trunc:.*]] = arith.trunci %[[res]] : i64 to i8
+// CHECK: %[[res_casted:.*]] = arith.bitcast %[[res_trunc]] : i8 to f8E4M3FN
+// CHECK: vector.print %[[res_casted]] : f8E4M3FN
+// CHECK: return
+// CHECK: }
+
+// Put rhs into separate function so that it won't be constant-folded.
+func.func @foo() -> f8E4M3FN {
+ %cst = arith.constant 2.2 : f8E4M3FN
+ return %cst : f8E4M3FN
+}
+
+func.func @full_example() {
+ %a = arith.constant 1.4 : f8E4M3FN
+ %b = func.call @foo() : () -> (f8E4M3FN)
+ %c = arith.addf %a, %b : f8E4M3FN
+
+ vector.print %c : f8E4M3FN
+ return
+}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_add(i32, i64, i64) -> i64
+// CHECK: %[[sem:.*]] = arith.constant 18 : i32
+// CHECK: call @_mlir_apfloat_add(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
+func.func @addf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
+ %0 = arith.addf %arg0, %arg1 : f4E2M1FN
+ return
+}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_subtract(i32, i64, i64) -> i64
+// CHECK: %[[sem:.*]] = arith.constant 18 : i32
+// CHECK: call @_mlir_apfloat_subtract(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
+func.func @subf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
+ %0 = arith.subf %arg0, %arg1 : f4E2M1FN
+ return
+}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_multiply(i32, i64, i64) -> i64
+// CHECK: %[[sem:.*]] = arith.constant 18 : i32
+// CHECK: call @_mlir_apfloat_multiply(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
+func.func @subf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
+ %0 = arith.mulf %arg0, %arg1 : f4E2M1FN
+ return
+}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_divide(i32, i64, i64) -> i64
+// CHECK: %[[sem:.*]] = arith.constant 18 : i32
+// CHECK: call @_mlir_apfloat_divide(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
+func.func @subf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
+ %0 = arith.divf %arg0, %arg1 : f4E2M1FN
+ return
+}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_remainder(i32, i64, i64) -> i64
+// CHECK: %[[sem:.*]] = arith.constant 18 : i32
+// CHECK: call @_mlir_apfloat_remainder(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
+func.func @remf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
+ %0 = arith.remf %arg0, %arg1 : f4E2M1FN
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
index d4c2394474b15..20c37f0ac8a25 100644
--- a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
@@ -1,38 +1,19 @@
-// Check that the ceildivsi lowering is correct.
-// We do not check any poison or UB values, as it is not possible to catch them.
-
-// RUN: mlir-opt %s --convert-arith-to-apfloat
+// RUN: mlir-opt %s --convert-arith-to-apfloat --convert-to-llvm | \
+// RUN: mlir-runner -e entry --entry-point-result=void \
+// RUN: --shared-libs=%mlir_c_runner_utils | FileCheck %s
// Put rhs into separate function so that it won't be constant-folded.
-func.func @foo() -> f4E2M1FN {
- %cst = arith.constant 5.0 : f4E2M1FN
- return %cst : f4E2M1FN
+func.func @foo() -> f8E4M3FN {
+ %cst = arith.constant 2.2 : f8E4M3FN
+ return %cst : f8E4M3FN
}
func.func @entry() {
- %a = arith.constant 5.0 : f4E2M1FN
- %b = func.call @foo() : () -> (f4E2M1FN)
- %c = arith.addf %a, %b : f4E2M1FN
- vector.print %c : f4E2M1FN
+ %a = arith.constant 1.4 : f8E4M3FN
+ %b = func.call @foo() : () -> (f8E4M3FN)
+ %c = arith.addf %a, %b : f8E4M3FN
+
+ // CHECK: 3.5
+ vector.print %c : f8E4M3FN
return
}
-
-// CHECK-LABEL: func.func private @_mlir_apfloat_add(i32, i64, i64) -> i64
-
-// CHECK-LABEL: func.func @foo() -> f4E2M1FN {
-// CHECK: %[[CONSTANT_0:.*]] = arith.constant 4.000000e+00 : f4E2M1FN
-// CHECK: return %[[CONSTANT_0]] : f4E2M1FN
-// CHECK: }
-
-// CHECK-LABEL: func.func @entry() {
-// CHECK: %[[CONSTANT_0:.*]] = arith.constant 18 : i32
-// CHECK: %[[CONSTANT_1:.*]] = arith.constant 6 : i64
-// CHECK: %[[VAL_0:.*]] = call @foo() : () -> f4E2M1FN
-// CHECK: %[[BITCAST_0:.*]] = arith.bitcast %[[VAL_0]] : f4E2M1FN to i4
-// CHECK: %[[EXTUI_0:.*]] = arith.extui %[[BITCAST_0]] : i4 to i64
-// CHECK: %[[VAL_1:.*]] = call @_mlir_apfloat_add(%[[CONSTANT_0]], %[[EXTUI_0]], %[[CONSTANT_1]]) : (i32, i64, i64) -> i64
-// CHECK: %[[TRUNCI_0:.*]] = arith.trunci %[[VAL_1]] : i64 to i4
-// CHECK: %[[BITCAST_1:.*]] = arith.bitcast %[[TRUNCI_0]] : i4 to f4E2M1FN
-// CHECK: vector.print %[[BITCAST_1]] : f4E2M1FN
-// CHECK: return
-// CHECK: }
\ No newline at end of file
>From 930a664a8792c1f29fa470d568ddf92a2ccb0e49 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sun, 9 Nov 2025 23:30:38 +0000
Subject: [PATCH 09/12] improve test case
---
.../Arith/CPU/test-apfloat-emulation.mlir | 33 ++++++++++++++-----
1 file changed, 24 insertions(+), 9 deletions(-)
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
index 20c37f0ac8a25..a2b3eb73a60b8 100644
--- a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
@@ -1,19 +1,34 @@
+// Case 1: All floating-point arithmetics is lowered through APFloat.
// RUN: mlir-opt %s --convert-arith-to-apfloat --convert-to-llvm | \
-// RUN: mlir-runner -e entry --entry-point-result=void \
-// RUN: --shared-libs=%mlir_c_runner_utils | FileCheck %s
+// RUN: mlir-runner -e entry --entry-point-result=void \
+// RUN: --shared-libs=%mlir_c_runner_utils | FileCheck %s
+
+// Case 2: Only unsupported arithmetics (f8E4M3FN) is lowered through APFloat.
+// Arithmetics on f32 is lowered directly to LLVM.
+// RUN: mlir-opt %s --convert-to-llvm --convert-arith-to-apfloat \
+// RUN: --convert-to-llvm --reconcile-unrealized-casts | \
+// RUN: mlir-runner -e entry --entry-point-result=void \
+// RUN: --shared-libs=%mlir_c_runner_utils | FileCheck %s
// Put rhs into separate function so that it won't be constant-folded.
-func.func @foo() -> f8E4M3FN {
- %cst = arith.constant 2.2 : f8E4M3FN
- return %cst : f8E4M3FN
+func.func @foo() -> (f8E4M3FN, f32) {
+ %cst1 = arith.constant 2.2 : f8E4M3FN
+ %cst2 = arith.constant 2.2 : f32
+ return %cst1, %cst2 : f8E4M3FN, f32
}
func.func @entry() {
- %a = arith.constant 1.4 : f8E4M3FN
- %b = func.call @foo() : () -> (f8E4M3FN)
- %c = arith.addf %a, %b : f8E4M3FN
+ %a1 = arith.constant 1.4 : f8E4M3FN
+ %a2 = arith.constant 1.4 : f32
+ %b1, %b2 = func.call @foo() : () -> (f8E4M3FN, f32)
+ %c1 = arith.addf %a1, %b1 : f8E4M3FN // not supported by LLVM
+ %c2 = arith.addf %a2, %b2 : f32 // supported by LLVM
// CHECK: 3.5
- vector.print %c : f8E4M3FN
+ vector.print %c1 : f8E4M3FN
+
+ // CHECK: 3.6
+ vector.print %c2 : f32
+
return
}
>From b644bfb28c17f3d86e638093c3df7ddb3a54c9ba Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sun, 9 Nov 2025 23:35:17 +0000
Subject: [PATCH 10/12] address comments
---
mlir/include/mlir/Dialect/Func/Utils/Utils.h | 8 ++---
.../ArithToAPFloat/ArithToAPFloat.cpp | 30 +++++++++----------
mlir/lib/Dialect/Func/Utils/Utils.cpp | 7 +++--
3 files changed, 22 insertions(+), 23 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
index 9c9973cf84368..22c53a239b524 100644
--- a/mlir/include/mlir/Dialect/Func/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
@@ -63,10 +63,10 @@ deduplicateArgsOfFuncOp(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp,
/// Create a FuncOp with signature `resultTypes`(`paramTypes`)` and name `name`.
/// Return a failure if the FuncOp found has unexpected signature.
FailureOr<FuncOp>
-lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
- ArrayRef<Type> paramTypes = {},
- ArrayRef<Type> resultTypes = {}, bool setPrivate = false,
- SymbolTableCollection *symbolTables = nullptr);
+lookupOrCreateFnDecl(OpBuilder &b, Operation *moduleOp, StringRef name,
+ ArrayRef<Type> paramTypes = {},
+ ArrayRef<Type> resultTypes = {}, bool setPrivate = false,
+ SymbolTableCollection *symbolTables = nullptr);
} // namespace func
} // namespace mlir
diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
index ee752b48eff9b..1b8f3018cf787 100644
--- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
@@ -37,13 +37,12 @@ using namespace mlir::func;
static FailureOr<Operation *>
lookupOrCreateBinaryFn(OpBuilder &b, Operation *moduleOp, StringRef name,
SymbolTableCollection *symbolTables = nullptr) {
- return lookupOrCreateFn(b, moduleOp,
- (llvm::Twine("_mlir_apfloat_") + name).str(),
- {IntegerType::get(moduleOp->getContext(), 32),
- IntegerType::get(moduleOp->getContext(), 64),
- IntegerType::get(moduleOp->getContext(), 64)},
- {IntegerType::get(moduleOp->getContext(), 64)},
- /*setPrivate=*/true, symbolTables);
+ auto i32Type = IntegerType::get(moduleOp->getContext(), 32);
+ auto i64Type = IntegerType::get(moduleOp->getContext(), 64);
+ return lookupOrCreateFnDecl(b, moduleOp,
+ (llvm::Twine("_mlir_apfloat_") + name).str(),
+ {i32Type, i64Type, i64Type}, {i64Type},
+ /*setPrivate=*/true, symbolTables);
}
/// Rewrite a binary arithmetic operation to an APFloat function call.
@@ -89,31 +88,30 @@ static LogicalResult rewriteBinaryOp(RewriterBase &rewriter, ModuleOp module,
namespace {
struct ArithToAPFloatConversionPass final
: impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
- using impl::ArithToAPFloatConversionPassBase<
- ArithToAPFloatConversionPass>::ArithToAPFloatConversionPassBase;
+ using Base::Base;
void runOnOperation() override {
- ModuleOp module = getOperation();
+ ModuleOp moduleOp = getOperation();
IRRewriter rewriter(getOperation()->getContext());
SmallVector<arith::AddFOp> addOps;
- WalkResult status = module->walk([&](Operation *op) {
+ WalkResult status = moduleOp->walk([&](Operation *op) {
rewriter.setInsertionPoint(op);
LogicalResult result =
llvm::TypeSwitch<Operation *, LogicalResult>(op)
.Case<arith::AddFOp>([&](arith::AddFOp op) {
- return rewriteBinaryOp(rewriter, module, op, "add");
+ return rewriteBinaryOp(rewriter, moduleOp, op, "add");
})
.Case<arith::SubFOp>([&](arith::SubFOp op) {
- return rewriteBinaryOp(rewriter, module, op, "subtract");
+ return rewriteBinaryOp(rewriter, moduleOp, op, "subtract");
})
.Case<arith::MulFOp>([&](arith::MulFOp op) {
- return rewriteBinaryOp(rewriter, module, op, "multiply");
+ return rewriteBinaryOp(rewriter, moduleOp, op, "multiply");
})
.Case<arith::DivFOp>([&](arith::DivFOp op) {
- return rewriteBinaryOp(rewriter, module, op, "divide");
+ return rewriteBinaryOp(rewriter, moduleOp, op, "divide");
})
.Case<arith::RemFOp>([&](arith::RemFOp op) {
- return rewriteBinaryOp(rewriter, module, op, "remainder");
+ return rewriteBinaryOp(rewriter, moduleOp, op, "remainder");
})
.Default([](Operation *op) { return success(); });
if (failed(result))
diff --git a/mlir/lib/Dialect/Func/Utils/Utils.cpp b/mlir/lib/Dialect/Func/Utils/Utils.cpp
index e187e62cf6555..8070e9d6a0a95 100644
--- a/mlir/lib/Dialect/Func/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Func/Utils/Utils.cpp
@@ -256,9 +256,10 @@ func::deduplicateArgsOfFuncOp(RewriterBase &rewriter, func::FuncOp funcOp,
}
FailureOr<func::FuncOp>
-func::lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
- ArrayRef<Type> paramTypes, ArrayRef<Type> resultTypes,
- bool setPrivate, SymbolTableCollection *symbolTables) {
+func::lookupOrCreateFnDecl(OpBuilder &b, Operation *moduleOp, StringRef name,
+ ArrayRef<Type> paramTypes,
+ ArrayRef<Type> resultTypes, bool setPrivate,
+ SymbolTableCollection *symbolTables) {
assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
"expected SymbolTable operation");
>From d94e2d1961b3388df4d4136d93f43d3ca261c8fe Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Sun, 9 Nov 2025 18:49:37 -0800
Subject: [PATCH 11/12] address the rest of the comments
---
.../ArithToAPFloat/ArithToAPFloat.cpp | 129 +++++++++---------
1 file changed, 62 insertions(+), 67 deletions(-)
diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
index 1b8f3018cf787..c3c38a61100b9 100644
--- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
@@ -12,9 +12,9 @@
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Utils/Utils.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
-
-#include "llvm/ADT/TypeSwitch.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
namespace mlir {
#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS
@@ -24,7 +24,7 @@ namespace mlir {
using namespace mlir;
using namespace mlir::func;
-/// Helper function to lookup or create the symbol for a runtime library
+/// Helper function to look up or create the symbol for a runtime library
/// function for a binary arithmetic operation.
///
/// Parameter 1: APFloat semantics
@@ -46,44 +46,55 @@ lookupOrCreateBinaryFn(OpBuilder &b, Operation *moduleOp, StringRef name,
}
/// Rewrite a binary arithmetic operation to an APFloat function call.
-template <typename OpTy>
-static LogicalResult rewriteBinaryOp(RewriterBase &rewriter, ModuleOp module,
- OpTy op, StringRef apfloatName) {
- // Get APFloat function from runtime library.
- FailureOr<Operation *> fn =
- lookupOrCreateBinaryFn(rewriter, module, apfloatName);
- if (failed(fn))
- return op->emitError("failed to lookup or create APFloat function");
+template <typename OpTy, const char *APFloatName>
+struct ArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
+ using OpRewritePattern<OpTy>::OpRewritePattern;
- // Cast operands to 64-bit integers.
- Location loc = op.getLoc();
- auto floatTy = cast<FloatType>(op.getType());
- auto intWType = rewriter.getIntegerType(floatTy.getWidth());
- auto int64Type = rewriter.getI64Type();
- Value lhsBits = arith::ExtUIOp::create(
- rewriter, loc, int64Type,
- arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
- Value rhsBits = arith::ExtUIOp::create(
- rewriter, loc, int64Type,
- arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override {
+ auto moduleOp = op->template getParentOfType<ModuleOp>();
+ if (!moduleOp) {
+ op.emitError("arith op must be contained within a builtin.module");
+ return failure();
+ }
+ // Get APFloat function from runtime library.
+ FailureOr<Operation *> fn =
+ lookupOrCreateBinaryFn(rewriter, moduleOp, APFloatName);
+ if (failed(fn))
+ return op->emitError("failed to lookup or create APFloat function");
- // Call APFloat function.
- int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
- Value semValue = arith::ConstantOp::create(
- rewriter, loc, rewriter.getI32Type(),
- rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
- SmallVector<Value> params = {semValue, lhsBits, rhsBits};
- auto resultOp =
- func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
- SymbolRefAttr::get(*fn), params);
+ rewriter.setInsertionPoint(op);
+ // Cast operands to 64-bit integers.
+ Location loc = op.getLoc();
+ auto floatTy = cast<FloatType>(op.getType());
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ auto int64Type = rewriter.getI64Type();
+ Value lhsBits = arith::ExtUIOp::create(
+ rewriter, loc, int64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
+ Value rhsBits = arith::ExtUIOp::create(
+ rewriter, loc, int64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));
- // Truncate result to the original width.
- Value truncatedBits =
- arith::TruncIOp::create(rewriter, loc, intWType, resultOp->getResult(0));
- rewriter.replaceOp(
- op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
- return success();
-}
+ // Call APFloat function.
+ int32_t sem =
+ llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
+ Value semValue = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getI32Type(),
+ rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
+ SmallVector<Value> params = {semValue, lhsBits, rhsBits};
+ auto resultOp =
+ func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*fn), params);
+
+ // Truncate result to the original width.
+ Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
+ resultOp->getResult(0));
+ rewriter.replaceOp(
+ op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
+ return success();
+ }
+};
namespace {
struct ArithToAPFloatConversionPass final
@@ -91,35 +102,19 @@ struct ArithToAPFloatConversionPass final
using Base::Base;
void runOnOperation() override {
- ModuleOp moduleOp = getOperation();
- IRRewriter rewriter(getOperation()->getContext());
- SmallVector<arith::AddFOp> addOps;
- WalkResult status = moduleOp->walk([&](Operation *op) {
- rewriter.setInsertionPoint(op);
- LogicalResult result =
- llvm::TypeSwitch<Operation *, LogicalResult>(op)
- .Case<arith::AddFOp>([&](arith::AddFOp op) {
- return rewriteBinaryOp(rewriter, moduleOp, op, "add");
- })
- .Case<arith::SubFOp>([&](arith::SubFOp op) {
- return rewriteBinaryOp(rewriter, moduleOp, op, "subtract");
- })
- .Case<arith::MulFOp>([&](arith::MulFOp op) {
- return rewriteBinaryOp(rewriter, moduleOp, op, "multiply");
- })
- .Case<arith::DivFOp>([&](arith::DivFOp op) {
- return rewriteBinaryOp(rewriter, moduleOp, op, "divide");
- })
- .Case<arith::RemFOp>([&](arith::RemFOp op) {
- return rewriteBinaryOp(rewriter, moduleOp, op, "remainder");
- })
- .Default([](Operation *op) { return success(); });
- if (failed(result))
- return WalkResult::interrupt();
- return WalkResult::advance();
- });
- if (status.wasInterrupted())
- return signalPassFailure();
+ MLIRContext *context = &getContext();
+ RewritePatternSet patterns(context);
+ static const char add[] = "add";
+ static const char subtract[] = "subtract";
+ static const char multiply[] = "multiply";
+ static const char divide[] = "divide";
+ static const char remainder[] = "remainder";
+ patterns.add<ArithOpToAPFloatConversion<arith::AddFOp, add>,
+ ArithOpToAPFloatConversion<arith::SubFOp, subtract>,
+ ArithOpToAPFloatConversion<arith::MulFOp, multiply>,
+ ArithOpToAPFloatConversion<arith::DivFOp, divide>,
+ ArithOpToAPFloatConversion<arith::RemFOp, remainder>>(context);
+ walkAndApplyPatterns(getOperation(), std::move(patterns));
}
};
} // namespace
>From 7d29b71b19a4115f650aa5a700bf4494a033c765 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Sun, 9 Nov 2025 23:33:15 -0800
Subject: [PATCH 12/12] use notifyMatchFailure
---
mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
index c3c38a61100b9..6891361353f16 100644
--- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
@@ -54,8 +54,8 @@ struct ArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
PatternRewriter &rewriter) const override {
auto moduleOp = op->template getParentOfType<ModuleOp>();
if (!moduleOp) {
- op.emitError("arith op must be contained within a builtin.module");
- return failure();
+ return rewriter.notifyMatchFailure(
+ op, "arith op must be contained within a builtin.module");
}
// Get APFloat function from runtime library.
FailureOr<Operation *> fn =
More information about the Mlir-commits
mailing list