[Mlir-commits] [mlir] 7b4a755 - [mlir][transform] TransformStateExtension: Replace op/value handles separately
Matthias Springer
llvmlistbot at llvm.org
Wed Mar 29 01:56:20 PDT 2023
Author: Matthias Springer
Date: 2023-03-29T10:56:08+02:00
New Revision: 7b4a7552719b7720b9c8ccb4bc04a9e6fa1ec0b6
URL: https://github.com/llvm/llvm-project/commit/7b4a7552719b7720b9c8ccb4bc04a9e6fa1ec0b6
DIFF: https://github.com/llvm/llvm-project/commit/7b4a7552719b7720b9c8ccb4bc04a9e6fa1ec0b6.diff
LOG: [mlir][transform] TransformStateExtension: Replace op/value handles separately
Differential Revision: https://reviews.llvm.org/D147038
Added:
mlir/test/lib/Dialect/Transform/TestTransformStateExtension.cpp
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/test/Dialect/Transform/transform-state-extension.mlir
mlir/test/lib/Dialect/Transform/CMakeLists.txt
mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index c430642daa1f3..41b084008acf4 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -313,8 +313,16 @@ class TransformState {
/// Replaces the given payload op with another op. If the replacement op is
/// null, removes the association of the payload op with its handle. Returns
/// failure if the op is not associated with any handle.
+ ///
+ /// Note: This function does not update value handles. None of the original
+ /// op's results are allowed to be mapped to any value handle.
LogicalResult replacePayloadOp(Operation *op, Operation *replacement);
+ /// Replaces the given payload value with another value. If the replacement
+ /// value is null, removes the association of the payload value with its
+ /// handle. Returns failure if the value is not associated with any handle.
+ LogicalResult replacePayloadValue(Value value, Value replacement);
+
private:
/// Back-reference to the state that is being extended.
TransformState &state;
@@ -484,18 +492,18 @@ class TransformState {
void forgetValueMapping(Value valueHandle,
ArrayRef<Operation *> payloadOperations);
- /// Updates the payload IR ops associated with the given transform IR value.
- /// The callback function is called once per associated operation and is
- /// expected to return the modified operation or nullptr. In the latter case,
- /// the corresponding operation is no longer associated with the transform IR
- /// value. Value handles associated with the results of the operation are
- /// also updated to be associated with the results of the new operation. For
- /// this reason, the new operation must have the same number of results.
+ /// Replaces the given payload op with another op. If the replacement op is
+ /// null, removes the association of the payload op with its handle.
///
- /// Returns failure if the payload does not satisfy the conditions associated
- /// with the type of the handle value.
+ /// Note: This function does not update value handles. None of the original
+ /// op's results are allowed to be mapped to any value handle.
LogicalResult replacePayloadOp(Operation *op, Operation *replacement);
+ /// Replaces the given payload value with another value. If the replacement
+ /// value is null, removes the association of the payload value with its
+ /// handle.
+ LogicalResult replacePayloadValue(Value value, Value replacement);
+
/// If the operand is a handle consumed by the operation, i.e. has the "free"
/// memory effect associated with it, identifies other handles that are
/// pointing to payload IR operations nested in the operations pointed to by
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 2f6221601eb4a..af0627d891eb0 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -336,18 +336,13 @@ transform::TransformState::replacePayloadOp(Operation *op,
dropMappingEntry(mappings.reverse, op, handle);
}
- // Drop the mapping between the op results and all value handles that point to
- // them. Don't care if there are no such handles.
- RaggedArray<Value> resultValueHandles;
+#ifndef NDEBUG
for (Value opResult : op->getResults()) {
SmallVector<Value> valueHandles;
(void)getHandlesForPayloadValue(opResult, valueHandles);
- for (Value handle : valueHandles) {
- Mappings &localMappings = getMapping(handle);
- dropMappingEntry(localMappings.reverseValues, opResult, handle);
- }
- resultValueHandles.push_back(std::move(valueHandles));
+ assert(valueHandles.empty() && "expected no mapping to old results");
}
+#endif // NDEBUG
// TODO: consider invalidating the handles to nested objects here.
@@ -358,14 +353,6 @@ transform::TransformState::replacePayloadOp(Operation *op,
Mappings &mappings = getMapping(handle);
dropMappingEntry(mappings.direct, handle, op);
}
- for (Value opResult : op->getResults()) {
- SmallVector<Value> valueHandles;
- (void)getHandlesForPayloadValue(opResult, valueHandles);
- for (Value handle : valueHandles) {
- Mappings &localMappings = getMapping(handle);
- dropMappingEntry(localMappings.values, handle, opResult);
- }
- }
return success();
}
@@ -386,33 +373,33 @@ transform::TransformState::replacePayloadOp(Operation *op,
mappings.reverse[replacement].push_back(handle);
}
- // Second, replace the mapped results of the operation.
- for (auto [origResult, handleList] :
- llvm::zip(op->getResults(), resultValueHandles)) {
- // No handles to the value, skip even if there is no replacement.
- if (handleList.empty())
- continue;
+ return success();
+}
- unsigned resultNumber = origResult.getResultNumber();
- if (resultNumber >= replacement->getNumResults()) {
- return emitError(op->getLoc())
- << "cannot replace an op with another op producing less results "
- "while tracking handles";
- }
+LogicalResult
+transform::TransformState::replacePayloadValue(Value value, Value replacement) {
+ SmallVector<Value> valueHandles;
+ (void)getHandlesForPayloadValue(value, valueHandles);
+
+ for (Value handle : valueHandles) {
+ Mappings &mappings = getMapping(handle);
+ dropMappingEntry(mappings.reverseValues, value, handle);
- Value replacementResult = replacement->getResult(resultNumber);
- for (Value resultHandle : handleList) {
- Mappings &mappings = getMapping(resultHandle);
- auto it = mappings.values.find(resultHandle);
+ // If replacing with null, that is erasing the mapping, drop the mapping
+ // between the handles and the IR objects
+ if (!replacement) {
+ dropMappingEntry(mappings.values, handle, value);
+ } else {
+ auto it = mappings.values.find(handle);
if (it == mappings.values.end())
continue;
SmallVector<Value> &association = it->getSecond();
for (Value &mapped : association) {
- if (mapped == origResult)
- mapped = replacementResult;
+ if (mapped == value)
+ mapped = replacement;
}
- mappings.reverseValues[replacementResult].push_back(resultHandle);
+ mappings.reverseValues[replacement].push_back(handle);
}
}
@@ -867,6 +854,16 @@ transform::TransformState::Extension::replacePayloadOp(Operation *op,
return state.replacePayloadOp(op, replacement);
}
+LogicalResult
+transform::TransformState::Extension::replacePayloadValue(Value value,
+ Value replacement) {
+ SmallVector<Value> handles;
+ if (failed(state.getHandlesForPayloadValue(value, handles)))
+ return failure();
+
+ return state.replacePayloadValue(value, replacement);
+}
+
//===----------------------------------------------------------------------===//
// TransformResults
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/transform-state-extension.mlir b/mlir/test/Dialect/Transform/transform-state-extension.mlir
index 7885e08291716..76cd1acf34493 100644
--- a/mlir/test/Dialect/Transform/transform-state-extension.mlir
+++ b/mlir/test/Dialect/Transform/transform-state-extension.mlir
@@ -50,7 +50,7 @@ module {
transform.sequence failures(propagate) {
^bb0(%arg0: !pdl.operation):
test_add_test_extension "A"
- // This is okay because we are replacing the top-level module opeation
+ // This is okay because we are replacing the top-level module operation
// (0 results) with this operation that has _more_ (1) results.
%dummy = test_remap_operand_to_self %arg0 : !pdl.operation
}
@@ -72,7 +72,7 @@ transform.sequence failures(propagate) {
transform.sequence failures(propagate) {
^bb0(%arg0: !pdl.operation):
test_add_test_extension "A"
- // expected-error @below {{cannot replace an op with another op producing less results while tracking handles}}
+ // expected-error @below {{cannot replace an op with another op producing fewer results while tracking handles}}
%dummy = test_remap_operand_to_self %arg0 : !pdl.operation
%valuehandle = transform.get_result %dummy[0] : (!pdl.operation) -> !transform.any_value
test_remap_operand_to_self %dummy
diff --git a/mlir/test/lib/Dialect/Transform/CMakeLists.txt b/mlir/test/lib/Dialect/Transform/CMakeLists.txt
index 56c7031481b95..b86b8f56ba6c4 100644
--- a/mlir/test/lib/Dialect/Transform/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Transform/CMakeLists.txt
@@ -8,6 +8,7 @@ add_public_tablegen_target(MLIRTestTransformDialectExtensionIncGen)
add_mlir_library(MLIRTestTransformDialect
TestTransformDialectExtension.cpp
TestTransformDialectInterpreter.cpp
+ TestTransformStateExtension.cpp
EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.cpp
new file mode 100644
index 0000000000000..e88f4df2d75c8
--- /dev/null
+++ b/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.cpp
@@ -0,0 +1,36 @@
+//===- TestTransformStateExtension.cpp ------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestTransformStateExtension.h"
+
+using namespace mlir;
+
+LogicalResult
+test::TestTransformStateExtension::updateMapping(Operation *previous,
+ Operation *updated) {
+ // Update value handles. The new ops should have at least as many results as
+ // the replacement op. Fewer results are acceptable, if those results are not
+ // mapped to any handle.
+ for (auto r = updated->getNumResults(); r < previous->getNumResults(); ++r) {
+ SmallVector<Value> handles;
+ (void)getTransformState().getHandlesForPayloadValue(previous->getResult(r),
+ handles);
+ if (!handles.empty())
+ return emitError(previous->getLoc())
+ << "cannot replace an op with another op producing fewer results "
+ "while tracking handles";
+ }
+
+ for (auto [oldValue, newValue] :
+ llvm::zip(previous->getResults(), updated->getResults()))
+ if (failed(replacePayloadValue(oldValue, newValue)))
+ return failure();
+
+ // Update op handle.
+ return replacePayloadOp(previous, updated);
+}
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h b/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h
index 3b2eb7602a7b5..752b3a78141ea 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h
+++ b/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h
@@ -29,9 +29,7 @@ class TestTransformStateExtension
StringRef getMessage() const { return message.getValue(); }
- LogicalResult updateMapping(Operation *previous, Operation *updated) {
- return replacePayloadOp(previous, updated);
- }
+ LogicalResult updateMapping(Operation *previous, Operation *updated);
private:
StringAttr message;
More information about the Mlir-commits
mailing list