[Mlir-commits] [mlir] [mlir][func]-Add deduplicate funcOp arguments transform (PR #158266)
Amir Bishara
llvmlistbot at llvm.org
Sat Sep 13 11:01:51 PDT 2025
================
@@ -330,6 +337,89 @@ void transform::ReplaceFuncSignatureOp::getEffects(
transform::modifiesPayload(effects);
}
+//===----------------------------------------------------------------------===//
+// DeduplicateFuncArgsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::DeduplicateFuncArgsOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto payloadOps = state.getPayloadOps(getModule());
+ if (!llvm::hasSingleElement(payloadOps))
+ return emitDefiniteFailure() << "requires a single module to operate on";
+
+ auto targetModuleOp = dyn_cast<ModuleOp>(*payloadOps.begin());
+ if (!targetModuleOp)
+ return emitSilenceableFailure(getLoc())
+ << "target is expected to be module operation";
+
+ func::FuncOp funcOp =
+ targetModuleOp.lookupSymbol<func::FuncOp>(getFunctionName());
+ if (!funcOp)
+ return emitSilenceableFailure(getLoc())
+ << "function with name '" << getFunctionName() << "' is not found";
+
+ SmallVector<func::CallOp> callOps;
+ targetModuleOp.walk([&](func::CallOp callOp) {
+ if (callOp.getCallee() == getFunctionName().getRootReference().getValue())
+ callOps.push_back(callOp);
+ });
+
+ // TODO: Support more than one callOp.
+ if (!llvm::hasSingleElement(callOps))
+ return emitSilenceableFailure(getLoc())
+ << "function with name '" << getFunctionName()
+ << "' does not have a single callOp";
+
+ llvm::DenseSet<Value> seenValues;
+ func::CallOp callOp = callOps.front();
+ bool hasDuplicatesOperands =
+ llvm::any_of(callOp.getOperands(), [&seenValues](Value operand) {
+ return !seenValues.insert(operand).second;
+ });
+
+ if (!hasDuplicatesOperands)
+ return emitSilenceableFailure(getLoc())
+ << "function with name '" << getFunctionName()
+ << "' does not have duplicate operands";
+
+ llvm::SmallVector<unsigned> oldArgIdxToNewArgIdx(callOp.getNumOperands());
+ llvm::DenseMap<Value, unsigned> valueToNewArgIdx;
+ for (auto [operandIdx, operand] : llvm::enumerate(callOp.getOperands())) {
+ if (!valueToNewArgIdx.count(operand))
+ valueToNewArgIdx[operand] = valueToNewArgIdx.size();
+ // Reduce the duplicate operands and maintain the original order.
+ oldArgIdxToNewArgIdx[operandIdx] = valueToNewArgIdx[operand];
+ }
+
+ llvm::SmallVector<unsigned> oldResIdxToNewResIdx(callOp.getNumResults());
+ for (unsigned resultIdx = 0; resultIdx < callOp.getNumResults(); ++resultIdx)
+ oldResIdxToNewResIdx[resultIdx] = resultIdx;
+
+ FailureOr<func::FuncOp> newFuncOpOrFailure = func::replaceFuncWithNewMapping(
+ rewriter, funcOp, oldArgIdxToNewArgIdx, oldResIdxToNewResIdx);
+ if (failed(newFuncOpOrFailure))
+ return emitSilenceableFailure(getLoc())
+ << "failed to deduplicate function arguments '" << getFunctionName()
+ << "'";
+
+ func::replaceCallOpWithNewMapping(rewriter, callOp, oldArgIdxToNewArgIdx,
+ oldResIdxToNewResIdx);
----------------
amirBish wrote:
Oh, could you please share more of your thoughts about this issue? It sounds an interesting transformation which also can used by transform dialect with a transformOp. Maybe I'm not fully understood you POV.
https://github.com/llvm/llvm-project/pull/158266
More information about the Mlir-commits
mailing list