[Mlir-commits] [mlir] [mlir][tosa] Add verifier for tosa.reverse (PR #70500)
Felix Schneider
llvmlistbot at llvm.org
Fri Oct 27 23:48:36 PDT 2023
https://github.com/ubfx updated https://github.com/llvm/llvm-project/pull/70500
>From 116bbc87dedf685c11d1d0b59295b75761fb0aa9 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Fri, 27 Oct 2023 20:57:18 +0200
Subject: [PATCH] [mlir][tosa] Add verifier for tosa.reverse
This patch adds a verifier to tosa.reverse which checks the axis
argument and input/output tensor ranks for validity.
We allow a special case where `axis == 0 && rank == 0`.
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 3 +-
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 29 ++++++++++++++++++++
mlir/test/Dialect/Tosa/canonicalize.mlir | 2 +-
mlir/test/Dialect/Tosa/invalid.mlir | 8 ++++++
4 files changed, 40 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index c0baf478358c132..81b9e93c2095f57 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1593,7 +1593,8 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
);
let hasFolder = 1;
-
+ let hasVerifier = 1;
+
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index c9e64a67302e772..9f619a3531ab615 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1768,6 +1768,35 @@ void IfOp::print(OpAsmPrinter &p) {
p.printOptionalAttrDict((*this)->getAttrs());
}
+LogicalResult ReverseOp::verify() {
+ TensorType inputType = getInput().getType();
+ TensorType outputType = getOutput().getType();
+ int32_t reverseAxis = getAxis();
+
+ if (reverseAxis < 0)
+ return emitOpError("expected non-negative reverse axis");
+ if (inputType.hasRank()) {
+ int64_t inputRank = inputType.getRank();
+ // We allow for a special case where the input/output shape has rank 0 and
+ // axis is also 0.
+ if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
+ return emitOpError("expect input tensor rank (")
+ << inputRank << ") to be larger than reverse axis (" << reverseAxis
+ << ")";
+ }
+ if (outputType.hasRank()) {
+ int64_t outputRank = outputType.getRank();
+ if (inputType.hasRank() && outputRank != inputType.getRank())
+ return emitOpError(
+ "expect output tensor rank to be equal to input tensor rank");
+ if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
+ return emitOpError("expect output tensor rank (")
+ << outputRank << ") to be larger than reverse axis ("
+ << reverseAxis << ")";
+ }
+ return success();
+}
+
// parse and print of WhileOp refer to the implementation of SCF dialect.
ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::Argument, 4> regionArgs;
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 46a31d6cf3e965e..102c9ed1578cde9 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -600,6 +600,6 @@ func.func nested @fold_reduce_rank_zero() {
// CHECK-NOT: tosa.reverse
%0 = tensor.empty() : tensor<i32>
%1 = tosa.reduce_min %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
- %2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
+ %2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
return
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 8a290299b925a7c..8e23a1fde04bc82 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -200,6 +200,14 @@ func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
// -----
+func.func @test_reverse_axis_out_of_range(%arg0 : tensor<13x21x3xf32>) -> () {
+ // expected-error at +1 {{'tosa.reverse' op expect input tensor rank (3) to be larger than reverse axis (5)}}
+ %0 = tosa.reverse %arg0 {axis = 5 : i32} : (tensor<13x21x3xf32>) -> tensor<?x?x?xi32>
+ return
+}
+
+// -----
+
func.func @test_const_attribute_type_mismatch() -> tensor<100x100xf32> {
// expected-error at +1 {{'tosa.const' op failed to verify that all of {value, output} have same shape}}
%0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1xf32>} : () -> tensor<100x100xf32>
More information about the Mlir-commits
mailing list