[Mlir-commits] [mlir] [mlir] Add FP software implementation lowering pass: `arith-to-apfloat` (PR #166618)

Maksim Levental llvmlistbot at llvm.org
Mon Nov 10 13:02:42 PST 2025


================
@@ -0,0 +1,161 @@
+//===- ArithToAPFloat.cpp - Arithmetic to APFloat Conversion --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Utils/Utils.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "arith-to-apfloat"
+
+namespace mlir {
+#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::func;
+
+static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
+                           StringRef name, FunctionType funcT, bool setPrivate,
+                           SymbolTableCollection *symbolTables = nullptr) {
+  OpBuilder::InsertionGuard g(b);
+  assert(!symTable->getRegion(0).empty() && "expected non-empty region");
+  b.setInsertionPointToStart(&symTable->getRegion(0).front());
+  FuncOp funcOp = FuncOp::create(b, symTable->getLoc(), name, funcT);
+  if (setPrivate)
+    funcOp.setPrivate();
+  if (symbolTables) {
+    SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable);
+    symbolTable.insert(funcOp, symTable->getRegion(0).front().begin());
+  }
+  return funcOp;
+}
+
+/// Helper function to look up or create the symbol for a runtime library
+/// function for a binary arithmetic operation.
+///
+/// Parameter 1: APFloat semantics
+/// Parameter 2: Left-hand side operand
+/// Parameter 3: Right-hand side operand
+///
+/// This function will return a failure if the function is found but has an
+/// unexpected signature.
+///
+static FailureOr<FuncOp>
+lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
+                       SymbolTableCollection *symbolTables = nullptr) {
+  auto i32Type = IntegerType::get(symTable->getContext(), 32);
+  auto i64Type = IntegerType::get(symTable->getContext(), 64);
+
+  std::string funcName = (llvm::Twine("__mlir_apfloat_") + name).str();
+  FunctionType funcT =
+      FunctionType::get(b.getContext(), {i32Type, i64Type, i64Type}, {i64Type});
+  FailureOr<FuncOp> func =
+      lookupFnDecl(symTable, funcName, funcT, symbolTables);
+  // Failed due to type mismatch.
+  if (failed(func))
+    return func;
+  // Successfully matched existing decl.
+  if (*func)
+    return *func;
+
+  return createFnDecl(b, symTable, funcName, funcT,
+                      /*setPrivate=*/true, symbolTables);
+}
+
+/// Rewrite a binary arithmetic operation to an APFloat function call.
+template <typename OpTy, const char *APFloatName>
+struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
+  BinaryArithOpToAPFloatConversion(MLIRContext *context, PatternBenefit benefit,
+                                   SymbolOpInterface symTable)
+      : OpRewritePattern<OpTy>(context, benefit), symTable(symTable) {};
+
+  LogicalResult matchAndRewrite(OpTy op,
+                                PatternRewriter &rewriter) const override {
+    // Get APFloat function from runtime library.
+    FailureOr<FuncOp> fn =
+        lookupOrCreateBinaryFn(rewriter, symTable, APFloatName);
+    if (failed(fn))
+      return fn;
+
+    rewriter.setInsertionPoint(op);
+    // Cast operands to 64-bit integers.
+    Location loc = op.getLoc();
+    auto floatTy = cast<FloatType>(op.getType());
+    auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+    auto int64Type = rewriter.getI64Type();
+    Value lhsBits = arith::ExtUIOp::create(
+        rewriter, loc, int64Type,
+        arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
+    Value rhsBits = arith::ExtUIOp::create(
+        rewriter, loc, int64Type,
+        arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));
+
+    // Call APFloat function.
+    int32_t sem =
+        llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
+    Value semValue = arith::ConstantOp::create(
+        rewriter, loc, rewriter.getI32Type(),
+        rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
+    SmallVector<Value> params = {semValue, lhsBits, rhsBits};
+    auto resultOp =
+        func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
+                             SymbolRefAttr::get(*fn), params);
+
+    // Truncate result to the original width.
+    Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
+                                                  resultOp->getResult(0));
+    rewriter.replaceOp(
+        op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
+    return success();
+  }
+
+  SymbolOpInterface symTable;
+};
+
+namespace {
+struct ArithToAPFloatConversionPass final
+    : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
+  using Base::Base;
+
+  void runOnOperation() override {
+    MLIRContext *context = &getContext();
+    RewritePatternSet patterns(context);
+    static const char add[] = "add";
+    static const char subtract[] = "subtract";
+    static const char multiply[] = "multiply";
+    static const char divide[] = "divide";
+    static const char remainder[] = "remainder";
+    patterns.add<BinaryArithOpToAPFloatConversion<arith::AddFOp, add>,
+                 BinaryArithOpToAPFloatConversion<arith::SubFOp, subtract>,
+                 BinaryArithOpToAPFloatConversion<arith::MulFOp, multiply>,
+                 BinaryArithOpToAPFloatConversion<arith::DivFOp, divide>,
+                 BinaryArithOpToAPFloatConversion<arith::RemFOp, remainder>>(
+        context, 1, getOperation());
+    LogicalResult result = success();
+    ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
+      if (diag.getSeverity() == DiagnosticSeverity::Error) {
+        result = failure();
+      }
+      // NB: if you don't return failure, no other diag handlers will fire (see
+      // mlir/lib/IR/Diagnostics.cpp:DiagnosticEngineImpl::emit).
+      return failure();
+    });
----------------
makslevental wrote:

My 2 cents: leaving it decomposed is better - this is not hard to write (once you know how diaghandlers work 😜).

But it's up to you - you're code owner/designer of that. If you so wish it. I can refactor this into the rewriter. Maybe the "abstraction" could be passing just the callback.

https://github.com/llvm/llvm-project/pull/166618


More information about the Mlir-commits mailing list