[Mlir-commits] [mlir] 7d5bef7 - [mlir] make DiagnosedSilenceableError(LogicalResult) ctor private

Alex Zinenko llvmlistbot at llvm.org
Mon Dec 12 04:52:12 PST 2022


Author: Alex Zinenko
Date: 2022-12-12T12:52:06Z
New Revision: 7d5bef77e560f172ebd13471eb24c4cb6063b568

URL: https://github.com/llvm/llvm-project/commit/7d5bef77e560f172ebd13471eb24c4cb6063b568
DIFF: https://github.com/llvm/llvm-project/commit/7d5bef77e560f172ebd13471eb24c4cb6063b568.diff

LOG: [mlir] make DiagnosedSilenceableError(LogicalResult) ctor private

Now we have more convenient functions to construct silenceable errors
while emitting diagnostics, and the constructor is ambiguous as it
doesn't tell whether the logical error is silencebale or definite.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D137257

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
    mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
    mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index bbcfabe754aa3..99e12a1067dd8 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -34,7 +34,6 @@ namespace mlir {
 /// failures as their diagnostics have been already reported to the user.
 class [[nodiscard]] DiagnosedSilenceableFailure {
 public:
-  explicit DiagnosedSilenceableFailure(LogicalResult result) : result(result) {}
   DiagnosedSilenceableFailure(const DiagnosedSilenceableFailure &) = delete;
   DiagnosedSilenceableFailure &
   operator=(const DiagnosedSilenceableFailure &) = delete;
@@ -156,6 +155,7 @@ class [[nodiscard]] DiagnosedSilenceableFailure {
   }
 
 private:
+  explicit DiagnosedSilenceableFailure(LogicalResult result) : result(result) {}
   explicit DiagnosedSilenceableFailure(Diagnostic &&diagnostic)
       : result(failure()) {
     diagnostics.emplace_back(std::move(diagnostic));

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
index fe29f303a630a..33781536239e9 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
@@ -51,23 +51,12 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> {
   ];
 
   let extraSharedClassDeclaration = [{
-    /// Emits a generic transform error for the current transform operation
-    /// targeting the given Payload IR operation and returns failure. Should
-    /// be only used as a last resort when the transformation itself provides
-    /// no further indication as to the reason of the failure.
-    ::mlir::LogicalResult reportUnknownTransformError(
-        ::mlir::Operation *target) {
-      ::mlir::InFlightDiagnostic diag = $_op->emitError() << "failed to apply";
-      diag.attachNote(target->getLoc()) << "attempted to apply to this op";
-      return diag;
-    }
-
     /// Creates the silenceable failure object with a diagnostic located at the
     /// current operation. Silenceable failure must be suppressed or reported
     /// explicitly at some later time.
     DiagnosedSilenceableFailure
     emitSilenceableError(const ::llvm::Twine &message = {}) {
-      return ::mlir::emitSilenceableFailure($_op);
+      return ::mlir::emitSilenceableFailure($_op, message);
     }
 
     /// Creates the definite failure object with a diagnostic located at the
@@ -78,6 +67,17 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> {
       return ::mlir::emitDefiniteFailure($_op, message);
     }
 
+    /// Emits a generic definite failure for the current transform operation
+    /// targeting the given Payload IR operation and returns failure. Should
+    /// be only used as a last resort when the transformation itself provides
+    /// no further indication as to the reason of the failure.
+    DiagnosedDefiniteFailure emitDefaultDefiniteFailure(
+        ::mlir::Operation *target) {
+      auto diag = ::mlir::emitDefiniteFailure($_op, "failed to apply");
+      diag.attachNote(target->getLoc()) << "attempted to apply to this op";
+      return diag;
+    }
+
     /// Creates the default silenceable failure for a transform op that failed
     /// to properly apply to a target.
     DiagnosedSilenceableFailure emitDefaultSilenceableFailure(

diff  --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index baf18dc4d329b..57cce1942803f 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -119,7 +119,7 @@ createGpuLaunch(RewriterBase &rewriter, Location loc,
                                        blkSizeX, blkSizeY, blkSizeZ);
   rewriter.setInsertionPointToEnd(&launchOp.getBody().front());
   rewriter.create<TerminatorOp>(loc);
-  return DiagnosedSilenceableFailure(success());
+  return DiagnosedSilenceableFailure::success();
 }
 
 /// Alter kernel configuration of the given kernel.

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 3ae4163b5cc6a..c8dd269029026 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -79,20 +79,20 @@ transform::DecomposeOp::applyToOne(linalg::LinalgOp target,
                                                      Conv1DNwcWcfOp>>(target);
   if (succeeded(windowedNhwc)) {
     results.push_back(*windowedNhwc);
-    return DiagnosedSilenceableFailure(success());
+    return DiagnosedSilenceableFailure::success();
   }
   FailureOr<LinalgOp> windowedNchw =
       tryApply<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp,
                                                      Conv1DNcwFcwOp>>(target);
   if (succeeded(windowedNchw)) {
     results.push_back(*windowedNchw);
-    return DiagnosedSilenceableFailure(success());
+    return DiagnosedSilenceableFailure::success();
   }
   FailureOr<LinalgOp> depthwise =
       tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target);
   if (succeeded(depthwise)) {
     results.push_back(*depthwise);
-    return DiagnosedSilenceableFailure(success());
+    return DiagnosedSilenceableFailure::success();
   }
   results.assign(1, nullptr);
   return emitDefaultSilenceableFailure(target);
@@ -206,7 +206,8 @@ transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
         return tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
             rewriter, tilingInterfaceOp, tileAndFuseOptions);
       });
-  return DiagnosedSilenceableFailure(result);
+  return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
+                        : DiagnosedSilenceableFailure::success();
 }
 
 ParseResult transform::FuseOp::parse(OpAsmParser &parser,
@@ -568,12 +569,12 @@ transform::GeneralizeOp::applyToOne(linalg::LinalgOp target,
   // Exit early if no transformation is needed.
   if (isa<GenericOp>(target)) {
     results.push_back(target);
-    return DiagnosedSilenceableFailure(success());
+    return DiagnosedSilenceableFailure::success();
   }
   FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target);
   if (succeeded(generic)) {
     results.push_back(generic->getOperation());
-    return DiagnosedSilenceableFailure(success());
+    return DiagnosedSilenceableFailure::success();
   }
   results.assign(1, nullptr);
   return emitDefaultSilenceableFailure(target);
@@ -592,7 +593,7 @@ transform::InterchangeOp::applyToOne(linalg::GenericOp target,
   // Exit early if no transformation is needed.
   if (interchangeVector.empty()) {
     results.push_back(target);
-    return DiagnosedSilenceableFailure(success());
+    return DiagnosedSilenceableFailure::success();
   }
   TrivialPatternRewriter rewriter(target->getContext());
   FailureOr<GenericOp> res =
@@ -600,7 +601,7 @@ transform::InterchangeOp::applyToOne(linalg::GenericOp target,
   if (failed(res))
     return DiagnosedSilenceableFailure::definiteFailure();
   results.push_back(res->getOperation());
-  return DiagnosedSilenceableFailure(success());
+  return DiagnosedSilenceableFailure::success();
 }
 
 LogicalResult transform::InterchangeOp::verify() {
@@ -639,8 +640,7 @@ transform::MatchOp::apply(transform::TransformResults &results,
   ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
   if (payloadOps.size() != 1) {
     results.set(getResult().cast<OpResult>(), {});
-    return DiagnosedSilenceableFailure(
-        this->emitOpError("requires exactly one target handle"));
+    return emitDefiniteFailure("requires exactly one target handle");
   }
 
   SmallVector<Operation *> res;
@@ -687,7 +687,7 @@ transform::MatchOp::apply(transform::TransformResults &results,
 
   payloadOps.front()->walk(matchFun);
   results.set(getResult().cast<OpResult>(), res);
-  return DiagnosedSilenceableFailure(success());
+  return DiagnosedSilenceableFailure::success();
 }
 
 //===---------------------------------------------------------------------===//
@@ -792,7 +792,7 @@ transform::PadOp::applyToOne(linalg::LinalgOp target,
       tryApply<LinalgPaddingPattern>(target, paddingOptions);
   if (succeeded(result)) {
     results.push_back(result->getOperation());
-    return DiagnosedSilenceableFailure(success());
+    return DiagnosedSilenceableFailure::success();
   }
 
   results.assign(1, nullptr);
@@ -866,15 +866,15 @@ transform::PromoteOp::applyToOne(linalg::LinalgOp target,
     promotionOptions = promotionOptions.setAlignment(*getAlignment());
 
   if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
-    return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
+    return emitDefaultDefiniteFailure(target);
 
   TrivialPatternRewriter rewriter(target->getContext());
   rewriter.setInsertionPoint(target);
   FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
   if (failed(res))
-    return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
+    return emitDefaultDefiniteFailure(target);
   results.push_back(target);
-  return DiagnosedSilenceableFailure(success());
+  return DiagnosedSilenceableFailure::success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -909,7 +909,7 @@ transform::ReplaceOp::apply(TransformResults &transformResults,
     replacements.push_back(replacement);
   }
   transformResults.set(getReplacement().cast<OpResult>(), replacements);
-  return DiagnosedSilenceableFailure(success());
+  return DiagnosedSilenceableFailure::success();
 }
 
 void transform::ReplaceOp::getEffects(
@@ -972,10 +972,10 @@ transform::ScalarizeOp::applyToOne(linalg::LinalgOp target,
   FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCFForOp(
       rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
   if (failed(maybeTilingResult))
-    return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
+    return emitDefaultDefiniteFailure(target);
 
   results.append(maybeTilingResult->tiledOps);
-  return DiagnosedSilenceableFailure(success());
+  return DiagnosedSilenceableFailure::success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -1171,13 +1171,13 @@ transform::SplitReductionOp::applyToOne(linalg::LinalgOp target,
           ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
           : splitReduction(rewriter, target, splitFn, getUseAlloc());
   if (failed(splitResult))
-    return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
+    return emitDefaultDefiniteFailure(target);
 
   results.push_back(splitResult->initOrAlloc);
   results.push_back(splitResult->fillOp);
   results.push_back(splitResult->splitLinalgOp);
   results.push_back(splitResult->resultCombiningLinalgOp);
-  return DiagnosedSilenceableFailure(success());
+  return DiagnosedSilenceableFailure::success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -1200,12 +1200,12 @@ DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne(
       sizes);
 
   if (failed(result))
-    return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
+    return emitDefaultSilenceableFailure(target);
   results.push_back(result->loops.front());
   results.push_back(result->initialOp);
   results.push_back(result->parallelTiledOp);
   results.push_back(result->mergeOp);
-  return DiagnosedSilenceableFailure(success());
+  return DiagnosedSilenceableFailure::success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -1235,7 +1235,7 @@ transform::TileReductionUsingForeachThreadOp::applyToOne(
   results.push_back(result->initialOp);
   results.push_back(result->parallelTiledOp);
   results.push_back(result->mergeOp);
-  return DiagnosedSilenceableFailure(success());
+  return DiagnosedSilenceableFailure::success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -1523,7 +1523,7 @@ static DiagnosedSilenceableFailure unpackPDLOperations(
     }
   }
 
-  return DiagnosedSilenceableFailure(success());
+  return DiagnosedSilenceableFailure::success();
 }
 
 DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl(
@@ -1533,7 +1533,7 @@ DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl(
     ArrayRef<OpFoldResult> mixedTileSizes, Optional<ArrayAttr> mapping,
     SmallVector<Operation *> &tileOps, SmallVector<Operation *> &tiledOps) {
   if (targets.empty())
-    return DiagnosedSilenceableFailure(success());
+    return DiagnosedSilenceableFailure::success();
 
   // getMixedNumThreads are OpFoldResults[index attributes or PDL operation].
   // Convert to OpFoldResults[index attributes or payload op].
@@ -1577,7 +1577,7 @@ DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl(
     tileOps.push_back(tilingResult->tileOp);
     tiledOps.push_back(tilingResult->tiledOp);
   }
-  return DiagnosedSilenceableFailure(success());
+  return DiagnosedSilenceableFailure::success();
 }
 
 DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply(
@@ -1604,7 +1604,7 @@ DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply(
   transformResults.set(getForeachThreadOp().cast<OpResult>(), tileOps);
   transformResults.set(getTiledOp().cast<OpResult>(), tiledOps);
 
-  return DiagnosedSilenceableFailure(success());
+  return DiagnosedSilenceableFailure::success();
 }
 
 void transform::TileToForeachThreadOp::getEffects(
@@ -1852,10 +1852,10 @@ transform::VectorizeOp::applyToOne(Operation *target,
     linalg::populatePadOpVectorizationPatterns(patterns);
 
   if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns))))
-    return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
+    return emitDefaultDefiniteFailure(target);
 
   results.push_back(target);
-  return DiagnosedSilenceableFailure(success());
+  return DiagnosedSilenceableFailure::success();
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
index 1d7a8b74ebe56..391164c76a07f 100644
--- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
+++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
@@ -33,7 +33,7 @@ transform::MemRefMultiBufferOp::applyToOne(memref::AllocOp target,
   }
 
   results.push_back(newBuffer.value());
-  return DiagnosedSilenceableFailure(success());
+  return DiagnosedSilenceableFailure::success();
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 8777662ec3902..21deab6bc2a06 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -103,10 +103,8 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results,
     FailureOr<func::FuncOp> outlined = outlineSingleBlockRegion(
         rewriter, location, exec.getRegion(), getFuncName(), &call);
 
-    if (failed(outlined)) {
-      (void)reportUnknownTransformError(target);
-      return DiagnosedSilenceableFailure::definiteFailure();
-    }
+    if (failed(outlined))
+      return emitDefaultDefiniteFailure(target);
 
     if (symbolTableOp) {
       SymbolTable &symbolTable =
@@ -139,7 +137,7 @@ transform::LoopPeelOp::applyToOne(scf::ForOp target,
       scf::peelAndCanonicalizeForLoop(rewriter, target, result);
   // TODO: Return both the peeled loop and the remainder loop.
   results.push_back(failed(status) ? target : result);
-  return DiagnosedSilenceableFailure(success());
+  return DiagnosedSilenceableFailure::success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -200,7 +198,7 @@ transform::LoopPipelineOp::applyToOne(scf::ForOp target,
       pattern.returningMatchAndRewrite(target, rewriter);
   if (succeeded(patternResult)) {
     results.push_back(*patternResult);
-    return DiagnosedSilenceableFailure(success());
+    return DiagnosedSilenceableFailure::success();
   }
   results.assign(1, nullptr);
   return emitDefaultSilenceableFailure(target);
@@ -225,7 +223,7 @@ transform::LoopUnrollOp::applyToOne(Operation *op,
     diag << "Op failed to unroll";
     return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
   }
-  return DiagnosedSilenceableFailure(success());
+  return DiagnosedSilenceableFailure::success();
 }
 
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list