[Mlir-commits] [mlir] make transform.split_handle accept any handle kind (PR #118752)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Dec 4 23:20:25 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Oleksandr "Alex" Zinenko (ftynse)
<details>
<summary>Changes</summary>
It can now split value and parameter handles in addition to operation handles. This is a generally useful functionality.
---
Full diff: https://github.com/llvm/llvm-project/pull/118752.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/Transform/IR/TransformOps.td (+11-10)
- (modified) mlir/lib/Dialect/Transform/IR/TransformOps.cpp (+52-12)
- (modified) mlir/test/Dialect/Transform/test-interpreter.mlir (+67-2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index b946fc8875860b..2d71d5b0892afe 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -1062,36 +1062,37 @@ def SplitHandleOp : TransformDialectOp<"split_handle",
[FunctionalStyleTransformOpTrait,
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
- let summary = "Splits a handle of payload ops into handles with a single op";
+ let summary = "Splits a handle or parameter into multiple values";
let description = [{
Splits `handle` into one or multiple handles, as specified by the number
of results of this operation. `handle` should be mapped to as many payload
- ops as there are results. Otherwise, this transform will fail produces a
- silenceable failure by default. Each result handle is mapped to exactly one
- payload op. The order of the payload ops is preserved, i.e., the i-th
- payload op is mapped to the i-th result handle.
+ ops, values or parameteres as there are results. Otherwise, this transform
+ will fail producing a silenceable failure by default. Each result handle
+ is mapped to exactly one payload unless specified otherwise by attributes
+ described below. The order of the payloads is preserved, i.e., the i-th
+ payload is mapped to the i-th result handle.
This operation is useful for ensuring a statically known number of
- operations are tracked by the source `handle` and to extract them into
+ payloads are tracked by the source `handle` and to extract them into
individual handles that can be further manipulated in isolation.
- If there are more payload ops than results, the remaining ops are mapped to
+ If there are more payloads than results, the remaining payloads are mapped to
the result with index `overflow_result`. If no `overflow_result` is
specified, the transform produces a silenceable failure.
If there are fewer payload ops than results, the transform produces a
silenceable failure if `fail_on_payload_too_small` is set to "true".
Otherwise, it succeeds and the remaining result handles are not mapped to
- any op. It also succeeds if `handle` is empty and
+ anything. It also succeeds if `handle` is empty and
`pass_through_empty_handle` is set to "true", regardless of
`fail_on_payload_too_small`.
}];
- let arguments = (ins TransformHandleTypeInterface:$handle,
+ let arguments = (ins Transform_AnyHandleOrParamType:$handle,
DefaultValuedAttr<BoolAttr, "true">:$pass_through_empty_handle,
DefaultValuedAttr<BoolAttr, "true">:$fail_on_payload_too_small,
OptionalAttr<I64Attr>:$overflow_result);
- let results = (outs Variadic<TransformHandleTypeInterface>:$results);
+ let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
let hasVerifier = 1;
let builders = [
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 590cae9aa0d667..68d1f2aef638a5 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -2415,32 +2415,63 @@ DiagnosedSilenceableFailure
transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
- int64_t numPayloadOps = llvm::range_size(state.getPayloadOps(getHandle()));
+ int64_t numPayloads =
+ llvm::TypeSwitch<Type, int64_t>(getHandle().getType())
+ .Case<TransformHandleTypeInterface>([&](auto x) {
+ return llvm::range_size(state.getPayloadOps(getHandle()));
+ })
+ .Case<TransformValueHandleTypeInterface>([&](auto x) {
+ return llvm::range_size(state.getPayloadValues(getHandle()));
+ })
+ .Case<TransformParamTypeInterface>([&](auto x) {
+ return llvm::range_size(state.getParams(getHandle()));
+ })
+ .Default([](auto x) {
+ llvm_unreachable("unknown transform dialect type interface");
+ return -1;
+ });
+
auto produceNumOpsError = [&]() {
return emitSilenceableError()
<< getHandle() << " expected to contain " << this->getNumResults()
- << " payload ops but it contains " << numPayloadOps
- << " payload ops";
+ << " payloads but it contains " << numPayloads
+ << " payloads";
};
// Fail if there are more payload ops than results and no overflow result was
// specified.
- if (numPayloadOps > getNumResults() && !getOverflowResult().has_value())
+ if (numPayloads > getNumResults() && !getOverflowResult().has_value())
return produceNumOpsError();
// Fail if there are more results than payload ops. Unless:
// - "fail_on_payload_too_small" is set to "false", or
// - "pass_through_empty_handle" is set to "true" and there are 0 payload ops.
- if (numPayloadOps < getNumResults() && getFailOnPayloadTooSmall() &&
- (numPayloadOps != 0 || !getPassThroughEmptyHandle()))
+ if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() &&
+ (numPayloads != 0 || !getPassThroughEmptyHandle()))
return produceNumOpsError();
- // Distribute payload ops.
- SmallVector<SmallVector<Operation *, 1>> resultHandles(getNumResults(), {});
+ // Distribute payloads.
+ SmallVector<SmallVector<MappedValue, 1>> resultHandles(getNumResults(), {});
if (getOverflowResult())
- resultHandles[*getOverflowResult()].reserve(numPayloadOps -
- getNumResults());
- for (auto &&en : llvm::enumerate(state.getPayloadOps(getHandle()))) {
+ resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults());
+
+ auto container = [&]() {
+ if (isa<TransformHandleTypeInterface>(getHandle().getType())) {
+ return llvm::map_to_vector(
+ state.getPayloadOps(getHandle()),
+ [](Operation *op) -> MappedValue { return op; });
+ }
+ if (isa<TransformValueHandleTypeInterface>(getHandle().getType())) {
+ return llvm::map_to_vector(state.getPayloadValues(getHandle()),
+ [](Value v) -> MappedValue { return v; });
+ }
+ assert(isa<TransformParamTypeInterface>(getHandle().getType()) &&
+ "unsupported kind of transform dialect type");
+ return llvm::map_to_vector(state.getParams(getHandle()),
+ [](Attribute a) -> MappedValue { return a; });
+ }();
+
+ for (auto &&en : llvm::enumerate(container)) {
int64_t resultNum = en.index();
if (resultNum >= getNumResults())
resultNum = *getOverflowResult();
@@ -2449,7 +2480,7 @@ transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
// Set transform op results.
for (auto &&it : llvm::enumerate(resultHandles))
- results.set(llvm::cast<OpResult>(getResult(it.index())), it.value());
+ results.setMappedValues(llvm::cast<OpResult>(getResult(it.index())), it.value());
return DiagnosedSilenceableFailure::success();
}
@@ -2466,6 +2497,15 @@ LogicalResult transform::SplitHandleOp::verify() {
if (getOverflowResult().has_value() &&
!(*getOverflowResult() < getNumResults()))
return emitOpError("overflow_result is not a valid result index");
+
+ for (Type resultType : getResultTypes()) {
+ if (implementSameTransformInterface(getHandle().getType(), resultType))
+ continue;
+
+ return emitOpError("expects result types to implement the same transform "
+ "interface as the operand type");
+ }
+
return success();
}
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 4fe2dbedff56e3..ecc234587cda95 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1094,7 +1094,7 @@ module attributes {transform.with_named_sequence} {
// expected-remark @below {{1}}
transform.debug.emit_param_as_remark %p : !transform.param<i64>
%muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
- // expected-error @below {{expected to contain 3 payload ops but it contains 2 payload ops}}
+ // expected-error @below {{expected to contain 3 payloads but it contains 2 payloads}}
%h_2:3 = transform.split_handle %muli_2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
@@ -1180,6 +1180,71 @@ module attributes {transform.with_named_sequence} {
// -----
+func.func private @opaque() -> (i32, i32)
+
+func.func @split_handle() {
+ func.call @opaque() : () -> (i32, i32)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%fun: !transform.any_op) {
+ %op = transform.structured.match ops{["func.call"]} in %fun : (!transform.any_op) -> !transform.any_op
+ %val = transform.get_result %op[all] : (!transform.any_op) -> !transform.any_value
+ %p = transform.num_associations %val : (!transform.any_value) -> !transform.any_param
+ // expected-remark @below {{total 2}}
+ transform.debug.emit_param_as_remark %p, "total" : !transform.any_param
+ %h:2 = transform.split_handle %val : (!transform.any_value) -> (!transform.any_value, !transform.any_value)
+ %p1 = transform.num_associations %h#0 : (!transform.any_value) -> !transform.any_param
+ %p2 = transform.num_associations %h#1 : (!transform.any_value) -> !transform.any_param
+ // expected-remark @below {{first 1}}
+ transform.debug.emit_param_as_remark %p1, "first" : !transform.any_param
+ // expected-remark @below {{second 1}}
+ transform.debug.emit_param_as_remark %p1, "second" : !transform.any_param
+ transform.yield
+ }
+}
+
+// -----
+
+func.func private @opaque() -> (i32, i32)
+
+func.func @split_handle() {
+ func.call @opaque() : () -> (i32, i32)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%fun: !transform.any_op) {
+ %op = transform.structured.match ops{["func.call"]} in %fun : (!transform.any_op) -> !transform.any_op
+ %val = transform.get_result %op[all] : (!transform.any_op) -> !transform.any_value
+ %type = transform.get_type %val : (!transform.any_value) -> !transform.any_param
+ %p = transform.num_associations %type : (!transform.any_param) -> !transform.any_param
+ // expected-remark @below {{total 2}}
+ transform.debug.emit_param_as_remark %p, "total" : !transform.any_param
+ %h:2 = transform.split_handle %type : (!transform.any_param) -> (!transform.any_param, !transform.any_param)
+ %p1 = transform.num_associations %h#0 : (!transform.any_param) -> !transform.any_param
+ %p2 = transform.num_associations %h#1 : (!transform.any_param) -> !transform.any_param
+ // expected-remark @below {{first 1}}
+ transform.debug.emit_param_as_remark %p1, "first" : !transform.any_param
+ // expected-remark @below {{second 1}}
+ transform.debug.emit_param_as_remark %p1, "second" : !transform.any_param
+ transform.yield
+ }
+}
+
+// -----
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%fun: !transform.any_op) {
+ // expected-error @below {{op expects result types to implement the same transform interface as the operand type}}
+ transform.split_handle %fun : (!transform.any_op) -> (!transform.any_op, !transform.any_value)
+ transform.yield
+ }
+}
+
+// -----
+
"test.some_op"() : () -> ()
"other_dialect.other_op"() : () -> ()
@@ -1324,7 +1389,7 @@ module attributes {transform.with_named_sequence} {
transform.sequence %root : !transform.any_op -> !transform.any_op failures(propagate) {
^bb1(%fun: !transform.any_op):
%muli = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
- // expected-error @below {{expected to contain 3 payload ops but it contains 2 payload ops}}
+ // expected-error @below {{expected to contain 3 payloads but it contains 2 payloads}}
%h_2:3 = split_handle %muli : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
/// Test that yield does not crash in the presence of silenceable error in
/// propagate mode.
``````````
</details>
https://github.com/llvm/llvm-project/pull/118752
More information about the Mlir-commits
mailing list