[Mlir-commits] [mlir] [mlir][arith] Add neutral element support to arith.maxnumf/arith.minnumf (PR #93278)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri May 24 00:49:58 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-arith
Author: donald chen (cxy-1993)
<details>
<summary>Changes</summary>
For maxnumf and minnumf, the result of calculations involving NaN will be another value, so their neutral element is set to NaN.
---
Full diff: https://github.com/llvm/llvm-project/pull/93278.diff
1 Files Affected:
- (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+14)
``````````diff
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index a0b50251c6b67..5797c5681a5fd 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2467,6 +2467,12 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
: APFloat::getInf(semantic, /*Negative=*/true);
return builder.getFloatAttr(resultType, identity);
}
+ case AtomicRMWKind::maxnumf: {
+ const llvm::fltSemantics &semantic =
+ llvm::cast<FloatType>(resultType).getFloatSemantics();
+ APFloat identity = APFloat::getNaN(semantic, /*Negative=*/true);
+ return builder.getFloatAttr(resultType, identity);
+ }
case AtomicRMWKind::addf:
case AtomicRMWKind::addi:
case AtomicRMWKind::maxu:
@@ -2489,6 +2495,12 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
return builder.getFloatAttr(resultType, identity);
}
+ case AtomicRMWKind::minnumf: {
+ const llvm::fltSemantics &semantic =
+ llvm::cast<FloatType>(resultType).getFloatSemantics();
+ APFloat identity = APFloat::getNaN(semantic, /*Negative=*/false);
+ return builder.getFloatAttr(resultType, identity);
+ }
case AtomicRMWKind::mins:
return builder.getIntegerAttr(
resultType, APInt::getSignedMaxValue(
@@ -2518,6 +2530,8 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
.Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
.Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
.Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
+ .Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; })
+ .Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; })
// Integer operations.
.Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
.Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
``````````
</details>
https://github.com/llvm/llvm-project/pull/93278
More information about the Mlir-commits
mailing list