[Mlir-commits] [mlir] [mlir][func]: Introduce ReplaceFuncSignature tranform operation (PR #143381)

Aviad Cohen llvmlistbot at llvm.org
Sat Jun 14 05:25:55 PDT 2025


https://github.com/AviadCo updated https://github.com/llvm/llvm-project/pull/143381

>From 0677cee0e996ce3e7eadb8507d0af8c54498d289 Mon Sep 17 00:00:00 2001
From: Aviad Cohen <aviad.cohen2 at mobileye.com>
Date: Sun, 8 Jun 2025 11:17:13 +0300
Subject: [PATCH] [mlir][func]: Introduce ReplaceFuncSignature tranform
 operation

This transform takes a module and a function name, and replaces the
signature of the function by reordering the arguments and results
according to the interchange arrays. The function is expected to be
defined in the module, and the interchange arrays must match the number
of arguments and results of the function.
---
 .../Func/TransformOps/FuncTransformOps.h      |   2 +-
 .../Func/TransformOps/FuncTransformOps.td     |  36 +++++
 mlir/include/mlir/Dialect/Func/Utils/Utils.h  |  49 +++++++
 mlir/lib/Dialect/Func/CMakeLists.txt          |   1 +
 .../Func/TransformOps/FuncTransformOps.cpp    | 107 +++++++++++++-
 mlir/lib/Dialect/Func/Utils/CMakeLists.txt    |  13 ++
 mlir/lib/Dialect/Func/Utils/Utils.cpp         | 121 ++++++++++++++++
 .../Dialect/Func/func-transform-invalid.mlir  |  87 ++++++++++++
 mlir/test/Dialect/Func/func-transform.mlir    | 132 ++++++++++++++++++
 9 files changed, 546 insertions(+), 2 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/Func/Utils/Utils.h
 create mode 100644 mlir/lib/Dialect/Func/Utils/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/Func/Utils/Utils.cpp
 create mode 100644 mlir/test/Dialect/Func/func-transform-invalid.mlir

diff --git a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.h b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.h
index 37f0ea0f28552..15c36429263cb 100644
--- a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.h
+++ b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.h
@@ -1,4 +1,4 @@
-//===- FuncTransformOps.h - CF transformation ops --------*- C++ -*-===//
+//===- FuncTransformOps.h - Function transformation ops --------*- C++ -*--===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
index 306fbf881de61..4062f310c6521 100644
--- a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
+++ b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
@@ -98,4 +98,40 @@ def CastAndCallOp : Op<Transform_Dialect,
   let hasVerifier = 1;
 }
 
+def ReplaceFuncSignatureOp
+    : Op<Transform_Dialect, "func.replace_func_signature",
+         [DeclareOpInterfaceMethods<TransformOpInterface>,
+          DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+  let description = [{
+      This transform takes a module and a function name, and replaces the
+      signature of the function by reordering the arguments and results
+      according to the interchange arrays. The function is expected to be
+      defined in the module, and the interchange arrays must match the number
+      of arguments and results of the function.
+
+      The `adjust_func_calls` attribute indicates whether the function calls
+      should be adjusted to match the new signature. If set to `true`, the
+      function calls will be adjusted to match the new signature, otherwise
+      they will not be adjusted.
+
+      This transform will emit a silenceable failure if:
+       - The function with the given name does not exist in the module.
+       - The interchange arrays do not match the number of arguments/results.
+       - The interchange arrays contain out of bound indices.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$module,
+      SymbolRefAttr:$function_name, DenseI32ArrayAttr:$args_interchange,
+      DenseI32ArrayAttr:$results_interchange, UnitAttr:$adjust_func_calls);
+  let results = (outs TransformHandleTypeInterface:$transformed_module,
+                      TransformHandleTypeInterface:$transformed_function);
+
+  let assemblyFormat = [{
+    $function_name
+    `args_interchange` `=` $args_interchange
+    `results_interchange` `=` $results_interchange
+    `at` $module attr-dict `:` functional-type(operands, results)
+  }];
+}
+
 #endif // FUNC_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
new file mode 100644
index 0000000000000..f70cd986f26b3
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
@@ -0,0 +1,49 @@
+//===- Utils.h - General Func transformation utilities ----*- C++ -*-------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines prototypes for various transformation utilities for
+// the Func dialect. These are not passes by themselves but are used
+// either by passes, optimization sequences, or in turn by other transformation
+// utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_FUNC_UTILS_H
+#define MLIR_DIALECT_FUNC_UTILS_H
+
+#include "mlir/IR/PatternMatch.h"
+#include "llvm/ADT/ArrayRef.h"
+
+namespace mlir {
+
+namespace func {
+
+class FuncOp;
+class CallOp;
+
+/// Creates a new function operation with the same name as the original
+/// function operation, but with the arguments reordered according to
+/// the `newArgsOrder` and `newResultsOrder`.
+/// The `funcOp` operation must have exactly one block.
+/// Returns the new function operation or failure if `funcOp` doesn't
+/// have exactly one block.
+FailureOr<func::FuncOp>
+replaceFuncWithNewOrder(RewriterBase &rewriter, FuncOp funcOp,
+                        llvm::ArrayRef<unsigned> newArgsOrder,
+                        llvm::ArrayRef<unsigned> newResultsOrder);
+/// Creates a new call operation with the values as the original
+/// call operation, but with the arguments reordered according to
+/// the `newArgsOrder` and `newResultsOrder`.
+CallOp replaceCallOpWithNewOrder(RewriterBase &rewriter, CallOp callOp,
+                                 llvm::ArrayRef<unsigned> newArgsOrder,
+                                 llvm::ArrayRef<unsigned> newResultsOrder);
+
+} // namespace func
+} // namespace mlir
+
+#endif // MLIR_DIALECT_FUNC_UTILS_H
diff --git a/mlir/lib/Dialect/Func/CMakeLists.txt b/mlir/lib/Dialect/Func/CMakeLists.txt
index ec999ffdb99da..a834aae8fbf81 100644
--- a/mlir/lib/Dialect/Func/CMakeLists.txt
+++ b/mlir/lib/Dialect/Func/CMakeLists.txt
@@ -2,3 +2,4 @@ add_subdirectory(Extensions)
 add_subdirectory(IR)
 add_subdirectory(Transforms)
 add_subdirectory(TransformOps)
+add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
index 9966d7339e1b4..3adbf092742be 100644
--- a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
+++ b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
@@ -1,4 +1,4 @@
-//===- FuncTransformOps.cpp - Implementation of CF transform ops ---===//
+//===- FuncTransformOps.cpp - Implementation of CF transform ops ----------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -11,10 +11,12 @@
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Utils/Utils.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformOps.h"
 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 using namespace mlir;
@@ -226,6 +228,109 @@ void transform::CastAndCallOp::getEffects(
   transform::modifiesPayload(effects);
 }
 
+//===----------------------------------------------------------------------===//
+// ReplaceFuncSignatureOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ReplaceFuncSignatureOp::apply(transform::TransformRewriter &rewriter,
+                                         transform::TransformResults &results,
+                                         transform::TransformState &state) {
+  auto payloadOps = state.getPayloadOps(getModule());
+  if (!llvm::hasSingleElement(payloadOps))
+    return emitDefiniteFailure() << "requires a single module to operate on";
+
+  auto targetModuleOp = dyn_cast<ModuleOp>(*payloadOps.begin());
+  if (!targetModuleOp)
+    return emitSilenceableFailure(getLoc())
+           << "target is expected to be module operation";
+
+  func::FuncOp funcOp =
+      targetModuleOp.lookupSymbol<func::FuncOp>(getFunctionName());
+  if (!funcOp)
+    return emitSilenceableFailure(getLoc())
+           << "function with name '" << getFunctionName() << "' not found";
+
+  unsigned numArgs = funcOp.getNumArguments();
+  unsigned numResults = funcOp.getNumResults();
+  // Check that the number of arguments and results matches the
+  // interchange sizes.
+  if (numArgs != getArgsInterchange().size())
+    return emitSilenceableFailure(getLoc())
+           << "function with name '" << getFunctionName() << "' has " << numArgs
+           << " arguments, but " << getArgsInterchange().size()
+           << " args interchange were given";
+
+  if (numResults != getResultsInterchange().size())
+    return emitSilenceableFailure(getLoc())
+           << "function with name '" << getFunctionName() << "' has "
+           << numResults << " results, but " << getResultsInterchange().size()
+           << " results interchange were given";
+
+  // Check that the args and results interchanges are unique.
+  SetVector<unsigned> argsInterchange, resultsInterchange;
+  argsInterchange.insert_range(getArgsInterchange());
+  resultsInterchange.insert_range(getResultsInterchange());
+  if (argsInterchange.size() != getArgsInterchange().size())
+    return emitSilenceableFailure(getLoc())
+           << "args interchange must be unique";
+
+  if (resultsInterchange.size() != getResultsInterchange().size())
+    return emitSilenceableFailure(getLoc())
+           << "results interchange must be unique";
+
+  // Check that the args and results interchange indices are in bounds.
+  for (unsigned index : argsInterchange) {
+    if (index >= numArgs) {
+      return emitSilenceableFailure(getLoc())
+             << "args interchange index " << index
+             << " is out of bounds for function with name '"
+             << getFunctionName() << "' with " << numArgs << " arguments";
+    }
+  }
+  for (unsigned index : resultsInterchange) {
+    if (index >= numResults) {
+      return emitSilenceableFailure(getLoc())
+             << "results interchange index " << index
+             << " is out of bounds for function with name '"
+             << getFunctionName() << "' with " << numResults << " results";
+    }
+  }
+
+  FailureOr<func::FuncOp> newFuncOpOrFailure = func::replaceFuncWithNewOrder(
+      rewriter, funcOp, argsInterchange.getArrayRef(),
+      resultsInterchange.getArrayRef());
+  if (failed(newFuncOpOrFailure))
+    return emitSilenceableFailure(getLoc())
+           << "failed to replace function signature '" << getFunctionName()
+           << "' with new order";
+
+  if (getAdjustFuncCalls()) {
+    SmallVector<func::CallOp> callOps;
+    targetModuleOp.walk([&](func::CallOp callOp) {
+      if (callOp.getCallee() == getFunctionName().getRootReference().getValue())
+        callOps.push_back(callOp);
+    });
+
+    for (func::CallOp callOp : callOps)
+      func::replaceCallOpWithNewOrder(rewriter, callOp,
+                                      argsInterchange.getArrayRef(),
+                                      resultsInterchange.getArrayRef());
+  }
+
+  results.set(cast<OpResult>(getTransformedModule()), {targetModuleOp});
+  results.set(cast<OpResult>(getTransformedFunction()), {*newFuncOpOrFailure});
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::ReplaceFuncSignatureOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::consumesHandle(getModuleMutable(), effects);
+  transform::producesHandle(getOperation()->getOpResults(), effects);
+  transform::modifiesPayload(effects);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Func/Utils/CMakeLists.txt b/mlir/lib/Dialect/Func/Utils/CMakeLists.txt
new file mode 100644
index 0000000000000..e39a8c8c25d03
--- /dev/null
+++ b/mlir/lib/Dialect/Func/Utils/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_dialect_library(MLIRFuncUtils
+  Utils.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Func/Utils
+
+  LINK_LIBS PUBLIC
+  MLIRFuncDialect
+  MLIRComplexDialect
+  MLIRDialect
+  MLIRDialectUtils
+  MLIRIR
+  )
diff --git a/mlir/lib/Dialect/Func/Utils/Utils.cpp b/mlir/lib/Dialect/Func/Utils/Utils.cpp
new file mode 100644
index 0000000000000..0e9662689ef78
--- /dev/null
+++ b/mlir/lib/Dialect/Func/Utils/Utils.cpp
@@ -0,0 +1,121 @@
+//===- Utils.cpp - Utilities to support the Func dialect ------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements utilities for the Func dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/Utils/Utils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/PatternMatch.h"
+#include "llvm/ADT/SmallVector.h"
+
+using namespace mlir;
+
+FailureOr<func::FuncOp>
+func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
+                              ArrayRef<unsigned> newArgsOrder,
+                              ArrayRef<unsigned> newResultsOrder) {
+  // Generate an empty new function operation with the same name as the
+  // original.
+  assert(funcOp.getNumArguments() == newArgsOrder.size() &&
+         "newArgsOrder must match the number of arguments in the function");
+  assert(funcOp.getNumResults() == newResultsOrder.size() &&
+         "newResultsOrder must match the number of results in the function");
+
+  if (!funcOp.getBody().hasOneBlock())
+    return rewriter.notifyMatchFailure(
+        funcOp, "expected function to have exactly one block");
+
+  ArrayRef<Type> origInputTypes = funcOp.getFunctionType().getInputs();
+  ArrayRef<Type> origOutputTypes = funcOp.getFunctionType().getResults();
+  SmallVector<Type> newInputTypes, newOutputTypes;
+  SmallVector<Location> locs;
+  for (unsigned int idx : newArgsOrder) {
+    newInputTypes.push_back(origInputTypes[idx]);
+    locs.push_back(funcOp.getArgument(newArgsOrder[idx]).getLoc());
+  }
+  for (unsigned int idx : newResultsOrder)
+    newOutputTypes.push_back(origOutputTypes[idx]);
+  rewriter.setInsertionPoint(funcOp);
+  auto newFuncOp = rewriter.create<func::FuncOp>(
+      funcOp.getLoc(), funcOp.getName(),
+      rewriter.getFunctionType(newInputTypes, newOutputTypes));
+
+  Region &newRegion = newFuncOp.getBody();
+  rewriter.createBlock(&newRegion, newRegion.begin(), newInputTypes, locs);
+  newFuncOp.setVisibility(funcOp.getVisibility());
+  newFuncOp->setDiscardableAttrs(funcOp->getDiscardableAttrDictionary());
+
+  // Map the arguments of the original function to the new function in
+  // the new order and adjust the attributes accordingly.
+  IRMapping operandMapper;
+  SmallVector<DictionaryAttr> argAttrs, resultAttrs;
+  funcOp.getAllArgAttrs(argAttrs);
+  for (unsigned int i = 0; i < newArgsOrder.size(); ++i) {
+    operandMapper.map(funcOp.getArgument(newArgsOrder[i]),
+                      newFuncOp.getArgument(i));
+    newFuncOp.setArgAttrs(i, argAttrs[newArgsOrder[i]]);
+  }
+  funcOp.getAllResultAttrs(resultAttrs);
+  for (unsigned int i = 0; i < newResultsOrder.size(); ++i)
+    newFuncOp.setResultAttrs(i, resultAttrs[newResultsOrder[i]]);
+
+  // Clone the operations from the original function to the new function.
+  rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
+  for (Operation &op : funcOp.getOps())
+    rewriter.clone(op, operandMapper);
+
+  // Handle the return operation.
+  auto returnOp = cast<func::ReturnOp>(
+      newFuncOp.getFunctionBody().begin()->getTerminator());
+  SmallVector<Value> newReturnValues;
+  for (unsigned int idx : newResultsOrder)
+    newReturnValues.push_back(returnOp.getOperand(idx));
+  rewriter.setInsertionPoint(returnOp);
+  auto newReturnOp =
+      rewriter.create<func::ReturnOp>(newFuncOp.getLoc(), newReturnValues);
+  newReturnOp->setDiscardableAttrs(returnOp->getDiscardableAttrDictionary());
+  rewriter.eraseOp(returnOp);
+
+  rewriter.eraseOp(funcOp);
+
+  return newFuncOp;
+}
+
+func::CallOp
+func::replaceCallOpWithNewOrder(RewriterBase &rewriter, func::CallOp callOp,
+                                ArrayRef<unsigned> newArgsOrder,
+                                ArrayRef<unsigned> newResultsOrder) {
+  assert(
+      callOp.getNumOperands() == newArgsOrder.size() &&
+      "newArgsOrder must match the number of operands in the call operation");
+  assert(
+      callOp.getNumResults() == newResultsOrder.size() &&
+      "newResultsOrder must match the number of results in the call operation");
+  SmallVector<Value> newArgsOrderValues;
+  for (unsigned int argIdx : newArgsOrder)
+    newArgsOrderValues.push_back(callOp.getOperand(argIdx));
+  SmallVector<Type> newResultTypes;
+  for (unsigned int resIdx : newResultsOrder)
+    newResultTypes.push_back(callOp.getResult(resIdx).getType());
+
+  // Replace the kernel call operation with a new one that has the
+  // reordered arguments.
+  rewriter.setInsertionPoint(callOp);
+  auto newCallOp = rewriter.create<func::CallOp>(
+      callOp.getLoc(), callOp.getCallee(), newResultTypes, newArgsOrderValues);
+  newCallOp.setNoInlineAttr(callOp.getNoInlineAttr());
+  for (auto &&[newIndex, origIndex] : llvm::enumerate(newResultsOrder))
+    rewriter.replaceAllUsesWith(callOp.getResult(origIndex),
+                                newCallOp.getResult(newIndex));
+  rewriter.eraseOp(callOp);
+
+  return newCallOp;
+}
diff --git a/mlir/test/Dialect/Func/func-transform-invalid.mlir b/mlir/test/Dialect/Func/func-transform-invalid.mlir
new file mode 100644
index 0000000000000..e712eee83f36e
--- /dev/null
+++ b/mlir/test/Dialect/Func/func-transform-invalid.mlir
@@ -0,0 +1,87 @@
+// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --split-input-file -verify-diagnostics
+
+module {
+  func.func private @func_with_reverse_order_no_result_no_calls(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) {
+    %c0 = arith.constant 0 : index
+    %view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+    %view0 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
+    %view1 = memref.view %arg2[%c0][] : memref<3xi8, 1> to memref<3xi8, 1>
+    return
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %module = transform.get_parent_op %func : (!transform.any_op) -> !transform.any_op
+    // expected-error @+1 {{function with name '@func_not_in_module' not found}}
+    transform.func.replace_func_signature @func_not_in_module args_interchange = [0, 2, 1] results_interchange = [] at %module : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+module {
+  func.func private @func_with_reverse_order_no_result_no_calls(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) {
+    %c0 = arith.constant 0 : index
+    %view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+    %view0 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
+    %view1 = memref.view %arg2[%c0][] : memref<3xi8, 1> to memref<3xi8, 1>
+    return
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %module = transform.get_parent_op %func : (!transform.any_op) -> !transform.any_op
+    // expected-error @+1 {{function with name '@func_with_reverse_order_no_result_no_calls' has 3 arguments, but 2 args interchange were given}}
+    transform.func.replace_func_signature @func_with_reverse_order_no_result_no_calls args_interchange = [0, 2] results_interchange = [] at %module : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+module {
+  func.func private @func_with_reverse_order_no_result_no_calls(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) {
+    %c0 = arith.constant 0 : index
+    %view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+    %view0 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
+    %view1 = memref.view %arg2[%c0][] : memref<3xi8, 1> to memref<3xi8, 1>
+    return
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %module = transform.get_parent_op %func : (!transform.any_op) -> !transform.any_op
+    // expected-error @+1 {{function with name '@func_with_reverse_order_no_result_no_calls' has 0 results, but 1 results interchange were given}}
+    transform.func.replace_func_signature @func_with_reverse_order_no_result_no_calls args_interchange = [0, 2, 1] results_interchange = [0] at %module : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+module {
+  func.func private @func_with_reverse_order_no_result_no_calls(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) {
+    %c0 = arith.constant 0 : index
+    %view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+    %view0 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
+    %view1 = memref.view %arg2[%c0][] : memref<3xi8, 1> to memref<3xi8, 1>
+    return
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %module = transform.get_parent_op %func : (!transform.any_op) -> !transform.any_op
+    // expected-error @+1 {{args interchange must be unique}}
+    transform.func.replace_func_signature @func_with_reverse_order_no_result_no_calls args_interchange = [0, 2, 2] results_interchange = [] at %module : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Func/func-transform.mlir b/mlir/test/Dialect/Func/func-transform.mlir
index 6aab07b0cb38a..36a66aaa95bfb 100644
--- a/mlir/test/Dialect/Func/func-transform.mlir
+++ b/mlir/test/Dialect/Func/func-transform.mlir
@@ -118,3 +118,135 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+module {
+  // CHECK:           func.func private @func_with_reverse_order_no_result_no_calls(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<3xi8, 1>, %[[ARG2:.*]]: memref<2xi8, 1>) {
+  func.func private @func_with_reverse_order_no_result_no_calls(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) {
+    // CHECK:             %[[C0:.*]] = arith.constant 0 : index
+    %c0 = arith.constant 0 : index
+    // CHECK:             %[[VAL_4:.*]] = memref.view %[[ARG0]]{{\[}}%[[C0]]][] : memref<1xi8, 1> to memref<1xi8, 1>
+    %view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+    // CHECK:             %[[VAL_5:.*]] = memref.view %[[ARG2]]{{\[}}%[[C0]]][] : memref<2xi8, 1> to memref<2xi8, 1>
+    %view0 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
+    // CHECK:             %[[VAL_6:.*]] = memref.view %[[ARG1]]{{\[}}%[[C0]]][] : memref<3xi8, 1> to memref<3xi8, 1>
+    %view1 = memref.view %arg2[%c0][] : memref<3xi8, 1> to memref<3xi8, 1>
+    return
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %module = transform.get_parent_op %func : (!transform.any_op) -> !transform.any_op
+    transform.func.replace_func_signature @func_with_reverse_order_no_result_no_calls args_interchange = [0, 2, 1] results_interchange = [] at %module : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+module {
+  // CHECK:           func.func private @func_with_reverse_order_no_result(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<3xi8, 1>, %[[ARG2:.*]]: memref<2xi8, 1>) {
+  func.func private @func_with_reverse_order_no_result(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) {
+    // CHECK:             %[[C0:.*]] = arith.constant 0 : index
+    %c0 = arith.constant 0 : index
+    // CHECK:             %[[VAL_4:.*]] = memref.view %[[ARG0]]{{\[}}%[[C0]]][] : memref<1xi8, 1> to memref<1xi8, 1>
+    %view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+    // CHECK:             %[[VAL_5:.*]] = memref.view %[[ARG2]]{{\[}}%[[C0]]][] : memref<2xi8, 1> to memref<2xi8, 1>
+    %view0 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
+    // CHECK:             %[[VAL_6:.*]] = memref.view %[[ARG1]]{{\[}}%[[C0]]][] : memref<3xi8, 1> to memref<3xi8, 1>
+    %view1 = memref.view %arg2[%c0][] : memref<3xi8, 1> to memref<3xi8, 1>
+    return
+  }
+
+  // CHECK:           func.func @func_with_reverse_order_no_result_caller(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<2xi8, 1>, %[[ARG2:.*]]: memref<3xi8, 1>) {
+  func.func @func_with_reverse_order_no_result_caller(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) {
+    // CHECK:             call @func_with_reverse_order_no_result(%[[ARG0]], %[[ARG2]], %[[ARG1]]) : (memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>) -> ()
+    call @func_with_reverse_order_no_result(%arg0, %arg1, %arg2) : (memref<1xi8, 1>, memref<2xi8, 1>, memref<3xi8, 1>) -> ()
+    return
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %module = transform.get_parent_op %f#0 : (!transform.any_op) -> !transform.any_op
+    transform.func.replace_func_signature @func_with_reverse_order_no_result args_interchange = [0, 2, 1] results_interchange = [] at %module {adjust_func_calls} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+module {
+  // CHECK:           func.func private @func_with_reverse_order(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<3xi8, 1>, %[[ARG2:.*]]: memref<2xi8, 1>) -> (memref<2xi8, 1>, memref<1xi8, 1>) {
+  func.func private @func_with_reverse_order(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>) {
+    // CHECK:             %[[C0:.*]] = arith.constant 0 : index
+    %c0 = arith.constant 0 : index
+    // CHECK:             %[[RET_0:.*]] = memref.view %[[ARG0]]{{\[}}%[[C0]]][] : memref<1xi8, 1> to memref<1xi8, 1>
+    %view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+    // CHECK:             %[[RET_1:.*]] = memref.view %[[ARG2]]{{\[}}%[[C0]]][] : memref<2xi8, 1> to memref<2xi8, 1>
+    %view0 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
+    // CHECK:             %[[VAL_6:.*]] = memref.view %[[ARG1]]{{\[}}%[[C0]]][] : memref<3xi8, 1> to memref<3xi8, 1>
+    %view1 = memref.view %arg2[%c0][] : memref<3xi8, 1> to memref<3xi8, 1>
+    // CHECK:             return %[[RET_1]], %[[RET_0]] : memref<2xi8, 1>, memref<1xi8, 1>
+    return %view, %view0 : memref<1xi8, 1>, memref<2xi8, 1>
+  }
+
+  // CHECK:           func.func @func_with_reverse_order_caller(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<2xi8, 1>, %[[ARG2:.*]]: memref<3xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>) {
+  func.func @func_with_reverse_order_caller(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>) {
+    // CHECK:             %[[RET:.*]]:2 = call @func_with_reverse_order(%[[ARG0]], %[[ARG2]], %[[ARG1]]) : (memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>) -> (memref<2xi8, 1>, memref<1xi8, 1>)
+    %0, %1 = call @func_with_reverse_order(%arg0, %arg1, %arg2) : (memref<1xi8, 1>, memref<2xi8, 1>, memref<3xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>)
+    // CHECK:             return %[[RET]]#1, %[[RET]]#0 : memref<1xi8, 1>, memref<2xi8, 1>
+    return %0, %1 : memref<1xi8, 1>, memref<2xi8, 1>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %module = transform.get_parent_op %f#0 : (!transform.any_op) -> !transform.any_op
+    transform.func.replace_func_signature @func_with_reverse_order args_interchange = [0, 2, 1] results_interchange = [1, 0] at %module {adjust_func_calls} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+module {
+  // CHECK:           func.func private @func_with_reverse_order_with_attr(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<3xi8, 1>, %[[ARG2:.*]]: memref<2xi8, 1> {transform.readonly}) -> (memref<2xi8, 1>, memref<1xi8, 1>) {
+  func.func private @func_with_reverse_order_with_attr(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>{transform.readonly}, %arg2: memref<3xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>) {
+    // CHECK:             %[[C0:.*]] = arith.constant 0 : index
+    %c0 = arith.constant 0 : index
+    // CHECK:             %[[RET_0:.*]] = memref.view %[[ARG0]]{{\[}}%[[C0]]][] : memref<1xi8, 1> to memref<1xi8, 1>
+    %view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+    // CHECK:             %[[RET_1:.*]] = memref.view %[[ARG2]]{{\[}}%[[C0]]][] : memref<2xi8, 1> to memref<2xi8, 1>
+    %view0 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
+    // CHECK:             %[[VAL_6:.*]] = memref.view %[[ARG1]]{{\[}}%[[C0]]][] : memref<3xi8, 1> to memref<3xi8, 1>
+    %view1 = memref.view %arg2[%c0][] : memref<3xi8, 1> to memref<3xi8, 1>
+    // CHECK:             return %[[RET_1]], %[[RET_0]] : memref<2xi8, 1>, memref<1xi8, 1>
+    return %view, %view0 : memref<1xi8, 1>, memref<2xi8, 1>
+  }
+
+  // CHECK:           func.func @func_with_reverse_order_with_attr_caller(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<2xi8, 1>, %[[ARG2:.*]]: memref<3xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>) {
+  func.func @func_with_reverse_order_with_attr_caller(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>) {
+    // CHECK:             %[[RET:.*]]:2 = call @func_with_reverse_order_with_attr(%[[ARG0]], %[[ARG2]], %[[ARG1]]) : (memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>) -> (memref<2xi8, 1>, memref<1xi8, 1>)
+    %0, %1 = call @func_with_reverse_order_with_attr(%arg0, %arg1, %arg2) : (memref<1xi8, 1>, memref<2xi8, 1>, memref<3xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>)
+    // CHECK:             return %[[RET]]#1, %[[RET]]#0 : memref<1xi8, 1>, memref<2xi8, 1>
+    return %0, %1 : memref<1xi8, 1>, memref<2xi8, 1>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %module = transform.get_parent_op %f#0 : (!transform.any_op) -> !transform.any_op
+    transform.func.replace_func_signature @func_with_reverse_order_with_attr args_interchange = [0, 2, 1] results_interchange = [1, 0] at %module {adjust_func_calls} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list