[Mlir-commits] [mlir] Allowing RDV to call `getArgOperandsMutable()` (PR #160415)
Mehdi Amini
llvmlistbot at llvm.org
Wed Sep 24 06:57:03 PDT 2025
================
@@ -754,17 +756,51 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
// Some functions may not allow erasing arguments or results. These calls
// return failure in such cases without modifying the function, so it's okay
// to proceed.
- (void)f.funcOp.eraseArguments(f.nonLiveArgs);
+ if (succeeded(f.funcOp.eraseArguments(f.nonLiveArgs))) {
+ // Record only if we actually erased something.
+ if (f.nonLiveArgs.any())
+ erasedFuncArgs.try_emplace(f.funcOp.getOperation(), f.nonLiveArgs);
+ }
(void)f.funcOp.eraseResults(f.nonLiveRets);
}
// 4. Operands
LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
for (OperationToCleanup &o : list.operands) {
- if (o.op->getNumOperands() > 0) {
- LDBG() << "Erasing " << o.nonLive.count()
- << " non-live operands from operation: "
- << OpWithFlags(o.op, OpPrintingFlags().skipRegions());
+ if (auto call = dyn_cast<CallOpInterface>(o.op)) {
+ if (SymbolRefAttr sym = call.getCallableForCallee().dyn_cast<SymbolRefAttr>()) {
+ Operation *callee = SymbolTable::lookupNearestSymbolFrom(o.op, sym);
+ auto it = erasedFuncArgs.find(callee);
+ if (it != erasedFuncArgs.end()) {
+ const BitVector &deadArgIdxs = it->second;
+ MutableOperandRange args = call.getArgOperandsMutable();
+ // First, erase the call arguments corresponding to erased callee args.
+ for (int i = static_cast<int>(args.size()) - 1; i >= 0; --i) {
+ if (i < static_cast<int>(deadArgIdxs.size()) && deadArgIdxs.test(i))
+ args.erase(i);
+ }
+ // If this operand cleanup entry also has a generic nonLive bitvector,
+ // clear bits for call arguments we already erased above to avoid
+ // double-erasing (which could impact other segments of ops with
+ // AttrSizedOperandSegments).
+ if (o.nonLive.any()) {
+ // Map the argument logical index to the operand number(s) recorded.
+ SmallVector<OpOperand *> callOperands =
+ operandsToOpOperands(call.getArgOperands());
+ for (int argIdx : deadArgIdxs.set_bits()) {
+ if (argIdx < static_cast<int>(callOperands.size())) {
+ unsigned operandNumber = callOperands[argIdx]->getOperandNumber();
+ if (operandNumber < o.nonLive.size())
+ o.nonLive.reset(operandNumber);
+ }
+ }
+ }
----------------
joker-eph wrote:
I'm not sure if I'm missing something, but that seems simpler:
```suggestion
if (o.nonLive.any()) {
// Map the argument logical index to the operand number(s) recorded.
int operandOffset = call.getArgOperands().getBeginOperandIndex();
for (int argIdx : deadArgIdxs.set_bits()) {
int operandNumber = operandOffset + argIdx;
if (operandNumber < o.nonLive.size())
o.nonLive.reset(operandNumber);
}
}
```
https://github.com/llvm/llvm-project/pull/160415
More information about the Mlir-commits
mailing list