[Mlir-commits] [mlir] [mlir][transform] Check for invalidated iterators on payload values (PR #66472)
Matthias Springer
llvmlistbot at llvm.org
Fri Sep 15 01:57:02 PDT 2023
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/66472
Same as #66369 but for payload values. (#66369 added checks only for payload operations.)
It was necessary to change the signature of `getPayloadValues` to return an iterator. This is now similar to payload operations.
Fixes an issue in #66369 where the `LLVM_ENABLE_ABI_BREAKING_CHECKS` check was inverted.
>From 7c2a3d741089099825bb287dfebb787d08427597 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Fri, 15 Sep 2023 10:54:26 +0200
Subject: [PATCH] [mlir][transform] Check for invalidated iterators on payload
values
Same as #66369 but for payload values. (#66369 added checks only for payload operations.)
It was necessary to change the signature of `getPayloadValues` to return an iterator. This is now similar to payload operations.
Fixes an issue in #66369 where the `LLVM_ENABLE_ABI_BREAKING_CHECKS` check was inverted.
---
.../Dialect/Transform/IR/MatchInterfaces.h | 6 +-
.../Transform/IR/TransformInterfaces.h | 70 +++++++++++++++----
.../Linalg/TransformOps/LinalgMatchOps.cpp | 2 +-
.../Transform/IR/TransformInterfaces.cpp | 42 +++++------
.../lib/Dialect/Transform/IR/TransformOps.cpp | 4 +-
.../TestTransformDialectExtension.cpp | 7 +-
6 files changed, 87 insertions(+), 44 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
index 1a6afc58fef2704..c8888f294f6ca1d 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
@@ -95,15 +95,15 @@ class SingleValueMatcherOpTrait
TransformResults &results,
TransformState &state) {
Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
- ValueRange payload = state.getPayloadValues(operandHandle);
- if (payload.size() != 1) {
+ auto payload = state.getPayloadValues(operandHandle);
+ if (!llvm::hasSingleElement(payload)) {
return emitDefiniteFailure(this->getOperation()->getLoc())
<< "SingleValueMatchOpTrait requires the value handle to point to "
"a single payload value";
}
return cast<OpTy>(this->getOperation())
- .matchValue(payload[0], results, state);
+ .matchValue(*payload.begin(), results, state);
}
void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 86af59142b77d9c..31a93b05cf7a153 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -170,7 +170,7 @@ class TransformState {
/// should be emitted when the value is used.
using InvalidatedHandleMap = DenseMap<Value, std::function<void(Location)>>;
-#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
/// Debug only: A timestamp is associated with each transform IR value, so
/// that invalid iterator usage can be detected more reliably.
using TransformIRTimestampMapping = DenseMap<Value, int64_t>;
@@ -185,7 +185,7 @@ class TransformState {
ValueMapping values;
ValueMapping reverseValues;
-#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
TransformIRTimestampMapping timestamps;
void incrementTimestamp(Value value) { ++timestamps[value]; }
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
@@ -220,7 +220,7 @@ class TransformState {
auto getPayloadOps(Value value) const {
ArrayRef<Operation *> view = getPayloadOpsView(value);
-#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
// Memorize the current timestamp and make sure that it has not changed
// when incrementing or dereferencing the iterator returned by this
// function. The timestamp is incremented when the "direct" mapping is
@@ -231,7 +231,7 @@ class TransformState {
// When ops are replaced/erased, they are replaced with nullptr (until
// the data structure is compacted). Do not enumerate these ops.
return llvm::make_filter_range(view, [=](Operation *op) {
-#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
bool sameTimestamp =
currentTimestamp == this->getMapping(value).timestamps.lookup(value);
assert(sameTimestamp && "iterator was invalidated during iteration");
@@ -244,9 +244,29 @@ class TransformState {
/// corresponds to.
ArrayRef<Attribute> getParams(Value value) const;
- /// Returns the list of payload IR values that the given transform IR value
- /// corresponds to.
- ArrayRef<Value> getPayloadValues(Value handleValue) const;
+ /// Returns an iterator that enumerates all payload IR values that the given
+ /// transform IR value corresponds to.
+ auto getPayloadValues(Value handleValue) const {
+ ArrayRef<Value> view = getPayloadValuesView(handleValue);
+
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
+ // Memorize the current timestamp and make sure that it has not changed
+ // when incrementing or dereferencing the iterator returned by this
+ // function. The timestamp is incremented when the "values" mapping is
+ // resized; this would invalidate the iterator returned by this function.
+ int64_t currentTimestamp =
+ getMapping(handleValue).timestamps.lookup(handleValue);
+ return llvm::make_filter_range(view, [=](Value v) {
+ bool sameTimestamp =
+ currentTimestamp ==
+ this->getMapping(handleValue).timestamps.lookup(handleValue);
+ assert(sameTimestamp && "iterator was invalidated during iteration");
+ return true;
+ });
+#else
+ return llvm::make_range(view.begin(), view.end());
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+ }
/// Populates `handles` with all handles pointing to the given Payload IR op.
/// Returns success if such handles exist, failure otherwise.
@@ -501,12 +521,15 @@ class TransformState {
LogicalResult updateStateFromResults(const TransformResults &results,
ResultRange opResults);
- /// Returns a list of all ops that the given transform IR value corresponds to
- /// at the time when this function is called. In case an op was erased, the
- /// returned list contains nullptr. This function is helpful for
- /// transformations that apply to a particular handle.
+ /// Returns a list of all ops that the given transform IR value corresponds
+ /// to. In case an op was erased, the returned list contains nullptr. This
+ /// function is helpful for transformations that apply to a particular handle.
ArrayRef<Operation *> getPayloadOpsView(Value value) const;
+ /// Returns a list of payload IR values that the given transform IR value
+ /// corresponds to.
+ ArrayRef<Value> getPayloadValuesView(Value handleValue) const;
+
/// Sets the payload IR ops associated with the given transform IR value
/// (handle). A payload op may be associated multiple handles as long as
/// at most one of them gets consumed by further transformations.
@@ -774,7 +797,8 @@ class TransformResults {
/// corresponds to the given list of payload IR ops. Each result must be set
/// by the transformation exactly once in case of transformation succeeding.
/// The value must have a type implementing TransformHandleTypeInterface.
- template <typename Range> void set(OpResult value, Range &&ops) {
+ template <typename Range>
+ void set(OpResult value, Range &&ops) {
int64_t position = value.getResultNumber();
assert(position < static_cast<int64_t>(operations.size()) &&
"setting results for a non-existent handle");
@@ -805,7 +829,27 @@ class TransformResults {
/// set by the transformation exactly once in case of transformation
/// succeeding. The value must have a type implementing
/// TransformValueHandleTypeInterface.
- void setValues(OpResult handle, ValueRange values);
+ template <typename Range>
+ void setValues(OpResult handle, Range &&values) {
+ int64_t position = handle.getResultNumber();
+ assert(position < static_cast<int64_t>(this->values.size()) &&
+ "setting values for a non-existent handle");
+ assert(this->values[position].data() == nullptr && "values already set");
+ assert(operations[position].data() == nullptr &&
+ "another kind of results already set");
+ assert(params[position].data() == nullptr &&
+ "another kind of results already set");
+ this->values.replace(position, std::forward<Range>(values));
+ }
+
+ /// Indicates that the result of the transform IR op at the given position
+ /// corresponds to the given range of payload IR values. Each result must be
+ /// set by the transformation exactly once in case of transformation
+ /// succeeding. The value must have a type implementing
+ /// TransformValueHandleTypeInterface.
+ void setValues(OpResult handle, std::initializer_list<Value> values) {
+ setValues(handle, ArrayRef<Value>(values));
+ }
/// Indicates that the result of the transform IR op at the given position
/// corresponds to the given range of mapped values. All mapped values are
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
index 7b8bf6fc5d8f5a4..fb021ed76242e90 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
@@ -728,7 +728,7 @@ DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation(
Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position));
if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
- results.setValues(cast<OpResult>(getResult()), result);
+ results.setValues(cast<OpResult>(getResult()), {result});
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 9cac178d3c2b869..fd2cf8816ae2162 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -75,7 +75,7 @@ ArrayRef<Attribute> transform::TransformState::getParams(Value value) const {
}
ArrayRef<Value>
-transform::TransformState::getPayloadValues(Value handleValue) const {
+transform::TransformState::getPayloadValuesView(Value handleValue) const {
const ValueMapping &mapping = getMapping(handleValue).values;
auto iter = mapping.find(handleValue);
assert(iter != mapping.end() && "cannot find mapping for value handle "
@@ -310,7 +310,7 @@ void transform::TransformState::forgetMapping(Value opHandle,
for (Operation *op : mappings.direct[opHandle])
dropMappingEntry(mappings.reverse, op, opHandle);
mappings.direct.erase(opHandle);
-#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
// Payload IR is removed from the mapping. This invalidates the respective
// iterators.
mappings.incrementTimestamp(opHandle);
@@ -322,6 +322,11 @@ void transform::TransformState::forgetMapping(Value opHandle,
for (Value resultHandle : resultHandles) {
Mappings &localMappings = getMapping(resultHandle);
dropMappingEntry(localMappings.values, resultHandle, opResult);
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
+ // Payload IR is removed from the mapping. This invalidates the respective
+ // iterators.
+ mappings.incrementTimestamp(resultHandle);
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
dropMappingEntry(localMappings.reverseValues, opResult, resultHandle);
}
}
@@ -333,6 +338,11 @@ void transform::TransformState::forgetValueMapping(
for (Value payloadValue : mappings.reverseValues[valueHandle])
dropMappingEntry(mappings.reverseValues, payloadValue, valueHandle);
mappings.values.erase(valueHandle);
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
+ // Payload IR is removed from the mapping. This invalidates the respective
+ // iterators.
+ mappings.incrementTimestamp(valueHandle);
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
for (Operation *payloadOp : payloadOperations) {
SmallVector<Value> opHandles;
@@ -342,7 +352,7 @@ void transform::TransformState::forgetValueMapping(
dropMappingEntry(localMappings.direct, opHandle, payloadOp);
dropMappingEntry(localMappings.reverse, payloadOp, opHandle);
-#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
// Payload IR is removed from the mapping. This invalidates the respective
// iterators.
localMappings.incrementTimestamp(opHandle);
@@ -439,6 +449,11 @@ transform::TransformState::replacePayloadValue(Value value, Value replacement) {
// between the handles and the IR objects
if (!replacement) {
dropMappingEntry(mappings.values, handle, value);
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
+ // Payload IR is removed from the mapping. This invalidates the respective
+ // iterators.
+ mappings.incrementTimestamp(handle);
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
} else {
auto it = mappings.values.find(handle);
if (it == mappings.values.end())
@@ -647,7 +662,7 @@ void transform::TransformState::recordValueHandleInvalidation(
OpOperand &valueHandle,
transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
// Invalidate other handles to the same value.
- for (Value payloadValue : getPayloadValues(valueHandle.get())) {
+ for (Value payloadValue : getPayloadValuesView(valueHandle.get())) {
SmallVector<Value> otherValueHandles;
(void)getHandlesForPayloadValue(payloadValue, otherValueHandles);
for (Value otherHandle : otherValueHandles) {
@@ -785,7 +800,7 @@ checkRepeatedConsumptionInOperand(ArrayRef<T> payload,
void transform::TransformState::compactOpHandles() {
for (Value handle : opHandlesToCompact) {
Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
-#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
if (llvm::find(mappings.direct[handle], nullptr) !=
mappings.direct[handle].end())
// Payload IR is removed from the mapping. This invalidates the respective
@@ -846,7 +861,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
FULL_LDBG("--checkRepeatedConsumptionInOperand For Value\n");
DiagnosedSilenceableFailure check =
checkRepeatedConsumptionInOperand<Value>(
- getPayloadValues(operand.get()), transform,
+ getPayloadValuesView(operand.get()), transform,
operand.getOperandNumber());
if (!check.succeeded()) {
FULL_LDBG("----FAILED\n");
@@ -912,7 +927,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
continue;
}
if (llvm::isa<TransformValueHandleTypeInterface>(operand.getType())) {
- for (Value payloadValue : getPayloadValues(operand)) {
+ for (Value payloadValue : getPayloadValuesView(operand)) {
if (llvm::isa<OpResult>(payloadValue)) {
origAssociatedOps.push_back(payloadValue.getDefiningOp());
continue;
@@ -1170,19 +1185,6 @@ void transform::TransformResults::setParams(
this->params.replace(position, params);
}
-void transform::TransformResults::setValues(OpResult handle,
- ValueRange values) {
- int64_t position = handle.getResultNumber();
- assert(position < static_cast<int64_t>(this->values.size()) &&
- "setting values for a non-existent handle");
- assert(this->values[position].data() == nullptr && "values already set");
- assert(operations[position].data() == nullptr &&
- "another kind of results already set");
- assert(params[position].data() == nullptr &&
- "another kind of results already set");
- this->values.replace(position, values);
-}
-
void transform::TransformResults::setMappedValues(
OpResult handle, ArrayRef<MappedValue> values) {
DiagnosedSilenceableFailure diag = dispatchMappedValues(
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index de3cd1b28e435bc..cd4f628f1459ab7 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1378,9 +1378,7 @@ transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<Attribute> params;
- ArrayRef<Value> values = state.getPayloadValues(getValue());
- params.reserve(values.size());
- for (Value value : values) {
+ for (Value value : state.getPayloadValues(getValue())) {
Type type = value.getType();
if (getElemental()) {
if (auto shaped = dyn_cast<ShapedType>(type)) {
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 59f045de3246f6b..e8c25aca237251a 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -136,7 +136,7 @@ DiagnosedSilenceableFailure
mlir::test::TestProduceValueHandleToSelfOperand::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
- results.setValues(llvm::cast<OpResult>(getOut()), getIn());
+ results.setValues(llvm::cast<OpResult>(getOut()), {getIn()});
return DiagnosedSilenceableFailure::success();
}
@@ -265,8 +265,7 @@ void mlir::test::TestPrintRemarkAtOperandOp::getEffects(
DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandValue::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
- ArrayRef<Value> values = state.getPayloadValues(getIn());
- for (Value value : values) {
+ for (Value value : state.getPayloadValues(getIn())) {
std::string note;
llvm::raw_string_ostream os(note);
if (auto arg = llvm::dyn_cast<BlockArgument>(value)) {
@@ -712,7 +711,7 @@ void mlir::test::TestProduceNullValueOp::getEffects(
DiagnosedSilenceableFailure mlir::test::TestProduceNullValueOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
- results.setValues(llvm::cast<OpResult>(getOut()), Value());
+ results.setValues(llvm::cast<OpResult>(getOut()), {Value()});
return DiagnosedSilenceableFailure::success();
}
More information about the Mlir-commits
mailing list