[Mlir-commits] [mlir] 9dec80b - [MLIR][TOSA] Remove ReluN operator from TOSA dialect

Jacques Pienaar llvmlistbot at llvm.org
Fri Aug 12 16:00:20 PDT 2022


Author: Eric Kunze
Date: 2022-08-12T16:00:11-07:00
New Revision: 9dec80be729fc8a38ff62f3ed87ddbcf13d0b3e8

URL: https://github.com/llvm/llvm-project/commit/9dec80be729fc8a38ff62f3ed87ddbcf13d0b3e8
DIFF: https://github.com/llvm/llvm-project/commit/9dec80be729fc8a38ff62f3ed87ddbcf13d0b3e8.diff

LOG: [MLIR][TOSA] Remove ReluN operator from TOSA dialect

ReluN has been removed from the TOSA specification. It can be replaced
in all instances with Clamp(0,N)

Signed-off-by: Eric Kunze <eric.kunze at arm.com>

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D128683

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
    mlir/test/Dialect/Tosa/ops.mlir
    mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index bfc152b25a9c6..9fcd955693f08 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -339,30 +339,6 @@ def Tosa_ClampOp : Tosa_Op<"clamp", [
   let hasCanonicalizer = 1;
 }
 
-//===----------------------------------------------------------------------===//
-// Operator: reluN
-//===----------------------------------------------------------------------===//
-def Tosa_ReluNOp : Tosa_Op<"reluN", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-    NoSideEffect]> {
-  let summary = "Computes rectified linear: `max(features, N)`.";
-
-  let description = [{
-     ReLU with a scalar maximum value.
-  }];
-
-  let arguments = (ins
-    Tosa_Tensor:$input,
-    I64Attr:$max_int,
-    F32Attr:$max_fp
-  );
-
-  let results = (outs
-    Tosa_Tensor:$output
-  );
-}
-
 //===----------------------------------------------------------------------===//
 // Operator: sigmoid
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 374c663511599..d1e9fe1ffc8e4 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -406,27 +406,6 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
     return clampIntHelper(loc, args[0], minVal, maxVal, rewriter);
   }
 
-  // tosa::ReluNOp
-  if (isa<tosa::ReluNOp>(op) && elementTy.isa<FloatType>()) {
-    auto zero =
-        rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 0));
-    bool losesInfo = false;
-    APFloat max_apf = op->getAttr("max_fp").cast<FloatAttr>().getValue();
-    max_apf.convert(elementTy.cast<FloatType>().getFloatSemantics(),
-                    APFloat::rmNearestTiesToEven, &losesInfo);
-    auto n = rewriter.create<arith::ConstantOp>(
-        loc, elementTy, rewriter.getFloatAttr(elementTy, max_apf));
-    return clampFloatHelper(loc, args[0], zero, n, rewriter);
-  }
-
-  if (isa<tosa::ReluNOp>(op) && elementTy.isa<IntegerType>()) {
-    auto zero =
-        rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
-    auto n = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
-                                                  rewriter);
-    return clampIntHelper(loc, args[0], zero, n, rewriter);
-  }
-
   // tosa::SigmoidOp
   if (isa<tosa::SigmoidOp>(op) && elementTy.isa<FloatType>()) {
     auto one =
@@ -2235,7 +2214,6 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
       PointwiseConverter<tosa::CeilOp>,
       PointwiseConverter<tosa::FloorOp>,
       PointwiseConverter<tosa::ClampOp>,
-      PointwiseConverter<tosa::ReluNOp>,
       PointwiseConverter<tosa::SigmoidOp>,
       IdentityNConverter<tosa::IdentityOp>,
       ReduceConverter<tosa::ReduceAllOp>,

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 90fe70d1a77dc..d944eece7ec90 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -981,7 +981,6 @@ NARY_SHAPE_INFER(tosa::MulOp)
 NARY_SHAPE_INFER(tosa::NegateOp)
 NARY_SHAPE_INFER(tosa::PowOp)
 NARY_SHAPE_INFER(tosa::ReciprocalOp)
-NARY_SHAPE_INFER(tosa::ReluNOp)
 NARY_SHAPE_INFER(tosa::RescaleOp)
 NARY_SHAPE_INFER(tosa::ReverseOp)
 NARY_SHAPE_INFER(tosa::RsqrtOp)

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 47efb8e72cb1c..c42c85e6462ab 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -218,17 +218,12 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
   // CHECK: arith.maxf
   %18 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
 
-  // CHECK: linalg.generic
-  // CHECK: arith.minf
-  // CHECK: arith.maxf
-  %19 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
-
   // CHECK: linalg.generic
   // CHECK: arith.negf
   // CHECK: exp
   // CHECK: arith.addf
   // CHECK: arith.divf
-  %20 = "tosa.sigmoid"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+  %19 = "tosa.sigmoid"(%0) : (tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
   // CHECK: arith.constant 0.000000e+00
@@ -242,20 +237,20 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
   // CHECK: arith.minf
   // CHECK: arith.maxf
   // CHECK: arith.fptosi
-  %21 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi32>
+  %20 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: arith.constant 0
   // CHECK: arith.cmpf
-  %22 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi1>
+  %21 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi1>
 
   // CHECK: linalg.generic
   // CHECK: arith.truncf
-  %23 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf16>
+  %22 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf16>
 
   // CHECK: linalg.generic
   // CHECK: arith.divf
-  %24 = "tosa.reciprocal"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+  %23 = "tosa.reciprocal"(%0) : (tensor<1xf32>) -> tensor<1xf32>
 
   return
 }
@@ -392,11 +387,6 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
   // CHECK: select
   %19 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
 
-  // CHECK: linalg.generic
-  // CHECK: arith.cmpi
-  // CHECK: select
-  %20 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
-
   // CHECK: linalg.generic
   // CHECK: arith.constant -32768
   // CHECK: arith.constant 32767
@@ -405,27 +395,27 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
   // CHECK: arith.cmpi slt
   // CHECK: select
   // CHECK: arith.trunci
-  %21 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16>
+  %20 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16>
 
   // CHECK: linalg.generic
   // CHECK: arith.extsi
-  %22 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64>
+  %21 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64>
 
   // CHECK: linalg.generic
   // CHECK: arith.constant 0
   // CHECK: arith.cmpi
-  %23 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1>
+  %22 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1>
 
   // CHECK: linalg.generic
   // CHECK: arith.sitofp
-  %24 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32>
+  %23 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
   // CHECK: arith.constant 0
   // CHECK: arith.cmpi sgt
   // CHECK: arith.subi
   // CHECK: select
-  %25 = "tosa.abs"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
+  %24 = "tosa.abs"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
 
   return
 }
@@ -474,7 +464,7 @@ func.func @test_i8(%arg0: tensor<1xi8>) -> () {
 // CHECK-LABEL: @test_clamp_f16
 func.func @test_clamp_f16(%arg0: tensor<1xf16>) -> () {
   // CHECK: linalg.generic
-  // CHECK: ^bb0(%[[ARG1:.+]]: f16, 
+  // CHECK: ^bb0(%[[ARG1:.+]]: f16,
   // CHECK-DAG: %[[C0:.+]] = arith.constant 0.0
   // CHECK-DAG: %[[C6:.+]] = arith.constant 6.0
   // CHECK-DAG: %[[MIN:.+]] = arith.minf %[[ARG1]], %[[C0]]

diff  --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 0514e7500798d..37b25a1519e1b 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -86,14 +86,6 @@ func.func @test_clamp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
   return %0 : tensor<13x21x3xf32>
 }
 
-// -----
-// CHECK-LABEL: relu
-func.func @test_relu(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
-  %0 = "tosa.reluN"(%arg0) {max_fp = 3.40282347E+38 : f32, max_int = 0 : i64} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
-  return %0 : tensor<13x21x3xf32>
-}
-
-
 // -----
 // CHECK-LABEL: sigmoid
 func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {

diff  --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 62bd0841ed454..3fd70b2510167 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -51,23 +51,20 @@ func.func @test_unary_f32(%arg0 : tensor<4xf32>) -> () {
   // CHECK: "tosa.reciprocal"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
   %7 = "tosa.reciprocal"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
 
-  // CHECK: "tosa.reluN"(%arg0) {{.+}} : (tensor<4xf32>) -> tensor<4xf32>
-  %8 = "tosa.reluN"(%arg0) { max_int = 10 : i64, min_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 10.0 : f32 } : (tensor<4xf32>) -> tensor<*xf32>
-
   // CHECK: "tosa.reverse"(%arg0) {axis = 0 : i64} : (tensor<4xf32>) -> tensor<4xf32>
-  %9 = "tosa.reverse"(%arg0) { axis = 0 : i64 } : (tensor<4xf32>) -> tensor<?xf32>
+  %8 = "tosa.reverse"(%arg0) { axis = 0 : i64 } : (tensor<4xf32>) -> tensor<?xf32>
 
   // CHECK: "tosa.rsqrt"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  %10 = "tosa.rsqrt"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
+  %9 = "tosa.rsqrt"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
 
   // CHECK: "tosa.tanh"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  %11 = "tosa.tanh"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
+  %10 = "tosa.tanh"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
 
   // CHECK: "tosa.sigmoid"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  %12 = "tosa.sigmoid"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
+  %11 = "tosa.sigmoid"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
 
   // CHECK: "tosa.cast"(%arg0) : (tensor<4xf32>) -> tensor<4xi32>
-  %13 = "tosa.cast"(%arg0) : (tensor<4xf32>) -> tensor<*xi32>
+  %12 = "tosa.cast"(%arg0) : (tensor<4xf32>) -> tensor<*xi32>
   return
 }
 
@@ -90,17 +87,14 @@ func.func @test_unary_i32(%arg0 : tensor<4xi32>) -> () {
   // CHECK: "tosa.negate"(%arg0) : (tensor<4xi32>) -> tensor<4xi32>
   %4 = "tosa.negate"(%arg0) : (tensor<4xi32>) -> tensor<*xi32>
 
-  // CHECK: "tosa.reluN"(%arg0) {{.+}} : (tensor<4xi32>) -> tensor<4xi32>
-  %5 = "tosa.reluN"(%arg0) { max_int = 10 : i64, min_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 10.0 : f32 } : (tensor<4xi32>) -> tensor<*xi32>
-
   // CHECK: "tosa.reverse"(%arg0) {axis = 0 : i64} : (tensor<4xi32>) -> tensor<4xi32>
-  %6 = "tosa.reverse"(%arg0) { axis = 0 : i64 } : (tensor<4xi32>) -> tensor<?xi32>
+  %5 = "tosa.reverse"(%arg0) { axis = 0 : i64 } : (tensor<4xi32>) -> tensor<?xi32>
 
   // CHECK: "tosa.rescale"(%arg0) {{.+}} : (tensor<4xi32>) -> tensor<4xi16>
-  %7 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [42 : i32, 43 : i32], shift = [14 : i32, 15 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor<4xi32>)  -> (tensor<*xi16>)
+  %6 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [42 : i32, 43 : i32], shift = [14 : i32, 15 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor<4xi32>)  -> (tensor<*xi16>)
 
   // CHECK: "tosa.identity"(%arg0) : (tensor<4xi32>) -> tensor<4xi32>
-  %8 = "tosa.identity"(%arg0) : (tensor<4xi32>) -> tensor<?xi32>
+  %7 = "tosa.identity"(%arg0) : (tensor<4xi32>) -> tensor<?xi32>
   return
 }
 


        


More information about the Mlir-commits mailing list