[Mlir-commits] [mlir] [mlir][SCF] Use `transform.get_parent_op` instead of `transform.loop.get_parent_for` (PR #70757)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 30 19:19:53 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

Add a new attribute to `get_parent_op` to get the n-th parent. Remove `transform.loop.get_parent_for`, which is no longer needed.


---

Patch is 24.48 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/70757.diff


11 Files Affected:

- (modified) mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td (-24) 
- (modified) mlir/include/mlir/Dialect/Transform/IR/TransformOps.td (+7-4) 
- (modified) mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp (-33) 
- (modified) mlir/lib/Dialect/Transform/IR/TransformOps.cpp (+22-19) 
- (modified) mlir/python/mlir/dialects/transform/__init__.py (+23-20) 
- (modified) mlir/python/mlir/dialects/transform/loop.py (-24) 
- (modified) mlir/test/Dialect/SCF/transform-ops-invalid.mlir (+2-2) 
- (modified) mlir/test/Dialect/SCF/transform-ops.mlir (+6-102) 
- (modified) mlir/test/Dialect/Transform/test-interpreter.mlir (+27-1) 
- (modified) mlir/test/python/dialects/transform.py (+5-2) 
- (modified) mlir/test/python/dialects/transform_loop_ext.py (-15) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 700a29139a35b10..14df7e23a430fb1 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -68,30 +68,6 @@ def ForallToForOp : Op<Transform_Dialect, "loop.forall_to_for",
   let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
 }
 
-def GetParentForOp : Op<Transform_Dialect, "loop.get_parent_for",
-    [NavigationTransformOpTrait, MemoryEffectsOpInterface,
-     DeclareOpInterfaceMethods<TransformOpInterface>]> {
-  let summary = "Gets a handle to the parent 'for' loop of the given operation";
-  let description = [{
-    Produces a handle to the n-th (default 1) parent `scf.for` or `affine.for`
-    (when the affine flag is true) loop for each Payload IR operation
-    associated with the operand. Fails if such a loop cannot be found. The list
-    of operations associated with the handle contains parent operations in the
-    same order as the list associated with the operand, except for operations
-    that are parents to more than one input which are only present once.
-  }];
-
-  let arguments =
-    (ins TransformHandleTypeInterface:$target,
-         DefaultValuedAttr<ConfinedAttr<I64Attr, [IntPositive]>,
-                           "1">:$num_loops,
-         DefaultValuedAttr<BoolAttr, "false">:$affine);
-  let results = (outs TransformHandleTypeInterface : $parent);
-
-  let assemblyFormat =
-    "$target attr-dict `:` functional-type(operands, results)";
-}
-
 def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
      DeclareOpInterfaceMethods<TransformOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 2fd0e80db96feba..307257f4a582be5 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -620,10 +620,11 @@ def GetParentOp : TransformDialectOp<"get_parent_op",
     that case for each target op, the closest parent op that fulfills all
     requirements, is returned.
     - `isolated_from_above`: the parent op must be isolated from above
-    - `allow_empty_results`: get_parent_op is allowed to return an empty list and
-      still succeeds. In such a case, if get_parent_op fails for any operation
-      in the list, the entire transform returns an empty handle.
+    - `allow_empty_results`: get_parent_op is allowed to return an empty list
+      and still succeeds. In such a case, if get_parent_op fails for any
+      operation in the list, the entire transform returns an empty handle.
     - `op_name`: the parent op must have the specified name
+    - `nth_parent`: get the n-th parent of that satisfies the above requirements
 
     If `deduplicate` is set, the result handle does not contain any duplicate
     ops. For example, given the list
@@ -641,7 +642,9 @@ def GetParentOp : TransformDialectOp<"get_parent_op",
                        UnitAttr:$isolated_from_above,
                        UnitAttr:$allow_empty_results,
                        OptionalAttr<StrAttr>:$op_name,
-                       UnitAttr:$deduplicate);
+                       UnitAttr:$deduplicate,
+                       DefaultValuedAttr<ConfinedAttr<I64Attr, [IntPositive]>,
+                                         "1">:$nth_parent);
   let results = (outs TransformHandleTypeInterface:$parent);
   let assemblyFormat =
     "$target attr-dict `:` functional-type(operands, results)";
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 65d503d7c4ad8b8..62370604142cd5b 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -49,39 +49,6 @@ void transform::ApplySCFStructuralConversionPatternsOp::
                                                  conversionTarget);
 }
 
-//===----------------------------------------------------------------------===//
-// GetParentForOp
-//===----------------------------------------------------------------------===//
-
-DiagnosedSilenceableFailure
-transform::GetParentForOp::apply(transform::TransformRewriter &rewriter,
-                                 transform::TransformResults &results,
-                                 transform::TransformState &state) {
-  SetVector<Operation *> parents;
-  for (Operation *target : state.getPayloadOps(getTarget())) {
-    Operation *loop, *current = target;
-    for (unsigned i = 0, e = getNumLoops(); i < e; ++i) {
-      loop = getAffine()
-                 ? current->getParentOfType<AffineForOp>().getOperation()
-                 : current->getParentOfType<scf::ForOp>().getOperation();
-      if (!loop) {
-        DiagnosedSilenceableFailure diag =
-            emitSilenceableError()
-            << "could not find an '"
-            << (getAffine() ? AffineForOp::getOperationName()
-                            : scf::ForOp::getOperationName())
-            << "' parent";
-        diag.attachNote(target->getLoc()) << "target op";
-        return diag;
-      }
-      current = loop;
-    }
-    parents.insert(loop);
-  }
-  results.set(cast<OpResult>(getResult()), parents.getArrayRef());
-  return DiagnosedSilenceableFailure::success();
-}
-
 //===----------------------------------------------------------------------===//
 // ForallToForOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 514a75b5d590469..7136e423470a28b 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1232,27 +1232,30 @@ transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
   SmallVector<Operation *> parents;
   DenseSet<Operation *> resultSet;
   for (Operation *target : state.getPayloadOps(getTarget())) {
-    Operation *parent = target->getParentOp();
-    while (parent) {
-      bool checkIsolatedFromAbove =
-          !getIsolatedFromAbove() ||
-          parent->hasTrait<OpTrait::IsIsolatedFromAbove>();
-      bool checkOpName = !getOpName().has_value() ||
-                         parent->getName().getStringRef() == *getOpName();
-      if (checkIsolatedFromAbove && checkOpName)
-        break;
+    Operation *parent = target;
+    for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
       parent = parent->getParentOp();
-    }
-    if (!parent) {
-      if (getAllowEmptyResults()) {
-        results.set(llvm::cast<OpResult>(getResult()), parents);
-        return DiagnosedSilenceableFailure::success();
+      while (parent) {
+        bool checkIsolatedFromAbove =
+            !getIsolatedFromAbove() ||
+            parent->hasTrait<OpTrait::IsIsolatedFromAbove>();
+        bool checkOpName = !getOpName().has_value() ||
+                           parent->getName().getStringRef() == *getOpName();
+        if (checkIsolatedFromAbove && checkOpName)
+          break;
+        parent = parent->getParentOp();
+      }
+      if (!parent) {
+        if (getAllowEmptyResults()) {
+          results.set(llvm::cast<OpResult>(getResult()), parents);
+          return DiagnosedSilenceableFailure::success();
+        }
+        DiagnosedSilenceableFailure diag =
+            emitSilenceableError()
+            << "could not find a parent op that matches all requirements";
+        diag.attachNote(target->getLoc()) << "target op";
+        return diag;
       }
-      DiagnosedSilenceableFailure diag =
-          emitSilenceableError()
-          << "could not find a parent op that matches all requirements";
-      diag.attachNote(target->getLoc()) << "target op";
-      return diag;
     }
     if (getDeduplicate()) {
       if (!resultSet.contains(parent)) {
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index f7a2026e800aeb0..b984f989d32ea09 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -52,26 +52,29 @@ def patterns(self) -> Block:
 
 @_ods_cext.register_operation(_Dialect, replace=True)
 class GetParentOp(GetParentOp):
-    def __init__(
-        self,
-        result_type: Type,
-        target: Union[Operation, Value],
-        *,
-        isolated_from_above: bool = False,
-        op_name: Optional[str] = None,
-        deduplicate: bool = False,
-        loc=None,
-        ip=None,
-    ):
-        super().__init__(
-            result_type,
-            _get_op_result_or_value(target),
-            isolated_from_above=isolated_from_above,
-            op_name=op_name,
-            deduplicate=deduplicate,
-            loc=loc,
-            ip=ip,
-        )
+
+  def __init__(
+      self,
+      result_type: Type,
+      target: Union[Operation, Value],
+      *,
+      isolated_from_above: bool = False,
+      op_name: Optional[str] = None,
+      deduplicate: bool = False,
+      nth_parent: int = 1,
+      loc=None,
+      ip=None,
+  ):
+    super().__init__(
+        result_type,
+        _get_op_result_or_value(target),
+        isolated_from_above=isolated_from_above,
+        op_name=op_name,
+        deduplicate=deduplicate,
+        nth_parent=nth_parent,
+        loc=loc,
+        ip=ip,
+    )
 
 
 @_ods_cext.register_operation(_Dialect, replace=True)
diff --git a/mlir/python/mlir/dialects/transform/loop.py b/mlir/python/mlir/dialects/transform/loop.py
index 6c89025f413839e..3bdd9ca3b22f072 100644
--- a/mlir/python/mlir/dialects/transform/loop.py
+++ b/mlir/python/mlir/dialects/transform/loop.py
@@ -17,30 +17,6 @@
 from typing import Optional, Union
 
 
- at _ods_cext.register_operation(_Dialect, replace=True)
-class GetParentForOp(GetParentForOp):
-    """Extension for GetParentForOp."""
-
-    def __init__(
-        self,
-        result_type: Type,
-        target: Union[Operation, Value],
-        *,
-        num_loops: Optional[int] = None,
-        ip=None,
-        loc=None,
-    ):
-        if num_loops is None:
-            num_loops = 1
-        super().__init__(
-            result_type,
-            _get_op_result_or_value(target),
-            num_loops=num_loops,
-            ip=ip,
-            loc=loc,
-        )
-
-
 @_ods_cext.register_operation(_Dialect, replace=True)
 class LoopOutlineOp(LoopOutlineOp):
     """Extension for LoopOutlineOp."""
diff --git a/mlir/test/Dialect/SCF/transform-ops-invalid.mlir b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
index 96c57d4716d3755..59b824d4ca26209 100644
--- a/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
@@ -32,7 +32,7 @@ func.func @test_loops_do_not_get_unrolled() {
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1 = transform.loop.get_parent_for %0 { affine = true } : (!transform.any_op) -> !transform.op<"affine.for">
+    %1 = transform.get_parent_op %0 {op_name = "affine.for"} : (!transform.any_op) -> !transform.op<"affine.for">
     // expected-error @below {{failed to unroll}}
     transform.loop.unroll %1 { factor = 8 } : !transform.op<"affine.for">
     transform.yield
@@ -81,7 +81,7 @@ func.func @test_loops_do_not_get_peeled() {
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1 = transform.loop.get_parent_for %0 : (!transform.any_op) -> !transform.op<"scf.for">
+    %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
     // expected-error @below {{failed to peel}}
     transform.loop.peel %1 : (!transform.op<"scf.for">) -> (!transform.any_op, !transform.any_op)
     transform.yield
diff --git a/mlir/test/Dialect/SCF/transform-ops.mlir b/mlir/test/Dialect/SCF/transform-ops.mlir
index 6d1ba48d3b935bb..74601cf5b34a178 100644
--- a/mlir/test/Dialect/SCF/transform-ops.mlir
+++ b/mlir/test/Dialect/SCF/transform-ops.mlir
@@ -1,53 +1,5 @@
 // RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s
 
-// CHECK-LABEL: @get_parent_for_op
-func.func @get_parent_for_op(%arg0: index, %arg1: index, %arg2: index) {
-  // expected-remark @below {{first loop}}
-  scf.for %i = %arg0 to %arg1 step %arg2 {
-    // expected-remark @below {{second loop}}
-    scf.for %j = %arg0 to %arg1 step %arg2 {
-      // expected-remark @below {{third loop}}
-      scf.for %k = %arg0 to %arg1 step %arg2 {
-        arith.addi %i, %j : index
-      }
-    }
-  }
-  return
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    // CHECK: = transform.loop.get_parent_for
-    %1 = transform.loop.get_parent_for %0 : (!transform.any_op) -> !transform.op<"scf.for">
-    %2 = transform.loop.get_parent_for %0 { num_loops = 2 } : (!transform.any_op) -> !transform.op<"scf.for">
-    %3 = transform.loop.get_parent_for %0 { num_loops = 3 } : (!transform.any_op) -> !transform.op<"scf.for">
-    transform.test_print_remark_at_operand %1, "third loop" : !transform.op<"scf.for">
-    transform.test_print_remark_at_operand %2, "second loop" : !transform.op<"scf.for">
-    transform.test_print_remark_at_operand %3, "first loop" : !transform.op<"scf.for">
-    transform.yield
-  }
-}
-
-// -----
-
-func.func @get_parent_for_op_no_loop(%arg0: index, %arg1: index) {
-  // expected-note @below {{target op}}
-  arith.addi %arg0, %arg1 : index
-  return
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    // expected-error @below {{could not find an 'scf.for' parent}}
-    %1 = transform.loop.get_parent_for %0 : (!transform.any_op) -> !transform.op<"scf.for">
-    transform.yield
-  }
-}
-
-// -----
-
 // Outlined functions:
 //
 // CHECK: func @foo(%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}})
@@ -81,7 +33,7 @@ func.func @loop_outline_op(%arg0: index, %arg1: index, %arg2: index) {
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1 = transform.loop.get_parent_for %0  : (!transform.any_op) -> !transform.op<"scf.for">
+    %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
     // CHECK: = transform.loop.outline %{{.*}}
     transform.loop.outline %1 {func_name = "foo"} : (!transform.op<"scf.for">) -> (!transform.any_op, !transform.any_op)
     transform.yield
@@ -114,7 +66,7 @@ func.func @loop_peel_op() {
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1 = transform.loop.get_parent_for %0 : (!transform.any_op) -> !transform.op<"scf.for">
+    %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
     %main_loop, %remainder = transform.loop.peel %1 : (!transform.op<"scf.for">) -> (!transform.op<"scf.for">, !transform.op<"scf.for">)
     // Make sure 
     transform.test_print_remark_at_operand %main_loop, "main loop" : !transform.op<"scf.for">
@@ -152,7 +104,7 @@ func.func @loop_pipeline_op(%A: memref<?xf32>, %result: memref<?xf32>) {
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["arith.addf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1 = transform.loop.get_parent_for %0 : (!transform.any_op) -> !transform.op<"scf.for">
+    %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
     %2 = transform.loop.pipeline %1 : (!transform.op<"scf.for">) -> !transform.any_op
     // Verify that the returned handle is usable.
     transform.test_print_remark_at_operand %2, "transformed" : !transform.any_op
@@ -178,7 +130,7 @@ func.func @loop_unroll_op() {
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1 = transform.loop.get_parent_for %0 : (!transform.any_op) -> !transform.op<"scf.for">
+    %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
     transform.loop.unroll %1 { factor = 4 } : !transform.op<"scf.for">
     transform.yield
   }
@@ -186,54 +138,6 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-// CHECK-LABEL: @get_parent_for_op
-func.func @get_parent_for_op(%arg0: index, %arg1: index, %arg2: index) {
-  // expected-remark @below {{first loop}}
-  affine.for %i = %arg0 to %arg1 {
-    // expected-remark @below {{second loop}}
-    affine.for %j = %arg0 to %arg1 {
-      // expected-remark @below {{third loop}}
-      affine.for %k = %arg0 to %arg1 {
-        arith.addi %i, %j : index
-      }
-    }
-  }
-  return
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    // CHECK: = transform.loop.get_parent_for
-    %1 = transform.loop.get_parent_for %0 { affine = true } : (!transform.any_op) -> !transform.op<"affine.for">
-    %2 = transform.loop.get_parent_for %0 { num_loops = 2, affine = true } : (!transform.any_op) -> !transform.op<"affine.for">
-    %3 = transform.loop.get_parent_for %0 { num_loops = 3, affine = true } : (!transform.any_op) -> !transform.op<"affine.for">
-    transform.test_print_remark_at_operand %1, "third loop" : !transform.op<"affine.for">
-    transform.test_print_remark_at_operand %2, "second loop" : !transform.op<"affine.for">
-    transform.test_print_remark_at_operand %3, "first loop" : !transform.op<"affine.for">
-    transform.yield
-  }
-}
-
-// -----
-
-func.func @get_parent_for_op_no_loop(%arg0: index, %arg1: index) {
-  // expected-note @below {{target op}}
-  arith.addi %arg0, %arg1 : index
-  return
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    // expected-error @below {{could not find an 'affine.for' parent}}
-    %1 = transform.loop.get_parent_for %0 { affine = true } : (!transform.any_op) -> !transform.op<"affine.for">
-    transform.yield
-  }
-}
-
-// -----
-
 func.func @loop_unroll_op() {
   %c0 = arith.constant 0 : index
   %c42 = arith.constant 42 : index
@@ -250,7 +154,7 @@ func.func @loop_unroll_op() {
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1 = transform.loop.get_parent_for %0 { affine = true } : (!transform.any_op) -> !transform.op<"affine.for">
+    %1 = transform.get_parent_op %0 {op_name = "affine.for"} : (!transform.any_op) -> !transform.op<"affine.for">...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/70757


More information about the Mlir-commits mailing list