[Mlir-commits] [mlir] 2225928 - [mlir][transform][NFC] Store all Mappings in region stack
Matthias Springer
llvmlistbot at llvm.org
Mon Jun 26 08:36:08 PDT 2023
Author: Matthias Springer
Date: 2023-06-26T17:35:26+02:00
New Revision: 22259281c040d70bf1cd8f933a48b26950edd0dd
URL: https://github.com/llvm/llvm-project/commit/22259281c040d70bf1cd8f933a48b26950edd0dd
DIFF: https://github.com/llvm/llvm-project/commit/22259281c040d70bf1cd8f933a48b26950edd0dd.diff
LOG: [mlir][transform][NFC] Store all Mappings in region stack
Do not swap the Mappings when entering a region that is isolated from above. Simply push another Mappings struct to the stack and prevent invalid accesses during lookups.
Differential Revision: https://reviews.llvm.org/D153765
Added:
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index d54f9c404fba4..21ddd7143966f 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -262,12 +262,6 @@ class TransformState {
// class body to comply with visibility and full-declaration requirements.
inline RegionScope make_region_scope(Region ®ion);
- /// Creates a new region scope for the given isolated-from-above region.
- /// Unlike the non-isolated counterpart, there is no nesting expectation.
- // Implementation note: this method is inline but implemented outside of the
- // class body to comply with visibility and full-declaration requirements
- inline RegionScope make_isolated_region_scope(Region ®ion);
-
/// A RAII object maintaining a "stack frame" for a transform IR region. When
/// applying a transform IR operation that contains a region, the caller is
/// expected to create a RegionScope before applying the ops contained in the
@@ -282,51 +276,25 @@ class TransformState {
~RegionScope();
private:
- /// Tag structure for
diff erentiating the constructor for isolated regions.
- struct Isolated {};
-
/// Creates a new scope for mappings between values defined in the given
/// transform IR region and payload IR objects.
RegionScope(TransformState &state, Region ®ion)
: state(state), region(®ion) {
- auto res = state.mappings.try_emplace(this->region);
+ auto res = state.mappings.insert(std::make_pair(®ion, Mappings()));
assert(res.second && "the region scope is already present");
(void)res;
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
- assert(((state.regionStack.size() == 1 && !state.regionStack.back()) ||
- state.regionStack.back()->isProperAncestor(®ion)) &&
- "scope started at a non-nested region");
state.regionStack.push_back(®ion);
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
}
- /// Creates a new scope for mappings between values defined in the given
- /// isolated-from-above transform IR region and payload IR objects.
- RegionScope(TransformState &state, Region ®ion, Isolated)
- : state(state), region(®ion) {
- // Store the previous mapping stack locally.
- storedMappings = llvm::SmallDenseMap<Region *, Mappings>();
- storedMappings->swap(state.mappings);
- state.mappings.try_emplace(this->region);
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
- state.regionStack.push_back(this->region);
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
- }
-
/// Back-reference to the transform state.
TransformState &state;
/// The region this scope is associated with.
Region *region;
- /// Local copy of the mappings that existed before entering the current
- /// region. Used only when the current region is isolated so we don't
- /// accidentally look up the values defined outside the isolated region.
- std::optional<llvm::SmallDenseMap<Region *, Mappings>> storedMappings =
- std::nullopt;
-
friend RegionScope TransformState::make_region_scope(Region &);
- friend RegionScope TransformState::make_isolated_region_scope(Region &);
};
friend class RegionScope;
@@ -446,9 +414,19 @@ class TransformState {
return const_cast<TransformState *>(this)->getMapping(value);
}
Mappings &getMapping(Value value) {
- auto it = mappings.find(value.getParentRegion());
+ Region *region = value.getParentRegion();
+ auto it = mappings.find(region);
assert(it != mappings.end() &&
"trying to find a mapping for a value from an unmapped region");
+#ifndef NDEBUG
+ for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) {
+ if (r == region)
+ break;
+ if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
+ llvm_unreachable(
+ "trying to get mapping beyond region that is isolated from above");
+ }
+#endif // NDEBUG
return it->second;
}
@@ -457,9 +435,19 @@ class TransformState {
return const_cast<TransformState *>(this)->getMapping(operation);
}
Mappings &getMapping(Operation *operation) {
- auto it = mappings.find(operation->getParentRegion());
+ Region *region = operation->getParentRegion();
+ auto it = mappings.find(region);
assert(it != mappings.end() &&
"trying to find a mapping for an operation from an unmapped region");
+#ifndef NDEBUG
+ for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) {
+ if (r == region)
+ break;
+ if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
+ llvm_unreachable(
+ "trying to get mapping beyond region that is isolated from above");
+ }
+#endif // NDEBUG
return it->second;
}
@@ -676,9 +664,9 @@ class TransformState {
/// Remove all nullptrs from op handles that were added by `replacePayloadOp`.
void compactOpHandles();
- /// The mappings between transform IR values and payload IR ops, aggregated by
- /// the region in which the transform IR values are defined.
- llvm::SmallDenseMap<Region *, Mappings> mappings;
+ /// A stack of mappings between transform IR values and payload IR ops,
+ /// aggregated by the region in which the transform IR values are defined.
+ llvm::MapVector<Region *, Mappings> mappings;
/// Op handles may be temporarily mapped to nullptr to avoid invalidating
/// payload op iterators. This set contains all op handles with nullptrs.
@@ -834,14 +822,6 @@ TransformState::RegionScope TransformState::make_region_scope(Region ®ion) {
return RegionScope(*this, region);
}
-/// Creates a RAII object the lifetime of which corresponds to the new mapping
-/// for transform IR values defined in the given isolated-from-above region.
-/// Values defined in surrounding regions cannot be accessed.
-TransformState::RegionScope
-TransformState::make_isolated_region_scope(Region ®ion) {
- return RegionScope(*this, region, RegionScope::Isolated());
-}
-
/// A listener that updates a TransformState based on IR modifications. This
/// listener can be used during a greedy pattern rewrite to keep the transform
/// state up-to-date.
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 27794e6636959..ce6ec5a3a0c5a 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -45,7 +45,7 @@ transform::TransformState::TransformState(
for (ArrayRef<MappedValue> mapping : extraMappings)
topLevelMappedValues.push_back(mapping);
- auto result = mappings.try_emplace(region);
+ auto result = mappings.insert(std::make_pair(region, Mappings()));
assert(result.second && "the region scope is already present");
(void)result;
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
@@ -85,12 +85,15 @@ transform::TransformState::getPayloadValues(Value handleValue) const {
LogicalResult transform::TransformState::getHandlesForPayloadOp(
Operation *op, SmallVectorImpl<Value> &handles) const {
bool found = false;
- for (const Mappings &mapping : llvm::make_second_range(mappings)) {
+ for (const auto &[region, mapping] : llvm::reverse(mappings)) {
auto iterator = mapping.reverse.find(op);
if (iterator != mapping.reverse.end()) {
llvm::append_range(handles, iterator->getSecond());
found = true;
}
+ // Stop looking when reaching a region that is isolated from above.
+ if (region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
+ break;
}
return success(found);
@@ -99,12 +102,15 @@ LogicalResult transform::TransformState::getHandlesForPayloadOp(
LogicalResult transform::TransformState::getHandlesForPayloadValue(
Value payloadValue, SmallVectorImpl<Value> &handles) const {
bool found = false;
- for (const Mappings &mapping : llvm::make_second_range(mappings)) {
+ for (const auto &[region, mapping] : llvm::reverse(mappings)) {
auto iterator = mapping.reverseValues.find(payloadValue);
if (iterator != mapping.reverseValues.end()) {
llvm::append_range(handles, iterator->getSecond());
found = true;
}
+ // Stop looking when reaching a region that is isolated from above.
+ if (region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
+ break;
}
return success(found);
@@ -590,8 +596,10 @@ void transform::TransformState::recordOpHandleInvalidation(
// number of IR objects (operations and values). Alternatively, we could walk
// the IR nested in each payload op associated with the given handle and look
// for handles associated with each operation and value.
- for (const transform::TransformState::Mappings &mapping :
- llvm::make_second_range(mappings)) {
+ for (const auto &[region, mapping] : llvm::reverse(mappings)) {
+ // Stop lookup when reaching a region that is isolated from above.
+ if (region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
+ break;
// Go over all op handle mappings and mark as invalidated any handle
// pointing to any of the payload ops associated with the given handle or
// any op nested in them.
@@ -1102,8 +1110,6 @@ transform::TransformState::RegionScope::~RegionScope() {
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
state.mappings.erase(region);
- if (storedMappings.has_value())
- state.mappings.swap(*storedMappings);
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
// If the last handle to a payload op has gone out of scope, we no longer
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 874248993af03..1d89ff45f5dcc 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -446,7 +446,7 @@ static DiagnosedSilenceableFailure
matchBlock(Block &block, Operation *op, transform::TransformState &state,
SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) {
assert(block.getParent() && "cannot match using a detached block");
- auto matchScope = state.make_isolated_region_scope(*block.getParent());
+ auto matchScope = state.make_region_scope(*block.getParent());
if (failed(state.mapBlockArgument(block.getArgument(0), {op})))
return DiagnosedSilenceableFailure::definiteFailure();
@@ -524,7 +524,7 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
continue;
}
- auto scope = state.make_isolated_region_scope(action.getFunctionBody());
+ auto scope = state.make_region_scope(action.getFunctionBody());
for (auto &&[arg, map] : llvm::zip_equal(
action.getFunctionBody().front().getArguments(), mappings)) {
if (failed(state.mapBlockArgument(arg, map)))
@@ -1029,7 +1029,7 @@ transform::IncludeOp::apply(transform::TransformRewriter &rewriter,
// Map operands to block arguments.
SmallVector<SmallVector<MappedValue>> mappings;
detail::prepareValueMappings(mappings, getOperands(), state);
- auto scope = state.make_isolated_region_scope(callee.getBody());
+ auto scope = state.make_region_scope(callee.getBody());
for (auto &&[arg, map] :
llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
if (failed(state.mapBlockArgument(arg, map)))
More information about the Mlir-commits
mailing list