[Mlir-commits] [mlir] [mlir][Transform] Add a transform.match.operation_empty op to allow s… (PR #68319)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 5 08:09:06 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
<details>
<summary>Changes</summary>
…pecifying negative conditions
In the process, get_parent_op gains an attribute to allow it to return empty handles explicitly and still succeed.
---
Full diff: https://github.com/llvm/llvm-project/pull/68319.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h (+99-39)
- (modified) mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td (+19-2)
- (modified) mlir/include/mlir/Dialect/Transform/IR/TransformOps.td (+24-3)
- (modified) mlir/lib/Dialect/Transform/IR/TransformOps.cpp (+20)
- (modified) mlir/test/Dialect/Transform/test-interpreter.mlir (+21)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
index c8888f294f6ca1d..2cf008a911bd644 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
@@ -11,14 +11,46 @@
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/OpDefinition.h"
+#include "llvm/ADT/STLExtras.h"
+#include <functional>
+#include <optional>
+#include <type_traits>
namespace mlir {
namespace transform {
class MatchOpInterface;
+namespace detail {
template <typename OpTy>
-class SingleOpMatcherOpTrait
- : public OpTrait::TraitBase<OpTy, SingleOpMatcherOpTrait> {
+DiagnosedSilenceableFailure matchOptionalOperationImpl(
+ OpTy op, TransformResults &results, TransformState &state, std::false_type) {
+ return op.matchOperation(std::nullopt, results, state);
+}
+
+template <typename OpTy>
+DiagnosedSilenceableFailure
+matchOptionalOperationImpl(OpTy op, TransformResults &results,
+ TransformState &state, std::true_type) {
+ return op.matchOperation(nullptr, results, state);
+}
+
+/// Dispatch `matchOperation` based on Operation* or std::optional<Operation*>
+/// first operand.
+template <typename OpTy, typename... Args>
+DiagnosedSilenceableFailure
+matchOptionalOperation(OpTy op, TransformResults &results,
+ TransformState &state) {
+ using uses_operation_ptr_t =
+ typename std::is_same <
+ typename llvm::function_traits<decltype(&OpTy::matchOperation)>::template arg_t<0>,
+ Operation*>;
+ return matchOptionalOperationImpl(op, results, state, uses_operation_ptr_t{});
+}
+} // namespace detail
+
+template <typename OpTy>
+class AtMostOneOpMatcherOpTrait
+ : public OpTrait::TraitBase<OpTy, AtMostOneOpMatcherOpTrait> {
template <typename T>
using has_get_operand_handle =
decltype(std::declval<T &>().getOperandHandle());
@@ -30,20 +62,22 @@ class SingleOpMatcherOpTrait
public:
static LogicalResult verifyTrait(Operation *op) {
static_assert(llvm::is_detected<has_get_operand_handle, OpTy>::value,
- "SingleOpMatcherOpTrait expects operation type to have the "
+ "AtMostOneOpMatcherOpTrait/SingleOpMatcherOpTrait expects operation type to have the "
"getOperandHandle() method");
static_assert(llvm::is_detected<has_match_operation, OpTy>::value,
- "SingleOpMatcherOpTrait expected operation type to have the "
+ "AtMostOneOpMatcherOpTrait/SingleOpMatcherOpTrait expected "
+ "operation type to have the "
"matchOperation(Operation *, TransformResults &, "
"TransformState &) method");
// This must be a dynamic assert because interface registration is dynamic.
assert(isa<MatchOpInterface>(op) &&
- "SingleOpMatchOpTrait is only available on operations with "
+ "AtMostOneOpMatcherOpTrait/SingleOpMatchOpTrait is only available on "
+ "operations with "
"MatchOpInterface");
Value operandHandle = cast<OpTy>(op).getOperandHandle();
if (!isa<TransformHandleTypeInterface>(operandHandle.getType())) {
- return op->emitError() << "SingleOpMatchOpTrait requires the op handle "
+ return op->emitError() << "AtMostOneOpMatcherOpTrait/SingleOpMatchOpTrait requires the op handle "
"to be of TransformHandleTypeInterface";
}
@@ -55,12 +89,16 @@ class SingleOpMatcherOpTrait
TransformState &state) {
Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
auto payload = state.getPayloadOps(operandHandle);
- if (!llvm::hasSingleElement(payload)) {
+ if (!payload.empty() && !llvm::hasSingleElement(payload)) {
return emitDefiniteFailure(this->getOperation()->getLoc())
- << "SingleOpMatchOpTrait requires the operand handle to point to "
- "a single payload op";
+ << "AtMostOneOpMatcherOpTrait requires the operand handle to "
+ "point to "
+ "at most one payload op";
+ }
+ if (payload.empty()) {
+ return detail::matchOptionalOperation(cast<OpTy>(this->getOperation()), results,
+ state);
}
-
return cast<OpTy>(this->getOperation())
.matchOperation(*payload.begin(), results, state);
}
@@ -73,46 +111,68 @@ class SingleOpMatcherOpTrait
};
template <typename OpTy>
-class SingleValueMatcherOpTrait
- : public OpTrait::TraitBase<OpTy, SingleValueMatcherOpTrait> {
-public:
- static LogicalResult verifyTrait(Operation *op) {
- // This must be a dynamic assert because interface registration is dynamic.
- assert(isa<MatchOpInterface>(op) &&
- "SingleValueMatchOpTrait is only available on operations with "
- "MatchOpInterface");
-
- Value operandHandle = cast<OpTy>(op).getOperandHandle();
- if (!isa<TransformValueHandleTypeInterface>(operandHandle.getType())) {
- return op->emitError() << "SingleValueMatchOpTrait requires an operand "
- "of TransformValueHandleTypeInterface";
- }
-
- return success();
- }
+class SingleOpMatcherOpTrait
+ : public AtMostOneOpMatcherOpTrait<OpTy> {
+ public:
DiagnosedSilenceableFailure apply(TransformRewriter &rewriter,
TransformResults &results,
TransformState &state) {
Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
- auto payload = state.getPayloadValues(operandHandle);
+ auto payload = state.getPayloadOps(operandHandle);
if (!llvm::hasSingleElement(payload)) {
return emitDefiniteFailure(this->getOperation()->getLoc())
- << "SingleValueMatchOpTrait requires the value handle to point to "
- "a single payload value";
+ << "SingleOpMatchOpTrait requires the operand handle to point to "
+ "a single payload op";
}
-
- return cast<OpTy>(this->getOperation())
- .matchValue(*payload.begin(), results, state);
- }
-
- void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- onlyReadsHandle(this->getOperation()->getOperands(), effects);
- producesHandle(this->getOperation()->getResults(), effects);
- onlyReadsPayload(effects);
+ return static_cast<AtMostOneOpMatcherOpTrait<OpTy> *>(this)->apply(
+ rewriter, results, state);
}
};
+template <typename OpTy>
+ class SingleValueMatcherOpTrait
+ : public OpTrait::TraitBase<OpTy, SingleValueMatcherOpTrait> {
+ public:
+ static LogicalResult verifyTrait(Operation *op) {
+ // This must be a dynamic assert because interface registration is
+ // dynamic.
+ assert(isa<MatchOpInterface>(op) &&
+ "SingleValueMatchOpTrait is only available on operations with "
+ "MatchOpInterface");
+
+ Value operandHandle = cast<OpTy>(op).getOperandHandle();
+ if (!isa<TransformValueHandleTypeInterface>(operandHandle.getType())) {
+ return op->emitError() << "SingleValueMatchOpTrait requires an operand "
+ "of TransformValueHandleTypeInterface";
+ }
+
+ return success();
+ }
+
+ DiagnosedSilenceableFailure apply(TransformRewriter &rewriter,
+ TransformResults &results,
+ TransformState &state) {
+ Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
+ auto payload = state.getPayloadValues(operandHandle);
+ if (!llvm::hasSingleElement(payload)) {
+ return emitDefiniteFailure(this->getOperation()->getLoc())
+ << "SingleValueMatchOpTrait requires the value handle to point "
+ "to "
+ "a single payload value";
+ }
+
+ return cast<OpTy>(this->getOperation())
+ .matchValue(*payload.begin(), results, state);
+ }
+
+ void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(this->getOperation()->getOperands(), effects);
+ producesHandle(this->getOperation()->getResults(), effects);
+ onlyReadsPayload(effects);
+ }
+ };
+
} // namespace transform
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td
index 1f81fd5252eb45b..be92e4d91b42b32 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td
@@ -14,11 +14,28 @@ def MatchOpInterface
let cppNamespace = "::mlir::transform";
}
+// Trait for "matcher" transform operations that apply to an operation handle
+// associated with at most one payload operation. Checks that it is indeed
+// the case and produces a definite failure when it is not. The matching logic
+// is implemented in the `matchOperation` function instead of `apply`. The op
+// with this trait must provide a `Value getOperandHandle()` function that
+// returns the handle to be used for matching.
+def AtMostOneOpMatcher : NativeOpTrait<"AtMostOneOpMatcherOpTrait"> {
+ let cppNamespace = "::mlir::transform";
+
+ string extraDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure matchOperation(
+ ::std::optional<::mlir::Operation *> maybeCurrent,
+ ::mlir::transform::TransformResults &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
// Trait for "matcher" transform operations that apply to an operation handle
// associated with exactly one payload operation. Checks that it is indeed
// the case and produces a definite failure when it is not. The matching logic
// is implemented in the `matchOperation` function instead of `apply`. The op
-// with this trait must provide a `Value getOperandHandle()` function that
+// with this trait must provide a `Value getOperandHandle()` function that
// returns the handle to be used for matching.
def SingleOpMatcher : NativeOpTrait<"SingleOpMatcherOpTrait"> {
let cppNamespace = "::mlir::transform";
@@ -35,7 +52,7 @@ def SingleOpMatcher : NativeOpTrait<"SingleOpMatcherOpTrait"> {
// associated with exactly one payload value. Checks that it is indeed
// the case and produces a definite failure when it is not. The matching logic
// is implemented in the `matchValue` function instead of `apply`. The op
-// with this trait must provide a `Value getOperandHandle()` function that
+// with this trait must provide a `Value getOperandHandle()` function that
// returns the handle to be used for matching.
def SingleValueMatcher : NativeOpTrait<"SingleValueMatcherOpTrait"> {
let cppNamespace = "::mlir::transform";
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index ca5c915ef8c2caa..2c6917236d34ddf 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -595,8 +595,9 @@ def GetDefiningOp : TransformDialectOp<"get_defining_op",
def GetParentOp : TransformDialectOp<"get_parent_op",
[DeclareOpInterfaceMethods<TransformOpInterface>,
+ MatchOpInterface,
NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
- let summary = "Gets handles to the closest isolated-from-above parents";
+ let summary = "Gets handles to the closest parent ops";
let description = [{
The handle defined by this Transform op corresponds to the parents of the
targeted payload ops (in the same order).
@@ -605,6 +606,9 @@ def GetParentOp : TransformDialectOp<"get_parent_op",
that case for each target op, the closest parent op that fulfills all
requirements, is returned.
- `isolated_from_above`: the parent op must be isolated from above
+ - `allow_empty_results`: get_parent_op is allowed to return an empty list and
+ still succeeds. In such a case, if get_parent_op fails for any operation
+ in the list, the entire transform returns an empty handle.
- `op_name`: the parent op must have the specified name
If `deduplicate` is set, the result handle does not contain any duplicate
@@ -614,12 +618,14 @@ def GetParentOp : TransformDialectOp<"get_parent_op",
is applied, e.g., "B" may itself be a parent of "A". This may have an impact
on the further transformation applied to the handle produced here.
- If any of the given Payload IR ops has no such suitable parent, the
- transformation fails silently.
+ If any of the given Payload IR ops has no such suitable parent, then:
+ - if `allow_empty_results` is set, the result handle is empty
+ - otherwise, the transformation fails silently.
}];
let arguments = (ins TransformHandleTypeInterface:$target,
UnitAttr:$isolated_from_above,
+ UnitAttr:$allow_empty_results,
OptionalAttr<StrAttr>:$op_name,
UnitAttr:$deduplicate);
let results = (outs TransformHandleTypeInterface:$parent);
@@ -739,6 +745,21 @@ def IncludeOp : TransformDialectOp<"include",
}];
}
+def MatchOperationEmptyOp : Op<Transform_Dialect, "match.operation_empty", [
+ AtMostOneOpMatcher,
+ MatchOpInterface,
+ MemoryEffectsOpInterface]> {
+ let summary =
+ "Matches if the handle is not associated to any op";
+ let description = [{
+ Succeeds if the handle is not associated to any op.
+ }];
+ let arguments = (ins TransformHandleTypeInterface:$operand_handle);
+ let assemblyFormat =
+ "$operand_handle attr-dict `:` type($operand_handle)";
+ let extraClassDeclaration = AtMostOneOpMatcher.extraDeclaration;
+}
+
def MatchOperationNameOp : TransformDialectOp<"match.operation_name",
[SingleOpMatcher,
MatchOpInterface,
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 44626260e2f9ef3..2dff1bf3d0a80ef 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -20,6 +20,7 @@
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
+#include "mlir/IR/Visitors.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Pass/Pass.h"
@@ -1244,6 +1245,10 @@ transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
parent = parent->getParentOp();
}
if (!parent) {
+ if (getAllowEmptyResults()) {
+ results.set(llvm::cast<OpResult>(getResult()), parents);
+ return DiagnosedSilenceableFailure::success();
+ }
DiagnosedSilenceableFailure diag =
emitSilenceableError()
<< "could not find a parent op that matches all requirements";
@@ -1545,6 +1550,21 @@ transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
.checkAndReport();
}
+//===----------------------------------------------------------------------===//
+// MatchOperationEmptyOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
+ ::std::optional<::mlir::Operation *> maybeCurrent,
+ transform::TransformResults &results, transform::TransformState &state) {
+ if (!maybeCurrent.has_value()) {
+ DBGS_MATCHER() << "MatchOperationEmptyOp success\n";
+ return DiagnosedSilenceableFailure::success();
+ }
+ DBGS_MATCHER() << "MatchOperationEmptyOp failure\n";
+ return emitSilenceableError() << "operation is not empty";
+}
+
//===----------------------------------------------------------------------===//
// MatchOperationNameOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index daa179cb15408b4..b641b21e876cc42 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -2037,3 +2037,24 @@ transform.sequence failures(propagate) {
// expected-remark @below{{0}}
test_print_number_of_associated_payload_ir_ops %empty_op : !transform.any_op
}
+
+
+// -----
+
+func.func @no_constant_under_loop(%lb: index, %ub: index, %step: index) {
+ scf.for %i= %lb to %ub step %step {
+ // expected-remark @below{{found arith.constant}}
+ arith.constant 0 : index
+ }
+ return
+}
+
+// Match `func.func`s that are not nested under a `func.func` and ensure there are none in the program
+transform.named_sequence @match_func_for_dispatch(%root: !transform.any_op {transform.readonly})
+ -> !transform.any_op {
+ transform.match.operation_name %root ["arith.constant"] : !transform.any_op
+ %variant = transform.get_parent_op %root { op_name = "func.func", allow_empty_results }
+ : (!transform.any_op) -> (!transform.any_op)
+ transform.match.operation_empty %variant : !transform.any_op
+ transform.yield %root : !transform.any_op
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/68319
More information about the Mlir-commits
mailing list