[Mlir-commits] [mlir] 4d7abe5 - [mlir][arith] Add support for `cmpf` to `ArithToAPFloat` (#169753)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 1 00:12:15 PST 2025


Author: Matthias Springer
Date: 2025-12-01T09:12:11+01:00
New Revision: 4d7abe535512e1076ff7e5fea14afde29615a8ed

URL: https://github.com/llvm/llvm-project/commit/4d7abe535512e1076ff7e5fea14afde29615a8ed
DIFF: https://github.com/llvm/llvm-project/commit/4d7abe535512e1076ff7e5fea14afde29615a8ed.diff

LOG: [mlir][arith] Add support for `cmpf` to `ArithToAPFloat` (#169753)

Add support for `arith.cmpf`.

Added: 
    

Modified: 
    mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
    mlir/lib/ExecutionEngine/APFloatWrappers.cpp
    mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
    mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
index 81fbdb1611deb..566632bd8707f 100644
--- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
@@ -41,15 +41,17 @@ static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
 }
 
 /// Helper function to look up or create the symbol for a runtime library
-/// function with the given parameter types. Always returns an int64_t.
+/// function with the given parameter types. Returns an int64_t, unless a
+/// 
diff erent result type is specified.
 static FailureOr<FuncOp>
 lookupOrCreateApFloatFn(OpBuilder &b, SymbolOpInterface symTable,
                         StringRef name, TypeRange paramTypes,
-                        SymbolTableCollection *symbolTables = nullptr) {
-  auto i64Type = IntegerType::get(symTable->getContext(), 64);
-
+                        SymbolTableCollection *symbolTables = nullptr,
+                        Type resultType = {}) {
+  if (!resultType)
+    resultType = IntegerType::get(symTable->getContext(), 64);
   std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
-  auto funcT = FunctionType::get(b.getContext(), paramTypes, {i64Type});
+  auto funcT = FunctionType::get(b.getContext(), paramTypes, {resultType});
   FailureOr<FuncOp> func =
       lookupFnDecl(symTable, funcName, funcT, symbolTables);
   // Failed due to type mismatch.
@@ -308,6 +310,145 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
   bool isUnsigned;
 };
 
+struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
+  CmpFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
+                            PatternBenefit benefit = 1)
+      : OpRewritePattern<arith::CmpFOp>(context, benefit), symTable(symTable) {}
+
+  LogicalResult matchAndRewrite(arith::CmpFOp op,
+                                PatternRewriter &rewriter) const override {
+    // Get APFloat function from runtime library.
+    auto i1Type = IntegerType::get(symTable->getContext(), 1);
+    auto i8Type = IntegerType::get(symTable->getContext(), 8);
+    auto i32Type = IntegerType::get(symTable->getContext(), 32);
+    auto i64Type = IntegerType::get(symTable->getContext(), 64);
+    FailureOr<FuncOp> fn =
+        lookupOrCreateApFloatFn(rewriter, symTable, "compare",
+                                {i32Type, i64Type, i64Type}, nullptr, i8Type);
+    if (failed(fn))
+      return fn;
+
+    // Cast operands to 64-bit integers.
+    rewriter.setInsertionPoint(op);
+    Location loc = op.getLoc();
+    auto floatTy = cast<FloatType>(op.getLhs().getType());
+    auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+    Value lhsBits = arith::ExtUIOp::create(
+        rewriter, loc, i64Type,
+        arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
+    Value rhsBits = arith::ExtUIOp::create(
+        rewriter, loc, i64Type,
+        arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));
+
+    // Call APFloat function.
+    Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+    SmallVector<Value> params = {semValue, lhsBits, rhsBits};
+    Value comparisonResult =
+        func::CallOp::create(rewriter, loc, TypeRange(i8Type),
+                             SymbolRefAttr::get(*fn), params)
+            ->getResult(0);
+
+    // Generate an i1 SSA value that is "true" if the comparison result matches
+    // the given `val`.
+    auto checkResult = [&](llvm::APFloat::cmpResult val) {
+      return arith::CmpIOp::create(
+          rewriter, loc, arith::CmpIPredicate::eq, comparisonResult,
+          arith::ConstantOp::create(
+              rewriter, loc, i8Type,
+              rewriter.getIntegerAttr(i8Type, static_cast<int8_t>(val)))
+              .getResult());
+    };
+    // Generate an i1 SSA value that is "true" if the comparison result matches
+    // any of the given `vals`.
+    std::function<Value(ArrayRef<llvm::APFloat::cmpResult>)> checkResults =
+        [&](ArrayRef<llvm::APFloat::cmpResult> vals) {
+          Value first = checkResult(vals.front());
+          if (vals.size() == 1)
+            return first;
+          Value rest = checkResults(vals.drop_front());
+          return arith::OrIOp::create(rewriter, loc, first, rest).getResult();
+        };
+
+    // This switch-case statement was taken from arith::applyCmpPredicate.
+    Value result;
+    switch (op.getPredicate()) {
+    case arith::CmpFPredicate::AlwaysFalse:
+      result = arith::ConstantOp::create(rewriter, loc, i1Type,
+                                         rewriter.getIntegerAttr(i1Type, 0))
+                   .getResult();
+      break;
+    case arith::CmpFPredicate::OEQ:
+      result = checkResult(llvm::APFloat::cmpEqual);
+      break;
+    case arith::CmpFPredicate::OGT:
+      result = checkResult(llvm::APFloat::cmpGreaterThan);
+      break;
+    case arith::CmpFPredicate::OGE:
+      result = checkResults(
+          {llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual});
+      break;
+    case arith::CmpFPredicate::OLT:
+      result = checkResult(llvm::APFloat::cmpLessThan);
+      break;
+    case arith::CmpFPredicate::OLE:
+      result =
+          checkResults({llvm::APFloat::cmpLessThan, llvm::APFloat::cmpEqual});
+      break;
+    case arith::CmpFPredicate::ONE:
+      // Not cmpUnordered and not cmpUnordered.
+      result = checkResults(
+          {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpGreaterThan});
+      break;
+    case arith::CmpFPredicate::ORD:
+      // Not cmpUnordered.
+      result = checkResults({llvm::APFloat::cmpLessThan,
+                             llvm::APFloat::cmpGreaterThan,
+                             llvm::APFloat::cmpEqual});
+      break;
+    case arith::CmpFPredicate::UEQ:
+      result =
+          checkResults({llvm::APFloat::cmpUnordered, llvm::APFloat::cmpEqual});
+      break;
+    case arith::CmpFPredicate::UGT:
+      result = checkResults(
+          {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpGreaterThan});
+      break;
+    case arith::CmpFPredicate::UGE:
+      result = checkResults({llvm::APFloat::cmpUnordered,
+                             llvm::APFloat::cmpGreaterThan,
+                             llvm::APFloat::cmpEqual});
+      break;
+    case arith::CmpFPredicate::ULT:
+      result = checkResults(
+          {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan});
+      break;
+    case arith::CmpFPredicate::ULE:
+      result =
+          checkResults({llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan,
+                        llvm::APFloat::cmpEqual});
+      break;
+    case arith::CmpFPredicate::UNE:
+      // Not cmpEqual.
+      result = checkResults({llvm::APFloat::cmpLessThan,
+                             llvm::APFloat::cmpGreaterThan,
+                             llvm::APFloat::cmpUnordered});
+      break;
+    case arith::CmpFPredicate::UNO:
+      result = checkResult(llvm::APFloat::cmpUnordered);
+      break;
+    case arith::CmpFPredicate::AlwaysTrue:
+      result = arith::ConstantOp::create(rewriter, loc, i1Type,
+                                         rewriter.getIntegerAttr(i1Type, 1))
+                   .getResult();
+      break;
+    }
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+
+  SymbolOpInterface symTable;
+};
+
 namespace {
 struct ArithToAPFloatConversionPass final
     : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
@@ -340,6 +481,7 @@ void ArithToAPFloatConversionPass::runOnOperation() {
                                                    /*isUnsigned=*/false);
   patterns.add<IntToFpConversion<arith::UIToFPOp>>(context, getOperation(),
                                                    /*isUnsigned=*/true);
+  patterns.add<CmpFOpToAPFloatConversion>(context, getOperation());
   LogicalResult result = success();
   ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
     if (diag.getSeverity() == DiagnosticSeverity::Error) {

diff  --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
index 44980ccd77491..77f7137264888 100644
--- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
@@ -131,4 +131,15 @@ MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_convert_from_int(
                           llvm::RoundingMode::NearestTiesToEven);
   return result.bitcastToAPInt().getZExtValue();
 }
+
+MLIR_APFLOAT_WRAPPERS_EXPORT int8_t _mlir_apfloat_compare(int32_t semantics,
+                                                          uint64_t a,
+                                                          uint64_t b) {
+  const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+      static_cast<llvm::APFloatBase::Semantics>(semantics));
+  unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
+  llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
+  llvm::APFloat y(sem, llvm::APInt(bitWidth, b));
+  return static_cast<int8_t>(x.compare(y));
+}
 }

diff  --git a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
index d71d81dddcd4f..78ce3640ecc67 100644
--- a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
+++ b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
@@ -198,3 +198,18 @@ func.func @uitofp(%arg0: i32) {
   %0 = arith.uitofp %arg0 : i32 to f4E2M1FN
   return
 }
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_compare(i32, i64, i64) -> i8
+// CHECK: %[[sem:.*]] = arith.constant 18 : i32
+// CHECK: %[[cmp:.*]] = call @_mlir_apfloat_compare(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i8
+// CHECK: %[[c3:.*]] = arith.constant 3 : i8
+// CHECK: %[[is_unordered:.*]] = arith.cmpi eq, %[[cmp]], %[[c3]] : i8
+// CHECK: %[[c0:.*]] = arith.constant 0 : i8
+// CHECK: %[[is_lt:.*]] = arith.cmpi eq, %[[cmp]], %[[c0]] : i8
+// CHECK: arith.ori %[[is_unordered]], %[[is_lt]] : i1
+func.func @cmpf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
+  %0 = arith.cmpf "ult", %arg0, %arg1 : f4E2M1FN
+  return
+}

diff  --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
index 8046610d479a8..433d058d025cf 100644
--- a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
@@ -43,6 +43,10 @@ func.func @entry() {
   %cvt = arith.truncf %b2 : f32 to f8E4M3FN
   vector.print %cvt : f8E4M3FN
 
+  // CHECK-NEXT: 1
+  %cmp1 = arith.cmpf "olt", %cvt, %c1 : f8E4M3FN
+  vector.print %cmp1 : i1
+
   // CHECK-NEXT: 1
   // Bit pattern: 01, interpreted as signed integer: 1
   %cvt_int_signed = arith.fptosi %cvt : f8E4M3FN to i2


        


More information about the Mlir-commits mailing list