[Mlir-commits] [mlir] [mlir][func]-Add deduplicate funcOp arguments transform (PR #158266)
Amir Bishara
llvmlistbot at llvm.org
Sat Sep 13 11:01:50 PDT 2025
================
@@ -90,32 +144,86 @@ func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
}
func::CallOp
-func::replaceCallOpWithNewOrder(RewriterBase &rewriter, func::CallOp callOp,
- ArrayRef<unsigned> newArgsOrder,
- ArrayRef<unsigned> newResultsOrder) {
+func::replaceCallOpWithNewMapping(RewriterBase &rewriter, func::CallOp callOp,
+ llvm::ArrayRef<unsigned> oldArgToNewArg,
+ llvm::ArrayRef<unsigned> oldResToNewRes) {
assert(
- callOp.getNumOperands() == newArgsOrder.size() &&
- "newArgsOrder must match the number of operands in the call operation");
+ callOp.getNumOperands() == oldArgToNewArg.size() &&
+ "oldArgToNewArg must match the number of operands in the call operation");
assert(
- callOp.getNumResults() == newResultsOrder.size() &&
- "newResultsOrder must match the number of results in the call operation");
+ callOp.getNumResults() == oldResToNewRes.size() &&
+ "oldResToNewRes must match the number of results in the call operation");
+
+ // Inverse mapping from new arguments to old arguments.
+ unsigned numOfNewArgs = 0;
+ auto maxNewArgIdx = llvm::max_element(oldArgToNewArg);
+ if (maxNewArgIdx != oldArgToNewArg.end())
+ numOfNewArgs = *maxNewArgIdx + 1;
+ llvm::SmallVector<llvm::SmallVector<unsigned>> newArgIdxToOldArgIdxs(
+ numOfNewArgs);
+ for (auto [oldArgIdx, newArgIdx] : llvm::enumerate(oldArgToNewArg))
+ newArgIdxToOldArgIdxs[newArgIdx].push_back(oldArgIdx);
+
SmallVector<Value> newArgsOrderValues;
- for (unsigned int argIdx : newArgsOrder)
- newArgsOrderValues.push_back(callOp.getOperand(argIdx));
+ for (const auto &[newArgIdx, oldArgIdxs] :
+ llvm::enumerate(newArgIdxToOldArgIdxs)) {
+ std::ignore = newArgIdx;
+ assert(
+ llvm::all_of(oldArgIdxs,
+ [&callOp](unsigned idx) -> bool {
+ return idx < callOp.getNumOperands();
+ }) &&
+ "idx must be less than the number of operands in the call operation");
+ assert(!oldArgIdxs.empty() && "oldArgIdx must not be empty");
+ Value origOperandToCheck = callOp.getOperand(oldArgIdxs.front());
+ assert(llvm::all_of(oldArgIdxs,
+ [&](unsigned idx) -> bool {
+ return callOp.getOperand(idx).getType() ==
+ origOperandToCheck.getType();
+ }) &&
+ "all oldArgIdx must have the same type");
+ newArgsOrderValues.push_back(origOperandToCheck);
+ }
+
+ unsigned numOfNewRes = 0;
+ auto maxNewResIdx = llvm::max_element(oldResToNewRes);
+ if (maxNewResIdx != oldResToNewRes.end())
+ numOfNewRes = *maxNewResIdx + 1;
+ llvm::SmallVector<llvm::SmallVector<unsigned>> newResIdxToOldResIdxs(
+ numOfNewRes);
+ for (auto [oldResIdx, newResIdx] : llvm::enumerate(oldResToNewRes))
+ newResIdxToOldResIdxs[newResIdx].push_back(oldResIdx);
+
SmallVector<Type> newResultTypes;
- for (unsigned int resIdx : newResultsOrder)
- newResultTypes.push_back(callOp.getResult(resIdx).getType());
+ for (auto [newResIdx, oldResIdxs] : llvm::enumerate(newResIdxToOldResIdxs)) {
+ std::ignore = newResIdx;
+ assert(llvm::all_of(oldResIdxs,
+ [&callOp](unsigned idx) -> bool {
+ return idx < callOp.getNumResults();
+ }) &&
+ "idx must be less than the number of results in the call operation");
+ assert(!oldResIdxs.empty() && "oldResIdx must not be empty");
+ Value origResultToCheck = callOp.getResult(oldResIdxs.front());
+ assert(llvm::all_of(oldResIdxs,
+ [&](unsigned idx) -> bool {
+ return callOp.getResult(idx).getType() ==
+ origResultToCheck.getType();
+ }) &&
+ "all oldResIdx must have the same type");
+ newResultTypes.push_back(origResultToCheck.getType());
+ }
----------------
amirBish wrote:
I have added two static methods one for creating the inverseMapping and one to get the new mapped elements based on it. Which has removed this duplicate code.
https://github.com/llvm/llvm-project/pull/158266
More information about the Mlir-commits
mailing list