[Mlir-commits] [mlir] [mlir][arith] Fix arith maxnumf/minnumf folder (PR #114595)
Clément Fournier
llvmlistbot at llvm.org
Mon Nov 4 08:16:37 PST 2024
https://github.com/oowekyala updated https://github.com/llvm/llvm-project/pull/114595
>From 4d4a8b11be0a22dafec89dcb0d819e517776b0be Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= <clement.fournier at tu-dresden.de>
Date: Fri, 1 Nov 2024 20:13:50 +0100
Subject: [PATCH] Fix arith maxnumf/minnumf folder
---
mlir/include/mlir/IR/Matchers.h | 5 +++++
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 12 +++++-------
mlir/test/Dialect/Arith/canonicalize.mlir | 22 +++++++++++++++-------
3 files changed, 25 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 6fa5a47109d20d..d218206e50f8f1 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -417,6 +417,11 @@ inline detail::constant_float_predicate_matcher m_OneFloat() {
}};
}
+/// Matches a constant scalar / vector splat / tensor splat float ones.
+inline detail::constant_float_predicate_matcher m_NaNFloat() {
+ return {[](const APFloat &value) { return value.isNaN(); }};
+}
+
/// Matches a constant scalar / vector splat / tensor splat float positive
/// infinity.
inline detail::constant_float_predicate_matcher m_PosInfFloat() {
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 254f54d9e459e1..ea74121261cc4e 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1014,13 +1014,11 @@ OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
if (getLhs() == getRhs())
return getRhs();
- // maxnumf(x, -inf) -> x
- if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
+ // maxnumf(x, NaN) -> x
+ if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
return getLhs();
- return constFoldBinaryOp<FloatAttr>(
- adaptor.getOperands(),
- [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
+ return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum);
}
//===----------------------------------------------------------------------===//
@@ -1100,8 +1098,8 @@ OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
if (getLhs() == getRhs())
return getRhs();
- // minnumf(x, +inf) -> x
- if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
+ // minnumf(x, NaN) -> x
+ if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
return getLhs();
return constFoldBinaryOp<FloatAttr>(
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index a386a178b78995..84f2b0f113a0c7 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1905,31 +1905,39 @@ func.func @test_maximumf(%arg0 : f32) -> (f32, f32, f32) {
// -----
// CHECK-LABEL: @test_minnumf(
-func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32) {
+func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32, f32) {
// CHECK-DAG: %[[C0:.+]] = arith.constant 0.0
+ // CHECK-DAG: %[[INF:.+]] = arith.constant
// CHECK-NEXT: %[[X:.+]] = arith.minnumf %arg0, %[[C0]]
- // CHECK-NEXT: return %[[X]], %arg0, %arg0
+ // CHECK-NEXT: %[[Y:.+]] = arith.minnumf %arg0, %[[INF]]
+ // CHECK-NEXT: return %[[X]], %arg0, %[[Y]], %arg0
%c0 = arith.constant 0.0 : f32
%inf = arith.constant 0x7F800000 : f32
+ %nan = arith.constant 0x7FC00000 : f32
%0 = arith.minnumf %c0, %arg0 : f32
%1 = arith.minnumf %arg0, %arg0 : f32
%2 = arith.minnumf %inf, %arg0 : f32
- return %0, %1, %2 : f32, f32, f32
+ %3 = arith.minnumf %nan, %arg0 : f32
+ return %0, %1, %2, %3 : f32, f32, f32, f32
}
// -----
// CHECK-LABEL: @test_maxnumf(
-func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32) {
- // CHECK-DAG: %[[C0:.+]] = arith.constant
+func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32, f32) {
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0.0
+ // CHECK-DAG: %[[NINF:.+]] = arith.constant
// CHECK-NEXT: %[[X:.+]] = arith.maxnumf %arg0, %[[C0]]
- // CHECK-NEXT: return %[[X]], %arg0, %arg0
+ // CHECK-NEXT: %[[Y:.+]] = arith.maxnumf %arg0, %[[NINF]]
+ // CHECK-NEXT: return %[[X]], %arg0, %[[Y]], %arg0
%c0 = arith.constant 0.0 : f32
%-inf = arith.constant 0xFF800000 : f32
+ %nan = arith.constant 0x7FC00000 : f32
%0 = arith.maxnumf %c0, %arg0 : f32
%1 = arith.maxnumf %arg0, %arg0 : f32
%2 = arith.maxnumf %-inf, %arg0 : f32
- return %0, %1, %2 : f32, f32, f32
+ %3 = arith.maxnumf %nan, %arg0 : f32
+ return %0, %1, %2, %3 : f32, f32, f32, f32
}
// -----
More information about the Mlir-commits
mailing list