[Mlir-commits] [mlir] [mlir][emitc] arith.cmpf to EmitC conversion (PR #93671)

Marius Brehler llvmlistbot at llvm.org
Mon Jun 3 04:57:57 PDT 2024


================
@@ -40,6 +42,160 @@ class ArithConstantOpConversionPattern
   }
 };
 
+class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    if (!isa<FloatType>(adaptor.getRhs().getType())) {
+      return rewriter.notifyMatchFailure(op.getLoc(),
+                                         "cmpf currently only supported on "
+                                         "floats, not tensors/vectors thereof");
+    }
+
+    bool unordered = false;
+    emitc::CmpPredicate predicate;
+    switch (op.getPredicate()) {
+    case arith::CmpFPredicate::AlwaysFalse: {
+      auto constant = rewriter.create<emitc::ConstantOp>(
+          op.getLoc(), rewriter.getI1Type(),
+          rewriter.getBoolAttr(/*value=*/false));
+      rewriter.replaceOp(op, constant);
+      return success();
+    }
+    case arith::CmpFPredicate::OEQ:
+      unordered = false;
+      predicate = emitc::CmpPredicate::eq;
+      break;
+    case arith::CmpFPredicate::OGT:
+      unordered = false;
+      predicate = emitc::CmpPredicate::gt;
+      break;
+    case arith::CmpFPredicate::OGE:
+      unordered = false;
+      predicate = emitc::CmpPredicate::ge;
+      break;
+    case arith::CmpFPredicate::OLT:
+      unordered = false;
+      predicate = emitc::CmpPredicate::lt;
+      break;
+    case arith::CmpFPredicate::OLE:
+      unordered = false;
+      predicate = emitc::CmpPredicate::le;
+      break;
+    case arith::CmpFPredicate::ONE:
+      unordered = false;
+      predicate = emitc::CmpPredicate::ne;
+      break;
+    case arith::CmpFPredicate::ORD: {
+      // ordered, i.e. none of the operands is NaN
+      auto cmp = createCheckIsOrdered(rewriter, op.getLoc(), adaptor.getLhs(),
+                                      adaptor.getRhs());
+      rewriter.replaceOp(op, cmp);
+      return success();
+    }
+    case arith::CmpFPredicate::UEQ:
+      unordered = true;
+      predicate = emitc::CmpPredicate::eq;
+      break;
+    case arith::CmpFPredicate::UGT:
+      unordered = true;
+      predicate = emitc::CmpPredicate::gt;
+      break;
+    case arith::CmpFPredicate::UGE:
+      unordered = true;
+      predicate = emitc::CmpPredicate::ge;
+      break;
+    case arith::CmpFPredicate::ULT:
+      unordered = true;
+      predicate = emitc::CmpPredicate::lt;
+      break;
+    case arith::CmpFPredicate::ULE:
+      unordered = true;
+      predicate = emitc::CmpPredicate::le;
+      break;
+    case arith::CmpFPredicate::UNE:
+      unordered = true;
+      predicate = emitc::CmpPredicate::ne;
+      break;
+    case arith::CmpFPredicate::UNO: {
+      // unordered, i.e. either operand is nan
+      auto cmp = createCheckIsUnordered(rewriter, op.getLoc(), adaptor.getLhs(),
+                                        adaptor.getRhs());
+      rewriter.replaceOp(op, cmp);
+      return success();
+    }
+    case arith::CmpFPredicate::AlwaysTrue: {
+      auto constant = rewriter.create<emitc::ConstantOp>(
+          op.getLoc(), rewriter.getI1Type(),
+          rewriter.getBoolAttr(/*value=*/true));
+      rewriter.replaceOp(op, constant);
+      return success();
+    }
+    }
+
+    // Compare the values naively
+    auto cmpResult =
+        rewriter.create<emitc::CmpOp>(op.getLoc(), op.getType(), predicate,
+                                      adaptor.getLhs(), adaptor.getRhs());
+
+    // Adjust the results for unordered/ordered semantics
+    if (unordered) {
+      auto isUnordered = createCheckIsUnordered(
+          rewriter, op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
+      rewriter.replaceOpWithNewOp<emitc::LogicalOrOp>(op, op.getType(),
+                                                      isUnordered, cmpResult);
+      return success();
+    }
+
+    auto isOrdered = createCheckIsOrdered(rewriter, op.getLoc(),
+                                          adaptor.getLhs(), adaptor.getRhs());
+    rewriter.replaceOpWithNewOp<emitc::LogicalAndOp>(op, op.getType(),
+                                                     isOrdered, cmpResult);
+    return success();
+  }
+
+private:
+  /// Return a value that is true iff \p operand is NaN.
+  Value isNaN(ConversionPatternRewriter &rewriter, Location loc,
+              Value operand) const {
+    // A value is NaN exactly when it compares unequal to itself.
+    return rewriter.create<emitc::CmpOp>(
+        loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, operand, operand);
+  }
+
+  /// Return a value that is true iff \p operand is not NaN.
+  Value isNotNaN(ConversionPatternRewriter &rewriter, Location loc,
+                 Value operand) const {
+    // A value is not NaN exactly when it compares equal to itself.
+    return rewriter.create<emitc::CmpOp>(
+        loc, rewriter.getI1Type(), emitc::CmpPredicate::eq, operand, operand);
+  }
+
+  /// Return a value that is true iff the operands \p first and \p second are
----------------
marbre wrote:

```suggestion
  /// Return a value that is true if the operands \p first and \p second are
```

https://github.com/llvm/llvm-project/pull/93671


More information about the Mlir-commits mailing list