[Mlir-commits] [mlir] [mlir][math] Add FP software implementation lowering pass: math-to-apfloat (PR #171221)
Maksim Levental
llvmlistbot at llvm.org
Mon Dec 8 14:49:52 PST 2025
https://github.com/makslevental created https://github.com/llvm/llvm-project/pull/171221
…float
>From 49f813aa114208b1f7232e8f28b0a9098e2d9477 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Mon, 8 Dec 2025 14:48:11 -0800
Subject: [PATCH] [mlir][math] Add FP software implementation lowering pass:
math-to-apfloat
---
.../Conversion/MathToAPFloat/MathToAPFloat.h | 21 ++
mlir/include/mlir/Conversion/Passes.td | 15 ++
mlir/lib/Conversion/CMakeLists.txt | 1 +
.../Conversion/MathToAPFloat/CMakeLists.txt | 17 ++
.../MathToAPFloat/MathToAPFloat.cpp | 185 ++++++++++++++++++
5 files changed, 239 insertions(+)
create mode 100644 mlir/include/mlir/Conversion/MathToAPFloat/MathToAPFloat.h
create mode 100644 mlir/lib/Conversion/MathToAPFloat/CMakeLists.txt
create mode 100644 mlir/lib/Conversion/MathToAPFloat/MathToAPFloat.cpp
diff --git a/mlir/include/mlir/Conversion/MathToAPFloat/MathToAPFloat.h b/mlir/include/mlir/Conversion/MathToAPFloat/MathToAPFloat.h
new file mode 100644
index 0000000000000..86179a1611d5e
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MathToAPFloat/MathToAPFloat.h
@@ -0,0 +1,21 @@
+//===- MathToAPFloat.h - Math to APFloat impl conversion ---*- C++ ------*-===//
+//
+// Part of the APFloat Project, under the Apache License v2.0 with APFloat
+// Exceptions. See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH APFloat-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_MATHTOAPFLOAT_MATHTOAPFLOAT_H
+#define MLIR_CONVERSION_MATHTOAPFLOAT_MATHTOAPFLOAT_H
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_MATHTOAPFLOATCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MATHTOAPFLOAT_MATHTOAPFLOAT_H
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index fcbaf3ccc1486..7f24e58671aab 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -775,6 +775,21 @@ def ConvertMathToLibmPass : Pass<"convert-math-to-libm", "ModuleOp"> {
];
}
+//===----------------------------------------------------------------------===//
+// MathToAPFloat
+//===----------------------------------------------------------------------===//
+
+def MathToAPFloatConversionPass
+ : Pass<"convert-math-to-apfloat", "ModuleOp"> {
+ let summary = "Convert Math ops to APFloat runtime library calls";
+ let description = [{
+ This pass converts supported Math ops to APFloat-based runtime library
+ calls (APFloatWrappers.cpp). APFloat is a software implementation of
+ floating-point mathmetic operations.
+ }];
+ let dependentDialects = ["math::MathDialect", "func::FuncDialect"];
+}
+
//===----------------------------------------------------------------------===//
// MathToLLVM
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 613dc6d242ceb..3c59fbda6810a 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -35,6 +35,7 @@ add_subdirectory(IndexToLLVM)
add_subdirectory(IndexToSPIRV)
add_subdirectory(LinalgToStandard)
add_subdirectory(LLVMCommon)
+add_subdirectory(MathToAPFloat)
add_subdirectory(MathToEmitC)
add_subdirectory(MathToFuncs)
add_subdirectory(MathToLibm)
diff --git a/mlir/lib/Conversion/MathToAPFloat/CMakeLists.txt b/mlir/lib/Conversion/MathToAPFloat/CMakeLists.txt
new file mode 100644
index 0000000000000..454b71b1ef160
--- /dev/null
+++ b/mlir/lib/Conversion/MathToAPFloat/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_conversion_library(MLIRMathToAPFloat
+ MathToAPFloat.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToLLVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRMathDialect
+ MLIRFuncDialect
+ MLIRFuncUtils
+ )
diff --git a/mlir/lib/Conversion/MathToAPFloat/MathToAPFloat.cpp b/mlir/lib/Conversion/MathToAPFloat/MathToAPFloat.cpp
new file mode 100644
index 0000000000000..954096907a21b
--- /dev/null
+++ b/mlir/lib/Conversion/MathToAPFloat/MathToAPFloat.cpp
@@ -0,0 +1,185 @@
+//===- MathToAPFloat.cpp - Mathmetic to APFloat Conversion --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToAPFloat/MathToAPFloat.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Utils/Utils.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_MATHTOAPFLOATCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::func;
+
+static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
+ StringRef name, FunctionType funcT, bool setPrivate,
+ SymbolTableCollection *symbolTables = nullptr) {
+ OpBuilder::InsertionGuard g(b);
+ assert(!symTable->getRegion(0).empty() && "expected non-empty region");
+ b.setInsertionPointToStart(&symTable->getRegion(0).front());
+ FuncOp funcOp = FuncOp::create(b, symTable->getLoc(), name, funcT);
+ if (setPrivate)
+ funcOp.setPrivate();
+ if (symbolTables) {
+ SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable);
+ symbolTable.insert(funcOp, symTable->getRegion(0).front().begin());
+ }
+ return funcOp;
+}
+
+/// Helper function to look up or create the symbol for a runtime library
+/// function with the given parameter types. Returns an int64_t, unless a
+/// different result type is specified.
+static FailureOr<FuncOp>
+lookupOrCreateApFloatFn(OpBuilder &b, SymbolOpInterface symTable,
+ StringRef name, TypeRange paramTypes,
+ SymbolTableCollection *symbolTables = nullptr,
+ Type resultType = {}) {
+ if (!resultType)
+ resultType = IntegerType::get(symTable->getContext(), 64);
+ std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
+ auto funcT = FunctionType::get(b.getContext(), paramTypes, {resultType});
+ FailureOr<FuncOp> func =
+ lookupFnDecl(symTable, funcName, funcT, symbolTables);
+ // Failed due to type mismatch.
+ if (failed(func))
+ return func;
+ // Successfully matched existing decl.
+ if (*func)
+ return *func;
+
+ return createFnDecl(b, symTable, funcName, funcT,
+ /*setPrivate=*/true, symbolTables);
+}
+
+/// Helper function to look up or create the symbol for a runtime library
+/// function for a binary mathmetic operation.
+///
+/// Parameter 1: APFloat semantics
+/// Parameter 2: Left-hand side operand
+/// Parameter 3: Right-hand side operand
+///
+/// This function will return a failure if the function is found but has an
+/// unexpected signature.
+///
+static FailureOr<FuncOp>
+lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
+ SymbolTableCollection *symbolTables = nullptr) {
+ auto i32Type = IntegerType::get(symTable->getContext(), 32);
+ auto i64Type = IntegerType::get(symTable->getContext(), 64);
+ return lookupOrCreateApFloatFn(b, symTable, name, {i32Type, i64Type, i64Type},
+ symbolTables);
+}
+
+static Value getSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy) {
+ int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
+ return arith::ConstantOp::create(b, loc, b.getI32Type(),
+ b.getIntegerAttr(b.getI32Type(), sem));
+}
+
+/// Given two operands of vector type and vector result type (with the same
+/// shape), call the given function for each pair of scalar operands and
+/// package the result into a vector. If the given operands and result type are
+/// not vectors, call the function directly. The second operand is optional.
+template <typename Fn, typename... Values>
+static Value forEachScalarValue(RewriterBase &rewriter, Location loc,
+ Value operand1, Value operand2, Type resultType,
+ Fn fn) {
+ auto vecTy1 = dyn_cast<VectorType>(operand1.getType());
+ if (operand2) {
+ // Sanity check: Operand types must match.
+ assert(vecTy1 == dyn_cast<VectorType>(operand2.getType()) &&
+ "expected same vector types");
+ }
+ if (!vecTy1) {
+ // Not a vector. Call the function directly.
+ return fn(operand1, operand2, resultType);
+ }
+
+ // Prepare scalar operands.
+ ResultRange sclars1 =
+ vector::ToElementsOp::create(rewriter, loc, operand1)->getResults();
+ SmallVector<Value> scalars2;
+ if (!operand2) {
+ // No second operand. Create a vector of empty values.
+ scalars2.assign(vecTy1.getNumElements(), Value());
+ } else {
+ llvm::append_range(
+ scalars2,
+ vector::ToElementsOp::create(rewriter, loc, operand2)->getResults());
+ }
+
+ // Call the function for each pair of scalar operands.
+ auto resultVecType = cast<VectorType>(resultType);
+ SmallVector<Value> results;
+ for (auto [scalar1, scalar2] : llvm::zip_equal(sclars1, scalars2)) {
+ Value result = fn(scalar1, scalar2, resultVecType.getElementType());
+ results.push_back(result);
+ }
+
+ // Package the results into a vector.
+ return vector::FromElementsOp::create(
+ rewriter, loc,
+ vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()),
+ results);
+}
+
+/// Check preconditions for the conversion:
+/// 1. All operands / results must be integers or floats (or vectors thereof).
+/// 2. The bitwidth of the operands / results must be <= 64.
+static LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op) {
+ for (Value value : llvm::concat<Value>(op->getOperands(), op->getResults())) {
+ Type type = value.getType();
+ if (auto vecTy = dyn_cast<VectorType>(type)) {
+ type = vecTy.getElementType();
+ }
+ if (!type.isIntOrFloat()) {
+ return rewriter.notifyMatchFailure(
+ op, "only integers and floats (or vectors thereof) are supported");
+ }
+ if (type.getIntOrFloatBitWidth() > 64)
+ return rewriter.notifyMatchFailure(op,
+ "bitwidth > 64 bits is not supported");
+ }
+ return success();
+}
+
+namespace {
+struct MathToAPFloatConversionPass final
+ : impl::MathToAPFloatConversionPassBase<MathToAPFloatConversionPass> {
+ using Base::Base;
+
+ void runOnOperation() override;
+};
+
+void MathToAPFloatConversionPass::runOnOperation() {
+ MLIRContext *context = &getContext();
+ RewritePatternSet patterns(context);
+ LogicalResult result = success();
+ ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
+ if (diag.getSeverity() == DiagnosticSeverity::Error) {
+ result = failure();
+ }
+ // NB: if you don't return failure, no other diag handlers will fire (see
+ // mlir/lib/IR/Diagnostics.cpp:DiagnosticEngineImpl::emit).
+ return failure();
+ });
+ walkAndApplyPatterns(getOperation(), std::move(patterns));
+ if (failed(result))
+ return signalPassFailure();
+}
+} // namespace
More information about the Mlir-commits
mailing list