[Mlir-commits] [mlir] f99ccf6 - [mlir] Add math polynomial approximation pass
Eugene Zhulenev
llvmlistbot at llvm.org
Fri Feb 19 12:43:42 PST 2021
Author: Eugene Zhulenev
Date: 2021-02-19T12:43:36-08:00
New Revision: f99ccf6516bdd5def4d3bc311330aec92f5cb99d
URL: https://github.com/llvm/llvm-project/commit/f99ccf6516bdd5def4d3bc311330aec92f5cb99d
DIFF: https://github.com/llvm/llvm-project/commit/f99ccf6516bdd5def4d3bc311330aec92f5cb99d.diff
LOG: [mlir] Add math polynomial approximation pass
This gives ~30x speedup compared to expanding Tanh into exp operations:
```
name old cpu/op new cpu/op delta
BM_mlir_Tanh_f32/10 253ns ± 3% 55ns ± 7% -78.35% (p=0.000 n=44+41)
BM_mlir_Tanh_f32/100 2.21µs ± 4% 0.14µs ± 8% -93.85% (p=0.000 n=48+49)
BM_mlir_Tanh_f32/1k 22.6µs ± 4% 0.7µs ± 5% -96.68% (p=0.000 n=32+42)
BM_mlir_Tanh_f32/10k 225µs ± 5% 7µs ± 6% -96.88% (p=0.000 n=49+55)
name old time/op new time/op delta
BM_mlir_Tanh_f32/10 259ns ± 1% 56ns ± 2% -78.31% (p=0.000 n=41+39)
BM_mlir_Tanh_f32/100 2.27µs ± 1% 0.14µs ± 5% -93.89% (p=0.000 n=46+49)
BM_mlir_Tanh_f32/1k 22.9µs ± 1% 0.8µs ± 4% -96.67% (p=0.000 n=30+42)
BM_mlir_Tanh_f32/10k 230µs ± 0% 7µs ± 3% -96.88% (p=0.000 n=37+55)
```
This approximations is based on Eigen::generic_fast_tanh function
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D96739
Added:
mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
mlir/test/Dialect/Math/polynomial-approximation.mlir
mlir/test/lib/Transforms/TestPolynomialApproximation.cpp
mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir
Modified:
mlir/include/mlir/Dialect/Math/Transforms/Passes.h
mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
mlir/test/lib/Transforms/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 82da168c2ccd..c965bab3769b 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -19,6 +19,9 @@ class OwningRewritePatternList;
void populateExpandTanhPattern(OwningRewritePatternList &patterns,
MLIRContext *ctx);
+void populateMathPolynomialApproximationPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *ctx);
+
} // namespace mlir
#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_
diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index b85549941a63..23463af00fb9 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRMathTransforms
ExpandTanh.cpp
+ PolynomialApproximation.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Math/Transforms
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
new file mode 100644
index 000000000000..d230334a8019
--- /dev/null
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -0,0 +1,194 @@
+//===- PolynomialApproximation.cpp - Approximate math operations ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements expansion of math operations to fast approximations
+// that do not rely on any of the library functions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+static bool isValidFloatType(Type type) {
+ if (auto vectorType = type.dyn_cast<VectorType>())
+ return vectorType.getElementType().isa<FloatType>();
+ return type.isa<FloatType>();
+}
+
+//----------------------------------------------------------------------------//
+// A PatternRewriter wrapper that provides concise API for building expansions
+// for operations on float scalars or vectors.
+//----------------------------------------------------------------------------//
+
+namespace {
+class FloatApproximationBuilder {
+public:
+ FloatApproximationBuilder(Location loc, Type type, PatternRewriter &rewriter);
+
+ Value constant(double value) const;
+
+ Value abs(Value a) const;
+ Value min(Value a, Value b) const;
+ Value max(Value a, Value b) const;
+ Value mul(Value a, Value b) const;
+ Value div(Value a, Value b) const;
+
+ // Fused multiple-add operation: a * b + c.
+ Value madd(Value a, Value b, Value c) const;
+
+ // Compares values `a` and `b` with the given `predicate`.
+ Value cmp(CmpFPredicate predicate, Value a, Value b) const;
+
+ // Selects values from `a` or `b` based on the `predicate`.
+ Value select(Value predicate, Value a, Value b) const;
+
+private:
+ Location loc;
+ PatternRewriter &rewriter;
+ VectorType vectorType; // can be null for scalar type
+ FloatType elementType;
+};
+} // namespace
+
+FloatApproximationBuilder::FloatApproximationBuilder(Location loc, Type type,
+ PatternRewriter &rewriter)
+ : loc(loc), rewriter(rewriter) {
+ vectorType = type.dyn_cast<VectorType>();
+
+ if (vectorType)
+ elementType = vectorType.getElementType().cast<FloatType>();
+ else
+ elementType = type.cast<FloatType>();
+}
+
+Value FloatApproximationBuilder::constant(double value) const {
+ auto attr = rewriter.getFloatAttr(elementType, value);
+ Value scalar = rewriter.create<ConstantOp>(loc, attr);
+
+ if (vectorType)
+ return rewriter.create<BroadcastOp>(loc, vectorType, scalar);
+ return scalar;
+}
+
+Value FloatApproximationBuilder::abs(Value a) const {
+ return rewriter.create<AbsFOp>(loc, a);
+}
+
+Value FloatApproximationBuilder::min(Value a, Value b) const {
+ return select(cmp(CmpFPredicate::OLT, a, b), a, b);
+}
+Value FloatApproximationBuilder::max(Value a, Value b) const {
+ return select(cmp(CmpFPredicate::OGT, a, b), a, b);
+}
+Value FloatApproximationBuilder::mul(Value a, Value b) const {
+ return rewriter.create<MulFOp>(loc, a, b);
+}
+
+Value FloatApproximationBuilder::div(Value a, Value b) const {
+ return rewriter.create<DivFOp>(loc, a, b);
+}
+
+Value FloatApproximationBuilder::madd(Value a, Value b, Value c) const {
+ return rewriter.create<FmaFOp>(loc, a, b, c);
+}
+
+Value FloatApproximationBuilder::cmp(CmpFPredicate predicate, Value a,
+ Value b) const {
+ return rewriter.create<CmpFOp>(loc, predicate, a, b);
+}
+
+Value FloatApproximationBuilder::select(Value predicate, Value a,
+ Value b) const {
+ return rewriter.create<SelectOp>(loc, predicate, a, b);
+}
+
+//----------------------------------------------------------------------------//
+// TanhOp approximation.
+//----------------------------------------------------------------------------//
+
+namespace {
+struct TanhApproximation : public OpRewritePattern<math::TanhOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(math::TanhOp op,
+ PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+LogicalResult
+TanhApproximation::matchAndRewrite(math::TanhOp op,
+ PatternRewriter &rewriter) const {
+ if (!isValidFloatType(op.operand().getType()))
+ return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+ Value operand = op.operand();
+ FloatApproximationBuilder builder(op->getLoc(), operand.getType(), rewriter);
+
+ // Clamp operand into [plusClamp, minusClamp] range.
+ Value plusClamp = builder.constant(7.90531110763549805);
+ Value minusClamp = builder.constant(-7.9053111076354980);
+ Value x = builder.max(builder.min(operand, plusClamp), minusClamp);
+
+ // Mask for tiny values that are approximated with `operand`.
+ Value tiny = builder.constant(0.0004f);
+ Value tinyMask = builder.cmp(CmpFPredicate::OLT, builder.abs(operand), tiny);
+
+ // The monomial coefficients of the numerator polynomial (odd).
+ Value alpha1 = builder.constant(4.89352455891786e-03);
+ Value alpha3 = builder.constant(6.37261928875436e-04);
+ Value alpha5 = builder.constant(1.48572235717979e-05);
+ Value alpha7 = builder.constant(5.12229709037114e-08);
+ Value alpha9 = builder.constant(-8.60467152213735e-11);
+ Value alpha11 = builder.constant(2.00018790482477e-13);
+ Value alpha13 = builder.constant(-2.76076847742355e-16);
+
+ // The monomial coefficients of the denominator polynomial (even).
+ Value beta0 = builder.constant(4.89352518554385e-03);
+ Value beta2 = builder.constant(2.26843463243900e-03);
+ Value beta4 = builder.constant(1.18534705686654e-04);
+ Value beta6 = builder.constant(1.19825839466702e-06);
+
+ // Since the polynomials are odd/even, we need x^2.
+ Value x2 = builder.mul(x, x);
+
+ // Evaluate the numerator polynomial p.
+ Value p = builder.madd(x2, alpha13, alpha11);
+ p = builder.madd(x2, p, alpha9);
+ p = builder.madd(x2, p, alpha7);
+ p = builder.madd(x2, p, alpha5);
+ p = builder.madd(x2, p, alpha3);
+ p = builder.madd(x2, p, alpha1);
+ p = builder.mul(x, p);
+
+ // Evaluate the denominator polynomial q.
+ Value q = builder.madd(x2, beta6, beta4);
+ q = builder.madd(x2, q, beta2);
+ q = builder.madd(x2, q, beta0);
+
+ // Divide the numerator by the denominator.
+ Value res = builder.select(tinyMask, x, builder.div(p, q));
+
+ rewriter.replaceOp(op, res);
+
+ return success();
+}
+
+//----------------------------------------------------------------------------//
+
+void mlir::populateMathPolynomialApproximationPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *ctx) {
+ patterns.insert<TanhApproximation>(ctx);
+}
diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
new file mode 100644
index 000000000000..a02443061502
--- /dev/null
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt %s -test-math-polynomial-approximation | FileCheck %s
+
+// CHECK-LABEL: @tanh_scalar
+func @tanh_scalar(%arg0: f32) -> f32 {
+ // CHECK-NOT: tanh
+ %0 = math.tanh %arg0 : f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: @tanh_vector
+func @tanh_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
+ // CHECK-NOT: tanh
+ %0 = math.tanh %arg0 : vector<8xf32>
+ return %0 : vector<8xf32>
+}
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 4917616bbd4a..7de188df4aa8 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -26,6 +26,7 @@ add_mlir_library(MLIRTestTransforms
TestLoopUnrolling.cpp
TestNumberOfExecutions.cpp
TestOpaqueLoc.cpp
+ TestPolynomialApproximation.cpp
TestMemRefBoundCheck.cpp
TestMemRefDependenceCheck.cpp
TestMemRefStrideCalculation.cpp
diff --git a/mlir/test/lib/Transforms/TestPolynomialApproximation.cpp b/mlir/test/lib/Transforms/TestPolynomialApproximation.cpp
new file mode 100644
index 000000000000..4f48538bd4bf
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestPolynomialApproximation.cpp
@@ -0,0 +1,46 @@
+//===- TestPolynomialApproximation.cpp - Test math ops approximations -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains test passes for expanding math operations into
+// polynomial approximations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+struct TestMathPolynomialApproximationPass
+ : public PassWrapper<TestMathPolynomialApproximationPass, FunctionPass> {
+ void runOnFunction() override;
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<vector::VectorDialect, math::MathDialect>();
+ }
+};
+} // end anonymous namespace
+
+void TestMathPolynomialApproximationPass::runOnFunction() {
+ OwningRewritePatternList patterns;
+ populateMathPolynomialApproximationPatterns(patterns, &getContext());
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+}
+
+namespace mlir {
+namespace test {
+void registerTestMathPolynomialApproximationPass() {
+ PassRegistration<TestMathPolynomialApproximationPass> pass(
+ "test-math-polynomial-approximation",
+ "Test math polynomial approximations");
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir
new file mode 100644
index 000000000000..444833789859
--- /dev/null
+++ b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir
@@ -0,0 +1,32 @@
+// RUN: mlir-opt %s -test-math-polynomial-approximation \
+// RUN: -convert-vector-to-llvm \
+// RUN: -convert-std-to-llvm \
+// RUN: | mlir-cpu-runner \
+// RUN: -e main -entry-point-result=void -O0 \
+// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext \
+// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+
+func @main() {
+ // ------------------------------------------------------------------------ //
+ // Tanh.
+ // ------------------------------------------------------------------------ //
+
+ // CHECK: 0.848284
+ %0 = constant 1.25 : f32
+ %1 = math.tanh %0 : f32
+ vector.print %1 : f32
+
+ // CHECK: 0.244919, 0.635149, 0.761594, 0.848284
+ %2 = constant dense<[0.25, 0.75, 1.0, 1.25]> : vector<4xf32>
+ %3 = math.tanh %2 : vector<4xf32>
+ vector.print %3 : vector<4xf32>
+
+ // CHECK: 0.099668, 0.197375, 0.291313, 0.379949, 0.462117, 0.53705, 0.604368, 0.664037
+ %4 = constant dense<[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]> : vector<8xf32>
+ %5 = math.tanh %4 : vector<8xf32>
+ vector.print %5 : vector<8xf32>
+
+ return
+}
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 2ce2e0198202..e03e7e8f8907 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -84,6 +84,7 @@ void registerTestLivenessPass();
void registerTestLoopFusion();
void registerTestLoopMappingPass();
void registerTestLoopUnrollingPass();
+void registerTestMathPolynomialApproximationPass();
void registerTestMemRefDependenceCheck();
void registerTestMemRefStrideCalculation();
void registerTestNumberOfBlockExecutionsPass();
@@ -157,6 +158,7 @@ void registerTestPasses() {
test::registerTestLoopFusion();
test::registerTestLoopMappingPass();
test::registerTestLoopUnrollingPass();
+ test::registerTestMathPolynomialApproximationPass();
test::registerTestMemRefDependenceCheck();
test::registerTestMemRefStrideCalculation();
test::registerTestNumberOfBlockExecutionsPass();
More information about the Mlir-commits
mailing list