[Mlir-commits] [mlir] [mlir][arith] Introduce `minnumf` and `maxnumf` operations (PR #66429)

Daniil Dudkin llvmlistbot at llvm.org
Thu Sep 14 13:48:21 PDT 2023


https://github.com/unterumarmung created https://github.com/llvm/llvm-project/pull/66429:

This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

Here we introduce new operations for floating-point numbers: `minnum` and `maxnum`.
These operations have different semantics than `minumumf` and `maximumf` ops.
They follow the eponymous LLVM intrinsics semantics, which differs
in the handling positive and negative zeros and NaNs.

This patch addresses the 1.3 task from the RFC.


>From d11c9a2c186dc71037d8fa82c1673f6ea9d0d420 Mon Sep 17 00:00:00 2001
From: Daniil Dudkin <unterumarmung at yandex.ru>
Date: Tue, 12 Sep 2023 22:40:39 +0300
Subject: [PATCH] [mlir][arith] Introduce `minnumf` and `maxnumf` operations

This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

Here we introduce new operations for floating-point numbers: `minnum` and `maxnum`.
These operations have different semantics than `minumumf` and `maximumf` ops.
They follow the eponymous LLVM intrinsics semantics, which differs
in the handling positive and negative zeros and NaNs.

This patch addresses the 1.3 task from the RFC.
---
 .../include/mlir/Dialect/Arith/IR/ArithOps.td | 55 +++++++++++++++++++
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        | 45 +++++++++++++--
 mlir/test/Dialect/Arith/canonicalize.mlir     | 38 +++++++++++--
 mlir/test/Dialect/Arith/ops.mlir              | 18 ++++--
 4 files changed, 142 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 07708cf2d78a964..58e5385bf3ff268 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -857,6 +857,34 @@ def Arith_MaximumFOp : Arith_FloatBinaryOp<"maximumf", [Commutative]> {
   let hasFolder = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// MaxNumFOp
+//===----------------------------------------------------------------------===//
+
+def Arith_MaxNumFOp : Arith_FloatBinaryOp<"maxnumf", [Commutative]> {
+  let summary = "floating-point maximum operation";
+  let description = [{
+    Syntax:
+
+    ```
+    operation ::= ssa-id `=` `arith.maxnumf` ssa-use `,` ssa-use `:` type
+    ```
+
+    Returns the maximum of the two arguments.
+    If the arguments are -0.0 and +0.0, then the result is either of them.
+    If one of the arguments is NaN, then the result is the other argument.
+
+    Example:
+
+    ```mlir
+    // Scalar floating-point maximum.
+    %a = arith.maxnumf %b, %c : f64
+    ```
+  }];
+  let hasFolder = 1;
+}
+
+
 //===----------------------------------------------------------------------===//
 // MaxSIOp
 //===----------------------------------------------------------------------===//
@@ -901,6 +929,33 @@ def Arith_MinimumFOp : Arith_FloatBinaryOp<"minimumf", [Commutative]> {
   let hasFolder = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// MinNumFOp
+//===----------------------------------------------------------------------===//
+
+def Arith_MinNumFOp : Arith_FloatBinaryOp<"minnumf", [Commutative]> {
+  let summary = "floating-point minimum operation";
+  let description = [{
+    Syntax:
+
+    ```
+    operation ::= ssa-id `=` `arith.minnumf` ssa-use `,` ssa-use `:` type
+    ```
+
+    Returns the minimum of the two arguments.
+    If the arguments are -0.0 and +0.0, then the result is either of them.
+    If one of the arguments is NaN, then the result is the other argument.
+
+    Example:
+
+    ```mlir
+    // Scalar floating-point minimum.
+    %a = arith.minnumf %b, %c : f64
+    ```
+  }];
+  let hasFolder = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // MinSIOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 1e34ac598860f52..d39c5b6051122e4 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -927,11 +927,11 @@ OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
-  // maxf(x,x) -> x
+  // maximumf(x,x) -> x
   if (getLhs() == getRhs())
     return getRhs();
 
-  // maxf(x, -inf) -> x
+  // maximumf(x, -inf) -> x
   if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
     return getLhs();
 
@@ -940,6 +940,25 @@ OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
       [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
 }
 
+//===----------------------------------------------------------------------===//
+// MaxNumFOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
+  // maxnumf(x,x) -> x
+  if (getLhs() == getRhs())
+    return getRhs();
+
+  // maxnumf(x, -inf) -> x
+  if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
+    return getLhs();
+
+  return constFoldBinaryOp<FloatAttr>(
+      adaptor.getOperands(),
+      [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
+}
+
+
 //===----------------------------------------------------------------------===//
 // MaxSIOp
 //===----------------------------------------------------------------------===//
@@ -995,11 +1014,11 @@ OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
-  // minf(x,x) -> x
+  // minimumf(x,x) -> x
   if (getLhs() == getRhs())
     return getRhs();
 
-  // minf(x, +inf) -> x
+  // minimumf(x, +inf) -> x
   if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
     return getLhs();
 
@@ -1008,6 +1027,24 @@ OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
       [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
 }
 
+//===----------------------------------------------------------------------===//
+// MinNumFOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
+  // minnumf(x,x) -> x
+  if (getLhs() == getRhs())
+    return getRhs();
+
+  // minnumf(x, +inf) -> x
+  if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
+    return getLhs();
+
+  return constFoldBinaryOp<FloatAttr>(
+      adaptor.getOperands(),
+      [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); });
+}
+
 //===----------------------------------------------------------------------===//
 // MinSIOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 5c93be887107bb6..84096354e6afe33 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1635,8 +1635,8 @@ func.func @test_minui2(%arg0 : i8) -> (i8, i8, i8, i8) {
 
 // -----
 
-// CHECK-LABEL: @test_minf(
-func.func @test_minf(%arg0 : f32) -> (f32, f32, f32) {
+// CHECK-LABEL: @test_minimumf(
+func.func @test_minimumf(%arg0 : f32) -> (f32, f32, f32) {
   // CHECK-DAG:   %[[C0:.+]] = arith.constant 0.0
   // CHECK-NEXT:  %[[X:.+]] = arith.minimumf %arg0, %[[C0]]
   // CHECK-NEXT:  return %[[X]], %arg0, %arg0
@@ -1650,8 +1650,8 @@ func.func @test_minf(%arg0 : f32) -> (f32, f32, f32) {
 
 // -----
 
-// CHECK-LABEL: @test_maxf(
-func.func @test_maxf(%arg0 : f32) -> (f32, f32, f32) {
+// CHECK-LABEL: @test_maximumf(
+func.func @test_maximumf(%arg0 : f32) -> (f32, f32, f32) {
   // CHECK-DAG:   %[[C0:.+]] = arith.constant
   // CHECK-NEXT:  %[[X:.+]] = arith.maximumf %arg0, %[[C0]]
   // CHECK-NEXT:   return %[[X]], %arg0, %arg0
@@ -1665,6 +1665,36 @@ func.func @test_maxf(%arg0 : f32) -> (f32, f32, f32) {
 
 // -----
 
+// CHECK-LABEL: @test_minnumf(
+func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32) {
+  // CHECK-DAG:   %[[C0:.+]] = arith.constant 0.0
+  // CHECK-NEXT:  %[[X:.+]] = arith.minnumf %arg0, %[[C0]]
+  // CHECK-NEXT:  return %[[X]], %arg0, %arg0
+  %c0 = arith.constant 0.0 : f32
+  %inf = arith.constant 0x7F800000 : f32
+  %0 = arith.minnumf %c0, %arg0 : f32
+  %1 = arith.minnumf %arg0, %arg0 : f32
+  %2 = arith.minnumf %inf, %arg0 : f32
+  return %0, %1, %2 : f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: @test_maxnumf(
+func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32) {
+  // CHECK-DAG:   %[[C0:.+]] = arith.constant
+  // CHECK-NEXT:  %[[X:.+]] = arith.maxnumf %arg0, %[[C0]]
+  // CHECK-NEXT:   return %[[X]], %arg0, %arg0
+  %c0 = arith.constant 0.0 : f32
+  %-inf = arith.constant 0xFF800000 : f32
+  %0 = arith.maxnumf %c0, %arg0 : f32
+  %1 = arith.maxnumf %arg0, %arg0 : f32
+  %2 = arith.maxnumf %-inf, %arg0 : f32
+  return %0, %1, %2 : f32, f32, f32
+}
+
+// -----
+
 // CHECK-LABEL: @test_addf(
 func.func @test_addf(%arg0 : f32) -> (f32, f32, f32, f32) {
   // CHECK-DAG:   %[[C2:.+]] = arith.constant 2.0
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 5b5618bb03676bf..88cc0072c7c5704 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -1071,9 +1071,12 @@ func.func @maximum(%v1: vector<4xf32>, %v2: vector<4xf32>,
                %sv1: vector<[4]xf32>, %sv2: vector<[4]xf32>,
                %f1: f32, %f2: f32,
                %i1: i32, %i2: i32) {
-  %max_vector = arith.maximumf %v1, %v2 : vector<4xf32>
-  %max_scalable_vector = arith.maximumf %sv1, %sv2 : vector<[4]xf32>
-  %max_float = arith.maximumf %f1, %f2 : f32
+  %maximum_vector = arith.maximumf %v1, %v2 : vector<4xf32>
+  %maximum_scalable_vector = arith.maximumf %sv1, %sv2 : vector<[4]xf32>
+  %maximum_float = arith.maximumf %f1, %f2 : f32
+  %maxnum_vector = arith.maxnumf %v1, %v2 : vector<4xf32>
+  %maxnum_scalable_vector = arith.maxnumf %sv1, %sv2 : vector<[4]xf32>
+  %maxnum_float = arith.maxnumf %f1, %f2 : f32
   %max_signed = arith.maxsi %i1, %i2 : i32
   %max_unsigned = arith.maxui %i1, %i2 : i32
   return
@@ -1084,9 +1087,12 @@ func.func @minimum(%v1: vector<4xf32>, %v2: vector<4xf32>,
                %sv1: vector<[4]xf32>, %sv2: vector<[4]xf32>,
                %f1: f32, %f2: f32,
                %i1: i32, %i2: i32) {
-  %min_vector = arith.minimumf %v1, %v2 : vector<4xf32>
-  %min_scalable_vector = arith.minimumf %sv1, %sv2 : vector<[4]xf32>
-  %min_float = arith.minimumf %f1, %f2 : f32
+  %minimum_vector = arith.minimumf %v1, %v2 : vector<4xf32>
+  %minimum_scalable_vector = arith.minimumf %sv1, %sv2 : vector<[4]xf32>
+  %minimum_float = arith.minimumf %f1, %f2 : f32
+  %minnum_vector = arith.minnumf %v1, %v2 : vector<4xf32>
+  %minnum_scalable_vector = arith.minnumf %sv1, %sv2 : vector<[4]xf32>
+  %minnum_float = arith.minnumf %f1, %f2 : f32
   %min_signed = arith.minsi %i1, %i2 : i32
   %min_unsigned = arith.minui %i1, %i2 : i32
   return



More information about the Mlir-commits mailing list