[Mlir-commits] [mlir] [mlir][Transforms] Dialect conversion: No rollback during analysis conversion (PR #106414)
Matthias Springer
llvmlistbot at llvm.org
Wed Aug 28 09:15:42 PDT 2024
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/106414
This commit changes the implementation of analysis conversions, so that no rollback is needed. Instead, the dialect conversion is run on a clone of the IR.
The purpose of this commit is to reduce the number of rollbacks in the dialect conversion framework.
>From 6c274bdb9430183ad57ff37ff8b98440c41f9236 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Wed, 28 Aug 2024 18:11:51 +0200
Subject: [PATCH] [mlir][Transforms] Dialect conversion: No rollback during
analysis conversion
This commit changes the implementation of analysis conversions, so that no rollback is needed. Instead, the dialect conversion is run on a clone of the IR.
The purpose of this commit is to reduce the number of rollbacks in the dialect conversion framework.
---
.../Transforms/Utils/DialectConversion.cpp | 76 +++++++++++++++++--
1 file changed, 68 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index cc9c9495e5155c..951bc7074277e9 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2489,13 +2489,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
if (failed(finalize(rewriter)))
return rewriterImpl.undoRewrites(), failure();
- // After a successful conversion, apply rewrites if this is not an analysis
- // conversion.
- if (mode == OpConversionMode::Analysis) {
- rewriterImpl.undoRewrites();
- } else {
- rewriterImpl.applyRewrites();
- }
+ // After a successful conversion, apply rewrites.
+ rewriterImpl.applyRewrites();
// Gather all unresolved materializations.
SmallVector<UnrealizedConversionCastOp> allCastOps;
@@ -3311,13 +3306,78 @@ LogicalResult mlir::applyFullConversion(Operation *op,
//===----------------------------------------------------------------------===//
// Analysis Conversion
+/// Find a common IsolatedFromAbove ancestor of the given ops. If at least one
+/// op is a top-level module op (which is expected to be isolated from above),
+/// return that op.
+static Operation *findCommonAncestor(ArrayRef<Operation *> ops) {
+ // Check if there is a top-level operation within `ops`. If so, return that
+ // op.
+ for (Operation *op : ops) {
+ if (!op->getParentOp()) {
+#ifndef NDEBUG
+ assert(op->hasTrait<OpTrait::IsolatedFromAbove>() &&
+ "expected top-level op to be isolated from above");
+ for (Operation *other : ops)
+ assert(op->isAncestor(other) &&
+ "expected ops to have a common ancestor");
+#endif // NDEBUG
+ return op;
+ }
+ }
+
+ // No top-level op. Find a common ancestor.
+ Operation *commonAncestor =
+ ops.front()->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
+ for (Operation *op : ops.drop_front()) {
+ while (!commonAncestor->isProperAncestor(op)) {
+ commonAncestor =
+ commonAncestor->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
+ assert(commonAncestor &&
+ "expected to find a common isolated from above ancestor");
+ }
+ }
+
+ return commonAncestor;
+}
+
LogicalResult mlir::applyAnalysisConversion(
ArrayRef<Operation *> ops, ConversionTarget &target,
const FrozenRewritePatternSet &patterns, ConversionConfig config) {
+#ifndef NDEBUG
+ if (config.legalizableOps)
+ assert(config.legalizableOps->empty() && "expected empty set");
+#endif // NDEBUG
+
+ // Clone closted common ancestor that is isolated from above.
+ Operation *commonAncestor = findCommonAncestor(ops);
+ IRMapping mapping;
+ Operation *clonedAncestor = commonAncestor->clone(mapping);
+ // Compute inverse IR mapping.
+ DenseMap<Operation *, Operation *> inverseOperationMap;
+ for (auto &it : mapping.getOperationMap())
+ inverseOperationMap[it.second] = it.first;
+
+ // Convert the cloned operations. The original IR will remain unchanged.
+ SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
+ ops, [&](Operation *op) { return mapping.lookup(op); });
OperationConverter opConverter(target, patterns, config,
OpConversionMode::Analysis);
- return opConverter.convertOperations(ops);
+ LogicalResult status = opConverter.convertOperations(opsToConvert);
+
+ // Remap `legalizableOps`, so that they point to the original ops and not the
+ // cloned ops.
+ if (config.legalizableOps) {
+ DenseSet<Operation *> originalLegalizableOps;
+ for (Operation *op : *config.legalizableOps)
+ originalLegalizableOps.insert(inverseOperationMap[op]);
+ *config.legalizableOps = std::move(originalLegalizableOps);
+ }
+
+ // Erase the cloned IR.
+ clonedAncestor->erase();
+ return status;
}
+
LogicalResult
mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
More information about the Mlir-commits
mailing list