[Mlir-commits] [mlir] 0bedb66 - [mlir][math] Improved math.atan approximation
Rob Suderman
llvmlistbot at llvm.org
Fri Jun 23 17:32:35 PDT 2023
Author: Robert Suderman
Date: 2023-06-23T17:25:34-07:00
New Revision: 0bedb667af1d105a243d5ad08a460f8b218724cb
URL: https://github.com/llvm/llvm-project/commit/0bedb667af1d105a243d5ad08a460f8b218724cb
DIFF: https://github.com/llvm/llvm-project/commit/0bedb667af1d105a243d5ad08a460f8b218724cb.diff
LOG: [mlir][math] Improved math.atan approximation
Used the cephes numerical approximation for `math.atan`. This is a
significant accuracy improvement over the previous taylor series
approximation.
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D153656
Added:
Modified:
mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
mlir/test/Dialect/Math/polynomial-approximation.mlir
mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 070ca0b7170b8..d48a6fae0a30d 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include <climits>
+#include <cmath>
#include <cstddef>
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -171,7 +172,7 @@ static Value floatCst(ImplicitLocOpBuilder &builder, float value,
builder.getFloatAttr(elementType, value));
}
-static Value f32Cst(ImplicitLocOpBuilder &builder, float value) {
+static Value f32Cst(ImplicitLocOpBuilder &builder, double value) {
return builder.create<arith::ConstantOp>(builder.getF32FloatAttr(value));
}
@@ -380,35 +381,76 @@ AtanApproximation::matchAndRewrite(math::AtanOp op,
ArrayRef<int64_t> shape = vectorShape(op.getOperand());
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
- auto one = broadcast(builder, f32Cst(builder, 1.0f), shape);
-
- // Remap the problem over [0.0, 1.0] by looking at the absolute value and the
- // handling symmetry.
Value abs = builder.create<math::AbsFOp>(operand);
- Value reciprocal = builder.create<arith::DivFOp>(one, abs);
- Value compare =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, abs, reciprocal);
- Value x = builder.create<arith::SelectOp>(compare, abs, reciprocal);
- // Perform the Taylor series approximation for atan over the range
- // [-1.0, 1.0].
- auto n1 = broadcast(builder, f32Cst(builder, 0.14418283f), shape);
- auto n2 = broadcast(builder, f32Cst(builder, -0.34999234f), shape);
- auto n3 = broadcast(builder, f32Cst(builder, -0.01067831f), shape);
- auto n4 = broadcast(builder, f32Cst(builder, 1.00209986f), shape);
-
- Value p = builder.create<math::FmaOp>(x, n1, n2);
- p = builder.create<math::FmaOp>(x, p, n3);
- p = builder.create<math::FmaOp>(x, p, n4);
- p = builder.create<arith::MulFOp>(x, p);
+ auto one = broadcast(builder, f32Cst(builder, 1.0), shape);
- // Remap the solution for over [0.0, 1.0] to [0.0, inf]
- auto halfPi = broadcast(builder, f32Cst(builder, 1.57079632679f), shape);
- Value sub = builder.create<arith::SubFOp>(halfPi, p);
- Value select = builder.create<arith::SelectOp>(compare, p, sub);
+ // When 0.66 < x <= 2.41 we do (x-1) / (x+1):
+ auto twoThirds = broadcast(builder, f32Cst(builder, 0.66), shape);
+ Value cmp2 =
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, abs, twoThirds);
+ Value addone = builder.create<arith::AddFOp>(abs, one);
+ Value subone = builder.create<arith::SubFOp>(abs, one);
+ Value xnum = builder.create<arith::SelectOp>(cmp2, subone, abs);
+ Value xden = builder.create<arith::SelectOp>(cmp2, addone, one);
+
+ auto bcast = [&](Value value) -> Value {
+ return broadcast(builder, value, shape);
+ };
+
+ // Break into the <= 0.66 or > 2.41 we do x or 1/x:
+ auto tan3pio8 = bcast(f32Cst(builder, 2.41421356237309504880));
+ Value cmp1 =
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, abs, tan3pio8);
+ xnum = builder.create<arith::SelectOp>(cmp1, one, xnum);
+ xden = builder.create<arith::SelectOp>(cmp1, abs, xden);
+
+ Value x = builder.create<arith::DivFOp>(xnum, xden);
+ Value xx = builder.create<arith::MulFOp>(x, x);
+
+ // Perform the Taylor series approximation for atan over the range
+ // [0.0, 0.66].
+ auto p0 = bcast(f32Cst(builder, -8.750608600031904122785e-01));
+ auto p1 = bcast(f32Cst(builder, -1.615753718733365076637e+01));
+ auto p2 = bcast(f32Cst(builder, -7.500855792314704667340e+01));
+ auto p3 = bcast(f32Cst(builder, -1.228866684490136173410e+02));
+ auto p4 = bcast(f32Cst(builder, -6.485021904942025371773e+01));
+ auto q0 = bcast(f32Cst(builder, +2.485846490142306297962e+01));
+ auto q1 = bcast(f32Cst(builder, +1.650270098316988542046e+02));
+ auto q2 = bcast(f32Cst(builder, +4.328810604912902668951e+02));
+ auto q3 = bcast(f32Cst(builder, +4.853903996359136964868e+02));
+ auto q4 = bcast(f32Cst(builder, +1.945506571482613964425e+02));
+
+ // Apply the polynomial approximation for the numerator:
+ Value n = p0;
+ n = builder.create<math::FmaOp>(xx, n, p1);
+ n = builder.create<math::FmaOp>(xx, n, p2);
+ n = builder.create<math::FmaOp>(xx, n, p3);
+ n = builder.create<math::FmaOp>(xx, n, p4);
+ n = builder.create<arith::MulFOp>(n, xx);
+
+ // Apply the polynomial approximation for the denominator:
+ Value d = q0;
+ d = builder.create<math::FmaOp>(xx, d, q1);
+ d = builder.create<math::FmaOp>(xx, d, q2);
+ d = builder.create<math::FmaOp>(xx, d, q3);
+ d = builder.create<math::FmaOp>(xx, d, q4);
+
+ // Compute approximation of theta:
+ Value ans0 = builder.create<arith::DivFOp>(n, d);
+ ans0 = builder.create<math::FmaOp>(ans0, x, x);
+
+ // Correct for the input mapping's angles:
+ Value mpi4 = bcast(f32Cst(builder, M_PI_4));
+ Value ans2 = builder.create<arith::AddFOp>(mpi4, ans0);
+ Value ans = builder.create<arith::SelectOp>(cmp2, ans2, ans0);
+
+ Value mpi2 = bcast(f32Cst(builder, M_PI_2));
+ Value ans1 = builder.create<arith::SubFOp>(mpi2, ans0);
+ ans = builder.create<arith::SelectOp>(cmp1, ans1, ans);
// Correct for signing of the input.
- rewriter.replaceOpWithNewOp<math::CopySignOp>(op, select, operand);
+ rewriter.replaceOpWithNewOp<math::CopySignOp>(op, ans, operand);
return success();
}
diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
index 3c87ecf72011d..6743998d6162c 100644
--- a/mlir/test/Dialect/Math/polynomial-approximation.mlir
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -587,24 +587,50 @@ func.func @rsqrt_vector_2x16xf32(%arg0: vector<2x16xf32>) -> vector<2x16xf32> {
}
// CHECK-LABEL: @atan_scalar
-// CHECK-DAG: %[[ONE:.+]] = arith.constant 1.000000e+00
-// CHECK-DAG: %[[N1:.+]] = arith.constant 0.144182831
-// CHECK-DAG: %[[N2:.+]] = arith.constant -0.349992335
-// CHECK-DAG: %[[N3:.+]] = arith.constant -0.0106783099
-// CHECK-DAG: %[[N4:.+]] = arith.constant 1.00209987
-// CHECK-DAG: %[[HALF_PI:.+]] = arith.constant 1.57079637
-// CHECK-DAG: %[[ABS:.+]] = math.absf %arg0
-// CHECK-DAG: %[[DIV:.+]] = arith.divf %cst, %[[ABS]]
-// CHECK-DAG: %[[CMP:.+]] = arith.cmpf olt, %[[ABS]], %[[DIV]]
-// CHECK-DAG: %[[SEL:.+]] = arith.select %[[CMP]], %[[ABS]], %[[DIV]]
-// CHECK-DAG: %[[P0:.+]] = math.fma %[[SEL]], %[[N1]], %[[N2]]
-// CHECK-DAG: %[[P1:.+]] = math.fma %[[SEL]], %[[P0]], %[[N3]]
-// CHECK-DAG: %[[P2:.+]] = math.fma %[[SEL]], %[[P1]], %[[N4]]
-// CHECK-DAG: %[[P3:.+]] = arith.mulf %[[SEL]], %[[P2]]
-// CHECK-DAG: %[[SUB:.+]] = arith.subf %[[HALF_PI]], %[[P3]]
-// CHECK-DAG: %[[EST:.+]] = arith.select %[[CMP]], %[[P3]], %[[SUB]]
-// CHECK-DAG: %[[RES:.+]] = math.copysign %[[EST]], %arg0
-// CHECK: return %[[RES]]
+// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
+// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 6.600000e-01 : f32
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 2.41421366 : f32
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant -0.875060856 : f32
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant -16.1575375 : f32
+// CHECK-DAG: %[[VAL_6:.*]] = arith.constant -75.0085601 : f32
+// CHECK-DAG: %[[VAL_7:.*]] = arith.constant -122.886665 : f32
+// CHECK-DAG: %[[VAL_8:.*]] = arith.constant -64.8502197 : f32
+// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 24.8584652 : f32
+// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 165.027008 : f32
+// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 432.881073 : f32
+// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 485.390411 : f32
+// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 194.550659 : f32
+// CHECK-DAG: %[[VAL_14:.*]] = arith.constant 0.785398185 : f32
+// CHECK-DAG: %[[VAL_15:.*]] = arith.constant 1.57079637 : f32
+// CHECK-DAG: %[[VAL_16:.*]] = math.absf %[[VAL_0]] : f32
+// CHECK-DAG: %[[VAL_17:.*]] = arith.cmpf ogt, %[[VAL_16]], %[[VAL_2]] : f32
+// CHECK-DAG: %[[VAL_18:.*]] = arith.addf %[[VAL_16]], %[[VAL_1]] : f32
+// CHECK-DAG: %[[VAL_19:.*]] = arith.subf %[[VAL_16]], %[[VAL_1]] : f32
+// CHECK-DAG: %[[VAL_20:.*]] = arith.select %[[VAL_17]], %[[VAL_19]], %[[VAL_16]] : f32
+// CHECK-DAG: %[[VAL_21:.*]] = arith.select %[[VAL_17]], %[[VAL_18]], %[[VAL_1]] : f32
+// CHECK-DAG: %[[VAL_22:.*]] = arith.cmpf ogt, %[[VAL_16]], %[[VAL_3]] : f32
+// CHECK-DAG: %[[VAL_23:.*]] = arith.select %[[VAL_22]], %[[VAL_1]], %[[VAL_20]] : f32
+// CHECK-DAG: %[[VAL_24:.*]] = arith.select %[[VAL_22]], %[[VAL_16]], %[[VAL_21]] : f32
+// CHECK-DAG: %[[VAL_25:.*]] = arith.divf %[[VAL_23]], %[[VAL_24]] : f32
+// CHECK-DAG: %[[VAL_26:.*]] = arith.mulf %[[VAL_25]], %[[VAL_25]] : f32
+// CHECK-DAG: %[[VAL_27:.*]] = math.fma %[[VAL_26]], %[[VAL_4]], %[[VAL_5]] : f32
+// CHECK-DAG: %[[VAL_28:.*]] = math.fma %[[VAL_26]], %[[VAL_27]], %[[VAL_6]] : f32
+// CHECK-DAG: %[[VAL_29:.*]] = math.fma %[[VAL_26]], %[[VAL_28]], %[[VAL_7]] : f32
+// CHECK-DAG: %[[VAL_30:.*]] = math.fma %[[VAL_26]], %[[VAL_29]], %[[VAL_8]] : f32
+// CHECK-DAG: %[[VAL_31:.*]] = arith.mulf %[[VAL_30]], %[[VAL_26]] : f32
+// CHECK-DAG: %[[VAL_32:.*]] = math.fma %[[VAL_26]], %[[VAL_9]], %[[VAL_10]] : f32
+// CHECK-DAG: %[[VAL_33:.*]] = math.fma %[[VAL_26]], %[[VAL_32]], %[[VAL_11]] : f32
+// CHECK-DAG: %[[VAL_34:.*]] = math.fma %[[VAL_26]], %[[VAL_33]], %[[VAL_12]] : f32
+// CHECK-DAG: %[[VAL_35:.*]] = math.fma %[[VAL_26]], %[[VAL_34]], %[[VAL_13]] : f32
+// CHECK-DAG: %[[VAL_36:.*]] = arith.divf %[[VAL_31]], %[[VAL_35]] : f32
+// CHECK-DAG: %[[VAL_37:.*]] = math.fma %[[VAL_36]], %[[VAL_25]], %[[VAL_25]] : f32
+// CHECK-DAG: %[[VAL_38:.*]] = arith.addf %[[VAL_37]], %[[VAL_14]] : f32
+// CHECK-DAG: %[[VAL_39:.*]] = arith.select %[[VAL_17]], %[[VAL_38]], %[[VAL_37]] : f32
+// CHECK-DAG: %[[VAL_40:.*]] = arith.subf %[[VAL_15]], %[[VAL_37]] : f32
+// CHECK-DAG: %[[VAL_41:.*]] = arith.select %[[VAL_22]], %[[VAL_40]], %[[VAL_39]] : f32
+// CHECK-DAG: %[[VAL_42:.*]] = math.copysign %[[VAL_41]], %[[VAL_0]] : f32
+// CHECK: return %[[VAL_42]] : f3
func.func @atan_scalar(%arg0: f32) -> f32 {
%0 = math.atan %arg0 : f32
return %0 : f32
@@ -612,59 +638,75 @@ func.func @atan_scalar(%arg0: f32) -> f32 {
// CHECK-LABEL: @atan2_scalar
-
-// ATan approximation:
-// CHECK-DAG: %[[ONE:.+]] = arith.constant 1.000000e+00
-// CHECK-DAG: %[[N1:.+]] = arith.constant 0.144182831
-// CHECK-DAG: %[[N2:.+]] = arith.constant -0.349992335
-// CHECK-DAG: %[[N3:.+]] = arith.constant -0.0106783099
-// CHECK-DAG: %[[N4:.+]] = arith.constant 1.00209987
-// CHECK-DAG: %[[HALF_PI:.+]] = arith.constant 1.57079637
-// CHECK-DAG: %[[ARG0:.+]] = arith.extf %arg0 : f16 to f32
-// CHECK-DAG: %[[ARG1:.+]] = arith.extf %arg1 : f16 to f32
-// CHECK-DAG: %[[RATIO:.+]] = arith.divf %[[ARG0]], %[[ARG1]]
-// CHECK-DAG: %[[ABS:.+]] = math.absf %[[RATIO]]
-// CHECK-DAG: %[[DIV:.+]] = arith.divf %cst, %[[ABS]]
-// CHECK-DAG: %[[CMP:.+]] = arith.cmpf olt, %[[ABS]], %[[DIV]]
-// CHECK-DAG: %[[SEL:.+]] = arith.select %[[CMP]], %[[ABS]], %[[DIV]]
-// CHECK-DAG: %[[P0:.+]] = math.fma %[[SEL]], %[[N1]], %[[N2]]
-// CHECK-DAG: %[[P1:.+]] = math.fma %[[SEL]], %[[P0]], %[[N3]]
-// CHECK-DAG: %[[P2:.+]] = math.fma %[[SEL]], %[[P1]], %[[N4]]
-// CHECK-DAG: %[[P3:.+]] = arith.mulf %[[SEL]], %[[P2]]
-// CHECK-DAG: %[[SUB:.+]] = arith.subf %[[HALF_PI]], %[[P3]]
-// CHECK-DAG: %[[EST:.+]] = arith.select %[[CMP]], %[[P3]], %[[SUB]]
-// CHECK-DAG: %[[ATAN:.+]] = math.copysign %[[EST]], %[[RATIO]]
-
-// Handle the case of x < 0:
-// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00
-// CHECK-DAG: %[[PI:.+]] = arith.constant 3.14159274
-// CHECK-DAG: %[[ADD_PI:.+]] = arith.addf %[[ATAN]], %[[PI]]
-// CHECK-DAG: %[[SUB_PI:.+]] = arith.subf %[[ATAN]], %[[PI]]
-// CHECK-DAG: %[[CMP_ATAN:.+]] = arith.cmpf ogt, %[[ATAN]], %[[ZERO]]
-// CHECK-DAG: %[[ATAN_ADJUST:.+]] = arith.select %[[CMP_ATAN]], %[[SUB_PI]], %[[ADD_PI]]
-// CHECK-DAG: %[[X_NEG:.+]] = arith.cmpf ogt, %[[ARG1]], %[[ZERO]]
-// CHECK-DAG: %[[ATAN_EST:.+]] = arith.select %[[X_NEG]], %[[ATAN]], %[[ATAN_ADJUST]]
-
-// Handle PI / 2 edge case:
-// CHECK-DAG: %[[X_ZERO:.+]] = arith.cmpf oeq, %[[ARG1]], %[[ZERO]]
-// CHECK-DAG: %[[Y_POS:.+]] = arith.cmpf ogt, %[[ARG0]], %[[ZERO]]
-// CHECK-DAG: %[[IS_HALF_PI:.+]] = arith.andi %[[X_ZERO]], %[[Y_POS]]
-// CHECK-DAG: %[[EDGE1:.+]] = arith.select %[[IS_HALF_PI]], %[[HALF_PI]], %[[ATAN_EST]]
-
-// Handle -PI / 2 edge case:
-// CHECK-DAG: %[[NEG_HALF_PI:.+]] = arith.constant -1.57079637
-// CHECK-DAG: %[[Y_NEG:.+]] = arith.cmpf olt, %[[ARG0]], %[[ZERO]]
-// CHECK-DAG: %[[IS_NEG_HALF_PI:.+]] = arith.andi %[[X_ZERO]], %[[Y_NEG]]
-// CHECK-DAG: %[[EDGE2:.+]] = arith.select %[[IS_NEG_HALF_PI]], %[[NEG_HALF_PI]], %[[EDGE1]]
-
-// Handle Nan edgecase:
-// CHECK-DAG: %[[Y_ZERO:.+]] = arith.cmpf oeq, %[[ARG0]], %[[ZERO]]
-// CHECK-DAG: %[[X_Y_ZERO:.+]] = arith.andi %[[X_ZERO]], %[[Y_ZERO]]
-// CHECK-DAG: %[[NAN:.+]] = arith.constant 0x7FC00000
-// CHECK-DAG: %[[EDGE3:.+]] = arith.select %[[X_Y_ZERO]], %[[NAN]], %[[EDGE2]]
-// CHECK: %[[RET:.+]] = arith.truncf %[[EDGE3]]
-// CHECK: return %[[RET]]
-
+// CHECK-SAME: %[[VAL_0:.*]]: f16,
+// CHECK-SAME: %[[VAL_1:.*]]: f16)
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 6.600000e-01 : f32
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2.41421366 : f32
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant -0.875060856 : f32
+// CHECK-DAG: %[[VAL_6:.*]] = arith.constant -16.1575375 : f32
+// CHECK-DAG: %[[VAL_7:.*]] = arith.constant -75.0085601 : f32
+// CHECK-DAG: %[[VAL_8:.*]] = arith.constant -122.886665 : f32
+// CHECK-DAG: %[[VAL_9:.*]] = arith.constant -64.8502197 : f32
+// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 24.8584652 : f32
+// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 165.027008 : f32
+// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 432.881073 : f32
+// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 485.390411 : f32
+// CHECK-DAG: %[[VAL_14:.*]] = arith.constant 194.550659 : f32
+// CHECK-DAG: %[[VAL_15:.*]] = arith.constant 0.785398185 : f32
+// CHECK-DAG: %[[VAL_16:.*]] = arith.constant 1.57079637 : f32
+// CHECK-DAG: %[[VAL_17:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[VAL_18:.*]] = arith.constant 3.14159274 : f32
+// CHECK-DAG: %[[VAL_19:.*]] = arith.constant -1.57079637 : f32
+// CHECK-DAG: %[[VAL_20:.*]] = arith.constant 0x7FC00000 : f32
+// CHECK-DAG: %[[VAL_21:.*]] = arith.extf %[[VAL_0]] : f16 to f32
+// CHECK-DAG: %[[VAL_22:.*]] = arith.extf %[[VAL_1]] : f16 to f32
+// CHECK-DAG: %[[VAL_23:.*]] = arith.divf %[[VAL_21]], %[[VAL_22]] : f32
+// CHECK-DAG: %[[VAL_24:.*]] = math.absf %[[VAL_23]] : f32
+// CHECK-DAG: %[[VAL_25:.*]] = arith.cmpf ogt, %[[VAL_24]], %[[VAL_3]] : f32
+// CHECK-DAG: %[[VAL_26:.*]] = arith.addf %[[VAL_24]], %[[VAL_2]] : f32
+// CHECK-DAG: %[[VAL_27:.*]] = arith.subf %[[VAL_24]], %[[VAL_2]] : f32
+// CHECK-DAG: %[[VAL_28:.*]] = arith.select %[[VAL_25]], %[[VAL_27]], %[[VAL_24]] : f32
+// CHECK-DAG: %[[VAL_29:.*]] = arith.select %[[VAL_25]], %[[VAL_26]], %[[VAL_2]] : f32
+// CHECK-DAG: %[[VAL_30:.*]] = arith.cmpf ogt, %[[VAL_24]], %[[VAL_4]] : f32
+// CHECK-DAG: %[[VAL_31:.*]] = arith.select %[[VAL_30]], %[[VAL_2]], %[[VAL_28]] : f32
+// CHECK-DAG: %[[VAL_32:.*]] = arith.select %[[VAL_30]], %[[VAL_24]], %[[VAL_29]] : f32
+// CHECK-DAG: %[[VAL_33:.*]] = arith.divf %[[VAL_31]], %[[VAL_32]] : f32
+// CHECK-DAG: %[[VAL_34:.*]] = arith.mulf %[[VAL_33]], %[[VAL_33]] : f32
+// CHECK-DAG: %[[VAL_35:.*]] = math.fma %[[VAL_34]], %[[VAL_5]], %[[VAL_6]] : f32
+// CHECK-DAG: %[[VAL_36:.*]] = math.fma %[[VAL_34]], %[[VAL_35]], %[[VAL_7]] : f32
+// CHECK-DAG: %[[VAL_37:.*]] = math.fma %[[VAL_34]], %[[VAL_36]], %[[VAL_8]] : f32
+// CHECK-DAG: %[[VAL_38:.*]] = math.fma %[[VAL_34]], %[[VAL_37]], %[[VAL_9]] : f32
+// CHECK-DAG: %[[VAL_39:.*]] = arith.mulf %[[VAL_38]], %[[VAL_34]] : f32
+// CHECK-DAG: %[[VAL_40:.*]] = math.fma %[[VAL_34]], %[[VAL_10]], %[[VAL_11]] : f32
+// CHECK-DAG: %[[VAL_41:.*]] = math.fma %[[VAL_34]], %[[VAL_40]], %[[VAL_12]] : f32
+// CHECK-DAG: %[[VAL_42:.*]] = math.fma %[[VAL_34]], %[[VAL_41]], %[[VAL_13]] : f32
+// CHECK-DAG: %[[VAL_43:.*]] = math.fma %[[VAL_34]], %[[VAL_42]], %[[VAL_14]] : f32
+// CHECK-DAG: %[[VAL_44:.*]] = arith.divf %[[VAL_39]], %[[VAL_43]] : f32
+// CHECK-DAG: %[[VAL_45:.*]] = math.fma %[[VAL_44]], %[[VAL_33]], %[[VAL_33]] : f32
+// CHECK-DAG: %[[VAL_46:.*]] = arith.addf %[[VAL_45]], %[[VAL_15]] : f32
+// CHECK-DAG: %[[VAL_47:.*]] = arith.select %[[VAL_25]], %[[VAL_46]], %[[VAL_45]] : f32
+// CHECK-DAG: %[[VAL_48:.*]] = arith.subf %[[VAL_16]], %[[VAL_45]] : f32
+// CHECK-DAG: %[[VAL_49:.*]] = arith.select %[[VAL_30]], %[[VAL_48]], %[[VAL_47]] : f32
+// CHECK-DAG: %[[VAL_50:.*]] = math.copysign %[[VAL_49]], %[[VAL_23]] : f32
+// CHECK-DAG: %[[VAL_51:.*]] = arith.addf %[[VAL_50]], %[[VAL_18]] : f32
+// CHECK-DAG: %[[VAL_52:.*]] = arith.subf %[[VAL_50]], %[[VAL_18]] : f32
+// CHECK-DAG: %[[VAL_53:.*]] = arith.cmpf ogt, %[[VAL_50]], %[[VAL_17]] : f32
+// CHECK-DAG: %[[VAL_54:.*]] = arith.select %[[VAL_53]], %[[VAL_52]], %[[VAL_51]] : f32
+// CHECK-DAG: %[[VAL_55:.*]] = arith.cmpf ogt, %[[VAL_22]], %[[VAL_17]] : f32
+// CHECK-DAG: %[[VAL_56:.*]] = arith.select %[[VAL_55]], %[[VAL_50]], %[[VAL_54]] : f32
+// CHECK-DAG: %[[VAL_57:.*]] = arith.cmpf oeq, %[[VAL_22]], %[[VAL_17]] : f32
+// CHECK-DAG: %[[VAL_58:.*]] = arith.cmpf ogt, %[[VAL_21]], %[[VAL_17]] : f32
+// CHECK-DAG: %[[VAL_59:.*]] = arith.andi %[[VAL_57]], %[[VAL_58]] : i1
+// CHECK-DAG: %[[VAL_60:.*]] = arith.select %[[VAL_59]], %[[VAL_16]], %[[VAL_56]] : f32
+// CHECK-DAG: %[[VAL_61:.*]] = arith.cmpf olt, %[[VAL_21]], %[[VAL_17]] : f32
+// CHECK-DAG: %[[VAL_62:.*]] = arith.andi %[[VAL_57]], %[[VAL_61]] : i1
+// CHECK-DAG: %[[VAL_63:.*]] = arith.select %[[VAL_62]], %[[VAL_19]], %[[VAL_60]] : f32
+// CHECK-DAG: %[[VAL_64:.*]] = arith.cmpf oeq, %[[VAL_21]], %[[VAL_17]] : f32
+// CHECK-DAG: %[[VAL_65:.*]] = arith.andi %[[VAL_57]], %[[VAL_64]] : i1
+// CHECK-DAG: %[[VAL_66:.*]] = arith.select %[[VAL_65]], %[[VAL_20]], %[[VAL_63]] : f32
+// CHECK-DAG: %[[VAL_67:.*]] = arith.truncf %[[VAL_66]] : f32 to f16
+// CHECK: return %[[VAL_67]] : f1
func.func @atan2_scalar(%arg0: f16, %arg1: f16) -> f16 {
%0 = math.atan2 %arg0, %arg1 : f16
return %0 : f16
diff --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
index 8f47fa76c4cfd..058ebb28dff27 100644
--- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
@@ -471,19 +471,19 @@ func.func @atan_f32(%a : f32) {
}
func.func @atan() {
- // CHECK: -0.785184
+ // CHECK: -0.785398
%0 = arith.constant -1.0 : f32
call @atan_f32(%0) : (f32) -> ()
- // CHECK: 0.785184
+ // CHECK: 0.785398
%1 = arith.constant 1.0 : f32
call @atan_f32(%1) : (f32) -> ()
- // CHECK: -0.463643
+ // CHECK: -0.463648
%2 = arith.constant -0.5 : f32
call @atan_f32(%2) : (f32) -> ()
- // CHECK: 0.463643
+ // CHECK: 0.463648
%3 = arith.constant 0.5 : f32
call @atan_f32(%3) : (f32) -> ()
@@ -548,7 +548,7 @@ func.func @atan2() {
// CHECK: -1.10715
call @atan2_f32(%neg_two, %one) : (f32, f32) -> ()
- // CHECK: 0.463643
+ // CHECK: 0.463648
call @atan2_f32(%one, %two) : (f32, f32) -> ()
// CHECK: 2.67795
@@ -561,7 +561,7 @@ func.func @atan2() {
%y11 = arith.constant -1.0 : f32
call @atan2_f32(%neg_one, %neg_two) : (f32, f32) -> ()
- // CHECK: -0.463643
+ // CHECK: -0.463648
call @atan2_f32(%neg_one, %two) : (f32, f32) -> ()
return
More information about the Mlir-commits
mailing list