[Mlir-commits] [mlir] b9e40cd - [mlir] multi-argument binding for top-level transform ops
Alex Zinenko
llvmlistbot at llvm.org
Tue Jan 31 06:21:37 PST 2023
Author: Alex Zinenko
Date: 2023-01-31T14:21:28Z
New Revision: b9e40cde3b35c64f0e4f2156224a671104e04efd
URL: https://github.com/llvm/llvm-project/commit/b9e40cde3b35c64f0e4f2156224a671104e04efd
DIFF: https://github.com/llvm/llvm-project/commit/b9e40cde3b35c64f0e4f2156224a671104e04efd.diff
LOG: [mlir] multi-argument binding for top-level transform ops
`applyTransforms` now takes an optional mapping to be associated with
trailing block arguments of the top-level transform op, in addition to
the payload root. This allows for more advanced forms of communication
between C++ code and the transform dialect interpreter, in particular
supplying operations without having to re-match them during
interpretation.
Reviewed By: shabalin
Differential Revision: https://reviews.llvm.org/D142559
Added:
mlir/test/Dialect/Transform/multi-arg-top-level-ops.mlir
mlir/test/Dialect/Transform/multi-arg-top-level-params.mlir
Modified:
mlir/docs/Dialects/Transform.md
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/python/mlir/dialects/_transform_ops_ext.py
mlir/test/Dialect/Transform/ops-invalid.mlir
mlir/test/Dialect/Transform/ops.mlir
mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
mlir/test/python/dialects/transform.py
Removed:
################################################################################
diff --git a/mlir/docs/Dialects/Transform.md b/mlir/docs/Dialects/Transform.md
index eb86bdca5c2a0..e73bb79283f24 100644
--- a/mlir/docs/Dialects/Transform.md
+++ b/mlir/docs/Dialects/Transform.md
@@ -109,13 +109,19 @@ A program transformation expressed using the Transform dialect can be
programmatically triggered by calling:
```c++
-LogicalResult transform::applyTransforms(Operation *payloadRoot,
- TransformOpInterface transform,
- const TransformOptions &options);
+LogicalResult transform::applyTransforms(
+ Operation *payloadRoot,
+ ArrayRef<ArrayRef<PointerUnion<Operation *, Attribute>> extraMappings,
+ TransformOpInterface transform,
+ const TransformOptions &options);
```
that applies the transformations specified by the top-level `transform` to
-payload IR contained in `payloadRoot`.
+payload IR contained in `payloadRoot`. The payload root operation will be
+associated with the first argument of the entry block of the top-level transform
+op. This block may have additional arguments, handles or parameters. They will
+be associated with values provided as `extraMappings`. The call will report an
+error and return if the wrong number of mappings is provided.
## Dialect Extension Mechanism
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index e523c08f99ec6..063b6dec7dc97 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -42,6 +42,9 @@ class TransformOptions {
bool expensiveChecksEnabled = true;
};
+using Param = Attribute;
+using MappedValue = llvm::PointerUnion<Operation *, Param>;
+
/// Entry point to the Transform dialect infrastructure. Applies the
/// transformation specified by `transform` to payload IR contained in
/// `payloadRoot`. The `transform` operation may contain other operations that
@@ -50,6 +53,7 @@ class TransformOptions {
/// This function internally keeps track of the transformation state.
LogicalResult
applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
+ ArrayRef<ArrayRef<MappedValue>> extraMapping = {},
const TransformOptions &options = TransformOptions());
/// The state maintained across applications of various ops implementing the
@@ -85,7 +89,7 @@ applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
/// using `mapBlockArguments`.
class TransformState {
public:
- using Param = Attribute;
+ using Param = transform::Param;
private:
/// Mapping between a Value in the transform IR and the corresponding set of
@@ -109,15 +113,23 @@ class TransformState {
ParamMapping params;
};
- friend LogicalResult applyTransforms(Operation *payloadRoot,
- TransformOpInterface transform,
- const TransformOptions &options);
+ friend LogicalResult applyTransforms(Operation *, TransformOpInterface,
+ ArrayRef<ArrayRef<MappedValue>>,
+ const TransformOptions &);
public:
/// Returns the op at which the transformation state is rooted. This is
/// typically helpful for transformations that apply globally.
Operation *getTopLevel() const;
+ /// Returns the number of extra mappings for the top-level operation.
+ size_t getNumTopLevelMappings() const { return topLevelMappedValues.size(); }
+
+ /// Returns the position-th extra mapping for the top-level operation.
+ ArrayRef<MappedValue> getTopLevelMapping(size_t position) const {
+ return topLevelMappedValues[position];
+ }
+
/// Returns the list of ops that the given transform IR value corresponds to.
/// This is helpful for transformations that apply to a particular handle.
ArrayRef<Operation *> getPayloadOps(Value value) const;
@@ -150,6 +162,8 @@ class TransformState {
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
return setPayloadOps(argument, operations);
}
+ LogicalResult mapBlockArgument(BlockArgument argument,
+ ArrayRef<MappedValue> values);
// Forward declarations to support limited visibility.
class RegionScope;
@@ -302,6 +316,7 @@ class TransformState {
/// which may or may not contain the region with transform ops. Additional
/// options can be provided through the trailing configuration object.
TransformState(Region *region, Operation *payloadRoot,
+ ArrayRef<ArrayRef<MappedValue>> extraMappings = {},
const TransformOptions &options = TransformOptions());
/// Returns the mappings frame for the reigon in which the value is defined.
@@ -403,6 +418,15 @@ class TransformState {
/// The top-level operation that contains all payload IR, typically a module.
Operation *topLevel;
+ /// Storage for extra mapped values (payload operations or parameters) to be
+ /// associated with additional entry block arguments of the top-level
+ /// transform operation. Each entry in `topLevelMappedValues` is a reference
+ /// to a contiguous block in `topLevelMappedValueStorage`.
+ // TODO: turn this into a proper named data structure, there are several more
+ // below.
+ SmallVector<ArrayRef<MappedValue>> topLevelMappedValues;
+ SmallVector<MappedValue> topLevelMappedValueStorage;
+
/// Additional options controlling the transformation state behavior.
TransformOptions options;
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
index 0a737d5313750..7eb9a01fc0f07 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
@@ -26,6 +26,9 @@ class FailurePropagationModeAttr;
/// A builder function that populates the body of a SequenceOp.
using SequenceBodyBuilderFn = ::llvm::function_ref<void(
::mlir::OpBuilder &, ::mlir::Location, ::mlir::BlockArgument)>;
+using SequenceBodyBuilderArgsFn =
+ ::llvm::function_ref<void(::mlir::OpBuilder &, ::mlir::Location,
+ ::mlir::BlockArgument, ::mlir::ValueRange)>;
} // namespace transform
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 4bb6700dbca29..6f3b4cf2e1077 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -384,7 +384,8 @@ def SequenceOp : TransformDialectOp<"sequence",
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
OpAsmOpInterface, PossibleTopLevelTransformOpTrait,
- SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">]> {
+ SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">,
+ AttrSizedOperandSegments]> {
let summary = "Contains a sequence of other transform ops to apply";
let description = [{
The transformations indicated by the sequence are applied in order of their
@@ -417,12 +418,14 @@ def SequenceOp : TransformDialectOp<"sequence",
}];
let arguments = (ins FailurePropagationMode:$failure_propagation_mode,
- Optional<TransformHandleTypeInterface>:$root);
+ Optional<TransformHandleTypeInterface>:$root,
+ Variadic<Transform_AnyHandleOrParamType>:$extra_bindings);
let results = (outs Variadic<TransformHandleTypeInterface>:$results);
let regions = (region SizedRegion<1>:$body);
let assemblyFormat =
- "($root^ `:` type($root))? (`->` type($results)^)? `failures` `(` "
+ "custom<SequenceOpOperands>($root, type($root), $extra_bindings, type($extra_bindings))"
+ " (`->` type($results)^)? `failures` `(` "
"$failure_propagation_mode `)` attr-dict-with-keyword regions";
let builders = [
@@ -432,11 +435,25 @@ def SequenceOp : TransformDialectOp<"sequence",
"::mlir::transform::FailurePropagationMode":$failure_propagation_mode,
"::mlir::Value":$root, "SequenceBodyBuilderFn":$bodyBuilder)>,
- // Build a sequence without a root but a certain bbArg type.
+ // Build a sequence with a root and additional arguments.
+ OpBuilder<(ins
+ "::mlir::TypeRange":$resultTypes,
+ "::mlir::transform::FailurePropagationMode":$failure_propagation_mode,
+ "::mlir::Value":$root, "::mlir::ValueRange":$extraBindings,
+ "SequenceBodyBuilderArgsFn":$bodyBuilder)>,
+
+ // Build a top-level sequence (no root).
+ OpBuilder<(ins
+ "::mlir::TypeRange":$resultTypes,
+ "::mlir::transform::FailurePropagationMode":$failure_propagation_mode,
+ "::mlir::Type":$bbArgType, "SequenceBodyBuilderFn":$bodyBuilder)>,
+
+ // Build a top-level sequence (no root) with extra arguments.
OpBuilder<(ins
"::mlir::TypeRange":$resultTypes,
"::mlir::transform::FailurePropagationMode":$failure_propagation_mode,
- "::mlir::Type":$bbArgType, "SequenceBodyBuilderFn":$bodyBuilder)>
+ "::mlir::Type":$bbArgType, "::mlir::TypeRange":$extraBindingTypes,
+ "SequenceBodyBuilderArgsFn":$bodyBuilder)>
];
let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index e2ab48e21a8bf..5ecc1f47573c6 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -27,10 +27,20 @@ using namespace mlir;
constexpr const Value transform::TransformState::kTopLevelValue;
-transform::TransformState::TransformState(Region *region,
- Operation *payloadRoot,
- const TransformOptions &options)
+transform::TransformState::TransformState(
+ Region *region, Operation *payloadRoot,
+ ArrayRef<ArrayRef<MappedValue>> extraMappings,
+ const TransformOptions &options)
: topLevel(payloadRoot), options(options) {
+ topLevelMappedValues.reserve(extraMappings.size());
+ for (ArrayRef<MappedValue> mapping : extraMappings) {
+ size_t start = topLevelMappedValueStorage.size();
+ llvm::append_range(topLevelMappedValueStorage, mapping);
+ topLevelMappedValues.push_back(
+ ArrayRef<MappedValue>(topLevelMappedValueStorage)
+ .slice(start, mapping.size()));
+ }
+
auto result = mappings.try_emplace(region);
assert(result.second && "the region scope is already present");
(void)result;
@@ -72,6 +82,38 @@ LogicalResult transform::TransformState::getHandlesForPayloadOp(
return success(found);
}
+LogicalResult
+transform::TransformState::mapBlockArgument(BlockArgument argument,
+ ArrayRef<MappedValue> values) {
+ if (argument.getType().isa<TransformHandleTypeInterface>()) {
+ SmallVector<Operation *> operations;
+ operations.reserve(values.size());
+ for (MappedValue value : values) {
+ if (auto *op = value.dyn_cast<Operation *>()) {
+ operations.push_back(op);
+ continue;
+ }
+ return emitError(argument.getLoc())
+ << "wrong kind of value provided for top-level operation handle";
+ }
+ return setPayloadOps(argument, operations);
+ }
+
+ assert(argument.getType().isa<TransformParamTypeInterface>() &&
+ "unsupported kind of block argument");
+ SmallVector<Param> parameters;
+ parameters.reserve(values.size());
+ for (MappedValue value : values) {
+ if (auto attr = value.dyn_cast<Attribute>()) {
+ parameters.push_back(attr);
+ continue;
+ }
+ return emitError(argument.getLoc())
+ << "wrong kind of value provided for top-level parameter";
+ }
+ return setParams(argument, parameters);
+}
+
LogicalResult
transform::TransformState::setPayloadOps(Value value,
ArrayRef<Operation *> targets) {
@@ -522,12 +564,43 @@ void transform::detail::setApplyToOneResults(
LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
TransformState &state, Operation *op, Region ®ion) {
SmallVector<Operation *> targets;
- if (op->getNumOperands() != 0)
+ SmallVector<SmallVector<MappedValue>> extraMappings;
+ if (op->getNumOperands() != 0) {
llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
- else
+ for (Value operand : op->getOperands().drop_front()) {
+ SmallVector<MappedValue> &mapped = extraMappings.emplace_back();
+ if (operand.getType().isa<TransformHandleTypeInterface>()) {
+ llvm::append_range(mapped, state.getPayloadOps(operand));
+ } else {
+ assert(operand.getType().isa<TransformParamTypeInterface>() &&
+ "unsupported kind of transform dialect value");
+ llvm::append_range(mapped, state.getParams(operand));
+ }
+ }
+ } else {
+ if (state.getNumTopLevelMappings() !=
+ region.front().getNumArguments() - 1) {
+ return emitError(op->getLoc())
+ << "operation expects " << region.front().getNumArguments() - 1
+ << " extra value bindings, but " << state.getNumTopLevelMappings()
+ << " were provided to the interpreter";
+ }
+
targets.push_back(state.getTopLevel());
+ for (unsigned i = 0, e = state.getNumTopLevelMappings(); i < e; ++i)
+ extraMappings.push_back(llvm::to_vector(state.getTopLevelMapping(i)));
+ }
+
+ if (failed(state.mapBlockArguments(region.front().getArgument(0), targets)))
+ return failure();
+
+ for (BlockArgument argument : region.front().getArguments().drop_front()) {
+ if (failed(state.mapBlockArgument(
+ argument, extraMappings[argument.getArgNumber() - 1])))
+ return failure();
+ }
- return state.mapBlockArguments(region.front().getArgument(0), targets);
+ return success();
}
LogicalResult
@@ -547,19 +620,42 @@ transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) {
return op->emitOpError() << "expects a single-block region";
Block *body = &bodyRegion->front();
- if (body->getNumArguments() != 1 ||
- !body->getArgumentTypes()[0].isa<TransformHandleTypeInterface>()) {
+ if (body->getNumArguments() == 0) {
+ return op->emitOpError()
+ << "expects the entry block to have at least one argument";
+ }
+ if (!body->getArgument(0).getType().isa<TransformHandleTypeInterface>()) {
return op->emitOpError()
- << "expects the entry block to have one argument "
- "of type implementing TransformHandleTypeInterface";
+ << "expects the first entry block argument to be of type "
+ "implementing TransformHandleTypeInterface";
+ }
+ BlockArgument arg = body->getArgument(0);
+ if (op->getNumOperands() != 0) {
+ if (arg.getType() != op->getOperand(0).getType()) {
+ return op->emitOpError()
+ << "expects the type of the block argument to match "
+ "the type of the operand";
+ }
+ }
+ for (BlockArgument arg : body->getArguments().drop_front()) {
+ if (arg.getType()
+ .isa<TransformHandleTypeInterface, TransformParamTypeInterface>())
+ continue;
+
+ InFlightDiagnostic diag =
+ op->emitOpError()
+ << "expects trailing entry block arguments to be of type implementing "
+ "TransformHandleTypeInterface or TransformParamTypeInterface";
+ diag.attachNote() << "argument #" << arg.getArgNumber() << " does not";
+ return diag;
}
if (auto *parent =
op->getParentWithTrait<PossibleTopLevelTransformOpTrait>()) {
- if (op->getNumOperands() == 0) {
+ if (op->getNumOperands() != body->getNumArguments()) {
InFlightDiagnostic diag =
op->emitOpError()
- << "expects the root operation to be provided for a nested op";
+ << "expects operands to be provided for a nested op";
diag.attachNote(parent->getLoc())
<< "nested in another possible top-level op";
return diag;
@@ -717,9 +813,11 @@ LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
// Entry point.
//===----------------------------------------------------------------------===//
-LogicalResult transform::applyTransforms(Operation *payloadRoot,
- TransformOpInterface transform,
- const TransformOptions &options) {
+LogicalResult
+transform::applyTransforms(Operation *payloadRoot,
+ TransformOpInterface transform,
+ ArrayRef<ArrayRef<MappedValue>> extraMapping,
+ const TransformOptions &options) {
#ifndef NDEBUG
if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
transform->getNumOperands() != 0) {
@@ -730,7 +828,8 @@ LogicalResult transform::applyTransforms(Operation *payloadRoot,
}
#endif // NDEBUG
- TransformState state(transform->getParentRegion(), payloadRoot, options);
+ TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
+ options);
return state.applyTransform(transform).checkAndReport();
}
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 8bd5cab2f460b..0314932fc994b 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -26,6 +26,16 @@
using namespace mlir;
+static ParseResult parseSequenceOpOperands(
+ OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &root,
+ Type &rootType,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
+ SmallVectorImpl<Type> &extraBindingTypes);
+static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
+ Value root, Type rootType,
+ ValueRange extraBindings,
+ TypeRange extraBindingTypes);
+
#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
@@ -654,6 +664,76 @@ transform::SequenceOp::apply(transform::TransformResults &results,
return DiagnosedSilenceableFailure::success();
}
+static ParseResult parseSequenceOpOperands(
+ OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &root,
+ Type &rootType,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
+ SmallVectorImpl<Type> &extraBindingTypes) {
+ OpAsmParser::UnresolvedOperand rootOperand;
+ OptionalParseResult hasRoot = parser.parseOptionalOperand(rootOperand);
+ if (!hasRoot.has_value()) {
+ root = std::nullopt;
+ return success();
+ }
+ if (failed(hasRoot.value()))
+ return failure();
+ root = rootOperand;
+
+ if (succeeded(parser.parseOptionalComma())) {
+ if (failed(parser.parseOperandList(extraBindings)))
+ return failure();
+ }
+ if (failed(parser.parseColon()))
+ return failure();
+
+ // The paren is truly optional.
+ (void)parser.parseOptionalLParen();
+
+ if (failed(parser.parseType(rootType))) {
+ return failure();
+ }
+
+ if (!extraBindings.empty()) {
+ if (parser.parseComma() || parser.parseTypeList(extraBindingTypes))
+ return failure();
+ }
+
+ if (extraBindingTypes.size() != extraBindings.size()) {
+ return parser.emitError(parser.getNameLoc(),
+ "expected types to be provided for all operands");
+ }
+
+ // The paren is truly optional.
+ (void)parser.parseOptionalRParen();
+ return success();
+}
+
+static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
+ Value root, Type rootType,
+ ValueRange extraBindings,
+ TypeRange extraBindingTypes) {
+ if (!root)
+ return;
+
+ printer << root;
+ bool hasExtras = !extraBindings.empty();
+ if (hasExtras) {
+ printer << ", ";
+ printer.printOperands(extraBindings);
+ }
+
+ printer << " : ";
+ if (hasExtras)
+ printer << "(";
+
+ printer << rootType;
+ if (hasExtras) {
+ printer << ", ";
+ llvm::interleaveComma(extraBindingTypes, printer.getStream());
+ printer << ")";
+ }
+}
+
/// Returns `true` if the given op operand may be consuming the handle value in
/// the Transform IR. That is, if it may have a Free effect on it.
static bool isValueUsePotentialConsumer(OpOperand &use) {
@@ -691,22 +771,22 @@ checkDoubleConsume(Value value,
}
LogicalResult transform::SequenceOp::verify() {
- assert(getBodyBlock()->getNumArguments() == 1 &&
- "the number of arguments must have been verified to be 1 by "
+ assert(getBodyBlock()->getNumArguments() >= 1 &&
+ "the number of arguments must have been verified to be more than 1 by "
"PossibleTopLevelTransformOpTrait");
- BlockArgument arg = getBodyBlock()->getArgument(0);
- if (getRoot()) {
- if (arg.getType() != getRoot().getType()) {
- return emitOpError() << "expects the type of the block argument to match "
- "the type of the operand";
- }
+ if (!getRoot() && !getExtraBindings().empty()) {
+ return emitOpError()
+ << "does not expect extra operands when used as top-level";
}
- // Check if the block argument has more than one consuming use.
- if (failed(checkDoubleConsume(
- arg, [this]() { return (emitOpError() << "block argument #0"); }))) {
- return failure();
+ // Check if a block argument has more than one consuming use.
+ for (BlockArgument arg : getBodyBlock()->getArguments()) {
+ if (failed(checkDoubleConsume(arg, [this, arg]() {
+ return (emitOpError() << "block argument #" << arg.getArgNumber());
+ }))) {
+ return failure();
+ }
}
// Check properties of the nested operations they cannot check themselves.
@@ -740,26 +820,26 @@ LogicalResult transform::SequenceOp::verify() {
return success();
}
+/// Appends to `effects` the memory effect instances on `target` with the same
+/// resource and effect as the ones the operation `iface` having on `source`.
+static void
+remapEffects(MemoryEffectOpInterface iface, BlockArgument source, Value target,
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ SmallVector<MemoryEffects::EffectInstance> nestedEffects;
+ iface.getEffectsOnValue(source, nestedEffects);
+ for (const auto &effect : nestedEffects)
+ effects.emplace_back(effect.getEffect(), target, effect.getResource());
+}
+
void transform::SequenceOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- auto *mappingResource = TransformMappingResource::get();
- effects.emplace_back(MemoryEffects::Read::get(), getRoot(), mappingResource);
-
- for (Value result : getResults()) {
- effects.emplace_back(MemoryEffects::Allocate::get(), result,
- mappingResource);
- effects.emplace_back(MemoryEffects::Write::get(), result, mappingResource);
- }
+ onlyReadsHandle(getRoot(), effects);
+ onlyReadsHandle(getExtraBindings(), effects);
+ producesHandle(getResults(), effects);
if (!getRoot()) {
for (Operation &op : *getBodyBlock()) {
- auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
- if (!iface) {
- // TODO: fill all possible effects; or require ops to actually implement
- // the memory effect interface always
- assert(false);
- }
-
+ auto iface = cast<MemoryEffectOpInterface>(&op);
SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
iface.getEffects(effects);
}
@@ -769,24 +849,20 @@ void transform::SequenceOp::getEffects(
// Carry over all effects on the argument of the entry block as those on the
// operand, this is the same value just remapped.
for (Operation &op : *getBodyBlock()) {
- auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
- if (!iface) {
- // TODO: fill all possible effects; or require ops to actually implement
- // the memory effect interface always
- assert(false);
- }
+ auto iface = cast<MemoryEffectOpInterface>(&op);
- SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
- iface.getEffectsOnValue(getBodyBlock()->getArgument(0), nestedEffects);
- for (const auto &effect : nestedEffects)
- effects.emplace_back(effect.getEffect(), getRoot(), effect.getResource());
+ remapEffects(iface, getBodyBlock()->getArgument(0), getRoot(), effects);
+ for (auto [source, target] : llvm::zip(
+ getBodyBlock()->getArguments().drop_front(), getExtraBindings())) {
+ remapEffects(iface, source, target, effects);
+ }
}
}
OperandRange transform::SequenceOp::getSuccessorEntryOperands(
std::optional<unsigned> index) {
assert(index && *index == 0 && "unexpected region index");
- if (getOperation()->getNumOperands() == 1)
+ if (getOperation()->getNumOperands() > 0)
return getOperation()->getOperands();
return OperandRange(getOperation()->operand_end(),
getOperation()->operand_end());
@@ -813,21 +889,51 @@ void transform::SequenceOp::getRegionInvocationBounds(
bounds.emplace_back(1, 1);
}
+template <typename FnTy>
+static void buildSequenceBody(OpBuilder &builder, OperationState &state,
+ Type bbArgType, TypeRange extraBindingTypes,
+ FnTy bodyBuilder) {
+ SmallVector<Type> types;
+ types.reserve(1 + extraBindingTypes.size());
+ types.push_back(bbArgType);
+ llvm::append_range(types, extraBindingTypes);
+
+ OpBuilder::InsertionGuard guard(builder);
+ Region *region = state.regions.back().get();
+ Block *bodyBlock = builder.createBlock(region, region->begin(),
+ extraBindingTypes, {state.location});
+
+ // Populate body.
+ builder.setInsertionPointToStart(bodyBlock);
+ if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
+ bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
+ } else {
+ bodyBuilder(builder, state.location, bodyBlock->getArgument(0),
+ bodyBlock->getArguments().drop_front());
+ }
+}
+
void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
TypeRange resultTypes,
FailurePropagationMode failurePropagationMode,
Value root,
SequenceBodyBuilderFn bodyBuilder) {
- build(builder, state, resultTypes, failurePropagationMode, root);
- Region *region = state.regions.back().get();
+ build(builder, state, resultTypes, failurePropagationMode, root,
+ /*extraBindings=*/ValueRange());
Type bbArgType = root.getType();
- OpBuilder::InsertionGuard guard(builder);
- Block *bodyBlock = builder.createBlock(
- region, region->begin(), TypeRange{bbArgType}, {state.location});
+ buildSequenceBody(builder, state, bbArgType,
+ /*extraBindingTypes=*/TypeRange(), bodyBuilder);
+}
- // Populate body.
- builder.setInsertionPointToStart(bodyBlock);
- bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
+void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
+ TypeRange resultTypes,
+ FailurePropagationMode failurePropagationMode,
+ Value root, ValueRange extraBindings,
+ SequenceBodyBuilderArgsFn bodyBuilder) {
+ build(builder, state, resultTypes, failurePropagationMode, root,
+ extraBindings);
+ buildSequenceBody(builder, state, root.getType(), extraBindings.getTypes(),
+ bodyBuilder);
}
void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
@@ -835,15 +941,20 @@ void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
FailurePropagationMode failurePropagationMode,
Type bbArgType,
SequenceBodyBuilderFn bodyBuilder) {
- build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value());
- Region *region = state.regions.back().get();
- OpBuilder::InsertionGuard guard(builder);
- Block *bodyBlock = builder.createBlock(
- region, region->begin(), TypeRange{bbArgType}, {state.location});
+ build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
+ /*extraBindings=*/ValueRange());
+ buildSequenceBody(builder, state, bbArgType,
+ /*extraBindingTypes=*/TypeRange(), bodyBuilder);
+}
- // Populate body.
- builder.setInsertionPointToStart(bodyBlock);
- bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
+void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
+ TypeRange resultTypes,
+ FailurePropagationMode failurePropagationMode,
+ Type bbArgType, TypeRange extraBindingTypes,
+ SequenceBodyBuilderArgsFn bodyBuilder) {
+ build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
+ /*extraBindings=*/ValueRange());
+ buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py
index 5cd57b050012a..593b8855c935f 100644
--- a/mlir/python/mlir/dialects/_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_transform_ops_ext.py
@@ -89,7 +89,9 @@ def __init__(self,
class SequenceOp:
def __init__(self, failure_propagation_mode, results: Sequence[Type],
- target: Union[Operation, Value, Type]):
+ target: Union[Operation, Value, Type],
+ extra_bindings: Optional[Union[Sequence[Value], Sequence[Type],
+ Operation, OpView]] = None):
root = _get_op_result_or_value(target) if isinstance(
target, (Operation, Value)) else None
root_type = root.type if not isinstance(target, Type) else target
@@ -98,10 +100,25 @@ def __init__(self, failure_propagation_mode, results: Sequence[Type],
IntegerType.get_signless(32), failure_propagation_mode._as_int())
else:
failure_propagation_mode = failure_propagation_mode
+
+ if extra_bindings is None:
+ extra_bindings = []
+ if isinstance(extra_bindings, (Operation, OpView)):
+ extra_bindings = _get_op_results_or_values(extra_bindings)
+
+ extra_binding_types = []
+ if len(extra_bindings) != 0:
+ if isinstance(extra_bindings[0], Type):
+ extra_binding_types = extra_bindings
+ extra_bindings = []
+ else:
+ extra_binding_types = [v.type for v in extra_bindings]
+
super().__init__(results_=results,
failure_propagation_mode=failure_propagation_mode_attr,
- root=root)
- self.regions[0].blocks.append(root_type)
+ root=root,
+ extra_bindings=extra_bindings)
+ self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types))
@property
def body(self) -> Block:
@@ -111,6 +128,10 @@ def body(self) -> Block:
def bodyTarget(self) -> Value:
return self.body.arguments[0]
+ @property
+ def bodyExtraArgs(self) -> BlockArgumentList:
+ return self.body.arguments[1:]
+
class WithPDLPatternsOp:
diff --git a/mlir/test/Dialect/Transform/multi-arg-top-level-ops.mlir b/mlir/test/Dialect/Transform/multi-arg-top-level-ops.mlir
new file mode 100644
index 0000000000000..447c6b4be9259
--- /dev/null
+++ b/mlir/test/Dialect/Transform/multi-arg-top-level-ops.mlir
@@ -0,0 +1,71 @@
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(test-transform-dialect-interpreter{bind-first-extra-to-ops=func.func bind-second-extra-to-ops=func.return})' \
+// RUN: --split-input-file --verify-diagnostics
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
+ transform.test_print_remark_at_operand %arg1, "first extra" : !transform.any_op
+ transform.test_print_remark_at_operand %arg2, "second extra" : !transform.any_op
+}
+
+// expected-remark @below {{first extra}}
+func.func @foo() {
+ // expected-remark @below {{second extra}}
+ return
+}
+
+// expected-remark @below {{first extra}}
+func.func @bar(%arg0: i1) {
+ cf.cond_br %arg0, ^bb1, ^bb2
+^bb1:
+ // expected-remark @below {{second extra}}
+ return
+^bb2:
+ // expected-remark @below {{second extra}}
+ return
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.param<i64>):
+ // expected-error @above {{wrong kind of value provided for top-level parameter}}
+}
+
+func.func @foo() {
+ return
+}
+
+// -----
+
+// expected-error @below {{operation expects 1 extra value bindings, but 2 were provided to the interpreter}}
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op):
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
+ transform.sequence %arg0, %arg1, %arg2 : !transform.any_op, !transform.any_op, !transform.any_op failures(propagate) {
+ ^bb0(%arg3: !transform.any_op, %arg4: !transform.any_op, %arg5: !transform.any_op):
+ transform.test_print_remark_at_operand %arg4, "first extra" : !transform.any_op
+ transform.test_print_remark_at_operand %arg5, "second extra" : !transform.any_op
+ }
+}
+
+// expected-remark @below {{first extra}}
+func.func @foo() {
+ // expected-remark @below {{second extra}}
+ return
+}
+
+// expected-remark @below {{first extra}}
+func.func @bar(%arg0: i1) {
+ cf.cond_br %arg0, ^bb1, ^bb2
+^bb1:
+ // expected-remark @below {{second extra}}
+ return
+^bb2:
+ // expected-remark @below {{second extra}}
+ return
+}
diff --git a/mlir/test/Dialect/Transform/multi-arg-top-level-params.mlir b/mlir/test/Dialect/Transform/multi-arg-top-level-params.mlir
new file mode 100644
index 0000000000000..f5d7f8f595f1d
--- /dev/null
+++ b/mlir/test/Dialect/Transform/multi-arg-top-level-params.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(test-transform-dialect-interpreter{bind-first-extra-to-params=1,2,3 bind-second-extra-to-params=42,45})' \
+// RUN: --split-input-file --verify-diagnostics
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !pdl.operation, %arg1: !transform.param<i64>, %arg2: !transform.param<i64>):
+ // expected-remark @below {{1 : i64, 2 : i64, 3 : i64}}
+ transform.test_print_param %arg1 : !transform.param<i64>
+ // expected-remark @below {{42 : i64, 45 : i64}}
+ transform.test_print_param %arg2 : !transform.param<i64>
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !pdl.operation, %arg1: !transform.any_op, %arg2: !transform.param<i64>):
+ // expected-error @above {{wrong kind of value provided for top-level operation handle}}
+}
+
+// -----
+
+// expected-error @below {{operation expects 3 extra value bindings, but 2 were provided to the interpreter}}
+transform.sequence failures(propagate) {
+^bb0(%arg0: !pdl.operation, %arg1: !transform.param<i64>, %arg2: !transform.param<i64>, %arg3: !transform.param<i64>):
+}
diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
index e957d7a26e575..2fd0a37c86b85 100644
--- a/mlir/test/Dialect/Transform/ops-invalid.mlir
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -1,15 +1,22 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
-// expected-error @below {{expects the entry block to have one argument of type implementing TransformHandleTypeInterface}}
+// expected-error @below {{expects the entry block to have at least one argument}}
transform.sequence failures(propagate) {
}
// -----
+// expected-error @below {{expects the first entry block argument to be of type implementing TransformHandleTypeInterface}}
+transform.sequence failures(propagate) {
+^bb0(%rag0: i64):
+}
+
+// -----
+
// expected-note @below {{nested in another possible top-level op}}
transform.sequence failures(propagate) {
^bb0(%arg0: !pdl.operation):
- // expected-error @below {{expects the root operation to be provided for a nested op}}
+ // expected-error @below {{expects operands to be provided for a nested op}}
transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
}
@@ -17,6 +24,14 @@ transform.sequence failures(propagate) {
// -----
+// expected-error @below {{'transform.sequence' op expects trailing entry block arguments to be of type implementing TransformHandleTypeInterface or TransformParamTypeInterface}}
+// expected-note @below {{argument #1 does not}}
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op, %arg1: i64):
+}
+
+// -----
+
// expected-error @below {{expected children ops to implement TransformOpInterface}}
transform.sequence failures(propagate) {
^bb0(%arg0: !pdl.operation):
@@ -46,10 +61,29 @@ transform.sequence failures(propagate) {
// -----
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
+ // expected-error @below {{expected types to be provided for all operands}}
+ transform.sequence %arg0, %arg1, %arg2 : (!transform.any_op, !transform.any_op) failures(propagate) {
+ ^bb0(%arg3: !transform.any_op, %arg4: !transform.any_op, %arg5: !transform.any_op):
+ }
+}
+
+// -----
+
+%0 = "test.generate_something"() : () -> !transform.any_op
+// expected-error @below {{does not expect extra operands when used as top-level}}
+"transform.sequence"(%0) ({
+^bb0(%arg0: !transform.any_op):
+ "transform.yield"() : () -> ()
+}) {failure_propagation_mode = 1 : i32, operand_segment_sizes = array<i32: 0, 1>} : (!transform.any_op) -> ()
+
+// -----
+
// expected-note @below {{nested in another possible top-level op}}
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
- // expected-error @below {{expects the root operation to be provided for a nested op}}
+ // expected-error @below {{expects operands to be provided for a nested op}}
transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
}
@@ -190,7 +224,7 @@ transform.sequence failures(propagate) {
// -----
-// expected-error @below {{expects the entry block to have one argument of type implementing TransformHandleTypeInterface}}
+// expected-error @below {{expects the entry block to have at least one argument}}
transform.alternatives {
^bb0:
transform.yield
diff --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir
index 0d27f92485560..73171a8f8cd03 100644
--- a/mlir/test/Dialect/Transform/ops.mlir
+++ b/mlir/test/Dialect/Transform/ops.mlir
@@ -50,6 +50,33 @@ transform.sequence failures(propagate) {
}
}
+// CHECK: transform.sequence failures(propagate)
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
+ // CHECK: sequence %{{.*}}, %{{.*}}, %{{.*}} : (!transform.any_op, !transform.any_op, !transform.any_op) failures(propagate)
+ transform.sequence %arg0, %arg1, %arg2 : !transform.any_op, !transform.any_op, !transform.any_op failures(propagate) {
+ ^bb0(%arg3: !transform.any_op, %arg4: !transform.any_op, %arg5: !transform.any_op):
+ }
+}
+
+// CHECK: transform.sequence failures(propagate)
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
+ // CHECK: sequence %{{.*}}, %{{.*}}, %{{.*}} : (!transform.any_op, !transform.any_op, !transform.any_op) failures(propagate)
+ transform.sequence %arg0, %arg1, %arg2 : (!transform.any_op, !transform.any_op, !transform.any_op) failures(propagate) {
+ ^bb0(%arg3: !transform.any_op, %arg4: !transform.any_op, %arg5: !transform.any_op):
+ }
+}
+
+// CHECK: transform.sequence failures(propagate)
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
+ // CHECK: sequence %{{.*}}, %{{.*}}, %{{.*}} : (!transform.any_op, !transform.any_op, !transform.any_op) failures(propagate)
+ transform.sequence %arg0, %arg1, %arg2 : (!transform.any_op, !transform.any_op, !transform.any_op) failures(propagate) {
+ ^bb0(%arg3: !transform.any_op, %arg4: !transform.any_op, %arg5: !transform.any_op):
+ }
+}
+
// CHECK: transform.sequence
// CHECK: foreach
transform.sequence failures(propagate) {
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
index 1696cae0b4467..7d049eb98be51 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
@@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
@@ -39,12 +40,72 @@ class TestTransformDialectInterpreterPass
return "apply transform dialect operations one by one";
}
+ ArrayRef<transform::MappedValue>
+ findOperationsByName(Operation *root, StringRef name,
+ SmallVectorImpl<transform::MappedValue> &storage) {
+ size_t start = storage.size();
+ root->walk([&](Operation *op) {
+ if (op->getName().getStringRef() == name) {
+ storage.push_back(op);
+ }
+ });
+ return ArrayRef(storage).drop_front(start);
+ }
+
+ ArrayRef<transform::MappedValue>
+ createParameterMapping(MLIRContext &context, ArrayRef<int> values,
+ SmallVectorImpl<transform::MappedValue> &storage) {
+ size_t start = storage.size();
+ llvm::append_range(storage, llvm::map_range(values, [&](int v) {
+ Builder b(&context);
+ return transform::MappedValue(b.getI64IntegerAttr(v));
+ }));
+ return ArrayRef(storage).drop_front(start);
+ }
+
void runOnOperation() override {
+ if (!bindFirstExtraToOps.empty() && !bindFirstExtraToParams.empty()) {
+ emitError(UnknownLoc::get(&getContext()))
+ << "cannot bind the first extra top-level argument to both "
+ "operations and parameters";
+ return signalPassFailure();
+ }
+ if (!bindSecondExtraToOps.empty() && !bindSecondExtraToParams.empty()) {
+ emitError(UnknownLoc::get(&getContext()))
+ << "cannot bind the second extra top-level argument to both "
+ "operations and parameters";
+ return signalPassFailure();
+ }
+ if ((!bindSecondExtraToOps.empty() || !bindSecondExtraToParams.empty()) &&
+ bindFirstExtraToOps.empty() && bindFirstExtraToParams.empty()) {
+ emitError(UnknownLoc::get(&getContext()))
+ << "cannot bind the second extra top-level argument without binding "
+ "the first";
+ return signalPassFailure();
+ }
+
+ SmallVector<transform::MappedValue> extraMappingStorage;
+ SmallVector<ArrayRef<transform::MappedValue>> extraMapping;
+ if (!bindFirstExtraToOps.empty()) {
+ extraMapping.push_back(findOperationsByName(
+ getOperation(), bindFirstExtraToOps.getValue(), extraMappingStorage));
+ } else if (!bindFirstExtraToParams.empty()) {
+ extraMapping.push_back(createParameterMapping(
+ getContext(), bindFirstExtraToParams, extraMappingStorage));
+ }
+ if (!bindSecondExtraToOps.empty()) {
+ extraMapping.push_back(findOperationsByName(
+ getOperation(), bindSecondExtraToOps, extraMappingStorage));
+ } else if (!bindSecondExtraToParams.empty()) {
+ extraMapping.push_back(createParameterMapping(
+ getContext(), bindSecondExtraToParams, extraMappingStorage));
+ }
+
ModuleOp module = getOperation();
for (auto op :
module.getBody()->getOps<transform::TransformOpInterface>()) {
if (failed(transform::applyTransforms(
- module, op,
+ module, op, extraMapping,
transform::TransformOptions().enableExpensiveChecks(
enableExpensiveChecks))))
return signalPassFailure();
@@ -55,6 +116,24 @@ class TestTransformDialectInterpreterPass
*this, "enable-expensive-checks", llvm::cl::init(false),
llvm::cl::desc("perform expensive checks to better report errors in the "
"transform IR")};
+
+ Option<std::string> bindFirstExtraToOps{
+ *this, "bind-first-extra-to-ops",
+ llvm::cl::desc("bind the first extra argument of the top-level op to "
+ "payload operations of the given kind")};
+ ListOption<int> bindFirstExtraToParams{
+ *this, "bind-first-extra-to-params",
+ llvm::cl::desc("bind the first extra argument of the top-level op to "
+ "the given integer parameters")};
+
+ Option<std::string> bindSecondExtraToOps{
+ *this, "bind-second-extra-to-ops",
+ llvm::cl::desc("bind the second extra argument of the top-level op to "
+ "payload operations of the given kind")};
+ ListOption<int> bindSecondExtraToParams{
+ *this, "bind-second-extra-to-params",
+ llvm::cl::desc("bind the second extra argument of the top-level op to "
+ "the given integer parameters")};
};
struct TestTransformDialectEraseSchedulePass
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index c2ee6c1976037..ed6b68edcece6 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -69,6 +69,38 @@ def testNestedSequenceOp():
# CHECK: }
+ at run
+def testSequenceOpWithExtras():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get(),
+ [transform.AnyOpType.get(),
+ transform.OperationType.get("foo.bar")])
+ with InsertionPoint(sequence.body):
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testSequenceOpWithExtras
+ # CHECK: transform.sequence failures(propagate)
+ # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):
+
+
+ at run
+def testNestedSequenceOpWithExtras():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get(),
+ [transform.AnyOpType.get(),
+ transform.OperationType.get("foo.bar")])
+ with InsertionPoint(sequence.body):
+ nested = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
+ [], sequence.bodyTarget,
+ sequence.bodyExtraArgs)
+ with InsertionPoint(nested.body):
+ transform.YieldOp()
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
+ # CHECK: transform.sequence failures(propagate)
+ # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
+ # CHECK: sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)
+
+
@run
def testTransformPDLOps():
withPdl = transform.WithPDLPatternsOp(pdl.OperationType.get())
More information about the Mlir-commits
mailing list