[Mlir-commits] [mlir] [mlir][transform] Check for invalidated iterators on payload values (PR #66472)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 15 01:58:05 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg
            
<details>
<summary>Changes</summary>
Same as #66369 but for payload values. (#66369 added checks only for payload operations.)

It was necessary to change the signature of `getPayloadValues` to return an iterator. This is now similar to payload operations.

Fixes an issue in #66369 where the `LLVM_ENABLE_ABI_BREAKING_CHECKS` check was inverted.
--
Full diff: https://github.com/llvm/llvm-project/pull/66472.diff

6 Files Affected:

- (modified) mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h (+3-3) 
- (modified) mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h (+57-13) 
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp (+1-1) 
- (modified) mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp (+22-20) 
- (modified) mlir/lib/Dialect/Transform/IR/TransformOps.cpp (+1-3) 
- (modified) mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp (+3-4) 


<pre>
diff --git a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
index 1a6afc58fef2704..c8888f294f6ca1d 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
@@ -95,15 +95,15 @@ class SingleValueMatcherOpTrait
                                     TransformResults &amp;results,
                                     TransformState &amp;state) {
     Value operandHandle = cast&lt;OpTy&gt;(this-&gt;getOperation()).getOperandHandle();
-    ValueRange payload = state.getPayloadValues(operandHandle);
-    if (payload.size() != 1) {
+    auto payload = state.getPayloadValues(operandHandle);
+    if (!llvm::hasSingleElement(payload)) {
       return emitDefiniteFailure(this-&gt;getOperation()-&gt;getLoc())
              &lt;&lt; &quot;SingleValueMatchOpTrait requires the value handle to point to &quot;
                 &quot;a single payload value&quot;;
     }
 
     return cast&lt;OpTy&gt;(this-&gt;getOperation())
-        .matchValue(payload[0], results, state);
+        .matchValue(*payload.begin(), results, state);
   }
 
   void getEffects(SmallVectorImpl&lt;MemoryEffects::EffectInstance&gt; &amp;effects) {
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 86af59142b77d9c..31a93b05cf7a153 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -170,7 +170,7 @@ class TransformState {
   /// should be emitted when the value is used.
   using InvalidatedHandleMap = DenseMap&lt;Value, std::function&lt;void(Location)&gt;&gt;;
 
-#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
   /// Debug only: A timestamp is associated with each transform IR value, so
   /// that invalid iterator usage can be detected more reliably.
   using TransformIRTimestampMapping = DenseMap&lt;Value, int64_t&gt;;
@@ -185,7 +185,7 @@ class TransformState {
     ValueMapping values;
     ValueMapping reverseValues;
 
-#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
     TransformIRTimestampMapping timestamps;
     void incrementTimestamp(Value value) { ++timestamps[value]; }
 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
@@ -220,7 +220,7 @@ class TransformState {
   auto getPayloadOps(Value value) const {
     ArrayRef&lt;Operation *&gt; view = getPayloadOpsView(value);
 
-#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
     // Memorize the current timestamp and make sure that it has not changed
     // when incrementing or dereferencing the iterator returned by this
     // function. The timestamp is incremented when the &quot;direct&quot; mapping is
@@ -231,7 +231,7 @@ class TransformState {
     // When ops are replaced/erased, they are replaced with nullptr (until
     // the data structure is compacted). Do not enumerate these ops.
     return llvm::make_filter_range(view, [=](Operation *op) {
-#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
       bool sameTimestamp =
           currentTimestamp == this-&gt;getMapping(value).timestamps.lookup(value);
       assert(sameTimestamp &amp;&amp; &quot;iterator was invalidated during iteration&quot;);
@@ -244,9 +244,29 @@ class TransformState {
   /// corresponds to.
   ArrayRef&lt;Attribute&gt; getParams(Value value) const;
 
-  /// Returns the list of payload IR values that the given transform IR value
-  /// corresponds to.
-  ArrayRef&lt;Value&gt; getPayloadValues(Value handleValue) const;
+  /// Returns an iterator that enumerates all payload IR values that the given
+  /// transform IR value corresponds to.
+  auto getPayloadValues(Value handleValue) const {
+    ArrayRef&lt;Value&gt; view = getPayloadValuesView(handleValue);
+
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
+    // Memorize the current timestamp and make sure that it has not changed
+    // when incrementing or dereferencing the iterator returned by this
+    // function. The timestamp is incremented when the &quot;values&quot; mapping is
+    // resized; this would invalidate the iterator returned by this function.
+    int64_t currentTimestamp =
+        getMapping(handleValue).timestamps.lookup(handleValue);
+    return llvm::make_filter_range(view, [=](Value v) {
+      bool sameTimestamp =
+          currentTimestamp ==
+          this-&gt;getMapping(handleValue).timestamps.lookup(handleValue);
+      assert(sameTimestamp &amp;&amp; &quot;iterator was invalidated during iteration&quot;);
+      return true;
+    });
+#else
+    return llvm::make_range(view.begin(), view.end());
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+  }
 
   /// Populates `handles` with all handles pointing to the given Payload IR op.
   /// Returns success if such handles exist, failure otherwise.
@@ -501,12 +521,15 @@ class TransformState {
   LogicalResult updateStateFromResults(const TransformResults &amp;results,
                                        ResultRange opResults);
 
-  /// Returns a list of all ops that the given transform IR value corresponds to
-  /// at the time when this function is called. In case an op was erased, the
-  /// returned list contains nullptr. This function is helpful for
-  /// transformations that apply to a particular handle.
+  /// Returns a list of all ops that the given transform IR value corresponds
+  /// to. In case an op was erased, the returned list contains nullptr. This
+  /// function is helpful for transformations that apply to a particular handle.
   ArrayRef&lt;Operation *&gt; getPayloadOpsView(Value value) const;
 
+  /// Returns a list of payload IR values that the given transform IR value
+  /// corresponds to.
+  ArrayRef&lt;Value&gt; getPayloadValuesView(Value handleValue) const;
+
   /// Sets the payload IR ops associated with the given transform IR value
   /// (handle). A payload op may be associated multiple handles as long as
   /// at most one of them gets consumed by further transformations.
@@ -774,7 +797,8 @@ class TransformResults {
   /// corresponds to the given list of payload IR ops. Each result must be set
   /// by the transformation exactly once in case of transformation succeeding.
   /// The value must have a type implementing TransformHandleTypeInterface.
-  template &lt;typename Range&gt; void set(OpResult value, Range &amp;&amp;ops) {
+  template &lt;typename Range&gt;
+  void set(OpResult value, Range &amp;&amp;ops) {
     int64_t position = value.getResultNumber();
     assert(position &lt; static_cast&lt;int64_t&gt;(operations.size()) &amp;&amp;
            &quot;setting results for a non-existent handle&quot;);
@@ -805,7 +829,27 @@ class TransformResults {
   /// set by the transformation exactly once in case of transformation
   /// succeeding. The value must have a type implementing
   /// TransformValueHandleTypeInterface.
-  void setValues(OpResult handle, ValueRange values);
+  template &lt;typename Range&gt;
+  void setValues(OpResult handle, Range &amp;&amp;values) {
+    int64_t position = handle.getResultNumber();
+    assert(position &lt; static_cast&lt;int64_t&gt;(this-&gt;values.size()) &amp;&amp;
+           &quot;setting values for a non-existent handle&quot;);
+    assert(this-&gt;values[position].data() == nullptr &amp;&amp; &quot;values already set&quot;);
+    assert(operations[position].data() == nullptr &amp;&amp;
+           &quot;another kind of results already set&quot;);
+    assert(params[position].data() == nullptr &amp;&amp;
+           &quot;another kind of results already set&quot;);
+    this-&gt;values.replace(position, std::forward&lt;Range&gt;(values));
+  }
+
+  /// Indicates that the result of the transform IR op at the given position
+  /// corresponds to the given range of payload IR values. Each result must be
+  /// set by the transformation exactly once in case of transformation
+  /// succeeding. The value must have a type implementing
+  /// TransformValueHandleTypeInterface.
+  void setValues(OpResult handle, std::initializer_list&lt;Value&gt; values) {
+    setValues(handle, ArrayRef&lt;Value&gt;(values));
+  }
 
   /// Indicates that the result of the transform IR op at the given position
   /// corresponds to the given range of mapped values. All mapped values are
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
index 7b8bf6fc5d8f5a4..fb021ed76242e90 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
@@ -728,7 +728,7 @@ DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation(
 
   Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position));
   if (isa&lt;TransformValueHandleTypeInterface&gt;(getResult().getType())) {
-    results.setValues(cast&lt;OpResult&gt;(getResult()), result);
+    results.setValues(cast&lt;OpResult&gt;(getResult()), {result});
     return DiagnosedSilenceableFailure::success();
   }
 
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 9cac178d3c2b869..fd2cf8816ae2162 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -75,7 +75,7 @@ ArrayRef&lt;Attribute&gt; transform::TransformState::getParams(Value value) const {
 }
 
 ArrayRef&lt;Value&gt;
-transform::TransformState::getPayloadValues(Value handleValue) const {
+transform::TransformState::getPayloadValuesView(Value handleValue) const {
   const ValueMapping &amp;mapping = getMapping(handleValue).values;
   auto iter = mapping.find(handleValue);
   assert(iter != mapping.end() &amp;&amp; &quot;cannot find mapping for value handle &quot;
@@ -310,7 +310,7 @@ void transform::TransformState::forgetMapping(Value opHandle,
   for (Operation *op : mappings.direct[opHandle])
     dropMappingEntry(mappings.reverse, op, opHandle);
   mappings.direct.erase(opHandle);
-#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
   // Payload IR is removed from the mapping. This invalidates the respective
   // iterators.
   mappings.incrementTimestamp(opHandle);
@@ -322,6 +322,11 @@ void transform::TransformState::forgetMapping(Value opHandle,
     for (Value resultHandle : resultHandles) {
       Mappings &amp;localMappings = getMapping(resultHandle);
       dropMappingEntry(localMappings.values, resultHandle, opResult);
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
+      // Payload IR is removed from the mapping. This invalidates the respective
+      // iterators.
+      mappings.incrementTimestamp(resultHandle);
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
       dropMappingEntry(localMappings.reverseValues, opResult, resultHandle);
     }
   }
@@ -333,6 +338,11 @@ void transform::TransformState::forgetValueMapping(
   for (Value payloadValue : mappings.reverseValues[valueHandle])
     dropMappingEntry(mappings.reverseValues, payloadValue, valueHandle);
   mappings.values.erase(valueHandle);
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
+  // Payload IR is removed from the mapping. This invalidates the respective
+  // iterators.
+  mappings.incrementTimestamp(valueHandle);
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
 
   for (Operation *payloadOp : payloadOperations) {
     SmallVector&lt;Value&gt; opHandles;
@@ -342,7 +352,7 @@ void transform::TransformState::forgetValueMapping(
       dropMappingEntry(localMappings.direct, opHandle, payloadOp);
       dropMappingEntry(localMappings.reverse, payloadOp, opHandle);
 
-#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
       // Payload IR is removed from the mapping. This invalidates the respective
       // iterators.
       localMappings.incrementTimestamp(opHandle);
@@ -439,6 +449,11 @@ transform::TransformState::replacePayloadValue(Value value, Value replacement) {
     // between the handles and the IR objects
     if (!replacement) {
       dropMappingEntry(mappings.values, handle, value);
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
+      // Payload IR is removed from the mapping. This invalidates the respective
+      // iterators.
+      mappings.incrementTimestamp(handle);
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
     } else {
       auto it = mappings.values.find(handle);
       if (it == mappings.values.end())
@@ -647,7 +662,7 @@ void transform::TransformState::recordValueHandleInvalidation(
     OpOperand &amp;valueHandle,
     transform::TransformState::InvalidatedHandleMap &amp;newlyInvalidated) const {
   // Invalidate other handles to the same value.
-  for (Value payloadValue : getPayloadValues(valueHandle.get())) {
+  for (Value payloadValue : getPayloadValuesView(valueHandle.get())) {
     SmallVector&lt;Value&gt; otherValueHandles;
     (void)getHandlesForPayloadValue(payloadValue, otherValueHandles);
     for (Value otherHandle : otherValueHandles) {
@@ -785,7 +800,7 @@ checkRepeatedConsumptionInOperand(ArrayRef&lt;T&gt; payload,
 void transform::TransformState::compactOpHandles() {
   for (Value handle : opHandlesToCompact) {
     Mappings &amp;mappings = getMapping(handle, /*allowOutOfScope=*/true);
-#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
     if (llvm::find(mappings.direct[handle], nullptr) !=
         mappings.direct[handle].end())
       // Payload IR is removed from the mapping. This invalidates the respective
@@ -846,7 +861,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
         FULL_LDBG(&quot;--checkRepeatedConsumptionInOperand For Value\n&quot;);
         DiagnosedSilenceableFailure check =
             checkRepeatedConsumptionInOperand&lt;Value&gt;(
-                getPayloadValues(operand.get()), transform,
+                getPayloadValuesView(operand.get()), transform,
                 operand.getOperandNumber());
         if (!check.succeeded()) {
           FULL_LDBG(&quot;----FAILED\n&quot;);
@@ -912,7 +927,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
       continue;
     }
     if (llvm::isa&lt;TransformValueHandleTypeInterface&gt;(operand.getType())) {
-      for (Value payloadValue : getPayloadValues(operand)) {
+      for (Value payloadValue : getPayloadValuesView(operand)) {
         if (llvm::isa&lt;OpResult&gt;(payloadValue)) {
           origAssociatedOps.push_back(payloadValue.getDefiningOp());
           continue;
@@ -1170,19 +1185,6 @@ void transform::TransformResults::setParams(
   this-&gt;params.replace(position, params);
 }
 
-void transform::TransformResults::setValues(OpResult handle,
-                                            ValueRange values) {
-  int64_t position = handle.getResultNumber();
-  assert(position &lt; static_cast&lt;int64_t&gt;(this-&gt;values.size()) &amp;&amp;
-         &quot;setting values for a non-existent handle&quot;);
-  assert(this-&gt;values[position].data() == nullptr &amp;&amp; &quot;values already set&quot;);
-  assert(operations[position].data() == nullptr &amp;&amp;
-         &quot;another kind of results already set&quot;);
-  assert(params[position].data() == nullptr &amp;&amp;
-         &quot;another kind of results already set&quot;);
-  this-&gt;values.replace(position, values);
-}
-
 void transform::TransformResults::setMappedValues(
     OpResult handle, ArrayRef&lt;MappedValue&gt; values) {
   DiagnosedSilenceableFailure diag = dispatchMappedValues(
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index de3cd1b28e435bc..cd4f628f1459ab7 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1378,9 +1378,7 @@ transform::GetTypeOp::apply(transform::TransformRewriter &amp;rewriter,
                             transform::TransformResults &amp;results,
                             transform::TransformState &amp;state) {
   SmallVector&lt;Attribute&gt; params;
-  ArrayRef&lt;Value&gt; values = state.getPayloadValues(getValue());
-  params.reserve(values.size());
-  for (Value value : values) {
+  for (Value value : state.getPayloadValues(getValue())) {
     Type type = value.getType();
     if (getElemental()) {
       if (auto shaped = dyn_cast&lt;ShapedType&gt;(type)) {
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 59f045de3246f6b..e8c25aca237251a 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -136,7 +136,7 @@ DiagnosedSilenceableFailure
 mlir::test::TestProduceValueHandleToSelfOperand::apply(
     transform::TransformRewriter &amp;rewriter,
     transform::TransformResults &amp;results, transform::TransformState &amp;state) {
-  results.setValues(llvm::cast&lt;OpResult&gt;(getOut()), getIn());
+  results.setValues(llvm::cast&lt;OpResult&gt;(getOut()), {getIn()});
   return DiagnosedSilenceableFailure::success();
 }
 
@@ -265,8 +265,7 @@ void mlir::test::TestPrintRemarkAtOperandOp::getEffects(
 DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandValue::apply(
     transform::TransformRewriter &amp;rewriter,
     transform::TransformResults &amp;results, transform::TransformState &amp;state) {
-  ArrayRef&lt;Value&gt; values = state.getPayloadValues(getIn());
-  for (Value value : values) {
+  for (Value value : state.getPayloadValues(getIn())) {
     std::string note;
     llvm::raw_string_ostream os(note);
     if (auto arg = llvm::dyn_cast&lt;BlockArgument&gt;(value)) {
@@ -712,7 +711,7 @@ void mlir::test::TestProduceNullValueOp::getEffects(
 DiagnosedSilenceableFailure mlir::test::TestProduceNullValueOp::apply(
     transform::TransformRewriter &amp;rewriter,
     transform::TransformResults &amp;results, transform::TransformState &amp;state) {
-  results.setValues(llvm::cast&lt;OpResult&gt;(getOut()), Value());
+  results.setValues(llvm::cast&lt;OpResult&gt;(getOut()), {Value()});
   return DiagnosedSilenceableFailure::success();
 }
 
</pre>
</details>


https://github.com/llvm/llvm-project/pull/66472


More information about the Mlir-commits mailing list