[Mlir-commits] [mlir] e86defd - [mlir][tosa] Add type checking traits to the appropriate ops
Robert Suderman
llvmlistbot at llvm.org
Thu May 25 17:03:25 PDT 2023
Author: TatWai Chong
Date: 2023-05-25T23:51:30Z
New Revision: e86defd588e79de60d19a101a43b0c8c86dff37b
URL: https://github.com/llvm/llvm-project/commit/e86defd588e79de60d19a101a43b0c8c86dff37b
DIFF: https://github.com/llvm/llvm-project/commit/e86defd588e79de60d19a101a43b0c8c86dff37b.diff
LOG: [mlir][tosa] Add type checking traits to the appropriate ops
Add the trait `SameOperandsAndResultElementType` and
`SameOperandsElementType` to verify ops that are known
to have the same input and output type rather than generate
an invalid tosa IR with mixed data types like:
"tosa.add"(%0, %1) : (tensor<nxbf16>, tensor<nxf32>) -> tensor<nxf32>
Thus apply tosa.cast prior if needed.
Change-Id: Ie866b84e371e3b571ec04f7abb090c216dd39c33
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D150472
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 41fbdcaea0956..64471076d25e9 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -215,4 +215,18 @@ class Tosa_Op<string mnemonic, list<Trait> traits = []> :
Op<Tosa_Dialect, mnemonic, !listconcat(traits, [TosaOpInterface])> {
}
+class Tosa_ElemWiseUnaryOp<string mnemonic, list<Trait> traits = []> :
+ Tosa_Op<mnemonic, !listconcat(traits, [
+ DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+ ["inferReturnTypeComponents"]>,
+ Pure, SameOperandsAndResultElementType])> {
+}
+
+class Tosa_ElemWiseBinaryOp<string mnemonic, list<Trait> traits = []> :
+ Tosa_Op<mnemonic, !listconcat(traits, [
+ DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+ ["inferReturnTypeComponents"]>,
+ ResultsBroadcastableShape, Pure, SameOperandsAndResultElementType])> {
+}
+
#endif // TOSA_OP_BASE
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 1f60353d0a3b7..613b9e325bb03 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -375,10 +375,7 @@ def Tosa_TransposeConv2DOp : Tosa_Op<"transpose_conv2d", [
//===----------------------------------------------------------------------===//
// Operator: clamp
//===----------------------------------------------------------------------===//
-def Tosa_ClampOp : Tosa_Op<"clamp", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+def Tosa_ClampOp : Tosa_ElemWiseUnaryOp<"clamp"> {
let summary = "Computes clamp(features, min, max).";
let description = [{
@@ -407,10 +404,7 @@ def Tosa_ClampOp : Tosa_Op<"clamp", [
//===----------------------------------------------------------------------===//
// Operator: sigmoid
//===----------------------------------------------------------------------===//
-def Tosa_SigmoidOp : Tosa_Op<"sigmoid", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+def Tosa_SigmoidOp : Tosa_ElemWiseUnaryOp<"sigmoid"> {
let summary = "Computes elementwise sigmoid of input.";
let description = [{
@@ -433,10 +427,7 @@ def Tosa_SigmoidOp : Tosa_Op<"sigmoid", [
//===----------------------------------------------------------------------===//
// Operator: tanh
//===----------------------------------------------------------------------===//
-def Tosa_TanhOp : Tosa_Op<"tanh", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+def Tosa_TanhOp : Tosa_ElemWiseUnaryOp<"tanh"> {
let summary = "Computes elementwise hyperbolic tangent of input";
let description = [{
@@ -490,10 +481,7 @@ def Tosa_ErfOp : Tosa_Op<"erf", [
//===----------------------------------------------------------------------===//
// Operator: add
//===----------------------------------------------------------------------===//
-def Tosa_AddOp : Tosa_Op<"add", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- ResultsBroadcastableShape, Pure, Commutative]> {
+def Tosa_AddOp : Tosa_ElemWiseBinaryOp<"add", [Commutative]> {
let summary = "Elementwise addition operator";
let description = [{
@@ -516,10 +504,7 @@ def Tosa_AddOp : Tosa_Op<"add", [
//===----------------------------------------------------------------------===//
// Operator: arithmetic_right_shift
//===----------------------------------------------------------------------===//
-def Tosa_ArithmeticRightShiftOp : Tosa_Op<"arithmetic_right_shift", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- ResultsBroadcastableShape, Pure]> {
+def Tosa_ArithmeticRightShiftOp : Tosa_ElemWiseBinaryOp<"arithmetic_right_shift"> {
let summary = "Elementwise Arithmetic Right Shift";
let description = [{
@@ -541,10 +526,7 @@ def Tosa_ArithmeticRightShiftOp : Tosa_Op<"arithmetic_right_shift", [
//===----------------------------------------------------------------------===//
// Operator: bitwise_and
//===----------------------------------------------------------------------===//
-def Tosa_BitwiseAndOp : Tosa_Op<"bitwise_and", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- ResultsBroadcastableShape, Pure, Commutative]> {
+def Tosa_BitwiseAndOp : Tosa_ElemWiseBinaryOp<"bitwise_and", [Commutative]> {
let summary = "Bitwise AND operator";
let description = [{
@@ -565,10 +547,7 @@ def Tosa_BitwiseAndOp : Tosa_Op<"bitwise_and", [
//===----------------------------------------------------------------------===//
// Operator: bitwise_or
//===----------------------------------------------------------------------===//
-def Tosa_BitwiseOrOp : Tosa_Op<"bitwise_or", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- ResultsBroadcastableShape, Pure, Commutative]> {
+def Tosa_BitwiseOrOp : Tosa_ElemWiseBinaryOp<"bitwise_or", [Commutative]> {
let summary = "Bitwise OR operator";
let description = [{
@@ -589,10 +568,7 @@ def Tosa_BitwiseOrOp : Tosa_Op<"bitwise_or", [
//===----------------------------------------------------------------------===//
// Operator: bitwise_xor
//===----------------------------------------------------------------------===//
-def Tosa_BitwiseXorOp : Tosa_Op<"bitwise_xor", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- ResultsBroadcastableShape, Pure, Commutative]> {
+def Tosa_BitwiseXorOp : Tosa_ElemWiseBinaryOp<"bitwise_xor", [Commutative]> {
let summary = "Bitwise XOR operator";
let description = [{
@@ -613,10 +589,7 @@ def Tosa_BitwiseXorOp : Tosa_Op<"bitwise_xor", [
//===----------------------------------------------------------------------===//
// Operator: div
//===----------------------------------------------------------------------===//
-def Tosa_DivOp : Tosa_Op<"div", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- ResultsBroadcastableShape, Pure]> {
+def Tosa_DivOp : Tosa_ElemWiseBinaryOp<"div"> {
let summary = "Integer divide operator";
let description = [{
@@ -639,10 +612,7 @@ def Tosa_DivOp : Tosa_Op<"div", [
//===----------------------------------------------------------------------===//
// Operator: logical_and
//===----------------------------------------------------------------------===//
-def Tosa_LogicalAndOp : Tosa_Op<"logical_and", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- ResultsBroadcastableShape, Commutative, Pure]> {
+def Tosa_LogicalAndOp : Tosa_ElemWiseBinaryOp<"logical_and", [Commutative]> {
let summary = "Returns the truth value of x AND y element-wise.";
let description = [{
@@ -663,10 +633,7 @@ def Tosa_LogicalAndOp : Tosa_Op<"logical_and", [
//===----------------------------------------------------------------------===//
// Operator: logical_left_shift
//===----------------------------------------------------------------------===//
-def Tosa_LogicalLeftShiftOp : Tosa_Op<"logical_left_shift", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- ResultsBroadcastableShape, Pure]> {
+def Tosa_LogicalLeftShiftOp : Tosa_ElemWiseBinaryOp<"logical_left_shift"> {
let summary = "Elementwise Logical Left Shift";
let description = [{
@@ -687,10 +654,7 @@ def Tosa_LogicalLeftShiftOp : Tosa_Op<"logical_left_shift", [
//===----------------------------------------------------------------------===//
// Operator: logical_right_shift
//===----------------------------------------------------------------------===//
-def Tosa_LogicalRightShiftOp : Tosa_Op<"logical_right_shift", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- ResultsBroadcastableShape, Pure]> {
+def Tosa_LogicalRightShiftOp : Tosa_ElemWiseBinaryOp<"logical_right_shift"> {
let summary = "Elementwise Logical Right Shift";
let description = [{
@@ -711,10 +675,7 @@ def Tosa_LogicalRightShiftOp : Tosa_Op<"logical_right_shift", [
//===----------------------------------------------------------------------===//
// Operator: logical_or
//===----------------------------------------------------------------------===//
-def Tosa_LogicalOrOp : Tosa_Op<"logical_or", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- ResultsBroadcastableShape, Commutative, Pure]> {
+def Tosa_LogicalOrOp : Tosa_ElemWiseBinaryOp<"logical_or", [Commutative]> {
let summary = "Returns the truth value of x OR y element-wise.";
let description = [{
@@ -735,10 +696,7 @@ def Tosa_LogicalOrOp : Tosa_Op<"logical_or", [
//===----------------------------------------------------------------------===//
// Operator: logical_xor
//===----------------------------------------------------------------------===//
-def Tosa_LogicalXorOp : Tosa_Op<"logical_xor", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- ResultsBroadcastableShape, Commutative, Pure]> {
+def Tosa_LogicalXorOp : Tosa_ElemWiseBinaryOp<"logical_xor", [Commutative]> {
let summary = "Returns the truth value of x XOR y element-wise.";
let description = [{
@@ -759,10 +717,7 @@ def Tosa_LogicalXorOp : Tosa_Op<"logical_xor", [
//===----------------------------------------------------------------------===//
// Operator: maximum
//===----------------------------------------------------------------------===//
-def Tosa_MaximumOp : Tosa_Op<"maximum", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- ResultsBroadcastableShape, Pure, Commutative]> {
+def Tosa_MaximumOp : Tosa_ElemWiseBinaryOp<"maximum", [Commutative]> {
let summary = "Elementwise Maximum";
let description = [{
@@ -783,10 +738,7 @@ def Tosa_MaximumOp : Tosa_Op<"maximum", [
//===----------------------------------------------------------------------===//
// Operator: minimum
//===----------------------------------------------------------------------===//
-def Tosa_MinimumOp : Tosa_Op<"minimum", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- ResultsBroadcastableShape, Pure, Commutative]> {
+def Tosa_MinimumOp : Tosa_ElemWiseBinaryOp<"minimum", [Commutative]> {
let summary = "Elementwise Minimum";
let description = [{
@@ -807,15 +759,13 @@ def Tosa_MinimumOp : Tosa_Op<"minimum", [
//===----------------------------------------------------------------------===//
// Operator: mul
//===----------------------------------------------------------------------===//
-def Tosa_MulOp : Tosa_Op<"mul", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- ResultsBroadcastableShape, Pure, Commutative]> {
+def Tosa_MulOp : Tosa_ElemWiseBinaryOp<"mul", [Commutative]> {
let summary = "Multiplication operator";
let description = [{
Elementwise multiplication (Hadamard product) of input1 and input2.
Axis of size 1 will be broadcast, as necessary.
+ i8/i16 input type can be promoted to i32 result type.
}];
let arguments = (ins
@@ -834,10 +784,7 @@ def Tosa_MulOp : Tosa_Op<"mul", [
//===----------------------------------------------------------------------===//
// Operator: pow
//===----------------------------------------------------------------------===//
-def Tosa_PowOp : Tosa_Op<"pow", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- ResultsBroadcastableShape, Pure]> {
+def Tosa_PowOp : Tosa_ElemWiseBinaryOp<"pow"> {
let summary = "Computes the power of one value to another.";
let description = [{
@@ -858,10 +805,7 @@ def Tosa_PowOp : Tosa_Op<"pow", [
//===----------------------------------------------------------------------===//
// Operator: sub
//===----------------------------------------------------------------------===//
-def Tosa_SubOp : Tosa_Op<"sub", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- ResultsBroadcastableShape, Pure]> {
+def Tosa_SubOp : Tosa_ElemWiseBinaryOp<"sub"> {
let summary = "Elementwise subtraction operator";
let description = [{
@@ -927,10 +871,7 @@ def Tosa_TableOp : Tosa_Op<"table", [
//===----------------------------------------------------------------------===//
// Operator: abs
//===----------------------------------------------------------------------===//
-def Tosa_AbsOp : Tosa_Op<"abs", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+def Tosa_AbsOp : Tosa_ElemWiseUnaryOp<"abs"> {
let summary = "Elementwise abs op";
let description = [{
@@ -951,10 +892,7 @@ def Tosa_AbsOp : Tosa_Op<"abs", [
//===----------------------------------------------------------------------===//
// Operator: bitwise_not
//===----------------------------------------------------------------------===//
-def Tosa_BitwiseNotOp : Tosa_Op<"bitwise_not", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- ResultsBroadcastableShape, Pure]> {
+def Tosa_BitwiseNotOp : Tosa_ElemWiseUnaryOp<"bitwise_not"> {
let summary = "Bitwise NOT operator";
let description = [{
@@ -973,10 +911,7 @@ def Tosa_BitwiseNotOp : Tosa_Op<"bitwise_not", [
//===----------------------------------------------------------------------===//
// Operator: ceil
//===----------------------------------------------------------------------===//
-def Tosa_CeilOp : Tosa_Op<"ceil", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+def Tosa_CeilOp : Tosa_ElemWiseUnaryOp<"ceil"> {
let summary = "Elementwise ceil op";
let description = [{
@@ -995,10 +930,7 @@ def Tosa_CeilOp : Tosa_Op<"ceil", [
//===----------------------------------------------------------------------===//
// Operator: clz
//===----------------------------------------------------------------------===//
-def Tosa_ClzOp : Tosa_Op<"clz", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+def Tosa_ClzOp : Tosa_ElemWiseUnaryOp<"clz"> {
let summary = "Elementwise count leading zero op";
let description = [{
@@ -1017,10 +949,7 @@ def Tosa_ClzOp : Tosa_Op<"clz", [
//===----------------------------------------------------------------------===//
// Operator: exp
//===----------------------------------------------------------------------===//
-def Tosa_ExpOp : Tosa_Op<"exp", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+def Tosa_ExpOp : Tosa_ElemWiseUnaryOp<"exp"> {
let summary = "Elementwise exp op";
let description = [{
@@ -1041,10 +970,7 @@ def Tosa_ExpOp : Tosa_Op<"exp", [
//===----------------------------------------------------------------------===//
// Operator: floor
//===----------------------------------------------------------------------===//
-def Tosa_FloorOp : Tosa_Op<"floor", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+def Tosa_FloorOp : Tosa_ElemWiseUnaryOp<"floor"> {
let summary = "Elementwise floor op";
let description = [{
@@ -1063,10 +989,7 @@ def Tosa_FloorOp : Tosa_Op<"floor", [
//===----------------------------------------------------------------------===//
// Operator: log
//===----------------------------------------------------------------------===//
-def Tosa_LogOp : Tosa_Op<"log", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+def Tosa_LogOp : Tosa_ElemWiseUnaryOp<"log"> {
let summary = "Elementwise log op";
let description = [{
@@ -1087,10 +1010,7 @@ def Tosa_LogOp : Tosa_Op<"log", [
//===----------------------------------------------------------------------===//
// Operator: logical_not
//===----------------------------------------------------------------------===//
-def Tosa_LogicalNotOp : Tosa_Op<"logical_not", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure, SameOperandsAndResultType]> {
+def Tosa_LogicalNotOp : Tosa_ElemWiseUnaryOp<"logical_not"> {
let summary = "Returns the truth value of NOT x element-wise.";
let description = [{
@@ -1109,10 +1029,7 @@ def Tosa_LogicalNotOp : Tosa_Op<"logical_not", [
//===----------------------------------------------------------------------===//
// Operator: negate
//===----------------------------------------------------------------------===//
-def Tosa_NegateOp : Tosa_Op<"negate", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+def Tosa_NegateOp : Tosa_ElemWiseUnaryOp<"negate"> {
let summary = "Elementwise negate op";
let description = [{
@@ -1136,10 +1053,7 @@ def Tosa_NegateOp : Tosa_Op<"negate", [
//===----------------------------------------------------------------------===//
// Operator: reciprocal
//===----------------------------------------------------------------------===//
-def Tosa_ReciprocalOp : Tosa_Op<"reciprocal", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+def Tosa_ReciprocalOp : Tosa_ElemWiseUnaryOp<"reciprocal"> {
let summary = "Elementwise reciprocal op";
let description = [{
@@ -1159,10 +1073,7 @@ def Tosa_ReciprocalOp : Tosa_Op<"reciprocal", [
//===----------------------------------------------------------------------===//
// Operator: rsqrt
//===----------------------------------------------------------------------===//
-def Tosa_RsqrtOp : Tosa_Op<"rsqrt", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+def Tosa_RsqrtOp : Tosa_ElemWiseUnaryOp<"rsqrt"> {
let summary = "Elementwise 1/sqrt op";
let description = [{
@@ -1219,7 +1130,7 @@ def Tosa_SelectOp : Tosa_Op<"select", [
// Operator: equal
//===----------------------------------------------------------------------===//
def Tosa_EqualOp : Tosa_Op<"equal", [InferTensorType, ResultsBroadcastableShape,
- Commutative, Pure]> {
+ Commutative, Pure, SameOperandsElementType]> {
let summary = "Returns the truth value of (x == y) element-wise.";
let description = [{
@@ -1250,7 +1161,7 @@ def Tosa_EqualOp : Tosa_Op<"equal", [InferTensorType, ResultsBroadcastableShape,
def Tosa_GreaterOp : Tosa_Op<"greater", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
- ResultsBroadcastableShape, Pure]> {
+ ResultsBroadcastableShape, Pure, SameOperandsElementType]> {
let summary = "Returns the truth value of (x > y) element-wise.";
let description = [{
@@ -1275,7 +1186,7 @@ def Tosa_GreaterOp : Tosa_Op<"greater", [
def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
- ResultsBroadcastableShape, Pure]> {
+ ResultsBroadcastableShape, Pure, SameOperandsElementType]> {
let summary = "Returns the truth value of (x >= y) element-wise.";
let description = [{
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 65d56ad7ad588..9e5615e5c33f9 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -282,10 +282,8 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
// CHECK-LABEL: @test_simple_i16
func.func @test_simple_i16(%arg0: tensor<1xi16>) -> () {
// CHECK: linalg.generic
- // CHECK: arith.extsi
- // CHECK: arith.extsi
// CHECK: arith.muli
- %0 = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32>
+ %0 = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi16>
return
}
More information about the Mlir-commits
mailing list