[Mlir-commits] [mlir] 7563eb6 - [tosa] Fix crash in shape inference for `tosa.transpose` (#74367)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Dec 5 10:46:46 PST 2023


Author: Felix Schneider
Date: 2023-12-05T19:46:42+01:00
New Revision: 7563eb64102c3bee9b0ab581309d170891fa0565

URL: https://github.com/llvm/llvm-project/commit/7563eb64102c3bee9b0ab581309d170891fa0565
DIFF: https://github.com/llvm/llvm-project/commit/7563eb64102c3bee9b0ab581309d170891fa0565.diff

LOG: [tosa] Fix crash in shape inference for `tosa.transpose` (#74367)

Fixes a crash in `TransposeOp::inferReturnTypeComponents()` when the
supplied permutation tensor is rank-0.
Also removes some dead code from the type inference function.

Fix https://github.com/llvm/llvm-project/issues/74237

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index f490cb1baa309..259fb6394669a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -983,6 +983,10 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
   ShapeAdaptor inputShape(adaptor.getInput1().getType());
   ShapeAdaptor permsShape(adaptor.getPerms().getType());
 
+  // We cannot infer anything from a rank-0 "permutation" tensor.
+  if (permsShape.hasRank() && permsShape.getRank() == 0)
+    return failure();
+
   // If input rank and permutation length is unknown, the output rank is
   // unknown.
   if (!inputShape.hasRank() || !permsShape.hasRank() ||
@@ -997,15 +1001,7 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
     return failure();
   }
 
-  // Without the input dims we cannot determine the output dim sizes but we
-  // can determine the output rank.
   SmallVector<int64_t> outputShape;
-  if (!inputShape.hasRank()) {
-    outputShape.resize(permsShape.getDimSize(0), ShapedType::kDynamic);
-    inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
-    return success();
-  }
-
   // Rank-0 means no permutations matter.
   if (inputShape.getRank() == 0) {
     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));

diff  --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index f057431a841b5..c240f5334c149 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1310,3 +1310,14 @@ func.func @test_large_constant_permutation() {
   return
 }
 
+// -----
+
+// CHECK-LABEL: test_rank0_transpose_perms
+// Fail to infer the shape but not crash.
+func.func @test_rank0_transpose_perms() {
+  %14 = tensor.empty() : tensor<5x27xi64>
+  %cst = tensor.empty() : tensor<i32>
+  // CHECK: tosa.transpose
+  %72 = tosa.transpose %14, %cst : (tensor<5x27xi64>, tensor<i32>) -> tensor<?x?xi64>
+  return
+}


        


More information about the Mlir-commits mailing list