[Mlir-commits] [mlir] [mlir][Transforms][NFC] Move `ReconcileUnrealizedCasts` implementation (PR #104671)
Markus Böck
llvmlistbot at llvm.org
Sat Aug 17 03:28:43 PDT 2024
================
@@ -2869,6 +2869,80 @@ LogicalResult OperationConverter::legalizeErasedResult(
return success();
}
+//===----------------------------------------------------------------------===//
+// Reconcile Unrealized Casts
+//===----------------------------------------------------------------------===//
+
+void mlir::reconcileUnrealizedCasts(
+ ArrayRef<UnrealizedConversionCastOp> castOps,
+ SmallVector<UnrealizedConversionCastOp> *remainingCastOps) {
+ SetVector<UnrealizedConversionCastOp> worklist(castOps.begin(),
+ castOps.end());
+ // This set is maintained only if `remainingCastOps` is provided.
+ DenseSet<Operation *> erasedOps;
+
+ // Helper function that adds all operands to the worklist that are an
+ // unrealized_conversion_cast op result.
+ auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
+ for (Value v : castOp.getInputs())
+ if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
+ worklist.insert(inputCastOp);
+ };
+
+ // Helper function that return the unrealized_conversion_cast op that
+ // defines all inputs of the given op (in the same order). Return "nullptr"
+ // if there is no such op.
+ auto getInputCast =
+ [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
+ if (castOp.getInputs().empty())
+ return {};
+ auto inputCastOp =
+ castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
+ if (!inputCastOp)
+ return {};
+ if (inputCastOp.getOutputs() != castOp.getInputs())
+ return {};
+ return inputCastOp;
+ };
+
+ // Process ops in the worklist bottom-to-top.
+ while (!worklist.empty()) {
+ UnrealizedConversionCastOp castOp = worklist.pop_back_val();
+ if (castOp->use_empty()) {
+ // DCE: If the op has no users, erase it. Add the operands to the
+ // worklist to find additional DCE opportunities.
+ enqueueOperands(castOp);
+ if (remainingCastOps)
+ erasedOps.insert(castOp.getOperation());
+ castOp->erase();
+ continue;
+ }
+
+ // Traverse the chain of input cast ops to see if an op with the same
+ // input types can be found.
+ UnrealizedConversionCastOp nextCast = castOp;
+ while (nextCast) {
+ if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
+ // Found a cast where the input types match the output types of the
+ // matched op. We can directly use those inputs and the matched op can
+ // be removed.
+ enqueueOperands(castOp);
+ castOp.replaceAllUsesWith(nextCast.getInputs());
+ if (remainingCastOps)
+ erasedOps.insert(castOp.getOperation());
+ castOp->erase();
+ break;
+ }
+ nextCast = getInputCast(nextCast);
+ }
+ }
+
+ if (remainingCastOps)
+ for (UnrealizedConversionCastOp op : castOps)
+ if (!erasedOps.contains(op.getOperation()))
----------------
zero9178 wrote:
This looks slightly scary given that `op` has technically been erased and `op.getOperation` now returns a dangling address. Can't come up with a better way of doing this either though and shouldn't be a correctness issue.
https://github.com/llvm/llvm-project/pull/104671
More information about the Mlir-commits
mailing list