[Mlir-commits] [mlir] [mlir][spirv] Add conversions for Arith's `maxnumf` and `minnumf` (PR #66696)

Jakub Kuderski llvmlistbot at llvm.org
Mon Sep 18 14:54:22 PDT 2023


================
@@ -1086,6 +1087,61 @@ class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// MinNumFOp, MaxNumFOp
+//===----------------------------------------------------------------------===//
+
+/// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or
+/// spirv.CL.fmax/fmin.
+template <typename Op, typename SPIRVOp>
+class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
+  template <typename TargetOp>
+  constexpr bool shouldInsertNanGuards() const {
+    return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
+  }
+
+public:
+  using OpConversionPattern<Op>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
+    Type dstType = converter->convertType(op.getType());
+    if (!dstType)
+      return getTypeConversionFailure(rewriter, op);
+
+    // arith.maxnumf/minnumf:
+    //   "If one of the arguments is NaN, then the result is the other
+    //   argument."
+    // spirv.GL.FMax/FMin
+    //   "which operand is the result is undefined if one of the operands
+    //   is a NaN."
+    // spirv.CL.fmax/fmin:
+    //   "If one argument is a NaN, Fmin returns the other argument."
+
+    Location loc = op.getLoc();
+    Value spirvOp =
+        rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
+
+    if (!shouldInsertNanGuards<SPIRVOp>() ||
+        converter->getOptions().enableFastMathMode) {
+      rewriter.replaceOp(op, spirvOp);
+      return success();
+    }
+
+    Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
+    Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
+
+    Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
+                                                     adaptor.getRhs(), spirvOp);
----------------
kuhar wrote:

Ah no, the difference is that `arith.*numf` guarantees a NaN output when both inputs are NaN, and in `spirv.GL` we would get an undefined result. So the current lowering is the way to go.

https://github.com/llvm/llvm-project/pull/66696


More information about the Mlir-commits mailing list