[Mlir-commits] [mlir] 0eb403a - [mlir][transform] Introduce transform.sequence op

Alex Zinenko llvmlistbot at llvm.org
Tue Apr 19 12:41:09 PDT 2022


Author: Alex Zinenko
Date: 2022-04-19T21:41:02+02:00
New Revision: 0eb403ad1b0c4d052ed93b58892c6432f78e0e7d

URL: https://github.com/llvm/llvm-project/commit/0eb403ad1b0c4d052ed93b58892c6432f78e0e7d
DIFF: https://github.com/llvm/llvm-project/commit/0eb403ad1b0c4d052ed93b58892c6432f78e0e7d.diff

LOG: [mlir][transform] Introduce transform.sequence op

Sequence is an important transform combination primitive that just indicates
transform ops being applied in a row. The simplest version requires fails
immediately if any transformation in the sequence fails. Introducing this
operation allows one to start placing transform IR within other IR.

Depends On D123135

Reviewed By: Mogball, rriddle

Differential Revision: https://reviews.llvm.org/D123664

Added: 
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/test/Dialect/Transform/ops-invalid.mlir
    mlir/test/Dialect/Transform/ops.mlir

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt
    mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
    mlir/lib/Dialect/Transform/IR/CMakeLists.txt
    mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
    mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
    mlir/test/Dialect/Transform/test-interpreter.mlir
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
    mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt
index 493e30d4a265d..d2d98c183f315 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt
@@ -1,8 +1,13 @@
-# The dialect does not have its own ops, so just generate the dialect files.
+# Generate the dialect files from the dialect .td.
+#
+# TODO: Make it possible to use XDialect instead of XOpsDialect in
+# add_mlir_dialect.
 set(LLVM_TARGET_DEFINITIONS TransformDialect.td)
 mlir_tablegen(TransformDialect.h.inc -gen-dialect-decls -dialect=transform)
 mlir_tablegen(TransformDialect.cpp.inc -gen-dialect-defs -dialect=transform)
 add_public_tablegen_target(MLIRTransformDialectIncGen)
 add_dependencies(mlir-headers MLIRTransformDialectIncGen)
 
+add_mlir_dialect(TransformOps transform)
+
 add_mlir_interface(TransformInterfaces)

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
index db14953ec1510..aca6497bcb9c1 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
@@ -161,6 +161,7 @@ def Transform_Dialect : Dialect {
 
   let name = "transform";
   let cppNamespace = "::mlir::transform";
+  let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
 
   let extraClassDeclaration = [{
     // Make addOperations available to the TransformDialectExtension class.
@@ -172,4 +173,9 @@ def Transform_Dialect : Dialect {
   }];
 }
 
+// Base class for ops that belong to the tranfsorm dialect. Ops defined in
+// extensions of this dialect may also use this.
+class TransformDialectOp<string mnemonic, list<Trait> traits = []>
+    : Op<Transform_Dialect, mnemonic, traits>;
+
 #endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 2caa20e6cdc66..f109ad599b841 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -33,6 +33,14 @@ class TransformOpInterface;
 /// expected to populate the `TransformResults` class instance in order to
 /// update the mapping. The `applyTransform` method takes care of propagating
 /// the state of `TransformResults` into the instance of this class.
+///
+/// When applying transform IR operations with regions, the client is expected
+/// to create a RegionScope RAII object to create a new "stack frame" for
+/// values defined inside the region. The mappings from and to these values will
+/// be automatically dropped when the object goes out of scope, typically at the
+/// end of the "apply" function of the parent operation. If a region contains
+/// blocks with arguments, the client can map those arguments to payload IR ops
+/// using "mapBlockArguments".
 class TransformState {
   /// Mapping between a Value in the transform IR and the corresponding set of
   /// operations in the payload IR.
@@ -42,9 +50,19 @@ class TransformState {
   /// currently associated with.
   using TransformOpReverseMapping = DenseMap<Operation *, Value>;
 
+  /// Bidirectional mappings between transform IR values and payload IR
+  /// operations.
+  struct Mappings {
+    TransformOpMapping direct;
+    TransformOpReverseMapping reverse;
+  };
+
 public:
-  /// Creates a state for the transformation rooted at the given op.
-  explicit TransformState(Operation *root);
+  /// Creates a state for transform ops living in the given region. The parent
+  /// operation of the region. The second argument points to the root operation
+  /// in the payload IR beind transformed, which may or may not contain the
+  /// region with transform ops.
+  TransformState(Region &region, Operation *root);
 
   /// Returns the op at which the transformation state is rooted. This is
   /// typically helpful for transformations that apply globally.
@@ -58,10 +76,96 @@ class TransformState {
   /// the state accordingly.
   LogicalResult applyTransform(TransformOpInterface transform);
 
+  /// Records the mapping between a block argument in the transform IR and a
+  /// list of operations in the payload IR. The arguments must be defined in
+  /// blocks of the currently processed transform IR region, typically after a
+  /// region scope is defined.
+  LogicalResult mapBlockArguments(BlockArgument argument,
+                                  ArrayRef<Operation *> operations) {
+#if LLVM_ENABLE_ABI_BREAKING_CHECKS
+    assert(argument.getParentRegion() == regionStack.back() &&
+           "mapping block arguments from a region other than the active one");
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+    return setPayloadOps(argument, operations);
+  }
+
+  // Forward declarations to support limited visibility.
+  class RegionScope;
+
+  /// Creates a new region scope for the given region. The region is expected to
+  /// be nested in the currently processed region.
+  // Implementation note: this method is inline but implemented outside of the
+  // class body to comply with visibility and full-declaration requirements.
+  inline RegionScope make_region_scope(Region &region);
+
+  /// A RAII object maintaining a "stack frame" for a transform IR region. When
+  /// applying a transform IR operation that contains a region, the caller is
+  /// expected to create a RegionScope before applying the ops contained in the
+  /// region. This ensures that the mappings between values defined in the
+  /// transform IR region and payload IR operations are cleared when the region
+  /// processing ends; such values cannot be accessed outside the region.
+  class RegionScope {
+  public:
+    /// Forgets the mapping from or to values defined in the associated
+    /// transform IR region.
+    ~RegionScope() {
+      state.mappings.erase(region);
+#if LLVM_ENABLE_ABI_BREAKING_CHECKS
+      state.regionStack.pop_back();
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+    }
+
+  private:
+    /// Creates a new scope for mappings between values defined in the given
+    /// transform IR region and payload IR operations.
+    RegionScope(TransformState &state, Region &region)
+        : state(state), region(&region) {
+      auto res = state.mappings.try_emplace(this->region);
+      assert(res.second && "the region scope is already present");
+      (void)res;
+#if LLVM_ENABLE_ABI_BREAKING_CHECKS
+      assert(state.regionStack.back()->isProperAncestor(&region) &&
+             "scope started at a non-nested region");
+      state.regionStack.push_back(&region);
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+    }
+
+    /// Back-reference to the transform state.
+    TransformState &state;
+
+    /// The region this scope is associated with.
+    Region *region;
+
+    friend RegionScope TransformState::make_region_scope(Region &);
+  };
+  friend class RegionScope;
+
 private:
   /// Identifier for storing top-level value in the `operations` mapping.
   static constexpr Value kTopLevelValue = Value();
 
+  /// Returns the mappings frame for the reigon in which the value is defined.
+  const Mappings &getMapping(Value value) const {
+    return const_cast<TransformState *>(this)->getMapping(value);
+  }
+  Mappings &getMapping(Value value) {
+    auto it = mappings.find(value.getParentRegion());
+    assert(it != mappings.end() &&
+           "trying to find a mapping for a value from an unmapped region");
+    return it->second;
+  }
+
+  /// Returns the mappings frame for the region in which the operation resides.
+  const Mappings &getMapping(Operation *operation) const {
+    return const_cast<TransformState *>(this)->getMapping(operation);
+  }
+  Mappings &getMapping(Operation *operation) {
+    auto it = mappings.find(operation->getParentRegion());
+    assert(it != mappings.end() &&
+           "trying to find a mapping for an operation from an unmapped region");
+    return it->second;
+  }
+
   /// Sets the payload IR ops associated with the given transform IR value.
   /// Fails if this would result in multiple transform IR values with uses
   /// corresponding to the same payload IR ops. For example, a hypothetical
@@ -88,9 +192,19 @@ class TransformState {
   void updatePayloadOps(Value value,
                         function_ref<Operation *(Operation *)> callback);
 
-  /// The mapping between payload IR values and transform IR ops.
-  TransformOpMapping operationMapping;
-  TransformOpReverseMapping reverseMapping;
+  /// The mappings between transform IR values and payload IR ops, aggregated by
+  /// the region in which the transform IR values are defined.
+  llvm::SmallDenseMap<Region *, Mappings> mappings;
+
+  /// The top-level operation that contains all payload IR, typically a module.
+  Operation *topLevel;
+
+#if LLVM_ENABLE_ABI_BREAKING_CHECKS
+  /// A stack of nested regions that are being processed in the transform IR.
+  /// Each region must be an ancestor of the following regions in this list.
+  /// These are also the keys for "mappings".
+  SmallVector<Region *> regionStack;
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
 };
 
 /// Local mapping between values defined by a specific op implementing the
@@ -123,6 +237,10 @@ class TransformResults {
   SmallVector<Operation *> operations;
 };
 
+TransformState::RegionScope TransformState::make_region_scope(Region &region) {
+  return RegionScope(*this, region);
+}
+
 } // namespace transform
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
new file mode 100644
index 0000000000000..a12b5abd8ffc8
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
@@ -0,0 +1,20 @@
+//===- TransformDialect.h - Transform dialect operations --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS_H
+#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS_H
+
+#include "mlir/Dialect/PDL/IR/PDLTypes.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Transform/IR/TransformOps.h.inc"
+
+#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS_H

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
new file mode 100644
index 0000000000000..246de281568b1
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -0,0 +1,78 @@
+//===- TransformOps.td - Transform dialect operations ------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS
+#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS
+
+include "mlir/IR/OpAsmInterface.td"
+include "mlir/Dialect/PDL/IR/PDLTypes.td"
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
+
+def SequenceOp : TransformDialectOp<"sequence",
+    [DeclareOpInterfaceMethods<TransformOpInterface>, OpAsmOpInterface,
+     SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">]> {
+  let summary = "Contains a sequence of other transform ops to apply";
+  let description = [{
+    The transformations indicated by the sequence are applied in order of their
+    appearance. Each value produced by a transformation within the sequence
+    corresponds to an operation or a group of operations in the payload IR.
+    Each value may be used at most once by another transformation operation as
+    the transformation is likely to replace the transformed operation with
+    another operation or a group thereof. In such cases, the transformation
+    operation is expected to produce a new value to denote the newly produced
+    operations that can be transformed further. During application, if any
+    transformation in the sequence fails, the entire sequence fails immediately
+    leaving the payload IR in potentially invalid state, i.e., this operation
+    offers no transformation rollback capabilities.
+
+    The entry block of this operation has a single argument that maps to either
+    the operand if provided or the top-level container operation of the payload
+    IR, typically the root operation of the pass interpreting the transform
+    dialect. Operand omission is only allowed for sequences not contained in
+    another sequence.
+  }];
+
+  let arguments = (ins Optional<PDL_Operation>:$root);
+  let results = (outs Variadic<AnyType>:$results);
+  let regions = (region SizedRegion<1>:$body);
+
+  let assemblyFormat =
+    "($root^)? attr-dict-with-keyword regions (`:` type($results)^)?";
+
+  let extraClassDeclaration = [{
+    /// Allow the dialect prefix to be omitted.
+    static StringRef getDefaultDialect() { return "transform"; }
+
+    Block *getBodyBlock() {
+      return &getBody().front();
+    }
+  }];
+
+  let hasVerifier = 1;
+}
+
+def YieldOp : TransformDialectOp<"yield", [Terminator]> {
+  let summary = "Yields operation handles from a transform IR region";
+  let description = [{
+    This terminator operation yields operation handles from regions of the
+    transform IR ops back to the containing op. It is not itself associated with
+    any transformation on the payload IR and is used for flow purposes only.
+  }];
+
+  let arguments = (ins Variadic<AnyType>:$operands);
+  let assemblyFormat = "operands attr-dict (`:` type($operands)^)?";
+
+  let builders = [
+    OpBuilder<(ins), [{
+      return build($_builder, $_state, ::mlir::ValueRange());
+    }]>
+  ];
+}
+
+#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS

diff  --git a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt
index 3a75f30ec7430..760ce9364b0aa 100644
--- a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRTransformDialect
   TransformDialect.cpp
   TransformInterfaces.cpp
+  TransformOps.cpp
 
   DEPENDS
   MLIRTransformDialectIncGen
@@ -8,4 +9,6 @@ add_mlir_dialect_library(MLIRTransformDialect
 
   LINK_LIBS PUBLIC
   MLIRIR
+  MLIRPDL
+  MLIRPDLInterp
   )

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
index de0cbd6718621..a566cb91ee750 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
@@ -7,9 +7,15 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
-
-#include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
 
 using namespace mlir;
 
-void transform::TransformDialect::initialize() {}
+#include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc"
+
+void transform::TransformDialect::initialize() {
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
+      >();
+}

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index d7a8aee9447a0..7df299a94cfb6 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Operation.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SmallPtrSet.h"
 
 using namespace mlir;
@@ -19,16 +20,21 @@ using namespace mlir;
 
 constexpr const Value transform::TransformState::kTopLevelValue;
 
-transform::TransformState::TransformState(Operation *root) {
-  operationMapping[kTopLevelValue].push_back(root);
+transform::TransformState::TransformState(Region &region, Operation *root)
+    : topLevel(root) {
+  auto result = mappings.try_emplace(&region);
+  assert(result.second && "the region scope is already present");
+  (void)result;
+#if LLVM_ENABLE_ABI_BREAKING_CHECKS
+  regionStack.push_back(&region);
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
 }
 
-Operation *transform::TransformState::getTopLevel() const {
-  return operationMapping.lookup(kTopLevelValue).front();
-}
+Operation *transform::TransformState::getTopLevel() const { return topLevel; }
 
 ArrayRef<Operation *>
 transform::TransformState::getPayloadOps(Value value) const {
+  const TransformOpMapping &operationMapping = getMapping(value).direct;
   auto iter = operationMapping.find(value);
   assert(iter != operationMapping.end() && "unknown handle");
   return iter->getSecond();
@@ -46,8 +52,9 @@ transform::TransformState::setPayloadOps(Value value,
   // Setting new payload for the value without cleaning it first is a misuse of
   // the API, assert here.
   SmallVector<Operation *> storedTargets(targets.begin(), targets.end());
+  Mappings &mappings = getMapping(value);
   bool inserted =
-      operationMapping.insert({value, std::move(storedTargets)}).second;
+      mappings.direct.insert({value, std::move(storedTargets)}).second;
   assert(inserted && "value is already associated with another list");
   (void)inserted;
 
@@ -55,7 +62,7 @@ transform::TransformState::setPayloadOps(Value value,
   // expressed using the dialect and may be constructed by valid API calls from
   // valid IR. Emit an error here.
   for (Operation *op : targets) {
-    auto insertionResult = reverseMapping.insert({op, value});
+    auto insertionResult = mappings.reverse.insert({op, value});
     if (!insertionResult.second) {
       InFlightDiagnostic diag = op->emitError()
                                 << "operation tracked by two handles";
@@ -69,15 +76,16 @@ transform::TransformState::setPayloadOps(Value value,
 }
 
 void transform::TransformState::removePayloadOps(Value value) {
-  for (Operation *op : operationMapping[value])
-    reverseMapping.erase(op);
-  operationMapping.erase(value);
+  Mappings &mappings = getMapping(value);
+  for (Operation *op : mappings.direct[value])
+    mappings.reverse.erase(op);
+  mappings.direct.erase(value);
 }
 
 void transform::TransformState::updatePayloadOps(
     Value value, function_ref<Operation *(Operation *)> callback) {
-  auto it = operationMapping.find(value);
-  assert(it != operationMapping.end() && "unknown handle");
+  auto it = getMapping(value).direct.find(value);
+  assert(it != getMapping(value).direct.end() && "unknown handle");
   SmallVector<Operation *> &association = it->getSecond();
   SmallVector<Operation *> updated;
   updated.reserve(association.size());
@@ -98,9 +106,13 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
   for (Value target : transform->getOperands())
     removePayloadOps(target);
 
-  for (auto &en : llvm::enumerate(transform->getResults()))
+  for (auto &en : llvm::enumerate(transform->getResults())) {
+    assert(en.value().getDefiningOp() == transform.getOperation() &&
+           "payload IR association for a value other than the result of the "
+           "current transform op");
     if (failed(setPayloadOps(en.value(), results.get(en.index()))))
       return failure();
+  }
 
   return success();
 }

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
new file mode 100644
index 0000000000000..3018e3b5b68bf
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -0,0 +1,101 @@
+//===- TransformDialect.cpp - Transform dialect operations ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/IR/Builders.h"
+
+#include "mlir/IR/OpImplementation.h"
+
+using namespace mlir;
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
+
+LogicalResult transform::SequenceOp::apply(transform::TransformResults &results,
+                                           transform::TransformState &state) {
+  SmallVector<Operation *> targets;
+  if (getRoot())
+    llvm::append_range(targets, state.getPayloadOps(getRoot()));
+  else
+    targets.push_back(state.getTopLevel());
+
+  // Map the entry block argument to the list of operations.
+  auto scope = state.make_region_scope(*getBodyBlock()->getParent());
+  if (failed(state.mapBlockArguments(getBodyBlock()->getArgument(0), targets)))
+    return failure();
+
+  // Apply the sequenced ops one by one.
+  for (Operation &transform : getBodyBlock()->without_terminator())
+    if (failed(state.applyTransform(cast<TransformOpInterface>(transform))))
+      return failure();
+
+  // Forward the operation mapping for values yielded from the sequence to the
+  // values produced by the sequence op.
+  for (const auto &pair :
+       llvm::zip(getBodyBlock()->getTerminator()->getOperands(),
+                 getOperation()->getOpResults())) {
+    Value terminatorOperand = std::get<0>(pair);
+    OpResult result = std::get<1>(pair);
+    results.set(result, state.getPayloadOps(terminatorOperand));
+  }
+
+  return success();
+}
+
+LogicalResult transform::SequenceOp::verify() {
+  if (getBodyBlock()->getNumArguments() != 1 ||
+      !getBodyBlock()->getArgumentTypes()[0].isa<pdl::OperationType>()) {
+    return emitOpError()
+           << "expected the entry block to have one argument of type "
+           << pdl::OperationType::get(getContext());
+  }
+
+  if (auto parent = getOperation()->getParentOfType<transform::SequenceOp>()) {
+    if (!getRoot()) {
+      InFlightDiagnostic diag =
+          emitOpError()
+          << "expected the root operation to be provided for a nested sequence";
+      diag.attachNote(parent.getLoc()) << "nested in another sequence";
+      return diag;
+    }
+  }
+
+  for (Operation &child : *getBodyBlock()) {
+    if (!isa<TransformOpInterface>(child) &&
+        &child != &getBodyBlock()->back()) {
+      InFlightDiagnostic diag =
+          emitOpError()
+          << "expected children ops to implement TransformOpInterface";
+      diag.attachNote(child.getLoc()) << "op without interface";
+      return diag;
+    }
+
+    for (OpResult result : child.getResults()) {
+      if (llvm::hasNItemsOrLess(result.getUses(), 1))
+        continue;
+      InFlightDiagnostic diag = child.emitError()
+                                << "result #" << result.getResultNumber()
+                                << " has more than one use";
+      for (OpOperand &use : result.getUses()) {
+        diag.attachNote(use.getOwner()->getLoc())
+            << "used here as operand #" << use.getOperandNumber();
+      }
+      return diag;
+    }
+  }
+
+  if (getBodyBlock()->getTerminator()->getOperandTypes() !=
+      getOperation()->getResultTypes()) {
+    InFlightDiagnostic diag = emitOpError()
+                              << "expects the types of the terminator operands "
+                                 "to match the types of the result";
+    diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
+    return diag;
+  }
+  return success();
+}

diff  --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
new file mode 100644
index 0000000000000..614628107834e
--- /dev/null
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -0,0 +1,52 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+// expected-error @below {{expected the entry block to have one argument of type '!pdl.operation'}}
+transform.sequence {
+}
+
+// -----
+
+// expected-note @below {{nested in another sequence}}
+transform.sequence {
+^bb0(%arg0: !pdl.operation):
+  // expected-error @below {{expected the root operation to be provided for a nested sequence}}
+  transform.sequence {
+  ^bb1(%arg1: !pdl.operation):
+  }
+}
+
+// -----
+
+// expected-error @below {{expected children ops to implement TransformOpInterface}}
+transform.sequence {
+^bb0(%arg0: !pdl.operation):
+  // expected-note @below {{op without interface}}
+  arith.constant 42.0 : f32
+}
+
+// -----
+
+transform.sequence {
+^bb0(%arg0: !pdl.operation):
+  // expected-error @below {{result #0 has more than one use}}
+  %0 = transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+  } : !pdl.operation
+  // expected-note @below {{used here as operand #0}}
+  transform.sequence %0 {
+  ^bb2(%arg2: !pdl.operation):
+  }
+  // expected-note @below {{used here as operand #0}}
+  transform.sequence %0 {
+  ^bb3(%arg3: !pdl.operation):
+  }
+}
+
+// -----
+
+// expected-error @below {{expects the types of the terminator operands to match the types of the resul}}
+%0 = transform.sequence {
+^bb0(%arg0: !pdl.operation):
+  // expected-note @below {{terminator}}
+  transform.yield
+} : !pdl.operation

diff  --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir
new file mode 100644
index 0000000000000..c3aab426aad26
--- /dev/null
+++ b/mlir/test/Dialect/Transform/ops.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+
+// CHECK: transform.sequence
+// CHECK: ^{{.+}}(%{{.+}}: !pdl.operation):
+transform.sequence {
+^bb0(%arg0: !pdl.operation):
+  // CHECK: sequence %{{.+}}
+  // CHECK: ^{{.+}}(%{{.+}}: !pdl.operation):
+  sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+  }
+}

diff  --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 429d4f44da14c..a6ceeea82a5c8 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -25,3 +25,47 @@ transform.test_consume_operand_if_matches_param_or_fail %0[21]
 %2 = transform.test_produce_param_or_forward_operand from %0
 transform.test_consume_operand_if_matches_param_or_fail %1[42]
 transform.test_consume_operand_if_matches_param_or_fail %2[42]
+
+// -----
+
+transform.sequence {
+^bb0(%arg0: !pdl.operation):
+  sequence %arg0 {
+  ^bb0(%arg1: !pdl.operation):
+    // expected-remark @below {{applying transformation "a"}}
+    test_transform_op "a"
+    // expected-remark @below {{applying transformation "b"}}
+    test_transform_op "b"
+    // expected-remark @below {{applying transformation "c"}}
+    test_transform_op "c"
+  }
+  // expected-remark @below {{applying transformation "d"}}
+  test_transform_op "d"
+  // expected-remark @below {{applying transformation "e"}}
+  test_transform_op "e"
+}
+
+// -----
+
+transform.sequence {
+^bb0(%arg0: !pdl.operation):
+  %0 = test_produce_param_or_forward_operand 42
+  sequence %0 {
+  ^bb0(%arg1: !pdl.operation):
+    // expected-remark @below {{succeeded}}
+    test_consume_operand_if_matches_param_or_fail %arg1[42]
+  }
+}
+
+// -----
+
+transform.sequence {
+^bb0(%arg0: !pdl.operation):
+  %0 = sequence %arg0 {
+  ^bb0(%arg1: !pdl.operation):
+    %1 = test_produce_param_or_forward_operand 42
+    yield %1 : !pdl.operation
+  } : !pdl.operation
+  // expected-remark @below {{succeeded}}
+  test_consume_operand_if_matches_param_or_fail %0[42]
+}

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index b0226b598dbe2..4aed0aae1e776 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -38,31 +38,47 @@ class TestTransformOp
 
   LogicalResult apply(transform::TransformResults &results,
                       transform::TransformState &state) {
-    emitRemark() << "applying transformation";
+    InFlightDiagnostic remark = emitRemark() << "applying transformation";
+    if (Attribute message = getMessage())
+      remark << " " << message;
+
     return success();
   }
 
+  Attribute getMessage() { return getOperation()->getAttr("message"); }
+
   static ParseResult parse(OpAsmParser &parser, OperationState &state) {
-    return success();
+    StringAttr message;
+    OptionalParseResult result = parser.parseOptionalAttribute(message);
+    if (!result.hasValue())
+      return success();
+
+    if (result.getValue().succeeded())
+      state.addAttribute("message", message);
+    return result.getValue();
   }
 
-  void print(OpAsmPrinter &printer) {}
+  void print(OpAsmPrinter &printer) {
+    if (getMessage())
+      printer << " " << getMessage();
+  }
 };
 } // namespace
 
 LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   if (getOperation()->getNumOperands() != 0) {
-    results.set(getResult().cast<OpResult>(), getOperand(0).getDefiningOp());
+    results.set(getResult().cast<OpResult>(),
+                getOperation()->getOperand(0).getDefiningOp());
   } else {
     results.set(getResult().cast<OpResult>(),
-                reinterpret_cast<Operation *>(*parameter()));
+                reinterpret_cast<Operation *>(*getParameter()));
   }
   return success();
 }
 
 LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() {
-  if (parameter().hasValue() ^ (getNumOperands() != 1))
+  if (getParameter().hasValue() ^ (getNumOperands() != 1))
     return emitOpError() << "expects either a parameter or an operand";
   return success();
 }
@@ -72,9 +88,9 @@ LogicalResult mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply(
   ArrayRef<Operation *> payload = state.getPayloadOps(getOperand());
   assert(payload.size() == 1 && "expected a single target op");
   auto value = reinterpret_cast<intptr_t>(payload[0]);
-  if (static_cast<uint64_t>(value) != parameter()) {
+  if (static_cast<uint64_t>(value) != getParameter()) {
     return emitOpError() << "expected the operand to be associated with "
-                         << parameter() << " got " << value;
+                         << getParameter() << " got " << value;
   }
 
   emitRemark() << "succeeded";

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
index 854a41318dac2..8c4aa1aee5529 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
@@ -37,7 +37,7 @@ class TestTransformDialectInterpreterPass
 
   void runOnOperation() override {
     ModuleOp module = getOperation();
-    transform::TransformState state(module);
+    transform::TransformState state(module.getBodyRegion(), module);
     for (auto op :
          module.getBody()->getOps<transform::TransformOpInterface>()) {
       if (failed(state.applyTransform(op)))

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index bc01efa76f74b..346e7f7d16a7d 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -7699,6 +7699,7 @@ td_library(
     srcs = glob(["include/mlir/Dialect/Transform/IR/*.td"]),
     deps = [
         ":OpBaseTdFiles",
+        ":PDLDialectTdFiles",
     ],
 )
 
@@ -7746,15 +7747,35 @@ gentbl_cc_library(
     deps = [":TransformDialectTdFiles"],
 )
 
+gentbl_cc_library(
+    name = "TransformOpsIncGen",
+    strip_include_prefix = "include",
+    tbl_outs = [
+        (
+            ["-gen-op-decls"],
+            "include/mlir/Dialect/Transform/IR/TransformOps.h.inc",
+        ),
+        (
+            ["-gen-op-defs"],
+            "include/mlir/Dialect/Transform/IR/TransformOps.cpp.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/Transform/IR/TransformOps.td",
+    deps = [":TransformDialectTdFiles"],
+)
+
 cc_library(
     name = "TransformDialect",
     srcs = glob(["lib/Dialect/Transform/IR/*.cpp"]),
     hdrs = glob(["include/mlir/Dialect/Transform/IR/*.h"]),
     deps = [
         ":IR",
+        ":PDLDialect",
         ":Support",
         ":TransformDialectIncGen",
         ":TransformDialectInterfacesIncGen",
+        ":TransformOpsIncGen",
         "//llvm:Support",
     ],
 )


        


More information about the Mlir-commits mailing list