[Mlir-commits] [mlir] 0bedb66 - [mlir][math] Improved math.atan approximation

Rob Suderman llvmlistbot at llvm.org
Fri Jun 23 17:32:35 PDT 2023


Author: Robert Suderman
Date: 2023-06-23T17:25:34-07:00
New Revision: 0bedb667af1d105a243d5ad08a460f8b218724cb

URL: https://github.com/llvm/llvm-project/commit/0bedb667af1d105a243d5ad08a460f8b218724cb
DIFF: https://github.com/llvm/llvm-project/commit/0bedb667af1d105a243d5ad08a460f8b218724cb.diff

LOG: [mlir][math] Improved math.atan approximation

Used the cephes numerical approximation for `math.atan`. This is a
significant accuracy improvement over the previous taylor series
approximation.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D153656

Added: 
    

Modified: 
    mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
    mlir/test/Dialect/Math/polynomial-approximation.mlir
    mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 070ca0b7170b8..d48a6fae0a30d 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include <climits>
+#include <cmath>
 #include <cstddef>
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
@@ -171,7 +172,7 @@ static Value floatCst(ImplicitLocOpBuilder &builder, float value,
       builder.getFloatAttr(elementType, value));
 }
 
-static Value f32Cst(ImplicitLocOpBuilder &builder, float value) {
+static Value f32Cst(ImplicitLocOpBuilder &builder, double value) {
   return builder.create<arith::ConstantOp>(builder.getF32FloatAttr(value));
 }
 
@@ -380,35 +381,76 @@ AtanApproximation::matchAndRewrite(math::AtanOp op,
   ArrayRef<int64_t> shape = vectorShape(op.getOperand());
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
-  auto one = broadcast(builder, f32Cst(builder, 1.0f), shape);
-
-  // Remap the problem over [0.0, 1.0] by looking at the absolute value and the
-  // handling symmetry.
   Value abs = builder.create<math::AbsFOp>(operand);
-  Value reciprocal = builder.create<arith::DivFOp>(one, abs);
-  Value compare =
-      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, abs, reciprocal);
-  Value x = builder.create<arith::SelectOp>(compare, abs, reciprocal);
 
-  // Perform the Taylor series approximation for atan over the range
-  // [-1.0, 1.0].
-  auto n1 = broadcast(builder, f32Cst(builder, 0.14418283f), shape);
-  auto n2 = broadcast(builder, f32Cst(builder, -0.34999234f), shape);
-  auto n3 = broadcast(builder, f32Cst(builder, -0.01067831f), shape);
-  auto n4 = broadcast(builder, f32Cst(builder, 1.00209986f), shape);
-
-  Value p = builder.create<math::FmaOp>(x, n1, n2);
-  p = builder.create<math::FmaOp>(x, p, n3);
-  p = builder.create<math::FmaOp>(x, p, n4);
-  p = builder.create<arith::MulFOp>(x, p);
+  auto one = broadcast(builder, f32Cst(builder, 1.0), shape);
 
-  // Remap the solution for over [0.0, 1.0] to [0.0, inf]
-  auto halfPi = broadcast(builder, f32Cst(builder, 1.57079632679f), shape);
-  Value sub = builder.create<arith::SubFOp>(halfPi, p);
-  Value select = builder.create<arith::SelectOp>(compare, p, sub);
+  // When 0.66 < x <= 2.41 we do (x-1) / (x+1):
+  auto twoThirds = broadcast(builder, f32Cst(builder, 0.66), shape);
+  Value cmp2 =
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, abs, twoThirds);
+  Value addone = builder.create<arith::AddFOp>(abs, one);
+  Value subone = builder.create<arith::SubFOp>(abs, one);
+  Value xnum = builder.create<arith::SelectOp>(cmp2, subone, abs);
+  Value xden = builder.create<arith::SelectOp>(cmp2, addone, one);
+
+  auto bcast = [&](Value value) -> Value {
+    return broadcast(builder, value, shape);
+  };
+
+  // Break into the <= 0.66 or > 2.41 we do x or 1/x:
+  auto tan3pio8 = bcast(f32Cst(builder, 2.41421356237309504880));
+  Value cmp1 =
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, abs, tan3pio8);
+  xnum = builder.create<arith::SelectOp>(cmp1, one, xnum);
+  xden = builder.create<arith::SelectOp>(cmp1, abs, xden);
+
+  Value x = builder.create<arith::DivFOp>(xnum, xden);
+  Value xx = builder.create<arith::MulFOp>(x, x);
+
+  // Perform the Taylor series approximation for atan over the range
+  // [0.0, 0.66].
+  auto p0 = bcast(f32Cst(builder, -8.750608600031904122785e-01));
+  auto p1 = bcast(f32Cst(builder, -1.615753718733365076637e+01));
+  auto p2 = bcast(f32Cst(builder, -7.500855792314704667340e+01));
+  auto p3 = bcast(f32Cst(builder, -1.228866684490136173410e+02));
+  auto p4 = bcast(f32Cst(builder, -6.485021904942025371773e+01));
+  auto q0 = bcast(f32Cst(builder, +2.485846490142306297962e+01));
+  auto q1 = bcast(f32Cst(builder, +1.650270098316988542046e+02));
+  auto q2 = bcast(f32Cst(builder, +4.328810604912902668951e+02));
+  auto q3 = bcast(f32Cst(builder, +4.853903996359136964868e+02));
+  auto q4 = bcast(f32Cst(builder, +1.945506571482613964425e+02));
+
+  // Apply the polynomial approximation for the numerator:
+  Value n = p0;
+  n = builder.create<math::FmaOp>(xx, n, p1);
+  n = builder.create<math::FmaOp>(xx, n, p2);
+  n = builder.create<math::FmaOp>(xx, n, p3);
+  n = builder.create<math::FmaOp>(xx, n, p4);
+  n = builder.create<arith::MulFOp>(n, xx);
+
+  // Apply the polynomial approximation for the denominator:
+  Value d = q0;
+  d = builder.create<math::FmaOp>(xx, d, q1);
+  d = builder.create<math::FmaOp>(xx, d, q2);
+  d = builder.create<math::FmaOp>(xx, d, q3);
+  d = builder.create<math::FmaOp>(xx, d, q4);
+
+  // Compute approximation of theta:
+  Value ans0 = builder.create<arith::DivFOp>(n, d);
+  ans0 = builder.create<math::FmaOp>(ans0, x, x);
+
+  // Correct for the input mapping's angles:
+  Value mpi4 = bcast(f32Cst(builder, M_PI_4));
+  Value ans2 = builder.create<arith::AddFOp>(mpi4, ans0);
+  Value ans = builder.create<arith::SelectOp>(cmp2, ans2, ans0);
+
+  Value mpi2 = bcast(f32Cst(builder, M_PI_2));
+  Value ans1 = builder.create<arith::SubFOp>(mpi2, ans0);
+  ans = builder.create<arith::SelectOp>(cmp1, ans1, ans);
 
   // Correct for signing of the input.
-  rewriter.replaceOpWithNewOp<math::CopySignOp>(op, select, operand);
+  rewriter.replaceOpWithNewOp<math::CopySignOp>(op, ans, operand);
   return success();
 }
 

diff  --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
index 3c87ecf72011d..6743998d6162c 100644
--- a/mlir/test/Dialect/Math/polynomial-approximation.mlir
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -587,24 +587,50 @@ func.func @rsqrt_vector_2x16xf32(%arg0: vector<2x16xf32>) -> vector<2x16xf32> {
 }
 
 // CHECK-LABEL: @atan_scalar
-// CHECK-DAG:  %[[ONE:.+]] = arith.constant 1.000000e+00
-// CHECK-DAG:  %[[N1:.+]] = arith.constant 0.144182831
-// CHECK-DAG:  %[[N2:.+]] = arith.constant -0.349992335
-// CHECK-DAG:  %[[N3:.+]] = arith.constant -0.0106783099
-// CHECK-DAG:  %[[N4:.+]] = arith.constant 1.00209987
-// CHECK-DAG:  %[[HALF_PI:.+]] = arith.constant 1.57079637
-// CHECK-DAG:  %[[ABS:.+]] = math.absf %arg0
-// CHECK-DAG:  %[[DIV:.+]] = arith.divf %cst, %[[ABS]]
-// CHECK-DAG:  %[[CMP:.+]] = arith.cmpf olt, %[[ABS]], %[[DIV]]
-// CHECK-DAG:  %[[SEL:.+]] = arith.select %[[CMP]], %[[ABS]], %[[DIV]]
-// CHECK-DAG:  %[[P0:.+]] = math.fma %[[SEL]], %[[N1]], %[[N2]]
-// CHECK-DAG:  %[[P1:.+]] = math.fma %[[SEL]], %[[P0]], %[[N3]]
-// CHECK-DAG:  %[[P2:.+]] = math.fma %[[SEL]], %[[P1]], %[[N4]]
-// CHECK-DAG:  %[[P3:.+]] = arith.mulf %[[SEL]], %[[P2]]
-// CHECK-DAG:  %[[SUB:.+]] = arith.subf %[[HALF_PI]], %[[P3]]
-// CHECK-DAG:  %[[EST:.+]] = arith.select %[[CMP]], %[[P3]], %[[SUB]]
-// CHECK-DAG:  %[[RES:.+]] = math.copysign %[[EST]], %arg0
-// CHECK:  return %[[RES]]
+// CHECK-SAME:      %[[VAL_0:.*]]: f32) -> f32 {
+// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 6.600000e-01 : f32
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 2.41421366 : f32
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant -0.875060856 : f32
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant -16.1575375 : f32
+// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant -75.0085601 : f32
+// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant -122.886665 : f32
+// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant -64.8502197 : f32
+// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 24.8584652 : f32
+// CHECK-DAG:       %[[VAL_10:.*]] = arith.constant 165.027008 : f32
+// CHECK-DAG:       %[[VAL_11:.*]] = arith.constant 432.881073 : f32
+// CHECK-DAG:       %[[VAL_12:.*]] = arith.constant 485.390411 : f32
+// CHECK-DAG:       %[[VAL_13:.*]] = arith.constant 194.550659 : f32
+// CHECK-DAG:       %[[VAL_14:.*]] = arith.constant 0.785398185 : f32
+// CHECK-DAG:       %[[VAL_15:.*]] = arith.constant 1.57079637 : f32
+// CHECK-DAG:       %[[VAL_16:.*]] = math.absf %[[VAL_0]] : f32
+// CHECK-DAG:       %[[VAL_17:.*]] = arith.cmpf ogt, %[[VAL_16]], %[[VAL_2]] : f32
+// CHECK-DAG:       %[[VAL_18:.*]] = arith.addf %[[VAL_16]], %[[VAL_1]] : f32
+// CHECK-DAG:       %[[VAL_19:.*]] = arith.subf %[[VAL_16]], %[[VAL_1]] : f32
+// CHECK-DAG:       %[[VAL_20:.*]] = arith.select %[[VAL_17]], %[[VAL_19]], %[[VAL_16]] : f32
+// CHECK-DAG:       %[[VAL_21:.*]] = arith.select %[[VAL_17]], %[[VAL_18]], %[[VAL_1]] : f32
+// CHECK-DAG:       %[[VAL_22:.*]] = arith.cmpf ogt, %[[VAL_16]], %[[VAL_3]] : f32
+// CHECK-DAG:       %[[VAL_23:.*]] = arith.select %[[VAL_22]], %[[VAL_1]], %[[VAL_20]] : f32
+// CHECK-DAG:       %[[VAL_24:.*]] = arith.select %[[VAL_22]], %[[VAL_16]], %[[VAL_21]] : f32
+// CHECK-DAG:       %[[VAL_25:.*]] = arith.divf %[[VAL_23]], %[[VAL_24]] : f32
+// CHECK-DAG:       %[[VAL_26:.*]] = arith.mulf %[[VAL_25]], %[[VAL_25]] : f32
+// CHECK-DAG:       %[[VAL_27:.*]] = math.fma %[[VAL_26]], %[[VAL_4]], %[[VAL_5]] : f32
+// CHECK-DAG:       %[[VAL_28:.*]] = math.fma %[[VAL_26]], %[[VAL_27]], %[[VAL_6]] : f32
+// CHECK-DAG:       %[[VAL_29:.*]] = math.fma %[[VAL_26]], %[[VAL_28]], %[[VAL_7]] : f32
+// CHECK-DAG:       %[[VAL_30:.*]] = math.fma %[[VAL_26]], %[[VAL_29]], %[[VAL_8]] : f32
+// CHECK-DAG:       %[[VAL_31:.*]] = arith.mulf %[[VAL_30]], %[[VAL_26]] : f32
+// CHECK-DAG:       %[[VAL_32:.*]] = math.fma %[[VAL_26]], %[[VAL_9]], %[[VAL_10]] : f32
+// CHECK-DAG:       %[[VAL_33:.*]] = math.fma %[[VAL_26]], %[[VAL_32]], %[[VAL_11]] : f32
+// CHECK-DAG:       %[[VAL_34:.*]] = math.fma %[[VAL_26]], %[[VAL_33]], %[[VAL_12]] : f32
+// CHECK-DAG:       %[[VAL_35:.*]] = math.fma %[[VAL_26]], %[[VAL_34]], %[[VAL_13]] : f32
+// CHECK-DAG:       %[[VAL_36:.*]] = arith.divf %[[VAL_31]], %[[VAL_35]] : f32
+// CHECK-DAG:       %[[VAL_37:.*]] = math.fma %[[VAL_36]], %[[VAL_25]], %[[VAL_25]] : f32
+// CHECK-DAG:       %[[VAL_38:.*]] = arith.addf %[[VAL_37]], %[[VAL_14]] : f32
+// CHECK-DAG:       %[[VAL_39:.*]] = arith.select %[[VAL_17]], %[[VAL_38]], %[[VAL_37]] : f32
+// CHECK-DAG:       %[[VAL_40:.*]] = arith.subf %[[VAL_15]], %[[VAL_37]] : f32
+// CHECK-DAG:       %[[VAL_41:.*]] = arith.select %[[VAL_22]], %[[VAL_40]], %[[VAL_39]] : f32
+// CHECK-DAG:       %[[VAL_42:.*]] = math.copysign %[[VAL_41]], %[[VAL_0]] : f32
+// CHECK:           return %[[VAL_42]] : f3
 func.func @atan_scalar(%arg0: f32) -> f32 {
   %0 = math.atan %arg0 : f32
   return %0 : f32
@@ -612,59 +638,75 @@ func.func @atan_scalar(%arg0: f32) -> f32 {
 
 
 // CHECK-LABEL: @atan2_scalar
-
-// ATan approximation:
-// CHECK-DAG:  %[[ONE:.+]] = arith.constant 1.000000e+00
-// CHECK-DAG:  %[[N1:.+]] = arith.constant 0.144182831
-// CHECK-DAG:  %[[N2:.+]] = arith.constant -0.349992335
-// CHECK-DAG:  %[[N3:.+]] = arith.constant -0.0106783099
-// CHECK-DAG:  %[[N4:.+]] = arith.constant 1.00209987
-// CHECK-DAG:  %[[HALF_PI:.+]] = arith.constant 1.57079637
-// CHECK-DAG:  %[[ARG0:.+]] = arith.extf %arg0 : f16 to f32
-// CHECK-DAG:  %[[ARG1:.+]] = arith.extf %arg1 : f16 to f32
-// CHECK-DAG:  %[[RATIO:.+]] = arith.divf %[[ARG0]], %[[ARG1]]
-// CHECK-DAG:  %[[ABS:.+]] = math.absf %[[RATIO]]
-// CHECK-DAG:  %[[DIV:.+]] = arith.divf %cst, %[[ABS]]
-// CHECK-DAG:  %[[CMP:.+]] = arith.cmpf olt, %[[ABS]], %[[DIV]]
-// CHECK-DAG:  %[[SEL:.+]] = arith.select %[[CMP]], %[[ABS]], %[[DIV]]
-// CHECK-DAG:  %[[P0:.+]] = math.fma %[[SEL]], %[[N1]], %[[N2]]
-// CHECK-DAG:  %[[P1:.+]] = math.fma %[[SEL]], %[[P0]], %[[N3]]
-// CHECK-DAG:  %[[P2:.+]] = math.fma %[[SEL]], %[[P1]], %[[N4]]
-// CHECK-DAG:  %[[P3:.+]] = arith.mulf %[[SEL]], %[[P2]]
-// CHECK-DAG:  %[[SUB:.+]] = arith.subf %[[HALF_PI]], %[[P3]]
-// CHECK-DAG:  %[[EST:.+]] = arith.select %[[CMP]], %[[P3]], %[[SUB]]
-// CHECK-DAG:  %[[ATAN:.+]] = math.copysign %[[EST]], %[[RATIO]]
-
-// Handle the case of x < 0:
-// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00
-// CHECK-DAG: %[[PI:.+]] = arith.constant 3.14159274
-// CHECK-DAG: %[[ADD_PI:.+]] = arith.addf %[[ATAN]], %[[PI]]
-// CHECK-DAG: %[[SUB_PI:.+]] = arith.subf %[[ATAN]], %[[PI]]
-// CHECK-DAG: %[[CMP_ATAN:.+]] = arith.cmpf ogt, %[[ATAN]], %[[ZERO]]
-// CHECK-DAG: %[[ATAN_ADJUST:.+]] = arith.select %[[CMP_ATAN]], %[[SUB_PI]], %[[ADD_PI]]
-// CHECK-DAG: %[[X_NEG:.+]] = arith.cmpf ogt, %[[ARG1]], %[[ZERO]]
-// CHECK-DAG: %[[ATAN_EST:.+]] = arith.select %[[X_NEG]], %[[ATAN]], %[[ATAN_ADJUST]]
-
-// Handle PI / 2 edge case:
-// CHECK-DAG: %[[X_ZERO:.+]] = arith.cmpf oeq, %[[ARG1]], %[[ZERO]]
-// CHECK-DAG: %[[Y_POS:.+]] = arith.cmpf ogt, %[[ARG0]], %[[ZERO]]
-// CHECK-DAG: %[[IS_HALF_PI:.+]] = arith.andi %[[X_ZERO]], %[[Y_POS]]
-// CHECK-DAG: %[[EDGE1:.+]] = arith.select %[[IS_HALF_PI]], %[[HALF_PI]], %[[ATAN_EST]]
-
-// Handle -PI / 2 edge case:
-// CHECK-DAG: %[[NEG_HALF_PI:.+]] = arith.constant -1.57079637
-// CHECK-DAG: %[[Y_NEG:.+]] = arith.cmpf olt, %[[ARG0]], %[[ZERO]]
-// CHECK-DAG: %[[IS_NEG_HALF_PI:.+]] = arith.andi %[[X_ZERO]], %[[Y_NEG]]
-// CHECK-DAG: %[[EDGE2:.+]] = arith.select %[[IS_NEG_HALF_PI]], %[[NEG_HALF_PI]], %[[EDGE1]]
-
-// Handle Nan edgecase:
-// CHECK-DAG: %[[Y_ZERO:.+]] = arith.cmpf oeq, %[[ARG0]], %[[ZERO]]
-// CHECK-DAG: %[[X_Y_ZERO:.+]] = arith.andi %[[X_ZERO]], %[[Y_ZERO]]
-// CHECK-DAG: %[[NAN:.+]] = arith.constant 0x7FC00000
-// CHECK-DAG: %[[EDGE3:.+]] = arith.select %[[X_Y_ZERO]], %[[NAN]], %[[EDGE2]]
-// CHECK: %[[RET:.+]] = arith.truncf %[[EDGE3]]
-// CHECK: return %[[RET]]
-
+// CHECK-SAME:      %[[VAL_0:.*]]: f16,
+// CHECK-SAME:      %[[VAL_1:.*]]: f16)
+// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 6.600000e-01 : f32
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 2.41421366 : f32
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant -0.875060856 : f32
+// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant -16.1575375 : f32
+// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant -75.0085601 : f32
+// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant -122.886665 : f32
+// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant -64.8502197 : f32
+// CHECK-DAG:       %[[VAL_10:.*]] = arith.constant 24.8584652 : f32
+// CHECK-DAG:       %[[VAL_11:.*]] = arith.constant 165.027008 : f32
+// CHECK-DAG:       %[[VAL_12:.*]] = arith.constant 432.881073 : f32
+// CHECK-DAG:       %[[VAL_13:.*]] = arith.constant 485.390411 : f32
+// CHECK-DAG:       %[[VAL_14:.*]] = arith.constant 194.550659 : f32
+// CHECK-DAG:       %[[VAL_15:.*]] = arith.constant 0.785398185 : f32
+// CHECK-DAG:       %[[VAL_16:.*]] = arith.constant 1.57079637 : f32
+// CHECK-DAG:       %[[VAL_17:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG:       %[[VAL_18:.*]] = arith.constant 3.14159274 : f32
+// CHECK-DAG:       %[[VAL_19:.*]] = arith.constant -1.57079637 : f32
+// CHECK-DAG:       %[[VAL_20:.*]] = arith.constant 0x7FC00000 : f32
+// CHECK-DAG:       %[[VAL_21:.*]] = arith.extf %[[VAL_0]] : f16 to f32
+// CHECK-DAG:       %[[VAL_22:.*]] = arith.extf %[[VAL_1]] : f16 to f32
+// CHECK-DAG:       %[[VAL_23:.*]] = arith.divf %[[VAL_21]], %[[VAL_22]] : f32
+// CHECK-DAG:       %[[VAL_24:.*]] = math.absf %[[VAL_23]] : f32
+// CHECK-DAG:       %[[VAL_25:.*]] = arith.cmpf ogt, %[[VAL_24]], %[[VAL_3]] : f32
+// CHECK-DAG:       %[[VAL_26:.*]] = arith.addf %[[VAL_24]], %[[VAL_2]] : f32
+// CHECK-DAG:       %[[VAL_27:.*]] = arith.subf %[[VAL_24]], %[[VAL_2]] : f32
+// CHECK-DAG:       %[[VAL_28:.*]] = arith.select %[[VAL_25]], %[[VAL_27]], %[[VAL_24]] : f32
+// CHECK-DAG:       %[[VAL_29:.*]] = arith.select %[[VAL_25]], %[[VAL_26]], %[[VAL_2]] : f32
+// CHECK-DAG:       %[[VAL_30:.*]] = arith.cmpf ogt, %[[VAL_24]], %[[VAL_4]] : f32
+// CHECK-DAG:       %[[VAL_31:.*]] = arith.select %[[VAL_30]], %[[VAL_2]], %[[VAL_28]] : f32
+// CHECK-DAG:       %[[VAL_32:.*]] = arith.select %[[VAL_30]], %[[VAL_24]], %[[VAL_29]] : f32
+// CHECK-DAG:       %[[VAL_33:.*]] = arith.divf %[[VAL_31]], %[[VAL_32]] : f32
+// CHECK-DAG:       %[[VAL_34:.*]] = arith.mulf %[[VAL_33]], %[[VAL_33]] : f32
+// CHECK-DAG:       %[[VAL_35:.*]] = math.fma %[[VAL_34]], %[[VAL_5]], %[[VAL_6]] : f32
+// CHECK-DAG:       %[[VAL_36:.*]] = math.fma %[[VAL_34]], %[[VAL_35]], %[[VAL_7]] : f32
+// CHECK-DAG:       %[[VAL_37:.*]] = math.fma %[[VAL_34]], %[[VAL_36]], %[[VAL_8]] : f32
+// CHECK-DAG:       %[[VAL_38:.*]] = math.fma %[[VAL_34]], %[[VAL_37]], %[[VAL_9]] : f32
+// CHECK-DAG:       %[[VAL_39:.*]] = arith.mulf %[[VAL_38]], %[[VAL_34]] : f32
+// CHECK-DAG:       %[[VAL_40:.*]] = math.fma %[[VAL_34]], %[[VAL_10]], %[[VAL_11]] : f32
+// CHECK-DAG:       %[[VAL_41:.*]] = math.fma %[[VAL_34]], %[[VAL_40]], %[[VAL_12]] : f32
+// CHECK-DAG:       %[[VAL_42:.*]] = math.fma %[[VAL_34]], %[[VAL_41]], %[[VAL_13]] : f32
+// CHECK-DAG:       %[[VAL_43:.*]] = math.fma %[[VAL_34]], %[[VAL_42]], %[[VAL_14]] : f32
+// CHECK-DAG:       %[[VAL_44:.*]] = arith.divf %[[VAL_39]], %[[VAL_43]] : f32
+// CHECK-DAG:       %[[VAL_45:.*]] = math.fma %[[VAL_44]], %[[VAL_33]], %[[VAL_33]] : f32
+// CHECK-DAG:       %[[VAL_46:.*]] = arith.addf %[[VAL_45]], %[[VAL_15]] : f32
+// CHECK-DAG:       %[[VAL_47:.*]] = arith.select %[[VAL_25]], %[[VAL_46]], %[[VAL_45]] : f32
+// CHECK-DAG:       %[[VAL_48:.*]] = arith.subf %[[VAL_16]], %[[VAL_45]] : f32
+// CHECK-DAG:       %[[VAL_49:.*]] = arith.select %[[VAL_30]], %[[VAL_48]], %[[VAL_47]] : f32
+// CHECK-DAG:       %[[VAL_50:.*]] = math.copysign %[[VAL_49]], %[[VAL_23]] : f32
+// CHECK-DAG:       %[[VAL_51:.*]] = arith.addf %[[VAL_50]], %[[VAL_18]] : f32
+// CHECK-DAG:       %[[VAL_52:.*]] = arith.subf %[[VAL_50]], %[[VAL_18]] : f32
+// CHECK-DAG:       %[[VAL_53:.*]] = arith.cmpf ogt, %[[VAL_50]], %[[VAL_17]] : f32
+// CHECK-DAG:       %[[VAL_54:.*]] = arith.select %[[VAL_53]], %[[VAL_52]], %[[VAL_51]] : f32
+// CHECK-DAG:       %[[VAL_55:.*]] = arith.cmpf ogt, %[[VAL_22]], %[[VAL_17]] : f32
+// CHECK-DAG:       %[[VAL_56:.*]] = arith.select %[[VAL_55]], %[[VAL_50]], %[[VAL_54]] : f32
+// CHECK-DAG:       %[[VAL_57:.*]] = arith.cmpf oeq, %[[VAL_22]], %[[VAL_17]] : f32
+// CHECK-DAG:       %[[VAL_58:.*]] = arith.cmpf ogt, %[[VAL_21]], %[[VAL_17]] : f32
+// CHECK-DAG:       %[[VAL_59:.*]] = arith.andi %[[VAL_57]], %[[VAL_58]] : i1
+// CHECK-DAG:       %[[VAL_60:.*]] = arith.select %[[VAL_59]], %[[VAL_16]], %[[VAL_56]] : f32
+// CHECK-DAG:       %[[VAL_61:.*]] = arith.cmpf olt, %[[VAL_21]], %[[VAL_17]] : f32
+// CHECK-DAG:       %[[VAL_62:.*]] = arith.andi %[[VAL_57]], %[[VAL_61]] : i1
+// CHECK-DAG:       %[[VAL_63:.*]] = arith.select %[[VAL_62]], %[[VAL_19]], %[[VAL_60]] : f32
+// CHECK-DAG:       %[[VAL_64:.*]] = arith.cmpf oeq, %[[VAL_21]], %[[VAL_17]] : f32
+// CHECK-DAG:       %[[VAL_65:.*]] = arith.andi %[[VAL_57]], %[[VAL_64]] : i1
+// CHECK-DAG:       %[[VAL_66:.*]] = arith.select %[[VAL_65]], %[[VAL_20]], %[[VAL_63]] : f32
+// CHECK-DAG:       %[[VAL_67:.*]] = arith.truncf %[[VAL_66]] : f32 to f16
+// CHECK:           return %[[VAL_67]] : f1
 func.func @atan2_scalar(%arg0: f16, %arg1: f16) -> f16 {
   %0 = math.atan2 %arg0, %arg1 : f16
   return %0 : f16

diff  --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
index 8f47fa76c4cfd..058ebb28dff27 100644
--- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
@@ -471,19 +471,19 @@ func.func @atan_f32(%a : f32) {
 }
 
 func.func @atan() {
-  // CHECK: -0.785184
+  // CHECK: -0.785398
   %0 = arith.constant -1.0 : f32
   call @atan_f32(%0) : (f32) -> ()
 
-  // CHECK: 0.785184
+  // CHECK: 0.785398
   %1 = arith.constant 1.0 : f32
   call @atan_f32(%1) : (f32) -> ()
 
-  // CHECK: -0.463643
+  // CHECK: -0.463648
   %2 = arith.constant -0.5 : f32
   call @atan_f32(%2) : (f32) -> ()
 
-  // CHECK: 0.463643
+  // CHECK: 0.463648
   %3 = arith.constant 0.5 : f32
   call @atan_f32(%3) : (f32) -> ()
 
@@ -548,7 +548,7 @@ func.func @atan2() {
   // CHECK: -1.10715
   call @atan2_f32(%neg_two, %one) : (f32, f32) -> ()
 
-  // CHECK: 0.463643
+  // CHECK: 0.463648
   call @atan2_f32(%one, %two) : (f32, f32) -> ()
 
   // CHECK: 2.67795
@@ -561,7 +561,7 @@ func.func @atan2() {
   %y11 = arith.constant -1.0 : f32
   call @atan2_f32(%neg_one, %neg_two) : (f32, f32) -> ()
 
-  // CHECK: -0.463643
+  // CHECK: -0.463648
   call @atan2_f32(%neg_one, %two) : (f32, f32) -> ()
 
   return


        


More information about the Mlir-commits mailing list