[Mlir-commits] [mlir] ea7f211 - [mlir] Add polynomial approximation for math::ExpOp

Ahmed Taei llvmlistbot at llvm.org
Mon Mar 1 11:02:26 PST 2021


Author: Ahmed Taei
Date: 2021-03-01T11:02:07-08:00
New Revision: ea7f211b2e6cbb49f178ec9ba085a4958d33cdea

URL: https://github.com/llvm/llvm-project/commit/ea7f211b2e6cbb49f178ec9ba085a4958d33cdea
DIFF: https://github.com/llvm/llvm-project/commit/ea7f211b2e6cbb49f178ec9ba085a4958d33cdea.diff

LOG: [mlir] Add polynomial approximation for math::ExpOp

Similar to fast_exp in https://github.com/boulos/syrah

Differential Revision: https://reviews.llvm.org/D97599

Added: 
    

Modified: 
    mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
    mlir/test/Dialect/Math/polynomial-approximation.mlir
    mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index eba9434bf5ec..669607e2ee09 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -10,7 +10,6 @@
 // that do not rely on any of the library functions.
 //
 //===----------------------------------------------------------------------===//
-
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/Math/IR/Math.h"
@@ -20,6 +19,7 @@
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include <limits.h>
 
 using namespace mlir;
 using namespace mlir::vector;
@@ -28,6 +28,8 @@ using TypePredicate = llvm::function_ref<bool(Type)>;
 
 static bool isF32(Type type) { return type.isF32(); }
 
+static bool isI32(Type type) { return type.isInteger(32); }
+
 // Returns vector width if the element type is matching the predicate (scalars
 // that do match the predicate have width equal to `1`).
 static Optional<int> vectorWidth(Type type, TypePredicate pred) {
@@ -153,6 +155,30 @@ static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
   return {normalizedFraction, exponent};
 }
 
+// Computes exp2 for an i32 argument.
+static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
+  assert(isI32(elementType(arg.getType())) && "argument must be i32 type");
+
+  int width = vectorWidth(arg.getType());
+
+  auto bcast = [&](Value value) -> Value {
+    return broadcast(builder, value, width);
+  };
+
+  auto f32Vec = broadcast(builder.getF32Type(), width);
+  // The exponent of f32 located at 23-bit.
+  auto exponetBitLocation = bcast(i32Cst(builder, 23));
+  // Set the exponent bias to zero.
+  auto bias = bcast(i32Cst(builder, 127));
+
+  Value biasedArg = builder.create<AddIOp>(arg, bias);
+  Value exp2ValueInt =
+      builder.create<ShiftLeftOp>(biasedArg, exponetBitLocation);
+  Value exp2ValueF32 = builder.create<LLVM::BitcastOp>(f32Vec, exp2ValueInt);
+
+  return exp2ValueF32;
+}
+
 //----------------------------------------------------------------------------//
 // TanhOp approximation.
 //----------------------------------------------------------------------------//
@@ -230,6 +256,11 @@ TanhApproximation::matchAndRewrite(math::TanhOp op,
   return success();
 }
 
+#define LN2_VALUE                                                              \
+  0.693147180559945309417232121458176568075500134360255254120680009493393621L
+#define LN2E_VALUE                                                             \
+  1.442695040888963407359924681001892137426645954152985934135449406931109219L
+
 //----------------------------------------------------------------------------//
 // LogOp approximation.
 //----------------------------------------------------------------------------//
@@ -247,9 +278,6 @@ struct LogApproximation : public OpRewritePattern<math::LogOp> {
 };
 } // namespace
 
-#define LN2_VALUE                                                              \
-  0.693147180559945309417232121458176568075500134360255254120680009493393621L
-
 LogicalResult
 LogApproximation::matchAndRewrite(math::LogOp op,
                                   PatternRewriter &rewriter) const {
@@ -353,9 +381,125 @@ LogApproximation::matchAndRewrite(math::LogOp op,
   return success();
 }
 
+//----------------------------------------------------------------------------//
+// Exp approximation.
+//----------------------------------------------------------------------------//
+
+namespace {
+
+struct ExpApproximation : public OpRewritePattern<math::ExpOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(math::ExpOp op,
+                                PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+// Approximate exp(x) using its reduced range exp(y) where y is in the range
+// [0, ln(2)], let y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2), exp(x)
+// = exp(y) * 2^k. exp(y).
+LogicalResult
+ExpApproximation::matchAndRewrite(math::ExpOp op,
+                                  PatternRewriter &rewriter) const {
+  auto width = vectorWidth(op.operand().getType(), isF32);
+  if (!width.hasValue())
+    return rewriter.notifyMatchFailure(op, "unsupported operand type");
+  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+
+  // TODO: Consider a common pattern rewriter with all methods below to
+  // write the approximations.
+  auto bcast = [&](Value value) -> Value {
+    return broadcast(builder, value, *width);
+  };
+  auto fmla = [&](Value a, Value b, Value c) {
+    return builder.create<FmaFOp>(a, b, c);
+  };
+  auto mul = [&](Value a, Value b) -> Value {
+    return builder.create<MulFOp>(a, b);
+  };
+  auto sub = [&](Value a, Value b) -> Value {
+    return builder.create<SubFOp>(a, b);
+  };
+  auto floor = [&](Value a) { return builder.create<FloorFOp>(a); };
+
+  Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE)));
+  Value cstLN2E = bcast(f32Cst(builder, static_cast<float>(LN2E_VALUE)));
+
+  // Polynomial coefficients.
+  Value cstCephesExpP0 = bcast(f32Cst(builder, 1.0));
+  Value cstCephesExpP1 = bcast(f32Cst(builder, 1.0));
+  Value cstCephesExpP2 = bcast(f32Cst(builder, 0.49970514590562437052f));
+  Value cstCephesExpP3 = bcast(f32Cst(builder, 0.16873890085469545053f));
+  Value cstCephesExpP4 = bcast(f32Cst(builder, 0.03668965196652099192f));
+  Value cstCephesExpP5 = bcast(f32Cst(builder, 0.01314350012789660196f));
+
+  Value x = op.operand();
+
+  // Reduced y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2)
+  Value xL2Inv = mul(x, cstLN2E);
+  Value kF32 = floor(xL2Inv);
+  Value kLn2 = mul(kF32, cstLn2);
+  Value y = sub(x, kLn2);
+
+  // Use Estrin's evaluation scheme with 3 independent parts:
+  // P(y)^y : (c0 + c1 y) + (c2 + c3 y) y^2 + (c4 + c5 y) y^4
+  Value y2 = mul(y, y);
+  Value y4 = mul(y2, y2);
+
+  Value q0 = fmla(cstCephesExpP1, y, cstCephesExpP0);
+  Value q1 = fmla(cstCephesExpP3, y, cstCephesExpP2);
+  Value q2 = fmla(cstCephesExpP5, y, cstCephesExpP4);
+  Value expY = fmla(q1, y2, q0);
+  expY = fmla(q2, y4, expY);
+
+  auto i32Vec = broadcast(builder.getI32Type(), *width);
+
+  // exp2(k)
+  Value k = builder.create<FPToSIOp>(kF32, i32Vec);
+  Value exp2KValue = exp2I32(builder, k);
+
+  // exp(x) = exp(y) * exp2(k)
+  expY = mul(expY, exp2KValue);
+
+  // Handle overflow, inf and underflow of exp(x). exp(x) range is [0, inf], its
+  // partitioned as the following:
+  // exp(x) = 0, x <= -inf
+  // exp(x) = underflow (min_float), x <= -88
+  // exp(x) = inf (min_float), x >= 88
+  // Note: |k| = 127 is the value where the 8-bits exponent saturates.
+  Value zerof32Const = bcast(f32Cst(builder, 0));
+  auto constPosInfinity =
+      bcast(f32Cst(builder, std::numeric_limits<float>::infinity()));
+  auto constNegIfinity =
+      bcast(f32Cst(builder, -std::numeric_limits<float>::infinity()));
+  auto underflow = bcast(f32Cst(builder, std::numeric_limits<float>::min()));
+
+  Value kMaxConst = bcast(i32Cst(builder, 127));
+  Value kMaxNegConst = bcast(i32Cst(builder, -127));
+  Value rightBound = builder.create<CmpIOp>(CmpIPredicate::sle, k, kMaxConst);
+  Value leftBound = builder.create<CmpIOp>(CmpIPredicate::sge, k, kMaxNegConst);
+
+  Value isNegInfinityX =
+      builder.create<CmpFOp>(CmpFPredicate::OEQ, x, constNegIfinity);
+  Value isPostiveX =
+      builder.create<CmpFOp>(CmpFPredicate::OGT, x, zerof32Const);
+  Value isComputable = builder.create<AndOp>(rightBound, leftBound);
+
+  expY = builder.create<SelectOp>(
+      isComputable, expY,
+      builder.create<SelectOp>(
+          isPostiveX, constPosInfinity,
+          builder.create<SelectOp>(isNegInfinityX, zerof32Const, underflow)));
+
+  rewriter.replaceOp(op, expY);
+
+  return success();
+}
+
 //----------------------------------------------------------------------------//
 
 void mlir::populateMathPolynomialApproximationPatterns(
     OwningRewritePatternList &patterns, MLIRContext *ctx) {
-  patterns.insert<TanhApproximation, LogApproximation>(ctx);
+  patterns.insert<TanhApproximation, LogApproximation, ExpApproximation>(ctx);
 }

diff  --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
index 82f30c612f31..6d102e5f5c85 100644
--- a/mlir/test/Dialect/Math/polynomial-approximation.mlir
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -20,3 +20,16 @@ func @vector(%arg0: vector<8xf32>) -> vector<8xf32> {
   %1 = math.log %0 : vector<8xf32>
   return %1 : vector<8xf32>
 }
+
+// CHECK-LABEL: @exp_scalar
+func @exp_scalar(%arg0: f32) -> f32 {
+  %0 = math.exp %arg0 : f32
+  return %0 : f32
+}
+
+// CHECK-LABEL: @exp_vector
+func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
+  // CHECK-NOT: math.exp
+  %0 = math.exp %arg0 : vector<8xf32>
+  return %0 : vector<8xf32>
+}

diff  --git a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir
index 32f2641cc009..072cc2d4655a 100644
--- a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir
+++ b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir
@@ -71,8 +71,46 @@ func @log() {
   return
 }
 
+// -------------------------------------------------------------------------- //
+// Log.
+// -------------------------------------------------------------------------- //
+func @exp() {
+  // CHECK: 2.71828
+  %0 = constant 1.0 : f32
+  %1 = math.exp %0 : f32
+  vector.print %1 : f32
+
+  // CHECK: 0.778802, 2.117, 2.71828, 3.85742
+  %2 = constant dense<[-0.25, 0.75, 1.0, 1.35]> : vector<4xf32>
+  %3 = math.exp %2 : vector<4xf32>
+  vector.print %3 : vector<4xf32>
+
+  // CHECK: 1
+  %zero = constant 0.0 : f32
+  %exp_zero = math.exp %zero : f32
+  vector.print %exp_zero : f32
+
+  // CHECK: 1.17549e-38, 1.38879e-11, 7.20049e+10, inf
+  %special_vec = constant dense<[-89.0, -25.0, 25.0, 89.0]> : vector<4xf32>
+  %exp_special_vec = math.exp %special_vec : vector<4xf32>
+  vector.print %exp_special_vec : vector<4xf32>
+
+  // CHECK: inf
+  %inf = constant 0x7f800000 : f32
+  %exp_inf = math.exp %inf : f32
+  vector.print %exp_inf : f32
+
+  // CHECK: 0
+  %negative_inf = constant 0xff800000 : f32
+  %exp_negative_inf = math.exp %negative_inf : f32
+  vector.print %exp_negative_inf : f32
+
+  return
+}
+
 func @main() {
   call @tanh(): () -> ()
   call @log(): () -> ()
+  call @exp(): () -> ()
   return
 }


        


More information about the Mlir-commits mailing list