[Mlir-commits] [mlir] [mlir][func]-Add deduplicate funcOp arguments	transform (PR #158266)
    Amir Bishara 
    llvmlistbot at llvm.org
       
    Sat Sep 13 11:01:34 PDT 2025
    
    
  
https://github.com/amirBish updated https://github.com/llvm/llvm-project/pull/158266
>From ffce3075a1ab474d2a3cde2e42437723b40be65f Mon Sep 17 00:00:00 2001
From: Amir Bishara <amir.bishara at mobileye.com>
Date: Fri, 12 Sep 2025 12:55:54 +0300
Subject: [PATCH] [mlir][func]-Add deduplicate funcOp arguments transform
This PR adds a new transform operation which removes the
duplicate arguments from the function operation based on
the callOp of this function.
To have a more simple implementation for now, the transform
will fail when having multiple callOps for the same function
we want to eliminate the different arguments from.
This pull request also adpat the utils under the func dialect
to be reusable also for this transformOp.
---
 .../Func/TransformOps/FuncTransformOps.td     |  26 +++
 mlir/include/mlir/Dialect/Func/Utils/Utils.h  |  25 +--
 .../Func/TransformOps/FuncTransformOps.cpp    |  99 ++++++++++-
 mlir/lib/Dialect/Func/Utils/Utils.cpp         | 165 ++++++++++++------
 .../Dialect/Func/func-transform-invalid.mlir  |  89 ++++++++++
 mlir/test/Dialect/Func/func-transform.mlir    |  62 +++++++
 6 files changed, 400 insertions(+), 66 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
index 4062f310c6521..b64b3fcdb275b 100644
--- a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
+++ b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
@@ -134,4 +134,30 @@ def ReplaceFuncSignatureOp
   }];
 }
 
+def DeduplicateFuncArgsOp
+    : Op<Transform_Dialect, "func.deduplicate_func_args",
+         [DeclareOpInterfaceMethods<TransformOpInterface>,
+          DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+  let description = [{
+      This transform takes a module and a function name, and deduplicates
+      the arguments of the function. The function is expected to be defined in
+      the module.
+
+      This transform will emit a silenceable failure if:
+       - The function with the given name does not exist in the module.
+       - The function does not have duplicate arguments.
+       - The function does not have a single call.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$module,
+      SymbolRefAttr:$function_name);
+  let results = (outs TransformHandleTypeInterface:$transformed_module,
+                      TransformHandleTypeInterface:$transformed_function);
+
+  let assemblyFormat = [{
+    $function_name
+    `at` $module attr-dict `:` functional-type(operands, results)
+  }];
+}
+
 #endif // FUNC_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
index 2e8b6723a0e53..16d53a8ced9d5 100644
--- a/mlir/include/mlir/Dialect/Func/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
@@ -27,21 +27,24 @@ class FuncOp;
 class CallOp;
 
 /// Creates a new function operation with the same name as the original
-/// function operation, but with the arguments reordered according to
-/// the `newArgsOrder` and `newResultsOrder`.
+/// function operation, but with the arguments mapped according to
+/// the `oldArgToNewArg` and `oldResToNewRes`.
 /// The `funcOp` operation must have exactly one block.
 /// Returns the new function operation or failure if `funcOp` doesn't
 /// have exactly one block.
-FailureOr<FuncOp>
-replaceFuncWithNewOrder(RewriterBase &rewriter, FuncOp funcOp,
-                        llvm::ArrayRef<unsigned> newArgsOrder,
-                        llvm::ArrayRef<unsigned> newResultsOrder);
+/// Note: the method asserts that the `oldArgToNewArg` and `oldResToNewRes`
+/// maps the whole function arguments and results.
+mlir::FailureOr<mlir::func::FuncOp> replaceFuncWithNewMapping(
+    mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp,
+    ArrayRef<int> oldArgIdxToNewArgIdx, ArrayRef<int> oldResIdxToNewResIdx);
 /// Creates a new call operation with the values as the original
-/// call operation, but with the arguments reordered according to
-/// the `newArgsOrder` and `newResultsOrder`.
-CallOp replaceCallOpWithNewOrder(RewriterBase &rewriter, CallOp callOp,
-                                 llvm::ArrayRef<unsigned> newArgsOrder,
-                                 llvm::ArrayRef<unsigned> newResultsOrder);
+/// call operation, but with the arguments mapped according to
+/// the `oldArgToNewArg` and `oldResToNewRes`.
+/// Note: the method asserts that the `oldArgToNewArg` and `oldResToNewRes`
+/// maps the whole call operation arguments and results.
+mlir::func::CallOp replaceCallOpWithNewMapping(
+    mlir::RewriterBase &rewriter, mlir::func::CallOp callOp,
+    ArrayRef<int> oldArgIdxToNewArgIdx, ArrayRef<int> oldResIdxToNewResIdx);
 
 } // namespace func
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
index 935d3e5ac331b..86334c383d5b4 100644
--- a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
+++ b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
 
 using namespace mlir;
 
@@ -296,9 +297,16 @@ transform::ReplaceFuncSignatureOp::apply(transform::TransformRewriter &rewriter,
     }
   }
 
-  FailureOr<func::FuncOp> newFuncOpOrFailure = func::replaceFuncWithNewOrder(
-      rewriter, funcOp, argsInterchange.getArrayRef(),
-      resultsInterchange.getArrayRef());
+  llvm::SmallVector<int> oldArgToNewArg(argsInterchange.size());
+  for (auto [newArgIdx, oldArgIdx] : llvm::enumerate(argsInterchange))
+    oldArgToNewArg[oldArgIdx] = newArgIdx;
+
+  llvm::SmallVector<int> oldResToNewRes(resultsInterchange.size());
+  for (auto [newResIdx, oldResIdx] : llvm::enumerate(resultsInterchange))
+    oldResToNewRes[oldResIdx] = newResIdx;
+
+  FailureOr<func::FuncOp> newFuncOpOrFailure = func::replaceFuncWithNewMapping(
+      rewriter, funcOp, oldArgToNewArg, oldResToNewRes);
   if (failed(newFuncOpOrFailure))
     return emitSilenceableFailure(getLoc())
            << "failed to replace function signature '" << getFunctionName()
@@ -312,9 +320,8 @@ transform::ReplaceFuncSignatureOp::apply(transform::TransformRewriter &rewriter,
     });
 
     for (func::CallOp callOp : callOps)
-      func::replaceCallOpWithNewOrder(rewriter, callOp,
-                                      argsInterchange.getArrayRef(),
-                                      resultsInterchange.getArrayRef());
+      func::replaceCallOpWithNewMapping(rewriter, callOp, oldArgToNewArg,
+                                        oldResToNewRes);
   }
 
   results.set(cast<OpResult>(getTransformedModule()), {targetModuleOp});
@@ -330,6 +337,86 @@ void transform::ReplaceFuncSignatureOp::getEffects(
   transform::modifiesPayload(effects);
 }
 
+//===----------------------------------------------------------------------===//
+// DeduplicateFuncArgsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::DeduplicateFuncArgsOp::apply(transform::TransformRewriter &rewriter,
+                                        transform::TransformResults &results,
+                                        transform::TransformState &state) {
+  auto payloadOps = state.getPayloadOps(getModule());
+  if (!llvm::hasSingleElement(payloadOps))
+    return emitDefiniteFailure() << "requires a single module to operate on";
+
+  auto targetModuleOp = dyn_cast<ModuleOp>(*payloadOps.begin());
+  if (!targetModuleOp)
+    return emitSilenceableFailure(getLoc())
+           << "target is expected to be module operation";
+
+  func::FuncOp funcOp =
+      targetModuleOp.lookupSymbol<func::FuncOp>(getFunctionName());
+  if (!funcOp)
+    return emitSilenceableFailure(getLoc())
+           << "function with name '" << getFunctionName() << "' is not found";
+
+  SmallVector<func::CallOp> callOps;
+  targetModuleOp.walk([&](func::CallOp callOp) {
+    if (callOp.getCallee() == getFunctionName().getRootReference().getValue())
+      callOps.push_back(callOp);
+  });
+
+  // TODO: Support more than one callOp.
+  if (!llvm::hasSingleElement(callOps))
+    return emitSilenceableFailure(getLoc())
+           << "function with name '" << getFunctionName()
+           << "' does not have a single callOp";
+
+  func::CallOp callOp = callOps.front();
+
+  llvm::SmallVector<int> oldArgIdxToNewArgIdx(callOp.getNumOperands());
+  llvm::DenseMap<Value, unsigned> valueToNewArgIdx;
+  for (auto [operandIdx, operand] : llvm::enumerate(callOp.getOperands())) {
+    auto [iterator, inserted] =
+        valueToNewArgIdx.insert({operand, valueToNewArgIdx.size()});
+    // Reduce the duplicate operands and maintain the original order.
+    oldArgIdxToNewArgIdx[operandIdx] = iterator->second;
+  }
+
+  bool hasDuplicatesOperands =
+      valueToNewArgIdx.size() != callOp.getNumOperands();
+  if (!hasDuplicatesOperands)
+    return emitSilenceableFailure(getLoc())
+           << "function with name '" << getFunctionName()
+           << "' does not have duplicate operands";
+
+  llvm::SmallVector<int> oldResIdxToNewResIdx(callOp.getNumResults());
+  for (unsigned resultIdx = 0; resultIdx < callOp.getNumResults(); ++resultIdx)
+    oldResIdxToNewResIdx[resultIdx] = resultIdx;
+
+  FailureOr<func::FuncOp> newFuncOpOrFailure = func::replaceFuncWithNewMapping(
+      rewriter, funcOp, oldArgIdxToNewArgIdx, oldResIdxToNewResIdx);
+  if (failed(newFuncOpOrFailure))
+    return emitSilenceableFailure(getLoc())
+           << "failed to deduplicate function arguments '" << getFunctionName()
+           << "'";
+
+  func::replaceCallOpWithNewMapping(rewriter, callOp, oldArgIdxToNewArgIdx,
+                                    oldResIdxToNewResIdx);
+
+  results.set(cast<OpResult>(getTransformedModule()), {targetModuleOp});
+  results.set(cast<OpResult>(getTransformedFunction()), {*newFuncOpOrFailure});
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::DeduplicateFuncArgsOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::consumesHandle(getModuleMutable(), effects);
+  transform::producesHandle(getOperation()->getOpResults(), effects);
+  transform::modifiesPayload(effects);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Func/Utils/Utils.cpp b/mlir/lib/Dialect/Func/Utils/Utils.cpp
index f781ed2d591b4..9841039487102 100644
--- a/mlir/lib/Dialect/Func/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Func/Utils/Utils.cpp
@@ -14,35 +14,98 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/PatternMatch.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 
 using namespace mlir;
 
+/// This method creates an inverse mapping of the provided map `oldToNew`.
+/// Given an array where `oldIdxToNewIdx[i] = j` means old index `i` maps
+/// to new index `j`,
+/// This method returns a vector where `result[j]` contains all old indices
+/// that map to new index `j`.
+///
+/// Example:
+/// ```
+/// oldIdxToNewIdx = [0, 1, 2, 2, 3]
+/// getInverseMapping(oldIdxToNewIdx) = [[0], [1], [2, 3], [4]]
+/// ```
+///
+static llvm::SmallVector<llvm::SmallVector<int>>
+getInverseMapping(ArrayRef<int> oldIdxToNewIdx) {
+  int numOfNewIdxs = 0;
+  if (!oldIdxToNewIdx.empty())
+    numOfNewIdxs = 1 + *llvm::max_element(oldIdxToNewIdx);
+  llvm::SmallVector<llvm::SmallVector<int>> newToOldIdxs(numOfNewIdxs);
+  for (auto [oldIdx, newIdx] : llvm::enumerate(oldIdxToNewIdx))
+    newToOldIdxs[newIdx].push_back(oldIdx);
+  return newToOldIdxs;
+}
+
+/// This method returns a new vector of elements that are mapped from the
+/// `origElements` based on the `newIdxToOldIdxs` mapping. This function assumes
+/// that the `newIdxToOldIdxs` mapping is valid, i.e. for each new index, there
+/// is at least one old index that maps to it. Also, It assumes that mapping to
+/// the same old index has the same element in the `origElements` vector.
+template <typename Element>
+static SmallVector<Element> getMappedElements(
+    ArrayRef<Element> origElements,
+    const llvm::SmallVector<llvm::SmallVector<int>> &newIdxToOldIdxs) {
+  SmallVector<Element> newElements;
+  for (const auto &oldIdxs : newIdxToOldIdxs) {
+    assert(llvm::all_of(oldIdxs,
+                        [&origElements](int idx) -> bool {
+                          return idx >= 0 &&
+                                 static_cast<size_t>(idx) < origElements.size();
+                        }) &&
+           "idx must be less than the number of elements in the original "
+           "elements");
+    assert(!oldIdxs.empty() && "oldIdx must not be empty");
+    Element origTypeToCheck = origElements[oldIdxs.front()];
+    assert(llvm::all_of(oldIdxs,
+                        [&](int idx) -> bool {
+                          return origElements[idx] == origTypeToCheck;
+                        }) &&
+           "all oldIdxs must be equal");
+    newElements.push_back(origTypeToCheck);
+  }
+  return newElements;
+}
+
 FailureOr<func::FuncOp>
-func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
-                              ArrayRef<unsigned> newArgsOrder,
-                              ArrayRef<unsigned> newResultsOrder) {
+func::replaceFuncWithNewMapping(RewriterBase &rewriter, func::FuncOp funcOp,
+                                ArrayRef<int> oldArgIdxToNewArgIdx,
+                                ArrayRef<int> oldResIdxToNewResIdx) {
   // Generate an empty new function operation with the same name as the
   // original.
-  assert(funcOp.getNumArguments() == newArgsOrder.size() &&
-         "newArgsOrder must match the number of arguments in the function");
-  assert(funcOp.getNumResults() == newResultsOrder.size() &&
-         "newResultsOrder must match the number of results in the function");
+  assert(funcOp.getNumArguments() == oldArgIdxToNewArgIdx.size() &&
+         "oldArgIdxToNewArgIdx must match the number of arguments in the "
+         "function");
+  assert(
+      funcOp.getNumResults() == oldResIdxToNewResIdx.size() &&
+      "oldResIdxToNewResIdx must match the number of results in the function");
 
   if (!funcOp.getBody().hasOneBlock())
     return rewriter.notifyMatchFailure(
         funcOp, "expected function to have exactly one block");
 
-  ArrayRef<Type> origInputTypes = funcOp.getFunctionType().getInputs();
-  ArrayRef<Type> origOutputTypes = funcOp.getFunctionType().getResults();
-  SmallVector<Type> newInputTypes, newOutputTypes;
+  // We may have some duplicate arguments in the old function, i.e.
+  // in the mapping `newArgIdxToOldArgIdxs` for some new argument index
+  // there may be multiple old argument indices.
+  llvm::SmallVector<llvm::SmallVector<int>> newArgIdxToOldArgIdxs =
+      getInverseMapping(oldArgIdxToNewArgIdx);
+  SmallVector<Type> newInputTypes = getMappedElements(
+      funcOp.getFunctionType().getInputs(), newArgIdxToOldArgIdxs);
+
   SmallVector<Location> locs;
-  for (unsigned int idx : newArgsOrder) {
-    newInputTypes.push_back(origInputTypes[idx]);
-    locs.push_back(funcOp.getArgument(newArgsOrder[idx]).getLoc());
-  }
-  for (unsigned int idx : newResultsOrder)
-    newOutputTypes.push_back(origOutputTypes[idx]);
+  for (const auto &oldArgIdxs : newArgIdxToOldArgIdxs)
+    locs.push_back(funcOp.getArgument(oldArgIdxs.front()).getLoc());
+
+  llvm::SmallVector<llvm::SmallVector<int>> newResToOldResIdxs =
+      getInverseMapping(oldResIdxToNewResIdx);
+  SmallVector<Type> newOutputTypes = getMappedElements(
+      funcOp.getFunctionType().getResults(), newResToOldResIdxs);
+
   rewriter.setInsertionPoint(funcOp);
   auto newFuncOp = func::FuncOp::create(
       rewriter, funcOp.getLoc(), funcOp.getName(),
@@ -51,21 +114,21 @@ func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
   Region &newRegion = newFuncOp.getBody();
   rewriter.createBlock(&newRegion, newRegion.begin(), newInputTypes, locs);
   newFuncOp.setVisibility(funcOp.getVisibility());
-  newFuncOp->setDiscardableAttrs(funcOp->getDiscardableAttrDictionary());
 
   // Map the arguments of the original function to the new function in
   // the new order and adjust the attributes accordingly.
   IRMapping operandMapper;
   SmallVector<DictionaryAttr> argAttrs, resultAttrs;
   funcOp.getAllArgAttrs(argAttrs);
-  for (unsigned int i = 0; i < newArgsOrder.size(); ++i) {
-    operandMapper.map(funcOp.getArgument(newArgsOrder[i]),
-                      newFuncOp.getArgument(i));
-    newFuncOp.setArgAttrs(i, argAttrs[newArgsOrder[i]]);
-  }
+  for (auto [oldArgIdx, newArgIdx] : llvm::enumerate(oldArgIdxToNewArgIdx))
+    operandMapper.map(funcOp.getArgument(oldArgIdx),
+                      newFuncOp.getArgument(newArgIdx));
+  for (auto [newArgIdx, oldArgIdx] : llvm::enumerate(newArgIdxToOldArgIdxs))
+    newFuncOp.setArgAttrs(newArgIdx, argAttrs[oldArgIdx.front()]);
+
   funcOp.getAllResultAttrs(resultAttrs);
-  for (unsigned int i = 0; i < newResultsOrder.size(); ++i)
-    newFuncOp.setResultAttrs(i, resultAttrs[newResultsOrder[i]]);
+  for (auto [newResIdx, oldResIdx] : llvm::enumerate(newResToOldResIdxs))
+    newFuncOp.setResultAttrs(newResIdx, resultAttrs[oldResIdx.front()]);
 
   // Clone the operations from the original function to the new function.
   rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
@@ -76,12 +139,11 @@ func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
   auto returnOp = cast<func::ReturnOp>(
       newFuncOp.getFunctionBody().begin()->getTerminator());
   SmallVector<Value> newReturnValues;
-  for (unsigned int idx : newResultsOrder)
-    newReturnValues.push_back(returnOp.getOperand(idx));
+  for (const auto &oldResIdxs : newResToOldResIdxs)
+    newReturnValues.push_back(returnOp.getOperand(oldResIdxs.front()));
+
   rewriter.setInsertionPoint(returnOp);
-  auto newReturnOp =
-      func::ReturnOp::create(rewriter, newFuncOp.getLoc(), newReturnValues);
-  newReturnOp->setDiscardableAttrs(returnOp->getDiscardableAttrDictionary());
+  func::ReturnOp::create(rewriter, newFuncOp.getLoc(), newReturnValues);
   rewriter.eraseOp(returnOp);
 
   rewriter.eraseOp(funcOp);
@@ -90,32 +152,37 @@ func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
 }
 
 func::CallOp
-func::replaceCallOpWithNewOrder(RewriterBase &rewriter, func::CallOp callOp,
-                                ArrayRef<unsigned> newArgsOrder,
-                                ArrayRef<unsigned> newResultsOrder) {
-  assert(
-      callOp.getNumOperands() == newArgsOrder.size() &&
-      "newArgsOrder must match the number of operands in the call operation");
-  assert(
-      callOp.getNumResults() == newResultsOrder.size() &&
-      "newResultsOrder must match the number of results in the call operation");
-  SmallVector<Value> newArgsOrderValues;
-  for (unsigned int argIdx : newArgsOrder)
-    newArgsOrderValues.push_back(callOp.getOperand(argIdx));
-  SmallVector<Type> newResultTypes;
-  for (unsigned int resIdx : newResultsOrder)
-    newResultTypes.push_back(callOp.getResult(resIdx).getType());
+func::replaceCallOpWithNewMapping(RewriterBase &rewriter, func::CallOp callOp,
+                                  ArrayRef<int> oldArgIdxToNewArgIdx,
+                                  ArrayRef<int> oldResIdxToNewResIdx) {
+  assert(callOp.getNumOperands() == oldArgIdxToNewArgIdx.size() &&
+         "oldArgIdxToNewArgIdx must match the number of operands in the call "
+         "operation");
+  assert(callOp.getNumResults() == oldResIdxToNewResIdx.size() &&
+         "oldResIdxToNewResIdx must match the number of results in the call "
+         "operation");
+
+  SmallVector<Value> origOperands = callOp.getOperands();
+  SmallVector<llvm::SmallVector<int>> newArgIdxToOldArgIdxs =
+      getInverseMapping(oldArgIdxToNewArgIdx);
+  SmallVector<Value> newOperandsValues =
+      getMappedElements<Value>(origOperands, newArgIdxToOldArgIdxs);
+  SmallVector<llvm::SmallVector<int>> newResToOldResIdxs =
+      getInverseMapping(oldResIdxToNewResIdx);
+  SmallVector<Type> origResultTypes = llvm::to_vector(callOp.getResultTypes());
+  SmallVector<Type> newResultTypes =
+      getMappedElements<Type>(origResultTypes, newResToOldResIdxs);
 
   // Replace the kernel call operation with a new one that has the
-  // reordered arguments.
+  // mapped arguments.
   rewriter.setInsertionPoint(callOp);
   auto newCallOp =
       func::CallOp::create(rewriter, callOp.getLoc(), callOp.getCallee(),
-                           newResultTypes, newArgsOrderValues);
+                           newResultTypes, newOperandsValues);
   newCallOp.setNoInlineAttr(callOp.getNoInlineAttr());
-  for (auto &&[newIndex, origIndex] : llvm::enumerate(newResultsOrder))
-    rewriter.replaceAllUsesWith(callOp.getResult(origIndex),
-                                newCallOp.getResult(newIndex));
+  for (auto &&[oldResIdx, newResIdx] : llvm::enumerate(oldResIdxToNewResIdx))
+    rewriter.replaceAllUsesWith(callOp.getResult(oldResIdx),
+                                newCallOp.getResult(newResIdx));
   rewriter.eraseOp(callOp);
 
   return newCallOp;
diff --git a/mlir/test/Dialect/Func/func-transform-invalid.mlir b/mlir/test/Dialect/Func/func-transform-invalid.mlir
index e712eee83f36e..77ed9307e4387 100644
--- a/mlir/test/Dialect/Func/func-transform-invalid.mlir
+++ b/mlir/test/Dialect/Func/func-transform-invalid.mlir
@@ -85,3 +85,92 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+func.func private @func_with_no_duplicate_args(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) {
+  %c0 = arith.constant 0 : index
+  %view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+  %view0 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
+  %view1 = memref.view %arg2[%c0][] : memref<3xi8, 1> to memref<3xi8, 1>
+  return
+}
+
+func.func @func_with_no_duplicate_args_caller(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) {
+  call @func_with_no_duplicate_args(%arg0, %arg1, %arg2) : (memref<1xi8, 1>, memref<2xi8, 1>, memref<3xi8, 1>) -> ()
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module: !transform.any_op) {
+    // expected-error @+1 {{function with name '@func_with_no_duplicate_args' does not have duplicate operands}}
+    transform.func.deduplicate_func_args @func_with_no_duplicate_args at %module : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+func.func private @func_not_found(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) {
+  %c0 = arith.constant 0 : index
+  %view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+  %view0 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
+  %view1 = memref.view %arg2[%c0][] : memref<3xi8, 1> to memref<3xi8, 1>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module: !transform.any_op) {
+    // expected-error @+1 {{function with name '@non_existent_func' is not found}}
+    transform.func.deduplicate_func_args @non_existent_func at %module : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+func.func private @func_with_multiple_calls(%arg0: memref<1xi8, 1>, %arg1: memref<1xi8, 1>) {
+  %c0 = arith.constant 0 : index
+  %view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+  %view0 = memref.view %arg1[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+  return
+}
+
+func.func @func_with_multiple_calls_caller1(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>) {
+  call @func_with_multiple_calls(%arg0, %arg0) : (memref<1xi8, 1>, memref<1xi8, 1>) -> ()
+  return
+}
+
+func.func @func_with_multiple_calls_caller2(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>) {
+  call @func_with_multiple_calls(%arg0, %arg0) : (memref<1xi8, 1>, memref<1xi8, 1>) -> ()
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module: !transform.any_op) {
+    // expected-error @+1 {{function with name '@func_with_multiple_calls' does not have a single callOp}}
+    transform.func.deduplicate_func_args @func_with_multiple_calls at %module : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+func.func private @func_with_no_calls(%arg0: memref<1xi8, 1>, %arg1: memref<1xi8, 1>) {
+  %c0 = arith.constant 0 : index
+  %view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+  %view0 = memref.view %arg1[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+  return
+}
+
+func.func @some_other_func() {
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module: !transform.any_op) {
+    // expected-error @+1 {{function with name '@func_with_no_calls' does not have a single callOp}}
+    transform.func.deduplicate_func_args @func_with_no_calls at %module : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Func/func-transform.mlir b/mlir/test/Dialect/Func/func-transform.mlir
index 36a66aaa95bfb..8a71511e3ed5b 100644
--- a/mlir/test/Dialect/Func/func-transform.mlir
+++ b/mlir/test/Dialect/Func/func-transform.mlir
@@ -250,3 +250,65 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// CHECK:           func.func private @func_with_duplicate_args(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<2xi8, 1>) {
+func.func private @func_with_duplicate_args(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<1xi8, 1>) {
+  %c0 = arith.constant 0 : index
+  // CHECK:             %[[VAL_3:.*]] = memref.view %[[ARG0]]{{\[}}%[[C0:.*]]][] : memref<1xi8, 1> to memref<1xi8, 1>
+  %view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+  // CHECK:             %[[VAL_4:.*]] = memref.view %[[ARG1]]{{\[}}%[[C0]]][] : memref<2xi8, 1> to memref<2xi8, 1>
+  %view0 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
+  // CHECK:             %[[VAL_5:.*]] = memref.view %[[ARG0]]{{\[}}%[[C0]]][] : memref<1xi8, 1> to memref<1xi8, 1>
+  %view1 = memref.view %arg2[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+  return
+}
+
+// CHECK:           func.func @func_with_duplicate_args_caller(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<2xi8, 1>) {
+func.func @func_with_duplicate_args_caller(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>) {
+  // CHECK:             call @func_with_duplicate_args(%[[ARG0]], %[[ARG1]]) : (memref<1xi8, 1>, memref<2xi8, 1>) -> ()
+  call @func_with_duplicate_args(%arg0, %arg1, %arg0) : (memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>) -> ()
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module: !transform.any_op) {
+    transform.func.deduplicate_func_args @func_with_duplicate_args at %module : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK:           func.func private @func_with_complex_duplicate_args(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<2xi8, 1>, %[[ARG2:.*]]: memref<3xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>) {
+func.func private @func_with_complex_duplicate_args(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<1xi8, 1>, %arg3: memref<3xi8, 1>, %arg4: memref<2xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>) {
+  %c0 = arith.constant 0 : index
+  // CHECK:             %[[RET_0:.*]] = memref.view %[[ARG0]]{{\[}}%[[C0:.*]]][] : memref<1xi8, 1> to memref<1xi8, 1>
+  %view0 = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+  // CHECK:             %[[RET_1:.*]] = memref.view %[[ARG1]]{{\[}}%[[C0]]][] : memref<2xi8, 1> to memref<2xi8, 1>
+  %view1 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
+  // CHECK:             %[[RET_2:.*]] = memref.view %[[ARG0]]{{\[}}%[[C0]]][] : memref<1xi8, 1> to memref<1xi8, 1>
+  %view2 = memref.view %arg2[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+  // CHECK:             %[[RET_3:.*]] = memref.view %[[ARG2]]{{\[}}%[[C0]]][] : memref<3xi8, 1> to memref<3xi8, 1>
+  %view3 = memref.view %arg3[%c0][] : memref<3xi8, 1> to memref<3xi8, 1>
+  // CHECK:             %[[RET_4:.*]] = memref.view %[[ARG1]]{{\[}}%[[C0]]][] : memref<2xi8, 1> to memref<2xi8, 1>
+  %view4 = memref.view %arg4[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
+  // CHECK:             return %[[RET_0]], %[[RET_1]], %[[RET_2]], %[[RET_3]], %[[RET_4]] : memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>
+  return %view0, %view1, %view2, %view3, %view4 : memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>
+}
+
+// CHECK:           func.func @func_with_complex_duplicate_args_caller(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<2xi8, 1>, %[[ARG2:.*]]: memref<3xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>) {
+func.func @func_with_complex_duplicate_args_caller(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>) {
+  // CHECK:             %[[RET:.*]]:5 = call @func_with_complex_duplicate_args(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (memref<1xi8, 1>, memref<2xi8, 1>, memref<3xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>)
+  %0:5 = call @func_with_complex_duplicate_args(%arg0, %arg1, %arg0, %arg2, %arg1) : (memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>)
+  // CHECK:             return %[[RET]]#0, %[[RET]]#1, %[[RET]]#2, %[[RET]]#3, %[[RET]]#4 : memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>
+  return %0#0, %0#1, %0#2, %0#3, %0#4 : memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module: !transform.any_op) {
+    transform.func.deduplicate_func_args @func_with_complex_duplicate_args at %module : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
    
    
More information about the Mlir-commits
mailing list