[Mlir-commits] [mlir] Allowing RDV to call `getArgOperandsMutable()` (PR #160415)
Francisco Geiman Thiesen
llvmlistbot at llvm.org
Wed Sep 24 13:56:50 PDT 2025
https://github.com/FranciscoThiesen updated https://github.com/llvm/llvm-project/pull/160415
>From 19a9c64d49906845faed6e81effed356a0020ec1 Mon Sep 17 00:00:00 2001
From: Francisco Geiman Thiesen <franciscoge at microsoft.com>
Date: Fri, 12 Sep 2025 22:19:46 +0000
Subject: [PATCH 1/2] Adding changes to RDV +small repro case for dialect with
callOp and the AttrSizedOperandSegments trait
---
mlir/lib/Transforms/RemoveDeadValues.cpp | 64 +++++++++++++++----
.../remove-dead-values-call-segments.mlir | 23 +++++++
mlir/test/lib/Dialect/Test/TestDialect.cpp | 44 +++++++++++++
mlir/test/lib/Dialect/Test/TestOps.td | 43 +++++++++++++
4 files changed, 160 insertions(+), 14 deletions(-)
create mode 100644 mlir/test/Transforms/remove-dead-values-call-segments.mlir
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 0e84b6dd17f29..0655adaad5f5f 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -306,19 +306,17 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
nonLiveSet.insert(arg);
}
- // Do (2).
+ // Do (2). (Skip creating generic operand cleanup entries for call ops.
+ // Call arguments will be removed in the call-site specific segment-aware
+ // cleanup, avoiding generic eraseOperands bitvector mechanics.)
SymbolTable::UseRange uses = *funcOp.getSymbolUses(module);
for (SymbolTable::SymbolUse use : uses) {
Operation *callOp = use.getUser();
assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
- // The number of operands in the call op may not match the number of
- // arguments in the func op.
- BitVector nonLiveCallOperands(callOp->getNumOperands(), false);
- SmallVector<OpOperand *> callOpOperands =
- operandsToOpOperands(cast<CallOpInterface>(callOp).getArgOperands());
- for (int index : nonLiveArgs.set_bits())
- nonLiveCallOperands.set(callOpOperands[index]->getOperandNumber());
- cl.operands.push_back({callOp, nonLiveCallOperands});
+ // Push an empty operand cleanup entry so that call-site specific logic in
+ // cleanUpDeadVals runs (it keys off CallOpInterface). The BitVector is
+ // intentionally all false to avoid generic erasure.
+ cl.operands.push_back({callOp, BitVector(callOp->getNumOperands(), false)});
}
// Do (3).
@@ -746,6 +744,10 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
// 3. Functions
LDBG() << "Cleaning up " << list.functions.size() << " functions";
+ // Record which function arguments were erased so we can shrink call-site
+ // argument segments for CallOpInterface operations (e.g. ops using
+ // AttrSizedOperandSegments) in the next phase.
+ DenseMap<Operation *, BitVector> erasedFuncArgs;
for (auto &f : list.functions) {
LDBG() << "Cleaning up function: " << f.funcOp.getOperation()->getName();
LDBG() << " Erasing " << f.nonLiveArgs.count() << " non-live arguments";
@@ -754,17 +756,51 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
// Some functions may not allow erasing arguments or results. These calls
// return failure in such cases without modifying the function, so it's okay
// to proceed.
- (void)f.funcOp.eraseArguments(f.nonLiveArgs);
+ if (succeeded(f.funcOp.eraseArguments(f.nonLiveArgs))) {
+ // Record only if we actually erased something.
+ if (f.nonLiveArgs.any())
+ erasedFuncArgs.try_emplace(f.funcOp.getOperation(), f.nonLiveArgs);
+ }
(void)f.funcOp.eraseResults(f.nonLiveRets);
}
// 4. Operands
LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
for (OperationToCleanup &o : list.operands) {
- if (o.op->getNumOperands() > 0) {
- LDBG() << "Erasing " << o.nonLive.count()
- << " non-live operands from operation: "
- << OpWithFlags(o.op, OpPrintingFlags().skipRegions());
+ if (auto call = dyn_cast<CallOpInterface>(o.op)) {
+ if (SymbolRefAttr sym = call.getCallableForCallee().dyn_cast<SymbolRefAttr>()) {
+ Operation *callee = SymbolTable::lookupNearestSymbolFrom(o.op, sym);
+ auto it = erasedFuncArgs.find(callee);
+ if (it != erasedFuncArgs.end()) {
+ const BitVector &deadArgIdxs = it->second;
+ MutableOperandRange args = call.getArgOperandsMutable();
+ // First, erase the call arguments corresponding to erased callee args.
+ for (int i = static_cast<int>(args.size()) - 1; i >= 0; --i) {
+ if (i < static_cast<int>(deadArgIdxs.size()) && deadArgIdxs.test(i))
+ args.erase(i);
+ }
+ // If this operand cleanup entry also has a generic nonLive bitvector,
+ // clear bits for call arguments we already erased above to avoid
+ // double-erasing (which could impact other segments of ops with
+ // AttrSizedOperandSegments).
+ if (o.nonLive.any()) {
+ // Map the argument logical index to the operand number(s) recorded.
+ SmallVector<OpOperand *> callOperands =
+ operandsToOpOperands(call.getArgOperands());
+ for (int argIdx : deadArgIdxs.set_bits()) {
+ if (argIdx < static_cast<int>(callOperands.size())) {
+ unsigned operandNumber = callOperands[argIdx]->getOperandNumber();
+ if (operandNumber < o.nonLive.size())
+ o.nonLive.reset(operandNumber);
+ }
+ }
+ }
+ }
+ }
+ }
+ // Only perform generic operand erasure for non-call ops; for call ops we
+ // already handled argument removals via the segment-aware path above.
+ if (!isa<CallOpInterface>(o.op) && o.nonLive.any()) {
o.op->eraseOperands(o.nonLive);
}
}
diff --git a/mlir/test/Transforms/remove-dead-values-call-segments.mlir b/mlir/test/Transforms/remove-dead-values-call-segments.mlir
new file mode 100644
index 0000000000000..fed9cabbd2ee8
--- /dev/null
+++ b/mlir/test/Transforms/remove-dead-values-call-segments.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt --split-input-file --remove-dead-values --mlir-print-op-generic %s | FileCheck %s --check-prefix=GEN
+
+// -----
+// Private callee: both args become dead after internal DCE; RDV drops callee
+// args and shrinks the *args* segment on the call-site to zero; sizes kept in
+// sync.
+
+module {
+ func.func private @callee(%x: i32, %y: i32) {
+ %u = arith.addi %x, %x : i32 // %y is dead
+ return
+ }
+
+ func.func @caller(%a: i32, %b: i32) {
+ // args segment initially has 2 operands.
+ "test.call_with_segments"(%a, %b) { callee = @callee,
+ operandSegmentSizes = array<i32: 0, 2, 0> } : (i32, i32) -> ()
+ return
+ }
+}
+
+// GEN: "test.call_with_segments"() <{callee = @callee, operandSegmentSizes = array<i32: 0, 0, 0>}> : () -> ()
+// ^ args shrank from 2 -> 0
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 987e8f3654ce8..5016ab6b94cdb 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -431,3 +431,47 @@ void TestDialect::getCanonicalizationPatterns(
RewritePatternSet &results) const {
results.add(&dialectCanonicalizationPattern);
}
+
+//===----------------------------------------------------------------------===//
+// TestCallWithSegmentsOp
+//===----------------------------------------------------------------------===//
+// The op `test.call_with_segments` models a call-like operation whose operands
+// are divided into 3 variadic segments: `prefix`, `args`, and `suffix`.
+// Only the middle segment represents the actual call arguments. The op uses
+// the AttrSizedOperandSegments trait, so we can derive segment boundaries from
+// the generated `operandSegmentSizes` attribute. We provide custom helpers to
+// expose the logical call arguments as both a read-only range and a mutable
+// range bound to the proper segment so that insertion/erasure updates the
+// attribute automatically.
+
+// Segment layout indices in the DenseI32ArrayAttr: [prefix, args, suffix].
+static constexpr unsigned kTestCallWithSegmentsArgsSegIndex = 1;
+
+Operation::operand_range CallWithSegmentsOp::getArgOperands() {
+ // Leverage generated getters for segment sizes: slice between prefix and
+ // suffix using current operand list.
+ return getOperation()->getOperands().slice(getPrefix().size(),
+ getArgs().size());
+}
+
+MutableOperandRange CallWithSegmentsOp::getArgOperandsMutable() {
+ Operation *op = getOperation();
+
+ // Obtain the canonical segment size attribute name for this op.
+ auto segName =
+ CallWithSegmentsOp::getOperandSegmentSizesAttrName(op->getName());
+ auto sizesAttr = op->getAttrOfType<DenseI32ArrayAttr>(segName);
+ assert(sizesAttr && "missing operandSegmentSizes attribute on op");
+
+ // Compute the start and length of the args segment from the prefix size and
+ // args size stored in the attribute.
+ auto sizes = sizesAttr.asArrayRef();
+ unsigned start = static_cast<unsigned>(sizes[0]); // prefix size
+ unsigned len = static_cast<unsigned>(sizes[1]); // args size
+
+ NamedAttribute segNamed(segName, sizesAttr);
+ MutableOperandRange::OperandSegment binding{kTestCallWithSegmentsArgsSegIndex,
+ segNamed};
+
+ return MutableOperandRange(op, start, len, {binding});
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 5564264ed8b0b..a459385129909 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3745,4 +3745,47 @@ def TestOpWithSuccessorRef : TEST_Op<"dummy_op_with_successor_ref"> {
}];
}
+def CallWithSegmentsOp : TEST_Op<"call_with_segments",
+ [AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<CallOpInterface>]> {
+ let summary = "test call op with segmented args";
+ let arguments = (ins
+ FlatSymbolRefAttr:$callee,
+ Variadic<AnyType>:$prefix, // non-arg segment (e.g., 'in')
+ Variadic<AnyType>:$args, // <-- the call *arguments* segment
+ Variadic<AnyType>:$suffix // non-arg segment (e.g., 'out')
+ );
+ let results = (outs);
+ let assemblyFormat = [{
+ $callee `(` $prefix `:` type($prefix) `)`
+ `(` $args `:` type($args) `)`
+ `(` $suffix `:` type($suffix) `)` attr-dict
+ }];
+
+ // Provide stub implementations for the ArgAndResultAttrsOpInterface.
+ let extraClassDeclaration = [{
+ ::mlir::ArrayAttr getArgAttrsAttr() { return {}; }
+ ::mlir::ArrayAttr getResAttrsAttr() { return {}; }
+ void setArgAttrsAttr(::mlir::ArrayAttr) {}
+ void setResAttrsAttr(::mlir::ArrayAttr) {}
+ ::mlir::Attribute removeArgAttrsAttr() { return {}; }
+ ::mlir::Attribute removeResAttrsAttr() { return {}; }
+ }];
+
+ let extraClassDefinition = [{
+ ::mlir::CallInterfaceCallable $cppClass::getCallableForCallee() {
+ if (auto sym = (*this)->getAttrOfType<::mlir::SymbolRefAttr>("callee"))
+ return ::mlir::CallInterfaceCallable(sym);
+ return ::mlir::CallInterfaceCallable();
+ }
+ void $cppClass::setCalleeFromCallable(::mlir::CallInterfaceCallable callee) {
+ if (auto sym = callee.dyn_cast<::mlir::SymbolRefAttr>())
+ (*this)->setAttr("callee", sym);
+ else
+ (*this)->removeAttr("callee");
+ }
+ }];
+}
+
+
#endif // TEST_OPS
>From 78ae4f1905e567dd66aebb488864e29194e69e84 Mon Sep 17 00:00:00 2001
From: Francisco Geiman Thiesen <franciscogthiesen at gmail.com>
Date: Wed, 24 Sep 2025 14:56:42 -0600
Subject: [PATCH 2/2] Update mlir/lib/Transforms/RemoveDeadValues.cpp with
joker-eph suggestion
Co-authored-by: Mehdi Amini <joker.eph at gmail.com>
---
mlir/lib/Transforms/RemoveDeadValues.cpp | 11 ++++-------
1 file changed, 4 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 0655adaad5f5f..03c02859366b2 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -785,14 +785,11 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
// AttrSizedOperandSegments).
if (o.nonLive.any()) {
// Map the argument logical index to the operand number(s) recorded.
- SmallVector<OpOperand *> callOperands =
- operandsToOpOperands(call.getArgOperands());
+ int operandOffset = call.getArgOperands().getBeginOperandIndex();
for (int argIdx : deadArgIdxs.set_bits()) {
- if (argIdx < static_cast<int>(callOperands.size())) {
- unsigned operandNumber = callOperands[argIdx]->getOperandNumber();
- if (operandNumber < o.nonLive.size())
- o.nonLive.reset(operandNumber);
- }
+ int operandNumber = operandOffset + argIdx;
+ if (operandNumber < o.nonLive.size())
+ o.nonLive.reset(operandNumber);
}
}
}
More information about the Mlir-commits
mailing list