[Mlir-commits] [mlir] 0e37ef0 - [mlir][transform] Use TrackingListener-aware iterator for getPayloadOps

Matthias Springer llvmlistbot at llvm.org
Mon May 15 01:38:01 PDT 2023


Author: Matthias Springer
Date: 2023-05-15T10:31:24+02:00
New Revision: 0e37ef08d4d8cf98941937874702825578fcb9c2

URL: https://github.com/llvm/llvm-project/commit/0e37ef08d4d8cf98941937874702825578fcb9c2
DIFF: https://github.com/llvm/llvm-project/commit/0e37ef08d4d8cf98941937874702825578fcb9c2.diff

LOG: [mlir][transform] Use TrackingListener-aware iterator for getPayloadOps

Instead of returning an `ArrayRef<Operation *>`, return at iterator that skips ops that were erased/replaced while iterating over the payload ops.

This fixes an issue in conjuction with TrackingListener, where a tracked op was erased during the iteration. Elements may not be removed from an array while iterating over it; this invalidates the iterator.

When ops are erased/removed via `replacePayloadOp`, they are not immediately removed from the mappings data structure. Instead, they are set to `nullptr`. `nullptr`s are not enumerated by `getPayloadOps`. At the end of each transformation, `nullptr`s are removed from the mapping data structure.

Differential Revision: https://reviews.llvm.org/D149847

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
    mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
    mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
    mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
    mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/test/Dialect/Transform/test-interpreter.mlir
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
index 8f8adffae50e3..c0d7545622e69 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
@@ -53,15 +53,15 @@ class SingleOpMatcherOpTrait
   DiagnosedSilenceableFailure apply(TransformResults &results,
                                     TransformState &state) {
     Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
-    ArrayRef<Operation *> payload = state.getPayloadOps(operandHandle);
-    if (payload.size() != 1) {
+    auto payload = state.getPayloadOps(operandHandle);
+    if (!llvm::hasSingleElement(payload)) {
       return emitDefiniteFailure(this->getOperation()->getLoc())
              << "SingleOpMatchOpTrait requires the operand handle to point to "
                 "a single payload op";
     }
 
     return cast<OpTy>(this->getOperation())
-        .matchOperation(payload[0], results, state);
+        .matchOperation(*payload.begin(), results, state);
   }
 
   void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 39a86d3828786..6730552c9c53a 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -22,6 +22,36 @@ namespace transform {
 
 class TransformOpInterface;
 class TransformResults;
+class TransformState;
+
+using Param = Attribute;
+using MappedValue = llvm::PointerUnion<Operation *, Param, Value>;
+
+namespace detail {
+/// Maps the only block argument of the op with PossibleTopLevelTransformOpTrait
+/// to either the list of operations associated with its operand or the root of
+/// the payload IR, depending on what is available in the context.
+LogicalResult
+mapPossibleTopLevelTransformOpBlockArguments(TransformState &state,
+                                             Operation *op, Region &region);
+
+/// Verification hook for PossibleTopLevelTransformOpTrait.
+LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op);
+
+/// Verification hook for TransformOpInterface.
+LogicalResult verifyTransformOpInterface(Operation *op);
+
+/// Populates `mappings` with mapped values associated with the given transform
+/// IR values in the given `state`.
+void prepareValueMappings(
+    SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings,
+    ValueRange values, const transform::TransformState &state);
+
+/// Populates `results` with payload associations that match exactly those of
+/// the operands to `block`'s terminator.
+void forwardTerminatorOperands(Block *block, transform::TransformState &state,
+                               transform::TransformResults &results);
+} // namespace detail
 
 /// Options controlling the application of transform operations by the
 /// TransformState.
@@ -46,9 +76,6 @@ class TransformOptions {
   bool expensiveChecksEnabled = true;
 };
 
-using Param = Attribute;
-using MappedValue = llvm::PointerUnion<Operation *, Param, Value>;
-
 /// Entry point to the Transform dialect infrastructure. Applies the
 /// transformation specified by `transform` to payload IR contained in
 /// `payloadRoot`. The `transform` operation may contain other operations that
@@ -140,9 +167,16 @@ class TransformState {
     return topLevelMappedValues[position];
   }
 
-  /// Returns the list of ops that the given transform IR value corresponds to.
-  /// This is helpful for transformations that apply to a particular handle.
-  ArrayRef<Operation *> getPayloadOps(Value value) const;
+  /// Returns an iterator that enumerates all ops that the given transform IR
+  /// value corresponds to. Ops may be erased while iterating; erased ops are
+  /// not enumerated. This function is helpful for transformations that apply to
+  /// a particular handle.
+  auto getPayloadOps(Value value) const {
+    // When ops are replaced/erased, they are replaced with nullptr (until
+    // the data structure is compacted). Do not enumerate these ops.
+    return llvm::make_filter_range(getPayloadOpsView(value),
+                                   [](Operation *op) { return op != nullptr; });
+  }
 
   /// Returns the list of parameters that the given transform IR value
   /// corresponds to.
@@ -407,6 +441,12 @@ class TransformState {
   LogicalResult updateStateFromResults(const TransformResults &results,
                                        ResultRange opResults);
 
+  /// Returns a list of all ops that the given transform IR value corresponds to
+  /// at the time when this function is called. In case an op was erased, the
+  /// returned list contains nullptr. This function is helpful for
+  /// transformations that apply to a particular handle.
+  ArrayRef<Operation *> getPayloadOpsView(Value value) const;
+
   /// Sets the payload IR ops associated with the given transform IR value
   /// (handle). A payload op may be associated multiple handles as long as
   /// at most one of them gets consumed by further transformations.
@@ -540,10 +580,19 @@ class TransformState {
   LogicalResult
   checkAndRecordHandleInvalidation(TransformOpInterface transform);
 
+  /// Remove all nullptrs from op handles that were added by `replacePayloadOp`.
+  void compactOpHandles();
+
   /// The mappings between transform IR values and payload IR ops, aggregated by
   /// the region in which the transform IR values are defined.
   llvm::SmallDenseMap<Region *, Mappings> mappings;
 
+  /// Op handles may be temporarily mapped to nullptr to avoid invalidating
+  /// payload op iterators. This set contains all op handles with nullptrs.
+  /// These handles are "compacted" (i.e., nullptrs removed) at the end of each
+  /// transform.
+  DenseSet<Value> opHandlesToCompact;
+
   /// Extensions attached to the TransformState, identified by the TypeID of
   /// their type. Only one extension of any given type is allowed.
   DenseMap<TypeID, std::unique_ptr<Extension>> extensions;
@@ -595,7 +644,25 @@ class TransformResults {
   /// corresponds to the given list of payload IR ops. Each result must be set
   /// by the transformation exactly once in case of transformation succeeding.
   /// The value must have a type implementing TransformHandleTypeInterface.
-  void set(OpResult value, ArrayRef<Operation *> ops);
+  template <typename Range> void set(OpResult value, Range &&ops) {
+    int64_t position = value.getResultNumber();
+    assert(position < static_cast<int64_t>(operations.size()) &&
+           "setting results for a non-existent handle");
+    assert(operations[position].data() == nullptr && "results already set");
+    assert(params[position].data() == nullptr &&
+           "another kind of results already set");
+    assert(values[position].data() == nullptr &&
+           "another kind of results already set");
+    operations.replace(position, std::forward<Range>(ops));
+  }
+
+  /// Indicates that the result of the transform IR op at the given position
+  /// corresponds to the given list of payload IR ops. Each result must be set
+  /// by the transformation exactly once in case of transformation succeeding.
+  /// The value must have a type implementing TransformHandleTypeInterface.
+  void set(OpResult value, std::initializer_list<Operation *> ops) {
+    set(value, ArrayRef<Operation *>(ops));
+  }
 
   /// Indicates that the result of the transform IR op at the given position
   /// corresponds to the given list of parameters. Each result must be set by
@@ -682,32 +749,6 @@ TransformState::make_isolated_region_scope(Region &region) {
   return RegionScope(*this, region, RegionScope::Isolated());
 }
 
-namespace detail {
-/// Maps the only block argument of the op with PossibleTopLevelTransformOpTrait
-/// to either the list of operations associated with its operand or the root of
-/// the payload IR, depending on what is available in the context.
-LogicalResult
-mapPossibleTopLevelTransformOpBlockArguments(TransformState &state,
-                                             Operation *op, Region &region);
-
-/// Verification hook for PossibleTopLevelTransformOpTrait.
-LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op);
-
-/// Verification hook for TransformOpInterface.
-LogicalResult verifyTransformOpInterface(Operation *op);
-
-/// Populates `mappings` with mapped values associated with the given transform
-/// IR values in the given `state`.
-void prepareValueMappings(
-    SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings,
-    ValueRange values, const transform::TransformState &state);
-
-/// Populates `results` with payload associations that match exactly those of
-/// the operands to `block`'s terminator.
-void forwardTerminatorOperands(Block *block, transform::TransformState &state,
-                               transform::TransformResults &results);
-} // namespace detail
-
 /// This trait is supposed to be attached to Transform dialect operations that
 /// can be standalone top-level transforms. Such operations typically contain
 /// other Transform dialect operations that can be executed following some
@@ -1069,9 +1110,9 @@ void setApplyToOneResults(Operation *transformOp,
 ///   - a concrete Op class, in which case a check is performed whether
 ///   `targets` contains operations of the same class and a silenceable failure
 ///   is reported if it does not.
-template <typename TransformOpTy>
+template <typename TransformOpTy, typename Range>
 DiagnosedSilenceableFailure
-applyTransformToEach(TransformOpTy transformOp, ArrayRef<Operation *> targets,
+applyTransformToEach(TransformOpTy transformOp, Range &&targets,
                      SmallVectorImpl<ApplyToEachResultList> &results,
                      TransformState &state) {
   using OpTy = typename llvm::function_traits<
@@ -1133,14 +1174,13 @@ template <typename OpTy>
 mlir::DiagnosedSilenceableFailure
 mlir::transform::TransformEachOpTrait<OpTy>::apply(
     TransformResults &transformResults, TransformState &state) {
-  ArrayRef<Operation *> targets =
-      state.getPayloadOps(this->getOperation()->getOperand(0));
+  auto targets = state.getPayloadOps(this->getOperation()->getOperand(0));
 
   // Step 1. Handle the corner case where no target is specified.
   // This is typically the case when the matcher fails to apply and we need to
   // propagate gracefully.
   // In this case, we fill all results with an empty vector.
-  if (targets.empty()) {
+  if (std::empty(targets)) {
     SmallVector<Operation *> emptyPayload;
     SmallVector<Attribute> emptyParams;
     for (OpResult r : this->getOperation()->getResults()) {
@@ -1157,7 +1197,6 @@ mlir::transform::TransformEachOpTrait<OpTy>::apply(
   // Step 2. Call applyToOne on each target and record newly produced ops in its
   // corresponding results entry.
   SmallVector<ApplyToEachResultList, 1> results;
-  results.reserve(targets.size());
   DiagnosedSilenceableFailure result = detail::applyTransformToEach(
       cast<OpTy>(this->getOperation()), targets, results, state);
 

diff  --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
index 728da70b85731..9a952e6c9dc58 100644
--- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
+++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
@@ -75,8 +75,7 @@ SimplifyBoundedAffineOpsOp::apply(TransformResults &results,
   for (const auto &it : llvm::zip_equal(getBoundedValues(), getLowerBounds(),
                                         getUpperBounds())) {
     Value handle = std::get<0>(it);
-    ArrayRef<Operation *> boundedValueOps = state.getPayloadOps(handle);
-    for (Operation *op : boundedValueOps) {
+    for (Operation *op : state.getPayloadOps(handle)) {
       if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
         auto diag =
             emitDefiniteFailure()
@@ -104,8 +103,8 @@ SimplifyBoundedAffineOpsOp::apply(TransformResults &results,
   }
 
   // Transform all targets.
-  ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
-  for (Operation *target : targets) {
+  SmallVector<Operation *> targets;
+  for (Operation *target : state.getPayloadOps(getTarget())) {
     if (!isa<AffineMinOp, AffineMaxOp>(target)) {
       auto diag = emitDefiniteFailure()
                   << "target must be affine.min or affine.max";
@@ -118,6 +117,7 @@ SimplifyBoundedAffineOpsOp::apply(TransformResults &results,
       diag.attachNote(target->getLoc()) << "target/constrained op";
       return diag;
     }
+    targets.push_back(target);
   }
   SmallVector<Operation *> transformed;
   RewritePatternSet patterns(getContext());

diff  --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index 5e36b55cff840..6dd32b815afce 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -41,7 +41,7 @@ transform::OneShotBufferizeOp::apply(TransformResults &transformResults,
     options.setFunctionBoundaryTypeConversion(
         *getFunctionBoundaryTypeConversion());
 
-  ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
+  auto payloadOps = state.getPayloadOps(getTarget());
   for (Operation *target : payloadOps) {
     if (!isa<ModuleOp, FunctionOpInterface>(target))
       return emitSilenceableError() << "expected module or function target";
@@ -80,8 +80,7 @@ transform::EliminateEmptyTensorsOp::apply(TransformResults &transformResults,
   OneShotBufferizationOptions options;
   options.allowReturnAllocs = true;
 
-  ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
-  for (Operation *target : payloadOps) {
+  for (Operation *target : state.getPayloadOps(getTarget())) {
     OneShotAnalysisState state(target, options);
     if (failed(analyzeOp(target, state)))
       return mlir::emitSilenceableFailure(target->getLoc())

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
index 02909bb69977f..0d9d533da3664 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
@@ -618,7 +618,7 @@ DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation(
   }
   Operation *firstUser = *result.getUsers().begin();
   if (getAny()) {
-    results.set(cast<OpResult>(getResult()), firstUser);
+    results.set(cast<OpResult>(getResult()), {firstUser});
     return DiagnosedSilenceableFailure::success();
   }
   if (getSingle()) {
@@ -626,7 +626,7 @@ DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation(
       return emitSilenceableError()
              << "more than one result user with single user requested";
     }
-    results.set(cast<OpResult>(getResult()), firstUser);
+    results.set(cast<OpResult>(getResult()), {firstUser});
     return DiagnosedSilenceableFailure::success();
   }
 

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ea8d285cf52b7..afef59990afc1 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -92,17 +92,17 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPDLOperations(
       result.push_back(ofr);
       continue;
     }
-    ArrayRef<Operation *> payloadOps = state.getPayloadOps(ofr.get<Value>());
-    if (payloadOps.size() != 1) {
+    auto payloadOps = state.getPayloadOps(ofr.get<Value>());
+    if (!llvm::hasSingleElement(payloadOps)) {
       DiagnosedSilenceableFailure diag =
           transformOp.emitSilenceableError()
           << "handle must be mapped to exactly one payload op";
       diag.attachNote(ofr.get<Value>().getLoc())
-          << "mapped to " << payloadOps.size() << " payload ops";
+          << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
       return diag;
     }
 
-    Operation *op = payloadOps[0];
+    Operation *op = *payloadOps.begin();
     if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
       DiagnosedSilenceableFailure diag =
           transformOp.emitSilenceableError()
@@ -125,8 +125,7 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPDLOperations(
 static DiagnosedSilenceableFailure unpackSingleIndexResultPDLOperations(
     transform::TransformState &state, TransformOpInterface transformOp,
     SmallVector<OpFoldResult> &result, Value packedHandle) {
-  ArrayRef<Operation *> payloadOps = state.getPayloadOps(packedHandle);
-  for (Operation *op : payloadOps) {
+  for (Operation *op : state.getPayloadOps(packedHandle)) {
     if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
       DiagnosedSilenceableFailure diag =
           transformOp.emitSilenceableError()
@@ -208,16 +207,14 @@ transform::DecomposeOp::applyToOne(LinalgOp target,
 
 /// Apply a tiling transformation to all payload ops and store both the
 /// tiled operation as well as the created tile loops.
+template <typename Range>
 static LogicalResult applyTilingToAll(
-    RewriterBase &rewriter, Operation *transformOp,
-    ArrayRef<Operation *> payloadOps, unsigned numLoops,
-    transform::TransformResults &transformResults,
+    RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
+    unsigned numLoops, transform::TransformResults &transformResults,
     function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
         applyFn) {
   SmallVector<Operation *> tiledLinalgOps;
   SmallVector<SmallVector<Operation *>> loopOps(numLoops);
-  for (unsigned int i = 0; i < numLoops; ++i)
-    loopOps[i].reserve(payloadOps.size());
 
   for (Operation *target : payloadOps) {
     auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
@@ -578,19 +575,19 @@ DiagnosedSilenceableFailure
 transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
                                        transform::TransformState &state) {
   SmallVector<Operation *> fusedOps;
-  ArrayRef<Operation *> producerOps = state.getPayloadOps(getProducerOp());
+  auto producerOps = state.getPayloadOps(getProducerOp());
   // If nothing to fuse, propagate success.
-  if (producerOps.empty()) {
+  if (std::empty(producerOps)) {
     results.set(cast<OpResult>(getFusedOp()), SmallVector<mlir::Operation *>{});
     return DiagnosedSilenceableFailure::success();
   }
-  ArrayRef<Operation *> containingOps = state.getPayloadOps(getContainingOp());
-  if (containingOps.size() != 1) {
+  auto containingOps = state.getPayloadOps(getContainingOp());
+  if (!llvm::hasSingleElement(containingOps)) {
     return emitDefiniteFailure()
            << "requires exactly one containing_op handle (got "
-           << containingOps.size() << ")";
+           << llvm::range_size(containingOps) << ")";
   }
-  Operation *containingOp = containingOps.front();
+  Operation *containingOp = *containingOps.begin();
 
   // Helper function to find the next producer that should be fused. Take any
   // producer that has a use inside the containing op.
@@ -810,8 +807,8 @@ transform::MatchOp::apply(transform::TransformResults &results,
     strs.insert(getOps()->getAsValueRange<StringAttr>().begin(),
                 getOps()->getAsValueRange<StringAttr>().end());
 
-  ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
-  if (payloadOps.size() != 1) {
+  auto payloadOps = state.getPayloadOps(getTarget());
+  if (!llvm::hasSingleElement(payloadOps)) {
     return emitDefiniteFailure("requires exactly one target handle");
   }
 
@@ -857,7 +854,7 @@ transform::MatchOp::apply(transform::TransformResults &results,
     return;
   };
 
-  payloadOps.front()->walk(matchFun);
+  (*payloadOps.begin())->walk(matchFun);
   results.set(cast<OpResult>(getResult()), res);
   return DiagnosedSilenceableFailure::success();
 }
@@ -996,18 +993,19 @@ SmallVector<OpFoldResult> transform::PackOp::getMixedPackedSizes() {
 DiagnosedSilenceableFailure
 transform::PackOp::apply(transform::TransformResults &transformResults,
                          transform::TransformState &state) {
-  ArrayRef<Operation *> targetOps = state.getPayloadOps(getTarget());
+  auto targetOps = state.getPayloadOps(getTarget());
   // If nothing to pack, propagate success.
-  if (targetOps.empty()) {
-    transformResults.set(cast<OpResult>(getPackedOp()), {});
+  if (std::empty(targetOps)) {
+    transformResults.set(cast<OpResult>(getPackedOp()),
+                         ArrayRef<Operation *>({}));
     return DiagnosedSilenceableFailure::success();
   }
   // Fail on multi-op handles.
-  auto linalgOp = dyn_cast<LinalgOp>(targetOps.front());
-  if (targetOps.size() != 1 || !linalgOp) {
+  auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
+  if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
     return emitSilenceableError()
            << "requires target to map to exactly 1 LinalgOp (got "
-           << targetOps.size() << ")";
+           << llvm::range_size(targetOps) << ")";
   }
   // Fail on mismatched number of pack sizes.
   if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
@@ -1030,7 +1028,7 @@ transform::PackOp::apply(transform::TransformResults &transformResults,
     return emitDefiniteFailure("data tiling failed");
 
   transformResults.set(cast<OpResult>(getPackedOp()),
-                       maybeResult->packedLinalgOp.getOperation());
+                       {maybeResult->packedLinalgOp.getOperation()});
   return DiagnosedSilenceableFailure::success();
 }
 
@@ -1204,16 +1202,10 @@ packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
 DiagnosedSilenceableFailure
 PackGreedilyOp::apply(transform::TransformResults &transformResults,
                       transform::TransformState &state) {
-  ArrayRef<Operation *> targetOpsView = state.getPayloadOps(getTarget());
-  // Store payload ops into a separate SmallVector because the TrackingListener
-  // removes erased ops from the transform state.
-  SmallVector<Operation *> targetOps(targetOpsView.begin(),
-                                     targetOpsView.end());
-
   SmallVector<Operation *> results;
   TrackingListener listener(state, *this);
   IRRewriter rewriter(getContext(), &listener);
-  for (Operation *op : targetOps) {
+  for (Operation *op : state.getPayloadOps(getTarget())) {
     auto linalgOp = dyn_cast<LinalgOp>(op);
     if (!linalgOp)
       continue;
@@ -1310,11 +1302,10 @@ bool isValidPackingPermutation(
 DiagnosedSilenceableFailure
 transform::PackTransposeOp::apply(transform::TransformResults &transformResults,
                                   transform::TransformState &state) {
-  ArrayRef<Operation *> packOrUnpackOps =
-      state.getPayloadOps(getTargetPackOrUnPackOp());
-  ArrayRef<Operation *> linalgOps = state.getPayloadOps(getTargetLinalgOp());
+  auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
+  auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
   // Step 1. If nothing to pack, propagate success.
-  if (packOrUnpackOps.empty()) {
+  if (std::empty(packOrUnpackOps)) {
     transformResults.set(cast<OpResult>(getPackedOp()), {});
     transformResults.set(cast<OpResult>(getPackOp()), {});
     transformResults.set(cast<OpResult>(getUnPackOp()), {});
@@ -1323,21 +1314,23 @@ transform::PackTransposeOp::apply(transform::TransformResults &transformResults,
 
   // Step 2. Bunch of runtime sanity check and error messages.
   // Step 2.1. Fail on multi-op handles.
-  if (packOrUnpackOps.size() != 1 || linalgOps.size() != 1) {
-    return emitSilenceableError() << "requires target to map to exactly 1 "
-                                     "packing op and 1 packed op ("
-                                  << "got " << packOrUnpackOps.size() << " and "
-                                  << linalgOps.size() << ")";
+  if (!llvm::hasSingleElement(packOrUnpackOps) ||
+      !llvm::hasSingleElement(linalgOps)) {
+    return emitSilenceableError()
+           << "requires target to map to exactly 1 "
+              "packing op and 1 packed op ("
+           << "got " << llvm::range_size(packOrUnpackOps) << " and "
+           << llvm::range_size(linalgOps) << ")";
   }
 
   // Step 2.2. Fail on wrong type.
-  auto packOp = dyn_cast<tensor::PackOp>(packOrUnpackOps.front());
-  auto unPackOp = dyn_cast<tensor::UnPackOp>(packOrUnpackOps.front());
+  auto packOp = dyn_cast<tensor::PackOp>(*packOrUnpackOps.begin());
+  auto unPackOp = dyn_cast<tensor::UnPackOp>(*packOrUnpackOps.begin());
   if ((!packOp && !unPackOp)) {
     return emitSilenceableError() << "requires target to map to a "
                                      "tensor.pack or tensor.unpack";
   }
-  LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(linalgOps.front());
+  LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
   if (!linalgOpTarget)
     return emitSilenceableError() << "requires a LinalgOp target";
 
@@ -1520,16 +1513,17 @@ LogicalResult transform::PadOp::verify() {
 DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
     transform::TransformResults &transformResults,
     transform::TransformState &state) {
-  ArrayRef<Operation *> targetOps = state.getPayloadOps(getTarget());
-  ArrayRef<Operation *> loopOps = state.getPayloadOps(getLoop());
-  if (targetOps.size() != 1 || loopOps.size() != 1) {
+  auto targetOps = state.getPayloadOps(getTarget());
+  auto loopOps = state.getPayloadOps(getLoop());
+  if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
     return emitDefiniteFailure()
            << "requires exactly one target and one loop handle (got "
-           << targetOps.size() << " and " << loopOps.size() << ")";
+           << llvm::range_size(targetOps) << " and "
+           << llvm::range_size(loopOps) << ")";
   }
 
-  auto padOp = dyn_cast_or_null<tensor::PadOp>(targetOps.front());
-  auto loopOp = dyn_cast_or_null<scf::ForOp>(loopOps.front());
+  auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
+  auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
   if (!padOp || !loopOp)
     return emitDefiniteFailure() << "requires exactly 2 non-null handles";
 
@@ -1543,13 +1537,13 @@ DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
 
   if (result->clonedLoopIvs.empty()) {
     transformResults.set(cast<OpResult>(getPackingLoop()),
-                         result->hoistedPadOp.getOperation());
+                         {result->hoistedPadOp.getOperation()});
     return DiagnosedSilenceableFailure::success();
   }
   auto outerPackedLoop =
       scf::getForInductionVarOwner(result->clonedLoopIvs.front());
   transformResults.set(cast<OpResult>(getPackingLoop()),
-                       outerPackedLoop.getOperation());
+                       {outerPackedLoop.getOperation()});
   return DiagnosedSilenceableFailure::success();
 }
 
@@ -1679,7 +1673,7 @@ transform::PromoteOp::applyToOne(LinalgOp target,
 DiagnosedSilenceableFailure
 transform::ReplaceOp::apply(TransformResults &transformResults,
                             TransformState &state) {
-  ArrayRef<Operation *> payload = state.getPayloadOps(getTarget());
+  auto payload = state.getPayloadOps(getTarget());
 
   // Check for invalid targets.
   for (Operation *target : payload) {
@@ -1814,7 +1808,8 @@ transform::RewriteInDestinationPassingStyleOp::applyToOne(
 DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
                                            TransformState &state) {
   // Collect the dynamic split points if provided.
-  ArrayRef<Operation *> payload = state.getPayloadOps(getTarget());
+  SmallVector<Operation *> payload =
+      llvm::to_vector(state.getPayloadOps(getTarget()));
   TrackingListener listener(state, *this);
   IRRewriter rewriter(getContext(), &listener);
   SmallVector<OpFoldResult> splitPoints;
@@ -2199,8 +2194,9 @@ transform::TileOp::apply(TransformResults &transformResults,
                          TransformState &state) {
   ArrayRef<int64_t> tileSizes = getStaticSizes();
 
-  ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
-  SmallVector<ArrayRef<Operation *>> dynamicSizeProducers;
+  SmallVector<Operation *> targets =
+      llvm::to_vector(state.getPayloadOps(getTarget()));
+  SmallVector<SmallVector<Operation *>> dynamicSizeProducers;
   SmallVector<SmallVector<int64_t>> paramSizes;
   dynamicSizeProducers.reserve(getDynamicSizes().size());
   paramSizes.reserve(getDynamicSizes().size());
@@ -2226,7 +2222,8 @@ transform::TileOp::apply(TransformResults &transformResults,
       continue;
     }
     paramSizes.push_back({});
-    dynamicSizeProducers.push_back(state.getPayloadOps(transformValue));
+    dynamicSizeProducers.push_back(
+        llvm::to_vector(state.getPayloadOps(transformValue)));
 
     if (dynamicSizeProducers.back().size() != targets.size()) {
       DiagnosedSilenceableFailure diag =
@@ -2536,10 +2533,6 @@ transform::TileToForallOp::apply(transform::TransformResults &transformResults,
   TrackingListener listener(state, *this);
   IRRewriter rewriter(getContext(), &listener);
   auto transformOp = cast<TransformOpInterface>(getOperation());
-  ArrayRef<Operation *> targetsView = state.getPayloadOps(getTarget());
-  // Store payload ops into a separate SmallVector because the TrackingListener
-  // removes erased ops from the transform state.
-  SmallVector<Operation *> targets(targetsView.begin(), targetsView.end());
 
   // Result payload ops.
   SmallVector<Operation *> tileOps;
@@ -2564,7 +2557,7 @@ transform::TileToForallOp::apply(transform::TransformResults &transformResults,
   if (!status.succeeded())
     return status;
 
-  for (Operation *target : targets) {
+  for (Operation *target : state.getPayloadOps(getTarget())) {
     linalg::ForallTilingResult tilingResult;
     DiagnosedSilenceableFailure diag = tileToForallOpImpl(
         rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
@@ -2652,12 +2645,13 @@ transform::TileToScfForOp::apply(TransformResults &transformResults,
                                  TransformState &state) {
   ArrayRef<int64_t> tileSizes = getStaticSizes();
 
-  ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
-  SmallVector<ArrayRef<Operation *>> dynamicSizeProducers;
+  SmallVector<Operation *> targets =
+      llvm::to_vector(state.getPayloadOps(getTarget()));
+  SmallVector<SmallVector<Operation *>> dynamicSizeProducers;
   dynamicSizeProducers.reserve(getDynamicSizes().size());
   for (Value dynamicSizeProducerHandle : getDynamicSizes()) {
     dynamicSizeProducers.push_back(
-        state.getPayloadOps(dynamicSizeProducerHandle));
+        llvm::to_vector(state.getPayloadOps(dynamicSizeProducerHandle)));
 
     if (dynamicSizeProducers.back().size() != targets.size()) {
       DiagnosedSilenceableFailure diag =
@@ -2884,8 +2878,8 @@ DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply(
     mlir::transform::TransformState &state) {
   TrackingListener listener(state, *this);
   IRRewriter rewriter(getContext(), &listener);
-  ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
-  if (targets.empty())
+  auto targets = state.getPayloadOps(getTarget());
+  if (std::empty(targets))
     return DiagnosedSilenceableFailure::success();
 
   SmallVector<int64_t> vectorSizes;
@@ -2896,16 +2890,16 @@ DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply(
       continue;
     }
 
-    ArrayRef<Operation *> szPayloads = state.getPayloadOps(sz.get<Value>());
-    if (szPayloads.size() != 1) {
+    auto szPayloads = state.getPayloadOps(sz.get<Value>());
+    if (!llvm::hasSingleElement(szPayloads)) {
       auto diag = this->emitOpError(
           "requires vector size handle that is mapped to 1 payload op");
       diag.attachNote(sz.get<Value>().getLoc())
-          << "mapped to " << szPayloads.size() << " payload ops";
+          << "mapped to " << llvm::range_size(szPayloads) << " payload ops";
       return DiagnosedSilenceableFailure::definiteFailure();
     }
 
-    Operation *szPayloadOp = szPayloads[0];
+    Operation *szPayloadOp = *szPayloads.begin();
     if (szPayloadOp->getNumResults() != 1 ||
         !szPayloadOp->getResult(0).getType().isIndex()) {
       auto diag = this->emitOpError(

diff  --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
index ae2472db4f862..c523af936e235 100644
--- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
+++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
@@ -35,9 +35,8 @@ DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply(
     transform::TransformResults &transformResults,
     transform::TransformState &state) {
   SmallVector<Operation *> results;
-  ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
   IRRewriter rewriter(getContext());
-  for (auto *op : payloadOps) {
+  for (Operation *op : state.getPayloadOps(getTarget())) {
     bool canApplyMultiBuffer = true;
     auto target = cast<memref::AllocOp>(op);
     LLVM_DEBUG(DBGS() << "Start multibuffer transform op: " << target << "\n";);

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 4f3afd8e695d7..bad1d74fb473c 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -53,7 +53,7 @@ transform::TransformState::TransformState(
 Operation *transform::TransformState::getTopLevel() const { return topLevel; }
 
 ArrayRef<Operation *>
-transform::TransformState::getPayloadOps(Value value) const {
+transform::TransformState::getPayloadOpsView(Value value) const {
   const TransformOpMapping &operationMapping = getMapping(value).direct;
   auto iter = operationMapping.find(value);
   assert(
@@ -357,18 +357,8 @@ transform::TransformState::replacePayloadOp(Operation *op,
 
   // TODO: consider invalidating the handles to nested objects here.
 
-  // If replacing with null, that is erasing the mapping, drop the mapping
-  // between the handles and the IR objects and return.
-  if (!replacement) {
-    for (Value handle : opHandles) {
-      Mappings &mappings = getMapping(handle);
-      dropMappingEntry(mappings.direct, handle, op);
-    }
-    return success();
-  }
-
 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
-  if (options.getExpensiveChecksEnabled()) {
+  if (replacement && options.getExpensiveChecksEnabled()) {
     auto insertion = cachedNames.insert({replacement, replacement->getName()});
     if (!insertion.second) {
       assert(insertion.first->second == replacement->getName() &&
@@ -377,8 +367,15 @@ transform::TransformState::replacePayloadOp(Operation *op,
   }
 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
 
-  // Otherwise, replace the pointed-to object of all handles while preserving
-  // their relative order. First, replace the mapped operation if present.
+  // Replace the pointed-to object of all handles with the replacement object.
+  // In case a payload op was erased (replacement object is nullptr), a nullptr
+  // is stored in the mapping. These nullptrs are removed after each transform.
+  // Furthermore, nullptrs are not enumerated by payload op iterators. The
+  // relative order of ops is preserved.
+  //
+  // Removing an op from the mapping would be problematic because removing an
+  // element from an array invalidates iterators; merely changing the value of
+  // elements does not.
   for (Value handle : opHandles) {
     Mappings &mappings = getMapping(handle);
     auto it = mappings.direct.find(handle);
@@ -391,7 +388,12 @@ transform::TransformState::replacePayloadOp(Operation *op,
       if (mapped == op)
         mapped = replacement;
     }
-    mappings.reverse[replacement].push_back(handle);
+
+    if (replacement) {
+      mappings.reverse[replacement].push_back(handle);
+    } else {
+      opHandlesToCompact.insert(handle);
+    }
   }
 
   return success();
@@ -645,7 +647,7 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
       FULL_LDBG("----found consume effect -> SKIP\n");
       if (llvm::isa<TransformHandleTypeInterface>(target.get().getType())) {
         FULL_LDBG("----recordOpHandleInvalidation\n");
-        ArrayRef<Operation *> payloadOps = getPayloadOps(target.get());
+        ArrayRef<Operation *> payloadOps = getPayloadOpsView(target.get());
         recordOpHandleInvalidation(target, payloadOps);
       } else if (llvm::isa<TransformValueHandleTypeInterface>(
                      target.get().getType())) {
@@ -686,6 +688,14 @@ checkRepeatedConsumptionInOperand(ArrayRef<T> payload,
   return DiagnosedSilenceableFailure::success();
 }
 
+void transform::TransformState::compactOpHandles() {
+  for (Value handle : opHandlesToCompact) {
+    Mappings &mappings = getMapping(handle);
+    llvm::erase_value(mappings.direct[handle], nullptr);
+  }
+  opHandlesToCompact.clear();
+}
+
 DiagnosedSilenceableFailure
 transform::TransformState::applyTransform(TransformOpInterface transform) {
   LLVM_DEBUG({
@@ -721,7 +731,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
         FULL_LDBG("--checkRepeatedConsumptionInOperand for Operation*\n");
         DiagnosedSilenceableFailure check =
             checkRepeatedConsumptionInOperand<Operation *>(
-                getPayloadOps(operand.get()), transform,
+                getPayloadOpsView(operand.get()), transform,
                 operand.getOperandNumber());
         if (!check.succeeded()) {
           FULL_LDBG("----FAILED\n");
@@ -835,6 +845,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
   // proceed on a best effort basis.
   transform::TransformResults results(transform->getNumResults());
   DiagnosedSilenceableFailure result(transform.apply(results, *this));
+  compactOpHandles();
   if (result.isDefiniteFailure())
     return result;
 
@@ -988,19 +999,6 @@ transform::TransformResults::TransformResults(unsigned numSegments) {
   values.appendEmptyRows(numSegments);
 }
 
-void transform::TransformResults::set(OpResult value,
-                                      ArrayRef<Operation *> ops) {
-  int64_t position = value.getResultNumber();
-  assert(position < static_cast<int64_t>(operations.size()) &&
-         "setting results for a non-existent handle");
-  assert(operations[position].data() == nullptr && "results already set");
-  assert(params[position].data() == nullptr &&
-         "another kind of results already set");
-  assert(values[position].data() == nullptr &&
-         "another kind of results already set");
-  operations.replace(position, ops);
-}
-
 void transform::TransformResults::setParams(
     OpResult value, ArrayRef<transform::TransformState::Param> params) {
   int64_t position = value.getResultNumber();

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index ecd5d2a915ab6..ad001707ddd64 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -767,10 +767,9 @@ LogicalResult transform::ForeachMatchOp::verifySymbolUses(
 DiagnosedSilenceableFailure
 transform::ForeachOp::apply(transform::TransformResults &results,
                             transform::TransformState &state) {
-  ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
   SmallVector<SmallVector<Operation *>> resultOps(getNumResults(), {});
 
-  for (Operation *op : payloadOps) {
+  for (Operation *op : state.getPayloadOps(getTarget())) {
     auto scope = state.make_region_scope(getBody());
     if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
       return DiagnosedSilenceableFailure::definiteFailure();
@@ -785,8 +784,7 @@ transform::ForeachOp::apply(transform::TransformResults &results,
 
     // Append yielded payload ops to result list (if any).
     for (unsigned i = 0; i < getNumResults(); ++i) {
-      ArrayRef<Operation *> yieldedOps =
-          state.getPayloadOps(getYieldOp().getOperand(i));
+      auto yieldedOps = state.getPayloadOps(getYieldOp().getOperand(i));
       resultOps[i].append(yieldedOps.begin(), yieldedOps.end());
     }
   }
@@ -882,16 +880,16 @@ DiagnosedSilenceableFailure
 transform::GetConsumersOfResult::apply(transform::TransformResults &results,
                                        transform::TransformState &state) {
   int64_t resultNumber = getResultNumber();
-  ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
-  if (payloadOps.empty()) {
-    results.set(llvm::cast<OpResult>(getResult()), {});
+  auto payloadOps = state.getPayloadOps(getTarget());
+  if (std::empty(payloadOps)) {
+    results.set(cast<OpResult>(getResult()), {});
     return DiagnosedSilenceableFailure::success();
   }
-  if (payloadOps.size() != 1)
+  if (!llvm::hasSingleElement(payloadOps))
     return emitDefiniteFailure()
            << "handle must be mapped to exactly one payload op";
 
-  Operation *target = payloadOps.front();
+  Operation *target = *payloadOps.begin();
   if (target->getNumResults() <= resultNumber)
     return emitDefiniteFailure() << "result number overflow";
   results.set(llvm::cast<OpResult>(getResult()),
@@ -1483,7 +1481,7 @@ void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result,
 DiagnosedSilenceableFailure
 transform::SplitHandleOp::apply(transform::TransformResults &results,
                                 transform::TransformState &state) {
-  int64_t numPayloadOps = state.getPayloadOps(getHandle()).size();
+  int64_t numPayloadOps = llvm::range_size(state.getPayloadOps(getHandle()));
   auto produceNumOpsError = [&]() {
     return emitSilenceableError()
            << getHandle() << " expected to contain " << this->getNumResults()
@@ -1573,11 +1571,12 @@ void transform::PDLMatchOp::getEffects(
 DiagnosedSilenceableFailure
 transform::ReplicateOp::apply(transform::TransformResults &results,
                               transform::TransformState &state) {
-  unsigned numRepetitions = state.getPayloadOps(getPattern()).size();
+  unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern()));
   for (const auto &en : llvm::enumerate(getHandles())) {
     Value handle = en.value();
-    if (llvm::isa<TransformHandleTypeInterface>(handle.getType())) {
-      ArrayRef<Operation *> current = state.getPayloadOps(handle);
+    if (isa<TransformHandleTypeInterface>(handle.getType())) {
+      SmallVector<Operation *> current =
+          llvm::to_vector(state.getPayloadOps(handle));
       SmallVector<Operation *> payload;
       payload.reserve(numRepetitions * current.size());
       for (unsigned i = 0; i < numRepetitions; ++i)
@@ -2011,8 +2010,7 @@ transform::PrintOp::apply(transform::TransformResults &results,
   }
 
   llvm::outs() << "]]]\n";
-  ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
-  for (Operation *target : targets)
+  for (Operation *target : state.getPayloadOps(getTarget()))
     llvm::outs() << *target << "\n";
 
   return DiagnosedSilenceableFailure::success();

diff  --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 8ceb72d8d46b3..fb89d57a9332f 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics
+// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics | FileCheck %s
 
 transform.sequence failures(propagate) {
 ^bb0(%arg0: !transform.any_op):
@@ -1598,3 +1598,26 @@ module attributes { transform.with_named_sequence } {
     return
   }
 }
+
+// -----
+
+// CHECK-LABEL: func @test_tracked_rewrite() {
+//  CHECK-NEXT:   "test.update_mapping"() {original_op = "test.replace_me"}
+//  CHECK-NEXT:   "test.drop_mapping"() {original_op = "test.replace_me"}
+//  CHECK-NEXT:   "test.update_mapping"() {original_op = "test.replace_me"}
+//  CHECK-NEXT: }
+func.func @test_tracked_rewrite() {
+  %0 = "test.replace_me"() {replacement = "test.update_mapping"} : () -> (i1)
+  %1 = "test.replace_me"() {replacement = "test.drop_mapping"} : () -> (i1)
+  %2 = "test.replace_me"() {replacement = "test.update_mapping"} : () -> (i1)
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["test.replace_me"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+  // expected-remark @below {{2 iterations}}
+  transform.test_tracked_rewrite %0 : (!pdl.operation) -> ()
+  // One replacement op (test.drop_mapping) is dropped from the mapping.
+  // expected-remark @below {{2}}
+  test_print_number_of_associated_payload_ir_ops %0
+}

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 1d1bbc3a57084..50a4c92da9aa4 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/PDL/IR/PDL.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
 #include "mlir/IR/OpImplementation.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
@@ -109,10 +110,10 @@ DiagnosedSilenceableFailure
 mlir::test::TestProduceSelfHandleOrForwardOperandOp::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   if (getOperation()->getNumOperands() != 0) {
-    results.set(llvm::cast<OpResult>(getResult()),
-                getOperation()->getOperand(0).getDefiningOp());
+    results.set(cast<OpResult>(getResult()),
+                {getOperation()->getOperand(0).getDefiningOp()});
   } else {
-    results.set(llvm::cast<OpResult>(getResult()), getOperation());
+    results.set(cast<OpResult>(getResult()), {getOperation()});
   }
   return DiagnosedSilenceableFailure::success();
 }
@@ -191,12 +192,13 @@ void mlir::test::TestConsumeOperand::getEffects(
 
 DiagnosedSilenceableFailure mlir::test::TestConsumeOperandOfOpKindOrFail::apply(
     transform::TransformResults &results, transform::TransformState &state) {
-  ArrayRef<Operation *> payload = state.getPayloadOps(getOperand());
-  assert(payload.size() == 1 && "expected a single target op");
-  if (payload[0]->getName().getStringRef() != getOpKind()) {
+  auto payload = state.getPayloadOps(getOperand());
+  assert(llvm::hasSingleElement(payload) && "expected a single target op");
+  if ((*payload.begin())->getName().getStringRef() != getOpKind()) {
     return emitSilenceableError()
            << "op expected the operand to be associated a payload op of kind "
-           << getOpKind() << " got " << payload[0]->getName().getStringRef();
+           << getOpKind() << " got "
+           << (*payload.begin())->getName().getStringRef();
   }
 
   emitRemark() << "succeeded";
@@ -230,7 +232,7 @@ void mlir::test::TestSucceedIfOperandOfOpKind::getEffects(
 
 DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandOp::apply(
     transform::TransformResults &results, transform::TransformState &state) {
-  ArrayRef<Operation *> payload = state.getPayloadOps(getOperand());
+  auto payload = state.getPayloadOps(getOperand());
   for (Operation *op : payload)
     op->emitRemark() << getMessage();
 
@@ -313,11 +315,11 @@ DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply(
   if (!extension)
     return emitDefiniteFailure("TestTransformStateExtension missing");
 
-  if (failed(extension->updateMapping(state.getPayloadOps(getOperand()).front(),
-                                      getOperation())))
+  if (failed(extension->updateMapping(
+          *state.getPayloadOps(getOperand()).begin(), getOperation())))
     return DiagnosedSilenceableFailure::definiteFailure();
   if (getNumResults() > 0)
-    results.set(llvm::cast<OpResult>(getResult(0)), getOperation());
+    results.set(cast<OpResult>(getResult(0)), {getOperation()});
   return DiagnosedSilenceableFailure::success();
 }
 
@@ -337,7 +339,7 @@ DiagnosedSilenceableFailure mlir::test::TestRemoveTestExtensionOp::apply(
 DiagnosedSilenceableFailure
 mlir::test::TestReversePayloadOpsOp::apply(transform::TransformResults &results,
                                            transform::TransformState &state) {
-  ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
+  auto payloadOps = state.getPayloadOps(getTarget());
   auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps));
   results.set(llvm::cast<OpResult>(getResult()), reversedOps);
   return DiagnosedSilenceableFailure::success();
@@ -431,7 +433,7 @@ mlir::test::TestPrintNumberOfAssociatedPayloadIROps::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   if (!getHandle())
     emitRemark() << 0;
-  emitRemark() << state.getPayloadOps(getHandle()).size();
+  emitRemark() << llvm::range_size(state.getPayloadOps(getHandle()));
   return DiagnosedSilenceableFailure::success();
 }
 
@@ -599,11 +601,11 @@ mlir::test::TestProduceTransformParamOrForwardOperandOp::applyToOne(
   } else if (getFirstResultIsNull()) {
     results.push_back(nullptr);
   } else {
-    results.push_back(state.getPayloadOps(getIn()).front());
+    results.push_back(*state.getPayloadOps(getIn()).begin());
   }
 
   if (getSecondResultIsHandle()) {
-    results.push_back(state.getPayloadOps(getIn()).front());
+    results.push_back(*state.getPayloadOps(getIn()).begin());
   } else {
     results.push_back(builder.getI64IntegerAttr(42));
   }
@@ -667,6 +669,70 @@ DiagnosedSilenceableFailure mlir::test::TestRequiredMemoryEffectsOp::apply(
   return DiagnosedSilenceableFailure::success();
 }
 
+void mlir::test::TestTrackedRewriteOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(getIn(), effects);
+  transform::modifiesPayload(effects);
+}
+
+namespace {
+/// A TrackingListener for test cases. When the replacement op is
+/// "test.update_mapping", it is considered as a replacement op in the transform
+/// state mapping. Otherwise, it is not and the original op is simply removed
+/// from the mapping.
+class TestTrackingListener : public transform::TrackingListener {
+  using transform::TrackingListener::TrackingListener;
+
+protected:
+  Operation *findReplacementOp(Operation *op,
+                               ValueRange newValues) const override {
+    if (newValues.size() != 1)
+      return nullptr;
+    Operation *replacement = newValues[0].getDefiningOp();
+    if (!replacement)
+      return nullptr;
+    if (replacement->getName().getStringRef() != "test.update_mapping")
+      return nullptr;
+    return replacement;
+  }
+};
+} // namespace
+
+DiagnosedSilenceableFailure
+mlir::test::TestTrackedRewriteOp::apply(transform::TransformResults &results,
+                                        transform::TransformState &state) {
+  TestTrackingListener listener(state, *this);
+  IRRewriter rewriter(getContext(), &listener);
+  int64_t numIterations = 0;
+
+  // `getPayloadOps` returns an iterator that skips ops that are erased in the
+  // loop body. Replacement ops are not enumerated.
+  for (Operation *op : state.getPayloadOps(getIn())) {
+    ++numIterations;
+    rewriter.setInsertionPointToEnd(op->getBlock());
+
+    // Erase all payload ops. The outer loop should have only one iteration.
+    for (Operation *op : state.getPayloadOps(getIn())) {
+      if (op->getName().getStringRef() != "test.replace_me")
+        continue;
+      auto replacementName = op->getAttrOfType<StringAttr>("replacement");
+      if (!replacementName)
+        continue;
+      SmallVector<NamedAttribute> attributes;
+      attributes.emplace_back(rewriter.getStringAttr("original_op"),
+                              op->getName().getIdentifier());
+      OperationState opState(op->getLoc(), replacementName,
+                             /*operands=*/ValueRange(),
+                             /*types=*/op->getResultTypes(), attributes);
+      Operation *newOp = rewriter.create(opState);
+      rewriter.replaceOp(op, newOp->getResults());
+    }
+  }
+
+  emitRemark() << numIterations << " iterations";
+  return DiagnosedSilenceableFailure::success();
+}
+
 namespace {
 /// Test extension of the Transform dialect. Registers additional ops and
 /// declares PDL as dependent dialect since the additional ops are using PDL

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index e5a23ede64088..c77f6ea320b68 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -449,4 +449,14 @@ def TestRequiredMemoryEffectsOp
   let cppNamespace = "::mlir::test";
 }
 
+def TestTrackedRewriteOp
+  : Op<Transform_Dialect, "test_tracked_rewrite",
+      [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+       DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let arguments = (ins TransformHandleTypeInterface:$in);
+  let results = (outs);
+  let assemblyFormat = "$in attr-dict `:` functional-type(operands, results)";
+  let cppNamespace = "::mlir::test";
+}
+
 #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD


        


More information about the Mlir-commits mailing list