[Mlir-commits] [mlir] ed02fa8 - [mlir] introduce parameters into the transofrm dialect
Alex Zinenko
llvmlistbot at llvm.org
Fri Jan 6 04:23:48 PST 2023
Author: Alex Zinenko
Date: 2023-01-06T12:23:29Z
New Revision: ed02fa81fd11b6ac048a574691ecd2b9d74ec274
URL: https://github.com/llvm/llvm-project/commit/ed02fa81fd11b6ac048a574691ecd2b9d74ec274
DIFF: https://github.com/llvm/llvm-project/commit/ed02fa81fd11b6ac048a574691ecd2b9d74ec274.diff
LOG: [mlir] introduce parameters into the transofrm dialect
Introduce a new kind of values into the transform dialect -- parameter
values. These values have a type implementing the new
`TransformParamTypeInterface` and are associated with lists of
attributes rather than lists of payload operations. This mechanism
allows one to wrap numeric calculations, typically heuristics, into
transform operations separate from those at actually applying the
transformation. For example, tile size computation can be now separated
from tiling itself, and not hardcoded in the transform dialect. This
further improves the separation of concerns between transform choice and
implementation.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D140976
Added:
Modified:
mlir/docs/Dialects/Transform.md
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td
mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
mlir/test/Dialect/Transform/test-interpreter.mlir
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
Removed:
################################################################################
diff --git a/mlir/docs/Dialects/Transform.md b/mlir/docs/Dialects/Transform.md
index 70ea6e9498fec..a74ae5de73d33 100644
--- a/mlir/docs/Dialects/Transform.md
+++ b/mlir/docs/Dialects/Transform.md
@@ -42,47 +42,61 @@ may look like:
```mlir
%0 = transform.loop.find { size > 42 } : !transform.interface<tileable>
-%1:2 = transform.loop.tile %0 { tile_sizes = [2,3,4] }
+%1 = transform.compute_trailing_tile_size %0 : !transform.param<index>
+%2:2 = transform.loop.tile %0 tile_sizes(1, 4, %1)
: (!transform.interface<tileable>)
- -> (!transform.op<loop>, !transform.op<loop>)
+ -> (!transform.op<loop>, !transform.op<loop>)
transform.loop.unroll %1#1 : !transform.op<loop>
```
-The values used in the Transform dialect, also referred to as *handles*,
-correspond to (groups of) operations in the payload IR. In the example
+The values used in the Transform dialect may correspond to either:
+
+ * sets of operations in the payload IR;
+
+ * sets of parameters (attributes) known at the execution time of the
+ transform dialect.
+
+The former kind of values is also referred to as *handles*. In the example
above, `%0` corresponds to the set of loops found in the payload IR that
-satisfy the condition, and `%1` correspond to groups of outer and inner
-loops, respectively, produced by the tiling transformation.
+satisfy the condition, and `%2` correspond to groups of outer and inner
+loops, respectively, produced by the tiling transformation, whereas `%1`
+corresponds to a list of tile sizes selected for each of the operations
+that `%0` corresponds to.
A transform handle such as `%0` may be associated with multiple payload
-operations. This is conceptually a set of operations and no assumptions
-should be made about the order of ops unless specified otherwise by the
-operation. Most Transform IR ops support operand values that are mapped to
-multiple operations. They usually apply the respective transformation for
-every mapped op ("batched execution"). Deviations from this convention are
-described in the documentation of Transform IR ops.
-
-The handle values have transform IR types. These types describe properties
-of payload IR operations associated with the value that are known to the
-transform dialect, for example, all associated payload operations implement
-a "TileableOp" interface, or have a specific "loop" kind. These properties
-are used to statically indicate pre- and post-conditions of a
-transformation connected to a Transform dialect operation. The conditions
-are verified when payload IR operations are first associated with a
-transform handle. By convention, Transform dialect operations are expected
-to indicate narrow preconditions for their operands by enforcing operand
-type constraints in the their definitions and verifiers. On the contrary,
-operations are expected to have few constraints on their results. Specific
-instances of a transform operation can then be created with a more
-restricted result type than the constraint in the operation (e.g., the
-"find" operation only constrains the result type to be a transform IR type
-while its concrete instance can have a type with stricter constraints such
-as implementing the "tilable" interface). The verification will then happen
-at transform execution time. This approach allows one to capture payload IR
-operation properties in the transform IR without resorting to excessive
-use of type casts or coupling dialect extensions between themselves. It is
-a trade-off between verbosity/complexity and static hardening, which can
-be revised in the future.
+operations. This is conceptually a set of operations and no assumptions should
+be made about the order of ops unless specified otherwise by the operation.
+Operations may take as operands and produce an arbitrary combination of values
+representing handles and parameters. Most Transform IR ops support operand
+values that are mapped to multiple operations. They usually apply the respective
+transformation for every mapped op ("batched execution"). Deviations from this
+convention are described in the documentation of Transform IR ops.
+
+The transform IR values have transform IR types, which implement either
+[TransformTypeInterface](Transform.md#transformtypeinterface-transformtypeinterface)
+or
+[TransformParamTypeInterface](Transform.md##transformparamtypeinterface-transformparamtypeinterface).
+The former interface verifiers properties of payload IR operations associated
+with the value that are known to the transform dialect, for example, all
+associated payload operations implement a "TileableOp" interface, or have a
+specific "loop" kind. Similarly, the latter interface verifies properties of
+attributes associated with the parameter value. These properties are used to
+statically indicate pre- and post-conditions of a transformation connected to a
+Transform dialect operation. The conditions are verified when attributes or
+payload IR operations are first associated with a transform handle. By
+convention, Transform dialect operations are expected to indicate narrow
+preconditions for their operands by enforcing operand type constraints in the
+their definitions and verifiers. On the contrary, operations are expected to
+have few constraints on their results. Specific instances of a transform
+operation can then be created with a more restricted result type than the
+constraint in the operation (e.g., the "find" operation only constrains the
+result type to be a transform IR type while its concrete instance can have a
+type with stricter constraints such as implementing the "tilable" interface).
+The verification will then happen at transform execution time. This approach
+allows one to capture payload IR operation properties in the transform IR
+without resorting to excessive use of type casts or coupling dialect extensions
+between themselves. It is a trade-off between verbosity/complexity and static
+hardening, which can be revised in the future.
Overall, Transform IR ops are expected to be contained in a single top-level
op. Such top-level ops specify how to apply the transformations described
@@ -96,8 +110,8 @@ programmatically triggered by calling:
```c++
LogicalResult transform::applyTransforms(Operation *payloadRoot,
- TransformOpInterface transform,
- const TransformOptions &options);
+ TransformOpInterface transform,
+ const TransformOptions &options);
```
that applies the transformations specified by the top-level `transform` to
@@ -139,6 +153,12 @@ The presence of interface implementations is checked at runtime when the
dialect is loaded to allow for those implementations to be supplied by
separate dialect extensions if desired.
+Similarly to operations, additional types can be injected into the dialect using
+the same extension mechanism. The types must:
+
+ * Implement exactly one of `TransformTypeInterface`,
+ `TransformParamTypeInterface`.
+
## Side Effects
The Transform dialect relies on MLIR side effect modelling to enable
@@ -250,6 +270,8 @@ must be enabled explicitly through `TransformOptions`. Additionally, the
after it has been consumed, but does so abstractly, without processing the
payload IR.
+Values associated with parameters (non-handles) cannot be invalidated.
+
## Intended Use and Integrations
The transformation control infrastructure provided by this dialect is
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 99e12a1067dd8..bfb6879e5d8b3 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -312,45 +312,58 @@ applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
/// The state maintained across applications of various ops implementing the
/// TransformOpInterface. The operations implementing this interface and the
/// surrounding structure are referred to as transform IR. The operations to
-/// which transformations apply are referred to as payload IR. The state thus
-/// contains the many-to-many mapping between values defined in the transform IR
-/// ops and payload IR ops. The "expensive-checks" option can be passed to
-/// the constructor at transformation execution time that transform IR values
-/// used as operands by a transform IR operation are not associated with
-/// dangling pointers to payload IR operations that are known to have been
-/// erased by previous transformation through the same or a
diff erent transform
-/// IR value.
+/// which transformations apply are referred to as payload IR. Transform IR
+/// operates on values that can be associated either with a list of payload IR
+/// operations (such values are referred to as handles) or with a list of
+/// parameters represented as attributes. The state thus contains the mapping
+/// between values defined in the transform IR ops and either payload IR ops or
+/// parameters. For payload ops, the mapping is many-to-many and the reverse
+/// mapping is also stored. The "expensive-checks" option can be passed to the
+/// constructor at transformation execution time that transform IR values used
+/// as operands by a transform IR operation are not associated with dangling
+/// pointers to payload IR operations that are known to have been erased by
+/// previous transformation through the same or a
diff erent transform IR value.
///
/// A reference to this class is passed as an argument to "apply" methods of the
-/// transform op interface. Thus the "apply" method can call
+/// transform op interface. Thus the "apply" method can call either
/// `state.getPayloadOps( getSomeOperand() )` to obtain the list of operations
-/// associated with its operand and subject to transformation. The method is
-/// expected to populate the `TransformResults` class instance in order to
-/// update the mapping. The `applyTransform` method takes care of propagating
-/// the state of `TransformResults` into the instance of this class.
+/// or `state.getParams( getSomeOperand() )` to obtain the list of parameters
+/// associated with its operand. The method is expected to populate the
+/// `TransformResults` class instance in order to update the mapping. The
+/// `applyTransform` method takes care of propagating the state of
+/// `TransformResults` into the instance of this class.
///
/// When applying transform IR operations with regions, the client is expected
-/// to create a RegionScope RAII object to create a new "stack frame" for
+/// to create a `RegionScope` RAII object to create a new "stack frame" for
/// values defined inside the region. The mappings from and to these values will
/// be automatically dropped when the object goes out of scope, typically at the
-/// end of the "apply" function of the parent operation. If a region contains
+/// end of the `apply` function of the parent operation. If a region contains
/// blocks with arguments, the client can map those arguments to payload IR ops
-/// using "mapBlockArguments".
+/// using `mapBlockArguments`.
class TransformState {
+public:
+ using Param = Attribute;
+
+private:
/// Mapping between a Value in the transform IR and the corresponding set of
/// operations in the payload IR.
- using TransformOpMapping = DenseMap<Value, SmallVector<Operation *>>;
+ using TransformOpMapping = DenseMap<Value, SmallVector<Operation *, 2>>;
/// Mapping between a payload IR operation and the transform IR values it is
/// associated with.
using TransformOpReverseMapping =
DenseMap<Operation *, SmallVector<Value, 2>>;
- /// Bidirectional mappings between transform IR values and payload IR
- /// operations.
+ /// Mapping between a Value in the transform IR and the corresponding list of
+ /// parameters.
+ using ParamMapping = DenseMap<Value, SmallVector<Param>>;
+
+ /// The bidirectional mappings between transform IR values and payload IR
+ /// operations, and the mapping between transform IR values and parameters.
struct Mappings {
TransformOpMapping direct;
TransformOpReverseMapping reverse;
+ ParamMapping params;
};
friend LogicalResult applyTransforms(Operation *payloadRoot,
@@ -366,6 +379,10 @@ class TransformState {
/// This is helpful for transformations that apply to a particular handle.
ArrayRef<Operation *> getPayloadOps(Value value) const;
+ /// Returns the list of parameters that the given transform IR value
+ /// corresponds to.
+ ArrayRef<Attribute> getParams(Value value) const;
+
/// Populates `handles` with all handles pointing to the given Payload IR op.
/// Returns success if such handles exist, failure otherwise.
LogicalResult getHandlesForPayloadOp(Operation *op,
@@ -590,9 +607,16 @@ class TransformState {
/// that the associated payload operation may no longer exist.
///
/// Returns failure if the payload does not satisfy the conditions associated
- /// with the type of the handle value.
+ /// with the type of the handle value. The value is expected to have a type
+ /// implementing TransformTypeInterface.
LogicalResult setPayloadOps(Value value, ArrayRef<Operation *> targets);
+ /// Sets the parameters associated with the given transform IR value. Returns
+ /// failure if the parameters do not satisfy the conditions associated with
+ /// the type of the value. The value is expected to have a type implementing
+ /// TransformParamTypeInterface.
+ LogicalResult setParams(Value value, ArrayRef<Param> params);
+
/// Forgets the payload IR ops associated with the given transform IR value.
void removePayloadOps(Value value);
@@ -661,26 +685,56 @@ class TransformResults {
public:
/// Indicates that the result of the transform IR op at the given position
/// corresponds to the given list of payload IR ops. Each result must be set
- /// by the transformation exactly once.
+ /// by the transformation exactly once. The value must have a type
+ /// implementing TransformTypeInterface.
void set(OpResult value, ArrayRef<Operation *> ops);
+ /// Indicates that the result of the transform IR op at the given position
+ /// corresponds to the given list of parameters. Each result must be set by
+ /// the transformation exactly once. The value must have a type implementing
+ /// TransformParamTypeInterface.
+ void setParams(OpResult value, ArrayRef<TransformState::Param> params);
+
private:
/// Creates an instance of TransformResults that expects mappings for
- /// `numSegments` values.
+ /// `numSegments` values, which may be associated with payload operations or
+ /// parameters.
explicit TransformResults(unsigned numSegments);
/// Gets the list of operations associated with the result identified by its
- /// number in the list of operation results.
+ /// number in the list of operation results. The result must have been set to
+ /// be associated with payload IR operations.
ArrayRef<Operation *> get(unsigned resultNumber) const;
+ /// Gets the list of parameters associated with the result identified by its
+ /// number in the list of operation results. The result must have been set to
+ /// be associated with parameters.
+ ArrayRef<TransformState::Param> getParams(unsigned resultNumber) const;
+
+ /// Returns `true` if the result identified by its number in the list of
+ /// operation results is associated with a list of parameters, `false` if it
+ /// is associated with the list of payload IR operations.
+ bool isParam(unsigned resultNumber) const;
+
/// Storage for pointers to payload IR ops that are associated with results of
/// a transform IR op. `segments` contains as many entries as the transform IR
- /// op has results. Each entry is a reference to a contiguous segment in
- /// the `operations` list that contains the pointers to operations. This
- /// allows for operations to be stored contiguously without nested vectors and
- /// for
diff erent segments to be set in any order.
+ /// op has results, even if some of them are not associated with payload IR
+ /// operations. Each entry is a reference to a contiguous segment in the
+ /// `operations` list that contains the pointers to operations. This allows
+ /// for operations to be stored contiguously without nested vectors and for
+ ///
diff erent segments to be set in any order.
SmallVector<ArrayRef<Operation *>, 2> segments;
SmallVector<Operation *> operations;
+
+ /// Storage for parameters that are associated with results of the transform
+ /// IR op. `paramSegments` contains as many entries as the transform IR op has
+ /// results, even if some of them are not associated with parameters. Each
+ /// entry is a reference to a contiguous segment in the `params` list that
+ /// contains the actual parameters. This allows for parameters to be stored
+ /// contiguously without nested vectors and for
diff erent segments to be set
+ /// in any order.
+ SmallVector<ArrayRef<TransformState::Param>, 2> paramSegments;
+ SmallVector<TransformState::Param> params;
};
TransformState::RegionScope TransformState::make_region_scope(Region ®ion) {
@@ -895,6 +949,39 @@ class NavigationTransformOpTrait
}
};
+namespace detail {
+/// Non-template implementation of ParamProducerTransformOpTrait::getEffects().
+void getParamProducerTransformOpTraitEffects(
+ Operation *op, SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
+/// Non-template implementation of ParamProducerTransformOpTrait::verify().
+LogicalResult verifyParamProducerTransformOpTrait(Operation *op);
+} // namespace detail
+
+/// Trait implementing the MemoryEffectsOpInterface for operations that produce
+/// transform dialect parameters. It marks all op results of
+/// TransformHandleTypeInterface as produced by the op, all operands as only
+/// read by the op and, if at least one of the operand is a handle to payload
+/// ops, the entire payload as potentially read. The op must only produce
+/// parameter-typed results.
+template <typename OpTy>
+class ParamProducerTransformOpTrait
+ : public OpTrait::TraitBase<OpTy, ParamProducerTransformOpTrait> {
+public:
+ /// Populates `effects` with effect instances described in the trait
+ /// documentation.
+ void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ detail::getParamProducerTransformOpTraitEffects(this->getOperation(),
+ effects);
+ }
+
+ /// Checks that the op matches the expectation of this trait, i.e., that it
+ /// implements the MemoryEffectsOpInterface and only produces parameter-typed
+ /// results.
+ static LogicalResult verifyTrait(Operation *op) {
+ return detail::verifyParamProducerTransformOpTrait(op);
+ }
+};
+
} // namespace transform
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
index c760cb12598bc..319af25f0c6af 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
@@ -103,27 +103,21 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> {
}];
}
-def TransformTypeInterface : TypeInterface<"TransformTypeInterface"> {
- let description = [{
- Types that can be used for Transform dialect handle values. Such types
- define the properties of Payload IR operations associated with the handle.
- A user of such a handle can assume that these properties have been verified
- for any Payload IR operation associated with it.
- }];
-
+class TransformTypeInterfaceBase<string cppClass, string cppObjectType>
+ : TypeInterface<cppClass> {
let cppNamespace = "::mlir::transform";
let methods = [
InterfaceMethod<
/*desc=*/[{
- Checks if the given list of associated Payload IR operations satisfy
- the conditions defined by this type. If not, produces a silenceable
+ Checks if the given associated objects (Payload IR operations or attributes)
+ satisfy the conditions defined by this type. If not, produces a silenceable
error at the specified location.
}],
/*returnType=*/"::mlir::DiagnosedSilenceableFailure",
/*name=*/"checkPayload",
/*arguments=*/(ins "::mlir::Location":$loc,
- "::mlir::ArrayRef<::mlir::Operation *>":$payload)
+ "::mlir::ArrayRef<" # cppObjectType # ">":$payload)
>
];
@@ -135,6 +129,29 @@ def TransformTypeInterface : TypeInterface<"TransformTypeInterface"> {
}];
}
+def TransformTypeInterface
+ : TransformTypeInterfaceBase<"TransformTypeInterface",
+ "::mlir::Operation *"> {
+ let description = [{
+ Types that can be used for the Transform dialect handle values. Such types
+ define the properties of Payload IR operations associated with the handle.
+ A user of such a handle can assume that these properties have been verified
+ for any Payload IR operation associated with it.
+ }];
+}
+
+def TransformParamTypeInterface
+ : TransformTypeInterfaceBase<"TransformParamTypeInterface",
+ "::mlir::Attribute"> {
+ let description = [{
+ Types that can be used for the Transform dialect parameter values. Such types
+ define the structure of the parameters associated with the value, e.g., their
+ underlying type. A user of the value can assume that the parameter has been
+ verified.
+ }];
+
+}
+
def FunctionalStyleTransformOpTrait
: NativeOpTrait<"FunctionalStyleTransformOpTrait"> {
let cppNamespace = "::mlir::transform";
@@ -148,4 +165,8 @@ def NavigationTransformOpTrait : NativeOpTrait<"NavigationTransformOpTrait"> {
let cppNamespace = "::mlir::transform";
}
+def ParamProducerTransformOpTrait : NativeOpTrait<"ParamProducerTransformOpTrait"> {
+ let cppNamespace = "::mlir::transform";
+}
+
#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_INTERFACES_TD
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td b/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td
index ec440d33e6e86..b7f39fae002b0 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td
@@ -36,6 +36,22 @@ def Transform_OperationType : TypeDef<Transform_Dialect, "Operation",
let assemblyFormat = "`<` $operation_name `>`";
}
+def Transform_ParamType : TypeDef<Transform_Dialect, "Param",
+ [DeclareTypeInterfaceMethods<TransformParamTypeInterface>]> {
+ let description = [{
+ Transform IR value that can be associated with the list of parameters
+ of the given type. Types are currently limited to integers, but may be
+ extended in the future to other types values of which can be contained
+ in attributes.
+ }];
+ let mnemonic = "param";
+ let parameters = (ins
+ TypeParameter<"::mlir::Type", "Underlying type of the parameter">:$type
+ );
+ let assemblyFormat = "`<` $type `>`";
+ let genVerifyDecl = 1;
+}
+
class Transform_ConcreteOpType<string opname>
: Type<And<[Transform_OperationType.predicate,
CPred<"$_self.cast<::mlir::transform::OperationType>()"
diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
index d3bd79e86c8d6..e7a08996fb03e 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
@@ -38,7 +38,11 @@ void transform::detail::checkImplementsTransformOpInterface(
void transform::detail::checkImplementsTransformTypeInterface(
TypeID typeID, MLIRContext *context) {
const auto &abstractType = AbstractType::lookup(typeID, context);
- assert(abstractType.hasInterface(TransformTypeInterface::getInterfaceID()));
+ assert(
+ (abstractType.hasInterface(TransformTypeInterface::getInterfaceID()) ||
+ abstractType.hasInterface(
+ TransformParamTypeInterface::getInterfaceID())) &&
+ "expected Transform dialect type to implement one of the two interfaces");
}
#endif // NDEBUG
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 10a4381e1d75a..5a178f31e96e8 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -44,7 +44,16 @@ ArrayRef<Operation *>
transform::TransformState::getPayloadOps(Value value) const {
const TransformOpMapping &operationMapping = getMapping(value).direct;
auto iter = operationMapping.find(value);
- assert(iter != operationMapping.end() && "unknown handle");
+ assert(iter != operationMapping.end() &&
+ "cannot find mapping for payload handle (param handle provided?)");
+ return iter->getSecond();
+}
+
+ArrayRef<Attribute> transform::TransformState::getParams(Value value) const {
+ const ParamMapping &mapping = getMapping(value).params;
+ auto iter = mapping.find(value);
+ assert(iter != mapping.end() &&
+ "cannot find mapping for param handle (payload handle provided?)");
return iter->getSecond();
}
@@ -67,6 +76,8 @@ transform::TransformState::setPayloadOps(Value value,
ArrayRef<Operation *> targets) {
assert(value != kTopLevelValue &&
"attempting to reset the transformation root");
+ assert(!value.getType().isa<TransformParamTypeInterface>() &&
+ "cannot associate payload ops with a value of parameter type");
auto iface = value.getType().cast<TransformTypeInterface>();
DiagnosedSilenceableFailure result =
@@ -89,6 +100,26 @@ transform::TransformState::setPayloadOps(Value value,
return success();
}
+LogicalResult transform::TransformState::setParams(Value value,
+ ArrayRef<Param> params) {
+ assert(value != nullptr && "attempting to set params for a null value");
+
+ auto valueType = value.getType().dyn_cast<TransformParamTypeInterface>();
+ assert(value &&
+ "cannot associate parameter with a value of non-parameter type");
+ DiagnosedSilenceableFailure result =
+ valueType.checkPayload(value.getLoc(), params);
+ if (failed(result.checkAndReport()))
+ return failure();
+
+ Mappings &mappings = getMapping(value);
+ bool inserted =
+ mappings.params.insert({value, llvm::to_vector(params)}).second;
+ assert(inserted && "value is already associated with another list of params");
+ (void)inserted;
+ return success();
+}
+
void transform::TransformState::dropReverseMapping(Mappings &mappings,
Operation *op, Value value) {
auto it = mappings.reverse.find(op);
@@ -112,8 +143,8 @@ LogicalResult transform::TransformState::updatePayloadOps(
Mappings &mappings = getMapping(value);
auto it = mappings.direct.find(value);
assert(it != mappings.direct.end() && "unknown handle");
- SmallVector<Operation *> &association = it->getSecond();
- SmallVector<Operation *> updated;
+ SmallVector<Operation *, 2> &association = it->getSecond();
+ SmallVector<Operation *, 2> updated;
updated.reserve(association.size());
for (Operation *op : association) {
@@ -269,8 +300,21 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
assert(result.getDefiningOp() == transform.getOperation() &&
"payload IR association for a value other than the result of the "
"current transform op");
- if (failed(setPayloadOps(result, results.get(result.getResultNumber()))))
- return DiagnosedSilenceableFailure::definiteFailure();
+ if (result.getType().isa<TransformParamTypeInterface>()) {
+ assert(results.isParam(result.getResultNumber()) &&
+ "expected parameters for the parameter-typed result");
+ if (failed(
+ setParams(result, results.getParams(result.getResultNumber())))) {
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+ } else {
+ assert(!results.isParam(result.getResultNumber()) &&
+ "expected payload ops for the non-parameter typed result");
+ if (failed(
+ setPayloadOps(result, results.get(result.getResultNumber())))) {
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+ }
}
printOnFailureRAII.release();
@@ -312,6 +356,8 @@ transform::TransformState::Extension::replacePayloadOp(Operation *op,
transform::TransformResults::TransformResults(unsigned numSegments) {
segments.resize(numSegments,
ArrayRef<Operation *>(nullptr, static_cast<size_t>(0)));
+ paramSegments.resize(numSegments, ArrayRef<TransformState::Param>(
+ nullptr, static_cast<size_t>(0)));
}
void transform::TransformResults::set(OpResult value,
@@ -325,14 +371,41 @@ void transform::TransformResults::set(OpResult value,
segments[position] = makeArrayRef(operations).drop_front(start);
}
+void transform::TransformResults::setParams(
+ OpResult value, ArrayRef<transform::TransformState::Param> params) {
+ int64_t position = value.getResultNumber();
+ assert(position < static_cast<int64_t>(paramSegments.size()) &&
+ "setting params for a non-existent handle");
+ assert(paramSegments[position].data() == nullptr && "params already set");
+ size_t start = this->params.size();
+ llvm::append_range(this->params, params);
+ paramSegments[position] = makeArrayRef(this->params).drop_front(start);
+}
+
ArrayRef<Operation *>
transform::TransformResults::get(unsigned resultNumber) const {
assert(resultNumber < segments.size() &&
"querying results for a non-existent handle");
- assert(segments[resultNumber].data() != nullptr && "querying unset results");
+ assert(segments[resultNumber].data() != nullptr &&
+ "querying unset results (param expected?)");
return segments[resultNumber];
}
+ArrayRef<transform::TransformState::Param>
+transform::TransformResults::getParams(unsigned resultNumber) const {
+ assert(resultNumber < paramSegments.size() &&
+ "querying params for a non-existent handle");
+ assert(paramSegments[resultNumber].data() != nullptr &&
+ "querying unset params (payload ops expected?)");
+ return paramSegments[resultNumber];
+}
+
+bool transform::TransformResults::isParam(unsigned resultNumber) const {
+ assert(resultNumber < paramSegments.size() &&
+ "querying association for a non-existent handle");
+ return paramSegments[resultNumber].data() != nullptr;
+}
+
//===----------------------------------------------------------------------===//
// Utilities for PossibleTopLevelTransformOpTrait.
//===----------------------------------------------------------------------===//
@@ -386,6 +459,43 @@ transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) {
return success();
}
+//===----------------------------------------------------------------------===//
+// Utilities for ParamProducedTransformOpTrait.
+//===----------------------------------------------------------------------===//
+
+void transform::detail::getParamProducerTransformOpTraitEffects(
+ Operation *op, SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ producesHandle(op->getResults(), effects);
+ bool hasPayloadOperands = false;
+ for (Value operand : op->getOperands()) {
+ onlyReadsHandle(operand, effects);
+ if (operand.getType().isa<TransformTypeInterface>())
+ hasPayloadOperands = true;
+ }
+ if (hasPayloadOperands)
+ onlyReadsPayload(effects);
+}
+
+LogicalResult
+transform::detail::verifyParamProducerTransformOpTrait(Operation *op) {
+ // Interfaces can be attached dynamically, so this cannot be a static
+ // assert.
+ if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
+ llvm::report_fatal_error(
+ Twine("ParamProducerTransformOpTrait must be attached to an op that "
+ "implements MemoryEffectsOpInterface, found on ") +
+ op->getName().getStringRef());
+ }
+ for (Value result : op->getResults()) {
+ if (result.getType().isa<TransformParamTypeInterface>())
+ continue;
+ return op->emitOpError()
+ << "ParamProducerTransformOpTrait attached to this op expects "
+ "result types to implement TransformParamTypeInterface";
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Memory effects.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
index 33771e658e6b7..4e4fb2da91df6 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Types.h"
@@ -37,12 +38,20 @@ void transform::TransformDialect::initializeTypes() {
>();
}
+//===----------------------------------------------------------------------===//
+// transform::AnyOpType
+//===----------------------------------------------------------------------===//
+
DiagnosedSilenceableFailure
transform::AnyOpType::checkPayload(Location loc,
ArrayRef<Operation *> payload) const {
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// transform::OperationType
+//===----------------------------------------------------------------------===//
+
DiagnosedSilenceableFailure
transform::OperationType::checkPayload(Location loc,
ArrayRef<Operation *> payload) const {
@@ -58,3 +67,35 @@ transform::OperationType::checkPayload(Location loc,
return DiagnosedSilenceableFailure::success();
}
+
+//===----------------------------------------------------------------------===//
+// transform::ParamType
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+transform::ParamType::verify(function_ref<InFlightDiagnostic()> emitError,
+ Type type) {
+ IntegerType intType = type.dyn_cast<IntegerType>();
+ if (!intType || intType.getWidth() > 64)
+ return emitError() << "only supports integer types with width <=64";
+ return success();
+}
+
+DiagnosedSilenceableFailure
+transform::ParamType::checkPayload(Location loc,
+ ArrayRef<Attribute> payload) const {
+ for (Attribute attr : payload) {
+ auto integerAttr = attr.dyn_cast<IntegerAttr>();
+ if (!integerAttr) {
+ return emitSilenceableError(loc)
+ << "expected parameter to be an integer attribute, got " << attr;
+ }
+ if (integerAttr.getType() != getType()) {
+ return emitSilenceableError(loc)
+ << "expected the type of the parameter attribute ("
+ << integerAttr.getType() << ") to match the parameter type ("
+ << getType() << ")";
+ }
+ }
+ return DiagnosedSilenceableFailure::success();
+}
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 6246716782dcc..6b8c98459d436 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -920,6 +920,7 @@ transform.with_pdl_patterns {
}
"test.some_op"() : () -> ()
+
// -----
func.func @split_handles(%a: index, %b: index, %c: index) {
@@ -937,3 +938,56 @@ transform.sequence -> !pdl.operation failures(propagate) {
/// propagate mode.
yield %fun : !pdl.operation
}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ %0 = transform.test_produce_integer_param_with_type i32 : !transform.test_dialect_param
+ // expected-remark @below {{0 : i32}}
+ transform.test_print_param %0 : !transform.test_dialect_param
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ // expected-error @below {{expected the type of the parameter attribute ('i32') to match the parameter type ('i64')}}
+ transform.test_produce_integer_param_with_type i32 : !transform.param<i64>
+}
+
+// -----
+
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ %0 = transform.test_add_to_param 40
+ %1 = transform.test_add_to_param %0, 2
+ // expected-remark @below {{42 : i32}}
+ transform.test_print_param %1 : !transform.test_dialect_param
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !pdl.operation):
+ %0 = transform.structured.match ops{["func.func"]} in %arg0
+ %1 = transform.test_produce_param_with_number_of_test_ops %0 : !pdl.operation
+ // expected-remark @below {{1 : i32, 3 : i32}}
+ transform.test_print_param %1 : !transform.test_dialect_param
+ %2 = transform.test_add_to_param %1, 100
+ // expected-remark @below {{101 : i32, 103 : i32}}
+ transform.test_print_param %2 : !transform.test_dialect_param
+}
+
+func.func private @one_test_op(%arg0: i32) {
+ "test.op_a"(%arg0) { attr = 0 : i32} : (i32) -> i32
+ return
+}
+
+func.func private @three_test_ops(%arg0: i32) {
+ "test.op_a"(%arg0) { attr = 0 : i32} : (i32) -> i32
+ "test.op_a"(%arg0) { attr = 0 : i32} : (i32) -> i32
+ "test.op_a"(%arg0) { attr = 0 : i32} : (i32) -> i32
+ return
+}
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 483d4feda7abc..fcdecb0c9e28c 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -17,8 +17,10 @@
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/OpImplementation.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Compiler.h"
+#include "llvm/Support/raw_ostream.h"
using namespace mlir;
@@ -317,15 +319,27 @@ DiagnosedSilenceableFailure mlir::transform::TestDialectOpType::checkPayload(
for (Operation *op : payload) {
if (op->getName().getDialectNamespace() != "test") {
- Diagnostic diag(loc, DiagnosticSeverity::Error);
- diag << "expected the payload operation to belong to the 'test' dialect";
- return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
+ return emitSilenceableError(loc) << "expected the payload operation to "
+ "belong to the 'test' dialect";
}
}
return DiagnosedSilenceableFailure::success();
}
+DiagnosedSilenceableFailure mlir::transform::TestDialectParamType::checkPayload(
+ Location loc, ArrayRef<Attribute> payload) const {
+ for (Attribute attr : payload) {
+ auto integerAttr = attr.dyn_cast<IntegerAttr>();
+ if (integerAttr && integerAttr.getType().isSignlessInteger(32))
+ continue;
+ return emitSilenceableError(loc)
+ << "expected the parameter to be a i32 integer attribute";
+ }
+
+ return DiagnosedSilenceableFailure::success();
+}
+
void mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getTarget(), effects);
@@ -346,6 +360,75 @@ mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::apply(
return DiagnosedSilenceableFailure::success();
}
+void mlir::test::TestPrintParamOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::onlyReadsHandle(getParam(), effects);
+}
+
+DiagnosedSilenceableFailure
+mlir::test::TestPrintParamOp::apply(transform::TransformResults &results,
+ transform::TransformState &state) {
+ std::string str;
+ llvm::raw_string_ostream os(str);
+ llvm::interleaveComma(state.getParams(getParam()), os);
+ auto diag = emitRemark() << os.str();
+ return DiagnosedSilenceableFailure::success();
+}
+
+DiagnosedSilenceableFailure
+mlir::test::TestAddToParamOp::apply(transform::TransformResults &results,
+ transform::TransformState &state) {
+ SmallVector<uint32_t> values(/*Size=*/1, /*Value=*/0);
+ if (Value param = getParam()) {
+ values = llvm::to_vector(
+ llvm::map_range(state.getParams(param), [](Attribute attr) -> uint32_t {
+ return attr.cast<IntegerAttr>().getValue().getLimitedValue(
+ UINT32_MAX);
+ }));
+ }
+
+ Builder builder(getContext());
+ SmallVector<Attribute> result = llvm::to_vector(
+ llvm::map_range(values, [this, &builder](uint32_t value) -> Attribute {
+ return builder.getI32IntegerAttr(value + getAddendum());
+ }));
+ results.setParams(getResult().cast<OpResult>(), result);
+ return DiagnosedSilenceableFailure::success();
+}
+
+DiagnosedSilenceableFailure
+mlir::test::TestProduceParamWithNumberOfTestOps::apply(
+ transform::TransformResults &results, transform::TransformState &state) {
+ Builder builder(getContext());
+ SmallVector<Attribute> result = llvm::to_vector(
+ llvm::map_range(state.getPayloadOps(getHandle()),
+ [&builder](Operation *payload) -> Attribute {
+ int32_t count = 0;
+ payload->walk([&count](Operation *op) {
+ if (op->getName().getDialectNamespace() == "test")
+ ++count;
+ });
+ return builder.getI32IntegerAttr(count);
+ }));
+ results.setParams(getResult().cast<OpResult>(), result);
+ return DiagnosedSilenceableFailure::success();
+}
+
+DiagnosedSilenceableFailure
+mlir::test::TestProduceIntegerParamWithTypeOp::apply(
+ transform::TransformResults &results, transform::TransformState &state) {
+ Attribute zero = IntegerAttr::get(getType(), 0);
+ results.setParams(getResult().cast<OpResult>(), zero);
+ return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult mlir::test::TestProduceIntegerParamWithTypeOp::verify() {
+ if (!getType().isa<IntegerType>()) {
+ return emitOpError() << "expects an integer type";
+ }
+ return success();
+}
+
namespace {
/// Test extension of the Transform dialect. Registers additional ops and
/// declares PDL as dependent dialect since the additional ops are using PDL
@@ -371,9 +454,6 @@ class TestTransformDialectExtension
};
} // namespace
-#define GET_OP_CLASSES
-#include "TestTransformDialectExtension.cpp.inc"
-
// These are automatically generated by ODS but are not used as the Transform
// dialect uses a
diff erent dispatch mechanism to support dialect extensions.
LLVM_ATTRIBUTE_UNUSED static OptionalParseResult
@@ -384,6 +464,9 @@ generatedTypePrinter(Type def, AsmPrinter &printer);
#define GET_TYPEDEF_CLASSES
#include "TestTransformDialectExtensionTypes.cpp.inc"
+#define GET_OP_CLASSES
+#include "TestTransformDialectExtension.cpp.inc"
+
void ::test::registerTestTransformDialectExtension(DialectRegistry ®istry) {
registry.addExtensions<TestTransformDialectExtension>();
}
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h
index e5785fdec6cda..dac611f3636f8 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h
@@ -23,12 +23,12 @@ namespace mlir {
class DialectRegistry;
} // namespace mlir
-#define GET_OP_CLASSES
-#include "TestTransformDialectExtension.h.inc"
-
#define GET_TYPEDEF_CLASSES
#include "TestTransformDialectExtensionTypes.h.inc"
+#define GET_OP_CLASSES
+#include "TestTransformDialectExtension.h.inc"
+
namespace test {
/// Registers the test extension to the Transform dialect.
void registerTestTransformDialectExtension(::mlir::DialectRegistry ®istry);
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index 58dc53431f0a6..59ad0442bfe80 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -29,6 +29,16 @@ def TestTransformTestDialectHandleType
let assemblyFormat = "";
}
+def TestTransformTestDialectParamType
+ : TypeDef<Transform_Dialect, "TestDialectParam",
+ [DeclareTypeInterfaceMethods<TransformParamTypeInterface>]> {
+ let description = [{
+ Parameter associated with an i32 attribute for testing purposes.
+ }];
+ let mnemonic = "test_dialect_param";
+ let assemblyFormat = "";
+}
+
def TestProduceParamOrForwardOperandOp
: Op<Transform_Dialect, "test_produce_param_or_forward_operand",
[DeclareOpInterfaceMethods<TransformOpInterface>]> {
@@ -262,4 +272,45 @@ def TestReportNumberOfTrackedHandlesNestedUnder
let cppNamespace = "::mlir::test";
}
+def TestPrintParamOp
+ : Op<Transform_Dialect, "test_print_param",
+ [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let arguments = (ins TransformParamTypeInterface:$param);
+ let assemblyFormat = "$param attr-dict `:` type($param)";
+ let cppNamespace = "::mlir::test";
+}
+
+def TestAddToParamOp
+ : Op<Transform_Dialect, "test_add_to_param",
+ [MemoryEffectsOpInterface, ParamProducerTransformOpTrait,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let arguments = (ins Optional<TestTransformTestDialectParamType>:$param,
+ I32Attr:$addendum);
+ let results = (outs TestTransformTestDialectParamType:$result);
+ let assemblyFormat = "($param^ `,`)? $addendum attr-dict";
+ let cppNamespace = "::mlir::test";
+}
+
+def TestProduceParamWithNumberOfTestOps
+ : Op<Transform_Dialect, "test_produce_param_with_number_of_test_ops",
+ [MemoryEffectsOpInterface, ParamProducerTransformOpTrait,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let arguments = (ins TransformTypeInterface:$handle);
+ let results = (outs TestTransformTestDialectParamType:$result);
+ let assemblyFormat = "$handle attr-dict `:` type($handle)";
+ let cppNamespace = "::mlir::test";
+}
+
+def TestProduceIntegerParamWithTypeOp
+ : Op<Transform_Dialect, "test_produce_integer_param_with_type",
+ [MemoryEffectsOpInterface, ParamProducerTransformOpTrait,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let arguments = (ins TypeAttr:$type);
+ let results = (outs TransformParamTypeInterface:$result);
+ let assemblyFormat = "$type attr-dict `:` type($result)";
+ let cppNamespace = "::mlir::test";
+ let hasVerifier = 1;
+}
+
#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD
More information about the Mlir-commits
mailing list