[Mlir-commits] [mlir] [mlir][tosa] Allow creation of reshape with unranked output (PR #140617)
Luke Hutton
llvmlistbot at llvm.org
Mon May 19 13:55:05 PDT 2025
https://github.com/lhutton1 created https://github.com/llvm/llvm-project/pull/140617
This commit allows reshape to be created with an unranked output, allowing it to be inferred by the shape inference pass.
>From e4218b4df50286de65f782d0b9bfe865b75b19c5 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Mon, 19 May 2025 08:57:13 +0000
Subject: [PATCH] [mlir][tosa] Allow creation of reshape with unranked output
This commit allows reshape to be created with an unranked output,
allowing it to be inferred by the shape inference pass.
Change-Id: I639e68982946eeac6dcbc0d30e6cfa2217592091
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 2 +-
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 13 ++++++++-----
mlir/test/Dialect/Tosa/ops.mlir | 8 ++++++++
3 files changed, 17 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 52bb0eb992b69..86f9ab94ec152 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1959,7 +1959,7 @@ def Tosa_ReshapeOp : Tosa_InferTensorTypeOp<"reshape"> {
);
let results = (outs
- Tosa_RankedTensor:$output
+ Tosa_Tensor:$output
);
list<Availability> availability = [
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index b2e471f2bba93..b74b820e11f75 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2012,7 +2012,6 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
return failure();
}
TensorType inputType = getInput1().getType();
- RankedTensorType outputType = getType();
SmallVector<int64_t> shapeValues;
if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeValues)) {
@@ -2020,6 +2019,14 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
return mlir::success();
}
+ int missingDims = llvm::count(shapeValues, -1);
+ if (missingDims > 1)
+ return emitOpError() << "expected at most one target dimension to be -1";
+
+ const auto outputType = dyn_cast<RankedTensorType>(getType());
+ if (!outputType)
+ return success();
+
if ((int64_t)shapeValues.size() != outputType.getRank())
return emitOpError() << "new shape does not match result rank";
@@ -2056,10 +2063,6 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
}
}
- int missingDims = llvm::count(shapeValues, -1);
- if (missingDims > 1)
- return emitOpError() << "expected at most one target dimension to be -1";
-
return mlir::success();
}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index f8273190bde40..e727614bd76f9 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -643,6 +643,14 @@ func.func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<1x819xf32> {
return %0 : tensor<1x819xf32>
}
+// -----
+// CHECK-LABEL: reshape_unranked_output
+func.func @test_reshape_unranked_output(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> {
+ %1 = tosa.const_shape {values = dense<[21, 13, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ %0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xf32>, !tosa.shape<3>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
// -----
// CHECK-LABEL: reverse
func.func @test_reverse(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
More information about the Mlir-commits
mailing list