[Mlir-commits] [mlir] [tosa] Fix crash in shape inference for `tosa.transpose` (PR #74367)
Felix Schneider
llvmlistbot at llvm.org
Mon Dec 4 12:49:53 PST 2023
https://github.com/ubfx created https://github.com/llvm/llvm-project/pull/74367
Fixes a crash in `TransposeOp::inferReturnTypeComponents()` when the supplied permutation tensor is rank-0.
Also removes some dead code from the type inference function.
Fix https://github.com/llvm/llvm-project/issues/74237
>From 80f487f1d9e3b8e1e8ac6cdb1a99340eb761d4cb Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Mon, 4 Dec 2023 21:37:26 +0100
Subject: [PATCH] [tosa] Fix crash in shape inference for `tosa.transpose`
Fixes a crash in `TransposeOp::inferReturnTypeComponents()` when the
supplied permutation tensor is rank-0.
Also removes some dead code from the type inference function.
Fix https://github.com/llvm/llvm-project/issues/74237
---
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 12 ++++--------
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 11 +++++++++++
2 files changed, 15 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index f490cb1baa309..259fb6394669a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -983,6 +983,10 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
ShapeAdaptor inputShape(adaptor.getInput1().getType());
ShapeAdaptor permsShape(adaptor.getPerms().getType());
+ // We cannot infer anything from a rank-0 "permutation" tensor.
+ if (permsShape.hasRank() && permsShape.getRank() == 0)
+ return failure();
+
// If input rank and permutation length is unknown, the output rank is
// unknown.
if (!inputShape.hasRank() || !permsShape.hasRank() ||
@@ -997,15 +1001,7 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
return failure();
}
- // Without the input dims we cannot determine the output dim sizes but we
- // can determine the output rank.
SmallVector<int64_t> outputShape;
- if (!inputShape.hasRank()) {
- outputShape.resize(permsShape.getDimSize(0), ShapedType::kDynamic);
- inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
- return success();
- }
-
// Rank-0 means no permutations matter.
if (inputShape.getRank() == 0) {
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index f057431a841b5..c240f5334c149 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1310,3 +1310,14 @@ func.func @test_large_constant_permutation() {
return
}
+// -----
+
+// CHECK-LABEL: test_rank0_transpose_perms
+// Fail to infer the shape but not crash.
+func.func @test_rank0_transpose_perms() {
+ %14 = tensor.empty() : tensor<5x27xi64>
+ %cst = tensor.empty() : tensor<i32>
+ // CHECK: tosa.transpose
+ %72 = tosa.transpose %14, %cst : (tensor<5x27xi64>, tensor<i32>) -> tensor<?x?xi64>
+ return
+}
More information about the Mlir-commits
mailing list