[Mlir-commits] [mlir] [mlir][Transform] Provide a minimal set of utils that allow implementing a simple transform dialect interpreter pass (PR #68330)

Nicolas Vasilache llvmlistbot at llvm.org
Thu Oct 5 10:35:09 PDT 2023


https://github.com/nicolasvasilache updated https://github.com/llvm/llvm-project/pull/68330

>From 2f01dd85ed3e411090af31010b32831a9022d52a Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolas.vasilache at gmail.com>
Date: Thu, 5 Oct 2023 15:01:10 +0000
Subject: [PATCH 1/2] [mlir][Transform] Add a transform.match.operation_empty
 op to allow specifying negative conditions

In the process, get_parent_op gains an attribute to allow it to return empty handles explicitly and still succeed.
---
 .../Dialect/Transform/IR/MatchInterfaces.h    | 93 +++++++++++++++----
 .../Dialect/Transform/IR/MatchInterfaces.td   | 21 ++++-
 .../mlir/Dialect/Transform/IR/TransformOps.td | 27 +++++-
 .../lib/Dialect/Transform/IR/TransformOps.cpp | 20 +++-
 .../Dialect/Transform/test-interpreter.mlir   | 72 ++++++++++++++
 5 files changed, 209 insertions(+), 24 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
index c8888f294f6ca1d..c52d8f976d607d9 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
@@ -11,14 +11,46 @@
 
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/IR/OpDefinition.h"
+#include "llvm/ADT/STLExtras.h"
+#include <optional>
+#include <type_traits>
 
 namespace mlir {
 namespace transform {
 class MatchOpInterface;
 
+namespace detail {
 template <typename OpTy>
-class SingleOpMatcherOpTrait
-    : public OpTrait::TraitBase<OpTy, SingleOpMatcherOpTrait> {
+DiagnosedSilenceableFailure
+matchOptionalOperationImpl(OpTy op, TransformResults &results,
+                           TransformState &state, std::false_type) {
+  return op.matchOperation(std::nullopt, results, state);
+}
+
+template <typename OpTy>
+DiagnosedSilenceableFailure
+matchOptionalOperationImpl(OpTy op, TransformResults &results,
+                           TransformState &state, std::true_type) {
+  return op.matchOperation(nullptr, results, state);
+}
+
+/// Dispatch `matchOperation` based on Operation* or std::optional<Operation*>
+/// first operand.
+template <typename OpTy, typename... Args>
+DiagnosedSilenceableFailure matchOptionalOperation(OpTy op,
+                                                   TransformResults &results,
+                                                   TransformState &state) {
+  using uses_operation_ptr_t = typename std::is_same<
+      typename llvm::function_traits<
+          decltype(&OpTy::matchOperation)>::template arg_t<0>,
+      Operation *>;
+  return matchOptionalOperationImpl(op, results, state, uses_operation_ptr_t{});
+}
+} // namespace detail
+
+template <typename OpTy>
+class AtMostOneOpMatcherOpTrait
+    : public OpTrait::TraitBase<OpTy, AtMostOneOpMatcherOpTrait> {
   template <typename T>
   using has_get_operand_handle =
       decltype(std::declval<T &>().getOperandHandle());
@@ -30,20 +62,22 @@ class SingleOpMatcherOpTrait
 public:
   static LogicalResult verifyTrait(Operation *op) {
     static_assert(llvm::is_detected<has_get_operand_handle, OpTy>::value,
-                  "SingleOpMatcherOpTrait expects operation type to have the "
-                  "getOperandHandle() method");
+                  "AtMostOneOpMatcherOpTrait/SingleOpMatcherOpTrait expects "
+                  "operation type to have the getOperandHandle() method");
     static_assert(llvm::is_detected<has_match_operation, OpTy>::value,
-                  "SingleOpMatcherOpTrait expected operation type to have the "
-                  "matchOperation(Operation *, TransformResults &, "
-                  "TransformState &) method");
+                  "AtMostOneOpMatcherOpTrait/SingleOpMatcherOpTrait expected "
+                  "operation type to have the matchOperation(Operation *, "
+                  "TransformResults &, TransformState &) method");
 
     // This must be a dynamic assert because interface registration is dynamic.
-    assert(isa<MatchOpInterface>(op) &&
-           "SingleOpMatchOpTrait is only available on operations with "
-           "MatchOpInterface");
+    assert(
+        isa<MatchOpInterface>(op) &&
+        "AtMostOneOpMatcherOpTrait/SingleOpMatchOpTrait is only available on "
+        "operations with MatchOpInterface");
     Value operandHandle = cast<OpTy>(op).getOperandHandle();
     if (!isa<TransformHandleTypeInterface>(operandHandle.getType())) {
-      return op->emitError() << "SingleOpMatchOpTrait requires the op handle "
+      return op->emitError() << "AtMostOneOpMatcherOpTrait/"
+                                "SingleOpMatchOpTrait requires the op handle "
                                 "to be of TransformHandleTypeInterface";
     }
 
@@ -55,12 +89,15 @@ class SingleOpMatcherOpTrait
                                     TransformState &state) {
     Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
     auto payload = state.getPayloadOps(operandHandle);
-    if (!llvm::hasSingleElement(payload)) {
+    if (!payload.empty() && !llvm::hasSingleElement(payload)) {
       return emitDefiniteFailure(this->getOperation()->getLoc())
-             << "SingleOpMatchOpTrait requires the operand handle to point to "
-                "a single payload op";
+             << "AtMostOneOpMatcherOpTrait requires the operand handle to "
+                "point to at most one payload op";
+    }
+    if (payload.empty()) {
+      return detail::matchOptionalOperation(cast<OpTy>(this->getOperation()),
+                                            results, state);
     }
-
     return cast<OpTy>(this->getOperation())
         .matchOperation(*payload.begin(), results, state);
   }
@@ -72,12 +109,32 @@ class SingleOpMatcherOpTrait
   }
 };
 
+template <typename OpTy>
+class SingleOpMatcherOpTrait : public AtMostOneOpMatcherOpTrait<OpTy> {
+
+public:
+  DiagnosedSilenceableFailure apply(TransformRewriter &rewriter,
+                                    TransformResults &results,
+                                    TransformState &state) {
+    Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
+    auto payload = state.getPayloadOps(operandHandle);
+    if (!llvm::hasSingleElement(payload)) {
+      return emitDefiniteFailure(this->getOperation()->getLoc())
+             << "SingleOpMatchOpTrait requires the operand handle to point to "
+                "a single payload op";
+    }
+    return static_cast<AtMostOneOpMatcherOpTrait<OpTy> *>(this)->apply(
+        rewriter, results, state);
+  }
+};
+
 template <typename OpTy>
 class SingleValueMatcherOpTrait
     : public OpTrait::TraitBase<OpTy, SingleValueMatcherOpTrait> {
 public:
   static LogicalResult verifyTrait(Operation *op) {
-    // This must be a dynamic assert because interface registration is dynamic.
+    // This must be a dynamic assert because interface registration is
+    // dynamic.
     assert(isa<MatchOpInterface>(op) &&
            "SingleValueMatchOpTrait is only available on operations with "
            "MatchOpInterface");
@@ -98,8 +155,8 @@ class SingleValueMatcherOpTrait
     auto payload = state.getPayloadValues(operandHandle);
     if (!llvm::hasSingleElement(payload)) {
       return emitDefiniteFailure(this->getOperation()->getLoc())
-             << "SingleValueMatchOpTrait requires the value handle to point to "
-                "a single payload value";
+             << "SingleValueMatchOpTrait requires the value handle to point "
+                "to a single payload value";
     }
 
     return cast<OpTy>(this->getOperation())
diff --git a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td
index 1f81fd5252eb45b..be92e4d91b42b32 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td
@@ -14,11 +14,28 @@ def MatchOpInterface
   let cppNamespace = "::mlir::transform";
 }
 
+// Trait for "matcher" transform operations that apply to an operation handle
+// associated with at most one payload operation. Checks that it is indeed
+// the case and produces a definite failure when it is not. The matching logic
+// is implemented in the `matchOperation` function instead of `apply`. The op
+// with this trait must provide a `Value getOperandHandle()` function that
+// returns the handle to be used for matching.
+def AtMostOneOpMatcher : NativeOpTrait<"AtMostOneOpMatcherOpTrait"> {
+  let cppNamespace = "::mlir::transform";
+
+  string extraDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure matchOperation(
+          ::std::optional<::mlir::Operation *> maybeCurrent,
+          ::mlir::transform::TransformResults &results,
+          ::mlir::transform::TransformState &state);
+  }];
+}
+
 // Trait for "matcher" transform operations that apply to an operation handle
 // associated with exactly one payload operation. Checks that it is indeed
 // the case and produces a definite failure when it is not. The matching logic
 // is implemented in the `matchOperation` function instead of `apply`. The op
-// with this trait must provide a `Value getOperandHandle()` function that 
+// with this trait must provide a `Value getOperandHandle()` function that
 // returns the handle to be used for matching.
 def SingleOpMatcher : NativeOpTrait<"SingleOpMatcherOpTrait"> {
   let cppNamespace = "::mlir::transform";
@@ -35,7 +52,7 @@ def SingleOpMatcher : NativeOpTrait<"SingleOpMatcherOpTrait"> {
 // associated with exactly one payload value. Checks that it is indeed
 // the case and produces a definite failure when it is not. The matching logic
 // is implemented in the `matchValue` function instead of `apply`. The op
-// with this trait must provide a `Value getOperandHandle()` function that 
+// with this trait must provide a `Value getOperandHandle()` function that
 // returns the handle to be used for matching.
 def SingleValueMatcher : NativeOpTrait<"SingleValueMatcherOpTrait"> {
   let cppNamespace = "::mlir::transform";
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index ca5c915ef8c2caa..2c6917236d34ddf 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -595,8 +595,9 @@ def GetDefiningOp : TransformDialectOp<"get_defining_op",
 
 def GetParentOp : TransformDialectOp<"get_parent_op",
     [DeclareOpInterfaceMethods<TransformOpInterface>,
+     MatchOpInterface,
      NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
-  let summary = "Gets handles to the closest isolated-from-above parents";
+  let summary = "Gets handles to the closest parent ops";
   let description = [{
     The handle defined by this Transform op corresponds to the parents of the
     targeted payload ops (in the same order).
@@ -605,6 +606,9 @@ def GetParentOp : TransformDialectOp<"get_parent_op",
     that case for each target op, the closest parent op that fulfills all
     requirements, is returned.
     - `isolated_from_above`: the parent op must be isolated from above
+    - `allow_empty_results`: get_parent_op is allowed to return an empty list and
+      still succeeds. In such a case, if get_parent_op fails for any operation
+      in the list, the entire transform returns an empty handle.
     - `op_name`: the parent op must have the specified name
 
     If `deduplicate` is set, the result handle does not contain any duplicate
@@ -614,12 +618,14 @@ def GetParentOp : TransformDialectOp<"get_parent_op",
     is applied, e.g., "B" may itself be a parent of "A". This may have an impact
     on the further transformation applied to the handle produced here.
 
-    If any of the given Payload IR ops has no such suitable parent, the
-    transformation fails silently.
+    If any of the given Payload IR ops has no such suitable parent, then:
+      - if `allow_empty_results` is set, the result handle is empty
+      - otherwise, the transformation fails silently.
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
                        UnitAttr:$isolated_from_above,
+                       UnitAttr:$allow_empty_results,
                        OptionalAttr<StrAttr>:$op_name,
                        UnitAttr:$deduplicate);
   let results = (outs TransformHandleTypeInterface:$parent);
@@ -739,6 +745,21 @@ def IncludeOp : TransformDialectOp<"include",
   }];
 }
 
+def MatchOperationEmptyOp : Op<Transform_Dialect, "match.operation_empty", [
+    AtMostOneOpMatcher,
+    MatchOpInterface,
+    MemoryEffectsOpInterface]> {
+  let summary =
+    "Matches if the handle is not associated to any op";
+  let description = [{
+    Succeeds if the handle is not associated to any op.
+  }];
+  let arguments = (ins TransformHandleTypeInterface:$operand_handle);
+  let assemblyFormat =
+      "$operand_handle attr-dict `:` type($operand_handle)";
+  let extraClassDeclaration = AtMostOneOpMatcher.extraDeclaration;
+}
+
 def MatchOperationNameOp : TransformDialectOp<"match.operation_name",
     [SingleOpMatcher,
      MatchOpInterface,
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 44626260e2f9ef3..0e20b379cc2a3e7 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1161,7 +1161,6 @@ void transform::ForeachOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   BlockArgument iterVar = getIterationVariable();
   if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
-
         return isHandleConsumed(iterVar, cast<TransformOpInterface>(&op));
       })) {
     consumesHandle(getTarget(), effects);
@@ -1244,6 +1243,10 @@ transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
       parent = parent->getParentOp();
     }
     if (!parent) {
+      if (getAllowEmptyResults()) {
+        results.set(llvm::cast<OpResult>(getResult()), parents);
+        return DiagnosedSilenceableFailure::success();
+      }
       DiagnosedSilenceableFailure diag =
           emitSilenceableError()
           << "could not find a parent op that matches all requirements";
@@ -1545,6 +1548,21 @@ transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
       .checkAndReport();
 }
 
+//===----------------------------------------------------------------------===//
+// MatchOperationEmptyOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
+    ::std::optional<::mlir::Operation *> maybeCurrent,
+    transform::TransformResults &results, transform::TransformState &state) {
+  if (!maybeCurrent.has_value()) {
+    DBGS_MATCHER() << "MatchOperationEmptyOp success\n";
+    return DiagnosedSilenceableFailure::success();
+  }
+  DBGS_MATCHER() << "MatchOperationEmptyOp failure\n";
+  return emitSilenceableError() << "operation is not empty";
+}
+
 //===----------------------------------------------------------------------===//
 // MatchOperationNameOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index daa179cb15408b4..3891c16b4115595 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -2037,3 +2037,75 @@ transform.sequence failures(propagate) {
   // expected-remark @below{{0}}
   test_print_number_of_associated_payload_ir_ops %empty_op : !transform.any_op
 }
+
+
+// -----
+
+func.func @no_constant_under_loop(%lb: index, %ub: index, %step: index) {
+  scf.for %i= %lb to %ub step %step {
+    arith.constant 0 : index
+  }
+  return
+}
+
+module @named_inclusion attributes { transform.with_named_sequence } {
+// Match `arith.constant`s that are not nested under a `scf.for` and ensure
+// there are none in the program
+
+transform.named_sequence @print(%root: !transform.any_op {transform.readonly}) {
+  transform.test_print_remark_at_operand %root, "matched func" : !transform.any_op
+  transform.yield 
+}
+
+transform.named_sequence @match_constant_not_under_scf_for(%root: !transform.any_op {transform.readonly}) 
+  -> !transform.any_op {
+  transform.match.operation_name %root ["arith.constant"] : !transform.any_op
+  %for = transform.get_parent_op %root { op_name = "scf.for", allow_empty_results }
+    : (!transform.any_op) -> (!transform.any_op)
+  transform.match.operation_empty %for : !transform.any_op
+  transform.yield %root : !transform.any_op
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  transform.foreach_match in %arg0
+      @match_constant_not_under_scf_for -> @print
+    : (!transform.any_op) -> (!transform.any_op)
+  transform.yield 
+}
+}
+
+// -----
+
+func.func @no_constant_under_loop(%lb: index, %ub: index, %step: index) {
+  // expected-remark @below {{no parent scf.for}}
+  arith.constant 0 : index
+  return
+}
+
+module @named_inclusion attributes { transform.with_named_sequence } {
+// Match `arith.constant`s that are not nested under a `scf.for` and ensure
+// there are none in the program
+
+transform.named_sequence @print(%root: !transform.any_op {transform.readonly}) {
+  transform.test_print_remark_at_operand %root, "no parent scf.for" : !transform.any_op
+  transform.yield 
+}
+
+transform.named_sequence @match_constant_not_under_scf_for(%root: !transform.any_op {transform.readonly}) 
+  -> !transform.any_op {
+  transform.match.operation_name %root ["arith.constant"] : !transform.any_op
+  %for = transform.get_parent_op %root { op_name = "scf.for", allow_empty_results }
+    : (!transform.any_op) -> (!transform.any_op)
+  transform.match.operation_empty %for : !transform.any_op
+  transform.yield %root : !transform.any_op
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  transform.foreach_match in %arg0
+      @match_constant_not_under_scf_for -> @print
+    : (!transform.any_op) -> (!transform.any_op)
+  transform.yield 
+}
+}

>From 94553f9d75485cba5078cf0409b4439b07ac851b Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolas.vasilache at gmail.com>
Date: Thu, 5 Oct 2023 16:11:00 +0000
Subject: [PATCH 2/2] [mlir][Transform] Provide a minimal set of utils that
 allow implementing a simple transform dialect interpreter pass

---
 .../Dialect/Transform/IR/TransformDialect.td  |  26 ++-
 .../Transform/IR/TransformInterfaces.h        |   6 +-
 .../Transforms/TransformInterpreterUtils.h    |  84 ++++++++
 .../Transform/IR/TransformInterfaces.cpp      |  23 +-
 .../Transform/Transforms/CMakeLists.txt       |   1 +
 .../Transforms/TransformInterpreterUtils.cpp  | 199 ++++++++++++++++++
 .../Dialect/Transform/CMakeLists.txt          |   2 +
 mlir/unittests/Dialect/Transform/Preload.cpp  |  92 ++++++++
 8 files changed, 419 insertions(+), 14 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
 create mode 100644 mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
 create mode 100644 mlir/unittests/Dialect/Transform/Preload.cpp

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
index 3448e27a41a6804..37cd8c4bddbea8a 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
@@ -20,6 +20,10 @@ def Transform_Dialect : Dialect {
 
   let hasOperationAttrVerify = 1;
   let extraClassDeclaration = [{
+      /// Symbol name for the default entry point "named sequence".
+      constexpr const static ::llvm::StringLiteral
+          kTransformEntryPointSymbolName = "__transform_main";
+      
       /// Name of the attribute attachable to the symbol table operation
       /// containing named sequences. This is used to trigger verification.
       constexpr const static ::llvm::StringLiteral
@@ -63,6 +67,21 @@ def Transform_Dialect : Dialect {
       using ExtensionTypePrintingHook =
           std::function<void (::mlir::Type, ::mlir::AsmPrinter &)>;
 
+      /// Appends the given module as a transform symbol library available to
+      /// all dialect users.
+      void registerLibraryModule(::mlir::OwningOpRef<::mlir::ModuleOp> &&
+                                 library) {
+        libraryModules.push_back(std::move(library));
+      }
+
+      /// Returns a range of registered library modules.
+      auto getLibraryModules() const {
+        return ::llvm::map_range(libraryModules, [
+        ](const ::mlir::OwningOpRef<::mlir::ModuleOp> &library) {
+          return library.get();
+        });
+      }
+
     private:
       /// Registers operations specified as template parameters with this
       /// dialect. Checks that they implement the required interfaces.
@@ -120,7 +139,12 @@ def Transform_Dialect : Dialect {
       /// A map from type TypeID to its printing function. No need to do string
       /// lookups when the type is fully constructed.
       ::llvm::DenseMap<::mlir::TypeID, ExtensionTypePrintingHook>
-      typePrintingHooks;
+          typePrintingHooks;
+
+      /// Modules containing symbols, e.g. named sequences, that will be
+      /// resolved by the interpreter when used.
+      ::llvm::SmallVector<::mlir::OwningOpRef<::mlir::ModuleOp>, 2>
+          libraryModules;
   }];
 }
 
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 0e72a93e685e32f..14285c661d253f4 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -111,7 +111,8 @@ class TransformOptions {
 LogicalResult
 applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
                 const RaggedArray<MappedValue> &extraMapping = {},
-                const TransformOptions &options = TransformOptions());
+                const TransformOptions &options = TransformOptions(),
+                bool enforceToplevelTransformOp = true);
 
 /// The state maintained across applications of various ops implementing the
 /// TransformOpInterface. The operations implementing this interface and the
@@ -193,7 +194,8 @@ class TransformState {
 
   friend LogicalResult applyTransforms(Operation *, TransformOpInterface,
                                        const RaggedArray<MappedValue> &,
-                                       const TransformOptions &);
+                                       const TransformOptions &,
+                                       bool enforceToplevelTransformOp);
 
   friend TransformState
   detail::makeTransformStateForTesting(Region *region, Operation *payloadRoot);
diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
new file mode 100644
index 000000000000000..7b203b5366bf9c2
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
@@ -0,0 +1,84 @@
+//===- TransformInterpreterUtils.h ---------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Lightweight transform dialect interpreter utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERUTILS_H
+#define MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERUTILS_H
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include <memory>
+
+namespace mlir {
+struct LogicalResult;
+class MLIRContext;
+class ModuleOp;
+class Operation;
+template <typename>
+class OwningOpRef;
+class Region;
+
+namespace transform {
+namespace detail {
+/// Utility to parse and verify the content of a `transformFileName` MLIR file
+/// containing a transform dialect specification.
+LogicalResult
+parseTransformInterpreterModule(MLIRContext *context,
+                                llvm::StringRef transformFileName,
+                                OwningOpRef<ModuleOp> &transformModule);
+
+/// Utility to load a transform interpreter `module` from a module that has
+/// already been preloaded in the context.
+/// This mode is useful in cases where expliciit parsing of a transform library
+/// from file is expected to be prohibitively expensive.
+/// In such cases, the transform module is expected to be found in the preloaded
+/// library modules of the transform dialect.
+LogicalResult
+getPreloadedTransformInterpreterModule(MLIRContext *context,
+                                       OwningOpRef<ModuleOp> &module);
+
+/// Finds the first TransformOpInterface named `kTransformEntryPointSymbolName`
+/// that is either:
+///   1. nested under `root` (takes precedence).
+///   2. nested under `module`, if not found in `root`.
+/// Reports returns null if no such operation found.
+TransformOpInterface findTransformEntryPoint(
+    Operation *root, ModuleOp *module = nullptr,
+    StringRef entryPoint = TransformDialect::kTransformEntryPointSymbolName);
+
+/// Replaces external symbols in `block` with their (non-external) definitions
+/// from the given module.
+LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions);
+} // namespace detail
+
+/// Standalone util to apply the named sequence `entryPoint` to the payload.
+/// This is done in 3 steps:
+///   1. lookup the `entryPoint` symbol in `{payload, sharedTransformModule}` by
+///   calling detail::findTransformEntryPoint.
+///   2. if the entry point is found and not nested under
+///   `sharedTransformModule`, call `detail::defineDeclaredSymbols` to "link" in
+///   the `sharedTransformModule`. Note: this may modify the transform IR
+///   embedded with the payload IR.
+///   3. apply the transform IR to the payload IR, relaxing the requirement that
+///   the transform IR is a top-level transform op. We are applying a named
+///   sequence anyway.
+LogicalResult applyTransformNamedSequence(
+    Operation *payload,
+    const std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
+    const TransformOptions &options,
+    StringRef entryPoint = TransformDialect::kTransformEntryPointSymbolName);
+
+} // namespace transform
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERUTILS_H
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 4a9bb2dba7d660c..9f1e68698f7844c 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -2079,18 +2079,19 @@ LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
 // Entry point.
 //===----------------------------------------------------------------------===//
 
-LogicalResult
-transform::applyTransforms(Operation *payloadRoot,
-                           TransformOpInterface transform,
-                           const RaggedArray<MappedValue> &extraMapping,
-                           const TransformOptions &options) {
+LogicalResult transform::applyTransforms(
+    Operation *payloadRoot, TransformOpInterface transform,
+    const RaggedArray<MappedValue> &extraMapping,
+    const TransformOptions &options, bool enforceToplevelTransformOp) {
 #ifndef NDEBUG
-  if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
-      transform->getNumOperands() != 0) {
-    transform->emitError()
-        << "expected transform to start at the top-level transform op";
-    llvm::report_fatal_error("could not run transforms",
-                             /*gen_crash_diag=*/false);
+  if (enforceToplevelTransformOp) {
+    if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
+        transform->getNumOperands() != 0) {
+      transform->emitError()
+          << "expected transform to start at the top-level transform op";
+      llvm::report_fatal_error("could not run transforms",
+                               /*gen_crash_diag=*/false);
+    }
   }
 #endif // NDEBUG
 
diff --git a/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
index 3f51ef1088f7af6..8774a8b86fb0d91 100644
--- a/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRTransformDialectTransforms
   CheckUses.cpp
   InferEffects.cpp
   TransformInterpreterPassBase.cpp
+  TransformInterpreterUtils.cpp
 
   DEPENDS
   MLIRTransformDialectTransformsIncGen
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
new file mode 100644
index 000000000000000..c2cfaa7884961f1
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
@@ -0,0 +1,199 @@
+//===- TransformInterpreterUtils.cpp --------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Lightweight transform dialect interpreter utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Support/FileUtilities.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+
+#define DEBUG_TYPE "transform-dialect-interpreter"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+LogicalResult transform::detail::parseTransformInterpreterModule(
+    MLIRContext *context, llvm::StringRef transformFileName,
+    OwningOpRef<ModuleOp> &transformModule) {
+  if (transformFileName.empty()) {
+    LLVM_DEBUG(
+        DBGS() << "no transform file name specified, assuming the transform "
+                  "module is embedded in the IR next to the top-level\n");
+    return success();
+  }
+  // Parse transformFileName content into a ModuleOp.
+  std::string errorMessage;
+  auto memoryBuffer = mlir::openInputFile(transformFileName, &errorMessage);
+  if (!memoryBuffer) {
+    return emitError(FileLineColLoc::get(
+               StringAttr::get(context, transformFileName), 0, 0))
+           << "failed to open transform file: " << errorMessage;
+  }
+  // Tell sourceMgr about this buffer, the parser will pick it up.
+  llvm::SourceMgr sourceMgr;
+  sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc());
+  transformModule =
+      OwningOpRef<ModuleOp>(parseSourceFile<ModuleOp>(sourceMgr, context));
+  if (failed(mlir::verify(*transformModule)))
+    return failure();
+  return success();
+}
+
+LogicalResult transform::detail::getPreloadedTransformInterpreterModule(
+    MLIRContext *context, OwningOpRef<ModuleOp> &transformModule) {
+  auto preloadedLibraryRange =
+      context->getOrLoadDialect<transform::TransformDialect>()
+          ->getLibraryModules();
+  if (!preloadedLibraryRange.empty()) {
+    transformModule =
+        OwningOpRef<ModuleOp>((*preloadedLibraryRange.begin()).clone());
+    return success();
+  }
+  return failure();
+}
+
+transform::TransformOpInterface
+transform::detail::findTransformEntryPoint(Operation *root, ModuleOp *module,
+                                           StringRef entryPoint) {
+  SmallVector<Operation *> l{root};
+  if (module)
+    l = SmallVector<Operation *>{root, module->getOperation()};
+  for (Operation *op : l) {
+    transform::TransformOpInterface transform = nullptr;
+    op->walk<WalkOrder::PreOrder>(
+        [&](transform::NamedSequenceOp namedSequenceOp) {
+          if (namedSequenceOp.getSymName() == entryPoint) {
+            transform = cast<transform::TransformOpInterface>(
+                namedSequenceOp.getOperation());
+            return WalkResult::interrupt();
+          }
+          return WalkResult::advance();
+        });
+    if (transform)
+      return transform;
+  }
+  auto diag = root->emitError()
+              << "could not find a nested named sequence with name: "
+              << entryPoint;
+  return nullptr;
+}
+
+LogicalResult transform::detail::defineDeclaredSymbols(Block &block,
+                                                       ModuleOp definitions) {
+  MLIRContext &ctx = *definitions->getContext();
+  auto consumedName =
+      StringAttr::get(&ctx, transform::TransformDialect::kArgConsumedAttrName);
+  auto readOnlyName =
+      StringAttr::get(&ctx, transform::TransformDialect::kArgReadOnlyAttrName);
+
+  for (Operation &op : llvm::make_early_inc_range(block)) {
+    LLVM_DEBUG(DBGS() << op << "\n");
+    auto symbol = dyn_cast<SymbolOpInterface>(op);
+    if (!symbol)
+      continue;
+    if (symbol->getNumRegions() == 1 && !symbol->getRegion(0).empty())
+      continue;
+
+    LLVM_DEBUG(DBGS() << "looking for definition of symbol "
+                      << symbol.getNameAttr() << ":");
+    SymbolTable symbolTable(definitions);
+    Operation *externalSymbol = symbolTable.lookup(symbol.getNameAttr());
+    if (!externalSymbol || externalSymbol->getNumRegions() != 1 ||
+        externalSymbol->getRegion(0).empty()) {
+      LLVM_DEBUG(llvm::dbgs() << "not found\n");
+      continue;
+    }
+
+    auto symbolFunc = dyn_cast<FunctionOpInterface>(op);
+    auto externalSymbolFunc = dyn_cast<FunctionOpInterface>(externalSymbol);
+    if (!symbolFunc || !externalSymbolFunc) {
+      LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n");
+      continue;
+    }
+
+    LLVM_DEBUG(llvm::dbgs() << "found @" << externalSymbol << "\n");
+    if (symbolFunc.getFunctionType() != externalSymbolFunc.getFunctionType()) {
+      return symbolFunc.emitError()
+             << "external definition has a mismatching signature ("
+             << externalSymbolFunc.getFunctionType() << ")";
+    }
+
+    for (unsigned i = 0, e = symbolFunc.getNumArguments(); i < e; ++i) {
+      bool isExternalConsumed =
+          externalSymbolFunc.getArgAttr(i, consumedName) != nullptr;
+      bool isExternalReadonly =
+          externalSymbolFunc.getArgAttr(i, readOnlyName) != nullptr;
+      bool isConsumed = symbolFunc.getArgAttr(i, consumedName) != nullptr;
+      bool isReadonly = symbolFunc.getArgAttr(i, readOnlyName) != nullptr;
+      if (!isExternalConsumed && !isExternalReadonly) {
+        if (isConsumed)
+          externalSymbolFunc.setArgAttr(i, consumedName, UnitAttr::get(&ctx));
+        else if (isReadonly)
+          externalSymbolFunc.setArgAttr(i, readOnlyName, UnitAttr::get(&ctx));
+        continue;
+      }
+
+      if ((isExternalConsumed && !isConsumed) ||
+          (isExternalReadonly && !isReadonly)) {
+        return symbolFunc.emitError()
+               << "external definition has mismatching consumption annotations "
+                  "for argument #"
+               << i;
+      }
+    }
+
+    OpBuilder builder(&op);
+    builder.setInsertionPoint(&op);
+    builder.clone(*externalSymbol);
+    symbol->erase();
+  }
+
+  return success();
+}
+
+LogicalResult transform::applyTransformNamedSequence(
+    Operation *payload,
+    const std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
+    const TransformOptions &options, StringRef entryPoint) {
+  ModuleOp transformModule =
+      (sharedTransformModule && sharedTransformModule->get())
+          ? sharedTransformModule->get()
+          : nullptr;
+  Operation *transformRoot =
+      detail::findTransformEntryPoint(payload, &transformModule, entryPoint);
+  if (!transformRoot)
+    return failure();
+
+  // `sharedTransformModule` may not be modified.
+  if (sharedTransformModule && sharedTransformModule.get() &&
+      !sharedTransformModule->get()->isAncestor(transformRoot)) {
+    if (failed(detail::defineDeclaredSymbols(*transformRoot->getBlock(),
+                                             sharedTransformModule->get())))
+      return failure();
+  }
+
+  // Apply the transform to the IR, do not enforce top-level constraints.
+  RaggedArray<MappedValue> noExtraMappings;
+  return applyTransforms(payload, cast<TransformOpInterface>(transformRoot),
+                         noExtraMappings, options,
+                         /*enforceToplevelTransformOp=*/false);
+}
diff --git a/mlir/unittests/Dialect/Transform/CMakeLists.txt b/mlir/unittests/Dialect/Transform/CMakeLists.txt
index 1fecd21221c91c8..cb8978cc29b6bd3 100644
--- a/mlir/unittests/Dialect/Transform/CMakeLists.txt
+++ b/mlir/unittests/Dialect/Transform/CMakeLists.txt
@@ -1,8 +1,10 @@
 add_mlir_unittest(MLIRTransformDialectTests
   BuildOnlyExtensionTest.cpp
+  Preload.cpp
 )
 target_link_libraries(MLIRTransformDialectTests
   PRIVATE
   MLIRFuncDialect
   MLIRTransformDialect
+  MLIRTransformDialectTransforms
 )
diff --git a/mlir/unittests/Dialect/Transform/Preload.cpp b/mlir/unittests/Dialect/Transform/Preload.cpp
new file mode 100644
index 000000000000000..f1f30a7c7527180
--- /dev/null
+++ b/mlir/unittests/Dialect/Transform/Preload.cpp
@@ -0,0 +1,92 @@
+//===- Preload.cpp - Test MlirOptMain parameterization ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Support/TypeID.h"
+#include "mlir/Tools/mlir-opt/MlirOptMain.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/raw_ostream.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+namespace mlir {
+namespace test {
+std::unique_ptr<Pass> createTestTransformDialectInterpreterPass();
+} // namespace test
+} // namespace mlir
+namespace test {
+void registerTestTransformDialectExtension(DialectRegistry &registry);
+} // namespace test
+
+const static llvm::StringLiteral library = R"MLIR(
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence public @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "from external symbol" : !transform.any_op
+    transform.yield
+  }
+})MLIR";
+
+const static llvm::StringLiteral input = R"MLIR(
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence private @__transform_main(%arg0: !transform.any_op {transform.readonly})
+
+  transform.sequence failures(propagate) {
+  ^bb0(%arg0: !transform.any_op):
+    include @__transform_main failures(propagate) (%arg0) : (!transform.any_op) -> ()
+  }
+})MLIR";
+
+TEST(Preload, ContextPreloadConstructedLibrary) {
+  registerPassManagerCLOptions();
+
+  MLIRContext context;
+  auto *dialect = context.getOrLoadDialect<transform::TransformDialect>();
+  ParserConfig parserConfig(&context);
+
+  OwningOpRef<ModuleOp> inputModule =
+      parseSourceString<ModuleOp>(input, parserConfig, "<input>");
+  EXPECT_TRUE(inputModule) << "failed to parse input module";
+
+  OwningOpRef<ModuleOp> transformLibrary =
+      parseSourceString<ModuleOp>(library, parserConfig, "<transform-library>");
+  EXPECT_TRUE(transformLibrary) << "failed to parse transform module";
+  dialect->registerLibraryModule(std::move(transformLibrary));
+
+  OwningOpRef<ModuleOp> retrievedTransformLibrary;
+  auto res = transform::detail::getPreloadedTransformInterpreterModule(
+      &context, retrievedTransformLibrary);
+  EXPECT_TRUE(succeeded(res)) << "failed to retrieve transform module";
+  EXPECT_TRUE(retrievedTransformLibrary)
+      << "failed to retrieve transform module";
+
+  ModuleOp transformModule = retrievedTransformLibrary.get();
+  transform::TransformOpInterface entryPoint =
+      transform::detail::findTransformEntryPoint(inputModule->getOperation(),
+                                                 &transformModule);
+  EXPECT_TRUE(entryPoint) << "failed to find entry point";
+
+  res = transform::detail::defineDeclaredSymbols(inputModule->getBodyRegion().front(), retrievedTransformLibrary.get());
+  EXPECT_TRUE(succeeded(res)) << "failed to define declared symbols";
+
+  transform::TransformOptions options;
+  std::shared_ptr<OwningOpRef<ModuleOp>>
+      sharedTransformModule = std::make_shared<OwningOpRef<ModuleOp>>(
+          std::move(retrievedTransformLibrary));
+  res = transform::applyTransformNamedSequence(inputModule->getOperation(),
+                                               sharedTransformModule, options);
+  EXPECT_TRUE(succeeded(res)) << "failed to apply named sequence";
+}



More information about the Mlir-commits mailing list