[Mlir-commits] [mlir] b6ab04c - [mlir][arith] Fix arith maxnumf/minnumf folder (#114595)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 27 12:06:53 PST 2024


Author: Clément Fournier
Date: 2024-11-27T21:06:49+01:00
New Revision: b6ab04c69c907362dc7ab65eb43a9907c9adcdc1

URL: https://github.com/llvm/llvm-project/commit/b6ab04c69c907362dc7ab65eb43a9907c9adcdc1
DIFF: https://github.com/llvm/llvm-project/commit/b6ab04c69c907362dc7ab65eb43a9907c9adcdc1.diff

LOG: [mlir][arith] Fix arith maxnumf/minnumf folder (#114595)

Fix #114594 
#### Context

[IEEE754-2019](https://ieeexplore.ieee.org/document/8766229) Sec 9.6
defines 2 minimum and 2 maximum operations. They are termed
- `maximum` and `maximumNumber`
- `minimum` and `minimumNumber`

In the arith dialect they are respectively named `maximumf` and
`maxnumf`, `minimumf` and `minnumf` so I use these names.

These operations only differ in how they handle NaN values. For
`maximumf` and `minimumf`, if any operand is NaN, then the result is
NaN, ie, NaN is propagated. For `maxnumf` and `minnumf`, if any operand
is NaN, then the other operand is returned, ie, NaN is absorbed. The
following identities hold:
```
maximumf(x, NaN) = maximumf(NaN, x) = NaN
maxnumf(x, NaN) = maxnumf(NaN, x) = x
```
(and same for min).

#### Arith folders

In the following I am talking about the folders for the arith
operations. The folders implement the following canonicalizations (`op`
is one of maximumf, maxnumf, minimumf, minnumf):
1. `op(x, x)` folds to `x` 
2. for `op(x, y)`, if `y` folds to the neutral element of the `op`, then
the `op` is folded to `x`.
    1. The neutral element of `maximumf` is -Infty
    2. The neutral element of `minimumf` is +Infty
3. The neutral element of `maxnumf` and `minnumf` is NaN as shown above.
3. for `op(x, y)`, if both `x` and `y` fold to constants `x'` and `y'`,
then the `op` is folded and the result is calculated with a
corresponding runtime function.

The folders are properly implemented for `maximumf` and `minimumf`, but
the same implementations were copied for the respective `maxnumf` and
`minnumf` functions. This means the neutral element of the second folder
above is wrong:
- `maxnumf(x, -Infty)` is folded to `x`, but that's wrong, because if
`x` is NaN then -Infty should be the result
- `minnumf(x, +Infty)` is folded to `x`, but same thing, the result
should be +Infty when `x` is NaN.

This is fixed by using `NaN` as neutral element for the `maxnumf` and
`minnumf` ops.[^1]

Again because of copy paste mistake, the third pattern above is using
`llvm::maximum` instead of `llvm::maximumnum` to calculate the result in
case both arguments fold to a constant:
- `maxnumf(NaN, x')` would have been folded to `llvm::maximum(NaN, x')`
which is `NaN`, whereas the result should be `x'`.

This folder for `minnumf` already correctly uses `llvm::minnum`, but I
fixed the one for `maxnumf` in this PR.


[^1]: this is by the way already correctly implemented in
[`arith::getIdentityValueAttr`](https://github.com/oowekyala/llvm-project/blob/a821964e0320d1e35514ced149ec10ec06d7131a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp#L2493-L2498)

Added: 
    

Modified: 
    mlir/include/mlir/IR/Matchers.h
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp
    mlir/test/Dialect/Arith/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 226afb9ad25f1a..816ef56e4db8ca 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -417,6 +417,11 @@ inline detail::constant_float_predicate_matcher m_OneFloat() {
   }};
 }
 
+/// Matches a constant scalar / vector splat / tensor splat float ones.
+inline detail::constant_float_predicate_matcher m_NaNFloat() {
+  return {[](const APFloat &value) { return value.isNaN(); }};
+}
+
 /// Matches a constant scalar / vector splat / tensor splat float positive
 /// infinity.
 inline detail::constant_float_predicate_matcher m_PosInfFloat() {

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 16b4e8eb4f022c..74c64761565d66 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1022,13 +1022,11 @@ OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
   if (getLhs() == getRhs())
     return getRhs();
 
-  // maxnumf(x, -inf) -> x
-  if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
+  // maxnumf(x, NaN) -> x
+  if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
     return getLhs();
 
-  return constFoldBinaryOp<FloatAttr>(
-      adaptor.getOperands(),
-      [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
+  return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1108,8 +1106,8 @@ OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
   if (getLhs() == getRhs())
     return getRhs();
 
-  // minnumf(x, +inf) -> x
-  if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
+  // minnumf(x, NaN) -> x
+  if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
     return getLhs();
 
   return constFoldBinaryOp<FloatAttr>(

diff  --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 1d4d5fc6f8319a..69df83d42f543e 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1917,31 +1917,39 @@ func.func @test_maximumf(%arg0 : f32) -> (f32, f32, f32) {
 // -----
 
 // CHECK-LABEL: @test_minnumf(
-func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32) {
+func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32, f32) {
   // CHECK-DAG:   %[[C0:.+]] = arith.constant 0.0
+  // CHECK-DAG:   %[[INF:.+]] = arith.constant
   // CHECK-NEXT:  %[[X:.+]] = arith.minnumf %arg0, %[[C0]]
-  // CHECK-NEXT:  return %[[X]], %arg0, %arg0
+  // CHECK-NEXT:  %[[Y:.+]] = arith.minnumf %arg0, %[[INF]]
+  // CHECK-NEXT:   return %[[X]], %arg0, %[[Y]], %arg0
   %c0 = arith.constant 0.0 : f32
   %inf = arith.constant 0x7F800000 : f32
+  %nan = arith.constant 0x7FC00000 : f32
   %0 = arith.minnumf %c0, %arg0 : f32
   %1 = arith.minnumf %arg0, %arg0 : f32
   %2 = arith.minnumf %inf, %arg0 : f32
-  return %0, %1, %2 : f32, f32, f32
+  %3 = arith.minnumf %nan, %arg0 : f32
+  return %0, %1, %2, %3 : f32, f32, f32, f32
 }
 
 // -----
 
 // CHECK-LABEL: @test_maxnumf(
-func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32) {
-  // CHECK-DAG:   %[[C0:.+]] = arith.constant
+func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32, f32) {
+  // CHECK-DAG:   %[[C0:.+]] = arith.constant 0.0
+  // CHECK-DAG:   %[[NINF:.+]] = arith.constant
   // CHECK-NEXT:  %[[X:.+]] = arith.maxnumf %arg0, %[[C0]]
-  // CHECK-NEXT:   return %[[X]], %arg0, %arg0
+  // CHECK-NEXT:  %[[Y:.+]] = arith.maxnumf %arg0, %[[NINF]]
+  // CHECK-NEXT:   return %[[X]], %arg0, %[[Y]], %arg0
   %c0 = arith.constant 0.0 : f32
   %-inf = arith.constant 0xFF800000 : f32
+  %nan = arith.constant 0x7FC00000 : f32
   %0 = arith.maxnumf %c0, %arg0 : f32
   %1 = arith.maxnumf %arg0, %arg0 : f32
   %2 = arith.maxnumf %-inf, %arg0 : f32
-  return %0, %1, %2 : f32, f32, f32
+  %3 = arith.maxnumf %nan, %arg0 : f32
+  return %0, %1, %2, %3 : f32, f32, f32, f32
 }
 
 // -----


        


More information about the Mlir-commits mailing list