[Mlir-commits] [mlir] [mlir][tosa] Add verifier for tosa.reverse (PR #70500)

Felix Schneider llvmlistbot at llvm.org
Fri Oct 27 23:48:36 PDT 2023


https://github.com/ubfx updated https://github.com/llvm/llvm-project/pull/70500

>From 116bbc87dedf685c11d1d0b59295b75761fb0aa9 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Fri, 27 Oct 2023 20:57:18 +0200
Subject: [PATCH] [mlir][tosa] Add verifier for tosa.reverse

This patch adds a verifier to tosa.reverse which checks the axis
argument and input/output tensor ranks for validity.
We allow a special case where `axis == 0 && rank == 0`.
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td |  3 +-
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp         | 29 ++++++++++++++++++++
 mlir/test/Dialect/Tosa/canonicalize.mlir     |  2 +-
 mlir/test/Dialect/Tosa/invalid.mlir          |  8 ++++++
 4 files changed, 40 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index c0baf478358c132..81b9e93c2095f57 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1593,7 +1593,8 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
   );
 
   let hasFolder = 1;
-
+  let hasVerifier = 1;
+  
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
 }
 
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index c9e64a67302e772..9f619a3531ab615 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1768,6 +1768,35 @@ void IfOp::print(OpAsmPrinter &p) {
   p.printOptionalAttrDict((*this)->getAttrs());
 }
 
+LogicalResult ReverseOp::verify() {
+  TensorType inputType = getInput().getType();
+  TensorType outputType = getOutput().getType();
+  int32_t reverseAxis = getAxis();
+
+  if (reverseAxis < 0)
+    return emitOpError("expected non-negative reverse axis");
+  if (inputType.hasRank()) {
+    int64_t inputRank = inputType.getRank();
+    // We allow for a special case where the input/output shape has rank 0 and
+    // axis is also 0.
+    if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
+      return emitOpError("expect input tensor rank (")
+             << inputRank << ") to be larger than reverse axis (" << reverseAxis
+             << ")";
+  }
+  if (outputType.hasRank()) {
+    int64_t outputRank = outputType.getRank();
+    if (inputType.hasRank() && outputRank != inputType.getRank())
+      return emitOpError(
+          "expect output tensor rank to be equal to input tensor rank");
+    if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
+      return emitOpError("expect output tensor rank (")
+             << outputRank << ") to be larger than reverse axis ("
+             << reverseAxis << ")";
+  }
+  return success();
+}
+
 // parse and print of WhileOp refer to the implementation of SCF dialect.
 ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
   SmallVector<OpAsmParser::Argument, 4> regionArgs;
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 46a31d6cf3e965e..102c9ed1578cde9 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -600,6 +600,6 @@ func.func nested @fold_reduce_rank_zero() {
   // CHECK-NOT: tosa.reverse
   %0 = tensor.empty() : tensor<i32>
   %1 = tosa.reduce_min %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
-  %2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
+  %2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
   return
 }
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 8a290299b925a7c..8e23a1fde04bc82 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -200,6 +200,14 @@ func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
 
 // -----
 
+func.func @test_reverse_axis_out_of_range(%arg0 : tensor<13x21x3xf32>) -> () {
+  // expected-error at +1 {{'tosa.reverse' op expect input tensor rank (3) to be larger than reverse axis (5)}}
+  %0 = tosa.reverse %arg0 {axis = 5 : i32} : (tensor<13x21x3xf32>) -> tensor<?x?x?xi32>
+  return
+}
+
+// -----
+
 func.func @test_const_attribute_type_mismatch() -> tensor<100x100xf32> {
   // expected-error at +1 {{'tosa.const' op failed to verify that all of {value, output} have same shape}}
   %0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1xf32>} : () -> tensor<100x100xf32>



More information about the Mlir-commits mailing list