[Mlir-commits] [mlir] a315534 - [mlir][tosa] Fix out-of-boundaries iteration for tosa-to-linalg

Rob Suderman llvmlistbot at llvm.org
Tue Jan 3 11:53:16 PST 2023


Author: a.puschin
Date: 2023-01-03T11:52:09-08:00
New Revision: a315534e52fd5c534fadc1e62101543aaf1537a2

URL: https://github.com/llvm/llvm-project/commit/a315534e52fd5c534fadc1e62101543aaf1537a2
DIFF: https://github.com/llvm/llvm-project/commit/a315534e52fd5c534fadc1e62101543aaf1537a2.diff

LOG: [mlir][tosa] Fix out-of-boundaries iteration for tosa-to-linalg

When the number of elements of two shapes are not equal, a Reshape operation cannot be used to transfer one into another

Function findIntermediateShape(...) can cause out-of-boundaries operator[] call if the abovementioned condition strikes

The test-case I used now causes no error as its root-cause was an issue in Tosa dialect with padded Conv2D operations lowering which is already solved in commit 69c984b6

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D140013

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 93178288dfc1b..8054c91ed8064 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1431,6 +1431,7 @@ def Tosa_ReshapeOp: Tosa_Op<"reshape", [
 
   let hasCanonicalizer = 1;
   let hasFolder = 1;
+  let hasVerifier = 1;
 
   let arguments = (ins
     Tosa_Tensor:$input1,

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 7cb72f4b97612..c4d8b686726a9 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -864,10 +864,14 @@ static bool findIntermediateShape(ArrayRef<int64_t> lhsShape,
            currRhsDim < rhsShape.size()) {
       if (lhsSize < rhsSize) {
         currLhsDim++;
-        lhsSize *= lhsShape[currLhsDim];
+        if (currLhsDim < lhsShape.size()) {
+          lhsSize *= lhsShape[currLhsDim];
+        }
       } else {
         currRhsDim++;
-        rhsSize *= rhsShape[currRhsDim];
+        if (currRhsDim < rhsShape.size()) {
+          rhsSize *= rhsShape[currRhsDim];
+        }
       }
     }
     if (lhsSize == rhsSize) {

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 1a3ba901b4db3..35c8152551226 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -700,6 +700,21 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
   return success();
 }
 
+mlir::LogicalResult tosa::ReshapeOp::verify() {
+  ShapedType inputType = getInput1().getType().cast<ShapedType>();
+  ShapedType outputType = getType().cast<ShapedType>();
+
+  if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
+    int64_t inputElementsNum = inputType.getNumElements();
+    int64_t outputElementsNum = outputType.getNumElements();
+    if (inputElementsNum != outputElementsNum) {
+      return emitOpError() << "Cannot reshape " << inputElementsNum
+                           << " elements into " << outputElementsNum;
+    }
+  }
+  return mlir::success();
+}
+
 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,


        


More information about the Mlir-commits mailing list