[Mlir-commits] [mlir] 3c196f1 - [mlir][transform] Remove redundant handle check in `replacePayload...`
Matthias Springer
llvmlistbot at llvm.org
Mon Jun 26 09:02:05 PDT 2023
Author: Matthias Springer
Date: 2023-06-26T17:59:06+02:00
New Revision: 3c196f1658f3c5fd368fdaa3c2fb165ed6d7fefa
URL: https://github.com/llvm/llvm-project/commit/3c196f1658f3c5fd368fdaa3c2fb165ed6d7fefa
DIFF: https://github.com/llvm/llvm-project/commit/3c196f1658f3c5fd368fdaa3c2fb165ed6d7fefa.diff
LOG: [mlir][transform] Remove redundant handle check in `replacePayload...`
Differential Revision: https://reviews.llvm.org/D153766
Added:
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 21ddd7143966f..20f9b2122e933 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -555,7 +555,8 @@ class TransformState {
ArrayRef<Operation *> payloadOperations);
/// Replaces the given payload op with another op. If the replacement op is
- /// null, removes the association of the payload op with its handle.
+ /// null, removes the association of the payload op with its handle. Returns
+ /// failure if the op is not associated with any handle.
///
/// Note: This function does not update value handles. None of the original
/// op's results are allowed to be mapped to any value handle.
@@ -563,7 +564,7 @@ class TransformState {
/// Replaces the given payload value with another value. If the replacement
/// value is null, removes the association of the payload value with its
- /// handle.
+ /// handle. Returns failure if the value is not associated with any handle.
LogicalResult replacePayloadValue(Value value, Value replacement);
/// Records handle invalidation reporters into `newlyInvalidated`.
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index ce6ec5a3a0c5a..5bd95169233b1 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -338,14 +338,7 @@ void transform::TransformState::forgetValueMapping(
LogicalResult
transform::TransformState::replacePayloadOp(Operation *op,
Operation *replacement) {
- // Drop the mapping between the op and all handles that point to it. Don't
- // care if there are on such handles.
- SmallVector<Value> opHandles;
- (void)getHandlesForPayloadOp(op, opHandles);
- for (Value handle : opHandles) {
- Mappings &mappings = getMapping(handle);
- dropMappingEntry(mappings.reverse, op, handle);
- }
+ // TODO: consider invalidating the handles to nested objects here.
#ifndef NDEBUG
for (Value opResult : op->getResults()) {
@@ -355,23 +348,29 @@ transform::TransformState::replacePayloadOp(Operation *op,
}
#endif // NDEBUG
+ // Drop the mapping between the op and all handles that point to it. Fail if
+ // there are no handles.
+ SmallVector<Value> opHandles;
+ if (failed(getHandlesForPayloadOp(op, opHandles)))
+ return failure();
+ for (Value handle : opHandles) {
+ Mappings &mappings = getMapping(handle);
+ dropMappingEntry(mappings.reverse, op, handle);
+ }
+
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
if (options.getExpensiveChecksEnabled()) {
auto it = cachedNames.find(op);
assert(it != cachedNames.end() && "entry not found");
assert(it->second == op->getName() && "operation name mismatch");
cachedNames.erase(it);
- }
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
-
- // TODO: consider invalidating the handles to nested objects here.
-
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
- if (replacement && options.getExpensiveChecksEnabled()) {
- auto insertion = cachedNames.insert({replacement, replacement->getName()});
- if (!insertion.second) {
- assert(insertion.first->second == replacement->getName() &&
- "operation is already cached with a
diff erent name");
+ if (replacement) {
+ auto insertion =
+ cachedNames.insert({replacement, replacement->getName()});
+ if (!insertion.second) {
+ assert(insertion.first->second == replacement->getName() &&
+ "operation is already cached with a
diff erent name");
+ }
}
}
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
@@ -411,7 +410,8 @@ transform::TransformState::replacePayloadOp(Operation *op,
LogicalResult
transform::TransformState::replacePayloadValue(Value value, Value replacement) {
SmallVector<Value> valueHandles;
- (void)getHandlesForPayloadValue(value, valueHandles);
+ if (failed(getHandlesForPayloadValue(value, valueHandles)))
+ return failure();
for (Value handle : valueHandles) {
Mappings &mappings = getMapping(handle);
@@ -537,30 +537,30 @@ void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
Location ancestorLoc = ancestor->getLoc();
Location opLoc = definingOp->getLoc();
Location valueLoc = payloadValue.getLoc();
- newlyInvalidated[valueHandle] =
- [valueHandle, owner, operandNo, resultNo, argumentNo, blockNo, regionNo,
- ancestorLoc, opLoc, valueLoc](Location currentLoc) {
- InFlightDiagnostic diag = emitError(currentLoc)
- << "op uses a handle invalidated by a "
- "previously executed transform op";
- diag.attachNote(valueHandle.getLoc()) << "invalidated handle";
- diag.attachNote(owner->getLoc())
- << "invalidated by this transform op that consumes its operand #"
- << operandNo
- << " and invalidates all handles to payload IR entities "
- "associated with this operand and entities nested in them";
- diag.attachNote(ancestorLoc)
- << "ancestor op associated with the consumed handle";
- if (resultNo) {
- diag.attachNote(opLoc)
- << "op defining the value as result #" << *resultNo;
- } else {
- diag.attachNote(opLoc)
- << "op defining the value as block argument #" << argumentNo
- << " of block #" << blockNo << " in region #" << regionNo;
- }
- diag.attachNote(valueLoc) << "payload value";
- };
+ newlyInvalidated[valueHandle] = [valueHandle, owner, operandNo, resultNo,
+ argumentNo, blockNo, regionNo, ancestorLoc,
+ opLoc, valueLoc](Location currentLoc) {
+ InFlightDiagnostic diag = emitError(currentLoc)
+ << "op uses a handle invalidated by a "
+ "previously executed transform op";
+ diag.attachNote(valueHandle.getLoc()) << "invalidated handle";
+ diag.attachNote(owner->getLoc())
+ << "invalidated by this transform op that consumes its operand #"
+ << operandNo
+ << " and invalidates all handles to payload IR entities "
+ "associated with this operand and entities nested in them";
+ diag.attachNote(ancestorLoc)
+ << "ancestor op associated with the consumed handle";
+ if (resultNo) {
+ diag.attachNote(opLoc)
+ << "op defining the value as result #" << *resultNo;
+ } else {
+ diag.attachNote(opLoc)
+ << "op defining the value as block argument #" << argumentNo
+ << " of block #" << blockNo << " in region #" << regionNo;
+ }
+ diag.attachNote(valueLoc) << "payload value";
+ };
}
}
@@ -1064,10 +1064,6 @@ transform::TransformState::Extension::~Extension() = default;
LogicalResult
transform::TransformState::Extension::replacePayloadOp(Operation *op,
Operation *replacement) {
- SmallVector<Value> handles;
- if (failed(state.getHandlesForPayloadOp(op, handles)))
- return failure();
-
// TODO: we may need to invalidate handles to operations and values nested in
// the operation being replaced.
return state.replacePayloadOp(op, replacement);
@@ -1076,10 +1072,6 @@ transform::TransformState::Extension::replacePayloadOp(Operation *op,
LogicalResult
transform::TransformState::Extension::replacePayloadValue(Value value,
Value replacement) {
- SmallVector<Value> handles;
- if (failed(state.getHandlesForPayloadValue(value, handles)))
- return failure();
-
return state.replacePayloadValue(value, replacement);
}
More information about the Mlir-commits
mailing list