[Mlir-commits] [mlir] [mlir][spirv] Add conversions for Arith's `maxnumf` and `minnumf` (PR #66696)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 18 13:57:09 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
<details>
<summary>Changes</summary>
This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.
In this commit, we add conversion patterns for the newly introduced operations `arith.minnumf` and `arith.maxnumf`. When converting to `spirv.CL`, there is no need to insert additional guards to propagate non-NaN values when one of the arguments is NaN because `CL` ops do exactly the same. However, `GL` ops have undefined behavior when one of the arguments is NaN, so we should insert additional guards to enforce the semantics of Arith's ops.
This patch addresses the 1.5 task of the mentioned RFC.
---
Full diff: https://github.com/llvm/llvm-project/pull/66696.diff
3 Files Affected:
- (modified) mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp (+60)
- (modified) mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir (+53-8)
- (modified) mlir/test/Conversion/ArithToSPIRV/fast-math.mlir (+22-4)
``````````diff
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index a589fb8050f34db..aba6a21deccb0cf 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -19,6 +19,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"
#include <cassert>
@@ -1086,6 +1087,61 @@ class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {
}
};
+//===----------------------------------------------------------------------===//
+// MinNumFOp, MaxNumFOp
+//===----------------------------------------------------------------------===//
+
+/// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or
+/// spirv.CL.fmax/fmin.
+template <typename Op, typename SPIRVOp>
+class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
+ template <typename TargetOp>
+ constexpr bool shouldInsertNanGuards() const {
+ return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
+ }
+
+public:
+ using OpConversionPattern<Op>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
+ Type dstType = converter->convertType(op.getType());
+ if (!dstType)
+ return getTypeConversionFailure(rewriter, op);
+
+ // arith.maxnumf/minnumf:
+ // "If one of the arguments is NaN, then the result is the other
+ // argument."
+ // spirv.GL.FMax/FMin
+ // "which operand is the result is undefined if one of the operands
+ // is a NaN."
+ // spirv.CL.fmax/fmin:
+ // "If one argument is a NaN, Fmin returns the other argument."
+
+ Location loc = op.getLoc();
+ Value spirvOp =
+ rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
+
+ if (!shouldInsertNanGuards<SPIRVOp>() ||
+ converter->getOptions().enableFastMathMode) {
+ rewriter.replaceOp(op, spirvOp);
+ return success();
+ }
+
+ Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
+ Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
+
+ Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
+ adaptor.getRhs(), spirvOp);
+ Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
+ adaptor.getLhs(), select1);
+
+ rewriter.replaceOp(op, select2);
+ return success();
+ }
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -1138,6 +1194,8 @@ void mlir::arith::populateArithToSPIRVPatterns(
MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
+ MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
+ MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSMaxOp>,
spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::GLUMaxOp>,
spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSMinOp>,
@@ -1145,6 +1203,8 @@ void mlir::arith::populateArithToSPIRVPatterns(
MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
+ MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
+ MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::CLSMaxOp>,
spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::CLUMaxOp>,
spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::CLSMinOp>,
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 165877eb554e245..0221e4815a9397d 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -1124,9 +1124,9 @@ func.func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
return
}
-// CHECK-LABEL: @float32_minf_scalar
+// CHECK-LABEL: @float32_minimumf_scalar
// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
-func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
+func.func @float32_minimumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
// CHECK: %[[MIN:.+]] = spirv.CL.fmin %arg0, %arg1 : f32
// CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : f32
// CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : f32
@@ -1137,9 +1137,18 @@ func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
return %0: f32
}
-// CHECK-LABEL: @float32_maxf_scalar
+// CHECK-LABEL: @float32_minnumf_scalar
+// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
+func.func @float32_minnumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
+ // CHECK: %[[MIN:.+]] = spirv.CL.fmin %arg0, %arg1 : f32
+ %0 = arith.minnumf %arg0, %arg1 : f32
+ // CHECK: return %[[MIN]]
+ return %0: f32
+}
+
+// CHECK-LABEL: @float32_maximumf_scalar
// CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
-func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
+func.func @float32_maximumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
// CHECK: %[[MAX:.+]] = spirv.CL.fmax %arg0, %arg1 : vector<2xf32>
// CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : vector<2xf32>
// CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : vector<2xf32>
@@ -1150,6 +1159,16 @@ func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) ->
return %0: vector<2xf32>
}
+// CHECK-LABEL: @float32_maxnumf_scalar
+// CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
+func.func @float32_maxnumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
+ // CHECK: %[[MAX:.+]] = spirv.CL.fmax %arg0, %arg1 : vector<2xf32>
+ %0 = arith.maxnumf %arg0, %arg1 : vector<2xf32>
+ // CHECK: return %[[MAX]]
+ return %0: vector<2xf32>
+}
+
+
// CHECK-LABEL: @scalar_srem
// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
func.func @scalar_srem(%lhs: i32, %rhs: i32) {
@@ -1270,9 +1289,9 @@ func.func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
return
}
-// CHECK-LABEL: @float32_minf_scalar
+// CHECK-LABEL: @float32_minimumf_scalar
// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
-func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
+func.func @float32_minimumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
// CHECK: %[[MIN:.+]] = spirv.GL.FMin %arg0, %arg1 : f32
// CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : f32
// CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : f32
@@ -1283,9 +1302,22 @@ func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
return %0: f32
}
-// CHECK-LABEL: @float32_maxf_scalar
+// CHECK-LABEL: @float32_minnumf_scalar
+// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
+func.func @float32_minnumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
+ // CHECK: %[[MIN:.+]] = spirv.GL.FMin %arg0, %arg1 : f32
+ // CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : f32
+ // CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : f32
+ // CHECK: %[[SELECT1:.+]] = spirv.Select %[[LHS_NAN]], %[[RHS]], %[[MIN]]
+ // CHECK: %[[SELECT2:.+]] = spirv.Select %[[RHS_NAN]], %[[LHS]], %[[SELECT1]]
+ %0 = arith.minnumf %arg0, %arg1 : f32
+ // CHECK: return %[[SELECT2]]
+ return %0: f32
+}
+
+// CHECK-LABEL: @float32_maximumf_scalar
// CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
-func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
+func.func @float32_maximumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
// CHECK: %[[MAX:.+]] = spirv.GL.FMax %arg0, %arg1 : vector<2xf32>
// CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : vector<2xf32>
// CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : vector<2xf32>
@@ -1296,6 +1328,19 @@ func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) ->
return %0: vector<2xf32>
}
+// CHECK-LABEL: @float32_maxnumf_scalar
+// CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
+func.func @float32_maxnumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
+ // CHECK: %[[MAX:.+]] = spirv.GL.FMax %arg0, %arg1 : vector<2xf32>
+ // CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : vector<2xf32>
+ // CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : vector<2xf32>
+ // CHECK: %[[SELECT1:.+]] = spirv.Select %[[LHS_NAN]], %[[RHS]], %[[MAX]]
+ // CHECK: %[[SELECT2:.+]] = spirv.Select %[[RHS_NAN]], %[[LHS]], %[[SELECT1]]
+ %0 = arith.maxnumf %arg0, %arg1 : vector<2xf32>
+ // CHECK: return %[[SELECT2]]
+ return %0: vector<2xf32>
+}
+
// Check int vector types.
// CHECK-LABEL: @int_vector234
func.func @int_vector234(%arg0: vector<2xi8>, %arg1: vector<4xi64>) {
diff --git a/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir b/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir
index 9dea7d6623885e4..dbf0361c2ab35bb 100644
--- a/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir
@@ -30,22 +30,40 @@ module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], []>, #spirv.resource_limits<>>
} {
-// CHECK-LABEL: @minf
+// CHECK-LABEL: @minimumf
// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
-func.func @minf(%arg0 : f32, %arg1 : f32) -> f32 {
+func.func @minimumf(%arg0 : f32, %arg1 : f32) -> f32 {
// CHECK: %[[F:.+]] = spirv.GL.FMin %[[LHS]], %[[RHS]]
%0 = arith.minimumf %arg0, %arg1 : f32
// CHECK: return %[[F]]
return %0: f32
}
-// CHECK-LABEL: @maxf
+// CHECK-LABEL: @maximumf
// CHECK-SAME: %[[LHS:.+]]: vector<4xf32>, %[[RHS:.+]]: vector<4xf32>
-func.func @maxf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> {
+func.func @maximumf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> {
// CHECK: %[[F:.+]] = spirv.GL.FMax %[[LHS]], %[[RHS]]
%0 = arith.maximumf %arg0, %arg1 : vector<4xf32>
// CHECK: return %[[F]]
return %0: vector<4xf32>
}
+// CHECK-LABEL: @minnumf
+// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
+func.func @minnumf(%arg0 : f32, %arg1 : f32) -> f32 {
+ // CHECK: %[[F:.+]] = spirv.GL.FMin %[[LHS]], %[[RHS]]
+ %0 = arith.minnumf %arg0, %arg1 : f32
+ // CHECK: return %[[F]]
+ return %0: f32
+}
+
+// CHECK-LABEL: @maxnumf
+// CHECK-SAME: %[[LHS:.+]]: vector<4xf32>, %[[RHS:.+]]: vector<4xf32>
+func.func @maxnumf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> {
+ // CHECK: %[[F:.+]] = spirv.GL.FMax %[[LHS]], %[[RHS]]
+ %0 = arith.maxnumf %arg0, %arg1 : vector<4xf32>
+ // CHECK: return %[[F]]
+ return %0: vector<4xf32>
+}
+
} // end module
``````````
</details>
https://github.com/llvm/llvm-project/pull/66696
More information about the Mlir-commits
mailing list