[Mlir-commits] [mlir] c089170 - [mlir] Add polynomial approximation for math::Log2

Eugene Zhulenev llvmlistbot at llvm.org
Wed Mar 10 14:49:29 PST 2021


Author: Emilio Cota
Date: 2021-03-10T14:49:22-08:00
New Revision: c0891706bc9faf428dfde7feddd8203efc43e118

URL: https://github.com/llvm/llvm-project/commit/c0891706bc9faf428dfde7feddd8203efc43e118
DIFF: https://github.com/llvm/llvm-project/commit/c0891706bc9faf428dfde7feddd8203efc43e118.diff

LOG: [mlir] Add polynomial approximation for math::Log2

```
name                     old cpu/op  new cpu/op  delta
BM_mlir_Log2_f32/10       134ns ±15%    45ns ± 4%  -66.39%  (p=0.000 n=20+17)
BM_mlir_Log2_f32/100     1.03µs ±16%  0.12µs ±10%  -88.78%  (p=0.000 n=20+18)
BM_mlir_Log2_f32/1k      10.3µs ±16%   0.7µs ± 5%  -93.24%  (p=0.000 n=20+17)
BM_mlir_Log2_f32/10k      104µs ±15%     7µs ±14%  -93.25%  (p=0.000 n=20+20)
BM_eigen_s_Log2_f32/10   95.3ns ±17%  90.9ns ± 6%     ~     (p=0.228 n=20+18)
BM_eigen_s_Log2_f32/100   907ns ± 3%   911ns ± 6%     ~     (p=0.539 n=16+20)
BM_eigen_s_Log2_f32/1k   9.88µs ± 4%  9.85µs ± 3%     ~     (p=0.790 n=16+17)
BM_eigen_s_Log2_f32/10k   105µs ±10%   110µs ±16%     ~     (p=0.459 n=16+20)
BM_eigen_v_Log2_f32/10   32.5ns ±31%  33.9ns ±14%   +4.31%  (p=0.028 n=17+20)
BM_eigen_v_Log2_f32/100   176ns ± 8%   180ns ± 7%   +2.19%  (p=0.045 n=16+17)
BM_eigen_v_Log2_f32/1k   1.44µs ± 4%  1.50µs ± 9%   +3.91%  (p=0.001 n=16+17)
BM_eigen_v_Log2_f32/10k  14.5µs ±10%  15.0µs ± 8%   +3.92%  (p=0.002 n=16+19)
```

Reviewed By: ezhulenev

Differential Revision: https://reviews.llvm.org/D98282

Added: 
    

Modified: 
    mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
    mlir/test/Dialect/Math/polynomial-approximation.mlir
    mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 669607e2ee09..c9285bf42c8d 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -258,29 +258,30 @@ TanhApproximation::matchAndRewrite(math::TanhOp op,
 
 #define LN2_VALUE                                                              \
   0.693147180559945309417232121458176568075500134360255254120680009493393621L
-#define LN2E_VALUE                                                             \
+#define LOG2E_VALUE                                                            \
   1.442695040888963407359924681001892137426645954152985934135449406931109219L
 
 //----------------------------------------------------------------------------//
-// LogOp approximation.
+// LogOp and Log2Op approximation.
 //----------------------------------------------------------------------------//
 
 namespace {
+template <typename Op>
+struct LogApproximationBase : public OpRewritePattern<Op> {
+  using OpRewritePattern<Op>::OpRewritePattern;
 
-// This approximations comes from the Julien Pommier's SSE math library.
-// Link: http://gruntthepeon.free.fr/ssemath
-struct LogApproximation : public OpRewritePattern<math::LogOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(math::LogOp op,
-                                PatternRewriter &rewriter) const final;
+  /// Base 2 if 'base2' is set; natural logarithm (base e) otherwise.
+  LogicalResult logMatchAndRewrite(Op op, PatternRewriter &rewriter,
+                                   bool base2) const;
 };
 } // namespace
 
+// This approximation comes from Julien Pommier's SSE math library.
+// Link: http://gruntthepeon.free.fr/ssemath
+template <typename Op>
 LogicalResult
-LogApproximation::matchAndRewrite(math::LogOp op,
-                                  PatternRewriter &rewriter) const {
+LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
+                                             bool base2) const {
   auto width = vectorWidth(op.operand().getType(), isF32);
   if (!width.hasValue())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
@@ -356,8 +357,13 @@ LogApproximation::matchAndRewrite(math::LogOp op,
   y0 = builder.create<FmaFOp>(cstNegHalf, x2, y0);
   x = builder.create<AddFOp>(x, y0);
 
-  Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE)));
-  x = builder.create<FmaFOp>(e, cstLn2, x);
+  if (base2) {
+    Value cstLog2e = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE)));
+    x = builder.create<FmaFOp>(x, cstLog2e, e);
+  } else {
+    Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE)));
+    x = builder.create<FmaFOp>(e, cstLn2, x);
+  }
 
   Value invalidMask =
       builder.create<CmpFOp>(CmpFPredicate::ULT, op.operand(), cstZero);
@@ -381,6 +387,28 @@ LogApproximation::matchAndRewrite(math::LogOp op,
   return success();
 }
 
+namespace {
+struct LogApproximation : public LogApproximationBase<math::LogOp> {
+  using LogApproximationBase::LogApproximationBase;
+
+  LogicalResult matchAndRewrite(math::LogOp op,
+                                PatternRewriter &rewriter) const final {
+    return logMatchAndRewrite(op, rewriter, /*base2=*/false);
+  }
+};
+} // namespace
+
+namespace {
+struct Log2Approximation : public LogApproximationBase<math::Log2Op> {
+  using LogApproximationBase::LogApproximationBase;
+
+  LogicalResult matchAndRewrite(math::Log2Op op,
+                                PatternRewriter &rewriter) const final {
+    return logMatchAndRewrite(op, rewriter, /*base2=*/true);
+  }
+};
+} // namespace
+
 //----------------------------------------------------------------------------//
 // Exp approximation.
 //----------------------------------------------------------------------------//
@@ -424,7 +452,7 @@ ExpApproximation::matchAndRewrite(math::ExpOp op,
   auto floor = [&](Value a) { return builder.create<FloorFOp>(a); };
 
   Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE)));
-  Value cstLN2E = bcast(f32Cst(builder, static_cast<float>(LN2E_VALUE)));
+  Value cstLog2E = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE)));
 
   // Polynomial coefficients.
   Value cstCephesExpP0 = bcast(f32Cst(builder, 1.0));
@@ -437,7 +465,7 @@ ExpApproximation::matchAndRewrite(math::ExpOp op,
   Value x = op.operand();
 
   // Reduced y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2)
-  Value xL2Inv = mul(x, cstLN2E);
+  Value xL2Inv = mul(x, cstLog2E);
   Value kF32 = floor(xL2Inv);
   Value kLn2 = mul(kF32, cstLn2);
   Value y = sub(x, kLn2);
@@ -501,5 +529,6 @@ ExpApproximation::matchAndRewrite(math::ExpOp op,
 
 void mlir::populateMathPolynomialApproximationPatterns(
     OwningRewritePatternList &patterns, MLIRContext *ctx) {
-  patterns.insert<TanhApproximation, LogApproximation, ExpApproximation>(ctx);
+  patterns.insert<TanhApproximation, LogApproximation, Log2Approximation,
+                  ExpApproximation>(ctx);
 }

diff  --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
index 6d102e5f5c85..5e3b3098cfac 100644
--- a/mlir/test/Dialect/Math/polynomial-approximation.mlir
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -9,7 +9,8 @@ func @scalar(%arg0: f32) -> f32 {
   %0 = math.tanh %arg0 : f32
   // CHECK-NOT: log
   %1 = math.log %0 : f32
-  return %1 : f32
+  %2 = math.log2 %1 : f32
+  return %2 : f32
 }
 
 // CHECK-LABEL: @vector
@@ -18,7 +19,8 @@ func @vector(%arg0: vector<8xf32>) -> vector<8xf32> {
   %0 = math.tanh %arg0 : vector<8xf32>
   // CHECK-NOT: log
   %1 = math.log %0 : vector<8xf32>
-  return %1 : vector<8xf32>
+  %2 = math.log2 %1 : vector<8xf32>
+  return %2 : vector<8xf32>
 }
 
 // CHECK-LABEL: @exp_scalar

diff  --git a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir
index 072cc2d4655a..02fc5241f452 100644
--- a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir
+++ b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir
@@ -71,8 +71,47 @@ func @log() {
   return
 }
 
+func @log2() {
+  // CHECK: 3.81887
+  %0 = constant 14.112233 : f32
+  %1 = math.log2 %0 : f32
+  vector.print %1 : f32
+
+  // CHECK: -2, -0.415037, 0, 0.321928
+  %2 = constant dense<[0.25, 0.75, 1.0, 1.25]> : vector<4xf32>
+  %3 = math.log2 %2 : vector<4xf32>
+  vector.print %3 : vector<4xf32>
+
+  // CHECK: -3.32193, -2.32193, -1.73697, -1.32193, -1, -0.736966, -0.514573, -0.321928
+  %4 = constant dense<[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]> : vector<8xf32>
+  %5 = math.log2 %4 : vector<8xf32>
+  vector.print %5 : vector<8xf32>
+
+  // CHECK: -inf
+  %zero = constant 0.0 : f32
+  %log_zero = math.log2 %zero : f32
+  vector.print %log_zero : f32
+
+  // CHECK: nan
+  %neg_one = constant -1.0 : f32
+  %log_neg_one = math.log2 %neg_one : f32
+  vector.print %log_neg_one : f32
+
+  // CHECK: inf
+  %inf = constant 0x7f800000 : f32
+  %log_inf = math.log2 %inf : f32
+  vector.print %log_inf : f32
+
+  // CHECK: -inf, nan, inf, 1.58496
+  %special_vec = constant dense<[0.0, -1.0, 0x7f800000, 3.0]> : vector<4xf32>
+  %log_special_vec = math.log2 %special_vec : vector<4xf32>
+  vector.print %log_special_vec : vector<4xf32>
+
+  return
+}
+
 // -------------------------------------------------------------------------- //
-// Log.
+// Exp.
 // -------------------------------------------------------------------------- //
 func @exp() {
   // CHECK: 2.71828
@@ -111,6 +150,7 @@ func @exp() {
 func @main() {
   call @tanh(): () -> ()
   call @log(): () -> ()
+  call @log2(): () -> ()
   call @exp(): () -> ()
   return
 }


        


More information about the Mlir-commits mailing list