[Mlir-commits] [mlir] 710dc72 - [mlir][math] Modified the 'math.exp' lowering for higher precision

Rob Suderman llvmlistbot at llvm.org
Fri Jun 23 12:27:24 PDT 2023


Author: Robert Suderman
Date: 2023-06-23T12:25:18-07:00
New Revision: 710dc7282a013830695fa7aa1559f1c96cb43b51

URL: https://github.com/llvm/llvm-project/commit/710dc7282a013830695fa7aa1559f1c96cb43b51
DIFF: https://github.com/llvm/llvm-project/commit/710dc7282a013830695fa7aa1559f1c96cb43b51.diff

LOG: [mlir][math] Modified the 'math.exp' lowering for higher precision

The existing lowering has lower precision for certain use cases, e.g.
tanh. Improved version should demonstrate an overall higher level of precision.

Reviewed By: cota, jpienaar

Differential Revision: https://reviews.llvm.org/D153592

Added: 
    

Modified: 
    mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
    mlir/test/Dialect/Math/polynomial-approximation.mlir
    mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index a3efc6ef41a95..070ca0b7170b8 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -895,6 +895,31 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
 
 namespace {
 
+Value clampWithNormals(ImplicitLocOpBuilder &builder,
+                       const llvm::ArrayRef<int64_t> shape, Value value,
+                       float lowerBound, float upperBound) {
+  assert(!std::isnan(lowerBound));
+  assert(!std::isnan(upperBound));
+
+  auto bcast = [&](Value value) -> Value {
+    return broadcast(builder, value, shape);
+  };
+
+  auto selectCmp = [&builder](auto pred, Value value, Value bound) {
+    return builder.create<arith::SelectOp>(
+        builder.create<arith::CmpFOp>(pred, value, bound), value, bound);
+  };
+
+  // Note: prefer UGE/ULE vs. UGT/ULT, since they generate vmaxps/vminps vs.
+  // vcmpleps+vmovaps on x86_64. The latter outcome is also obtained with
+  // arith::{Max,Min}FOp.
+  value = selectCmp(arith::CmpFPredicate::UGE, value,
+                    bcast(f32Cst(builder, lowerBound)));
+  value = selectCmp(arith::CmpFPredicate::ULE, value,
+                    bcast(f32Cst(builder, upperBound)));
+  return value;
+}
+
 struct ExpApproximation : public OpRewritePattern<math::ExpOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
@@ -902,122 +927,146 @@ struct ExpApproximation : public OpRewritePattern<math::ExpOp> {
   LogicalResult matchAndRewrite(math::ExpOp op,
                                 PatternRewriter &rewriter) const final;
 };
-} // namespace
 
-// Approximate exp(x) using its reduced range exp(y) where y is in the range
-// [0, ln(2)], let y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2), exp(x)
-// = exp(y) * 2^k. exp(y).
 LogicalResult
 ExpApproximation::matchAndRewrite(math::ExpOp op,
                                   PatternRewriter &rewriter) const {
-  if (!getElementTypeOrSelf(op.getOperand()).isF32())
+  auto shape = vectorShape(op.getOperand().getType());
+  auto elementTy = getElementTypeOrSelf(op.getType());
+  if (!elementTy.isF32())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
-  ArrayRef<int64_t> shape = vectorShape(op.getOperand());
-
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
 
-  // TODO: Consider a common pattern rewriter with all methods below to
-  // write the approximations.
+  auto add = [&](Value a, Value b) -> Value {
+    return builder.create<arith::AddFOp>(a, b);
+  };
   auto bcast = [&](Value value) -> Value {
     return broadcast(builder, value, shape);
   };
+  auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); };
   auto fmla = [&](Value a, Value b, Value c) {
     return builder.create<math::FmaOp>(a, b, c);
   };
   auto mul = [&](Value a, Value b) -> Value {
     return builder.create<arith::MulFOp>(a, b);
   };
-  auto sub = [&](Value a, Value b) -> Value {
-    return builder.create<arith::SubFOp>(a, b);
-  };
-  auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); };
-
-  Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE)));
-  Value cstLog2E = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE)));
-
-  // Polynomial coefficients.
-  Value cstCephesExpP0 = bcast(f32Cst(builder, 1.0));
-  Value cstCephesExpP1 = bcast(f32Cst(builder, 1.0));
-  Value cstCephesExpP2 = bcast(f32Cst(builder, 0.49970514590562437052f));
-  Value cstCephesExpP3 = bcast(f32Cst(builder, 0.16873890085469545053f));
-  Value cstCephesExpP4 = bcast(f32Cst(builder, 0.03668965196652099192f));
-  Value cstCephesExpP5 = bcast(f32Cst(builder, 0.01314350012789660196f));
 
+  // Polynomial approximation from Cephes.
+  //
+  // To compute e^x, we re-express it as
+  //
+  //   e^x = e^(a + b)
+  //       = e^(a + n log(2))
+  //       = e^a * 2^n.
+  //
+  // We choose n = round(x / log(2)), restricting the value of `a` to
+  // (-log(2)/2, log(2)/2).  We then use a polynomial to compute e^a. The
+  // relative error between our approximation and the true value of e^a is less
+  // than 2^-22.5 for all values of `a` within this range.
+
+  // Restrict input to a small range, including some values that evaluate to
+  // +/- inf.  Note that for our lower bound, we choose log(2^-126) instead of
+  // log(F32_EPSILON). We do so because this routine always flushes denormal
+  // floating points to 0. Therefore, we only need to worry about exponentiating
+  // up to the smallest representable non-denormal floating point, which is
+  // 2^-126.
+
+  // Constants.
+  Value cst_half = bcast(f32Cst(builder, 0.5f));
+  Value cst_one = bcast(f32Cst(builder, 1.0f));
+
+  // 1/log(2)
+  Value cst_log2ef = bcast(f32Cst(builder, 1.44269504088896341f));
+
+  Value cst_exp_c1 = bcast(f32Cst(builder, -0.693359375f));
+  Value cst_exp_c2 = bcast(f32Cst(builder, 2.12194440e-4f));
+  Value cst_exp_p0 = bcast(f32Cst(builder, 1.9875691500E-4f));
+  Value cst_exp_p1 = bcast(f32Cst(builder, 1.3981999507E-3f));
+  Value cst_exp_p2 = bcast(f32Cst(builder, 8.3334519073E-3f));
+  Value cst_exp_p3 = bcast(f32Cst(builder, 4.1665795894E-2f));
+  Value cst_exp_p4 = bcast(f32Cst(builder, 1.6666665459E-1f));
+  Value cst_exp_p5 = bcast(f32Cst(builder, 5.0000001201E-1f));
+
+  // Our computations below aren't particularly sensitive to the exact choices
+  // here, so we choose values a bit larger/smaller than
+  //
+  //   log(F32_MAX) = 88.723...
+  //   log(2^-126) = -87.337...
   Value x = op.getOperand();
+  x = clampWithNormals(builder, shape, x, -87.8f, 88.8f);
+  Value n = floor(fmla(x, cst_log2ef, cst_half));
 
-  Value isNan = builder.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, x, x);
-
-  // Reduced y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2)
-  Value xL2Inv = mul(x, cstLog2E);
-  Value kF32 = floor(xL2Inv);
-  Value kLn2 = mul(kF32, cstLn2);
-  Value y = sub(x, kLn2);
-
-  // Use Estrin's evaluation scheme with 3 independent parts:
-  // P(y)^y : (c0 + c1 y) + (c2 + c3 y) y^2 + (c4 + c5 y) y^4
-  Value y2 = mul(y, y);
-  Value y4 = mul(y2, y2);
-
-  Value q0 = fmla(cstCephesExpP1, y, cstCephesExpP0);
-  Value q1 = fmla(cstCephesExpP3, y, cstCephesExpP2);
-  Value q2 = fmla(cstCephesExpP5, y, cstCephesExpP4);
-  Value expY = fmla(q1, y2, q0);
-  expY = fmla(q2, y4, expY);
-
-  auto i32Vec = broadcast(builder.getI32Type(), shape);
-
-  // exp2(k)
-  Value k = builder.create<arith::FPToSIOp>(i32Vec, kF32);
-  Value exp2KValue = exp2I32(builder, k);
-
-  // exp(x) = exp(y) * exp2(k)
-  expY = mul(expY, exp2KValue);
-
-  // Handle overflow, inf and underflow of exp(x). exp(x) range is [0, inf], its
-  // partitioned as the following:
-  // exp(x) = 0, x <= -inf
-  // exp(x) = underflow (min_float), x <= -88
-  // exp(x) = inf (min_float), x >= 88
-  // Note: |k| = 127 is the value where the 8-bits exponent saturates.
-  Value zerof32Const = bcast(f32Cst(builder, 0));
-  auto constPosInfinity =
-      bcast(f32Cst(builder, std::numeric_limits<float>::infinity()));
-  auto constNegIfinity =
-      bcast(f32Cst(builder, -std::numeric_limits<float>::infinity()));
-  auto underflow = bcast(f32Cst(builder, std::numeric_limits<float>::min()));
-
-  Value kMaxConst = bcast(i32Cst(builder, 127));
-  Value kMaxNegConst = bcast(i32Cst(builder, -127));
-  Value rightBound =
-      builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, k, kMaxConst);
-  Value leftBound =
-      builder.create<arith::CmpIOp>(arith::CmpIPredicate::sge, k, kMaxNegConst);
-
-  Value isNegInfinityX = builder.create<arith::CmpFOp>(
-      arith::CmpFPredicate::OEQ, x, constNegIfinity);
-  Value isPosInfinityX = builder.create<arith::CmpFOp>(
-      arith::CmpFPredicate::OEQ, x, constPosInfinity);
-  Value isPostiveX =
-      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zerof32Const);
-  Value isComputable = builder.create<arith::AndIOp>(rightBound, leftBound);
-
-  expY = builder.create<arith::SelectOp>(
-      isNan, x,
-      builder.create<arith::SelectOp>(
-          isNegInfinityX, zerof32Const,
-          builder.create<arith::SelectOp>(
-              isPosInfinityX, constPosInfinity,
-              builder.create<arith::SelectOp>(
-                  isComputable, expY,
-                  builder.create<arith::SelectOp>(isPostiveX, constPosInfinity,
-                                                  underflow)))));
-
-  rewriter.replaceOp(op, expY);
-
-  return success();
+  // When we eventually do the multiplication in e^a * 2^n, we need to handle
+  // the case when n > 127, the max fp32 exponent (so 2^n == inf) but e^a < 1
+  // (so e^a * 2^n != inf).  There's a similar problem for n < -126, the
+  // smallest fp32 exponent.
+  //
+  // A straightforward solution would be to detect n out of range and split it
+  // up, doing
+  //
+  //   e^a * 2^n = e^a * 2^(n1 + n2)
+  //             = (2^n1 * e^a) * 2^n2.
+  //
+  // But it turns out this approach is quite slow, probably because it
+  // manipulates subnormal values.
+  //
+  // The approach we use instead is to clamp n to [-127, 127]. Let n' be the
+  // value of n clamped to [-127, 127]. In the case where n' = 127, `a` can grow
+  // up to as large as 88.8 - 127 * log(2) which is about 0.7703. Even though
+  // this value of `a` is outside our previously specified range, e^a will still
+  // only have a relative error of approximately 2^-16 at worse. In practice
+  // this seems to work well enough; it passes our exhaustive tests, breaking
+  // only one result, and by one ulp (we return exp(88.7228394) = max-float but
+  // we should return inf).
+  //
+  // In the case where n' = -127, the original input value of x is so small that
+  // e^x, our final answer, is less than 2^-126. Since 2^-126 is the smallest
+  // normal floating point, and since we flush denormals, we simply return 0. We
+  // do this in a branchless way by observing that our code for constructing 2^n
+  // produces 0 if n = -127.
+  //
+  // The proof that n' = -127 implies e^x < 2^-126 is as follows:
+  //
+  //    n' = -127 implies n <= -127
+  //              implies round(x / log(2)) <= -127
+  //              implies x/log(2) < -126.5
+  //              implies x < -126.5 * log(2)
+  //              implies e^x < e^(-126.5 * log(2))
+  //              implies e^x < 2^-126.5 < 2^-126
+  //
+  //    This proves that n' = -127 implies e^x < 2^-126.
+  n = clampWithNormals(builder, shape, n, -127.0f, 127.0f);
+
+  // Computes x = x - n' * log(2), the value for `a`
+  x = fmla(cst_exp_c1, n, x);
+  x = fmla(cst_exp_c2, n, x);
+
+  // Polynomial to compute z = e^a, accurate for a in (-0.5, 0.5).
+  Value z = fmla(x, cst_exp_p0, cst_exp_p1);
+  z = fmla(z, x, cst_exp_p2);
+  z = fmla(z, x, cst_exp_p3);
+  z = fmla(z, x, cst_exp_p4);
+  z = fmla(z, x, cst_exp_p5);
+  z = fmla(z, mul(x, x), x);
+  z = add(cst_one, z);
+
+  // Convert n' to an i32.  This is safe because we clamped it above.
+  auto i32_vec = broadcast(builder.getI32Type(), shape);
+  Value n_i32 = builder.create<arith::FPToSIOp>(i32_vec, n);
+
+  // Creates the value 2^n' if -126 <= n' <= 127 and 0 if n' = -127.
+  Value pow2 = exp2I32(builder, n_i32);
+
+  // Return z * 2^n' if -126 <= n' <= 127 and 0 if n = -127.
+  Value ret = mul(z, pow2);
+
+  rewriter.replaceOp(op, ret);
+  return mlir::success();
 }
 
+} // namespace
+
 //----------------------------------------------------------------------------//
 // ExpM1 approximation.
 //----------------------------------------------------------------------------//

diff  --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
index b87d4b5ecdbc6..3c87ecf72011d 100644
--- a/mlir/test/Dialect/Math/polynomial-approximation.mlir
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -96,48 +96,47 @@ func.func @erf_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
 
 // CHECK-LABEL:   func @exp_scalar(
 // CHECK-SAME:                     %[[VAL_0:.*]]: f32) -> f32 {
-// CHECK-DAG:           %[[VAL_1:.*]] = arith.constant 0.693147182 : f32
-// CHECK-DAG:           %[[VAL_2:.*]] = arith.constant 1.44269502 : f32
-// CHECK-DAG:           %[[VAL_3:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK-DAG:           %[[VAL_4:.*]] = arith.constant 0.499705136 : f32
-// CHECK-DAG:           %[[VAL_5:.*]] = arith.constant 0.168738902 : f32
-// CHECK-DAG:           %[[VAL_6:.*]] = arith.constant 0.0366896503 : f32
-// CHECK-DAG:           %[[VAL_7:.*]] = arith.constant 1.314350e-02 : f32
-// CHECK-DAG:           %[[VAL_8:.*]] = arith.constant 23 : i32
-// CHECK-DAG:           %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK-DAG:           %[[VAL_10:.*]] = arith.constant 0x7F800000 : f32
-// CHECK-DAG:           %[[VAL_11:.*]] = arith.constant 0xFF800000 : f32
-// CHECK-DAG:           %[[VAL_12:.*]] = arith.constant 1.17549435E-38 : f32
-// CHECK-DAG:           %[[VAL_13:.*]] = arith.constant 127 : i32
-// CHECK-DAG:           %[[VAL_14:.*]] = arith.constant -127 : i32
-// CHECK:           %[[IS_NAN:.*]] = arith.cmpf uno, %[[VAL_0]], %[[VAL_0]] : f32
-// CHECK:           %[[VAL_15:.*]] = arith.mulf %[[VAL_0]], %[[VAL_2]] : f32
-// CHECK:           %[[VAL_16:.*]] = math.floor %[[VAL_15]] : f32
-// CHECK:           %[[VAL_17:.*]] = arith.mulf %[[VAL_16]], %[[VAL_1]] : f32
-// CHECK:           %[[VAL_18:.*]] = arith.subf %[[VAL_0]], %[[VAL_17]] : f32
-// CHECK:           %[[VAL_19:.*]] = arith.mulf %[[VAL_18]], %[[VAL_18]] : f32
-// CHECK:           %[[VAL_20:.*]] = arith.mulf %[[VAL_19]], %[[VAL_19]] : f32
-// CHECK:           %[[VAL_21:.*]] = math.fma %[[VAL_3]], %[[VAL_18]], %[[VAL_3]] : f32
-// CHECK:           %[[VAL_22:.*]] = math.fma %[[VAL_5]], %[[VAL_18]], %[[VAL_4]] : f32
-// CHECK:           %[[VAL_23:.*]] = math.fma %[[VAL_7]], %[[VAL_18]], %[[VAL_6]] : f32
-// CHECK:           %[[VAL_24:.*]] = math.fma %[[VAL_22]], %[[VAL_19]], %[[VAL_21]] : f32
-// CHECK:           %[[VAL_25:.*]] = math.fma %[[VAL_23]], %[[VAL_20]], %[[VAL_24]] : f32
-// CHECK:           %[[VAL_26:.*]] = arith.fptosi %[[VAL_16]] : f32 to i32
-// CHECK:           %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_13]] : i32
-// CHECK:           %[[VAL_28:.*]] = arith.shli %[[VAL_27]], %[[VAL_8]] : i32
-// CHECK:           %[[VAL_29:.*]] = arith.bitcast %[[VAL_28]] : i32 to f32
-// CHECK:           %[[VAL_30:.*]] = arith.mulf %[[VAL_25]], %[[VAL_29]] : f32
-// CHECK:           %[[VAL_31:.*]] = arith.cmpi sle, %[[VAL_26]], %[[VAL_13]] : i32
-// CHECK:           %[[VAL_32:.*]] = arith.cmpi sge, %[[VAL_26]], %[[VAL_14]] : i32
-// CHECK:           %[[VAL_33:.*]] = arith.cmpf oeq, %[[VAL_0]], %[[VAL_11]] : f32
-// CHECK:           %[[VAL_34:.*]] = arith.cmpf oeq, %[[VAL_0]], %[[VAL_10]] : f32
-// CHECK:           %[[VAL_35:.*]] = arith.cmpf ogt, %[[VAL_0]], %[[VAL_9]] : f32
-// CHECK:           %[[VAL_36:.*]] = arith.andi %[[VAL_31]], %[[VAL_32]] : i1
-// CHECK:           %[[VAL_37:.*]] = arith.select %[[VAL_35]], %[[VAL_10]], %[[VAL_12]] : f32
-// CHECK:           %[[VAL_38:.*]] = arith.select %[[VAL_36]], %[[VAL_30]], %[[VAL_37]] : f32
-// CHECK:           %[[VAL_39:.*]] = arith.select %[[VAL_34]], %[[VAL_10]], %[[VAL_38]] : f32
-// CHECK:           %[[VAL_40:.*]] = arith.select %[[VAL_33]], %[[VAL_9]], %[[VAL_39]] : f32
-// CHECK:           %[[VAL_41:.*]] = arith.select %[[IS_NAN]], %[[VAL_0]], %[[VAL_40]] : f32
+// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 5.000000e-01 : f32
+// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 1.44269502 : f32
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant -0.693359375 : f32
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 2.12194442E-4 : f32
+// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 1.98756912E-4 : f32
+// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 0.00139819994 : f32
+// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 0.00833345205 : f32
+// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 0.0416657962 : f32
+// CHECK-DAG:       %[[VAL_10:.*]] = arith.constant 0.166666657 : f32
+// CHECK-DAG:       %[[VAL_11:.*]] = arith.constant -8.780000e+01 : f32
+// CHECK-DAG:       %[[VAL_12:.*]] = arith.constant 8.880000e+01 : f32
+// CHECK-DAG:       %[[VAL_13:.*]] = arith.constant -1.270000e+02 : f32
+// CHECK-DAG:       %[[VAL_14:.*]] = arith.constant 1.270000e+02 : f32
+// CHECK-DAG:       %[[VAL_15:.*]] = arith.constant 23 : i32
+// CHECK-DAG:       %[[VAL_16:.*]] = arith.constant 127 : i32
+// CHECK-DAG:       %[[VAL_17:.*]] = arith.cmpf uge, %[[VAL_0]], %[[VAL_11]] : f32
+// CHECK-DAG:       %[[VAL_18:.*]] = arith.select %[[VAL_17]], %[[VAL_0]], %[[VAL_11]] : f32
+// CHECK-DAG:       %[[VAL_19:.*]] = arith.cmpf ule, %[[VAL_18]], %[[VAL_12]] : f32
+// CHECK-DAG:       %[[VAL_20:.*]] = arith.select %[[VAL_19]], %[[VAL_18]], %[[VAL_12]] : f32
+// CHECK-DAG:       %[[VAL_21:.*]] = math.fma %[[VAL_20]], %[[VAL_3]], %[[VAL_1]] : f32
+// CHECK-DAG:       %[[VAL_22:.*]] = math.floor %[[VAL_21]] : f32
+// CHECK-DAG:       %[[VAL_23:.*]] = arith.cmpf uge, %[[VAL_22]], %[[VAL_13]] : f32
+// CHECK-DAG:       %[[VAL_24:.*]] = arith.select %[[VAL_23]], %[[VAL_22]], %[[VAL_13]] : f32
+// CHECK-DAG:       %[[VAL_25:.*]] = arith.cmpf ule, %[[VAL_24]], %[[VAL_14]] : f32
+// CHECK-DAG:       %[[VAL_26:.*]] = arith.select %[[VAL_25]], %[[VAL_24]], %[[VAL_14]] : f32
+// CHECK-DAG:       %[[VAL_27:.*]] = math.fma %[[VAL_4]], %[[VAL_26]], %[[VAL_20]] : f32
+// CHECK-DAG:       %[[VAL_28:.*]] = math.fma %[[VAL_5]], %[[VAL_26]], %[[VAL_27]] : f32
+// CHECK-DAG:       %[[VAL_29:.*]] = math.fma %[[VAL_28]], %[[VAL_6]], %[[VAL_7]] : f32
+// CHECK-DAG:       %[[VAL_30:.*]] = math.fma %[[VAL_29]], %[[VAL_28]], %[[VAL_8]] : f32
+// CHECK-DAG:       %[[VAL_31:.*]] = math.fma %[[VAL_30]], %[[VAL_28]], %[[VAL_9]] : f32
+// CHECK-DAG:       %[[VAL_32:.*]] = math.fma %[[VAL_31]], %[[VAL_28]], %[[VAL_10]] : f32
+// CHECK-DAG:       %[[VAL_33:.*]] = math.fma %[[VAL_32]], %[[VAL_28]], %[[VAL_1]] : f32
+// CHECK-DAG:       %[[VAL_34:.*]] = arith.mulf %[[VAL_28]], %[[VAL_28]] : f32
+// CHECK-DAG:       %[[VAL_35:.*]] = math.fma %[[VAL_33]], %[[VAL_34]], %[[VAL_28]] : f32
+// CHECK-DAG:       %[[VAL_36:.*]] = arith.addf %[[VAL_35]], %[[VAL_2]] : f32
+// CHECK-DAG:       %[[VAL_37:.*]] = arith.fptosi %[[VAL_26]] : f32 to i32
+// CHECK-DAG:       %[[VAL_38:.*]] = arith.addi %[[VAL_37]], %[[VAL_16]] : i32
+// CHECK-DAG:       %[[VAL_39:.*]] = arith.shli %[[VAL_38]], %[[VAL_15]] : i32
+// CHECK-DAG:       %[[VAL_40:.*]] = arith.bitcast %[[VAL_39]] : i32 to f32
+// CHECK-DAG:       %[[VAL_41:.*]] = arith.mulf %[[VAL_36]], %[[VAL_40]] : f32
 // CHECK:           return %[[VAL_41]] : f32
 func.func @exp_scalar(%arg0: f32) -> f32 {
   %0 = math.exp %arg0 : f32
@@ -146,11 +145,7 @@ func.func @exp_scalar(%arg0: f32) -> f32 {
 
 // CHECK-LABEL:   func @exp_vector(
 // CHECK-SAME:                     %[[VAL_0:.*]]: vector<8xf32>) -> vector<8xf32> {
-// CHECK:           %[[VAL_1:.*]] = arith.constant dense<0.693147182> : vector<8xf32>
-// CHECK-NOT:       exp
-// CHECK-COUNT-4:   select
-// CHECK:           %[[VAL_40:.*]] = arith.select
-// CHECK:           return %[[VAL_40]] : vector<8xf32>
+// CHECK-NOT:   math.exp
 func.func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
   %0 = math.exp %arg0 : vector<8xf32>
   return %0 : vector<8xf32>
@@ -158,26 +153,114 @@ func.func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
 
 // CHECK-LABEL:   func @expm1_scalar(
 // CHECK-SAME:                       %[[X:.*]]: f32) -> f32 {
-// CHECK-DAG:           %[[CST_MINUSONE:.*]] = arith.constant -1.000000e+00 : f32
-// CHECK-DAG:           %[[CST_LOG2E:.*]] = arith.constant 1.44269502 : f32
-// CHECK-DAG:           %[[CST_ONE:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK:           %[[BEGIN_EXP_X:.*]] = arith.mulf %[[X]], %[[CST_LOG2E]] : f32
-// CHECK-NOT:       exp
-// CHECK-COUNT-4:   select
-// CHECK:           %[[EXP_X:.*]] = arith.select
-// CHECK:           %[[IS_ONE_OR_NAN:.*]] = arith.cmpf ueq, %[[EXP_X]], %[[CST_ONE]] : f32
-// CHECK:           %[[VAL_59:.*]] = arith.subf %[[EXP_X]], %[[CST_ONE]] : f32
-// CHECK:           %[[VAL_60:.*]] = arith.cmpf oeq, %[[VAL_59]], %[[CST_MINUSONE]] : f32
-// CHECK-NOT:       log
-// CHECK-COUNT-5:   select
-// CHECK:           %[[LOG_U:.*]] = arith.select
-// CHECK:           %[[VAL_104:.*]] = arith.cmpf oeq, %[[LOG_U]], %[[EXP_X]] : f32
-// CHECK:           %[[VAL_105:.*]] = arith.divf %[[X]], %[[LOG_U]] : f32
-// CHECK:           %[[VAL_106:.*]] = arith.mulf %[[VAL_59]], %[[VAL_105]] : f32
-// CHECK:           %[[VAL_107:.*]] = arith.select %[[VAL_104]], %[[EXP_X]], %[[VAL_106]] : f32
-// CHECK:           %[[VAL_108:.*]] = arith.select %[[VAL_60]], %[[CST_MINUSONE]], %[[VAL_107]] : f32
-// CHECK:           %[[VAL_109:.*]] = arith.select %[[IS_ONE_OR_NAN]], %[[X]], %[[VAL_108]] : f32
-// CHECK:           return %[[VAL_109]] : f32
+// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant -1.000000e+00 : f32
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 5.000000e-01 : f32
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 1.44269502 : f32
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant -0.693359375 : f32
+// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 2.12194442E-4 : f32
+// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 1.98756912E-4 : f32
+// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 0.00139819994 : f32
+// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 0.00833345205 : f32
+// CHECK-DAG:       %[[VAL_10:.*]] = arith.constant 0.0416657962 : f32
+// CHECK-DAG:       %[[VAL_11:.*]] = arith.constant 0.166666657 : f32
+// CHECK-DAG:       %[[VAL_12:.*]] = arith.constant -8.780000e+01 : f32
+// CHECK-DAG:       %[[VAL_13:.*]] = arith.constant 8.880000e+01 : f32
+// CHECK-DAG:       %[[VAL_14:.*]] = arith.constant -1.270000e+02 : f32
+// CHECK-DAG:       %[[VAL_15:.*]] = arith.constant 1.270000e+02 : f32
+// CHECK-DAG:       %[[VAL_16:.*]] = arith.constant 23 : i32
+// CHECK-DAG:       %[[VAL_17:.*]] = arith.constant 127 : i32
+// CHECK-DAG:       %[[VAL_18:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG:       %[[VAL_19:.*]] = arith.constant -5.000000e-01 : f32
+// CHECK-DAG:       %[[VAL_20:.*]] = arith.constant 1.17549435E-38 : f32
+// CHECK-DAG:       %[[VAL_21:.*]] = arith.constant 0xFF800000 : f32
+// CHECK-DAG:       %[[VAL_22:.*]] = arith.constant 0x7F800000 : f32
+// CHECK-DAG:       %[[VAL_23:.*]] = arith.constant 0x7FC00000 : f32
+// CHECK-DAG:       %[[VAL_24:.*]] = arith.constant 0.707106769 : f32
+// CHECK-DAG:       %[[VAL_25:.*]] = arith.constant 0.0703768358 : f32
+// CHECK-DAG:       %[[VAL_26:.*]] = arith.constant -0.115146101 : f32
+// CHECK-DAG:       %[[VAL_27:.*]] = arith.constant 0.116769984 : f32
+// CHECK-DAG:       %[[VAL_28:.*]] = arith.constant -0.12420141 : f32
+// CHECK-DAG:       %[[VAL_29:.*]] = arith.constant 0.142493233 : f32
+// CHECK-DAG:       %[[VAL_30:.*]] = arith.constant -0.166680574 : f32
+// CHECK-DAG:       %[[VAL_31:.*]] = arith.constant 0.200007141 : f32
+// CHECK-DAG:       %[[VAL_32:.*]] = arith.constant -0.24999994 : f32
+// CHECK-DAG:       %[[VAL_33:.*]] = arith.constant 0.333333313 : f32
+// CHECK-DAG:       %[[VAL_34:.*]] = arith.constant 1.260000e+02 : f32
+// CHECK-DAG:       %[[VAL_35:.*]] = arith.constant -2139095041 : i32
+// CHECK-DAG:       %[[VAL_36:.*]] = arith.constant 1056964608 : i32
+// CHECK-DAG:       %[[VAL_37:.*]] = arith.constant 0.693147182 : f32
+// CHECK-DAG:       %[[VAL_38:.*]] = arith.cmpf uge, %[[X]], %[[VAL_12]] : f32
+// CHECK-DAG:       %[[VAL_39:.*]] = arith.select %[[VAL_38]], %[[X]], %[[VAL_12]] : f32
+// CHECK-DAG:       %[[VAL_40:.*]] = arith.cmpf ule, %[[VAL_39]], %[[VAL_13]] : f32
+// CHECK-DAG:       %[[VAL_41:.*]] = arith.select %[[VAL_40]], %[[VAL_39]], %[[VAL_13]] : f32
+// CHECK-DAG:       %[[VAL_42:.*]] = math.fma %[[VAL_41]], %[[VAL_4]], %[[VAL_3]] : f32
+// CHECK-DAG:       %[[VAL_43:.*]] = math.floor %[[VAL_42]] : f32
+// CHECK-DAG:       %[[VAL_44:.*]] = arith.cmpf uge, %[[VAL_43]], %[[VAL_14]] : f32
+// CHECK-DAG:       %[[VAL_45:.*]] = arith.select %[[VAL_44]], %[[VAL_43]], %[[VAL_14]] : f32
+// CHECK-DAG:       %[[VAL_46:.*]] = arith.cmpf ule, %[[VAL_45]], %[[VAL_15]] : f32
+// CHECK-DAG:       %[[VAL_47:.*]] = arith.select %[[VAL_46]], %[[VAL_45]], %[[VAL_15]] : f32
+// CHECK-DAG:       %[[VAL_48:.*]] = math.fma %[[VAL_5]], %[[VAL_47]], %[[VAL_41]] : f32
+// CHECK-DAG:       %[[VAL_49:.*]] = math.fma %[[VAL_6]], %[[VAL_47]], %[[VAL_48]] : f32
+// CHECK-DAG:       %[[VAL_50:.*]] = math.fma %[[VAL_49]], %[[VAL_7]], %[[VAL_8]] : f32
+// CHECK-DAG:       %[[VAL_51:.*]] = math.fma %[[VAL_50]], %[[VAL_49]], %[[VAL_9]] : f32
+// CHECK-DAG:       %[[VAL_52:.*]] = math.fma %[[VAL_51]], %[[VAL_49]], %[[VAL_10]] : f32
+// CHECK-DAG:       %[[VAL_53:.*]] = math.fma %[[VAL_52]], %[[VAL_49]], %[[VAL_11]] : f32
+// CHECK-DAG:       %[[VAL_54:.*]] = math.fma %[[VAL_53]], %[[VAL_49]], %[[VAL_3]] : f32
+// CHECK-DAG:       %[[VAL_55:.*]] = arith.mulf %[[VAL_49]], %[[VAL_49]] : f32
+// CHECK-DAG:       %[[VAL_56:.*]] = math.fma %[[VAL_54]], %[[VAL_55]], %[[VAL_49]] : f32
+// CHECK-DAG:       %[[VAL_57:.*]] = arith.addf %[[VAL_56]], %[[VAL_1]] : f32
+// CHECK-DAG:       %[[VAL_58:.*]] = arith.fptosi %[[VAL_47]] : f32 to i32
+// CHECK-DAG:       %[[VAL_59:.*]] = arith.addi %[[VAL_58]], %[[VAL_17]] : i32
+// CHECK-DAG:       %[[VAL_60:.*]] = arith.shli %[[VAL_59]], %[[VAL_16]] : i32
+// CHECK-DAG:       %[[VAL_61:.*]] = arith.bitcast %[[VAL_60]] : i32 to f32
+// CHECK-DAG:       %[[VAL_62:.*]] = arith.mulf %[[VAL_57]], %[[VAL_61]] : f32
+// CHECK-DAG:       %[[VAL_63:.*]] = arith.cmpf ueq, %[[VAL_62]], %[[VAL_1]] : f32
+// CHECK-DAG:       %[[VAL_64:.*]] = arith.subf %[[VAL_62]], %[[VAL_1]] : f32
+// CHECK-DAG:       %[[VAL_65:.*]] = arith.cmpf oeq, %[[VAL_64]], %[[VAL_2]] : f32
+// CHECK-DAG:       %[[VAL_66:.*]] = arith.cmpf ugt, %[[VAL_62]], %[[VAL_20]] : f32
+// CHECK-DAG:       %[[VAL_67:.*]] = arith.select %[[VAL_66]], %[[VAL_62]], %[[VAL_20]] : f32
+// CHECK-DAG:       %[[VAL_68:.*]] = arith.bitcast %[[VAL_67]] : f32 to i32
+// CHECK-DAG:       %[[VAL_69:.*]] = arith.andi %[[VAL_68]], %[[VAL_35]] : i32
+// CHECK-DAG:       %[[VAL_70:.*]] = arith.ori %[[VAL_69]], %[[VAL_36]] : i32
+// CHECK-DAG:       %[[VAL_71:.*]] = arith.bitcast %[[VAL_70]] : i32 to f32
+// CHECK-DAG:       %[[VAL_72:.*]] = arith.bitcast %[[VAL_67]] : f32 to i32
+// CHECK-DAG:       %[[VAL_73:.*]] = arith.shrui %[[VAL_72]], %[[VAL_16]] : i32
+// CHECK-DAG:       %[[VAL_74:.*]] = arith.sitofp %[[VAL_73]] : i32 to f32
+// CHECK-DAG:       %[[VAL_75:.*]] = arith.subf %[[VAL_74]], %[[VAL_34]] : f32
+// CHECK-DAG:       %[[VAL_76:.*]] = arith.cmpf olt, %[[VAL_71]], %[[VAL_24]] : f32
+// CHECK-DAG:       %[[VAL_77:.*]] = arith.select %[[VAL_76]], %[[VAL_71]], %[[VAL_18]] : f32
+// CHECK-DAG:       %[[VAL_78:.*]] = arith.subf %[[VAL_71]], %[[VAL_1]] : f32
+// CHECK-DAG:       %[[VAL_79:.*]] = arith.select %[[VAL_76]], %[[VAL_1]], %[[VAL_18]] : f32
+// CHECK-DAG:       %[[VAL_80:.*]] = arith.subf %[[VAL_75]], %[[VAL_79]] : f32
+// CHECK-DAG:       %[[VAL_81:.*]] = arith.addf %[[VAL_78]], %[[VAL_77]] : f32
+// CHECK-DAG:       %[[VAL_82:.*]] = arith.mulf %[[VAL_81]], %[[VAL_81]] : f32
+// CHECK-DAG:       %[[VAL_83:.*]] = arith.mulf %[[VAL_82]], %[[VAL_81]] : f32
+// CHECK-DAG:       %[[VAL_84:.*]] = math.fma %[[VAL_25]], %[[VAL_81]], %[[VAL_26]] : f32
+// CHECK-DAG:       %[[VAL_85:.*]] = math.fma %[[VAL_28]], %[[VAL_81]], %[[VAL_29]] : f32
+// CHECK-DAG:       %[[VAL_86:.*]] = math.fma %[[VAL_31]], %[[VAL_81]], %[[VAL_32]] : f32
+// CHECK-DAG:       %[[VAL_87:.*]] = math.fma %[[VAL_84]], %[[VAL_81]], %[[VAL_27]] : f32
+// CHECK-DAG:       %[[VAL_88:.*]] = math.fma %[[VAL_85]], %[[VAL_81]], %[[VAL_30]] : f32
+// CHECK-DAG:       %[[VAL_89:.*]] = math.fma %[[VAL_86]], %[[VAL_81]], %[[VAL_33]] : f32
+// CHECK-DAG:       %[[VAL_90:.*]] = math.fma %[[VAL_87]], %[[VAL_83]], %[[VAL_88]] : f32
+// CHECK-DAG:       %[[VAL_91:.*]] = math.fma %[[VAL_90]], %[[VAL_83]], %[[VAL_89]] : f32
+// CHECK-DAG:       %[[VAL_92:.*]] = arith.mulf %[[VAL_91]], %[[VAL_83]] : f32
+// CHECK-DAG:       %[[VAL_93:.*]] = math.fma %[[VAL_19]], %[[VAL_82]], %[[VAL_92]] : f32
+// CHECK-DAG:       %[[VAL_94:.*]] = arith.addf %[[VAL_81]], %[[VAL_93]] : f32
+// CHECK-DAG:       %[[VAL_95:.*]] = math.fma %[[VAL_80]], %[[VAL_37]], %[[VAL_94]] : f32
+// CHECK-DAG:       %[[VAL_96:.*]] = arith.cmpf ult, %[[VAL_62]], %[[VAL_18]] : f32
+// CHECK-DAG:       %[[VAL_97:.*]] = arith.cmpf oeq, %[[VAL_62]], %[[VAL_18]] : f32
+// CHECK-DAG:       %[[VAL_98:.*]] = arith.cmpf oeq, %[[VAL_62]], %[[VAL_22]] : f32
+// CHECK-DAG:       %[[VAL_99:.*]] = arith.select %[[VAL_98]], %[[VAL_22]], %[[VAL_95]] : f32
+// CHECK-DAG:       %[[VAL_100:.*]] = arith.select %[[VAL_96]], %[[VAL_23]], %[[VAL_99]] : f32
+// CHECK-DAG:       %[[VAL_101:.*]] = arith.select %[[VAL_97]], %[[VAL_21]], %[[VAL_100]] : f32
+// CHECK-DAG:       %[[VAL_102:.*]] = arith.cmpf oeq, %[[VAL_101]], %[[VAL_62]] : f32
+// CHECK-DAG:       %[[VAL_103:.*]] = arith.divf %[[X]], %[[VAL_101]] : f32
+// CHECK-DAG:       %[[VAL_104:.*]] = arith.mulf %[[VAL_64]], %[[VAL_103]] : f32
+// CHECK-DAG:       %[[VAL_105:.*]] = arith.select %[[VAL_102]], %[[VAL_62]], %[[VAL_104]] : f32
+// CHECK-DAG:       %[[VAL_106:.*]] = arith.select %[[VAL_65]], %[[VAL_2]], %[[VAL_105]] : f32
+// CHECK-DAG:       %[[VAL_107:.*]] = arith.select %[[VAL_63]], %[[X]], %[[VAL_106]] : f32
+// CHECK-DAG:       return %[[VAL_107]] : f32
 // CHECK:         }
 func.func @expm1_scalar(%arg0: f32) -> f32 {
   %0 = math.expm1 %arg0 : f32
@@ -186,16 +269,9 @@ func.func @expm1_scalar(%arg0: f32) -> f32 {
 
 // CHECK-LABEL:   func @expm1_vector(
 // CHECK-SAME:                       %[[VAL_0:.*]]: vector<8x8xf32>) -> vector<8x8xf32> {
-// CHECK:           %[[VAL_1:.*]] = arith.constant dense<-1.000000e+00> : vector<8x8xf32>
 // CHECK-NOT:       exp
-// CHECK-COUNT-5:   select
 // CHECK-NOT:       log
-// CHECK-COUNT-5:   select
 // CHECK-NOT:       expm1
-// CHECK-COUNT-3:   select
-// CHECK:           %[[VAL_115:.*]] = arith.select
-// CHECK:           return %[[VAL_115]] : vector<8x8xf32>
-// CHECK:         }
 func.func @expm1_vector(%arg0: vector<8x8xf32>) -> vector<8x8xf32> {
   %0 = math.expm1 %arg0 : vector<8x8xf32>
   return %0 : vector<8x8xf32>

diff  --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
index 130147b01d0a7..8f47fa76c4cfd 100644
--- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
@@ -293,7 +293,7 @@ func.func @exp() {
   %f0 = arith.constant 1.0 : f32
   call @exp_f32(%f0) : (f32) -> ()
 
-  // CHECK: 0.778802, 2.117, 2.71828, 3.85742
+  // CHECK: 0.778801, 2.117, 2.71828, 3.85743
   %v1 = arith.constant dense<[-0.25, 0.75, 1.0, 1.35]> : vector<4xf32>
   call @exp_4xf32(%v1) : (vector<4xf32>) -> ()
 
@@ -301,7 +301,7 @@ func.func @exp() {
   %zero = arith.constant 0.0 : f32
   call @exp_f32(%zero) : (f32) -> ()
 
-  // CHECK: 1.17549e-38, 1.38879e-11, 7.20049e+10, inf
+  // CHECK: 0, 1.38879e-11, 7.20049e+10, inf
   %special_vec = arith.constant dense<[-89.0, -25.0, 25.0, 89.0]> : vector<4xf32>
   call @exp_4xf32(%special_vec) : (vector<4xf32>) -> ()
 
@@ -349,7 +349,7 @@ func.func @expm1() {
   %f0 = arith.constant 1.0e-10 : f32
   call @expm1_f32(%f0) : (f32) -> ()
 
-  // CHECK: -0.00995016, 0.0100502, 0.648721, 6.38905
+  // CHECK: -0.00995017, 0.0100502, 0.648721, 6.38906
   %v1 = arith.constant dense<[-0.01, 0.01, 0.5, 2.0]> : vector<4xf32>
   call @expm1_4xf32(%v1) : (vector<4xf32>) -> ()
 
@@ -701,5 +701,3 @@ func.func @main() {
   call @ceilf() : () -> ()
   return
 }
-
-


        


More information about the Mlir-commits mailing list