[Mlir-commits] [mlir] [mlir][emitc] arith.cmpf to EmitC conversion (PR #93671)
Simon Camphausen
llvmlistbot at llvm.org
Wed May 29 06:10:03 PDT 2024
================
@@ -40,6 +42,162 @@ class ArithConstantOpConversionPattern
}
};
+class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ if (!isa<FloatType>(adaptor.getRhs().getType())) {
+ return rewriter.notifyMatchFailure(op.getLoc(),
+ "cmpf currently only supported on "
+ "floats, not tensors/vectors thereof");
+ }
+
+ bool unordered = false;
+ emitc::CmpPredicate predicate;
+ switch (op.getPredicate()) {
+ case arith::CmpFPredicate::AlwaysFalse: {
+ auto constant = rewriter.create<emitc::ConstantOp>(
+ op.getLoc(), rewriter.getI1Type(),
+ rewriter.getBoolAttr(/*value=*/false));
+ rewriter.replaceOp(op, constant);
+ return success();
+ }
+ case arith::CmpFPredicate::OEQ:
+ unordered = false;
+ predicate = emitc::CmpPredicate::eq;
+ break;
+ case arith::CmpFPredicate::OGT:
+ unordered = false;
+ predicate = emitc::CmpPredicate::gt;
+ break;
+ case arith::CmpFPredicate::OGE:
+ unordered = false;
+ predicate = emitc::CmpPredicate::ge;
+ break;
+ case arith::CmpFPredicate::OLT:
+ unordered = false;
+ predicate = emitc::CmpPredicate::lt;
+ break;
+ case arith::CmpFPredicate::OLE:
+ unordered = false;
+ predicate = emitc::CmpPredicate::le;
+ break;
+ case arith::CmpFPredicate::ONE:
+ unordered = false;
+ predicate = emitc::CmpPredicate::ne;
+ break;
+ case arith::CmpFPredicate::ORD: {
+ // ordered, i.e. none of the operands is NaN
+ auto cmp = createCheckIsOrdered(rewriter, op.getLoc(), adaptor.getLhs(),
+ adaptor.getRhs());
+ rewriter.replaceOp(op, cmp);
+ return success();
+ }
+ case arith::CmpFPredicate::UEQ:
+ unordered = true;
+ predicate = emitc::CmpPredicate::eq;
+ break;
+ case arith::CmpFPredicate::UGT:
+ unordered = true;
+ predicate = emitc::CmpPredicate::gt;
+ break;
+ case arith::CmpFPredicate::UGE:
+ unordered = true;
+ predicate = emitc::CmpPredicate::ge;
+ break;
+ case arith::CmpFPredicate::ULT:
+ unordered = true;
+ predicate = emitc::CmpPredicate::lt;
+ break;
+ case arith::CmpFPredicate::ULE:
+ unordered = true;
+ predicate = emitc::CmpPredicate::le;
+ break;
+ case arith::CmpFPredicate::UNE:
+ unordered = true;
+ predicate = emitc::CmpPredicate::ne;
+ break;
+ case arith::CmpFPredicate::UNO: {
+ // unordered, i.e. either operand is nan
+ auto cmp = createCheckIsUnordered(rewriter, op.getLoc(), adaptor.getLhs(),
+ adaptor.getRhs());
+ rewriter.replaceOp(op, cmp);
+ return success();
+ }
+ case arith::CmpFPredicate::AlwaysTrue: {
+ auto constant = rewriter.create<emitc::ConstantOp>(
+ op.getLoc(), rewriter.getI1Type(),
+ rewriter.getBoolAttr(/*value=*/true));
+ rewriter.replaceOp(op, constant);
+ return success();
+ }
+ }
+
+ // Compare the values naively
+ auto cmpResult =
+ rewriter.create<emitc::CmpOp>(op.getLoc(), op.getType(), predicate,
+ adaptor.getLhs(), adaptor.getRhs());
+
+ // Adjust the results for unordered/ordered semantics
+ if (unordered) {
+ auto isUnordered = createCheckIsUnordered(
+ rewriter, op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
+ rewriter.replaceOpWithNewOp<emitc::LogicalOrOp>(
+ op, op.getType(), isUnordered.getResult(), cmpResult);
+ return success();
+ }
+
+ auto isOrdered = createCheckIsOrdered(rewriter, op.getLoc(),
+ adaptor.getLhs(), adaptor.getRhs());
+ rewriter.replaceOpWithNewOp<emitc::LogicalAndOp>(op, op.getType(),
+ isOrdered, cmpResult);
+ return success();
+ }
+
+private:
+ /// Return an operation that returns true (in i1) when \p operand is NaN.
+ emitc::CmpOp isNan(ConversionPatternRewriter &rewriter, Location loc,
+ Value operand) const {
+ // A value is NaN exactly when it compares unequal to itself.
+ return rewriter.create<emitc::CmpOp>(
+ loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, operand, operand);
+ }
+
+ /// Return an operation that returns true (in i1) when \p operand is not NaN.
+ emitc::CmpOp isNotNan(ConversionPatternRewriter &rewriter, Location loc,
+ Value operand) const {
+ // A value is not NaN exactly when it compares equal to itself.
+ return rewriter.create<emitc::CmpOp>(
+ loc, rewriter.getI1Type(), emitc::CmpPredicate::eq, operand, operand);
+ }
+
+ /// Return an op that return true (in i1) if the operands \p first and
+ /// \p second are unordered (i.e., at least one of them is NaN).
+ emitc::LogicalOrOp createCheckIsUnordered(ConversionPatternRewriter &rewriter,
+ Location loc, Value first,
+ Value second) const {
+ auto firstIsNaN = isNan(rewriter, loc, first);
+ auto secondIsNaN = isNan(rewriter, loc, second);
+ return rewriter.create<emitc::LogicalOrOp>(loc, rewriter.getI1Type(),
+ firstIsNaN, secondIsNaN);
+ }
+
+ /// Return an op that return true (in i1) if the operands \p first and
+ /// \p second are both ordered (i.e., none one of them is NaN).
+ emitc::LogicalAndOp createCheckIsOrdered(ConversionPatternRewriter &rewriter,
+ Location loc, Value first,
+ Value second) const {
+ auto firstIsNaN = isNotNan(rewriter, loc, first);
+ auto secondIsNaN = isNotNan(rewriter, loc, second);
----------------
simon-camp wrote:
nit: rename to IsNotNaN
https://github.com/llvm/llvm-project/pull/93671
More information about the Mlir-commits
mailing list