[llvm-branch-commits] [mlir] [mlir][Transforms] Encapsulate dialect conversion options in `ConversionConfig` (PR #82250)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Feb 19 06:03:40 PST 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/82250

This commit adds a new `ConversionConfig` struct that allows users to customize the dialect conversion. This configuration is similar to `GreedyRewriteConfig` for the greedy pattern rewrite driver.

A few existing options are moved to this objects, simplifying the dialect conversion API.

>From 819e5f95ed0857e88972501cb10b9931b1e91a1c Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 19 Feb 2024 14:02:14 +0000
Subject: [PATCH] [mlir][Transforms] Encapsulate dialect conversion options in
 `ConversionConfig`

This commit adds a new `ConversionConfig` struct that allows users to customize the dialect conversion. This configuration is similar to `GreedyRewriteConfig` for the greedy pattern rewrite driver.

A few existing options are moved to this objects, simplifying the dialect conversion API.
---
 .../mlir/Transforms/DialectConversion.h       |  75 ++++++----
 .../Transforms/Utils/DialectConversion.cpp    | 129 +++++++++---------
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   |  14 +-
 3 files changed, 118 insertions(+), 100 deletions(-)

diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 5c91a9498b35d4..8da5dcb0be3fd0 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -24,6 +24,7 @@ namespace mlir {
 // Forward declarations.
 class Attribute;
 class Block;
+struct ConversionConfig;
 class ConversionPatternRewriter;
 class MLIRContext;
 class Operation;
@@ -770,7 +771,8 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// Conversion pattern rewriters must not be used outside of dialect
   /// conversions. They apply some IR rewrites in a delayed fashion and could
   /// bring the IR into an inconsistent state when used standalone.
-  explicit ConversionPatternRewriter(MLIRContext *ctx);
+  explicit ConversionPatternRewriter(MLIRContext *ctx,
+                                     const ConversionConfig &config);
 
   // Hide unsupported pattern rewriter API.
   using OpBuilder::setListener;
@@ -1070,6 +1072,30 @@ class PDLConversionConfig final {
 
 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
 
+//===----------------------------------------------------------------------===//
+// ConversionConfig
+//===----------------------------------------------------------------------===//
+
+/// Dialect conversion configuration.
+struct ConversionConfig {
+  /// An optional callback used to notify about match failure diagnostics during
+  /// the conversion. Diagnostics are only reported to this callback may only be
+  /// available in debug mode.
+  function_ref<void(Diagnostic &)> notifyCallback = nullptr;
+
+  /// Partial conversion only. All operations that are found not to be
+  /// legalizable are placed in this set. (Note that if there is an op
+  /// explicitly marked as illegal, the conversion terminates and the set will
+  /// not necessarily be complete.)
+  DenseSet<Operation *> *unlegalizedOps = nullptr;
+
+  /// Analysis conversion only. All operations that are found to be legalizable
+  /// are placed in this set. Note that no actual rewrites are applied to the
+  /// IR during an analysis conversion and only pre-existing operations are
+  /// added to the set.
+  DenseSet<Operation *> *legalizableOps = nullptr;
+};
+
 //===----------------------------------------------------------------------===//
 // Op Conversion Entry Points
 //===----------------------------------------------------------------------===//
@@ -1083,19 +1109,16 @@ class PDLConversionConfig final {
 /// Apply a partial conversion on the given operations and all nested
 /// operations. This method converts as many operations to the target as
 /// possible, ignoring operations that failed to legalize. This method only
-/// returns failure if there ops explicitly marked as illegal. If an
-/// `unconvertedOps` set is provided, all operations that are found not to be
-/// legalizable to the given `target` are placed within that set. (Note that if
-/// there is an op explicitly marked as illegal, the conversion terminates and
-/// the `unconvertedOps` set will not necessarily be complete.)
+/// returns failure if there ops explicitly marked as illegal.
 LogicalResult
-applyPartialConversion(ArrayRef<Operation *> ops, const ConversionTarget &target,
+applyPartialConversion(ArrayRef<Operation *> ops,
+                       const ConversionTarget &target,
                        const FrozenRewritePatternSet &patterns,
-                       DenseSet<Operation *> *unconvertedOps = nullptr);
+                       ConversionConfig config = ConversionConfig());
 LogicalResult
 applyPartialConversion(Operation *op, const ConversionTarget &target,
                        const FrozenRewritePatternSet &patterns,
-                       DenseSet<Operation *> *unconvertedOps = nullptr);
+                       ConversionConfig config = ConversionConfig());
 
 /// Apply a complete conversion on the given operations, and all nested
 /// operations. This method returns failure if the conversion of any operation
@@ -1103,31 +1126,27 @@ applyPartialConversion(Operation *op, const ConversionTarget &target,
 /// within 'ops'.
 LogicalResult applyFullConversion(ArrayRef<Operation *> ops,
                                   const ConversionTarget &target,
-                                  const FrozenRewritePatternSet &patterns);
+                                  const FrozenRewritePatternSet &patterns,
+                                  ConversionConfig config = ConversionConfig());
 LogicalResult applyFullConversion(Operation *op, const ConversionTarget &target,
-                                  const FrozenRewritePatternSet &patterns);
+                                  const FrozenRewritePatternSet &patterns,
+                                  ConversionConfig config = ConversionConfig());
 
 /// Apply an analysis conversion on the given operations, and all nested
 /// operations. This method analyzes which operations would be successfully
 /// converted to the target if a conversion was applied. All operations that
 /// were found to be legalizable to the given 'target' are placed within the
-/// provided 'convertedOps' set; note that no actual rewrites are applied to the
-/// operations on success and only pre-existing operations are added to the set.
-/// This method only returns failure if there are unreachable blocks in any of
-/// the regions nested within 'ops'. There's an additional argument
-/// `notifyCallback` which is used for collecting match failure diagnostics
-/// generated during the conversion. Diagnostics are only reported to this
-/// callback may only be available in debug mode.
-LogicalResult applyAnalysisConversion(
-    ArrayRef<Operation *> ops, ConversionTarget &target,
-    const FrozenRewritePatternSet &patterns,
-    DenseSet<Operation *> &convertedOps,
-    function_ref<void(Diagnostic &)> notifyCallback = nullptr);
-LogicalResult applyAnalysisConversion(
-    Operation *op, ConversionTarget &target,
-    const FrozenRewritePatternSet &patterns,
-    DenseSet<Operation *> &convertedOps,
-    function_ref<void(Diagnostic &)> notifyCallback = nullptr);
+/// provided 'config.legalizableOps' set; note that no actual rewrites are
+/// applied to the operations on success. This method only returns failure if
+/// there are unreachable blocks in any of the regions nested within 'ops'.
+LogicalResult
+applyAnalysisConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
+                        const FrozenRewritePatternSet &patterns,
+                        ConversionConfig config = ConversionConfig());
+LogicalResult
+applyAnalysisConversion(Operation *op, ConversionTarget &target,
+                        const FrozenRewritePatternSet &patterns,
+                        ConversionConfig config = ConversionConfig());
 } // namespace mlir
 
 #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 6cf178e149be7f..30fc2298b3deb3 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -224,6 +224,8 @@ class IRRewrite {
   /// Erase the given block (unless it was already erased).
   void eraseBlock(Block *block);
 
+  const ConversionConfig &getConfig() const;
+
   const Kind kind;
   ConversionPatternRewriterImpl &rewriterImpl;
 };
@@ -723,9 +725,10 @@ static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
 namespace mlir {
 namespace detail {
 struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
-  explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter)
+  explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter,
+                                         const ConversionConfig &config)
       : rewriter(rewriter), eraseRewriter(rewriter.getContext()),
-        notifyCallback(nullptr) {}
+        config(config) {}
 
   //===--------------------------------------------------------------------===//
   // State Management
@@ -931,10 +934,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// converting the arguments of blocks within that region.
   DenseMap<Region *, const TypeConverter *> regionToConverter;
 
-  /// This allows the user to collect the match failure message.
-  function_ref<void(Diagnostic &)> notifyCallback;
-
-  DenseSet<Operation *> *trackedOps = nullptr;
+  /// Dialect conversion configuration.
+  const ConversionConfig &config;
 
 #ifndef NDEBUG
   /// A set of operations that have pending updates. This tracking isn't
@@ -957,6 +958,10 @@ void IRRewrite::eraseBlock(Block *block) {
   rewriterImpl.eraseRewriter.eraseBlock(block);
 }
 
+const ConversionConfig &IRRewrite::getConfig() const {
+  return rewriterImpl.config;
+}
+
 void BlockTypeConversionRewrite::commit() {
   // Process the remapping for each of the original arguments.
   for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
@@ -1074,8 +1079,8 @@ void ReplaceOperationRewrite::commit() {
     if (Value newValue =
             rewriterImpl.mapping.lookupOrNull(result, result.getType()))
       result.replaceAllUsesWith(newValue);
-  if (rewriterImpl.trackedOps)
-    rewriterImpl.trackedOps->erase(op);
+  if (getConfig().unlegalizedOps)
+    getConfig().unlegalizedOps->erase(op);
   // Do not erase the operation yet. It may still be referenced in `mapping`.
   op->getBlock()->getOperations().remove(op);
 }
@@ -1510,8 +1515,8 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
     Diagnostic diag(loc, DiagnosticSeverity::Remark);
     reasonCallback(diag);
     logger.startLine() << "** Failure : " << diag.str() << "\n";
-    if (notifyCallback)
-      notifyCallback(diag);
+    if (config.notifyCallback)
+      config.notifyCallback(diag);
   });
 }
 
@@ -1519,9 +1524,10 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
 // ConversionPatternRewriter
 //===----------------------------------------------------------------------===//
 
-ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
+ConversionPatternRewriter::ConversionPatternRewriter(
+    MLIRContext *ctx, const ConversionConfig &config)
     : PatternRewriter(ctx),
-      impl(new detail::ConversionPatternRewriterImpl(*this)) {
+      impl(new detail::ConversionPatternRewriterImpl(*this, config)) {
   setListener(impl.get());
 }
 
@@ -1972,12 +1978,12 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
     assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
     LLVM_DEBUG({
       logFailure(rewriterImpl.logger, "pattern failed to match");
-      if (rewriterImpl.notifyCallback) {
+      if (rewriterImpl.config.notifyCallback) {
         Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
         diag << "Failed to apply pattern \"" << pattern.getDebugName()
              << "\" on op:\n"
              << *op;
-        rewriterImpl.notifyCallback(diag);
+        rewriterImpl.config.notifyCallback(diag);
       }
     });
     rewriterImpl.resetState(curState);
@@ -2365,14 +2371,12 @@ namespace mlir {
 struct OperationConverter {
   explicit OperationConverter(const ConversionTarget &target,
                               const FrozenRewritePatternSet &patterns,
-                              OpConversionMode mode,
-                              DenseSet<Operation *> *trackedOps = nullptr)
-      : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {}
+                              const ConversionConfig &config,
+                              OpConversionMode mode)
+      : opLegalizer(target, patterns), config(config), mode(mode) {}
 
   /// Converts the given operations to the conversion target.
-  LogicalResult
-  convertOperations(ArrayRef<Operation *> ops,
-                    function_ref<void(Diagnostic &)> notifyCallback = nullptr);
+  LogicalResult convertOperations(ArrayRef<Operation *> ops);
 
 private:
   /// Converts an operation with the given rewriter.
@@ -2409,14 +2413,11 @@ struct OperationConverter {
   /// The legalizer to use when converting operations.
   OperationLegalizer opLegalizer;
 
+  /// Dialect conversion configuration.
+  ConversionConfig config;
+
   /// The conversion mode to use when legalizing operations.
   OpConversionMode mode;
-
-  /// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
-  /// this is populated with ops found to be legalizable to the target.
-  /// When mode == OpConversionMode::Partial, this is populated with ops found
-  /// *not* to be legalizable to the target.
-  DenseSet<Operation *> *trackedOps;
 };
 } // namespace mlir
 
@@ -2430,28 +2431,27 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
       return op->emitError()
              << "failed to legalize operation '" << op->getName() << "'";
     // Partial conversions allow conversions to fail iff the operation was not
-    // explicitly marked as illegal. If the user provided a nonlegalizableOps
-    // set, non-legalizable ops are included.
+    // explicitly marked as illegal. If the user provided a `unlegalizedOps`
+    // set, non-legalizable ops are added to that set.
     if (mode == OpConversionMode::Partial) {
       if (opLegalizer.isIllegal(op))
         return op->emitError()
                << "failed to legalize operation '" << op->getName()
                << "' that was explicitly marked illegal";
-      if (trackedOps)
-        trackedOps->insert(op);
+      if (config.unlegalizedOps)
+        config.unlegalizedOps->insert(op);
     }
   } else if (mode == OpConversionMode::Analysis) {
     // Analysis conversions don't fail if any operations fail to legalize,
     // they are only interested in the operations that were successfully
     // legalized.
-    trackedOps->insert(op);
+    if (config.legalizableOps)
+      config.legalizableOps->insert(op);
   }
   return success();
 }
 
-LogicalResult OperationConverter::convertOperations(
-    ArrayRef<Operation *> ops,
-    function_ref<void(Diagnostic &)> notifyCallback) {
+LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   if (ops.empty())
     return success();
   const ConversionTarget &target = opLegalizer.getTarget();
@@ -2472,10 +2472,8 @@ LogicalResult OperationConverter::convertOperations(
   }
 
   // Convert each operation and discard rewrites on failure.
-  ConversionPatternRewriter rewriter(ops.front()->getContext());
+  ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
   ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
-  rewriterImpl.notifyCallback = notifyCallback;
-  rewriterImpl.trackedOps = trackedOps;
 
   for (auto *op : toConvert)
     if (failed(convert(rewriter, op)))
@@ -3461,56 +3459,51 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
 //===----------------------------------------------------------------------===//
 // Partial Conversion
 
-LogicalResult
-mlir::applyPartialConversion(ArrayRef<Operation *> ops,
-                             const ConversionTarget &target,
-                             const FrozenRewritePatternSet &patterns,
-                             DenseSet<Operation *> *unconvertedOps) {
-  OperationConverter opConverter(target, patterns, OpConversionMode::Partial,
-                                 unconvertedOps);
+LogicalResult mlir::applyPartialConversion(
+    ArrayRef<Operation *> ops, const ConversionTarget &target,
+    const FrozenRewritePatternSet &patterns, ConversionConfig config) {
+  OperationConverter opConverter(target, patterns, config,
+                                 OpConversionMode::Partial);
   return opConverter.convertOperations(ops);
 }
 LogicalResult
 mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
                              const FrozenRewritePatternSet &patterns,
-                             DenseSet<Operation *> *unconvertedOps) {
-  return applyPartialConversion(llvm::ArrayRef(op), target, patterns,
-                                unconvertedOps);
+                             ConversionConfig config) {
+  return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
 }
 
 //===----------------------------------------------------------------------===//
 // Full Conversion
 
-LogicalResult
-mlir::applyFullConversion(ArrayRef<Operation *> ops, const ConversionTarget &target,
-                          const FrozenRewritePatternSet &patterns) {
-  OperationConverter opConverter(target, patterns, OpConversionMode::Full);
+LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
+                                        const ConversionTarget &target,
+                                        const FrozenRewritePatternSet &patterns,
+                                        ConversionConfig config) {
+  OperationConverter opConverter(target, patterns, config,
+                                 OpConversionMode::Full);
   return opConverter.convertOperations(ops);
 }
-LogicalResult
-mlir::applyFullConversion(Operation *op, const ConversionTarget &target,
-                          const FrozenRewritePatternSet &patterns) {
-  return applyFullConversion(llvm::ArrayRef(op), target, patterns);
+LogicalResult mlir::applyFullConversion(Operation *op,
+                                        const ConversionTarget &target,
+                                        const FrozenRewritePatternSet &patterns,
+                                        ConversionConfig config) {
+  return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
 }
 
 //===----------------------------------------------------------------------===//
 // Analysis Conversion
 
-LogicalResult
-mlir::applyAnalysisConversion(ArrayRef<Operation *> ops,
-                              ConversionTarget &target,
-                              const FrozenRewritePatternSet &patterns,
-                              DenseSet<Operation *> &convertedOps,
-                              function_ref<void(Diagnostic &)> notifyCallback) {
-  OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
-                                 &convertedOps);
-  return opConverter.convertOperations(ops, notifyCallback);
+LogicalResult mlir::applyAnalysisConversion(
+    ArrayRef<Operation *> ops, ConversionTarget &target,
+    const FrozenRewritePatternSet &patterns, ConversionConfig config) {
+  OperationConverter opConverter(target, patterns, config,
+                                 OpConversionMode::Analysis);
+  return opConverter.convertOperations(ops);
 }
 LogicalResult
 mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
                               const FrozenRewritePatternSet &patterns,
-                              DenseSet<Operation *> &convertedOps,
-                              function_ref<void(Diagnostic &)> notifyCallback) {
-  return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns,
-                                 convertedOps, notifyCallback);
+                              ConversionConfig config) {
+  return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
 }
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 1c02232b8adbb1..c04d5cc80446f4 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1135,8 +1135,10 @@ struct TestLegalizePatternDriver
     // Handle a partial conversion.
     if (mode == ConversionMode::Partial) {
       DenseSet<Operation *> unlegalizedOps;
-      if (failed(applyPartialConversion(
-              getOperation(), target, std::move(patterns), &unlegalizedOps))) {
+      ConversionConfig config;
+      config.unlegalizedOps = &unlegalizedOps;
+      if (failed(applyPartialConversion(getOperation(), target,
+                                        std::move(patterns), config))) {
         getOperation()->emitRemark() << "applyPartialConversion failed";
       }
       // Emit remarks for each legalizable operation.
@@ -1164,8 +1166,10 @@ struct TestLegalizePatternDriver
 
     // Analyze the convertible operations.
     DenseSet<Operation *> legalizedOps;
+    ConversionConfig config;
+    config.legalizableOps = &legalizedOps;
     if (failed(applyAnalysisConversion(getOperation(), target,
-                                       std::move(patterns), legalizedOps)))
+                                       std::move(patterns), config)))
       return signalPassFailure();
 
     // Emit remarks for each legalizable operation.
@@ -1789,8 +1793,10 @@ struct TestMergeBlocksPatternDriver
         });
 
     DenseSet<Operation *> unlegalizedOps;
+    ConversionConfig config;
+    config.unlegalizedOps = &unlegalizedOps;
     (void)applyPartialConversion(getOperation(), target, std::move(patterns),
-                                 &unlegalizedOps);
+                                 config);
     for (auto *op : unlegalizedOps)
       op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
   }



More information about the llvm-branch-commits mailing list