[Mlir-commits] [mlir] 40a8bd6 - [mlir] use side effects in the Transform dialect

Alex Zinenko llvmlistbot at llvm.org
Fri Apr 22 14:29:20 PDT 2022


Author: Alex Zinenko
Date: 2022-04-22T23:29:11+02:00
New Revision: 40a8bd635b08f310e4f95b0789a70953bba1e645

URL: https://github.com/llvm/llvm-project/commit/40a8bd635b08f310e4f95b0789a70953bba1e645
DIFF: https://github.com/llvm/llvm-project/commit/40a8bd635b08f310e4f95b0789a70953bba1e645.diff

LOG: [mlir] use side effects in the Transform dialect

Currently, the sequence of Transform dialect operations only supports a single
use of each operand (verified by the `transform.sequence` operation). This was
originally motivated by the need to guard against accessing a payload IR
operation associated with a transform IR value after this operation has likely
been rewritten by a transformation. However, not all Transform dialect
operations rewrite payload IR, in particular the "navigation" operation such as
`transform.pdl_match` do not.

Introduce memory effects to the Transform dialect operations to describe their
effect on the payload IR and the mapping between payload IR opreations and
transform IR values. Use these effects to replace the single-use rule, allowing
repeated reads and disallowing use-after-free, where operations with the "free"
effect are considered to "consume" the transform IR value and rewrite the
corresponding payload IR operations). As an additional improvement, this
enables code motion transformation on the transform IR itself.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D124181

Added: 
    mlir/include/mlir/Dialect/Transform/IR/TransformEffects.td

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
    mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/lib/Dialect/Transform/IR/CMakeLists.txt
    mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
    mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/test/Dialect/Transform/ops-invalid.mlir
    mlir/test/Dialect/Transform/ops.mlir
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
index d1607b57622f9..75e6a64fe5510 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
@@ -17,16 +17,13 @@
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/StringMap.h"
 
-#include "mlir/Dialect/Transform/IR/TransformDialect.h.inc"
-
 namespace mlir {
 namespace transform {
-
 #ifndef NDEBUG
 namespace detail {
 /// Asserts that the operations provided as template arguments implement the
-/// TransformOpInterface. This must be a dynamic assertion since interface
-/// implementations may be registered at runtime.
+/// TransformOpInterface and MemoryEffectsOpInterface. This must be a dynamic
+/// assertion since interface implementations may be registered at runtime.
 template <typename OpTy>
 static inline void checkImplementsTransformInterface(MLIRContext *context) {
   // Since the operation is being inserted into the Transform dialect and the
@@ -34,12 +31,23 @@ static inline void checkImplementsTransformInterface(MLIRContext *context) {
   // itself having the interface implementation.
   RegisteredOperationName opName =
       *RegisteredOperationName::lookup(OpTy::getOperationName(), context);
-  assert(opName.hasInterface<TransformOpInterface>() &&
+  assert((opName.hasInterface<TransformOpInterface>() ||
+          opName.hasTrait<OpTrait::IsTerminator>()) &&
+         "non-terminator ops injected into the transform dialect must "
+         "implement TransformOpInterface");
+  assert(opName.hasInterface<MemoryEffectOpInterface>() &&
          "ops injected into the transform dialect must implement "
-         "TransformOpInterface");
+         "MemoryEffectsOpInterface");
 }
 } // namespace detail
 #endif // NDEBUG
+} // namespace transform
+} // namespace mlir
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h.inc"
+
+namespace mlir {
+namespace transform {
 
 /// Base class for extensions of the Transform dialect that supports injecting
 /// operations into the Transform dialect at load time. Concrete extensions are
@@ -66,19 +74,12 @@ class TransformDialectExtension
 
 protected:
   /// Injects the operations into the Transform dialect. The operations must
-  /// implement the TransformOpInterface and the implementation must be already
-  /// available when the operation is injected.
+  /// implement the TransformOpInterface and MemoryEffectsOpInterface, and the
+  /// implementations must be already available when the operation is injected.
   template <typename... OpTys>
   void registerTransformOps() {
     opInitializers.push_back([](TransformDialect *transformDialect) {
-      transformDialect->addOperations<OpTys...>();
-
-#ifndef NDEBUG
-      (void)std::initializer_list<int>{
-          (detail::checkImplementsTransformInterface<OpTys>(
-               transformDialect->getContext()),
-           0)...};
-#endif // NDEBUG
+      transformDialect->addOperationsChecked<OpTys...>();
     });
   }
 

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
index d695b850474a4..2802d659312ea 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
@@ -64,6 +64,15 @@ def Transform_Dialect : Dialect {
     correspond to groups of outer and inner loops, respectively, produced by
     the tiling transformation.
 
+    Overall, Transform IR ops are expected to be contained in a single top-level
+    op. Such top-level ops specify how to apply the transformations described
+    by the operations they contain, e.g., `transform.sequence` executes
+    transformations one by one and fails if any of them fails. Such ops are
+    expected to have the `PossibleTopLevelTransformOpTrait` and may be used
+    without arguments.
+
+    ## Dialect Extension Mechanism
+
     This dialect is designed to be extensible, that is, clients of this dialect
     are allowed to inject additional operations into this dialect using the
     `TransformDialectExtension` mechanism. This allows the dialect to avoid a
@@ -84,12 +93,63 @@ def Transform_Dialect : Dialect {
     `LoopTransformDialectExtension` in the cases above. Unprefixed operation
     names are reserved for ops defined directly in the Transform dialect.
 
-    Overall, Transform IR ops are expected to be contained in a single top-level
-    op. Such top-level ops specifie how to apply the transformations described
-    by operations they contain, e.g., `transform.sequence` executes
-    transformations one by one and fails if any of them fails. Such ops are
-    expected to have the `PossibleTopLevelTransformOpTrait` and may be used
-    without arguments.
+    Operations injected into the dialect must:
+
+      * Implement the `TransformOpInterface` to execute the corresponding
+        transformation on the payload IR.
+
+      * Implement the `MemoryEffectsOpInterface` to annotate the effects of
+        the transform IR operation on the payload IR as well as on the mapping
+        between transform IR values and payload IR operations. See below for
+        the description of available effects.
+
+    The presence of interface implementations is checked at runtime when the
+    dialect is loaded to allow for those implementations to be supplied by
+    separate dialect extensions if desired.
+
+    ## Side Effects
+
+    The Transform dialect relies on MLIR side effect modelling to enable
+    optimization of the transform IR. More specifically, it provides several
+    side effect resource objects and expects operations to describe their
+    effects on these resources.
+
+      * `TransformMappingResource` - side effect resource corresponding to the
+        mapping between transform IR values and payload IR operations.
+        
+        - An `Allocate` effect from this resource means creating a new mapping
+          entry, it is always accompanied by a `Write` effet.
+          
+        - A `Read` effect from this resource means accessing the mapping.
+        
+        - A `Free` effect on this resource indicates the removal of the mapping
+          entry, typically after a transformation that modifies the payload IR
+          operations associated with one of the transform IR operation's
+          operands. It is always accompanied by a `Read` effect.
+
+      * `PayloadIRResource` - side effect resource corresponding to the payload
+        IR itself.
+
+        - A `Read` effect from this resource means accessing the payload IR.
+
+        - A `Write` effect on this resource means mutating the payload IR. It is
+          almost always accompanied by a `Read`.
+
+    The typical flow of values in the transform IR is as follows. Most
+    operations produce new transform IR values and immediately associate them
+    with a list of payload IR operations. This corresponds to `Allocate` and
+    `Write` effects on the `TransformMappingResource`, and often requires at
+    least a `Read` effect on the `PayloadIRResource`. Transform operations that
+    only inspect the payload IR to produce new handles are usually limited to
+    these effects on their operands. Transform operations that mutate the
+    payload IR are thought to _consume_ the handles provided as operands, that
+    is have the `Read` and `Free` effects on them. As with the usual memory
+    effects, using a value after it was freed is incorrect. In case of the
+    transform IR, this value is likely associated with payload IR operations
+    that were modified or even removed by the transformation, so it is
+    meaningless to refer to them. When further transformations are desired, the
+    transform operations can return _new_ handles that can be read or consumed
+    by subsequent operations.
 
     ## Intended Use and Integrations
 
@@ -182,8 +242,18 @@ def Transform_Dialect : Dialect {
       getPDLConstraintHooks() const;
 
     private:
-      // Make addOperations available to the TransformDialectExtension class.
-      using ::mlir::Dialect::addOperations;
+      /// Registers operations specified as template parameters with this
+      /// dialect. Checks that they implement the required interfaces.
+      template <typename... OpTys>
+      void addOperationsChecked() {
+        addOperations<OpTys...>();
+
+        #ifndef NDEBUG
+        (void)std::initializer_list<int>{
+          (detail::checkImplementsTransformInterface<OpTys>(getContext()),
+           0)...};
+        #endif // NDEBUG
+      }
 
       template <typename, typename...>
       friend class TransformDialectExtension;

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformEffects.td b/mlir/include/mlir/Dialect/Transform/IR/TransformEffects.td
new file mode 100644
index 0000000000000..b6106fe96c975
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformEffects.td
@@ -0,0 +1,62 @@
+
+//===- TransformEffect.td - Transform side effects ---------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines side effects and associated resources for operations in the
+// Transform dialect and extensions.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_EFFECTS_TD
+#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_EFFECTS_TD
+
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// Effects on the mapping between Transform IR values and Payload IR ops.
+//===----------------------------------------------------------------------===//
+
+// Side effect resource corresponding to the mapping between transform IR values
+// and Payload IR operations.
+def TransformMappingResource
+    : Resource<"::mlir::transform::TransformMappingResource">;
+
+// Describes the creation of a new entry in the transform mapping. Should be
+// accompanied by the Write effect as the entry is immediately initialized by
+// any reasonable transform operation.
+def TransformMappingAlloc : MemAlloc<TransformMappingResource>;
+
+// Describes the removal of an entry in the transform mapping. Typically
+// accompanied by the Read effect.
+def TransformMappingFree : MemFree<TransformMappingResource>;
+
+// Describes the access to the mapping. Read-only accesses can be reordered.
+def TransformMappingRead : MemRead<TransformMappingResource>;
+
+// Describes a modification of an existing entry in the mapping. It is rarely
+// used alone, and is mostly accompanied by the Allocate effect.
+def TransformMappingWrite : MemWrite<TransformMappingResource>;
+
+//===----------------------------------------------------------------------===//
+// Effects on Payload IR.
+//===----------------------------------------------------------------------===//
+
+// Side effect resource corresponding to the Payload IR itself.
+def PayloadIRResource : Resource<"::mlir::transform::PayloadIRResource">;
+
+// Corresponds to the read-only access to the Payload IR through some operation
+// handles in the Transform IR.
+def PayloadIRRead : MemRead<PayloadIRResource>;
+
+// Corresponds to the mutation of the Payload IR through an operation handle in
+// the Transform IR. Should be accompanied by the Read effect for most transform
+// operations (only a complete overwrite of the root op of the Payload IR is a
+// write-only modification).
+def PayloadIRWrite : MemWrite<PayloadIRResource>;
+
+#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_EFFECTS_TD

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 49d3bcd8be454..1d0c6baa36383 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -11,6 +11,8 @@
 
 #include "mlir/IR/OpDefinition.h"
 
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
 namespace mlir {
 namespace transform {
 
@@ -376,6 +378,35 @@ class PossibleTopLevelTransformOpTrait
   }
 };
 
+/// Side effect resource corresponding to the mapping between Transform IR
+/// values and Payload IR operations. An Allocate effect from this resource
+/// means creating a new mapping entry, it is always accompanied by a Write
+/// effet. A Read effect from this resource means accessing the mapping. A Free
+/// effect on this resource indicates the removal of the mapping entry,
+/// typically after a transformation that modifies the Payload IR operations
+/// associated with one of the Transform IR operation's operands. It is always
+/// accompanied by a Read effect. Read-after-Free and double-Free are not
+/// allowed (they would be problematic with "regular" memory effects too) as
+/// they indicate an attempt to access Payload IR operations that have been
+/// modified, potentially erased, by the previous tranfsormations.
+// TODO: consider custom effects if these are not enabling generic passes such
+// as CSE/DCE to work.
+struct TransformMappingResource
+    : public SideEffects::Resource::Base<TransformMappingResource> {
+  StringRef getName() override { return "transform.mapping"; }
+};
+
+/// Side effect resource corresponding to the Payload IR itself. Only Read and
+/// Write effects are expected on this resource, with Write always accompanied
+/// by a Read (short of fully replacing the top-level Payload IR operation, one
+/// cannot modify the Payload IR without reading it first). This is intended
+/// to disallow reordering of Transform IR operations that mutate the Payload IR
+/// while still allowing the reordering of those that only access it.
+struct PayloadIRResource
+    : public SideEffects::Resource::Base<PayloadIRResource> {
+  StringRef getName() override { return "transform.payload_ir"; }
+};
+
 } // namespace transform
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 489197fe46f60..2548492767df3 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -13,6 +13,7 @@ include "mlir/IR/OpAsmInterface.td"
 include "mlir/IR/SymbolInterfaces.td"
 include "mlir/Dialect/PDL/IR/PDLTypes.td"
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/IR/TransformEffects.td"
 include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
 
 def PDLMatchOp : TransformDialectOp<"pdl_match",
@@ -35,29 +36,30 @@ def PDLMatchOp : TransformDialectOp<"pdl_match",
     could not be looked up or compiled.
   }];
 
-  let arguments = (ins PDL_Operation:$root, SymbolRefAttr:$pattern_name);
-  let results = (outs PDL_Operation:$matched);
+  let arguments = (ins
+    Arg<PDL_Operation, "Payload IR scope to match within",
+        [TransformMappingRead, PayloadIRRead]>:$root, 
+    SymbolRefAttr:$pattern_name);
+  let results = (outs
+    Res<PDL_Operation, "Handle to the matched Payload IR ops",
+        [TransformMappingAlloc, TransformMappingWrite]>:$matched);
 
   let assemblyFormat = "$pattern_name `in` $root attr-dict";
 }
 
 def SequenceOp : TransformDialectOp<"sequence",
-    [DeclareOpInterfaceMethods<TransformOpInterface>, OpAsmOpInterface,
-     PossibleTopLevelTransformOpTrait,
+    [DeclareOpInterfaceMethods<TransformOpInterface>,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     OpAsmOpInterface, PossibleTopLevelTransformOpTrait,
      SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">]> {
   let summary = "Contains a sequence of other transform ops to apply";
   let description = [{
     The transformations indicated by the sequence are applied in order of their
     appearance. Each value produced by a transformation within the sequence
     corresponds to an operation or a group of operations in the payload IR.
-    Each value may be used at most once by another transformation operation as
-    the transformation is likely to replace the transformed operation with
-    another operation or a group thereof. In such cases, the transformation
-    operation is expected to produce a new value to denote the newly produced
-    operations that can be transformed further. During application, if any
-    transformation in the sequence fails, the entire sequence fails immediately
-    leaving the payload IR in potentially invalid state, i.e., this operation
-    offers no transformation rollback capabilities.
+    During application, if any transformation in the sequence fails, the entire
+    sequence fails immediately leaving the payload IR in potentially invalid
+    state, i.e., this operation offers no transformation rollback capabilities.
 
     The entry block of this operation has a single argument that maps to either
     the operand if provided or the top-level container operation of the payload
@@ -83,7 +85,8 @@ def SequenceOp : TransformDialectOp<"sequence",
 
 def WithPDLPatternsOp : TransformDialectOp<"with_pdl_patterns",
     [DeclareOpInterfaceMethods<TransformOpInterface>, NoTerminator,
-     OpAsmOpInterface, PossibleTopLevelTransformOpTrait, SymbolTable]> {
+     OpAsmOpInterface, PossibleTopLevelTransformOpTrait, RecursiveSideEffects,
+     SymbolTable]> {
   let summary = "Contains PDL patterns available for use in transforms";
   let description = [{
     This op contains a set of named PDL patterns that are available for the
@@ -120,7 +123,9 @@ def WithPDLPatternsOp : TransformDialectOp<"with_pdl_patterns",
     ops associated with its operand when provided.
   }];
 
-  let arguments = (ins Optional<PDL_Operation>:$root);
+  let arguments = (ins
+    Arg<Optional<PDL_Operation>, "Root operation of the Payload IR",
+        [TransformMappingRead]>:$root);
   let regions = (region SizedRegion<1>:$body);
   let assemblyFormat = "($root^)? attr-dict-with-keyword regions";
 
@@ -140,7 +145,9 @@ def YieldOp : TransformDialectOp<"yield", [Terminator]> {
     any transformation on the payload IR and is used for flow purposes only.
   }];
 
-  let arguments = (ins Variadic<AnyType>:$operands);
+  let arguments = (ins 
+    Arg<Variadic<AnyType>, "Opration handles yielded back to the parent",
+        [TransformMappingRead]>:$operands);
   let assemblyFormat = "operands attr-dict (`:` type($operands)^)?";
 
   let builders = [

diff  --git a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt
index a5ac053c91195..dd22d0d10c017 100644
--- a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt
@@ -12,4 +12,5 @@ add_mlir_dialect_library(MLIRTransformDialect
   MLIRPDL
   MLIRPDLInterp
   MLIRRewrite
+  MLIRSideEffectInterfaces
   )

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
index 513f8736237a4..e4e09b9ba1536 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
@@ -14,7 +14,9 @@ using namespace mlir;
 #include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc"
 
 void transform::TransformDialect::initialize() {
-  addOperations<
+  // Using the checked version to enable the same assertions as for the ops from
+  // extensions.
+  addOperationsChecked<
 #define GET_OP_LIST
 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
       >();

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 2c9a2870dd616..c96f573363be0 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -104,8 +104,22 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
   if (failed(transform.apply(results, *this)))
     return failure();
 
-  for (Value target : transform->getOperands())
-    removePayloadOps(target);
+  // Remove the mapping for the operand if it is consumed by the operation. This
+  // allows us to catch use-after-free with assertions later on.
+  auto memEffectInterface =
+      cast<MemoryEffectOpInterface>(transform.getOperation());
+  SmallVector<MemoryEffects::EffectInstance, 2> effects;
+  for (Value target : transform->getOperands()) {
+    effects.clear();
+    memEffectInterface.getEffectsOnValue(target, effects);
+    if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
+          return isa<transform::TransformMappingResource>(
+                     effect.getResource()) &&
+                 isa<MemoryEffects::Free>(effect.getEffect());
+        })) {
+      removePayloadOps(target);
+    }
+  }
 
   for (auto &en : llvm::enumerate(transform->getResults())) {
     assert(en.value().getDefiningOp() == transform.getOperation() &&

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index c68ba11e3a06f..3dcf41447b77b 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -164,7 +164,58 @@ LogicalResult transform::SequenceOp::apply(transform::TransformResults &results,
   return success();
 }
 
+/// Returns `true` if the given op operand may be consuming the handle value in
+/// the Transform IR. That is, if it may have a Free effect on it.
+static bool isValueUsePotentialConsumer(OpOperand &use) {
+  // Conservatively assume the effect being present in absence of the interface.
+  auto memEffectInterface = dyn_cast<MemoryEffectOpInterface>(use.getOwner());
+  if (!memEffectInterface)
+    return true;
+
+  SmallVector<MemoryEffects::EffectInstance, 2> effects;
+  memEffectInterface.getEffectsOnValue(use.get(), effects);
+  return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
+    return isa<transform::TransformMappingResource>(effect.getResource()) &&
+           isa<MemoryEffects::Free>(effect.getEffect());
+  });
+}
+
+LogicalResult
+checkDoubleConsume(Value value,
+                   function_ref<InFlightDiagnostic()> reportError) {
+  OpOperand *potentialConsumer = nullptr;
+  for (OpOperand &use : value.getUses()) {
+    if (!isValueUsePotentialConsumer(use))
+      continue;
+
+    if (!potentialConsumer) {
+      potentialConsumer = &use;
+      continue;
+    }
+
+    InFlightDiagnostic diag = reportError()
+                              << " has more than one potential consumer";
+    diag.attachNote(potentialConsumer->getOwner()->getLoc())
+        << "used here as operand #" << potentialConsumer->getOperandNumber();
+    diag.attachNote(use.getOwner()->getLoc())
+        << "used here as operand #" << use.getOperandNumber();
+    return diag;
+  }
+
+  return success();
+}
+
 LogicalResult transform::SequenceOp::verify() {
+  // Check if the block argument has more than one consuming use.
+  for (BlockArgument argument : getBodyBlock()->getArguments()) {
+    auto report = [&]() {
+      return (emitOpError() << "block argument #" << argument.getArgNumber());
+    };
+    if (failed(checkDoubleConsume(argument, report)))
+      return failure();
+  }
+
+  // Check properties of the nested operations they cannot check themselves.
   for (Operation &child : *getBodyBlock()) {
     if (!isa<TransformOpInterface>(child) &&
         &child != &getBodyBlock()->back()) {
@@ -176,16 +227,11 @@ LogicalResult transform::SequenceOp::verify() {
     }
 
     for (OpResult result : child.getResults()) {
-      if (llvm::hasNItemsOrLess(result.getUses(), 1))
-        continue;
-      InFlightDiagnostic diag = child.emitError()
-                                << "result #" << result.getResultNumber()
-                                << " has more than one use";
-      for (OpOperand &use : result.getUses()) {
-        diag.attachNote(use.getOwner()->getLoc())
-            << "used here as operand #" << use.getOperandNumber();
-      }
-      return diag;
+      auto report = [&]() {
+        return (child.emitError() << "result #" << result.getResultNumber());
+      };
+      if (failed(checkDoubleConsume(result, report)))
+        return failure();
     }
   }
 
@@ -200,6 +246,49 @@ LogicalResult transform::SequenceOp::verify() {
   return success();
 }
 
+void transform::SequenceOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  auto *mappingResource = TransformMappingResource::get();
+  effects.emplace_back(MemoryEffects::Read::get(), getRoot(), mappingResource);
+
+  for (Value result : getResults()) {
+    effects.emplace_back(MemoryEffects::Allocate::get(), result,
+                         mappingResource);
+    effects.emplace_back(MemoryEffects::Write::get(), result, mappingResource);
+  }
+
+  if (!getRoot()) {
+    for (Operation &op : *getBodyBlock()) {
+      auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
+      if (!iface) {
+        // TODO: fill all possible effects; or require ops to actually implement
+        // the memory effect interface always
+        assert(false);
+      }
+
+      SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
+      iface.getEffects(effects);
+    }
+    return;
+  }
+
+  // Carry over all effects on the argument of the entry block as those on the
+  // operand, this is the same value just remapped.
+  for (Operation &op : *getBodyBlock()) {
+    auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
+    if (!iface) {
+      // TODO: fill all possible effects; or require ops to actually implement
+      // the memory effect interface always
+      assert(false);
+    }
+
+    SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
+    iface.getEffectsOnValue(getBodyBlock()->getArgument(0), nestedEffects);
+    for (const auto &effect : nestedEffects)
+      effects.emplace_back(effect.getEffect(), getRoot(), effect.getResource());
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // WithPDLPatternsOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
index 61ed760d700f1..ade63ebfc4596 100644
--- a/mlir/test/Dialect/Transform/ops-invalid.mlir
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -26,24 +26,6 @@ transform.sequence {
 
 // -----
 
-transform.sequence {
-^bb0(%arg0: !pdl.operation):
-  // expected-error @below {{result #0 has more than one use}}
-  %0 = transform.sequence %arg0 {
-  ^bb1(%arg1: !pdl.operation):
-  } : !pdl.operation
-  // expected-note @below {{used here as operand #0}}
-  transform.sequence %0 {
-  ^bb2(%arg2: !pdl.operation):
-  }
-  // expected-note @below {{used here as operand #0}}
-  transform.sequence %0 {
-  ^bb3(%arg3: !pdl.operation):
-  }
-}
-
-// -----
-
 // expected-error @below {{expects the types of the terminator operands to match the types of the resul}}
 %0 = transform.sequence {
 ^bb0(%arg0: !pdl.operation):
@@ -111,3 +93,63 @@ transform.with_pdl_patterns {
 ^bb1:
   "test.potential_terminator"() : () -> ()
 }) : () -> ()
+
+// -----
+
+transform.sequence {
+^bb0(%arg0: !pdl.operation):
+  // expected-error @below {{result #0 has more than one potential consumer}}
+  %0 = test_produce_param_or_forward_operand 42
+  // expected-note @below {{used here as operand #0}}
+  test_consume_operand_if_matches_param_or_fail %0[42]
+  // expected-note @below {{used here as operand #0}}
+  test_consume_operand_if_matches_param_or_fail %0[42]
+}
+
+// -----
+
+transform.sequence {
+^bb0(%arg0: !pdl.operation):
+  // expected-error @below {{result #0 has more than one potential consumer}}
+  %0 = test_produce_param_or_forward_operand 42
+  // expected-note @below {{used here as operand #0}}
+  test_consume_operand_if_matches_param_or_fail %0[42]
+  // expected-note @below {{used here as operand #0}}
+  transform.sequence %0 {
+  ^bb1(%arg1: !pdl.operation):
+    test_consume_operand_if_matches_param_or_fail %arg1[42]
+  }
+}
+
+// -----
+
+transform.sequence {
+^bb0(%arg0: !pdl.operation):
+  // expected-error @below {{result #0 has more than one potential consumer}}
+  %0 = test_produce_param_or_forward_operand 42
+  // expected-note @below {{used here as operand #0}}
+  test_consume_operand_if_matches_param_or_fail %0[42]
+  transform.sequence %0 {
+  ^bb1(%arg1: !pdl.operation):
+    // expected-note @below {{used here as operand #0}}
+    test_consume_operand_if_matches_param_or_fail %0[42]
+  }
+}
+
+// -----
+
+transform.sequence {
+^bb0(%arg0: !pdl.operation):
+  // expected-error @below {{result #0 has more than one potential consumer}}
+  %0 = test_produce_param_or_forward_operand 42
+  // expected-note @below {{used here as operand #0}}
+  test_consume_operand_if_matches_param_or_fail %0[42]
+  // expected-note @below {{used here as operand #0}}
+  transform.sequence %0 {
+  ^bb1(%arg1: !pdl.operation):
+    transform.sequence %arg1 {
+    ^bb2(%arg2: !pdl.operation):
+      test_consume_operand_if_matches_param_or_fail %arg2[42]
+    }
+  }
+}

diff  --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir
index 34ee62e0bbc75..e9e99de310b72 100644
--- a/mlir/test/Dialect/Transform/ops.mlir
+++ b/mlir/test/Dialect/Transform/ops.mlir
@@ -30,3 +30,22 @@ transform.sequence {
   ^bb1(%arg1: !pdl.operation):
   }
 }
+
+// Using the same value multiple times without consuming it is fine.
+// CHECK: transform.sequence
+// CHECK: %[[V:.+]] = sequence
+// CHECK: sequence %[[V]]
+// CHECK: sequence %[[V]]
+transform.sequence {
+^bb0(%arg0: !pdl.operation):
+  %0 = transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    yield %arg1 : !pdl.operation
+  } : !pdl.operation
+  transform.sequence %0 {
+  ^bb2(%arg2: !pdl.operation):
+  }
+  transform.sequence %0 {
+  ^bb3(%arg3: !pdl.operation):
+  }
+}

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index c3bbb5a66c61e..a58f12d1176ec 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -25,7 +25,8 @@ namespace {
 /// applied. This op is defined in C++ to test that C++ definitions also work
 /// for op injection into the Transform dialect.
 class TestTransformOp
-    : public Op<TestTransformOp, transform::TransformOpInterface::Trait> {
+    : public Op<TestTransformOp, transform::TransformOpInterface::Trait,
+                MemoryEffectOpInterface::Trait> {
 public:
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformOp)
 
@@ -63,6 +64,9 @@ class TestTransformOp
     if (getMessage())
       printer << " " << getMessage();
   }
+
+  // No side effects.
+  void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
 };
 
 /// A test op to exercise the verifier of the PossibleTopLevelTransformOpTrait
@@ -72,7 +76,8 @@ class TestTransformOp
 class TestTransformUnrestrictedOpNoInterface
     : public Op<TestTransformUnrestrictedOpNoInterface,
                 transform::PossibleTopLevelTransformOpTrait,
-                transform::TransformOpInterface::Trait> {
+                transform::TransformOpInterface::Trait,
+                MemoryEffectOpInterface::Trait> {
 public:
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
       TestTransformUnrestrictedOpNoInterface)
@@ -90,6 +95,9 @@ class TestTransformUnrestrictedOpNoInterface
                       transform::TransformState &state) {
     return success();
   }
+
+  // No side effects.
+  void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
 };
 } // namespace
 

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index 4596780ac131e..6fe34ae064232 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -16,15 +16,19 @@
 
 include "mlir/IR/OpBase.td"
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/IR/TransformEffects.td"
 include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
 include "mlir/Dialect/PDL/IR/PDLTypes.td"
 
 def TestProduceParamOrForwardOperandOp
   : Op<Transform_Dialect, "test_produce_param_or_forward_operand",
        [DeclareOpInterfaceMethods<TransformOpInterface>]> {
-  let arguments = (ins Optional<PDL_Operation>:$operand,
-                       OptionalAttr<I64Attr>:$parameter);
-  let results = (outs PDL_Operation:$res);
+  let arguments = (ins 
+     Arg<Optional<PDL_Operation>, "", [TransformMappingRead]>:$operand,
+     OptionalAttr<I64Attr>:$parameter);
+  let results = (outs 
+     Res<PDL_Operation, "",
+         [TransformMappingAlloc, TransformMappingWrite]>:$res);
   let assemblyFormat = "(`from` $operand^)? ($parameter^)? attr-dict";
   let cppNamespace = "::mlir::test";
   let hasVerifier = 1;
@@ -33,7 +37,10 @@ def TestProduceParamOrForwardOperandOp
 def TestConsumeOperandIfMatchesParamOrFail
   : Op<Transform_Dialect, "test_consume_operand_if_matches_param_or_fail",
        [DeclareOpInterfaceMethods<TransformOpInterface>]> {
-  let arguments = (ins PDL_Operation:$operand, I64Attr:$parameter);
+  let arguments = (ins 
+    Arg<PDL_Operation, "",
+        [TransformMappingWrite, TransformMappingFree]>:$operand,
+    I64Attr:$parameter);
   let assemblyFormat = "$operand `[` $parameter `]` attr-dict";
   let cppNamespace = "::mlir::test";
 }
@@ -41,7 +48,10 @@ def TestConsumeOperandIfMatchesParamOrFail
 def TestPrintRemarkAtOperandOp
   : Op<Transform_Dialect, "test_print_remark_at_operand",
        [DeclareOpInterfaceMethods<TransformOpInterface>]> {
-  let arguments = (ins PDL_Operation:$operand, StrAttr:$message);
+  let arguments = (ins 
+    Arg<PDL_Operation, "",
+        [TransformMappingRead, PayloadIRRead]>:$operand,
+    StrAttr:$message);
   let assemblyFormat = "$operand `,` $message attr-dict";
   let cppNamespace = "::mlir::test";
 }

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 58b63c464f9d4..fbf23bab8d2a3 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -7682,6 +7682,7 @@ td_library(
     deps = [
         ":OpBaseTdFiles",
         ":PDLDialectTdFiles",
+        ":SideEffectInterfacesTdFiles",
     ],
 )
 
@@ -7756,6 +7757,7 @@ cc_library(
         ":PDLDialect",
         ":PDLInterpDialect",
         ":Rewrite",
+        ":SideEffectInterfaces",
         ":Support",
         ":TransformDialectIncGen",
         ":TransformDialectInterfacesIncGen",


        


More information about the Mlir-commits mailing list