[Mlir-commits] [mlir] 4d7abe5 - [mlir][arith] Add support for `cmpf` to `ArithToAPFloat` (#169753)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 1 00:12:15 PST 2025
Author: Matthias Springer
Date: 2025-12-01T09:12:11+01:00
New Revision: 4d7abe535512e1076ff7e5fea14afde29615a8ed
URL: https://github.com/llvm/llvm-project/commit/4d7abe535512e1076ff7e5fea14afde29615a8ed
DIFF: https://github.com/llvm/llvm-project/commit/4d7abe535512e1076ff7e5fea14afde29615a8ed.diff
LOG: [mlir][arith] Add support for `cmpf` to `ArithToAPFloat` (#169753)
Add support for `arith.cmpf`.
Added:
Modified:
mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
mlir/lib/ExecutionEngine/APFloatWrappers.cpp
mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
index 81fbdb1611deb..566632bd8707f 100644
--- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
@@ -41,15 +41,17 @@ static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
}
/// Helper function to look up or create the symbol for a runtime library
-/// function with the given parameter types. Always returns an int64_t.
+/// function with the given parameter types. Returns an int64_t, unless a
+///
diff erent result type is specified.
static FailureOr<FuncOp>
lookupOrCreateApFloatFn(OpBuilder &b, SymbolOpInterface symTable,
StringRef name, TypeRange paramTypes,
- SymbolTableCollection *symbolTables = nullptr) {
- auto i64Type = IntegerType::get(symTable->getContext(), 64);
-
+ SymbolTableCollection *symbolTables = nullptr,
+ Type resultType = {}) {
+ if (!resultType)
+ resultType = IntegerType::get(symTable->getContext(), 64);
std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
- auto funcT = FunctionType::get(b.getContext(), paramTypes, {i64Type});
+ auto funcT = FunctionType::get(b.getContext(), paramTypes, {resultType});
FailureOr<FuncOp> func =
lookupFnDecl(symTable, funcName, funcT, symbolTables);
// Failed due to type mismatch.
@@ -308,6 +310,145 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
bool isUnsigned;
};
+struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
+ CmpFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<arith::CmpFOp>(context, benefit), symTable(symTable) {}
+
+ LogicalResult matchAndRewrite(arith::CmpFOp op,
+ PatternRewriter &rewriter) const override {
+ // Get APFloat function from runtime library.
+ auto i1Type = IntegerType::get(symTable->getContext(), 1);
+ auto i8Type = IntegerType::get(symTable->getContext(), 8);
+ auto i32Type = IntegerType::get(symTable->getContext(), 32);
+ auto i64Type = IntegerType::get(symTable->getContext(), 64);
+ FailureOr<FuncOp> fn =
+ lookupOrCreateApFloatFn(rewriter, symTable, "compare",
+ {i32Type, i64Type, i64Type}, nullptr, i8Type);
+ if (failed(fn))
+ return fn;
+
+ // Cast operands to 64-bit integers.
+ rewriter.setInsertionPoint(op);
+ Location loc = op.getLoc();
+ auto floatTy = cast<FloatType>(op.getLhs().getType());
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ Value lhsBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
+ Value rhsBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));
+
+ // Call APFloat function.
+ Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+ SmallVector<Value> params = {semValue, lhsBits, rhsBits};
+ Value comparisonResult =
+ func::CallOp::create(rewriter, loc, TypeRange(i8Type),
+ SymbolRefAttr::get(*fn), params)
+ ->getResult(0);
+
+ // Generate an i1 SSA value that is "true" if the comparison result matches
+ // the given `val`.
+ auto checkResult = [&](llvm::APFloat::cmpResult val) {
+ return arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::eq, comparisonResult,
+ arith::ConstantOp::create(
+ rewriter, loc, i8Type,
+ rewriter.getIntegerAttr(i8Type, static_cast<int8_t>(val)))
+ .getResult());
+ };
+ // Generate an i1 SSA value that is "true" if the comparison result matches
+ // any of the given `vals`.
+ std::function<Value(ArrayRef<llvm::APFloat::cmpResult>)> checkResults =
+ [&](ArrayRef<llvm::APFloat::cmpResult> vals) {
+ Value first = checkResult(vals.front());
+ if (vals.size() == 1)
+ return first;
+ Value rest = checkResults(vals.drop_front());
+ return arith::OrIOp::create(rewriter, loc, first, rest).getResult();
+ };
+
+ // This switch-case statement was taken from arith::applyCmpPredicate.
+ Value result;
+ switch (op.getPredicate()) {
+ case arith::CmpFPredicate::AlwaysFalse:
+ result = arith::ConstantOp::create(rewriter, loc, i1Type,
+ rewriter.getIntegerAttr(i1Type, 0))
+ .getResult();
+ break;
+ case arith::CmpFPredicate::OEQ:
+ result = checkResult(llvm::APFloat::cmpEqual);
+ break;
+ case arith::CmpFPredicate::OGT:
+ result = checkResult(llvm::APFloat::cmpGreaterThan);
+ break;
+ case arith::CmpFPredicate::OGE:
+ result = checkResults(
+ {llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual});
+ break;
+ case arith::CmpFPredicate::OLT:
+ result = checkResult(llvm::APFloat::cmpLessThan);
+ break;
+ case arith::CmpFPredicate::OLE:
+ result =
+ checkResults({llvm::APFloat::cmpLessThan, llvm::APFloat::cmpEqual});
+ break;
+ case arith::CmpFPredicate::ONE:
+ // Not cmpUnordered and not cmpUnordered.
+ result = checkResults(
+ {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpGreaterThan});
+ break;
+ case arith::CmpFPredicate::ORD:
+ // Not cmpUnordered.
+ result = checkResults({llvm::APFloat::cmpLessThan,
+ llvm::APFloat::cmpGreaterThan,
+ llvm::APFloat::cmpEqual});
+ break;
+ case arith::CmpFPredicate::UEQ:
+ result =
+ checkResults({llvm::APFloat::cmpUnordered, llvm::APFloat::cmpEqual});
+ break;
+ case arith::CmpFPredicate::UGT:
+ result = checkResults(
+ {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpGreaterThan});
+ break;
+ case arith::CmpFPredicate::UGE:
+ result = checkResults({llvm::APFloat::cmpUnordered,
+ llvm::APFloat::cmpGreaterThan,
+ llvm::APFloat::cmpEqual});
+ break;
+ case arith::CmpFPredicate::ULT:
+ result = checkResults(
+ {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan});
+ break;
+ case arith::CmpFPredicate::ULE:
+ result =
+ checkResults({llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan,
+ llvm::APFloat::cmpEqual});
+ break;
+ case arith::CmpFPredicate::UNE:
+ // Not cmpEqual.
+ result = checkResults({llvm::APFloat::cmpLessThan,
+ llvm::APFloat::cmpGreaterThan,
+ llvm::APFloat::cmpUnordered});
+ break;
+ case arith::CmpFPredicate::UNO:
+ result = checkResult(llvm::APFloat::cmpUnordered);
+ break;
+ case arith::CmpFPredicate::AlwaysTrue:
+ result = arith::ConstantOp::create(rewriter, loc, i1Type,
+ rewriter.getIntegerAttr(i1Type, 1))
+ .getResult();
+ break;
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+
+ SymbolOpInterface symTable;
+};
+
namespace {
struct ArithToAPFloatConversionPass final
: impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
@@ -340,6 +481,7 @@ void ArithToAPFloatConversionPass::runOnOperation() {
/*isUnsigned=*/false);
patterns.add<IntToFpConversion<arith::UIToFPOp>>(context, getOperation(),
/*isUnsigned=*/true);
+ patterns.add<CmpFOpToAPFloatConversion>(context, getOperation());
LogicalResult result = success();
ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
if (diag.getSeverity() == DiagnosticSeverity::Error) {
diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
index 44980ccd77491..77f7137264888 100644
--- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
@@ -131,4 +131,15 @@ MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_convert_from_int(
llvm::RoundingMode::NearestTiesToEven);
return result.bitcastToAPInt().getZExtValue();
}
+
+MLIR_APFLOAT_WRAPPERS_EXPORT int8_t _mlir_apfloat_compare(int32_t semantics,
+ uint64_t a,
+ uint64_t b) {
+ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+ static_cast<llvm::APFloatBase::Semantics>(semantics));
+ unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
+ llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
+ llvm::APFloat y(sem, llvm::APInt(bitWidth, b));
+ return static_cast<int8_t>(x.compare(y));
+}
}
diff --git a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
index d71d81dddcd4f..78ce3640ecc67 100644
--- a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
+++ b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
@@ -198,3 +198,18 @@ func.func @uitofp(%arg0: i32) {
%0 = arith.uitofp %arg0 : i32 to f4E2M1FN
return
}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_compare(i32, i64, i64) -> i8
+// CHECK: %[[sem:.*]] = arith.constant 18 : i32
+// CHECK: %[[cmp:.*]] = call @_mlir_apfloat_compare(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i8
+// CHECK: %[[c3:.*]] = arith.constant 3 : i8
+// CHECK: %[[is_unordered:.*]] = arith.cmpi eq, %[[cmp]], %[[c3]] : i8
+// CHECK: %[[c0:.*]] = arith.constant 0 : i8
+// CHECK: %[[is_lt:.*]] = arith.cmpi eq, %[[cmp]], %[[c0]] : i8
+// CHECK: arith.ori %[[is_unordered]], %[[is_lt]] : i1
+func.func @cmpf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
+ %0 = arith.cmpf "ult", %arg0, %arg1 : f4E2M1FN
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
index 8046610d479a8..433d058d025cf 100644
--- a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
@@ -43,6 +43,10 @@ func.func @entry() {
%cvt = arith.truncf %b2 : f32 to f8E4M3FN
vector.print %cvt : f8E4M3FN
+ // CHECK-NEXT: 1
+ %cmp1 = arith.cmpf "olt", %cvt, %c1 : f8E4M3FN
+ vector.print %cmp1 : i1
+
// CHECK-NEXT: 1
// Bit pattern: 01, interpreted as signed integer: 1
%cvt_int_signed = arith.fptosi %cvt : f8E4M3FN to i2
More information about the Mlir-commits
mailing list