[Mlir-commits] [mlir] [mlir] Add option for a cleanup pattern set to SCF tiling helper (PR #109554)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Sep 25 09:37:24 PDT 2024
================
@@ -1315,6 +1317,172 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
return generatedSlices;
}
+namespace {
+
+//===----------------------------------------------------------------------===//
+// SliceWorklist
+//===----------------------------------------------------------------------===//
+
+/// Struct for tracking the number of stale entries on the worklist and whether
+/// there is a remaining valid entry.
+struct EntryCount {
+ bool isValid = true;
+ unsigned count = 0;
+};
+
+/// A FIFO worklist of operations with efficient removal and set semantics.
+///
+/// This class maintains a queue of operations and a mapping of operations to
+/// positions in the vector, so that operations can be removed efficiently at
+/// random. When an operation is removed, it is replaced with nullptr. Such
+/// nullptr are skipped when pop'ing elements.
+///
+/// This is similar to the worklist used by the GreedyPatternRewriteDriver,
+/// except instead FIFO so that slices for fusion can be processed breadth
+/// first.
+class SliceWorklist {
+public:
+ SliceWorklist() = default;
+
+ /// Push an operation to the end of the worklist. This assumes that
+ /// the given operation is not already on the worklist.
+ void push(Operation *op);
+
+ /// Pop the an operation from the end of the worklist. Returns nullptr if
+ /// there are no remaining valid operations.
+ Operation *pop();
+
+ /// Remove an operation from the worklist.
+ void remove(Operation *op);
+
+protected:
+ /// The queue of operations.
+ std::deque<Operation *> list;
+
+ /// A mapping of operations to the number of stale copies in the queue.
+ DenseMap<Operation *, EntryCount> map;
+};
+
+void SliceWorklist::push(Operation *op) {
+ assert(op && "cannot push nullptr to worklist");
+ list.push_back(op);
+ EntryCount newCount = map.lookup(op);
+ // Because operations are only pushed on creation, valid duplicates are
+ // never added.
+ assert((!map.contains(op) || !newCount.isValid) &&
+ "cannot push a duplicate operation");
+ map[op] = {/*isValid=*/true, newCount.count + 1};
+}
+
+Operation *SliceWorklist::pop() {
+ // Pop the front of the queue until we hit a valid entry.
+ while (!list.empty()) {
+ Operation *op = list.front();
+ list.pop_front();
+
+ EntryCount e = map.lookup(op);
+ // If the entry count is greater than 1 or there is no valid entry,
+ // this must be a stale entry. Decrement the map entry by one and continue.
+ if (e.count > 1 || !e.isValid) {
+ int64_t newCount = e.count - 1;
+ if (newCount <= 0)
+ map.erase(op);
+ else
+ map[op] = {e.isValid, static_cast<unsigned int>(newCount)};
+ continue;
+ }
+
+ map.erase(op);
+ return op;
+ }
+ return nullptr;
+}
+
+// Mark the operation as invalid if present. Removal from the map will
+// happen later when popping from the worklist.
+void SliceWorklist::remove(Operation *op) {
+ if (!map.contains(op))
+ return;
+
+ EntryCount e = map.lookup(op);
+ map[op] = {/*isValid=*/false, e.count};
+}
+
+//===----------------------------------------------------------------------===//
+// SliceTrackingListener
+//===----------------------------------------------------------------------===//
+
+/// This class is a listener for tracking the insertion and removal of
+/// `tensor.extract_slice` ops in a worklist. This can be used in a greedy
+/// fusion algorithm to apply cleanup patterns in between fusion steps.
+class SliceTrackingListener : public RewriterBase::Listener {
+public:
+ explicit SliceTrackingListener(
+ std::optional<FrozenRewritePatternSet> patterns);
+ SliceTrackingListener() = default;
+
+ /// Adds the given list of operations to the worklist, and if present, applies
+ /// the list of `patterns` to the newly added operations. This only processes
+ /// the given operations and any newly inserted ones by the pattern set.
+ LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps);
+
+ /// Add to the new operation worklist if it is an extract_slice.
+ void notifyOperationInserted(Operation *op,
+ OpBuilder::InsertPoint previous) override;
+
+ /// Remove the operation from the worklist.
+ void notifyOperationErased(Operation *op) override;
+
+ /// Remove the operation from the worklist.
+ void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
+
+ /// The worklist for this transformation keeps track of the operations that
+ /// need to be (re)visited.
+ SliceWorklist worklist;
+
+private:
+ /// Optional pattern set to apply when adding new operations to the worklist.
+ std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
+};
+
+SliceTrackingListener::SliceTrackingListener(
+ std::optional<FrozenRewritePatternSet> p) {
+ patterns = std::move(p);
+}
+
+LogicalResult
+SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
+ for (Operation *op : ops) {
+ if (isa<tensor::ExtractSliceOp>(op))
+ worklist.push(op);
+ }
+
+ if (!patterns)
+ return success();
+
+ GreedyRewriteConfig config;
+ config.listener = this;
+ config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
----------------
MaheshRavishankar wrote:
Interesting, `ExistingAndNewOps` seems like it should be effectively same as `AnyOp`.
https://github.com/llvm/llvm-project/pull/109554
More information about the Mlir-commits
mailing list