[Mlir-commits] [mlir] [mlir][Conversion] Generalize and fix crash in `reconcile-unrealized-casts` (PR #95700)

Markus Böck llvmlistbot at llvm.org
Sun Jun 16 06:37:27 PDT 2024


================
@@ -22,113 +20,87 @@ using namespace mlir;
 
 namespace {
 
-/// Folds the DAGs of `unrealized_conversion_cast`s that have as exit types
-/// the same as the input ones.
-/// For example, the DAGs `A -> B -> C -> B -> A` and `A -> B -> C -> A`
-/// represent a noop within the IR, and thus the initial input values can be
-/// propagated.
-/// The same does not hold for 'open' chains of casts, such as
-/// `A -> B -> C`. In this last case there is no cycle among the types and thus
-/// the conversion is incomplete. The same hold for 'closed' chains like
-/// `A -> B -> A`, but with the result of type `B` being used by some non-cast
-/// operations.
-/// Bifurcations (that is when a chain starts in between of another one) are
-/// also taken into considerations, and all the above considerations remain
-/// valid.
-/// Special corner cases such as dead casts or single casts with same input and
-/// output types are also covered.
-struct UnrealizedConversionCastPassthrough
-    : public OpRewritePattern<UnrealizedConversionCastOp> {
-  using OpRewritePattern<UnrealizedConversionCastOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(UnrealizedConversionCastOp op,
-                                PatternRewriter &rewriter) const override {
-    // The nodes that either are not used by any operation or have at least
-    // one user that is not an unrealized cast.
-    DenseSet<UnrealizedConversionCastOp> exitNodes;
-
-    // The nodes whose users are all unrealized casts
-    DenseSet<UnrealizedConversionCastOp> intermediateNodes;
-
-    // Stack used for the depth-first traversal of the use-def DAG.
-    SmallVector<UnrealizedConversionCastOp, 2> visitStack;
-    visitStack.push_back(op);
-
-    while (!visitStack.empty()) {
-      UnrealizedConversionCastOp current = visitStack.pop_back_val();
-      auto users = current->getUsers();
-      bool isLive = false;
-
-      for (Operation *user : users) {
-        if (auto other = dyn_cast<UnrealizedConversionCastOp>(user)) {
-          if (other.getInputs() != current.getOutputs())
-            return rewriter.notifyMatchFailure(
-                op, "mismatching values propagation");
-        } else {
-          isLive = true;
-        }
-
-        // Continue traversing the DAG of unrealized casts
-        if (auto other = dyn_cast<UnrealizedConversionCastOp>(user))
-          visitStack.push_back(other);
-      }
-
-      // If the cast is live, then we need to check if the results of the last
-      // cast have the same type of the root inputs. It this is the case (e.g.
-      // `{A -> B, B -> A}`, but also `{A -> A}`), then the cycle is just a
-      // no-op and the inputs can be forwarded. If it's not (e.g.
-      // `{A -> B, B -> C}`, `{A -> B}`), then the cast chain is incomplete.
-
-      bool isCycle = current.getResultTypes() == op.getInputs().getTypes();
-
-      if (isLive && !isCycle)
-        return rewriter.notifyMatchFailure(op,
-                                           "live unrealized conversion cast");
-
-      bool isExitNode = users.empty() || isLive;
-
-      if (isExitNode) {
-        exitNodes.insert(current);
-      } else {
-        intermediateNodes.insert(current);
-      }
-    }
-
-    // Replace the sink nodes with the root input values
-    for (UnrealizedConversionCastOp exitNode : exitNodes)
-      rewriter.replaceOp(exitNode, op.getInputs());
-
-    // Erase all the other casts belonging to the DAG
-    for (UnrealizedConversionCastOp castOp : intermediateNodes)
-      rewriter.eraseOp(castOp);
-
-    return success();
-  }
-};
-
 /// Pass to simplify and eliminate unrealized conversion casts.
+///
+/// This pass processes unrealized_conversion_cast ops in a worklist-driven
+/// fashion. For each matched cast op, if the chain of input casts eventually
+/// reaches a cast op where the input types match the output types of the
+/// matched op, replace the matched op with the inputs.
+///
+/// Example:
+/// %1 = unrealized_conversion_cast %0 : !A to !B
+/// %2 = unrealized_conversion_cast %1 : !B to !C
+/// %3 = unrealized_conversion_cast %2 : !C to !A
+///
+/// In the above example, %0 can be used instead of %3 and all cast ops are
+/// folded away.
 struct ReconcileUnrealizedCasts
     : public impl::ReconcileUnrealizedCastsBase<ReconcileUnrealizedCasts> {
   ReconcileUnrealizedCasts() = default;
 
   void runOnOperation() override {
-    RewritePatternSet patterns(&getContext());
-    populateReconcileUnrealizedCastsPatterns(patterns);
-    ConversionTarget target(getContext());
-    target.addIllegalOp<UnrealizedConversionCastOp>();
-    if (failed(applyPartialConversion(getOperation(), target,
-                                      std::move(patterns))))
-      signalPassFailure();
+    // Gather all unrealized_conversion_cast ops.
+    SetVector<UnrealizedConversionCastOp> worklist;
+    getOperation()->walk(
+        [&](UnrealizedConversionCastOp castOp) { worklist.insert(castOp); });
+
+    // Helper function that adds all operands to the worklist that are an
+    // unrealized_conversion_cast op result.
+    auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
+      for (Value v : castOp.getInputs())
+        if (auto castOp = v.getDefiningOp<UnrealizedConversionCastOp>())
+          worklist.insert(castOp);
+    };
+
+    // Helper function that return the unrealized_conversion_cast op that
+    // defines all inputs of the given op (in the same order). Return "nullptr"
+    // if there is no such op.
+    auto getInputCast =
+        [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
+      if (castOp.getInputs().empty())
----------------
zero9178 wrote:

Should we maybe add a test for this as well? I don't think there is a test exercising this yet

https://github.com/llvm/llvm-project/pull/95700


More information about the Mlir-commits mailing list