[Mlir-commits] [mlir] 8a57bc0 - [mlir][tosa] Add verifiers to ReduceOps, fix shape inference crash (#69843)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 24 10:53:43 PDT 2023


Author: Felix Schneider
Date: 2023-10-24T19:53:38+02:00
New Revision: 8a57bc092850d6a5fa43a7f74a04352bb135f56e

URL: https://github.com/llvm/llvm-project/commit/8a57bc092850d6a5fa43a7f74a04352bb135f56e
DIFF: https://github.com/llvm/llvm-project/commit/8a57bc092850d6a5fa43a7f74a04352bb135f56e.diff

LOG: [mlir][tosa] Add verifiers to ReduceOps, fix shape inference crash (#69843)

This patch adds verifiers to `tosa.reduce_*` ops that check, among other things,
that the supplied `axis` argument is compatible with the input/output tensors'
shapes. We allow for a special case of `axis == 0 && rank == 0` to be valid.
This patch also adds a check to `ReduceInferReturnTypes()` to ensure that the
shape inference pass doesn't crash on an invalid `axis` argument anymore.

Fix https://github.com/llvm/llvm-project/issues/68187

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Tosa/canonicalize.mlir
    mlir/test/Dialect/Tosa/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 5cc97469d14c314..901384eae50176b 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1271,6 +1271,7 @@ def Tosa_ReduceAllOp : Tosa_InferTensorTypeOp<"reduce_all"> {
   );
 
   let hasFolder = 1;
+  let hasVerifier = 1;
 
   let extraClassDeclaration = [{
     /// Returns true when two result types are compatible for this op;
@@ -1304,6 +1305,7 @@ def Tosa_ReduceAnyOp : Tosa_InferTensorTypeOp<"reduce_any"> {
   );
 
   let hasFolder = 1;
+  let hasVerifier = 1;
 
   let extraClassDeclaration = [{
     /// Returns true when two result types are compatible for this op;
@@ -1337,6 +1339,7 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {
   );
 
   let hasFolder = 1;
+  let hasVerifier = 1;
 
   let extraClassDeclaration = [{
     /// Returns true when two result types are compatible for this op;
@@ -1371,6 +1374,7 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {
   );
 
   let hasFolder = 1;
+  let hasVerifier = 1;
 
   let extraClassDeclaration = [{
     /// Returns true when two result types are compatible for this op;
@@ -1405,6 +1409,7 @@ def Tosa_ReduceProdOp : Tosa_InferTensorTypeOp<"reduce_prod"> {
   );
 
   let hasFolder = 1;
+  let hasVerifier = 1;
 
   let extraClassDeclaration = [{
     /// Returns true when two result types are compatible for this op;
@@ -1436,8 +1441,10 @@ def Tosa_ReduceSumOp : Tosa_InferTensorTypeOp<"reduce_sum"> {
   let results = (outs
     Tosa_Tensor:$output
   );
-  let hasFolder = 1;
 
+  let hasFolder = 1;
+  let hasVerifier = 1;
+  
   let extraClassDeclaration = [{
     /// Returns true when two result types are compatible for this op;
     /// Method used by InferTypeOpInterface.

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e03904a1611fc42..078e50e857fbb2a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1109,14 +1109,14 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
 static LogicalResult ReduceInferReturnTypes(
     ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  if (!operandShape.hasRank() || operandShape.getRank() == 0) {
+  int64_t axisVal = axis.getValue().getSExtValue();
+  if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
     inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
     return success();
   }
 
   SmallVector<int64_t> outputShape;
   operandShape.getDims(outputShape);
-  int64_t axisVal = axis.getValue().getSExtValue();
   outputShape[axisVal] = 1;
   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
   return success();
@@ -1155,6 +1155,63 @@ REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
 COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
 #undef COMPATIBLE_RETURN_TYPES
 
+template <typename T>
+static LogicalResult verifyReduceOp(T op) {
+  // All TOSA reduce Ops have input, output and axis.
+  TensorType inputType = op.getInput().getType();
+  TensorType outputType = op.getOutput().getType();
+  int32_t reduceAxis = op.getAxis();
+
+  if (reduceAxis < 0) {
+    op.emitOpError("reduce axis must not be negative");
+    return failure();
+  }
+  if (inputType.hasRank()) {
+    int64_t inputRank = inputType.getRank();
+    // We allow for a special case where the input/output shape has rank 0 and
+    // axis is also 0.
+    if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
+      op.emitOpError("expect input tensor rank (")
+          << inputRank << ") to be larger than reduce axis (" << reduceAxis
+          << ")";
+      return failure();
+    }
+  }
+  if (outputType.hasRank()) {
+    int64_t outputRank = outputType.getRank();
+    if (inputType.hasRank() && outputRank != inputType.getRank()) {
+      op.emitOpError(
+          "expect output tensor rank to be equal to input tensor rank");
+      return failure();
+    }
+    if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {
+      op.emitOpError("expect output tensor rank (")
+          << outputRank << ") to be larger than reduce axis (" << reduceAxis
+          << ")";
+      return failure();
+    }
+    // We can only verify the reduced dimension size to be 1 if this is not the
+    // special case of output rank == 0.
+    if (outputRank != 0) {
+      auto outputShape = outputType.getShape();
+      if (!outputType.isDynamicDim(reduceAxis) &&
+          outputShape[reduceAxis] != 1) {
+        op.emitOpError("expect reduced dimension size to be 1, got ")
+            << outputShape[reduceAxis];
+        return failure();
+      }
+    }
+  }
+  return success();
+}
+
+LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
+LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
+LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
+LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
+LogicalResult tosa::ReduceProdOp::verify() { return verifyReduceOp(*this); }
+LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
+
 static LogicalResult NAryInferReturnTypes(
     const ValueShapeRange &operands,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index dddf15fffbb7aec..46a31d6cf3e965e 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -599,7 +599,7 @@ func.func nested @fold_reduce_rank_zero() {
   // CHECK-NOT: tosa.reduce_min
   // CHECK-NOT: tosa.reverse
   %0 = tensor.empty() : tensor<i32>
-  %1 = tosa.reduce_min %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
+  %1 = tosa.reduce_min %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
   %2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
   return
 }

diff  --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 9233662e88db902..7a6b507566eb25d 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -128,14 +128,69 @@ func.func @test_reduce_min_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
 // -----
 
 func.func @test_reduce_prod_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
-  // expected-error at +2 {{failed to infer returned types}}
-  // expected-error at +1 {{'tosa.reduce_prod' op inferred type(s) 'tensor<2x1x4x5xf32>' are incompatible with return type(s) of operation 'tensor<2x3x4x5xf32>'}}
+  // expected-error at +1 {{'tosa.reduce_prod' op expect reduced dimension size to be 1, got 3}}
   %0 = tosa.reduce_prod %arg0 {axis = 1 : i32} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32>
   return
 }
 
 // -----
 
+func.func @test_reduce_all_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
+  // expected-error at +1 {{'tosa.reduce_all' op expect input tensor rank (3) to be larger than reduce axis (3)}}
+  %0 = tosa.reduce_all %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
+  return
+}
+
+// -----
+
+func.func @test_reduce_any_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
+  // expected-error at +1 {{'tosa.reduce_any' op expect input tensor rank (3) to be larger than reduce axis (3)}}
+  %0 = tosa.reduce_any %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
+  return
+}
+
+// -----
+
+func.func @test_reduce_max_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
+  // expected-error at +1 {{'tosa.reduce_max' op expect input tensor rank (3) to be larger than reduce axis (3)}}
+  %0 = tosa.reduce_max %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
+  return
+}
+
+// -----
+
+func.func @test_reduce_min_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
+  // expected-error at +1 {{'tosa.reduce_min' op expect input tensor rank (3) to be larger than reduce axis (3)}}
+  %0 = tosa.reduce_min %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
+  return
+}
+
+// -----
+
+func.func @test_reduce_prod_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
+  // expected-error at +1 {{'tosa.reduce_prod' op expect input tensor rank (3) to be larger than reduce axis (3)}}
+  %0 = tosa.reduce_prod %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
+  return
+}
+
+// -----
+
+func.func @test_reduce_sum_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
+  // expected-error at +1 {{'tosa.reduce_sum' op expect input tensor rank (3) to be larger than reduce axis (3)}}
+  %0 = tosa.reduce_sum %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
+  return
+}
+
+// -----
+
+func.func @test_reduce_min_invalid_output_rank(%arg0 : tensor<i32>) -> () {
+  // expected-error at +1 {{'tosa.reduce_min' op expect output tensor rank to be equal to input tensor rank}}
+  %0 = tosa.reduce_min %arg0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
+  return
+}
+
+// -----
+
 func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
   // expected-error at +2 {{failed to infer returned types}}
   // expected-error at +1 {{'tosa.reshape' op inferred type(s) 'tensor<13x21x3x1xf32>' are incompatible with return type(s) of operation 'tensor<13x21x3x1xi32>'}}


        


More information about the Mlir-commits mailing list