[Mlir-commits] [mlir] Revert "[mlir] Add FP software implementation lowering pass: `arith-to-apfloat` (#166618)" (PR #167429)
Maksim Levental
llvmlistbot at llvm.org
Mon Nov 10 16:51:52 PST 2025
https://github.com/makslevental created https://github.com/llvm/llvm-project/pull/167429
This reverts commit 222f4e494a0cd9515c242fd083c2776772734385.
>From f2bf54caa9dfceffa1e260c914289da47a571eab Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Mon, 10 Nov 2025 16:51:09 -0800
Subject: [PATCH] Revert "[mlir] Add FP software implementation lowering pass:
`arith-to-apfloat` (#166618)"
This reverts commit 222f4e494a0cd9515c242fd083c2776772734385.
---
.../ArithToAPFloat/ArithToAPFloat.h | 21 ---
mlir/include/mlir/Conversion/Passes.h | 1 -
mlir/include/mlir/Conversion/Passes.td | 15 --
mlir/include/mlir/Dialect/Func/Utils/Utils.h | 7 -
.../mlir/Dialect/LLVMIR/FunctionCallUtils.h | 4 -
.../ArithToAPFloat/ArithToAPFloat.cpp | 161 ------------------
.../Conversion/ArithToAPFloat/CMakeLists.txt | 17 --
.../Conversion/ArithToLLVM/ArithToLLVM.cpp | 1 -
mlir/lib/Conversion/CMakeLists.txt | 1 -
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 14 --
mlir/lib/Dialect/Func/Utils/Utils.cpp | 25 ---
.../Dialect/LLVMIR/IR/FunctionCallUtils.cpp | 11 --
mlir/lib/ExecutionEngine/APFloatWrappers.cpp | 81 ---------
mlir/lib/ExecutionEngine/CMakeLists.txt | 12 --
.../ArithToApfloat/arith-to-apfloat.mlir | 128 --------------
.../Arith/CPU/test-apfloat-emulation.mlir | 34 ----
16 files changed, 533 deletions(-)
delete mode 100644 mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h
delete mode 100644 mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
delete mode 100644 mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt
delete mode 100644 mlir/lib/ExecutionEngine/APFloatWrappers.cpp
delete mode 100644 mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
delete mode 100644 mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
diff --git a/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h b/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h
deleted file mode 100644
index 64a42a228199e..0000000000000
--- a/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h
+++ /dev/null
@@ -1,21 +0,0 @@
-//===- ArithToAPFloat.h - Arith to APFloat impl conversion ---*- C++ ----*-===//
-//
-// Part of the APFloat Project, under the Apache License v2.0 with APFloat
-// Exceptions. See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH APFloat-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H
-#define MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H
-
-#include <memory>
-
-namespace mlir {
-class Pass;
-
-#define GEN_PASS_DECL_ARITHTOAPFLOATCONVERSIONPASS
-#include "mlir/Conversion/Passes.h.inc"
-} // namespace mlir
-
-#endif // MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 82bdfd02661a6..40d866ec7bf10 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -12,7 +12,6 @@
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
-#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 79bc380dbcb7a..70e3e45c225db 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -186,21 +186,6 @@ def ArithToLLVMConversionPass : Pass<"convert-arith-to-llvm"> {
];
}
-//===----------------------------------------------------------------------===//
-// ArithToAPFloat
-//===----------------------------------------------------------------------===//
-
-def ArithToAPFloatConversionPass
- : Pass<"convert-arith-to-apfloat", "ModuleOp"> {
- let summary = "Convert Arith ops to APFloat runtime library calls";
- let description = [{
- This pass converts supported Arith ops to APFloat-based runtime library
- calls (APFloatWrappers.cpp). APFloat is a software implementation of
- floating-point arithmetic operations.
- }];
- let dependentDialects = ["func::FuncDialect"];
-}
-
//===----------------------------------------------------------------------===//
// ArithToSPIRV
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
index 00d50874a2e8d..3576126a487ac 100644
--- a/mlir/include/mlir/Dialect/Func/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
@@ -60,13 +60,6 @@ mlir::FailureOr<std::pair<mlir::func::FuncOp, mlir::func::CallOp>>
deduplicateArgsOfFuncOp(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp,
mlir::ModuleOp moduleOp);
-/// Look up a FuncOp with signature `resultTypes`(`paramTypes`)` and name
-/// `name`. Return a failure if the FuncOp is found but with a different
-/// signature.
-FailureOr<FuncOp> lookupFnDecl(SymbolOpInterface symTable, StringRef name,
- FunctionType funcT,
- SymbolTableCollection *symbolTables = nullptr);
-
} // namespace func
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index b09d32022e348..8ad9ed18acebd 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -52,10 +52,6 @@ lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp,
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables = nullptr);
-FailureOr<LLVM::LLVMFuncOp>
-lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp,
- SymbolTableCollection *symbolTables = nullptr);
-
/// Declares a function to print a C-string.
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
/// have the signature void(char const*). The default function is `printString`.
diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
deleted file mode 100644
index 012e934d3050f..0000000000000
--- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
+++ /dev/null
@@ -1,161 +0,0 @@
-//===- ArithToAPFloat.cpp - Arithmetic to APFloat Conversion --------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Arith/Transforms/Passes.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Func/Utils/Utils.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/Verifier.h"
-#include "mlir/Transforms/WalkPatternRewriteDriver.h"
-#include "llvm/Support/Debug.h"
-
-#define DEBUG_TYPE "arith-to-apfloat"
-
-namespace mlir {
-#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS
-#include "mlir/Conversion/Passes.h.inc"
-} // namespace mlir
-
-using namespace mlir;
-using namespace mlir::func;
-
-static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
- StringRef name, FunctionType funcT, bool setPrivate,
- SymbolTableCollection *symbolTables = nullptr) {
- OpBuilder::InsertionGuard g(b);
- assert(!symTable->getRegion(0).empty() && "expected non-empty region");
- b.setInsertionPointToStart(&symTable->getRegion(0).front());
- FuncOp funcOp = FuncOp::create(b, symTable->getLoc(), name, funcT);
- if (setPrivate)
- funcOp.setPrivate();
- if (symbolTables) {
- SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable);
- symbolTable.insert(funcOp, symTable->getRegion(0).front().begin());
- }
- return funcOp;
-}
-
-/// Helper function to look up or create the symbol for a runtime library
-/// function for a binary arithmetic operation.
-///
-/// Parameter 1: APFloat semantics
-/// Parameter 2: Left-hand side operand
-/// Parameter 3: Right-hand side operand
-///
-/// This function will return a failure if the function is found but has an
-/// unexpected signature.
-///
-static FailureOr<FuncOp>
-lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
- SymbolTableCollection *symbolTables = nullptr) {
- auto i32Type = IntegerType::get(symTable->getContext(), 32);
- auto i64Type = IntegerType::get(symTable->getContext(), 64);
-
- std::string funcName = (llvm::Twine("__mlir_apfloat_") + name).str();
- FunctionType funcT =
- FunctionType::get(b.getContext(), {i32Type, i64Type, i64Type}, {i64Type});
- FailureOr<FuncOp> func =
- lookupFnDecl(symTable, funcName, funcT, symbolTables);
- // Failed due to type mismatch.
- if (failed(func))
- return func;
- // Successfully matched existing decl.
- if (*func)
- return *func;
-
- return createFnDecl(b, symTable, funcName, funcT,
- /*setPrivate=*/true, symbolTables);
-}
-
-/// Rewrite a binary arithmetic operation to an APFloat function call.
-template <typename OpTy, const char *APFloatName>
-struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
- BinaryArithOpToAPFloatConversion(MLIRContext *context, PatternBenefit benefit,
- SymbolOpInterface symTable)
- : OpRewritePattern<OpTy>(context, benefit), symTable(symTable) {};
-
- LogicalResult matchAndRewrite(OpTy op,
- PatternRewriter &rewriter) const override {
- // Get APFloat function from runtime library.
- FailureOr<FuncOp> fn =
- lookupOrCreateBinaryFn(rewriter, symTable, APFloatName);
- if (failed(fn))
- return fn;
-
- rewriter.setInsertionPoint(op);
- // Cast operands to 64-bit integers.
- Location loc = op.getLoc();
- auto floatTy = cast<FloatType>(op.getType());
- auto intWType = rewriter.getIntegerType(floatTy.getWidth());
- auto int64Type = rewriter.getI64Type();
- Value lhsBits = arith::ExtUIOp::create(
- rewriter, loc, int64Type,
- arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
- Value rhsBits = arith::ExtUIOp::create(
- rewriter, loc, int64Type,
- arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));
-
- // Call APFloat function.
- int32_t sem =
- llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
- Value semValue = arith::ConstantOp::create(
- rewriter, loc, rewriter.getI32Type(),
- rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
- SmallVector<Value> params = {semValue, lhsBits, rhsBits};
- auto resultOp =
- func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
- SymbolRefAttr::get(*fn), params);
-
- // Truncate result to the original width.
- Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
- resultOp->getResult(0));
- rewriter.replaceOp(
- op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
- return success();
- }
-
- SymbolOpInterface symTable;
-};
-
-namespace {
-struct ArithToAPFloatConversionPass final
- : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
- using Base::Base;
-
- void runOnOperation() override {
- MLIRContext *context = &getContext();
- RewritePatternSet patterns(context);
- static const char add[] = "add";
- static const char subtract[] = "subtract";
- static const char multiply[] = "multiply";
- static const char divide[] = "divide";
- static const char remainder[] = "remainder";
- patterns.add<BinaryArithOpToAPFloatConversion<arith::AddFOp, add>,
- BinaryArithOpToAPFloatConversion<arith::SubFOp, subtract>,
- BinaryArithOpToAPFloatConversion<arith::MulFOp, multiply>,
- BinaryArithOpToAPFloatConversion<arith::DivFOp, divide>,
- BinaryArithOpToAPFloatConversion<arith::RemFOp, remainder>>(
- context, 1, getOperation());
- LogicalResult result = success();
- ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
- if (diag.getSeverity() == DiagnosticSeverity::Error) {
- result = failure();
- }
- // NB: if you don't return failure, no other diag handlers will fire (see
- // mlir/lib/IR/Diagnostics.cpp:DiagnosticEngineImpl::emit).
- return failure();
- });
- walkAndApplyPatterns(getOperation(), std::move(patterns));
- if (failed(result))
- return signalPassFailure();
- }
-};
-} // namespace
diff --git a/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt b/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt
deleted file mode 100644
index b0d1e46b3655f..0000000000000
--- a/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt
+++ /dev/null
@@ -1,17 +0,0 @@
-add_mlir_conversion_library(MLIRArithToAPFloat
- ArithToAPFloat.cpp
-
- ADDITIONAL_HEADER_DIRS
- ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM
-
- DEPENDS
- MLIRConversionPassIncGen
-
- LINK_COMPONENTS
- Core
-
- LINK_LIBS PUBLIC
- MLIRArithDialect
- MLIRArithTransforms
- MLIRFuncDialect
- )
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index f2bacc3399144..b6099902cc337 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -14,7 +14,6 @@
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
-#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/TypeUtilities.h"
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 613dc6d242ceb..bebf1b8fff3f9 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -2,7 +2,6 @@ add_subdirectory(AffineToStandard)
add_subdirectory(AMDGPUToROCDL)
add_subdirectory(ArithCommon)
add_subdirectory(ArithToAMDGPU)
-add_subdirectory(ArithToAPFloat)
add_subdirectory(ArithToArmSME)
add_subdirectory(ArithToEmitC)
add_subdirectory(ArithToLLVM)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index c747e1b59558a..69a317ecd101f 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1654,20 +1654,6 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
return failure();
}
}
- } else if (auto floatTy = dyn_cast<FloatType>(printType)) {
- // Print other floating-point types using the APFloat runtime library.
- int32_t sem =
- llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
- Value semValue = LLVM::ConstantOp::create(
- rewriter, loc, rewriter.getI32Type(),
- rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
- Value floatBits =
- LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(), value);
- printer =
- LLVM::lookupOrCreateApFloatPrintFn(rewriter, parent, symbolTables);
- emitCall(rewriter, loc, printer.value(),
- ValueRange({semValue, floatBits}));
- return success();
} else {
return failure();
}
diff --git a/mlir/lib/Dialect/Func/Utils/Utils.cpp b/mlir/lib/Dialect/Func/Utils/Utils.cpp
index d6dfd0229963c..b4cb0932ef631 100644
--- a/mlir/lib/Dialect/Func/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Func/Utils/Utils.cpp
@@ -254,28 +254,3 @@ func::deduplicateArgsOfFuncOp(RewriterBase &rewriter, func::FuncOp funcOp,
return std::make_pair(*newFuncOpOrFailure, newCallOp);
}
-
-FailureOr<func::FuncOp>
-func::lookupFnDecl(SymbolOpInterface symTable, StringRef name,
- FunctionType funcT, SymbolTableCollection *symbolTables) {
- FuncOp func;
- if (symbolTables) {
- func = symbolTables->lookupSymbolIn<FuncOp>(
- symTable, StringAttr::get(symTable->getContext(), name));
- } else {
- func = llvm::dyn_cast_or_null<FuncOp>(
- SymbolTable::lookupSymbolIn(symTable, name));
- }
-
- if (!func)
- return func;
-
- mlir::FunctionType foundFuncT = func.getFunctionType();
- // Assert the signature of the found function is same as expected
- if (funcT != foundFuncT) {
- return func.emitError("matched function '")
- << name << "' but with different type: " << foundFuncT
- << " (expected " << funcT << ")";
- }
- return func;
-}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index 160b6ae89215c..feaffa34897b6 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -30,7 +30,6 @@ static constexpr llvm::StringRef kPrintF16 = "printF16";
static constexpr llvm::StringRef kPrintBF16 = "printBF16";
static constexpr llvm::StringRef kPrintF32 = "printF32";
static constexpr llvm::StringRef kPrintF64 = "printF64";
-static constexpr llvm::StringRef kPrintApFloat = "printApFloat";
static constexpr llvm::StringRef kPrintString = "printString";
static constexpr llvm::StringRef kPrintOpen = "printOpen";
static constexpr llvm::StringRef kPrintClose = "printClose";
@@ -161,16 +160,6 @@ mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
-FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp,
- SymbolTableCollection *symbolTables) {
- return lookupOrCreateReservedFn(
- b, moduleOp, kPrintApFloat,
- {IntegerType::get(moduleOp->getContext(), 32),
- IntegerType::get(moduleOp->getContext(), 64)},
- LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
-}
-
static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
return LLVM::LLVMPointerType::get(context);
}
diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
deleted file mode 100644
index 85ea0986cde5b..0000000000000
--- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
+++ /dev/null
@@ -1,81 +0,0 @@
-//===- APFloatWrappers.cpp - Software Implementation of FP Arithmetics --- ===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file exposes the APFloat infrastructure to MLIR programs as a runtime
-// library. APFloat is a software implementation of floating point arithmetics.
-//
-// On the MLIR side, floating-point values must be bitcasted to 64-bit integers
-// before calling a runtime function. If a floating-point type has less than
-// 64 bits, it must be zero-extended to 64 bits after bitcasting it to an
-// integer.
-//
-// Runtime functions receive the floating-point operands of the arithmeic
-// operation in the form of 64-bit integers, along with the APFloat semantics
-// in the form of a 32-bit integer, which will be interpreted as an
-// APFloatBase::Semantics enum value.
-//
-#include "llvm/ADT/APFloat.h"
-
-#if (defined(_WIN32) || defined(__CYGWIN__))
-#define MLIR_APFLOAT_WRAPPERS_EXPORTED __declspec(dllexport)
-#else
-#define MLIR_APFLOAT_WRAPPERS_EXPORTED __attribute__((visibility("default")))
-#endif
-
-/// Binary operations without rounding mode.
-#define APFLOAT_BINARY_OP(OP) \
- int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED __mlir_apfloat_##OP( \
- int32_t semantics, uint64_t a, uint64_t b) { \
- const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \
- static_cast<llvm::APFloatBase::Semantics>(semantics)); \
- unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); \
- llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a)); \
- llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b)); \
- lhs.OP(rhs); \
- return lhs.bitcastToAPInt().getZExtValue(); \
- }
-
-/// Binary operations with rounding mode.
-#define APFLOAT_BINARY_OP_ROUNDING_MODE(OP, ROUNDING_MODE) \
- int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED __mlir_apfloat_##OP( \
- int32_t semantics, uint64_t a, uint64_t b) { \
- const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \
- static_cast<llvm::APFloatBase::Semantics>(semantics)); \
- unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); \
- llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a)); \
- llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b)); \
- lhs.OP(rhs, ROUNDING_MODE); \
- return lhs.bitcastToAPInt().getZExtValue(); \
- }
-
-extern "C" {
-
-#define BIN_OPS_WITH_ROUNDING(X) \
- X(add, llvm::RoundingMode::NearestTiesToEven) \
- X(subtract, llvm::RoundingMode::NearestTiesToEven) \
- X(multiply, llvm::RoundingMode::NearestTiesToEven) \
- X(divide, llvm::RoundingMode::NearestTiesToEven)
-
-BIN_OPS_WITH_ROUNDING(APFLOAT_BINARY_OP_ROUNDING_MODE)
-#undef BIN_OPS_WITH_ROUNDING
-#undef APFLOAT_BINARY_OP_ROUNDING_MODE
-
-APFLOAT_BINARY_OP(remainder)
-
-#undef APFLOAT_BINARY_OP
-
-void MLIR_APFLOAT_WRAPPERS_EXPORTED printApFloat(int32_t semantics,
- uint64_t a) {
- const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
- static_cast<llvm::APFloatBase::Semantics>(semantics));
- unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
- llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
- double d = x.convertToDouble();
- fprintf(stdout, "%lg", d);
-}
-}
diff --git a/mlir/lib/ExecutionEngine/CMakeLists.txt b/mlir/lib/ExecutionEngine/CMakeLists.txt
index 8c09e50e4de7b..fdeb4dacf9278 100644
--- a/mlir/lib/ExecutionEngine/CMakeLists.txt
+++ b/mlir/lib/ExecutionEngine/CMakeLists.txt
@@ -2,7 +2,6 @@
# is a big dependency which most don't need.
set(LLVM_OPTIONAL_SOURCES
- APFloatWrappers.cpp
ArmRunnerUtils.cpp
ArmSMEStubs.cpp
AsyncRuntime.cpp
@@ -168,15 +167,6 @@ if(LLVM_ENABLE_PIC)
set_property(TARGET mlir_float16_utils PROPERTY CXX_STANDARD 17)
target_compile_definitions(mlir_float16_utils PRIVATE mlir_float16_utils_EXPORTS)
- add_mlir_library(mlir_apfloat_wrappers
- SHARED
- APFloatWrappers.cpp
-
- EXCLUDE_FROM_LIBMLIR
- )
- set_property(TARGET mlir_apfloat_wrappers PROPERTY CXX_STANDARD 17)
- target_compile_definitions(mlir_apfloat_wrappers PRIVATE mlir_apfloat_wrappers_EXPORTS)
-
add_subdirectory(SparseTensor)
add_mlir_library(mlir_c_runner_utils
@@ -187,7 +177,6 @@ if(LLVM_ENABLE_PIC)
EXCLUDE_FROM_LIBMLIR
LINK_LIBS PUBLIC
- mlir_apfloat_wrappers
mlir_float16_utils
MLIRSparseTensorEnums
MLIRSparseTensorRuntime
@@ -202,7 +191,6 @@ if(LLVM_ENABLE_PIC)
EXCLUDE_FROM_LIBMLIR
LINK_LIBS PUBLIC
- mlir_apfloat_wrappers
mlir_float16_utils
)
target_compile_definitions(mlir_runner_utils PRIVATE mlir_runner_utils_EXPORTS)
diff --git a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
deleted file mode 100644
index fe4d28a56f808..0000000000000
--- a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
+++ /dev/null
@@ -1,128 +0,0 @@
-// RUN: mlir-opt %s --convert-arith-to-apfloat -split-input-file -verify-diagnostics | FileCheck %s
-
-// CHECK-LABEL: func.func private @__mlir_apfloat_add(i32, i64, i64) -> i64
-
-// CHECK-LABEL: func.func @foo() -> f8E4M3FN {
-// CHECK: %[[CONSTANT_0:.*]] = arith.constant 2.250000e+00 : f8E4M3FN
-// CHECK: return %[[CONSTANT_0]] : f8E4M3FN
-// CHECK: }
-
-// CHECK-LABEL: func.func @bar() -> f6E3M2FN {
-// CHECK: %[[CONSTANT_0:.*]] = arith.constant 3.000000e+00 : f6E3M2FN
-// CHECK: return %[[CONSTANT_0]] : f6E3M2FN
-// CHECK: }
-
-// Illustrate that both f8E4M3FN and f6E3M2FN calling the same __mlir_apfloat_add is fine
-// because each gets its own semantics enum and gets bitcast/extui/trunci to its own width.
-// CHECK-LABEL: func.func @full_example() {
-// CHECK: %[[CONSTANT_0:.*]] = arith.constant 1.375000e+00 : f8E4M3FN
-// CHECK: %[[VAL_0:.*]] = call @foo() : () -> f8E4M3FN
-// CHECK: %[[BITCAST_0:.*]] = arith.bitcast %[[CONSTANT_0]] : f8E4M3FN to i8
-// CHECK: %[[EXTUI_0:.*]] = arith.extui %[[BITCAST_0]] : i8 to i64
-// CHECK: %[[BITCAST_1:.*]] = arith.bitcast %[[VAL_0]] : f8E4M3FN to i8
-// CHECK: %[[EXTUI_1:.*]] = arith.extui %[[BITCAST_1]] : i8 to i64
-// // fltSemantics semantics for f8E4M3FN
-// CHECK: %[[CONSTANT_1:.*]] = arith.constant 10 : i32
-// CHECK: %[[VAL_1:.*]] = call @__mlir_apfloat_add(%[[CONSTANT_1]], %[[EXTUI_0]], %[[EXTUI_1]]) : (i32, i64, i64) -> i64
-// CHECK: %[[TRUNCI_0:.*]] = arith.trunci %[[VAL_1]] : i64 to i8
-// CHECK: %[[BITCAST_2:.*]] = arith.bitcast %[[TRUNCI_0]] : i8 to f8E4M3FN
-// CHECK: vector.print %[[BITCAST_2]] : f8E4M3FN
-
-// CHECK: %[[CONSTANT_2:.*]] = arith.constant 2.500000e+00 : f6E3M2FN
-// CHECK: %[[VAL_2:.*]] = call @bar() : () -> f6E3M2FN
-// CHECK: %[[BITCAST_3:.*]] = arith.bitcast %[[CONSTANT_2]] : f6E3M2FN to i6
-// CHECK: %[[EXTUI_2:.*]] = arith.extui %[[BITCAST_3]] : i6 to i64
-// CHECK: %[[BITCAST_4:.*]] = arith.bitcast %[[VAL_2]] : f6E3M2FN to i6
-// CHECK: %[[EXTUI_3:.*]] = arith.extui %[[BITCAST_4]] : i6 to i64
-// // fltSemantics semantics for f6E3M2FN
-// CHECK: %[[CONSTANT_3:.*]] = arith.constant 16 : i32
-// CHECK: %[[VAL_3:.*]] = call @__mlir_apfloat_add(%[[CONSTANT_3]], %[[EXTUI_2]], %[[EXTUI_3]]) : (i32, i64, i64) -> i64
-// CHECK: %[[TRUNCI_1:.*]] = arith.trunci %[[VAL_3]] : i64 to i6
-// CHECK: %[[BITCAST_5:.*]] = arith.bitcast %[[TRUNCI_1]] : i6 to f6E3M2FN
-// CHECK: vector.print %[[BITCAST_5]] : f6E3M2FN
-// CHECK: return
-// CHECK: }
-
-// Put rhs into separate function so that it won't be constant-folded.
-func.func @foo() -> f8E4M3FN {
- %cst = arith.constant 2.2 : f8E4M3FN
- return %cst : f8E4M3FN
-}
-
-func.func @bar() -> f6E3M2FN {
- %cst = arith.constant 3.2 : f6E3M2FN
- return %cst : f6E3M2FN
-}
-
-func.func @full_example() {
- %a = arith.constant 1.4 : f8E4M3FN
- %b = func.call @foo() : () -> (f8E4M3FN)
- %c = arith.addf %a, %b : f8E4M3FN
- vector.print %c : f8E4M3FN
-
- %d = arith.constant 2.4 : f6E3M2FN
- %e = func.call @bar() : () -> (f6E3M2FN)
- %f = arith.addf %d, %e : f6E3M2FN
- vector.print %f : f6E3M2FN
- return
-}
-
-// -----
-
-// CHECK: func.func private @__mlir_apfloat_add(i32, i64, i64) -> i64
-// CHECK: %[[sem:.*]] = arith.constant 18 : i32
-// CHECK: call @__mlir_apfloat_add(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
-func.func @addf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
- %0 = arith.addf %arg0, %arg1 : f4E2M1FN
- return
-}
-
-// -----
-
-// Test decl collision (different type)
-// expected-error at +1{{matched function '__mlir_apfloat_add' but with different type: '(i32, i32, f32) -> index' (expected '(i32, i64, i64) -> i64')}}
-func.func private @__mlir_apfloat_add(i32, i32, f32) -> index
-func.func @addf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
- %0 = arith.addf %arg0, %arg1 : f4E2M1FN
- return
-}
-
-// -----
-
-// CHECK: func.func private @__mlir_apfloat_subtract(i32, i64, i64) -> i64
-// CHECK: %[[sem:.*]] = arith.constant 18 : i32
-// CHECK: call @__mlir_apfloat_subtract(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
-func.func @subf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
- %0 = arith.subf %arg0, %arg1 : f4E2M1FN
- return
-}
-
-// -----
-
-// CHECK: func.func private @__mlir_apfloat_multiply(i32, i64, i64) -> i64
-// CHECK: %[[sem:.*]] = arith.constant 18 : i32
-// CHECK: call @__mlir_apfloat_multiply(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
-func.func @subf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
- %0 = arith.mulf %arg0, %arg1 : f4E2M1FN
- return
-}
-
-// -----
-
-// CHECK: func.func private @__mlir_apfloat_divide(i32, i64, i64) -> i64
-// CHECK: %[[sem:.*]] = arith.constant 18 : i32
-// CHECK: call @__mlir_apfloat_divide(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
-func.func @subf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
- %0 = arith.divf %arg0, %arg1 : f4E2M1FN
- return
-}
-
-// -----
-
-// CHECK: func.func private @__mlir_apfloat_remainder(i32, i64, i64) -> i64
-// CHECK: %[[sem:.*]] = arith.constant 18 : i32
-// CHECK: call @__mlir_apfloat_remainder(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
-func.func @remf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
- %0 = arith.remf %arg0, %arg1 : f4E2M1FN
- return
-}
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
deleted file mode 100644
index a2b3eb73a60b8..0000000000000
--- a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
+++ /dev/null
@@ -1,34 +0,0 @@
-// Case 1: All floating-point arithmetics is lowered through APFloat.
-// RUN: mlir-opt %s --convert-arith-to-apfloat --convert-to-llvm | \
-// RUN: mlir-runner -e entry --entry-point-result=void \
-// RUN: --shared-libs=%mlir_c_runner_utils | FileCheck %s
-
-// Case 2: Only unsupported arithmetics (f8E4M3FN) is lowered through APFloat.
-// Arithmetics on f32 is lowered directly to LLVM.
-// RUN: mlir-opt %s --convert-to-llvm --convert-arith-to-apfloat \
-// RUN: --convert-to-llvm --reconcile-unrealized-casts | \
-// RUN: mlir-runner -e entry --entry-point-result=void \
-// RUN: --shared-libs=%mlir_c_runner_utils | FileCheck %s
-
-// Put rhs into separate function so that it won't be constant-folded.
-func.func @foo() -> (f8E4M3FN, f32) {
- %cst1 = arith.constant 2.2 : f8E4M3FN
- %cst2 = arith.constant 2.2 : f32
- return %cst1, %cst2 : f8E4M3FN, f32
-}
-
-func.func @entry() {
- %a1 = arith.constant 1.4 : f8E4M3FN
- %a2 = arith.constant 1.4 : f32
- %b1, %b2 = func.call @foo() : () -> (f8E4M3FN, f32)
- %c1 = arith.addf %a1, %b1 : f8E4M3FN // not supported by LLVM
- %c2 = arith.addf %a2, %b2 : f32 // supported by LLVM
-
- // CHECK: 3.5
- vector.print %c1 : f8E4M3FN
-
- // CHECK: 3.6
- vector.print %c2 : f32
-
- return
-}
More information about the Mlir-commits
mailing list