[Mlir-commits] [mlir] [mlir][arith] Add LLVM lowering to `maxnum`, `minnum` ops (PR #66431)

Daniil Dudkin llvmlistbot at llvm.org
Thu Sep 14 14:03:31 PDT 2023


https://github.com/unterumarmung created https://github.com/llvm/llvm-project/pull/66431:

This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.
    
The commit addresses the task 1.4 of the RFC by adding LLVM lowering to the corresponding LLVM intrinsics.
    
Please **note**: this PR is part of a stack of patches and depends on #66429.

>From d11c9a2c186dc71037d8fa82c1673f6ea9d0d420 Mon Sep 17 00:00:00 2001
From: Daniil Dudkin <unterumarmung at yandex.ru>
Date: Tue, 12 Sep 2023 22:40:39 +0300
Subject: [PATCH 1/2] [mlir][arith] Introduce `minnumf` and `maxnumf`
 operations

This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

Here we introduce new operations for floating-point numbers: `minnum` and `maxnum`.
These operations have different semantics than `minumumf` and `maximumf` ops.
They follow the eponymous LLVM intrinsics semantics, which differs
in the handling positive and negative zeros and NaNs.

This patch addresses the 1.3 task from the RFC.
---
 .../include/mlir/Dialect/Arith/IR/ArithOps.td | 55 +++++++++++++++++++
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        | 45 +++++++++++++--
 mlir/test/Dialect/Arith/canonicalize.mlir     | 38 +++++++++++--
 mlir/test/Dialect/Arith/ops.mlir              | 18 ++++--
 4 files changed, 142 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 07708cf2d78a964..58e5385bf3ff268 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -857,6 +857,34 @@ def Arith_MaximumFOp : Arith_FloatBinaryOp<"maximumf", [Commutative]> {
   let hasFolder = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// MaxNumFOp
+//===----------------------------------------------------------------------===//
+
+def Arith_MaxNumFOp : Arith_FloatBinaryOp<"maxnumf", [Commutative]> {
+  let summary = "floating-point maximum operation";
+  let description = [{
+    Syntax:
+
+    ```
+    operation ::= ssa-id `=` `arith.maxnumf` ssa-use `,` ssa-use `:` type
+    ```
+
+    Returns the maximum of the two arguments.
+    If the arguments are -0.0 and +0.0, then the result is either of them.
+    If one of the arguments is NaN, then the result is the other argument.
+
+    Example:
+
+    ```mlir
+    // Scalar floating-point maximum.
+    %a = arith.maxnumf %b, %c : f64
+    ```
+  }];
+  let hasFolder = 1;
+}
+
+
 //===----------------------------------------------------------------------===//
 // MaxSIOp
 //===----------------------------------------------------------------------===//
@@ -901,6 +929,33 @@ def Arith_MinimumFOp : Arith_FloatBinaryOp<"minimumf", [Commutative]> {
   let hasFolder = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// MinNumFOp
+//===----------------------------------------------------------------------===//
+
+def Arith_MinNumFOp : Arith_FloatBinaryOp<"minnumf", [Commutative]> {
+  let summary = "floating-point minimum operation";
+  let description = [{
+    Syntax:
+
+    ```
+    operation ::= ssa-id `=` `arith.minnumf` ssa-use `,` ssa-use `:` type
+    ```
+
+    Returns the minimum of the two arguments.
+    If the arguments are -0.0 and +0.0, then the result is either of them.
+    If one of the arguments is NaN, then the result is the other argument.
+
+    Example:
+
+    ```mlir
+    // Scalar floating-point minimum.
+    %a = arith.minnumf %b, %c : f64
+    ```
+  }];
+  let hasFolder = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // MinSIOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 1e34ac598860f52..d39c5b6051122e4 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -927,11 +927,11 @@ OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
-  // maxf(x,x) -> x
+  // maximumf(x,x) -> x
   if (getLhs() == getRhs())
     return getRhs();
 
-  // maxf(x, -inf) -> x
+  // maximumf(x, -inf) -> x
   if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
     return getLhs();
 
@@ -940,6 +940,25 @@ OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
       [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
 }
 
+//===----------------------------------------------------------------------===//
+// MaxNumFOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
+  // maxnumf(x,x) -> x
+  if (getLhs() == getRhs())
+    return getRhs();
+
+  // maxnumf(x, -inf) -> x
+  if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
+    return getLhs();
+
+  return constFoldBinaryOp<FloatAttr>(
+      adaptor.getOperands(),
+      [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
+}
+
+
 //===----------------------------------------------------------------------===//
 // MaxSIOp
 //===----------------------------------------------------------------------===//
@@ -995,11 +1014,11 @@ OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
-  // minf(x,x) -> x
+  // minimumf(x,x) -> x
   if (getLhs() == getRhs())
     return getRhs();
 
-  // minf(x, +inf) -> x
+  // minimumf(x, +inf) -> x
   if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
     return getLhs();
 
@@ -1008,6 +1027,24 @@ OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
       [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
 }
 
+//===----------------------------------------------------------------------===//
+// MinNumFOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
+  // minnumf(x,x) -> x
+  if (getLhs() == getRhs())
+    return getRhs();
+
+  // minnumf(x, +inf) -> x
+  if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
+    return getLhs();
+
+  return constFoldBinaryOp<FloatAttr>(
+      adaptor.getOperands(),
+      [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); });
+}
+
 //===----------------------------------------------------------------------===//
 // MinSIOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 5c93be887107bb6..84096354e6afe33 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1635,8 +1635,8 @@ func.func @test_minui2(%arg0 : i8) -> (i8, i8, i8, i8) {
 
 // -----
 
-// CHECK-LABEL: @test_minf(
-func.func @test_minf(%arg0 : f32) -> (f32, f32, f32) {
+// CHECK-LABEL: @test_minimumf(
+func.func @test_minimumf(%arg0 : f32) -> (f32, f32, f32) {
   // CHECK-DAG:   %[[C0:.+]] = arith.constant 0.0
   // CHECK-NEXT:  %[[X:.+]] = arith.minimumf %arg0, %[[C0]]
   // CHECK-NEXT:  return %[[X]], %arg0, %arg0
@@ -1650,8 +1650,8 @@ func.func @test_minf(%arg0 : f32) -> (f32, f32, f32) {
 
 // -----
 
-// CHECK-LABEL: @test_maxf(
-func.func @test_maxf(%arg0 : f32) -> (f32, f32, f32) {
+// CHECK-LABEL: @test_maximumf(
+func.func @test_maximumf(%arg0 : f32) -> (f32, f32, f32) {
   // CHECK-DAG:   %[[C0:.+]] = arith.constant
   // CHECK-NEXT:  %[[X:.+]] = arith.maximumf %arg0, %[[C0]]
   // CHECK-NEXT:   return %[[X]], %arg0, %arg0
@@ -1665,6 +1665,36 @@ func.func @test_maxf(%arg0 : f32) -> (f32, f32, f32) {
 
 // -----
 
+// CHECK-LABEL: @test_minnumf(
+func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32) {
+  // CHECK-DAG:   %[[C0:.+]] = arith.constant 0.0
+  // CHECK-NEXT:  %[[X:.+]] = arith.minnumf %arg0, %[[C0]]
+  // CHECK-NEXT:  return %[[X]], %arg0, %arg0
+  %c0 = arith.constant 0.0 : f32
+  %inf = arith.constant 0x7F800000 : f32
+  %0 = arith.minnumf %c0, %arg0 : f32
+  %1 = arith.minnumf %arg0, %arg0 : f32
+  %2 = arith.minnumf %inf, %arg0 : f32
+  return %0, %1, %2 : f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: @test_maxnumf(
+func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32) {
+  // CHECK-DAG:   %[[C0:.+]] = arith.constant
+  // CHECK-NEXT:  %[[X:.+]] = arith.maxnumf %arg0, %[[C0]]
+  // CHECK-NEXT:   return %[[X]], %arg0, %arg0
+  %c0 = arith.constant 0.0 : f32
+  %-inf = arith.constant 0xFF800000 : f32
+  %0 = arith.maxnumf %c0, %arg0 : f32
+  %1 = arith.maxnumf %arg0, %arg0 : f32
+  %2 = arith.maxnumf %-inf, %arg0 : f32
+  return %0, %1, %2 : f32, f32, f32
+}
+
+// -----
+
 // CHECK-LABEL: @test_addf(
 func.func @test_addf(%arg0 : f32) -> (f32, f32, f32, f32) {
   // CHECK-DAG:   %[[C2:.+]] = arith.constant 2.0
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 5b5618bb03676bf..88cc0072c7c5704 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -1071,9 +1071,12 @@ func.func @maximum(%v1: vector<4xf32>, %v2: vector<4xf32>,
                %sv1: vector<[4]xf32>, %sv2: vector<[4]xf32>,
                %f1: f32, %f2: f32,
                %i1: i32, %i2: i32) {
-  %max_vector = arith.maximumf %v1, %v2 : vector<4xf32>
-  %max_scalable_vector = arith.maximumf %sv1, %sv2 : vector<[4]xf32>
-  %max_float = arith.maximumf %f1, %f2 : f32
+  %maximum_vector = arith.maximumf %v1, %v2 : vector<4xf32>
+  %maximum_scalable_vector = arith.maximumf %sv1, %sv2 : vector<[4]xf32>
+  %maximum_float = arith.maximumf %f1, %f2 : f32
+  %maxnum_vector = arith.maxnumf %v1, %v2 : vector<4xf32>
+  %maxnum_scalable_vector = arith.maxnumf %sv1, %sv2 : vector<[4]xf32>
+  %maxnum_float = arith.maxnumf %f1, %f2 : f32
   %max_signed = arith.maxsi %i1, %i2 : i32
   %max_unsigned = arith.maxui %i1, %i2 : i32
   return
@@ -1084,9 +1087,12 @@ func.func @minimum(%v1: vector<4xf32>, %v2: vector<4xf32>,
                %sv1: vector<[4]xf32>, %sv2: vector<[4]xf32>,
                %f1: f32, %f2: f32,
                %i1: i32, %i2: i32) {
-  %min_vector = arith.minimumf %v1, %v2 : vector<4xf32>
-  %min_scalable_vector = arith.minimumf %sv1, %sv2 : vector<[4]xf32>
-  %min_float = arith.minimumf %f1, %f2 : f32
+  %minimum_vector = arith.minimumf %v1, %v2 : vector<4xf32>
+  %minimum_scalable_vector = arith.minimumf %sv1, %sv2 : vector<[4]xf32>
+  %minimum_float = arith.minimumf %f1, %f2 : f32
+  %minnum_vector = arith.minnumf %v1, %v2 : vector<4xf32>
+  %minnum_scalable_vector = arith.minnumf %sv1, %sv2 : vector<[4]xf32>
+  %minnum_float = arith.minnumf %f1, %f2 : f32
   %min_signed = arith.minsi %i1, %i2 : i32
   %min_unsigned = arith.minui %i1, %i2 : i32
   return

>From e2188fe2cb085dd0ff8374d9ed235d55a970898a Mon Sep 17 00:00:00 2001
From: Daniil Dudkin <unterumarmung at yandex.ru>
Date: Fri, 15 Sep 2023 00:01:30 +0300
Subject: [PATCH 2/2] [mlir][arith] Add LLVM lowering to `maxnum`, `minnum` ops

This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

The commit addresses the task 1.4 of the RFC by adding LLVM lowering to the corresponding LLVM intrinsics.

Please **note**: this PR is part of a stack of patches and depends on #66429.
---
 mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp     | 8 ++++++++
 mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir | 4 ++++
 2 files changed, 12 insertions(+)

diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index a695441fd8dd750..337f2dbcbe4edf5 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -57,6 +57,9 @@ using FPToUIOpLowering =
 using MaximumFOpLowering =
     VectorConvertToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp,
                                arith::AttrConvertFastMathToLLVM>;
+using MaxNumFOpLowering =
+    VectorConvertToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp,
+                               arith::AttrConvertFastMathToLLVM>;
 using MaxSIOpLowering =
     VectorConvertToLLVMPattern<arith::MaxSIOp, LLVM::SMaxOp>;
 using MaxUIOpLowering =
@@ -64,6 +67,9 @@ using MaxUIOpLowering =
 using MinimumFOpLowering =
     VectorConvertToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp,
                                arith::AttrConvertFastMathToLLVM>;
+using MinNumFOpLowering =
+    VectorConvertToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp,
+                               arith::AttrConvertFastMathToLLVM>;
 using MinSIOpLowering =
     VectorConvertToLLVMPattern<arith::MinSIOp, LLVM::SMinOp>;
 using MinUIOpLowering =
@@ -496,9 +502,11 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
     IndexCastOpSILowering,
     IndexCastOpUILowering,
     MaximumFOpLowering,
+    MaxNumFOpLowering,
     MaxSIOpLowering,
     MaxUIOpLowering,
     MinimumFOpLowering,
+    MinNumFOpLowering,
     MinSIOpLowering,
     MinUIOpLowering,
     MulFOpLowering,
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 5855f7b3b9904fd..6f614b113788c7e 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -526,6 +526,10 @@ func.func @minmaxf(%arg0 : f32, %arg1 : f32) -> f32 {
   %0 = arith.minimumf %arg0, %arg1 : f32
   // CHECK: = llvm.intr.maximum(%arg0, %arg1) : (f32, f32) -> f32
   %1 = arith.maximumf %arg0, %arg1 : f32
+  // CHECK: = llvm.intr.minnum(%arg0, %arg1) : (f32, f32) -> f32
+  %2 = arith.minnumf %arg0, %arg1 : f32
+  // CHECK: = llvm.intr.maxnum(%arg0, %arg1) : (f32, f32) -> f32
+  %3 = arith.maxnumf %arg0, %arg1 : f32
   return %0 : f32
 }
 



More information about the Mlir-commits mailing list