[Mlir-commits] [mlir] [mlir] Add option for a cleanup pattern set to SCF tiling helper (PR #109554)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 25 09:37:24 PDT 2024


================
@@ -1315,6 +1317,172 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
   return generatedSlices;
 }
 
+namespace {
+
+//===----------------------------------------------------------------------===//
+// SliceWorklist
+//===----------------------------------------------------------------------===//
+
+/// Struct for tracking the number of stale entries on the worklist and whether
+/// there is a remaining valid entry.
+struct EntryCount {
+  bool isValid = true;
+  unsigned count = 0;
+};
+
+/// A FIFO worklist of operations with efficient removal and set semantics.
+///
+/// This class maintains a queue of operations and a mapping of operations to
+/// positions in the vector, so that operations can be removed efficiently at
+/// random. When an operation is removed, it is replaced with nullptr. Such
+/// nullptr are skipped when pop'ing elements.
+///
+/// This is similar to the worklist used by the GreedyPatternRewriteDriver,
+/// except instead FIFO so that slices for fusion can be processed breadth
+/// first.
+class SliceWorklist {
+public:
+  SliceWorklist() = default;
+
+  /// Push an operation to the end of the worklist. This assumes that
+  /// the given operation is not already on the worklist.
+  void push(Operation *op);
+
+  /// Pop the an operation from the end of the worklist. Returns nullptr if
+  /// there are no remaining valid operations.
+  Operation *pop();
+
+  /// Remove an operation from the worklist.
+  void remove(Operation *op);
+
+protected:
+  /// The queue of operations.
+  std::deque<Operation *> list;
+
+  /// A mapping of operations to the number of stale copies in the queue.
+  DenseMap<Operation *, EntryCount> map;
+};
+
+void SliceWorklist::push(Operation *op) {
+  assert(op && "cannot push nullptr to worklist");
+  list.push_back(op);
+  EntryCount newCount = map.lookup(op);
+  // Because operations are only pushed on creation, valid duplicates are
+  // never added.
+  assert((!map.contains(op) || !newCount.isValid) &&
+         "cannot push a duplicate operation");
+  map[op] = {/*isValid=*/true, newCount.count + 1};
+}
+
+Operation *SliceWorklist::pop() {
+  // Pop the front of the queue until we hit a valid entry.
+  while (!list.empty()) {
+    Operation *op = list.front();
+    list.pop_front();
+
+    EntryCount e = map.lookup(op);
+    // If the entry count is greater than 1 or there is no valid entry,
+    // this must be a stale entry. Decrement the map entry by one and continue.
+    if (e.count > 1 || !e.isValid) {
+      int64_t newCount = e.count - 1;
+      if (newCount <= 0)
+        map.erase(op);
+      else
+        map[op] = {e.isValid, static_cast<unsigned int>(newCount)};
+      continue;
+    }
+
+    map.erase(op);
+    return op;
+  }
+  return nullptr;
+}
+
+// Mark the operation as invalid if present. Removal from the map will
+// happen later when popping from the worklist.
+void SliceWorklist::remove(Operation *op) {
+  if (!map.contains(op))
+    return;
+
+  EntryCount e = map.lookup(op);
+  map[op] = {/*isValid=*/false, e.count};
+}
+
+//===----------------------------------------------------------------------===//
+// SliceTrackingListener
+//===----------------------------------------------------------------------===//
+
+/// This class is a listener for tracking the insertion and removal of
+/// `tensor.extract_slice` ops in a worklist. This can be used in a greedy
+/// fusion algorithm to apply cleanup patterns in between fusion steps.
+class SliceTrackingListener : public RewriterBase::Listener {
+public:
+  explicit SliceTrackingListener(
+      std::optional<FrozenRewritePatternSet> patterns);
+  SliceTrackingListener() = default;
+
+  /// Adds the given list of operations to the worklist, and if present, applies
+  /// the list of `patterns` to the newly added operations. This only processes
+  /// the given operations and any newly inserted ones by the pattern set.
+  LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps);
+
+  /// Add to the new operation worklist if it is an extract_slice.
+  void notifyOperationInserted(Operation *op,
+                               OpBuilder::InsertPoint previous) override;
+
+  /// Remove the operation from the worklist.
+  void notifyOperationErased(Operation *op) override;
+
+  /// Remove the operation from the worklist.
+  void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
+
+  /// The worklist for this transformation keeps track of the operations that
+  /// need to be (re)visited.
+  SliceWorklist worklist;
+
+private:
+  /// Optional pattern set to apply when adding new operations to the worklist.
+  std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
+};
+
+SliceTrackingListener::SliceTrackingListener(
+    std::optional<FrozenRewritePatternSet> p) {
+  patterns = std::move(p);
+}
+
+LogicalResult
+SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
+  for (Operation *op : ops) {
+    if (isa<tensor::ExtractSliceOp>(op))
+      worklist.push(op);
+  }
+
+  if (!patterns)
+    return success();
+
+  GreedyRewriteConfig config;
+  config.listener = this;
+  config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
----------------
MaheshRavishankar wrote:

Interesting, `ExistingAndNewOps` seems like it should be effectively same as `AnyOp`.

https://github.com/llvm/llvm-project/pull/109554


More information about the Mlir-commits mailing list