[Mlir-commits] [mlir] 1d45282 - [mlir] address post-commit review for D127724

Alex Zinenko llvmlistbot at llvm.org
Wed Jun 15 09:43:40 PDT 2022


Author: Alex Zinenko
Date: 2022-06-15T18:43:05+02:00
New Revision: 1d45282aa398a1afafd7070cf6a09efacab6dc27

URL: https://github.com/llvm/llvm-project/commit/1d45282aa398a1afafd7070cf6a09efacab6dc27
DIFF: https://github.com/llvm/llvm-project/commit/1d45282aa398a1afafd7070cf6a09efacab6dc27.diff

LOG: [mlir] address post-commit review for D127724

- make transform.alternatives op apply only to isolated-from-above payload IR
  scopes;
- fix potential leak;
- fix several typos.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
    mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/test/Dialect/Transform/test-interpreter.mlir
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
index 1a36e2f3d144c..27a68b05429c4 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
@@ -170,7 +170,7 @@ def Transform_Dialect : Dialect {
     perform the required recovery steps thus succeeding themselves. On
     the other hand, they must propagate irrecoverable failures. For such
     failures, the diagnostics are emitted immediately whereas their emission is
-    postponed for recoverable faliures. Transformation container operations may
+    postponed for recoverable failures. Transformation container operations may
     also fail to recover from a theoretically recoverable failure, in which case
     they are expected to emit the diagnostic and turn the failure into an
     irrecoverable one. A recoverable failure produced by applying the top-level

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index df33f954f6820..acb6769ce6314 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -20,42 +20,42 @@ namespace mlir {
 ///   - success;
 ///   - silencable (recoverable) failure with yet-unreported diagnostic;
 ///   - definite failure.
-/// Silencable failure is intended to communicate information about
+/// Silenceable failure is intended to communicate information about
 /// transformations that did not apply but in a way that supports recovery,
 /// for example, they did not modify the payload IR or modified it in some
 /// predictable way. They are associated with a Diagnostic that provides more
-/// details on the failure. Silencable failure can be discarded, turning the
+/// details on the failure. Silenceable failure can be discarded, turning the
 /// result into success, or "reported", emitting the diagnostic and turning the
 /// result into definite failure. Transform IR operations containing other
 /// operations are allowed to do either with the results of the nested
 /// transformations, but must propagate definite failures as their diagnostics
 /// have been already reported to the user.
-class LLVM_NODISCARD DiagnosedSilencableFailure {
+class LLVM_NODISCARD DiagnosedSilenceableFailure {
 public:
-  explicit DiagnosedSilencableFailure(LogicalResult result) : result(result) {}
-  DiagnosedSilencableFailure(const DiagnosedSilencableFailure &) = delete;
-  DiagnosedSilencableFailure &
-  operator=(const DiagnosedSilencableFailure &) = delete;
-  DiagnosedSilencableFailure(DiagnosedSilencableFailure &&) = default;
-  DiagnosedSilencableFailure &
-  operator=(DiagnosedSilencableFailure &&) = default;
-
-  /// Constructs a DiagnosedSilencableFailure in the success state.
-  static DiagnosedSilencableFailure success() {
-    return DiagnosedSilencableFailure(::mlir::success());
+  explicit DiagnosedSilenceableFailure(LogicalResult result) : result(result) {}
+  DiagnosedSilenceableFailure(const DiagnosedSilenceableFailure &) = delete;
+  DiagnosedSilenceableFailure &
+  operator=(const DiagnosedSilenceableFailure &) = delete;
+  DiagnosedSilenceableFailure(DiagnosedSilenceableFailure &&) = default;
+  DiagnosedSilenceableFailure &
+  operator=(DiagnosedSilenceableFailure &&) = default;
+
+  /// Constructs a DiagnosedSilenceableFailure in the success state.
+  static DiagnosedSilenceableFailure success() {
+    return DiagnosedSilenceableFailure(::mlir::success());
   }
 
-  /// Constructs a DiagnosedSilencableFailure in the failure state. Typically,
+  /// Constructs a DiagnosedSilenceableFailure in the failure state. Typically,
   /// a diagnostic has been emitted before this.
-  static DiagnosedSilencableFailure definiteFailure() {
-    return DiagnosedSilencableFailure(::mlir::failure());
+  static DiagnosedSilenceableFailure definiteFailure() {
+    return DiagnosedSilenceableFailure(::mlir::failure());
   }
 
-  /// Constructs a DiagnosedSilencableFailure in the silencable failure state,
+  /// Constructs a DiagnosedSilenceableFailure in the silencable failure state,
   /// ready to emit the given diagnostic. This is considered a failure
   /// regardless of the diagnostic severity.
-  static DiagnosedSilencableFailure silencableFailure(Diagnostic &&diag) {
-    return DiagnosedSilencableFailure(std::forward<Diagnostic>(diag));
+  static DiagnosedSilenceableFailure silencableFailure(Diagnostic &&diag) {
+    return DiagnosedSilenceableFailure(std::forward<Diagnostic>(diag));
   }
 
   /// Converts all kinds of failure into a LogicalResult failure, emitting the
@@ -75,7 +75,7 @@ class LLVM_NODISCARD DiagnosedSilencableFailure {
   }
 
   /// Returns `true` if this is a silencable failure.
-  bool isSilencableFailure() const { return diagnostic.hasValue(); }
+  bool isSilenceableFailure() const { return diagnostic.hasValue(); }
 
   /// Returns `true` if this is a success.
   bool succeeded() const {
@@ -98,26 +98,26 @@ class LLVM_NODISCARD DiagnosedSilencableFailure {
 
   /// Streams the given values into the diagnotic. Expects this object to be a
   /// silencable failure.
-  template <typename T> DiagnosedSilencableFailure &operator<<(T &&value) & {
-    assert(isSilencableFailure() &&
+  template <typename T> DiagnosedSilenceableFailure &operator<<(T &&value) & {
+    assert(isSilenceableFailure() &&
            "can only append output in silencable failure state");
     *diagnostic << std::forward<T>(value);
     return *this;
   }
-  template <typename T> DiagnosedSilencableFailure &&operator<<(T &&value) && {
+  template <typename T> DiagnosedSilenceableFailure &&operator<<(T &&value) && {
     return std::move(this->operator<<(std::forward<T>(value)));
   }
 
   /// Attaches a note to the diagnostic. Expects this object to be a silencable
   /// failure.
   Diagnostic &attachNote(Optional<Location> loc = llvm::None) {
-    assert(isSilencableFailure() &&
+    assert(isSilenceableFailure() &&
            "can only attach notes to silencable failures");
     return diagnostic->attachNote(loc);
   }
 
 private:
-  explicit DiagnosedSilencableFailure(Diagnostic &&diagnostic)
+  explicit DiagnosedSilenceableFailure(Diagnostic &&diagnostic)
       : diagnostic(std::move(diagnostic)), result(failure()) {}
 
   /// The diagnostic associated with this object. If present, the object is
@@ -226,7 +226,7 @@ class TransformState {
 
   /// Applies the transformation specified by the given transform op and updates
   /// the state accordingly.
-  DiagnosedSilencableFailure applyTransform(TransformOpInterface transform);
+  DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform);
 
   /// Records the mapping between a block argument in the transform IR and a
   /// list of operations in the payload IR. The arguments must be defined in
@@ -524,7 +524,7 @@ namespace detail {
 /// the payload IR, depending on what is available in the context.
 LogicalResult
 mapPossibleTopLevelTransformOpBlockArguments(TransformState &state,
-                                             Operation *op, unsigned region);
+                                             Operation *op, Region &region);
 
 /// Verification hook for PossibleTopLevelTransformOpTrait.
 LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op);
@@ -562,10 +562,18 @@ class PossibleTopLevelTransformOpTrait
   /// and the relevant list of Payload IR operations in the given state. The
   /// state is expected to be already scoped at the region of this operation.
   /// Returns failure if the mapping failed, e.g., the value is already mapped.
-  LogicalResult mapBlockArguments(TransformState &state, unsigned region = 0) {
+  LogicalResult mapBlockArguments(TransformState &state, Region &region) {
+    assert(region.getParentOp() == this->getOperation() &&
+           "op comes from the wrong region");
     return detail::mapPossibleTopLevelTransformOpBlockArguments(
         state, this->getOperation(), region);
   }
+  LogicalResult mapBlockArguments(TransformState &state) {
+    assert(
+        this->getOperation()->getNumRegions() == 1 &&
+        "must indicate the region to map if the operation has more than one");
+    return mapBlockArguments(state, this->getOperation()->getRegion(0));
+  }
 };
 
 /// Trait implementing the TransformOpInterface for operations applying a
@@ -586,8 +594,8 @@ class TransformEachOpTrait
   /// Calls `applyToOne` for every payload operation associated with the operand
   /// of this transform IR op. If `applyToOne` returns ops, associates them with
   /// the result of this transform op.
-  DiagnosedSilencableFailure apply(TransformResults &transformResults,
-                                   TransformState &state);
+  DiagnosedSilenceableFailure apply(TransformResults &transformResults,
+                                    TransformState &state);
 
   /// Checks that the op matches the expectations of this trait.
   static LogicalResult verifyTrait(Operation *op);
@@ -738,7 +746,7 @@ appendTransformResultToVector(Ty result,
 ///   `targets` contains operations of the same class and a silencable failure
 ///   is reported if it does not.
 template <typename FnTy>
-DiagnosedSilencableFailure
+DiagnosedSilenceableFailure
 applyTransformToEach(ArrayRef<Operation *> targets,
                      SmallVectorImpl<Operation *> &results, FnTy transform) {
   using OpTy = typename llvm::function_traits<FnTy>::template arg_t<0>;
@@ -753,21 +761,21 @@ applyTransformToEach(ArrayRef<Operation *> targets,
     if (!specificOp) {
       Diagnostic diag(target->getLoc(), DiagnosticSeverity::Error);
       diag << "attempted to apply transform to the wrong op kind";
-      return DiagnosedSilencableFailure::silencableFailure(std::move(diag));
+      return DiagnosedSilenceableFailure::silencableFailure(std::move(diag));
     }
 
     auto result = transform(specificOp);
     if (failed(appendTransformResultToVector(result, results)))
-      return DiagnosedSilencableFailure::definiteFailure();
+      return DiagnosedSilenceableFailure::definiteFailure();
   }
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
 } // namespace detail
 } // namespace transform
 } // namespace mlir
 
 template <typename OpTy>
-mlir::DiagnosedSilencableFailure
+mlir::DiagnosedSilenceableFailure
 mlir::transform::TransformEachOpTrait<OpTy>::apply(
     TransformResults &transformResults, TransformState &state) {
   using TransformOpType = typename llvm::function_traits<
@@ -775,7 +783,7 @@ mlir::transform::TransformEachOpTrait<OpTy>::apply(
   ArrayRef<Operation *> targets =
       state.getPayloadOps(this->getOperation()->getOperand(0));
   SmallVector<Operation *> results;
-  DiagnosedSilencableFailure result = detail::applyTransformToEach(
+  DiagnosedSilenceableFailure result = detail::applyTransformToEach(
       targets, results, [&](TransformOpType specificOp) {
         return static_cast<OpTy *>(this)->applyToOne(specificOp);
       });
@@ -786,7 +794,7 @@ mlir::transform::TransformEachOpTrait<OpTy>::apply(
     transformResults.set(
         this->getOperation()->getResult(0).template cast<OpResult>(), results);
   }
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
 
 template <typename OpTy>

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
index b8b6a0aae1a65..3ce99c4a3678c 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
@@ -42,7 +42,7 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> {
         special status object indicating whether the transformation succeeded
         or failed, and, if it failed, whether the failure is recoverable or not.
       }],
-      /*returnType=*/"::mlir::DiagnosedSilencableFailure",
+      /*returnType=*/"::mlir::DiagnosedSilenceableFailure",
       /*name=*/"apply",
       /*arguments=*/(ins
           "::mlir::transform::TransformResults &":$transformResults,
@@ -64,9 +64,9 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> {
 
     /// Creates the silencable failure object with a diagnostic located at the
     /// current operation.
-    DiagnosedSilencableFailure emitSilencableError() {
+    DiagnosedSilenceableFailure emitSilenceableError() {
       Diagnostic diag($_op->getLoc(), DiagnosticSeverity::Error);
-      return DiagnosedSilencableFailure::silencableFailure(std::move(diag));
+      return DiagnosedSilenceableFailure::silencableFailure(std::move(diag));
     }
   }];
 }

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index f415e0395080f..dd0e82e4ada79 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -44,13 +44,14 @@ def AlternativesOp : TransformDialectOp<"alternatives",
     The single operand of this operation is the scope in which the alternative
     transformation sequences are attempted, that is, an operation in the payload
     IR that contains all the other operations that may be modified by the
-    transformations. There is no check that the transforms are indeed scoped
-    as their "apply" methods can be arbitrarily complex. Therefore it is the
-    responsibility of the user to ensure that the transforms are scoped
-    correctly, or to produce an irrecoverable error and thus abort the execution
-    without attempting the remaining alternatives. Note that the payload IR
-    outside of the given scope is not necessarily in the valid state, or even
-    accessible to the tranfsormation.
+    transformations. The scope operation must be isolated from above. There is
+    no check that the transforms are indeed scoped as their "apply" methods can
+    be arbitrarily complex. Therefore it is the responsibility of the user to
+    ensure that the transforms are scoped correctly, or to produce an
+    irrecoverable error and thus abort the execution without attempting the
+    remaining alternatives. Note that the payload IR outside of the given scope
+    is not necessarily in the valid state, or even accessible to the
+    tranfsormation.
     
     The changes to the IR within the scope performed by transforms in the failed
     alternative region are reverted before attempting the next region.
@@ -72,8 +73,8 @@ def AlternativesOp : TransformDialectOp<"alternatives",
     ```mlir
     %result = transform.alternatives %scope {
     ^bb0(%arg0: !pdl.operation):
-      // Try a failible transformation.
-      %0 = transform.failible %arg0 // ...
+      // Try a fallible transformation.
+      %0 = transform.fallible %arg0 // ...
       // If succeeded, yield the the result of the transformation.
       transform.yield %0 : !pdl.operation
     }, {

diff  --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index c1f532c89e9f0..66144cf073b4b 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -24,7 +24,7 @@ using namespace mlir::transform;
 // OneShotBufferizeOp
 //===----------------------------------------------------------------------===//
 
-DiagnosedSilencableFailure
+DiagnosedSilenceableFailure
 transform::OneShotBufferizeOp::apply(TransformResults &transformResults,
                                      TransformState &state) {
   OneShotBufferizationOptions options;
@@ -39,19 +39,19 @@ transform::OneShotBufferizeOp::apply(TransformResults &transformResults,
   for (Operation *target : payloadOps) {
     auto moduleOp = dyn_cast<ModuleOp>(target);
     if (getTargetIsModule() && !moduleOp)
-      return emitSilencableError() << "expected ModuleOp target";
+      return emitSilenceableError() << "expected ModuleOp target";
     if (options.bufferizeFunctionBoundaries) {
       if (!moduleOp)
-        return emitSilencableError() << "expected ModuleOp target";
+        return emitSilenceableError() << "expected ModuleOp target";
       if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options)))
-        return emitSilencableError() << "bufferization failed";
+        return emitSilenceableError() << "bufferization failed";
     } else {
       if (failed(bufferization::runOneShotBufferize(target, options)))
-        return emitSilencableError() << "bufferization failed";
+        return emitSilenceableError() << "bufferization failed";
     }
   }
 
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
 
 void transform::OneShotBufferizeOp::getEffects(

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index d239cad4b7ad0..a5e865b4b96d7 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -164,7 +164,7 @@ static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result,
   return success();
 }
 
-DiagnosedSilencableFailure
+DiagnosedSilenceableFailure
 transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
                          mlir::transform::TransformState &state) {
   LinalgTilingAndFusionOptions fusionOptions;
@@ -188,8 +188,8 @@ transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
                                tileLoopNest->getLoopOps().end()};
         return tiledLinalgOp;
       });
-  return failed(result) ? DiagnosedSilencableFailure::definiteFailure()
-                        : DiagnosedSilencableFailure::success();
+  return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
+                        : DiagnosedSilenceableFailure::success();
 }
 
 ParseResult transform::FuseOp::parse(OpAsmParser &parser,
@@ -398,7 +398,7 @@ FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target) {
 // TileOp
 //===----------------------------------------------------------------------===//
 
-DiagnosedSilencableFailure
+DiagnosedSilenceableFailure
 transform::TileOp::apply(TransformResults &transformResults,
                          TransformState &state) {
   LinalgTilingOptions tilingOptions;
@@ -415,7 +415,7 @@ transform::TileOp::apply(TransformResults &transformResults,
         SimpleRewriter rewriter(linalgOp.getContext());
         return pattern.returningMatchAndRewrite(linalgOp, rewriter);
       });
-  return DiagnosedSilencableFailure(result);
+  return DiagnosedSilenceableFailure(result);
 }
 
 ParseResult transform::TileOp::parse(OpAsmParser &parser,

diff  --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 6b39a35976616..f7821da7a53f1 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -31,7 +31,7 @@ class SimpleRewriter : public PatternRewriter {
 // GetParentForOp
 //===----------------------------------------------------------------------===//
 
-DiagnosedSilencableFailure
+DiagnosedSilenceableFailure
 transform::GetParentForOp::apply(transform::TransformResults &results,
                                  transform::TransformState &state) {
   SetVector<Operation *> parents;
@@ -41,10 +41,10 @@ transform::GetParentForOp::apply(transform::TransformResults &results,
     for (unsigned i = 0, e = getNumLoops(); i < e; ++i) {
       loop = current->getParentOfType<scf::ForOp>();
       if (!loop) {
-        DiagnosedSilencableFailure diag = emitSilencableError()
-                                          << "could not find an '"
-                                          << scf::ForOp::getOperationName()
-                                          << "' parent";
+        DiagnosedSilenceableFailure diag = emitSilenceableError()
+                                           << "could not find an '"
+                                           << scf::ForOp::getOperationName()
+                                           << "' parent";
         diag.attachNote(target->getLoc()) << "target op";
         return diag;
       }
@@ -53,7 +53,7 @@ transform::GetParentForOp::apply(transform::TransformResults &results,
     parents.insert(loop);
   }
   results.set(getResult().cast<OpResult>(), parents.getArrayRef());
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -85,7 +85,7 @@ static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b,
   return executeRegionOp;
 }
 
-DiagnosedSilencableFailure
+DiagnosedSilenceableFailure
 transform::LoopOutlineOp::apply(transform::TransformResults &results,
                                 transform::TransformState &state) {
   SmallVector<Operation *> transformed;
@@ -96,8 +96,8 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results,
     SimpleRewriter rewriter(getContext());
     scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target);
     if (!exec) {
-      DiagnosedSilencableFailure diag = emitSilencableError()
-                                        << "failed to outline";
+      DiagnosedSilenceableFailure diag = emitSilenceableError()
+                                         << "failed to outline";
       diag.attachNote(target->getLoc()) << "target op";
       return diag;
     }
@@ -107,7 +107,7 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results,
 
     if (failed(outlined)) {
       (void)reportUnknownTransformError(target);
-      return DiagnosedSilencableFailure::definiteFailure();
+      return DiagnosedSilenceableFailure::definiteFailure();
     }
 
     if (symbolTableOp) {
@@ -120,7 +120,7 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results,
     transformed.push_back(*outlined);
   }
   results.set(getTransformed().cast<OpResult>(), transformed);
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index ad6935fdc71f5..ecf1cbe8aa3ed 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -188,16 +188,16 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
   return success();
 }
 
-DiagnosedSilencableFailure
+DiagnosedSilenceableFailure
 transform::TransformState::applyTransform(TransformOpInterface transform) {
   LLVM_DEBUG(DBGS() << "applying: " << transform << "\n");
   if (options.getExpensiveChecksEnabled() &&
       failed(checkAndRecordHandleInvalidation(transform))) {
-    return DiagnosedSilencableFailure::definiteFailure();
+    return DiagnosedSilenceableFailure::definiteFailure();
   }
 
   transform::TransformResults results(transform->getNumResults());
-  DiagnosedSilencableFailure result(transform.apply(results, *this));
+  DiagnosedSilenceableFailure result(transform.apply(results, *this));
   if (!result.succeeded())
     return result;
 
@@ -223,10 +223,10 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
            "payload IR association for a value other than the result of the "
            "current transform op");
     if (failed(setPayloadOps(result, results.get(result.getResultNumber()))))
-      return DiagnosedSilencableFailure::definiteFailure();
+      return DiagnosedSilenceableFailure::definiteFailure();
   }
 
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -277,15 +277,14 @@ transform::TransformResults::get(unsigned resultNumber) const {
 //===----------------------------------------------------------------------===//
 
 LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
-    TransformState &state, Operation *op, unsigned region) {
+    TransformState &state, Operation *op, Region &region) {
   SmallVector<Operation *> targets;
   if (op->getNumOperands() != 0)
     llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
   else
     targets.push_back(state.getTopLevel());
 
-  return state.mapBlockArguments(op->getRegion(region).front().getArgument(0),
-                                 targets);
+  return state.mapBlockArguments(region.front().getArgument(0), targets);
 }
 
 LogicalResult

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 3037cf65b3a01..d071c8eea26cf 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -164,7 +164,7 @@ static void forwardTerminatorOperands(Block *block,
   }
 }
 
-DiagnosedSilencableFailure
+DiagnosedSilenceableFailure
 transform::AlternativesOp::apply(transform::TransformResults &results,
                                  transform::TransformState &state) {
   SmallVector<Operation *> originals;
@@ -178,7 +178,14 @@ transform::AlternativesOp::apply(transform::TransformResults &results,
       InFlightDiagnostic diag =
           emitError() << "scope must not contain the transforms being applied";
       diag.attachNote(original->getLoc()) << "scope";
-      return DiagnosedSilencableFailure::definiteFailure();
+      return DiagnosedSilenceableFailure::definiteFailure();
+    }
+    if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
+      InFlightDiagnostic diag =
+          emitError()
+          << "only isolated-from-above ops can be alternative scopes";
+      diag.attachNote(original->getLoc()) << "scope";
+      return DiagnosedSilenceableFailure(std::move(diag));
     }
   }
 
@@ -190,18 +197,18 @@ transform::AlternativesOp::apply(transform::TransformResults &results,
     auto scope = state.make_region_scope(reg);
     auto clones = llvm::to_vector(
         llvm::map_range(originals, [](Operation *op) { return op->clone(); }));
-    if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
-      return DiagnosedSilencableFailure::definiteFailure();
     auto deleteClones = llvm::make_scope_exit([&] {
       for (Operation *clone : clones)
         clone->erase();
     });
+    if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
+      return DiagnosedSilenceableFailure::definiteFailure();
 
     bool failed = false;
     for (Operation &transform : reg.front().without_terminator()) {
-      DiagnosedSilencableFailure result =
+      DiagnosedSilenceableFailure result =
           state.applyTransform(cast<TransformOpInterface>(transform));
-      if (result.isSilencableFailure()) {
+      if (result.isSilenceableFailure()) {
         LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage()
                           << "\n");
         failed = true;
@@ -209,7 +216,7 @@ transform::AlternativesOp::apply(transform::TransformResults &results,
       }
 
       if (::mlir::failed(result.silence()))
-        return DiagnosedSilencableFailure::definiteFailure();
+        return DiagnosedSilenceableFailure::definiteFailure();
     }
 
     // If all operations in the given alternative succeeded, no need to consider
@@ -227,10 +234,10 @@ transform::AlternativesOp::apply(transform::TransformResults &results,
         rewriter.replaceOp(original, clone->getResults());
       }
       forwardTerminatorOperands(&reg.front(), state, results);
-      return DiagnosedSilencableFailure::success();
+      return DiagnosedSilenceableFailure::success();
     }
   }
-  return emitSilencableError() << "all alternatives failed";
+  return emitSilenceableError() << "all alternatives failed";
 }
 
 LogicalResult transform::AlternativesOp::verify() {
@@ -260,15 +267,15 @@ LogicalResult transform::AlternativesOp::verify() {
 // GetClosestIsolatedParentOp
 //===----------------------------------------------------------------------===//
 
-DiagnosedSilencableFailure transform::GetClosestIsolatedParentOp::apply(
+DiagnosedSilenceableFailure transform::GetClosestIsolatedParentOp::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   SetVector<Operation *> parents;
   for (Operation *target : state.getPayloadOps(getTarget())) {
     Operation *parent =
         target->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
     if (!parent) {
-      DiagnosedSilencableFailure diag =
-          emitSilencableError()
+      DiagnosedSilenceableFailure diag =
+          emitSilenceableError()
           << "could not find an isolated-from-above parent op";
       diag.attachNote(target->getLoc()) << "target op";
       return diag;
@@ -276,14 +283,14 @@ DiagnosedSilencableFailure transform::GetClosestIsolatedParentOp::apply(
     parents.insert(parent);
   }
   results.set(getResult().cast<OpResult>(), parents.getArrayRef());
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
 
 //===----------------------------------------------------------------------===//
 // PDLMatchOp
 //===----------------------------------------------------------------------===//
 
-DiagnosedSilencableFailure
+DiagnosedSilenceableFailure
 transform::PDLMatchOp::apply(transform::TransformResults &results,
                              transform::TransformState &state) {
   auto *extension = state.getExtension<PatternApplicatorExtension>();
@@ -294,28 +301,28 @@ transform::PDLMatchOp::apply(transform::TransformResults &results,
     if (failed(extension->findAllMatches(
             getPatternName().getLeafReference().getValue(), root, targets))) {
       emitOpError() << "could not find pattern '" << getPatternName() << "'";
-      return DiagnosedSilencableFailure::definiteFailure();
+      return DiagnosedSilenceableFailure::definiteFailure();
     }
   }
   results.set(getResult().cast<OpResult>(), targets);
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
 
 //===----------------------------------------------------------------------===//
 // SequenceOp
 //===----------------------------------------------------------------------===//
 
-DiagnosedSilencableFailure
+DiagnosedSilenceableFailure
 transform::SequenceOp::apply(transform::TransformResults &results,
                              transform::TransformState &state) {
   // Map the entry block argument to the list of operations.
   auto scope = state.make_region_scope(*getBodyBlock()->getParent());
   if (failed(mapBlockArguments(state)))
-    return DiagnosedSilencableFailure::definiteFailure();
+    return DiagnosedSilenceableFailure::definiteFailure();
 
   // Apply the sequenced ops one by one.
   for (Operation &transform : getBodyBlock()->without_terminator()) {
-    DiagnosedSilencableFailure result =
+    DiagnosedSilenceableFailure result =
         state.applyTransform(cast<TransformOpInterface>(transform));
     if (!result.succeeded())
       return result;
@@ -324,7 +331,7 @@ transform::SequenceOp::apply(transform::TransformResults &results,
   // Forward the operation mapping for values yielded from the sequence to the
   // values produced by the sequence op.
   forwardTerminatorOperands(getBodyBlock(), state, results);
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
 
 /// Returns `true` if the given op operand may be consuming the handle value in
@@ -486,7 +493,7 @@ void transform::SequenceOp::getRegionInvocationBounds(
 // WithPDLPatternsOp
 //===----------------------------------------------------------------------===//
 
-DiagnosedSilencableFailure
+DiagnosedSilenceableFailure
 transform::WithPDLPatternsOp::apply(transform::TransformResults &results,
                                     transform::TransformState &state) {
   OwningOpRef<ModuleOp> pdlModuleOp =
@@ -505,7 +512,7 @@ transform::WithPDLPatternsOp::apply(transform::TransformResults &results,
 
   auto scope = state.make_region_scope(getBody());
   if (failed(mapBlockArguments(state)))
-    return DiagnosedSilencableFailure::definiteFailure();
+    return DiagnosedSilenceableFailure::definiteFailure();
   return state.applyTransform(transformOp);
 }
 

diff  --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 22cfc009af17c..45e01e82ab362 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -348,3 +348,33 @@ module {
   }
 }
 
+// -----
+
+func.func @foo(%arg0: index, %arg1: index, %arg2: index) {
+  // expected-note @below {{scope}}
+  scf.for %i = %arg0 to %arg1 step %arg2 {
+    %0 = arith.constant 0 : i32
+  }
+  return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @match_const : benefit(1) {
+    %0 = pdl.operands
+    %1 = pdl.types
+    %2 = pdl.operation "arith.constant"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+    pdl.rewrite %2 with "transform.dialect"
+  }
+
+
+  sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.pdl_match @match_const in %arg1
+    %1 = transform.loop.get_parent_for %0
+    // expected-error @below {{only isolated-from-above ops can be alternative scopes}}
+    alternatives %1 {
+    ^bb2(%arg2: !pdl.operation):
+    }
+  }
+}

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 81cba585ffc60..59052189bbf87 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -38,13 +38,13 @@ class TestTransformOp
     return llvm::StringLiteral("transform.test_transform_op");
   }
 
-  DiagnosedSilencableFailure apply(transform::TransformResults &results,
-                                   transform::TransformState &state) {
+  DiagnosedSilenceableFailure apply(transform::TransformResults &results,
+                                    transform::TransformState &state) {
     InFlightDiagnostic remark = emitRemark() << "applying transformation";
     if (Attribute message = getMessage())
       remark << " " << message;
 
-    return DiagnosedSilencableFailure::success();
+    return DiagnosedSilenceableFailure::success();
   }
 
   Attribute getMessage() { return getOperation()->getAttr("message"); }
@@ -91,9 +91,9 @@ class TestTransformUnrestrictedOpNoInterface
         "transform.test_transform_unrestricted_op_no_interface");
   }
 
-  DiagnosedSilencableFailure apply(transform::TransformResults &results,
-                                   transform::TransformState &state) {
-    return DiagnosedSilencableFailure::success();
+  DiagnosedSilenceableFailure apply(transform::TransformResults &results,
+                                    transform::TransformState &state) {
+    return DiagnosedSilenceableFailure::success();
   }
 
   // No side effects.
@@ -101,7 +101,7 @@ class TestTransformUnrestrictedOpNoInterface
 };
 } // namespace
 
-DiagnosedSilencableFailure
+DiagnosedSilenceableFailure
 mlir::test::TestProduceParamOrForwardOperandOp::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   if (getOperation()->getNumOperands() != 0) {
@@ -111,7 +111,7 @@ mlir::test::TestProduceParamOrForwardOperandOp::apply(
     results.set(getResult().cast<OpResult>(),
                 reinterpret_cast<Operation *>(*getParameter()));
   }
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
 
 LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() {
@@ -120,50 +120,51 @@ LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() {
   return success();
 }
 
-DiagnosedSilencableFailure
+DiagnosedSilenceableFailure
 mlir::test::TestConsumeOperand::apply(transform::TransformResults &results,
                                       transform::TransformState &state) {
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
 
-DiagnosedSilencableFailure
+DiagnosedSilenceableFailure
 mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   ArrayRef<Operation *> payload = state.getPayloadOps(getOperand());
   assert(payload.size() == 1 && "expected a single target op");
   auto value = reinterpret_cast<intptr_t>(payload[0]);
   if (static_cast<uint64_t>(value) != getParameter()) {
-    return emitSilencableError()
+    return emitSilenceableError()
            << "op expected the operand to be associated with " << getParameter()
            << " got " << value;
   }
 
   emitRemark() << "succeeded";
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
 
-DiagnosedSilencableFailure mlir::test::TestPrintRemarkAtOperandOp::apply(
+DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandOp::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   ArrayRef<Operation *> payload = state.getPayloadOps(getOperand());
   for (Operation *op : payload)
     op->emitRemark() << getMessage();
 
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
 
-DiagnosedSilencableFailure
+DiagnosedSilenceableFailure
 mlir::test::TestAddTestExtensionOp::apply(transform::TransformResults &results,
                                           transform::TransformState &state) {
   state.addExtension<TestTransformStateExtension>(getMessageAttr());
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
 
-DiagnosedSilencableFailure mlir::test::TestCheckIfTestExtensionPresentOp::apply(
+DiagnosedSilenceableFailure
+mlir::test::TestCheckIfTestExtensionPresentOp::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   auto *extension = state.getExtension<TestTransformStateExtension>();
   if (!extension) {
     emitRemark() << "extension absent";
-    return DiagnosedSilencableFailure::success();
+    return DiagnosedSilenceableFailure::success();
   }
 
   InFlightDiagnostic diag = emitRemark()
@@ -175,54 +176,54 @@ DiagnosedSilencableFailure mlir::test::TestCheckIfTestExtensionPresentOp::apply(
            "operations");
   }
 
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
 
-DiagnosedSilencableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply(
+DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   auto *extension = state.getExtension<TestTransformStateExtension>();
   if (!extension) {
     emitError() << "TestTransformStateExtension missing";
-    return DiagnosedSilencableFailure::definiteFailure();
+    return DiagnosedSilenceableFailure::definiteFailure();
   }
 
   if (failed(extension->updateMapping(state.getPayloadOps(getOperand()).front(),
                                       getOperation())))
-    return DiagnosedSilencableFailure::definiteFailure();
-  return DiagnosedSilencableFailure::success();
+    return DiagnosedSilenceableFailure::definiteFailure();
+  return DiagnosedSilenceableFailure::success();
 }
 
-DiagnosedSilencableFailure mlir::test::TestRemoveTestExtensionOp::apply(
+DiagnosedSilenceableFailure mlir::test::TestRemoveTestExtensionOp::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   state.removeExtension<TestTransformStateExtension>();
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
-DiagnosedSilencableFailure mlir::test::TestTransformOpWithRegions::apply(
+DiagnosedSilenceableFailure mlir::test::TestTransformOpWithRegions::apply(
     transform::TransformResults &results, transform::TransformState &state) {
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
 
 void mlir::test::TestTransformOpWithRegions::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
 
-DiagnosedSilencableFailure
+DiagnosedSilenceableFailure
 mlir::test::TestBranchingTransformOpTerminator::apply(
     transform::TransformResults &results, transform::TransformState &state) {
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
 
 void mlir::test::TestBranchingTransformOpTerminator::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
 
-DiagnosedSilencableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply(
+DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   emitRemark() << getRemark();
   for (Operation *op : state.getPayloadOps(getTarget()))
     op->erase();
 
   if (getFailAfterErase())
-    return emitSilencableError() << "silencable error";
-  return DiagnosedSilencableFailure::success();
+    return emitSilenceableError() << "silencable error";
+  return DiagnosedSilenceableFailure::success();
 }
 
 namespace {


        


More information about the Mlir-commits mailing list