[Mlir-commits] [mlir] [mlir][arith] Add neutral element support to arith.maxnumf/arith.minnumf (PR #93278)

donald chen llvmlistbot at llvm.org
Fri May 24 00:49:29 PDT 2024


https://github.com/cxy-1993 created https://github.com/llvm/llvm-project/pull/93278

For maxnumf and minnumf, the result of calculations involving NaN will be another value, so their neutral element is set to NaN.

>From fb0aa3fde4c6e526df26bb0e172ef15639b6a25d Mon Sep 17 00:00:00 2001
From: cxy <chenxunyu1993 at gmail.com>
Date: Fri, 24 May 2024 07:45:57 +0000
Subject: [PATCH] [mlir][arith] Add neutral element support to
 arith.maxnumf/arith.minnumf

For maxnumf and minnumf, the result of calculations involving NaN will be
another value, so their neutral element is set to NaN.
---
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 14 ++++++++++++++
 1 file changed, 14 insertions(+)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index a0b50251c6b67..5797c5681a5fd 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2467,6 +2467,12 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
                            : APFloat::getInf(semantic, /*Negative=*/true);
     return builder.getFloatAttr(resultType, identity);
   }
+  case AtomicRMWKind::maxnumf: {
+    const llvm::fltSemantics &semantic =
+        llvm::cast<FloatType>(resultType).getFloatSemantics();
+    APFloat identity = APFloat::getNaN(semantic, /*Negative=*/true);
+    return builder.getFloatAttr(resultType, identity);
+  }
   case AtomicRMWKind::addf:
   case AtomicRMWKind::addi:
   case AtomicRMWKind::maxu:
@@ -2489,6 +2495,12 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
 
     return builder.getFloatAttr(resultType, identity);
   }
+  case AtomicRMWKind::minnumf: {
+    const llvm::fltSemantics &semantic =
+        llvm::cast<FloatType>(resultType).getFloatSemantics();
+    APFloat identity = APFloat::getNaN(semantic, /*Negative=*/false);
+    return builder.getFloatAttr(resultType, identity);
+  }
   case AtomicRMWKind::mins:
     return builder.getIntegerAttr(
         resultType, APInt::getSignedMaxValue(
@@ -2518,6 +2530,8 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
           .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
           .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
           .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
+          .Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; })
+          .Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; })
           // Integer operations.
           .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
           .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })



More information about the Mlir-commits mailing list