[Mlir-commits] [mlir] [mlir][Transforms][NFC] Move `ReconcileUnrealizedCasts` implementation (PR #104671)
Matthias Springer
llvmlistbot at llvm.org
Sat Aug 17 03:36:23 PDT 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/104671
>From 8703175c448143e769957d9e2b5e02c9c1ac6e0f Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 17 Aug 2024 12:01:58 +0200
Subject: [PATCH 1/3] [mlir][Transforms][NFC] Move `ReconcileUnrealizedCasts`
implementation
Move the implementation of `ReconcileUnrealizedCasts` to `DialectConversion.cpp`, so that it can be called from there in a future commit. This commit is in preparation of decoupling argument/source/target materializations from the dialect conversion framework. The existing logic around unresolved materializations that predicts IR changes to decide if a cast op can be folded/erased will become obsolete, as `ReconcileUnrealizedCasts` will perform these kind of foldings on fully materialized IR.
---
.../mlir/Transforms/DialectConversion.h | 23 ++++++
.../ReconcileUnrealizedCasts.cpp | 58 +--------------
.../Transforms/Utils/DialectConversion.cpp | 74 +++++++++++++++++++
3 files changed, 100 insertions(+), 55 deletions(-)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index a51b00271f0aeb..86f0337dd90dfe 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1126,6 +1126,29 @@ struct ConversionConfig {
RewriterBase::Listener *listener = nullptr;
};
+//===----------------------------------------------------------------------===//
+// Reconcile Unrealized Casts
+//===----------------------------------------------------------------------===//
+
+/// Try to reconcile all given UnrealizedConversionCastOps and store the
+/// left-over ops in `remainingCastOps` (if provided).
+///
+/// This function processes cast ops in a worklist-driven fashion. For each
+/// cast op, if the chain of input casts eventually reaches a cast op where the
+/// input types match the output types of the matched op, replace the matched
+/// op with the inputs.
+///
+/// Example:
+/// %1 = unrealized_conversion_cast %0 : !A to !B
+/// %2 = unrealized_conversion_cast %1 : !B to !C
+/// %3 = unrealized_conversion_cast %2 : !C to !A
+///
+/// In the above example, %0 can be used instead of %3 and all cast ops are
+/// folded away.
+void reconcileUnrealizedCasts(
+ ArrayRef<UnrealizedConversionCastOp> castOps,
+ SmallVector<UnrealizedConversionCastOp> *remainingCastOps = nullptr);
+
//===----------------------------------------------------------------------===//
// Op Conversion Entry Points
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp b/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
index 12e0029cebfd0d..d01e3dcbe8cc45 100644
--- a/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
+++ b/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
@@ -10,6 +10,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
#define GEN_PASS_DEF_RECONCILEUNREALIZEDCASTS
@@ -39,63 +40,10 @@ struct ReconcileUnrealizedCasts
ReconcileUnrealizedCasts() = default;
void runOnOperation() override {
- // Gather all unrealized_conversion_cast ops.
- SetVector<UnrealizedConversionCastOp> worklist;
+ SmallVector<UnrealizedConversionCastOp> ops;
getOperation()->walk(
[&](UnrealizedConversionCastOp castOp) { worklist.insert(castOp); });
-
- // Helper function that adds all operands to the worklist that are an
- // unrealized_conversion_cast op result.
- auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
- for (Value v : castOp.getInputs())
- if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
- worklist.insert(inputCastOp);
- };
-
- // Helper function that return the unrealized_conversion_cast op that
- // defines all inputs of the given op (in the same order). Return "nullptr"
- // if there is no such op.
- auto getInputCast =
- [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
- if (castOp.getInputs().empty())
- return {};
- auto inputCastOp = castOp.getInputs()
- .front()
- .getDefiningOp<UnrealizedConversionCastOp>();
- if (!inputCastOp)
- return {};
- if (inputCastOp.getOutputs() != castOp.getInputs())
- return {};
- return inputCastOp;
- };
-
- // Process ops in the worklist bottom-to-top.
- while (!worklist.empty()) {
- UnrealizedConversionCastOp castOp = worklist.pop_back_val();
- if (castOp->use_empty()) {
- // DCE: If the op has no users, erase it. Add the operands to the
- // worklist to find additional DCE opportunities.
- enqueueOperands(castOp);
- castOp->erase();
- continue;
- }
-
- // Traverse the chain of input cast ops to see if an op with the same
- // input types can be found.
- UnrealizedConversionCastOp nextCast = castOp;
- while (nextCast) {
- if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
- // Found a cast where the input types match the output types of the
- // matched op. We can directly use those inputs and the matched op can
- // be removed.
- enqueueOperands(castOp);
- castOp.replaceAllUsesWith(nextCast.getInputs());
- castOp->erase();
- break;
- }
- nextCast = getInputCast(nextCast);
- }
- }
+ reconcileUnrealizedCasts(ops);
}
};
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 8a4c7463a69a95..0da8eabadb4ee1 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2869,6 +2869,80 @@ LogicalResult OperationConverter::legalizeErasedResult(
return success();
}
+//===----------------------------------------------------------------------===//
+// Reconcile Unrealized Casts
+//===----------------------------------------------------------------------===//
+
+void mlir::reconcileUnrealizedCasts(
+ ArrayRef<UnrealizedConversionCastOp> castOps,
+ SmallVector<UnrealizedConversionCastOp> *remainingCastOps) {
+ SetVector<UnrealizedConversionCastOp> worklist(castOps.begin(),
+ castOps.end());
+ // This set is maintained only if `remainingCastOps` is provided.
+ DenseSet<Operation *> erasedOps;
+
+ // Helper function that adds all operands to the worklist that are an
+ // unrealized_conversion_cast op result.
+ auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
+ for (Value v : castOp.getInputs())
+ if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
+ worklist.insert(inputCastOp);
+ };
+
+ // Helper function that return the unrealized_conversion_cast op that
+ // defines all inputs of the given op (in the same order). Return "nullptr"
+ // if there is no such op.
+ auto getInputCast =
+ [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
+ if (castOp.getInputs().empty())
+ return {};
+ auto inputCastOp =
+ castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
+ if (!inputCastOp)
+ return {};
+ if (inputCastOp.getOutputs() != castOp.getInputs())
+ return {};
+ return inputCastOp;
+ };
+
+ // Process ops in the worklist bottom-to-top.
+ while (!worklist.empty()) {
+ UnrealizedConversionCastOp castOp = worklist.pop_back_val();
+ if (castOp->use_empty()) {
+ // DCE: If the op has no users, erase it. Add the operands to the
+ // worklist to find additional DCE opportunities.
+ enqueueOperands(castOp);
+ if (remainingCastOps)
+ erasedOps.insert(castOp.getOperation());
+ castOp->erase();
+ continue;
+ }
+
+ // Traverse the chain of input cast ops to see if an op with the same
+ // input types can be found.
+ UnrealizedConversionCastOp nextCast = castOp;
+ while (nextCast) {
+ if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
+ // Found a cast where the input types match the output types of the
+ // matched op. We can directly use those inputs and the matched op can
+ // be removed.
+ enqueueOperands(castOp);
+ castOp.replaceAllUsesWith(nextCast.getInputs());
+ if (remainingCastOps)
+ erasedOps.insert(castOp.getOperation());
+ castOp->erase();
+ break;
+ }
+ nextCast = getInputCast(nextCast);
+ }
+ }
+
+ if (remainingCastOps)
+ for (UnrealizedConversionCastOp op : castOps)
+ if (!erasedOps.contains(op.getOperation()))
+ remainingCastOps->push_back(op);
+}
+
//===----------------------------------------------------------------------===//
// Type Conversion
//===----------------------------------------------------------------------===//
>From ab5f5c155dddd3a65c8760844fcb65328c01debb Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sat, 17 Aug 2024 12:36:05 +0200
Subject: [PATCH 2/3] Update
mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Co-authored-by: Markus Böck <markus.boeck02 at gmail.com>
---
.../ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp b/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
index d01e3dcbe8cc45..d503620b086000 100644
--- a/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
+++ b/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
@@ -42,7 +42,7 @@ struct ReconcileUnrealizedCasts
void runOnOperation() override {
SmallVector<UnrealizedConversionCastOp> ops;
getOperation()->walk(
- [&](UnrealizedConversionCastOp castOp) { worklist.insert(castOp); });
+ [&](UnrealizedConversionCastOp castOp) { ops.insert(castOp); });
reconcileUnrealizedCasts(ops);
}
};
>From 7f16253d98a4f5cf6231f39c4108ab1568493c1a Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sat, 17 Aug 2024 12:36:15 +0200
Subject: [PATCH 3/3] Update mlir/include/mlir/Transforms/DialectConversion.h
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Co-authored-by: Markus Böck <markus.boeck02 at gmail.com>
---
mlir/include/mlir/Transforms/DialectConversion.h | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 86f0337dd90dfe..60113bdef16a23 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1147,7 +1147,7 @@ struct ConversionConfig {
/// folded away.
void reconcileUnrealizedCasts(
ArrayRef<UnrealizedConversionCastOp> castOps,
- SmallVector<UnrealizedConversionCastOp> *remainingCastOps = nullptr);
+ SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps = nullptr);
//===----------------------------------------------------------------------===//
// Op Conversion Entry Points
More information about the Mlir-commits
mailing list