[Mlir-commits] [mlir] [mlir] Add option to run CSE between greedy rewriter iterations (PR #193081)
Mehdi Amini
llvmlistbot at llvm.org
Tue Apr 21 01:53:03 PDT 2026
https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/193081
>From 6c284c022e6e0141ae7e99b5ad7bb881169dceb3 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Mon, 20 Apr 2026 13:50:44 -0700
Subject: [PATCH] [mlir] Add option to run CSE between greedy rewriter
iterations
The greedy pattern rewrite driver previously only deduplicated constant ops
between iterations (via the operation folder). Structurally identical
non-constant subexpressions remained distinct SSA values, blocking fold
patterns that only fire when operands match. Reaching the true fixpoint
required chaining an external `cse,canonicalize,...` pipeline.
Add an opt-in `cseBetweenIterations` flag on `GreedyRewriteConfig` that runs
full CSE on the scoped region after each pattern-application iteration, and
surface it as a `cse-between-iterations` option on the canonicalizer pass.
Off by default to preserve existing performance characteristics.
To let the greedy driver (in MLIRTransformUtils) call into CSE without
creating a layering cycle with MLIRTransforms, the CSE driver implementation
moves to `Utils/CSE.cpp`; the CSE pass in `Transforms/CSE.cpp` becomes a thin
wrapper over the public API. A region-scoped overload of
`eliminateCommonSubExpressions` is added for use by the driver.
Assisted-by: Claude Code
---
mlir/include/mlir/Transforms/CSE.h | 18 +-
.../Transforms/GreedyPatternRewriteDriver.h | 24 +
mlir/include/mlir/Transforms/Passes.td | 8 +-
.../lib/Dialect/Transform/IR/TransformOps.cpp | 48 +-
mlir/lib/Transforms/CSE.cpp | 415 +----------------
mlir/lib/Transforms/Canonicalizer.cpp | 2 +
mlir/lib/Transforms/Utils/CMakeLists.txt | 1 +
mlir/lib/Transforms/Utils/CSE.cpp | 439 ++++++++++++++++++
.../Utils/GreedyPatternRewriteDriver.cpp | 12 +
mlir/test/Pass/run-reproducer.mlir | 2 +-
.../canonicalize-cse-between-iterations.mlir | 81 ++++
mlir/test/Transforms/composite-pass.mlir | 2 +-
12 files changed, 620 insertions(+), 432 deletions(-)
create mode 100644 mlir/lib/Transforms/Utils/CSE.cpp
create mode 100644 mlir/test/Transforms/canonicalize-cse-between-iterations.mlir
diff --git a/mlir/include/mlir/Transforms/CSE.h b/mlir/include/mlir/Transforms/CSE.h
index 3d01ece078050..4a87d585e0eb9 100644
--- a/mlir/include/mlir/Transforms/CSE.h
+++ b/mlir/include/mlir/Transforms/CSE.h
@@ -13,18 +13,34 @@
#ifndef MLIR_TRANSFORMS_CSE_H_
#define MLIR_TRANSFORMS_CSE_H_
+#include <cstdint>
+
namespace mlir {
class DominanceInfo;
class Operation;
+class Region;
class RewriterBase;
/// Eliminate common subexpressions within the given operation. This transform
/// looks for and deduplicates equivalent operations.
///
-/// `changed` indicates whether the IR was modified or not.
+/// `changed` indicates whether the IR was modified or not. `numCSE` and
+/// `numDCE` receive counts of operations deduplicated and dead operations
+/// erased, respectively.
void eliminateCommonSubExpressions(RewriterBase &rewriter,
DominanceInfo &domInfo, Operation *op,
+ bool *changed = nullptr,
+ int64_t *numCSE = nullptr,
+ int64_t *numDCE = nullptr);
+
+/// Eliminate common subexpressions within the given region.
+///
+/// `changed` indicates whether the IR was modified or not. Statistics are not
+/// reported through this overload; use the `Operation *` overload when CSE /
+/// DCE counts are needed.
+void eliminateCommonSubExpressions(RewriterBase &rewriter,
+ DominanceInfo &domInfo, Region ®ion,
bool *changed = nullptr);
} // namespace mlir
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index d56d7e58c35f9..ddeff8a6c552d 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -137,6 +137,29 @@ class GreedyRewriteConfig {
return *this;
}
+ /// If set to "true", full common-subexpression elimination is run on the
+ /// scoped region between each pattern-application iteration. Unlike
+ /// `cseConstants` (which only deduplicates constant ops via the operation
+ /// folder) this runs the standard CSE algorithm and can unblock further
+ /// canonicalizations on the next iteration. Off by default because it
+ /// rebuilds dominance info each iteration.
+ ///
+ /// Caveats when enabling this option:
+ /// - Any listener attached via `setListener` will be notified of
+ /// `notifyOperationReplaced` / `notifyOperationErased` events generated
+ /// by CSE. Pattern authors relying on operation identity (e.g., the
+ /// transform dialect's handle tracking) must account for this.
+ /// - CSE-driven changes feed back into the iteration loop: a pattern that
+ /// re-materializes duplicates that CSE keeps collapsing can extend the
+ /// iteration count and, in the worst case, hit `maxIterations`. Under
+ /// `testConvergence=true` such pipelines will be reported as
+ /// non-convergent.
+ bool isCSEBetweenIterationsEnabled() const { return cseBetweenIterations; }
+ GreedyRewriteConfig &enableCSEBetweenIterations(bool enable = true) {
+ cseBetweenIterations = enable;
+ return *this;
+ }
+
private:
Region *scope = nullptr;
bool useTopDownTraversal = false;
@@ -148,6 +171,7 @@ class GreedyRewriteConfig {
RewriterBase::Listener *listener = nullptr;
bool fold = true;
bool cseConstants = true;
+ bool cseBetweenIterations = false;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 1474e580cfc03..3822d1d2a4156 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -48,7 +48,13 @@ def CanonicalizerPass : Pass<"canonicalize"> {
Option<"maxNumRewrites", "max-num-rewrites", "int64_t", /*default=*/"-1",
"Max. number of pattern rewrites within an iteration">,
Option<"testConvergence", "test-convergence", "bool", /*default=*/"false",
- "Test only: Fail pass on non-convergence to detect cyclic pattern">
+ "Test only: Fail pass on non-convergence to detect cyclic pattern">,
+ Option<"cseBetweenIterations", "cse-between-iterations", "bool",
+ /*default=*/"false",
+ "Run full CSE between each pattern-application iteration. "
+ "CSE-driven changes trigger extra iterations, so this may push "
+ "the iteration count up to max-iterations and affect convergence "
+ "under test-convergence.">
] # RewritePassUtils.options;
}
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 03b28ad0acfa2..1a27e6cb5de58 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -414,37 +414,33 @@ DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
? GreedyRewriteConfig::kNoLimit
: getMaxNumRewrites());
- // Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE
- // was requested, apply the greedy pattern rewrite only once. (The greedy
- // pattern rewrite driver already iterates to a fixpoint internally.)
- bool cseChanged = false;
+ if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
+ // Op is isolated from above. The greedy driver iterates to a fixpoint
+ // internally and optionally runs full CSE between iterations.
+ config.enableCSEBetweenIterations(getApplyCse());
+ if (failed(applyPatternsGreedily(target, frozenPatterns, config))) {
+ return emitSilenceableFailure(target)
+ << "greedy pattern application failed";
+ }
+ return DiagnosedSilenceableFailure::success();
+ }
+
+ // Non-isolated case: gather the ops manually because the op-list
+ // GreedyPatternRewriteDriver overload only performs a single iteration and
+ // does not simplify regions. CSE is driven externally to reach a fixpoint.
+ SmallVector<Operation *> ops;
+ target->walk([&](Operation *nestedOp) {
+ if (target != nestedOp)
+ ops.push_back(nestedOp);
+ });
+
// One or two iterations should be sufficient. Stop iterating after a certain
// threshold to make debugging easier.
static const int64_t kNumMaxIterations = 50;
int64_t iteration = 0;
+ bool cseChanged = false;
do {
- LogicalResult result = failure();
- if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
- // Op is isolated from above. Apply patterns and also perform region
- // simplification.
- result = applyPatternsGreedily(target, frozenPatterns, config);
- } else {
- // Manually gather list of ops because the other
- // GreedyPatternRewriteDriver overloads only accepts ops that are isolated
- // from above. This way, patterns can be applied to ops that are not
- // isolated from above. Regions are not being simplified. Furthermore,
- // only a single greedy rewrite iteration is performed.
- SmallVector<Operation *> ops;
- target->walk([&](Operation *nestedOp) {
- if (target != nestedOp)
- ops.push_back(nestedOp);
- });
- result = applyOpPatternsGreedily(ops, frozenPatterns, config);
- }
-
- // A failure typically indicates that the pattern application did not
- // converge.
- if (failed(result)) {
+ if (failed(applyOpPatternsGreedily(ops, frozenPatterns, config))) {
return emitSilenceableFailure(target)
<< "greedy pattern application failed";
}
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index c426ac698b7ae..f7afa03e2f02b 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -6,8 +6,9 @@
//
//===----------------------------------------------------------------------===//
//
-// This transformation pass performs a simple common sub-expression elimination
-// algorithm on operations within a region.
+// This file implements the CSE pass. The actual CSE algorithm lives in
+// mlir/lib/Transforms/Utils/CSE.cpp so that it can be invoked from other
+// utilities (e.g. the greedy pattern rewrite driver).
//
//===----------------------------------------------------------------------===//
@@ -15,14 +16,7 @@
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
-#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/Passes.h"
-#include "llvm/ADT/DenseMapInfo.h"
-#include "llvm/ADT/ScopedHashTable.h"
-#include "llvm/Support/Allocator.h"
-#include "llvm/Support/RecyclingAllocator.h"
-#include <deque>
namespace mlir {
#define GEN_PASS_DEF_CSEPASS
@@ -31,392 +25,6 @@ namespace mlir {
using namespace mlir;
-namespace {
-struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
- static unsigned getHashValue(const Operation *opC) {
- return OperationEquivalence::computeHash(
- const_cast<Operation *>(opC),
- /*hashOperands=*/OperationEquivalence::directHashValue,
- /*hashResults=*/OperationEquivalence::ignoreHashValue,
- OperationEquivalence::IgnoreLocations);
- }
- static bool isEqual(const Operation *lhsC, const Operation *rhsC) {
- auto *lhs = const_cast<Operation *>(lhsC);
- auto *rhs = const_cast<Operation *>(rhsC);
- if (lhs == rhs)
- return true;
- if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
- rhs == getTombstoneKey() || rhs == getEmptyKey())
- return false;
- return OperationEquivalence::isEquivalentTo(
- const_cast<Operation *>(lhsC), const_cast<Operation *>(rhsC),
- OperationEquivalence::IgnoreLocations);
- }
-};
-} // namespace
-
-namespace {
-/// Simple common sub-expression elimination.
-class CSEDriver {
-public:
- CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo)
- : rewriter(rewriter), domInfo(domInfo) {}
-
- /// Simplify all operations within the given op.
- void simplify(Operation *op, bool *changed = nullptr);
-
- int64_t getNumCSE() const { return numCSE; }
- int64_t getNumDCE() const { return numDCE; }
-
-private:
- /// Shared implementation of operation elimination and scoped map definitions.
- using AllocatorTy = llvm::RecyclingAllocator<
- llvm::BumpPtrAllocator,
- llvm::ScopedHashTableVal<Operation *, Operation *>>;
- using ScopedMapTy = llvm::ScopedHashTable<Operation *, Operation *,
- SimpleOperationInfo, AllocatorTy>;
-
- /// Cache holding MemoryEffects information between two operations. The first
- /// operation is stored has the key. The second operation is stored inside a
- /// pair in the value. The pair also hold the MemoryEffects between those
- /// two operations. If the MemoryEffects is nullptr then we assume there is
- /// no operation with MemoryEffects::Write between the two operations.
- using MemEffectsCache =
- DenseMap<Operation *, std::pair<Operation *, MemoryEffects::Effect *>>;
-
- /// Represents a single entry in the depth first traversal of a CFG.
- struct CFGStackNode {
- CFGStackNode(ScopedMapTy &knownValues, DominanceInfoNode *node)
- : scope(knownValues), node(node), childIterator(node->begin()) {}
-
- /// Scope for the known values.
- ScopedMapTy::ScopeTy scope;
-
- DominanceInfoNode *node;
- DominanceInfoNode::const_iterator childIterator;
-
- /// If this node has been fully processed yet or not.
- bool processed = false;
- };
-
- /// Attempt to eliminate a redundant operation. Returns success if the
- /// operation was marked for removal, failure otherwise.
- LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op,
- bool hasSSADominance);
- void simplifyBlock(ScopedMapTy &knownValues, Block *bb, bool hasSSADominance);
- void simplifyRegion(ScopedMapTy &knownValues, Region ®ion);
-
- void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
- Operation *existing, bool hasSSADominance);
-
- /// Check if there is side-effecting operations other than the given effect
- /// between the two operations.
- bool hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp);
-
- /// A rewriter for modifying the IR.
- RewriterBase &rewriter;
-
- /// Operations marked as dead and to be erased.
- std::vector<Operation *> opsToErase;
- DominanceInfo *domInfo = nullptr;
- MemEffectsCache memEffectsCache;
-
- // Various statistics.
- int64_t numCSE = 0;
- int64_t numDCE = 0;
-};
-} // namespace
-
-void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
- Operation *existing,
- bool hasSSADominance) {
- // If we find one then replace all uses of the current operation with the
- // existing one and mark it for deletion. We can only replace an operand in
- // an operation if it has not been visited yet.
- if (hasSSADominance) {
- // If the region has SSA dominance, then we are guaranteed to have not
- // visited any use of the current operation.
- if (auto *rewriteListener =
- dyn_cast_if_present<RewriterBase::Listener>(rewriter.getListener()))
- rewriteListener->notifyOperationReplaced(op, existing);
- // Replace all uses, but do not remove the operation yet. This does not
- // notify the listener because the original op is not erased.
- rewriter.replaceAllUsesWith(op->getResults(), existing->getResults());
- opsToErase.push_back(op);
- } else {
- // When the region does not have SSA dominance, we need to check if we
- // have visited a use before replacing any use.
- auto wasVisited = [&](OpOperand &operand) {
- return !knownValues.count(operand.getOwner());
- };
- if (auto *rewriteListener =
- dyn_cast_if_present<RewriterBase::Listener>(rewriter.getListener()))
- for (Value v : op->getResults())
- if (all_of(v.getUses(), wasVisited))
- rewriteListener->notifyOperationReplaced(op, existing);
-
- // Replace all uses, but do not remove the operation yet. This does not
- // notify the listener because the original op is not erased.
- rewriter.replaceUsesWithIf(op->getResults(), existing->getResults(),
- wasVisited);
-
- // There may be some remaining uses of the operation.
- if (op->use_empty())
- opsToErase.push_back(op);
- }
-
- // If the existing operation has an unknown location and the current
- // operation doesn't, then set the existing op's location to that of the
- // current op.
- if (isa<UnknownLoc>(existing->getLoc()) && !isa<UnknownLoc>(op->getLoc()))
- existing->setLoc(op->getLoc());
-
- ++numCSE;
-}
-
-bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp,
- Operation *toOp) {
- assert(fromOp->getBlock() == toOp->getBlock());
- assert(hasEffect<MemoryEffects::Read>(fromOp) &&
- "expected read effect on fromOp");
- assert(hasEffect<MemoryEffects::Read>(toOp) &&
- "expected read effect on toOp");
-
- // Collect the read effects of fromOp. A write can only block CSE if it
- // can conflict with one of these reads.
- SmallVector<MemoryEffects::EffectInstance> readEffects;
- if (auto memOp = dyn_cast<MemoryEffectOpInterface>(fromOp)) {
- SmallVector<MemoryEffects::EffectInstance> fromEffects;
- memOp.getEffects(fromEffects);
- for (MemoryEffects::EffectInstance &e : fromEffects)
- if (isa<MemoryEffects::Read>(e.getEffect()))
- readEffects.push_back(e);
- }
-
- Operation *nextOp = fromOp->getNextNode();
- auto result =
- memEffectsCache.try_emplace(fromOp, std::make_pair(fromOp, nullptr));
- if (!result.second) {
- auto memEffectsCachePair = result.first->second;
- if (memEffectsCachePair.second == nullptr) {
- // No MemoryEffects::Write has been detected until the cached operation.
- // Continue looking from the cached operation to toOp.
- nextOp = memEffectsCachePair.first;
- } else {
- // MemoryEffects::Write has been detected before so there is no need to
- // check further.
- return true;
- }
- }
- while (nextOp && nextOp != toOp) {
- std::optional<SmallVector<MemoryEffects::EffectInstance>> effects =
- getEffectsRecursively(nextOp);
- if (!effects) {
- // TODO: Do we need to handle other effects generically?
- // If the operation does not implement the MemoryEffectOpInterface we
- // conservatively assume it writes.
- result.first->second =
- std::make_pair(nextOp, MemoryEffects::Write::get());
- return true;
- }
-
- for (const MemoryEffects::EffectInstance &effect : *effects) {
- if (isa<MemoryEffects::Write>(effect.getEffect())) {
- // A write on a resource disjoint from all read resources cannot
- // conflict with the reads being CSE'd.
- SideEffects::Resource *writeResource = effect.getResource();
- bool canConflict =
- llvm::any_of(readEffects, [&](const auto &readEffect) {
- SideEffects::Resource *readResource = readEffect.getResource();
- if (writeResource->isDisjointFrom(readResource))
- return false;
- // A pointer-based access to an addressable resource cannot
- // conflict with a non-addressable resource.
- if (readEffect.getValue() && !writeResource->isAddressable())
- return false;
- if (effect.getValue() && !readResource->isAddressable())
- return false;
- return true;
- });
- if (canConflict) {
- result.first->second = {nextOp, MemoryEffects::Write::get()};
- return true;
- }
- }
- }
- nextOp = nextOp->getNextNode();
- }
- result.first->second = std::make_pair(toOp, nullptr);
- return false;
-}
-
-/// Attempt to eliminate a redundant operation.
-LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
- Operation *op,
- bool hasSSADominance) {
- // Don't simplify terminator operations.
- if (op->hasTrait<OpTrait::IsTerminator>())
- return failure();
-
- // Don't simplify operations with regions that have multiple blocks.
- // TODO: We need additional tests to verify that we handle such IR correctly.
- if (!llvm::all_of(op->getRegions(),
- [](Region &r) { return r.empty() || r.hasOneBlock(); }))
- return failure();
-
- // Some simple use case of operation with memory side-effect are dealt with
- // here. Operations with no side-effect are done after.
- if (!isMemoryEffectFree(op)) {
- // TODO: Only basic use case for operations with MemoryEffects::Read can be
- // eleminated now. More work needs to be done for more complicated patterns
- // and other side-effects.
- if (!hasSingleEffect<MemoryEffects::Read>(op))
- return failure();
-
- // Look for an existing definition for the operation.
- if (auto *existing = knownValues.lookup(op)) {
- if (existing->getBlock() == op->getBlock() &&
- !hasOtherSideEffectingOpInBetween(existing, op)) {
- // The operation that can be deleted has been reach with no
- // side-effecting operations in between the existing operation and
- // this one so we can remove the duplicate.
- replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
- return success();
- }
- }
- knownValues.insert(op, op);
- return failure();
- }
-
- // Look for an existing definition for the operation.
- if (auto *existing = knownValues.lookup(op)) {
- replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
- return success();
- }
-
- // Otherwise, we add this operation to the known values map.
- knownValues.insert(op, op);
- return failure();
-}
-
-void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb,
- bool hasSSADominance) {
- for (auto &op : llvm::make_early_inc_range(*bb)) {
- // If the operation is already trivially dead just add it to the erase list.
- // This also avoids calling `simplifyRegion` on dead region ops
- // unnecessarily.
- if (isOpTriviallyDead(&op)) {
- opsToErase.push_back(&op);
- ++numDCE;
- continue;
- }
-
- // Most operations don't have regions, so fast path that case.
- if (op.getNumRegions() != 0) {
- // If this operation is isolated above, we can't process nested regions
- // with the given 'knownValues' map. This would cause the insertion of
- // implicit captures in explicit capture only regions.
- if (op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>()) {
- ScopedMapTy nestedKnownValues;
- for (auto ®ion : op.getRegions())
- simplifyRegion(nestedKnownValues, region);
- } else {
- // Otherwise, process nested regions normally.
- for (auto ®ion : op.getRegions())
- simplifyRegion(knownValues, region);
- }
- }
-
- // If the operation is simplified, we don't process any held regions.
- if (succeeded(simplifyOperation(knownValues, &op, hasSSADominance)))
- continue;
- }
- // Clear the MemoryEffects cache since its usage is by block only.
- memEffectsCache.clear();
-}
-
-void CSEDriver::simplifyRegion(ScopedMapTy &knownValues, Region ®ion) {
- // If the region is empty there is nothing to do.
- if (region.empty())
- return;
-
- bool hasSSADominance = domInfo->hasSSADominance(®ion);
-
- // If the region only contains one block, then simplify it directly.
- if (region.hasOneBlock()) {
- ScopedMapTy::ScopeTy scope(knownValues);
- simplifyBlock(knownValues, ®ion.front(), hasSSADominance);
- return;
- }
-
- // If the region does not have dominanceInfo, then skip it.
- // TODO: Regions without SSA dominance should define a different
- // traversal order which is appropriate and can be used here.
- if (!hasSSADominance)
- return;
-
- // Note, deque is being used here because there was significant performance
- // gains over vector when the container becomes very large due to the
- // specific access patterns. If/when these performance issues are no
- // longer a problem we can change this to vector. For more information see
- // the llvm mailing list discussion on this:
- // http://lists.llvm.org/pipermail/llvm-commits/Week-of-Mon-20120116/135228.html
- std::deque<std::unique_ptr<CFGStackNode>> stack;
-
- // Process the nodes of the dom tree for this region.
- stack.emplace_back(std::make_unique<CFGStackNode>(
- knownValues, domInfo->getRootNode(®ion)));
-
- while (!stack.empty()) {
- auto ¤tNode = stack.back();
-
- // Check to see if we need to process this node.
- if (!currentNode->processed) {
- currentNode->processed = true;
- simplifyBlock(knownValues, currentNode->node->getBlock(),
- hasSSADominance);
- }
-
- // Otherwise, check to see if we need to process a child node.
- if (currentNode->childIterator != currentNode->node->end()) {
- auto *childNode = *(currentNode->childIterator++);
- stack.emplace_back(
- std::make_unique<CFGStackNode>(knownValues, childNode));
- } else {
- // Finally, if the node and all of its children have been processed
- // then we delete the node.
- stack.pop_back();
- }
- }
-}
-
-void CSEDriver::simplify(Operation *op, bool *changed) {
- /// Simplify all regions.
- ScopedMapTy knownValues;
- for (auto ®ion : op->getRegions())
- simplifyRegion(knownValues, region);
-
- /// Erase any operations that were marked as dead during simplification, and
- /// remove their associated dominator trees.
- for (auto *op : opsToErase) {
- for (Region ®ion : op->getRegions())
- domInfo->invalidate(®ion);
- rewriter.eraseOp(op);
- }
- if (changed)
- *changed = !opsToErase.empty();
-
- // Note: CSE does currently not remove ops with regions, so DominanceInfo
- // does not have to be invalidated.
-}
-
-void mlir::eliminateCommonSubExpressions(RewriterBase &rewriter,
- DominanceInfo &domInfo, Operation *op,
- bool *changed) {
- CSEDriver driver(rewriter, &domInfo);
- driver.simplify(op, changed);
-}
-
namespace {
/// CSE pass.
struct CSE : public impl::CSEPassBase<CSE> {
@@ -425,15 +33,18 @@ struct CSE : public impl::CSEPassBase<CSE> {
} // namespace
void CSE::runOnOperation() {
- // Simplify the IR.
IRRewriter rewriter(&getContext());
- CSEDriver driver(rewriter, &getAnalysis<DominanceInfo>());
+ auto &domInfo = getAnalysis<DominanceInfo>();
bool changed = false;
- driver.simplify(getOperation(), &changed);
-
- // Set statistics.
- numCSE = driver.getNumCSE();
- numDCE = driver.getNumDCE();
+ // `numCSE` / `numDCE` are `llvm::Statistic` objects, not raw `int64_t`, so
+ // the public API's out-parameters cannot point at them directly.
+ int64_t cseCount = 0;
+ int64_t dceCount = 0;
+ eliminateCommonSubExpressions(rewriter, domInfo, getOperation(), &changed,
+ &cseCount, &dceCount);
+
+ numCSE = cseCount;
+ numDCE = dceCount;
// If there was no change to the IR, we mark all analyses as preserved.
if (!changed)
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index 9f9bad1c2a678..aa3b1152f1181 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -35,6 +35,7 @@ struct Canonicalizer : public impl::CanonicalizerPassBase<Canonicalizer> {
this->regionSimplifyLevel = config.getRegionSimplificationLevel();
this->maxIterations = config.getMaxIterations();
this->maxNumRewrites = config.getMaxNumRewrites();
+ this->cseBetweenIterations = config.isCSEBetweenIterationsEnabled();
this->disabledPatterns = disabledPatterns;
this->enabledPatterns = enabledPatterns;
}
@@ -47,6 +48,7 @@ struct Canonicalizer : public impl::CanonicalizerPassBase<Canonicalizer> {
config.setRegionSimplificationLevel(regionSimplifyLevel);
config.setMaxIterations(maxIterations);
config.setMaxNumRewrites(maxNumRewrites);
+ config.enableCSEBetweenIterations(cseBetweenIterations);
RewritePatternSet owningPatterns(context);
for (auto *dialect : context->getLoadedDialects())
diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt
index 3ca16239ba33c..335c2cacd2a4a 100644
--- a/mlir/lib/Transforms/Utils/CMakeLists.txt
+++ b/mlir/lib/Transforms/Utils/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_library(MLIRTransformUtils
CFGToSCF.cpp
CommutativityUtils.cpp
ControlFlowSinkUtils.cpp
+ CSE.cpp
DialectConversion.cpp
FoldUtils.cpp
GreedyPatternRewriteDriver.cpp
diff --git a/mlir/lib/Transforms/Utils/CSE.cpp b/mlir/lib/Transforms/Utils/CSE.cpp
new file mode 100644
index 0000000000000..6c67deb257cec
--- /dev/null
+++ b/mlir/lib/Transforms/Utils/CSE.cpp
@@ -0,0 +1,439 @@
+//===- CSE.cpp - Common Sub-expression Elimination ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements common sub-expression elimination as a library utility.
+// The matching CSE pass is a thin wrapper over the APIs declared here.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/CSE.h"
+
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "llvm/ADT/DenseMapInfo.h"
+#include "llvm/ADT/ScopedHashTable.h"
+#include "llvm/Support/Allocator.h"
+#include "llvm/Support/RecyclingAllocator.h"
+#include <deque>
+
+using namespace mlir;
+
+namespace {
+struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
+ static unsigned getHashValue(const Operation *opC) {
+ return OperationEquivalence::computeHash(
+ const_cast<Operation *>(opC),
+ /*hashOperands=*/OperationEquivalence::directHashValue,
+ /*hashResults=*/OperationEquivalence::ignoreHashValue,
+ OperationEquivalence::IgnoreLocations);
+ }
+ static bool isEqual(const Operation *lhsC, const Operation *rhsC) {
+ auto *lhs = const_cast<Operation *>(lhsC);
+ auto *rhs = const_cast<Operation *>(rhsC);
+ if (lhs == rhs)
+ return true;
+ if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
+ rhs == getTombstoneKey() || rhs == getEmptyKey())
+ return false;
+ return OperationEquivalence::isEquivalentTo(
+ const_cast<Operation *>(lhsC), const_cast<Operation *>(rhsC),
+ OperationEquivalence::IgnoreLocations);
+ }
+};
+} // namespace
+
+namespace {
+/// Simple common sub-expression elimination.
+class CSEDriver {
+public:
+ CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo)
+ : rewriter(rewriter), domInfo(domInfo) {}
+
+ /// Simplify all operations within the given op.
+ void simplify(Operation *op, bool *changed = nullptr);
+
+ /// Simplify operations within the given region.
+ void simplify(Region ®ion, bool *changed = nullptr);
+
+ int64_t getNumCSE() const { return numCSE; }
+ int64_t getNumDCE() const { return numDCE; }
+
+private:
+ /// Shared implementation of operation elimination and scoped map definitions.
+ using AllocatorTy = llvm::RecyclingAllocator<
+ llvm::BumpPtrAllocator,
+ llvm::ScopedHashTableVal<Operation *, Operation *>>;
+ using ScopedMapTy = llvm::ScopedHashTable<Operation *, Operation *,
+ SimpleOperationInfo, AllocatorTy>;
+
+ /// Cache holding MemoryEffects information between two operations. The first
+ /// operation is stored has the key. The second operation is stored inside a
+ /// pair in the value. The pair also hold the MemoryEffects between those
+ /// two operations. If the MemoryEffects is nullptr then we assume there is
+ /// no operation with MemoryEffects::Write between the two operations.
+ using MemEffectsCache =
+ DenseMap<Operation *, std::pair<Operation *, MemoryEffects::Effect *>>;
+
+ /// Represents a single entry in the depth first traversal of a CFG.
+ struct CFGStackNode {
+ CFGStackNode(ScopedMapTy &knownValues, DominanceInfoNode *node)
+ : scope(knownValues), node(node), childIterator(node->begin()) {}
+
+ /// Scope for the known values.
+ ScopedMapTy::ScopeTy scope;
+
+ DominanceInfoNode *node;
+ DominanceInfoNode::const_iterator childIterator;
+
+ /// If this node has been fully processed yet or not.
+ bool processed = false;
+ };
+
+ /// Attempt to eliminate a redundant operation. Returns success if the
+ /// operation was marked for removal, failure otherwise.
+ LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op,
+ bool hasSSADominance);
+ void simplifyBlock(ScopedMapTy &knownValues, Block *bb, bool hasSSADominance);
+ void simplifyRegion(ScopedMapTy &knownValues, Region ®ion);
+
+ /// Erase all operations queued for deletion by the simplification routines.
+ void eraseDeadOps(bool *changed);
+
+ void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
+ Operation *existing, bool hasSSADominance);
+
+ /// Check if there is side-effecting operations other than the given effect
+ /// between the two operations.
+ bool hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp);
+
+ /// A rewriter for modifying the IR.
+ RewriterBase &rewriter;
+
+ /// Operations marked as dead and to be erased.
+ std::vector<Operation *> opsToErase;
+ DominanceInfo *domInfo = nullptr;
+ MemEffectsCache memEffectsCache;
+
+ // Various statistics.
+ int64_t numCSE = 0;
+ int64_t numDCE = 0;
+};
+} // namespace
+
+void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
+ Operation *existing,
+ bool hasSSADominance) {
+ // If we find one then replace all uses of the current operation with the
+ // existing one and mark it for deletion. We can only replace an operand in
+ // an operation if it has not been visited yet.
+ if (hasSSADominance) {
+ // If the region has SSA dominance, then we are guaranteed to have not
+ // visited any use of the current operation.
+ if (auto *rewriteListener =
+ dyn_cast_if_present<RewriterBase::Listener>(rewriter.getListener()))
+ rewriteListener->notifyOperationReplaced(op, existing);
+ // Replace all uses, but do not remove the operation yet. This does not
+ // notify the listener because the original op is not erased.
+ rewriter.replaceAllUsesWith(op->getResults(), existing->getResults());
+ opsToErase.push_back(op);
+ } else {
+ // When the region does not have SSA dominance, we need to check if we
+ // have visited a use before replacing any use.
+ auto wasVisited = [&](OpOperand &operand) {
+ return !knownValues.count(operand.getOwner());
+ };
+ if (auto *rewriteListener =
+ dyn_cast_if_present<RewriterBase::Listener>(rewriter.getListener()))
+ for (Value v : op->getResults())
+ if (all_of(v.getUses(), wasVisited))
+ rewriteListener->notifyOperationReplaced(op, existing);
+
+ // Replace all uses, but do not remove the operation yet. This does not
+ // notify the listener because the original op is not erased.
+ rewriter.replaceUsesWithIf(op->getResults(), existing->getResults(),
+ wasVisited);
+
+ // There may be some remaining uses of the operation.
+ if (op->use_empty())
+ opsToErase.push_back(op);
+ }
+
+ // If the existing operation has an unknown location and the current
+ // operation doesn't, then set the existing op's location to that of the
+ // current op.
+ if (isa<UnknownLoc>(existing->getLoc()) && !isa<UnknownLoc>(op->getLoc()))
+ existing->setLoc(op->getLoc());
+
+ ++numCSE;
+}
+
+bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp,
+ Operation *toOp) {
+ assert(fromOp->getBlock() == toOp->getBlock());
+ assert(hasEffect<MemoryEffects::Read>(fromOp) &&
+ "expected read effect on fromOp");
+ assert(hasEffect<MemoryEffects::Read>(toOp) &&
+ "expected read effect on toOp");
+
+ // Collect the read effects of fromOp. A write can only block CSE if it
+ // can conflict with one of these reads.
+ SmallVector<MemoryEffects::EffectInstance> readEffects;
+ if (auto memOp = dyn_cast<MemoryEffectOpInterface>(fromOp)) {
+ SmallVector<MemoryEffects::EffectInstance> fromEffects;
+ memOp.getEffects(fromEffects);
+ for (MemoryEffects::EffectInstance &e : fromEffects)
+ if (isa<MemoryEffects::Read>(e.getEffect()))
+ readEffects.push_back(e);
+ }
+
+ Operation *nextOp = fromOp->getNextNode();
+ auto result =
+ memEffectsCache.try_emplace(fromOp, std::make_pair(fromOp, nullptr));
+ if (!result.second) {
+ auto memEffectsCachePair = result.first->second;
+ if (memEffectsCachePair.second == nullptr) {
+ // No MemoryEffects::Write has been detected until the cached operation.
+ // Continue looking from the cached operation to toOp.
+ nextOp = memEffectsCachePair.first;
+ } else {
+ // MemoryEffects::Write has been detected before so there is no need to
+ // check further.
+ return true;
+ }
+ }
+ while (nextOp && nextOp != toOp) {
+ std::optional<SmallVector<MemoryEffects::EffectInstance>> effects =
+ getEffectsRecursively(nextOp);
+ if (!effects) {
+ // TODO: Do we need to handle other effects generically?
+ // If the operation does not implement the MemoryEffectOpInterface we
+ // conservatively assume it writes.
+ result.first->second =
+ std::make_pair(nextOp, MemoryEffects::Write::get());
+ return true;
+ }
+
+ for (const MemoryEffects::EffectInstance &effect : *effects) {
+ if (isa<MemoryEffects::Write>(effect.getEffect())) {
+ // A write on a resource disjoint from all read resources cannot
+ // conflict with the reads being CSE'd.
+ SideEffects::Resource *writeResource = effect.getResource();
+ bool canConflict =
+ llvm::any_of(readEffects, [&](const auto &readEffect) {
+ SideEffects::Resource *readResource = readEffect.getResource();
+ if (writeResource->isDisjointFrom(readResource))
+ return false;
+ // A pointer-based access to an addressable resource cannot
+ // conflict with a non-addressable resource.
+ if (readEffect.getValue() && !writeResource->isAddressable())
+ return false;
+ if (effect.getValue() && !readResource->isAddressable())
+ return false;
+ return true;
+ });
+ if (canConflict) {
+ result.first->second = {nextOp, MemoryEffects::Write::get()};
+ return true;
+ }
+ }
+ }
+ nextOp = nextOp->getNextNode();
+ }
+ result.first->second = std::make_pair(toOp, nullptr);
+ return false;
+}
+
+/// Attempt to eliminate a redundant operation.
+LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
+ Operation *op,
+ bool hasSSADominance) {
+ // Don't simplify terminator operations.
+ if (op->hasTrait<OpTrait::IsTerminator>())
+ return failure();
+
+ // Don't simplify operations with regions that have multiple blocks.
+ // TODO: We need additional tests to verify that we handle such IR correctly.
+ if (!llvm::all_of(op->getRegions(),
+ [](Region &r) { return r.empty() || r.hasOneBlock(); }))
+ return failure();
+
+ // Some simple use case of operation with memory side-effect are dealt with
+ // here. Operations with no side-effect are done after.
+ if (!isMemoryEffectFree(op)) {
+ // TODO: Only basic use case for operations with MemoryEffects::Read can be
+ // eleminated now. More work needs to be done for more complicated patterns
+ // and other side-effects.
+ if (!hasSingleEffect<MemoryEffects::Read>(op))
+ return failure();
+
+ // Look for an existing definition for the operation.
+ if (auto *existing = knownValues.lookup(op)) {
+ if (existing->getBlock() == op->getBlock() &&
+ !hasOtherSideEffectingOpInBetween(existing, op)) {
+ // The operation that can be deleted has been reach with no
+ // side-effecting operations in between the existing operation and
+ // this one so we can remove the duplicate.
+ replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
+ return success();
+ }
+ }
+ knownValues.insert(op, op);
+ return failure();
+ }
+
+ // Look for an existing definition for the operation.
+ if (auto *existing = knownValues.lookup(op)) {
+ replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
+ return success();
+ }
+
+ // Otherwise, we add this operation to the known values map.
+ knownValues.insert(op, op);
+ return failure();
+}
+
+void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb,
+ bool hasSSADominance) {
+ for (auto &op : llvm::make_early_inc_range(*bb)) {
+ // If the operation is already trivially dead just add it to the erase list.
+ // This also avoids calling `simplifyRegion` on dead region ops
+ // unnecessarily.
+ if (isOpTriviallyDead(&op)) {
+ opsToErase.push_back(&op);
+ ++numDCE;
+ continue;
+ }
+
+ // Most operations don't have regions, so fast path that case.
+ if (op.getNumRegions() != 0) {
+ // If this operation is isolated above, we can't process nested regions
+ // with the given 'knownValues' map. This would cause the insertion of
+ // implicit captures in explicit capture only regions.
+ if (op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>()) {
+ ScopedMapTy nestedKnownValues;
+ for (auto ®ion : op.getRegions())
+ simplifyRegion(nestedKnownValues, region);
+ } else {
+ // Otherwise, process nested regions normally.
+ for (auto ®ion : op.getRegions())
+ simplifyRegion(knownValues, region);
+ }
+ }
+
+ // If the operation is simplified, we don't process any held regions.
+ if (succeeded(simplifyOperation(knownValues, &op, hasSSADominance)))
+ continue;
+ }
+ // Clear the MemoryEffects cache since its usage is by block only.
+ memEffectsCache.clear();
+}
+
+void CSEDriver::simplifyRegion(ScopedMapTy &knownValues, Region ®ion) {
+ // If the region is empty there is nothing to do.
+ if (region.empty())
+ return;
+
+ bool hasSSADominance = domInfo->hasSSADominance(®ion);
+
+ // If the region only contains one block, then simplify it directly.
+ if (region.hasOneBlock()) {
+ ScopedMapTy::ScopeTy scope(knownValues);
+ simplifyBlock(knownValues, ®ion.front(), hasSSADominance);
+ return;
+ }
+
+ // If the region does not have dominanceInfo, then skip it.
+ // TODO: Regions without SSA dominance should define a different
+ // traversal order which is appropriate and can be used here.
+ if (!hasSSADominance)
+ return;
+
+ // Note, deque is being used here because there was significant performance
+ // gains over vector when the container becomes very large due to the
+ // specific access patterns. If/when these performance issues are no
+ // longer a problem we can change this to vector. For more information see
+ // the llvm mailing list discussion on this:
+ // http://lists.llvm.org/pipermail/llvm-commits/Week-of-Mon-20120116/135228.html
+ std::deque<std::unique_ptr<CFGStackNode>> stack;
+
+ // Process the nodes of the dom tree for this region.
+ stack.emplace_back(std::make_unique<CFGStackNode>(
+ knownValues, domInfo->getRootNode(®ion)));
+
+ while (!stack.empty()) {
+ auto ¤tNode = stack.back();
+
+ // Check to see if we need to process this node.
+ if (!currentNode->processed) {
+ currentNode->processed = true;
+ simplifyBlock(knownValues, currentNode->node->getBlock(),
+ hasSSADominance);
+ }
+
+ // Otherwise, check to see if we need to process a child node.
+ if (currentNode->childIterator != currentNode->node->end()) {
+ auto *childNode = *(currentNode->childIterator++);
+ stack.emplace_back(
+ std::make_unique<CFGStackNode>(knownValues, childNode));
+ } else {
+ // Finally, if the node and all of its children have been processed
+ // then we delete the node.
+ stack.pop_back();
+ }
+ }
+}
+
+void CSEDriver::eraseDeadOps(bool *changed) {
+ // Erase any operations that were marked as dead during simplification, and
+ // remove their associated dominator trees.
+ for (auto *op : opsToErase) {
+ for (Region ®ion : op->getRegions())
+ domInfo->invalidate(®ion);
+ rewriter.eraseOp(op);
+ }
+ if (changed)
+ *changed = !opsToErase.empty();
+ opsToErase.clear();
+
+ // Note: CSE does currently not remove ops with regions, so DominanceInfo
+ // does not have to be invalidated.
+}
+
+void CSEDriver::simplify(Operation *op, bool *changed) {
+ // Simplify all regions.
+ ScopedMapTy knownValues;
+ for (auto ®ion : op->getRegions())
+ simplifyRegion(knownValues, region);
+ eraseDeadOps(changed);
+}
+
+void CSEDriver::simplify(Region ®ion, bool *changed) {
+ ScopedMapTy knownValues;
+ simplifyRegion(knownValues, region);
+ eraseDeadOps(changed);
+}
+
+void mlir::eliminateCommonSubExpressions(RewriterBase &rewriter,
+ DominanceInfo &domInfo, Operation *op,
+ bool *changed, int64_t *numCSE,
+ int64_t *numDCE) {
+ CSEDriver driver(rewriter, &domInfo);
+ driver.simplify(op, changed);
+ if (numCSE)
+ *numCSE = driver.getNumCSE();
+ if (numDCE)
+ *numDCE = driver.getNumDCE();
+}
+
+void mlir::eliminateCommonSubExpressions(RewriterBase &rewriter,
+ DominanceInfo &domInfo, Region ®ion,
+ bool *changed) {
+ CSEDriver driver(rewriter, &domInfo);
+ driver.simplify(region, changed);
+}
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 578e680535bed..eba6a81e65cfc 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -14,12 +14,14 @@
#include "mlir/Config/mlir-config.h"
#include "mlir/IR/Action.h"
+#include "mlir/IR/Dominance.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Rewrite/PatternApplicator.h"
+#include "mlir/Transforms/CSE.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/BitVector.h"
@@ -897,6 +899,16 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
/*mergeBlocks=*/config.getRegionSimplificationLevel() ==
GreedySimplifyRegionLevel::Aggressive));
}
+
+ // Optionally run full CSE. If CSE changes the IR we iterate again so
+ // that patterns can fire on the deduplicated operations.
+ if (config.isCSEBetweenIterationsEnabled()) {
+ DominanceInfo domInfo;
+ bool cseChanged = false;
+ eliminateCommonSubExpressions(rewriter, domInfo, region,
+ &cseChanged);
+ continueRewrites |= cseChanged;
+ }
},
{®ion}, iteration);
} while (continueRewrites);
diff --git a/mlir/test/Pass/run-reproducer.mlir b/mlir/test/Pass/run-reproducer.mlir
index 256c9e83f30f9..68f634d2038fc 100644
--- a/mlir/test/Pass/run-reproducer.mlir
+++ b/mlir/test/Pass/run-reproducer.mlir
@@ -17,7 +17,7 @@ func.func @bar() {
// CHECK: builtin.module(
// CHECK-NEXT: func.func(
// CHECK-NEXT: cse,
- // CHECK-NEXT: canonicalize{ max-iterations=1 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=false}
+ // CHECK-NEXT: canonicalize{cse-between-iterations=false max-iterations=1 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=false}
// CHECK-NEXT: )
// CHECK-NEXT: )
pipeline: "builtin.module(func.func(cse,canonicalize{max-iterations=1 max-num-rewrites=-1 region-simplify=normal top-down=false}))",
diff --git a/mlir/test/Transforms/canonicalize-cse-between-iterations.mlir b/mlir/test/Transforms/canonicalize-cse-between-iterations.mlir
new file mode 100644
index 0000000000000..36ad8f7fb4cac
--- /dev/null
+++ b/mlir/test/Transforms/canonicalize-cse-between-iterations.mlir
@@ -0,0 +1,81 @@
+// RUN: mlir-opt %s --canonicalize -split-input-file | FileCheck %s --check-prefixes=CHECK,NOCSE
+// RUN: mlir-opt %s --canonicalize='cse-between-iterations=true' -split-input-file | FileCheck %s --check-prefixes=CHECK,CSE
+// Convergence / max-iterations interaction: only one pass-application iteration
+// is allowed, so CSE unifies the duplicates but the follow-up fold cannot fire.
+// RUN: mlir-opt %s --canonicalize='cse-between-iterations=true max-iterations=1' -split-input-file | FileCheck %s --check-prefixes=CHECK,ONESHOT
+
+// Two structurally identical subexpressions cannot be folded away by
+// canonicalization alone because they are distinct SSA values. Running CSE
+// between iterations unifies them, which lets `arith.subi %a, %a -> 0` fire
+// on the next iteration and the whole body collapses to a constant.
+
+// CHECK-LABEL: @dup_subs
+func.func @dup_subs(%x: i32, %y: i32) -> i32 {
+ // NOCSE-COUNT-3: arith.subi
+ // NOCSE-NOT: arith.subi
+
+ // CSE-NOT: arith.subi
+ // CSE: %[[C0:.*]] = arith.constant 0 : i32
+ // CSE: return %[[C0]]
+
+ // Max-iterations=1: CSE fires once but the downstream subi(a, a) -> 0 fold
+ // needs a second pattern-application iteration, which is disallowed.
+ // ONESHOT-COUNT-2: arith.subi
+ // ONESHOT-NOT: arith.constant
+ %a = arith.subi %x, %y : i32
+ %b = arith.subi %x, %y : i32
+ %c = arith.subi %a, %b : i32
+ return %c : i32
+}
+
+// -----
+
+// After CSE unifies the two redundant subi ops, the downstream `arith.subi
+// %a, %a` folds to 0, which in turn makes the downstream `arith.addi 0, %y`
+// fold to %y. This demonstrates that CSE-between-iterations enables a
+// cascading simplification that canonicalization alone cannot achieve.
+
+// CHECK-LABEL: @cascade
+func.func @cascade(%x: i32, %y: i32) -> i32 {
+ // NOCSE-COUNT-3: arith.subi
+ // NOCSE: arith.addi
+ // NOCSE: return
+
+ // CSE-NOT: arith.subi
+ // CSE-NOT: arith.addi
+ // CSE: return %arg1 : i32
+ %a = arith.subi %x, %y : i32
+ %b = arith.subi %x, %y : i32
+ %c = arith.subi %a, %b : i32
+ %d = arith.addi %c, %y : i32
+ return %d : i32
+}
+
+// -----
+
+// Nested regions must also be reached by CSE-between-iterations. The
+// duplicate `arith.subi` ops inside the scf.for body are unified, unblocking
+// the `arith.subi %a, %a -> 0` fold on the next iteration and then the
+// `arith.addi 0, ...` fold that follows. The loop body still uses `%i` so
+// the loop itself is not dead and survives canonicalization.
+
+// CHECK-LABEL: @nested
+func.func @nested(%lb: index, %ub: index, %step: index,
+ %x: i32, %y: i32, %init: i32) -> i32 {
+ // NOCSE: scf.for
+ // NOCSE-COUNT-3: arith.subi
+
+ // CSE: scf.for
+ // CSE-NOT: arith.subi
+ // CSE: scf.yield
+ %r = scf.for %i = %lb to %ub step %step iter_args(%acc = %init) -> i32 {
+ %a = arith.subi %x, %y : i32
+ %b = arith.subi %x, %y : i32
+ %c = arith.subi %a, %b : i32
+ %ic = arith.index_cast %i : index to i32
+ %nxt = arith.addi %acc, %ic : i32
+ %final = arith.addi %nxt, %c : i32
+ scf.yield %final : i32
+ }
+ return %r : i32
+}
diff --git a/mlir/test/Transforms/composite-pass.mlir b/mlir/test/Transforms/composite-pass.mlir
index 460cd612cde63..03c540d72185b 100644
--- a/mlir/test/Transforms/composite-pass.mlir
+++ b/mlir/test/Transforms/composite-pass.mlir
@@ -4,7 +4,7 @@
// Ensure the composite pass correctly prints its options.
// PIPELINE: builtin.module(
// PIPELINE-NEXT: composite-fixed-point-pass{max-iterations=10 name=TestCompositePass
-// PIPELINE-SAME: pipeline=canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true},cse}
+// PIPELINE-SAME: pipeline=canonicalize{cse-between-iterations=false max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true},cse}
// CHECK-LABEL: running `TestCompositePass`
// CHECK: running `CanonicalizerPass`
More information about the Mlir-commits
mailing list