[Mlir-commits] [mlir] 641124a - [mlir][spirv] Add conversions for Arith's `maxnumf` and `minnumf` (#66696)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 19 12:49:53 PDT 2023


Author: Daniil Dudkin
Date: 2023-09-19T22:49:48+03:00
New Revision: 641124a9b98eef668d8bb3bd3b6da5105a17284e

URL: https://github.com/llvm/llvm-project/commit/641124a9b98eef668d8bb3bd3b6da5105a17284e
DIFF: https://github.com/llvm/llvm-project/commit/641124a9b98eef668d8bb3bd3b6da5105a17284e.diff

LOG: [mlir][spirv] Add conversions for Arith's `maxnumf` and `minnumf` (#66696)

This patch is part of a larger initiative aimed at fixing floating-point
`max` and `min` operations in MLIR:
https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

In this commit, we add conversion patterns for the newly introduced
operations `arith.minnumf` and `arith.maxnumf`. When converting to
`spirv.CL`, there is no need to insert additional guards to propagate
non-NaN values when one of the arguments is NaN because `CL` ops do
exactly the same. However, `GL` ops have undefined behavior when one of
the arguments is NaN, so we should insert additional guards to enforce
the semantics of Arith's ops.

This patch addresses the 1.5 task of the mentioned RFC.

Added: 
    

Modified: 
    mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
    mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
    mlir/test/Conversion/ArithToSPIRV/fast-math.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index a589fb8050f34db..aba6a21deccb0cf 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -19,6 +19,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/MathExtras.h"
 #include <cassert>
@@ -1086,6 +1087,61 @@ class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// MinNumFOp, MaxNumFOp
+//===----------------------------------------------------------------------===//
+
+/// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or
+/// spirv.CL.fmax/fmin.
+template <typename Op, typename SPIRVOp>
+class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
+  template <typename TargetOp>
+  constexpr bool shouldInsertNanGuards() const {
+    return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
+  }
+
+public:
+  using OpConversionPattern<Op>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
+    Type dstType = converter->convertType(op.getType());
+    if (!dstType)
+      return getTypeConversionFailure(rewriter, op);
+
+    // arith.maxnumf/minnumf:
+    //   "If one of the arguments is NaN, then the result is the other
+    //   argument."
+    // spirv.GL.FMax/FMin
+    //   "which operand is the result is undefined if one of the operands
+    //   is a NaN."
+    // spirv.CL.fmax/fmin:
+    //   "If one argument is a NaN, Fmin returns the other argument."
+
+    Location loc = op.getLoc();
+    Value spirvOp =
+        rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
+
+    if (!shouldInsertNanGuards<SPIRVOp>() ||
+        converter->getOptions().enableFastMathMode) {
+      rewriter.replaceOp(op, spirvOp);
+      return success();
+    }
+
+    Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
+    Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
+
+    Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
+                                                     adaptor.getRhs(), spirvOp);
+    Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
+                                                     adaptor.getLhs(), select1);
+
+    rewriter.replaceOp(op, select2);
+    return success();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -1138,6 +1194,8 @@ void mlir::arith::populateArithToSPIRVPatterns(
 
     MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
     MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
+    MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
+    MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
     spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSMaxOp>,
     spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::GLUMaxOp>,
     spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSMinOp>,
@@ -1145,6 +1203,8 @@ void mlir::arith::populateArithToSPIRVPatterns(
 
     MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
     MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
+    MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
+    MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
     spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::CLSMaxOp>,
     spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::CLUMaxOp>,
     spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::CLSMinOp>,

diff  --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 165877eb554e245..0221e4815a9397d 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -1124,9 +1124,9 @@ func.func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
   return
 }
 
-// CHECK-LABEL: @float32_minf_scalar
+// CHECK-LABEL: @float32_minimumf_scalar
 // CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
-func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
+func.func @float32_minimumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
   // CHECK: %[[MIN:.+]] = spirv.CL.fmin %arg0, %arg1 : f32
   // CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : f32
   // CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : f32
@@ -1137,9 +1137,18 @@ func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
   return %0: f32
 }
 
-// CHECK-LABEL: @float32_maxf_scalar
+// CHECK-LABEL: @float32_minnumf_scalar
+// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
+func.func @float32_minnumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
+  // CHECK: %[[MIN:.+]] = spirv.CL.fmin %arg0, %arg1 : f32
+  %0 = arith.minnumf %arg0, %arg1 : f32
+  // CHECK: return %[[MIN]]
+  return %0: f32
+}
+
+// CHECK-LABEL: @float32_maximumf_scalar
 // CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
-func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
+func.func @float32_maximumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
   // CHECK: %[[MAX:.+]] = spirv.CL.fmax %arg0, %arg1 : vector<2xf32>
   // CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : vector<2xf32>
   // CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : vector<2xf32>
@@ -1150,6 +1159,16 @@ func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) ->
   return %0: vector<2xf32>
 }
 
+// CHECK-LABEL: @float32_maxnumf_scalar
+// CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
+func.func @float32_maxnumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
+  // CHECK: %[[MAX:.+]] = spirv.CL.fmax %arg0, %arg1 : vector<2xf32>
+  %0 = arith.maxnumf %arg0, %arg1 : vector<2xf32>
+  // CHECK: return %[[MAX]]
+  return %0: vector<2xf32>
+}
+
+
 // CHECK-LABEL: @scalar_srem
 // CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
 func.func @scalar_srem(%lhs: i32, %rhs: i32) {
@@ -1270,9 +1289,9 @@ func.func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
   return
 }
 
-// CHECK-LABEL: @float32_minf_scalar
+// CHECK-LABEL: @float32_minimumf_scalar
 // CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
-func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
+func.func @float32_minimumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
   // CHECK: %[[MIN:.+]] = spirv.GL.FMin %arg0, %arg1 : f32
   // CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : f32
   // CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : f32
@@ -1283,9 +1302,22 @@ func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
   return %0: f32
 }
 
-// CHECK-LABEL: @float32_maxf_scalar
+// CHECK-LABEL: @float32_minnumf_scalar
+// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
+func.func @float32_minnumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
+  // CHECK: %[[MIN:.+]] = spirv.GL.FMin %arg0, %arg1 : f32
+  // CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : f32
+  // CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : f32
+  // CHECK: %[[SELECT1:.+]] = spirv.Select %[[LHS_NAN]], %[[RHS]], %[[MIN]]
+  // CHECK: %[[SELECT2:.+]] = spirv.Select %[[RHS_NAN]], %[[LHS]], %[[SELECT1]]
+  %0 = arith.minnumf %arg0, %arg1 : f32
+  // CHECK: return %[[SELECT2]]
+  return %0: f32
+}
+
+// CHECK-LABEL: @float32_maximumf_scalar
 // CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
-func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
+func.func @float32_maximumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
   // CHECK: %[[MAX:.+]] = spirv.GL.FMax %arg0, %arg1 : vector<2xf32>
   // CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : vector<2xf32>
   // CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : vector<2xf32>
@@ -1296,6 +1328,19 @@ func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) ->
   return %0: vector<2xf32>
 }
 
+// CHECK-LABEL: @float32_maxnumf_scalar
+// CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
+func.func @float32_maxnumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
+  // CHECK: %[[MAX:.+]] = spirv.GL.FMax %arg0, %arg1 : vector<2xf32>
+  // CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : vector<2xf32>
+  // CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : vector<2xf32>
+  // CHECK: %[[SELECT1:.+]] = spirv.Select %[[LHS_NAN]], %[[RHS]], %[[MAX]]
+  // CHECK: %[[SELECT2:.+]] = spirv.Select %[[RHS_NAN]], %[[LHS]], %[[SELECT1]]
+  %0 = arith.maxnumf %arg0, %arg1 : vector<2xf32>
+  // CHECK: return %[[SELECT2]]
+  return %0: vector<2xf32>
+}
+
 // Check int vector types.
 // CHECK-LABEL: @int_vector234
 func.func @int_vector234(%arg0: vector<2xi8>, %arg1: vector<4xi64>) {

diff  --git a/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir b/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir
index 9dea7d6623885e4..dbf0361c2ab35bb 100644
--- a/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir
@@ -30,22 +30,40 @@ module attributes {
   spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], []>, #spirv.resource_limits<>>
 } {
 
-// CHECK-LABEL: @minf
+// CHECK-LABEL: @minimumf
 // CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
-func.func @minf(%arg0 : f32, %arg1 : f32) -> f32 {
+func.func @minimumf(%arg0 : f32, %arg1 : f32) -> f32 {
   // CHECK: %[[F:.+]] = spirv.GL.FMin %[[LHS]], %[[RHS]]
   %0 = arith.minimumf %arg0, %arg1 : f32
   // CHECK: return %[[F]]
   return %0: f32
 }
 
-// CHECK-LABEL: @maxf
+// CHECK-LABEL: @maximumf
 // CHECK-SAME: %[[LHS:.+]]: vector<4xf32>, %[[RHS:.+]]: vector<4xf32>
-func.func @maxf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> {
+func.func @maximumf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> {
   // CHECK: %[[F:.+]] = spirv.GL.FMax %[[LHS]], %[[RHS]]
   %0 = arith.maximumf %arg0, %arg1 : vector<4xf32>
   // CHECK: return %[[F]]
   return %0: vector<4xf32>
 }
 
+// CHECK-LABEL: @minnumf
+// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
+func.func @minnumf(%arg0 : f32, %arg1 : f32) -> f32 {
+  // CHECK: %[[F:.+]] = spirv.GL.FMin %[[LHS]], %[[RHS]]
+  %0 = arith.minnumf %arg0, %arg1 : f32
+  // CHECK: return %[[F]]
+  return %0: f32
+}
+
+// CHECK-LABEL: @maxnumf
+// CHECK-SAME: %[[LHS:.+]]: vector<4xf32>, %[[RHS:.+]]: vector<4xf32>
+func.func @maxnumf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> {
+  // CHECK: %[[F:.+]] = spirv.GL.FMax %[[LHS]], %[[RHS]]
+  %0 = arith.maxnumf %arg0, %arg1 : vector<4xf32>
+  // CHECK: return %[[F]]
+  return %0: vector<4xf32>
+}
+
 } // end module


        


More information about the Mlir-commits mailing list