[Mlir-commits] [mlir] [mlir][tosa] Add verifier for tosa.reverse (PR #70500)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 27 12:30:02 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Felix Schneider (ubfx)

<details>
<summary>Changes</summary>

This patch adds a verifier to tosa.reverse which checks the axis argument and input/output tensor ranks for validity. We allow a special case where `axis == 0 && rank == 0`.

---
Full diff: https://github.com/llvm/llvm-project/pull/70500.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+2-1) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+29) 
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+1-1) 
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+8) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index c0baf478358c132..81b9e93c2095f57 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1593,7 +1593,8 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
   );
 
   let hasFolder = 1;
-
+  let hasVerifier = 1;
+  
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
 }
 
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index c9e64a67302e772..9f619a3531ab615 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1768,6 +1768,35 @@ void IfOp::print(OpAsmPrinter &p) {
   p.printOptionalAttrDict((*this)->getAttrs());
 }
 
+LogicalResult ReverseOp::verify() {
+  TensorType inputType = getInput().getType();
+  TensorType outputType = getOutput().getType();
+  int32_t reverseAxis = getAxis();
+
+  if (reverseAxis < 0)
+    return emitOpError("expected non-negative reverse axis");
+  if (inputType.hasRank()) {
+    int64_t inputRank = inputType.getRank();
+    // We allow for a special case where the input/output shape has rank 0 and
+    // axis is also 0.
+    if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
+      return emitOpError("expect input tensor rank (")
+             << inputRank << ") to be larger than reverse axis (" << reverseAxis
+             << ")";
+  }
+  if (outputType.hasRank()) {
+    int64_t outputRank = outputType.getRank();
+    if (inputType.hasRank() && outputRank != inputType.getRank())
+      return emitOpError(
+          "expect output tensor rank to be equal to input tensor rank");
+    if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
+      return emitOpError("expect output tensor rank (")
+             << outputRank << ") to be larger than reverse axis ("
+             << reverseAxis << ")";
+  }
+  return success();
+}
+
 // parse and print of WhileOp refer to the implementation of SCF dialect.
 ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
   SmallVector<OpAsmParser::Argument, 4> regionArgs;
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 46a31d6cf3e965e..102c9ed1578cde9 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -600,6 +600,6 @@ func.func nested @fold_reduce_rank_zero() {
   // CHECK-NOT: tosa.reverse
   %0 = tensor.empty() : tensor<i32>
   %1 = tosa.reduce_min %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
-  %2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
+  %2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
   return
 }
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 8a290299b925a7c..8e23a1fde04bc82 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -200,6 +200,14 @@ func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
 
 // -----
 
+func.func @test_reverse_axis_out_of_range(%arg0 : tensor<13x21x3xf32>) -> () {
+  // expected-error at +1 {{'tosa.reverse' op expect input tensor rank (3) to be larger than reverse axis (5)}}
+  %0 = tosa.reverse %arg0 {axis = 5 : i32} : (tensor<13x21x3xf32>) -> tensor<?x?x?xi32>
+  return
+}
+
+// -----
+
 func.func @test_const_attribute_type_mismatch() -> tensor<100x100xf32> {
   // expected-error at +1 {{'tosa.const' op failed to verify that all of {value, output} have same shape}}
   %0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1xf32>} : () -> tensor<100x100xf32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/70500


More information about the Mlir-commits mailing list