[Mlir-commits] [mlir] [mlir][tosa] Allow unranked input/output tensors in resize ops (PR #141608)
Luke Hutton
llvmlistbot at llvm.org
Tue May 27 07:09:34 PDT 2025
https://github.com/lhutton1 created https://github.com/llvm/llvm-project/pull/141608
This commit allows the input/output of the resize op to be unranked to account for shapes being computed during shape inference.
>From eee205f29f8b6298253242d5cfd069fa498cb3d2 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Tue, 27 May 2025 13:03:42 +0000
Subject: [PATCH] [mlir][tosa] Allow unranked input/output tensors in resize
ops
This commit allows the input/output of the resize op to be unranked to
account for shapes being computed during shape inference.
Change-Id: Ib53b6fa16e73779e3b9c40f8463cc89afc04226a
---
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 20 ++++++++++----------
mlir/test/Dialect/Tosa/ops.mlir | 20 ++++++++++++++++++++
2 files changed, 30 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 3ee5a85a21dca..4620da57a5b27 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2496,16 +2496,6 @@ LogicalResult tosa::ResizeOp::verify() {
const RankedTensorType outputType =
llvm::dyn_cast<RankedTensorType>(output.getType());
- if (!inputType)
- return emitOpError("expect a ranked input tensor");
- if (!outputType)
- return emitOpError("expect a ranked output tensor");
-
- const int64_t oh = outputType.getDimSize(1);
- const int64_t ow = outputType.getDimSize(2);
- const int64_t ih = inputType.getDimSize(1);
- const int64_t iw = inputType.getDimSize(2);
-
SmallVector<int64_t> scaleValues;
SmallVector<int64_t> offsetValues;
SmallVector<int64_t> borderValues;
@@ -2531,6 +2521,16 @@ LogicalResult tosa::ResizeOp::verify() {
const int64_t borderY = borderValues[0];
const int64_t borderX = borderValues[1];
+ if (!inputType)
+ return success();
+ if (!outputType)
+ return success();
+
+ const int64_t oh = outputType.getDimSize(1);
+ const int64_t ow = outputType.getDimSize(2);
+ const int64_t ih = inputType.getDimSize(1);
+ const int64_t iw = inputType.getDimSize(2);
+
// Don't check with input height that could be broadcast (ih != 1)
// since Linalg, a consumer of TOSA, expects broadcasting support
// in resize to be available. Taking the cautious approach for now,
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 5ec506a45b3ad..767fa833dedd4 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -743,6 +743,26 @@ func.func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> {
return %1 : tensor<1x64x64x8xf32>
}
+// -----
+// CHECK-LABEL: resize_unranked_output
+func.func @test_resize_unranked_output(%arg0: tensor<1x32x32x8xf32>) -> tensor<*xf32> {
+ %scale = tosa.const_shape { values = dense<[4, 2, 4, 2]> : tensor<4xindex> } : () -> !tosa.shape<4>
+ %offset = tosa.const_shape { values = dense<[-1, -1]> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %border = tosa.const_shape { values = dense<[1, 1]> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %1 = tosa.resize %arg0, %scale, %offset, %border { mode = "BILINEAR" } : (tensor<1x32x32x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
+ return %1 : tensor<*xf32>
+}
+
+// -----
+// CHECK-LABEL: resize_unranked_input
+func.func @test_resize_unranked_input(%arg0: tensor<*xf32>) -> tensor<1x64x64x8xf32> {
+ %scale = tosa.const_shape { values = dense<[4, 2, 4, 2]> : tensor<4xindex> } : () -> !tosa.shape<4>
+ %offset = tosa.const_shape { values = dense<[-1, -1]> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %border = tosa.const_shape { values = dense<[1, 1]> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %1 = tosa.resize %arg0, %scale, %offset, %border { mode = "BILINEAR" } : (tensor<*xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x64x64x8xf32>
+ return %1 : tensor<1x64x64x8xf32>
+}
+
// -----
// CHECK-LABEL: cast
func.func @test_cast1(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> {
More information about the Mlir-commits
mailing list