[Mlir-commits] [mlir] 8483d18 - [mlir][Transform] Relax the applicability of transform.foreach_match to also take into account the op itself

Nicolas Vasilache llvmlistbot at llvm.org
Mon Oct 30 04:53:16 PDT 2023


Author: Nicolas Vasilache
Date: 2023-10-30T11:53:04Z
New Revision: 8483d18be5b6b5e8721a10eb558be06008307ec6

URL: https://github.com/llvm/llvm-project/commit/8483d18be5b6b5e8721a10eb558be06008307ec6
DIFF: https://github.com/llvm/llvm-project/commit/8483d18be5b6b5e8721a10eb558be06008307ec6.diff

LOG: [mlir][Transform] Relax the applicability of transform.foreach_match to also take into account the op itself

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/test/Dialect/Linalg/match-ops-interpreter.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index b14c89eadb097d9..2fd0e80db96feba 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -481,8 +481,16 @@ def ForeachMatchOp : TransformDialectOp<"foreach_match", [
     This operation consumes the operand and produces a new handle associated
     with the same payload. This is necessary to trigger invalidation of handles
     to any of the payload operations nested in the payload operations associated
-    with the operand, as those are likely to be modified by actions. Note that
-    the root payload operation associated with the operand are not matched.
+    with the operand, as those are likely to be modified by actions. 
+    
+    By default, the root payload operation associated with the operand is not
+    matched. This is to support the conservative case where applied actions may
+    invalidate the root payload operation. If the optional `restrict_root`
+    attribute is set, the root operand is guaranteed to not be invalidated by any
+    of the applied actions. In such cases, the root payload operation is also
+    matched. This is useful because matching the root payload operation is a
+    common idiom, when e.g. matching a func.func directly and operations nested
+    under it.
 
     The operation succeeds if none of the matchers produced a definite failure
     during application and if all of the applied actions produced success. Note
@@ -495,13 +503,19 @@ def ForeachMatchOp : TransformDialectOp<"foreach_match", [
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$root,
+                       UnitAttr:$restrict_root,
                        SymbolRefArrayAttr:$matchers,
                        SymbolRefArrayAttr:$actions);
   let results = (outs TransformHandleTypeInterface:$updated);
 
-  let assemblyFormat =
-      "`in` $root custom<ForeachMatchSymbols>($matchers, $actions) "
-      "attr-dict `:` functional-type($root, $updated)";
+  let assemblyFormat = [{
+    (`restrict_root` $restrict_root^)?
+    `in`
+    $root
+    custom<ForeachMatchSymbols>($matchers, $actions)
+    attr-dict
+    `:` functional-type($root, $updated)
+  }];
 
   let hasVerifier = 1;
 }

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 8db77b6059dd2e3..514a75b5d590469 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -850,8 +850,9 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
 
   for (Operation *root : state.getPayloadOps(getRoot())) {
     WalkResult walkResult = root->walk([&](Operation *op) {
-      // Skip over the root op itself so we don't invalidate it.
-      if (op == root)
+      // If getRestrictRoot is not present, skip over the root op itself so we
+      // don't invalidate it.
+      if (!getRestrictRoot() && op == root)
         return WalkResult::advance();
 
       DEBUG_MATCHER({
@@ -1556,10 +1557,10 @@ DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
     ::std::optional<::mlir::Operation *> maybeCurrent,
     transform::TransformResults &results, transform::TransformState &state) {
   if (!maybeCurrent.has_value()) {
-    DBGS_MATCHER() << "MatchOperationEmptyOp success\n";
+    DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp success\n"; });
     return DiagnosedSilenceableFailure::success();
   }
-  DBGS_MATCHER() << "MatchOperationEmptyOp failure\n";
+  DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp failure\n"; });
   return emitSilenceableError() << "operation is not empty";
 }
 
@@ -1961,7 +1962,8 @@ void transform::NamedSequenceOp::build(OpBuilder &builder,
   state.addAttribute(SymbolTable::getSymbolAttrName(),
                      builder.getStringAttr(symName));
   state.addAttribute(getFunctionTypeAttrName(state.name),
-                     TypeAttr::get(FunctionType::get(builder.getContext(), rootType, resultTypes)));
+                     TypeAttr::get(FunctionType::get(builder.getContext(),
+                                                     rootType, resultTypes)));
   state.attributes.append(attrs.begin(), attrs.end());
   state.addRegion();
 

diff  --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
index 9489aadac843d7b..c88945c8a5c60fd 100644
--- a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
+++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
@@ -100,12 +100,13 @@ module attributes { transform.with_named_sequence } {
   }
 
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.consumed}) {
-    transform.foreach_match in %arg0
+    transform.foreach_match restrict_root in %arg0
         @match_structured_suppress -> @do_nothing
         : (!transform.any_op) -> !transform.any_op
     transform.yield
   }
 
+  // expected-remark @below {{other}}
   func.func @payload() attributes { transform.target_tag = "start_here" } {
     // expected-remark @below {{other}}
     %D = arith.constant dense<1.0> : tensor<2x4xf32>


        


More information about the Mlir-commits mailing list