[Mlir-commits] [mlir] [mlir][spirv] Add conversions for Arith's `maxnumf` and `minnumf` (PR #66696)
Jakub Kuderski
llvmlistbot at llvm.org
Mon Sep 18 14:54:22 PDT 2023
================
@@ -1086,6 +1087,61 @@ class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {
}
};
+//===----------------------------------------------------------------------===//
+// MinNumFOp, MaxNumFOp
+//===----------------------------------------------------------------------===//
+
+/// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or
+/// spirv.CL.fmax/fmin.
+template <typename Op, typename SPIRVOp>
+class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
+ template <typename TargetOp>
+ constexpr bool shouldInsertNanGuards() const {
+ return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
+ }
+
+public:
+ using OpConversionPattern<Op>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
+ Type dstType = converter->convertType(op.getType());
+ if (!dstType)
+ return getTypeConversionFailure(rewriter, op);
+
+ // arith.maxnumf/minnumf:
+ // "If one of the arguments is NaN, then the result is the other
+ // argument."
+ // spirv.GL.FMax/FMin
+ // "which operand is the result is undefined if one of the operands
+ // is a NaN."
+ // spirv.CL.fmax/fmin:
+ // "If one argument is a NaN, Fmin returns the other argument."
+
+ Location loc = op.getLoc();
+ Value spirvOp =
+ rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
+
+ if (!shouldInsertNanGuards<SPIRVOp>() ||
+ converter->getOptions().enableFastMathMode) {
+ rewriter.replaceOp(op, spirvOp);
+ return success();
+ }
+
+ Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
+ Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
+
+ Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
+ adaptor.getRhs(), spirvOp);
----------------
kuhar wrote:
Ah no, the difference is that `arith.*numf` guarantees a NaN output when both inputs are NaN, and in `spirv.GL` we would get an undefined result. So the current lowering is the way to go.
https://github.com/llvm/llvm-project/pull/66696
More information about the Mlir-commits
mailing list