[Mlir-commits] [mlir] [mlir][arith] Fix arith maxnumf/minnumf folder (PR #114595)

Clément Fournier llvmlistbot at llvm.org
Mon Nov 4 08:16:37 PST 2024


https://github.com/oowekyala updated https://github.com/llvm/llvm-project/pull/114595

>From 4d4a8b11be0a22dafec89dcb0d819e517776b0be Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= <clement.fournier at tu-dresden.de>
Date: Fri, 1 Nov 2024 20:13:50 +0100
Subject: [PATCH] Fix arith maxnumf/minnumf folder

---
 mlir/include/mlir/IR/Matchers.h           |  5 +++++
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp    | 12 +++++-------
 mlir/test/Dialect/Arith/canonicalize.mlir | 22 +++++++++++++++-------
 3 files changed, 25 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 6fa5a47109d20d..d218206e50f8f1 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -417,6 +417,11 @@ inline detail::constant_float_predicate_matcher m_OneFloat() {
   }};
 }
 
+/// Matches a constant scalar / vector splat / tensor splat float ones.
+inline detail::constant_float_predicate_matcher m_NaNFloat() {
+  return {[](const APFloat &value) { return value.isNaN(); }};
+}
+
 /// Matches a constant scalar / vector splat / tensor splat float positive
 /// infinity.
 inline detail::constant_float_predicate_matcher m_PosInfFloat() {
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 254f54d9e459e1..ea74121261cc4e 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1014,13 +1014,11 @@ OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
   if (getLhs() == getRhs())
     return getRhs();
 
-  // maxnumf(x, -inf) -> x
-  if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
+  // maxnumf(x, NaN) -> x
+  if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
     return getLhs();
 
-  return constFoldBinaryOp<FloatAttr>(
-      adaptor.getOperands(),
-      [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
+  return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1100,8 +1098,8 @@ OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
   if (getLhs() == getRhs())
     return getRhs();
 
-  // minnumf(x, +inf) -> x
-  if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
+  // minnumf(x, NaN) -> x
+  if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
     return getLhs();
 
   return constFoldBinaryOp<FloatAttr>(
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index a386a178b78995..84f2b0f113a0c7 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1905,31 +1905,39 @@ func.func @test_maximumf(%arg0 : f32) -> (f32, f32, f32) {
 // -----
 
 // CHECK-LABEL: @test_minnumf(
-func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32) {
+func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32, f32) {
   // CHECK-DAG:   %[[C0:.+]] = arith.constant 0.0
+  // CHECK-DAG:   %[[INF:.+]] = arith.constant
   // CHECK-NEXT:  %[[X:.+]] = arith.minnumf %arg0, %[[C0]]
-  // CHECK-NEXT:  return %[[X]], %arg0, %arg0
+  // CHECK-NEXT:  %[[Y:.+]] = arith.minnumf %arg0, %[[INF]]
+  // CHECK-NEXT:   return %[[X]], %arg0, %[[Y]], %arg0
   %c0 = arith.constant 0.0 : f32
   %inf = arith.constant 0x7F800000 : f32
+  %nan = arith.constant 0x7FC00000 : f32
   %0 = arith.minnumf %c0, %arg0 : f32
   %1 = arith.minnumf %arg0, %arg0 : f32
   %2 = arith.minnumf %inf, %arg0 : f32
-  return %0, %1, %2 : f32, f32, f32
+  %3 = arith.minnumf %nan, %arg0 : f32
+  return %0, %1, %2, %3 : f32, f32, f32, f32
 }
 
 // -----
 
 // CHECK-LABEL: @test_maxnumf(
-func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32) {
-  // CHECK-DAG:   %[[C0:.+]] = arith.constant
+func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32, f32) {
+  // CHECK-DAG:   %[[C0:.+]] = arith.constant 0.0
+  // CHECK-DAG:   %[[NINF:.+]] = arith.constant
   // CHECK-NEXT:  %[[X:.+]] = arith.maxnumf %arg0, %[[C0]]
-  // CHECK-NEXT:   return %[[X]], %arg0, %arg0
+  // CHECK-NEXT:  %[[Y:.+]] = arith.maxnumf %arg0, %[[NINF]]
+  // CHECK-NEXT:   return %[[X]], %arg0, %[[Y]], %arg0
   %c0 = arith.constant 0.0 : f32
   %-inf = arith.constant 0xFF800000 : f32
+  %nan = arith.constant 0x7FC00000 : f32
   %0 = arith.maxnumf %c0, %arg0 : f32
   %1 = arith.maxnumf %arg0, %arg0 : f32
   %2 = arith.maxnumf %-inf, %arg0 : f32
-  return %0, %1, %2 : f32, f32, f32
+  %3 = arith.maxnumf %nan, %arg0 : f32
+  return %0, %1, %2, %3 : f32, f32, f32, f32
 }
 
 // -----



More information about the Mlir-commits mailing list