[Mlir-commits] [mlir] [mlir][CSE] Introduce hoist-pure-ops logic to CSE pass (PR #180556)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 13 09:06:43 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-func

Author: lonely eagle (linuxlonelyeagle)

<details>
<summary>Changes</summary>

This PR is based on this theory: `if an Op is a Pure Op, we have the opportunity to hoist its position based on SSA dominance`. This logic has now been incorporated into the CSE pass, now we can use it to further optimize the IR to achieve more concise code.
RFC: https://discourse.llvm.org/t/rfc-mlir-introduce-hoist-pure-ops-pass/88903

---

Patch is 74.35 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/180556.diff


21 Files Affected:

- (modified) mlir/include/mlir/Dialect/Func/IR/FuncOps.h (+1) 
- (modified) mlir/include/mlir/Dialect/Func/IR/FuncOps.td (+3-1) 
- (modified) mlir/include/mlir/Dialect/SCF/IR/SCF.h (+1) 
- (modified) mlir/include/mlir/Dialect/SCF/IR/SCFOps.td (+4-1) 
- (modified) mlir/include/mlir/Interfaces/CMakeLists.txt (+1) 
- (added) mlir/include/mlir/Interfaces/HoistingContainerOpInterface.h (+13) 
- (added) mlir/include/mlir/Interfaces/HoistingContainerOpInterface.td (+41) 
- (modified) mlir/lib/Dialect/Func/IR/CMakeLists.txt (+1) 
- (modified) mlir/lib/Dialect/SCF/IR/CMakeLists.txt (+1) 
- (modified) mlir/lib/Interfaces/CMakeLists.txt (+2-1) 
- (added) mlir/lib/Interfaces/HoistingContainerOpInterface.cpp (+21) 
- (modified) mlir/lib/Transforms/CMakeLists.txt (+1) 
- (modified) mlir/lib/Transforms/CSE.cpp (+167-20) 
- (modified) mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir (+4-8) 
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir (+4-11) 
- (modified) mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir (+1-2) 
- (modified) mlir/test/Dialect/Linalg/transform-op-pad.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir (+105-106) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir (+57-61) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir (+42-43) 
- (modified) mlir/test/Transforms/cse.mlir (+10-14) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Func/IR/FuncOps.h b/mlir/include/mlir/Dialect/Func/IR/FuncOps.h
index 5e10a9f50b774..ad5eac754f236 100644
--- a/mlir/include/mlir/Dialect/Func/IR/FuncOps.h
+++ b/mlir/include/mlir/Dialect/Func/IR/FuncOps.h
@@ -18,6 +18,7 @@
 #include "mlir/Interfaces/CallInterfaces.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Interfaces/HoistingContainerOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
diff --git a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
index 06ce4f16c867d..db01cff2a6937 100644
--- a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
+++ b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
@@ -15,6 +15,7 @@ include "mlir/IR/SymbolInterfaces.td"
 include "mlir/Interfaces/CallInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/FunctionInterfaces.td"
+include "mlir/Interfaces/HoistingContainerOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
@@ -248,7 +249,8 @@ def ConstantOp : Func_Op<"constant",
 //===----------------------------------------------------------------------===//
 
 def FuncOp : Func_Op<"func", [
-  AffineScope, AutomaticAllocationScope,
+  AffineScope, AutomaticAllocationScope, 
+  DeclareOpInterfaceMethods<HoistingContainerOpInterface>,
   FunctionOpInterface, IsolatedFromAbove, OpAsmOpInterface
 ]> {
   let summary = "An operation with a name containing a single `SSACFG` region";
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index e754a04b0903a..f554147814a75 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -20,6 +20,7 @@
 #include "mlir/IR/RegionKindInterface.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/HoistingContainerOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/ParallelCombiningOpInterface.h"
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index a08cf3c95e6ce..62061a90c0d2e 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -18,6 +18,7 @@ include "mlir/Interfaces/LoopLikeInterface.td"
 include "mlir/IR/RegionKindInterface.td"
 include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Interfaces/HoistingContainerOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/ParallelCombiningOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -162,7 +163,8 @@ def ForOp : SCF_Op<"for",
        DeclareOpInterfaceMethods<RegionBranchOpInterface,
         ["getEntrySuccessorOperands", "getSuccessorInputs"]>,
        SingleBlockImplicitTerminator<"scf::YieldOp">,
-       RecursiveMemoryEffects]> {
+       RecursiveMemoryEffects,
+       DeclareOpInterfaceMethods<HoistingContainerOpInterface>]> {
   let summary = "for operation";
   let description = [{
     The `scf.for` operation represents a loop taking 3 SSA value as operands
@@ -986,6 +988,7 @@ def WhileOp : SCF_Op<"while",
         ["getEntrySuccessorOperands", "getSuccessorInputs"]>,
      DeclareOpInterfaceMethods<LoopLikeOpInterface,
         ["getRegionIterArgs", "getYieldedValuesMutable"]>,
+     DeclareOpInterfaceMethods<HoistingContainerOpInterface>,
      RecursiveMemoryEffects, SingleBlock]> {
   let summary = "a generic 'while' loop";
   let description = [{
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index eb96a68861116..7acefb64f61ab 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_interface(ControlFlowInterfaces)
 add_mlir_interface(DerivedAttributeOpInterface)
 add_mlir_interface(DestinationStyleOpInterface)
 add_mlir_interface(FunctionInterfaces)
+add_mlir_interface(HoistingContainerOpInterface)
 add_mlir_interface(IndexingMapOpInterface)
 add_mlir_interface(InferIntRangeInterface)
 add_mlir_interface(InferStridedMetadataInterface)
diff --git a/mlir/include/mlir/Interfaces/HoistingContainerOpInterface.h b/mlir/include/mlir/Interfaces/HoistingContainerOpInterface.h
new file mode 100644
index 0000000000000..6953f8b2138a3
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/HoistingContainerOpInterface.h
@@ -0,0 +1,13 @@
+
+
+#ifndef MLIR_INTERFACES_HOISTING_CONTAINER_OP_INTERFACE
+#define MLIR_INTERFACES_HOISTING_CONTAINER_OP_INTERFACE
+
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/HoistingContainerOpInterface.h.inc"
+
+namespace mlir {
+bool canContainHoistedOps(Operation *op);
+} // namespace mlir
+
+#endif // MLIR_INTERFACES_HOISTING_CONTAINER_OP_INTERFACE
diff --git a/mlir/include/mlir/Interfaces/HoistingContainerOpInterface.td b/mlir/include/mlir/Interfaces/HoistingContainerOpInterface.td
new file mode 100644
index 0000000000000..1f1c9994f09c1
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/HoistingContainerOpInterface.td
@@ -0,0 +1,41 @@
+//===- HoistingContainerOpInterface.td - Interface Decl. -*- tablegen -*---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the definition file for the HoistingContainerOpInterface. 
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_HOISTING_CONTAINER_OP_INTERFACE
+#define MLIR_INTERFACES_HOISTING_CONTAINER_OP_INTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def HoistingContainerOpInterface : OpInterface<"HoistingContainerOpInterface"> {
+  let description = [{
+    This interface models whether an operation's regions are capable of 
+    acting as a container for operations hoisted from nested regions.
+  }];
+  let cppNamespace = "::mlir";
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns true if this operation's regions can accommodate operations 
+        hoisted from its nested scopes.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"canContainHoistedOps",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return true;
+      }]
+    >
+  ];
+}
+
+#endif // MLIR_INTERFACES_HOISTING_CONTAINER_OP_INTERFACE
diff --git a/mlir/lib/Dialect/Func/IR/CMakeLists.txt b/mlir/lib/Dialect/Func/IR/CMakeLists.txt
index 329301c6fbafd..c748fdf2b57f0 100644
--- a/mlir/lib/Dialect/Func/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Func/IR/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRFuncDialect
   MLIRCallInterfaces
   MLIRControlFlowInterfaces
   MLIRFunctionInterfaces
+  MLIRHoistingContainerOpInterface
   MLIRInferTypeOpInterface
   MLIRIR
   MLIRSideEffectInterfaces
diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
index b111117410ba3..bbf27027f37b4 100644
--- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRSCFDialect
   MLIRControlFlowDialect
   MLIRDialectUtils
   MLIRFunctionInterfaces
+  MLIRHoistingContainerOpInterface
   MLIRIR
   MLIRLoopLikeInterface
   MLIRSideEffectInterfaces
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index ad3e2b61be418..1d73e5d2c6912 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -8,6 +8,7 @@ set(LLVM_OPTIONAL_SOURCES
   DestinationStyleOpInterface.cpp
   FunctionImplementation.cpp
   FunctionInterfaces.cpp
+  HoistingContainerOpInterface.cpp
   IndexingMapOpInterface.cpp
   InferIntRangeInterface.cpp
   InferStridedMetadataInterface.cpp
@@ -64,7 +65,7 @@ add_mlir_library(MLIRFunctionInterfaces
   MLIRCallInterfaces
   MLIRIR
 )
-
+add_mlir_interface_library(HoistingContainerOpInterface)
 add_mlir_interface_library(IndexingMapOpInterface)
 add_mlir_interface_library(InferIntRangeInterface)
 
diff --git a/mlir/lib/Interfaces/HoistingContainerOpInterface.cpp b/mlir/lib/Interfaces/HoistingContainerOpInterface.cpp
new file mode 100644
index 0000000000000..33801c6509ad2
--- /dev/null
+++ b/mlir/lib/Interfaces/HoistingContainerOpInterface.cpp
@@ -0,0 +1,21 @@
+//===- HoistingContainerOpInterface.cpp -- Hoisting Container Op Interface -==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/HoistingContainerOpInterface.h"
+
+using namespace mlir;
+
+namespace mlir {
+#include "mlir/Interfaces/HoistingContainerOpInterface.cpp.inc"
+} // namespace mlir
+
+bool mlir::canContainHoistedOps(Operation *op) {
+  if (auto containerOp = dyn_cast<HoistingContainerOpInterface>(op))
+    return containerOp.canContainHoistedOps();
+  return false;
+}
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 8907724627386..1a2cd72691a79 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -32,6 +32,7 @@ add_mlir_library(MLIRTransforms
   LINK_LIBS PUBLIC
   MLIRAnalysis
   MLIRFunctionInterfaces
+  MLIRHoistingContainerOpInterface
   MLIRLoopLikeInterface
   MLIRMemOpInterfaces
   MLIRMemorySlotInterfaces
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 8eaac308755fd..e22166c549048 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -15,12 +15,15 @@
 
 #include "mlir/IR/Dominance.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Interfaces/HoistingContainerOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/Passes.h"
 #include "llvm/ADT/DenseMapInfo.h"
 #include "llvm/ADT/ScopedHashTable.h"
 #include "llvm/Support/Allocator.h"
+#include "llvm/Support/DebugLog.h"
 #include "llvm/Support/RecyclingAllocator.h"
 #include <deque>
 
@@ -29,6 +32,7 @@ namespace mlir {
 #include "mlir/Transforms/Passes.h.inc"
 } // namespace mlir
 
+#define DEBUG_TYPE "cse"
 using namespace mlir;
 
 namespace {
@@ -101,13 +105,17 @@ class CSEDriver {
 
   /// Attempt to eliminate a redundant operation. Returns success if the
   /// operation was marked for removal, failure otherwise.
-  LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op,
+  LogicalResult simplifyOperation(ScopedMapTy &knownValues,
+                                  ScopedMapTy &knownPureOps, Operation *op,
                                   bool hasSSADominance);
-  void simplifyBlock(ScopedMapTy &knownValues, Block *bb, bool hasSSADominance);
-  void simplifyRegion(ScopedMapTy &knownValues, Region &region);
+  void simplifyBlock(ScopedMapTy &knownValues, ScopedMapTy &knownPureOps,
+                     Block *bb, bool hasSSADominance);
+  void simplifyRegion(ScopedMapTy &knownValues, ScopedMapTy &knownPureOps,
+                      Region &region);
 
   void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
                             Operation *existing, bool hasSSADominance);
+  LogicalResult hoistPureOp(Operation *existing, Operation *op);
 
   /// Check if there is side-effecting operations other than the given effect
   /// between the two operations.
@@ -117,7 +125,7 @@ class CSEDriver {
   RewriterBase &rewriter;
 
   /// Operations marked as dead and to be erased.
-  std::vector<Operation *> opsToErase;
+  SmallVector<Operation *> opsToErase;
   DominanceInfo *domInfo = nullptr;
   MemEffectsCache memEffectsCache;
 
@@ -127,6 +135,69 @@ class CSEDriver {
 };
 } // namespace
 
+static bool isBlockCrossIsIsolatedFromAbove(DominanceInfo *dominate, Block *a,
+                                            Block *b) {
+  if (a == b)
+    return false;
+  if (a->getParent() == b->getParent())
+    return false;
+  if (dominate->dominates(b, a))
+    std::swap(b, a);
+  while (b && b->getParentOp()) {
+    Operation *parnetOp = b->getParentOp();
+    if (parnetOp->mightHaveTrait<OpTrait::IsIsolatedFromAbove>())
+      return true;
+    b = parnetOp->getBlock();
+    if (b == a)
+      return false;
+  }
+  return false;
+}
+
+/// Hoist the pure ops to the location of the Nearest Common Dominator
+LogicalResult CSEDriver::hoistPureOp(Operation *existing, Operation *op) {
+  Block *ancestorBlock =
+      domInfo->findNearestCommonDominator(existing->getBlock(), op->getBlock());
+  if (!ancestorBlock) {
+    LDBG() << "hoist " << OpWithFlags(existing, OpPrintingFlags().skipRegions())
+           << " and " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+           << "failed";
+    return failure();
+  }
+
+  if (ancestorBlock->getParent() != existing->getParentRegion() &&
+      !canContainHoistedOps(ancestorBlock->getParentOp()))
+    return failure();
+
+  if (isBlockCrossIsIsolatedFromAbove(domInfo, ancestorBlock,
+                                      existing->getBlock()))
+    return failure();
+
+  Operation *insertPoint = nullptr;
+  for (Value operand : op->getOperands()) {
+    if (domInfo->properlyDominates(operand, &ancestorBlock->front()))
+      continue;
+    if (!insertPoint) {
+      insertPoint = operand.getDefiningOp();
+    } else {
+      insertPoint = domInfo->dominates(insertPoint, operand.getDefiningOp())
+                        ? operand.getDefiningOp()
+                        : insertPoint;
+    }
+  }
+  if (!insertPoint) {
+    rewriter.moveOpBefore(existing, ancestorBlock, ancestorBlock->begin());
+    rewriter.moveOpAfter(op, existing);
+  } else {
+    rewriter.moveOpAfter(existing, insertPoint);
+    rewriter.moveOpAfter(op, existing);
+  }
+  LDBG() << "hoist " << OpWithFlags(existing, OpPrintingFlags().skipRegions())
+         << " and " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+         << "success";
+  return success();
+}
+
 void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
                                      Operation *existing,
                                      bool hasSSADominance) {
@@ -141,6 +212,15 @@ void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
       rewriteListener->notifyOperationReplaced(op, existing);
     // Replace all uses, but do not remove the operation yet. This does not
     // notify the listener because the original op is not erased.
+    if (!domInfo->properlyDominates(existing, op)) {
+      if (failed(hoistPureOp(existing, op)))
+        return;
+    }
+    LDBG() << "replace " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+           << " with "
+           << OpWithFlags(existing, OpPrintingFlags().skipRegions());
+    LDBG() << "add " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+           << " to opsToErase";
     rewriter.replaceAllUsesWith(op->getResults(), existing->getResults());
     opsToErase.push_back(op);
   } else {
@@ -155,14 +235,25 @@ void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
         if (all_of(v.getUses(), wasVisited))
           rewriteListener->notifyOperationReplaced(op, existing);
 
+    if (!domInfo->properlyDominates(existing, op)) {
+      if (failed(hoistPureOp(existing, op)))
+        return;
+    }
     // Replace all uses, but do not remove the operation yet. This does not
     // notify the listener because the original op is not erased.
+    LDBG() << "replace " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+           << " with "
+           << OpWithFlags(existing, OpPrintingFlags().skipRegions());
     rewriter.replaceUsesWithIf(op->getResults(), existing->getResults(),
                                wasVisited);
 
     // There may be some remaining uses of the operation.
-    if (op->use_empty())
+    if (op->use_empty()) {
+      LDBG() << "use_empty, add "
+             << OpWithFlags(op, OpPrintingFlags().skipRegions())
+             << " to opsToErase";
       opsToErase.push_back(op);
+    }
   }
 
   // If the existing operation has an unknown location and the current
@@ -222,8 +313,11 @@ bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp,
 
 /// Attempt to eliminate a redundant operation.
 LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
+                                           ScopedMapTy &knownPureOps,
                                            Operation *op,
                                            bool hasSSADominance) {
+  LDBG() << "visit operation: "
+         << OpWithFlags(op, OpPrintingFlags().skipRegions());
   // Don't simplify terminator operations.
   if (op->hasTrait<OpTrait::IsTerminator>())
     return failure();
@@ -261,7 +355,27 @@ LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
         return success();
       }
     }
-    knownValues.insert(op, op);
+    if (auto *existing = knownPureOps.lookup(op)) {
+      if (existing->getBlock() == op->getBlock() &&
+          !hasOtherSideEffectingOpInBetween(existing, op)) {
+        // The operation that can be deleted has been reach with no
+        // side-effecting operations in between the existing operation and
+        // this one so we can remove the duplicate.
+        replaceUsesAndDelete(knownPureOps, op, existing, hasSSADominance);
+        return success();
+      }
+    }
+
+    if (mlir::isPure(op)) {
+      LDBG() << "insert op: "
+             << OpWithFlags(op, OpPrintingFlags().skipRegions())
+             << "to pureMap";
+      knownPureOps.insert(op, op);
+    } else {
+      LDBG() << "insert op: "
+             << OpWithFlags(op, OpPrintingFlags().skipRegions()) << "to map";
+      knownValues.insert(op, op);
+    }
     return failure();
   }
 
@@ -272,14 +386,31 @@ LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
     return success();
   }
 
-  // Otherwise, we add this operation to the known values map.
-  knownValues.insert(op, op);
+  if (auto *existing = knownPureOps.lookup(op)) {
+    replaceUsesAndDelete(knownPureOps, op, existing, hasSSADominance);
+    ++numCSE;
+    return success();
+  }
+
+  if (mlir::isPure(op)) {
+    LDBG() << "insert op: " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+           << "to pureMap";
+    knownPureOps.insert(op, op);
+  } else {
+    // Otherwise, we add this operation to the known values map.
+    LDBG() << "insert op: " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+           << "to map";
+    knownValues.insert(op, op);
+  }
   return failure();
 }
 
-void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb,
+void CSEDriver::simplifyBlock(ScopedMapTy &knownValues,
+                              ScopedMapTy &knownPureOps, Block *bb,
                               bool hasSSADominance) {
-  for (auto &op : *bb) {
+  LDBG() << "visit block #" << bb->computeBlockNumber() << " of "
+         << OpWithFlags(bb->getParentOp(), OpPrintingFlags().skipRegions());
+  for (auto &op : llvm::make_early_inc_range(*bb)) {
     // Most operations don't have regions, so fast path that case.
     if (op.getNumRegions() != 0) {
       // If this operation is isolated above, we can't process nested regions
@@ -287,34 +418,45 @@ void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb,
       // implicit captures in explicit capture only regions.
       if (op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>()) {
         ScopedMapTy nestedKnownValues;
+        ScopedMapTy nestedKnownPureOps;
+        ScopedMapTy::ScopeTy scope(nestedKnownValues);
+        ScopedMapTy::ScopeTy pureScope(nestedKnownPureOps);
         for (auto &region : op.getRegions())
-          simplifyRegion(nestedKnownValues, region);
+          simplifyRegion(nestedKnownValues, nestedKnownPureOps, region);
       } else {
         // Othe...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/180556


More information about the Mlir-commits mailing list