[Mlir-commits] [mlir] 4779348 - [mlir][tosa] Fix not to crash with large permutation indexes (#69857)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Oct 29 03:14:15 PDT 2023
Author: Kai Sasaki
Date: 2023-10-29T19:14:10+09:00
New Revision: 47793481a4dac8b23fbdface9679835ea551a86c
URL: https://github.com/llvm/llvm-project/commit/47793481a4dac8b23fbdface9679835ea551a86c
DIFF: https://github.com/llvm/llvm-project/commit/47793481a4dac8b23fbdface9679835ea551a86c.diff
LOG: [mlir][tosa] Fix not to crash with large permutation indexes (#69857)
Added:
Modified:
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 9f619a3531ab615..4ec6714a7e02a8b 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1016,6 +1016,18 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) &&
attr.getType().getRank() == 1) {
ShapeAdaptor permShape = attr;
+ // Constant permutation must be the same length as the input rank.
+ if (inputShape.getRank() != permShape.getRank())
+ return emitOptionalError(location,
+ "constant permutation must be the same length"
+ " as the input rank");
+
+ // Constant permutation values must be within the input rank.
+ for (int i = 0, e = inputShape.getRank(); i < e; i++) {
+ if (inputShape.getRank() <= permShape.getDimSize(i))
+ return failure();
+ }
+
outputShape.reserve(inputShape.getRank());
for (int i = 0, s = inputShape.getRank(); i < s; i++) {
outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 1ce4defcf4a6e65..7af66ae1dbc90f0 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1272,3 +1272,29 @@ func.func @test_tosa_use_def_chain(%arg0: tensor<1x32x32x3xf32>, %arg1: tensor<1
%1 = tosa.max_pool2d %0 {kernel = array<i64: 2, 2>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>} : (tensor<?x32x32x16xf32>) -> tensor<?x16x16x16xf32>
return %1 : tensor<?x16x16x16xf32>
}
+
+// -----
+
+// CHECK-LABEL: test_rank_size_constant_permutation
+func.func @test_rank_size_constant_permutation() {
+ %c6 = arith.constant 6 : index
+ %cst_26 = arith.constant dense<[0, 2]> : tensor<2xi32>
+ %14 = tensor.empty(%c6) : tensor<?x27xi64>
+ // Fail to infer the shape but not crash.
+ // CHECK: (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
+ %72 = tosa.transpose %14, %cst_26 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: test_large_constant_permutation
+func.func @test_large_constant_permutation() {
+ %c6 = arith.constant 6 : index
+ %cst_26 = arith.constant dense<[1185677355, 332462212]> : tensor<2xi32>
+ %14 = tensor.empty(%c6) : tensor<?x27xi64>
+ // Fail to infer the shape but not crash.
+ // CHECK: (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
+ %72 = tosa.transpose %14, %cst_26 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
+ return
+}
More information about the Mlir-commits
mailing list