[Mlir-commits] [mlir] [mlir] Add option to run CSE between greedy rewriter iterations (PR #193081)

Matthias Springer llvmlistbot at llvm.org
Tue Apr 21 03:09:42 PDT 2026


================
@@ -0,0 +1,439 @@
+//===- CSE.cpp - Common Sub-expression Elimination ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements common sub-expression elimination as a library utility.
+// The matching CSE pass is a thin wrapper over the APIs declared here.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/CSE.h"
+
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "llvm/ADT/DenseMapInfo.h"
+#include "llvm/ADT/ScopedHashTable.h"
+#include "llvm/Support/Allocator.h"
+#include "llvm/Support/RecyclingAllocator.h"
+#include <deque>
+
+using namespace mlir;
+
+namespace {
+struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
+  static unsigned getHashValue(const Operation *opC) {
+    return OperationEquivalence::computeHash(
+        const_cast<Operation *>(opC),
+        /*hashOperands=*/OperationEquivalence::directHashValue,
+        /*hashResults=*/OperationEquivalence::ignoreHashValue,
+        OperationEquivalence::IgnoreLocations);
+  }
+  static bool isEqual(const Operation *lhsC, const Operation *rhsC) {
+    auto *lhs = const_cast<Operation *>(lhsC);
+    auto *rhs = const_cast<Operation *>(rhsC);
+    if (lhs == rhs)
+      return true;
+    if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
+        rhs == getTombstoneKey() || rhs == getEmptyKey())
+      return false;
+    return OperationEquivalence::isEquivalentTo(
+        const_cast<Operation *>(lhsC), const_cast<Operation *>(rhsC),
+        OperationEquivalence::IgnoreLocations);
+  }
+};
+} // namespace
+
+namespace {
+/// Simple common sub-expression elimination.
+class CSEDriver {
+public:
+  CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo)
+      : rewriter(rewriter), domInfo(domInfo) {}
+
+  /// Simplify all operations within the given op.
+  void simplify(Operation *op, bool *changed = nullptr);
+
+  /// Simplify operations within the given region.
+  void simplify(Region &region, bool *changed = nullptr);
+
+  int64_t getNumCSE() const { return numCSE; }
+  int64_t getNumDCE() const { return numDCE; }
+
+private:
+  /// Shared implementation of operation elimination and scoped map definitions.
+  using AllocatorTy = llvm::RecyclingAllocator<
+      llvm::BumpPtrAllocator,
+      llvm::ScopedHashTableVal<Operation *, Operation *>>;
+  using ScopedMapTy = llvm::ScopedHashTable<Operation *, Operation *,
+                                            SimpleOperationInfo, AllocatorTy>;
+
+  /// Cache holding MemoryEffects information between two operations. The first
+  /// operation is stored has the key. The second operation is stored inside a
+  /// pair in the value. The pair also hold the MemoryEffects between those
+  /// two operations. If the MemoryEffects is nullptr then we assume there is
+  /// no operation with MemoryEffects::Write between the two operations.
+  using MemEffectsCache =
+      DenseMap<Operation *, std::pair<Operation *, MemoryEffects::Effect *>>;
+
+  /// Represents a single entry in the depth first traversal of a CFG.
+  struct CFGStackNode {
+    CFGStackNode(ScopedMapTy &knownValues, DominanceInfoNode *node)
+        : scope(knownValues), node(node), childIterator(node->begin()) {}
+
+    /// Scope for the known values.
+    ScopedMapTy::ScopeTy scope;
+
+    DominanceInfoNode *node;
+    DominanceInfoNode::const_iterator childIterator;
+
+    /// If this node has been fully processed yet or not.
+    bool processed = false;
+  };
+
+  /// Attempt to eliminate a redundant operation. Returns success if the
+  /// operation was marked for removal, failure otherwise.
+  LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op,
+                                  bool hasSSADominance);
+  void simplifyBlock(ScopedMapTy &knownValues, Block *bb, bool hasSSADominance);
+  void simplifyRegion(ScopedMapTy &knownValues, Region &region);
+
+  /// Erase all operations queued for deletion by the simplification routines.
+  void eraseDeadOps(bool *changed);
+
+  void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
+                            Operation *existing, bool hasSSADominance);
+
+  /// Check if there is side-effecting operations other than the given effect
+  /// between the two operations.
+  bool hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp);
+
+  /// A rewriter for modifying the IR.
+  RewriterBase &rewriter;
+
+  /// Operations marked as dead and to be erased.
+  std::vector<Operation *> opsToErase;
+  DominanceInfo *domInfo = nullptr;
+  MemEffectsCache memEffectsCache;
+
+  // Various statistics.
+  int64_t numCSE = 0;
+  int64_t numDCE = 0;
+};
+} // namespace
+
+void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
+                                     Operation *existing,
+                                     bool hasSSADominance) {
+  // If we find one then replace all uses of the current operation with the
+  // existing one and mark it for deletion. We can only replace an operand in
+  // an operation if it has not been visited yet.
+  if (hasSSADominance) {
+    // If the region has SSA dominance, then we are guaranteed to have not
+    // visited any use of the current operation.
+    if (auto *rewriteListener =
+            dyn_cast_if_present<RewriterBase::Listener>(rewriter.getListener()))
+      rewriteListener->notifyOperationReplaced(op, existing);
+    // Replace all uses, but do not remove the operation yet. This does not
+    // notify the listener because the original op is not erased.
+    rewriter.replaceAllUsesWith(op->getResults(), existing->getResults());
+    opsToErase.push_back(op);
+  } else {
+    // When the region does not have SSA dominance, we need to check if we
+    // have visited a use before replacing any use.
+    auto wasVisited = [&](OpOperand &operand) {
+      return !knownValues.count(operand.getOwner());
+    };
+    if (auto *rewriteListener =
+            dyn_cast_if_present<RewriterBase::Listener>(rewriter.getListener()))
+      for (Value v : op->getResults())
+        if (all_of(v.getUses(), wasVisited))
+          rewriteListener->notifyOperationReplaced(op, existing);
----------------
matthias-springer wrote:

`rewriter.replaceUsesWithIf` already does that.


https://github.com/llvm/llvm-project/pull/193081


More information about the Mlir-commits mailing list