[Mlir-commits] [mlir] 75751f3 - Reapply "Reapply "[mlir] Add FP software implementation lowering pass: `arith-to-apfloat` (#166618)" (#167431)" (#167436)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 11 08:50:41 PST 2025
Author: Maksim Levental
Date: 2025-11-11T08:50:37-08:00
New Revision: 75751f33a8702d14184d435d544f38149514b63f
URL: https://github.com/llvm/llvm-project/commit/75751f33a8702d14184d435d544f38149514b63f
DIFF: https://github.com/llvm/llvm-project/commit/75751f33a8702d14184d435d544f38149514b63f.diff
LOG: Reapply "Reapply "[mlir] Add FP software implementation lowering pass: `arith-to-apfloat` (#166618)" (#167431)" (#167436)
Reland https://github.com/llvm/llvm-project/pull/166618 by fixing
missing symbol issues by explicitly loading
`--shared-libs=%mlir_apfloat_wrappers` as well as
`--shared-libs=%mlir_c_runner_utils`.
Added:
mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h
mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt
mlir/lib/ExecutionEngine/APFloatWrappers.cpp
mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
Modified:
mlir/include/mlir/Conversion/Passes.h
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Dialect/Func/Utils/Utils.h
mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
mlir/lib/Conversion/CMakeLists.txt
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/Func/Utils/Utils.cpp
mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
mlir/lib/ExecutionEngine/CMakeLists.txt
mlir/test/lit.cfg.py
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h b/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h
new file mode 100644
index 0000000000000..64a42a228199e
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h
@@ -0,0 +1,21 @@
+//===- ArithToAPFloat.h - Arith to APFloat impl conversion ---*- C++ ----*-===//
+//
+// Part of the APFloat Project, under the Apache License v2.0 with APFloat
+// Exceptions. See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH APFloat-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H
+#define MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_ARITHTOAPFLOATCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 40d866ec7bf10..82bdfd02661a6 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -12,6 +12,7 @@
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
+#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 70e3e45c225db..79bc380dbcb7a 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -186,6 +186,21 @@ def ArithToLLVMConversionPass : Pass<"convert-arith-to-llvm"> {
];
}
+//===----------------------------------------------------------------------===//
+// ArithToAPFloat
+//===----------------------------------------------------------------------===//
+
+def ArithToAPFloatConversionPass
+ : Pass<"convert-arith-to-apfloat", "ModuleOp"> {
+ let summary = "Convert Arith ops to APFloat runtime library calls";
+ let description = [{
+ This pass converts supported Arith ops to APFloat-based runtime library
+ calls (APFloatWrappers.cpp). APFloat is a software implementation of
+ floating-point arithmetic operations.
+ }];
+ let dependentDialects = ["func::FuncDialect"];
+}
+
//===----------------------------------------------------------------------===//
// ArithToSPIRV
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
index 3576126a487ac..00d50874a2e8d 100644
--- a/mlir/include/mlir/Dialect/Func/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
@@ -60,6 +60,13 @@ mlir::FailureOr<std::pair<mlir::func::FuncOp, mlir::func::CallOp>>
deduplicateArgsOfFuncOp(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp,
mlir::ModuleOp moduleOp);
+/// Look up a FuncOp with signature `resultTypes`(`paramTypes`)` and name
+/// `name`. Return a failure if the FuncOp is found but with a
diff erent
+/// signature.
+FailureOr<FuncOp> lookupFnDecl(SymbolOpInterface symTable, StringRef name,
+ FunctionType funcT,
+ SymbolTableCollection *symbolTables = nullptr);
+
} // namespace func
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 8ad9ed18acebd..b09d32022e348 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -52,6 +52,10 @@ lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp,
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables = nullptr);
+
/// Declares a function to print a C-string.
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
/// have the signature void(char const*). The default function is `printString`.
diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
new file mode 100644
index 0000000000000..01fd5b278aca4
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
@@ -0,0 +1,158 @@
+//===- ArithToAPFloat.cpp - Arithmetic to APFloat Conversion --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Utils/Utils.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::func;
+
+static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
+ StringRef name, FunctionType funcT, bool setPrivate,
+ SymbolTableCollection *symbolTables = nullptr) {
+ OpBuilder::InsertionGuard g(b);
+ assert(!symTable->getRegion(0).empty() && "expected non-empty region");
+ b.setInsertionPointToStart(&symTable->getRegion(0).front());
+ FuncOp funcOp = FuncOp::create(b, symTable->getLoc(), name, funcT);
+ if (setPrivate)
+ funcOp.setPrivate();
+ if (symbolTables) {
+ SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable);
+ symbolTable.insert(funcOp, symTable->getRegion(0).front().begin());
+ }
+ return funcOp;
+}
+
+/// Helper function to look up or create the symbol for a runtime library
+/// function for a binary arithmetic operation.
+///
+/// Parameter 1: APFloat semantics
+/// Parameter 2: Left-hand side operand
+/// Parameter 3: Right-hand side operand
+///
+/// This function will return a failure if the function is found but has an
+/// unexpected signature.
+///
+static FailureOr<FuncOp>
+lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
+ SymbolTableCollection *symbolTables = nullptr) {
+ auto i32Type = IntegerType::get(symTable->getContext(), 32);
+ auto i64Type = IntegerType::get(symTable->getContext(), 64);
+
+ std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
+ FunctionType funcT =
+ FunctionType::get(b.getContext(), {i32Type, i64Type, i64Type}, {i64Type});
+ FailureOr<FuncOp> func =
+ lookupFnDecl(symTable, funcName, funcT, symbolTables);
+ // Failed due to type mismatch.
+ if (failed(func))
+ return func;
+ // Successfully matched existing decl.
+ if (*func)
+ return *func;
+
+ return createFnDecl(b, symTable, funcName, funcT,
+ /*setPrivate=*/true, symbolTables);
+}
+
+/// Rewrite a binary arithmetic operation to an APFloat function call.
+template <typename OpTy, const char *APFloatName>
+struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
+ BinaryArithOpToAPFloatConversion(MLIRContext *context, PatternBenefit benefit,
+ SymbolOpInterface symTable)
+ : OpRewritePattern<OpTy>(context, benefit), symTable(symTable) {};
+
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override {
+ // Get APFloat function from runtime library.
+ FailureOr<FuncOp> fn =
+ lookupOrCreateBinaryFn(rewriter, symTable, APFloatName);
+ if (failed(fn))
+ return fn;
+
+ rewriter.setInsertionPoint(op);
+ // Cast operands to 64-bit integers.
+ Location loc = op.getLoc();
+ auto floatTy = cast<FloatType>(op.getType());
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ auto int64Type = rewriter.getI64Type();
+ Value lhsBits = arith::ExtUIOp::create(
+ rewriter, loc, int64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
+ Value rhsBits = arith::ExtUIOp::create(
+ rewriter, loc, int64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));
+
+ // Call APFloat function.
+ int32_t sem =
+ llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
+ Value semValue = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getI32Type(),
+ rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
+ SmallVector<Value> params = {semValue, lhsBits, rhsBits};
+ auto resultOp =
+ func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*fn), params);
+
+ // Truncate result to the original width.
+ Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
+ resultOp->getResult(0));
+ rewriter.replaceOp(
+ op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
+ return success();
+ }
+
+ SymbolOpInterface symTable;
+};
+
+namespace {
+struct ArithToAPFloatConversionPass final
+ : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
+ using Base::Base;
+
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+ RewritePatternSet patterns(context);
+ static const char add[] = "add";
+ static const char subtract[] = "subtract";
+ static const char multiply[] = "multiply";
+ static const char divide[] = "divide";
+ static const char remainder[] = "remainder";
+ patterns.add<BinaryArithOpToAPFloatConversion<arith::AddFOp, add>,
+ BinaryArithOpToAPFloatConversion<arith::SubFOp, subtract>,
+ BinaryArithOpToAPFloatConversion<arith::MulFOp, multiply>,
+ BinaryArithOpToAPFloatConversion<arith::DivFOp, divide>,
+ BinaryArithOpToAPFloatConversion<arith::RemFOp, remainder>>(
+ context, 1, getOperation());
+ LogicalResult result = success();
+ ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
+ if (diag.getSeverity() == DiagnosticSeverity::Error) {
+ result = failure();
+ }
+ // NB: if you don't return failure, no other diag handlers will fire (see
+ // mlir/lib/IR/Diagnostics.cpp:DiagnosticEngineImpl::emit).
+ return failure();
+ });
+ walkAndApplyPatterns(getOperation(), std::move(patterns));
+ if (failed(result))
+ return signalPassFailure();
+ }
+};
+} // namespace
diff --git a/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt b/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt
new file mode 100644
index 0000000000000..b5ec49c087163
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_conversion_library(MLIRArithToAPFloat
+ ArithToAPFloat.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRArithDialect
+ MLIRArithTransforms
+ MLIRFuncDialect
+ MLIRFuncUtils
+ )
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index b6099902cc337..f2bacc3399144 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -14,6 +14,7 @@
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/TypeUtilities.h"
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index bebf1b8fff3f9..613dc6d242ceb 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -2,6 +2,7 @@ add_subdirectory(AffineToStandard)
add_subdirectory(AMDGPUToROCDL)
add_subdirectory(ArithCommon)
add_subdirectory(ArithToAMDGPU)
+add_subdirectory(ArithToAPFloat)
add_subdirectory(ArithToArmSME)
add_subdirectory(ArithToEmitC)
add_subdirectory(ArithToLLVM)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 69a317ecd101f..c747e1b59558a 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1654,6 +1654,20 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
return failure();
}
}
+ } else if (auto floatTy = dyn_cast<FloatType>(printType)) {
+ // Print other floating-point types using the APFloat runtime library.
+ int32_t sem =
+ llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
+ Value semValue = LLVM::ConstantOp::create(
+ rewriter, loc, rewriter.getI32Type(),
+ rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
+ Value floatBits =
+ LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(), value);
+ printer =
+ LLVM::lookupOrCreateApFloatPrintFn(rewriter, parent, symbolTables);
+ emitCall(rewriter, loc, printer.value(),
+ ValueRange({semValue, floatBits}));
+ return success();
} else {
return failure();
}
diff --git a/mlir/lib/Dialect/Func/Utils/Utils.cpp b/mlir/lib/Dialect/Func/Utils/Utils.cpp
index b4cb0932ef631..d6dfd0229963c 100644
--- a/mlir/lib/Dialect/Func/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Func/Utils/Utils.cpp
@@ -254,3 +254,28 @@ func::deduplicateArgsOfFuncOp(RewriterBase &rewriter, func::FuncOp funcOp,
return std::make_pair(*newFuncOpOrFailure, newCallOp);
}
+
+FailureOr<func::FuncOp>
+func::lookupFnDecl(SymbolOpInterface symTable, StringRef name,
+ FunctionType funcT, SymbolTableCollection *symbolTables) {
+ FuncOp func;
+ if (symbolTables) {
+ func = symbolTables->lookupSymbolIn<FuncOp>(
+ symTable, StringAttr::get(symTable->getContext(), name));
+ } else {
+ func = llvm::dyn_cast_or_null<FuncOp>(
+ SymbolTable::lookupSymbolIn(symTable, name));
+ }
+
+ if (!func)
+ return func;
+
+ mlir::FunctionType foundFuncT = func.getFunctionType();
+ // Assert the signature of the found function is same as expected
+ if (funcT != foundFuncT) {
+ return func.emitError("matched function '")
+ << name << "' but with
diff erent type: " << foundFuncT
+ << " (expected " << funcT << ")";
+ }
+ return func;
+}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index feaffa34897b6..160b6ae89215c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -30,6 +30,7 @@ static constexpr llvm::StringRef kPrintF16 = "printF16";
static constexpr llvm::StringRef kPrintBF16 = "printBF16";
static constexpr llvm::StringRef kPrintF32 = "printF32";
static constexpr llvm::StringRef kPrintF64 = "printF64";
+static constexpr llvm::StringRef kPrintApFloat = "printApFloat";
static constexpr llvm::StringRef kPrintString = "printString";
static constexpr llvm::StringRef kPrintOpen = "printOpen";
static constexpr llvm::StringRef kPrintClose = "printClose";
@@ -160,6 +161,16 @@ mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables) {
+ return lookupOrCreateReservedFn(
+ b, moduleOp, kPrintApFloat,
+ {IntegerType::get(moduleOp->getContext(), 32),
+ IntegerType::get(moduleOp->getContext(), 64)},
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
+}
+
static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
return LLVM::LLVMPointerType::get(context);
}
diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
new file mode 100644
index 0000000000000..0a05f7369e556
--- /dev/null
+++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
@@ -0,0 +1,89 @@
+//===- APFloatWrappers.cpp - Software Implementation of FP Arithmetics --- ===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file exposes the APFloat infrastructure to MLIR programs as a runtime
+// library. APFloat is a software implementation of floating point arithmetics.
+//
+// On the MLIR side, floating-point values must be bitcasted to 64-bit integers
+// before calling a runtime function. If a floating-point type has less than
+// 64 bits, it must be zero-extended to 64 bits after bitcasting it to an
+// integer.
+//
+// Runtime functions receive the floating-point operands of the arithmeic
+// operation in the form of 64-bit integers, along with the APFloat semantics
+// in the form of a 32-bit integer, which will be interpreted as an
+// APFloatBase::Semantics enum value.
+//
+#include "llvm/ADT/APFloat.h"
+
+#ifdef _WIN32
+#ifndef MLIR_APFLOAT_WRAPPERS_EXPORT
+#ifdef mlir_apfloat_wrappers_EXPORTS
+// We are building this library
+#define MLIR_APFLOAT_WRAPPERS_EXPORT __declspec(dllexport)
+#else
+// We are using this library
+#define MLIR_APFLOAT_WRAPPERS_EXPORT __declspec(dllimport)
+#endif // mlir_apfloat_wrappers_EXPORTS
+#endif // MLIR_APFLOAT_WRAPPERS_EXPORT
+#else
+// Non-windows: use visibility attributes.
+#define MLIR_APFLOAT_WRAPPERS_EXPORT __attribute__((visibility("default")))
+#endif // _WIN32
+
+/// Binary operations without rounding mode.
+#define APFLOAT_BINARY_OP(OP) \
+ MLIR_APFLOAT_WRAPPERS_EXPORT int64_t _mlir_apfloat_##OP( \
+ int32_t semantics, uint64_t a, uint64_t b) { \
+ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \
+ static_cast<llvm::APFloatBase::Semantics>(semantics)); \
+ unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); \
+ llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a)); \
+ llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b)); \
+ lhs.OP(rhs); \
+ return lhs.bitcastToAPInt().getZExtValue(); \
+ }
+
+/// Binary operations with rounding mode.
+#define APFLOAT_BINARY_OP_ROUNDING_MODE(OP, ROUNDING_MODE) \
+ MLIR_APFLOAT_WRAPPERS_EXPORT int64_t _mlir_apfloat_##OP( \
+ int32_t semantics, uint64_t a, uint64_t b) { \
+ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \
+ static_cast<llvm::APFloatBase::Semantics>(semantics)); \
+ unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); \
+ llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a)); \
+ llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b)); \
+ lhs.OP(rhs, ROUNDING_MODE); \
+ return lhs.bitcastToAPInt().getZExtValue(); \
+ }
+
+extern "C" {
+
+#define BIN_OPS_WITH_ROUNDING(X) \
+ X(add, llvm::RoundingMode::NearestTiesToEven) \
+ X(subtract, llvm::RoundingMode::NearestTiesToEven) \
+ X(multiply, llvm::RoundingMode::NearestTiesToEven) \
+ X(divide, llvm::RoundingMode::NearestTiesToEven)
+
+BIN_OPS_WITH_ROUNDING(APFLOAT_BINARY_OP_ROUNDING_MODE)
+#undef BIN_OPS_WITH_ROUNDING
+#undef APFLOAT_BINARY_OP_ROUNDING_MODE
+
+APFLOAT_BINARY_OP(remainder)
+
+#undef APFLOAT_BINARY_OP
+
+MLIR_APFLOAT_WRAPPERS_EXPORT void printApFloat(int32_t semantics, uint64_t a) {
+ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+ static_cast<llvm::APFloatBase::Semantics>(semantics));
+ unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
+ llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
+ double d = x.convertToDouble();
+ fprintf(stdout, "%lg", d);
+}
+}
diff --git a/mlir/lib/ExecutionEngine/CMakeLists.txt b/mlir/lib/ExecutionEngine/CMakeLists.txt
index fdeb4dacf9278..8c09e50e4de7b 100644
--- a/mlir/lib/ExecutionEngine/CMakeLists.txt
+++ b/mlir/lib/ExecutionEngine/CMakeLists.txt
@@ -2,6 +2,7 @@
# is a big dependency which most don't need.
set(LLVM_OPTIONAL_SOURCES
+ APFloatWrappers.cpp
ArmRunnerUtils.cpp
ArmSMEStubs.cpp
AsyncRuntime.cpp
@@ -167,6 +168,15 @@ if(LLVM_ENABLE_PIC)
set_property(TARGET mlir_float16_utils PROPERTY CXX_STANDARD 17)
target_compile_definitions(mlir_float16_utils PRIVATE mlir_float16_utils_EXPORTS)
+ add_mlir_library(mlir_apfloat_wrappers
+ SHARED
+ APFloatWrappers.cpp
+
+ EXCLUDE_FROM_LIBMLIR
+ )
+ set_property(TARGET mlir_apfloat_wrappers PROPERTY CXX_STANDARD 17)
+ target_compile_definitions(mlir_apfloat_wrappers PRIVATE mlir_apfloat_wrappers_EXPORTS)
+
add_subdirectory(SparseTensor)
add_mlir_library(mlir_c_runner_utils
@@ -177,6 +187,7 @@ if(LLVM_ENABLE_PIC)
EXCLUDE_FROM_LIBMLIR
LINK_LIBS PUBLIC
+ mlir_apfloat_wrappers
mlir_float16_utils
MLIRSparseTensorEnums
MLIRSparseTensorRuntime
@@ -191,6 +202,7 @@ if(LLVM_ENABLE_PIC)
EXCLUDE_FROM_LIBMLIR
LINK_LIBS PUBLIC
+ mlir_apfloat_wrappers
mlir_float16_utils
)
target_compile_definitions(mlir_runner_utils PRIVATE mlir_runner_utils_EXPORTS)
diff --git a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
new file mode 100644
index 0000000000000..797f42c37a26f
--- /dev/null
+++ b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
@@ -0,0 +1,128 @@
+// RUN: mlir-opt %s --convert-arith-to-apfloat -split-input-file -verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL: func.func private @_mlir_apfloat_add(i32, i64, i64) -> i64
+
+// CHECK-LABEL: func.func @foo() -> f8E4M3FN {
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant 2.250000e+00 : f8E4M3FN
+// CHECK: return %[[CONSTANT_0]] : f8E4M3FN
+// CHECK: }
+
+// CHECK-LABEL: func.func @bar() -> f6E3M2FN {
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant 3.000000e+00 : f6E3M2FN
+// CHECK: return %[[CONSTANT_0]] : f6E3M2FN
+// CHECK: }
+
+// Illustrate that both f8E4M3FN and f6E3M2FN calling the same _mlir_apfloat_add is fine
+// because each gets its own semantics enum and gets bitcast/extui/trunci to its own width.
+// CHECK-LABEL: func.func @full_example() {
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant 1.375000e+00 : f8E4M3FN
+// CHECK: %[[VAL_0:.*]] = call @foo() : () -> f8E4M3FN
+// CHECK: %[[BITCAST_0:.*]] = arith.bitcast %[[CONSTANT_0]] : f8E4M3FN to i8
+// CHECK: %[[EXTUI_0:.*]] = arith.extui %[[BITCAST_0]] : i8 to i64
+// CHECK: %[[BITCAST_1:.*]] = arith.bitcast %[[VAL_0]] : f8E4M3FN to i8
+// CHECK: %[[EXTUI_1:.*]] = arith.extui %[[BITCAST_1]] : i8 to i64
+// // fltSemantics semantics for f8E4M3FN
+// CHECK: %[[CONSTANT_1:.*]] = arith.constant 10 : i32
+// CHECK: %[[VAL_1:.*]] = call @_mlir_apfloat_add(%[[CONSTANT_1]], %[[EXTUI_0]], %[[EXTUI_1]]) : (i32, i64, i64) -> i64
+// CHECK: %[[TRUNCI_0:.*]] = arith.trunci %[[VAL_1]] : i64 to i8
+// CHECK: %[[BITCAST_2:.*]] = arith.bitcast %[[TRUNCI_0]] : i8 to f8E4M3FN
+// CHECK: vector.print %[[BITCAST_2]] : f8E4M3FN
+
+// CHECK: %[[CONSTANT_2:.*]] = arith.constant 2.500000e+00 : f6E3M2FN
+// CHECK: %[[VAL_2:.*]] = call @bar() : () -> f6E3M2FN
+// CHECK: %[[BITCAST_3:.*]] = arith.bitcast %[[CONSTANT_2]] : f6E3M2FN to i6
+// CHECK: %[[EXTUI_2:.*]] = arith.extui %[[BITCAST_3]] : i6 to i64
+// CHECK: %[[BITCAST_4:.*]] = arith.bitcast %[[VAL_2]] : f6E3M2FN to i6
+// CHECK: %[[EXTUI_3:.*]] = arith.extui %[[BITCAST_4]] : i6 to i64
+// // fltSemantics semantics for f6E3M2FN
+// CHECK: %[[CONSTANT_3:.*]] = arith.constant 16 : i32
+// CHECK: %[[VAL_3:.*]] = call @_mlir_apfloat_add(%[[CONSTANT_3]], %[[EXTUI_2]], %[[EXTUI_3]]) : (i32, i64, i64) -> i64
+// CHECK: %[[TRUNCI_1:.*]] = arith.trunci %[[VAL_3]] : i64 to i6
+// CHECK: %[[BITCAST_5:.*]] = arith.bitcast %[[TRUNCI_1]] : i6 to f6E3M2FN
+// CHECK: vector.print %[[BITCAST_5]] : f6E3M2FN
+// CHECK: return
+// CHECK: }
+
+// Put rhs into separate function so that it won't be constant-folded.
+func.func @foo() -> f8E4M3FN {
+ %cst = arith.constant 2.2 : f8E4M3FN
+ return %cst : f8E4M3FN
+}
+
+func.func @bar() -> f6E3M2FN {
+ %cst = arith.constant 3.2 : f6E3M2FN
+ return %cst : f6E3M2FN
+}
+
+func.func @full_example() {
+ %a = arith.constant 1.4 : f8E4M3FN
+ %b = func.call @foo() : () -> (f8E4M3FN)
+ %c = arith.addf %a, %b : f8E4M3FN
+ vector.print %c : f8E4M3FN
+
+ %d = arith.constant 2.4 : f6E3M2FN
+ %e = func.call @bar() : () -> (f6E3M2FN)
+ %f = arith.addf %d, %e : f6E3M2FN
+ vector.print %f : f6E3M2FN
+ return
+}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_add(i32, i64, i64) -> i64
+// CHECK: %[[sem:.*]] = arith.constant 18 : i32
+// CHECK: call @_mlir_apfloat_add(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
+func.func @addf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
+ %0 = arith.addf %arg0, %arg1 : f4E2M1FN
+ return
+}
+
+// -----
+
+// Test decl collision (
diff erent type)
+// expected-error at +1{{matched function '_mlir_apfloat_add' but with
diff erent type: '(i32, i32, f32) -> index' (expected '(i32, i64, i64) -> i64')}}
+func.func private @_mlir_apfloat_add(i32, i32, f32) -> index
+func.func @addf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
+ %0 = arith.addf %arg0, %arg1 : f4E2M1FN
+ return
+}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_subtract(i32, i64, i64) -> i64
+// CHECK: %[[sem:.*]] = arith.constant 18 : i32
+// CHECK: call @_mlir_apfloat_subtract(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
+func.func @subf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
+ %0 = arith.subf %arg0, %arg1 : f4E2M1FN
+ return
+}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_multiply(i32, i64, i64) -> i64
+// CHECK: %[[sem:.*]] = arith.constant 18 : i32
+// CHECK: call @_mlir_apfloat_multiply(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
+func.func @subf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
+ %0 = arith.mulf %arg0, %arg1 : f4E2M1FN
+ return
+}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_divide(i32, i64, i64) -> i64
+// CHECK: %[[sem:.*]] = arith.constant 18 : i32
+// CHECK: call @_mlir_apfloat_divide(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
+func.func @subf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
+ %0 = arith.divf %arg0, %arg1 : f4E2M1FN
+ return
+}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_remainder(i32, i64, i64) -> i64
+// CHECK: %[[sem:.*]] = arith.constant 18 : i32
+// CHECK: call @_mlir_apfloat_remainder(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
+func.func @remf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
+ %0 = arith.remf %arg0, %arg1 : f4E2M1FN
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
new file mode 100644
index 0000000000000..2768afe0834b5
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
@@ -0,0 +1,36 @@
+// Case 1: All floating-point arithmetics is lowered through APFloat.
+// RUN: mlir-opt %s --convert-arith-to-apfloat --convert-to-llvm | \
+// RUN: mlir-runner -e entry --entry-point-result=void \
+// RUN: --shared-libs=%mlir_c_runner_utils \
+// RUN: --shared-libs=%mlir_apfloat_wrappers | FileCheck %s
+
+// Case 2: Only unsupported arithmetics (f8E4M3FN) is lowered through APFloat.
+// Arithmetics on f32 is lowered directly to LLVM.
+// RUN: mlir-opt %s --convert-to-llvm --convert-arith-to-apfloat \
+// RUN: --convert-to-llvm --reconcile-unrealized-casts | \
+// RUN: mlir-runner -e entry --entry-point-result=void \
+// RUN: --shared-libs=%mlir_c_runner_utils \
+// RUN: --shared-libs=%mlir_apfloat_wrappers | FileCheck %s
+
+// Put rhs into separate function so that it won't be constant-folded.
+func.func @foo() -> (f8E4M3FN, f32) {
+ %cst1 = arith.constant 2.2 : f8E4M3FN
+ %cst2 = arith.constant 2.2 : f32
+ return %cst1, %cst2 : f8E4M3FN, f32
+}
+
+func.func @entry() {
+ %a1 = arith.constant 1.4 : f8E4M3FN
+ %a2 = arith.constant 1.4 : f32
+ %b1, %b2 = func.call @foo() : () -> (f8E4M3FN, f32)
+ %c1 = arith.addf %a1, %b1 : f8E4M3FN // not supported by LLVM
+ %c2 = arith.addf %a2, %b2 : f32 // supported by LLVM
+
+ // CHECK: 3.5
+ vector.print %c1 : f8E4M3FN
+
+ // CHECK: 3.6
+ vector.print %c2 : f32
+
+ return
+}
diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py
index 6ff12d66523f5..4a38ed605be0c 100644
--- a/mlir/test/lit.cfg.py
+++ b/mlir/test/lit.cfg.py
@@ -208,6 +208,7 @@ def find_real_python_interpreter():
add_runtime("mlir_c_runner_utils"),
add_runtime("mlir_async_runtime"),
add_runtime("mlir_float16_utils"),
+ add_runtime("mlir_apfloat_wrappers"),
"mlir-linalg-ods-yaml-gen",
"mlir-reduce",
"mlir-pdll",
More information about the Mlir-commits
mailing list