[Mlir-commits] [mlir] [mlir][tosa] Add pass to assign static input shape to TOSA functions (PR #171156)

Luke Hutton llvmlistbot at llvm.org
Wed Dec 10 07:02:49 PST 2025


https://github.com/lhutton1 updated https://github.com/llvm/llvm-project/pull/171156

>From 2448915b1efd241291581b71eb34c98ae2fff7b5 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Sat, 6 Dec 2025 10:46:00 +0000
Subject: [PATCH 1/3] [mlir][tosa] Add pass to assign static input shape to
 TOSA functions

This commit introduces the `--tosa-eperimental-input-shape` pass, which
allows a user to convert dynamically shaped input arguments of TOSA functions
to a user defined static shape. Here is a simple example:
```bash
func.func @test(%arg0: tensor<2x?xi32>, %arg1: tensor<?x256xf32>, %arg2: tensor<?x9xf32>) -> (tensor<2x?xi32>, tensor<?x256xf32>, tensor<?x9xf32>) {
    %0 = tosa.add %arg0, %arg0 : (tensor<2x?xi32>, tensor<2x?xi32>) -> tensor<2x?xi32>
    %1 = tosa.reciprocal %arg1 : (tensor<?x256xf32>) -> tensor<?x256xf32>
    %2 = tosa.sub %arg2, %arg2 : (tensor<?x9xf32>, tensor<?x9xf32>) -> tensor<?x9xf32>
    return %0, %1, %2 : tensor<2x?xi32>, tensor<?x256xf32>, tensor<?x9xf32>
}

$ mlir-opt --tosa-experimental-input-shape="args=arg0:2x16,arg2:64x9" test.mlir
func.func @test(%arg0: tensor<2x16xi32>, %arg1: tensor<?x256xf32>, %arg2: tensor<64x9xf32>) -> (tensor<2x?xi32>, tensor<?x256xf32>, tensor<?x9xf32>) {
    %0 = tosa.add %arg0, %arg0 : (tensor<2x16xi32>, tensor<2x16xi32>) -> tensor<2x?xi32>
    %1 = tosa.reciprocal %arg1 : (tensor<?x256xf32>) -> tensor<?x256xf32>
    %2 = tosa.sub %arg2, %arg2 : (tensor<64x9xf32>, tensor<64x9xf32>) -> tensor<?x9xf32>
    return %0, %1, %2 : tensor<2x?xi32>, tensor<?x256xf32>, tensor<?x9xf32>
}
```

When used in conjunction with the `--tosa-infer-shapes` pass, it can be
used to resolve simple TOSA functions (those that don't include
TOSA shape operations) to propagate static shape information. Continuing
from the example above:
```bash
$ mlir-opt --tosa-infer-shapes test2.mlir
func.func @test(%arg0: tensor<2x16xi32>, %arg1: tensor<?x256xf32>, %arg2: tensor<64x9xf32>) -> (tensor<2x?xi32>, tensor<?x256xf32>, tensor<?x9xf32>) {
    %0 = tosa.add %arg0, %arg0 : (tensor<2x16xi32>, tensor<2x16xi32>) -> tensor<2x16xi32>
    %cast = tensor.cast %0 : tensor<2x16xi32> to tensor<2x?xi32>
    %1 = tosa.reciprocal %arg1 : (tensor<?x256xf32>) -> tensor<?x256xf32>
    %2 = tosa.sub %arg2, %arg2 : (tensor<64x9xf32>, tensor<64x9xf32>) -> tensor<64x9xf32>
    %cast_0 = tensor.cast %2 : tensor<64x9xf32> to tensor<?x9xf32>
    return %cast, %1, %cast_0 : tensor<2x?xi32>, tensor<?x256xf32>, tensor<?x9xf32>
}
```
Note: tosa-infer-shapes currently doesn't have an option to update the
function signature.

Co-authored-by: Kaushik Varadharajan <kaushik.varadharajan at arm.com>
Change-Id: Ie8ab1383d3f7388f0e06dc90dfc197ee7c481af6
---
 .../mlir/Dialect/Tosa/Transforms/Passes.h     |   2 +
 .../mlir/Dialect/Tosa/Transforms/Passes.td    |  22 +++
 .../Dialect/Tosa/Transforms/CMakeLists.txt    |   1 +
 .../Tosa/Transforms/TosaInputShape.cpp        | 184 ++++++++++++++++++
 mlir/test/Dialect/Tosa/tosa-input-shape.mlir  |  70 +++++++
 5 files changed, 279 insertions(+)
 create mode 100644 mlir/lib/Dialect/Tosa/Transforms/TosaInputShape.cpp
 create mode 100644 mlir/test/Dialect/Tosa/tosa-input-shape.mlir

diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index ba99d2f1d2727..e33d7c698856c 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -41,6 +41,8 @@ void populateTosaConstantReduction(MLIRContext *ctx,
 void populateTosaTypeConversion(TypeConverter &converter);
 
 std::unique_ptr<Pass> createTosaTestQuantUtilAPIPass();
+std::unique_ptr<Pass>
+createTosaInputShapePass(std::vector<std::string> args = {});
 
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 12f520297b702..58ea47db731a6 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -198,4 +198,26 @@ def TosaNarrowI64ToI32Pass : Pass<"tosa-narrow-i64-to-i32", "func::FuncOp"> {
   ];
 }
 
+def TosaInputShape : Pass<"tosa-experimental-input-shape", "func::FuncOp"> {
+  let summary = "Override dynamic function arguments to specified static shapes.";
+  let description = [{
+    Pass that overrides the dynamic input shapes of function arguments to
+    specified static shapes. If a specified static shape conflicts with the
+    static dimensions in an original input shape, an error is reported.
+  }];
+
+  let constructor = "tosa::createTosaInputShapePass()";
+  let dependentDialects = [
+    "func::FuncDialect",
+    "tensor::TensorDialect",
+    "tosa::TosaDialect",
+  ];
+  let options = [
+    ListOption<"args", "args", "std::string",
+          "Comma-separated list of shape descriptions. Each description contains the "
+          "argument name, a colon, and a shape with dimensions separated by x "
+          "(e.g. arg0:5x5,arg3:2x64)">,
+  ];
+}
+
 #endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index 091b481d6394b..cf1e6ab55872c 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
   TosaProfileCompliance.cpp
   TosaValidation.cpp
   TosaNarrowI64ToI32.cpp
+  TosaInputShape.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInputShape.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInputShape.cpp
new file mode 100644
index 0000000000000..1665f33f0d478
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInputShape.cpp
@@ -0,0 +1,184 @@
+//===- TosaInputShape.cpp -------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Pass that overrides the dynamic input shapes of function arguments to
+// specified static shapes. If a specified static shape conflicts with the
+// static dimensions in an original input shape, an error is reported.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace tosa {
+#define GEN_PASS_DEF_TOSAINPUTSHAPE
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+} // namespace tosa
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+namespace {
+
+typedef std::pair<size_t, SmallVector<int64_t>> IdxAndShape;
+
+FailureOr<IdxAndShape> parseInputShape(Location loc, StringRef input) {
+  if (!input.consume_front("arg")) {
+    emitError(loc) << "expected prefix 'arg' at the start of " << input;
+    return failure();
+  }
+
+  const size_t colonPos = input.find(':');
+  if (colonPos == StringRef::npos) {
+    emitError(loc) << "expected ':' after argument index in '" << input << "'";
+    return failure();
+  }
+
+  const StringRef indexStr = input.substr(0, colonPos);
+  input = input.substr(colonPos + 1);
+
+  size_t index;
+  if (indexStr.getAsInteger(10, index) || index < 0) {
+    emitError(loc) << "invalid argument index, got " << indexStr;
+    return failure();
+  }
+
+  SmallVector<int64_t> shape;
+  while (!input.empty()) {
+    const size_t xPos = input.find("x");
+    StringRef dimStr;
+    if (xPos == StringRef::npos) {
+      dimStr = input;
+      input = "";
+    } else {
+      dimStr = input.substr(0, xPos);
+      input = input.substr(xPos + 1);
+    }
+
+    int64_t dimVal;
+    if (dimStr.getAsInteger(10, dimVal) || dimVal <= 0) {
+      return failure();
+    }
+    shape.push_back(dimVal);
+  }
+
+  const auto idxAndShape = std::make_pair(index, shape);
+  return {idxAndShape};
+}
+
+// Parse input shape arguments from command line input. Returns parsed
+// static shapes and an optional error message.
+// For example:
+//   "args=arg0:5x10,arg8:3x9" => {{{0, {5, 10}}, {8, {3, 9}}}, ""}
+//   "args=arg0:" => {{}, "error message"}
+FailureOr<SmallVector<IdxAndShape>>
+parseInputShapes(Location loc, const std::vector<std::string> &args) {
+  SmallVector<IdxAndShape> inputShapes;
+  for (const std::string &arg : args) {
+    const auto maybeInputShape = parseInputShape(loc, arg);
+    if (failed(maybeInputShape))
+      return failure();
+    inputShapes.push_back(maybeInputShape.value());
+  }
+  return inputShapes;
+}
+
+struct TosaInputShape : public tosa::impl::TosaInputShapeBase<TosaInputShape> {
+public:
+  TosaInputShape() = default;
+
+  explicit TosaInputShape(std::vector<std::string> args) : TosaInputShape() {
+    this->args = args;
+  }
+
+  void runOnOperation() override {
+    MLIRContext *context = &getContext();
+    const Location unknownLoc = UnknownLoc::get(context);
+    const auto maybeArgsParsed = parseInputShapes(unknownLoc, args);
+    if (failed(maybeArgsParsed))
+      return;
+    const SmallVector<IdxAndShape> argsParsed = maybeArgsParsed.value();
+    func::FuncOp func = getOperation();
+
+    const auto getUpdatedTensorType =
+        [&](size_t argIdx, ArrayRef<Type> argTypes,
+            ArrayRef<int64_t> requestedShape) -> FailureOr<Type> {
+      const size_t numInputs = argTypes.size();
+      if (argIdx >= numInputs)
+        return func.emitError()
+               << "provided arg index " << argIdx
+               << " is larger than number of inputs " << numInputs << ".";
+
+      auto tensorType = dyn_cast<TensorType>(argTypes[argIdx]);
+      if (!tensorType)
+        return func.emitError()
+               << "expected tensor type, got " << argTypes[argIdx];
+
+      const ArrayRef<int64_t> originalShape = tensorType.getShape();
+      if (failed(verifyCompatibleShape(originalShape, requestedShape)))
+        return func.emitError()
+               << "arg" << argIdx
+               << " has incompatible shape with requested input shape ("
+               << requestedShape << "), got " << tensorType;
+      return tensorType.cloneWith(requestedShape, tensorType.getElementType());
+    };
+
+    // Update argument shapes in the entry block
+    Block &entryBlock = func.getBody().front();
+    for (const auto &[argIdx, shape] : argsParsed) {
+      SmallVector<Type> argTypes(entryBlock.getArgumentTypes());
+      FailureOr<Type> newTensorType =
+          getUpdatedTensorType(argIdx, argTypes, shape);
+      if (failed(newTensorType))
+        return signalPassFailure();
+
+      entryBlock.getArgument(argIdx).setType(newTensorType.value());
+    }
+
+    // Get new func argument types
+    FunctionType oldFunctionType = func.getFunctionType();
+    ArrayRef<Type> oldInputTypes = oldFunctionType.getInputs();
+    const size_t numInputs = oldInputTypes.size();
+    SmallVector<Type> newInputs(oldInputTypes.begin(), oldInputTypes.end());
+    newInputs.reserve(numInputs);
+    for (const auto &[argIdx, shape] : argsParsed) {
+      FailureOr<Type> newTensorType =
+          getUpdatedTensorType(argIdx, oldInputTypes, shape);
+      if (failed(newTensorType))
+        return signalPassFailure();
+
+      newInputs[argIdx] = newTensorType.value();
+    }
+
+    // Update function signature
+    Block &lastBlock = func.getBody().back();
+    Operation *terminator = lastBlock.getTerminator();
+    SmallVector<Type> newResults;
+    if (auto returnOp = dyn_cast_or_null<func::ReturnOp>(terminator)) {
+      auto types = returnOp.getOperandTypes();
+      newResults.assign(types.begin(), types.end());
+    } else {
+      auto types = oldFunctionType.getResults();
+      newResults.assign(types.begin(), types.end());
+    }
+    const FunctionType newFunctionType =
+        oldFunctionType.clone(newInputs, newResults);
+    func.setFunctionType(newFunctionType);
+  }
+};
+
+} // namespace
+
+std::unique_ptr<Pass>
+mlir::tosa::createTosaInputShapePass(std::vector<std::string> args) {
+  return std::make_unique<TosaInputShape>(args);
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-input-shape.mlir b/mlir/test/Dialect/Tosa/tosa-input-shape.mlir
new file mode 100644
index 0000000000000..520a4575d3d60
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-input-shape.mlir
@@ -0,0 +1,70 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-experimental-input-shape="args=arg0:2x16,arg2:64x9" %s | FileCheck %s
+
+// CHECK-LABEL: test_empty_func
+func.func @test_empty_func(
+        // CHECK: %arg0: tensor<2x16xi32>
+        %arg0: tensor<2x?xi32>,
+        // CHECK: %arg1: tensor<?x256xf32>
+        %arg1: tensor<?x256xf32>,
+        // CHECK: %arg2: tensor<64x9xf32>
+        %arg2: tensor<?x9xf32>) -> (tensor<2x?xi32>, tensor<?x256xf32>, tensor<?x9xf32>) {
+    // CHECK: %arg0, %arg1, %arg2 : tensor<2x16xi32>, tensor<?x256xf32>, tensor<64x9xf32>
+    return %arg0, %arg1, %arg2 : tensor<2x?xi32>, tensor<?x256xf32>, tensor<?x9xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_func_with_ops
+func.func @test_func_with_ops(
+        // CHECK: %arg0: tensor<2x16xi32>
+        %arg0: tensor<2x?xi32>,
+        // CHECK: %arg1: tensor<?x256xf32>
+        %arg1: tensor<?x256xf32>,
+        // CHECK: %arg2: tensor<64x9xf32>
+        %arg2: tensor<?x9xf32>) -> (tensor<2x?xi32>, tensor<?x256xf32>, tensor<?x9xf32>) {
+    // CHECK: %[[ADD:.*]] = tosa.add %arg0, %arg0 : (tensor<2x16xi32>, tensor<2x16xi32>)
+    %0 = tosa.add %arg0, %arg0 : (tensor<2x?xi32>, tensor<2x?xi32>) -> tensor<2x?xi32>
+    // CHECK: %[[RECIP:.*]] =  tosa.reciprocal %arg1 : (tensor<?x256xf32>)
+    %1 = tosa.reciprocal %arg1 : (tensor<?x256xf32>) -> tensor<?x256xf32>
+    // CHECK: %[[SUB:.*]] = tosa.sub %arg2, %arg2 : (tensor<64x9xf32>, tensor<64x9xf32>)
+    %2 = tosa.sub %arg2, %arg2 : (tensor<?x9xf32>, tensor<?x9xf32>) -> tensor<?x9xf32>
+    return %0, %1, %2 : tensor<2x?xi32>, tensor<?x256xf32>, tensor<?x9xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_controlflow
+func.func @test_controlflow(
+        // CHECK: %arg0: tensor<2x16xi32>
+        %arg0: tensor<2x?xi32>,
+        // CHECK: %arg1: tensor<?x256xf32>
+        %arg1: tensor<?x256xf32>,
+        // CHECK: %arg2: tensor<64x9xf32>
+        %arg2: tensor<?x9xf32>,
+        // CHECK: %arg3: tensor<i1>
+        %arg3: tensor<i1>) -> (tensor<2x?xi32>, tensor<?x256xf32>, tensor<?x9xf32>) {
+    // CHECK: %[[IF:.*]]:3 = tosa.cond_if %arg3 (%arg4 = %arg0, %arg5 = %arg1, %arg6 = %arg2) : tensor<i1> (tensor<2x16xi32>, tensor<?x256xf32>, tensor<64x9xf32>) -> (tensor<2x?xi32>, tensor<?x256xf32>, tensor<?x9xf32>) {
+    %0:3 = tosa.cond_if %arg3 (%arg4 = %arg0, %arg5 = %arg1, %arg6 = %arg2) : tensor<i1> (tensor<2x?xi32>, tensor<?x256xf32>, tensor<?x9xf32>) -> (tensor<2x?xi32>, tensor<?x256xf32>, tensor<?x9xf32>) {
+        ^bb0(%arg4: tensor<2x?xi32>, %arg5: tensor<?x256xf32>, %arg6: tensor<?x9xf32>):
+            tosa.yield %arg4, %arg5, %arg6 : tensor<2x?xi32>, tensor<?x256xf32>, tensor<?x9xf32>
+    } else {
+        ^bb0(%arg4: tensor<2x?xi32>, %arg5: tensor<?x256xf32>, %arg6: tensor<?x9xf32>):
+            tosa.yield %arg4, %arg5, %arg6 : tensor<2x?xi32>, tensor<?x256xf32>, tensor<?x9xf32>
+    }
+    // CHECK: return %[[IF]]#0, %[[IF]]#1, %[[IF]]#2 : tensor<2x?xi32>, tensor<?x256xf32>, tensor<?x9xf32>
+    return %0#0, %0#1, %0#2 : tensor<2x?xi32>, tensor<?x256xf32>, tensor<?x9xf32>
+}
+
+// -----
+
+func.func @test_wrong_number_input_args(%arg0: tensor<2x?xf32>) -> tensor<2x?xf32> {
+    // expected-error at -1 {{provided arg index 2 is larger than number of inputs 1}}
+    return %arg0 : tensor<2x?xf32>
+}
+
+// -----
+
+func.func @test_incompatible_input_shape(%arg0: tensor<1x?xf32>, %arg1: tensor<1x?xf32>, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> {
+    // expected-error at -1 {{arg0 has incompatible shape with requested input shape (2, 16), got 'tensor<1x?xf32>'}}
+    return %arg0 : tensor<1x?xf32>
+}

>From 645d15d00a14e8c75a5ce6482897e5dd7386a247 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Tue, 9 Dec 2025 17:46:30 +0000
Subject: [PATCH 2/3] address review comments

Change-Id: Ibd0d1d986655e23001e0d2c42773d6ffcc9123c5
---
 .../lib/Dialect/Tosa/Transforms/TosaInputShape.cpp | 14 ++++++--------
 mlir/test/Dialect/Tosa/tosa-input-shape.mlir       |  2 ++
 2 files changed, 8 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInputShape.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInputShape.cpp
index 1665f33f0d478..e733fb8d378f5 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInputShape.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInputShape.cpp
@@ -134,8 +134,8 @@ struct TosaInputShape : public tosa::impl::TosaInputShapeBase<TosaInputShape> {
 
     // Update argument shapes in the entry block
     Block &entryBlock = func.getBody().front();
+    const SmallVector<Type> argTypes(entryBlock.getArgumentTypes());
     for (const auto &[argIdx, shape] : argsParsed) {
-      SmallVector<Type> argTypes(entryBlock.getArgumentTypes());
       FailureOr<Type> newTensorType =
           getUpdatedTensorType(argIdx, argTypes, shape);
       if (failed(newTensorType))
@@ -145,11 +145,9 @@ struct TosaInputShape : public tosa::impl::TosaInputShapeBase<TosaInputShape> {
     }
 
     // Get new func argument types
-    FunctionType oldFunctionType = func.getFunctionType();
-    ArrayRef<Type> oldInputTypes = oldFunctionType.getInputs();
-    const size_t numInputs = oldInputTypes.size();
+    const FunctionType oldFunctionType = func.getFunctionType();
+    const ArrayRef<Type> oldInputTypes = oldFunctionType.getInputs();
     SmallVector<Type> newInputs(oldInputTypes.begin(), oldInputTypes.end());
-    newInputs.reserve(numInputs);
     for (const auto &[argIdx, shape] : argsParsed) {
       FailureOr<Type> newTensorType =
           getUpdatedTensorType(argIdx, oldInputTypes, shape);
@@ -161,13 +159,13 @@ struct TosaInputShape : public tosa::impl::TosaInputShapeBase<TosaInputShape> {
 
     // Update function signature
     Block &lastBlock = func.getBody().back();
-    Operation *terminator = lastBlock.getTerminator();
+    const Operation *terminator = lastBlock.getTerminator();
     SmallVector<Type> newResults;
     if (auto returnOp = dyn_cast_or_null<func::ReturnOp>(terminator)) {
-      auto types = returnOp.getOperandTypes();
+      const auto types = returnOp.getOperandTypes();
       newResults.assign(types.begin(), types.end());
     } else {
-      auto types = oldFunctionType.getResults();
+      const auto types = oldFunctionType.getResults();
       newResults.assign(types.begin(), types.end());
     }
     const FunctionType newFunctionType =
diff --git a/mlir/test/Dialect/Tosa/tosa-input-shape.mlir b/mlir/test/Dialect/Tosa/tosa-input-shape.mlir
index 520a4575d3d60..ee8e7aac609d5 100644
--- a/mlir/test/Dialect/Tosa/tosa-input-shape.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-input-shape.mlir
@@ -45,9 +45,11 @@ func.func @test_controlflow(
         %arg3: tensor<i1>) -> (tensor<2x?xi32>, tensor<?x256xf32>, tensor<?x9xf32>) {
     // CHECK: %[[IF:.*]]:3 = tosa.cond_if %arg3 (%arg4 = %arg0, %arg5 = %arg1, %arg6 = %arg2) : tensor<i1> (tensor<2x16xi32>, tensor<?x256xf32>, tensor<64x9xf32>) -> (tensor<2x?xi32>, tensor<?x256xf32>, tensor<?x9xf32>) {
     %0:3 = tosa.cond_if %arg3 (%arg4 = %arg0, %arg5 = %arg1, %arg6 = %arg2) : tensor<i1> (tensor<2x?xi32>, tensor<?x256xf32>, tensor<?x9xf32>) -> (tensor<2x?xi32>, tensor<?x256xf32>, tensor<?x9xf32>) {
+        // CHECK: ^bb0(%arg4: tensor<2x?xi32>, %arg5: tensor<?x256xf32>, %arg6: tensor<?x9xf32>):
         ^bb0(%arg4: tensor<2x?xi32>, %arg5: tensor<?x256xf32>, %arg6: tensor<?x9xf32>):
             tosa.yield %arg4, %arg5, %arg6 : tensor<2x?xi32>, tensor<?x256xf32>, tensor<?x9xf32>
     } else {
+        // CHECK: ^bb0(%arg4: tensor<2x?xi32>, %arg5: tensor<?x256xf32>, %arg6: tensor<?x9xf32>):
         ^bb0(%arg4: tensor<2x?xi32>, %arg5: tensor<?x256xf32>, %arg6: tensor<?x9xf32>):
             tosa.yield %arg4, %arg5, %arg6 : tensor<2x?xi32>, tensor<?x256xf32>, tensor<?x9xf32>
     }

>From 0ff796232f4132d51501023649e188088d8d34c3 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Wed, 10 Dec 2025 14:55:36 +0000
Subject: [PATCH 3/3] remove tosa/func from "dependentDialects"

Change-Id: I8c40de763485ed8bb86c78ebc2fbbb836f7e895c
---
 mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td | 4 +---
 1 file changed, 1 insertion(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 58ea47db731a6..cac90c4473a8e 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -208,9 +208,7 @@ def TosaInputShape : Pass<"tosa-experimental-input-shape", "func::FuncOp"> {
 
   let constructor = "tosa::createTosaInputShapePass()";
   let dependentDialects = [
-    "func::FuncDialect",
-    "tensor::TensorDialect",
-    "tosa::TosaDialect",
+    "tensor::TensorDialect"
   ];
   let options = [
     ListOption<"args", "args", "std::string",



More information about the Mlir-commits mailing list