[Mlir-commits] [mlir] 78711b6 - [mlir][Transforms] Legalize nested operations (#172158)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Dec 16 07:47:51 PST 2025
Author: Matthias Springer
Date: 2025-12-16T16:47:47+01:00
New Revision: 78711b66bde63188a983c0b8227dc55814e86bb6
URL: https://github.com/llvm/llvm-project/commit/78711b66bde63188a983c0b8227dc55814e86bb6
DIFF: https://github.com/llvm/llvm-project/commit/78711b66bde63188a983c0b8227dc55814e86bb6.diff
LOG: [mlir][Transforms] Legalize nested operations (#172158)
This commit align the implementation of
`ConversionPatternRewriter::legalize` with its documentation:
```
/// Attempt to legalize the given region. This can be used within
...
LogicalResult legalize(Region *r);
```
This function now legalizes the entire region, including nested ops. The
implementation follows the same logic as the "main" traversal:
pre-order, forward-dominance.
Added:
Modified:
mlir/lib/Transforms/Utils/DialectConversion.cpp
mlir/test/Transforms/test-legalizer.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 09ad42364baaf..d4b1c8c7f0a74 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2257,37 +2257,6 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
return success();
}
-LogicalResult ConversionPatternRewriter::legalize(Region *r) {
- // Fast path: If the region is empty, there is nothing to legalize.
- if (r->empty())
- return success();
-
- // Gather a list of all operations to legalize. This is done before
- // converting the entry block signature because unrealized_conversion_cast
- // ops should not be included.
- SmallVector<Operation *> ops;
- for (Block &b : *r)
- for (Operation &op : b)
- ops.push_back(&op);
-
- // If the current pattern runs with a type converter, convert the entry block
- // signature.
- if (const TypeConverter *converter = impl->currentTypeConverter) {
- std::optional<TypeConverter::SignatureConversion> conversion =
- converter->convertBlockSignature(&r->front());
- if (!conversion)
- return failure();
- applySignatureConversion(&r->front(), *conversion, converter);
- }
-
- // Legalize all operations in the region.
- for (Operation *op : ops)
- if (failed(legalize(op)))
- return failure();
-
- return success();
-}
-
void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
Block::iterator before,
ValueRange argValues) {
@@ -3287,8 +3256,20 @@ struct OperationConverter {
: rewriter(ctx, config, *this), opLegalizer(rewriter, target, patterns),
mode(mode) {}
- /// Converts the given operations to the conversion target.
- LogicalResult convertOperations(ArrayRef<Operation *> ops);
+ /// Applies the conversion to the given operations (and their nested
+ /// operations).
+ LogicalResult applyConversion(ArrayRef<Operation *> ops);
+
+ /// Legalizes the given operations (and their nested operations) to the
+ /// conversion target.
+ template <typename Fn>
+ LogicalResult legalizeOperations(ArrayRef<Operation *> ops, Fn onFailure,
+ bool isRecursiveLegalization = false);
+ LogicalResult legalizeOperations(ArrayRef<Operation *> ops,
+ bool isRecursiveLegalization = false) {
+ return legalizeOperations(
+ ops, /*onFailure=*/[&]() {}, isRecursiveLegalization);
+ }
/// Converts a single operation. If `isRecursiveLegalization` is "true", the
/// conversion is a recursive legalization request, triggered from within a
@@ -3297,6 +3278,8 @@ struct OperationConverter {
/// legalization mechanism).
LogicalResult convert(Operation *op, bool isRecursiveLegalization = false);
+ const ConversionTarget &getTarget() { return opLegalizer.getTarget(); }
+
private:
/// The rewriter to use when converting operations.
ConversionPatternRewriter rewriter;
@@ -3309,10 +3292,6 @@ struct OperationConverter {
};
} // namespace mlir
-LogicalResult ConversionPatternRewriter::legalize(Operation *op) {
- return impl->opConverter.convert(op, /*isRecursiveLegalization=*/true);
-}
-
LogicalResult OperationConverter::convert(Operation *op,
bool isRecursiveLegalization) {
const ConversionConfig &config = rewriter.getConfig();
@@ -3398,12 +3377,15 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
return failure();
}
-LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
+template <typename Fn>
+LogicalResult
+OperationConverter::legalizeOperations(ArrayRef<Operation *> ops, Fn onFailure,
+ bool isRecursiveLegalization) {
const ConversionTarget &target = opLegalizer.getTarget();
// Compute the set of operations and blocks to convert.
SmallVector<Operation *> toConvert;
- for (auto *op : ops) {
+ for (Operation *op : ops) {
op->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
[&](Operation *op) {
toConvert.push_back(op);
@@ -3415,25 +3397,67 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
return WalkResult::advance();
});
}
+ for (Operation *op : toConvert) {
+ if (failed(convert(op, isRecursiveLegalization))) {
+ // Failed to convert an operation.
+ onFailure();
+ return failure();
+ }
+ }
+ return success();
+}
- // Convert each operation and discard rewrites on failure.
- ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
+LogicalResult ConversionPatternRewriter::legalize(Operation *op) {
+ return impl->opConverter.legalizeOperations(op,
+ /*isRecursiveLegalization=*/true);
+}
- for (auto *op : toConvert) {
- if (failed(convert(op))) {
- // Dialect conversion failed.
- if (rewriterImpl.config.allowPatternRollback) {
- // Rollback is allowed: restore the original IR.
- rewriterImpl.undoRewrites();
- } else {
- // Rollback is not allowed: apply all modifications that have been
- // performed so far.
- rewriterImpl.applyRewrites();
- }
+LogicalResult ConversionPatternRewriter::legalize(Region *r) {
+ // Fast path: If the region is empty, there is nothing to legalize.
+ if (r->empty())
+ return success();
+
+ // Gather a list of all operations to legalize. This is done before
+ // converting the entry block signature because unrealized_conversion_cast
+ // ops should not be included.
+ SmallVector<Operation *> ops;
+ for (Block &b : *r)
+ for (Operation &op : b)
+ ops.push_back(&op);
+
+ // If the current pattern runs with a type converter, convert the entry block
+ // signature.
+ if (const TypeConverter *converter = impl->currentTypeConverter) {
+ std::optional<TypeConverter::SignatureConversion> conversion =
+ converter->convertBlockSignature(&r->front());
+ if (!conversion)
return failure();
- }
+ applySignatureConversion(&r->front(), *conversion, converter);
}
+ // Legalize all operations in the region. This includes all nested
+ // operations.
+ return impl->opConverter.legalizeOperations(ops,
+ /*isRecursiveLegalization=*/true);
+}
+
+LogicalResult OperationConverter::applyConversion(ArrayRef<Operation *> ops) {
+ // Convert each operation and discard rewrites on failure.
+ ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
+ LogicalResult status = legalizeOperations(ops, /*onFailure=*/[&]() {
+ // Dialect conversion failed.
+ if (rewriterImpl.config.allowPatternRollback) {
+ // Rollback is allowed: restore the original IR.
+ rewriterImpl.undoRewrites();
+ } else {
+ // Rollback is not allowed: apply all modifications that have been
+ // performed so far.
+ rewriterImpl.applyRewrites();
+ }
+ });
+ if (failed(status))
+ return failure();
+
// After a successful conversion, apply rewrites.
rewriterImpl.applyRewrites();
@@ -4143,7 +4167,7 @@ static LogicalResult applyConversion(ArrayRef<Operation *> ops,
[&] {
OperationConverter opConverter(ops.front()->getContext(), target,
patterns, config, mode);
- status = opConverter.convertOperations(ops);
+ status = opConverter.applyConversion(ops);
},
irUnits);
return status;
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 88a71cc26ab0c..8d854aff1992f 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -454,11 +454,16 @@ func.func @test_working_1to1_pattern(%arg0: f16) {
// The region of "test.post_order_legalization" is converted before the op.
// CHECK: notifyBlockInserted into test.post_order_legalization: was unlinked
+// CHECK: notifyOperationInserted: test.remaining_consumer
+// CHECK: notifyOperationInserted: test.legal_op
// CHECK: notifyOperationInserted: test.invalid
// CHECK: notifyBlockErased
// CHECK: notifyOperationInserted: test.valid, was unlinked
// CHECK: notifyOperationReplaced: test.invalid
// CHECK: notifyOperationErased: test.invalid
+// CHECK: notifyOperationInserted: test.valid, was unlinked
+// CHECK: notifyOperationReplaced: test.invalid
+// CHECK: notifyOperationErased: test.invalid
// CHECK: notifyOperationModified: test.post_order_legalization
// CHECK-LABEL: func @test_preorder_legalization
@@ -475,6 +480,9 @@ func.func @test_preorder_legalization() {
^bb0(%arg0: i64):
// expected-remark @+1 {{'test.remaining_consumer' is not legalizable}}
"test.remaining_consumer"(%arg0) : (i64) -> ()
+ "test.legal_op"() ({
+ "test.invalid"(%arg0) : (i64) -> ()
+ }) : () -> ()
"test.invalid"(%arg0) : (i64) -> ()
}) : () -> ()
// expected-remark @+1 {{'func.return' is not legalizable}}
More information about the Mlir-commits
mailing list