[Mlir-commits] [mlir] [mlir][Conversion] Generalize and fix crash in `reconcile-unrealized-casts` (PR #95700)

Markus Böck llvmlistbot at llvm.org
Sun Jun 16 06:08:20 PDT 2024


================
@@ -22,113 +20,87 @@ using namespace mlir;
 
 namespace {
 
-/// Folds the DAGs of `unrealized_conversion_cast`s that have as exit types
-/// the same as the input ones.
-/// For example, the DAGs `A -> B -> C -> B -> A` and `A -> B -> C -> A`
-/// represent a noop within the IR, and thus the initial input values can be
-/// propagated.
-/// The same does not hold for 'open' chains of casts, such as
-/// `A -> B -> C`. In this last case there is no cycle among the types and thus
-/// the conversion is incomplete. The same hold for 'closed' chains like
-/// `A -> B -> A`, but with the result of type `B` being used by some non-cast
-/// operations.
-/// Bifurcations (that is when a chain starts in between of another one) are
-/// also taken into considerations, and all the above considerations remain
-/// valid.
-/// Special corner cases such as dead casts or single casts with same input and
-/// output types are also covered.
-struct UnrealizedConversionCastPassthrough
-    : public OpRewritePattern<UnrealizedConversionCastOp> {
-  using OpRewritePattern<UnrealizedConversionCastOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(UnrealizedConversionCastOp op,
-                                PatternRewriter &rewriter) const override {
-    // The nodes that either are not used by any operation or have at least
-    // one user that is not an unrealized cast.
-    DenseSet<UnrealizedConversionCastOp> exitNodes;
-
-    // The nodes whose users are all unrealized casts
-    DenseSet<UnrealizedConversionCastOp> intermediateNodes;
-
-    // Stack used for the depth-first traversal of the use-def DAG.
-    SmallVector<UnrealizedConversionCastOp, 2> visitStack;
-    visitStack.push_back(op);
-
-    while (!visitStack.empty()) {
-      UnrealizedConversionCastOp current = visitStack.pop_back_val();
-      auto users = current->getUsers();
-      bool isLive = false;
-
-      for (Operation *user : users) {
-        if (auto other = dyn_cast<UnrealizedConversionCastOp>(user)) {
-          if (other.getInputs() != current.getOutputs())
-            return rewriter.notifyMatchFailure(
-                op, "mismatching values propagation");
-        } else {
-          isLive = true;
-        }
-
-        // Continue traversing the DAG of unrealized casts
-        if (auto other = dyn_cast<UnrealizedConversionCastOp>(user))
-          visitStack.push_back(other);
-      }
-
-      // If the cast is live, then we need to check if the results of the last
-      // cast have the same type of the root inputs. It this is the case (e.g.
-      // `{A -> B, B -> A}`, but also `{A -> A}`), then the cycle is just a
-      // no-op and the inputs can be forwarded. If it's not (e.g.
-      // `{A -> B, B -> C}`, `{A -> B}`), then the cast chain is incomplete.
-
-      bool isCycle = current.getResultTypes() == op.getInputs().getTypes();
-
-      if (isLive && !isCycle)
-        return rewriter.notifyMatchFailure(op,
-                                           "live unrealized conversion cast");
-
-      bool isExitNode = users.empty() || isLive;
-
-      if (isExitNode) {
-        exitNodes.insert(current);
-      } else {
-        intermediateNodes.insert(current);
-      }
-    }
-
-    // Replace the sink nodes with the root input values
-    for (UnrealizedConversionCastOp exitNode : exitNodes)
-      rewriter.replaceOp(exitNode, op.getInputs());
-
-    // Erase all the other casts belonging to the DAG
-    for (UnrealizedConversionCastOp castOp : intermediateNodes)
-      rewriter.eraseOp(castOp);
-
-    return success();
-  }
-};
-
 /// Pass to simplify and eliminate unrealized conversion casts.
+///
+/// This pass processes unrealized_conversion_cast ops in a worklist-driven
+/// fashion. For each matched cast op, if the chain of input casts eventually
+/// reaches a cast op where the input types match the output types of the
+/// matched op, replace the matched op with the inputs.
+///
+/// Example:
+/// %1 = unrealized_conversion_cast %0 : !A to !B
+/// %2 = unrealized_conversion_cast %1 : !B to !C
+/// %3 = unrealized_conversion_cast %2 : !C to !A
+///
+/// In the above example, %0 can be used instead of %3 and all cast ops are
+/// folded away.
 struct ReconcileUnrealizedCasts
     : public impl::ReconcileUnrealizedCastsBase<ReconcileUnrealizedCasts> {
   ReconcileUnrealizedCasts() = default;
 
   void runOnOperation() override {
-    RewritePatternSet patterns(&getContext());
-    populateReconcileUnrealizedCastsPatterns(patterns);
-    ConversionTarget target(getContext());
-    target.addIllegalOp<UnrealizedConversionCastOp>();
-    if (failed(applyPartialConversion(getOperation(), target,
-                                      std::move(patterns))))
-      signalPassFailure();
+    // Gather all unrealized_conversion_cast ops.
+    SetVector<UnrealizedConversionCastOp> worklist;
+    getOperation()->walk(
+        [&](UnrealizedConversionCastOp castOp) { worklist.insert(castOp); });
+
+    // Helper function that adds all operands to the worklist that are an
+    // unrealized_conversion_cast op result.
+    auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
+      for (Value v : castOp.getInputs())
+        if (auto castOp = v.getDefiningOp<UnrealizedConversionCastOp>())
+          worklist.insert(castOp);
+    };
+
+    // Helper function that return the unrealized_conversion_cast op that
+    // defines all inputs of the given op (in the same order). Return "nullptr"
+    // if there is no such op.
+    auto getInputCast =
+        [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
+      if (castOp.getInputs().empty())
----------------
zero9178 wrote:

> This goes a bit outside the scope of this pass and I may be wrong, but it seems to me that invalid IR would be coming into the reconciliation pass. Following your example, if the cast is there and its operand is a block operand to be removed, then the use of that argument inside the cast should be replaced by some other SSA value. Otherwise, what would the cast operate on? (I know unrealized casts are "dummy casts", but still, it would conceptually generate a result from nothing)

To give an example as to why this is okay: The operand may simply be dropped during dialect conversion rather than replaced. I previously made use of this/needed this to translate out of a MemorySSA dialect where the SSA values and uses representing versions of memory were dropped and erased instead of replaced.

https://github.com/llvm/llvm-project/pull/95700


More information about the Mlir-commits mailing list