[Mlir-commits] [mlir] [mlir][emitc] arith.cmpf to EmitC conversion (PR #93671)

Tina Jung llvmlistbot at llvm.org
Wed May 29 08:01:57 PDT 2024


================
@@ -40,6 +42,162 @@ class ArithConstantOpConversionPattern
   }
 };
 
+class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    if (!isa<FloatType>(adaptor.getRhs().getType())) {
+      return rewriter.notifyMatchFailure(op.getLoc(),
+                                         "cmpf currently only supported on "
+                                         "floats, not tensors/vectors thereof");
+    }
+
+    bool unordered = false;
+    emitc::CmpPredicate predicate;
+    switch (op.getPredicate()) {
+    case arith::CmpFPredicate::AlwaysFalse: {
+      auto constant = rewriter.create<emitc::ConstantOp>(
+          op.getLoc(), rewriter.getI1Type(),
+          rewriter.getBoolAttr(/*value=*/false));
+      rewriter.replaceOp(op, constant);
+      return success();
+    }
+    case arith::CmpFPredicate::OEQ:
+      unordered = false;
+      predicate = emitc::CmpPredicate::eq;
+      break;
+    case arith::CmpFPredicate::OGT:
+      unordered = false;
+      predicate = emitc::CmpPredicate::gt;
+      break;
+    case arith::CmpFPredicate::OGE:
+      unordered = false;
+      predicate = emitc::CmpPredicate::ge;
+      break;
+    case arith::CmpFPredicate::OLT:
+      unordered = false;
+      predicate = emitc::CmpPredicate::lt;
+      break;
+    case arith::CmpFPredicate::OLE:
+      unordered = false;
+      predicate = emitc::CmpPredicate::le;
+      break;
+    case arith::CmpFPredicate::ONE:
+      unordered = false;
+      predicate = emitc::CmpPredicate::ne;
+      break;
+    case arith::CmpFPredicate::ORD: {
+      // ordered, i.e. none of the operands is NaN
+      auto cmp = createCheckIsOrdered(rewriter, op.getLoc(), adaptor.getLhs(),
+                                      adaptor.getRhs());
+      rewriter.replaceOp(op, cmp);
+      return success();
+    }
+    case arith::CmpFPredicate::UEQ:
+      unordered = true;
+      predicate = emitc::CmpPredicate::eq;
+      break;
+    case arith::CmpFPredicate::UGT:
+      unordered = true;
+      predicate = emitc::CmpPredicate::gt;
+      break;
+    case arith::CmpFPredicate::UGE:
+      unordered = true;
+      predicate = emitc::CmpPredicate::ge;
+      break;
+    case arith::CmpFPredicate::ULT:
+      unordered = true;
+      predicate = emitc::CmpPredicate::lt;
+      break;
+    case arith::CmpFPredicate::ULE:
+      unordered = true;
+      predicate = emitc::CmpPredicate::le;
+      break;
+    case arith::CmpFPredicate::UNE:
+      unordered = true;
+      predicate = emitc::CmpPredicate::ne;
+      break;
+    case arith::CmpFPredicate::UNO: {
+      // unordered, i.e. either operand is nan
+      auto cmp = createCheckIsUnordered(rewriter, op.getLoc(), adaptor.getLhs(),
+                                        adaptor.getRhs());
+      rewriter.replaceOp(op, cmp);
+      return success();
+    }
+    case arith::CmpFPredicate::AlwaysTrue: {
+      auto constant = rewriter.create<emitc::ConstantOp>(
+          op.getLoc(), rewriter.getI1Type(),
+          rewriter.getBoolAttr(/*value=*/true));
+      rewriter.replaceOp(op, constant);
+      return success();
+    }
+    }
+
+    // Compare the values naively
+    auto cmpResult =
+        rewriter.create<emitc::CmpOp>(op.getLoc(), op.getType(), predicate,
+                                      adaptor.getLhs(), adaptor.getRhs());
+
+    // Adjust the results for unordered/ordered semantics
+    if (unordered) {
+      auto isUnordered = createCheckIsUnordered(
+          rewriter, op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
+      rewriter.replaceOpWithNewOp<emitc::LogicalOrOp>(
+          op, op.getType(), isUnordered.getResult(), cmpResult);
+      return success();
+    }
+
+    auto isOrdered = createCheckIsOrdered(rewriter, op.getLoc(),
+                                          adaptor.getLhs(), adaptor.getRhs());
+    rewriter.replaceOpWithNewOp<emitc::LogicalAndOp>(op, op.getType(),
+                                                     isOrdered, cmpResult);
+    return success();
+  }
+
+private:
+  /// Return an operation that returns true (in i1) when \p operand is NaN.
+  emitc::CmpOp isNan(ConversionPatternRewriter &rewriter, Location loc,
----------------
TinaAMD wrote:

Adapted here: https://github.com/llvm/llvm-project/pull/93671/commits/ffd688698559735c180c1e33d909c89aa43db17e

https://github.com/llvm/llvm-project/pull/93671


More information about the Mlir-commits mailing list