[Mlir-commits] [mlir] [mlir][tosa] Fix crash in inferReturnTypes for ReduceOps (PR #69843)
Felix Schneider
llvmlistbot at llvm.org
Mon Oct 23 14:18:06 PDT 2023
https://github.com/ubfx updated https://github.com/llvm/llvm-project/pull/69843
>From ec1495d5373f436f1f8c230699bdf315035f49e7 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sat, 21 Oct 2023 13:09:40 +0200
Subject: [PATCH 1/4] [mlir][tosa] Fix crash in inferReturnTypes for ReduceOps
The `tosa.reduce_*` ops take an `axis` Attribute that determines along
which dimension the reduction takes place. A crash can occur during
shape inference when the input tensor rank is so low that the given
axis doesn't exist.
Fix https://github.com/llvm/llvm-project/issues/68187
---
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e03904a1611fc42..0f616db31c06a5f 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1117,7 +1117,8 @@ static LogicalResult ReduceInferReturnTypes(
SmallVector<int64_t> outputShape;
operandShape.getDims(outputShape);
int64_t axisVal = axis.getValue().getSExtValue();
- outputShape[axisVal] = 1;
+ if (axisVal < operandShape.getRank())
+ outputShape[axisVal] = 1;
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
return success();
}
>From 9c430215ac1e2fbd969b0e3b0e79de3ff46ff6ad Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sat, 21 Oct 2023 15:33:19 +0200
Subject: [PATCH 2/4] rebase
---
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 7 +++----
1 file changed, 3 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 0f616db31c06a5f..5292465477b1094 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1109,16 +1109,15 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
static LogicalResult ReduceInferReturnTypes(
ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- if (!operandShape.hasRank() || operandShape.getRank() == 0) {
+ int64_t axisVal = axis.getValue().getSExtValue();
+ if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
return success();
}
SmallVector<int64_t> outputShape;
operandShape.getDims(outputShape);
- int64_t axisVal = axis.getValue().getSExtValue();
- if (axisVal < operandShape.getRank())
- outputShape[axisVal] = 1;
+ outputShape[axisVal] = 1;
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
return success();
}
>From c68f17041d9bb9837b813dba49945d43f1396ec2 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Mon, 23 Oct 2023 22:32:06 +0200
Subject: [PATCH 3/4] Add verifiers, tests for invalid reduce ops
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 9 +++-
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 40 +++++++++++++++
mlir/test/Dialect/Tosa/canonicalize.mlir | 10 ----
mlir/test/Dialect/Tosa/invalid.mlir | 51 +++++++++++++++++++-
4 files changed, 97 insertions(+), 13 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 5cc97469d14c314..901384eae50176b 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1271,6 +1271,7 @@ def Tosa_ReduceAllOp : Tosa_InferTensorTypeOp<"reduce_all"> {
);
let hasFolder = 1;
+ let hasVerifier = 1;
let extraClassDeclaration = [{
/// Returns true when two result types are compatible for this op;
@@ -1304,6 +1305,7 @@ def Tosa_ReduceAnyOp : Tosa_InferTensorTypeOp<"reduce_any"> {
);
let hasFolder = 1;
+ let hasVerifier = 1;
let extraClassDeclaration = [{
/// Returns true when two result types are compatible for this op;
@@ -1337,6 +1339,7 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {
);
let hasFolder = 1;
+ let hasVerifier = 1;
let extraClassDeclaration = [{
/// Returns true when two result types are compatible for this op;
@@ -1371,6 +1374,7 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {
);
let hasFolder = 1;
+ let hasVerifier = 1;
let extraClassDeclaration = [{
/// Returns true when two result types are compatible for this op;
@@ -1405,6 +1409,7 @@ def Tosa_ReduceProdOp : Tosa_InferTensorTypeOp<"reduce_prod"> {
);
let hasFolder = 1;
+ let hasVerifier = 1;
let extraClassDeclaration = [{
/// Returns true when two result types are compatible for this op;
@@ -1436,8 +1441,10 @@ def Tosa_ReduceSumOp : Tosa_InferTensorTypeOp<"reduce_sum"> {
let results = (outs
Tosa_Tensor:$output
);
- let hasFolder = 1;
+ let hasFolder = 1;
+ let hasVerifier = 1;
+
let extraClassDeclaration = [{
/// Returns true when two result types are compatible for this op;
/// Method used by InferTypeOpInterface.
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 5292465477b1094..39bb2f8092be4e6 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1155,6 +1155,46 @@ REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
#undef COMPATIBLE_RETURN_TYPES
+template <typename T> static LogicalResult verifyReduceOp(T op) {
+ // All TOSA reduce Ops have input, output and axis.
+ TensorType inputType = op.getInput().getType();
+ TensorType outputType = op.getOutput().getType();
+ int32_t reduceAxis = op.getAxis();
+
+ if (reduceAxis < 0) {
+ op.emitOpError("reduce axis must not be negative");
+ return failure();
+ }
+ if (inputType.hasRank() && reduceAxis >= inputType.getRank()) {
+ op.emitOpError("expect input tensor rank (")
+ << inputType.getRank() << ") to be larger than reduce axis ("
+ << reduceAxis << ")";
+ return failure();
+ }
+ if (outputType.hasRank()) {
+ if (reduceAxis >= outputType.getRank()) {
+ op.emitOpError("expect output tensor rank (")
+ << outputType.getRank() << ") to be larger than reduce axis ("
+ << reduceAxis << ")";
+ return failure();
+ }
+ auto outputShape = outputType.getShape();
+ if (!outputType.isDynamicDim(reduceAxis) && outputShape[reduceAxis] != 1) {
+ op.emitOpError("expect reduced dimension size to be 1, got ")
+ << outputShape[reduceAxis];
+ return failure();
+ }
+ }
+ return success();
+}
+
+LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
+LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
+LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
+LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
+LogicalResult tosa::ReduceProdOp::verify() { return verifyReduceOp(*this); }
+LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
+
static LogicalResult NAryInferReturnTypes(
const ValueShapeRange &operands,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index dddf15fffbb7aec..1e4d661d15fdff3 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -593,13 +593,3 @@ func.func @fold_abs_abs(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
}
// -----
-
-// CHECK-LABEL: @fold_reduce_rank_zero
-func.func nested @fold_reduce_rank_zero() {
- // CHECK-NOT: tosa.reduce_min
- // CHECK-NOT: tosa.reverse
- %0 = tensor.empty() : tensor<i32>
- %1 = tosa.reduce_min %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
- %2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
- return
-}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 9233662e88db902..332ea2df4a91bb3 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -128,14 +128,61 @@ func.func @test_reduce_min_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
// -----
func.func @test_reduce_prod_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
- // expected-error at +2 {{failed to infer returned types}}
- // expected-error at +1 {{'tosa.reduce_prod' op inferred type(s) 'tensor<2x1x4x5xf32>' are incompatible with return type(s) of operation 'tensor<2x3x4x5xf32>'}}
+ // expected-error at +1 {{'tosa.reduce_prod' op expect reduced dimension size to be 1, got 3}}
%0 = tosa.reduce_prod %arg0 {axis = 1 : i32} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32>
return
}
// -----
+func.func @test_reduce_all_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
+ // expected-error at +1 {{'tosa.reduce_all' op expect input tensor rank (3) to be larger than reduce axis (3)}}
+ %0 = tosa.reduce_all %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
+ return
+}
+
+// -----
+
+func.func @test_reduce_any_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
+ // expected-error at +1 {{'tosa.reduce_any' op expect input tensor rank (3) to be larger than reduce axis (3)}}
+ %0 = tosa.reduce_any %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
+ return
+}
+
+// -----
+
+func.func @test_reduce_max_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
+ // expected-error at +1 {{'tosa.reduce_max' op expect input tensor rank (3) to be larger than reduce axis (3)}}
+ %0 = tosa.reduce_max %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
+ return
+}
+
+// -----
+
+func.func @test_reduce_min_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
+ // expected-error at +1 {{'tosa.reduce_min' op expect input tensor rank (3) to be larger than reduce axis (3)}}
+ %0 = tosa.reduce_min %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
+ return
+}
+
+// -----
+
+func.func @test_reduce_prod_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
+ // expected-error at +1 {{'tosa.reduce_prod' op expect input tensor rank (3) to be larger than reduce axis (3)}}
+ %0 = tosa.reduce_prod %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
+ return
+}
+
+// -----
+
+func.func @test_reduce_sum_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
+ // expected-error at +1 {{'tosa.reduce_sum' op expect input tensor rank (3) to be larger than reduce axis (3)}}
+ %0 = tosa.reduce_sum %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
+ return
+}
+
+// -----
+
func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
// expected-error at +2 {{failed to infer returned types}}
// expected-error at +1 {{'tosa.reshape' op inferred type(s) 'tensor<13x21x3x1xf32>' are incompatible with return type(s) of operation 'tensor<13x21x3x1xi32>'}}
>From f4d8865c0e74d5d83646494b50cc98ddc1e6a156 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Mon, 23 Oct 2023 23:16:11 +0200
Subject: [PATCH 4/4] add special case for rank 0 reduce
---
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 18 ++++++++++++------
mlir/test/Dialect/Tosa/canonicalize.mlir | 10 ++++++++++
2 files changed, 22 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 39bb2f8092be4e6..2a6fc2862e30696 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1155,7 +1155,8 @@ REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
#undef COMPATIBLE_RETURN_TYPES
-template <typename T> static LogicalResult verifyReduceOp(T op) {
+template <typename T>
+static LogicalResult verifyReduceOp(T op) {
// All TOSA reduce Ops have input, output and axis.
TensorType inputType = op.getInput().getType();
TensorType outputType = op.getOutput().getType();
@@ -1165,11 +1166,16 @@ template <typename T> static LogicalResult verifyReduceOp(T op) {
op.emitOpError("reduce axis must not be negative");
return failure();
}
- if (inputType.hasRank() && reduceAxis >= inputType.getRank()) {
- op.emitOpError("expect input tensor rank (")
- << inputType.getRank() << ") to be larger than reduce axis ("
- << reduceAxis << ")";
- return failure();
+ if (inputType.hasRank()) {
+ int64_t inputRank = inputType.getRank();
+ // We allow for a special case where the input shape has rank 0 and axis is
+ // also 0.
+ if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
+ op.emitOpError("expect input tensor rank (")
+ << inputType.getRank() << ") to be larger than reduce axis ("
+ << reduceAxis << ")";
+ return failure();
+ }
}
if (outputType.hasRank()) {
if (reduceAxis >= outputType.getRank()) {
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 1e4d661d15fdff3..dddf15fffbb7aec 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -593,3 +593,13 @@ func.func @fold_abs_abs(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
}
// -----
+
+// CHECK-LABEL: @fold_reduce_rank_zero
+func.func nested @fold_reduce_rank_zero() {
+ // CHECK-NOT: tosa.reduce_min
+ // CHECK-NOT: tosa.reverse
+ %0 = tensor.empty() : tensor<i32>
+ %1 = tosa.reduce_min %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
+ %2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
+ return
+}
More information about the Mlir-commits
mailing list