[Mlir-commits] [mlir] b122cbe - [mlir][Math] Fix NaN handling in Exp approximation

Adrian Kuegel llvmlistbot at llvm.org
Tue Feb 15 06:18:13 PST 2022


Author: Adrian Kuegel
Date: 2022-02-15T15:17:56+01:00
New Revision: b122cbebec43fc3cca485448337c6890c7f36cbc

URL: https://github.com/llvm/llvm-project/commit/b122cbebec43fc3cca485448337c6890c7f36cbc
DIFF: https://github.com/llvm/llvm-project/commit/b122cbebec43fc3cca485448337c6890c7f36cbc.diff

LOG: [mlir][Math] Fix NaN handling in Exp approximation

Differential Revision: https://reviews.llvm.org/D119832

Added: 
    

Modified: 
    mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
    mlir/test/Dialect/Math/polynomial-approximation.mlir
    mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 4d9338a457aa3..5d3f629210d42 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -930,6 +930,8 @@ ExpApproximation::matchAndRewrite(math::ExpOp op,
 
   Value x = op.getOperand();
 
+  Value isNan = builder.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, x, x);
+
   // Reduced y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2)
   Value xL2Inv = mul(x, cstLog2E);
   Value kF32 = floor(xL2Inv);
@@ -985,13 +987,15 @@ ExpApproximation::matchAndRewrite(math::ExpOp op,
   Value isComputable = builder.create<arith::AndIOp>(rightBound, leftBound);
 
   expY = builder.create<arith::SelectOp>(
-      isNegInfinityX, zerof32Const,
+      isNan, x,
       builder.create<arith::SelectOp>(
-          isPosInfinityX, constPosInfinity,
+          isNegInfinityX, zerof32Const,
           builder.create<arith::SelectOp>(
-              isComputable, expY,
-              builder.create<arith::SelectOp>(isPostiveX, constPosInfinity,
-                                              underflow))));
+              isPosInfinityX, constPosInfinity,
+              builder.create<arith::SelectOp>(
+                  isComputable, expY,
+                  builder.create<arith::SelectOp>(isPostiveX, constPosInfinity,
+                                                  underflow)))));
 
   rewriter.replaceOp(op, expY);
 

diff  --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
index 424a6d28b77be..457e585e25f9a 100644
--- a/mlir/test/Dialect/Math/polynomial-approximation.mlir
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -110,6 +110,7 @@ func @erf_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
 // CHECK-DAG:           %[[VAL_12:.*]] = arith.constant 1.17549435E-38 : f32
 // CHECK-DAG:           %[[VAL_13:.*]] = arith.constant 127 : i32
 // CHECK-DAG:           %[[VAL_14:.*]] = arith.constant -127 : i32
+// CHECK:           %[[IS_NAN:.*]] = arith.cmpf uno, %[[VAL_0]], %[[VAL_0]] : f32
 // CHECK:           %[[VAL_15:.*]] = arith.mulf %[[VAL_0]], %[[VAL_2]] : f32
 // CHECK:           %[[VAL_16:.*]] = math.floor %[[VAL_15]] : f32
 // CHECK:           %[[VAL_17:.*]] = arith.mulf %[[VAL_16]], %[[VAL_1]] : f32
@@ -136,7 +137,8 @@ func @erf_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
 // CHECK:           %[[VAL_38:.*]] = arith.select %[[VAL_36]], %[[VAL_30]], %[[VAL_37]] : f32
 // CHECK:           %[[VAL_39:.*]] = arith.select %[[VAL_34]], %[[VAL_10]], %[[VAL_38]] : f32
 // CHECK:           %[[VAL_40:.*]] = arith.select %[[VAL_33]], %[[VAL_9]], %[[VAL_39]] : f32
-// CHECK:           return %[[VAL_40]] : f32
+// CHECK:           %[[VAL_41:.*]] = arith.select %[[IS_NAN]], %[[VAL_0]], %[[VAL_40]] : f32
+// CHECK:           return %[[VAL_41]] : f32
 func @exp_scalar(%arg0: f32) -> f32 {
   %0 = math.exp %arg0 : f32
   return %0 : f32
@@ -146,7 +148,7 @@ func @exp_scalar(%arg0: f32) -> f32 {
 // CHECK-SAME:                     %[[VAL_0:.*]]: vector<8xf32>) -> vector<8xf32> {
 // CHECK:           %[[VAL_1:.*]] = arith.constant dense<0.693147182> : vector<8xf32>
 // CHECK-NOT:       exp
-// CHECK-COUNT-3:   select
+// CHECK-COUNT-4:   select
 // CHECK:           %[[VAL_40:.*]] = arith.select
 // CHECK:           return %[[VAL_40]] : vector<8xf32>
 func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
@@ -161,7 +163,7 @@ func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
 // CHECK-DAG:           %[[CST_ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK:           %[[BEGIN_EXP_X:.*]] = arith.mulf %[[X]], %[[CST_LOG2E]] : f32
 // CHECK-NOT:       exp
-// CHECK-COUNT-3:   select
+// CHECK-COUNT-4:   select
 // CHECK:           %[[EXP_X:.*]] = arith.select
 // CHECK:           %[[IS_ONE_OR_NAN:.*]] = arith.cmpf ueq, %[[EXP_X]], %[[CST_ONE]] : f32
 // CHECK:           %[[VAL_59:.*]] = arith.subf %[[EXP_X]], %[[CST_ONE]] : f32
@@ -186,7 +188,7 @@ func @expm1_scalar(%arg0: f32) -> f32 {
 // CHECK-SAME:                       %[[VAL_0:.*]]: vector<8x8xf32>) -> vector<8x8xf32> {
 // CHECK:           %[[VAL_1:.*]] = arith.constant dense<-1.000000e+00> : vector<8x8xf32>
 // CHECK-NOT:       exp
-// CHECK-COUNT-4:   select
+// CHECK-COUNT-5:   select
 // CHECK-NOT:       log
 // CHECK-COUNT-5:   select
 // CHECK-NOT:       expm1

diff  --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
index 5a41d56dd42bd..413c04cc8867a 100644
--- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
@@ -258,6 +258,11 @@ func @exp() {
   %exp_negative_inf = math.exp %negative_inf : f32
   vector.print %exp_negative_inf : f32
 
+  // CHECK: nan
+  %nan = arith.constant 0x7fc00000 : f32
+  %exp_nan = math.exp %nan : f32
+  vector.print %exp_nan : f32
+
   return
 }
 
@@ -292,6 +297,11 @@ func @expm1() {
   %log_special_vec = math.expm1 %special_vec : vector<3xf32>
   vector.print %log_special_vec : vector<3xf32>
 
+  // CHECK: nan
+  %nan = arith.constant 0x7fc00000 : f32
+  %exp_nan = math.expm1 %nan : f32
+  vector.print %exp_nan : f32
+
   return
 }
 // -------------------------------------------------------------------------- //


        


More information about the Mlir-commits mailing list