[Mlir-commits] [mlir] [mlir][transform] Add transform.get_operand op (PR #78397)
Quinn Dawkins
llvmlistbot at llvm.org
Wed Jan 17 08:41:51 PST 2024
https://github.com/qedawkins updated https://github.com/llvm/llvm-project/pull/78397
>From 91418fce183ca6cbcae419df1cb35a13f7cca7c7 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Tue, 16 Jan 2024 23:53:33 -0500
Subject: [PATCH 1/2] [mlir][transform] Add transform.get_operand op
Similar to `transform.get_result`, except it returns a handle to the
operand indicated by `operand_number`, or all operands if no index is
given.
Additionally updates `get_result` to make the `result_number`
optional. This makes the use case of wanting to get all of the
results of an operation easier by no longer requiring the user to
reconstruct the list of results one-by-one.
---
.../mlir/Dialect/Transform/IR/TransformOps.td | 31 ++++++--
.../lib/Dialect/Transform/IR/TransformOps.cpp | 37 +++++++++-
.../Dialect/Transform/test-interpreter.mlir | 73 +++++++++++++++++++
3 files changed, 135 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index fe2c28f45aea04..6637d81dab5e2a 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -725,22 +725,43 @@ def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand",
"functional-type(operands, results)";
}
+def GetOperandOp : TransformDialectOp<"get_operand",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ NavigationTransformOpTrait, MatchOpInterface, MemoryEffectsOpInterface]> {
+ let summary = "Get a handle to the operand(s) of the targeted op";
+ let description = [{
+ The handle defined by this Transform op corresponds to the Operands of the
+ given `target` operation. Optionally `operand_number` can be specified to
+ select a specific operand.
+
+ This transform fails silently if the targeted operation does not have enough
+ operands. It reads the target handle and produces the result handle.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target,
+ OptionalAttr<I64Attr>:$operand_number);
+ let results = (outs TransformValueHandleTypeInterface:$result);
+ let assemblyFormat = "$target (`[` $operand_number^ `]`)? attr-dict `:` "
+ "functional-type(operands, results)";
+}
+
def GetResultOp : TransformDialectOp<"get_result",
[DeclareOpInterfaceMethods<TransformOpInterface>,
NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
- let summary = "Get handle to the a result of the targeted op";
+ let summary = "Get a handle to the result(s) of the targeted op";
let description = [{
- The handle defined by this Transform op corresponds to the OpResult with
- `result_number` that is defined by the given `target` operation.
+ The handle defined by this Transform op correspond to the OpResults of the
+ given `target` operation. Optionally `result_number` can be specified to
+ select a specific result.
This transform fails silently if the targeted operation does not have enough
results. It reads the target handle and produces the result handle.
}];
let arguments = (ins TransformHandleTypeInterface:$target,
- I64Attr:$result_number);
+ OptionalAttr<I64Attr>:$result_number);
let results = (outs TransformValueHandleTypeInterface:$result);
- let assemblyFormat = "$target `[` $result_number `]` attr-dict `:` "
+ let assemblyFormat = "$target (`[` $result_number^ `]`)? attr-dict `:` "
"functional-type(operands, results)";
}
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index b80fc09751d2aa..56baae9b5fadf2 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1464,6 +1464,35 @@ transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter,
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// GetOperandOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::GetOperandOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ std::optional<int64_t> maybeOperandNumber = getOperandNumber();
+ SmallVector<Value> operands;
+ for (Operation *target : state.getPayloadOps(getTarget())) {
+ if (!maybeOperandNumber) {
+ for (Value operand : target->getOperands())
+ operands.push_back(operand);
+ continue;
+ }
+ int64_t operandNumber = *maybeOperandNumber;
+ if (operandNumber >= target->getNumOperands()) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError() << "targeted op does not have enough operands";
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+ operands.push_back(target->getOperand(operandNumber));
+ }
+ results.setValues(llvm::cast<OpResult>(getResult()), operands);
+ return DiagnosedSilenceableFailure::success();
+}
+
//===----------------------------------------------------------------------===//
// GetResultOp
//===----------------------------------------------------------------------===//
@@ -1472,9 +1501,15 @@ DiagnosedSilenceableFailure
transform::GetResultOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
- int64_t resultNumber = getResultNumber();
+ std::optional<int64_t> maybeResultNumber = getResultNumber();
SmallVector<Value> opResults;
for (Operation *target : state.getPayloadOps(getTarget())) {
+ if (!maybeResultNumber) {
+ for (Value result : target->getResults())
+ opResults.push_back(result);
+ continue;
+ }
+ int64_t resultNumber = *maybeResultNumber;
if (resultNumber >= target->getNumResults()) {
DiagnosedSilenceableFailure diag =
emitSilenceableError() << "targeted op does not have enough results";
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 96f2122e976df5..b89b52e2f403d5 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1483,6 +1483,60 @@ module attributes {transform.with_named_sequence} {
// -----
+// expected-remark @below {{addi operand}}
+// expected-note @below {{value handle points to a block argument #0}}
+func.func @get_operand_of_op(%arg0: index, %arg1: index) -> index {
+ %r = arith.addi %arg0, %arg1 : index
+ return %r : index
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+ %addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %operand = transform.get_operand %addi[0] : (!transform.any_op) -> !transform.any_value
+ transform.debug.emit_remark_at %operand, "addi operand" : !transform.any_value
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @get_out_of_bounds_operand_of_op(%arg0: index, %arg1: index) -> index {
+ // expected-note @below {{target op}}
+ %r = arith.addi %arg0, %arg1 : index
+ return %r : index
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+ %addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error @below {{targeted op does not have enough operands}}
+ %operand = transform.get_operand %addi[2] : (!transform.any_op) -> !transform.any_value
+ transform.debug.emit_remark_at %operand, "addi operand" : !transform.any_value
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @get_multiple_operands_of_op(%arg0: index, %arg1: index) -> index {
+ %r = arith.addi %arg0, %arg1 : index
+ return %r : index
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+ %addui = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %operands = transform.get_operand %addui : (!transform.any_op) -> !transform.any_value
+ %p = transform.num_associations %operands : (!transform.any_value) -> !transform.param<i64>
+ // expected-remark @below {{2}}
+ transform.debug.emit_param_as_remark %p : !transform.param<i64>
+ transform.yield
+ }
+}
+
+// -----
+
func.func @get_result_of_op(%arg0: index, %arg1: index) -> index {
// expected-remark @below {{addi result}}
// expected-note @below {{value handle points to an op result #0}}
@@ -1537,6 +1591,25 @@ module attributes {transform.with_named_sequence} {
// -----
+func.func @get_multiple_result_of_op(%arg0: index, %arg1: index) -> (index, i1) {
+ // expected-remark @below {{matched bool}}
+ %r, %b = arith.addui_extended %arg0, %arg1 : index, i1
+ return %r, %b : index, i1
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+ %addui = transform.structured.match ops{["arith.addui_extended"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %results = transform.get_result %addui : (!transform.any_op) -> !transform.any_value
+ %adds = transform.get_defining_op %results : (!transform.any_value) -> !transform.any_op
+ %_, %add_again = transform.split_handle %adds : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.debug.emit_remark_at %add_again, "matched bool" : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
// expected-note @below {{target value}}
func.func @get_result_of_op_bbarg(%arg0: index, %arg1: index) -> index {
%r = arith.addi %arg0, %arg1 : index
>From f3ec38244f58f345e404391244c0212d51ffc2d3 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Wed, 17 Jan 2024 11:41:27 -0500
Subject: [PATCH 2/2] Address comments and switch to mirroring the linalg.match
positional spec
---
.../Linalg/TransformOps/LinalgMatchOps.td | 4 +-
.../Dialect/Transform/IR/MatchInterfaces.h | 51 ++++++-
.../mlir/Dialect/Transform/IR/TransformOps.td | 53 +++++--
.../Linalg/TransformOps/LinalgMatchOps.cpp | 112 +--------------
.../Dialect/Transform/IR/MatchInterfaces.cpp | 135 ++++++++++++++++++
.../lib/Dialect/Transform/IR/TransformOps.cpp | 60 ++++----
.../Dialect/Transform/test-interpreter.mlir | 37 +++--
7 files changed, 295 insertions(+), 157 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
index 9e108529ec129b..162dd05f93030f 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
@@ -288,7 +288,7 @@ def MatchStructuredDimOp : Op<Transform_Dialect, "match.structured.dim", [
let results = (outs Optional<TransformParamTypeInterface>:$result);
let assemblyFormat =
"$operand_handle `[`"
- "custom<StructuredTransformDims>($raw_dim_list, $is_inverted, $is_all)"
+ "custom<TransformMatchDims>($raw_dim_list, $is_inverted, $is_all)"
"`]` attr-dict `:` "
"custom<SemiFunctionType>(type($operand_handle), type($result))";
@@ -347,7 +347,7 @@ class MatchStructuredOperandOp<string opname> : Op<Transform_Dialect, opname, [
(outs Optional<AnyTypeOf<[TransformAnyHandle,Transform_AffineMapParamType]>>:$result);
let assemblyFormat =
"$operand_handle `[`"
- "custom<StructuredTransformDims>($raw_position_list, $is_inverted, $is_all)"
+ "custom<TransformMatchDims>($raw_position_list, $is_inverted, $is_all)"
"`]` attr-dict "
"`:` custom<SemiFunctionType>(type($operand_handle), type($result))";
diff --git a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
index b155b110677d6c..36aeb4583029c9 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
@@ -9,11 +9,12 @@
#ifndef MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
#define MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
+#include <optional>
+#include <type_traits>
+
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/OpDefinition.h"
#include "llvm/ADT/STLExtras.h"
-#include <optional>
-#include <type_traits>
namespace mlir {
namespace transform {
@@ -168,6 +169,52 @@ class SingleValueMatcherOpTrait
}
};
+//===----------------------------------------------------------------------===//
+// Printing/parsing for positional specification matchers
+//===----------------------------------------------------------------------===//
+
+/// Parses a positional index specification for transform match operations.
+/// The following forms are accepted:
+///
+/// - `all`: sets `isAll` and returns;
+/// - comma-separated-integer-list: populates `rawDimList` with the values;
+/// - `except` `(` comma-separated-integer-list `)`: populates `rawDimList`
+/// with the values and sets `isInverted`.
+ParseResult parseTransformMatchDims(OpAsmParser &parser,
+ DenseI64ArrayAttr &rawDimList,
+ UnitAttr &isInverted, UnitAttr &isAll);
+
+/// Prints a positional index specification for transform match operations.
+void printTransformMatchDims(OpAsmPrinter &printer, Operation *op,
+ DenseI64ArrayAttr rawDimList, UnitAttr isInverted,
+ UnitAttr isAll);
+
+//===----------------------------------------------------------------------===//
+// Utilities for positional specification matchers
+//===----------------------------------------------------------------------===//
+
+/// Checks if the positional specification defined is valid and reports errors
+/// otherwise.
+LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef<int64_t> raw,
+ bool inverted, bool all);
+
+/// Populates `result` with the positional identifiers relative to `maxNumber`.
+/// If `isAll` is set, the result will contain all numbers from `0` to
+/// `maxNumber - 1` inclusive regardless of `rawList`. Otherwise, negative
+/// values from `rawList` are are interpreted as counting backwards from
+/// `maxNumber`, i.e., `-1` is interpreted a `maxNumber - 1`, while positive
+/// numbers remain as is. If `isInverted` is set, populates `result` with those
+/// values from the `0` to `maxNumber - 1` inclusive range that don't appear in
+/// `rawList`. If `rawList` contains values that are greater than or equal to
+/// `maxNumber` or less than `-maxNumber`, produces a silenceable error at the
+/// given location. `maxNumber` must be positive. If `rawList` contains
+/// duplicate numbers or numbers that become duplicate after negative value
+/// remapping, emits a silenceable error.
+DiagnosedSilenceableFailure
+expandTargetSpecification(Location loc, bool isAll, bool isInverted,
+ ArrayRef<int64_t> rawList, int64_t maxNumber,
+ SmallVectorImpl<int64_t> &result);
+
} // namespace transform
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 6637d81dab5e2a..1ca7c0bcb51e06 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -730,19 +730,31 @@ def GetOperandOp : TransformDialectOp<"get_operand",
NavigationTransformOpTrait, MatchOpInterface, MemoryEffectsOpInterface]> {
let summary = "Get a handle to the operand(s) of the targeted op";
let description = [{
- The handle defined by this Transform op corresponds to the Operands of the
- given `target` operation. Optionally `operand_number` can be specified to
- select a specific operand.
+ The handle defined by this Transform op corresponds to the operands of the
+ given `target` operation specified by the given set of positions. There are
+ three possible modes:
+
+ - Position list directly, i.e. `%target[0, 1, 2]`. This will return the
+ operands at the specified positions.
+ - Inverted position list, i.e. `%target[except(0, 1, 2)]`. This will return
+ all operands except those at the given positions.
+ - All, i.e. `%target[all]`. This will return all operands of the operation.
- This transform fails silently if the targeted operation does not have enough
- operands. It reads the target handle and produces the result handle.
+ This transform produces a silenceable failure if any of the operand indices
+ exceeds the number of operands in the target. It reads the target handle and
+ produces the result handle.
}];
let arguments = (ins TransformHandleTypeInterface:$target,
- OptionalAttr<I64Attr>:$operand_number);
+ DenseI64ArrayAttr:$raw_position_list,
+ UnitAttr:$is_inverted,
+ UnitAttr:$is_all);
let results = (outs TransformValueHandleTypeInterface:$result);
- let assemblyFormat = "$target (`[` $operand_number^ `]`)? attr-dict `:` "
- "functional-type(operands, results)";
+ let assemblyFormat =
+ "$target `[`"
+ "custom<TransformMatchDims>($raw_position_list, $is_inverted, $is_all)"
+ "`]` attr-dict `:` functional-type(operands, results)";
+ let hasVerifier = 1;
}
def GetResultOp : TransformDialectOp<"get_result",
@@ -756,13 +768,32 @@ def GetResultOp : TransformDialectOp<"get_result",
This transform fails silently if the targeted operation does not have enough
results. It reads the target handle and produces the result handle.
+
+ The handle defined by this Transform op corresponds to the results of the
+ given `target` operation specified by the given set of positions. There are
+ three possible modes:
+
+ - Position list directly, i.e. `%target[0, 1, 2]`. This will return the
+ results at the specified positions.
+ - Inverted position list, i.e. `%target[except(0, 1, 2)]`. This will return
+ all results except those at the given positions.
+ - All, i.e. `%target[all]`. This will return all results of the operation.
+
+ This transform produces a silenceable failure if any of the result indices
+ exceeds the number of results returned by the target. It reads the target
+ handle and produces the result handle.
}];
let arguments = (ins TransformHandleTypeInterface:$target,
- OptionalAttr<I64Attr>:$result_number);
+ DenseI64ArrayAttr:$raw_position_list,
+ UnitAttr:$is_inverted,
+ UnitAttr:$is_all);
let results = (outs TransformValueHandleTypeInterface:$result);
- let assemblyFormat = "$target (`[` $result_number^ `]`)? attr-dict `:` "
- "functional-type(operands, results)";
+ let assemblyFormat =
+ "$target `[`"
+ "custom<TransformMatchDims>($raw_position_list, $is_inverted, $is_all)"
+ "`]` attr-dict `:` functional-type(operands, results)";
+ let hasVerifier = 1;
}
def GetTypeOp : TransformDialectOp<"get_type",
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
index fb021ed76242e9..c739813d01195b 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
@@ -388,33 +388,6 @@ expandTargetSpecification(Location loc, bool isAll, bool isInverted,
return DiagnosedSilenceableFailure::success();
}
-/// Checks if the positional specification defined is valid and reports errors
-/// otherwise.
-LogicalResult verifyStructuredTransformDimsOp(Operation *op,
- ArrayRef<int64_t> raw,
- bool inverted, bool all) {
- if (all) {
- if (inverted) {
- return op->emitOpError()
- << "cannot request both 'all' and 'inverted' values in the list";
- }
- if (!raw.empty()) {
- return op->emitOpError()
- << "cannot both request 'all' and specific values in the list";
- }
- }
- if (!all && raw.empty()) {
- return op->emitOpError() << "must request specific values in the list if "
- "'all' is not specified";
- }
- SmallVector<int64_t> rawVector = llvm::to_vector(raw);
- auto *it = std::unique(rawVector.begin(), rawVector.end());
- if (it != rawVector.end())
- return op->emitOpError() << "expected the listed values to be unique";
-
- return success();
-}
-
//===----------------------------------------------------------------------===//
// MatchStructuredDimOp
//===----------------------------------------------------------------------===//
@@ -475,8 +448,8 @@ LogicalResult transform::MatchStructuredDimOp::verify() {
return emitOpError() << "cannot request the same dimension to be both "
"parallel and reduction";
}
- return verifyStructuredTransformDimsOp(getOperation(), getRawDimList(),
- getIsInverted(), getIsAll());
+ return verifyTransformMatchDimsOp(getOperation(), getRawDimList(),
+ getIsInverted(), getIsAll());
}
//===----------------------------------------------------------------------===//
@@ -592,8 +565,8 @@ LogicalResult verifyStructuredOperandOp(OpTy op) {
LogicalResult transform::MatchStructuredInputOp::verify() {
if (failed(verifyStructuredOperandOp(*this)))
return failure();
- return verifyStructuredTransformDimsOp(getOperation(), getRawPositionList(),
- getIsInverted(), getIsAll());
+ return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
+ getIsInverted(), getIsAll());
}
//===----------------------------------------------------------------------===//
@@ -665,8 +638,8 @@ DiagnosedSilenceableFailure transform::MatchStructuredInitOp::getPositionsFor(
LogicalResult transform::MatchStructuredInitOp::verify() {
if (failed(verifyStructuredOperandOp(*this)))
return failure();
- return verifyStructuredTransformDimsOp(getOperation(), getRawPositionList(),
- getIsInverted(), getIsAll());
+ return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
+ getIsInverted(), getIsAll());
}
//===----------------------------------------------------------------------===//
@@ -793,78 +766,5 @@ void transform::MatchStructuredYieldOp::build(OpBuilder &builder,
build(builder, state, ValueRange());
}
-//===----------------------------------------------------------------------===//
-// Printing and parsing for structured match ops.
-//===----------------------------------------------------------------------===//
-
-/// Keyword syntax for positional specification inversion.
-constexpr const static llvm::StringLiteral kDimExceptKeyword = "except";
-
-/// Keyword syntax for full inclusion in positional specification.
-constexpr const static llvm::StringLiteral kDimAllKeyword = "all";
-
-/// Parses a positional specification for structured transform operations. The
-/// following forms are accepted:
-///
-/// - `all`: sets `isAll` and returns;
-/// - comma-separated-integer-list: populates `rawDimList` with the values;
-/// - `except` `(` comma-separated-integer-list `)`: populates `rawDimList`
-/// with the values and sets `isInverted`.
-static ParseResult parseStructuredTransformDims(OpAsmParser &parser,
- DenseI64ArrayAttr &rawDimList,
- UnitAttr &isInverted,
- UnitAttr &isAll) {
- Builder &builder = parser.getBuilder();
- if (parser.parseOptionalKeyword(kDimAllKeyword).succeeded()) {
- rawDimList = builder.getDenseI64ArrayAttr({});
- isInverted = nullptr;
- isAll = builder.getUnitAttr();
- return success();
- }
-
- isAll = nullptr;
- isInverted = nullptr;
- if (parser.parseOptionalKeyword(kDimExceptKeyword).succeeded()) {
- isInverted = builder.getUnitAttr();
- }
-
- if (isInverted) {
- if (parser.parseLParen().failed())
- return failure();
- }
-
- SmallVector<int64_t> values;
- ParseResult listResult = parser.parseCommaSeparatedList(
- [&]() { return parser.parseInteger(values.emplace_back()); });
- if (listResult.failed())
- return failure();
-
- rawDimList = builder.getDenseI64ArrayAttr(values);
-
- if (isInverted) {
- if (parser.parseRParen().failed())
- return failure();
- }
- return success();
-}
-
-/// Prints a positional specification for structured transform operations.
-static void printStructuredTransformDims(OpAsmPrinter &printer, Operation *op,
- DenseI64ArrayAttr rawDimList,
- UnitAttr isInverted, UnitAttr isAll) {
- if (isAll) {
- printer << kDimAllKeyword;
- return;
- }
- if (isInverted) {
- printer << kDimExceptKeyword << "(";
- }
- llvm::interleaveComma(rawDimList.asArrayRef(), printer.getStream(),
- [&](int64_t value) { printer << value; });
- if (isInverted) {
- printer << ")";
- }
-}
-
#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"
diff --git a/mlir/lib/Dialect/Transform/IR/MatchInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/MatchInterfaces.cpp
index 6049ee767e1bb7..b9b6dabc26216e 100644
--- a/mlir/lib/Dialect/Transform/IR/MatchInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/MatchInterfaces.cpp
@@ -10,6 +10,141 @@
using namespace mlir;
+//===----------------------------------------------------------------------===//
+// Printing and parsing for match ops.
+//===----------------------------------------------------------------------===//
+
+/// Keyword syntax for positional specification inversion.
+constexpr const static llvm::StringLiteral kDimExceptKeyword = "except";
+
+/// Keyword syntax for full inclusion in positional specification.
+constexpr const static llvm::StringLiteral kDimAllKeyword = "all";
+
+ParseResult transform::parseTransformMatchDims(OpAsmParser &parser,
+ DenseI64ArrayAttr &rawDimList,
+ UnitAttr &isInverted,
+ UnitAttr &isAll) {
+ Builder &builder = parser.getBuilder();
+ if (parser.parseOptionalKeyword(kDimAllKeyword).succeeded()) {
+ rawDimList = builder.getDenseI64ArrayAttr({});
+ isInverted = nullptr;
+ isAll = builder.getUnitAttr();
+ return success();
+ }
+
+ isAll = nullptr;
+ isInverted = nullptr;
+ if (parser.parseOptionalKeyword(kDimExceptKeyword).succeeded()) {
+ isInverted = builder.getUnitAttr();
+ }
+
+ if (isInverted) {
+ if (parser.parseLParen().failed())
+ return failure();
+ }
+
+ SmallVector<int64_t> values;
+ ParseResult listResult = parser.parseCommaSeparatedList(
+ [&]() { return parser.parseInteger(values.emplace_back()); });
+ if (listResult.failed())
+ return failure();
+
+ rawDimList = builder.getDenseI64ArrayAttr(values);
+
+ if (isInverted) {
+ if (parser.parseRParen().failed())
+ return failure();
+ }
+ return success();
+}
+
+void transform::printTransformMatchDims(OpAsmPrinter &printer, Operation *op,
+ DenseI64ArrayAttr rawDimList,
+ UnitAttr isInverted, UnitAttr isAll) {
+ if (isAll) {
+ printer << kDimAllKeyword;
+ return;
+ }
+ if (isInverted) {
+ printer << kDimExceptKeyword << "(";
+ }
+ llvm::interleaveComma(rawDimList.asArrayRef(), printer.getStream(),
+ [&](int64_t value) { printer << value; });
+ if (isInverted) {
+ printer << ")";
+ }
+}
+
+LogicalResult transform::verifyTransformMatchDimsOp(Operation *op,
+ ArrayRef<int64_t> raw,
+ bool inverted, bool all) {
+ if (all) {
+ if (inverted) {
+ return op->emitOpError()
+ << "cannot request both 'all' and 'inverted' values in the list";
+ }
+ if (!raw.empty()) {
+ return op->emitOpError()
+ << "cannot both request 'all' and specific values in the list";
+ }
+ }
+ if (!all && raw.empty()) {
+ return op->emitOpError() << "must request specific values in the list if "
+ "'all' is not specified";
+ }
+ SmallVector<int64_t> rawVector = llvm::to_vector(raw);
+ auto *it = std::unique(rawVector.begin(), rawVector.end());
+ if (it != rawVector.end())
+ return op->emitOpError() << "expected the listed values to be unique";
+
+ return success();
+}
+
+DiagnosedSilenceableFailure transform::expandTargetSpecification(
+ Location loc, bool isAll, bool isInverted, ArrayRef<int64_t> rawList,
+ int64_t maxNumber, SmallVectorImpl<int64_t> &result) {
+ assert(maxNumber > 0 && "expected size to be positive");
+ assert(!(isAll && isInverted) && "cannot invert all");
+ if (isAll) {
+ result = llvm::to_vector(llvm::seq<int64_t>(0, maxNumber));
+ return DiagnosedSilenceableFailure::success();
+ }
+
+ SmallVector<int64_t> expanded;
+ llvm::SmallDenseSet<int64_t> visited;
+ expanded.reserve(rawList.size());
+ SmallVectorImpl<int64_t> &target = isInverted ? expanded : result;
+ for (int64_t raw : rawList) {
+ int64_t updated = raw < 0 ? maxNumber + raw : raw;
+ if (updated >= maxNumber) {
+ return emitSilenceableFailure(loc)
+ << "position overflow " << updated << " (updated from " << raw
+ << ") for maximum " << maxNumber;
+ }
+ if (updated < 0) {
+ return emitSilenceableFailure(loc) << "position underflow " << updated
+ << " (updated from " << raw << ")";
+ }
+ if (!visited.insert(updated).second) {
+ return emitSilenceableFailure(loc) << "repeated position " << updated
+ << " (updated from " << raw << ")";
+ }
+ target.push_back(updated);
+ }
+
+ if (!isInverted)
+ return DiagnosedSilenceableFailure::success();
+
+ result.reserve(result.size() + (maxNumber - expanded.size()));
+ for (int64_t candidate : llvm::seq<int64_t>(0, maxNumber)) {
+ if (llvm::is_contained(expanded, candidate))
+ continue;
+ result.push_back(candidate);
+ }
+
+ return DiagnosedSilenceableFailure::success();
+}
+
//===----------------------------------------------------------------------===//
// Generated interface implementation.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 56baae9b5fadf2..485d4448e7c368 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1472,27 +1472,31 @@ DiagnosedSilenceableFailure
transform::GetOperandOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
- std::optional<int64_t> maybeOperandNumber = getOperandNumber();
SmallVector<Value> operands;
for (Operation *target : state.getPayloadOps(getTarget())) {
- if (!maybeOperandNumber) {
- for (Value operand : target->getOperands())
- operands.push_back(operand);
- continue;
- }
- int64_t operandNumber = *maybeOperandNumber;
- if (operandNumber >= target->getNumOperands()) {
- DiagnosedSilenceableFailure diag =
- emitSilenceableError() << "targeted op does not have enough operands";
- diag.attachNote(target->getLoc()) << "target op";
+ SmallVector<int64_t> operandPositions;
+ DiagnosedSilenceableFailure diag = expandTargetSpecification(
+ getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
+ target->getNumOperands(), operandPositions);
+ if (diag.isSilenceableFailure()) {
+ diag.attachNote(target->getLoc())
+ << "while considering positions of this payload operation";
return diag;
}
- operands.push_back(target->getOperand(operandNumber));
+ llvm::append_range(operands,
+ llvm::map_range(operandPositions, [&](int64_t pos) {
+ return target->getOperand(pos);
+ }));
}
- results.setValues(llvm::cast<OpResult>(getResult()), operands);
+ results.setValues(cast<OpResult>(getResult()), operands);
return DiagnosedSilenceableFailure::success();
}
+LogicalResult transform::GetOperandOp::verify() {
+ return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
+ getIsInverted(), getIsAll());
+}
+
//===----------------------------------------------------------------------===//
// GetResultOp
//===----------------------------------------------------------------------===//
@@ -1501,27 +1505,31 @@ DiagnosedSilenceableFailure
transform::GetResultOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
- std::optional<int64_t> maybeResultNumber = getResultNumber();
SmallVector<Value> opResults;
for (Operation *target : state.getPayloadOps(getTarget())) {
- if (!maybeResultNumber) {
- for (Value result : target->getResults())
- opResults.push_back(result);
- continue;
- }
- int64_t resultNumber = *maybeResultNumber;
- if (resultNumber >= target->getNumResults()) {
- DiagnosedSilenceableFailure diag =
- emitSilenceableError() << "targeted op does not have enough results";
- diag.attachNote(target->getLoc()) << "target op";
+ SmallVector<int64_t> resultPositions;
+ DiagnosedSilenceableFailure diag = expandTargetSpecification(
+ getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
+ target->getNumResults(), resultPositions);
+ if (diag.isSilenceableFailure()) {
+ diag.attachNote(target->getLoc())
+ << "while considering positions of this payload operation";
return diag;
}
- opResults.push_back(target->getOpResult(resultNumber));
+ llvm::append_range(opResults,
+ llvm::map_range(resultPositions, [&](int64_t pos) {
+ return target->getResult(pos);
+ }));
}
- results.setValues(llvm::cast<OpResult>(getResult()), opResults);
+ results.setValues(cast<OpResult>(getResult()), opResults);
return DiagnosedSilenceableFailure::success();
}
+LogicalResult transform::GetResultOp::verify() {
+ return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
+ getIsInverted(), getIsAll());
+}
+
//===----------------------------------------------------------------------===//
// GetTypeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index b89b52e2f403d5..de5807b2874b27 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1502,7 +1502,7 @@ module attributes {transform.with_named_sequence} {
// -----
func.func @get_out_of_bounds_operand_of_op(%arg0: index, %arg1: index) -> index {
- // expected-note @below {{target op}}
+ // expected-note @below {{while considering positions of this payload operation}}
%r = arith.addi %arg0, %arg1 : index
return %r : index
}
@@ -1510,7 +1510,7 @@ func.func @get_out_of_bounds_operand_of_op(%arg0: index, %arg1: index) -> index
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // expected-error @below {{targeted op does not have enough operands}}
+ // expected-error @below {{position overflow 2 (updated from 2) for maximum 2}}
%operand = transform.get_operand %addi[2] : (!transform.any_op) -> !transform.any_value
transform.debug.emit_remark_at %operand, "addi operand" : !transform.any_value
transform.yield
@@ -1519,6 +1519,24 @@ module attributes {transform.with_named_sequence} {
// -----
+// expected-remark @below {{addi operand}}
+// expected-note @below {{value handle points to a block argument #1}}
+func.func @get_inverted_operand_of_op(%arg0: index, %arg1: index) -> index {
+ %r = arith.addi %arg0, %arg1 : index
+ return %r : index
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+ %addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %operand = transform.get_operand %addi[except(0)] : (!transform.any_op) -> !transform.any_value
+ transform.debug.emit_remark_at %operand, "addi operand" : !transform.any_value
+ transform.yield
+ }
+}
+
+// -----
+
func.func @get_multiple_operands_of_op(%arg0: index, %arg1: index) -> index {
%r = arith.addi %arg0, %arg1 : index
return %r : index
@@ -1527,7 +1545,7 @@ func.func @get_multiple_operands_of_op(%arg0: index, %arg1: index) -> index {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%addui = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %operands = transform.get_operand %addui : (!transform.any_op) -> !transform.any_value
+ %operands = transform.get_operand %addui[all] : (!transform.any_op) -> !transform.any_value
%p = transform.num_associations %operands : (!transform.any_value) -> !transform.param<i64>
// expected-remark @below {{2}}
transform.debug.emit_param_as_remark %p : !transform.param<i64>
@@ -1556,7 +1574,7 @@ module attributes {transform.with_named_sequence} {
// -----
func.func @get_out_of_bounds_result_of_op(%arg0: index, %arg1: index) -> index {
- // expected-note @below {{target op}}
+ // expected-note @below {{while considering positions of this payload operation}}
%r = arith.addi %arg0, %arg1 : index
return %r : index
}
@@ -1564,7 +1582,7 @@ func.func @get_out_of_bounds_result_of_op(%arg0: index, %arg1: index) -> index {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // expected-error @below {{targeted op does not have enough results}}
+ // expected-error @below {{position overflow 1 (updated from 1) for maximum 1}}
%result = transform.get_result %addi[1] : (!transform.any_op) -> !transform.any_value
transform.debug.emit_remark_at %result, "addi result" : !transform.any_value
transform.yield
@@ -1592,7 +1610,6 @@ module attributes {transform.with_named_sequence} {
// -----
func.func @get_multiple_result_of_op(%arg0: index, %arg1: index) -> (index, i1) {
- // expected-remark @below {{matched bool}}
%r, %b = arith.addui_extended %arg0, %arg1 : index, i1
return %r, %b : index, i1
}
@@ -1600,10 +1617,10 @@ func.func @get_multiple_result_of_op(%arg0: index, %arg1: index) -> (index, i1)
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%addui = transform.structured.match ops{["arith.addui_extended"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %results = transform.get_result %addui : (!transform.any_op) -> !transform.any_value
- %adds = transform.get_defining_op %results : (!transform.any_value) -> !transform.any_op
- %_, %add_again = transform.split_handle %adds : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.debug.emit_remark_at %add_again, "matched bool" : !transform.any_op
+ %results = transform.get_result %addui[all] : (!transform.any_op) -> !transform.any_value
+ %p = transform.num_associations %results : (!transform.any_value) -> !transform.param<i64>
+ // expected-remark @below {{2}}
+ transform.debug.emit_param_as_remark %p : !transform.param<i64>
transform.yield
}
}
More information about the Mlir-commits
mailing list