[Mlir-commits] [mlir] e3890b7 - [mlir] Introduce transform.alternatives op

Alex Zinenko llvmlistbot at llvm.org
Tue Jun 14 08:51:38 PDT 2022


Author: Alex Zinenko
Date: 2022-06-14T17:51:30+02:00
New Revision: e3890b7fd65595685fa9e781c78bf3f4d4231e32

URL: https://github.com/llvm/llvm-project/commit/e3890b7fd65595685fa9e781c78bf3f4d4231e32
DIFF: https://github.com/llvm/llvm-project/commit/e3890b7fd65595685fa9e781c78bf3f4d4231e32.diff

LOG: [mlir] Introduce transform.alternatives op

Introduce a transform dialect op that allows one to attempt different
transformation sequences on the same piece of payload IR until one of them
succeeds. This op fundamentally expands the scope of possibilities in the
transform dialect that, until now, could only propagate transformation failure,
at least using in-tree operations. This requires a more detailed specification
of the execution model for the transform dialect that now indicates how failure
is handled and propagated.

Transformations described by transform operations now have tri-state results,
with some errors being fundamentally irrecoverable (e.g., generating malformed
IR) and some others being recoverable by containing ops. Existing transform ops
directly implementing the `apply` interface method are updated to produce this
directly. Transform ops with the `TransformEachTransformOpTrait` are currently
considered to produce only irrecoverable failures and will be updated
separately.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D127724

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
    mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/test/Dialect/Transform/ops-invalid.mlir
    mlir/test/Dialect/Transform/test-interpreter.mlir
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
    mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
index 598a0649b524a..1a36e2f3d144c 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
@@ -151,6 +151,38 @@ def Transform_Dialect : Dialect {
     transform operations can return _new_ handles that can be read or consumed
     by subsequent operations.
 
+    ## Execution Model
+
+    The transformation starts at the specifed top-level transform IR operation
+    and applies to some payload IR scope, identified by the payload IR op that
+    contains the IR to transform. It is the responsibility of the user to
+    properly select the scope and/or to avoid the transformations to modify the
+    IR outside of the given scope. The top-level transform IR operation may
+    contain further transform operations and execute them in the desired order.
+
+    Transformation application functions produce a tri-state status:
+
+    - success;
+    - recoverable (silencable) failure;
+    - irrecoverable failure.
+
+    Transformation container operations may intercept recoverable failures and
+    perform the required recovery steps thus succeeding themselves. On
+    the other hand, they must propagate irrecoverable failures. For such
+    failures, the diagnostics are emitted immediately whereas their emission is
+    postponed for recoverable faliures. Transformation container operations may
+    also fail to recover from a theoretically recoverable failure, in which case
+    they are expected to emit the diagnostic and turn the failure into an
+    irrecoverable one. A recoverable failure produced by applying the top-level
+    transform IR operation is considered irrecoverable.
+
+    Transformation container operations are allowed to "step over" some nested
+    operations if the application of some previous operation produced a failure.
+    This can be conceptually thought of as having a global "recoverable error
+    register" that is read/write accessed by each transform operation as a side
+    effect. The transformation is skipped if the register already contains an
+    error description, and the control flow proceeds to the following operation.
+
     ## Handle Invalidation
 
     The execution model of the transform dialect expects that a payload IR

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index c42474c671e97..df33f954f6820 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -14,6 +14,129 @@
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 namespace mlir {
+
+/// The result of a transform IR operation application. This can have one of the
+/// three states:
+///   - success;
+///   - silencable (recoverable) failure with yet-unreported diagnostic;
+///   - definite failure.
+/// Silencable failure is intended to communicate information about
+/// transformations that did not apply but in a way that supports recovery,
+/// for example, they did not modify the payload IR or modified it in some
+/// predictable way. They are associated with a Diagnostic that provides more
+/// details on the failure. Silencable failure can be discarded, turning the
+/// result into success, or "reported", emitting the diagnostic and turning the
+/// result into definite failure. Transform IR operations containing other
+/// operations are allowed to do either with the results of the nested
+/// transformations, but must propagate definite failures as their diagnostics
+/// have been already reported to the user.
+class LLVM_NODISCARD DiagnosedSilencableFailure {
+public:
+  explicit DiagnosedSilencableFailure(LogicalResult result) : result(result) {}
+  DiagnosedSilencableFailure(const DiagnosedSilencableFailure &) = delete;
+  DiagnosedSilencableFailure &
+  operator=(const DiagnosedSilencableFailure &) = delete;
+  DiagnosedSilencableFailure(DiagnosedSilencableFailure &&) = default;
+  DiagnosedSilencableFailure &
+  operator=(DiagnosedSilencableFailure &&) = default;
+
+  /// Constructs a DiagnosedSilencableFailure in the success state.
+  static DiagnosedSilencableFailure success() {
+    return DiagnosedSilencableFailure(::mlir::success());
+  }
+
+  /// Constructs a DiagnosedSilencableFailure in the failure state. Typically,
+  /// a diagnostic has been emitted before this.
+  static DiagnosedSilencableFailure definiteFailure() {
+    return DiagnosedSilencableFailure(::mlir::failure());
+  }
+
+  /// Constructs a DiagnosedSilencableFailure in the silencable failure state,
+  /// ready to emit the given diagnostic. This is considered a failure
+  /// regardless of the diagnostic severity.
+  static DiagnosedSilencableFailure silencableFailure(Diagnostic &&diag) {
+    return DiagnosedSilencableFailure(std::forward<Diagnostic>(diag));
+  }
+
+  /// Converts all kinds of failure into a LogicalResult failure, emitting the
+  /// diagnostic if necessary. Must not be called more than once.
+  LogicalResult checkAndReport() {
+#if LLVM_ENABLE_ABI_BREAKING_CHECKS
+    assert(!reported && "attempting to report a diagnostic more than once");
+    reported = true;
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+    if (diagnostic) {
+      diagnostic->getLocation().getContext()->getDiagEngine().emit(
+          std::move(*diagnostic));
+      diagnostic.reset();
+      result = ::mlir::failure();
+    }
+    return result;
+  }
+
+  /// Returns `true` if this is a silencable failure.
+  bool isSilencableFailure() const { return diagnostic.hasValue(); }
+
+  /// Returns `true` if this is a success.
+  bool succeeded() const {
+    return !diagnostic.hasValue() && ::mlir::succeeded(result);
+  }
+
+  /// Returns the diagnostic message without emitting it. Expects this object
+  /// to be a silencable failure.
+  std::string getMessage() const { return diagnostic->str(); }
+
+  /// Converts silencable failure into LogicalResult success without reporting
+  /// the diagnostic, preserves the other states.
+  LogicalResult silence() {
+    if (diagnostic) {
+      diagnostic.reset();
+      result = ::mlir::success();
+    }
+    return result;
+  }
+
+  /// Streams the given values into the diagnotic. Expects this object to be a
+  /// silencable failure.
+  template <typename T> DiagnosedSilencableFailure &operator<<(T &&value) & {
+    assert(isSilencableFailure() &&
+           "can only append output in silencable failure state");
+    *diagnostic << std::forward<T>(value);
+    return *this;
+  }
+  template <typename T> DiagnosedSilencableFailure &&operator<<(T &&value) && {
+    return std::move(this->operator<<(std::forward<T>(value)));
+  }
+
+  /// Attaches a note to the diagnostic. Expects this object to be a silencable
+  /// failure.
+  Diagnostic &attachNote(Optional<Location> loc = llvm::None) {
+    assert(isSilencableFailure() &&
+           "can only attach notes to silencable failures");
+    return diagnostic->attachNote(loc);
+  }
+
+private:
+  explicit DiagnosedSilencableFailure(Diagnostic &&diagnostic)
+      : diagnostic(std::move(diagnostic)), result(failure()) {}
+
+  /// The diagnostic associated with this object. If present, the object is
+  /// considered to be in the silencable failure state regardless of the
+  /// `result` field.
+  Optional<Diagnostic> diagnostic;
+
+  /// The "definite" logical state, either success or failure. Ignored if the
+  /// diagnostic message is present.
+  LogicalResult result;
+
+#if LLVM_ENABLE_ABI_BREAKING_CHECKS
+  /// Whther the associated diagnostic have been reported. Diagnostic reporting
+  /// consumes the diagnostic, so we need a mechanism to 
diff erentiate a
+  /// reported diagnostic from a state where it was never created.
+  bool reported = false;
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+};
+
 namespace transform {
 
 class TransformOpInterface;
@@ -103,7 +226,7 @@ class TransformState {
 
   /// Applies the transformation specified by the given transform op and updates
   /// the state accordingly.
-  LogicalResult applyTransform(TransformOpInterface transform);
+  DiagnosedSilencableFailure applyTransform(TransformOpInterface transform);
 
   /// Records the mapping between a block argument in the transform IR and a
   /// list of operations in the payload IR. The arguments must be defined in
@@ -401,7 +524,7 @@ namespace detail {
 /// the payload IR, depending on what is available in the context.
 LogicalResult
 mapPossibleTopLevelTransformOpBlockArguments(TransformState &state,
-                                             Operation *op);
+                                             Operation *op, unsigned region);
 
 /// Verification hook for PossibleTopLevelTransformOpTrait.
 LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op);
@@ -411,7 +534,7 @@ LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op);
 /// can be standalone top-level transforms. Such operations typically contain
 /// other Transform dialect operations that can be executed following some
 /// control flow logic specific to the current operation. The operations with
-/// this trait are expected to have exactly one single-block region with one
+/// this trait are expected to have at least one single-block region with one
 /// argument of PDL Operation type. The operations are also expected to be valid
 /// without operands, in which case they are considered top-level, and with one
 /// or more arguments, in which case they are considered nested. Top-level
@@ -430,16 +553,18 @@ class PossibleTopLevelTransformOpTrait
     return detail::verifyPossibleTopLevelTransformOpTrait(op);
   }
 
-  /// Returns the single block of the op's only region.
-  Block *getBodyBlock() { return &this->getOperation()->getRegion(0).front(); }
+  /// Returns the single block of the given region.
+  Block *getBodyBlock(unsigned region = 0) {
+    return &this->getOperation()->getRegion(region).front();
+  }
 
-  /// Sets up the mapping between the entry block of the only region of this op
+  /// Sets up the mapping between the entry block of the given region of this op
   /// and the relevant list of Payload IR operations in the given state. The
   /// state is expected to be already scoped at the region of this operation.
   /// Returns failure if the mapping failed, e.g., the value is already mapped.
-  LogicalResult mapBlockArguments(TransformState &state) {
+  LogicalResult mapBlockArguments(TransformState &state, unsigned region = 0) {
     return detail::mapPossibleTopLevelTransformOpBlockArguments(
-        state, this->getOperation());
+        state, this->getOperation(), region);
   }
 };
 
@@ -461,8 +586,8 @@ class TransformEachOpTrait
   /// Calls `applyToOne` for every payload operation associated with the operand
   /// of this transform IR op. If `applyToOne` returns ops, associates them with
   /// the result of this transform op.
-  LogicalResult apply(TransformResults &transformResults,
-                      TransformState &state);
+  DiagnosedSilencableFailure apply(TransformResults &transformResults,
+                                   TransformState &state);
 
   /// Checks that the op matches the expectations of this trait.
   static LogicalResult verifyTrait(Operation *op);
@@ -497,22 +622,22 @@ struct PayloadIRResource
   StringRef getName() override { return "transform.payload_ir"; }
 };
 
-/// Trait implementing the MemoryEffectOpInterface for single-operand operations
-/// that "consume" their operand and produce a new result.
+/// Trait implementing the MemoryEffectOpInterface for operations that "consume"
+/// their operands and produce new results.
 template <typename OpTy>
 class FunctionalStyleTransformOpTrait
     : public OpTrait::TraitBase<OpTy, FunctionalStyleTransformOpTrait> {
 public:
-  /// This op "consumes" the operand by reading and freeing it, "produces" the
-  /// results by allocating and writing it and reads/writes the payload IR in
-  /// the process.
+  /// This op "consumes" the operands by reading and freeing then, "produces"
+  /// the results by allocating and writing it and reads/writes the payload IR
+  /// in the process.
   void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-    effects.emplace_back(MemoryEffects::Read::get(),
-                         this->getOperation()->getOperand(0),
-                         TransformMappingResource::get());
-    effects.emplace_back(MemoryEffects::Free::get(),
-                         this->getOperation()->getOperand(0),
-                         TransformMappingResource::get());
+    for (Value operand : this->getOperation()->getOperands()) {
+      effects.emplace_back(MemoryEffects::Read::get(), operand,
+                           TransformMappingResource::get());
+      effects.emplace_back(MemoryEffects::Free::get(), operand,
+                           TransformMappingResource::get());
+    }
     for (Value result : this->getOperation()->getResults()) {
       effects.emplace_back(MemoryEffects::Allocate::get(), result,
                            TransformMappingResource::get());
@@ -525,8 +650,6 @@ class FunctionalStyleTransformOpTrait
 
   /// Checks that the op matches the expectations of this trait.
   static LogicalResult verifyTrait(Operation *op) {
-    static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
-                  "expected single-operand op");
     if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
       op->emitError()
           << "FunctionalStyleTransformOpTrait should only be attached to ops "
@@ -612,12 +735,12 @@ appendTransformResultToVector(Ty result,
 /// where OpTy is either
 ///   - Operation *, in which case the transform is always applied;
 ///   - a concrete Op class, in which case a check is performed whether
-///   `targets` contains operations of the same class and a failure is reported
-///   if it does not.
+///   `targets` contains operations of the same class and a silencable failure
+///   is reported if it does not.
 template <typename FnTy>
-LogicalResult applyTransformToEach(ArrayRef<Operation *> targets,
-                                   SmallVectorImpl<Operation *> &results,
-                                   FnTy transform) {
+DiagnosedSilencableFailure
+applyTransformToEach(ArrayRef<Operation *> targets,
+                     SmallVectorImpl<Operation *> &results, FnTy transform) {
   using OpTy = typename llvm::function_traits<FnTy>::template arg_t<0>;
   static_assert(std::is_convertible<OpTy, Operation *>::value,
                 "expected transform function to take an operation");
@@ -627,37 +750,43 @@ LogicalResult applyTransformToEach(ArrayRef<Operation *> targets,
                 "FailureOr<convertible-to-Operation*>");
   for (Operation *target : targets) {
     auto specificOp = dyn_cast<OpTy>(target);
-    if (!specificOp)
-      return failure();
+    if (!specificOp) {
+      Diagnostic diag(target->getLoc(), DiagnosticSeverity::Error);
+      diag << "attempted to apply transform to the wrong op kind";
+      return DiagnosedSilencableFailure::silencableFailure(std::move(diag));
+    }
 
     auto result = transform(specificOp);
     if (failed(appendTransformResultToVector(result, results)))
-      return failure();
+      return DiagnosedSilencableFailure::definiteFailure();
   }
-  return success();
+  return DiagnosedSilencableFailure::success();
 }
 } // namespace detail
 } // namespace transform
 } // namespace mlir
 
 template <typename OpTy>
-mlir::LogicalResult mlir::transform::TransformEachOpTrait<OpTy>::apply(
+mlir::DiagnosedSilencableFailure
+mlir::transform::TransformEachOpTrait<OpTy>::apply(
     TransformResults &transformResults, TransformState &state) {
   using TransformOpType = typename llvm::function_traits<
       decltype(&OpTy::applyToOne)>::template arg_t<0>;
   ArrayRef<Operation *> targets =
       state.getPayloadOps(this->getOperation()->getOperand(0));
   SmallVector<Operation *> results;
-  if (failed(detail::applyTransformToEach(
-          targets, results, [&](TransformOpType specificOp) {
-            return static_cast<OpTy *>(this)->applyToOne(specificOp);
-          })))
-    return failure();
+  DiagnosedSilencableFailure result = detail::applyTransformToEach(
+      targets, results, [&](TransformOpType specificOp) {
+        return static_cast<OpTy *>(this)->applyToOne(specificOp);
+      });
+  if (!result.succeeded())
+    return result;
+
   if (OpTy::template hasTrait<OpTrait::OneResult>()) {
     transformResults.set(
         this->getOperation()->getResult(0).template cast<OpResult>(), results);
   }
-  return success();
+  return DiagnosedSilencableFailure::success();
 }
 
 template <typename OpTy>

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
index ff85a74500725..b8b6a0aae1a65 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
@@ -38,9 +38,11 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> {
         accepts as arguments the object that must be populated with results of
         the current transformation and a transformation state object that can be
         used for queries, e.g., to obtain the list of operations on which the
-        transformation represented by the current op is targeted.
+        transformation represented by the current op is targeted. Returns a
+        special status object indicating whether the transformation succeeded
+        or failed, and, if it failed, whether the failure is recoverable or not.
       }],
-      /*returnType=*/"::mlir::LogicalResult",
+      /*returnType=*/"::mlir::DiagnosedSilencableFailure",
       /*name=*/"apply",
       /*arguments=*/(ins
           "::mlir::transform::TransformResults &":$transformResults,
@@ -59,6 +61,13 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> {
       diag.attachNote(target->getLoc()) << "attempted to apply to this op";
       return diag;
     }
+
+    /// Creates the silencable failure object with a diagnostic located at the
+    /// current operation.
+    DiagnosedSilencableFailure emitSilencableError() {
+      Diagnostic diag($_op->getLoc(), DiagnosticSeverity::Error);
+      return DiagnosedSilencableFailure::silencableFailure(std::move(diag));
+    }
   }];
 }
 

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 28ef83eafa119..f415e0395080f 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -17,6 +17,83 @@ include "mlir/Dialect/Transform/IR/TransformDialect.td"
 include "mlir/Dialect/Transform/IR/TransformEffects.td"
 include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
 
+def AlternativesOp : TransformDialectOp<"alternatives",
+    [DeclareOpInterfaceMethods<RegionBranchOpInterface,
+        ["getSuccessorEntryOperands", "getSuccessorRegions",
+         "getRegionInvocationBounds"]>,
+     DeclareOpInterfaceMethods<TransformOpInterface>,
+     FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     IsolatedFromAbove, PossibleTopLevelTransformOpTrait,
+     SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">]> {
+  let summary = "Attempts sequences of transforms until one succeeds";
+  let description = [{
+    This op may have an arbitrary number of regions, each of which represents a
+    sequence of transform operations to be applied to the same payload IR. The
+    regions are visited in order of appearance, and transforms in them are
+    applied in their respective order of appearance. If one of these transforms
+    fails to apply, the remaining ops in the same region are skipped an the next
+    region is attempted. If all transformations in a region succeed, the
+    remaining regions are skipped and the entire "alternatives" transformation
+    succeeds. If all regions contained a failing transformation, the entire
+    "alternatives" transformation fails.
+
+    It is up to the nested operations to define which errors are "recoverable"
+    (or "silencable") and allow another alternatives to be attempted, and which
+    errors should be propagated without attempting the other alternatives.
+
+    The single operand of this operation is the scope in which the alternative
+    transformation sequences are attempted, that is, an operation in the payload
+    IR that contains all the other operations that may be modified by the
+    transformations. There is no check that the transforms are indeed scoped
+    as their "apply" methods can be arbitrarily complex. Therefore it is the
+    responsibility of the user to ensure that the transforms are scoped
+    correctly, or to produce an irrecoverable error and thus abort the execution
+    without attempting the remaining alternatives. Note that the payload IR
+    outside of the given scope is not necessarily in the valid state, or even
+    accessible to the tranfsormation.
+    
+    The changes to the IR within the scope performed by transforms in the failed
+    alternative region are reverted before attempting the next region.
+    Practically, this is achieved by cloning the scope. Therefore it is advised
+    to limit the scope as much as possible and place the most likely
+    alternatives early in the region list. The operation is also isolated from
+    above and requires rediscovering the operations within the given scope to
+    avoid additional handle invalidation. The latter restriction may be lifted
+    in the future.
+
+    Each of the regions may yield transform IR handles. The handles of the first
+    successful alternative region are returned as the results of the
+    "alternatives" op. Therefore, each alternative region must yield the same
+    number of results, which should also match the number and the types of the
+    "alternatives" op results.
+
+    Remark: this op allows one to implement a simple "try" construct as follows:
+
+    ```mlir
+    %result = transform.alternatives %scope {
+    ^bb0(%arg0: !pdl.operation):
+      // Try a failible transformation.
+      %0 = transform.failible %arg0 // ...
+      // If succeeded, yield the the result of the transformation.
+      transform.yield %0 : !pdl.operation
+    }, {
+    ^bb0(%arg0: !pdl.operation):
+      // Otherwise, the second alternative is tried and it always succeeeds by
+      // returning the original handle.
+      transform.yield %arg0 : !pdl.operation
+    }
+    ```
+  }];
+
+  let arguments = (ins Optional<PDL_Operation>:$scope);
+  let results = (outs Variadic<AnyType>:$results);
+  let regions = (region VariadicRegion<SizedRegion<1>>:$alternatives);
+
+  let assemblyFormat =
+    "($scope^)? (`->` type($results)^)? attr-dict-with-keyword regions";
+  let hasVerifier = 1;
+}
+
 def GetClosestIsolatedParentOp : TransformDialectOp<"get_closest_isolated_parent",
     [DeclareOpInterfaceMethods<TransformOpInterface>,
      NavigationTransformOpTrait, MemoryEffectsOpInterface]> {

diff  --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index 97d0bda35af73..c1f532c89e9f0 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -24,7 +24,7 @@ using namespace mlir::transform;
 // OneShotBufferizeOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult
+DiagnosedSilencableFailure
 transform::OneShotBufferizeOp::apply(TransformResults &transformResults,
                                      TransformState &state) {
   OneShotBufferizationOptions options;
@@ -39,19 +39,19 @@ transform::OneShotBufferizeOp::apply(TransformResults &transformResults,
   for (Operation *target : payloadOps) {
     auto moduleOp = dyn_cast<ModuleOp>(target);
     if (getTargetIsModule() && !moduleOp)
-      return emitError("expected ModuleOp target");
+      return emitSilencableError() << "expected ModuleOp target";
     if (options.bufferizeFunctionBoundaries) {
       if (!moduleOp)
-        return emitError("expected ModuleOp target");
+        return emitSilencableError() << "expected ModuleOp target";
       if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options)))
-        return emitError("bufferization failed");
+        return emitSilencableError() << "bufferization failed";
     } else {
       if (failed(bufferization::runOneShotBufferize(target, options)))
-        return emitError("bufferization failed");
+        return emitSilencableError() << "bufferization failed";
     }
   }
 
-  return success();
+  return DiagnosedSilencableFailure::success();
 }
 
 void transform::OneShotBufferizeOp::getEffects(

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index b5cfb2ab58dd2..d239cad4b7ad0 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -13,10 +13,8 @@
 #include "mlir/Dialect/PDL/IR/PDL.h"
 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
-#include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Parser/Parser.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/Support/FormatVariadic.h"
 
 using namespace mlir;
 using namespace mlir::linalg;
@@ -166,14 +164,14 @@ static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result,
   return success();
 }
 
-LogicalResult
+DiagnosedSilencableFailure
 transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
                          mlir::transform::TransformState &state) {
   LinalgTilingAndFusionOptions fusionOptions;
   fusionOptions.tileSizes = extractI64Array(getTileSizes());
   fusionOptions.tileInterchange = extractI64Array(getTileInterchange());
 
-  return applyTilingToAll(
+  LogicalResult result = applyTilingToAll(
       getOperation(), getTarget(), fusionOptions.tileSizes, transformResults,
       state, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> {
         LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions);
@@ -190,6 +188,8 @@ transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
                                tileLoopNest->getLoopOps().end()};
         return tiledLinalgOp;
       });
+  return failed(result) ? DiagnosedSilencableFailure::definiteFailure()
+                        : DiagnosedSilencableFailure::success();
 }
 
 ParseResult transform::FuseOp::parse(OpAsmParser &parser,
@@ -398,8 +398,9 @@ FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target) {
 // TileOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult transform::TileOp::apply(TransformResults &transformResults,
-                                       TransformState &state) {
+DiagnosedSilencableFailure
+transform::TileOp::apply(TransformResults &transformResults,
+                         TransformState &state) {
   LinalgTilingOptions tilingOptions;
   SmallVector<int64_t> tileSizes = extractI64Array(getSizes());
 
@@ -408,12 +409,13 @@ LogicalResult transform::TileOp::apply(TransformResults &transformResults,
   tilingOptions.setInterchange(extractUIntArray(getInterchange()));
   LinalgTilingPattern pattern(getContext(), tilingOptions);
 
-  return applyTilingToAll(getOperation(), getTarget(), tileSizes,
-                          transformResults, state, [&](LinalgOp linalgOp) {
-                            SimpleRewriter rewriter(linalgOp.getContext());
-                            return pattern.returningMatchAndRewrite(linalgOp,
-                                                                    rewriter);
-                          });
+  LogicalResult result = applyTilingToAll(
+      getOperation(), getTarget(), tileSizes, transformResults, state,
+      [&](LinalgOp linalgOp) {
+        SimpleRewriter rewriter(linalgOp.getContext());
+        return pattern.returningMatchAndRewrite(linalgOp, rewriter);
+      });
+  return DiagnosedSilencableFailure(result);
 }
 
 ParseResult transform::TileOp::parse(OpAsmParser &parser,

diff  --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 7910e83eac981..6b39a35976616 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/SCF/Transforms.h"
 #include "mlir/Dialect/SCF/Utils/Utils.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 
 using namespace mlir;
@@ -30,7 +31,7 @@ class SimpleRewriter : public PatternRewriter {
 // GetParentForOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult
+DiagnosedSilencableFailure
 transform::GetParentForOp::apply(transform::TransformResults &results,
                                  transform::TransformState &state) {
   SetVector<Operation *> parents;
@@ -40,9 +41,10 @@ transform::GetParentForOp::apply(transform::TransformResults &results,
     for (unsigned i = 0, e = getNumLoops(); i < e; ++i) {
       loop = current->getParentOfType<scf::ForOp>();
       if (!loop) {
-        InFlightDiagnostic diag = emitError() << "could not find an '"
-                                              << scf::ForOp::getOperationName()
-                                              << "' parent";
+        DiagnosedSilencableFailure diag = emitSilencableError()
+                                          << "could not find an '"
+                                          << scf::ForOp::getOperationName()
+                                          << "' parent";
         diag.attachNote(target->getLoc()) << "target op";
         return diag;
       }
@@ -51,7 +53,7 @@ transform::GetParentForOp::apply(transform::TransformResults &results,
     parents.insert(loop);
   }
   results.set(getResult().cast<OpResult>(), parents.getArrayRef());
-  return success();
+  return DiagnosedSilencableFailure::success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -83,7 +85,7 @@ static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b,
   return executeRegionOp;
 }
 
-LogicalResult
+DiagnosedSilencableFailure
 transform::LoopOutlineOp::apply(transform::TransformResults &results,
                                 transform::TransformState &state) {
   SmallVector<Operation *> transformed;
@@ -94,7 +96,8 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results,
     SimpleRewriter rewriter(getContext());
     scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target);
     if (!exec) {
-      InFlightDiagnostic diag = emitError() << "failed to outline";
+      DiagnosedSilencableFailure diag = emitSilencableError()
+                                        << "failed to outline";
       diag.attachNote(target->getLoc()) << "target op";
       return diag;
     }
@@ -102,8 +105,10 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results,
     FailureOr<func::FuncOp> outlined = outlineSingleBlockRegion(
         rewriter, location, exec.getRegion(), getFuncName(), &call);
 
-    if (failed(outlined))
-      return reportUnknownTransformError(target);
+    if (failed(outlined)) {
+      (void)reportUnknownTransformError(target);
+      return DiagnosedSilencableFailure::definiteFailure();
+    }
 
     if (symbolTableOp) {
       SymbolTable &symbolTable =
@@ -115,7 +120,7 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results,
     transformed.push_back(*outlined);
   }
   results.set(getTransformed().cast<OpResult>(), transformed);
-  return success();
+  return DiagnosedSilencableFailure::success();
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index ff28f447e43bd..ad6935fdc71f5 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -10,8 +10,10 @@
 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Operation.h"
-#include "llvm/ADT/ScopeExit.h"
-#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "transform-dialect"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
 
 using namespace mlir;
 
@@ -186,16 +188,18 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
   return success();
 }
 
-LogicalResult
+DiagnosedSilencableFailure
 transform::TransformState::applyTransform(TransformOpInterface transform) {
+  LLVM_DEBUG(DBGS() << "applying: " << transform << "\n");
   if (options.getExpensiveChecksEnabled() &&
       failed(checkAndRecordHandleInvalidation(transform))) {
-    return failure();
+    return DiagnosedSilencableFailure::definiteFailure();
   }
 
   transform::TransformResults results(transform->getNumResults());
-  if (failed(transform.apply(results, *this)))
-    return failure();
+  DiagnosedSilencableFailure result(transform.apply(results, *this));
+  if (!result.succeeded())
+    return result;
 
   // Remove the mapping for the operand if it is consumed by the operation. This
   // allows us to catch use-after-free with assertions later on.
@@ -219,10 +223,10 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
            "payload IR association for a value other than the result of the "
            "current transform op");
     if (failed(setPayloadOps(result, results.get(result.getResultNumber()))))
-      return failure();
+      return DiagnosedSilencableFailure::definiteFailure();
   }
 
-  return success();
+  return DiagnosedSilencableFailure::success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -273,14 +277,14 @@ transform::TransformResults::get(unsigned resultNumber) const {
 //===----------------------------------------------------------------------===//
 
 LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
-    TransformState &state, Operation *op) {
+    TransformState &state, Operation *op, unsigned region) {
   SmallVector<Operation *> targets;
   if (op->getNumOperands() != 0)
     llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
   else
     targets.push_back(state.getTopLevel());
 
-  return state.mapBlockArguments(op->getRegion(0).front().getArgument(0),
+  return state.mapBlockArguments(op->getRegion(region).front().getArgument(0),
                                  targets);
 }
 
@@ -293,8 +297,8 @@ transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) {
          "should implement TransformOpInterface to have "
          "PossibleTopLevelTransformOpTrait");
 
-  if (op->getNumRegions() != 1)
-    return op->emitOpError() << "expects one region";
+  if (op->getNumRegions() < 1)
+    return op->emitOpError() << "expects at least one region";
 
   Region *bodyRegion = &op->getRegion(0);
   if (!llvm::hasNItems(*bodyRegion, 1))

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 95d43601a6ea6..51956ab9eddc2 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -10,13 +10,16 @@
 #include "mlir/Dialect/PDL/IR/PDLOps.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
-#include "mlir/IR/Builders.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
 #include "mlir/Rewrite/PatternApplicator.h"
 #include "llvm/ADT/ScopeExit.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "transform-dialect"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
 
 using namespace mlir;
 
@@ -115,34 +118,174 @@ LogicalResult PatternApplicatorExtension::findAllMatches(
 }
 } // namespace
 
+//===----------------------------------------------------------------------===//
+// AlternativesOp
+//===----------------------------------------------------------------------===//
+
+OperandRange
+transform::AlternativesOp::getSuccessorEntryOperands(unsigned index) {
+  if (getOperation()->getNumOperands() == 1)
+    return getOperation()->getOperands();
+  return OperandRange(getOperation()->operand_end(),
+                      getOperation()->operand_end());
+}
+
+void transform::AlternativesOp::getSuccessorRegions(
+    Optional<unsigned> index, ArrayRef<Attribute> operands,
+    SmallVectorImpl<RegionSuccessor> &regions) {
+  for (Region &alternative :
+       llvm::drop_begin(getAlternatives(), index.hasValue() ? *index + 1 : 0)) {
+    regions.emplace_back(&alternative, !getOperands().empty()
+                                           ? alternative.getArguments()
+                                           : Block::BlockArgListType());
+  }
+  if (index.hasValue())
+    regions.emplace_back(getOperation()->getResults());
+}
+
+void transform::AlternativesOp::getRegionInvocationBounds(
+    ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
+  (void)operands;
+  // The region corresponding to the first alternative is always executed, the
+  // remaining may or may not be executed.
+  bounds.reserve(getNumRegions());
+  bounds.emplace_back(1, 1);
+  bounds.resize(getNumRegions(), InvocationBounds(0, 1));
+}
+
+static void forwardTerminatorOperands(Block *block,
+                                      transform::TransformState &state,
+                                      transform::TransformResults &results) {
+  for (const auto &pair : llvm::zip(block->getTerminator()->getOperands(),
+                                    block->getParentOp()->getOpResults())) {
+    Value terminatorOperand = std::get<0>(pair);
+    OpResult result = std::get<1>(pair);
+    results.set(result, state.getPayloadOps(terminatorOperand));
+  }
+}
+
+DiagnosedSilencableFailure
+transform::AlternativesOp::apply(transform::TransformResults &results,
+                                 transform::TransformState &state) {
+  SmallVector<Operation *> originals;
+  if (Value scopeHandle = getScope())
+    llvm::append_range(originals, state.getPayloadOps(scopeHandle));
+  else
+    originals.push_back(state.getTopLevel());
+
+  for (Operation *original : originals) {
+    if (original->isAncestor(getOperation())) {
+      InFlightDiagnostic diag =
+          emitError() << "scope must not contain the transforms being applied";
+      diag.attachNote(original->getLoc()) << "scope";
+      return DiagnosedSilencableFailure::definiteFailure();
+    }
+  }
+
+  for (Region &reg : getAlternatives()) {
+    // Clone the scope operations and make the transforms in this alternative
+    // region apply to them by virtue of mapping the block argument (the only
+    // visible handle) to the cloned scope operations. This effectively prevents
+    // the transformation from accessing any IR outside the scope.
+    auto scope = state.make_region_scope(reg);
+    auto clones = llvm::to_vector(
+        llvm::map_range(originals, [](Operation *op) { return op->clone(); }));
+    if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
+      return DiagnosedSilencableFailure::definiteFailure();
+    auto deleteClones = llvm::make_scope_exit([&] {
+      for (Operation *clone : clones)
+        clone->erase();
+    });
+
+    bool failed = false;
+    for (Operation &transform : reg.front().without_terminator()) {
+      DiagnosedSilencableFailure result =
+          state.applyTransform(cast<TransformOpInterface>(transform));
+      if (result.isSilencableFailure()) {
+        LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage()
+                          << "\n");
+        failed = true;
+        break;
+      }
+
+      if (::mlir::failed(result.silence()))
+        return DiagnosedSilencableFailure::definiteFailure();
+    }
+
+    // If all operations in the given alternative succeeded, no need to consider
+    // the rest. Replace the original scoping operation with the clone on which
+    // the transformations were performed.
+    if (!failed) {
+      // We will be using the clones, so cancel their scheduled deletion.
+      deleteClones.release();
+      IRRewriter rewriter(getContext());
+      for (const auto &kvp : llvm::zip(originals, clones)) {
+        Operation *original = std::get<0>(kvp);
+        Operation *clone = std::get<1>(kvp);
+        original->getBlock()->getOperations().insert(original->getIterator(),
+                                                     clone);
+        rewriter.replaceOp(original, clone->getResults());
+      }
+      forwardTerminatorOperands(&reg.front(), state, results);
+      return DiagnosedSilencableFailure::success();
+    }
+  }
+  return emitSilencableError() << "all alternatives failed";
+}
+
+LogicalResult transform::AlternativesOp::verify() {
+  for (Region &alternative : getAlternatives()) {
+    Block &block = alternative.front();
+    if (block.getNumArguments() != 1 ||
+        !block.getArgument(0).getType().isa<pdl::OperationType>()) {
+      return emitOpError()
+             << "expects region blocks to have one operand of type "
+             << pdl::OperationType::get(getContext());
+    }
+
+    Operation *terminator = block.getTerminator();
+    if (terminator->getOperands().getTypes() != getResults().getTypes()) {
+      InFlightDiagnostic diag = emitOpError()
+                                << "expects terminator operands to have the "
+                                   "same type as results of the operation";
+      diag.attachNote(terminator->getLoc()) << "terminator";
+      return diag;
+    }
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // GetClosestIsolatedParentOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult transform::GetClosestIsolatedParentOp::apply(
+DiagnosedSilencableFailure transform::GetClosestIsolatedParentOp::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   SetVector<Operation *> parents;
   for (Operation *target : state.getPayloadOps(getTarget())) {
     Operation *parent =
         target->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
     if (!parent) {
-      InFlightDiagnostic diag =
-          emitError() << "could not find an isolated-from-above parent op";
+      DiagnosedSilencableFailure diag =
+          emitSilencableError()
+          << "could not find an isolated-from-above parent op";
       diag.attachNote(target->getLoc()) << "target op";
       return diag;
     }
     parents.insert(parent);
   }
   results.set(getResult().cast<OpResult>(), parents.getArrayRef());
-  return success();
+  return DiagnosedSilencableFailure::success();
 }
 
 //===----------------------------------------------------------------------===//
 // PDLMatchOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult transform::PDLMatchOp::apply(transform::TransformResults &results,
-                                           transform::TransformState &state) {
+DiagnosedSilencableFailure
+transform::PDLMatchOp::apply(transform::TransformResults &results,
+                             transform::TransformState &state) {
   auto *extension = state.getExtension<PatternApplicatorExtension>();
   assert(extension &&
          "expected PatternApplicatorExtension to be attached by the parent op");
@@ -150,41 +293,38 @@ LogicalResult transform::PDLMatchOp::apply(transform::TransformResults &results,
   for (Operation *root : state.getPayloadOps(getRoot())) {
     if (failed(extension->findAllMatches(
             getPatternName().getLeafReference().getValue(), root, targets))) {
-      return emitOpError() << "could not find pattern '" << getPatternName()
-                           << "'";
+      emitOpError() << "could not find pattern '" << getPatternName() << "'";
+      return DiagnosedSilencableFailure::definiteFailure();
     }
   }
   results.set(getResult().cast<OpResult>(), targets);
-  return success();
+  return DiagnosedSilencableFailure::success();
 }
 
 //===----------------------------------------------------------------------===//
 // SequenceOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult transform::SequenceOp::apply(transform::TransformResults &results,
-                                           transform::TransformState &state) {
+DiagnosedSilencableFailure
+transform::SequenceOp::apply(transform::TransformResults &results,
+                             transform::TransformState &state) {
   // Map the entry block argument to the list of operations.
   auto scope = state.make_region_scope(*getBodyBlock()->getParent());
   if (failed(mapBlockArguments(state)))
-    return failure();
+    return DiagnosedSilencableFailure::definiteFailure();
 
   // Apply the sequenced ops one by one.
-  for (Operation &transform : getBodyBlock()->without_terminator())
-    if (failed(state.applyTransform(cast<TransformOpInterface>(transform))))
-      return failure();
+  for (Operation &transform : getBodyBlock()->without_terminator()) {
+    DiagnosedSilencableFailure result =
+        state.applyTransform(cast<TransformOpInterface>(transform));
+    if (!result.succeeded())
+      return result;
+  }
 
   // Forward the operation mapping for values yielded from the sequence to the
   // values produced by the sequence op.
-  for (const auto &pair :
-       llvm::zip(getBodyBlock()->getTerminator()->getOperands(),
-                 getOperation()->getOpResults())) {
-    Value terminatorOperand = std::get<0>(pair);
-    OpResult result = std::get<1>(pair);
-    results.set(result, state.getPayloadOps(terminatorOperand));
-  }
-
-  return success();
+  forwardTerminatorOperands(getBodyBlock(), state, results);
+  return DiagnosedSilencableFailure::success();
 }
 
 /// Returns `true` if the given op operand may be consuming the handle value in
@@ -346,7 +486,7 @@ void transform::SequenceOp::getRegionInvocationBounds(
 // WithPDLPatternsOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult
+DiagnosedSilencableFailure
 transform::WithPDLPatternsOp::apply(transform::TransformResults &results,
                                     transform::TransformState &state) {
   OwningOpRef<ModuleOp> pdlModuleOp =
@@ -365,7 +505,7 @@ transform::WithPDLPatternsOp::apply(transform::TransformResults &results,
 
   auto scope = state.make_region_scope(getBody());
   if (failed(mapBlockArguments(state)))
-    return failure();
+    return DiagnosedSilencableFailure::definiteFailure();
   return state.applyTransform(transformOp);
 }
 

diff  --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
index ade63ebfc4596..b76bd07d3a475 100644
--- a/mlir/test/Dialect/Transform/ops-invalid.mlir
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -81,7 +81,7 @@ transform.with_pdl_patterns {
 
 // -----
 
-// expected-error @below {{expects one region}}
+// expected-error @below {{expects at least one region}}
 "transform.test_transform_unrestricted_op_no_interface"() : () -> ()
 
 // -----
@@ -153,3 +153,34 @@ transform.sequence {
     }
   }
 }
+
+// -----
+
+transform.sequence {
+^bb1(%arg1: !pdl.operation):
+  // expected-error @below {{expects at least one region}}
+  transform.alternatives
+}
+
+// -----
+
+transform.sequence {
+^bb1(%arg1: !pdl.operation):
+  // expected-error @below {{expects terminator operands to have the same type as results of the operation}}
+  %2 = transform.alternatives %arg1 -> !pdl.operation {
+  ^bb2(%arg2: !pdl.operation):
+    transform.yield %arg2 : !pdl.operation
+  }, {
+  ^bb2(%arg2: !pdl.operation):
+    // expected-note @below {{terminator}}
+    transform.yield
+  }
+}
+
+// -----
+
+// expected-error @below {{expects the entry block to have one argument of type '!pdl.operation'}}
+transform.alternatives {
+^bb0:
+  transform.yield
+}

diff  --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 530af50df2950..22cfc009af17c 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -128,3 +128,223 @@ transform.with_pdl_patterns {
     test_print_remark_at_operand %m, "parent function"
   }
 }
+
+// -----
+
+func.func @foo() {
+  %0 = arith.constant 0 : i32
+  return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @match_func : benefit(1) {
+    %0 = pdl.operands
+    %1 = pdl.types
+    %2 = pdl.operation "func.func"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+    pdl.rewrite %2 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    // This is necessary to run the transformation on something other than the
+    // top-level module, "alternatives" cannot be run on that.
+    %0 = pdl_match @match_func in %arg1
+    transform.alternatives %0 {
+    ^bb2(%arg2: !pdl.operation):
+      %1 = transform.test_produce_param_or_forward_operand 42
+      // This operation fails, which triggers the next alternative without
+      // reporting the error.
+      transform.test_consume_operand_if_matches_param_or_fail %1[43]
+    }, {
+    ^bb2(%arg2: !pdl.operation):
+      %1 = transform.test_produce_param_or_forward_operand 42
+      // expected-remark @below {{succeeded}}
+      transform.test_consume_operand_if_matches_param_or_fail %1[42]
+    }
+  }
+}
+
+// -----
+
+func.func private @bar()
+
+func.func @foo() {
+  call @bar() : () -> ()
+  return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @match_call : benefit(1) {
+    %0 = pdl.operands
+    %1 = pdl.types
+    %2 = pdl.operation "func.call"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+    pdl.rewrite %2 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = pdl_match @match_call in %arg1
+    %1 = get_closest_isolated_parent %0
+    // expected-error @below {{all alternatives failed}}
+    transform.alternatives %1 {
+    ^bb2(%arg2: !pdl.operation):
+      %2 = transform.pdl_match @match_call in %arg2
+      // expected-remark @below {{applying}}
+      transform.test_emit_remark_and_erase_operand %2, "applying" {fail_after_erase}
+    }
+  }
+}
+
+// -----
+
+func.func private @bar()
+
+func.func @foo() {
+  // expected-remark @below {{still here}}
+  call @bar() : () -> ()
+  return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @match_call : benefit(1) {
+    %0 = pdl.operands
+    %1 = pdl.types
+    %2 = pdl.operation "func.call"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+    pdl.rewrite %2 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = pdl_match @match_call in %arg1
+    %1 = get_closest_isolated_parent %0
+    transform.alternatives %1 {
+    ^bb2(%arg2: !pdl.operation):
+      %2 = transform.pdl_match @match_call in %arg2
+      // expected-remark @below {{applying}}
+      transform.test_emit_remark_and_erase_operand %2, "applying" {fail_after_erase}
+    }, {
+    ^bb2(%arg2: !pdl.operation):
+      %2 = transform.pdl_match @match_call in %arg2
+      transform.test_print_remark_at_operand %2, "still here"
+      // This alternative succeeds.
+    }, {
+    ^bb2(%arg2: !pdl.operation):
+      // This alternative is never run, so we must not have a remark here.
+      %2 = transform.pdl_match @match_call in %arg2
+      transform.test_emit_remark_and_erase_operand %2, "should not happen" {fail_after_erase}
+    }
+  }
+}
+
+// -----
+
+func.func private @bar()
+
+// CHECK-LABEL: @erase_call
+func.func @erase_call() {
+  // CHECK-NOT: call @bar
+  call @bar() : () -> ()
+  return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @match_call : benefit(1) {
+    %0 = pdl.operands
+    %1 = pdl.types
+    %2 = pdl.operation "func.call"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+    pdl.rewrite %2 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = pdl_match @match_call in %arg1
+    %1 = get_closest_isolated_parent %0
+    transform.alternatives %1 {
+    ^bb2(%arg2: !pdl.operation):
+      %2 = transform.pdl_match @match_call in %arg2
+      // expected-remark @below {{applying}}
+      transform.test_emit_remark_and_erase_operand %2, "applying" {fail_after_erase}
+    }, {
+    ^bb2(%arg2: !pdl.operation):
+      %2 = transform.pdl_match @match_call in %arg2
+      // expected-remark @below {{applying second time}}
+      transform.test_emit_remark_and_erase_operand %2, "applying second time"
+    }
+  }
+}
+
+// -----
+
+func.func private @bar()
+
+func.func @foo() {
+  call @bar() : () -> ()
+  return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @match_call : benefit(1) {
+    %0 = pdl.operands
+    %1 = pdl.types
+    %2 = pdl.operation "func.call"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+    pdl.rewrite %2 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = pdl_match @match_call in %arg1
+    %1 = get_closest_isolated_parent %0
+    %2 = transform.alternatives %1 -> !pdl.operation {
+    ^bb2(%arg2: !pdl.operation):
+      %3 = transform.pdl_match @match_call in %arg2
+      // expected-remark @below {{applying}}
+      transform.test_emit_remark_and_erase_operand %3, "applying" {fail_after_erase}
+      %4 = transform.test_produce_param_or_forward_operand 43
+      transform.yield %4 : !pdl.operation
+    }, {
+    ^bb2(%arg2: !pdl.operation):
+      %4 = transform.test_produce_param_or_forward_operand 42
+      transform.yield %4 : !pdl.operation
+    }
+    // The first alternative failed, so the returned value is taken from the
+    // second alternative.
+    // expected-remark @below {{succeeded}}
+    transform.test_consume_operand_if_matches_param_or_fail %2[42]
+  }
+}
+
+// -----
+
+// expected-note @below {{scope}}
+module {
+  func.func @foo() {
+    %0 = arith.constant 0 : i32
+    return
+  }
+
+  func.func @bar() {
+    %0 = arith.constant 0 : i32
+    %1 = arith.constant 1 : i32
+    return
+  }
+
+  transform.sequence {
+  ^bb1(%arg1: !pdl.operation):
+    // expected-error @below {{scope must not contain the transforms being applied}}
+    transform.alternatives %arg1 {
+    ^bb2(%arg2: !pdl.operation):
+      %0 = transform.test_produce_param_or_forward_operand 42
+      transform.test_consume_operand_if_matches_param_or_fail %0[43]
+    }, {
+    ^bb2(%arg2: !pdl.operation):
+      %0 = transform.test_produce_param_or_forward_operand 42
+      transform.test_consume_operand_if_matches_param_or_fail %0[42]
+    }
+  }
+}
+

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index e2d0a74b9f429..81cba585ffc60 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -38,13 +38,13 @@ class TestTransformOp
     return llvm::StringLiteral("transform.test_transform_op");
   }
 
-  LogicalResult apply(transform::TransformResults &results,
-                      transform::TransformState &state) {
+  DiagnosedSilencableFailure apply(transform::TransformResults &results,
+                                   transform::TransformState &state) {
     InFlightDiagnostic remark = emitRemark() << "applying transformation";
     if (Attribute message = getMessage())
       remark << " " << message;
 
-    return success();
+    return DiagnosedSilencableFailure::success();
   }
 
   Attribute getMessage() { return getOperation()->getAttr("message"); }
@@ -91,9 +91,9 @@ class TestTransformUnrestrictedOpNoInterface
         "transform.test_transform_unrestricted_op_no_interface");
   }
 
-  LogicalResult apply(transform::TransformResults &results,
-                      transform::TransformState &state) {
-    return success();
+  DiagnosedSilencableFailure apply(transform::TransformResults &results,
+                                   transform::TransformState &state) {
+    return DiagnosedSilencableFailure::success();
   }
 
   // No side effects.
@@ -101,7 +101,8 @@ class TestTransformUnrestrictedOpNoInterface
 };
 } // namespace
 
-LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::apply(
+DiagnosedSilencableFailure
+mlir::test::TestProduceParamOrForwardOperandOp::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   if (getOperation()->getNumOperands() != 0) {
     results.set(getResult().cast<OpResult>(),
@@ -110,7 +111,7 @@ LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::apply(
     results.set(getResult().cast<OpResult>(),
                 reinterpret_cast<Operation *>(*getParameter()));
   }
-  return success();
+  return DiagnosedSilencableFailure::success();
 }
 
 LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() {
@@ -119,48 +120,50 @@ LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() {
   return success();
 }
 
-LogicalResult
+DiagnosedSilencableFailure
 mlir::test::TestConsumeOperand::apply(transform::TransformResults &results,
                                       transform::TransformState &state) {
-  return success();
+  return DiagnosedSilencableFailure::success();
 }
 
-LogicalResult mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply(
+DiagnosedSilencableFailure
+mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   ArrayRef<Operation *> payload = state.getPayloadOps(getOperand());
   assert(payload.size() == 1 && "expected a single target op");
   auto value = reinterpret_cast<intptr_t>(payload[0]);
   if (static_cast<uint64_t>(value) != getParameter()) {
-    return emitOpError() << "expected the operand to be associated with "
-                         << getParameter() << " got " << value;
+    return emitSilencableError()
+           << "op expected the operand to be associated with " << getParameter()
+           << " got " << value;
   }
 
   emitRemark() << "succeeded";
-  return success();
+  return DiagnosedSilencableFailure::success();
 }
 
-LogicalResult mlir::test::TestPrintRemarkAtOperandOp::apply(
+DiagnosedSilencableFailure mlir::test::TestPrintRemarkAtOperandOp::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   ArrayRef<Operation *> payload = state.getPayloadOps(getOperand());
   for (Operation *op : payload)
     op->emitRemark() << getMessage();
 
-  return success();
+  return DiagnosedSilencableFailure::success();
 }
 
-LogicalResult
+DiagnosedSilencableFailure
 mlir::test::TestAddTestExtensionOp::apply(transform::TransformResults &results,
                                           transform::TransformState &state) {
   state.addExtension<TestTransformStateExtension>(getMessageAttr());
-  return success();
+  return DiagnosedSilencableFailure::success();
 }
 
-LogicalResult mlir::test::TestCheckIfTestExtensionPresentOp::apply(
+DiagnosedSilencableFailure mlir::test::TestCheckIfTestExtensionPresentOp::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   auto *extension = state.getExtension<TestTransformStateExtension>();
   if (!extension) {
     emitRemark() << "extension absent";
-    return success();
+    return DiagnosedSilencableFailure::success();
   }
 
   InFlightDiagnostic diag = emitRemark()
@@ -172,40 +175,56 @@ LogicalResult mlir::test::TestCheckIfTestExtensionPresentOp::apply(
            "operations");
   }
 
-  return success();
+  return DiagnosedSilencableFailure::success();
 }
 
-LogicalResult mlir::test::TestRemapOperandPayloadToSelfOp::apply(
+DiagnosedSilencableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   auto *extension = state.getExtension<TestTransformStateExtension>();
-  if (!extension)
-    return emitError() << "TestTransformStateExtension missing";
+  if (!extension) {
+    emitError() << "TestTransformStateExtension missing";
+    return DiagnosedSilencableFailure::definiteFailure();
+  }
 
-  return extension->updateMapping(state.getPayloadOps(getOperand()).front(),
-                                  getOperation());
+  if (failed(extension->updateMapping(state.getPayloadOps(getOperand()).front(),
+                                      getOperation())))
+    return DiagnosedSilencableFailure::definiteFailure();
+  return DiagnosedSilencableFailure::success();
 }
 
-LogicalResult mlir::test::TestRemoveTestExtensionOp::apply(
+DiagnosedSilencableFailure mlir::test::TestRemoveTestExtensionOp::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   state.removeExtension<TestTransformStateExtension>();
-  return success();
+  return DiagnosedSilencableFailure::success();
 }
-LogicalResult mlir::test::TestTransformOpWithRegions::apply(
+DiagnosedSilencableFailure mlir::test::TestTransformOpWithRegions::apply(
     transform::TransformResults &results, transform::TransformState &state) {
-  return success();
+  return DiagnosedSilencableFailure::success();
 }
 
 void mlir::test::TestTransformOpWithRegions::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
 
-LogicalResult mlir::test::TestBranchingTransformOpTerminator::apply(
+DiagnosedSilencableFailure
+mlir::test::TestBranchingTransformOpTerminator::apply(
     transform::TransformResults &results, transform::TransformState &state) {
-  return success();
+  return DiagnosedSilencableFailure::success();
 }
 
 void mlir::test::TestBranchingTransformOpTerminator::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
 
+DiagnosedSilencableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply(
+    transform::TransformResults &results, transform::TransformState &state) {
+  emitRemark() << getRemark();
+  for (Operation *op : state.getPayloadOps(getTarget()))
+    op->erase();
+
+  if (getFailAfterErase())
+    return emitSilencableError() << "silencable error";
+  return DiagnosedSilencableFailure::success();
+}
+
 namespace {
 /// Test extension of the Transform dialect. Registers additional ops and
 /// declares PDL as dependent dialect since the additional ops are using PDL

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index 0cf6b8fb4612d..a8dab8be106fb 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -118,4 +118,14 @@ def TestBranchingTransformOpTerminator
   let cppNamespace = "::mlir::test";
 }
 
+def TestEmitRemarkAndEraseOperandOp
+  : Op<Transform_Dialect, "test_emit_remark_and_erase_operand",
+    [DeclareOpInterfaceMethods<TransformOpInterface>,
+     MemoryEffectsOpInterface, FunctionalStyleTransformOpTrait]> {
+  let arguments = (ins PDL_Operation:$target, StrAttr:$remark,
+                   UnitAttr:$fail_after_erase);
+  let assemblyFormat = "$target `,` $remark attr-dict";
+  let cppNamespace = "::mlir::test";
+}
+
 #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
index e54b2c96dd91a..e74be0d67a676 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
@@ -47,7 +47,7 @@ class TestTransformDialectInterpreterPass
             enableExpensiveChecks));
     for (auto op :
          module.getBody()->getOps<transform::TransformOpInterface>()) {
-      if (failed(state.applyTransform(op)))
+      if (failed(state.applyTransform(op).checkAndReport()))
         return signalPassFailure();
     }
   }


        


More information about the Mlir-commits mailing list