[Mlir-commits] [mlir] [mlir][func]-Add deduplicate funcOp arguments transform (PR #158266)

Amir Bishara llvmlistbot at llvm.org
Sat Sep 13 22:33:13 PDT 2025


https://github.com/amirBish updated https://github.com/llvm/llvm-project/pull/158266

>From bfc15d5b8c92232d2645f16dd49823bcca41a9a3 Mon Sep 17 00:00:00 2001
From: Amir Bishara <amir.bishara at mobileye.com>
Date: Fri, 12 Sep 2025 12:55:54 +0300
Subject: [PATCH] [mlir][func]-Add deduplicate funcOp arguments transform

This PR adds a new transform operation which removes the
duplicate arguments from the function operation based on
the callOp of this function.

To have a more simple implementation for now, the transform
will fail when having multiple callOps for the same function
we want to eliminate the different arguments from.

This pull request also adpat the utils under the func dialect
to be reusable also for this transformOp.
---
 .../Func/TransformOps/FuncTransformOps.td     |  26 ++
 mlir/include/mlir/Dialect/Func/Utils/Utils.h  |  42 +++-
 .../Func/TransformOps/FuncTransformOps.cpp    |  63 ++++-
 mlir/lib/Dialect/Func/Utils/Utils.cpp         | 232 ++++++++++++++----
 .../Dialect/Func/func-transform-invalid.mlir  |  89 +++++++
 mlir/test/Dialect/Func/func-transform.mlir    |  62 +++++
 6 files changed, 447 insertions(+), 67 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
index 4062f310c6521..b64b3fcdb275b 100644
--- a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
+++ b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
@@ -134,4 +134,30 @@ def ReplaceFuncSignatureOp
   }];
 }
 
+def DeduplicateFuncArgsOp
+    : Op<Transform_Dialect, "func.deduplicate_func_args",
+         [DeclareOpInterfaceMethods<TransformOpInterface>,
+          DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+  let description = [{
+      This transform takes a module and a function name, and deduplicates
+      the arguments of the function. The function is expected to be defined in
+      the module.
+
+      This transform will emit a silenceable failure if:
+       - The function with the given name does not exist in the module.
+       - The function does not have duplicate arguments.
+       - The function does not have a single call.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$module,
+      SymbolRefAttr:$function_name);
+  let results = (outs TransformHandleTypeInterface:$transformed_module,
+                      TransformHandleTypeInterface:$transformed_function);
+
+  let assemblyFormat = [{
+    $function_name
+    `at` $module attr-dict `:` functional-type(operands, results)
+  }];
+}
+
 #endif // FUNC_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
index 2e8b6723a0e53..ffebf9b29cc1e 100644
--- a/mlir/include/mlir/Dialect/Func/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
@@ -18,32 +18,50 @@
 
 #include "mlir/IR/PatternMatch.h"
 #include "llvm/ADT/ArrayRef.h"
+#include <string>
 
 namespace mlir {
 
+class ModuleOp;
+
 namespace func {
 
 class FuncOp;
 class CallOp;
 
 /// Creates a new function operation with the same name as the original
-/// function operation, but with the arguments reordered according to
-/// the `newArgsOrder` and `newResultsOrder`.
+/// function operation, but with the arguments mapped according to
+/// the `oldArgToNewArg` and `oldResToNewRes`.
 /// The `funcOp` operation must have exactly one block.
 /// Returns the new function operation or failure if `funcOp` doesn't
 /// have exactly one block.
-FailureOr<FuncOp>
-replaceFuncWithNewOrder(RewriterBase &rewriter, FuncOp funcOp,
-                        llvm::ArrayRef<unsigned> newArgsOrder,
-                        llvm::ArrayRef<unsigned> newResultsOrder);
+/// Note: the method asserts that the `oldArgToNewArg` and `oldResToNewRes`
+/// maps the whole function arguments and results.
+mlir::FailureOr<mlir::func::FuncOp> replaceFuncWithNewMapping(
+    mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp,
+    ArrayRef<int> oldArgIdxToNewArgIdx, ArrayRef<int> oldResIdxToNewResIdx);
 /// Creates a new call operation with the values as the original
-/// call operation, but with the arguments reordered according to
-/// the `newArgsOrder` and `newResultsOrder`.
-CallOp replaceCallOpWithNewOrder(RewriterBase &rewriter, CallOp callOp,
-                                 llvm::ArrayRef<unsigned> newArgsOrder,
-                                 llvm::ArrayRef<unsigned> newResultsOrder);
+/// call operation, but with the arguments mapped according to
+/// the `oldArgToNewArg` and `oldResToNewRes`.
+/// Note: the method asserts that the `oldArgToNewArg` and `oldResToNewRes`
+/// maps the whole call operation arguments and results.
+mlir::func::CallOp replaceCallOpWithNewMapping(
+    mlir::RewriterBase &rewriter, mlir::func::CallOp callOp,
+    ArrayRef<int> oldArgIdxToNewArgIdx, ArrayRef<int> oldResIdxToNewResIdx);
+
+/// This utility function examines all call operations within the given
+/// `moduleOp` that target the specified `funcOp`. It identifies duplicate
+/// operands in the call operations, creates mappings to deduplicate them, and
+/// then applies the transformation to both the function and its call sites. For
+/// now, it only supports one call operation for the function operation. The
+/// function returns a pair containing the new funcOp and the new callOp. Note:
+/// after the transformation, the original funcOp and callOp will be erased. The
+/// `errorMessage` will be set to the error message if the transformation fails.
+mlir::FailureOr<std::pair<mlir::func::FuncOp, mlir::func::CallOp>>
+deduplicateArgsOfFuncOp(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp,
+                        mlir::ModuleOp moduleOp, std::string &errorMessage);
 
 } // namespace func
 } // namespace mlir
 
-#endif // MLIR_DIALECT_FUNC_UTILS_H
+#endif // MLIR_DIALECT_FUNC_UTILS_H
\ No newline at end of file
diff --git a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
index 935d3e5ac331b..ee8e8333effc2 100644
--- a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
+++ b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
 
 using namespace mlir;
 
@@ -296,9 +297,16 @@ transform::ReplaceFuncSignatureOp::apply(transform::TransformRewriter &rewriter,
     }
   }
 
-  FailureOr<func::FuncOp> newFuncOpOrFailure = func::replaceFuncWithNewOrder(
-      rewriter, funcOp, argsInterchange.getArrayRef(),
-      resultsInterchange.getArrayRef());
+  llvm::SmallVector<int> oldArgToNewArg(argsInterchange.size());
+  for (auto [newArgIdx, oldArgIdx] : llvm::enumerate(argsInterchange))
+    oldArgToNewArg[oldArgIdx] = newArgIdx;
+
+  llvm::SmallVector<int> oldResToNewRes(resultsInterchange.size());
+  for (auto [newResIdx, oldResIdx] : llvm::enumerate(resultsInterchange))
+    oldResToNewRes[oldResIdx] = newResIdx;
+
+  FailureOr<func::FuncOp> newFuncOpOrFailure = func::replaceFuncWithNewMapping(
+      rewriter, funcOp, oldArgToNewArg, oldResToNewRes);
   if (failed(newFuncOpOrFailure))
     return emitSilenceableFailure(getLoc())
            << "failed to replace function signature '" << getFunctionName()
@@ -312,9 +320,8 @@ transform::ReplaceFuncSignatureOp::apply(transform::TransformRewriter &rewriter,
     });
 
     for (func::CallOp callOp : callOps)
-      func::replaceCallOpWithNewOrder(rewriter, callOp,
-                                      argsInterchange.getArrayRef(),
-                                      resultsInterchange.getArrayRef());
+      func::replaceCallOpWithNewMapping(rewriter, callOp, oldArgToNewArg,
+                                        oldResToNewRes);
   }
 
   results.set(cast<OpResult>(getTransformedModule()), {targetModuleOp});
@@ -330,6 +337,50 @@ void transform::ReplaceFuncSignatureOp::getEffects(
   transform::modifiesPayload(effects);
 }
 
+//===----------------------------------------------------------------------===//
+// DeduplicateFuncArgsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::DeduplicateFuncArgsOp::apply(transform::TransformRewriter &rewriter,
+                                        transform::TransformResults &results,
+                                        transform::TransformState &state) {
+  auto payloadOps = state.getPayloadOps(getModule());
+  if (!llvm::hasSingleElement(payloadOps))
+    return emitDefiniteFailure() << "requires a single module to operate on";
+
+  auto targetModuleOp = dyn_cast<ModuleOp>(*payloadOps.begin());
+  if (!targetModuleOp)
+    return emitSilenceableFailure(getLoc())
+           << "target is expected to be module operation";
+
+  func::FuncOp funcOp =
+      targetModuleOp.lookupSymbol<func::FuncOp>(getFunctionName());
+  if (!funcOp)
+    return emitSilenceableFailure(getLoc())
+           << "function with name '" << getFunctionName() << "' is not found";
+
+  std::string errorMessage;
+  auto transformationResult = func::deduplicateArgsOfFuncOp(
+      rewriter, funcOp, targetModuleOp, errorMessage);
+  if (failed(transformationResult))
+    return emitSilenceableFailure(getLoc()) << errorMessage;
+
+  auto [newFuncOp, newCallOp] = *transformationResult;
+
+  results.set(cast<OpResult>(getTransformedModule()), {targetModuleOp});
+  results.set(cast<OpResult>(getTransformedFunction()), {newFuncOp});
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::DeduplicateFuncArgsOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::consumesHandle(getModuleMutable(), effects);
+  transform::producesHandle(getOperation()->getOpResults(), effects);
+  transform::modifiesPayload(effects);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Func/Utils/Utils.cpp b/mlir/lib/Dialect/Func/Utils/Utils.cpp
index f781ed2d591b4..f5d9766cf12ef 100644
--- a/mlir/lib/Dialect/Func/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Func/Utils/Utils.cpp
@@ -14,35 +14,98 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/PatternMatch.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 
 using namespace mlir;
 
+/// This method creates an inverse mapping of the provided map `oldToNew`.
+/// Given an array where `oldIdxToNewIdx[i] = j` means old index `i` maps
+/// to new index `j`,
+/// This method returns a vector where `result[j]` contains all old indices
+/// that map to new index `j`.
+///
+/// Example:
+/// ```
+/// oldIdxToNewIdx = [0, 1, 2, 2, 3]
+/// getInverseMapping(oldIdxToNewIdx) = [[0], [1], [2, 3], [4]]
+/// ```
+///
+static llvm::SmallVector<llvm::SmallVector<int>>
+getInverseMapping(ArrayRef<int> oldIdxToNewIdx) {
+  int numOfNewIdxs = 0;
+  if (!oldIdxToNewIdx.empty())
+    numOfNewIdxs = 1 + *llvm::max_element(oldIdxToNewIdx);
+  llvm::SmallVector<llvm::SmallVector<int>> newToOldIdxs(numOfNewIdxs);
+  for (auto [oldIdx, newIdx] : llvm::enumerate(oldIdxToNewIdx))
+    newToOldIdxs[newIdx].push_back(oldIdx);
+  return newToOldIdxs;
+}
+
+/// This method returns a new vector of elements that are mapped from the
+/// `origElements` based on the `newIdxToOldIdxs` mapping. This function assumes
+/// that the `newIdxToOldIdxs` mapping is valid, i.e. for each new index, there
+/// is at least one old index that maps to it. Also, It assumes that mapping to
+/// the same old index has the same element in the `origElements` vector.
+template <typename Element>
+static SmallVector<Element> getMappedElements(
+    ArrayRef<Element> origElements,
+    const llvm::SmallVector<llvm::SmallVector<int>> &newIdxToOldIdxs) {
+  SmallVector<Element> newElements;
+  for (const auto &oldIdxs : newIdxToOldIdxs) {
+    assert(llvm::all_of(oldIdxs,
+                        [&origElements](int idx) -> bool {
+                          return idx >= 0 &&
+                                 static_cast<size_t>(idx) < origElements.size();
+                        }) &&
+           "idx must be less than the number of elements in the original "
+           "elements");
+    assert(!oldIdxs.empty() && "oldIdx must not be empty");
+    Element origTypeToCheck = origElements[oldIdxs.front()];
+    assert(llvm::all_of(oldIdxs,
+                        [&](int idx) -> bool {
+                          return origElements[idx] == origTypeToCheck;
+                        }) &&
+           "all oldIdxs must be equal");
+    newElements.push_back(origTypeToCheck);
+  }
+  return newElements;
+}
+
 FailureOr<func::FuncOp>
-func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
-                              ArrayRef<unsigned> newArgsOrder,
-                              ArrayRef<unsigned> newResultsOrder) {
+func::replaceFuncWithNewMapping(RewriterBase &rewriter, func::FuncOp funcOp,
+                                ArrayRef<int> oldArgIdxToNewArgIdx,
+                                ArrayRef<int> oldResIdxToNewResIdx) {
   // Generate an empty new function operation with the same name as the
   // original.
-  assert(funcOp.getNumArguments() == newArgsOrder.size() &&
-         "newArgsOrder must match the number of arguments in the function");
-  assert(funcOp.getNumResults() == newResultsOrder.size() &&
-         "newResultsOrder must match the number of results in the function");
+  assert(funcOp.getNumArguments() == oldArgIdxToNewArgIdx.size() &&
+         "oldArgIdxToNewArgIdx must match the number of arguments in the "
+         "function");
+  assert(
+      funcOp.getNumResults() == oldResIdxToNewResIdx.size() &&
+      "oldResIdxToNewResIdx must match the number of results in the function");
 
   if (!funcOp.getBody().hasOneBlock())
     return rewriter.notifyMatchFailure(
         funcOp, "expected function to have exactly one block");
 
-  ArrayRef<Type> origInputTypes = funcOp.getFunctionType().getInputs();
-  ArrayRef<Type> origOutputTypes = funcOp.getFunctionType().getResults();
-  SmallVector<Type> newInputTypes, newOutputTypes;
+  // We may have some duplicate arguments in the old function, i.e.
+  // in the mapping `newArgIdxToOldArgIdxs` for some new argument index
+  // there may be multiple old argument indices.
+  llvm::SmallVector<llvm::SmallVector<int>> newArgIdxToOldArgIdxs =
+      getInverseMapping(oldArgIdxToNewArgIdx);
+  SmallVector<Type> newInputTypes = getMappedElements(
+      funcOp.getFunctionType().getInputs(), newArgIdxToOldArgIdxs);
+
   SmallVector<Location> locs;
-  for (unsigned int idx : newArgsOrder) {
-    newInputTypes.push_back(origInputTypes[idx]);
-    locs.push_back(funcOp.getArgument(newArgsOrder[idx]).getLoc());
-  }
-  for (unsigned int idx : newResultsOrder)
-    newOutputTypes.push_back(origOutputTypes[idx]);
+  for (const auto &oldArgIdxs : newArgIdxToOldArgIdxs)
+    locs.push_back(funcOp.getArgument(oldArgIdxs.front()).getLoc());
+
+  llvm::SmallVector<llvm::SmallVector<int>> newResToOldResIdxs =
+      getInverseMapping(oldResIdxToNewResIdx);
+  SmallVector<Type> newOutputTypes = getMappedElements(
+      funcOp.getFunctionType().getResults(), newResToOldResIdxs);
+
   rewriter.setInsertionPoint(funcOp);
   auto newFuncOp = func::FuncOp::create(
       rewriter, funcOp.getLoc(), funcOp.getName(),
@@ -51,21 +114,21 @@ func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
   Region &newRegion = newFuncOp.getBody();
   rewriter.createBlock(&newRegion, newRegion.begin(), newInputTypes, locs);
   newFuncOp.setVisibility(funcOp.getVisibility());
-  newFuncOp->setDiscardableAttrs(funcOp->getDiscardableAttrDictionary());
 
   // Map the arguments of the original function to the new function in
   // the new order and adjust the attributes accordingly.
   IRMapping operandMapper;
   SmallVector<DictionaryAttr> argAttrs, resultAttrs;
   funcOp.getAllArgAttrs(argAttrs);
-  for (unsigned int i = 0; i < newArgsOrder.size(); ++i) {
-    operandMapper.map(funcOp.getArgument(newArgsOrder[i]),
-                      newFuncOp.getArgument(i));
-    newFuncOp.setArgAttrs(i, argAttrs[newArgsOrder[i]]);
-  }
+  for (auto [oldArgIdx, newArgIdx] : llvm::enumerate(oldArgIdxToNewArgIdx))
+    operandMapper.map(funcOp.getArgument(oldArgIdx),
+                      newFuncOp.getArgument(newArgIdx));
+  for (auto [newArgIdx, oldArgIdx] : llvm::enumerate(newArgIdxToOldArgIdxs))
+    newFuncOp.setArgAttrs(newArgIdx, argAttrs[oldArgIdx.front()]);
+
   funcOp.getAllResultAttrs(resultAttrs);
-  for (unsigned int i = 0; i < newResultsOrder.size(); ++i)
-    newFuncOp.setResultAttrs(i, resultAttrs[newResultsOrder[i]]);
+  for (auto [newResIdx, oldResIdx] : llvm::enumerate(newResToOldResIdxs))
+    newFuncOp.setResultAttrs(newResIdx, resultAttrs[oldResIdx.front()]);
 
   // Clone the operations from the original function to the new function.
   rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
@@ -76,12 +139,11 @@ func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
   auto returnOp = cast<func::ReturnOp>(
       newFuncOp.getFunctionBody().begin()->getTerminator());
   SmallVector<Value> newReturnValues;
-  for (unsigned int idx : newResultsOrder)
-    newReturnValues.push_back(returnOp.getOperand(idx));
+  for (const auto &oldResIdxs : newResToOldResIdxs)
+    newReturnValues.push_back(returnOp.getOperand(oldResIdxs.front()));
+
   rewriter.setInsertionPoint(returnOp);
-  auto newReturnOp =
-      func::ReturnOp::create(rewriter, newFuncOp.getLoc(), newReturnValues);
-  newReturnOp->setDiscardableAttrs(returnOp->getDiscardableAttrDictionary());
+  func::ReturnOp::create(rewriter, newFuncOp.getLoc(), newReturnValues);
   rewriter.eraseOp(returnOp);
 
   rewriter.eraseOp(funcOp);
@@ -90,33 +152,105 @@ func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
 }
 
 func::CallOp
-func::replaceCallOpWithNewOrder(RewriterBase &rewriter, func::CallOp callOp,
-                                ArrayRef<unsigned> newArgsOrder,
-                                ArrayRef<unsigned> newResultsOrder) {
-  assert(
-      callOp.getNumOperands() == newArgsOrder.size() &&
-      "newArgsOrder must match the number of operands in the call operation");
-  assert(
-      callOp.getNumResults() == newResultsOrder.size() &&
-      "newResultsOrder must match the number of results in the call operation");
-  SmallVector<Value> newArgsOrderValues;
-  for (unsigned int argIdx : newArgsOrder)
-    newArgsOrderValues.push_back(callOp.getOperand(argIdx));
-  SmallVector<Type> newResultTypes;
-  for (unsigned int resIdx : newResultsOrder)
-    newResultTypes.push_back(callOp.getResult(resIdx).getType());
+func::replaceCallOpWithNewMapping(RewriterBase &rewriter, func::CallOp callOp,
+                                  ArrayRef<int> oldArgIdxToNewArgIdx,
+                                  ArrayRef<int> oldResIdxToNewResIdx) {
+  assert(callOp.getNumOperands() == oldArgIdxToNewArgIdx.size() &&
+         "oldArgIdxToNewArgIdx must match the number of operands in the call "
+         "operation");
+  assert(callOp.getNumResults() == oldResIdxToNewResIdx.size() &&
+         "oldResIdxToNewResIdx must match the number of results in the call "
+         "operation");
+
+  SmallVector<Value> origOperands = callOp.getOperands();
+  SmallVector<llvm::SmallVector<int>> newArgIdxToOldArgIdxs =
+      getInverseMapping(oldArgIdxToNewArgIdx);
+  SmallVector<Value> newOperandsValues =
+      getMappedElements<Value>(origOperands, newArgIdxToOldArgIdxs);
+  SmallVector<llvm::SmallVector<int>> newResToOldResIdxs =
+      getInverseMapping(oldResIdxToNewResIdx);
+  SmallVector<Type> origResultTypes = llvm::to_vector(callOp.getResultTypes());
+  SmallVector<Type> newResultTypes =
+      getMappedElements<Type>(origResultTypes, newResToOldResIdxs);
 
   // Replace the kernel call operation with a new one that has the
-  // reordered arguments.
+  // mapped arguments.
   rewriter.setInsertionPoint(callOp);
   auto newCallOp =
       func::CallOp::create(rewriter, callOp.getLoc(), callOp.getCallee(),
-                           newResultTypes, newArgsOrderValues);
+                           newResultTypes, newOperandsValues);
   newCallOp.setNoInlineAttr(callOp.getNoInlineAttr());
-  for (auto &&[newIndex, origIndex] : llvm::enumerate(newResultsOrder))
-    rewriter.replaceAllUsesWith(callOp.getResult(origIndex),
-                                newCallOp.getResult(newIndex));
+  for (auto &&[oldResIdx, newResIdx] : llvm::enumerate(oldResIdxToNewResIdx))
+    rewriter.replaceAllUsesWith(callOp.getResult(oldResIdx),
+                                newCallOp.getResult(newResIdx));
   rewriter.eraseOp(callOp);
 
   return newCallOp;
 }
+
+FailureOr<std::pair<func::FuncOp, func::CallOp>>
+func::deduplicateArgsOfFuncOp(RewriterBase &rewriter, func::FuncOp funcOp,
+                              ModuleOp moduleOp, std::string &errorMessage) {
+  SmallVector<func::CallOp> callOps;
+  auto traversalResult = moduleOp.walk([&](func::CallOp callOp) {
+    if (callOp.getCallee() == funcOp.getSymName()) {
+      if (!callOps.empty())
+        // Only support one callOp for now
+        return WalkResult::interrupt();
+      callOps.push_back(callOp);
+    }
+    return WalkResult::advance();
+  });
+
+  if (traversalResult.wasInterrupted()) {
+    errorMessage = "function with name '" + funcOp.getSymName().str() +
+                   "' has more than one callOp";
+    return failure();
+  }
+
+  if (callOps.empty()) {
+    errorMessage = "function with name '" + funcOp.getSymName().str() +
+                   "' does not have any callOp";
+    return failure();
+  }
+
+  func::CallOp callOp = callOps.front();
+
+  // Create mapping for arguments (deduplicate operands)
+  SmallVector<int> oldArgIdxToNewArgIdx(callOp.getNumOperands());
+  llvm::DenseMap<Value, int> valueToNewArgIdx;
+  for (auto [operandIdx, operand] : llvm::enumerate(callOp.getOperands())) {
+    auto [iterator, inserted] = valueToNewArgIdx.insert(
+        {operand, static_cast<int>(valueToNewArgIdx.size())});
+    // Reduce the duplicate operands and maintain the original order.
+    oldArgIdxToNewArgIdx[operandIdx] = iterator->second;
+  }
+
+  bool hasDuplicateOperands =
+      valueToNewArgIdx.size() != callOp.getNumOperands();
+  if (!hasDuplicateOperands) {
+    errorMessage = "function with name '" + funcOp.getSymName().str() +
+                   "' does not have duplicate operands";
+    return failure();
+  }
+
+  // Create identity mapping for results (no deduplication needed)
+  SmallVector<int> oldResIdxToNewResIdx(callOp.getNumResults());
+  for (int resultIdx = 0; resultIdx < static_cast<int>(callOp.getNumResults());
+       ++resultIdx)
+    oldResIdxToNewResIdx[resultIdx] = resultIdx;
+
+  // Apply the transformation to create new function and call operations
+  FailureOr<func::FuncOp> newFuncOpOrFailure = replaceFuncWithNewMapping(
+      rewriter, funcOp, oldArgIdxToNewArgIdx, oldResIdxToNewResIdx);
+  if (failed(newFuncOpOrFailure)) {
+    errorMessage = "failed to replace function signature with name '" +
+                   funcOp.getSymName().str() + "' with new order";
+    return failure();
+  }
+
+  func::CallOp newCallOp = replaceCallOpWithNewMapping(
+      rewriter, callOp, oldArgIdxToNewArgIdx, oldResIdxToNewResIdx);
+
+  return std::make_pair(*newFuncOpOrFailure, newCallOp);
+}
diff --git a/mlir/test/Dialect/Func/func-transform-invalid.mlir b/mlir/test/Dialect/Func/func-transform-invalid.mlir
index e712eee83f36e..941e6444054e1 100644
--- a/mlir/test/Dialect/Func/func-transform-invalid.mlir
+++ b/mlir/test/Dialect/Func/func-transform-invalid.mlir
@@ -85,3 +85,92 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+func.func private @func_with_no_duplicate_args(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) {
+  %c0 = arith.constant 0 : index
+  %view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+  %view0 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
+  %view1 = memref.view %arg2[%c0][] : memref<3xi8, 1> to memref<3xi8, 1>
+  return
+}
+
+func.func @func_with_no_duplicate_args_caller(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) {
+  call @func_with_no_duplicate_args(%arg0, %arg1, %arg2) : (memref<1xi8, 1>, memref<2xi8, 1>, memref<3xi8, 1>) -> ()
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module: !transform.any_op) {
+    // expected-error @+1 {{function with name 'func_with_no_duplicate_args' does not have duplicate operands}}
+    transform.func.deduplicate_func_args @func_with_no_duplicate_args at %module : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+func.func private @func_not_found(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) {
+  %c0 = arith.constant 0 : index
+  %view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+  %view0 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
+  %view1 = memref.view %arg2[%c0][] : memref<3xi8, 1> to memref<3xi8, 1>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module: !transform.any_op) {
+    // expected-error @+1 {{function with name '@non_existent_func' is not found}}
+    transform.func.deduplicate_func_args @non_existent_func at %module : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+func.func private @func_with_multiple_calls(%arg0: memref<1xi8, 1>, %arg1: memref<1xi8, 1>) {
+  %c0 = arith.constant 0 : index
+  %view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+  %view0 = memref.view %arg1[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+  return
+}
+
+func.func @func_with_multiple_calls_caller1(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>) {
+  call @func_with_multiple_calls(%arg0, %arg0) : (memref<1xi8, 1>, memref<1xi8, 1>) -> ()
+  return
+}
+
+func.func @func_with_multiple_calls_caller2(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>) {
+  call @func_with_multiple_calls(%arg0, %arg0) : (memref<1xi8, 1>, memref<1xi8, 1>) -> ()
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module: !transform.any_op) {
+    // expected-error @+1 {{function with name 'func_with_multiple_calls' has more than one callOp}}
+    transform.func.deduplicate_func_args @func_with_multiple_calls at %module : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+func.func private @func_with_no_calls(%arg0: memref<1xi8, 1>, %arg1: memref<1xi8, 1>) {
+  %c0 = arith.constant 0 : index
+  %view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+  %view0 = memref.view %arg1[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+  return
+}
+
+func.func @some_other_func() {
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module: !transform.any_op) {
+    // expected-error @+1 {{function with name 'func_with_no_calls' does not have any callOp}}
+    transform.func.deduplicate_func_args @func_with_no_calls at %module : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Func/func-transform.mlir b/mlir/test/Dialect/Func/func-transform.mlir
index 36a66aaa95bfb..8a71511e3ed5b 100644
--- a/mlir/test/Dialect/Func/func-transform.mlir
+++ b/mlir/test/Dialect/Func/func-transform.mlir
@@ -250,3 +250,65 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// CHECK:           func.func private @func_with_duplicate_args(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<2xi8, 1>) {
+func.func private @func_with_duplicate_args(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<1xi8, 1>) {
+  %c0 = arith.constant 0 : index
+  // CHECK:             %[[VAL_3:.*]] = memref.view %[[ARG0]]{{\[}}%[[C0:.*]]][] : memref<1xi8, 1> to memref<1xi8, 1>
+  %view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+  // CHECK:             %[[VAL_4:.*]] = memref.view %[[ARG1]]{{\[}}%[[C0]]][] : memref<2xi8, 1> to memref<2xi8, 1>
+  %view0 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
+  // CHECK:             %[[VAL_5:.*]] = memref.view %[[ARG0]]{{\[}}%[[C0]]][] : memref<1xi8, 1> to memref<1xi8, 1>
+  %view1 = memref.view %arg2[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+  return
+}
+
+// CHECK:           func.func @func_with_duplicate_args_caller(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<2xi8, 1>) {
+func.func @func_with_duplicate_args_caller(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>) {
+  // CHECK:             call @func_with_duplicate_args(%[[ARG0]], %[[ARG1]]) : (memref<1xi8, 1>, memref<2xi8, 1>) -> ()
+  call @func_with_duplicate_args(%arg0, %arg1, %arg0) : (memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>) -> ()
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module: !transform.any_op) {
+    transform.func.deduplicate_func_args @func_with_duplicate_args at %module : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK:           func.func private @func_with_complex_duplicate_args(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<2xi8, 1>, %[[ARG2:.*]]: memref<3xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>) {
+func.func private @func_with_complex_duplicate_args(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<1xi8, 1>, %arg3: memref<3xi8, 1>, %arg4: memref<2xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>) {
+  %c0 = arith.constant 0 : index
+  // CHECK:             %[[RET_0:.*]] = memref.view %[[ARG0]]{{\[}}%[[C0:.*]]][] : memref<1xi8, 1> to memref<1xi8, 1>
+  %view0 = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+  // CHECK:             %[[RET_1:.*]] = memref.view %[[ARG1]]{{\[}}%[[C0]]][] : memref<2xi8, 1> to memref<2xi8, 1>
+  %view1 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
+  // CHECK:             %[[RET_2:.*]] = memref.view %[[ARG0]]{{\[}}%[[C0]]][] : memref<1xi8, 1> to memref<1xi8, 1>
+  %view2 = memref.view %arg2[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+  // CHECK:             %[[RET_3:.*]] = memref.view %[[ARG2]]{{\[}}%[[C0]]][] : memref<3xi8, 1> to memref<3xi8, 1>
+  %view3 = memref.view %arg3[%c0][] : memref<3xi8, 1> to memref<3xi8, 1>
+  // CHECK:             %[[RET_4:.*]] = memref.view %[[ARG1]]{{\[}}%[[C0]]][] : memref<2xi8, 1> to memref<2xi8, 1>
+  %view4 = memref.view %arg4[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
+  // CHECK:             return %[[RET_0]], %[[RET_1]], %[[RET_2]], %[[RET_3]], %[[RET_4]] : memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>
+  return %view0, %view1, %view2, %view3, %view4 : memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>
+}
+
+// CHECK:           func.func @func_with_complex_duplicate_args_caller(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<2xi8, 1>, %[[ARG2:.*]]: memref<3xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>) {
+func.func @func_with_complex_duplicate_args_caller(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>) {
+  // CHECK:             %[[RET:.*]]:5 = call @func_with_complex_duplicate_args(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (memref<1xi8, 1>, memref<2xi8, 1>, memref<3xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>)
+  %0:5 = call @func_with_complex_duplicate_args(%arg0, %arg1, %arg0, %arg2, %arg1) : (memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>)
+  // CHECK:             return %[[RET]]#0, %[[RET]]#1, %[[RET]]#2, %[[RET]]#3, %[[RET]]#4 : memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>
+  return %0#0, %0#1, %0#2, %0#3, %0#4 : memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module: !transform.any_op) {
+    transform.func.deduplicate_func_args @func_with_complex_duplicate_args at %module : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list