[Mlir-commits] [mlir] [mlir][scf] Track replacements using a listener in TileAndFuse (PR #120999)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 23 13:02:07 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Kunwar Grover (Groverkss)
<details>
<summary>Changes</summary>
This PR makes TileAndFuse explicitly track replacements using a listener instead of assuming that the results always come from the outer most tiling loop. scf::tileUsingInterface can introduce merge operations whose results are the actual replacements to use, instead of the outer most loop results.
---
Full diff: https://github.com/llvm/llvm-project/pull/120999.diff
1 Files Affected:
- (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+59-21)
``````````diff
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 90db42d479a193..2277989bf8411b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -28,6 +28,7 @@
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <optional>
@@ -1467,6 +1468,47 @@ void SliceTrackingListener::notifyOperationReplaced(Operation *op,
ValueRange replacement) {
removeOp(op);
}
+
+//===----------------------------------------------------------------------===//
+// ReplacementListener
+//===----------------------------------------------------------------------===//
+
+/// Listener that tracks updates replacements for values which can be mutated.
+/// This listener runs on top of the existing listener for the rewriter,
+/// to make sure external users can still run listeners.
+class ReplacementListener : public RewriterBase::ForwardingListener {
+public:
+ ReplacementListener(DenseMap<Value, Value> &replacements,
+ OpBuilder::Listener *listener)
+ : ForwardingListener(listener), replacements(replacements) {}
+
+ void updateReplacementValues(ValueRange origValues,
+ ValueRange replaceValues) {
+ // This can probably be written better, but just iterates over the map
+ // and the new replacements for now.
+ for (auto &[key, val] : replacements) {
+ for (auto [orig, replace] : llvm::zip_equal(origValues, replaceValues)) {
+ if (val == orig) {
+ val = replace;
+ }
+ }
+ }
+ }
+
+ void notifyOperationReplaced(Operation *op, Operation *newOp) override {
+ ForwardingListener::notifyOperationReplaced(op, newOp);
+ updateReplacementValues(op->getResults(), newOp->getResults());
+ }
+
+ void notifyOperationReplaced(Operation *op, ValueRange values) override {
+ ForwardingListener::notifyOperationReplaced(op, values);
+ updateReplacementValues(op->getResults(), values);
+ }
+
+private:
+ DenseMap<Value, Value> &replacements;
+};
+
} // namespace
/// Implementation of tile consumer and fuse producer greedily.
@@ -1493,26 +1535,27 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
for (auto *tiledOp : tilingResult->tiledOps)
tiledAndFusedOps.insert(tiledOp);
+ DenseMap<Value, Value> replacements;
+ for (auto [origVal, replacement] : llvm::zip_equal(
+ consumer->getResults(), tilingResult->mergeResult.replacements)) {
+ replacements[origVal] = replacement;
+ }
+
// If there are no loops generated, fusion is immaterial.
auto &loops = tilingResult->loops;
if (loops.empty()) {
- DenseMap<Value, Value> replacements;
- for (auto [origVal, replacement] : llvm::zip_equal(
- consumer->getResults(), tilingResult->mergeResult.replacements)) {
- replacements[origVal] = replacement;
- }
return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
replacements};
}
- // To keep track of replacements for now just record the map from the
- // original untiled value to the result number of the for loop. Since the
- // loop gets potentially replaced during fusion, keeping the value directly
- // wont work.
- DenseMap<Value, size_t> origValToResultNumber;
- for (auto [index, result] : llvm::enumerate(consumer->getResults())) {
- origValToResultNumber[result] = index;
- }
+ // Since the loop gets potentially replaced during fusion, we need to track
+ // the mutation of replacement values. To do this, we attach a listener to
+ // update the replacements as they happen.
+ OpBuilder::Listener *previousListener = rewriter.getListener();
+ auto resetListener =
+ llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
+ ReplacementListener replaceListener(replacements, previousListener);
+ rewriter.setListener(&replaceListener);
// 2. Typically, the operands of the tiled operation are slices of the
// operands of the untiled operation. These are expressed in IR using
@@ -1581,9 +1624,9 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
worklistCandidates.append(newSlices.value());
for (auto [index, result] :
llvm::enumerate(fusableProducerOp->getResults())) {
- origValToResultNumber[result] = loops.front()->getNumResults() -
- fusableProducerOp->getNumResults() +
- index;
+ replacements[result] = loops.front()->getResult(
+ loops.front()->getNumResults() -
+ fusableProducerOp->getNumResults() + index);
}
}
if (Operation *tiledAndFusedOp =
@@ -1597,11 +1640,6 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
}
}
- DenseMap<Value, Value> replacements;
- for (auto [origVal, resultNumber] : origValToResultNumber) {
- replacements[origVal] = loops.front()->getResult(resultNumber);
- }
-
return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
replacements};
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/120999
More information about the Mlir-commits
mailing list