[Mlir-commits] [mlir] 87de451 - [mlir][Math] Fix NaN handling in ExpM1 approximation.

Adrian Kuegel llvmlistbot at llvm.org
Tue Feb 15 03:10:26 PST 2022


Author: Adrian Kuegel
Date: 2022-02-15T12:10:12+01:00
New Revision: 87de451bc577e1f68abb9cec7f68a55f133b6897

URL: https://github.com/llvm/llvm-project/commit/87de451bc577e1f68abb9cec7f68a55f133b6897
DIFF: https://github.com/llvm/llvm-project/commit/87de451bc577e1f68abb9cec7f68a55f133b6897.diff

LOG: [mlir][Math] Fix NaN handling in ExpM1 approximation.

Differential Revision: https://reviews.llvm.org/D119822

Added: 
    

Modified: 
    mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
    mlir/test/Dialect/Math/polynomial-approximation.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 9c8e413b0e55f..4d9338a457aa3 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -1033,8 +1033,8 @@ ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
   Value cstNegOne = bcast(f32Cst(builder, -1.0f));
   Value x = op.getOperand();
   Value u = builder.create<math::ExpOp>(x);
-  Value uEqOne =
-      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne);
+  Value uEqOneOrNaN =
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::UEQ, u, cstOne);
   Value uMinusOne = builder.create<arith::SubFOp>(u, cstOne);
   Value uMinusOneEqNegOne = builder.create<arith::CmpFOp>(
       arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne);
@@ -1050,7 +1050,7 @@ ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
       uMinusOne, builder.create<arith::DivFOp>(x, logU));
   expm1 = builder.create<arith::SelectOp>(isInf, u, expm1);
   Value approximation = builder.create<arith::SelectOp>(
-      uEqOne, x,
+      uEqOneOrNaN, x,
       builder.create<arith::SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1));
   rewriter.replaceOp(op, approximation);
   return success();

diff  --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
index e8c09ba2ca0ac..424a6d28b77be 100644
--- a/mlir/test/Dialect/Math/polynomial-approximation.mlir
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -163,7 +163,7 @@ func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
 // CHECK-NOT:       exp
 // CHECK-COUNT-3:   select
 // CHECK:           %[[EXP_X:.*]] = arith.select
-// CHECK:           %[[VAL_58:.*]] = arith.cmpf oeq, %[[EXP_X]], %[[CST_ONE]] : f32
+// CHECK:           %[[IS_ONE_OR_NAN:.*]] = arith.cmpf ueq, %[[EXP_X]], %[[CST_ONE]] : f32
 // CHECK:           %[[VAL_59:.*]] = arith.subf %[[EXP_X]], %[[CST_ONE]] : f32
 // CHECK:           %[[VAL_60:.*]] = arith.cmpf oeq, %[[VAL_59]], %[[CST_MINUSONE]] : f32
 // CHECK-NOT:       log
@@ -174,7 +174,7 @@ func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
 // CHECK:           %[[VAL_106:.*]] = arith.mulf %[[VAL_59]], %[[VAL_105]] : f32
 // CHECK:           %[[VAL_107:.*]] = arith.select %[[VAL_104]], %[[EXP_X]], %[[VAL_106]] : f32
 // CHECK:           %[[VAL_108:.*]] = arith.select %[[VAL_60]], %[[CST_MINUSONE]], %[[VAL_107]] : f32
-// CHECK:           %[[VAL_109:.*]] = arith.select %[[VAL_58]], %[[X]], %[[VAL_108]] : f32
+// CHECK:           %[[VAL_109:.*]] = arith.select %[[IS_ONE_OR_NAN]], %[[X]], %[[VAL_108]] : f32
 // CHECK:           return %[[VAL_109]] : f32
 // CHECK:         }
 func @expm1_scalar(%arg0: f32) -> f32 {


        


More information about the Mlir-commits mailing list