[Mlir-commits] [mlir] [mlir][Conversion] Generalize and fix crash in `reconcile-unrealized-casts` (PR #95700)
Jacques Pienaar
llvmlistbot at llvm.org
Mon Jun 17 12:20:45 PDT 2024
================
@@ -22,113 +20,87 @@ using namespace mlir;
namespace {
-/// Folds the DAGs of `unrealized_conversion_cast`s that have as exit types
-/// the same as the input ones.
-/// For example, the DAGs `A -> B -> C -> B -> A` and `A -> B -> C -> A`
-/// represent a noop within the IR, and thus the initial input values can be
-/// propagated.
-/// The same does not hold for 'open' chains of casts, such as
-/// `A -> B -> C`. In this last case there is no cycle among the types and thus
-/// the conversion is incomplete. The same hold for 'closed' chains like
-/// `A -> B -> A`, but with the result of type `B` being used by some non-cast
-/// operations.
-/// Bifurcations (that is when a chain starts in between of another one) are
-/// also taken into considerations, and all the above considerations remain
-/// valid.
-/// Special corner cases such as dead casts or single casts with same input and
-/// output types are also covered.
-struct UnrealizedConversionCastPassthrough
- : public OpRewritePattern<UnrealizedConversionCastOp> {
- using OpRewritePattern<UnrealizedConversionCastOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(UnrealizedConversionCastOp op,
- PatternRewriter &rewriter) const override {
- // The nodes that either are not used by any operation or have at least
- // one user that is not an unrealized cast.
- DenseSet<UnrealizedConversionCastOp> exitNodes;
-
- // The nodes whose users are all unrealized casts
- DenseSet<UnrealizedConversionCastOp> intermediateNodes;
-
- // Stack used for the depth-first traversal of the use-def DAG.
- SmallVector<UnrealizedConversionCastOp, 2> visitStack;
- visitStack.push_back(op);
-
- while (!visitStack.empty()) {
- UnrealizedConversionCastOp current = visitStack.pop_back_val();
- auto users = current->getUsers();
- bool isLive = false;
-
- for (Operation *user : users) {
- if (auto other = dyn_cast<UnrealizedConversionCastOp>(user)) {
- if (other.getInputs() != current.getOutputs())
- return rewriter.notifyMatchFailure(
- op, "mismatching values propagation");
- } else {
- isLive = true;
- }
-
- // Continue traversing the DAG of unrealized casts
- if (auto other = dyn_cast<UnrealizedConversionCastOp>(user))
- visitStack.push_back(other);
- }
-
- // If the cast is live, then we need to check if the results of the last
- // cast have the same type of the root inputs. It this is the case (e.g.
- // `{A -> B, B -> A}`, but also `{A -> A}`), then the cycle is just a
- // no-op and the inputs can be forwarded. If it's not (e.g.
- // `{A -> B, B -> C}`, `{A -> B}`), then the cast chain is incomplete.
-
- bool isCycle = current.getResultTypes() == op.getInputs().getTypes();
-
- if (isLive && !isCycle)
- return rewriter.notifyMatchFailure(op,
- "live unrealized conversion cast");
-
- bool isExitNode = users.empty() || isLive;
-
- if (isExitNode) {
- exitNodes.insert(current);
- } else {
- intermediateNodes.insert(current);
- }
- }
-
- // Replace the sink nodes with the root input values
- for (UnrealizedConversionCastOp exitNode : exitNodes)
- rewriter.replaceOp(exitNode, op.getInputs());
-
- // Erase all the other casts belonging to the DAG
- for (UnrealizedConversionCastOp castOp : intermediateNodes)
- rewriter.eraseOp(castOp);
-
- return success();
- }
-};
-
/// Pass to simplify and eliminate unrealized conversion casts.
+///
+/// This pass processes unrealized_conversion_cast ops in a worklist-driven
+/// fashion. For each matched cast op, if the chain of input casts eventually
+/// reaches a cast op where the input types match the output types of the
+/// matched op, replace the matched op with the inputs.
+///
+/// Example:
+/// %1 = unrealized_conversion_cast %0 : !A to !B
+/// %2 = unrealized_conversion_cast %1 : !B to !C
+/// %3 = unrealized_conversion_cast %2 : !C to !A
+///
+/// In the above example, %0 can be used instead of %3 and all cast ops are
+/// folded away.
struct ReconcileUnrealizedCasts
: public impl::ReconcileUnrealizedCastsBase<ReconcileUnrealizedCasts> {
ReconcileUnrealizedCasts() = default;
void runOnOperation() override {
- RewritePatternSet patterns(&getContext());
- populateReconcileUnrealizedCastsPatterns(patterns);
- ConversionTarget target(getContext());
- target.addIllegalOp<UnrealizedConversionCastOp>();
- if (failed(applyPartialConversion(getOperation(), target,
- std::move(patterns))))
- signalPassFailure();
+ // Gather all unrealized_conversion_cast ops.
+ SetVector<UnrealizedConversionCastOp> worklist;
+ getOperation()->walk(
+ [&](UnrealizedConversionCastOp castOp) { worklist.insert(castOp); });
+
+ // Helper function that adds all operands to the worklist that are an
+ // unrealized_conversion_cast op result.
+ auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
+ for (Value v : castOp.getInputs())
+ if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
+ worklist.insert(inputCastOp);
+ };
+
+ // Helper function that return the unrealized_conversion_cast op that
+ // defines all inputs of the given op (in the same order). Return "nullptr"
+ // if there is no such op.
+ auto getInputCast =
+ [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
+ if (castOp.getInputs().empty())
+ return {};
+ auto inputCastOp = castOp.getInputs()
+ .front()
+ .getDefiningOp<UnrealizedConversionCastOp>();
+ if (!inputCastOp)
+ return {};
+ if (inputCastOp.getOutputs() != castOp.getInputs())
+ return {};
+ return inputCastOp;
+ };
+
+ // Process ops in the worklist bottom-to-top.
+ while (!worklist.empty()) {
+ UnrealizedConversionCastOp castOp = worklist.pop_back_val();
+ if (castOp->use_empty()) {
+ // DCE: If the op has no users, erase it. Add the operands to the
+ // worklist to find additional DCE opportunities.
+ enqueueOperands(castOp);
+ castOp->erase();
+ continue;
+ }
+
+ // Traverse the chain of input cast ops to see if an op with the same
+ // input types can be found.
+ UnrealizedConversionCastOp nextCast = castOp;
+ while (nextCast) {
+ if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
+ // Found a cast where the input types match the output types of the
----------------
jpienaar wrote:
Do we have to worry about other attributes? (I think here they are all discardable and so we need not, probably if downstream wanted to assign meaning to them then they should not run this ?)
https://github.com/llvm/llvm-project/pull/95700
More information about the Mlir-commits
mailing list