[Mlir-commits] [mlir] 60fbd60 - Revert "[mlir][Transforms] Encapsulate dialect conversion options in `ConversionConfig` (#83662)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Mar 2 14:41:44 PST 2024


Author: Mehdi Amini
Date: 2024-03-02T14:41:40-08:00
New Revision: 60fbd6050107875956960c3ce35cf94b202d8675

URL: https://github.com/llvm/llvm-project/commit/60fbd6050107875956960c3ce35cf94b202d8675
DIFF: https://github.com/llvm/llvm-project/commit/60fbd6050107875956960c3ce35cf94b202d8675.diff

LOG: Revert "[mlir][Transforms] Encapsulate dialect conversion options in `ConversionConfig` (#83662)

This reverts commit 5f1319bb385342c7ef4124b05b83b89ef8588ee8.

A FIR test is broken on Windows

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Transforms/Utils/DialectConversion.cpp
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 84396529eb7c2e..88eefa69a8003f 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -24,7 +24,6 @@ namespace mlir {
 // Forward declarations.
 class Attribute;
 class Block;
-struct ConversionConfig;
 class ConversionPatternRewriter;
 class MLIRContext;
 class Operation;
@@ -768,8 +767,7 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// Conversion pattern rewriters must not be used outside of dialect
   /// conversions. They apply some IR rewrites in a delayed fashion and could
   /// bring the IR into an inconsistent state when used standalone.
-  explicit ConversionPatternRewriter(MLIRContext *ctx,
-                                     const ConversionConfig &config);
+  explicit ConversionPatternRewriter(MLIRContext *ctx);
 
   // Hide unsupported pattern rewriter API.
   using OpBuilder::setListener;
@@ -1069,30 +1067,6 @@ class PDLConversionConfig final {
 
 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
 
-//===----------------------------------------------------------------------===//
-// ConversionConfig
-//===----------------------------------------------------------------------===//
-
-/// Dialect conversion configuration.
-struct ConversionConfig {
-  /// An optional callback used to notify about match failure diagnostics during
-  /// the conversion. Diagnostics reported to this callback may only be
-  /// available in debug mode.
-  function_ref<void(Diagnostic &)> notifyCallback = nullptr;
-
-  /// Partial conversion only. All operations that are found not to be
-  /// legalizable are placed in this set. (Note that if there is an op
-  /// explicitly marked as illegal, the conversion terminates and the set will
-  /// not necessarily be complete.)
-  DenseSet<Operation *> *unlegalizedOps = nullptr;
-
-  /// Analysis conversion only. All operations that are found to be legalizable
-  /// are placed in this set. Note that no actual rewrites are applied to the
-  /// IR during an analysis conversion and only pre-existing operations are
-  /// added to the set.
-  DenseSet<Operation *> *legalizableOps = nullptr;
-};
-
 //===----------------------------------------------------------------------===//
 // Op Conversion Entry Points
 //===----------------------------------------------------------------------===//
@@ -1106,16 +1080,20 @@ struct ConversionConfig {
 /// Apply a partial conversion on the given operations and all nested
 /// operations. This method converts as many operations to the target as
 /// possible, ignoring operations that failed to legalize. This method only
-/// returns failure if there ops explicitly marked as illegal.
+/// returns failure if there ops explicitly marked as illegal. If an
+/// `unconvertedOps` set is provided, all operations that are found not to be
+/// legalizable to the given `target` are placed within that set. (Note that if
+/// there is an op explicitly marked as illegal, the conversion terminates and
+/// the `unconvertedOps` set will not necessarily be complete.)
 LogicalResult
 applyPartialConversion(ArrayRef<Operation *> ops,
                        const ConversionTarget &target,
                        const FrozenRewritePatternSet &patterns,
-                       ConversionConfig config = ConversionConfig());
+                       DenseSet<Operation *> *unconvertedOps = nullptr);
 LogicalResult
 applyPartialConversion(Operation *op, const ConversionTarget &target,
                        const FrozenRewritePatternSet &patterns,
-                       ConversionConfig config = ConversionConfig());
+                       DenseSet<Operation *> *unconvertedOps = nullptr);
 
 /// Apply a complete conversion on the given operations, and all nested
 /// operations. This method returns failure if the conversion of any operation
@@ -1123,27 +1101,31 @@ applyPartialConversion(Operation *op, const ConversionTarget &target,
 /// within 'ops'.
 LogicalResult applyFullConversion(ArrayRef<Operation *> ops,
                                   const ConversionTarget &target,
-                                  const FrozenRewritePatternSet &patterns,
-                                  ConversionConfig config = ConversionConfig());
+                                  const FrozenRewritePatternSet &patterns);
 LogicalResult applyFullConversion(Operation *op, const ConversionTarget &target,
-                                  const FrozenRewritePatternSet &patterns,
-                                  ConversionConfig config = ConversionConfig());
+                                  const FrozenRewritePatternSet &patterns);
 
 /// Apply an analysis conversion on the given operations, and all nested
 /// operations. This method analyzes which operations would be successfully
 /// converted to the target if a conversion was applied. All operations that
 /// were found to be legalizable to the given 'target' are placed within the
-/// provided 'config.legalizableOps' set; note that no actual rewrites are
-/// applied to the operations on success. This method only returns failure if
-/// there are unreachable blocks in any of the regions nested within 'ops'.
-LogicalResult
-applyAnalysisConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
-                        const FrozenRewritePatternSet &patterns,
-                        ConversionConfig config = ConversionConfig());
-LogicalResult
-applyAnalysisConversion(Operation *op, ConversionTarget &target,
-                        const FrozenRewritePatternSet &patterns,
-                        ConversionConfig config = ConversionConfig());
+/// provided 'convertedOps' set; note that no actual rewrites are applied to the
+/// operations on success and only pre-existing operations are added to the set.
+/// This method only returns failure if there are unreachable blocks in any of
+/// the regions nested within 'ops'. There's an additional argument
+/// `notifyCallback` which is used for collecting match failure diagnostics
+/// generated during the conversion. Diagnostics are only reported to this
+/// callback may only be available in debug mode.
+LogicalResult applyAnalysisConversion(
+    ArrayRef<Operation *> ops, ConversionTarget &target,
+    const FrozenRewritePatternSet &patterns,
+    DenseSet<Operation *> &convertedOps,
+    function_ref<void(Diagnostic &)> notifyCallback = nullptr);
+LogicalResult applyAnalysisConversion(
+    Operation *op, ConversionTarget &target,
+    const FrozenRewritePatternSet &patterns,
+    DenseSet<Operation *> &convertedOps,
+    function_ref<void(Diagnostic &)> notifyCallback = nullptr);
 } // namespace mlir
 
 #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 26899301eb742e..ffdb442033d323 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -230,8 +230,6 @@ class IRRewrite {
   /// Erase the given block (unless it was already erased).
   void eraseBlock(Block *block);
 
-  const ConversionConfig &getConfig() const;
-
   const Kind kind;
   ConversionPatternRewriterImpl &rewriterImpl;
 };
@@ -734,9 +732,8 @@ static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
 namespace mlir {
 namespace detail {
 struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
-  explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
-                                         const ConversionConfig &config)
-      : eraseRewriter(ctx), config(config) {}
+  explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter)
+      : eraseRewriter(rewriter.getContext()) {}
 
   //===--------------------------------------------------------------------===//
   // State Management
@@ -936,8 +933,14 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// converting the arguments of blocks within that region.
   DenseMap<Region *, const TypeConverter *> regionToConverter;
 
-  /// Dialect conversion configuration.
-  const ConversionConfig &config;
+  /// This allows the user to collect the match failure message.
+  function_ref<void(Diagnostic &)> notifyCallback;
+
+  /// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
+  /// this is populated with ops found to be legalizable to the target.
+  /// When mode == OpConversionMode::Partial, this is populated with ops found
+  /// *not* to be legalizable to the target.
+  DenseSet<Operation *> *trackedOps = nullptr;
 
 #ifndef NDEBUG
   /// A set of operations that have pending updates. This tracking isn't
@@ -960,10 +963,6 @@ void IRRewrite::eraseBlock(Block *block) {
   rewriterImpl.eraseRewriter.eraseBlock(block);
 }
 
-const ConversionConfig &IRRewrite::getConfig() const {
-  return rewriterImpl.config;
-}
-
 void BlockTypeConversionRewrite::commit() {
   // Process the remapping for each of the original arguments.
   for (auto [origArg, info] :
@@ -1081,8 +1080,8 @@ void ReplaceOperationRewrite::commit() {
     if (Value newValue =
             rewriterImpl.mapping.lookupOrNull(result, result.getType()))
       result.replaceAllUsesWith(newValue);
-  if (getConfig().unlegalizedOps)
-    getConfig().unlegalizedOps->erase(op);
+  if (rewriterImpl.trackedOps)
+    rewriterImpl.trackedOps->erase(op);
   // Do not erase the operation yet. It may still be referenced in `mapping`.
   op->getBlock()->getOperations().remove(op);
 }
@@ -1505,8 +1504,8 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
     Diagnostic diag(loc, DiagnosticSeverity::Remark);
     reasonCallback(diag);
     logger.startLine() << "** Failure : " << diag.str() << "\n";
-    if (config.notifyCallback)
-      config.notifyCallback(diag);
+    if (notifyCallback)
+      notifyCallback(diag);
   });
 }
 
@@ -1514,10 +1513,9 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
 // ConversionPatternRewriter
 //===----------------------------------------------------------------------===//
 
-ConversionPatternRewriter::ConversionPatternRewriter(
-    MLIRContext *ctx, const ConversionConfig &config)
+ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
     : PatternRewriter(ctx),
-      impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
+      impl(new detail::ConversionPatternRewriterImpl(*this)) {
   setListener(impl.get());
 }
 
@@ -1986,12 +1984,12 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
     assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
     LLVM_DEBUG({
       logFailure(rewriterImpl.logger, "pattern failed to match");
-      if (rewriterImpl.config.notifyCallback) {
+      if (rewriterImpl.notifyCallback) {
         Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
         diag << "Failed to apply pattern \"" << pattern.getDebugName()
              << "\" on op:\n"
              << *op;
-        rewriterImpl.config.notifyCallback(diag);
+        rewriterImpl.notifyCallback(diag);
       }
     });
     rewriterImpl.resetState(curState);
@@ -2379,12 +2377,14 @@ namespace mlir {
 struct OperationConverter {
   explicit OperationConverter(const ConversionTarget &target,
                               const FrozenRewritePatternSet &patterns,
-                              const ConversionConfig &config,
-                              OpConversionMode mode)
-      : opLegalizer(target, patterns), config(config), mode(mode) {}
+                              OpConversionMode mode,
+                              DenseSet<Operation *> *trackedOps = nullptr)
+      : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {}
 
   /// Converts the given operations to the conversion target.
-  LogicalResult convertOperations(ArrayRef<Operation *> ops);
+  LogicalResult
+  convertOperations(ArrayRef<Operation *> ops,
+                    function_ref<void(Diagnostic &)> notifyCallback = nullptr);
 
 private:
   /// Converts an operation with the given rewriter.
@@ -2421,11 +2421,14 @@ struct OperationConverter {
   /// The legalizer to use when converting operations.
   OperationLegalizer opLegalizer;
 
-  /// Dialect conversion configuration.
-  ConversionConfig config;
-
   /// The conversion mode to use when legalizing operations.
   OpConversionMode mode;
+
+  /// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
+  /// this is populated with ops found to be legalizable to the target.
+  /// When mode == OpConversionMode::Partial, this is populated with ops found
+  /// *not* to be legalizable to the target.
+  DenseSet<Operation *> *trackedOps;
 };
 } // namespace mlir
 
@@ -2439,27 +2442,28 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
       return op->emitError()
              << "failed to legalize operation '" << op->getName() << "'";
     // Partial conversions allow conversions to fail iff the operation was not
-    // explicitly marked as illegal. If the user provided a `unlegalizedOps`
-    // set, non-legalizable ops are added to that set.
+    // explicitly marked as illegal. If the user provided a nonlegalizableOps
+    // set, non-legalizable ops are included.
     if (mode == OpConversionMode::Partial) {
       if (opLegalizer.isIllegal(op))
         return op->emitError()
                << "failed to legalize operation '" << op->getName()
                << "' that was explicitly marked illegal";
-      if (config.unlegalizedOps)
-        config.unlegalizedOps->insert(op);
+      if (trackedOps)
+        trackedOps->insert(op);
     }
   } else if (mode == OpConversionMode::Analysis) {
     // Analysis conversions don't fail if any operations fail to legalize,
     // they are only interested in the operations that were successfully
     // legalized.
-    if (config.legalizableOps)
-      config.legalizableOps->insert(op);
+    trackedOps->insert(op);
   }
   return success();
 }
 
-LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
+LogicalResult OperationConverter::convertOperations(
+    ArrayRef<Operation *> ops,
+    function_ref<void(Diagnostic &)> notifyCallback) {
   if (ops.empty())
     return success();
   const ConversionTarget &target = opLegalizer.getTarget();
@@ -2480,8 +2484,10 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   }
 
   // Convert each operation and discard rewrites on failure.
-  ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
+  ConversionPatternRewriter rewriter(ops.front()->getContext());
   ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
+  rewriterImpl.notifyCallback = notifyCallback;
+  rewriterImpl.trackedOps = trackedOps;
 
   for (auto *op : toConvert)
     if (failed(convert(rewriter, op)))
@@ -3468,51 +3474,57 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
 //===----------------------------------------------------------------------===//
 // Partial Conversion
 
-LogicalResult mlir::applyPartialConversion(
-    ArrayRef<Operation *> ops, const ConversionTarget &target,
-    const FrozenRewritePatternSet &patterns, ConversionConfig config) {
-  OperationConverter opConverter(target, patterns, config,
-                                 OpConversionMode::Partial);
+LogicalResult
+mlir::applyPartialConversion(ArrayRef<Operation *> ops,
+                             const ConversionTarget &target,
+                             const FrozenRewritePatternSet &patterns,
+                             DenseSet<Operation *> *unconvertedOps) {
+  OperationConverter opConverter(target, patterns, OpConversionMode::Partial,
+                                 unconvertedOps);
   return opConverter.convertOperations(ops);
 }
 LogicalResult
 mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
                              const FrozenRewritePatternSet &patterns,
-                             ConversionConfig config) {
-  return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
+                             DenseSet<Operation *> *unconvertedOps) {
+  return applyPartialConversion(llvm::ArrayRef(op), target, patterns,
+                                unconvertedOps);
 }
 
 //===----------------------------------------------------------------------===//
 // Full Conversion
 
-LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
-                                        const ConversionTarget &target,
-                                        const FrozenRewritePatternSet &patterns,
-                                        ConversionConfig config) {
-  OperationConverter opConverter(target, patterns, config,
-                                 OpConversionMode::Full);
+LogicalResult
+mlir::applyFullConversion(ArrayRef<Operation *> ops,
+                          const ConversionTarget &target,
+                          const FrozenRewritePatternSet &patterns) {
+  OperationConverter opConverter(target, patterns, OpConversionMode::Full);
   return opConverter.convertOperations(ops);
 }
-LogicalResult mlir::applyFullConversion(Operation *op,
-                                        const ConversionTarget &target,
-                                        const FrozenRewritePatternSet &patterns,
-                                        ConversionConfig config) {
-  return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
+LogicalResult
+mlir::applyFullConversion(Operation *op, const ConversionTarget &target,
+                          const FrozenRewritePatternSet &patterns) {
+  return applyFullConversion(llvm::ArrayRef(op), target, patterns);
 }
 
 //===----------------------------------------------------------------------===//
 // Analysis Conversion
 
-LogicalResult mlir::applyAnalysisConversion(
-    ArrayRef<Operation *> ops, ConversionTarget &target,
-    const FrozenRewritePatternSet &patterns, ConversionConfig config) {
-  OperationConverter opConverter(target, patterns, config,
-                                 OpConversionMode::Analysis);
-  return opConverter.convertOperations(ops);
+LogicalResult
+mlir::applyAnalysisConversion(ArrayRef<Operation *> ops,
+                              ConversionTarget &target,
+                              const FrozenRewritePatternSet &patterns,
+                              DenseSet<Operation *> &convertedOps,
+                              function_ref<void(Diagnostic &)> notifyCallback) {
+  OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
+                                 &convertedOps);
+  return opConverter.convertOperations(ops, notifyCallback);
 }
 LogicalResult
 mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
                               const FrozenRewritePatternSet &patterns,
-                              ConversionConfig config) {
-  return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
+                              DenseSet<Operation *> &convertedOps,
+                              function_ref<void(Diagnostic &)> notifyCallback) {
+  return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns,
+                                 convertedOps, notifyCallback);
 }

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index abc0e43c7b7f2d..157bfcc1eb23be 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1152,10 +1152,8 @@ struct TestLegalizePatternDriver
     // Handle a partial conversion.
     if (mode == ConversionMode::Partial) {
       DenseSet<Operation *> unlegalizedOps;
-      ConversionConfig config;
-      config.unlegalizedOps = &unlegalizedOps;
-      if (failed(applyPartialConversion(getOperation(), target,
-                                        std::move(patterns), config))) {
+      if (failed(applyPartialConversion(
+              getOperation(), target, std::move(patterns), &unlegalizedOps))) {
         getOperation()->emitRemark() << "applyPartialConversion failed";
       }
       // Emit remarks for each legalizable operation.
@@ -1183,10 +1181,8 @@ struct TestLegalizePatternDriver
 
     // Analyze the convertible operations.
     DenseSet<Operation *> legalizedOps;
-    ConversionConfig config;
-    config.legalizableOps = &legalizedOps;
     if (failed(applyAnalysisConversion(getOperation(), target,
-                                       std::move(patterns), config)))
+                                       std::move(patterns), legalizedOps)))
       return signalPassFailure();
 
     // Emit remarks for each legalizable operation.
@@ -1809,10 +1805,8 @@ struct TestMergeBlocksPatternDriver
         });
 
     DenseSet<Operation *> unlegalizedOps;
-    ConversionConfig config;
-    config.unlegalizedOps = &unlegalizedOps;
     (void)applyPartialConversion(getOperation(), target, std::move(patterns),
-                                 config);
+                                 &unlegalizedOps);
     for (auto *op : unlegalizedOps)
       op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
   }


        


More information about the Mlir-commits mailing list