[Mlir-commits] [mlir] [mlir][analysis] Introduce hoist-pure-ops logic to CSE pass (PR #180556)

lonely eagle llvmlistbot at llvm.org
Thu Feb 12 00:56:42 PST 2026


https://github.com/linuxlonelyeagle updated https://github.com/llvm/llvm-project/pull/180556

>From 394c2c2d152b8c91d045a2418b77d83c13cb70e3 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Sun, 8 Feb 2026 06:27:20 +0000
Subject: [PATCH 1/3] add hoist pure loigc to CSE.

---
 mlir/lib/Transforms/CSE.cpp                   | 159 +++++++++++--
 .../ArmSMEToLLVM/tile-spills-and-fills.mlir   |  12 +-
 ...ot-bufferize-empty-tensor-elimination.mlir |  15 +-
 .../Linalg/matmul-shared-memory-padding.mlir  |   3 +-
 .../test/Dialect/Linalg/transform-op-pad.mlir |   2 +-
 .../SparseTensor/sparse_fill_zero.mlir        | 211 +++++++++---------
 .../sparse_kernels_to_iterator.mlir           | 118 +++++-----
 .../SparseTensor/sparse_vector_index.mlir     |  85 ++++---
 mlir/test/Transforms/cse.mlir                 |  24 +-
 9 files changed, 363 insertions(+), 266 deletions(-)

diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 8eaac308755fd..594460e763a3d 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -15,12 +15,14 @@
 
 #include "mlir/IR/Dominance.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/Passes.h"
 #include "llvm/ADT/DenseMapInfo.h"
 #include "llvm/ADT/ScopedHashTable.h"
 #include "llvm/Support/Allocator.h"
+#include "llvm/Support/DebugLog.h"
 #include "llvm/Support/RecyclingAllocator.h"
 #include <deque>
 
@@ -29,6 +31,7 @@ namespace mlir {
 #include "mlir/Transforms/Passes.h.inc"
 } // namespace mlir
 
+#define DEBUG_TYPE "cse"
 using namespace mlir;
 
 namespace {
@@ -101,13 +104,17 @@ class CSEDriver {
 
   /// Attempt to eliminate a redundant operation. Returns success if the
   /// operation was marked for removal, failure otherwise.
-  LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op,
+  LogicalResult simplifyOperation(ScopedMapTy &knownValues,
+                                  ScopedMapTy &knownPureOps, Operation *op,
                                   bool hasSSADominance);
-  void simplifyBlock(ScopedMapTy &knownValues, Block *bb, bool hasSSADominance);
-  void simplifyRegion(ScopedMapTy &knownValues, Region &region);
+  void simplifyBlock(ScopedMapTy &knownValues, ScopedMapTy &knownPureOps,
+                     Block *bb, bool hasSSADominance);
+  void simplifyRegion(ScopedMapTy &knownValues, ScopedMapTy &knownPureOps,
+                      Region &region);
 
   void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
                             Operation *existing, bool hasSSADominance);
+  LogicalResult hoistPureOp(Operation *existing, Operation *op);
 
   /// Check if there is side-effecting operations other than the given effect
   /// between the two operations.
@@ -117,7 +124,7 @@ class CSEDriver {
   RewriterBase &rewriter;
 
   /// Operations marked as dead and to be erased.
-  std::vector<Operation *> opsToErase;
+  SmallVector<Operation *> opsToErase;
   DominanceInfo *domInfo = nullptr;
   MemEffectsCache memEffectsCache;
 
@@ -127,6 +134,42 @@ class CSEDriver {
 };
 } // namespace
 
+/// Hoist the pure ops to the location of the Nearest Common Dominator.
+LogicalResult CSEDriver::hoistPureOp(Operation *existing, Operation *op) {
+  Block *ancestorBlock =
+      domInfo->findNearestCommonDominator(existing->getBlock(), op->getBlock());
+  if (!ancestorBlock) {
+    LDBG() << "hoist " << OpWithFlags(existing, OpPrintingFlags().skipRegions())
+           << " and " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+           << "failed";
+    return failure();
+  }
+
+  Operation *insertPoint = nullptr;
+  for (Value operand : op->getOperands()) {
+    if (domInfo->properlyDominates(operand, &ancestorBlock->front()))
+      continue;
+    if (!insertPoint) {
+      insertPoint = operand.getDefiningOp();
+    } else {
+      insertPoint = domInfo->dominates(insertPoint, operand.getDefiningOp())
+                        ? operand.getDefiningOp()
+                        : insertPoint;
+    }
+  }
+  if (!insertPoint) {
+    rewriter.moveOpBefore(existing, ancestorBlock, ancestorBlock->begin());
+    rewriter.moveOpAfter(op, existing);
+  } else {
+    rewriter.moveOpAfter(existing, insertPoint);
+    rewriter.moveOpAfter(op, existing);
+  }
+  LDBG() << "hoist " << OpWithFlags(existing, OpPrintingFlags().skipRegions())
+         << " and " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+         << "success";
+  return success();
+}
+
 void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
                                      Operation *existing,
                                      bool hasSSADominance) {
@@ -141,6 +184,15 @@ void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
       rewriteListener->notifyOperationReplaced(op, existing);
     // Replace all uses, but do not remove the operation yet. This does not
     // notify the listener because the original op is not erased.
+    if (!domInfo->properlyDominates(existing, op)) {
+      if (failed(hoistPureOp(existing, op)))
+        return;
+    }
+    LDBG() << "replace " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+           << " with "
+           << OpWithFlags(existing, OpPrintingFlags().skipRegions());
+    LDBG() << "add " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+           << " to opsToErase";
     rewriter.replaceAllUsesWith(op->getResults(), existing->getResults());
     opsToErase.push_back(op);
   } else {
@@ -155,14 +207,25 @@ void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
         if (all_of(v.getUses(), wasVisited))
           rewriteListener->notifyOperationReplaced(op, existing);
 
+    if (!domInfo->properlyDominates(existing, op)) {
+      if (failed(hoistPureOp(existing, op)))
+        return;
+    }
     // Replace all uses, but do not remove the operation yet. This does not
     // notify the listener because the original op is not erased.
+    LDBG() << "replace " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+           << " with "
+           << OpWithFlags(existing, OpPrintingFlags().skipRegions());
     rewriter.replaceUsesWithIf(op->getResults(), existing->getResults(),
                                wasVisited);
 
     // There may be some remaining uses of the operation.
-    if (op->use_empty())
+    if (op->use_empty()) {
+      LDBG() << "use_empty, add "
+             << OpWithFlags(op, OpPrintingFlags().skipRegions())
+             << " to opsToErase";
       opsToErase.push_back(op);
+    }
   }
 
   // If the existing operation has an unknown location and the current
@@ -222,8 +285,11 @@ bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp,
 
 /// Attempt to eliminate a redundant operation.
 LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
+                                           ScopedMapTy &knownPureOps,
                                            Operation *op,
                                            bool hasSSADominance) {
+  LDBG() << "visit operation: "
+         << OpWithFlags(op, OpPrintingFlags().skipRegions());
   // Don't simplify terminator operations.
   if (op->hasTrait<OpTrait::IsTerminator>())
     return failure();
@@ -261,7 +327,27 @@ LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
         return success();
       }
     }
-    knownValues.insert(op, op);
+    if (auto *existing = knownPureOps.lookup(op)) {
+      if (existing->getBlock() == op->getBlock() &&
+          !hasOtherSideEffectingOpInBetween(existing, op)) {
+        // The operation that can be deleted has been reach with no
+        // side-effecting operations in between the existing operation and
+        // this one so we can remove the duplicate.
+        replaceUsesAndDelete(knownPureOps, op, existing, hasSSADominance);
+        return success();
+      }
+    }
+
+    if (mlir::isPure(op)) {
+      LDBG() << "insert op: "
+             << OpWithFlags(op, OpPrintingFlags().skipRegions())
+             << "to pureMap";
+      knownPureOps.insert(op, op);
+    } else {
+      LDBG() << "insert op: "
+             << OpWithFlags(op, OpPrintingFlags().skipRegions()) << "to map";
+      knownValues.insert(op, op);
+    }
     return failure();
   }
 
@@ -272,14 +358,31 @@ LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
     return success();
   }
 
-  // Otherwise, we add this operation to the known values map.
-  knownValues.insert(op, op);
+  if (auto *existing = knownPureOps.lookup(op)) {
+    replaceUsesAndDelete(knownPureOps, op, existing, hasSSADominance);
+    ++numCSE;
+    return success();
+  }
+
+  if (mlir::isPure(op)) {
+    LDBG() << "insert op: " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+           << "to pureMap";
+    knownPureOps.insert(op, op);
+  } else {
+    // Otherwise, we add this operation to the known values map.
+    LDBG() << "insert op: " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+           << "to map";
+    knownValues.insert(op, op);
+  }
   return failure();
 }
 
-void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb,
+void CSEDriver::simplifyBlock(ScopedMapTy &knownValues,
+                              ScopedMapTy &knownPureOps, Block *bb,
                               bool hasSSADominance) {
-  for (auto &op : *bb) {
+  LDBG() << "visit block #" << bb->computeBlockNumber() << " of "
+         << OpWithFlags(bb->getParentOp(), OpPrintingFlags().skipRegions());
+  for (auto &op : llvm::make_early_inc_range(*bb)) {
     // Most operations don't have regions, so fast path that case.
     if (op.getNumRegions() != 0) {
       // If this operation is isolated above, we can't process nested regions
@@ -287,34 +390,45 @@ void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb,
       // implicit captures in explicit capture only regions.
       if (op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>()) {
         ScopedMapTy nestedKnownValues;
+        ScopedMapTy nestedKnownPureOps;
+        ScopedMapTy::ScopeTy scope(nestedKnownValues);
+        ScopedMapTy::ScopeTy pureScope(nestedKnownPureOps);
         for (auto &region : op.getRegions())
-          simplifyRegion(nestedKnownValues, region);
+          simplifyRegion(nestedKnownValues, nestedKnownPureOps, region);
       } else {
         // Otherwise, process nested regions normally.
         for (auto &region : op.getRegions())
-          simplifyRegion(knownValues, region);
+          simplifyRegion(knownValues, knownPureOps, region);
       }
     }
 
     // If the operation is simplified, we don't process any held regions.
-    if (succeeded(simplifyOperation(knownValues, &op, hasSSADominance)))
+    if (succeeded(
+            simplifyOperation(knownValues, knownPureOps, &op, hasSSADominance)))
       continue;
   }
   // Clear the MemoryEffects cache since its usage is by block only.
   memEffectsCache.clear();
 }
 
-void CSEDriver::simplifyRegion(ScopedMapTy &knownValues, Region &region) {
+void CSEDriver::simplifyRegion(ScopedMapTy &knownValues,
+                               ScopedMapTy &knownPureOps, Region &region) {
   // If the region is empty there is nothing to do.
   if (region.empty())
     return;
+  LDBG() << "visit region #" << region.getRegionNumber() << " of "
+         << OpWithFlags(region.getParentOp(), OpPrintingFlags().skipRegions());
 
+  // Prevent CSE of pure operations across function boundaries.
+  std::optional<ScopedMapTy::ScopeTy> funcPureScope;
+  if (isa<FunctionOpInterface>(region.getParentOp())) {
+    funcPureScope.emplace(knownPureOps);
+  }
   bool hasSSADominance = domInfo->hasSSADominance(&region);
-
   // If the region only contains one block, then simplify it directly.
   if (region.hasOneBlock()) {
     ScopedMapTy::ScopeTy scope(knownValues);
-    simplifyBlock(knownValues, &region.front(), hasSSADominance);
+    simplifyBlock(knownValues, knownPureOps, &region.front(), hasSSADominance);
     return;
   }
 
@@ -342,7 +456,7 @@ void CSEDriver::simplifyRegion(ScopedMapTy &knownValues, Region &region) {
     // Check to see if we need to process this node.
     if (!currentNode->processed) {
       currentNode->processed = true;
-      simplifyBlock(knownValues, currentNode->node->getBlock(),
+      simplifyBlock(knownValues, knownPureOps, currentNode->node->getBlock(),
                     hasSSADominance);
     }
 
@@ -361,9 +475,14 @@ void CSEDriver::simplifyRegion(ScopedMapTy &knownValues, Region &region) {
 
 void CSEDriver::simplify(Operation *op, bool *changed) {
   /// Simplify all regions.
-  ScopedMapTy knownValues;
-  for (auto &region : op->getRegions())
-    simplifyRegion(knownValues, region);
+  {
+    ScopedMapTy knownValues;
+    ScopedMapTy knownPureOps;
+    ScopedMapTy::ScopeTy scope(knownPureOps);
+    for (auto &region : op->getRegions()) {
+      simplifyRegion(knownValues, knownPureOps, region);
+    }
+  }
 
   /// Erase any operations that were marked as dead during simplification.
   for (auto *op : opsToErase)
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
index 2a183cb4d056a..7fd66035a4140 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
@@ -82,6 +82,8 @@ func.func @use_too_many_tiles() {
 
 //  AFTER-LLVM-LOWERING-DAG: %[[TILE_ALLOCA:.*]] = memref.alloca(%[[SVL_H]], %[[SVL_H]])
 // AFTER-LLVM-LOWERING-SAME:   {arm_sme.in_memory_tile_id = 16 : i32} : memref<?x?xi16>
+//      AFTER-LLVM-LOWERING:   %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]
+//      AFTER-LLVM-LOWERING:   %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1]
 //
 //  AFTER-LLVM-LOWERING-NOT: scf.for
 
@@ -104,8 +106,6 @@ func.func @use_too_many_tiles() {
 
 //      AFTER-LLVM-LOWERING: scf.for
 // AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_H]] step %[[C1]] {
-//      AFTER-LLVM-LOWERING:   %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]
-//      AFTER-LLVM-LOWERING:   %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1]
 //      AFTER-LLVM-LOWERING:   %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]]
 //      AFTER-LLVM-LOWERING:   %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}>
 // AFTER-LLVM-LOWERING-NEXT:   "arm_sme.intr.ld1h.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}>
@@ -122,8 +122,6 @@ func.func @use_too_many_tiles() {
 
 //      AFTER-LLVM-LOWERING: scf.for
 // AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_H]] step %[[C1]] {
-//      AFTER-LLVM-LOWERING:   %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]
-//      AFTER-LLVM-LOWERING:   %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1]
 //      AFTER-LLVM-LOWERING:   %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]]
 //      AFTER-LLVM-LOWERING:   %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}>
 // AFTER-LLVM-LOWERING-NEXT:   "arm_sme.intr.ld1h.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}>
@@ -156,6 +154,8 @@ func.func @very_excessive_spills(%useAllTiles : vector<[16]x[16]xi8>, %memref: m
 //  AFTER-LLVM-LOWERING-DAG: %[[SVL_S:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
 //  AFTER-LLVM-LOWERING-DAG: %[[TILE_ALLOCA:.*]] = memref.alloca(%[[SVL_S]], %[[SVL_S]])
 // AFTER-LLVM-LOWERING-SAME:   {arm_sme.in_memory_tile_id = 16 : i32} : memref<?x?xf32>
+//      AFTER-LLVM-LOWERING:   %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]
+//      AFTER-LLVM-LOWERING:   %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1]
 //
 
 /// 1. Swap %useAllTiles and %tile - note that this will only swap one 32-bit
@@ -163,8 +163,6 @@ func.func @very_excessive_spills(%useAllTiles : vector<[16]x[16]xi8>, %memref: m
 
 //      AFTER-LLVM-LOWERING: scf.for
 // AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_S]] step %[[C1]] {
-//      AFTER-LLVM-LOWERING:   %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]
-//      AFTER-LLVM-LOWERING:   %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1]
 //      AFTER-LLVM-LOWERING:   %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]]
 // Read ZA tile slice -> vector
 //      AFTER-LLVM-LOWERING:   %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}>
@@ -182,8 +180,6 @@ func.func @very_excessive_spills(%useAllTiles : vector<[16]x[16]xi8>, %memref: m
 
 //      AFTER-LLVM-LOWERING: scf.for
 // AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_S]] step %[[C1]] {
-//      AFTER-LLVM-LOWERING:   %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]
-//      AFTER-LLVM-LOWERING:   %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1]
 //      AFTER-LLVM-LOWERING:   %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]]
 /// Read ZA tile slice -> vector
 //      AFTER-LLVM-LOWERING:   %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}>
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
index 3929f5be3b4ef..8dc6364fddb2e 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -199,27 +199,20 @@ func.func @eleminate_multiple_ops(%t: tensor<?xf32> {bufferization.buffer_layout
 {
   %cst1 = arith.constant 0.0: f32
   %cst2 = arith.constant 1.0: f32
-
-  // CHECK: %[[r:.*]] = scf.if %{{.*}} -> (memref
+  // CHECK: %[[T_SUBVIEW_1:.*]] =  memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1]
+  // CHECK: scf.if %{{.*}}
   %if = scf.if %c -> tensor<?xf32> {
-    // CHECK: %[[T_SUBVIEW_1:.*]] =  memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1]
     %a1 = tensor.empty(%sz) : tensor<?xf32>
     // CHECK: linalg.fill ins({{.*}} : f32) outs(%[[T_SUBVIEW_1]] : memref<?xf32
     %f1 = linalg.fill ins(%cst1 : f32) outs(%a1 : tensor<?xf32>) -> tensor<?xf32>
-    // CHECK: scf.yield %[[T_SUBVIEW_1]]
     scf.yield %f1 : tensor<?xf32>
   } else {
-      // CHECK: %[[T_SUBVIEW_2:.*]] =  memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1]
     %a2 = tensor.empty(%sz) : tensor<?xf32>
-    // CHECK: linalg.fill ins({{.*}} : f32) outs(%[[T_SUBVIEW_2]] : memref<?xf32
+    // CHECK: linalg.fill ins({{.*}} : f32) outs(%[[T_SUBVIEW_1]] : memref<?xf32
     %f2 = linalg.fill ins(%cst2 : f32) outs(%a2 : tensor<?xf32>) -> tensor<?xf32>
-    // CHECK: scf.yield %[[T_SUBVIEW_2]]
     scf.yield %f2 : tensor<?xf32>
   }
-
-  // Self-copy could canonicalize away later.
-  // CHECK: %[[T_SUBVIEW_3:.*]] =  memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1]
-  // CHECK: memref.copy %[[r]], %[[T_SUBVIEW_3]]
+  // CHECK: return %[[FUNC_ARG]] 
   %r1 = tensor.insert_slice %if into %t[42][%sz][1]: tensor<?xf32> into tensor<?xf32>
   return %r1: tensor<?xf32>
 }
diff --git a/mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir b/mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir
index 6cab25b50460d..8aa59b2bf0f5f 100644
--- a/mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir
+++ b/mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir
@@ -5,14 +5,13 @@
 //   CHECK-NOT:     memref.copy
 //       CHECK:     linalg.fill
 //       CHECK:     scf.for
+//       CHECK:       vector.constant_mask [16, 4] : vector<128x4xi1>
 //       CHECK:       memref.alloc() : memref<128x16xf32, 3>
 //       CHECK:       scf.forall
-//       CHECK:         vector.constant_mask [16, 4] : vector<128x4xi1>
 //       CHECK:         vector.transfer_read
 //       CHECK:         vector.transfer_write
 //       CHECK:       memref.alloc() : memref<16x128xf32, 3>
 //       CHECK:       scf.forall
-//       CHECK:         vector.constant_mask [16, 4] : vector<128x4xi1>
 //       CHECK:         vector.transfer_read
 //       CHECK:         vector.transfer_write
 //       CHECK:       memref.alloc() : memref<128x128xf32, 3>
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
index 51bf4a23406d4..9b6716f09df37 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -315,7 +315,7 @@ module attributes {transform.with_named_sequence} {
 
 // Test dynamic padding using `use_prescribed_tensor_shapes`
 
-// CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (-s0 + (s0 ceildiv 7) * 7)>
+// CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (-s0 + (s0 ceildiv 7) * 7)> 
 // CHECK: @use_prescribed_tensor_shapes
 // CHECK: (%[[ARG0:.*]]: tensor<?x12xf32>, %[[ARG1:.*]]: tensor<12x?xf32>
 func.func @use_prescribed_tensor_shapes(%arg0: tensor<?x12xf32>,
diff --git a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
index c26ba56347299..f6aa91ad2b10f 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
@@ -3,118 +3,117 @@
 #DCSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>
 
 // CHECK-LABEL:   func.func @fill_zero_after_alloc(
-// CHECK-SAME:      %[[VAL_0:.*]]: !llvm.ptr,
-// CHECK-SAME:      %[[VAL_1:.*]]: !llvm.ptr) -> !llvm.ptr {
-// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK-DAG:       %[[ZERO:.*]] = llvm.mlir.zero : !llvm.ptr
-// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 1 : i32
-// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0 : i32
-// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : index
-// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 1 : index
-// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant false
-// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant true
-// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 100 : index
-// CHECK-DAG:       %[[VAL_10:.*]] = arith.constant 300 : index
-// CHECK-DAG:       %[[VAL_11:.*]] = arith.constant 262144 : i64
-// CHECK:           %[[VAL_12:.*]] = memref.alloca() : memref<2xi64>
-// CHECK:           %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi64> to memref<?xi64>
-// CHECK:           memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi64>
-// CHECK:           memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_6]]] : memref<2xi64>
-// CHECK:           %[[VAL_14:.*]] = memref.alloca() : memref<2xindex>
-// CHECK:           %[[VAL_15:.*]] = memref.cast %[[VAL_14]] : memref<2xindex> to memref<?xindex>
-// CHECK:           memref.store %[[VAL_9]], %[[VAL_14]]{{\[}}%[[VAL_5]]] : memref<2xindex>
-// CHECK:           memref.store %[[VAL_10]], %[[VAL_14]]{{\[}}%[[VAL_6]]] : memref<2xindex>
-// CHECK:           %[[VAL_16:.*]] = memref.alloca() : memref<2xindex>
-// CHECK:           %[[VAL_17:.*]] = memref.cast %[[VAL_16]] : memref<2xindex> to memref<?xindex>
-// CHECK:           memref.store %[[VAL_5]], %[[VAL_16]]{{\[}}%[[VAL_5]]] : memref<2xindex>
-// CHECK:           memref.store %[[VAL_6]], %[[VAL_16]]{{\[}}%[[VAL_6]]] : memref<2xindex>
-// CHECK:           %[[VAL_19:.*]] = call @newSparseTensor(%[[VAL_15]], %[[VAL_15]], %[[VAL_13]], %[[VAL_17]], %[[VAL_17]], %[[VAL_4]], %[[VAL_4]], %[[VAL_3]], %[[VAL_4]], %[[ZERO]]) : (memref<?xindex>, memref<?xindex>, memref<?xi64>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr
-// CHECK:           %[[VAL_20:.*]] = memref.alloc() : memref<300xf64>
-// CHECK:           %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<300xf64> to memref<?xf64>
-// CHECK:           %[[VAL_22:.*]] = memref.alloc() : memref<300xi1>
-// CHECK:           %[[VAL_23:.*]] = memref.cast %[[VAL_22]] : memref<300xi1> to memref<?xi1>
-// CHECK:           %[[VAL_24:.*]] = memref.alloc() : memref<300xindex>
-// CHECK:           %[[VAL_25:.*]] = memref.cast %[[VAL_24]] : memref<300xindex> to memref<?xindex>
-// CHECK:           linalg.fill ins(%[[VAL_2]] : f64) outs(%[[VAL_20]] : memref<300xf64>)
-// CHECK:           linalg.fill ins(%[[VAL_7]] : i1) outs(%[[VAL_22]] : memref<300xi1>)
-// CHECK-DAG:       %[[VAL_26:.*]] = call @sparsePositions0(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK-DAG:       %[[VAL_27:.*]] = call @sparseCoordinates0(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK-DAG:       %[[VAL_28:.*]] = call @sparsePositions0(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK-DAG:       %[[VAL_29:.*]] = call @sparseCoordinates0(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK-DAG:       %[[VAL_30:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr) -> memref<?xf64>
-// CHECK-DAG:       %[[VAL_31:.*]] = call @sparsePositions0(%[[VAL_1]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK-DAG:       %[[VAL_32:.*]] = call @sparseCoordinates0(%[[VAL_1]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK-DAG:       %[[VAL_33:.*]] = call @sparsePositions0(%[[VAL_1]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK-DAG:       %[[VAL_34:.*]] = call @sparseCoordinates0(%[[VAL_1]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK-DAG:       %[[VAL_35:.*]] = call @sparseValuesF64(%[[VAL_1]]) : (!llvm.ptr) -> memref<?xf64>
-// CHECK:           %[[VAL_36:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK:           %[[VAL_37:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_6]]] : memref<?xindex>
-// CHECK:           scf.for %[[VAL_38:.*]] = %[[VAL_36]] to %[[VAL_37]] step %[[VAL_6]] {
-// CHECK:             %[[VAL_39:.*]] = memref.load %[[VAL_27]]{{\[}}%[[VAL_38]]] : memref<?xindex>
-// CHECK:             %[[VAL_40:.*]] = memref.load %[[VAL_28]]{{\[}}%[[VAL_38]]] : memref<?xindex>
-// CHECK:             %[[VAL_41:.*]] = arith.addi %[[VAL_38]], %[[VAL_6]] : index
-// CHECK:             %[[VAL_42:.*]] = memref.load %[[VAL_28]]{{\[}}%[[VAL_41]]] : memref<?xindex>
-// CHECK:             %[[VAL_43:.*]] = memref.load %[[VAL_31]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK:             %[[VAL_44:.*]] = memref.load %[[VAL_31]]{{\[}}%[[VAL_6]]] : memref<?xindex>
-// CHECK:             %[[VAL_45:.*]]:3 = scf.while (%[[VAL_46:.*]] = %[[VAL_40]], %[[VAL_47:.*]] = %[[VAL_43]], %[[VAL_48:.*]] = %[[VAL_5]]) : (index, index, index) -> (index, index, index) {
-// CHECK:               %[[VAL_49:.*]] = arith.cmpi ult, %[[VAL_46]], %[[VAL_42]] : index
-// CHECK:               %[[VAL_50:.*]] = arith.cmpi ult, %[[VAL_47]], %[[VAL_44]] : index
-// CHECK:               %[[VAL_51:.*]] = arith.andi %[[VAL_49]], %[[VAL_50]] : i1
-// CHECK:               scf.condition(%[[VAL_51]]) %[[VAL_46]], %[[VAL_47]], %[[VAL_48]] : index, index, index
+// CHECK-SAME:      %[[ARG0:.*]]: !llvm.ptr,
+// CHECK-SAME:      %[[ARG1:.*]]: !llvm.ptr) -> !llvm.ptr {
+// CHECK:           %[[CONSTANT_0:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK:           %[[MLIR_0:.*]] = llvm.mlir.zero : !llvm.ptr
+// CHECK:           %[[CONSTANT_1:.*]] = arith.constant 1 : i32
+// CHECK:           %[[CONSTANT_2:.*]] = arith.constant 0 : i32
+// CHECK:           %[[CONSTANT_3:.*]] = arith.constant true
+// CHECK:           %[[CONSTANT_4:.*]] = arith.constant false
+// CHECK:           %[[CONSTANT_5:.*]] = arith.constant 1 : index
+// CHECK:           %[[CONSTANT_6:.*]] = arith.constant 0 : index
+// CHECK:           %[[CONSTANT_7:.*]] = arith.constant 100 : index
+// CHECK:           %[[CONSTANT_8:.*]] = arith.constant 300 : index
+// CHECK:           %[[CONSTANT_9:.*]] = arith.constant 262144 : i64
+// CHECK:           %[[ALLOCA_0:.*]] = memref.alloca() : memref<2xi64>
+// CHECK:           %[[CAST_0:.*]] = memref.cast %[[ALLOCA_0]] : memref<2xi64> to memref<?xi64>
+// CHECK:           memref.store %[[CONSTANT_9]], %[[ALLOCA_0]]{{\[}}%[[CONSTANT_6]]] : memref<2xi64>
+// CHECK:           memref.store %[[CONSTANT_9]], %[[ALLOCA_0]]{{\[}}%[[CONSTANT_5]]] : memref<2xi64>
+// CHECK:           %[[ALLOCA_1:.*]] = memref.alloca() : memref<2xindex>
+// CHECK:           %[[CAST_1:.*]] = memref.cast %[[ALLOCA_1]] : memref<2xindex> to memref<?xindex>
+// CHECK:           memref.store %[[CONSTANT_7]], %[[ALLOCA_1]]{{\[}}%[[CONSTANT_6]]] : memref<2xindex>
+// CHECK:           memref.store %[[CONSTANT_8]], %[[ALLOCA_1]]{{\[}}%[[CONSTANT_5]]] : memref<2xindex>
+// CHECK:           %[[ALLOCA_2:.*]] = memref.alloca() : memref<2xindex>
+// CHECK:           %[[CAST_2:.*]] = memref.cast %[[ALLOCA_2]] : memref<2xindex> to memref<?xindex>
+// CHECK:           memref.store %[[CONSTANT_6]], %[[ALLOCA_2]]{{\[}}%[[CONSTANT_6]]] : memref<2xindex>
+// CHECK:           memref.store %[[CONSTANT_5]], %[[ALLOCA_2]]{{\[}}%[[CONSTANT_5]]] : memref<2xindex>
+// CHECK:           %[[VAL_0:.*]] = call @newSparseTensor(%[[CAST_1]], %[[CAST_1]], %[[CAST_0]], %[[CAST_2]], %[[CAST_2]], %[[CONSTANT_2]], %[[CONSTANT_2]], %[[CONSTANT_1]], %[[CONSTANT_2]], %[[MLIR_0]]) : (memref<?xindex>, memref<?xindex>, memref<?xi64>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr
+// CHECK:           %[[ALLOC_0:.*]] = memref.alloc() : memref<300xf64>
+// CHECK:           %[[CAST_3:.*]] = memref.cast %[[ALLOC_0]] : memref<300xf64> to memref<?xf64>
+// CHECK:           %[[ALLOC_1:.*]] = memref.alloc() : memref<300xi1>
+// CHECK:           %[[CAST_4:.*]] = memref.cast %[[ALLOC_1]] : memref<300xi1> to memref<?xi1>
+// CHECK:           %[[ALLOC_2:.*]] = memref.alloc() : memref<300xindex>
+// CHECK:           %[[CAST_5:.*]] = memref.cast %[[ALLOC_2]] : memref<300xindex> to memref<?xindex>
+// CHECK:           linalg.fill ins(%[[CONSTANT_0]] : f64) outs(%[[ALLOC_0]] : memref<300xf64>)
+// CHECK:           linalg.fill ins(%[[CONSTANT_4]] : i1) outs(%[[ALLOC_1]] : memref<300xi1>)
+// CHECK:           %[[VAL_1:.*]] = call @sparseValuesF64(%[[ARG0]]) : (!llvm.ptr) -> memref<?xf64>
+// CHECK:           %[[VAL_2:.*]] = call @sparseValuesF64(%[[ARG1]]) : (!llvm.ptr) -> memref<?xf64>
+// CHECK:           %[[VAL_3:.*]] = call @sparsePositions0(%[[ARG0]], %[[CONSTANT_6]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK:           %[[VAL_4:.*]] = call @sparseCoordinates0(%[[ARG0]], %[[CONSTANT_6]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK:           %[[VAL_5:.*]] = call @sparsePositions0(%[[ARG0]], %[[CONSTANT_5]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK:           %[[VAL_6:.*]] = call @sparseCoordinates0(%[[ARG0]], %[[CONSTANT_5]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = call @sparsePositions0(%[[ARG1]], %[[CONSTANT_6]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = call @sparseCoordinates0(%[[ARG1]], %[[CONSTANT_6]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = call @sparsePositions0(%[[ARG1]], %[[CONSTANT_5]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = call @sparseCoordinates0(%[[ARG1]], %[[CONSTANT_5]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK:           %[[LOAD_0:.*]] = memref.load %[[VAL_3]]{{\[}}%[[CONSTANT_6]]] : memref<?xindex>
+// CHECK:           %[[LOAD_1:.*]] = memref.load %[[VAL_3]]{{\[}}%[[CONSTANT_5]]] : memref<?xindex>
+// CHECK:           scf.for %[[VAL_11:.*]] = %[[LOAD_0]] to %[[LOAD_1]] step %[[CONSTANT_5]] {
+// CHECK:             %[[LOAD_2:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_11]]] : memref<?xindex>
+// CHECK:             %[[LOAD_3:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_11]]] : memref<?xindex>
+// CHECK:             %[[ADDI_0:.*]] = arith.addi %[[VAL_11]], %[[CONSTANT_5]] : index
+// CHECK:             %[[LOAD_4:.*]] = memref.load %[[VAL_5]]{{\[}}%[[ADDI_0]]] : memref<?xindex>
+// CHECK:             %[[LOAD_5:.*]] = memref.load %[[VAL_7]]{{\[}}%[[CONSTANT_6]]] : memref<?xindex>
+// CHECK:             %[[LOAD_6:.*]] = memref.load %[[VAL_7]]{{\[}}%[[CONSTANT_5]]] : memref<?xindex>
+// CHECK:             %[[WHILE_0:.*]]:3 = scf.while (%[[VAL_12:.*]] = %[[LOAD_3]], %[[VAL_13:.*]] = %[[LOAD_5]], %[[VAL_14:.*]] = %[[CONSTANT_6]]) : (index, index, index) -> (index, index, index) {
+// CHECK:               %[[CMPI_0:.*]] = arith.cmpi ult, %[[VAL_12]], %[[LOAD_4]] : index
+// CHECK:               %[[CMPI_1:.*]] = arith.cmpi ult, %[[VAL_13]], %[[LOAD_6]] : index
+// CHECK:               %[[ANDI_0:.*]] = arith.andi %[[CMPI_0]], %[[CMPI_1]] : i1
+// CHECK:               scf.condition(%[[ANDI_0]]) %[[VAL_12]], %[[VAL_13]], %[[VAL_14]] : index, index, index
 // CHECK:             } do {
-// CHECK:             ^bb0(%[[VAL_52:.*]]: index, %[[VAL_53:.*]]: index, %[[VAL_54:.*]]: index):
-// CHECK:               %[[VAL_55:.*]] = memref.load %[[VAL_29]]{{\[}}%[[VAL_52]]] : memref<?xindex>
-// CHECK:               %[[VAL_56:.*]] = memref.load %[[VAL_32]]{{\[}}%[[VAL_53]]] : memref<?xindex>
-// CHECK:               %[[VAL_57:.*]] = arith.cmpi ult, %[[VAL_56]], %[[VAL_55]] : index
-// CHECK:               %[[VAL_58:.*]] = arith.select %[[VAL_57]], %[[VAL_56]], %[[VAL_55]] : index
-// CHECK:               %[[VAL_59:.*]] = arith.cmpi eq, %[[VAL_55]], %[[VAL_58]] : index
-// CHECK:               %[[VAL_60:.*]] = arith.cmpi eq, %[[VAL_56]], %[[VAL_58]] : index
-// CHECK:               %[[VAL_61:.*]] = arith.andi %[[VAL_59]], %[[VAL_60]] : i1
-// CHECK:               %[[VAL_62:.*]] = scf.if %[[VAL_61]] -> (index) {
-// CHECK:                 %[[VAL_63:.*]] = memref.load %[[VAL_30]]{{\[}}%[[VAL_52]]] : memref<?xf64>
-// CHECK:                 %[[VAL_64:.*]] = memref.load %[[VAL_33]]{{\[}}%[[VAL_53]]] : memref<?xindex>
-// CHECK:                 %[[VAL_65:.*]] = arith.addi %[[VAL_53]], %[[VAL_6]] : index
-// CHECK:                 %[[VAL_66:.*]] = memref.load %[[VAL_33]]{{\[}}%[[VAL_65]]] : memref<?xindex>
-// CHECK:                 %[[VAL_67:.*]] = scf.for %[[VAL_68:.*]] = %[[VAL_64]] to %[[VAL_66]] step %[[VAL_6]] iter_args(%[[VAL_69:.*]] = %[[VAL_54]]) -> (index) {
-// CHECK:                   %[[VAL_70:.*]] = memref.load %[[VAL_34]]{{\[}}%[[VAL_68]]] : memref<?xindex>
-// CHECK:                   %[[VAL_71:.*]] = memref.load %[[VAL_20]]{{\[}}%[[VAL_70]]] : memref<300xf64>
-// CHECK:                   %[[VAL_72:.*]] = memref.load %[[VAL_35]]{{\[}}%[[VAL_68]]] : memref<?xf64>
-// CHECK:                   %[[VAL_73:.*]] = arith.mulf %[[VAL_63]], %[[VAL_72]] : f64
-// CHECK:                   %[[VAL_74:.*]] = arith.addf %[[VAL_71]], %[[VAL_73]] : f64
-// CHECK:                   %[[VAL_75:.*]] = memref.load %[[VAL_22]]{{\[}}%[[VAL_70]]] : memref<300xi1>
-// CHECK:                   %[[VAL_76:.*]] = arith.cmpi eq, %[[VAL_75]], %[[VAL_7]] : i1
-// CHECK:                   %[[VAL_77:.*]] = scf.if %[[VAL_76]] -> (index) {
-// CHECK:                       memref.store %[[VAL_8]], %[[VAL_22]]{{\[}}%[[VAL_70]]] : memref<300xi1>
-// CHECK:                       memref.store %[[VAL_70]], %[[VAL_24]]{{\[}}%[[VAL_69]]] : memref<300xindex>
-// CHECK:                       %[[VAL_78:.*]] = arith.addi %[[VAL_69]], %[[VAL_6]] : index
-// CHECK:                       scf.yield %[[VAL_78]] : index
+// CHECK:             ^bb0(%[[VAL_15:.*]]: index, %[[VAL_16:.*]]: index, %[[VAL_17:.*]]: index):
+// CHECK:               %[[ADDI_1:.*]] = arith.addi %[[VAL_16]], %[[CONSTANT_5]] : index
+// CHECK:               %[[LOAD_7:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_15]]] : memref<?xindex>
+// CHECK:               %[[LOAD_8:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref<?xindex>
+// CHECK:               %[[CMPI_2:.*]] = arith.cmpi ult, %[[LOAD_8]], %[[LOAD_7]] : index
+// CHECK:               %[[SELECT_0:.*]] = arith.select %[[CMPI_2]], %[[LOAD_8]], %[[LOAD_7]] : index
+// CHECK:               %[[CMPI_3:.*]] = arith.cmpi eq, %[[LOAD_7]], %[[SELECT_0]] : index
+// CHECK:               %[[CMPI_4:.*]] = arith.cmpi eq, %[[LOAD_8]], %[[SELECT_0]] : index
+// CHECK:               %[[ANDI_1:.*]] = arith.andi %[[CMPI_3]], %[[CMPI_4]] : i1
+// CHECK:               %[[IF_0:.*]] = scf.if %[[ANDI_1]] -> (index) {
+// CHECK:                 %[[LOAD_9:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_15]]] : memref<?xf64>
+// CHECK:                 %[[LOAD_10:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_16]]] : memref<?xindex>
+// CHECK:                 %[[LOAD_11:.*]] = memref.load %[[VAL_9]]{{\[}}%[[ADDI_1]]] : memref<?xindex>
+// CHECK:                 %[[FOR_0:.*]] = scf.for %[[VAL_18:.*]] = %[[LOAD_10]] to %[[LOAD_11]] step %[[CONSTANT_5]] iter_args(%[[VAL_19:.*]] = %[[VAL_17]]) -> (index) {
+// CHECK:                   %[[LOAD_12:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_18]]] : memref<?xindex>
+// CHECK:                   %[[LOAD_13:.*]] = memref.load %[[ALLOC_0]]{{\[}}%[[LOAD_12]]] : memref<300xf64>
+// CHECK:                   %[[LOAD_14:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_18]]] : memref<?xf64>
+// CHECK:                   %[[MULF_0:.*]] = arith.mulf %[[LOAD_9]], %[[LOAD_14]] : f64
+// CHECK:                   %[[ADDF_0:.*]] = arith.addf %[[LOAD_13]], %[[MULF_0]] : f64
+// CHECK:                   %[[LOAD_15:.*]] = memref.load %[[ALLOC_1]]{{\[}}%[[LOAD_12]]] : memref<300xi1>
+// CHECK:                   %[[CMPI_5:.*]] = arith.cmpi eq, %[[LOAD_15]], %[[CONSTANT_4]] : i1
+// CHECK:                   %[[IF_1:.*]] = scf.if %[[CMPI_5]] -> (index) {
+// CHECK:                     memref.store %[[CONSTANT_3]], %[[ALLOC_1]]{{\[}}%[[LOAD_12]]] : memref<300xi1>
+// CHECK:                     memref.store %[[LOAD_12]], %[[ALLOC_2]]{{\[}}%[[VAL_19]]] : memref<300xindex>
+// CHECK:                     %[[ADDI_2:.*]] = arith.addi %[[VAL_19]], %[[CONSTANT_5]] : index
+// CHECK:                     scf.yield %[[ADDI_2]] : index
 // CHECK:                   } else {
-// CHECK:                       scf.yield %[[VAL_69]] : index
+// CHECK:                     scf.yield %[[VAL_19]] : index
 // CHECK:                   }
-// CHECK:                   memref.store %[[VAL_74]], %[[VAL_20]]{{\[}}%[[VAL_70]]] : memref<300xf64>
-// CHECK:                   scf.yield %[[VAL_77]] : index
-// CHECK:                 }
-// CHECK:                 scf.yield %[[VAL_67]] : index
+// CHECK:                   memref.store %[[ADDF_0]], %[[ALLOC_0]]{{\[}}%[[LOAD_12]]] : memref<300xf64>
+// CHECK:                   scf.yield %[[IF_1]] : index
+// CHECK:                 } {"Emitted from" = "linalg.generic"}
+// CHECK:                 scf.yield %[[FOR_0]] : index
 // CHECK:               } else {
-// CHECK:                 scf.yield %[[VAL_54]] : index
+// CHECK:                 scf.yield %[[VAL_17]] : index
 // CHECK:               }
-// CHECK:               %[[VAL_79:.*]] = arith.addi %[[VAL_52]], %[[VAL_6]] : index
-// CHECK:               %[[VAL_80:.*]] = arith.select %[[VAL_59]], %[[VAL_79]], %[[VAL_52]] : index
-// CHECK:               %[[VAL_81:.*]] = arith.addi %[[VAL_53]], %[[VAL_6]] : index
-// CHECK:               %[[VAL_82:.*]] = arith.select %[[VAL_60]], %[[VAL_81]], %[[VAL_53]] : index
-// CHECK:               scf.yield %[[VAL_80]], %[[VAL_82]], %[[VAL_62]] : index, index, index
+// CHECK:               %[[ADDI_3:.*]] = arith.addi %[[VAL_15]], %[[CONSTANT_5]] : index
+// CHECK:               %[[SELECT_1:.*]] = arith.select %[[CMPI_3]], %[[ADDI_3]], %[[VAL_15]] : index
+// CHECK:               %[[SELECT_2:.*]] = arith.select %[[CMPI_4]], %[[ADDI_1]], %[[VAL_16]] : index
+// CHECK:               scf.yield %[[SELECT_1]], %[[SELECT_2]], %[[IF_0]] : index, index, index
 // CHECK:             }
-// CHECK:             %[[VAL_83:.*]] = memref.alloca() : memref<2xindex>
-// CHECK:             %[[VAL_84:.*]] = memref.cast %[[VAL_83]] : memref<2xindex> to memref<?xindex>
-// CHECK:             memref.store %[[VAL_39]], %[[VAL_83]]{{\[}}%[[VAL_5]]] : memref<2xindex>
-// CHECK:             func.call @expInsertF64(%[[VAL_19]], %[[VAL_84]], %[[VAL_21]], %[[VAL_23]], %[[VAL_25]], %[[VAL_85:.*]]#2) : (!llvm.ptr, memref<?xindex>, memref<?xf64>, memref<?xi1>, memref<?xindex>, index) -> ()
-// CHECK:           }
-// CHECK:           memref.dealloc %[[VAL_20]] : memref<300xf64>
-// CHECK:           memref.dealloc %[[VAL_22]] : memref<300xi1>
-// CHECK:           memref.dealloc %[[VAL_24]] : memref<300xindex>
-// CHECK:           call @endLexInsert(%[[VAL_19]]) : (!llvm.ptr) -> ()
-// CHECK:           return %[[VAL_19]] : !llvm.ptr
-// CHECK:       }
+// CHECK:             %[[ALLOCA_3:.*]] = memref.alloca() : memref<2xindex>
+// CHECK:             %[[CAST_6:.*]] = memref.cast %[[ALLOCA_3]] : memref<2xindex> to memref<?xindex>
+// CHECK:             memref.store %[[LOAD_2]], %[[ALLOCA_3]]{{\[}}%[[CONSTANT_6]]] : memref<2xindex>
+// CHECK:             func.call @expInsertF64(%[[VAL_0]], %[[CAST_6]], %[[CAST_3]], %[[CAST_4]], %[[CAST_5]], %[[VAL_20:.*]]#2) : (!llvm.ptr, memref<?xindex>, memref<?xf64>, memref<?xi1>, memref<?xindex>, index) -> ()
+// CHECK:           } {"Emitted from" = "linalg.generic"}
+// CHECK:           memref.dealloc %[[ALLOC_0]] : memref<300xf64>
+// CHECK:           memref.dealloc %[[ALLOC_1]] : memref<300xi1>
+// CHECK:           memref.dealloc %[[ALLOC_2]] : memref<300xindex>
+// CHECK:           call @endLexInsert(%[[VAL_0]]) : (!llvm.ptr) -> ()
+// CHECK:           return %[[VAL_0]] : !llvm.ptr
+// CHECK:         }
 func.func @fill_zero_after_alloc(%arg0: tensor<100x200xf64, #DCSR>,
                                        %arg1: tensor<200x300xf64, #DCSR>) -> tensor<100x300xf64, #DCSR> {
   %0 = tensor.empty() : tensor<100x300xf64, #DCSR>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir b/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
index f6f7f396adab5..c4d86b6b6931f 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
@@ -79,76 +79,72 @@ func.func @sqsum(%arg0: tensor<?x?x?x?xi32, #COO>) -> tensor<i32> {
 // ITER:         }
 
 // CHECK-LABEL:   func.func @add(
-// CHECK-SAME:      %[[VAL_0:.*]]: tensor<10xi32, #sparse{{.*}}>,
-// CHECK-SAME:      %[[VAL_1:.*]]: tensor<10xi32, #sparse{{.*}}>) -> tensor<10xi32> {
-// CHECK:           %[[VAL_2:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_4:.*]] = arith.constant 0 : i32
-// CHECK:           %[[VAL_5:.*]] = arith.constant dense<0> : tensor<10xi32>
-// CHECK:           %[[VAL_6:.*]] = bufferization.to_buffer %[[VAL_5]] : tensor<10xi32> to memref<10xi32>
-// CHECK:           linalg.fill ins(%[[VAL_4]] : i32) outs(%[[VAL_6]] : memref<10xi32>)
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref<?xindex>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref<?xindex>
-// CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_2]]] : memref<?xindex>
-// CHECK:           %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref<?xindex>
-// CHECK:           %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref<?xindex>
-// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_2]]] : memref<?xindex>
-// CHECK:           %[[VAL_15:.*]]:2 = scf.while (%[[VAL_16:.*]] = %[[VAL_9]], %[[VAL_17:.*]] = %[[VAL_13]]) : (index, index) -> (index, index) {
-// CHECK:             %[[VAL_18:.*]] = arith.cmpi ult, %[[VAL_16]], %[[VAL_10]] : index
-// CHECK:             %[[VAL_19:.*]] = arith.cmpi ult, %[[VAL_17]], %[[VAL_14]] : index
-// CHECK:             %[[VAL_20:.*]] = arith.andi %[[VAL_18]], %[[VAL_19]] : i1
-// CHECK:             scf.condition(%[[VAL_20]]) %[[VAL_16]], %[[VAL_17]] : index, index
+// CHECK-SAME:      %[[ARG0:.*]]: tensor<10xi32, {{.*}}>,
+// CHECK-SAME:      %[[ARG1:.*]]: tensor<10xi32, {{.*}}>) -> tensor<10xi32> {
+// CHECK:           %[[CONSTANT_0:.*]] = arith.constant 1 : index
+// CHECK:           %[[CONSTANT_1:.*]] = arith.constant 0 : index
+// CHECK:           %[[CONSTANT_2:.*]] = arith.constant dense<0> : tensor<10xi32>
+// CHECK:           %[[CONSTANT_3:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VALUES_0:.*]] = sparse_tensor.values %[[ARG1]] : tensor<10xi32, {{.*}}> to memref<?xi32>
+// CHECK:           %[[VALUES_1:.*]] = sparse_tensor.values %[[ARG0]] : tensor<10xi32, {{.*}}> to memref<?xi32>
+// CHECK:           %[[TO_BUFFER_0:.*]] = bufferization.to_buffer %[[CONSTANT_2]] : tensor<10xi32> to memref<10xi32>
+// CHECK:           linalg.fill ins(%[[CONSTANT_3]] : i32) outs(%[[TO_BUFFER_0]] : memref<10xi32>)
+// CHECK:           %[[POSITIONS_0:.*]] = sparse_tensor.positions %[[ARG0]] {level = 0 : index} : tensor<10xi32, {{.*}}> to memref<?xindex>
+// CHECK:           %[[COORDINATES_0:.*]] = sparse_tensor.coordinates %[[ARG0]] {level = 0 : index} : tensor<10xi32, {{.*}}> to memref<?xindex>
+// CHECK:           %[[LOAD_0:.*]] = memref.load %[[POSITIONS_0]]{{\[}}%[[CONSTANT_1]]] : memref<?xindex>
+// CHECK:           %[[LOAD_1:.*]] = memref.load %[[POSITIONS_0]]{{\[}}%[[CONSTANT_0]]] : memref<?xindex>
+// CHECK:           %[[POSITIONS_1:.*]] = sparse_tensor.positions %[[ARG1]] {level = 0 : index} : tensor<10xi32, {{.*}}> to memref<?xindex>
+// CHECK:           %[[COORDINATES_1:.*]] = sparse_tensor.coordinates %[[ARG1]] {level = 0 : index} : tensor<10xi32, {{.*}}> to memref<?xindex>
+// CHECK:           %[[LOAD_2:.*]] = memref.load %[[POSITIONS_1]]{{\[}}%[[CONSTANT_1]]] : memref<?xindex>
+// CHECK:           %[[LOAD_3:.*]] = memref.load %[[POSITIONS_1]]{{\[}}%[[CONSTANT_0]]] : memref<?xindex>
+// CHECK:           %[[WHILE_0:.*]]:2 = scf.while (%[[VAL_0:.*]] = %[[LOAD_0]], %[[VAL_1:.*]] = %[[LOAD_2]]) : (index, index) -> (index, index) {
+// CHECK:             %[[CMPI_0:.*]] = arith.cmpi ult, %[[VAL_0]], %[[LOAD_1]] : index
+// CHECK:             %[[CMPI_1:.*]] = arith.cmpi ult, %[[VAL_1]], %[[LOAD_3]] : index
+// CHECK:             %[[ANDI_0:.*]] = arith.andi %[[CMPI_0]], %[[CMPI_1]] : i1
+// CHECK:             scf.condition(%[[ANDI_0]]) %[[VAL_0]], %[[VAL_1]] : index, index
 // CHECK:           } do {
-// CHECK:           ^bb0(%[[VAL_21:.*]]: index, %[[VAL_22:.*]]: index):
-// CHECK:             %[[VAL_23:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_21]]] : memref<?xindex>
-// CHECK:             %[[VAL_24:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_22]]] : memref<?xindex>
-// CHECK:             %[[VAL_25:.*]] = arith.cmpi ult, %[[VAL_24]], %[[VAL_23]] : index
-// CHECK:             %[[VAL_26:.*]] = arith.select %[[VAL_25]], %[[VAL_24]], %[[VAL_23]] : index
-// CHECK:             %[[VAL_27:.*]] = arith.cmpi eq, %[[VAL_23]], %[[VAL_26]] : index
-// CHECK:             %[[VAL_28:.*]] = arith.cmpi eq, %[[VAL_24]], %[[VAL_26]] : index
-// CHECK:             %[[VAL_29:.*]] = arith.andi %[[VAL_27]], %[[VAL_28]] : i1
-// CHECK:             scf.if %[[VAL_29]] {
-// CHECK:               %[[VAL_30:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
-// CHECK:               %[[VAL_31:.*]] = memref.load %[[VAL_30]]{{\[}}%[[VAL_21]]] : memref<?xi32>
-// CHECK:               %[[VAL_32:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
-// CHECK:               %[[VAL_33:.*]] = memref.load %[[VAL_32]]{{\[}}%[[VAL_22]]] : memref<?xi32>
-// CHECK:               %[[VAL_34:.*]] = arith.addi %[[VAL_31]], %[[VAL_33]] : i32
-// CHECK:               memref.store %[[VAL_34]], %[[VAL_6]]{{\[}}%[[VAL_26]]] : memref<10xi32>
+// CHECK:           ^bb0(%[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index):
+// CHECK:             %[[LOAD_4:.*]] = memref.load %[[COORDINATES_0]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK:             %[[LOAD_5:.*]] = memref.load %[[COORDINATES_1]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK:             %[[CMPI_2:.*]] = arith.cmpi ult, %[[LOAD_5]], %[[LOAD_4]] : index
+// CHECK:             %[[SELECT_0:.*]] = arith.select %[[CMPI_2]], %[[LOAD_5]], %[[LOAD_4]] : index
+// CHECK:             %[[CMPI_3:.*]] = arith.cmpi eq, %[[LOAD_4]], %[[SELECT_0]] : index
+// CHECK:             %[[CMPI_4:.*]] = arith.cmpi eq, %[[LOAD_5]], %[[SELECT_0]] : index
+// CHECK:             %[[ANDI_1:.*]] = arith.andi %[[CMPI_3]], %[[CMPI_4]] : i1
+// CHECK:             scf.if %[[ANDI_1]] {
+// CHECK:               %[[LOAD_6:.*]] = memref.load %[[VALUES_1]]{{\[}}%[[VAL_2]]] : memref<?xi32>
+// CHECK:               %[[LOAD_7:.*]] = memref.load %[[VALUES_0]]{{\[}}%[[VAL_3]]] : memref<?xi32>
+// CHECK:               %[[ADDI_0:.*]] = arith.addi %[[LOAD_6]], %[[LOAD_7]] : i32
+// CHECK:               memref.store %[[ADDI_0]], %[[TO_BUFFER_0]]{{\[}}%[[SELECT_0]]] : memref<10xi32>
 // CHECK:             } else {
-// CHECK:               scf.if %[[VAL_27]] {
-// CHECK:                 %[[VAL_35:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
-// CHECK:                 %[[VAL_36:.*]] = memref.load %[[VAL_35]]{{\[}}%[[VAL_21]]] : memref<?xi32>
-// CHECK:                 memref.store %[[VAL_36]], %[[VAL_6]]{{\[}}%[[VAL_26]]] : memref<10xi32>
+// CHECK:               scf.if %[[CMPI_3]] {
+// CHECK:                 %[[LOAD_8:.*]] = memref.load %[[VALUES_1]]{{\[}}%[[VAL_2]]] : memref<?xi32>
+// CHECK:                 memref.store %[[LOAD_8]], %[[TO_BUFFER_0]]{{\[}}%[[SELECT_0]]] : memref<10xi32>
 // CHECK:               } else {
-// CHECK:                 scf.if %[[VAL_28]] {
-// CHECK:                   %[[VAL_37:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
-// CHECK:                   %[[VAL_38:.*]] = memref.load %[[VAL_37]]{{\[}}%[[VAL_22]]] : memref<?xi32>
-// CHECK:                   memref.store %[[VAL_38]], %[[VAL_6]]{{\[}}%[[VAL_26]]] : memref<10xi32>
+// CHECK:                 scf.if %[[CMPI_4]] {
+// CHECK:                   %[[LOAD_9:.*]] = memref.load %[[VALUES_0]]{{\[}}%[[VAL_3]]] : memref<?xi32>
+// CHECK:                   memref.store %[[LOAD_9]], %[[TO_BUFFER_0]]{{\[}}%[[SELECT_0]]] : memref<10xi32>
 // CHECK:                 }
 // CHECK:               }
 // CHECK:             }
-// CHECK:             %[[VAL_39:.*]] = arith.addi %[[VAL_21]], %[[VAL_2]] : index
-// CHECK:             %[[VAL_40:.*]] = arith.select %[[VAL_27]], %[[VAL_39]], %[[VAL_21]] : index
-// CHECK:             %[[VAL_41:.*]] = arith.addi %[[VAL_22]], %[[VAL_2]] : index
-// CHECK:             %[[VAL_42:.*]] = arith.select %[[VAL_28]], %[[VAL_41]], %[[VAL_22]] : index
-// CHECK:             scf.yield %[[VAL_40]], %[[VAL_42]] : index, index
+// CHECK:             %[[ADDI_1:.*]] = arith.addi %[[VAL_2]], %[[CONSTANT_0]] : index
+// CHECK:             %[[SELECT_1:.*]] = arith.select %[[CMPI_3]], %[[ADDI_1]], %[[VAL_2]] : index
+// CHECK:             %[[ADDI_2:.*]] = arith.addi %[[VAL_3]], %[[CONSTANT_0]] : index
+// CHECK:             %[[SELECT_2:.*]] = arith.select %[[CMPI_4]], %[[ADDI_2]], %[[VAL_3]] : index
+// CHECK:             scf.yield %[[SELECT_1]], %[[SELECT_2]] : index, index
 // CHECK:           }
-// CHECK:           %[[VAL_43:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
-// CHECK:           scf.for %[[VAL_44:.*]] = %[[VAL_45:.*]]#0 to %[[VAL_10]] step %[[VAL_2]] {
-// CHECK:             %[[VAL_46:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_44]]] : memref<?xindex>
-// CHECK:             %[[VAL_47:.*]] = memref.load %[[VAL_43]]{{\[}}%[[VAL_44]]] : memref<?xi32>
-// CHECK:             memref.store %[[VAL_47]], %[[VAL_6]]{{\[}}%[[VAL_46]]] : memref<10xi32>
+// CHECK:           scf.for %[[VAL_4:.*]] = %[[VAL_5:.*]]#0 to %[[LOAD_1]] step %[[CONSTANT_0]] {
+// CHECK:             %[[LOAD_10:.*]] = memref.load %[[COORDINATES_0]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:             %[[LOAD_11:.*]] = memref.load %[[VALUES_1]]{{\[}}%[[VAL_4]]] : memref<?xi32>
+// CHECK:             memref.store %[[LOAD_11]], %[[TO_BUFFER_0]]{{\[}}%[[LOAD_10]]] : memref<10xi32>
 // CHECK:           }
-// CHECK:           %[[VAL_48:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
-// CHECK:           scf.for %[[VAL_49:.*]] = %[[VAL_50:.*]]#1 to %[[VAL_14]] step %[[VAL_2]] {
-// CHECK:             %[[VAL_51:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_49]]] : memref<?xindex>
-// CHECK:             %[[VAL_52:.*]] = memref.load %[[VAL_48]]{{\[}}%[[VAL_49]]] : memref<?xi32>
-// CHECK:             memref.store %[[VAL_52]], %[[VAL_6]]{{\[}}%[[VAL_51]]] : memref<10xi32>
+// CHECK:           scf.for %[[VAL_6:.*]] = %[[VAL_7:.*]]#1 to %[[LOAD_3]] step %[[CONSTANT_0]] {
+// CHECK:             %[[LOAD_12:.*]] = memref.load %[[COORDINATES_1]]{{\[}}%[[VAL_6]]] : memref<?xindex>
+// CHECK:             %[[LOAD_13:.*]] = memref.load %[[VALUES_0]]{{\[}}%[[VAL_6]]] : memref<?xi32>
+// CHECK:             memref.store %[[LOAD_13]], %[[TO_BUFFER_0]]{{\[}}%[[LOAD_12]]] : memref<10xi32>
 // CHECK:           }
-// CHECK:           %[[VAL_53:.*]] = bufferization.to_tensor %[[VAL_6]] : memref<10xi32>
-// CHECK:           return %[[VAL_53]] : tensor<10xi32>
+// CHECK:           %[[TO_TENSOR_0:.*]] = bufferization.to_tensor %[[TO_BUFFER_0]] : memref<10xi32> to tensor<10xi32>
+// CHECK:           return %[[TO_TENSOR_0]] : tensor<10xi32>
 // CHECK:         }
 func.func @add(%arg0: tensor<10xi32, #VEC>, %arg1: tensor<10xi32, #VEC>) -> tensor<10xi32> {
   %cst = arith.constant dense<0> : tensor<10xi32>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir
index e9587edef4678..165d0835b5824 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir
@@ -58,56 +58,55 @@ func.func @sparse_index_1d_conj(%arga: tensor<8xi64, #SparseVector>) -> tensor<8
   return %r : tensor<8xi64>
 }
 
-// CHECK-LABEL: func.func @sparse_index_1d_disj(
-// CHECK-SAME:      %[[VAL_0:.*]]: tensor<8xi64, #sparse{{[0-9]*}}>) -> tensor<8xi64> {
-// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 8 : index
-// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
-// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0 : i64
-// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
-// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant true
-// CHECK-DAG:       %[[VAL_7:.*]] = tensor.empty() : tensor<8xi64>
-// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xi64>
-// CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_buffer %[[VAL_7]] : tensor<8xi64> to memref<8xi64>
-// CHECK-DAG:       linalg.fill ins(%[[VAL_3]] : i64) outs(%[[VAL_11]] : memref<8xi64>)
-// CHECK:           %[[VAL_12:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK:           %[[VAL_14:.*]]:2 = scf.while (%[[VAL_15:.*]] = %[[VAL_12]], %[[VAL_16:.*]] = %[[VAL_4]]) : (index, index) -> (index, index) {
-// CHECK:             %[[VAL_17:.*]] = arith.cmpi ult, %[[VAL_15]], %[[VAL_13]] : index
-// CHECK:             scf.condition(%[[VAL_17]]) %[[VAL_15]], %[[VAL_16]] : index, index
+// CHECK-LABEL:   func.func @sparse_index_1d_disj(
+// CHECK-SAME:      %[[ARG0:.*]]: tensor<8xi64, #sparse{{[0-9]*}}>) -> tensor<8xi64> {
+// CHECK:           %[[CONSTANT_0:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
+// CHECK:           %[[CONSTANT_1:.*]] = arith.constant true
+// CHECK:           %[[CONSTANT_2:.*]] = arith.constant 1 : index
+// CHECK:           %[[CONSTANT_3:.*]] = arith.constant 0 : index
+// CHECK:           %[[CONSTANT_4:.*]] = arith.constant 8 : index
+// CHECK:           %[[CONSTANT_5:.*]] = arith.constant 0 : i64
+// CHECK:           %[[EMPTY_0:.*]] = tensor.empty() : tensor<8xi64>
+// CHECK:           %[[VALUES_0:.*]] = sparse_tensor.values %[[ARG0]] : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xi64>
+// CHECK:           %[[TO_BUFFER_0:.*]] = bufferization.to_buffer %[[EMPTY_0]] : tensor<8xi64> to memref<8xi64>
+// CHECK:           linalg.fill ins(%[[CONSTANT_5]] : i64) outs(%[[TO_BUFFER_0]] : memref<8xi64>)
+// CHECK:           %[[POSITIONS_0:.*]] = sparse_tensor.positions %[[ARG0]] {level = 0 : index} : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK:           %[[COORDINATES_0:.*]] = sparse_tensor.coordinates %[[ARG0]] {level = 0 : index} : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK:           %[[LOAD_0:.*]] = memref.load %[[POSITIONS_0]]{{\[}}%[[CONSTANT_3]]] : memref<?xindex>
+// CHECK:           %[[LOAD_1:.*]] = memref.load %[[POSITIONS_0]]{{\[}}%[[CONSTANT_2]]] : memref<?xindex>
+// CHECK:           %[[WHILE_0:.*]]:2 = scf.while (%[[VAL_0:.*]] = %[[LOAD_0]], %[[VAL_1:.*]] = %[[CONSTANT_3]]) : (index, index) -> (index, index) {
+// CHECK:             %[[CMPI_0:.*]] = arith.cmpi ult, %[[VAL_0]], %[[LOAD_1]] : index
+// CHECK:             scf.condition(%[[CMPI_0]]) %[[VAL_0]], %[[VAL_1]] : index, index
 // CHECK:           } do {
-// CHECK:           ^bb0(%[[VAL_18:.*]]: index, %[[VAL_19:.*]]: index):
-// CHECK:             %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref<?xindex>
-// CHECK:             %[[VAL_21:.*]] = arith.cmpi eq, %[[VAL_20]], %[[VAL_19]] : index
-// CHECK:             scf.if %[[VAL_21]] {
-// CHECK:               %[[VAL_22:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_18]]] : memref<?xi64>
-// CHECK:               %[[VAL_23:.*]] = arith.index_cast %[[VAL_19]] : index to i64
-// CHECK:               %[[VAL_24:.*]] = arith.addi %[[VAL_22]], %[[VAL_23]] : i64
-// CHECK:               memref.store %[[VAL_24]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<8xi64>
+// CHECK:           ^bb0(%[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index):
+// CHECK:             %[[INDEX_CAST_0:.*]] = arith.index_cast %[[VAL_3]] : index to i64
+// CHECK:             %[[LOAD_2:.*]] = memref.load %[[COORDINATES_0]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK:             %[[CMPI_1:.*]] = arith.cmpi eq, %[[LOAD_2]], %[[VAL_3]] : index
+// CHECK:             scf.if %[[CMPI_1]] {
+// CHECK:               %[[LOAD_3:.*]] = memref.load %[[VALUES_0]]{{\[}}%[[VAL_2]]] : memref<?xi64>
+// CHECK:               %[[ADDI_0:.*]] = arith.addi %[[LOAD_3]], %[[INDEX_CAST_0]] : i64
+// CHECK:               memref.store %[[ADDI_0]], %[[TO_BUFFER_0]]{{\[}}%[[VAL_3]]] : memref<8xi64>
 // CHECK:             } else {
-// CHECK:               scf.if %[[VAL_6]] {
-// CHECK:                 %[[VAL_25:.*]] = arith.index_cast %[[VAL_19]] : index to i64
-// CHECK:                 memref.store %[[VAL_25]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<8xi64>
+// CHECK:               scf.if %[[CONSTANT_1]] {
+// CHECK:                 memref.store %[[INDEX_CAST_0]], %[[TO_BUFFER_0]]{{\[}}%[[VAL_3]]] : memref<8xi64>
 // CHECK:               } else {
 // CHECK:               }
 // CHECK:             }
-// CHECK:             %[[VAL_26:.*]] = arith.addi %[[VAL_18]], %[[VAL_5]] : index
-// CHECK:             %[[VAL_27:.*]] = arith.select %[[VAL_21]], %[[VAL_26]], %[[VAL_18]] : index
-// CHECK:             %[[VAL_28:.*]] = arith.addi %[[VAL_19]], %[[VAL_5]] : index
-// CHECK:             scf.yield %[[VAL_27]], %[[VAL_28]] : index, index
+// CHECK:             %[[ADDI_1:.*]] = arith.addi %[[VAL_2]], %[[CONSTANT_2]] : index
+// CHECK:             %[[SELECT_0:.*]] = arith.select %[[CMPI_1]], %[[ADDI_1]], %[[VAL_2]] : index
+// CHECK:             %[[ADDI_2:.*]] = arith.addi %[[VAL_3]], %[[CONSTANT_2]] : index
+// CHECK:             scf.yield %[[SELECT_0]], %[[ADDI_2]] : index, index
 // CHECK:           } attributes {"Emitted from" = "linalg.generic"}
-// CHECK:           scf.for %[[VAL_29:.*]] = %[[VAL_30:.*]]#1 to %[[VAL_1]] step %[[VAL_1]] {
-// CHECK:             %[[VAL_31:.*]] = affine.min #map(%[[VAL_1]], %[[VAL_29]]){{\[}}%[[VAL_1]]]
-// CHECK:             %[[VAL_32:.*]] = vector.create_mask %[[VAL_31]] : vector<8xi1>
-// CHECK:             %[[VAL_33:.*]] = vector.broadcast %[[VAL_29]] : index to vector<8xindex>
-// CHECK:             %[[VAL_34:.*]] = arith.addi %[[VAL_33]], %[[VAL_2]] : vector<8xindex>
-// CHECK:             %[[VAL_35:.*]] = arith.index_cast %[[VAL_34]] : vector<8xindex> to vector<8xi64>
-// CHECK:             vector.maskedstore %[[VAL_11]]{{\[}}%[[VAL_29]]], %[[VAL_32]], %[[VAL_35]] : memref<8xi64>, vector<8xi1>, vector<8xi64>
+// CHECK:           scf.for %[[VAL_4:.*]] = %[[VAL_5:.*]]#1 to %[[CONSTANT_4]] step %[[CONSTANT_4]] {
+// CHECK:             %[[MIN_0:.*]] = affine.min #{{.*}}(%[[CONSTANT_4]], %[[VAL_4]]){{\[}}%[[CONSTANT_4]]]
+// CHECK:             %[[CREATE_MASK_0:.*]] = vector.create_mask %[[MIN_0]] : vector<8xi1>
+// CHECK:             %[[BROADCAST_0:.*]] = vector.broadcast %[[VAL_4]] : index to vector<8xindex>
+// CHECK:             %[[ADDI_3:.*]] = arith.addi %[[BROADCAST_0]], %[[CONSTANT_0]] : vector<8xindex>
+// CHECK:             %[[INDEX_CAST_1:.*]] = arith.index_cast %[[ADDI_3]] : vector<8xindex> to vector<8xi64>
+// CHECK:             vector.maskedstore %[[TO_BUFFER_0]]{{\[}}%[[VAL_4]]], %[[CREATE_MASK_0]], %[[INDEX_CAST_1]] : memref<8xi64>, vector<8xi1>, vector<8xi64>
 // CHECK:           } {"Emitted from" = "linalg.generic"}
-// CHECK:           %[[VAL_36:.*]] = bufferization.to_tensor %[[VAL_11]] : memref<8xi64>
-// CHECK:           return %[[VAL_36]] : tensor<8xi64>
+// CHECK:           %[[TO_TENSOR_0:.*]] = bufferization.to_tensor %[[TO_BUFFER_0]] : memref<8xi64> to tensor<8xi64>
+// CHECK:           return %[[TO_TENSOR_0]] : tensor<8xi64>
 // CHECK:         }
 func.func @sparse_index_1d_disj(%arga: tensor<8xi64, #SparseVector>) -> tensor<8xi64> {
   %init = tensor.empty() : tensor<8xi64>
diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir
index b447094874d01..81b63b9152cec 100644
--- a/mlir/test/Transforms/cse.mlir
+++ b/mlir/test/Transforms/cse.mlir
@@ -163,16 +163,14 @@ func.func @down_propagate() -> i32 {
 /// Check that operation definitions are NOT propagated up the dominance tree.
 // CHECK-LABEL: @up_propagate_for
 func.func @up_propagate_for() -> i32 {
+  // CHECK-NEXT: %[[VAR_c1_i32_0:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
   // CHECK: affine.for {{.*}} = 0 to 4 {
-  affine.for %i = 0 to 4 {
-    // CHECK-NEXT: %[[VAR_c1_i32_0:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
+  affine.for %i = 0 to 4 {  
     // CHECK-NEXT: "foo"(%[[VAR_c1_i32_0]]) : (i32) -> ()
     %0 = arith.constant 1 : i32
     "foo"(%0) : (i32) -> ()
   }
-
-  // CHECK: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
-  // CHECK-NEXT: return %[[VAR_c1_i32]] : i32
+  // CHECK: return %[[VAR_c1_i32_0]] : i32
   %1 = arith.constant 1 : i32
   return %1 : i32
 }
@@ -181,7 +179,8 @@ func.func @up_propagate_for() -> i32 {
 
 // CHECK-LABEL: func @up_propagate
 func.func @up_propagate() -> i32 {
-  // CHECK-NEXT:  %[[VAR_c0_i32:[0-9a-zA-Z_]+]] = arith.constant 0 : i32
+  // CHECK-NEXT: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
+  // CHECK-NEXT: %[[VAR_c0_i32:[0-9a-zA-Z_]+]] = arith.constant 0 : i32
   %0 = arith.constant 0 : i32
 
   // CHECK-NEXT: %[[VAR_true:[0-9a-zA-Z_]+]] = arith.constant true
@@ -191,17 +190,15 @@ func.func @up_propagate() -> i32 {
   cf.cond_br %cond, ^bb1, ^bb2(%0 : i32)
 
 ^bb1: // CHECK: ^bb1:
-  // CHECK-NEXT: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
   %1 = arith.constant 1 : i32
 
   // CHECK-NEXT: cf.br ^bb2(%[[VAR_c1_i32]] : i32)
   cf.br ^bb2(%1 : i32)
 
 ^bb2(%arg : i32): // CHECK: ^bb2
-  // CHECK-NEXT: %[[VAR_c1_i32_0:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
   %2 = arith.constant 1 : i32
 
-  // CHECK-NEXT: %[[VAR_1:[0-9a-zA-Z_]+]] = arith.addi %{{.*}}, %[[VAR_c1_i32_0]] : i32
+  // CHECK-NEXT: %[[VAR_1:[0-9a-zA-Z_]+]] = arith.addi %{{.*}}, %[[VAR_c1_i32]] : i32
   %add = arith.addi %arg, %2 : i32
 
   // CHECK-NEXT: return %[[VAR_1]] : i32
@@ -216,6 +213,7 @@ func.func @up_propagate() -> i32 {
 func.func @up_propagate_region() -> i32 {
   // CHECK-NEXT: {{.*}} "foo.region"
   %0 = "foo.region"() ({
+    // CHECK-NEXT:  %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
     // CHECK-NEXT:  %[[VAR_c0_i32:[0-9a-zA-Z_]+]] = arith.constant 0 : i32
     // CHECK-NEXT: %[[VAR_true:[0-9a-zA-Z_]+]] = arith.constant true
     // CHECK-NEXT: cf.cond_br
@@ -225,15 +223,13 @@ func.func @up_propagate_region() -> i32 {
     cf.cond_br %true, ^bb1, ^bb2(%1 : i32)
 
   ^bb1: // CHECK: ^bb1:
-    // CHECK-NEXT: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
     // CHECK-NEXT: cf.br
 
     %c1_i32 = arith.constant 1 : i32
     cf.br ^bb2(%c1_i32 : i32)
 
   ^bb2(%arg : i32): // CHECK: ^bb2(%[[VAR_1:.*]]: i32):
-    // CHECK-NEXT: %[[VAR_c1_i32_0:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
-    // CHECK-NEXT: %[[VAR_2:[0-9a-zA-Z_]+]] = arith.addi %[[VAR_1]], %[[VAR_c1_i32_0]] : i32
+    // CHECK-NEXT: %[[VAR_2:[0-9a-zA-Z_]+]] = arith.addi %[[VAR_1]], %[[VAR_c1_i32]] : i32
     // CHECK-NEXT: "foo.yield"(%[[VAR_2]]) : (i32) -> ()
 
     %c1_i32_0 = arith.constant 1 : i32
@@ -500,8 +496,8 @@ func.func @cse_multiple_regions(%c: i1, %t: tensor<5xf32>) -> (tensor<5xf32>, te
   return %r1, %r2 : tensor<5xf32>, tensor<5xf32>
 }
 // CHECK-LABEL: func @cse_multiple_regions
-//       CHECK:   %[[if:.*]] = scf.if {{.*}} {
-//       CHECK:     tensor.empty
+//       CHECK:   tensor.empty
+//       CHECK:   %[[if:.*]] = scf.if {{.*}}
 //       CHECK:     scf.yield
 //       CHECK:   } else {
 //       CHECK:     scf.yield

>From 223b391aa220a3d52fd35d6f3feaaafda44568ab Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Tue, 10 Feb 2026 01:37:00 +0000
Subject: [PATCH 2/3] fix CI.

---
 mlir/lib/Transforms/CSE.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 594460e763a3d..e5aa08cb28994 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -420,9 +420,9 @@ void CSEDriver::simplifyRegion(ScopedMapTy &knownValues,
          << OpWithFlags(region.getParentOp(), OpPrintingFlags().skipRegions());
 
   // Prevent CSE of pure operations across function boundaries.
-  std::optional<ScopedMapTy::ScopeTy> funcPureScope;
+  std::unique_ptr<ScopedMapTy::ScopeTy> funcPureScope;
   if (isa<FunctionOpInterface>(region.getParentOp())) {
-    funcPureScope.emplace(knownPureOps);
+    funcPureScope = std::make_unique<ScopedMapTy::ScopeTy>(knownPureOps);
   }
   bool hasSSADominance = domInfo->hasSSADominance(&region);
   // If the region only contains one block, then simplify it directly.

>From 809d7f5f01d0bfeeb5150d75894e7e16f7731343 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Thu, 12 Feb 2026 08:56:13 +0000
Subject: [PATCH 3/3] add isBlockCrossIsIsolatedFromAbove func.

---
 mlir/lib/Transforms/CSE.cpp | 24 +++++++++++++++++++++++-
 1 file changed, 23 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index e5aa08cb28994..925e90c71e053 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -134,7 +134,26 @@ class CSEDriver {
 };
 } // namespace
 
-/// Hoist the pure ops to the location of the Nearest Common Dominator.
+static bool isBlockCrossIsIsolatedFromAbove(DominanceInfo *dominate, Block *a,
+                                            Block *b) {
+  if (a == b)
+    return false;
+  if (a->getParent() == b->getParent())
+    return false;
+  if (dominate->dominates(b, a))
+    std::swap(b, a);
+  while (b && b->getParentOp()) {
+    Operation *parnetOp = b->getParentOp();
+    if (parnetOp->mightHaveTrait<OpTrait::IsIsolatedFromAbove>())
+      return true;
+    b = parnetOp->getBlock();
+    if (b == a)
+      return false;
+  }
+  return false;
+}
+
+/// Hoist the pure ops to the location of the Nearest Common Dominator
 LogicalResult CSEDriver::hoistPureOp(Operation *existing, Operation *op) {
   Block *ancestorBlock =
       domInfo->findNearestCommonDominator(existing->getBlock(), op->getBlock());
@@ -144,6 +163,9 @@ LogicalResult CSEDriver::hoistPureOp(Operation *existing, Operation *op) {
            << "failed";
     return failure();
   }
+  if (isBlockCrossIsIsolatedFromAbove(domInfo, ancestorBlock,
+                                      existing->getBlock()))
+    return failure();
 
   Operation *insertPoint = nullptr;
   for (Value operand : op->getOperands()) {



More information about the Mlir-commits mailing list