[Mlir-commits] [mlir] [mlir][tosa] Add missing verifier check for `tosa.reshape` (PR #109301)
Longsheng Mou
llvmlistbot at llvm.org
Thu Sep 19 08:45:31 PDT 2024
https://github.com/CoTinker created https://github.com/llvm/llvm-project/pull/109301
This PR adds a missing verifier check for `tosa.reshape`, ensuring that the number of elements in `new_shape` matches the number of elements in the input tensor. Fixes #108151 and fixes #107969.
>From f8435ca1dc3409a84c38206c780f619d8c076b2d Mon Sep 17 00:00:00 2001
From: Longsheng Mou <moulongsheng at huawei.com>
Date: Thu, 19 Sep 2024 23:37:19 +0800
Subject: [PATCH] [mlir][tosa] Add missing verifier check for `tosa.reshape`
This PR adds a missing verifier check for `tosa.reshape`,
ensuring that the number of elements in `new_shape` matches
the number of elements in the input tensor.
---
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 23 +++++++++++++++++++----
mlir/test/Dialect/Tosa/invalid.mlir | 16 ++++++++++++++++
2 files changed, 35 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 0d0241fea5152c..6dce3d03066c9a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -30,6 +30,8 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
+#include <numeric>
+
using namespace mlir;
using namespace mlir::tosa;
@@ -1015,12 +1017,25 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
<< newShapeDim;
}
- if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
+ if (inputType.hasStaticShape()) {
int64_t inputElementsNum = inputType.getNumElements();
- int64_t outputElementsNum = outputType.getNumElements();
- if (inputElementsNum != outputElementsNum) {
+ if (outputType.hasStaticShape()) {
+ int64_t outputElementsNum = outputType.getNumElements();
+ if (inputElementsNum != outputElementsNum) {
+ return emitOpError() << "cannot reshape " << inputElementsNum
+ << " elements into " << outputElementsNum;
+ }
+ }
+
+ int64_t newShapeElementsNum = std::accumulate(
+ getNewShape().begin(), getNewShape().end(), 1LL,
+ [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
+ bool isStaticNewShape =
+ llvm::all_of(getNewShape(), [](int64_t s) { return s > 0; });
+ if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
+ (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
return emitOpError() << "cannot reshape " << inputElementsNum
- << " elements into " << outputElementsNum;
+ << " elements into " << newShapeElementsNum;
}
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 311fdb1226c523..e5c5b9b3663903 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -360,6 +360,22 @@ func.func @test_reshape_invalid_size(%arg0 : tensor<2x4xf32>) -> () {
// -----
+func.func @test_reshape_invalid_newshape(%arg0 : tensor<1xf32>) -> () {
+ // expected-error at +1 {{'tosa.reshape' op cannot reshape 1 elements into 4}}
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1, 4>} : (tensor<1xf32>) -> tensor<?x4xf32>
+ return
+}
+
+// -----
+
+func.func @test_reshape_invalid_newshape(%arg0 : tensor<8xf32>) -> () {
+ // expected-error at +1 {{'tosa.reshape' op cannot reshape 8 elements into 4}}
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 1, 4>} : (tensor<8xf32>) -> tensor<?x4xf32>
+ return
+}
+
+// -----
+
func.func @test_reshape_invalid_placeholders(%arg0 : tensor<?xf32>) -> () {
// expected-error at +1 {{'tosa.reshape' op expected at most one target dimension to be -1}}
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1, -1>} : (tensor<?xf32>) -> tensor<2x?x?xf32>
More information about the Mlir-commits
mailing list