[Mlir-commits] [mlir] [mlir] update remaining transform tests to main pass (PR #81279)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 9 09:04:00 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Oleksandr "Alex" Zinenko (ftynse)

<details>
<summary>Changes</summary>

Use the main transform interpreter pass instead of the test pass. The only tests that are not updated are specific to the operation of the test pass.

---

Patch is 64.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81279.diff


15 Files Affected:

- (modified) mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp (+73-22) 
- (modified) mlir/test/Dialect/Transform/include/test-interpreter-external-concurrent-source.mlir (+15-10) 
- (modified) mlir/test/Dialect/Transform/include/test-interpreter-external-source.mlir (+8-6) 
- (modified) mlir/test/Dialect/Transform/multi-arg-top-level-ops.mlir (+42-21) 
- (modified) mlir/test/Dialect/Transform/multi-arg-top-level-params.mlir (+26-13) 
- (modified) mlir/test/Dialect/Transform/multi-arg-top-level-values.mlir (+21-11) 
- (modified) mlir/test/Dialect/Transform/test-interpreter-debug.mlir (+17-41) 
- (modified) mlir/test/Dialect/Transform/test-interpreter-external-concurrent.mlir (+3-1) 
- (modified) mlir/test/Dialect/Transform/test-interpreter-external.mlir (+3-1) 
- (modified) mlir/test/Dialect/Transform/test-interpreter.mlir (-1) 
- (modified) mlir/test/Dialect/Transform/test-pass-application.mlir (+45-33) 
- (modified) mlir/test/Dialect/Transform/test-pattern-application.mlir (+206-178) 
- (modified) mlir/test/Dialect/Transform/test-pdl-extension.mlir (+39-29) 
- (modified) mlir/test/Dialect/Transform/transform-state-extension.mlir (+54-48) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-contraction.mlir (+10-8) 


``````````diff
diff --git a/mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp b/mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp
index 5073234a7e35e9..7adf223f3440a5 100644
--- a/mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp
@@ -50,12 +50,79 @@ static Operation *findPayloadRoot(Operation *passRoot, StringRef tag) {
     return WalkResult::interrupt();
   });
 
+  if (!target) {
+    passRoot->emitError()
+        << "could not find the operation with transform.target_tag=\"" << tag
+        << "\" attribute";
+    return nullptr;
+  }
+
   return walkResult.wasInterrupted() ? nullptr : target;
 }
 
 namespace {
 class InterpreterPass
     : public transform::impl::InterpreterPassBase<InterpreterPass> {
+  // Parses the pass arguments to bind trailing arguments of the entry point.
+  std::optional<RaggedArray<transform::MappedValue>>
+  parseArguments(Operation *payloadRoot) {
+    MLIRContext *context = payloadRoot->getContext();
+
+    SmallVector<SmallVector<transform::MappedValue>, 2> trailingBindings;
+    trailingBindings.resize(debugBindTrailingArgs.size());
+
+    // Construct lists of op names to match.
+    SmallVector<std::optional<OperationName>> debugBindNames;
+    debugBindNames.reserve(debugBindTrailingArgs.size());
+    for (auto &&[position, nameString] :
+         llvm::enumerate(debugBindTrailingArgs)) {
+      StringRef name = nameString;
+
+      // Parse the integer literals.
+      if (name.starts_with("#")) {
+        debugBindNames.push_back(std::nullopt);
+        StringRef lhs = "";
+        StringRef rhs = name.drop_front();
+        do {
+          std::tie(lhs, rhs) = rhs.split(';');
+          int64_t value;
+          if (lhs.getAsInteger(10, value)) {
+            emitError(UnknownLoc::get(context))
+                << "couldn't parse integer pass argument " << name;
+            return std::nullopt;
+          }
+          trailingBindings[position].push_back(
+              Builder(context).getI64IntegerAttr(value));
+        } while (!rhs.empty());
+      } else if (name.starts_with("^")) {
+        debugBindNames.emplace_back(OperationName(name.drop_front(), context));
+      } else {
+        debugBindNames.emplace_back(OperationName(name, context));
+      }
+    }
+
+    // Collect operations or results for extra bindings.
+    payloadRoot->walk([&](Operation *payload) {
+      for (auto &&[position, name] : llvm::enumerate(debugBindNames)) {
+        if (!name || payload->getName() != *name)
+          continue;
+
+        if (StringRef(*std::next(debugBindTrailingArgs.begin(), position))
+                .starts_with("^")) {
+          llvm::append_range(trailingBindings[position], payload->getResults());
+        } else {
+          trailingBindings[position].push_back(payload);
+        }
+      }
+    });
+
+    RaggedArray<transform::MappedValue> bindings;
+    bindings.push_back(ArrayRef<Operation *>{payloadRoot});
+    for (SmallVector<transform::MappedValue> &trailing : trailingBindings)
+      bindings.push_back(std::move(trailing));
+    return bindings;
+  }
+
 public:
   using Base::Base;
 
@@ -67,34 +134,18 @@ class InterpreterPass
         findPayloadRoot(getOperation(), debugPayloadRootTag);
     if (!payloadRoot)
       return signalPassFailure();
-    auto debugBindNames = llvm::map_to_vector(
-        debugBindTrailingArgs,
-        [&](const std::string &name) { return OperationName(name, context); });
-    SmallVector<SmallVector<Operation *>, 2> trailingBindings;
-    trailingBindings.resize(debugBindNames.size());
-    payloadRoot->walk([&](Operation *payload) {
-      for (auto &&[position, name] : llvm::enumerate(debugBindNames)) {
-        if (payload->getName() == name)
-          trailingBindings[position].push_back(payload);
-      }
-    });
 
     Operation *transformEntryPoint = transform::detail::findTransformEntryPoint(
         getOperation(), transformModule, entryPoint);
-    if (!transformEntryPoint) {
-      getOperation()->emitError()
-          << "could not find transform entry point: " << entryPoint
-          << " in either payload or transform module";
+    if (!transformEntryPoint)
       return signalPassFailure();
-    }
-
-    RaggedArray<transform::MappedValue> bindings;
-    bindings.push_back(ArrayRef<Operation *>{payloadRoot});
-    for (SmallVector<Operation *> &trailing : trailingBindings)
-      bindings.push_back(std::move(trailing));
 
+    std::optional<RaggedArray<transform::MappedValue>> bindings =
+        parseArguments(payloadRoot);
+    if (!bindings)
+      return signalPassFailure();
     if (failed(transform::applyTransformNamedSequence(
-            bindings,
+            *bindings,
             cast<transform::TransformOpInterface>(transformEntryPoint),
             transformModule,
             options.enableExpensiveChecks(!disableExpensiveChecks)))) {
diff --git a/mlir/test/Dialect/Transform/include/test-interpreter-external-concurrent-source.mlir b/mlir/test/Dialect/Transform/include/test-interpreter-external-concurrent-source.mlir
index 316b90f85236e4..255ff5f31ed3f5 100644
--- a/mlir/test/Dialect/Transform/include/test-interpreter-external-concurrent-source.mlir
+++ b/mlir/test/Dialect/Transform/include/test-interpreter-external-concurrent-source.mlir
@@ -1,16 +1,21 @@
 // RUN: mlir-opt %s
 // No need to check anything else than parsing here, this is being used by another test as data.
 
-transform.with_pdl_patterns {
-^bb0(%arg0: !transform.any_op):
-  pdl.pattern @func_return : benefit(1) {
-    %0 = pdl.operation "func.return"
-    pdl.rewrite %0 with "transform.dialect"
-  }
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%root: !transform.any_op) {
+    transform.with_pdl_patterns %root : !transform.any_op {
+    ^bb0(%arg0: !transform.any_op):
+      pdl.pattern @func_return : benefit(1) {
+        %0 = pdl.operation "func.return"
+        pdl.rewrite %0 with "transform.dialect"
+      }
 
-  sequence %arg0 : !transform.any_op failures(propagate) {
-  ^bb1(%arg1: !transform.any_op):
-    %0 = pdl_match @func_return in %arg1 : (!transform.any_op) -> !transform.op<"func.return">
-    transform.debug.emit_remark_at %0, "matched" : !transform.op<"func.return">
+      sequence %arg0 : !transform.any_op failures(propagate) {
+      ^bb1(%arg1: !transform.any_op):
+        %0 = pdl_match @func_return in %arg1 : (!transform.any_op) -> !transform.op<"func.return">
+        transform.debug.emit_remark_at %0, "matched" : !transform.op<"func.return">
+      }
+    }
+    transform.yield
   }
 }
diff --git a/mlir/test/Dialect/Transform/include/test-interpreter-external-source.mlir b/mlir/test/Dialect/Transform/include/test-interpreter-external-source.mlir
index 5956c86ebbe4b2..f6b7f787cc2c38 100644
--- a/mlir/test/Dialect/Transform/include/test-interpreter-external-source.mlir
+++ b/mlir/test/Dialect/Transform/include/test-interpreter-external-source.mlir
@@ -1,11 +1,13 @@
 // RUN: mlir-opt %s
 // No need to check anything else than parsing here, this is being used by another test as data.
 
-transform.sequence failures(propagate) {
-^bb0(%arg0: !transform.any_op):
-  transform.debug.emit_remark_at %arg0, "outer" : !transform.any_op
-  transform.sequence %arg0 : !transform.any_op failures(propagate) attributes {transform.target_tag="transform"} {
-  ^bb1(%arg1: !transform.any_op):
-    transform.debug.emit_remark_at %arg1, "inner" : !transform.any_op
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    transform.debug.emit_remark_at %arg0, "outer" : !transform.any_op
+    transform.sequence %arg0 : !transform.any_op failures(propagate) attributes {transform.target_tag="transform"} {
+    ^bb1(%arg1: !transform.any_op):
+      transform.debug.emit_remark_at %arg1, "inner" : !transform.any_op
+    }
+    transform.yield
   }
 }
diff --git a/mlir/test/Dialect/Transform/multi-arg-top-level-ops.mlir b/mlir/test/Dialect/Transform/multi-arg-top-level-ops.mlir
index 9a7e7ca2f9536e..1c018b1b1f7796 100644
--- a/mlir/test/Dialect/Transform/multi-arg-top-level-ops.mlir
+++ b/mlir/test/Dialect/Transform/multi-arg-top-level-ops.mlir
@@ -1,10 +1,15 @@
-// RUN: mlir-opt %s --pass-pipeline='builtin.module(test-transform-dialect-interpreter{bind-first-extra-to-ops=func.func bind-second-extra-to-ops=func.return})' \
-// RUN:             --split-input-file --verify-diagnostics
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(transform-interpreter{\
+// RUN:       debug-bind-trailing-args=func.func,func.return})" \
+// RUN:   --split-input-file --verify-diagnostics
 
-transform.sequence failures(propagate) {
-^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
-  transform.debug.emit_remark_at %arg1, "first extra" : !transform.any_op
-  transform.debug.emit_remark_at %arg2, "second extra" : !transform.any_op
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(
+      %arg0: !transform.any_op, %arg1: !transform.any_op,
+      %arg2: !transform.any_op) {
+    transform.debug.emit_remark_at %arg1, "first extra" : !transform.any_op
+    transform.debug.emit_remark_at %arg2, "second extra" : !transform.any_op
+    transform.yield
+  }
 }
 
 // expected-remark @below {{first extra}}
@@ -26,9 +31,13 @@ func.func @bar(%arg0: i1) {
 
 // -----
 
-transform.sequence failures(propagate) {
-^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.param<i64>):
-  // expected-error @above {{wrong kind of value provided for top-level parameter}}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(
+      %arg0: !transform.any_op, %arg1: !transform.any_op,
+      %arg2: !transform.param<i64>) {
+    // expected-error @above {{wrong kind of value provided for top-level parameter}}
+    transform.yield
+  }
 }
 
 func.func @foo() {
@@ -37,9 +46,13 @@ func.func @foo() {
 
 // -----
 
-transform.sequence failures(propagate) {
-^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_value):
-  // expected-error @above {{wrong kind of value provided for the top-level value handle}}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(
+      %arg0: !transform.any_op, %arg1: !transform.any_op,
+      %arg2: !transform.any_value) {
+    // expected-error @above {{wrong kind of value provided for the top-level value handle}}
+    transform.yield
+  }
 }
 
 func.func @foo() {
@@ -48,19 +61,27 @@ func.func @foo() {
 
 // -----
 
-// expected-error @below {{operation expects 1 extra value bindings, but 2 were provided to the interpreter}}
-transform.sequence failures(propagate) {
-^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op):
+
+module attributes {transform.with_named_sequence} {
+  // expected-error @below {{operation expects 1 extra value bindings, but 2 were provided to the interpreter}}
+  transform.named_sequence @__transform_main(
+      %arg0: !transform.any_op, %arg1: !transform.any_op) {
+    transform.yield
+  }
 }
 
 // -----
 
-transform.sequence failures(propagate) {
-^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
-  transform.sequence %arg0, %arg1, %arg2 : !transform.any_op, !transform.any_op, !transform.any_op failures(propagate) {
-  ^bb0(%arg3: !transform.any_op, %arg4: !transform.any_op, %arg5: !transform.any_op):
-    transform.debug.emit_remark_at %arg4, "first extra" : !transform.any_op
-    transform.debug.emit_remark_at %arg5, "second extra" : !transform.any_op
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(
+      %arg0: !transform.any_op, %arg1: !transform.any_op,
+      %arg2: !transform.any_op) {
+    transform.sequence %arg0, %arg1, %arg2 : !transform.any_op, !transform.any_op, !transform.any_op failures(propagate) {
+    ^bb0(%arg3: !transform.any_op, %arg4: !transform.any_op, %arg5: !transform.any_op):
+      transform.debug.emit_remark_at %arg4, "first extra" : !transform.any_op
+      transform.debug.emit_remark_at %arg5, "second extra" : !transform.any_op
+    }
+    transform.yield
   }
 }
 
diff --git a/mlir/test/Dialect/Transform/multi-arg-top-level-params.mlir b/mlir/test/Dialect/Transform/multi-arg-top-level-params.mlir
index f59a4b6d4ccc32..6486bcae3294e4 100644
--- a/mlir/test/Dialect/Transform/multi-arg-top-level-params.mlir
+++ b/mlir/test/Dialect/Transform/multi-arg-top-level-params.mlir
@@ -1,24 +1,37 @@
-// RUN: mlir-opt %s --pass-pipeline='builtin.module(test-transform-dialect-interpreter{bind-first-extra-to-params=1,2,3 bind-second-extra-to-params=42,45})' \
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(transform-interpreter{\
+// RUN:                             debug-bind-trailing-args=#1;2;3,#42;45})' \
 // RUN:          --split-input-file --verify-diagnostics
 
-transform.sequence failures(propagate) {
-^bb0(%arg0: !transform.any_op, %arg1: !transform.param<i64>, %arg2: !transform.param<i64>):
-  // expected-remark @below {{1 : i64, 2 : i64, 3 : i64}}
-  transform.debug.emit_param_as_remark %arg1 : !transform.param<i64>
-  // expected-remark @below {{42 : i64, 45 : i64}}
-  transform.debug.emit_param_as_remark %arg2 : !transform.param<i64>
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(
+      %arg0: !transform.any_op, %arg1: !transform.param<i64>,
+      %arg2: !transform.param<i64>) {
+    // expected-remark @below {{1 : i64, 2 : i64, 3 : i64}}
+    transform.debug.emit_param_as_remark %arg1 : !transform.param<i64>
+    // expected-remark @below {{42 : i64, 45 : i64}}
+    transform.debug.emit_param_as_remark %arg2 : !transform.param<i64>
+    transform.yield
+  }
 }
 
 // -----
 
-transform.sequence failures(propagate) {
-^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.param<i64>):
-  // expected-error @above {{wrong kind of value provided for top-level operation handle}}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(
+      %arg0: !transform.any_op, %arg1: !transform.any_op,
+      // expected-error @above {{wrong kind of value provided for top-level operation handle}}
+      %arg2: !transform.param<i64>) {
+    transform.yield
+  }
 }
 
 // -----
 
-// expected-error @below {{operation expects 3 extra value bindings, but 2 were provided to the interpreter}}
-transform.sequence failures(propagate) {
-^bb0(%arg0: !transform.any_op, %arg1: !transform.param<i64>, %arg2: !transform.param<i64>, %arg3: !transform.param<i64>):
+module attributes {transform.with_named_sequence} {
+  // expected-error @below {{operation expects 3 extra value bindings, but 2 were provided to the interpreter}}
+  transform.named_sequence @__transform_main(
+      %arg0: !transform.any_op, %arg1: !transform.param<i64>,
+      %arg2: !transform.param<i64>, %arg3: !transform.param<i64>) {
+    transform.yield
+  }
 }
diff --git a/mlir/test/Dialect/Transform/multi-arg-top-level-values.mlir b/mlir/test/Dialect/Transform/multi-arg-top-level-values.mlir
index 38d7e28697774d..dcc1079267dc7c 100644
--- a/mlir/test/Dialect/Transform/multi-arg-top-level-values.mlir
+++ b/mlir/test/Dialect/Transform/multi-arg-top-level-values.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s --pass-pipeline='builtin.module(test-transform-dialect-interpreter{bind-first-extra-to-results-of-ops=test.some_returning_op bind-second-extra-to-results-of-ops=test.some_other_returning_op})' \
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(transform-interpreter{\
+// RUN:     debug-bind-trailing-args=^test.some_returning_op,^test.some_other_returning_op})' \
 // RUN:             --split-input-file --verify-diagnostics
 
 // Note that diagnostic checker will merge two diagnostics with the same message
@@ -21,10 +22,14 @@
 // expected-note @below {{value handle points to an op result #1}}
 %2:2 = "test.some_other_returning_op"() : () -> (f32, f64)
 
-transform.sequence failures(propagate) {
-^bb0(%arg0: !transform.any_op, %arg1: !transform.any_value, %arg2: !transform.any_value):
-  transform.debug.emit_remark_at %arg1, "first extra" : !transform.any_value
-  transform.debug.emit_remark_at %arg2, "second extra" : !transform.any_value
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(
+      %arg0: !transform.any_op, %arg1: !transform.any_value,
+      %arg2: !transform.any_value) {
+    transform.debug.emit_remark_at %arg1, "first extra" : !transform.any_value
+    transform.debug.emit_remark_at %arg2, "second extra" : !transform.any_value
+    transform.yield
+  }
 }
 
 // -----
@@ -32,14 +37,19 @@ transform.sequence failures(propagate) {
 %0:2 = "test.some_returning_op"() : () -> (i32, i64)
 %1 = "test.some_returning_op"() : () -> index
 
-transform.sequence failures(propagate) {
-// expected-error @below {{wrong kind of value provided for top-level operation handle}}
-^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_value):
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(
+      // expected-error @below {{wrong kind of value provided for top-level operation handle}}
+      %arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_value) {
+    transform.yield
+  }
 }
 
 // -----
 
-// expected-error @below {{operation expects 1 extra value bindings, but 2 were provided to the interpreter}}
-transform.sequence failures(propagate) {
-^bb0(%arg0: !transform.any_op, %arg1: !transform.any_value):
+module attributes {transform.with_named_sequence} {
+  // expected-error @below {{operation expects 1 extra value bindings, but 2 were provided to the interpreter}}
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op, %arg1: !transform.any_value) {
+    transform.yield
+  }
 }
diff --git a/mlir/test/Dialect/Transform/test-interpreter-debug.mlir b/mlir/test/Dialect/Transform/test-interpreter-debug.mlir
index c7dad582dd432c..99301ea23c6f8d 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-debug.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-debug.mlir
@@ -1,19 +1,21 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{debug-payload-root-tag=payload debug-transform-root-tag=transform})" \
-// RUN:             --allow-unregistered-dialect --split-input-file --verify-diagnostics
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(transform-interpreter{\
+// RUN:         debug-payload-root-tag=payload \
+// RUN:         entry-point=transform})" \
+// RUN:   --allow-unregistered-dialect --split-input-file --verify-diagnostics
 
 // expected-error @below {{could not find the operation with transform.target_tag="payload" attribute}}
-module {
-  transform.sequence failures(suppress) {
-  ^bb0(%arg0: !transform.any_op):
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @transform(%arg0: !transform.any_op) {
+    transform.yield
   }
 }
 
 // -----
 
-// expected-error @below {{could not find the operation with transform.target_tag="transform" attribute}}
-module {
-  transform.sequence failures(suppress) {
-  ^bb0(%arg0: !transform.any_op):
+// expected-error @below {{could not find a nested named sequence with name: transform}}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @not_transform(%arg0: !transform.any_op) {
+    transform.yield
   }
 
   module attributes {transform.target_tag="payload"} {}
@@ -21,42 +23,16 @@ module {
 
 // -----
 
-// expected-error @below {{more than one operation with transform.target_tag="transform" attribute}}
-module {
-  // expected-note @below {{first operation}}
-  transform.sequence failures(propagate) attributes {transform.target_tag="transform"} {
-  ^bb0(%arg0: !transform.any_op):
-  }
-
-  // expected-note @below {{other operation}}
-  transform.sequence failures(propagate) attributes {transform.target_tag="transform"} {
-  ^bb0(%arg0: !transform.any_op):
-  }
-
-  module attributes {transform.target_tag="payload"} {}
-}
-
-// -----
-
-module {
-  // expected-error @below {{expected the transform entry point to be a top-level transform op}}
-  func.func private @foo() attributes {transf...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/81279


More information about the Mlir-commits mailing list