[Mlir-commits] [mlir] 0e37ef0 - [mlir][transform] Use TrackingListener-aware iterator for getPayloadOps
Matthias Springer
llvmlistbot at llvm.org
Mon May 15 01:38:01 PDT 2023
Author: Matthias Springer
Date: 2023-05-15T10:31:24+02:00
New Revision: 0e37ef08d4d8cf98941937874702825578fcb9c2
URL: https://github.com/llvm/llvm-project/commit/0e37ef08d4d8cf98941937874702825578fcb9c2
DIFF: https://github.com/llvm/llvm-project/commit/0e37ef08d4d8cf98941937874702825578fcb9c2.diff
LOG: [mlir][transform] Use TrackingListener-aware iterator for getPayloadOps
Instead of returning an `ArrayRef<Operation *>`, return at iterator that skips ops that were erased/replaced while iterating over the payload ops.
This fixes an issue in conjuction with TrackingListener, where a tracked op was erased during the iteration. Elements may not be removed from an array while iterating over it; this invalidates the iterator.
When ops are erased/removed via `replacePayloadOp`, they are not immediately removed from the mappings data structure. Instead, they are set to `nullptr`. `nullptr`s are not enumerated by `getPayloadOps`. At the end of each transformation, `nullptr`s are removed from the mapping data structure.
Differential Revision: https://reviews.llvm.org/D149847
Added:
Modified:
mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Transform/test-interpreter.mlir
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
index 8f8adffae50e3..c0d7545622e69 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
@@ -53,15 +53,15 @@ class SingleOpMatcherOpTrait
DiagnosedSilenceableFailure apply(TransformResults &results,
TransformState &state) {
Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
- ArrayRef<Operation *> payload = state.getPayloadOps(operandHandle);
- if (payload.size() != 1) {
+ auto payload = state.getPayloadOps(operandHandle);
+ if (!llvm::hasSingleElement(payload)) {
return emitDefiniteFailure(this->getOperation()->getLoc())
<< "SingleOpMatchOpTrait requires the operand handle to point to "
"a single payload op";
}
return cast<OpTy>(this->getOperation())
- .matchOperation(payload[0], results, state);
+ .matchOperation(*payload.begin(), results, state);
}
void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 39a86d3828786..6730552c9c53a 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -22,6 +22,36 @@ namespace transform {
class TransformOpInterface;
class TransformResults;
+class TransformState;
+
+using Param = Attribute;
+using MappedValue = llvm::PointerUnion<Operation *, Param, Value>;
+
+namespace detail {
+/// Maps the only block argument of the op with PossibleTopLevelTransformOpTrait
+/// to either the list of operations associated with its operand or the root of
+/// the payload IR, depending on what is available in the context.
+LogicalResult
+mapPossibleTopLevelTransformOpBlockArguments(TransformState &state,
+ Operation *op, Region ®ion);
+
+/// Verification hook for PossibleTopLevelTransformOpTrait.
+LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op);
+
+/// Verification hook for TransformOpInterface.
+LogicalResult verifyTransformOpInterface(Operation *op);
+
+/// Populates `mappings` with mapped values associated with the given transform
+/// IR values in the given `state`.
+void prepareValueMappings(
+ SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings,
+ ValueRange values, const transform::TransformState &state);
+
+/// Populates `results` with payload associations that match exactly those of
+/// the operands to `block`'s terminator.
+void forwardTerminatorOperands(Block *block, transform::TransformState &state,
+ transform::TransformResults &results);
+} // namespace detail
/// Options controlling the application of transform operations by the
/// TransformState.
@@ -46,9 +76,6 @@ class TransformOptions {
bool expensiveChecksEnabled = true;
};
-using Param = Attribute;
-using MappedValue = llvm::PointerUnion<Operation *, Param, Value>;
-
/// Entry point to the Transform dialect infrastructure. Applies the
/// transformation specified by `transform` to payload IR contained in
/// `payloadRoot`. The `transform` operation may contain other operations that
@@ -140,9 +167,16 @@ class TransformState {
return topLevelMappedValues[position];
}
- /// Returns the list of ops that the given transform IR value corresponds to.
- /// This is helpful for transformations that apply to a particular handle.
- ArrayRef<Operation *> getPayloadOps(Value value) const;
+ /// Returns an iterator that enumerates all ops that the given transform IR
+ /// value corresponds to. Ops may be erased while iterating; erased ops are
+ /// not enumerated. This function is helpful for transformations that apply to
+ /// a particular handle.
+ auto getPayloadOps(Value value) const {
+ // When ops are replaced/erased, they are replaced with nullptr (until
+ // the data structure is compacted). Do not enumerate these ops.
+ return llvm::make_filter_range(getPayloadOpsView(value),
+ [](Operation *op) { return op != nullptr; });
+ }
/// Returns the list of parameters that the given transform IR value
/// corresponds to.
@@ -407,6 +441,12 @@ class TransformState {
LogicalResult updateStateFromResults(const TransformResults &results,
ResultRange opResults);
+ /// Returns a list of all ops that the given transform IR value corresponds to
+ /// at the time when this function is called. In case an op was erased, the
+ /// returned list contains nullptr. This function is helpful for
+ /// transformations that apply to a particular handle.
+ ArrayRef<Operation *> getPayloadOpsView(Value value) const;
+
/// Sets the payload IR ops associated with the given transform IR value
/// (handle). A payload op may be associated multiple handles as long as
/// at most one of them gets consumed by further transformations.
@@ -540,10 +580,19 @@ class TransformState {
LogicalResult
checkAndRecordHandleInvalidation(TransformOpInterface transform);
+ /// Remove all nullptrs from op handles that were added by `replacePayloadOp`.
+ void compactOpHandles();
+
/// The mappings between transform IR values and payload IR ops, aggregated by
/// the region in which the transform IR values are defined.
llvm::SmallDenseMap<Region *, Mappings> mappings;
+ /// Op handles may be temporarily mapped to nullptr to avoid invalidating
+ /// payload op iterators. This set contains all op handles with nullptrs.
+ /// These handles are "compacted" (i.e., nullptrs removed) at the end of each
+ /// transform.
+ DenseSet<Value> opHandlesToCompact;
+
/// Extensions attached to the TransformState, identified by the TypeID of
/// their type. Only one extension of any given type is allowed.
DenseMap<TypeID, std::unique_ptr<Extension>> extensions;
@@ -595,7 +644,25 @@ class TransformResults {
/// corresponds to the given list of payload IR ops. Each result must be set
/// by the transformation exactly once in case of transformation succeeding.
/// The value must have a type implementing TransformHandleTypeInterface.
- void set(OpResult value, ArrayRef<Operation *> ops);
+ template <typename Range> void set(OpResult value, Range &&ops) {
+ int64_t position = value.getResultNumber();
+ assert(position < static_cast<int64_t>(operations.size()) &&
+ "setting results for a non-existent handle");
+ assert(operations[position].data() == nullptr && "results already set");
+ assert(params[position].data() == nullptr &&
+ "another kind of results already set");
+ assert(values[position].data() == nullptr &&
+ "another kind of results already set");
+ operations.replace(position, std::forward<Range>(ops));
+ }
+
+ /// Indicates that the result of the transform IR op at the given position
+ /// corresponds to the given list of payload IR ops. Each result must be set
+ /// by the transformation exactly once in case of transformation succeeding.
+ /// The value must have a type implementing TransformHandleTypeInterface.
+ void set(OpResult value, std::initializer_list<Operation *> ops) {
+ set(value, ArrayRef<Operation *>(ops));
+ }
/// Indicates that the result of the transform IR op at the given position
/// corresponds to the given list of parameters. Each result must be set by
@@ -682,32 +749,6 @@ TransformState::make_isolated_region_scope(Region ®ion) {
return RegionScope(*this, region, RegionScope::Isolated());
}
-namespace detail {
-/// Maps the only block argument of the op with PossibleTopLevelTransformOpTrait
-/// to either the list of operations associated with its operand or the root of
-/// the payload IR, depending on what is available in the context.
-LogicalResult
-mapPossibleTopLevelTransformOpBlockArguments(TransformState &state,
- Operation *op, Region ®ion);
-
-/// Verification hook for PossibleTopLevelTransformOpTrait.
-LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op);
-
-/// Verification hook for TransformOpInterface.
-LogicalResult verifyTransformOpInterface(Operation *op);
-
-/// Populates `mappings` with mapped values associated with the given transform
-/// IR values in the given `state`.
-void prepareValueMappings(
- SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings,
- ValueRange values, const transform::TransformState &state);
-
-/// Populates `results` with payload associations that match exactly those of
-/// the operands to `block`'s terminator.
-void forwardTerminatorOperands(Block *block, transform::TransformState &state,
- transform::TransformResults &results);
-} // namespace detail
-
/// This trait is supposed to be attached to Transform dialect operations that
/// can be standalone top-level transforms. Such operations typically contain
/// other Transform dialect operations that can be executed following some
@@ -1069,9 +1110,9 @@ void setApplyToOneResults(Operation *transformOp,
/// - a concrete Op class, in which case a check is performed whether
/// `targets` contains operations of the same class and a silenceable failure
/// is reported if it does not.
-template <typename TransformOpTy>
+template <typename TransformOpTy, typename Range>
DiagnosedSilenceableFailure
-applyTransformToEach(TransformOpTy transformOp, ArrayRef<Operation *> targets,
+applyTransformToEach(TransformOpTy transformOp, Range &&targets,
SmallVectorImpl<ApplyToEachResultList> &results,
TransformState &state) {
using OpTy = typename llvm::function_traits<
@@ -1133,14 +1174,13 @@ template <typename OpTy>
mlir::DiagnosedSilenceableFailure
mlir::transform::TransformEachOpTrait<OpTy>::apply(
TransformResults &transformResults, TransformState &state) {
- ArrayRef<Operation *> targets =
- state.getPayloadOps(this->getOperation()->getOperand(0));
+ auto targets = state.getPayloadOps(this->getOperation()->getOperand(0));
// Step 1. Handle the corner case where no target is specified.
// This is typically the case when the matcher fails to apply and we need to
// propagate gracefully.
// In this case, we fill all results with an empty vector.
- if (targets.empty()) {
+ if (std::empty(targets)) {
SmallVector<Operation *> emptyPayload;
SmallVector<Attribute> emptyParams;
for (OpResult r : this->getOperation()->getResults()) {
@@ -1157,7 +1197,6 @@ mlir::transform::TransformEachOpTrait<OpTy>::apply(
// Step 2. Call applyToOne on each target and record newly produced ops in its
// corresponding results entry.
SmallVector<ApplyToEachResultList, 1> results;
- results.reserve(targets.size());
DiagnosedSilenceableFailure result = detail::applyTransformToEach(
cast<OpTy>(this->getOperation()), targets, results, state);
diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
index 728da70b85731..9a952e6c9dc58 100644
--- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
+++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
@@ -75,8 +75,7 @@ SimplifyBoundedAffineOpsOp::apply(TransformResults &results,
for (const auto &it : llvm::zip_equal(getBoundedValues(), getLowerBounds(),
getUpperBounds())) {
Value handle = std::get<0>(it);
- ArrayRef<Operation *> boundedValueOps = state.getPayloadOps(handle);
- for (Operation *op : boundedValueOps) {
+ for (Operation *op : state.getPayloadOps(handle)) {
if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
auto diag =
emitDefiniteFailure()
@@ -104,8 +103,8 @@ SimplifyBoundedAffineOpsOp::apply(TransformResults &results,
}
// Transform all targets.
- ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
- for (Operation *target : targets) {
+ SmallVector<Operation *> targets;
+ for (Operation *target : state.getPayloadOps(getTarget())) {
if (!isa<AffineMinOp, AffineMaxOp>(target)) {
auto diag = emitDefiniteFailure()
<< "target must be affine.min or affine.max";
@@ -118,6 +117,7 @@ SimplifyBoundedAffineOpsOp::apply(TransformResults &results,
diag.attachNote(target->getLoc()) << "target/constrained op";
return diag;
}
+ targets.push_back(target);
}
SmallVector<Operation *> transformed;
RewritePatternSet patterns(getContext());
diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index 5e36b55cff840..6dd32b815afce 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -41,7 +41,7 @@ transform::OneShotBufferizeOp::apply(TransformResults &transformResults,
options.setFunctionBoundaryTypeConversion(
*getFunctionBoundaryTypeConversion());
- ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
+ auto payloadOps = state.getPayloadOps(getTarget());
for (Operation *target : payloadOps) {
if (!isa<ModuleOp, FunctionOpInterface>(target))
return emitSilenceableError() << "expected module or function target";
@@ -80,8 +80,7 @@ transform::EliminateEmptyTensorsOp::apply(TransformResults &transformResults,
OneShotBufferizationOptions options;
options.allowReturnAllocs = true;
- ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
- for (Operation *target : payloadOps) {
+ for (Operation *target : state.getPayloadOps(getTarget())) {
OneShotAnalysisState state(target, options);
if (failed(analyzeOp(target, state)))
return mlir::emitSilenceableFailure(target->getLoc())
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
index 02909bb69977f..0d9d533da3664 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
@@ -618,7 +618,7 @@ DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation(
}
Operation *firstUser = *result.getUsers().begin();
if (getAny()) {
- results.set(cast<OpResult>(getResult()), firstUser);
+ results.set(cast<OpResult>(getResult()), {firstUser});
return DiagnosedSilenceableFailure::success();
}
if (getSingle()) {
@@ -626,7 +626,7 @@ DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation(
return emitSilenceableError()
<< "more than one result user with single user requested";
}
- results.set(cast<OpResult>(getResult()), firstUser);
+ results.set(cast<OpResult>(getResult()), {firstUser});
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ea8d285cf52b7..afef59990afc1 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -92,17 +92,17 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPDLOperations(
result.push_back(ofr);
continue;
}
- ArrayRef<Operation *> payloadOps = state.getPayloadOps(ofr.get<Value>());
- if (payloadOps.size() != 1) {
+ auto payloadOps = state.getPayloadOps(ofr.get<Value>());
+ if (!llvm::hasSingleElement(payloadOps)) {
DiagnosedSilenceableFailure diag =
transformOp.emitSilenceableError()
<< "handle must be mapped to exactly one payload op";
diag.attachNote(ofr.get<Value>().getLoc())
- << "mapped to " << payloadOps.size() << " payload ops";
+ << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
return diag;
}
- Operation *op = payloadOps[0];
+ Operation *op = *payloadOps.begin();
if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
DiagnosedSilenceableFailure diag =
transformOp.emitSilenceableError()
@@ -125,8 +125,7 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPDLOperations(
static DiagnosedSilenceableFailure unpackSingleIndexResultPDLOperations(
transform::TransformState &state, TransformOpInterface transformOp,
SmallVector<OpFoldResult> &result, Value packedHandle) {
- ArrayRef<Operation *> payloadOps = state.getPayloadOps(packedHandle);
- for (Operation *op : payloadOps) {
+ for (Operation *op : state.getPayloadOps(packedHandle)) {
if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
DiagnosedSilenceableFailure diag =
transformOp.emitSilenceableError()
@@ -208,16 +207,14 @@ transform::DecomposeOp::applyToOne(LinalgOp target,
/// Apply a tiling transformation to all payload ops and store both the
/// tiled operation as well as the created tile loops.
+template <typename Range>
static LogicalResult applyTilingToAll(
- RewriterBase &rewriter, Operation *transformOp,
- ArrayRef<Operation *> payloadOps, unsigned numLoops,
- transform::TransformResults &transformResults,
+ RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
+ unsigned numLoops, transform::TransformResults &transformResults,
function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
applyFn) {
SmallVector<Operation *> tiledLinalgOps;
SmallVector<SmallVector<Operation *>> loopOps(numLoops);
- for (unsigned int i = 0; i < numLoops; ++i)
- loopOps[i].reserve(payloadOps.size());
for (Operation *target : payloadOps) {
auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
@@ -578,19 +575,19 @@ DiagnosedSilenceableFailure
transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<Operation *> fusedOps;
- ArrayRef<Operation *> producerOps = state.getPayloadOps(getProducerOp());
+ auto producerOps = state.getPayloadOps(getProducerOp());
// If nothing to fuse, propagate success.
- if (producerOps.empty()) {
+ if (std::empty(producerOps)) {
results.set(cast<OpResult>(getFusedOp()), SmallVector<mlir::Operation *>{});
return DiagnosedSilenceableFailure::success();
}
- ArrayRef<Operation *> containingOps = state.getPayloadOps(getContainingOp());
- if (containingOps.size() != 1) {
+ auto containingOps = state.getPayloadOps(getContainingOp());
+ if (!llvm::hasSingleElement(containingOps)) {
return emitDefiniteFailure()
<< "requires exactly one containing_op handle (got "
- << containingOps.size() << ")";
+ << llvm::range_size(containingOps) << ")";
}
- Operation *containingOp = containingOps.front();
+ Operation *containingOp = *containingOps.begin();
// Helper function to find the next producer that should be fused. Take any
// producer that has a use inside the containing op.
@@ -810,8 +807,8 @@ transform::MatchOp::apply(transform::TransformResults &results,
strs.insert(getOps()->getAsValueRange<StringAttr>().begin(),
getOps()->getAsValueRange<StringAttr>().end());
- ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
- if (payloadOps.size() != 1) {
+ auto payloadOps = state.getPayloadOps(getTarget());
+ if (!llvm::hasSingleElement(payloadOps)) {
return emitDefiniteFailure("requires exactly one target handle");
}
@@ -857,7 +854,7 @@ transform::MatchOp::apply(transform::TransformResults &results,
return;
};
- payloadOps.front()->walk(matchFun);
+ (*payloadOps.begin())->walk(matchFun);
results.set(cast<OpResult>(getResult()), res);
return DiagnosedSilenceableFailure::success();
}
@@ -996,18 +993,19 @@ SmallVector<OpFoldResult> transform::PackOp::getMixedPackedSizes() {
DiagnosedSilenceableFailure
transform::PackOp::apply(transform::TransformResults &transformResults,
transform::TransformState &state) {
- ArrayRef<Operation *> targetOps = state.getPayloadOps(getTarget());
+ auto targetOps = state.getPayloadOps(getTarget());
// If nothing to pack, propagate success.
- if (targetOps.empty()) {
- transformResults.set(cast<OpResult>(getPackedOp()), {});
+ if (std::empty(targetOps)) {
+ transformResults.set(cast<OpResult>(getPackedOp()),
+ ArrayRef<Operation *>({}));
return DiagnosedSilenceableFailure::success();
}
// Fail on multi-op handles.
- auto linalgOp = dyn_cast<LinalgOp>(targetOps.front());
- if (targetOps.size() != 1 || !linalgOp) {
+ auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
+ if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
return emitSilenceableError()
<< "requires target to map to exactly 1 LinalgOp (got "
- << targetOps.size() << ")";
+ << llvm::range_size(targetOps) << ")";
}
// Fail on mismatched number of pack sizes.
if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
@@ -1030,7 +1028,7 @@ transform::PackOp::apply(transform::TransformResults &transformResults,
return emitDefiniteFailure("data tiling failed");
transformResults.set(cast<OpResult>(getPackedOp()),
- maybeResult->packedLinalgOp.getOperation());
+ {maybeResult->packedLinalgOp.getOperation()});
return DiagnosedSilenceableFailure::success();
}
@@ -1204,16 +1202,10 @@ packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
DiagnosedSilenceableFailure
PackGreedilyOp::apply(transform::TransformResults &transformResults,
transform::TransformState &state) {
- ArrayRef<Operation *> targetOpsView = state.getPayloadOps(getTarget());
- // Store payload ops into a separate SmallVector because the TrackingListener
- // removes erased ops from the transform state.
- SmallVector<Operation *> targetOps(targetOpsView.begin(),
- targetOpsView.end());
-
SmallVector<Operation *> results;
TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
- for (Operation *op : targetOps) {
+ for (Operation *op : state.getPayloadOps(getTarget())) {
auto linalgOp = dyn_cast<LinalgOp>(op);
if (!linalgOp)
continue;
@@ -1310,11 +1302,10 @@ bool isValidPackingPermutation(
DiagnosedSilenceableFailure
transform::PackTransposeOp::apply(transform::TransformResults &transformResults,
transform::TransformState &state) {
- ArrayRef<Operation *> packOrUnpackOps =
- state.getPayloadOps(getTargetPackOrUnPackOp());
- ArrayRef<Operation *> linalgOps = state.getPayloadOps(getTargetLinalgOp());
+ auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
+ auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
// Step 1. If nothing to pack, propagate success.
- if (packOrUnpackOps.empty()) {
+ if (std::empty(packOrUnpackOps)) {
transformResults.set(cast<OpResult>(getPackedOp()), {});
transformResults.set(cast<OpResult>(getPackOp()), {});
transformResults.set(cast<OpResult>(getUnPackOp()), {});
@@ -1323,21 +1314,23 @@ transform::PackTransposeOp::apply(transform::TransformResults &transformResults,
// Step 2. Bunch of runtime sanity check and error messages.
// Step 2.1. Fail on multi-op handles.
- if (packOrUnpackOps.size() != 1 || linalgOps.size() != 1) {
- return emitSilenceableError() << "requires target to map to exactly 1 "
- "packing op and 1 packed op ("
- << "got " << packOrUnpackOps.size() << " and "
- << linalgOps.size() << ")";
+ if (!llvm::hasSingleElement(packOrUnpackOps) ||
+ !llvm::hasSingleElement(linalgOps)) {
+ return emitSilenceableError()
+ << "requires target to map to exactly 1 "
+ "packing op and 1 packed op ("
+ << "got " << llvm::range_size(packOrUnpackOps) << " and "
+ << llvm::range_size(linalgOps) << ")";
}
// Step 2.2. Fail on wrong type.
- auto packOp = dyn_cast<tensor::PackOp>(packOrUnpackOps.front());
- auto unPackOp = dyn_cast<tensor::UnPackOp>(packOrUnpackOps.front());
+ auto packOp = dyn_cast<tensor::PackOp>(*packOrUnpackOps.begin());
+ auto unPackOp = dyn_cast<tensor::UnPackOp>(*packOrUnpackOps.begin());
if ((!packOp && !unPackOp)) {
return emitSilenceableError() << "requires target to map to a "
"tensor.pack or tensor.unpack";
}
- LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(linalgOps.front());
+ LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
if (!linalgOpTarget)
return emitSilenceableError() << "requires a LinalgOp target";
@@ -1520,16 +1513,17 @@ LogicalResult transform::PadOp::verify() {
DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
transform::TransformResults &transformResults,
transform::TransformState &state) {
- ArrayRef<Operation *> targetOps = state.getPayloadOps(getTarget());
- ArrayRef<Operation *> loopOps = state.getPayloadOps(getLoop());
- if (targetOps.size() != 1 || loopOps.size() != 1) {
+ auto targetOps = state.getPayloadOps(getTarget());
+ auto loopOps = state.getPayloadOps(getLoop());
+ if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
return emitDefiniteFailure()
<< "requires exactly one target and one loop handle (got "
- << targetOps.size() << " and " << loopOps.size() << ")";
+ << llvm::range_size(targetOps) << " and "
+ << llvm::range_size(loopOps) << ")";
}
- auto padOp = dyn_cast_or_null<tensor::PadOp>(targetOps.front());
- auto loopOp = dyn_cast_or_null<scf::ForOp>(loopOps.front());
+ auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
+ auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
if (!padOp || !loopOp)
return emitDefiniteFailure() << "requires exactly 2 non-null handles";
@@ -1543,13 +1537,13 @@ DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
if (result->clonedLoopIvs.empty()) {
transformResults.set(cast<OpResult>(getPackingLoop()),
- result->hoistedPadOp.getOperation());
+ {result->hoistedPadOp.getOperation()});
return DiagnosedSilenceableFailure::success();
}
auto outerPackedLoop =
scf::getForInductionVarOwner(result->clonedLoopIvs.front());
transformResults.set(cast<OpResult>(getPackingLoop()),
- outerPackedLoop.getOperation());
+ {outerPackedLoop.getOperation()});
return DiagnosedSilenceableFailure::success();
}
@@ -1679,7 +1673,7 @@ transform::PromoteOp::applyToOne(LinalgOp target,
DiagnosedSilenceableFailure
transform::ReplaceOp::apply(TransformResults &transformResults,
TransformState &state) {
- ArrayRef<Operation *> payload = state.getPayloadOps(getTarget());
+ auto payload = state.getPayloadOps(getTarget());
// Check for invalid targets.
for (Operation *target : payload) {
@@ -1814,7 +1808,8 @@ transform::RewriteInDestinationPassingStyleOp::applyToOne(
DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
TransformState &state) {
// Collect the dynamic split points if provided.
- ArrayRef<Operation *> payload = state.getPayloadOps(getTarget());
+ SmallVector<Operation *> payload =
+ llvm::to_vector(state.getPayloadOps(getTarget()));
TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
SmallVector<OpFoldResult> splitPoints;
@@ -2199,8 +2194,9 @@ transform::TileOp::apply(TransformResults &transformResults,
TransformState &state) {
ArrayRef<int64_t> tileSizes = getStaticSizes();
- ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
- SmallVector<ArrayRef<Operation *>> dynamicSizeProducers;
+ SmallVector<Operation *> targets =
+ llvm::to_vector(state.getPayloadOps(getTarget()));
+ SmallVector<SmallVector<Operation *>> dynamicSizeProducers;
SmallVector<SmallVector<int64_t>> paramSizes;
dynamicSizeProducers.reserve(getDynamicSizes().size());
paramSizes.reserve(getDynamicSizes().size());
@@ -2226,7 +2222,8 @@ transform::TileOp::apply(TransformResults &transformResults,
continue;
}
paramSizes.push_back({});
- dynamicSizeProducers.push_back(state.getPayloadOps(transformValue));
+ dynamicSizeProducers.push_back(
+ llvm::to_vector(state.getPayloadOps(transformValue)));
if (dynamicSizeProducers.back().size() != targets.size()) {
DiagnosedSilenceableFailure diag =
@@ -2536,10 +2533,6 @@ transform::TileToForallOp::apply(transform::TransformResults &transformResults,
TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
auto transformOp = cast<TransformOpInterface>(getOperation());
- ArrayRef<Operation *> targetsView = state.getPayloadOps(getTarget());
- // Store payload ops into a separate SmallVector because the TrackingListener
- // removes erased ops from the transform state.
- SmallVector<Operation *> targets(targetsView.begin(), targetsView.end());
// Result payload ops.
SmallVector<Operation *> tileOps;
@@ -2564,7 +2557,7 @@ transform::TileToForallOp::apply(transform::TransformResults &transformResults,
if (!status.succeeded())
return status;
- for (Operation *target : targets) {
+ for (Operation *target : state.getPayloadOps(getTarget())) {
linalg::ForallTilingResult tilingResult;
DiagnosedSilenceableFailure diag = tileToForallOpImpl(
rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
@@ -2652,12 +2645,13 @@ transform::TileToScfForOp::apply(TransformResults &transformResults,
TransformState &state) {
ArrayRef<int64_t> tileSizes = getStaticSizes();
- ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
- SmallVector<ArrayRef<Operation *>> dynamicSizeProducers;
+ SmallVector<Operation *> targets =
+ llvm::to_vector(state.getPayloadOps(getTarget()));
+ SmallVector<SmallVector<Operation *>> dynamicSizeProducers;
dynamicSizeProducers.reserve(getDynamicSizes().size());
for (Value dynamicSizeProducerHandle : getDynamicSizes()) {
dynamicSizeProducers.push_back(
- state.getPayloadOps(dynamicSizeProducerHandle));
+ llvm::to_vector(state.getPayloadOps(dynamicSizeProducerHandle)));
if (dynamicSizeProducers.back().size() != targets.size()) {
DiagnosedSilenceableFailure diag =
@@ -2884,8 +2878,8 @@ DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply(
mlir::transform::TransformState &state) {
TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
- ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
- if (targets.empty())
+ auto targets = state.getPayloadOps(getTarget());
+ if (std::empty(targets))
return DiagnosedSilenceableFailure::success();
SmallVector<int64_t> vectorSizes;
@@ -2896,16 +2890,16 @@ DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply(
continue;
}
- ArrayRef<Operation *> szPayloads = state.getPayloadOps(sz.get<Value>());
- if (szPayloads.size() != 1) {
+ auto szPayloads = state.getPayloadOps(sz.get<Value>());
+ if (!llvm::hasSingleElement(szPayloads)) {
auto diag = this->emitOpError(
"requires vector size handle that is mapped to 1 payload op");
diag.attachNote(sz.get<Value>().getLoc())
- << "mapped to " << szPayloads.size() << " payload ops";
+ << "mapped to " << llvm::range_size(szPayloads) << " payload ops";
return DiagnosedSilenceableFailure::definiteFailure();
}
- Operation *szPayloadOp = szPayloads[0];
+ Operation *szPayloadOp = *szPayloads.begin();
if (szPayloadOp->getNumResults() != 1 ||
!szPayloadOp->getResult(0).getType().isIndex()) {
auto diag = this->emitOpError(
diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
index ae2472db4f862..c523af936e235 100644
--- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
+++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
@@ -35,9 +35,8 @@ DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply(
transform::TransformResults &transformResults,
transform::TransformState &state) {
SmallVector<Operation *> results;
- ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
IRRewriter rewriter(getContext());
- for (auto *op : payloadOps) {
+ for (Operation *op : state.getPayloadOps(getTarget())) {
bool canApplyMultiBuffer = true;
auto target = cast<memref::AllocOp>(op);
LLVM_DEBUG(DBGS() << "Start multibuffer transform op: " << target << "\n";);
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 4f3afd8e695d7..bad1d74fb473c 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -53,7 +53,7 @@ transform::TransformState::TransformState(
Operation *transform::TransformState::getTopLevel() const { return topLevel; }
ArrayRef<Operation *>
-transform::TransformState::getPayloadOps(Value value) const {
+transform::TransformState::getPayloadOpsView(Value value) const {
const TransformOpMapping &operationMapping = getMapping(value).direct;
auto iter = operationMapping.find(value);
assert(
@@ -357,18 +357,8 @@ transform::TransformState::replacePayloadOp(Operation *op,
// TODO: consider invalidating the handles to nested objects here.
- // If replacing with null, that is erasing the mapping, drop the mapping
- // between the handles and the IR objects and return.
- if (!replacement) {
- for (Value handle : opHandles) {
- Mappings &mappings = getMapping(handle);
- dropMappingEntry(mappings.direct, handle, op);
- }
- return success();
- }
-
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
- if (options.getExpensiveChecksEnabled()) {
+ if (replacement && options.getExpensiveChecksEnabled()) {
auto insertion = cachedNames.insert({replacement, replacement->getName()});
if (!insertion.second) {
assert(insertion.first->second == replacement->getName() &&
@@ -377,8 +367,15 @@ transform::TransformState::replacePayloadOp(Operation *op,
}
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
- // Otherwise, replace the pointed-to object of all handles while preserving
- // their relative order. First, replace the mapped operation if present.
+ // Replace the pointed-to object of all handles with the replacement object.
+ // In case a payload op was erased (replacement object is nullptr), a nullptr
+ // is stored in the mapping. These nullptrs are removed after each transform.
+ // Furthermore, nullptrs are not enumerated by payload op iterators. The
+ // relative order of ops is preserved.
+ //
+ // Removing an op from the mapping would be problematic because removing an
+ // element from an array invalidates iterators; merely changing the value of
+ // elements does not.
for (Value handle : opHandles) {
Mappings &mappings = getMapping(handle);
auto it = mappings.direct.find(handle);
@@ -391,7 +388,12 @@ transform::TransformState::replacePayloadOp(Operation *op,
if (mapped == op)
mapped = replacement;
}
- mappings.reverse[replacement].push_back(handle);
+
+ if (replacement) {
+ mappings.reverse[replacement].push_back(handle);
+ } else {
+ opHandlesToCompact.insert(handle);
+ }
}
return success();
@@ -645,7 +647,7 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
FULL_LDBG("----found consume effect -> SKIP\n");
if (llvm::isa<TransformHandleTypeInterface>(target.get().getType())) {
FULL_LDBG("----recordOpHandleInvalidation\n");
- ArrayRef<Operation *> payloadOps = getPayloadOps(target.get());
+ ArrayRef<Operation *> payloadOps = getPayloadOpsView(target.get());
recordOpHandleInvalidation(target, payloadOps);
} else if (llvm::isa<TransformValueHandleTypeInterface>(
target.get().getType())) {
@@ -686,6 +688,14 @@ checkRepeatedConsumptionInOperand(ArrayRef<T> payload,
return DiagnosedSilenceableFailure::success();
}
+void transform::TransformState::compactOpHandles() {
+ for (Value handle : opHandlesToCompact) {
+ Mappings &mappings = getMapping(handle);
+ llvm::erase_value(mappings.direct[handle], nullptr);
+ }
+ opHandlesToCompact.clear();
+}
+
DiagnosedSilenceableFailure
transform::TransformState::applyTransform(TransformOpInterface transform) {
LLVM_DEBUG({
@@ -721,7 +731,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
FULL_LDBG("--checkRepeatedConsumptionInOperand for Operation*\n");
DiagnosedSilenceableFailure check =
checkRepeatedConsumptionInOperand<Operation *>(
- getPayloadOps(operand.get()), transform,
+ getPayloadOpsView(operand.get()), transform,
operand.getOperandNumber());
if (!check.succeeded()) {
FULL_LDBG("----FAILED\n");
@@ -835,6 +845,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
// proceed on a best effort basis.
transform::TransformResults results(transform->getNumResults());
DiagnosedSilenceableFailure result(transform.apply(results, *this));
+ compactOpHandles();
if (result.isDefiniteFailure())
return result;
@@ -988,19 +999,6 @@ transform::TransformResults::TransformResults(unsigned numSegments) {
values.appendEmptyRows(numSegments);
}
-void transform::TransformResults::set(OpResult value,
- ArrayRef<Operation *> ops) {
- int64_t position = value.getResultNumber();
- assert(position < static_cast<int64_t>(operations.size()) &&
- "setting results for a non-existent handle");
- assert(operations[position].data() == nullptr && "results already set");
- assert(params[position].data() == nullptr &&
- "another kind of results already set");
- assert(values[position].data() == nullptr &&
- "another kind of results already set");
- operations.replace(position, ops);
-}
-
void transform::TransformResults::setParams(
OpResult value, ArrayRef<transform::TransformState::Param> params) {
int64_t position = value.getResultNumber();
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index ecd5d2a915ab6..ad001707ddd64 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -767,10 +767,9 @@ LogicalResult transform::ForeachMatchOp::verifySymbolUses(
DiagnosedSilenceableFailure
transform::ForeachOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
- ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
SmallVector<SmallVector<Operation *>> resultOps(getNumResults(), {});
- for (Operation *op : payloadOps) {
+ for (Operation *op : state.getPayloadOps(getTarget())) {
auto scope = state.make_region_scope(getBody());
if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
return DiagnosedSilenceableFailure::definiteFailure();
@@ -785,8 +784,7 @@ transform::ForeachOp::apply(transform::TransformResults &results,
// Append yielded payload ops to result list (if any).
for (unsigned i = 0; i < getNumResults(); ++i) {
- ArrayRef<Operation *> yieldedOps =
- state.getPayloadOps(getYieldOp().getOperand(i));
+ auto yieldedOps = state.getPayloadOps(getYieldOp().getOperand(i));
resultOps[i].append(yieldedOps.begin(), yieldedOps.end());
}
}
@@ -882,16 +880,16 @@ DiagnosedSilenceableFailure
transform::GetConsumersOfResult::apply(transform::TransformResults &results,
transform::TransformState &state) {
int64_t resultNumber = getResultNumber();
- ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
- if (payloadOps.empty()) {
- results.set(llvm::cast<OpResult>(getResult()), {});
+ auto payloadOps = state.getPayloadOps(getTarget());
+ if (std::empty(payloadOps)) {
+ results.set(cast<OpResult>(getResult()), {});
return DiagnosedSilenceableFailure::success();
}
- if (payloadOps.size() != 1)
+ if (!llvm::hasSingleElement(payloadOps))
return emitDefiniteFailure()
<< "handle must be mapped to exactly one payload op";
- Operation *target = payloadOps.front();
+ Operation *target = *payloadOps.begin();
if (target->getNumResults() <= resultNumber)
return emitDefiniteFailure() << "result number overflow";
results.set(llvm::cast<OpResult>(getResult()),
@@ -1483,7 +1481,7 @@ void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result,
DiagnosedSilenceableFailure
transform::SplitHandleOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
- int64_t numPayloadOps = state.getPayloadOps(getHandle()).size();
+ int64_t numPayloadOps = llvm::range_size(state.getPayloadOps(getHandle()));
auto produceNumOpsError = [&]() {
return emitSilenceableError()
<< getHandle() << " expected to contain " << this->getNumResults()
@@ -1573,11 +1571,12 @@ void transform::PDLMatchOp::getEffects(
DiagnosedSilenceableFailure
transform::ReplicateOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
- unsigned numRepetitions = state.getPayloadOps(getPattern()).size();
+ unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern()));
for (const auto &en : llvm::enumerate(getHandles())) {
Value handle = en.value();
- if (llvm::isa<TransformHandleTypeInterface>(handle.getType())) {
- ArrayRef<Operation *> current = state.getPayloadOps(handle);
+ if (isa<TransformHandleTypeInterface>(handle.getType())) {
+ SmallVector<Operation *> current =
+ llvm::to_vector(state.getPayloadOps(handle));
SmallVector<Operation *> payload;
payload.reserve(numRepetitions * current.size());
for (unsigned i = 0; i < numRepetitions; ++i)
@@ -2011,8 +2010,7 @@ transform::PrintOp::apply(transform::TransformResults &results,
}
llvm::outs() << "]]]\n";
- ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
- for (Operation *target : targets)
+ for (Operation *target : state.getPayloadOps(getTarget()))
llvm::outs() << *target << "\n";
return DiagnosedSilenceableFailure::success();
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 8ceb72d8d46b3..fb89d57a9332f 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics
+// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics | FileCheck %s
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
@@ -1598,3 +1598,26 @@ module attributes { transform.with_named_sequence } {
return
}
}
+
+// -----
+
+// CHECK-LABEL: func @test_tracked_rewrite() {
+// CHECK-NEXT: "test.update_mapping"() {original_op = "test.replace_me"}
+// CHECK-NEXT: "test.drop_mapping"() {original_op = "test.replace_me"}
+// CHECK-NEXT: "test.update_mapping"() {original_op = "test.replace_me"}
+// CHECK-NEXT: }
+func.func @test_tracked_rewrite() {
+ %0 = "test.replace_me"() {replacement = "test.update_mapping"} : () -> (i1)
+ %1 = "test.replace_me"() {replacement = "test.drop_mapping"} : () -> (i1)
+ %2 = "test.replace_me"() {replacement = "test.update_mapping"} : () -> (i1)
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["test.replace_me"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ // expected-remark @below {{2 iterations}}
+ transform.test_tracked_rewrite %0 : (!pdl.operation) -> ()
+ // One replacement op (test.drop_mapping) is dropped from the mapping.
+ // expected-remark @below {{2}}
+ test_print_number_of_associated_payload_ir_ops %0
+}
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 1d1bbc3a57084..50a4c92da9aa4 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -109,10 +110,10 @@ DiagnosedSilenceableFailure
mlir::test::TestProduceSelfHandleOrForwardOperandOp::apply(
transform::TransformResults &results, transform::TransformState &state) {
if (getOperation()->getNumOperands() != 0) {
- results.set(llvm::cast<OpResult>(getResult()),
- getOperation()->getOperand(0).getDefiningOp());
+ results.set(cast<OpResult>(getResult()),
+ {getOperation()->getOperand(0).getDefiningOp()});
} else {
- results.set(llvm::cast<OpResult>(getResult()), getOperation());
+ results.set(cast<OpResult>(getResult()), {getOperation()});
}
return DiagnosedSilenceableFailure::success();
}
@@ -191,12 +192,13 @@ void mlir::test::TestConsumeOperand::getEffects(
DiagnosedSilenceableFailure mlir::test::TestConsumeOperandOfOpKindOrFail::apply(
transform::TransformResults &results, transform::TransformState &state) {
- ArrayRef<Operation *> payload = state.getPayloadOps(getOperand());
- assert(payload.size() == 1 && "expected a single target op");
- if (payload[0]->getName().getStringRef() != getOpKind()) {
+ auto payload = state.getPayloadOps(getOperand());
+ assert(llvm::hasSingleElement(payload) && "expected a single target op");
+ if ((*payload.begin())->getName().getStringRef() != getOpKind()) {
return emitSilenceableError()
<< "op expected the operand to be associated a payload op of kind "
- << getOpKind() << " got " << payload[0]->getName().getStringRef();
+ << getOpKind() << " got "
+ << (*payload.begin())->getName().getStringRef();
}
emitRemark() << "succeeded";
@@ -230,7 +232,7 @@ void mlir::test::TestSucceedIfOperandOfOpKind::getEffects(
DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandOp::apply(
transform::TransformResults &results, transform::TransformState &state) {
- ArrayRef<Operation *> payload = state.getPayloadOps(getOperand());
+ auto payload = state.getPayloadOps(getOperand());
for (Operation *op : payload)
op->emitRemark() << getMessage();
@@ -313,11 +315,11 @@ DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply(
if (!extension)
return emitDefiniteFailure("TestTransformStateExtension missing");
- if (failed(extension->updateMapping(state.getPayloadOps(getOperand()).front(),
- getOperation())))
+ if (failed(extension->updateMapping(
+ *state.getPayloadOps(getOperand()).begin(), getOperation())))
return DiagnosedSilenceableFailure::definiteFailure();
if (getNumResults() > 0)
- results.set(llvm::cast<OpResult>(getResult(0)), getOperation());
+ results.set(cast<OpResult>(getResult(0)), {getOperation()});
return DiagnosedSilenceableFailure::success();
}
@@ -337,7 +339,7 @@ DiagnosedSilenceableFailure mlir::test::TestRemoveTestExtensionOp::apply(
DiagnosedSilenceableFailure
mlir::test::TestReversePayloadOpsOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
- ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
+ auto payloadOps = state.getPayloadOps(getTarget());
auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps));
results.set(llvm::cast<OpResult>(getResult()), reversedOps);
return DiagnosedSilenceableFailure::success();
@@ -431,7 +433,7 @@ mlir::test::TestPrintNumberOfAssociatedPayloadIROps::apply(
transform::TransformResults &results, transform::TransformState &state) {
if (!getHandle())
emitRemark() << 0;
- emitRemark() << state.getPayloadOps(getHandle()).size();
+ emitRemark() << llvm::range_size(state.getPayloadOps(getHandle()));
return DiagnosedSilenceableFailure::success();
}
@@ -599,11 +601,11 @@ mlir::test::TestProduceTransformParamOrForwardOperandOp::applyToOne(
} else if (getFirstResultIsNull()) {
results.push_back(nullptr);
} else {
- results.push_back(state.getPayloadOps(getIn()).front());
+ results.push_back(*state.getPayloadOps(getIn()).begin());
}
if (getSecondResultIsHandle()) {
- results.push_back(state.getPayloadOps(getIn()).front());
+ results.push_back(*state.getPayloadOps(getIn()).begin());
} else {
results.push_back(builder.getI64IntegerAttr(42));
}
@@ -667,6 +669,70 @@ DiagnosedSilenceableFailure mlir::test::TestRequiredMemoryEffectsOp::apply(
return DiagnosedSilenceableFailure::success();
}
+void mlir::test::TestTrackedRewriteOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::onlyReadsHandle(getIn(), effects);
+ transform::modifiesPayload(effects);
+}
+
+namespace {
+/// A TrackingListener for test cases. When the replacement op is
+/// "test.update_mapping", it is considered as a replacement op in the transform
+/// state mapping. Otherwise, it is not and the original op is simply removed
+/// from the mapping.
+class TestTrackingListener : public transform::TrackingListener {
+ using transform::TrackingListener::TrackingListener;
+
+protected:
+ Operation *findReplacementOp(Operation *op,
+ ValueRange newValues) const override {
+ if (newValues.size() != 1)
+ return nullptr;
+ Operation *replacement = newValues[0].getDefiningOp();
+ if (!replacement)
+ return nullptr;
+ if (replacement->getName().getStringRef() != "test.update_mapping")
+ return nullptr;
+ return replacement;
+ }
+};
+} // namespace
+
+DiagnosedSilenceableFailure
+mlir::test::TestTrackedRewriteOp::apply(transform::TransformResults &results,
+ transform::TransformState &state) {
+ TestTrackingListener listener(state, *this);
+ IRRewriter rewriter(getContext(), &listener);
+ int64_t numIterations = 0;
+
+ // `getPayloadOps` returns an iterator that skips ops that are erased in the
+ // loop body. Replacement ops are not enumerated.
+ for (Operation *op : state.getPayloadOps(getIn())) {
+ ++numIterations;
+ rewriter.setInsertionPointToEnd(op->getBlock());
+
+ // Erase all payload ops. The outer loop should have only one iteration.
+ for (Operation *op : state.getPayloadOps(getIn())) {
+ if (op->getName().getStringRef() != "test.replace_me")
+ continue;
+ auto replacementName = op->getAttrOfType<StringAttr>("replacement");
+ if (!replacementName)
+ continue;
+ SmallVector<NamedAttribute> attributes;
+ attributes.emplace_back(rewriter.getStringAttr("original_op"),
+ op->getName().getIdentifier());
+ OperationState opState(op->getLoc(), replacementName,
+ /*operands=*/ValueRange(),
+ /*types=*/op->getResultTypes(), attributes);
+ Operation *newOp = rewriter.create(opState);
+ rewriter.replaceOp(op, newOp->getResults());
+ }
+ }
+
+ emitRemark() << numIterations << " iterations";
+ return DiagnosedSilenceableFailure::success();
+}
+
namespace {
/// Test extension of the Transform dialect. Registers additional ops and
/// declares PDL as dependent dialect since the additional ops are using PDL
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index e5a23ede64088..c77f6ea320b68 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -449,4 +449,14 @@ def TestRequiredMemoryEffectsOp
let cppNamespace = "::mlir::test";
}
+def TestTrackedRewriteOp
+ : Op<Transform_Dialect, "test_tracked_rewrite",
+ [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let arguments = (ins TransformHandleTypeInterface:$in);
+ let results = (outs);
+ let assemblyFormat = "$in attr-dict `:` functional-type(operands, results)";
+ let cppNamespace = "::mlir::test";
+}
+
#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD
More information about the Mlir-commits
mailing list