[Mlir-commits] [mlir] [mlir][func]-Add deduplicate funcOp arguments transform (PR #158266)

Amir Bishara llvmlistbot at llvm.org
Sat Sep 13 11:01:51 PDT 2025


================
@@ -37,12 +38,62 @@ func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
   ArrayRef<Type> origOutputTypes = funcOp.getFunctionType().getResults();
   SmallVector<Type> newInputTypes, newOutputTypes;
   SmallVector<Location> locs;
-  for (unsigned int idx : newArgsOrder) {
-    newInputTypes.push_back(origInputTypes[idx]);
-    locs.push_back(funcOp.getArgument(newArgsOrder[idx]).getLoc());
+
+  // We may have some duplicate arguments in the old function, i.e.
+  // in the mapping `newArgIdxToOldArgIdxs` for some new argument index
+  // there may be multiple old argument indices.
+  unsigned numOfNewArgs = 0;
+  auto maxNewArgIdx = llvm::max_element(oldArgToNewArg);
+  if (maxNewArgIdx != oldArgToNewArg.end())
+    numOfNewArgs = *maxNewArgIdx + 1;
+  llvm::SmallVector<llvm::SmallVector<unsigned>> newArgIdxToOldArgIdxs(
+      numOfNewArgs);
+  for (auto [oldArgIdx, newArgIdx] : llvm::enumerate(oldArgToNewArg))
+    newArgIdxToOldArgIdxs[newArgIdx].push_back(oldArgIdx);
+
+  for (auto [newArgIdx, oldArgIdxs] : llvm::enumerate(newArgIdxToOldArgIdxs)) {
+    std::ignore = newArgIdx;
+    assert(llvm::all_of(oldArgIdxs,
+                        [&funcOp](unsigned idx) -> bool {
+                          return idx < funcOp.getNumArguments();
+                        }) &&
+           "idx must be less than the number of arguments in the function");
+    assert(!oldArgIdxs.empty() && "oldArgIdxs must not be empty");
+    Type origInputTypeToCheck = origInputTypes[oldArgIdxs.front()];
+    assert(llvm::all_of(oldArgIdxs,
+                        [&](unsigned idx) -> bool {
+                          return origInputTypes[idx] == origInputTypeToCheck;
+                        }) &&
+           "all oldArgIdx must have the same type");
+    newInputTypes.push_back(origInputTypeToCheck);
+    locs.push_back(funcOp.getArgument(oldArgIdxs.front()).getLoc());
+  }
+
+  unsigned numOfNewRes = 0;
+  auto maxNewResIdx = llvm::max_element(oldResToNewRes);
+  if (maxNewResIdx != oldResToNewRes.end())
+    numOfNewRes = *maxNewResIdx + 1;
+  llvm::SmallVector<llvm::SmallVector<unsigned>> newResToOldResIdxs(
----------------
amirBish wrote:

Sure changed, Still could you please elaborate more about why preferring the int instead of using the unsigned in this context (though we're handling indexes).

https://github.com/llvm/llvm-project/pull/158266


More information about the Mlir-commits mailing list