[Mlir-commits] [mlir] Change 758703 (PR #145075)
Won Jong Jeon
llvmlistbot at llvm.org
Fri Jun 20 10:15:19 PDT 2025
https://github.com/wonjeon updated https://github.com/llvm/llvm-project/pull/145075
>From 110e1f49ea880fc37386d4748717d4873000daa2 Mon Sep 17 00:00:00 2001
From: Kaushik Varadharajan <kaushik.varadharajan at arm.com>
Date: Thu, 20 Jun 2024 23:16:45 +0000
Subject: [PATCH] [mlir][tosa] TosaInputShape supports functions with multiple
arguments
The new command-line syntax is
--experimental-tosa-input-shape="args=arg0:5x5,arg8:2x9" etc.
Signed-off-by: Kaushik Varadharajan <kaushik.varadharajan at arm.com>
Change-Id: I393d51a89a9017212437bda40a0100c881198777
---
.../mlir/Dialect/Tosa/Transforms/Passes.h | 2 +
.../mlir/Dialect/Tosa/Transforms/Passes.td | 21 +++
.../Dialect/Tosa/Transforms/CMakeLists.txt | 1 +
.../Tosa/Transforms/TosaInputShape.cpp | 175 ++++++++++++++++++
mlir/test/Dialect/Tosa/tosa-input-shape.mlir | 15 ++
5 files changed, 214 insertions(+)
create mode 100644 mlir/lib/Dialect/Tosa/Transforms/TosaInputShape.cpp
create mode 100644 mlir/test/Dialect/Tosa/tosa-input-shape.mlir
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index 306e4b1f218e7..33fde665e8108 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -42,6 +42,8 @@ void populateTosaConstantReduction(MLIRContext *ctx,
void populateTosaTypeConversion(TypeConverter &converter);
std::unique_ptr<Pass> createTosaTestQuantUtilAPIPass();
+std::unique_ptr<Pass>
+createTosaInputShapePass(std::vector<std::string> args = {});
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index d005a4cc6859c..dd5c11f6aac7e 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -127,4 +127,25 @@ def TosaReduceTransposes : Pass<"tosa-reduce-transposes", "func::FuncOp"> {
}];
}
+def TosaInputShape : Pass<"experimental-tosa-input-shape", "func::FuncOp"> {
+ let summary = "Override dynamic input shapes of function arguments to specified static shapes.";
+ let description = [{
+ Pass that overrides the dynamic input shapes of function arguments to specified static shapes.
+ It is an error if a specified static shape conflicts with the static dimensions in an original input shape.
+ }];
+
+ let constructor = "tosa::createTosaInputShapePass()";
+ let dependentDialects = [
+ "func::FuncDialect",
+ "tensor::TensorDialect",
+ "tosa::TosaDialect",
+ ];
+ let options = [
+ ListOption<"args", "args", "std::string",
+ "Comma-separated list of shape descriptions. Each description contains the "
+ "argument name, a colon, and a shape with dimensions separated by x "
+ "(e.g. arg0:5x5,arg3:2x64)">,
+ ];
+}
+
#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index bbf079faea3d0..d9458886c0f95 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
TosaTypeConverters.cpp
TosaProfileCompliance.cpp
TosaValidation.cpp
+ TosaInputShape.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInputShape.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInputShape.cpp
new file mode 100644
index 0000000000000..12b6445893f5a
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInputShape.cpp
@@ -0,0 +1,175 @@
+//===- TosaInputShape.cpp -------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Change input shape of function argument to specified shape.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/FormatVariadic.h"
+
+namespace mlir {
+namespace tosa {
+#define GEN_PASS_DEF_TOSAINPUTSHAPE
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+} // namespace tosa
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+namespace {
+
+std::pair<std::vector<std::pair<size_t, std::vector<int64_t>>>, std::string>
+parse_input_shapes(std::vector<std::string> args) {
+ /**
+ * This function returns two values: a vector of parsed arguments, and an
+ * optional error message. Each arguments contains its argument number and the
+ * shape. For example:
+ * "args=arg0:5x10,arg8:3x9" => {{{0, {5, 10}}, {8, {3, 9}}}, ""}
+ * "args=arg0:" => {{}, "error message"}
+ */
+
+ std::vector<std::pair<size_t, std::vector<int64_t>>> shapes;
+
+ for (std::string arg : args) {
+ if (arg.substr(0, 3) != "arg") {
+ return {{}, "Arguments must start with 'arg'"};
+ }
+
+ char *endptr;
+ size_t argnum = std::strtoul(&arg[3], &endptr, /*base=*/10);
+ if (*endptr != ':') {
+ return {{}, "Invalid argument name"};
+ }
+ std::string shape_str = endptr + 1;
+
+ std::vector<int64_t> curr;
+ while (!shape_str.empty()) {
+ size_t dim = std::strtoul(shape_str.data(), &endptr, /*base=*/10);
+ if ((*endptr != '\0' && *endptr != 'x') || shape_str == endptr) {
+ return {{}, "Invalid input shape description"};
+ }
+ curr.push_back(dim);
+ if (*endptr == '\0') {
+ break;
+ }
+ shape_str = endptr + 1;
+ }
+ shapes.push_back({argnum, curr});
+ }
+ return {shapes, ""};
+}
+
+/// Pass that change function input shapes to specified static input shapes
+struct TosaInputShape : public tosa::impl::TosaInputShapeBase<TosaInputShape> {
+public:
+ TosaInputShape() = default;
+ explicit TosaInputShape(std::vector<std::string> args) : TosaInputShape() {
+ this->args = args;
+ }
+ void runOnOperation() override {
+ func::FuncOp func = getOperation();
+ auto [args_parsed, args_parse_err] = parse_input_shapes(args);
+
+ if (!args_parse_err.empty()) {
+ func.emitError() << args_parse_err;
+ return;
+ }
+
+ for (auto &block : func.getBody()) {
+
+ for (auto [argnum, shape] : args_parsed) {
+ if (argnum >= block.getNumArguments()) {
+ func.emitError() << "arg" << argnum << " doesn't exist.";
+ return;
+ }
+ BlockArgument block_arg = block.getArgument(argnum);
+ Type arg_type = block_arg.getType();
+ TensorType tensor_type = cast<TensorType>(arg_type);
+ if (failed(mlir::verifyCompatibleShape(tensor_type.getShape(), shape))) {
+ func->emitError()
+ << "arg" << argnum << " has incompatible shape with input shape.";
+ return;
+ }
+ SmallVector<int64_t> new_shape(shape.begin(), shape.end());
+ auto new_tensor_type =
+ tensor_type.cloneWith(new_shape, tensor_type.getElementType());
+ block_arg.setType(new_tensor_type);
+ }
+
+ bool found_func_op = false;
+
+ for (Operation &op : block) {
+ // Update result shape for func.func
+ func::FuncOp funcOp = mlir::dyn_cast<func::FuncOp>(op.getParentOp());
+ if (funcOp && !found_func_op) {
+ FunctionType old_function_type = funcOp.getFunctionType();
+ std::vector<Type> inputs = old_function_type.getInputs();
+
+ for (auto [argnum, shape] : args_parsed) {
+ if ((size_t)argnum >= inputs.size()) {
+ func.emitError() << "arg" << argnum << " doesn't exist.";
+ return;
+ }
+ auto tensor_type = cast<TensorType>(inputs[argnum]);
+
+ if (failed(mlir::verifyCompatibleShape(tensor_type.getShape(), shape))) {
+ funcOp->emitError()
+ << "arg" << argnum
+ << " has incompatible shape with input shape.";
+ return;
+ }
+ SmallVector<int64_t> new_shape(shape.begin(), shape.end());
+ auto new_tensor_type =
+ tensor_type.cloneWith(new_shape, tensor_type.getElementType());
+ inputs[argnum] = cast<Type>(new_tensor_type);
+ }
+
+ FunctionType new_function_type = old_function_type.clone(
+ TypeRange{ArrayRef(inputs)},
+ TypeRange{old_function_type.getResults()});
+ funcOp.setFunctionType(new_function_type);
+ found_func_op = true;
+ }
+ // Update result shape of func.return
+ func::ReturnOp returnOp = mlir::dyn_cast<func::ReturnOp>(op);
+ if (returnOp) {
+ func::FuncOp funcOp = dyn_cast<func::FuncOp>(op.getParentOp());
+ if (funcOp) {
+ FunctionType old_function_type = funcOp.getFunctionType();
+ FunctionType new_function_type = old_function_type.clone(
+ TypeRange{old_function_type.getInputs()},
+ returnOp.getOperandTypes());
+ funcOp.setFunctionType(new_function_type);
+ }
+ }
+ }
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<Pass>
+mlir::tosa::createTosaInputShapePass(std::vector<std::string> args) {
+ return std::make_unique<TosaInputShape>(args);
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-input-shape.mlir b/mlir/test/Dialect/Tosa/tosa-input-shape.mlir
new file mode 100644
index 0000000000000..2a784aa3d33cb
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-input-shape.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt --split-input-file --experimental-tosa-input-shape="args=arg0:2x16,arg3:64x9" %s | FileCheck %s
+
+func.func @test_input_shape(
+ // CHECK: %arg0: tensor<2x16xi32>
+ %arg0: tensor<2x?xi32>,
+ // CHECK: %arg1: tensor<?x256xf32>
+ %arg1: tensor<?x256xf32>,
+ // CHECK: %arg2: tensor<2x?xi32>
+ %arg2: tensor<2x?xi32>,
+ // CHECK: %arg3: tensor<64x9xf32>
+ %arg3: tensor<?x9xf32>) -> (tensor<2x?xi32>, tensor<?x9xf32>) {
+
+ // CHECK: %arg0, %arg3 : tensor<2x16xi32>, tensor<64x9xf32>
+ return %arg0, %arg3 : tensor<2x?xi32>, tensor<?x9xf32>
+}
More information about the Mlir-commits
mailing list