[Mlir-commits] [mlir] 7373cab - [mlir][sparse] implement full reduction "scalarization" across loop nests

Aart Bik llvmlistbot at llvm.org
Thu Nov 4 17:39:05 PDT 2021


Author: Aart Bik
Date: 2021-11-04T17:38:47-07:00
New Revision: 7373cabcda8f5c0ed83cf40034ff69bc47a4a8c9

URL: https://github.com/llvm/llvm-project/commit/7373cabcda8f5c0ed83cf40034ff69bc47a4a8c9
DIFF: https://github.com/llvm/llvm-project/commit/7373cabcda8f5c0ed83cf40034ff69bc47a4a8c9.diff

LOG: [mlir][sparse] implement full reduction "scalarization" across loop nests

The earlier reduction "scalarization" was only applied to a chain of
*innermost* and *for* loops. This revision generalizes this to any
nesting of for- and while-loops. This implies that reductions can be
implemented with a lot less load and store operations. The chaining
is implemented with a forest of yield statements (but not as bad as
when we would also include the while-induction).

Fixes https://bugs.llvm.org/show_bug.cgi?id=52311

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D113078

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/test/Dialect/SparseTensor/sparse_1d.mlir
    mlir/test/Dialect/SparseTensor/sparse_2d.mlir
    mlir/test/Dialect/SparseTensor/sparse_3d.mlir
    mlir/test/Dialect/SparseTensor/sparse_lower.mlir
    mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir
    mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index cfab38616d55f..f8db7eb00319a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -39,7 +39,7 @@ namespace {
 enum SortMask { kSparseOnly = 0x0, kIncludeDense = 0x1, kIncludeUndef = 0x2 };
 
 // Reduction kinds.
-enum Reduction { kSum, kProduct, kAnd, kOr, kXor };
+enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor };
 
 // Code generation.
 struct CodeGen {
@@ -50,7 +50,7 @@ struct CodeGen {
         highs(numTensors, std::vector<Value>(numLoops)),
         pidxs(numTensors, std::vector<Value>(numLoops)),
         idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(),
-        curVecLength(1), curVecMask() {}
+        redKind(kNoReduc), curVecLength(1), curVecMask() {}
   /// Sparsification options.
   SparsificationOptions options;
   /// Universal dense indices and upper bounds (by index). The loops array
@@ -71,9 +71,7 @@ struct CodeGen {
   std::vector<std::vector<Value>> pidxs;
   std::vector<std::vector<Value>> idxs;
   /// Current reduction, updated during code generation. When indices of a
-  /// reduction are exhausted,  all inner loops can "scalarize" the reduction.
-  // TODO: currently only done for (a chain of) innermost for-loops, where it
-  // is most effective; we could generalize to more outer and while-loops.
+  /// reduction are exhausted, all inner loops can use a scalarized reduction.
   unsigned redExp;
   Value redVal;
   Reduction redKind;
@@ -314,12 +312,14 @@ static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
 }
 
 //===----------------------------------------------------------------------===//
-// Sparse compiler synthesis methods (statements and expressions).
+// Sparse compiler synthesis methods (reductions).
 //===----------------------------------------------------------------------===//
 
 /// Maps reduction kind to name encoding.
 static StringRef getReductionName(Reduction kind) {
   switch (kind) {
+  case kNoReduc:
+    break;
   case kSum:
     return "add";
   case kProduct:
@@ -356,13 +356,16 @@ static Reduction getReduction(Kind kind) {
   }
 }
 
-/// Generates an initial value for a vector reductions, following the scheme
+/// Generates an initial value for a vector reduction, following the scheme
 /// given in Chapter 5 of "The Software Vectorization Handbook", where the
 /// initial scalar value is correctly embedded in the vector reduction value,
 /// and a straightforward horizontal reduction will complete the operation.
-static Value genReductionInit(PatternRewriter &rewriter, Location loc,
-                              Reduction kind, VectorType vtp, Value r) {
-  switch (kind) {
+static Value genVectorReducInit(CodeGen &codegen, PatternRewriter &rewriter,
+                                Location loc, VectorType vtp) {
+  Value r = codegen.redVal;
+  switch (codegen.redKind) {
+  case kNoReduc:
+    break;
   case kSum:
   case kXor: {
     // Initialize reduction vector to: | 0 | .. | 0 | r |
@@ -390,6 +393,25 @@ static Value genReductionInit(PatternRewriter &rewriter, Location loc,
   llvm_unreachable("unknown reduction kind");
 }
 
+/// Generates final value for a vector reduction.
+static Value genVectorReducEnd(CodeGen &codegen, PatternRewriter &rewriter,
+                               Location loc, VectorType vtp) {
+  StringRef name = getReductionName(codegen.redKind);
+  StringAttr kind = rewriter.getStringAttr(name);
+  return rewriter.create<vector::ReductionOp>(loc, vtp.getElementType(), kind,
+                                              codegen.redVal, ValueRange{});
+}
+
+/// Updates scalarized reduction value.
+static void updateReduc(Merger &merger, CodeGen &codegen, Value reduc) {
+  assert(codegen.redKind != kNoReduc);
+  codegen.redVal = merger.exp(codegen.redExp).val = reduc;
+}
+
+//===----------------------------------------------------------------------===//
+// Sparse compiler synthesis methods (statements and expressions).
+//===----------------------------------------------------------------------===//
+
 /// Maps sparse integer option to actual integral storage type.
 static Type genIntType(PatternRewriter &rewriter, unsigned width) {
   if (width == 0)
@@ -516,7 +538,7 @@ static VectorType vectorType(CodeGen &codegen, Value ptr) {
 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter,
                            Value iv, Value lo, Value hi, Value step) {
   Location loc = iv.getLoc();
-  VectorType mtp = vectorType(codegen, rewriter.getIntegerType(1));
+  VectorType mtp = vectorType(codegen, genIntType(rewriter, 1));
   // Special case if the vector length evenly divides the trip count (for
   // example, "for i = 0, 128, 16"). A constant all-true mask is generated
   // so that all subsequent masked memory operations are immediately folded
@@ -671,7 +693,7 @@ static void genTensorStore(Merger &merger, CodeGen &codegen,
     if (codegen.curVecLength > 1)
       rhs = rewriter.create<SelectOp>(op.getLoc(), codegen.curVecMask, rhs,
                                       codegen.redVal);
-    codegen.redVal = rhs;
+    updateReduc(merger, codegen, rhs);
     return;
   }
   // Actual store.
@@ -708,11 +730,11 @@ static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc,
     if (!etp.isa<IndexType>()) {
       if (etp.getIntOrFloatBitWidth() < 32)
         vload = rewriter.create<arith::ExtUIOp>(
-            loc, vload, vectorType(codegen, rewriter.getIntegerType(32)));
+            loc, vload, vectorType(codegen, genIntType(rewriter, 32)));
       else if (etp.getIntOrFloatBitWidth() < 64 &&
                !codegen.options.enableSIMDIndex32)
         vload = rewriter.create<arith::ExtUIOp>(
-            loc, vload, vectorType(codegen, rewriter.getIntegerType(64)));
+            loc, vload, vectorType(codegen, genIntType(rewriter, 64)));
     }
     return vload;
   }
@@ -723,8 +745,8 @@ static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc,
   Value load = rewriter.create<memref::LoadOp>(loc, ptr, s);
   if (!load.getType().isa<IndexType>()) {
     if (load.getType().getIntOrFloatBitWidth() < 64)
-      load = rewriter.create<arith::ExtUIOp>(loc, load,
-                                             rewriter.getIntegerType(64));
+      load =
+          rewriter.create<arith::ExtUIOp>(loc, load, genIntType(rewriter, 64));
     load =
         rewriter.create<arith::IndexCastOp>(loc, load, rewriter.getIndexType());
   }
@@ -752,43 +774,6 @@ static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter,
   return rewriter.create<arith::AddIOp>(loc, mul, i);
 }
 
-/// Generates start of a reduction.
-static Value genReductionStart(Merger &merger, CodeGen &codegen,
-                               PatternRewriter &rewriter,
-                               linalg::GenericOp op) {
-  if (codegen.redVal)
-    return codegen.redVal; // chained with previous for-loop
-  // Generate vector or scalar start of a reduction.
-  unsigned vl = codegen.curVecLength;
-  if (vl > 1) {
-    VectorType vtp = vectorType(codegen, codegen.buffers[codegen.redExp]);
-    assert(!merger.exp(codegen.redExp).val);
-    codegen.curVecLength = 1;
-    Value load = genTensorLoad(merger, codegen, rewriter, op, codegen.redExp);
-    codegen.curVecLength = vl;
-    return genReductionInit(rewriter, op.getLoc(), codegen.redKind, vtp, load);
-  }
-  return genTensorLoad(merger, codegen, rewriter, op, codegen.redExp);
-}
-
-/// Generates end of a reduction.
-static void genReductionEnd(Merger &merger, CodeGen &codegen,
-                            PatternRewriter &rewriter, linalg::GenericOp op) {
-  Value red = codegen.redVal;
-  if (!red)
-    return;
-  assert(codegen.curVecLength == 1);
-  codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain
-  // Generate vector or scalar end of a reduction.
-  if (auto vtp = red.getType().dyn_cast<VectorType>()) {
-    StringRef name = getReductionName(codegen.redKind);
-    StringAttr kind = rewriter.getStringAttr(name);
-    red = rewriter.create<vector::ReductionOp>(
-        op.getLoc(), vtp.getElementType(), kind, red, ValueRange{});
-  }
-  genTensorStore(merger, codegen, rewriter, op, red);
-}
-
 /// Recursively generates tensor expression.
 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
                     linalg::GenericOp op, unsigned exp) {
@@ -828,7 +813,7 @@ static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a,
 /// Hoists loop invariant tensor loads for which indices have been exhausted.
 static void genInvariants(Merger &merger, CodeGen &codegen,
                           PatternRewriter &rewriter, linalg::GenericOp op,
-                          unsigned exp, unsigned ldx, bool hoist,
+                          unsigned exp, unsigned ldx, bool atStart,
                           Kind last = Kind::kTensor) {
   if (exp == -1u)
     return;
@@ -844,14 +829,27 @@ static void genInvariants(Merger &merger, CodeGen &codegen,
         return; // still in play
     }
     // All exhausted at this level (atLevel denotes exactly at this level).
+    if (!atLevel)
+      return;
     OpOperand *lhs = op.getOutputOperand(0);
     if (lhs == t) {
-      codegen.redExp = hoist ? exp : -1u;
-      codegen.redKind = getReduction(last);
-      assert(!codegen.redVal);
-    } else if (atLevel) {
+      // Start or end a scalarized reduction
+      if (atStart) {
+        Value load = genTensorLoad(merger, codegen, rewriter, op, exp);
+        codegen.redKind = getReduction(last);
+        codegen.redExp = exp;
+        updateReduc(merger, codegen, load);
+      } else {
+        Value redVal = codegen.redVal;
+        updateReduc(merger, codegen, Value());
+        codegen.redExp = -1u;
+        codegen.redKind = kNoReduc;
+        genTensorStore(merger, codegen, rewriter, op, redVal);
+      }
+    } else {
+      // Start or end loop invariant hoisting of a tensor load.
       merger.exp(exp).val =
-          hoist ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value();
+          atStart ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value();
     }
   } else if (merger.exp(exp).kind != Kind::kInvariant) {
     // Traverse into the binary operations. Note that we only hoist
@@ -860,8 +858,8 @@ static void genInvariants(Merger &merger, CodeGen &codegen,
     Kind last = merger.exp(exp).kind;
     unsigned e0 = merger.exp(exp).children.e0;
     unsigned e1 = merger.exp(exp).children.e1;
-    genInvariants(merger, codegen, rewriter, op, e0, ldx, hoist, last);
-    genInvariants(merger, codegen, rewriter, op, e1, ldx, hoist, last);
+    genInvariants(merger, codegen, rewriter, op, e0, ldx, atStart, last);
+    genInvariants(merger, codegen, rewriter, op, e1, ldx, atStart, last);
   }
 }
 
@@ -1005,18 +1003,20 @@ static Operation *genFor(Merger &merger, CodeGen &codegen,
     return parOp;
   }
 
-  // Emit a sequential loop, potentially with a scalarized reduction.
-  bool scalarRed = isInner && codegen.redExp != -1u;
+  // Emit a sequential or vector loop.
   SmallVector<Value, 4> operands;
-  if (scalarRed) {
-    Value load = genReductionStart(merger, codegen, rewriter, op);
-    operands.push_back(load);
+  if (codegen.redVal) {
+    // In a vector loop, bring reduction into SIMD form, if not already.
+    if (isVector && !codegen.redVal.getType().isa<VectorType>()) {
+      VectorType vtp = vectorType(codegen, codegen.redVal.getType());
+      Value vred = genVectorReducInit(codegen, rewriter, loc, vtp);
+      updateReduc(merger, codegen, vred);
+    }
+    operands.push_back(codegen.redVal);
   }
   scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands);
-  if (scalarRed) {
-    codegen.redVal = merger.exp(codegen.redExp).val =
-        forOp.getRegionIterArgs().front();
-  }
+  if (codegen.redVal)
+    updateReduc(merger, codegen, forOp.getRegionIterArgs().front());
   // Assign induction variable to sparse or dense index.
   Value iv = forOp.getInductionVar();
   if (isSparse)
@@ -1044,17 +1044,18 @@ static Operation *genWhile(Merger &merger, CodeGen &codegen,
       unsigned tensor = merger.tensor(b);
       assert(idx == merger.index(b));
       types.push_back(indexType);
-      assert(codegen.pidxs[tensor][idx].getType().isa<IndexType>() &&
-             "type mismatch for sparse index");
       operands.push_back(codegen.pidxs[tensor][idx]);
     }
   }
+  if (codegen.redVal) {
+    types.push_back(codegen.redVal.getType());
+    operands.push_back(codegen.redVal);
+  }
   if (needsUniv) {
     types.push_back(indexType);
-    assert(codegen.loops[idx].getType().isa<IndexType>() &&
-           "type mismatch for universal index");
     operands.push_back(codegen.loops[idx]);
   }
+  assert(types.size() == operands.size());
   Location loc = op.getLoc();
   scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
   Block *before = rewriter.createBlock(&whileOp.before(), {}, types);
@@ -1077,6 +1078,8 @@ static Operation *genWhile(Merger &merger, CodeGen &codegen,
       codegen.pidxs[tensor][idx] = after->getArgument(o++);
     }
   }
+  if (codegen.redVal)
+    updateReduc(merger, codegen, after->getArgument(o++));
   if (needsUniv)
     codegen.loops[idx] = after->getArgument(o++);
   assert(o == operands.size());
@@ -1098,7 +1101,6 @@ static Operation *genLoop(Merger &merger, CodeGen &codegen,
     return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx,
                   indices);
   }
-  genReductionEnd(merger, codegen, rewriter, op); // cannot chain
   return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices);
 }
 
@@ -1163,8 +1165,24 @@ static void genLocals(Merger &merger, CodeGen &codegen,
 static void genWhileInduction(Merger &merger, CodeGen &codegen,
                               PatternRewriter &rewriter, linalg::GenericOp op,
                               unsigned idx, bool needsUniv,
-                              llvm::BitVector &induction, ResultRange results) {
+                              llvm::BitVector &induction,
+                              scf::WhileOp whileOp) {
   Location loc = op.getLoc();
+  // Finalize each else branch of all if statements.
+  if (codegen.redVal) {
+    while (auto ifOp = dyn_cast_or_null<scf::IfOp>(
+               rewriter.getInsertionBlock()->getParentOp())) {
+      rewriter.create<scf::YieldOp>(loc, codegen.redVal);
+      updateReduc(merger, codegen, ifOp.getResult(0));
+      rewriter.setInsertionPointAfter(ifOp);
+    }
+  }
+  rewriter.setInsertionPointToEnd(&whileOp.after().front());
+  // Finalize the induction. Note that the induction could be performed
+  // in the individual if-branches to avoid re-evaluating the conditions.
+  // However, that would result in a rather elaborate forest of yield
+  // instructions during code generation. Moreover, performing the induction
+  // after the if-statements more closely resembles code generated by TACO.
   unsigned o = 0;
   SmallVector<Value, 4> operands;
   Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
@@ -1179,16 +1197,38 @@ static void genWhileInduction(Merger &merger, CodeGen &codegen,
                                                  op1, op2);
       Value add = rewriter.create<arith::AddIOp>(loc, op3, one);
       operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3));
-      codegen.pidxs[tensor][idx] = results[o++];
+      codegen.pidxs[tensor][idx] = whileOp->getResult(o++);
     }
   }
+  if (codegen.redVal) {
+    operands.push_back(codegen.redVal);
+    updateReduc(merger, codegen, whileOp->getResult(o++));
+  }
   if (needsUniv) {
     operands.push_back(
         rewriter.create<arith::AddIOp>(loc, codegen.loops[idx], one));
-    codegen.loops[idx] = results[o++];
+    codegen.loops[idx] = whileOp->getResult(o++);
   }
   assert(o == operands.size());
   rewriter.create<scf::YieldOp>(loc, operands);
+  rewriter.setInsertionPointAfter(whileOp);
+}
+
+/// Generates the induction structure for a for-loop.
+static void genForInduction(Merger &merger, CodeGen &codegen,
+                            PatternRewriter &rewriter, linalg::GenericOp op,
+                            Operation *loop) {
+  Location loc = op.getLoc();
+  unsigned o = 0;
+  SmallVector<Value, 4> operands;
+  if (codegen.redVal) {
+    operands.push_back(codegen.redVal);
+    updateReduc(merger, codegen, loop->getResult(o++));
+  }
+  assert(o == operands.size());
+  if (o > 0)
+    rewriter.create<scf::YieldOp>(loc, operands);
+  rewriter.setInsertionPointAfter(loop);
 }
 
 /// Generates a single if-statement within a while-loop.
@@ -1196,6 +1236,7 @@ static scf::IfOp genIf(Merger &merger, CodeGen &codegen,
                        PatternRewriter &rewriter, linalg::GenericOp op,
                        unsigned idx, llvm::BitVector &conditions) {
   Location loc = op.getLoc();
+  SmallVector<Type, 4> types;
   Value cond;
   for (unsigned b = 0, be = conditions.size(); b < be; b++) {
     if (conditions[b]) {
@@ -1213,11 +1254,23 @@ static scf::IfOp genIf(Merger &merger, CodeGen &codegen,
       cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, clause) : clause;
     }
   }
-  scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ true);
+  if (codegen.redVal)
+    types.push_back(codegen.redVal.getType());
+  scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, types, cond, /*else=*/true);
   rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
   return ifOp;
 }
 
+/// Generates end of true branch of if-statement within a while-loop.
+static void endIf(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
+                  linalg::GenericOp op, scf::IfOp ifOp, Value ifInput) {
+  if (codegen.redVal) {
+    rewriter.create<scf::YieldOp>(op.getLoc(), codegen.redVal);
+    updateReduc(merger, codegen, ifInput);
+  }
+  rewriter.setInsertionPointToStart(&ifOp.elseRegion().front());
+}
+
 //===----------------------------------------------------------------------===//
 // Sparse compiler synthesis methods (loop sequence).
 //===----------------------------------------------------------------------===//
@@ -1230,14 +1283,16 @@ static bool startLoopSeq(Merger &merger, CodeGen &codegen,
                          unsigned at, unsigned idx, unsigned ldx,
                          unsigned lts) {
   assert(codegen.curVecLength == 1);
+  assert(!codegen.loops[idx]);
   // Emit invariants at this loop sequence level.
-  genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/true);
+  genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/true);
   // Emit further intitialization at this loop sequence level.
   unsigned l0 = merger.set(lts)[0];
-  if (genInit(merger, codegen, rewriter, op, topSort, at,
-              merger.lat(l0).bits)) {
-    // Maintain the universal index only if it is actually
-    // consumed by a subsequent lattice point.
+  bool needsUniv =
+      genInit(merger, codegen, rewriter, op, topSort, at, merger.lat(l0).bits);
+  // Maintain the universal index only if it is actually
+  // consumed by a subsequent lattice point.
+  if (needsUniv) {
     unsigned lsize = merger.set(lts).size();
     for (unsigned i = 1; i < lsize; i++) {
       unsigned li = merger.set(lts)[i];
@@ -1270,16 +1325,12 @@ static bool endLoop(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
   codegen.curVecLength = 1;
   // End a while-loop.
   if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
-    rewriter.setInsertionPointToEnd(&whileOp.after().front());
     genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv,
-                      merger.lat(li).bits, whileOp.results());
+                      merger.lat(li).bits, whileOp);
     return needsUniv;
   }
   // End a for-loop.
-  if (codegen.redVal) {
-    rewriter.create<scf::YieldOp>(op.getLoc(), codegen.redVal);
-    codegen.redVal = loop->getResult(0);
-  }
+  genForInduction(merger, codegen, rewriter, op, loop);
   return false;
 }
 
@@ -1288,11 +1339,14 @@ static void endLoopSeq(Merger &merger, CodeGen &codegen,
                        PatternRewriter &rewriter, linalg::GenericOp op,
                        unsigned exp, unsigned idx, unsigned ldx) {
   assert(codegen.curVecLength == 1);
-  // Finalize any pending reduction.
-  genReductionEnd(merger, codegen, rewriter, op);
-  // Unmark bookkeeping of invariants and loop index.
-  genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false);
   codegen.loops[idx] = Value();
+  // Bring a pending reduction back from SIMD form when sequence ends.
+  if (codegen.redVal)
+    if (auto vtp = codegen.redVal.getType().dyn_cast<VectorType>())
+      updateReduc(merger, codegen,
+                  genVectorReducEnd(codegen, rewriter, op.getLoc(), vtp));
+  // Unmark bookkeeping of invariants and loop index.
+  genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/false);
 }
 
 /// Recursively generates code while computing iteration lattices in order
@@ -1327,6 +1381,7 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
 
     // Visit all lattices points with Li >= Lj to generate the
     // loop-body, possibly with if statements for coiteration.
+    Value ifInput = codegen.redVal;
     bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr;
     for (unsigned j = 0; j < lsize; j++) {
       unsigned lj = merger.set(lts)[j];
@@ -1337,7 +1392,7 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
           scf::IfOp ifOp =
               genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple);
           genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1);
-          rewriter.setInsertionPointToStart(&ifOp.elseRegion().front());
+          endIf(merger, codegen, rewriter, op, ifOp, ifInput);
         } else {
           genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1);
         }
@@ -1347,7 +1402,6 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
     // End a loop.
     needsUniv =
         endLoop(merger, codegen, rewriter, op, loop, idx, li, needsUniv);
-    rewriter.setInsertionPointAfter(loop);
   }
 
   // End a loop sequence.
@@ -1426,18 +1480,19 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
       return failure();
 
     // Builds the tensor expression for the Linalg operation in SSA form.
-    Optional<unsigned> exp = merger.buildTensorExpFromLinalg(op);
-    if (!exp.hasValue())
+    Optional<unsigned> optExp = merger.buildTensorExpFromLinalg(op);
+    if (!optExp.hasValue())
       return failure();
+    unsigned exp = optExp.getValue();
 
     // Rejects an inadmissable tensor expression.
-    if (!isAdmissableTensorExp(merger, op, exp.getValue()))
+    if (!isAdmissableTensorExp(merger, op, exp))
       return failure();
 
     // Recursively generates code.
     CodeGen codegen(options, numTensors, numLoops);
     genBuffers(merger, codegen, rewriter, op);
-    genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0);
+    genStmt(merger, codegen, rewriter, op, topSort, exp, 0);
     genResult(merger, codegen, rewriter, op);
     return success();
   }

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_1d.mlir b/mlir/test/Dialect/SparseTensor/sparse_1d.mlir
index 01cbf09785396..7c5fc72c63288 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_1d.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_1d.mlir
@@ -842,24 +842,24 @@ func @two_way_inv_alt(%arga: tensor<16xf32, #SV>,
 }
 
 // CHECK-LABEL:   func @sum_reduction(
-// CHECK-SAME:                        %[[VAL_0:.*]]: tensor<?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
-// CHECK-SAME:                        %[[VAL_1:.*]]: tensor<f32>) -> tensor<f32> {
-// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_3:.*]] = arith.constant 1 : index
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<f32>) -> tensor<f32> {
+// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_4:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK:           %[[VAL_5:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_6:.*]] = memref.buffer_cast %[[VAL_1]] : memref<f32>
 // CHECK:           %[[VAL_7:.*]] = memref.alloc() : memref<f32>
 // CHECK:           memref.copy %[[VAL_6]], %[[VAL_7]] : memref<f32> to memref<f32>
-// CHECK:           %[[VAL_8:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref<?xindex>
-// CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_7]][] : memref<f32>
+// CHECK-DAG:       %[[VAL_8:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK-DAG:       %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK-DAG:       %[[VAL_10:.*]] = memref.load %[[VAL_7]][] : memref<f32>
 // CHECK:           %[[VAL_11:.*]] = scf.for %[[VAL_12:.*]] = %[[VAL_8]] to %[[VAL_9]] step %[[VAL_3]] iter_args(%[[VAL_13:.*]] = %[[VAL_10]]) -> (f32) {
 // CHECK:             %[[VAL_14:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_12]]] : memref<?xf32>
 // CHECK:             %[[VAL_15:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f32
 // CHECK:             scf.yield %[[VAL_15]] : f32
 // CHECK:           }
-// CHECK:           memref.store %[[VAL_16:.*]], %[[VAL_7]][] : memref<f32>
+// CHECK:           memref.store %[[VAL_11]], %[[VAL_7]][] : memref<f32>
 // CHECK:           %[[VAL_17:.*]] = memref.tensor_load %[[VAL_7]] : memref<f32>
 // CHECK:           return %[[VAL_17]] : tensor<f32>
 // CHECK:         }
@@ -885,11 +885,11 @@ func @sum_reduction(%arga: tensor<?xf32, #SV>, %argx: tensor<f32>) -> tensor<f32
 }
 
 // CHECK-LABEL:   func @sum_reduction_ss(
-// CHECK-SAME:                           %[[VAL_0:.*0]]: tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
-// CHECK-SAME:                           %[[VAL_1:.*1]]: tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
-// CHECK-SAME:                           %[[VAL_2:.*2]]: tensor<f32>) -> tensor<f32> {
-// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK-SAME:      %[[VAL_0:.*0]]: tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:      %[[VAL_1:.*1]]: tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:      %[[VAL_2:.*2]]: tensor<f32>) -> tensor<f32> {
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK:           %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
@@ -899,71 +899,71 @@ func @sum_reduction(%arga: tensor<?xf32, #SV>, %argx: tensor<f32>) -> tensor<f32
 // CHECK:           %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref<f32>
 // CHECK:           %[[VAL_12:.*]] = memref.alloc() : memref<f32>
 // CHECK:           memref.copy %[[VAL_11]], %[[VAL_12]] : memref<f32> to memref<f32>
-// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK:           %[[VAL_17:.*]]:2 = scf.while (%[[VAL_18:.*]] = %[[VAL_13]], %[[VAL_19:.*]] = %[[VAL_15]]) : (index, index) -> (index, index) {
-// CHECK:             %[[VAL_20:.*]] = arith.cmpi ult, %[[VAL_18]], %[[VAL_14]] : index
-// CHECK:             %[[VAL_21:.*]] = arith.cmpi ult, %[[VAL_19]], %[[VAL_16]] : index
-// CHECK:             %[[VAL_22:.*]] = arith.andi %[[VAL_20]], %[[VAL_21]] : i1
-// CHECK:             scf.condition(%[[VAL_22]]) %[[VAL_18]], %[[VAL_19]] : index, index
+// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_12]][] : memref<f32>
+// CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK:           %[[VAL_17:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:           %[[VAL_18:.*]]:3 = scf.while (%[[VAL_19:.*]] = %[[VAL_14]], %[[VAL_20:.*]] = %[[VAL_16]], %[[VAL_21:.*]] = %[[VAL_13]]) : (index, index, f32) -> (index, index, f32) {
+// CHECK:             %[[VAL_22:.*]] = arith.cmpi ult, %[[VAL_19]], %[[VAL_15]] : index
+// CHECK:             %[[VAL_23:.*]] = arith.cmpi ult, %[[VAL_20]], %[[VAL_17]] : index
+// CHECK:             %[[VAL_24:.*]] = arith.andi %[[VAL_22]], %[[VAL_23]] : i1
+// CHECK:             scf.condition(%[[VAL_24]]) %[[VAL_19]], %[[VAL_20]], %[[VAL_21]] : index, index, f32
 // CHECK:           } do {
-// CHECK:           ^bb0(%[[VAL_23:.*]]: index, %[[VAL_24:.*]]: index):
-// CHECK:             %[[VAL_25:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_23]]] : memref<?xindex>
-// CHECK:             %[[VAL_26:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_24]]] : memref<?xindex>
-// CHECK:             %[[VAL_27:.*]] = arith.cmpi ult, %[[VAL_26]], %[[VAL_25]] : index
-// CHECK:             %[[VAL_28:.*]] = select %[[VAL_27]], %[[VAL_26]], %[[VAL_25]] : index
-// CHECK:             %[[VAL_29:.*]] = arith.cmpi eq, %[[VAL_25]], %[[VAL_28]] : index
-// CHECK:             %[[VAL_30:.*]] = arith.cmpi eq, %[[VAL_26]], %[[VAL_28]] : index
-// CHECK:             %[[VAL_31:.*]] = arith.andi %[[VAL_29]], %[[VAL_30]] : i1
-// CHECK:             scf.if %[[VAL_31]] {
-// CHECK:               %[[VAL_32:.*]] = memref.load %[[VAL_12]][] : memref<f32>
-// CHECK:               %[[VAL_33:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_23]]] : memref<?xf32>
-// CHECK:               %[[VAL_34:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_24]]] : memref<?xf32>
-// CHECK:               %[[VAL_35:.*]] = arith.addf %[[VAL_33]], %[[VAL_34]] : f32
-// CHECK:               %[[VAL_36:.*]] = arith.addf %[[VAL_32]], %[[VAL_35]] : f32
-// CHECK:               memref.store %[[VAL_36]], %[[VAL_12]][] : memref<f32>
+// CHECK:           ^bb0(%[[VAL_25:.*]]: index, %[[VAL_26:.*]]: index, %[[VAL_27:.*]]: f32):
+// CHECK:             %[[VAL_28:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_25]]] : memref<?xindex>
+// CHECK:             %[[VAL_29:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_26]]] : memref<?xindex>
+// CHECK:             %[[VAL_30:.*]] = arith.cmpi ult, %[[VAL_29]], %[[VAL_28]] : index
+// CHECK:             %[[VAL_31:.*]] = select %[[VAL_30]], %[[VAL_29]], %[[VAL_28]] : index
+// CHECK:             %[[VAL_32:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_31]] : index
+// CHECK:             %[[VAL_33:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_31]] : index
+// CHECK:             %[[VAL_34:.*]] = arith.andi %[[VAL_32]], %[[VAL_33]] : i1
+// CHECK:             %[[VAL_35:.*]] = scf.if %[[VAL_34]] -> (f32) {
+// CHECK:               %[[VAL_36:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_25]]] : memref<?xf32>
+// CHECK:               %[[VAL_37:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_26]]] : memref<?xf32>
+// CHECK:               %[[VAL_38:.*]] = arith.addf %[[VAL_36]], %[[VAL_37]] : f32
+// CHECK:               %[[VAL_39:.*]] = arith.addf %[[VAL_27]], %[[VAL_38]] : f32
+// CHECK:               scf.yield %[[VAL_39]] : f32
 // CHECK:             } else {
-// CHECK:               %[[VAL_37:.*]] = arith.cmpi eq, %[[VAL_25]], %[[VAL_28]] : index
-// CHECK:               scf.if %[[VAL_37]] {
-// CHECK:                 %[[VAL_38:.*]] = memref.load %[[VAL_12]][] : memref<f32>
-// CHECK:                 %[[VAL_39:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_23]]] : memref<?xf32>
-// CHECK:                 %[[VAL_40:.*]] = arith.addf %[[VAL_38]], %[[VAL_39]] : f32
-// CHECK:                 memref.store %[[VAL_40]], %[[VAL_12]][] : memref<f32>
+// CHECK:               %[[VAL_40:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_31]] : index
+// CHECK:               %[[VAL_41:.*]] = scf.if %[[VAL_40]] -> (f32) {
+// CHECK:                 %[[VAL_42:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_25]]] : memref<?xf32>
+// CHECK:                 %[[VAL_43:.*]] = arith.addf %[[VAL_27]], %[[VAL_42]] : f32
+// CHECK:                 scf.yield %[[VAL_43]] : f32
 // CHECK:               } else {
-// CHECK:                 %[[VAL_41:.*]] = arith.cmpi eq, %[[VAL_26]], %[[VAL_28]] : index
-// CHECK:                 scf.if %[[VAL_41]] {
-// CHECK:                   %[[VAL_42:.*]] = memref.load %[[VAL_12]][] : memref<f32>
-// CHECK:                   %[[VAL_43:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_24]]] : memref<?xf32>
-// CHECK:                   %[[VAL_44:.*]] = arith.addf %[[VAL_42]], %[[VAL_43]] : f32
-// CHECK:                   memref.store %[[VAL_44]], %[[VAL_12]][] : memref<f32>
+// CHECK:                 %[[VAL_44:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_31]] : index
+// CHECK:                 %[[VAL_45:.*]] = scf.if %[[VAL_44]] -> (f32) {
+// CHECK:                   %[[VAL_46:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_26]]] : memref<?xf32>
+// CHECK:                   %[[VAL_47:.*]] = arith.addf %[[VAL_27]], %[[VAL_46]] : f32
+// CHECK:                   scf.yield %[[VAL_47]] : f32
 // CHECK:                 } else {
+// CHECK:                   scf.yield %[[VAL_27]] : f32
 // CHECK:                 }
+// CHECK:                 scf.yield %[[VAL_48:.*]] : f32
 // CHECK:               }
+// CHECK:               scf.yield %[[VAL_49:.*]] : f32
 // CHECK:             }
-// CHECK:             %[[VAL_45:.*]] = arith.cmpi eq, %[[VAL_25]], %[[VAL_28]] : index
-// CHECK:             %[[VAL_46:.*]] = arith.addi %[[VAL_23]], %[[VAL_4]] : index
-// CHECK:             %[[VAL_47:.*]] = select %[[VAL_45]], %[[VAL_46]], %[[VAL_23]] : index
-// CHECK:             %[[VAL_48:.*]] = arith.cmpi eq, %[[VAL_26]], %[[VAL_28]] : index
-// CHECK:             %[[VAL_49:.*]] = arith.addi %[[VAL_24]], %[[VAL_4]] : index
-// CHECK:             %[[VAL_50:.*]] = select %[[VAL_48]], %[[VAL_49]], %[[VAL_24]] : index
-// CHECK:             scf.yield %[[VAL_47]], %[[VAL_50]] : index, index
+// CHECK:             %[[VAL_50:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_31]] : index
+// CHECK:             %[[VAL_51:.*]] = arith.addi %[[VAL_25]], %[[VAL_4]] : index
+// CHECK:             %[[VAL_52:.*]] = select %[[VAL_50]], %[[VAL_51]], %[[VAL_25]] : index
+// CHECK:             %[[VAL_53:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_31]] : index
+// CHECK:             %[[VAL_54:.*]] = arith.addi %[[VAL_26]], %[[VAL_4]] : index
+// CHECK:             %[[VAL_55:.*]] = select %[[VAL_53]], %[[VAL_54]], %[[VAL_26]] : index
+// CHECK:             scf.yield %[[VAL_52]], %[[VAL_55]], %[[VAL_56:.*]] : index, index, f32
 // CHECK:           }
-// CHECK:           %[[VAL_51:.*]] = memref.load %[[VAL_12]][] : memref<f32>
-// CHECK:           %[[VAL_52:.*]] = scf.for %[[VAL_53:.*]] = %[[VAL_54:.*]]#0 to %[[VAL_14]] step %[[VAL_4]] iter_args(%[[VAL_55:.*]] = %[[VAL_51]]) -> (f32) {
-// CHECK:             %[[VAL_56:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_53]]] : memref<?xf32>
-// CHECK:             %[[VAL_57:.*]] = arith.addf %[[VAL_55]], %[[VAL_56]] : f32
-// CHECK:             scf.yield %[[VAL_57]] : f32
+// CHECK:           %[[VAL_57:.*]] = scf.for %[[VAL_58:.*]] = %[[VAL_59:.*]]#0 to %[[VAL_15]] step %[[VAL_4]] iter_args(%[[VAL_60:.*]] = %[[VAL_59]]#2) -> (f32) {
+// CHECK:             %[[VAL_61:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_58]]] : memref<?xf32>
+// CHECK:             %[[VAL_62:.*]] = arith.addf %[[VAL_60]], %[[VAL_61]] : f32
+// CHECK:             scf.yield %[[VAL_62]] : f32
 // CHECK:           }
-// CHECK:           %[[VAL_58:.*]] = scf.for %[[VAL_59:.*]] = %[[VAL_60:.*]]#1 to %[[VAL_16]] step %[[VAL_4]] iter_args(%[[VAL_61:.*]] = %[[VAL_62:.*]]) -> (f32) {
-// CHECK:             %[[VAL_63:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_59]]] : memref<?xf32>
-// CHECK:             %[[VAL_64:.*]] = arith.addf %[[VAL_61]], %[[VAL_63]] : f32
-// CHECK:             scf.yield %[[VAL_64]] : f32
+// CHECK:           %[[VAL_63:.*]] = scf.for %[[VAL_64:.*]] = %[[VAL_65:.*]]#1 to %[[VAL_17]] step %[[VAL_4]] iter_args(%[[VAL_66:.*]] = %[[VAL_67:.*]]) -> (f32) {
+// CHECK:             %[[VAL_68:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_64]]] : memref<?xf32>
+// CHECK:             %[[VAL_69:.*]] = arith.addf %[[VAL_66]], %[[VAL_68]] : f32
+// CHECK:             scf.yield %[[VAL_69]] : f32
 // CHECK:           }
-// CHECK:           memref.store %[[VAL_65:.*]], %[[VAL_12]][] : memref<f32>
-// CHECK:           %[[VAL_66:.*]] = memref.tensor_load %[[VAL_12]] : memref<f32>
-// CHECK:           return %[[VAL_66]] : tensor<f32>
+// CHECK:           memref.store %[[VAL_70:.*]], %[[VAL_12]][] : memref<f32>
+// CHECK:           %[[VAL_71:.*]] = memref.tensor_load %[[VAL_12]] : memref<f32>
+// CHECK:           return %[[VAL_71]] : tensor<f32>
 // CHECK:         }
 func @sum_reduction_ss(%arga: tensor<16xf32, #SV>,
                        %argb: tensor<16xf32, #SV>,
@@ -993,12 +993,12 @@ func @sum_reduction_ss(%arga: tensor<16xf32, #SV>,
 }
 
 // CHECK-LABEL:   func @sum_reduction_inv(
-// CHECK-SAME:                            %[[VAL_0:.*0]]: tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
-// CHECK-SAME:                            %[[VAL_1:.*1]]: tensor<f32>,
-// CHECK-SAME:                            %[[VAL_2:.*2]]: tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
-// CHECK-SAME:                            %[[VAL_3:.*3]]: tensor<f32>) -> tensor<f32> {
-// CHECK:           %[[VAL_4:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK-SAME:      %[[VAL_0:.*0]]: tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:      %[[VAL_1:.*1]]: tensor<f32>,
+// CHECK-SAME:      %[[VAL_2:.*2]]: tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:      %[[VAL_3:.*3]]: tensor<f32>) -> tensor<f32> {
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
@@ -1009,75 +1009,75 @@ func @sum_reduction_ss(%arga: tensor<16xf32, #SV>,
 // CHECK:           %[[VAL_13:.*]] = memref.buffer_cast %[[VAL_3]] : memref<f32>
 // CHECK:           %[[VAL_14:.*]] = memref.alloc() : memref<f32>
 // CHECK:           memref.copy %[[VAL_13]], %[[VAL_14]] : memref<f32> to memref<f32>
-// CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_9]][] : memref<f32>
-// CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK:           %[[VAL_17:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK:           %[[VAL_18:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK:           %[[VAL_19:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK:           %[[VAL_20:.*]]:2 = scf.while (%[[VAL_21:.*]] = %[[VAL_16]], %[[VAL_22:.*]] = %[[VAL_18]]) : (index, index) -> (index, index) {
-// CHECK:             %[[VAL_23:.*]] = arith.cmpi ult, %[[VAL_21]], %[[VAL_17]] : index
-// CHECK:             %[[VAL_24:.*]] = arith.cmpi ult, %[[VAL_22]], %[[VAL_19]] : index
-// CHECK:             %[[VAL_25:.*]] = arith.andi %[[VAL_23]], %[[VAL_24]] : i1
-// CHECK:             scf.condition(%[[VAL_25]]) %[[VAL_21]], %[[VAL_22]] : index, index
+// CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_14]][] : memref<f32>
+// CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_9]][] : memref<f32>
+// CHECK:           %[[VAL_17:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:           %[[VAL_18:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:           %[[VAL_19:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:           %[[VAL_20:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:           %[[VAL_21:.*]]:3 = scf.while (%[[VAL_22:.*]] = %[[VAL_17]], %[[VAL_23:.*]] = %[[VAL_19]], %[[VAL_24:.*]] = %[[VAL_15]]) : (index, index, f32) -> (index, index, f32) {
+// CHECK:             %[[VAL_25:.*]] = arith.cmpi ult, %[[VAL_22]], %[[VAL_18]] : index
+// CHECK:             %[[VAL_26:.*]] = arith.cmpi ult, %[[VAL_23]], %[[VAL_20]] : index
+// CHECK:             %[[VAL_27:.*]] = arith.andi %[[VAL_25]], %[[VAL_26]] : i1
+// CHECK:             scf.condition(%[[VAL_27]]) %[[VAL_22]], %[[VAL_23]], %[[VAL_24]] : index, index, f32
 // CHECK:           } do {
-// CHECK:           ^bb0(%[[VAL_26:.*]]: index, %[[VAL_27:.*]]: index):
-// CHECK:             %[[VAL_28:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_26]]] : memref<?xindex>
-// CHECK:             %[[VAL_29:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_27]]] : memref<?xindex>
-// CHECK:             %[[VAL_30:.*]] = arith.cmpi ult, %[[VAL_29]], %[[VAL_28]] : index
-// CHECK:             %[[VAL_31:.*]] = select %[[VAL_30]], %[[VAL_29]], %[[VAL_28]] : index
-// CHECK:             %[[VAL_32:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_31]] : index
-// CHECK:             %[[VAL_33:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_31]] : index
-// CHECK:             %[[VAL_34:.*]] = arith.andi %[[VAL_32]], %[[VAL_33]] : i1
-// CHECK:             scf.if %[[VAL_34]] {
-// CHECK:               %[[VAL_35:.*]] = memref.load %[[VAL_14]][] : memref<f32>
-// CHECK:               %[[VAL_36:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_26]]] : memref<?xf32>
-// CHECK:               %[[VAL_37:.*]] = arith.mulf %[[VAL_36]], %[[VAL_15]] : f32
-// CHECK:               %[[VAL_38:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_27]]] : memref<?xf32>
-// CHECK:               %[[VAL_39:.*]] = arith.addf %[[VAL_37]], %[[VAL_38]] : f32
-// CHECK:               %[[VAL_40:.*]] = arith.addf %[[VAL_35]], %[[VAL_39]] : f32
-// CHECK:               memref.store %[[VAL_40]], %[[VAL_14]][] : memref<f32>
+// CHECK:           ^bb0(%[[VAL_28:.*]]: index, %[[VAL_29:.*]]: index, %[[VAL_30:.*]]: f32):
+// CHECK:             %[[VAL_31:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_28]]] : memref<?xindex>
+// CHECK:             %[[VAL_32:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_29]]] : memref<?xindex>
+// CHECK:             %[[VAL_33:.*]] = arith.cmpi ult, %[[VAL_32]], %[[VAL_31]] : index
+// CHECK:             %[[VAL_34:.*]] = select %[[VAL_33]], %[[VAL_32]], %[[VAL_31]] : index
+// CHECK:             %[[VAL_35:.*]] = arith.cmpi eq, %[[VAL_31]], %[[VAL_34]] : index
+// CHECK:             %[[VAL_36:.*]] = arith.cmpi eq, %[[VAL_32]], %[[VAL_34]] : index
+// CHECK:             %[[VAL_37:.*]] = arith.andi %[[VAL_35]], %[[VAL_36]] : i1
+// CHECK:             %[[VAL_38:.*]] = scf.if %[[VAL_37]] -> (f32) {
+// CHECK:               %[[VAL_39:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_28]]] : memref<?xf32>
+// CHECK:               %[[VAL_40:.*]] = arith.mulf %[[VAL_39]], %[[VAL_16]] : f32
+// CHECK:               %[[VAL_41:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_29]]] : memref<?xf32>
+// CHECK:               %[[VAL_42:.*]] = arith.addf %[[VAL_40]], %[[VAL_41]] : f32
+// CHECK:               %[[VAL_43:.*]] = arith.addf %[[VAL_30]], %[[VAL_42]] : f32
+// CHECK:               scf.yield %[[VAL_43]] : f32
 // CHECK:             } else {
-// CHECK:               %[[VAL_41:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_31]] : index
-// CHECK:               scf.if %[[VAL_41]] {
-// CHECK:                 %[[VAL_42:.*]] = memref.load %[[VAL_14]][] : memref<f32>
-// CHECK:                 %[[VAL_43:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_26]]] : memref<?xf32>
-// CHECK:                 %[[VAL_44:.*]] = arith.mulf %[[VAL_43]], %[[VAL_15]] : f32
-// CHECK:                 %[[VAL_45:.*]] = arith.addf %[[VAL_42]], %[[VAL_44]] : f32
-// CHECK:                 memref.store %[[VAL_45]], %[[VAL_14]][] : memref<f32>
+// CHECK:               %[[VAL_44:.*]] = arith.cmpi eq, %[[VAL_31]], %[[VAL_34]] : index
+// CHECK:               %[[VAL_45:.*]] = scf.if %[[VAL_44]] -> (f32) {
+// CHECK:                 %[[VAL_46:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_28]]] : memref<?xf32>
+// CHECK:                 %[[VAL_47:.*]] = arith.mulf %[[VAL_46]], %[[VAL_16]] : f32
+// CHECK:                 %[[VAL_48:.*]] = arith.addf %[[VAL_30]], %[[VAL_47]] : f32
+// CHECK:                 scf.yield %[[VAL_48]] : f32
 // CHECK:               } else {
-// CHECK:                 %[[VAL_46:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_31]] : index
-// CHECK:                 scf.if %[[VAL_46]] {
-// CHECK:                   %[[VAL_47:.*]] = memref.load %[[VAL_14]][] : memref<f32>
-// CHECK:                   %[[VAL_48:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_27]]] : memref<?xf32>
-// CHECK:                   %[[VAL_49:.*]] = arith.addf %[[VAL_47]], %[[VAL_48]] : f32
-// CHECK:                   memref.store %[[VAL_49]], %[[VAL_14]][] : memref<f32>
+// CHECK:                 %[[VAL_49:.*]] = arith.cmpi eq, %[[VAL_32]], %[[VAL_34]] : index
+// CHECK:                 %[[VAL_50:.*]] = scf.if %[[VAL_49]] -> (f32) {
+// CHECK:                   %[[VAL_51:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_29]]] : memref<?xf32>
+// CHECK:                   %[[VAL_52:.*]] = arith.addf %[[VAL_30]], %[[VAL_51]] : f32
+// CHECK:                   scf.yield %[[VAL_52]] : f32
 // CHECK:                 } else {
+// CHECK:                   scf.yield %[[VAL_30]] : f32
 // CHECK:                 }
+// CHECK:                 scf.yield %[[VAL_53:.*]] : f32
 // CHECK:               }
+// CHECK:               scf.yield %[[VAL_54:.*]] : f32
 // CHECK:             }
-// CHECK:             %[[VAL_50:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_31]] : index
-// CHECK:             %[[VAL_51:.*]] = arith.addi %[[VAL_26]], %[[VAL_5]] : index
-// CHECK:             %[[VAL_52:.*]] = select %[[VAL_50]], %[[VAL_51]], %[[VAL_26]] : index
-// CHECK:             %[[VAL_53:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_31]] : index
-// CHECK:             %[[VAL_54:.*]] = arith.addi %[[VAL_27]], %[[VAL_5]] : index
-// CHECK:             %[[VAL_55:.*]] = select %[[VAL_53]], %[[VAL_54]], %[[VAL_27]] : index
-// CHECK:             scf.yield %[[VAL_52]], %[[VAL_55]] : index, index
+// CHECK:             %[[VAL_55:.*]] = arith.cmpi eq, %[[VAL_31]], %[[VAL_34]] : index
+// CHECK:             %[[VAL_56:.*]] = arith.addi %[[VAL_28]], %[[VAL_5]] : index
+// CHECK:             %[[VAL_57:.*]] = select %[[VAL_55]], %[[VAL_56]], %[[VAL_28]] : index
+// CHECK:             %[[VAL_58:.*]] = arith.cmpi eq, %[[VAL_32]], %[[VAL_34]] : index
+// CHECK:             %[[VAL_59:.*]] = arith.addi %[[VAL_29]], %[[VAL_5]] : index
+// CHECK:             %[[VAL_60:.*]] = select %[[VAL_58]], %[[VAL_59]], %[[VAL_29]] : index
+// CHECK:             scf.yield %[[VAL_57]], %[[VAL_60]], %[[VAL_61:.*]] : index, index, f32
 // CHECK:           }
-// CHECK:           %[[VAL_56:.*]] = memref.load %[[VAL_14]][] : memref<f32>
-// CHECK:           %[[VAL_57:.*]] = scf.for %[[VAL_58:.*]] = %[[VAL_59:.*]]#0 to %[[VAL_17]] step %[[VAL_5]] iter_args(%[[VAL_60:.*]] = %[[VAL_56]]) -> (f32) {
-// CHECK:             %[[VAL_61:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_58]]] : memref<?xf32>
-// CHECK:             %[[VAL_62:.*]] = arith.mulf %[[VAL_61]], %[[VAL_15]] : f32
-// CHECK:             %[[VAL_63:.*]] = arith.addf %[[VAL_60]], %[[VAL_62]] : f32
-// CHECK:             scf.yield %[[VAL_63]] : f32
+// CHECK:           %[[VAL_62:.*]] = scf.for %[[VAL_63:.*]] = %[[VAL_64:.*]]#0 to %[[VAL_18]] step %[[VAL_5]] iter_args(%[[VAL_65:.*]] = %[[VAL_64]]#2) -> (f32) {
+// CHECK:             %[[VAL_66:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_63]]] : memref<?xf32>
+// CHECK:             %[[VAL_67:.*]] = arith.mulf %[[VAL_66]], %[[VAL_16]] : f32
+// CHECK:             %[[VAL_68:.*]] = arith.addf %[[VAL_65]], %[[VAL_67]] : f32
+// CHECK:             scf.yield %[[VAL_68]] : f32
 // CHECK:           }
-// CHECK:           %[[VAL_64:.*]] = scf.for %[[VAL_65:.*]] = %[[VAL_66:.*]]#1 to %[[VAL_19]] step %[[VAL_5]] iter_args(%[[VAL_67:.*]] = %[[VAL_68:.*]]) -> (f32) {
-// CHECK:             %[[VAL_69:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_65]]] : memref<?xf32>
-// CHECK:             %[[VAL_70:.*]] = arith.addf %[[VAL_67]], %[[VAL_69]] : f32
-// CHECK:             scf.yield %[[VAL_70]] : f32
+// CHECK:           %[[VAL_69:.*]] = scf.for %[[VAL_70:.*]] = %[[VAL_71:.*]]#1 to %[[VAL_20]] step %[[VAL_5]] iter_args(%[[VAL_72:.*]] = %[[VAL_73:.*]]) -> (f32) {
+// CHECK:             %[[VAL_74:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_70]]] : memref<?xf32>
+// CHECK:             %[[VAL_75:.*]] = arith.addf %[[VAL_72]], %[[VAL_74]] : f32
+// CHECK:             scf.yield %[[VAL_75]] : f32
 // CHECK:           }
-// CHECK:           memref.store %[[VAL_71:.*]], %[[VAL_14]][] : memref<f32>
-// CHECK:           %[[VAL_72:.*]] = memref.tensor_load %[[VAL_14]] : memref<f32>
-// CHECK:           return %[[VAL_72]] : tensor<f32>
+// CHECK:           memref.store %[[VAL_76:.*]], %[[VAL_14]][] : memref<f32>
+// CHECK:           %[[VAL_77:.*]] = memref.tensor_load %[[VAL_14]] : memref<f32>
+// CHECK:           return %[[VAL_77]] : tensor<f32>
 // CHECK:         }
 func @sum_reduction_inv(%arga: tensor<16xf32, #SV>,
                         %argb: tensor<f32>,
@@ -1289,12 +1289,12 @@ func @four_tensors_op(%arga: tensor<?xf64>,
 }
 
 // CHECK-LABEL:   func @red3s(
-// CHECK-SAME:                %[[VAL_0:.*0]]: tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
-// CHECK-SAME:                %[[VAL_1:.*1]]: tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
-// CHECK-SAME:                %[[VAL_2:.*2]]: tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
-// CHECK-SAME:                %[[VAL_3:.*3]]: tensor<f64>) -> tensor<f64> {
-// CHECK:           %[[VAL_4:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK-SAME:      %[[VAL_0:.*0]]: tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:      %[[VAL_1:.*1]]: tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:      %[[VAL_2:.*2]]: tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:      %[[VAL_3:.*3]]: tensor<f64>) -> tensor<f64> {
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf64>
@@ -1307,277 +1307,275 @@ func @four_tensors_op(%arga: tensor<?xf64>,
 // CHECK:           %[[VAL_15:.*]] = memref.buffer_cast %[[VAL_3]] : memref<f64>
 // CHECK:           %[[VAL_16:.*]] = memref.alloc() : memref<f64>
 // CHECK:           memref.copy %[[VAL_15]], %[[VAL_16]] : memref<f64> to memref<f64>
-// CHECK:           %[[VAL_17:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK:           %[[VAL_18:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK:           %[[VAL_19:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK:           %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK:           %[[VAL_21:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK:           %[[VAL_22:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK:           %[[VAL_23:.*]]:3 = scf.while (%[[VAL_24:.*]] = %[[VAL_17]], %[[VAL_25:.*]] = %[[VAL_19]], %[[VAL_26:.*]] = %[[VAL_21]]) : (index, index, index) -> (index, index, index) {
-// CHECK:             %[[VAL_27:.*]] = arith.cmpi ult, %[[VAL_24]], %[[VAL_18]] : index
-// CHECK:             %[[VAL_28:.*]] = arith.cmpi ult, %[[VAL_25]], %[[VAL_20]] : index
-// CHECK:             %[[VAL_29:.*]] = arith.andi %[[VAL_27]], %[[VAL_28]] : i1
-// CHECK:             %[[VAL_30:.*]] = arith.cmpi ult, %[[VAL_26]], %[[VAL_22]] : index
+// CHECK:           %[[VAL_17:.*]] = memref.load %[[VAL_16]][] : memref<f64>
+// CHECK:           %[[VAL_18:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:           %[[VAL_19:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:           %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:           %[[VAL_21:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:           %[[VAL_22:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:           %[[VAL_23:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:           %[[VAL_24:.*]]:4 = scf.while (%[[VAL_25:.*]] = %[[VAL_18]], %[[VAL_26:.*]] = %[[VAL_20]], %[[VAL_27:.*]] = %[[VAL_22]], %[[VAL_28:.*]] = %[[VAL_17]]) : (index, index, index, f64) -> (index, index, index, f64) {
+// CHECK:             %[[VAL_29:.*]] = arith.cmpi ult, %[[VAL_25]], %[[VAL_19]] : index
+// CHECK:             %[[VAL_30:.*]] = arith.cmpi ult, %[[VAL_26]], %[[VAL_21]] : index
 // CHECK:             %[[VAL_31:.*]] = arith.andi %[[VAL_29]], %[[VAL_30]] : i1
-// CHECK:             scf.condition(%[[VAL_31]]) %[[VAL_24]], %[[VAL_25]], %[[VAL_26]] : index, index, index
+// CHECK:             %[[VAL_32:.*]] = arith.cmpi ult, %[[VAL_27]], %[[VAL_23]] : index
+// CHECK:             %[[VAL_33:.*]] = arith.andi %[[VAL_31]], %[[VAL_32]] : i1
+// CHECK:             scf.condition(%[[VAL_33]]) %[[VAL_25]], %[[VAL_26]], %[[VAL_27]], %[[VAL_28]] : index, index, index, f64
 // CHECK:           } do {
-// CHECK:           ^bb0(%[[VAL_32:.*]]: index, %[[VAL_33:.*]]: index, %[[VAL_34:.*]]: index):
-// CHECK:             %[[VAL_35:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_32]]] : memref<?xindex>
-// CHECK:             %[[VAL_36:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_33]]] : memref<?xindex>
-// CHECK:             %[[VAL_37:.*]] = arith.cmpi ult, %[[VAL_36]], %[[VAL_35]] : index
-// CHECK:             %[[VAL_38:.*]] = select %[[VAL_37]], %[[VAL_36]], %[[VAL_35]] : index
-// CHECK:             %[[VAL_39:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_34]]] : memref<?xindex>
+// CHECK:           ^bb0(%[[VAL_34:.*]]: index, %[[VAL_35:.*]]: index, %[[VAL_36:.*]]: index, %[[VAL_37:.*]]: f64):
+// CHECK:             %[[VAL_38:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_34]]] : memref<?xindex>
+// CHECK:             %[[VAL_39:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_35]]] : memref<?xindex>
 // CHECK:             %[[VAL_40:.*]] = arith.cmpi ult, %[[VAL_39]], %[[VAL_38]] : index
 // CHECK:             %[[VAL_41:.*]] = select %[[VAL_40]], %[[VAL_39]], %[[VAL_38]] : index
-// CHECK:             %[[VAL_42:.*]] = arith.cmpi eq, %[[VAL_35]], %[[VAL_41]] : index
-// CHECK:             %[[VAL_43:.*]] = arith.cmpi eq, %[[VAL_36]], %[[VAL_41]] : index
-// CHECK:             %[[VAL_44:.*]] = arith.andi %[[VAL_42]], %[[VAL_43]] : i1
-// CHECK:             %[[VAL_45:.*]] = arith.cmpi eq, %[[VAL_39]], %[[VAL_41]] : index
-// CHECK:             %[[VAL_46:.*]] = arith.andi %[[VAL_44]], %[[VAL_45]] : i1
-// CHECK:             scf.if %[[VAL_46]] {
-// CHECK:               %[[VAL_47:.*]] = memref.load %[[VAL_16]][] : memref<f64>
-// CHECK:               %[[VAL_48:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_32]]] : memref<?xf64>
-// CHECK:               %[[VAL_49:.*]] = arith.addf %[[VAL_47]], %[[VAL_48]] : f64
-// CHECK:               %[[VAL_50:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_33]]] : memref<?xf64>
-// CHECK:               %[[VAL_51:.*]] = arith.addf %[[VAL_49]], %[[VAL_50]] : f64
-// CHECK:               %[[VAL_52:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_34]]] : memref<?xf64>
-// CHECK:               %[[VAL_53:.*]] = arith.addf %[[VAL_51]], %[[VAL_52]] : f64
-// CHECK:               memref.store %[[VAL_53]], %[[VAL_16]][] : memref<f64>
+// CHECK:             %[[VAL_42:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_36]]] : memref<?xindex>
+// CHECK:             %[[VAL_43:.*]] = arith.cmpi ult, %[[VAL_42]], %[[VAL_41]] : index
+// CHECK:             %[[VAL_44:.*]] = select %[[VAL_43]], %[[VAL_42]], %[[VAL_41]] : index
+// CHECK:             %[[VAL_45:.*]] = arith.cmpi eq, %[[VAL_38]], %[[VAL_44]] : index
+// CHECK:             %[[VAL_46:.*]] = arith.cmpi eq, %[[VAL_39]], %[[VAL_44]] : index
+// CHECK:             %[[VAL_47:.*]] = arith.andi %[[VAL_45]], %[[VAL_46]] : i1
+// CHECK:             %[[VAL_48:.*]] = arith.cmpi eq, %[[VAL_42]], %[[VAL_44]] : index
+// CHECK:             %[[VAL_49:.*]] = arith.andi %[[VAL_47]], %[[VAL_48]] : i1
+// CHECK:             %[[VAL_50:.*]] = scf.if %[[VAL_49]] -> (f64) {
+// CHECK:               %[[VAL_51:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_34]]] : memref<?xf64>
+// CHECK:               %[[VAL_52:.*]] = arith.addf %[[VAL_37]], %[[VAL_51]] : f64
+// CHECK:               %[[VAL_53:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_35]]] : memref<?xf64>
+// CHECK:               %[[VAL_54:.*]] = arith.addf %[[VAL_52]], %[[VAL_53]] : f64
+// CHECK:               %[[VAL_55:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_36]]] : memref<?xf64>
+// CHECK:               %[[VAL_56:.*]] = arith.addf %[[VAL_54]], %[[VAL_55]] : f64
+// CHECK:               scf.yield %[[VAL_56]] : f64
 // CHECK:             } else {
-// CHECK:               %[[VAL_54:.*]] = arith.cmpi eq, %[[VAL_36]], %[[VAL_41]] : index
-// CHECK:               %[[VAL_55:.*]] = arith.cmpi eq, %[[VAL_39]], %[[VAL_41]] : index
-// CHECK:               %[[VAL_56:.*]] = arith.andi %[[VAL_54]], %[[VAL_55]] : i1
-// CHECK:               scf.if %[[VAL_56]] {
-// CHECK:                 %[[VAL_57:.*]] = memref.load %[[VAL_16]][] : memref<f64>
-// CHECK:                 %[[VAL_58:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_33]]] : memref<?xf64>
-// CHECK:                 %[[VAL_59:.*]] = arith.addf %[[VAL_57]], %[[VAL_58]] : f64
-// CHECK:                 %[[VAL_60:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_34]]] : memref<?xf64>
-// CHECK:                 %[[VAL_61:.*]] = arith.addf %[[VAL_59]], %[[VAL_60]] : f64
-// CHECK:                 memref.store %[[VAL_61]], %[[VAL_16]][] : memref<f64>
+// CHECK:               %[[VAL_57:.*]] = arith.cmpi eq, %[[VAL_39]], %[[VAL_44]] : index
+// CHECK:               %[[VAL_58:.*]] = arith.cmpi eq, %[[VAL_42]], %[[VAL_44]] : index
+// CHECK:               %[[VAL_59:.*]] = arith.andi %[[VAL_57]], %[[VAL_58]] : i1
+// CHECK:               %[[VAL_60:.*]] = scf.if %[[VAL_59]] -> (f64) {
+// CHECK:                 %[[VAL_61:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_35]]] : memref<?xf64>
+// CHECK:                 %[[VAL_62:.*]] = arith.addf %[[VAL_37]], %[[VAL_61]] : f64
+// CHECK:                 %[[VAL_63:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_36]]] : memref<?xf64>
+// CHECK:                 %[[VAL_64:.*]] = arith.addf %[[VAL_62]], %[[VAL_63]] : f64
+// CHECK:                 scf.yield %[[VAL_64]] : f64
 // CHECK:               } else {
-// CHECK:                 %[[VAL_62:.*]] = arith.cmpi eq, %[[VAL_35]], %[[VAL_41]] : index
-// CHECK:                 %[[VAL_63:.*]] = arith.cmpi eq, %[[VAL_39]], %[[VAL_41]] : index
-// CHECK:                 %[[VAL_64:.*]] = arith.andi %[[VAL_62]], %[[VAL_63]] : i1
-// CHECK:                 scf.if %[[VAL_64]] {
-// CHECK:                   %[[VAL_65:.*]] = memref.load %[[VAL_16]][] : memref<f64>
-// CHECK:                   %[[VAL_66:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_32]]] : memref<?xf64>
-// CHECK:                   %[[VAL_67:.*]] = arith.addf %[[VAL_65]], %[[VAL_66]] : f64
-// CHECK:                   %[[VAL_68:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_34]]] : memref<?xf64>
-// CHECK:                   %[[VAL_69:.*]] = arith.addf %[[VAL_67]], %[[VAL_68]] : f64
-// CHECK:                   memref.store %[[VAL_69]], %[[VAL_16]][] : memref<f64>
+// CHECK:                 %[[VAL_65:.*]] = arith.cmpi eq, %[[VAL_38]], %[[VAL_44]] : index
+// CHECK:                 %[[VAL_66:.*]] = arith.cmpi eq, %[[VAL_42]], %[[VAL_44]] : index
+// CHECK:                 %[[VAL_67:.*]] = arith.andi %[[VAL_65]], %[[VAL_66]] : i1
+// CHECK:                 %[[VAL_68:.*]] = scf.if %[[VAL_67]] -> (f64) {
+// CHECK:                   %[[VAL_69:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_34]]] : memref<?xf64>
+// CHECK:                   %[[VAL_70:.*]] = arith.addf %[[VAL_37]], %[[VAL_69]] : f64
+// CHECK:                   %[[VAL_71:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_36]]] : memref<?xf64>
+// CHECK:                   %[[VAL_72:.*]] = arith.addf %[[VAL_70]], %[[VAL_71]] : f64
+// CHECK:                   scf.yield %[[VAL_72]] : f64
 // CHECK:                 } else {
-// CHECK:                   %[[VAL_70:.*]] = arith.cmpi eq, %[[VAL_39]], %[[VAL_41]] : index
-// CHECK:                   scf.if %[[VAL_70]] {
-// CHECK:                     %[[VAL_71:.*]] = memref.load %[[VAL_16]][] : memref<f64>
-// CHECK:                     %[[VAL_72:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_34]]] : memref<?xf64>
-// CHECK:                     %[[VAL_73:.*]] = arith.addf %[[VAL_71]], %[[VAL_72]] : f64
-// CHECK:                     memref.store %[[VAL_73]], %[[VAL_16]][] : memref<f64>
+// CHECK:                   %[[VAL_73:.*]] = arith.cmpi eq, %[[VAL_42]], %[[VAL_44]] : index
+// CHECK:                   %[[VAL_74:.*]] = scf.if %[[VAL_73]] -> (f64) {
+// CHECK:                     %[[VAL_75:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_36]]] : memref<?xf64>
+// CHECK:                     %[[VAL_76:.*]] = arith.addf %[[VAL_37]], %[[VAL_75]] : f64
+// CHECK:                     scf.yield %[[VAL_76]] : f64
 // CHECK:                   } else {
-// CHECK:                     %[[VAL_74:.*]] = arith.cmpi eq, %[[VAL_35]], %[[VAL_41]] : index
-// CHECK:                     %[[VAL_75:.*]] = arith.cmpi eq, %[[VAL_36]], %[[VAL_41]] : index
-// CHECK:                     %[[VAL_76:.*]] = arith.andi %[[VAL_74]], %[[VAL_75]] : i1
-// CHECK:                     scf.if %[[VAL_76]] {
-// CHECK:                       %[[VAL_77:.*]] = memref.load %[[VAL_16]][] : memref<f64>
-// CHECK:                       %[[VAL_78:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_32]]] : memref<?xf64>
-// CHECK:                       %[[VAL_79:.*]] = arith.addf %[[VAL_77]], %[[VAL_78]] : f64
-// CHECK:                       %[[VAL_80:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_33]]] : memref<?xf64>
-// CHECK:                       %[[VAL_81:.*]] = arith.addf %[[VAL_79]], %[[VAL_80]] : f64
-// CHECK:                       memref.store %[[VAL_81]], %[[VAL_16]][] : memref<f64>
+// CHECK:                     %[[VAL_77:.*]] = arith.cmpi eq, %[[VAL_38]], %[[VAL_44]] : index
+// CHECK:                     %[[VAL_78:.*]] = arith.cmpi eq, %[[VAL_39]], %[[VAL_44]] : index
+// CHECK:                     %[[VAL_79:.*]] = arith.andi %[[VAL_77]], %[[VAL_78]] : i1
+// CHECK:                     %[[VAL_80:.*]] = scf.if %[[VAL_79]] -> (f64) {
+// CHECK:                       %[[VAL_81:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_34]]] : memref<?xf64>
+// CHECK:                       %[[VAL_82:.*]] = arith.addf %[[VAL_37]], %[[VAL_81]] : f64
+// CHECK:                       %[[VAL_83:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_35]]] : memref<?xf64>
+// CHECK:                       %[[VAL_84:.*]] = arith.addf %[[VAL_82]], %[[VAL_83]] : f64
+// CHECK:                       scf.yield %[[VAL_84]] : f64
 // CHECK:                     } else {
-// CHECK:                       %[[VAL_82:.*]] = arith.cmpi eq, %[[VAL_36]], %[[VAL_41]] : index
-// CHECK:                       scf.if %[[VAL_82]] {
-// CHECK:                         %[[VAL_83:.*]] = memref.load %[[VAL_16]][] : memref<f64>
-// CHECK:                         %[[VAL_84:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_33]]] : memref<?xf64>
-// CHECK:                         %[[VAL_85:.*]] = arith.addf %[[VAL_83]], %[[VAL_84]] : f64
-// CHECK:                         memref.store %[[VAL_85]], %[[VAL_16]][] : memref<f64>
+// CHECK:                       %[[VAL_85:.*]] = arith.cmpi eq, %[[VAL_39]], %[[VAL_44]] : index
+// CHECK:                       %[[VAL_86:.*]] = scf.if %[[VAL_85]] -> (f64) {
+// CHECK:                         %[[VAL_87:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_35]]] : memref<?xf64>
+// CHECK:                         %[[VAL_88:.*]] = arith.addf %[[VAL_37]], %[[VAL_87]] : f64
+// CHECK:                         scf.yield %[[VAL_88]] : f64
 // CHECK:                       } else {
-// CHECK:                         %[[VAL_86:.*]] = arith.cmpi eq, %[[VAL_35]], %[[VAL_41]] : index
-// CHECK:                         scf.if %[[VAL_86]] {
-// CHECK:                           %[[VAL_87:.*]] = memref.load %[[VAL_16]][] : memref<f64>
-// CHECK:                           %[[VAL_88:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_32]]] : memref<?xf64>
-// CHECK:                           %[[VAL_89:.*]] = arith.addf %[[VAL_87]], %[[VAL_88]] : f64
-// CHECK:                           memref.store %[[VAL_89]], %[[VAL_16]][] : memref<f64>
+// CHECK:                         %[[VAL_89:.*]] = arith.cmpi eq, %[[VAL_38]], %[[VAL_44]] : index
+// CHECK:                         %[[VAL_90:.*]] = scf.if %[[VAL_89]] -> (f64) {
+// CHECK:                           %[[VAL_91:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_34]]] : memref<?xf64>
+// CHECK:                           %[[VAL_92:.*]] = arith.addf %[[VAL_37]], %[[VAL_91]] : f64
+// CHECK:                           scf.yield %[[VAL_92]] : f64
 // CHECK:                         } else {
+// CHECK:                           scf.yield %[[VAL_37]] : f64
 // CHECK:                         }
+// CHECK:                         scf.yield %[[VAL_93:.*]] : f64
 // CHECK:                       }
+// CHECK:                       scf.yield %[[VAL_94:.*]] : f64
 // CHECK:                     }
+// CHECK:                     scf.yield %[[VAL_95:.*]] : f64
 // CHECK:                   }
+// CHECK:                   scf.yield %[[VAL_96:.*]] : f64
 // CHECK:                 }
+// CHECK:                 scf.yield %[[VAL_97:.*]] : f64
 // CHECK:               }
+// CHECK:               scf.yield %[[VAL_98:.*]] : f64
 // CHECK:             }
-// CHECK:             %[[VAL_90:.*]] = arith.cmpi eq, %[[VAL_35]], %[[VAL_41]] : index
-// CHECK:             %[[VAL_91:.*]] = arith.addi %[[VAL_32]], %[[VAL_5]] : index
-// CHECK:             %[[VAL_92:.*]] = select %[[VAL_90]], %[[VAL_91]], %[[VAL_32]] : index
-// CHECK:             %[[VAL_93:.*]] = arith.cmpi eq, %[[VAL_36]], %[[VAL_41]] : index
-// CHECK:             %[[VAL_94:.*]] = arith.addi %[[VAL_33]], %[[VAL_5]] : index
-// CHECK:             %[[VAL_95:.*]] = select %[[VAL_93]], %[[VAL_94]], %[[VAL_33]] : index
-// CHECK:             %[[VAL_96:.*]] = arith.cmpi eq, %[[VAL_39]], %[[VAL_41]] : index
-// CHECK:             %[[VAL_97:.*]] = arith.addi %[[VAL_34]], %[[VAL_5]] : index
-// CHECK:             %[[VAL_98:.*]] = select %[[VAL_96]], %[[VAL_97]], %[[VAL_34]] : index
-// CHECK:             scf.yield %[[VAL_92]], %[[VAL_95]], %[[VAL_98]] : index, index, index
+// CHECK:             %[[VAL_99:.*]] = arith.cmpi eq, %[[VAL_38]], %[[VAL_44]] : index
+// CHECK:             %[[VAL_100:.*]] = arith.addi %[[VAL_34]], %[[VAL_5]] : index
+// CHECK:             %[[VAL_101:.*]] = select %[[VAL_99]], %[[VAL_100]], %[[VAL_34]] : index
+// CHECK:             %[[VAL_102:.*]] = arith.cmpi eq, %[[VAL_39]], %[[VAL_44]] : index
+// CHECK:             %[[VAL_103:.*]] = arith.addi %[[VAL_35]], %[[VAL_5]] : index
+// CHECK:             %[[VAL_104:.*]] = select %[[VAL_102]], %[[VAL_103]], %[[VAL_35]] : index
+// CHECK:             %[[VAL_105:.*]] = arith.cmpi eq, %[[VAL_42]], %[[VAL_44]] : index
+// CHECK:             %[[VAL_106:.*]] = arith.addi %[[VAL_36]], %[[VAL_5]] : index
+// CHECK:             %[[VAL_107:.*]] = select %[[VAL_105]], %[[VAL_106]], %[[VAL_36]] : index
+// CHECK:             scf.yield %[[VAL_101]], %[[VAL_104]], %[[VAL_107]], %[[VAL_108:.*]] : index, index, index, f64
 // CHECK:           }
-// CHECK:           %[[VAL_99:.*]]:2 = scf.while (%[[VAL_100:.*]] = %[[VAL_101:.*]]#1, %[[VAL_102:.*]] = %[[VAL_101]]#2) : (index, index) -> (index, index) {
-// CHECK:             %[[VAL_103:.*]] = arith.cmpi ult, %[[VAL_100]], %[[VAL_20]] : index
-// CHECK:             %[[VAL_104:.*]] = arith.cmpi ult, %[[VAL_102]], %[[VAL_22]] : index
-// CHECK:             %[[VAL_105:.*]] = arith.andi %[[VAL_103]], %[[VAL_104]] : i1
-// CHECK:             scf.condition(%[[VAL_105]]) %[[VAL_100]], %[[VAL_102]] : index, index
+// CHECK:           %[[VAL_109:.*]]:3 = scf.while (%[[VAL_110:.*]] = %[[VAL_111:.*]]#1, %[[VAL_112:.*]] = %[[VAL_111]]#2, %[[VAL_113:.*]] = %[[VAL_111]]#3) : (index, index, f64) -> (index, index, f64) {
+// CHECK:             %[[VAL_114:.*]] = arith.cmpi ult, %[[VAL_110]], %[[VAL_21]] : index
+// CHECK:             %[[VAL_115:.*]] = arith.cmpi ult, %[[VAL_112]], %[[VAL_23]] : index
+// CHECK:             %[[VAL_116:.*]] = arith.andi %[[VAL_114]], %[[VAL_115]] : i1
+// CHECK:             scf.condition(%[[VAL_116]]) %[[VAL_110]], %[[VAL_112]], %[[VAL_113]] : index, index, f64
 // CHECK:           } do {
-// CHECK:           ^bb0(%[[VAL_106:.*]]: index, %[[VAL_107:.*]]: index):
-// CHECK:             %[[VAL_108:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_106]]] : memref<?xindex>
-// CHECK:             %[[VAL_109:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_107]]] : memref<?xindex>
-// CHECK:             %[[VAL_110:.*]] = arith.cmpi ult, %[[VAL_109]], %[[VAL_108]] : index
-// CHECK:             %[[VAL_111:.*]] = select %[[VAL_110]], %[[VAL_109]], %[[VAL_108]] : index
-// CHECK:             %[[VAL_112:.*]] = arith.cmpi eq, %[[VAL_108]], %[[VAL_111]] : index
-// CHECK:             %[[VAL_113:.*]] = arith.cmpi eq, %[[VAL_109]], %[[VAL_111]] : index
-// CHECK:             %[[VAL_114:.*]] = arith.andi %[[VAL_112]], %[[VAL_113]] : i1
-// CHECK:             scf.if %[[VAL_114]] {
-// CHECK:               %[[VAL_115:.*]] = memref.load %[[VAL_16]][] : memref<f64>
-// CHECK:               %[[VAL_116:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_106]]] : memref<?xf64>
-// CHECK:               %[[VAL_117:.*]] = arith.addf %[[VAL_115]], %[[VAL_116]] : f64
-// CHECK:               %[[VAL_118:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_107]]] : memref<?xf64>
-// CHECK:               %[[VAL_119:.*]] = arith.addf %[[VAL_117]], %[[VAL_118]] : f64
-// CHECK:               memref.store %[[VAL_119]], %[[VAL_16]][] : memref<f64>
+// CHECK:           ^bb0(%[[VAL_117:.*]]: index, %[[VAL_118:.*]]: index, %[[VAL_119:.*]]: f64):
+// CHECK:             %[[VAL_120:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_117]]] : memref<?xindex>
+// CHECK:             %[[VAL_121:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_118]]] : memref<?xindex>
+// CHECK:             %[[VAL_122:.*]] = arith.cmpi ult, %[[VAL_121]], %[[VAL_120]] : index
+// CHECK:             %[[VAL_123:.*]] = select %[[VAL_122]], %[[VAL_121]], %[[VAL_120]] : index
+// CHECK:             %[[VAL_124:.*]] = arith.cmpi eq, %[[VAL_120]], %[[VAL_123]] : index
+// CHECK:             %[[VAL_125:.*]] = arith.cmpi eq, %[[VAL_121]], %[[VAL_123]] : index
+// CHECK:             %[[VAL_126:.*]] = arith.andi %[[VAL_124]], %[[VAL_125]] : i1
+// CHECK:             %[[VAL_127:.*]] = scf.if %[[VAL_126]] -> (f64) {
+// CHECK:               %[[VAL_128:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_117]]] : memref<?xf64>
+// CHECK:               %[[VAL_129:.*]] = arith.addf %[[VAL_119]], %[[VAL_128]] : f64
+// CHECK:               %[[VAL_130:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_118]]] : memref<?xf64>
+// CHECK:               %[[VAL_131:.*]] = arith.addf %[[VAL_129]], %[[VAL_130]] : f64
+// CHECK:               scf.yield %[[VAL_131]] : f64
 // CHECK:             } else {
-// CHECK:               %[[VAL_120:.*]] = arith.cmpi eq, %[[VAL_109]], %[[VAL_111]] : index
-// CHECK:               scf.if %[[VAL_120]] {
-// CHECK:                 %[[VAL_121:.*]] = memref.load %[[VAL_16]][] : memref<f64>
-// CHECK:                 %[[VAL_122:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_107]]] : memref<?xf64>
-// CHECK:                 %[[VAL_123:.*]] = arith.addf %[[VAL_121]], %[[VAL_122]] : f64
-// CHECK:                 memref.store %[[VAL_123]], %[[VAL_16]][] : memref<f64>
+// CHECK:               %[[VAL_132:.*]] = arith.cmpi eq, %[[VAL_121]], %[[VAL_123]] : index
+// CHECK:               %[[VAL_133:.*]] = scf.if %[[VAL_132]] -> (f64) {
+// CHECK:                 %[[VAL_134:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_118]]] : memref<?xf64>
+// CHECK:                 %[[VAL_135:.*]] = arith.addf %[[VAL_119]], %[[VAL_134]] : f64
+// CHECK:                 scf.yield %[[VAL_135]] : f64
 // CHECK:               } else {
-// CHECK:                 %[[VAL_124:.*]] = arith.cmpi eq, %[[VAL_108]], %[[VAL_111]] : index
-// CHECK:                 scf.if %[[VAL_124]] {
-// CHECK:                   %[[VAL_125:.*]] = memref.load %[[VAL_16]][] : memref<f64>
-// CHECK:                   %[[VAL_126:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_106]]] : memref<?xf64>
-// CHECK:                   %[[VAL_127:.*]] = arith.addf %[[VAL_125]], %[[VAL_126]] : f64
-// CHECK:                   memref.store %[[VAL_127]], %[[VAL_16]][] : memref<f64>
+// CHECK:                 %[[VAL_136:.*]] = arith.cmpi eq, %[[VAL_120]], %[[VAL_123]] : index
+// CHECK:                 %[[VAL_137:.*]] = scf.if %[[VAL_136]] -> (f64) {
+// CHECK:                   %[[VAL_138:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_117]]] : memref<?xf64>
+// CHECK:                   %[[VAL_139:.*]] = arith.addf %[[VAL_119]], %[[VAL_138]] : f64
+// CHECK:                   scf.yield %[[VAL_139]] : f64
 // CHECK:                 } else {
+// CHECK:                   scf.yield %[[VAL_119]] : f64
 // CHECK:                 }
+// CHECK:                 scf.yield %[[VAL_140:.*]] : f64
 // CHECK:               }
+// CHECK:               scf.yield %[[VAL_141:.*]] : f64
 // CHECK:             }
-// CHECK:             %[[VAL_128:.*]] = arith.cmpi eq, %[[VAL_108]], %[[VAL_111]] : index
-// CHECK:             %[[VAL_129:.*]] = arith.addi %[[VAL_106]], %[[VAL_5]] : index
-// CHECK:             %[[VAL_130:.*]] = select %[[VAL_128]], %[[VAL_129]], %[[VAL_106]] : index
-// CHECK:             %[[VAL_131:.*]] = arith.cmpi eq, %[[VAL_109]], %[[VAL_111]] : index
-// CHECK:             %[[VAL_132:.*]] = arith.addi %[[VAL_107]], %[[VAL_5]] : index
-// CHECK:             %[[VAL_133:.*]] = select %[[VAL_131]], %[[VAL_132]], %[[VAL_107]] : index
-// CHECK:             scf.yield %[[VAL_130]], %[[VAL_133]] : index, index
+// CHECK:             %[[VAL_142:.*]] = arith.cmpi eq, %[[VAL_120]], %[[VAL_123]] : index
+// CHECK:             %[[VAL_143:.*]] = arith.addi %[[VAL_117]], %[[VAL_5]] : index
+// CHECK:             %[[VAL_144:.*]] = select %[[VAL_142]], %[[VAL_143]], %[[VAL_117]] : index
+// CHECK:             %[[VAL_145:.*]] = arith.cmpi eq, %[[VAL_121]], %[[VAL_123]] : index
+// CHECK:             %[[VAL_146:.*]] = arith.addi %[[VAL_118]], %[[VAL_5]] : index
+// CHECK:             %[[VAL_147:.*]] = select %[[VAL_145]], %[[VAL_146]], %[[VAL_118]] : index
+// CHECK:             scf.yield %[[VAL_144]], %[[VAL_147]], %[[VAL_148:.*]] : index, index, f64
 // CHECK:           }
-// CHECK:           %[[VAL_134:.*]]:2 = scf.while (%[[VAL_135:.*]] = %[[VAL_136:.*]]#0, %[[VAL_137:.*]] = %[[VAL_138:.*]]#1) : (index, index) -> (index, index) {
-// CHECK:             %[[VAL_139:.*]] = arith.cmpi ult, %[[VAL_135]], %[[VAL_18]] : index
-// CHECK:             %[[VAL_140:.*]] = arith.cmpi ult, %[[VAL_137]], %[[VAL_22]] : index
-// CHECK:             %[[VAL_141:.*]] = arith.andi %[[VAL_139]], %[[VAL_140]] : i1
-// CHECK:             scf.condition(%[[VAL_141]]) %[[VAL_135]], %[[VAL_137]] : index, index
+// CHECK:           %[[VAL_149:.*]]:3 = scf.while (%[[VAL_150:.*]] = %[[VAL_151:.*]]#0, %[[VAL_152:.*]] = %[[VAL_153:.*]]#1, %[[VAL_154:.*]] = %[[VAL_153]]#2) : (index, index, f64) -> (index, index, f64) {
+// CHECK:             %[[VAL_155:.*]] = arith.cmpi ult, %[[VAL_150]], %[[VAL_19]] : index
+// CHECK:             %[[VAL_156:.*]] = arith.cmpi ult, %[[VAL_152]], %[[VAL_23]] : index
+// CHECK:             %[[VAL_157:.*]] = arith.andi %[[VAL_155]], %[[VAL_156]] : i1
+// CHECK:             scf.condition(%[[VAL_157]]) %[[VAL_150]], %[[VAL_152]], %[[VAL_154]] : index, index, f64
 // CHECK:           } do {
-// CHECK:           ^bb0(%[[VAL_142:.*]]: index, %[[VAL_143:.*]]: index):
-// CHECK:             %[[VAL_144:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_142]]] : memref<?xindex>
-// CHECK:             %[[VAL_145:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_143]]] : memref<?xindex>
-// CHECK:             %[[VAL_146:.*]] = arith.cmpi ult, %[[VAL_145]], %[[VAL_144]] : index
-// CHECK:             %[[VAL_147:.*]] = select %[[VAL_146]], %[[VAL_145]], %[[VAL_144]] : index
-// CHECK:             %[[VAL_148:.*]] = arith.cmpi eq, %[[VAL_144]], %[[VAL_147]] : index
-// CHECK:             %[[VAL_149:.*]] = arith.cmpi eq, %[[VAL_145]], %[[VAL_147]] : index
-// CHECK:             %[[VAL_150:.*]] = arith.andi %[[VAL_148]], %[[VAL_149]] : i1
-// CHECK:             scf.if %[[VAL_150]] {
-// CHECK:               %[[VAL_151:.*]] = memref.load %[[VAL_16]][] : memref<f64>
-// CHECK:               %[[VAL_152:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_142]]] : memref<?xf64>
-// CHECK:               %[[VAL_153:.*]] = arith.addf %[[VAL_151]], %[[VAL_152]] : f64
-// CHECK:               %[[VAL_154:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_143]]] : memref<?xf64>
-// CHECK:               %[[VAL_155:.*]] = arith.addf %[[VAL_153]], %[[VAL_154]] : f64
-// CHECK:               memref.store %[[VAL_155]], %[[VAL_16]][] : memref<f64>
+// CHECK:           ^bb0(%[[VAL_158:.*]]: index, %[[VAL_159:.*]]: index, %[[VAL_160:.*]]: f64):
+// CHECK:             %[[VAL_161:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_158]]] : memref<?xindex>
+// CHECK:             %[[VAL_162:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_159]]] : memref<?xindex>
+// CHECK:             %[[VAL_163:.*]] = arith.cmpi ult, %[[VAL_162]], %[[VAL_161]] : index
+// CHECK:             %[[VAL_164:.*]] = select %[[VAL_163]], %[[VAL_162]], %[[VAL_161]] : index
+// CHECK:             %[[VAL_165:.*]] = arith.cmpi eq, %[[VAL_161]], %[[VAL_164]] : index
+// CHECK:             %[[VAL_166:.*]] = arith.cmpi eq, %[[VAL_162]], %[[VAL_164]] : index
+// CHECK:             %[[VAL_167:.*]] = arith.andi %[[VAL_165]], %[[VAL_166]] : i1
+// CHECK:             %[[VAL_168:.*]] = scf.if %[[VAL_167]] -> (f64) {
+// CHECK:               %[[VAL_169:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_158]]] : memref<?xf64>
+// CHECK:               %[[VAL_170:.*]] = arith.addf %[[VAL_160]], %[[VAL_169]] : f64
+// CHECK:               %[[VAL_171:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_159]]] : memref<?xf64>
+// CHECK:               %[[VAL_172:.*]] = arith.addf %[[VAL_170]], %[[VAL_171]] : f64
+// CHECK:               scf.yield %[[VAL_172]] : f64
 // CHECK:             } else {
-// CHECK:               %[[VAL_156:.*]] = arith.cmpi eq, %[[VAL_145]], %[[VAL_147]] : index
-// CHECK:               scf.if %[[VAL_156]] {
-// CHECK:                 %[[VAL_157:.*]] = memref.load %[[VAL_16]][] : memref<f64>
-// CHECK:                 %[[VAL_158:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_143]]] : memref<?xf64>
-// CHECK:                 %[[VAL_159:.*]] = arith.addf %[[VAL_157]], %[[VAL_158]] : f64
-// CHECK:                 memref.store %[[VAL_159]], %[[VAL_16]][] : memref<f64>
+// CHECK:               %[[VAL_173:.*]] = arith.cmpi eq, %[[VAL_162]], %[[VAL_164]] : index
+// CHECK:               %[[VAL_174:.*]] = scf.if %[[VAL_173]] -> (f64) {
+// CHECK:                 %[[VAL_175:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_159]]] : memref<?xf64>
+// CHECK:                 %[[VAL_176:.*]] = arith.addf %[[VAL_160]], %[[VAL_175]] : f64
+// CHECK:                 scf.yield %[[VAL_176]] : f64
 // CHECK:               } else {
-// CHECK:                 %[[VAL_160:.*]] = arith.cmpi eq, %[[VAL_144]], %[[VAL_147]] : index
-// CHECK:                 scf.if %[[VAL_160]] {
-// CHECK:                   %[[VAL_161:.*]] = memref.load %[[VAL_16]][] : memref<f64>
-// CHECK:                   %[[VAL_162:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_142]]] : memref<?xf64>
-// CHECK:                   %[[VAL_163:.*]] = arith.addf %[[VAL_161]], %[[VAL_162]] : f64
-// CHECK:                   memref.store %[[VAL_163]], %[[VAL_16]][] : memref<f64>
+// CHECK:                 %[[VAL_177:.*]] = arith.cmpi eq, %[[VAL_161]], %[[VAL_164]] : index
+// CHECK:                 %[[VAL_178:.*]] = scf.if %[[VAL_177]] -> (f64) {
+// CHECK:                   %[[VAL_179:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_158]]] : memref<?xf64>
+// CHECK:                   %[[VAL_180:.*]] = arith.addf %[[VAL_160]], %[[VAL_179]] : f64
+// CHECK:                   scf.yield %[[VAL_180]] : f64
 // CHECK:                 } else {
+// CHECK:                   scf.yield %[[VAL_160]] : f64
 // CHECK:                 }
+// CHECK:                 scf.yield %[[VAL_181:.*]] : f64
 // CHECK:               }
+// CHECK:               scf.yield %[[VAL_182:.*]] : f64
 // CHECK:             }
-// CHECK:             %[[VAL_164:.*]] = arith.cmpi eq, %[[VAL_144]], %[[VAL_147]] : index
-// CHECK:             %[[VAL_165:.*]] = arith.addi %[[VAL_142]], %[[VAL_5]] : index
-// CHECK:             %[[VAL_166:.*]] = select %[[VAL_164]], %[[VAL_165]], %[[VAL_142]] : index
-// CHECK:             %[[VAL_167:.*]] = arith.cmpi eq, %[[VAL_145]], %[[VAL_147]] : index
-// CHECK:             %[[VAL_168:.*]] = arith.addi %[[VAL_143]], %[[VAL_5]] : index
-// CHECK:             %[[VAL_169:.*]] = select %[[VAL_167]], %[[VAL_168]], %[[VAL_143]] : index
-// CHECK:             scf.yield %[[VAL_166]], %[[VAL_169]] : index, index
+// CHECK:             %[[VAL_183:.*]] = arith.cmpi eq, %[[VAL_161]], %[[VAL_164]] : index
+// CHECK:             %[[VAL_184:.*]] = arith.addi %[[VAL_158]], %[[VAL_5]] : index
+// CHECK:             %[[VAL_185:.*]] = select %[[VAL_183]], %[[VAL_184]], %[[VAL_158]] : index
+// CHECK:             %[[VAL_186:.*]] = arith.cmpi eq, %[[VAL_162]], %[[VAL_164]] : index
+// CHECK:             %[[VAL_187:.*]] = arith.addi %[[VAL_159]], %[[VAL_5]] : index
+// CHECK:             %[[VAL_188:.*]] = select %[[VAL_186]], %[[VAL_187]], %[[VAL_159]] : index
+// CHECK:             scf.yield %[[VAL_185]], %[[VAL_188]], %[[VAL_189:.*]] : index, index, f64
 // CHECK:           }
-// CHECK:           %[[VAL_170:.*]] = memref.load %[[VAL_16]][] : memref<f64>
-// CHECK:           %[[VAL_171:.*]] = scf.for %[[VAL_172:.*]] = %[[VAL_173:.*]]#1 to %[[VAL_22]] step %[[VAL_5]] iter_args(%[[VAL_174:.*]] = %[[VAL_170]]) -> (f64) {
-// CHECK:             %[[VAL_175:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_172]]] : memref<?xf64>
-// CHECK:             %[[VAL_176:.*]] = arith.addf %[[VAL_174]], %[[VAL_175]] : f64
-// CHECK:             scf.yield %[[VAL_176]] : f64
+// CHECK:           %[[VAL_190:.*]] = scf.for %[[VAL_191:.*]] = %[[VAL_192:.*]]#1 to %[[VAL_23]] step %[[VAL_5]] iter_args(%[[VAL_193:.*]] = %[[VAL_192]]#2) -> (f64) {
+// CHECK:             %[[VAL_194:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_191]]] : memref<?xf64>
+// CHECK:             %[[VAL_195:.*]] = arith.addf %[[VAL_193]], %[[VAL_194]] : f64
+// CHECK:             scf.yield %[[VAL_195]] : f64
 // CHECK:           }
-// CHECK:           memref.store %[[VAL_177:.*]], %[[VAL_16]][] : memref<f64>
-// CHECK:           %[[VAL_178:.*]]:2 = scf.while (%[[VAL_179:.*]] = %[[VAL_180:.*]]#0, %[[VAL_181:.*]] = %[[VAL_182:.*]]#0) : (index, index) -> (index, index) {
-// CHECK:             %[[VAL_183:.*]] = arith.cmpi ult, %[[VAL_179]], %[[VAL_18]] : index
-// CHECK:             %[[VAL_184:.*]] = arith.cmpi ult, %[[VAL_181]], %[[VAL_20]] : index
-// CHECK:             %[[VAL_185:.*]] = arith.andi %[[VAL_183]], %[[VAL_184]] : i1
-// CHECK:             scf.condition(%[[VAL_185]]) %[[VAL_179]], %[[VAL_181]] : index, index
+// CHECK:           %[[VAL_196:.*]]:3 = scf.while (%[[VAL_197:.*]] = %[[VAL_198:.*]]#0, %[[VAL_199:.*]] = %[[VAL_200:.*]]#0, %[[VAL_201:.*]] = %[[VAL_202:.*]]) : (index, index, f64) -> (index, index, f64) {
+// CHECK:             %[[VAL_203:.*]] = arith.cmpi ult, %[[VAL_197]], %[[VAL_19]] : index
+// CHECK:             %[[VAL_204:.*]] = arith.cmpi ult, %[[VAL_199]], %[[VAL_21]] : index
+// CHECK:             %[[VAL_205:.*]] = arith.andi %[[VAL_203]], %[[VAL_204]] : i1
+// CHECK:             scf.condition(%[[VAL_205]]) %[[VAL_197]], %[[VAL_199]], %[[VAL_201]] : index, index, f64
 // CHECK:           } do {
-// CHECK:           ^bb0(%[[VAL_186:.*]]: index, %[[VAL_187:.*]]: index):
-// CHECK:             %[[VAL_188:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_186]]] : memref<?xindex>
-// CHECK:             %[[VAL_189:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_187]]] : memref<?xindex>
-// CHECK:             %[[VAL_190:.*]] = arith.cmpi ult, %[[VAL_189]], %[[VAL_188]] : index
-// CHECK:             %[[VAL_191:.*]] = select %[[VAL_190]], %[[VAL_189]], %[[VAL_188]] : index
-// CHECK:             %[[VAL_192:.*]] = arith.cmpi eq, %[[VAL_188]], %[[VAL_191]] : index
-// CHECK:             %[[VAL_193:.*]] = arith.cmpi eq, %[[VAL_189]], %[[VAL_191]] : index
-// CHECK:             %[[VAL_194:.*]] = arith.andi %[[VAL_192]], %[[VAL_193]] : i1
-// CHECK:             scf.if %[[VAL_194]] {
-// CHECK:               %[[VAL_195:.*]] = memref.load %[[VAL_16]][] : memref<f64>
-// CHECK:               %[[VAL_196:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_186]]] : memref<?xf64>
-// CHECK:               %[[VAL_197:.*]] = arith.addf %[[VAL_195]], %[[VAL_196]] : f64
-// CHECK:               %[[VAL_198:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_187]]] : memref<?xf64>
-// CHECK:               %[[VAL_199:.*]] = arith.addf %[[VAL_197]], %[[VAL_198]] : f64
-// CHECK:               memref.store %[[VAL_199]], %[[VAL_16]][] : memref<f64>
+// CHECK:           ^bb0(%[[VAL_206:.*]]: index, %[[VAL_207:.*]]: index, %[[VAL_208:.*]]: f64):
+// CHECK:             %[[VAL_209:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_206]]] : memref<?xindex>
+// CHECK:             %[[VAL_210:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_207]]] : memref<?xindex>
+// CHECK:             %[[VAL_211:.*]] = arith.cmpi ult, %[[VAL_210]], %[[VAL_209]] : index
+// CHECK:             %[[VAL_212:.*]] = select %[[VAL_211]], %[[VAL_210]], %[[VAL_209]] : index
+// CHECK:             %[[VAL_213:.*]] = arith.cmpi eq, %[[VAL_209]], %[[VAL_212]] : index
+// CHECK:             %[[VAL_214:.*]] = arith.cmpi eq, %[[VAL_210]], %[[VAL_212]] : index
+// CHECK:             %[[VAL_215:.*]] = arith.andi %[[VAL_213]], %[[VAL_214]] : i1
+// CHECK:             %[[VAL_216:.*]] = scf.if %[[VAL_215]] -> (f64) {
+// CHECK:               %[[VAL_217:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_206]]] : memref<?xf64>
+// CHECK:               %[[VAL_218:.*]] = arith.addf %[[VAL_208]], %[[VAL_217]] : f64
+// CHECK:               %[[VAL_219:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_207]]] : memref<?xf64>
+// CHECK:               %[[VAL_220:.*]] = arith.addf %[[VAL_218]], %[[VAL_219]] : f64
+// CHECK:               scf.yield %[[VAL_220]] : f64
 // CHECK:             } else {
-// CHECK:               %[[VAL_200:.*]] = arith.cmpi eq, %[[VAL_189]], %[[VAL_191]] : index
-// CHECK:               scf.if %[[VAL_200]] {
-// CHECK:                 %[[VAL_201:.*]] = memref.load %[[VAL_16]][] : memref<f64>
-// CHECK:                 %[[VAL_202:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_187]]] : memref<?xf64>
-// CHECK:                 %[[VAL_203:.*]] = arith.addf %[[VAL_201]], %[[VAL_202]] : f64
-// CHECK:                 memref.store %[[VAL_203]], %[[VAL_16]][] : memref<f64>
+// CHECK:               %[[VAL_221:.*]] = arith.cmpi eq, %[[VAL_210]], %[[VAL_212]] : index
+// CHECK:               %[[VAL_222:.*]] = scf.if %[[VAL_221]] -> (f64) {
+// CHECK:                 %[[VAL_223:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_207]]] : memref<?xf64>
+// CHECK:                 %[[VAL_224:.*]] = arith.addf %[[VAL_208]], %[[VAL_223]] : f64
+// CHECK:                 scf.yield %[[VAL_224]] : f64
 // CHECK:               } else {
-// CHECK:                 %[[VAL_204:.*]] = arith.cmpi eq, %[[VAL_188]], %[[VAL_191]] : index
-// CHECK:                 scf.if %[[VAL_204]] {
-// CHECK:                   %[[VAL_205:.*]] = memref.load %[[VAL_16]][] : memref<f64>
-// CHECK:                   %[[VAL_206:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_186]]] : memref<?xf64>
-// CHECK:                   %[[VAL_207:.*]] = arith.addf %[[VAL_205]], %[[VAL_206]] : f64
-// CHECK:                   memref.store %[[VAL_207]], %[[VAL_16]][] : memref<f64>
+// CHECK:                 %[[VAL_225:.*]] = arith.cmpi eq, %[[VAL_209]], %[[VAL_212]] : index
+// CHECK:                 %[[VAL_226:.*]] = scf.if %[[VAL_225]] -> (f64) {
+// CHECK:                   %[[VAL_227:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_206]]] : memref<?xf64>
+// CHECK:                   %[[VAL_228:.*]] = arith.addf %[[VAL_208]], %[[VAL_227]] : f64
+// CHECK:                   scf.yield %[[VAL_228]] : f64
 // CHECK:                 } else {
+// CHECK:                   scf.yield %[[VAL_208]] : f64
 // CHECK:                 }
+// CHECK:                 scf.yield %[[VAL_229:.*]] : f64
 // CHECK:               }
+// CHECK:               scf.yield %[[VAL_230:.*]] : f64
 // CHECK:             }
-// CHECK:             %[[VAL_208:.*]] = arith.cmpi eq, %[[VAL_188]], %[[VAL_191]] : index
-// CHECK:             %[[VAL_209:.*]] = arith.addi %[[VAL_186]], %[[VAL_5]] : index
-// CHECK:             %[[VAL_210:.*]] = select %[[VAL_208]], %[[VAL_209]], %[[VAL_186]] : index
-// CHECK:             %[[VAL_211:.*]] = arith.cmpi eq, %[[VAL_189]], %[[VAL_191]] : index
-// CHECK:             %[[VAL_212:.*]] = arith.addi %[[VAL_187]], %[[VAL_5]] : index
-// CHECK:             %[[VAL_213:.*]] = select %[[VAL_211]], %[[VAL_212]], %[[VAL_187]] : index
-// CHECK:             scf.yield %[[VAL_210]], %[[VAL_213]] : index, index
+// CHECK:             %[[VAL_231:.*]] = arith.cmpi eq, %[[VAL_209]], %[[VAL_212]] : index
+// CHECK:             %[[VAL_232:.*]] = arith.addi %[[VAL_206]], %[[VAL_5]] : index
+// CHECK:             %[[VAL_233:.*]] = select %[[VAL_231]], %[[VAL_232]], %[[VAL_206]] : index
+// CHECK:             %[[VAL_234:.*]] = arith.cmpi eq, %[[VAL_210]], %[[VAL_212]] : index
+// CHECK:             %[[VAL_235:.*]] = arith.addi %[[VAL_207]], %[[VAL_5]] : index
+// CHECK:             %[[VAL_236:.*]] = select %[[VAL_234]], %[[VAL_235]], %[[VAL_207]] : index
+// CHECK:             scf.yield %[[VAL_233]], %[[VAL_236]], %[[VAL_237:.*]] : index, index, f64
 // CHECK:           }
-// CHECK:           %[[VAL_214:.*]] = memref.load %[[VAL_16]][] : memref<f64>
-// CHECK:           %[[VAL_215:.*]] = scf.for %[[VAL_216:.*]] = %[[VAL_217:.*]]#1 to %[[VAL_20]] step %[[VAL_5]] iter_args(%[[VAL_218:.*]] = %[[VAL_214]]) -> (f64) {
-// CHECK:             %[[VAL_219:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_216]]] : memref<?xf64>
-// CHECK:             %[[VAL_220:.*]] = arith.addf %[[VAL_218]], %[[VAL_219]] : f64
-// CHECK:             scf.yield %[[VAL_220]] : f64
+// CHECK:           %[[VAL_238:.*]] = scf.for %[[VAL_239:.*]] = %[[VAL_240:.*]]#1 to %[[VAL_21]] step %[[VAL_5]] iter_args(%[[VAL_241:.*]] = %[[VAL_240]]#2) -> (f64) {
+// CHECK:             %[[VAL_242:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_239]]] : memref<?xf64>
+// CHECK:             %[[VAL_243:.*]] = arith.addf %[[VAL_241]], %[[VAL_242]] : f64
+// CHECK:             scf.yield %[[VAL_243]] : f64
 // CHECK:           }
-// CHECK:           %[[VAL_221:.*]] = scf.for %[[VAL_222:.*]] = %[[VAL_223:.*]]#0 to %[[VAL_18]] step %[[VAL_5]] iter_args(%[[VAL_224:.*]] = %[[VAL_225:.*]]) -> (f64) {
-// CHECK:             %[[VAL_226:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_222]]] : memref<?xf64>
-// CHECK:             %[[VAL_227:.*]] = arith.addf %[[VAL_224]], %[[VAL_226]] : f64
-// CHECK:             scf.yield %[[VAL_227]] : f64
+// CHECK:           %[[VAL_244:.*]] = scf.for %[[VAL_245:.*]] = %[[VAL_246:.*]]#0 to %[[VAL_19]] step %[[VAL_5]] iter_args(%[[VAL_247:.*]] = %[[VAL_248:.*]]) -> (f64) {
+// CHECK:             %[[VAL_249:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_245]]] : memref<?xf64>
+// CHECK:             %[[VAL_250:.*]] = arith.addf %[[VAL_247]], %[[VAL_249]] : f64
+// CHECK:             scf.yield %[[VAL_250]] : f64
 // CHECK:           }
-// CHECK:           memref.store %[[VAL_228:.*]], %[[VAL_16]][] : memref<f64>
-// CHECK:           %[[VAL_229:.*]] = memref.tensor_load %[[VAL_16]] : memref<f64>
-// CHECK:           return %[[VAL_229]] : tensor<f64>
+// CHECK:           memref.store %[[VAL_251:.*]], %[[VAL_16]][] : memref<f64>
+// CHECK:           %[[VAL_252:.*]] = memref.tensor_load %[[VAL_16]] : memref<f64>
+// CHECK:           return %[[VAL_252]] : tensor<f64>
 // CHECK:         }
 func @red3s(%arga: tensor<?xf64, #SV>,
             %argb: tensor<?xf64, #SV>,

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
index 1de13929b36b6..954ad0622663e 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
@@ -870,12 +870,12 @@ func @mul_sd_ds(%arga: tensor<32x16xf32, #Tsd>, %argb: tensor<32x16xf32, #Tds>,
 }
 
 // CHECK-LABEL:   func @matvec(
-// CHECK-SAME:                 %[[VAL_0:.*]]: tensor<16x32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
-// CHECK-SAME:                 %[[VAL_1:.*]]: tensor<32xf32>,
-// CHECK-SAME:                 %[[VAL_2:.*]]: tensor<16xf32>) -> tensor<16xf32> {
-// CHECK-DAG:           %[[VAL_3:.*]] = arith.constant 16 : index
-// CHECK-DAG:           %[[VAL_4:.*]] = arith.constant 0 : index
-// CHECK-DAG:           %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<16x32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32xf32>,
+// CHECK-SAME:      %[[VAL_2:.*]]: tensor<16xf32>) -> tensor<16xf32> {
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 16 : index
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<16x32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<16x32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<16x32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
@@ -884,10 +884,10 @@ func @mul_sd_ds(%arga: tensor<32x16xf32, #Tsd>, %argb: tensor<32x16xf32, #Tds>,
 // CHECK:           %[[VAL_11:.*]] = memref.alloc() : memref<16xf32>
 // CHECK:           memref.copy %[[VAL_10]], %[[VAL_11]] : memref<16xf32> to memref<16xf32>
 // CHECK:           scf.for %[[VAL_12:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
-// CHECK:             %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<?xindex>
-// CHECK:             %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_5]] : index
-// CHECK:             %[[VAL_15:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_14]]] : memref<?xindex>
-// CHECK:             %[[VAL_16:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_12]]] : memref<16xf32>
+// CHECK-DAG:         %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<?xindex>
+// CHECK-DAG:         %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_5]] : index
+// CHECK-DAG:         %[[VAL_15:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_14]]] : memref<?xindex>
+// CHECK-DAG:         %[[VAL_16:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_12]]] : memref<16xf32>
 // CHECK:             %[[VAL_17:.*]] = scf.for %[[VAL_18:.*]] = %[[VAL_13]] to %[[VAL_15]] step %[[VAL_5]] iter_args(%[[VAL_19:.*]] = %[[VAL_16]]) -> (f32) {
 // CHECK:               %[[VAL_20:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_18]]] : memref<?xindex>
 // CHECK:               %[[VAL_21:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xf32>
@@ -896,7 +896,7 @@ func @mul_sd_ds(%arga: tensor<32x16xf32, #Tsd>, %argb: tensor<32x16xf32, #Tds>,
 // CHECK:               %[[VAL_24:.*]] = arith.addf %[[VAL_23]], %[[VAL_19]] : f32
 // CHECK:               scf.yield %[[VAL_24]] : f32
 // CHECK:             }
-// CHECK:             memref.store %[[VAL_25:.*]], %[[VAL_11]]{{\[}}%[[VAL_12]]] : memref<16xf32>
+// CHECK:             memref.store %[[VAL_17]], %[[VAL_11]]{{\[}}%[[VAL_12]]] : memref<16xf32>
 // CHECK:           }
 // CHECK:           %[[VAL_26:.*]] = memref.tensor_load %[[VAL_11]] : memref<16xf32>
 // CHECK:           return %[[VAL_26]] : tensor<16xf32>
@@ -923,30 +923,31 @@ func @matvec(%argA: tensor<16x32xf32, #Tds>, %argb: tensor<32xf32>, %argx: tenso
 }
 
 // CHECK-LABEL:   func @sum_reduction(
-// CHECK-SAME:                        %[[VAL_0:.*]]: tensor<10x20xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
-// CHECK-SAME:                        %[[VAL_1:.*]]: tensor<f32>) -> tensor<f32> {
-// CHECK-DAG:           %[[VAL_2:.*]] = arith.constant 10 : index
-// CHECK-DAG:           %[[VAL_3:.*]] = arith.constant 0 : index
-// CHECK-DAG:           %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<10x20xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<10x20xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<f32>) -> tensor<f32> {
+// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 10 : index
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<10x20xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK:           %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10x20xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_7:.*]] = memref.buffer_cast %[[VAL_1]] : memref<f32>
 // CHECK:           %[[VAL_8:.*]] = memref.alloc() : memref<f32>
 // CHECK:           memref.copy %[[VAL_7]], %[[VAL_8]] : memref<f32> to memref<f32>
-// CHECK:           scf.for %[[VAL_9:.*]] = %[[VAL_3]] to %[[VAL_2]] step %[[VAL_4]] {
-// CHECK:             %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_9]]] : memref<?xindex>
-// CHECK:             %[[VAL_11:.*]] = arith.addi %[[VAL_9]], %[[VAL_4]] : index
-// CHECK:             %[[VAL_12:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_11]]] : memref<?xindex>
-// CHECK:             %[[VAL_13:.*]] = memref.load %[[VAL_8]][] : memref<f32>
-// CHECK:             %[[VAL_14:.*]] = scf.for %[[VAL_15:.*]] = %[[VAL_10]] to %[[VAL_12]] step %[[VAL_4]] iter_args(%[[VAL_16:.*]] = %[[VAL_13]]) -> (f32) {
-// CHECK:               %[[VAL_17:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_15]]] : memref<?xf32>
-// CHECK:               %[[VAL_18:.*]] = arith.addf %[[VAL_16]], %[[VAL_17]] : f32
-// CHECK:               scf.yield %[[VAL_18]] : f32
+// CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<f32>
+// CHECK:           %[[VAL_10:.*]] = scf.for %[[VAL_11:.*]] = %[[VAL_4]] to %[[VAL_2]] step %[[VAL_3]] iter_args(%[[VAL_12:.*]] = %[[VAL_9]]) -> (f32) {
+// CHECK:             %[[VAL_13:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_11]]] : memref<?xindex>
+// CHECK:             %[[VAL_14:.*]] = arith.addi %[[VAL_11]], %[[VAL_3]] : index
+// CHECK:             %[[VAL_15:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_14]]] : memref<?xindex>
+// CHECK:             %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_13]] to %[[VAL_15]] step %[[VAL_3]] iter_args(%[[VAL_18:.*]] = %[[VAL_12]]) -> (f32) {
+// CHECK:               %[[VAL_19:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_17]]] : memref<?xf32>
+// CHECK:               %[[VAL_20:.*]] = arith.addf %[[VAL_18]], %[[VAL_19]] : f32
+// CHECK:               scf.yield %[[VAL_20]] : f32
 // CHECK:             }
-// CHECK:             memref.store %[[VAL_19:.*]], %[[VAL_8]][] : memref<f32>
+// CHECK:             scf.yield %[[VAL_16]] : f32
 // CHECK:           }
-// CHECK:           %[[VAL_20:.*]] = memref.tensor_load %[[VAL_8]] : memref<f32>
-// CHECK:           return %[[VAL_20]] : tensor<f32>
+// CHECK:           memref.store %[[VAL_10]], %[[VAL_8]][] : memref<f32>
+// CHECK:           %[[VAL_23:.*]] = memref.tensor_load %[[VAL_8]] : memref<f32>
+// CHECK:           return %[[VAL_23]] : tensor<f32>
 // CHECK:         }
 func @sum_reduction(%arga: tensor<10x20xf32, #Tds>, %argx: tensor<f32>) -> tensor<f32> {
   %0 = linalg.generic #trait_sum_reduction
@@ -1020,10 +1021,10 @@ func @scale(%arga: tensor<?x?xf64, #Tds>, %argx: tensor<?x?xf64>) -> tensor<?x?x
 }
 
 // CHECK-LABEL:   func @sampled_dense_dense(
-// CHECK-SAME:                              %[[VAL_0:.*0]]: tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
-// CHECK-SAME:                              %[[VAL_1:.*1]]: tensor<?x?xf32>,
-// CHECK-SAME:                              %[[VAL_2:.*2]]: tensor<?x?xf32>,
-// CHECK-SAME:                              %[[VAL_3:.*3]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK-SAME:      %[[VAL_0:.*0]]: tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:      %[[VAL_1:.*1]]: tensor<?x?xf32>,
+// CHECK-SAME:      %[[VAL_2:.*2]]: tensor<?x?xf32>,
+// CHECK-SAME:      %[[VAL_3:.*3]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
 // CHECK:           %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_5:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
@@ -1047,9 +1048,9 @@ func @scale(%arga: tensor<?x?xf64, #Tds>, %argx: tensor<?x?xf64>) -> tensor<?x?x
 // CHECK:             %[[VAL_23:.*]] = arith.addi %[[VAL_20]], %[[VAL_5]] : index
 // CHECK:             %[[VAL_24:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_23]]] : memref<?xindex>
 // CHECK:             scf.for %[[VAL_25:.*]] = %[[VAL_22]] to %[[VAL_24]] step %[[VAL_5]] {
-// CHECK:               %[[VAL_26:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_25]]] : memref<?xindex>
-// CHECK:               %[[VAL_27:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_25]]] : memref<?xf32>
-// CHECK:               %[[VAL_28:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_21]], %[[VAL_26]]] : memref<?x?xf32>
+// CHECK-DAG:           %[[VAL_26:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_25]]] : memref<?xindex>
+// CHECK-DAG:           %[[VAL_27:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_25]]] : memref<?xf32>
+// CHECK-DAG:           %[[VAL_28:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_21]], %[[VAL_26]]] : memref<?x?xf32>
 // CHECK:               %[[VAL_29:.*]] = scf.for %[[VAL_30:.*]] = %[[VAL_4]] to %[[VAL_12]] step %[[VAL_5]] iter_args(%[[VAL_31:.*]] = %[[VAL_28]]) -> (f32) {
 // CHECK:                 %[[VAL_32:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_21]], %[[VAL_30]]] : memref<?x?xf32>
 // CHECK:                 %[[VAL_33:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_30]], %[[VAL_26]]] : memref<?x?xf32>
@@ -1058,7 +1059,7 @@ func @scale(%arga: tensor<?x?xf64, #Tds>, %argx: tensor<?x?xf64>) -> tensor<?x?x
 // CHECK:                 %[[VAL_36:.*]] = arith.addf %[[VAL_31]], %[[VAL_35]] : f32
 // CHECK:                 scf.yield %[[VAL_36]] : f32
 // CHECK:               }
-// CHECK:               memref.store %[[VAL_37:.*]], %[[VAL_17]]{{\[}}%[[VAL_21]], %[[VAL_26]]] : memref<?x?xf32>
+// CHECK:               memref.store %[[VAL_29]], %[[VAL_17]]{{\[}}%[[VAL_21]], %[[VAL_26]]] : memref<?x?xf32>
 // CHECK:             }
 // CHECK:           }
 // CHECK:           %[[VAL_38:.*]] = memref.tensor_load %[[VAL_17]] : memref<?x?xf32>
@@ -1094,25 +1095,25 @@ func @sampled_dense_dense(%args: tensor<?x?xf32, #Tss>,
 }
 
 // CHECK-LABEL:   func @sum_kernel_with_inv(
-// CHECK-SAME:                              %[[VAL_0:.*0]]: tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
-// CHECK-SAME:                              %[[VAL_1:.*1]]: tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
-// CHECK-SAME:                              %[[VAL_2:.*2]]: tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
-// CHECK-SAME:                              %[[VAL_3:.*3]]: tensor<?xf32>,
-// CHECK-SAME:                              %[[VAL_4:.*4]]: tensor<f32>,
-// CHECK-SAME:                              %[[VAL_5:.*5]]: tensor<?xf32>) -> tensor<?xf32> {
-// CHECK-DAG:           %[[VAL_6:.*]] = arith.constant 0 : index
-// CHECK-DAG:           %[[VAL_7:.*]] = arith.constant true
-// CHECK-DAG:           %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK-SAME:      %[[VAL_0:.*0]]: tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:      %[[VAL_1:.*1]]: tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:      %[[VAL_2:.*2]]: tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:      %[[VAL_3:.*3]]: tensor<?xf32>,
+// CHECK-SAME:      %[[VAL_4:.*4]]: tensor<f32>,
+// CHECK-SAME:      %[[VAL_5:.*5]]: tensor<?xf32>) -> tensor<?xf32> {
+// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant true
 // CHECK:           %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_6]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK:           %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_6]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
-// CHECK:           %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_8]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
-// CHECK:           %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_8]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_7]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_7]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK:           %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
-// CHECK:           %[[VAL_14:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_8]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
-// CHECK:           %[[VAL_15:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_8]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_14:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_7]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_15:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_7]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK:           %[[VAL_16:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
-// CHECK:           %[[VAL_17:.*]] = sparse_tensor.pointers %[[VAL_2]], %[[VAL_8]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
-// CHECK:           %[[VAL_18:.*]] = sparse_tensor.indices %[[VAL_2]], %[[VAL_8]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_17:.*]] = sparse_tensor.pointers %[[VAL_2]], %[[VAL_7]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_18:.*]] = sparse_tensor.indices %[[VAL_2]], %[[VAL_7]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK:           %[[VAL_19:.*]] = sparse_tensor.values %[[VAL_2]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_20:.*]] = memref.buffer_cast %[[VAL_3]] : memref<?xf32>
 // CHECK:           %[[VAL_21:.*]] = memref.buffer_cast %[[VAL_4]] : memref<f32>
@@ -1122,7 +1123,7 @@ func @sampled_dense_dense(%args: tensor<?x?xf32, #Tss>,
 // CHECK:           memref.copy %[[VAL_23]], %[[VAL_24]] : memref<?xf32> to memref<?xf32>
 // CHECK:           %[[VAL_25:.*]] = memref.load %[[VAL_21]][] : memref<f32>
 // CHECK:           %[[VAL_26:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_6]]] : memref<?xindex>
-// CHECK:           %[[VAL_27:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_8]]] : memref<?xindex>
+// CHECK:           %[[VAL_27:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_7]]] : memref<?xindex>
 // CHECK:           %[[VAL_28:.*]]:2 = scf.while (%[[VAL_29:.*]] = %[[VAL_26]], %[[VAL_30:.*]] = %[[VAL_6]]) : (index, index) -> (index, index) {
 // CHECK:             %[[VAL_31:.*]] = arith.cmpi ult, %[[VAL_29]], %[[VAL_27]] : index
 // CHECK:             scf.condition(%[[VAL_31]]) %[[VAL_29]], %[[VAL_30]] : index, index
@@ -1131,158 +1132,158 @@ func @sampled_dense_dense(%args: tensor<?x?xf32, #Tss>,
 // CHECK:             %[[VAL_34:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_32]]] : memref<?xindex>
 // CHECK:             %[[VAL_35:.*]] = arith.cmpi eq, %[[VAL_34]], %[[VAL_33]] : index
 // CHECK:             scf.if %[[VAL_35]] {
-// CHECK:               %[[VAL_36:.*]] = memref.load %[[VAL_20]]{{\[}}%[[VAL_33]]] : memref<?xf32>
-// CHECK:               %[[VAL_37:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_32]]] : memref<?xindex>
-// CHECK:               %[[VAL_38:.*]] = arith.addi %[[VAL_32]], %[[VAL_8]] : index
-// CHECK:               %[[VAL_39:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_38]]] : memref<?xindex>
-// CHECK:               %[[VAL_40:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_33]]] : memref<?xindex>
-// CHECK:               %[[VAL_41:.*]] = arith.addi %[[VAL_33]], %[[VAL_8]] : index
-// CHECK:               %[[VAL_42:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_41]]] : memref<?xindex>
-// CHECK:               %[[VAL_43:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_33]]] : memref<?xindex>
-// CHECK:               %[[VAL_44:.*]] = arith.addi %[[VAL_33]], %[[VAL_8]] : index
-// CHECK:               %[[VAL_45:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_44]]] : memref<?xindex>
-// CHECK:               %[[VAL_46:.*]]:3 = scf.while (%[[VAL_47:.*]] = %[[VAL_37]], %[[VAL_48:.*]] = %[[VAL_40]], %[[VAL_49:.*]] = %[[VAL_43]]) : (index, index, index) -> (index, index, index) {
-// CHECK:                 %[[VAL_50:.*]] = arith.cmpi ult, %[[VAL_47]], %[[VAL_39]] : index
-// CHECK:                 %[[VAL_51:.*]] = arith.cmpi ult, %[[VAL_48]], %[[VAL_42]] : index
-// CHECK:                 %[[VAL_52:.*]] = arith.andi %[[VAL_50]], %[[VAL_51]] : i1
-// CHECK:                 %[[VAL_53:.*]] = arith.cmpi ult, %[[VAL_49]], %[[VAL_45]] : index
+// CHECK:               %[[VAL_36:.*]] = memref.load %[[VAL_24]]{{\[}}%[[VAL_33]]] : memref<?xf32>
+// CHECK:               %[[VAL_37:.*]] = memref.load %[[VAL_20]]{{\[}}%[[VAL_33]]] : memref<?xf32>
+// CHECK:               %[[VAL_38:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_32]]] : memref<?xindex>
+// CHECK:               %[[VAL_39:.*]] = arith.addi %[[VAL_32]], %[[VAL_7]] : index
+// CHECK:               %[[VAL_40:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_39]]] : memref<?xindex>
+// CHECK:               %[[VAL_41:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_33]]] : memref<?xindex>
+// CHECK:               %[[VAL_42:.*]] = arith.addi %[[VAL_33]], %[[VAL_7]] : index
+// CHECK:               %[[VAL_43:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_42]]] : memref<?xindex>
+// CHECK:               %[[VAL_44:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_33]]] : memref<?xindex>
+// CHECK:               %[[VAL_45:.*]] = arith.addi %[[VAL_33]], %[[VAL_7]] : index
+// CHECK:               %[[VAL_46:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_45]]] : memref<?xindex>
+// CHECK:               %[[VAL_47:.*]]:4 = scf.while (%[[VAL_48:.*]] = %[[VAL_38]], %[[VAL_49:.*]] = %[[VAL_41]], %[[VAL_50:.*]] = %[[VAL_44]], %[[VAL_51:.*]] = %[[VAL_36]]) : (index, index, index, f32) -> (index, index, index, f32) {
+// CHECK:                 %[[VAL_52:.*]] = arith.cmpi ult, %[[VAL_48]], %[[VAL_40]] : index
+// CHECK:                 %[[VAL_53:.*]] = arith.cmpi ult, %[[VAL_49]], %[[VAL_43]] : index
 // CHECK:                 %[[VAL_54:.*]] = arith.andi %[[VAL_52]], %[[VAL_53]] : i1
-// CHECK:                 scf.condition(%[[VAL_54]]) %[[VAL_47]], %[[VAL_48]], %[[VAL_49]] : index, index, index
+// CHECK:                 %[[VAL_55:.*]] = arith.cmpi ult, %[[VAL_50]], %[[VAL_46]] : index
+// CHECK:                 %[[VAL_56:.*]] = arith.andi %[[VAL_54]], %[[VAL_55]] : i1
+// CHECK:                 scf.condition(%[[VAL_56]]) %[[VAL_48]], %[[VAL_49]], %[[VAL_50]], %[[VAL_51]] : index, index, index, f32
 // CHECK:               } do {
-// CHECK:               ^bb0(%[[VAL_55:.*]]: index, %[[VAL_56:.*]]: index, %[[VAL_57:.*]]: index):
-// CHECK:                 %[[VAL_58:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_55]]] : memref<?xindex>
-// CHECK:                 %[[VAL_59:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_56]]] : memref<?xindex>
-// CHECK:                 %[[VAL_60:.*]] = arith.cmpi ult, %[[VAL_59]], %[[VAL_58]] : index
-// CHECK:                 %[[VAL_61:.*]] = select %[[VAL_60]], %[[VAL_59]], %[[VAL_58]] : index
-// CHECK:                 %[[VAL_62:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_57]]] : memref<?xindex>
+// CHECK:               ^bb0(%[[VAL_57:.*]]: index, %[[VAL_58:.*]]: index, %[[VAL_59:.*]]: index, %[[VAL_60:.*]]: f32):
+// CHECK:                 %[[VAL_61:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_57]]] : memref<?xindex>
+// CHECK:                 %[[VAL_62:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_58]]] : memref<?xindex>
 // CHECK:                 %[[VAL_63:.*]] = arith.cmpi ult, %[[VAL_62]], %[[VAL_61]] : index
 // CHECK:                 %[[VAL_64:.*]] = select %[[VAL_63]], %[[VAL_62]], %[[VAL_61]] : index
-// CHECK:                 %[[VAL_65:.*]] = arith.cmpi eq, %[[VAL_58]], %[[VAL_64]] : index
-// CHECK:                 %[[VAL_66:.*]] = arith.cmpi eq, %[[VAL_59]], %[[VAL_64]] : index
-// CHECK:                 %[[VAL_67:.*]] = arith.andi %[[VAL_65]], %[[VAL_66]] : i1
-// CHECK:                 %[[VAL_68:.*]] = arith.cmpi eq, %[[VAL_62]], %[[VAL_64]] : index
-// CHECK:                 %[[VAL_69:.*]] = arith.andi %[[VAL_67]], %[[VAL_68]] : i1
-// CHECK:                 scf.if %[[VAL_69]] {
-// CHECK:                   %[[VAL_70:.*]] = memref.load %[[VAL_24]]{{\[}}%[[VAL_33]]] : memref<?xf32>
-// CHECK:                   %[[VAL_71:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_55]]] : memref<?xf32>
-// CHECK:                   %[[VAL_72:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_56]]] : memref<?xf32>
-// CHECK:                   %[[VAL_73:.*]] = arith.mulf %[[VAL_71]], %[[VAL_72]] : f32
-// CHECK:                   %[[VAL_74:.*]] = arith.mulf %[[VAL_73]], %[[VAL_36]] : f32
-// CHECK:                   %[[VAL_75:.*]] = arith.mulf %[[VAL_74]], %[[VAL_25]] : f32
-// CHECK:                   %[[VAL_76:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_57]]] : memref<?xf32>
-// CHECK:                   %[[VAL_77:.*]] = arith.addf %[[VAL_75]], %[[VAL_76]] : f32
-// CHECK:                   %[[VAL_78:.*]] = arith.addf %[[VAL_70]], %[[VAL_77]] : f32
-// CHECK:                   memref.store %[[VAL_78]], %[[VAL_24]]{{\[}}%[[VAL_33]]] : memref<?xf32>
+// CHECK:                 %[[VAL_65:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_59]]] : memref<?xindex>
+// CHECK:                 %[[VAL_66:.*]] = arith.cmpi ult, %[[VAL_65]], %[[VAL_64]] : index
+// CHECK:                 %[[VAL_67:.*]] = select %[[VAL_66]], %[[VAL_65]], %[[VAL_64]] : index
+// CHECK:                 %[[VAL_68:.*]] = arith.cmpi eq, %[[VAL_61]], %[[VAL_67]] : index
+// CHECK:                 %[[VAL_69:.*]] = arith.cmpi eq, %[[VAL_62]], %[[VAL_67]] : index
+// CHECK:                 %[[VAL_70:.*]] = arith.andi %[[VAL_68]], %[[VAL_69]] : i1
+// CHECK:                 %[[VAL_71:.*]] = arith.cmpi eq, %[[VAL_65]], %[[VAL_67]] : index
+// CHECK:                 %[[VAL_72:.*]] = arith.andi %[[VAL_70]], %[[VAL_71]] : i1
+// CHECK:                 %[[VAL_73:.*]] = scf.if %[[VAL_72]] -> (f32) {
+// CHECK:                   %[[VAL_74:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_57]]] : memref<?xf32>
+// CHECK:                   %[[VAL_75:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_58]]] : memref<?xf32>
+// CHECK:                   %[[VAL_76:.*]] = arith.mulf %[[VAL_74]], %[[VAL_75]] : f32
+// CHECK:                   %[[VAL_77:.*]] = arith.mulf %[[VAL_76]], %[[VAL_37]] : f32
+// CHECK:                   %[[VAL_78:.*]] = arith.mulf %[[VAL_77]], %[[VAL_25]] : f32
+// CHECK:                   %[[VAL_79:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_59]]] : memref<?xf32>
+// CHECK:                   %[[VAL_80:.*]] = arith.addf %[[VAL_78]], %[[VAL_79]] : f32
+// CHECK:                   %[[VAL_81:.*]] = arith.addf %[[VAL_60]], %[[VAL_80]] : f32
+// CHECK:                   scf.yield %[[VAL_81]] : f32
 // CHECK:                 } else {
-// CHECK:                   %[[VAL_79:.*]] = arith.cmpi eq, %[[VAL_58]], %[[VAL_64]] : index
-// CHECK:                   %[[VAL_80:.*]] = arith.cmpi eq, %[[VAL_59]], %[[VAL_64]] : index
-// CHECK:                   %[[VAL_81:.*]] = arith.andi %[[VAL_79]], %[[VAL_80]] : i1
-// CHECK:                   scf.if %[[VAL_81]] {
-// CHECK:                     %[[VAL_82:.*]] = memref.load %[[VAL_24]]{{\[}}%[[VAL_33]]] : memref<?xf32>
-// CHECK:                     %[[VAL_83:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_55]]] : memref<?xf32>
-// CHECK:                     %[[VAL_84:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_56]]] : memref<?xf32>
-// CHECK:                     %[[VAL_85:.*]] = arith.mulf %[[VAL_83]], %[[VAL_84]] : f32
-// CHECK:                     %[[VAL_86:.*]] = arith.mulf %[[VAL_85]], %[[VAL_36]] : f32
-// CHECK:                     %[[VAL_87:.*]] = arith.mulf %[[VAL_86]], %[[VAL_25]] : f32
-// CHECK:                     %[[VAL_88:.*]] = arith.addf %[[VAL_82]], %[[VAL_87]] : f32
-// CHECK:                     memref.store %[[VAL_88]], %[[VAL_24]]{{\[}}%[[VAL_33]]] : memref<?xf32>
+// CHECK:                   %[[VAL_82:.*]] = arith.cmpi eq, %[[VAL_61]], %[[VAL_67]] : index
+// CHECK:                   %[[VAL_83:.*]] = arith.cmpi eq, %[[VAL_62]], %[[VAL_67]] : index
+// CHECK:                   %[[VAL_84:.*]] = arith.andi %[[VAL_82]], %[[VAL_83]] : i1
+// CHECK:                   %[[VAL_85:.*]] = scf.if %[[VAL_84]] -> (f32) {
+// CHECK:                     %[[VAL_86:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_57]]] : memref<?xf32>
+// CHECK:                     %[[VAL_87:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_58]]] : memref<?xf32>
+// CHECK:                     %[[VAL_88:.*]] = arith.mulf %[[VAL_86]], %[[VAL_87]] : f32
+// CHECK:                     %[[VAL_89:.*]] = arith.mulf %[[VAL_88]], %[[VAL_37]] : f32
+// CHECK:                     %[[VAL_90:.*]] = arith.mulf %[[VAL_89]], %[[VAL_25]] : f32
+// CHECK:                     %[[VAL_91:.*]] = arith.addf %[[VAL_60]], %[[VAL_90]] : f32
+// CHECK:                     scf.yield %[[VAL_91]] : f32
 // CHECK:                   } else {
-// CHECK:                     %[[VAL_89:.*]] = arith.cmpi eq, %[[VAL_62]], %[[VAL_64]] : index
-// CHECK:                     scf.if %[[VAL_89]] {
-// CHECK:                       %[[VAL_90:.*]] = memref.load %[[VAL_24]]{{\[}}%[[VAL_33]]] : memref<?xf32>
-// CHECK:                       %[[VAL_91:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_57]]] : memref<?xf32>
-// CHECK:                       %[[VAL_92:.*]] = arith.addf %[[VAL_90]], %[[VAL_91]] : f32
-// CHECK:                       memref.store %[[VAL_92]], %[[VAL_24]]{{\[}}%[[VAL_33]]] : memref<?xf32>
+// CHECK:                     %[[VAL_92:.*]] = arith.cmpi eq, %[[VAL_65]], %[[VAL_67]] : index
+// CHECK:                     %[[VAL_93:.*]] = scf.if %[[VAL_92]] -> (f32) {
+// CHECK:                       %[[VAL_94:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_59]]] : memref<?xf32>
+// CHECK:                       %[[VAL_95:.*]] = arith.addf %[[VAL_60]], %[[VAL_94]] : f32
+// CHECK:                       scf.yield %[[VAL_95]] : f32
 // CHECK:                     } else {
+// CHECK:                       scf.yield %[[VAL_60]] : f32
 // CHECK:                     }
+// CHECK:                     scf.yield %[[VAL_96:.*]] : f32
 // CHECK:                   }
+// CHECK:                   scf.yield %[[VAL_97:.*]] : f32
 // CHECK:                 }
-// CHECK:                 %[[VAL_93:.*]] = arith.cmpi eq, %[[VAL_58]], %[[VAL_64]] : index
-// CHECK:                 %[[VAL_94:.*]] = arith.addi %[[VAL_55]], %[[VAL_8]] : index
-// CHECK:                 %[[VAL_95:.*]] = select %[[VAL_93]], %[[VAL_94]], %[[VAL_55]] : index
-// CHECK:                 %[[VAL_96:.*]] = arith.cmpi eq, %[[VAL_59]], %[[VAL_64]] : index
-// CHECK:                 %[[VAL_97:.*]] = arith.addi %[[VAL_56]], %[[VAL_8]] : index
-// CHECK:                 %[[VAL_98:.*]] = select %[[VAL_96]], %[[VAL_97]], %[[VAL_56]] : index
-// CHECK:                 %[[VAL_99:.*]] = arith.cmpi eq, %[[VAL_62]], %[[VAL_64]] : index
-// CHECK:                 %[[VAL_100:.*]] = arith.addi %[[VAL_57]], %[[VAL_8]] : index
-// CHECK:                 %[[VAL_101:.*]] = select %[[VAL_99]], %[[VAL_100]], %[[VAL_57]] : index
-// CHECK:                 scf.yield %[[VAL_95]], %[[VAL_98]], %[[VAL_101]] : index, index, index
+// CHECK:                 %[[VAL_98:.*]] = arith.cmpi eq, %[[VAL_61]], %[[VAL_67]] : index
+// CHECK:                 %[[VAL_99:.*]] = arith.addi %[[VAL_57]], %[[VAL_7]] : index
+// CHECK:                 %[[VAL_100:.*]] = select %[[VAL_98]], %[[VAL_99]], %[[VAL_57]] : index
+// CHECK:                 %[[VAL_101:.*]] = arith.cmpi eq, %[[VAL_62]], %[[VAL_67]] : index
+// CHECK:                 %[[VAL_102:.*]] = arith.addi %[[VAL_58]], %[[VAL_7]] : index
+// CHECK:                 %[[VAL_103:.*]] = select %[[VAL_101]], %[[VAL_102]], %[[VAL_58]] : index
+// CHECK:                 %[[VAL_104:.*]] = arith.cmpi eq, %[[VAL_65]], %[[VAL_67]] : index
+// CHECK:                 %[[VAL_105:.*]] = arith.addi %[[VAL_59]], %[[VAL_7]] : index
+// CHECK:                 %[[VAL_106:.*]] = select %[[VAL_104]], %[[VAL_105]], %[[VAL_59]] : index
+// CHECK:                 scf.yield %[[VAL_100]], %[[VAL_103]], %[[VAL_106]], %[[VAL_107:.*]] : index, index, index, f32
 // CHECK:               }
-// CHECK:               %[[VAL_102:.*]]:2 = scf.while (%[[VAL_103:.*]] = %[[VAL_104:.*]]#0, %[[VAL_105:.*]] = %[[VAL_104]]#1) : (index, index) -> (index, index) {
-// CHECK:                 %[[VAL_106:.*]] = arith.cmpi ult, %[[VAL_103]], %[[VAL_39]] : index
-// CHECK:                 %[[VAL_107:.*]] = arith.cmpi ult, %[[VAL_105]], %[[VAL_42]] : index
-// CHECK:                 %[[VAL_108:.*]] = arith.andi %[[VAL_106]], %[[VAL_107]] : i1
-// CHECK:                 scf.condition(%[[VAL_108]]) %[[VAL_103]], %[[VAL_105]] : index, index
+// CHECK:               %[[VAL_108:.*]]:3 = scf.while (%[[VAL_109:.*]] = %[[VAL_110:.*]]#0, %[[VAL_111:.*]] = %[[VAL_110]]#1, %[[VAL_112:.*]] = %[[VAL_110]]#3) : (index, index, f32) -> (index, index, f32) {
+// CHECK:                 %[[VAL_113:.*]] = arith.cmpi ult, %[[VAL_109]], %[[VAL_40]] : index
+// CHECK:                 %[[VAL_114:.*]] = arith.cmpi ult, %[[VAL_111]], %[[VAL_43]] : index
+// CHECK:                 %[[VAL_115:.*]] = arith.andi %[[VAL_113]], %[[VAL_114]] : i1
+// CHECK:                 scf.condition(%[[VAL_115]]) %[[VAL_109]], %[[VAL_111]], %[[VAL_112]] : index, index, f32
 // CHECK:               } do {
-// CHECK:               ^bb0(%[[VAL_109:.*]]: index, %[[VAL_110:.*]]: index):
-// CHECK:                 %[[VAL_111:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_109]]] : memref<?xindex>
-// CHECK:                 %[[VAL_112:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_110]]] : memref<?xindex>
-// CHECK:                 %[[VAL_113:.*]] = arith.cmpi ult, %[[VAL_112]], %[[VAL_111]] : index
-// CHECK:                 %[[VAL_114:.*]] = select %[[VAL_113]], %[[VAL_112]], %[[VAL_111]] : index
-// CHECK:                 %[[VAL_115:.*]] = arith.cmpi eq, %[[VAL_111]], %[[VAL_114]] : index
-// CHECK:                 %[[VAL_116:.*]] = arith.cmpi eq, %[[VAL_112]], %[[VAL_114]] : index
-// CHECK:                 %[[VAL_117:.*]] = arith.andi %[[VAL_115]], %[[VAL_116]] : i1
-// CHECK:                 scf.if %[[VAL_117]] {
-// CHECK:                   %[[VAL_118:.*]] = memref.load %[[VAL_24]]{{\[}}%[[VAL_33]]] : memref<?xf32>
-// CHECK:                   %[[VAL_119:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_109]]] : memref<?xf32>
-// CHECK:                   %[[VAL_120:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_110]]] : memref<?xf32>
-// CHECK:                   %[[VAL_121:.*]] = arith.mulf %[[VAL_119]], %[[VAL_120]] : f32
-// CHECK:                   %[[VAL_122:.*]] = arith.mulf %[[VAL_121]], %[[VAL_36]] : f32
-// CHECK:                   %[[VAL_123:.*]] = arith.mulf %[[VAL_122]], %[[VAL_25]] : f32
-// CHECK:                   %[[VAL_124:.*]] = arith.addf %[[VAL_118]], %[[VAL_123]] : f32
-// CHECK:                   memref.store %[[VAL_124]], %[[VAL_24]]{{\[}}%[[VAL_33]]] : memref<?xf32>
+// CHECK:               ^bb0(%[[VAL_116:.*]]: index, %[[VAL_117:.*]]: index, %[[VAL_118:.*]]: f32):
+// CHECK:                 %[[VAL_119:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_116]]] : memref<?xindex>
+// CHECK:                 %[[VAL_120:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_117]]] : memref<?xindex>
+// CHECK:                 %[[VAL_121:.*]] = arith.cmpi ult, %[[VAL_120]], %[[VAL_119]] : index
+// CHECK:                 %[[VAL_122:.*]] = select %[[VAL_121]], %[[VAL_120]], %[[VAL_119]] : index
+// CHECK:                 %[[VAL_123:.*]] = arith.cmpi eq, %[[VAL_119]], %[[VAL_122]] : index
+// CHECK:                 %[[VAL_124:.*]] = arith.cmpi eq, %[[VAL_120]], %[[VAL_122]] : index
+// CHECK:                 %[[VAL_125:.*]] = arith.andi %[[VAL_123]], %[[VAL_124]] : i1
+// CHECK:                 %[[VAL_126:.*]] = scf.if %[[VAL_125]] -> (f32) {
+// CHECK:                   %[[VAL_127:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_116]]] : memref<?xf32>
+// CHECK:                   %[[VAL_128:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_117]]] : memref<?xf32>
+// CHECK:                   %[[VAL_129:.*]] = arith.mulf %[[VAL_127]], %[[VAL_128]] : f32
+// CHECK:                   %[[VAL_130:.*]] = arith.mulf %[[VAL_129]], %[[VAL_37]] : f32
+// CHECK:                   %[[VAL_131:.*]] = arith.mulf %[[VAL_130]], %[[VAL_25]] : f32
+// CHECK:                   %[[VAL_132:.*]] = arith.addf %[[VAL_118]], %[[VAL_131]] : f32
+// CHECK:                   scf.yield %[[VAL_132]] : f32
 // CHECK:                 } else {
+// CHECK:                   scf.yield %[[VAL_118]] : f32
 // CHECK:                 }
-// CHECK:                 %[[VAL_125:.*]] = arith.cmpi eq, %[[VAL_111]], %[[VAL_114]] : index
-// CHECK:                 %[[VAL_126:.*]] = arith.addi %[[VAL_109]], %[[VAL_8]] : index
-// CHECK:                 %[[VAL_127:.*]] = select %[[VAL_125]], %[[VAL_126]], %[[VAL_109]] : index
-// CHECK:                 %[[VAL_128:.*]] = arith.cmpi eq, %[[VAL_112]], %[[VAL_114]] : index
-// CHECK:                 %[[VAL_129:.*]] = arith.addi %[[VAL_110]], %[[VAL_8]] : index
-// CHECK:                 %[[VAL_130:.*]] = select %[[VAL_128]], %[[VAL_129]], %[[VAL_110]] : index
-// CHECK:                 scf.yield %[[VAL_127]], %[[VAL_130]] : index, index
+// CHECK:                 %[[VAL_133:.*]] = arith.cmpi eq, %[[VAL_119]], %[[VAL_122]] : index
+// CHECK:                 %[[VAL_134:.*]] = arith.addi %[[VAL_116]], %[[VAL_7]] : index
+// CHECK:                 %[[VAL_135:.*]] = select %[[VAL_133]], %[[VAL_134]], %[[VAL_116]] : index
+// CHECK:                 %[[VAL_136:.*]] = arith.cmpi eq, %[[VAL_120]], %[[VAL_122]] : index
+// CHECK:                 %[[VAL_137:.*]] = arith.addi %[[VAL_117]], %[[VAL_7]] : index
+// CHECK:                 %[[VAL_138:.*]] = select %[[VAL_136]], %[[VAL_137]], %[[VAL_117]] : index
+// CHECK:                 scf.yield %[[VAL_135]], %[[VAL_138]], %[[VAL_139:.*]] : index, index, f32
 // CHECK:               }
-// CHECK:               %[[VAL_131:.*]] = memref.load %[[VAL_24]]{{\[}}%[[VAL_33]]] : memref<?xf32>
-// CHECK:               %[[VAL_132:.*]] = scf.for %[[VAL_133:.*]] = %[[VAL_134:.*]]#2 to %[[VAL_45]] step %[[VAL_8]] iter_args(%[[VAL_135:.*]] = %[[VAL_131]]) -> (f32) {
-// CHECK:                 %[[VAL_136:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_133]]] : memref<?xf32>
-// CHECK:                 %[[VAL_137:.*]] = arith.addf %[[VAL_135]], %[[VAL_136]] : f32
-// CHECK:                 scf.yield %[[VAL_137]] : f32
+// CHECK:               %[[VAL_140:.*]] = scf.for %[[VAL_141:.*]] = %[[VAL_142:.*]]#2 to %[[VAL_46]] step %[[VAL_7]] iter_args(%[[VAL_143:.*]] = %[[VAL_144:.*]]#2) -> (f32) {
+// CHECK:                 %[[VAL_145:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_141]]] : memref<?xf32>
+// CHECK:                 %[[VAL_146:.*]] = arith.addf %[[VAL_143]], %[[VAL_145]] : f32
+// CHECK:                 scf.yield %[[VAL_146]] : f32
 // CHECK:               }
-// CHECK:               memref.store %[[VAL_138:.*]], %[[VAL_24]]{{\[}}%[[VAL_33]]] : memref<?xf32>
+// CHECK:               memref.store %[[VAL_147:.*]], %[[VAL_24]]{{\[}}%[[VAL_33]]] : memref<?xf32>
 // CHECK:             } else {
-// CHECK:               scf.if %[[VAL_7]] {
-// CHECK:                 %[[VAL_139:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_33]]] : memref<?xindex>
-// CHECK:                 %[[VAL_140:.*]] = arith.addi %[[VAL_33]], %[[VAL_8]] : index
-// CHECK:                 %[[VAL_141:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_140]]] : memref<?xindex>
-// CHECK:                 %[[VAL_142:.*]] = memref.load %[[VAL_24]]{{\[}}%[[VAL_33]]] : memref<?xf32>
-// CHECK:                 %[[VAL_143:.*]] = scf.for %[[VAL_144:.*]] = %[[VAL_139]] to %[[VAL_141]] step %[[VAL_8]] iter_args(%[[VAL_145:.*]] = %[[VAL_142]]) -> (f32) {
-// CHECK:                   %[[VAL_146:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_144]]] : memref<?xf32>
-// CHECK:                   %[[VAL_147:.*]] = arith.addf %[[VAL_145]], %[[VAL_146]] : f32
-// CHECK:                   scf.yield %[[VAL_147]] : f32
+// CHECK:               scf.if %[[VAL_8]] {
+// CHECK:                 %[[VAL_148:.*]] = memref.load %[[VAL_24]]{{\[}}%[[VAL_33]]] : memref<?xf32>
+// CHECK:                 %[[VAL_149:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_33]]] : memref<?xindex>
+// CHECK:                 %[[VAL_150:.*]] = arith.addi %[[VAL_33]], %[[VAL_7]] : index
+// CHECK:                 %[[VAL_151:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_150]]] : memref<?xindex>
+// CHECK:                 %[[VAL_152:.*]] = scf.for %[[VAL_153:.*]] = %[[VAL_149]] to %[[VAL_151]] step %[[VAL_7]] iter_args(%[[VAL_154:.*]] = %[[VAL_148]]) -> (f32) {
+// CHECK:                   %[[VAL_155:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_153]]] : memref<?xf32>
+// CHECK:                   %[[VAL_156:.*]] = arith.addf %[[VAL_154]], %[[VAL_155]] : f32
+// CHECK:                   scf.yield %[[VAL_156]] : f32
 // CHECK:                 }
-// CHECK:                 memref.store %[[VAL_148:.*]], %[[VAL_24]]{{\[}}%[[VAL_33]]] : memref<?xf32>
+// CHECK:                 memref.store %[[VAL_157:.*]], %[[VAL_24]]{{\[}}%[[VAL_33]]] : memref<?xf32>
 // CHECK:               } else {
 // CHECK:               }
 // CHECK:             }
-// CHECK:             %[[VAL_149:.*]] = arith.cmpi eq, %[[VAL_34]], %[[VAL_33]] : index
-// CHECK:             %[[VAL_150:.*]] = arith.addi %[[VAL_32]], %[[VAL_8]] : index
-// CHECK:             %[[VAL_151:.*]] = select %[[VAL_149]], %[[VAL_150]], %[[VAL_32]] : index
-// CHECK:             %[[VAL_152:.*]] = arith.addi %[[VAL_33]], %[[VAL_8]] : index
-// CHECK:             scf.yield %[[VAL_151]], %[[VAL_152]] : index, index
+// CHECK:             %[[VAL_158:.*]] = arith.cmpi eq, %[[VAL_34]], %[[VAL_33]] : index
+// CHECK:             %[[VAL_159:.*]] = arith.addi %[[VAL_32]], %[[VAL_7]] : index
+// CHECK:             %[[VAL_160:.*]] = select %[[VAL_158]], %[[VAL_159]], %[[VAL_32]] : index
+// CHECK:             %[[VAL_161:.*]] = arith.addi %[[VAL_33]], %[[VAL_7]] : index
+// CHECK:             scf.yield %[[VAL_160]], %[[VAL_161]] : index, index
 // CHECK:           }
-// CHECK:           scf.for %[[VAL_153:.*]] = %[[VAL_154:.*]]#1 to %[[VAL_22]] step %[[VAL_8]] {
-// CHECK:             %[[VAL_155:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_153]]] : memref<?xindex>
-// CHECK:             %[[VAL_156:.*]] = arith.addi %[[VAL_153]], %[[VAL_8]] : index
-// CHECK:             %[[VAL_157:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_156]]] : memref<?xindex>
-// CHECK:             %[[VAL_158:.*]] = memref.load %[[VAL_24]]{{\[}}%[[VAL_153]]] : memref<?xf32>
-// CHECK:             %[[VAL_159:.*]] = scf.for %[[VAL_160:.*]] = %[[VAL_155]] to %[[VAL_157]] step %[[VAL_8]] iter_args(%[[VAL_161:.*]] = %[[VAL_158]]) -> (f32) {
-// CHECK:               %[[VAL_162:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_160]]] : memref<?xf32>
-// CHECK:               %[[VAL_163:.*]] = arith.addf %[[VAL_161]], %[[VAL_162]] : f32
-// CHECK:               scf.yield %[[VAL_163]] : f32
+// CHECK:           scf.for %[[VAL_162:.*]] = %[[VAL_163:.*]]#1 to %[[VAL_22]] step %[[VAL_7]] {
+// CHECK:             %[[VAL_164:.*]] = memref.load %[[VAL_24]]{{\[}}%[[VAL_162]]] : memref<?xf32>
+// CHECK:             %[[VAL_165:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_162]]] : memref<?xindex>
+// CHECK:             %[[VAL_166:.*]] = arith.addi %[[VAL_162]], %[[VAL_7]] : index
+// CHECK:             %[[VAL_167:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_166]]] : memref<?xindex>
+// CHECK:             %[[VAL_168:.*]] = scf.for %[[VAL_169:.*]] = %[[VAL_165]] to %[[VAL_167]] step %[[VAL_7]] iter_args(%[[VAL_170:.*]] = %[[VAL_164]]) -> (f32) {
+// CHECK:               %[[VAL_171:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_169]]] : memref<?xf32>
+// CHECK:               %[[VAL_172:.*]] = arith.addf %[[VAL_170]], %[[VAL_171]] : f32
+// CHECK:               scf.yield %[[VAL_172]] : f32
 // CHECK:             }
-// CHECK:             memref.store %[[VAL_164:.*]], %[[VAL_24]]{{\[}}%[[VAL_153]]] : memref<?xf32>
+// CHECK:             memref.store %[[VAL_173:.*]], %[[VAL_24]]{{\[}}%[[VAL_162]]] : memref<?xf32>
 // CHECK:           }
-// CHECK:           %[[VAL_165:.*]] = memref.tensor_load %[[VAL_24]] : memref<?xf32>
-// CHECK:           return %[[VAL_165]] : tensor<?xf32>
+// CHECK:           %[[VAL_174:.*]] = memref.tensor_load %[[VAL_24]] : memref<?xf32>
+// CHECK:           return %[[VAL_174]] : tensor<?xf32>
 // CHECK:         }
 func @sum_kernel_with_inv(%arga: tensor<?x?xf32, #Tss>,
                           %argb: tensor<?x?xf32, #Tds>,

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_3d.mlir b/mlir/test/Dialect/SparseTensor/sparse_3d.mlir
index 734ea159b8207..9070ac36d3b16 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_3d.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_3d.mlir
@@ -1194,39 +1194,41 @@ func @kernel_3d(%arga: tensor<?x?xf32>,
 }
 
 // CHECK-LABEL:   func @sum_reduction(
-// CHECK-SAME:                        %[[VAL_0:.*]]: tensor<10x20x30xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
-// CHECK-SAME:                        %[[VAL_1:.*]]: tensor<f32>) -> tensor<f32> {
-// CHECK-DAG:           %[[VAL_2:.*]] = arith.constant 2 : index
-// CHECK-DAG:           %[[VAL_3:.*]] = arith.constant 0 : index
-// CHECK-DAG:           %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<10x20x30xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<10x20x30xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<10x20x30xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10x20x30xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
+// CHECK-SAME:       %[[VAL_0:.*]]: tensor<10x20x30xf32, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<f32>) -> tensor<f32> {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<10x20x30xf32, #sparse_tensor.encoding<{{{.*}}>>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<10x20x30xf32, #sparse_tensor.encoding<{{{.*}}>>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<10x20x30xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10x20x30xf32, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK:           %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<f32>
 // CHECK:           %[[VAL_10:.*]] = memref.alloc() : memref<f32>
 // CHECK:           memref.copy %[[VAL_9]], %[[VAL_10]] : memref<f32> to memref<f32>
-// CHECK:           %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK:           %[[VAL_12:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK:           scf.for %[[VAL_13:.*]] = %[[VAL_11]] to %[[VAL_12]] step %[[VAL_4]] {
-// CHECK:             %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref<?xindex>
-// CHECK:             %[[VAL_15:.*]] = arith.addi %[[VAL_13]], %[[VAL_4]] : index
-// CHECK:             %[[VAL_16:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_15]]] : memref<?xindex>
-// CHECK:             scf.for %[[VAL_17:.*]] = %[[VAL_14]] to %[[VAL_16]] step %[[VAL_4]] {
-// CHECK:               %[[VAL_18:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_17]]] : memref<?xindex>
-// CHECK:               %[[VAL_19:.*]] = arith.addi %[[VAL_17]], %[[VAL_4]] : index
-// CHECK:               %[[VAL_20:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_19]]] : memref<?xindex>
-// CHECK:               %[[VAL_21:.*]] = memref.load %[[VAL_10]][] : memref<f32>
-// CHECK:               %[[VAL_22:.*]] = scf.for %[[VAL_23:.*]] = %[[VAL_18]] to %[[VAL_20]] step %[[VAL_4]] iter_args(%[[VAL_24:.*]] = %[[VAL_21]]) -> (f32) {
-// CHECK:                 %[[VAL_25:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_23]]] : memref<?xf32>
-// CHECK:                 %[[VAL_26:.*]] = arith.addf %[[VAL_24]], %[[VAL_25]] : f32
-// CHECK:                 scf.yield %[[VAL_26]] : f32
+// CHECK:           %[[VAL_11:.*]] = memref.load %[[VAL_10]][] : memref<f32>
+// CHECK:           %[[VAL_12:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK:           %[[VAL_14:.*]] = scf.for %[[VAL_15:.*]] = %[[VAL_12]] to %[[VAL_13]] step %[[VAL_3]] iter_args(%[[VAL_16:.*]] = %[[VAL_11]]) -> (f32) {
+// CHECK:             %[[VAL_17:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_15]]] : memref<?xindex>
+// CHECK:             %[[VAL_18:.*]] = arith.addi %[[VAL_15]], %[[VAL_3]] : index
+// CHECK:             %[[VAL_19:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_18]]] : memref<?xindex>
+// CHECK:             %[[VAL_20:.*]] = scf.for %[[VAL_21:.*]] = %[[VAL_17]] to %[[VAL_19]] step %[[VAL_3]] iter_args(%[[VAL_22:.*]] = %[[VAL_16]]) -> (f32) {
+// CHECK:               %[[VAL_23:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_21]]] : memref<?xindex>
+// CHECK:               %[[VAL_24:.*]] = arith.addi %[[VAL_21]], %[[VAL_3]] : index
+// CHECK:               %[[VAL_25:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_24]]] : memref<?xindex>
+// CHECK:               %[[VAL_26:.*]] = scf.for %[[VAL_27:.*]] = %[[VAL_23]] to %[[VAL_25]] step %[[VAL_3]] iter_args(%[[VAL_28:.*]] = %[[VAL_22]]) -> (f32) {
+// CHECK:                 %[[VAL_29:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_27]]] : memref<?xf32>
+// CHECK:                 %[[VAL_30:.*]] = arith.addf %[[VAL_28]], %[[VAL_29]] : f32
+// CHECK:                 scf.yield %[[VAL_30]] : f32
 // CHECK:               }
-// CHECK:               memref.store %[[VAL_27:.*]], %[[VAL_10]][] : memref<f32>
+// CHECK:               scf.yield %[[VAL_26]] : f32
 // CHECK:             }
+// CHECK:             scf.yield %[[VAL_20]] : f32
 // CHECK:           }
-// CHECK:           %[[VAL_28:.*]] = memref.tensor_load %[[VAL_10]] : memref<f32>
-// CHECK:           return %[[VAL_28]] : tensor<f32>
+// CHECK:           memref.store %[[VAL_14]], %[[VAL_10]][] : memref<f32>
+// CHECK:           %[[VAL_34:.*]] = memref.tensor_load %[[VAL_10]] : memref<f32>
+// CHECK:           return %[[VAL_34]] : tensor<f32>
 // CHECK:         }
 func @sum_reduction(%arga: tensor<10x20x30xf32, #Tsss>, %argx: tensor<f32>) -> tensor<f32> {
   %0 = linalg.generic #trait_sum_reduction
@@ -1250,35 +1252,37 @@ func @sum_reduction(%arga: tensor<10x20x30xf32, #Tsss>, %argx: tensor<f32>) -> t
 }
 
 // CHECK-LABEL:   func @sum_reduction_inv(
-// CHECK-SAME:                            %[[VAL_0:.*]]: tensor<?x?x?xf32>,
-// CHECK-SAME:                            %[[VAL_1:.*]]: tensor<?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
-// CHECK-SAME:                            %[[VAL_2:.*]]: tensor<f32>) -> tensor<f32> {
-// CHECK-DAG:           %[[VAL_3:.*]] = arith.constant 2 : index
-// CHECK-DAG:           %[[VAL_4:.*]] = arith.constant 0 : index
-// CHECK-DAG:           %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_5]] : tensor<?x?x?xf32>
-// CHECK:           %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xf32>
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<?x?x?xf32>,
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-SAME:      %[[VAL_2:.*]]: tensor<f32>) -> tensor<f32> {
+// CHECK:           %[[VAL_3:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_5:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xf32>
+// CHECK:           %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?x?xf32>
 // CHECK:           %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_0]] : memref<?x?x?xf32>
-// CHECK:           %[[VAL_9:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor<?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
+// CHECK:           %[[VAL_9:.*]] = tensor.dim %[[VAL_1]], %[[VAL_5]] : tensor<?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xf32, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK:           %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref<f32>
 // CHECK:           %[[VAL_12:.*]] = memref.alloc() : memref<f32>
 // CHECK:           memref.copy %[[VAL_11]], %[[VAL_12]] : memref<f32> to memref<f32>
-// CHECK:           scf.for %[[VAL_13:.*]] = %[[VAL_4]] to %[[VAL_9]] step %[[VAL_5]] {
-// CHECK:             %[[VAL_14:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_13]]] : memref<?xf32>
-// CHECK:             scf.for %[[VAL_15:.*]] = %[[VAL_4]] to %[[VAL_6]] step %[[VAL_5]] {
-// CHECK:               %[[VAL_16:.*]] = memref.load %[[VAL_12]][] : memref<f32>
-// CHECK:               %[[VAL_17:.*]] = scf.for %[[VAL_18:.*]] = %[[VAL_4]] to %[[VAL_7]] step %[[VAL_5]] iter_args(%[[VAL_19:.*]] = %[[VAL_16]]) -> (f32) {
-// CHECK:                 %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_13]], %[[VAL_15]], %[[VAL_18]]] : memref<?x?x?xf32>
-// CHECK:                 %[[VAL_21:.*]] = arith.mulf %[[VAL_20]], %[[VAL_14]] : f32
-// CHECK:                 %[[VAL_22:.*]] = arith.addf %[[VAL_19]], %[[VAL_21]] : f32
-// CHECK:                 scf.yield %[[VAL_22]] : f32
+// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_12]][] : memref<f32>
+// CHECK:           %[[VAL_14:.*]] = scf.for %[[VAL_15:.*]] = %[[VAL_5]] to %[[VAL_9]] step %[[VAL_3]] iter_args(%[[VAL_16:.*]] = %[[VAL_13]]) -> (f32) {
+// CHECK:             %[[VAL_17:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_15]]] : memref<?xf32>
+// CHECK:             %[[VAL_18:.*]] = scf.for %[[VAL_19:.*]] = %[[VAL_5]] to %[[VAL_6]] step %[[VAL_3]] iter_args(%[[VAL_20:.*]] = %[[VAL_16]]) -> (f32) {
+// CHECK:               %[[VAL_21:.*]] = scf.for %[[VAL_22:.*]] = %[[VAL_5]] to %[[VAL_7]] step %[[VAL_3]] iter_args(%[[VAL_23:.*]] = %[[VAL_20]]) -> (f32) {
+// CHECK:                 %[[VAL_24:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_15]], %[[VAL_19]], %[[VAL_22]]] : memref<?x?x?xf32>
+// CHECK:                 %[[VAL_25:.*]] = arith.mulf %[[VAL_24]], %[[VAL_17]] : f32
+// CHECK:                 %[[VAL_26:.*]] = arith.addf %[[VAL_23]], %[[VAL_25]] : f32
+// CHECK:                 scf.yield %[[VAL_26]] : f32
 // CHECK:               }
-// CHECK:               memref.store %[[VAL_23:.*]], %[[VAL_12]][] : memref<f32>
+// CHECK:               scf.yield %[[VAL_21]] : f32
 // CHECK:             }
+// CHECK:             scf.yield %[[VAL_18]] : f32
 // CHECK:           }
-// CHECK:           %[[VAL_24:.*]] = memref.tensor_load %[[VAL_12]] : memref<f32>
-// CHECK:           return %[[VAL_24]] : tensor<f32>
+// CHECK:           memref.store %[[VAL_14]], %[[VAL_12]][] : memref<f32>
+// CHECK:           %[[VAL_30:.*]] = memref.tensor_load %[[VAL_12]] : memref<f32>
+// CHECK:           return %[[VAL_30]] : tensor<f32>
 // CHECK:         }
 func @sum_reduction_inv(%arga: tensor<?x?x?xf32>,
                         %argb: tensor<?xf32, #Td>,

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_lower.mlir b/mlir/test/Dialect/SparseTensor/sparse_lower.mlir
index 2ad36917530dd..363a0281af7b6 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_lower.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_lower.mlir
@@ -21,24 +21,24 @@
 }
 
 // CHECK-HIR-LABEL:   func @matvec(
-// CHECK-HIR-SAME:                 %[[VAL_0:.*]]: tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
-// CHECK-HIR-SAME:                 %[[VAL_1:.*]]: tensor<64xf64>,
-// CHECK-HIR-SAME:                 %[[VAL_2:.*]]: tensor<32xf64>) -> tensor<32xf64> {
+// CHECK-HIR-SAME:      %[[VAL_0:.*]]: tensor<32x64xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-HIR-SAME:      %[[VAL_1:.*]]: tensor<64xf64>,
+// CHECK-HIR-SAME:      %[[VAL_2:.*]]: tensor<32xf64>) -> tensor<32xf64> {
 // CHECK-HIR-DAG:       %[[VAL_3:.*]] = arith.constant 32 : index
 // CHECK-HIR-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-HIR-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK-HIR:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
-// CHECK-HIR:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
-// CHECK-HIR:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf64>
+// CHECK-HIR:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<32x64xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-HIR:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<32x64xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-HIR:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x64xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK-HIR:           %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<64xf64>
 // CHECK-HIR:           %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf64>
 // CHECK-HIR:           %[[VAL_11:.*]] = memref.alloc() : memref<32xf64>
 // CHECK-HIR:           memref.copy %[[VAL_10]], %[[VAL_11]] : memref<32xf64> to memref<32xf64>
 // CHECK-HIR:           scf.for %[[VAL_12:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
-// CHECK-HIR:             %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<?xindex>
-// CHECK-HIR:             %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_5]] : index
-// CHECK-HIR:             %[[VAL_15:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_14]]] : memref<?xindex>
-// CHECK-HIR:             %[[VAL_16:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_12]]] : memref<32xf64>
+// CHECK-HIR-DAG:         %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<?xindex>
+// CHECK-HIR-DAG:         %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_5]] : index
+// CHECK-HIR-DAG:         %[[VAL_15:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_14]]] : memref<?xindex>
+// CHECK-HIR-DAG:         %[[VAL_16:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_12]]] : memref<32xf64>
 // CHECK-HIR:             %[[VAL_17:.*]] = scf.for %[[VAL_18:.*]] = %[[VAL_13]] to %[[VAL_15]] step %[[VAL_5]] iter_args(%[[VAL_19:.*]] = %[[VAL_16]]) -> (f64) {
 // CHECK-HIR:               %[[VAL_20:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_18]]] : memref<?xindex>
 // CHECK-HIR:               %[[VAL_21:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xf64>
@@ -47,16 +47,16 @@
 // CHECK-HIR:               %[[VAL_24:.*]] = arith.addf %[[VAL_19]], %[[VAL_23]] : f64
 // CHECK-HIR:               scf.yield %[[VAL_24]] : f64
 // CHECK-HIR:             }
-// CHECK-HIR:             memref.store %[[VAL_25:.*]], %[[VAL_11]]{{\[}}%[[VAL_12]]] : memref<32xf64>
+// CHECK-HIR:             memref.store %[[VAL_17]], %[[VAL_11]]{{\[}}%[[VAL_12]]] : memref<32xf64>
 // CHECK-HIR:           }
 // CHECK-HIR:           %[[VAL_26:.*]] = memref.tensor_load %[[VAL_11]] : memref<32xf64>
 // CHECK-HIR:           return %[[VAL_26]] : tensor<32xf64>
 // CHECK-HIR:         }
 
 // CHECK-MIR-LABEL:   func @matvec(
-// CHECK-MIR-SAME:                 %[[VAL_0:.*]]: !llvm.ptr<i8>,
-// CHECK-MIR-SAME:                 %[[VAL_1:.*]]: tensor<64xf64>,
-// CHECK-MIR-SAME:                 %[[VAL_2:.*]]: tensor<32xf64>) -> tensor<32xf64> {
+// CHECK-MIR-SAME:      %[[VAL_0:.*]]: !llvm.ptr<i8>,
+// CHECK-MIR-SAME:      %[[VAL_1:.*]]: tensor<64xf64>,
+// CHECK-MIR-SAME:      %[[VAL_2:.*]]: tensor<32xf64>) -> tensor<32xf64> {
 // CHECK-MIR-DAG:       %[[VAL_3:.*]] = arith.constant 32 : index
 // CHECK-MIR-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-MIR-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
@@ -68,10 +68,10 @@
 // CHECK-MIR:           %[[VAL_11:.*]] = memref.alloc() : memref<32xf64>
 // CHECK-MIR:           memref.copy %[[VAL_10]], %[[VAL_11]] : memref<32xf64> to memref<32xf64>
 // CHECK-MIR:           scf.for %[[VAL_14:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
-// CHECK-MIR:             %[[VAL_15:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_14]]] : memref<?xindex>
-// CHECK-MIR:             %[[VAL_16:.*]] = arith.addi %[[VAL_14]], %[[VAL_5]] : index
-// CHECK-MIR:             %[[VAL_17:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_16]]] : memref<?xindex>
-// CHECK-MIR:             %[[VAL_18:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_14]]] : memref<32xf64>
+// CHECK-MIR-DAG:         %[[VAL_15:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_14]]] : memref<?xindex>
+// CHECK-MIR-DAG:         %[[VAL_16:.*]] = arith.addi %[[VAL_14]], %[[VAL_5]] : index
+// CHECK-MIR-DAG:         %[[VAL_17:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_16]]] : memref<?xindex>
+// CHECK-MIR-DAG:         %[[VAL_18:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_14]]] : memref<32xf64>
 // CHECK-MIR:             %[[VAL_19:.*]] = scf.for %[[VAL_20:.*]] = %[[VAL_15]] to %[[VAL_17]] step %[[VAL_5]] iter_args(%[[VAL_21:.*]] = %[[VAL_18]]) -> (f64) {
 // CHECK-MIR:               %[[VAL_22:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_20]]] : memref<?xindex>
 // CHECK-MIR:               %[[VAL_23:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_20]]] : memref<?xf64>
@@ -80,16 +80,16 @@
 // CHECK-MIR:               %[[VAL_26:.*]] = arith.addf %[[VAL_21]], %[[VAL_25]] : f64
 // CHECK-MIR:               scf.yield %[[VAL_26]] : f64
 // CHECK-MIR:             }
-// CHECK-MIR:             memref.store %[[VAL_27:.*]], %[[VAL_11]]{{\[}}%[[VAL_14]]] : memref<32xf64>
+// CHECK-MIR:             memref.store %[[VAL_19]], %[[VAL_11]]{{\[}}%[[VAL_14]]] : memref<32xf64>
 // CHECK-MIR:           }
 // CHECK-MIR:           %[[VAL_28:.*]] = memref.tensor_load %[[VAL_11]] : memref<32xf64>
 // CHECK-MIR:           return %[[VAL_28]] : tensor<32xf64>
 // CHECK-MIR:         }
 
 // CHECK-LIR-LABEL:   func @matvec(
-// CHECK-LIR-SAME:                 %[[VAL_0:.*]]: !llvm.ptr<i8>,
-// CHECK-LIR-SAME:                 %[[VAL_1:.*]]: memref<64xf64>,
-// CHECK-LIR-SAME:                 %[[VAL_2:.*]]: memref<32xf64>) -> memref<32xf64> {
+// CHECK-LIR-SAME:      %[[VAL_0:.*]]: !llvm.ptr<i8>,
+// CHECK-LIR-SAME:      %[[VAL_1:.*]]: memref<64xf64>,
+// CHECK-LIR-SAME:      %[[VAL_2:.*]]: memref<32xf64>) -> memref<32xf64> {
 // CHECK-LIR-DAG:       %[[VAL_3:.*]] = arith.constant 32 : index
 // CHECK-LIR-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-LIR-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
@@ -99,10 +99,10 @@
 // CHECK-LIR:           %[[VAL_9:.*]] = memref.alloc() : memref<32xf64>
 // CHECK-LIR:           memref.copy %[[VAL_2]], %[[VAL_9]] : memref<32xf64> to memref<32xf64>
 // CHECK-LIR:           scf.for %[[VAL_12:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
-// CHECK-LIR:             %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<?xindex>
-// CHECK-LIR:             %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_5]] : index
-// CHECK-LIR:             %[[VAL_15:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_14]]] : memref<?xindex>
-// CHECK-LIR:             %[[VAL_16:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_12]]] : memref<32xf64>
+// CHECK-LIR-DAG:         %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<?xindex>
+// CHECK-LIR-DAG:         %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_5]] : index
+// CHECK-LIR-DAG:         %[[VAL_15:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_14]]] : memref<?xindex>
+// CHECK-LIR-DAG:         %[[VAL_16:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_12]]] : memref<32xf64>
 // CHECK-LIR:             %[[VAL_17:.*]] = scf.for %[[VAL_18:.*]] = %[[VAL_13]] to %[[VAL_15]] step %[[VAL_5]] iter_args(%[[VAL_19:.*]] = %[[VAL_16]]) -> (f64) {
 // CHECK-LIR:               %[[VAL_20:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_18]]] : memref<?xindex>
 // CHECK-LIR:               %[[VAL_21:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xf64>
@@ -111,7 +111,7 @@
 // CHECK-LIR:               %[[VAL_24:.*]] = arith.addf %[[VAL_19]], %[[VAL_23]] : f64
 // CHECK-LIR:               scf.yield %[[VAL_24]] : f64
 // CHECK-LIR:             }
-// CHECK-LIR:             memref.store %[[VAL_25:.*]], %[[VAL_9]]{{\[}}%[[VAL_12]]] : memref<32xf64>
+// CHECK-LIR:             memref.store %[[VAL_17]], %[[VAL_9]]{{\[}}%[[VAL_12]]] : memref<32xf64>
 // CHECK-LIR:           }
 // CHECK-LIR:           return %[[VAL_9]] : memref<32xf64>
 // CHECK-LIR:         }

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir b/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir
index 40bc39f4605b0..728ef0d184fcd 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir
@@ -21,22 +21,22 @@
 }
 
 // CHECK-HIR-LABEL:   func @matvec(
-// CHECK-HIR-SAME:                 %[[VAL_0:.*]]: tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
-// CHECK-HIR-SAME:                 %[[VAL_1:.*]]: tensor<64xf64>,
-// CHECK-HIR-SAME:                 %[[VAL_2:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+// CHECK-HIR-SAME:      %[[VAL_0:.*]]: tensor<32x64xf64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-HIR-SAME:      %[[VAL_1:.*]]: tensor<64xf64>,
+// CHECK-HIR-SAME:      %[[VAL_2:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
 // CHECK-HIR-DAG:       %[[VAL_3:.*]] = arith.constant 32 : index
 // CHECK-HIR-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-HIR-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK-HIR:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
-// CHECK-HIR:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
-// CHECK-HIR:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf64>
+// CHECK-HIR:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<32x64xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-HIR:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<32x64xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-HIR:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x64xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK-HIR:           %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<64xf64>
 // CHECK-HIR:           %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf64>
 // CHECK-HIR:           scf.for %[[VAL_11:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
-// CHECK-HIR:             %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
-// CHECK-HIR:             %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_5]] : index
-// CHECK-HIR:             %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref<?xindex>
-// CHECK-HIR:             %[[VAL_15:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<32xf64>
+// CHECK-HIR-DAG:         %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
+// CHECK-HIR-DAG:         %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_5]] : index
+// CHECK-HIR-DAG:         %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref<?xindex>
+// CHECK-HIR-DAG:         %[[VAL_15:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<32xf64>
 // CHECK-HIR:             %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_12]] to %[[VAL_14]] step %[[VAL_5]] iter_args(%[[VAL_18:.*]] = %[[VAL_15]]) -> (f64) {
 // CHECK-HIR:               %[[VAL_19:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_17]]] : memref<?xindex>
 // CHECK-HIR:               %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref<?xf64>
@@ -45,16 +45,16 @@
 // CHECK-HIR:               %[[VAL_23:.*]] = arith.addf %[[VAL_18]], %[[VAL_22]] : f64
 // CHECK-HIR:               scf.yield %[[VAL_23]] : f64
 // CHECK-HIR:             }
-// CHECK-HIR:             memref.store %[[VAL_24:.*]], %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<32xf64>
+// CHECK-HIR:             memref.store %[[VAL_16]], %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<32xf64>
 // CHECK-HIR:           }
 // CHECK-HIR:           %[[VAL_25:.*]] = memref.tensor_load %[[VAL_10]] : memref<32xf64>
 // CHECK-HIR:           return %[[VAL_25]] : tensor<32xf64>
 // CHECK-HIR:         }
 
 // CHECK-MIR-LABEL:   func @matvec(
-// CHECK-MIR-SAME:                 %[[VAL_0:.*]]: !llvm.ptr<i8>,
-// CHECK-MIR-SAME:                 %[[VAL_1:.*]]: tensor<64xf64>,
-// CHECK-MIR-SAME:                 %[[VAL_2:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+// CHECK-MIR-SAME:      %[[VAL_0:.*]]: !llvm.ptr<i8>,
+// CHECK-MIR-SAME:      %[[VAL_1:.*]]: tensor<64xf64>,
+// CHECK-MIR-SAME:      %[[VAL_2:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
 // CHECK-MIR-DAG:       %[[VAL_3:.*]] = arith.constant 32 : index
 // CHECK-MIR-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-MIR-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
@@ -64,10 +64,10 @@
 // CHECK-MIR:           %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<64xf64>
 // CHECK-MIR:           %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf64>
 // CHECK-MIR:           scf.for %[[VAL_11:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
-// CHECK-MIR:             %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
-// CHECK-MIR:             %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_5]] : index
-// CHECK-MIR:             %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref<?xindex>
-// CHECK-MIR:             %[[VAL_15:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<32xf64>
+// CHECK-MIR-DAG:         %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
+// CHECK-MIR-DAG:         %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_5]] : index
+// CHECK-MIR-DAG:         %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref<?xindex>
+// CHECK-MIR-DAG:         %[[VAL_15:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<32xf64>
 // CHECK-MIR:             %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_12]] to %[[VAL_14]] step %[[VAL_5]] iter_args(%[[VAL_18:.*]] = %[[VAL_15]]) -> (f64) {
 // CHECK-MIR:               %[[VAL_19:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_17]]] : memref<?xindex>
 // CHECK-MIR:               %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref<?xf64>
@@ -76,16 +76,16 @@
 // CHECK-MIR:               %[[VAL_23:.*]] = arith.addf %[[VAL_18]], %[[VAL_22]] : f64
 // CHECK-MIR:               scf.yield %[[VAL_23]] : f64
 // CHECK-MIR:             }
-// CHECK-MIR:             memref.store %[[VAL_24:.*]], %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<32xf64>
+// CHECK-MIR:             memref.store %[[VAL_16]], %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<32xf64>
 // CHECK-MIR:           }
 // CHECK-MIR:           %[[VAL_25:.*]] = memref.tensor_load %[[VAL_10]] : memref<32xf64>
 // CHECK-MIR:           return %[[VAL_25]] : tensor<32xf64>
 // CHECK-MIR:         }
 
 // CHECK-LIR-LABEL:   func @matvec(
-// CHECK-LIR-SAME:                 %[[VAL_0:.*]]: !llvm.ptr<i8>,
-// CHECK-LIR-SAME:                 %[[VAL_1:.*]]: memref<64xf64>,
-// CHECK-LIR-SAME:                 %[[VAL_2:.*]]: memref<32xf64> {linalg.inplaceable = true}) -> memref<32xf64> {
+// CHECK-LIR-SAME:      %[[VAL_0:.*]]: !llvm.ptr<i8>,
+// CHECK-LIR-SAME:      %[[VAL_1:.*]]: memref<64xf64>,
+// CHECK-LIR-SAME:      %[[VAL_2:.*]]: memref<32xf64> {linalg.inplaceable = true}) -> memref<32xf64> {
 // CHECK-LIR-DAG:       %[[VAL_3:.*]] = arith.constant 32 : index
 // CHECK-LIR-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-LIR-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
@@ -93,10 +93,10 @@
 // CHECK-LIR:           %[[VAL_7:.*]] = call @sparseIndices(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
 // CHECK-LIR:           %[[VAL_8:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr<i8>) -> memref<?xf64>
 // CHECK-LIR:           scf.for %[[VAL_9:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
-// CHECK-LIR:             %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_9]]] : memref<?xindex>
-// CHECK-LIR:             %[[VAL_11:.*]] = arith.addi %[[VAL_9]], %[[VAL_5]] : index
-// CHECK-LIR:             %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
-// CHECK-LIR:             %[[VAL_13:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] : memref<32xf64>
+// CHECK-LIR-DAG:         %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_9]]] : memref<?xindex>
+// CHECK-LIR-DAG:         %[[VAL_11:.*]] = arith.addi %[[VAL_9]], %[[VAL_5]] : index
+// CHECK-LIR-DAG:         %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
+// CHECK-LIR-DAG:         %[[VAL_13:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] : memref<32xf64>
 // CHECK-LIR:             %[[VAL_14:.*]] = scf.for %[[VAL_15:.*]] = %[[VAL_10]] to %[[VAL_12]] step %[[VAL_5]] iter_args(%[[VAL_16:.*]] = %[[VAL_13]]) -> (f64) {
 // CHECK-LIR:               %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref<?xindex>
 // CHECK-LIR:               %[[VAL_18:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_15]]] : memref<?xf64>
@@ -105,7 +105,7 @@
 // CHECK-LIR:               %[[VAL_21:.*]] = arith.addf %[[VAL_16]], %[[VAL_20]] : f64
 // CHECK-LIR:               scf.yield %[[VAL_21]] : f64
 // CHECK-LIR:             }
-// CHECK-LIR:             memref.store %[[VAL_22:.*]], %[[VAL_2]]{{\[}}%[[VAL_9]]] : memref<32xf64>
+// CHECK-LIR:             memref.store %[[VAL_14]], %[[VAL_2]]{{\[}}%[[VAL_9]]] : memref<32xf64>
 // CHECK-LIR:           }
 // CHECK-LIR:           return %[[VAL_2]] : memref<32xf64>
 // CHECK-LIR:         }

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
index 30d36b5655a13..43135912e41c8 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
@@ -17,67 +17,71 @@
 }
 
 // CHECK-HIR-LABEL:   func @sparse_dynamic_dims(
-// CHECK-HIR-SAME:                                      %[[VAL_0:.*]]: tensor<?x?x?xf32,  #sparse_tensor.encoding<{{{.*}}}>>,
-// CHECK-HIR-SAME:                                      %[[VAL_1:.*]]: tensor<f32>) -> tensor<f32> {
-// CHECK-HIR-DAG:       %[[C0:.*]] = arith.constant 0 : index
-// CHECK-HIR-DAG:       %[[C1:.*]] = arith.constant 1 : index
-// CHECK-HIR-DAG:       %[[C2:.*]] = arith.constant 2 : index
-// CHECK-HIR:           %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[C2]] : tensor<?x?x?xf32,  #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-HIR:           %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[C0]] : tensor<?x?x?xf32,  #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-HIR:           %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[C1]] : tensor<?x?x?xf32,  #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-HIR:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?x?xf32,  #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-HIR-SAME:      %[[VAL_0:.*]]: tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-HIR-SAME:      %[[VAL_1:.*]]: tensor<f32>) -> tensor<f32> {
+// CHECK-HIR-DAG:       %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK-HIR-DAG:       %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK-HIR-DAG:       %[[VAL_4:.*]] = arith.constant 2 : index
+// CHECK-HIR:           %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-HIR:           %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-HIR:           %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-HIR:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK-HIR:           %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<f32>
 // CHECK-HIR:           %[[VAL_10:.*]] = memref.alloc() : memref<f32>
 // CHECK-HIR:           memref.copy %[[VAL_9]], %[[VAL_10]] : memref<f32> to memref<f32>
-// CHECK-HIR:           scf.for %[[VAL_11:.*]] = %[[C0]] to %[[VAL_5]] step %[[C1]] {
-// CHECK-HIR:             scf.for %[[VAL_12:.*]] = %[[C0]] to %[[VAL_6]] step %[[C1]] {
-// CHECK-HIR:               %[[VAL_13:.*]] = arith.muli %[[VAL_6]], %[[VAL_11]] : index
-// CHECK-HIR:               %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_12]] : index
-// CHECK-HIR:               %[[VAL_15:.*]] = memref.load %[[VAL_10]][] : memref<f32>
-// CHECK-HIR:               %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[C0]] to %[[VAL_7]] step %[[C1]] iter_args(%[[VAL_18:.*]] = %[[VAL_15]]) -> (f32) {
-// CHECK-HIR:                 %[[VAL_19:.*]] = arith.muli %[[VAL_7]], %[[VAL_14]] : index
-// CHECK-HIR:                 %[[VAL_20:.*]] = arith.addi %[[VAL_19]], %[[VAL_17]] : index
-// CHECK-HIR:                 %[[VAL_21:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_20]]] : memref<?xf32>
-// CHECK-HIR:                 %[[VAL_22:.*]] = arith.addf %[[VAL_18]], %[[VAL_21]] : f32
-// CHECK-HIR:                 scf.yield %[[VAL_22]] : f32
+// CHECK-HIR:           %[[VAL_11:.*]] = memref.load %[[VAL_10]][] : memref<f32>
+// CHECK-HIR:           %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_2]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
+// CHECK-HIR:             %[[VAL_15:.*]] = scf.for %[[VAL_16:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_2]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) {
+// CHECK-HIR:               %[[VAL_18:.*]] = arith.muli %[[VAL_6]], %[[VAL_13]] : index
+// CHECK-HIR:               %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[VAL_16]] : index
+// CHECK-HIR:               %[[VAL_20:.*]] = scf.for %[[VAL_21:.*]] = %[[VAL_3]] to %[[VAL_7]] step %[[VAL_2]] iter_args(%[[VAL_22:.*]] = %[[VAL_17]]) -> (f32) {
+// CHECK-HIR:                 %[[VAL_23:.*]] = arith.muli %[[VAL_7]], %[[VAL_19]] : index
+// CHECK-HIR:                 %[[VAL_24:.*]] = arith.addi %[[VAL_23]], %[[VAL_21]] : index
+// CHECK-HIR:                 %[[VAL_25:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_24]]] : memref<?xf32>
+// CHECK-HIR:                 %[[VAL_26:.*]] = arith.addf %[[VAL_22]], %[[VAL_25]] : f32
+// CHECK-HIR:                 scf.yield %[[VAL_26]] : f32
 // CHECK-HIR:               }
-// CHECK-HIR:               memref.store %[[VAL_23:.*]], %[[VAL_10]][] : memref<f32>
+// CHECK-HIR:               scf.yield %[[VAL_20]] : f32
 // CHECK-HIR:             }
+// CHECK-HIR:             scf.yield %[[VAL_15]] : f32
 // CHECK-HIR:           }
-// CHECK-HIR:           %[[VAL_24:.*]] = memref.tensor_load %[[VAL_10]] : memref<f32>
-// CHECK-HIR:           return %[[VAL_24]] : tensor<f32>
+// CHECK-HIR:           memref.store %[[VAL_12]], %[[VAL_10]][] : memref<f32>
+// CHECK-HIR:           %[[VAL_30:.*]] = memref.tensor_load %[[VAL_10]] : memref<f32>
+// CHECK-HIR:           return %[[VAL_30]] : tensor<f32>
 // CHECK-HIR:         }
 //
 // CHECK-MIR-LABEL:   func @sparse_dynamic_dims(
-// CHECK-MIR-SAME:                                      %[[VAL_0:.*]]: !llvm.ptr<i8>,
-// CHECK-MIR-SAME:                                      %[[VAL_1:.*]]: tensor<f32>) -> tensor<f32> {
-// CHECK-MIR-DAG:       %[[C0:.*]] = arith.constant 0 : index
-// CHECK-MIR-DAG:       %[[C1:.*]] = arith.constant 1 : index
-// CHECK-MIR-DAG:       %[[C2:.*]] = arith.constant 2 : index
-// CHECK-MIR:           %[[VAL_5:.*]] = call @sparseDimSize(%[[VAL_0]], %[[C0]]) : (!llvm.ptr<i8>, index) -> index
-// CHECK-MIR:           %[[VAL_6:.*]] = call @sparseDimSize(%[[VAL_0]], %[[C1]]) : (!llvm.ptr<i8>, index) -> index
-// CHECK-MIR:           %[[VAL_7:.*]] = call @sparseDimSize(%[[VAL_0]], %[[C2]]) : (!llvm.ptr<i8>, index) -> index
+// CHECK-MIR-SAME:      %[[VAL_0:.*]]: !llvm.ptr<i8>,
+// CHECK-MIR-SAME:      %[[VAL_1:.*]]: tensor<f32>) -> tensor<f32> {
+// CHECK-MIR-DAG:       %[[VAL_2:.*]] = arith.constant 2 : index
+// CHECK-MIR-DAG:       %[[VAL_3:.*]] = arith.constant 1 : index
+// CHECK-MIR-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK-MIR:           %[[VAL_5:.*]] = call @sparseDimSize(%[[VAL_0]], %[[VAL_4]]) : (!llvm.ptr<i8>, index) -> index
+// CHECK-MIR:           %[[VAL_6:.*]] = call @sparseDimSize(%[[VAL_0]], %[[VAL_3]]) : (!llvm.ptr<i8>, index) -> index
+// CHECK-MIR:           %[[VAL_7:.*]] = call @sparseDimSize(%[[VAL_0]], %[[VAL_2]]) : (!llvm.ptr<i8>, index) -> index
 // CHECK-MIR:           %[[VAL_8:.*]] = call @sparseValuesF32(%[[VAL_0]]) : (!llvm.ptr<i8>) -> memref<?xf32>
 // CHECK-MIR:           %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<f32>
 // CHECK-MIR:           %[[VAL_10:.*]] = memref.alloc() : memref<f32>
 // CHECK-MIR:           memref.copy %[[VAL_9]], %[[VAL_10]] : memref<f32> to memref<f32>
-// CHECK-MIR:           scf.for %[[VAL_11:.*]] = %[[C0]] to %[[VAL_5]] step %[[C1]] {
-// CHECK-MIR:             scf.for %[[VAL_12:.*]] = %[[C0]] to %[[VAL_6]] step %[[C1]] {
-// CHECK-MIR:               %[[VAL_13:.*]] = arith.muli %[[VAL_6]], %[[VAL_11]] : index
-// CHECK-MIR:               %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_12]] : index
-// CHECK-MIR:               %[[VAL_15:.*]] = memref.load %[[VAL_10]][] : memref<f32>
-// CHECK-MIR:               %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[C0]] to %[[VAL_7]] step %[[C1]] iter_args(%[[VAL_18:.*]] = %[[VAL_15]]) -> (f32) {
-// CHECK-MIR:                 %[[VAL_19:.*]] = arith.muli %[[VAL_7]], %[[VAL_14]] : index
-// CHECK-MIR:                 %[[VAL_20:.*]] = arith.addi %[[VAL_19]], %[[VAL_17]] : index
-// CHECK-MIR:                 %[[VAL_21:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_20]]] : memref<?xf32>
-// CHECK-MIR:                 %[[VAL_22:.*]] = arith.addf %[[VAL_18]], %[[VAL_21]] : f32
-// CHECK-MIR:                 scf.yield %[[VAL_22]] : f32
+// CHECK-MIR:           %[[VAL_11:.*]] = memref.load %[[VAL_10]][] : memref<f32>
+// CHECK-MIR:           %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_3]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
+// CHECK-MIR:             %[[VAL_15:.*]] = scf.for %[[VAL_16:.*]] = %[[VAL_4]] to %[[VAL_6]] step %[[VAL_3]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) {
+// CHECK-MIR:               %[[VAL_18:.*]] = arith.muli %[[VAL_6]], %[[VAL_13]] : index
+// CHECK-MIR:               %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[VAL_16]] : index
+// CHECK-MIR:               %[[VAL_20:.*]] = scf.for %[[VAL_21:.*]] = %[[VAL_4]] to %[[VAL_7]] step %[[VAL_3]] iter_args(%[[VAL_22:.*]] = %[[VAL_17]]) -> (f32) {
+// CHECK-MIR:                 %[[VAL_23:.*]] = arith.muli %[[VAL_7]], %[[VAL_19]] : index
+// CHECK-MIR:                 %[[VAL_24:.*]] = arith.addi %[[VAL_23]], %[[VAL_21]] : index
+// CHECK-MIR:                 %[[VAL_25:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_24]]] : memref<?xf32>
+// CHECK-MIR:                 %[[VAL_26:.*]] = arith.addf %[[VAL_22]], %[[VAL_25]] : f32
+// CHECK-MIR:                 scf.yield %[[VAL_26]] : f32
 // CHECK-MIR:               }
-// CHECK-MIR:               memref.store %[[VAL_23:.*]], %[[VAL_10]][] : memref<f32>
+// CHECK-MIR:               scf.yield %[[VAL_20]] : f32
 // CHECK-MIR:             }
+// CHECK-MIR:             scf.yield %[[VAL_15]] : f32
 // CHECK-MIR:           }
-// CHECK-MIR:           %[[VAL_24:.*]] = memref.tensor_load %[[VAL_10]] : memref<f32>
-// CHECK-MIR:           return %[[VAL_24]] : tensor<f32>
+// CHECK-MIR:           memref.store %[[VAL_12]], %[[VAL_10]][] : memref<f32>
+// CHECK-MIR:           %[[VAL_30:.*]] = memref.tensor_load %[[VAL_10]] : memref<f32>
+// CHECK-MIR:           return %[[VAL_30]] : tensor<f32>
 // CHECK-MIR:         }
 func @sparse_dynamic_dims(%arga: tensor<?x?x?xf32, #X>,
                           %argx: tensor<f32>) -> tensor<f32> {


        


More information about the Mlir-commits mailing list