[Mlir-commits] [mlir] [mlir][arith] Fix arith maxnumf/minnumf folder (PR #114595)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Nov 1 12:23:47 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-arith
@llvm/pr-subscribers-mlir-core
Author: Clément Fournier (oowekyala)
<details>
<summary>Changes</summary>
Fix #<!-- -->114594
---
Full diff: https://github.com/llvm/llvm-project/pull/114595.diff
3 Files Affected:
- (modified) mlir/include/mlir/IR/Matchers.h (+5)
- (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+8-7)
- (modified) mlir/test/Dialect/Arith/canonicalize.mlir (+15-7)
``````````diff
diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 6fa5a47109d20d..d218206e50f8f1 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -417,6 +417,11 @@ inline detail::constant_float_predicate_matcher m_OneFloat() {
}};
}
+/// Matches a constant scalar / vector splat / tensor splat float ones.
+inline detail::constant_float_predicate_matcher m_NaNFloat() {
+ return {[](const APFloat &value) { return value.isNaN(); }};
+}
+
/// Matches a constant scalar / vector splat / tensor splat float positive
/// infinity.
inline detail::constant_float_predicate_matcher m_PosInfFloat() {
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 254f54d9e459e1..7734911e1e01a7 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1014,13 +1014,14 @@ OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
if (getLhs() == getRhs())
return getRhs();
- // maxnumf(x, -inf) -> x
- if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
+ // maxnumf(x, NaN) -> x
+ if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
return getLhs();
- return constFoldBinaryOp<FloatAttr>(
- adaptor.getOperands(),
- [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
+ return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
+ [](const APFloat &a, const APFloat &b) {
+ return llvm::maximumnum(a, b);
+ });
}
//===----------------------------------------------------------------------===//
@@ -1100,8 +1101,8 @@ OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
if (getLhs() == getRhs())
return getRhs();
- // minnumf(x, +inf) -> x
- if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
+ // minnumf(x, NaN) -> x
+ if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
return getLhs();
return constFoldBinaryOp<FloatAttr>(
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index a386a178b78995..84f2b0f113a0c7 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1905,31 +1905,39 @@ func.func @test_maximumf(%arg0 : f32) -> (f32, f32, f32) {
// -----
// CHECK-LABEL: @test_minnumf(
-func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32) {
+func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32, f32) {
// CHECK-DAG: %[[C0:.+]] = arith.constant 0.0
+ // CHECK-DAG: %[[INF:.+]] = arith.constant
// CHECK-NEXT: %[[X:.+]] = arith.minnumf %arg0, %[[C0]]
- // CHECK-NEXT: return %[[X]], %arg0, %arg0
+ // CHECK-NEXT: %[[Y:.+]] = arith.minnumf %arg0, %[[INF]]
+ // CHECK-NEXT: return %[[X]], %arg0, %[[Y]], %arg0
%c0 = arith.constant 0.0 : f32
%inf = arith.constant 0x7F800000 : f32
+ %nan = arith.constant 0x7FC00000 : f32
%0 = arith.minnumf %c0, %arg0 : f32
%1 = arith.minnumf %arg0, %arg0 : f32
%2 = arith.minnumf %inf, %arg0 : f32
- return %0, %1, %2 : f32, f32, f32
+ %3 = arith.minnumf %nan, %arg0 : f32
+ return %0, %1, %2, %3 : f32, f32, f32, f32
}
// -----
// CHECK-LABEL: @test_maxnumf(
-func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32) {
- // CHECK-DAG: %[[C0:.+]] = arith.constant
+func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32, f32) {
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0.0
+ // CHECK-DAG: %[[NINF:.+]] = arith.constant
// CHECK-NEXT: %[[X:.+]] = arith.maxnumf %arg0, %[[C0]]
- // CHECK-NEXT: return %[[X]], %arg0, %arg0
+ // CHECK-NEXT: %[[Y:.+]] = arith.maxnumf %arg0, %[[NINF]]
+ // CHECK-NEXT: return %[[X]], %arg0, %[[Y]], %arg0
%c0 = arith.constant 0.0 : f32
%-inf = arith.constant 0xFF800000 : f32
+ %nan = arith.constant 0x7FC00000 : f32
%0 = arith.maxnumf %c0, %arg0 : f32
%1 = arith.maxnumf %arg0, %arg0 : f32
%2 = arith.maxnumf %-inf, %arg0 : f32
- return %0, %1, %2 : f32, f32, f32
+ %3 = arith.maxnumf %nan, %arg0 : f32
+ return %0, %1, %2, %3 : f32, f32, f32, f32
}
// -----
``````````
</details>
https://github.com/llvm/llvm-project/pull/114595
More information about the Mlir-commits
mailing list