[Mlir-commits] [mlir] [mlir][Transform] Add a transform.match.operation_empty op to allow s… (PR #68319)

Nicolas Vasilache llvmlistbot at llvm.org
Fri Oct 6 00:15:49 PDT 2023


https://github.com/nicolasvasilache updated https://github.com/llvm/llvm-project/pull/68319

>From 711d3601ad93df805ccbe4f94fb14b321b7d9359 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolas.vasilache at gmail.com>
Date: Thu, 5 Oct 2023 15:01:10 +0000
Subject: [PATCH] [mlir][Transform] Add a transform.match.operation_empty op to
 allow specifying negative conditions

In the process, get_parent_op gains an attribute to allow it to return empty handles explicitly and still succeed.
---
 .../Dialect/Transform/IR/MatchInterfaces.h    |  95 ++++++--
 .../Dialect/Transform/IR/MatchInterfaces.td   |  21 +-
 .../Dialect/Transform/IR/TransformDialect.td  | 204 +++++++++---------
 .../mlir/Dialect/Transform/IR/TransformOps.td |  27 ++-
 .../lib/Dialect/Transform/IR/TransformOps.cpp |  20 +-
 .../Dialect/Transform/test-interpreter.mlir   |  72 +++++++
 6 files changed, 312 insertions(+), 127 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
index c8888f294f6ca1d..b155b110677d6c7 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
@@ -11,39 +11,71 @@
 
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/IR/OpDefinition.h"
+#include "llvm/ADT/STLExtras.h"
+#include <optional>
+#include <type_traits>
 
 namespace mlir {
 namespace transform {
 class MatchOpInterface;
 
+namespace detail {
+/// Dispatch `matchOperation` based on Operation* or std::optional<Operation*>
+/// first operand.
 template <typename OpTy>
-class SingleOpMatcherOpTrait
-    : public OpTrait::TraitBase<OpTy, SingleOpMatcherOpTrait> {
+DiagnosedSilenceableFailure matchOptionalOperation(OpTy op,
+                                                   TransformResults &results,
+                                                   TransformState &state) {
+  if constexpr (std::is_same_v<
+                    typename llvm::function_traits<
+                        decltype(&OpTy::matchOperation)>::template arg_t<0>,
+                    Operation *>) {
+    return op.matchOperation(nullptr, results, state);
+  } else {
+    return op.matchOperation(std::nullopt, results, state);
+  }
+}
+} // namespace detail
+
+template <typename OpTy>
+class AtMostOneOpMatcherOpTrait
+    : public OpTrait::TraitBase<OpTy, AtMostOneOpMatcherOpTrait> {
   template <typename T>
   using has_get_operand_handle =
       decltype(std::declval<T &>().getOperandHandle());
   template <typename T>
-  using has_match_operation = decltype(std::declval<T &>().matchOperation(
+  using has_match_operation_ptr = decltype(std::declval<T &>().matchOperation(
       std::declval<Operation *>(), std::declval<TransformResults &>(),
       std::declval<TransformState &>()));
+  template <typename T>
+  using has_match_operation_optional =
+      decltype(std::declval<T &>().matchOperation(
+          std::declval<std::optional<Operation *>>(),
+          std::declval<TransformResults &>(),
+          std::declval<TransformState &>()));
 
 public:
   static LogicalResult verifyTrait(Operation *op) {
     static_assert(llvm::is_detected<has_get_operand_handle, OpTy>::value,
-                  "SingleOpMatcherOpTrait expects operation type to have the "
-                  "getOperandHandle() method");
-    static_assert(llvm::is_detected<has_match_operation, OpTy>::value,
-                  "SingleOpMatcherOpTrait expected operation type to have the "
-                  "matchOperation(Operation *, TransformResults &, "
-                  "TransformState &) method");
+                  "AtMostOneOpMatcherOpTrait/SingleOpMatcherOpTrait expects "
+                  "operation type to have the getOperandHandle() method");
+    static_assert(
+        llvm::is_detected<has_match_operation_ptr, OpTy>::value ||
+            llvm::is_detected<has_match_operation_optional, OpTy>::value,
+        "AtMostOneOpMatcherOpTrait/SingleOpMatcherOpTrait expected operation "
+        "type to have either the matchOperation(Operation *, TransformResults "
+        "&, TransformState &) or the matchOperation(std::optional<Operation*>, "
+        "TransformResults &, TransformState &) method");
 
     // This must be a dynamic assert because interface registration is dynamic.
-    assert(isa<MatchOpInterface>(op) &&
-           "SingleOpMatchOpTrait is only available on operations with "
-           "MatchOpInterface");
+    assert(
+        isa<MatchOpInterface>(op) &&
+        "AtMostOneOpMatcherOpTrait/SingleOpMatchOpTrait is only available on "
+        "operations with MatchOpInterface");
     Value operandHandle = cast<OpTy>(op).getOperandHandle();
     if (!isa<TransformHandleTypeInterface>(operandHandle.getType())) {
-      return op->emitError() << "SingleOpMatchOpTrait requires the op handle "
+      return op->emitError() << "AtMostOneOpMatcherOpTrait/"
+                                "SingleOpMatchOpTrait requires the op handle "
                                 "to be of TransformHandleTypeInterface";
     }
 
@@ -55,12 +87,15 @@ class SingleOpMatcherOpTrait
                                     TransformState &state) {
     Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
     auto payload = state.getPayloadOps(operandHandle);
-    if (!llvm::hasSingleElement(payload)) {
+    if (!llvm::hasNItemsOrLess(payload, 1)) {
       return emitDefiniteFailure(this->getOperation()->getLoc())
-             << "SingleOpMatchOpTrait requires the operand handle to point to "
-                "a single payload op";
+             << "AtMostOneOpMatcherOpTrait requires the operand handle to "
+                "point to at most one payload op";
+    }
+    if (payload.empty()) {
+      return detail::matchOptionalOperation(cast<OpTy>(this->getOperation()),
+                                            results, state);
     }
-
     return cast<OpTy>(this->getOperation())
         .matchOperation(*payload.begin(), results, state);
   }
@@ -72,12 +107,32 @@ class SingleOpMatcherOpTrait
   }
 };
 
+template <typename OpTy>
+class SingleOpMatcherOpTrait : public AtMostOneOpMatcherOpTrait<OpTy> {
+
+public:
+  DiagnosedSilenceableFailure apply(TransformRewriter &rewriter,
+                                    TransformResults &results,
+                                    TransformState &state) {
+    Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
+    auto payload = state.getPayloadOps(operandHandle);
+    if (!llvm::hasSingleElement(payload)) {
+      return emitDefiniteFailure(this->getOperation()->getLoc())
+             << "SingleOpMatchOpTrait requires the operand handle to point to "
+                "a single payload op";
+    }
+    return static_cast<AtMostOneOpMatcherOpTrait<OpTy> *>(this)->apply(
+        rewriter, results, state);
+  }
+};
+
 template <typename OpTy>
 class SingleValueMatcherOpTrait
     : public OpTrait::TraitBase<OpTy, SingleValueMatcherOpTrait> {
 public:
   static LogicalResult verifyTrait(Operation *op) {
-    // This must be a dynamic assert because interface registration is dynamic.
+    // This must be a dynamic assert because interface registration is
+    // dynamic.
     assert(isa<MatchOpInterface>(op) &&
            "SingleValueMatchOpTrait is only available on operations with "
            "MatchOpInterface");
@@ -98,8 +153,8 @@ class SingleValueMatcherOpTrait
     auto payload = state.getPayloadValues(operandHandle);
     if (!llvm::hasSingleElement(payload)) {
       return emitDefiniteFailure(this->getOperation()->getLoc())
-             << "SingleValueMatchOpTrait requires the value handle to point to "
-                "a single payload value";
+             << "SingleValueMatchOpTrait requires the value handle to point "
+                "to a single payload value";
     }
 
     return cast<OpTy>(this->getOperation())
diff --git a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td
index 1f81fd5252eb45b..be92e4d91b42b32 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td
@@ -14,11 +14,28 @@ def MatchOpInterface
   let cppNamespace = "::mlir::transform";
 }
 
+// Trait for "matcher" transform operations that apply to an operation handle
+// associated with at most one payload operation. Checks that it is indeed
+// the case and produces a definite failure when it is not. The matching logic
+// is implemented in the `matchOperation` function instead of `apply`. The op
+// with this trait must provide a `Value getOperandHandle()` function that
+// returns the handle to be used for matching.
+def AtMostOneOpMatcher : NativeOpTrait<"AtMostOneOpMatcherOpTrait"> {
+  let cppNamespace = "::mlir::transform";
+
+  string extraDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure matchOperation(
+          ::std::optional<::mlir::Operation *> maybeCurrent,
+          ::mlir::transform::TransformResults &results,
+          ::mlir::transform::TransformState &state);
+  }];
+}
+
 // Trait for "matcher" transform operations that apply to an operation handle
 // associated with exactly one payload operation. Checks that it is indeed
 // the case and produces a definite failure when it is not. The matching logic
 // is implemented in the `matchOperation` function instead of `apply`. The op
-// with this trait must provide a `Value getOperandHandle()` function that 
+// with this trait must provide a `Value getOperandHandle()` function that
 // returns the handle to be used for matching.
 def SingleOpMatcher : NativeOpTrait<"SingleOpMatcherOpTrait"> {
   let cppNamespace = "::mlir::transform";
@@ -35,7 +52,7 @@ def SingleOpMatcher : NativeOpTrait<"SingleOpMatcherOpTrait"> {
 // associated with exactly one payload value. Checks that it is indeed
 // the case and produces a definite failure when it is not. The matching logic
 // is implemented in the `matchValue` function instead of `apply`. The op
-// with this trait must provide a `Value getOperandHandle()` function that 
+// with this trait must provide a `Value getOperandHandle()` function that
 // returns the handle to be used for matching.
 def SingleValueMatcher : NativeOpTrait<"SingleValueMatcherOpTrait"> {
   let cppNamespace = "::mlir::transform";
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
index 3448e27a41a6804..70a76ab9670f907 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
@@ -20,107 +20,109 @@ def Transform_Dialect : Dialect {
 
   let hasOperationAttrVerify = 1;
   let extraClassDeclaration = [{
-      /// Name of the attribute attachable to the symbol table operation
-      /// containing named sequences. This is used to trigger verification.
-      constexpr const static ::llvm::StringLiteral
-          kWithNamedSequenceAttrName = "transform.with_named_sequence";
-
-      /// Name of the attribute attachable to an operation so it can be
-      /// identified as root by the default interpreter pass.
-      constexpr const static ::llvm::StringLiteral
-          kTargetTagAttrName = "transform.target_tag";
-
-      /// Name of the attribute attachable to an operation, indicating that
-      /// TrackingListener failures should be silenced.
-      constexpr const static ::llvm::StringLiteral
-          kSilenceTrackingFailuresAttrName = "transform.silence_tracking_failures";
-
-      /// Names of the attributes indicating whether an argument of an external
-      /// transform dialect symbol is consumed or only read.
-      constexpr const static ::llvm::StringLiteral
-          kArgConsumedAttrName = "transform.consumed";
-      constexpr const static ::llvm::StringLiteral
-          kArgReadOnlyAttrName = "transform.readonly";
-
-      template <typename DataTy>
-      const DataTy &getExtraData() const {
-        return *static_cast<const DataTy *>(extraData.at(::mlir::TypeID::get<DataTy>()).get());
-      }
-
-      /// Parses a type registered by this dialect or one of its extensions.
-      ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override;
-
-      /// Prints a type registered by this dialect or one of its extensions.
-      void printType(::mlir::Type type,
-                     ::mlir::DialectAsmPrinter &printer) const override;
-
-      /// Parser callback for an individual type registered by this dialect or
-      /// its extensions.
-      using ExtensionTypeParsingHook = ::mlir::Type (*)(::mlir::AsmParser &);
-
-      /// Printer callback for an individual type registered by this dialect or
-      /// its extensions.
-      using ExtensionTypePrintingHook =
-          std::function<void (::mlir::Type, ::mlir::AsmPrinter &)>;
-
-    private:
-      /// Registers operations specified as template parameters with this
-      /// dialect. Checks that they implement the required interfaces.
-      template <typename... OpTys>
-      void addOperationsChecked() {
-        (addOperationIfNotRegistered<OpTys>(), ...);
-      }
-      template <typename OpTy>
-      void addOperationIfNotRegistered();
-
-      /// Reports a repeated registration error of an op with the given name.
-      [[noreturn]] void reportDuplicateOpRegistration(StringRef opName);
-
-      /// Registers the types specified as template parameters with the
-      /// Transform dialect. Checks that they meet the requirements for
-      /// Transform IR types.
-      template <typename... TypeTys>
-      void addTypesChecked() {
-        (addTypeIfNotRegistered<TypeTys>(), ...);
-      }
-      template <typename Type>
-      void addTypeIfNotRegistered();
-
-      /// Reports a repeated registration error of a type with the given
-      /// mnemonic.
-      [[noreturn]] void reportDuplicateTypeRegistration(StringRef mnemonic);
-
-      /// Registers dialect types with the context.
-      void initializeTypes();
-
-      // Give extensions access to injection functions.
-      template <typename, typename...>
-      friend class TransformDialectExtension;
-
-      /// Gets a mutable reference to extra data of the kind specified as
-      /// template argument. Allocates the data on the first call.
-      template <typename DataTy>
-      DataTy &getOrCreateExtraData();
-
-      //===----------------------------------------------------------------===//
-      // Data fields
-      //===----------------------------------------------------------------===//
-
-      /// Additional data associated with and owned by the dialect. Accessible
-      /// to extensions.
-      ::llvm::DenseMap<::mlir::TypeID, std::unique_ptr<
-            ::mlir::transform::detail::TransformDialectDataBase>>
-          extraData;
-
-      /// A map from type mnemonic to its parsing function for the remainder of
-      /// the syntax. The parser has access to the mnemonic, so it is used for
-      /// further dispatch.
-      ::llvm::StringMap<ExtensionTypeParsingHook> typeParsingHooks;
-
-      /// A map from type TypeID to its printing function. No need to do string
-      /// lookups when the type is fully constructed.
-      ::llvm::DenseMap<::mlir::TypeID, ExtensionTypePrintingHook>
-      typePrintingHooks;
+    /// Name of the attribute attachable to the symbol table operation
+    /// containing named sequences. This is used to trigger verification.
+    constexpr const static ::llvm::StringLiteral kWithNamedSequenceAttrName =
+        "transform.with_named_sequence";
+
+    /// Name of the attribute attachable to an operation so it can be
+    /// identified as root by the default interpreter pass.
+    constexpr const static ::llvm::StringLiteral kTargetTagAttrName =
+        "transform.target_tag";
+
+    /// Name of the attribute attachable to an operation, indicating that
+    /// TrackingListener failures should be silenced.
+    constexpr const static ::llvm::StringLiteral
+        kSilenceTrackingFailuresAttrName =
+            "transform.silence_tracking_failures";
+
+    /// Names of the attributes indicating whether an argument of an external
+    /// transform dialect symbol is consumed or only read.
+    constexpr const static ::llvm::StringLiteral kArgConsumedAttrName =
+        "transform.consumed";
+    constexpr const static ::llvm::StringLiteral kArgReadOnlyAttrName =
+        "transform.readonly";
+
+    template <typename DataTy>
+    const DataTy &getExtraData() const {
+      return *static_cast<const DataTy *>(
+          extraData.at(::mlir::TypeID::get<DataTy>()).get());
+    }
+
+    /// Parses a type registered by this dialect or one of its extensions.
+    ::mlir::Type parseType(::mlir::DialectAsmParser & parser) const override;
+
+    /// Prints a type registered by this dialect or one of its extensions.
+    void printType(::mlir::Type type, ::mlir::DialectAsmPrinter & printer)
+        const override;
+
+    /// Parser callback for an individual type registered by this dialect or
+    /// its extensions.
+    using ExtensionTypeParsingHook = ::mlir::Type (*)(::mlir::AsmParser &);
+
+    /// Printer callback for an individual type registered by this dialect or
+    /// its extensions.
+    using ExtensionTypePrintingHook =
+        std::function<void(::mlir::Type, ::mlir::AsmPrinter &)>;
+
+  private:
+    /// Registers operations specified as template parameters with this
+    /// dialect. Checks that they implement the required interfaces.
+    template <typename... OpTys>
+    void addOperationsChecked() {
+      (addOperationIfNotRegistered<OpTys>(), ...);
+    }
+    template <typename OpTy>
+    void addOperationIfNotRegistered();
+
+    /// Reports a repeated registration error of an op with the given name.
+    [[noreturn]] void reportDuplicateOpRegistration(StringRef opName);
+
+    /// Registers types specified as template parameters with the Transform
+    /// dialect. Checks that they meet the requirements for Transform IR types.
+    template <typename... TypeTys>
+    void addTypesChecked() {
+      (addTypeIfNotRegistered<TypeTys>(), ...);
+    }
+    template <typename Type>
+    void addTypeIfNotRegistered();
+
+    /// Reports a repeated registration error of a type with the given
+    /// mnemonic.
+    [[noreturn]] void reportDuplicateTypeRegistration(StringRef mnemonic);
+
+    /// Registers dialect types with the context.
+    void initializeTypes();
+
+    // Give extensions access to injection functions.
+    template <typename, typename...>
+    friend class TransformDialectExtension;
+
+    /// Gets a mutable reference to extra data of the kind specified as
+    /// template argument. Allocates the data on the first call.
+    template <typename DataTy>
+    DataTy &getOrCreateExtraData();
+
+    //===----------------------------------------------------------------===//
+    // Data fields
+    //===----------------------------------------------------------------===//
+
+    /// Additional data associated with and owned by the dialect. Accessible
+    /// to extensions.
+    ::llvm::DenseMap<
+        ::mlir::TypeID,
+        std::unique_ptr<::mlir::transform::detail::TransformDialectDataBase>>
+        extraData;
+
+    /// A map from type mnemonic to its parsing function for the remainder of
+    /// the syntax. The parser has access to the mnemonic, so it is used for
+    /// further dispatch.
+    ::llvm::StringMap<ExtensionTypeParsingHook> typeParsingHooks;
+
+    /// A map from type TypeID to its printing function. No need to do string
+    /// lookups when the type is fully constructed.
+    ::llvm::DenseMap<::mlir::TypeID, ExtensionTypePrintingHook>
+        typePrintingHooks;
   }];
 }
 
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index ca5c915ef8c2caa..5bc92e8e954eae7 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -595,8 +595,9 @@ def GetDefiningOp : TransformDialectOp<"get_defining_op",
 
 def GetParentOp : TransformDialectOp<"get_parent_op",
     [DeclareOpInterfaceMethods<TransformOpInterface>,
+     MatchOpInterface,
      NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
-  let summary = "Gets handles to the closest isolated-from-above parents";
+  let summary = "Gets handles to the closest parent ops";
   let description = [{
     The handle defined by this Transform op corresponds to the parents of the
     targeted payload ops (in the same order).
@@ -605,6 +606,9 @@ def GetParentOp : TransformDialectOp<"get_parent_op",
     that case for each target op, the closest parent op that fulfills all
     requirements, is returned.
     - `isolated_from_above`: the parent op must be isolated from above
+    - `allow_empty_results`: get_parent_op is allowed to return an empty list and
+      still succeeds. In such a case, if get_parent_op fails for any operation
+      in the list, the entire transform returns an empty handle.
     - `op_name`: the parent op must have the specified name
 
     If `deduplicate` is set, the result handle does not contain any duplicate
@@ -614,12 +618,14 @@ def GetParentOp : TransformDialectOp<"get_parent_op",
     is applied, e.g., "B" may itself be a parent of "A". This may have an impact
     on the further transformation applied to the handle produced here.
 
-    If any of the given Payload IR ops has no such suitable parent, the
-    transformation fails silently.
+    If any of the given Payload IR ops has no such suitable parent, then:
+      - if `allow_empty_results` is set, the result handle is empty
+      - otherwise, the transformation produces a silenceable failure.
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
                        UnitAttr:$isolated_from_above,
+                       UnitAttr:$allow_empty_results,
                        OptionalAttr<StrAttr>:$op_name,
                        UnitAttr:$deduplicate);
   let results = (outs TransformHandleTypeInterface:$parent);
@@ -739,6 +745,21 @@ def IncludeOp : TransformDialectOp<"include",
   }];
 }
 
+def MatchOperationEmptyOp : Op<Transform_Dialect, "match.operation_empty", [
+    AtMostOneOpMatcher,
+    MatchOpInterface,
+    MemoryEffectsOpInterface]> {
+  let summary =
+    "Matches if the handle is not associated to any op";
+  let description = [{
+    Succeeds if the handle is not associated to any op.
+  }];
+  let arguments = (ins TransformHandleTypeInterface:$operand_handle);
+  let assemblyFormat =
+      "$operand_handle attr-dict `:` type($operand_handle)";
+  let extraClassDeclaration = AtMostOneOpMatcher.extraDeclaration;
+}
+
 def MatchOperationNameOp : TransformDialectOp<"match.operation_name",
     [SingleOpMatcher,
      MatchOpInterface,
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 44626260e2f9ef3..0e20b379cc2a3e7 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1161,7 +1161,6 @@ void transform::ForeachOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   BlockArgument iterVar = getIterationVariable();
   if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
-
         return isHandleConsumed(iterVar, cast<TransformOpInterface>(&op));
       })) {
     consumesHandle(getTarget(), effects);
@@ -1244,6 +1243,10 @@ transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
       parent = parent->getParentOp();
     }
     if (!parent) {
+      if (getAllowEmptyResults()) {
+        results.set(llvm::cast<OpResult>(getResult()), parents);
+        return DiagnosedSilenceableFailure::success();
+      }
       DiagnosedSilenceableFailure diag =
           emitSilenceableError()
           << "could not find a parent op that matches all requirements";
@@ -1545,6 +1548,21 @@ transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
       .checkAndReport();
 }
 
+//===----------------------------------------------------------------------===//
+// MatchOperationEmptyOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
+    ::std::optional<::mlir::Operation *> maybeCurrent,
+    transform::TransformResults &results, transform::TransformState &state) {
+  if (!maybeCurrent.has_value()) {
+    DBGS_MATCHER() << "MatchOperationEmptyOp success\n";
+    return DiagnosedSilenceableFailure::success();
+  }
+  DBGS_MATCHER() << "MatchOperationEmptyOp failure\n";
+  return emitSilenceableError() << "operation is not empty";
+}
+
 //===----------------------------------------------------------------------===//
 // MatchOperationNameOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index daa179cb15408b4..3891c16b4115595 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -2037,3 +2037,75 @@ transform.sequence failures(propagate) {
   // expected-remark @below{{0}}
   test_print_number_of_associated_payload_ir_ops %empty_op : !transform.any_op
 }
+
+
+// -----
+
+func.func @no_constant_under_loop(%lb: index, %ub: index, %step: index) {
+  scf.for %i= %lb to %ub step %step {
+    arith.constant 0 : index
+  }
+  return
+}
+
+module @named_inclusion attributes { transform.with_named_sequence } {
+// Match `arith.constant`s that are not nested under a `scf.for` and ensure
+// there are none in the program
+
+transform.named_sequence @print(%root: !transform.any_op {transform.readonly}) {
+  transform.test_print_remark_at_operand %root, "matched func" : !transform.any_op
+  transform.yield 
+}
+
+transform.named_sequence @match_constant_not_under_scf_for(%root: !transform.any_op {transform.readonly}) 
+  -> !transform.any_op {
+  transform.match.operation_name %root ["arith.constant"] : !transform.any_op
+  %for = transform.get_parent_op %root { op_name = "scf.for", allow_empty_results }
+    : (!transform.any_op) -> (!transform.any_op)
+  transform.match.operation_empty %for : !transform.any_op
+  transform.yield %root : !transform.any_op
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  transform.foreach_match in %arg0
+      @match_constant_not_under_scf_for -> @print
+    : (!transform.any_op) -> (!transform.any_op)
+  transform.yield 
+}
+}
+
+// -----
+
+func.func @no_constant_under_loop(%lb: index, %ub: index, %step: index) {
+  // expected-remark @below {{no parent scf.for}}
+  arith.constant 0 : index
+  return
+}
+
+module @named_inclusion attributes { transform.with_named_sequence } {
+// Match `arith.constant`s that are not nested under a `scf.for` and ensure
+// there are none in the program
+
+transform.named_sequence @print(%root: !transform.any_op {transform.readonly}) {
+  transform.test_print_remark_at_operand %root, "no parent scf.for" : !transform.any_op
+  transform.yield 
+}
+
+transform.named_sequence @match_constant_not_under_scf_for(%root: !transform.any_op {transform.readonly}) 
+  -> !transform.any_op {
+  transform.match.operation_name %root ["arith.constant"] : !transform.any_op
+  %for = transform.get_parent_op %root { op_name = "scf.for", allow_empty_results }
+    : (!transform.any_op) -> (!transform.any_op)
+  transform.match.operation_empty %for : !transform.any_op
+  transform.yield %root : !transform.any_op
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  transform.foreach_match in %arg0
+      @match_constant_not_under_scf_for -> @print
+    : (!transform.any_op) -> (!transform.any_op)
+  transform.yield 
+}
+}



More information about the Mlir-commits mailing list