[Mlir-commits] [mlir] 6ec0985 - [mlir] Enable disabling folding in dialect conversion (#152890)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Aug 10 09:00:34 PDT 2025


Author: Jacques Pienaar
Date: 2025-08-10T18:00:31+02:00
New Revision: 6ec0985ad07d68016587353fce715dce73cf1faa

URL: https://github.com/llvm/llvm-project/commit/6ec0985ad07d68016587353fce715dce73cf1faa
DIFF: https://github.com/llvm/llvm-project/commit/6ec0985ad07d68016587353fce715dce73cf1faa.diff

LOG: [mlir] Enable disabling folding in dialect conversion (#152890)

Previously this only happened post checking if the op is legal, but was
done unconditionally post (and before other legalization patterns). Add
option to not attempt folding and one to do so as last resort.

Did consider but did not add a always attempt to fold option (which
would have folded whether or not legal), but removed TODO about it.

Added: 
    mlir/test/Transforms/test-legalizer-fold-after.mlir
    mlir/test/Transforms/test-legalizer-fold-before.mlir
    mlir/test/Transforms/test-legalizer-no-fold.mlir

Modified: 
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Transforms/Utils/DialectConversion.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 4e651a0489899..ea828480be4cb 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1161,6 +1161,16 @@ class PDLConversionConfig final {
 // ConversionConfig
 //===----------------------------------------------------------------------===//
 
+/// An enum to control folding behavior during dialect conversion.
+enum class DialectConversionFoldingMode {
+  /// Never attempt to fold.
+  Never,
+  /// Only attempt to fold not legal operations before applying patterns.
+  BeforePatterns,
+  /// Only attempt to fold not legal operations after applying patterns.
+  AfterPatterns,
+};
+
 /// Dialect conversion configuration.
 struct ConversionConfig {
   /// An optional callback used to notify about match failure diagnostics during
@@ -1243,6 +1253,10 @@ struct ConversionConfig {
   /// your patterns do not trigger any IR rollbacks. For details, see
   /// https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083.
   bool allowPatternRollback = true;
+
+  /// The folding mode to use during conversion.
+  DialectConversionFoldingMode foldingMode =
+      DialectConversionFoldingMode::BeforePatterns;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 0c26b4ed46b31..2470f2b122def 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2257,15 +2257,17 @@ OperationLegalizer::legalize(Operation *op,
     return success();
   }
 
-  // If the operation isn't legal, try to fold it in-place.
-  // TODO: Should we always try to do this, even if the op is
-  // already legal?
-  if (succeeded(legalizeWithFold(op, rewriter))) {
-    LLVM_DEBUG({
-      logSuccess(logger, "operation was folded");
-      logger.startLine() << logLineComment;
-    });
-    return success();
+  // If the operation is not legal, try to fold it in-place if the folding mode
+  // is 'BeforePatterns'. 'Never' will skip this.
+  const ConversionConfig &config = rewriter.getConfig();
+  if (config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
+    if (succeeded(legalizeWithFold(op, rewriter))) {
+      LLVM_DEBUG({
+        logSuccess(logger, "operation was folded");
+        logger.startLine() << logLineComment;
+      });
+      return success();
+    }
   }
 
   // Otherwise, we need to apply a legalization pattern to this operation.
@@ -2277,6 +2279,18 @@ OperationLegalizer::legalize(Operation *op,
     return success();
   }
 
+  // If the operation can't be legalized via patterns, try to fold it in-place
+  // if the folding mode is 'AfterPatterns'.
+  if (config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
+    if (succeeded(legalizeWithFold(op, rewriter))) {
+      LLVM_DEBUG({
+        logSuccess(logger, "operation was folded");
+        logger.startLine() << logLineComment;
+      });
+      return success();
+    }
+  }
+
   LLVM_DEBUG({
     logFailure(logger, "no matched legalization pattern");
     logger.startLine() << logLineComment;

diff  --git a/mlir/test/Transforms/test-legalizer-fold-after.mlir b/mlir/test/Transforms/test-legalizer-fold-after.mlir
new file mode 100644
index 0000000000000..7f80252dc9604
--- /dev/null
+++ b/mlir/test/Transforms/test-legalizer-fold-after.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-opt %s -test-legalize-patterns="test-legalize-folding-mode=after-patterns" | FileCheck %s
+
+// CHECK-LABEL: @fold_legalization
+func.func @fold_legalization() -> i32 {
+  // CHECK-NOT: op_in_place_self_fold
+  // CHECK: 97
+  %1 = "test.op_in_place_self_fold"() : () -> (i32)
+  "test.return"(%1) : (i32) -> ()
+}

diff  --git a/mlir/test/Transforms/test-legalizer-fold-before.mlir b/mlir/test/Transforms/test-legalizer-fold-before.mlir
new file mode 100644
index 0000000000000..fe6e29351a5d7
--- /dev/null
+++ b/mlir/test/Transforms/test-legalizer-fold-before.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-opt %s -test-legalize-patterns="test-legalize-folding-mode=before-patterns" | FileCheck %s
+
+// CHECK-LABEL: @fold_legalization
+func.func @fold_legalization() -> i32 {
+  // CHECK: op_in_place_self_fold
+  // CHECK-SAME: folded
+  %1 = "test.op_in_place_self_fold"() : () -> (i32)
+  "test.return"(%1) : (i32) -> ()
+}

diff  --git a/mlir/test/Transforms/test-legalizer-no-fold.mlir b/mlir/test/Transforms/test-legalizer-no-fold.mlir
new file mode 100644
index 0000000000000..720d17f41943d
--- /dev/null
+++ b/mlir/test/Transforms/test-legalizer-no-fold.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -test-legalize-patterns="test-legalize-folding-mode=never" | FileCheck %s
+
+// CHECK-LABEL: @remove_foldable_op(
+func.func @remove_foldable_op(%arg0 : i32) -> (i32) {
+  // Check that op was not folded.
+  // CHECK: "test.op_with_region_fold"
+  %0 = "test.op_with_region_fold"(%arg0) ({
+    "foo.op_with_region_terminator"() : () -> ()
+  }) : (i32) -> (i32)
+  "test.return"(%0) : (i32) -> ()
+}
+

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 3cdd2f226687e..231400ec9cd29 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1498,6 +1498,8 @@ def TestOpInPlaceSelfFold : TEST_Op<"op_in_place_self_fold"> {
   let results = (outs I32);
   let hasFolder = 1;
 }
+def : Pat<(TestOpInPlaceSelfFold:$op $_),
+          (TestOpConstant ConstantAttr<I32Attr, "97">)>;
 
 // Test op that simply returns success.
 def TestOpInPlaceFoldSuccess : TEST_Op<"op_in_place_fold_success"> {

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index f92f0982f85b2..ff958d9a3d2b4 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1507,8 +1507,8 @@ struct TestLegalizePatternDriver
     ConversionTarget target(getContext());
     target.addLegalOp<ModuleOp>();
     target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
-                      TerminatorOp, OneRegionOp, TestValidProducerOp,
-                      TestValidConsumerOp>();
+                      TerminatorOp, TestOpConstant, OneRegionOp,
+                      TestValidProducerOp, TestValidConsumerOp>();
     target.addLegalOp(OperationName("test.legal_op", &getContext()));
     target
         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
@@ -1563,6 +1563,7 @@ struct TestLegalizePatternDriver
       DumpNotifications dumpNotifications;
       config.listener = &dumpNotifications;
       config.unlegalizedOps = &unlegalizedOps;
+      config.foldingMode = foldingMode;
       if (failed(applyPartialConversion(getOperation(), target,
                                         std::move(patterns), config))) {
         getOperation()->emitRemark() << "applyPartialConversion failed";
@@ -1582,6 +1583,7 @@ struct TestLegalizePatternDriver
 
       ConversionConfig config;
       DumpNotifications dumpNotifications;
+      config.foldingMode = foldingMode;
       config.listener = &dumpNotifications;
       if (failed(applyFullConversion(getOperation(), target,
                                      std::move(patterns), config))) {
@@ -1596,6 +1598,7 @@ struct TestLegalizePatternDriver
     // Analyze the convertible operations.
     DenseSet<Operation *> legalizedOps;
     ConversionConfig config;
+    config.foldingMode = foldingMode;
     config.legalizableOps = &legalizedOps;
     if (failed(applyAnalysisConversion(getOperation(), target,
                                        std::move(patterns), config)))
@@ -1616,6 +1619,21 @@ struct TestLegalizePatternDriver
           clEnumValN(ConversionMode::Full, "full", "Perform a full conversion"),
           clEnumValN(ConversionMode::Partial, "partial",
                      "Perform a partial conversion"))};
+
+  Option<DialectConversionFoldingMode> foldingMode{
+      *this, "test-legalize-folding-mode",
+      llvm::cl::desc("The folding mode to use with the test driver"),
+      llvm::cl::init(DialectConversionFoldingMode::BeforePatterns),
+      llvm::cl::values(clEnumValN(DialectConversionFoldingMode::Never, "never",
+                                  "Never attempt to fold"),
+                       clEnumValN(DialectConversionFoldingMode::BeforePatterns,
+                                  "before-patterns",
+                                  "Only attempt to fold not legal operations "
+                                  "before applying patterns"),
+                       clEnumValN(DialectConversionFoldingMode::AfterPatterns,
+                                  "after-patterns",
+                                  "Only attempt to fold not legal operations "
+                                  "after applying patterns"))};
 };
 } // namespace
 


        


More information about the Mlir-commits mailing list