[Mlir-commits] [mlir] [mlir][arith] Add support for `extf`, `truncf` to `ArithToAPFloat` (PR #169275)

Matthias Springer llvmlistbot at llvm.org
Sun Nov 23 19:55:56 PST 2025


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/169275

None

>From 23f7e745b19bf7df094f5a20b9537c971a4aaa49 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 24 Nov 2025 03:54:42 +0000
Subject: [PATCH] [mlir][arith] Add support for `extf`, `truncf` to
 `ArithToAPFloat`

---
 .../ArithToAPFloat/ArithToAPFloat.cpp         | 99 +++++++++++++++----
 mlir/lib/ExecutionEngine/APFloatWrappers.cpp  | 17 +++-
 .../ArithToApfloat/arith-to-apfloat.mlir      | 22 +++++
 .../Arith/CPU/test-apfloat-emulation.mlir     | 15 ++-
 4 files changed, 129 insertions(+), 24 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
index 699edb188a70a..90e6e674da519 100644
--- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
@@ -41,24 +41,15 @@ static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
 }
 
 /// Helper function to look up or create the symbol for a runtime library
-/// function for a binary arithmetic operation.
-///
-/// Parameter 1: APFloat semantics
-/// Parameter 2: Left-hand side operand
-/// Parameter 3: Right-hand side operand
-///
-/// This function will return a failure if the function is found but has an
-/// unexpected signature.
-///
+/// function with the given parameter types. Always returns an int64_t.
 static FailureOr<FuncOp>
-lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
-                       SymbolTableCollection *symbolTables = nullptr) {
-  auto i32Type = IntegerType::get(symTable->getContext(), 32);
+lookupOrCreateApFloatFn(OpBuilder &b, SymbolOpInterface symTable,
+                        StringRef name, TypeRange paramTypes,
+                        SymbolTableCollection *symbolTables = nullptr) {
   auto i64Type = IntegerType::get(symTable->getContext(), 64);
 
   std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
-  FunctionType funcT =
-      FunctionType::get(b.getContext(), {i32Type, i64Type, i64Type}, {i64Type});
+  FunctionType funcT = FunctionType::get(b.getContext(), paramTypes, {i64Type});
   FailureOr<FuncOp> func =
       lookupFnDecl(symTable, funcName, funcT, symbolTables);
   // Failed due to type mismatch.
@@ -72,6 +63,31 @@ lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
                       /*setPrivate=*/true, symbolTables);
 }
 
+/// Helper function to look up or create the symbol for a runtime library
+/// function for a binary arithmetic operation.
+///
+/// Parameter 1: APFloat semantics
+/// Parameter 2: Left-hand side operand
+/// Parameter 3: Right-hand side operand
+///
+/// This function will return a failure if the function is found but has an
+/// unexpected signature.
+///
+static FailureOr<FuncOp>
+lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
+                       SymbolTableCollection *symbolTables = nullptr) {
+  auto i32Type = IntegerType::get(symTable->getContext(), 32);
+  auto i64Type = IntegerType::get(symTable->getContext(), 64);
+  return lookupOrCreateApFloatFn(b, symTable, name, {i32Type, i64Type, i64Type},
+                                 symbolTables);
+}
+
+static Value getSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy) {
+  int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
+  return arith::ConstantOp::create(b, loc, b.getI32Type(),
+                                   b.getIntegerAttr(b.getI32Type(), sem));
+}
+
 /// Rewrite a binary arithmetic operation to an APFloat function call.
 template <typename OpTy>
 struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
@@ -104,11 +120,7 @@ struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
         arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));
 
     // Call APFloat function.
-    int32_t sem =
-        llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
-    Value semValue = arith::ConstantOp::create(
-        rewriter, loc, rewriter.getI32Type(),
-        rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
+    Value semValue = getSemanticsValue(rewriter, loc, floatTy);
     SmallVector<Value> params = {semValue, lhsBits, rhsBits};
     auto resultOp =
         func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
@@ -126,6 +138,53 @@ struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
   const char *APFloatName;
 };
 
+template <typename OpTy>
+struct FpToFpConversion final : OpRewritePattern<OpTy> {
+  FpToFpConversion(MLIRContext *context, SymbolOpInterface symTable,
+                   PatternBenefit benefit = 1)
+      : OpRewritePattern<OpTy>(context, benefit), symTable(symTable){};
+
+  LogicalResult matchAndRewrite(OpTy op,
+                                PatternRewriter &rewriter) const override {
+    // Get APFloat function from runtime library.
+    auto i32Type = IntegerType::get(symTable->getContext(), 32);
+    auto i64Type = IntegerType::get(symTable->getContext(), 64);
+    FailureOr<FuncOp> fn = lookupOrCreateApFloatFn(
+        rewriter, symTable, "convert", {i32Type, i32Type, i64Type});
+    if (failed(fn))
+      return fn;
+
+    rewriter.setInsertionPoint(op);
+    // Cast operands to 64-bit integers.
+    Location loc = op.getLoc();
+    auto inFloatTy = cast<FloatType>(op.getOperand().getType());
+    auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
+    auto int64Type = rewriter.getI64Type();
+    Value operandBits = arith::ExtUIOp::create(
+        rewriter, loc, int64Type,
+        arith::BitcastOp::create(rewriter, loc, inIntWType, op.getOperand()));
+
+    // Call APFloat function.
+    Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
+    auto outFloatTy = cast<FloatType>(op.getType());
+    Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
+    SmallVector<Value> params = {inSemValue, outSemValue, operandBits};
+    auto resultOp =
+        func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
+                             SymbolRefAttr::get(*fn), params);
+
+    // Truncate result to the original width.
+    auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth());
+    Value truncatedBits = arith::TruncIOp::create(rewriter, loc, outIntWType,
+                                                  resultOp->getResult(0));
+    rewriter.replaceOp(
+        op, arith::BitcastOp::create(rewriter, loc, outFloatTy, truncatedBits));
+    return success();
+  }
+
+  SymbolOpInterface symTable;
+};
+
 namespace {
 struct ArithToAPFloatConversionPass final
     : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
@@ -147,6 +206,8 @@ void ArithToAPFloatConversionPass::runOnOperation() {
       context, "divide", getOperation());
   patterns.add<BinaryArithOpToAPFloatConversion<arith::RemFOp>>(
       context, "remainder", getOperation());
+  patterns.add<FpToFpConversion<arith::ExtFOp>>(context, getOperation());
+  patterns.add<FpToFpConversion<arith::TruncFOp>>(context, getOperation());
   LogicalResult result = success();
   ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
     if (diag.getSeverity() == DiagnosticSeverity::Error) {
diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
index 0a05f7369e556..511b05ea380f0 100644
--- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
@@ -51,7 +51,7 @@
 
 /// Binary operations with rounding mode.
 #define APFLOAT_BINARY_OP_ROUNDING_MODE(OP, ROUNDING_MODE)                     \
-  MLIR_APFLOAT_WRAPPERS_EXPORT int64_t _mlir_apfloat_##OP(                     \
+  MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_##OP(                    \
       int32_t semantics, uint64_t a, uint64_t b) {                             \
     const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(        \
         static_cast<llvm::APFloatBase::Semantics>(semantics));                 \
@@ -86,4 +86,19 @@ MLIR_APFLOAT_WRAPPERS_EXPORT void printApFloat(int32_t semantics, uint64_t a) {
   double d = x.convertToDouble();
   fprintf(stdout, "%lg", d);
 }
+
+MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t
+_mlir_apfloat_convert(int32_t inSemantics, int32_t outSemantics, uint64_t a) {
+  const llvm::fltSemantics &inSem = llvm::APFloatBase::EnumToSemantics(
+      static_cast<llvm::APFloatBase::Semantics>(inSemantics));
+  const llvm::fltSemantics &outSem = llvm::APFloatBase::EnumToSemantics(
+      static_cast<llvm::APFloatBase::Semantics>(outSemantics));
+  unsigned bitWidthIn = llvm::APFloatBase::semanticsSizeInBits(inSem);
+  llvm::APFloat val(inSem, llvm::APInt(bitWidthIn, a));
+  // TODO: Custom rounding modes are not supported yet.
+  bool losesInfo;
+  val.convert(outSem, llvm::RoundingMode::NearestTiesToEven, &losesInfo);
+  llvm::APInt result = val.bitcastToAPInt();
+  return result.getZExtValue();
+}
 }
diff --git a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
index 797f42c37a26f..038acbfc965a2 100644
--- a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
+++ b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
@@ -126,3 +126,25 @@ func.func @remf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
   %0 = arith.remf %arg0, %arg1 : f4E2M1FN
   return
 }
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_convert(i32, i32, i64) -> i64
+// CHECK: %[[sem_in:.*]] = arith.constant 18 : i32
+// CHECK: %[[sem_out:.*]] = arith.constant 2 : i32
+// CHECK: call @_mlir_apfloat_convert(%[[sem_in]], %[[sem_out]], %{{.*}}) : (i32, i32, i64) -> i64
+func.func @extf(%arg0: f4E2M1FN) {
+  %0 = arith.extf %arg0 : f4E2M1FN to f32
+  return
+}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_convert(i32, i32, i64) -> i64
+// CHECK: %[[sem_in:.*]] = arith.constant 1 : i32
+// CHECK: %[[sem_out:.*]] = arith.constant 18 : i32
+// CHECK: call @_mlir_apfloat_convert(%[[sem_in]], %[[sem_out]], %{{.*}}) : (i32, i32, i64) -> i64
+func.func @truncf(%arg0: bf16) {
+  %0 = arith.truncf %arg0 : bf16 to f4E2M1FN
+  return
+}
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
index dbaa20346a03a..51976434d2be2 100644
--- a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
@@ -27,14 +27,21 @@ func.func @entry() {
   %a1 = arith.constant 1.4 : f8E4M3FN
   %a2 = arith.constant 1.4 : f32
   %b1, %b2 = func.call @foo() : () -> (f8E4M3FN, f32)
-  %c1 = arith.addf %a1, %b1 : f8E4M3FN  // not supported by LLVM
-  %c2 = arith.addf %a2, %b2 : f32       // supported by LLVM
 
-  // CHECK: 3.5
+  // CHECK: 2.2
+  vector.print %b2 : f32
+
+  // CHECK-NEXT: 3.5
+  %c1 = arith.addf %a1, %b1 : f8E4M3FN  // not supported by LLVM
   vector.print %c1 : f8E4M3FN
 
-  // CHECK: 3.6
+  // CHECK-NEXT: 3.6
+  %c2 = arith.addf %a2, %b2 : f32       // supported by LLVM
   vector.print %c2 : f32
 
+  // CHECK-NEXT: 2.25
+  %cvt = arith.truncf %b2 : f32 to f8E4M3FN
+  vector.print %cvt : f8E4M3FN
+
   return
 }



More information about the Mlir-commits mailing list