[Mlir-commits] [mlir] [mlir][Transforms] Dialect conversion: No rollback during analysis conversion (PR #106414)

Matthias Springer llvmlistbot at llvm.org
Wed Aug 28 09:15:42 PDT 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/106414

This commit changes the implementation of analysis conversions, so that no rollback is needed. Instead, the dialect conversion is run on a clone of the IR.

The purpose of this commit is to reduce the number of rollbacks in the dialect conversion framework.

>From 6c274bdb9430183ad57ff37ff8b98440c41f9236 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Wed, 28 Aug 2024 18:11:51 +0200
Subject: [PATCH] [mlir][Transforms] Dialect conversion: No rollback during
 analysis conversion

This commit changes the implementation of analysis conversions, so that no rollback is needed. Instead, the dialect conversion is run on a clone of the IR.

The purpose of this commit is to reduce the number of rollbacks in the dialect conversion framework.
---
 .../Transforms/Utils/DialectConversion.cpp    | 76 +++++++++++++++++--
 1 file changed, 68 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index cc9c9495e5155c..951bc7074277e9 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2489,13 +2489,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   if (failed(finalize(rewriter)))
     return rewriterImpl.undoRewrites(), failure();
 
-  // After a successful conversion, apply rewrites if this is not an analysis
-  // conversion.
-  if (mode == OpConversionMode::Analysis) {
-    rewriterImpl.undoRewrites();
-  } else {
-    rewriterImpl.applyRewrites();
-  }
+  // After a successful conversion, apply rewrites.
+  rewriterImpl.applyRewrites();
 
   // Gather all unresolved materializations.
   SmallVector<UnrealizedConversionCastOp> allCastOps;
@@ -3311,13 +3306,78 @@ LogicalResult mlir::applyFullConversion(Operation *op,
 //===----------------------------------------------------------------------===//
 // Analysis Conversion
 
+/// Find a common IsolatedFromAbove ancestor of the given ops. If at least one
+/// op is a top-level module op (which is expected to be isolated from above),
+/// return that op.
+static Operation *findCommonAncestor(ArrayRef<Operation *> ops) {
+  // Check if there is a top-level operation within `ops`. If so, return that
+  // op.
+  for (Operation *op : ops) {
+    if (!op->getParentOp()) {
+#ifndef NDEBUG
+      assert(op->hasTrait<OpTrait::IsolatedFromAbove>() &&
+             "expected top-level op to be isolated from above");
+      for (Operation *other : ops)
+        assert(op->isAncestor(other) &&
+               "expected ops to have a common ancestor");
+#endif // NDEBUG
+      return op;
+    }
+  }
+
+  // No top-level op. Find a common ancestor.
+  Operation *commonAncestor =
+      ops.front()->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
+  for (Operation *op : ops.drop_front()) {
+    while (!commonAncestor->isProperAncestor(op)) {
+      commonAncestor =
+          commonAncestor->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
+      assert(commonAncestor &&
+             "expected to find a common isolated from above ancestor");
+    }
+  }
+
+  return commonAncestor;
+}
+
 LogicalResult mlir::applyAnalysisConversion(
     ArrayRef<Operation *> ops, ConversionTarget &target,
     const FrozenRewritePatternSet &patterns, ConversionConfig config) {
+#ifndef NDEBUG
+  if (config.legalizableOps)
+    assert(config.legalizableOps->empty() && "expected empty set");
+#endif // NDEBUG
+
+  // Clone closted common ancestor that is isolated from above.
+  Operation *commonAncestor = findCommonAncestor(ops);
+  IRMapping mapping;
+  Operation *clonedAncestor = commonAncestor->clone(mapping);
+  // Compute inverse IR mapping.
+  DenseMap<Operation *, Operation *> inverseOperationMap;
+  for (auto &it : mapping.getOperationMap())
+    inverseOperationMap[it.second] = it.first;
+
+  // Convert the cloned operations. The original IR will remain unchanged.
+  SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
+      ops, [&](Operation *op) { return mapping.lookup(op); });
   OperationConverter opConverter(target, patterns, config,
                                  OpConversionMode::Analysis);
-  return opConverter.convertOperations(ops);
+  LogicalResult status = opConverter.convertOperations(opsToConvert);
+
+  // Remap `legalizableOps`, so that they point to the original ops and not the
+  // cloned ops.
+  if (config.legalizableOps) {
+    DenseSet<Operation *> originalLegalizableOps;
+    for (Operation *op : *config.legalizableOps)
+      originalLegalizableOps.insert(inverseOperationMap[op]);
+    *config.legalizableOps = std::move(originalLegalizableOps);
+  }
+
+  // Erase the cloned IR.
+  clonedAncestor->erase();
+  return status;
 }
+
 LogicalResult
 mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
                               const FrozenRewritePatternSet &patterns,



More information about the Mlir-commits mailing list