[Mlir-commits] [mlir] Prototype: APFloat runtime library for unsupported CPU floating-point types (PR #166618)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Nov 8 22:32:59 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

<details>
<summary>Changes</summary>

Lower floating point arithmetics that are unsupported by LLVM to calls into the execution engine runtime library.

---

Patch is 23.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/166618.diff


16 Files Affected:

- (added) mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h (+21) 
- (modified) mlir/include/mlir/Conversion/Passes.h (+1) 
- (modified) mlir/include/mlir/Conversion/Passes.td (+16) 
- (modified) mlir/include/mlir/Dialect/Func/Utils/Utils.h (+8) 
- (modified) mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h (+4) 
- (added) mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp (+113) 
- (added) mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt (+17) 
- (modified) mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp (+1) 
- (modified) mlir/lib/Conversion/CMakeLists.txt (+1) 
- (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+14) 
- (modified) mlir/lib/Dialect/Func/Utils/Utils.cpp (+42) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp (+11) 
- (added) mlir/lib/ExecutionEngine/APFloatWrappers.cpp (+71) 
- (modified) mlir/lib/ExecutionEngine/CMakeLists.txt (+12) 
- (added) mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir (+38) 
- (added) mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir (+19) 


``````````diff
diff --git a/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h b/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h
new file mode 100644
index 0000000000000..64a42a228199e
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h
@@ -0,0 +1,21 @@
+//===- ArithToAPFloat.h - Arith to APFloat impl conversion ---*- C++ ----*-===//
+//
+// Part of the APFloat Project, under the Apache License v2.0 with APFloat
+// Exceptions. See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH APFloat-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H
+#define MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_ARITHTOAPFLOATCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 40d866ec7bf10..82bdfd02661a6 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -12,6 +12,7 @@
 #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
 #include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
+#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
 #include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
 #include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"
 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 70e3e45c225db..2cd7d14f5517b 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -186,6 +186,22 @@ def ArithToLLVMConversionPass : Pass<"convert-arith-to-llvm"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// ArithToAPFloat
+//===----------------------------------------------------------------------===//
+
+def ArithToAPFloatConversionPass
+    : Pass<"convert-arith-to-apfloat", "ModuleOp"> {
+  let summary = "Convert Arith ops to APFloat runtime library calls";
+  let description = [{
+    This pass converts supported Arith ops to APFloat-based runtime library
+    calls (APFloatWrappers.cpp). APFloat is a software implementation of
+    floating-point arithmetic operations.
+  }];
+  let dependentDialects = ["func::FuncDialect"];
+  let options = [];
+}
+
 //===----------------------------------------------------------------------===//
 // ArithToSPIRV
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
index 3576126a487ac..9c9973cf84368 100644
--- a/mlir/include/mlir/Dialect/Func/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
@@ -60,6 +60,14 @@ mlir::FailureOr<std::pair<mlir::func::FuncOp, mlir::func::CallOp>>
 deduplicateArgsOfFuncOp(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp,
                         mlir::ModuleOp moduleOp);
 
+/// Create a FuncOp with signature `resultTypes`(`paramTypes`)` and name `name`.
+/// Return a failure if the FuncOp found has unexpected signature.
+FailureOr<FuncOp>
+lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
+                 ArrayRef<Type> paramTypes = {},
+                 ArrayRef<Type> resultTypes = {}, bool setPrivate = false,
+                 SymbolTableCollection *symbolTables = nullptr);
+
 } // namespace func
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 8ad9ed18acebd..b09d32022e348 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -52,6 +52,10 @@ lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp,
 FailureOr<LLVM::LLVMFuncOp>
 lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
                          SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp,
+                             SymbolTableCollection *symbolTables = nullptr);
+
 /// Declares a function to print a C-string.
 /// If a custom runtime function is defined via `runtimeFunctionName`, it must
 /// have the signature void(char const*). The default function is `printString`.
diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
new file mode 100644
index 0000000000000..62074625033ce
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
@@ -0,0 +1,113 @@
+//===- ArithToAPFloat.cpp - Arithmetic to APFloat impl conversion ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Utils/Utils.h"
+#include "mlir/IR/Verifier.h"
+
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::func;
+
+static FailureOr<Operation *>
+lookupOrCreateBinaryFn(OpBuilder &b, Operation *moduleOp, StringRef name,
+                       SymbolTableCollection *symbolTables = nullptr) {
+  return lookupOrCreateFn(b, moduleOp,
+                          (llvm::Twine("_mlir_apfloat_") + name).str(),
+                          {IntegerType::get(moduleOp->getContext(), 32),
+                           IntegerType::get(moduleOp->getContext(), 64),
+                           IntegerType::get(moduleOp->getContext(), 64)},
+                          {IntegerType::get(moduleOp->getContext(), 64)},
+                          /*setPrivate*/ true, symbolTables);
+}
+
+template <typename OpTy>
+static LogicalResult rewriteBinaryOp(RewriterBase &rewriter, ModuleOp module,
+                                     OpTy op, StringRef apfloatName) {
+  // Get APFloat function from runtime library.
+  FailureOr<Operation *> fn =
+      lookupOrCreateBinaryFn(rewriter, module, apfloatName);
+  if (failed(fn))
+    return op->emitError("failed to lookup or create APFloat function");
+
+  // Cast operands to 64-bit integers.
+  Location loc = op.getLoc();
+  auto floatTy = cast<FloatType>(op.getType());
+  auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+  auto int64Type = rewriter.getI64Type();
+  Value lhsBits = arith::ExtUIOp::create(
+      rewriter, loc, int64Type,
+      arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
+  Value rhsBits = arith::ExtUIOp::create(
+      rewriter, loc, int64Type,
+      arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));
+
+  // Call APFloat function.
+  int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
+  Value semValue = arith::ConstantOp::create(
+      rewriter, loc, rewriter.getI32Type(),
+      rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
+  SmallVector<Value> params = {semValue, lhsBits, rhsBits};
+  auto resultOp =
+      func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
+                           SymbolRefAttr::get(*fn), params);
+
+  // Truncate result to the original width.
+  Value truncatedBits =
+      arith::TruncIOp::create(rewriter, loc, intWType, resultOp->getResult(0));
+  rewriter.replaceOp(
+      op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
+  return success();
+}
+
+namespace {
+struct ArithToAPFloatConversionPass final
+    : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
+  using impl::ArithToAPFloatConversionPassBase<
+      ArithToAPFloatConversionPass>::ArithToAPFloatConversionPassBase;
+
+  void runOnOperation() override {
+    ModuleOp module = getOperation();
+    IRRewriter rewriter(getOperation()->getContext());
+    SmallVector<arith::AddFOp> addOps;
+    WalkResult status = module->walk([&](Operation *op) {
+      rewriter.setInsertionPoint(op);
+      LogicalResult result =
+          llvm::TypeSwitch<Operation *, LogicalResult>(op)
+              .Case<arith::AddFOp>([&](arith::AddFOp op) {
+                return rewriteBinaryOp(rewriter, module, op, "add");
+              })
+              .Case<arith::SubFOp>([&](arith::SubFOp op) {
+                return rewriteBinaryOp(rewriter, module, op, "subtract");
+              })
+              .Case<arith::MulFOp>([&](arith::MulFOp op) {
+                return rewriteBinaryOp(rewriter, module, op, "mulitply");
+              })
+              .Case<arith::DivFOp>([&](arith::DivFOp op) {
+                return rewriteBinaryOp(rewriter, module, op, "divide");
+              })
+              .Default([](Operation *op) { return success(); });
+      if (failed(result))
+        return WalkResult::interrupt();
+      return WalkResult::advance();
+    });
+    if (status.wasInterrupted())
+      return signalPassFailure();
+  }
+};
+} // namespace
diff --git a/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt b/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt
new file mode 100644
index 0000000000000..b0d1e46b3655f
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_conversion_library(MLIRArithToAPFloat
+  ArithToAPFloat.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRArithDialect
+  MLIRArithTransforms
+  MLIRFuncDialect
+  )
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index b6099902cc337..f2bacc3399144 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
 #include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/TypeUtilities.h"
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index bebf1b8fff3f9..613dc6d242ceb 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -2,6 +2,7 @@ add_subdirectory(AffineToStandard)
 add_subdirectory(AMDGPUToROCDL)
 add_subdirectory(ArithCommon)
 add_subdirectory(ArithToAMDGPU)
+add_subdirectory(ArithToAPFloat)
 add_subdirectory(ArithToArmSME)
 add_subdirectory(ArithToEmitC)
 add_subdirectory(ArithToLLVM)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 69a317ecd101f..c747e1b59558a 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1654,6 +1654,20 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
           return failure();
         }
       }
+    } else if (auto floatTy = dyn_cast<FloatType>(printType)) {
+      // Print other floating-point types using the APFloat runtime library.
+      int32_t sem =
+          llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
+      Value semValue = LLVM::ConstantOp::create(
+          rewriter, loc, rewriter.getI32Type(),
+          rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
+      Value floatBits =
+          LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(), value);
+      printer =
+          LLVM::lookupOrCreateApFloatPrintFn(rewriter, parent, symbolTables);
+      emitCall(rewriter, loc, printer.value(),
+               ValueRange({semValue, floatBits}));
+      return success();
     } else {
       return failure();
     }
diff --git a/mlir/lib/Dialect/Func/Utils/Utils.cpp b/mlir/lib/Dialect/Func/Utils/Utils.cpp
index b4cb0932ef631..e187e62cf6555 100644
--- a/mlir/lib/Dialect/Func/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Func/Utils/Utils.cpp
@@ -254,3 +254,45 @@ func::deduplicateArgsOfFuncOp(RewriterBase &rewriter, func::FuncOp funcOp,
 
   return std::make_pair(*newFuncOpOrFailure, newCallOp);
 }
+
+FailureOr<func::FuncOp>
+func::lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
+                       ArrayRef<Type> paramTypes, ArrayRef<Type> resultTypes,
+                       bool setPrivate, SymbolTableCollection *symbolTables) {
+  assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
+         "expected SymbolTable operation");
+
+  FuncOp func;
+  if (symbolTables) {
+    func = symbolTables->lookupSymbolIn<FuncOp>(
+        moduleOp, StringAttr::get(moduleOp->getContext(), name));
+  } else {
+    func = llvm::dyn_cast_or_null<FuncOp>(
+        SymbolTable::lookupSymbolIn(moduleOp, name));
+  }
+
+  FunctionType funcT =
+      FunctionType::get(b.getContext(), paramTypes, resultTypes);
+  // Assert the signature of the found function is same as expected
+  if (func) {
+    if (funcT != func.getFunctionType()) {
+      func.emitError("redefinition of function '")
+          << name << "' of different type " << funcT << " is prohibited";
+      return failure();
+    }
+    return func;
+  }
+
+  OpBuilder::InsertionGuard g(b);
+  assert(!moduleOp->getRegion(0).empty() && "expected non-empty region");
+  b.setInsertionPointToStart(&moduleOp->getRegion(0).front());
+  FuncOp funcOp = FuncOp::create(b, moduleOp->getLoc(), name, funcT);
+  if (setPrivate)
+    funcOp.setPrivate();
+  if (symbolTables) {
+    SymbolTable &symbolTable = symbolTables->getSymbolTable(moduleOp);
+    symbolTable.insert(funcOp, moduleOp->getRegion(0).front().begin());
+  }
+
+  return funcOp;
+}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index feaffa34897b6..160b6ae89215c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -30,6 +30,7 @@ static constexpr llvm::StringRef kPrintF16 = "printF16";
 static constexpr llvm::StringRef kPrintBF16 = "printBF16";
 static constexpr llvm::StringRef kPrintF32 = "printF32";
 static constexpr llvm::StringRef kPrintF64 = "printF64";
+static constexpr llvm::StringRef kPrintApFloat = "printApFloat";
 static constexpr llvm::StringRef kPrintString = "printString";
 static constexpr llvm::StringRef kPrintOpen = "printOpen";
 static constexpr llvm::StringRef kPrintClose = "printClose";
@@ -160,6 +161,16 @@ mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
       LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
 }
 
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp,
+                                         SymbolTableCollection *symbolTables) {
+  return lookupOrCreateReservedFn(
+      b, moduleOp, kPrintApFloat,
+      {IntegerType::get(moduleOp->getContext(), 32),
+       IntegerType::get(moduleOp->getContext(), 64)},
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
+}
+
 static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
   return LLVM::LLVMPointerType::get(context);
 }
diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
new file mode 100644
index 0000000000000..e8f57f231ce5c
--- /dev/null
+++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
@@ -0,0 +1,71 @@
+//===- APFloatWrappers.cpp - Software Implementation of FP Arithmetics --- ===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/APFloat.h"
+
+#include <iostream>
+
+#if (defined(_WIN32) || defined(__CYGWIN__))
+#define MLIR_APFLOAT_WRAPPERS_EXPORTED __declspec(dllexport)
+#else
+#define MLIR_APFLOAT_WRAPPERS_EXPORTED __attribute__((visibility("default")))
+#endif
+
+/// Binary operations without rounding mode.
+#define APFLOAT_BINARY_OP(OP)                                                  \
+  int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED _mlir_apfloat_##OP(                   \
+      int32_t semantics, uint64_t a, uint64_t b) {                             \
+    const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(        \
+        static_cast<llvm::APFloatBase::Semantics>(semantics));                 \
+    unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);           \
+    llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a));                          \
+    llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b));                          \
+    llvm::APFloatBase::opStatus status = lhs.OP(rhs);                          \
+    return lhs.bitcastToAPInt().getZExtValue();                                \
+  }
+
+/// Binary operations with rounding mode.
+#define APFLOAT_BINARY_OP_ROUNDING_MODE(OP, ROUNDING_MODE)                     \
+  int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED _mlir_apfloat_##OP(                   \
+      int32_t semantics, uint64_t a, uint64_t b) {                             \
+    const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(        \
+        static_cast<llvm::APFloatBase::Semantics>(semantics));                 \
+    unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);           \
+    llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a));                          \
+    llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b));                          \
+    llvm::APFloatBase::opStatus status = lhs.OP(rhs, ROUNDING_MODE);           \
+    return lhs.bitcastToAPInt().getZExtValue();                                \
+  }
+
+extern "C" {
+
+#define BIN_OPS_WITH_ROUNDING(X)                                               \
+  X(add, llvm::RoundingMode::NearestTiesToEven)                                \
+  X(subtract, llvm::RoundingMode::NearestTiesToEven)                           \
+  X(multiply, llvm::RoundingMode::NearestTiesToEven)                           \
+  X(divide, llvm::RoundingMode::NearestTiesToEven)
+
+BIN_OPS_WITH_ROUNDING(APFLOAT_BINARY_OP_ROUNDING_MODE)
+#undef BIN_OPS_WITH_ROUNDING
+#undef APFLOAT_BINARY_OP_ROUNDING_MODE
+
+APFLOAT_BINARY_OP(remainder)
+APFLOAT_BINARY_OP(mod)
+
+#undef APFLOAT_BINARY_OP
+
+void MLIR_APFLOAT_WRAPPERS_EXPORTED printApFloat(int32_t semantics,
+                                                 uint64_t a) {
+  const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+      static_cast<llvm::APFloatBase::Semantics>(semantics));
+  unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
+  llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
+  double d = x.convertToDouble();
+  fprintf(stdout, "%lg", d);
+}
+}
diff --git a/mlir/lib/ExecutionEngine/CMakeLists.txt b/mlir/lib/ExecutionEngine/CMakeLists.txt
index fdeb4dacf9278..8c09e50e4de7b 100644
--- a/mlir/lib/ExecutionEngine/CMakeLists.txt
+++ b/mlir/lib/ExecutionEngine/CMakeLists.txt
@@ -2,6 +2,7 @@
 # is a big dependency which most don't need.
 
 set(LLVM_OPTIONAL_SOURCES
+  APFloatWrappers.cpp
   ArmRunnerUtils.cpp
   ArmSMEStubs.cpp
   AsyncRuntime.cpp
@@ -167,6 +168,15 @@ if(LLVM_ENABLE_PIC)
   set_property(TARGET mlir_float16_utils PROPERTY CXX_STANDARD 17)
   target_compile_definitions(mlir_float16_utils PRIVATE mlir_float16_utils_EXPORTS)
 
+  add_mlir_library(mlir_apfloat_wrappers
+    SHARED
+    APFloatWrappers.cpp
+
+    EXCLUDE_FROM_LIBMLIR
+    )
+  set_property(TARGET mlir_apfloat_wrappers PROPERTY CXX_STANDARD 17)
+  target_compile_definitions(mlir_apfloat...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/166618


More information about the Mlir-commits mailing list