[Mlir-commits] [mlir] 3500e11 - [mlir][tosa] Add InferTensorType interface to tosa reduce operations
Aviad Cohen
llvmlistbot at llvm.org
Tue Apr 4 21:25:17 PDT 2023
Author: Aviad Cohen
Date: 2023-04-05T07:25:10+03:00
New Revision: 3500e11065d616f4653ea8ba8c979b29c69a00d7
URL: https://github.com/llvm/llvm-project/commit/3500e11065d616f4653ea8ba8c979b29c69a00d7
DIFF: https://github.com/llvm/llvm-project/commit/3500e11065d616f4653ea8ba8c979b29c69a00d7.diff
LOG: [mlir][tosa] Add InferTensorType interface to tosa reduce operations
When this interface is used, a call to inferReturnTypeComponents()
is generated on creation and verification of the op.
Reviewed By: jpienaar, eric-k256
Differential Revision: https://reviews.llvm.org/D147407
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Dialect/Tosa/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 043098f65a9e..287e62465251 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1262,9 +1262,7 @@ def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [
// Operator: reduce_all
//===----------------------------------------------------------------------===//
def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+ InferTensorType, Pure]> {
let summary = "Reduce All operator";
let description = [{
@@ -1281,15 +1279,19 @@ def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [
);
let hasFolder = 1;
+
+ let extraClassDeclaration = [{
+ /// Returns true when two result types are compatible for this op;
+ /// Method used by InferTypeOpInterface.
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+ }];
}
//===----------------------------------------------------------------------===//
// Operator: reduce_any
//===----------------------------------------------------------------------===//
def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+ InferTensorType, Pure]> {
let summary = "Reduce Any operator";
let description = [{
@@ -1306,15 +1308,19 @@ def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [
);
let hasFolder = 1;
+
+ let extraClassDeclaration = [{
+ /// Returns true when two result types are compatible for this op;
+ /// Method used by InferTypeOpInterface.
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+ }];
}
//===----------------------------------------------------------------------===//
// Operator: reduce_max
//===----------------------------------------------------------------------===//
def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+ InferTensorType, Pure]> {
let summary = "Reduce Max operator";
let description = [{
@@ -1331,15 +1337,19 @@ def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [
);
let hasFolder = 1;
+
+ let extraClassDeclaration = [{
+ /// Returns true when two result types are compatible for this op;
+ /// Method used by InferTypeOpInterface.
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+ }];
}
//===----------------------------------------------------------------------===//
// Operator: reduce_min
//===----------------------------------------------------------------------===//
def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+ InferTensorType, Pure]> {
let summary = "Reduce Min operator";
let description = [{
@@ -1356,15 +1366,19 @@ def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [
);
let hasFolder = 1;
+
+ let extraClassDeclaration = [{
+ /// Returns true when two result types are compatible for this op;
+ /// Method used by InferTypeOpInterface.
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+ }];
}
//===----------------------------------------------------------------------===//
// Operator: reduce_prod
//===----------------------------------------------------------------------===//
def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+ InferTensorType, Pure]> {
let summary = "Reduce Prod operator";
let description = [{
@@ -1381,15 +1395,19 @@ def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [
);
let hasFolder = 1;
+
+ let extraClassDeclaration = [{
+ /// Returns true when two result types are compatible for this op;
+ /// Method used by InferTypeOpInterface.
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+ }];
}
//===----------------------------------------------------------------------===//
// Operator: reduce_sum
//===----------------------------------------------------------------------===//
def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+ InferTensorType, Pure]> {
let summary = "Reduce Sum operator";
let description = [{
@@ -1406,6 +1424,12 @@ def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [
);
let hasFolder = 1;
+
+ let extraClassDeclaration = [{
+ /// Returns true when two result types are compatible for this op;
+ /// Method used by InferTypeOpInterface.
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+ }];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 13a43516f8a2..b22bd6590f37 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -422,14 +422,6 @@ LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
return success();
}
-bool tosa::ConcatOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
- if (l.size() != r.size() || l.size() != 1)
- return false;
- if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0]))
- return false;
- return succeeded(verifyCompatibleShape(l[0], r[0]));
-}
-
LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
@@ -913,10 +905,10 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
}
static LogicalResult ReduceInferReturnTypes(
- ShapeAdaptor operandShape, IntegerAttr axis,
+ ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
if (!operandShape.hasRank()) {
- inferredReturnShapes.push_back(ShapedTypeComponents());
+ inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
return success();
}
@@ -924,20 +916,32 @@ static LogicalResult ReduceInferReturnTypes(
operandShape.getDims(outputShape);
int64_t axisVal = axis.getValue().getSExtValue();
outputShape[axisVal] = 1;
- inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+ inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
return success();
}
+#define COMPATIBLE_RETURN_TYPES(OP) \
+ bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
+ if (l.size() != r.size() || l.size() != 1) \
+ return false; \
+ if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
+ return false; \
+ return succeeded(verifyCompatibleShape(l[0], r[0])); \
+ }
+
#define REDUCE_SHAPE_INFER(OP) \
LogicalResult OP::inferReturnTypeComponents( \
MLIRContext *context, ::std::optional<Location> location, \
ValueShapeRange operands, DictionaryAttr attributes, \
RegionRange regions, \
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
- return ReduceInferReturnTypes(operands.getShape(0), \
+ Type inputType = \
+ operands.getType()[0].cast<TensorType>().getElementType(); \
+ return ReduceInferReturnTypes(operands.getShape(0), inputType, \
attributes.get("axis").cast<IntegerAttr>(), \
inferredReturnShapes); \
- }
+ } \
+ COMPATIBLE_RETURN_TYPES(OP)
REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
@@ -946,6 +950,8 @@ REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
REDUCE_SHAPE_INFER(tosa::ReduceProdOp)
REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
#undef REDUCE_SHAPE_INFER
+COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
+#undef COMPATIBLE_RETURN_TYPES
static LogicalResult NAryInferReturnTypes(
const ValueShapeRange &operands,
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 5a120eed3a8f..c05a1c4577b7 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -96,3 +96,35 @@ func.func @test_fully_connected_non_const(%arg0: tensor<13x21x3xf32>, %arg1: ten
%2 = "tosa.fully_connected"(%1, %0, %arg1) : (tensor<273x3xf32>, tensor<2x3xf32>, tensor<2xf32>) -> tensor<273x2xf32>
return %2 : tensor<273x2xf32>
}
+
+// -----
+
+func.func @test_reduce_sum_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
+ // expected-error at +1 {{'tosa.reduce_sum' op inferred type(s) 'tensor<1x3x4x5xf32>' are incompatible with return type(s) of operation 'tensor<1x3x4x5xi32>'}}
+ %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<2x3x4x5xf32>) -> tensor<1x3x4x5xi32>
+ return
+}
+
+// -----
+
+func.func @test_reduce_max_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
+ // expected-error at +1 {{'tosa.reduce_max' op inferred type(s) 'tensor<2x3x4x1xf32>' are incompatible with return type(s) of operation 'tensor<2x3x4x1xi32>'}}
+ %0 = "tosa.reduce_max"(%arg0) {axis = 3 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x1xi32>
+ return
+}
+
+// -----
+
+func.func @test_reduce_min_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
+ // expected-error at +1 {{'tosa.reduce_min' op inferred type(s) 'tensor<2x1x4x5xf32>' are incompatible with return type(s) of operation 'tensor<2x1x4x5xi32>'}}
+ %0 = "tosa.reduce_min"(%arg0) {axis = 1 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x1x4x5xi32>
+ return
+}
+
+// -----
+
+func.func @test_reduce_prod_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
+ // expected-error at +1 {{'tosa.reduce_prod' op inferred type(s) 'tensor<2x1x4x5xf32>' are incompatible with return type(s) of operation 'tensor<2x3x4x5xf32>'}}
+ %0 = "tosa.reduce_prod"(%arg0) {axis = 1 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32>
+ return
+}
More information about the Mlir-commits
mailing list