[Mlir-commits] [mlir] [mlir][arith] Add neutral element support to arith.maxnumf/arith.minnumf (PR #93278)
donald chen
llvmlistbot at llvm.org
Fri May 24 00:49:29 PDT 2024
https://github.com/cxy-1993 created https://github.com/llvm/llvm-project/pull/93278
For maxnumf and minnumf, the result of calculations involving NaN will be another value, so their neutral element is set to NaN.
>From fb0aa3fde4c6e526df26bb0e172ef15639b6a25d Mon Sep 17 00:00:00 2001
From: cxy <chenxunyu1993 at gmail.com>
Date: Fri, 24 May 2024 07:45:57 +0000
Subject: [PATCH] [mlir][arith] Add neutral element support to
arith.maxnumf/arith.minnumf
For maxnumf and minnumf, the result of calculations involving NaN will be
another value, so their neutral element is set to NaN.
---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 14 ++++++++++++++
1 file changed, 14 insertions(+)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index a0b50251c6b67..5797c5681a5fd 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2467,6 +2467,12 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
: APFloat::getInf(semantic, /*Negative=*/true);
return builder.getFloatAttr(resultType, identity);
}
+ case AtomicRMWKind::maxnumf: {
+ const llvm::fltSemantics &semantic =
+ llvm::cast<FloatType>(resultType).getFloatSemantics();
+ APFloat identity = APFloat::getNaN(semantic, /*Negative=*/true);
+ return builder.getFloatAttr(resultType, identity);
+ }
case AtomicRMWKind::addf:
case AtomicRMWKind::addi:
case AtomicRMWKind::maxu:
@@ -2489,6 +2495,12 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
return builder.getFloatAttr(resultType, identity);
}
+ case AtomicRMWKind::minnumf: {
+ const llvm::fltSemantics &semantic =
+ llvm::cast<FloatType>(resultType).getFloatSemantics();
+ APFloat identity = APFloat::getNaN(semantic, /*Negative=*/false);
+ return builder.getFloatAttr(resultType, identity);
+ }
case AtomicRMWKind::mins:
return builder.getIntegerAttr(
resultType, APInt::getSignedMaxValue(
@@ -2518,6 +2530,8 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
.Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
.Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
.Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
+ .Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; })
+ .Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; })
// Integer operations.
.Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
.Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
More information about the Mlir-commits
mailing list