[Mlir-commits] [mlir] [mlir][spirv] Add conversions for Arith's `maxnumf` and `minnumf` (PR #66696)

Daniil Dudkin llvmlistbot at llvm.org
Mon Sep 18 13:55:48 PDT 2023


https://github.com/unterumarmung created https://github.com/llvm/llvm-project/pull/66696

This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

In this commit, we add conversion patterns for the newly introduced operations `arith.minnumf` and `arith.maxnumf`. When converting to `spirv.CL`, there is no need to insert additional guards to propagate non-NaN values when one of the arguments is NaN because `CL` ops do exactly the same. However, `GL` ops have undefined behavior when one of the arguments is NaN, so we should insert additional guards to enforce the semantics of Arith's ops.

This patch addresses the 1.5 task of the mentioned RFC.


>From 1cbe674701e64bede0e796ec06d043f2686ee821 Mon Sep 17 00:00:00 2001
From: Daniil Dudkin <unterumarmung at yandex.ru>
Date: Mon, 18 Sep 2023 23:55:05 +0300
Subject: [PATCH] [mlir][spirv] Add conversions for Arith's `maxnumf` and
 `minnumf`

This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

In this commit, we add conversion patterns for the newly introduced operations `arith.minnumf` and `arith.maxnumf`. When converting to `spirv.CL`, there is no need to insert additional guards to propagate non-NaN values when one of the arguments is NaN because `CL` ops do exactly the same. However, `GL` ops have undefined behavior when one of the arguments is NaN, so we should insert additional guards to enforce the semantics of Arith's ops.

This patch addresses the 1.5 task of the mentioned RFC.
---
 .../Conversion/ArithToSPIRV/ArithToSPIRV.cpp  | 60 ++++++++++++++++++
 .../ArithToSPIRV/arith-to-spirv.mlir          | 61 ++++++++++++++++---
 .../Conversion/ArithToSPIRV/fast-math.mlir    | 26 ++++++--
 3 files changed, 135 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index a589fb8050f34db..aba6a21deccb0cf 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -19,6 +19,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/MathExtras.h"
 #include <cassert>
@@ -1086,6 +1087,61 @@ class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// MinNumFOp, MaxNumFOp
+//===----------------------------------------------------------------------===//
+
+/// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or
+/// spirv.CL.fmax/fmin.
+template <typename Op, typename SPIRVOp>
+class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
+  template <typename TargetOp>
+  constexpr bool shouldInsertNanGuards() const {
+    return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
+  }
+
+public:
+  using OpConversionPattern<Op>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
+    Type dstType = converter->convertType(op.getType());
+    if (!dstType)
+      return getTypeConversionFailure(rewriter, op);
+
+    // arith.maxnumf/minnumf:
+    //   "If one of the arguments is NaN, then the result is the other
+    //   argument."
+    // spirv.GL.FMax/FMin
+    //   "which operand is the result is undefined if one of the operands
+    //   is a NaN."
+    // spirv.CL.fmax/fmin:
+    //   "If one argument is a NaN, Fmin returns the other argument."
+
+    Location loc = op.getLoc();
+    Value spirvOp =
+        rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
+
+    if (!shouldInsertNanGuards<SPIRVOp>() ||
+        converter->getOptions().enableFastMathMode) {
+      rewriter.replaceOp(op, spirvOp);
+      return success();
+    }
+
+    Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
+    Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
+
+    Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
+                                                     adaptor.getRhs(), spirvOp);
+    Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
+                                                     adaptor.getLhs(), select1);
+
+    rewriter.replaceOp(op, select2);
+    return success();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -1138,6 +1194,8 @@ void mlir::arith::populateArithToSPIRVPatterns(
 
     MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
     MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
+    MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
+    MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
     spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSMaxOp>,
     spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::GLUMaxOp>,
     spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSMinOp>,
@@ -1145,6 +1203,8 @@ void mlir::arith::populateArithToSPIRVPatterns(
 
     MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
     MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
+    MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
+    MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
     spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::CLSMaxOp>,
     spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::CLUMaxOp>,
     spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::CLSMinOp>,
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 165877eb554e245..0221e4815a9397d 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -1124,9 +1124,9 @@ func.func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
   return
 }
 
-// CHECK-LABEL: @float32_minf_scalar
+// CHECK-LABEL: @float32_minimumf_scalar
 // CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
-func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
+func.func @float32_minimumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
   // CHECK: %[[MIN:.+]] = spirv.CL.fmin %arg0, %arg1 : f32
   // CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : f32
   // CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : f32
@@ -1137,9 +1137,18 @@ func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
   return %0: f32
 }
 
-// CHECK-LABEL: @float32_maxf_scalar
+// CHECK-LABEL: @float32_minnumf_scalar
+// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
+func.func @float32_minnumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
+  // CHECK: %[[MIN:.+]] = spirv.CL.fmin %arg0, %arg1 : f32
+  %0 = arith.minnumf %arg0, %arg1 : f32
+  // CHECK: return %[[MIN]]
+  return %0: f32
+}
+
+// CHECK-LABEL: @float32_maximumf_scalar
 // CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
-func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
+func.func @float32_maximumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
   // CHECK: %[[MAX:.+]] = spirv.CL.fmax %arg0, %arg1 : vector<2xf32>
   // CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : vector<2xf32>
   // CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : vector<2xf32>
@@ -1150,6 +1159,16 @@ func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) ->
   return %0: vector<2xf32>
 }
 
+// CHECK-LABEL: @float32_maxnumf_scalar
+// CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
+func.func @float32_maxnumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
+  // CHECK: %[[MAX:.+]] = spirv.CL.fmax %arg0, %arg1 : vector<2xf32>
+  %0 = arith.maxnumf %arg0, %arg1 : vector<2xf32>
+  // CHECK: return %[[MAX]]
+  return %0: vector<2xf32>
+}
+
+
 // CHECK-LABEL: @scalar_srem
 // CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
 func.func @scalar_srem(%lhs: i32, %rhs: i32) {
@@ -1270,9 +1289,9 @@ func.func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
   return
 }
 
-// CHECK-LABEL: @float32_minf_scalar
+// CHECK-LABEL: @float32_minimumf_scalar
 // CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
-func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
+func.func @float32_minimumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
   // CHECK: %[[MIN:.+]] = spirv.GL.FMin %arg0, %arg1 : f32
   // CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : f32
   // CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : f32
@@ -1283,9 +1302,22 @@ func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
   return %0: f32
 }
 
-// CHECK-LABEL: @float32_maxf_scalar
+// CHECK-LABEL: @float32_minnumf_scalar
+// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
+func.func @float32_minnumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
+  // CHECK: %[[MIN:.+]] = spirv.GL.FMin %arg0, %arg1 : f32
+  // CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : f32
+  // CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : f32
+  // CHECK: %[[SELECT1:.+]] = spirv.Select %[[LHS_NAN]], %[[RHS]], %[[MIN]]
+  // CHECK: %[[SELECT2:.+]] = spirv.Select %[[RHS_NAN]], %[[LHS]], %[[SELECT1]]
+  %0 = arith.minnumf %arg0, %arg1 : f32
+  // CHECK: return %[[SELECT2]]
+  return %0: f32
+}
+
+// CHECK-LABEL: @float32_maximumf_scalar
 // CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
-func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
+func.func @float32_maximumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
   // CHECK: %[[MAX:.+]] = spirv.GL.FMax %arg0, %arg1 : vector<2xf32>
   // CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : vector<2xf32>
   // CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : vector<2xf32>
@@ -1296,6 +1328,19 @@ func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) ->
   return %0: vector<2xf32>
 }
 
+// CHECK-LABEL: @float32_maxnumf_scalar
+// CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
+func.func @float32_maxnumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
+  // CHECK: %[[MAX:.+]] = spirv.GL.FMax %arg0, %arg1 : vector<2xf32>
+  // CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : vector<2xf32>
+  // CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : vector<2xf32>
+  // CHECK: %[[SELECT1:.+]] = spirv.Select %[[LHS_NAN]], %[[RHS]], %[[MAX]]
+  // CHECK: %[[SELECT2:.+]] = spirv.Select %[[RHS_NAN]], %[[LHS]], %[[SELECT1]]
+  %0 = arith.maxnumf %arg0, %arg1 : vector<2xf32>
+  // CHECK: return %[[SELECT2]]
+  return %0: vector<2xf32>
+}
+
 // Check int vector types.
 // CHECK-LABEL: @int_vector234
 func.func @int_vector234(%arg0: vector<2xi8>, %arg1: vector<4xi64>) {
diff --git a/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir b/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir
index 9dea7d6623885e4..dbf0361c2ab35bb 100644
--- a/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir
@@ -30,22 +30,40 @@ module attributes {
   spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], []>, #spirv.resource_limits<>>
 } {
 
-// CHECK-LABEL: @minf
+// CHECK-LABEL: @minimumf
 // CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
-func.func @minf(%arg0 : f32, %arg1 : f32) -> f32 {
+func.func @minimumf(%arg0 : f32, %arg1 : f32) -> f32 {
   // CHECK: %[[F:.+]] = spirv.GL.FMin %[[LHS]], %[[RHS]]
   %0 = arith.minimumf %arg0, %arg1 : f32
   // CHECK: return %[[F]]
   return %0: f32
 }
 
-// CHECK-LABEL: @maxf
+// CHECK-LABEL: @maximumf
 // CHECK-SAME: %[[LHS:.+]]: vector<4xf32>, %[[RHS:.+]]: vector<4xf32>
-func.func @maxf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> {
+func.func @maximumf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> {
   // CHECK: %[[F:.+]] = spirv.GL.FMax %[[LHS]], %[[RHS]]
   %0 = arith.maximumf %arg0, %arg1 : vector<4xf32>
   // CHECK: return %[[F]]
   return %0: vector<4xf32>
 }
 
+// CHECK-LABEL: @minnumf
+// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
+func.func @minnumf(%arg0 : f32, %arg1 : f32) -> f32 {
+  // CHECK: %[[F:.+]] = spirv.GL.FMin %[[LHS]], %[[RHS]]
+  %0 = arith.minnumf %arg0, %arg1 : f32
+  // CHECK: return %[[F]]
+  return %0: f32
+}
+
+// CHECK-LABEL: @maxnumf
+// CHECK-SAME: %[[LHS:.+]]: vector<4xf32>, %[[RHS:.+]]: vector<4xf32>
+func.func @maxnumf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> {
+  // CHECK: %[[F:.+]] = spirv.GL.FMax %[[LHS]], %[[RHS]]
+  %0 = arith.maxnumf %arg0, %arg1 : vector<4xf32>
+  // CHECK: return %[[F]]
+  return %0: vector<4xf32>
+}
+
 } // end module



More information about the Mlir-commits mailing list