[Mlir-commits] [mlir] [mlir][arith] Add neutral element support to arith.maxnumf/arith.minnumf (PR #93278)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 24 00:49:58 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-arith

Author: donald chen (cxy-1993)

<details>
<summary>Changes</summary>

For maxnumf and minnumf, the result of calculations involving NaN will be another value, so their neutral element is set to NaN.

---
Full diff: https://github.com/llvm/llvm-project/pull/93278.diff


1 Files Affected:

- (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+14) 


``````````diff
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index a0b50251c6b67..5797c5681a5fd 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2467,6 +2467,12 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
                            : APFloat::getInf(semantic, /*Negative=*/true);
     return builder.getFloatAttr(resultType, identity);
   }
+  case AtomicRMWKind::maxnumf: {
+    const llvm::fltSemantics &semantic =
+        llvm::cast<FloatType>(resultType).getFloatSemantics();
+    APFloat identity = APFloat::getNaN(semantic, /*Negative=*/true);
+    return builder.getFloatAttr(resultType, identity);
+  }
   case AtomicRMWKind::addf:
   case AtomicRMWKind::addi:
   case AtomicRMWKind::maxu:
@@ -2489,6 +2495,12 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
 
     return builder.getFloatAttr(resultType, identity);
   }
+  case AtomicRMWKind::minnumf: {
+    const llvm::fltSemantics &semantic =
+        llvm::cast<FloatType>(resultType).getFloatSemantics();
+    APFloat identity = APFloat::getNaN(semantic, /*Negative=*/false);
+    return builder.getFloatAttr(resultType, identity);
+  }
   case AtomicRMWKind::mins:
     return builder.getIntegerAttr(
         resultType, APInt::getSignedMaxValue(
@@ -2518,6 +2530,8 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
           .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
           .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
           .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
+          .Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; })
+          .Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; })
           // Integer operations.
           .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
           .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })

``````````

</details>


https://github.com/llvm/llvm-project/pull/93278


More information about the Mlir-commits mailing list