[Mlir-commits] [mlir] [mlir] Do not abort when encountering a 1:1 conversion pattern (PR #153605)

Markus Böck llvmlistbot at llvm.org
Thu Aug 14 08:56:31 PDT 2025


https://github.com/zero9178 created https://github.com/llvm/llvm-project/pull/153605

Prior to this PR, the default behaviour of a conversion pattern which receives operands of a 1:N is to abort the compilation. This has historically been useful when the 1:N type conversion got merged into the dialect conversion as it allowed us to easily find patterns that should be capable of handling 1:N type conversions but didn't.

However, this behaviour has the disadvantage of being non-composable: While the pattern in question cannot handle the 1:N type conversion, another pattern part of the set might, but doesn't get the chance as compilation is aborted.

This PR fixes this behaviour by failing to match and instead of aborting, giving other patterns the chance to legalize an op. The implementation uses a reusable function called `dispatchTo1To1` to allow derived conversion patterns to also implement the behaviour.

>From 4db31082e154e41ad4af05df5a6adddcb10a8f93 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Markus=20B=C3=B6ck?= <mboeck at nvidia.com>
Date: Thu, 14 Aug 2025 17:55:33 +0200
Subject: [PATCH] [mlir] Do not abort when encountering a 1:1 conversion
 pattern

Prior to this PR, the default behaviour of a conversion pattern which receives operands of a 1:N is to abort the compilation. This has historically been useful when the 1:N type conversion got merged into the dialect conversion as it allowed us to easily find patterns that should be capable of handling 1:N type conversions but didn't.

However, this behaviour has the disadvantage of being non-composable: While the pattern in question cannot handle the 1:N type conversion, another pattern part of the set might, but doesn't get the chance as compilation is aborted.

This PR fixes this behaviour by failing to match and instead of aborting, giving other patterns the chance to legalize an op. The implementation uses a reusable function called `dispatchTo1To1` to allow derived conversion patterns to also implement the behaviour.
---
 .../mlir/Conversion/LLVMCommon/Pattern.h      |  6 +-
 .../mlir/Transforms/DialectConversion.h       | 62 ++++++++++++++++---
 .../Transforms/Utils/DialectConversion.cpp    |  8 +--
 mlir/test/Transforms/test-legalizer.mlir      | 21 +++++++
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   | 21 ++++++-
 5 files changed, 98 insertions(+), 20 deletions(-)

diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 969154abe8830..19148a5d783f3 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -233,9 +233,7 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
   virtual LogicalResult
   matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const {
-    SmallVector<Value> oneToOneOperands =
-        getOneToOneAdaptorOperands(adaptor.getOperands());
-    return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+    return dispatchTo1To1(*this, op, adaptor, rewriter);
   }
 
 private:
@@ -276,7 +274,7 @@ class ConvertOpInterfaceToLLVMPattern : public ConvertToLLVMPattern {
   virtual LogicalResult
   matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
                   ConversionPatternRewriter &rewriter) const {
-    return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+    return dispatchTo1To1(*this, op, operands, rewriter);
   }
 
 private:
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index e601c821e1e4e..80a09b643ce38 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -521,8 +521,8 @@ class ConversionPattern : public RewritePattern {
 
   /// Hook for derived classes to implement combined matching and rewriting.
   /// This overload supports only 1:1 replacements. The 1:N overload is called
-  /// by the driver. By default, it calls this 1:1 overload or reports a fatal
-  /// error if 1:N replacements were found.
+  /// by the driver. By default, it calls this 1:1 overload or fails to match
+  /// if 1:N replacements were found.
   virtual LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const {
@@ -534,7 +534,7 @@ class ConversionPattern : public RewritePattern {
   virtual LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
                   ConversionPatternRewriter &rewriter) const {
-    return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+    return dispatchTo1To1(*this, op, operands, rewriter);
   }
 
   /// Attempt to match and rewrite the IR root at the specified operation.
@@ -567,11 +567,26 @@ class ConversionPattern : public RewritePattern {
   /// try to extract the single value of each range to construct a the inputs
   /// for a 1:1 adaptor.
   ///
-  /// This function produces a fatal error if at least one range has 0 or
-  /// more than 1 value: "pattern 'name' does not support 1:N conversion"
-  SmallVector<Value>
+  /// Returns failure if at least one range has 0 or more than 1 value.
+  FailureOr<SmallVector<Value>>
   getOneToOneAdaptorOperands(ArrayRef<ValueRange> operands) const;
 
+  /// Overloaded method used to dispatch to the 1:1 'matchAndRewrite' method
+  /// if possible and emit diagnostic with a failure return value otherwise.
+  /// 'self' should be '*this' of the derived-pattern and is used to dispatch
+  /// to the correct 'matchAndRewrite' method in the derived pattern.
+  template <typename Pattern, typename SourceOp>
+  static LogicalResult dispatchTo1To1(const Pattern &self, SourceOp op,
+                                      ArrayRef<ValueRange> operands,
+                                      ConversionPatternRewriter &rewriter);
+
+  /// Same as above, but accepts an adaptor as operand.
+  template <typename Pattern, typename SourceOp>
+  static LogicalResult dispatchTo1To1(
+      const Pattern &self, SourceOp op,
+      typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>> adaptor,
+      ConversionPatternRewriter &rewriter);
+
 protected:
   /// An optional type converter for use by this pattern.
   const TypeConverter *typeConverter = nullptr;
@@ -620,9 +635,7 @@ class OpConversionPattern : public ConversionPattern {
   virtual LogicalResult
   matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const {
-    SmallVector<Value> oneToOneOperands =
-        getOneToOneAdaptorOperands(adaptor.getOperands());
-    return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+    return dispatchTo1To1(*this, op, adaptor, rewriter);
   }
 
 private:
@@ -666,7 +679,7 @@ class OpInterfaceConversionPattern : public ConversionPattern {
   virtual LogicalResult
   matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
                   ConversionPatternRewriter &rewriter) const {
-    return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+    return dispatchTo1To1(*this, op, operands, rewriter);
   }
 
 private:
@@ -865,6 +878,35 @@ class ConversionPatternRewriter final : public PatternRewriter {
   std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
 };
 
+template <class Pattern, typename SourceOp>
+LogicalResult
+ConversionPattern::dispatchTo1To1(const Pattern &self, SourceOp op,
+                                  ArrayRef<ValueRange> operands,
+                                  ConversionPatternRewriter &rewriter) {
+  FailureOr<SmallVector<Value>> oneToOneOperands =
+      self.getOneToOneAdaptorOperands(operands);
+  if (failed(oneToOneOperands))
+    return rewriter.notifyMatchFailure(op,
+                                       "pattern '" + self.getDebugName() +
+                                           "' does not support 1:N conversion");
+  return self.matchAndRewrite(op, *oneToOneOperands, rewriter);
+}
+
+template <typename Pattern, typename SourceOp>
+LogicalResult ConversionPattern::dispatchTo1To1(
+    const Pattern &self, SourceOp op,
+    typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>> adaptor,
+    ConversionPatternRewriter &rewriter) {
+  FailureOr<SmallVector<Value>> oneToOneOperands =
+      self.getOneToOneAdaptorOperands(adaptor.getOperands());
+  if (failed(oneToOneOperands))
+    return rewriter.notifyMatchFailure(op,
+                                       "pattern '" + self.getDebugName() +
+                                           "' does not support 1:N conversion");
+  return self.matchAndRewrite(
+      op, typename SourceOp::Adaptor(*oneToOneOperands, adaptor), rewriter);
+}
+
 //===----------------------------------------------------------------------===//
 // ConversionTarget
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 001c13e1ab08c..ff34a58965763 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2244,17 +2244,17 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
 // ConversionPattern
 //===----------------------------------------------------------------------===//
 
-SmallVector<Value> ConversionPattern::getOneToOneAdaptorOperands(
+FailureOr<SmallVector<Value>> ConversionPattern::getOneToOneAdaptorOperands(
     ArrayRef<ValueRange> operands) const {
   SmallVector<Value> oneToOneOperands;
   oneToOneOperands.reserve(operands.size());
   for (ValueRange operand : operands) {
     if (operand.size() != 1)
-      llvm::report_fatal_error("pattern '" + getDebugName() +
-                               "' does not support 1:N conversion");
+      return failure();
+
     oneToOneOperands.push_back(operand.front());
   }
-  return oneToOneOperands;
+  return std::move(oneToOneOperands);
 }
 
 LogicalResult
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 9a04da7904863..55d153db7f4bb 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -439,3 +439,24 @@ func.func @test_lookup_without_converter() {
   // expected-remark at +1 {{op 'func.return' is not legalizable}}
   return
 }
+
+// -----
+// expected-remark at -1 {{applyPartialConversion failed}}
+
+func.func @test_skip_1to1_pattern(%arg0: f32) {
+  // expected-error at +1 {{failed to legalize operation 'test.type_consumer'}}
+  "test.type_consumer"(%arg0) : (f32) -> ()
+  return
+}
+
+// -----
+
+// Demonstrate that the pattern generally works, but only for 1:1 type
+// conversions.
+
+// CHECK-LABEL: @test_working_1to1_pattern(
+func.func @test_working_1to1_pattern(%arg0: f16) {
+  // CHECK-NEXT: "test.return"() : () -> ()
+  "test.type_consumer"(%arg0) : (f16) -> ()
+  "test.return"() : () -> ()
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 657dfd2bac6ec..6300c5b0ca21c 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1386,6 +1386,23 @@ class TestMultiple1ToNReplacement : public ConversionPattern {
   }
 };
 
+/// Pattern that erases 'test.type_consumers' iff the input operand is the
+/// result of a 1:1 type conversion.
+/// Used to test correct skipping of 1:1 patterns in the 1:N case.
+class TestTypeConsumerOpPattern
+    : public OpConversionPattern<TestTypeConsumerOp> {
+public:
+  TestTypeConsumerOpPattern(MLIRContext *ctx, const TypeConverter &converter)
+      : OpConversionPattern<TestTypeConsumerOp>(converter, ctx) {}
+
+  LogicalResult
+  matchAndRewrite(TestTypeConsumerOp op, OpAdaptor operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
 /// Test unambiguous overload resolution of replaceOpWithMultiple. This
 /// function is just to trigger compiler errors. It is never executed.
 [[maybe_unused]] void testReplaceOpWithMultipleOverloads(
@@ -1497,8 +1514,8 @@ struct TestLegalizePatternDriver
         TestRepetitive1ToNConsumer>(&getContext());
     patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
                  TestPassthroughInvalidOp, TestMultiple1ToNReplacement,
-                 TestBlockArgReplace, TestReplaceWithValidConsumer>(
-        &getContext(), converter);
+                 TestBlockArgReplace, TestReplaceWithValidConsumer,
+                 TestTypeConsumerOpPattern>(&getContext(), converter);
     patterns.add<TestConvertBlockArgs>(converter, &getContext());
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
                                                               converter);



More information about the Mlir-commits mailing list