[Mlir-commits] [mlir] c8b5d30 - [mlir][Transforms] Add missing check in tosa::transpose::verify() (#102099)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Aug 12 00:52:42 PDT 2024


Author: DarshanRamakant
Date: 2024-08-12T15:52:39+08:00
New Revision: c8b5d30f707757a4fe4d9d0bb01f762665f6942f

URL: https://github.com/llvm/llvm-project/commit/c8b5d30f707757a4fe4d9d0bb01f762665f6942f
DIFF: https://github.com/llvm/llvm-project/commit/c8b5d30f707757a4fe4d9d0bb01f762665f6942f.diff

LOG: [mlir][Transforms] Add missing check in tosa::transpose::verify() (#102099)

The tosa::transpose::verify() should make sure
that the permutation numbers are within the size of 
the input array. Otherwise it will cause a cryptic array
out of bound assertion later.Fix #99513.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Utils/IndexingUtils.h
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Tosa/invalid.mlir
    mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index 7849782e5442bd..99218f491ddef4 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -202,6 +202,9 @@ SmallVector<T> applyPermutation(ArrayRef<T> input,
                                 ArrayRef<int64_t> permutation) {
   assert(input.size() == permutation.size() &&
          "expected input rank to equal permutation rank");
+  assert(
+      llvm::all_of(permutation, [&](size_t s) { return s < input.size(); }) &&
+      "permutation must be within input bounds");
   auto permutationRange = llvm::map_range(
       llvm::seq<unsigned>(0, input.size()),
       [&](int64_t idx) -> T { return input[permutation[idx]]; });

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 4337787e4aeadd..39ea7a5b61f5ec 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1116,6 +1116,12 @@ LogicalResult tosa::TransposeOp::verify() {
            "Unexpectedly found permutation tensor without rank");
     if (!isPermutationVector(constantPerms))
       return emitOpError() << "expected valid permutation tensor";
+
+    if (inputType.hasRank() && !llvm::all_of(constantPerms, [&](int64_t s) {
+          return s < inputType.getRank();
+        })) {
+      return emitOpError() << "permutation must be within input bounds";
+    }
   }
   return success();
 }

diff  --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index cb38d4d81ca2ee..e1fcf056480083 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -413,3 +413,38 @@ func.func @test_tile_invalid_multiples() {
   %1 = tosa.tile %0 {multiples = array<i64>} : (tensor<4x31x31xf32>) -> tensor<4x31x31xf32>
   return
 }
+
+// -----
+
+// CHECK-LABEL: @test_invalid_constant_permutation
+func.func @test_invalid_constant_permutation() {
+  // expected-error at +3 {{permutation must be within input bounds}}
+  %0 = tensor.empty() : tensor<3x4x5xi32>
+  %1 = arith.constant dense<[3, 0, 1]> : tensor<3xi32>
+  %2 = tosa.transpose %0, %1 : (tensor<3x4x5xi32>, tensor<3xi32>) -> tensor<3x4x5xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_rank_size_constant_permutation
+func.func @test_rank_size_constant_permutation() {
+  // expected-error at +4 {{permutation must be within input bounds}}
+  %0 = arith.constant 6 : index
+  %1 = arith.constant dense<[0, 2]> : tensor<2xi32>
+  %2 = tensor.empty(%0) : tensor<?x27xi64>
+  %3 = tosa.transpose %2, %1 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_large_constant_permutation
+func.func @test_large_constant_permutation() {
+  // expected-error at +4 {{permutation must be within input bounds}}
+  %0 = arith.constant 6 : index
+  %1 = arith.constant dense<[1185677355, 332462212]> : tensor<2xi32>
+  %2 = tensor.empty(%0) : tensor<?x27xi64>
+  %3 = tosa.transpose %2, %1 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
+  return
+}

diff  --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 5f8afa57bc7478..3224f88968f3d2 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1373,30 +1373,4 @@ func.func @test_tosa_use_def_chain(%arg0: tensor<1x32x32x3xf32>, %arg1: tensor<1
   // CHECK: (tensor<1x32x32x16xf32>) -> tensor<1x16x16x16xf32>
   %1 = tosa.max_pool2d %0 {kernel = array<i64: 2, 2>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>} : (tensor<?x32x32x16xf32>) -> tensor<?x16x16x16xf32>
   return %1 : tensor<?x16x16x16xf32>
-}
-
-// -----
-
-// CHECK-LABEL: test_rank_size_constant_permutation
-func.func @test_rank_size_constant_permutation() {
-  %c6 = arith.constant 6 : index
-  %cst_26 = arith.constant dense<[0, 2]> : tensor<2xi32>
-  %14 = tensor.empty(%c6) : tensor<?x27xi64>
-  // Fail to infer the shape but not crash.
-  // CHECK: (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
-  %72 = tosa.transpose %14, %cst_26 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: test_large_constant_permutation
-func.func @test_large_constant_permutation() {
-  %c6 = arith.constant 6 : index
-  %cst_26 = arith.constant dense<[1185677355, 332462212]> : tensor<2xi32>
-  %14 = tensor.empty(%c6) : tensor<?x27xi64>
-  // Fail to infer the shape but not crash.
-  // CHECK: (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
-  %72 = tosa.transpose %14, %cst_26 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
-  return
-}
+}
\ No newline at end of file


        


More information about the Mlir-commits mailing list