[Mlir-commits] [mlir] [mlir][CSE] Introduce hoist-pure-ops logic to CSE pass (PR #180556)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Feb 13 09:06:43 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-func
Author: lonely eagle (linuxlonelyeagle)
<details>
<summary>Changes</summary>
This PR is based on this theory: `if an Op is a Pure Op, we have the opportunity to hoist its position based on SSA dominance`. This logic has now been incorporated into the CSE pass, now we can use it to further optimize the IR to achieve more concise code.
RFC: https://discourse.llvm.org/t/rfc-mlir-introduce-hoist-pure-ops-pass/88903
---
Patch is 74.35 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/180556.diff
21 Files Affected:
- (modified) mlir/include/mlir/Dialect/Func/IR/FuncOps.h (+1)
- (modified) mlir/include/mlir/Dialect/Func/IR/FuncOps.td (+3-1)
- (modified) mlir/include/mlir/Dialect/SCF/IR/SCF.h (+1)
- (modified) mlir/include/mlir/Dialect/SCF/IR/SCFOps.td (+4-1)
- (modified) mlir/include/mlir/Interfaces/CMakeLists.txt (+1)
- (added) mlir/include/mlir/Interfaces/HoistingContainerOpInterface.h (+13)
- (added) mlir/include/mlir/Interfaces/HoistingContainerOpInterface.td (+41)
- (modified) mlir/lib/Dialect/Func/IR/CMakeLists.txt (+1)
- (modified) mlir/lib/Dialect/SCF/IR/CMakeLists.txt (+1)
- (modified) mlir/lib/Interfaces/CMakeLists.txt (+2-1)
- (added) mlir/lib/Interfaces/HoistingContainerOpInterface.cpp (+21)
- (modified) mlir/lib/Transforms/CMakeLists.txt (+1)
- (modified) mlir/lib/Transforms/CSE.cpp (+167-20)
- (modified) mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir (+4-8)
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir (+4-11)
- (modified) mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir (+1-2)
- (modified) mlir/test/Dialect/Linalg/transform-op-pad.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir (+105-106)
- (modified) mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir (+57-61)
- (modified) mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir (+42-43)
- (modified) mlir/test/Transforms/cse.mlir (+10-14)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Func/IR/FuncOps.h b/mlir/include/mlir/Dialect/Func/IR/FuncOps.h
index 5e10a9f50b774..ad5eac754f236 100644
--- a/mlir/include/mlir/Dialect/Func/IR/FuncOps.h
+++ b/mlir/include/mlir/Dialect/Func/IR/FuncOps.h
@@ -18,6 +18,7 @@
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Interfaces/HoistingContainerOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
index 06ce4f16c867d..db01cff2a6937 100644
--- a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
+++ b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
@@ -15,6 +15,7 @@ include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/FunctionInterfaces.td"
+include "mlir/Interfaces/HoistingContainerOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -248,7 +249,8 @@ def ConstantOp : Func_Op<"constant",
//===----------------------------------------------------------------------===//
def FuncOp : Func_Op<"func", [
- AffineScope, AutomaticAllocationScope,
+ AffineScope, AutomaticAllocationScope,
+ DeclareOpInterfaceMethods<HoistingContainerOpInterface>,
FunctionOpInterface, IsolatedFromAbove, OpAsmOpInterface
]> {
let summary = "An operation with a name containing a single `SSACFG` region";
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index e754a04b0903a..f554147814a75 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -20,6 +20,7 @@
#include "mlir/IR/RegionKindInterface.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/HoistingContainerOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index a08cf3c95e6ce..62061a90c0d2e 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -18,6 +18,7 @@ include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/IR/RegionKindInterface.td"
include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Interfaces/HoistingContainerOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/ParallelCombiningOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -162,7 +163,8 @@ def ForOp : SCF_Op<"for",
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getEntrySuccessorOperands", "getSuccessorInputs"]>,
SingleBlockImplicitTerminator<"scf::YieldOp">,
- RecursiveMemoryEffects]> {
+ RecursiveMemoryEffects,
+ DeclareOpInterfaceMethods<HoistingContainerOpInterface>]> {
let summary = "for operation";
let description = [{
The `scf.for` operation represents a loop taking 3 SSA value as operands
@@ -986,6 +988,7 @@ def WhileOp : SCF_Op<"while",
["getEntrySuccessorOperands", "getSuccessorInputs"]>,
DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getRegionIterArgs", "getYieldedValuesMutable"]>,
+ DeclareOpInterfaceMethods<HoistingContainerOpInterface>,
RecursiveMemoryEffects, SingleBlock]> {
let summary = "a generic 'while' loop";
let description = [{
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index eb96a68861116..7acefb64f61ab 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_interface(ControlFlowInterfaces)
add_mlir_interface(DerivedAttributeOpInterface)
add_mlir_interface(DestinationStyleOpInterface)
add_mlir_interface(FunctionInterfaces)
+add_mlir_interface(HoistingContainerOpInterface)
add_mlir_interface(IndexingMapOpInterface)
add_mlir_interface(InferIntRangeInterface)
add_mlir_interface(InferStridedMetadataInterface)
diff --git a/mlir/include/mlir/Interfaces/HoistingContainerOpInterface.h b/mlir/include/mlir/Interfaces/HoistingContainerOpInterface.h
new file mode 100644
index 0000000000000..6953f8b2138a3
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/HoistingContainerOpInterface.h
@@ -0,0 +1,13 @@
+
+
+#ifndef MLIR_INTERFACES_HOISTING_CONTAINER_OP_INTERFACE
+#define MLIR_INTERFACES_HOISTING_CONTAINER_OP_INTERFACE
+
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/HoistingContainerOpInterface.h.inc"
+
+namespace mlir {
+bool canContainHoistedOps(Operation *op);
+} // namespace mlir
+
+#endif // MLIR_INTERFACES_HOISTING_CONTAINER_OP_INTERFACE
diff --git a/mlir/include/mlir/Interfaces/HoistingContainerOpInterface.td b/mlir/include/mlir/Interfaces/HoistingContainerOpInterface.td
new file mode 100644
index 0000000000000..1f1c9994f09c1
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/HoistingContainerOpInterface.td
@@ -0,0 +1,41 @@
+//===- HoistingContainerOpInterface.td - Interface Decl. -*- tablegen -*---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the definition file for the HoistingContainerOpInterface.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_HOISTING_CONTAINER_OP_INTERFACE
+#define MLIR_INTERFACES_HOISTING_CONTAINER_OP_INTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def HoistingContainerOpInterface : OpInterface<"HoistingContainerOpInterface"> {
+ let description = [{
+ This interface models whether an operation's regions are capable of
+ acting as a container for operations hoisted from nested regions.
+ }];
+ let cppNamespace = "::mlir";
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns true if this operation's regions can accommodate operations
+ hoisted from its nested scopes.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"canContainHoistedOps",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return true;
+ }]
+ >
+ ];
+}
+
+#endif // MLIR_INTERFACES_HOISTING_CONTAINER_OP_INTERFACE
diff --git a/mlir/lib/Dialect/Func/IR/CMakeLists.txt b/mlir/lib/Dialect/Func/IR/CMakeLists.txt
index 329301c6fbafd..c748fdf2b57f0 100644
--- a/mlir/lib/Dialect/Func/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Func/IR/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRFuncDialect
MLIRCallInterfaces
MLIRControlFlowInterfaces
MLIRFunctionInterfaces
+ MLIRHoistingContainerOpInterface
MLIRInferTypeOpInterface
MLIRIR
MLIRSideEffectInterfaces
diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
index b111117410ba3..bbf27027f37b4 100644
--- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRSCFDialect
MLIRControlFlowDialect
MLIRDialectUtils
MLIRFunctionInterfaces
+ MLIRHoistingContainerOpInterface
MLIRIR
MLIRLoopLikeInterface
MLIRSideEffectInterfaces
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index ad3e2b61be418..1d73e5d2c6912 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -8,6 +8,7 @@ set(LLVM_OPTIONAL_SOURCES
DestinationStyleOpInterface.cpp
FunctionImplementation.cpp
FunctionInterfaces.cpp
+ HoistingContainerOpInterface.cpp
IndexingMapOpInterface.cpp
InferIntRangeInterface.cpp
InferStridedMetadataInterface.cpp
@@ -64,7 +65,7 @@ add_mlir_library(MLIRFunctionInterfaces
MLIRCallInterfaces
MLIRIR
)
-
+add_mlir_interface_library(HoistingContainerOpInterface)
add_mlir_interface_library(IndexingMapOpInterface)
add_mlir_interface_library(InferIntRangeInterface)
diff --git a/mlir/lib/Interfaces/HoistingContainerOpInterface.cpp b/mlir/lib/Interfaces/HoistingContainerOpInterface.cpp
new file mode 100644
index 0000000000000..33801c6509ad2
--- /dev/null
+++ b/mlir/lib/Interfaces/HoistingContainerOpInterface.cpp
@@ -0,0 +1,21 @@
+//===- HoistingContainerOpInterface.cpp -- Hoisting Container Op Interface -==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/HoistingContainerOpInterface.h"
+
+using namespace mlir;
+
+namespace mlir {
+#include "mlir/Interfaces/HoistingContainerOpInterface.cpp.inc"
+} // namespace mlir
+
+bool mlir::canContainHoistedOps(Operation *op) {
+ if (auto containerOp = dyn_cast<HoistingContainerOpInterface>(op))
+ return containerOp.canContainHoistedOps();
+ return false;
+}
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 8907724627386..1a2cd72691a79 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -32,6 +32,7 @@ add_mlir_library(MLIRTransforms
LINK_LIBS PUBLIC
MLIRAnalysis
MLIRFunctionInterfaces
+ MLIRHoistingContainerOpInterface
MLIRLoopLikeInterface
MLIRMemOpInterfaces
MLIRMemorySlotInterfaces
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 8eaac308755fd..e22166c549048 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -15,12 +15,15 @@
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Interfaces/HoistingContainerOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/Support/Allocator.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/RecyclingAllocator.h"
#include <deque>
@@ -29,6 +32,7 @@ namespace mlir {
#include "mlir/Transforms/Passes.h.inc"
} // namespace mlir
+#define DEBUG_TYPE "cse"
using namespace mlir;
namespace {
@@ -101,13 +105,17 @@ class CSEDriver {
/// Attempt to eliminate a redundant operation. Returns success if the
/// operation was marked for removal, failure otherwise.
- LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op,
+ LogicalResult simplifyOperation(ScopedMapTy &knownValues,
+ ScopedMapTy &knownPureOps, Operation *op,
bool hasSSADominance);
- void simplifyBlock(ScopedMapTy &knownValues, Block *bb, bool hasSSADominance);
- void simplifyRegion(ScopedMapTy &knownValues, Region ®ion);
+ void simplifyBlock(ScopedMapTy &knownValues, ScopedMapTy &knownPureOps,
+ Block *bb, bool hasSSADominance);
+ void simplifyRegion(ScopedMapTy &knownValues, ScopedMapTy &knownPureOps,
+ Region ®ion);
void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
Operation *existing, bool hasSSADominance);
+ LogicalResult hoistPureOp(Operation *existing, Operation *op);
/// Check if there is side-effecting operations other than the given effect
/// between the two operations.
@@ -117,7 +125,7 @@ class CSEDriver {
RewriterBase &rewriter;
/// Operations marked as dead and to be erased.
- std::vector<Operation *> opsToErase;
+ SmallVector<Operation *> opsToErase;
DominanceInfo *domInfo = nullptr;
MemEffectsCache memEffectsCache;
@@ -127,6 +135,69 @@ class CSEDriver {
};
} // namespace
+static bool isBlockCrossIsIsolatedFromAbove(DominanceInfo *dominate, Block *a,
+ Block *b) {
+ if (a == b)
+ return false;
+ if (a->getParent() == b->getParent())
+ return false;
+ if (dominate->dominates(b, a))
+ std::swap(b, a);
+ while (b && b->getParentOp()) {
+ Operation *parnetOp = b->getParentOp();
+ if (parnetOp->mightHaveTrait<OpTrait::IsIsolatedFromAbove>())
+ return true;
+ b = parnetOp->getBlock();
+ if (b == a)
+ return false;
+ }
+ return false;
+}
+
+/// Hoist the pure ops to the location of the Nearest Common Dominator
+LogicalResult CSEDriver::hoistPureOp(Operation *existing, Operation *op) {
+ Block *ancestorBlock =
+ domInfo->findNearestCommonDominator(existing->getBlock(), op->getBlock());
+ if (!ancestorBlock) {
+ LDBG() << "hoist " << OpWithFlags(existing, OpPrintingFlags().skipRegions())
+ << " and " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << "failed";
+ return failure();
+ }
+
+ if (ancestorBlock->getParent() != existing->getParentRegion() &&
+ !canContainHoistedOps(ancestorBlock->getParentOp()))
+ return failure();
+
+ if (isBlockCrossIsIsolatedFromAbove(domInfo, ancestorBlock,
+ existing->getBlock()))
+ return failure();
+
+ Operation *insertPoint = nullptr;
+ for (Value operand : op->getOperands()) {
+ if (domInfo->properlyDominates(operand, &ancestorBlock->front()))
+ continue;
+ if (!insertPoint) {
+ insertPoint = operand.getDefiningOp();
+ } else {
+ insertPoint = domInfo->dominates(insertPoint, operand.getDefiningOp())
+ ? operand.getDefiningOp()
+ : insertPoint;
+ }
+ }
+ if (!insertPoint) {
+ rewriter.moveOpBefore(existing, ancestorBlock, ancestorBlock->begin());
+ rewriter.moveOpAfter(op, existing);
+ } else {
+ rewriter.moveOpAfter(existing, insertPoint);
+ rewriter.moveOpAfter(op, existing);
+ }
+ LDBG() << "hoist " << OpWithFlags(existing, OpPrintingFlags().skipRegions())
+ << " and " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << "success";
+ return success();
+}
+
void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
Operation *existing,
bool hasSSADominance) {
@@ -141,6 +212,15 @@ void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
rewriteListener->notifyOperationReplaced(op, existing);
// Replace all uses, but do not remove the operation yet. This does not
// notify the listener because the original op is not erased.
+ if (!domInfo->properlyDominates(existing, op)) {
+ if (failed(hoistPureOp(existing, op)))
+ return;
+ }
+ LDBG() << "replace " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << " with "
+ << OpWithFlags(existing, OpPrintingFlags().skipRegions());
+ LDBG() << "add " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << " to opsToErase";
rewriter.replaceAllUsesWith(op->getResults(), existing->getResults());
opsToErase.push_back(op);
} else {
@@ -155,14 +235,25 @@ void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
if (all_of(v.getUses(), wasVisited))
rewriteListener->notifyOperationReplaced(op, existing);
+ if (!domInfo->properlyDominates(existing, op)) {
+ if (failed(hoistPureOp(existing, op)))
+ return;
+ }
// Replace all uses, but do not remove the operation yet. This does not
// notify the listener because the original op is not erased.
+ LDBG() << "replace " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << " with "
+ << OpWithFlags(existing, OpPrintingFlags().skipRegions());
rewriter.replaceUsesWithIf(op->getResults(), existing->getResults(),
wasVisited);
// There may be some remaining uses of the operation.
- if (op->use_empty())
+ if (op->use_empty()) {
+ LDBG() << "use_empty, add "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << " to opsToErase";
opsToErase.push_back(op);
+ }
}
// If the existing operation has an unknown location and the current
@@ -222,8 +313,11 @@ bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp,
/// Attempt to eliminate a redundant operation.
LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
+ ScopedMapTy &knownPureOps,
Operation *op,
bool hasSSADominance) {
+ LDBG() << "visit operation: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
// Don't simplify terminator operations.
if (op->hasTrait<OpTrait::IsTerminator>())
return failure();
@@ -261,7 +355,27 @@ LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
return success();
}
}
- knownValues.insert(op, op);
+ if (auto *existing = knownPureOps.lookup(op)) {
+ if (existing->getBlock() == op->getBlock() &&
+ !hasOtherSideEffectingOpInBetween(existing, op)) {
+ // The operation that can be deleted has been reach with no
+ // side-effecting operations in between the existing operation and
+ // this one so we can remove the duplicate.
+ replaceUsesAndDelete(knownPureOps, op, existing, hasSSADominance);
+ return success();
+ }
+ }
+
+ if (mlir::isPure(op)) {
+ LDBG() << "insert op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << "to pureMap";
+ knownPureOps.insert(op, op);
+ } else {
+ LDBG() << "insert op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions()) << "to map";
+ knownValues.insert(op, op);
+ }
return failure();
}
@@ -272,14 +386,31 @@ LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
return success();
}
- // Otherwise, we add this operation to the known values map.
- knownValues.insert(op, op);
+ if (auto *existing = knownPureOps.lookup(op)) {
+ replaceUsesAndDelete(knownPureOps, op, existing, hasSSADominance);
+ ++numCSE;
+ return success();
+ }
+
+ if (mlir::isPure(op)) {
+ LDBG() << "insert op: " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << "to pureMap";
+ knownPureOps.insert(op, op);
+ } else {
+ // Otherwise, we add this operation to the known values map.
+ LDBG() << "insert op: " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << "to map";
+ knownValues.insert(op, op);
+ }
return failure();
}
-void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb,
+void CSEDriver::simplifyBlock(ScopedMapTy &knownValues,
+ ScopedMapTy &knownPureOps, Block *bb,
bool hasSSADominance) {
- for (auto &op : *bb) {
+ LDBG() << "visit block #" << bb->computeBlockNumber() << " of "
+ << OpWithFlags(bb->getParentOp(), OpPrintingFlags().skipRegions());
+ for (auto &op : llvm::make_early_inc_range(*bb)) {
// Most operations don't have regions, so fast path that case.
if (op.getNumRegions() != 0) {
// If this operation is isolated above, we can't process nested regions
@@ -287,34 +418,45 @@ void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb,
// implicit captures in explicit capture only regions.
if (op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>()) {
ScopedMapTy nestedKnownValues;
+ ScopedMapTy nestedKnownPureOps;
+ ScopedMapTy::ScopeTy scope(nestedKnownValues);
+ ScopedMapTy::ScopeTy pureScope(nestedKnownPureOps);
for (auto ®ion : op.getRegions())
- simplifyRegion(nestedKnownValues, region);
+ simplifyRegion(nestedKnownValues, nestedKnownPureOps, region);
} else {
// Othe...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/180556
More information about the Mlir-commits
mailing list