[Mlir-commits] [mlir] [mlir][Conversion] Generalize and fix crash in `reconcile-unrealized-casts` (PR #95700)
Matthias Springer
llvmlistbot at llvm.org
Mon Jun 17 12:37:19 PDT 2024
================
@@ -22,113 +20,87 @@ using namespace mlir;
namespace {
-/// Folds the DAGs of `unrealized_conversion_cast`s that have as exit types
-/// the same as the input ones.
-/// For example, the DAGs `A -> B -> C -> B -> A` and `A -> B -> C -> A`
-/// represent a noop within the IR, and thus the initial input values can be
-/// propagated.
-/// The same does not hold for 'open' chains of casts, such as
-/// `A -> B -> C`. In this last case there is no cycle among the types and thus
-/// the conversion is incomplete. The same hold for 'closed' chains like
-/// `A -> B -> A`, but with the result of type `B` being used by some non-cast
-/// operations.
-/// Bifurcations (that is when a chain starts in between of another one) are
-/// also taken into considerations, and all the above considerations remain
-/// valid.
-/// Special corner cases such as dead casts or single casts with same input and
-/// output types are also covered.
-struct UnrealizedConversionCastPassthrough
- : public OpRewritePattern<UnrealizedConversionCastOp> {
- using OpRewritePattern<UnrealizedConversionCastOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(UnrealizedConversionCastOp op,
- PatternRewriter &rewriter) const override {
- // The nodes that either are not used by any operation or have at least
- // one user that is not an unrealized cast.
- DenseSet<UnrealizedConversionCastOp> exitNodes;
-
- // The nodes whose users are all unrealized casts
- DenseSet<UnrealizedConversionCastOp> intermediateNodes;
-
- // Stack used for the depth-first traversal of the use-def DAG.
- SmallVector<UnrealizedConversionCastOp, 2> visitStack;
- visitStack.push_back(op);
-
- while (!visitStack.empty()) {
- UnrealizedConversionCastOp current = visitStack.pop_back_val();
- auto users = current->getUsers();
- bool isLive = false;
-
- for (Operation *user : users) {
- if (auto other = dyn_cast<UnrealizedConversionCastOp>(user)) {
- if (other.getInputs() != current.getOutputs())
- return rewriter.notifyMatchFailure(
- op, "mismatching values propagation");
- } else {
- isLive = true;
- }
-
- // Continue traversing the DAG of unrealized casts
- if (auto other = dyn_cast<UnrealizedConversionCastOp>(user))
- visitStack.push_back(other);
- }
-
- // If the cast is live, then we need to check if the results of the last
- // cast have the same type of the root inputs. It this is the case (e.g.
- // `{A -> B, B -> A}`, but also `{A -> A}`), then the cycle is just a
- // no-op and the inputs can be forwarded. If it's not (e.g.
- // `{A -> B, B -> C}`, `{A -> B}`), then the cast chain is incomplete.
-
- bool isCycle = current.getResultTypes() == op.getInputs().getTypes();
-
- if (isLive && !isCycle)
- return rewriter.notifyMatchFailure(op,
- "live unrealized conversion cast");
-
- bool isExitNode = users.empty() || isLive;
-
- if (isExitNode) {
- exitNodes.insert(current);
- } else {
- intermediateNodes.insert(current);
- }
- }
-
- // Replace the sink nodes with the root input values
- for (UnrealizedConversionCastOp exitNode : exitNodes)
- rewriter.replaceOp(exitNode, op.getInputs());
-
- // Erase all the other casts belonging to the DAG
- for (UnrealizedConversionCastOp castOp : intermediateNodes)
- rewriter.eraseOp(castOp);
-
- return success();
- }
-};
-
/// Pass to simplify and eliminate unrealized conversion casts.
+///
+/// This pass processes unrealized_conversion_cast ops in a worklist-driven
+/// fashion. For each matched cast op, if the chain of input casts eventually
+/// reaches a cast op where the input types match the output types of the
+/// matched op, replace the matched op with the inputs.
+///
+/// Example:
+/// %1 = unrealized_conversion_cast %0 : !A to !B
+/// %2 = unrealized_conversion_cast %1 : !B to !C
+/// %3 = unrealized_conversion_cast %2 : !C to !A
+///
+/// In the above example, %0 can be used instead of %3 and all cast ops are
+/// folded away.
struct ReconcileUnrealizedCasts
: public impl::ReconcileUnrealizedCastsBase<ReconcileUnrealizedCasts> {
ReconcileUnrealizedCasts() = default;
void runOnOperation() override {
- RewritePatternSet patterns(&getContext());
- populateReconcileUnrealizedCastsPatterns(patterns);
- ConversionTarget target(getContext());
- target.addIllegalOp<UnrealizedConversionCastOp>();
- if (failed(applyPartialConversion(getOperation(), target,
- std::move(patterns))))
- signalPassFailure();
+ // Gather all unrealized_conversion_cast ops.
+ SetVector<UnrealizedConversionCastOp> worklist;
+ getOperation()->walk(
+ [&](UnrealizedConversionCastOp castOp) { worklist.insert(castOp); });
+
+ // Helper function that adds all operands to the worklist that are an
+ // unrealized_conversion_cast op result.
+ auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
+ for (Value v : castOp.getInputs())
+ if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
+ worklist.insert(inputCastOp);
+ };
+
+ // Helper function that return the unrealized_conversion_cast op that
+ // defines all inputs of the given op (in the same order). Return "nullptr"
+ // if there is no such op.
+ auto getInputCast =
+ [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
+ if (castOp.getInputs().empty())
+ return {};
+ auto inputCastOp = castOp.getInputs()
+ .front()
+ .getDefiningOp<UnrealizedConversionCastOp>();
+ if (!inputCastOp)
+ return {};
+ if (inputCastOp.getOutputs() != castOp.getInputs())
+ return {};
+ return inputCastOp;
+ };
+
+ // Process ops in the worklist bottom-to-top.
+ while (!worklist.empty()) {
+ UnrealizedConversionCastOp castOp = worklist.pop_back_val();
+ if (castOp->use_empty()) {
+ // DCE: If the op has no users, erase it. Add the operands to the
+ // worklist to find additional DCE opportunities.
+ enqueueOperands(castOp);
+ castOp->erase();
+ continue;
+ }
+
+ // Traverse the chain of input cast ops to see if an op with the same
+ // input types can be found.
+ UnrealizedConversionCastOp nextCast = castOp;
+ while (nextCast) {
+ if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
+ // Found a cast where the input types match the output types of the
----------------
matthias-springer wrote:
Usually transformations just drop them. It is also unclear where these attributes could be put. There may be no `unrealized_conversion_cast` ops left after running this pass.
https://github.com/llvm/llvm-project/pull/95700
More information about the Mlir-commits
mailing list