[Mlir-commits] [mlir] [mlir][transform] Check for invalidated iterators on payload values (PR #66472)

Matthias Springer llvmlistbot at llvm.org
Fri Sep 15 01:57:02 PDT 2023


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/66472

Same as #66369 but for payload values. (#66369 added checks only for payload operations.)

It was necessary to change the signature of `getPayloadValues` to return an iterator. This is now similar to payload operations.

Fixes an issue in #66369 where the `LLVM_ENABLE_ABI_BREAKING_CHECKS` check was inverted.

>From 7c2a3d741089099825bb287dfebb787d08427597 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Fri, 15 Sep 2023 10:54:26 +0200
Subject: [PATCH] [mlir][transform] Check for invalidated iterators on payload
 values

Same as #66369 but for payload values. (#66369 added checks only for payload operations.)

It was necessary to change the signature of `getPayloadValues` to return an iterator. This is now similar to payload operations.

Fixes an issue in #66369 where the `LLVM_ENABLE_ABI_BREAKING_CHECKS` check was inverted.
---
 .../Dialect/Transform/IR/MatchInterfaces.h    |  6 +-
 .../Transform/IR/TransformInterfaces.h        | 70 +++++++++++++++----
 .../Linalg/TransformOps/LinalgMatchOps.cpp    |  2 +-
 .../Transform/IR/TransformInterfaces.cpp      | 42 +++++------
 .../lib/Dialect/Transform/IR/TransformOps.cpp |  4 +-
 .../TestTransformDialectExtension.cpp         |  7 +-
 6 files changed, 87 insertions(+), 44 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
index 1a6afc58fef2704..c8888f294f6ca1d 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
@@ -95,15 +95,15 @@ class SingleValueMatcherOpTrait
                                     TransformResults &results,
                                     TransformState &state) {
     Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
-    ValueRange payload = state.getPayloadValues(operandHandle);
-    if (payload.size() != 1) {
+    auto payload = state.getPayloadValues(operandHandle);
+    if (!llvm::hasSingleElement(payload)) {
       return emitDefiniteFailure(this->getOperation()->getLoc())
              << "SingleValueMatchOpTrait requires the value handle to point to "
                 "a single payload value";
     }
 
     return cast<OpTy>(this->getOperation())
-        .matchValue(payload[0], results, state);
+        .matchValue(*payload.begin(), results, state);
   }
 
   void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 86af59142b77d9c..31a93b05cf7a153 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -170,7 +170,7 @@ class TransformState {
   /// should be emitted when the value is used.
   using InvalidatedHandleMap = DenseMap<Value, std::function<void(Location)>>;
 
-#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
   /// Debug only: A timestamp is associated with each transform IR value, so
   /// that invalid iterator usage can be detected more reliably.
   using TransformIRTimestampMapping = DenseMap<Value, int64_t>;
@@ -185,7 +185,7 @@ class TransformState {
     ValueMapping values;
     ValueMapping reverseValues;
 
-#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
     TransformIRTimestampMapping timestamps;
     void incrementTimestamp(Value value) { ++timestamps[value]; }
 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
@@ -220,7 +220,7 @@ class TransformState {
   auto getPayloadOps(Value value) const {
     ArrayRef<Operation *> view = getPayloadOpsView(value);
 
-#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
     // Memorize the current timestamp and make sure that it has not changed
     // when incrementing or dereferencing the iterator returned by this
     // function. The timestamp is incremented when the "direct" mapping is
@@ -231,7 +231,7 @@ class TransformState {
     // When ops are replaced/erased, they are replaced with nullptr (until
     // the data structure is compacted). Do not enumerate these ops.
     return llvm::make_filter_range(view, [=](Operation *op) {
-#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
       bool sameTimestamp =
           currentTimestamp == this->getMapping(value).timestamps.lookup(value);
       assert(sameTimestamp && "iterator was invalidated during iteration");
@@ -244,9 +244,29 @@ class TransformState {
   /// corresponds to.
   ArrayRef<Attribute> getParams(Value value) const;
 
-  /// Returns the list of payload IR values that the given transform IR value
-  /// corresponds to.
-  ArrayRef<Value> getPayloadValues(Value handleValue) const;
+  /// Returns an iterator that enumerates all payload IR values that the given
+  /// transform IR value corresponds to.
+  auto getPayloadValues(Value handleValue) const {
+    ArrayRef<Value> view = getPayloadValuesView(handleValue);
+
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
+    // Memorize the current timestamp and make sure that it has not changed
+    // when incrementing or dereferencing the iterator returned by this
+    // function. The timestamp is incremented when the "values" mapping is
+    // resized; this would invalidate the iterator returned by this function.
+    int64_t currentTimestamp =
+        getMapping(handleValue).timestamps.lookup(handleValue);
+    return llvm::make_filter_range(view, [=](Value v) {
+      bool sameTimestamp =
+          currentTimestamp ==
+          this->getMapping(handleValue).timestamps.lookup(handleValue);
+      assert(sameTimestamp && "iterator was invalidated during iteration");
+      return true;
+    });
+#else
+    return llvm::make_range(view.begin(), view.end());
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+  }
 
   /// Populates `handles` with all handles pointing to the given Payload IR op.
   /// Returns success if such handles exist, failure otherwise.
@@ -501,12 +521,15 @@ class TransformState {
   LogicalResult updateStateFromResults(const TransformResults &results,
                                        ResultRange opResults);
 
-  /// Returns a list of all ops that the given transform IR value corresponds to
-  /// at the time when this function is called. In case an op was erased, the
-  /// returned list contains nullptr. This function is helpful for
-  /// transformations that apply to a particular handle.
+  /// Returns a list of all ops that the given transform IR value corresponds
+  /// to. In case an op was erased, the returned list contains nullptr. This
+  /// function is helpful for transformations that apply to a particular handle.
   ArrayRef<Operation *> getPayloadOpsView(Value value) const;
 
+  /// Returns a list of payload IR values that the given transform IR value
+  /// corresponds to.
+  ArrayRef<Value> getPayloadValuesView(Value handleValue) const;
+
   /// Sets the payload IR ops associated with the given transform IR value
   /// (handle). A payload op may be associated multiple handles as long as
   /// at most one of them gets consumed by further transformations.
@@ -774,7 +797,8 @@ class TransformResults {
   /// corresponds to the given list of payload IR ops. Each result must be set
   /// by the transformation exactly once in case of transformation succeeding.
   /// The value must have a type implementing TransformHandleTypeInterface.
-  template <typename Range> void set(OpResult value, Range &&ops) {
+  template <typename Range>
+  void set(OpResult value, Range &&ops) {
     int64_t position = value.getResultNumber();
     assert(position < static_cast<int64_t>(operations.size()) &&
            "setting results for a non-existent handle");
@@ -805,7 +829,27 @@ class TransformResults {
   /// set by the transformation exactly once in case of transformation
   /// succeeding. The value must have a type implementing
   /// TransformValueHandleTypeInterface.
-  void setValues(OpResult handle, ValueRange values);
+  template <typename Range>
+  void setValues(OpResult handle, Range &&values) {
+    int64_t position = handle.getResultNumber();
+    assert(position < static_cast<int64_t>(this->values.size()) &&
+           "setting values for a non-existent handle");
+    assert(this->values[position].data() == nullptr && "values already set");
+    assert(operations[position].data() == nullptr &&
+           "another kind of results already set");
+    assert(params[position].data() == nullptr &&
+           "another kind of results already set");
+    this->values.replace(position, std::forward<Range>(values));
+  }
+
+  /// Indicates that the result of the transform IR op at the given position
+  /// corresponds to the given range of payload IR values. Each result must be
+  /// set by the transformation exactly once in case of transformation
+  /// succeeding. The value must have a type implementing
+  /// TransformValueHandleTypeInterface.
+  void setValues(OpResult handle, std::initializer_list<Value> values) {
+    setValues(handle, ArrayRef<Value>(values));
+  }
 
   /// Indicates that the result of the transform IR op at the given position
   /// corresponds to the given range of mapped values. All mapped values are
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
index 7b8bf6fc5d8f5a4..fb021ed76242e90 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
@@ -728,7 +728,7 @@ DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation(
 
   Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position));
   if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
-    results.setValues(cast<OpResult>(getResult()), result);
+    results.setValues(cast<OpResult>(getResult()), {result});
     return DiagnosedSilenceableFailure::success();
   }
 
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 9cac178d3c2b869..fd2cf8816ae2162 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -75,7 +75,7 @@ ArrayRef<Attribute> transform::TransformState::getParams(Value value) const {
 }
 
 ArrayRef<Value>
-transform::TransformState::getPayloadValues(Value handleValue) const {
+transform::TransformState::getPayloadValuesView(Value handleValue) const {
   const ValueMapping &mapping = getMapping(handleValue).values;
   auto iter = mapping.find(handleValue);
   assert(iter != mapping.end() && "cannot find mapping for value handle "
@@ -310,7 +310,7 @@ void transform::TransformState::forgetMapping(Value opHandle,
   for (Operation *op : mappings.direct[opHandle])
     dropMappingEntry(mappings.reverse, op, opHandle);
   mappings.direct.erase(opHandle);
-#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
   // Payload IR is removed from the mapping. This invalidates the respective
   // iterators.
   mappings.incrementTimestamp(opHandle);
@@ -322,6 +322,11 @@ void transform::TransformState::forgetMapping(Value opHandle,
     for (Value resultHandle : resultHandles) {
       Mappings &localMappings = getMapping(resultHandle);
       dropMappingEntry(localMappings.values, resultHandle, opResult);
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
+      // Payload IR is removed from the mapping. This invalidates the respective
+      // iterators.
+      mappings.incrementTimestamp(resultHandle);
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
       dropMappingEntry(localMappings.reverseValues, opResult, resultHandle);
     }
   }
@@ -333,6 +338,11 @@ void transform::TransformState::forgetValueMapping(
   for (Value payloadValue : mappings.reverseValues[valueHandle])
     dropMappingEntry(mappings.reverseValues, payloadValue, valueHandle);
   mappings.values.erase(valueHandle);
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
+  // Payload IR is removed from the mapping. This invalidates the respective
+  // iterators.
+  mappings.incrementTimestamp(valueHandle);
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
 
   for (Operation *payloadOp : payloadOperations) {
     SmallVector<Value> opHandles;
@@ -342,7 +352,7 @@ void transform::TransformState::forgetValueMapping(
       dropMappingEntry(localMappings.direct, opHandle, payloadOp);
       dropMappingEntry(localMappings.reverse, payloadOp, opHandle);
 
-#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
       // Payload IR is removed from the mapping. This invalidates the respective
       // iterators.
       localMappings.incrementTimestamp(opHandle);
@@ -439,6 +449,11 @@ transform::TransformState::replacePayloadValue(Value value, Value replacement) {
     // between the handles and the IR objects
     if (!replacement) {
       dropMappingEntry(mappings.values, handle, value);
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
+      // Payload IR is removed from the mapping. This invalidates the respective
+      // iterators.
+      mappings.incrementTimestamp(handle);
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
     } else {
       auto it = mappings.values.find(handle);
       if (it == mappings.values.end())
@@ -647,7 +662,7 @@ void transform::TransformState::recordValueHandleInvalidation(
     OpOperand &valueHandle,
     transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
   // Invalidate other handles to the same value.
-  for (Value payloadValue : getPayloadValues(valueHandle.get())) {
+  for (Value payloadValue : getPayloadValuesView(valueHandle.get())) {
     SmallVector<Value> otherValueHandles;
     (void)getHandlesForPayloadValue(payloadValue, otherValueHandles);
     for (Value otherHandle : otherValueHandles) {
@@ -785,7 +800,7 @@ checkRepeatedConsumptionInOperand(ArrayRef<T> payload,
 void transform::TransformState::compactOpHandles() {
   for (Value handle : opHandlesToCompact) {
     Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
-#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
     if (llvm::find(mappings.direct[handle], nullptr) !=
         mappings.direct[handle].end())
       // Payload IR is removed from the mapping. This invalidates the respective
@@ -846,7 +861,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
         FULL_LDBG("--checkRepeatedConsumptionInOperand For Value\n");
         DiagnosedSilenceableFailure check =
             checkRepeatedConsumptionInOperand<Value>(
-                getPayloadValues(operand.get()), transform,
+                getPayloadValuesView(operand.get()), transform,
                 operand.getOperandNumber());
         if (!check.succeeded()) {
           FULL_LDBG("----FAILED\n");
@@ -912,7 +927,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
       continue;
     }
     if (llvm::isa<TransformValueHandleTypeInterface>(operand.getType())) {
-      for (Value payloadValue : getPayloadValues(operand)) {
+      for (Value payloadValue : getPayloadValuesView(operand)) {
         if (llvm::isa<OpResult>(payloadValue)) {
           origAssociatedOps.push_back(payloadValue.getDefiningOp());
           continue;
@@ -1170,19 +1185,6 @@ void transform::TransformResults::setParams(
   this->params.replace(position, params);
 }
 
-void transform::TransformResults::setValues(OpResult handle,
-                                            ValueRange values) {
-  int64_t position = handle.getResultNumber();
-  assert(position < static_cast<int64_t>(this->values.size()) &&
-         "setting values for a non-existent handle");
-  assert(this->values[position].data() == nullptr && "values already set");
-  assert(operations[position].data() == nullptr &&
-         "another kind of results already set");
-  assert(params[position].data() == nullptr &&
-         "another kind of results already set");
-  this->values.replace(position, values);
-}
-
 void transform::TransformResults::setMappedValues(
     OpResult handle, ArrayRef<MappedValue> values) {
   DiagnosedSilenceableFailure diag = dispatchMappedValues(
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index de3cd1b28e435bc..cd4f628f1459ab7 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1378,9 +1378,7 @@ transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
                             transform::TransformResults &results,
                             transform::TransformState &state) {
   SmallVector<Attribute> params;
-  ArrayRef<Value> values = state.getPayloadValues(getValue());
-  params.reserve(values.size());
-  for (Value value : values) {
+  for (Value value : state.getPayloadValues(getValue())) {
     Type type = value.getType();
     if (getElemental()) {
       if (auto shaped = dyn_cast<ShapedType>(type)) {
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 59f045de3246f6b..e8c25aca237251a 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -136,7 +136,7 @@ DiagnosedSilenceableFailure
 mlir::test::TestProduceValueHandleToSelfOperand::apply(
     transform::TransformRewriter &rewriter,
     transform::TransformResults &results, transform::TransformState &state) {
-  results.setValues(llvm::cast<OpResult>(getOut()), getIn());
+  results.setValues(llvm::cast<OpResult>(getOut()), {getIn()});
   return DiagnosedSilenceableFailure::success();
 }
 
@@ -265,8 +265,7 @@ void mlir::test::TestPrintRemarkAtOperandOp::getEffects(
 DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandValue::apply(
     transform::TransformRewriter &rewriter,
     transform::TransformResults &results, transform::TransformState &state) {
-  ArrayRef<Value> values = state.getPayloadValues(getIn());
-  for (Value value : values) {
+  for (Value value : state.getPayloadValues(getIn())) {
     std::string note;
     llvm::raw_string_ostream os(note);
     if (auto arg = llvm::dyn_cast<BlockArgument>(value)) {
@@ -712,7 +711,7 @@ void mlir::test::TestProduceNullValueOp::getEffects(
 DiagnosedSilenceableFailure mlir::test::TestProduceNullValueOp::apply(
     transform::TransformRewriter &rewriter,
     transform::TransformResults &results, transform::TransformState &state) {
-  results.setValues(llvm::cast<OpResult>(getOut()), Value());
+  results.setValues(llvm::cast<OpResult>(getOut()), {Value()});
   return DiagnosedSilenceableFailure::success();
 }
 



More information about the Mlir-commits mailing list