[Mlir-commits] [mlir] 2f9f9af - [mlir] Add polynomial approximation for atan and atan2
Rob Suderman
llvmlistbot at llvm.org
Fri Jan 21 12:31:28 PST 2022
Author: Rob Suderman
Date: 2022-01-21T12:22:58-08:00
New Revision: 2f9f9afa4e1281b4ac7c8ad36860a4e35e6f5070
URL: https://github.com/llvm/llvm-project/commit/2f9f9afa4e1281b4ac7c8ad36860a4e35e6f5070
DIFF: https://github.com/llvm/llvm-project/commit/2f9f9afa4e1281b4ac7c8ad36860a4e35e6f5070.diff
LOG: [mlir] Add polynomial approximation for atan and atan2
Implement a taylor series approximation for atan and add an atan2 lowering
that uses atan's appromation. This includes tests for edge cases and tests
for each quadrant.
Reviewed By: NatashaKnk
Differential Revision: https://reviews.llvm.org/D115682
Added:
Modified:
mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
mlir/test/Dialect/Math/polynomial-approximation.mlir
mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 9931e89647bcf..7d04ae7e3d34f 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -278,6 +278,133 @@ Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
}
} // namespace
+//----------------------------------------------------------------------------//
+// AtanOp approximation.
+//----------------------------------------------------------------------------//
+
+namespace {
+struct AtanApproximation : public OpRewritePattern<math::AtanOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(math::AtanOp op,
+ PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+LogicalResult
+AtanApproximation::matchAndRewrite(math::AtanOp op,
+ PatternRewriter &rewriter) const {
+ auto operand = op.getOperand();
+ if (!getElementTypeOrSelf(operand).isF32())
+ return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+ ArrayRef<int64_t> shape = vectorShape(op.getOperand());
+
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+ auto one = broadcast(builder, f32Cst(builder, 1.0f), shape);
+
+ // Remap the problem over [0.0, 1.0] by looking at the absolute value and the
+ // handling symmetry.
+ Value abs = builder.create<math::AbsOp>(operand);
+ Value reciprocal = builder.create<arith::DivFOp>(one, abs);
+ Value compare =
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, abs, reciprocal);
+ Value x = builder.create<SelectOp>(compare, abs, reciprocal);
+
+ // Perform the Taylor series approximation for atan over the range
+ // [-1.0, 1.0].
+ auto n1 = broadcast(builder, f32Cst(builder, 0.14418283), shape);
+ auto n2 = broadcast(builder, f32Cst(builder, -0.34999234), shape);
+ auto n3 = broadcast(builder, f32Cst(builder, -0.01067831), shape);
+ auto n4 = broadcast(builder, f32Cst(builder, 1.00209986), shape);
+
+ Value p = builder.create<math::FmaOp>(x, n1, n2);
+ p = builder.create<math::FmaOp>(x, p, n3);
+ p = builder.create<math::FmaOp>(x, p, n4);
+ p = builder.create<arith::MulFOp>(x, p);
+
+ // Remap the solution for over [0.0, 1.0] to [0.0, inf]
+ auto half_pi = broadcast(builder, f32Cst(builder, 1.57079632679f), shape);
+ Value sub = builder.create<arith::SubFOp>(half_pi, p);
+ Value select = builder.create<SelectOp>(compare, p, sub);
+
+ // Correct for signing of the input.
+ rewriter.replaceOpWithNewOp<math::CopySignOp>(op, select, operand);
+ return success();
+}
+
+//----------------------------------------------------------------------------//
+// AtanOp approximation.
+//----------------------------------------------------------------------------//
+
+namespace {
+struct Atan2Approximation : public OpRewritePattern<math::Atan2Op> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(math::Atan2Op op,
+ PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+LogicalResult
+Atan2Approximation::matchAndRewrite(math::Atan2Op op,
+ PatternRewriter &rewriter) const {
+ auto y = op.getOperand(0);
+ auto x = op.getOperand(1);
+ if (!getElementTypeOrSelf(x).isF32())
+ return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+ ArrayRef<int64_t> shape = vectorShape(op.getResult());
+
+ // Compute atan in the valid range.
+ auto div = builder.create<arith::DivFOp>(y, x);
+ auto atan = builder.create<math::AtanOp>(div);
+
+ // Determine what the atan would be for a 180 degree rotation.
+ auto zero = broadcast(builder, f32Cst(builder, 0.0f), shape);
+ auto pi = broadcast(builder, f32Cst(builder, 3.14159265359f), shape);
+ auto add_pi = builder.create<arith::AddFOp>(atan, pi);
+ auto sub_pi = builder.create<arith::SubFOp>(atan, pi);
+ auto atan_gt =
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, atan, zero);
+ auto flipped_atan = builder.create<SelectOp>(atan_gt, sub_pi, add_pi);
+
+ // Determine whether to directly use atan or use the 180 degree flip
+ auto x_gt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zero);
+ Value result = builder.create<SelectOp>(x_gt, atan, flipped_atan);
+
+ // Handle x = 0, y > 0
+ Value x_zero =
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, x, zero);
+ Value y_gt =
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, y, zero);
+ Value is_half_pi = builder.create<arith::AndIOp>(x_zero, y_gt);
+ auto half_pi = broadcast(builder, f32Cst(builder, 1.57079632679f), shape);
+ result = builder.create<SelectOp>(is_half_pi, half_pi, result);
+
+ // Handle x = 0, y < 0
+ Value y_lt =
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, y, zero);
+ Value is_negative_half_pi_pi = builder.create<arith::AndIOp>(x_zero, y_lt);
+ auto negative_half_pi_pi =
+ broadcast(builder, f32Cst(builder, -1.57079632679), shape);
+ result = builder.create<SelectOp>(is_negative_half_pi_pi, negative_half_pi_pi,
+ result);
+
+ // Handle x = 0, y = 0;
+ Value y_zero =
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, y, zero);
+ Value is_nan = builder.create<arith::AndIOp>(x_zero, y_zero);
+ Value cst_nan = broadcast(builder, f32FromBits(builder, 0x7fc00000), shape);
+ result = builder.create<SelectOp>(is_nan, cst_nan, result);
+
+ rewriter.replaceOp(op, result);
+ return success();
+}
+
//----------------------------------------------------------------------------//
// TanhOp approximation.
//----------------------------------------------------------------------------//
@@ -1074,9 +1201,10 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
void mlir::populateMathPolynomialApproximationPatterns(
RewritePatternSet &patterns,
const MathPolynomialApproximationOptions &options) {
- patterns.add<TanhApproximation, LogApproximation, Log2Approximation,
- Log1pApproximation, ErfPolynomialApproximation, ExpApproximation,
- ExpM1Approximation, SinAndCosApproximation<true, math::SinOp>,
+ patterns.add<AtanApproximation, Atan2Approximation, TanhApproximation,
+ LogApproximation, Log2Approximation, Log1pApproximation,
+ ErfPolynomialApproximation, ExpApproximation, ExpM1Approximation,
+ SinAndCosApproximation<true, math::SinOp>,
SinAndCosApproximation<false, math::CosOp>>(
patterns.getContext());
if (options.enableAvx2)
diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
index f388b84d83c8e..a40cc9c6f037a 100644
--- a/mlir/test/Dialect/Math/polynomial-approximation.mlir
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -507,3 +507,85 @@ func @rsqrt_vector_2x16xf32(%arg0: vector<2x16xf32>) -> vector<2x16xf32> {
%0 = math.rsqrt %arg0 : vector<2x16xf32>
return %0 : vector<2x16xf32>
}
+
+// CHECK-LABEL: @atan_scalar
+// CHECK-DAG: %[[ONE:.+]] = arith.constant 1.000000e+00
+// CHECK-DAG: %[[N1:.+]] = arith.constant 0.144182831
+// CHECK-DAG: %[[N2:.+]] = arith.constant -0.349992335
+// CHECK-DAG: %[[N3:.+]] = arith.constant -0.0106783099
+// CHECK-DAG: %[[N4:.+]] = arith.constant 1.00209987
+// CHECK-DAG: %[[HALF_PI:.+]] = arith.constant 1.57079637
+// CHECK-DAG: %[[ABS:.+]] = math.abs %arg0
+// CHECK-DAG: %[[DIV:.+]] = arith.divf %cst, %[[ABS]]
+// CHECK-DAG: %[[CMP:.+]] = arith.cmpf olt, %[[ABS]], %[[DIV]]
+// CHECK-DAG: %[[SEL:.+]] = select %[[CMP]], %[[ABS]], %[[DIV]]
+// CHECK-DAG: %[[P0:.+]] = math.fma %[[SEL]], %[[N1]], %[[N2]]
+// CHECK-DAG: %[[P1:.+]] = math.fma %[[SEL]], %[[P0]], %[[N3]]
+// CHECK-DAG: %[[P2:.+]] = math.fma %[[SEL]], %[[P1]], %[[N4]]
+// CHECK-DAG: %[[P3:.+]] = arith.mulf %[[SEL]], %[[P2]]
+// CHECK-DAG: %[[SUB:.+]] = arith.subf %[[HALF_PI]], %[[P3]]
+// CHECK-DAG: %[[EST:.+]] = select %[[CMP]], %[[P3]], %[[SUB]]
+// CHECK-DAG: %[[RES:.+]] = math.copysign %[[EST]], %arg0
+// CHECK: return %[[RES]]
+func @atan_scalar(%arg0: f32) -> f32 {
+ %0 = math.atan %arg0 : f32
+ return %0 : f32
+}
+
+
+// CHECK-LABEL: @atan2_scalar
+
+// ATan approximation:
+// CHECK-DAG: %[[ONE:.+]] = arith.constant 1.000000e+00
+// CHECK-DAG: %[[N1:.+]] = arith.constant 0.144182831
+// CHECK-DAG: %[[N2:.+]] = arith.constant -0.349992335
+// CHECK-DAG: %[[N3:.+]] = arith.constant -0.0106783099
+// CHECK-DAG: %[[N4:.+]] = arith.constant 1.00209987
+// CHECK-DAG: %[[HALF_PI:.+]] = arith.constant 1.57079637
+// CHECK-DAG: %[[RATIO:.+]] = arith.divf %arg0, %arg1
+// CHECK-DAG: %[[ABS:.+]] = math.abs %[[RATIO]]
+// CHECK-DAG: %[[DIV:.+]] = arith.divf %cst, %[[ABS]]
+// CHECK-DAG: %[[CMP:.+]] = arith.cmpf olt, %[[ABS]], %[[DIV]]
+// CHECK-DAG: %[[SEL:.+]] = select %[[CMP]], %[[ABS]], %[[DIV]]
+// CHECK-DAG: %[[P0:.+]] = math.fma %[[SEL]], %[[N1]], %[[N2]]
+// CHECK-DAG: %[[P1:.+]] = math.fma %[[SEL]], %[[P0]], %[[N3]]
+// CHECK-DAG: %[[P2:.+]] = math.fma %[[SEL]], %[[P1]], %[[N4]]
+// CHECK-DAG: %[[P3:.+]] = arith.mulf %[[SEL]], %[[P2]]
+// CHECK-DAG: %[[SUB:.+]] = arith.subf %[[HALF_PI]], %[[P3]]
+// CHECK-DAG: %[[EST:.+]] = select %[[CMP]], %[[P3]], %[[SUB]]
+// CHECK-DAG: %[[ATAN:.+]] = math.copysign %[[EST]], %[[RATIO]]
+
+// Handle the case of x < 0:
+// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00
+// CHECK-DAG: %[[PI:.+]] = arith.constant 3.14159274
+// CHECK-DAG: %[[ADD_PI:.+]] = arith.addf %[[ATAN]], %[[PI]]
+// CHECK-DAG: %[[SUB_PI:.+]] = arith.subf %[[ATAN]], %[[PI]]
+// CHECK-DAG: %[[CMP_ATAN:.+]] = arith.cmpf ogt, %[[ATAN]], %[[ZERO]]
+// CHECK-DAG: %[[ATAN_ADJUST:.+]] = select %[[CMP_ATAN]], %[[SUB_PI]], %[[ADD_PI]]
+// CHECK-DAG: %[[X_NEG:.+]] = arith.cmpf ogt, %arg1, %[[ZERO]]
+// CHECK-DAG: %[[ATAN_EST:.+]] = select %[[X_NEG]], %[[ATAN]], %[[ATAN_ADJUST]]
+
+// Handle PI / 2 edge case:
+// CHECK-DAG: %[[X_ZERO:.+]] = arith.cmpf oeq, %arg1, %[[ZERO]]
+// CHECK-DAG: %[[Y_POS:.+]] = arith.cmpf ogt, %arg0, %[[ZERO]]
+// CHECK-DAG: %[[IS_HALF_PI:.+]] = arith.andi %[[X_ZERO]], %[[Y_POS]]
+// CHECK-DAG: %[[EDGE1:.+]] = select %[[IS_HALF_PI]], %[[HALF_PI]], %[[ATAN_EST]]
+
+// Handle -PI / 2 edge case:
+// CHECK-DAG: %[[NEG_HALF_PI:.+]] = arith.constant -1.57079637
+// CHECK-DAG: %[[Y_NEG:.+]] = arith.cmpf olt, %arg0, %[[ZERO]]
+// CHECK-DAG: %[[IS_NEG_HALF_PI:.+]] = arith.andi %[[X_ZERO]], %[[Y_NEG]]
+// CHECK-DAG: %[[EDGE2:.+]] = select %[[IS_NEG_HALF_PI]], %[[NEG_HALF_PI]], %[[EDGE1]]
+
+// Handle Nan edgecase:
+// CHECK-DAG: %[[Y_ZERO:.+]] = arith.cmpf oeq, %arg0, %[[ZERO]]
+// CHECK-DAG: %[[X_Y_ZERO:.+]] = arith.andi %[[X_ZERO]], %[[Y_ZERO]]
+// CHECK-DAG: %[[NAN:.+]] = arith.constant 0x7FC00000
+// CHECK-DAG: %[[EDGE3:.+]] = select %[[X_Y_ZERO]], %[[NAN]], %[[EDGE2]]
+// CHECK: return %[[EDGE3]]
+
+func @atan2_scalar(%arg0: f32, %arg1: f32) -> f32 {
+ %0 = math.atan2 %arg0, %arg1 : f32
+ return %0 : f32
+}
+
diff --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
index b3c41057fa302..5a41d56dd42bd 100644
--- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
@@ -371,6 +371,122 @@ func @cos() {
return
}
+// -------------------------------------------------------------------------- //
+// Atan.
+// -------------------------------------------------------------------------- //
+
+func @atan() {
+ // CHECK: -0.785184
+ %0 = arith.constant -1.0 : f32
+ %atan_0 = math.atan %0 : f32
+ vector.print %atan_0 : f32
+
+ // CHECK: 0.785184
+ %1 = arith.constant 1.0 : f32
+ %atan_1 = math.atan %1 : f32
+ vector.print %atan_1 : f32
+
+ // CHECK: -0.463643
+ %2 = arith.constant -0.5 : f32
+ %atan_2 = math.atan %2 : f32
+ vector.print %atan_2 : f32
+
+ // CHECK: 0.463643
+ %3 = arith.constant 0.5 : f32
+ %atan_3 = math.atan %3 : f32
+ vector.print %atan_3 : f32
+
+ // CHECK: 0
+ %4 = arith.constant 0.0 : f32
+ %atan_4 = math.atan %4 : f32
+ vector.print %atan_4 : f32
+
+ // CHECK: -1.10715
+ %5 = arith.constant -2.0 : f32
+ %atan_5 = math.atan %5 : f32
+ vector.print %atan_5 : f32
+
+ // CHECK: 1.10715
+ %6 = arith.constant 2.0 : f32
+ %atan_6 = math.atan %6 : f32
+ vector.print %atan_6 : f32
+
+ return
+}
+
+
+// -------------------------------------------------------------------------- //
+// Atan2.
+// -------------------------------------------------------------------------- //
+
+func @atan2() {
+ %zero = arith.constant 0.0 : f32
+ %one = arith.constant 1.0 : f32
+ %two = arith.constant 2.0 : f32
+ %neg_one = arith.constant -1.0 : f32
+ %neg_two = arith.constant -2.0 : f32
+
+ // CHECK: 0
+ %atan2_0 = math.atan2 %zero, %one : f32
+ vector.print %atan2_0 : f32
+
+ // CHECK: 1.5708
+ %atan2_1 = math.atan2 %one, %zero : f32
+ vector.print %atan2_1 : f32
+
+ // CHECK: 3.14159
+ %atan2_2 = math.atan2 %zero, %neg_one : f32
+ vector.print %atan2_2 : f32
+
+ // CHECK: -1.5708
+ %atan2_3 = math.atan2 %neg_one, %zero : f32
+ vector.print %atan2_3 : f32
+
+ // CHECK: nan
+ %atan2_4 = math.atan2 %zero, %zero : f32
+ vector.print %atan2_4 : f32
+
+ // CHECK: 1.10715
+ %atan2_5 = math.atan2 %two, %one : f32
+ vector.print %atan2_5 : f32
+
+ // CHECK: 2.03444
+ %x6 = arith.constant -1.0 : f32
+ %y6 = arith.constant 2.0 : f32
+ %atan2_6 = math.atan2 %two, %neg_one : f32
+ vector.print %atan2_6 : f32
+
+ // CHECK: -2.03444
+ %atan2_7 = math.atan2 %neg_two, %neg_one : f32
+ vector.print %atan2_7 : f32
+
+ // CHECK: -1.10715
+ %atan2_8 = math.atan2 %neg_two, %one : f32
+ vector.print %atan2_8 : f32
+
+ // CHECK: 0.463643
+ %atan2_9 = math.atan2 %one, %two : f32
+ vector.print %atan2_9 : f32
+
+ // CHECK: 2.67795
+ %x10 = arith.constant -2.0 : f32
+ %y10 = arith.constant 1.0 : f32
+ %atan2_10 = math.atan2 %one, %neg_two : f32
+ vector.print %atan2_10 : f32
+
+ // CHECK: -2.67795
+ %x11 = arith.constant -2.0 : f32
+ %y11 = arith.constant -1.0 : f32
+ %atan2_11 = math.atan2 %neg_one, %neg_two : f32
+ vector.print %atan2_11 : f32
+
+ // CHECK: -0.463643
+ %atan2_12 = math.atan2 %neg_one, %two : f32
+ vector.print %atan2_12 : f32
+
+ return
+}
+
func @main() {
call @tanh(): () -> ()
@@ -382,5 +498,7 @@ func @main() {
call @expm1(): () -> ()
call @sin(): () -> ()
call @cos(): () -> ()
+ call @atan() : () -> ()
+ call @atan2() : () -> ()
return
}
More information about the Mlir-commits
mailing list