[Mlir-commits] [mlir] [mlir][math] Add FP software implementation lowering pass: math-to-apfloat (PR #171221)
Maksim Levental
llvmlistbot at llvm.org
Fri Dec 12 09:53:53 PST 2025
================
@@ -0,0 +1,198 @@
+//===- MathToAPFloat.cpp - Mathmetic to APFloat Conversion ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToAPFloat/MathToAPFloat.h"
+#include "Utils.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Utils/Utils.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_MATHTOAPFLOATCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::func;
+
+struct AbsFOpToAPFloatConversion final : OpRewritePattern<math::AbsFOp> {
+ AbsFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<math::AbsFOp>(context, benefit), symTable(symTable) {}
+
+ LogicalResult matchAndRewrite(math::AbsFOp op,
+ PatternRewriter &rewriter) const override {
+ // Get APFloat function from runtime library.
+ auto i32Type = IntegerType::get(symTable->getContext(), 32);
+ auto i64Type = IntegerType::get(symTable->getContext(), 64);
+ FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
+ rewriter, symTable, "_mlir_apfloat_abs", {i32Type, i64Type});
+ if (failed(fn))
+ return fn;
+ Location loc = op.getLoc();
+ rewriter.setInsertionPoint(op);
+ // Cast operands to 64-bit integers.
+ auto operand = op.getOperand();
+ auto floatTy = cast<FloatType>(operand.getType());
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ Value operandBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, operand));
+
+ // Call APFloat function.
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
+ SmallVector<Value> params = {semValue, operandBits};
+ Value negatedBits = func::CallOp::create(rewriter, loc, TypeRange(i64Type),
+ SymbolRefAttr::get(*fn), params)
+ ->getResult(0);
+
+ // Truncate result to the original width.
+ Value truncatedBits =
+ arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
+ rewriter.replaceOp(
+ op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
+ return success();
+ }
+
+ SymbolOpInterface symTable;
+};
+
+template <typename OpTy>
+struct IsOpToAPFloatConversion final : OpRewritePattern<OpTy> {
+ IsOpToAPFloatConversion(MLIRContext *context, const char *APFloatName,
+ SymbolOpInterface symTable,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
+ APFloatName(APFloatName) {};
+
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override {
+ // Get APFloat function from runtime library.
+ auto i1 = IntegerType::get(symTable->getContext(), 1);
+ auto i32Type = IntegerType::get(symTable->getContext(), 32);
+ auto i64Type = IntegerType::get(symTable->getContext(), 64);
+ std::string funcName =
+ (llvm::Twine("_mlir_apfloat_is") + APFloatName).str();
+ FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
+ rewriter, symTable, funcName, {i32Type, i64Type}, nullptr, i1);
+ if (failed(fn))
+ return fn;
+ Location loc = op.getLoc();
+ rewriter.setInsertionPoint(op);
+ // Cast operands to 64-bit integers.
+ auto operand = op.getOperand();
+ auto floatTy = cast<FloatType>(operand.getType());
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ Value operandBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, operand));
+
+ // Call APFloat function.
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
+ SmallVector<Value> params = {semValue, operandBits};
+ rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(i64Type),
----------------
makslevental wrote:
whoops! good catch
https://github.com/llvm/llvm-project/pull/171221
More information about the Mlir-commits
mailing list