[Mlir-commits] [mlir] 880b8f4 - Add f16 type support in math.erf op.

Prashant Kumar llvmlistbot at llvm.org
Fri Oct 14 05:58:12 PDT 2022


Author: Prashant Kumar
Date: 2022-10-14T12:57:58Z
New Revision: 880b8f4e04419d5f9b506cd6e66b76180acecbb4

URL: https://github.com/llvm/llvm-project/commit/880b8f4e04419d5f9b506cd6e66b76180acecbb4
DIFF: https://github.com/llvm/llvm-project/commit/880b8f4e04419d5f9b506cd6e66b76180acecbb4.diff

LOG: Add f16 type support in math.erf op.

f16 type support was missing in the math.erf op.

Reviewed By: ezhulenev

Differential Revision: https://reviews.llvm.org/D135770

Added: 
    

Modified: 
    mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index ec5546476a0de..abffadffb2251 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -165,6 +165,14 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
 // Helper functions to create constants.
 //----------------------------------------------------------------------------//
 
+static Value floatCst(ImplicitLocOpBuilder &builder, float value,
+                      Type elementType) {
+  assert(elementType.isF16() ||
+         elementType.isF32() && "x must be f16 or f32 type.");
+  return builder.create<arith::ConstantOp>(
+      builder.getFloatAttr(elementType, value));
+}
+
 static Value f32Cst(ImplicitLocOpBuilder &builder, float value) {
   return builder.create<arith::ConstantOp>(builder.getF32FloatAttr(value));
 }
@@ -270,11 +278,13 @@ static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
 namespace {
 Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
                                 llvm::ArrayRef<Value> coeffs, Value x) {
-  assert(getElementTypeOrSelf(x).isF32() && "x must be f32 type");
+  Type elementType = getElementTypeOrSelf(x);
+  assert(elementType.isF32() ||
+         elementType.isF16() && "x must be f32 or f16 type");
   ArrayRef<int64_t> shape = vectorShape(x);
 
   if (coeffs.empty())
-    return broadcast(builder, f32Cst(builder, 0.0f), shape);
+    return broadcast(builder, floatCst(builder, 0.0f, elementType), shape);
 
   if (coeffs.size() == 1)
     return coeffs[0];
@@ -771,10 +781,13 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op,
 LogicalResult
 ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
                                             PatternRewriter &rewriter) const {
-  if (!getElementTypeOrSelf(op.getOperand()).isF32())
-    return rewriter.notifyMatchFailure(op, "unsupported operand type");
+  Value operand = op.getOperand();
+  Type elementType = getElementTypeOrSelf(operand);
 
-  ArrayRef<int64_t> shape = vectorShape(op.getOperand());
+  if (!(elementType.isF32() || elementType.isF16()))
+    return rewriter.notifyMatchFailure(op,
+                                       "only f32 and f16 type is supported.");
+  ArrayRef<int64_t> shape = vectorShape(operand);
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   auto bcast = [&](Value value) -> Value {
@@ -784,57 +797,56 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
   const int intervalsCount = 3;
   const int polyDegree = 4;
 
-  Value zero = bcast(f32Cst(builder, 0));
-  Value one = bcast(f32Cst(builder, 1));
+  Value zero = bcast(floatCst(builder, 0, elementType));
+  Value one = bcast(floatCst(builder, 1, elementType));
   Value pp[intervalsCount][polyDegree + 1];
-  pp[0][0] = bcast(f32Cst(builder, +0.00000000000000000e+00f));
-  pp[0][1] = bcast(f32Cst(builder, +1.12837916222975858e+00f));
-  pp[0][2] = bcast(f32Cst(builder, -5.23018562988006470e-01f));
-  pp[0][3] = bcast(f32Cst(builder, +2.09741709609267072e-01f));
-  pp[0][4] = bcast(f32Cst(builder, +2.58146801602987875e-02f));
-  pp[1][0] = bcast(f32Cst(builder, +0.00000000000000000e+00f));
-  pp[1][1] = bcast(f32Cst(builder, +1.12750687816789140e+00f));
-  pp[1][2] = bcast(f32Cst(builder, -3.64721408487825775e-01f));
-  pp[1][3] = bcast(f32Cst(builder, +1.18407396425136952e-01f));
-  pp[1][4] = bcast(f32Cst(builder, +3.70645533056476558e-02f));
-  pp[2][0] = bcast(f32Cst(builder, -3.30093071049483172e-03f));
-  pp[2][1] = bcast(f32Cst(builder, +3.51961938357697011e-03f));
-  pp[2][2] = bcast(f32Cst(builder, -1.41373622814988039e-03f));
-  pp[2][3] = bcast(f32Cst(builder, +2.53447094961941348e-04f));
-  pp[2][4] = bcast(f32Cst(builder, -1.71048029455037401e-05f));
+  pp[0][0] = bcast(floatCst(builder, +0.00000000000000000e+00f, elementType));
+  pp[0][1] = bcast(floatCst(builder, +1.12837916222975858e+00f, elementType));
+  pp[0][2] = bcast(floatCst(builder, -5.23018562988006470e-01f, elementType));
+  pp[0][3] = bcast(floatCst(builder, +2.09741709609267072e-01f, elementType));
+  pp[0][4] = bcast(floatCst(builder, +2.58146801602987875e-02f, elementType));
+  pp[1][0] = bcast(floatCst(builder, +0.00000000000000000e+00f, elementType));
+  pp[1][1] = bcast(floatCst(builder, +1.12750687816789140e+00f, elementType));
+  pp[1][2] = bcast(floatCst(builder, -3.64721408487825775e-01f, elementType));
+  pp[1][3] = bcast(floatCst(builder, +1.18407396425136952e-01f, elementType));
+  pp[1][4] = bcast(floatCst(builder, +3.70645533056476558e-02f, elementType));
+  pp[2][0] = bcast(floatCst(builder, -3.30093071049483172e-03f, elementType));
+  pp[2][1] = bcast(floatCst(builder, +3.51961938357697011e-03f, elementType));
+  pp[2][2] = bcast(floatCst(builder, -1.41373622814988039e-03f, elementType));
+  pp[2][3] = bcast(floatCst(builder, +2.53447094961941348e-04f, elementType));
+  pp[2][4] = bcast(floatCst(builder, -1.71048029455037401e-05f, elementType));
 
   Value qq[intervalsCount][polyDegree + 1];
-  qq[0][0] = bcast(f32Cst(builder, +1.000000000000000000e+00f));
-  qq[0][1] = bcast(f32Cst(builder, -4.635138185962547255e-01f));
-  qq[0][2] = bcast(f32Cst(builder, +5.192301327279782447e-01f));
-  qq[0][3] = bcast(f32Cst(builder, -1.318089722204810087e-01f));
-  qq[0][4] = bcast(f32Cst(builder, +7.397964654672315005e-02f));
-  qq[1][0] = bcast(f32Cst(builder, +1.00000000000000000e+00f));
-  qq[1][1] = bcast(f32Cst(builder, -3.27607011824493086e-01f));
-  qq[1][2] = bcast(f32Cst(builder, +4.48369090658821977e-01f));
-  qq[1][3] = bcast(f32Cst(builder, -8.83462621207857930e-02f));
-  qq[1][4] = bcast(f32Cst(builder, +5.72442770283176093e-02f));
-  qq[2][0] = bcast(f32Cst(builder, +1.00000000000000000e+00f));
-  qq[2][1] = bcast(f32Cst(builder, -2.06069165953913769e+00f));
-  qq[2][2] = bcast(f32Cst(builder, +1.62705939945477759e+00f));
-  qq[2][3] = bcast(f32Cst(builder, -5.83389859211130017e-01f));
-  qq[2][4] = bcast(f32Cst(builder, +8.21908939856640930e-02f));
+  qq[0][0] = bcast(floatCst(builder, +1.000000000000000000e+00f, elementType));
+  qq[0][1] = bcast(floatCst(builder, -4.635138185962547255e-01f, elementType));
+  qq[0][2] = bcast(floatCst(builder, +5.192301327279782447e-01f, elementType));
+  qq[0][3] = bcast(floatCst(builder, -1.318089722204810087e-01f, elementType));
+  qq[0][4] = bcast(floatCst(builder, +7.397964654672315005e-02f, elementType));
+  qq[1][0] = bcast(floatCst(builder, +1.00000000000000000e+00f, elementType));
+  qq[1][1] = bcast(floatCst(builder, -3.27607011824493086e-01f, elementType));
+  qq[1][2] = bcast(floatCst(builder, +4.48369090658821977e-01f, elementType));
+  qq[1][3] = bcast(floatCst(builder, -8.83462621207857930e-02f, elementType));
+  qq[1][4] = bcast(floatCst(builder, +5.72442770283176093e-02f, elementType));
+  qq[2][0] = bcast(floatCst(builder, +1.00000000000000000e+00f, elementType));
+  qq[2][1] = bcast(floatCst(builder, -2.06069165953913769e+00f, elementType));
+  qq[2][2] = bcast(floatCst(builder, +1.62705939945477759e+00f, elementType));
+  qq[2][3] = bcast(floatCst(builder, -5.83389859211130017e-01f, elementType));
+  qq[2][4] = bcast(floatCst(builder, +8.21908939856640930e-02f, elementType));
 
   Value offsets[intervalsCount];
-  offsets[0] = bcast(f32Cst(builder, 0.0f));
-  offsets[1] = bcast(f32Cst(builder, 0.0f));
-  offsets[2] = bcast(f32Cst(builder, 1.0f));
+  offsets[0] = bcast(floatCst(builder, 0.0f, elementType));
+  offsets[1] = bcast(floatCst(builder, 0.0f, elementType));
+  offsets[2] = bcast(floatCst(builder, 1.0f, elementType));
 
   Value bounds[intervalsCount];
-  bounds[0] = bcast(f32Cst(builder, 0.8f));
-  bounds[1] = bcast(f32Cst(builder, 2.0f));
-  bounds[2] = bcast(f32Cst(builder, 3.75f));
-
-  Value isNegativeArg = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT,
-                                                      op.getOperand(), zero);
-  Value negArg = builder.create<arith::NegFOp>(op.getOperand());
-  Value x =
-      builder.create<arith::SelectOp>(isNegativeArg, negArg, op.getOperand());
+  bounds[0] = bcast(floatCst(builder, 0.8f, elementType));
+  bounds[1] = bcast(floatCst(builder, 2.0f, elementType));
+  bounds[2] = bcast(floatCst(builder, 3.75f, elementType));
+
+  Value isNegativeArg =
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
+  Value negArg = builder.create<arith::NegFOp>(operand);
+  Value x = builder.create<arith::SelectOp>(isNegativeArg, negArg, operand);
 
   Value offset = offsets[0];
   Value p[polyDegree + 1];


        


More information about the Mlir-commits mailing list