[Mlir-commits] [mlir] 3fe7127 - [mlir] add structured (Linalg) transform op matchers
Alex Zinenko
llvmlistbot at llvm.org
Thu Apr 13 05:38:00 PDT 2023
Author: Alex Zinenko
Date: 2023-04-13T12:37:51Z
New Revision: 3fe7127d48fde531901c795826b566d93dbf1ea4
URL: https://github.com/llvm/llvm-project/commit/3fe7127d48fde531901c795826b566d93dbf1ea4
DIFF: https://github.com/llvm/llvm-project/commit/3fe7127d48fde531901c795826b566d93dbf1ea4.diff
LOG: [mlir] add structured (Linalg) transform op matchers
Add a set of transform operations into the "structured" extension of the
Transform dialect that allow one to select transformation targets more
specifically than the currently available matching. In particular, add
the mechanism for identifying the producers of operands (input and init
in destination-passing style) and users of results, as well as
mechanisms for reasoning about the shape of the iteration space.
Additionally, add several transform operations to manipulate parameters
that could be useful to implement more advanced selectors. Specifically,
new operations let one produce and compare parameter values to implement
shape-driven transformations.
New operations are placed in separate files to decrease compilation
time. Some relayering of the extension is necessary to avoid repeated
generation of enums.
Depends on D148013
Depends on D148014
Depends on D148015
Reviewed By: chelini
Differential Revision: https://reviews.llvm.org/D148017
Added:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td
mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
mlir/lib/Dialect/Transform/IR/TransformAttrs.cpp
mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
mlir/test/Dialect/Linalg/match-ops-invalid.mlir
mlir/test/Dialect/Linalg/match-ops.mlir
Modified:
mlir/docs/Dialects/Transform.md
mlir/include/mlir/Dialect/Linalg/TransformOps/CMakeLists.txt
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td
mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td
mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td
mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Transform/IR/CMakeLists.txt
mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Transform/test-interpreter.mlir
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/docs/Dialects/Transform.md b/mlir/docs/Dialects/Transform.md
index 123d66136d390..586c12f2fa523 100644
--- a/mlir/docs/Dialects/Transform.md
+++ b/mlir/docs/Dialects/Transform.md
@@ -431,6 +431,10 @@ ops rather than having the methods directly act on the payload IR.
[include "Dialects/MemRefTransformOps.md"]
+## Structured (Linalg) Match Operations
+
+[include "Dialects/LinalgStructuredMatchOps.md"]
+
## Structured (Linalg) Transform Operations
[include "Dialects/LinalgStructuredTransformOps.md"]
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/TransformOps/CMakeLists.txt
index 737d9af5c7641..4f6b251bf6964 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/CMakeLists.txt
@@ -1,8 +1,17 @@
+set(LLVM_TARGET_DEFINITIONS LinalgMatchOps.td)
+mlir_tablegen(LinalgMatchOps.h.inc -gen-op-decls)
+mlir_tablegen(LinalgMatchOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRLinalgMatchOpsIncGen)
+
set(LLVM_TARGET_DEFINITIONS LinalgTransformOps.td)
mlir_tablegen(LinalgTransformOps.h.inc -gen-op-decls)
mlir_tablegen(LinalgTransformOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRLinalgTransformOpsIncGen)
+
+set(LLVM_TARGET_DEFINITIONS LinalgTransformEnums.td)
mlir_tablegen(LinalgTransformOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(LinalgTransformOpsEnums.cpp.inc -gen-enum-defs)
-add_public_tablegen_target(MLIRLinalgTransformOpsIncGen)
+add_public_tablegen_target(MLIRLinalgTransformEnumsIncGen)
+add_mlir_doc(LinalgMatchOps LinalgStructuredMatchOps Dialects/ -gen-op-doc)
add_mlir_doc(LinalgTransformOps LinalgStructuredTransformOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
new file mode 100644
index 0000000000000..c409adb05f4b3
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
@@ -0,0 +1,465 @@
+//===- LinalgMatchOps.td - Linalg transform matcher ops ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LINALG_MATCH_OPS
+#define LINALG_MATCH_OPS
+
+include "mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td"
+include "mlir/Dialect/Transform/IR/MatchInterfaces.td"
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// Structured match op and predicates usable inside it.
+//===----------------------------------------------------------------------===//
+
+def MatchStructuredOp : Op<Transform_Dialect, "match.structured", [
+ MatchOpInterface,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ SingleOpMatcher,
+ SingleBlockImplicitTerminator<"::mlir::transform::MatchStructuredYieldOp">]> {
+ let summary =
+ "Matches a structured (linalg) operation with additional conditions";
+ let description = [{
+ Checks if the payload operation associated with the operand handle is a
+ structured operation, that is, an operation that implements
+ `LinalgOpInterface`, and that all conditions listed in the body of this
+ operation are satisfied. Produces a silenceable failure if the payload
+ operation is not structured.
+
+ The transform operations nested in the body region are applied one by one.
+ If any of them produces a failure, silenceable or definite, the following
+ operations are not applied. If the failure propagation mode is "propagate",
+ silenceable failures are forwarded as the result of this operation. If it is
+ "suppress", they are ignored and this operation immediately succeeds.
+ Definite failures are always propagated immediately.
+
+ In case of success, the transform values produced by this operation are
+ associated with the same payload as the operands of the block terminator. If
+ any of the nested operations produced a silenceable failure, regardless of
+ the failure propagation mode, the transform values produced by this
+ operation that correspond to the already defined terminator operands are
+ associated with the same payload as the already defined terminator operands.
+ Other values produced by this operation are associated with empty payloads.
+
+ If the failure propagation mode is not specified, it is considered
+ "propagate" by default. The "suppress" mode can be used to specify optional
+ matches.
+
+ #### Return modes
+
+ This operation only reads all operand handles and produces all resulting
+ handles. It succeeds in "propagate" mode if the payload operation is a
+ structured operation and if all the nested operations succeed. It succeeds
+ in "suppress" mode as long as the operand handle is associated with exactly
+ one payload operation. It produces a definite failure when the handle is
+ not associated with exactly one payload operation.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$current,
+ OptionalAttr<FailurePropagationMode>:$failure_propagation_mode);
+ let results = (outs Variadic<Transform_AnyHandleOrParamType>:$outputs);
+
+ let regions = (region SizedRegion<1>:$body_region);
+ let assemblyFormat =
+ "(`failures` `(` $failure_propagation_mode^ `)`)?"
+ "$current `:` custom<SemiFunctionType>(type($current), type($outputs))"
+ "attr-dict-with-keyword regions";
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = SingleOpMatcher.extraDeclaration # [{
+ ::mlir::Value getOperandHandle() { return getCurrent(); }
+ }];
+}
+
+def StructuredPredicate : NativeOpTrait<"StructuredOpPredicateOpTrait"> {
+ let cppNamespace = "::mlir::transform";
+ string extraDescription = [{
+ This op can only appear immediately inside a `transform.match.structured`
+ op and apply to its first block argument because it assumes the payload
+ to have been already checked for being a single structured op.
+ }];
+}
+
+def MatchStructuredBodyOp : Op<Transform_Dialect, "match.structured.body", [
+ SingleOpMatcher,
+ StructuredPredicate,
+ MatchOpInterface,
+ MemoryEffectsOpInterface]> {
+ let summary =
+ "Checks if the body of the structured op satisfies some criteria";
+ let description = !strconcat([{
+ Checks if the body of the structured payload op satisfies one of the
+ following mutually exclusive criteria specified by attributes:
+
+ * `reduction_position`: the body of the structured payload op implements
+ a reduction of the `n`-th operand (`n` is the value of the attribute)
+ using a single combiner operation;
+
+ * `passthrough`: the body of the structured payload op only forwards
+ inputs to the outputs (copy or broadcast).
+
+ }], StructuredPredicate.extraDescription, [{
+
+ #### Return modes
+
+ Succeeds if the operation body satisfies the specified criteria, produces a
+ silenceable failure otherwise. Produces a definite failure if the operand is
+ not associated with a single payload op.
+ }]);
+ let arguments = (ins TransformHandleTypeInterface:$operand_handle,
+ OptionalAttr<I64Attr>:$reduction_position,
+ UnitAttr:$passthrough);
+ let assemblyFormat = "$operand_handle attr-dict `:` type($operand_handle)";
+ let extraClassDeclaration = SingleOpMatcher.extraDeclaration;
+ let hasVerifier = 1;
+}
+
+class StructuredDimDescription<string kind> {
+ string description = !strconcat([{
+ The following }], kind ,[{ specifications are supported:
+
+ * `all`: all }], kind ,[{s are checked and captured;
+ * list of integers: the listed }], kind, [{s are checked and captured;
+ * `except(` list of integers `)`: all }], kind, [{s except the
+ specified ones are checked and captured.
+
+ Negative indexes are interpreted by counting values from the last one
+ (similarly to Python). For example, `-1` means the last }], kind, [{ and
+ `except(-1)` means all }], kind, [{s but the last. Indexes must be unique,
+ including after interpretation of negative ones.
+
+ Produces a silenceable failure in case of index overflow, including backward
+ counting.
+ }]);
+}
+
+def MatchStructuredDimOp : Op<Transform_Dialect, "match.structured.dim", [
+ StructuredPredicate,
+ SingleOpMatcher,
+ MatchOpInterface,
+ MemoryEffectsOpInterface]> {
+ let summary =
+ "Checks if the dimensions of the structured op satisfy some criteria";
+ let description = !strconcat([{
+ Checks if the dimensions (loop ranges) of the structured payload op satisfy
+ the criteria specified as attributes. May capture the numeric value of the
+ dimension into a parameter that it returns.
+
+ }],
+ StructuredDimDescription<"dimension">.description,
+ [{
+
+ The following mutually exclusive conditions are available as unit
+ attributes:
+
+ * `parallel`: the dimension corresponds to a parallel loop;
+ * `reduction`: the dimension corresponds to a reduction loop.
+
+ If the result type is specified, associates the parameter with the (static)
+ values of dimensions in the same order as listed and preserving the natural
+ order for `all` and `except`. Specifically, if `-1, -2` are specified, the
+ parameter will be associated with the value of the second-to-last dimension
+ followed by the last dimension. If the dimension is dynamic, the parameter
+ will contain a negative value corresponding to kDynamic in C++.
+
+ }], StructuredPredicate.extraDescription, [{
+
+ #### Return modes
+
+ Succeeds if the specified dimensions satisfy the specified criteria,
+ produces a silenceable failure otherwise. Produces a definite failure if
+ the operand is not associated with a single payload op.
+ }]);
+
+ let arguments = (ins TransformHandleTypeInterface:$operand_handle,
+ DenseI64ArrayAttr:$raw_dim_list,
+ UnitAttr:$is_inverted,
+ UnitAttr:$is_all,
+ UnitAttr:$parallel,
+ UnitAttr:$reduction);
+
+ let results = (outs Optional<TransformParamTypeInterface>:$result);
+ let assemblyFormat =
+ "$operand_handle `[`"
+ "custom<StructuredTransformDims>($raw_dim_list, $is_inverted, $is_all)"
+ "`]` attr-dict `:` "
+ "custom<SemiFunctionType>(type($operand_handle), type($result))";
+
+ let extraClassDeclaration = SingleOpMatcher.extraDeclaration # [{
+ ::mlir::DiagnosedSilenceableFailure getDimensionsFor(
+ ::mlir::linalg::LinalgOp op,
+ ::llvm::SmallVectorImpl<int64_t> &dims);
+ }];
+
+ let hasVerifier = 1;
+}
+
+def MatchStructuredElementalBitwidthOp
+ : Op<Transform_Dialect, "match.structured.elemental_bitwidth", [
+ SingleValueMatcher,
+ MatchOpInterface,
+ MemoryEffectsOpInterface]> {
+ let summary =
+ "Captures the bitwidth of the value's elemental type as a parameter";
+ let description = !strconcat([{
+ Produces a transform dialect parameter associated with the bitwidth of the
+ elemental type of the payload value passed as the operand.}],
+
+ StructuredPredicate.extraDescription, [{
+
+ #### Return modes
+
+ Succeeds if the operand is associated with exactly one payload value of
+ `ShapedType`. Produces a silenceable failure otherwise.
+ }]);
+ let arguments = (ins TransformValueHandleTypeInterface:$operand_handle);
+ let results = (outs TransformParamTypeInterface:$result);
+ let assemblyFormat =
+ "$operand_handle attr-dict `:` functional-type(operands, results)";
+ let extraClassDeclaration = SingleValueMatcher.extraDeclaration;
+}
+
+class MatchStructuredOperandOp<string opname> : Op<Transform_Dialect, opname, [
+ StructuredPredicate,
+ SingleOpMatcher,
+ MatchOpInterface,
+ MemoryEffectsOpInterface]> {
+
+ // TODO: consider an attribute controlling whether to fail or succeed on
+ // out-of-bounds accesses.
+ let arguments = (ins TransformHandleTypeInterface:$operand_handle,
+ DenseI64ArrayAttr:$raw_position_list,
+ UnitAttr:$is_inverted,
+ UnitAttr:$is_all,
+ UnitAttr:$permutation,
+ UnitAttr:$projected_permutation);
+
+ // TODO: allow this to bind multiple inputs simultaneously after checking that
+ // `transform.foreach` works well in matches.
+ let results = (outs Optional<TransformAnyHandle>:$result);
+ let assemblyFormat =
+ "$operand_handle `[`"
+ "custom<StructuredTransformDims>($raw_position_list, $is_inverted, $is_all)"
+ "`]` attr-dict "
+ "`:` custom<SemiFunctionType>(type($operand_handle), type($result))";
+
+ let extraClassDeclaration = SingleOpMatcher.extraDeclaration # [{
+ ::mlir::DiagnosedSilenceableFailure getPositionsFor(
+ ::mlir::linalg::LinalgOp op,
+ ::llvm::SmallVectorImpl<int64_t> &positions);
+ }];
+
+ let hasVerifier = 1;
+}
+
+def MatchStructuredInputOp : MatchStructuredOperandOp<"match.structured.input"> {
+ let summary =
+ "Captures input operand(s) of a structured operation in an op or value handle";
+ let description = !strconcat([{
+ Produces a transform dialect value handle associated with the payload value
+ supplied as input operand to the given structured payload operation, or an
+ operation handle to the structured payload operation producing said payload
+ value depending on the result type.
+
+ }],
+ StructuredDimDescription<"input">.description,
+ [{
+
+ }],
+ StructuredPredicate.extraDescription,
+ [{
+
+ #### Return modes
+
+ Succeeds if all input indexes are in bounds, produces a silenceable failure
+ otherwise. Additionally, when the result is an operation handle, produces a
+ silenceable failure if the input specification defines more than one input
+ or if the operand is not an operation result.
+ }]);
+}
+
+def MatchStructuredInitOp : MatchStructuredOperandOp<"match.structured.init"> {
+ let summary =
+ "Captures init operand(s) of a structured operation in an op or value handle";
+ let description = !strconcat([{
+ Produces a transform dialect value handle associated with the payload value
+ supplied as init(outs) operand to the given structured payload operation,
+ or an operation handle to the structured payload operation producing said
+ payload value depending on the result type.
+
+ }],
+ StructuredDimDescription<"init">.description,
+ [{
+
+ }],
+ StructuredPredicate.extraDescription,
+ [{
+
+ #### Return modes
+
+ Succeeds if all init(outs) indexes are in bounds, produces a silenceable
+ failure otherwise. Additionally, when the result is an operation handle,
+ produces a silenceable failure if the init(outs) specification defines
+ more than one init(outs) or if the operand is not an operation result.
+ }]);
+}
+
+
+def MatchStructuredNumInputsOp
+ : Op<Transform_Dialect, "match.structured.num_inputs", [
+ StructuredPredicate,
+ SingleOpMatcher,
+ MatchOpInterface,
+ MemoryEffectsOpInterface]> {
+ let summary = "Captures the number of input operands of a structured "
+ "operation as parameter";
+ let description = !strconcat([{
+ Produces a transform dialect parameter value associated with an integer
+ attribute containing the number of input operands of the payload operation
+ associated with the operand handle.
+
+ }], StructuredPredicate.extraDescription, [{
+
+ #### Return modes
+
+ Succeeds if the operand is associated with exactly one structured payload
+ operation. Produces a silenceable failure otherwise.
+ }]);
+
+ let arguments = (ins TransformHandleTypeInterface:$operand_handle);
+ let results = (outs TransformParamTypeInterface:$result);
+ let assemblyFormat =
+ "$operand_handle attr-dict `:` functional-type(operands, results)";
+ let extraClassDeclaration = SingleOpMatcher.extraDeclaration;
+}
+
+def MatchStructuredNumInitsOp
+ : Op<Transform_Dialect, "match.structured.num_inits", [
+ StructuredPredicate,
+ SingleOpMatcher,
+ MatchOpInterface,
+ MemoryEffectsOpInterface]> {
+ let summary = "Captures the number of init(outs) operands of a structured"
+ "operation as parameter";
+ let description = !strconcat([{
+ Produces a transform dialect parameter value associated with an integer
+ attribute containing the number of init(outs) operands of the payload
+ operation associated with the operand handle.
+
+ }], StructuredPredicate.extraDescription, [{
+
+ #### Return modes
+
+ Succeeds if the operand is associated with exactly one structured payload
+ operation. Produces a silenceable failure otherwise.
+ }]);
+
+ let arguments = (ins TransformHandleTypeInterface:$operand_handle);
+ let results = (outs TransformParamTypeInterface:$result);
+ let assemblyFormat =
+ "$operand_handle attr-dict `:` functional-type(operands, results)";
+ let extraClassDeclaration = SingleOpMatcher.extraDeclaration;
+}
+
+def MatchStructuredRankOp : Op<Transform_Dialect, "match.structured.rank", [
+ StructuredPredicate,
+ SingleOpMatcher,
+ MatchOpInterface,
+ MemoryEffectsOpInterface]> {
+ let summary = "Captures the rank of a structured operation as parameter";
+ let description = !strconcat([{
+ Produces a transform dialect parameter value associated with an integer
+ attribute containing the rank of the structured payload operation associated
+ with the operand handle.
+
+ }], StructuredPredicate.extraDescription, [{
+
+ #### Return modes
+
+ Succeeds if the operand is associated with exactly one structured payload
+ operation. Produces a silenceable failure otherwise.
+ }]);
+
+ let arguments = (ins TransformHandleTypeInterface:$operand_handle);
+ let results = (outs TransformParamTypeInterface:$rank);
+ let assemblyFormat =
+ "$operand_handle attr-dict `:`"
+ "custom<SemiFunctionType>(type($operand_handle), type($rank))";
+
+ let extraClassDeclaration = SingleOpMatcher.extraDeclaration;
+}
+
+def MatchStructuredResultOp : Op<Transform_Dialect, "match.structured.result", [
+ StructuredPredicate,
+ SingleOpMatcher,
+ MatchOpInterface,
+ MemoryEffectsOpInterface]> {
+ let summary = "Captures the result of a structured payload operation in an "
+ "op or value handle";
+ let description = !strconcat([{
+ Produces a transform dialect value handle associated with the payload value
+ defined as a result of the payload operation associated with the operand
+ handle, or an operation handle to an operation using the produced result
+ with additional constraints specified by the attributes as follows.
+
+ * If `any` is specified, binds the resulting handle to any operation using
+ the result and succeeds.
+ * If `single` is specified, binds the resulting handle to the only
+ operation using the result or fails if there is more than one (or no)
+ such operation.
+
+ The number of the result is specified as `position` attribute. It may take
+ positive and negative values. Negative values are interpreted as counting
+ results from backwards, e.g., `-1` means the last result and `-2` means the
+ second-to-last result. In any case, the position must be in bounds for the
+ given payload operation. A silenceable failure is produced for out-of-bounds
+ positions.
+
+ }], StructuredPredicate.extraDescription, [{
+
+ #### Return modes
+
+ Succeeds if the position is in bounds and if the user operation could be
+ found when requested. Produces a silenceable failure otherwise.
+ }]);
+ let arguments = (ins TransformHandleTypeInterface:$operand_handle,
+ I64Attr:$position,
+ UnitAttr:$any,
+ UnitAttr:$single);
+ let results = (outs TransformAnyHandle:$result);
+ let assemblyFormat =
+ "$operand_handle `[` $position `]` (`any` $any^)? (`single` $single^)?"
+ "attr-dict `:` functional-type(operands, results)";
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = SingleOpMatcher.extraDeclaration # [{
+ ::mlir::DiagnosedSilenceableFailure
+ getPositionFor(::mlir::linalg::LinalgOp op, int64_t &position);
+ }];
+}
+
+def MatchStructuredYieldOp : Op<Transform_Dialect, "match.structured.yield", [
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ Terminator]> {
+ let summary = "Terminator for transform.match.structured blocks";
+ let description = [{
+ Forwards the payload association from the operands to the results of the
+ parent op. Always succeeds.
+ }];
+ let builders = [
+ OpBuilder<(ins)>
+ ];
+
+ let arguments = (ins Variadic<Transform_AnyHandleOrParamType>:$handles);
+ let assemblyFormat = "$handles attr-dict (`:` type($handles)^)?";
+}
+
+#endif // LINALG_MATCH_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td
new file mode 100644
index 0000000000000..bf997ac2b068c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td
@@ -0,0 +1,9 @@
+include "mlir/IR/EnumAttr.td"
+
+def MatchInterfaceEnum : I32EnumAttr<"MatchInterfaceEnum", "An interface to match",
+ [
+ I32EnumAttrCase<"LinalgOp", 0>,
+ I32EnumAttrCase<"TilingInterface", 1>
+ ]>{
+ let cppNamespace = "mlir::transform";
+}
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
index 091c419f1774a..6276c5687808c 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
@@ -11,6 +11,9 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
+#include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/OpImplementation.h"
@@ -40,15 +43,6 @@ struct NumThreadsSpec {};
} // namespace transform
} // namespace mlir
-//===----------------------------------------------------------------------===//
-// Linalg Transform Operations
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.h.inc"
-
-#define GET_OP_CLASSES
-#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h.inc"
-
namespace mlir {
class DialectRegistry;
@@ -62,6 +56,25 @@ DiagnosedSilenceableFailure tileToForallOpImpl(
ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
SmallVector<Operation *> &tileOps, SmallVector<Operation *> &tiledOps);
+namespace detail {
+LogicalResult verifyStructuredOpPredicateOpTrait(Operation *op,
+ Value structuredOpHandle);
+} // namespace detail
+
+template <typename OpTy>
+class StructuredOpPredicateOpTrait
+ : public OpTrait::TraitBase<OpTy, StructuredOpPredicateOpTrait> {
+public:
+ static LogicalResult verifyTrait(Operation *op) {
+ static_assert(
+ OpTy::template hasTrait<SingleOpMatcherOpTrait>(),
+ "StructuredOpPredicateOpTrait requires SingleOpMatcherOpTrait");
+
+ return detail::verifyStructuredOpPredicateOpTrait(
+ op, cast<OpTy>(op).getOperandHandle());
+ }
+};
+
} // namespace transform
namespace linalg {
@@ -69,4 +82,16 @@ void registerTransformDialectExtension(DialectRegistry ®istry);
} // namespace linalg
} // namespace mlir
+//===----------------------------------------------------------------------===//
+// Linalg Transform Operations
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.h.inc"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h.inc"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h.inc"
+
#endif // MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index f60e6bdc79abb..af76660d988f2 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -9,13 +9,14 @@
#ifndef LINALG_TRANSFORM_OPS
#define LINALG_TRANSFORM_OPS
+include "mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td"
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/Dialect/PDL/IR/PDLTypes.td"
include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
-include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/RegionKindInterface.td"
@@ -344,14 +345,6 @@ def LowerUnPackOp : Op<Transform_Dialect, "structured.lower_unpack", [
// MatchOp
//===----------------------------------------------------------------------===//
-def MatchInterfaceEnum : I32EnumAttr<"MatchInterfaceEnum", "An interface to match",
- [
- I32EnumAttrCase<"LinalgOp", 0>,
- I32EnumAttrCase<"TilingInterface", 1>
- ]>{
- let cppNamespace = "mlir::transform";
-}
-
def MatchOp : Op<Transform_Dialect, "structured.match",
[MemoryEffectsOpInterface,
NavigationTransformOpTrait,
diff --git a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
index 6301edf0f1c9c..d362524528f1f 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
@@ -63,6 +63,52 @@ class SingleOpMatcherOpTrait
return cast<OpTy>(this->getOperation())
.matchOperation(payload[0], results, state);
}
+
+ void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(this->getOperation()->getOperands(), effects);
+ producesHandle(this->getOperation()->getResults(), effects);
+ onlyReadsPayload(effects);
+ }
+};
+
+template <typename OpTy>
+class SingleValueMatcherOpTrait
+ : public OpTrait::TraitBase<OpTy, SingleValueMatcherOpTrait> {
+public:
+ static LogicalResult verifyTrait(Operation *op) {
+ // This must be a dynamic assert because interface registration is dynamic.
+ assert(isa<MatchOpInterface>(op) &&
+ "SingleValueMatchOpTrait is only available on operations with "
+ "MatchOpInterface");
+
+ Value operandHandle = cast<OpTy>(op).getOperandHandle();
+ if (!operandHandle.getType().isa<TransformValueHandleTypeInterface>()) {
+ return op->emitError() << "SingleValueMatchOpTrait requires an operand "
+ "of TransformValueHandleTypeInterface";
+ }
+
+ return success();
+ }
+
+ DiagnosedSilenceableFailure apply(TransformResults &results,
+ TransformState &state) {
+ Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
+ ValueRange payload = state.getPayloadValues(operandHandle);
+ if (payload.size() != 1) {
+ return emitDefiniteFailure(this->getOperation()->getLoc())
+ << "SingleValueMatchOpTrait requires the value handle to point to "
+ "a single payload value";
+ }
+
+ return cast<OpTy>(this->getOperation())
+ .matchValue(payload[0], results, state);
+ }
+
+ void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(this->getOperation()->getOperands(), effects);
+ producesHandle(this->getOperation()->getResults(), effects);
+ onlyReadsPayload(effects);
+ }
};
} // namespace transform
diff --git a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td
index fdfce6cd90f52..1f81fd5252eb4 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td
@@ -14,6 +14,12 @@ def MatchOpInterface
let cppNamespace = "::mlir::transform";
}
+// Trait for "matcher" transform operations that apply to an operation handle
+// associated with exactly one payload operation. Checks that it is indeed
+// the case and produces a definite failure when it is not. The matching logic
+// is implemented in the `matchOperation` function instead of `apply`. The op
+// with this trait must provide a `Value getOperandHandle()` function that
+// returns the handle to be used for matching.
def SingleOpMatcher : NativeOpTrait<"SingleOpMatcherOpTrait"> {
let cppNamespace = "::mlir::transform";
@@ -24,3 +30,20 @@ def SingleOpMatcher : NativeOpTrait<"SingleOpMatcherOpTrait"> {
::mlir::transform::TransformState &state);
}];
}
+
+// Trait for "matcher" transform operations that apply to a value handle
+// associated with exactly one payload value. Checks that it is indeed
+// the case and produces a definite failure when it is not. The matching logic
+// is implemented in the `matchValue` function instead of `apply`. The op
+// with this trait must provide a `Value getOperandHandle()` function that
+// returns the handle to be used for matching.
+def SingleValueMatcher : NativeOpTrait<"SingleValueMatcherOpTrait"> {
+ let cppNamespace = "::mlir::transform";
+
+ string extraDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure matchValue(
+ ::mlir::Value current,
+ ::mlir::transform::TransformResults &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.h b/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.h
new file mode 100644
index 0000000000000..3cb935003b4c4
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.h
@@ -0,0 +1,20 @@
+//===- TransformAttr.h - Transform Dialect Attribute Definition -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMATTRS_H
+#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMATTRS_H
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributes.h"
+
+#include <cstdint>
+#include <optional>
+
+#include "mlir/Dialect/Transform/IR/TransformDialectEnums.h.inc"
+
+#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMATTRS_H
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td b/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td
index 652f71434b697..ebad2994880e7 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td
@@ -20,4 +20,17 @@ def FailurePropagationMode : I32EnumAttr<
let cppNamespace = "::mlir::transform";
}
+def MatchCmpIPredicateAttr : I32EnumAttr<
+ "MatchCmpIPredicate", "",
+ [
+ I32EnumAttrCase<"eq", 0>,
+ I32EnumAttrCase<"ne", 1>,
+ I32EnumAttrCase<"lt", 2>,
+ I32EnumAttrCase<"le", 3>,
+ I32EnumAttrCase<"gt", 4>,
+ I32EnumAttrCase<"ge", 5>,
+ ]> {
+ let cppNamespace = "::mlir::transform";
+}
+
#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMATTRS
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
index 41f3da05f888b..36712add2eb05 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
@@ -243,6 +243,4 @@ class BuildOnly : public DerivedTy {
} // namespace transform
} // namespace mlir
-#include "mlir/Dialect/Transform/IR/TransformDialectEnums.h.inc"
-
#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 177eb0309964b..e2e2354876524 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -617,6 +617,10 @@ class TransformResults {
/// operations.
void setMappedValues(OpResult handle, ArrayRef<MappedValue> values);
+ /// Sets the currently unset results to empty lists of the kind expected by
+ /// the corresponding results of the given `transform` op.
+ void setRemainingToEmpty(TransformOpInterface transform);
+
private:
/// Creates an instance of TransformResults that expects mappings for
/// `numSegments` values, which may be associated with payload operations or
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
index 78d2a65e63379..b6bc094d8ba55 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
@@ -11,6 +11,7 @@
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/IR/FunctionInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 954fc174b84c2..4bee1d4c6eb2e 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -378,6 +378,46 @@ def IncludeOp : TransformDialectOp<"include",
}];
}
+def MatchOperationNameOp : TransformDialectOp<"match.operation_name",
+ [SingleOpMatcher,
+ MatchOpInterface,
+ MemoryEffectsOpInterface]> {
+ let summary = "Matches a single operation of one of the given kinds";
+ let description = [{
+ Succeeds if the operation associated with the operand handle has one of the
+ given operation names. Produces a silenceable failure otherwise.
+
+ If more than one payload operation is associated with the operand handle,
+ produces a definite failure.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$operand_handle,
+ StrArrayAttr:$op_names);
+ let assemblyFormat =
+ "$operand_handle $op_names attr-dict `:` type($operand_handle)";
+ let extraClassDeclaration = SingleOpMatcher.extraDeclaration;
+}
+
+def MatchParamCmpIOp : Op<Transform_Dialect, "match.param.cmpi", [
+ DeclareOpInterfaceMethods<TransformOpInterface>,
+ MatchOpInterface,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ SameTypeOperands]> {
+ let summary =
+ "Matches if two parameter lists are associated with the same value";
+ let description = [{
+ Succeeds if all of the co-indexed values associated with the given
+ parameters relate as specified by the predicate (greater than, less than,
+ equal to, or their combinations). Comparison treats all values as signed.
+ Produces a silenceable failure otherwise.
+ }];
+ let arguments = (ins TransformParamTypeInterface:$param,
+ TransformParamTypeInterface:$reference,
+ MatchCmpIPredicateAttr:$predicate);
+ let assemblyFormat =
+ "$predicate $param `,` $reference attr-dict `:` type($param)";
+}
+
def MergeHandlesOp : TransformDialectOp<"merge_handles",
[DeclareOpInterfaceMethods<TransformOpInterface, ["allowsRepeatedHandleOperands"]>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
@@ -501,6 +541,24 @@ def SplitHandlesOp : TransformDialectOp<"split_handles",
}];
}
+def ParamConstantOp : Op<Transform_Dialect, "param.constant", [
+ MatchOpInterface,
+ DeclareOpInterfaceMethods<TransformOpInterface>,
+ MemoryEffectsOpInterface,
+ ParamProducerTransformOpTrait]> {
+ let summary = "Produces a new transform dialect parameter value associated "
+ "with the given attribute";
+ let description = [{
+ Produces a new transform dialect parameter associated with the singleton
+ list containing the given attribute. The operation itself always succeeds,
+ but the general association check may fail if the parameter type does not
+ accept the given kind of attribute as valid.
+ }];
+ let arguments = (ins AnyAttr:$value);
+ let results = (outs TransformParamTypeInterface:$param);
+ let assemblyFormat = "$value attr-dict `->` type($param)";
+}
+
def PDLMatchOp : TransformDialectOp<"pdl_match",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td b/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td
index 9eece0fb3dcfa..2a2ecc287392e 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td
@@ -68,4 +68,9 @@ class Transform_ConcreteOpType<string opname>
"Transform IR handle to " # opname # " operations",
"::mlir::transform::OperationType">;
+def TransformAnyHandle : Type<
+ Or<[TransformHandleTypeInterface.predicate,
+ TransformValueHandleTypeInterface.predicate]>,
+ "transform operation or value handle">;
+
#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMTYPES
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
index b7d9812ada0b1..079d585d4ea37 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
@@ -1,11 +1,14 @@
add_mlir_dialect_library(MLIRLinalgTransformOps
+ LinalgMatchOps.cpp
LinalgTransformOps.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg/TransformOps
DEPENDS
+ MLIRLinalgMatchOpsIncGen
MLIRLinalgTransformOpsIncGen
+ MLIRLinalgTransformEnumsIncGen
LINK_LIBS PUBLIC
MLIRAffineDialect
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
new file mode 100644
index 0000000000000..e2d895f12f536
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
@@ -0,0 +1,826 @@
+//===- LinalgTransformOps.cpp - Implementation of Linalg match ops --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
+#include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/FunctionImplementation.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/FormatVariadic.h"
+
+using namespace mlir;
+
+#define DEBUG_TYPE "linalg-transforms"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+//===----------------------------------------------------------------------===//
+// StructuredMatchOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation(
+ Operation *current, transform::TransformResults &results,
+ transform::TransformState &state) {
+ // First, check if the payload operation is a structured Linalg operation.
+ if (!isa<linalg::LinalgOp>(current)) {
+ if (getFailurePropagationMode().value_or(
+ FailurePropagationMode::Propagate) ==
+ FailurePropagationMode::Propagate) {
+ return emitSilenceableError() << "expected a Linalg op";
+ }
+ // If errors are suppressed, succeed and set all results to empty lists.
+ LLVM_DEBUG(DBGS() << "optional nested matcher expected a Linalg op");
+ results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation()));
+ return DiagnosedSilenceableFailure::success();
+ }
+
+ // Bind `current` to the block argument.
+ auto scope = state.make_region_scope(getBodyRegion());
+ if (failed(state.mapBlockArgument(getBody()->getArgument(0),
+ MappedValue(current)))) {
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+
+ for (Operation &nested : getBody()->without_terminator()) {
+ DiagnosedSilenceableFailure diag =
+ state.applyTransform(cast<TransformOpInterface>(nested));
+ if (diag.isDefiniteFailure())
+ return diag;
+ if (diag.succeeded())
+ continue;
+
+ // If propagating errors, do this immediately.
+ assert(diag.isSilenceableFailure());
+ if (getFailurePropagationMode().value_or(
+ FailurePropagationMode::Propagate) ==
+ FailurePropagationMode::Propagate) {
+ return diag;
+ }
+
+ // If suppressing errors, print the message into the debug stream before
+ // silencing it. Then set all results value that are already known.
+ // Results come from the terminator operands, which may be defined in the
+ // (single) block of this operation or above it. When they are defined
+ // above, they are known to be mapped at this point per SSA dominance.
+ // When they are defined in this block, we additionally check if we have
+ // already applied the operation that defines them. If not, the
+ // corresponding results will be set to empty lists.
+ LLVM_DEBUG(DBGS() << "optional nested matcher failed: " << diag.getMessage()
+ << "\n");
+ (void)diag.silence();
+ SmallVector<OpOperand *> undefinedOperands;
+ for (OpOperand &terminatorOperand :
+ getBody()->getTerminator()->getOpOperands()) {
+ Operation *definingOp = terminatorOperand.get().getDefiningOp();
+ if (!definingOp)
+ continue;
+ if (definingOp->getBlock() != getBody())
+ continue;
+ if (definingOp->isBeforeInBlock(&nested))
+ continue;
+
+ undefinedOperands.push_back(&terminatorOperand);
+ }
+
+ SmallVector<SmallVector<transform::MappedValue>> mappings;
+ auto filtered = llvm::make_filter_range(
+ getBody()->getTerminator()->getOpOperands(), [&](OpOperand &opOperand) {
+ return !llvm::is_contained(undefinedOperands, &opOperand);
+ });
+ SmallVector<Value> definedOperands = llvm::to_vector(llvm::map_range(
+ filtered, [](OpOperand &opOperand) { return opOperand.get(); }));
+ detail::prepareValueMappings(mappings, definedOperands, state);
+ for (auto &&[operand, mapping] : llvm::zip_equal(filtered, mappings)) {
+ results.setMappedValues(getResults()[operand.getOperandNumber()],
+ mapping);
+ }
+ results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation()));
+ return DiagnosedSilenceableFailure::success();
+ }
+
+ // Set the results.
+ detail::forwardTerminatorOperands(getBody(), state, results);
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::MatchStructuredOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getCurrent(), effects);
+ onlyReadsPayload(effects);
+ producesHandle(getOutputs(), effects);
+}
+
+LogicalResult transform::MatchStructuredOp::verify() {
+ if (getBody()->getNumArguments() != 1)
+ return emitOpError() << "expected one body argument";
+ if (!isa<TransformHandleTypeInterface>(getBody()->getArgument(0).getType())) {
+ return emitOpError() << "expected body argument to implement "
+ "TransformHandleTypeInterface";
+ }
+ for (Operation &nested : getBody()->without_terminator()) {
+ if (isa<MatchOpInterface>(nested))
+ continue;
+ InFlightDiagnostic diag =
+ emitOpError()
+ << "expects nested operations to implement MatchOpInterface";
+ diag.attachNote(nested.getLoc()) << "offending operation";
+ return diag;
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// StructuredOpPredicateOpTrait
+//===----------------------------------------------------------------------===//
+
+LogicalResult transform::detail::verifyStructuredOpPredicateOpTrait(
+ Operation *op, Value structuredOpHandle) {
+ if (!isa_and_nonnull<MatchStructuredOp>(op->getParentOp())) {
+ return op->emitOpError() << "expects parent op to be '"
+ << MatchStructuredOp::getOperationName() << "'";
+ }
+
+ // Bail out here, let the verifier of the parent complain.
+ Operation *parent = op->getParentOp();
+ if (parent->getNumRegions() < 1 || parent->getRegion(0).empty() ||
+ parent->getRegion(0).front().getNumArguments() < 1)
+ return success();
+
+ if (structuredOpHandle != parent->getRegion(0).front().getArgument(0)) {
+ return op->emitOpError()
+ << "expected predicate to apply to the surrounding structured op";
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// MatchStructuredBodyOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation(
+ Operation *current, transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto linalgOp = cast<linalg::LinalgOp>(current);
+ if (std::optional<uint64_t> position = getReductionPosition()) {
+ SmallVector<Operation *> combinerOps;
+ if (!matchReduction(linalgOp.getRegionOutputArgs(), *position,
+ combinerOps)) {
+ return emitSilenceableError() << "could not match reduction";
+ }
+ if (combinerOps.size() != 1) {
+ return emitSilenceableError() << "reduction combiner is not a single op";
+ }
+ return DiagnosedSilenceableFailure::success();
+ }
+ if (getPassthrough()) {
+ Block &body = linalgOp->getRegion(0).front();
+ if (body.getTerminator()->getOperands() != linalgOp.getRegionInputArgs()) {
+ return emitSilenceableError() << "not a passthrough";
+ }
+ return DiagnosedSilenceableFailure::success();
+ }
+ return emitDefiniteFailure() << "unknown body condition";
+}
+
+LogicalResult transform::MatchStructuredBodyOp::verify() {
+ if (getReductionPosition() && getPassthrough()) {
+ return emitOpError() << "reduction position and passthrough conditions are "
+ "mutually exclusive";
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Utilities for structured match predicates.
+//===----------------------------------------------------------------------===//
+
+/// Checks if all values from `list` are also contained in `reference`. Returns
+/// a silenceable error with the given message at the given location when it is
+/// not the case. The error message must contain the "{0}" placeholder that
+/// will be substituted with the value from `list` that is not contained in
+/// `reference`.
+static DiagnosedSilenceableFailure containsAll(ArrayRef<unsigned> reference,
+ ArrayRef<int64_t> list,
+ Location loc,
+ const char *message) {
+ for (int64_t value : list) {
+ if (llvm::any_of(reference, [&](unsigned ref) {
+ return static_cast<int64_t>(ref) == value;
+ })) {
+ continue;
+ }
+ return emitSilenceableFailure(loc) << llvm::formatv(message, value);
+ }
+ return DiagnosedSilenceableFailure::success();
+}
+
+/// Populates `result` with the positional identifiers relative to `maxNumber`.
+/// If `isAll` is set, the result will contain all numbers from `0` to
+/// `maxNumber - 1` inclusive regardless of `rawList`. Otherwise, negative
+/// values from `rawList` are are interpreted as counting backwards from
+/// `maxNumber`, i.e., `-1` is interpreted a `maxNumber - 1`, while positive
+/// numbers remain as is. If `isInverted` is set, populates `result` with those
+/// values from the `0` to `maxNumber - 1` inclusive range that don't appear in
+/// `rawList`. If `rawList` contains values that are greater than or equal to
+/// `maxNumber` or less than `-maxNumber`, produces a silenceable error at the
+/// given location. `maxNumber` must be positive. If `rawList` contains
+/// duplicate numbers or numbers that become duplicate after negative value
+/// remapping, emits a silenceable error.
+static DiagnosedSilenceableFailure
+expandTargetSpecification(Location loc, bool isAll, bool isInverted,
+ ArrayRef<int64_t> rawList, int64_t maxNumber,
+ SmallVectorImpl<int64_t> &result) {
+ assert(maxNumber > 0 && "expected size to be positive");
+ assert(!(isAll && isInverted) && "cannot invert all");
+ if (isAll) {
+ result = llvm::to_vector(llvm::seq<int64_t>(0, maxNumber));
+ return DiagnosedSilenceableFailure::success();
+ }
+
+ SmallVector<int64_t> expanded;
+ llvm::SmallDenseSet<int64_t> visited;
+ expanded.reserve(rawList.size());
+ SmallVectorImpl<int64_t> &target = isInverted ? expanded : result;
+ for (int64_t raw : rawList) {
+ int64_t updated = raw < 0 ? maxNumber + raw : raw;
+ if (updated >= maxNumber) {
+ return emitSilenceableFailure(loc)
+ << "position overflow " << updated << " (updated from " << raw
+ << ") for maximum " << maxNumber;
+ }
+ if (updated < 0) {
+ return emitSilenceableFailure(loc) << "position underflow " << updated
+ << " (updated from " << raw << ")";
+ }
+ if (!visited.insert(updated).second) {
+ return emitSilenceableFailure(loc) << "repeated position " << updated
+ << " (updated from " << raw << ")";
+ }
+ target.push_back(updated);
+ }
+
+ if (!isInverted)
+ return DiagnosedSilenceableFailure::success();
+
+ result.reserve(result.size() + (maxNumber - expanded.size()));
+ for (int64_t candidate : llvm::seq<int64_t>(0, maxNumber)) {
+ if (llvm::is_contained(expanded, candidate))
+ continue;
+ result.push_back(candidate);
+ }
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+/// Checks if the positional specification defined is valid and reports errors
+/// otherwise.
+LogicalResult verifyStructuredTransformDimsOp(Operation *op,
+ ArrayRef<int64_t> raw,
+ bool inverted, bool all) {
+ if (all) {
+ if (inverted) {
+ return op->emitOpError()
+ << "cannot request both 'all' and 'inverted' values in the list";
+ }
+ if (!raw.empty()) {
+ return op->emitOpError()
+ << "cannot both request 'all' and specific values in the list";
+ }
+ }
+ if (!all && raw.empty()) {
+ return op->emitOpError() << "must request specific values in the list if "
+ "'all' is not specified";
+ }
+ SmallVector<int64_t> rawVector = llvm::to_vector(raw);
+ auto *it = std::unique(rawVector.begin(), rawVector.end());
+ if (it != rawVector.end())
+ return op->emitOpError() << "expected the listed values to be unique";
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// MatchStructuredDimOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::MatchStructuredDimOp::matchOperation(
+ Operation *current, transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto linalgOp = cast<linalg::LinalgOp>(current);
+ SmallVector<int64_t> dimensions;
+ DiagnosedSilenceableFailure diag = getDimensionsFor(linalgOp, dimensions);
+ if (!diag.succeeded())
+ return diag;
+
+ // If asked to check for the kind of dimension, perform the check.
+ if (getParallel() || getReduction()) {
+ SmallVector<unsigned> reference;
+ if (getParallel())
+ linalgOp.getParallelDims(reference);
+ else if (getReduction())
+ linalgOp.getReductionDims(reference);
+
+ DiagnosedSilenceableFailure diag =
+ containsAll(reference, dimensions, getLoc(),
+ getParallel() ? "expects dimension #{0} to be parallel"
+ : "expects dimension #{0} to be reduction");
+ if (!diag.succeeded())
+ return diag;
+ }
+
+ // If not capturing, we are done here.
+ if (!getResult())
+ return diag;
+
+ SmallVector<int64_t, 4> ranges = linalgOp.getStaticLoopRanges();
+ Builder builder(current);
+ SmallVector<Attribute> captured = llvm::to_vector(
+ llvm::map_range(dimensions, [&](int64_t dim) -> Attribute {
+ return builder.getI64IntegerAttr(ranges[dim]);
+ }));
+ results.setParams(cast<OpResult>(getResult()), captured);
+ return DiagnosedSilenceableFailure::success();
+}
+
+DiagnosedSilenceableFailure transform::MatchStructuredDimOp::getDimensionsFor(
+ linalg::LinalgOp op, SmallVectorImpl<int64_t> &dims) {
+ DiagnosedSilenceableFailure diag =
+ expandTargetSpecification(getLoc(), getIsAll(), getIsInverted(),
+ getRawDimList(), op.getNumLoops(), dims);
+ if (diag.isSilenceableFailure()) {
+ diag.attachNote(op->getLoc())
+ << "while considering dimensions of this payload operation";
+ }
+ return diag;
+}
+
+LogicalResult transform::MatchStructuredDimOp::verify() {
+ if (getParallel() && getReduction()) {
+ return emitOpError() << "cannot request the same dimension to be both "
+ "parallel and reduction";
+ }
+ return verifyStructuredTransformDimsOp(getOperation(), getRawDimList(),
+ getIsInverted(), getIsAll());
+}
+
+//===----------------------------------------------------------------------===//
+// MatchStructuredElementalBitwidthOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::MatchStructuredElementalBitwidthOp::matchValue(
+ Value current, transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto setupResult = [&](int64_t bitwidth) {
+ Attribute attr = Builder(current.getContext()).getI64IntegerAttr(bitwidth);
+ results.setParams(cast<OpResult>(getResult()), {attr});
+ return DiagnosedSilenceableFailure::success();
+ };
+
+ Type type = current.getType();
+ if (type.isIntOrFloat())
+ return setupResult(type.getIntOrFloatBitWidth());
+
+ if (auto shapedType = dyn_cast<ShapedType>(type)) {
+ if (shapedType.getElementType().isIntOrFloat())
+ return setupResult(shapedType.getElementTypeBitWidth());
+ }
+ return emitSilenceableError()
+ << "unsupported type for bitwidth extraction: " << type;
+}
+
+//===----------------------------------------------------------------------===//
+// MatchStructuredInputOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::MatchStructuredInputOp::matchOperation(
+ Operation *current, transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto linalgOp = cast<linalg::LinalgOp>(current);
+ SmallVector<int64_t> positions;
+ DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions);
+ if (!diag.succeeded())
+ return diag;
+
+ SmallVector<MappedValue> operandMapping;
+ operandMapping.reserve(positions.size());
+ for (int64_t position : positions) {
+ AffineMap indexingMap =
+ linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(position));
+ if (getPermutation() && !indexingMap.isPermutation()) {
+ return emitSilenceableError() << "the indexing map for input #"
+ << position << " is not a permutation";
+ }
+ if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) {
+ return emitSilenceableError()
+ << "the indexing map for input #" << position
+ << " is not a projected permutation";
+ }
+
+ // If capture not requested, skip it.
+ if (!getResult())
+ continue;
+
+ Value operand = linalgOp.getDpsInputOperand(position)->get();
+ if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
+ operandMapping.emplace_back(operand);
+ continue;
+ }
+
+ Operation *operandProducer = operand.getDefiningOp();
+ if (!operandProducer) {
+ return emitSilenceableError()
+ << "input #" << position << " is not produced by an operation";
+ }
+ operandMapping.emplace_back(operandProducer);
+ }
+ if (getResult())
+ results.setMappedValues(cast<OpResult>(getResult()), operandMapping);
+ return DiagnosedSilenceableFailure::success();
+}
+
+DiagnosedSilenceableFailure transform::MatchStructuredInputOp::getPositionsFor(
+ linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) {
+ DiagnosedSilenceableFailure diag = expandTargetSpecification(
+ getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
+ op.getNumDpsInputs(), positions);
+ if (diag.isSilenceableFailure()) {
+ diag.attachNote(op->getLoc())
+ << "while considering DPS inputs of this payload operation";
+ }
+ return diag;
+}
+
+/// Verifies a matcher op for structured input or output, specifically the
+/// attributes specifying the operand positions.
+template <typename OpTy>
+LogicalResult verifyStructuredOperandOp(OpTy op) {
+ if (op.getPermutation() && op.getProjectedPermutation()) {
+ return op.emitOpError()
+ << op.getPermutationAttrName() << " and "
+ << op.getProjectedPermutationAttrName() << " are mutually exclusive";
+ }
+ if (op.getRawPositionList().size() > 1 && op.getResult()) {
+ return op.emitOpError()
+ << "cannot bind multiple inputs/inits to the same value";
+ }
+
+ return success();
+}
+
+LogicalResult transform::MatchStructuredInputOp::verify() {
+ if (failed(verifyStructuredOperandOp(*this)))
+ return failure();
+ return verifyStructuredTransformDimsOp(getOperation(), getRawPositionList(),
+ getIsInverted(), getIsAll());
+}
+
+//===----------------------------------------------------------------------===//
+// MatchStructuredInitOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::MatchStructuredInitOp::matchOperation(
+ Operation *current, transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto linalgOp = cast<linalg::LinalgOp>(current);
+ SmallVector<int64_t> positions;
+ DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions);
+ if (!diag.succeeded())
+ return diag;
+
+ SmallVector<MappedValue> operandMapping;
+ operandMapping.reserve(positions.size());
+ for (int64_t position : positions) {
+ AffineMap indexingMap =
+ linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(position));
+ if (getPermutation() && !indexingMap.isPermutation()) {
+ return emitSilenceableError() << "the indexing map for output(init) #"
+ << position << " is not a permutation";
+ }
+ if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) {
+ return emitSilenceableError() << "the indexing map for output(init) #"
+ << position << " is not a permutation";
+ }
+
+ // If capture not requested, skip it.
+ if (!getResult())
+ continue;
+
+ Value operand = linalgOp.getDpsInitOperand(position)->get();
+ if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
+ operandMapping.emplace_back(operand);
+ continue;
+ }
+
+ Operation *operandProducer = operand.getDefiningOp();
+ if (!operandProducer) {
+ return emitSilenceableError() << "output(init) #" << position
+ << " is not produced by an operation";
+ }
+ operandMapping.emplace_back(operandProducer);
+ }
+ if (getResult())
+ results.setMappedValues(cast<OpResult>(getResult()), operandMapping);
+ return DiagnosedSilenceableFailure::success();
+}
+
+DiagnosedSilenceableFailure transform::MatchStructuredInitOp::getPositionsFor(
+ linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) {
+ DiagnosedSilenceableFailure diag = expandTargetSpecification(
+ getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
+ op.getNumDpsInits(), positions);
+ if (diag.isSilenceableFailure()) {
+ diag.attachNote(op->getLoc())
+ << "while considering DPS inits (outputs) of this payload operation";
+ }
+ return diag;
+}
+
+LogicalResult transform::MatchStructuredInitOp::verify() {
+ if (failed(verifyStructuredOperandOp(*this)))
+ return failure();
+ return verifyStructuredTransformDimsOp(getOperation(), getRawPositionList(),
+ getIsInverted(), getIsAll());
+}
+
+//===----------------------------------------------------------------------===//
+// MatchStructuredNumInputsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::MatchStructuredNumInputsOp::matchOperation(
+ Operation *current, transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto linalgOp = cast<linalg::LinalgOp>(current);
+ Attribute attr =
+ Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInputs());
+ results.setParams(cast<OpResult>(getResult()), {attr});
+ return DiagnosedSilenceableFailure::success();
+}
+
+//===----------------------------------------------------------------------===//
+// MatchStructuredNumInitsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::MatchStructuredNumInitsOp::matchOperation(
+ Operation *current, transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto linalgOp = cast<linalg::LinalgOp>(current);
+ Attribute attr =
+ Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInits());
+ results.setParams(cast<OpResult>(getResult()), {attr});
+ return DiagnosedSilenceableFailure::success();
+}
+
+//===----------------------------------------------------------------------===//
+// MatchStructuredRankOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::MatchStructuredRankOp::matchOperation(
+ Operation *current, transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto linalgOp = cast<linalg::LinalgOp>(current);
+ int64_t numLoops = linalgOp.getNumLoops();
+ Attribute attr = Builder(linalgOp->getContext()).getI64IntegerAttr(numLoops);
+ results.setParams(cast<OpResult>(getRank()), {attr});
+ return DiagnosedSilenceableFailure::success();
+}
+
+//===----------------------------------------------------------------------===//
+// MatchStructuredResultOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation(
+ Operation *op, transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto linalgOp = cast<linalg::LinalgOp>(op);
+ int64_t position;
+ DiagnosedSilenceableFailure diag = getPositionFor(linalgOp, position);
+ if (!diag.succeeded())
+ return diag;
+
+ Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position));
+ if (getResult().getType().isa<TransformValueHandleTypeInterface>()) {
+ results.setValues(cast<OpResult>(getResult()), result);
+ return DiagnosedSilenceableFailure::success();
+ }
+
+ if (result.getUsers().empty()) {
+ return emitSilenceableError()
+ << "no users of the result #" << getPosition();
+ }
+ Operation *firstUser = *result.getUsers().begin();
+ if (getAny()) {
+ results.set(cast<OpResult>(getResult()), firstUser);
+ return DiagnosedSilenceableFailure::success();
+ }
+ if (getSingle()) {
+ if (!llvm::hasSingleElement(result.getUsers())) {
+ return emitSilenceableError()
+ << "more than one result user with single user requested";
+ }
+ results.set(cast<OpResult>(getResult()), firstUser);
+ return DiagnosedSilenceableFailure::success();
+ }
+
+ return emitDefiniteFailure() << "unknown sub-predicate";
+}
+
+DiagnosedSilenceableFailure
+transform::MatchStructuredResultOp::getPositionFor(linalg::LinalgOp op,
+ int64_t &position) {
+ auto rawPosition = static_cast<int64_t>(getPosition());
+ position = rawPosition < 0 ? op.getNumDpsInits() + rawPosition : rawPosition;
+ if (position >= op.getNumDpsInits() || position < 0) {
+ return emitSilenceableError()
+ << "position " << rawPosition
+ << " overflows the number of results(ints) of the payload operation";
+ }
+ return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::MatchStructuredResultOp::verify() {
+ if ((getAny() || getSingle()) ^
+ getResult().getType().isa<TransformHandleTypeInterface>()) {
+ return emitOpError() << "expects either the any/single keyword or the type "
+ "value handle result type";
+ }
+ if (getAny() && getSingle()) {
+ return emitOpError() << "'any' and 'single' are mutually exclusive";
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// MatchStructuredYieldOp
+//===----------------------------------------------------------------------===//
+
+void transform::MatchStructuredYieldOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getHandles(), effects);
+ onlyReadsPayload(effects);
+}
+
+void transform::MatchStructuredYieldOp::build(OpBuilder &builder,
+ OperationState &state) {
+ build(builder, state, ValueRange());
+}
+
+//===----------------------------------------------------------------------===//
+// Printing and parsing for structured match ops.
+//===----------------------------------------------------------------------===//
+
+/// Keyword syntax for positional specification inversion.
+constexpr const static llvm::StringLiteral kDimExceptKeyword = "except";
+
+/// Keyword syntax for full inclusion in positional specification.
+constexpr const static llvm::StringLiteral kDimAllKeyword = "all";
+
+/// Parses a positional specification for structured transform operations. The
+/// following forms are accepted:
+///
+/// - `all`: sets `isAll` and returns;
+/// - comma-separated-integer-list: populates `rawDimList` with the values;
+/// - `except` `(` comma-separated-integer-list `)`: populates `rawDimList`
+/// with the values and sets `isInverted`.
+static ParseResult parseStructuredTransformDims(OpAsmParser &parser,
+ DenseI64ArrayAttr &rawDimList,
+ UnitAttr &isInverted,
+ UnitAttr &isAll) {
+ Builder &builder = parser.getBuilder();
+ if (parser.parseOptionalKeyword(kDimAllKeyword).succeeded()) {
+ rawDimList = builder.getDenseI64ArrayAttr({});
+ isInverted = nullptr;
+ isAll = builder.getUnitAttr();
+ return success();
+ }
+
+ isAll = nullptr;
+ isInverted = nullptr;
+ if (parser.parseOptionalKeyword(kDimExceptKeyword).succeeded()) {
+ isInverted = builder.getUnitAttr();
+ }
+
+ if (isInverted) {
+ if (parser.parseLParen().failed())
+ return failure();
+ }
+
+ SmallVector<int64_t> values;
+ ParseResult listResult = parser.parseCommaSeparatedList(
+ [&]() { return parser.parseInteger(values.emplace_back()); });
+ if (listResult.failed())
+ return failure();
+
+ rawDimList = builder.getDenseI64ArrayAttr(values);
+
+ if (isInverted) {
+ if (parser.parseRParen().failed())
+ return failure();
+ }
+ return success();
+}
+
+/// Prints a positional specification for structured transform operations.
+static void printStructuredTransformDims(OpAsmPrinter &printer, Operation *op,
+ DenseI64ArrayAttr rawDimList,
+ UnitAttr isInverted, UnitAttr isAll) {
+ if (isAll) {
+ printer << kDimAllKeyword;
+ return;
+ }
+ if (isInverted) {
+ printer << kDimExceptKeyword << "(";
+ }
+ llvm::interleaveComma(rawDimList.asArrayRef(), printer.getStream(),
+ [&](int64_t value) { printer << value; });
+ if (isInverted) {
+ printer << ")";
+ }
+}
+/// Parses a single non-function type or a function type with at least one
+/// argument. This allows for the following syntax:
+///
+/// - type: just the argument type;
+/// - `(` type `)` `->` type: one argument and one result type;
+/// - `(` type `)` `->` `(` comma-separated-type-list `)`: one argument and
+/// multiple result types.
+///
+/// Unlike FunctionType, this allows and requires one to omit the parens around
+/// the argument type in absence of result types, and does not accept the
+/// trailing `-> ()` construct, which makes the syntax nicer for operations.
+static ParseResult parseSemiFunctionType(OpAsmParser &parser,
+ Type &argumentType, Type &resultType) {
+ argumentType = resultType = nullptr;
+ bool hasLParen = parser.parseOptionalLParen().succeeded();
+ if (parser.parseType(argumentType).failed())
+ return failure();
+ if (!hasLParen)
+ return success();
+
+ return failure(parser.parseRParen().failed() ||
+ parser.parseArrow().failed() ||
+ parser.parseType(resultType).failed());
+}
+static ParseResult parseSemiFunctionType(OpAsmParser &parser,
+ Type &argumentType,
+ SmallVectorImpl<Type> &resultTypes) {
+ argumentType = nullptr;
+ bool hasLParen = parser.parseOptionalLParen().succeeded();
+ if (parser.parseType(argumentType).failed())
+ return failure();
+ if (!hasLParen)
+ return success();
+
+ if (parser.parseRParen().failed() || parser.parseArrow().failed())
+ return failure();
+
+ if (parser.parseOptionalLParen().failed()) {
+ Type type;
+ if (parser.parseType(type).failed())
+ return failure();
+ resultTypes.push_back(type);
+ return success();
+ }
+ if (parser.parseTypeList(resultTypes).failed() ||
+ parser.parseRParen().failed()) {
+ resultTypes.clear();
+ return failure();
+ }
+ return success();
+}
+
+/// Prints argument and result types in a syntax similar to that of FunctionType
+/// but allowing and requiring one to omit the parens around the argument type
+/// in absence of result types, and without the trailing `-> ()`.
+static void printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
+ Type argumentType, TypeRange resultType) {
+ if (!resultType.empty())
+ printer << "(";
+ printer << argumentType;
+ if (resultType.empty())
+ return;
+ printer << ") -> ";
+
+ if (resultType.size() > 1)
+ printer << "(";
+ llvm::interleaveComma(resultType, printer.getStream());
+ if (resultType.size() > 1)
+ printer << ")";
+}
+static void printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
+ Type argumentType, Type resultType) {
+ return printSemiFunctionType(printer, op, argumentType,
+ resultType ? TypeRange(resultType)
+ : TypeRange());
+}
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index a33abe9a508c6..2970d3476d512 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3257,6 +3257,10 @@ class LinalgTransformDialectExtension
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
>();
+ registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"
+ >();
}
};
} // namespace
diff --git a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt
index 64619309d5d86..2fed20f927380 100644
--- a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRTransformDialect
MatchInterfaces.cpp
+ TransformAttrs.cpp
TransformDialect.cpp
TransformInterfaces.cpp
TransformOps.cpp
diff --git a/mlir/lib/Dialect/Transform/IR/TransformAttrs.cpp b/mlir/lib/Dialect/Transform/IR/TransformAttrs.cpp
new file mode 100644
index 0000000000000..1549055e38ad9
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/IR/TransformAttrs.cpp
@@ -0,0 +1,12 @@
+//===- TransformAttrs.cpp - Transform Dialect Attribute Definitions -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
+#include "mlir/IR/BuiltinTypes.h"
+
+#include "mlir/Dialect/Transform/IR/TransformDialectEnums.cpp.inc"
diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
index d4578e0648179..5d29e0bd3b3d9 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
@@ -186,5 +186,3 @@ LogicalResult transform::TransformDialect::verifyOperationAttribute(
return emitError(op->getLoc())
<< "unknown attribute: " << attribute.getName();
}
-
-#include "mlir/Dialect/Transform/IR/TransformDialectEnums.cpp.inc"
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 046dcb5bbac87..0fe1801f629ed 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -840,19 +840,8 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
// If a silenceable failure was produced, some results may be unset, set them
// to empty lists.
- if (result.isSilenceableFailure()) {
- for (OpResult opResult : transform->getResults()) {
- if (results.isSet(opResult.getResultNumber()))
- continue;
-
- if (opResult.getType().isa<TransformParamTypeInterface>())
- results.setParams(opResult, {});
- else if (opResult.getType().isa<TransformValueHandleTypeInterface>())
- results.setValues(opResult, {});
- else
- results.set(opResult, {});
- }
- }
+ if (result.isSilenceableFailure())
+ results.setRemainingToEmpty(transform);
// Remove the mapping for the operand if it is consumed by the operation. This
// allows us to catch use-after-free with assertions later on.
@@ -1058,6 +1047,14 @@ void transform::TransformResults::setMappedValues(
(void)diag.silence();
}
+void transform::TransformResults::setRemainingToEmpty(
+ transform::TransformOpInterface transform) {
+ for (OpResult opResult : transform->getResults()) {
+ if (!isSet(opResult.getResultNumber()))
+ setMappedValues(opResult, {});
+ }
+}
+
ArrayRef<Operation *>
transform::TransformResults::get(unsigned resultNumber) const {
assert(resultNumber < operations.size() &&
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 0e20999a6bdbb..6284b41302e2d 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/PDL/IR/PDLOps.h"
#include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
@@ -531,8 +532,8 @@ transform::ForeachMatchOp::apply(transform::TransformResults &results,
if (diag.isDefiniteFailure())
return WalkResult::interrupt();
if (diag.isSilenceableFailure()) {
- DEBUG_MATCHER(DBGS_MATCHER()
- << "matcher " << matcher.getName() << " failed\n");
+ DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
+ << " failed: " << diag.getMessage());
continue;
}
@@ -1171,6 +1172,118 @@ transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
.checkAndReport();
}
+//===----------------------------------------------------------------------===//
+// MatchOperationNameOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::MatchOperationNameOp::matchOperation(
+ Operation *current, transform::TransformResults &results,
+ transform::TransformState &state) {
+ StringRef currentOpName = current->getName().getStringRef();
+ for (auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
+ if (acceptedAttr.getValue() == currentOpName)
+ return DiagnosedSilenceableFailure::success();
+ }
+ return emitSilenceableError() << "wrong operation name";
+}
+
+//===----------------------------------------------------------------------===//
+// MatchParamCmpIOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::MatchParamCmpIOp::apply(transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto signedAPIntAsString = [&](APInt value) {
+ std::string str;
+ llvm::raw_string_ostream os(str);
+ value.print(os, /*isSigned=*/true);
+ return os.str();
+ };
+
+ ArrayRef<Attribute> params = state.getParams(getParam());
+ ArrayRef<Attribute> references = state.getParams(getReference());
+
+ if (params.size() != references.size()) {
+ return emitSilenceableError()
+ << "parameters have
diff erent payload lengths (" << params.size()
+ << " vs " << references.size() << ")";
+ }
+
+ for (auto &&[i, param, reference] : llvm::enumerate(params, references)) {
+ auto intAttr = param.dyn_cast<IntegerAttr>();
+ auto refAttr = reference.dyn_cast<IntegerAttr>();
+ if (!intAttr || !refAttr) {
+ return emitDefiniteFailure()
+ << "non-integer parameter value not expected";
+ }
+ if (intAttr.getType() != refAttr.getType()) {
+ return emitDefiniteFailure()
+ << "mismatching integer attribute types in parameter #" << i;
+ }
+ APInt value = intAttr.getValue();
+ APInt refValue = refAttr.getValue();
+
+ // TODO: this copy will not be necessary in C++20.
+ int64_t position = i;
+ auto reportError = [&](StringRef direction) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError() << "expected parameter to be " << direction
+ << " " << signedAPIntAsString(refValue)
+ << ", got " << signedAPIntAsString(value);
+ diag.attachNote(getParam().getLoc())
+ << "value # " << position
+ << " associated with the parameter defined here";
+ return diag;
+ };
+
+ switch (getPredicate()) {
+ case MatchCmpIPredicate::eq:
+ if (value.eq(refValue))
+ break;
+ return reportError("equal to");
+ case MatchCmpIPredicate::ne:
+ if (value.ne(refValue))
+ break;
+ return reportError("not equal to");
+ case MatchCmpIPredicate::lt:
+ if (value.slt(refValue))
+ break;
+ return reportError("less than");
+ case MatchCmpIPredicate::le:
+ if (value.sle(refValue))
+ break;
+ return reportError("less than or equal to");
+ case MatchCmpIPredicate::gt:
+ if (value.sgt(refValue))
+ break;
+ return reportError("greater than");
+ case MatchCmpIPredicate::ge:
+ if (value.sge(refValue))
+ break;
+ return reportError("greater than or equal to");
+ }
+ }
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::MatchParamCmpIOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getParam(), effects);
+ onlyReadsHandle(getReference(), effects);
+}
+
+//===----------------------------------------------------------------------===//
+// ParamConstantOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ParamConstantOp::apply(transform::TransformResults &results,
+ transform::TransformState &state) {
+ results.setParams(cast<OpResult>(getParam()), {getValue()});
+ return DiagnosedSilenceableFailure::success();
+}
+
//===----------------------------------------------------------------------===//
// MergeHandlesOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
new file mode 100644
index 0000000000000..a8678a766f6b5
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
@@ -0,0 +1,754 @@
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{debug-payload-root-tag=start_here})" --split-input-file --verify-diagnostics
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @print_structured(%arg0: !transform.any_op {transform.readonly}) {
+ transform.test_print_remark_at_operand %arg0, "structured" : !transform.any_op
+ transform.yield
+ }
+
+ transform.named_sequence @match_structured_empty(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
+ %0 = transform.match.structured %arg0 : (!transform.any_op) -> !transform.any_op {
+ ^bb0(%arg1: !transform.any_op):
+ transform.match.structured.yield %arg1 : !transform.any_op
+ }
+ transform.yield %0 : !transform.any_op
+ }
+
+ // Entry point. Match any structured operation and emit at remark.
+ transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } {
+ ^bb0(%arg0: !transform.any_op):
+ transform.foreach_match in %arg0
+ @match_structured_empty -> @print_structured
+ : (!transform.any_op) -> !transform.any_op
+ }
+
+ func.func @payload() attributes { transform.target_tag = "start_here" } {
+ %preA = tensor.empty() : tensor<2x3xf32>
+ %cA = arith.constant 1.0 : f32
+ // expected-remark @below {{structured}}
+ %A = linalg.fill ins(%cA : f32) outs(%preA : tensor<2x3xf32>) -> tensor<2x3xf32>
+
+ %B = arith.constant dense<1.0> : tensor<3x4xf32>
+ %C = arith.constant dense<1000.0> : tensor<2x4xf32>
+ // expected-remark @below {{structured}}
+ %D = linalg.matmul ins(%A, %B: tensor<2x3xf32>, tensor<3x4xf32>)
+ outs(%C: tensor<2x4xf32>) -> tensor<2x4xf32>
+
+ %E = arith.constant dense<2.0> : tensor<2x4xf32>
+ // expected-remark @below {{structured}}
+ linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%D : tensor<2x4xf32>) outs(%E : tensor<2x4xf32>) {
+ ^bb0(%arg0: f32, %arg1: f32):
+ linalg.yield %arg0 : f32
+ } -> tensor<2x4xf32>
+
+ return
+ }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @do_nothing(%arg0: !transform.any_op {transform.readonly}) {
+ transform.yield
+ }
+
+ // Entry point. Match any structured operation and emit a remark. Also emit
+ // a
diff erent remark at all considered operations. When it fails, the
+ // failure is suppressed and the resulting handle is assocaited with an empty
+ // list, hence nothing is printed. Both remark printing operations happen
+ // after the check in the sequence, so they only apply if the check operation
+ // produced success (due to failure suppression or not).
+ transform.named_sequence @match_structured_suppress(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
+ %0 = transform.match.structured failures(suppress) %arg0 : (!transform.any_op) -> !transform.any_op {
+ ^bb0(%arg1: !transform.any_op):
+ transform.match.structured.yield %arg1 : !transform.any_op
+ }
+ transform.test_print_remark_at_operand %0, "structured" : !transform.any_op
+ transform.test_print_remark_at_operand %arg0, "other" : !transform.any_op
+ transform.yield %0 : !transform.any_op
+ }
+
+ transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } {
+ ^bb0(%arg0: !transform.any_op):
+ transform.foreach_match in %arg0
+ @match_structured_suppress -> @do_nothing
+ : (!transform.any_op) -> !transform.any_op
+ }
+
+ func.func @payload() attributes { transform.target_tag = "start_here" } {
+ // expected-remark @below {{other}}
+ %D = arith.constant dense<1.0> : tensor<2x4xf32>
+ // expected-remark @below {{other}}
+ %E = arith.constant dense<2.0> : tensor<2x4xf32>
+ // expected-remark @below {{structured}}
+ // expected-remark @below {{other}}
+ linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%D : tensor<2x4xf32>) outs(%E : tensor<2x4xf32>) {
+ ^bb0(%arg0: f32, %arg1: f32):
+ // expected-remark @below {{other}}
+ linalg.yield %arg0 : f32
+ } -> tensor<2x4xf32>
+
+ // expected-remark @below {{other}}
+ return
+ }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @print_passthrough(%arg0: !transform.any_op {transform.readonly}) {
+ transform.test_print_remark_at_operand %arg0, "passthrough" : !transform.any_op
+ transform.yield
+ }
+
+ transform.named_sequence @match_structured_body_passthrough(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
+ %0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op {
+ ^bb0(%arg1: !transform.any_op):
+ transform.match.structured.body %arg1 { passthrough } : !transform.any_op
+ transform.match.structured.yield %arg1 : !transform.any_op
+ }
+ transform.yield %0 : !transform.any_op
+ }
+
+ transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } {
+ ^bb0(%arg0: !transform.any_op):
+ transform.foreach_match in %arg0
+ @match_structured_body_passthrough -> @print_passthrough
+ : (!transform.any_op) -> !transform.any_op
+ }
+
+ func.func @payload(%in: tensor<2xf32>, %out: tensor<2xf32>) attributes { transform.target_tag = "start_here" } {
+ // expected-remark @below {{passthrough}}
+ linalg.generic {
+ indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]
+ } ins(%in : tensor<2xf32>) outs(%out : tensor<2xf32>) {
+ ^bb0(%arg0: f32, %arg1: f32):
+ linalg.yield %arg0 : f32
+ } -> tensor<2xf32>
+
+ linalg.generic {
+ indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]
+ } ins(%in : tensor<2xf32>) outs(%out : tensor<2xf32>) {
+ ^bb0(%arg0: f32, %arg1: f32):
+ %0 = arith.mulf %arg0, %arg1 : f32
+ linalg.yield %0 : f32
+ } -> tensor<2xf32>
+
+ // expected-remark @below {{passthrough}}
+ linalg.copy ins(%in : tensor<2xf32>) outs(%out : tensor<2xf32>) -> tensor<2xf32>
+
+ return
+ }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @print_reduction(%arg0: !transform.any_op {transform.readonly}) {
+ transform.test_print_remark_at_operand %arg0, "reduction" : !transform.any_op
+ transform.yield
+ }
+
+ transform.named_sequence @match_structured_body_reduction(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
+ %0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op {
+ ^bb0(%arg1: !transform.any_op):
+ transform.match.structured.body %arg1 { reduction_position = 0 } : !transform.any_op
+ transform.match.structured.yield %arg1 : !transform.any_op
+ }
+ transform.yield %0 : !transform.any_op
+ }
+
+ transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } {
+ ^bb0(%arg0: !transform.any_op):
+ transform.foreach_match in %arg0
+ @match_structured_body_reduction -> @print_reduction
+ : (!transform.any_op) -> !transform.any_op
+ }
+
+ func.func @payload(%lhs: tensor<2x4xf32>, %rhs: tensor<4x3xf32>, %out: tensor<2x3xf32>) attributes { transform.target_tag = "start_here" } {
+ // expected-remark @below {{reduction}}
+ linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"]
+ } ins(%lhs, %rhs: tensor<2x4xf32>, tensor<4x3xf32>) outs(%out: tensor<2x3xf32>) {
+ ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
+ %0 = arith.mulf %arg0, %arg1 : f32
+ %1 = arith.addf %0, %arg2 : f32
+ linalg.yield %1 : f32
+ } -> tensor<2x3xf32>
+
+ %r = tensor.empty() : tensor<2x3xf32>
+ linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"]
+ } ins(%lhs, %rhs: tensor<2x4xf32>, tensor<4x3xf32>) outs(%out, %r: tensor<2x3xf32>, tensor<2x3xf32>) {
+ ^bb0(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32):
+ %0 = arith.mulf %arg0, %arg1 : f32
+ %1 = arith.cmpf olt, %0, %arg2 : f32
+ %2 = arith.select %1, %0, %arg2 : f32
+ %3 = arith.select %1, %arg3, %0 : f32
+ linalg.yield %2, %3 : f32, f32
+ } -> (tensor<2x3xf32>, tensor<2x3xf32>)
+
+ // expected-remark @below {{reduction}}
+ linalg.matmul ins(%lhs, %rhs: tensor<2x4xf32>, tensor<4x3xf32>) outs(%out: tensor<2x3xf32>) -> tensor<2x3xf32>
+
+ %e = tensor.empty() : tensor<2x4xf32>
+ linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%lhs: tensor<2x4xf32>) outs(%e: tensor<2x4xf32>) {
+ ^bb0(%arg0: f32, %arg1: f32):
+ linalg.yield %arg0 : f32
+ } -> tensor<2x4xf32>
+
+ return
+ }
+}
+
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @do_nothing(%arg0: !transform.any_op {transform.readonly}) {
+ transform.yield
+ }
+
+ transform.named_sequence @match_dimension_capture(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
+ // Capture multiple dimension values. Suppress failures so we can print them anyway after the capture.
+ %0:9 = transform.match.structured failures(suppress) %arg0
+ : (!transform.any_op) -> (!transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>,
+ !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>) {
+ ^bb0(%arg1: !transform.any_op):
+ // This also tests the positional specification used by other ops, which may not test it again.
+ %1 = transform.match.structured.dim %arg1[all] : (!transform.any_op) -> !transform.param<i64>
+ %2 = transform.match.structured.dim %arg1[0] : (!transform.any_op) -> !transform.param<i64>
+ %3 = transform.match.structured.dim %arg1[-1] : (!transform.any_op) -> !transform.param<i64>
+ %4 = transform.match.structured.dim %arg1[0, 2] : (!transform.any_op) -> !transform.param<i64>
+ %5 = transform.match.structured.dim %arg1[0, -1] : (!transform.any_op) -> !transform.param<i64>
+ %6 = transform.match.structured.dim %arg1[except(-1)] : (!transform.any_op) -> !transform.param<i64>
+ %7 = transform.match.structured.dim %arg1[except(0, -2)] : (!transform.any_op) -> !transform.param<i64>
+ %8 = transform.match.structured.dim %arg1[0, -3] : (!transform.any_op) -> !transform.param<i64>
+ transform.match.structured.yield %arg1, %1, %2, %3, %4, %5, %6, %7, %8
+ : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>,
+ !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>
+ }
+ transform.test_print_param %0#1, "dimensions all:" at %0#0 : !transform.param<i64>, !transform.any_op
+ transform.test_print_param %0#2, "dimension 0:" at %0#0 : !transform.param<i64>, !transform.any_op
+ transform.test_print_param %0#3, "dimension -1:" at %0#0 : !transform.param<i64>, !transform.any_op
+ transform.test_print_param %0#4, "dimensions 0, 2:" at %0#0 : !transform.param<i64>, !transform.any_op
+ transform.test_print_param %0#5, "dimensions 0, -1:" at %0#0 : !transform.param<i64>, !transform.any_op
+ transform.test_print_param %0#6, "dimensions except -1:" at %0#0 : !transform.param<i64>, !transform.any_op
+ transform.test_print_param %0#7, "dimensions except 0, -2:" at %0#0 : !transform.param<i64>, !transform.any_op
+ transform.test_print_param %0#8, "dimensions 0, -3:" at %0#0 : !transform.param<i64>, !transform.any_op
+ transform.yield %0#0 : !transform.any_op
+ }
+
+ transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } {
+ ^bb0(%arg0: !transform.any_op):
+ transform.foreach_match in %arg0 @match_dimension_capture -> @do_nothing : (!transform.any_op) -> !transform.any_op
+ }
+
+ func.func @payload(%lhs: tensor<2x4xf32>, %rhs: tensor<4x3xf32>, %out: tensor<2x3xf32>) attributes { transform.target_tag = "start_here" } {
+ // The last does not emit anything because it fails to match
+ // due to 0 and -3 being the same dimension in the 3D case.
+ // expected-remark @below {{dimensions all: 2 : i64, 3 : i64, 4 : i64}}
+ // expected-remark @below {{dimension 0: 2 : i64}}
+ // expected-remark @below {{dimension -1: 4 : i64}}
+ // expected-remark @below {{dimensions 0, 2: 2 : i64, 4 : i64}}
+ // expected-remark @below {{dimensions 0, -1: 2 : i64, 4 : i64}}
+ // expected-remark @below {{dimensions except -1: 2 : i64, 3 : i64}}
+ // expected-remark @below {{dimensions except 0, -2: 4 : i64}}
+ // expected-remark @below {{dimensions 0, -3:}}
+ linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"]
+ } ins(%lhs, %rhs: tensor<2x4xf32>, tensor<4x3xf32>) outs(%out: tensor<2x3xf32>) {
+ ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
+ %0 = arith.mulf %arg0, %arg1 : f32
+ %1 = arith.addf %0, %arg2 : f32
+ linalg.yield %1 : f32
+ } -> tensor<2x3xf32>
+
+ return
+ }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @print_all_reduction(%arg0: !transform.any_op {transform.readonly}) {
+ transform.test_print_remark_at_operand %arg0, "all reduction" : !transform.any_op
+ transform.yield
+ }
+ transform.named_sequence @print_all_parallel(%arg0: !transform.any_op {transform.readonly}) {
+ transform.test_print_remark_at_operand %arg0, "all parallel" : !transform.any_op
+ transform.yield
+ }
+ transform.named_sequence @print_last_reduction(%arg0: !transform.any_op {transform.readonly}) {
+ transform.test_print_remark_at_operand %arg0, "last reduction" : !transform.any_op
+ transform.yield
+ }
+ transform.named_sequence @print_parallel_except_last(%arg0: !transform.any_op {transform.readonly}) {
+ transform.test_print_remark_at_operand %arg0, "parallel except last" : !transform.any_op
+ transform.yield
+ }
+
+ transform.named_sequence @match_all_reduction(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
+ transform.match.structured failures(propagate) %arg0 : !transform.any_op {
+ ^bb0(%arg1: !transform.any_op):
+ transform.match.structured.dim %arg1[all] { reduction } : !transform.any_op
+ transform.match.structured.yield
+ }
+ transform.yield %arg0 : !transform.any_op
+ }
+ transform.named_sequence @match_all_parallel(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
+ transform.match.structured failures(propagate) %arg0 : !transform.any_op {
+ ^bb0(%arg1: !transform.any_op):
+ transform.match.structured.dim %arg1[all] { parallel } : !transform.any_op
+ transform.match.structured.yield
+ }
+ transform.yield %arg0 : !transform.any_op
+ }
+ transform.named_sequence @match_last_reduction(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
+ transform.match.structured failures(propagate) %arg0 : !transform.any_op {
+ ^bb0(%arg1: !transform.any_op):
+ transform.match.structured.dim %arg1[-1] { reduction } : !transform.any_op
+ transform.match.structured.yield
+ }
+ transform.yield %arg0 : !transform.any_op
+ }
+ transform.named_sequence @match_parallel_except_last(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
+ transform.match.structured failures(propagate) %arg0 : !transform.any_op {
+ ^bb0(%arg1: !transform.any_op):
+ transform.match.structured.dim %arg1[except(-1)] { parallel } : !transform.any_op
+ transform.match.structured.yield
+ }
+ transform.yield %arg0 : !transform.any_op
+ }
+
+ transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } {
+ ^bb0(%arg0: !transform.any_op):
+ %0 = transform.foreach_match in %arg0 @match_all_reduction -> @print_all_reduction : (!transform.any_op) -> !transform.any_op
+ %1 = transform.foreach_match in %0 @match_all_parallel -> @print_all_parallel : (!transform.any_op) -> !transform.any_op
+ %2 = transform.foreach_match in %1 @match_last_reduction -> @print_last_reduction : (!transform.any_op) -> !transform.any_op
+ %3 = transform.foreach_match in %2 @match_parallel_except_last -> @print_parallel_except_last : (!transform.any_op) -> !transform.any_op
+ }
+
+ func.func @payload(%lhs: tensor<2x4xf32>, %rhs: tensor<4x3xf32>, %out: tensor<2x3xf32>) attributes { transform.target_tag = "start_here" } {
+ // expected-remark @below {{last reduction}}
+ // expected-remark @below {{parallel except last}}
+ linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"]
+ } ins(%lhs, %rhs: tensor<2x4xf32>, tensor<4x3xf32>) outs(%out: tensor<2x3xf32>) {
+ ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
+ %0 = arith.mulf %arg0, %arg1 : f32
+ %1 = arith.addf %0, %arg2 : f32
+ linalg.yield %1 : f32
+ } -> tensor<2x3xf32>
+
+ // expected-remark @below {{last reduction}}
+ // expected-remark @below {{parallel except last}}
+ linalg.matmul ins(%lhs, %rhs : tensor<2x4xf32>, tensor<4x3xf32>) outs(%out : tensor<2x3xf32>) -> tensor<2x3xf32>
+
+ %cst = arith.constant 1.0 : f32
+ // expected-remark @below {{all parallel}}
+ // expected-remark @below {{parallel except last}}
+ linalg.fill ins(%cst : f32) outs(%out: tensor<2x3xf32>) -> tensor<2x3xf32>
+
+ return
+ }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @match_bitwidth(%arg0: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.param<i64>) {
+ %bw = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.param<i64> {
+ ^bb0(%arg1: !transform.any_op):
+ %0 = transform.match.structured.init %arg1 [0] : (!transform.any_op) -> !transform.any_value
+ %1 = transform.match.structured.elemental_bitwidth %0 : (!transform.any_value) -> !transform.param<i64>
+ transform.match.structured.yield %1 : !transform.param<i64>
+ }
+ transform.yield %arg0, %bw : !transform.any_op, !transform.param<i64>
+ }
+
+ transform.named_sequence @print_bitwidth(%arg0: !transform.any_op {transform.readonly}, %arg1: !transform.param<i64> {transform.readonly}) {
+ transform.test_print_param %arg1, "bitwidth:" at %arg0 : !transform.param<i64>, !transform.any_op
+ transform.yield
+ }
+
+ transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } {
+ ^bb0(%arg0: !transform.any_op):
+ transform.foreach_match in %arg0 @match_bitwidth -> @print_bitwidth : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+
+ func.func @payload(%f32: f32, %tf32: tensor<?xf32>,
+ %index: index, %tindex: tensor<?xindex>)
+ attributes { transform.target_tag = "start_here" } {
+ // expected-remark @below {{bitwidth: 32}}
+ linalg.fill ins(%f32: f32) outs(%tf32: tensor<?xf32>) -> tensor<?xf32>
+ linalg.fill ins(%index: index) outs(%tindex: tensor<?xindex>) -> tensor<?xindex>
+ return
+ }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @match_init(%arg0: !transform.any_op {transform.readonly})
+ -> (!transform.any_op, !transform.any_value, !transform.any_value, !transform.any_op) {
+ %outs:3 = transform.match.structured failures(suppress) %arg0
+ : (!transform.any_op) -> (!transform.any_value, !transform.any_value, !transform.any_op) {
+ ^bb0(%arg1: !transform.any_op):
+ %0 = transform.match.structured.init %arg1 [0] : (!transform.any_op) -> !transform.any_value
+ %1 = transform.match.structured.init %arg1 [all] : (!transform.any_op) -> !transform.any_value
+ %2 = transform.match.structured.init %arg1 [0] : (!transform.any_op) -> !transform.any_op
+ transform.match.structured.yield %0, %1, %2 : !transform.any_value, !transform.any_value, !transform.any_op
+ }
+ transform.yield %arg0, %outs#0, %outs#1, %outs#2 : !transform.any_op, !transform.any_value, !transform.any_value, !transform.any_op
+ }
+
+ transform.named_sequence @print_init(%arg0: !transform.any_op {transform.readonly},
+ %arg1: !transform.any_value {transform.readonly},
+ %arg2: !transform.any_value {transform.readonly},
+ %arg3: !transform.any_op {transform.readonly}) {
+ transform.test_print_remark_at_operand_value %arg1, "output 0" : !transform.any_value
+ transform.test_print_remark_at_operand %arg3, "output producer" : !transform.any_op
+ transform.test_print_remark_at_operand_value %arg2, "all output" : !transform.any_value
+ transform.yield
+ }
+
+ transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } {
+ ^bb0(%arg0: !transform.any_op):
+ transform.foreach_match in %arg0 @match_init -> @print_init : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+
+
+ func.func @payload(%f32: f32,
+ // expected-remark @below {{output 0}}
+ // expected-remark @below {{all output}}
+ // expected-note @below {{value handle points to a block argument #1 in block #0 in region #0}}
+ %tf32: tensor<?xf32>,
+ // expected-remark @below {{all output}}
+ // expected-note @below {{value handle points to a block argument #2 in block #0 in region #0}}
+ %tf32_2: tensor<?xf32>)
+ attributes { transform.target_tag = "start_here" } {
+ // expected-remark @below {{output 0}}
+ // expected-remark @below {{output producer}}
+ // expected-remark @below {{all output}}
+ // expected-note @below {{value handle points to an op result #0}}
+ %0 = linalg.fill ins(%f32: f32) outs(%tf32: tensor<?xf32>) -> tensor<?xf32>
+
+ linalg.generic {
+ indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]
+ } ins(%tf32: tensor<?xf32>) outs(%0, %tf32_2: tensor<?xf32>, tensor<?xf32>) {
+ ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
+ linalg.yield %arg0, %arg0 : f32, f32
+ } -> (tensor<?xf32>, tensor<?xf32>)
+ return
+ }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @match_init_0_permutation(%arg0: !transform.any_op {transform.readonly})
+ -> !transform.any_op {
+ %0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op {
+ ^bb0(%arg1: !transform.any_op):
+ transform.match.structured.init %arg1[0] { permutation }: !transform.any_op
+ transform.match.structured.yield %arg1 : !transform.any_op
+ }
+ transform.yield %0 : !transform.any_op
+ }
+ transform.named_sequence @match_init_1_permutation(%arg0: !transform.any_op {transform.readonly})
+ -> !transform.any_op {
+ %0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op {
+ ^bb0(%arg1: !transform.any_op):
+ transform.match.structured.init %arg1[1] { permutation }: !transform.any_op
+ transform.match.structured.yield %arg1 : !transform.any_op
+ }
+ transform.yield %0 : !transform.any_op
+ }
+ transform.named_sequence @match_init_2_projected_permutation(%arg0: !transform.any_op {transform.readonly})
+ -> !transform.any_op {
+ %0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op {
+ ^bb0(%arg1: !transform.any_op):
+ transform.match.structured.init %arg1[2] { projected_permutation }: !transform.any_op
+ transform.match.structured.yield %arg1 : !transform.any_op
+ }
+ transform.yield %0 : !transform.any_op
+ }
+
+ transform.named_sequence @print_init_0_permutation(%arg0: !transform.any_op {transform.readonly}) {
+ transform.test_print_remark_at_operand %arg0, "matched output 0 permutation" : !transform.any_op
+ transform.yield
+ }
+ transform.named_sequence @print_init_1_permutation(%arg0: !transform.any_op {transform.readonly}) {
+ transform.test_print_remark_at_operand %arg0, "matched output 1 permutation" : !transform.any_op
+ transform.yield
+ }
+ transform.named_sequence @print_init_2_projected_permutation(%arg0: !transform.any_op {transform.readonly}) {
+ transform.test_print_remark_at_operand %arg0, "matched output 2 projected permutation" : !transform.any_op
+ transform.yield
+ }
+
+ transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } {
+ ^bb0(%arg0: !transform.any_op):
+ %0 = transform.foreach_match in %arg0 @match_init_0_permutation -> @print_init_0_permutation : (!transform.any_op) -> !transform.any_op
+ %1 = transform.foreach_match in %0 @match_init_1_permutation -> @print_init_1_permutation : (!transform.any_op) -> !transform.any_op
+ %2 = transform.foreach_match in %1 @match_init_2_projected_permutation -> @print_init_2_projected_permutation : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+
+ func.func @payload(%f32: f32,
+ %oned: tensor<?xf32>,
+ %oned2: tensor<?xf32>,
+ %twod: tensor<?x?xf32>)
+ attributes { transform.target_tag = "start_here" } {
+ // expected-remark @below {{matched output 2 projected permutation}}
+ linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0)>,
+ affine_map<(d0, d1) -> (d0 + d1)>,
+ affine_map<(d0, d1) -> (d1)>,
+ affine_map<(d0, d1) -> (d1, d0)>],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%oned: tensor<?xf32>) outs(%oned, %oned2, %twod: tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>) {
+ ^bb0(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32):
+ linalg.yield %arg0, %arg0, %arg0 : f32, f32, f32
+ } -> (tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>)
+
+ // expected-remark @below {{matched output 2 projected permutation}}
+ // expected-remark @below {{matched output 1 permutation}}
+ linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0)>,
+ affine_map<(d0, d1) -> (d0 + d1)>,
+ affine_map<(d0, d1) -> (d1, d0)>,
+ affine_map<(d0, d1) -> (d1)>],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%oned: tensor<?xf32>) outs(%oned, %twod, %oned2: tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>) {
+ ^bb0(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32):
+ linalg.yield %arg0, %arg0, %arg0 : f32, f32, f32
+ } -> (tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>)
+ return
+ }
+}
+
+// -----
+
+
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @match_num_io(%arg0: !transform.any_op {transform.readonly})
+ -> (!transform.param<i64>, !transform.param<i64>, !transform.any_op) {
+ %0:3 = transform.match.structured failures(propagate) %arg0
+ : (!transform.any_op) -> (!transform.param<i64>, !transform.param<i64>, !transform.any_op) {
+ ^bb0(%arg1: !transform.any_op):
+ %1 = transform.match.structured.num_inputs %arg1 : (!transform.any_op) -> !transform.param<i64>
+ %2 = transform.match.structured.num_inits %arg1 : (!transform.any_op) -> !transform.param<i64>
+ transform.match.structured.yield %1, %2, %arg1 : !transform.param<i64>, !transform.param<i64>, !transform.any_op
+ }
+ transform.yield %0#0, %0#1, %0#2 : !transform.param<i64>, !transform.param<i64>, !transform.any_op
+ }
+
+
+ transform.named_sequence @print_num_io(
+ %arg0: !transform.param<i64> {transform.readonly},
+ %arg1: !transform.param<i64> {transform.readonly},
+ %arg2: !transform.any_op {transform.readonly}) {
+ transform.test_print_param %arg0, "inputs" at %arg2 : !transform.param<i64>, !transform.any_op
+ transform.test_print_param %arg1, "outputs" at %arg2 : !transform.param<i64>, !transform.any_op
+ transform.yield
+ }
+
+
+ transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } {
+ ^bb0(%arg0: !transform.any_op):
+ %0 = transform.foreach_match in %arg0 @match_num_io -> @print_num_io : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+
+ func.func @payload(%f32: f32,
+ %oned: tensor<?xf32>,
+ %oned2: tensor<?xf32>,
+ %twod: tensor<?x?xf32>)
+ attributes { transform.target_tag = "start_here" } {
+ // expected-remark @below {{inputs 1}}
+ // expected-remark @below {{outputs 3}}
+ linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0)>,
+ affine_map<(d0, d1) -> (d0 + d1)>,
+ affine_map<(d0, d1) -> (d1)>,
+ affine_map<(d0, d1) -> (d1, d0)>],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%oned: tensor<?xf32>) outs(%oned, %oned2, %twod: tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>) {
+ ^bb0(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32):
+ linalg.yield %arg0, %arg0, %arg0 : f32, f32, f32
+ } -> (tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>)
+
+ // expected-remark @below {{inputs 2}}
+ // expected-remark @below {{outputs 2}}
+ linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0)>,
+ affine_map<(d0, d1) -> (d1, d0)>,
+ affine_map<(d0, d1) -> (d0 + d1)>,
+ affine_map<(d0, d1) -> (d1)>],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%oned, %twod: tensor<?xf32>, tensor<?x?xf32>) outs(%oned, %oned2: tensor<?xf32>, tensor<?xf32>) {
+ ^bb0(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32):
+ linalg.yield %arg0, %arg0 : f32, f32
+ } -> (tensor<?xf32>, tensor<?xf32>)
+ return
+ }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @match_rank(%arg0: !transform.any_op {transform.readonly})
+ -> (!transform.param<i64>, !transform.any_op) {
+ %0:2 = transform.match.structured failures(propagate) %arg0
+ : (!transform.any_op) -> (!transform.param<i64>, !transform.any_op) {
+ ^bb0(%arg1: !transform.any_op):
+ %1 = transform.match.structured.rank %arg1 : (!transform.any_op) -> !transform.param<i64>
+ transform.match.structured.yield %1, %arg1 : !transform.param<i64>, !transform.any_op
+ }
+ transform.yield %0#0, %0#1 : !transform.param<i64>, !transform.any_op
+ }
+
+
+ transform.named_sequence @print_rank(%arg0: !transform.param<i64> {transform.readonly},
+ %arg2: !transform.any_op {transform.readonly}) {
+ transform.test_print_param %arg0, "rank" at %arg2 : !transform.param<i64>, !transform.any_op
+ transform.yield
+ }
+
+ transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } {
+ ^bb0(%arg0: !transform.any_op):
+ %0 = transform.foreach_match in %arg0 @match_rank -> @print_rank : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+
+ func.func @payload(%f32: f32,
+ %twod: tensor<42x42xf32>)
+ attributes { transform.target_tag = "start_here" } {
+ %0 = tensor.empty() : tensor<42x42xf32>
+ // expected-remark @below {{rank 2}}
+ %1 = linalg.fill ins(%f32 : f32) outs(%0 : tensor<42x42xf32>) -> tensor<42x42xf32>
+ // expected-remark @below {{rank 3}}
+ linalg.matmul ins(%twod, %twod : tensor<42x42xf32>, tensor<42x42xf32>)
+ outs(%1 : tensor<42x42xf32>) -> tensor<42x42xf32>
+ return
+ }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @match_single_result(%arg0: !transform.any_op {transform.readonly})
+ -> (!transform.any_op, !transform.any_op) {
+ %0:2 = transform.match.structured failures(propagate) %arg0
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op) {
+ ^bb0(%arg1: !transform.any_op):
+ %1 = transform.match.structured.result %arg1[0] { single } : (!transform.any_op) -> !transform.any_op
+ transform.match.structured.yield %1, %arg1 : !transform.any_op, !transform.any_op
+ }
+ transform.yield %0#0, %0#1 : !transform.any_op, !transform.any_op
+ }
+ transform.named_sequence @match_result_value(%arg0: !transform.any_op {transform.readonly})
+ -> (!transform.any_value, !transform.any_op) {
+ %0:2 = transform.match.structured failures(propagate) %arg0
+ : (!transform.any_op) -> (!transform.any_value, !transform.any_op) {
+ ^bb0(%arg1: !transform.any_op):
+ %1 = transform.match.structured.result %arg1[0] : (!transform.any_op) -> !transform.any_value
+ transform.match.structured.yield %1, %arg1 : !transform.any_value, !transform.any_op
+ }
+ transform.yield %0#0, %0#1 : !transform.any_value, !transform.any_op
+ }
+ transform.named_sequence @match_any_result(%arg0: !transform.any_op {transform.readonly})
+ -> (!transform.any_op) {
+ %0 = transform.match.structured failures(propagate) %arg0
+ : (!transform.any_op) -> !transform.any_op {
+ ^bb0(%arg1: !transform.any_op):
+ %1 = transform.match.structured.result %arg1[-1] { any } : (!transform.any_op) -> !transform.any_op
+ transform.match.structured.yield %arg1 : !transform.any_op
+ }
+ transform.yield %0 : !transform.any_op
+ }
+
+ transform.named_sequence @print_single_result(%arg0: !transform.any_op {transform.readonly},
+ %arg2: !transform.any_op {transform.readonly}) {
+ transform.test_print_remark_at_operand %arg2, "matched single result" : !transform.any_op
+ transform.test_print_remark_at_operand %arg0, "single user" : !transform.any_op
+ transform.yield
+ }
+ transform.named_sequence @print_result_value(%arg0: !transform.any_value {transform.readonly},
+ %arg1: !transform.any_op {transform.readonly}) {
+ transform.test_print_remark_at_operand %arg1, "matched result value" : !transform.any_op
+ transform.test_print_remark_at_operand_value %arg0, "op result" : !transform.any_value
+ transform.yield
+ }
+ transform.named_sequence @print_any_result(%arg0: !transform.any_op {transform.readonly}) {
+ transform.test_print_remark_at_operand %arg0, "matched any result" : !transform.any_op
+ transform.yield
+ }
+
+ transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } {
+ ^bb0(%arg0: !transform.any_op):
+ %0 = transform.foreach_match in %arg0 @match_single_result -> @print_single_result : (!transform.any_op) -> !transform.any_op
+ %1 = transform.foreach_match in %0 @match_result_value -> @print_result_value : (!transform.any_op) -> !transform.any_op
+ %2 = transform.foreach_match in %1 @match_any_result -> @print_any_result : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+
+ func.func @payload(%f32: f32, %f322: f32, %f323: f32,
+ %twod: tensor<42x42xf32>)
+ attributes { transform.target_tag = "start_here" } {
+ %0 = tensor.empty() : tensor<42x42xf32>
+
+ // expected-remark @below {{matched result value}}
+ // expected-remark @below {{op result}}
+ // expected-note @below {{value handle points to an op result #0}}
+ %1 = linalg.fill ins(%f32 : f32) outs(%0 : tensor<42x42xf32>) -> tensor<42x42xf32>
+ // expected-remark @below {{matched result value}}
+ // expected-remark @below {{op result}}
+ // expected-note @below {{value handle points to an op result #0}}
+ // expected-remark @below {{matched single result}}
+ // expected-remark @below {{matched any result}}
+ %2 = linalg.fill ins(%f322 : f32) outs(%0 : tensor<42x42xf32>) -> tensor<42x42xf32>
+ // expected-remark @below {{matched result value}}
+ // expected-remark @below {{op result}}
+ // expected-note @below {{value handle points to an op result #0}}
+ // expected-remark @below {{matched any result}}
+ %3 = linalg.fill ins(%f323 : f32) outs(%0 : tensor<42x42xf32>) -> tensor<42x42xf32>
+
+ // expected-remark @below {{matched result value}}
+ // expected-remark @below {{op result}}
+ // expected-note @below {{value handle points to an op result #0}}
+ // expected-remark @below {{single user}}
+ linalg.elemwise_unary {fun = #linalg.unary_fn<negf>} ins(%2 : tensor<42x42xf32>) outs(%0 : tensor<42x42xf32>) -> tensor<42x42xf32>
+ // expected-remark @below {{matched result value}}
+ // expected-remark @below {{op result}}
+ // expected-note @below {{value handle points to an op result #0}}
+ linalg.elemwise_unary {fun = #linalg.unary_fn<exp>} ins(%3 : tensor<42x42xf32>) outs(%0 : tensor<42x42xf32>) -> tensor<42x42xf32>
+ // expected-remark @below {{matched result value}}
+ // expected-remark @below {{op result}}
+ // expected-note @below {{value handle points to an op result #0}}
+ linalg.elemwise_unary {fun = #linalg.unary_fn<exp>} ins(%3 : tensor<42x42xf32>) outs(%0 : tensor<42x42xf32>) -> tensor<42x42xf32>
+ return
+ }
+}
diff --git a/mlir/test/Dialect/Linalg/match-ops-invalid.mlir b/mlir/test/Dialect/Linalg/match-ops-invalid.mlir
new file mode 100644
index 0000000000000..6dade865c64c1
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/match-ops-invalid.mlir
@@ -0,0 +1,225 @@
+// RUN: mlir-opt %s --split-input-file --verify-diagnostics
+
+transform.sequence failures(suppress) {
+^bb0(%arg0: !transform.any_op):
+ // expected-error @below {{expected one body argument}}
+ transform.match.structured %arg0 : !transform.any_op {
+ ^bb1:
+ transform.match.structured.yield
+ }
+ transform.yield
+}
+
+// -----
+
+transform.sequence failures(suppress) {
+^bb0(%arg0: !transform.any_op):
+ // expected-error @below {{expected body argument to implement TransformHandleTypeInterface}}
+ transform.match.structured %arg0 : !transform.any_op {
+ ^bb1(%arg1: i32):
+ transform.match.structured.yield
+ }
+ transform.yield
+}
+
+// -----
+
+transform.sequence failures(suppress) {
+^bb0(%arg0: !transform.any_op):
+ // expected-error @below {{expects nested operations to implement MatchOpInterface}}
+ transform.match.structured %arg0 : !transform.any_op {
+ ^bb1(%arg1: !transform.any_op):
+ // expected-note @below {{offending operation}}
+ transform.test_consume_operand %arg1 : !transform.any_op
+ transform.match.structured.yield
+ }
+ transform.yield
+}
+// -----
+
+transform.sequence failures(suppress) {
+^bb0(%arg0: !transform.any_op):
+ // expected-error @below {{expects parent op to be 'transform.match.structured'}}
+ transform.match.structured.body %arg0 { passthrough } : !transform.any_op
+ transform.yield
+}
+
+
+// -----
+
+transform.sequence failures(suppress) {
+^bb0(%arg0: !transform.any_op):
+ transform.match.structured %arg0 : !transform.any_op {
+ ^bb1(%arg1: !transform.any_op):
+ // expected-error @below {{expected predicate to apply to the surrounding structured op}}
+ transform.match.structured.body %arg0 { passthrough } : !transform.any_op
+ transform.match.structured.yield
+ }
+ transform.yield
+}
+
+// -----
+
+transform.sequence failures(suppress) {
+^bb0(%arg0: !transform.any_op):
+ transform.match.structured %arg0 : !transform.any_op {
+ ^bb1(%arg1: !transform.any_op):
+ // expected-error @below {{reduction position and passthrough conditions are mutually exclusive}}
+ transform.match.structured.body %arg1 { passthrough, reduction_position = 0 } : !transform.any_op
+ transform.match.structured.yield
+ }
+ transform.yield
+}
+
+// -----
+
+transform.sequence failures(suppress) {
+^bb0(%arg0: !transform.any_op):
+ transform.match.structured %arg0 : !transform.any_op {
+ ^bb1(%arg1: !transform.any_op):
+ // expected-error @below {{cannot request both 'all' and 'inverted' values in the list}}
+ "transform.match.structured.dim"(%arg1) { is_all, is_inverted, raw_dim_list = array<i64> } : (!transform.any_op) -> ()
+ transform.match.structured.yield
+ }
+ transform.yield
+}
+
+// -----
+
+transform.sequence failures(suppress) {
+^bb0(%arg0: !transform.any_op):
+ transform.match.structured %arg0 : !transform.any_op {
+ ^bb1(%arg1: !transform.any_op):
+ // expected-error @below {{cannot both request 'all' and specific values in the list}}
+ "transform.match.structured.dim"(%arg1) { is_all, raw_dim_list = array<i64: 0, 1> } : (!transform.any_op) -> ()
+ transform.match.structured.yield
+ }
+ transform.yield
+}
+// -----
+
+transform.sequence failures(suppress) {
+^bb0(%arg0: !transform.any_op):
+ transform.match.structured %arg0 : !transform.any_op {
+ ^bb1(%arg1: !transform.any_op):
+ // expected-error @below {{must request specific values in the list if 'all' is not specified}}
+ "transform.match.structured.dim"(%arg1) { raw_dim_list = array<i64> } : (!transform.any_op) -> ()
+ transform.match.structured.yield
+ }
+ transform.yield
+}
+
+// -----
+
+transform.sequence failures(suppress) {
+^bb0(%arg0: !transform.any_op):
+ transform.match.structured %arg0 : !transform.any_op {
+ ^bb1(%arg1: !transform.any_op):
+ // expected-error @below {{op expected the listed values to be unique}}
+ "transform.match.structured.dim"(%arg1) { raw_dim_list = array<i64: 0, 0> } : (!transform.any_op) -> ()
+ transform.match.structured.yield
+ }
+ transform.yield
+}
+
+// -----
+
+transform.sequence failures(suppress) {
+^bb0(%arg0: !transform.any_op):
+ transform.match.structured %arg0 : !transform.any_op {
+ ^bb1(%arg1: !transform.any_op):
+ // expected-error @below {{cannot request the same dimension to be both parallel and reduction}}
+ "transform.match.structured.dim"(%arg1) { is_all, parallel, reduction, raw_dim_list = array<i64> } : (!transform.any_op) -> ()
+ transform.match.structured.yield
+ }
+ transform.yield
+}
+
+// -----
+
+transform.sequence failures(suppress) {
+^bb0(%arg0: !transform.any_op):
+ transform.match.structured %arg0 : !transform.any_op {
+ ^bb1(%arg1: !transform.any_op):
+ // expected-error @below {{"permutation" and "projected_permutation" are mutually exclusive}}
+ transform.match.structured.input %arg1[all] { permutation, projected_permutation } : !transform.any_op
+ transform.match.structured.yield
+ }
+ transform.yield
+}
+// -----
+
+transform.sequence failures(suppress) {
+^bb0(%arg0: !transform.any_op):
+ transform.match.structured %arg0 : !transform.any_op {
+ ^bb1(%arg1: !transform.any_op):
+ // expected-error @below {{cannot bind multiple inputs/inits to the same value}}
+ transform.match.structured.input %arg1[0, 1] : (!transform.any_op) -> !transform.any_op
+ transform.match.structured.yield
+ }
+ transform.yield
+}
+
+// -----
+
+transform.sequence failures(suppress) {
+^bb0(%arg0: !transform.any_op):
+ transform.match.structured %arg0 : !transform.any_op {
+ ^bb1(%arg1: !transform.any_op):
+ // expected-error @below {{"permutation" and "projected_permutation" are mutually exclusive}}
+ transform.match.structured.init %arg1[all] { permutation, projected_permutation } : !transform.any_op
+ transform.match.structured.yield
+ }
+ transform.yield
+}
+// -----
+
+transform.sequence failures(suppress) {
+^bb0(%arg0: !transform.any_op):
+ transform.match.structured %arg0 : !transform.any_op {
+ ^bb1(%arg1: !transform.any_op):
+ // expected-error @below {{cannot bind multiple inputs/inits to the same value}}
+ transform.match.structured.init %arg1[0, 1] : (!transform.any_op) -> !transform.any_op
+ transform.match.structured.yield
+ }
+ transform.yield
+}
+
+// -----
+
+transform.sequence failures(suppress) {
+^bb0(%arg0: !transform.any_op):
+ transform.match.structured %arg0 : !transform.any_op {
+ ^bb1(%arg1: !transform.any_op):
+ // expected-error @below {{expects either the any/single keyword or the type value handle result type}}
+ transform.match.structured.result %arg1[0] : (!transform.any_op) -> !transform.any_op
+ transform.match.structured.yield
+ }
+ transform.yield
+}
+
+// -----
+
+transform.sequence failures(suppress) {
+^bb0(%arg0: !transform.any_op):
+ transform.match.structured %arg0 : !transform.any_op {
+ ^bb1(%arg1: !transform.any_op):
+ // expected-error @below {{expects either the any/single keyword or the type value handle result type}}
+ transform.match.structured.result %arg1[0] {any} : (!transform.any_op) -> !transform.any_value
+ transform.match.structured.yield
+ }
+ transform.yield
+}
+
+// -----
+
+transform.sequence failures(suppress) {
+^bb0(%arg0: !transform.any_op):
+ transform.match.structured %arg0 : !transform.any_op {
+ ^bb1(%arg1: !transform.any_op):
+ // expected-error @below {{'any' and 'single' are mutually exclusive}}
+ transform.match.structured.result %arg1[0] {any, single} : (!transform.any_op) -> !transform.any_op
+ transform.match.structured.yield
+ }
+ transform.yield
+}
diff --git a/mlir/test/Dialect/Linalg/match-ops.mlir b/mlir/test/Dialect/Linalg/match-ops.mlir
new file mode 100644
index 0000000000000..7d15a101446ec
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/match-ops.mlir
@@ -0,0 +1,40 @@
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ transform.match.structured %arg0 : !transform.any_op {
+ ^bb1(%arg1: !transform.any_op):
+ // Checking the syntax of positional specifications.
+ // CHECK: dim %{{.*}}[all]
+ transform.match.structured.dim %arg1[all] : !transform.any_op
+ // CHECK: dim %{{.*}}[0]
+ transform.match.structured.dim %arg1[0] : !transform.any_op
+ // CHECK: dim %{{.*}}[0, 1, -2]
+ transform.match.structured.dim %arg1[0, 1, -2] : !transform.any_op
+ // CHECK: dim %{{.*}}[except(0)]
+ transform.match.structured.dim %arg1[except(0)] : !transform.any_op
+ // CHECK: dim %{{.*}}[except(0, -1, 2)]
+ transform.match.structured.dim %arg1[except(0, -1, 2)] : !transform.any_op
+
+ transform.match.structured.yield
+ }
+
+ // Checking the syntax of trailing types.
+ // CHECK: structured %{{.*}} : !transform.any_op
+ transform.match.structured %arg0 : !transform.any_op {
+ ^bb1(%arg1: !transform.any_op):
+ transform.match.structured.yield
+ }
+ // CHECK: structured %{{.*}} : (!transform.any_op) -> !transform.any_op
+ transform.match.structured %arg0 : (!transform.any_op) -> !transform.any_op {
+ ^bb1(%arg1: !transform.any_op):
+ transform.match.structured.yield %arg1 : !transform.any_op
+ }
+ // CHECK: structured %{{.*}} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.match.structured %arg0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) {
+ ^bb1(%arg1: !transform.any_op):
+ transform.match.structured.yield %arg1, %arg1 : !transform.any_op, !transform.any_op
+ }
+
+ transform.yield
+}
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 44124e9bca747..17af525156056 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1411,7 +1411,7 @@ module attributes { transform.with_named_sequence } {
module attributes { transform.with_named_sequence } {
transform.named_sequence @match(%arg: !transform.any_op {transform.readonly}) {
// expected-error @below {{expected operations in the match part to implement MatchOpInterface}}
- transform.test_print_remark_at_operand %arg, "remark" : !transform.any_op
+ "test.unknown_op"() : () -> ()
transform.yield
}
transform.named_sequence @action() {
@@ -1424,3 +1424,128 @@ module attributes { transform.with_named_sequence } {
@match -> @action : (!transform.any_op) -> !transform.any_op
}
}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @match_func(%arg0: !transform.any_op {transform.readonly})
+ -> !transform.any_op {
+ transform.match.operation_name %arg0 ["func.func"] : !transform.any_op
+ transform.yield %arg0 : !transform.any_op
+ }
+
+ transform.named_sequence @print_func(%arg0: !transform.any_op {transform.readonly}) {
+ transform.test_print_remark_at_operand %arg0, "matched func" : !transform.any_op
+ transform.yield
+ }
+
+ transform.sequence failures(propagate) {
+ ^bb(%arg0: !transform.any_op):
+ transform.foreach_match in %arg0 @match_func -> @print_func : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+
+ // expected-remark @below {{matched func}}
+ func.func @payload() {
+ return
+ }
+
+ // expected-remark @below {{matched func}}
+ func.func private @declaration()
+
+ "test.something_else"() : () -> ()
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @eq_1(%arg0: !transform.any_op {transform.readonly})
+ -> !transform.any_op {
+ transform.match.operation_name %arg0 ["func.func"] : !transform.any_op
+ %0 = transform.test_produce_param_with_number_of_test_ops %arg0 : !transform.any_op
+ %1 = transform.param.constant 1 : i32 -> !transform.test_dialect_param
+ transform.match.param.cmpi eq %0, %1 : !transform.test_dialect_param
+ transform.test_print_remark_at_operand %arg0, "matched == 1" : !transform.any_op
+ transform.yield %arg0 : !transform.any_op
+ }
+
+ transform.named_sequence @ne_0(%arg0: !transform.any_op {transform.readonly})
+ -> !transform.any_op {
+ transform.match.operation_name %arg0 ["func.func"] : !transform.any_op
+ %0 = transform.test_produce_param_with_number_of_test_ops %arg0 : !transform.any_op
+ %1 = transform.param.constant 0 : i32 -> !transform.test_dialect_param
+ transform.match.param.cmpi ne %0, %1 : !transform.test_dialect_param
+ transform.test_print_remark_at_operand %arg0, "matched != 0" : !transform.any_op
+ transform.yield %arg0 : !transform.any_op
+ }
+
+ transform.named_sequence @gt_m1(%arg0: !transform.any_op {transform.readonly})
+ -> !transform.any_op {
+ transform.match.operation_name %arg0 ["func.func"] : !transform.any_op
+ %0 = transform.test_produce_param_with_number_of_test_ops %arg0 : !transform.any_op
+ %1 = transform.param.constant -1 : i32 -> !transform.test_dialect_param
+ transform.match.param.cmpi gt %0, %1 : !transform.test_dialect_param
+ transform.test_print_remark_at_operand %arg0, "matched > -1" : !transform.any_op
+ transform.yield %arg0 : !transform.any_op
+ }
+
+ transform.named_sequence @ge_1(%arg0: !transform.any_op {transform.readonly})
+ -> !transform.any_op {
+ transform.match.operation_name %arg0 ["func.func"] : !transform.any_op
+ %0 = transform.test_produce_param_with_number_of_test_ops %arg0 : !transform.any_op
+ %1 = transform.param.constant 1 : i32 -> !transform.test_dialect_param
+ transform.match.param.cmpi ge %0, %1 : !transform.test_dialect_param
+ transform.test_print_remark_at_operand %arg0, "matched >= 1" : !transform.any_op
+ transform.yield %arg0 : !transform.any_op
+ }
+
+ transform.named_sequence @lt_1(%arg0: !transform.any_op {transform.readonly})
+ -> !transform.any_op {
+ transform.match.operation_name %arg0 ["func.func"] : !transform.any_op
+ %0 = transform.test_produce_param_with_number_of_test_ops %arg0 : !transform.any_op
+ %1 = transform.param.constant 1 : i32 -> !transform.test_dialect_param
+ transform.match.param.cmpi lt %0, %1 : !transform.test_dialect_param
+ transform.test_print_remark_at_operand %arg0, "matched < 1" : !transform.any_op
+ transform.yield %arg0 : !transform.any_op
+ }
+
+ transform.named_sequence @le_1(%arg0: !transform.any_op {transform.readonly})
+ -> !transform.any_op {
+ transform.match.operation_name %arg0 ["func.func"] : !transform.any_op
+ %0 = transform.test_produce_param_with_number_of_test_ops %arg0 : !transform.any_op
+ %1 = transform.param.constant 1 : i32 -> !transform.test_dialect_param
+ transform.match.param.cmpi le %0, %1 : !transform.test_dialect_param
+ transform.test_print_remark_at_operand %arg0, "matched <= 1" : !transform.any_op
+ transform.yield %arg0 : !transform.any_op
+ }
+
+ transform.named_sequence @do_nothing(%arg0: !transform.any_op {transform.readonly}) {
+ transform.yield
+ }
+
+ transform.sequence failures(propagate) {
+ ^bb(%arg0: !transform.any_op):
+ %0 = transform.foreach_match in %arg0 @eq_1 -> @do_nothing : (!transform.any_op) -> !transform.any_op
+ %1 = transform.foreach_match in %0 @ne_0 -> @do_nothing : (!transform.any_op) -> !transform.any_op
+ %2 = transform.foreach_match in %1 @gt_m1 -> @do_nothing : (!transform.any_op) -> !transform.any_op
+ %3 = transform.foreach_match in %2 @ge_1 -> @do_nothing : (!transform.any_op) -> !transform.any_op
+ %4 = transform.foreach_match in %3 @lt_1 -> @do_nothing : (!transform.any_op) -> !transform.any_op
+ %5 = transform.foreach_match in %4 @le_1 -> @do_nothing : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+
+ // expected-remark @below {{matched > -1}}
+ // expected-remark @below {{matched < 1}}
+ // expected-remark @below {{matched <= 1}}
+ func.func private @declaration()
+
+ // expected-remark @below {{matched == 1}}
+ // expected-remark @below {{matched != 0}}
+ // expected-remark @below {{matched > -1}}
+ // expected-remark @below {{matched >= 1}}
+ // expected-remark @below {{matched <= 1}}
+ func.func @definition() {
+ "test.something"() : () -> ()
+ return
+ }
+}
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index f264d8dec429a..9cb26c11eb892 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -505,6 +505,9 @@ mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::apply(
void mlir::test::TestPrintParamOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getParam(), effects);
+ if (getAnchor())
+ transform::onlyReadsHandle(getAnchor(), effects);
+ transform::onlyReadsPayload(effects);
}
DiagnosedSilenceableFailure
@@ -512,8 +515,15 @@ mlir::test::TestPrintParamOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
std::string str;
llvm::raw_string_ostream os(str);
+ if (getMessage())
+ os << *getMessage() << " ";
llvm::interleaveComma(state.getParams(getParam()), os);
- auto diag = emitRemark() << os.str();
+ if (!getAnchor()) {
+ emitRemark() << os.str();
+ return DiagnosedSilenceableFailure::success();
+ }
+ for (Operation *payload : state.getPayloadOps(getAnchor()))
+ ::mlir::emitRemark(payload->getLoc()) << os.str();
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index 36fb471d28b85..e5a23ede64088 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -132,7 +132,8 @@ def TestSucceedIfOperandOfOpKind
def TestPrintRemarkAtOperandOp
: Op<Transform_Dialect, "test_print_remark_at_operand",
- [DeclareOpInterfaceMethods<TransformOpInterface>,
+ [MatchOpInterface,
+ DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let arguments = (ins
TransformHandleTypeInterface:$operand,
@@ -340,16 +341,22 @@ def TestReportNumberOfTrackedHandlesNestedUnder
def TestPrintParamOp
: Op<Transform_Dialect, "test_print_param",
- [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ [MatchOpInterface,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
- let arguments = (ins TransformParamTypeInterface:$param);
- let assemblyFormat = "$param attr-dict `:` type($param)";
+ let arguments = (ins TransformParamTypeInterface:$param,
+ Optional<TransformHandleTypeInterface>:$anchor,
+ OptionalAttr<StrAttr>:$message);
+ let assemblyFormat = "$param (`,` $message^)? (`at` $anchor^)?"
+ "attr-dict `:` type($param) (`,` type($anchor)^)?";
let cppNamespace = "::mlir::test";
}
def TestAddToParamOp
: Op<Transform_Dialect, "test_add_to_param",
- [MemoryEffectsOpInterface, ParamProducerTransformOpTrait,
+ [MatchOpInterface,
+ MemoryEffectsOpInterface,
+ ParamProducerTransformOpTrait,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
let arguments = (ins Optional<TestTransformTestDialectParamType>:$param,
I32Attr:$addendum);
@@ -360,7 +367,9 @@ def TestAddToParamOp
def TestProduceParamWithNumberOfTestOps
: Op<Transform_Dialect, "test_produce_param_with_number_of_test_ops",
- [MemoryEffectsOpInterface, ParamProducerTransformOpTrait,
+ [MatchOpInterface,
+ MemoryEffectsOpInterface,
+ ParamProducerTransformOpTrait,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
let arguments = (ins TransformHandleTypeInterface:$handle);
let results = (outs TestTransformTestDialectParamType:$result);
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 335d7c66c2fb9..1d11779330605 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -2386,15 +2386,14 @@ cc_library(
":MemRefTransforms",
":NVVMDialect",
":Pass",
+ ":SerializeToCubin",
":SparseTensorDialect",
":SparseTensorTransforms",
":TensorTransforms",
":Transforms",
":VectorToLLVM",
":VectorTransforms",
- ] + if_cuda_available([
- ":SerializeToCubin",
- ]),
+ ],
)
##---------------------------------------------------------------------------##
@@ -3039,7 +3038,6 @@ cc_library(
":SCFPassIncGen",
":Support",
":TensorDialect",
- ":ValueBoundsOpInterface",
":ViewLikeInterface",
"//llvm:Support",
],
@@ -5721,7 +5719,6 @@ cc_library(
":TensorDialect",
":TensorUtils",
":TilingInterface",
- ":ValueBoundsOpInterface",
"//llvm:Support",
],
)
@@ -6530,7 +6527,6 @@ cc_library(
":LLVMDialect",
":MathDialect",
":Pass",
- ":SCFDialect",
":Transforms",
":VectorDialect",
":VectorUtils",
@@ -7131,25 +7127,6 @@ cc_library(
],
)
-cc_library(
- name = "PluginsLib",
- srcs = [
- "lib/Tools/Plugins/DialectPlugin.cpp",
- "lib/Tools/Plugins/PassPlugin.cpp",
- ],
- hdrs = [
- "include/mlir/Tools/Plugins/DialectPlugin.h",
- "include/mlir/Tools/Plugins/PassPlugin.h",
- ],
- includes = ["include"],
- deps = [
- ":IR",
- ":Pass",
- ":Support",
- "//llvm:Support",
- ],
-)
-
cc_library(
name = "MlirOptLib",
srcs = [
@@ -7166,7 +7143,6 @@ cc_library(
":Observers",
":Parser",
":Pass",
- ":PluginsLib",
":Support",
"//llvm:Support",
],
@@ -8361,9 +8337,9 @@ td_library(
td_library(
name = "LinalgTransformOpsTdFiles",
- srcs = [
- "include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td",
- ],
+ srcs = glob([
+ "include/mlir/Dialect/Linalg/TransformOps/*.td",
+ ]),
includes = ["include"],
deps = [
":PDLDialectTdFiles",
@@ -8431,17 +8407,32 @@ gentbl_cc_library(
)
gentbl_cc_library(
- name = "LinalgTransformOpsIncGen",
+ name = "LinalgMatchOpsIncGen",
strip_include_prefix = "include",
tbl_outs = [
(
["-gen-op-decls"],
- "include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h.inc",
+ "include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h.inc",
),
(
["-gen-op-defs"],
- "include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc",
+ "include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc",
),
+ ],
+ tblgen = ":mlir-tblgen",
+ td_file = "include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td",
+ deps = [
+ ":LinalgTransformEnumsIncGen",
+ ":LinalgTransformOpsIncGen",
+ ":LinalgTransformOpsTdFiles",
+ ":SCFDeviceMappingInterfacesIncGen",
+ ],
+)
+
+gentbl_cc_library(
+ name = "LinalgTransformEnumsIncGen",
+ strip_include_prefix = "include",
+ tbl_outs = [
(
["-gen-enum-decls"],
"include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.h.inc",
@@ -8452,8 +8443,30 @@ gentbl_cc_library(
),
],
tblgen = ":mlir-tblgen",
+ td_file = "include/mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td",
+ deps = [
+ ":LinalgTransformOpsTdFiles",
+ ":SCFDeviceMappingInterfacesIncGen",
+ ],
+)
+
+gentbl_cc_library(
+ name = "LinalgTransformOpsIncGen",
+ strip_include_prefix = "include",
+ tbl_outs = [
+ (
+ ["-gen-op-decls"],
+ "include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h.inc",
+ ),
+ (
+ ["-gen-op-defs"],
+ "include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc",
+ ),
+ ],
+ tblgen = ":mlir-tblgen",
td_file = "include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td",
deps = [
+ ":LinalgTransformEnumsIncGen",
":LinalgTransformOpsTdFiles",
":SCFDeviceMappingInterfacesIncGen",
],
@@ -8677,7 +8690,6 @@ cc_library(
":Support",
":TensorDialect",
":TilingInterface",
- ":ValueBoundsOpInterface",
":ViewLikeInterface",
"//llvm:Support",
],
@@ -8685,30 +8697,30 @@ cc_library(
cc_library(
name = "LinalgTransformOps",
- srcs = [
- "lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp",
- ],
+ srcs = glob([
+ "lib/Dialect/Linalg/TransformOps/*.cpp",
+ ]),
hdrs = [
"include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h",
],
includes = ["include"],
deps = [
":AffineDialect",
+ ":Analysis",
":ArithDialect",
":AsmParser",
- ":ControlFlowDialect",
":DialectUtils",
":FuncDialect",
":GPUDialect",
":IR",
":LinalgDialect",
+ ":LinalgMatchOpsIncGen",
+ ":LinalgTransformEnumsIncGen",
":LinalgTransformOpsIncGen",
":LinalgTransforms",
":LinalgUtils",
":PDLDialect",
- ":Parser",
":SCFTransforms",
- ":SideEffectInterfaces",
":Support",
":TensorDialect",
":TensorUtils",
@@ -8796,7 +8808,6 @@ cc_library(
deps = [
":AffineAnalysis",
":AffineDialect",
- ":AffineTransforms",
":AffineUtils",
":Analysis",
":ArithDialect",
@@ -8834,7 +8845,6 @@ cc_library(
":TilingInterface",
":TransformUtils",
":Transforms",
- ":ValueBoundsOpInterface",
":VectorDialect",
":VectorToSCF",
":VectorTransforms",
@@ -9486,62 +9496,62 @@ gentbl_cc_library(
)
gentbl_cc_library(
- name = "TransformDialectInterfacesIncGen",
+ name = "TransformDialectMatchInterfacesIncGen",
strip_include_prefix = "include",
tbl_outs = [
(
[
"-gen-op-interface-decls",
],
- "include/mlir/Dialect/Transform/IR/TransformInterfaces.h.inc",
+ "include/mlir/Dialect/Transform/IR/MatchInterfaces.h.inc",
),
(
[
"-gen-op-interface-defs",
],
- "include/mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc",
- ),
- (
- [
- "-gen-type-interface-decls",
- ],
- "include/mlir/Dialect/Transform/IR/TransformTypeInterfaces.h.inc",
- ),
- (
- [
- "-gen-type-interface-defs",
- ],
- "include/mlir/Dialect/Transform/IR/TransformTypeInterfaces.cpp.inc",
+ "include/mlir/Dialect/Transform/IR/MatchInterfaces.cpp.inc",
),
],
tblgen = ":mlir-tblgen",
- td_file = "include/mlir/Dialect/Transform/IR/TransformInterfaces.td",
- deps = [":TransformDialectTdFiles"],
+ td_file = "include/mlir/Dialect/Transform/IR/MatchInterfaces.td",
+ deps = [
+ ":TransformDialectInterfacesIncGen",
+ ":TransformDialectTdFiles",
+ ],
)
gentbl_cc_library(
- name = "TransformDialectMatchInterfacesIncGen",
+ name = "TransformDialectInterfacesIncGen",
strip_include_prefix = "include",
tbl_outs = [
(
[
"-gen-op-interface-decls",
],
- "include/mlir/Dialect/Transform/IR/MatchInterfaces.h.inc",
+ "include/mlir/Dialect/Transform/IR/TransformInterfaces.h.inc",
),
(
[
"-gen-op-interface-defs",
],
- "include/mlir/Dialect/Transform/IR/MatchInterfaces.cpp.inc",
+ "include/mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc",
+ ),
+ (
+ [
+ "-gen-type-interface-decls",
+ ],
+ "include/mlir/Dialect/Transform/IR/TransformTypeInterfaces.h.inc",
+ ),
+ (
+ [
+ "-gen-type-interface-defs",
+ ],
+ "include/mlir/Dialect/Transform/IR/TransformTypeInterfaces.cpp.inc",
),
],
tblgen = ":mlir-tblgen",
- td_file = "include/mlir/Dialect/Transform/IR/MatchInterfaces.td",
- deps = [
- ":TransformDialectTdFiles",
- ":TransformDialectInterfacesIncGen",
- ],
+ td_file = "include/mlir/Dialect/Transform/IR/TransformInterfaces.td",
+ deps = [":TransformDialectTdFiles"],
)
gentbl_cc_library(
More information about the Mlir-commits
mailing list