[Mlir-commits] [mlir] [mlir][tosa] Add missing verifier check for `tosa.reshape` (PR #109301)

Longsheng Mou llvmlistbot at llvm.org
Thu Sep 19 08:45:31 PDT 2024


https://github.com/CoTinker created https://github.com/llvm/llvm-project/pull/109301

This PR adds a missing verifier check for `tosa.reshape`, ensuring that the number of elements in `new_shape` matches the number of elements in the input tensor. Fixes #108151 and fixes #107969.

>From f8435ca1dc3409a84c38206c780f619d8c076b2d Mon Sep 17 00:00:00 2001
From: Longsheng Mou <moulongsheng at huawei.com>
Date: Thu, 19 Sep 2024 23:37:19 +0800
Subject: [PATCH] [mlir][tosa] Add missing verifier check for `tosa.reshape`

This PR adds a missing verifier check for `tosa.reshape`,
ensuring that the number of elements in `new_shape` matches
the number of elements in the input tensor.
---
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 23 +++++++++++++++++++----
 mlir/test/Dialect/Tosa/invalid.mlir  | 16 ++++++++++++++++
 2 files changed, 35 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 0d0241fea5152c..6dce3d03066c9a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -30,6 +30,8 @@
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/TypeSwitch.h"
 
+#include <numeric>
+
 using namespace mlir;
 using namespace mlir::tosa;
 
@@ -1015,12 +1017,25 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
                            << newShapeDim;
   }
 
-  if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
+  if (inputType.hasStaticShape()) {
     int64_t inputElementsNum = inputType.getNumElements();
-    int64_t outputElementsNum = outputType.getNumElements();
-    if (inputElementsNum != outputElementsNum) {
+    if (outputType.hasStaticShape()) {
+      int64_t outputElementsNum = outputType.getNumElements();
+      if (inputElementsNum != outputElementsNum) {
+        return emitOpError() << "cannot reshape " << inputElementsNum
+                             << " elements into " << outputElementsNum;
+      }
+    }
+
+    int64_t newShapeElementsNum = std::accumulate(
+        getNewShape().begin(), getNewShape().end(), 1LL,
+        [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
+    bool isStaticNewShape =
+        llvm::all_of(getNewShape(), [](int64_t s) { return s > 0; });
+    if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
+        (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
       return emitOpError() << "cannot reshape " << inputElementsNum
-                           << " elements into " << outputElementsNum;
+                           << " elements into " << newShapeElementsNum;
     }
   }
 
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 311fdb1226c523..e5c5b9b3663903 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -360,6 +360,22 @@ func.func @test_reshape_invalid_size(%arg0 : tensor<2x4xf32>) -> () {
 
 // -----
 
+func.func @test_reshape_invalid_newshape(%arg0 : tensor<1xf32>) -> () {
+  // expected-error at +1 {{'tosa.reshape' op cannot reshape 1 elements into 4}}
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1, 4>} : (tensor<1xf32>) -> tensor<?x4xf32>
+  return
+}
+
+// -----
+
+func.func @test_reshape_invalid_newshape(%arg0 : tensor<8xf32>) -> () {
+  // expected-error at +1 {{'tosa.reshape' op cannot reshape 8 elements into 4}}
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 1, 4>} : (tensor<8xf32>) -> tensor<?x4xf32>
+  return
+}
+
+// -----
+
 func.func @test_reshape_invalid_placeholders(%arg0 : tensor<?xf32>) -> () {
   // expected-error at +1 {{'tosa.reshape' op expected at most one target dimension to be -1}}
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1, -1>} : (tensor<?xf32>) -> tensor<2x?x?xf32>



More information about the Mlir-commits mailing list