[Mlir-commits] [mlir] f1b9221 - [MLIR][Math] Add erf to math dialect

Sean Silva llvmlistbot at llvm.org
Mon Oct 25 11:30:45 PDT 2021


Author: Boian Petkantchin
Date: 2021-10-25T18:30:17Z
New Revision: f1b922188ead5ca492c8d8edd47921b013a22ae0

URL: https://github.com/llvm/llvm-project/commit/f1b922188ead5ca492c8d8edd47921b013a22ae0
DIFF: https://github.com/llvm/llvm-project/commit/f1b922188ead5ca492c8d8edd47921b013a22ae0.diff

LOG: [MLIR][Math] Add erf to math dialect

Add math.erf lowering to libm call.
Add math.erf polynomial approximation.

Reviewed By: silvas, ezhulenev

Differential Revision: https://reviews.llvm.org/D112200

Added: 
    mlir/include/mlir/Dialect/Math/Transforms/Approximation.h

Modified: 
    mlir/include/mlir/Dialect/Math/IR/MathOps.td
    mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
    mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
    mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
    mlir/test/Dialect/Math/ops.mlir
    mlir/test/Dialect/Math/polynomial-approximation.mlir
    mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir
    mlir/utils/vim/syntax/mlir.vim

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 57482b55768f5..75e0b290e8457 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -285,6 +285,39 @@ def Math_SinOp : Math_FloatUnaryOp<"sin"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// ErfOp
+//===----------------------------------------------------------------------===//
+
+def Math_ErfOp : Math_FloatUnaryOp<"erf"> {
+  let summary = "error function of the specified value";
+  let description = [{
+    Syntax:
+
+    ```
+    operation ::= ssa-id `=` `math.erf` ssa-use `:` type
+    ```
+
+    The `erf` operation computes the error function. It takes one operand
+    and returns one result of the same type. This type may be a float scalar
+    type, a vector whose element type is float, or a tensor of floats. It has
+    no standard attributes.
+
+    Example:
+
+    ```mlir
+    // Scalar error function value.
+    %a = math.erf %b : f64
+
+    // SIMD vector element-wise error function value.
+    %f = math.erf %g : vector<4xf32>
+
+    // Tensor element-wise error function value.
+    %x = math.erf %y : tensor<4x?xf8>
+    ```
+  }];
+}
+
 
 //===----------------------------------------------------------------------===//
 // ExpOp

diff  --git a/mlir/include/mlir/Dialect/Math/Transforms/Approximation.h b/mlir/include/mlir/Dialect/Math/Transforms/Approximation.h
new file mode 100644
index 0000000000000..ae64f4d434a09
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Approximation.h
@@ -0,0 +1,29 @@
+//===- Approximation.h - Math dialect -----------------------------*- C++-*-==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MATH_TRANSFORMATIONS_APPROXIMATION_H_
+#define MLIR_DIALECT_MATH_TRANSFORMATIONS_APPROXIMATION_H_
+
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+namespace math {
+
+struct ErfPolynomialApproximation : public OpRewritePattern<math::ErfOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(math::ErfOp op,
+                                PatternRewriter &rewriter) const final;
+};
+
+} // namespace math
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MATH_TRANSFORMATIONS_APPROXIMATION_H_

diff  --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index f52f1cc2bb9cf..9447bed434efc 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -116,6 +116,8 @@ void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
                VecOpToScalarOp<math::TanhOp>>(patterns.getContext(), benefit);
   patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(),
                                                   "atan2f", "atan2", benefit);
+  patterns.add<ScalarOpToLibmCall<math::ErfOp>>(patterns.getContext(), "erff",
+                                                "erf", benefit);
   patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(patterns.getContext(),
                                                   "expm1f", "expm1", benefit);
   patterns.add<ScalarOpToLibmCall<math::TanhOp>>(patterns.getContext(), "tanhf",

diff  --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 3761b48c569e9..306a1d117e1f7 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -13,6 +13,7 @@
 
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Approximation.h"
 #include "mlir/Dialect/Math/Transforms/Passes.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/Dialect/X86Vector/X86VectorDialect.h"
@@ -21,9 +22,12 @@
 #include "mlir/Transforms/Bufferize.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/ArrayRef.h"
 #include <climits>
+#include <cstddef>
 
 using namespace mlir;
+using namespace mlir::math;
 using namespace mlir::vector;
 
 using TypePredicate = llvm::function_ref<bool(Type)>;
@@ -183,6 +187,24 @@ static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
   return exp2ValueF32;
 }
 
+namespace {
+Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
+                                llvm::ArrayRef<Value> coeffs, Value x) {
+  auto width = vectorWidth(x.getType(), isF32);
+  if (coeffs.size() == 0) {
+    return broadcast(builder, f32Cst(builder, 0.0f), *width);
+  } else if (coeffs.size() == 1) {
+    return coeffs[0];
+  }
+  Value res = builder.create<math::FmaOp>(x, coeffs[coeffs.size() - 1],
+                                          coeffs[coeffs.size() - 2]);
+  for (auto i = ptr
diff _t(coeffs.size()) - 3; i >= 0; --i) {
+    res = builder.create<math::FmaOp>(x, res, coeffs[i]);
+  }
+  return res;
+}
+} // namespace
+
 //----------------------------------------------------------------------------//
 // TanhOp approximation.
 //----------------------------------------------------------------------------//
@@ -465,6 +487,122 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op,
   return success();
 }
 
+//----------------------------------------------------------------------------//
+// Erf approximation.
+//----------------------------------------------------------------------------//
+
+// Approximates erf(x) with
+// a - P(x)/Q(x)
+// where P and Q are polynomials of degree 4.
+// Different coefficients are chosen based on the value of x.
+// The approximation error is ~2.5e-07.
+// Boost's minimax tool that utilizes the Remez method was used to find the
+// coefficients.
+LogicalResult
+ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
+                                            PatternRewriter &rewriter) const {
+  auto width = vectorWidth(op.operand().getType(), isF32);
+  if (!width.hasValue())
+    return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+  auto bcast = [&](Value value) -> Value {
+    return broadcast(builder, value, *width);
+  };
+
+  const int intervalsCount = 3;
+  const int polyDegree = 4;
+
+  Value zero = bcast(f32Cst(builder, 0));
+  Value one = bcast(f32Cst(builder, 1));
+  Value pp[intervalsCount][polyDegree + 1];
+  pp[0][0] = bcast(f32Cst(builder, +0.00000000000000000e+00));
+  pp[0][1] = bcast(f32Cst(builder, +1.12837916222975858e+00));
+  pp[0][2] = bcast(f32Cst(builder, -5.23018562988006470e-01));
+  pp[0][3] = bcast(f32Cst(builder, +2.09741709609267072e-01));
+  pp[0][4] = bcast(f32Cst(builder, +2.58146801602987875e-02));
+  pp[1][0] = bcast(f32Cst(builder, +0.00000000000000000e+00));
+  pp[1][1] = bcast(f32Cst(builder, +1.12750687816789140e+00));
+  pp[1][2] = bcast(f32Cst(builder, -3.64721408487825775e-01));
+  pp[1][3] = bcast(f32Cst(builder, +1.18407396425136952e-01));
+  pp[1][4] = bcast(f32Cst(builder, +3.70645533056476558e-02));
+  pp[2][0] = bcast(f32Cst(builder, -3.30093071049483172e-03));
+  pp[2][1] = bcast(f32Cst(builder, +3.51961938357697011e-03));
+  pp[2][2] = bcast(f32Cst(builder, -1.41373622814988039e-03));
+  pp[2][3] = bcast(f32Cst(builder, +2.53447094961941348e-04));
+  pp[2][4] = bcast(f32Cst(builder, -1.71048029455037401e-05));
+
+  Value qq[intervalsCount][polyDegree + 1];
+  qq[0][0] = bcast(f32Cst(builder, +1.000000000000000000e+00));
+  qq[0][1] = bcast(f32Cst(builder, -4.635138185962547255e-01));
+  qq[0][2] = bcast(f32Cst(builder, +5.192301327279782447e-01));
+  qq[0][3] = bcast(f32Cst(builder, -1.318089722204810087e-01));
+  qq[0][4] = bcast(f32Cst(builder, +7.397964654672315005e-02));
+  qq[1][0] = bcast(f32Cst(builder, +1.00000000000000000e+00));
+  qq[1][1] = bcast(f32Cst(builder, -3.27607011824493086e-01));
+  qq[1][2] = bcast(f32Cst(builder, +4.48369090658821977e-01));
+  qq[1][3] = bcast(f32Cst(builder, -8.83462621207857930e-02));
+  qq[1][4] = bcast(f32Cst(builder, +5.72442770283176093e-02));
+  qq[2][0] = bcast(f32Cst(builder, +1.00000000000000000e+00));
+  qq[2][1] = bcast(f32Cst(builder, -2.06069165953913769e+00));
+  qq[2][2] = bcast(f32Cst(builder, +1.62705939945477759e+00));
+  qq[2][3] = bcast(f32Cst(builder, -5.83389859211130017e-01));
+  qq[2][4] = bcast(f32Cst(builder, +8.21908939856640930e-02));
+
+  Value offsets[intervalsCount];
+  offsets[0] = bcast(f32Cst(builder, 0));
+  offsets[1] = bcast(f32Cst(builder, 0));
+  offsets[2] = bcast(f32Cst(builder, 1));
+
+  Value bounds[intervalsCount];
+  bounds[0] = bcast(f32Cst(builder, 0.8));
+  bounds[1] = bcast(f32Cst(builder, 2));
+  bounds[2] = bcast(f32Cst(builder, 3.75));
+
+  Value isNegativeArg = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT,
+                                                      op.operand(), zero);
+  Value negArg = builder.create<arith::NegFOp>(op.operand());
+  Value x = builder.create<SelectOp>(isNegativeArg, negArg, op.operand());
+
+  Value offset = offsets[0];
+  Value p[polyDegree + 1];
+  Value q[polyDegree + 1];
+  for (int i = 0; i <= polyDegree; ++i) {
+    p[i] = pp[0][i];
+    q[i] = qq[0][i];
+  }
+
+  // TODO: maybe use vector stacking to reduce the number of selects.
+  Value isLessThanBound[intervalsCount];
+  for (int j = 0; j < intervalsCount - 1; ++j) {
+    isLessThanBound[j] =
+        builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, bounds[j]);
+    for (int i = 0; i <= polyDegree; ++i) {
+      p[i] = builder.create<SelectOp>(isLessThanBound[j], p[i], pp[j + 1][i]);
+      q[i] = builder.create<SelectOp>(isLessThanBound[j], q[i], qq[j + 1][i]);
+    }
+    offset =
+        builder.create<SelectOp>(isLessThanBound[j], offset, offsets[j + 1]);
+  }
+  isLessThanBound[intervalsCount - 1] = builder.create<arith::CmpFOp>(
+      arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]);
+
+  Value pPoly = makePolynomialCalculation(builder, p, x);
+  Value qPoly = makePolynomialCalculation(builder, q, x);
+  Value rationalPoly = builder.create<arith::DivFOp>(pPoly, qPoly);
+  Value formula = builder.create<arith::AddFOp>(offset, rationalPoly);
+  formula = builder.create<SelectOp>(isLessThanBound[intervalsCount - 1],
+                                     formula, one);
+
+  // erf is odd function: erf(x) = -erf(-x).
+  Value negFormula = builder.create<arith::NegFOp>(formula);
+  Value res = builder.create<SelectOp>(isNegativeArg, negFormula, formula);
+
+  rewriter.replaceOp(op, res);
+
+  return success();
+}
+
 //----------------------------------------------------------------------------//
 // Exp approximation.
 //----------------------------------------------------------------------------//
@@ -848,8 +986,8 @@ void mlir::populateMathPolynomialApproximationPatterns(
     RewritePatternSet &patterns,
     const MathPolynomialApproximationOptions &options) {
   patterns.add<TanhApproximation, LogApproximation, Log2Approximation,
-               Log1pApproximation, ExpApproximation, ExpM1Approximation,
-               SinAndCosApproximation<true, math::SinOp>,
+               Log1pApproximation, ErfPolynomialApproximation, ExpApproximation,
+               ExpM1Approximation, SinAndCosApproximation<true, math::SinOp>,
                SinAndCosApproximation<false, math::CosOp>>(
       patterns.getContext());
   if (options.enableAvx2)

diff  --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
index be14562d3d4d0..51c5d4a235322 100644
--- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
+++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
@@ -1,5 +1,7 @@
 // RUN: mlir-opt %s -convert-math-to-libm -canonicalize | FileCheck %s
 
+// CHECK-DAG: @erf(f64) -> f64
+// CHECK-DAG: @erff(f32) -> f32
 // CHECK-DAG: @expm1(f64) -> f64
 // CHECK-DAG: @expm1f(f32) -> f32
 // CHECK-DAG: @atan2(f64, f64) -> f64
@@ -32,6 +34,18 @@ func @atan2_caller(%float: f32, %double: f64) -> (f32, f64) {
   return %float_result, %double_result : f32, f64
 }
 
+// CHECK-LABEL: func @erf_caller
+// CHECK-SAME: %[[FLOAT:.*]]: f32
+// CHECK-SAME: %[[DOUBLE:.*]]: f64
+func @erf_caller(%float: f32, %double: f64) -> (f32, f64)  {
+  // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @erff(%[[FLOAT]]) : (f32) -> f32
+  %float_result = math.erf %float : f32
+  // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @erf(%[[DOUBLE]]) : (f64) -> f64
+  %double_result = math.erf %double : f64
+  // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
+  return %float_result, %double_result : f32, f64
+}
+
 // CHECK-LABEL: func @expm1_caller
 // CHECK-SAME: %[[FLOAT:.*]]: f32
 // CHECK-SAME: %[[DOUBLE:.*]]: f64

diff  --git a/mlir/test/Dialect/Math/ops.mlir b/mlir/test/Dialect/Math/ops.mlir
index 186fcff050e99..67ea78f4e797c 100644
--- a/mlir/test/Dialect/Math/ops.mlir
+++ b/mlir/test/Dialect/Math/ops.mlir
@@ -50,6 +50,18 @@ func @sin(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
   return
 }
 
+// CHECK-LABEL: func @erf(
+// CHECK-SAME:            %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
+func @erf(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
+  // CHECK: %{{.*}} = math.erf %[[F]] : f32
+  %0 = math.erf %f : f32
+  // CHECK: %{{.*}} = math.erf %[[V]] : vector<4xf32>
+  %1 = math.erf %v : vector<4xf32>
+  // CHECK: %{{.*}} = math.erf %[[T]] : tensor<4x4x?xf32>
+  %2 = math.erf %t : tensor<4x4x?xf32>
+  return
+}
+
 // CHECK-LABEL: func @exp(
 // CHECK-SAME:            %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
 func @exp(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {

diff  --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
index 9ba7c47bc5d47..f7e7c215f7aaf 100644
--- a/mlir/test/Dialect/Math/polynomial-approximation.mlir
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -5,6 +5,95 @@
 // Check that all math functions lowered to approximations built from
 // standard operations (add, mul, fma, shift, etc...).
 
+// CHECK-LABEL: func @erf_scalar(
+// CHECK-SAME:    %[[val_arg0:.*]]: f32) -> f32 {
+// CHECK-DAG:     %[[val_cst:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG:     %[[val_cst_0:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-DAG:     %[[val_cst_1:.*]] = arith.constant 1.12837911 : f32
+// CHECK-DAG:     %[[val_cst_2:.*]] = arith.constant -0.523018539 : f32
+// CHECK-DAG:     %[[val_cst_3:.*]] = arith.constant 0.209741712 : f32
+// CHECK-DAG:     %[[val_cst_4:.*]] = arith.constant 0.0258146804 : f32
+// CHECK-DAG:     %[[val_cst_5:.*]] = arith.constant 1.12750685 : f32
+// CHECK-DAG:     %[[val_cst_6:.*]] = arith.constant -0.364721417 : f32
+// CHECK-DAG:     %[[val_cst_7:.*]] = arith.constant 0.118407398 : f32
+// CHECK-DAG:     %[[val_cst_8:.*]] = arith.constant 0.0370645523 : f32
+// CHECK-DAG:     %[[val_cst_9:.*]] = arith.constant -0.00330093061 : f32
+// CHECK-DAG:     %[[val_cst_10:.*]] = arith.constant 0.00351961935 : f32
+// CHECK-DAG:     %[[val_cst_11:.*]] = arith.constant -0.00141373626 : f32
+// CHECK-DAG:     %[[val_cst_12:.*]] = arith.constant 2.53447099E-4 : f32
+// CHECK-DAG:     %[[val_cst_13:.*]] = arith.constant -1.71048032E-5 : f32
+// CHECK-DAG:     %[[val_cst_14:.*]] = arith.constant -0.463513821 : f32
+// CHECK-DAG:     %[[val_cst_15:.*]] = arith.constant 0.519230127 : f32
+// CHECK-DAG:     %[[val_cst_16:.*]] = arith.constant -0.131808966 : f32
+// CHECK-DAG:     %[[val_cst_17:.*]] = arith.constant 0.0739796459 : f32
+// CHECK-DAG:     %[[val_cst_18:.*]] = arith.constant -3.276070e-01 : f32
+// CHECK-DAG:     %[[val_cst_19:.*]] = arith.constant 0.448369086 : f32
+// CHECK-DAG:     %[[val_cst_20:.*]] = arith.constant -0.0883462652 : f32
+// CHECK-DAG:     %[[val_cst_21:.*]] = arith.constant 0.0572442785 : f32
+// CHECK-DAG:     %[[val_cst_22:.*]] = arith.constant -2.0606916 : f32
+// CHECK-DAG:     %[[val_cst_23:.*]] = arith.constant 1.62705934 : f32
+// CHECK-DAG:     %[[val_cst_24:.*]] = arith.constant -0.583389878 : f32
+// CHECK-DAG:     %[[val_cst_25:.*]] = arith.constant 0.0821908935 : f32
+// CHECK-DAG:     %[[val_cst_26:.*]] = arith.constant 8.000000e-01 : f32
+// CHECK-DAG:     %[[val_cst_27:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG:     %[[val_cst_28:.*]] = arith.constant 3.750000e+00 : f32
+// CHECK:         %[[val_0:.*]] = arith.cmpf olt, %[[val_arg0]], %[[val_cst]] : f32
+// CHECK:         %[[val_1:.*]] = arith.negf %[[val_arg0]] : f32
+// CHECK:         %[[val_2:.*]] = select %[[val_0]], %[[val_1]], %[[val_arg0]] : f32
+// CHECK:         %[[val_3:.*]] = arith.cmpf olt, %[[val_2]], %[[val_cst_26]] : f32
+// CHECK:         %[[val_4:.*]] = select %[[val_3]], %[[val_cst_1]], %[[val_cst_5]] : f32
+// CHECK:         %[[val_5:.*]] = select %[[val_3]], %[[val_cst_14]], %[[val_cst_18]] : f32
+// CHECK:         %[[val_6:.*]] = select %[[val_3]], %[[val_cst_2]], %[[val_cst_6]] : f32
+// CHECK:         %[[val_7:.*]] = select %[[val_3]], %[[val_cst_15]], %[[val_cst_19]] : f32
+// CHECK:         %[[val_8:.*]] = select %[[val_3]], %[[val_cst_3]], %[[val_cst_7]] : f32
+// CHECK:         %[[val_9:.*]] = select %[[val_3]], %[[val_cst_16]], %[[val_cst_20]] : f32
+// CHECK:         %[[val_10:.*]] = select %[[val_3]], %[[val_cst_4]], %[[val_cst_8]] : f32
+// CHECK:         %[[val_11:.*]] = select %[[val_3]], %[[val_cst_17]], %[[val_cst_21]] : f32
+// CHECK:         %[[val_12:.*]] = arith.cmpf olt, %[[val_2]], %[[val_cst_27]] : f32
+// CHECK:         %[[val_13:.*]] = select %[[val_12]], %[[val_cst]], %[[val_cst_9]] : f32
+// CHECK:         %[[val_14:.*]] = select %[[val_12]], %[[val_4]], %[[val_cst_10]] : f32
+// CHECK:         %[[val_15:.*]] = select %[[val_12]], %[[val_5]], %[[val_cst_22]] : f32
+// CHECK:         %[[val_16:.*]] = select %[[val_12]], %[[val_6]], %[[val_cst_11]] : f32
+// CHECK:         %[[val_17:.*]] = select %[[val_12]], %[[val_7]], %[[val_cst_23]] : f32
+// CHECK:         %[[val_18:.*]] = select %[[val_12]], %[[val_8]], %[[val_cst_12]] : f32
+// CHECK:         %[[val_19:.*]] = select %[[val_12]], %[[val_9]], %[[val_cst_24]] : f32
+// CHECK:         %[[val_20:.*]] = select %[[val_12]], %[[val_10]], %[[val_cst_13]] : f32
+// CHECK:         %[[val_21:.*]] = select %[[val_12]], %[[val_11]], %[[val_cst_25]] : f32
+// CHECK:         %[[val_22:.*]] = select %[[val_12]], %[[val_cst]], %[[val_cst_0]] : f32
+// CHECK:         %[[val_23:.*]] = arith.cmpf ult, %[[val_2]], %[[val_cst_28]] : f32
+// CHECK:         %[[val_24:.*]] = math.fma %[[val_2]], %[[val_20]], %[[val_18]] : f32
+// CHECK:         %[[val_25:.*]] = math.fma %[[val_2]], %[[val_24]], %[[val_16]] : f32
+// CHECK:         %[[val_26:.*]] = math.fma %[[val_2]], %[[val_25]], %[[val_14]] : f32
+// CHECK:         %[[val_27:.*]] = math.fma %[[val_2]], %[[val_26]], %[[val_13]] : f32
+// CHECK:         %[[val_28:.*]] = math.fma %[[val_2]], %[[val_21]], %[[val_19]] : f32
+// CHECK:         %[[val_29:.*]] = math.fma %[[val_2]], %[[val_28]], %[[val_17]] : f32
+// CHECK:         %[[val_30:.*]] = math.fma %[[val_2]], %[[val_29]], %[[val_15]] : f32
+// CHECK:         %[[val_31:.*]] = math.fma %[[val_2]], %[[val_30]], %[[val_cst_0]] : f32
+// CHECK:         %[[val_32:.*]] = arith.divf %[[val_27]], %[[val_31]] : f32
+// CHECK:         %[[val_33:.*]] = arith.addf %[[val_22]], %[[val_32]] : f32
+// CHECK:         %[[val_34:.*]] = select %[[val_23]], %[[val_33]], %[[val_cst_0]] : f32
+// CHECK:         %[[val_35:.*]] = arith.negf %[[val_34]] : f32
+// CHECK:         %[[val_36:.*]] = select %[[val_0]], %[[val_35]], %[[val_34]] : f32
+// CHECK:         return %[[val_36]] : f32
+// CHECK:       }
+func @erf_scalar(%arg0: f32) -> f32 {
+  %0 = math.erf %arg0 : f32
+  return %0 : f32
+}
+
+// CHECK-LABEL:   func @erf_vector(
+// CHECK-SAME:                     %[[arg0:.*]]: vector<8xf32>) -> vector<8xf32> {
+// CHECK:           %[[zero:.*]] = arith.constant dense<0.000000e+00> : vector<8xf32>
+// CHECK-NOT:       erf
+// CHECK-COUNT-20:  select
+// CHECK:           %[[res:.*]] = select
+// CHECK:           return %[[res]] : vector<8xf32>
+// CHECK:         }
+func @erf_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
+  %0 = math.erf %arg0 : vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
 // CHECK-LABEL:   func @exp_scalar(
 // CHECK-SAME:                     %[[VAL_0:.*]]: f32) -> f32 {
 // CHECK-DAG:           %[[VAL_1:.*]] = arith.constant 0.693147182 : f32

diff  --git a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir
index c0f319eb9d45c..b3c41057fa302 100644
--- a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir
+++ b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir
@@ -152,6 +152,78 @@ func @log1p() {
   return
 }
 
+// -------------------------------------------------------------------------- //
+// Erf.
+// -------------------------------------------------------------------------- //
+func @erf() {
+  // CHECK: -0.000274406
+  %val1 = arith.constant -2.431864e-4 : f32
+  %erfVal1 = math.erf %val1 : f32
+  vector.print %erfVal1 : f32
+
+  // CHECK: 0.742095
+  %val2 = arith.constant 0.79999 : f32
+  %erfVal2 = math.erf %val2 : f32
+  vector.print %erfVal2 : f32
+
+  // CHECK: 0.742101
+  %val3 = arith.constant 0.8 : f32
+  %erfVal3 = math.erf %val3 : f32
+  vector.print %erfVal3 : f32
+
+  // CHECK: 0.995322
+  %val4 = arith.constant 1.99999 : f32
+  %erfVal4 = math.erf %val4 : f32
+  vector.print %erfVal4 : f32
+
+  // CHECK: 0.995322
+  %val5 = arith.constant 2.0 : f32
+  %erfVal5 = math.erf %val5 : f32
+  vector.print %erfVal5 : f32
+
+  // CHECK: 1
+  %val6 = arith.constant 3.74999 : f32
+  %erfVal6 = math.erf %val6 : f32
+  vector.print %erfVal6 : f32
+
+  // CHECK: 1
+  %val7 = arith.constant 3.75 : f32
+  %erfVal7 = math.erf %val7 : f32
+  vector.print %erfVal7 : f32
+
+  // CHECK: -1
+  %negativeInf = arith.constant 0xff800000 : f32
+  %erfNegativeInf = math.erf %negativeInf : f32
+  vector.print %erfNegativeInf : f32
+
+  // CHECK: -1, -1, -0.913759, -0.731446
+  %vecVals1 = arith.constant dense<[-3.4028235e+38, -4.54318, -1.2130899, -7.8234202e-01]> : vector<4xf32>
+  %erfVecVals1 = math.erf %vecVals1 : vector<4xf32>
+  vector.print %erfVecVals1 : vector<4xf32>
+
+  // CHECK: -1.3264e-38, 0, 1.3264e-38, 0.121319
+  %vecVals2 = arith.constant dense<[-1.1754944e-38, 0.0, 1.1754944e-38, 1.0793410e-01]> : vector<4xf32>
+  %erfVecVals2 = math.erf %vecVals2 : vector<4xf32>
+  vector.print %erfVecVals2 : vector<4xf32>
+
+  // CHECK: 0.919477, 0.999069, 1, 1
+  %vecVals3 = arith.constant dense<[1.23578, 2.34093, 3.82342, 3.4028235e+38]> : vector<4xf32>
+  %erfVecVals3 = math.erf %vecVals3 : vector<4xf32>
+  vector.print %erfVecVals3 : vector<4xf32>
+
+  // CHECK: 1
+  %inf = arith.constant 0x7f800000 : f32
+  %erfInf = math.erf %inf : f32
+  vector.print %erfInf : f32
+
+  // CHECK: nan
+  %nan = arith.constant 0x7fc00000 : f32
+  %erfNan = math.erf %nan : f32
+  vector.print %erfNan : f32
+
+  return
+}
+
 // -------------------------------------------------------------------------- //
 // Exp.
 // -------------------------------------------------------------------------- //
@@ -305,6 +377,7 @@ func @main() {
   call @log(): () -> ()
   call @log2(): () -> ()
   call @log1p(): () -> ()
+  call @erf(): () -> ()
   call @exp(): () -> ()
   call @expm1(): () -> ()
   call @sin(): () -> ()

diff  --git a/mlir/utils/vim/syntax/mlir.vim b/mlir/utils/vim/syntax/mlir.vim
index 6d2e953e2cab1..8caa06e1a843c 100644
--- a/mlir/utils/vim/syntax/mlir.vim
+++ b/mlir/utils/vim/syntax/mlir.vim
@@ -43,6 +43,9 @@ syn keyword mlirOps memref_shape_cast mulf muli negf powf prefetch rsqrt sitofp
 syn keyword mlirOps splat store select sqrt subf subi subview tanh
 syn keyword mlirOps view
 
+" Math ops.
+syn match mlirOps /\<math\.erf\>/
+
 " Affine ops.
 syn match mlirOps /\<affine\.apply\>/
 syn match mlirOps /\<affine\.dma_start\>/


        


More information about the Mlir-commits mailing list