[Mlir-commits] [mlir] f99ccf6 - [mlir] Add math polynomial approximation pass

Eugene Zhulenev llvmlistbot at llvm.org
Fri Feb 19 12:43:42 PST 2021


Author: Eugene Zhulenev
Date: 2021-02-19T12:43:36-08:00
New Revision: f99ccf6516bdd5def4d3bc311330aec92f5cb99d

URL: https://github.com/llvm/llvm-project/commit/f99ccf6516bdd5def4d3bc311330aec92f5cb99d
DIFF: https://github.com/llvm/llvm-project/commit/f99ccf6516bdd5def4d3bc311330aec92f5cb99d.diff

LOG: [mlir] Add math polynomial approximation pass

This gives ~30x speedup compared to expanding Tanh into exp operations:

```
name                  old cpu/op  new cpu/op  delta
BM_mlir_Tanh_f32/10    253ns ± 3%    55ns ± 7%  -78.35%  (p=0.000 n=44+41)
BM_mlir_Tanh_f32/100  2.21µs ± 4%  0.14µs ± 8%  -93.85%  (p=0.000 n=48+49)
BM_mlir_Tanh_f32/1k   22.6µs ± 4%   0.7µs ± 5%  -96.68%  (p=0.000 n=32+42)
BM_mlir_Tanh_f32/10k   225µs ± 5%     7µs ± 6%  -96.88%  (p=0.000 n=49+55)

name                  old time/op             new time/op             delta
BM_mlir_Tanh_f32/10    259ns ± 1%               56ns ± 2%  -78.31%        (p=0.000 n=41+39)
BM_mlir_Tanh_f32/100  2.27µs ± 1%             0.14µs ± 5%  -93.89%        (p=0.000 n=46+49)
BM_mlir_Tanh_f32/1k   22.9µs ± 1%              0.8µs ± 4%  -96.67%        (p=0.000 n=30+42)
BM_mlir_Tanh_f32/10k   230µs ± 0%                7µs ± 3%  -96.88%        (p=0.000 n=37+55)
```

This approximations is based on Eigen::generic_fast_tanh function

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D96739

Added: 
    mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
    mlir/test/Dialect/Math/polynomial-approximation.mlir
    mlir/test/lib/Transforms/TestPolynomialApproximation.cpp
    mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir

Modified: 
    mlir/include/mlir/Dialect/Math/Transforms/Passes.h
    mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
    mlir/test/lib/Transforms/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 82da168c2ccd..c965bab3769b 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -19,6 +19,9 @@ class OwningRewritePatternList;
 void populateExpandTanhPattern(OwningRewritePatternList &patterns,
                                MLIRContext *ctx);
 
+void populateMathPolynomialApproximationPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *ctx);
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_

diff  --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index b85549941a63..23463af00fb9 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRMathTransforms
   ExpandTanh.cpp
+  PolynomialApproximation.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Math/Transforms

diff  --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
new file mode 100644
index 000000000000..d230334a8019
--- /dev/null
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -0,0 +1,194 @@
+//===- PolynomialApproximation.cpp - Approximate math operations ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements expansion of math operations to fast approximations
+// that do not rely on any of the library functions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+static bool isValidFloatType(Type type) {
+  if (auto vectorType = type.dyn_cast<VectorType>())
+    return vectorType.getElementType().isa<FloatType>();
+  return type.isa<FloatType>();
+}
+
+//----------------------------------------------------------------------------//
+// A PatternRewriter wrapper that provides concise API for building expansions
+// for operations on float scalars or vectors.
+//----------------------------------------------------------------------------//
+
+namespace {
+class FloatApproximationBuilder {
+public:
+  FloatApproximationBuilder(Location loc, Type type, PatternRewriter &rewriter);
+
+  Value constant(double value) const;
+
+  Value abs(Value a) const;
+  Value min(Value a, Value b) const;
+  Value max(Value a, Value b) const;
+  Value mul(Value a, Value b) const;
+  Value div(Value a, Value b) const;
+
+  // Fused multiple-add operation: a * b + c.
+  Value madd(Value a, Value b, Value c) const;
+
+  // Compares values `a` and `b` with the given `predicate`.
+  Value cmp(CmpFPredicate predicate, Value a, Value b) const;
+
+  // Selects values from `a` or `b` based on the `predicate`.
+  Value select(Value predicate, Value a, Value b) const;
+
+private:
+  Location loc;
+  PatternRewriter &rewriter;
+  VectorType vectorType; // can be null for scalar type
+  FloatType elementType;
+};
+} // namespace
+
+FloatApproximationBuilder::FloatApproximationBuilder(Location loc, Type type,
+                                                     PatternRewriter &rewriter)
+    : loc(loc), rewriter(rewriter) {
+  vectorType = type.dyn_cast<VectorType>();
+
+  if (vectorType)
+    elementType = vectorType.getElementType().cast<FloatType>();
+  else
+    elementType = type.cast<FloatType>();
+}
+
+Value FloatApproximationBuilder::constant(double value) const {
+  auto attr = rewriter.getFloatAttr(elementType, value);
+  Value scalar = rewriter.create<ConstantOp>(loc, attr);
+
+  if (vectorType)
+    return rewriter.create<BroadcastOp>(loc, vectorType, scalar);
+  return scalar;
+}
+
+Value FloatApproximationBuilder::abs(Value a) const {
+  return rewriter.create<AbsFOp>(loc, a);
+}
+
+Value FloatApproximationBuilder::min(Value a, Value b) const {
+  return select(cmp(CmpFPredicate::OLT, a, b), a, b);
+}
+Value FloatApproximationBuilder::max(Value a, Value b) const {
+  return select(cmp(CmpFPredicate::OGT, a, b), a, b);
+}
+Value FloatApproximationBuilder::mul(Value a, Value b) const {
+  return rewriter.create<MulFOp>(loc, a, b);
+}
+
+Value FloatApproximationBuilder::div(Value a, Value b) const {
+  return rewriter.create<DivFOp>(loc, a, b);
+}
+
+Value FloatApproximationBuilder::madd(Value a, Value b, Value c) const {
+  return rewriter.create<FmaFOp>(loc, a, b, c);
+}
+
+Value FloatApproximationBuilder::cmp(CmpFPredicate predicate, Value a,
+                                     Value b) const {
+  return rewriter.create<CmpFOp>(loc, predicate, a, b);
+}
+
+Value FloatApproximationBuilder::select(Value predicate, Value a,
+                                        Value b) const {
+  return rewriter.create<SelectOp>(loc, predicate, a, b);
+}
+
+//----------------------------------------------------------------------------//
+// TanhOp approximation.
+//----------------------------------------------------------------------------//
+
+namespace {
+struct TanhApproximation : public OpRewritePattern<math::TanhOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(math::TanhOp op,
+                                PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+LogicalResult
+TanhApproximation::matchAndRewrite(math::TanhOp op,
+                                   PatternRewriter &rewriter) const {
+  if (!isValidFloatType(op.operand().getType()))
+    return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+  Value operand = op.operand();
+  FloatApproximationBuilder builder(op->getLoc(), operand.getType(), rewriter);
+
+  // Clamp operand into [plusClamp, minusClamp] range.
+  Value plusClamp = builder.constant(7.90531110763549805);
+  Value minusClamp = builder.constant(-7.9053111076354980);
+  Value x = builder.max(builder.min(operand, plusClamp), minusClamp);
+
+  // Mask for tiny values that are approximated with `operand`.
+  Value tiny = builder.constant(0.0004f);
+  Value tinyMask = builder.cmp(CmpFPredicate::OLT, builder.abs(operand), tiny);
+
+  // The monomial coefficients of the numerator polynomial (odd).
+  Value alpha1 = builder.constant(4.89352455891786e-03);
+  Value alpha3 = builder.constant(6.37261928875436e-04);
+  Value alpha5 = builder.constant(1.48572235717979e-05);
+  Value alpha7 = builder.constant(5.12229709037114e-08);
+  Value alpha9 = builder.constant(-8.60467152213735e-11);
+  Value alpha11 = builder.constant(2.00018790482477e-13);
+  Value alpha13 = builder.constant(-2.76076847742355e-16);
+
+  // The monomial coefficients of the denominator polynomial (even).
+  Value beta0 = builder.constant(4.89352518554385e-03);
+  Value beta2 = builder.constant(2.26843463243900e-03);
+  Value beta4 = builder.constant(1.18534705686654e-04);
+  Value beta6 = builder.constant(1.19825839466702e-06);
+
+  // Since the polynomials are odd/even, we need x^2.
+  Value x2 = builder.mul(x, x);
+
+  // Evaluate the numerator polynomial p.
+  Value p = builder.madd(x2, alpha13, alpha11);
+  p = builder.madd(x2, p, alpha9);
+  p = builder.madd(x2, p, alpha7);
+  p = builder.madd(x2, p, alpha5);
+  p = builder.madd(x2, p, alpha3);
+  p = builder.madd(x2, p, alpha1);
+  p = builder.mul(x, p);
+
+  // Evaluate the denominator polynomial q.
+  Value q = builder.madd(x2, beta6, beta4);
+  q = builder.madd(x2, q, beta2);
+  q = builder.madd(x2, q, beta0);
+
+  // Divide the numerator by the denominator.
+  Value res = builder.select(tinyMask, x, builder.div(p, q));
+
+  rewriter.replaceOp(op, res);
+
+  return success();
+}
+
+//----------------------------------------------------------------------------//
+
+void mlir::populateMathPolynomialApproximationPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *ctx) {
+  patterns.insert<TanhApproximation>(ctx);
+}

diff  --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
new file mode 100644
index 000000000000..a02443061502
--- /dev/null
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt %s -test-math-polynomial-approximation | FileCheck %s
+
+// CHECK-LABEL: @tanh_scalar
+func @tanh_scalar(%arg0: f32) -> f32 {
+  // CHECK-NOT: tanh
+  %0 = math.tanh %arg0 : f32
+  return %0 : f32
+}
+
+// CHECK-LABEL: @tanh_vector
+func @tanh_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
+  // CHECK-NOT: tanh
+  %0 = math.tanh %arg0 : vector<8xf32>
+  return %0 : vector<8xf32>
+}

diff  --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 4917616bbd4a..7de188df4aa8 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -26,6 +26,7 @@ add_mlir_library(MLIRTestTransforms
   TestLoopUnrolling.cpp
   TestNumberOfExecutions.cpp
   TestOpaqueLoc.cpp
+  TestPolynomialApproximation.cpp
   TestMemRefBoundCheck.cpp
   TestMemRefDependenceCheck.cpp
   TestMemRefStrideCalculation.cpp

diff  --git a/mlir/test/lib/Transforms/TestPolynomialApproximation.cpp b/mlir/test/lib/Transforms/TestPolynomialApproximation.cpp
new file mode 100644
index 000000000000..4f48538bd4bf
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestPolynomialApproximation.cpp
@@ -0,0 +1,46 @@
+//===- TestPolynomialApproximation.cpp - Test math ops approximations -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains test passes for expanding math operations into
+// polynomial approximations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+struct TestMathPolynomialApproximationPass
+    : public PassWrapper<TestMathPolynomialApproximationPass, FunctionPass> {
+  void runOnFunction() override;
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<vector::VectorDialect, math::MathDialect>();
+  }
+};
+} // end anonymous namespace
+
+void TestMathPolynomialApproximationPass::runOnFunction() {
+  OwningRewritePatternList patterns;
+  populateMathPolynomialApproximationPatterns(patterns, &getContext());
+  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+}
+
+namespace mlir {
+namespace test {
+void registerTestMathPolynomialApproximationPass() {
+  PassRegistration<TestMathPolynomialApproximationPass> pass(
+      "test-math-polynomial-approximation",
+      "Test math polynomial approximations");
+}
+} // namespace test
+} // namespace mlir

diff  --git a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir
new file mode 100644
index 000000000000..444833789859
--- /dev/null
+++ b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir
@@ -0,0 +1,32 @@
+// RUN:   mlir-opt %s -test-math-polynomial-approximation                      \
+// RUN:               -convert-vector-to-llvm                                  \
+// RUN:               -convert-std-to-llvm                                     \
+// RUN: | mlir-cpu-runner                                                      \
+// RUN:     -e main -entry-point-result=void -O0                               \
+// RUN:     -shared-libs=%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext  \
+// RUN:     -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext    \
+// RUN: | FileCheck %s
+
+
+func @main() {
+  // ------------------------------------------------------------------------ //
+  // Tanh.
+  // ------------------------------------------------------------------------ //
+
+  // CHECK: 0.848284
+  %0 = constant 1.25 : f32
+  %1 = math.tanh %0 : f32
+  vector.print %1 : f32
+
+  // CHECK: 0.244919, 0.635149, 0.761594, 0.848284
+  %2 = constant dense<[0.25, 0.75, 1.0, 1.25]> : vector<4xf32>
+  %3 = math.tanh %2 : vector<4xf32>
+  vector.print %3 : vector<4xf32>
+
+  // CHECK: 0.099668, 0.197375, 0.291313, 0.379949, 0.462117, 0.53705, 0.604368, 0.664037
+  %4 = constant dense<[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]> : vector<8xf32>
+  %5 = math.tanh %4 : vector<8xf32>
+  vector.print %5 : vector<8xf32>
+
+  return
+}

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 2ce2e0198202..e03e7e8f8907 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -84,6 +84,7 @@ void registerTestLivenessPass();
 void registerTestLoopFusion();
 void registerTestLoopMappingPass();
 void registerTestLoopUnrollingPass();
+void registerTestMathPolynomialApproximationPass();
 void registerTestMemRefDependenceCheck();
 void registerTestMemRefStrideCalculation();
 void registerTestNumberOfBlockExecutionsPass();
@@ -157,6 +158,7 @@ void registerTestPasses() {
   test::registerTestLoopFusion();
   test::registerTestLoopMappingPass();
   test::registerTestLoopUnrollingPass();
+  test::registerTestMathPolynomialApproximationPass();
   test::registerTestMemRefDependenceCheck();
   test::registerTestMemRefStrideCalculation();
   test::registerTestNumberOfBlockExecutionsPass();


        


More information about the Mlir-commits mailing list