[Mlir-commits] [mlir] [mlir][Transforms] Dialect conversion: Add option to disable folding (PR #92683)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun May 19 03:01:26 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

This commit adds a new flag to `ConversionConfig` that deactivates op folding during a dialect conversion.

Op folding is problematic beause op folders may assume that the entire IR is in a valid state. (See #<!-- -->89770 for an example.) However, the dialect conversion driver does not guarantee that the IR is valid during a dialect conversion; it only guarantees that the IR is valid at the end of a dialect conversion. E.g., IR may be invalid after a conversion pattern application because some IR modifications (e.g., op/block replacements) are applied in a delayed fashion at the end of a dialect conversion. This makes op folders generally unsafe to use with a dialect conversion.

Note: For the same reason, it is also not safe to use non-conversion patterns with a dialect conversion. Regular rewrite patterns may assume that the entire IR is in a valid state, but conversion patterns cannot -- and developers of conversion patterns must take that into account.


---
Full diff: https://github.com/llvm/llvm-project/pull/92683.diff


6 Files Affected:

- (modified) mlir/include/mlir/Transforms/DialectConversion.h (+10-1) 
- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+1-1) 
- (modified) mlir/test/Transforms/test-legalizer-analysis.mlir (+1-1) 
- (modified) mlir/test/Transforms/test-legalizer-full.mlir (+1-1) 
- (added) mlir/test/Transforms/test-legalizer-no-fold.mlir (+11) 
- (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+23-19) 


``````````diff
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 83198c9b0db54..ea41e7c5ba803 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -247,7 +247,8 @@ class TypeConverter {
   /// Attempts a 1-1 type conversion, expecting the result type to be
   /// `TargetType`. Returns the converted type cast to `TargetType` on success,
   /// and a null type on conversion or cast failure.
-  template <typename TargetType> TargetType convertType(Type t) const {
+  template <typename TargetType>
+  TargetType convertType(Type t) const {
     return dyn_cast_or_null<TargetType>(convertType(t));
   }
 
@@ -1118,6 +1119,14 @@ struct ConversionConfig {
   // already been modified) and iterators into past IR state cannot be
   // represented at the moment.
   RewriterBase::Listener *listener = nullptr;
+
+  /// If set to "true", the dialect conversion driver attempts to fold
+  /// operations throughout the conversion. This is problematic because op
+  /// folders may assume that the IR is in a valid state at the beginning of
+  /// the folding process. However, the dialect conversion does not guarantee
+  /// that because some IR modifications are delayed until the end of the
+  /// conversion.
+  bool foldOps = true;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index d407d60334c70..3c684e9a208ac 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2030,7 +2030,7 @@ OperationLegalizer::legalize(Operation *op,
   // If the operation isn't legal, try to fold it in-place.
   // TODO: Should we always try to do this, even if the op is
   // already legal?
-  if (succeeded(legalizeWithFold(op, rewriter))) {
+  if (config.foldOps && succeeded(legalizeWithFold(op, rewriter))) {
     LLVM_DEBUG({
       logSuccess(logger, "operation was folded");
       logger.startLine() << logLineComment;
diff --git a/mlir/test/Transforms/test-legalizer-analysis.mlir b/mlir/test/Transforms/test-legalizer-analysis.mlir
index 19a13100159a2..829415b9af414 100644
--- a/mlir/test/Transforms/test-legalizer-analysis.mlir
+++ b/mlir/test/Transforms/test-legalizer-analysis.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns -verify-diagnostics -test-legalize-mode=analysis %s | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns="legalize-mode=analysis" -verify-diagnostics %s | FileCheck %s
 // expected-remark at -2 {{op 'builtin.module' is legalizable}}
 
 // expected-remark at +1 {{op 'func.func' is legalizable}}
diff --git a/mlir/test/Transforms/test-legalizer-full.mlir b/mlir/test/Transforms/test-legalizer-full.mlir
index 5f1148cac6501..ea163a5767755 100644
--- a/mlir/test/Transforms/test-legalizer-full.mlir
+++ b/mlir/test/Transforms/test-legalizer-full.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns -test-legalize-mode=full -split-input-file -verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns="legalize-mode=full" -split-input-file -verify-diagnostics %s | FileCheck %s
 
 // CHECK-LABEL: func @multi_level_mapping
 func.func @multi_level_mapping() {
diff --git a/mlir/test/Transforms/test-legalizer-no-fold.mlir b/mlir/test/Transforms/test-legalizer-no-fold.mlir
new file mode 100644
index 0000000000000..61afd72c934bc
--- /dev/null
+++ b/mlir/test/Transforms/test-legalizer-no-fold.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns="fold-ops=0" %s | FileCheck %s
+
+// CHECK-LABEL: @remove_foldable_op(
+func.func @remove_foldable_op(%arg0 : i32) -> (i32) {
+  // Check that op was not folded.
+  // CHECK: "test.op_with_region_fold"
+  %0 = "test.op_with_region_fold"(%arg0) ({
+    "foo.op_with_region_terminator"() : () -> ()
+  }) : (i32) -> (i32)
+  "test.return"(%0) : (i32) -> ()
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index f9f7d4eacf948..97ef7a51aa203 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1102,7 +1102,9 @@ struct TestLegalizePatternDriver
   /// The mode of conversion to use with the driver.
   enum class ConversionMode { Analysis, Full, Partial };
 
-  TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
+  TestLegalizePatternDriver() = default;
+  TestLegalizePatternDriver(const TestLegalizePatternDriver &other)
+      : PassWrapper(other) {}
 
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<func::FuncDialect, test::TestDialect>();
@@ -1179,6 +1181,7 @@ struct TestLegalizePatternDriver
       DumpNotifications dumpNotifications;
       config.listener = &dumpNotifications;
       config.unlegalizedOps = &unlegalizedOps;
+      config.foldOps = foldOps;
       if (failed(applyPartialConversion(getOperation(), target,
                                         std::move(patterns), config))) {
         getOperation()->emitRemark() << "applyPartialConversion failed";
@@ -1197,6 +1200,7 @@ struct TestLegalizePatternDriver
       });
 
       ConversionConfig config;
+      config.foldOps = foldOps;
       DumpNotifications dumpNotifications;
       config.listener = &dumpNotifications;
       if (failed(applyFullConversion(getOperation(), target,
@@ -1212,6 +1216,7 @@ struct TestLegalizePatternDriver
     // Analyze the convertible operations.
     DenseSet<Operation *> legalizedOps;
     ConversionConfig config;
+    config.foldOps = foldOps;
     config.legalizableOps = &legalizedOps;
     if (failed(applyAnalysisConversion(getOperation(), target,
                                        std::move(patterns), config)))
@@ -1222,24 +1227,25 @@ struct TestLegalizePatternDriver
       op->emitRemark() << "op '" << op->getName() << "' is legalizable";
   }
 
-  /// The mode of conversion to use.
-  ConversionMode mode;
+  Option<bool> foldOps{
+      *this, "fold-ops",
+      llvm::cl::desc("Fold ops throughout the conversion process"),
+      llvm::cl::init(true)};
+
+  Option<TestLegalizePatternDriver::ConversionMode> mode{
+      *this, "legalize-mode",
+      llvm::cl::desc("The legalization mode to use with the test driver"),
+      llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
+      llvm::cl::values(
+          clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
+                     "analysis", "Perform an analysis conversion"),
+          clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
+                     "Perform a full conversion"),
+          clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
+                     "partial", "Perform a partial conversion"))};
 };
 } // namespace
 
-static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
-    legalizerConversionMode(
-        "test-legalize-mode",
-        llvm::cl::desc("The legalization mode to use with the test driver"),
-        llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
-        llvm::cl::values(
-            clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
-                       "analysis", "Perform an analysis conversion"),
-            clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
-                       "Perform a full conversion"),
-            clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
-                       "partial", "Perform a partial conversion")));
-
 //===----------------------------------------------------------------------===//
 // ConversionPatternRewriter::getRemappedValue testing. This method is used
 // to get the remapped value of an original value that was replaced using
@@ -1909,9 +1915,7 @@ void registerPatternsTestPass() {
   PassRegistration<TestPatternDriver>();
   PassRegistration<TestStrictPatternDriver>();
 
-  PassRegistration<TestLegalizePatternDriver>([] {
-    return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
-  });
+  PassRegistration<TestLegalizePatternDriver>();
 
   PassRegistration<TestRemappedValue>();
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/92683


More information about the Mlir-commits mailing list