[Mlir-commits] [flang] [mlir] [mlir][CSE] Introduce hoist-pure-ops logic to CSE pass (PR #180556)
lonely eagle
llvmlistbot at llvm.org
Mon Mar 2 03:54:30 PST 2026
https://github.com/linuxlonelyeagle updated https://github.com/llvm/llvm-project/pull/180556
>From d2c4f1ef8166c1a8449b1b4537e05f7c5665dcf4 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Sun, 8 Feb 2026 06:27:20 +0000
Subject: [PATCH 1/7] add hoist pure loigc to CSE.
---
mlir/lib/Transforms/CSE.cpp | 159 +++++++++++--
.../ArmSMEToLLVM/tile-spills-and-fills.mlir | 12 +-
...ot-bufferize-empty-tensor-elimination.mlir | 15 +-
.../Linalg/matmul-shared-memory-padding.mlir | 3 +-
.../test/Dialect/Linalg/transform-op-pad.mlir | 2 +-
.../SparseTensor/sparse_fill_zero.mlir | 211 +++++++++---------
.../sparse_kernels_to_iterator.mlir | 118 +++++-----
.../SparseTensor/sparse_vector_index.mlir | 85 ++++---
mlir/test/Transforms/cse.mlir | 24 +-
9 files changed, 363 insertions(+), 266 deletions(-)
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 48a20655b2806..6cd9329e77fc8 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -15,12 +15,14 @@
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/Support/Allocator.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/RecyclingAllocator.h"
#include <deque>
@@ -29,6 +31,7 @@ namespace mlir {
#include "mlir/Transforms/Passes.h.inc"
} // namespace mlir
+#define DEBUG_TYPE "cse"
using namespace mlir;
namespace {
@@ -101,13 +104,17 @@ class CSEDriver {
/// Attempt to eliminate a redundant operation. Returns success if the
/// operation was marked for removal, failure otherwise.
- LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op,
+ LogicalResult simplifyOperation(ScopedMapTy &knownValues,
+ ScopedMapTy &knownPureOps, Operation *op,
bool hasSSADominance);
- void simplifyBlock(ScopedMapTy &knownValues, Block *bb, bool hasSSADominance);
- void simplifyRegion(ScopedMapTy &knownValues, Region ®ion);
+ void simplifyBlock(ScopedMapTy &knownValues, ScopedMapTy &knownPureOps,
+ Block *bb, bool hasSSADominance);
+ void simplifyRegion(ScopedMapTy &knownValues, ScopedMapTy &knownPureOps,
+ Region ®ion);
void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
Operation *existing, bool hasSSADominance);
+ LogicalResult hoistPureOp(Operation *existing, Operation *op);
/// Check if there is side-effecting operations other than the given effect
/// between the two operations.
@@ -117,7 +124,7 @@ class CSEDriver {
RewriterBase &rewriter;
/// Operations marked as dead and to be erased.
- std::vector<Operation *> opsToErase;
+ SmallVector<Operation *> opsToErase;
DominanceInfo *domInfo = nullptr;
MemEffectsCache memEffectsCache;
@@ -127,6 +134,42 @@ class CSEDriver {
};
} // namespace
+/// Hoist the pure ops to the location of the Nearest Common Dominator.
+LogicalResult CSEDriver::hoistPureOp(Operation *existing, Operation *op) {
+ Block *ancestorBlock =
+ domInfo->findNearestCommonDominator(existing->getBlock(), op->getBlock());
+ if (!ancestorBlock) {
+ LDBG() << "hoist " << OpWithFlags(existing, OpPrintingFlags().skipRegions())
+ << " and " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << "failed";
+ return failure();
+ }
+
+ Operation *insertPoint = nullptr;
+ for (Value operand : op->getOperands()) {
+ if (domInfo->properlyDominates(operand, &ancestorBlock->front()))
+ continue;
+ if (!insertPoint) {
+ insertPoint = operand.getDefiningOp();
+ } else {
+ insertPoint = domInfo->dominates(insertPoint, operand.getDefiningOp())
+ ? operand.getDefiningOp()
+ : insertPoint;
+ }
+ }
+ if (!insertPoint) {
+ rewriter.moveOpBefore(existing, ancestorBlock, ancestorBlock->begin());
+ rewriter.moveOpAfter(op, existing);
+ } else {
+ rewriter.moveOpAfter(existing, insertPoint);
+ rewriter.moveOpAfter(op, existing);
+ }
+ LDBG() << "hoist " << OpWithFlags(existing, OpPrintingFlags().skipRegions())
+ << " and " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << "success";
+ return success();
+}
+
void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
Operation *existing,
bool hasSSADominance) {
@@ -141,6 +184,15 @@ void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
rewriteListener->notifyOperationReplaced(op, existing);
// Replace all uses, but do not remove the operation yet. This does not
// notify the listener because the original op is not erased.
+ if (!domInfo->properlyDominates(existing, op)) {
+ if (failed(hoistPureOp(existing, op)))
+ return;
+ }
+ LDBG() << "replace " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << " with "
+ << OpWithFlags(existing, OpPrintingFlags().skipRegions());
+ LDBG() << "add " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << " to opsToErase";
rewriter.replaceAllUsesWith(op->getResults(), existing->getResults());
opsToErase.push_back(op);
} else {
@@ -155,14 +207,25 @@ void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
if (all_of(v.getUses(), wasVisited))
rewriteListener->notifyOperationReplaced(op, existing);
+ if (!domInfo->properlyDominates(existing, op)) {
+ if (failed(hoistPureOp(existing, op)))
+ return;
+ }
// Replace all uses, but do not remove the operation yet. This does not
// notify the listener because the original op is not erased.
+ LDBG() << "replace " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << " with "
+ << OpWithFlags(existing, OpPrintingFlags().skipRegions());
rewriter.replaceUsesWithIf(op->getResults(), existing->getResults(),
wasVisited);
// There may be some remaining uses of the operation.
- if (op->use_empty())
+ if (op->use_empty()) {
+ LDBG() << "use_empty, add "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << " to opsToErase";
opsToErase.push_back(op);
+ }
}
// If the existing operation has an unknown location and the current
@@ -222,8 +285,11 @@ bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp,
/// Attempt to eliminate a redundant operation.
LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
+ ScopedMapTy &knownPureOps,
Operation *op,
bool hasSSADominance) {
+ LDBG() << "visit operation: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
// Don't simplify terminator operations.
if (op->hasTrait<OpTrait::IsTerminator>())
return failure();
@@ -261,7 +327,27 @@ LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
return success();
}
}
- knownValues.insert(op, op);
+ if (auto *existing = knownPureOps.lookup(op)) {
+ if (existing->getBlock() == op->getBlock() &&
+ !hasOtherSideEffectingOpInBetween(existing, op)) {
+ // The operation that can be deleted has been reach with no
+ // side-effecting operations in between the existing operation and
+ // this one so we can remove the duplicate.
+ replaceUsesAndDelete(knownPureOps, op, existing, hasSSADominance);
+ return success();
+ }
+ }
+
+ if (mlir::isPure(op)) {
+ LDBG() << "insert op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << "to pureMap";
+ knownPureOps.insert(op, op);
+ } else {
+ LDBG() << "insert op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions()) << "to map";
+ knownValues.insert(op, op);
+ }
return failure();
}
@@ -272,14 +358,31 @@ LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
return success();
}
- // Otherwise, we add this operation to the known values map.
- knownValues.insert(op, op);
+ if (auto *existing = knownPureOps.lookup(op)) {
+ replaceUsesAndDelete(knownPureOps, op, existing, hasSSADominance);
+ ++numCSE;
+ return success();
+ }
+
+ if (mlir::isPure(op)) {
+ LDBG() << "insert op: " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << "to pureMap";
+ knownPureOps.insert(op, op);
+ } else {
+ // Otherwise, we add this operation to the known values map.
+ LDBG() << "insert op: " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << "to map";
+ knownValues.insert(op, op);
+ }
return failure();
}
-void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb,
+void CSEDriver::simplifyBlock(ScopedMapTy &knownValues,
+ ScopedMapTy &knownPureOps, Block *bb,
bool hasSSADominance) {
- for (auto &op : *bb) {
+ LDBG() << "visit block #" << bb->computeBlockNumber() << " of "
+ << OpWithFlags(bb->getParentOp(), OpPrintingFlags().skipRegions());
+ for (auto &op : llvm::make_early_inc_range(*bb)) {
// Most operations don't have regions, so fast path that case.
if (op.getNumRegions() != 0) {
// If this operation is isolated above, we can't process nested regions
@@ -287,34 +390,45 @@ void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb,
// implicit captures in explicit capture only regions.
if (op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>()) {
ScopedMapTy nestedKnownValues;
+ ScopedMapTy nestedKnownPureOps;
+ ScopedMapTy::ScopeTy scope(nestedKnownValues);
+ ScopedMapTy::ScopeTy pureScope(nestedKnownPureOps);
for (auto ®ion : op.getRegions())
- simplifyRegion(nestedKnownValues, region);
+ simplifyRegion(nestedKnownValues, nestedKnownPureOps, region);
} else {
// Otherwise, process nested regions normally.
for (auto ®ion : op.getRegions())
- simplifyRegion(knownValues, region);
+ simplifyRegion(knownValues, knownPureOps, region);
}
}
// If the operation is simplified, we don't process any held regions.
- if (succeeded(simplifyOperation(knownValues, &op, hasSSADominance)))
+ if (succeeded(
+ simplifyOperation(knownValues, knownPureOps, &op, hasSSADominance)))
continue;
}
// Clear the MemoryEffects cache since its usage is by block only.
memEffectsCache.clear();
}
-void CSEDriver::simplifyRegion(ScopedMapTy &knownValues, Region ®ion) {
+void CSEDriver::simplifyRegion(ScopedMapTy &knownValues,
+ ScopedMapTy &knownPureOps, Region ®ion) {
// If the region is empty there is nothing to do.
if (region.empty())
return;
+ LDBG() << "visit region #" << region.getRegionNumber() << " of "
+ << OpWithFlags(region.getParentOp(), OpPrintingFlags().skipRegions());
+ // Prevent CSE of pure operations across function boundaries.
+ std::optional<ScopedMapTy::ScopeTy> funcPureScope;
+ if (isa<FunctionOpInterface>(region.getParentOp())) {
+ funcPureScope.emplace(knownPureOps);
+ }
bool hasSSADominance = domInfo->hasSSADominance(®ion);
-
// If the region only contains one block, then simplify it directly.
if (region.hasOneBlock()) {
ScopedMapTy::ScopeTy scope(knownValues);
- simplifyBlock(knownValues, ®ion.front(), hasSSADominance);
+ simplifyBlock(knownValues, knownPureOps, ®ion.front(), hasSSADominance);
return;
}
@@ -342,7 +456,7 @@ void CSEDriver::simplifyRegion(ScopedMapTy &knownValues, Region ®ion) {
// Check to see if we need to process this node.
if (!currentNode->processed) {
currentNode->processed = true;
- simplifyBlock(knownValues, currentNode->node->getBlock(),
+ simplifyBlock(knownValues, knownPureOps, currentNode->node->getBlock(),
hasSSADominance);
}
@@ -361,9 +475,14 @@ void CSEDriver::simplifyRegion(ScopedMapTy &knownValues, Region ®ion) {
void CSEDriver::simplify(Operation *op, bool *changed) {
/// Simplify all regions.
- ScopedMapTy knownValues;
- for (auto ®ion : op->getRegions())
- simplifyRegion(knownValues, region);
+ {
+ ScopedMapTy knownValues;
+ ScopedMapTy knownPureOps;
+ ScopedMapTy::ScopeTy scope(knownPureOps);
+ for (auto ®ion : op->getRegions()) {
+ simplifyRegion(knownValues, knownPureOps, region);
+ }
+ }
/// Erase any operations that were marked as dead during simplification.
for (auto *op : opsToErase)
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
index 2a183cb4d056a..7fd66035a4140 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
@@ -82,6 +82,8 @@ func.func @use_too_many_tiles() {
// AFTER-LLVM-LOWERING-DAG: %[[TILE_ALLOCA:.*]] = memref.alloca(%[[SVL_H]], %[[SVL_H]])
// AFTER-LLVM-LOWERING-SAME: {arm_sme.in_memory_tile_id = 16 : i32} : memref<?x?xi16>
+// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]
+// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1]
//
// AFTER-LLVM-LOWERING-NOT: scf.for
@@ -104,8 +106,6 @@ func.func @use_too_many_tiles() {
// AFTER-LLVM-LOWERING: scf.for
// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_H]] step %[[C1]] {
-// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]
-// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1]
// AFTER-LLVM-LOWERING: %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]]
// AFTER-LLVM-LOWERING: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}>
// AFTER-LLVM-LOWERING-NEXT: "arm_sme.intr.ld1h.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}>
@@ -122,8 +122,6 @@ func.func @use_too_many_tiles() {
// AFTER-LLVM-LOWERING: scf.for
// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_H]] step %[[C1]] {
-// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]
-// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1]
// AFTER-LLVM-LOWERING: %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]]
// AFTER-LLVM-LOWERING: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}>
// AFTER-LLVM-LOWERING-NEXT: "arm_sme.intr.ld1h.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}>
@@ -156,6 +154,8 @@ func.func @very_excessive_spills(%useAllTiles : vector<[16]x[16]xi8>, %memref: m
// AFTER-LLVM-LOWERING-DAG: %[[SVL_S:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
// AFTER-LLVM-LOWERING-DAG: %[[TILE_ALLOCA:.*]] = memref.alloca(%[[SVL_S]], %[[SVL_S]])
// AFTER-LLVM-LOWERING-SAME: {arm_sme.in_memory_tile_id = 16 : i32} : memref<?x?xf32>
+// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]
+// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1]
//
/// 1. Swap %useAllTiles and %tile - note that this will only swap one 32-bit
@@ -163,8 +163,6 @@ func.func @very_excessive_spills(%useAllTiles : vector<[16]x[16]xi8>, %memref: m
// AFTER-LLVM-LOWERING: scf.for
// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_S]] step %[[C1]] {
-// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]
-// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1]
// AFTER-LLVM-LOWERING: %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]]
// Read ZA tile slice -> vector
// AFTER-LLVM-LOWERING: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}>
@@ -182,8 +180,6 @@ func.func @very_excessive_spills(%useAllTiles : vector<[16]x[16]xi8>, %memref: m
// AFTER-LLVM-LOWERING: scf.for
// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_S]] step %[[C1]] {
-// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]
-// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1]
// AFTER-LLVM-LOWERING: %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]]
/// Read ZA tile slice -> vector
// AFTER-LLVM-LOWERING: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}>
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
index 3929f5be3b4ef..8dc6364fddb2e 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -199,27 +199,20 @@ func.func @eleminate_multiple_ops(%t: tensor<?xf32> {bufferization.buffer_layout
{
%cst1 = arith.constant 0.0: f32
%cst2 = arith.constant 1.0: f32
-
- // CHECK: %[[r:.*]] = scf.if %{{.*}} -> (memref
+ // CHECK: %[[T_SUBVIEW_1:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1]
+ // CHECK: scf.if %{{.*}}
%if = scf.if %c -> tensor<?xf32> {
- // CHECK: %[[T_SUBVIEW_1:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1]
%a1 = tensor.empty(%sz) : tensor<?xf32>
// CHECK: linalg.fill ins({{.*}} : f32) outs(%[[T_SUBVIEW_1]] : memref<?xf32
%f1 = linalg.fill ins(%cst1 : f32) outs(%a1 : tensor<?xf32>) -> tensor<?xf32>
- // CHECK: scf.yield %[[T_SUBVIEW_1]]
scf.yield %f1 : tensor<?xf32>
} else {
- // CHECK: %[[T_SUBVIEW_2:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1]
%a2 = tensor.empty(%sz) : tensor<?xf32>
- // CHECK: linalg.fill ins({{.*}} : f32) outs(%[[T_SUBVIEW_2]] : memref<?xf32
+ // CHECK: linalg.fill ins({{.*}} : f32) outs(%[[T_SUBVIEW_1]] : memref<?xf32
%f2 = linalg.fill ins(%cst2 : f32) outs(%a2 : tensor<?xf32>) -> tensor<?xf32>
- // CHECK: scf.yield %[[T_SUBVIEW_2]]
scf.yield %f2 : tensor<?xf32>
}
-
- // Self-copy could canonicalize away later.
- // CHECK: %[[T_SUBVIEW_3:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1]
- // CHECK: memref.copy %[[r]], %[[T_SUBVIEW_3]]
+ // CHECK: return %[[FUNC_ARG]]
%r1 = tensor.insert_slice %if into %t[42][%sz][1]: tensor<?xf32> into tensor<?xf32>
return %r1: tensor<?xf32>
}
diff --git a/mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir b/mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir
index 6cab25b50460d..8aa59b2bf0f5f 100644
--- a/mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir
+++ b/mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir
@@ -5,14 +5,13 @@
// CHECK-NOT: memref.copy
// CHECK: linalg.fill
// CHECK: scf.for
+// CHECK: vector.constant_mask [16, 4] : vector<128x4xi1>
// CHECK: memref.alloc() : memref<128x16xf32, 3>
// CHECK: scf.forall
-// CHECK: vector.constant_mask [16, 4] : vector<128x4xi1>
// CHECK: vector.transfer_read
// CHECK: vector.transfer_write
// CHECK: memref.alloc() : memref<16x128xf32, 3>
// CHECK: scf.forall
-// CHECK: vector.constant_mask [16, 4] : vector<128x4xi1>
// CHECK: vector.transfer_read
// CHECK: vector.transfer_write
// CHECK: memref.alloc() : memref<128x128xf32, 3>
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
index 51bf4a23406d4..9b6716f09df37 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -315,7 +315,7 @@ module attributes {transform.with_named_sequence} {
// Test dynamic padding using `use_prescribed_tensor_shapes`
-// CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (-s0 + (s0 ceildiv 7) * 7)>
+// CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (-s0 + (s0 ceildiv 7) * 7)>
// CHECK: @use_prescribed_tensor_shapes
// CHECK: (%[[ARG0:.*]]: tensor<?x12xf32>, %[[ARG1:.*]]: tensor<12x?xf32>
func.func @use_prescribed_tensor_shapes(%arg0: tensor<?x12xf32>,
diff --git a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
index c26ba56347299..f6aa91ad2b10f 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
@@ -3,118 +3,117 @@
#DCSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>
// CHECK-LABEL: func.func @fill_zero_after_alloc(
-// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr,
-// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr) -> !llvm.ptr {
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK-DAG: %[[ZERO:.*]] = llvm.mlir.zero : !llvm.ptr
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : i32
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : i32
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_7:.*]] = arith.constant false
-// CHECK-DAG: %[[VAL_8:.*]] = arith.constant true
-// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 100 : index
-// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 300 : index
-// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 262144 : i64
-// CHECK: %[[VAL_12:.*]] = memref.alloca() : memref<2xi64>
-// CHECK: %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi64> to memref<?xi64>
-// CHECK: memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi64>
-// CHECK: memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_6]]] : memref<2xi64>
-// CHECK: %[[VAL_14:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[VAL_15:.*]] = memref.cast %[[VAL_14]] : memref<2xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_9]], %[[VAL_14]]{{\[}}%[[VAL_5]]] : memref<2xindex>
-// CHECK: memref.store %[[VAL_10]], %[[VAL_14]]{{\[}}%[[VAL_6]]] : memref<2xindex>
-// CHECK: %[[VAL_16:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[VAL_17:.*]] = memref.cast %[[VAL_16]] : memref<2xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_5]], %[[VAL_16]]{{\[}}%[[VAL_5]]] : memref<2xindex>
-// CHECK: memref.store %[[VAL_6]], %[[VAL_16]]{{\[}}%[[VAL_6]]] : memref<2xindex>
-// CHECK: %[[VAL_19:.*]] = call @newSparseTensor(%[[VAL_15]], %[[VAL_15]], %[[VAL_13]], %[[VAL_17]], %[[VAL_17]], %[[VAL_4]], %[[VAL_4]], %[[VAL_3]], %[[VAL_4]], %[[ZERO]]) : (memref<?xindex>, memref<?xindex>, memref<?xi64>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr
-// CHECK: %[[VAL_20:.*]] = memref.alloc() : memref<300xf64>
-// CHECK: %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<300xf64> to memref<?xf64>
-// CHECK: %[[VAL_22:.*]] = memref.alloc() : memref<300xi1>
-// CHECK: %[[VAL_23:.*]] = memref.cast %[[VAL_22]] : memref<300xi1> to memref<?xi1>
-// CHECK: %[[VAL_24:.*]] = memref.alloc() : memref<300xindex>
-// CHECK: %[[VAL_25:.*]] = memref.cast %[[VAL_24]] : memref<300xindex> to memref<?xindex>
-// CHECK: linalg.fill ins(%[[VAL_2]] : f64) outs(%[[VAL_20]] : memref<300xf64>)
-// CHECK: linalg.fill ins(%[[VAL_7]] : i1) outs(%[[VAL_22]] : memref<300xi1>)
-// CHECK-DAG: %[[VAL_26:.*]] = call @sparsePositions0(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK-DAG: %[[VAL_27:.*]] = call @sparseCoordinates0(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK-DAG: %[[VAL_28:.*]] = call @sparsePositions0(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK-DAG: %[[VAL_29:.*]] = call @sparseCoordinates0(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK-DAG: %[[VAL_30:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr) -> memref<?xf64>
-// CHECK-DAG: %[[VAL_31:.*]] = call @sparsePositions0(%[[VAL_1]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK-DAG: %[[VAL_32:.*]] = call @sparseCoordinates0(%[[VAL_1]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK-DAG: %[[VAL_33:.*]] = call @sparsePositions0(%[[VAL_1]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK-DAG: %[[VAL_34:.*]] = call @sparseCoordinates0(%[[VAL_1]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK-DAG: %[[VAL_35:.*]] = call @sparseValuesF64(%[[VAL_1]]) : (!llvm.ptr) -> memref<?xf64>
-// CHECK: %[[VAL_36:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_6]]] : memref<?xindex>
-// CHECK: scf.for %[[VAL_38:.*]] = %[[VAL_36]] to %[[VAL_37]] step %[[VAL_6]] {
-// CHECK: %[[VAL_39:.*]] = memref.load %[[VAL_27]]{{\[}}%[[VAL_38]]] : memref<?xindex>
-// CHECK: %[[VAL_40:.*]] = memref.load %[[VAL_28]]{{\[}}%[[VAL_38]]] : memref<?xindex>
-// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_38]], %[[VAL_6]] : index
-// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_28]]{{\[}}%[[VAL_41]]] : memref<?xindex>
-// CHECK: %[[VAL_43:.*]] = memref.load %[[VAL_31]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_31]]{{\[}}%[[VAL_6]]] : memref<?xindex>
-// CHECK: %[[VAL_45:.*]]:3 = scf.while (%[[VAL_46:.*]] = %[[VAL_40]], %[[VAL_47:.*]] = %[[VAL_43]], %[[VAL_48:.*]] = %[[VAL_5]]) : (index, index, index) -> (index, index, index) {
-// CHECK: %[[VAL_49:.*]] = arith.cmpi ult, %[[VAL_46]], %[[VAL_42]] : index
-// CHECK: %[[VAL_50:.*]] = arith.cmpi ult, %[[VAL_47]], %[[VAL_44]] : index
-// CHECK: %[[VAL_51:.*]] = arith.andi %[[VAL_49]], %[[VAL_50]] : i1
-// CHECK: scf.condition(%[[VAL_51]]) %[[VAL_46]], %[[VAL_47]], %[[VAL_48]] : index, index, index
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr,
+// CHECK-SAME: %[[ARG1:.*]]: !llvm.ptr) -> !llvm.ptr {
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK: %[[MLIR_0:.*]] = llvm.mlir.zero : !llvm.ptr
+// CHECK: %[[CONSTANT_1:.*]] = arith.constant 1 : i32
+// CHECK: %[[CONSTANT_2:.*]] = arith.constant 0 : i32
+// CHECK: %[[CONSTANT_3:.*]] = arith.constant true
+// CHECK: %[[CONSTANT_4:.*]] = arith.constant false
+// CHECK: %[[CONSTANT_5:.*]] = arith.constant 1 : index
+// CHECK: %[[CONSTANT_6:.*]] = arith.constant 0 : index
+// CHECK: %[[CONSTANT_7:.*]] = arith.constant 100 : index
+// CHECK: %[[CONSTANT_8:.*]] = arith.constant 300 : index
+// CHECK: %[[CONSTANT_9:.*]] = arith.constant 262144 : i64
+// CHECK: %[[ALLOCA_0:.*]] = memref.alloca() : memref<2xi64>
+// CHECK: %[[CAST_0:.*]] = memref.cast %[[ALLOCA_0]] : memref<2xi64> to memref<?xi64>
+// CHECK: memref.store %[[CONSTANT_9]], %[[ALLOCA_0]]{{\[}}%[[CONSTANT_6]]] : memref<2xi64>
+// CHECK: memref.store %[[CONSTANT_9]], %[[ALLOCA_0]]{{\[}}%[[CONSTANT_5]]] : memref<2xi64>
+// CHECK: %[[ALLOCA_1:.*]] = memref.alloca() : memref<2xindex>
+// CHECK: %[[CAST_1:.*]] = memref.cast %[[ALLOCA_1]] : memref<2xindex> to memref<?xindex>
+// CHECK: memref.store %[[CONSTANT_7]], %[[ALLOCA_1]]{{\[}}%[[CONSTANT_6]]] : memref<2xindex>
+// CHECK: memref.store %[[CONSTANT_8]], %[[ALLOCA_1]]{{\[}}%[[CONSTANT_5]]] : memref<2xindex>
+// CHECK: %[[ALLOCA_2:.*]] = memref.alloca() : memref<2xindex>
+// CHECK: %[[CAST_2:.*]] = memref.cast %[[ALLOCA_2]] : memref<2xindex> to memref<?xindex>
+// CHECK: memref.store %[[CONSTANT_6]], %[[ALLOCA_2]]{{\[}}%[[CONSTANT_6]]] : memref<2xindex>
+// CHECK: memref.store %[[CONSTANT_5]], %[[ALLOCA_2]]{{\[}}%[[CONSTANT_5]]] : memref<2xindex>
+// CHECK: %[[VAL_0:.*]] = call @newSparseTensor(%[[CAST_1]], %[[CAST_1]], %[[CAST_0]], %[[CAST_2]], %[[CAST_2]], %[[CONSTANT_2]], %[[CONSTANT_2]], %[[CONSTANT_1]], %[[CONSTANT_2]], %[[MLIR_0]]) : (memref<?xindex>, memref<?xindex>, memref<?xi64>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr
+// CHECK: %[[ALLOC_0:.*]] = memref.alloc() : memref<300xf64>
+// CHECK: %[[CAST_3:.*]] = memref.cast %[[ALLOC_0]] : memref<300xf64> to memref<?xf64>
+// CHECK: %[[ALLOC_1:.*]] = memref.alloc() : memref<300xi1>
+// CHECK: %[[CAST_4:.*]] = memref.cast %[[ALLOC_1]] : memref<300xi1> to memref<?xi1>
+// CHECK: %[[ALLOC_2:.*]] = memref.alloc() : memref<300xindex>
+// CHECK: %[[CAST_5:.*]] = memref.cast %[[ALLOC_2]] : memref<300xindex> to memref<?xindex>
+// CHECK: linalg.fill ins(%[[CONSTANT_0]] : f64) outs(%[[ALLOC_0]] : memref<300xf64>)
+// CHECK: linalg.fill ins(%[[CONSTANT_4]] : i1) outs(%[[ALLOC_1]] : memref<300xi1>)
+// CHECK: %[[VAL_1:.*]] = call @sparseValuesF64(%[[ARG0]]) : (!llvm.ptr) -> memref<?xf64>
+// CHECK: %[[VAL_2:.*]] = call @sparseValuesF64(%[[ARG1]]) : (!llvm.ptr) -> memref<?xf64>
+// CHECK: %[[VAL_3:.*]] = call @sparsePositions0(%[[ARG0]], %[[CONSTANT_6]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK: %[[VAL_4:.*]] = call @sparseCoordinates0(%[[ARG0]], %[[CONSTANT_6]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK: %[[VAL_5:.*]] = call @sparsePositions0(%[[ARG0]], %[[CONSTANT_5]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK: %[[VAL_6:.*]] = call @sparseCoordinates0(%[[ARG0]], %[[CONSTANT_5]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK: %[[VAL_7:.*]] = call @sparsePositions0(%[[ARG1]], %[[CONSTANT_6]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK: %[[VAL_8:.*]] = call @sparseCoordinates0(%[[ARG1]], %[[CONSTANT_6]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK: %[[VAL_9:.*]] = call @sparsePositions0(%[[ARG1]], %[[CONSTANT_5]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK: %[[VAL_10:.*]] = call @sparseCoordinates0(%[[ARG1]], %[[CONSTANT_5]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK: %[[LOAD_0:.*]] = memref.load %[[VAL_3]]{{\[}}%[[CONSTANT_6]]] : memref<?xindex>
+// CHECK: %[[LOAD_1:.*]] = memref.load %[[VAL_3]]{{\[}}%[[CONSTANT_5]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_11:.*]] = %[[LOAD_0]] to %[[LOAD_1]] step %[[CONSTANT_5]] {
+// CHECK: %[[LOAD_2:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_11]]] : memref<?xindex>
+// CHECK: %[[LOAD_3:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_11]]] : memref<?xindex>
+// CHECK: %[[ADDI_0:.*]] = arith.addi %[[VAL_11]], %[[CONSTANT_5]] : index
+// CHECK: %[[LOAD_4:.*]] = memref.load %[[VAL_5]]{{\[}}%[[ADDI_0]]] : memref<?xindex>
+// CHECK: %[[LOAD_5:.*]] = memref.load %[[VAL_7]]{{\[}}%[[CONSTANT_6]]] : memref<?xindex>
+// CHECK: %[[LOAD_6:.*]] = memref.load %[[VAL_7]]{{\[}}%[[CONSTANT_5]]] : memref<?xindex>
+// CHECK: %[[WHILE_0:.*]]:3 = scf.while (%[[VAL_12:.*]] = %[[LOAD_3]], %[[VAL_13:.*]] = %[[LOAD_5]], %[[VAL_14:.*]] = %[[CONSTANT_6]]) : (index, index, index) -> (index, index, index) {
+// CHECK: %[[CMPI_0:.*]] = arith.cmpi ult, %[[VAL_12]], %[[LOAD_4]] : index
+// CHECK: %[[CMPI_1:.*]] = arith.cmpi ult, %[[VAL_13]], %[[LOAD_6]] : index
+// CHECK: %[[ANDI_0:.*]] = arith.andi %[[CMPI_0]], %[[CMPI_1]] : i1
+// CHECK: scf.condition(%[[ANDI_0]]) %[[VAL_12]], %[[VAL_13]], %[[VAL_14]] : index, index, index
// CHECK: } do {
-// CHECK: ^bb0(%[[VAL_52:.*]]: index, %[[VAL_53:.*]]: index, %[[VAL_54:.*]]: index):
-// CHECK: %[[VAL_55:.*]] = memref.load %[[VAL_29]]{{\[}}%[[VAL_52]]] : memref<?xindex>
-// CHECK: %[[VAL_56:.*]] = memref.load %[[VAL_32]]{{\[}}%[[VAL_53]]] : memref<?xindex>
-// CHECK: %[[VAL_57:.*]] = arith.cmpi ult, %[[VAL_56]], %[[VAL_55]] : index
-// CHECK: %[[VAL_58:.*]] = arith.select %[[VAL_57]], %[[VAL_56]], %[[VAL_55]] : index
-// CHECK: %[[VAL_59:.*]] = arith.cmpi eq, %[[VAL_55]], %[[VAL_58]] : index
-// CHECK: %[[VAL_60:.*]] = arith.cmpi eq, %[[VAL_56]], %[[VAL_58]] : index
-// CHECK: %[[VAL_61:.*]] = arith.andi %[[VAL_59]], %[[VAL_60]] : i1
-// CHECK: %[[VAL_62:.*]] = scf.if %[[VAL_61]] -> (index) {
-// CHECK: %[[VAL_63:.*]] = memref.load %[[VAL_30]]{{\[}}%[[VAL_52]]] : memref<?xf64>
-// CHECK: %[[VAL_64:.*]] = memref.load %[[VAL_33]]{{\[}}%[[VAL_53]]] : memref<?xindex>
-// CHECK: %[[VAL_65:.*]] = arith.addi %[[VAL_53]], %[[VAL_6]] : index
-// CHECK: %[[VAL_66:.*]] = memref.load %[[VAL_33]]{{\[}}%[[VAL_65]]] : memref<?xindex>
-// CHECK: %[[VAL_67:.*]] = scf.for %[[VAL_68:.*]] = %[[VAL_64]] to %[[VAL_66]] step %[[VAL_6]] iter_args(%[[VAL_69:.*]] = %[[VAL_54]]) -> (index) {
-// CHECK: %[[VAL_70:.*]] = memref.load %[[VAL_34]]{{\[}}%[[VAL_68]]] : memref<?xindex>
-// CHECK: %[[VAL_71:.*]] = memref.load %[[VAL_20]]{{\[}}%[[VAL_70]]] : memref<300xf64>
-// CHECK: %[[VAL_72:.*]] = memref.load %[[VAL_35]]{{\[}}%[[VAL_68]]] : memref<?xf64>
-// CHECK: %[[VAL_73:.*]] = arith.mulf %[[VAL_63]], %[[VAL_72]] : f64
-// CHECK: %[[VAL_74:.*]] = arith.addf %[[VAL_71]], %[[VAL_73]] : f64
-// CHECK: %[[VAL_75:.*]] = memref.load %[[VAL_22]]{{\[}}%[[VAL_70]]] : memref<300xi1>
-// CHECK: %[[VAL_76:.*]] = arith.cmpi eq, %[[VAL_75]], %[[VAL_7]] : i1
-// CHECK: %[[VAL_77:.*]] = scf.if %[[VAL_76]] -> (index) {
-// CHECK: memref.store %[[VAL_8]], %[[VAL_22]]{{\[}}%[[VAL_70]]] : memref<300xi1>
-// CHECK: memref.store %[[VAL_70]], %[[VAL_24]]{{\[}}%[[VAL_69]]] : memref<300xindex>
-// CHECK: %[[VAL_78:.*]] = arith.addi %[[VAL_69]], %[[VAL_6]] : index
-// CHECK: scf.yield %[[VAL_78]] : index
+// CHECK: ^bb0(%[[VAL_15:.*]]: index, %[[VAL_16:.*]]: index, %[[VAL_17:.*]]: index):
+// CHECK: %[[ADDI_1:.*]] = arith.addi %[[VAL_16]], %[[CONSTANT_5]] : index
+// CHECK: %[[LOAD_7:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_15]]] : memref<?xindex>
+// CHECK: %[[LOAD_8:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref<?xindex>
+// CHECK: %[[CMPI_2:.*]] = arith.cmpi ult, %[[LOAD_8]], %[[LOAD_7]] : index
+// CHECK: %[[SELECT_0:.*]] = arith.select %[[CMPI_2]], %[[LOAD_8]], %[[LOAD_7]] : index
+// CHECK: %[[CMPI_3:.*]] = arith.cmpi eq, %[[LOAD_7]], %[[SELECT_0]] : index
+// CHECK: %[[CMPI_4:.*]] = arith.cmpi eq, %[[LOAD_8]], %[[SELECT_0]] : index
+// CHECK: %[[ANDI_1:.*]] = arith.andi %[[CMPI_3]], %[[CMPI_4]] : i1
+// CHECK: %[[IF_0:.*]] = scf.if %[[ANDI_1]] -> (index) {
+// CHECK: %[[LOAD_9:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_15]]] : memref<?xf64>
+// CHECK: %[[LOAD_10:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_16]]] : memref<?xindex>
+// CHECK: %[[LOAD_11:.*]] = memref.load %[[VAL_9]]{{\[}}%[[ADDI_1]]] : memref<?xindex>
+// CHECK: %[[FOR_0:.*]] = scf.for %[[VAL_18:.*]] = %[[LOAD_10]] to %[[LOAD_11]] step %[[CONSTANT_5]] iter_args(%[[VAL_19:.*]] = %[[VAL_17]]) -> (index) {
+// CHECK: %[[LOAD_12:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_18]]] : memref<?xindex>
+// CHECK: %[[LOAD_13:.*]] = memref.load %[[ALLOC_0]]{{\[}}%[[LOAD_12]]] : memref<300xf64>
+// CHECK: %[[LOAD_14:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_18]]] : memref<?xf64>
+// CHECK: %[[MULF_0:.*]] = arith.mulf %[[LOAD_9]], %[[LOAD_14]] : f64
+// CHECK: %[[ADDF_0:.*]] = arith.addf %[[LOAD_13]], %[[MULF_0]] : f64
+// CHECK: %[[LOAD_15:.*]] = memref.load %[[ALLOC_1]]{{\[}}%[[LOAD_12]]] : memref<300xi1>
+// CHECK: %[[CMPI_5:.*]] = arith.cmpi eq, %[[LOAD_15]], %[[CONSTANT_4]] : i1
+// CHECK: %[[IF_1:.*]] = scf.if %[[CMPI_5]] -> (index) {
+// CHECK: memref.store %[[CONSTANT_3]], %[[ALLOC_1]]{{\[}}%[[LOAD_12]]] : memref<300xi1>
+// CHECK: memref.store %[[LOAD_12]], %[[ALLOC_2]]{{\[}}%[[VAL_19]]] : memref<300xindex>
+// CHECK: %[[ADDI_2:.*]] = arith.addi %[[VAL_19]], %[[CONSTANT_5]] : index
+// CHECK: scf.yield %[[ADDI_2]] : index
// CHECK: } else {
-// CHECK: scf.yield %[[VAL_69]] : index
+// CHECK: scf.yield %[[VAL_19]] : index
// CHECK: }
-// CHECK: memref.store %[[VAL_74]], %[[VAL_20]]{{\[}}%[[VAL_70]]] : memref<300xf64>
-// CHECK: scf.yield %[[VAL_77]] : index
-// CHECK: }
-// CHECK: scf.yield %[[VAL_67]] : index
+// CHECK: memref.store %[[ADDF_0]], %[[ALLOC_0]]{{\[}}%[[LOAD_12]]] : memref<300xf64>
+// CHECK: scf.yield %[[IF_1]] : index
+// CHECK: } {"Emitted from" = "linalg.generic"}
+// CHECK: scf.yield %[[FOR_0]] : index
// CHECK: } else {
-// CHECK: scf.yield %[[VAL_54]] : index
+// CHECK: scf.yield %[[VAL_17]] : index
// CHECK: }
-// CHECK: %[[VAL_79:.*]] = arith.addi %[[VAL_52]], %[[VAL_6]] : index
-// CHECK: %[[VAL_80:.*]] = arith.select %[[VAL_59]], %[[VAL_79]], %[[VAL_52]] : index
-// CHECK: %[[VAL_81:.*]] = arith.addi %[[VAL_53]], %[[VAL_6]] : index
-// CHECK: %[[VAL_82:.*]] = arith.select %[[VAL_60]], %[[VAL_81]], %[[VAL_53]] : index
-// CHECK: scf.yield %[[VAL_80]], %[[VAL_82]], %[[VAL_62]] : index, index, index
+// CHECK: %[[ADDI_3:.*]] = arith.addi %[[VAL_15]], %[[CONSTANT_5]] : index
+// CHECK: %[[SELECT_1:.*]] = arith.select %[[CMPI_3]], %[[ADDI_3]], %[[VAL_15]] : index
+// CHECK: %[[SELECT_2:.*]] = arith.select %[[CMPI_4]], %[[ADDI_1]], %[[VAL_16]] : index
+// CHECK: scf.yield %[[SELECT_1]], %[[SELECT_2]], %[[IF_0]] : index, index, index
// CHECK: }
-// CHECK: %[[VAL_83:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[VAL_84:.*]] = memref.cast %[[VAL_83]] : memref<2xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_39]], %[[VAL_83]]{{\[}}%[[VAL_5]]] : memref<2xindex>
-// CHECK: func.call @expInsertF64(%[[VAL_19]], %[[VAL_84]], %[[VAL_21]], %[[VAL_23]], %[[VAL_25]], %[[VAL_85:.*]]#2) : (!llvm.ptr, memref<?xindex>, memref<?xf64>, memref<?xi1>, memref<?xindex>, index) -> ()
-// CHECK: }
-// CHECK: memref.dealloc %[[VAL_20]] : memref<300xf64>
-// CHECK: memref.dealloc %[[VAL_22]] : memref<300xi1>
-// CHECK: memref.dealloc %[[VAL_24]] : memref<300xindex>
-// CHECK: call @endLexInsert(%[[VAL_19]]) : (!llvm.ptr) -> ()
-// CHECK: return %[[VAL_19]] : !llvm.ptr
-// CHECK: }
+// CHECK: %[[ALLOCA_3:.*]] = memref.alloca() : memref<2xindex>
+// CHECK: %[[CAST_6:.*]] = memref.cast %[[ALLOCA_3]] : memref<2xindex> to memref<?xindex>
+// CHECK: memref.store %[[LOAD_2]], %[[ALLOCA_3]]{{\[}}%[[CONSTANT_6]]] : memref<2xindex>
+// CHECK: func.call @expInsertF64(%[[VAL_0]], %[[CAST_6]], %[[CAST_3]], %[[CAST_4]], %[[CAST_5]], %[[VAL_20:.*]]#2) : (!llvm.ptr, memref<?xindex>, memref<?xf64>, memref<?xi1>, memref<?xindex>, index) -> ()
+// CHECK: } {"Emitted from" = "linalg.generic"}
+// CHECK: memref.dealloc %[[ALLOC_0]] : memref<300xf64>
+// CHECK: memref.dealloc %[[ALLOC_1]] : memref<300xi1>
+// CHECK: memref.dealloc %[[ALLOC_2]] : memref<300xindex>
+// CHECK: call @endLexInsert(%[[VAL_0]]) : (!llvm.ptr) -> ()
+// CHECK: return %[[VAL_0]] : !llvm.ptr
+// CHECK: }
func.func @fill_zero_after_alloc(%arg0: tensor<100x200xf64, #DCSR>,
%arg1: tensor<200x300xf64, #DCSR>) -> tensor<100x300xf64, #DCSR> {
%0 = tensor.empty() : tensor<100x300xf64, #DCSR>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir b/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
index f6f7f396adab5..c4d86b6b6931f 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
@@ -79,76 +79,72 @@ func.func @sqsum(%arg0: tensor<?x?x?x?xi32, #COO>) -> tensor<i32> {
// ITER: }
// CHECK-LABEL: func.func @add(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<10xi32, #sparse{{.*}}>,
-// CHECK-SAME: %[[VAL_1:.*]]: tensor<10xi32, #sparse{{.*}}>) -> tensor<10xi32> {
-// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_4:.*]] = arith.constant 0 : i32
-// CHECK: %[[VAL_5:.*]] = arith.constant dense<0> : tensor<10xi32>
-// CHECK: %[[VAL_6:.*]] = bufferization.to_buffer %[[VAL_5]] : tensor<10xi32> to memref<10xi32>
-// CHECK: linalg.fill ins(%[[VAL_4]] : i32) outs(%[[VAL_6]] : memref<10xi32>)
-// CHECK: %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref<?xindex>
-// CHECK: %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref<?xindex>
-// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_2]]] : memref<?xindex>
-// CHECK: %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref<?xindex>
-// CHECK: %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref<?xindex>
-// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_2]]] : memref<?xindex>
-// CHECK: %[[VAL_15:.*]]:2 = scf.while (%[[VAL_16:.*]] = %[[VAL_9]], %[[VAL_17:.*]] = %[[VAL_13]]) : (index, index) -> (index, index) {
-// CHECK: %[[VAL_18:.*]] = arith.cmpi ult, %[[VAL_16]], %[[VAL_10]] : index
-// CHECK: %[[VAL_19:.*]] = arith.cmpi ult, %[[VAL_17]], %[[VAL_14]] : index
-// CHECK: %[[VAL_20:.*]] = arith.andi %[[VAL_18]], %[[VAL_19]] : i1
-// CHECK: scf.condition(%[[VAL_20]]) %[[VAL_16]], %[[VAL_17]] : index, index
+// CHECK-SAME: %[[ARG0:.*]]: tensor<10xi32, {{.*}}>,
+// CHECK-SAME: %[[ARG1:.*]]: tensor<10xi32, {{.*}}>) -> tensor<10xi32> {
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : index
+// CHECK: %[[CONSTANT_1:.*]] = arith.constant 0 : index
+// CHECK: %[[CONSTANT_2:.*]] = arith.constant dense<0> : tensor<10xi32>
+// CHECK: %[[CONSTANT_3:.*]] = arith.constant 0 : i32
+// CHECK: %[[VALUES_0:.*]] = sparse_tensor.values %[[ARG1]] : tensor<10xi32, {{.*}}> to memref<?xi32>
+// CHECK: %[[VALUES_1:.*]] = sparse_tensor.values %[[ARG0]] : tensor<10xi32, {{.*}}> to memref<?xi32>
+// CHECK: %[[TO_BUFFER_0:.*]] = bufferization.to_buffer %[[CONSTANT_2]] : tensor<10xi32> to memref<10xi32>
+// CHECK: linalg.fill ins(%[[CONSTANT_3]] : i32) outs(%[[TO_BUFFER_0]] : memref<10xi32>)
+// CHECK: %[[POSITIONS_0:.*]] = sparse_tensor.positions %[[ARG0]] {level = 0 : index} : tensor<10xi32, {{.*}}> to memref<?xindex>
+// CHECK: %[[COORDINATES_0:.*]] = sparse_tensor.coordinates %[[ARG0]] {level = 0 : index} : tensor<10xi32, {{.*}}> to memref<?xindex>
+// CHECK: %[[LOAD_0:.*]] = memref.load %[[POSITIONS_0]]{{\[}}%[[CONSTANT_1]]] : memref<?xindex>
+// CHECK: %[[LOAD_1:.*]] = memref.load %[[POSITIONS_0]]{{\[}}%[[CONSTANT_0]]] : memref<?xindex>
+// CHECK: %[[POSITIONS_1:.*]] = sparse_tensor.positions %[[ARG1]] {level = 0 : index} : tensor<10xi32, {{.*}}> to memref<?xindex>
+// CHECK: %[[COORDINATES_1:.*]] = sparse_tensor.coordinates %[[ARG1]] {level = 0 : index} : tensor<10xi32, {{.*}}> to memref<?xindex>
+// CHECK: %[[LOAD_2:.*]] = memref.load %[[POSITIONS_1]]{{\[}}%[[CONSTANT_1]]] : memref<?xindex>
+// CHECK: %[[LOAD_3:.*]] = memref.load %[[POSITIONS_1]]{{\[}}%[[CONSTANT_0]]] : memref<?xindex>
+// CHECK: %[[WHILE_0:.*]]:2 = scf.while (%[[VAL_0:.*]] = %[[LOAD_0]], %[[VAL_1:.*]] = %[[LOAD_2]]) : (index, index) -> (index, index) {
+// CHECK: %[[CMPI_0:.*]] = arith.cmpi ult, %[[VAL_0]], %[[LOAD_1]] : index
+// CHECK: %[[CMPI_1:.*]] = arith.cmpi ult, %[[VAL_1]], %[[LOAD_3]] : index
+// CHECK: %[[ANDI_0:.*]] = arith.andi %[[CMPI_0]], %[[CMPI_1]] : i1
+// CHECK: scf.condition(%[[ANDI_0]]) %[[VAL_0]], %[[VAL_1]] : index, index
// CHECK: } do {
-// CHECK: ^bb0(%[[VAL_21:.*]]: index, %[[VAL_22:.*]]: index):
-// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_21]]] : memref<?xindex>
-// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_22]]] : memref<?xindex>
-// CHECK: %[[VAL_25:.*]] = arith.cmpi ult, %[[VAL_24]], %[[VAL_23]] : index
-// CHECK: %[[VAL_26:.*]] = arith.select %[[VAL_25]], %[[VAL_24]], %[[VAL_23]] : index
-// CHECK: %[[VAL_27:.*]] = arith.cmpi eq, %[[VAL_23]], %[[VAL_26]] : index
-// CHECK: %[[VAL_28:.*]] = arith.cmpi eq, %[[VAL_24]], %[[VAL_26]] : index
-// CHECK: %[[VAL_29:.*]] = arith.andi %[[VAL_27]], %[[VAL_28]] : i1
-// CHECK: scf.if %[[VAL_29]] {
-// CHECK: %[[VAL_30:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
-// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_30]]{{\[}}%[[VAL_21]]] : memref<?xi32>
-// CHECK: %[[VAL_32:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
-// CHECK: %[[VAL_33:.*]] = memref.load %[[VAL_32]]{{\[}}%[[VAL_22]]] : memref<?xi32>
-// CHECK: %[[VAL_34:.*]] = arith.addi %[[VAL_31]], %[[VAL_33]] : i32
-// CHECK: memref.store %[[VAL_34]], %[[VAL_6]]{{\[}}%[[VAL_26]]] : memref<10xi32>
+// CHECK: ^bb0(%[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index):
+// CHECK: %[[LOAD_4:.*]] = memref.load %[[COORDINATES_0]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK: %[[LOAD_5:.*]] = memref.load %[[COORDINATES_1]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK: %[[CMPI_2:.*]] = arith.cmpi ult, %[[LOAD_5]], %[[LOAD_4]] : index
+// CHECK: %[[SELECT_0:.*]] = arith.select %[[CMPI_2]], %[[LOAD_5]], %[[LOAD_4]] : index
+// CHECK: %[[CMPI_3:.*]] = arith.cmpi eq, %[[LOAD_4]], %[[SELECT_0]] : index
+// CHECK: %[[CMPI_4:.*]] = arith.cmpi eq, %[[LOAD_5]], %[[SELECT_0]] : index
+// CHECK: %[[ANDI_1:.*]] = arith.andi %[[CMPI_3]], %[[CMPI_4]] : i1
+// CHECK: scf.if %[[ANDI_1]] {
+// CHECK: %[[LOAD_6:.*]] = memref.load %[[VALUES_1]]{{\[}}%[[VAL_2]]] : memref<?xi32>
+// CHECK: %[[LOAD_7:.*]] = memref.load %[[VALUES_0]]{{\[}}%[[VAL_3]]] : memref<?xi32>
+// CHECK: %[[ADDI_0:.*]] = arith.addi %[[LOAD_6]], %[[LOAD_7]] : i32
+// CHECK: memref.store %[[ADDI_0]], %[[TO_BUFFER_0]]{{\[}}%[[SELECT_0]]] : memref<10xi32>
// CHECK: } else {
-// CHECK: scf.if %[[VAL_27]] {
-// CHECK: %[[VAL_35:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
-// CHECK: %[[VAL_36:.*]] = memref.load %[[VAL_35]]{{\[}}%[[VAL_21]]] : memref<?xi32>
-// CHECK: memref.store %[[VAL_36]], %[[VAL_6]]{{\[}}%[[VAL_26]]] : memref<10xi32>
+// CHECK: scf.if %[[CMPI_3]] {
+// CHECK: %[[LOAD_8:.*]] = memref.load %[[VALUES_1]]{{\[}}%[[VAL_2]]] : memref<?xi32>
+// CHECK: memref.store %[[LOAD_8]], %[[TO_BUFFER_0]]{{\[}}%[[SELECT_0]]] : memref<10xi32>
// CHECK: } else {
-// CHECK: scf.if %[[VAL_28]] {
-// CHECK: %[[VAL_37:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
-// CHECK: %[[VAL_38:.*]] = memref.load %[[VAL_37]]{{\[}}%[[VAL_22]]] : memref<?xi32>
-// CHECK: memref.store %[[VAL_38]], %[[VAL_6]]{{\[}}%[[VAL_26]]] : memref<10xi32>
+// CHECK: scf.if %[[CMPI_4]] {
+// CHECK: %[[LOAD_9:.*]] = memref.load %[[VALUES_0]]{{\[}}%[[VAL_3]]] : memref<?xi32>
+// CHECK: memref.store %[[LOAD_9]], %[[TO_BUFFER_0]]{{\[}}%[[SELECT_0]]] : memref<10xi32>
// CHECK: }
// CHECK: }
// CHECK: }
-// CHECK: %[[VAL_39:.*]] = arith.addi %[[VAL_21]], %[[VAL_2]] : index
-// CHECK: %[[VAL_40:.*]] = arith.select %[[VAL_27]], %[[VAL_39]], %[[VAL_21]] : index
-// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_22]], %[[VAL_2]] : index
-// CHECK: %[[VAL_42:.*]] = arith.select %[[VAL_28]], %[[VAL_41]], %[[VAL_22]] : index
-// CHECK: scf.yield %[[VAL_40]], %[[VAL_42]] : index, index
+// CHECK: %[[ADDI_1:.*]] = arith.addi %[[VAL_2]], %[[CONSTANT_0]] : index
+// CHECK: %[[SELECT_1:.*]] = arith.select %[[CMPI_3]], %[[ADDI_1]], %[[VAL_2]] : index
+// CHECK: %[[ADDI_2:.*]] = arith.addi %[[VAL_3]], %[[CONSTANT_0]] : index
+// CHECK: %[[SELECT_2:.*]] = arith.select %[[CMPI_4]], %[[ADDI_2]], %[[VAL_3]] : index
+// CHECK: scf.yield %[[SELECT_1]], %[[SELECT_2]] : index, index
// CHECK: }
-// CHECK: %[[VAL_43:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
-// CHECK: scf.for %[[VAL_44:.*]] = %[[VAL_45:.*]]#0 to %[[VAL_10]] step %[[VAL_2]] {
-// CHECK: %[[VAL_46:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_44]]] : memref<?xindex>
-// CHECK: %[[VAL_47:.*]] = memref.load %[[VAL_43]]{{\[}}%[[VAL_44]]] : memref<?xi32>
-// CHECK: memref.store %[[VAL_47]], %[[VAL_6]]{{\[}}%[[VAL_46]]] : memref<10xi32>
+// CHECK: scf.for %[[VAL_4:.*]] = %[[VAL_5:.*]]#0 to %[[LOAD_1]] step %[[CONSTANT_0]] {
+// CHECK: %[[LOAD_10:.*]] = memref.load %[[COORDINATES_0]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK: %[[LOAD_11:.*]] = memref.load %[[VALUES_1]]{{\[}}%[[VAL_4]]] : memref<?xi32>
+// CHECK: memref.store %[[LOAD_11]], %[[TO_BUFFER_0]]{{\[}}%[[LOAD_10]]] : memref<10xi32>
// CHECK: }
-// CHECK: %[[VAL_48:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
-// CHECK: scf.for %[[VAL_49:.*]] = %[[VAL_50:.*]]#1 to %[[VAL_14]] step %[[VAL_2]] {
-// CHECK: %[[VAL_51:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_49]]] : memref<?xindex>
-// CHECK: %[[VAL_52:.*]] = memref.load %[[VAL_48]]{{\[}}%[[VAL_49]]] : memref<?xi32>
-// CHECK: memref.store %[[VAL_52]], %[[VAL_6]]{{\[}}%[[VAL_51]]] : memref<10xi32>
+// CHECK: scf.for %[[VAL_6:.*]] = %[[VAL_7:.*]]#1 to %[[LOAD_3]] step %[[CONSTANT_0]] {
+// CHECK: %[[LOAD_12:.*]] = memref.load %[[COORDINATES_1]]{{\[}}%[[VAL_6]]] : memref<?xindex>
+// CHECK: %[[LOAD_13:.*]] = memref.load %[[VALUES_0]]{{\[}}%[[VAL_6]]] : memref<?xi32>
+// CHECK: memref.store %[[LOAD_13]], %[[TO_BUFFER_0]]{{\[}}%[[LOAD_12]]] : memref<10xi32>
// CHECK: }
-// CHECK: %[[VAL_53:.*]] = bufferization.to_tensor %[[VAL_6]] : memref<10xi32>
-// CHECK: return %[[VAL_53]] : tensor<10xi32>
+// CHECK: %[[TO_TENSOR_0:.*]] = bufferization.to_tensor %[[TO_BUFFER_0]] : memref<10xi32> to tensor<10xi32>
+// CHECK: return %[[TO_TENSOR_0]] : tensor<10xi32>
// CHECK: }
func.func @add(%arg0: tensor<10xi32, #VEC>, %arg1: tensor<10xi32, #VEC>) -> tensor<10xi32> {
%cst = arith.constant dense<0> : tensor<10xi32>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir
index e9587edef4678..165d0835b5824 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir
@@ -58,56 +58,55 @@ func.func @sparse_index_1d_conj(%arga: tensor<8xi64, #SparseVector>) -> tensor<8
return %r : tensor<8xi64>
}
-// CHECK-LABEL: func.func @sparse_index_1d_disj(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<8xi64, #sparse{{[0-9]*}}>) -> tensor<8xi64> {
-// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 8 : index
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : i64
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_6:.*]] = arith.constant true
-// CHECK-DAG: %[[VAL_7:.*]] = tensor.empty() : tensor<8xi64>
-// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xi64>
-// CHECK-DAG: %[[VAL_11:.*]] = bufferization.to_buffer %[[VAL_7]] : tensor<8xi64> to memref<8xi64>
-// CHECK-DAG: linalg.fill ins(%[[VAL_3]] : i64) outs(%[[VAL_11]] : memref<8xi64>)
-// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK: %[[VAL_14:.*]]:2 = scf.while (%[[VAL_15:.*]] = %[[VAL_12]], %[[VAL_16:.*]] = %[[VAL_4]]) : (index, index) -> (index, index) {
-// CHECK: %[[VAL_17:.*]] = arith.cmpi ult, %[[VAL_15]], %[[VAL_13]] : index
-// CHECK: scf.condition(%[[VAL_17]]) %[[VAL_15]], %[[VAL_16]] : index, index
+// CHECK-LABEL: func.func @sparse_index_1d_disj(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<8xi64, #sparse{{[0-9]*}}>) -> tensor<8xi64> {
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
+// CHECK: %[[CONSTANT_1:.*]] = arith.constant true
+// CHECK: %[[CONSTANT_2:.*]] = arith.constant 1 : index
+// CHECK: %[[CONSTANT_3:.*]] = arith.constant 0 : index
+// CHECK: %[[CONSTANT_4:.*]] = arith.constant 8 : index
+// CHECK: %[[CONSTANT_5:.*]] = arith.constant 0 : i64
+// CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<8xi64>
+// CHECK: %[[VALUES_0:.*]] = sparse_tensor.values %[[ARG0]] : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xi64>
+// CHECK: %[[TO_BUFFER_0:.*]] = bufferization.to_buffer %[[EMPTY_0]] : tensor<8xi64> to memref<8xi64>
+// CHECK: linalg.fill ins(%[[CONSTANT_5]] : i64) outs(%[[TO_BUFFER_0]] : memref<8xi64>)
+// CHECK: %[[POSITIONS_0:.*]] = sparse_tensor.positions %[[ARG0]] {level = 0 : index} : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK: %[[COORDINATES_0:.*]] = sparse_tensor.coordinates %[[ARG0]] {level = 0 : index} : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK: %[[LOAD_0:.*]] = memref.load %[[POSITIONS_0]]{{\[}}%[[CONSTANT_3]]] : memref<?xindex>
+// CHECK: %[[LOAD_1:.*]] = memref.load %[[POSITIONS_0]]{{\[}}%[[CONSTANT_2]]] : memref<?xindex>
+// CHECK: %[[WHILE_0:.*]]:2 = scf.while (%[[VAL_0:.*]] = %[[LOAD_0]], %[[VAL_1:.*]] = %[[CONSTANT_3]]) : (index, index) -> (index, index) {
+// CHECK: %[[CMPI_0:.*]] = arith.cmpi ult, %[[VAL_0]], %[[LOAD_1]] : index
+// CHECK: scf.condition(%[[CMPI_0]]) %[[VAL_0]], %[[VAL_1]] : index, index
// CHECK: } do {
-// CHECK: ^bb0(%[[VAL_18:.*]]: index, %[[VAL_19:.*]]: index):
-// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref<?xindex>
-// CHECK: %[[VAL_21:.*]] = arith.cmpi eq, %[[VAL_20]], %[[VAL_19]] : index
-// CHECK: scf.if %[[VAL_21]] {
-// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_18]]] : memref<?xi64>
-// CHECK: %[[VAL_23:.*]] = arith.index_cast %[[VAL_19]] : index to i64
-// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_22]], %[[VAL_23]] : i64
-// CHECK: memref.store %[[VAL_24]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<8xi64>
+// CHECK: ^bb0(%[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index):
+// CHECK: %[[INDEX_CAST_0:.*]] = arith.index_cast %[[VAL_3]] : index to i64
+// CHECK: %[[LOAD_2:.*]] = memref.load %[[COORDINATES_0]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK: %[[CMPI_1:.*]] = arith.cmpi eq, %[[LOAD_2]], %[[VAL_3]] : index
+// CHECK: scf.if %[[CMPI_1]] {
+// CHECK: %[[LOAD_3:.*]] = memref.load %[[VALUES_0]]{{\[}}%[[VAL_2]]] : memref<?xi64>
+// CHECK: %[[ADDI_0:.*]] = arith.addi %[[LOAD_3]], %[[INDEX_CAST_0]] : i64
+// CHECK: memref.store %[[ADDI_0]], %[[TO_BUFFER_0]]{{\[}}%[[VAL_3]]] : memref<8xi64>
// CHECK: } else {
-// CHECK: scf.if %[[VAL_6]] {
-// CHECK: %[[VAL_25:.*]] = arith.index_cast %[[VAL_19]] : index to i64
-// CHECK: memref.store %[[VAL_25]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<8xi64>
+// CHECK: scf.if %[[CONSTANT_1]] {
+// CHECK: memref.store %[[INDEX_CAST_0]], %[[TO_BUFFER_0]]{{\[}}%[[VAL_3]]] : memref<8xi64>
// CHECK: } else {
// CHECK: }
// CHECK: }
-// CHECK: %[[VAL_26:.*]] = arith.addi %[[VAL_18]], %[[VAL_5]] : index
-// CHECK: %[[VAL_27:.*]] = arith.select %[[VAL_21]], %[[VAL_26]], %[[VAL_18]] : index
-// CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_19]], %[[VAL_5]] : index
-// CHECK: scf.yield %[[VAL_27]], %[[VAL_28]] : index, index
+// CHECK: %[[ADDI_1:.*]] = arith.addi %[[VAL_2]], %[[CONSTANT_2]] : index
+// CHECK: %[[SELECT_0:.*]] = arith.select %[[CMPI_1]], %[[ADDI_1]], %[[VAL_2]] : index
+// CHECK: %[[ADDI_2:.*]] = arith.addi %[[VAL_3]], %[[CONSTANT_2]] : index
+// CHECK: scf.yield %[[SELECT_0]], %[[ADDI_2]] : index, index
// CHECK: } attributes {"Emitted from" = "linalg.generic"}
-// CHECK: scf.for %[[VAL_29:.*]] = %[[VAL_30:.*]]#1 to %[[VAL_1]] step %[[VAL_1]] {
-// CHECK: %[[VAL_31:.*]] = affine.min #map(%[[VAL_1]], %[[VAL_29]]){{\[}}%[[VAL_1]]]
-// CHECK: %[[VAL_32:.*]] = vector.create_mask %[[VAL_31]] : vector<8xi1>
-// CHECK: %[[VAL_33:.*]] = vector.broadcast %[[VAL_29]] : index to vector<8xindex>
-// CHECK: %[[VAL_34:.*]] = arith.addi %[[VAL_33]], %[[VAL_2]] : vector<8xindex>
-// CHECK: %[[VAL_35:.*]] = arith.index_cast %[[VAL_34]] : vector<8xindex> to vector<8xi64>
-// CHECK: vector.maskedstore %[[VAL_11]]{{\[}}%[[VAL_29]]], %[[VAL_32]], %[[VAL_35]] : memref<8xi64>, vector<8xi1>, vector<8xi64>
+// CHECK: scf.for %[[VAL_4:.*]] = %[[VAL_5:.*]]#1 to %[[CONSTANT_4]] step %[[CONSTANT_4]] {
+// CHECK: %[[MIN_0:.*]] = affine.min #{{.*}}(%[[CONSTANT_4]], %[[VAL_4]]){{\[}}%[[CONSTANT_4]]]
+// CHECK: %[[CREATE_MASK_0:.*]] = vector.create_mask %[[MIN_0]] : vector<8xi1>
+// CHECK: %[[BROADCAST_0:.*]] = vector.broadcast %[[VAL_4]] : index to vector<8xindex>
+// CHECK: %[[ADDI_3:.*]] = arith.addi %[[BROADCAST_0]], %[[CONSTANT_0]] : vector<8xindex>
+// CHECK: %[[INDEX_CAST_1:.*]] = arith.index_cast %[[ADDI_3]] : vector<8xindex> to vector<8xi64>
+// CHECK: vector.maskedstore %[[TO_BUFFER_0]]{{\[}}%[[VAL_4]]], %[[CREATE_MASK_0]], %[[INDEX_CAST_1]] : memref<8xi64>, vector<8xi1>, vector<8xi64>
// CHECK: } {"Emitted from" = "linalg.generic"}
-// CHECK: %[[VAL_36:.*]] = bufferization.to_tensor %[[VAL_11]] : memref<8xi64>
-// CHECK: return %[[VAL_36]] : tensor<8xi64>
+// CHECK: %[[TO_TENSOR_0:.*]] = bufferization.to_tensor %[[TO_BUFFER_0]] : memref<8xi64> to tensor<8xi64>
+// CHECK: return %[[TO_TENSOR_0]] : tensor<8xi64>
// CHECK: }
func.func @sparse_index_1d_disj(%arga: tensor<8xi64, #SparseVector>) -> tensor<8xi64> {
%init = tensor.empty() : tensor<8xi64>
diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir
index b447094874d01..81b63b9152cec 100644
--- a/mlir/test/Transforms/cse.mlir
+++ b/mlir/test/Transforms/cse.mlir
@@ -163,16 +163,14 @@ func.func @down_propagate() -> i32 {
/// Check that operation definitions are NOT propagated up the dominance tree.
// CHECK-LABEL: @up_propagate_for
func.func @up_propagate_for() -> i32 {
+ // CHECK-NEXT: %[[VAR_c1_i32_0:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
// CHECK: affine.for {{.*}} = 0 to 4 {
- affine.for %i = 0 to 4 {
- // CHECK-NEXT: %[[VAR_c1_i32_0:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
+ affine.for %i = 0 to 4 {
// CHECK-NEXT: "foo"(%[[VAR_c1_i32_0]]) : (i32) -> ()
%0 = arith.constant 1 : i32
"foo"(%0) : (i32) -> ()
}
-
- // CHECK: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
- // CHECK-NEXT: return %[[VAR_c1_i32]] : i32
+ // CHECK: return %[[VAR_c1_i32_0]] : i32
%1 = arith.constant 1 : i32
return %1 : i32
}
@@ -181,7 +179,8 @@ func.func @up_propagate_for() -> i32 {
// CHECK-LABEL: func @up_propagate
func.func @up_propagate() -> i32 {
- // CHECK-NEXT: %[[VAR_c0_i32:[0-9a-zA-Z_]+]] = arith.constant 0 : i32
+ // CHECK-NEXT: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
+ // CHECK-NEXT: %[[VAR_c0_i32:[0-9a-zA-Z_]+]] = arith.constant 0 : i32
%0 = arith.constant 0 : i32
// CHECK-NEXT: %[[VAR_true:[0-9a-zA-Z_]+]] = arith.constant true
@@ -191,17 +190,15 @@ func.func @up_propagate() -> i32 {
cf.cond_br %cond, ^bb1, ^bb2(%0 : i32)
^bb1: // CHECK: ^bb1:
- // CHECK-NEXT: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
%1 = arith.constant 1 : i32
// CHECK-NEXT: cf.br ^bb2(%[[VAR_c1_i32]] : i32)
cf.br ^bb2(%1 : i32)
^bb2(%arg : i32): // CHECK: ^bb2
- // CHECK-NEXT: %[[VAR_c1_i32_0:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
%2 = arith.constant 1 : i32
- // CHECK-NEXT: %[[VAR_1:[0-9a-zA-Z_]+]] = arith.addi %{{.*}}, %[[VAR_c1_i32_0]] : i32
+ // CHECK-NEXT: %[[VAR_1:[0-9a-zA-Z_]+]] = arith.addi %{{.*}}, %[[VAR_c1_i32]] : i32
%add = arith.addi %arg, %2 : i32
// CHECK-NEXT: return %[[VAR_1]] : i32
@@ -216,6 +213,7 @@ func.func @up_propagate() -> i32 {
func.func @up_propagate_region() -> i32 {
// CHECK-NEXT: {{.*}} "foo.region"
%0 = "foo.region"() ({
+ // CHECK-NEXT: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
// CHECK-NEXT: %[[VAR_c0_i32:[0-9a-zA-Z_]+]] = arith.constant 0 : i32
// CHECK-NEXT: %[[VAR_true:[0-9a-zA-Z_]+]] = arith.constant true
// CHECK-NEXT: cf.cond_br
@@ -225,15 +223,13 @@ func.func @up_propagate_region() -> i32 {
cf.cond_br %true, ^bb1, ^bb2(%1 : i32)
^bb1: // CHECK: ^bb1:
- // CHECK-NEXT: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
// CHECK-NEXT: cf.br
%c1_i32 = arith.constant 1 : i32
cf.br ^bb2(%c1_i32 : i32)
^bb2(%arg : i32): // CHECK: ^bb2(%[[VAR_1:.*]]: i32):
- // CHECK-NEXT: %[[VAR_c1_i32_0:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
- // CHECK-NEXT: %[[VAR_2:[0-9a-zA-Z_]+]] = arith.addi %[[VAR_1]], %[[VAR_c1_i32_0]] : i32
+ // CHECK-NEXT: %[[VAR_2:[0-9a-zA-Z_]+]] = arith.addi %[[VAR_1]], %[[VAR_c1_i32]] : i32
// CHECK-NEXT: "foo.yield"(%[[VAR_2]]) : (i32) -> ()
%c1_i32_0 = arith.constant 1 : i32
@@ -500,8 +496,8 @@ func.func @cse_multiple_regions(%c: i1, %t: tensor<5xf32>) -> (tensor<5xf32>, te
return %r1, %r2 : tensor<5xf32>, tensor<5xf32>
}
// CHECK-LABEL: func @cse_multiple_regions
-// CHECK: %[[if:.*]] = scf.if {{.*}} {
-// CHECK: tensor.empty
+// CHECK: tensor.empty
+// CHECK: %[[if:.*]] = scf.if {{.*}}
// CHECK: scf.yield
// CHECK: } else {
// CHECK: scf.yield
>From 1ba837985d0f5073771f2e2e51b356cee06c06e8 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Tue, 10 Feb 2026 01:37:00 +0000
Subject: [PATCH 2/7] fix CI.
---
mlir/lib/Transforms/CSE.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 6cd9329e77fc8..2c046273da36c 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -420,9 +420,9 @@ void CSEDriver::simplifyRegion(ScopedMapTy &knownValues,
<< OpWithFlags(region.getParentOp(), OpPrintingFlags().skipRegions());
// Prevent CSE of pure operations across function boundaries.
- std::optional<ScopedMapTy::ScopeTy> funcPureScope;
+ std::unique_ptr<ScopedMapTy::ScopeTy> funcPureScope;
if (isa<FunctionOpInterface>(region.getParentOp())) {
- funcPureScope.emplace(knownPureOps);
+ funcPureScope = std::make_unique<ScopedMapTy::ScopeTy>(knownPureOps);
}
bool hasSSADominance = domInfo->hasSSADominance(®ion);
// If the region only contains one block, then simplify it directly.
>From 6b7dadefff964b1d5c412049b5a765f0ddadfd75 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Thu, 12 Feb 2026 08:56:13 +0000
Subject: [PATCH 3/7] add isBlockCrossIsIsolatedFromAbove func.
---
mlir/lib/Transforms/CSE.cpp | 24 +++++++++++++++++++++++-
1 file changed, 23 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 2c046273da36c..3de980e2ada02 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -134,7 +134,26 @@ class CSEDriver {
};
} // namespace
-/// Hoist the pure ops to the location of the Nearest Common Dominator.
+static bool isBlockCrossIsIsolatedFromAbove(DominanceInfo *dominate, Block *a,
+ Block *b) {
+ if (a == b)
+ return false;
+ if (a->getParent() == b->getParent())
+ return false;
+ if (dominate->dominates(b, a))
+ std::swap(b, a);
+ while (b && b->getParentOp()) {
+ Operation *parnetOp = b->getParentOp();
+ if (parnetOp->mightHaveTrait<OpTrait::IsIsolatedFromAbove>())
+ return true;
+ b = parnetOp->getBlock();
+ if (b == a)
+ return false;
+ }
+ return false;
+}
+
+/// Hoist the pure ops to the location of the Nearest Common Dominator
LogicalResult CSEDriver::hoistPureOp(Operation *existing, Operation *op) {
Block *ancestorBlock =
domInfo->findNearestCommonDominator(existing->getBlock(), op->getBlock());
@@ -144,6 +163,9 @@ LogicalResult CSEDriver::hoistPureOp(Operation *existing, Operation *op) {
<< "failed";
return failure();
}
+ if (isBlockCrossIsIsolatedFromAbove(domInfo, ancestorBlock,
+ existing->getBlock()))
+ return failure();
Operation *insertPoint = nullptr;
for (Value operand : op->getOperands()) {
>From 9ea1a9191accedced9c912db3f7bb67b4cec87c6 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Fri, 13 Feb 2026 17:05:38 +0000
Subject: [PATCH 4/7] add HoistingContainerOpInterface.cpp.
---
flang/test/Lower/vector-subscript-io.f90 | 5 +--
mlir/include/mlir/Dialect/Func/IR/FuncOps.h | 1 +
mlir/include/mlir/Dialect/Func/IR/FuncOps.td | 4 +-
mlir/include/mlir/Dialect/SCF/IR/SCF.h | 1 +
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 5 ++-
mlir/include/mlir/Interfaces/CMakeLists.txt | 1 +
.../Interfaces/HoistingContainerOpInterface.h | 19 +++++++++
.../HoistingContainerOpInterface.td | 41 +++++++++++++++++++
mlir/lib/Dialect/Func/IR/CMakeLists.txt | 1 +
mlir/lib/Dialect/SCF/IR/CMakeLists.txt | 1 +
mlir/lib/Interfaces/CMakeLists.txt | 3 +-
.../HoistingContainerOpInterface.cpp | 21 ++++++++++
mlir/lib/Transforms/CMakeLists.txt | 1 +
mlir/lib/Transforms/CSE.cpp | 6 +++
14 files changed, 104 insertions(+), 6 deletions(-)
create mode 100644 mlir/include/mlir/Interfaces/HoistingContainerOpInterface.h
create mode 100644 mlir/include/mlir/Interfaces/HoistingContainerOpInterface.td
create mode 100644 mlir/lib/Interfaces/HoistingContainerOpInterface.cpp
diff --git a/flang/test/Lower/vector-subscript-io.f90 b/flang/test/Lower/vector-subscript-io.f90
index 0f64e99e03a20..f0a2b875d2e62 100644
--- a/flang/test/Lower/vector-subscript-io.f90
+++ b/flang/test/Lower/vector-subscript-io.f90
@@ -541,11 +541,11 @@ subroutine iostat_in_io_loop(k, j, stat)
! CHECK: fir.call @_FortranAioEnableHandlers(%[[VAL_378]], %[[VAL_369]], %[[VAL_370]], %[[VAL_370]], %[[VAL_370]], %[[VAL_370]]) {{.*}}: (!fir.ref<i8>, i1, i1, i1, i1, i1) -> ()
! CHECK: cf.br ^bb1(%[[VAL_371]], %[[VAL_369]] : index, i1)
! CHECK: ^bb1(%[[VAL_380:.*]]: index, %[[VAL_381:.*]]: i1):
+! CHECK: %[[VAL_384:.*]] = fir.convert %[[VAL_380]] : (index) -> i32
! CHECK: %[[VAL_382:.*]] = arith.cmpi sle, %[[VAL_380]], %[[VAL_368]] : index
! CHECK: %[[VAL_383:.*]] = arith.andi %[[VAL_381]], %[[VAL_382]] : i1
! CHECK: cf.cond_br %[[VAL_383]], ^bb2, ^bb7
! CHECK: ^bb2:
-! CHECK: %[[VAL_384:.*]] = fir.convert %[[VAL_380]] : (index) -> i32
! CHECK: fir.store %[[VAL_384]] to %[[VAL_375]] : !fir.ref<i32>
! CHECK: cf.cond_br %[[VAL_381]], ^bb3, ^bb6(%[[VAL_370]] : i1)
! CHECK: ^bb3:
@@ -573,8 +573,7 @@ subroutine iostat_in_io_loop(k, j, stat)
! CHECK: %[[VAL_405:.*]] = arith.addi %[[VAL_380]], %[[VAL_371]] overflow<nsw> : index
! CHECK: cf.br ^bb1(%[[VAL_405]], %[[VAL_404]] : index, i1)
! CHECK: ^bb7:
-! CHECK: %[[VAL_406:.*]] = fir.convert %[[VAL_380]] : (index) -> i32
-! CHECK: fir.store %[[VAL_406]] to %[[VAL_375]] : !fir.ref<i32>
+! CHECK: fir.store %[[VAL_384]] to %[[VAL_375]] : !fir.ref<i32>
! CHECK: %[[VAL_407:.*]] = fir.call @_FortranAioEndIoStatement(%[[VAL_378]]) {{.*}}: (!fir.ref<i8>) -> i32
! CHECK: fir.store %[[VAL_407]] to %[[VAL_408]] : !fir.ref<i32>
! CHECK: return
diff --git a/mlir/include/mlir/Dialect/Func/IR/FuncOps.h b/mlir/include/mlir/Dialect/Func/IR/FuncOps.h
index 5e10a9f50b774..ad5eac754f236 100644
--- a/mlir/include/mlir/Dialect/Func/IR/FuncOps.h
+++ b/mlir/include/mlir/Dialect/Func/IR/FuncOps.h
@@ -18,6 +18,7 @@
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Interfaces/HoistingContainerOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
index 06ce4f16c867d..db01cff2a6937 100644
--- a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
+++ b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
@@ -15,6 +15,7 @@ include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/FunctionInterfaces.td"
+include "mlir/Interfaces/HoistingContainerOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -248,7 +249,8 @@ def ConstantOp : Func_Op<"constant",
//===----------------------------------------------------------------------===//
def FuncOp : Func_Op<"func", [
- AffineScope, AutomaticAllocationScope,
+ AffineScope, AutomaticAllocationScope,
+ DeclareOpInterfaceMethods<HoistingContainerOpInterface>,
FunctionOpInterface, IsolatedFromAbove, OpAsmOpInterface
]> {
let summary = "An operation with a name containing a single `SSACFG` region";
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index e754a04b0903a..f554147814a75 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -20,6 +20,7 @@
#include "mlir/IR/RegionKindInterface.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/HoistingContainerOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index a08cf3c95e6ce..62061a90c0d2e 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -18,6 +18,7 @@ include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/IR/RegionKindInterface.td"
include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Interfaces/HoistingContainerOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/ParallelCombiningOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -162,7 +163,8 @@ def ForOp : SCF_Op<"for",
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getEntrySuccessorOperands", "getSuccessorInputs"]>,
SingleBlockImplicitTerminator<"scf::YieldOp">,
- RecursiveMemoryEffects]> {
+ RecursiveMemoryEffects,
+ DeclareOpInterfaceMethods<HoistingContainerOpInterface>]> {
let summary = "for operation";
let description = [{
The `scf.for` operation represents a loop taking 3 SSA value as operands
@@ -986,6 +988,7 @@ def WhileOp : SCF_Op<"while",
["getEntrySuccessorOperands", "getSuccessorInputs"]>,
DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getRegionIterArgs", "getYieldedValuesMutable"]>,
+ DeclareOpInterfaceMethods<HoistingContainerOpInterface>,
RecursiveMemoryEffects, SingleBlock]> {
let summary = "a generic 'while' loop";
let description = [{
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index 3cbc9df05f3d7..ecfdb98cc9f76 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_interface(ControlFlowInterfaces)
add_mlir_interface(DerivedAttributeOpInterface)
add_mlir_interface(DestinationStyleOpInterface)
add_mlir_interface(FunctionInterfaces)
+add_mlir_interface(HoistingContainerOpInterface)
add_mlir_interface(IndexingMapOpInterface)
add_mlir_interface(InferIntRangeInterface)
add_mlir_interface(InferStridedMetadataInterface)
diff --git a/mlir/include/mlir/Interfaces/HoistingContainerOpInterface.h b/mlir/include/mlir/Interfaces/HoistingContainerOpInterface.h
new file mode 100644
index 0000000000000..f9e08fb2fc2ce
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/HoistingContainerOpInterface.h
@@ -0,0 +1,19 @@
+//===- HoistingContainerOpInterface.h ---------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_HOISTING_CONTAINER_OP_INTERFACE
+#define MLIR_INTERFACES_HOISTING_CONTAINER_OP_INTERFACE
+
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/HoistingContainerOpInterface.h.inc"
+
+namespace mlir {
+bool canContainHoistedOps(Operation *op);
+} // namespace mlir
+
+#endif // MLIR_INTERFACES_HOISTING_CONTAINER_OP_INTERFACE
diff --git a/mlir/include/mlir/Interfaces/HoistingContainerOpInterface.td b/mlir/include/mlir/Interfaces/HoistingContainerOpInterface.td
new file mode 100644
index 0000000000000..1f1c9994f09c1
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/HoistingContainerOpInterface.td
@@ -0,0 +1,41 @@
+//===- HoistingContainerOpInterface.td - Interface Decl. -*- tablegen -*---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the definition file for the HoistingContainerOpInterface.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_HOISTING_CONTAINER_OP_INTERFACE
+#define MLIR_INTERFACES_HOISTING_CONTAINER_OP_INTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def HoistingContainerOpInterface : OpInterface<"HoistingContainerOpInterface"> {
+ let description = [{
+ This interface models whether an operation's regions are capable of
+ acting as a container for operations hoisted from nested regions.
+ }];
+ let cppNamespace = "::mlir";
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns true if this operation's regions can accommodate operations
+ hoisted from its nested scopes.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"canContainHoistedOps",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return true;
+ }]
+ >
+ ];
+}
+
+#endif // MLIR_INTERFACES_HOISTING_CONTAINER_OP_INTERFACE
diff --git a/mlir/lib/Dialect/Func/IR/CMakeLists.txt b/mlir/lib/Dialect/Func/IR/CMakeLists.txt
index 329301c6fbafd..c748fdf2b57f0 100644
--- a/mlir/lib/Dialect/Func/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Func/IR/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRFuncDialect
MLIRCallInterfaces
MLIRControlFlowInterfaces
MLIRFunctionInterfaces
+ MLIRHoistingContainerOpInterface
MLIRInferTypeOpInterface
MLIRIR
MLIRSideEffectInterfaces
diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
index b111117410ba3..bbf27027f37b4 100644
--- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRSCFDialect
MLIRControlFlowDialect
MLIRDialectUtils
MLIRFunctionInterfaces
+ MLIRHoistingContainerOpInterface
MLIRIR
MLIRLoopLikeInterface
MLIRSideEffectInterfaces
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index ad3e2b61be418..1d73e5d2c6912 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -8,6 +8,7 @@ set(LLVM_OPTIONAL_SOURCES
DestinationStyleOpInterface.cpp
FunctionImplementation.cpp
FunctionInterfaces.cpp
+ HoistingContainerOpInterface.cpp
IndexingMapOpInterface.cpp
InferIntRangeInterface.cpp
InferStridedMetadataInterface.cpp
@@ -64,7 +65,7 @@ add_mlir_library(MLIRFunctionInterfaces
MLIRCallInterfaces
MLIRIR
)
-
+add_mlir_interface_library(HoistingContainerOpInterface)
add_mlir_interface_library(IndexingMapOpInterface)
add_mlir_interface_library(InferIntRangeInterface)
diff --git a/mlir/lib/Interfaces/HoistingContainerOpInterface.cpp b/mlir/lib/Interfaces/HoistingContainerOpInterface.cpp
new file mode 100644
index 0000000000000..33801c6509ad2
--- /dev/null
+++ b/mlir/lib/Interfaces/HoistingContainerOpInterface.cpp
@@ -0,0 +1,21 @@
+//===- HoistingContainerOpInterface.cpp -- Hoisting Container Op Interface -==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/HoistingContainerOpInterface.h"
+
+using namespace mlir;
+
+namespace mlir {
+#include "mlir/Interfaces/HoistingContainerOpInterface.cpp.inc"
+} // namespace mlir
+
+bool mlir::canContainHoistedOps(Operation *op) {
+ if (auto containerOp = dyn_cast<HoistingContainerOpInterface>(op))
+ return containerOp.canContainHoistedOps();
+ return false;
+}
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 8907724627386..1a2cd72691a79 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -32,6 +32,7 @@ add_mlir_library(MLIRTransforms
LINK_LIBS PUBLIC
MLIRAnalysis
MLIRFunctionInterfaces
+ MLIRHoistingContainerOpInterface
MLIRLoopLikeInterface
MLIRMemOpInterfaces
MLIRMemorySlotInterfaces
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 3de980e2ada02..2e36520ed3586 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -16,6 +16,7 @@
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Interfaces/HoistingContainerOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Passes.h"
@@ -163,6 +164,11 @@ LogicalResult CSEDriver::hoistPureOp(Operation *existing, Operation *op) {
<< "failed";
return failure();
}
+
+ if (ancestorBlock->getParent() != existing->getParentRegion() &&
+ !canContainHoistedOps(ancestorBlock->getParentOp()))
+ return failure();
+
if (isBlockCrossIsIsolatedFromAbove(domInfo, ancestorBlock,
existing->getBlock()))
return failure();
>From 2f02dd090c31ff4c329a1cab593c2141f3ccf16a Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Sat, 14 Feb 2026 08:07:47 +0000
Subject: [PATCH 5/7] add comment and fix nit.
---
.../mlir/Interfaces/HoistingContainerOpInterface.h | 10 +++++++---
mlir/lib/Transforms/CSE.cpp | 11 +++++++++--
2 files changed, 16 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/HoistingContainerOpInterface.h b/mlir/include/mlir/Interfaces/HoistingContainerOpInterface.h
index f9e08fb2fc2ce..b6a6addd89173 100644
--- a/mlir/include/mlir/Interfaces/HoistingContainerOpInterface.h
+++ b/mlir/include/mlir/Interfaces/HoistingContainerOpInterface.h
@@ -6,14 +6,18 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_INTERFACES_HOISTING_CONTAINER_OP_INTERFACE
-#define MLIR_INTERFACES_HOISTING_CONTAINER_OP_INTERFACE
+#ifndef MLIR_INTERFACES_HOISTING_CONTAINER_OP_INTERFACE_H_
+#define MLIR_INTERFACES_HOISTING_CONTAINER_OP_INTERFACE_H_
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/HoistingContainerOpInterface.h.inc"
namespace mlir {
+/// Returns true if the given operation implements HoistingContainerOpInterface
+/// and its implementation allows hosting hoisted operations. Returns false
+/// if the operation does not implement the interface, or if the operation
+/// explicitly disallows hoisting.
bool canContainHoistedOps(Operation *op);
} // namespace mlir
-#endif // MLIR_INTERFACES_HOISTING_CONTAINER_OP_INTERFACE
+#endif // MLIR_INTERFACES_HOISTING_CONTAINER_OP_INTERFACE_H_
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 2e36520ed3586..a8ffda195e29a 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -135,6 +135,8 @@ class CSEDriver {
};
} // namespace
+/// Returns true if the path between block 'a' and block 'b' in the region
+/// hierarchy crosses an operation with the 'IsIsolatedFromAbove' trait.
static bool isBlockCrossIsIsolatedFromAbove(DominanceInfo *dominate, Block *a,
Block *b) {
if (a == b)
@@ -154,7 +156,7 @@ static bool isBlockCrossIsIsolatedFromAbove(DominanceInfo *dominate, Block *a,
return false;
}
-/// Hoist the pure ops to the location of the Nearest Common Dominator
+/// Hoist the pure ops to the location of the Nearest Common Dominator.
LogicalResult CSEDriver::hoistPureOp(Operation *existing, Operation *op) {
Block *ancestorBlock =
domInfo->findNearestCommonDominator(existing->getBlock(), op->getBlock());
@@ -165,6 +167,9 @@ LogicalResult CSEDriver::hoistPureOp(Operation *existing, Operation *op) {
return failure();
}
+ // If the ancestorBlock is in a different region than the existing operation,
+ // we need to check if the parentOp of the ancestorBlock can contain hoisted
+ // ops.
if (ancestorBlock->getParent() != existing->getParentRegion() &&
!canContainHoistedOps(ancestorBlock->getParentOp()))
return failure();
@@ -173,6 +178,7 @@ LogicalResult CSEDriver::hoistPureOp(Operation *existing, Operation *op) {
existing->getBlock()))
return failure();
+ // Find the insertion point based on dominance relationships.
Operation *insertPoint = nullptr;
for (Value operand : op->getOperands()) {
if (domInfo->properlyDominates(operand, &ancestorBlock->front()))
@@ -502,7 +508,8 @@ void CSEDriver::simplifyRegion(ScopedMapTy &knownValues,
}
void CSEDriver::simplify(Operation *op, bool *changed) {
- /// Simplify all regions.
+ /// Simplify all regions. Added a new scope using curly braces to release the
+ /// knownPureOps scope before deleting the operation.
{
ScopedMapTy knownValues;
ScopedMapTy knownPureOps;
>From b4aa36695ce6c1d78433d97e4a527e06d081b37b Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Fri, 27 Feb 2026 03:07:57 +0000
Subject: [PATCH 6/7] remove HoistingContainerOpInterface and add
hoist-pure-ops option.
---
flang/lib/Optimizer/Passes/Pipelines.cpp | 4 +-
mlir/include/mlir/Dialect/Func/IR/FuncOps.h | 1 -
mlir/include/mlir/Dialect/Func/IR/FuncOps.td | 4 +-
mlir/include/mlir/Dialect/SCF/IR/SCF.h | 1 -
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 5 +--
mlir/include/mlir/Interfaces/CMakeLists.txt | 1 -
.../Interfaces/HoistingContainerOpInterface.h | 23 -----------
.../HoistingContainerOpInterface.td | 41 -------------------
mlir/include/mlir/Transforms/Passes.td | 4 ++
mlir/lib/Dialect/Func/IR/CMakeLists.txt | 1 -
mlir/lib/Dialect/SCF/IR/CMakeLists.txt | 1 -
mlir/lib/Interfaces/CMakeLists.txt | 3 +-
.../HoistingContainerOpInterface.cpp | 21 ----------
mlir/lib/Transforms/CMakeLists.txt | 1 -
mlir/lib/Transforms/CSE.cpp | 21 ++++------
mlir/test/Pass/run-reproducer.mlir | 4 +-
mlir/test/Transforms/composite-pass.mlir | 2 +-
mlir/test/python/pass_manager.py | 4 +-
18 files changed, 23 insertions(+), 119 deletions(-)
delete mode 100644 mlir/include/mlir/Interfaces/HoistingContainerOpInterface.h
delete mode 100644 mlir/include/mlir/Interfaces/HoistingContainerOpInterface.td
delete mode 100644 mlir/lib/Interfaces/HoistingContainerOpInterface.cpp
diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp
index bb1ad84523c82..a6a721da60bc8 100644
--- a/flang/lib/Optimizer/Passes/Pipelines.cpp
+++ b/flang/lib/Optimizer/Passes/Pipelines.cpp
@@ -273,7 +273,9 @@ void createHLFIRToFIRPassPipeline(mlir::PassManager &pm,
pm, hlfir::createInlineElementals);
if (optLevel.isOptimizingForSpeed()) {
addCanonicalizerPassWithoutRegionSimplification(pm);
- pm.addPass(mlir::createCSEPass());
+ mlir::CSEOptions options;
+ options.hoistPureOps = false;
+ pm.addPass(mlir::createCSEPass(options));
// Run SimplifyHLFIRIntrinsics pass late after CSE,
// and allow introducing operations with new side effects.
addNestedPassToAllTopLevelOperations<PassConstructor>(pm, []() {
diff --git a/mlir/include/mlir/Dialect/Func/IR/FuncOps.h b/mlir/include/mlir/Dialect/Func/IR/FuncOps.h
index ad5eac754f236..5e10a9f50b774 100644
--- a/mlir/include/mlir/Dialect/Func/IR/FuncOps.h
+++ b/mlir/include/mlir/Dialect/Func/IR/FuncOps.h
@@ -18,7 +18,6 @@
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
-#include "mlir/Interfaces/HoistingContainerOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
index db01cff2a6937..06ce4f16c867d 100644
--- a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
+++ b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
@@ -15,7 +15,6 @@ include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/FunctionInterfaces.td"
-include "mlir/Interfaces/HoistingContainerOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -249,8 +248,7 @@ def ConstantOp : Func_Op<"constant",
//===----------------------------------------------------------------------===//
def FuncOp : Func_Op<"func", [
- AffineScope, AutomaticAllocationScope,
- DeclareOpInterfaceMethods<HoistingContainerOpInterface>,
+ AffineScope, AutomaticAllocationScope,
FunctionOpInterface, IsolatedFromAbove, OpAsmOpInterface
]> {
let summary = "An operation with a name containing a single `SSACFG` region";
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index f554147814a75..e754a04b0903a 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -20,7 +20,6 @@
#include "mlir/IR/RegionKindInterface.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
-#include "mlir/Interfaces/HoistingContainerOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 62061a90c0d2e..a08cf3c95e6ce 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -18,7 +18,6 @@ include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/IR/RegionKindInterface.td"
include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
-include "mlir/Interfaces/HoistingContainerOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/ParallelCombiningOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -163,8 +162,7 @@ def ForOp : SCF_Op<"for",
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getEntrySuccessorOperands", "getSuccessorInputs"]>,
SingleBlockImplicitTerminator<"scf::YieldOp">,
- RecursiveMemoryEffects,
- DeclareOpInterfaceMethods<HoistingContainerOpInterface>]> {
+ RecursiveMemoryEffects]> {
let summary = "for operation";
let description = [{
The `scf.for` operation represents a loop taking 3 SSA value as operands
@@ -988,7 +986,6 @@ def WhileOp : SCF_Op<"while",
["getEntrySuccessorOperands", "getSuccessorInputs"]>,
DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getRegionIterArgs", "getYieldedValuesMutable"]>,
- DeclareOpInterfaceMethods<HoistingContainerOpInterface>,
RecursiveMemoryEffects, SingleBlock]> {
let summary = "a generic 'while' loop";
let description = [{
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index ecfdb98cc9f76..3cbc9df05f3d7 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -5,7 +5,6 @@ add_mlir_interface(ControlFlowInterfaces)
add_mlir_interface(DerivedAttributeOpInterface)
add_mlir_interface(DestinationStyleOpInterface)
add_mlir_interface(FunctionInterfaces)
-add_mlir_interface(HoistingContainerOpInterface)
add_mlir_interface(IndexingMapOpInterface)
add_mlir_interface(InferIntRangeInterface)
add_mlir_interface(InferStridedMetadataInterface)
diff --git a/mlir/include/mlir/Interfaces/HoistingContainerOpInterface.h b/mlir/include/mlir/Interfaces/HoistingContainerOpInterface.h
deleted file mode 100644
index b6a6addd89173..0000000000000
--- a/mlir/include/mlir/Interfaces/HoistingContainerOpInterface.h
+++ /dev/null
@@ -1,23 +0,0 @@
-//===- HoistingContainerOpInterface.h ---------------------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_INTERFACES_HOISTING_CONTAINER_OP_INTERFACE_H_
-#define MLIR_INTERFACES_HOISTING_CONTAINER_OP_INTERFACE_H_
-
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/Interfaces/HoistingContainerOpInterface.h.inc"
-
-namespace mlir {
-/// Returns true if the given operation implements HoistingContainerOpInterface
-/// and its implementation allows hosting hoisted operations. Returns false
-/// if the operation does not implement the interface, or if the operation
-/// explicitly disallows hoisting.
-bool canContainHoistedOps(Operation *op);
-} // namespace mlir
-
-#endif // MLIR_INTERFACES_HOISTING_CONTAINER_OP_INTERFACE_H_
diff --git a/mlir/include/mlir/Interfaces/HoistingContainerOpInterface.td b/mlir/include/mlir/Interfaces/HoistingContainerOpInterface.td
deleted file mode 100644
index 1f1c9994f09c1..0000000000000
--- a/mlir/include/mlir/Interfaces/HoistingContainerOpInterface.td
+++ /dev/null
@@ -1,41 +0,0 @@
-//===- HoistingContainerOpInterface.td - Interface Decl. -*- tablegen -*---===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This is the definition file for the HoistingContainerOpInterface.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_INTERFACES_HOISTING_CONTAINER_OP_INTERFACE
-#define MLIR_INTERFACES_HOISTING_CONTAINER_OP_INTERFACE
-
-include "mlir/IR/OpBase.td"
-
-def HoistingContainerOpInterface : OpInterface<"HoistingContainerOpInterface"> {
- let description = [{
- This interface models whether an operation's regions are capable of
- acting as a container for operations hoisted from nested regions.
- }];
- let cppNamespace = "::mlir";
- let methods = [
- InterfaceMethod<
- /*desc=*/[{
- Returns true if this operation's regions can accommodate operations
- hoisted from its nested scopes.
- }],
- /*retTy=*/"bool",
- /*methodName=*/"canContainHoistedOps",
- /*args=*/(ins),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- return true;
- }]
- >
- ];
-}
-
-#endif // MLIR_INTERFACES_HOISTING_CONTAINER_OP_INTERFACE
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 20af90e56ee67..dcaeef144fb65 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -84,6 +84,10 @@ def CSEPass : Pass<"cse"> {
operations. See [Common subexpression elimination](https://en.wikipedia.org/wiki/Common_subexpression_elimination)
for more general details on this optimization.
}];
+ let options = [
+ Option<"hoistPureOps", "hoist-pure-ops", "bool", /*default=*/"true",
+ "Allow hoisting of pure operations out of regions">,
+ ];
let statistics = [
Statistic<"numCSE", "num-cse'd", "Number of operations CSE'd">,
Statistic<"numDCE", "num-dce'd", "Number of operations DCE'd">
diff --git a/mlir/lib/Dialect/Func/IR/CMakeLists.txt b/mlir/lib/Dialect/Func/IR/CMakeLists.txt
index c748fdf2b57f0..329301c6fbafd 100644
--- a/mlir/lib/Dialect/Func/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Func/IR/CMakeLists.txt
@@ -11,7 +11,6 @@ add_mlir_dialect_library(MLIRFuncDialect
MLIRCallInterfaces
MLIRControlFlowInterfaces
MLIRFunctionInterfaces
- MLIRHoistingContainerOpInterface
MLIRInferTypeOpInterface
MLIRIR
MLIRSideEffectInterfaces
diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
index bbf27027f37b4..b111117410ba3 100644
--- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
@@ -14,7 +14,6 @@ add_mlir_dialect_library(MLIRSCFDialect
MLIRControlFlowDialect
MLIRDialectUtils
MLIRFunctionInterfaces
- MLIRHoistingContainerOpInterface
MLIRIR
MLIRLoopLikeInterface
MLIRSideEffectInterfaces
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index 1d73e5d2c6912..ad3e2b61be418 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -8,7 +8,6 @@ set(LLVM_OPTIONAL_SOURCES
DestinationStyleOpInterface.cpp
FunctionImplementation.cpp
FunctionInterfaces.cpp
- HoistingContainerOpInterface.cpp
IndexingMapOpInterface.cpp
InferIntRangeInterface.cpp
InferStridedMetadataInterface.cpp
@@ -65,7 +64,7 @@ add_mlir_library(MLIRFunctionInterfaces
MLIRCallInterfaces
MLIRIR
)
-add_mlir_interface_library(HoistingContainerOpInterface)
+
add_mlir_interface_library(IndexingMapOpInterface)
add_mlir_interface_library(InferIntRangeInterface)
diff --git a/mlir/lib/Interfaces/HoistingContainerOpInterface.cpp b/mlir/lib/Interfaces/HoistingContainerOpInterface.cpp
deleted file mode 100644
index 33801c6509ad2..0000000000000
--- a/mlir/lib/Interfaces/HoistingContainerOpInterface.cpp
+++ /dev/null
@@ -1,21 +0,0 @@
-//===- HoistingContainerOpInterface.cpp -- Hoisting Container Op Interface -==//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Interfaces/HoistingContainerOpInterface.h"
-
-using namespace mlir;
-
-namespace mlir {
-#include "mlir/Interfaces/HoistingContainerOpInterface.cpp.inc"
-} // namespace mlir
-
-bool mlir::canContainHoistedOps(Operation *op) {
- if (auto containerOp = dyn_cast<HoistingContainerOpInterface>(op))
- return containerOp.canContainHoistedOps();
- return false;
-}
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 1a2cd72691a79..8907724627386 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -32,7 +32,6 @@ add_mlir_library(MLIRTransforms
LINK_LIBS PUBLIC
MLIRAnalysis
MLIRFunctionInterfaces
- MLIRHoistingContainerOpInterface
MLIRLoopLikeInterface
MLIRMemOpInterfaces
MLIRMemorySlotInterfaces
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index a8ffda195e29a..437dcf8d4e45d 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -16,7 +16,6 @@
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
-#include "mlir/Interfaces/HoistingContainerOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Passes.h"
@@ -63,8 +62,9 @@ namespace {
/// Simple common sub-expression elimination.
class CSEDriver {
public:
- CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo)
- : rewriter(rewriter), domInfo(domInfo) {}
+ CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo,
+ bool hoistPureOps = true)
+ : rewriter(rewriter), domInfo(domInfo), hoistPureOps(hoistPureOps) {}
/// Simplify all operations within the given op.
void simplify(Operation *op, bool *changed = nullptr);
@@ -127,6 +127,7 @@ class CSEDriver {
/// Operations marked as dead and to be erased.
SmallVector<Operation *> opsToErase;
DominanceInfo *domInfo = nullptr;
+ bool hoistPureOps = true;
MemEffectsCache memEffectsCache;
// Various statistics.
@@ -167,13 +168,6 @@ LogicalResult CSEDriver::hoistPureOp(Operation *existing, Operation *op) {
return failure();
}
- // If the ancestorBlock is in a different region than the existing operation,
- // we need to check if the parentOp of the ancestorBlock can contain hoisted
- // ops.
- if (ancestorBlock->getParent() != existing->getParentRegion() &&
- !canContainHoistedOps(ancestorBlock->getParentOp()))
- return failure();
-
if (isBlockCrossIsIsolatedFromAbove(domInfo, ancestorBlock,
existing->getBlock()))
return failure();
@@ -219,7 +213,7 @@ void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
// Replace all uses, but do not remove the operation yet. This does not
// notify the listener because the original op is not erased.
if (!domInfo->properlyDominates(existing, op)) {
- if (failed(hoistPureOp(existing, op)))
+ if (!hoistPureOps || failed(hoistPureOp(existing, op)))
return;
}
LDBG() << "replace " << OpWithFlags(op, OpPrintingFlags().skipRegions())
@@ -242,7 +236,7 @@ void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
rewriteListener->notifyOperationReplaced(op, existing);
if (!domInfo->properlyDominates(existing, op)) {
- if (failed(hoistPureOp(existing, op)))
+ if (!hoistPureOps || failed(hoistPureOp(existing, op)))
return;
}
// Replace all uses, but do not remove the operation yet. This does not
@@ -539,6 +533,7 @@ void mlir::eliminateCommonSubExpressions(RewriterBase &rewriter,
namespace {
/// CSE pass.
struct CSE : public impl::CSEPassBase<CSE> {
+ using impl::CSEPassBase<CSE>::CSEPassBase;
void runOnOperation() override;
};
} // namespace
@@ -546,7 +541,7 @@ struct CSE : public impl::CSEPassBase<CSE> {
void CSE::runOnOperation() {
// Simplify the IR.
IRRewriter rewriter(&getContext());
- CSEDriver driver(rewriter, &getAnalysis<DominanceInfo>());
+ CSEDriver driver(rewriter, &getAnalysis<DominanceInfo>(), hoistPureOps);
bool changed = false;
driver.simplify(getOperation(), &changed);
diff --git a/mlir/test/Pass/run-reproducer.mlir b/mlir/test/Pass/run-reproducer.mlir
index 256c9e83f30f9..288194dad6275 100644
--- a/mlir/test/Pass/run-reproducer.mlir
+++ b/mlir/test/Pass/run-reproducer.mlir
@@ -16,11 +16,11 @@ func.func @bar() {
verify_each: true,
// CHECK: builtin.module(
// CHECK-NEXT: func.func(
- // CHECK-NEXT: cse,
+ // CHECK-NEXT: cse{hoist-pure-ops=true},
// CHECK-NEXT: canonicalize{ max-iterations=1 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=false}
// CHECK-NEXT: )
// CHECK-NEXT: )
- pipeline: "builtin.module(func.func(cse,canonicalize{max-iterations=1 max-num-rewrites=-1 region-simplify=normal top-down=false}))",
+ pipeline: "builtin.module(func.func(cse{hoist-pure-ops=1},canonicalize{max-iterations=1 max-num-rewrites=-1 region-simplify=normal top-down=false}))",
disable_threading: true
}
}
diff --git a/mlir/test/Transforms/composite-pass.mlir b/mlir/test/Transforms/composite-pass.mlir
index 460cd612cde63..bee7c502f642f 100644
--- a/mlir/test/Transforms/composite-pass.mlir
+++ b/mlir/test/Transforms/composite-pass.mlir
@@ -4,7 +4,7 @@
// Ensure the composite pass correctly prints its options.
// PIPELINE: builtin.module(
// PIPELINE-NEXT: composite-fixed-point-pass{max-iterations=10 name=TestCompositePass
-// PIPELINE-SAME: pipeline=canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true},cse}
+// PIPELINE-SAME: pipeline=canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true},cse{hoist-pure-ops=true}}
// CHECK-LABEL: running `TestCompositePass`
// CHECK: running `CanonicalizerPass`
diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index a097af92d1f0a..22d5a6e9cfadd 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -115,10 +115,10 @@ def testAdd():
pm = PassManager("any", Context())
# CHECK: pm: 'any()'
log(f"pm: '{pm}'")
- # CHECK: pm: 'any(cse)'
+ # CHECK: pm: 'any(cse{hoist-pure-ops=true})'
pm.add("cse")
log(f"pm: '{pm}'")
- # CHECK: pm: 'any(cse,cse)'
+ # CHECK: pm: 'any(cse{hoist-pure-ops=true},cse{hoist-pure-ops=true})'
pm.add("cse")
log(f"pm: '{pm}'")
>From 86a0dedc1925d3a221e1086fc1eada38dfac1ae1 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Mon, 2 Mar 2026 08:21:06 +0000
Subject: [PATCH 7/7] rebase main and remove unuse logic.
---
flang/lib/Optimizer/Passes/Pipelines.cpp | 2 +-
mlir/lib/Transforms/CSE.cpp | 5 -----
2 files changed, 1 insertion(+), 6 deletions(-)
diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp
index a6a721da60bc8..c630fed0d009b 100644
--- a/flang/lib/Optimizer/Passes/Pipelines.cpp
+++ b/flang/lib/Optimizer/Passes/Pipelines.cpp
@@ -273,7 +273,7 @@ void createHLFIRToFIRPassPipeline(mlir::PassManager &pm,
pm, hlfir::createInlineElementals);
if (optLevel.isOptimizingForSpeed()) {
addCanonicalizerPassWithoutRegionSimplification(pm);
- mlir::CSEOptions options;
+ mlir::CSEPassOptions options;
options.hoistPureOps = false;
pm.addPass(mlir::createCSEPass(options));
// Run SimplifyHLFIRIntrinsics pass late after CSE,
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 437dcf8d4e45d..a1b732c35b4d7 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -447,11 +447,6 @@ void CSEDriver::simplifyRegion(ScopedMapTy &knownValues,
LDBG() << "visit region #" << region.getRegionNumber() << " of "
<< OpWithFlags(region.getParentOp(), OpPrintingFlags().skipRegions());
- // Prevent CSE of pure operations across function boundaries.
- std::unique_ptr<ScopedMapTy::ScopeTy> funcPureScope;
- if (isa<FunctionOpInterface>(region.getParentOp())) {
- funcPureScope = std::make_unique<ScopedMapTy::ScopeTy>(knownPureOps);
- }
bool hasSSADominance = domInfo->hasSSADominance(®ion);
// If the region only contains one block, then simplify it directly.
if (region.hasOneBlock()) {
More information about the Mlir-commits
mailing list