[Mlir-commits] [mlir] [mlir][tosa] Fix not to crash with large permutation indexes (PR #69857)

Kai Sasaki llvmlistbot at llvm.org
Mon Oct 23 23:43:56 PDT 2023


https://github.com/Lewuathe updated https://github.com/llvm/llvm-project/pull/69857

>From 3cca18b97b8e4ee0cfd9dff064571777aca61374 Mon Sep 17 00:00:00 2001
From: Kai Sasaki <lewuathe at gmail.com>
Date: Fri, 20 Oct 2023 15:30:16 +0900
Subject: [PATCH 1/2] [mlir][tosa] Fix not to crash with large permutation
 indexes with tosa.transpose

---
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          | 10 +++++++
 mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 26 +++++++++++++++++++
 2 files changed, 36 insertions(+)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index ff34183f9a030a8..537817db4abc588 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -966,6 +966,16 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
   if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) &&
       attr.getType().getRank() == 1) {
     ShapeAdaptor permShape = attr;
+    // Constant permutation must be the same length as the input rank.
+    if (inputShape.getRank() != permShape.getRank())
+      return failure();
+
+    // Constant permutation values must be within the input rank.
+    for (int i = 0; i < inputShape.getRank(); i++) {
+      if (inputShape.getRank() <= permShape.getDimSize(i))
+        return failure();
+    }
+
     outputShape.reserve(inputShape.getRank());
     for (int i = 0, s = inputShape.getRank(); i < s; i++) {
       outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 1ce4defcf4a6e65..2dc46a9a3b9cb73 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1272,3 +1272,29 @@ func.func @test_tosa_use_def_chain(%arg0: tensor<1x32x32x3xf32>, %arg1: tensor<1
   %1 = tosa.max_pool2d %0 {kernel = array<i64: 2, 2>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>} : (tensor<?x32x32x16xf32>) -> tensor<?x16x16x16xf32>
   return %1 : tensor<?x16x16x16xf32>
 }
+
+// -----
+
+// CHECK-LABEL: test_rank_size_constant_permutation
+func.func @test_rank_size_constant_permutation() {
+  %c6 = arith.constant 6 : index
+  %cst_26 = arith.constant dense<[0, 2]> : tensor<2xi32>
+  %14 = tensor.empty(%c6) : tensor<?x27xi64>
+  // Fail to infer the shape but not crash.
+  // CHECK: (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
+  %72 = tosa.transpose %14, %cst_26 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_large_constant_permutation
+func.func @test_large_constant_permutation() {
+  %c6 = arith.constant 6 : index
+  %cst_26 = arith.constant dense<[1185677355, 332462212]> : tensor<2xi32>
+  %14 = tensor.empty(%c6) : tensor<?x27xi64>
+  // Fail to infer the shape but not crash.
+  // CHECK: (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
+  %72 = tosa.transpose %14, %cst_26 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
+  return
+}
\ No newline at end of file

>From a5b00fef448231dc688d7bb3aee3a225046ebaba Mon Sep 17 00:00:00 2001
From: Kai Sasaki <lewuathe at gmail.com>
Date: Tue, 24 Oct 2023 15:43:16 +0900
Subject: [PATCH 2/2] Post review follow-up

---
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          | 6 ++++--
 mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 2 +-
 2 files changed, 5 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 537817db4abc588..a9197ab5f45cb0f 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -968,10 +968,12 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
     ShapeAdaptor permShape = attr;
     // Constant permutation must be the same length as the input rank.
     if (inputShape.getRank() != permShape.getRank())
-      return failure();
+      return emitOptionalError(location,
+                               "Constant permutation must be the same length"
+                               " as the input rank");
 
     // Constant permutation values must be within the input rank.
-    for (int i = 0; i < inputShape.getRank(); i++) {
+    for (int i = 0, e = inputShape.getRank(); i < e; i++) {
       if (inputShape.getRank() <= permShape.getDimSize(i))
         return failure();
     }
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 2dc46a9a3b9cb73..7af66ae1dbc90f0 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1297,4 +1297,4 @@ func.func @test_large_constant_permutation() {
   // CHECK: (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
   %72 = tosa.transpose %14, %cst_26 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
   return
-}
\ No newline at end of file
+}



More information about the Mlir-commits mailing list