[Mlir-commits] [mlir] 1fef1f9 - [MLIR][TOSA] add tosa erf operator

Eric Kunze llvmlistbot at llvm.org
Fri May 19 14:51:41 PDT 2023


Author: Manupa Karunaratne
Date: 2023-05-19T14:50:14-07:00
New Revision: 1fef1f97dbf227d14865e022ceb566803fb65a0c

URL: https://github.com/llvm/llvm-project/commit/1fef1f97dbf227d14865e022ceb566803fb65a0c
DIFF: https://github.com/llvm/llvm-project/commit/1fef1f97dbf227d14865e022ceb566803fb65a0c.diff

LOG: [MLIR][TOSA] add tosa erf operator

This commit adds tosa erf operator and its lowering
to math lib functions.

Reviewed By: eric-k256, jpienaar

Differential Revision: https://reviews.llvm.org/D150357

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
    mlir/test/Dialect/Tosa/ops.mlir
    mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index b9fa2f80601e7..842ac746eb664 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -455,6 +455,32 @@ def Tosa_TanhOp : Tosa_Op<"tanh", [
   );
 }
 
+
+//===----------------------------------------------------------------------===//
+// Operator: erf
+//===----------------------------------------------------------------------===//
+def Tosa_ErfOp : Tosa_Op<"erf", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    Pure]> {
+  let summary = "Computes gauss error function of input";
+
+  let description = [{
+    Gauss error function: $ erf(x) = \frac{2}{\sqrt(\pi)} \int_{0}^{x} e^{-t^2} \,dt $
+    For quantized integer data types, the TABLE operator should be used instead
+    with the following definition.  The erf_table has 513 entries each of
+    16-bit/8-bit precision and covering the input range -4.0 to +4.0 in steps of 1/64.
+  }];
+
+  let arguments = (ins
+    Tosa_Tensor:$input
+  );
+
+  let results = (outs
+    Tosa_Tensor:$output
+  );
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Spec Section 2.4
 // Operator Class: Elementwise unary/binary/ternary operators.

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 6aa075158e7fd..2faf7f1a625ea 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -299,6 +299,10 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
   if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
     return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
 
+  // tosa::ErfOp
+  if (isa<tosa::ErfOp>(op) && elementTy.isa<FloatType>())
+    return rewriter.create<mlir::math::ErfOp>(loc, resultTypes, args);
+
   // tosa::GreaterOp
   if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
     return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
@@ -2044,6 +2048,7 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
       PointwiseConverter<tosa::ExpOp>,
       PointwiseConverter<tosa::AbsOp>,
       PointwiseConverter<tosa::TanhOp>,
+      PointwiseConverter<tosa::ErfOp>,
       PointwiseConverter<tosa::BitwiseAndOp>,
       PointwiseConverter<tosa::BitwiseOrOp>,
       PointwiseConverter<tosa::BitwiseNotOp>,

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 1040d4c159659..d2c732c778f10 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1058,6 +1058,7 @@ NARY_SHAPE_INFER(tosa::RsqrtOp)
 NARY_SHAPE_INFER(tosa::SelectOp)
 NARY_SHAPE_INFER(tosa::SubOp)
 NARY_SHAPE_INFER(tosa::TanhOp)
+NARY_SHAPE_INFER(tosa::ErfOp)
 NARY_SHAPE_INFER(tosa::SigmoidOp)
 #undef PRED_SHAPE_INFER
 

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 3e654ab9c56b0..65d56ad7ad588 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -258,6 +258,10 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
   // CHECK: arith.divf
   %23 = "tosa.reciprocal"(%0) : (tensor<1xf32>) -> tensor<1xf32>
 
+  // CHECK: linalg.generic
+  // CHECK: math.erf
+  %24 = "tosa.erf"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+
   return
 }
 

diff  --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index bf3cf3d7084bc..72f020336ff05 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -114,6 +114,13 @@ func.func @test_tanh(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
   return %0 : tensor<13x21x3xf32>
 }
 
+// -----
+// CHECK-LABEL: erf
+func.func @test_erf(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+  %0 = "tosa.erf"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
 // -----
 // CHECK-LABEL: add
 func.func @test_add(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {

diff  --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 56b088784cbf4..5bbb6e1fb4a6d 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -65,6 +65,9 @@ func.func @test_unary_f32(%arg0 : tensor<4xf32>) -> () {
 
   // CHECK: "tosa.cast"(%arg0) : (tensor<4xf32>) -> tensor<4xi32>
   %12 = "tosa.cast"(%arg0) : (tensor<4xf32>) -> tensor<*xi32>
+
+  // CHECK: "tosa.erf"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %13 = "tosa.erf"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
   return
 }
 


        


More information about the Mlir-commits mailing list