[Mlir-commits] [mlir] e6eb94d - [mlir][tosa] Add missing verifier check for `tosa.reshape` (#109301)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Sep 20 02:49:21 PDT 2024
Author: Longsheng Mou
Date: 2024-09-20T10:49:17+01:00
New Revision: e6eb94d3982cc55b16f62a17177dab7da46ae0f0
URL: https://github.com/llvm/llvm-project/commit/e6eb94d3982cc55b16f62a17177dab7da46ae0f0
DIFF: https://github.com/llvm/llvm-project/commit/e6eb94d3982cc55b16f62a17177dab7da46ae0f0.diff
LOG: [mlir][tosa] Add missing verifier check for `tosa.reshape` (#109301)
This PR adds a missing verifier check for `tosa.reshape`, ensuring that
the number of elements in `new_shape` matches the number of elements in
the input tensor. Fixes #108151 and fixes #107969.
Added:
Modified:
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Dialect/Tosa/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 0d0241fea5152c..6dce3d03066c9a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -30,6 +30,8 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
+#include <numeric>
+
using namespace mlir;
using namespace mlir::tosa;
@@ -1015,12 +1017,25 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
<< newShapeDim;
}
- if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
+ if (inputType.hasStaticShape()) {
int64_t inputElementsNum = inputType.getNumElements();
- int64_t outputElementsNum = outputType.getNumElements();
- if (inputElementsNum != outputElementsNum) {
+ if (outputType.hasStaticShape()) {
+ int64_t outputElementsNum = outputType.getNumElements();
+ if (inputElementsNum != outputElementsNum) {
+ return emitOpError() << "cannot reshape " << inputElementsNum
+ << " elements into " << outputElementsNum;
+ }
+ }
+
+ int64_t newShapeElementsNum = std::accumulate(
+ getNewShape().begin(), getNewShape().end(), 1LL,
+ [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
+ bool isStaticNewShape =
+ llvm::all_of(getNewShape(), [](int64_t s) { return s > 0; });
+ if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
+ (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
return emitOpError() << "cannot reshape " << inputElementsNum
- << " elements into " << outputElementsNum;
+ << " elements into " << newShapeElementsNum;
}
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 311fdb1226c523..e5c5b9b3663903 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -360,6 +360,22 @@ func.func @test_reshape_invalid_size(%arg0 : tensor<2x4xf32>) -> () {
// -----
+func.func @test_reshape_invalid_newshape(%arg0 : tensor<1xf32>) -> () {
+ // expected-error at +1 {{'tosa.reshape' op cannot reshape 1 elements into 4}}
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1, 4>} : (tensor<1xf32>) -> tensor<?x4xf32>
+ return
+}
+
+// -----
+
+func.func @test_reshape_invalid_newshape(%arg0 : tensor<8xf32>) -> () {
+ // expected-error at +1 {{'tosa.reshape' op cannot reshape 8 elements into 4}}
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 1, 4>} : (tensor<8xf32>) -> tensor<?x4xf32>
+ return
+}
+
+// -----
+
func.func @test_reshape_invalid_placeholders(%arg0 : tensor<?xf32>) -> () {
// expected-error at +1 {{'tosa.reshape' op expected at most one target dimension to be -1}}
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1, -1>} : (tensor<?xf32>) -> tensor<2x?x?xf32>
More information about the Mlir-commits
mailing list