[Mlir-commits] [mlir] [mlir][math] Add FP software implementation lowering pass: math-to-apfloat (PR #171221)

Maksim Levental llvmlistbot at llvm.org
Fri Dec 12 09:53:53 PST 2025


================
@@ -0,0 +1,198 @@
+//===- MathToAPFloat.cpp - Mathmetic to APFloat Conversion ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToAPFloat/MathToAPFloat.h"
+#include "Utils.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Utils/Utils.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_MATHTOAPFLOATCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::func;
+
+struct AbsFOpToAPFloatConversion final : OpRewritePattern<math::AbsFOp> {
+  AbsFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
+                            PatternBenefit benefit = 1)
+      : OpRewritePattern<math::AbsFOp>(context, benefit), symTable(symTable) {}
+
+  LogicalResult matchAndRewrite(math::AbsFOp op,
+                                PatternRewriter &rewriter) const override {
+    // Get APFloat function from runtime library.
+    auto i32Type = IntegerType::get(symTable->getContext(), 32);
+    auto i64Type = IntegerType::get(symTable->getContext(), 64);
+    FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
+        rewriter, symTable, "_mlir_apfloat_abs", {i32Type, i64Type});
+    if (failed(fn))
+      return fn;
+    Location loc = op.getLoc();
+    rewriter.setInsertionPoint(op);
+    // Cast operands to 64-bit integers.
+    auto operand = op.getOperand();
+    auto floatTy = cast<FloatType>(operand.getType());
+    auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+    Value operandBits = arith::ExtUIOp::create(
+        rewriter, loc, i64Type,
+        arith::BitcastOp::create(rewriter, loc, intWType, operand));
+
+    // Call APFloat function.
+    Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
+    SmallVector<Value> params = {semValue, operandBits};
+    Value negatedBits = func::CallOp::create(rewriter, loc, TypeRange(i64Type),
+                                             SymbolRefAttr::get(*fn), params)
+                            ->getResult(0);
+
+    // Truncate result to the original width.
+    Value truncatedBits =
+        arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
+    rewriter.replaceOp(
+        op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
+    return success();
+  }
+
+  SymbolOpInterface symTable;
+};
+
+template <typename OpTy>
+struct IsOpToAPFloatConversion final : OpRewritePattern<OpTy> {
+  IsOpToAPFloatConversion(MLIRContext *context, const char *APFloatName,
+                          SymbolOpInterface symTable,
+                          PatternBenefit benefit = 1)
+      : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
+        APFloatName(APFloatName) {};
+
+  LogicalResult matchAndRewrite(OpTy op,
+                                PatternRewriter &rewriter) const override {
+    // Get APFloat function from runtime library.
+    auto i1 = IntegerType::get(symTable->getContext(), 1);
+    auto i32Type = IntegerType::get(symTable->getContext(), 32);
+    auto i64Type = IntegerType::get(symTable->getContext(), 64);
+    std::string funcName =
+        (llvm::Twine("_mlir_apfloat_is") + APFloatName).str();
+    FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
+        rewriter, symTable, funcName, {i32Type, i64Type}, nullptr, i1);
+    if (failed(fn))
+      return fn;
+    Location loc = op.getLoc();
+    rewriter.setInsertionPoint(op);
+    // Cast operands to 64-bit integers.
+    auto operand = op.getOperand();
+    auto floatTy = cast<FloatType>(operand.getType());
+    auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+    Value operandBits = arith::ExtUIOp::create(
+        rewriter, loc, i64Type,
+        arith::BitcastOp::create(rewriter, loc, intWType, operand));
+
+    // Call APFloat function.
+    Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
+    SmallVector<Value> params = {semValue, operandBits};
+    rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(i64Type),
----------------
makslevental wrote:

whoops! good catch

https://github.com/llvm/llvm-project/pull/171221


More information about the Mlir-commits mailing list