[llvm-branch-commits] [mlir] [mlir][arith] Add support for `fptosi`, `fptoui` to `ArithToAPFloat` (PR #169277)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Sun Nov 23 20:22:05 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-arith
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
Add support for `arith.fptosi` and `arith.fptoui`.
Depends on #<!-- -->169275.
---
Full diff: https://github.com/llvm/llvm-project/pull/169277.diff
4 Files Affected:
- (modified) mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp (+58)
- (modified) mlir/lib/ExecutionEngine/APFloatWrappers.cpp (+14)
- (modified) mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir (+26)
- (modified) mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir (+10)
``````````diff
diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
index 90e6e674da519..1fe698f1c8902 100644
--- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
@@ -185,6 +185,60 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
SymbolOpInterface symTable;
};
+template <typename OpTy>
+struct FpToIntConversion final : OpRewritePattern<OpTy> {
+ FpToIntConversion(MLIRContext *context, SymbolOpInterface symTable,
+ bool isUnsigned, PatternBenefit benefit = 1)
+ : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
+ isUnsigned(isUnsigned){};
+
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override {
+ // Get APFloat function from runtime library.
+ auto i1Type = IntegerType::get(symTable->getContext(), 1);
+ auto i32Type = IntegerType::get(symTable->getContext(), 32);
+ auto i64Type = IntegerType::get(symTable->getContext(), 64);
+ FailureOr<FuncOp> fn =
+ lookupOrCreateApFloatFn(rewriter, symTable, "convert_to_int",
+ {i32Type, i32Type, i1Type, i64Type});
+ if (failed(fn))
+ return fn;
+
+ rewriter.setInsertionPoint(op);
+ // Cast operands to 64-bit integers.
+ Location loc = op.getLoc();
+ auto inFloatTy = cast<FloatType>(op.getOperand().getType());
+ auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
+ auto int64Type = rewriter.getI64Type();
+ Value operandBits = arith::ExtUIOp::create(
+ rewriter, loc, int64Type,
+ arith::BitcastOp::create(rewriter, loc, inIntWType, op.getOperand()));
+
+ // Call APFloat function.
+ Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
+ auto outIntTy = cast<IntegerType>(op.getType());
+ Value outWidthValue = arith::ConstantOp::create(
+ rewriter, loc, i32Type,
+ rewriter.getIntegerAttr(i32Type, outIntTy.getWidth()));
+ Value isUnsignedValue = arith::ConstantOp::create(
+ rewriter, loc, i1Type, rewriter.getIntegerAttr(i1Type, isUnsigned));
+ SmallVector<Value> params = {inSemValue, outWidthValue, isUnsignedValue,
+ operandBits};
+ auto resultOp =
+ func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*fn), params);
+
+ // Truncate result to the original width.
+ Value truncatedBits = arith::TruncIOp::create(rewriter, loc, outIntTy,
+ resultOp->getResult(0));
+ rewriter.replaceOp(op, truncatedBits);
+ return success();
+ }
+
+ SymbolOpInterface symTable;
+ bool isUnsigned;
+};
+
namespace {
struct ArithToAPFloatConversionPass final
: impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
@@ -208,6 +262,10 @@ void ArithToAPFloatConversionPass::runOnOperation() {
context, "remainder", getOperation());
patterns.add<FpToFpConversion<arith::ExtFOp>>(context, getOperation());
patterns.add<FpToFpConversion<arith::TruncFOp>>(context, getOperation());
+ patterns.add<FpToIntConversion<arith::FPToSIOp>>(context, getOperation(),
+ /*isUnsigned=*/false);
+ patterns.add<FpToIntConversion<arith::FPToUIOp>>(context, getOperation(),
+ /*isUnsigned=*/true);
LogicalResult result = success();
ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
if (diag.getSeverity() == DiagnosticSeverity::Error) {
diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
index 511b05ea380f0..632fe9cf2269d 100644
--- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
@@ -20,6 +20,7 @@
// APFloatBase::Semantics enum value.
//
#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/APSInt.h"
#ifdef _WIN32
#ifndef MLIR_APFLOAT_WRAPPERS_EXPORT
@@ -101,4 +102,17 @@ _mlir_apfloat_convert(int32_t inSemantics, int32_t outSemantics, uint64_t a) {
llvm::APInt result = val.bitcastToAPInt();
return result.getZExtValue();
}
+
+MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_convert_to_int(
+ int32_t semantics, int32_t resultWidth, bool isUnsigned, uint64_t a) {
+ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+ static_cast<llvm::APFloatBase::Semantics>(semantics));
+ unsigned inputWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
+ llvm::APFloat val(sem, llvm::APInt(inputWidth, a));
+ llvm::APSInt result(resultWidth, isUnsigned);
+ bool isExact;
+ // TODO: Custom rounding modes are not supported yet.
+ val.convertToInteger(result, llvm::RoundingMode::NearestTiesToEven, &isExact);
+ return result.getZExtValue();
+}
}
diff --git a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
index 038acbfc965a2..f1acfd5e5618a 100644
--- a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
+++ b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
@@ -148,3 +148,29 @@ func.func @truncf(%arg0: bf16) {
%0 = arith.truncf %arg0 : bf16 to f4E2M1FN
return
}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_convert_to_int(i32, i32, i1, i64) -> i64
+// CHECK: %[[sem_in:.*]] = arith.constant 0 : i32
+// CHECK: %[[out_width:.*]] = arith.constant 4 : i32
+// CHECK: %[[is_unsigned:.*]] = arith.constant false
+// CHECK: %[[res:.*]] = call @_mlir_apfloat_convert_to_int(%[[sem_in]], %[[out_width]], %[[is_unsigned]], %{{.*}}) : (i32, i32, i1, i64) -> i64
+// CHECK: arith.trunci %[[res]] : i64 to i4
+func.func @fptosi(%arg0: f16) {
+ %0 = arith.fptosi %arg0 : f16 to i4
+ return
+}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_convert_to_int(i32, i32, i1, i64) -> i64
+// CHECK: %[[sem_in:.*]] = arith.constant 0 : i32
+// CHECK: %[[out_width:.*]] = arith.constant 4 : i32
+// CHECK: %[[is_unsigned:.*]] = arith.constant true
+// CHECK: %[[res:.*]] = call @_mlir_apfloat_convert_to_int(%[[sem_in]], %[[out_width]], %[[is_unsigned]], %{{.*}}) : (i32, i32, i1, i64) -> i64
+// CHECK: arith.trunci %[[res]] : i64 to i4
+func.func @fptoui(%arg0: f16) {
+ %0 = arith.fptoui %arg0 : f16 to i4
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
index 51976434d2be2..5e93945c3eb60 100644
--- a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
@@ -43,5 +43,15 @@ func.func @entry() {
%cvt = arith.truncf %b2 : f32 to f8E4M3FN
vector.print %cvt : f8E4M3FN
+ // CHECK-NEXT: 1
+ // Bit pattern: 01, interpreted as signed integer: 1
+ %cvt_int_signed = arith.fptosi %cvt : f8E4M3FN to i2
+ vector.print %cvt_int_signed : i2
+
+ // CHECK-NEXT: -2
+ // Bit pattern: 10, interpreted as signed integer: -2
+ %cvt_int_unsigned = arith.fptoui %cvt : f8E4M3FN to i2
+ vector.print %cvt_int_unsigned : i2
+
return
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/169277
More information about the llvm-branch-commits
mailing list