[Mlir-commits] [mlir] [mlir][tosa] Enhance TosaInferShapes pass for simple shape inference (PR #178418)

Luke Hutton llvmlistbot at llvm.org
Wed Feb 25 06:37:41 PST 2026


https://github.com/lhutton1 updated https://github.com/llvm/llvm-project/pull/178418

>From cf31801fa98b70d033e14bea39329e59a57b0aeb Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Tue, 27 Jan 2026 17:44:04 +0000
Subject: [PATCH 1/5] Move propagateShapesIn* functions to private methods

This allows pass parameters to be accessed without plumbing
them through multiple layers of function calls.

Change-Id: I45a64739b00cc30ed1aed0698fe6c75a6321a9e7
---
 .../Tosa/Transforms/TosaInferShapes.cpp       | 348 +++++++++---------
 1 file changed, 175 insertions(+), 173 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index 4d347c02ee16d..a62b2278c4362 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -128,179 +128,6 @@ class TypeModificationState {
   llvm::SmallVector<std::pair<Value, Type>> oldTypes;
 };
 
-void propagateShapesInRegion(Region &region, TypeModificationState &state);
-
-void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) {
-  IfOp ifOp = dyn_cast<IfOp>(op);
-  if (!ifOp)
-    return;
-
-  for (auto &region : op.getRegions()) {
-    Block &frontBlock = region.front();
-    if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands())
-      return;
-
-    for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
-      auto inferredTy = cast<ShapedType>(op.getOperand(i).getType());
-      auto blockArg = frontBlock.getArgument(i - 1);
-      auto oldType = cast<ShapedType>(blockArg.getType());
-
-      if (inferredTy.hasRank()) {
-        Type newType = oldType.clone(inferredTy.getShape());
-        state.setType(blockArg, newType);
-      }
-    }
-
-    for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) {
-      ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType(
-          ifOp.getOperand(i + 1).getType());
-      ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType(
-          frontBlock.getArgument(i).getType());
-      ValueKnowledge joinedKnowledge =
-          ValueKnowledge::join(operandKnowledge, blockKnowledge);
-      if (!joinedKnowledge)
-        continue;
-      state.setType(frontBlock.getArgument(i), joinedKnowledge.getType());
-    }
-
-    propagateShapesInRegion(region, state);
-  }
-}
-
-void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
-  WhileOp whileOp = dyn_cast<WhileOp>(op);
-  if (!whileOp)
-    return;
-
-  // Determine what the expected argument types are to the cond/body blocks.
-  // The expected arguments should be compatible with ever iteration of the
-  // loop body / condition for tosa.while.
-  SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes());
-
-  bool hasNewTypes = true;
-  while (hasNewTypes) {
-    TypeModificationState localState;
-
-    // Set types on the block args.
-    Region &bodyRegion = op.getRegion(1);
-    Block &block = bodyRegion.front();
-    for (int i = 0, s = argTypes.size(); i < s; i++) {
-      localState.setType(block.getArgument(i), argTypes[i]);
-    }
-
-    // Propagate to the end.
-    propagateShapesInRegion(bodyRegion, localState);
-
-    // Find all the tosa yield types and verify there is a single one.
-    llvm::SmallVector<YieldOp> yieldOps;
-    for (auto &block : bodyRegion)
-      if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
-        yieldOps.push_back(yieldOp);
-
-    assert(yieldOps.size() == 1 && "missing or non-unique yield op");
-    // Using the new tosa.yield operand types, infer the new subtypes.
-    llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
-    for (auto ty : argTypes) {
-      yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty));
-    }
-
-    for (auto yieldOp : yieldOps) {
-      for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
-        auto newKnowledge =
-            ValueKnowledge::getKnowledgeFromType(it.value().getType());
-        yieldTypeInfo[it.index()] =
-            ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge);
-      }
-    }
-
-    // This should never happen.
-    if (yieldTypeInfo.size() != argTypes.size()) {
-      op.emitWarning("has a tosa.yield with the incorrect number of operands");
-      return;
-    }
-
-    // Determine the new block args and see if any changed.
-    hasNewTypes = false;
-    for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
-      Type newType = yieldTypeInfo[i].getType();
-      hasNewTypes |= (newType != argTypes[i]);
-      argTypes[i] = newType;
-    }
-
-    // Roll back all changes made during the speculative part of the algorithm.
-    localState.rollBack();
-  }
-
-  // We now set the block arguments according to the most recent shape
-  // inference results. This gives us the block arg types for the next
-  // iteration.
-  for (auto &region : op.getRegions()) {
-    for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
-      state.setType(region.front().getArgument(i), argTypes[i]);
-    }
-
-    propagateShapesInRegion(region, state);
-  }
-}
-
-void propagateShapesInRegion(Region &region, TypeModificationState &state) {
-  Dialect *tosaDialect = region.getContext()->getLoadedDialect<TosaDialect>();
-
-  for (auto &block : region) {
-    for (Operation &op : block) {
-      if (op.getDialect() != tosaDialect)
-        continue;
-
-      propagateShapesToTosaIf(op, state);
-      propagateShapesToTosaWhile(op, state);
-
-      InferShapedTypeOpInterface shapeInterface =
-          dyn_cast<InferShapedTypeOpInterface>(op);
-      if (!shapeInterface)
-        continue;
-
-      SmallVector<ShapedTypeComponents> returnedShapes;
-
-      if (shapeInterface
-              .inferReturnTypeComponents(
-                  op.getContext(), op.getLoc(), op.getOperands(),
-                  op.getDiscardableAttrDictionary(), op.getPropertiesStorage(),
-                  op.getRegions(), returnedShapes)
-              .succeeded()) {
-        for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
-          Value result = std::get<0>(it);
-          ShapedTypeComponents predictedShape = std::get<1>(it);
-
-          // Determine the knowledge based on the output type.
-          // TODO: should also query WIP type probably
-          Type resultTy = result.getType();
-          auto currentKnowledge =
-              ValueKnowledge::getKnowledgeFromType(resultTy);
-
-          // Compute the knowledge based on the inferred type.
-          auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
-          inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
-          inferredKnowledge.hasRank = predictedShape.hasRank();
-          if (predictedShape.hasRank()) {
-            for (auto dim : predictedShape.getDims()) {
-              inferredKnowledge.sizes.push_back(dim);
-            }
-          }
-
-          // Compute the new type based on the joined version.
-          auto newKnowledge =
-              ValueKnowledge::join(currentKnowledge, inferredKnowledge);
-          if (!newKnowledge)
-            continue;
-
-          // Set new type
-          state.setType(result, newKnowledge.getType());
-        }
-      }
-    }
-  }
-}
-
 /// Recursively validate tosa ops with SameOperandsAndResultRank trait in region
 /// and all nested regions
 void validateSameOperandsAndResultRankTrait(Region &region) {
@@ -341,5 +168,180 @@ struct TosaInferShapes
 
     validateSameOperandsAndResultRankTrait(func.getBody());
   }
+
+private:
+  void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) {
+    IfOp ifOp = dyn_cast<IfOp>(op);
+    if (!ifOp)
+      return;
+
+    for (auto &region : op.getRegions()) {
+      Block &frontBlock = region.front();
+      if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands())
+        return;
+
+      for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
+        auto inferredTy = cast<ShapedType>(op.getOperand(i).getType());
+        auto blockArg = frontBlock.getArgument(i - 1);
+        auto oldType = cast<ShapedType>(blockArg.getType());
+
+        if (inferredTy.hasRank()) {
+          Type newType = oldType.clone(inferredTy.getShape());
+          state.setType(blockArg, newType);
+        }
+      }
+
+      for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) {
+        ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType(
+            ifOp.getOperand(i + 1).getType());
+        ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType(
+            frontBlock.getArgument(i).getType());
+        ValueKnowledge joinedKnowledge =
+            ValueKnowledge::join(operandKnowledge, blockKnowledge);
+        if (!joinedKnowledge)
+          continue;
+        state.setType(frontBlock.getArgument(i), joinedKnowledge.getType());
+      }
+
+      propagateShapesInRegion(region, state);
+    }
+  }
+
+  void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
+    WhileOp whileOp = dyn_cast<WhileOp>(op);
+    if (!whileOp)
+      return;
+
+    // Determine what the expected argument types are to the cond/body blocks.
+    // The expected arguments should be compatible with ever iteration of the
+    // loop body / condition for tosa.while.
+    SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes());
+
+    bool hasNewTypes = true;
+    while (hasNewTypes) {
+      TypeModificationState localState;
+
+      // Set types on the block args.
+      Region &bodyRegion = op.getRegion(1);
+      Block &block = bodyRegion.front();
+      for (int i = 0, s = argTypes.size(); i < s; i++) {
+        localState.setType(block.getArgument(i), argTypes[i]);
+      }
+
+      // Propagate to the end.
+      propagateShapesInRegion(bodyRegion, localState);
+
+      // Find all the tosa yield types and verify there is a single one.
+      llvm::SmallVector<YieldOp> yieldOps;
+      for (auto &block : bodyRegion)
+        if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
+          yieldOps.push_back(yieldOp);
+
+      assert(yieldOps.size() == 1 && "missing or non-unique yield op");
+      // Using the new tosa.yield operand types, infer the new subtypes.
+      llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
+      for (auto ty : argTypes) {
+        yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty));
+      }
+
+      for (auto yieldOp : yieldOps) {
+        for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
+          auto newKnowledge =
+              ValueKnowledge::getKnowledgeFromType(it.value().getType());
+          yieldTypeInfo[it.index()] =
+              ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge);
+        }
+      }
+
+      // This should never happen.
+      if (yieldTypeInfo.size() != argTypes.size()) {
+        op.emitWarning(
+            "has a tosa.yield with the incorrect number of operands");
+        return;
+      }
+
+      // Determine the new block args and see if any changed.
+      hasNewTypes = false;
+      for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
+        Type newType = yieldTypeInfo[i].getType();
+        hasNewTypes |= (newType != argTypes[i]);
+        argTypes[i] = newType;
+      }
+
+      // Roll back all changes made during the speculative part of the
+      // algorithm.
+      localState.rollBack();
+    }
+
+    // We now set the block arguments according to the most recent shape
+    // inference results. This gives us the block arg types for the next
+    // iteration.
+    for (auto &region : op.getRegions()) {
+      for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
+        state.setType(region.front().getArgument(i), argTypes[i]);
+      }
+
+      propagateShapesInRegion(region, state);
+    }
+  }
+
+  void propagateShapesInRegion(Region &region, TypeModificationState &state) {
+    Dialect *tosaDialect = region.getContext()->getLoadedDialect<TosaDialect>();
+
+    for (auto &block : region) {
+      for (Operation &op : block) {
+        if (op.getDialect() != tosaDialect)
+          continue;
+
+        propagateShapesToTosaIf(op, state);
+        propagateShapesToTosaWhile(op, state);
+
+        InferShapedTypeOpInterface shapeInterface =
+            dyn_cast<InferShapedTypeOpInterface>(op);
+        if (!shapeInterface)
+          continue;
+
+        SmallVector<ShapedTypeComponents> returnedShapes;
+
+        if (shapeInterface
+                .inferReturnTypeComponents(
+                    op.getContext(), op.getLoc(), op.getOperands(),
+                    op.getDiscardableAttrDictionary(),
+                    op.getPropertiesStorage(), op.getRegions(), returnedShapes)
+                .succeeded()) {
+          for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
+            Value result = std::get<0>(it);
+            ShapedTypeComponents predictedShape = std::get<1>(it);
+
+            // Determine the knowledge based on the output type.
+            // TODO: should also query WIP type probably
+            Type resultTy = result.getType();
+            auto currentKnowledge =
+                ValueKnowledge::getKnowledgeFromType(resultTy);
+
+            // Compute the knowledge based on the inferred type.
+            auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
+            inferredKnowledge.dtype =
+                cast<ShapedType>(resultTy).getElementType();
+            inferredKnowledge.hasRank = predictedShape.hasRank();
+            if (predictedShape.hasRank()) {
+              for (auto dim : predictedShape.getDims()) {
+                inferredKnowledge.sizes.push_back(dim);
+              }
+            }
+
+            // Compute the new type based on the joined version.
+            auto newKnowledge =
+                ValueKnowledge::join(currentKnowledge, inferredKnowledge);
+            if (!newKnowledge)
+              continue;
+
+            // Set new type
+            state.setType(result, newKnowledge.getType());
+          }
+        }
+      }
+    }
+  }
 };
 } // namespace

>From d4e55edb751dc65f1973361b91158d142f9b3c1c Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Tue, 27 Jan 2026 18:28:27 +0000
Subject: [PATCH 2/5] [mlir][tosa] Enhance TosaInferShapes pass for simple
 shape inference

This commit enhances the TosaInferShapes pass with two new options:
- fold-shape-expressions
- convert-function-boundaries

The "fold-shape-expressions" option enables greedily folding the
newly added TOSA shape operations when possible. Folding these
operations directly within TosaInferShapes is useful since it allows
shapes of later operations to be inferred in a single pass.

The "convert-function-boundaries" updates the return types of a
function to the newly inferred output shapes. This avoids the need for
additional tensor.cast operations at function boundaries. This option
is particularly useful when wanting to resolve a dynamic function to
fully static.

When both of these options are used in conjuction with the
"tosa-input-shapes" pass option, it's possible to resolve a dynamic
function to static in a single pass.

Change-Id: I29dffd9910cd7ef6ed9fda7d38ec24accb48e598
---
 .../mlir/Dialect/Tosa/Transforms/Passes.td    |  10 ++
 .../Tosa/Transforms/TosaInferShapes.cpp       |  57 ++++++++-
 ...a-infer-shapes-fold-shape-expressions.mlir |  59 +++++++++
 mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 121 ++++++++++++------
 4 files changed, 208 insertions(+), 39 deletions(-)
 create mode 100644 mlir/test/Dialect/Tosa/tosa-infer-shapes-fold-shape-expressions.mlir

diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 03be41d684f3f..5979ce4962e55 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -42,6 +42,16 @@ def TosaInferShapesPass : Pass<"tosa-infer-shapes", "func::FuncOp"> {
     "tensor::TensorDialect",
     "tosa::TosaDialect",
   ];
+
+  let options = [
+    Option<"foldShapeExpressions", "fold-shape-expressions", "bool",
+           /*default=*/"false",
+           "Fold TOSA shape operations when they have known input values">,
+    Option<"convertFunctionBoundaries", "convert-function-boundaries", "bool",
+           /*default=*/"false",
+           "If enabled, the pass will convert function I/O types as well. Otherwise casts will"
+           "be inserted at the I/O boundaries.">
+  ];
 }
 
 def TosaMakeBroadcastablePass
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index a62b2278c4362..37644ee8c03f8 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -20,6 +20,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/FoldUtils.h"
 
 namespace mlir {
 namespace tosa {
@@ -160,6 +161,13 @@ void validateSameOperandsAndResultRankTrait(Region &region) {
 struct TosaInferShapes
     : public tosa::impl::TosaInferShapesPassBase<TosaInferShapes> {
 public:
+  explicit TosaInferShapes() = default;
+  explicit TosaInferShapes(const TosaInferShapesPassOptions &options)
+      : TosaInferShapes() {
+    this->foldShapeExpressions = options.foldShapeExpressions;
+    this->convertFunctionBoundaries = options.convertFunctionBoundaries;
+  }
+
   void runOnOperation() override {
     func::FuncOp func = getOperation();
     TypeModificationState state;
@@ -167,6 +175,9 @@ struct TosaInferShapes
     state.commit();
 
     validateSameOperandsAndResultRankTrait(func.getBody());
+
+    if (convertFunctionBoundaries)
+      convertFunctionReturnTypes(func);
   }
 
 private:
@@ -286,16 +297,25 @@ struct TosaInferShapes
   }
 
   void propagateShapesInRegion(Region &region, TypeModificationState &state) {
-    Dialect *tosaDialect = region.getContext()->getLoadedDialect<TosaDialect>();
+    MLIRContext *ctx = region.getContext();
+    Dialect *tosaDialect = ctx->getLoadedDialect<TosaDialect>();
+    OperationFolder folder(ctx);
 
     for (auto &block : region) {
-      for (Operation &op : block) {
+      for (auto it = block.begin(); it != block.end();) {
+        Operation &op = *it++;
         if (op.getDialect() != tosaDialect)
           continue;
 
         propagateShapesToTosaIf(op, state);
         propagateShapesToTosaWhile(op, state);
 
+        if (foldShapeExpressions &&
+            op.hasTrait<OpTrait::tosa::TosaShapeOperator>()) {
+          (void)folder.tryToFold(&op);
+          continue;
+        }
+
         InferShapedTypeOpInterface shapeInterface =
             dyn_cast<InferShapedTypeOpInterface>(op);
         if (!shapeInterface)
@@ -343,5 +363,38 @@ struct TosaInferShapes
       }
     }
   }
+
+  void convertFunctionReturnTypes(func::FuncOp func) {
+    IRRewriter rewriter(func.getContext());
+    SmallVector<Type> newReturnTypes;
+
+    // Rewrite func.return ops, removing dead tensor.cast ops if possible
+    func.walk([&rewriter, &newReturnTypes](func::ReturnOp ret) {
+      SmallVector<Value> newReturnValues;
+      OperandRange returnOperands = ret.getOperands();
+      newReturnValues.reserve(returnOperands.size());
+      newReturnTypes.reserve(returnOperands.size());
+
+      for (const Value &v : returnOperands) {
+        Value newReturnValue = v;
+        if (auto castOp = v.getDefiningOp<tensor::CastOp>()) {
+          newReturnValue = castOp.getSource();
+          if (castOp->use_empty())
+            rewriter.eraseOp(castOp);
+        }
+        newReturnValues.push_back(newReturnValue);
+        newReturnTypes.push_back(newReturnValue.getType());
+      }
+
+      rewriter.setInsertionPoint(ret);
+      rewriter.replaceOpWithNewOp<func::ReturnOp>(ret, newReturnValues);
+    });
+
+    // Update function return types with newly inferred types
+    const FunctionType oldType = func.getFunctionType();
+    const FunctionType newType = FunctionType::get(
+        func.getContext(), oldType.getInputs(), newReturnTypes);
+    func.setType(newType);
+  }
 };
 } // namespace
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes-fold-shape-expressions.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes-fold-shape-expressions.mlir
new file mode 100644
index 0000000000000..1d8cc863ebe3e
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes-fold-shape-expressions.mlir
@@ -0,0 +1,59 @@
+// RUN: mlir-opt --split-input-file --verify-diagnostics --tosa-infer-shapes="fold-shape-expressions" %s | FileCheck %s --check-prefixes=CHECK,DEFAULT
+// RUN: mlir-opt --split-input-file --verify-diagnostics --tosa-infer-shapes="convert-function-boundaries fold-shape-expressions" %s | FileCheck %s --check-prefixes=CHECK,FUNCBOUND
+
+// -----
+
+// CHECK-LABEL: test_simple_shape_expression
+func.func @test_simple_shape_expression(%arg0: tensor<7x12xi32>, %arg1: tensor<80xi32>, %arg2: tensor<4xi32>) -> tensor<?xi32> {
+  // CHECK-NOT: tosa.dim
+  // CHECK-NOT: tosa.add_shape
+  // CHECK: %[[SHAPE:.+]] = tosa.const_shape {values = dense<84> : tensor<1xindex>} : () -> !tosa.shape<1>
+  // CHECK: %[[RESHAPE:.+]] = tosa.reshape %arg0, %[[SHAPE]] : (tensor<7x12xi32>, !tosa.shape<1>) -> tensor<84xi32>
+  // CHECK: %[[TILE:.+]] = tosa.tile %[[RESHAPE]], %[[SHAPE]] : (tensor<84xi32>, !tosa.shape<1>) -> tensor<7056xi32>
+  // DEFAULT: %[[CAST:.+]] = tensor.cast %[[TILE]] : tensor<7056xi32> to tensor<?xi32>
+  // DEFAULT: return %[[CAST]] : tensor<?xi32>
+  // FUNCBOUND: return %[[TILE]] : tensor<7056xi32>
+  %a = tosa.dim %arg1 {axis = 0: i32} : (tensor<80xi32>) -> !tosa.shape<1>
+  %b = tosa.dim %arg2 {axis = 0: i32} : (tensor<4xi32>) -> !tosa.shape<1>
+  %c = tosa.add_shape %a, %b : (!tosa.shape<1>, !tosa.shape<1>) -> !tosa.shape<1>
+  %d = tosa.reshape %arg0, %c : (tensor<7x12xi32>, !tosa.shape<1>) -> tensor<?xi32>
+  %e = tosa.dim %d {axis = 0: i32} : (tensor<?xi32>) -> !tosa.shape<1>
+  %f = tosa.tile %d, %e : (tensor<?xi32>, !tosa.shape<1>) -> tensor<?xi32>
+  return %f : tensor<?xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cond_if_with_shape_expressions
+func.func @test_cond_if_with_shape_expressions(%arg0 : tensor<3xf32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
+  // CHECK: %[[CONST_SHAPE:.*]] = tosa.const_shape {values = dense<3> : tensor<1xindex>} : () -> !tosa.shape<1>
+  // CHECK: tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> {
+  %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<3xf32>, tensor<3xf32>) -> tensor<?xf32> {
+    // CHECK: ^bb0(%arg3: tensor<3xf32>, %arg4: tensor<3xf32>)
+    ^bb0(%arg3 : tensor<?xf32>, %arg4 : tensor<?xf32>):
+      // CHECK-NOT: tosa.dim
+      %0 = tosa.dim %arg3 {axis = 0 : i32} : (tensor<?xf32>) -> !tosa.shape<1>
+      // CHECK: %[[RESHAPE:.*]] = tosa.reshape %arg3, %[[CONST_SHAPE]] : (tensor<3xf32>, !tosa.shape<1>) -> tensor<3xf32>
+      %1 = tosa.reshape %arg3, %0 : (tensor<?xf32>, !tosa.shape<1>) -> tensor<?xf32>
+      // CHECK: tosa.yield %[[RESHAPE]] : tensor<3xf32>
+      tosa.yield %1 : tensor<?xf32>
+  } else {
+    // CHECK: ^bb0(%arg3: tensor<3xf32>, %arg4: tensor<3xf32>)
+    ^bb0(%arg3 : tensor<?xf32>, %arg4 : tensor<?xf32>):
+      // CHECK: tosa.yield %arg4 : tensor<3xf32>
+      tosa.yield %arg4 : tensor<?xf32>
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_no_fold_shape_expression
+func.func @test_no_fold_shape_expression(%arg0: tensor<1x?x3xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
+  // CHECK: tosa.dim
+  %0 = tosa.dim %arg0 {axis = 1: i32} : (tensor<1x?x3xf32>) -> !tosa.shape<1>
+  // CHECK: tosa.tile
+  %1 = tosa.tile %arg1, %0 : (tensor<?xf32>, !tosa.shape<1>) -> tensor<?xf32>
+  // CHECK: return %{{.*}} : tensor<?xf32>
+  return %1 : tensor<?xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 9213fdd636813..8e7fa584ff46a 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1,9 +1,12 @@
-// RUN: mlir-opt --split-input-file --tosa-infer-shapes --allow-unregistered-dialect %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --verify-diagnostics --tosa-infer-shapes --allow-unregistered-dialect %s | FileCheck %s --allow-unused-prefixes --check-prefixes=CHECK,DEFAULT
+// RUN: mlir-opt --split-input-file --verify-diagnostics --tosa-infer-shapes="convert-function-boundaries" --allow-unregistered-dialect %s | FileCheck %s --allow-unused-prefixes --check-prefixes=CHECK,FUNCBOUND
 
 // CHECK-LABEL: @test_return
 func.func @test_return(%arg0 : tensor<4xf32>) -> tensor<*xf32> {
-  // CHECK: [[LOG:%.+]] = tosa.log %arg0 : (tensor<4xf32>) -> tensor<4xf32>
-  // CHECK: tensor.cast [[LOG]] : tensor<4xf32> to tensor<*xf32>
+  // CHECK: %[[LOG:.+]] = tosa.log %arg0 : (tensor<4xf32>) -> tensor<4xf32>
+  // DEFAULT: %[[CAST:.+]] = tensor.cast %[[LOG]] : tensor<4xf32> to tensor<*xf32>
+  // DEFAULT: return %[[CAST]] : tensor<*xf32>
+  // FUNCBOUND: return %[[LOG]] : tensor<4xf32>
   %0 = tosa.log %arg0 : (tensor<4xf32>) -> tensor<*xf32>
   return %0 : tensor<*xf32>
 }
@@ -12,13 +15,13 @@ func.func @test_return(%arg0 : tensor<4xf32>) -> tensor<*xf32> {
 
 // CHECK-LABEL: @test_multiple
 func.func @test_multiple(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>, %arg2 : tensor<1xf32>) -> tensor<*xf32> {
-  // CHECK: [[ADD:%.+]] = tosa.add %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
+  // CHECK: %[[ADD:.+]] = tosa.add %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
   %0 = tosa.add %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
 
-  // CHECK: [[LOG:%.+]] = tosa.log %0 : (tensor<4xf32>) -> tensor<4xf32>
+  // CHECK: %[[LOG:.+]] = tosa.log %[[ADD]] : (tensor<4xf32>) -> tensor<4xf32>
   %1 = tosa.log %0 : (tensor<*xf32>) -> tensor<*xf32>
 
-  // CHECK: [[SUB:%.+]] = tosa.sub %0, %arg2 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
+  // CHECK: %[[SUB:.+]] = tosa.sub %[[ADD]], %arg2 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
   %2 = tosa.sub %0, %arg2 : (tensor<*xf32>, tensor<1xf32>) -> tensor<*xf32>
   return %0 : tensor<*xf32>
 }
@@ -364,12 +367,15 @@ func.func @test_unranked_zero_points_matmul(%arg0: tensor<1x2x3xf32>, %arg1: ten
 
 // CHECK-LABEL: @test_accepts_unranked_scalar_tensor
 func.func @test_accepts_unranked_scalar_tensor(%arg0: tensor<1x2x2xf32>, %arg1: tensor<1xf32>) -> tensor<*xf32> {
-  // CHECK: %[[ZP:.*]] = tosa.cast %arg1 : (tensor<1xf32>) -> tensor<1xf32>
+  // CHECK-DAG: %[[SHAPE:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  // CHECK-DAG: %[[ZP:.*]] = tosa.cast %arg1 : (tensor<1xf32>) -> tensor<1xf32>
   %0 = tosa.cast %arg1 : (tensor<1xf32>) -> tensor<*xf32>
-  // CHECK: %[[SHAPE:.*]] = tosa.const_shape
   %1 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
-  // CHECK: tosa.pad %arg0, %[[SHAPE]], %[[ZP]] : (tensor<1x2x2xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<1x3x3xf32>
+  // CHECK: %[[PAD:.*]] = tosa.pad %arg0, %[[SHAPE]], %[[ZP]] : (tensor<1x2x2xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<1x3x3xf32>
   %2 = tosa.pad %arg0, %1, %0 : (tensor<1x2x2xf32>, !tosa.shape<6>, tensor<*xf32>) -> tensor<*xf32>
+  // DEFAULT: %[[CAST:.*]] = tensor.cast %[[PAD]] : tensor<1x3x3xf32> to tensor<*xf32>
+  // DEFAULT: return %[[CAST]] : tensor<*xf32>
+  // FUNCBOUND: return %[[PAD]] : tensor<1x3x3xf32>
   return %2 : tensor<*xf32>
 }
 
@@ -406,18 +412,16 @@ func.func @test_table_dynamic(%arg0 : tensor<4x?xi16>, %arg1 : tensor<513xi16>)
 
 // CHECK-LABEL: @test_static_reshape
 func.func @test_static_reshape(%arg0 : tensor<4x4xi32>) -> () {
-  // CHECK: %[[CONST3:.+]] = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
+  // CHECK-DAG: %[[CONSTSHAPE1:.+]] = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
+  // CHECK-DAG: %[[CONSTSHAPE2:.+]] = tosa.const_shape {values = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1>
+  // CHECK-DAG: %[[CONSTSHAPE3:.+]] = tosa.const_shape {values = dense<[2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // CHECK-DAG: tosa.reshape %arg0, %[[CONSTSHAPE1]] : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32>
+  // CHECK-DAG: tosa.reshape %arg0, %[[CONSTSHAPE2]] : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32>
+  // CHECK-DAG: tosa.reshape %arg0, %[[CONSTSHAPE3]] : (tensor<4x4xi32>, !tosa.shape<2>) -> tensor<2x8xi32>
   %3 = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
-  // CHECK: tosa.reshape %arg0, %[[CONST3]] : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32>
   %0 = tosa.reshape %arg0, %3 : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32>
-
-  // CHECK: %[[CONST4:.+]] = tosa.const_shape {values = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1>
-  // CHECK: tosa.reshape %arg0, %[[CONST4]] : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32>
   %4 = tosa.const_shape {values = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1>
   %1 = tosa.reshape %arg0, %4 : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32>
-
-  // CHECK: %[[CONST5:.+]] = tosa.const_shape {values = dense<[2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
-  // CHECK: tosa.reshape %arg0, %[[CONST5]] : (tensor<4x4xi32>, !tosa.shape<2>) -> tensor<2x8xi32>
   %5 = tosa.const_shape {values = dense<[2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
   %2 = tosa.reshape %arg0, %5 : (tensor<4x4xi32>, !tosa.shape<2>) -> tensor<2x8xi32>
 
@@ -428,19 +432,17 @@ func.func @test_static_reshape(%arg0 : tensor<4x4xi32>) -> () {
 
 // CHECK-LABEL: @test_dynamic_reshape
 func.func @test_dynamic_reshape(%arg0 : tensor<4x?xi32>) -> () {
-  // CHECK: %0 = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
+  // CHECK-DAG: %[[CONSTSHAPE1:.+]] = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
+  // CHECK-DAG: %[[CONSTSHAPE2:.+]] = tosa.const_shape {values = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1>
+  // CHECK-DAG: %[[CONSTSHAPE3:.+]] = tosa.const_shape {values = dense<[2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // CHECK-DAG: tosa.reshape %arg0, %[[CONSTSHAPE1]] : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor<16xi32>
+  // CHECK-DAG: tosa.reshape %arg0, %[[CONSTSHAPE2]] : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor<?xi32>
+  // CHECK-DAG: tosa.reshape %arg0, %[[CONSTSHAPE3]] : (tensor<4x?xi32>, !tosa.shape<2>) -> tensor<2x?xi32>
   %0 = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
-  // CHECK: %1 = tosa.reshape %arg0, %0 : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor<16xi32>
   %1 = tosa.reshape %arg0, %0 : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor<?xi32>
-
-  // CHECK: %2 = tosa.const_shape {values = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1>
   %2 = tosa.const_shape {values = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1>
-  // CHECK: %3 = tosa.reshape %arg0, %2 : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor<?xi32>
   %3 = tosa.reshape %arg0, %2 : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor<?xi32>
-
-  // CHECK: %4 = tosa.const_shape {values = dense<[2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
   %4 = tosa.const_shape {values = dense<[2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
-  // CHECK: %5 = tosa.reshape %arg0, %4 : (tensor<4x?xi32>, !tosa.shape<2>) -> tensor<2x?xi32>
   %5 = tosa.reshape %arg0, %4 : (tensor<4x?xi32>, !tosa.shape<2>) -> tensor<?x?xi32>
 
   return
@@ -563,9 +565,9 @@ func.func @test_padding_simple(%arg0 : tensor<1x2xf32>) -> () {
 
 // CHECK-LABEL: @test_slice
 func.func @test_slice(%arg0 : tensor<?xi32>) -> () {
-  // CHECK: %0 = tosa.const_shape  {values = dense<1> : tensor<1xindex>}
-  // CHECK: %1 = tosa.const_shape  {values = dense<2> : tensor<1xindex>}
-  // CHECK: %2 = tosa.slice %arg0, %0, %1 : (tensor<?xi32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<2xi32>
+  // CHECK-DAG: %[[SIZE:.+]] = tosa.const_shape {values = dense<2> : tensor<1xindex>} : () -> !tosa.shape<1>
+  // CHECK-DAG: %[[START:.+]] = tosa.const_shape {values = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1>
+  // CHECK: %[[SLICE:.+]] = tosa.slice %arg0, %[[START]], %[[SIZE]] : (tensor<?xi32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<2xi32>
   %0 = tosa.const_shape {values = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1>
   %1 = tosa.const_shape {values = dense<2> : tensor<1xindex>} : () -> !tosa.shape<1>
   %2= tosa.slice %arg0, %0, %1 : (tensor<?xi32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<?xi32>
@@ -576,8 +578,8 @@ func.func @test_slice(%arg0 : tensor<?xi32>) -> () {
 
 // CHECK-LABEL: @test_slice_size_minus_one
 func.func @test_slice_size_minus_one(%arg0 : tensor<?x8x8x8xi32>) -> () {
-  // CHECK: %[[START:.+]] = tosa.const_shape
-  // CHECK: %[[SIZE:.+]] = tosa.const_shape
+  // CHECK-DAG: %[[SIZE:.+]] = tosa.const_shape {values = dense<-1> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // CHECK-DAG: %[[START:.+]] = tosa.const_shape {values = dense<[0, 1, -1, 8]> : tensor<4xindex>} : () -> !tosa.shape<4>
   // CHECK: %[[VAL:.+]] = tosa.slice %arg0, %[[START]], %[[SIZE]] : (tensor<?x8x8x8xi32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x7x?x?xi32>
   // this checks following
   //  dim 0: size=-1, input dim=? => inferred output dim is ?
@@ -594,9 +596,9 @@ func.func @test_slice_size_minus_one(%arg0 : tensor<?x8x8x8xi32>) -> () {
 
 // CHECK-LABEL: @test_slice_dynamic
 func.func @test_slice_dynamic(%arg0 : tensor<10x?x2xf32>) -> () {
-  // CHECK: %0 = tosa.const_shape  {values = dense<[1, 0, 0]> : tensor<3xindex>}
-  // CHECK: %1 = tosa.const_shape  {values = dense<[7, -1, 1]> : tensor<3xindex>}
-  // CHECK: %2 = tosa.slice %arg0, %0, %1 : (tensor<10x?x2xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<7x?x1xf32>
+  // CHECK-DAG: %[[SIZE:.+]] = tosa.const_shape {values = dense<[7, -1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  // CHECK-DAG: %[[START:.+]] = tosa.const_shape {values = dense<[1, 0, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  // CHECK: %[[SLICE:.+]] = tosa.slice %arg0, %[[START]], %[[SIZE]] : (tensor<10x?x2xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<7x?x1xf32>
   %0 = tosa.const_shape {values = dense<[1, 0, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
   %1 = tosa.const_shape {values = dense<[7, -1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
   %2= tosa.slice %arg0, %0, %1 : (tensor<10x?x2xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<?x?x?xf32>
@@ -1146,7 +1148,7 @@ func.func @resize_negative_output_dim(%arg0: tensor<1x3x1x1xi8>) {
   %scale = tosa.const_shape { values = dense<[1, 3, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
   %offset = tosa.const_shape { values = dense<[6, 1]> : tensor<2xindex> } : () -> !tosa.shape<2>
   %border = tosa.const_shape { values = dense<[-15, 0]> : tensor<2xindex> } : () -> !tosa.shape<2>
-  // expected-error at +1 {{calculated output height and width must be non-negative, got height = -5, width = 0}}
+  // expected-error at below {{calculated output height and width must be non-negative, got height = -5, width = 0}}
   %0 = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<1x3x1x1xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xi8>
   return
 }
@@ -1214,6 +1216,25 @@ func.func @if_test_propagate(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : t
 
 // -----
 
+func.func @if_test_propagate_dynamic(%arg0 : tensor<3xf32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
+  // CHECK: tosa.cond_if
+  // CHECK: -> tensor<3xf32>
+  %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<3xf32>, tensor<3xf32>) -> tensor<?xf32> {
+    // CHECK: ^bb0(%arg3: tensor<3xf32>, %arg4: tensor<3xf32>)
+    ^bb0(%arg3 : tensor<?xf32>, %arg4 : tensor<?xf32>):
+      // CHECK: tosa.yield %arg3 : tensor<3xf32>
+      tosa.yield %arg3 : tensor<?xf32>
+  } else {
+    // CHECK: ^bb0(%arg3: tensor<3xf32>, %arg4: tensor<3xf32>)
+    ^bb0(%arg3 : tensor<?xf32>, %arg4 : tensor<?xf32>):
+      // CHECK: tosa.yield %arg4 : tensor<3xf32>
+      tosa.yield %arg4 : tensor<?xf32>
+  }
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @while_test
 func.func @while_test(%arg0 : tensor<i32>) -> (tensor<*xi32>) {
   // CHECK:      tosa.add
@@ -1249,7 +1270,9 @@ func.func @while_test(%arg0 : tensor<i32>) -> (tensor<*xi32>) {
     tosa.yield %3 : tensor<*xi32>
   }
 
-  // CHECK:      tensor.cast
+  // DEFAULT: %[[CAST:.+]] = tensor.cast %{{.*}} : tensor<i32> to tensor<*xi32>
+  // DEFAULT: return %[[CAST]] : tensor<*xi32>
+  // FUNCBOUND: return %{{.*}} : tensor<i32>
   return %1 : tensor<*xi32>
 }
 
@@ -1323,7 +1346,9 @@ func.func @while_dont_crash(%arg0 : tensor<i32>) -> (tensor<*xi32>) {
     "use"(%3) : (tensor<*xi32>) -> ()
     tosa.yield %3 : tensor<*xi32>
   }
-  // CHECK: tensor.cast
+  // DEFAULT: %[[CAST:.+]] = tensor.cast
+  // DEFAULT: return %[[CAST]] : tensor<*xi32>
+  // FUNCBOUND: return %{{.*}} : tensor<i32>
   return %1 : tensor<*xi32>
 }
 
@@ -1379,7 +1404,9 @@ func.func @while_dont_crash_nested(%arg0 : tensor<i32>) -> (tensor<*xi32>) {
     tosa.yield %1 : tensor<*xi32>
   }
 
-  // CHECK: tensor.cast
+  // DEFAULT: %[[CAST:.+]] = tensor.cast
+  // DEFAULT: return %[[CAST]] : tensor<*xi32>
+  // FUNCBOUND: return %{{.*}} : tensor<i32>
   return %1 : tensor<*xi32>
 }
 
@@ -1752,3 +1779,23 @@ func.func @test_avg_pool2d_unranked_input(%input: tensor<*xi32>, %zp: tensor<1xi
   %0 = tosa.avg_pool2d %input, %zp, %zp { acc_type = i32, kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1> } : (tensor<*xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
   return
 }
+
+// -----
+
+// CHECK-LABEL: test_simple_shape_expression
+func.func @test_simple_shape_expression(%arg0: tensor<7x12xi32>, %arg1: tensor<80xi32>, %arg2: tensor<4xi32>) -> tensor<?xi32> {
+  // CHECK: %[[DIM1:.+]] = tosa.dim
+  // CHECK: %[[DIM2:.+]] = tosa.dim
+  // CHECK: %[[ADD_SHAPE:.+]] = tosa.add_shape
+  // CHECK: %[[RESHAPE:.+]] = tosa.reshape %arg0, %[[ADD_SHAPE]] : (tensor<7x12xi32>, !tosa.shape<1>) -> tensor<?xi32>
+  // CHECK: %[[DIM3:.+]] = tosa.dim %[[RESHAPE]]
+  // CHECK: %[[TILE:.+]] = tosa.tile %[[RESHAPE]], %[[DIM3]] : (tensor<?xi32>, !tosa.shape<1>) -> tensor<?xi32>
+  // CHECK: return %[[TILE]] : tensor<?xi32>
+  %a = tosa.dim %arg1 {axis = 0: i32} : (tensor<80xi32>) -> !tosa.shape<1>
+  %b = tosa.dim %arg2 {axis = 0: i32} : (tensor<4xi32>) -> !tosa.shape<1>
+  %c = tosa.add_shape %a, %b : (!tosa.shape<1>, !tosa.shape<1>) -> !tosa.shape<1>
+  %d = tosa.reshape %arg0, %c : (tensor<7x12xi32>, !tosa.shape<1>) -> tensor<?xi32>
+  %e = tosa.dim %d {axis = 0: i32} : (tensor<?xi32>) -> !tosa.shape<1>
+  %f = tosa.tile %d, %e : (tensor<?xi32>, !tosa.shape<1>) -> tensor<?xi32>
+  return %f : tensor<?xi32>
+}

>From a1c03f2baa21ff1ca0ef896b123471385587d478 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Tue, 3 Feb 2026 09:26:14 +0000
Subject: [PATCH 3/5] Address review comments and remove dead casts correctly

Change-Id: Ic37435a0fae3d439f4355d62ac9d17c02fcba974
---
 .../Tosa/Transforms/TosaInferShapes.cpp       | 18 ++++-
 ...a-infer-shapes-fold-shape-expressions.mlir |  1 +
 mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 79 +++++++++++++++++++
 3 files changed, 95 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index 37644ee8c03f8..e7e83abc14926 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -302,6 +302,9 @@ struct TosaInferShapes
     OperationFolder folder(ctx);
 
     for (auto &block : region) {
+      // The loop body may erase operations, so we need to be careful
+      // when iterating. Fetch the next operation before the current
+      // operation is modified.
       for (auto it = block.begin(); it != block.end();) {
         Operation &op = *it++;
         if (op.getDialect() != tosaDialect)
@@ -371,16 +374,17 @@ struct TosaInferShapes
     // Rewrite func.return ops, removing dead tensor.cast ops if possible
     func.walk([&rewriter, &newReturnTypes](func::ReturnOp ret) {
       SmallVector<Value> newReturnValues;
+      SmallVector<Value> maybeDeadCasts;
       OperandRange returnOperands = ret.getOperands();
       newReturnValues.reserve(returnOperands.size());
-      newReturnTypes.reserve(returnOperands.size());
+      maybeDeadCasts.reserve(returnOperands.size());
+      newReturnTypes.reserve(newReturnTypes.size() + returnOperands.size());
 
       for (const Value &v : returnOperands) {
         Value newReturnValue = v;
         if (auto castOp = v.getDefiningOp<tensor::CastOp>()) {
           newReturnValue = castOp.getSource();
-          if (castOp->use_empty())
-            rewriter.eraseOp(castOp);
+          maybeDeadCasts.push_back(castOp);
         }
         newReturnValues.push_back(newReturnValue);
         newReturnTypes.push_back(newReturnValue.getType());
@@ -388,6 +392,14 @@ struct TosaInferShapes
 
       rewriter.setInsertionPoint(ret);
       rewriter.replaceOpWithNewOp<func::ReturnOp>(ret, newReturnValues);
+
+      if (!maybeDeadCasts.empty()) {
+        llvm::for_each(maybeDeadCasts, [&](Value castVal) {
+          if (castVal.use_empty()) {
+            rewriter.eraseOp(castVal.getDefiningOp());
+          }
+        });
+      }
     });
 
     // Update function return types with newly inferred types
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes-fold-shape-expressions.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes-fold-shape-expressions.mlir
index 1d8cc863ebe3e..73fb7e9cbca4d 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes-fold-shape-expressions.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes-fold-shape-expressions.mlir
@@ -12,6 +12,7 @@ func.func @test_simple_shape_expression(%arg0: tensor<7x12xi32>, %arg1: tensor<8
   // CHECK: %[[TILE:.+]] = tosa.tile %[[RESHAPE]], %[[SHAPE]] : (tensor<84xi32>, !tosa.shape<1>) -> tensor<7056xi32>
   // DEFAULT: %[[CAST:.+]] = tensor.cast %[[TILE]] : tensor<7056xi32> to tensor<?xi32>
   // DEFAULT: return %[[CAST]] : tensor<?xi32>
+  // FUNCBOUND-NOT: tensor.cast
   // FUNCBOUND: return %[[TILE]] : tensor<7056xi32>
   %a = tosa.dim %arg1 {axis = 0: i32} : (tensor<80xi32>) -> !tosa.shape<1>
   %b = tosa.dim %arg2 {axis = 0: i32} : (tensor<4xi32>) -> !tosa.shape<1>
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 8e7fa584ff46a..1e52f78df55e0 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1799,3 +1799,82 @@ func.func @test_simple_shape_expression(%arg0: tensor<7x12xi32>, %arg1: tensor<8
   %f = tosa.tile %d, %e : (tensor<?xi32>, !tosa.shape<1>) -> tensor<?xi32>
   return %f : tensor<?xi32>
 }
+
+// -----
+
+// CHECK-LABEL: test_conv2d_block_scaled_static
+func.func @test_conv2d_block_scaled_static(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<*xf32> {
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // CHECK: -> tensor<1x4x4x8xf32>
+  %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_conv2d_block_scaled_dynamic_scales
+func.func @test_conv2d_block_scaled_dynamic_scales(%arg0: tensor<?x4x4x64xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<?x1x1x64xf4E2M1FN>, %arg3: tensor<*xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<*xf32> {
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // CHECK: -> tensor<?x4x4x?xf32>
+  %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<?x4x4x64xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<?x1x1x64xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_conv2d_block_scaled_dynamic_data
+func.func @test_conv2d_block_scaled_dynamic_data(%arg0: tensor<*xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<*xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<*xf32> {
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // CHECK: -> tensor<1x4x4x8xf32>
+  %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<*xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_conv2d_block_scaled_dynamic_unranked
+func.func @test_conv2d_block_scaled_dynamic_unranked(%arg0: tensor<*xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<*xf4E2M1FN>, %arg3: tensor<*xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<*xf32> {
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // CHECK: -> tensor<?x?x?x?xf32>
+  %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_dwconv2d_bias_broadcast
+func.func @test_dwconv2d_bias_broadcast(%input: tensor<2x8x9x?xf32>, %weight: tensor<3x3x?x?xf32>, %bias: tensor<1xf32>, %input_zp: tensor<1xf32>, %weight_zp: tensor<1xf32>) {
+  // CHECK: -> tensor<2x6x7x?xf32>
+  %0 = tosa.depthwise_conv2d %input, %weight, %bias, %input_zp, %weight_zp
+       { acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> }
+       : (tensor<2x8x9x?xf32>, tensor<3x3x?x?xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x?xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_tconv2d_bias_broadcast
+func.func @test_tconv2d_bias_broadcast(%input: tensor<2x6x7x3xf32>, %weight: tensor<?x3x3x3xf32>, %bias: tensor<1xf32>, %input_zp: tensor<1xf32>, %weight_zp: tensor<1xf32>) {
+  // CHECK: -> tensor<2x8x9x?xf32>
+  %0 = tosa.transpose_conv2d %input, %weight, %bias, %input_zp, %weight_zp
+       { acc_type = f32, pad = array<i64: 0, 0, 0, 0>, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> }
+       : (tensor<2x6x7x3xf32>, tensor<?x3x3x3xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x?xf32>
+    return
+  }
+
+// -----
+
+// CHECK-LABEL: test_avg_pool2d_unranked_input
+func.func @test_avg_pool2d_unranked_input(%input: tensor<*xi32>, %zp: tensor<1xi32>) {
+  // CHECK: -> tensor<?x?x?x?xi32>
+  %0 = tosa.avg_pool2d %input, %zp, %zp { acc_type = i32, kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1> } : (tensor<*xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
+  return
+}

>From aa73ec798b2ad00b3f8072d96e4d2b790743fd82 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Mon, 23 Feb 2026 13:25:35 +0000
Subject: [PATCH 4/5] Dead code elimination on folded expressions

After review comments, removes dead constants after
folding shape expressions.

Change-Id: Ib13bfd40c1977c0249c2c7575a8c164b644f2eed
---
 mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp     | 9 ++++++++-
 .../Tosa/tosa-infer-shapes-fold-shape-expressions.mlir   | 2 ++
 2 files changed, 10 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index e7e83abc14926..4e321645a751d 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -315,7 +315,14 @@ struct TosaInferShapes
 
         if (foldShapeExpressions &&
             op.hasTrait<OpTrait::tosa::TosaShapeOperator>()) {
-          (void)folder.tryToFold(&op);
+          const LogicalResult foldResult = folder.tryToFold(&op);
+          if (succeeded(foldResult))
+            // Clean up any dead const_shape ops as a result of folding.
+            llvm::for_each(op.getOperands(), [](Value operand) {
+              Operation *definingOp = operand.getDefiningOp();
+              if (definingOp && isOpTriviallyDead(definingOp))
+                definingOp->erase();
+            });
           continue;
         }
 
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes-fold-shape-expressions.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes-fold-shape-expressions.mlir
index 73fb7e9cbca4d..067f0efba9c86 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes-fold-shape-expressions.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes-fold-shape-expressions.mlir
@@ -8,6 +8,8 @@ func.func @test_simple_shape_expression(%arg0: tensor<7x12xi32>, %arg1: tensor<8
   // CHECK-NOT: tosa.dim
   // CHECK-NOT: tosa.add_shape
   // CHECK: %[[SHAPE:.+]] = tosa.const_shape {values = dense<84> : tensor<1xindex>} : () -> !tosa.shape<1>
+  // CHECK-NOT: tosa.const_shape {values = dense<4> : tensor<1xindex>} : () -> !tosa.shape<1>
+  // CHECK-NOT: tosa.const_shape {values = dense<80> : tensor<1xindex>} : () -> !tosa.shape<1>
   // CHECK: %[[RESHAPE:.+]] = tosa.reshape %arg0, %[[SHAPE]] : (tensor<7x12xi32>, !tosa.shape<1>) -> tensor<84xi32>
   // CHECK: %[[TILE:.+]] = tosa.tile %[[RESHAPE]], %[[SHAPE]] : (tensor<84xi32>, !tosa.shape<1>) -> tensor<7056xi32>
   // DEFAULT: %[[CAST:.+]] = tensor.cast %[[TILE]] : tensor<7056xi32> to tensor<?xi32>

>From 7c6d02ef575f54f7caaead68fb56388d9b8357c9 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Wed, 25 Feb 2026 14:27:18 +0000
Subject: [PATCH 5/5] Fix dead const_shape elimination

Move cleanup to after infer shapes has completed.

Change-Id: If398d0209a764ee9376e52600fa24a8c6cdcf262
---
 .../Tosa/Transforms/TosaInferShapes.cpp       | 19 +++++++++++--------
 1 file changed, 11 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index 4e321645a751d..837ebca572aae 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/Iterators.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/FoldUtils.h"
@@ -174,6 +175,15 @@ struct TosaInferShapes
     propagateShapesInRegion(func.getBody(), state);
     state.commit();
 
+    if (foldShapeExpressions) {
+      // Folding shape expressions may leave dead tosa.const_shape operations
+      func.walk<WalkOrder::PostOrder, ReverseIterator>(
+          [](tosa::ConstShapeOp op) {
+            if (isOpTriviallyDead(op))
+              op->erase();
+          });
+    }
+
     validateSameOperandsAndResultRankTrait(func.getBody());
 
     if (convertFunctionBoundaries)
@@ -315,14 +325,7 @@ struct TosaInferShapes
 
         if (foldShapeExpressions &&
             op.hasTrait<OpTrait::tosa::TosaShapeOperator>()) {
-          const LogicalResult foldResult = folder.tryToFold(&op);
-          if (succeeded(foldResult))
-            // Clean up any dead const_shape ops as a result of folding.
-            llvm::for_each(op.getOperands(), [](Value operand) {
-              Operation *definingOp = operand.getDefiningOp();
-              if (definingOp && isOpTriviallyDead(definingOp))
-                definingOp->erase();
-            });
+          (void)folder.tryToFold(&op);
           continue;
         }
 



More information about the Mlir-commits mailing list