[Mlir-commits] [mlir] [mlir][func]-Add deduplicate funcOp arguments transform (PR #158266)
Amir Bishara
llvmlistbot at llvm.org
Sun Sep 14 13:20:38 PDT 2025
================
@@ -90,33 +152,105 @@ func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
}
func::CallOp
-func::replaceCallOpWithNewOrder(RewriterBase &rewriter, func::CallOp callOp,
- ArrayRef<unsigned> newArgsOrder,
- ArrayRef<unsigned> newResultsOrder) {
- assert(
- callOp.getNumOperands() == newArgsOrder.size() &&
- "newArgsOrder must match the number of operands in the call operation");
- assert(
- callOp.getNumResults() == newResultsOrder.size() &&
- "newResultsOrder must match the number of results in the call operation");
- SmallVector<Value> newArgsOrderValues;
- for (unsigned int argIdx : newArgsOrder)
- newArgsOrderValues.push_back(callOp.getOperand(argIdx));
- SmallVector<Type> newResultTypes;
- for (unsigned int resIdx : newResultsOrder)
- newResultTypes.push_back(callOp.getResult(resIdx).getType());
+func::replaceCallOpWithNewMapping(RewriterBase &rewriter, func::CallOp callOp,
+ ArrayRef<int> oldArgIdxToNewArgIdx,
+ ArrayRef<int> oldResIdxToNewResIdx) {
+ assert(callOp.getNumOperands() == oldArgIdxToNewArgIdx.size() &&
+ "oldArgIdxToNewArgIdx must match the number of operands in the call "
+ "operation");
+ assert(callOp.getNumResults() == oldResIdxToNewResIdx.size() &&
+ "oldResIdxToNewResIdx must match the number of results in the call "
+ "operation");
+
+ SmallVector<Value> origOperands = callOp.getOperands();
+ SmallVector<llvm::SmallVector<int>> newArgIdxToOldArgIdxs =
+ getInverseMapping(oldArgIdxToNewArgIdx);
+ SmallVector<Value> newOperandsValues =
+ getMappedElements<Value>(origOperands, newArgIdxToOldArgIdxs);
+ SmallVector<llvm::SmallVector<int>> newResToOldResIdxs =
+ getInverseMapping(oldResIdxToNewResIdx);
+ SmallVector<Type> origResultTypes = llvm::to_vector(callOp.getResultTypes());
+ SmallVector<Type> newResultTypes =
+ getMappedElements<Type>(origResultTypes, newResToOldResIdxs);
// Replace the kernel call operation with a new one that has the
- // reordered arguments.
+ // mapped arguments.
rewriter.setInsertionPoint(callOp);
auto newCallOp =
func::CallOp::create(rewriter, callOp.getLoc(), callOp.getCallee(),
- newResultTypes, newArgsOrderValues);
+ newResultTypes, newOperandsValues);
newCallOp.setNoInlineAttr(callOp.getNoInlineAttr());
- for (auto &&[newIndex, origIndex] : llvm::enumerate(newResultsOrder))
- rewriter.replaceAllUsesWith(callOp.getResult(origIndex),
- newCallOp.getResult(newIndex));
+ for (auto &&[oldResIdx, newResIdx] : llvm::enumerate(oldResIdxToNewResIdx))
+ rewriter.replaceAllUsesWith(callOp.getResult(oldResIdx),
+ newCallOp.getResult(newResIdx));
rewriter.eraseOp(callOp);
return newCallOp;
}
+
+FailureOr<std::pair<func::FuncOp, func::CallOp>>
+func::deduplicateArgsOfFuncOp(RewriterBase &rewriter, func::FuncOp funcOp,
+ ModuleOp moduleOp, std::string &errorMessage) {
+ SmallVector<func::CallOp> callOps;
+ auto traversalResult = moduleOp.walk([&](func::CallOp callOp) {
+ if (callOp.getCallee() == funcOp.getSymName()) {
+ if (!callOps.empty())
+ // Only support one callOp for now
+ return WalkResult::interrupt();
+ callOps.push_back(callOp);
+ }
+ return WalkResult::advance();
+ });
+
+ if (traversalResult.wasInterrupted()) {
+ errorMessage = "function with name '" + funcOp.getSymName().str() +
+ "' has more than one callOp";
+ return failure();
+ }
+
+ if (callOps.empty()) {
+ errorMessage = "function with name '" + funcOp.getSymName().str() +
+ "' does not have any callOp";
+ return failure();
+ }
+
+ func::CallOp callOp = callOps.front();
+
+ // Create mapping for arguments (deduplicate operands)
+ SmallVector<int> oldArgIdxToNewArgIdx(callOp.getNumOperands());
+ llvm::DenseMap<Value, int> valueToNewArgIdx;
+ for (auto [operandIdx, operand] : llvm::enumerate(callOp.getOperands())) {
+ auto [iterator, inserted] = valueToNewArgIdx.insert(
+ {operand, static_cast<int>(valueToNewArgIdx.size())});
+ // Reduce the duplicate operands and maintain the original order.
+ oldArgIdxToNewArgIdx[operandIdx] = iterator->second;
+ }
+
+ bool hasDuplicateOperands =
+ valueToNewArgIdx.size() != callOp.getNumOperands();
+ if (!hasDuplicateOperands) {
+ errorMessage = "function with name '" + funcOp.getSymName().str() +
+ "' does not have duplicate operands";
+ return failure();
----------------
amirBish wrote:
Replaced with `LDBG()`.
https://github.com/llvm/llvm-project/pull/158266
More information about the Mlir-commits
mailing list