[Mlir-commits] [mlir] 22a4930 - [mlir][tosa] Allow creation of reshape with unranked output (#140617)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 21 01:59:39 PDT 2025


Author: Luke Hutton
Date: 2025-05-21T09:59:36+01:00
New Revision: 22a493089ad009c7fd444fb2022c9174e681e222

URL: https://github.com/llvm/llvm-project/commit/22a493089ad009c7fd444fb2022c9174e681e222
DIFF: https://github.com/llvm/llvm-project/commit/22a493089ad009c7fd444fb2022c9174e681e222.diff

LOG: [mlir][tosa] Allow creation of reshape with unranked output (#140617)

This commit allows reshape to be created with an unranked output,
allowing it to be inferred by the shape inference pass.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Tosa/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 52bb0eb992b69..86f9ab94ec152 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1959,7 +1959,7 @@ def Tosa_ReshapeOp : Tosa_InferTensorTypeOp<"reshape"> {
   );
 
   let results = (outs
-    Tosa_RankedTensor:$output
+    Tosa_Tensor:$output
   );
 
   list<Availability> availability = [

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 371c6dc27b428..3ee5a85a21dca 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2064,7 +2064,6 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
     return failure();
   }
   TensorType inputType = getInput1().getType();
-  RankedTensorType outputType = getType();
 
   SmallVector<int64_t> shapeValues;
   if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeValues)) {
@@ -2072,6 +2071,14 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
     return mlir::success();
   }
 
+  int missingDims = llvm::count(shapeValues, -1);
+  if (missingDims > 1)
+    return emitOpError() << "expected at most one target dimension to be -1";
+
+  const auto outputType = dyn_cast<RankedTensorType>(getType());
+  if (!outputType)
+    return success();
+
   if ((int64_t)shapeValues.size() != outputType.getRank())
     return emitOpError() << "new shape does not match result rank";
 
@@ -2108,10 +2115,6 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
     }
   }
 
-  int missingDims = llvm::count(shapeValues, -1);
-  if (missingDims > 1)
-    return emitOpError() << "expected at most one target dimension to be -1";
-
   return mlir::success();
 }
 

diff  --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index e327ed900f45f..7aea1c06698e8 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -643,6 +643,14 @@ func.func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<1x819xf32> {
   return %0 : tensor<1x819xf32>
 }
 
+// -----
+// CHECK-LABEL: reshape_unranked_output
+func.func @test_reshape_unranked_output(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> {
+  %1 = tosa.const_shape {values = dense<[21, 13, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xf32>, !tosa.shape<3>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
 // -----
 // CHECK-LABEL: reverse
 func.func @test_reverse(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {


        


More information about the Mlir-commits mailing list