[Mlir-commits] [mlir] 05b1989 - [mlir][arith] Add support for `negf` to `ArithToAPFloat` (#169759)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 1 00:28:27 PST 2025
Author: Matthias Springer
Date: 2025-12-01T08:28:23Z
New Revision: 05b19895510af314a78ed42c6a969c4478a8f496
URL: https://github.com/llvm/llvm-project/commit/05b19895510af314a78ed42c6a969c4478a8f496
DIFF: https://github.com/llvm/llvm-project/commit/05b19895510af314a78ed42c6a969c4478a8f496.diff
LOG: [mlir][arith] Add support for `negf` to `ArithToAPFloat` (#169759)
Add support for `arith.negf`.
Added:
Modified:
mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
mlir/lib/ExecutionEngine/APFloatWrappers.cpp
mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
index 566632bd8707f..024a97b03c14e 100644
--- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
@@ -449,6 +449,49 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
SymbolOpInterface symTable;
};
+struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
+ NegFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<arith::NegFOp>(context, benefit), symTable(symTable) {}
+
+ LogicalResult matchAndRewrite(arith::NegFOp op,
+ PatternRewriter &rewriter) const override {
+ // Get APFloat function from runtime library.
+ auto i32Type = IntegerType::get(symTable->getContext(), 32);
+ auto i64Type = IntegerType::get(symTable->getContext(), 64);
+ FailureOr<FuncOp> fn =
+ lookupOrCreateApFloatFn(rewriter, symTable, "neg", {i32Type, i64Type});
+ if (failed(fn))
+ return fn;
+
+ // Cast operand to 64-bit integer.
+ rewriter.setInsertionPoint(op);
+ Location loc = op.getLoc();
+ auto floatTy = cast<FloatType>(op.getOperand().getType());
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ Value operandBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type, arith::BitcastOp::create(rewriter, loc, intWType, op.getOperand()));
+
+ // Call APFloat function.
+ Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+ SmallVector<Value> params = {semValue, operandBits};
+ Value negatedBits =
+ func::CallOp::create(rewriter, loc, TypeRange(i64Type),
+ SymbolRefAttr::get(*fn), params)
+ ->getResult(0);
+
+ // Truncate result to the original width.
+ Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
+ negatedBits);
+ Value result =
+ arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits);
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+
+ SymbolOpInterface symTable;
+};
+
namespace {
struct ArithToAPFloatConversionPass final
: impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
@@ -471,7 +514,8 @@ void ArithToAPFloatConversionPass::runOnOperation() {
patterns.add<BinaryArithOpToAPFloatConversion<arith::RemFOp>>(
context, "remainder", getOperation());
patterns
- .add<FpToFpConversion<arith::ExtFOp>, FpToFpConversion<arith::TruncFOp>>(
+ .add<FpToFpConversion<arith::ExtFOp>, FpToFpConversion<arith::TruncFOp>,
+ CmpFOpToAPFloatConversion, NegFOpToAPFloatConversion>(
context, getOperation());
patterns.add<FpToIntConversion<arith::FPToSIOp>>(context, getOperation(),
/*isUnsigned=*/false);
@@ -481,7 +525,6 @@ void ArithToAPFloatConversionPass::runOnOperation() {
/*isUnsigned=*/false);
patterns.add<IntToFpConversion<arith::UIToFPOp>>(context, getOperation(),
/*isUnsigned=*/true);
- patterns.add<CmpFOpToAPFloatConversion>(context, getOperation());
LogicalResult result = success();
ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
if (diag.getSeverity() == DiagnosticSeverity::Error) {
diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
index 77f7137264888..f2d5254be6b57 100644
--- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
@@ -142,4 +142,13 @@ MLIR_APFLOAT_WRAPPERS_EXPORT int8_t _mlir_apfloat_compare(int32_t semantics,
llvm::APFloat y(sem, llvm::APInt(bitWidth, b));
return static_cast<int8_t>(x.compare(y));
}
+
+MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_neg(int32_t semantics, uint64_t a) {
+ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+ static_cast<llvm::APFloatBase::Semantics>(semantics));
+ unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
+ llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
+ x.changeSign();
+ return x.bitcastToAPInt().getZExtValue();
+}
}
diff --git a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
index 78ce3640ecc67..775cb5ea60f22 100644
--- a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
+++ b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
@@ -213,3 +213,13 @@ func.func @cmpf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
%0 = arith.cmpf "ult", %arg0, %arg1 : f4E2M1FN
return
}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_neg(i32, i64) -> i64
+// CHECK: %[[sem:.*]] = arith.constant 2 : i32
+// CHECK: %[[res:.*]] = call @_mlir_apfloat_neg(%[[sem]], %{{.*}}) : (i32, i64) -> i64
+func.func @negf(%arg0: f32) {
+ %0 = arith.negf %arg0 : f32
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
index 433d058d025cf..555cc9a531966 100644
--- a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
@@ -43,6 +43,10 @@ func.func @entry() {
%cvt = arith.truncf %b2 : f32 to f8E4M3FN
vector.print %cvt : f8E4M3FN
+ // CHECK-NEXT: -2.25
+ %negated = arith.negf %cvt : f8E4M3FN
+ vector.print %negated : f8E4M3FN
+
// CHECK-NEXT: 1
%cmp1 = arith.cmpf "olt", %cvt, %c1 : f8E4M3FN
vector.print %cmp1 : i1
More information about the Mlir-commits
mailing list