[Mlir-commits] [mlir] d064c48 - [mlir] Introduce Transform dialect

Alex Zinenko llvmlistbot at llvm.org
Thu Apr 14 04:48:53 PDT 2022


Author: Alex Zinenko
Date: 2022-04-14T13:48:45+02:00
New Revision: d064c4801c927ad5a706b66eaee4814fcc5ea45e

URL: https://github.com/llvm/llvm-project/commit/d064c4801c927ad5a706b66eaee4814fcc5ea45e
DIFF: https://github.com/llvm/llvm-project/commit/d064c4801c927ad5a706b66eaee4814fcc5ea45e.diff

LOG: [mlir] Introduce Transform dialect

This dialect provides operations that can be used to control transformation of
the IR using a different portion of the IR. It refers to the IR being
transformed as payload IR, and to the IR guiding the transformation as
transform IR.

The main use case for this dialect is orchestrating fine-grain transformations
on individual operations or sets thereof. For example, it may involve finding
loop-like operations with specific properties (e.g., large size) in the payload
IR, applying loop tiling to those and only those operations, and then applying
loop unrolling to the inner loops produced by the previous transformations. As
such, it is not intended as a replacement for the pass infrastructure, nor for
the pattern rewriting infrastructure. In the most common case, the transform IR
will be processed and applied to payload IR by a pass. Transformations
expressed by the transform dialect may be implemented using the pattern
infrastructure or any other relevant MLIR component.

This dialect is designed to be extensible, that is, clients of this dialect are
allowed to inject additional operations into this dialect using the newly
introduced in this patch `TransformDialectExtension` mechanism. This allows the
dialect to avoid a dependency on the implementation of the transformation as
well as to avoid introducing dialect-specific transform dialects.

See https://discourse.llvm.org/t/rfc-interfaces-and-dialects-for-precise-ir-transformation-control/60927.

Reviewed By: nicolasvasilache, Mogball, rriddle

Differential Revision: https://reviews.llvm.org/D123135

Added: 
    mlir/include/mlir/Dialect/Transform/CMakeLists.txt
    mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt
    mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
    mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
    mlir/lib/Dialect/Transform/CMakeLists.txt
    mlir/lib/Dialect/Transform/IR/CMakeLists.txt
    mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
    mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
    mlir/test/Dialect/Transform/test-dialect-injection.mlir
    mlir/test/Dialect/Transform/test-interpreter.mlir
    mlir/test/lib/Dialect/Transform/CMakeLists.txt
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
    mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
    mlir/test/lib/Dialect/Transform/lit.local.cfg

Modified: 
    mlir/include/mlir/Dialect/CMakeLists.txt
    mlir/include/mlir/InitAllDialects.h
    mlir/lib/Dialect/CMakeLists.txt
    mlir/test/lib/Dialect/CMakeLists.txt
    mlir/test/mlir-opt/commandline.mlir
    mlir/tools/mlir-opt/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
    utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index a0b5209838bc4..2db29357000ce 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -27,5 +27,6 @@ add_subdirectory(SparseTensor)
 add_subdirectory(SPIRV)
 add_subdirectory(Tensor)
 add_subdirectory(Tosa)
+add_subdirectory(Transform)
 add_subdirectory(Vector)
 add_subdirectory(X86Vector)

diff  --git a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt
new file mode 100644
index 0000000000000..f33061b2d87cf
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt
new file mode 100644
index 0000000000000..493e30d4a265d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt
@@ -0,0 +1,8 @@
+# The dialect does not have its own ops, so just generate the dialect files.
+set(LLVM_TARGET_DEFINITIONS TransformDialect.td)
+mlir_tablegen(TransformDialect.h.inc -gen-dialect-decls -dialect=transform)
+mlir_tablegen(TransformDialect.cpp.inc -gen-dialect-defs -dialect=transform)
+add_public_tablegen_target(MLIRTransformDialectIncGen)
+add_dependencies(mlir-headers MLIRTransformDialectIncGen)
+
+add_mlir_interface(TransformInterfaces)

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
new file mode 100644
index 0000000000000..3b5f52beb670e
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
@@ -0,0 +1,99 @@
+//===- TransformDialect.h - Transform Dialect Definition --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
+#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
+
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/Support/LLVM.h"
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h.inc"
+
+namespace mlir {
+namespace transform {
+
+#ifndef NDEBUG
+namespace detail {
+/// Asserts that the operations provided as template arguments implement the
+/// TransformOpInterface. This must be a dynamic assertion since interface
+/// implementations may be registered at runtime.
+template <typename OpTy>
+static inline void checkImplementsTransformInterface(MLIRContext *context) {
+  // Since the operation is being inserted into the Transform dialect and the
+  // dialect does not implement the interface fallback, only check for the op
+  // itself having the interface implementation.
+  RegisteredOperationName opName =
+      *RegisteredOperationName::lookup(OpTy::getOperationName(), context);
+  assert(opName.hasInterface<TransformOpInterface>() &&
+         "ops injected into the transform dialect must implement "
+         "TransformOpInterface");
+}
+} // namespace detail
+#endif // NDEBUG
+
+/// Base class for extensions of the Transform dialect that supports injecting
+/// operations into the Transform dialect at load time. Concrete extensions are
+/// expected to derive this class and register operations in the constructor.
+/// They can be registered with the DialectRegistry and automatically applied
+/// to the Transform dialect when it is loaded.
+template <typename DerivedTy, typename... ExtraDialects>
+class TransformDialectExtension
+    : public DialectExtension<DerivedTy, TransformDialect, ExtraDialects...> {
+  using Initializer = std::function<void(TransformDialect *)>;
+  using DialectLoader = std::function<void(MLIRContext *)>;
+
+public:
+  /// Extension application hook. Actually loads the dependent dialects and
+  /// registers the additional operations. Not expected to be called directly.
+  void apply(MLIRContext *context, TransformDialect *transformDialect,
+             ExtraDialects *...) const final {
+    for (const DialectLoader &loader : dialectLoaders)
+      loader(context);
+    for (const Initializer &init : opInitializers)
+      init(transformDialect);
+  }
+
+protected:
+  /// Injects the operations into the Transform dialect. The operations must
+  /// implement the TransformOpInterface and the implementation must be already
+  /// available when the operation is injected.
+  template <typename... OpTys>
+  void registerTransformOps() {
+    opInitializers.push_back([](TransformDialect *transformDialect) {
+      transformDialect->addOperations<OpTys...>();
+
+#ifndef NDEBUG
+      std::initializer_list<int>{
+          (detail::checkImplementsTransformInterface<OpTys>(
+               transformDialect->getContext()),
+           0)...};
+#endif // NDEBUG
+    });
+  }
+
+  /// Declares that this Transform dialect extension depends on the dialect
+  /// provided as template parameter. When the Transform dialect is loaded,
+  /// dependent dialects will be loaded as well. This is intended for dialects
+  /// that contain attributes and types used in creation and canonicalization of
+  /// the injected operations.
+  template <typename DialectTy>
+  void declareDependentDialect() {
+    dialectLoaders.push_back(
+        [](MLIRContext *context) { context->loadDialect<DialectTy>(); });
+  }
+
+private:
+  SmallVector<Initializer> opInitializers;
+  SmallVector<DialectLoader> dialectLoaders;
+};
+
+} // namespace transform
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
new file mode 100644
index 0000000000000..db14953ec1510
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
@@ -0,0 +1,175 @@
+//===- TransformDialect.td - Transform dialect definition --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT
+#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT
+
+include "mlir/IR/OpBase.td"
+
+def Transform_Dialect : Dialect {
+  let summary = "Fine-grain transformation control dialect";
+  let description = [{
+    ## Disclaimer
+
+    ** Proceed with care: not ready for general use. **
+
+    This dialect is evolving rapidly and may change on a very short notice. To
+    decrease the maintenance burden and churn, only a few in-tree use cases are
+    currently supported in the main tree:
+
+      - high-level transformations on "structured ops" (i.e. ops that operate on
+        chunks of data in a way that can be decomposed into operations on
+        smaller chunks of data and control flow) in Linalg, Tensor and Vector
+        dialects.
+
+    *Please post a description of the intended use case on the MLIR forum and
+    wait for confirmation.*
+
+    ## Overview
+
+    This dialect provides operations that can be used to control transformation
+    of the IR using a 
diff erent portion of the IR. It refers to the IR being
+    transformed as payload IR, and to the IR guiding the transformation as
+    transform IR.
+
+    The main use case for this dialect is orchestrating fine-grain
+    transformations on individual operations or sets thereof. For example, it
+    may involve finding loop-like operations with specific properties (e.g.,
+    large size) in the payload IR, applying loop tiling to those and only those
+    operations, and then applying loop unrolling to the inner loops produced
+    by the previous transformations. As such, it is not intended as a
+    replacement for the pass infrastructure, nor for the pattern rewriting
+    infrastructure. In the most common case, the transform IR will be processed
+    and applied to the payload IR by a pass. Transformations expressed by the
+    transform dialect may be implemented using the pattern infrastructure or any
+    other relevant MLIR component.
+
+    The following IR gives a rough idea of what the operations in this dialect
+    may look like:
+
+    ```mlir
+    %0 = transform.loop.find { size > 42 }
+    %1:2 = transform.loop.tile { tile_sizes = [2,3,4] }
+    transform.loop.unroll %1#1
+    ```
+
+    The values defined by operations in this dialect correspond to (groups of)
+    operations in the payload IR. In the example above, `%0` corresponds to the
+    set of loops found in the payload IR that satisfy the condition, and `%1`
+    correspond to groups of outer and inner loops, respectively, produced by
+    the tiling transformation.
+
+    This dialect is designed to be extensible, that is, clients of this dialect
+    are allowed to inject additional operations into this dialect using the
+    `TransformDialectExtension` mechanism. This allows the dialect to avoid a
+    dependency on the implementation of the transformation as well as to avoid
+    introducing dialect-specific transform dialects. In the example above,
+    the operations may have been injected by a notional `loop` dialect rather
+    than defined in this dialect, hence the common prefix.
+
+    It is recommended to prefix injected operations with one or several
+    dot-separated words that indicate which extension adds them. For
+    dialect-specific transformations, the prefix is naturally the name of the
+    dialect, e.g., `transform.affine.reschedule`. For dialect-agnostic
+    transformations (typically implemented using interfaces), the prefix may
+    be derived from the interface name or from a common concept, e.g.,
+    `transform.loop.tile` may apply to any loop-like operation that implements
+    `TileableOpInterface`. The C++ classes for the dialect extension should
+    include the prefix in their name, e.g., `AffineTransformDialectExtension` or
+    `LoopTransformDialectExtension` in the cases above. Unprefixed operation
+    names are reserved for ops defined directly in the Transform dialect.
+
+    ## Intended Use and Integrations
+
+    The transformation control infrastructure provided by this dialect is
+    positioned roughly between rewrite patterns and passes. A transformation
+    that is executed by a transform operation is likely to be sufficiently
+    complex to require at least a set of patterns to be implemented. It is also
+    expected to be more focused than a pass: a pass typically applies identical
+    transformations everywhere in the IR, a transform dialect-controlled
+    transformation would apply to a small subset of operations selected, e.g.,
+    by a pattern-matching operation or generated by a previous transformation.
+    It is discouraged, although technically possible, to run a pass pipeline as
+    part of the transform op implementation.
+
+    One of the main scenarios for using this dialect is fine-grain chaining of
+    transformations. For example, a loop-like operation may see its iteration
+    domain split into two parts, implemented as separate loops (transformation
+    known as index-set splitting), each of which is then transformed 
diff erently
+    (e.g., the first loop is tiled and the second unrolled) with the necessary
+    enabling and cleanup patterns around the main transformation:
+
+    ```mlir
+    // <generate %loop, e.g., by pattern-matching>
+    // ...
+    %parts:2 = transform.loop.split %loop { upper_bound_divisible_by = 8 }
+    transform.loop.tile %parts#0 { tile_sizes = [8] }
+    transform.loop.unroll %parts#1 { full }
+    ```
+
+    This composition would have been 
diff icult to implement as separate passes
+    since the hypothetical "tiling" and "unrolling" pass would need to somehow
+    
diff erentiate between the parts of the loop produced by the previous pass
+    (both are the same operation, and it is likely undesirable to pollute the
+    operation with pass-specific information). Implementing passes that run the
+    combined transfomration would have run into the combinatorial explosion
+    issue due to multiple possible transform compositions or into the need for
+    deep pass parameterization, the ultimate form of which is an ad-hoc dialect
+    to specify which transformations the pass should run. The transform dialect
+    provides a uniform, extensible mechanism for controlling transformations in
+    such cases.
+
+    The transform dialect is supposed to be consumed by an "interpreter" pass
+    that drives the application of transformations. To ensure extensibility and
+    composability, this pass is not expected to actually perform the
+    transformations specified by the ops. Instead, the transformations are
+    implemented by the transform ops themselves via `TransformOpInterface`. The
+    pass serves as the entry point, handles the flow of transform operations and
+    takes care of bookkeeping. As such, the transform dialect does not provide
+    the interpreter pass. Instead, it provides a set of utilities that can be
+    used by clients to define their own interpreter passes or as part of a more
+    complex pass. For example, the mapping between values in the tranfsorm IR
+    and operations in the payload IR, or the function that applies the
+    transformations specified by ops in the given block sequentially. Note that
+    a transform op may have regions with further transform ops in them, with
+    the op itself guiding how to dispatch the transformation control flow to
+    those regions. This approach allows clients to decide on the relative
+    location of the transform IR in their input (e.g., nested modules, separate
+    modules, optional regions to certain operations, etc.), register additional
+    transform operations and perform client-specific bookkeeping.
+
+    ## Effects on the Infrastructure
+
+    Although scoped to a single dialect, this functionality conceptually belongs
+    to the MLIR infrastructure. It aims to be minimally intrusive and opt-in.
+
+    Some infrastructural components may grow extra functionality to support the
+    transform dialect. In particular, the pattern infrastructure may add extra
+    hooks to identify the "main results" of a transformation or to notify
+    external observers about changes made to certain operations. These are not
+    expected to affect the existing uses of the infrastructure.
+
+    For the sake of reusability, transformations should be implemented as
+    utility functions that are called from the interface methods of transform
+    ops rather than having the methods directly act on the payload IR.
+  }];
+
+  let name = "transform";
+  let cppNamespace = "::mlir::transform";
+
+  let extraClassDeclaration = [{
+    // Make addOperations available to the TransformDialectExtension class.
+    private:
+      using ::mlir::Dialect::addOperations;
+
+      template <typename, typename...>
+      friend class TransformDialectExtension;
+  }];
+}
+
+#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
new file mode 100644
index 0000000000000..2caa20e6cdc66
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -0,0 +1,131 @@
+//===- TransformInterfaces.h - Transform Dialect Interfaces -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMINTERFACES_H
+#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMINTERFACES_H
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace transform {
+
+class TransformOpInterface;
+
+/// The state maintained across applications of various ops implementing the
+/// TransformOpInterface. The operations implementing this interface and the
+/// surrounding structure are referred to as transform IR. The operations to
+/// which transformations apply are referred to as payload IR. The state thus
+/// contains the mapping between values defined in the transform IR ops and
+/// payload IR ops. It assumes that each value in the transform IR can be used
+/// at most once (since transformations are likely to change the payload IR ops
+/// the value corresponds to). Checks that transform IR values correspond to
+/// disjoint sets of payload IR ops throughout the transformation.
+///
+/// A reference to this class is passed as an argument to "apply" methods of the
+/// transform op interface. Thus the "apply" method can call
+/// `state.getPayloadOps( getSomeOperand() )` to obtain the list of operations
+/// associated with its operand and subject to transformation. The method is
+/// expected to populate the `TransformResults` class instance in order to
+/// update the mapping. The `applyTransform` method takes care of propagating
+/// the state of `TransformResults` into the instance of this class.
+class TransformState {
+  /// Mapping between a Value in the transform IR and the corresponding set of
+  /// operations in the payload IR.
+  using TransformOpMapping = DenseMap<Value, SmallVector<Operation *>>;
+
+  /// Mapping between a payload IR operation and the transform IR value it is
+  /// currently associated with.
+  using TransformOpReverseMapping = DenseMap<Operation *, Value>;
+
+public:
+  /// Creates a state for the transformation rooted at the given op.
+  explicit TransformState(Operation *root);
+
+  /// Returns the op at which the transformation state is rooted. This is
+  /// typically helpful for transformations that apply globally.
+  Operation *getTopLevel() const;
+
+  /// Returns the list of ops that the given transform IR value corresponds to.
+  /// This is helpful for transformations that apply to a particular handle.
+  ArrayRef<Operation *> getPayloadOps(Value value) const;
+
+  /// Applies the transformation specified by the given transform op and updates
+  /// the state accordingly.
+  LogicalResult applyTransform(TransformOpInterface transform);
+
+private:
+  /// Identifier for storing top-level value in the `operations` mapping.
+  static constexpr Value kTopLevelValue = Value();
+
+  /// Sets the payload IR ops associated with the given transform IR value.
+  /// Fails if this would result in multiple transform IR values with uses
+  /// corresponding to the same payload IR ops. For example, a hypothetical
+  /// "find function by name" transform op would (indirectly) call this
+  /// function for its result. Having two such calls in a row with for 
diff erent
+  /// values, e.g. coming from 
diff erent ops:
+  ///
+  ///   %0 = transform.find_func_by_name { name = "myfunc" }
+  ///   %1 = transform.find_func_by_name { name = "myfunc" }
+  ///
+  /// would lead to both values pointing to the same operation. The second call
+  /// to setPayloadOps will fail, unless the association with the %0 value is
+  /// removed first by calling update/removePayloadOps.
+  LogicalResult setPayloadOps(Value value, ArrayRef<Operation *> targets);
+
+  /// Forgets the payload IR ops associated with the given transform IR value.
+  void removePayloadOps(Value value);
+
+  /// Updates the payload IR ops associated with the given transform IR value.
+  /// The callback function is called once per associated operation and is
+  /// expected to return the modified operation or nullptr. In the latter case,
+  /// the corresponding operation is no longer associated with the transform IR
+  /// value.
+  void updatePayloadOps(Value value,
+                        function_ref<Operation *(Operation *)> callback);
+
+  /// The mapping between payload IR values and transform IR ops.
+  TransformOpMapping operationMapping;
+  TransformOpReverseMapping reverseMapping;
+};
+
+/// Local mapping between values defined by a specific op implementing the
+/// TransformOpInterface and the payload IR ops they correspond to.
+class TransformResults {
+  friend class TransformState;
+
+public:
+  /// Indicates that the result of the transform IR op at the given position
+  /// corresponds to the given list of payload IR ops. Each result must be set
+  /// by the transformation exactly once.
+  void set(OpResult value, ArrayRef<Operation *> ops);
+
+private:
+  /// Creates an instance of TransformResults that expects mappings for
+  /// `numSegments` values.
+  explicit TransformResults(unsigned numSegments);
+
+  /// Gets the list of operations associated with the result identified by its
+  /// number in the list of operation results.
+  ArrayRef<Operation *> get(unsigned resultNumber) const;
+
+  /// Storage for pointers to payload IR ops that are associated with results of
+  /// a transform IR op. `segments` contains as many entries as the transform IR
+  /// op has results. Each entry is a reference to a contiguous segment in
+  /// the `operations` list that contains the pointers to operations. This
+  /// allows for operations to be stored contiguously without nested vectors and
+  /// for 
diff erent segments to be set in any order.
+  SmallVector<ArrayRef<Operation *>, 2> segments;
+  SmallVector<Operation *> operations;
+};
+
+} // namespace transform
+} // namespace mlir
+
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h.inc"
+
+#endif // DIALECT_TRANSFORM_IR_TRANSFORMINTERFACES_H

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
new file mode 100644
index 0000000000000..5b8d4202f0b4d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
@@ -0,0 +1,52 @@
+//===- TransformInterfaces.td - Transform Op interfaces ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the interfaces for transformation-related-ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_INTERFACES_TD
+#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_INTERFACES_TD
+
+include "mlir/IR/OpBase.td"
+
+def TransformOpInterface : OpInterface<"TransformOpInterface"> {
+  let description = [{
+    This interface is to be implemented by operations that identify
+    transformations to be performed on other operations. The former are referred
+    to as transform IR operations. The latter are referred to as payload IR
+    operations. Such transform IR operations provide a fine-grain control
+    mechanism over how transformations are applied by using and defining
+    transform IR values, referred to as handles, that correspond to sets of
+    operations in the payload IR. Transformations are applied starting from the
+    operations identified by handles, but may affect other operations as well.
+    Further restrictions may be imposed by flows that rely on transform IR
+    operations to control transformations.
+  }];
+
+  let cppNamespace = "::mlir::transform";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Applies the transformation represented by the current operation. This
+        accepts as arguments the object that must be populated with results of
+        the current transformation and a transformation state object that can be
+        used for queries, e.g., to obtain the list of operations on which the
+        transformation represented by the current op is targeted.
+      }],
+      /*returnType=*/"::mlir::LogicalResult",
+      /*name=*/"apply",
+      /*arguments=*/(ins
+          "::mlir::transform::TransformResults &":$transformResults,
+          "::mlir::transform::TransformState &":$state
+    )>,
+  ];
+}
+
+#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_INTERFACES_TD

diff  --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 7f370cd16bf7f..d2cc62241f0e6 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -51,6 +51,7 @@
 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/X86Vector/X86VectorDialect.h"
@@ -92,6 +93,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
                   shape::ShapeDialect,
                   sparse_tensor::SparseTensorDialect,
                   tensor::TensorDialect,
+                  transform::TransformDialect,
                   tosa::TosaDialect,
                   x86vector::X86VectorDialect>();
   // clang-format on

diff  --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index 6ffc8c3085a34..df5d9d22aae15 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -27,6 +27,7 @@ add_subdirectory(SparseTensor)
 add_subdirectory(SPIRV)
 add_subdirectory(Tensor)
 add_subdirectory(Tosa)
+add_subdirectory(Transform)
 add_subdirectory(Utils)
 add_subdirectory(Vector)
 add_subdirectory(X86Vector)

diff  --git a/mlir/lib/Dialect/Transform/CMakeLists.txt b/mlir/lib/Dialect/Transform/CMakeLists.txt
new file mode 100644
index 0000000000000..f33061b2d87cf
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)

diff  --git a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt
new file mode 100644
index 0000000000000..3a75f30ec7430
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt
@@ -0,0 +1,11 @@
+add_mlir_dialect_library(MLIRTransformDialect
+  TransformDialect.cpp
+  TransformInterfaces.cpp
+
+  DEPENDS
+  MLIRTransformDialectIncGen
+  MLIRTransformInterfacesIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  )

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
new file mode 100644
index 0000000000000..de0cbd6718621
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
@@ -0,0 +1,15 @@
+//===- TransformDialect.cpp - Transform Dialect Definition ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc"
+
+using namespace mlir;
+
+void transform::TransformDialect::initialize() {}

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
new file mode 100644
index 0000000000000..d7a8aee9447a0
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -0,0 +1,140 @@
+//===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Operation.h"
+#include "llvm/ADT/SmallPtrSet.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// TransformState
+//===----------------------------------------------------------------------===//
+
+constexpr const Value transform::TransformState::kTopLevelValue;
+
+transform::TransformState::TransformState(Operation *root) {
+  operationMapping[kTopLevelValue].push_back(root);
+}
+
+Operation *transform::TransformState::getTopLevel() const {
+  return operationMapping.lookup(kTopLevelValue).front();
+}
+
+ArrayRef<Operation *>
+transform::TransformState::getPayloadOps(Value value) const {
+  auto iter = operationMapping.find(value);
+  assert(iter != operationMapping.end() && "unknown handle");
+  return iter->getSecond();
+}
+
+LogicalResult
+transform::TransformState::setPayloadOps(Value value,
+                                         ArrayRef<Operation *> targets) {
+  assert(value != kTopLevelValue &&
+         "attempting to reset the transformation root");
+
+  if (value.use_empty())
+    return success();
+
+  // Setting new payload for the value without cleaning it first is a misuse of
+  // the API, assert here.
+  SmallVector<Operation *> storedTargets(targets.begin(), targets.end());
+  bool inserted =
+      operationMapping.insert({value, std::move(storedTargets)}).second;
+  assert(inserted && "value is already associated with another list");
+  (void)inserted;
+
+  // Having multiple handles to the same operation is an error in the transform
+  // expressed using the dialect and may be constructed by valid API calls from
+  // valid IR. Emit an error here.
+  for (Operation *op : targets) {
+    auto insertionResult = reverseMapping.insert({op, value});
+    if (!insertionResult.second) {
+      InFlightDiagnostic diag = op->emitError()
+                                << "operation tracked by two handles";
+      diag.attachNote(value.getLoc()) << "handle";
+      diag.attachNote(insertionResult.first->second.getLoc()) << "handle";
+      return diag;
+    }
+  }
+
+  return success();
+}
+
+void transform::TransformState::removePayloadOps(Value value) {
+  for (Operation *op : operationMapping[value])
+    reverseMapping.erase(op);
+  operationMapping.erase(value);
+}
+
+void transform::TransformState::updatePayloadOps(
+    Value value, function_ref<Operation *(Operation *)> callback) {
+  auto it = operationMapping.find(value);
+  assert(it != operationMapping.end() && "unknown handle");
+  SmallVector<Operation *> &association = it->getSecond();
+  SmallVector<Operation *> updated;
+  updated.reserve(association.size());
+
+  for (Operation *op : association)
+    if (Operation *updatedOp = callback(op))
+      updated.push_back(updatedOp);
+
+  std::swap(association, updated);
+}
+
+LogicalResult
+transform::TransformState::applyTransform(TransformOpInterface transform) {
+  transform::TransformResults results(transform->getNumResults());
+  if (failed(transform.apply(results, *this)))
+    return failure();
+
+  for (Value target : transform->getOperands())
+    removePayloadOps(target);
+
+  for (auto &en : llvm::enumerate(transform->getResults()))
+    if (failed(setPayloadOps(en.value(), results.get(en.index()))))
+      return failure();
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// TransformResults
+//===----------------------------------------------------------------------===//
+
+transform::TransformResults::TransformResults(unsigned numSegments) {
+  segments.resize(numSegments,
+                  ArrayRef<Operation *>(nullptr, static_cast<size_t>(0)));
+}
+
+void transform::TransformResults::set(OpResult value,
+                                      ArrayRef<Operation *> ops) {
+  unsigned position = value.getResultNumber();
+  assert(position < segments.size() &&
+         "setting results for a non-existent handle");
+  assert(segments[position].data() == nullptr && "results already set");
+  unsigned start = operations.size();
+  llvm::append_range(operations, ops);
+  segments[position] = makeArrayRef(operations).drop_front(start);
+}
+
+ArrayRef<Operation *>
+transform::TransformResults::get(unsigned resultNumber) const {
+  assert(resultNumber < segments.size() &&
+         "querying results for a non-existent handle");
+  assert(segments[resultNumber].data() != nullptr && "querying unset results");
+  return segments[resultNumber];
+}
+
+//===----------------------------------------------------------------------===//
+// Generated interface implementation.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc"

diff  --git a/mlir/test/Dialect/Transform/test-dialect-injection.mlir b/mlir/test/Dialect/Transform/test-dialect-injection.mlir
new file mode 100644
index 0000000000000..4474666a25eb4
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-dialect-injection.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-opt %s | FileCheck %s
+
+// These ops are defined by a test extension but should be okay to roundtrip.
+
+// CHECK: transform.test_transform_op
+transform.test_transform_op
+
+// CHECK: = transform.test_produce_param_or_forward_operand 42 {foo = "bar"}
+%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
+
+// CHECK: transform.test_consume_operand_if_matches_param_or_fail %{{.*}}[42]
+transform.test_consume_operand_if_matches_param_or_fail %0[42]

diff  --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
new file mode 100644
index 0000000000000..429d4f44da14c
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file --verify-diagnostics
+
+// expected-remark @below {{applying transformation}}
+transform.test_transform_op
+
+// -----
+
+%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
+// expected-remark @below {{succeeded}}
+transform.test_consume_operand_if_matches_param_or_fail %0[42]
+
+// -----
+
+%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
+// expected-error @below {{expected the operand to be associated with 21 got 42}}
+transform.test_consume_operand_if_matches_param_or_fail %0[21]
+
+// -----
+
+// expected-error @below {{operation tracked by two handles}}
+%0 = transform.test_produce_param_or_forward_operand 42
+// expected-note @below {{handle}}
+%1 = transform.test_produce_param_or_forward_operand from %0
+// expected-note @below {{handle}}
+%2 = transform.test_produce_param_or_forward_operand from %0
+transform.test_consume_operand_if_matches_param_or_fail %1[42]
+transform.test_consume_operand_if_matches_param_or_fail %2[42]

diff  --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index 4f5bbfba74063..7c8d1a709d137 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -11,4 +11,5 @@ add_subdirectory(SPIRV)
 add_subdirectory(Tensor)
 add_subdirectory(Test)
 add_subdirectory(Tosa)
+add_subdirectory(Transform)
 add_subdirectory(Vector)

diff  --git a/mlir/test/lib/Dialect/Transform/CMakeLists.txt b/mlir/test/lib/Dialect/Transform/CMakeLists.txt
new file mode 100644
index 0000000000000..2318f82795160
--- /dev/null
+++ b/mlir/test/lib/Dialect/Transform/CMakeLists.txt
@@ -0,0 +1,20 @@
+set(LLVM_TARGET_DEFINITIONS TestTransformDialectExtension.td)
+mlir_tablegen(TestTransformDialectExtension.h.inc -gen-op-decls)
+mlir_tablegen(TestTransformDialectExtension.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRTestTransformDialectExtensionIncGen)
+
+add_mlir_library(MLIRTestTransformDialect
+  TestTransformDialectExtension.cpp
+  TestTransformDialectInterpreter.cpp
+
+  EXCLUDE_FROM_LIBMLIR
+
+  DEPENDS
+  MLIRTestTransformDialectExtensionIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRPass
+  MLIRPDL
+  MLIRTransformDialect
+)

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
new file mode 100644
index 0000000000000..6cf00cf002b3f
--- /dev/null
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -0,0 +1,107 @@
+//===- TestTransformDialectExtension.cpp ----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines an extension of the MLIR Transform dialect for testing
+// purposes.
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestTransformDialectExtension.h"
+#include "mlir/Dialect/PDL/IR/PDL.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpImplementation.h"
+
+using namespace mlir;
+
+namespace {
+/// Simple transform op defined outside of the dialect. Just emits a remark when
+/// applied.
+class TestTransformOp
+    : public Op<TestTransformOp, transform::TransformOpInterface::Trait> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformOp)
+
+  using Op::Op;
+
+  static ArrayRef<StringRef> getAttributeNames() { return {}; }
+
+  static constexpr llvm::StringLiteral getOperationName() {
+    return llvm::StringLiteral("transform.test_transform_op");
+  }
+
+  LogicalResult apply(transform::TransformResults &results,
+                      transform::TransformState &state) {
+    emitRemark() << "applying transformation";
+    return success();
+  }
+
+  static ParseResult parse(OpAsmParser &parser, OperationState &state) {
+    return success();
+  }
+
+  void print(OpAsmPrinter &printer) {}
+};
+} // namespace
+
+LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::apply(
+    transform::TransformResults &results, transform::TransformState &state) {
+  if (getOperation()->getNumOperands() != 0) {
+    results.set(getResult().cast<OpResult>(), getOperand(0).getDefiningOp());
+  } else {
+    results.set(getResult().cast<OpResult>(),
+                reinterpret_cast<Operation *>(*parameter()));
+  }
+  return success();
+}
+
+LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() {
+  if (parameter().hasValue() ^ (getNumOperands() != 1))
+    return emitOpError() << "expects either a parameter or an operand";
+  return success();
+}
+
+LogicalResult mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply(
+    transform::TransformResults &results, transform::TransformState &state) {
+  ArrayRef<Operation *> payload = state.getPayloadOps(getOperand());
+  assert(payload.size() == 1 && "expected a single target op");
+  auto value = reinterpret_cast<intptr_t>(payload[0]);
+  if (value != parameter()) {
+    return emitOpError() << "expected the operand to be associated with "
+                         << parameter() << " got " << value;
+  }
+
+  emitRemark() << "succeeded";
+  return success();
+}
+
+namespace {
+/// Test extension of the Transform dialect. Registers additional ops and
+/// declares PDL as dependent dialect since the additional ops are using PDL
+/// types for operands and results.
+class TestTransformDialectExtension
+    : public transform::TransformDialectExtension<
+          TestTransformDialectExtension> {
+public:
+  TestTransformDialectExtension() {
+    declareDependentDialect<pdl::PDLDialect>();
+    registerTransformOps<TestTransformOp,
+#define GET_OP_LIST
+#include "TestTransformDialectExtension.cpp.inc"
+                         >();
+  }
+};
+} // namespace
+
+#define GET_OP_CLASSES
+#include "TestTransformDialectExtension.cpp.inc"
+
+void ::test::registerTestTransformDialectExtension(DialectRegistry &registry) {
+  registry.addExtensions<TestTransformDialectExtension>();
+}

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h
new file mode 100644
index 0000000000000..891249c854e0e
--- /dev/null
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h
@@ -0,0 +1,33 @@
+//===- TestTransformDialectExtension.h --------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines an extension of the MLIR Transform dialect for testing
+// purposes.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TESTTRANSFORMDIALECTEXTENSION_H
+#define MLIR_TESTTRANSFORMDIALECTEXTENSION_H
+
+#include "mlir/Dialect/PDL/IR/PDLTypes.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+class DialectRegistry;
+} // namespace mlir
+
+#define GET_OP_CLASSES
+#include "TestTransformDialectExtension.h.inc"
+
+namespace test {
+/// Registers the test extension to the Transform dialect.
+void registerTestTransformDialectExtension(::mlir::DialectRegistry &registry);
+} // namespace test
+
+#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_H

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
new file mode 100644
index 0000000000000..c263409c618d1
--- /dev/null
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -0,0 +1,41 @@
+//===- TestTransformDialectExtension.td --------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the operations that are injected into the Transform
+// dialect through the extension mechanism, as a test.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TESTTRANSFORMDIALECTEXTENSION_TD
+#define MLIR_TESTTRANSFORMDIALECTEXTENSION_TD
+
+include "mlir/IR/OpBase.td"
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
+include "mlir/Dialect/PDL/IR/PDLTypes.td"
+
+def TestProduceParamOrForwardOperandOp
+  : Op<Transform_Dialect, "test_produce_param_or_forward_operand",
+       [DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let arguments = (ins Optional<PDL_Operation>:$operand,
+                       OptionalAttr<I64Attr>:$parameter);
+  let results = (outs PDL_Operation:$res);
+  let assemblyFormat = "(`from` $operand^)? ($parameter^)? attr-dict";
+  let cppNamespace = "::mlir::test";
+  let hasVerifier = 1;
+}
+
+def TestConsumeOperandIfMatchesParamOrFail
+  : Op<Transform_Dialect, "test_consume_operand_if_matches_param_or_fail",
+       [DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let arguments = (ins PDL_Operation:$operand, I64Attr:$parameter);
+  let assemblyFormat = "$operand `[` $parameter `]` attr-dict";
+  let cppNamespace = "::mlir::test";
+}
+
+#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
new file mode 100644
index 0000000000000..854a41318dac2
--- /dev/null
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
@@ -0,0 +1,57 @@
+//===- TestTransformDialectInterpreter.cpp --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a test pass that interprets Transform dialect operations in
+// the module.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+/// Simple pass that applies transform dialect ops directly contained in a
+/// module.
+class TestTransformDialectInterpreterPass
+    : public PassWrapper<TestTransformDialectInterpreterPass,
+                         OperationPass<ModuleOp>> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestTransformDialectInterpreterPass)
+
+  StringRef getArgument() const override {
+    return "test-transform-dialect-interpreter";
+  }
+
+  StringRef getDescription() const override {
+    return "apply transform dialect operations one by one";
+  }
+
+  void runOnOperation() override {
+    ModuleOp module = getOperation();
+    transform::TransformState state(module);
+    for (auto op :
+         module.getBody()->getOps<transform::TransformOpInterface>()) {
+      if (failed(state.applyTransform(op)))
+        return signalPassFailure();
+    }
+  }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+/// Registers the test pass for applying transform dialect ops.
+void registerTestTransformDialectInterpreterPass() {
+  PassRegistration<TestTransformDialectInterpreterPass> reg;
+}
+} // namespace test
+} // namespace mlir

diff  --git a/mlir/test/lib/Dialect/Transform/lit.local.cfg b/mlir/test/lib/Dialect/Transform/lit.local.cfg
new file mode 100644
index 0000000000000..edb5b44b2e2fe
--- /dev/null
+++ b/mlir/test/lib/Dialect/Transform/lit.local.cfg
@@ -0,0 +1 @@
+config.suffixes.remove('.td')
\ No newline at end of file

diff  --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir
index 5ea1fd6e61394..1f606fdcc2029 100644
--- a/mlir/test/mlir-opt/commandline.mlir
+++ b/mlir/test/mlir-opt/commandline.mlir
@@ -33,5 +33,6 @@
 // CHECK-NEXT: tensor
 // CHECK-NEXT: test
 // CHECK-NEXT: tosa
+// CHECK-NEXT: transform
 // CHECK-NEXT: vector
 // CHECK-NEXT: x86vector

diff  --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index ad98114fcde23..7215629ca6786 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -30,6 +30,7 @@ if(MLIR_INCLUDE_TESTS)
     MLIRTestPass
     MLIRTestReducer
     MLIRTestRewrite
+    MLIRTestTransformDialect
     MLIRTestTransforms
     MLIRVectorTestPasses
     )

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 2be6e15e3aec0..5e946de7a42e3 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -107,12 +107,14 @@ void registerTestRecursiveTypesPass();
 void registerTestSCFUtilsPass();
 void registerTestSliceAnalysisPass();
 void registerTestTensorTransforms();
+void registerTestTransformDialectInterpreterPass();
 void registerTestVectorLowerings();
 } // namespace test
 } // namespace mlir
 
 namespace test {
 void registerTestDialect(DialectRegistry &);
+void registerTestTransformDialectExtension(DialectRegistry &);
 } // namespace test
 
 #ifdef MLIR_INCLUDE_TESTS
@@ -196,6 +198,7 @@ void registerTestPasses() {
   mlir::test::registerTestSCFUtilsPass();
   mlir::test::registerTestSliceAnalysisPass();
   mlir::test::registerTestTensorTransforms();
+  mlir::test::registerTestTransformDialectInterpreterPass();
   mlir::test::registerTestVectorLowerings();
 }
 #endif
@@ -209,6 +212,7 @@ int main(int argc, char **argv) {
   registerAllDialects(registry);
 #ifdef MLIR_INCLUDE_TESTS
   ::test::registerTestDialect(registry);
+  ::test::registerTestTransformDialectExtension(registry);
 #endif
   return mlir::asMainReturnCode(
       mlir::MlirOptMain(argc, argv, "MLIR modular optimizer driver\n", registry,

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 927f9887acc5e..ee15429333664 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6019,6 +6019,7 @@ cc_library(
         ":TensorTransforms",
         ":TosaDialect",
         ":TosaToLinalg",
+        ":TransformDialect",
         ":Transforms",
         ":TransformsPassIncGen",
         ":VectorOps",
@@ -6079,6 +6080,7 @@ cc_binary(
         "//mlir/test:TestShapeDialect",
         "//mlir/test:TestTensor",
         "//mlir/test:TestTosaDialect",
+        "//mlir/test:TestTransformDialect",
         "//mlir/test:TestTransforms",
         "//mlir/test:TestTypeDialect",
         "//mlir/test:TestVector",
@@ -7583,6 +7585,71 @@ cc_library(
     ],
 )
 
+td_library(
+    name = "TransformDialectTdFiles",
+    srcs = glob(["include/mlir/Dialect/Transform/IR/*.td"]),
+    deps = [
+        ":OpBaseTdFiles",
+    ],
+)
+
+gentbl_cc_library(
+    name = "TransformDialectInterfacesIncGen",
+    strip_include_prefix = "include",
+    tbl_outs = [
+        (
+            [
+                "-gen-op-interface-decls",
+            ],
+            "include/mlir/Dialect/Transform/IR/TransformInterfaces.h.inc",
+        ),
+        (
+            [
+                "-gen-op-interface-defs",
+            ],
+            "include/mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/Transform/IR/TransformInterfaces.td",
+    deps = [":TransformDialectTdFiles"],
+)
+
+gentbl_cc_library(
+    name = "TransformDialectIncGen",
+    strip_include_prefix = "include",
+    tbl_outs = [
+        (
+            [
+                "-gen-dialect-decls",
+            ],
+            "include/mlir/Dialect/Transform/IR/TransformDialect.h.inc",
+        ),
+        (
+            [
+                "-gen-dialect-defs",
+            ],
+            "include/mlir/Dialect/Transform/IR/TransformDialect.cpp.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/Transform/IR/TransformDialect.td",
+    deps = [":TransformDialectTdFiles"],
+)
+
+cc_library(
+    name = "TransformDialect",
+    srcs = glob(["lib/Dialect/Transform/IR/*.cpp"]),
+    hdrs = glob(["include/mlir/Dialect/Transform/IR/*.h"]),
+    deps = [
+        ":IR",
+        ":Support",
+        ":TransformDialectIncGen",
+        ":TransformDialectInterfacesIncGen",
+        "//llvm:Support",
+    ],
+)
+
 td_library(
     name = "ComplexOpsTdFiles",
     srcs = [

diff  --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index 6627a2ea45443..b131f8365de78 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -192,6 +192,51 @@ gentbl_cc_library(
     ],
 )
 
+td_library(
+    name = "TransformDialectTdFiles",
+    srcs = glob(["lib/Dialect/Transform/*.td"]),
+    deps = [
+        "//mlir:OpBaseTdFiles",
+    ],
+)
+
+gentbl_cc_library(
+    name = "TestTransformDialectExtensionIncGen",
+    strip_include_prefix = "lib/Dialect/Transform",
+    tbl_outs = [
+        (
+            ["-gen-op-decls"],
+            "lib/Dialect/Transform/TestTransformDialectExtension.h.inc",
+        ),
+        (
+            ["-gen-op-defs"],
+            "lib/Dialect/Transform/TestTransformDialectExtension.cpp.inc",
+        ),
+    ],
+    tblgen = "//mlir:mlir-tblgen",
+    td_file = "lib/Dialect/Transform/TestTransformDialectExtension.td",
+    test = True,
+    deps = [
+        ":TransformDialectTdFiles",
+        "//mlir:PDLDialectTdFiles",
+        "//mlir:TransformDialectTdFiles",
+    ],
+)
+
+cc_library(
+    name = "TestTransformDialect",
+    srcs = glob(["lib/Dialect/Transform/*.cpp"]),
+    hdrs = glob(["lib/Dialect/Transform/*.h"]),
+    includes = ["lib/Dialect/Transform"],
+    deps = [
+        ":TestTransformDialectExtensionIncGen",
+        "//mlir:IR",
+        "//mlir:PDLDialect",
+        "//mlir:Pass",
+        "//mlir:TransformDialect",
+    ],
+)
+
 cc_library(
     name = "TestDialect",
     srcs = glob(["lib/Dialect/Test/*.cpp"]),


        


More information about the Mlir-commits mailing list