[Mlir-commits] [mlir] 0242b96 - [mlir] more side effect verification in transform dialect
Alex Zinenko
llvmlistbot at llvm.org
Mon Feb 6 05:15:44 PST 2023
Author: Alex Zinenko
Date: 2023-02-06T13:15:36Z
New Revision: 0242b96214a608ee3dd509d0967128bc172d2da6
URL: https://github.com/llvm/llvm-project/commit/0242b96214a608ee3dd509d0967128bc172d2da6
DIFF: https://github.com/llvm/llvm-project/commit/0242b96214a608ee3dd509d0967128bc172d2da6.diff
LOG: [mlir] more side effect verification in transform dialect
Add a verifier checking that if a transform operation consumes a handle
(which is associated with a payload operation being erased or
recreated), it also indicates modification of the payload IR. This
hasn't been consistent in the past because of the "no-aliasing"
assumption where we couldn't have had more than one handle to an
operation, requiring some handle-manipulation operations, such as
`transform.merge_handles` to consume their operands. That assumption has
been liften and it is no longer necessary for these operations to
consume handles and thus make the life harder for the clients.
Additionally, remove TransformEffects.td that uses the ODS mechanism for
indicating side effects that works only for operands and results. It
was being used incorrectly to also indicate effects on the payload IR,
not assocaited with any IR value, and lacked the consume/produce
semantics available via helpers in C++.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D142361
Added:
Modified:
mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td
mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Transform/ops-invalid.mlir
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
Removed:
mlir/include/mlir/Dialect/Transform/IR/TransformEffects.td
################################################################################
diff --git a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
index 1ab7bd3628e7c..76f0d8e6439cb 100644
--- a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
+++ b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
@@ -11,7 +11,6 @@
include "mlir/Dialect/PDL/IR/PDLTypes.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"
-include "mlir/Dialect/Transform/IR/TransformEffects.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
diff --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td
index 2431488e01f2d..3a663914e3b52 100644
--- a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td
@@ -11,7 +11,6 @@
include "mlir/Dialect/Bufferization/IR/BufferizationEnums.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"
-include "mlir/Dialect/Transform/IR/TransformEffects.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/Dialect/PDL/IR/PDLTypes.td"
diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
index 802a915d8ec70..2ee13548c9171 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
@@ -10,7 +10,6 @@
#define GPU_TRANSFORM_OPS
include "mlir/Dialect/Transform/IR/TransformDialect.td"
-include "mlir/Dialect/Transform/IR/TransformEffects.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
include "mlir/Dialect/PDL/IR/PDLTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 630000c57f8f7..5e776e0e004b4 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -10,7 +10,6 @@
#define LINALG_TRANSFORM_OPS
include "mlir/Dialect/Transform/IR/TransformDialect.td"
-include "mlir/Dialect/Transform/IR/TransformEffects.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/Dialect/PDL/IR/PDLTypes.td"
@@ -89,7 +88,8 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
def FuseIntoContainingOp :
Op<Transform_Dialect, "structured.fuse_into_containing_op",
- [DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "Fuse a producer into a containing operation.";
let description = [{
@@ -125,14 +125,9 @@ def FuseIntoContainingOp :
This operation reads the containing op handle.
}];
- let arguments = (ins Arg<PDL_Operation, "",
- [TransformMappingRead,
- TransformMappingFree]>:$producer_op,
- Arg<PDL_Operation, "",
- [TransformMappingRead]>:$containing_op);
- let results = (outs Res<PDL_Operation, "",
- [TransformMappingAlloc,
- TransformMappingWrite]>:$fused_op);
+ let arguments = (ins PDL_Operation:$producer_op,
+ PDL_Operation:$containing_op);
+ let results = (outs PDL_Operation:$fused_op);
let assemblyFormat = "$producer_op `into` $containing_op attr-dict";
let builders = [
diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
index c480e7c7a7077..f16fe8a493d2c 100644
--- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
@@ -10,7 +10,6 @@
#define MEMREF_TRANSFORM_OPS
include "mlir/Dialect/Transform/IR/TransformDialect.td"
-include "mlir/Dialect/Transform/IR/TransformEffects.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
include "mlir/Dialect/PDL/IR/PDLTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index affa9abac31a0..b286850ad9895 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -10,7 +10,6 @@
#define SCF_TRANSFORM_OPS
include "mlir/Dialect/Transform/IR/TransformDialect.td"
-include "mlir/Dialect/Transform/IR/TransformEffects.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformEffects.td b/mlir/include/mlir/Dialect/Transform/IR/TransformEffects.td
deleted file mode 100644
index b6106fe96c975..0000000000000
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformEffects.td
+++ /dev/null
@@ -1,62 +0,0 @@
-
-//===- TransformEffect.td - Transform side effects ---------*- tablegen -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines side effects and associated resources for operations in the
-// Transform dialect and extensions.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_EFFECTS_TD
-#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_EFFECTS_TD
-
-include "mlir/Interfaces/SideEffectInterfaces.td"
-
-//===----------------------------------------------------------------------===//
-// Effects on the mapping between Transform IR values and Payload IR ops.
-//===----------------------------------------------------------------------===//
-
-// Side effect resource corresponding to the mapping between transform IR values
-// and Payload IR operations.
-def TransformMappingResource
- : Resource<"::mlir::transform::TransformMappingResource">;
-
-// Describes the creation of a new entry in the transform mapping. Should be
-// accompanied by the Write effect as the entry is immediately initialized by
-// any reasonable transform operation.
-def TransformMappingAlloc : MemAlloc<TransformMappingResource>;
-
-// Describes the removal of an entry in the transform mapping. Typically
-// accompanied by the Read effect.
-def TransformMappingFree : MemFree<TransformMappingResource>;
-
-// Describes the access to the mapping. Read-only accesses can be reordered.
-def TransformMappingRead : MemRead<TransformMappingResource>;
-
-// Describes a modification of an existing entry in the mapping. It is rarely
-// used alone, and is mostly accompanied by the Allocate effect.
-def TransformMappingWrite : MemWrite<TransformMappingResource>;
-
-//===----------------------------------------------------------------------===//
-// Effects on Payload IR.
-//===----------------------------------------------------------------------===//
-
-// Side effect resource corresponding to the Payload IR itself.
-def PayloadIRResource : Resource<"::mlir::transform::PayloadIRResource">;
-
-// Corresponds to the read-only access to the Payload IR through some operation
-// handles in the Transform IR.
-def PayloadIRRead : MemRead<PayloadIRResource>;
-
-// Corresponds to the mutation of the Payload IR through an operation handle in
-// the Transform IR. Should be accompanied by the Read effect for most transform
-// operations (only a complete overwrite of the root op of the Payload IR is a
-// write-only modification).
-def PayloadIRWrite : MemWrite<PayloadIRResource>;
-
-#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_EFFECTS_TD
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 6f3b4cf2e1077..dd66e61880416 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -12,11 +12,11 @@
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Dialect/Transform/IR/TransformAttrs.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"
-include "mlir/Dialect/Transform/IR/TransformEffects.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
def AlternativesOp : TransformDialectOp<"alternatives",
@@ -466,7 +466,8 @@ def SequenceOp : TransformDialectOp<"sequence",
def WithPDLPatternsOp : TransformDialectOp<"with_pdl_patterns",
[DeclareOpInterfaceMethods<TransformOpInterface>, NoTerminator,
- OpAsmOpInterface, PossibleTopLevelTransformOpTrait, RecursiveMemoryEffects,
+ OpAsmOpInterface, PossibleTopLevelTransformOpTrait,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
SymbolTable]> {
let summary = "Contains PDL patterns available for use in transforms";
let description = [{
@@ -505,8 +506,8 @@ def WithPDLPatternsOp : TransformDialectOp<"with_pdl_patterns",
}];
let arguments = (ins
- Arg<Optional<TransformHandleTypeInterface>, "Root operation of the Payload IR",
- [TransformMappingRead]>:$root);
+ Arg<Optional<TransformHandleTypeInterface>, "Root operation of the Payload IR"
+ >:$root);
let regions = (region SizedRegion<1>:$body);
let assemblyFormat = "($root^ `:` type($root))? attr-dict-with-keyword regions";
@@ -518,7 +519,8 @@ def WithPDLPatternsOp : TransformDialectOp<"with_pdl_patterns",
}];
}
-def YieldOp : TransformDialectOp<"yield", [Terminator]> {
+def YieldOp : TransformDialectOp<"yield",
+ [Terminator, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "Yields operation handles from a transform IR region";
let description = [{
This terminator operation yields operation handles from regions of the
@@ -527,8 +529,8 @@ def YieldOp : TransformDialectOp<"yield", [Terminator]> {
}];
let arguments = (ins
- Arg<Variadic<TransformHandleTypeInterface>, "Operation handles yielded back to the parent",
- [TransformMappingRead]>:$operands);
+ Arg<Variadic<TransformHandleTypeInterface>, "Operation handles yielded back to the parent"
+ >:$operands);
let assemblyFormat = "operands attr-dict (`:` type($operands)^)?";
let builders = [
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 060e6bcef5cf8..4533c5a8d6425 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -10,7 +10,6 @@
#define VECTOR_TRANSFORM_OPS
include "mlir/Dialect/Transform/IR/TransformDialect.td"
-include "mlir/Dialect/Transform/IR/TransformEffects.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
include "mlir/Dialect/PDL/IR/PDLTypes.td"
include "mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index a1b2d48a67066..a2f8a1d38eb84 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -60,13 +60,14 @@ transform::OneShotBufferizeOp::apply(TransformResults &transformResults,
void transform::OneShotBufferizeOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- effects.emplace_back(MemoryEffects::Read::get(), getTarget(),
- TransformMappingResource::get());
-
// Handles that are not modules are not longer usable.
- if (!getTargetIsModule())
- effects.emplace_back(MemoryEffects::Free::get(), getTarget(),
- TransformMappingResource::get());
+ if (!getTargetIsModule()) {
+ consumesHandle(getTarget(), effects);
+ } else {
+ onlyReadsHandle(getTarget(), effects);
+ }
+
+ modifiesPayload(effects);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 022c94e39e238..94725fa043406 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -713,6 +713,14 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
return DiagnosedSilenceableFailure::success();
}
+void transform::FuseIntoContainingOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ consumesHandle(getProducerOp(), effects);
+ onlyReadsHandle(getContainingOp(), effects);
+ producesHandle(getFusedOp(), effects);
+ modifiesPayload(effects);
+}
+
//===----------------------------------------------------------------------===//
// GeneralizeOp
//===----------------------------------------------------------------------===//
@@ -2668,6 +2676,7 @@ void transform::TileToForeachThreadOp::getEffects(
onlyReadsHandle(getPackedNumThreads(), effects);
onlyReadsHandle(getPackedTileSizes(), effects);
producesHandle(getResults(), effects);
+ modifiesPayload(effects);
}
SmallVector<OpFoldResult> TileToForeachThreadOp::getMixedNumThreads() {
@@ -2997,6 +3006,7 @@ void transform::MaskedVectorizeOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getTarget(), effects);
onlyReadsHandle(getVectorSizes(), effects);
+ modifiesPayload(effects);
}
SmallVector<OpFoldResult> MaskedVectorizeOp::getMixedVectorSizes() {
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 5ecc1f47573c6..e14fca25af77c 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -783,6 +783,7 @@ LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
});
};
+ std::optional<unsigned> firstConsumedOperand = std::nullopt;
for (OpOperand &operand : op->getOpOperands()) {
auto range = effectsOn(operand.get());
if (range.empty()) {
@@ -793,7 +794,30 @@ LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
<< operand.getOperandNumber();
return diag;
}
+ if (::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(range)) {
+ InFlightDiagnostic diag = op->emitError()
+ << "TransformOpInterface did not expect "
+ "'allocate' memory effect on an operand";
+ diag.attachNote() << "specified for operand #"
+ << operand.getOperandNumber();
+ return diag;
+ }
+ if (!firstConsumedOperand &&
+ ::hasEffect<MemoryEffects::Free, TransformMappingResource>(range)) {
+ firstConsumedOperand = operand.getOperandNumber();
+ }
+ }
+
+ if (firstConsumedOperand &&
+ !::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects)) {
+ InFlightDiagnostic diag =
+ op->emitError()
+ << "TransformOpInterface expects ops consuming operands to have a "
+ "'write' effect on the payload resource";
+ diag.attachNote() << "consumes operand #" << *firstConsumedOperand;
+ return diag;
}
+
for (OpResult result : op->getResults()) {
auto range = effectsOn(result);
if (!::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(
@@ -806,6 +830,7 @@ LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
return diag;
}
}
+
return success();
}
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index f9bdc0ee5793e..bc1ac6d5b0ffe 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -292,7 +292,7 @@ transform::CastOp::applyToOne(Operation *target, ApplyToEachResultList &results,
void transform::CastOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsPayload(effects);
- consumesHandle(getInput(), effects);
+ onlyReadsHandle(getInput(), effects);
producesHandle(getOutput(), effects);
}
@@ -501,7 +501,7 @@ bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
void transform::MergeHandlesOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- consumesHandle(getHandles(), effects);
+ onlyReadsHandle(getHandles(), effects);
producesHandle(getResult(), effects);
// There are no effects on the Payload IR as this is only a handle
@@ -557,7 +557,7 @@ transform::SplitHandlesOp::apply(transform::TransformResults &results,
void transform::SplitHandlesOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- consumesHandle(getHandle(), effects);
+ onlyReadsHandle(getHandle(), effects);
producesHandle(getResults(), effects);
// There are no effects on the Payload IR as this is only a handle
// manipulation.
@@ -626,7 +626,7 @@ transform::ReplicateOp::apply(transform::TransformResults &results,
void transform::ReplicateOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getPattern(), effects);
- consumesHandle(getHandles(), effects);
+ onlyReadsHandle(getHandles(), effects);
producesHandle(getReplicated(), effects);
}
@@ -832,34 +832,62 @@ remapEffects(MemoryEffectOpInterface iface, BlockArgument source, Value target,
effects.emplace_back(effect.getEffect(), target, effect.getResource());
}
-void transform::SequenceOp::getEffects(
- SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- onlyReadsHandle(getRoot(), effects);
- onlyReadsHandle(getExtraBindings(), effects);
- producesHandle(getResults(), effects);
+namespace {
+template <typename T>
+using has_get_extra_bindings = decltype(std::declval<T &>().getExtraBindings());
+} // namespace
+
+/// Populate `effects` with transform dialect memory effects for the potential
+/// top-level operation. Such operations have recursive effects from nested
+/// operations. When they have an operand, we can additionally remap effects on
+/// the block argument to be effects on the operand.
+template <typename OpTy>
+static void getPotentialTopLevelEffects(
+ OpTy operation, SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::onlyReadsHandle(operation->getOperands(), effects);
+ transform::producesHandle(operation->getResults(), effects);
+
+ if (!operation.getRoot()) {
+ for (Operation &op : *operation.getBodyBlock()) {
+ auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
+ if (!iface)
+ continue;
- if (!getRoot()) {
- for (Operation &op : *getBodyBlock()) {
- auto iface = cast<MemoryEffectOpInterface>(&op);
SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
iface.getEffects(effects);
}
return;
}
- // Carry over all effects on the argument of the entry block as those on the
- // operand, this is the same value just remapped.
- for (Operation &op : *getBodyBlock()) {
- auto iface = cast<MemoryEffectOpInterface>(&op);
+ // Carry over all effects on arguments of the entry block as those on the
+ // operands, this is the same value just remapped.
+ for (Operation &op : *operation.getBodyBlock()) {
+ auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
+ if (!iface)
+ continue;
- remapEffects(iface, getBodyBlock()->getArgument(0), getRoot(), effects);
- for (auto [source, target] : llvm::zip(
- getBodyBlock()->getArguments().drop_front(), getExtraBindings())) {
- remapEffects(iface, source, target, effects);
+ remapEffects(iface, operation.getBodyBlock()->getArgument(0),
+ operation.getRoot(), effects);
+ if constexpr (llvm::is_detected<has_get_extra_bindings, OpTy>::value) {
+ for (auto [source, target] :
+ llvm::zip(operation.getBodyBlock()->getArguments().drop_front(),
+ operation.getExtraBindings())) {
+ remapEffects(iface, source, target, effects);
+ }
}
+
+ SmallVector<MemoryEffects::EffectInstance> nestedEffects;
+ iface.getEffectsOnResource(transform::PayloadIRResource::get(),
+ nestedEffects);
+ llvm::append_range(effects, nestedEffects);
}
}
+void transform::SequenceOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ getPotentialTopLevelEffects(*this, effects);
+}
+
OperandRange transform::SequenceOp::getSuccessorEntryOperands(
std::optional<unsigned> index) {
assert(index && *index == 0 && "unexpected region index");
@@ -983,6 +1011,11 @@ transform::WithPDLPatternsOp::apply(transform::TransformResults &results,
return state.applyTransform(transformOp);
}
+void transform::WithPDLPatternsOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ getPotentialTopLevelEffects(*this, effects);
+}
+
LogicalResult transform::WithPDLPatternsOp::verify() {
Block *body = getBodyBlock();
Operation *topLevelOp = nullptr;
@@ -1065,3 +1098,12 @@ void transform::PrintOp::getEffects(
// writes into the default resource.
effects.emplace_back(MemoryEffects::Write::get());
}
+
+//===----------------------------------------------------------------------===//
+// YieldOp
+//===----------------------------------------------------------------------===//
+
+void transform::YieldOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getOperands(), effects);
+}
diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
index 2fd0a37c86b85..500fe61e01481 100644
--- a/mlir/test/Dialect/Transform/ops-invalid.mlir
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -251,7 +251,7 @@ transform.sequence failures(suppress) {
^bb0(%arg0: !transform.any_op):
// expected-error @below {{TransformOpInterface requires memory effects on operands to be specified}}
// expected-note @below {{no effects specified for operand #0}}
- transform.test_required_memory_effects %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.test_required_memory_effects %arg0 {modifies_payload} : (!transform.any_op) -> !transform.any_op
}
// -----
@@ -260,5 +260,5 @@ transform.sequence failures(suppress) {
^bb0(%arg0: !transform.any_op):
// expected-error @below {{TransformOpInterface requires 'allocate' memory effect to be specified for results}}
// expected-note @below {{no 'allocate' effect specified for result #0}}
- transform.test_required_memory_effects %arg0 {has_operand_effect} : (!transform.any_op) -> !transform.any_op
+ transform.test_required_memory_effects %arg0 {has_operand_effect, modifies_payload} : (!transform.any_op) -> !transform.any_op
}
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 63d682831114b..0bd3031df9132 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -118,6 +118,13 @@ mlir::test::TestProduceParamOrForwardOperandOp::apply(
return DiagnosedSilenceableFailure::success();
}
+void mlir::test::TestProduceParamOrForwardOperandOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ if (getOperand())
+ transform::onlyReadsHandle(getOperand(), effects);
+ transform::producesHandle(getRes(), effects);
+}
+
LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() {
if (getParameter().has_value() ^ (getNumOperands() != 1))
return emitOpError() << "expects either a parameter or an operand";
@@ -130,6 +137,14 @@ mlir::test::TestConsumeOperand::apply(transform::TransformResults &results,
return DiagnosedSilenceableFailure::success();
}
+void mlir::test::TestConsumeOperand::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::consumesHandle(getOperand(), effects);
+ if (getSecondOperand())
+ transform::consumesHandle(getSecondOperand(), effects);
+ transform::modifiesPayload(effects);
+}
+
DiagnosedSilenceableFailure
mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply(
transform::TransformResults &results, transform::TransformState &state) {
@@ -146,6 +161,12 @@ mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply(
return DiagnosedSilenceableFailure::success();
}
+void mlir::test::TestConsumeOperandIfMatchesParamOrFail::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::consumesHandle(getOperand(), effects);
+ transform::modifiesPayload(effects);
+}
+
DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandOp::apply(
transform::TransformResults &results, transform::TransformState &state) {
ArrayRef<Operation *> payload = state.getPayloadOps(getOperand());
@@ -155,6 +176,12 @@ DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandOp::apply(
return DiagnosedSilenceableFailure::success();
}
+void mlir::test::TestPrintRemarkAtOperandOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::onlyReadsHandle(getOperand(), effects);
+ transform::onlyReadsPayload(effects);
+}
+
DiagnosedSilenceableFailure
mlir::test::TestAddTestExtensionOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
@@ -187,6 +214,12 @@ mlir::test::TestCheckIfTestExtensionPresentOp::apply(
return DiagnosedSilenceableFailure::success();
}
+void mlir::test::TestCheckIfTestExtensionPresentOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::onlyReadsHandle(getOperand(), effects);
+ transform::onlyReadsPayload(effects);
+}
+
DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply(
transform::TransformResults &results, transform::TransformState &state) {
auto *extension = state.getExtension<TestTransformStateExtension>();
@@ -199,6 +232,12 @@ DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply(
return DiagnosedSilenceableFailure::success();
}
+void mlir::test::TestRemapOperandPayloadToSelfOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::onlyReadsHandle(getOperand(), effects);
+ transform::onlyReadsPayload(effects);
+}
+
DiagnosedSilenceableFailure mlir::test::TestRemoveTestExtensionOp::apply(
transform::TransformResults &results, transform::TransformState &state) {
state.removeExtension<TestTransformStateExtension>();
@@ -312,6 +351,13 @@ mlir::test::TestCopyPayloadOp::apply(transform::TransformResults &results,
return DiagnosedSilenceableFailure::success();
}
+void mlir::test::TestCopyPayloadOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::onlyReadsHandle(getHandle(), effects);
+ transform::producesHandle(getCopy(), effects);
+ transform::onlyReadsPayload(effects);
+}
+
DiagnosedSilenceableFailure mlir::transform::TestDialectOpType::checkPayload(
Location loc, ArrayRef<Operation *> payload) const {
if (payload.empty())
@@ -491,6 +537,9 @@ void mlir::test::TestRequiredMemoryEffectsOp::getEffects(
transform::producesHandle(getOut(), effects);
else
transform::onlyReadsHandle(getOut(), effects);
+
+ if (getModifiesPayload())
+ transform::modifiesPayload(effects);
}
DiagnosedSilenceableFailure mlir::test::TestRequiredMemoryEffectsOp::apply(
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index 02e8a691db3c3..cc67c2ac2afd6 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -14,10 +14,10 @@
#ifndef MLIR_TESTTRANSFORMDIALECTEXTENSION_TD
#define MLIR_TESTTRANSFORMDIALECTEXTENSION_TD
+include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpBase.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"
-include "mlir/Dialect/Transform/IR/TransformEffects.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
include "mlir/Dialect/PDL/IR/PDLTypes.td"
@@ -41,35 +41,33 @@ def TestTransformTestDialectParamType
def TestProduceParamOrForwardOperandOp
: Op<Transform_Dialect, "test_produce_param_or_forward_operand",
- [DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let arguments = (ins
- Arg<Optional<PDL_Operation>, "", [TransformMappingRead]>:$operand,
+ Optional<PDL_Operation>:$operand,
OptionalAttr<I64Attr>:$parameter);
- let results = (outs
- Res<PDL_Operation, "",
- [TransformMappingAlloc, TransformMappingWrite]>:$res);
+ let results = (outs PDL_Operation:$res);
let assemblyFormat = "(`from` $operand^)? ($parameter^)? attr-dict";
let cppNamespace = "::mlir::test";
let hasVerifier = 1;
}
def TestConsumeOperand : Op<Transform_Dialect, "test_consume_operand",
- [DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let arguments = (ins
- Arg<PDL_Operation, "",
- [TransformMappingRead, TransformMappingFree]>:$operand,
- Arg<Optional<PDL_Operation>, "",
- [TransformMappingRead, TransformMappingFree]>:$second_operand);
+ PDL_Operation:$operand,
+ Optional<PDL_Operation>:$second_operand);
let assemblyFormat = "$operand (`,` $second_operand^)? attr-dict";
let cppNamespace = "::mlir::test";
}
def TestConsumeOperandIfMatchesParamOrFail
: Op<Transform_Dialect, "test_consume_operand_if_matches_param_or_fail",
- [DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let arguments = (ins
- Arg<PDL_Operation, "",
- [TransformMappingRead, TransformMappingFree]>:$operand,
+ PDL_Operation:$operand,
I64Attr:$parameter);
let assemblyFormat = "$operand `[` $parameter `]` attr-dict";
let cppNamespace = "::mlir::test";
@@ -77,10 +75,10 @@ def TestConsumeOperandIfMatchesParamOrFail
def TestPrintRemarkAtOperandOp
: Op<Transform_Dialect, "test_print_remark_at_operand",
- [DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let arguments = (ins
- Arg<TransformHandleTypeInterface, "",
- [TransformMappingRead, PayloadIRRead]>:$operand,
+ TransformHandleTypeInterface:$operand,
StrAttr:$message);
let assemblyFormat =
"$operand `,` $message attr-dict `:` type($operand)";
@@ -98,19 +96,18 @@ def TestAddTestExtensionOp
def TestCheckIfTestExtensionPresentOp
: Op<Transform_Dialect, "test_check_if_test_extension_present",
- [DeclareOpInterfaceMethods<TransformOpInterface>]> {
- let arguments = (ins
- Arg<PDL_Operation, "", [TransformMappingRead, PayloadIRRead]>:$operand);
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+ let arguments = (ins PDL_Operation:$operand);
let assemblyFormat = "$operand attr-dict";
let cppNamespace = "::mlir::test";
}
def TestRemapOperandPayloadToSelfOp
: Op<Transform_Dialect, "test_remap_operand_to_self",
- [DeclareOpInterfaceMethods<TransformOpInterface>]> {
- let arguments = (ins
- Arg<PDL_Operation, "",
- [TransformMappingRead, TransformMappingWrite, PayloadIRRead]>:$operand);
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+ let arguments = (ins PDL_Operation:$operand);
let assemblyFormat = "$operand attr-dict";
let cppNamespace = "::mlir::test";
}
@@ -255,10 +252,10 @@ def TestPrintNumberOfAssociatedPayloadIROps
def TestCopyPayloadOp
: Op<Transform_Dialect, "test_copy_payload",
- [DeclareOpInterfaceMethods<TransformOpInterface>]> {
- let arguments = (ins Arg<PDL_Operation, "", [TransformMappingRead]>:$handle);
- let results = (outs Res<PDL_Operation, "",
- [TransformMappingAlloc, TransformMappingWrite]>:$copy);
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+ let arguments = (ins PDL_Operation:$handle);
+ let results = (outs PDL_Operation:$copy);
let cppNamespace = "::mlir::test";
let assemblyFormat = "$handle attr-dict";
}
@@ -358,7 +355,8 @@ def TestRequiredMemoryEffectsOp
DeclareOpInterfaceMethods<TransformOpInterface>]> {
let arguments = (ins TransformHandleTypeInterface:$in,
UnitAttr:$has_operand_effect,
- UnitAttr:$has_result_effect);
+ UnitAttr:$has_result_effect,
+ UnitAttr:$modifies_payload);
let results = (outs TransformHandleTypeInterface:$out);
let assemblyFormat = "$in attr-dict `:` functional-type(operands, results)";
let cppNamespace = "::mlir::test";
More information about the Mlir-commits
mailing list