[flang-commits] [flang] 8582025 - [mlir][Transforms] Turn 1:N -> 1:1 dispatch fatal error into match failure (#153605)
via flang-commits
flang-commits at lists.llvm.org
Fri Aug 15 02:45:30 PDT 2025
Author: Markus Böck
Date: 2025-08-15T11:45:25+02:00
New Revision: 8582025f1fb9485ced594efe0661ed4a4a80d5c9
URL: https://github.com/llvm/llvm-project/commit/8582025f1fb9485ced594efe0661ed4a4a80d5c9
DIFF: https://github.com/llvm/llvm-project/commit/8582025f1fb9485ced594efe0661ed4a4a80d5c9.diff
LOG: [mlir][Transforms] Turn 1:N -> 1:1 dispatch fatal error into match failure (#153605)
Prior to this PR, the default behaviour of a conversion pattern which
receives operands of a 1:N is to abort the compilation. This has
historically been useful when the 1:N type conversion got merged into
the dialect conversion as it allowed us to easily find patterns that
should be capable of handling 1:N type conversions but didn't.
However, this behaviour has the disadvantage of being non-composable:
While the pattern in question cannot handle the 1:N type conversion,
another pattern part of the set might, but doesn't get the chance as
compilation is aborted.
This PR fixes this behaviour by failing to match and instead of
aborting, giving other patterns the chance to legalize an op. The
implementation uses a reusable function called `dispatchTo1To1` to allow
derived conversion patterns to also implement the behaviour.
Added:
Modified:
flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h
mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Transforms/Utils/DialectConversion.cpp
mlir/test/Transforms/test-legalizer.mlir
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h b/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h
index b7fa8fc3848f2..7d816a8843371 100644
--- a/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h
+++ b/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h
@@ -237,9 +237,7 @@ class FIROpConversion : public ConvertFIRToLLVMPattern {
virtual llvm::LogicalResult
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
- llvm::SmallVector<mlir::Value> oneToOneOperands =
- getOneToOneAdaptorOperands(adaptor.getOperands());
- return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+ return dispatchTo1To1(*this, op, adaptor, rewriter);
}
private:
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 79b102b43a15f..c292e3727f46c 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -243,9 +243,7 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
virtual LogicalResult
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- SmallVector<Value> oneToOneOperands =
- getOneToOneAdaptorOperands(adaptor.getOperands());
- return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+ return dispatchTo1To1(*this, op, adaptor, rewriter);
}
private:
@@ -286,7 +284,7 @@ class ConvertOpInterfaceToLLVMPattern : public ConvertToLLVMPattern {
virtual LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const {
- return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+ return dispatchTo1To1(*this, op, operands, rewriter);
}
private:
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index e601c821e1e4e..220431e6ee2f1 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -521,8 +521,8 @@ class ConversionPattern : public RewritePattern {
/// Hook for derived classes to implement combined matching and rewriting.
/// This overload supports only 1:1 replacements. The 1:N overload is called
- /// by the driver. By default, it calls this 1:1 overload or reports a fatal
- /// error if 1:N replacements were found.
+ /// by the driver. By default, it calls this 1:1 overload or fails to match
+ /// if 1:N replacements were found.
virtual LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
@@ -534,7 +534,7 @@ class ConversionPattern : public RewritePattern {
virtual LogicalResult
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const {
- return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+ return dispatchTo1To1(*this, op, operands, rewriter);
}
/// Attempt to match and rewrite the IR root at the specified operation.
@@ -567,11 +567,26 @@ class ConversionPattern : public RewritePattern {
/// try to extract the single value of each range to construct a the inputs
/// for a 1:1 adaptor.
///
- /// This function produces a fatal error if at least one range has 0 or
- /// more than 1 value: "pattern 'name' does not support 1:N conversion"
- SmallVector<Value>
+ /// Returns failure if at least one range has 0 or more than 1 value.
+ FailureOr<SmallVector<Value>>
getOneToOneAdaptorOperands(ArrayRef<ValueRange> operands) const;
+ /// Overloaded method used to dispatch to the 1:1 'matchAndRewrite' method
+ /// if possible and emit diagnostic with a failure return value otherwise.
+ /// 'self' should be '*this' of the derived-pattern and is used to dispatch
+ /// to the correct 'matchAndRewrite' method in the derived pattern.
+ template <typename SelfPattern, typename SourceOp>
+ static LogicalResult dispatchTo1To1(const SelfPattern &self, SourceOp op,
+ ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter);
+
+ /// Same as above, but accepts an adaptor as operand.
+ template <typename SelfPattern, typename SourceOp>
+ static LogicalResult dispatchTo1To1(
+ const SelfPattern &self, SourceOp op,
+ typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>> adaptor,
+ ConversionPatternRewriter &rewriter);
+
protected:
/// An optional type converter for use by this pattern.
const TypeConverter *typeConverter = nullptr;
@@ -620,9 +635,7 @@ class OpConversionPattern : public ConversionPattern {
virtual LogicalResult
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- SmallVector<Value> oneToOneOperands =
- getOneToOneAdaptorOperands(adaptor.getOperands());
- return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+ return dispatchTo1To1(*this, op, adaptor, rewriter);
}
private:
@@ -666,7 +679,7 @@ class OpInterfaceConversionPattern : public ConversionPattern {
virtual LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const {
- return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+ return dispatchTo1To1(*this, op, operands, rewriter);
}
private:
@@ -865,6 +878,35 @@ class ConversionPatternRewriter final : public PatternRewriter {
std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
};
+template <typename SelfPattern, typename SourceOp>
+LogicalResult
+ConversionPattern::dispatchTo1To1(const SelfPattern &self, SourceOp op,
+ ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) {
+ FailureOr<SmallVector<Value>> oneToOneOperands =
+ self.getOneToOneAdaptorOperands(operands);
+ if (failed(oneToOneOperands))
+ return rewriter.notifyMatchFailure(op,
+ "pattern '" + self.getDebugName() +
+ "' does not support 1:N conversion");
+ return self.matchAndRewrite(op, *oneToOneOperands, rewriter);
+}
+
+template <typename SelfPattern, typename SourceOp>
+LogicalResult ConversionPattern::dispatchTo1To1(
+ const SelfPattern &self, SourceOp op,
+ typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>> adaptor,
+ ConversionPatternRewriter &rewriter) {
+ FailureOr<SmallVector<Value>> oneToOneOperands =
+ self.getOneToOneAdaptorOperands(adaptor.getOperands());
+ if (failed(oneToOneOperands))
+ return rewriter.notifyMatchFailure(op,
+ "pattern '" + self.getDebugName() +
+ "' does not support 1:N conversion");
+ return self.matchAndRewrite(
+ op, typename SourceOp::Adaptor(*oneToOneOperands, adaptor), rewriter);
+}
+
//===----------------------------------------------------------------------===//
// ConversionTarget
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 001c13e1ab08c..ff34a58965763 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2244,17 +2244,17 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
// ConversionPattern
//===----------------------------------------------------------------------===//
-SmallVector<Value> ConversionPattern::getOneToOneAdaptorOperands(
+FailureOr<SmallVector<Value>> ConversionPattern::getOneToOneAdaptorOperands(
ArrayRef<ValueRange> operands) const {
SmallVector<Value> oneToOneOperands;
oneToOneOperands.reserve(operands.size());
for (ValueRange operand : operands) {
if (operand.size() != 1)
- llvm::report_fatal_error("pattern '" + getDebugName() +
- "' does not support 1:N conversion");
+ return failure();
+
oneToOneOperands.push_back(operand.front());
}
- return oneToOneOperands;
+ return std::move(oneToOneOperands);
}
LogicalResult
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 9a04da7904863..55d153db7f4bb 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -439,3 +439,24 @@ func.func @test_lookup_without_converter() {
// expected-remark at +1 {{op 'func.return' is not legalizable}}
return
}
+
+// -----
+// expected-remark at -1 {{applyPartialConversion failed}}
+
+func.func @test_skip_1to1_pattern(%arg0: f32) {
+ // expected-error at +1 {{failed to legalize operation 'test.type_consumer'}}
+ "test.type_consumer"(%arg0) : (f32) -> ()
+ return
+}
+
+// -----
+
+// Demonstrate that the pattern generally works, but only for 1:1 type
+// conversions.
+
+// CHECK-LABEL: @test_working_1to1_pattern(
+func.func @test_working_1to1_pattern(%arg0: f16) {
+ // CHECK-NEXT: "test.return"() : () -> ()
+ "test.type_consumer"(%arg0) : (f16) -> ()
+ "test.return"() : () -> ()
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 657dfd2bac6ec..6300c5b0ca21c 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1386,6 +1386,23 @@ class TestMultiple1ToNReplacement : public ConversionPattern {
}
};
+/// Pattern that erases 'test.type_consumers' iff the input operand is the
+/// result of a 1:1 type conversion.
+/// Used to test correct skipping of 1:1 patterns in the 1:N case.
+class TestTypeConsumerOpPattern
+ : public OpConversionPattern<TestTypeConsumerOp> {
+public:
+ TestTypeConsumerOpPattern(MLIRContext *ctx, const TypeConverter &converter)
+ : OpConversionPattern<TestTypeConsumerOp>(converter, ctx) {}
+
+ LogicalResult
+ matchAndRewrite(TestTypeConsumerOp op, OpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const final {
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
/// Test unambiguous overload resolution of replaceOpWithMultiple. This
/// function is just to trigger compiler errors. It is never executed.
[[maybe_unused]] void testReplaceOpWithMultipleOverloads(
@@ -1497,8 +1514,8 @@ struct TestLegalizePatternDriver
TestRepetitive1ToNConsumer>(&getContext());
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
TestPassthroughInvalidOp, TestMultiple1ToNReplacement,
- TestBlockArgReplace, TestReplaceWithValidConsumer>(
- &getContext(), converter);
+ TestBlockArgReplace, TestReplaceWithValidConsumer,
+ TestTypeConsumerOpPattern>(&getContext(), converter);
patterns.add<TestConvertBlockArgs>(converter, &getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
More information about the flang-commits
mailing list