[Mlir-commits] [mlir] 8a57bc0 - [mlir][tosa] Add verifiers to ReduceOps, fix shape inference crash (#69843)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 24 10:53:43 PDT 2023
Author: Felix Schneider
Date: 2023-10-24T19:53:38+02:00
New Revision: 8a57bc092850d6a5fa43a7f74a04352bb135f56e
URL: https://github.com/llvm/llvm-project/commit/8a57bc092850d6a5fa43a7f74a04352bb135f56e
DIFF: https://github.com/llvm/llvm-project/commit/8a57bc092850d6a5fa43a7f74a04352bb135f56e.diff
LOG: [mlir][tosa] Add verifiers to ReduceOps, fix shape inference crash (#69843)
This patch adds verifiers to `tosa.reduce_*` ops that check, among other things,
that the supplied `axis` argument is compatible with the input/output tensors'
shapes. We allow for a special case of `axis == 0 && rank == 0` to be valid.
This patch also adds a check to `ReduceInferReturnTypes()` to ensure that the
shape inference pass doesn't crash on an invalid `axis` argument anymore.
Fix https://github.com/llvm/llvm-project/issues/68187
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Dialect/Tosa/canonicalize.mlir
mlir/test/Dialect/Tosa/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 5cc97469d14c314..901384eae50176b 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1271,6 +1271,7 @@ def Tosa_ReduceAllOp : Tosa_InferTensorTypeOp<"reduce_all"> {
);
let hasFolder = 1;
+ let hasVerifier = 1;
let extraClassDeclaration = [{
/// Returns true when two result types are compatible for this op;
@@ -1304,6 +1305,7 @@ def Tosa_ReduceAnyOp : Tosa_InferTensorTypeOp<"reduce_any"> {
);
let hasFolder = 1;
+ let hasVerifier = 1;
let extraClassDeclaration = [{
/// Returns true when two result types are compatible for this op;
@@ -1337,6 +1339,7 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {
);
let hasFolder = 1;
+ let hasVerifier = 1;
let extraClassDeclaration = [{
/// Returns true when two result types are compatible for this op;
@@ -1371,6 +1374,7 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {
);
let hasFolder = 1;
+ let hasVerifier = 1;
let extraClassDeclaration = [{
/// Returns true when two result types are compatible for this op;
@@ -1405,6 +1409,7 @@ def Tosa_ReduceProdOp : Tosa_InferTensorTypeOp<"reduce_prod"> {
);
let hasFolder = 1;
+ let hasVerifier = 1;
let extraClassDeclaration = [{
/// Returns true when two result types are compatible for this op;
@@ -1436,8 +1441,10 @@ def Tosa_ReduceSumOp : Tosa_InferTensorTypeOp<"reduce_sum"> {
let results = (outs
Tosa_Tensor:$output
);
- let hasFolder = 1;
+ let hasFolder = 1;
+ let hasVerifier = 1;
+
let extraClassDeclaration = [{
/// Returns true when two result types are compatible for this op;
/// Method used by InferTypeOpInterface.
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e03904a1611fc42..078e50e857fbb2a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1109,14 +1109,14 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
static LogicalResult ReduceInferReturnTypes(
ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- if (!operandShape.hasRank() || operandShape.getRank() == 0) {
+ int64_t axisVal = axis.getValue().getSExtValue();
+ if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
return success();
}
SmallVector<int64_t> outputShape;
operandShape.getDims(outputShape);
- int64_t axisVal = axis.getValue().getSExtValue();
outputShape[axisVal] = 1;
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
return success();
@@ -1155,6 +1155,63 @@ REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
#undef COMPATIBLE_RETURN_TYPES
+template <typename T>
+static LogicalResult verifyReduceOp(T op) {
+ // All TOSA reduce Ops have input, output and axis.
+ TensorType inputType = op.getInput().getType();
+ TensorType outputType = op.getOutput().getType();
+ int32_t reduceAxis = op.getAxis();
+
+ if (reduceAxis < 0) {
+ op.emitOpError("reduce axis must not be negative");
+ return failure();
+ }
+ if (inputType.hasRank()) {
+ int64_t inputRank = inputType.getRank();
+ // We allow for a special case where the input/output shape has rank 0 and
+ // axis is also 0.
+ if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
+ op.emitOpError("expect input tensor rank (")
+ << inputRank << ") to be larger than reduce axis (" << reduceAxis
+ << ")";
+ return failure();
+ }
+ }
+ if (outputType.hasRank()) {
+ int64_t outputRank = outputType.getRank();
+ if (inputType.hasRank() && outputRank != inputType.getRank()) {
+ op.emitOpError(
+ "expect output tensor rank to be equal to input tensor rank");
+ return failure();
+ }
+ if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {
+ op.emitOpError("expect output tensor rank (")
+ << outputRank << ") to be larger than reduce axis (" << reduceAxis
+ << ")";
+ return failure();
+ }
+ // We can only verify the reduced dimension size to be 1 if this is not the
+ // special case of output rank == 0.
+ if (outputRank != 0) {
+ auto outputShape = outputType.getShape();
+ if (!outputType.isDynamicDim(reduceAxis) &&
+ outputShape[reduceAxis] != 1) {
+ op.emitOpError("expect reduced dimension size to be 1, got ")
+ << outputShape[reduceAxis];
+ return failure();
+ }
+ }
+ }
+ return success();
+}
+
+LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
+LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
+LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
+LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
+LogicalResult tosa::ReduceProdOp::verify() { return verifyReduceOp(*this); }
+LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
+
static LogicalResult NAryInferReturnTypes(
const ValueShapeRange &operands,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index dddf15fffbb7aec..46a31d6cf3e965e 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -599,7 +599,7 @@ func.func nested @fold_reduce_rank_zero() {
// CHECK-NOT: tosa.reduce_min
// CHECK-NOT: tosa.reverse
%0 = tensor.empty() : tensor<i32>
- %1 = tosa.reduce_min %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
+ %1 = tosa.reduce_min %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
%2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
return
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 9233662e88db902..7a6b507566eb25d 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -128,14 +128,69 @@ func.func @test_reduce_min_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
// -----
func.func @test_reduce_prod_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
- // expected-error at +2 {{failed to infer returned types}}
- // expected-error at +1 {{'tosa.reduce_prod' op inferred type(s) 'tensor<2x1x4x5xf32>' are incompatible with return type(s) of operation 'tensor<2x3x4x5xf32>'}}
+ // expected-error at +1 {{'tosa.reduce_prod' op expect reduced dimension size to be 1, got 3}}
%0 = tosa.reduce_prod %arg0 {axis = 1 : i32} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32>
return
}
// -----
+func.func @test_reduce_all_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
+ // expected-error at +1 {{'tosa.reduce_all' op expect input tensor rank (3) to be larger than reduce axis (3)}}
+ %0 = tosa.reduce_all %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
+ return
+}
+
+// -----
+
+func.func @test_reduce_any_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
+ // expected-error at +1 {{'tosa.reduce_any' op expect input tensor rank (3) to be larger than reduce axis (3)}}
+ %0 = tosa.reduce_any %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
+ return
+}
+
+// -----
+
+func.func @test_reduce_max_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
+ // expected-error at +1 {{'tosa.reduce_max' op expect input tensor rank (3) to be larger than reduce axis (3)}}
+ %0 = tosa.reduce_max %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
+ return
+}
+
+// -----
+
+func.func @test_reduce_min_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
+ // expected-error at +1 {{'tosa.reduce_min' op expect input tensor rank (3) to be larger than reduce axis (3)}}
+ %0 = tosa.reduce_min %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
+ return
+}
+
+// -----
+
+func.func @test_reduce_prod_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
+ // expected-error at +1 {{'tosa.reduce_prod' op expect input tensor rank (3) to be larger than reduce axis (3)}}
+ %0 = tosa.reduce_prod %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
+ return
+}
+
+// -----
+
+func.func @test_reduce_sum_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
+ // expected-error at +1 {{'tosa.reduce_sum' op expect input tensor rank (3) to be larger than reduce axis (3)}}
+ %0 = tosa.reduce_sum %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
+ return
+}
+
+// -----
+
+func.func @test_reduce_min_invalid_output_rank(%arg0 : tensor<i32>) -> () {
+ // expected-error at +1 {{'tosa.reduce_min' op expect output tensor rank to be equal to input tensor rank}}
+ %0 = tosa.reduce_min %arg0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
+ return
+}
+
+// -----
+
func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
// expected-error at +2 {{failed to infer returned types}}
// expected-error at +1 {{'tosa.reshape' op inferred type(s) 'tensor<13x21x3x1xf32>' are incompatible with return type(s) of operation 'tensor<13x21x3x1xi32>'}}
More information about the Mlir-commits
mailing list