[Mlir-commits] [mlir] [mlir][tosa] TosaInputShape supports functions with multiple arguments (PR #145075)

Luke Hutton llvmlistbot at llvm.org
Tue Jun 24 02:27:40 PDT 2025


================
@@ -0,0 +1,177 @@
+//===- TosaInputShape.cpp -------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Change input shape of function argument to specified shape.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/FormatVariadic.h"
+
+namespace mlir {
+namespace tosa {
+#define GEN_PASS_DEF_TOSAINPUTSHAPE
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+} // namespace tosa
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+namespace {
+
+std::pair<std::vector<std::pair<size_t, std::vector<int64_t>>>, std::string>
+parse_input_shapes(std::vector<std::string> args) {
+  /**
+   * This function returns two values: a vector of parsed arguments, and an
+   * optional error message. Each arguments contains its argument number and the
+   * shape. For example:
+   * "args=arg0:5x10,arg8:3x9" => {{{0, {5, 10}}, {8, {3, 9}}}, ""}
+   * "args=arg0:" => {{}, "error message"}
+   */
+
+  std::vector<std::pair<size_t, std::vector<int64_t>>> shapes;
+
+  for (std::string arg : args) {
+    if (arg.substr(0, 3) != "arg") {
+      return {{}, "Arguments must start with 'arg'"};
+    }
+
+    char *endptr;
+    size_t argnum = std::strtoul(&arg[3], &endptr, /*base=*/10);
+    if (*endptr != ':') {
+      return {{}, "Invalid argument name"};
+    }
+    std::string shape_str = endptr + 1;
+
+    std::vector<int64_t> curr;
+    while (!shape_str.empty()) {
+      size_t dim = std::strtoul(shape_str.data(), &endptr, /*base=*/10);
+      if ((*endptr != '\0' && *endptr != 'x') || shape_str == endptr) {
+        return {{}, "Invalid input shape description"};
+      }
+      curr.push_back(dim);
+      if (*endptr == '\0') {
+        break;
+      }
+      shape_str = endptr + 1;
+    }
+    shapes.push_back({argnum, curr});
+  }
+  return {shapes, ""};
+}
+
+/// Pass that change function input shapes to specified static input shapes
+struct TosaInputShape : public tosa::impl::TosaInputShapeBase<TosaInputShape> {
+public:
+  TosaInputShape() = default;
+  explicit TosaInputShape(std::vector<std::string> args) : TosaInputShape() {
+    this->args = args;
+  }
+  void runOnOperation() override {
+    func::FuncOp func = getOperation();
+    auto [args_parsed, args_parse_err] = parse_input_shapes(args);
+
+    if (!args_parse_err.empty()) {
+      func.emitError() << args_parse_err;
+      return;
+    }
+
+    for (auto &block : func.getBody()) {
+
+      for (auto [argnum, shape] : args_parsed) {
+        if (argnum >= block.getNumArguments()) {
+          func.emitError() << "arg" << argnum << " doesn't exist.";
+          return;
+        }
+        BlockArgument block_arg = block.getArgument(argnum);
+        Type arg_type = block_arg.getType();
+        TensorType tensor_type = cast<TensorType>(arg_type);
+        if (failed(
+                mlir::verifyCompatibleShape(tensor_type.getShape(), shape))) {
+          func->emitError()
+              << "arg" << argnum << " has incompatible shape with input shape.";
+          return;
+        }
+        SmallVector<int64_t> new_shape(shape.begin(), shape.end());
+        auto new_tensor_type =
+            tensor_type.cloneWith(new_shape, tensor_type.getElementType());
+        block_arg.setType(new_tensor_type);
+      }
+
+      bool found_func_op = false;
+
+      for (Operation &op : block) {
+        // Update result shape for func.func
+        func::FuncOp funcOp = mlir::dyn_cast<func::FuncOp>(op.getParentOp());
+        if (funcOp && !found_func_op) {
+          FunctionType old_function_type = funcOp.getFunctionType();
+          std::vector<Type> inputs = old_function_type.getInputs();
+
+          for (auto [argnum, shape] : args_parsed) {
+            if ((size_t)argnum >= inputs.size()) {
+              func.emitError() << "arg" << argnum << " doesn't exist.";
+              return;
+            }
+            auto tensor_type = cast<TensorType>(inputs[argnum]);
+
+            if (failed(mlir::verifyCompatibleShape(tensor_type.getShape(),
+                                                   shape))) {
+              funcOp->emitError()
+                  << "arg" << argnum
+                  << " has incompatible shape with input shape.";
+              return;
+            }
+            SmallVector<int64_t> new_shape(shape.begin(), shape.end());
+            auto new_tensor_type =
+                tensor_type.cloneWith(new_shape, tensor_type.getElementType());
+            inputs[argnum] = cast<Type>(new_tensor_type);
+          }
+
+          FunctionType new_function_type = old_function_type.clone(
+              TypeRange{ArrayRef(inputs)},
+              TypeRange{old_function_type.getResults()});
+          funcOp.setFunctionType(new_function_type);
+          found_func_op = true;
+        }
+        // Update result shape of func.return
+        func::ReturnOp returnOp = mlir::dyn_cast<func::ReturnOp>(op);
+        if (returnOp) {
+          func::FuncOp funcOp = dyn_cast<func::FuncOp>(op.getParentOp());
+          if (funcOp) {
----------------
lhutton1 wrote:

nit: could we remove this additional layer of nesting e.g. `if (returnOp && funcOp)`?

https://github.com/llvm/llvm-project/pull/145075


More information about the Mlir-commits mailing list