[Mlir-commits] [mlir] 98f93e3 - [mlir][sparse] factorized merger/emitter/codegen into single environment

Aart Bik llvmlistbot at llvm.org
Tue Dec 20 11:34:21 PST 2022


Author: Aart Bik
Date: 2022-12-20T11:34:12-08:00
New Revision: 98f93e3b726ac061153ff381b0119d6f2711abe4

URL: https://github.com/llvm/llvm-project/commit/98f93e3b726ac061153ff381b0119d6f2711abe4
DIFF: https://github.com/llvm/llvm-project/commit/98f93e3b726ac061153ff381b0119d6f2711abe4.diff

LOG: [mlir][sparse] factorized merger/emitter/codegen into single environment

This cleans up a lot of parameter passing. It also prepares adding
proper "delegate" functions to the new environment and moving this
out into its own class with a better OO design.

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D140257

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 680ef8d00b037..3f67f66a93097 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -41,7 +41,7 @@ using namespace mlir::sparse_tensor;
 
 namespace {
 
-// Iteration graph sorting.
+/// Iteration graph sorting.
 enum SortMask {
   kSparseOnly = 0x0,
   kIncludeDense = 0x1,
@@ -49,37 +49,91 @@ enum SortMask {
   kIncludeAll = 0x3
 };
 
-// Reduction kinds.
+/// Reduction kinds.
 enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor, kCustom };
 
-// Code generation.
-struct CodeGen {
-  CodeGen(SparsificationOptions o, MLIRContext *context, ValueRange tensors,
-          unsigned numTensors, unsigned numLoops, OpOperand *op, unsigned nest,
-          std::vector<unsigned> &ts)
-      : options(o),
-        loopEmitter(
-            tensors,
-            StringAttr::get(context, linalg::GenericOp::getOperationName()),
-            /*hasOutput=*/true,
-            /*isSparseOut=*/op != nullptr, ts),
-        sparseOut(op), outerParNest(nest), topSort(ts) {
-    if (op)
-      insChain = op->get();
+/// Code generation environment. This structure aggregates a number
+/// of data structures needed during code generation. Such an environment
+/// simplifies passing around data during sparsification (rather than
+/// passing around all the individual compoments where needed).
+//
+// TODO: refactor further, move into own file
+//
+struct CodeGenEnv {
+  CodeGenEnv(linalg::GenericOp linop, SparsificationOptions opts,
+             unsigned numTensors, unsigned numLoops, unsigned numFilterLoops)
+      : linalgOp(linop), options(opts), topSort(),
+        merger(numTensors, numLoops, numFilterLoops), loopEmitter(nullptr),
+        redExp(-1u), redKind(kNoReduc), redCustom(-1u), sparseOut(nullptr) {}
+
+  // Start emitting.
+  void startEmit(SparseTensorLoopEmitter *le) {
+    assert(!loopEmitter && "must only start emitting once");
+    loopEmitter = le;
+    if (sparseOut) {
+      insChain = sparseOut->get();
+      merger.setHasSparseOut(true);
+    }
+  }
+
+  // Delegate methods to merger.
+  TensorExp &exp(unsigned e) { return merger.exp(e); }
+  LatPoint &lat(unsigned l) { return merger.lat(l); }
+  SmallVector<unsigned> &set(unsigned s) { return merger.set(s); }
+  DimLevelType dimLevelType(unsigned t, unsigned i) const {
+    return merger.getDimLevelType(t, i);
+  }
+  DimLevelType dimLevelType(unsigned b) const {
+    return merger.getDimLevelType(b);
+  }
+  bool isFilterLoop(unsigned i) const { return merger.isFilterLoop(i); }
+
+  // Delegate methods to loop emitter.
+  Value getLoopIV(unsigned i) const { return loopEmitter->getLoopIV(i); }
+  const std::vector<Value> &getValBuffer() const {
+    return loopEmitter->getValBuffer();
+  }
+
+  // Convenience method to slice topsort.
+  ArrayRef<unsigned> getTopSortSlice(size_t n, size_t m) const {
+    return ArrayRef<unsigned>(topSort).slice(n, m);
+  }
+
+  // Convenience method to get current loop stack.
+  ArrayRef<unsigned> getLoopCurStack() const {
+    return getTopSortSlice(0, loopEmitter->getCurrentDepth());
+  }
+
+  // Convenience method to get the IV of the given loop index.
+  Value getLoopIdxValue(size_t loopIdx) const {
+    for (unsigned lv = 0, lve = topSort.size(); lv < lve; lv++)
+      if (topSort[lv] == loopIdx)
+        return getLoopIV(lv);
+    llvm_unreachable("invalid loop index");
   }
+
+  // TODO: make private
+
+  /// Linalg operation.
+  linalg::GenericOp linalgOp;
   /// Sparsification options.
   SparsificationOptions options;
-  /// Loop emitter helper class.
-  SparseTensorLoopEmitter loopEmitter;
+  // Topological sort.
+  std::vector<unsigned> topSort;
+  /// Merger helper class.
+  Merger merger;
+  /// Loop emitter helper class (keep reference in scope!).
+  /// TODO: move emitter constructor up in time?
+  SparseTensorLoopEmitter *loopEmitter;
   /// Current reduction, updated during code generation. When indices of a
   /// reduction are exhausted, all inner loops can use a scalarized reduction.
-  unsigned redExp = -1u;
+  unsigned redExp;
   Value redVal;
-  Reduction redKind = kNoReduc;
-  unsigned redCustom = -1u;
-  // Sparse tensor as output. Implemented either through direct injective
-  // insertion in lexicographic index order or through access pattern expansion
-  // in the innermost loop nest (`expValues` through `expCount`).
+  Reduction redKind;
+  unsigned redCustom;
+  /// Sparse tensor as output. Implemented either through direct injective
+  /// insertion in lexicographic index order or through access pattern expansion
+  /// in the innermost loop nest (`expValues` through `expCount`).
   OpOperand *sparseOut;
   unsigned outerParNest;
   Value insChain; // bookkeeping for insertion chain
@@ -87,21 +141,6 @@ struct CodeGen {
   Value expFilled;
   Value expAdded;
   Value expCount;
-  // Topsort (reference should remain in scope).
-  std::vector<unsigned> &topSort;
-
-  ArrayRef<unsigned> getLoopCurStack() const {
-    ArrayRef<unsigned> topSortRef = topSort;
-    return topSortRef.slice(0, loopEmitter.getCurrentDepth());
-  }
-
-  Value getLoopIdxValue(size_t loopIdx) const {
-    for (unsigned lv = 0; lv < topSort.size(); lv++)
-      if (topSort[lv] == loopIdx)
-        return loopEmitter.getLoopIV(lv);
-
-    llvm_unreachable("invalid loop index");
-  }
 };
 
 /// A helper class that visits an affine expression and tries to find an
@@ -133,6 +172,7 @@ class AffineDimFinder : public AffineExprVisitor<AffineDimFinder> {
   /// The mapping between dim=>iterator type.
   SmallVector<utils::IteratorType> iterTypes;
 };
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -172,17 +212,17 @@ static bool isInvariantAffine(AffineExpr a, ArrayRef<unsigned> loopStack,
 }
 
 /// Determines if affine expression is invariant.
-static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a,
-                              unsigned ldx, bool &atLevel) {
-  return isInvariantAffine(a, codegen.getLoopCurStack(), ldx, atLevel);
+static bool isInvariantAffine(CodeGenEnv &env, AffineExpr a, unsigned ldx,
+                              bool &atLevel) {
+  return isInvariantAffine(a, env.getLoopCurStack(), ldx, atLevel);
 }
 
 /// Helper method to construct a permuted dimension ordering
 /// that adheres to the given topological sort.
-static AffineMap permute(const Merger &merger, MLIRContext *context,
-                         AffineMap m, ArrayRef<unsigned> topSort) {
-  assert(m.getNumDims() + merger.getNumFilterLoops() == topSort.size() &&
-         "TopoSort/AffineMap size mismatch");
+static AffineMap permute(CodeGenEnv &env, AffineMap m) {
+  assert(m.getNumDims() + env.merger.getNumFilterLoops() ==
+             env.topSort.size() &&
+         "size mismatch");
   // Construct the inverse of `m`; to avoid the asymptotic complexity
   // of calling `m.getPermutedPosition` repeatedly.
   SmallVector<unsigned> perm;
@@ -191,13 +231,14 @@ static AffineMap permute(const Merger &merger, MLIRContext *context,
   unsigned loopDepth = 1;
 
   // Construct the permutation.
-  while (worklist.any() && loopDepth <= topSort.size()) {
+  while (worklist.any() && loopDepth <= env.topSort.size()) {
     unsigned preSize = perm.size();
     for (auto dim : worklist.set_bits()) {
       bool atLevel = false;
       if (m.getResult(dim).isa<AffineConstantExpr>() ||
-          (isInvariantAffine(m.getResult(dim), topSort.slice(0, loopDepth),
-                             topSort[loopDepth - 1], atLevel) &&
+          (isInvariantAffine(m.getResult(dim),
+                             env.getTopSortSlice(0, loopDepth),
+                             env.topSort[loopDepth - 1], atLevel) &&
            atLevel)) {
         // If the matching affine is constant expression or just become
         // invariant. We can visit the dimension now without breaking the
@@ -215,7 +256,7 @@ static AffineMap permute(const Merger &merger, MLIRContext *context,
   }
 
   assert(perm.size() == numResults);
-  return AffineMap::getPermutationMap(perm, context);
+  return AffineMap::getPermutationMap(perm, env.linalgOp.getContext());
 }
 
 /// Helper method to inspect affine expressions. Rejects cases where the
@@ -305,24 +346,24 @@ static unsigned getNumCompoundAffineOnSparseDims(linalg::GenericOp op) {
 /// Returns true if the sparse annotations and affine subscript
 /// expressions of all tensors are admissible. Returns false if
 /// no annotations are found or inadmissible constructs occur.
-static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
+static bool findSparseAnnotations(CodeGenEnv &env) {
   bool annotated = false;
-  unsigned filterLdx = merger.getFilterLoopStartingIdx();
-  for (OpOperand &t : op->getOpOperands()) {
-    auto map = op.getMatchingIndexingMap(&t);
+  unsigned filterLdx = env.merger.getFilterLoopStartingIdx();
+  for (OpOperand &t : env.linalgOp->getOpOperands()) {
+    auto map = env.linalgOp.getMatchingIndexingMap(&t);
     auto enc = getSparseTensorEncoding(t.get().getType());
     if (enc)
       annotated = true;
-    assert(map.getNumResults() == op.getRank(&t));
-
+    assert(map.getNumResults() == env.linalgOp.getRank(&t));
     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
       unsigned tensor = t.getOperandNumber();
       AffineExpr a = map.getResult(toOrigDim(enc, d));
-      if (!findAffine(merger, tensor, d, a, getDimLevelType(enc, d), filterLdx))
+      if (!findAffine(env.merger, tensor, d, a, getDimLevelType(enc, d),
+                      filterLdx))
         return false; // inadmissible affine expression
     }
   }
-  assert(filterLdx == merger.getNumLoops());
+  assert(filterLdx == env.merger.getNumLoops());
   return annotated;
 }
 
@@ -330,9 +371,8 @@ static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
 /// as we use adj matrix for the graph.
 /// The sorted result will put the first Reduction iterator to the
 /// latest possible index.
-static bool topSortOptimal(unsigned n,
+static bool topSortOptimal(CodeGenEnv &env, unsigned n,
                            ArrayRef<utils::IteratorType> iteratorTypes,
-                           const Merger &merger, std::vector<unsigned> &topSort,
                            std::vector<unsigned> &inDegree,
                            std::vector<std::vector<bool>> &adjM) {
   std::vector<unsigned> redIt;    // reduce iterator with 0 degree
@@ -340,7 +380,7 @@ static bool topSortOptimal(unsigned n,
   std::vector<unsigned> filterIt; // filter loop with 0 degree
   for (unsigned i = 0; i < n; i++) {
     if (inDegree[i] == 0) {
-      if (merger.isFilterLoop(i))
+      if (env.isFilterLoop(i))
         filterIt.push_back(i);
       else if (linalg::isReductionIterator(iteratorTypes[i]))
         redIt.push_back(i);
@@ -371,12 +411,12 @@ static bool topSortOptimal(unsigned n,
     //        O(X) computation  => O(NK+NMX) time complexity
     auto &it = !filterIt.empty() ? filterIt : (!parIt.empty() ? parIt : redIt);
     auto src = it.back();
-    topSort.push_back(src);
+    env.topSort.push_back(src);
     it.pop_back();
     // Update in-degree, and push 0-degree node into worklist.
     for (unsigned dst = 0; dst < n; dst++) {
       if (adjM[src][dst] && --inDegree[dst] == 0) {
-        if (merger.isFilterLoop(dst))
+        if (env.isFilterLoop(dst))
           filterIt.push_back(dst);
         else if (linalg::isReductionIterator(iteratorTypes[dst]))
           redIt.push_back(dst);
@@ -385,7 +425,7 @@ static bool topSortOptimal(unsigned n,
       }
     }
   }
-  return topSort.size() == n;
+  return env.topSort.size() == n;
 }
 
 /// Helper method to add all constraints from the indices in one affine
@@ -477,21 +517,21 @@ static void tryLoosenAffineDenseConstraints(linalg::GenericOp op,
 /// along fixed dimensions. Even for dense storage formats, however, the
 /// natural index order yields innermost unit-stride access with better
 /// spatial locality.
-static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
-                                  std::vector<unsigned> &topSort, unsigned mask,
+static bool computeIterationGraph(CodeGenEnv &env, unsigned mask,
                                   OpOperand *skip = nullptr) {
   // Set up an n x n from/to adjacency matrix of the iteration graph
   // for the implicit loop indices i_0 .. i_n-1.
-  unsigned n = merger.getNumLoops();
+  unsigned n = env.merger.getNumLoops();
   std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false));
   std::vector<unsigned> inDegree(n, 0); // in-degree of each node.
-  auto iteratorTypes = op.getIteratorTypesArray();
+  auto iteratorTypes = env.linalgOp.getIteratorTypesArray();
   // Iterate over the indexing maps of every tensor in the tensor expression.
-  for (OpOperand &t : op->getOpOperands()) {
+  for (OpOperand &t : env.linalgOp->getOpOperands()) {
     // Get map and encoding.
-    auto map = op.getMatchingIndexingMap(&t);
+    auto map = env.linalgOp.getMatchingIndexingMap(&t);
     auto enc = getSparseTensorEncoding(t.get().getType());
-    assert(map.getNumDims() + getNumCompoundAffineOnSparseDims(op) == n);
+    assert(map.getNumDims() + getNumCompoundAffineOnSparseDims(env.linalgOp) ==
+           n);
     // Skip dense tensor constraints when not requested.
     if (!(mask & SortMask::kIncludeDense) && !enc)
       continue;
@@ -501,11 +541,11 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
     // on the loop indices if no explicit dimension ordering is given.
     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
       AffineExpr ta = map.getResult(toOrigDim(enc, d));
-      Optional<unsigned> tldx = merger.getLoopIdx(t.getOperandNumber(), d);
+      Optional<unsigned> tldx = env.merger.getLoopIdx(t.getOperandNumber(), d);
 
       // Filter loops should be constructed after all the dependent loops,
       // i.e., d0 + d1 < filter_loop(d0 + d1)
-      if (tldx && merger.isFilterLoop(*tldx)) {
+      if (tldx && env.isFilterLoop(*tldx)) {
         assert(!ta.isa<AffineDimExpr>() &&
                !isDenseDLT(getDimLevelType(enc, d)));
         addAffineOrderings(adjM, inDegree, ta, AffineExpr(), std::nullopt,
@@ -525,7 +565,7 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
       if (d > 0) {
         AffineExpr fa = map.getResult(toOrigDim(enc, d - 1));
         Optional<unsigned> fldx =
-            merger.getLoopIdx(t.getOperandNumber(), d - 1);
+            env.merger.getLoopIdx(t.getOperandNumber(), d - 1);
 
         // Applying order constraints on every pair of dimExpr between two
         // compound affine expressions can sometime too strict:
@@ -533,7 +573,7 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
         // It is totally fine to have loop sequence d0->d2->d1->d3 instead of
         // requiring d0 < d2, d1 < d2, d0 < d3, d1 < d3.
         if (!(mask & SortMask::kIncludeDense))
-          tryLoosenAffineDenseConstraints(op, fldx, fa, tldx, ta);
+          tryLoosenAffineDenseConstraints(env.linalgOp, fldx, fa, tldx, ta);
 
         // (d0 + d1) < (d2 + d3), or
         // filter_loop_d-1 < (d2 + d3), or
@@ -548,24 +588,24 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
     if (mask & SortMask::kIncludeUndef) {
       unsigned tensor = t.getOperandNumber();
       for (unsigned i = 0; i < n; i++)
-        if (isCompressedDLT(merger.getDimLevelType(tensor, i)) ||
-            isSingletonDLT(merger.getDimLevelType(tensor, i))) {
+        if (isCompressedDLT(env.dimLevelType(tensor, i)) ||
+            isSingletonDLT(env.dimLevelType(tensor, i))) {
           for (unsigned j = 0; j < n; j++)
-            if (isUndefDLT(merger.getDimLevelType(tensor, j))) {
+            if (isUndefDLT(env.dimLevelType(tensor, j))) {
               adjM[i][j] = true;
               inDegree[j]++;
             }
         } else {
-          assert(isDenseDLT(merger.getDimLevelType(tensor, i)) ||
-                 isUndefDLT(merger.getDimLevelType(tensor, i)));
+          assert(isDenseDLT(env.dimLevelType(tensor, i)) ||
+                 isUndefDLT(env.dimLevelType(tensor, i)));
         }
     }
   }
   // Topologically sort the iteration graph to determine loop order.
   // Report failure for a cyclic iteration graph.
-  topSort.clear();
-  topSort.reserve(n);
-  return topSortOptimal(n, iteratorTypes, merger, topSort, inDegree, adjM);
+  env.topSort.clear();
+  env.topSort.reserve(n);
+  return topSortOptimal(env, n, iteratorTypes, inDegree, adjM);
 }
 
 /// Returns true if tensor materializes uninitialized into the computation.
@@ -574,29 +614,25 @@ static bool isMaterializing(Value val) {
          val.getDefiningOp<bufferization::AllocTensorOp>();
 }
 
-/// Returns true when the tensor expression is admissible for codegen.
+/// Returns true when the tensor expression is admissible for env.
 /// Since all sparse input tensors are admissible, we just need to check
 /// whether the out tensor in the tensor expression codegen is admissible.
 /// Sets `sparseOut` to the tensor and `outerParNest` to the outer injective
 /// nesting depth when a "truly dynamic" sparse tensor output occurs.
-static bool isAdmissibleTensorExp(Merger &merger, linalg::GenericOp op,
-                                  std::vector<unsigned> &topSort, unsigned exp,
-                                  OpOperand **sparseOut,
-                                  unsigned &outerParNest) {
+static bool isAdmissibleTensorExp(CodeGenEnv &env, unsigned exp) {
   // We reject any expression that makes a reduction from `-outTensor`, as those
-  // expression create dependency between the current iteration (i) and the
-  // previous iteration (i-1). It would then require iterating over the whole
-  // coordinate space, which prevent us from exploiting sparsity for faster
-  // code.
-  for (utils::IteratorType it : op.getIteratorTypesArray()) {
+  // expressions create a dependency between the current iteration (i) and the
+  // previous iteration (i-1). It would require iterating over the whole
+  // coordinate space, which prevent exploiting sparsity for faster code.
+  for (utils::IteratorType it : env.linalgOp.getIteratorTypesArray()) {
     if (it == utils::IteratorType::reduction) {
-      if (merger.hasNegateOnOut(exp))
+      if (env.merger.hasNegateOnOut(exp))
         return false;
       break;
     }
   }
 
-  OpOperand *lhs = op.getDpsInitOperand(0);
+  OpOperand *lhs = env.linalgOp.getDpsInitOperand(0);
   unsigned tensor = lhs->getOperandNumber();
   auto enc = getSparseTensorEncoding(lhs->get().getType());
   // An non-annotated output tensor is assumed dense, and becomes a random
@@ -606,40 +642,41 @@ static bool isAdmissibleTensorExp(Merger &merger, linalg::GenericOp op,
   // An all-dense annotated "sparse" output tensor becomes a linearized random
   // access 1-dim memref. Also admissible since insertions cannot occur.
   bool allDense = true;
-  unsigned numLoops = merger.getNumLoops(); // numNativeLoops + numFilterLoops
-  for (unsigned i = 0; i < merger.getNumLoops(); i++)
-    if (isCompressedDLT(merger.getDimLevelType(tensor, i)) ||
-        isSingletonDLT(merger.getDimLevelType(tensor, i))) {
+  unsigned numLoops =
+      env.merger.getNumLoops(); // numNativeLoops + numFilterLoops
+  for (unsigned i = 0; i < env.merger.getNumLoops(); i++)
+    if (isCompressedDLT(env.dimLevelType(tensor, i)) ||
+        isSingletonDLT(env.dimLevelType(tensor, i))) {
       allDense = false;
       break;
     } else {
-      assert(isDenseDLT(merger.getDimLevelType(tensor, i)) ||
-             isUndefDLT(merger.getDimLevelType(tensor, i)));
+      assert(isDenseDLT(env.dimLevelType(tensor, i)) ||
+             isUndefDLT(env.dimLevelType(tensor, i)));
     }
   if (allDense)
     return true;
 
   // TODO: support compound affine expression on sparse output.
-  if (getNumCompoundAffineOnSparseDims(op.getMatchingIndexingMap(lhs),
+  if (getNumCompoundAffineOnSparseDims(env.linalgOp.getMatchingIndexingMap(lhs),
                                        lhs->get()) != 0)
     return false;
 
   // A tensor expression with a sparse output tensor that changes its values
   // but not its nonzero structure, an operation called "simply dynamic" in
-  // [Bik96,Ch9], is also admissible without special codegen.
-  if (merger.isSingleCondition(tensor, exp))
+  // [Bik96,Ch9], is also admissible without special env.
+  if (env.merger.isSingleCondition(tensor, exp))
     return true;
 
   // Accept "truly dynamic" if the output tensor materializes uninitialized
   // into the computation and insertions occur in lexicographic index order.
   if (isMaterializing(lhs->get())) {
-    auto iteratorTypes = op.getIteratorTypesArray();
+    auto iteratorTypes = env.linalgOp.getIteratorTypesArray();
     unsigned nest = 0;
     for (unsigned i = 0; i < numLoops; i++) {
-      if (!merger.isFilterLoop(topSort[i])) {
+      if (!env.isFilterLoop(env.topSort[i])) {
         // We only count non-filter loops as filter loops should be considered
         // as a special type of parallel loops.
-        if (linalg::isReductionIterator(iteratorTypes[topSort[i]]))
+        if (linalg::isReductionIterator(iteratorTypes[env.topSort[i]]))
           break; // terminate at first reduction
         nest++;
       }
@@ -647,9 +684,9 @@ static bool isAdmissibleTensorExp(Merger &merger, linalg::GenericOp op,
     // Determine admissible dynamic insertion situations:
     // (1) fully injective, since there are no reductions,
     // (2) admissible 1-d expansion in innermost dimension.
-    if (nest >= op.getRank(lhs) - 1) {
-      *sparseOut = lhs;
-      outerParNest = nest;
+    if (nest >= env.linalgOp.getRank(lhs) - 1) {
+      env.sparseOut = lhs;
+      env.outerParNest = nest;
       return true;
     }
   }
@@ -688,9 +725,9 @@ static Reduction getReduction(Kind kind) {
 }
 
 /// Updates scalarized reduction value.
-static void updateReduc(Merger &merger, CodeGen &codegen, Value reduc) {
-  assert(codegen.redKind != kNoReduc);
-  codegen.redVal = merger.exp(codegen.redExp).val = reduc;
+static void updateReduc(CodeGenEnv &env, Value reduc) {
+  assert(env.redKind != kNoReduc);
+  env.redVal = env.exp(env.redExp).val = reduc;
 }
 
 /// Extracts identity from custom reduce.
@@ -705,38 +742,38 @@ static Value getCustomRedId(Operation *op) {
 /// Generates loop boundary statements (entering/exiting loops). The function
 /// passes and updates the reduction value.
 static Optional<Operation *> genLoopBoundary(
-    CodeGen &codegen, Merger &merger,
+    CodeGenEnv &env,
     function_ref<Optional<Operation *>(MutableArrayRef<Value> reduc)>
         callback) {
   SmallVector<Value> reduc;
-  if (codegen.redVal)
-    reduc.push_back(codegen.redVal);
-  if (codegen.expValues)
-    reduc.push_back(codegen.expCount);
-  if (codegen.insChain)
-    reduc.push_back(codegen.insChain);
+  if (env.redVal)
+    reduc.push_back(env.redVal);
+  if (env.expValues)
+    reduc.push_back(env.expCount);
+  if (env.insChain)
+    reduc.push_back(env.insChain);
 
   auto r = callback(reduc);
 
   // Callback should do in-place update on reduction value vector.
   unsigned i = 0;
-  if (codegen.redVal)
-    updateReduc(merger, codegen, reduc[i++]);
-  if (codegen.expValues)
-    codegen.expCount = reduc[i++];
-  if (codegen.insChain)
-    codegen.insChain = reduc[i];
+  if (env.redVal)
+    updateReduc(env, reduc[i++]);
+  if (env.expValues)
+    env.expCount = reduc[i++];
+  if (env.insChain)
+    env.insChain = reduc[i];
 
   return r;
 }
 
 /// Local bufferization of all dense and sparse data structures.
-static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder,
-                       linalg::GenericOp op) {
+static void genBuffers(CodeGenEnv &env, OpBuilder &builder) {
+  linalg::GenericOp op = env.linalgOp;
   Location loc = op.getLoc();
   assert(op.getNumOperands() == op.getNumDpsInputs() + 1);
 
-  codegen.loopEmitter.initializeLoopEmit(
+  env.loopEmitter->initializeLoopEmit(
       builder, loc,
       /// Generates buffer for the output tensor.
       /// Note that all sparse kernels assume that when all elements are written
@@ -749,10 +786,9 @@ static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder,
             Value tensor) -> Value {
         // Must not be a sparse tensor.
         assert(!getSparseTensorEncoding(tensor.getType()));
+        // Two output tensor references should point to the same object.
         OpOperand *lhs = op.getDpsInitOperand(0);
-        // Two output tensors references should pointed to the same object.
         assert(lhs->get() == tensor);
-        bool isInit = op.isInitTensor(lhs);
         // An output tensor can simply materialize from the buffer of the tensor
         // that appears in the outs() clause. For updates, this has the
         // advantage that only the nonzero value are involved in the
@@ -761,6 +797,7 @@ static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder,
         // may negatively impact running complexity (viz. O(n^2 + nnz) vs.
         // O(nnz) for matrices).
         // TODO: use better analysis to avoid zeroing out the buffer?
+        bool isInit = op.isInitTensor(lhs);
         Value init = memref;
         if (!isInit) {
           Value zero = constantZero(builder, loc,
@@ -773,83 +810,82 @@ static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder,
 }
 
 /// Generates index for load/store on sparse tensor.
-static Value genIndex(CodeGen &codegen, linalg::GenericOp op, OpOperand *t) {
-  auto map = op.getMatchingIndexingMap(t);
+static Value genIndex(CodeGenEnv &env, OpOperand *t) {
+  auto map = env.linalgOp.getMatchingIndexingMap(t);
   auto enc = getSparseTensorEncoding(t->get().getType());
   AffineExpr a = map.getResult(toOrigDim(enc, map.getNumResults() - 1));
   assert(a.getKind() == AffineExprKind::DimId);
   unsigned idx = a.cast<AffineDimExpr>().getPosition();
-  return codegen.getLoopIdxValue(idx);
+  return env.getLoopIdxValue(idx);
 }
 
 /// Generates subscript for load/store on a dense or sparse tensor.
-static Value genSubscript(CodeGen &codegen, OpBuilder &builder,
-                          linalg::GenericOp op, OpOperand *t,
+static Value genSubscript(CodeGenEnv &env, OpBuilder &builder, OpOperand *t,
                           SmallVectorImpl<Value> &args) {
+  linalg::GenericOp op = env.linalgOp;
   unsigned tensor = t->getOperandNumber();
   auto map = op.getMatchingIndexingMap(t);
   auto enc = getSparseTensorEncoding(t->get().getType());
   unsigned rank = map.getNumResults();
   if (enc) {
-    Value pidx = codegen.loopEmitter.getPidxs()[tensor].back();
+    Value pidx = env.loopEmitter->getPidxs()[tensor].back();
     assert(pidx);
     args.push_back(pidx); // position index
   } else {
     for (unsigned d = 0; d < rank; d++) {
       AffineExpr a = map.getResult(d);
-      args.push_back(codegen.loopEmitter.genAffine(builder, a, op.getLoc()));
+      args.push_back(env.loopEmitter->genAffine(builder, a, op.getLoc()));
     }
   }
-  return codegen.loopEmitter.getValBuffer()[tensor];
+  return env.getValBuffer()[tensor];
 }
 
 /// Generates insertion code to implement dynamic tensor load.
-static Value genInsertionLoad(CodeGen &codegen, OpBuilder &builder,
-                              linalg::GenericOp op, OpOperand *t) {
+static Value genInsertionLoad(CodeGenEnv &env, OpBuilder &builder,
+                              OpOperand *t) {
+  linalg::GenericOp op = env.linalgOp;
   Location loc = op.getLoc();
   // Direct lexicographic index order, tensor loads as zero.
-  if (!codegen.expValues) {
+  if (!env.expValues) {
     Type tp = getElementTypeOrSelf(t->get().getType());
     return constantZero(builder, loc, tp);
   }
   // Load from expanded access pattern.
-  Value index = genIndex(codegen, op, t);
-  return builder.create<memref::LoadOp>(loc, codegen.expValues, index);
+  Value index = genIndex(env, t);
+  return builder.create<memref::LoadOp>(loc, env.expValues, index);
 }
 
 /// Generates insertion code to implement dynamic tensor load for reduction.
-static Value genInsertionLoadReduce(Merger &merger, CodeGen &codegen,
-                                    OpBuilder &builder, linalg::GenericOp op,
+static Value genInsertionLoadReduce(CodeGenEnv &env, OpBuilder &builder,
                                     OpOperand *t) {
+  linalg::GenericOp op = env.linalgOp;
   Location loc = op.getLoc();
-  Value identity = getCustomRedId(merger.exp(codegen.redCustom).op);
+  Value identity = getCustomRedId(env.exp(env.redCustom).op);
   // Direct lexicographic index order, tensor loads as identity.
-  if (!codegen.expValues) {
+  if (!env.expValues) {
     return identity;
   }
   // Load from expanded access pattern if filled, identity otherwise.
-  Value index = genIndex(codegen, op, t);
-  Value isFilled =
-      builder.create<memref::LoadOp>(loc, codegen.expFilled, index);
-  Value valAtIndex =
-      builder.create<memref::LoadOp>(loc, codegen.expValues, index);
+  Value index = genIndex(env, t);
+  Value isFilled = builder.create<memref::LoadOp>(loc, env.expFilled, index);
+  Value valAtIndex = builder.create<memref::LoadOp>(loc, env.expValues, index);
   return builder.create<arith::SelectOp>(loc, isFilled, valAtIndex, identity);
 }
 
 /// Generates insertion code to implement dynamic tensor store.
-static void genInsertionStore(CodeGen &codegen, OpBuilder &builder,
-                              linalg::GenericOp op, OpOperand *t, Value rhs) {
+static void genInsertionStore(CodeGenEnv &env, OpBuilder &builder, OpOperand *t,
+                              Value rhs) {
+  linalg::GenericOp op = env.linalgOp;
   Location loc = op.getLoc();
   // Direct insertion in lexicographic index order.
-  if (!codegen.expValues) {
+  if (!env.expValues) {
     unsigned rank = op.getRank(t);
     SmallVector<Value> indices;
     for (unsigned i = 0; i < rank; i++) {
-      assert(codegen.loopEmitter.getLoopIV(i));
-      indices.push_back(codegen.loopEmitter.getLoopIV(i));
+      assert(env.getLoopIV(i));
+      indices.push_back(env.getLoopIV(i));
     }
-    codegen.insChain =
-        builder.create<InsertOp>(loc, rhs, codegen.insChain, indices);
+    env.insChain = builder.create<InsertOp>(loc, rhs, env.insChain, indices);
     return;
   }
   // Generates insertion code along expanded access pattern.
@@ -858,110 +894,108 @@ static void genInsertionStore(CodeGen &codegen, OpBuilder &builder,
   //     expAdded[inserts++] = i
   //   endif
   //   values[i] = rhs
-  Value index = genIndex(codegen, op, t);
+  Value index = genIndex(env, t);
   Value fval = constantI1(builder, loc, false);
   Value tval = constantI1(builder, loc, true);
   // If statement.
-  Value filled = builder.create<memref::LoadOp>(loc, codegen.expFilled, index);
+  Value filled = builder.create<memref::LoadOp>(loc, env.expFilled, index);
   Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
                                              filled, fval);
   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIndexType(), cond,
                                              /*else=*/true);
   // True branch.
   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-  builder.create<memref::StoreOp>(loc, tval, codegen.expFilled, index);
-  builder.create<memref::StoreOp>(loc, index, codegen.expAdded,
-                                  codegen.expCount);
+  builder.create<memref::StoreOp>(loc, tval, env.expFilled, index);
+  builder.create<memref::StoreOp>(loc, index, env.expAdded, env.expCount);
   Value one = constantIndex(builder, loc, 1);
-  Value add = builder.create<arith::AddIOp>(loc, codegen.expCount, one);
+  Value add = builder.create<arith::AddIOp>(loc, env.expCount, one);
   builder.create<scf::YieldOp>(loc, add);
   // False branch.
   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
-  builder.create<scf::YieldOp>(loc, codegen.expCount);
+  builder.create<scf::YieldOp>(loc, env.expCount);
   builder.setInsertionPointAfter(ifOp);
   // Value assignment.
-  codegen.expCount = ifOp.getResult(0);
-  builder.create<memref::StoreOp>(loc, rhs, codegen.expValues, index);
+  env.expCount = ifOp.getResult(0);
+  builder.create<memref::StoreOp>(loc, rhs, env.expValues, index);
 }
 
 /// Generates a load on a dense or sparse tensor.
-static Value genTensorLoad(Merger &merger, CodeGen &codegen, OpBuilder &builder,
-                           linalg::GenericOp op, unsigned exp) {
+static Value genTensorLoad(CodeGenEnv &env, OpBuilder &builder, unsigned exp) {
   // Test if the load was hoisted to a higher loop nest.
-  Value val = merger.exp(exp).val;
+  Value val = env.exp(exp).val;
   if (val)
     return val;
 
   // Load during insertion.
-  OpOperand &t = op->getOpOperand(merger.exp(exp).tensor);
-  if (&t == codegen.sparseOut) {
-    if (codegen.redCustom != -1u)
-      return genInsertionLoadReduce(merger, codegen, builder, op, &t);
-    return genInsertionLoad(codegen, builder, op, &t);
+  linalg::GenericOp op = env.linalgOp;
+  OpOperand &t = op->getOpOperand(env.exp(exp).tensor);
+  if (&t == env.sparseOut) {
+    if (env.redCustom != -1u)
+      return genInsertionLoadReduce(env, builder, &t);
+    return genInsertionLoad(env, builder, &t);
   }
   // Actual load.
   SmallVector<Value> args;
-  Value ptr = genSubscript(codegen, builder, op, &t, args);
+  Value ptr = genSubscript(env, builder, &t, args);
   return builder.create<memref::LoadOp>(op.getLoc(), ptr, args);
 }
 
 /// Generates a store on a dense or sparse tensor.
-static void genTensorStore(Merger &merger, CodeGen &codegen, OpBuilder &builder,
-                           linalg::GenericOp op, unsigned exp, Value rhs) {
+static void genTensorStore(CodeGenEnv &env, OpBuilder &builder, unsigned exp,
+                           Value rhs) {
+  linalg::GenericOp op = env.linalgOp;
   Location loc = op.getLoc();
   // Test if this is a scalarized reduction.
-  if (codegen.redVal) {
-    updateReduc(merger, codegen, rhs);
+  if (env.redVal) {
+    updateReduc(env, rhs);
     return;
   }
   // Store during insertion.
   OpOperand *t = op.getDpsInitOperand(0);
-  if (t == codegen.sparseOut) {
+  if (t == env.sparseOut) {
     if (!rhs) {
       // Only unary and binary are allowed to return uninitialized rhs
       // to indicate missing output.
-      assert(merger.exp(exp).kind == kUnary || merger.exp(exp).kind == kBinary);
-    } else if (merger.exp(exp).kind == kSelect) {
+      assert(env.exp(exp).kind == kUnary || env.exp(exp).kind == kBinary);
+    } else if (env.exp(exp).kind == kSelect) {
       // Select operation insertion.
-      Value insChain = codegen.insChain;
+      Value insChain = env.insChain;
       assert(insChain);
       scf::IfOp ifOp = builder.create<scf::IfOp>(loc, insChain.getType(), rhs,
                                                  /*else=*/true);
       builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
       // Existing value was preserved to be used here.
-      assert(merger.exp(exp).val);
-      Value v0 = merger.exp(exp).val;
-      genInsertionStore(codegen, builder, op, t, v0);
-      merger.exp(exp).val = Value();
+      assert(env.exp(exp).val);
+      Value v0 = env.exp(exp).val;
+      genInsertionStore(env, builder, t, v0);
+      env.exp(exp).val = Value();
       // Yield modified insertion chain along true branch.
-      builder.create<scf::YieldOp>(op.getLoc(), codegen.insChain);
+      builder.create<scf::YieldOp>(op.getLoc(), env.insChain);
       // Yield original insertion chain along false branch.
       builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
       builder.create<scf::YieldOp>(loc, insChain);
       // Done with if statement.
-      codegen.insChain = ifOp->getResult(0);
+      env.insChain = ifOp->getResult(0);
       builder.setInsertionPointAfter(ifOp);
     } else {
-      genInsertionStore(codegen, builder, op, t, rhs);
+      genInsertionStore(env, builder, t, rhs);
     }
     return;
   }
   // Actual store.
   SmallVector<Value> args;
-  Value ptr = genSubscript(codegen, builder, op, t, args);
+  Value ptr = genSubscript(env, builder, t, args);
   builder.create<memref::StoreOp>(loc, rhs, ptr, args);
 }
 
 /// Generates an invariant value.
-inline static Value genInvariantValue(Merger &merger, CodeGen &codegen,
-                                      OpBuilder &builder, unsigned exp) {
-  return merger.exp(exp).val;
+inline static Value genInvariantValue(CodeGenEnv &env, unsigned exp) {
+  return env.exp(exp).val;
 }
 
 /// Generates an index value.
-inline static Value genIndexValue(CodeGen &codegen, OpBuilder &builder,
-                                  unsigned idx) {
-  return codegen.getLoopIdxValue(idx);
+inline static Value genIndexValue(CodeGenEnv &env, unsigned idx) {
+  return env.getLoopIdxValue(idx);
 }
 
 /// Semi-ring branches are simply inlined by the sparse compiler. Prior
@@ -969,86 +1003,84 @@ inline static Value genIndexValue(CodeGen &codegen, OpBuilder &builder,
 /// branch or otherwise invariantly defined outside the loop nest, with the
 /// exception of index computations, which need to be relinked to actual
 /// inlined cloned code.
-static Value relinkBranch(CodeGen &codegen, RewriterBase &rewriter,
-                          Block *block, Value e, unsigned ldx) {
+static Value relinkBranch(CodeGenEnv &env, RewriterBase &rewriter, Block *block,
+                          Value e, unsigned ldx) {
   if (Operation *def = e.getDefiningOp()) {
     if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
-      return genIndexValue(codegen, rewriter, indexOp.getDim());
+      return genIndexValue(env, indexOp.getDim());
     if (def->getBlock() == block) {
       for (unsigned i = 0, n = def->getNumOperands(); i < n; i++)
         def->setOperand(
-            i, relinkBranch(codegen, rewriter, block, def->getOperand(i), ldx));
+            i, relinkBranch(env, rewriter, block, def->getOperand(i), ldx));
     }
   }
   return e;
 }
 
 /// Recursively generates tensor expression.
-static Value genExp(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
-                    linalg::GenericOp op, unsigned exp, unsigned ldx) {
+static Value genExp(CodeGenEnv &env, RewriterBase &rewriter, unsigned exp,
+                    unsigned ldx) {
+  linalg::GenericOp op = env.linalgOp;
   Location loc = op.getLoc();
+
   if (exp == -1u)
     return Value();
-  if (merger.exp(exp).kind == Kind::kTensor)
-    return genTensorLoad(merger, codegen, rewriter, op, exp);
-  if (merger.exp(exp).kind == Kind::kInvariant)
-    return genInvariantValue(merger, codegen, rewriter, exp);
-  if (merger.exp(exp).kind == Kind::kIndex)
-    return genIndexValue(codegen, rewriter, merger.exp(exp).index);
-
-  if (merger.exp(exp).kind == Kind::kReduce) {
-    // Make custom reduction identity accessible for expanded access pattern.
-    assert(codegen.redCustom == -1u);
-    codegen.redCustom = exp;
+  if (env.exp(exp).kind == Kind::kTensor)
+    return genTensorLoad(env, rewriter, exp);
+  if (env.exp(exp).kind == Kind::kInvariant)
+    return genInvariantValue(env, exp);
+  if (env.exp(exp).kind == Kind::kIndex)
+    return genIndexValue(env, env.exp(exp).index);
+
+  // Make custom reduction identity accessible for expanded access pattern.
+  if (env.exp(exp).kind == Kind::kReduce) {
+    assert(env.redCustom == -1u);
+    env.redCustom = exp;
   }
 
-  Value v0 =
-      genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0, ldx);
-  Value v1 =
-      genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1, ldx);
-  Value ee = merger.buildExp(rewriter, loc, exp, v0, v1);
-  if (ee && (merger.exp(exp).kind == Kind::kUnary ||
-             merger.exp(exp).kind == Kind::kBinary ||
-             merger.exp(exp).kind == Kind::kBinaryBranch ||
-             merger.exp(exp).kind == Kind::kReduce ||
-             merger.exp(exp).kind == Kind::kSelect))
-    ee = relinkBranch(codegen, rewriter, ee.getParentBlock(), ee, ldx);
-
-  if (merger.exp(exp).kind == kSelect) {
-    assert(!merger.exp(exp).val);
-    merger.exp(exp).val = v0; // Preserve value for later use.
-  }
-
-  if (merger.exp(exp).kind == Kind::kReduce) {
-    assert(codegen.redCustom != -1u);
-    codegen.redCustom = -1u;
+  Value v0 = genExp(env, rewriter, env.exp(exp).children.e0, ldx);
+  Value v1 = genExp(env, rewriter, env.exp(exp).children.e1, ldx);
+  Value ee = env.merger.buildExp(rewriter, loc, exp, v0, v1);
+  if (ee && (env.exp(exp).kind == Kind::kUnary ||
+             env.exp(exp).kind == Kind::kBinary ||
+             env.exp(exp).kind == Kind::kBinaryBranch ||
+             env.exp(exp).kind == Kind::kReduce ||
+             env.exp(exp).kind == Kind::kSelect))
+    ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee, ldx);
+
+  if (env.exp(exp).kind == kSelect) {
+    assert(!env.exp(exp).val);
+    env.exp(exp).val = v0; // Preserve value for later use.
+  } else if (env.exp(exp).kind == Kind::kReduce) {
+    assert(env.redCustom != -1u);
+    env.redCustom = -1u;
   }
 
   return ee;
 }
 
 /// Hoists loop invariant tensor loads for which indices have been exhausted.
-static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder,
-                          linalg::GenericOp op, unsigned exp, unsigned ldx,
-                          bool atStart, unsigned last = -1u) {
+static void genInvariants(CodeGenEnv &env, OpBuilder &builder, unsigned exp,
+                          unsigned ldx, bool atStart, unsigned last = -1u) {
   if (exp == -1u)
     return;
-  if (merger.exp(exp).kind == Kind::kTensor) {
+  if (env.exp(exp).kind == Kind::kTensor) {
     // Inspect tensor indices.
     bool atLevel = ldx == -1u;
-    OpOperand &t = op->getOpOperand(merger.exp(exp).tensor);
+    linalg::GenericOp op = env.linalgOp;
+    OpOperand &t = op->getOpOperand(env.exp(exp).tensor);
     auto map = op.getMatchingIndexingMap(&t);
     auto enc = getSparseTensorEncoding(t.get().getType());
     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
       AffineExpr a = map.getResult(toOrigDim(enc, d));
-      Optional<unsigned> sldx = merger.getLoopIdx(t.getOperandNumber(), d);
-      if (sldx && merger.isFilterLoop(*sldx)) {
-        if (!codegen.getLoopIdxValue(*sldx))
+      Optional<unsigned> sldx = env.merger.getLoopIdx(t.getOperandNumber(), d);
+      if (sldx && env.isFilterLoop(*sldx)) {
+        if (!env.getLoopIdxValue(*sldx))
           // The filter loops has not been constructed.
           return;
         if (*sldx == ldx)
           atLevel = true;
-      } else if (!isInvariantAffine(codegen, a, ldx, atLevel))
+      } else if (!isInvariantAffine(env, a, ldx, atLevel))
         return; // still in play
     }
     // All exhausted at this level (atLevel denotes exactly at this level).
@@ -1058,45 +1090,43 @@ static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder,
     if (lhs == &t) {
       // Start or end a scalarized reduction
       if (atStart) {
-        Kind kind = merger.exp(last).kind;
-        Value load = kind == Kind::kReduce
-                         ? getCustomRedId(merger.exp(last).op)
-                         : genTensorLoad(merger, codegen, builder, op, exp);
-        codegen.redKind = getReduction(kind);
-        codegen.redExp = exp;
-        updateReduc(merger, codegen, load);
+        Kind kind = env.exp(last).kind;
+        Value load = kind == Kind::kReduce ? getCustomRedId(env.exp(last).op)
+                                           : genTensorLoad(env, builder, exp);
+        env.redKind = getReduction(kind);
+        env.redExp = exp;
+        updateReduc(env, load);
       } else {
-        Value redVal = codegen.redVal;
-        updateReduc(merger, codegen, Value());
-        codegen.redExp = -1u;
-        codegen.redKind = kNoReduc;
-        genTensorStore(merger, codegen, builder, op, exp, redVal);
+        Value redVal = env.redVal;
+        updateReduc(env, Value());
+        env.redExp = -1u;
+        env.redKind = kNoReduc;
+        genTensorStore(env, builder, exp, redVal);
       }
     } else {
       // Start or end loop invariant hoisting of a tensor load.
-      merger.exp(exp).val =
-          atStart ? genTensorLoad(merger, codegen, builder, op, exp) : Value();
+      env.exp(exp).val = atStart ? genTensorLoad(env, builder, exp) : Value();
     }
-  } else if (merger.exp(exp).kind != Kind::kInvariant &&
-             merger.exp(exp).kind != Kind::kIndex) {
+  } else if (env.exp(exp).kind != Kind::kInvariant &&
+             env.exp(exp).kind != Kind::kIndex) {
     // Traverse into the binary operations. Note that we only hoist
     // tensor loads, since subsequent MLIR/LLVM passes know how to
     // deal with all other kinds of derived loop invariants.
-    unsigned e0 = merger.exp(exp).children.e0;
-    unsigned e1 = merger.exp(exp).children.e1;
-    genInvariants(merger, codegen, builder, op, e0, ldx, atStart, exp);
-    genInvariants(merger, codegen, builder, op, e1, ldx, atStart, exp);
+    unsigned e0 = env.exp(exp).children.e0;
+    unsigned e1 = env.exp(exp).children.e1;
+    genInvariants(env, builder, e0, ldx, atStart, exp);
+    genInvariants(env, builder, e1, ldx, atStart, exp);
   }
 }
 
 /// Generates an expanded access pattern in innermost dimension.
-static void genExpansion(Merger &merger, CodeGen &codegen, OpBuilder &builder,
-                         linalg::GenericOp op, unsigned at, bool atStart) {
-  OpOperand *lhs = codegen.sparseOut;
-  if (!lhs || codegen.outerParNest != op.getRank(lhs) - 1 ||
-      at != codegen.outerParNest)
+static void genExpansion(CodeGenEnv &env, OpBuilder &builder, unsigned at,
+                         bool atStart) {
+  linalg::GenericOp op = env.linalgOp;
+  OpOperand *lhs = env.sparseOut;
+  if (!lhs || env.outerParNest != op.getRank(lhs) - 1 || at != env.outerParNest)
     return; // not needed at this level
-  assert(codegen.redVal == nullptr);
+  assert(env.redVal == nullptr);
   // Generate start or end of an expanded access pattern. Note that because
   // an expension does not rely on the ongoing contents of the sparse storage
   // scheme, we can use the original tensor as incoming SSA value (which
@@ -1114,38 +1144,37 @@ static void genExpansion(Merger &merger, CodeGen &codegen, OpBuilder &builder,
     auto res =
         builder.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor);
     assert(res.getNumResults() == 4);
-    assert(!codegen.expValues);
-    codegen.expValues = res.getResult(0);
-    codegen.expFilled = res.getResult(1);
-    codegen.expAdded = res.getResult(2);
-    codegen.expCount = res.getResult(3);
+    assert(!env.expValues);
+    env.expValues = res.getResult(0);
+    env.expFilled = res.getResult(1);
+    env.expAdded = res.getResult(2);
+    env.expCount = res.getResult(3);
   } else {
-    assert(codegen.expValues);
+    assert(env.expValues);
     SmallVector<Value> indices;
     for (unsigned i = 0; i < at; i++) {
-      assert(codegen.loopEmitter.getLoopIV(i));
-      indices.push_back(codegen.loopEmitter.getLoopIV(i));
+      assert(env.getLoopIV(i));
+      indices.push_back(env.getLoopIV(i));
     }
-    codegen.insChain = builder.create<CompressOp>(
-        loc, codegen.expValues, codegen.expFilled, codegen.expAdded,
-        codegen.expCount, codegen.insChain, indices);
-    codegen.expValues = codegen.expFilled = codegen.expAdded =
-        codegen.expCount = Value();
+    env.insChain = builder.create<CompressOp>(loc, env.expValues, env.expFilled,
+                                              env.expAdded, env.expCount,
+                                              env.insChain, indices);
+    env.expValues = env.expFilled = env.expAdded = env.expCount = Value();
   }
 }
 
 /// Returns parallelization strategy. Any implicit loop in the Linalg
 /// operation that is marked "parallel" is a candidate. Whether it is actually
 /// converted to a parallel operation depends on the requested strategy.
-static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isSparse) {
+static bool isParallelFor(CodeGenEnv &env, bool isOuter, bool isSparse) {
   // Reject parallelization of sparse output.
-  if (codegen.sparseOut)
+  if (env.sparseOut)
     return false;
   // Parallel loops on tensor expansion can cause data races.
-  if (codegen.expCount)
+  if (env.expCount)
     return false;
   // Inspect strategy.
-  switch (codegen.options.parallelizationStrategy) {
+  switch (env.options.parallelizationStrategy) {
   case SparseParallelizationStrategy::kNone:
     return false;
   case SparseParallelizationStrategy::kDenseOuterLoop:
@@ -1161,98 +1190,94 @@ static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isSparse) {
 }
 
 /// Generates a for-loop on a single index.
-static Operation *genFor(Merger &merger, CodeGen &codegen, OpBuilder &builder,
-                         linalg::GenericOp op, bool isOuter, bool isInner,
-                         unsigned idx, size_t tid, size_t dim,
+static Operation *genFor(CodeGenEnv &env, OpBuilder &builder, bool isOuter,
+                         bool isInner, unsigned idx, size_t tid, size_t dim,
                          ArrayRef<size_t> extraTids,
                          ArrayRef<size_t> extraDims) {
+  linalg::GenericOp op = env.linalgOp;
   Location loc = op.getLoc();
-  bool isSparse = isCompressedDLT(merger.getDimLevelType(tid, idx)) ||
-                  isSingletonDLT(merger.getDimLevelType(tid, idx));
-  bool isParallel = isParallelFor(codegen, isOuter, isSparse);
-
-  Operation *loop =
-      *genLoopBoundary(codegen, merger, [&](MutableArrayRef<Value> reduc) {
-        if (merger.isFilterLoop(idx)) {
-          // extraTids/extraDims must be empty because filter loops only
-          // corresponding to the one and only sparse tensor level.
-          assert(isSparse && extraTids.empty() && extraDims.empty());
-          OpOperand *t = &op->getOpOperand(tid);
-          auto enc = getSparseTensorEncoding(t->get().getType());
-          // Retrieves the affine expression for the filter loop.
-          AffineExpr a =
-              op.getMatchingIndexingMap(t).getResult(toOrigDim(enc, dim));
-          return codegen.loopEmitter.enterFilterLoopOverTensorAtDim(
-              builder, loc, tid, dim, a, reduc);
-        }
-        return codegen.loopEmitter.enterLoopOverTensorAtDim(
-            builder, loc, tid, dim, reduc, isParallel, extraTids, extraDims);
-      });
+  auto iteratorTypes = op.getIteratorTypesArray();
+  bool isSparse = isCompressedDLT(env.dimLevelType(tid, idx)) ||
+                  isSingletonDLT(env.dimLevelType(tid, idx));
+  bool isParallel = isParallelFor(env, isOuter, isSparse);
+
+  Operation *loop = *genLoopBoundary(env, [&](MutableArrayRef<Value> reduc) {
+    if (env.isFilterLoop(idx)) {
+      // extraTids/extraDims must be empty because filter loops only
+      // corresponding to the one and only sparse tensor level.
+      assert(isSparse && extraTids.empty() && extraDims.empty());
+      OpOperand *t = &op->getOpOperand(tid);
+      auto enc = getSparseTensorEncoding(t->get().getType());
+      // Retrieves the affine expression for the filter loop.
+      AffineExpr a =
+          op.getMatchingIndexingMap(t).getResult(toOrigDim(enc, dim));
+      return env.loopEmitter->enterFilterLoopOverTensorAtDim(builder, loc, tid,
+                                                             dim, a, reduc);
+    }
+    return env.loopEmitter->enterLoopOverTensorAtDim(
+        builder, loc, tid, dim, reduc, isParallel, extraTids, extraDims);
+  });
   assert(loop);
   return loop;
 }
 
 /// Emit a while-loop for co-iteration over multiple indices.
-static Operation *genWhile(Merger &merger, CodeGen &codegen, OpBuilder &builder,
-                           linalg::GenericOp op, unsigned idx, bool needsUniv,
-                           ArrayRef<size_t> condTids, ArrayRef<size_t> condDims,
+static Operation *genWhile(CodeGenEnv &env, OpBuilder &builder, unsigned idx,
+                           bool needsUniv, ArrayRef<size_t> condTids,
+                           ArrayRef<size_t> condDims,
                            ArrayRef<size_t> extraTids,
                            ArrayRef<size_t> extraDims) {
-
-  Operation *loop =
-      *genLoopBoundary(codegen, merger, [&](MutableArrayRef<Value> reduc) {
-        // Construct the while-loop with a parameter for each index.
-        return codegen.loopEmitter.enterCoIterationOverTensorsAtDims(
-            builder, op.getLoc(), condTids, condDims, needsUniv, reduc,
-            extraTids, extraDims);
-      });
+  Operation *loop = *genLoopBoundary(env, [&](MutableArrayRef<Value> reduc) {
+    // Construct the while-loop with a parameter for each index.
+    return env.loopEmitter->enterCoIterationOverTensorsAtDims(
+        builder, env.linalgOp.getLoc(), condTids, condDims, needsUniv, reduc,
+        extraTids, extraDims);
+  });
   assert(loop);
   return loop;
 }
 
 /// Generates a for-loop or a while-loop, depending on whether it implements
 /// singleton iteration or co-iteration over the given conjunction.
-static Operation *genLoop(Merger &merger, CodeGen &codegen, OpBuilder &builder,
-                          linalg::GenericOp op, unsigned at, bool needsUniv,
-                          ArrayRef<size_t> condTids, ArrayRef<size_t> condDims,
-                          ArrayRef<size_t> extraTids,
+static Operation *genLoop(CodeGenEnv &env, OpBuilder &builder, unsigned at,
+                          bool needsUniv, ArrayRef<size_t> condTids,
+                          ArrayRef<size_t> condDims, ArrayRef<size_t> extraTids,
                           ArrayRef<size_t> extraDims) {
   assert(condTids.size() == condDims.size());
   assert(extraTids.size() == extraDims.size());
-  unsigned idx = codegen.topSort[at];
+  unsigned idx = env.topSort[at];
   if (condTids.size() == 1) {
     bool isOuter = at == 0;
-    bool isInner = at == codegen.topSort.size() - 1;
-    return genFor(merger, codegen, builder, op, isOuter, isInner, idx,
-                  condTids.front(), condDims.front(), extraTids, extraDims);
+    bool isInner = at == env.topSort.size() - 1;
+    return genFor(env, builder, isOuter, isInner, idx, condTids.front(),
+                  condDims.front(), extraTids, extraDims);
   }
-  return genWhile(merger, codegen, builder, op, idx, needsUniv, condTids,
-                  condDims, extraTids, extraDims);
+  return genWhile(env, builder, idx, needsUniv, condTids, condDims, extraTids,
+                  extraDims);
 }
 
 /// Generates the induction structure for a while-loop.
-static void finalizeWhileOp(Merger &merger, CodeGen &codegen,
-                            OpBuilder &builder, linalg::GenericOp op,
-                            unsigned idx, bool needsUniv, BitVector &induction,
+static void finalizeWhileOp(CodeGenEnv &env, OpBuilder &builder, unsigned idx,
+                            bool needsUniv, BitVector &induction,
                             scf::WhileOp whileOp) {
-  Location loc = op.getLoc();
+  Location loc = env.linalgOp.getLoc();
   // Finalize each else branch of all if statements.
-  if (codegen.redVal || codegen.expValues || codegen.insChain) {
+  if (env.redVal || env.expValues || env.insChain) {
     while (auto ifOp = dyn_cast_or_null<scf::IfOp>(
                builder.getInsertionBlock()->getParentOp())) {
       unsigned y = 0;
       SmallVector<Value> yields;
-      if (codegen.redVal) {
-        yields.push_back(codegen.redVal);
-        updateReduc(merger, codegen, ifOp.getResult(y++));
+      if (env.redVal) {
+        yields.push_back(env.redVal);
+        updateReduc(env, ifOp.getResult(y++));
       }
-      if (codegen.expValues) {
-        yields.push_back(codegen.expCount);
-        codegen.expCount = ifOp->getResult(y++);
+      if (env.expValues) {
+        yields.push_back(env.expCount);
+        env.expCount = ifOp->getResult(y++);
       }
-      if (codegen.insChain) {
-        yields.push_back(codegen.insChain);
-        codegen.insChain = ifOp->getResult(y++);
+      if (env.insChain) {
+        yields.push_back(env.insChain);
+        env.insChain = ifOp->getResult(y++);
       }
       assert(y == yields.size());
       builder.create<scf::YieldOp>(loc, yields);
@@ -1263,62 +1288,61 @@ static void finalizeWhileOp(Merger &merger, CodeGen &codegen,
 }
 
 /// Generates a single if-statement within a while-loop.
-static scf::IfOp genIf(Merger &merger, CodeGen &codegen, OpBuilder &builder,
-                       linalg::GenericOp op, unsigned idx,
+static scf::IfOp genIf(CodeGenEnv &env, OpBuilder &builder, unsigned idx,
                        BitVector &conditions) {
-  Location loc = op.getLoc();
+  Location loc = env.linalgOp.getLoc();
   SmallVector<Type> types;
   Value cond;
   for (unsigned b = 0, be = conditions.size(); b < be; b++) {
     if (!conditions[b])
       continue;
-    unsigned tensor = merger.tensor(b);
-    assert(idx == merger.index(b));
+    unsigned tensor = env.merger.tensor(b);
+    assert(idx == env.merger.index(b));
     Value clause;
-    if (isCompressedDLT(merger.getDimLevelType(b)) ||
-        isSingletonDLT(merger.getDimLevelType(b))) {
-      auto dim = *merger.getDimNum(tensor, idx);
-      Value op1 = codegen.loopEmitter.getCoord()[tensor][dim];
-      Value op2 = codegen.getLoopIdxValue(idx);
+    if (isCompressedDLT(env.dimLevelType(b)) ||
+        isSingletonDLT(env.dimLevelType(b))) {
+      auto dim = *env.merger.getDimNum(tensor, idx);
+      Value op1 = env.loopEmitter->getCoord()[tensor][dim];
+      Value op2 = env.getLoopIdxValue(idx);
       clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, op1,
                                              op2);
     } else {
-      assert(isDenseDLT(merger.getDimLevelType(b)) ||
-             isUndefDLT(merger.getDimLevelType(b)));
+      assert(isDenseDLT(env.dimLevelType(b)) ||
+             isUndefDLT(env.dimLevelType(b)));
       clause = constantI1(builder, loc, true);
     }
     cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause;
   }
-  if (codegen.redVal)
-    types.push_back(codegen.redVal.getType());
-  if (codegen.expValues)
+  if (env.redVal)
+    types.push_back(env.redVal.getType());
+  if (env.expValues)
     types.push_back(builder.getIndexType());
-  if (codegen.insChain)
-    types.push_back(codegen.insChain.getType());
+  if (env.insChain)
+    types.push_back(env.insChain.getType());
   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
   return ifOp;
 }
 
 /// Generates end of true branch of if-statement within a while-loop.
-static void endIf(Merger &merger, CodeGen &codegen, OpBuilder &builder,
-                  linalg::GenericOp op, scf::IfOp ifOp, Operation *loop,
-                  Value redInput, Value cntInput, Value insInput) {
+static void endIf(CodeGenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
+                  Operation *loop, Value redInput, Value cntInput,
+                  Value insInput) {
   SmallVector<Value> operands;
-  if (codegen.redVal) {
-    operands.push_back(codegen.redVal);
-    updateReduc(merger, codegen, redInput);
+  if (env.redVal) {
+    operands.push_back(env.redVal);
+    updateReduc(env, redInput);
   }
-  if (codegen.expValues) {
-    operands.push_back(codegen.expCount);
-    codegen.expCount = cntInput;
+  if (env.expValues) {
+    operands.push_back(env.expCount);
+    env.expCount = cntInput;
   }
-  if (codegen.insChain) {
-    operands.push_back(codegen.insChain);
-    codegen.insChain = insInput;
+  if (env.insChain) {
+    operands.push_back(env.insChain);
+    env.insChain = insInput;
   }
   if (!operands.empty())
-    builder.create<scf::YieldOp>(op.getLoc(), operands);
+    builder.create<scf::YieldOp>(env.linalgOp.getLoc(), operands);
   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
 }
 
@@ -1328,24 +1352,24 @@ static void endIf(Merger &merger, CodeGen &codegen, OpBuilder &builder,
 
 /// Starts a loop sequence at given level. Returns true if
 /// the universal loop index must be maintained at this level.
-static bool startLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder,
-                         linalg::GenericOp op, unsigned exp, unsigned at,
-                         unsigned idx, unsigned ldx, unsigned lts) {
-  assert(!codegen.getLoopIdxValue(idx));
+static bool startLoopSeq(CodeGenEnv &env, OpBuilder &builder, unsigned exp,
+                         unsigned at, unsigned idx, unsigned ldx,
+                         unsigned lts) {
+  assert(!env.getLoopIdxValue(idx));
   // Emit invariants at this loop sequence level.
-  genInvariants(merger, codegen, builder, op, exp, ldx, /*atStart=*/true);
+  genInvariants(env, builder, exp, ldx, /*atStart=*/true);
   // Emit access pattern expansion for sparse tensor output.
-  genExpansion(merger, codegen, builder, op, at, /*atStart=*/true);
+  genExpansion(env, builder, at, /*atStart=*/true);
   // Emit further intitialization at this loop sequence level.
-  unsigned l0 = merger.set(lts)[0];
+  unsigned l0 = env.set(lts)[0];
   bool needsUniv = false;
 
   SmallVector<size_t> tids;
   SmallVector<size_t> dims;
-  merger.foreachTidDimPairInBits(
-      merger.lat(l0).bits,
+  env.merger.foreachTidDimPairInBits(
+      env.lat(l0).bits,
       [&](unsigned b, unsigned tid, Optional<unsigned> dim, DimLevelType dlt) {
-        assert(merger.index(b) == idx);
+        assert(env.merger.index(b) == idx);
         if (isDenseDLT(dlt) || isUndefDLT(dlt)) {
           needsUniv = true;
         } else {
@@ -1355,28 +1379,27 @@ static bool startLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder,
         }
       });
 
-  codegen.loopEmitter.enterNewLoopSeq(builder, op.getLoc(), tids, dims);
+  env.loopEmitter->enterNewLoopSeq(builder, env.linalgOp.getLoc(), tids, dims);
 
   // Maintain the universal index only if it is actually
   // consumed by a subsequent lattice point.
   if (needsUniv) {
-    unsigned lsize = merger.set(lts).size();
+    unsigned lsize = env.set(lts).size();
     for (unsigned i = 1; i < lsize; i++) {
-      unsigned li = merger.set(lts)[i];
-      if (!merger.hasAnySparse(merger.lat(li).simple))
+      unsigned li = env.set(lts)[i];
+      if (!env.merger.hasAnySparse(env.lat(li).simple))
         return true;
     }
   }
   return false;
 }
 
-static void genConstantDenseAddressFromLevel(CodeGen &codegen,
-                                             OpBuilder &builder,
-                                             linalg::GenericOp op, unsigned tid,
+static void genConstantDenseAddressFromLevel(CodeGenEnv &env,
+                                             OpBuilder &builder, unsigned tid,
                                              unsigned lvl) {
   // TODO: Handle affine expression on output tensor.
+  linalg::GenericOp op = env.linalgOp;
   assert(tid < op.getNumDpsInputs());
-
   OpOperand *input = op.getDpsInputOperands()[tid];
   ArrayRef<AffineExpr> affines = op.getMatchingIndexingMap(input).getResults();
   auto enc = getSparseTensorEncoding(input->get().getType());
@@ -1384,42 +1407,38 @@ static void genConstantDenseAddressFromLevel(CodeGen &codegen,
     for (unsigned i = lvl, e = affines.size(); i < e; i++) {
       AffineExpr affine = affines[toOrigDim(enc, i)];
       if (isDenseDLT(getDimLevelType(enc, i)) &&
-          affine.isa<AffineConstantExpr>()) {
-        codegen.loopEmitter.genDenseAffineAddressAtCurLevel(
+          affine.isa<AffineConstantExpr>())
+        env.loopEmitter->genDenseAffineAddressAtCurLevel(
             builder, op.getLoc(), input->getOperandNumber(), i, affine);
-      } else {
-        // Breaks on first non-dense non-constant level.
-        return;
-      }
+      else
+        return; // break on first non-dense non-constant level
     }
   }
 }
 
-static void genInitConstantDenseAddress(CodeGen &codegen,
-                                        RewriterBase &rewriter,
-                                        linalg::GenericOp op) {
-  // We can generates address for constant affine expression before any loops
+static void genInitConstantDenseAddress(CodeGenEnv &env,
+                                        RewriterBase &rewriter) {
+  // We can generate address for constant affine expression before any loops
   // starting from the first level as they do not depend on any thing.
   // E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
   // levels can be determined before loops.
-  for (unsigned tid = 0, e = op.getNumDpsInputs(); tid < e; tid++)
-    genConstantDenseAddressFromLevel(codegen, rewriter, op, tid, 0);
+  for (unsigned tid = 0, e = env.linalgOp.getNumDpsInputs(); tid < e; tid++)
+    genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
 }
 
 static void translateBitsToTidDimPairs(
-    Merger &merger, CodeGen &codegen, linalg::GenericOp op, unsigned li,
-    unsigned idx, SmallVectorImpl<size_t> &condTids,
-    SmallVectorImpl<size_t> &condDims, SmallVectorImpl<size_t> &extraTids,
-    SmallVectorImpl<size_t> &extraDims, SmallVectorImpl<size_t> &affineTids,
-    SmallVectorImpl<size_t> &affineDims, SmallVectorImpl<AffineExpr> &exps) {
-
-  const BitVector &all = merger.lat(li).bits;
-  const BitVector &simple = merger.lat(li).simple;
+    CodeGenEnv &env, unsigned li, unsigned idx,
+    SmallVectorImpl<size_t> &condTids, SmallVectorImpl<size_t> &condDims,
+    SmallVectorImpl<size_t> &extraTids, SmallVectorImpl<size_t> &extraDims,
+    SmallVectorImpl<size_t> &affineTids, SmallVectorImpl<size_t> &affineDims,
+    SmallVectorImpl<AffineExpr> &exps) {
+  const BitVector &all = env.lat(li).bits;
+  const BitVector &simple = env.lat(li).simple;
 
   // Converts bits to array + dim pair
-  merger.foreachTidDimPairInBits(all, [&, idx](unsigned b, unsigned tid,
-                                               Optional<unsigned> dim,
-                                               DimLevelType dlt) {
+  env.merger.foreachTidDimPairInBits(all, [&, idx](unsigned b, unsigned tid,
+                                                   Optional<unsigned> dim,
+                                                   DimLevelType dlt) {
     if (simple.test(b)) {
       if (isUndefDLT(dlt)) {
         // An undefined dlt in the lattices, we probably mean to iterate based
@@ -1428,8 +1447,8 @@ static void translateBitsToTidDimPairs(
         // output tensor).
         // out[i][j] = invariant; or a broadcast
         // out[i][j] = in[i] (j is undef for input)
-        tid = merger.getOutTensorID();
-        dim = merger.getDimNum(tid, idx);
+        tid = env.merger.getOutTensorID();
+        dim = env.merger.getDimNum(tid, idx);
         // Skips invalid dim (e.g., when this is a zero ranked tensor).
         if (!dim)
           return;
@@ -1442,6 +1461,7 @@ static void translateBitsToTidDimPairs(
       extraDims.push_back(*dim);
     } else {
       assert(isUndefDLT(dlt));
+      linalg::GenericOp op = env.linalgOp;
       if (tid >= op.getNumDpsInputs())
         // We only handle affine expression on input tensors (for now).
         return;
@@ -1464,7 +1484,7 @@ static void translateBitsToTidDimPairs(
         // Constant affine expression are handled in genLoop
         if (!exp.isa<AffineConstantExpr>()) {
           bool atLevel = false;
-          if (isInvariantAffine(codegen, exp, idx, atLevel) && atLevel) {
+          if (isInvariantAffine(env, exp, idx, atLevel) && atLevel) {
             // If the compound affine is invariant and we are right at the
             // level. We need to generate the address according to the affine
             // expression. This is also the best place we can do it to avoid
@@ -1482,20 +1502,19 @@ static void translateBitsToTidDimPairs(
     }
   });
 
-  if (isDenseDLT(merger.getDimLevelType(merger.getOutTensorID(), idx))) {
+  if (isDenseDLT(env.dimLevelType(env.merger.getOutTensorID(), idx))) {
     // Note that we generate dense indices of the output tensor
     // unconditionally, since they may not appear in the lattice, but may be
     // needed for linearized codegen.
-    auto dim = *merger.getDimNum(merger.getOutTensorID(), idx);
-    extraTids.push_back(merger.getOutTensorID());
+    auto dim = *env.merger.getDimNum(env.merger.getOutTensorID(), idx);
+    extraTids.push_back(env.merger.getOutTensorID());
     extraDims.push_back(dim);
   }
 }
 
 /// Starts a single loop in current sequence.
-static Operation *startLoop(Merger &merger, CodeGen &codegen,
-                            OpBuilder &builder, linalg::GenericOp op,
-                            unsigned at, unsigned li, bool needsUniv) {
+static Operation *startLoop(CodeGenEnv &env, OpBuilder &builder, unsigned at,
+                            unsigned li, bool needsUniv) {
   // The set of tensors + dims to generate loops on
   SmallVector<size_t> condTids, condDims;
   // The set of (dense) tensors that is optimized from condition, yet still
@@ -1506,17 +1525,16 @@ static Operation *startLoop(Merger &merger, CodeGen &codegen,
   // level.
   SmallVector<size_t> affineTids, affineDims;
   SmallVector<AffineExpr> affines;
+  translateBitsToTidDimPairs(env, li, env.topSort[at], condTids, condDims,
+                             extraTids, extraDims, affineTids, affineDims,
+                             affines);
 
-  translateBitsToTidDimPairs(merger, codegen, op, li, codegen.topSort[at],
-                             condTids, condDims, extraTids, extraDims,
-                             affineTids, affineDims, affines);
   // Emit the for/while-loop control.
-  Operation *loop = genLoop(merger, codegen, builder, op, at, needsUniv,
-                            condTids, condDims, extraTids, extraDims);
-
+  Operation *loop = genLoop(env, builder, at, needsUniv, condTids, condDims,
+                            extraTids, extraDims);
   for (auto [tid, dim, exp] : llvm::zip(affineTids, affineDims, affines)) {
-    codegen.loopEmitter.genDenseAffineAddressAtCurLevel(builder, op.getLoc(),
-                                                        tid, dim, exp);
+    env.loopEmitter->genDenseAffineAddressAtCurLevel(
+        builder, env.linalgOp.getLoc(), tid, dim, exp);
   }
 
   // Until now, we have entered every <tid, dim> pair in {cond, extra,
@@ -1525,27 +1543,25 @@ static Operation *startLoop(Merger &merger, CodeGen &codegen,
   auto allTids = llvm::concat<size_t>(condTids, extraTids, affineTids);
   auto allDims = llvm::concat<size_t>(condDims, extraDims, affineDims);
   for (auto [tid, dim] : llvm::zip(allTids, allDims)) {
-    if (tid != merger.getOutTensorID())
-      genConstantDenseAddressFromLevel(codegen, builder, op, tid, dim + 1);
+    if (tid != env.merger.getOutTensorID())
+      genConstantDenseAddressFromLevel(env, builder, tid, dim + 1);
   }
 
   return loop;
 }
 
 /// Ends a single loop in current sequence. Returns new values for needsUniv.
-static bool endLoop(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
-                    linalg::GenericOp op, Operation *loop, unsigned idx,
-                    unsigned li, bool needsUniv) {
+static bool endLoop(CodeGenEnv &env, RewriterBase &rewriter, Operation *loop,
+                    unsigned idx, unsigned li, bool needsUniv) {
   // End a while-loop.
   if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
-    finalizeWhileOp(merger, codegen, rewriter, op, idx, needsUniv,
-                    merger.lat(li).bits, whileOp);
+    finalizeWhileOp(env, rewriter, idx, needsUniv, env.lat(li).bits, whileOp);
   } else {
     needsUniv = false;
   }
 
-  genLoopBoundary(codegen, merger, [&](MutableArrayRef<Value> reduc) {
-    codegen.loopEmitter.exitCurrentLoop(rewriter, op.getLoc(), reduc);
+  genLoopBoundary(env, [&](MutableArrayRef<Value> reduc) {
+    env.loopEmitter->exitCurrentLoop(rewriter, env.linalgOp.getLoc(), reduc);
     return std::nullopt;
   });
 
@@ -1553,85 +1569,79 @@ static bool endLoop(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
 }
 
 /// Ends a loop sequence at given level.
-static void endLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder,
-                       linalg::GenericOp op, unsigned exp, unsigned at,
-                       unsigned idx, unsigned ldx) {
-  assert(codegen.getLoopIdxValue(idx) == nullptr);
-  codegen.loopEmitter.exitCurrentLoopSeq();
+static void endLoopSeq(CodeGenEnv &env, OpBuilder &builder, unsigned exp,
+                       unsigned at, unsigned idx, unsigned ldx) {
+  assert(env.getLoopIdxValue(idx) == nullptr);
+  env.loopEmitter->exitCurrentLoopSeq();
   // Unmark bookkeeping of invariants and loop index.
-  genInvariants(merger, codegen, builder, op, exp, ldx, /*atStart=*/false);
+  genInvariants(env, builder, exp, ldx, /*atStart=*/false);
   // Finalize access pattern expansion for sparse tensor output.
-  genExpansion(merger, codegen, builder, op, at, /*atStart=*/false);
+  genExpansion(env, builder, at, /*atStart=*/false);
 }
 
 /// Recursively generates code while computing iteration lattices in order
 /// to manage the complexity of implementing co-iteration over unions
 /// and intersections of sparse iterations spaces.
-static void genStmt(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
-                    linalg::GenericOp op, unsigned exp, unsigned at) {
+static void genStmt(CodeGenEnv &env, RewriterBase &rewriter, unsigned exp,
+                    unsigned at) {
   // At each leaf, assign remaining tensor (sub)expression to output tensor.
-  if (at == codegen.topSort.size()) {
-    unsigned ldx = codegen.topSort[at - 1];
-    Value rhs = genExp(merger, codegen, rewriter, op, exp, ldx);
-    genTensorStore(merger, codegen, rewriter, op, exp, rhs);
+  if (at == env.topSort.size()) {
+    unsigned ldx = env.topSort[at - 1];
+    Value rhs = genExp(env, rewriter, exp, ldx);
+    genTensorStore(env, rewriter, exp, rhs);
     return;
   }
 
   // Construct iteration lattices for current loop index, with L0 at top.
-  unsigned idx = codegen.topSort[at];
-  unsigned ldx = at == 0 ? -1u : codegen.topSort[at - 1];
-  unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx));
+  unsigned idx = env.topSort[at];
+  unsigned ldx = at == 0 ? -1u : env.topSort[at - 1];
+  unsigned lts = env.merger.optimizeSet(env.merger.buildLattices(exp, idx));
 
   // TODO: sort
   // TODO: dedup
 
   // Start a loop sequence.
-  bool needsUniv =
-      startLoopSeq(merger, codegen, rewriter, op, exp, at, idx, ldx, lts);
+  bool needsUniv = startLoopSeq(env, rewriter, exp, at, idx, ldx, lts);
 
   // Emit a loop for every lattice point L0 >= Li in this loop sequence.
-  unsigned lsize = merger.set(lts).size();
+  unsigned lsize = env.set(lts).size();
   for (unsigned i = 0; i < lsize; i++) {
     // Start a loop.
-    unsigned li = merger.set(lts)[i];
-    Operation *loop =
-        startLoop(merger, codegen, rewriter, op, at, li, needsUniv);
+    unsigned li = env.set(lts)[i];
+    Operation *loop = startLoop(env, rewriter, at, li, needsUniv);
 
     // Visit all lattices points with Li >= Lj to generate the
     // loop-body, possibly with if statements for coiteration.
-    Value redInput = codegen.redVal;
-    Value cntInput = codegen.expCount;
-    Value insInput = codegen.insChain;
+    Value redInput = env.redVal;
+    Value cntInput = env.expCount;
+    Value insInput = env.insChain;
     bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr;
     for (unsigned j = 0; j < lsize; j++) {
-      unsigned lj = merger.set(lts)[j];
-      unsigned ej = merger.lat(lj).exp;
-      if (li == lj || merger.latGT(li, lj)) {
+      unsigned lj = env.set(lts)[j];
+      unsigned ej = env.lat(lj).exp;
+      if (li == lj || env.merger.latGT(li, lj)) {
         // Recurse into body of each branch.
         if (isWhile) {
-          scf::IfOp ifOp =
-              genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple);
-          genStmt(merger, codegen, rewriter, op, ej, at + 1);
-          endIf(merger, codegen, rewriter, op, ifOp, loop, redInput, cntInput,
-                insInput);
+          scf::IfOp ifOp = genIf(env, rewriter, idx, env.lat(lj).simple);
+          genStmt(env, rewriter, ej, at + 1);
+          endIf(env, rewriter, ifOp, loop, redInput, cntInput, insInput);
         } else {
-          genStmt(merger, codegen, rewriter, op, ej, at + 1);
+          genStmt(env, rewriter, ej, at + 1);
         }
       }
     }
 
     // End a loop.
-    needsUniv =
-        endLoop(merger, codegen, rewriter, op, loop, idx, li, needsUniv);
+    needsUniv = endLoop(env, rewriter, loop, idx, li, needsUniv);
   }
 
   // End a loop sequence.
-  endLoopSeq(merger, codegen, rewriter, op, exp, at, idx, ldx);
+  endLoopSeq(env, rewriter, exp, at, idx, ldx);
 }
 
 /// Converts the result computed by the sparse kernel into the required form.
-static void genResult(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
-                      linalg::GenericOp op) {
+static void genResult(CodeGenEnv &env, RewriterBase &rewriter) {
+  linalg::GenericOp op = env.linalgOp;
   OpOperand *lhs = op.getDpsInitOperand(0);
   Value tensor = lhs->get();
   Type resType = tensor.getType();
@@ -1639,14 +1649,14 @@ static void genResult(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
     // The sparse tensor rematerializes from the original sparse tensor's
     // underlying sparse storage format. For an insertion chain, the
     // tensor materializes from the chain with 'hasInserts' enabled.
-    bool hasInserts = codegen.sparseOut == lhs;
+    bool hasInserts = env.sparseOut == lhs;
     if (hasInserts)
-      tensor = codegen.insChain;
+      tensor = env.insChain;
     rewriter.replaceOpWithNewOp<LoadOp>(op, resType, tensor, hasInserts);
   } else {
     // To rematerialize an non-annotated tensor, simply load it
     // from the bufferized value.
-    Value val = codegen.loopEmitter.getValBuffer().back(); // value array
+    Value val = env.getValBuffer().back(); // value array
     rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val);
   }
 }
@@ -1656,6 +1666,7 @@ static void genResult(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
 //===----------------------------------------------------------------------===//
 
 namespace {
+
 /// Sparse rewriting rule for generic Lingalg operation.
 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
 public:
@@ -1664,86 +1675,84 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
 
   LogicalResult matchAndRewrite(linalg::GenericOp op,
                                 PatternRewriter &rewriter) const override {
-    // Detects sparse annotations and translate the per-dimension sparsity
-    // information for all tensors to loop indices in the kernel.
+    // Only accept single output operations.
     if (op.getNumDpsInits() != 1)
       return failure();
+
+    // Sets up a code generation environment.
     unsigned numTensors = op->getNumOperands();
     unsigned numLoops = op.getNumLoops();
     unsigned numFilterLoops = getNumCompoundAffineOnSparseDims(op);
-    Merger merger(numTensors, numLoops, numFilterLoops);
-    if (!findSparseAnnotations(merger, op))
+    CodeGenEnv env(op, options, numTensors, numLoops, numFilterLoops);
+
+    // Detects sparse annotations and translates the per-dimension sparsity
+    // information for all tensors to loop indices in the kernel.
+    if (!findSparseAnnotations(env))
       return failure();
 
     // Builds the tensor expression for the Linalg operation in SSA form.
-    Optional<unsigned> optExp = merger.buildTensorExpFromLinalg(op);
+    Optional<unsigned> optExp = env.merger.buildTensorExpFromLinalg(op);
     if (!optExp.has_value())
       return failure();
-
     unsigned exp = *optExp;
-    OpOperand *sparseOut = nullptr;
-    unsigned outerParNest = 0;
+
     // Computes a topologically sorted iteration graph to ensure tensors
     // are visited in natural index order. Gradually relaxes the considered
     // constraints until an acyclic iteration graph results, such that sparse
     // code generation can proceed. As a last resort, an attempt is made
     // to resolve cycles by inserting a conversion.
-    std::vector<unsigned> topSort;
-    // Whether the current GenericOp is admissible.
     bool isAdmissible = false;
     bool hasCycle = true;
     // An const list of all masks that we used for interation graph
-    // computation. Must be ordered from strict -> loose.
+    // computation. Must be ordered from more strict to less strict.
     const auto allMask = {SortMask::kIncludeAll, SortMask::kIncludeUndef,
                           SortMask::kIncludeDense, SortMask::kSparseOnly};
     for (auto mask : allMask)
-      if (computeIterationGraph(merger, op, topSort, mask)) {
+      if (computeIterationGraph(env, mask)) {
         hasCycle = false;
-        if (isAdmissibleTensorExp(merger, op, topSort, exp, &sparseOut,
-                                  outerParNest)) {
+        if (isAdmissibleTensorExp(env, exp)) {
           isAdmissible = true;
           break;
         }
         // else try a set of less strict constraints.
       }
-
     if (hasCycle)
-      // Give it one last shot to resolve the cycle.
-      return resolveCycle(merger, rewriter, op);
+      return resolveCycle(env, rewriter); // one last shot
     if (!isAdmissible)
-      // Inadmissible expression, reject.
-      return failure();
-
-    merger.setHasSparseOut(sparseOut != nullptr);
+      return failure(); // inadmissible expression, reject
 
+    // Updates environment with a loop emitter.
+    // TODO: refactor so that emitter can be constructed earlier
+    //       and updating is made easy, i.e. remove this whole block?
     SmallVector<Value> tensors;
     for (OpOperand &t : op->getOpOperands())
       tensors.push_back(t.get());
+    SparseTensorLoopEmitter lpe(
+        tensors,
+        StringAttr::get(op.getContext(), linalg::GenericOp::getOperationName()),
+        /*hasOutput=*/true, /*isSparseOut=*/env.sparseOut != nullptr,
+        env.topSort);
+    env.startEmit(&lpe);
 
     // Recursively generates code if admissible.
-    CodeGen codegen(options, op.getContext(), tensors, numTensors, numLoops,
-                    sparseOut, outerParNest, topSort);
-    genBuffers(merger, codegen, rewriter, op);
-    genInitConstantDenseAddress(codegen, rewriter, op);
-    genStmt(merger, codegen, rewriter, op, exp, 0);
-    genResult(merger, codegen, rewriter, op);
+    genBuffers(env, rewriter);
+    genInitConstantDenseAddress(env, rewriter);
+    genStmt(env, rewriter, exp, 0);
+    genResult(env, rewriter);
     return success();
   }
 
 private:
   // Last resort cycle resolution.
-  LogicalResult resolveCycle(Merger &merger, PatternRewriter &rewriter,
-                             linalg::GenericOp op) const {
+  LogicalResult resolveCycle(CodeGenEnv &env, PatternRewriter &rewriter) const {
     // Compute topological sort while leaving out every
     // sparse input tensor in succession until an acylic
     // iteration graph results.
-    std::vector<unsigned> topSort;
-    for (OpOperand *t : op.getDpsInputOperands()) {
+    for (OpOperand *t : env.linalgOp.getDpsInputOperands()) {
       unsigned tensor = t->getOperandNumber();
       Value tval = t->get();
       auto srcEnc = getSparseTensorEncoding(tval.getType());
-      if (!srcEnc ||
-          !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly, t))
+      if (!srcEnc || !computeIterationGraph(env, SortMask::kSparseOnly, t))
         continue;
       // Found an input tensor that resolves the cycle by inserting a
       // conversion into a sparse tensor that adheres to the iteration
@@ -1754,16 +1763,15 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
       //
       auto srcTp = tval.getType().cast<RankedTensorType>();
       auto dstEnc = SparseTensorEncodingAttr::get(
-          op->getContext(), srcEnc.getDimLevelType(),
-          permute(merger, getContext(), op.getMatchingIndexingMap(t),
-                  topSort), // new order
+          getContext(), srcEnc.getDimLevelType(),
+          permute(env, env.linalgOp.getMatchingIndexingMap(t)), // new order
           srcEnc.getHigherOrdering(), srcEnc.getPointerBitWidth(),
           srcEnc.getIndexBitWidth());
       auto dstTp = RankedTensorType::get(srcTp.getShape(),
                                          srcTp.getElementType(), dstEnc);
       auto convert = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval);
-      op->setOperand(tensor, convert);
-      rewriter.setInsertionPointAfter(op);
+      env.linalgOp->setOperand(tensor, convert);
+      rewriter.setInsertionPointAfter(env.linalgOp);
       rewriter.create<bufferization::DeallocTensorOp>(tval.getLoc(), convert);
       return success();
     }


        


More information about the Mlir-commits mailing list