[Mlir-commits] [mlir] a879a1b - [mlir][tosa] Add tosa.reciprocal and tosa.sigmoid lowerings

Rob Suderman llvmlistbot at llvm.org
Wed Mar 31 14:21:58 PDT 2021


Author: natashaknk
Date: 2021-03-31T14:21:03-07:00
New Revision: a879a1b034943318f2a8fa52c12bd142df5ebd51

URL: https://github.com/llvm/llvm-project/commit/a879a1b034943318f2a8fa52c12bd142df5ebd51
DIFF: https://github.com/llvm/llvm-project/commit/a879a1b034943318f2a8fa52c12bd142df5ebd51.diff

LOG: [mlir][tosa] Add tosa.reciprocal and tosa.sigmoid lowerings

Lowering reciprocal and sigmoid elementwise operations to linalg dialect.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D99676

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 4dbd8879d6a0a..dbc7654818b3c 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -115,6 +115,13 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
     return rewriter.create<mlir::MulFOp>(loc, resultTypes, args);
   }
 
+  // tosa::ReciprocalOp
+  if (isa<tosa::ReciprocalOp>(op) && elementTy.isa<FloatType>()) {
+    auto one =
+        rewriter.create<mlir::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
+    return rewriter.create<mlir::DivFOp>(loc, resultTypes, one, args[0]);
+  }
+
   if (isa<tosa::MulOp>(op) && elementTy.isa<IntegerType>()) {
     Value a = args[0];
     Value b = args[1];
@@ -325,6 +332,16 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
                                      rewriter);
   }
 
+  // tosa::SigmoidOp
+  if (isa<tosa::SigmoidOp>(op) && elementTy.isa<FloatType>()) {
+    auto one =
+        rewriter.create<mlir::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
+    auto negate = rewriter.create<mlir::NegFOp>(loc, resultTypes, args[0]);
+    auto exp = rewriter.create<mlir::math::ExpOp>(loc, resultTypes, negate);
+    auto added = rewriter.create<mlir::AddFOp>(loc, resultTypes, exp, one);
+    return rewriter.create<mlir::DivFOp>(loc, resultTypes, one, added);
+  }
+
   // tosa::CastOp
   if (isa<tosa::CastOp>(op)) {
     Type srcTy = elementTy;
@@ -1382,11 +1399,11 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
     RewritePatternSet *patterns) {
   patterns->add<
       PointwiseConverter<tosa::AddOp>, PointwiseConverter<tosa::SubOp>,
-      PointwiseConverter<tosa::MulOp>, PointwiseConverter<tosa::NegateOp>,
-      PointwiseConverter<tosa::PowOp>, PointwiseConverter<tosa::RsqrtOp>,
-      PointwiseConverter<tosa::LogOp>, PointwiseConverter<tosa::ExpOp>,
-      PointwiseConverter<tosa::AbsOp>, PointwiseConverter<tosa::TanhOp>,
-      PointwiseConverter<tosa::BitwiseAndOp>,
+      PointwiseConverter<tosa::MulOp>, PointwiseConverter<tosa::ReciprocalOp>,
+      PointwiseConverter<tosa::NegateOp>, PointwiseConverter<tosa::PowOp>,
+      PointwiseConverter<tosa::RsqrtOp>, PointwiseConverter<tosa::LogOp>,
+      PointwiseConverter<tosa::ExpOp>, PointwiseConverter<tosa::AbsOp>,
+      PointwiseConverter<tosa::TanhOp>, PointwiseConverter<tosa::BitwiseAndOp>,
       PointwiseConverter<tosa::BitwiseOrOp>,
       PointwiseConverter<tosa::BitwiseNotOp>,
       PointwiseConverter<tosa::BitwiseXorOp>,
@@ -1401,11 +1418,11 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
       PointwiseConverter<tosa::MaximumOp>, PointwiseConverter<tosa::MinimumOp>,
       PointwiseConverter<tosa::CeilOp>, PointwiseConverter<tosa::FloorOp>,
       PointwiseConverter<tosa::ClampOp>, PointwiseConverter<tosa::ReluNOp>,
-      IdentityNConverter<tosa::IdentityOp>,
+      PointwiseConverter<tosa::SigmoidOp>, IdentityNConverter<tosa::IdentityOp>,
       IdentityNConverter<tosa::IdentityNOp>, ReduceConverter<tosa::ReduceMinOp>,
       ReduceConverter<tosa::ReduceMaxOp>, ReduceConverter<tosa::ReduceSumOp>,
-      ReduceConverter<tosa::ReduceProdOp>, ArgMaxConverter, ConcatConverter, PadConverter,
-      ReshapeConverter, RescaleConverter, ReverseConverter, TileConverter,
-      TransposeConverter, MatMulConverter, FullyConnectedConverter>(
-        patterns->getContext());
+      ReduceConverter<tosa::ReduceProdOp>, ArgMaxConverter, ConcatConverter,
+      PadConverter, ReshapeConverter, RescaleConverter, ReverseConverter,
+      TileConverter, TransposeConverter, MatMulConverter,
+      FullyConnectedConverter>(patterns->getContext());
 }

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 8dc968193829b..83ec9222cf3c0 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -180,22 +180,33 @@ func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
   // CHECK: select
   %18 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
 
+  // CHECK: linalg.generic
+  // CHECK: negf
+  // CHECK: exp
+  // CHECK: addf
+  // CHECK: divf
+  %19 = "tosa.sigmoid"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+
   // CHECK: linalg.generic
   // CHECK: fptosi
-  %19 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi32>
+  %20 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: constant 0
   // CHECK: cmpf
-  %20 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi1>
+  %21 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi1>
 
   // CHECK: linalg.generic
   // CHECK: fptrunc
-  %21 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf16>
+  %22 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf16>
 
   // CHECK: linalg.generic
   // CHECK: yield
-  %22 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+  %23 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+
+  // CHECK: linalg.generic
+  // CHECK: divf
+  %24 = "tosa.reciprocal"(%0) : (tensor<1xf32>) -> tensor<1xf32>
 
   return
 }


        


More information about the Mlir-commits mailing list