[Mlir-commits] [mlir] [mlir][analysis] Introduce hoist-pure-ops logic to CSE pass (PR #180556)
lonely eagle
llvmlistbot at llvm.org
Mon Feb 9 08:48:35 PST 2026
https://github.com/linuxlonelyeagle updated https://github.com/llvm/llvm-project/pull/180556
>From c41619a54a77e33a4cb83c36c2c32f4919ac82c9 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Sun, 8 Feb 2026 06:27:20 +0000
Subject: [PATCH] add hoist pure loigc to CSE.
---
mlir/lib/Transforms/CSE.cpp | 159 +++++++++++--
.../ArmSMEToLLVM/tile-spills-and-fills.mlir | 12 +-
...ot-bufferize-empty-tensor-elimination.mlir | 15 +-
.../Linalg/matmul-shared-memory-padding.mlir | 3 +-
.../test/Dialect/Linalg/transform-op-pad.mlir | 2 +-
.../SparseTensor/sparse_fill_zero.mlir | 211 +++++++++---------
.../sparse_kernels_to_iterator.mlir | 118 +++++-----
.../SparseTensor/sparse_vector_index.mlir | 85 ++++---
mlir/test/Transforms/cse.mlir | 24 +-
9 files changed, 363 insertions(+), 266 deletions(-)
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 8eaac308755fd..594460e763a3d 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -15,12 +15,14 @@
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/Support/Allocator.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/RecyclingAllocator.h"
#include <deque>
@@ -29,6 +31,7 @@ namespace mlir {
#include "mlir/Transforms/Passes.h.inc"
} // namespace mlir
+#define DEBUG_TYPE "cse"
using namespace mlir;
namespace {
@@ -101,13 +104,17 @@ class CSEDriver {
/// Attempt to eliminate a redundant operation. Returns success if the
/// operation was marked for removal, failure otherwise.
- LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op,
+ LogicalResult simplifyOperation(ScopedMapTy &knownValues,
+ ScopedMapTy &knownPureOps, Operation *op,
bool hasSSADominance);
- void simplifyBlock(ScopedMapTy &knownValues, Block *bb, bool hasSSADominance);
- void simplifyRegion(ScopedMapTy &knownValues, Region ®ion);
+ void simplifyBlock(ScopedMapTy &knownValues, ScopedMapTy &knownPureOps,
+ Block *bb, bool hasSSADominance);
+ void simplifyRegion(ScopedMapTy &knownValues, ScopedMapTy &knownPureOps,
+ Region ®ion);
void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
Operation *existing, bool hasSSADominance);
+ LogicalResult hoistPureOp(Operation *existing, Operation *op);
/// Check if there is side-effecting operations other than the given effect
/// between the two operations.
@@ -117,7 +124,7 @@ class CSEDriver {
RewriterBase &rewriter;
/// Operations marked as dead and to be erased.
- std::vector<Operation *> opsToErase;
+ SmallVector<Operation *> opsToErase;
DominanceInfo *domInfo = nullptr;
MemEffectsCache memEffectsCache;
@@ -127,6 +134,42 @@ class CSEDriver {
};
} // namespace
+/// Hoist the pure ops to the location of the Nearest Common Dominator.
+LogicalResult CSEDriver::hoistPureOp(Operation *existing, Operation *op) {
+ Block *ancestorBlock =
+ domInfo->findNearestCommonDominator(existing->getBlock(), op->getBlock());
+ if (!ancestorBlock) {
+ LDBG() << "hoist " << OpWithFlags(existing, OpPrintingFlags().skipRegions())
+ << " and " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << "failed";
+ return failure();
+ }
+
+ Operation *insertPoint = nullptr;
+ for (Value operand : op->getOperands()) {
+ if (domInfo->properlyDominates(operand, &ancestorBlock->front()))
+ continue;
+ if (!insertPoint) {
+ insertPoint = operand.getDefiningOp();
+ } else {
+ insertPoint = domInfo->dominates(insertPoint, operand.getDefiningOp())
+ ? operand.getDefiningOp()
+ : insertPoint;
+ }
+ }
+ if (!insertPoint) {
+ rewriter.moveOpBefore(existing, ancestorBlock, ancestorBlock->begin());
+ rewriter.moveOpAfter(op, existing);
+ } else {
+ rewriter.moveOpAfter(existing, insertPoint);
+ rewriter.moveOpAfter(op, existing);
+ }
+ LDBG() << "hoist " << OpWithFlags(existing, OpPrintingFlags().skipRegions())
+ << " and " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << "success";
+ return success();
+}
+
void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
Operation *existing,
bool hasSSADominance) {
@@ -141,6 +184,15 @@ void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
rewriteListener->notifyOperationReplaced(op, existing);
// Replace all uses, but do not remove the operation yet. This does not
// notify the listener because the original op is not erased.
+ if (!domInfo->properlyDominates(existing, op)) {
+ if (failed(hoistPureOp(existing, op)))
+ return;
+ }
+ LDBG() << "replace " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << " with "
+ << OpWithFlags(existing, OpPrintingFlags().skipRegions());
+ LDBG() << "add " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << " to opsToErase";
rewriter.replaceAllUsesWith(op->getResults(), existing->getResults());
opsToErase.push_back(op);
} else {
@@ -155,14 +207,25 @@ void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
if (all_of(v.getUses(), wasVisited))
rewriteListener->notifyOperationReplaced(op, existing);
+ if (!domInfo->properlyDominates(existing, op)) {
+ if (failed(hoistPureOp(existing, op)))
+ return;
+ }
// Replace all uses, but do not remove the operation yet. This does not
// notify the listener because the original op is not erased.
+ LDBG() << "replace " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << " with "
+ << OpWithFlags(existing, OpPrintingFlags().skipRegions());
rewriter.replaceUsesWithIf(op->getResults(), existing->getResults(),
wasVisited);
// There may be some remaining uses of the operation.
- if (op->use_empty())
+ if (op->use_empty()) {
+ LDBG() << "use_empty, add "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << " to opsToErase";
opsToErase.push_back(op);
+ }
}
// If the existing operation has an unknown location and the current
@@ -222,8 +285,11 @@ bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp,
/// Attempt to eliminate a redundant operation.
LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
+ ScopedMapTy &knownPureOps,
Operation *op,
bool hasSSADominance) {
+ LDBG() << "visit operation: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
// Don't simplify terminator operations.
if (op->hasTrait<OpTrait::IsTerminator>())
return failure();
@@ -261,7 +327,27 @@ LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
return success();
}
}
- knownValues.insert(op, op);
+ if (auto *existing = knownPureOps.lookup(op)) {
+ if (existing->getBlock() == op->getBlock() &&
+ !hasOtherSideEffectingOpInBetween(existing, op)) {
+ // The operation that can be deleted has been reach with no
+ // side-effecting operations in between the existing operation and
+ // this one so we can remove the duplicate.
+ replaceUsesAndDelete(knownPureOps, op, existing, hasSSADominance);
+ return success();
+ }
+ }
+
+ if (mlir::isPure(op)) {
+ LDBG() << "insert op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << "to pureMap";
+ knownPureOps.insert(op, op);
+ } else {
+ LDBG() << "insert op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions()) << "to map";
+ knownValues.insert(op, op);
+ }
return failure();
}
@@ -272,14 +358,31 @@ LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
return success();
}
- // Otherwise, we add this operation to the known values map.
- knownValues.insert(op, op);
+ if (auto *existing = knownPureOps.lookup(op)) {
+ replaceUsesAndDelete(knownPureOps, op, existing, hasSSADominance);
+ ++numCSE;
+ return success();
+ }
+
+ if (mlir::isPure(op)) {
+ LDBG() << "insert op: " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << "to pureMap";
+ knownPureOps.insert(op, op);
+ } else {
+ // Otherwise, we add this operation to the known values map.
+ LDBG() << "insert op: " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << "to map";
+ knownValues.insert(op, op);
+ }
return failure();
}
-void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb,
+void CSEDriver::simplifyBlock(ScopedMapTy &knownValues,
+ ScopedMapTy &knownPureOps, Block *bb,
bool hasSSADominance) {
- for (auto &op : *bb) {
+ LDBG() << "visit block #" << bb->computeBlockNumber() << " of "
+ << OpWithFlags(bb->getParentOp(), OpPrintingFlags().skipRegions());
+ for (auto &op : llvm::make_early_inc_range(*bb)) {
// Most operations don't have regions, so fast path that case.
if (op.getNumRegions() != 0) {
// If this operation is isolated above, we can't process nested regions
@@ -287,34 +390,45 @@ void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb,
// implicit captures in explicit capture only regions.
if (op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>()) {
ScopedMapTy nestedKnownValues;
+ ScopedMapTy nestedKnownPureOps;
+ ScopedMapTy::ScopeTy scope(nestedKnownValues);
+ ScopedMapTy::ScopeTy pureScope(nestedKnownPureOps);
for (auto ®ion : op.getRegions())
- simplifyRegion(nestedKnownValues, region);
+ simplifyRegion(nestedKnownValues, nestedKnownPureOps, region);
} else {
// Otherwise, process nested regions normally.
for (auto ®ion : op.getRegions())
- simplifyRegion(knownValues, region);
+ simplifyRegion(knownValues, knownPureOps, region);
}
}
// If the operation is simplified, we don't process any held regions.
- if (succeeded(simplifyOperation(knownValues, &op, hasSSADominance)))
+ if (succeeded(
+ simplifyOperation(knownValues, knownPureOps, &op, hasSSADominance)))
continue;
}
// Clear the MemoryEffects cache since its usage is by block only.
memEffectsCache.clear();
}
-void CSEDriver::simplifyRegion(ScopedMapTy &knownValues, Region ®ion) {
+void CSEDriver::simplifyRegion(ScopedMapTy &knownValues,
+ ScopedMapTy &knownPureOps, Region ®ion) {
// If the region is empty there is nothing to do.
if (region.empty())
return;
+ LDBG() << "visit region #" << region.getRegionNumber() << " of "
+ << OpWithFlags(region.getParentOp(), OpPrintingFlags().skipRegions());
+ // Prevent CSE of pure operations across function boundaries.
+ std::optional<ScopedMapTy::ScopeTy> funcPureScope;
+ if (isa<FunctionOpInterface>(region.getParentOp())) {
+ funcPureScope.emplace(knownPureOps);
+ }
bool hasSSADominance = domInfo->hasSSADominance(®ion);
-
// If the region only contains one block, then simplify it directly.
if (region.hasOneBlock()) {
ScopedMapTy::ScopeTy scope(knownValues);
- simplifyBlock(knownValues, ®ion.front(), hasSSADominance);
+ simplifyBlock(knownValues, knownPureOps, ®ion.front(), hasSSADominance);
return;
}
@@ -342,7 +456,7 @@ void CSEDriver::simplifyRegion(ScopedMapTy &knownValues, Region ®ion) {
// Check to see if we need to process this node.
if (!currentNode->processed) {
currentNode->processed = true;
- simplifyBlock(knownValues, currentNode->node->getBlock(),
+ simplifyBlock(knownValues, knownPureOps, currentNode->node->getBlock(),
hasSSADominance);
}
@@ -361,9 +475,14 @@ void CSEDriver::simplifyRegion(ScopedMapTy &knownValues, Region ®ion) {
void CSEDriver::simplify(Operation *op, bool *changed) {
/// Simplify all regions.
- ScopedMapTy knownValues;
- for (auto ®ion : op->getRegions())
- simplifyRegion(knownValues, region);
+ {
+ ScopedMapTy knownValues;
+ ScopedMapTy knownPureOps;
+ ScopedMapTy::ScopeTy scope(knownPureOps);
+ for (auto ®ion : op->getRegions()) {
+ simplifyRegion(knownValues, knownPureOps, region);
+ }
+ }
/// Erase any operations that were marked as dead during simplification.
for (auto *op : opsToErase)
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
index 2a183cb4d056a..7fd66035a4140 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
@@ -82,6 +82,8 @@ func.func @use_too_many_tiles() {
// AFTER-LLVM-LOWERING-DAG: %[[TILE_ALLOCA:.*]] = memref.alloca(%[[SVL_H]], %[[SVL_H]])
// AFTER-LLVM-LOWERING-SAME: {arm_sme.in_memory_tile_id = 16 : i32} : memref<?x?xi16>
+// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]
+// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1]
//
// AFTER-LLVM-LOWERING-NOT: scf.for
@@ -104,8 +106,6 @@ func.func @use_too_many_tiles() {
// AFTER-LLVM-LOWERING: scf.for
// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_H]] step %[[C1]] {
-// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]
-// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1]
// AFTER-LLVM-LOWERING: %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]]
// AFTER-LLVM-LOWERING: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}>
// AFTER-LLVM-LOWERING-NEXT: "arm_sme.intr.ld1h.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}>
@@ -122,8 +122,6 @@ func.func @use_too_many_tiles() {
// AFTER-LLVM-LOWERING: scf.for
// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_H]] step %[[C1]] {
-// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]
-// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1]
// AFTER-LLVM-LOWERING: %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]]
// AFTER-LLVM-LOWERING: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}>
// AFTER-LLVM-LOWERING-NEXT: "arm_sme.intr.ld1h.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}>
@@ -156,6 +154,8 @@ func.func @very_excessive_spills(%useAllTiles : vector<[16]x[16]xi8>, %memref: m
// AFTER-LLVM-LOWERING-DAG: %[[SVL_S:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
// AFTER-LLVM-LOWERING-DAG: %[[TILE_ALLOCA:.*]] = memref.alloca(%[[SVL_S]], %[[SVL_S]])
// AFTER-LLVM-LOWERING-SAME: {arm_sme.in_memory_tile_id = 16 : i32} : memref<?x?xf32>
+// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]
+// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1]
//
/// 1. Swap %useAllTiles and %tile - note that this will only swap one 32-bit
@@ -163,8 +163,6 @@ func.func @very_excessive_spills(%useAllTiles : vector<[16]x[16]xi8>, %memref: m
// AFTER-LLVM-LOWERING: scf.for
// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_S]] step %[[C1]] {
-// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]
-// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1]
// AFTER-LLVM-LOWERING: %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]]
// Read ZA tile slice -> vector
// AFTER-LLVM-LOWERING: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}>
@@ -182,8 +180,6 @@ func.func @very_excessive_spills(%useAllTiles : vector<[16]x[16]xi8>, %memref: m
// AFTER-LLVM-LOWERING: scf.for
// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_S]] step %[[C1]] {
-// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]
-// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1]
// AFTER-LLVM-LOWERING: %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]]
/// Read ZA tile slice -> vector
// AFTER-LLVM-LOWERING: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}>
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
index 3929f5be3b4ef..8dc6364fddb2e 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -199,27 +199,20 @@ func.func @eleminate_multiple_ops(%t: tensor<?xf32> {bufferization.buffer_layout
{
%cst1 = arith.constant 0.0: f32
%cst2 = arith.constant 1.0: f32
-
- // CHECK: %[[r:.*]] = scf.if %{{.*}} -> (memref
+ // CHECK: %[[T_SUBVIEW_1:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1]
+ // CHECK: scf.if %{{.*}}
%if = scf.if %c -> tensor<?xf32> {
- // CHECK: %[[T_SUBVIEW_1:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1]
%a1 = tensor.empty(%sz) : tensor<?xf32>
// CHECK: linalg.fill ins({{.*}} : f32) outs(%[[T_SUBVIEW_1]] : memref<?xf32
%f1 = linalg.fill ins(%cst1 : f32) outs(%a1 : tensor<?xf32>) -> tensor<?xf32>
- // CHECK: scf.yield %[[T_SUBVIEW_1]]
scf.yield %f1 : tensor<?xf32>
} else {
- // CHECK: %[[T_SUBVIEW_2:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1]
%a2 = tensor.empty(%sz) : tensor<?xf32>
- // CHECK: linalg.fill ins({{.*}} : f32) outs(%[[T_SUBVIEW_2]] : memref<?xf32
+ // CHECK: linalg.fill ins({{.*}} : f32) outs(%[[T_SUBVIEW_1]] : memref<?xf32
%f2 = linalg.fill ins(%cst2 : f32) outs(%a2 : tensor<?xf32>) -> tensor<?xf32>
- // CHECK: scf.yield %[[T_SUBVIEW_2]]
scf.yield %f2 : tensor<?xf32>
}
-
- // Self-copy could canonicalize away later.
- // CHECK: %[[T_SUBVIEW_3:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1]
- // CHECK: memref.copy %[[r]], %[[T_SUBVIEW_3]]
+ // CHECK: return %[[FUNC_ARG]]
%r1 = tensor.insert_slice %if into %t[42][%sz][1]: tensor<?xf32> into tensor<?xf32>
return %r1: tensor<?xf32>
}
diff --git a/mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir b/mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir
index 6cab25b50460d..8aa59b2bf0f5f 100644
--- a/mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir
+++ b/mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir
@@ -5,14 +5,13 @@
// CHECK-NOT: memref.copy
// CHECK: linalg.fill
// CHECK: scf.for
+// CHECK: vector.constant_mask [16, 4] : vector<128x4xi1>
// CHECK: memref.alloc() : memref<128x16xf32, 3>
// CHECK: scf.forall
-// CHECK: vector.constant_mask [16, 4] : vector<128x4xi1>
// CHECK: vector.transfer_read
// CHECK: vector.transfer_write
// CHECK: memref.alloc() : memref<16x128xf32, 3>
// CHECK: scf.forall
-// CHECK: vector.constant_mask [16, 4] : vector<128x4xi1>
// CHECK: vector.transfer_read
// CHECK: vector.transfer_write
// CHECK: memref.alloc() : memref<128x128xf32, 3>
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
index 51bf4a23406d4..9b6716f09df37 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -315,7 +315,7 @@ module attributes {transform.with_named_sequence} {
// Test dynamic padding using `use_prescribed_tensor_shapes`
-// CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (-s0 + (s0 ceildiv 7) * 7)>
+// CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (-s0 + (s0 ceildiv 7) * 7)>
// CHECK: @use_prescribed_tensor_shapes
// CHECK: (%[[ARG0:.*]]: tensor<?x12xf32>, %[[ARG1:.*]]: tensor<12x?xf32>
func.func @use_prescribed_tensor_shapes(%arg0: tensor<?x12xf32>,
diff --git a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
index c26ba56347299..f6aa91ad2b10f 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
@@ -3,118 +3,117 @@
#DCSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>
// CHECK-LABEL: func.func @fill_zero_after_alloc(
-// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr,
-// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr) -> !llvm.ptr {
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK-DAG: %[[ZERO:.*]] = llvm.mlir.zero : !llvm.ptr
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : i32
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : i32
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_7:.*]] = arith.constant false
-// CHECK-DAG: %[[VAL_8:.*]] = arith.constant true
-// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 100 : index
-// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 300 : index
-// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 262144 : i64
-// CHECK: %[[VAL_12:.*]] = memref.alloca() : memref<2xi64>
-// CHECK: %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi64> to memref<?xi64>
-// CHECK: memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi64>
-// CHECK: memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_6]]] : memref<2xi64>
-// CHECK: %[[VAL_14:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[VAL_15:.*]] = memref.cast %[[VAL_14]] : memref<2xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_9]], %[[VAL_14]]{{\[}}%[[VAL_5]]] : memref<2xindex>
-// CHECK: memref.store %[[VAL_10]], %[[VAL_14]]{{\[}}%[[VAL_6]]] : memref<2xindex>
-// CHECK: %[[VAL_16:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[VAL_17:.*]] = memref.cast %[[VAL_16]] : memref<2xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_5]], %[[VAL_16]]{{\[}}%[[VAL_5]]] : memref<2xindex>
-// CHECK: memref.store %[[VAL_6]], %[[VAL_16]]{{\[}}%[[VAL_6]]] : memref<2xindex>
-// CHECK: %[[VAL_19:.*]] = call @newSparseTensor(%[[VAL_15]], %[[VAL_15]], %[[VAL_13]], %[[VAL_17]], %[[VAL_17]], %[[VAL_4]], %[[VAL_4]], %[[VAL_3]], %[[VAL_4]], %[[ZERO]]) : (memref<?xindex>, memref<?xindex>, memref<?xi64>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr
-// CHECK: %[[VAL_20:.*]] = memref.alloc() : memref<300xf64>
-// CHECK: %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<300xf64> to memref<?xf64>
-// CHECK: %[[VAL_22:.*]] = memref.alloc() : memref<300xi1>
-// CHECK: %[[VAL_23:.*]] = memref.cast %[[VAL_22]] : memref<300xi1> to memref<?xi1>
-// CHECK: %[[VAL_24:.*]] = memref.alloc() : memref<300xindex>
-// CHECK: %[[VAL_25:.*]] = memref.cast %[[VAL_24]] : memref<300xindex> to memref<?xindex>
-// CHECK: linalg.fill ins(%[[VAL_2]] : f64) outs(%[[VAL_20]] : memref<300xf64>)
-// CHECK: linalg.fill ins(%[[VAL_7]] : i1) outs(%[[VAL_22]] : memref<300xi1>)
-// CHECK-DAG: %[[VAL_26:.*]] = call @sparsePositions0(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK-DAG: %[[VAL_27:.*]] = call @sparseCoordinates0(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK-DAG: %[[VAL_28:.*]] = call @sparsePositions0(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK-DAG: %[[VAL_29:.*]] = call @sparseCoordinates0(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK-DAG: %[[VAL_30:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr) -> memref<?xf64>
-// CHECK-DAG: %[[VAL_31:.*]] = call @sparsePositions0(%[[VAL_1]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK-DAG: %[[VAL_32:.*]] = call @sparseCoordinates0(%[[VAL_1]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK-DAG: %[[VAL_33:.*]] = call @sparsePositions0(%[[VAL_1]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK-DAG: %[[VAL_34:.*]] = call @sparseCoordinates0(%[[VAL_1]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK-DAG: %[[VAL_35:.*]] = call @sparseValuesF64(%[[VAL_1]]) : (!llvm.ptr) -> memref<?xf64>
-// CHECK: %[[VAL_36:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_6]]] : memref<?xindex>
-// CHECK: scf.for %[[VAL_38:.*]] = %[[VAL_36]] to %[[VAL_37]] step %[[VAL_6]] {
-// CHECK: %[[VAL_39:.*]] = memref.load %[[VAL_27]]{{\[}}%[[VAL_38]]] : memref<?xindex>
-// CHECK: %[[VAL_40:.*]] = memref.load %[[VAL_28]]{{\[}}%[[VAL_38]]] : memref<?xindex>
-// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_38]], %[[VAL_6]] : index
-// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_28]]{{\[}}%[[VAL_41]]] : memref<?xindex>
-// CHECK: %[[VAL_43:.*]] = memref.load %[[VAL_31]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_31]]{{\[}}%[[VAL_6]]] : memref<?xindex>
-// CHECK: %[[VAL_45:.*]]:3 = scf.while (%[[VAL_46:.*]] = %[[VAL_40]], %[[VAL_47:.*]] = %[[VAL_43]], %[[VAL_48:.*]] = %[[VAL_5]]) : (index, index, index) -> (index, index, index) {
-// CHECK: %[[VAL_49:.*]] = arith.cmpi ult, %[[VAL_46]], %[[VAL_42]] : index
-// CHECK: %[[VAL_50:.*]] = arith.cmpi ult, %[[VAL_47]], %[[VAL_44]] : index
-// CHECK: %[[VAL_51:.*]] = arith.andi %[[VAL_49]], %[[VAL_50]] : i1
-// CHECK: scf.condition(%[[VAL_51]]) %[[VAL_46]], %[[VAL_47]], %[[VAL_48]] : index, index, index
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr,
+// CHECK-SAME: %[[ARG1:.*]]: !llvm.ptr) -> !llvm.ptr {
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK: %[[MLIR_0:.*]] = llvm.mlir.zero : !llvm.ptr
+// CHECK: %[[CONSTANT_1:.*]] = arith.constant 1 : i32
+// CHECK: %[[CONSTANT_2:.*]] = arith.constant 0 : i32
+// CHECK: %[[CONSTANT_3:.*]] = arith.constant true
+// CHECK: %[[CONSTANT_4:.*]] = arith.constant false
+// CHECK: %[[CONSTANT_5:.*]] = arith.constant 1 : index
+// CHECK: %[[CONSTANT_6:.*]] = arith.constant 0 : index
+// CHECK: %[[CONSTANT_7:.*]] = arith.constant 100 : index
+// CHECK: %[[CONSTANT_8:.*]] = arith.constant 300 : index
+// CHECK: %[[CONSTANT_9:.*]] = arith.constant 262144 : i64
+// CHECK: %[[ALLOCA_0:.*]] = memref.alloca() : memref<2xi64>
+// CHECK: %[[CAST_0:.*]] = memref.cast %[[ALLOCA_0]] : memref<2xi64> to memref<?xi64>
+// CHECK: memref.store %[[CONSTANT_9]], %[[ALLOCA_0]]{{\[}}%[[CONSTANT_6]]] : memref<2xi64>
+// CHECK: memref.store %[[CONSTANT_9]], %[[ALLOCA_0]]{{\[}}%[[CONSTANT_5]]] : memref<2xi64>
+// CHECK: %[[ALLOCA_1:.*]] = memref.alloca() : memref<2xindex>
+// CHECK: %[[CAST_1:.*]] = memref.cast %[[ALLOCA_1]] : memref<2xindex> to memref<?xindex>
+// CHECK: memref.store %[[CONSTANT_7]], %[[ALLOCA_1]]{{\[}}%[[CONSTANT_6]]] : memref<2xindex>
+// CHECK: memref.store %[[CONSTANT_8]], %[[ALLOCA_1]]{{\[}}%[[CONSTANT_5]]] : memref<2xindex>
+// CHECK: %[[ALLOCA_2:.*]] = memref.alloca() : memref<2xindex>
+// CHECK: %[[CAST_2:.*]] = memref.cast %[[ALLOCA_2]] : memref<2xindex> to memref<?xindex>
+// CHECK: memref.store %[[CONSTANT_6]], %[[ALLOCA_2]]{{\[}}%[[CONSTANT_6]]] : memref<2xindex>
+// CHECK: memref.store %[[CONSTANT_5]], %[[ALLOCA_2]]{{\[}}%[[CONSTANT_5]]] : memref<2xindex>
+// CHECK: %[[VAL_0:.*]] = call @newSparseTensor(%[[CAST_1]], %[[CAST_1]], %[[CAST_0]], %[[CAST_2]], %[[CAST_2]], %[[CONSTANT_2]], %[[CONSTANT_2]], %[[CONSTANT_1]], %[[CONSTANT_2]], %[[MLIR_0]]) : (memref<?xindex>, memref<?xindex>, memref<?xi64>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr
+// CHECK: %[[ALLOC_0:.*]] = memref.alloc() : memref<300xf64>
+// CHECK: %[[CAST_3:.*]] = memref.cast %[[ALLOC_0]] : memref<300xf64> to memref<?xf64>
+// CHECK: %[[ALLOC_1:.*]] = memref.alloc() : memref<300xi1>
+// CHECK: %[[CAST_4:.*]] = memref.cast %[[ALLOC_1]] : memref<300xi1> to memref<?xi1>
+// CHECK: %[[ALLOC_2:.*]] = memref.alloc() : memref<300xindex>
+// CHECK: %[[CAST_5:.*]] = memref.cast %[[ALLOC_2]] : memref<300xindex> to memref<?xindex>
+// CHECK: linalg.fill ins(%[[CONSTANT_0]] : f64) outs(%[[ALLOC_0]] : memref<300xf64>)
+// CHECK: linalg.fill ins(%[[CONSTANT_4]] : i1) outs(%[[ALLOC_1]] : memref<300xi1>)
+// CHECK: %[[VAL_1:.*]] = call @sparseValuesF64(%[[ARG0]]) : (!llvm.ptr) -> memref<?xf64>
+// CHECK: %[[VAL_2:.*]] = call @sparseValuesF64(%[[ARG1]]) : (!llvm.ptr) -> memref<?xf64>
+// CHECK: %[[VAL_3:.*]] = call @sparsePositions0(%[[ARG0]], %[[CONSTANT_6]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK: %[[VAL_4:.*]] = call @sparseCoordinates0(%[[ARG0]], %[[CONSTANT_6]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK: %[[VAL_5:.*]] = call @sparsePositions0(%[[ARG0]], %[[CONSTANT_5]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK: %[[VAL_6:.*]] = call @sparseCoordinates0(%[[ARG0]], %[[CONSTANT_5]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK: %[[VAL_7:.*]] = call @sparsePositions0(%[[ARG1]], %[[CONSTANT_6]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK: %[[VAL_8:.*]] = call @sparseCoordinates0(%[[ARG1]], %[[CONSTANT_6]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK: %[[VAL_9:.*]] = call @sparsePositions0(%[[ARG1]], %[[CONSTANT_5]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK: %[[VAL_10:.*]] = call @sparseCoordinates0(%[[ARG1]], %[[CONSTANT_5]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK: %[[LOAD_0:.*]] = memref.load %[[VAL_3]]{{\[}}%[[CONSTANT_6]]] : memref<?xindex>
+// CHECK: %[[LOAD_1:.*]] = memref.load %[[VAL_3]]{{\[}}%[[CONSTANT_5]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_11:.*]] = %[[LOAD_0]] to %[[LOAD_1]] step %[[CONSTANT_5]] {
+// CHECK: %[[LOAD_2:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_11]]] : memref<?xindex>
+// CHECK: %[[LOAD_3:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_11]]] : memref<?xindex>
+// CHECK: %[[ADDI_0:.*]] = arith.addi %[[VAL_11]], %[[CONSTANT_5]] : index
+// CHECK: %[[LOAD_4:.*]] = memref.load %[[VAL_5]]{{\[}}%[[ADDI_0]]] : memref<?xindex>
+// CHECK: %[[LOAD_5:.*]] = memref.load %[[VAL_7]]{{\[}}%[[CONSTANT_6]]] : memref<?xindex>
+// CHECK: %[[LOAD_6:.*]] = memref.load %[[VAL_7]]{{\[}}%[[CONSTANT_5]]] : memref<?xindex>
+// CHECK: %[[WHILE_0:.*]]:3 = scf.while (%[[VAL_12:.*]] = %[[LOAD_3]], %[[VAL_13:.*]] = %[[LOAD_5]], %[[VAL_14:.*]] = %[[CONSTANT_6]]) : (index, index, index) -> (index, index, index) {
+// CHECK: %[[CMPI_0:.*]] = arith.cmpi ult, %[[VAL_12]], %[[LOAD_4]] : index
+// CHECK: %[[CMPI_1:.*]] = arith.cmpi ult, %[[VAL_13]], %[[LOAD_6]] : index
+// CHECK: %[[ANDI_0:.*]] = arith.andi %[[CMPI_0]], %[[CMPI_1]] : i1
+// CHECK: scf.condition(%[[ANDI_0]]) %[[VAL_12]], %[[VAL_13]], %[[VAL_14]] : index, index, index
// CHECK: } do {
-// CHECK: ^bb0(%[[VAL_52:.*]]: index, %[[VAL_53:.*]]: index, %[[VAL_54:.*]]: index):
-// CHECK: %[[VAL_55:.*]] = memref.load %[[VAL_29]]{{\[}}%[[VAL_52]]] : memref<?xindex>
-// CHECK: %[[VAL_56:.*]] = memref.load %[[VAL_32]]{{\[}}%[[VAL_53]]] : memref<?xindex>
-// CHECK: %[[VAL_57:.*]] = arith.cmpi ult, %[[VAL_56]], %[[VAL_55]] : index
-// CHECK: %[[VAL_58:.*]] = arith.select %[[VAL_57]], %[[VAL_56]], %[[VAL_55]] : index
-// CHECK: %[[VAL_59:.*]] = arith.cmpi eq, %[[VAL_55]], %[[VAL_58]] : index
-// CHECK: %[[VAL_60:.*]] = arith.cmpi eq, %[[VAL_56]], %[[VAL_58]] : index
-// CHECK: %[[VAL_61:.*]] = arith.andi %[[VAL_59]], %[[VAL_60]] : i1
-// CHECK: %[[VAL_62:.*]] = scf.if %[[VAL_61]] -> (index) {
-// CHECK: %[[VAL_63:.*]] = memref.load %[[VAL_30]]{{\[}}%[[VAL_52]]] : memref<?xf64>
-// CHECK: %[[VAL_64:.*]] = memref.load %[[VAL_33]]{{\[}}%[[VAL_53]]] : memref<?xindex>
-// CHECK: %[[VAL_65:.*]] = arith.addi %[[VAL_53]], %[[VAL_6]] : index
-// CHECK: %[[VAL_66:.*]] = memref.load %[[VAL_33]]{{\[}}%[[VAL_65]]] : memref<?xindex>
-// CHECK: %[[VAL_67:.*]] = scf.for %[[VAL_68:.*]] = %[[VAL_64]] to %[[VAL_66]] step %[[VAL_6]] iter_args(%[[VAL_69:.*]] = %[[VAL_54]]) -> (index) {
-// CHECK: %[[VAL_70:.*]] = memref.load %[[VAL_34]]{{\[}}%[[VAL_68]]] : memref<?xindex>
-// CHECK: %[[VAL_71:.*]] = memref.load %[[VAL_20]]{{\[}}%[[VAL_70]]] : memref<300xf64>
-// CHECK: %[[VAL_72:.*]] = memref.load %[[VAL_35]]{{\[}}%[[VAL_68]]] : memref<?xf64>
-// CHECK: %[[VAL_73:.*]] = arith.mulf %[[VAL_63]], %[[VAL_72]] : f64
-// CHECK: %[[VAL_74:.*]] = arith.addf %[[VAL_71]], %[[VAL_73]] : f64
-// CHECK: %[[VAL_75:.*]] = memref.load %[[VAL_22]]{{\[}}%[[VAL_70]]] : memref<300xi1>
-// CHECK: %[[VAL_76:.*]] = arith.cmpi eq, %[[VAL_75]], %[[VAL_7]] : i1
-// CHECK: %[[VAL_77:.*]] = scf.if %[[VAL_76]] -> (index) {
-// CHECK: memref.store %[[VAL_8]], %[[VAL_22]]{{\[}}%[[VAL_70]]] : memref<300xi1>
-// CHECK: memref.store %[[VAL_70]], %[[VAL_24]]{{\[}}%[[VAL_69]]] : memref<300xindex>
-// CHECK: %[[VAL_78:.*]] = arith.addi %[[VAL_69]], %[[VAL_6]] : index
-// CHECK: scf.yield %[[VAL_78]] : index
+// CHECK: ^bb0(%[[VAL_15:.*]]: index, %[[VAL_16:.*]]: index, %[[VAL_17:.*]]: index):
+// CHECK: %[[ADDI_1:.*]] = arith.addi %[[VAL_16]], %[[CONSTANT_5]] : index
+// CHECK: %[[LOAD_7:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_15]]] : memref<?xindex>
+// CHECK: %[[LOAD_8:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref<?xindex>
+// CHECK: %[[CMPI_2:.*]] = arith.cmpi ult, %[[LOAD_8]], %[[LOAD_7]] : index
+// CHECK: %[[SELECT_0:.*]] = arith.select %[[CMPI_2]], %[[LOAD_8]], %[[LOAD_7]] : index
+// CHECK: %[[CMPI_3:.*]] = arith.cmpi eq, %[[LOAD_7]], %[[SELECT_0]] : index
+// CHECK: %[[CMPI_4:.*]] = arith.cmpi eq, %[[LOAD_8]], %[[SELECT_0]] : index
+// CHECK: %[[ANDI_1:.*]] = arith.andi %[[CMPI_3]], %[[CMPI_4]] : i1
+// CHECK: %[[IF_0:.*]] = scf.if %[[ANDI_1]] -> (index) {
+// CHECK: %[[LOAD_9:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_15]]] : memref<?xf64>
+// CHECK: %[[LOAD_10:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_16]]] : memref<?xindex>
+// CHECK: %[[LOAD_11:.*]] = memref.load %[[VAL_9]]{{\[}}%[[ADDI_1]]] : memref<?xindex>
+// CHECK: %[[FOR_0:.*]] = scf.for %[[VAL_18:.*]] = %[[LOAD_10]] to %[[LOAD_11]] step %[[CONSTANT_5]] iter_args(%[[VAL_19:.*]] = %[[VAL_17]]) -> (index) {
+// CHECK: %[[LOAD_12:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_18]]] : memref<?xindex>
+// CHECK: %[[LOAD_13:.*]] = memref.load %[[ALLOC_0]]{{\[}}%[[LOAD_12]]] : memref<300xf64>
+// CHECK: %[[LOAD_14:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_18]]] : memref<?xf64>
+// CHECK: %[[MULF_0:.*]] = arith.mulf %[[LOAD_9]], %[[LOAD_14]] : f64
+// CHECK: %[[ADDF_0:.*]] = arith.addf %[[LOAD_13]], %[[MULF_0]] : f64
+// CHECK: %[[LOAD_15:.*]] = memref.load %[[ALLOC_1]]{{\[}}%[[LOAD_12]]] : memref<300xi1>
+// CHECK: %[[CMPI_5:.*]] = arith.cmpi eq, %[[LOAD_15]], %[[CONSTANT_4]] : i1
+// CHECK: %[[IF_1:.*]] = scf.if %[[CMPI_5]] -> (index) {
+// CHECK: memref.store %[[CONSTANT_3]], %[[ALLOC_1]]{{\[}}%[[LOAD_12]]] : memref<300xi1>
+// CHECK: memref.store %[[LOAD_12]], %[[ALLOC_2]]{{\[}}%[[VAL_19]]] : memref<300xindex>
+// CHECK: %[[ADDI_2:.*]] = arith.addi %[[VAL_19]], %[[CONSTANT_5]] : index
+// CHECK: scf.yield %[[ADDI_2]] : index
// CHECK: } else {
-// CHECK: scf.yield %[[VAL_69]] : index
+// CHECK: scf.yield %[[VAL_19]] : index
// CHECK: }
-// CHECK: memref.store %[[VAL_74]], %[[VAL_20]]{{\[}}%[[VAL_70]]] : memref<300xf64>
-// CHECK: scf.yield %[[VAL_77]] : index
-// CHECK: }
-// CHECK: scf.yield %[[VAL_67]] : index
+// CHECK: memref.store %[[ADDF_0]], %[[ALLOC_0]]{{\[}}%[[LOAD_12]]] : memref<300xf64>
+// CHECK: scf.yield %[[IF_1]] : index
+// CHECK: } {"Emitted from" = "linalg.generic"}
+// CHECK: scf.yield %[[FOR_0]] : index
// CHECK: } else {
-// CHECK: scf.yield %[[VAL_54]] : index
+// CHECK: scf.yield %[[VAL_17]] : index
// CHECK: }
-// CHECK: %[[VAL_79:.*]] = arith.addi %[[VAL_52]], %[[VAL_6]] : index
-// CHECK: %[[VAL_80:.*]] = arith.select %[[VAL_59]], %[[VAL_79]], %[[VAL_52]] : index
-// CHECK: %[[VAL_81:.*]] = arith.addi %[[VAL_53]], %[[VAL_6]] : index
-// CHECK: %[[VAL_82:.*]] = arith.select %[[VAL_60]], %[[VAL_81]], %[[VAL_53]] : index
-// CHECK: scf.yield %[[VAL_80]], %[[VAL_82]], %[[VAL_62]] : index, index, index
+// CHECK: %[[ADDI_3:.*]] = arith.addi %[[VAL_15]], %[[CONSTANT_5]] : index
+// CHECK: %[[SELECT_1:.*]] = arith.select %[[CMPI_3]], %[[ADDI_3]], %[[VAL_15]] : index
+// CHECK: %[[SELECT_2:.*]] = arith.select %[[CMPI_4]], %[[ADDI_1]], %[[VAL_16]] : index
+// CHECK: scf.yield %[[SELECT_1]], %[[SELECT_2]], %[[IF_0]] : index, index, index
// CHECK: }
-// CHECK: %[[VAL_83:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[VAL_84:.*]] = memref.cast %[[VAL_83]] : memref<2xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_39]], %[[VAL_83]]{{\[}}%[[VAL_5]]] : memref<2xindex>
-// CHECK: func.call @expInsertF64(%[[VAL_19]], %[[VAL_84]], %[[VAL_21]], %[[VAL_23]], %[[VAL_25]], %[[VAL_85:.*]]#2) : (!llvm.ptr, memref<?xindex>, memref<?xf64>, memref<?xi1>, memref<?xindex>, index) -> ()
-// CHECK: }
-// CHECK: memref.dealloc %[[VAL_20]] : memref<300xf64>
-// CHECK: memref.dealloc %[[VAL_22]] : memref<300xi1>
-// CHECK: memref.dealloc %[[VAL_24]] : memref<300xindex>
-// CHECK: call @endLexInsert(%[[VAL_19]]) : (!llvm.ptr) -> ()
-// CHECK: return %[[VAL_19]] : !llvm.ptr
-// CHECK: }
+// CHECK: %[[ALLOCA_3:.*]] = memref.alloca() : memref<2xindex>
+// CHECK: %[[CAST_6:.*]] = memref.cast %[[ALLOCA_3]] : memref<2xindex> to memref<?xindex>
+// CHECK: memref.store %[[LOAD_2]], %[[ALLOCA_3]]{{\[}}%[[CONSTANT_6]]] : memref<2xindex>
+// CHECK: func.call @expInsertF64(%[[VAL_0]], %[[CAST_6]], %[[CAST_3]], %[[CAST_4]], %[[CAST_5]], %[[VAL_20:.*]]#2) : (!llvm.ptr, memref<?xindex>, memref<?xf64>, memref<?xi1>, memref<?xindex>, index) -> ()
+// CHECK: } {"Emitted from" = "linalg.generic"}
+// CHECK: memref.dealloc %[[ALLOC_0]] : memref<300xf64>
+// CHECK: memref.dealloc %[[ALLOC_1]] : memref<300xi1>
+// CHECK: memref.dealloc %[[ALLOC_2]] : memref<300xindex>
+// CHECK: call @endLexInsert(%[[VAL_0]]) : (!llvm.ptr) -> ()
+// CHECK: return %[[VAL_0]] : !llvm.ptr
+// CHECK: }
func.func @fill_zero_after_alloc(%arg0: tensor<100x200xf64, #DCSR>,
%arg1: tensor<200x300xf64, #DCSR>) -> tensor<100x300xf64, #DCSR> {
%0 = tensor.empty() : tensor<100x300xf64, #DCSR>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir b/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
index f6f7f396adab5..c4d86b6b6931f 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
@@ -79,76 +79,72 @@ func.func @sqsum(%arg0: tensor<?x?x?x?xi32, #COO>) -> tensor<i32> {
// ITER: }
// CHECK-LABEL: func.func @add(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<10xi32, #sparse{{.*}}>,
-// CHECK-SAME: %[[VAL_1:.*]]: tensor<10xi32, #sparse{{.*}}>) -> tensor<10xi32> {
-// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_4:.*]] = arith.constant 0 : i32
-// CHECK: %[[VAL_5:.*]] = arith.constant dense<0> : tensor<10xi32>
-// CHECK: %[[VAL_6:.*]] = bufferization.to_buffer %[[VAL_5]] : tensor<10xi32> to memref<10xi32>
-// CHECK: linalg.fill ins(%[[VAL_4]] : i32) outs(%[[VAL_6]] : memref<10xi32>)
-// CHECK: %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref<?xindex>
-// CHECK: %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref<?xindex>
-// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_2]]] : memref<?xindex>
-// CHECK: %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref<?xindex>
-// CHECK: %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref<?xindex>
-// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_2]]] : memref<?xindex>
-// CHECK: %[[VAL_15:.*]]:2 = scf.while (%[[VAL_16:.*]] = %[[VAL_9]], %[[VAL_17:.*]] = %[[VAL_13]]) : (index, index) -> (index, index) {
-// CHECK: %[[VAL_18:.*]] = arith.cmpi ult, %[[VAL_16]], %[[VAL_10]] : index
-// CHECK: %[[VAL_19:.*]] = arith.cmpi ult, %[[VAL_17]], %[[VAL_14]] : index
-// CHECK: %[[VAL_20:.*]] = arith.andi %[[VAL_18]], %[[VAL_19]] : i1
-// CHECK: scf.condition(%[[VAL_20]]) %[[VAL_16]], %[[VAL_17]] : index, index
+// CHECK-SAME: %[[ARG0:.*]]: tensor<10xi32, {{.*}}>,
+// CHECK-SAME: %[[ARG1:.*]]: tensor<10xi32, {{.*}}>) -> tensor<10xi32> {
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : index
+// CHECK: %[[CONSTANT_1:.*]] = arith.constant 0 : index
+// CHECK: %[[CONSTANT_2:.*]] = arith.constant dense<0> : tensor<10xi32>
+// CHECK: %[[CONSTANT_3:.*]] = arith.constant 0 : i32
+// CHECK: %[[VALUES_0:.*]] = sparse_tensor.values %[[ARG1]] : tensor<10xi32, {{.*}}> to memref<?xi32>
+// CHECK: %[[VALUES_1:.*]] = sparse_tensor.values %[[ARG0]] : tensor<10xi32, {{.*}}> to memref<?xi32>
+// CHECK: %[[TO_BUFFER_0:.*]] = bufferization.to_buffer %[[CONSTANT_2]] : tensor<10xi32> to memref<10xi32>
+// CHECK: linalg.fill ins(%[[CONSTANT_3]] : i32) outs(%[[TO_BUFFER_0]] : memref<10xi32>)
+// CHECK: %[[POSITIONS_0:.*]] = sparse_tensor.positions %[[ARG0]] {level = 0 : index} : tensor<10xi32, {{.*}}> to memref<?xindex>
+// CHECK: %[[COORDINATES_0:.*]] = sparse_tensor.coordinates %[[ARG0]] {level = 0 : index} : tensor<10xi32, {{.*}}> to memref<?xindex>
+// CHECK: %[[LOAD_0:.*]] = memref.load %[[POSITIONS_0]]{{\[}}%[[CONSTANT_1]]] : memref<?xindex>
+// CHECK: %[[LOAD_1:.*]] = memref.load %[[POSITIONS_0]]{{\[}}%[[CONSTANT_0]]] : memref<?xindex>
+// CHECK: %[[POSITIONS_1:.*]] = sparse_tensor.positions %[[ARG1]] {level = 0 : index} : tensor<10xi32, {{.*}}> to memref<?xindex>
+// CHECK: %[[COORDINATES_1:.*]] = sparse_tensor.coordinates %[[ARG1]] {level = 0 : index} : tensor<10xi32, {{.*}}> to memref<?xindex>
+// CHECK: %[[LOAD_2:.*]] = memref.load %[[POSITIONS_1]]{{\[}}%[[CONSTANT_1]]] : memref<?xindex>
+// CHECK: %[[LOAD_3:.*]] = memref.load %[[POSITIONS_1]]{{\[}}%[[CONSTANT_0]]] : memref<?xindex>
+// CHECK: %[[WHILE_0:.*]]:2 = scf.while (%[[VAL_0:.*]] = %[[LOAD_0]], %[[VAL_1:.*]] = %[[LOAD_2]]) : (index, index) -> (index, index) {
+// CHECK: %[[CMPI_0:.*]] = arith.cmpi ult, %[[VAL_0]], %[[LOAD_1]] : index
+// CHECK: %[[CMPI_1:.*]] = arith.cmpi ult, %[[VAL_1]], %[[LOAD_3]] : index
+// CHECK: %[[ANDI_0:.*]] = arith.andi %[[CMPI_0]], %[[CMPI_1]] : i1
+// CHECK: scf.condition(%[[ANDI_0]]) %[[VAL_0]], %[[VAL_1]] : index, index
// CHECK: } do {
-// CHECK: ^bb0(%[[VAL_21:.*]]: index, %[[VAL_22:.*]]: index):
-// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_21]]] : memref<?xindex>
-// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_22]]] : memref<?xindex>
-// CHECK: %[[VAL_25:.*]] = arith.cmpi ult, %[[VAL_24]], %[[VAL_23]] : index
-// CHECK: %[[VAL_26:.*]] = arith.select %[[VAL_25]], %[[VAL_24]], %[[VAL_23]] : index
-// CHECK: %[[VAL_27:.*]] = arith.cmpi eq, %[[VAL_23]], %[[VAL_26]] : index
-// CHECK: %[[VAL_28:.*]] = arith.cmpi eq, %[[VAL_24]], %[[VAL_26]] : index
-// CHECK: %[[VAL_29:.*]] = arith.andi %[[VAL_27]], %[[VAL_28]] : i1
-// CHECK: scf.if %[[VAL_29]] {
-// CHECK: %[[VAL_30:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
-// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_30]]{{\[}}%[[VAL_21]]] : memref<?xi32>
-// CHECK: %[[VAL_32:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
-// CHECK: %[[VAL_33:.*]] = memref.load %[[VAL_32]]{{\[}}%[[VAL_22]]] : memref<?xi32>
-// CHECK: %[[VAL_34:.*]] = arith.addi %[[VAL_31]], %[[VAL_33]] : i32
-// CHECK: memref.store %[[VAL_34]], %[[VAL_6]]{{\[}}%[[VAL_26]]] : memref<10xi32>
+// CHECK: ^bb0(%[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index):
+// CHECK: %[[LOAD_4:.*]] = memref.load %[[COORDINATES_0]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK: %[[LOAD_5:.*]] = memref.load %[[COORDINATES_1]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK: %[[CMPI_2:.*]] = arith.cmpi ult, %[[LOAD_5]], %[[LOAD_4]] : index
+// CHECK: %[[SELECT_0:.*]] = arith.select %[[CMPI_2]], %[[LOAD_5]], %[[LOAD_4]] : index
+// CHECK: %[[CMPI_3:.*]] = arith.cmpi eq, %[[LOAD_4]], %[[SELECT_0]] : index
+// CHECK: %[[CMPI_4:.*]] = arith.cmpi eq, %[[LOAD_5]], %[[SELECT_0]] : index
+// CHECK: %[[ANDI_1:.*]] = arith.andi %[[CMPI_3]], %[[CMPI_4]] : i1
+// CHECK: scf.if %[[ANDI_1]] {
+// CHECK: %[[LOAD_6:.*]] = memref.load %[[VALUES_1]]{{\[}}%[[VAL_2]]] : memref<?xi32>
+// CHECK: %[[LOAD_7:.*]] = memref.load %[[VALUES_0]]{{\[}}%[[VAL_3]]] : memref<?xi32>
+// CHECK: %[[ADDI_0:.*]] = arith.addi %[[LOAD_6]], %[[LOAD_7]] : i32
+// CHECK: memref.store %[[ADDI_0]], %[[TO_BUFFER_0]]{{\[}}%[[SELECT_0]]] : memref<10xi32>
// CHECK: } else {
-// CHECK: scf.if %[[VAL_27]] {
-// CHECK: %[[VAL_35:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
-// CHECK: %[[VAL_36:.*]] = memref.load %[[VAL_35]]{{\[}}%[[VAL_21]]] : memref<?xi32>
-// CHECK: memref.store %[[VAL_36]], %[[VAL_6]]{{\[}}%[[VAL_26]]] : memref<10xi32>
+// CHECK: scf.if %[[CMPI_3]] {
+// CHECK: %[[LOAD_8:.*]] = memref.load %[[VALUES_1]]{{\[}}%[[VAL_2]]] : memref<?xi32>
+// CHECK: memref.store %[[LOAD_8]], %[[TO_BUFFER_0]]{{\[}}%[[SELECT_0]]] : memref<10xi32>
// CHECK: } else {
-// CHECK: scf.if %[[VAL_28]] {
-// CHECK: %[[VAL_37:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
-// CHECK: %[[VAL_38:.*]] = memref.load %[[VAL_37]]{{\[}}%[[VAL_22]]] : memref<?xi32>
-// CHECK: memref.store %[[VAL_38]], %[[VAL_6]]{{\[}}%[[VAL_26]]] : memref<10xi32>
+// CHECK: scf.if %[[CMPI_4]] {
+// CHECK: %[[LOAD_9:.*]] = memref.load %[[VALUES_0]]{{\[}}%[[VAL_3]]] : memref<?xi32>
+// CHECK: memref.store %[[LOAD_9]], %[[TO_BUFFER_0]]{{\[}}%[[SELECT_0]]] : memref<10xi32>
// CHECK: }
// CHECK: }
// CHECK: }
-// CHECK: %[[VAL_39:.*]] = arith.addi %[[VAL_21]], %[[VAL_2]] : index
-// CHECK: %[[VAL_40:.*]] = arith.select %[[VAL_27]], %[[VAL_39]], %[[VAL_21]] : index
-// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_22]], %[[VAL_2]] : index
-// CHECK: %[[VAL_42:.*]] = arith.select %[[VAL_28]], %[[VAL_41]], %[[VAL_22]] : index
-// CHECK: scf.yield %[[VAL_40]], %[[VAL_42]] : index, index
+// CHECK: %[[ADDI_1:.*]] = arith.addi %[[VAL_2]], %[[CONSTANT_0]] : index
+// CHECK: %[[SELECT_1:.*]] = arith.select %[[CMPI_3]], %[[ADDI_1]], %[[VAL_2]] : index
+// CHECK: %[[ADDI_2:.*]] = arith.addi %[[VAL_3]], %[[CONSTANT_0]] : index
+// CHECK: %[[SELECT_2:.*]] = arith.select %[[CMPI_4]], %[[ADDI_2]], %[[VAL_3]] : index
+// CHECK: scf.yield %[[SELECT_1]], %[[SELECT_2]] : index, index
// CHECK: }
-// CHECK: %[[VAL_43:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
-// CHECK: scf.for %[[VAL_44:.*]] = %[[VAL_45:.*]]#0 to %[[VAL_10]] step %[[VAL_2]] {
-// CHECK: %[[VAL_46:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_44]]] : memref<?xindex>
-// CHECK: %[[VAL_47:.*]] = memref.load %[[VAL_43]]{{\[}}%[[VAL_44]]] : memref<?xi32>
-// CHECK: memref.store %[[VAL_47]], %[[VAL_6]]{{\[}}%[[VAL_46]]] : memref<10xi32>
+// CHECK: scf.for %[[VAL_4:.*]] = %[[VAL_5:.*]]#0 to %[[LOAD_1]] step %[[CONSTANT_0]] {
+// CHECK: %[[LOAD_10:.*]] = memref.load %[[COORDINATES_0]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK: %[[LOAD_11:.*]] = memref.load %[[VALUES_1]]{{\[}}%[[VAL_4]]] : memref<?xi32>
+// CHECK: memref.store %[[LOAD_11]], %[[TO_BUFFER_0]]{{\[}}%[[LOAD_10]]] : memref<10xi32>
// CHECK: }
-// CHECK: %[[VAL_48:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
-// CHECK: scf.for %[[VAL_49:.*]] = %[[VAL_50:.*]]#1 to %[[VAL_14]] step %[[VAL_2]] {
-// CHECK: %[[VAL_51:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_49]]] : memref<?xindex>
-// CHECK: %[[VAL_52:.*]] = memref.load %[[VAL_48]]{{\[}}%[[VAL_49]]] : memref<?xi32>
-// CHECK: memref.store %[[VAL_52]], %[[VAL_6]]{{\[}}%[[VAL_51]]] : memref<10xi32>
+// CHECK: scf.for %[[VAL_6:.*]] = %[[VAL_7:.*]]#1 to %[[LOAD_3]] step %[[CONSTANT_0]] {
+// CHECK: %[[LOAD_12:.*]] = memref.load %[[COORDINATES_1]]{{\[}}%[[VAL_6]]] : memref<?xindex>
+// CHECK: %[[LOAD_13:.*]] = memref.load %[[VALUES_0]]{{\[}}%[[VAL_6]]] : memref<?xi32>
+// CHECK: memref.store %[[LOAD_13]], %[[TO_BUFFER_0]]{{\[}}%[[LOAD_12]]] : memref<10xi32>
// CHECK: }
-// CHECK: %[[VAL_53:.*]] = bufferization.to_tensor %[[VAL_6]] : memref<10xi32>
-// CHECK: return %[[VAL_53]] : tensor<10xi32>
+// CHECK: %[[TO_TENSOR_0:.*]] = bufferization.to_tensor %[[TO_BUFFER_0]] : memref<10xi32> to tensor<10xi32>
+// CHECK: return %[[TO_TENSOR_0]] : tensor<10xi32>
// CHECK: }
func.func @add(%arg0: tensor<10xi32, #VEC>, %arg1: tensor<10xi32, #VEC>) -> tensor<10xi32> {
%cst = arith.constant dense<0> : tensor<10xi32>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir
index e9587edef4678..165d0835b5824 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir
@@ -58,56 +58,55 @@ func.func @sparse_index_1d_conj(%arga: tensor<8xi64, #SparseVector>) -> tensor<8
return %r : tensor<8xi64>
}
-// CHECK-LABEL: func.func @sparse_index_1d_disj(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<8xi64, #sparse{{[0-9]*}}>) -> tensor<8xi64> {
-// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 8 : index
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : i64
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_6:.*]] = arith.constant true
-// CHECK-DAG: %[[VAL_7:.*]] = tensor.empty() : tensor<8xi64>
-// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xi64>
-// CHECK-DAG: %[[VAL_11:.*]] = bufferization.to_buffer %[[VAL_7]] : tensor<8xi64> to memref<8xi64>
-// CHECK-DAG: linalg.fill ins(%[[VAL_3]] : i64) outs(%[[VAL_11]] : memref<8xi64>)
-// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK: %[[VAL_14:.*]]:2 = scf.while (%[[VAL_15:.*]] = %[[VAL_12]], %[[VAL_16:.*]] = %[[VAL_4]]) : (index, index) -> (index, index) {
-// CHECK: %[[VAL_17:.*]] = arith.cmpi ult, %[[VAL_15]], %[[VAL_13]] : index
-// CHECK: scf.condition(%[[VAL_17]]) %[[VAL_15]], %[[VAL_16]] : index, index
+// CHECK-LABEL: func.func @sparse_index_1d_disj(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<8xi64, #sparse{{[0-9]*}}>) -> tensor<8xi64> {
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
+// CHECK: %[[CONSTANT_1:.*]] = arith.constant true
+// CHECK: %[[CONSTANT_2:.*]] = arith.constant 1 : index
+// CHECK: %[[CONSTANT_3:.*]] = arith.constant 0 : index
+// CHECK: %[[CONSTANT_4:.*]] = arith.constant 8 : index
+// CHECK: %[[CONSTANT_5:.*]] = arith.constant 0 : i64
+// CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<8xi64>
+// CHECK: %[[VALUES_0:.*]] = sparse_tensor.values %[[ARG0]] : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xi64>
+// CHECK: %[[TO_BUFFER_0:.*]] = bufferization.to_buffer %[[EMPTY_0]] : tensor<8xi64> to memref<8xi64>
+// CHECK: linalg.fill ins(%[[CONSTANT_5]] : i64) outs(%[[TO_BUFFER_0]] : memref<8xi64>)
+// CHECK: %[[POSITIONS_0:.*]] = sparse_tensor.positions %[[ARG0]] {level = 0 : index} : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK: %[[COORDINATES_0:.*]] = sparse_tensor.coordinates %[[ARG0]] {level = 0 : index} : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK: %[[LOAD_0:.*]] = memref.load %[[POSITIONS_0]]{{\[}}%[[CONSTANT_3]]] : memref<?xindex>
+// CHECK: %[[LOAD_1:.*]] = memref.load %[[POSITIONS_0]]{{\[}}%[[CONSTANT_2]]] : memref<?xindex>
+// CHECK: %[[WHILE_0:.*]]:2 = scf.while (%[[VAL_0:.*]] = %[[LOAD_0]], %[[VAL_1:.*]] = %[[CONSTANT_3]]) : (index, index) -> (index, index) {
+// CHECK: %[[CMPI_0:.*]] = arith.cmpi ult, %[[VAL_0]], %[[LOAD_1]] : index
+// CHECK: scf.condition(%[[CMPI_0]]) %[[VAL_0]], %[[VAL_1]] : index, index
// CHECK: } do {
-// CHECK: ^bb0(%[[VAL_18:.*]]: index, %[[VAL_19:.*]]: index):
-// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref<?xindex>
-// CHECK: %[[VAL_21:.*]] = arith.cmpi eq, %[[VAL_20]], %[[VAL_19]] : index
-// CHECK: scf.if %[[VAL_21]] {
-// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_18]]] : memref<?xi64>
-// CHECK: %[[VAL_23:.*]] = arith.index_cast %[[VAL_19]] : index to i64
-// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_22]], %[[VAL_23]] : i64
-// CHECK: memref.store %[[VAL_24]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<8xi64>
+// CHECK: ^bb0(%[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index):
+// CHECK: %[[INDEX_CAST_0:.*]] = arith.index_cast %[[VAL_3]] : index to i64
+// CHECK: %[[LOAD_2:.*]] = memref.load %[[COORDINATES_0]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK: %[[CMPI_1:.*]] = arith.cmpi eq, %[[LOAD_2]], %[[VAL_3]] : index
+// CHECK: scf.if %[[CMPI_1]] {
+// CHECK: %[[LOAD_3:.*]] = memref.load %[[VALUES_0]]{{\[}}%[[VAL_2]]] : memref<?xi64>
+// CHECK: %[[ADDI_0:.*]] = arith.addi %[[LOAD_3]], %[[INDEX_CAST_0]] : i64
+// CHECK: memref.store %[[ADDI_0]], %[[TO_BUFFER_0]]{{\[}}%[[VAL_3]]] : memref<8xi64>
// CHECK: } else {
-// CHECK: scf.if %[[VAL_6]] {
-// CHECK: %[[VAL_25:.*]] = arith.index_cast %[[VAL_19]] : index to i64
-// CHECK: memref.store %[[VAL_25]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<8xi64>
+// CHECK: scf.if %[[CONSTANT_1]] {
+// CHECK: memref.store %[[INDEX_CAST_0]], %[[TO_BUFFER_0]]{{\[}}%[[VAL_3]]] : memref<8xi64>
// CHECK: } else {
// CHECK: }
// CHECK: }
-// CHECK: %[[VAL_26:.*]] = arith.addi %[[VAL_18]], %[[VAL_5]] : index
-// CHECK: %[[VAL_27:.*]] = arith.select %[[VAL_21]], %[[VAL_26]], %[[VAL_18]] : index
-// CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_19]], %[[VAL_5]] : index
-// CHECK: scf.yield %[[VAL_27]], %[[VAL_28]] : index, index
+// CHECK: %[[ADDI_1:.*]] = arith.addi %[[VAL_2]], %[[CONSTANT_2]] : index
+// CHECK: %[[SELECT_0:.*]] = arith.select %[[CMPI_1]], %[[ADDI_1]], %[[VAL_2]] : index
+// CHECK: %[[ADDI_2:.*]] = arith.addi %[[VAL_3]], %[[CONSTANT_2]] : index
+// CHECK: scf.yield %[[SELECT_0]], %[[ADDI_2]] : index, index
// CHECK: } attributes {"Emitted from" = "linalg.generic"}
-// CHECK: scf.for %[[VAL_29:.*]] = %[[VAL_30:.*]]#1 to %[[VAL_1]] step %[[VAL_1]] {
-// CHECK: %[[VAL_31:.*]] = affine.min #map(%[[VAL_1]], %[[VAL_29]]){{\[}}%[[VAL_1]]]
-// CHECK: %[[VAL_32:.*]] = vector.create_mask %[[VAL_31]] : vector<8xi1>
-// CHECK: %[[VAL_33:.*]] = vector.broadcast %[[VAL_29]] : index to vector<8xindex>
-// CHECK: %[[VAL_34:.*]] = arith.addi %[[VAL_33]], %[[VAL_2]] : vector<8xindex>
-// CHECK: %[[VAL_35:.*]] = arith.index_cast %[[VAL_34]] : vector<8xindex> to vector<8xi64>
-// CHECK: vector.maskedstore %[[VAL_11]]{{\[}}%[[VAL_29]]], %[[VAL_32]], %[[VAL_35]] : memref<8xi64>, vector<8xi1>, vector<8xi64>
+// CHECK: scf.for %[[VAL_4:.*]] = %[[VAL_5:.*]]#1 to %[[CONSTANT_4]] step %[[CONSTANT_4]] {
+// CHECK: %[[MIN_0:.*]] = affine.min #{{.*}}(%[[CONSTANT_4]], %[[VAL_4]]){{\[}}%[[CONSTANT_4]]]
+// CHECK: %[[CREATE_MASK_0:.*]] = vector.create_mask %[[MIN_0]] : vector<8xi1>
+// CHECK: %[[BROADCAST_0:.*]] = vector.broadcast %[[VAL_4]] : index to vector<8xindex>
+// CHECK: %[[ADDI_3:.*]] = arith.addi %[[BROADCAST_0]], %[[CONSTANT_0]] : vector<8xindex>
+// CHECK: %[[INDEX_CAST_1:.*]] = arith.index_cast %[[ADDI_3]] : vector<8xindex> to vector<8xi64>
+// CHECK: vector.maskedstore %[[TO_BUFFER_0]]{{\[}}%[[VAL_4]]], %[[CREATE_MASK_0]], %[[INDEX_CAST_1]] : memref<8xi64>, vector<8xi1>, vector<8xi64>
// CHECK: } {"Emitted from" = "linalg.generic"}
-// CHECK: %[[VAL_36:.*]] = bufferization.to_tensor %[[VAL_11]] : memref<8xi64>
-// CHECK: return %[[VAL_36]] : tensor<8xi64>
+// CHECK: %[[TO_TENSOR_0:.*]] = bufferization.to_tensor %[[TO_BUFFER_0]] : memref<8xi64> to tensor<8xi64>
+// CHECK: return %[[TO_TENSOR_0]] : tensor<8xi64>
// CHECK: }
func.func @sparse_index_1d_disj(%arga: tensor<8xi64, #SparseVector>) -> tensor<8xi64> {
%init = tensor.empty() : tensor<8xi64>
diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir
index b447094874d01..81b63b9152cec 100644
--- a/mlir/test/Transforms/cse.mlir
+++ b/mlir/test/Transforms/cse.mlir
@@ -163,16 +163,14 @@ func.func @down_propagate() -> i32 {
/// Check that operation definitions are NOT propagated up the dominance tree.
// CHECK-LABEL: @up_propagate_for
func.func @up_propagate_for() -> i32 {
+ // CHECK-NEXT: %[[VAR_c1_i32_0:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
// CHECK: affine.for {{.*}} = 0 to 4 {
- affine.for %i = 0 to 4 {
- // CHECK-NEXT: %[[VAR_c1_i32_0:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
+ affine.for %i = 0 to 4 {
// CHECK-NEXT: "foo"(%[[VAR_c1_i32_0]]) : (i32) -> ()
%0 = arith.constant 1 : i32
"foo"(%0) : (i32) -> ()
}
-
- // CHECK: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
- // CHECK-NEXT: return %[[VAR_c1_i32]] : i32
+ // CHECK: return %[[VAR_c1_i32_0]] : i32
%1 = arith.constant 1 : i32
return %1 : i32
}
@@ -181,7 +179,8 @@ func.func @up_propagate_for() -> i32 {
// CHECK-LABEL: func @up_propagate
func.func @up_propagate() -> i32 {
- // CHECK-NEXT: %[[VAR_c0_i32:[0-9a-zA-Z_]+]] = arith.constant 0 : i32
+ // CHECK-NEXT: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
+ // CHECK-NEXT: %[[VAR_c0_i32:[0-9a-zA-Z_]+]] = arith.constant 0 : i32
%0 = arith.constant 0 : i32
// CHECK-NEXT: %[[VAR_true:[0-9a-zA-Z_]+]] = arith.constant true
@@ -191,17 +190,15 @@ func.func @up_propagate() -> i32 {
cf.cond_br %cond, ^bb1, ^bb2(%0 : i32)
^bb1: // CHECK: ^bb1:
- // CHECK-NEXT: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
%1 = arith.constant 1 : i32
// CHECK-NEXT: cf.br ^bb2(%[[VAR_c1_i32]] : i32)
cf.br ^bb2(%1 : i32)
^bb2(%arg : i32): // CHECK: ^bb2
- // CHECK-NEXT: %[[VAR_c1_i32_0:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
%2 = arith.constant 1 : i32
- // CHECK-NEXT: %[[VAR_1:[0-9a-zA-Z_]+]] = arith.addi %{{.*}}, %[[VAR_c1_i32_0]] : i32
+ // CHECK-NEXT: %[[VAR_1:[0-9a-zA-Z_]+]] = arith.addi %{{.*}}, %[[VAR_c1_i32]] : i32
%add = arith.addi %arg, %2 : i32
// CHECK-NEXT: return %[[VAR_1]] : i32
@@ -216,6 +213,7 @@ func.func @up_propagate() -> i32 {
func.func @up_propagate_region() -> i32 {
// CHECK-NEXT: {{.*}} "foo.region"
%0 = "foo.region"() ({
+ // CHECK-NEXT: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
// CHECK-NEXT: %[[VAR_c0_i32:[0-9a-zA-Z_]+]] = arith.constant 0 : i32
// CHECK-NEXT: %[[VAR_true:[0-9a-zA-Z_]+]] = arith.constant true
// CHECK-NEXT: cf.cond_br
@@ -225,15 +223,13 @@ func.func @up_propagate_region() -> i32 {
cf.cond_br %true, ^bb1, ^bb2(%1 : i32)
^bb1: // CHECK: ^bb1:
- // CHECK-NEXT: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
// CHECK-NEXT: cf.br
%c1_i32 = arith.constant 1 : i32
cf.br ^bb2(%c1_i32 : i32)
^bb2(%arg : i32): // CHECK: ^bb2(%[[VAR_1:.*]]: i32):
- // CHECK-NEXT: %[[VAR_c1_i32_0:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
- // CHECK-NEXT: %[[VAR_2:[0-9a-zA-Z_]+]] = arith.addi %[[VAR_1]], %[[VAR_c1_i32_0]] : i32
+ // CHECK-NEXT: %[[VAR_2:[0-9a-zA-Z_]+]] = arith.addi %[[VAR_1]], %[[VAR_c1_i32]] : i32
// CHECK-NEXT: "foo.yield"(%[[VAR_2]]) : (i32) -> ()
%c1_i32_0 = arith.constant 1 : i32
@@ -500,8 +496,8 @@ func.func @cse_multiple_regions(%c: i1, %t: tensor<5xf32>) -> (tensor<5xf32>, te
return %r1, %r2 : tensor<5xf32>, tensor<5xf32>
}
// CHECK-LABEL: func @cse_multiple_regions
-// CHECK: %[[if:.*]] = scf.if {{.*}} {
-// CHECK: tensor.empty
+// CHECK: tensor.empty
+// CHECK: %[[if:.*]] = scf.if {{.*}}
// CHECK: scf.yield
// CHECK: } else {
// CHECK: scf.yield
More information about the Mlir-commits
mailing list