[Mlir-commits] [mlir] [Tosa] Add Tosa Sin and Cos operators (PR #82510)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 21 09:26:56 PST 2024


https://github.com/Jerry-Ge created https://github.com/llvm/llvm-project/pull/82510

- Add Tosa Sin and Cos operators to the MLIR dialect
- Define the new Tosa_FloatTensor type

>From 2e1f4d899a72e51db1ee848441495b4d34551931 Mon Sep 17 00:00:00 2001
From: Jerry Ge <jerry.ge at arm.com>
Date: Wed, 21 Feb 2024 09:25:24 -0800
Subject: [PATCH] [Tosa] Add Tosa Sin and Cos operators

- Add Tosa Sin and Cos operators to the MLIR dialect
- Define the new Tosa_FloatTensor type

Signed-off-by: Jerry Ge <jerry.ge at arm.com>
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  | 54 ++++++++++++++++---
 .../mlir/Dialect/Tosa/IR/TosaTypesBase.td     |  2 +
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          |  2 +
 mlir/test/Dialect/Tosa/ops.mlir               | 14 +++++
 4 files changed, 65 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 0ee9e713724ea2..48811ff21366a8 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -989,6 +989,26 @@ def Tosa_ClzOp : Tosa_ElementwiseOp<"clz", [SameOperandsAndResultElementType]> {
   );
 }
 
+//===----------------------------------------------------------------------===//
+// Operator: cos
+//===----------------------------------------------------------------------===//
+def Tosa_CosOp : Tosa_ElementwiseOp<"cos",
+    [SameOperandsAndResultElementType]> {
+  let summary = "Elementwise cos op";
+
+  let description = [{
+    Elementwise cosine operation for values given in radians.
+  }];
+
+  let arguments = (ins
+    Tosa_FloatTensor:$input
+  );
+
+  let results = (outs
+    Tosa_FloatTensor:$output
+  );
+}
+
 //===----------------------------------------------------------------------===//
 // Operator: exp
 //===----------------------------------------------------------------------===//
@@ -1148,6 +1168,26 @@ def Tosa_RsqrtOp : Tosa_ElementwiseOp<"rsqrt",
   );
 }
 
+//===----------------------------------------------------------------------===//
+// Operator: sin
+//===----------------------------------------------------------------------===//
+def Tosa_SinOp : Tosa_ElementwiseOp<"sin",
+    [SameOperandsAndResultElementType]> {
+  let summary = "Elementwise sin op";
+
+  let description = [{
+    Elementwise sine operation for values given in radians.
+  }];
+
+  let arguments = (ins
+    Tosa_FloatTensor:$input
+  );
+
+  let results = (outs
+    Tosa_FloatTensor:$output
+  );
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Spec Section 2.6
 // Operator Class: Elementwise unary/binary/ternary operators.
@@ -1461,7 +1501,7 @@ def Tosa_ReduceSumOp : Tosa_InferTensorTypeOp<"reduce_sum"> {
 
   let hasFolder = 1;
   let hasVerifier = 1;
-  
+
   let extraClassDeclaration = [{
     /// Returns true when two result types are compatible for this op;
     /// Method used by InferTypeOpInterface.
@@ -1611,7 +1651,7 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
 
   let hasFolder = 1;
   let hasVerifier = 1;
-  
+
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
 }
 
@@ -1794,9 +1834,9 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
 
     | Mode                     | Input   | Output  |
     |--------------------------|---------|---------|
-    | signed 8 to bool         | int8    | Boolean | 
-    | signed 16 to bool        | int16   | Boolean | 
-    | signed 32 to bool        | int32   | Boolean | 
+    | signed 8 to bool         | int8    | Boolean |
+    | signed 16 to bool        | int16   | Boolean |
+    | signed 32 to bool        | int32   | Boolean |
     | bool to 8                | Boolean | int8    |
     | bool to 16               | Boolean | int16   |
     | bool to 32               | Boolean | int32   |
@@ -1810,8 +1850,8 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
     | float to signed 16       | float   | int16   |
     | signed 8 to float        | int8    | float   |
     | signed 16 to float       | int16   | float   |
-    | float 32 to float 64     | float32 | float64 | 
-    | float 64 to float 32     | float64 | float32 | 
+    | float 32 to float 64     | float32 | float64 |
+    | float 64 to float 32     | float64 | float32 |
   }];
 
   let arguments = (ins
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index c55ddaafdda76e..5a4d6ff464f19e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -113,6 +113,8 @@ def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
 def Tosa_Int32Tensor : TensorOf<[Tosa_Int32]>;
 def Tosa_Int32Or64Tensor : TensorOf<[Tosa_Int32Or64]>;
 
+def Tosa_FloatTensor : TensorOf<[Tosa_Float]>;
+
 // Either ranked or unranked tensor of TOSA supported element types.
 def Tosa_Tensor : TensorOf<[Tosa_AnyNumber]>;
 def Tosa_Tensor_Plus_F64 : TensorOf<[Tosa_AnyNumber_Plus_F64]>;
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 950ee597b891b5..62d07859e32f6e 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1330,6 +1330,7 @@ NARY_SHAPE_INFER(tosa::CastOp)
 NARY_SHAPE_INFER(tosa::CeilOp)
 NARY_SHAPE_INFER(tosa::ClampOp)
 NARY_SHAPE_INFER(tosa::ClzOp)
+NARY_SHAPE_INFER(tosa::CosOp)
 NARY_SHAPE_INFER(tosa::DivOp)
 NARY_SHAPE_INFER(tosa::ExpOp)
 NARY_SHAPE_INFER(tosa::FloorOp)
@@ -1352,6 +1353,7 @@ NARY_SHAPE_INFER(tosa::ReciprocalOp)
 NARY_SHAPE_INFER(tosa::RescaleOp)
 NARY_SHAPE_INFER(tosa::ReverseOp)
 NARY_SHAPE_INFER(tosa::RsqrtOp)
+NARY_SHAPE_INFER(tosa::SinOp)
 NARY_SHAPE_INFER(tosa::SelectOp)
 NARY_SHAPE_INFER(tosa::SubOp)
 NARY_SHAPE_INFER(tosa::TanhOp)
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 3d68464ebf0b30..01b27072a4b646 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -375,6 +375,13 @@ func.func @test_clz(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
   return %0 : tensor<13x21x3xi32>
 }
 
+// -----
+// CHECK-LABEL: cos
+func.func @test_cos(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+  %0 = tosa.cos %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
 // -----
 // CHECK-LABEL: exp
 func.func @test_exp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
@@ -424,6 +431,13 @@ func.func @test_rsqrt(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
   return %0 : tensor<13x21x3xf32>
 }
 
+// -----
+// CHECK-LABEL: sin
+func.func @test_sin(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+  %0 = tosa.sin %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
 // -----
 // CHECK-LABEL: select
 func.func @test_select(%arg0: tensor<1x1x1xi1>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {



More information about the Mlir-commits mailing list