[Mlir-commits] [mlir] 2f9f9af - [mlir] Add polynomial approximation for atan and atan2

Rob Suderman llvmlistbot at llvm.org
Fri Jan 21 12:31:28 PST 2022


Author: Rob Suderman
Date: 2022-01-21T12:22:58-08:00
New Revision: 2f9f9afa4e1281b4ac7c8ad36860a4e35e6f5070

URL: https://github.com/llvm/llvm-project/commit/2f9f9afa4e1281b4ac7c8ad36860a4e35e6f5070
DIFF: https://github.com/llvm/llvm-project/commit/2f9f9afa4e1281b4ac7c8ad36860a4e35e6f5070.diff

LOG: [mlir] Add polynomial approximation for atan and atan2

Implement a taylor series approximation for atan and add an atan2 lowering
that uses atan's appromation. This includes tests for edge cases and tests
for each quadrant.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D115682

Added: 
    

Modified: 
    mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
    mlir/test/Dialect/Math/polynomial-approximation.mlir
    mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 9931e89647bcf..7d04ae7e3d34f 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -278,6 +278,133 @@ Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
 }
 } // namespace
 
+//----------------------------------------------------------------------------//
+// AtanOp approximation.
+//----------------------------------------------------------------------------//
+
+namespace {
+struct AtanApproximation : public OpRewritePattern<math::AtanOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(math::AtanOp op,
+                                PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+LogicalResult
+AtanApproximation::matchAndRewrite(math::AtanOp op,
+                                   PatternRewriter &rewriter) const {
+  auto operand = op.getOperand();
+  if (!getElementTypeOrSelf(operand).isF32())
+    return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+  ArrayRef<int64_t> shape = vectorShape(op.getOperand());
+
+  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+  auto one = broadcast(builder, f32Cst(builder, 1.0f), shape);
+
+  // Remap the problem over [0.0, 1.0] by looking at the absolute value and the
+  // handling symmetry.
+  Value abs = builder.create<math::AbsOp>(operand);
+  Value reciprocal = builder.create<arith::DivFOp>(one, abs);
+  Value compare =
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, abs, reciprocal);
+  Value x = builder.create<SelectOp>(compare, abs, reciprocal);
+
+  // Perform the Taylor series approximation for atan over the range
+  // [-1.0, 1.0].
+  auto n1 = broadcast(builder, f32Cst(builder, 0.14418283), shape);
+  auto n2 = broadcast(builder, f32Cst(builder, -0.34999234), shape);
+  auto n3 = broadcast(builder, f32Cst(builder, -0.01067831), shape);
+  auto n4 = broadcast(builder, f32Cst(builder, 1.00209986), shape);
+
+  Value p = builder.create<math::FmaOp>(x, n1, n2);
+  p = builder.create<math::FmaOp>(x, p, n3);
+  p = builder.create<math::FmaOp>(x, p, n4);
+  p = builder.create<arith::MulFOp>(x, p);
+
+  // Remap the solution for over [0.0, 1.0] to [0.0, inf]
+  auto half_pi = broadcast(builder, f32Cst(builder, 1.57079632679f), shape);
+  Value sub = builder.create<arith::SubFOp>(half_pi, p);
+  Value select = builder.create<SelectOp>(compare, p, sub);
+
+  // Correct for signing of the input.
+  rewriter.replaceOpWithNewOp<math::CopySignOp>(op, select, operand);
+  return success();
+}
+
+//----------------------------------------------------------------------------//
+// AtanOp approximation.
+//----------------------------------------------------------------------------//
+
+namespace {
+struct Atan2Approximation : public OpRewritePattern<math::Atan2Op> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(math::Atan2Op op,
+                                PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+LogicalResult
+Atan2Approximation::matchAndRewrite(math::Atan2Op op,
+                                    PatternRewriter &rewriter) const {
+  auto y = op.getOperand(0);
+  auto x = op.getOperand(1);
+  if (!getElementTypeOrSelf(x).isF32())
+    return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+  ArrayRef<int64_t> shape = vectorShape(op.getResult());
+
+  // Compute atan in the valid range.
+  auto div = builder.create<arith::DivFOp>(y, x);
+  auto atan = builder.create<math::AtanOp>(div);
+
+  // Determine what the atan would be for a 180 degree rotation.
+  auto zero = broadcast(builder, f32Cst(builder, 0.0f), shape);
+  auto pi = broadcast(builder, f32Cst(builder, 3.14159265359f), shape);
+  auto add_pi = builder.create<arith::AddFOp>(atan, pi);
+  auto sub_pi = builder.create<arith::SubFOp>(atan, pi);
+  auto atan_gt =
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, atan, zero);
+  auto flipped_atan = builder.create<SelectOp>(atan_gt, sub_pi, add_pi);
+
+  // Determine whether to directly use atan or use the 180 degree flip
+  auto x_gt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zero);
+  Value result = builder.create<SelectOp>(x_gt, atan, flipped_atan);
+
+  // Handle x = 0, y > 0
+  Value x_zero =
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, x, zero);
+  Value y_gt =
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, y, zero);
+  Value is_half_pi = builder.create<arith::AndIOp>(x_zero, y_gt);
+  auto half_pi = broadcast(builder, f32Cst(builder, 1.57079632679f), shape);
+  result = builder.create<SelectOp>(is_half_pi, half_pi, result);
+
+  // Handle x = 0, y < 0
+  Value y_lt =
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, y, zero);
+  Value is_negative_half_pi_pi = builder.create<arith::AndIOp>(x_zero, y_lt);
+  auto negative_half_pi_pi =
+      broadcast(builder, f32Cst(builder, -1.57079632679), shape);
+  result = builder.create<SelectOp>(is_negative_half_pi_pi, negative_half_pi_pi,
+                                    result);
+
+  // Handle x = 0, y = 0;
+  Value y_zero =
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, y, zero);
+  Value is_nan = builder.create<arith::AndIOp>(x_zero, y_zero);
+  Value cst_nan = broadcast(builder, f32FromBits(builder, 0x7fc00000), shape);
+  result = builder.create<SelectOp>(is_nan, cst_nan, result);
+
+  rewriter.replaceOp(op, result);
+  return success();
+}
+
 //----------------------------------------------------------------------------//
 // TanhOp approximation.
 //----------------------------------------------------------------------------//
@@ -1074,9 +1201,10 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
 void mlir::populateMathPolynomialApproximationPatterns(
     RewritePatternSet &patterns,
     const MathPolynomialApproximationOptions &options) {
-  patterns.add<TanhApproximation, LogApproximation, Log2Approximation,
-               Log1pApproximation, ErfPolynomialApproximation, ExpApproximation,
-               ExpM1Approximation, SinAndCosApproximation<true, math::SinOp>,
+  patterns.add<AtanApproximation, Atan2Approximation, TanhApproximation,
+               LogApproximation, Log2Approximation, Log1pApproximation,
+               ErfPolynomialApproximation, ExpApproximation, ExpM1Approximation,
+               SinAndCosApproximation<true, math::SinOp>,
                SinAndCosApproximation<false, math::CosOp>>(
       patterns.getContext());
   if (options.enableAvx2)

diff  --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
index f388b84d83c8e..a40cc9c6f037a 100644
--- a/mlir/test/Dialect/Math/polynomial-approximation.mlir
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -507,3 +507,85 @@ func @rsqrt_vector_2x16xf32(%arg0: vector<2x16xf32>) -> vector<2x16xf32> {
   %0 = math.rsqrt %arg0 : vector<2x16xf32>
   return %0 : vector<2x16xf32>
 }
+
+// CHECK-LABEL: @atan_scalar
+// CHECK-DAG:  %[[ONE:.+]] = arith.constant 1.000000e+00
+// CHECK-DAG:  %[[N1:.+]] = arith.constant 0.144182831
+// CHECK-DAG:  %[[N2:.+]] = arith.constant -0.349992335
+// CHECK-DAG:  %[[N3:.+]] = arith.constant -0.0106783099
+// CHECK-DAG:  %[[N4:.+]] = arith.constant 1.00209987
+// CHECK-DAG:  %[[HALF_PI:.+]] = arith.constant 1.57079637
+// CHECK-DAG:  %[[ABS:.+]] = math.abs %arg0
+// CHECK-DAG:  %[[DIV:.+]] = arith.divf %cst, %[[ABS]]
+// CHECK-DAG:  %[[CMP:.+]] = arith.cmpf olt, %[[ABS]], %[[DIV]]
+// CHECK-DAG:  %[[SEL:.+]] = select %[[CMP]], %[[ABS]], %[[DIV]]
+// CHECK-DAG:  %[[P0:.+]] = math.fma %[[SEL]], %[[N1]], %[[N2]]
+// CHECK-DAG:  %[[P1:.+]] = math.fma %[[SEL]], %[[P0]], %[[N3]]
+// CHECK-DAG:  %[[P2:.+]] = math.fma %[[SEL]], %[[P1]], %[[N4]]
+// CHECK-DAG:  %[[P3:.+]] = arith.mulf %[[SEL]], %[[P2]]
+// CHECK-DAG:  %[[SUB:.+]] = arith.subf %[[HALF_PI]], %[[P3]]
+// CHECK-DAG:  %[[EST:.+]] = select %[[CMP]], %[[P3]], %[[SUB]]
+// CHECK-DAG:  %[[RES:.+]] = math.copysign %[[EST]], %arg0
+// CHECK:  return %[[RES]]
+func @atan_scalar(%arg0: f32) -> f32 {
+  %0 = math.atan %arg0 : f32
+  return %0 : f32
+}
+
+
+// CHECK-LABEL: @atan2_scalar
+
+// ATan approximation:
+// CHECK-DAG:  %[[ONE:.+]] = arith.constant 1.000000e+00
+// CHECK-DAG:  %[[N1:.+]] = arith.constant 0.144182831
+// CHECK-DAG:  %[[N2:.+]] = arith.constant -0.349992335
+// CHECK-DAG:  %[[N3:.+]] = arith.constant -0.0106783099
+// CHECK-DAG:  %[[N4:.+]] = arith.constant 1.00209987
+// CHECK-DAG:  %[[HALF_PI:.+]] = arith.constant 1.57079637
+// CHECK-DAG:  %[[RATIO:.+]] = arith.divf %arg0, %arg1
+// CHECK-DAG:  %[[ABS:.+]] = math.abs %[[RATIO]]
+// CHECK-DAG:  %[[DIV:.+]] = arith.divf %cst, %[[ABS]]
+// CHECK-DAG:  %[[CMP:.+]] = arith.cmpf olt, %[[ABS]], %[[DIV]]
+// CHECK-DAG:  %[[SEL:.+]] = select %[[CMP]], %[[ABS]], %[[DIV]]
+// CHECK-DAG:  %[[P0:.+]] = math.fma %[[SEL]], %[[N1]], %[[N2]]
+// CHECK-DAG:  %[[P1:.+]] = math.fma %[[SEL]], %[[P0]], %[[N3]]
+// CHECK-DAG:  %[[P2:.+]] = math.fma %[[SEL]], %[[P1]], %[[N4]]
+// CHECK-DAG:  %[[P3:.+]] = arith.mulf %[[SEL]], %[[P2]]
+// CHECK-DAG:  %[[SUB:.+]] = arith.subf %[[HALF_PI]], %[[P3]]
+// CHECK-DAG:  %[[EST:.+]] = select %[[CMP]], %[[P3]], %[[SUB]]
+// CHECK-DAG:  %[[ATAN:.+]] = math.copysign %[[EST]], %[[RATIO]]
+
+// Handle the case of x < 0:
+// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00
+// CHECK-DAG: %[[PI:.+]] = arith.constant 3.14159274
+// CHECK-DAG: %[[ADD_PI:.+]] = arith.addf %[[ATAN]], %[[PI]]
+// CHECK-DAG: %[[SUB_PI:.+]] = arith.subf %[[ATAN]], %[[PI]]
+// CHECK-DAG: %[[CMP_ATAN:.+]] = arith.cmpf ogt, %[[ATAN]], %[[ZERO]]
+// CHECK-DAG: %[[ATAN_ADJUST:.+]] = select %[[CMP_ATAN]], %[[SUB_PI]], %[[ADD_PI]]
+// CHECK-DAG: %[[X_NEG:.+]] = arith.cmpf ogt, %arg1, %[[ZERO]]
+// CHECK-DAG: %[[ATAN_EST:.+]] = select %[[X_NEG]], %[[ATAN]], %[[ATAN_ADJUST]]
+
+// Handle PI / 2 edge case:
+// CHECK-DAG: %[[X_ZERO:.+]] = arith.cmpf oeq, %arg1, %[[ZERO]]
+// CHECK-DAG: %[[Y_POS:.+]] = arith.cmpf ogt, %arg0, %[[ZERO]]
+// CHECK-DAG: %[[IS_HALF_PI:.+]] = arith.andi %[[X_ZERO]], %[[Y_POS]]
+// CHECK-DAG: %[[EDGE1:.+]] = select %[[IS_HALF_PI]], %[[HALF_PI]], %[[ATAN_EST]]
+
+// Handle -PI / 2 edge case:
+// CHECK-DAG: %[[NEG_HALF_PI:.+]] = arith.constant -1.57079637
+// CHECK-DAG: %[[Y_NEG:.+]] = arith.cmpf olt, %arg0, %[[ZERO]]
+// CHECK-DAG: %[[IS_NEG_HALF_PI:.+]] = arith.andi %[[X_ZERO]], %[[Y_NEG]]
+// CHECK-DAG: %[[EDGE2:.+]] = select %[[IS_NEG_HALF_PI]], %[[NEG_HALF_PI]], %[[EDGE1]]
+
+// Handle Nan edgecase:
+// CHECK-DAG: %[[Y_ZERO:.+]] = arith.cmpf oeq, %arg0, %[[ZERO]]
+// CHECK-DAG: %[[X_Y_ZERO:.+]] = arith.andi %[[X_ZERO]], %[[Y_ZERO]]
+// CHECK-DAG: %[[NAN:.+]] = arith.constant 0x7FC00000
+// CHECK-DAG: %[[EDGE3:.+]] = select %[[X_Y_ZERO]], %[[NAN]], %[[EDGE2]]
+// CHECK: return %[[EDGE3]]
+
+func @atan2_scalar(%arg0: f32, %arg1: f32) -> f32 {
+  %0 = math.atan2 %arg0, %arg1 : f32
+  return %0 : f32
+}
+

diff  --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
index b3c41057fa302..5a41d56dd42bd 100644
--- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
@@ -371,6 +371,122 @@ func @cos() {
   return
 }
 
+// -------------------------------------------------------------------------- //
+// Atan.
+// -------------------------------------------------------------------------- //
+
+func @atan() {
+  // CHECK: -0.785184
+  %0 = arith.constant -1.0 : f32
+  %atan_0 = math.atan %0 : f32
+  vector.print %atan_0 : f32
+
+  // CHECK: 0.785184
+  %1 = arith.constant 1.0 : f32
+  %atan_1 = math.atan %1 : f32
+  vector.print %atan_1 : f32
+
+  // CHECK: -0.463643
+  %2 = arith.constant -0.5 : f32
+  %atan_2 = math.atan %2 : f32
+  vector.print %atan_2 : f32
+
+  // CHECK: 0.463643
+  %3 = arith.constant 0.5 : f32
+  %atan_3 = math.atan %3 : f32
+  vector.print %atan_3 : f32
+
+  // CHECK: 0
+  %4 = arith.constant 0.0 : f32
+  %atan_4 = math.atan %4 : f32
+  vector.print %atan_4 : f32
+
+  // CHECK: -1.10715
+  %5 = arith.constant -2.0 : f32
+  %atan_5 = math.atan %5 : f32
+  vector.print %atan_5 : f32
+
+  // CHECK: 1.10715
+  %6 = arith.constant 2.0 : f32
+  %atan_6 = math.atan %6 : f32
+  vector.print %atan_6 : f32
+
+  return
+}
+
+
+// -------------------------------------------------------------------------- //
+// Atan2.
+// -------------------------------------------------------------------------- //
+
+func @atan2() {
+  %zero = arith.constant 0.0 : f32
+  %one = arith.constant 1.0 : f32
+  %two = arith.constant 2.0 : f32
+  %neg_one = arith.constant -1.0 : f32
+  %neg_two = arith.constant -2.0 : f32
+
+  // CHECK: 0
+  %atan2_0 = math.atan2 %zero, %one : f32
+  vector.print %atan2_0 : f32
+
+  // CHECK: 1.5708
+  %atan2_1 = math.atan2 %one, %zero : f32
+  vector.print %atan2_1 : f32
+
+  // CHECK: 3.14159
+  %atan2_2 = math.atan2 %zero, %neg_one : f32
+  vector.print %atan2_2 : f32
+
+  // CHECK: -1.5708
+  %atan2_3 = math.atan2 %neg_one, %zero : f32
+  vector.print %atan2_3 : f32
+
+  // CHECK: nan
+  %atan2_4 = math.atan2 %zero, %zero : f32
+  vector.print %atan2_4 : f32
+
+  // CHECK: 1.10715
+  %atan2_5 = math.atan2 %two, %one : f32
+  vector.print %atan2_5 : f32
+
+  // CHECK: 2.03444
+  %x6 = arith.constant -1.0 : f32
+  %y6 = arith.constant 2.0 : f32
+  %atan2_6 = math.atan2 %two, %neg_one : f32
+  vector.print %atan2_6 : f32
+
+  // CHECK: -2.03444
+  %atan2_7 = math.atan2 %neg_two, %neg_one : f32
+  vector.print %atan2_7 : f32
+
+  // CHECK: -1.10715
+  %atan2_8 = math.atan2 %neg_two, %one : f32
+  vector.print %atan2_8 : f32
+
+  // CHECK: 0.463643
+  %atan2_9 = math.atan2 %one, %two : f32
+  vector.print %atan2_9 : f32
+
+  // CHECK: 2.67795
+  %x10 = arith.constant -2.0 : f32
+  %y10 = arith.constant 1.0 : f32
+  %atan2_10 = math.atan2 %one, %neg_two : f32
+  vector.print %atan2_10 : f32
+
+  // CHECK: -2.67795
+  %x11 = arith.constant -2.0 : f32
+  %y11 = arith.constant -1.0 : f32
+  %atan2_11 = math.atan2 %neg_one, %neg_two : f32
+  vector.print %atan2_11 : f32
+
+  // CHECK: -0.463643
+  %atan2_12 = math.atan2 %neg_one, %two : f32
+  vector.print %atan2_12 : f32
+
+  return
+}
+
 
 func @main() {
   call @tanh(): () -> ()
@@ -382,5 +498,7 @@ func @main() {
   call @expm1(): () -> ()
   call @sin(): () -> ()
   call @cos(): () -> ()
+  call @atan() : () -> ()
+  call @atan2() : () -> ()
   return
 }


        


More information about the Mlir-commits mailing list