[Mlir-commits] [mlir] [mlir][arith] Introduce `minnumf` and `maxnumf` operations (PR #66429)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 14 13:49:24 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
<details>
<summary>Changes</summary>
This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.
Here we introduce new operations for floating-point numbers: `minnum` and `maxnum`.
These operations have different semantics than `minumumf` and `maximumf` ops.
They follow the eponymous LLVM intrinsics semantics, which differs
in the handling positive and negative zeros and NaNs.
This patch addresses the 1.3 task from the RFC.
--
Full diff: https://github.com/llvm/llvm-project/pull/66429.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/Arith/IR/ArithOps.td (+55)
- (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+41-4)
- (modified) mlir/test/Dialect/Arith/canonicalize.mlir (+34-4)
- (modified) mlir/test/Dialect/Arith/ops.mlir (+12-6)
<pre>
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 07708cf2d78a964..58e5385bf3ff268 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -857,6 +857,34 @@ def Arith_MaximumFOp : Arith_FloatBinaryOp<"maximumf", [Commutative]> {
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// MaxNumFOp
+//===----------------------------------------------------------------------===//
+
+def Arith_MaxNumFOp : Arith_FloatBinaryOp<"maxnumf", [Commutative]> {
+ let summary = "floating-point maximum operation";
+ let description = [{
+ Syntax:
+
+ ```
+ operation ::= ssa-id `=` `arith.maxnumf` ssa-use `,` ssa-use `:` type
+ ```
+
+ Returns the maximum of the two arguments.
+ If the arguments are -0.0 and +0.0, then the result is either of them.
+ If one of the arguments is NaN, then the result is the other argument.
+
+ Example:
+
+ ```mlir
+ // Scalar floating-point maximum.
+ %a = arith.maxnumf %b, %c : f64
+ ```
+ }];
+ let hasFolder = 1;
+}
+
+
//===----------------------------------------------------------------------===//
// MaxSIOp
//===----------------------------------------------------------------------===//
@@ -901,6 +929,33 @@ def Arith_MinimumFOp : Arith_FloatBinaryOp<"minimumf", [Commutative]> {
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// MinNumFOp
+//===----------------------------------------------------------------------===//
+
+def Arith_MinNumFOp : Arith_FloatBinaryOp<"minnumf", [Commutative]> {
+ let summary = "floating-point minimum operation";
+ let description = [{
+ Syntax:
+
+ ```
+ operation ::= ssa-id `=` `arith.minnumf` ssa-use `,` ssa-use `:` type
+ ```
+
+ Returns the minimum of the two arguments.
+ If the arguments are -0.0 and +0.0, then the result is either of them.
+ If one of the arguments is NaN, then the result is the other argument.
+
+ Example:
+
+ ```mlir
+ // Scalar floating-point minimum.
+ %a = arith.minnumf %b, %c : f64
+ ```
+ }];
+ let hasFolder = 1;
+}
+
//===----------------------------------------------------------------------===//
// MinSIOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 1e34ac598860f52..d39c5b6051122e4 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -927,11 +927,11 @@ OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
- // maxf(x,x) -> x
+ // maximumf(x,x) -> x
if (getLhs() == getRhs())
return getRhs();
- // maxf(x, -inf) -> x
+ // maximumf(x, -inf) -> x
if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
return getLhs();
@@ -940,6 +940,25 @@ OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
[](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
}
+//===----------------------------------------------------------------------===//
+// MaxNumFOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
+ // maxnumf(x,x) -> x
+ if (getLhs() == getRhs())
+ return getRhs();
+
+ // maxnumf(x, -inf) -> x
+ if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
+ return getLhs();
+
+ return constFoldBinaryOp<FloatAttr>(
+ adaptor.getOperands(),
+ [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
+}
+
+
//===----------------------------------------------------------------------===//
// MaxSIOp
//===----------------------------------------------------------------------===//
@@ -995,11 +1014,11 @@ OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
- // minf(x,x) -> x
+ // minimumf(x,x) -> x
if (getLhs() == getRhs())
return getRhs();
- // minf(x, +inf) -> x
+ // minimumf(x, +inf) -> x
if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
return getLhs();
@@ -1008,6 +1027,24 @@ OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
[](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
}
+//===----------------------------------------------------------------------===//
+// MinNumFOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
+ // minnumf(x,x) -> x
+ if (getLhs() == getRhs())
+ return getRhs();
+
+ // minnumf(x, +inf) -> x
+ if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
+ return getLhs();
+
+ return constFoldBinaryOp<FloatAttr>(
+ adaptor.getOperands(),
+ [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); });
+}
+
//===----------------------------------------------------------------------===//
// MinSIOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 5c93be887107bb6..84096354e6afe33 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1635,8 +1635,8 @@ func.func @test_minui2(%arg0 : i8) -> (i8, i8, i8, i8) {
// -----
-// CHECK-LABEL: @test_minf(
-func.func @test_minf(%arg0 : f32) -> (f32, f32, f32) {
+// CHECK-LABEL: @test_minimumf(
+func.func @test_minimumf(%arg0 : f32) -> (f32, f32, f32) {
// CHECK-DAG: %[[C0:.+]] = arith.constant 0.0
// CHECK-NEXT: %[[X:.+]] = arith.minimumf %arg0, %[[C0]]
// CHECK-NEXT: return %[[X]], %arg0, %arg0
@@ -1650,8 +1650,8 @@ func.func @test_minf(%arg0 : f32) -> (f32, f32, f32) {
// -----
-// CHECK-LABEL: @test_maxf(
-func.func @test_maxf(%arg0 : f32) -> (f32, f32, f32) {
+// CHECK-LABEL: @test_maximumf(
+func.func @test_maximumf(%arg0 : f32) -> (f32, f32, f32) {
// CHECK-DAG: %[[C0:.+]] = arith.constant
// CHECK-NEXT: %[[X:.+]] = arith.maximumf %arg0, %[[C0]]
// CHECK-NEXT: return %[[X]], %arg0, %arg0
@@ -1665,6 +1665,36 @@ func.func @test_maxf(%arg0 : f32) -> (f32, f32, f32) {
// -----
+// CHECK-LABEL: @test_minnumf(
+func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32) {
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0.0
+ // CHECK-NEXT: %[[X:.+]] = arith.minnumf %arg0, %[[C0]]
+ // CHECK-NEXT: return %[[X]], %arg0, %arg0
+ %c0 = arith.constant 0.0 : f32
+ %inf = arith.constant 0x7F800000 : f32
+ %0 = arith.minnumf %c0, %arg0 : f32
+ %1 = arith.minnumf %arg0, %arg0 : f32
+ %2 = arith.minnumf %inf, %arg0 : f32
+ return %0, %1, %2 : f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: @test_maxnumf(
+func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32) {
+ // CHECK-DAG: %[[C0:.+]] = arith.constant
+ // CHECK-NEXT: %[[X:.+]] = arith.maxnumf %arg0, %[[C0]]
+ // CHECK-NEXT: return %[[X]], %arg0, %arg0
+ %c0 = arith.constant 0.0 : f32
+ %-inf = arith.constant 0xFF800000 : f32
+ %0 = arith.maxnumf %c0, %arg0 : f32
+ %1 = arith.maxnumf %arg0, %arg0 : f32
+ %2 = arith.maxnumf %-inf, %arg0 : f32
+ return %0, %1, %2 : f32, f32, f32
+}
+
+// -----
+
// CHECK-LABEL: @test_addf(
func.func @test_addf(%arg0 : f32) -> (f32, f32, f32, f32) {
// CHECK-DAG: %[[C2:.+]] = arith.constant 2.0
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 5b5618bb03676bf..88cc0072c7c5704 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -1071,9 +1071,12 @@ func.func @maximum(%v1: vector<4xf32>, %v2: vector<4xf32>,
%sv1: vector<[4]xf32>, %sv2: vector<[4]xf32>,
%f1: f32, %f2: f32,
%i1: i32, %i2: i32) {
- %max_vector = arith.maximumf %v1, %v2 : vector<4xf32>
- %max_scalable_vector = arith.maximumf %sv1, %sv2 : vector<[4]xf32>
- %max_float = arith.maximumf %f1, %f2 : f32
+ %maximum_vector = arith.maximumf %v1, %v2 : vector<4xf32>
+ %maximum_scalable_vector = arith.maximumf %sv1, %sv2 : vector<[4]xf32>
+ %maximum_float = arith.maximumf %f1, %f2 : f32
+ %maxnum_vector = arith.maxnumf %v1, %v2 : vector<4xf32>
+ %maxnum_scalable_vector = arith.maxnumf %sv1, %sv2 : vector<[4]xf32>
+ %maxnum_float = arith.maxnumf %f1, %f2 : f32
%max_signed = arith.maxsi %i1, %i2 : i32
%max_unsigned = arith.maxui %i1, %i2 : i32
return
@@ -1084,9 +1087,12 @@ func.func @minimum(%v1: vector<4xf32>, %v2: vector<4xf32>,
%sv1: vector<[4]xf32>, %sv2: vector<[4]xf32>,
%f1: f32, %f2: f32,
%i1: i32, %i2: i32) {
- %min_vector = arith.minimumf %v1, %v2 : vector<4xf32>
- %min_scalable_vector = arith.minimumf %sv1, %sv2 : vector<[4]xf32>
- %min_float = arith.minimumf %f1, %f2 : f32
+ %minimum_vector = arith.minimumf %v1, %v2 : vector<4xf32>
+ %minimum_scalable_vector = arith.minimumf %sv1, %sv2 : vector<[4]xf32>
+ %minimum_float = arith.minimumf %f1, %f2 : f32
+ %minnum_vector = arith.minnumf %v1, %v2 : vector<4xf32>
+ %minnum_scalable_vector = arith.minnumf %sv1, %sv2 : vector<[4]xf32>
+ %minnum_float = arith.minnumf %f1, %f2 : f32
%min_signed = arith.minsi %i1, %i2 : i32
%min_unsigned = arith.minui %i1, %i2 : i32
return
</pre>
</details>
https://github.com/llvm/llvm-project/pull/66429
More information about the Mlir-commits
mailing list