[Mlir-commits] [mlir] 1fef1f9 - [MLIR][TOSA] add tosa erf operator
Eric Kunze
llvmlistbot at llvm.org
Fri May 19 14:51:41 PDT 2023
Author: Manupa Karunaratne
Date: 2023-05-19T14:50:14-07:00
New Revision: 1fef1f97dbf227d14865e022ceb566803fb65a0c
URL: https://github.com/llvm/llvm-project/commit/1fef1f97dbf227d14865e022ceb566803fb65a0c
DIFF: https://github.com/llvm/llvm-project/commit/1fef1f97dbf227d14865e022ceb566803fb65a0c.diff
LOG: [MLIR][TOSA] add tosa erf operator
This commit adds tosa erf operator and its lowering
to math lib functions.
Reviewed By: eric-k256, jpienaar
Differential Revision: https://reviews.llvm.org/D150357
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
mlir/test/Dialect/Tosa/ops.mlir
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index b9fa2f80601e7..842ac746eb664 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -455,6 +455,32 @@ def Tosa_TanhOp : Tosa_Op<"tanh", [
);
}
+
+//===----------------------------------------------------------------------===//
+// Operator: erf
+//===----------------------------------------------------------------------===//
+def Tosa_ErfOp : Tosa_Op<"erf", [
+ DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+ ["inferReturnTypeComponents"]>,
+ Pure]> {
+ let summary = "Computes gauss error function of input";
+
+ let description = [{
+ Gauss error function: $ erf(x) = \frac{2}{\sqrt(\pi)} \int_{0}^{x} e^{-t^2} \,dt $
+ For quantized integer data types, the TABLE operator should be used instead
+ with the following definition. The erf_table has 513 entries each of
+ 16-bit/8-bit precision and covering the input range -4.0 to +4.0 in steps of 1/64.
+ }];
+
+ let arguments = (ins
+ Tosa_Tensor:$input
+ );
+
+ let results = (outs
+ Tosa_Tensor:$output
+ );
+}
+
//===----------------------------------------------------------------------===//
// TOSA Spec Section 2.4
// Operator Class: Elementwise unary/binary/ternary operators.
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 6aa075158e7fd..2faf7f1a625ea 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -299,6 +299,10 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
+ // tosa::ErfOp
+ if (isa<tosa::ErfOp>(op) && elementTy.isa<FloatType>())
+ return rewriter.create<mlir::math::ErfOp>(loc, resultTypes, args);
+
// tosa::GreaterOp
if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
@@ -2044,6 +2048,7 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
PointwiseConverter<tosa::ExpOp>,
PointwiseConverter<tosa::AbsOp>,
PointwiseConverter<tosa::TanhOp>,
+ PointwiseConverter<tosa::ErfOp>,
PointwiseConverter<tosa::BitwiseAndOp>,
PointwiseConverter<tosa::BitwiseOrOp>,
PointwiseConverter<tosa::BitwiseNotOp>,
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 1040d4c159659..d2c732c778f10 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1058,6 +1058,7 @@ NARY_SHAPE_INFER(tosa::RsqrtOp)
NARY_SHAPE_INFER(tosa::SelectOp)
NARY_SHAPE_INFER(tosa::SubOp)
NARY_SHAPE_INFER(tosa::TanhOp)
+NARY_SHAPE_INFER(tosa::ErfOp)
NARY_SHAPE_INFER(tosa::SigmoidOp)
#undef PRED_SHAPE_INFER
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 3e654ab9c56b0..65d56ad7ad588 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -258,6 +258,10 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
// CHECK: arith.divf
%23 = "tosa.reciprocal"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+ // CHECK: linalg.generic
+ // CHECK: math.erf
+ %24 = "tosa.erf"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+
return
}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index bf3cf3d7084bc..72f020336ff05 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -114,6 +114,13 @@ func.func @test_tanh(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
return %0 : tensor<13x21x3xf32>
}
+// -----
+// CHECK-LABEL: erf
+func.func @test_erf(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tosa.erf"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
// -----
// CHECK-LABEL: add
func.func @test_add(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 56b088784cbf4..5bbb6e1fb4a6d 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -65,6 +65,9 @@ func.func @test_unary_f32(%arg0 : tensor<4xf32>) -> () {
// CHECK: "tosa.cast"(%arg0) : (tensor<4xf32>) -> tensor<4xi32>
%12 = "tosa.cast"(%arg0) : (tensor<4xf32>) -> tensor<*xi32>
+
+ // CHECK: "tosa.erf"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+ %13 = "tosa.erf"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
return
}
More information about the Mlir-commits
mailing list