[Mlir-commits] [mlir] 6e3631d - [mlir][scf] Track replacements using a listener in TileAndFuse (#120999)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Dec 24 10:01:44 PST 2024


Author: Kunwar Grover
Date: 2024-12-24T18:01:41Z
New Revision: 6e3631d0e3316394ff4eae2913013d323e685790

URL: https://github.com/llvm/llvm-project/commit/6e3631d0e3316394ff4eae2913013d323e685790
DIFF: https://github.com/llvm/llvm-project/commit/6e3631d0e3316394ff4eae2913013d323e685790.diff

LOG: [mlir][scf] Track replacements using a listener in TileAndFuse (#120999)

This PR makes TileAndFuse explicitly track replacements using a listener
instead of assuming that the results always come from the outer most
tiling loop. scf::tileUsingInterface can introduce merge operations
whose results are the actual replacements to use, instead of the outer
most loop results.

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 90db42d479a193..2277989bf8411b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -28,6 +28,7 @@
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 #include <optional>
@@ -1467,6 +1468,47 @@ void SliceTrackingListener::notifyOperationReplaced(Operation *op,
                                                     ValueRange replacement) {
   removeOp(op);
 }
+
+//===----------------------------------------------------------------------===//
+// ReplacementListener
+//===----------------------------------------------------------------------===//
+
+/// Listener that tracks updates replacements for values which can be mutated.
+/// This listener runs on top of the existing listener for the rewriter,
+/// to make sure external users can still run listeners.
+class ReplacementListener : public RewriterBase::ForwardingListener {
+public:
+  ReplacementListener(DenseMap<Value, Value> &replacements,
+                      OpBuilder::Listener *listener)
+      : ForwardingListener(listener), replacements(replacements) {}
+
+  void updateReplacementValues(ValueRange origValues,
+                               ValueRange replaceValues) {
+    // This can probably be written better, but just iterates over the map
+    // and the new replacements for now.
+    for (auto &[key, val] : replacements) {
+      for (auto [orig, replace] : llvm::zip_equal(origValues, replaceValues)) {
+        if (val == orig) {
+          val = replace;
+        }
+      }
+    }
+  }
+
+  void notifyOperationReplaced(Operation *op, Operation *newOp) override {
+    ForwardingListener::notifyOperationReplaced(op, newOp);
+    updateReplacementValues(op->getResults(), newOp->getResults());
+  }
+
+  void notifyOperationReplaced(Operation *op, ValueRange values) override {
+    ForwardingListener::notifyOperationReplaced(op, values);
+    updateReplacementValues(op->getResults(), values);
+  }
+
+private:
+  DenseMap<Value, Value> &replacements;
+};
+
 } // namespace
 
 /// Implementation of tile consumer and fuse producer greedily.
@@ -1493,26 +1535,27 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
   for (auto *tiledOp : tilingResult->tiledOps)
     tiledAndFusedOps.insert(tiledOp);
 
+  DenseMap<Value, Value> replacements;
+  for (auto [origVal, replacement] : llvm::zip_equal(
+           consumer->getResults(), tilingResult->mergeResult.replacements)) {
+    replacements[origVal] = replacement;
+  }
+
   // If there are no loops generated, fusion is immaterial.
   auto &loops = tilingResult->loops;
   if (loops.empty()) {
-    DenseMap<Value, Value> replacements;
-    for (auto [origVal, replacement] : llvm::zip_equal(
-             consumer->getResults(), tilingResult->mergeResult.replacements)) {
-      replacements[origVal] = replacement;
-    }
     return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
                                      replacements};
   }
 
-  // To keep track of replacements for now just record the map from the
-  // original untiled value to the result number of the for loop. Since the
-  // loop gets potentially replaced during fusion, keeping the value directly
-  // wont work.
-  DenseMap<Value, size_t> origValToResultNumber;
-  for (auto [index, result] : llvm::enumerate(consumer->getResults())) {
-    origValToResultNumber[result] = index;
-  }
+  // Since the loop gets potentially replaced during fusion, we need to track
+  // the mutation of replacement values. To do this, we attach a listener to
+  // update the replacements as they happen.
+  OpBuilder::Listener *previousListener = rewriter.getListener();
+  auto resetListener =
+      llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
+  ReplacementListener replaceListener(replacements, previousListener);
+  rewriter.setListener(&replaceListener);
 
   // 2. Typically, the operands of the tiled operation are slices of the
   //    operands of the untiled operation. These are expressed in IR using
@@ -1581,9 +1624,9 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
       worklistCandidates.append(newSlices.value());
       for (auto [index, result] :
            llvm::enumerate(fusableProducerOp->getResults())) {
-        origValToResultNumber[result] = loops.front()->getNumResults() -
-                                        fusableProducerOp->getNumResults() +
-                                        index;
+        replacements[result] = loops.front()->getResult(
+            loops.front()->getNumResults() -
+            fusableProducerOp->getNumResults() + index);
       }
     }
     if (Operation *tiledAndFusedOp =
@@ -1597,11 +1640,6 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
     }
   }
 
-  DenseMap<Value, Value> replacements;
-  for (auto [origVal, resultNumber] : origValToResultNumber) {
-    replacements[origVal] = loops.front()->getResult(resultNumber);
-  }
-
   return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
                                    replacements};
 }


        


More information about the Mlir-commits mailing list