[Mlir-commits] [mlir] [mlir][SCF] Use `transform.get_parent_op` instead of `transform.loop.get_parent_for` (PR #70757)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Oct 30 19:19:53 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
Add a new attribute to `get_parent_op` to get the n-th parent. Remove `transform.loop.get_parent_for`, which is no longer needed.
---
Patch is 24.48 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/70757.diff
11 Files Affected:
- (modified) mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td (-24)
- (modified) mlir/include/mlir/Dialect/Transform/IR/TransformOps.td (+7-4)
- (modified) mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp (-33)
- (modified) mlir/lib/Dialect/Transform/IR/TransformOps.cpp (+22-19)
- (modified) mlir/python/mlir/dialects/transform/__init__.py (+23-20)
- (modified) mlir/python/mlir/dialects/transform/loop.py (-24)
- (modified) mlir/test/Dialect/SCF/transform-ops-invalid.mlir (+2-2)
- (modified) mlir/test/Dialect/SCF/transform-ops.mlir (+6-102)
- (modified) mlir/test/Dialect/Transform/test-interpreter.mlir (+27-1)
- (modified) mlir/test/python/dialects/transform.py (+5-2)
- (modified) mlir/test/python/dialects/transform_loop_ext.py (-15)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 700a29139a35b10..14df7e23a430fb1 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -68,30 +68,6 @@ def ForallToForOp : Op<Transform_Dialect, "loop.forall_to_for",
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
}
-def GetParentForOp : Op<Transform_Dialect, "loop.get_parent_for",
- [NavigationTransformOpTrait, MemoryEffectsOpInterface,
- DeclareOpInterfaceMethods<TransformOpInterface>]> {
- let summary = "Gets a handle to the parent 'for' loop of the given operation";
- let description = [{
- Produces a handle to the n-th (default 1) parent `scf.for` or `affine.for`
- (when the affine flag is true) loop for each Payload IR operation
- associated with the operand. Fails if such a loop cannot be found. The list
- of operations associated with the handle contains parent operations in the
- same order as the list associated with the operand, except for operations
- that are parents to more than one input which are only present once.
- }];
-
- let arguments =
- (ins TransformHandleTypeInterface:$target,
- DefaultValuedAttr<ConfinedAttr<I64Attr, [IntPositive]>,
- "1">:$num_loops,
- DefaultValuedAttr<BoolAttr, "false">:$affine);
- let results = (outs TransformHandleTypeInterface : $parent);
-
- let assemblyFormat =
- "$target attr-dict `:` functional-type(operands, results)";
-}
-
def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 2fd0e80db96feba..307257f4a582be5 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -620,10 +620,11 @@ def GetParentOp : TransformDialectOp<"get_parent_op",
that case for each target op, the closest parent op that fulfills all
requirements, is returned.
- `isolated_from_above`: the parent op must be isolated from above
- - `allow_empty_results`: get_parent_op is allowed to return an empty list and
- still succeeds. In such a case, if get_parent_op fails for any operation
- in the list, the entire transform returns an empty handle.
+ - `allow_empty_results`: get_parent_op is allowed to return an empty list
+ and still succeeds. In such a case, if get_parent_op fails for any
+ operation in the list, the entire transform returns an empty handle.
- `op_name`: the parent op must have the specified name
+ - `nth_parent`: get the n-th parent of that satisfies the above requirements
If `deduplicate` is set, the result handle does not contain any duplicate
ops. For example, given the list
@@ -641,7 +642,9 @@ def GetParentOp : TransformDialectOp<"get_parent_op",
UnitAttr:$isolated_from_above,
UnitAttr:$allow_empty_results,
OptionalAttr<StrAttr>:$op_name,
- UnitAttr:$deduplicate);
+ UnitAttr:$deduplicate,
+ DefaultValuedAttr<ConfinedAttr<I64Attr, [IntPositive]>,
+ "1">:$nth_parent);
let results = (outs TransformHandleTypeInterface:$parent);
let assemblyFormat =
"$target attr-dict `:` functional-type(operands, results)";
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 65d503d7c4ad8b8..62370604142cd5b 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -49,39 +49,6 @@ void transform::ApplySCFStructuralConversionPatternsOp::
conversionTarget);
}
-//===----------------------------------------------------------------------===//
-// GetParentForOp
-//===----------------------------------------------------------------------===//
-
-DiagnosedSilenceableFailure
-transform::GetParentForOp::apply(transform::TransformRewriter &rewriter,
- transform::TransformResults &results,
- transform::TransformState &state) {
- SetVector<Operation *> parents;
- for (Operation *target : state.getPayloadOps(getTarget())) {
- Operation *loop, *current = target;
- for (unsigned i = 0, e = getNumLoops(); i < e; ++i) {
- loop = getAffine()
- ? current->getParentOfType<AffineForOp>().getOperation()
- : current->getParentOfType<scf::ForOp>().getOperation();
- if (!loop) {
- DiagnosedSilenceableFailure diag =
- emitSilenceableError()
- << "could not find an '"
- << (getAffine() ? AffineForOp::getOperationName()
- : scf::ForOp::getOperationName())
- << "' parent";
- diag.attachNote(target->getLoc()) << "target op";
- return diag;
- }
- current = loop;
- }
- parents.insert(loop);
- }
- results.set(cast<OpResult>(getResult()), parents.getArrayRef());
- return DiagnosedSilenceableFailure::success();
-}
-
//===----------------------------------------------------------------------===//
// ForallToForOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 514a75b5d590469..7136e423470a28b 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1232,27 +1232,30 @@ transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
SmallVector<Operation *> parents;
DenseSet<Operation *> resultSet;
for (Operation *target : state.getPayloadOps(getTarget())) {
- Operation *parent = target->getParentOp();
- while (parent) {
- bool checkIsolatedFromAbove =
- !getIsolatedFromAbove() ||
- parent->hasTrait<OpTrait::IsIsolatedFromAbove>();
- bool checkOpName = !getOpName().has_value() ||
- parent->getName().getStringRef() == *getOpName();
- if (checkIsolatedFromAbove && checkOpName)
- break;
+ Operation *parent = target;
+ for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
parent = parent->getParentOp();
- }
- if (!parent) {
- if (getAllowEmptyResults()) {
- results.set(llvm::cast<OpResult>(getResult()), parents);
- return DiagnosedSilenceableFailure::success();
+ while (parent) {
+ bool checkIsolatedFromAbove =
+ !getIsolatedFromAbove() ||
+ parent->hasTrait<OpTrait::IsIsolatedFromAbove>();
+ bool checkOpName = !getOpName().has_value() ||
+ parent->getName().getStringRef() == *getOpName();
+ if (checkIsolatedFromAbove && checkOpName)
+ break;
+ parent = parent->getParentOp();
+ }
+ if (!parent) {
+ if (getAllowEmptyResults()) {
+ results.set(llvm::cast<OpResult>(getResult()), parents);
+ return DiagnosedSilenceableFailure::success();
+ }
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError()
+ << "could not find a parent op that matches all requirements";
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
}
- DiagnosedSilenceableFailure diag =
- emitSilenceableError()
- << "could not find a parent op that matches all requirements";
- diag.attachNote(target->getLoc()) << "target op";
- return diag;
}
if (getDeduplicate()) {
if (!resultSet.contains(parent)) {
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index f7a2026e800aeb0..b984f989d32ea09 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -52,26 +52,29 @@ def patterns(self) -> Block:
@_ods_cext.register_operation(_Dialect, replace=True)
class GetParentOp(GetParentOp):
- def __init__(
- self,
- result_type: Type,
- target: Union[Operation, Value],
- *,
- isolated_from_above: bool = False,
- op_name: Optional[str] = None,
- deduplicate: bool = False,
- loc=None,
- ip=None,
- ):
- super().__init__(
- result_type,
- _get_op_result_or_value(target),
- isolated_from_above=isolated_from_above,
- op_name=op_name,
- deduplicate=deduplicate,
- loc=loc,
- ip=ip,
- )
+
+ def __init__(
+ self,
+ result_type: Type,
+ target: Union[Operation, Value],
+ *,
+ isolated_from_above: bool = False,
+ op_name: Optional[str] = None,
+ deduplicate: bool = False,
+ nth_parent: int = 1,
+ loc=None,
+ ip=None,
+ ):
+ super().__init__(
+ result_type,
+ _get_op_result_or_value(target),
+ isolated_from_above=isolated_from_above,
+ op_name=op_name,
+ deduplicate=deduplicate,
+ nth_parent=nth_parent,
+ loc=loc,
+ ip=ip,
+ )
@_ods_cext.register_operation(_Dialect, replace=True)
diff --git a/mlir/python/mlir/dialects/transform/loop.py b/mlir/python/mlir/dialects/transform/loop.py
index 6c89025f413839e..3bdd9ca3b22f072 100644
--- a/mlir/python/mlir/dialects/transform/loop.py
+++ b/mlir/python/mlir/dialects/transform/loop.py
@@ -17,30 +17,6 @@
from typing import Optional, Union
- at _ods_cext.register_operation(_Dialect, replace=True)
-class GetParentForOp(GetParentForOp):
- """Extension for GetParentForOp."""
-
- def __init__(
- self,
- result_type: Type,
- target: Union[Operation, Value],
- *,
- num_loops: Optional[int] = None,
- ip=None,
- loc=None,
- ):
- if num_loops is None:
- num_loops = 1
- super().__init__(
- result_type,
- _get_op_result_or_value(target),
- num_loops=num_loops,
- ip=ip,
- loc=loc,
- )
-
-
@_ods_cext.register_operation(_Dialect, replace=True)
class LoopOutlineOp(LoopOutlineOp):
"""Extension for LoopOutlineOp."""
diff --git a/mlir/test/Dialect/SCF/transform-ops-invalid.mlir b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
index 96c57d4716d3755..59b824d4ca26209 100644
--- a/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
@@ -32,7 +32,7 @@ func.func @test_loops_do_not_get_unrolled() {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.loop.get_parent_for %0 { affine = true } : (!transform.any_op) -> !transform.op<"affine.for">
+ %1 = transform.get_parent_op %0 {op_name = "affine.for"} : (!transform.any_op) -> !transform.op<"affine.for">
// expected-error @below {{failed to unroll}}
transform.loop.unroll %1 { factor = 8 } : !transform.op<"affine.for">
transform.yield
@@ -81,7 +81,7 @@ func.func @test_loops_do_not_get_peeled() {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.loop.get_parent_for %0 : (!transform.any_op) -> !transform.op<"scf.for">
+ %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
// expected-error @below {{failed to peel}}
transform.loop.peel %1 : (!transform.op<"scf.for">) -> (!transform.any_op, !transform.any_op)
transform.yield
diff --git a/mlir/test/Dialect/SCF/transform-ops.mlir b/mlir/test/Dialect/SCF/transform-ops.mlir
index 6d1ba48d3b935bb..74601cf5b34a178 100644
--- a/mlir/test/Dialect/SCF/transform-ops.mlir
+++ b/mlir/test/Dialect/SCF/transform-ops.mlir
@@ -1,53 +1,5 @@
// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s
-// CHECK-LABEL: @get_parent_for_op
-func.func @get_parent_for_op(%arg0: index, %arg1: index, %arg2: index) {
- // expected-remark @below {{first loop}}
- scf.for %i = %arg0 to %arg1 step %arg2 {
- // expected-remark @below {{second loop}}
- scf.for %j = %arg0 to %arg1 step %arg2 {
- // expected-remark @below {{third loop}}
- scf.for %k = %arg0 to %arg1 step %arg2 {
- arith.addi %i, %j : index
- }
- }
- }
- return
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // CHECK: = transform.loop.get_parent_for
- %1 = transform.loop.get_parent_for %0 : (!transform.any_op) -> !transform.op<"scf.for">
- %2 = transform.loop.get_parent_for %0 { num_loops = 2 } : (!transform.any_op) -> !transform.op<"scf.for">
- %3 = transform.loop.get_parent_for %0 { num_loops = 3 } : (!transform.any_op) -> !transform.op<"scf.for">
- transform.test_print_remark_at_operand %1, "third loop" : !transform.op<"scf.for">
- transform.test_print_remark_at_operand %2, "second loop" : !transform.op<"scf.for">
- transform.test_print_remark_at_operand %3, "first loop" : !transform.op<"scf.for">
- transform.yield
- }
-}
-
-// -----
-
-func.func @get_parent_for_op_no_loop(%arg0: index, %arg1: index) {
- // expected-note @below {{target op}}
- arith.addi %arg0, %arg1 : index
- return
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // expected-error @below {{could not find an 'scf.for' parent}}
- %1 = transform.loop.get_parent_for %0 : (!transform.any_op) -> !transform.op<"scf.for">
- transform.yield
- }
-}
-
-// -----
-
// Outlined functions:
//
// CHECK: func @foo(%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}})
@@ -81,7 +33,7 @@ func.func @loop_outline_op(%arg0: index, %arg1: index, %arg2: index) {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.loop.get_parent_for %0 : (!transform.any_op) -> !transform.op<"scf.for">
+ %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
// CHECK: = transform.loop.outline %{{.*}}
transform.loop.outline %1 {func_name = "foo"} : (!transform.op<"scf.for">) -> (!transform.any_op, !transform.any_op)
transform.yield
@@ -114,7 +66,7 @@ func.func @loop_peel_op() {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.loop.get_parent_for %0 : (!transform.any_op) -> !transform.op<"scf.for">
+ %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
%main_loop, %remainder = transform.loop.peel %1 : (!transform.op<"scf.for">) -> (!transform.op<"scf.for">, !transform.op<"scf.for">)
// Make sure
transform.test_print_remark_at_operand %main_loop, "main loop" : !transform.op<"scf.for">
@@ -152,7 +104,7 @@ func.func @loop_pipeline_op(%A: memref<?xf32>, %result: memref<?xf32>) {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["arith.addf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.loop.get_parent_for %0 : (!transform.any_op) -> !transform.op<"scf.for">
+ %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
%2 = transform.loop.pipeline %1 : (!transform.op<"scf.for">) -> !transform.any_op
// Verify that the returned handle is usable.
transform.test_print_remark_at_operand %2, "transformed" : !transform.any_op
@@ -178,7 +130,7 @@ func.func @loop_unroll_op() {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.loop.get_parent_for %0 : (!transform.any_op) -> !transform.op<"scf.for">
+ %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
transform.loop.unroll %1 { factor = 4 } : !transform.op<"scf.for">
transform.yield
}
@@ -186,54 +138,6 @@ module attributes {transform.with_named_sequence} {
// -----
-// CHECK-LABEL: @get_parent_for_op
-func.func @get_parent_for_op(%arg0: index, %arg1: index, %arg2: index) {
- // expected-remark @below {{first loop}}
- affine.for %i = %arg0 to %arg1 {
- // expected-remark @below {{second loop}}
- affine.for %j = %arg0 to %arg1 {
- // expected-remark @below {{third loop}}
- affine.for %k = %arg0 to %arg1 {
- arith.addi %i, %j : index
- }
- }
- }
- return
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // CHECK: = transform.loop.get_parent_for
- %1 = transform.loop.get_parent_for %0 { affine = true } : (!transform.any_op) -> !transform.op<"affine.for">
- %2 = transform.loop.get_parent_for %0 { num_loops = 2, affine = true } : (!transform.any_op) -> !transform.op<"affine.for">
- %3 = transform.loop.get_parent_for %0 { num_loops = 3, affine = true } : (!transform.any_op) -> !transform.op<"affine.for">
- transform.test_print_remark_at_operand %1, "third loop" : !transform.op<"affine.for">
- transform.test_print_remark_at_operand %2, "second loop" : !transform.op<"affine.for">
- transform.test_print_remark_at_operand %3, "first loop" : !transform.op<"affine.for">
- transform.yield
- }
-}
-
-// -----
-
-func.func @get_parent_for_op_no_loop(%arg0: index, %arg1: index) {
- // expected-note @below {{target op}}
- arith.addi %arg0, %arg1 : index
- return
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // expected-error @below {{could not find an 'affine.for' parent}}
- %1 = transform.loop.get_parent_for %0 { affine = true } : (!transform.any_op) -> !transform.op<"affine.for">
- transform.yield
- }
-}
-
-// -----
-
func.func @loop_unroll_op() {
%c0 = arith.constant 0 : index
%c42 = arith.constant 42 : index
@@ -250,7 +154,7 @@ func.func @loop_unroll_op() {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.loop.get_parent_for %0 { affine = true } : (!transform.any_op) -> !transform.op<"affine.for">
+ %1 = transform.get_parent_op %0 {op_name = "affine.for"} : (!transform.any_op) -> !transform.op<"affine.for">...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/70757
More information about the Mlir-commits
mailing list