[Mlir-commits] [mlir] [mlir][spirv] Add conversions for Arith's `maxnumf` and `minnumf` (PR #66696)
Jakub Kuderski
llvmlistbot at llvm.org
Mon Sep 18 14:54:21 PDT 2023
================
@@ -1086,6 +1087,61 @@ class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {
}
};
+//===----------------------------------------------------------------------===//
+// MinNumFOp, MaxNumFOp
+//===----------------------------------------------------------------------===//
+
+/// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or
+/// spirv.CL.fmax/fmin.
+template <typename Op, typename SPIRVOp>
+class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
+ template <typename TargetOp>
+ constexpr bool shouldInsertNanGuards() const {
+ return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
+ }
+
+public:
+ using OpConversionPattern<Op>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
+ Type dstType = converter->convertType(op.getType());
+ if (!dstType)
+ return getTypeConversionFailure(rewriter, op);
+
+ // arith.maxnumf/minnumf:
+ // "If one of the arguments is NaN, then the result is the other
+ // argument."
+ // spirv.GL.FMax/FMin
+ // "which operand is the result is undefined if one of the operands
+ // is a NaN."
+ // spirv.CL.fmax/fmin:
+ // "If one argument is a NaN, Fmin returns the other argument."
+
+ Location loc = op.getLoc();
+ Value spirvOp =
+ rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
+
+ if (!shouldInsertNanGuards<SPIRVOp>() ||
+ converter->getOptions().enableFastMathMode) {
+ rewriter.replaceOp(op, spirvOp);
+ return success();
+ }
+
+ Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
+ Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
+
+ Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
+ adaptor.getRhs(), spirvOp);
----------------
kuhar wrote:
Could we lower it to something like:
```
%lhsIsNan = spirv.IsNan %lhs
%x = spirv.Select %lhsIsNan, %rhs, %lhs
%rhsIsNan = spirv.IsNan %rhs
%y = spirv.Select %rhsIsNan, %lhs, %rhs
%res = spirv.GL.FMax %x, %y
```
I don't know which of the lowerings should be preferred, just wanted to explore the alternatives.
https://github.com/llvm/llvm-project/pull/66696
More information about the Mlir-commits
mailing list