[Mlir-commits] [mlir] [mlir][Conversion] Generalize and fix crash in `reconcile-unrealized-casts` (PR #95700)
Matthias Springer
llvmlistbot at llvm.org
Sun Jun 16 10:21:12 PDT 2024
================
@@ -22,113 +20,87 @@ using namespace mlir;
namespace {
-/// Folds the DAGs of `unrealized_conversion_cast`s that have as exit types
-/// the same as the input ones.
-/// For example, the DAGs `A -> B -> C -> B -> A` and `A -> B -> C -> A`
-/// represent a noop within the IR, and thus the initial input values can be
-/// propagated.
-/// The same does not hold for 'open' chains of casts, such as
-/// `A -> B -> C`. In this last case there is no cycle among the types and thus
-/// the conversion is incomplete. The same hold for 'closed' chains like
-/// `A -> B -> A`, but with the result of type `B` being used by some non-cast
-/// operations.
-/// Bifurcations (that is when a chain starts in between of another one) are
-/// also taken into considerations, and all the above considerations remain
-/// valid.
-/// Special corner cases such as dead casts or single casts with same input and
-/// output types are also covered.
-struct UnrealizedConversionCastPassthrough
- : public OpRewritePattern<UnrealizedConversionCastOp> {
- using OpRewritePattern<UnrealizedConversionCastOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(UnrealizedConversionCastOp op,
- PatternRewriter &rewriter) const override {
- // The nodes that either are not used by any operation or have at least
- // one user that is not an unrealized cast.
- DenseSet<UnrealizedConversionCastOp> exitNodes;
-
- // The nodes whose users are all unrealized casts
- DenseSet<UnrealizedConversionCastOp> intermediateNodes;
-
- // Stack used for the depth-first traversal of the use-def DAG.
- SmallVector<UnrealizedConversionCastOp, 2> visitStack;
- visitStack.push_back(op);
-
- while (!visitStack.empty()) {
- UnrealizedConversionCastOp current = visitStack.pop_back_val();
- auto users = current->getUsers();
- bool isLive = false;
-
- for (Operation *user : users) {
- if (auto other = dyn_cast<UnrealizedConversionCastOp>(user)) {
- if (other.getInputs() != current.getOutputs())
- return rewriter.notifyMatchFailure(
- op, "mismatching values propagation");
- } else {
- isLive = true;
- }
-
- // Continue traversing the DAG of unrealized casts
- if (auto other = dyn_cast<UnrealizedConversionCastOp>(user))
- visitStack.push_back(other);
- }
-
- // If the cast is live, then we need to check if the results of the last
- // cast have the same type of the root inputs. It this is the case (e.g.
- // `{A -> B, B -> A}`, but also `{A -> A}`), then the cycle is just a
- // no-op and the inputs can be forwarded. If it's not (e.g.
- // `{A -> B, B -> C}`, `{A -> B}`), then the cast chain is incomplete.
-
- bool isCycle = current.getResultTypes() == op.getInputs().getTypes();
-
- if (isLive && !isCycle)
- return rewriter.notifyMatchFailure(op,
- "live unrealized conversion cast");
-
- bool isExitNode = users.empty() || isLive;
-
- if (isExitNode) {
- exitNodes.insert(current);
- } else {
- intermediateNodes.insert(current);
- }
- }
-
- // Replace the sink nodes with the root input values
- for (UnrealizedConversionCastOp exitNode : exitNodes)
- rewriter.replaceOp(exitNode, op.getInputs());
-
- // Erase all the other casts belonging to the DAG
- for (UnrealizedConversionCastOp castOp : intermediateNodes)
- rewriter.eraseOp(castOp);
-
- return success();
- }
-};
-
/// Pass to simplify and eliminate unrealized conversion casts.
+///
+/// This pass processes unrealized_conversion_cast ops in a worklist-driven
+/// fashion. For each matched cast op, if the chain of input casts eventually
+/// reaches a cast op where the input types match the output types of the
+/// matched op, replace the matched op with the inputs.
+///
+/// Example:
+/// %1 = unrealized_conversion_cast %0 : !A to !B
+/// %2 = unrealized_conversion_cast %1 : !B to !C
+/// %3 = unrealized_conversion_cast %2 : !C to !A
+///
+/// In the above example, %0 can be used instead of %3 and all cast ops are
+/// folded away.
struct ReconcileUnrealizedCasts
: public impl::ReconcileUnrealizedCastsBase<ReconcileUnrealizedCasts> {
ReconcileUnrealizedCasts() = default;
void runOnOperation() override {
- RewritePatternSet patterns(&getContext());
- populateReconcileUnrealizedCastsPatterns(patterns);
- ConversionTarget target(getContext());
- target.addIllegalOp<UnrealizedConversionCastOp>();
- if (failed(applyPartialConversion(getOperation(), target,
- std::move(patterns))))
- signalPassFailure();
+ // Gather all unrealized_conversion_cast ops.
+ SetVector<UnrealizedConversionCastOp> worklist;
+ getOperation()->walk(
+ [&](UnrealizedConversionCastOp castOp) { worklist.insert(castOp); });
+
+ // Helper function that adds all operands to the worklist that are an
+ // unrealized_conversion_cast op result.
+ auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
+ for (Value v : castOp.getInputs())
+ if (auto castOp = v.getDefiningOp<UnrealizedConversionCastOp>())
+ worklist.insert(castOp);
+ };
+
+ // Helper function that return the unrealized_conversion_cast op that
+ // defines all inputs of the given op (in the same order). Return "nullptr"
+ // if there is no such op.
+ auto getInputCast =
+ [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
+ if (castOp.getInputs().empty())
----------------
matthias-springer wrote:
Added a test case with a cast op that has no inputs. (The pass does not do anything with it.)
https://github.com/llvm/llvm-project/pull/95700
More information about the Mlir-commits
mailing list