[Mlir-commits] [mlir] 87de451 - [mlir][Math] Fix NaN handling in ExpM1 approximation.
Adrian Kuegel
llvmlistbot at llvm.org
Tue Feb 15 03:10:26 PST 2022
Author: Adrian Kuegel
Date: 2022-02-15T12:10:12+01:00
New Revision: 87de451bc577e1f68abb9cec7f68a55f133b6897
URL: https://github.com/llvm/llvm-project/commit/87de451bc577e1f68abb9cec7f68a55f133b6897
DIFF: https://github.com/llvm/llvm-project/commit/87de451bc577e1f68abb9cec7f68a55f133b6897.diff
LOG: [mlir][Math] Fix NaN handling in ExpM1 approximation.
Differential Revision: https://reviews.llvm.org/D119822
Added:
Modified:
mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
mlir/test/Dialect/Math/polynomial-approximation.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 9c8e413b0e55f..4d9338a457aa3 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -1033,8 +1033,8 @@ ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
Value cstNegOne = bcast(f32Cst(builder, -1.0f));
Value x = op.getOperand();
Value u = builder.create<math::ExpOp>(x);
- Value uEqOne =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne);
+ Value uEqOneOrNaN =
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::UEQ, u, cstOne);
Value uMinusOne = builder.create<arith::SubFOp>(u, cstOne);
Value uMinusOneEqNegOne = builder.create<arith::CmpFOp>(
arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne);
@@ -1050,7 +1050,7 @@ ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
uMinusOne, builder.create<arith::DivFOp>(x, logU));
expm1 = builder.create<arith::SelectOp>(isInf, u, expm1);
Value approximation = builder.create<arith::SelectOp>(
- uEqOne, x,
+ uEqOneOrNaN, x,
builder.create<arith::SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1));
rewriter.replaceOp(op, approximation);
return success();
diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
index e8c09ba2ca0ac..424a6d28b77be 100644
--- a/mlir/test/Dialect/Math/polynomial-approximation.mlir
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -163,7 +163,7 @@ func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
// CHECK-NOT: exp
// CHECK-COUNT-3: select
// CHECK: %[[EXP_X:.*]] = arith.select
-// CHECK: %[[VAL_58:.*]] = arith.cmpf oeq, %[[EXP_X]], %[[CST_ONE]] : f32
+// CHECK: %[[IS_ONE_OR_NAN:.*]] = arith.cmpf ueq, %[[EXP_X]], %[[CST_ONE]] : f32
// CHECK: %[[VAL_59:.*]] = arith.subf %[[EXP_X]], %[[CST_ONE]] : f32
// CHECK: %[[VAL_60:.*]] = arith.cmpf oeq, %[[VAL_59]], %[[CST_MINUSONE]] : f32
// CHECK-NOT: log
@@ -174,7 +174,7 @@ func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
// CHECK: %[[VAL_106:.*]] = arith.mulf %[[VAL_59]], %[[VAL_105]] : f32
// CHECK: %[[VAL_107:.*]] = arith.select %[[VAL_104]], %[[EXP_X]], %[[VAL_106]] : f32
// CHECK: %[[VAL_108:.*]] = arith.select %[[VAL_60]], %[[CST_MINUSONE]], %[[VAL_107]] : f32
-// CHECK: %[[VAL_109:.*]] = arith.select %[[VAL_58]], %[[X]], %[[VAL_108]] : f32
+// CHECK: %[[VAL_109:.*]] = arith.select %[[IS_ONE_OR_NAN]], %[[X]], %[[VAL_108]] : f32
// CHECK: return %[[VAL_109]] : f32
// CHECK: }
func @expm1_scalar(%arg0: f32) -> f32 {
More information about the Mlir-commits
mailing list