[Mlir-commits] [mlir] [mlir][arith] Add LLVM lowering to `maxnum`, `minnum` ops (PR #66431)
Daniil Dudkin
llvmlistbot at llvm.org
Thu Sep 14 14:03:31 PDT 2023
https://github.com/unterumarmung created https://github.com/llvm/llvm-project/pull/66431:
This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.
The commit addresses the task 1.4 of the RFC by adding LLVM lowering to the corresponding LLVM intrinsics.
Please **note**: this PR is part of a stack of patches and depends on #66429.
>From d11c9a2c186dc71037d8fa82c1673f6ea9d0d420 Mon Sep 17 00:00:00 2001
From: Daniil Dudkin <unterumarmung at yandex.ru>
Date: Tue, 12 Sep 2023 22:40:39 +0300
Subject: [PATCH 1/2] [mlir][arith] Introduce `minnumf` and `maxnumf`
operations
This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.
Here we introduce new operations for floating-point numbers: `minnum` and `maxnum`.
These operations have different semantics than `minumumf` and `maximumf` ops.
They follow the eponymous LLVM intrinsics semantics, which differs
in the handling positive and negative zeros and NaNs.
This patch addresses the 1.3 task from the RFC.
---
.../include/mlir/Dialect/Arith/IR/ArithOps.td | 55 +++++++++++++++++++
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 45 +++++++++++++--
mlir/test/Dialect/Arith/canonicalize.mlir | 38 +++++++++++--
mlir/test/Dialect/Arith/ops.mlir | 18 ++++--
4 files changed, 142 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 07708cf2d78a964..58e5385bf3ff268 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -857,6 +857,34 @@ def Arith_MaximumFOp : Arith_FloatBinaryOp<"maximumf", [Commutative]> {
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// MaxNumFOp
+//===----------------------------------------------------------------------===//
+
+def Arith_MaxNumFOp : Arith_FloatBinaryOp<"maxnumf", [Commutative]> {
+ let summary = "floating-point maximum operation";
+ let description = [{
+ Syntax:
+
+ ```
+ operation ::= ssa-id `=` `arith.maxnumf` ssa-use `,` ssa-use `:` type
+ ```
+
+ Returns the maximum of the two arguments.
+ If the arguments are -0.0 and +0.0, then the result is either of them.
+ If one of the arguments is NaN, then the result is the other argument.
+
+ Example:
+
+ ```mlir
+ // Scalar floating-point maximum.
+ %a = arith.maxnumf %b, %c : f64
+ ```
+ }];
+ let hasFolder = 1;
+}
+
+
//===----------------------------------------------------------------------===//
// MaxSIOp
//===----------------------------------------------------------------------===//
@@ -901,6 +929,33 @@ def Arith_MinimumFOp : Arith_FloatBinaryOp<"minimumf", [Commutative]> {
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// MinNumFOp
+//===----------------------------------------------------------------------===//
+
+def Arith_MinNumFOp : Arith_FloatBinaryOp<"minnumf", [Commutative]> {
+ let summary = "floating-point minimum operation";
+ let description = [{
+ Syntax:
+
+ ```
+ operation ::= ssa-id `=` `arith.minnumf` ssa-use `,` ssa-use `:` type
+ ```
+
+ Returns the minimum of the two arguments.
+ If the arguments are -0.0 and +0.0, then the result is either of them.
+ If one of the arguments is NaN, then the result is the other argument.
+
+ Example:
+
+ ```mlir
+ // Scalar floating-point minimum.
+ %a = arith.minnumf %b, %c : f64
+ ```
+ }];
+ let hasFolder = 1;
+}
+
//===----------------------------------------------------------------------===//
// MinSIOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 1e34ac598860f52..d39c5b6051122e4 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -927,11 +927,11 @@ OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
- // maxf(x,x) -> x
+ // maximumf(x,x) -> x
if (getLhs() == getRhs())
return getRhs();
- // maxf(x, -inf) -> x
+ // maximumf(x, -inf) -> x
if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
return getLhs();
@@ -940,6 +940,25 @@ OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
[](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
}
+//===----------------------------------------------------------------------===//
+// MaxNumFOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
+ // maxnumf(x,x) -> x
+ if (getLhs() == getRhs())
+ return getRhs();
+
+ // maxnumf(x, -inf) -> x
+ if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
+ return getLhs();
+
+ return constFoldBinaryOp<FloatAttr>(
+ adaptor.getOperands(),
+ [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
+}
+
+
//===----------------------------------------------------------------------===//
// MaxSIOp
//===----------------------------------------------------------------------===//
@@ -995,11 +1014,11 @@ OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
- // minf(x,x) -> x
+ // minimumf(x,x) -> x
if (getLhs() == getRhs())
return getRhs();
- // minf(x, +inf) -> x
+ // minimumf(x, +inf) -> x
if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
return getLhs();
@@ -1008,6 +1027,24 @@ OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
[](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
}
+//===----------------------------------------------------------------------===//
+// MinNumFOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
+ // minnumf(x,x) -> x
+ if (getLhs() == getRhs())
+ return getRhs();
+
+ // minnumf(x, +inf) -> x
+ if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
+ return getLhs();
+
+ return constFoldBinaryOp<FloatAttr>(
+ adaptor.getOperands(),
+ [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); });
+}
+
//===----------------------------------------------------------------------===//
// MinSIOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 5c93be887107bb6..84096354e6afe33 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1635,8 +1635,8 @@ func.func @test_minui2(%arg0 : i8) -> (i8, i8, i8, i8) {
// -----
-// CHECK-LABEL: @test_minf(
-func.func @test_minf(%arg0 : f32) -> (f32, f32, f32) {
+// CHECK-LABEL: @test_minimumf(
+func.func @test_minimumf(%arg0 : f32) -> (f32, f32, f32) {
// CHECK-DAG: %[[C0:.+]] = arith.constant 0.0
// CHECK-NEXT: %[[X:.+]] = arith.minimumf %arg0, %[[C0]]
// CHECK-NEXT: return %[[X]], %arg0, %arg0
@@ -1650,8 +1650,8 @@ func.func @test_minf(%arg0 : f32) -> (f32, f32, f32) {
// -----
-// CHECK-LABEL: @test_maxf(
-func.func @test_maxf(%arg0 : f32) -> (f32, f32, f32) {
+// CHECK-LABEL: @test_maximumf(
+func.func @test_maximumf(%arg0 : f32) -> (f32, f32, f32) {
// CHECK-DAG: %[[C0:.+]] = arith.constant
// CHECK-NEXT: %[[X:.+]] = arith.maximumf %arg0, %[[C0]]
// CHECK-NEXT: return %[[X]], %arg0, %arg0
@@ -1665,6 +1665,36 @@ func.func @test_maxf(%arg0 : f32) -> (f32, f32, f32) {
// -----
+// CHECK-LABEL: @test_minnumf(
+func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32) {
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0.0
+ // CHECK-NEXT: %[[X:.+]] = arith.minnumf %arg0, %[[C0]]
+ // CHECK-NEXT: return %[[X]], %arg0, %arg0
+ %c0 = arith.constant 0.0 : f32
+ %inf = arith.constant 0x7F800000 : f32
+ %0 = arith.minnumf %c0, %arg0 : f32
+ %1 = arith.minnumf %arg0, %arg0 : f32
+ %2 = arith.minnumf %inf, %arg0 : f32
+ return %0, %1, %2 : f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: @test_maxnumf(
+func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32) {
+ // CHECK-DAG: %[[C0:.+]] = arith.constant
+ // CHECK-NEXT: %[[X:.+]] = arith.maxnumf %arg0, %[[C0]]
+ // CHECK-NEXT: return %[[X]], %arg0, %arg0
+ %c0 = arith.constant 0.0 : f32
+ %-inf = arith.constant 0xFF800000 : f32
+ %0 = arith.maxnumf %c0, %arg0 : f32
+ %1 = arith.maxnumf %arg0, %arg0 : f32
+ %2 = arith.maxnumf %-inf, %arg0 : f32
+ return %0, %1, %2 : f32, f32, f32
+}
+
+// -----
+
// CHECK-LABEL: @test_addf(
func.func @test_addf(%arg0 : f32) -> (f32, f32, f32, f32) {
// CHECK-DAG: %[[C2:.+]] = arith.constant 2.0
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 5b5618bb03676bf..88cc0072c7c5704 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -1071,9 +1071,12 @@ func.func @maximum(%v1: vector<4xf32>, %v2: vector<4xf32>,
%sv1: vector<[4]xf32>, %sv2: vector<[4]xf32>,
%f1: f32, %f2: f32,
%i1: i32, %i2: i32) {
- %max_vector = arith.maximumf %v1, %v2 : vector<4xf32>
- %max_scalable_vector = arith.maximumf %sv1, %sv2 : vector<[4]xf32>
- %max_float = arith.maximumf %f1, %f2 : f32
+ %maximum_vector = arith.maximumf %v1, %v2 : vector<4xf32>
+ %maximum_scalable_vector = arith.maximumf %sv1, %sv2 : vector<[4]xf32>
+ %maximum_float = arith.maximumf %f1, %f2 : f32
+ %maxnum_vector = arith.maxnumf %v1, %v2 : vector<4xf32>
+ %maxnum_scalable_vector = arith.maxnumf %sv1, %sv2 : vector<[4]xf32>
+ %maxnum_float = arith.maxnumf %f1, %f2 : f32
%max_signed = arith.maxsi %i1, %i2 : i32
%max_unsigned = arith.maxui %i1, %i2 : i32
return
@@ -1084,9 +1087,12 @@ func.func @minimum(%v1: vector<4xf32>, %v2: vector<4xf32>,
%sv1: vector<[4]xf32>, %sv2: vector<[4]xf32>,
%f1: f32, %f2: f32,
%i1: i32, %i2: i32) {
- %min_vector = arith.minimumf %v1, %v2 : vector<4xf32>
- %min_scalable_vector = arith.minimumf %sv1, %sv2 : vector<[4]xf32>
- %min_float = arith.minimumf %f1, %f2 : f32
+ %minimum_vector = arith.minimumf %v1, %v2 : vector<4xf32>
+ %minimum_scalable_vector = arith.minimumf %sv1, %sv2 : vector<[4]xf32>
+ %minimum_float = arith.minimumf %f1, %f2 : f32
+ %minnum_vector = arith.minnumf %v1, %v2 : vector<4xf32>
+ %minnum_scalable_vector = arith.minnumf %sv1, %sv2 : vector<[4]xf32>
+ %minnum_float = arith.minnumf %f1, %f2 : f32
%min_signed = arith.minsi %i1, %i2 : i32
%min_unsigned = arith.minui %i1, %i2 : i32
return
>From e2188fe2cb085dd0ff8374d9ed235d55a970898a Mon Sep 17 00:00:00 2001
From: Daniil Dudkin <unterumarmung at yandex.ru>
Date: Fri, 15 Sep 2023 00:01:30 +0300
Subject: [PATCH 2/2] [mlir][arith] Add LLVM lowering to `maxnum`, `minnum` ops
This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.
The commit addresses the task 1.4 of the RFC by adding LLVM lowering to the corresponding LLVM intrinsics.
Please **note**: this PR is part of a stack of patches and depends on #66429.
---
mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp | 8 ++++++++
mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir | 4 ++++
2 files changed, 12 insertions(+)
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index a695441fd8dd750..337f2dbcbe4edf5 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -57,6 +57,9 @@ using FPToUIOpLowering =
using MaximumFOpLowering =
VectorConvertToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp,
arith::AttrConvertFastMathToLLVM>;
+using MaxNumFOpLowering =
+ VectorConvertToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp,
+ arith::AttrConvertFastMathToLLVM>;
using MaxSIOpLowering =
VectorConvertToLLVMPattern<arith::MaxSIOp, LLVM::SMaxOp>;
using MaxUIOpLowering =
@@ -64,6 +67,9 @@ using MaxUIOpLowering =
using MinimumFOpLowering =
VectorConvertToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp,
arith::AttrConvertFastMathToLLVM>;
+using MinNumFOpLowering =
+ VectorConvertToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp,
+ arith::AttrConvertFastMathToLLVM>;
using MinSIOpLowering =
VectorConvertToLLVMPattern<arith::MinSIOp, LLVM::SMinOp>;
using MinUIOpLowering =
@@ -496,9 +502,11 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
IndexCastOpSILowering,
IndexCastOpUILowering,
MaximumFOpLowering,
+ MaxNumFOpLowering,
MaxSIOpLowering,
MaxUIOpLowering,
MinimumFOpLowering,
+ MinNumFOpLowering,
MinSIOpLowering,
MinUIOpLowering,
MulFOpLowering,
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 5855f7b3b9904fd..6f614b113788c7e 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -526,6 +526,10 @@ func.func @minmaxf(%arg0 : f32, %arg1 : f32) -> f32 {
%0 = arith.minimumf %arg0, %arg1 : f32
// CHECK: = llvm.intr.maximum(%arg0, %arg1) : (f32, f32) -> f32
%1 = arith.maximumf %arg0, %arg1 : f32
+ // CHECK: = llvm.intr.minnum(%arg0, %arg1) : (f32, f32) -> f32
+ %2 = arith.minnumf %arg0, %arg1 : f32
+ // CHECK: = llvm.intr.maxnum(%arg0, %arg1) : (f32, f32) -> f32
+ %3 = arith.maxnumf %arg0, %arg1 : f32
return %0 : f32
}
More information about the Mlir-commits
mailing list