[Mlir-commits] [mlir] [mlir][Transforms][NFC] Move `ReconcileUnrealizedCasts` implementation (PR #104671)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Aug 17 03:09:06 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-core
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
Move the implementation of `ReconcileUnrealizedCasts` to `DialectConversion.cpp`, so that it can be called from there in a future commit.
This commit is in preparation of decoupling argument/source/target materializations from the dialect conversion framework. The existing logic around unresolved materializations that predicts IR changes to decide if a cast op can be folded/erased will become obsolete, as `ReconcileUnrealizedCasts` will perform these kind of foldings on fully materialized IR.
---
Full diff: https://github.com/llvm/llvm-project/pull/104671.diff
3 Files Affected:
- (modified) mlir/include/mlir/Transforms/DialectConversion.h (+23)
- (modified) mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp (+3-55)
- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+74)
``````````diff
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index a51b00271f0aeb..86f0337dd90dfe 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1126,6 +1126,29 @@ struct ConversionConfig {
RewriterBase::Listener *listener = nullptr;
};
+//===----------------------------------------------------------------------===//
+// Reconcile Unrealized Casts
+//===----------------------------------------------------------------------===//
+
+/// Try to reconcile all given UnrealizedConversionCastOps and store the
+/// left-over ops in `remainingCastOps` (if provided).
+///
+/// This function processes cast ops in a worklist-driven fashion. For each
+/// cast op, if the chain of input casts eventually reaches a cast op where the
+/// input types match the output types of the matched op, replace the matched
+/// op with the inputs.
+///
+/// Example:
+/// %1 = unrealized_conversion_cast %0 : !A to !B
+/// %2 = unrealized_conversion_cast %1 : !B to !C
+/// %3 = unrealized_conversion_cast %2 : !C to !A
+///
+/// In the above example, %0 can be used instead of %3 and all cast ops are
+/// folded away.
+void reconcileUnrealizedCasts(
+ ArrayRef<UnrealizedConversionCastOp> castOps,
+ SmallVector<UnrealizedConversionCastOp> *remainingCastOps = nullptr);
+
//===----------------------------------------------------------------------===//
// Op Conversion Entry Points
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp b/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
index 12e0029cebfd0d..d01e3dcbe8cc45 100644
--- a/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
+++ b/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
@@ -10,6 +10,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
#define GEN_PASS_DEF_RECONCILEUNREALIZEDCASTS
@@ -39,63 +40,10 @@ struct ReconcileUnrealizedCasts
ReconcileUnrealizedCasts() = default;
void runOnOperation() override {
- // Gather all unrealized_conversion_cast ops.
- SetVector<UnrealizedConversionCastOp> worklist;
+ SmallVector<UnrealizedConversionCastOp> ops;
getOperation()->walk(
[&](UnrealizedConversionCastOp castOp) { worklist.insert(castOp); });
-
- // Helper function that adds all operands to the worklist that are an
- // unrealized_conversion_cast op result.
- auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
- for (Value v : castOp.getInputs())
- if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
- worklist.insert(inputCastOp);
- };
-
- // Helper function that return the unrealized_conversion_cast op that
- // defines all inputs of the given op (in the same order). Return "nullptr"
- // if there is no such op.
- auto getInputCast =
- [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
- if (castOp.getInputs().empty())
- return {};
- auto inputCastOp = castOp.getInputs()
- .front()
- .getDefiningOp<UnrealizedConversionCastOp>();
- if (!inputCastOp)
- return {};
- if (inputCastOp.getOutputs() != castOp.getInputs())
- return {};
- return inputCastOp;
- };
-
- // Process ops in the worklist bottom-to-top.
- while (!worklist.empty()) {
- UnrealizedConversionCastOp castOp = worklist.pop_back_val();
- if (castOp->use_empty()) {
- // DCE: If the op has no users, erase it. Add the operands to the
- // worklist to find additional DCE opportunities.
- enqueueOperands(castOp);
- castOp->erase();
- continue;
- }
-
- // Traverse the chain of input cast ops to see if an op with the same
- // input types can be found.
- UnrealizedConversionCastOp nextCast = castOp;
- while (nextCast) {
- if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
- // Found a cast where the input types match the output types of the
- // matched op. We can directly use those inputs and the matched op can
- // be removed.
- enqueueOperands(castOp);
- castOp.replaceAllUsesWith(nextCast.getInputs());
- castOp->erase();
- break;
- }
- nextCast = getInputCast(nextCast);
- }
- }
+ reconcileUnrealizedCasts(ops);
}
};
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 8a4c7463a69a95..0da8eabadb4ee1 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2869,6 +2869,80 @@ LogicalResult OperationConverter::legalizeErasedResult(
return success();
}
+//===----------------------------------------------------------------------===//
+// Reconcile Unrealized Casts
+//===----------------------------------------------------------------------===//
+
+void mlir::reconcileUnrealizedCasts(
+ ArrayRef<UnrealizedConversionCastOp> castOps,
+ SmallVector<UnrealizedConversionCastOp> *remainingCastOps) {
+ SetVector<UnrealizedConversionCastOp> worklist(castOps.begin(),
+ castOps.end());
+ // This set is maintained only if `remainingCastOps` is provided.
+ DenseSet<Operation *> erasedOps;
+
+ // Helper function that adds all operands to the worklist that are an
+ // unrealized_conversion_cast op result.
+ auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
+ for (Value v : castOp.getInputs())
+ if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
+ worklist.insert(inputCastOp);
+ };
+
+ // Helper function that return the unrealized_conversion_cast op that
+ // defines all inputs of the given op (in the same order). Return "nullptr"
+ // if there is no such op.
+ auto getInputCast =
+ [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
+ if (castOp.getInputs().empty())
+ return {};
+ auto inputCastOp =
+ castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
+ if (!inputCastOp)
+ return {};
+ if (inputCastOp.getOutputs() != castOp.getInputs())
+ return {};
+ return inputCastOp;
+ };
+
+ // Process ops in the worklist bottom-to-top.
+ while (!worklist.empty()) {
+ UnrealizedConversionCastOp castOp = worklist.pop_back_val();
+ if (castOp->use_empty()) {
+ // DCE: If the op has no users, erase it. Add the operands to the
+ // worklist to find additional DCE opportunities.
+ enqueueOperands(castOp);
+ if (remainingCastOps)
+ erasedOps.insert(castOp.getOperation());
+ castOp->erase();
+ continue;
+ }
+
+ // Traverse the chain of input cast ops to see if an op with the same
+ // input types can be found.
+ UnrealizedConversionCastOp nextCast = castOp;
+ while (nextCast) {
+ if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
+ // Found a cast where the input types match the output types of the
+ // matched op. We can directly use those inputs and the matched op can
+ // be removed.
+ enqueueOperands(castOp);
+ castOp.replaceAllUsesWith(nextCast.getInputs());
+ if (remainingCastOps)
+ erasedOps.insert(castOp.getOperation());
+ castOp->erase();
+ break;
+ }
+ nextCast = getInputCast(nextCast);
+ }
+ }
+
+ if (remainingCastOps)
+ for (UnrealizedConversionCastOp op : castOps)
+ if (!erasedOps.contains(op.getOperation()))
+ remainingCastOps->push_back(op);
+}
+
//===----------------------------------------------------------------------===//
// Type Conversion
//===----------------------------------------------------------------------===//
``````````
</details>
https://github.com/llvm/llvm-project/pull/104671
More information about the Mlir-commits
mailing list