[Mlir-commits] [mlir] ea7f211 - [mlir] Add polynomial approximation for math::ExpOp
Ahmed Taei
llvmlistbot at llvm.org
Mon Mar 1 11:02:26 PST 2021
Author: Ahmed Taei
Date: 2021-03-01T11:02:07-08:00
New Revision: ea7f211b2e6cbb49f178ec9ba085a4958d33cdea
URL: https://github.com/llvm/llvm-project/commit/ea7f211b2e6cbb49f178ec9ba085a4958d33cdea
DIFF: https://github.com/llvm/llvm-project/commit/ea7f211b2e6cbb49f178ec9ba085a4958d33cdea.diff
LOG: [mlir] Add polynomial approximation for math::ExpOp
Similar to fast_exp in https://github.com/boulos/syrah
Differential Revision: https://reviews.llvm.org/D97599
Added:
Modified:
mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
mlir/test/Dialect/Math/polynomial-approximation.mlir
mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index eba9434bf5ec..669607e2ee09 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -10,7 +10,6 @@
// that do not rely on any of the library functions.
//
//===----------------------------------------------------------------------===//
-
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/Math/IR/Math.h"
@@ -20,6 +19,7 @@
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include <limits.h>
using namespace mlir;
using namespace mlir::vector;
@@ -28,6 +28,8 @@ using TypePredicate = llvm::function_ref<bool(Type)>;
static bool isF32(Type type) { return type.isF32(); }
+static bool isI32(Type type) { return type.isInteger(32); }
+
// Returns vector width if the element type is matching the predicate (scalars
// that do match the predicate have width equal to `1`).
static Optional<int> vectorWidth(Type type, TypePredicate pred) {
@@ -153,6 +155,30 @@ static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
return {normalizedFraction, exponent};
}
+// Computes exp2 for an i32 argument.
+static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
+ assert(isI32(elementType(arg.getType())) && "argument must be i32 type");
+
+ int width = vectorWidth(arg.getType());
+
+ auto bcast = [&](Value value) -> Value {
+ return broadcast(builder, value, width);
+ };
+
+ auto f32Vec = broadcast(builder.getF32Type(), width);
+ // The exponent of f32 located at 23-bit.
+ auto exponetBitLocation = bcast(i32Cst(builder, 23));
+ // Set the exponent bias to zero.
+ auto bias = bcast(i32Cst(builder, 127));
+
+ Value biasedArg = builder.create<AddIOp>(arg, bias);
+ Value exp2ValueInt =
+ builder.create<ShiftLeftOp>(biasedArg, exponetBitLocation);
+ Value exp2ValueF32 = builder.create<LLVM::BitcastOp>(f32Vec, exp2ValueInt);
+
+ return exp2ValueF32;
+}
+
//----------------------------------------------------------------------------//
// TanhOp approximation.
//----------------------------------------------------------------------------//
@@ -230,6 +256,11 @@ TanhApproximation::matchAndRewrite(math::TanhOp op,
return success();
}
+#define LN2_VALUE \
+ 0.693147180559945309417232121458176568075500134360255254120680009493393621L
+#define LN2E_VALUE \
+ 1.442695040888963407359924681001892137426645954152985934135449406931109219L
+
//----------------------------------------------------------------------------//
// LogOp approximation.
//----------------------------------------------------------------------------//
@@ -247,9 +278,6 @@ struct LogApproximation : public OpRewritePattern<math::LogOp> {
};
} // namespace
-#define LN2_VALUE \
- 0.693147180559945309417232121458176568075500134360255254120680009493393621L
-
LogicalResult
LogApproximation::matchAndRewrite(math::LogOp op,
PatternRewriter &rewriter) const {
@@ -353,9 +381,125 @@ LogApproximation::matchAndRewrite(math::LogOp op,
return success();
}
+//----------------------------------------------------------------------------//
+// Exp approximation.
+//----------------------------------------------------------------------------//
+
+namespace {
+
+struct ExpApproximation : public OpRewritePattern<math::ExpOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(math::ExpOp op,
+ PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+// Approximate exp(x) using its reduced range exp(y) where y is in the range
+// [0, ln(2)], let y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2), exp(x)
+// = exp(y) * 2^k. exp(y).
+LogicalResult
+ExpApproximation::matchAndRewrite(math::ExpOp op,
+ PatternRewriter &rewriter) const {
+ auto width = vectorWidth(op.operand().getType(), isF32);
+ if (!width.hasValue())
+ return rewriter.notifyMatchFailure(op, "unsupported operand type");
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+
+ // TODO: Consider a common pattern rewriter with all methods below to
+ // write the approximations.
+ auto bcast = [&](Value value) -> Value {
+ return broadcast(builder, value, *width);
+ };
+ auto fmla = [&](Value a, Value b, Value c) {
+ return builder.create<FmaFOp>(a, b, c);
+ };
+ auto mul = [&](Value a, Value b) -> Value {
+ return builder.create<MulFOp>(a, b);
+ };
+ auto sub = [&](Value a, Value b) -> Value {
+ return builder.create<SubFOp>(a, b);
+ };
+ auto floor = [&](Value a) { return builder.create<FloorFOp>(a); };
+
+ Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE)));
+ Value cstLN2E = bcast(f32Cst(builder, static_cast<float>(LN2E_VALUE)));
+
+ // Polynomial coefficients.
+ Value cstCephesExpP0 = bcast(f32Cst(builder, 1.0));
+ Value cstCephesExpP1 = bcast(f32Cst(builder, 1.0));
+ Value cstCephesExpP2 = bcast(f32Cst(builder, 0.49970514590562437052f));
+ Value cstCephesExpP3 = bcast(f32Cst(builder, 0.16873890085469545053f));
+ Value cstCephesExpP4 = bcast(f32Cst(builder, 0.03668965196652099192f));
+ Value cstCephesExpP5 = bcast(f32Cst(builder, 0.01314350012789660196f));
+
+ Value x = op.operand();
+
+ // Reduced y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2)
+ Value xL2Inv = mul(x, cstLN2E);
+ Value kF32 = floor(xL2Inv);
+ Value kLn2 = mul(kF32, cstLn2);
+ Value y = sub(x, kLn2);
+
+ // Use Estrin's evaluation scheme with 3 independent parts:
+ // P(y)^y : (c0 + c1 y) + (c2 + c3 y) y^2 + (c4 + c5 y) y^4
+ Value y2 = mul(y, y);
+ Value y4 = mul(y2, y2);
+
+ Value q0 = fmla(cstCephesExpP1, y, cstCephesExpP0);
+ Value q1 = fmla(cstCephesExpP3, y, cstCephesExpP2);
+ Value q2 = fmla(cstCephesExpP5, y, cstCephesExpP4);
+ Value expY = fmla(q1, y2, q0);
+ expY = fmla(q2, y4, expY);
+
+ auto i32Vec = broadcast(builder.getI32Type(), *width);
+
+ // exp2(k)
+ Value k = builder.create<FPToSIOp>(kF32, i32Vec);
+ Value exp2KValue = exp2I32(builder, k);
+
+ // exp(x) = exp(y) * exp2(k)
+ expY = mul(expY, exp2KValue);
+
+ // Handle overflow, inf and underflow of exp(x). exp(x) range is [0, inf], its
+ // partitioned as the following:
+ // exp(x) = 0, x <= -inf
+ // exp(x) = underflow (min_float), x <= -88
+ // exp(x) = inf (min_float), x >= 88
+ // Note: |k| = 127 is the value where the 8-bits exponent saturates.
+ Value zerof32Const = bcast(f32Cst(builder, 0));
+ auto constPosInfinity =
+ bcast(f32Cst(builder, std::numeric_limits<float>::infinity()));
+ auto constNegIfinity =
+ bcast(f32Cst(builder, -std::numeric_limits<float>::infinity()));
+ auto underflow = bcast(f32Cst(builder, std::numeric_limits<float>::min()));
+
+ Value kMaxConst = bcast(i32Cst(builder, 127));
+ Value kMaxNegConst = bcast(i32Cst(builder, -127));
+ Value rightBound = builder.create<CmpIOp>(CmpIPredicate::sle, k, kMaxConst);
+ Value leftBound = builder.create<CmpIOp>(CmpIPredicate::sge, k, kMaxNegConst);
+
+ Value isNegInfinityX =
+ builder.create<CmpFOp>(CmpFPredicate::OEQ, x, constNegIfinity);
+ Value isPostiveX =
+ builder.create<CmpFOp>(CmpFPredicate::OGT, x, zerof32Const);
+ Value isComputable = builder.create<AndOp>(rightBound, leftBound);
+
+ expY = builder.create<SelectOp>(
+ isComputable, expY,
+ builder.create<SelectOp>(
+ isPostiveX, constPosInfinity,
+ builder.create<SelectOp>(isNegInfinityX, zerof32Const, underflow)));
+
+ rewriter.replaceOp(op, expY);
+
+ return success();
+}
+
//----------------------------------------------------------------------------//
void mlir::populateMathPolynomialApproximationPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
- patterns.insert<TanhApproximation, LogApproximation>(ctx);
+ patterns.insert<TanhApproximation, LogApproximation, ExpApproximation>(ctx);
}
diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
index 82f30c612f31..6d102e5f5c85 100644
--- a/mlir/test/Dialect/Math/polynomial-approximation.mlir
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -20,3 +20,16 @@ func @vector(%arg0: vector<8xf32>) -> vector<8xf32> {
%1 = math.log %0 : vector<8xf32>
return %1 : vector<8xf32>
}
+
+// CHECK-LABEL: @exp_scalar
+func @exp_scalar(%arg0: f32) -> f32 {
+ %0 = math.exp %arg0 : f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: @exp_vector
+func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
+ // CHECK-NOT: math.exp
+ %0 = math.exp %arg0 : vector<8xf32>
+ return %0 : vector<8xf32>
+}
diff --git a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir
index 32f2641cc009..072cc2d4655a 100644
--- a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir
+++ b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir
@@ -71,8 +71,46 @@ func @log() {
return
}
+// -------------------------------------------------------------------------- //
+// Log.
+// -------------------------------------------------------------------------- //
+func @exp() {
+ // CHECK: 2.71828
+ %0 = constant 1.0 : f32
+ %1 = math.exp %0 : f32
+ vector.print %1 : f32
+
+ // CHECK: 0.778802, 2.117, 2.71828, 3.85742
+ %2 = constant dense<[-0.25, 0.75, 1.0, 1.35]> : vector<4xf32>
+ %3 = math.exp %2 : vector<4xf32>
+ vector.print %3 : vector<4xf32>
+
+ // CHECK: 1
+ %zero = constant 0.0 : f32
+ %exp_zero = math.exp %zero : f32
+ vector.print %exp_zero : f32
+
+ // CHECK: 1.17549e-38, 1.38879e-11, 7.20049e+10, inf
+ %special_vec = constant dense<[-89.0, -25.0, 25.0, 89.0]> : vector<4xf32>
+ %exp_special_vec = math.exp %special_vec : vector<4xf32>
+ vector.print %exp_special_vec : vector<4xf32>
+
+ // CHECK: inf
+ %inf = constant 0x7f800000 : f32
+ %exp_inf = math.exp %inf : f32
+ vector.print %exp_inf : f32
+
+ // CHECK: 0
+ %negative_inf = constant 0xff800000 : f32
+ %exp_negative_inf = math.exp %negative_inf : f32
+ vector.print %exp_negative_inf : f32
+
+ return
+}
+
func @main() {
call @tanh(): () -> ()
call @log(): () -> ()
+ call @exp(): () -> ()
return
}
More information about the Mlir-commits
mailing list