[Mlir-commits] [mlir] 7b4a755 - [mlir][transform] TransformStateExtension: Replace op/value handles separately

Matthias Springer llvmlistbot at llvm.org
Wed Mar 29 01:56:20 PDT 2023


Author: Matthias Springer
Date: 2023-03-29T10:56:08+02:00
New Revision: 7b4a7552719b7720b9c8ccb4bc04a9e6fa1ec0b6

URL: https://github.com/llvm/llvm-project/commit/7b4a7552719b7720b9c8ccb4bc04a9e6fa1ec0b6
DIFF: https://github.com/llvm/llvm-project/commit/7b4a7552719b7720b9c8ccb4bc04a9e6fa1ec0b6.diff

LOG: [mlir][transform] TransformStateExtension: Replace op/value handles separately

Differential Revision: https://reviews.llvm.org/D147038

Added: 
    mlir/test/lib/Dialect/Transform/TestTransformStateExtension.cpp

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
    mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
    mlir/test/Dialect/Transform/transform-state-extension.mlir
    mlir/test/lib/Dialect/Transform/CMakeLists.txt
    mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index c430642daa1f3..41b084008acf4 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -313,8 +313,16 @@ class TransformState {
     /// Replaces the given payload op with another op. If the replacement op is
     /// null, removes the association of the payload op with its handle. Returns
     /// failure if the op is not associated with any handle.
+    ///
+    /// Note: This function does not update value handles. None of the original
+    /// op's results are allowed to be mapped to any value handle.
     LogicalResult replacePayloadOp(Operation *op, Operation *replacement);
 
+    /// Replaces the given payload value with another value. If the replacement
+    /// value is null, removes the association of the payload value with its
+    /// handle. Returns failure if the value is not associated with any handle.
+    LogicalResult replacePayloadValue(Value value, Value replacement);
+
   private:
     /// Back-reference to the state that is being extended.
     TransformState &state;
@@ -484,18 +492,18 @@ class TransformState {
   void forgetValueMapping(Value valueHandle,
                           ArrayRef<Operation *> payloadOperations);
 
-  /// Updates the payload IR ops associated with the given transform IR value.
-  /// The callback function is called once per associated operation and is
-  /// expected to return the modified operation or nullptr. In the latter case,
-  /// the corresponding operation is no longer associated with the transform IR
-  /// value. Value handles associated with the results of the operation are
-  /// also updated to be associated with the results of the new operation. For
-  /// this reason, the new operation must have the same number of results.
+  /// Replaces the given payload op with another op. If the replacement op is
+  /// null, removes the association of the payload op with its handle.
   ///
-  /// Returns failure if the payload does not satisfy the conditions associated
-  /// with the type of the handle value.
+  /// Note: This function does not update value handles. None of the original
+  /// op's results are allowed to be mapped to any value handle.
   LogicalResult replacePayloadOp(Operation *op, Operation *replacement);
 
+  /// Replaces the given payload value with another value. If the replacement
+  /// value is null, removes the association of the payload value with its
+  /// handle.
+  LogicalResult replacePayloadValue(Value value, Value replacement);
+
   /// If the operand is a handle consumed by the operation, i.e. has the "free"
   /// memory effect associated with it, identifies other handles that are
   /// pointing to payload IR operations nested in the operations pointed to by

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 2f6221601eb4a..af0627d891eb0 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -336,18 +336,13 @@ transform::TransformState::replacePayloadOp(Operation *op,
     dropMappingEntry(mappings.reverse, op, handle);
   }
 
-  // Drop the mapping between the op results and all value handles that point to
-  // them. Don't care if there are no such handles.
-  RaggedArray<Value> resultValueHandles;
+#ifndef NDEBUG
   for (Value opResult : op->getResults()) {
     SmallVector<Value> valueHandles;
     (void)getHandlesForPayloadValue(opResult, valueHandles);
-    for (Value handle : valueHandles) {
-      Mappings &localMappings = getMapping(handle);
-      dropMappingEntry(localMappings.reverseValues, opResult, handle);
-    }
-    resultValueHandles.push_back(std::move(valueHandles));
+    assert(valueHandles.empty() && "expected no mapping to old results");
   }
+#endif // NDEBUG
 
   // TODO: consider invalidating the handles to nested objects here.
 
@@ -358,14 +353,6 @@ transform::TransformState::replacePayloadOp(Operation *op,
       Mappings &mappings = getMapping(handle);
       dropMappingEntry(mappings.direct, handle, op);
     }
-    for (Value opResult : op->getResults()) {
-      SmallVector<Value> valueHandles;
-      (void)getHandlesForPayloadValue(opResult, valueHandles);
-      for (Value handle : valueHandles) {
-        Mappings &localMappings = getMapping(handle);
-        dropMappingEntry(localMappings.values, handle, opResult);
-      }
-    }
     return success();
   }
 
@@ -386,33 +373,33 @@ transform::TransformState::replacePayloadOp(Operation *op,
     mappings.reverse[replacement].push_back(handle);
   }
 
-  // Second, replace the mapped results of the operation.
-  for (auto [origResult, handleList] :
-       llvm::zip(op->getResults(), resultValueHandles)) {
-    // No handles to the value, skip even if there is no replacement.
-    if (handleList.empty())
-      continue;
+  return success();
+}
 
-    unsigned resultNumber = origResult.getResultNumber();
-    if (resultNumber >= replacement->getNumResults()) {
-      return emitError(op->getLoc())
-             << "cannot replace an op with another op producing less results "
-                "while tracking handles";
-    }
+LogicalResult
+transform::TransformState::replacePayloadValue(Value value, Value replacement) {
+  SmallVector<Value> valueHandles;
+  (void)getHandlesForPayloadValue(value, valueHandles);
+
+  for (Value handle : valueHandles) {
+    Mappings &mappings = getMapping(handle);
+    dropMappingEntry(mappings.reverseValues, value, handle);
 
-    Value replacementResult = replacement->getResult(resultNumber);
-    for (Value resultHandle : handleList) {
-      Mappings &mappings = getMapping(resultHandle);
-      auto it = mappings.values.find(resultHandle);
+    // If replacing with null, that is erasing the mapping, drop the mapping
+    // between the handles and the IR objects
+    if (!replacement) {
+      dropMappingEntry(mappings.values, handle, value);
+    } else {
+      auto it = mappings.values.find(handle);
       if (it == mappings.values.end())
         continue;
 
       SmallVector<Value> &association = it->getSecond();
       for (Value &mapped : association) {
-        if (mapped == origResult)
-          mapped = replacementResult;
+        if (mapped == value)
+          mapped = replacement;
       }
-      mappings.reverseValues[replacementResult].push_back(resultHandle);
+      mappings.reverseValues[replacement].push_back(handle);
     }
   }
 
@@ -867,6 +854,16 @@ transform::TransformState::Extension::replacePayloadOp(Operation *op,
   return state.replacePayloadOp(op, replacement);
 }
 
+LogicalResult
+transform::TransformState::Extension::replacePayloadValue(Value value,
+                                                          Value replacement) {
+  SmallVector<Value> handles;
+  if (failed(state.getHandlesForPayloadValue(value, handles)))
+    return failure();
+
+  return state.replacePayloadValue(value, replacement);
+}
+
 //===----------------------------------------------------------------------===//
 // TransformResults
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Transform/transform-state-extension.mlir b/mlir/test/Dialect/Transform/transform-state-extension.mlir
index 7885e08291716..76cd1acf34493 100644
--- a/mlir/test/Dialect/Transform/transform-state-extension.mlir
+++ b/mlir/test/Dialect/Transform/transform-state-extension.mlir
@@ -50,7 +50,7 @@ module {
 transform.sequence failures(propagate) {
 ^bb0(%arg0: !pdl.operation):
   test_add_test_extension "A"
-   // This is okay because we are replacing the top-level module opeation
+   // This is okay because we are replacing the top-level module operation
    // (0 results) with this operation that has _more_ (1) results.
   %dummy = test_remap_operand_to_self %arg0 : !pdl.operation
 }
@@ -72,7 +72,7 @@ transform.sequence failures(propagate) {
 transform.sequence failures(propagate) {
 ^bb0(%arg0: !pdl.operation):
   test_add_test_extension "A"
-  // expected-error @below {{cannot replace an op with another op producing less results while tracking handles}}
+  // expected-error @below {{cannot replace an op with another op producing fewer results while tracking handles}}
   %dummy = test_remap_operand_to_self %arg0 : !pdl.operation
   %valuehandle = transform.get_result %dummy[0] : (!pdl.operation) -> !transform.any_value
   test_remap_operand_to_self %dummy

diff  --git a/mlir/test/lib/Dialect/Transform/CMakeLists.txt b/mlir/test/lib/Dialect/Transform/CMakeLists.txt
index 56c7031481b95..b86b8f56ba6c4 100644
--- a/mlir/test/lib/Dialect/Transform/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Transform/CMakeLists.txt
@@ -8,6 +8,7 @@ add_public_tablegen_target(MLIRTestTransformDialectExtensionIncGen)
 add_mlir_library(MLIRTestTransformDialect
   TestTransformDialectExtension.cpp
   TestTransformDialectInterpreter.cpp
+  TestTransformStateExtension.cpp
 
   EXCLUDE_FROM_LIBMLIR
 

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.cpp
new file mode 100644
index 0000000000000..e88f4df2d75c8
--- /dev/null
+++ b/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.cpp
@@ -0,0 +1,36 @@
+//===- TestTransformStateExtension.cpp ------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestTransformStateExtension.h"
+
+using namespace mlir;
+
+LogicalResult
+test::TestTransformStateExtension::updateMapping(Operation *previous,
+                                                 Operation *updated) {
+  // Update value handles. The new ops should have at least as many results as
+  // the replacement op. Fewer results are acceptable, if those results are not
+  // mapped to any handle.
+  for (auto r = updated->getNumResults(); r < previous->getNumResults(); ++r) {
+    SmallVector<Value> handles;
+    (void)getTransformState().getHandlesForPayloadValue(previous->getResult(r),
+                                                        handles);
+    if (!handles.empty())
+      return emitError(previous->getLoc())
+             << "cannot replace an op with another op producing fewer results "
+                "while tracking handles";
+  }
+
+  for (auto [oldValue, newValue] :
+       llvm::zip(previous->getResults(), updated->getResults()))
+    if (failed(replacePayloadValue(oldValue, newValue)))
+      return failure();
+
+  // Update op handle.
+  return replacePayloadOp(previous, updated);
+}

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h b/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h
index 3b2eb7602a7b5..752b3a78141ea 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h
+++ b/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h
@@ -29,9 +29,7 @@ class TestTransformStateExtension
 
   StringRef getMessage() const { return message.getValue(); }
 
-  LogicalResult updateMapping(Operation *previous, Operation *updated) {
-    return replacePayloadOp(previous, updated);
-  }
+  LogicalResult updateMapping(Operation *previous, Operation *updated);
 
 private:
   StringAttr message;


        


More information about the Mlir-commits mailing list