[Mlir-commits] [mlir] [mlir][func]-Add deduplicate funcOp arguments transform (PR #158266)

Amir Bishara llvmlistbot at llvm.org
Fri Sep 12 07:37:50 PDT 2025


https://github.com/amirBish updated https://github.com/llvm/llvm-project/pull/158266

>From ebac68f350ed4096da8f9ee16bdfb867c1175721 Mon Sep 17 00:00:00 2001
From: Amir Bishara <amir.bishara at mobileye.com>
Date: Fri, 12 Sep 2025 12:55:54 +0300
Subject: [PATCH] [mlir][func]-Add deduplicate funcOp arguments transform

This PR adds a new transform operation which removes the
duplicate arguments from the function operation based on
the callOp of this function.

To have a more simple implementation for now, the transform
will fail when having multiple callOps for the same function
we want to eliminate the different arguments from.

This pull request also adpat the utils under the func dialect
to be reusable also for this transformOp.
---
 .../Func/TransformOps/FuncTransformOps.td     |  26 +++
 mlir/include/mlir/Dialect/Func/Utils/Utils.h  |  29 +--
 .../Func/TransformOps/FuncTransformOps.cpp    | 102 +++++++++-
 mlir/lib/Dialect/Func/Utils/Utils.cpp         | 180 ++++++++++++++----
 .../Dialect/Func/func-transform-invalid.mlir  |  89 +++++++++
 mlir/test/Dialect/Func/func-transform.mlir    |  62 ++++++
 6 files changed, 435 insertions(+), 53 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
index 4062f310c6521..b64b3fcdb275b 100644
--- a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
+++ b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
@@ -134,4 +134,30 @@ def ReplaceFuncSignatureOp
   }];
 }
 
+def DeduplicateFuncArgsOp
+    : Op<Transform_Dialect, "func.deduplicate_func_args",
+         [DeclareOpInterfaceMethods<TransformOpInterface>,
+          DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+  let description = [{
+      This transform takes a module and a function name, and deduplicates
+      the arguments of the function. The function is expected to be defined in
+      the module.
+
+      This transform will emit a silenceable failure if:
+       - The function with the given name does not exist in the module.
+       - The function does not have duplicate arguments.
+       - The function does not have a single call.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$module,
+      SymbolRefAttr:$function_name);
+  let results = (outs TransformHandleTypeInterface:$transformed_module,
+                      TransformHandleTypeInterface:$transformed_function);
+
+  let assemblyFormat = [{
+    $function_name
+    `at` $module attr-dict `:` functional-type(operands, results)
+  }];
+}
+
 #endif // FUNC_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
index 2e8b6723a0e53..a75e3b93956dd 100644
--- a/mlir/include/mlir/Dialect/Func/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
@@ -27,21 +27,28 @@ class FuncOp;
 class CallOp;
 
 /// Creates a new function operation with the same name as the original
-/// function operation, but with the arguments reordered according to
-/// the `newArgsOrder` and `newResultsOrder`.
+/// function operation, but with the arguments mapped according to
+/// the `oldArgToNewArg` and `oldResToNewRes`.
 /// The `funcOp` operation must have exactly one block.
 /// Returns the new function operation or failure if `funcOp` doesn't
 /// have exactly one block.
-FailureOr<FuncOp>
-replaceFuncWithNewOrder(RewriterBase &rewriter, FuncOp funcOp,
-                        llvm::ArrayRef<unsigned> newArgsOrder,
-                        llvm::ArrayRef<unsigned> newResultsOrder);
+/// Note: the method asserts that the `oldArgToNewArg` and `oldResToNewRes`
+/// maps the whole function arguments and results.
+mlir::FailureOr<mlir::func::FuncOp>
+replaceFuncWithNewMapping(mlir::RewriterBase &rewriter,
+                          mlir::func::FuncOp funcOp,
+                          llvm::ArrayRef<unsigned> oldArgToNewArg,
+                          llvm::ArrayRef<unsigned> oldResToNewRes);
 /// Creates a new call operation with the values as the original
-/// call operation, but with the arguments reordered according to
-/// the `newArgsOrder` and `newResultsOrder`.
-CallOp replaceCallOpWithNewOrder(RewriterBase &rewriter, CallOp callOp,
-                                 llvm::ArrayRef<unsigned> newArgsOrder,
-                                 llvm::ArrayRef<unsigned> newResultsOrder);
+/// call operation, but with the arguments mapped according to
+/// the `oldArgToNewArg` and `oldResToNewRes`.
+/// Note: the method asserts that the `oldArgToNewArg` and `oldResToNewRes`
+/// maps the whole call operation arguments and results.
+mlir::func::CallOp
+replaceCallOpWithNewMapping(mlir::RewriterBase &rewriter,
+                            mlir::func::CallOp callOp,
+                            llvm::ArrayRef<unsigned> oldArgToNewArg,
+                            llvm::ArrayRef<unsigned> oldResToNewRes);
 
 } // namespace func
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
index 935d3e5ac331b..1568473550b75 100644
--- a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
+++ b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
 
 using namespace mlir;
 
@@ -296,9 +297,16 @@ transform::ReplaceFuncSignatureOp::apply(transform::TransformRewriter &rewriter,
     }
   }
 
-  FailureOr<func::FuncOp> newFuncOpOrFailure = func::replaceFuncWithNewOrder(
-      rewriter, funcOp, argsInterchange.getArrayRef(),
-      resultsInterchange.getArrayRef());
+  llvm::SmallVector<unsigned> oldArgToNewArg(argsInterchange.size());
+  for (auto [newArgIdx, oldArgIdx] : llvm::enumerate(argsInterchange))
+    oldArgToNewArg[oldArgIdx] = newArgIdx;
+
+  llvm::SmallVector<unsigned> oldResToNewRes(resultsInterchange.size());
+  for (auto [newResIdx, oldResIdx] : llvm::enumerate(resultsInterchange))
+    oldResToNewRes[oldResIdx] = newResIdx;
+
+  FailureOr<func::FuncOp> newFuncOpOrFailure = func::replaceFuncWithNewMapping(
+      rewriter, funcOp, oldArgToNewArg, oldResToNewRes);
   if (failed(newFuncOpOrFailure))
     return emitSilenceableFailure(getLoc())
            << "failed to replace function signature '" << getFunctionName()
@@ -312,9 +320,8 @@ transform::ReplaceFuncSignatureOp::apply(transform::TransformRewriter &rewriter,
     });
 
     for (func::CallOp callOp : callOps)
-      func::replaceCallOpWithNewOrder(rewriter, callOp,
-                                      argsInterchange.getArrayRef(),
-                                      resultsInterchange.getArrayRef());
+      func::replaceCallOpWithNewMapping(rewriter, callOp, oldArgToNewArg,
+                                        oldResToNewRes);
   }
 
   results.set(cast<OpResult>(getTransformedModule()), {targetModuleOp});
@@ -330,6 +337,89 @@ void transform::ReplaceFuncSignatureOp::getEffects(
   transform::modifiesPayload(effects);
 }
 
+//===----------------------------------------------------------------------===//
+// DeduplicateFuncArgsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::DeduplicateFuncArgsOp::apply(transform::TransformRewriter &rewriter,
+                                        transform::TransformResults &results,
+                                        transform::TransformState &state) {
+  auto payloadOps = state.getPayloadOps(getModule());
+  if (!llvm::hasSingleElement(payloadOps))
+    return emitDefiniteFailure() << "requires a single module to operate on";
+
+  auto targetModuleOp = dyn_cast<ModuleOp>(*payloadOps.begin());
+  if (!targetModuleOp)
+    return emitSilenceableFailure(getLoc())
+           << "target is expected to be module operation";
+
+  func::FuncOp funcOp =
+      targetModuleOp.lookupSymbol<func::FuncOp>(getFunctionName());
+  if (!funcOp)
+    return emitSilenceableFailure(getLoc())
+           << "function with name '" << getFunctionName() << "' is not found";
+
+  SmallVector<func::CallOp> callOps;
+  targetModuleOp.walk([&](func::CallOp callOp) {
+    if (callOp.getCallee() == getFunctionName().getRootReference().getValue())
+      callOps.push_back(callOp);
+  });
+
+  // TODO: Support more than one callOp.
+  if (!llvm::hasSingleElement(callOps))
+    return emitSilenceableFailure(getLoc())
+           << "function with name '" << getFunctionName()
+           << "' does not have a single callOp";
+
+  llvm::DenseSet<Value> seenValues;
+  func::CallOp callOp = callOps.front();
+  bool hasDuplicatesOperands =
+      llvm::any_of(callOp.getOperands(), [&seenValues](Value operand) {
+        return !seenValues.insert(operand).second;
+      });
+
+  if (!hasDuplicatesOperands)
+    return emitSilenceableFailure(getLoc())
+           << "function with name '" << getFunctionName()
+           << "' does not have duplicate operands";
+
+  llvm::SmallVector<unsigned> oldArgIdxToNewArgIdx(callOp.getNumOperands());
+  llvm::DenseMap<Value, unsigned> valueToNewArgIdx;
+  for (auto [operandIdx, operand] : llvm::enumerate(callOp.getOperands())) {
+    if (!valueToNewArgIdx.count(operand))
+      valueToNewArgIdx[operand] = valueToNewArgIdx.size();
+    // Reduce the duplicate operands and maintain the original order.
+    oldArgIdxToNewArgIdx[operandIdx] = valueToNewArgIdx[operand];
+  }
+
+  llvm::SmallVector<unsigned> oldResIdxToNewResIdx(callOp.getNumResults());
+  for (unsigned resultIdx = 0; resultIdx < callOp.getNumResults(); ++resultIdx)
+    oldResIdxToNewResIdx[resultIdx] = resultIdx;
+
+  FailureOr<func::FuncOp> newFuncOpOrFailure = func::replaceFuncWithNewMapping(
+      rewriter, funcOp, oldArgIdxToNewArgIdx, oldResIdxToNewResIdx);
+  if (failed(newFuncOpOrFailure))
+    return emitSilenceableFailure(getLoc())
+           << "failed to deduplicate function arguments '" << getFunctionName()
+           << "'";
+
+  func::replaceCallOpWithNewMapping(rewriter, callOp, oldArgIdxToNewArgIdx,
+                                    oldResIdxToNewResIdx);
+
+  results.set(cast<OpResult>(getTransformedModule()), {targetModuleOp});
+  results.set(cast<OpResult>(getTransformedFunction()), {*newFuncOpOrFailure});
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::DeduplicateFuncArgsOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::consumesHandle(getModuleMutable(), effects);
+  transform::producesHandle(getOperation()->getOpResults(), effects);
+  transform::modifiesPayload(effects);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Func/Utils/Utils.cpp b/mlir/lib/Dialect/Func/Utils/Utils.cpp
index f781ed2d591b4..d6960ce84b7cc 100644
--- a/mlir/lib/Dialect/Func/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Func/Utils/Utils.cpp
@@ -14,20 +14,21 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/PatternMatch.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 
 using namespace mlir;
 
 FailureOr<func::FuncOp>
-func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
-                              ArrayRef<unsigned> newArgsOrder,
-                              ArrayRef<unsigned> newResultsOrder) {
+func::replaceFuncWithNewMapping(RewriterBase &rewriter, func::FuncOp funcOp,
+                                llvm::ArrayRef<unsigned> oldArgToNewArg,
+                                llvm::ArrayRef<unsigned> oldResToNewRes) {
   // Generate an empty new function operation with the same name as the
   // original.
-  assert(funcOp.getNumArguments() == newArgsOrder.size() &&
-         "newArgsOrder must match the number of arguments in the function");
-  assert(funcOp.getNumResults() == newResultsOrder.size() &&
-         "newResultsOrder must match the number of results in the function");
+  assert(funcOp.getNumArguments() == oldArgToNewArg.size() &&
+         "oldArgToNewArg must match the number of arguments in the function");
+  assert(funcOp.getNumResults() == oldResToNewRes.size() &&
+         "oldResToNewRes must match the number of results in the function");
 
   if (!funcOp.getBody().hasOneBlock())
     return rewriter.notifyMatchFailure(
@@ -37,12 +38,62 @@ func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
   ArrayRef<Type> origOutputTypes = funcOp.getFunctionType().getResults();
   SmallVector<Type> newInputTypes, newOutputTypes;
   SmallVector<Location> locs;
-  for (unsigned int idx : newArgsOrder) {
-    newInputTypes.push_back(origInputTypes[idx]);
-    locs.push_back(funcOp.getArgument(newArgsOrder[idx]).getLoc());
+
+  // We may have some duplicate arguments in the old function, i.e.
+  // in the mapping `newArgIdxToOldArgIdxs` for some new argument index
+  // there may be multiple old argument indices.
+  unsigned numOfNewArgs = 0;
+  auto maxNewArgIdx = llvm::max_element(oldArgToNewArg);
+  if (maxNewArgIdx != oldArgToNewArg.end())
+    numOfNewArgs = *maxNewArgIdx + 1;
+  llvm::SmallVector<llvm::SmallVector<unsigned>> newArgIdxToOldArgIdxs(
+      numOfNewArgs);
+  for (auto [oldArgIdx, newArgIdx] : llvm::enumerate(oldArgToNewArg))
+    newArgIdxToOldArgIdxs[newArgIdx].push_back(oldArgIdx);
+
+  for (auto [newArgIdx, oldArgIdxs] : llvm::enumerate(newArgIdxToOldArgIdxs)) {
+    std::ignore = newArgIdx;
+    assert(llvm::all_of(oldArgIdxs,
+                        [&funcOp](unsigned idx) -> bool {
+                          return idx < funcOp.getNumArguments();
+                        }) &&
+           "idx must be less than the number of arguments in the function");
+    assert(!oldArgIdxs.empty() && "oldArgIdxs must not be empty");
+    Type origInputTypeToCheck = origInputTypes[oldArgIdxs.front()];
+    assert(llvm::all_of(oldArgIdxs,
+                        [&](unsigned idx) -> bool {
+                          return origInputTypes[idx] == origInputTypeToCheck;
+                        }) &&
+           "all oldArgIdx must have the same type");
+    newInputTypes.push_back(origInputTypeToCheck);
+    locs.push_back(funcOp.getArgument(oldArgIdxs.front()).getLoc());
+  }
+
+  unsigned numOfNewRes = 0;
+  auto maxNewResIdx = llvm::max_element(oldResToNewRes);
+  if (maxNewResIdx != oldResToNewRes.end())
+    numOfNewRes = *maxNewResIdx + 1;
+  llvm::SmallVector<llvm::SmallVector<unsigned>> newResToOldResIdxs(
+      numOfNewRes);
+  for (auto [oldResIdx, newResIdx] : llvm::enumerate(oldResToNewRes))
+    newResToOldResIdxs[newResIdx].push_back(oldResIdx);
+
+  for (auto [newResIdx, oldResIdxs] : llvm::enumerate(newResToOldResIdxs)) {
+    std::ignore = newResIdx;
+    assert(llvm::all_of(oldResIdxs,
+                        [&funcOp](unsigned idx) -> bool {
+                          return idx < funcOp.getNumResults();
+                        }) &&
+           "idx must be less than the number of results in the function");
+    Type origOutputTypeToCheck = origOutputTypes[oldResIdxs.front()];
+    assert(llvm::all_of(oldResIdxs,
+                        [&](unsigned idx) -> bool {
+                          return origOutputTypes[idx] == origOutputTypeToCheck;
+                        }) &&
+           "all oldResIdx must have the same type");
+    newOutputTypes.push_back(origOutputTypeToCheck);
   }
-  for (unsigned int idx : newResultsOrder)
-    newOutputTypes.push_back(origOutputTypes[idx]);
+
   rewriter.setInsertionPoint(funcOp);
   auto newFuncOp = func::FuncOp::create(
       rewriter, funcOp.getLoc(), funcOp.getName(),
@@ -58,14 +109,15 @@ func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
   IRMapping operandMapper;
   SmallVector<DictionaryAttr> argAttrs, resultAttrs;
   funcOp.getAllArgAttrs(argAttrs);
-  for (unsigned int i = 0; i < newArgsOrder.size(); ++i) {
-    operandMapper.map(funcOp.getArgument(newArgsOrder[i]),
-                      newFuncOp.getArgument(i));
-    newFuncOp.setArgAttrs(i, argAttrs[newArgsOrder[i]]);
-  }
+  for (auto [oldArgIdx, newArgIdx] : llvm::enumerate(oldArgToNewArg))
+    operandMapper.map(funcOp.getArgument(oldArgIdx),
+                      newFuncOp.getArgument(newArgIdx));
+  for (auto [newArgIdx, oldArgIdx] : llvm::enumerate(newArgIdxToOldArgIdxs))
+    newFuncOp.setArgAttrs(newArgIdx, argAttrs[oldArgIdx.front()]);
+
   funcOp.getAllResultAttrs(resultAttrs);
-  for (unsigned int i = 0; i < newResultsOrder.size(); ++i)
-    newFuncOp.setResultAttrs(i, resultAttrs[newResultsOrder[i]]);
+  for (auto [newResIdx, oldResIdx] : llvm::enumerate(newResToOldResIdxs))
+    newFuncOp.setResultAttrs(newResIdx, resultAttrs[oldResIdx.front()]);
 
   // Clone the operations from the original function to the new function.
   rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
@@ -76,8 +128,10 @@ func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
   auto returnOp = cast<func::ReturnOp>(
       newFuncOp.getFunctionBody().begin()->getTerminator());
   SmallVector<Value> newReturnValues;
-  for (unsigned int idx : newResultsOrder)
-    newReturnValues.push_back(returnOp.getOperand(idx));
+  for (auto [newResIdx, oldResIdx] : llvm::enumerate(newResToOldResIdxs)) {
+    std::ignore = newResIdx;
+    newReturnValues.push_back(returnOp.getOperand(oldResIdx.front()));
+  }
   rewriter.setInsertionPoint(returnOp);
   auto newReturnOp =
       func::ReturnOp::create(rewriter, newFuncOp.getLoc(), newReturnValues);
@@ -90,32 +144,86 @@ func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
 }
 
 func::CallOp
-func::replaceCallOpWithNewOrder(RewriterBase &rewriter, func::CallOp callOp,
-                                ArrayRef<unsigned> newArgsOrder,
-                                ArrayRef<unsigned> newResultsOrder) {
+func::replaceCallOpWithNewMapping(RewriterBase &rewriter, func::CallOp callOp,
+                                  llvm::ArrayRef<unsigned> oldArgToNewArg,
+                                  llvm::ArrayRef<unsigned> oldResToNewRes) {
   assert(
-      callOp.getNumOperands() == newArgsOrder.size() &&
-      "newArgsOrder must match the number of operands in the call operation");
+      callOp.getNumOperands() == oldArgToNewArg.size() &&
+      "oldArgToNewArg must match the number of operands in the call operation");
   assert(
-      callOp.getNumResults() == newResultsOrder.size() &&
-      "newResultsOrder must match the number of results in the call operation");
+      callOp.getNumResults() == oldResToNewRes.size() &&
+      "oldResToNewRes must match the number of results in the call operation");
+
+  // Inverse mapping from new arguments to old arguments.
+  unsigned numOfNewArgs = 0;
+  auto maxNewArgIdx = llvm::max_element(oldArgToNewArg);
+  if (maxNewArgIdx != oldArgToNewArg.end())
+    numOfNewArgs = *maxNewArgIdx + 1;
+  llvm::SmallVector<llvm::SmallVector<unsigned>> newArgIdxToOldArgIdxs(
+      numOfNewArgs);
+  for (auto [oldArgIdx, newArgIdx] : llvm::enumerate(oldArgToNewArg))
+    newArgIdxToOldArgIdxs[newArgIdx].push_back(oldArgIdx);
+
   SmallVector<Value> newArgsOrderValues;
-  for (unsigned int argIdx : newArgsOrder)
-    newArgsOrderValues.push_back(callOp.getOperand(argIdx));
+  for (const auto &[newArgIdx, oldArgIdxs] :
+       llvm::enumerate(newArgIdxToOldArgIdxs)) {
+    std::ignore = newArgIdx;
+    assert(
+        llvm::all_of(oldArgIdxs,
+                     [&callOp](unsigned idx) -> bool {
+                       return idx < callOp.getNumOperands();
+                     }) &&
+        "idx must be less than the number of operands in the call operation");
+    assert(!oldArgIdxs.empty() && "oldArgIdx must not be empty");
+    Value origOperandToCheck = callOp.getOperand(oldArgIdxs.front());
+    assert(llvm::all_of(oldArgIdxs,
+                        [&](unsigned idx) -> bool {
+                          return callOp.getOperand(idx).getType() ==
+                                 origOperandToCheck.getType();
+                        }) &&
+           "all oldArgIdx must have the same type");
+    newArgsOrderValues.push_back(origOperandToCheck);
+  }
+
+  unsigned numOfNewRes = 0;
+  auto maxNewResIdx = llvm::max_element(oldResToNewRes);
+  if (maxNewResIdx != oldResToNewRes.end())
+    numOfNewRes = *maxNewResIdx + 1;
+  llvm::SmallVector<llvm::SmallVector<unsigned>> newResIdxToOldResIdxs(
+      numOfNewRes);
+  for (auto [oldResIdx, newResIdx] : llvm::enumerate(oldResToNewRes))
+    newResIdxToOldResIdxs[newResIdx].push_back(oldResIdx);
+
   SmallVector<Type> newResultTypes;
-  for (unsigned int resIdx : newResultsOrder)
-    newResultTypes.push_back(callOp.getResult(resIdx).getType());
+  for (auto [newResIdx, oldResIdxs] : llvm::enumerate(newResIdxToOldResIdxs)) {
+    std::ignore = newResIdx;
+    assert(llvm::all_of(oldResIdxs,
+                        [&callOp](unsigned idx) -> bool {
+                          return idx < callOp.getNumResults();
+                        }) &&
+           "idx must be less than the number of results in the call operation");
+    assert(!oldResIdxs.empty() && "oldResIdx must not be empty");
+    Value origResultToCheck = callOp.getResult(oldResIdxs.front());
+    assert(llvm::all_of(oldResIdxs,
+                        [&](unsigned idx) -> bool {
+                          return callOp.getResult(idx).getType() ==
+                                 origResultToCheck.getType();
+                        }) &&
+           "all oldResIdx must have the same type");
+    newResultTypes.push_back(origResultToCheck.getType());
+  }
 
   // Replace the kernel call operation with a new one that has the
-  // reordered arguments.
+  // mapped arguments.
   rewriter.setInsertionPoint(callOp);
   auto newCallOp =
       func::CallOp::create(rewriter, callOp.getLoc(), callOp.getCallee(),
                            newResultTypes, newArgsOrderValues);
+  newCallOp->setDiscardableAttrs(callOp->getDiscardableAttrDictionary());
   newCallOp.setNoInlineAttr(callOp.getNoInlineAttr());
-  for (auto &&[newIndex, origIndex] : llvm::enumerate(newResultsOrder))
-    rewriter.replaceAllUsesWith(callOp.getResult(origIndex),
-                                newCallOp.getResult(newIndex));
+  for (auto &&[oldResIdx, newResIdx] : llvm::enumerate(oldResToNewRes))
+    rewriter.replaceAllUsesWith(callOp.getResult(oldResIdx),
+                                newCallOp.getResult(newResIdx));
   rewriter.eraseOp(callOp);
 
   return newCallOp;
diff --git a/mlir/test/Dialect/Func/func-transform-invalid.mlir b/mlir/test/Dialect/Func/func-transform-invalid.mlir
index e712eee83f36e..77ed9307e4387 100644
--- a/mlir/test/Dialect/Func/func-transform-invalid.mlir
+++ b/mlir/test/Dialect/Func/func-transform-invalid.mlir
@@ -85,3 +85,92 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+func.func private @func_with_no_duplicate_args(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) {
+  %c0 = arith.constant 0 : index
+  %view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+  %view0 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
+  %view1 = memref.view %arg2[%c0][] : memref<3xi8, 1> to memref<3xi8, 1>
+  return
+}
+
+func.func @func_with_no_duplicate_args_caller(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) {
+  call @func_with_no_duplicate_args(%arg0, %arg1, %arg2) : (memref<1xi8, 1>, memref<2xi8, 1>, memref<3xi8, 1>) -> ()
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module: !transform.any_op) {
+    // expected-error @+1 {{function with name '@func_with_no_duplicate_args' does not have duplicate operands}}
+    transform.func.deduplicate_func_args @func_with_no_duplicate_args at %module : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+func.func private @func_not_found(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) {
+  %c0 = arith.constant 0 : index
+  %view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+  %view0 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
+  %view1 = memref.view %arg2[%c0][] : memref<3xi8, 1> to memref<3xi8, 1>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module: !transform.any_op) {
+    // expected-error @+1 {{function with name '@non_existent_func' is not found}}
+    transform.func.deduplicate_func_args @non_existent_func at %module : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+func.func private @func_with_multiple_calls(%arg0: memref<1xi8, 1>, %arg1: memref<1xi8, 1>) {
+  %c0 = arith.constant 0 : index
+  %view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+  %view0 = memref.view %arg1[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+  return
+}
+
+func.func @func_with_multiple_calls_caller1(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>) {
+  call @func_with_multiple_calls(%arg0, %arg0) : (memref<1xi8, 1>, memref<1xi8, 1>) -> ()
+  return
+}
+
+func.func @func_with_multiple_calls_caller2(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>) {
+  call @func_with_multiple_calls(%arg0, %arg0) : (memref<1xi8, 1>, memref<1xi8, 1>) -> ()
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module: !transform.any_op) {
+    // expected-error @+1 {{function with name '@func_with_multiple_calls' does not have a single callOp}}
+    transform.func.deduplicate_func_args @func_with_multiple_calls at %module : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+func.func private @func_with_no_calls(%arg0: memref<1xi8, 1>, %arg1: memref<1xi8, 1>) {
+  %c0 = arith.constant 0 : index
+  %view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+  %view0 = memref.view %arg1[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+  return
+}
+
+func.func @some_other_func() {
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module: !transform.any_op) {
+    // expected-error @+1 {{function with name '@func_with_no_calls' does not have a single callOp}}
+    transform.func.deduplicate_func_args @func_with_no_calls at %module : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Func/func-transform.mlir b/mlir/test/Dialect/Func/func-transform.mlir
index 36a66aaa95bfb..8a71511e3ed5b 100644
--- a/mlir/test/Dialect/Func/func-transform.mlir
+++ b/mlir/test/Dialect/Func/func-transform.mlir
@@ -250,3 +250,65 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// CHECK:           func.func private @func_with_duplicate_args(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<2xi8, 1>) {
+func.func private @func_with_duplicate_args(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<1xi8, 1>) {
+  %c0 = arith.constant 0 : index
+  // CHECK:             %[[VAL_3:.*]] = memref.view %[[ARG0]]{{\[}}%[[C0:.*]]][] : memref<1xi8, 1> to memref<1xi8, 1>
+  %view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+  // CHECK:             %[[VAL_4:.*]] = memref.view %[[ARG1]]{{\[}}%[[C0]]][] : memref<2xi8, 1> to memref<2xi8, 1>
+  %view0 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
+  // CHECK:             %[[VAL_5:.*]] = memref.view %[[ARG0]]{{\[}}%[[C0]]][] : memref<1xi8, 1> to memref<1xi8, 1>
+  %view1 = memref.view %arg2[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+  return
+}
+
+// CHECK:           func.func @func_with_duplicate_args_caller(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<2xi8, 1>) {
+func.func @func_with_duplicate_args_caller(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>) {
+  // CHECK:             call @func_with_duplicate_args(%[[ARG0]], %[[ARG1]]) : (memref<1xi8, 1>, memref<2xi8, 1>) -> ()
+  call @func_with_duplicate_args(%arg0, %arg1, %arg0) : (memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>) -> ()
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module: !transform.any_op) {
+    transform.func.deduplicate_func_args @func_with_duplicate_args at %module : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK:           func.func private @func_with_complex_duplicate_args(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<2xi8, 1>, %[[ARG2:.*]]: memref<3xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>) {
+func.func private @func_with_complex_duplicate_args(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<1xi8, 1>, %arg3: memref<3xi8, 1>, %arg4: memref<2xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>) {
+  %c0 = arith.constant 0 : index
+  // CHECK:             %[[RET_0:.*]] = memref.view %[[ARG0]]{{\[}}%[[C0:.*]]][] : memref<1xi8, 1> to memref<1xi8, 1>
+  %view0 = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+  // CHECK:             %[[RET_1:.*]] = memref.view %[[ARG1]]{{\[}}%[[C0]]][] : memref<2xi8, 1> to memref<2xi8, 1>
+  %view1 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
+  // CHECK:             %[[RET_2:.*]] = memref.view %[[ARG0]]{{\[}}%[[C0]]][] : memref<1xi8, 1> to memref<1xi8, 1>
+  %view2 = memref.view %arg2[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+  // CHECK:             %[[RET_3:.*]] = memref.view %[[ARG2]]{{\[}}%[[C0]]][] : memref<3xi8, 1> to memref<3xi8, 1>
+  %view3 = memref.view %arg3[%c0][] : memref<3xi8, 1> to memref<3xi8, 1>
+  // CHECK:             %[[RET_4:.*]] = memref.view %[[ARG1]]{{\[}}%[[C0]]][] : memref<2xi8, 1> to memref<2xi8, 1>
+  %view4 = memref.view %arg4[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
+  // CHECK:             return %[[RET_0]], %[[RET_1]], %[[RET_2]], %[[RET_3]], %[[RET_4]] : memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>
+  return %view0, %view1, %view2, %view3, %view4 : memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>
+}
+
+// CHECK:           func.func @func_with_complex_duplicate_args_caller(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<2xi8, 1>, %[[ARG2:.*]]: memref<3xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>) {
+func.func @func_with_complex_duplicate_args_caller(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>) {
+  // CHECK:             %[[RET:.*]]:5 = call @func_with_complex_duplicate_args(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (memref<1xi8, 1>, memref<2xi8, 1>, memref<3xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>)
+  %0:5 = call @func_with_complex_duplicate_args(%arg0, %arg1, %arg0, %arg2, %arg1) : (memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>)
+  // CHECK:             return %[[RET]]#0, %[[RET]]#1, %[[RET]]#2, %[[RET]]#3, %[[RET]]#4 : memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>
+  return %0#0, %0#1, %0#2, %0#3, %0#4 : memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module: !transform.any_op) {
+    transform.func.deduplicate_func_args @func_with_complex_duplicate_args at %module : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list