[Mlir-commits] [mlir] [mlir][func]-Add deduplicate funcOp arguments transform (PR #158266)

Mehdi Amini llvmlistbot at llvm.org
Fri Sep 12 09:33:31 PDT 2025


================
@@ -90,32 +144,86 @@ func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
 }
 
 func::CallOp
-func::replaceCallOpWithNewOrder(RewriterBase &rewriter, func::CallOp callOp,
-                                ArrayRef<unsigned> newArgsOrder,
-                                ArrayRef<unsigned> newResultsOrder) {
+func::replaceCallOpWithNewMapping(RewriterBase &rewriter, func::CallOp callOp,
+                                  llvm::ArrayRef<unsigned> oldArgToNewArg,
+                                  llvm::ArrayRef<unsigned> oldResToNewRes) {
   assert(
-      callOp.getNumOperands() == newArgsOrder.size() &&
-      "newArgsOrder must match the number of operands in the call operation");
+      callOp.getNumOperands() == oldArgToNewArg.size() &&
+      "oldArgToNewArg must match the number of operands in the call operation");
   assert(
-      callOp.getNumResults() == newResultsOrder.size() &&
-      "newResultsOrder must match the number of results in the call operation");
+      callOp.getNumResults() == oldResToNewRes.size() &&
+      "oldResToNewRes must match the number of results in the call operation");
+
+  // Inverse mapping from new arguments to old arguments.
+  unsigned numOfNewArgs = 0;
+  auto maxNewArgIdx = llvm::max_element(oldArgToNewArg);
+  if (maxNewArgIdx != oldArgToNewArg.end())
+    numOfNewArgs = *maxNewArgIdx + 1;
+  llvm::SmallVector<llvm::SmallVector<unsigned>> newArgIdxToOldArgIdxs(
+      numOfNewArgs);
+  for (auto [oldArgIdx, newArgIdx] : llvm::enumerate(oldArgToNewArg))
+    newArgIdxToOldArgIdxs[newArgIdx].push_back(oldArgIdx);
+
   SmallVector<Value> newArgsOrderValues;
-  for (unsigned int argIdx : newArgsOrder)
-    newArgsOrderValues.push_back(callOp.getOperand(argIdx));
+  for (const auto &[newArgIdx, oldArgIdxs] :
+       llvm::enumerate(newArgIdxToOldArgIdxs)) {
+    std::ignore = newArgIdx;
+    assert(
+        llvm::all_of(oldArgIdxs,
+                     [&callOp](unsigned idx) -> bool {
+                       return idx < callOp.getNumOperands();
+                     }) &&
+        "idx must be less than the number of operands in the call operation");
+    assert(!oldArgIdxs.empty() && "oldArgIdx must not be empty");
+    Value origOperandToCheck = callOp.getOperand(oldArgIdxs.front());
+    assert(llvm::all_of(oldArgIdxs,
+                        [&](unsigned idx) -> bool {
+                          return callOp.getOperand(idx).getType() ==
+                                 origOperandToCheck.getType();
+                        }) &&
+           "all oldArgIdx must have the same type");
+    newArgsOrderValues.push_back(origOperandToCheck);
+  }
+
+  unsigned numOfNewRes = 0;
+  auto maxNewResIdx = llvm::max_element(oldResToNewRes);
+  if (maxNewResIdx != oldResToNewRes.end())
+    numOfNewRes = *maxNewResIdx + 1;
+  llvm::SmallVector<llvm::SmallVector<unsigned>> newResIdxToOldResIdxs(
+      numOfNewRes);
+  for (auto [oldResIdx, newResIdx] : llvm::enumerate(oldResToNewRes))
+    newResIdxToOldResIdxs[newResIdx].push_back(oldResIdx);
+
   SmallVector<Type> newResultTypes;
-  for (unsigned int resIdx : newResultsOrder)
-    newResultTypes.push_back(callOp.getResult(resIdx).getType());
+  for (auto [newResIdx, oldResIdxs] : llvm::enumerate(newResIdxToOldResIdxs)) {
+    std::ignore = newResIdx;
+    assert(llvm::all_of(oldResIdxs,
+                        [&callOp](unsigned idx) -> bool {
+                          return idx < callOp.getNumResults();
+                        }) &&
+           "idx must be less than the number of results in the call operation");
+    assert(!oldResIdxs.empty() && "oldResIdx must not be empty");
+    Value origResultToCheck = callOp.getResult(oldResIdxs.front());
+    assert(llvm::all_of(oldResIdxs,
+                        [&](unsigned idx) -> bool {
+                          return callOp.getResult(idx).getType() ==
+                                 origResultToCheck.getType();
+                        }) &&
+           "all oldResIdx must have the same type");
+    newResultTypes.push_back(origResultToCheck.getType());
+  }
----------------
joker-eph wrote:

The process for results seems oddly similar to the process for operands, can this be refactored?

https://github.com/llvm/llvm-project/pull/158266


More information about the Mlir-commits mailing list