[Mlir-commits] [mlir] Change 758703 (PR #145075)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 20 10:11:55 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Won Jong Jeon (wonjeon)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/145075.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h (+2) 
- (modified) mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td (+21) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt (+1) 
- (added) mlir/lib/Dialect/Tosa/Transforms/TosaInputShape.cpp (+175) 
- (added) mlir/test/Dialect/Tosa/tosa-input-shape.mlir (+15) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index 306e4b1f218e7..33fde665e8108 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -42,6 +42,8 @@ void populateTosaConstantReduction(MLIRContext *ctx,
 void populateTosaTypeConversion(TypeConverter &converter);
 
 std::unique_ptr<Pass> createTosaTestQuantUtilAPIPass();
+std::unique_ptr<Pass>
+createTosaInputShapePass(std::vector<std::string> args = {});
 
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index d005a4cc6859c..dd5c11f6aac7e 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -127,4 +127,25 @@ def TosaReduceTransposes : Pass<"tosa-reduce-transposes", "func::FuncOp"> {
   }];
 }
 
+def TosaInputShape : Pass<"experimental-tosa-input-shape", "func::FuncOp"> {
+  let summary = "Override dynamic input shapes of function arguments to specified static shapes.";
+  let description = [{
+    Pass that overrides the dynamic input shapes of function arguments to specified static shapes.
+    It is an error if a specified static shape conflicts with the static dimensions in an original input shape.
+  }];
+
+  let constructor = "tosa::createTosaInputShapePass()";
+  let dependentDialects = [
+    "func::FuncDialect",
+    "tensor::TensorDialect",
+    "tosa::TosaDialect",
+  ];
+  let options = [
+    ListOption<"args", "args", "std::string",
+          "Comma-separated list of shape descriptions. Each description contains the "
+          "argument name, a colon, and a shape with dimensions separated by x "
+          "(e.g. arg0:5x5,arg3:2x64)">,
+  ];
+}
+
 #endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index bbf079faea3d0..d9458886c0f95 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
   TosaTypeConverters.cpp
   TosaProfileCompliance.cpp
   TosaValidation.cpp
+  TosaInputShape.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInputShape.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInputShape.cpp
new file mode 100644
index 0000000000000..12b6445893f5a
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInputShape.cpp
@@ -0,0 +1,175 @@
+//===- TosaInputShape.cpp -------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Change input shape of function argument to specified shape.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/FormatVariadic.h"
+
+namespace mlir {
+namespace tosa {
+#define GEN_PASS_DEF_TOSAINPUTSHAPE
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+} // namespace tosa
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+namespace {
+
+std::pair<std::vector<std::pair<size_t, std::vector<int64_t>>>, std::string>
+parse_input_shapes(std::vector<std::string> args) {
+  /**
+   * This function returns two values: a vector of parsed arguments, and an
+   * optional error message. Each arguments contains its argument number and the
+   * shape. For example:
+   * "args=arg0:5x10,arg8:3x9" => {{{0, {5, 10}}, {8, {3, 9}}}, ""}
+   * "args=arg0:" => {{}, "error message"}
+   */
+
+  std::vector<std::pair<size_t, std::vector<int64_t>>> shapes;
+
+  for (std::string arg : args) {
+    if (arg.substr(0, 3) != "arg") {
+      return {{}, "Arguments must start with 'arg'"};
+    }
+
+    char *endptr;
+    size_t argnum = std::strtoul(&arg[3], &endptr, /*base=*/10);
+    if (*endptr != ':') {
+      return {{}, "Invalid argument name"};
+    }
+    std::string shape_str = endptr + 1;
+
+    std::vector<int64_t> curr;
+    while (!shape_str.empty()) {
+      size_t dim = std::strtoul(shape_str.data(), &endptr, /*base=*/10);
+      if ((*endptr != '\0' && *endptr != 'x') || shape_str == endptr) {
+        return {{}, "Invalid input shape description"};
+      }
+      curr.push_back(dim);
+      if (*endptr == '\0') {
+        break;
+      }
+      shape_str = endptr + 1;
+    }
+    shapes.push_back({argnum, curr});
+  }
+  return {shapes, ""};
+}
+
+/// Pass that change function input shapes to specified static input shapes
+struct TosaInputShape : public tosa::impl::TosaInputShapeBase<TosaInputShape> {
+public:
+  TosaInputShape() = default;
+  explicit TosaInputShape(std::vector<std::string> args) : TosaInputShape() {
+    this->args = args;
+  }
+  void runOnOperation() override {
+    func::FuncOp func = getOperation();
+    auto [args_parsed, args_parse_err] = parse_input_shapes(args);
+
+    if (!args_parse_err.empty()) {
+      func.emitError() << args_parse_err;
+      return;
+    }
+
+    for (auto &block : func.getBody()) {
+
+      for (auto [argnum, shape] : args_parsed) {
+        if (argnum >= block.getNumArguments()) {
+          func.emitError() << "arg" << argnum << " doesn't exist.";
+          return;
+        }
+        BlockArgument block_arg = block.getArgument(argnum);
+        Type arg_type = block_arg.getType();
+        TensorType tensor_type = cast<TensorType>(arg_type);
+        if (failed(mlir::verifyCompatibleShape(tensor_type.getShape(), shape))) {
+          func->emitError()
+              << "arg" << argnum << " has incompatible shape with input shape.";
+          return;
+        }
+        SmallVector<int64_t> new_shape(shape.begin(), shape.end());
+        auto new_tensor_type =
+            tensor_type.cloneWith(new_shape, tensor_type.getElementType());
+        block_arg.setType(new_tensor_type);
+      }
+
+      bool found_func_op = false;
+
+      for (Operation &op : block) {
+        // Update result shape for func.func
+        func::FuncOp funcOp = mlir::dyn_cast<func::FuncOp>(op.getParentOp());
+        if (funcOp && !found_func_op) {
+          FunctionType old_function_type = funcOp.getFunctionType();
+          std::vector<Type> inputs = old_function_type.getInputs();
+
+          for (auto [argnum, shape] : args_parsed) {
+            if ((size_t)argnum >= inputs.size()) {
+              func.emitError() << "arg" << argnum << " doesn't exist.";
+              return;
+            }
+            auto tensor_type = cast<TensorType>(inputs[argnum]);
+
+            if (failed(mlir::verifyCompatibleShape(tensor_type.getShape(), shape))) {
+              funcOp->emitError()
+                  << "arg" << argnum
+                  << " has incompatible shape with input shape.";
+              return;
+            }
+            SmallVector<int64_t> new_shape(shape.begin(), shape.end());
+            auto new_tensor_type =
+                tensor_type.cloneWith(new_shape, tensor_type.getElementType());
+            inputs[argnum] = cast<Type>(new_tensor_type);
+          }
+
+          FunctionType new_function_type = old_function_type.clone(
+              TypeRange{ArrayRef(inputs)},
+              TypeRange{old_function_type.getResults()});
+          funcOp.setFunctionType(new_function_type);
+          found_func_op = true;
+        }
+        // Update result shape of func.return
+        func::ReturnOp returnOp = mlir::dyn_cast<func::ReturnOp>(op);
+        if (returnOp) {
+          func::FuncOp funcOp = dyn_cast<func::FuncOp>(op.getParentOp());
+          if (funcOp) {
+            FunctionType old_function_type = funcOp.getFunctionType();
+            FunctionType new_function_type = old_function_type.clone(
+                TypeRange{old_function_type.getInputs()},
+                returnOp.getOperandTypes());
+            funcOp.setFunctionType(new_function_type);
+          }
+        }
+      }
+    }
+  }
+};
+
+} // namespace
+
+std::unique_ptr<Pass>
+mlir::tosa::createTosaInputShapePass(std::vector<std::string> args) {
+  return std::make_unique<TosaInputShape>(args);
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-input-shape.mlir b/mlir/test/Dialect/Tosa/tosa-input-shape.mlir
new file mode 100644
index 0000000000000..2a784aa3d33cb
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-input-shape.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt --split-input-file --experimental-tosa-input-shape="args=arg0:2x16,arg3:64x9" %s | FileCheck %s
+
+func.func @test_input_shape(
+    // CHECK: %arg0: tensor<2x16xi32>
+    %arg0: tensor<2x?xi32>,
+    // CHECK: %arg1: tensor<?x256xf32>
+    %arg1: tensor<?x256xf32>,
+    // CHECK: %arg2: tensor<2x?xi32>
+    %arg2: tensor<2x?xi32>,
+    // CHECK: %arg3: tensor<64x9xf32>
+    %arg3: tensor<?x9xf32>) -> (tensor<2x?xi32>, tensor<?x9xf32>) {
+
+    // CHECK: %arg0, %arg3 : tensor<2x16xi32>, tensor<64x9xf32>
+    return %arg0, %arg3 : tensor<2x?xi32>, tensor<?x9xf32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/145075


More information about the Mlir-commits mailing list