[Mlir-commits] [mlir] b122cbe - [mlir][Math] Fix NaN handling in Exp approximation
Adrian Kuegel
llvmlistbot at llvm.org
Tue Feb 15 06:18:13 PST 2022
Author: Adrian Kuegel
Date: 2022-02-15T15:17:56+01:00
New Revision: b122cbebec43fc3cca485448337c6890c7f36cbc
URL: https://github.com/llvm/llvm-project/commit/b122cbebec43fc3cca485448337c6890c7f36cbc
DIFF: https://github.com/llvm/llvm-project/commit/b122cbebec43fc3cca485448337c6890c7f36cbc.diff
LOG: [mlir][Math] Fix NaN handling in Exp approximation
Differential Revision: https://reviews.llvm.org/D119832
Added:
Modified:
mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
mlir/test/Dialect/Math/polynomial-approximation.mlir
mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 4d9338a457aa3..5d3f629210d42 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -930,6 +930,8 @@ ExpApproximation::matchAndRewrite(math::ExpOp op,
Value x = op.getOperand();
+ Value isNan = builder.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, x, x);
+
// Reduced y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2)
Value xL2Inv = mul(x, cstLog2E);
Value kF32 = floor(xL2Inv);
@@ -985,13 +987,15 @@ ExpApproximation::matchAndRewrite(math::ExpOp op,
Value isComputable = builder.create<arith::AndIOp>(rightBound, leftBound);
expY = builder.create<arith::SelectOp>(
- isNegInfinityX, zerof32Const,
+ isNan, x,
builder.create<arith::SelectOp>(
- isPosInfinityX, constPosInfinity,
+ isNegInfinityX, zerof32Const,
builder.create<arith::SelectOp>(
- isComputable, expY,
- builder.create<arith::SelectOp>(isPostiveX, constPosInfinity,
- underflow))));
+ isPosInfinityX, constPosInfinity,
+ builder.create<arith::SelectOp>(
+ isComputable, expY,
+ builder.create<arith::SelectOp>(isPostiveX, constPosInfinity,
+ underflow)))));
rewriter.replaceOp(op, expY);
diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
index 424a6d28b77be..457e585e25f9a 100644
--- a/mlir/test/Dialect/Math/polynomial-approximation.mlir
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -110,6 +110,7 @@ func @erf_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 1.17549435E-38 : f32
// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 127 : i32
// CHECK-DAG: %[[VAL_14:.*]] = arith.constant -127 : i32
+// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[VAL_0]], %[[VAL_0]] : f32
// CHECK: %[[VAL_15:.*]] = arith.mulf %[[VAL_0]], %[[VAL_2]] : f32
// CHECK: %[[VAL_16:.*]] = math.floor %[[VAL_15]] : f32
// CHECK: %[[VAL_17:.*]] = arith.mulf %[[VAL_16]], %[[VAL_1]] : f32
@@ -136,7 +137,8 @@ func @erf_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
// CHECK: %[[VAL_38:.*]] = arith.select %[[VAL_36]], %[[VAL_30]], %[[VAL_37]] : f32
// CHECK: %[[VAL_39:.*]] = arith.select %[[VAL_34]], %[[VAL_10]], %[[VAL_38]] : f32
// CHECK: %[[VAL_40:.*]] = arith.select %[[VAL_33]], %[[VAL_9]], %[[VAL_39]] : f32
-// CHECK: return %[[VAL_40]] : f32
+// CHECK: %[[VAL_41:.*]] = arith.select %[[IS_NAN]], %[[VAL_0]], %[[VAL_40]] : f32
+// CHECK: return %[[VAL_41]] : f32
func @exp_scalar(%arg0: f32) -> f32 {
%0 = math.exp %arg0 : f32
return %0 : f32
@@ -146,7 +148,7 @@ func @exp_scalar(%arg0: f32) -> f32 {
// CHECK-SAME: %[[VAL_0:.*]]: vector<8xf32>) -> vector<8xf32> {
// CHECK: %[[VAL_1:.*]] = arith.constant dense<0.693147182> : vector<8xf32>
// CHECK-NOT: exp
-// CHECK-COUNT-3: select
+// CHECK-COUNT-4: select
// CHECK: %[[VAL_40:.*]] = arith.select
// CHECK: return %[[VAL_40]] : vector<8xf32>
func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
@@ -161,7 +163,7 @@ func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
// CHECK-DAG: %[[CST_ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[BEGIN_EXP_X:.*]] = arith.mulf %[[X]], %[[CST_LOG2E]] : f32
// CHECK-NOT: exp
-// CHECK-COUNT-3: select
+// CHECK-COUNT-4: select
// CHECK: %[[EXP_X:.*]] = arith.select
// CHECK: %[[IS_ONE_OR_NAN:.*]] = arith.cmpf ueq, %[[EXP_X]], %[[CST_ONE]] : f32
// CHECK: %[[VAL_59:.*]] = arith.subf %[[EXP_X]], %[[CST_ONE]] : f32
@@ -186,7 +188,7 @@ func @expm1_scalar(%arg0: f32) -> f32 {
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x8xf32>) -> vector<8x8xf32> {
// CHECK: %[[VAL_1:.*]] = arith.constant dense<-1.000000e+00> : vector<8x8xf32>
// CHECK-NOT: exp
-// CHECK-COUNT-4: select
+// CHECK-COUNT-5: select
// CHECK-NOT: log
// CHECK-COUNT-5: select
// CHECK-NOT: expm1
diff --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
index 5a41d56dd42bd..413c04cc8867a 100644
--- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
@@ -258,6 +258,11 @@ func @exp() {
%exp_negative_inf = math.exp %negative_inf : f32
vector.print %exp_negative_inf : f32
+ // CHECK: nan
+ %nan = arith.constant 0x7fc00000 : f32
+ %exp_nan = math.exp %nan : f32
+ vector.print %exp_nan : f32
+
return
}
@@ -292,6 +297,11 @@ func @expm1() {
%log_special_vec = math.expm1 %special_vec : vector<3xf32>
vector.print %log_special_vec : vector<3xf32>
+ // CHECK: nan
+ %nan = arith.constant 0x7fc00000 : f32
+ %exp_nan = math.expm1 %nan : f32
+ vector.print %exp_nan : f32
+
return
}
// -------------------------------------------------------------------------- //
More information about the Mlir-commits
mailing list