[Mlir-commits] [llvm] [mlir] Allowing RDV to call `getArgOperandsMutable()` (PR #160415)

Francisco Geiman Thiesen llvmlistbot at llvm.org
Thu Sep 25 15:46:11 PDT 2025


https://github.com/FranciscoThiesen updated https://github.com/llvm/llvm-project/pull/160415

>From 19a9c64d49906845faed6e81effed356a0020ec1 Mon Sep 17 00:00:00 2001
From: Francisco Geiman Thiesen <franciscoge at microsoft.com>
Date: Fri, 12 Sep 2025 22:19:46 +0000
Subject: [PATCH 1/6] Adding changes to RDV +small repro case for dialect with
 callOp and the AttrSizedOperandSegments trait

---
 mlir/lib/Transforms/RemoveDeadValues.cpp      | 64 +++++++++++++++----
 .../remove-dead-values-call-segments.mlir     | 23 +++++++
 mlir/test/lib/Dialect/Test/TestDialect.cpp    | 44 +++++++++++++
 mlir/test/lib/Dialect/Test/TestOps.td         | 43 +++++++++++++
 4 files changed, 160 insertions(+), 14 deletions(-)
 create mode 100644 mlir/test/Transforms/remove-dead-values-call-segments.mlir

diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 0e84b6dd17f29..0655adaad5f5f 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -306,19 +306,17 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
       nonLiveSet.insert(arg);
     }
 
-  // Do (2).
+  // Do (2). (Skip creating generic operand cleanup entries for call ops.
+  // Call arguments will be removed in the call-site specific segment-aware
+  // cleanup, avoiding generic eraseOperands bitvector mechanics.)
   SymbolTable::UseRange uses = *funcOp.getSymbolUses(module);
   for (SymbolTable::SymbolUse use : uses) {
     Operation *callOp = use.getUser();
     assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
-    // The number of operands in the call op may not match the number of
-    // arguments in the func op.
-    BitVector nonLiveCallOperands(callOp->getNumOperands(), false);
-    SmallVector<OpOperand *> callOpOperands =
-        operandsToOpOperands(cast<CallOpInterface>(callOp).getArgOperands());
-    for (int index : nonLiveArgs.set_bits())
-      nonLiveCallOperands.set(callOpOperands[index]->getOperandNumber());
-    cl.operands.push_back({callOp, nonLiveCallOperands});
+    // Push an empty operand cleanup entry so that call-site specific logic in
+    // cleanUpDeadVals runs (it keys off CallOpInterface). The BitVector is
+    // intentionally all false to avoid generic erasure.
+    cl.operands.push_back({callOp, BitVector(callOp->getNumOperands(), false)});
   }
 
   // Do (3).
@@ -746,6 +744,10 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
 
   // 3. Functions
   LDBG() << "Cleaning up " << list.functions.size() << " functions";
+  // Record which function arguments were erased so we can shrink call-site
+  // argument segments for CallOpInterface operations (e.g. ops using
+  // AttrSizedOperandSegments) in the next phase.
+  DenseMap<Operation *, BitVector> erasedFuncArgs;
   for (auto &f : list.functions) {
     LDBG() << "Cleaning up function: " << f.funcOp.getOperation()->getName();
     LDBG() << "  Erasing " << f.nonLiveArgs.count() << " non-live arguments";
@@ -754,17 +756,51 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
     // Some functions may not allow erasing arguments or results. These calls
     // return failure in such cases without modifying the function, so it's okay
     // to proceed.
-    (void)f.funcOp.eraseArguments(f.nonLiveArgs);
+    if (succeeded(f.funcOp.eraseArguments(f.nonLiveArgs))) {
+      // Record only if we actually erased something.
+      if (f.nonLiveArgs.any())
+        erasedFuncArgs.try_emplace(f.funcOp.getOperation(), f.nonLiveArgs);
+    }
     (void)f.funcOp.eraseResults(f.nonLiveRets);
   }
 
   // 4. Operands
   LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
   for (OperationToCleanup &o : list.operands) {
-    if (o.op->getNumOperands() > 0) {
-      LDBG() << "Erasing " << o.nonLive.count()
-             << " non-live operands from operation: "
-             << OpWithFlags(o.op, OpPrintingFlags().skipRegions());
+    if (auto call = dyn_cast<CallOpInterface>(o.op)) {
+      if (SymbolRefAttr sym = call.getCallableForCallee().dyn_cast<SymbolRefAttr>()) {
+        Operation *callee = SymbolTable::lookupNearestSymbolFrom(o.op, sym);
+        auto it = erasedFuncArgs.find(callee);
+        if (it != erasedFuncArgs.end()) {
+          const BitVector &deadArgIdxs = it->second;
+          MutableOperandRange args = call.getArgOperandsMutable();
+          // First, erase the call arguments corresponding to erased callee args.
+          for (int i = static_cast<int>(args.size()) - 1; i >= 0; --i) {
+            if (i < static_cast<int>(deadArgIdxs.size()) && deadArgIdxs.test(i))
+              args.erase(i);
+          }
+          // If this operand cleanup entry also has a generic nonLive bitvector,
+          // clear bits for call arguments we already erased above to avoid
+          // double-erasing (which could impact other segments of ops with
+          // AttrSizedOperandSegments).
+          if (o.nonLive.any()) {
+            // Map the argument logical index to the operand number(s) recorded.
+            SmallVector<OpOperand *> callOperands =
+                operandsToOpOperands(call.getArgOperands());
+            for (int argIdx : deadArgIdxs.set_bits()) {
+              if (argIdx < static_cast<int>(callOperands.size())) {
+                unsigned operandNumber = callOperands[argIdx]->getOperandNumber();
+                if (operandNumber < o.nonLive.size())
+                  o.nonLive.reset(operandNumber);
+              }
+            }
+          }
+        }
+      }
+    }
+    // Only perform generic operand erasure for non-call ops; for call ops we
+    // already handled argument removals via the segment-aware path above.
+    if (!isa<CallOpInterface>(o.op) && o.nonLive.any()) {
       o.op->eraseOperands(o.nonLive);
     }
   }
diff --git a/mlir/test/Transforms/remove-dead-values-call-segments.mlir b/mlir/test/Transforms/remove-dead-values-call-segments.mlir
new file mode 100644
index 0000000000000..fed9cabbd2ee8
--- /dev/null
+++ b/mlir/test/Transforms/remove-dead-values-call-segments.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt --split-input-file --remove-dead-values --mlir-print-op-generic %s | FileCheck %s --check-prefix=GEN
+
+// -----
+// Private callee: both args become dead after internal DCE; RDV drops callee
+// args and shrinks the *args* segment on the call-site to zero; sizes kept in
+// sync.
+
+module {
+  func.func private @callee(%x: i32, %y: i32) {
+    %u = arith.addi %x, %x : i32   // %y is dead
+    return
+  }
+
+  func.func @caller(%a: i32, %b: i32) {
+    // args segment initially has 2 operands.
+    "test.call_with_segments"(%a, %b) { callee = @callee,
+      operandSegmentSizes = array<i32: 0, 2, 0> } : (i32, i32) -> ()
+    return
+  }
+}
+
+// GEN: "test.call_with_segments"() <{callee = @callee, operandSegmentSizes = array<i32: 0, 0, 0>}> : () -> ()
+//       ^ args shrank from 2 -> 0
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 987e8f3654ce8..5016ab6b94cdb 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -431,3 +431,47 @@ void TestDialect::getCanonicalizationPatterns(
     RewritePatternSet &results) const {
   results.add(&dialectCanonicalizationPattern);
 }
+
+//===----------------------------------------------------------------------===//
+// TestCallWithSegmentsOp
+//===----------------------------------------------------------------------===//
+// The op `test.call_with_segments` models a call-like operation whose operands
+// are divided into 3 variadic segments: `prefix`, `args`, and `suffix`.
+// Only the middle segment represents the actual call arguments. The op uses
+// the AttrSizedOperandSegments trait, so we can derive segment boundaries from
+// the generated `operandSegmentSizes` attribute. We provide custom helpers to
+// expose the logical call arguments as both a read-only range and a mutable
+// range bound to the proper segment so that insertion/erasure updates the
+// attribute automatically.
+
+// Segment layout indices in the DenseI32ArrayAttr: [prefix, args, suffix].
+static constexpr unsigned kTestCallWithSegmentsArgsSegIndex = 1;
+
+Operation::operand_range CallWithSegmentsOp::getArgOperands() {
+  // Leverage generated getters for segment sizes: slice between prefix and
+  // suffix using current operand list.
+  return getOperation()->getOperands().slice(getPrefix().size(),
+                                             getArgs().size());
+}
+
+MutableOperandRange CallWithSegmentsOp::getArgOperandsMutable() {
+  Operation *op = getOperation();
+
+  // Obtain the canonical segment size attribute name for this op.
+  auto segName =
+      CallWithSegmentsOp::getOperandSegmentSizesAttrName(op->getName());
+  auto sizesAttr = op->getAttrOfType<DenseI32ArrayAttr>(segName);
+  assert(sizesAttr && "missing operandSegmentSizes attribute on op");
+
+  // Compute the start and length of the args segment from the prefix size and
+  // args size stored in the attribute.
+  auto sizes = sizesAttr.asArrayRef();
+  unsigned start = static_cast<unsigned>(sizes[0]); // prefix size
+  unsigned len = static_cast<unsigned>(sizes[1]);    // args size
+
+  NamedAttribute segNamed(segName, sizesAttr);
+  MutableOperandRange::OperandSegment binding{kTestCallWithSegmentsArgsSegIndex,
+                                              segNamed};
+
+  return MutableOperandRange(op, start, len, {binding});
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 5564264ed8b0b..a459385129909 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3745,4 +3745,47 @@ def TestOpWithSuccessorRef : TEST_Op<"dummy_op_with_successor_ref"> {
   }];
 }
 
+def CallWithSegmentsOp : TEST_Op<"call_with_segments",
+    [AttrSizedOperandSegments,
+  DeclareOpInterfaceMethods<CallOpInterface>]> {
+  let summary = "test call op with segmented args";
+  let arguments = (ins
+    FlatSymbolRefAttr:$callee,
+    Variadic<AnyType>:$prefix,   // non-arg segment (e.g., 'in')
+    Variadic<AnyType>:$args,     // <-- the call *arguments* segment
+    Variadic<AnyType>:$suffix    // non-arg segment (e.g., 'out')
+  );
+  let results = (outs);
+  let assemblyFormat = [{
+    $callee `(` $prefix `:` type($prefix) `)`
+            `(` $args `:` type($args) `)`
+            `(` $suffix `:` type($suffix) `)` attr-dict
+  }];
+
+  // Provide stub implementations for the ArgAndResultAttrsOpInterface.
+  let extraClassDeclaration = [{
+    ::mlir::ArrayAttr getArgAttrsAttr() { return {}; }
+    ::mlir::ArrayAttr getResAttrsAttr() { return {}; }
+    void setArgAttrsAttr(::mlir::ArrayAttr) {}
+    void setResAttrsAttr(::mlir::ArrayAttr) {}
+    ::mlir::Attribute removeArgAttrsAttr() { return {}; }
+    ::mlir::Attribute removeResAttrsAttr() { return {}; }
+  }];
+
+  let extraClassDefinition = [{
+    ::mlir::CallInterfaceCallable $cppClass::getCallableForCallee() {
+      if (auto sym = (*this)->getAttrOfType<::mlir::SymbolRefAttr>("callee"))
+        return ::mlir::CallInterfaceCallable(sym);
+      return ::mlir::CallInterfaceCallable();
+    }
+    void $cppClass::setCalleeFromCallable(::mlir::CallInterfaceCallable callee) {
+      if (auto sym = callee.dyn_cast<::mlir::SymbolRefAttr>())
+        (*this)->setAttr("callee", sym);
+      else
+        (*this)->removeAttr("callee");
+    }
+  }];
+}
+
+
 #endif // TEST_OPS

>From 78ae4f1905e567dd66aebb488864e29194e69e84 Mon Sep 17 00:00:00 2001
From: Francisco Geiman Thiesen <franciscogthiesen at gmail.com>
Date: Wed, 24 Sep 2025 14:56:42 -0600
Subject: [PATCH 2/6] Update mlir/lib/Transforms/RemoveDeadValues.cpp with
 joker-eph suggestion

Co-authored-by: Mehdi Amini <joker.eph at gmail.com>
---
 mlir/lib/Transforms/RemoveDeadValues.cpp | 11 ++++-------
 1 file changed, 4 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 0655adaad5f5f..03c02859366b2 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -785,14 +785,11 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
           // AttrSizedOperandSegments).
           if (o.nonLive.any()) {
             // Map the argument logical index to the operand number(s) recorded.
-            SmallVector<OpOperand *> callOperands =
-                operandsToOpOperands(call.getArgOperands());
+            int operandOffset = call.getArgOperands().getBeginOperandIndex();
             for (int argIdx : deadArgIdxs.set_bits()) {
-              if (argIdx < static_cast<int>(callOperands.size())) {
-                unsigned operandNumber = callOperands[argIdx]->getOperandNumber();
-                if (operandNumber < o.nonLive.size())
-                  o.nonLive.reset(operandNumber);
-              }
+              int operandNumber = operandOffset + argIdx;
+              if (operandNumber < o.nonLive.size())
+                o.nonLive.reset(operandNumber);
             }
           }
         }

>From e3d5d543d9c0580acee8081897054018687ac5ae Mon Sep 17 00:00:00 2001
From: Francisco Geiman Thiesen <franciscogthiesen at gmail.com>
Date: Wed, 24 Sep 2025 15:49:27 -0600
Subject: [PATCH 3/6] Avoiding the expensive symbol look-up

---
 mlir/lib/Transforms/RemoveDeadValues.cpp | 12 +++++++-----
 1 file changed, 7 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 03c02859366b2..18d75a93195a2 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -88,6 +88,7 @@ struct FunctionToCleanUp {
 struct OperationToCleanup {
   Operation *op;
   BitVector nonLive;
+  Operation *callee = nullptr; // Optional: For CallOpInterface ops, stores the callee function
 };
 
 struct BlockArgsToCleanup {
@@ -316,7 +317,8 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
     // Push an empty operand cleanup entry so that call-site specific logic in
     // cleanUpDeadVals runs (it keys off CallOpInterface). The BitVector is
     // intentionally all false to avoid generic erasure.
-    cl.operands.push_back({callOp, BitVector(callOp->getNumOperands(), false)});
+    // Store the funcOp as the callee to avoid expensive symbol lookup later.
+    cl.operands.push_back({callOp, BitVector(callOp->getNumOperands(), false), funcOp.getOperation()});
   }
 
   // Do (3).
@@ -768,9 +770,9 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
   LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
   for (OperationToCleanup &o : list.operands) {
     if (auto call = dyn_cast<CallOpInterface>(o.op)) {
-      if (SymbolRefAttr sym = call.getCallableForCallee().dyn_cast<SymbolRefAttr>()) {
-        Operation *callee = SymbolTable::lookupNearestSymbolFrom(o.op, sym);
-        auto it = erasedFuncArgs.find(callee);
+      // Use the stored callee reference if available, avoiding expensive symbol lookup
+      if (o.callee) {
+        auto it = erasedFuncArgs.find(o.callee);
         if (it != erasedFuncArgs.end()) {
           const BitVector &deadArgIdxs = it->second;
           MutableOperandRange args = call.getArgOperandsMutable();
@@ -788,7 +790,7 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
             int operandOffset = call.getArgOperands().getBeginOperandIndex();
             for (int argIdx : deadArgIdxs.set_bits()) {
               int operandNumber = operandOffset + argIdx;
-              if (operandNumber < o.nonLive.size())
+              if (operandNumber < static_cast<int>(o.nonLive.size()))
                 o.nonLive.reset(operandNumber);
             }
           }

>From 6ea2caca51d16d1e296de53751b39870fa30c611 Mon Sep 17 00:00:00 2001
From: Francisco Geiman Thiesen <franciscogthiesen at gmail.com>
Date: Wed, 24 Sep 2025 16:26:13 -0600
Subject: [PATCH 4/6] Clang formatting

---
 mlir/lib/Transforms/RemoveDeadValues.cpp   | 12 ++++++++----
 mlir/test/lib/Dialect/Test/TestDialect.cpp |  2 +-
 2 files changed, 9 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 18d75a93195a2..01b5522572769 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -88,7 +88,8 @@ struct FunctionToCleanUp {
 struct OperationToCleanup {
   Operation *op;
   BitVector nonLive;
-  Operation *callee = nullptr; // Optional: For CallOpInterface ops, stores the callee function
+  Operation *callee =
+      nullptr; // Optional: For CallOpInterface ops, stores the callee function
 };
 
 struct BlockArgsToCleanup {
@@ -318,7 +319,8 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
     // cleanUpDeadVals runs (it keys off CallOpInterface). The BitVector is
     // intentionally all false to avoid generic erasure.
     // Store the funcOp as the callee to avoid expensive symbol lookup later.
-    cl.operands.push_back({callOp, BitVector(callOp->getNumOperands(), false), funcOp.getOperation()});
+    cl.operands.push_back({callOp, BitVector(callOp->getNumOperands(), false),
+                           funcOp.getOperation()});
   }
 
   // Do (3).
@@ -770,13 +772,15 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
   LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
   for (OperationToCleanup &o : list.operands) {
     if (auto call = dyn_cast<CallOpInterface>(o.op)) {
-      // Use the stored callee reference if available, avoiding expensive symbol lookup
+      // Use the stored callee reference if available, avoiding expensive symbol
+      // lookup
       if (o.callee) {
         auto it = erasedFuncArgs.find(o.callee);
         if (it != erasedFuncArgs.end()) {
           const BitVector &deadArgIdxs = it->second;
           MutableOperandRange args = call.getArgOperandsMutable();
-          // First, erase the call arguments corresponding to erased callee args.
+          // First, erase the call arguments corresponding to erased callee
+          // args.
           for (int i = static_cast<int>(args.size()) - 1; i >= 0; --i) {
             if (i < static_cast<int>(deadArgIdxs.size()) && deadArgIdxs.test(i))
               args.erase(i);
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 5016ab6b94cdb..21d75f58b0a3a 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -467,7 +467,7 @@ MutableOperandRange CallWithSegmentsOp::getArgOperandsMutable() {
   // args size stored in the attribute.
   auto sizes = sizesAttr.asArrayRef();
   unsigned start = static_cast<unsigned>(sizes[0]); // prefix size
-  unsigned len = static_cast<unsigned>(sizes[1]);    // args size
+  unsigned len = static_cast<unsigned>(sizes[1]);   // args size
 
   NamedAttribute segNamed(segName, sizesAttr);
   MutableOperandRange::OperandSegment binding{kTestCallWithSegmentsArgsSegIndex,

>From a71642f30851f62c71d71ffd4ce777ea0fd3b08c Mon Sep 17 00:00:00 2001
From: Francisco Geiman Thiesen <franciscogthiesen at gmail.com>
Date: Wed, 24 Sep 2025 19:55:50 -0700
Subject: [PATCH 5/6] Making assumption explicit

---
 mlir/lib/Transforms/RemoveDeadValues.cpp | 61 +++++++++++++-----------
 1 file changed, 32 insertions(+), 29 deletions(-)

diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 01b5522572769..3f4cb7e22fa6e 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -771,39 +771,42 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
   // 4. Operands
   LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
   for (OperationToCleanup &o : list.operands) {
-    if (auto call = dyn_cast<CallOpInterface>(o.op)) {
-      // Use the stored callee reference if available, avoiding expensive symbol
-      // lookup
-      if (o.callee) {
-        auto it = erasedFuncArgs.find(o.callee);
-        if (it != erasedFuncArgs.end()) {
-          const BitVector &deadArgIdxs = it->second;
-          MutableOperandRange args = call.getArgOperandsMutable();
-          // First, erase the call arguments corresponding to erased callee
-          // args.
-          for (int i = static_cast<int>(args.size()) - 1; i >= 0; --i) {
-            if (i < static_cast<int>(deadArgIdxs.size()) && deadArgIdxs.test(i))
-              args.erase(i);
-          }
-          // If this operand cleanup entry also has a generic nonLive bitvector,
-          // clear bits for call arguments we already erased above to avoid
-          // double-erasing (which could impact other segments of ops with
-          // AttrSizedOperandSegments).
-          if (o.nonLive.any()) {
-            // Map the argument logical index to the operand number(s) recorded.
-            int operandOffset = call.getArgOperands().getBeginOperandIndex();
-            for (int argIdx : deadArgIdxs.set_bits()) {
-              int operandNumber = operandOffset + argIdx;
-              if (operandNumber < static_cast<int>(o.nonLive.size()))
-                o.nonLive.reset(operandNumber);
-            }
+    // Handle call-specific cleanup only when we have a cached callee reference.
+    // This avoids expensive symbol lookup and is defensive against future changes.
+    bool handledAsCall = false;
+    if (o.callee && isa<CallOpInterface>(o.op)) {
+      auto call = cast<CallOpInterface>(o.op);
+      auto it = erasedFuncArgs.find(o.callee);
+      if (it != erasedFuncArgs.end()) {
+        const BitVector &deadArgIdxs = it->second;
+        MutableOperandRange args = call.getArgOperandsMutable();
+        // First, erase the call arguments corresponding to erased callee
+        // args.
+        for (int i = static_cast<int>(args.size()) - 1; i >= 0; --i) {
+          if (i < static_cast<int>(deadArgIdxs.size()) && deadArgIdxs.test(i))
+            args.erase(i);
+        }
+        // If this operand cleanup entry also has a generic nonLive bitvector,
+        // clear bits for call arguments we already erased above to avoid
+        // double-erasing (which could impact other segments of ops with
+        // AttrSizedOperandSegments).
+        if (o.nonLive.any()) {
+          // Map the argument logical index to the operand number(s) recorded.
+          int operandOffset = call.getArgOperands().getBeginOperandIndex();
+          for (int argIdx : deadArgIdxs.set_bits()) {
+            int operandNumber = operandOffset + argIdx;
+            if (operandNumber < static_cast<int>(o.nonLive.size()))
+              o.nonLive.reset(operandNumber);
           }
         }
+        handledAsCall = true;
       }
     }
-    // Only perform generic operand erasure for non-call ops; for call ops we
-    // already handled argument removals via the segment-aware path above.
-    if (!isa<CallOpInterface>(o.op) && o.nonLive.any()) {
+    // Perform generic operand erasure for:
+    // - Non-call operations
+    // - Call operations without cached callee (where handledAsCall is false)
+    // But skip call operations that were already handled via segment-aware path
+    if (!handledAsCall && o.nonLive.any()) {
       o.op->eraseOperands(o.nonLive);
     }
   }

>From d2a3dbede2e18e58183621df5e39584dc1bf7e2c Mon Sep 17 00:00:00 2001
From: Francisco Geiman Thiesen <franciscogthiesen at gmail.com>
Date: Wed, 24 Sep 2025 21:26:41 -0700
Subject: [PATCH 6/6] Adding bidirectional iterator support to
 const_set_bits_iterator_impl

---
 llvm/include/llvm/ADT/BitVector.h        | 32 +++++++++++++++++++++---
 mlir/lib/Transforms/RemoveDeadValues.cpp | 11 ++++----
 2 files changed, 33 insertions(+), 10 deletions(-)

diff --git a/llvm/include/llvm/ADT/BitVector.h b/llvm/include/llvm/ADT/BitVector.h
index 72da2343fae13..a6e2a397c661e 100644
--- a/llvm/include/llvm/ADT/BitVector.h
+++ b/llvm/include/llvm/ADT/BitVector.h
@@ -40,12 +40,25 @@ template <typename BitVectorT> class const_set_bits_iterator_impl {
     Current = Parent.find_next(Current);
   }
 
+  void retreat() {
+    // For bidirectional iteration to work with reverse_iterator,
+    // we need to handle the case where Current might be at end (-1)
+    // or at a position where we need to find the previous set bit.
+    if (Current == -1) {
+      // If we're at the end, go to the last set bit
+      Current = Parent.find_last();
+    } else {
+      // Otherwise find the previous set bit before Current
+      Current = Parent.find_prev(Current);
+    }
+  }
+
 public:
-  using iterator_category = std::forward_iterator_tag;
+  using iterator_category = std::bidirectional_iterator_tag;
   using difference_type   = std::ptrdiff_t;
-  using value_type        = int;
-  using pointer           = value_type*;
-  using reference         = value_type&;
+  using value_type        = unsigned;
+  using pointer           = const value_type*;
+  using reference         = value_type;
 
   const_set_bits_iterator_impl(const BitVectorT &Parent, int Current)
       : Parent(Parent), Current(Current) {}
@@ -64,6 +77,17 @@ template <typename BitVectorT> class const_set_bits_iterator_impl {
     return *this;
   }
 
+  const_set_bits_iterator_impl operator--(int) {
+    auto Prev = *this;
+    retreat();
+    return Prev;
+  }
+
+  const_set_bits_iterator_impl &operator--() {
+    retreat();
+    return *this;
+  }
+
   unsigned operator*() const { return Current; }
 
   bool operator==(const const_set_bits_iterator_impl &Other) const {
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 3f4cb7e22fa6e..fe74f3d4632a6 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -772,7 +772,8 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
   LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
   for (OperationToCleanup &o : list.operands) {
     // Handle call-specific cleanup only when we have a cached callee reference.
-    // This avoids expensive symbol lookup and is defensive against future changes.
+    // This avoids expensive symbol lookup and is defensive against future
+    // changes.
     bool handledAsCall = false;
     if (o.callee && isa<CallOpInterface>(o.op)) {
       auto call = cast<CallOpInterface>(o.op);
@@ -781,11 +782,9 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
         const BitVector &deadArgIdxs = it->second;
         MutableOperandRange args = call.getArgOperandsMutable();
         // First, erase the call arguments corresponding to erased callee
-        // args.
-        for (int i = static_cast<int>(args.size()) - 1; i >= 0; --i) {
-          if (i < static_cast<int>(deadArgIdxs.size()) && deadArgIdxs.test(i))
-            args.erase(i);
-        }
+        // args. We iterate backwards to preserve indices.
+        for (unsigned argIdx : llvm::reverse(deadArgIdxs.set_bits()))
+          args.erase(argIdx);
         // If this operand cleanup entry also has a generic nonLive bitvector,
         // clear bits for call arguments we already erased above to avoid
         // double-erasing (which could impact other segments of ops with



More information about the Mlir-commits mailing list