[Mlir-commits] [mlir] [mlir][emitc] arith.cmpf to EmitC conversion (PR #93671)
Marius Brehler
llvmlistbot at llvm.org
Mon Jun 3 04:57:56 PDT 2024
================
@@ -40,6 +42,160 @@ class ArithConstantOpConversionPattern
}
};
+class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ if (!isa<FloatType>(adaptor.getRhs().getType())) {
+ return rewriter.notifyMatchFailure(op.getLoc(),
+ "cmpf currently only supported on "
+ "floats, not tensors/vectors thereof");
+ }
+
+ bool unordered = false;
+ emitc::CmpPredicate predicate;
+ switch (op.getPredicate()) {
+ case arith::CmpFPredicate::AlwaysFalse: {
+ auto constant = rewriter.create<emitc::ConstantOp>(
+ op.getLoc(), rewriter.getI1Type(),
+ rewriter.getBoolAttr(/*value=*/false));
+ rewriter.replaceOp(op, constant);
+ return success();
+ }
+ case arith::CmpFPredicate::OEQ:
+ unordered = false;
+ predicate = emitc::CmpPredicate::eq;
+ break;
+ case arith::CmpFPredicate::OGT:
+ unordered = false;
+ predicate = emitc::CmpPredicate::gt;
+ break;
+ case arith::CmpFPredicate::OGE:
+ unordered = false;
+ predicate = emitc::CmpPredicate::ge;
+ break;
+ case arith::CmpFPredicate::OLT:
+ unordered = false;
+ predicate = emitc::CmpPredicate::lt;
+ break;
+ case arith::CmpFPredicate::OLE:
+ unordered = false;
+ predicate = emitc::CmpPredicate::le;
+ break;
+ case arith::CmpFPredicate::ONE:
+ unordered = false;
+ predicate = emitc::CmpPredicate::ne;
+ break;
+ case arith::CmpFPredicate::ORD: {
+ // ordered, i.e. none of the operands is NaN
+ auto cmp = createCheckIsOrdered(rewriter, op.getLoc(), adaptor.getLhs(),
+ adaptor.getRhs());
+ rewriter.replaceOp(op, cmp);
+ return success();
+ }
+ case arith::CmpFPredicate::UEQ:
+ unordered = true;
+ predicate = emitc::CmpPredicate::eq;
+ break;
+ case arith::CmpFPredicate::UGT:
+ unordered = true;
+ predicate = emitc::CmpPredicate::gt;
+ break;
+ case arith::CmpFPredicate::UGE:
+ unordered = true;
+ predicate = emitc::CmpPredicate::ge;
+ break;
+ case arith::CmpFPredicate::ULT:
+ unordered = true;
+ predicate = emitc::CmpPredicate::lt;
+ break;
+ case arith::CmpFPredicate::ULE:
+ unordered = true;
+ predicate = emitc::CmpPredicate::le;
+ break;
+ case arith::CmpFPredicate::UNE:
+ unordered = true;
+ predicate = emitc::CmpPredicate::ne;
+ break;
+ case arith::CmpFPredicate::UNO: {
+ // unordered, i.e. either operand is nan
+ auto cmp = createCheckIsUnordered(rewriter, op.getLoc(), adaptor.getLhs(),
+ adaptor.getRhs());
+ rewriter.replaceOp(op, cmp);
+ return success();
+ }
+ case arith::CmpFPredicate::AlwaysTrue: {
+ auto constant = rewriter.create<emitc::ConstantOp>(
+ op.getLoc(), rewriter.getI1Type(),
+ rewriter.getBoolAttr(/*value=*/true));
+ rewriter.replaceOp(op, constant);
+ return success();
+ }
+ }
+
+ // Compare the values naively
+ auto cmpResult =
+ rewriter.create<emitc::CmpOp>(op.getLoc(), op.getType(), predicate,
+ adaptor.getLhs(), adaptor.getRhs());
+
+ // Adjust the results for unordered/ordered semantics
+ if (unordered) {
+ auto isUnordered = createCheckIsUnordered(
+ rewriter, op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
+ rewriter.replaceOpWithNewOp<emitc::LogicalOrOp>(op, op.getType(),
+ isUnordered, cmpResult);
+ return success();
+ }
+
+ auto isOrdered = createCheckIsOrdered(rewriter, op.getLoc(),
+ adaptor.getLhs(), adaptor.getRhs());
+ rewriter.replaceOpWithNewOp<emitc::LogicalAndOp>(op, op.getType(),
+ isOrdered, cmpResult);
+ return success();
+ }
+
+private:
+ /// Return a value that is true iff \p operand is NaN.
----------------
marbre wrote:
```suggestion
/// Return a value that is true if \p operand is NaN.
```
https://github.com/llvm/llvm-project/pull/93671
More information about the Mlir-commits
mailing list