[Mlir-commits] [mlir] [mlir][tosa] Fix crash in inferReturnTypes for ReduceOps (PR #69843)

Felix Schneider llvmlistbot at llvm.org
Mon Oct 23 14:18:06 PDT 2023


https://github.com/ubfx updated https://github.com/llvm/llvm-project/pull/69843

>From ec1495d5373f436f1f8c230699bdf315035f49e7 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sat, 21 Oct 2023 13:09:40 +0200
Subject: [PATCH 1/4] [mlir][tosa] Fix crash in inferReturnTypes for ReduceOps

The `tosa.reduce_*` ops take an `axis` Attribute that determines along
which dimension the reduction takes place. A crash can occur during
shape inference when the input tensor rank is so low that the given
axis doesn't exist.

Fix https://github.com/llvm/llvm-project/issues/68187
---
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e03904a1611fc42..0f616db31c06a5f 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1117,7 +1117,8 @@ static LogicalResult ReduceInferReturnTypes(
   SmallVector<int64_t> outputShape;
   operandShape.getDims(outputShape);
   int64_t axisVal = axis.getValue().getSExtValue();
-  outputShape[axisVal] = 1;
+  if (axisVal < operandShape.getRank())
+    outputShape[axisVal] = 1;
   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
   return success();
 }

>From 9c430215ac1e2fbd969b0e3b0e79de3ff46ff6ad Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sat, 21 Oct 2023 15:33:19 +0200
Subject: [PATCH 2/4] rebase

---
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 0f616db31c06a5f..5292465477b1094 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1109,16 +1109,15 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
 static LogicalResult ReduceInferReturnTypes(
     ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  if (!operandShape.hasRank() || operandShape.getRank() == 0) {
+  int64_t axisVal = axis.getValue().getSExtValue();
+  if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
     inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
     return success();
   }
 
   SmallVector<int64_t> outputShape;
   operandShape.getDims(outputShape);
-  int64_t axisVal = axis.getValue().getSExtValue();
-  if (axisVal < operandShape.getRank())
-    outputShape[axisVal] = 1;
+  outputShape[axisVal] = 1;
   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
   return success();
 }

>From c68f17041d9bb9837b813dba49945d43f1396ec2 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Mon, 23 Oct 2023 22:32:06 +0200
Subject: [PATCH 3/4] Add verifiers, tests for invalid reduce ops

---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td |  9 +++-
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp         | 40 +++++++++++++++
 mlir/test/Dialect/Tosa/canonicalize.mlir     | 10 ----
 mlir/test/Dialect/Tosa/invalid.mlir          | 51 +++++++++++++++++++-
 4 files changed, 97 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 5cc97469d14c314..901384eae50176b 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1271,6 +1271,7 @@ def Tosa_ReduceAllOp : Tosa_InferTensorTypeOp<"reduce_all"> {
   );
 
   let hasFolder = 1;
+  let hasVerifier = 1;
 
   let extraClassDeclaration = [{
     /// Returns true when two result types are compatible for this op;
@@ -1304,6 +1305,7 @@ def Tosa_ReduceAnyOp : Tosa_InferTensorTypeOp<"reduce_any"> {
   );
 
   let hasFolder = 1;
+  let hasVerifier = 1;
 
   let extraClassDeclaration = [{
     /// Returns true when two result types are compatible for this op;
@@ -1337,6 +1339,7 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {
   );
 
   let hasFolder = 1;
+  let hasVerifier = 1;
 
   let extraClassDeclaration = [{
     /// Returns true when two result types are compatible for this op;
@@ -1371,6 +1374,7 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {
   );
 
   let hasFolder = 1;
+  let hasVerifier = 1;
 
   let extraClassDeclaration = [{
     /// Returns true when two result types are compatible for this op;
@@ -1405,6 +1409,7 @@ def Tosa_ReduceProdOp : Tosa_InferTensorTypeOp<"reduce_prod"> {
   );
 
   let hasFolder = 1;
+  let hasVerifier = 1;
 
   let extraClassDeclaration = [{
     /// Returns true when two result types are compatible for this op;
@@ -1436,8 +1441,10 @@ def Tosa_ReduceSumOp : Tosa_InferTensorTypeOp<"reduce_sum"> {
   let results = (outs
     Tosa_Tensor:$output
   );
-  let hasFolder = 1;
 
+  let hasFolder = 1;
+  let hasVerifier = 1;
+  
   let extraClassDeclaration = [{
     /// Returns true when two result types are compatible for this op;
     /// Method used by InferTypeOpInterface.
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 5292465477b1094..39bb2f8092be4e6 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1155,6 +1155,46 @@ REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
 COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
 #undef COMPATIBLE_RETURN_TYPES
 
+template <typename T> static LogicalResult verifyReduceOp(T op) {
+  // All TOSA reduce Ops have input, output and axis.
+  TensorType inputType = op.getInput().getType();
+  TensorType outputType = op.getOutput().getType();
+  int32_t reduceAxis = op.getAxis();
+
+  if (reduceAxis < 0) {
+    op.emitOpError("reduce axis must not be negative");
+    return failure();
+  }
+  if (inputType.hasRank() && reduceAxis >= inputType.getRank()) {
+    op.emitOpError("expect input tensor rank (")
+        << inputType.getRank() << ") to be larger than reduce axis ("
+        << reduceAxis << ")";
+    return failure();
+  }
+  if (outputType.hasRank()) {
+    if (reduceAxis >= outputType.getRank()) {
+      op.emitOpError("expect output tensor rank (")
+          << outputType.getRank() << ") to be larger than reduce axis ("
+          << reduceAxis << ")";
+      return failure();
+    }
+    auto outputShape = outputType.getShape();
+    if (!outputType.isDynamicDim(reduceAxis) && outputShape[reduceAxis] != 1) {
+      op.emitOpError("expect reduced dimension size to be 1, got ")
+          << outputShape[reduceAxis];
+      return failure();
+    }
+  }
+  return success();
+}
+
+LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
+LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
+LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
+LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
+LogicalResult tosa::ReduceProdOp::verify() { return verifyReduceOp(*this); }
+LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
+
 static LogicalResult NAryInferReturnTypes(
     const ValueShapeRange &operands,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index dddf15fffbb7aec..1e4d661d15fdff3 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -593,13 +593,3 @@ func.func @fold_abs_abs(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
 }
 
 // -----
-
-// CHECK-LABEL: @fold_reduce_rank_zero
-func.func nested @fold_reduce_rank_zero() {
-  // CHECK-NOT: tosa.reduce_min
-  // CHECK-NOT: tosa.reverse
-  %0 = tensor.empty() : tensor<i32>
-  %1 = tosa.reduce_min %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
-  %2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
-  return
-}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 9233662e88db902..332ea2df4a91bb3 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -128,14 +128,61 @@ func.func @test_reduce_min_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
 // -----
 
 func.func @test_reduce_prod_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
-  // expected-error at +2 {{failed to infer returned types}}
-  // expected-error at +1 {{'tosa.reduce_prod' op inferred type(s) 'tensor<2x1x4x5xf32>' are incompatible with return type(s) of operation 'tensor<2x3x4x5xf32>'}}
+  // expected-error at +1 {{'tosa.reduce_prod' op expect reduced dimension size to be 1, got 3}}
   %0 = tosa.reduce_prod %arg0 {axis = 1 : i32} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32>
   return
 }
 
 // -----
 
+func.func @test_reduce_all_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
+  // expected-error at +1 {{'tosa.reduce_all' op expect input tensor rank (3) to be larger than reduce axis (3)}}
+  %0 = tosa.reduce_all %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
+  return
+}
+
+// -----
+
+func.func @test_reduce_any_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
+  // expected-error at +1 {{'tosa.reduce_any' op expect input tensor rank (3) to be larger than reduce axis (3)}}
+  %0 = tosa.reduce_any %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
+  return
+}
+
+// -----
+
+func.func @test_reduce_max_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
+  // expected-error at +1 {{'tosa.reduce_max' op expect input tensor rank (3) to be larger than reduce axis (3)}}
+  %0 = tosa.reduce_max %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
+  return
+}
+
+// -----
+
+func.func @test_reduce_min_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
+  // expected-error at +1 {{'tosa.reduce_min' op expect input tensor rank (3) to be larger than reduce axis (3)}}
+  %0 = tosa.reduce_min %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
+  return
+}
+
+// -----
+
+func.func @test_reduce_prod_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
+  // expected-error at +1 {{'tosa.reduce_prod' op expect input tensor rank (3) to be larger than reduce axis (3)}}
+  %0 = tosa.reduce_prod %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
+  return
+}
+
+// -----
+
+func.func @test_reduce_sum_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
+  // expected-error at +1 {{'tosa.reduce_sum' op expect input tensor rank (3) to be larger than reduce axis (3)}}
+  %0 = tosa.reduce_sum %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
+  return
+}
+
+// -----
+
 func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
   // expected-error at +2 {{failed to infer returned types}}
   // expected-error at +1 {{'tosa.reshape' op inferred type(s) 'tensor<13x21x3x1xf32>' are incompatible with return type(s) of operation 'tensor<13x21x3x1xi32>'}}

>From f4d8865c0e74d5d83646494b50cc98ddc1e6a156 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Mon, 23 Oct 2023 23:16:11 +0200
Subject: [PATCH 4/4] add special case for rank 0 reduce

---
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp     | 18 ++++++++++++------
 mlir/test/Dialect/Tosa/canonicalize.mlir | 10 ++++++++++
 2 files changed, 22 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 39bb2f8092be4e6..2a6fc2862e30696 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1155,7 +1155,8 @@ REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
 COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
 #undef COMPATIBLE_RETURN_TYPES
 
-template <typename T> static LogicalResult verifyReduceOp(T op) {
+template <typename T>
+static LogicalResult verifyReduceOp(T op) {
   // All TOSA reduce Ops have input, output and axis.
   TensorType inputType = op.getInput().getType();
   TensorType outputType = op.getOutput().getType();
@@ -1165,11 +1166,16 @@ template <typename T> static LogicalResult verifyReduceOp(T op) {
     op.emitOpError("reduce axis must not be negative");
     return failure();
   }
-  if (inputType.hasRank() && reduceAxis >= inputType.getRank()) {
-    op.emitOpError("expect input tensor rank (")
-        << inputType.getRank() << ") to be larger than reduce axis ("
-        << reduceAxis << ")";
-    return failure();
+  if (inputType.hasRank()) {
+    int64_t inputRank = inputType.getRank();
+    // We allow for a special case where the input shape has rank 0 and axis is
+    // also 0.
+    if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
+      op.emitOpError("expect input tensor rank (")
+          << inputType.getRank() << ") to be larger than reduce axis ("
+          << reduceAxis << ")";
+      return failure();
+    }
   }
   if (outputType.hasRank()) {
     if (reduceAxis >= outputType.getRank()) {
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 1e4d661d15fdff3..dddf15fffbb7aec 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -593,3 +593,13 @@ func.func @fold_abs_abs(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
 }
 
 // -----
+
+// CHECK-LABEL: @fold_reduce_rank_zero
+func.func nested @fold_reduce_rank_zero() {
+  // CHECK-NOT: tosa.reduce_min
+  // CHECK-NOT: tosa.reverse
+  %0 = tensor.empty() : tensor<i32>
+  %1 = tosa.reduce_min %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
+  %2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
+  return
+}



More information about the Mlir-commits mailing list