[Mlir-commits] [mlir] 880b8f4 - Add f16 type support in math.erf op.
Prashant Kumar
llvmlistbot at llvm.org
Fri Oct 14 05:58:12 PDT 2022
Author: Prashant Kumar
Date: 2022-10-14T12:57:58Z
New Revision: 880b8f4e04419d5f9b506cd6e66b76180acecbb4
URL: https://github.com/llvm/llvm-project/commit/880b8f4e04419d5f9b506cd6e66b76180acecbb4
DIFF: https://github.com/llvm/llvm-project/commit/880b8f4e04419d5f9b506cd6e66b76180acecbb4.diff
LOG: Add f16 type support in math.erf op.
f16 type support was missing in the math.erf op.
Reviewed By: ezhulenev
Differential Revision: https://reviews.llvm.org/D135770
Added:
Modified:
mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index ec5546476a0de..abffadffb2251 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -165,6 +165,14 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
// Helper functions to create constants.
//----------------------------------------------------------------------------//
+static Value floatCst(ImplicitLocOpBuilder &builder, float value,
+ Type elementType) {
+ assert(elementType.isF16() ||
+ elementType.isF32() && "x must be f16 or f32 type.");
+ return builder.create<arith::ConstantOp>(
+ builder.getFloatAttr(elementType, value));
+}
+
static Value f32Cst(ImplicitLocOpBuilder &builder, float value) {
return builder.create<arith::ConstantOp>(builder.getF32FloatAttr(value));
}
@@ -270,11 +278,13 @@ static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
namespace {
Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
llvm::ArrayRef<Value> coeffs, Value x) {
- assert(getElementTypeOrSelf(x).isF32() && "x must be f32 type");
+ Type elementType = getElementTypeOrSelf(x);
+ assert(elementType.isF32() ||
+ elementType.isF16() && "x must be f32 or f16 type");
ArrayRef<int64_t> shape = vectorShape(x);
if (coeffs.empty())
- return broadcast(builder, f32Cst(builder, 0.0f), shape);
+ return broadcast(builder, floatCst(builder, 0.0f, elementType), shape);
if (coeffs.size() == 1)
return coeffs[0];
@@ -771,10 +781,13 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op,
LogicalResult
ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
PatternRewriter &rewriter) const {
- if (!getElementTypeOrSelf(op.getOperand()).isF32())
- return rewriter.notifyMatchFailure(op, "unsupported operand type");
+ Value operand = op.getOperand();
+ Type elementType = getElementTypeOrSelf(operand);
- ArrayRef<int64_t> shape = vectorShape(op.getOperand());
+ if (!(elementType.isF32() || elementType.isF16()))
+ return rewriter.notifyMatchFailure(op,
+ "only f32 and f16 type is supported.");
+ ArrayRef<int64_t> shape = vectorShape(operand);
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
@@ -784,57 +797,56 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
const int intervalsCount = 3;
const int polyDegree = 4;
- Value zero = bcast(f32Cst(builder, 0));
- Value one = bcast(f32Cst(builder, 1));
+ Value zero = bcast(floatCst(builder, 0, elementType));
+ Value one = bcast(floatCst(builder, 1, elementType));
Value pp[intervalsCount][polyDegree + 1];
- pp[0][0] = bcast(f32Cst(builder, +0.00000000000000000e+00f));
- pp[0][1] = bcast(f32Cst(builder, +1.12837916222975858e+00f));
- pp[0][2] = bcast(f32Cst(builder, -5.23018562988006470e-01f));
- pp[0][3] = bcast(f32Cst(builder, +2.09741709609267072e-01f));
- pp[0][4] = bcast(f32Cst(builder, +2.58146801602987875e-02f));
- pp[1][0] = bcast(f32Cst(builder, +0.00000000000000000e+00f));
- pp[1][1] = bcast(f32Cst(builder, +1.12750687816789140e+00f));
- pp[1][2] = bcast(f32Cst(builder, -3.64721408487825775e-01f));
- pp[1][3] = bcast(f32Cst(builder, +1.18407396425136952e-01f));
- pp[1][4] = bcast(f32Cst(builder, +3.70645533056476558e-02f));
- pp[2][0] = bcast(f32Cst(builder, -3.30093071049483172e-03f));
- pp[2][1] = bcast(f32Cst(builder, +3.51961938357697011e-03f));
- pp[2][2] = bcast(f32Cst(builder, -1.41373622814988039e-03f));
- pp[2][3] = bcast(f32Cst(builder, +2.53447094961941348e-04f));
- pp[2][4] = bcast(f32Cst(builder, -1.71048029455037401e-05f));
+ pp[0][0] = bcast(floatCst(builder, +0.00000000000000000e+00f, elementType));
+ pp[0][1] = bcast(floatCst(builder, +1.12837916222975858e+00f, elementType));
+ pp[0][2] = bcast(floatCst(builder, -5.23018562988006470e-01f, elementType));
+ pp[0][3] = bcast(floatCst(builder, +2.09741709609267072e-01f, elementType));
+ pp[0][4] = bcast(floatCst(builder, +2.58146801602987875e-02f, elementType));
+ pp[1][0] = bcast(floatCst(builder, +0.00000000000000000e+00f, elementType));
+ pp[1][1] = bcast(floatCst(builder, +1.12750687816789140e+00f, elementType));
+ pp[1][2] = bcast(floatCst(builder, -3.64721408487825775e-01f, elementType));
+ pp[1][3] = bcast(floatCst(builder, +1.18407396425136952e-01f, elementType));
+ pp[1][4] = bcast(floatCst(builder, +3.70645533056476558e-02f, elementType));
+ pp[2][0] = bcast(floatCst(builder, -3.30093071049483172e-03f, elementType));
+ pp[2][1] = bcast(floatCst(builder, +3.51961938357697011e-03f, elementType));
+ pp[2][2] = bcast(floatCst(builder, -1.41373622814988039e-03f, elementType));
+ pp[2][3] = bcast(floatCst(builder, +2.53447094961941348e-04f, elementType));
+ pp[2][4] = bcast(floatCst(builder, -1.71048029455037401e-05f, elementType));
Value qq[intervalsCount][polyDegree + 1];
- qq[0][0] = bcast(f32Cst(builder, +1.000000000000000000e+00f));
- qq[0][1] = bcast(f32Cst(builder, -4.635138185962547255e-01f));
- qq[0][2] = bcast(f32Cst(builder, +5.192301327279782447e-01f));
- qq[0][3] = bcast(f32Cst(builder, -1.318089722204810087e-01f));
- qq[0][4] = bcast(f32Cst(builder, +7.397964654672315005e-02f));
- qq[1][0] = bcast(f32Cst(builder, +1.00000000000000000e+00f));
- qq[1][1] = bcast(f32Cst(builder, -3.27607011824493086e-01f));
- qq[1][2] = bcast(f32Cst(builder, +4.48369090658821977e-01f));
- qq[1][3] = bcast(f32Cst(builder, -8.83462621207857930e-02f));
- qq[1][4] = bcast(f32Cst(builder, +5.72442770283176093e-02f));
- qq[2][0] = bcast(f32Cst(builder, +1.00000000000000000e+00f));
- qq[2][1] = bcast(f32Cst(builder, -2.06069165953913769e+00f));
- qq[2][2] = bcast(f32Cst(builder, +1.62705939945477759e+00f));
- qq[2][3] = bcast(f32Cst(builder, -5.83389859211130017e-01f));
- qq[2][4] = bcast(f32Cst(builder, +8.21908939856640930e-02f));
+ qq[0][0] = bcast(floatCst(builder, +1.000000000000000000e+00f, elementType));
+ qq[0][1] = bcast(floatCst(builder, -4.635138185962547255e-01f, elementType));
+ qq[0][2] = bcast(floatCst(builder, +5.192301327279782447e-01f, elementType));
+ qq[0][3] = bcast(floatCst(builder, -1.318089722204810087e-01f, elementType));
+ qq[0][4] = bcast(floatCst(builder, +7.397964654672315005e-02f, elementType));
+ qq[1][0] = bcast(floatCst(builder, +1.00000000000000000e+00f, elementType));
+ qq[1][1] = bcast(floatCst(builder, -3.27607011824493086e-01f, elementType));
+ qq[1][2] = bcast(floatCst(builder, +4.48369090658821977e-01f, elementType));
+ qq[1][3] = bcast(floatCst(builder, -8.83462621207857930e-02f, elementType));
+ qq[1][4] = bcast(floatCst(builder, +5.72442770283176093e-02f, elementType));
+ qq[2][0] = bcast(floatCst(builder, +1.00000000000000000e+00f, elementType));
+ qq[2][1] = bcast(floatCst(builder, -2.06069165953913769e+00f, elementType));
+ qq[2][2] = bcast(floatCst(builder, +1.62705939945477759e+00f, elementType));
+ qq[2][3] = bcast(floatCst(builder, -5.83389859211130017e-01f, elementType));
+ qq[2][4] = bcast(floatCst(builder, +8.21908939856640930e-02f, elementType));
Value offsets[intervalsCount];
- offsets[0] = bcast(f32Cst(builder, 0.0f));
- offsets[1] = bcast(f32Cst(builder, 0.0f));
- offsets[2] = bcast(f32Cst(builder, 1.0f));
+ offsets[0] = bcast(floatCst(builder, 0.0f, elementType));
+ offsets[1] = bcast(floatCst(builder, 0.0f, elementType));
+ offsets[2] = bcast(floatCst(builder, 1.0f, elementType));
Value bounds[intervalsCount];
- bounds[0] = bcast(f32Cst(builder, 0.8f));
- bounds[1] = bcast(f32Cst(builder, 2.0f));
- bounds[2] = bcast(f32Cst(builder, 3.75f));
-
- Value isNegativeArg = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT,
- op.getOperand(), zero);
- Value negArg = builder.create<arith::NegFOp>(op.getOperand());
- Value x =
- builder.create<arith::SelectOp>(isNegativeArg, negArg, op.getOperand());
+ bounds[0] = bcast(floatCst(builder, 0.8f, elementType));
+ bounds[1] = bcast(floatCst(builder, 2.0f, elementType));
+ bounds[2] = bcast(floatCst(builder, 3.75f, elementType));
+
+ Value isNegativeArg =
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
+ Value negArg = builder.create<arith::NegFOp>(operand);
+ Value x = builder.create<arith::SelectOp>(isNegativeArg, negArg, operand);
Value offset = offsets[0];
Value p[polyDegree + 1];
More information about the Mlir-commits
mailing list