[Mlir-commits] [mlir] [mlir][sparse] support sparsifying sparse kernels to sparse-iterator-based loop (PR #95858)

Peiming Liu llvmlistbot at llvm.org
Mon Jun 17 15:17:33 PDT 2024


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/95858

>From 77b00da55fa46dafc6173d288de9168544936e2d Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 17 Jun 2024 21:30:21 +0000
Subject: [PATCH 1/2] [mlir][sparse] sparsify sparse kernels to (experimental)
 sparse-iterator-based loop

---
 .../Dialect/SparseTensor/IR/SparseTensor.h    |   2 +-
 .../SparseTensor/IR/SparseTensorOps.td        |  28 +++-
 .../Dialect/SparseTensor/Pipelines/Passes.h   |  17 ++-
 .../Dialect/SparseTensor/Transforms/Passes.h  |   1 +
 .../Dialect/SparseTensor/Transforms/Passes.td |   4 +-
 .../SparseTensor/IR/SparseTensorDialect.cpp   |  38 +++++
 .../Transforms/SparseIterationToScf.cpp       |   3 +-
 .../Transforms/Sparsification.cpp             |   4 +-
 .../SparsificationAndBufferizationPass.cpp    |   6 +
 .../Transforms/Utils/LoopEmitter.cpp          | 137 ++++++++++++------
 .../Transforms/Utils/LoopEmitter.h            |   7 +
 .../fuse_sparse_pad_with_consumer.mlir        |  14 +-
 mlir/test/Dialect/SparseTensor/sparse_2d.mlir |  40 ++---
 mlir/test/Dialect/SparseTensor/sparse_3d.mlir |  34 ++---
 .../Dialect/SparseTensor/sparse_affine.mlir   |  30 ++--
 .../Dialect/SparseTensor/sparse_batch.mlir    |  22 +--
 .../SparseTensor/sparse_broadcast.mlir        |  12 +-
 .../Dialect/SparseTensor/sparse_concat.mlir   | 100 ++++++-------
 .../SparseTensor/sparse_fill_zero.mlir        |  20 +--
 .../Dialect/SparseTensor/sparse_fp_ops.mlir   |  88 +++++------
 .../Dialect/SparseTensor/sparse_fusion.mlir   |   6 +-
 .../Dialect/SparseTensor/sparse_int_ops.mlir  | 100 ++++++-------
 .../Dialect/SparseTensor/sparse_kernels.mlir  |  12 +-
 .../sparse_kernels_to_iterator.mlir           |  37 +++++
 .../SparseTensor/sparse_lower_inplace.mlir    |  26 ++--
 .../SparseTensor/sparse_matmul_codegen.mlir   |  24 +--
 mlir/test/Dialect/SparseTensor/sparse_nd.mlir |   2 +-
 .../test/Dialect/SparseTensor/sparse_out.mlir |  60 ++++----
 .../Dialect/SparseTensor/sparse_outbuf.mlir   |  34 ++---
 .../SparseTensor/sparse_parallel_reduce.mlir  |  10 +-
 .../Dialect/SparseTensor/sparse_perm.mlir     |   2 +-
 .../Dialect/SparseTensor/sparse_reshape.mlir  |  48 +++---
 .../Dialect/SparseTensor/sparse_sddmm.mlir    |   2 +-
 .../SparseTensor/sparse_sddmm_org.mlir        |  16 +-
 .../SparseTensor/sparse_tensor_reshape.mlir   |  12 +-
 .../SparseTensor/sparse_vector_chain.mlir     |  14 +-
 .../SparseTensor/sparse_vector_index.mlir     |  20 +--
 .../Dialect/SparseTensor/spy_sddmm_bsr.mlir   |  16 +-
 .../SparseTensor/vectorize_reduction.mlir     |  84 +++++------
 39 files changed, 658 insertions(+), 474 deletions(-)
 create mode 100644 mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 04a6386a199de..68ca036121520 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -89,7 +89,7 @@ class LevelSet {
     assert(i < 64);
     return (bits & (1 << i)) != 0;
   }
-
+  unsigned max() const { return 64 - llvm::countl_zero(bits); }
   unsigned count() const { return llvm::popcount(bits); }
   bool empty() const { return bits == 0; }
 };
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index b2089924291cd..f31df080d7811 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1493,6 +1493,10 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
       ```
   }];
 
+  let arguments = (ins AnySparseTensor:$tensor,
+                       Optional<AnySparseIterator>:$parentIter,
+                       LevelAttr:$loLvl, LevelAttr:$hiLvl);
+  let results = (outs AnySparseIterSpace:$extractedSpace);
 
   let extraClassDeclaration = [{
     std::pair<Level, Level> getLvlRange() {
@@ -1506,10 +1510,20 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
     }
   }];
 
-  let arguments = (ins AnySparseTensor:$tensor,
-                       Optional<AnySparseIterator>:$parentIter,
-                       LevelAttr:$loLvl, LevelAttr:$hiLvl);
-  let results = (outs AnySparseIterSpace:$extractedSpace);
+  let builders = [
+    // Construct a 1-D iteration space.
+    OpBuilder<(ins "Value":$tensor, "Value":$parentIter,
+                   "sparse_tensor::Level":$loLvl),
+    [{
+      build($_builder, $_state, tensor, parentIter, loLvl, loLvl + 1);
+    }]>,
+    // Construct a 1-D root iteration space
+    OpBuilder<(ins "Value":$tensor),
+    [{
+      build($_builder, $_state, tensor, nullptr, 0);
+    }]>
+  ];
+
   let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
                        " attr-dict `:` type($tensor) (`,` type($parentIter)^)? "
                        "`->` qualified(type($extractedSpace))";
@@ -1594,6 +1608,12 @@ def IterateOp : SparseTensor_Op<"iterate",
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$region);
 
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<(ins "Value":$iterSpace, "ValueRange":$initArgs)>,
+    OpBuilder<(ins "Value":$iterSpace, "ValueRange":$initArgs, "LevelSet" :$crdUsedLvls)>
+  ];
+
   let extraClassDeclaration = [{
     unsigned getSpaceDim() {
       return getIterSpace().getType().getSpaceDim();
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
index 85bfee336f848..90021ffa7c380 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
@@ -51,6 +51,20 @@ struct SparsifierOptions : public PassPipelineOptions<SparsifierOptions> {
               mlir::SparseParallelizationStrategy::kAnyStorageAnyLoop,
               "any-storage-any-loop",
               "Enable sparse parallelization for any storage and loop."))};
+  PassOptions::Option<mlir::SparseEmitStrategy> emitStrategy{
+      *this, "sparse-emit-strategy",
+      ::llvm::cl::desc(
+          "Emit functional code or interfaces (to debug) for sparse loops"),
+      ::llvm::cl::init(mlir::SparseEmitStrategy::kFunctional),
+      llvm::cl::values(
+          clEnumValN(mlir::SparseEmitStrategy::kFunctional, "functional",
+                     "Emit functional code (with scf.for/while)."),
+          clEnumValN(mlir::SparseEmitStrategy::kSparseIterator,
+                     "sparse-iterator",
+                     "Emit (experimental) loops (with sparse.iterate)."),
+          clEnumValN(
+              mlir::SparseEmitStrategy::kDebugInterface, "debug-interface",
+              "Emit non-functional but easy-to-read interfaces to debug."))};
 
   PassOptions::Option<bool> enableRuntimeLibrary{
       *this, "enable-runtime-library",
@@ -143,7 +157,8 @@ struct SparsifierOptions : public PassPipelineOptions<SparsifierOptions> {
 
   /// Projects out the options for `createSparsificationPass`.
   SparsificationOptions sparsificationOptions() const {
-    return SparsificationOptions(parallelization, enableRuntimeLibrary);
+    return SparsificationOptions(parallelization, emitStrategy,
+                                 enableRuntimeLibrary);
   }
 
   /// Projects out the options for `createConvertVectorToLLVMPass`.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index c9164e39a3a75..2edb9a61d1876 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -51,6 +51,7 @@ enum class ReinterpretMapScope {
 /// Defines a scope for reinterpret map pass.
 enum class SparseEmitStrategy {
   kFunctional,     // generate fully inlined (and functional) sparse iteration
+  kSparseIterator, // generate (experimental) loop using sparse iterator.
   kDebugInterface, // generate only place-holder for sparse iteration
 };
 
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index b18c975105b75..7173cde67384b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -163,7 +163,9 @@ def SparsificationPass : Pass<"sparsification", "ModuleOp"> {
            "mlir::SparseEmitStrategy::kFunctional",
            "Emit functional code or interfaces (to debug) for sparse loops", [{llvm::cl::values(
              clEnumValN(mlir::SparseEmitStrategy::kFunctional, "functional",
-                        "Emit functional code."),
+                        "Emit functional code (with scf.for/while)."),
+             clEnumValN(mlir::SparseEmitStrategy::kSparseIterator, "sparse-iterator",
+                        "Emit (experimental) loops (with sparse.iterate)."),
              clEnumValN(mlir::SparseEmitStrategy::kDebugInterface, "debug-interface",
                         "Emit non-functional but easy-to-read interfaces to debug."))}]>,
     Option<"enableRuntimeLibrary", "enable-runtime-library", "bool",
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index ac711769ed2ea..504888be4c85f 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2300,6 +2300,41 @@ void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
   results.add<RemoveUnusedLvlCrds>(context);
 }
 
+void IterateOp::build(OpBuilder &builder, OperationState &odsState,
+                      Value iterSpace, ValueRange initArgs) {
+  unsigned rank = llvm::cast<IterSpaceType>(iterSpace.getType()).getSpaceDim();
+  // All ones.
+  LevelSet set((1 << rank) - 1);
+  return build(builder, odsState, iterSpace, initArgs, set);
+}
+
+void IterateOp::build(OpBuilder &builder, OperationState &odsState,
+                      Value iterSpace, ValueRange initArgs,
+                      LevelSet crdUsedLvls) {
+  OpBuilder::InsertionGuard guard(builder);
+
+  odsState.addOperands(iterSpace);
+  odsState.addOperands(initArgs);
+  odsState.getOrAddProperties<Properties>().crdUsedLvls =
+      builder.getIntegerAttr(builder.getIntegerType(64), crdUsedLvls);
+  Region *bodyRegion = odsState.addRegion();
+  odsState.addTypes(initArgs.getTypes());
+  Block *bodyBlock = builder.createBlock(bodyRegion);
+
+  // 1st args, sparse iterator
+  bodyBlock->addArgument(
+      llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(),
+      odsState.location);
+
+  // Followed by a list of used coordinates.
+  for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
+    bodyBlock->addArgument(builder.getIndexType(), odsState.location);
+
+  // Followed by a list of user-provided loop arguments.
+  for (Value v : initArgs)
+    bodyBlock->addArgument(v.getType(), v.getLoc());
+}
+
 ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
   OpAsmParser::Argument iterator;
   OpAsmParser::UnresolvedOperand iterSpace;
@@ -2384,6 +2419,9 @@ LogicalResult IterateOp::verify() {
     return emitOpError(
         "mismatch in number of loop-carried values and defined values");
   }
+  if (getCrdUsedLvls().max() > getSpaceDim())
+    return emitOpError("required out-of-bound coordinates");
+
   return success();
 }
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index 4224925147c84..1d614b7b29361 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -164,7 +164,6 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
       // replace sparse_tensor.yield with scf.yield.
       rewriter.eraseOp(yieldOp);
       rewriter.create<scf::YieldOp>(loc, yields);
-
       const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
       rewriter.replaceOp(
           op, whileOp.getResults().drop_front(it->getCursor().size()),
@@ -192,6 +191,8 @@ mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
 
 void mlir::populateLowerSparseIterationToSCFPatterns(
     TypeConverter &converter, RewritePatternSet &patterns) {
+
+  IterateOp::getCanonicalizationPatterns(patterns, patterns.getContext());
   patterns.add<ExtractIterSpaceConverter, SparseIterateOpConverter>(
       converter, patterns.getContext());
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 0c8e431d8c996..e8f85aab887ca 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1071,7 +1071,9 @@ static bool getAllTidLvlsInLatPoints(
   }
   // If we just need to one loop conditions and the conditions is not imposed on
   // non-unique level, the loop can be generated by a for loop.
-  return numloopCond == 1 && !hasNonUnique;
+  return numloopCond == 1 &&
+         (!hasNonUnique || env.options().sparseEmitStrategy ==
+                               SparseEmitStrategy::kSparseIterator);
 }
 
 /// Starts a loop sequence at given level. Returns true if
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 13c750e83d045..3ae6732f900fe 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -159,6 +159,12 @@ class SparsificationAndBufferizationPass
         pm.addPass(createSparseGPUCodegenPass(0, enableRuntimeLibrary));
       pm.addPass(createSparseReinterpretMapPass(ReinterpretMapScope::kAll));
       pm.addPass(createSparsificationPass(sparsificationOptions));
+      if (sparsificationOptions.sparseEmitStrategy ==
+          SparseEmitStrategy::kSparseIterator) {
+        pm.addNestedPass<func::FuncOp>(createSparseSpaceCollapsePass());
+        pm.addNestedPass<func::FuncOp>(createLowerSparseIterationToSCFPass());
+      }
+
       pm.addNestedPass<func::FuncOp>(createStageSparseOperationsPass());
       pm.addPass(createLowerSparseOpsToForeachPass(enableRuntimeLibrary,
                                                    /*enableConvert=*/true));
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index fe0e515a2d180..2be0193f0de83 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -139,6 +139,7 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
   this->valBuffer.assign(numTensors, nullptr);
   this->lvls.resize(numTensors);
   this->iters.resize(numTensors);
+  this->spIterVals.resize(numTensors);
 
   // These zeros will be overwritten below, but we need to initialize
   // them to something since we'll need random-access assignment.
@@ -173,6 +174,7 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
 
     lvls[tid].resize(lvlRank);
     iters[tid].resize(lvlRank);
+    spIterVals[tid].resize(lvlRank);
     loopHighs.assign(numLoops, nullptr);
 
     // Slice-driven loops related initialization.
@@ -229,6 +231,57 @@ LoopEmitter::makeLevelIterator(OpBuilder &builder, Location loc, TensorId t,
 void LoopEmitter::initializeLoopEmit(
     OpBuilder &builder, Location loc, LoopEmitter::OutputUpdater updater,
     LoopEmitter::SynTensorBoundSetter synSetter) {
+
+  // For every manifest tensor, set up the values buffer.
+  for (TensorId t = 0, numTensors = getNumManifestTensors(); t < numTensors;
+       t++) {
+    // TODO: this should be done through a folding pass after switching to
+    // `sparse_tensor.iterate`-based sparsification.
+    const Value tensor = tryFoldTensors(tensors[t]);
+    const auto rtp = dyn_cast<RankedTensorType>(tensor.getType());
+    // Skips only scalar, zero ranked tensor still need to be bufferized and
+    // (probably) filled with zeros by users.
+    if (!rtp)
+      continue;
+
+    auto stt = getSparseTensorType(tensor);
+    const auto shape = rtp.getShape();
+
+    // Perform the required bufferization. Dense inputs materialize from the
+    // input tensors. Sparse inputs use sparse primitives to obtain the values.
+    // Delegates extra output initialization to clients.
+    bool isOutput = isOutputTensor(t);
+    Type elementType = stt.getElementType();
+    if (!stt.hasEncoding()) {
+      // Non-annotated dense tensors.
+      BaseMemRefType denseTp = MemRefType::get(shape, elementType);
+
+      // TODO: if we unconditionally use fully dynamic layout here, it breaks
+      // some vectorization passes which requires static stride = 1.
+      // Is it possible to call vectorization pass after bufferization?
+      if (llvm::isa_and_nonnull<tensor::ExtractSliceOp>(tensor.getDefiningOp()))
+        denseTp = bufferization::getMemRefTypeWithFullyDynamicLayout(rtp);
+
+      Value denseVal =
+          builder.create<bufferization::ToMemrefOp>(loc, denseTp, tensor);
+      // Dense outputs need special handling.
+      if (isOutput && updater)
+        denseVal = updater(builder, loc, denseVal, tensor);
+
+      valBuffer[t] = denseVal;
+    } else {
+      // Annotated sparse tensors.
+      // We also need the value buffer for all-dense annotated "sparse"
+      // tensors.
+      valBuffer[t] = builder.create<ToValuesOp>(loc, tensor);
+    }
+  }
+
+  // The sparse iterator values will only be available after the loop is
+  // constructed.
+  if (emitStrategy == SparseEmitStrategy::kSparseIterator)
+    return;
+
   // For every synthetic tensor, set the high bound by calling the callback.
   if (synSetter) {
     TensorId synId = getSynTensorId();
@@ -241,7 +294,6 @@ void LoopEmitter::initializeLoopEmit(
   }
 
   // For every manifest tensor:
-  // * get the values buffer.
   // * For every level:
   //   * get the positions and coordinates buffers
   //   * get/compute the level-size, which is also used as the upper-bound
@@ -256,12 +308,9 @@ void LoopEmitter::initializeLoopEmit(
       // Skips only scalar, zero ranked tensor still need to be bufferized and
       // (probably) filled with zeros by users.
       continue;
-    // FIXME: the definition of `lvlRank` looks more like a dim-rank;
-    // but the variable is used as a level everywhere below, which
-    // suggests there may be some dim/lvl confusion going on here.
+
     auto stt = getSparseTensorType(tensor);
     const Level lvlRank = stt.getLvlRank();
-    const auto shape = rtp.getShape();
 
     // Scan all levels of current tensor.
     for (Level l = 0; l < lvlRank; l++) {
@@ -273,36 +322,6 @@ void LoopEmitter::initializeLoopEmit(
       auto it = makeLevelIterator(builder, loc, t, l);
       iters[t][l].emplace_back(std::move(it));
     }
-
-    // Perform the required bufferization. Dense inputs materialize
-    // from the input tensors. Sparse inputs use sparse primitives to obtain the
-    // values.
-    // Delegates extra output initialization to clients.
-    bool isOutput = isOutputTensor(t);
-    Type elementType = stt.getElementType();
-    if (!stt.hasEncoding()) {
-      // Non-annotated dense tensors.
-      BaseMemRefType denseTp = MemRefType::get(shape, elementType);
-
-      // TODO: if we unconditionally use fully dynamic layout here, it breaks
-      // some vectorization passes which requires static stride = 1.
-      // Is it possible to call vectorization pass after bufferization?
-      if (llvm::isa_and_nonnull<tensor::ExtractSliceOp>(tensor.getDefiningOp()))
-        denseTp = bufferization::getMemRefTypeWithFullyDynamicLayout(rtp);
-
-      Value denseVal =
-          builder.create<bufferization::ToMemrefOp>(loc, denseTp, tensor);
-      // Dense outputs need special handling.
-      if (isOutput && updater)
-        denseVal = updater(builder, loc, denseVal, tensor);
-
-      valBuffer[t] = denseVal;
-    } else {
-      // Annotated sparse tensors.
-      // We also need the value buffer for all-dense annotated "sparse"
-      // tensors.
-      valBuffer[t] = builder.create<ToValuesOp>(loc, tensor);
-    }
     // NOTE: we can also prepare for 0 lvl here in advance, this will hoist
     // some loop preparation from tensor iteration, but will also (undesirably)
     // hoist the code ouside if-conditions.
@@ -396,11 +415,13 @@ void LoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc,
                                   ArrayRef<TensorLevel> tidLvls) {
   // TODO: sort
   assert(loopSeqStack.size() == loopStack.size());
-  // Prepares for all the tensors used in the current loop sequence.
 
-  for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
-    levelReducedDep[tid][lvl]++;
-    prepareLoopOverTensorAtLvl(builder, loc, tid, lvl);
+  if (emitStrategy != SparseEmitStrategy::kSparseIterator) {
+    // Prepares for all the tensors used in the current loop sequence.
+    for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
+      levelReducedDep[tid][lvl]++;
+      prepareLoopOverTensorAtLvl(builder, loc, tid, lvl);
+    }
   }
 
   // Universal Index starts from 0.
@@ -598,6 +619,31 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
     OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls,
     MutableArrayRef<Value> reduc, bool tryParallel, bool needsUniv) {
 
+  // TODO: handle coiteration with sparse iterator.
+  if (emitStrategy == SparseEmitStrategy::kSparseIterator) {
+    assert(tidLvls.size() == 1);
+    auto [tid, lvl] = unpackTensorLevel(tidLvls.front());
+    Value t = tensors[tid];
+
+    // Extract and iterate over the iteration space.
+    ExtractIterSpaceOp extractSpaceOp =
+        lvl == 0 ? builder.create<ExtractIterSpaceOp>(loc, t)
+                 : builder.create<ExtractIterSpaceOp>(
+                       loc, t, spIterVals[tid][lvl - 1], lvl);
+
+    IterateOp iterOp = builder.create<IterateOp>(
+        loc, extractSpaceOp.getExtractedSpace(), reduc);
+    spIterVals[tid][lvl] = iterOp.getIterator();
+
+    // Update the reduction varaibles.
+    llvm::copy(iterOp.getRegionIterArgs(), reduc.begin());
+    // Set the insertion point to loop body.
+    builder.setInsertionPointToStart(iterOp.getBody());
+    loopStack.emplace_back(tidLvls, iterOp, builder.getInsertionBlock(),
+                           iterOp.getIterator(), loopTag);
+    return iterOp;
+  }
+
   // TODO: support multiple return on parallel for?
   tryParallel = tryParallel && reduc.size() <= 1;
 
@@ -685,6 +731,16 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
 void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
                               MutableArrayRef<Value> reduc) {
   const LoopInfo &loopInfo = loopStack.back();
+  if (emitStrategy == SparseEmitStrategy::kSparseIterator) {
+    auto iterateOp = llvm::cast<IterateOp>(loopInfo.loop);
+    assert(reduc.size() == iterateOp.getNumResults());
+    rewriter.create<sparse_tensor::YieldOp>(loc, reduc);
+    // Exit the loop.
+    rewriter.setInsertionPointAfter(iterateOp);
+    // In-place update reduction variables.
+    llvm::copy(iterateOp.getResults(), reduc.begin());
+    return;
+  }
   if (auto forOp = llvm::dyn_cast<scf::ForOp>(loopInfo.loop)) {
     if (!reduc.empty()) {
       assert(reduc.size() == forOp.getNumResults());
@@ -693,8 +749,7 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
     // Exit the loop.
     rewriter.setInsertionPointAfter(forOp);
     // In-place update reduction variables.
-    for (unsigned i = 0, e = forOp.getResults().size(); i < e; i++)
-      reduc[i] = forOp.getResult(i);
+    llvm::copy(forOp.getResults(), reduc.begin());
   } else {
     auto parOp = llvm::cast<scf::ParallelOp>(loopInfo.loop);
     if (!reduc.empty()) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index 34312df912997..2a884b10e36b0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -418,6 +418,13 @@ class LoopEmitter {
   // Loop Sequence Stack, stores the universal index for the current loop
   // sequence. and a list of tid level that the loop sequence traverse.
   std::vector<std::pair<Value, std::vector<TensorLevel>>> loopSeqStack;
+
+  //
+  // EXPERIMENTAL:
+  // Fields for generating sparse-iterator-based loop.
+  //
+
+  std::vector<std::vector<Value>> spIterVals;
 };
 
 } // namespace sparse_tensor
diff --git a/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir b/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir
index 4f509bf747ab6..0ef143a1a2f38 100644
--- a/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir
+++ b/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir
@@ -25,13 +25,13 @@
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 2 : index
 // CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[VAL_9:.*]] = tensor.empty() : tensor<8x8xf32>
-// CHECK:           %[[VAL_10:.*]] = linalg.fill ins(%[[VAL_8]] : f32) outs(%[[VAL_9]] : tensor<8x8xf32>) -> tensor<8x8xf32>
-// CHECK:           %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<4x4xf32, #sparse> to memref<?xindex>
-// CHECK:           %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<4x4xf32, #sparse> to memref<?xindex>
-// CHECK:           %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4xf32, #sparse> to memref<?xf32>
-// CHECK:           %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_10]] : memref<8x8xf32>
-// CHECK:           linalg.fill ins(%[[VAL_8]] : f32) outs(%[[VAL_14]] : memref<8x8xf32>)
+// CHECK-DAG:       %[[VAL_9:.*]] = tensor.empty() : tensor<8x8xf32>
+// CHECK-DAG:       %[[VAL_10:.*]] = linalg.fill ins(%[[VAL_8]] : f32) outs(%[[VAL_9]] : tensor<8x8xf32>) -> tensor<8x8xf32>
+// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<4x4xf32, #sparse> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<4x4xf32, #sparse> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4xf32, #sparse> to memref<?xf32>
+// CHECK-DAG:       %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_10]] : memref<8x8xf32>
+// CHECK-DAG:       linalg.fill ins(%[[VAL_8]] : f32) outs(%[[VAL_14]] : memref<8x8xf32>)
 // CHECK:           scf.for %[[VAL_15:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_5]] {
 // CHECK:             %[[VAL_16:.*]] = arith.subi %[[VAL_15]], %[[VAL_7]] : index
 // CHECK:             %[[VAL_17:.*]] = arith.cmpi ult, %[[VAL_15]], %[[VAL_7]] : index
diff --git a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
index 4afa0a8ceccd4..06670ab096fcd 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
@@ -139,7 +139,7 @@ func.func @mul_dd(%arga: tensor<32x16xf32, #Tdd>, %argb: tensor<32x16xf32>, %arg
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16xf32>
 // CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf32>
-// CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_13]] : memref<32x16xf32>)
+// CHECK-DAG:       linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_13]] : memref<32x16xf32>)
 // CHECK:           scf.for %[[VAL_14:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_7]] {
 // CHECK:             %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_14]]] : memref<?xindex>
 // CHECK:             %[[VAL_16:.*]] = arith.addi %[[VAL_14]], %[[VAL_7]] : index
@@ -204,7 +204,7 @@ func.func @add_ds(%arga: tensor<32x16xf32, #Tds>, %argb: tensor<32x16xf32>, %arg
 // CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16xf32>
 // CHECK-DAG:       %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xi1>
-// CHECK:           linalg.fill ins(%[[VAL_5]] : i1) outs(%[[VAL_14]] : memref<32x16xi1>)
+// CHECK-DAG:       linalg.fill ins(%[[VAL_5]] : i1) outs(%[[VAL_14]] : memref<32x16xi1>)
 // CHECK:           scf.for %[[VAL_15:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_7]] {
 // CHECK:             %[[VAL_16:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_15]]] : memref<?xindex>
 // CHECK:             %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_7]] : index
@@ -267,7 +267,7 @@ func.func @cmp_ds(%arga: tensor<32x16xf32, #Tds>, %argb: tensor<32x16xf32>, %arg
 // CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16xf32>
 // CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf32>
-// CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_11]] : memref<32x16xf32>)
+// CHECK-DAG:       linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_11]] : memref<32x16xf32>)
 // CHECK:           scf.for %[[VAL_12:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
 // CHECK:             %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<?xindex>
 // CHECK:             %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_5]] : index
@@ -308,7 +308,7 @@ func.func @mul_ds(%arga: tensor<32x16xf32, #Tds>, %argb: tensor<32x16xf32>, %arg
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16xf32>
 // CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf32>
-// CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_13]] : memref<32x16xf32>)
+// CHECK-DAG:       linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_13]] : memref<32x16xf32>)
 // CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_6]]] : memref<?xindex>
 // CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_7]]] : memref<?xindex>
 // CHECK:           %[[VAL_16:.*]]:2 = scf.while (%[[VAL_17:.*]] = %[[VAL_14]], %[[VAL_18:.*]] = %[[VAL_6]]) : (index, index) -> (index, index) {
@@ -376,9 +376,9 @@ func.func @add_sd(%arga: tensor<32x16xf32, #Tsd>, %argb: tensor<32x16xf32>, %arg
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xf32>
-// CHECK:           %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16xf32>
-// CHECK:           %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xi1>
-// CHECK:           linalg.fill ins(%[[VAL_5]] : i1) outs(%[[VAL_14]] : memref<32x16xi1>)
+// CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16xf32>
+// CHECK-DAG:       %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xi1>
+// CHECK-DAG:       linalg.fill ins(%[[VAL_5]] : i1) outs(%[[VAL_14]] : memref<32x16xi1>)
 // CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_6]]] : memref<?xindex>
 // CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_7]]] : memref<?xindex>
 // CHECK:           %[[VAL_17:.*]]:2 = scf.while (%[[VAL_18:.*]] = %[[VAL_15]], %[[VAL_19:.*]] = %[[VAL_6]]) : (index, index) -> (index, index) {
@@ -446,7 +446,7 @@ func.func @cmp_sd(%arga: tensor<32x16xf32, #Tsd>, %argb: tensor<32x16xf32>, %arg
 // CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16xf32>
 // CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf32>
-// CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_11]] : memref<32x16xf32>)
+// CHECK-DAG:       linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_11]] : memref<32x16xf32>)
 // CHECK:           %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
 // CHECK:           scf.for %[[VAL_14:.*]] = %[[VAL_12]] to %[[VAL_13]] step %[[VAL_5]] {
@@ -490,7 +490,7 @@ func.func @mul_sd(%arga: tensor<32x16xf32, #Tsd>, %argb: tensor<32x16xf32>, %arg
 // CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16xf32>
 // CHECK-DAG:       %[[VAL_15:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf32>
-// CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_15]] : memref<32x16xf32>)
+// CHECK-DAG:       linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_15]] : memref<32x16xf32>)
 // CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_6]]] : memref<?xindex>
 // CHECK:           %[[VAL_17:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_7]]] : memref<?xindex>
 // CHECK:           %[[VAL_18:.*]]:2 = scf.while (%[[VAL_19:.*]] = %[[VAL_16]], %[[VAL_20:.*]] = %[[VAL_6]]) : (index, index) -> (index, index) {
@@ -584,9 +584,9 @@ func.func @add_ss(%arga: tensor<32x16xf32, #Tss>, %argb: tensor<32x16xf32>, %arg
 // CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xf32>
-// CHECK:           %[[VAL_15:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16xf32>
-// CHECK:           %[[VAL_16:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xi1>
-// CHECK:           linalg.fill ins(%[[VAL_5]] : i1) outs(%[[VAL_16]] : memref<32x16xi1>)
+// CHECK-DAG:       %[[VAL_15:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16xf32>
+// CHECK-DAG:       %[[VAL_16:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xi1>
+// CHECK-DAG:       linalg.fill ins(%[[VAL_5]] : i1) outs(%[[VAL_16]] : memref<32x16xi1>)
 // CHECK:           %[[VAL_17:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_6]]] : memref<?xindex>
 // CHECK:           %[[VAL_18:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_7]]] : memref<?xindex>
 // CHECK:           %[[VAL_19:.*]]:2 = scf.while (%[[VAL_20:.*]] = %[[VAL_17]], %[[VAL_21:.*]] = %[[VAL_6]]) : (index, index) -> (index, index) {
@@ -681,7 +681,7 @@ func.func @cmp_ss(%arga: tensor<32x16xf32, #Tss>, %argb: tensor<32x16xf32>, %arg
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16xf32>
 // CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf32>
-// CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_12]] : memref<32x16xf32>)
+// CHECK-DAG:       linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_12]] : memref<32x16xf32>)
 // CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           scf.for %[[VAL_15:.*]] = %[[VAL_13]] to %[[VAL_14]] step %[[VAL_4]] {
@@ -727,7 +727,7 @@ func.func @mul_ss(%arga: tensor<32x16xf32, #Tss>, %argb: tensor<32x16xf32>, %arg
 // CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 1 : index} : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_16:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf32>
-// CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_16]] : memref<32x16xf32>)
+// CHECK-DAG:       linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_16]] : memref<32x16xf32>)
 // CHECK:           %[[VAL_17:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:           %[[VAL_18:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           %[[VAL_19:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_3]]] : memref<?xindex>
@@ -892,7 +892,7 @@ func.func @add_ss_ss(%arga: tensor<32x16xf32, #Tss>, %argb: tensor<32x16xf32, #T
 // CHECK-DAG:       %[[VAL_15:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 1 : index} : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_16:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xf32>
 // CHECK-DAG:           %[[VAL_17:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xi1>
-// CHECK:           linalg.fill ins(%[[VAL_3]] : i1) outs(%[[VAL_17]] : memref<32x16xi1>)
+// CHECK-DAG:       linalg.fill ins(%[[VAL_3]] : i1) outs(%[[VAL_17]] : memref<32x16xi1>)
 // CHECK:           %[[VAL_18:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           %[[VAL_19:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_5]]] : memref<?xindex>
 // CHECK:           %[[VAL_20:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_4]]] : memref<?xindex>
@@ -1167,7 +1167,7 @@ func.func @sub_ss_batched(%0: tensor<2x3xf64, #BatchedVector>, %1: tensor<2x3xf6
 // CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 1 : index} : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_16:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf32>
-// CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_16]] : memref<32x16xf32>)
+// CHECK-DAG:       linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_16]] : memref<32x16xf32>)
 // CHECK:           %[[VAL_17:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:           %[[VAL_18:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           %[[VAL_19:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_3]]] : memref<?xindex>
@@ -1261,7 +1261,7 @@ func.func @mul_ss_ss(%arga: tensor<32x16xf32, #Tss>, %argb: tensor<32x16xf32, #T
 // CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 1 : index} : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_15:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf32>
-// CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_15]] : memref<32x16xf32>)
+// CHECK-DAG:       linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_15]] : memref<32x16xf32>)
 // CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_5]]] : memref<?xindex>
 // CHECK:           %[[VAL_17:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_7]]] : memref<?xindex>
 // CHECK:           %[[VAL_18:.*]]:2 = scf.while (%[[VAL_19:.*]] = %[[VAL_16]], %[[VAL_20:.*]] = %[[VAL_5]]) : (index, index) -> (index, index) {
@@ -1363,7 +1363,7 @@ func.func @add_sd_ds(%arga: tensor<32x16xf32, #Tsd>, %argb: tensor<32x16xf32, #T
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 1 : index} : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf32>
-// CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_13]] : memref<32x16xf32>)
+// CHECK-DAG:       linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_13]] : memref<32x16xf32>)
 // CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
 // CHECK:           scf.for %[[VAL_16:.*]] = %[[VAL_14]] to %[[VAL_15]] step %[[VAL_5]] {
@@ -1512,7 +1512,7 @@ func.func @sum_reduction(%arga: tensor<10x20xf32, #Tds>, %argx: tensor<f32>) ->
 // CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xf64, #sparse{{[0-9]*}}> to memref<?xf64>
 // CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_3]] : tensor<?x?xf64, #sparse{{[0-9]*}}>
 // CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<?x?xf64>
-// CHECK:           linalg.fill ins(%{{.*}} : f64) outs(%[[VAL_11]] : memref<?x?xf64>)
+// CHECK-DAG:       linalg.fill ins(%{{.*}} : f64) outs(%[[VAL_11]] : memref<?x?xf64>)
 // CHECK:           scf.for %[[VAL_12:.*]] = %[[VAL_3]] to %[[VAL_8]] step %[[VAL_4]] {
 // CHECK:             %[[VAL_13:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_12]]] : memref<?xindex>
 // CHECK:             %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_4]] : index
@@ -1634,7 +1634,7 @@ func.func @sampled_dense_dense(%args: tensor<?x?xf32, #Tss>,
 // CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 1 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_15:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 1 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_16:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xf32>
+// CHECK-DAG:       %[[VAL_16:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_17:.*]] = sparse_tensor.positions %[[VAL_2]] {level = 1 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_18:.*]] = sparse_tensor.coordinates %[[VAL_2]] {level = 1 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_19:.*]] = sparse_tensor.values %[[VAL_2]] : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xf32>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_3d.mlir b/mlir/test/Dialect/SparseTensor/sparse_3d.mlir
index b2f528fc7a25e..427a5c3d03a73 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_3d.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_3d.mlir
@@ -35,7 +35,7 @@
 // CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
-// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_11]] : memref<32x16x8xf32>)
+// CHECK-DAG:       linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_11]] : memref<32x16x8xf32>)
 // CHECK:           scf.for %[[VAL_12:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_7]] {
 // CHECK:             %[[VAL_14:.*]] = arith.muli %[[VAL_12]], %[[VAL_4]] : index
 // CHECK:             scf.for %[[VAL_13:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] {
@@ -77,7 +77,7 @@ func.func @add_ddd(%arga: tensor<32x16x8xf32, #Tddd>, %argb: tensor<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
-// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_11]] : memref<32x16x8xf32>)
+// CHECK-DAG:       linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_11]] : memref<32x16x8xf32>)
 // CHECK:           scf.for %[[VAL_12:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_7]] {
 // CHECK:             %[[VAL_14:.*]] = arith.muli %[[VAL_12]], %[[VAL_4]] : index
 // CHECK:             scf.for %[[VAL_13:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] {
@@ -122,7 +122,7 @@ func.func @mul_ddd(%arga: tensor<32x16x8xf32, #Tddd>, %argb: tensor<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_15:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
-// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_15]] : memref<32x16x8xf32>)
+// CHECK-DAG:       linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_15]] : memref<32x16x8xf32>)
 // CHECK:           scf.for %[[VAL_16:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_9]] {
 // CHECK:             %[[VAL_18:.*]] = arith.muli %[[VAL_16]], %[[VAL_5]] : index
 // CHECK:             scf.for %[[VAL_17:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_9]] {
@@ -189,7 +189,7 @@ func.func @add_dds(%arga: tensor<32x16x8xf32, #Tdds>, %argb: tensor<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
-// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_13]] : memref<32x16x8xf32>)
+// CHECK-DAG:       linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_13]] : memref<32x16x8xf32>)
 // CHECK:           scf.for %[[VAL_14:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] {
 // CHECK:             %[[VAL_16:.*]] = arith.muli %[[VAL_14]], %[[VAL_5]] : index
 // CHECK:             scf.for %[[VAL_15:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_7]] {
@@ -236,7 +236,7 @@ func.func @mul_dds(%arga: tensor<32x16x8xf32, #Tdds>, %argb: tensor<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
-// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_14]] : memref<32x16x8xf32>)
+// CHECK-DAG:       linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_14]] : memref<32x16x8xf32>)
 // CHECK:           scf.for %[[VAL_15:.*]] = %[[VAL_7]] to %[[VAL_3]] step %[[VAL_8]] {
 // CHECK:             %[[VAL_16:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_15]]] : memref<?xindex>
 // CHECK:             %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_8]] : index
@@ -307,7 +307,7 @@ func.func @add_dsd(%arga: tensor<32x16x8xf32, #Tdsd>, %argb: tensor<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
-// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_12]] : memref<32x16x8xf32>)
+// CHECK-DAG:       linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_12]] : memref<32x16x8xf32>)
 // CHECK:           scf.for %[[VAL_13:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {
 // CHECK:             %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_13]]] : memref<?xindex>
 // CHECK:             %[[VAL_15:.*]] = arith.addi %[[VAL_13]], %[[VAL_6]] : index
@@ -356,7 +356,7 @@ func.func @mul_dsd(%arga: tensor<32x16x8xf32, #Tdsd>, %argb: tensor<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_15:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_17:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
-// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_17]] : memref<32x16x8xf32>)
+// CHECK-DAG:       linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_17]] : memref<32x16x8xf32>)
 // CHECK:           scf.for %[[VAL_18:.*]] = %[[VAL_8]] to %[[VAL_4]] step %[[VAL_9]] {
 // CHECK:             %[[VAL_19:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_18]]] : memref<?xindex>
 // CHECK:             %[[VAL_20:.*]] = arith.addi %[[VAL_18]], %[[VAL_9]] : index
@@ -452,7 +452,7 @@ func.func @add_dss(%arga: tensor<32x16x8xf32, #Tdss>, %argb: tensor<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
-// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_14]] : memref<32x16x8xf32>)
+// CHECK-DAG:       linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_14]] : memref<32x16x8xf32>)
 // CHECK:           scf.for %[[VAL_15:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] {
 // CHECK:             %[[VAL_16:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref<?xindex>
 // CHECK:             %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_6]] : index
@@ -501,7 +501,7 @@ func.func @mul_dss(%arga: tensor<32x16x8xf32, #Tdss>, %argb: tensor<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
-// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_14]] : memref<32x16x8xf32>)
+// CHECK-DAG:       linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_14]] : memref<32x16x8xf32>)
 // CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_7]]] : memref<?xindex>
 // CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_8]]] : memref<?xindex>
 // CHECK:           %[[VAL_17:.*]]:2 = scf.while (%[[VAL_18:.*]] = %[[VAL_15]], %[[VAL_19:.*]] = %[[VAL_7]]) : (index, index) -> (index, index) {
@@ -577,7 +577,7 @@ func.func @add_sdd(%arga: tensor<32x16x8xf32, #Tsdd>, %argb: tensor<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
-// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_12]] : memref<32x16x8xf32>)
+// CHECK-DAG:       linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_12]] : memref<32x16x8xf32>)
 // CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_5]]] : memref<?xindex>
 // CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref<?xindex>
 // CHECK:           scf.for %[[VAL_15:.*]] = %[[VAL_13]] to %[[VAL_14]] step %[[VAL_6]] {
@@ -627,7 +627,7 @@ func.func @mul_sdd(%arga: tensor<32x16x8xf32, #Tsdd>, %argb: tensor<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_15:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_17:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
-// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_17]] : memref<32x16x8xf32>)
+// CHECK-DAG:       linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_17]] : memref<32x16x8xf32>)
 // CHECK:           %[[VAL_18:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_8]]] : memref<?xindex>
 // CHECK:           %[[VAL_19:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_9]]] : memref<?xindex>
 // CHECK:           %[[VAL_20:.*]]:2 = scf.while (%[[VAL_21:.*]] = %[[VAL_18]], %[[VAL_22:.*]] = %[[VAL_8]]) : (index, index) -> (index, index) {
@@ -728,7 +728,7 @@ func.func @add_sds(%arga: tensor<32x16x8xf32, #Tsds>, %argb: tensor<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
-// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_14]] : memref<32x16x8xf32>)
+// CHECK-DAG:       linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_14]] : memref<32x16x8xf32>)
 // CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_5]]] : memref<?xindex>
 // CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref<?xindex>
 // CHECK:           scf.for %[[VAL_17:.*]] = %[[VAL_15]] to %[[VAL_16]] step %[[VAL_6]] {
@@ -780,7 +780,7 @@ func.func @mul_sds(%arga: tensor<32x16x8xf32, #Tsds>, %argb: tensor<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_16:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
-// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_16]] : memref<32x16x8xf32>)
+// CHECK-DAG:       linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_16]] : memref<32x16x8xf32>)
 // CHECK:           %[[VAL_17:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_7]]] : memref<?xindex>
 // CHECK:           %[[VAL_18:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_8]]] : memref<?xindex>
 // CHECK:           %[[VAL_19:.*]]:2 = scf.while (%[[VAL_20:.*]] = %[[VAL_17]], %[[VAL_21:.*]] = %[[VAL_7]]) : (index, index) -> (index, index) {
@@ -885,7 +885,7 @@ func.func @add_ssd(%arga: tensor<32x16x8xf32, #Tssd>, %argb: tensor<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
-// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_13]] : memref<32x16x8xf32>)
+// CHECK-DAG:       linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_13]] : memref<32x16x8xf32>)
 // CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
 // CHECK:           scf.for %[[VAL_16:.*]] = %[[VAL_14]] to %[[VAL_15]] step %[[VAL_5]] {
@@ -939,7 +939,7 @@ func.func @mul_ssd(%arga: tensor<32x16x8xf32, #Tssd>, %argb: tensor<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_16:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_17:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_19:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
-// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_19]] : memref<32x16x8xf32>)
+// CHECK-DAG:       linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_19]] : memref<32x16x8xf32>)
 // CHECK:           %[[VAL_20:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_8]]] : memref<?xindex>
 // CHECK:           %[[VAL_21:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_9]]] : memref<?xindex>
 // CHECK:           %[[VAL_22:.*]]:2 = scf.while (%[[VAL_23:.*]] = %[[VAL_20]], %[[VAL_24:.*]] = %[[VAL_8]]) : (index, index) -> (index, index) {
@@ -1069,7 +1069,7 @@ func.func @add_sss(%arga: tensor<32x16x8xf32, #Tsss>, %argb: tensor<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_15:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
-// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_15]] : memref<32x16x8xf32>)
+// CHECK-DAG:       linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_15]] : memref<32x16x8xf32>)
 // CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           %[[VAL_17:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
 // CHECK:           scf.for %[[VAL_18:.*]] = %[[VAL_16]] to %[[VAL_17]] step %[[VAL_5]] {
@@ -1308,7 +1308,7 @@ func.func @sum_reduction_inv(%arga: tensor<?x?x?xf32>,
 // CHECK-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<20xf32>
 // CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<30xf32>
 // CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_3]] : memref<10x20x30xf32>
-// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_13]] : memref<10x20x30xf32>)
+// CHECK-DAG:       linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_13]] : memref<10x20x30xf32>)
 // CHECK:           scf.for %[[VAL_14:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_8]] {
 // CHECK:             %[[VAL_15:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_14]]] : memref<?xf32>
 // CHECK:             scf.for %[[VAL_16:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_affine.mlir b/mlir/test/Dialect/SparseTensor/sparse_affine.mlir
index 2128ca7539fa0..1ec6fb586d434 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_affine.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_affine.mlir
@@ -114,7 +114,7 @@ func.func @mul_inv_enc_dense1d(%arga: tensor<32xf32, #EncDenseVec>,
 // CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi32, #sparse{{[0-9]*}}>
 // CHECK-DAG:       %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<34xi32>
 // CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xi32>
-// CHECK:           linalg.fill ins(%[[ZERO]] : i32) outs(%[[VAL_11]] : memref<32xi32>)
+// CHECK-DAG:       linalg.fill ins(%[[ZERO]] : i32) outs(%[[VAL_11]] : memref<32xi32>)
 // CHECK:           %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           scf.for %[[VAL_14:.*]] = %[[VAL_12]] to %[[VAL_13]] step %[[VAL_4]] {
@@ -217,13 +217,13 @@ func.func @mul_affine_dense2d(%arga: tensor<32x16xf64, #CSR>,
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 2 : index
 // CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 3 : index
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<34x16xf64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<34x16xf64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<34x16xf64, #sparse{{[0-9]*}}> to memref<?xf64>
-// CHECK:           %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<32x19xf64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 0 : index} : tensor<32x19xf64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x19xf64, #sparse{{[0-9]*}}> to memref<?xf64>
-// CHECK:           %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf64>
+// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<34x16xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<34x16xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<34x16xf64, #sparse{{[0-9]*}}> to memref<?xf64>
+// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<32x19xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 0 : index} : tensor<32x19xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x19xf64, #sparse{{[0-9]*}}> to memref<?xf64>
+// CHECK-DAG:       %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf64>
 // CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_5]]] : memref<?xindex>
 // CHECK:           scf.for %[[VAL_17:.*]] = %[[VAL_15]] to %[[VAL_16]] step %[[VAL_5]] {
@@ -281,13 +281,13 @@ func.func @mul_affine_dense_dim_2d(%arga: tensor<34x16xf64, #CSR>,
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 1 : index
 // CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 3 : index
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<34x16xf64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<34x16xf64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<34x16xf64, #sparse{{[0-9]*}}> to memref<?xf64>
-// CHECK:           %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<32x19xf64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 0 : index} : tensor<32x19xf64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x19xf64, #sparse{{[0-9]*}}> to memref<?xf64>
-// CHECK:           %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf64>
+// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<34x16xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<34x16xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<34x16xf64, #sparse{{[0-9]*}}> to memref<?xf64>
+// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<32x19xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 0 : index} : tensor<32x19xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x19xf64, #sparse{{[0-9]*}}> to memref<?xf64>
+// CHECK-DAG:       %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf64>
 // CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_5]]] : memref<?xindex>
 // CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_6]]] : memref<?xindex>
 // CHECK:           scf.for %[[VAL_17:.*]] = %[[VAL_15]] to %[[VAL_16]] step %[[VAL_6]] {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_batch.mlir b/mlir/test/Dialect/SparseTensor/sparse_batch.mlir
index f6d2d0d4f7669..f158fc6108a13 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_batch.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_batch.mlir
@@ -5,17 +5,17 @@
 
 // CHECK-LABEL:   func.func @main(
 // CHECK-SAME:      %[[VAL_0:.*]]: tensor<8x4x2xf32, #sparse{{[0-9]*}}>) -> tensor<8x4x2xf32> {
-// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[VAL_4:.*]] = arith.constant 4 : index
-// CHECK:           %[[VAL_5:.*]] = arith.constant 8 : index
-// CHECK:           %[[VAL_6:.*]] = tensor.empty() : tensor<8x4x2xf32>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 2 : index} : tensor<8x4x2xf32, #sparse{{[0-9]*}}> to memref<8x?xindex>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 2 : index} : tensor<8x4x2xf32, #sparse{{[0-9]*}}> to memref<8x?xindex>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x4x2xf32, #sparse{{[0-9]*}}> to memref<8x?xf32>
-// CHECK:           %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_6]] : memref<8x4x2xf32>
-// CHECK:           linalg.fill ins(%[[VAL_3]] : f32) outs(%[[VAL_10]] : memref<8x4x2xf32>)
+// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 4 : index
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 8 : index
+// CHECK-DAG:       %[[VAL_6:.*]] = tensor.empty() : tensor<8x4x2xf32>
+// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 2 : index} : tensor<8x4x2xf32, #sparse{{[0-9]*}}> to memref<8x?xindex>
+// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 2 : index} : tensor<8x4x2xf32, #sparse{{[0-9]*}}> to memref<8x?xindex>
+// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x4x2xf32, #sparse{{[0-9]*}}> to memref<8x?xf32>
+// CHECK-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_6]] : memref<8x4x2xf32>
+// CHECK-DAG:       linalg.fill ins(%[[VAL_3]] : f32) outs(%[[VAL_10]] : memref<8x4x2xf32>)
 // CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_1]] {
 // CHECK:             scf.for %[[VAL_12:.*]] = %[[VAL_2]] to %[[VAL_4]] step %[[VAL_1]] {
 // CHECK:               %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]], %[[VAL_12]]] : memref<8x?xindex>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_broadcast.mlir b/mlir/test/Dialect/SparseTensor/sparse_broadcast.mlir
index a409329700ffd..f1cf90b2c06b2 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_broadcast.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_broadcast.mlir
@@ -16,12 +16,12 @@
 //   CHECK-DAG:  %[[TMP_c3:.*]] = arith.constant 3 : index
 //   CHECK-DAG:  %[[TMP_c0:.*]] = arith.constant 0 : index
 //   CHECK-DAG:  %[[TMP_c1:.*]] = arith.constant 1 : index
-//       CHECK:  %[[TMP_0:.*]] = tensor.empty()
-//       CHECK:  %[[TMP_1:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 0 : index}
-//       CHECK:  %[[TMP_2:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 0 : index}
-//       CHECK:  %[[TMP_3:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 1 : index}
-//       CHECK:  %[[TMP_4:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 1 : index}
-//       CHECK:  %[[TMP_5:.*]] = sparse_tensor.values %[[TMP_arg0]]
+//   CHECK-DAG:  %[[TMP_0:.*]] = tensor.empty()
+//   CHECK-DAG:  %[[TMP_1:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 0 : index}
+//   CHECK-DAG:  %[[TMP_2:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 0 : index}
+//   CHECK-DAG:  %[[TMP_3:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 1 : index}
+//   CHECK-DAG:  %[[TMP_4:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 1 : index}
+//   CHECK-DAG:  %[[TMP_5:.*]] = sparse_tensor.values %[[TMP_arg0]]
 //       CHECK:  %[[TMP_6:.*]] = memref.load %[[TMP_1]][%[[TMP_c0]]] : memref<?xindex>
 //       CHECK:  %[[TMP_7:.*]] = memref.load %[[TMP_1]][%[[TMP_c1]]] : memref<?xindex>
 //       CHECK:  %[[T:.*]] = scf.for %[[TMP_arg1:.*]] = %[[TMP_6]] to %[[TMP_7]] step %[[TMP_c1]] {{.*}} {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
index bbf0b7c7c341d..5fd47971ea3be 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
@@ -14,12 +14,12 @@
 //   CHECK-DAG:  %[[TMP_c1:.*]] = arith.constant 1 : index
 //   CHECK-DAG:  %[[TMP_c5:.*]] = arith.constant 5 : index
 //   CHECK-DAG:  %[[TMP_c2:.*]] = arith.constant 2 : index
-//       CHECK:  %[[TMP_0:.*]] = bufferization.alloc_tensor() : tensor<9x4xf64, #sparse>
-//       CHECK:  %[[TMP_1:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse>
-//       CHECK:  %[[TMP_2:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse>
-//       CHECK:  %[[TMP_3:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse>
-//       CHECK:  %[[TMP_4:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse>
-//       CHECK:  %[[TMP_5:.*]] = sparse_tensor.values %[[TMP_arg0]] : tensor<2x4xf64, #sparse>
+//   CHECK-DAG:  %[[TMP_0:.*]] = bufferization.alloc_tensor() : tensor<9x4xf64, #sparse>
+//   CHECK-DAG:  %[[TMP_1:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse>
+//   CHECK-DAG:  %[[TMP_2:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse>
+//   CHECK-DAG:  %[[TMP_3:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse>
+//   CHECK-DAG:  %[[TMP_4:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse>
+//   CHECK-DAG:  %[[TMP_5:.*]] = sparse_tensor.values %[[TMP_arg0]] : tensor<2x4xf64, #sparse>
 //       CHECK:  %[[TMP_6:.*]] = memref.load %[[TMP_1]][%[[TMP_c0]]] : memref<?xindex>
 //       CHECK:  %[[TMP_7:.*]] = memref.load %[[TMP_1]][%[[TMP_c1]]] : memref<?xindex>
 //       CHECK:  %[[RET_1:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_6]] to %[[TMP_7]] step %[[TMP_c1]] iter_args(%[[A0:.*]] = %[[TMP_0]])
@@ -35,11 +35,11 @@
 //       CHECK:    }
 //       CHECK:    scf.yield %[[RET_4]]
 //       CHECK:  }
-//       CHECK:  %[[TMP_8:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse>
-//       CHECK:  %[[TMP_9:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse>
-//       CHECK:  %[[TMP_10:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse>
-//       CHECK:  %[[TMP_11:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse>
-//       CHECK:  %[[TMP_12:.*]] = sparse_tensor.values %[[TMP_arg1]] : tensor<3x4xf64, #sparse>
+//   CHECK-DAG:  %[[TMP_8:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse>
+//   CHECK-DAG:  %[[TMP_9:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse>
+//   CHECK-DAG:  %[[TMP_10:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse>
+//   CHECK-DAG:  %[[TMP_11:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse>
+//   CHECK-DAG:  %[[TMP_12:.*]] = sparse_tensor.values %[[TMP_arg1]] : tensor<3x4xf64, #sparse>
 //       CHECK:  %[[TMP_13:.*]] = memref.load %[[TMP_8]][%[[TMP_c0]]] : memref<?xindex>
 //       CHECK:  %[[TMP_14:.*]] = memref.load %[[TMP_8]][%[[TMP_c1]]] : memref<?xindex>
 //       CHECK:  %[[RET_2:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_13]] to %[[TMP_14]] step %[[TMP_c1]] iter_args(%[[A2:.*]] = %[[RET_1]])
@@ -56,11 +56,11 @@
 //       CHECK:    }
 //       CHECK:    scf.yield %[[RET_5]]
 //       CHECK:  }
-//       CHECK:  %[[TMP_15:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse>
-//       CHECK:  %[[TMP_16:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse>
-//       CHECK:  %[[TMP_17:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse>
-//       CHECK:  %[[TMP_18:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse>
-//       CHECK:  %[[TMP_19:.*]] = sparse_tensor.values %[[TMP_arg2]] : tensor<4x4xf64, #sparse>
+//   CHECK-DAG:  %[[TMP_15:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse>
+//   CHECK-DAG:  %[[TMP_16:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse>
+//   CHECK-DAG:  %[[TMP_17:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse>
+//   CHECK-DAG:  %[[TMP_18:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse>
+//   CHECK-DAG:  %[[TMP_19:.*]] = sparse_tensor.values %[[TMP_arg2]] : tensor<4x4xf64, #sparse>
 //       CHECK:  %[[TMP_20:.*]] = memref.load %[[TMP_15]][%[[TMP_c0]]] : memref<?xindex>
 //       CHECK:  %[[TMP_21:.*]] = memref.load %[[TMP_15]][%[[TMP_c1]]] : memref<?xindex>
 //       CHECK:  %[[RET_3:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_20]] to %[[TMP_21]] step %[[TMP_c1]] iter_args(%[[A4:.*]] = %[[RET_2]])
@@ -100,12 +100,12 @@ func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #DCSR>,
 //   CHECK-DAG:  %[[TMP_c2:.*]] = arith.constant 2 : index
 //   CHECK-DAG:  %[[TMP_c9:.*]] = arith.constant 9 : index
 //   CHECK-DAG:  %[[TMP_c4:.*]] = arith.constant 4 : index
-//       CHECK:  %[[TMP_0:.*]] = bufferization.alloc_tensor(%[[TMP_c9]], %[[TMP_c4]]) : tensor<?x?xf64, #sparse>
-//       CHECK:  %[[TMP_1:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse>
-//       CHECK:  %[[TMP_2:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse>
-//       CHECK:  %[[TMP_3:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse>
-//       CHECK:  %[[TMP_4:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse>
-//       CHECK:  %[[TMP_5:.*]] = sparse_tensor.values %[[TMP_arg0]] : tensor<2x4xf64, #sparse>
+//   CHECK-DAG:  %[[TMP_0:.*]] = bufferization.alloc_tensor(%[[TMP_c9]], %[[TMP_c4]]) : tensor<?x?xf64, #sparse>
+//   CHECK-DAG:  %[[TMP_1:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse>
+//   CHECK-DAG:  %[[TMP_2:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse>
+//   CHECK-DAG:  %[[TMP_3:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse>
+//   CHECK-DAG:  %[[TMP_4:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse>
+//   CHECK-DAG:  %[[TMP_5:.*]] = sparse_tensor.values %[[TMP_arg0]] : tensor<2x4xf64, #sparse>
 //       CHECK:  %[[TMP_6:.*]] = memref.load %[[TMP_1]][%[[TMP_c0]]] : memref<?xindex>
 //       CHECK:  %[[TMP_7:.*]] = memref.load %[[TMP_1]][%[[TMP_c1]]] : memref<?xindex>
 //       CHECK:  %[[RET_1:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_6]] to %[[TMP_7]] step %[[TMP_c1]] iter_args(%[[A0:.*]] = %[[TMP_0]])
@@ -121,12 +121,12 @@ func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #DCSR>,
 //       CHECK:    }
 //       CHECK:    scf.yield %[[RET_4]]
 //       CHECK:  }
-//       CHECK:  %[[TMP_8:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse>
-//       CHECK:  %[[TMP_9:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse>
-//       CHECK:  %[[TMP_10:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse>
-//       CHECK:  %[[TMP_11:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse>
-//       CHECK:  %[[TMP_12:.*]] = sparse_tensor.values %[[TMP_arg1]] : tensor<3x4xf64, #sparse>
-//       CHECK:  %[[TMP_13:.*]] = memref.load %[[TMP_8]][%[[TMP_c0]]] : memref<?xindex>
+//   CHECK-DAG:  %[[TMP_8:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse>
+//   CHECK-DAG:  %[[TMP_9:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse>
+//   CHECK-DAG:  %[[TMP_10:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse>
+//   CHECK-DAG:  %[[TMP_11:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse>
+//   CHECK-DAG:  %[[TMP_12:.*]] = sparse_tensor.values %[[TMP_arg1]] : tensor<3x4xf64, #sparse>
+//   CHECK-DAG:  %[[TMP_13:.*]] = memref.load %[[TMP_8]][%[[TMP_c0]]] : memref<?xindex>
 //       CHECK:  %[[TMP_14:.*]] = memref.load %[[TMP_8]][%[[TMP_c1]]] : memref<?xindex>
 //       CHECK:  %[[RET_2:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_13]] to %[[TMP_14]] step %[[TMP_c1]] iter_args(%[[A2:.*]] = %[[RET_1]])
 //       CHECK:    %[[TMP_23:.*]] = memref.load %[[TMP_9]][%[[TMP_arg3]]] : memref<?xindex>
@@ -142,11 +142,11 @@ func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #DCSR>,
 //       CHECK:    }
 //       CHECK:    scf.yield %[[RET_5]]
 //       CHECK:  }
-//       CHECK:  %[[TMP_15:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse>
-//       CHECK:  %[[TMP_16:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse>
-//       CHECK:  %[[TMP_17:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse>
-//       CHECK:  %[[TMP_18:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse>
-//       CHECK:  %[[TMP_19:.*]] = sparse_tensor.values %[[TMP_arg2]] : tensor<4x4xf64, #sparse>
+//   CHECK-DAG:  %[[TMP_15:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse>
+//   CHECK-DAG:  %[[TMP_16:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse>
+//   CHECK-DAG:  %[[TMP_17:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse>
+//   CHECK-DAG:  %[[TMP_18:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse>
+//   CHECK-DAG:  %[[TMP_19:.*]] = sparse_tensor.values %[[TMP_arg2]] : tensor<4x4xf64, #sparse>
 //       CHECK:  %[[TMP_20:.*]] = memref.load %[[TMP_15]][%[[TMP_c0]]] : memref<?xindex>
 //       CHECK:  %[[TMP_21:.*]] = memref.load %[[TMP_15]][%[[TMP_c1]]] : memref<?xindex>
 //       CHECK:  %[[RET_3:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_20]] to %[[TMP_21]] step %[[TMP_c1]] iter_args(%[[A4:.*]] = %[[RET_2]])
@@ -187,13 +187,13 @@ func.func @concat_sparse_sparse_dynamic(%arg0: tensor<2x4xf64, #DCSR>,
 // CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 1 : index
 // CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 2 : index
-// CHECK:           %[[VAL_10:.*]] = bufferization.alloc_tensor(%[[VAL_4]], %[[VAL_3]]) : tensor<?x?xf64>
-// CHECK:           %[[VAL_11:.*]] = linalg.fill ins(%[[VAL_6]] : f64) outs(%[[VAL_10]] : tensor<?x?xf64>) -> tensor<?x?xf64>
-// CHECK:           %[[VAL_12:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<2x4xf64, #sparse>
-// CHECK:           %[[VAL_13:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<2x4xf64, #sparse>
-// CHECK:           %[[VAL_14:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<2x4xf64, #sparse>
-// CHECK:           %[[VAL_15:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<2x4xf64, #sparse>
-// CHECK:           %[[VAL_16:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<2x4xf64, #sparse>
+// CHECK-DAG:       %[[VAL_10:.*]] = bufferization.alloc_tensor(%[[VAL_4]], %[[VAL_3]]) : tensor<?x?xf64>
+// CHECK-DAG:       %[[VAL_11:.*]] = linalg.fill ins(%[[VAL_6]] : f64) outs(%[[VAL_10]] : tensor<?x?xf64>) -> tensor<?x?xf64>
+// CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<2x4xf64, #sparse>
+// CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<2x4xf64, #sparse>
+// CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<2x4xf64, #sparse>
+// CHECK-DAG:       %[[VAL_15:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<2x4xf64, #sparse>
+// CHECK-DAG:       %[[VAL_16:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<2x4xf64, #sparse>
 // CHECK:           %[[VAL_17:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_7]]] : memref<?xindex>
 // CHECK:           %[[VAL_18:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_8]]] : memref<?xindex>
 // CHECK:           %[[VAL_19:.*]] = scf.for %[[VAL_20:.*]] = %[[VAL_17]] to %[[VAL_18]] step %[[VAL_8]] iter_args(%[[VAL_21:.*]] = %[[VAL_11]]) -> (tensor<?x?xf64>) {
@@ -209,11 +209,11 @@ func.func @concat_sparse_sparse_dynamic(%arg0: tensor<2x4xf64, #DCSR>,
 // CHECK:             }
 // CHECK:             scf.yield %[[VAL_26]] : tensor<?x?xf64>
 // CHECK:           }
-// CHECK:           %[[VAL_32:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<3x4xf64, #sparse>
-// CHECK:           %[[VAL_33:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 0 : index} : tensor<3x4xf64, #sparse>
-// CHECK:           %[[VAL_34:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 1 : index} : tensor<3x4xf64, #sparse>
-// CHECK:           %[[VAL_35:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 1 : index} : tensor<3x4xf64, #sparse>
-// CHECK:           %[[VAL_36:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<3x4xf64, #sparse>
+// CHECK-DAG:       %[[VAL_32:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<3x4xf64, #sparse>
+// CHECK-DAG:       %[[VAL_33:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 0 : index} : tensor<3x4xf64, #sparse>
+// CHECK-DAG:       %[[VAL_34:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 1 : index} : tensor<3x4xf64, #sparse>
+// CHECK-DAG:       %[[VAL_35:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 1 : index} : tensor<3x4xf64, #sparse>
+// CHECK-DAG:       %[[VAL_36:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<3x4xf64, #sparse>
 // CHECK:           %[[VAL_37:.*]] = memref.load %[[VAL_32]]{{\[}}%[[VAL_7]]] : memref<?xindex>
 // CHECK:           %[[VAL_38:.*]] = memref.load %[[VAL_32]]{{\[}}%[[VAL_8]]] : memref<?xindex>
 // CHECK:           %[[VAL_39:.*]] = scf.for %[[VAL_40:.*]] = %[[VAL_37]] to %[[VAL_38]] step %[[VAL_8]] iter_args(%[[VAL_41:.*]] = %[[VAL_19]]) -> (tensor<?x?xf64>) {
@@ -230,11 +230,11 @@ func.func @concat_sparse_sparse_dynamic(%arg0: tensor<2x4xf64, #DCSR>,
 // CHECK:             }
 // CHECK:             scf.yield %[[VAL_46]] : tensor<?x?xf64>
 // CHECK:           }
-// CHECK:           %[[VAL_53:.*]] = sparse_tensor.positions %[[VAL_2]] {level = 0 : index} : tensor<4x4xf64, #sparse>
-// CHECK:           %[[VAL_54:.*]] = sparse_tensor.coordinates %[[VAL_2]] {level = 0 : index} : tensor<4x4xf64, #sparse>
-// CHECK:           %[[VAL_55:.*]] = sparse_tensor.positions %[[VAL_2]] {level = 1 : index} : tensor<4x4xf64, #sparse>
-// CHECK:           %[[VAL_56:.*]] = sparse_tensor.coordinates %[[VAL_2]] {level = 1 : index} : tensor<4x4xf64, #sparse>
-// CHECK:           %[[VAL_57:.*]] = sparse_tensor.values %[[VAL_2]] : tensor<4x4xf64, #sparse>
+// CHECK-DAG:       %[[VAL_53:.*]] = sparse_tensor.positions %[[VAL_2]] {level = 0 : index} : tensor<4x4xf64, #sparse>
+// CHECK-DAG:       %[[VAL_54:.*]] = sparse_tensor.coordinates %[[VAL_2]] {level = 0 : index} : tensor<4x4xf64, #sparse>
+// CHECK-DAG:       %[[VAL_55:.*]] = sparse_tensor.positions %[[VAL_2]] {level = 1 : index} : tensor<4x4xf64, #sparse>
+// CHECK-DAG:       %[[VAL_56:.*]] = sparse_tensor.coordinates %[[VAL_2]] {level = 1 : index} : tensor<4x4xf64, #sparse>
+// CHECK-DAG:       %[[VAL_57:.*]] = sparse_tensor.values %[[VAL_2]] : tensor<4x4xf64, #sparse>
 // CHECK:           %[[VAL_58:.*]] = memref.load %[[VAL_53]]{{\[}}%[[VAL_7]]] : memref<?xindex>
 // CHECK:           %[[VAL_59:.*]] = memref.load %[[VAL_53]]{{\[}}%[[VAL_8]]] : memref<?xindex>
 // CHECK:           %[[VAL_60:.*]] = scf.for %[[VAL_61:.*]] = %[[VAL_58]] to %[[VAL_59]] step %[[VAL_8]] iter_args(%[[VAL_62:.*]] = %[[VAL_39]]) -> (tensor<?x?xf64>) {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
index df3e4b0ed60c7..c26ba56347299 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
@@ -37,16 +37,16 @@
 // CHECK:           %[[VAL_25:.*]] = memref.cast %[[VAL_24]] : memref<300xindex> to memref<?xindex>
 // CHECK:           linalg.fill ins(%[[VAL_2]] : f64) outs(%[[VAL_20]] : memref<300xf64>)
 // CHECK:           linalg.fill ins(%[[VAL_7]] : i1) outs(%[[VAL_22]] : memref<300xi1>)
-// CHECK:           %[[VAL_26:.*]] = call @sparsePositions0(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK:           %[[VAL_27:.*]] = call @sparseCoordinates0(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK:           %[[VAL_28:.*]] = call @sparsePositions0(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK:           %[[VAL_29:.*]] = call @sparseCoordinates0(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK:           %[[VAL_30:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr) -> memref<?xf64>
-// CHECK:           %[[VAL_31:.*]] = call @sparsePositions0(%[[VAL_1]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK:           %[[VAL_32:.*]] = call @sparseCoordinates0(%[[VAL_1]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK:           %[[VAL_33:.*]] = call @sparsePositions0(%[[VAL_1]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK:           %[[VAL_34:.*]] = call @sparseCoordinates0(%[[VAL_1]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK:           %[[VAL_35:.*]] = call @sparseValuesF64(%[[VAL_1]]) : (!llvm.ptr) -> memref<?xf64>
+// CHECK-DAG:       %[[VAL_26:.*]] = call @sparsePositions0(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK-DAG:       %[[VAL_27:.*]] = call @sparseCoordinates0(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK-DAG:       %[[VAL_28:.*]] = call @sparsePositions0(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK-DAG:       %[[VAL_29:.*]] = call @sparseCoordinates0(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK-DAG:       %[[VAL_30:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr) -> memref<?xf64>
+// CHECK-DAG:       %[[VAL_31:.*]] = call @sparsePositions0(%[[VAL_1]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK-DAG:       %[[VAL_32:.*]] = call @sparseCoordinates0(%[[VAL_1]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK-DAG:       %[[VAL_33:.*]] = call @sparsePositions0(%[[VAL_1]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK-DAG:       %[[VAL_34:.*]] = call @sparseCoordinates0(%[[VAL_1]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK-DAG:       %[[VAL_35:.*]] = call @sparseValuesF64(%[[VAL_1]]) : (!llvm.ptr) -> memref<?xf64>
 // CHECK:           %[[VAL_36:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_5]]] : memref<?xindex>
 // CHECK:           %[[VAL_37:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_6]]] : memref<?xindex>
 // CHECK:           scf.for %[[VAL_38:.*]] = %[[VAL_36]] to %[[VAL_37]] step %[[VAL_6]] {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir b/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
index 07b2c3c22995d..b6c7b771394b1 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
@@ -35,10 +35,10 @@
 // CHECK-SAME:    %[[VAL_1:.*]]: tensor<32xf64>) -> tensor<32xf64> {
 // CHECK-DAG:     %[[VAL_2:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[VAL_3:.*]] = arith.constant 1 : index
-// CHECK:         %[[VAL_4:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:         %[[VAL_5:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:         %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xf64>
-// CHECK:         %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xf64>
+// CHECK-DAG:     %[[VAL_4:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:     %[[VAL_5:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:     %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xf64>
+// CHECK-DAG:     %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xf64>
 // CHECK:         %[[VAL_8:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref<?xindex>
 // CHECK:         %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:         scf.for %[[VAL_10:.*]] = %[[VAL_8]] to %[[VAL_9]] step %[[VAL_3]] {
@@ -67,10 +67,10 @@ func.func @abs(%arga: tensor<32xf64, #SV>,
 // CHECK-SAME:    %[[VAL_1:.*]]: tensor<32xf64>) -> tensor<32xf64> {
 // CHECK-DAG:     %[[VAL_2:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[VAL_3:.*]] = arith.constant 1 : index
-// CHECK:         %[[VAL_4:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:         %[[VAL_5:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:         %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xf64>
-// CHECK:         %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xf64>
+// CHECK-DAG:     %[[VAL_4:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:     %[[VAL_5:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:     %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xf64>
+// CHECK-DAG:     %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xf64>
 // CHECK:         %[[VAL_8:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref<?xindex>
 // CHECK:         %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:         scf.for %[[VAL_10:.*]] = %[[VAL_8]] to %[[VAL_9]] step %[[VAL_3]] {
@@ -99,10 +99,10 @@ func.func @ceil(%arga: tensor<32xf64, #SV>,
 // CHECK-SAME:    %[[VAL_1:.*]]: tensor<32xf64>) -> tensor<32xf64> {
 // CHECK-DAG:     %[[VAL_2:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[VAL_3:.*]] = arith.constant 1 : index
-// CHECK:         %[[VAL_4:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:         %[[VAL_5:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:         %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xf64>
-// CHECK:         %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xf64>
+// CHECK-DAG:     %[[VAL_4:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:     %[[VAL_5:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:     %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xf64>
+// CHECK-DAG:     %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xf64>
 // CHECK:         %[[VAL_8:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref<?xindex>
 // CHECK:         %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:         scf.for %[[VAL_10:.*]] = %[[VAL_8]] to %[[VAL_9]] step %[[VAL_3]] {
@@ -131,10 +131,10 @@ func.func @floor(%arga: tensor<32xf64, #SV>,
 // CHECK-SAME:    %[[VAL_1:.*]]: tensor<32xf64>) -> tensor<32xf64> {
 // CHECK-DAG:     %[[VAL_2:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[VAL_3:.*]] = arith.constant 1 : index
-// CHECK:         %[[VAL_4:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:         %[[VAL_5:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:         %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xf64>
-// CHECK:         %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xf64>
+// CHECK-DAG:     %[[VAL_4:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:     %[[VAL_5:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:     %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xf64>
+// CHECK-DAG:     %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xf64>
 // CHECK:         %[[VAL_8:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref<?xindex>
 // CHECK:         %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:         scf.for %[[VAL_10:.*]] = %[[VAL_8]] to %[[VAL_9]] step %[[VAL_3]] {
@@ -166,11 +166,11 @@ func.func @neg(%arga: tensor<32xf64, #SV>,
 // CHECK-DAG:     %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[VAL_5:.*]] = arith.constant true
 // CHECK-DAG:     %[[VAL_6:.*]] = arith.constant 1 : index
-// CHECK:         %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}>
-// CHECK:         %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}>
-// CHECK:         %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse{{[0-9]*}}>
-// CHECK:         %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xf64>
-// CHECK:         %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xf64>
+// CHECK-DAG:     %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}>
+// CHECK-DAG:     %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}>
+// CHECK-DAG:     %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse{{[0-9]*}}>
+// CHECK-DAG:     %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xf64>
+// CHECK-DAG:     %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xf64>
 // CHECK:         %[[VAL_12:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:         %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref<?xindex>
 // CHECK:         %[[VAL_14:.*]]:2 = scf.while (%[[VAL_15:.*]] = %[[VAL_12]], %[[VAL_16:.*]] = %[[VAL_4]]) : (index, index) -> (index, index) {
@@ -226,11 +226,11 @@ func.func @add(%arga: tensor<32xf64, #SV>,
 // CHECK-DAG:     %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[VAL_5:.*]] = arith.constant true
 // CHECK-DAG:     %[[VAL_6:.*]] = arith.constant 1 : index
-// CHECK:         %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:         %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:         %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xf64>
-// CHECK:         %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xf64>
-// CHECK:         %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xf64>
+// CHECK-DAG:     %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:     %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:     %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xf64>
+// CHECK-DAG:     %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xf64>
+// CHECK-DAG:     %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xf64>
 // CHECK:         %[[VAL_12:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:         %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref<?xindex>
 // CHECK:         %[[VAL_14:.*]]:2 = scf.while (%[[VAL_15:.*]] = %[[VAL_12]], %[[VAL_16:.*]] = %[[VAL_4]]) : (index, index) -> (index, index) {
@@ -286,11 +286,11 @@ func.func @sub(%arga: tensor<32xf64, #SV>,
 // CHECK-SAME:    %[[VAL_2:.*]]: tensor<32xf64>) -> tensor<32xf64> {
 // CHECK-DAG:     %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK:         %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}>
-// CHECK:         %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}>
-// CHECK:         %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse{{[0-9]*}}>
-// CHECK:         %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xf64>
-// CHECK:         %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xf64>
+// CHECK-DAG:     %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}>
+// CHECK-DAG:     %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}>
+// CHECK-DAG:     %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse{{[0-9]*}}>
+// CHECK-DAG:     %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xf64>
+// CHECK-DAG:     %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xf64>
 // CHECK:         %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:         %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:         scf.for %[[VAL_12:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_4]] {
@@ -322,10 +322,10 @@ func.func @mul(%arga: tensor<32xf64, #SV>,
 // CHECK-DAG:     %[[VAL_2:.*]] = arith.constant 2.000000e+00 : f64
 // CHECK-DAG:     %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK:         %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}>
-// CHECK:         %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}>
-// CHECK:         %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse{{[0-9]*}}>
-// CHECK:         %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xf64>
+// CHECK-DAG:     %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}>
+// CHECK-DAG:     %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}>
+// CHECK-DAG:     %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse{{[0-9]*}}>
+// CHECK-DAG:     %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xf64>
 // CHECK:         %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:         %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:         scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_4]] {
@@ -354,10 +354,10 @@ func.func @divbyc(%arga: tensor<32xf64, #SV>,
 // CHECK-SAME:      %[[VAL_0:.*]]: tensor<32xf64, #sparse{{[0-9]*}}>) -> tensor<32xf64, #sparse{{[0-9]*}}> {
 // CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_3:.*]] = tensor.empty() : tensor<32xf64, #sparse{{[0-9]*}}>
-// CHECK:           %[[VAL_4:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xf64>
+// CHECK-DAG:       %[[VAL_3:.*]] = tensor.empty() : tensor<32xf64, #sparse{{[0-9]*}}>
+// CHECK-DAG:       %[[VAL_4:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_5:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse{{[0-9]*}}> to memref<?xf64>
 // CHECK:           %[[VAL_7:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_1]]] : memref<?xindex>
 // CHECK:           %[[VAL_8:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref<?xindex>
 // CHECK:           %[[T:.*]] = scf.for %[[VAL_9:.*]] = %[[VAL_7]] to %[[VAL_8]] step %[[VAL_2]] {{.*}} {
@@ -401,11 +401,11 @@ func.func @zero_preserving_math(%arga: tensor<32xf64, #SV>) -> tensor<32xf64, #S
 // CHECK-SAME:      %[[VAL_0:.*]]: tensor<32xcomplex<f64>, #sparse{{.*}}> {
 // CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_3:.*]] = complex.constant [0.000000e+00, 1.000000e+00] : complex<f64>
-// CHECK:           %[[VAL_4:.*]] = tensor.empty() : tensor<32xcomplex<f64>, #sparse{{[0-9]*}}>
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xcomplex<f64>, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xcomplex<f64>, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xcomplex<f64>, #sparse{{[0-9]*}}> to memref<?xcomplex<f64>>
+// CHECK-DAG:       %[[VAL_3:.*]] = complex.constant [0.000000e+00, 1.000000e+00] : complex<f64>
+// CHECK-DAG:       %[[VAL_4:.*]] = tensor.empty() : tensor<32xcomplex<f64>, #sparse{{[0-9]*}}>
+// CHECK-DAG:       %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xcomplex<f64>, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xcomplex<f64>, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xcomplex<f64>, #sparse{{[0-9]*}}> to memref<?xcomplex<f64>>
 // CHECK:           %[[VAL_8:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_1]]] : memref<?xindex>
 // CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_2]]] : memref<?xindex>
 // CHECK:           %[[T:.*]] = scf.for %[[VAL_10:.*]] = %[[VAL_8]] to %[[VAL_9]] step %[[VAL_2]] {{.*}} {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir b/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir
index 2cc64434a1d8f..50f21416f5a74 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir
@@ -26,9 +26,9 @@
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<100xf64, #sparse> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<100xf64, #sparse> to memref<?xf64>
 // CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_8]] : memref<100xf64>
-// CHECK:           linalg.fill ins(%[[VAL_4]] : f64) outs(%[[VAL_12]] : memref<100xf64>)
-// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK-DAG:           linalg.fill ins(%[[VAL_4]] : f64) outs(%[[VAL_12]] : memref<100xf64>)
+// CHECK-DAG:           %[[VAL_13:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK-DAG:           %[[VAL_14:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_2]]] : memref<?xindex>
 // CHECK:           %[[VAL_15:.*]]:2 = scf.while (%[[VAL_16:.*]] = %[[VAL_13]], %[[VAL_17:.*]] = %[[VAL_3]]) : (index, index) -> (index, index) {
 // CHECK:             %[[VAL_18:.*]] = arith.cmpi ult, %[[VAL_16]], %[[VAL_14]] : index
 // CHECK:             scf.condition(%[[VAL_18]]) %[[VAL_16]], %[[VAL_17]] : index, index
diff --git a/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir b/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir
index 868b11c91fe35..8fa473b5a9dba 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir
@@ -30,11 +30,11 @@
 // CHECK-DAG:           %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-DAG:           %[[VAL_5:.*]] = arith.constant true
 // CHECK-DAG:           %[[VAL_6:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse{{[0-9]*}}>
-// CHECK:           %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xi64>
-// CHECK:           %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xi64>
+// CHECK-DAG:           %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}>
+// CHECK-DAG:           %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}>
+// CHECK-DAG:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse{{[0-9]*}}>
+// CHECK-DAG:           %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xi64>
+// CHECK-DAG:           %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xi64>
 // CHECK:           %[[VAL_12:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref<?xindex>
 // CHECK:           %[[VAL_14:.*]]:2 = scf.while (%[[VAL_15:.*]] = %[[VAL_12]], %[[VAL_16:.*]] = %[[VAL_4]]) : (index, index) -> (index, index) {
@@ -91,11 +91,11 @@ func.func @add(%arga: tensor<32xi64, #SV>,
 // CHECK-DAG:           %[[VAL_5:.*]] = arith.constant true
 // CHECK-DAG:           %[[VAL_6:.*]] = arith.constant 1 : index
 // CHECK-DAG:           %[[VAL_7:.*]] = arith.constant 0 : i64
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse{{[0-9]*}}>
-// CHECK:           %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xi64>
-// CHECK:           %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xi64>
+// CHECK-DAG:           %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}>
+// CHECK-DAG:           %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}>
+// CHECK-DAG:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse{{[0-9]*}}>
+// CHECK-DAG:           %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xi64>
+// CHECK-DAG:           %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xi64>
 // CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_6]]] : memref<?xindex>
 // CHECK:           %[[VAL_15:.*]]:2 = scf.while (%[[VAL_16:.*]] = %[[VAL_13]], %[[VAL_17:.*]] = %[[VAL_4]]) : (index, index) -> (index, index) {
@@ -151,11 +151,11 @@ func.func @sub(%arga: tensor<32xi64, #SV>,
 // CHECK-SAME:              %[[VAL_2:.*]]: tensor<32xi64>) -> tensor<32xi64> {
 // CHECK-DAG:           %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK-DAG:           %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}>
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse{{[0-9]*}}>
-// CHECK:           %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xi64>
-// CHECK:           %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xi64>
+// CHECK-DAG:           %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}>
+// CHECK-DAG:           %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}>
+// CHECK-DAG:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse{{[0-9]*}}>
+// CHECK-DAG:           %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xi64>
+// CHECK-DAG:           %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xi64>
 // CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:           %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           scf.for %[[VAL_12:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_4]] {
@@ -187,10 +187,10 @@ func.func @mul(%arga: tensor<32xi64, #SV>,
 // CHECK-DAG:           %[[VAL_2:.*]] = arith.constant 2 : i64
 // CHECK-DAG:           %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK-DAG:           %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}>
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse{{[0-9]*}}>
-// CHECK:           %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xi64>
+// CHECK-DAG:           %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}>
+// CHECK-DAG:           %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}>
+// CHECK-DAG:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse{{[0-9]*}}>
+// CHECK-DAG:           %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xi64>
 // CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_4]] {
@@ -221,10 +221,10 @@ func.func @divsbyc(%arga: tensor<32xi64, #SV>,
 // CHECK-DAG:           %[[VAL_2:.*]] = arith.constant 2 : i64
 // CHECK-DAG:           %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK-DAG:           %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}>
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse{{[0-9]*}}>
-// CHECK:           %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xi64>
+// CHECK-DAG:           %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}>
+// CHECK-DAG:           %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}>
+// CHECK-DAG:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse{{[0-9]*}}>
+// CHECK-DAG:           %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xi64>
 // CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_4]] {
@@ -255,11 +255,11 @@ func.func @divubyc(%arga: tensor<32xi64, #SV>,
 // CHECK-SAME:              %[[VAL_2:.*]]: tensor<32xi64>) -> tensor<32xi64> {
 // CHECK-DAG:           %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK-DAG:           %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xi64>
-// CHECK:           %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xi64>
-// CHECK:           %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xi64>
+// CHECK-DAG:           %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:           %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xi64>
+// CHECK-DAG:           %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xi64>
+// CHECK-DAG:           %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xi64>
 // CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:           %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           scf.for %[[VAL_12:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_4]] {
@@ -293,11 +293,11 @@ func.func @and(%arga: tensor<32xi64, #SV>,
 // CHECK-DAG:           %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-DAG:           %[[VAL_5:.*]] = arith.constant true
 // CHECK-DAG:           %[[VAL_6:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xi64>
-// CHECK:           %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xi64>
-// CHECK:           %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xi64>
+// CHECK-DAG:           %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:           %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xi64>
+// CHECK-DAG:           %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xi64>
+// CHECK-DAG:           %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xi64>
 // CHECK:           %[[VAL_12:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref<?xindex>
 // CHECK:           %[[VAL_14:.*]]:2 = scf.while (%[[VAL_15:.*]] = %[[VAL_12]], %[[VAL_16:.*]] = %[[VAL_4]]) : (index, index) -> (index, index) {
@@ -353,11 +353,11 @@ func.func @or(%arga: tensor<32xi64, #SV>,
 // CHECK-DAG:           %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-DAG:           %[[VAL_5:.*]] = arith.constant true
 // CHECK-DAG:           %[[VAL_6:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xi64>
-// CHECK:           %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xi64>
-// CHECK:           %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xi64>
+// CHECK-DAG:           %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:           %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xi64>
+// CHECK-DAG:           %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xi64>
+// CHECK-DAG:           %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xi64>
 // CHECK:           %[[VAL_12:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref<?xindex>
 // CHECK:           %[[VAL_14:.*]]:2 = scf.while (%[[VAL_15:.*]] = %[[VAL_12]], %[[VAL_16:.*]] = %[[VAL_4]]) : (index, index) -> (index, index) {
@@ -411,10 +411,10 @@ func.func @xor(%arga: tensor<32xi64, #SV>,
 // CHECK-DAG:           %[[VAL_2:.*]] = arith.constant 2 : i64
 // CHECK-DAG:           %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK-DAG:           %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xi64>
-// CHECK:           %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xi64>
+// CHECK-DAG:           %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:           %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xi64>
+// CHECK-DAG:           %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xi64>
 // CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_4]] {
@@ -445,10 +445,10 @@ func.func @ashrbyc(%arga: tensor<32xi64, #SV>,
 // CHECK-DAG:           %[[VAL_2:.*]] = arith.constant 2 : i64
 // CHECK-DAG:           %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK-DAG:           %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xi64>
-// CHECK:           %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xi64>
+// CHECK-DAG:           %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:           %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xi64>
+// CHECK-DAG:           %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xi64>
 // CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_4]] {
@@ -479,10 +479,10 @@ func.func @lsrbyc(%arga: tensor<32xi64, #SV>,
 // CHECK-DAG:           %[[VAL_2:.*]] = arith.constant 2 : i64
 // CHECK-DAG:           %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK-DAG:           %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xi64>
-// CHECK:           %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xi64>
+// CHECK-DAG:           %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:           %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse{{[0-9]*}}> to memref<?xi64>
+// CHECK-DAG:           %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xi64>
 // CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_4]] {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
index 912d8bf5145e6..78e29979ca1ac 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
@@ -18,8 +18,8 @@
 // CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<10x20xf32, #sparse{{[0-9]*}}> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<10x20xf32, #sparse{{[0-9]*}}> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10x20xf32, #sparse{{[0-9]*}}> to memref<?xf32>
-// CHECK:           %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<20x30xf32>
-// CHECK:           %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<10x30xf32>
+// CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<20x30xf32>
+// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<10x30xf32>
 // CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
 // CHECK:           scf.for %[[VAL_15:.*]] = %[[VAL_13]] to %[[VAL_14]] step %[[VAL_5]] {
@@ -64,7 +64,7 @@ func.func @matmul1(%a: tensor<10x20xf32, #DCSR>,
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 1 : index}
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 1 : index}
 // CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_1]]
-// CHECK:           %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<10x30xf32>
+// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<10x30xf32>
 // CHECK:           scf.for %[[VAL_13:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
 // CHECK:             %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:             %[[VAL_15:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_5]]] : memref<?xindex>
@@ -209,7 +209,7 @@ func.func @matmul2(%A: tensor<4x8xf64, #DCSR>,
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 1 : index} : tensor<3x3xi32, #sparse{{[0-9]*}}> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 1 : index} : tensor<3x3xi32, #sparse{{[0-9]*}}> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<3x3xi32, #sparse{{[0-9]*}}> to memref<?xi32>
-// CHECK:           %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<6x6xi32>
+// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<6x6xi32>
 // CHECK:           scf.for %[[VAL_13:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
 // CHECK:             scf.for %[[VAL_14:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
 // CHECK:               %[[VAL_15:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_13]], %[[VAL_14]]] : memref<6x6xi32>
@@ -261,7 +261,7 @@ func.func @conv2d(%input:  tensor<8x8xi32>,
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 1 : index} : tensor<3x6xi8, #sparse{{[0-9]*}}> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 1 : index} : tensor<3x6xi8, #sparse{{[0-9]*}}> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<3x6xi8, #sparse{{[0-9]*}}> to memref<?xi8>
-// CHECK:           %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_2]] : memref<5x6xi64>
+// CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_2]] : memref<5x6xi64>
 // CHECK:           scf.for %[[VAL_14:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
 // CHECK:             %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:             %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_5]]] : memref<?xindex>
@@ -309,7 +309,7 @@ func.func @quantized_matmul(%input1: tensor<5x3xi8>,
 // CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<1024xf32, #sparse{{[0-9]*}}> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 0 : index} : tensor<1024xf32, #sparse{{[0-9]*}}> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<1024xf32, #sparse{{[0-9]*}}> to memref<?xf32>
-// CHECK:           %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<f32>
+// CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<f32>
 // CHECK:           %[[VAL_12:.*]] = memref.load %[[VAL_11]][] : memref<f32>
 // CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir b/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
new file mode 100644
index 0000000000000..35cf9aaf446b2
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="sparse-emit-strategy=sparse-iterator" --sparse-space-collapse --lower-sparse-iteration-to-scf | FileCheck %s
+
+
+#COO = #sparse_tensor.encoding<{
+  map = (d0, d1, d2, d3) -> (
+    d0 : compressed(nonunique),
+    d1 : singleton(nonunique, soa),
+    d2 : singleton(nonunique, soa),
+    d3 : singleton(soa)
+  ),
+  explicitVal = 1 : i32
+}>
+
+// CHECK-LABEL:   func.func @sqsum(
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[POS_BUF:.*]] = sparse_tensor.positions %{{.*}} {level = 0 : index} : tensor<?x?x?x?xi32, #sparse> to memref<?xindex>
+// CHECK:           %[[POS_LO:.*]] = memref.load %[[POS_BUF]]{{\[}}%[[C0]]] : memref<?xindex>
+// CHECK:           %[[POS_HI:.*]] = memref.load %[[POS_BUF]]{{\[}}%[[C1]]] : memref<?xindex>
+// CHECK:           %[[SQ_SUM:.*]] = scf.for %[[POS:.*]] = %[[POS_LO]] to %[[POS_HI]] step %[[C1]] {{.*}} {
+// CHECK:             %[[SUM:.*]] = arith.addi
+// CHECK:             scf.yield %[[SUM]] : i32
+// CHECK:           }
+// CHECK:           memref.store
+// CHECK:           %[[RET:.*]] = bufferization.to_tensor
+// CHECK:           return %[[RET]] : tensor<i32>
+// CHECK:         }
+func.func @sqsum(%arg0: tensor<?x?x?x?xi32, #COO>) -> tensor<i32> {
+  %cst = arith.constant dense<0> : tensor<i32>
+  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> ()>], iterator_types = ["reduction", "reduction", "reduction", "reduction"]} ins(%arg0 : tensor<?x?x?x?xi32, #COO>) outs(%cst : tensor<i32>) {
+  ^bb0(%in: i32, %out: i32):
+    %1 = arith.muli %in, %in : i32
+    %2 = arith.addi %out, %1 : i32
+    linalg.yield %2 : i32
+  } -> tensor<i32>
+  return %0 : tensor<i32>
+}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir b/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir
index a827360abb426..773c5677eea55 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir
@@ -26,11 +26,11 @@
 // CHECK-HIR-DAG:       %[[VAL_3:.*]] = arith.constant 32 : index
 // CHECK-HIR-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-HIR-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK-HIR:           %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}>
-// CHECK-HIR:           %[[VAL_7:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}>
-// CHECK-HIR:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x64xf64, #sparse{{[0-9]*}}>
-// CHECK-HIR:           %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<64xf64>
-// CHECK-HIR:           %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xf64>
+// CHECK-HIR-DAG:       %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}>
+// CHECK-HIR-DAG:       %[[VAL_7:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}>
+// CHECK-HIR-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x64xf64, #sparse{{[0-9]*}}>
+// CHECK-HIR-DAG:       %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<64xf64>
+// CHECK-HIR-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xf64>
 // CHECK-HIR:           scf.for %[[VAL_11:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
 // CHECK-HIR-DAG:         %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
 // CHECK-HIR-DAG:         %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_5]] : index
@@ -57,11 +57,11 @@
 // CHECK-MIR-DAG:       %[[VAL_3:.*]] = arith.constant 32 : index
 // CHECK-MIR-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-MIR-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK-MIR:           %[[VAL_6:.*]] = call @sparsePositions0(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK-MIR:           %[[VAL_7:.*]] = call @sparseCoordinates0(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK-MIR:           %[[VAL_8:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr) -> memref<?xf64>
-// CHECK-MIR:           %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<64xf64>
-// CHECK-MIR:           %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xf64>
+// CHECK-MIR-DAG:       %[[VAL_6:.*]] = call @sparsePositions0(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK-MIR-DAG:       %[[VAL_7:.*]] = call @sparseCoordinates0(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK-MIR-DAG:       %[[VAL_8:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr) -> memref<?xf64>
+// CHECK-MIR-DAG:       %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<64xf64>
+// CHECK-MIR-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xf64>
 // CHECK-MIR:           scf.for %[[VAL_11:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
 // CHECK-MIR-DAG:         %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
 // CHECK-MIR-DAG:         %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_5]] : index
@@ -88,9 +88,9 @@
 // CHECK-LIR-DAG:       %[[VAL_3:.*]] = arith.constant 32 : index
 // CHECK-LIR-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-LIR-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK-LIR:           %[[VAL_6:.*]] = call @sparsePositions0(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK-LIR:           %[[VAL_7:.*]] = call @sparseCoordinates0(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref<?xindex>
-// CHECK-LIR:           %[[VAL_8:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr) -> memref<?xf64>
+// CHECK-LIR-DAG:       %[[VAL_6:.*]] = call @sparsePositions0(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK-LIR-DAG:       %[[VAL_7:.*]] = call @sparseCoordinates0(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref<?xindex>
+// CHECK-LIR-DAG:       %[[VAL_8:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr) -> memref<?xf64>
 // CHECK-LIR:           scf.for %[[VAL_9:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
 // CHECK-LIR-DAG:         %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_9]]] : memref<?xindex>
 // CHECK-LIR-DAG:         %[[VAL_11:.*]] = arith.addi %[[VAL_9]], %[[VAL_5]] : index
diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
index ad12b637d0c52..0362ab4607528 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
@@ -44,18 +44,18 @@
 // CHECK:           %[[VAL_33:.*]] = memref.cast %[[VAL_32]] : memref<4xindex> to memref<?xindex>
 // CHECK:           linalg.fill ins(%[[VAL_8]] : f64) outs(%[[VAL_30]] : memref<4xf64>)
 // CHECK:           linalg.fill ins(%[[VAL_10]] : i1) outs(%[[VAL_31]] : memref<4xi1>)
-// CHECK:           %[[VAL_34:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]]  pos_mem_sz at 1 : !sparse_tensor.storage_specifier
-// CHECK:           %[[VAL_35:.*]] = memref.subview %[[VAL_0]][0] {{\[}}%[[VAL_34]]] [1] : memref<?xindex> to memref<?xindex>
-// CHECK:           %[[VAL_36:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]]  crd_mem_sz at 1 : !sparse_tensor.storage_specifier
-// CHECK:           %[[VAL_37:.*]] = memref.subview %[[VAL_1]][0] {{\[}}%[[VAL_36]]] [1] : memref<?xindex> to memref<?xindex>
-// CHECK:           %[[VAL_38:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]]  val_mem_sz : !sparse_tensor.storage_specifier
-// CHECK:           %[[VAL_39:.*]] = memref.subview %[[VAL_2]][0] {{\[}}%[[VAL_38]]] [1] : memref<?xf64> to memref<?xf64>
-// CHECK:           %[[VAL_40:.*]] = sparse_tensor.storage_specifier.get %[[VAL_7]]  pos_mem_sz at 1 : !sparse_tensor.storage_specifier
-// CHECK:           %[[VAL_41:.*]] = memref.subview %[[VAL_4]][0] {{\[}}%[[VAL_40]]] [1] : memref<?xindex> to memref<?xindex>
-// CHECK:           %[[VAL_42:.*]] = sparse_tensor.storage_specifier.get %[[VAL_7]]  crd_mem_sz at 1 : !sparse_tensor.storage_specifier
-// CHECK:           %[[VAL_43:.*]] = memref.subview %[[VAL_5]][0] {{\[}}%[[VAL_42]]] [1] : memref<?xindex> to memref<?xindex>
-// CHECK:           %[[VAL_44:.*]] = sparse_tensor.storage_specifier.get %[[VAL_7]]  val_mem_sz : !sparse_tensor.storage_specifier
-// CHECK:           %[[VAL_45:.*]] = memref.subview %[[VAL_6]][0] {{\[}}%[[VAL_44]]] [1] : memref<?xf64> to memref<?xf64>
+// CHECK-DAG:       %[[VAL_34:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]]  pos_mem_sz at 1 : !sparse_tensor.storage_specifier
+// CHECK-DAG:       %[[VAL_35:.*]] = memref.subview %[[VAL_0]][0] {{\[}}%[[VAL_34]]] [1] : memref<?xindex> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_36:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]]  crd_mem_sz at 1 : !sparse_tensor.storage_specifier
+// CHECK-DAG:       %[[VAL_37:.*]] = memref.subview %[[VAL_1]][0] {{\[}}%[[VAL_36]]] [1] : memref<?xindex> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_38:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]]  val_mem_sz : !sparse_tensor.storage_specifier
+// CHECK-DAG:       %[[VAL_39:.*]] = memref.subview %[[VAL_2]][0] {{\[}}%[[VAL_38]]] [1] : memref<?xf64> to memref<?xf64>
+// CHECK-DAG:       %[[VAL_40:.*]] = sparse_tensor.storage_specifier.get %[[VAL_7]]  pos_mem_sz at 1 : !sparse_tensor.storage_specifier
+// CHECK-DAG:       %[[VAL_41:.*]] = memref.subview %[[VAL_4]][0] {{\[}}%[[VAL_40]]] [1] : memref<?xindex> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_42:.*]] = sparse_tensor.storage_specifier.get %[[VAL_7]]  crd_mem_sz at 1 : !sparse_tensor.storage_specifier
+// CHECK-DAG:       %[[VAL_43:.*]] = memref.subview %[[VAL_5]][0] {{\[}}%[[VAL_42]]] [1] : memref<?xindex> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_44:.*]] = sparse_tensor.storage_specifier.get %[[VAL_7]]  val_mem_sz : !sparse_tensor.storage_specifier
+// CHECK-DAG:       %[[VAL_45:.*]] = memref.subview %[[VAL_6]][0] {{\[}}%[[VAL_44]]] [1] : memref<?xf64> to memref<?xf64>
 // CHECK:           %[[VAL_46:.*]]:4 = scf.for %[[VAL_47:.*]] = %[[VAL_12]] to %[[VAL_13]] step %[[VAL_11]] iter_args(%[[VAL_48:.*]] = %[[VAL_27]], %[[VAL_49:.*]] = %[[VAL_17]], %[[VAL_50:.*]] = %[[VAL_19]], %[[VAL_51:.*]] = %[[VAL_29]]) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
 // CHECK:             %[[VAL_52:.*]] = memref.load %[[VAL_35]]{{\[}}%[[VAL_47]]] : memref<?xindex>
 // CHECK:             %[[VAL_53:.*]] = arith.addi %[[VAL_47]], %[[VAL_11]] : index
diff --git a/mlir/test/Dialect/SparseTensor/sparse_nd.mlir b/mlir/test/Dialect/SparseTensor/sparse_nd.mlir
index 5b77591c1c08d..2ac36fa6d8996 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_nd.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_nd.mlir
@@ -42,7 +42,7 @@
 // CHECK-DAG:       %[[VAL_17:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 4 : index} : tensor<80x70x60x50x40x30x20x10xf32, #sparse{{[0-9]*}}> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_18:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<80x70x60x50x40x30x20x10xf32, #sparse{{[0-9]*}}> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_20:.*]] = bufferization.to_memref %[[VAL_2]] : memref<10x20x30x40x50x60x70x80xf32>
-// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_20]] : memref<10x20x30x40x50x60x70x80xf32>
+// CHECK-DAG:       linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_20]] : memref<10x20x30x40x50x60x70x80xf32>
 // CHECK:           scf.for %[[VAL_21:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_12]] {
 // CHECK:             %[[VAL_23:.*]] = arith.muli %[[VAL_21]], %[[VAL_9]] : index
 // CHECK:             scf.for %[[VAL_22:.*]] = %[[VAL_11]] to %[[VAL_9]] step %[[VAL_12]] {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_out.mlir b/mlir/test/Dialect/SparseTensor/sparse_out.mlir
index 08b81b54a9e63..4dff06b8155dd 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_out.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_out.mlir
@@ -153,23 +153,23 @@ func.func @sparse_truly_dynamic(%arga: tensor<10x20xf32, #CSR>) -> tensor<10x20x
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0 : i32
 // CHECK-DAG:       %[[VAL_FALSE:.*]] = arith.constant false
 // CHECK-DAG:       %[[VAL_TRUE:.*]] = arith.constant true
-// CHECK:           %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xi32, #sparse{{[0-9]*}}>
-// CHECK:           %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xi32, #sparse{{[0-9]*}}>
-// CHECK:           %[[VAL_7:.*]] = tensor.empty(%[[VAL_5]], %[[VAL_6]]) : tensor<?x?xi32, #sparse{{[0-9]*}}>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<?x?x?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<?x?x?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<?x?x?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_11:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<?x?x?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_12:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 2 : index} : tensor<?x?x?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_13:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 2 : index} : tensor<?x?x?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?x?xi32, #sparse{{[0-9]*}}> to memref<?xi32>
-// CHECK:           %[[VAL_15:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<?x?x?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_16:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 0 : index} : tensor<?x?x?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_17:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 1 : index} : tensor<?x?x?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_18:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 1 : index} : tensor<?x?x?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_19:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 2 : index} : tensor<?x?x?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_20:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 2 : index} : tensor<?x?x?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_21:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?x?x?xi32, #sparse{{[0-9]*}}> to memref<?xi32>
+// CHECK-DAG:       %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xi32, #sparse{{[0-9]*}}>
+// CHECK-DAG:       %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xi32, #sparse{{[0-9]*}}>
+// CHECK-DAG:       %[[VAL_7:.*]] = tensor.empty(%[[VAL_5]], %[[VAL_6]]) : tensor<?x?xi32, #sparse{{[0-9]*}}>
+// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<?x?x?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<?x?x?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<?x?x?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<?x?x?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 2 : index} : tensor<?x?x?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 2 : index} : tensor<?x?x?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?x?xi32, #sparse{{[0-9]*}}> to memref<?xi32>
+// CHECK-DAG:       %[[VAL_15:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<?x?x?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_16:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 0 : index} : tensor<?x?x?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_17:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 1 : index} : tensor<?x?x?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_18:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 1 : index} : tensor<?x?x?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_19:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 2 : index} : tensor<?x?x?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_20:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 2 : index} : tensor<?x?x?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_21:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?x?x?xi32, #sparse{{[0-9]*}}> to memref<?xi32>
 // CHECK:           %[[VAL_22:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_2]]] : memref<?xindex>
 // CHECK:           %[[VAL_23:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:           %[[VAL_24:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_2]]] : memref<?xindex>
@@ -316,19 +316,19 @@ func.func @sumred(%arga: tensor<?x?x?xi32, #SparseTensor>,
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 1 : index
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant false
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant true
-// CHECK:           %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?xf32, #sparse{{[0-9]*}}>
-// CHECK:           %[[VAL_7:.*]] = tensor.dim %[[VAL_1]], %[[VAL_3]] : tensor<?x?xf32, #sparse{{[0-9]*}}>
-// CHECK:           %[[VAL_8:.*]] = tensor.empty(%[[VAL_6]], %[[VAL_7]]) : tensor<?x?xf32, #sparse{{[0-9]*}}>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xf32>
-// CHECK:           %[[VAL_14:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_15:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_16:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 1 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_17:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 1 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_18:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xf32>
+// CHECK-DAG:       %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?xf32, #sparse{{[0-9]*}}>
+// CHECK-DAG:       %[[VAL_7:.*]] = tensor.dim %[[VAL_1]], %[[VAL_3]] : tensor<?x?xf32, #sparse{{[0-9]*}}>
+// CHECK-DAG:       %[[VAL_8:.*]] = tensor.empty(%[[VAL_6]], %[[VAL_7]]) : tensor<?x?xf32, #sparse{{[0-9]*}}>
+// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xf32>
+// CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_15:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_16:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 1 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_17:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 1 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_18:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xf32>
 // CHECK:           %[[VAL_19:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_2]]] : memref<?xindex>
 // CHECK:           %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:           %[[VAL_21:.*]] = scf.for %[[VAL_22:.*]] = %[[VAL_19]] to %[[VAL_20]] step %[[VAL_3]] iter_args(%[[VAL_23:.*]] = %[[VAL_8]]) -> (tensor<?x?xf32, #sparse{{[0-9]*}}>) {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_outbuf.mlir b/mlir/test/Dialect/SparseTensor/sparse_outbuf.mlir
index 1028b58be37df..5b453e9a736a2 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_outbuf.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_outbuf.mlir
@@ -16,11 +16,11 @@
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #{{.*}}> to memref<?xindex>
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #{{.*}}> to memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #{{.*}}> to memref<?xi32>
-// CHECK:           %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_1]] : memref<10xf32>
-// CHECK:           linalg.fill ins(%[[VAL_3]] : f32) outs(%[[VAL_8]] : memref<10xf32>)
+// CHECK-DAG:       %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #{{.*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #{{.*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #{{.*}}> to memref<?xi32>
+// CHECK-DAG:       %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_1]] : memref<10xf32>
+// CHECK-DAG:       linalg.fill ins(%[[VAL_3]] : f32) outs(%[[VAL_8]] : memref<10xf32>)
 // CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_2]]] : memref<?xindex>
 // CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_4]] {
@@ -49,12 +49,12 @@ func.func @allout_inplace(%arga: tensor<10xi32, #SV>,
 // CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_4:.*]] = tensor.empty() : tensor<10xf32>
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #{{.*}}> to memref<?xindex>
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #{{.*}}> to memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #{{.*}}> to memref<?xi32>
-// CHECK:           %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_4]] : memref<10xf32>
-// CHECK:           linalg.fill ins(%[[VAL_2]] : f32) outs(%[[VAL_8]] : memref<10xf32>)
+// CHECK-DAG:       %[[VAL_4:.*]] = tensor.empty() : tensor<10xf32>
+// CHECK-DAG:       %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #{{.*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #{{.*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #{{.*}}> to memref<?xi32>
+// CHECK-DAG:       %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_4]] : memref<10xf32>
+// CHECK-DAG:       linalg.fill ins(%[[VAL_2]] : f32) outs(%[[VAL_8]] : memref<10xf32>)
 // CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_1]]] : memref<?xindex>
 // CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_3]] {
@@ -83,12 +83,12 @@ func.func @allout_materialize(%arga: tensor<10xi32, #SV>) -> tensor<10xf32> {
 // CHECK-SAME:      %[[VAL_1:.*]]: tensor<10xf32>) -> tensor<10xf32> {
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_4:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<10xf32, #{{.*}}> to memref<?xindex>
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<10xf32, #{{.*}}> to memref<?xindex>
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xf32, #{{.*}}> to memref<?xf32>
-// CHECK:           %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_1]] : memref<10xf32>
-// CHECK:           %[[VAL_8:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref<?xindex>
-// CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK-DAG:       %[[VAL_4:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<10xf32, #{{.*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_5:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<10xf32, #{{.*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xf32, #{{.*}}> to memref<?xf32>
+// CHECK-DAG:       %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_1]] : memref<10xf32>
+// CHECK-DAG:       %[[VAL_8:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK-DAG:       %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:           scf.for %[[VAL_10:.*]] = %[[VAL_8]] to %[[VAL_9]] step %[[VAL_3]] {
 // CHECK:             %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_10]]] : memref<?xindex>
 // CHECK:             %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_10]]] : memref<?xf32>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_parallel_reduce.mlir b/mlir/test/Dialect/SparseTensor/sparse_parallel_reduce.mlir
index 61b50bcd7d0c6..44a551464c860 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_parallel_reduce.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_parallel_reduce.mlir
@@ -21,11 +21,11 @@
 //   CHECK-DAG:  %[[TMP_c16:.*]] = arith.constant 16 : index
 //   CHECK-DAG:  %[[TMP_c0:.*]] = arith.constant 0 : index
 //   CHECK-DAG:  %[[TMP_c1:.*]] = arith.constant 1 : index
-//       CHECK:  %[[TMP_0:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 1 : index}
-//       CHECK:  %[[TMP_1:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 1 : index}
-//       CHECK:  %[[TMP_2:.*]] = sparse_tensor.values %[[TMP_arg0]]
-//       CHECK:  %[[TMP_3:.*]] = bufferization.to_memref %[[TMP_arg1]] : memref<32xf32>
-//       CHECK:  %[[TMP_4:.*]] = bufferization.to_memref %[[TMP_arg2]] : memref<16xf32>
+//   CHECK-DAG:  %[[TMP_0:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 1 : index}
+//   CHECK-DAG:  %[[TMP_1:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 1 : index}
+//   CHECK-DAG:  %[[TMP_2:.*]] = sparse_tensor.values %[[TMP_arg0]]
+//   CHECK-DAG:  %[[TMP_3:.*]] = bufferization.to_memref %[[TMP_arg1]] : memref<32xf32>
+//   CHECK-DAG:  %[[TMP_4:.*]] = bufferization.to_memref %[[TMP_arg2]] : memref<16xf32>
 //       CHECK:  scf.parallel (%[[TMP_arg3:.*]]) = (%[[TMP_c0]]) to (%[[TMP_c16]]) step (%[[TMP_c1]]) {
 //       CHECK:    %[[TMP_6:.*]] = memref.load %[[TMP_4]][%[[TMP_arg3]]] : memref<16xf32>
 //       CHECK:    %[[TMP_7:.*]] = memref.load %[[TMP_0]][%[[TMP_arg3]]] : memref<?xindex>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_perm.mlir b/mlir/test/Dialect/SparseTensor/sparse_perm.mlir
index 173c69a969218..07c273fcddc3b 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_perm.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_perm.mlir
@@ -65,7 +65,7 @@ func.func @sparse_static_dims(%arga: tensor<10x20x30xf32, #X>,
 // CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.lvl %[[DEMAP]], %[[VAL_3]] : tensor<?x?x?xf32, #sparse{{[0-9]*}}>
 // CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.lvl %[[DEMAP]], %[[VAL_4]] : tensor<?x?x?xf32, #sparse{{[0-9]*}}>
 // CHECK-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<?x?x?xf32>
-// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_10]] : memref<?x?x?xf32>)
+// CHECK-DAG:       linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_10]] : memref<?x?x?xf32>)
 // CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_3]] to %[[VAL_7]] step %[[VAL_4]] {
 // CHECK:             %[[VAL_13:.*]] = arith.muli %[[VAL_11]], %[[VAL_8]] : index
 // CHECK:             scf.for %[[VAL_12:.*]] = %[[VAL_3]] to %[[VAL_8]] step %[[VAL_4]] {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
index 492dcd05dc909..846dd8560a998 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
@@ -20,10 +20,10 @@
 // CHECK-DAG:     %[[C10:.*]] = arith.constant 10 : index
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
-// CHECK:         %[[B:.*]] = bufferization.alloc_tensor()
-// CHECK:         %[[P0:.*]] = sparse_tensor.positions %[[S]] {level = 0 : index}
-// CHECK:         %[[I0:.*]] = sparse_tensor.coordinates %[[S]] {level = 0 : index}
-// CHECK:         %[[V:.*]] = sparse_tensor.values %[[S]]
+// CHECK-DAG:     %[[B:.*]] = bufferization.alloc_tensor()
+// CHECK-DAG:     %[[P0:.*]] = sparse_tensor.positions %[[S]] {level = 0 : index}
+// CHECK-DAG:     %[[I0:.*]] = sparse_tensor.coordinates %[[S]] {level = 0 : index}
+// CHECK-DAG:     %[[V:.*]] = sparse_tensor.values %[[S]]
 // CHECK:         %[[S0:.*]] = memref.load %[[P0]]{{\[}}%[[C0]]] : memref<?xindex>
 // CHECK:         %[[E0:.*]] = memref.load %[[P0]]{{\[}}%[[C1]]] : memref<?xindex>
 // CHECK:         %[[RET:.*]] = scf.for %[[I:.*]] = %[[S0]] to %[[E0]] step %[[C1]] iter_args(%[[R:.*]] = %[[B]])
@@ -57,12 +57,12 @@ func.func @sparse_expand(%arg0: tensor<100xf64, #SparseVector>) -> tensor<10x10x
 // CHECK-DAG:     %[[C10:.*]] = arith.constant 10 : index
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
-// CHECK:         %[[B:.*]] = bufferization.alloc_tensor()
-// CHECK:         %[[P0:.*]] = sparse_tensor.positions %[[S]] {level = 0 : index}
-// CHECK:         %[[I0:.*]] = sparse_tensor.coordinates %[[S]] {level = 0 : index}
-// CHECK:         %[[P1:.*]] = sparse_tensor.positions %[[S]] {level = 1 : index}
-// CHECK:         %[[I1:.*]] = sparse_tensor.coordinates %[[S]] {level = 1 : index}
-// CHECK:         %[[V:.*]] = sparse_tensor.values %[[S]]
+// CHECK-DAG:     %[[B:.*]] = bufferization.alloc_tensor()
+// CHECK-DAG:     %[[P0:.*]] = sparse_tensor.positions %[[S]] {level = 0 : index}
+// CHECK-DAG:     %[[I0:.*]] = sparse_tensor.coordinates %[[S]] {level = 0 : index}
+// CHECK-DAG:     %[[P1:.*]] = sparse_tensor.positions %[[S]] {level = 1 : index}
+// CHECK-DAG:     %[[I1:.*]] = sparse_tensor.coordinates %[[S]] {level = 1 : index}
+// CHECK-DAG:     %[[V:.*]] = sparse_tensor.values %[[S]]
 // CHECK:         %[[S0:.*]] = memref.load %[[P0]]{{\[}}%[[C0]]] : memref<?xindex>
 // CHECK:         %[[E0:.*]] = memref.load %[[P0]]{{\[}}%[[C1]]] : memref<?xindex>
 // CHECK:         %[[RET:.*]] = scf.for %[[I:.*]] = %[[S0]] to %[[E0]] step %[[C1]] iter_args(%[[A0:.*]] = %[[B]])
@@ -103,12 +103,12 @@ func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<10
 // CHECK-DAG:     %[[C10:.*]] = arith.constant 10 : index
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
-// CHECK:         %[[SD:.*]] = sparse_tensor.lvl %[[S]], %[[C0]]
-// CHECK:         %[[DD0:.*]] = arith.divui %[[SD]], %[[C10]] : index
-// CHECK:         %[[B:.*]] = bufferization.alloc_tensor(%[[DD0]])
-// CHECK:         %[[P0:.*]] = sparse_tensor.positions %[[S]] {level = 0 : index}
-// CHECK:         %[[I0:.*]] = sparse_tensor.coordinates %[[S]] {level = 0 : index}
-// CHECK:         %[[V:.*]] = sparse_tensor.values %[[S]]
+// CHECK-DAG:     %[[SD:.*]] = sparse_tensor.lvl %[[S]], %[[C0]]
+// CHECK-DAG:     %[[DD0:.*]] = arith.divui %[[SD]], %[[C10]] : index
+// CHECK-DAG:     %[[B:.*]] = bufferization.alloc_tensor(%[[DD0]])
+// CHECK-DAG:     %[[P0:.*]] = sparse_tensor.positions %[[S]] {level = 0 : index}
+// CHECK-DAG:     %[[I0:.*]] = sparse_tensor.coordinates %[[S]] {level = 0 : index}
+// CHECK-DAG:     %[[V:.*]] = sparse_tensor.values %[[S]]
 // CHECK:         %[[S0:.*]] = memref.load %[[P0]]{{\[}}%[[C0]]] : memref<?xindex>
 // CHECK:         %[[E0:.*]] = memref.load %[[P0]]{{\[}}%[[C1]]] : memref<?xindex>
 // CHECK:         %[[RET:.*]] = scf.for %[[I:.*]] = %[[S0]] to %[[E0]] step %[[C1]] iter_args(%[[R:.*]] = %[[B]])
@@ -146,14 +146,14 @@ func.func @dynamic_sparse_expand(%arg0: tensor<?xf64, #SparseVector>, %sz0: inde
 // CHECK-DAG:     %[[C10:.*]] = arith.constant 10 : index
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
-// CHECK:         %[[SD1:.*]] = sparse_tensor.lvl %[[S]], %[[C1]]
-// CHECK:         %[[DD0:.*]] = arith.muli %[[SD1]], %[[C10]] : index
-// CHECK:         %[[B:.*]] = bufferization.alloc_tensor(%[[DD0]])
-// CHECK:         %[[P0:.*]] = sparse_tensor.positions %[[S]] {level = 0 : index}
-// CHECK:         %[[I0:.*]] = sparse_tensor.coordinates %[[S]] {level = 0 : index}
-// CHECK:         %[[P1:.*]] = sparse_tensor.positions %[[S]] {level = 1 : index}
-// CHECK:         %[[I1:.*]] = sparse_tensor.coordinates %[[S]] {level = 1 : index}
-// CHECK:         %[[V:.*]] = sparse_tensor.values %[[S]]
+// CHECK-DAG:     %[[SD1:.*]] = sparse_tensor.lvl %[[S]], %[[C1]]
+// CHECK-DAG:     %[[DD0:.*]] = arith.muli %[[SD1]], %[[C10]] : index
+// CHECK-DAG:     %[[B:.*]] = bufferization.alloc_tensor(%[[DD0]])
+// CHECK-DAG:     %[[P0:.*]] = sparse_tensor.positions %[[S]] {level = 0 : index}
+// CHECK-DAG:     %[[I0:.*]] = sparse_tensor.coordinates %[[S]] {level = 0 : index}
+// CHECK-DAG:     %[[P1:.*]] = sparse_tensor.positions %[[S]] {level = 1 : index}
+// CHECK-DAG:     %[[I1:.*]] = sparse_tensor.coordinates %[[S]] {level = 1 : index}
+// CHECK-DAG:     %[[V:.*]] = sparse_tensor.values %[[S]]
 // CHECK:         %[[S0:.*]] = memref.load %[[P0]]{{\[}}%[[C0]]] : memref<?xindex>
 // CHECK:         %[[E0:.*]] = memref.load %[[P0]]{{\[}}%[[C1]]] : memref<?xindex>
 // CHECK:         %[[RET:.*]] = scf.for %[[I:.*]] = %[[S0]] to %[[E0]] step %[[C1]] iter_args(%[[R0:.*]] = %[[B]])
diff --git a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
index fcd69bea426d6..a03b97684a7a4 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
@@ -71,7 +71,7 @@ func.func @fold_yield_direct_zero() -> tensor<32xf64> {
 // CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse{{[0-9]*}}> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse{{[0-9]*}}> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_15:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xf64, #sparse{{[0-9]*}}> to memref<?xf64>
-// CHECK:           %[[VAL_16:.*]] = bufferization.to_memref %[[VAL_8]] : memref<8x8xf64>
+// CHECK-DAG:       %[[VAL_16:.*]] = bufferization.to_memref %[[VAL_8]] : memref<8x8xf64>
 // CHECK:           %[[VAL_17:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           %[[VAL_18:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_5]]] : memref<?xindex>
 // CHECK:           scf.for %[[VAL_19:.*]] = %[[VAL_17]] to %[[VAL_18]] step %[[VAL_5]] {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_sddmm_org.mlir b/mlir/test/Dialect/SparseTensor/sparse_sddmm_org.mlir
index 5fa332f8f1819..a66028c61b22f 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_sddmm_org.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_sddmm_org.mlir
@@ -29,14 +29,14 @@
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant false
 // CHECK-DAG:       %[[VAL_7:.*]] = arith.constant true
-// CHECK:           %[[VAL_8:.*]] = tensor.empty() : tensor<8x8xf64, #sparse{{[0-9]*}}>
-// CHECK:           %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<8x8xf64>
-// CHECK:           %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_2]] : memref<8x8xf64>
-// CHECK:           %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8x8xf64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8x8xf64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_13:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_14:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_15:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xf64, #sparse{{[0-9]*}}> to memref<?xf64>
+// CHECK-DAG:       %[[VAL_8:.*]] = tensor.empty() : tensor<8x8xf64, #sparse{{[0-9]*}}>
+// CHECK-DAG:       %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<8x8xf64>
+// CHECK-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_2]] : memref<8x8xf64>
+// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8x8xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8x8xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_15:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xf64, #sparse{{[0-9]*}}> to memref<?xf64>
 // CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           %[[VAL_17:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_5]]] : memref<?xindex>
 // CHECK:           %[[VAL_18:.*]] = scf.for %[[VAL_19:.*]] = %[[VAL_16]] to %[[VAL_17]] step %[[VAL_5]] iter_args(%[[VAL_20:.*]] = %[[VAL_8]]) -> (tensor<8x8xf64, #sparse{{[0-9]*}}>) {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_tensor_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_tensor_reshape.mlir
index 89826ebfe14d1..ad4f8e08b3e3d 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_tensor_reshape.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_tensor_reshape.mlir
@@ -9,12 +9,12 @@
 // CHECK-DAG:     %[[C10:.*]] = arith.constant 10 : index
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
-// CHECK:         %[[B:.*]] = bufferization.alloc_tensor()
-// CHECK:         %[[P0:.*]] = sparse_tensor.positions %[[S]] {level = 0 : index}
-// CHECK:         %[[I0:.*]] = sparse_tensor.coordinates %[[S]] {level = 0 : index}
-// CHECK:         %[[P1:.*]] = sparse_tensor.positions %[[S]] {level = 1 : index}
-// CHECK:         %[[I1:.*]] = sparse_tensor.coordinates %[[S]] {level = 1 : index}
-// CHECK:         %[[V:.*]] = sparse_tensor.values %[[S]]
+// CHECK-DAG:     %[[B:.*]] = bufferization.alloc_tensor()
+// CHECK-DAG:     %[[P0:.*]] = sparse_tensor.positions %[[S]] {level = 0 : index}
+// CHECK-DAG:     %[[I0:.*]] = sparse_tensor.coordinates %[[S]] {level = 0 : index}
+// CHECK-DAG:     %[[P1:.*]] = sparse_tensor.positions %[[S]] {level = 1 : index}
+// CHECK-DAG:     %[[I1:.*]] = sparse_tensor.coordinates %[[S]] {level = 1 : index}
+// CHECK-DAG:     %[[V:.*]] = sparse_tensor.values %[[S]]
 // CHECK:         %[[S0:.*]] = memref.load %[[P0]]{{\[}}%[[C0]]] : memref<?xindex>
 // CHECK:         %[[E0:.*]] = memref.load %[[P0]]{{\[}}%[[C1]]] : memref<?xindex>
 // CHECK:         %[[RET:.*]] = scf.for %[[I:.*]] = %[[S0]] to %[[E0]] step %[[C1]] iter_args(%[[A0:.*]] = %[[B]])
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
index e3508f11cb758..281e7858ce25e 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
@@ -25,13 +25,13 @@
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 64 : index
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 1 : index} : tensor<64x32xf64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 1 : index} : tensor<64x32xf64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<64x32xf64, #sparse{{[0-9]*}}> to memref<?xf64>
-// CHECK:           %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_2]] {level = 1 : index} : tensor<64x32xf64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_2]] {level = 1 : index} : tensor<64x32xf64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_2]] : tensor<64x32xf64, #sparse{{[0-9]*}}> to memref<?xf64>
-// CHECK:           %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_0]] : memref<f64>
+// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 1 : index} : tensor<64x32xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 1 : index} : tensor<64x32xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<64x32xf64, #sparse{{[0-9]*}}> to memref<?xf64>
+// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_2]] {level = 1 : index} : tensor<64x32xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_2]] {level = 1 : index} : tensor<64x32xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_2]] : tensor<64x32xf64, #sparse{{[0-9]*}}> to memref<?xf64>
+// CHECK-DAG:       %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_0]] : memref<f64>
 // CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_14]][] : memref<f64>
 // CHECK:           %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_7]] iter_args(%[[VAL_18:.*]] = %[[VAL_15]]) -> (f64) {
 // CHECK:             %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref<?xindex>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir
index c9d432924c0db..ac357d9dc2485 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir
@@ -25,11 +25,11 @@
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 1 : index
 // CHECK-DAG:       %[[VAL_7:.*]] = tensor.empty() : tensor<8xi64>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xi64>
-// CHECK:           %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_7]] : memref<8xi64>
-// CHECK:           linalg.fill ins(%[[VAL_4]] : i64) outs(%[[VAL_11]] : memref<8xi64>)
+// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xi64>
+// CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_7]] : memref<8xi64>
+// CHECK-DAG:       linalg.fill ins(%[[VAL_4]] : i64) outs(%[[VAL_11]] : memref<8xi64>)
 // CHECK:           %[[VAL_12:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_5]]] : memref<?xindex>
 // CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_6]]] : memref<?xindex>
 // CHECK:           scf.for %[[VAL_14:.*]] = %[[VAL_12]] to %[[VAL_13]] step %[[VAL_1]] {
@@ -67,11 +67,11 @@ func.func @sparse_index_1d_conj(%arga: tensor<8xi64, #SparseVector>) -> tensor<8
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant true
 // CHECK-DAG:       %[[VAL_7:.*]] = tensor.empty() : tensor<8xi64>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xi64>
-// CHECK:           %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_7]] : memref<8xi64>
-// CHECK:           linalg.fill ins(%[[VAL_3]] : i64) outs(%[[VAL_11]] : memref<8xi64>)
+// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xi64>
+// CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_7]] : memref<8xi64>
+// CHECK-DAG:       linalg.fill ins(%[[VAL_3]] : i64) outs(%[[VAL_11]] : memref<8xi64>)
 // CHECK:           %[[VAL_12:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_5]]] : memref<?xindex>
 // CHECK:           %[[VAL_14:.*]]:2 = scf.while (%[[VAL_15:.*]] = %[[VAL_12]], %[[VAL_16:.*]] = %[[VAL_4]]) : (index, index) -> (index, index) {
diff --git a/mlir/test/Dialect/SparseTensor/spy_sddmm_bsr.mlir b/mlir/test/Dialect/SparseTensor/spy_sddmm_bsr.mlir
index eac834b946c2e..10a7ac5802ec9 100755
--- a/mlir/test/Dialect/SparseTensor/spy_sddmm_bsr.mlir
+++ b/mlir/test/Dialect/SparseTensor/spy_sddmm_bsr.mlir
@@ -35,14 +35,14 @@
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 2 : index
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.reinterpret_map %[[VAL_0]] : tensor<?x?xf32, #[[$BSR]]> to tensor<?x?x2x2xf32, #[[$MAP]]>
-// CHECK:           %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[VAL_3]] : tensor<?x?xf32>
-// CHECK:           %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<?x?xf32>
-// CHECK:           %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_2]] : memref<?x?xf32>
-// CHECK:           %[[VAL_11:.*]] = sparse_tensor.lvl %[[VAL_7]], %[[VAL_4]] : tensor<?x?x2x2xf32, #[[$MAP]]>
-// CHECK:           %[[VAL_12:.*]] = sparse_tensor.positions %[[VAL_7]] {level = 1 : index} : tensor<?x?x2x2xf32, #[[$MAP]]> to memref<?xindex>
-// CHECK:           %[[VAL_13:.*]] = sparse_tensor.coordinates %[[VAL_7]] {level = 1 : index} : tensor<?x?x2x2xf32, #[[$MAP]]> to memref<?xindex>
-// CHECK:           %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_7]] : tensor<?x?x2x2xf32, #[[$MAP]]> to memref<?xf32>
+// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.reinterpret_map %[[VAL_0]] : tensor<?x?xf32, #[[$BSR]]> to tensor<?x?x2x2xf32, #[[$MAP]]>
+// CHECK-DAG:       %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[VAL_3]] : tensor<?x?xf32>
+// CHECK-DAG:       %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<?x?xf32>
+// CHECK-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_2]] : memref<?x?xf32>
+// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.lvl %[[VAL_7]], %[[VAL_4]] : tensor<?x?x2x2xf32, #[[$MAP]]>
+// CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.positions %[[VAL_7]] {level = 1 : index} : tensor<?x?x2x2xf32, #[[$MAP]]> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.coordinates %[[VAL_7]] {level = 1 : index} : tensor<?x?x2x2xf32, #[[$MAP]]> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_7]] : tensor<?x?x2x2xf32, #[[$MAP]]> to memref<?xf32>
 // CHECK:           scf.for %[[VAL_15:.*]] = %[[VAL_4]] to %[[VAL_11]] step %[[VAL_3]] {
 // CHECK:             %[[VAL_16:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_15]]] : memref<?xindex>
 // CHECK:             %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_3]] : index
diff --git a/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir b/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir
index d2b51322e0be4..578e86a793f90 100644
--- a/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir
+++ b/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir
@@ -14,9 +14,9 @@
 // CHECK-ON-DAG:       %[[VAL_3:.*]] = arith.constant dense<0> : vector<8xi13>
 // CHECK-ON-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-ON-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK-ON:           %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<?xi13, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK-ON:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xi13, #sparse{{[0-9]*}}> to memref<?xi13>
-// CHECK-ON:           %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_0]] : memref<i13>
+// CHECK-ON-DAG:       %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<?xi13, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-ON-DAG:       %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xi13, #sparse{{[0-9]*}}> to memref<?xi13>
+// CHECK-ON-DAG:       %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_0]] : memref<i13>
 // CHECK-ON:           %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i13>
 // CHECK-ON:           %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK-ON:           %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
@@ -40,9 +40,9 @@
 // CHECK-OFF-SAME:      %[[VAL_1:.*]]: tensor<?xi13, #sparse{{[0-9]*}}>) -> tensor<i13> {
 // CHECK-OFF-DAG:       %[[VAL_2:.*]] = arith.constant 0 : index
 // CHECK-OFF-DAG:       %[[VAL_3:.*]] = arith.constant 1 : index
-// CHECK-OFF:           %[[VAL_4:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<?xi13, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK-OFF:           %[[VAL_5:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xi13, #sparse{{[0-9]*}}> to memref<?xi13>
-// CHECK-OFF:           %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_0]] : memref<i13>
+// CHECK-OFF-DAG:       %[[VAL_4:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<?xi13, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-OFF-DAG:       %[[VAL_5:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xi13, #sparse{{[0-9]*}}> to memref<?xi13>
+// CHECK-OFF-DAG:       %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_0]] : memref<i13>
 // CHECK-OFF:           %[[VAL_7:.*]] = memref.load %[[VAL_6]][] : memref<i13>
 // CHECK-OFF:           %[[VAL_8:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref<?xindex>
 // CHECK-OFF:           %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref<?xindex>
@@ -91,9 +91,9 @@ func.func @sparse_reduction_ori(%argx: tensor<i13>,
 // CHECK-ON-DAG:       %[[VAL_3:.*]] = arith.constant dense<0> : vector<8xi13>
 // CHECK-ON-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-ON-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK-ON:           %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<?xi13, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK-ON:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xi13, #sparse{{[0-9]*}}> to memref<?xi13>
-// CHECK-ON:           %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_0]] : memref<i13>
+// CHECK-ON-DAG:       %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<?xi13, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-ON-DAG:       %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xi13, #sparse{{[0-9]*}}> to memref<?xi13>
+// CHECK-ON-DAG:       %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_0]] : memref<i13>
 // CHECK-ON:           %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i13>
 // CHECK-ON:           %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK-ON:           %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
@@ -117,9 +117,9 @@ func.func @sparse_reduction_ori(%argx: tensor<i13>,
 // CHECK-OFF-SAME:      %[[VAL_1:.*]]: tensor<?xi13, #sparse{{[0-9]*}}>) -> tensor<i13> {
 // CHECK-OFF-DAG:       %[[VAL_2:.*]] = arith.constant 0 : index
 // CHECK-OFF-DAG:       %[[VAL_3:.*]] = arith.constant 1 : index
-// CHECK-OFF:           %[[VAL_4:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<?xi13, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK-OFF:           %[[VAL_5:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xi13, #sparse{{[0-9]*}}> to memref<?xi13>
-// CHECK-OFF:           %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_0]] : memref<i13>
+// CHECK-OFF-DAG:       %[[VAL_4:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<?xi13, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-OFF-DAG:       %[[VAL_5:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xi13, #sparse{{[0-9]*}}> to memref<?xi13>
+// CHECK-OFF-DAG:       %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_0]] : memref<i13>
 // CHECK-OFF:           %[[VAL_7:.*]] = memref.load %[[VAL_6]][] : memref<i13>
 // CHECK-OFF:           %[[VAL_8:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref<?xindex>
 // CHECK-OFF:           %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref<?xindex>
@@ -166,9 +166,9 @@ func.func @sparse_reduction_ori_accumulator_on_rhs(%argx: tensor<i13>,
 // CHECK-ON-DAG:       %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK-ON-DAG:       %[[VAL_4:.*]] = arith.constant dense<0> : vector<8xi32>
 // CHECK-ON-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK-ON:           %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK-ON:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xi32, #sparse{{[0-9]*}}> to memref<?xi32>
-// CHECK-ON:           %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_0]] : memref<i32>
+// CHECK-ON-DAG:       %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-ON-DAG:       %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xi32, #sparse{{[0-9]*}}> to memref<?xi32>
+// CHECK-ON-DAG:       %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_0]] : memref<i32>
 // CHECK-ON:           %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i32>
 // CHECK-ON:           %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK-ON:           %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
@@ -192,9 +192,9 @@ func.func @sparse_reduction_ori_accumulator_on_rhs(%argx: tensor<i13>,
 // CHECK-OFF-SAME:      %[[VAL_1:.*]]: tensor<?xi32, #sparse{{[0-9]*}}>) -> tensor<i32> {
 // CHECK-OFF-DAG:       %[[VAL_2:.*]] = arith.constant 0 : index
 // CHECK-OFF-DAG:       %[[VAL_3:.*]] = arith.constant 1 : index
-// CHECK-OFF:           %[[VAL_4:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK-OFF:           %[[VAL_5:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xi32, #sparse{{[0-9]*}}> to memref<?xi32>
-// CHECK-OFF:           %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_0]] : memref<i32>
+// CHECK-OFF-DAG:       %[[VAL_4:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-OFF-DAG:       %[[VAL_5:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xi32, #sparse{{[0-9]*}}> to memref<?xi32>
+// CHECK-OFF-DAG:       %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_0]] : memref<i32>
 // CHECK-OFF:           %[[VAL_7:.*]] = memref.load %[[VAL_6]][] : memref<i32>
 // CHECK-OFF:           %[[VAL_8:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref<?xindex>
 // CHECK-OFF:           %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref<?xindex>
@@ -241,9 +241,9 @@ func.func @sparse_reduction_subi(%argx: tensor<i32>,
 // CHECK-ON-DAG:  %[[VAL_3:.*]] = arith.constant dense<0> : vector<8xi32>
 // CHECK-ON-DAG:  %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-ON-DAG:  %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK-ON:  %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK-ON:  %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xi32, #sparse{{[0-9]*}}> to memref<?xi32>
-// CHECK-ON:  %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_0]] : memref<i32>
+// CHECK-ON-DAG:  %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-ON-DAG:  %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xi32, #sparse{{[0-9]*}}> to memref<?xi32>
+// CHECK-ON-DAG:  %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_0]] : memref<i32>
 // CHECK-ON:  %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i32>
 // CHECK-ON:  %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK-ON:  %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
@@ -267,9 +267,9 @@ func.func @sparse_reduction_subi(%argx: tensor<i32>,
 // CHECK-OFF-SAME:  %[[VAL_1:.*]]: tensor<?xi32, #sparse{{[0-9]*}}>) -> tensor<i32> {
 // CHECK-OFF-DAG:   %[[VAL_2:.*]] = arith.constant 0 : index
 // CHECK-OFF-DAG:   %[[VAL_3:.*]] = arith.constant 1 : index
-// CHECK-OFF:   %[[VAL_4:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK-OFF:   %[[VAL_5:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xi32, #sparse{{[0-9]*}}> to memref<?xi32>
-// CHECK-OFF:   %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_0]] : memref<i32>
+// CHECK-OFF-DAG:   %[[VAL_4:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-OFF-DAG:   %[[VAL_5:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xi32, #sparse{{[0-9]*}}> to memref<?xi32>
+// CHECK-OFF-DAG:   %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_0]] : memref<i32>
 // CHECK-OFF:   %[[VAL_7:.*]] = memref.load %[[VAL_6]][] : memref<i32>
 // CHECK-OFF:   %[[VAL_8:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref<?xindex>
 // CHECK-OFF:   %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref<?xindex>
@@ -317,9 +317,9 @@ func.func @sparse_reduction_xor(%argx: tensor<i32>,
 // CHECK-ON-DAG:   %[[VAL_3:.*]] = arith.constant dense<0> : vector<8xi32>
 // CHECK-ON-DAG:   %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-ON-DAG:   %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK-ON:   %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK-ON:   %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xi32, #sparse{{[0-9]*}}> to memref<?xi32>
-// CHECK-ON:   %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_0]] : memref<i32>
+// CHECK-ON-DAG:   %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-ON-DAG:   %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xi32, #sparse{{[0-9]*}}> to memref<?xi32>
+// CHECK-ON-DAG:   %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_0]] : memref<i32>
 // CHECK-ON:   %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i32>
 // CHECK-ON:   %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK-ON:   %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
@@ -343,9 +343,9 @@ func.func @sparse_reduction_xor(%argx: tensor<i32>,
 // CHECK-OFF-SAME:   %[[VAL_1:.*]]: tensor<?xi32, #sparse{{[0-9]*}}>) -> tensor<i32> {
 // CHECK-OFF-DAG:   %[[VAL_2:.*]] = arith.constant 0 : index
 // CHECK-OFF-DAG:   %[[VAL_3:.*]] = arith.constant 1 : index
-// CHECK-OFF:   %[[VAL_4:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK-OFF:   %[[VAL_5:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xi32, #sparse{{[0-9]*}}> to memref<?xi32>
-// CHECK-OFF:   %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_0]] : memref<i32>
+// CHECK-OFF-DAG:   %[[VAL_4:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<?xi32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-OFF-DAG:   %[[VAL_5:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xi32, #sparse{{[0-9]*}}> to memref<?xi32>
+// CHECK-OFF-DAG:   %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_0]] : memref<i32>
 // CHECK-OFF:   %[[VAL_7:.*]] = memref.load %[[VAL_6]][] : memref<i32>
 // CHECK-OFF:   %[[VAL_8:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref<?xindex>
 // CHECK-OFF:   %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref<?xindex>
@@ -393,9 +393,9 @@ func.func @sparse_reduction_addi(%argx: tensor<i32>,
 // CHECK-ON-DAG:   %[[VAL_3:.*]] = arith.constant dense<0.000000e+00> : vector<8xf32>
 // CHECK-ON-DAG:   %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-ON-DAG:   %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK-ON:   %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK-ON:   %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xf32, #sparse{{[0-9]*}}> to memref<?xf32>
-// CHECK-ON:   %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_0]] : memref<f32>
+// CHECK-ON-DAG:   %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-ON-DAG:   %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xf32, #sparse{{[0-9]*}}> to memref<?xf32>
+// CHECK-ON-DAG:   %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_0]] : memref<f32>
 // CHECK-ON:   %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<f32>
 // CHECK-ON:   %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK-ON:   %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
@@ -419,9 +419,9 @@ func.func @sparse_reduction_addi(%argx: tensor<i32>,
 // CHECK-OFF-SAME:   %[[VAL_1:.*]]: tensor<?xf32, #sparse{{[0-9]*}}>) -> tensor<f32> {
 // CHECK-OFF-DAG:   %[[VAL_2:.*]] = arith.constant 0 : index
 // CHECK-OFF-DAG:   %[[VAL_3:.*]] = arith.constant 1 : index
-// CHECK-OFF:   %[[VAL_4:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK-OFF:   %[[VAL_5:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xf32, #sparse{{[0-9]*}}> to memref<?xf32>
-// CHECK-OFF:   %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_0]] : memref<f32>
+// CHECK-OFF-DAG:   %[[VAL_4:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-OFF-DAG:   %[[VAL_5:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xf32, #sparse{{[0-9]*}}> to memref<?xf32>
+// CHECK-OFF-DAG:   %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_0]] : memref<f32>
 // CHECK-OFF:   %[[VAL_7:.*]] = memref.load %[[VAL_6]][] : memref<f32>
 // CHECK-OFF:   %[[VAL_8:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref<?xindex>
 // CHECK-OFF:   %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref<?xindex>
@@ -469,9 +469,9 @@ func.func @sparse_reduction_subf(%argx: tensor<f32>,
 // CHECK-ON-DAG:   %[[VAL_3:.*]] = arith.constant dense<0.000000e+00> : vector<8xf32>
 // CHECK-ON-DAG:   %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-ON-DAG:   %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK-ON:   %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK-ON:   %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xf32, #sparse{{[0-9]*}}> to memref<?xf32>
-// CHECK-ON:   %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_0]] : memref<f32>
+// CHECK-ON-DAG:   %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-ON-DAG:   %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xf32, #sparse{{[0-9]*}}> to memref<?xf32>
+// CHECK-ON-DAG:   %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_0]] : memref<f32>
 // CHECK-ON:   %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<f32>
 // CHECK-ON:   %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK-ON:   %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
@@ -495,9 +495,9 @@ func.func @sparse_reduction_subf(%argx: tensor<f32>,
 // CHECK-OFF-SAME:  %[[VAL_1:.*]]: tensor<?xf32, #sparse{{[0-9]*}}>) -> tensor<f32> {
 // CHECK-OFF-DAG:   %[[VAL_2:.*]] = arith.constant 0 : index
 // CHECK-OFF-DAG:   %[[VAL_3:.*]] = arith.constant 1 : index
-// CHECK-OFF:   %[[VAL_4:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK-OFF:   %[[VAL_5:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xf32, #sparse{{[0-9]*}}> to memref<?xf32>
-// CHECK-OFF:   %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_0]] : memref<f32>
+// CHECK-OFF-DAG:   %[[VAL_4:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-OFF-DAG:   %[[VAL_5:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xf32, #sparse{{[0-9]*}}> to memref<?xf32>
+// CHECK-OFF-DAG:   %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_0]] : memref<f32>
 // CHECK-OFF:   %[[VAL_7:.*]] = memref.load %[[VAL_6]][] : memref<f32>
 // CHECK-OFF:   %[[VAL_8:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref<?xindex>
 // CHECK-OFF:   %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref<?xindex>

>From 824cea11bdf6ea0c5f8fca82953219c8da693a24 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 17 Jun 2024 22:14:32 +0000
Subject: [PATCH 2/2] add tests

---
 .../sparse_kernels_to_iterator.mlir           |  8 +-
 .../CPU/iterator-based-sqsum.mlir             | 80 +++++++++++++++++++
 2 files changed, 87 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Integration/Dialect/SparseTensor/CPU/iterator-based-sqsum.mlir

diff --git a/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir b/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
index 35cf9aaf446b2..f5bbea0d340fb 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
@@ -27,7 +27,13 @@
 // CHECK:         }
 func.func @sqsum(%arg0: tensor<?x?x?x?xi32, #COO>) -> tensor<i32> {
   %cst = arith.constant dense<0> : tensor<i32>
-  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> ()>], iterator_types = ["reduction", "reduction", "reduction", "reduction"]} ins(%arg0 : tensor<?x?x?x?xi32, #COO>) outs(%cst : tensor<i32>) {
+  %0 = linalg.generic {
+    indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> ()>
+    ],
+    iterator_types = ["reduction", "reduction", "reduction", "reduction"]
+  } ins(%arg0 : tensor<?x?x?x?xi32, #COO>) outs(%cst : tensor<i32>) {
   ^bb0(%in: i32, %out: i32):
     %1 = arith.muli %in, %in : i32
     %2 = arith.addi %out, %1 : i32
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/iterator-based-sqsum.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/iterator-based-sqsum.mlir
new file mode 100644
index 0000000000000..758c780e10cdd
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/iterator-based-sqsum.mlir
@@ -0,0 +1,80 @@
+//--------------------------------------------------------------------------------------------------
+// WHEN CREATING A NEW TEST, PLEASE JUST COPY & PASTE WITHOUT EDITS.
+//
+// Set-up that's shared across all tests in this directory. In principle, this
+// config could be moved to lit.local.cfg. However, there are downstream users that
+//  do not use these LIT config files. Hence why this is kept inline.
+//
+// DEFINE: %{sparsifier_opts} = enable-runtime-library=true
+// DEFINE: %{sparsifier_opts_sve} = enable-arm-sve=true %{sparsifier_opts}
+// DEFINE: %{compile} = mlir-opt %s --sparsifier="%{sparsifier_opts}"
+// DEFINE: %{compile_sve} = mlir-opt %s --sparsifier="%{sparsifier_opts_sve}"
+// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils
+// DEFINE: %{run_opts} = -e main -entry-point-result=void
+// DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs}
+// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs}
+//
+// DEFINE: %{env} =
+//--------------------------------------------------------------------------------------------------
+
+// RUN: %{compile} | %{run} | FileCheck %s
+//
+// Do the same run, but now do sparsification using sparse-iterator-based loops.
+// REDEFINE: %{sparsifier_opts} = sparse-emit-strategy=sparse-iterator
+// RUN: %{compile} | %{run} | FileCheck %s
+//
+
+#COO = #sparse_tensor.encoding<{
+  map = (d0, d1, d2, d3) -> (
+    d0 : compressed(nonunique),
+    d1 : singleton(nonunique, soa),
+    d2 : singleton(nonunique, soa),
+    d3 : singleton(soa)
+  ),
+  explicitVal = 1 : i32
+}>
+
+// An example of vector reductions.
+module {
+
+  func.func @sqsum(%arg0: tensor<2x3x4x5xi32, #COO>) -> tensor<i32> {
+    %cst = arith.constant dense<0> : tensor<i32>
+    %0 = linalg.generic {
+      indexing_maps = [
+        affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+        affine_map<(d0, d1, d2, d3) -> ()>
+      ],
+      iterator_types = ["reduction", "reduction", "reduction", "reduction"]
+    } ins(%arg0 : tensor<2x3x4x5xi32, #COO>) outs(%cst : tensor<i32>) {
+    ^bb0(%in: i32, %out: i32):
+      %1 = arith.muli %in, %in : i32
+      %2 = arith.addi %out, %1 : i32
+      linalg.yield %2 : i32
+    } -> tensor<i32>
+    return %0 : tensor<i32>
+  }
+
+  func.func @main() {
+    %cst = arith.constant sparse<
+     [
+       [0, 1, 2, 3],
+       [1, 1, 2, 3],
+       [1, 2, 2, 3],
+       [1, 2, 3, 4]
+     ],
+     [1, 1, 1, 1]
+    > : tensor<2x3x4x5xi32>
+
+    %input = sparse_tensor.convert %cst : tensor<2x3x4x5xi32> to tensor<2x3x4x5xi32, #COO>
+    %0 = call @sqsum(%input) : (tensor<2x3x4x5xi32, #COO>) -> tensor<i32>
+    %v = tensor.extract %0[] : tensor<i32>
+
+    // CHECK: 4
+    vector.print %v : i32
+
+
+    bufferization.dealloc_tensor %input : tensor<2x3x4x5xi32, #COO>
+    bufferization.dealloc_tensor %0 : tensor<i32>
+    return
+  }
+}



More information about the Mlir-commits mailing list