[Mlir-commits] [mlir] [mlir][math] Add FP software implementation lowering pass: math-to-apfloat (PR #171221)
Maksim Levental
llvmlistbot at llvm.org
Wed Dec 10 11:37:37 PST 2025
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/171221
>From 7533d5615d194ec40fbe74c9aa5fc8d50f519461 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Mon, 8 Dec 2025 14:48:11 -0800
Subject: [PATCH 1/2] [mlir][math] Add FP software implementation lowering
pass: math-to-apfloat
---
.../Conversion/MathToAPFloat/MathToAPFloat.h | 21 +++
mlir/include/mlir/Conversion/Passes.h | 1 +
mlir/include/mlir/Conversion/Passes.td | 15 ++
mlir/include/mlir/Dialect/Func/Utils/Utils.h | 16 ++
.../ArithToAPFloat.cpp | 92 +++--------
.../ArithAndMathToAPFloat/CMakeLists.txt | 49 ++++++
.../ArithAndMathToAPFloat/MathToAPFloat.cpp | 148 ++++++++++++++++++
.../ArithAndMathToAPFloat/Utils.cpp | 22 +++
.../Conversion/ArithAndMathToAPFloat/Utils.h | 21 +++
.../Conversion/ArithToAPFloat/CMakeLists.txt | 19 ---
mlir/lib/Conversion/CMakeLists.txt | 2 +-
mlir/lib/Dialect/Func/Utils/Utils.cpp | 39 +++++
mlir/lib/ExecutionEngine/APFloatWrappers.cpp | 48 +++++-
mlir/lib/ExecutionEngine/CMakeLists.txt | 10 +-
.../Math/CPU/test-apfloat-emulation.mlir | 32 ++++
15 files changed, 442 insertions(+), 93 deletions(-)
create mode 100644 mlir/include/mlir/Conversion/MathToAPFloat/MathToAPFloat.h
rename mlir/lib/Conversion/{ArithToAPFloat => ArithAndMathToAPFloat}/ArithToAPFloat.cpp (88%)
create mode 100644 mlir/lib/Conversion/ArithAndMathToAPFloat/CMakeLists.txt
create mode 100644 mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
create mode 100644 mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp
create mode 100644 mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h
delete mode 100644 mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt
create mode 100644 mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir
diff --git a/mlir/include/mlir/Conversion/MathToAPFloat/MathToAPFloat.h b/mlir/include/mlir/Conversion/MathToAPFloat/MathToAPFloat.h
new file mode 100644
index 0000000000000..86179a1611d5e
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MathToAPFloat/MathToAPFloat.h
@@ -0,0 +1,21 @@
+//===- MathToAPFloat.h - Math to APFloat impl conversion ---*- C++ ------*-===//
+//
+// Part of the APFloat Project, under the Apache License v2.0 with APFloat
+// Exceptions. See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH APFloat-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_MATHTOAPFLOAT_MATHTOAPFLOAT_H
+#define MLIR_CONVERSION_MATHTOAPFLOAT_MATHTOAPFLOAT_H
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_MATHTOAPFLOATCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MATHTOAPFLOAT_MATHTOAPFLOAT_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 82bdfd02661a6..05ec2f8ce2538 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -44,6 +44,7 @@
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
+#include "mlir/Conversion/MathToAPFloat/MathToAPFloat.h"
#include "mlir/Conversion/MathToEmitC/MathToEmitCPass.h"
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index fcbaf3ccc1486..7f24e58671aab 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -775,6 +775,21 @@ def ConvertMathToLibmPass : Pass<"convert-math-to-libm", "ModuleOp"> {
];
}
+//===----------------------------------------------------------------------===//
+// MathToAPFloat
+//===----------------------------------------------------------------------===//
+
+def MathToAPFloatConversionPass
+ : Pass<"convert-math-to-apfloat", "ModuleOp"> {
+ let summary = "Convert Math ops to APFloat runtime library calls";
+ let description = [{
+ This pass converts supported Math ops to APFloat-based runtime library
+ calls (APFloatWrappers.cpp). APFloat is a software implementation of
+ floating-point mathmetic operations.
+ }];
+ let dependentDialects = ["math::MathDialect", "func::FuncDialect"];
+}
+
//===----------------------------------------------------------------------===//
// MathToLLVM
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
index 00d50874a2e8d..079c1f461b6ed 100644
--- a/mlir/include/mlir/Dialect/Func/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
@@ -67,6 +67,22 @@ FailureOr<FuncOp> lookupFnDecl(SymbolOpInterface symTable, StringRef name,
FunctionType funcT,
SymbolTableCollection *symbolTables = nullptr);
+/// Create a FuncOp decl and insert it into `symTable` operation. If
+/// `symbolTables` is provided, then the decl will be inserted into the
+/// SymbolTableCollection.
+FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
+ FunctionType funcT, bool setPrivate,
+ SymbolTableCollection *symbolTables = nullptr);
+
+/// Helper function to look up or create the symbol for a runtime library
+/// function with the given parameter types. Returns an int64_t, unless a
+/// different result type is specified.
+FailureOr<FuncOp>
+lookupOrCreateFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
+ TypeRange paramTypes,
+ SymbolTableCollection *symbolTables = nullptr,
+ Type resultType = {});
+
} // namespace func
} // namespace mlir
diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
similarity index 88%
rename from mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
rename to mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
index 79816fc6e3bf1..b9ba94ef08098 100644
--- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
+#include "Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
@@ -25,47 +26,6 @@ namespace mlir {
using namespace mlir;
using namespace mlir::func;
-static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
- StringRef name, FunctionType funcT, bool setPrivate,
- SymbolTableCollection *symbolTables = nullptr) {
- OpBuilder::InsertionGuard g(b);
- assert(!symTable->getRegion(0).empty() && "expected non-empty region");
- b.setInsertionPointToStart(&symTable->getRegion(0).front());
- FuncOp funcOp = FuncOp::create(b, symTable->getLoc(), name, funcT);
- if (setPrivate)
- funcOp.setPrivate();
- if (symbolTables) {
- SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable);
- symbolTable.insert(funcOp, symTable->getRegion(0).front().begin());
- }
- return funcOp;
-}
-
-/// Helper function to look up or create the symbol for a runtime library
-/// function with the given parameter types. Returns an int64_t, unless a
-/// different result type is specified.
-static FailureOr<FuncOp>
-lookupOrCreateApFloatFn(OpBuilder &b, SymbolOpInterface symTable,
- StringRef name, TypeRange paramTypes,
- SymbolTableCollection *symbolTables = nullptr,
- Type resultType = {}) {
- if (!resultType)
- resultType = IntegerType::get(symTable->getContext(), 64);
- std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
- auto funcT = FunctionType::get(b.getContext(), paramTypes, {resultType});
- FailureOr<FuncOp> func =
- lookupFnDecl(symTable, funcName, funcT, symbolTables);
- // Failed due to type mismatch.
- if (failed(func))
- return func;
- // Successfully matched existing decl.
- if (*func)
- return *func;
-
- return createFnDecl(b, symTable, funcName, funcT,
- /*setPrivate=*/true, symbolTables);
-}
-
/// Helper function to look up or create the symbol for a runtime library
/// function for a binary arithmetic operation.
///
@@ -81,14 +41,9 @@ lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
SymbolTableCollection *symbolTables = nullptr) {
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
- return lookupOrCreateApFloatFn(b, symTable, name, {i32Type, i64Type, i64Type},
- symbolTables);
-}
-
-static Value getSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy) {
- int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
- return arith::ConstantOp::create(b, loc, b.getI32Type(),
- b.getIntegerAttr(b.getI32Type(), sem));
+ std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
+ return lookupOrCreateFnDecl(b, symTable, funcName,
+ {i32Type, i64Type, i64Type}, symbolTables);
}
/// Given two operands of vector type and vector result type (with the same
@@ -197,7 +152,7 @@ struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
arith::BitcastOp::create(rewriter, loc, intWType, rhs));
// Call APFloat function.
- Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
SmallVector<Value> params = {semValue, lhsBits, rhsBits};
auto resultOp = func::CallOp::create(rewriter, loc,
TypeRange(rewriter.getI64Type()),
@@ -231,8 +186,9 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
// Get APFloat function from runtime library.
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
- FailureOr<FuncOp> fn = lookupOrCreateApFloatFn(
- rewriter, symTable, "convert", {i32Type, i32Type, i64Type});
+ FailureOr<FuncOp> fn =
+ lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_convert",
+ {i32Type, i32Type, i64Type});
if (failed(fn))
return fn;
@@ -250,9 +206,10 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
// Call APFloat function.
- Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
+ Value inSemValue = getAPFloatSemanticsValue(rewriter, loc, inFloatTy);
auto outFloatTy = cast<FloatType>(resultType);
- Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
+ Value outSemValue =
+ getAPFloatSemanticsValue(rewriter, loc, outFloatTy);
std::array<Value, 3> params = {inSemValue, outSemValue, operandBits};
auto resultOp = func::CallOp::create(rewriter, loc,
TypeRange(rewriter.getI64Type()),
@@ -289,8 +246,8 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
FailureOr<FuncOp> fn =
- lookupOrCreateApFloatFn(rewriter, symTable, "convert_to_int",
- {i32Type, i32Type, i1Type, i64Type});
+ lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_convert_to_int",
+ {i32Type, i32Type, i1Type, i64Type});
if (failed(fn))
return fn;
@@ -308,7 +265,7 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
// Call APFloat function.
- Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
+ Value inSemValue = getAPFloatSemanticsValue(rewriter, loc, inFloatTy);
auto outIntTy = cast<IntegerType>(resultType);
Value outWidthValue = arith::ConstantOp::create(
rewriter, loc, i32Type,
@@ -350,9 +307,9 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
auto i1Type = IntegerType::get(symTable->getContext(), 1);
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
- FailureOr<FuncOp> fn =
- lookupOrCreateApFloatFn(rewriter, symTable, "convert_from_int",
- {i32Type, i32Type, i1Type, i64Type});
+ FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
+ rewriter, symTable, "_mlir_apfloat_convert_from_int",
+ {i32Type, i32Type, i1Type, i64Type});
if (failed(fn))
return fn;
@@ -377,7 +334,8 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
// Call APFloat function.
auto outFloatTy = cast<FloatType>(resultType);
- Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
+ Value outSemValue =
+ getAPFloatSemanticsValue(rewriter, loc, outFloatTy);
Value inWidthValue = arith::ConstantOp::create(
rewriter, loc, i32Type,
rewriter.getIntegerAttr(i32Type, inIntTy.getWidth()));
@@ -421,8 +379,8 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
FailureOr<FuncOp> fn =
- lookupOrCreateApFloatFn(rewriter, symTable, "compare",
- {i32Type, i64Type, i64Type}, nullptr, i8Type);
+ lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_compare",
+ {i32Type, i64Type, i64Type}, nullptr, i8Type);
if (failed(fn))
return fn;
@@ -443,7 +401,7 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
arith::BitcastOp::create(rewriter, loc, intWType, rhs));
// Call APFloat function.
- Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
SmallVector<Value> params = {semValue, lhsBits, rhsBits};
Value comparisonResult =
func::CallOp::create(rewriter, loc, TypeRange(i8Type),
@@ -569,8 +527,8 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
// Get APFloat function from runtime library.
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
- FailureOr<FuncOp> fn =
- lookupOrCreateApFloatFn(rewriter, symTable, "neg", {i32Type, i64Type});
+ FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
+ rewriter, symTable, "_mlir_apfloat_neg", {i32Type, i64Type});
if (failed(fn))
return fn;
@@ -588,7 +546,7 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
arith::BitcastOp::create(rewriter, loc, intWType, operand1));
// Call APFloat function.
- Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
SmallVector<Value> params = {semValue, operandBits};
Value negatedBits =
func::CallOp::create(rewriter, loc, TypeRange(i64Type),
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/CMakeLists.txt b/mlir/lib/Conversion/ArithAndMathToAPFloat/CMakeLists.txt
new file mode 100644
index 0000000000000..cc8e61a87addc
--- /dev/null
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/CMakeLists.txt
@@ -0,0 +1,49 @@
+add_mlir_library(ArithAndMathToAPFloatUtils
+ Utils.cpp
+ PARTIAL_SOURCES_INTENDED
+
+ LINK_LIBS PUBLIC
+ MLIRArithDialect
+)
+
+add_mlir_conversion_library(MLIRArithToAPFloat
+ ArithToAPFloat.cpp
+ PARTIAL_SOURCES_INTENDED
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ ArithAndMathToAPFloatUtils
+ MLIRArithDialect
+ MLIRArithTransforms
+ MLIRFuncDialect
+ MLIRFuncUtils
+ MLIRVectorDialect
+ )
+
+add_mlir_conversion_library(MLIRMathToAPFloat
+ MathToAPFloat.cpp
+ PARTIAL_SOURCES_INTENDED
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToLLVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ ArithAndMathToAPFloatUtils
+ MLIRMathDialect
+ MLIRFuncDialect
+ MLIRFuncUtils
+ )
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
new file mode 100644
index 0000000000000..e540747ac0abd
--- /dev/null
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
@@ -0,0 +1,148 @@
+//===- MathToAPFloat.cpp - Mathmetic to APFloat Conversion ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToAPFloat/MathToAPFloat.h"
+#include "Utils.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Utils/Utils.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_MATHTOAPFLOATCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::func;
+
+struct AbsFOpToAPFloatConversion final : OpRewritePattern<math::AbsFOp> {
+ AbsFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<math::AbsFOp>(context, benefit), symTable(symTable) {}
+
+ LogicalResult matchAndRewrite(math::AbsFOp op,
+ PatternRewriter &rewriter) const override {
+ // Get APFloat function from runtime library.
+ auto i32Type = IntegerType::get(symTable->getContext(), 32);
+ auto i64Type = IntegerType::get(symTable->getContext(), 64);
+ FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
+ rewriter, symTable, "_mlir_apfloat_abs", {i32Type, i64Type});
+ if (failed(fn))
+ return fn;
+ Location loc = op.getLoc();
+ rewriter.setInsertionPoint(op);
+ // Cast operands to 64-bit integers.
+ auto operand = op.getOperand();
+ auto floatTy = cast<FloatType>(operand.getType());
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ Value operandBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, operand));
+
+ // Call APFloat function.
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
+ SmallVector<Value> params = {semValue, operandBits};
+ Value negatedBits = func::CallOp::create(rewriter, loc, TypeRange(i64Type),
+ SymbolRefAttr::get(*fn), params)
+ ->getResult(0);
+
+ // Truncate result to the original width.
+ Value truncatedBits =
+ arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
+ rewriter.replaceOp(
+ op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
+ return success();
+ }
+
+ SymbolOpInterface symTable;
+};
+
+template <typename OpTy>
+struct IsOpToAPFloatConversion final : OpRewritePattern<OpTy> {
+ IsOpToAPFloatConversion(MLIRContext *context, const char *APFloatName,
+ SymbolOpInterface symTable,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
+ APFloatName(APFloatName) {};
+
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override {
+ // Get APFloat function from runtime library.
+ auto i1 = IntegerType::get(symTable->getContext(), 1);
+ auto i32Type = IntegerType::get(symTable->getContext(), 32);
+ auto i64Type = IntegerType::get(symTable->getContext(), 64);
+ std::string funcName =
+ (llvm::Twine("_mlir_apfloat_is") + APFloatName).str();
+ FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
+ rewriter, symTable, funcName, {i32Type, i64Type}, nullptr, i1);
+ if (failed(fn))
+ return fn;
+ Location loc = op.getLoc();
+ rewriter.setInsertionPoint(op);
+ // Cast operands to 64-bit integers.
+ auto operand = op.getOperand();
+ auto floatTy = cast<FloatType>(operand.getType());
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ Value operandBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, operand));
+
+ // Call APFloat function.
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
+ SmallVector<Value> params = {semValue, operandBits};
+ rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(i64Type),
+ SymbolRefAttr::get(*fn), params);
+ return success();
+ }
+
+ SymbolOpInterface symTable;
+ const char *APFloatName;
+};
+
+namespace {
+struct MathToAPFloatConversionPass final
+ : impl::MathToAPFloatConversionPassBase<MathToAPFloatConversionPass> {
+ using Base::Base;
+
+ void runOnOperation() override;
+};
+
+void MathToAPFloatConversionPass::runOnOperation() {
+ MLIRContext *context = &getContext();
+ RewritePatternSet patterns(context);
+
+ patterns.add<AbsFOpToAPFloatConversion>(context, getOperation());
+ patterns.add<IsOpToAPFloatConversion<math::IsFiniteOp>>(context, "finite",
+ getOperation());
+ patterns.add<IsOpToAPFloatConversion<math::IsInfOp>>(context, "infinite",
+ getOperation());
+ patterns.add<IsOpToAPFloatConversion<math::IsNaNOp>>(context, "nan",
+ getOperation());
+ patterns.add<IsOpToAPFloatConversion<math::IsNormalOp>>(context, "normal",
+ getOperation());
+
+ LogicalResult result = success();
+ ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
+ if (diag.getSeverity() == DiagnosticSeverity::Error) {
+ result = failure();
+ }
+ // NB: if you don't return failure, no other diag handlers will fire (see
+ // mlir/lib/IR/Diagnostics.cpp:DiagnosticEngineImpl::emit).
+ return failure();
+ });
+ walkAndApplyPatterns(getOperation(), std::move(patterns));
+ if (failed(result))
+ return signalPassFailure();
+}
+} // namespace
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp
new file mode 100644
index 0000000000000..2b5857367dc40
--- /dev/null
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp
@@ -0,0 +1,22 @@
+//===- Utils.cpp - Utils for APFloat Conversion ---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "Utils.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Value.h"
+
+mlir::Value mlir::getAPFloatSemanticsValue(OpBuilder &b, Location loc,
+ FloatType floatTy) {
+ int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
+ return arith::ConstantOp::create(b, loc, b.getI32Type(),
+ b.getIntegerAttr(b.getI32Type(), sem));
+}
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h
new file mode 100644
index 0000000000000..5f11d24261b43
--- /dev/null
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h
@@ -0,0 +1,21 @@
+//===- Utils.h - Utils for APFloat Conversion - C++ -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_UTILS_H_
+#define MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_UTILS_H_
+
+namespace mlir {
+class Value;
+class OpBuilder;
+class Location;
+class FloatType;
+
+Value getAPFloatSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_UTILS_H_
diff --git a/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt b/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt
deleted file mode 100644
index 31fce7a4de8a2..0000000000000
--- a/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt
+++ /dev/null
@@ -1,19 +0,0 @@
-add_mlir_conversion_library(MLIRArithToAPFloat
- ArithToAPFloat.cpp
-
- ADDITIONAL_HEADER_DIRS
- ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM
-
- DEPENDS
- MLIRConversionPassIncGen
-
- LINK_COMPONENTS
- Core
-
- LINK_LIBS PUBLIC
- MLIRArithDialect
- MLIRArithTransforms
- MLIRFuncDialect
- MLIRFuncUtils
- MLIRVectorDialect
- )
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 613dc6d242ceb..2ed10effb53da 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -2,7 +2,7 @@ add_subdirectory(AffineToStandard)
add_subdirectory(AMDGPUToROCDL)
add_subdirectory(ArithCommon)
add_subdirectory(ArithToAMDGPU)
-add_subdirectory(ArithToAPFloat)
+add_subdirectory(ArithAndMathToAPFloat)
add_subdirectory(ArithToArmSME)
add_subdirectory(ArithToEmitC)
add_subdirectory(ArithToLLVM)
diff --git a/mlir/lib/Dialect/Func/Utils/Utils.cpp b/mlir/lib/Dialect/Func/Utils/Utils.cpp
index d6dfd0229963c..0a56817b704ff 100644
--- a/mlir/lib/Dialect/Func/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Func/Utils/Utils.cpp
@@ -279,3 +279,42 @@ func::lookupFnDecl(SymbolOpInterface symTable, StringRef name,
}
return func;
}
+
+func::FuncOp func::createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
+ StringRef name, FunctionType funcT,
+ bool setPrivate,
+ SymbolTableCollection *symbolTables) {
+ OpBuilder::InsertionGuard g(b);
+ assert(!symTable->getRegion(0).empty() && "expected non-empty region");
+ b.setInsertionPointToStart(&symTable->getRegion(0).front());
+ func::FuncOp funcOp =
+ func::FuncOp::create(b, symTable->getLoc(), name, funcT);
+ if (setPrivate)
+ funcOp.setPrivate();
+ if (symbolTables) {
+ SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable);
+ symbolTable.insert(funcOp, symTable->getRegion(0).front().begin());
+ }
+ return funcOp;
+}
+
+FailureOr<func::FuncOp>
+func::lookupOrCreateFnDecl(OpBuilder &b, SymbolOpInterface symTable,
+ StringRef funcName, TypeRange paramTypes,
+ SymbolTableCollection *symbolTables,
+ Type resultType) {
+ if (!resultType)
+ resultType = IntegerType::get(symTable->getContext(), 64);
+ auto funcT = FunctionType::get(b.getContext(), paramTypes, {resultType});
+ FailureOr<func::FuncOp> func =
+ lookupFnDecl(symTable, funcName, funcT, symbolTables);
+ // Failed due to type mismatch.
+ if (failed(func))
+ return func;
+ // Successfully matched existing decl.
+ if (*func)
+ return *func;
+
+ return createFnDecl(b, symTable, funcName, funcT,
+ /*setPrivate=*/true, symbolTables);
+}
\ No newline at end of file
diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
index f3e38eb8ffa2d..0c076af20dea7 100644
--- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
@@ -143,7 +143,8 @@ MLIR_APFLOAT_WRAPPERS_EXPORT int8_t _mlir_apfloat_compare(int32_t semantics,
return static_cast<int8_t>(x.compare(y));
}
-MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_neg(int32_t semantics, uint64_t a) {
+MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_neg(int32_t semantics,
+ uint64_t a) {
const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
static_cast<llvm::APFloatBase::Semantics>(semantics));
unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
@@ -152,6 +153,51 @@ MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_neg(int32_t semantics, uint6
return x.bitcastToAPInt().getZExtValue();
}
+MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_abs(int32_t semantics,
+ uint64_t a) {
+ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+ static_cast<llvm::APFloatBase::Semantics>(semantics));
+ unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
+ llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
+ return abs(x).bitcastToAPInt().getZExtValue();
+}
+
+MLIR_APFLOAT_WRAPPERS_EXPORT bool _mlir_apfloat_isfinite(int32_t semantics,
+ uint64_t a) {
+ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+ static_cast<llvm::APFloatBase::Semantics>(semantics));
+ unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
+ llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
+ return x.isFinite();
+}
+
+MLIR_APFLOAT_WRAPPERS_EXPORT bool _mlir_apfloat_isinfinite(int32_t semantics,
+ uint64_t a) {
+ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+ static_cast<llvm::APFloatBase::Semantics>(semantics));
+ unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
+ llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
+ return x.isInfinity();
+}
+
+MLIR_APFLOAT_WRAPPERS_EXPORT bool _mlir_apfloat_isnormal(int32_t semantics,
+ uint64_t a) {
+ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+ static_cast<llvm::APFloatBase::Semantics>(semantics));
+ unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
+ llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
+ return x.isNormal();
+}
+
+MLIR_APFLOAT_WRAPPERS_EXPORT bool _mlir_apfloat_isnan(int32_t semantics,
+ uint64_t a) {
+ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+ static_cast<llvm::APFloatBase::Semantics>(semantics));
+ unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
+ llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
+ return x.isNaN();
+}
+
/// Min/max operations.
#define APFLOAT_MIN_MAX_OP(OP) \
MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_##OP( \
diff --git a/mlir/lib/ExecutionEngine/CMakeLists.txt b/mlir/lib/ExecutionEngine/CMakeLists.txt
index a6153523a5e97..3ce4079f16644 100644
--- a/mlir/lib/ExecutionEngine/CMakeLists.txt
+++ b/mlir/lib/ExecutionEngine/CMakeLists.txt
@@ -168,8 +168,8 @@ if(LLVM_ENABLE_PIC)
set_property(TARGET mlir_float16_utils PROPERTY CXX_STANDARD 17)
target_compile_definitions(mlir_float16_utils PRIVATE mlir_float16_utils_EXPORTS)
- if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
- # TODO: This support library is only used on Linux builds until we figure
+ if(CMAKE_SYSTEM_NAME STREQUAL "Linux" OR CMAKE_SYSTEM_NAME STREQUAL "Darwin")
+ # TODO: This support library is only used on Linux and Mac builds until we figure
# out how to hide LLVM symbols in a way that works for all platforms.
add_mlir_library(mlir_apfloat_wrappers
SHARED
@@ -185,7 +185,7 @@ if(LLVM_ENABLE_PIC)
)
target_compile_definitions(mlir_apfloat_wrappers PRIVATE mlir_apfloat_wrappers_EXPORTS)
# Hide LLVM symbols to avoid ODR violations.
- target_link_options(mlir_apfloat_wrappers PRIVATE "-Wl,--exclude-libs,ALL")
+ target_link_options(mlir_apfloat_wrappers PRIVATE $<$<PLATFORM_ID:Linux>:LINKER:--exclude-libs,ALL>)
endif()
add_subdirectory(SparseTensor)
@@ -205,8 +205,8 @@ if(LLVM_ENABLE_PIC)
set_property(TARGET mlir_c_runner_utils PROPERTY CXX_STANDARD 17)
target_compile_definitions(mlir_c_runner_utils PRIVATE mlir_c_runner_utils_EXPORTS)
- # Conditionally link apfloat wrappers only on Linux.
- if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
+ # Conditionally link apfloat wrappers only on Linux and Mac.
+ if(CMAKE_SYSTEM_NAME STREQUAL "Linux" OR CMAKE_SYSTEM_NAME STREQUAL "Darwin")
target_link_libraries(mlir_c_runner_utils PUBLIC mlir_apfloat_wrappers)
endif()
diff --git a/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir
new file mode 100644
index 0000000000000..aca8a432a53b9
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir
@@ -0,0 +1,32 @@
+// REQUIRES: system-linux || system-darwin
+// TODO: Run only on Linux until we figure out how to build
+// mlir_apfloat_wrappers in a platform-independent way.
+
+// Case 1: All floating-point arithmetics is lowered through APFloat.
+// RUN: mlir-opt %s --convert-math-to-apfloat --convert-to-llvm | \
+// RUN: mlir-runner -e entry --entry-point-result=void \
+// RUN: --shared-libs=%mlir_c_runner_utils \
+// RUN: --shared-libs=%mlir_apfloat_wrappers | FileCheck %s
+
+// Case 2: Only unsupported arithmetics (f8E4M3FN) is lowered through APFloat.
+// Arithmetics on f32 is lowered directly to LLVM.
+// RUN: mlir-opt %s --convert-to-llvm --convert-math-to-apfloat \
+// RUN: --convert-to-llvm --reconcile-unrealized-casts | \
+// RUN: mlir-runner -e entry --entry-point-result=void \
+// RUN: --shared-libs=%mlir_c_runner_utils \
+// RUN: --shared-libs=%mlir_apfloat_wrappers | FileCheck %s
+
+func.func @entry() {
+ %neg14fp8 = arith.constant -1.4 : f8E4M3FN
+ %neg14fp32 = arith.constant 1.4 : f32
+
+ // CHECK: 1.375
+ %c2 = math.absf %neg14fp8 : f8E4M3FN
+ vector.print %c2 : f8E4M3FN
+
+ // CHECK: 1.4
+ %c3 = math.absf %neg14fp32 : f32
+ vector.print %c3 : f32
+
+ return
+}
>From 8e34b3a281bd82a105df155f2ad3c6a9f3208930 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Wed, 10 Dec 2025 11:37:24 -0800
Subject: [PATCH 2/2] not working (print is wrong?)
---
.../ArithAndMathToAPFloat/MathToAPFloat.cpp | 51 +++++++++++++++++++
mlir/lib/ExecutionEngine/APFloatWrappers.cpp | 23 +++++++++
.../Math/CPU/test-apfloat-emulation.mlir | 9 ++++
mlir/test/lit.cfg.py | 4 +-
4 files changed, 85 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
index e540747ac0abd..4c8764ba1d6b0 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
@@ -110,6 +110,56 @@ struct IsOpToAPFloatConversion final : OpRewritePattern<OpTy> {
const char *APFloatName;
};
+struct FmaOpToAPFloatConversion final : OpRewritePattern<math::FmaOp> {
+ FmaOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<math::FmaOp>(context, benefit), symTable(symTable) {};
+
+ LogicalResult matchAndRewrite(math::FmaOp op,
+ PatternRewriter &rewriter) const override {
+
+ auto i32Type = IntegerType::get(symTable->getContext(), 32);
+ auto i64Type = IntegerType::get(symTable->getContext(), 64);
+ FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
+ rewriter, symTable, "_mlir_apfloat_fused_multiply_add",
+ {i32Type, i64Type, i64Type, i64Type});
+ if (failed(fn))
+ return fn;
+ Location loc = op.getLoc();
+ rewriter.setInsertionPoint(op);
+
+ // Cast operands to 64-bit integers.
+ auto floatTy = cast<FloatType>(op.getResult().getType());
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ auto int64Type = rewriter.getI64Type();
+ Value operand = arith::ExtUIOp::create(
+ rewriter, loc, int64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, op.getA()));
+ Value multiplicand = arith::ExtUIOp::create(
+ rewriter, loc, int64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, op.getB()));
+ Value addend = arith::ExtUIOp::create(
+ rewriter, loc, int64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, op.getC()));
+
+ // Call APFloat function.
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
+ SmallVector<Value> params = {semValue, operand, multiplicand, addend};
+ auto resultOp =
+ func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*fn), params);
+
+ // Truncate result to the original width.
+ Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
+ resultOp->getResult(0));
+ rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, floatTy, truncatedBits);
+ return success();
+ }
+
+ SymbolOpInterface symTable;
+ const char *APFloatName;
+};
+
namespace {
struct MathToAPFloatConversionPass final
: impl::MathToAPFloatConversionPassBase<MathToAPFloatConversionPass> {
@@ -131,6 +181,7 @@ void MathToAPFloatConversionPass::runOnOperation() {
getOperation());
patterns.add<IsOpToAPFloatConversion<math::IsNormalOp>>(context, "normal",
getOperation());
+ patterns.add<FmaOpToAPFloatConversion>(context, getOperation());
LogicalResult result = success();
ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
index 0c076af20dea7..254590a0d8566 100644
--- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
@@ -21,6 +21,7 @@
//
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APSInt.h"
+#include "llvm/Support/Debug.h"
#ifdef _WIN32
#ifndef MLIR_APFLOAT_WRAPPERS_EXPORT
@@ -198,6 +199,28 @@ MLIR_APFLOAT_WRAPPERS_EXPORT bool _mlir_apfloat_isnan(int32_t semantics,
return x.isNaN();
}
+MLIR_APFLOAT_WRAPPERS_EXPORT bool
+_mlir_apfloat_fused_multiply_add(int32_t semantics, uint64_t operand,
+ uint64_t multiplicand, uint64_t addend) {
+ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+ static_cast<llvm::APFloatBase::Semantics>(semantics));
+ unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
+ llvm::APFloat operand_(sem, llvm::APInt(bitWidth, operand));
+ llvm::APFloat multiplicand_(sem, llvm::APInt(bitWidth, multiplicand));
+ llvm::APFloat addend_(sem, llvm::APInt(bitWidth, addend));
+ llvm::detail::opStatus stat = operand_.fusedMultiplyAdd(
+ multiplicand_, addend_, llvm::RoundingMode::NearestTiesToEven);
+
+ ////////////
+ operand_.print(llvm::dbgs());
+ llvm::dbgs() << "\n";
+ ////////////
+
+ assert(stat == llvm::APFloatBase::opOK &&
+ "expected fusedMultiplyAdd status to be OK");
+ return operand_.bitcastToAPInt().getZExtValue();
+}
+
/// Min/max operations.
#define APFLOAT_MIN_MAX_OP(OP) \
MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_##OP( \
diff --git a/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir
index aca8a432a53b9..892b970a2796d 100644
--- a/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir
+++ b/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir
@@ -28,5 +28,14 @@ func.func @entry() {
%c3 = math.absf %neg14fp32 : f32
vector.print %c3 : f32
+ // see llvm/unittests/ADT/APFloatTest::TEST(APFloatTest, Float8E8M0FNUFMA)
+ %twof8E8M0FNU = arith.constant 2.0 : f8E8M0FNU
+ %fourf8E8M0FNU = arith.constant 4.0 : f8E8M0FNU
+ %eightf8E8M0FNU = arith.constant 8.0 : f8E8M0FNU
+
+ // CHECK: 16
+ %c4 = math.fma %fourf8E8M0FNU, %twof8E8M0FNU, %eightf8E8M0FNU : f8E8M0FNU
+ // vector.print %c4 : f8E8M0FNU
+
return
}
diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py
index 675ded35d98f3..9c5bee169efe0 100644
--- a/mlir/test/lit.cfg.py
+++ b/mlir/test/lit.cfg.py
@@ -214,8 +214,8 @@ def find_real_python_interpreter():
"not",
]
-if "Linux" in config.host_os:
- # TODO: Run only on Linux until we figure out how to build
+if "Linux" in config.host_os or "Darwin" in config.host_os:
+ # TODO: Run only on Linux and Mac until we figure out how to build
# mlir_apfloat_wrappers in a platform-independent way.
tools.extend([add_runtime("mlir_apfloat_wrappers")])
More information about the Mlir-commits
mailing list