[Mlir-commits] [mlir] [mlir][tosa] Enhance TosaInferShapes pass for simple shape inference (PR #178418)

Luke Hutton llvmlistbot at llvm.org
Tue Feb 3 01:28:46 PST 2026


https://github.com/lhutton1 updated https://github.com/llvm/llvm-project/pull/178418

>From d136c90c0392140527ff684e66a9842de0065f51 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Tue, 27 Jan 2026 17:44:04 +0000
Subject: [PATCH 1/3] Move propagateShapesIn* functions to private methods

This allows pass parameters to be accessed without plumbing
them through multiple layers of function calls.

Change-Id: I45a64739b00cc30ed1aed0698fe6c75a6321a9e7
---
 .../Tosa/Transforms/TosaInferShapes.cpp       | 348 +++++++++---------
 1 file changed, 175 insertions(+), 173 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index 4d347c02ee16d..a62b2278c4362 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -128,179 +128,6 @@ class TypeModificationState {
   llvm::SmallVector<std::pair<Value, Type>> oldTypes;
 };
 
-void propagateShapesInRegion(Region &region, TypeModificationState &state);
-
-void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) {
-  IfOp ifOp = dyn_cast<IfOp>(op);
-  if (!ifOp)
-    return;
-
-  for (auto &region : op.getRegions()) {
-    Block &frontBlock = region.front();
-    if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands())
-      return;
-
-    for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
-      auto inferredTy = cast<ShapedType>(op.getOperand(i).getType());
-      auto blockArg = frontBlock.getArgument(i - 1);
-      auto oldType = cast<ShapedType>(blockArg.getType());
-
-      if (inferredTy.hasRank()) {
-        Type newType = oldType.clone(inferredTy.getShape());
-        state.setType(blockArg, newType);
-      }
-    }
-
-    for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) {
-      ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType(
-          ifOp.getOperand(i + 1).getType());
-      ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType(
-          frontBlock.getArgument(i).getType());
-      ValueKnowledge joinedKnowledge =
-          ValueKnowledge::join(operandKnowledge, blockKnowledge);
-      if (!joinedKnowledge)
-        continue;
-      state.setType(frontBlock.getArgument(i), joinedKnowledge.getType());
-    }
-
-    propagateShapesInRegion(region, state);
-  }
-}
-
-void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
-  WhileOp whileOp = dyn_cast<WhileOp>(op);
-  if (!whileOp)
-    return;
-
-  // Determine what the expected argument types are to the cond/body blocks.
-  // The expected arguments should be compatible with ever iteration of the
-  // loop body / condition for tosa.while.
-  SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes());
-
-  bool hasNewTypes = true;
-  while (hasNewTypes) {
-    TypeModificationState localState;
-
-    // Set types on the block args.
-    Region &bodyRegion = op.getRegion(1);
-    Block &block = bodyRegion.front();
-    for (int i = 0, s = argTypes.size(); i < s; i++) {
-      localState.setType(block.getArgument(i), argTypes[i]);
-    }
-
-    // Propagate to the end.
-    propagateShapesInRegion(bodyRegion, localState);
-
-    // Find all the tosa yield types and verify there is a single one.
-    llvm::SmallVector<YieldOp> yieldOps;
-    for (auto &block : bodyRegion)
-      if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
-        yieldOps.push_back(yieldOp);
-
-    assert(yieldOps.size() == 1 && "missing or non-unique yield op");
-    // Using the new tosa.yield operand types, infer the new subtypes.
-    llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
-    for (auto ty : argTypes) {
-      yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty));
-    }
-
-    for (auto yieldOp : yieldOps) {
-      for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
-        auto newKnowledge =
-            ValueKnowledge::getKnowledgeFromType(it.value().getType());
-        yieldTypeInfo[it.index()] =
-            ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge);
-      }
-    }
-
-    // This should never happen.
-    if (yieldTypeInfo.size() != argTypes.size()) {
-      op.emitWarning("has a tosa.yield with the incorrect number of operands");
-      return;
-    }
-
-    // Determine the new block args and see if any changed.
-    hasNewTypes = false;
-    for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
-      Type newType = yieldTypeInfo[i].getType();
-      hasNewTypes |= (newType != argTypes[i]);
-      argTypes[i] = newType;
-    }
-
-    // Roll back all changes made during the speculative part of the algorithm.
-    localState.rollBack();
-  }
-
-  // We now set the block arguments according to the most recent shape
-  // inference results. This gives us the block arg types for the next
-  // iteration.
-  for (auto &region : op.getRegions()) {
-    for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
-      state.setType(region.front().getArgument(i), argTypes[i]);
-    }
-
-    propagateShapesInRegion(region, state);
-  }
-}
-
-void propagateShapesInRegion(Region &region, TypeModificationState &state) {
-  Dialect *tosaDialect = region.getContext()->getLoadedDialect<TosaDialect>();
-
-  for (auto &block : region) {
-    for (Operation &op : block) {
-      if (op.getDialect() != tosaDialect)
-        continue;
-
-      propagateShapesToTosaIf(op, state);
-      propagateShapesToTosaWhile(op, state);
-
-      InferShapedTypeOpInterface shapeInterface =
-          dyn_cast<InferShapedTypeOpInterface>(op);
-      if (!shapeInterface)
-        continue;
-
-      SmallVector<ShapedTypeComponents> returnedShapes;
-
-      if (shapeInterface
-              .inferReturnTypeComponents(
-                  op.getContext(), op.getLoc(), op.getOperands(),
-                  op.getDiscardableAttrDictionary(), op.getPropertiesStorage(),
-                  op.getRegions(), returnedShapes)
-              .succeeded()) {
-        for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
-          Value result = std::get<0>(it);
-          ShapedTypeComponents predictedShape = std::get<1>(it);
-
-          // Determine the knowledge based on the output type.
-          // TODO: should also query WIP type probably
-          Type resultTy = result.getType();
-          auto currentKnowledge =
-              ValueKnowledge::getKnowledgeFromType(resultTy);
-
-          // Compute the knowledge based on the inferred type.
-          auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
-          inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
-          inferredKnowledge.hasRank = predictedShape.hasRank();
-          if (predictedShape.hasRank()) {
-            for (auto dim : predictedShape.getDims()) {
-              inferredKnowledge.sizes.push_back(dim);
-            }
-          }
-
-          // Compute the new type based on the joined version.
-          auto newKnowledge =
-              ValueKnowledge::join(currentKnowledge, inferredKnowledge);
-          if (!newKnowledge)
-            continue;
-
-          // Set new type
-          state.setType(result, newKnowledge.getType());
-        }
-      }
-    }
-  }
-}
-
 /// Recursively validate tosa ops with SameOperandsAndResultRank trait in region
 /// and all nested regions
 void validateSameOperandsAndResultRankTrait(Region &region) {
@@ -341,5 +168,180 @@ struct TosaInferShapes
 
     validateSameOperandsAndResultRankTrait(func.getBody());
   }
+
+private:
+  void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) {
+    IfOp ifOp = dyn_cast<IfOp>(op);
+    if (!ifOp)
+      return;
+
+    for (auto &region : op.getRegions()) {
+      Block &frontBlock = region.front();
+      if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands())
+        return;
+
+      for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
+        auto inferredTy = cast<ShapedType>(op.getOperand(i).getType());
+        auto blockArg = frontBlock.getArgument(i - 1);
+        auto oldType = cast<ShapedType>(blockArg.getType());
+
+        if (inferredTy.hasRank()) {
+          Type newType = oldType.clone(inferredTy.getShape());
+          state.setType(blockArg, newType);
+        }
+      }
+
+      for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) {
+        ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType(
+            ifOp.getOperand(i + 1).getType());
+        ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType(
+            frontBlock.getArgument(i).getType());
+        ValueKnowledge joinedKnowledge =
+            ValueKnowledge::join(operandKnowledge, blockKnowledge);
+        if (!joinedKnowledge)
+          continue;
+        state.setType(frontBlock.getArgument(i), joinedKnowledge.getType());
+      }
+
+      propagateShapesInRegion(region, state);
+    }
+  }
+
+  void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
+    WhileOp whileOp = dyn_cast<WhileOp>(op);
+    if (!whileOp)
+      return;
+
+    // Determine what the expected argument types are to the cond/body blocks.
+    // The expected arguments should be compatible with ever iteration of the
+    // loop body / condition for tosa.while.
+    SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes());
+
+    bool hasNewTypes = true;
+    while (hasNewTypes) {
+      TypeModificationState localState;
+
+      // Set types on the block args.
+      Region &bodyRegion = op.getRegion(1);
+      Block &block = bodyRegion.front();
+      for (int i = 0, s = argTypes.size(); i < s; i++) {
+        localState.setType(block.getArgument(i), argTypes[i]);
+      }
+
+      // Propagate to the end.
+      propagateShapesInRegion(bodyRegion, localState);
+
+      // Find all the tosa yield types and verify there is a single one.
+      llvm::SmallVector<YieldOp> yieldOps;
+      for (auto &block : bodyRegion)
+        if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
+          yieldOps.push_back(yieldOp);
+
+      assert(yieldOps.size() == 1 && "missing or non-unique yield op");
+      // Using the new tosa.yield operand types, infer the new subtypes.
+      llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
+      for (auto ty : argTypes) {
+        yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty));
+      }
+
+      for (auto yieldOp : yieldOps) {
+        for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
+          auto newKnowledge =
+              ValueKnowledge::getKnowledgeFromType(it.value().getType());
+          yieldTypeInfo[it.index()] =
+              ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge);
+        }
+      }
+
+      // This should never happen.
+      if (yieldTypeInfo.size() != argTypes.size()) {
+        op.emitWarning(
+            "has a tosa.yield with the incorrect number of operands");
+        return;
+      }
+
+      // Determine the new block args and see if any changed.
+      hasNewTypes = false;
+      for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
+        Type newType = yieldTypeInfo[i].getType();
+        hasNewTypes |= (newType != argTypes[i]);
+        argTypes[i] = newType;
+      }
+
+      // Roll back all changes made during the speculative part of the
+      // algorithm.
+      localState.rollBack();
+    }
+
+    // We now set the block arguments according to the most recent shape
+    // inference results. This gives us the block arg types for the next
+    // iteration.
+    for (auto &region : op.getRegions()) {
+      for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
+        state.setType(region.front().getArgument(i), argTypes[i]);
+      }
+
+      propagateShapesInRegion(region, state);
+    }
+  }
+
+  void propagateShapesInRegion(Region &region, TypeModificationState &state) {
+    Dialect *tosaDialect = region.getContext()->getLoadedDialect<TosaDialect>();
+
+    for (auto &block : region) {
+      for (Operation &op : block) {
+        if (op.getDialect() != tosaDialect)
+          continue;
+
+        propagateShapesToTosaIf(op, state);
+        propagateShapesToTosaWhile(op, state);
+
+        InferShapedTypeOpInterface shapeInterface =
+            dyn_cast<InferShapedTypeOpInterface>(op);
+        if (!shapeInterface)
+          continue;
+
+        SmallVector<ShapedTypeComponents> returnedShapes;
+
+        if (shapeInterface
+                .inferReturnTypeComponents(
+                    op.getContext(), op.getLoc(), op.getOperands(),
+                    op.getDiscardableAttrDictionary(),
+                    op.getPropertiesStorage(), op.getRegions(), returnedShapes)
+                .succeeded()) {
+          for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
+            Value result = std::get<0>(it);
+            ShapedTypeComponents predictedShape = std::get<1>(it);
+
+            // Determine the knowledge based on the output type.
+            // TODO: should also query WIP type probably
+            Type resultTy = result.getType();
+            auto currentKnowledge =
+                ValueKnowledge::getKnowledgeFromType(resultTy);
+
+            // Compute the knowledge based on the inferred type.
+            auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
+            inferredKnowledge.dtype =
+                cast<ShapedType>(resultTy).getElementType();
+            inferredKnowledge.hasRank = predictedShape.hasRank();
+            if (predictedShape.hasRank()) {
+              for (auto dim : predictedShape.getDims()) {
+                inferredKnowledge.sizes.push_back(dim);
+              }
+            }
+
+            // Compute the new type based on the joined version.
+            auto newKnowledge =
+                ValueKnowledge::join(currentKnowledge, inferredKnowledge);
+            if (!newKnowledge)
+              continue;
+
+            // Set new type
+            state.setType(result, newKnowledge.getType());
+          }
+        }
+      }
+    }
+  }
 };
 } // namespace

>From 659d03102e1d0b6e161932745bc5364ab8ebad00 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Tue, 27 Jan 2026 18:28:27 +0000
Subject: [PATCH 2/3] [mlir][tosa] Enhance TosaInferShapes pass for simple
 shape inference

This commit enhances the TosaInferShapes pass with two new options:
- fold-shape-expressions
- convert-function-boundaries

The "fold-shape-expressions" option enables greedily folding the
newly added TOSA shape operations when possible. Folding these
operations directly within TosaInferShapes is useful since it allows
shapes of later operations to be inferred in a single pass.

The "convert-function-boundaries" updates the return types of a
function to the newly inferred output shapes. This avoids the need for
additional tensor.cast operations at function boundaries. This option
is particularly useful when wanting to resolve a dynamic function to
fully static.

When both of these options are used in conjuction with the
"tosa-input-shapes" pass option, it's possible to resolve a dynamic
function to static in a single pass.

Change-Id: I29dffd9910cd7ef6ed9fda7d38ec24accb48e598
---
 .../mlir/Dialect/Tosa/Transforms/Passes.td    |  10 +
 .../Tosa/Transforms/TosaInferShapes.cpp       |  57 ++++-
 ...a-infer-shapes-fold-shape-expressions.mlir |  59 ++++++
 mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 200 ++++++++----------
 4 files changed, 208 insertions(+), 118 deletions(-)
 create mode 100644 mlir/test/Dialect/Tosa/tosa-infer-shapes-fold-shape-expressions.mlir

diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 03be41d684f3f..5979ce4962e55 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -42,6 +42,16 @@ def TosaInferShapesPass : Pass<"tosa-infer-shapes", "func::FuncOp"> {
     "tensor::TensorDialect",
     "tosa::TosaDialect",
   ];
+
+  let options = [
+    Option<"foldShapeExpressions", "fold-shape-expressions", "bool",
+           /*default=*/"false",
+           "Fold TOSA shape operations when they have known input values">,
+    Option<"convertFunctionBoundaries", "convert-function-boundaries", "bool",
+           /*default=*/"false",
+           "If enabled, the pass will convert function I/O types as well. Otherwise casts will"
+           "be inserted at the I/O boundaries.">
+  ];
 }
 
 def TosaMakeBroadcastablePass
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index a62b2278c4362..37644ee8c03f8 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -20,6 +20,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/FoldUtils.h"
 
 namespace mlir {
 namespace tosa {
@@ -160,6 +161,13 @@ void validateSameOperandsAndResultRankTrait(Region &region) {
 struct TosaInferShapes
     : public tosa::impl::TosaInferShapesPassBase<TosaInferShapes> {
 public:
+  explicit TosaInferShapes() = default;
+  explicit TosaInferShapes(const TosaInferShapesPassOptions &options)
+      : TosaInferShapes() {
+    this->foldShapeExpressions = options.foldShapeExpressions;
+    this->convertFunctionBoundaries = options.convertFunctionBoundaries;
+  }
+
   void runOnOperation() override {
     func::FuncOp func = getOperation();
     TypeModificationState state;
@@ -167,6 +175,9 @@ struct TosaInferShapes
     state.commit();
 
     validateSameOperandsAndResultRankTrait(func.getBody());
+
+    if (convertFunctionBoundaries)
+      convertFunctionReturnTypes(func);
   }
 
 private:
@@ -286,16 +297,25 @@ struct TosaInferShapes
   }
 
   void propagateShapesInRegion(Region &region, TypeModificationState &state) {
-    Dialect *tosaDialect = region.getContext()->getLoadedDialect<TosaDialect>();
+    MLIRContext *ctx = region.getContext();
+    Dialect *tosaDialect = ctx->getLoadedDialect<TosaDialect>();
+    OperationFolder folder(ctx);
 
     for (auto &block : region) {
-      for (Operation &op : block) {
+      for (auto it = block.begin(); it != block.end();) {
+        Operation &op = *it++;
         if (op.getDialect() != tosaDialect)
           continue;
 
         propagateShapesToTosaIf(op, state);
         propagateShapesToTosaWhile(op, state);
 
+        if (foldShapeExpressions &&
+            op.hasTrait<OpTrait::tosa::TosaShapeOperator>()) {
+          (void)folder.tryToFold(&op);
+          continue;
+        }
+
         InferShapedTypeOpInterface shapeInterface =
             dyn_cast<InferShapedTypeOpInterface>(op);
         if (!shapeInterface)
@@ -343,5 +363,38 @@ struct TosaInferShapes
       }
     }
   }
+
+  void convertFunctionReturnTypes(func::FuncOp func) {
+    IRRewriter rewriter(func.getContext());
+    SmallVector<Type> newReturnTypes;
+
+    // Rewrite func.return ops, removing dead tensor.cast ops if possible
+    func.walk([&rewriter, &newReturnTypes](func::ReturnOp ret) {
+      SmallVector<Value> newReturnValues;
+      OperandRange returnOperands = ret.getOperands();
+      newReturnValues.reserve(returnOperands.size());
+      newReturnTypes.reserve(returnOperands.size());
+
+      for (const Value &v : returnOperands) {
+        Value newReturnValue = v;
+        if (auto castOp = v.getDefiningOp<tensor::CastOp>()) {
+          newReturnValue = castOp.getSource();
+          if (castOp->use_empty())
+            rewriter.eraseOp(castOp);
+        }
+        newReturnValues.push_back(newReturnValue);
+        newReturnTypes.push_back(newReturnValue.getType());
+      }
+
+      rewriter.setInsertionPoint(ret);
+      rewriter.replaceOpWithNewOp<func::ReturnOp>(ret, newReturnValues);
+    });
+
+    // Update function return types with newly inferred types
+    const FunctionType oldType = func.getFunctionType();
+    const FunctionType newType = FunctionType::get(
+        func.getContext(), oldType.getInputs(), newReturnTypes);
+    func.setType(newType);
+  }
 };
 } // namespace
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes-fold-shape-expressions.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes-fold-shape-expressions.mlir
new file mode 100644
index 0000000000000..1d8cc863ebe3e
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes-fold-shape-expressions.mlir
@@ -0,0 +1,59 @@
+// RUN: mlir-opt --split-input-file --verify-diagnostics --tosa-infer-shapes="fold-shape-expressions" %s | FileCheck %s --check-prefixes=CHECK,DEFAULT
+// RUN: mlir-opt --split-input-file --verify-diagnostics --tosa-infer-shapes="convert-function-boundaries fold-shape-expressions" %s | FileCheck %s --check-prefixes=CHECK,FUNCBOUND
+
+// -----
+
+// CHECK-LABEL: test_simple_shape_expression
+func.func @test_simple_shape_expression(%arg0: tensor<7x12xi32>, %arg1: tensor<80xi32>, %arg2: tensor<4xi32>) -> tensor<?xi32> {
+  // CHECK-NOT: tosa.dim
+  // CHECK-NOT: tosa.add_shape
+  // CHECK: %[[SHAPE:.+]] = tosa.const_shape {values = dense<84> : tensor<1xindex>} : () -> !tosa.shape<1>
+  // CHECK: %[[RESHAPE:.+]] = tosa.reshape %arg0, %[[SHAPE]] : (tensor<7x12xi32>, !tosa.shape<1>) -> tensor<84xi32>
+  // CHECK: %[[TILE:.+]] = tosa.tile %[[RESHAPE]], %[[SHAPE]] : (tensor<84xi32>, !tosa.shape<1>) -> tensor<7056xi32>
+  // DEFAULT: %[[CAST:.+]] = tensor.cast %[[TILE]] : tensor<7056xi32> to tensor<?xi32>
+  // DEFAULT: return %[[CAST]] : tensor<?xi32>
+  // FUNCBOUND: return %[[TILE]] : tensor<7056xi32>
+  %a = tosa.dim %arg1 {axis = 0: i32} : (tensor<80xi32>) -> !tosa.shape<1>
+  %b = tosa.dim %arg2 {axis = 0: i32} : (tensor<4xi32>) -> !tosa.shape<1>
+  %c = tosa.add_shape %a, %b : (!tosa.shape<1>, !tosa.shape<1>) -> !tosa.shape<1>
+  %d = tosa.reshape %arg0, %c : (tensor<7x12xi32>, !tosa.shape<1>) -> tensor<?xi32>
+  %e = tosa.dim %d {axis = 0: i32} : (tensor<?xi32>) -> !tosa.shape<1>
+  %f = tosa.tile %d, %e : (tensor<?xi32>, !tosa.shape<1>) -> tensor<?xi32>
+  return %f : tensor<?xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cond_if_with_shape_expressions
+func.func @test_cond_if_with_shape_expressions(%arg0 : tensor<3xf32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
+  // CHECK: %[[CONST_SHAPE:.*]] = tosa.const_shape {values = dense<3> : tensor<1xindex>} : () -> !tosa.shape<1>
+  // CHECK: tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> {
+  %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<3xf32>, tensor<3xf32>) -> tensor<?xf32> {
+    // CHECK: ^bb0(%arg3: tensor<3xf32>, %arg4: tensor<3xf32>)
+    ^bb0(%arg3 : tensor<?xf32>, %arg4 : tensor<?xf32>):
+      // CHECK-NOT: tosa.dim
+      %0 = tosa.dim %arg3 {axis = 0 : i32} : (tensor<?xf32>) -> !tosa.shape<1>
+      // CHECK: %[[RESHAPE:.*]] = tosa.reshape %arg3, %[[CONST_SHAPE]] : (tensor<3xf32>, !tosa.shape<1>) -> tensor<3xf32>
+      %1 = tosa.reshape %arg3, %0 : (tensor<?xf32>, !tosa.shape<1>) -> tensor<?xf32>
+      // CHECK: tosa.yield %[[RESHAPE]] : tensor<3xf32>
+      tosa.yield %1 : tensor<?xf32>
+  } else {
+    // CHECK: ^bb0(%arg3: tensor<3xf32>, %arg4: tensor<3xf32>)
+    ^bb0(%arg3 : tensor<?xf32>, %arg4 : tensor<?xf32>):
+      // CHECK: tosa.yield %arg4 : tensor<3xf32>
+      tosa.yield %arg4 : tensor<?xf32>
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_no_fold_shape_expression
+func.func @test_no_fold_shape_expression(%arg0: tensor<1x?x3xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
+  // CHECK: tosa.dim
+  %0 = tosa.dim %arg0 {axis = 1: i32} : (tensor<1x?x3xf32>) -> !tosa.shape<1>
+  // CHECK: tosa.tile
+  %1 = tosa.tile %arg1, %0 : (tensor<?xf32>, !tosa.shape<1>) -> tensor<?xf32>
+  // CHECK: return %{{.*}} : tensor<?xf32>
+  return %1 : tensor<?xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index bc5f41b1af304..707cf36f7c7b6 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1,9 +1,12 @@
-// RUN: mlir-opt --split-input-file --tosa-infer-shapes --allow-unregistered-dialect %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --verify-diagnostics --tosa-infer-shapes --allow-unregistered-dialect %s | FileCheck %s --allow-unused-prefixes --check-prefixes=CHECK,DEFAULT
+// RUN: mlir-opt --split-input-file --verify-diagnostics --tosa-infer-shapes="convert-function-boundaries" --allow-unregistered-dialect %s | FileCheck %s --allow-unused-prefixes --check-prefixes=CHECK,FUNCBOUND
 
 // CHECK-LABEL: @test_return
 func.func @test_return(%arg0 : tensor<4xf32>) -> tensor<*xf32> {
-  // CHECK: [[LOG:%.+]] = tosa.log %arg0 : (tensor<4xf32>) -> tensor<4xf32>
-  // CHECK: tensor.cast [[LOG]] : tensor<4xf32> to tensor<*xf32>
+  // CHECK: %[[LOG:.+]] = tosa.log %arg0 : (tensor<4xf32>) -> tensor<4xf32>
+  // DEFAULT: %[[CAST:.+]] = tensor.cast %[[LOG]] : tensor<4xf32> to tensor<*xf32>
+  // DEFAULT: return %[[CAST]] : tensor<*xf32>
+  // FUNCBOUND: return %[[LOG]] : tensor<4xf32>
   %0 = tosa.log %arg0 : (tensor<4xf32>) -> tensor<*xf32>
   return %0 : tensor<*xf32>
 }
@@ -12,13 +15,13 @@ func.func @test_return(%arg0 : tensor<4xf32>) -> tensor<*xf32> {
 
 // CHECK-LABEL: @test_multiple
 func.func @test_multiple(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>, %arg2 : tensor<1xf32>) -> tensor<*xf32> {
-  // CHECK: [[ADD:%.+]] = tosa.add %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
+  // CHECK: %[[ADD:.+]] = tosa.add %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
   %0 = tosa.add %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
 
-  // CHECK: [[LOG:%.+]] = tosa.log %0 : (tensor<4xf32>) -> tensor<4xf32>
+  // CHECK: %[[LOG:.+]] = tosa.log %[[ADD]] : (tensor<4xf32>) -> tensor<4xf32>
   %1 = tosa.log %0 : (tensor<*xf32>) -> tensor<*xf32>
 
-  // CHECK: [[SUB:%.+]] = tosa.sub %0, %arg2 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
+  // CHECK: %[[SUB:.+]] = tosa.sub %[[ADD]], %arg2 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
   %2 = tosa.sub %0, %arg2 : (tensor<*xf32>, tensor<1xf32>) -> tensor<*xf32>
   return %0 : tensor<*xf32>
 }
@@ -346,12 +349,15 @@ func.func @test_unranked_zero_points_matmul(%arg0: tensor<1x2x3xf32>, %arg1: ten
 
 // CHECK-LABEL: @test_accepts_unranked_scalar_tensor
 func.func @test_accepts_unranked_scalar_tensor(%arg0: tensor<1x2x2xf32>, %arg1: tensor<1xf32>) -> tensor<*xf32> {
-  // CHECK: %[[ZP:.*]] = tosa.cast %arg1 : (tensor<1xf32>) -> tensor<1xf32>
+  // CHECK-DAG: %[[SHAPE:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  // CHECK-DAG: %[[ZP:.*]] = tosa.cast %arg1 : (tensor<1xf32>) -> tensor<1xf32>
   %0 = tosa.cast %arg1 : (tensor<1xf32>) -> tensor<*xf32>
-  // CHECK: %[[SHAPE:.*]] = tosa.const_shape
   %1 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
-  // CHECK: tosa.pad %arg0, %[[SHAPE]], %[[ZP]] : (tensor<1x2x2xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<1x3x3xf32>
+  // CHECK: %[[PAD:.*]] = tosa.pad %arg0, %[[SHAPE]], %[[ZP]] : (tensor<1x2x2xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<1x3x3xf32>
   %2 = tosa.pad %arg0, %1, %0 : (tensor<1x2x2xf32>, !tosa.shape<6>, tensor<*xf32>) -> tensor<*xf32>
+  // DEFAULT: %[[CAST:.*]] = tensor.cast %[[PAD]] : tensor<1x3x3xf32> to tensor<*xf32>
+  // DEFAULT: return %[[CAST]] : tensor<*xf32>
+  // FUNCBOUND: return %[[PAD]] : tensor<1x3x3xf32>
   return %2 : tensor<*xf32>
 }
 
@@ -388,18 +394,16 @@ func.func @test_table_dynamic(%arg0 : tensor<4x?xi16>, %arg1 : tensor<513xi16>)
 
 // CHECK-LABEL: @test_static_reshape
 func.func @test_static_reshape(%arg0 : tensor<4x4xi32>) -> () {
-  // CHECK: %[[CONST3:.+]] = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
+  // CHECK-DAG: %[[CONSTSHAPE1:.+]] = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
+  // CHECK-DAG: %[[CONSTSHAPE2:.+]] = tosa.const_shape {values = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1>
+  // CHECK-DAG: %[[CONSTSHAPE3:.+]] = tosa.const_shape {values = dense<[2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // CHECK-DAG: tosa.reshape %arg0, %[[CONSTSHAPE1]] : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32>
+  // CHECK-DAG: tosa.reshape %arg0, %[[CONSTSHAPE2]] : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32>
+  // CHECK-DAG: tosa.reshape %arg0, %[[CONSTSHAPE3]] : (tensor<4x4xi32>, !tosa.shape<2>) -> tensor<2x8xi32>
   %3 = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
-  // CHECK: tosa.reshape %arg0, %[[CONST3]] : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32>
   %0 = tosa.reshape %arg0, %3 : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32>
-
-  // CHECK: %[[CONST4:.+]] = tosa.const_shape {values = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1>
-  // CHECK: tosa.reshape %arg0, %[[CONST4]] : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32>
   %4 = tosa.const_shape {values = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1>
   %1 = tosa.reshape %arg0, %4 : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32>
-
-  // CHECK: %[[CONST5:.+]] = tosa.const_shape {values = dense<[2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
-  // CHECK: tosa.reshape %arg0, %[[CONST5]] : (tensor<4x4xi32>, !tosa.shape<2>) -> tensor<2x8xi32>
   %5 = tosa.const_shape {values = dense<[2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
   %2 = tosa.reshape %arg0, %5 : (tensor<4x4xi32>, !tosa.shape<2>) -> tensor<2x8xi32>
 
@@ -410,19 +414,17 @@ func.func @test_static_reshape(%arg0 : tensor<4x4xi32>) -> () {
 
 // CHECK-LABEL: @test_dynamic_reshape
 func.func @test_dynamic_reshape(%arg0 : tensor<4x?xi32>) -> () {
-  // CHECK: %0 = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
+  // CHECK-DAG: %[[CONSTSHAPE1:.+]] = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
+  // CHECK-DAG: %[[CONSTSHAPE2:.+]] = tosa.const_shape {values = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1>
+  // CHECK-DAG: %[[CONSTSHAPE3:.+]] = tosa.const_shape {values = dense<[2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // CHECK-DAG: tosa.reshape %arg0, %[[CONSTSHAPE1]] : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor<16xi32>
+  // CHECK-DAG: tosa.reshape %arg0, %[[CONSTSHAPE2]] : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor<?xi32>
+  // CHECK-DAG: tosa.reshape %arg0, %[[CONSTSHAPE3]] : (tensor<4x?xi32>, !tosa.shape<2>) -> tensor<2x?xi32>
   %0 = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
-  // CHECK: %1 = tosa.reshape %arg0, %0 : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor<16xi32>
   %1 = tosa.reshape %arg0, %0 : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor<?xi32>
-
-  // CHECK: %2 = tosa.const_shape {values = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1>
   %2 = tosa.const_shape {values = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1>
-  // CHECK: %3 = tosa.reshape %arg0, %2 : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor<?xi32>
   %3 = tosa.reshape %arg0, %2 : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor<?xi32>
-
-  // CHECK: %4 = tosa.const_shape {values = dense<[2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
   %4 = tosa.const_shape {values = dense<[2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
-  // CHECK: %5 = tosa.reshape %arg0, %4 : (tensor<4x?xi32>, !tosa.shape<2>) -> tensor<2x?xi32>
   %5 = tosa.reshape %arg0, %4 : (tensor<4x?xi32>, !tosa.shape<2>) -> tensor<?x?xi32>
 
   return
@@ -545,9 +547,9 @@ func.func @test_padding_simple(%arg0 : tensor<1x2xf32>) -> () {
 
 // CHECK-LABEL: @test_slice
 func.func @test_slice(%arg0 : tensor<?xi32>) -> () {
-  // CHECK: %0 = tosa.const_shape  {values = dense<1> : tensor<1xindex>}
-  // CHECK: %1 = tosa.const_shape  {values = dense<2> : tensor<1xindex>}
-  // CHECK: %2 = tosa.slice %arg0, %0, %1 : (tensor<?xi32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<2xi32>
+  // CHECK-DAG: %[[SIZE:.+]] = tosa.const_shape {values = dense<2> : tensor<1xindex>} : () -> !tosa.shape<1>
+  // CHECK-DAG: %[[START:.+]] = tosa.const_shape {values = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1>
+  // CHECK: %[[SLICE:.+]] = tosa.slice %arg0, %[[START]], %[[SIZE]] : (tensor<?xi32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<2xi32>
   %0 = tosa.const_shape {values = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1>
   %1 = tosa.const_shape {values = dense<2> : tensor<1xindex>} : () -> !tosa.shape<1>
   %2= tosa.slice %arg0, %0, %1 : (tensor<?xi32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<?xi32>
@@ -558,8 +560,8 @@ func.func @test_slice(%arg0 : tensor<?xi32>) -> () {
 
 // CHECK-LABEL: @test_slice_size_minus_one
 func.func @test_slice_size_minus_one(%arg0 : tensor<?x8x8x8xi32>) -> () {
-  // CHECK: %[[START:.+]] = tosa.const_shape
-  // CHECK: %[[SIZE:.+]] = tosa.const_shape
+  // CHECK-DAG: %[[SIZE:.+]] = tosa.const_shape {values = dense<-1> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // CHECK-DAG: %[[START:.+]] = tosa.const_shape {values = dense<[0, 1, -1, 8]> : tensor<4xindex>} : () -> !tosa.shape<4>
   // CHECK: %[[VAL:.+]] = tosa.slice %arg0, %[[START]], %[[SIZE]] : (tensor<?x8x8x8xi32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x7x?x?xi32>
   // this checks following
   //  dim 0: size=-1, input dim=? => inferred output dim is ?
@@ -576,8 +578,8 @@ func.func @test_slice_size_minus_one(%arg0 : tensor<?x8x8x8xi32>) -> () {
 
 // CHECK-LABEL: @test_slice_size_out_of_bound
 func.func @test_slice_size_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () {
-  // CHECK: %[[START:.+]] = tosa.const_shape
-  // CHECK: %[[SIZE:.+]] = tosa.const_shape
+  // CHECK-DAG: %[[SIZE:.+]] = tosa.const_shape {values = dense<[0, -2, 9, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // CHECK-DAG: %[[START:.+]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
   // CHECK: %[[VAL:.+]] = tosa.slice %arg0, %[[START]], %[[SIZE]] : (tensor<8x8x8x?xi32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x?x?x4xi32>
   // this checks following
   //  dim 0: size=0 => inferred output dim is ?
@@ -594,8 +596,8 @@ func.func @test_slice_size_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () {
 
 // CHECK-LABEL: @test_slice_start_out_of_bound
 func.func @test_slice_start_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () {
-  // CHECK: %[[START:.+]] = tosa.const_shape
-  // CHECK: %[[SIZE:.+]] = tosa.const_shape
+  // CHECK-DAG: %[[SIZE:.+]] = tosa.const_shape {values = dense<[1, 1, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // CHECK-DAG: %[[START:.+]] = tosa.const_shape {values = dense<[-1, 8, 6, 8000000]> : tensor<4xindex>} : () -> !tosa.shape<4>
   // CHECK: %[[VAL:.+]] = tosa.slice %arg0, %[[START]], %[[SIZE]] : (tensor<8x8x8x?xi32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x?x?x4xi32>
   // this checks following
   //  dim 0: start=-1 => inferred output dim is ?
@@ -612,9 +614,9 @@ func.func @test_slice_start_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () {
 
 // CHECK-LABEL: @test_slice_dynamic
 func.func @test_slice_dynamic(%arg0 : tensor<10x?x2xf32>) -> () {
-  // CHECK: %0 = tosa.const_shape  {values = dense<[1, 0, 0]> : tensor<3xindex>}
-  // CHECK: %1 = tosa.const_shape  {values = dense<[7, -1, 1]> : tensor<3xindex>}
-  // CHECK: %2 = tosa.slice %arg0, %0, %1 : (tensor<10x?x2xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<7x?x1xf32>
+  // CHECK-DAG: %[[SIZE:.+]] = tosa.const_shape {values = dense<[7, -1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  // CHECK-DAG: %[[START:.+]] = tosa.const_shape {values = dense<[1, 0, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  // CHECK: %[[SLICE:.+]] = tosa.slice %arg0, %[[START]], %[[SIZE]] : (tensor<10x?x2xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<7x?x1xf32>
   %0 = tosa.const_shape {values = dense<[1, 0, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
   %1 = tosa.const_shape {values = dense<[7, -1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
   %2= tosa.slice %arg0, %0, %1 : (tensor<10x?x2xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<?x?x?xf32>
@@ -1164,7 +1166,7 @@ func.func @resize_negative_output_dim(%arg0: tensor<1x3x1x1xi8>) {
   %scale = tosa.const_shape { values = dense<[1, 3, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
   %offset = tosa.const_shape { values = dense<[6, 1]> : tensor<2xindex> } : () -> !tosa.shape<2>
   %border = tosa.const_shape { values = dense<[-15, 0]> : tensor<2xindex> } : () -> !tosa.shape<2>
-  // expected-error at +1 {{calculated output height and width must be non-negative, got height = -5, width = 0}}
+  // expected-error at below {{calculated output height and width must be non-negative, got height = -5, width = 0}}
   %0 = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<1x3x1x1xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xi8>
   return
 }
@@ -1232,6 +1234,25 @@ func.func @if_test_propagate(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : t
 
 // -----
 
+func.func @if_test_propagate_dynamic(%arg0 : tensor<3xf32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
+  // CHECK: tosa.cond_if
+  // CHECK: -> tensor<3xf32>
+  %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<3xf32>, tensor<3xf32>) -> tensor<?xf32> {
+    // CHECK: ^bb0(%arg3: tensor<3xf32>, %arg4: tensor<3xf32>)
+    ^bb0(%arg3 : tensor<?xf32>, %arg4 : tensor<?xf32>):
+      // CHECK: tosa.yield %arg3 : tensor<3xf32>
+      tosa.yield %arg3 : tensor<?xf32>
+  } else {
+    // CHECK: ^bb0(%arg3: tensor<3xf32>, %arg4: tensor<3xf32>)
+    ^bb0(%arg3 : tensor<?xf32>, %arg4 : tensor<?xf32>):
+      // CHECK: tosa.yield %arg4 : tensor<3xf32>
+      tosa.yield %arg4 : tensor<?xf32>
+  }
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @while_test
 func.func @while_test(%arg0 : tensor<i32>) -> (tensor<*xi32>) {
   // CHECK:      tosa.add
@@ -1267,7 +1288,9 @@ func.func @while_test(%arg0 : tensor<i32>) -> (tensor<*xi32>) {
     tosa.yield %3 : tensor<*xi32>
   }
 
-  // CHECK:      tensor.cast
+  // DEFAULT: %[[CAST:.+]] = tensor.cast %{{.*}} : tensor<i32> to tensor<*xi32>
+  // DEFAULT: return %[[CAST]] : tensor<*xi32>
+  // FUNCBOUND: return %{{.*}} : tensor<i32>
   return %1 : tensor<*xi32>
 }
 
@@ -1341,7 +1364,9 @@ func.func @while_dont_crash(%arg0 : tensor<i32>) -> (tensor<*xi32>) {
     "use"(%3) : (tensor<*xi32>) -> ()
     tosa.yield %3 : tensor<*xi32>
   }
-  // CHECK: tensor.cast
+  // DEFAULT: %[[CAST:.+]] = tensor.cast
+  // DEFAULT: return %[[CAST]] : tensor<*xi32>
+  // FUNCBOUND: return %{{.*}} : tensor<i32>
   return %1 : tensor<*xi32>
 }
 
@@ -1397,7 +1422,9 @@ func.func @while_dont_crash_nested(%arg0 : tensor<i32>) -> (tensor<*xi32>) {
     tosa.yield %1 : tensor<*xi32>
   }
 
-  // CHECK: tensor.cast
+  // DEFAULT: %[[CAST:.+]] = tensor.cast
+  // DEFAULT: return %[[CAST]] : tensor<*xi32>
+  // FUNCBOUND: return %{{.*}} : tensor<i32>
   return %1 : tensor<*xi32>
 }
 
@@ -1676,79 +1703,20 @@ func.func @test_cast_to_block_scaled_dynamic_scales(%arg0: tensor<4x?xf32>) -> (
 
 // -----
 
-// CHECK-LABEL: test_conv2d_block_scaled_static
-func.func @test_conv2d_block_scaled_static(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<*xf32> {
-  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
-  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
-  %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
-  // CHECK: -> tensor<1x4x4x8xf32>
-  %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
-  return %0 : tensor<*xf32>
-}
-
-// -----
-
-// CHECK-LABEL: test_conv2d_block_scaled_dynamic_scales
-func.func @test_conv2d_block_scaled_dynamic_scales(%arg0: tensor<?x4x4x64xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<?x1x1x64xf4E2M1FN>, %arg3: tensor<*xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<*xf32> {
-  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
-  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
-  %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
-  // CHECK: -> tensor<?x4x4x?xf32>
-  %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<?x4x4x64xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<?x1x1x64xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
-  return %0 : tensor<*xf32>
-}
-
-// -----
-
-// CHECK-LABEL: test_conv2d_block_scaled_dynamic_data
-func.func @test_conv2d_block_scaled_dynamic_data(%arg0: tensor<*xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<*xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<*xf32> {
-  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
-  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
-  %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
-  // CHECK: -> tensor<1x4x4x8xf32>
-  %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<*xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
-  return %0 : tensor<*xf32>
-}
-
-// -----
-
-// CHECK-LABEL: test_conv2d_block_scaled_dynamic_unranked
-func.func @test_conv2d_block_scaled_dynamic_unranked(%arg0: tensor<*xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<*xf4E2M1FN>, %arg3: tensor<*xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<*xf32> {
-  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
-  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
-  %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
-  // CHECK: -> tensor<?x?x?x?xf32>
-  %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
-  return %0 : tensor<*xf32>
-}
-
-// -----
-
-// CHECK-LABEL: test_dwconv2d_bias_broadcast
-func.func @test_dwconv2d_bias_broadcast(%input: tensor<2x8x9x?xf32>, %weight: tensor<3x3x?x?xf32>, %bias: tensor<1xf32>, %input_zp: tensor<1xf32>, %weight_zp: tensor<1xf32>) {
-  // CHECK: -> tensor<2x6x7x?xf32>
-  %0 = tosa.depthwise_conv2d %input, %weight, %bias, %input_zp, %weight_zp
-       { acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> }
-       : (tensor<2x8x9x?xf32>, tensor<3x3x?x?xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x?xf32>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: test_tconv2d_bias_broadcast
-func.func @test_tconv2d_bias_broadcast(%input: tensor<2x6x7x3xf32>, %weight: tensor<?x3x3x3xf32>, %bias: tensor<1xf32>, %input_zp: tensor<1xf32>, %weight_zp: tensor<1xf32>) {
-  // CHECK: -> tensor<2x8x9x?xf32>
-  %0 = tosa.transpose_conv2d %input, %weight, %bias, %input_zp, %weight_zp
-       { acc_type = f32, pad = array<i64: 0, 0, 0, 0>, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> }
-       : (tensor<2x6x7x3xf32>, tensor<?x3x3x3xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x?xf32>
-    return
-  }
-
-// -----
-
-// CHECK-LABEL: test_avg_pool2d_unranked_input
-func.func @test_avg_pool2d_unranked_input(%input: tensor<*xi32>, %zp: tensor<1xi32>) {
-  // CHECK: -> tensor<?x?x?x?xi32>
-  %0 = tosa.avg_pool2d %input, %zp, %zp { acc_type = i32, kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1> } : (tensor<*xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
-  return
+// CHECK-LABEL: test_simple_shape_expression
+func.func @test_simple_shape_expression(%arg0: tensor<7x12xi32>, %arg1: tensor<80xi32>, %arg2: tensor<4xi32>) -> tensor<?xi32> {
+  // CHECK: %[[DIM1:.+]] = tosa.dim
+  // CHECK: %[[DIM2:.+]] = tosa.dim
+  // CHECK: %[[ADD_SHAPE:.+]] = tosa.add_shape
+  // CHECK: %[[RESHAPE:.+]] = tosa.reshape %arg0, %[[ADD_SHAPE]] : (tensor<7x12xi32>, !tosa.shape<1>) -> tensor<?xi32>
+  // CHECK: %[[DIM3:.+]] = tosa.dim %[[RESHAPE]]
+  // CHECK: %[[TILE:.+]] = tosa.tile %[[RESHAPE]], %[[DIM3]] : (tensor<?xi32>, !tosa.shape<1>) -> tensor<?xi32>
+  // CHECK: return %[[TILE]] : tensor<?xi32>
+  %a = tosa.dim %arg1 {axis = 0: i32} : (tensor<80xi32>) -> !tosa.shape<1>
+  %b = tosa.dim %arg2 {axis = 0: i32} : (tensor<4xi32>) -> !tosa.shape<1>
+  %c = tosa.add_shape %a, %b : (!tosa.shape<1>, !tosa.shape<1>) -> !tosa.shape<1>
+  %d = tosa.reshape %arg0, %c : (tensor<7x12xi32>, !tosa.shape<1>) -> tensor<?xi32>
+  %e = tosa.dim %d {axis = 0: i32} : (tensor<?xi32>) -> !tosa.shape<1>
+  %f = tosa.tile %d, %e : (tensor<?xi32>, !tosa.shape<1>) -> tensor<?xi32>
+  return %f : tensor<?xi32>
 }

>From ef3472e90a8152e59cbd50068f271da7eb70df1f Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Tue, 3 Feb 2026 09:26:14 +0000
Subject: [PATCH 3/3] Address review comments and remove dead casts correctly

Change-Id: Ic37435a0fae3d439f4355d62ac9d17c02fcba974
---
 .../Tosa/Transforms/TosaInferShapes.cpp       | 15 +++-
 ...a-infer-shapes-fold-shape-expressions.mlir |  1 +
 mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 79 +++++++++++++++++++
 3 files changed, 92 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index 37644ee8c03f8..60c9aa20ac77d 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -371,16 +371,17 @@ struct TosaInferShapes
     // Rewrite func.return ops, removing dead tensor.cast ops if possible
     func.walk([&rewriter, &newReturnTypes](func::ReturnOp ret) {
       SmallVector<Value> newReturnValues;
+      SmallVector<Value> maybeDeadCasts;
       OperandRange returnOperands = ret.getOperands();
       newReturnValues.reserve(returnOperands.size());
-      newReturnTypes.reserve(returnOperands.size());
+      maybeDeadCasts.reserve(returnOperands.size());
+      newReturnTypes.reserve(newReturnTypes.size() + returnOperands.size());
 
       for (const Value &v : returnOperands) {
         Value newReturnValue = v;
         if (auto castOp = v.getDefiningOp<tensor::CastOp>()) {
           newReturnValue = castOp.getSource();
-          if (castOp->use_empty())
-            rewriter.eraseOp(castOp);
+          maybeDeadCasts.push_back(castOp);
         }
         newReturnValues.push_back(newReturnValue);
         newReturnTypes.push_back(newReturnValue.getType());
@@ -388,6 +389,14 @@ struct TosaInferShapes
 
       rewriter.setInsertionPoint(ret);
       rewriter.replaceOpWithNewOp<func::ReturnOp>(ret, newReturnValues);
+
+      if (!maybeDeadCasts.empty()) {
+        llvm::for_each(maybeDeadCasts, [&](Value castVal) {
+          if (castVal.use_empty()) {
+            rewriter.eraseOp(castVal.getDefiningOp());
+          }
+        });
+      }
     });
 
     // Update function return types with newly inferred types
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes-fold-shape-expressions.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes-fold-shape-expressions.mlir
index 1d8cc863ebe3e..73fb7e9cbca4d 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes-fold-shape-expressions.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes-fold-shape-expressions.mlir
@@ -12,6 +12,7 @@ func.func @test_simple_shape_expression(%arg0: tensor<7x12xi32>, %arg1: tensor<8
   // CHECK: %[[TILE:.+]] = tosa.tile %[[RESHAPE]], %[[SHAPE]] : (tensor<84xi32>, !tosa.shape<1>) -> tensor<7056xi32>
   // DEFAULT: %[[CAST:.+]] = tensor.cast %[[TILE]] : tensor<7056xi32> to tensor<?xi32>
   // DEFAULT: return %[[CAST]] : tensor<?xi32>
+  // FUNCBOUND-NOT: tensor.cast
   // FUNCBOUND: return %[[TILE]] : tensor<7056xi32>
   %a = tosa.dim %arg1 {axis = 0: i32} : (tensor<80xi32>) -> !tosa.shape<1>
   %b = tosa.dim %arg2 {axis = 0: i32} : (tensor<4xi32>) -> !tosa.shape<1>
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 707cf36f7c7b6..61085afc3b6ea 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1720,3 +1720,82 @@ func.func @test_simple_shape_expression(%arg0: tensor<7x12xi32>, %arg1: tensor<8
   %f = tosa.tile %d, %e : (tensor<?xi32>, !tosa.shape<1>) -> tensor<?xi32>
   return %f : tensor<?xi32>
 }
+
+// -----
+
+// CHECK-LABEL: test_conv2d_block_scaled_static
+func.func @test_conv2d_block_scaled_static(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<*xf32> {
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // CHECK: -> tensor<1x4x4x8xf32>
+  %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_conv2d_block_scaled_dynamic_scales
+func.func @test_conv2d_block_scaled_dynamic_scales(%arg0: tensor<?x4x4x64xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<?x1x1x64xf4E2M1FN>, %arg3: tensor<*xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<*xf32> {
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // CHECK: -> tensor<?x4x4x?xf32>
+  %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<?x4x4x64xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<?x1x1x64xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_conv2d_block_scaled_dynamic_data
+func.func @test_conv2d_block_scaled_dynamic_data(%arg0: tensor<*xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<*xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<*xf32> {
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // CHECK: -> tensor<1x4x4x8xf32>
+  %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<*xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_conv2d_block_scaled_dynamic_unranked
+func.func @test_conv2d_block_scaled_dynamic_unranked(%arg0: tensor<*xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<*xf4E2M1FN>, %arg3: tensor<*xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<*xf32> {
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // CHECK: -> tensor<?x?x?x?xf32>
+  %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_dwconv2d_bias_broadcast
+func.func @test_dwconv2d_bias_broadcast(%input: tensor<2x8x9x?xf32>, %weight: tensor<3x3x?x?xf32>, %bias: tensor<1xf32>, %input_zp: tensor<1xf32>, %weight_zp: tensor<1xf32>) {
+  // CHECK: -> tensor<2x6x7x?xf32>
+  %0 = tosa.depthwise_conv2d %input, %weight, %bias, %input_zp, %weight_zp
+       { acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> }
+       : (tensor<2x8x9x?xf32>, tensor<3x3x?x?xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x?xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_tconv2d_bias_broadcast
+func.func @test_tconv2d_bias_broadcast(%input: tensor<2x6x7x3xf32>, %weight: tensor<?x3x3x3xf32>, %bias: tensor<1xf32>, %input_zp: tensor<1xf32>, %weight_zp: tensor<1xf32>) {
+  // CHECK: -> tensor<2x8x9x?xf32>
+  %0 = tosa.transpose_conv2d %input, %weight, %bias, %input_zp, %weight_zp
+       { acc_type = f32, pad = array<i64: 0, 0, 0, 0>, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> }
+       : (tensor<2x6x7x3xf32>, tensor<?x3x3x3xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x?xf32>
+    return
+  }
+
+// -----
+
+// CHECK-LABEL: test_avg_pool2d_unranked_input
+func.func @test_avg_pool2d_unranked_input(%input: tensor<*xi32>, %zp: tensor<1xi32>) {
+  // CHECK: -> tensor<?x?x?x?xi32>
+  %0 = tosa.avg_pool2d %input, %zp, %zp { acc_type = i32, kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1> } : (tensor<*xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
+  return
+}



More information about the Mlir-commits mailing list