[Mlir-commits] [mlir] [mlir] Add option to run CSE between greedy rewriter iterations (PR #193081)
Matthias Springer
llvmlistbot at llvm.org
Tue Apr 21 03:09:42 PDT 2026
================
@@ -0,0 +1,439 @@
+//===- CSE.cpp - Common Sub-expression Elimination ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements common sub-expression elimination as a library utility.
+// The matching CSE pass is a thin wrapper over the APIs declared here.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/CSE.h"
+
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "llvm/ADT/DenseMapInfo.h"
+#include "llvm/ADT/ScopedHashTable.h"
+#include "llvm/Support/Allocator.h"
+#include "llvm/Support/RecyclingAllocator.h"
+#include <deque>
+
+using namespace mlir;
+
+namespace {
+struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
+ static unsigned getHashValue(const Operation *opC) {
+ return OperationEquivalence::computeHash(
+ const_cast<Operation *>(opC),
+ /*hashOperands=*/OperationEquivalence::directHashValue,
+ /*hashResults=*/OperationEquivalence::ignoreHashValue,
+ OperationEquivalence::IgnoreLocations);
+ }
+ static bool isEqual(const Operation *lhsC, const Operation *rhsC) {
+ auto *lhs = const_cast<Operation *>(lhsC);
+ auto *rhs = const_cast<Operation *>(rhsC);
+ if (lhs == rhs)
+ return true;
+ if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
+ rhs == getTombstoneKey() || rhs == getEmptyKey())
+ return false;
+ return OperationEquivalence::isEquivalentTo(
+ const_cast<Operation *>(lhsC), const_cast<Operation *>(rhsC),
+ OperationEquivalence::IgnoreLocations);
+ }
+};
+} // namespace
+
+namespace {
+/// Simple common sub-expression elimination.
+class CSEDriver {
+public:
+ CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo)
+ : rewriter(rewriter), domInfo(domInfo) {}
+
+ /// Simplify all operations within the given op.
+ void simplify(Operation *op, bool *changed = nullptr);
+
+ /// Simplify operations within the given region.
+ void simplify(Region ®ion, bool *changed = nullptr);
+
+ int64_t getNumCSE() const { return numCSE; }
+ int64_t getNumDCE() const { return numDCE; }
+
+private:
+ /// Shared implementation of operation elimination and scoped map definitions.
+ using AllocatorTy = llvm::RecyclingAllocator<
+ llvm::BumpPtrAllocator,
+ llvm::ScopedHashTableVal<Operation *, Operation *>>;
+ using ScopedMapTy = llvm::ScopedHashTable<Operation *, Operation *,
+ SimpleOperationInfo, AllocatorTy>;
+
+ /// Cache holding MemoryEffects information between two operations. The first
+ /// operation is stored has the key. The second operation is stored inside a
+ /// pair in the value. The pair also hold the MemoryEffects between those
+ /// two operations. If the MemoryEffects is nullptr then we assume there is
+ /// no operation with MemoryEffects::Write between the two operations.
+ using MemEffectsCache =
+ DenseMap<Operation *, std::pair<Operation *, MemoryEffects::Effect *>>;
+
+ /// Represents a single entry in the depth first traversal of a CFG.
+ struct CFGStackNode {
+ CFGStackNode(ScopedMapTy &knownValues, DominanceInfoNode *node)
+ : scope(knownValues), node(node), childIterator(node->begin()) {}
+
+ /// Scope for the known values.
+ ScopedMapTy::ScopeTy scope;
+
+ DominanceInfoNode *node;
+ DominanceInfoNode::const_iterator childIterator;
+
+ /// If this node has been fully processed yet or not.
+ bool processed = false;
+ };
+
+ /// Attempt to eliminate a redundant operation. Returns success if the
+ /// operation was marked for removal, failure otherwise.
+ LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op,
+ bool hasSSADominance);
+ void simplifyBlock(ScopedMapTy &knownValues, Block *bb, bool hasSSADominance);
+ void simplifyRegion(ScopedMapTy &knownValues, Region ®ion);
+
+ /// Erase all operations queued for deletion by the simplification routines.
+ void eraseDeadOps(bool *changed);
+
+ void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
+ Operation *existing, bool hasSSADominance);
+
+ /// Check if there is side-effecting operations other than the given effect
+ /// between the two operations.
+ bool hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp);
+
+ /// A rewriter for modifying the IR.
+ RewriterBase &rewriter;
+
+ /// Operations marked as dead and to be erased.
+ std::vector<Operation *> opsToErase;
+ DominanceInfo *domInfo = nullptr;
+ MemEffectsCache memEffectsCache;
+
+ // Various statistics.
+ int64_t numCSE = 0;
+ int64_t numDCE = 0;
+};
+} // namespace
+
+void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
+ Operation *existing,
+ bool hasSSADominance) {
+ // If we find one then replace all uses of the current operation with the
+ // existing one and mark it for deletion. We can only replace an operand in
+ // an operation if it has not been visited yet.
+ if (hasSSADominance) {
+ // If the region has SSA dominance, then we are guaranteed to have not
+ // visited any use of the current operation.
+ if (auto *rewriteListener =
+ dyn_cast_if_present<RewriterBase::Listener>(rewriter.getListener()))
+ rewriteListener->notifyOperationReplaced(op, existing);
+ // Replace all uses, but do not remove the operation yet. This does not
+ // notify the listener because the original op is not erased.
+ rewriter.replaceAllUsesWith(op->getResults(), existing->getResults());
+ opsToErase.push_back(op);
+ } else {
+ // When the region does not have SSA dominance, we need to check if we
+ // have visited a use before replacing any use.
+ auto wasVisited = [&](OpOperand &operand) {
+ return !knownValues.count(operand.getOwner());
+ };
+ if (auto *rewriteListener =
+ dyn_cast_if_present<RewriterBase::Listener>(rewriter.getListener()))
+ for (Value v : op->getResults())
+ if (all_of(v.getUses(), wasVisited))
+ rewriteListener->notifyOperationReplaced(op, existing);
----------------
matthias-springer wrote:
`rewriter.replaceUsesWithIf` already does that.
https://github.com/llvm/llvm-project/pull/193081
More information about the Mlir-commits
mailing list