[Mlir-commits] [mlir] 384049a - [mlir][sparse] completed codegen environment privatization

Aart Bik llvmlistbot at llvm.org
Thu Dec 22 10:34:51 PST 2022


Author: Aart Bik
Date: 2022-12-22T10:34:43-08:00
New Revision: 384049a755fa9f3a58655599f86e4fdf4633c390

URL: https://github.com/llvm/llvm-project/commit/384049a755fa9f3a58655599f86e4fdf4633c390
DIFF: https://github.com/llvm/llvm-project/commit/384049a755fa9f3a58655599f86e4fdf4633c390.diff

LOG: [mlir][sparse] completed codegen environment privatization

All members are now private and access is through delegate
or convenience methods only (except the loop emitter, which
is still under refactoring).

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D140519

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
index 52d6d5822e577..0be15d6995366 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
@@ -12,27 +12,84 @@ using namespace mlir;
 using namespace mlir::sparse_tensor;
 
 //===----------------------------------------------------------------------===//
-// Code generation environment constructor and setup
+// Code generation environment constructor and general methods
 //===----------------------------------------------------------------------===//
 
 CodegenEnv::CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts,
                        unsigned numTensors, unsigned numLoops,
                        unsigned numFilterLoops)
-    : linalgOp(linop), options(opts), topSort(),
-      merger(numTensors, numLoops, numFilterLoops), loopEmitter(nullptr),
-      sparseOut(nullptr), redVal(nullptr), redExp(-1u), redCustom(-1u) {}
+    : linalgOp(linop), sparseOptions(opts),
+      latticeMerger(numTensors, numLoops, numFilterLoops), loopEmitter(nullptr),
+      topSort(), sparseOut(nullptr), outerParNest(-1u), insChain(), expValues(),
+      expFilled(), expAdded(), expCount(), redVal(), redExp(-1u),
+      redCustom(-1u) {}
 
-void CodegenEnv::startEmit(SparseTensorLoopEmitter *le) {
-  assert(!loopEmitter && "must only start emitting once");
+void CodegenEnv::startEmit(OpOperand *so, unsigned lv,
+                           SparseTensorLoopEmitter *le) {
+  assert(sparseOut == nullptr && loopEmitter == nullptr &&
+         insChain == nullptr && "must only start emitting once");
+  sparseOut = so;
+  outerParNest = lv;
   loopEmitter = le;
   if (sparseOut) {
     insChain = sparseOut->get();
-    merger.setHasSparseOut(true);
+    latticeMerger.setHasSparseOut(true);
   }
 }
 
 //===----------------------------------------------------------------------===//
-// Code generation environment methods
+// Code generation environment topological sort methods
+//===----------------------------------------------------------------------===//
+
+ArrayRef<unsigned> CodegenEnv::getTopSortSlice(size_t n, size_t m) const {
+  return ArrayRef<unsigned>(topSort).slice(n, m);
+}
+
+ArrayRef<unsigned> CodegenEnv::getLoopCurStack() const {
+  return getTopSortSlice(0, loopEmitter->getCurrentDepth());
+}
+
+Value CodegenEnv::getLoopIdxValue(size_t loopIdx) const {
+  for (unsigned lv = 0, lve = topSort.size(); lv < lve; lv++)
+    if (topSort[lv] == loopIdx)
+      return loopEmitter->getLoopIV(lv);
+  llvm_unreachable("invalid loop index");
+}
+
+//===----------------------------------------------------------------------===//
+// Code generation environment sparse tensor output and expansion methods
+//===----------------------------------------------------------------------===//
+
+void CodegenEnv::updateInsertionChain(Value chain) {
+  assert(sparseOut != nullptr && insChain != nullptr);
+  insChain = chain;
+}
+
+bool CodegenEnv::atExpandLevel(OpOperand *o, unsigned rank, unsigned lv) const {
+  return sparseOut == o && outerParNest == rank - 1 && outerParNest == lv;
+}
+
+void CodegenEnv::startExpand(Value values, Value filled, Value added,
+                             Value count) {
+  assert(sparseOut != nullptr && expValues == nullptr);
+  expValues = values;
+  expFilled = filled;
+  expAdded = added;
+  expCount = count;
+}
+
+void CodegenEnv::updateExpandCount(Value count) {
+  assert(sparseOut != nullptr && expValues != nullptr);
+  expCount = count;
+}
+
+void CodegenEnv::endExpand() {
+  assert(sparseOut != nullptr && expValues != nullptr);
+  expValues = expFilled = expAdded = expCount = Value();
+}
+
+//===----------------------------------------------------------------------===//
+// Code generation environment reduction methods
 //===----------------------------------------------------------------------===//
 
 void CodegenEnv::startReduc(unsigned exp, Value val) {

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
index 1b0362203f2fd..cb5ba99a93065 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
@@ -28,59 +28,87 @@ namespace sparse_tensor {
 /// sparsification. This environment simplifies passing around such
 /// data during sparsification (rather than passing around all the
 /// individual compoments where needed). Furthermore, it provides
-/// a number of delegate and convience methods that keep some of the
-/// implementation details transparent to sparsification.
+/// convience methods that keep implementation details transparent
+/// to sparsification while asserting on internal consistency.
 class CodegenEnv {
 public:
+  /// Constructs a code generation environment which can be
+  /// passed around during sparsification for bookkeeping
+  /// together with some consistency asserts.
   CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts,
              unsigned numTensors, unsigned numLoops, unsigned numFilterLoops);
 
-  // Start emitting.
-  void startEmit(SparseTensorLoopEmitter *le);
+  //
+  // General methods.
+  //
+
+  linalg::GenericOp op() const { return linalgOp; }
+  const SparsificationOptions &options() const { return sparseOptions; }
+  Merger &merger() { return latticeMerger; }
+  SparseTensorLoopEmitter *emitter() { return loopEmitter; }
+
+  void startEmit(OpOperand *so, unsigned lv, SparseTensorLoopEmitter *le);
 
-  // Delegate methods to merger.
-  TensorExp &exp(unsigned e) { return merger.exp(e); }
-  LatPoint &lat(unsigned l) { return merger.lat(l); }
-  SmallVector<unsigned> &set(unsigned s) { return merger.set(s); }
-  DimLevelType dimLevelType(unsigned t, unsigned i) const {
-    return merger.getDimLevelType(t, i);
+  //
+  // Merger delegates.
+  //
+
+  TensorExp &exp(unsigned e) { return latticeMerger.exp(e); }
+  LatPoint &lat(unsigned l) { return latticeMerger.lat(l); }
+  SmallVector<unsigned> &set(unsigned s) { return latticeMerger.set(s); }
+  DimLevelType dlt(unsigned t, unsigned i) const {
+    return latticeMerger.getDimLevelType(t, i);
   }
-  DimLevelType dimLevelType(unsigned b) const {
-    return merger.getDimLevelType(b);
+  DimLevelType dlt(unsigned b) const {
+    return latticeMerger.getDimLevelType(b);
   }
-  bool isFilterLoop(unsigned i) const { return merger.isFilterLoop(i); }
 
-  // Delegate methods to loop emitter.
-  Value getLoopIV(unsigned i) const { return loopEmitter->getLoopIV(i); }
-  const std::vector<Value> &getValBuffer() const {
-    return loopEmitter->getValBuffer();
-  }
+  //
+  // Topological delegate and sort methods.
+  //
 
-  // Convenience method to slice topsort.
-  ArrayRef<unsigned> getTopSortSlice(size_t n, size_t m) const {
-    return ArrayRef<unsigned>(topSort).slice(n, m);
-  }
+  // TODO: get rid of this one!
+  std::vector<unsigned> &topSortRef() { return topSort; }
 
-  // Convenience method to get current loop stack.
-  ArrayRef<unsigned> getLoopCurStack() const {
-    return getTopSortSlice(0, loopEmitter->getCurrentDepth());
+  size_t topSortSize() const { return topSort.size(); }
+  unsigned topSortAt(unsigned i) const { return topSort.at(i); }
+  void topSortPushBack(unsigned i) { topSort.push_back(i); }
+  void topSortClear(unsigned capacity = 0) {
+    topSort.clear();
+    topSort.reserve(capacity);
   }
 
-  // Convenience method to get the IV of the given loop index.
-  Value getLoopIdxValue(size_t loopIdx) const {
-    for (unsigned lv = 0, lve = topSort.size(); lv < lve; lv++)
-      if (topSort[lv] == loopIdx)
-        return getLoopIV(lv);
-    llvm_unreachable("invalid loop index");
-  }
+  ArrayRef<unsigned> getTopSortSlice(size_t n, size_t m) const;
+  ArrayRef<unsigned> getLoopCurStack() const;
+  Value getLoopIdxValue(size_t loopIdx) const;
+
+  //
+  // Sparse tensor output and expansion methods.
+  //
+
+  bool hasSparseOutput() const { return sparseOut != nullptr; }
+  bool isSparseOutput(OpOperand *o) const { return sparseOut == o; }
+
+  Value getInsertionChain() const { return insChain; }
+  void updateInsertionChain(Value chain);
+
+  bool atExpandLevel(OpOperand *o, unsigned rank, unsigned lv) const;
+  void startExpand(Value values, Value filled, Value added, Value count);
+  bool isExpand() const { return expValues != nullptr; }
+  void updateExpandCount(Value count);
+  Value getExpandValues() const { return expValues; }
+  Value getExpandFilled() const { return expFilled; }
+  Value getExpandAdded() const { return expAdded; }
+  Value getExpandCount() const { return expCount; }
+  void endExpand();
 
   //
-  // Reductions.
+  // Reduction methods.
   //
 
   void startReduc(unsigned exp, Value val);
-  void updateReduc(Value val);
   bool isReduc() const { return redExp != -1u; }
+  void updateReduc(Value val);
   Value getReduc() const { return redVal; }
   Value endReduc();
 
@@ -89,39 +117,34 @@ class CodegenEnv {
   Value getCustomRedId();
   void endCustomReduc();
 
-public:
-  //
-  // TODO make this section private too, using similar refactoring as for reduc
-  //
-
+private:
   // Linalg operation.
   linalg::GenericOp linalgOp;
 
   // Sparsification options.
-  SparsificationOptions options;
-
-  // Topological sort.
-  std::vector<unsigned> topSort;
+  SparsificationOptions sparseOptions;
 
   // Merger helper class.
-  Merger merger;
+  Merger latticeMerger;
 
   // Loop emitter helper class (keep reference in scope!).
   // TODO: move emitter constructor up in time?
   SparseTensorLoopEmitter *loopEmitter;
 
+  // Topological sort.
+  std::vector<unsigned> topSort;
+
   // Sparse tensor as output. Implemented either through direct injective
-  // insertion in lexicographic index order or through access pattern expansion
-  // in the innermost loop nest (`expValues` through `expCount`).
+  // insertion in lexicographic index order or through access pattern
+  // expansion in the innermost loop nest (`expValues` through `expCount`).
   OpOperand *sparseOut;
   unsigned outerParNest;
-  Value insChain; // bookkeeping for insertion chain
+  Value insChain;
   Value expValues;
   Value expFilled;
   Value expAdded;
   Value expCount;
 
-private:
   // Bookkeeping for reductions (up-to-date value of the reduction, and indices
   // into the merger's expression tree. When the indices of a tensor reduction
   // expression are exhausted, all inner loops can use a scalarized reduction.

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index a6c031486c7e3..eb71c4cb8ff0e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -127,8 +127,8 @@ static bool isInvariantAffine(CodegenEnv &env, AffineExpr a, unsigned ldx,
 /// Helper method to construct a permuted dimension ordering
 /// that adheres to the given topological sort.
 static AffineMap permute(CodegenEnv &env, AffineMap m) {
-  assert(m.getNumDims() + env.merger.getNumFilterLoops() ==
-             env.topSort.size() &&
+  assert(m.getNumDims() + env.merger().getNumFilterLoops() ==
+             env.topSortSize() &&
          "size mismatch");
   // Construct the inverse of `m`; to avoid the asymptotic complexity
   // of calling `m.getPermutedPosition` repeatedly.
@@ -138,14 +138,14 @@ static AffineMap permute(CodegenEnv &env, AffineMap m) {
   unsigned loopDepth = 1;
 
   // Construct the permutation.
-  while (worklist.any() && loopDepth <= env.topSort.size()) {
+  while (worklist.any() && loopDepth <= env.topSortSize()) {
     unsigned preSize = perm.size();
     for (auto dim : worklist.set_bits()) {
       bool atLevel = false;
       if (m.getResult(dim).isa<AffineConstantExpr>() ||
           (isInvariantAffine(m.getResult(dim),
                              env.getTopSortSlice(0, loopDepth),
-                             env.topSort[loopDepth - 1], atLevel) &&
+                             env.topSortAt(loopDepth - 1), atLevel) &&
            atLevel)) {
         // If the matching affine is constant expression or just become
         // invariant. We can visit the dimension now without breaking the
@@ -163,7 +163,7 @@ static AffineMap permute(CodegenEnv &env, AffineMap m) {
   }
 
   assert(perm.size() == numResults);
-  return AffineMap::getPermutationMap(perm, env.linalgOp.getContext());
+  return AffineMap::getPermutationMap(perm, env.op().getContext());
 }
 
 /// Helper method to inspect affine expressions. Rejects cases where the
@@ -255,22 +255,22 @@ static unsigned getNumCompoundAffineOnSparseDims(linalg::GenericOp op) {
 /// no annotations are found or inadmissible constructs occur.
 static bool findSparseAnnotations(CodegenEnv &env) {
   bool annotated = false;
-  unsigned filterLdx = env.merger.getFilterLoopStartingIdx();
-  for (OpOperand &t : env.linalgOp->getOpOperands()) {
-    auto map = env.linalgOp.getMatchingIndexingMap(&t);
+  unsigned filterLdx = env.merger().getFilterLoopStartingIdx();
+  for (OpOperand &t : env.op()->getOpOperands()) {
+    auto map = env.op().getMatchingIndexingMap(&t);
     auto enc = getSparseTensorEncoding(t.get().getType());
     if (enc)
       annotated = true;
-    assert(map.getNumResults() == env.linalgOp.getRank(&t));
+    assert(map.getNumResults() == env.op().getRank(&t));
     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
       unsigned tensor = t.getOperandNumber();
       AffineExpr a = map.getResult(toOrigDim(enc, d));
-      if (!findAffine(env.merger, tensor, d, a, getDimLevelType(enc, d),
+      if (!findAffine(env.merger(), tensor, d, a, getDimLevelType(enc, d),
                       filterLdx))
         return false; // inadmissible affine expression
     }
   }
-  assert(filterLdx == env.merger.getNumLoops());
+  assert(filterLdx == env.merger().getNumLoops());
   return annotated;
 }
 
@@ -287,7 +287,7 @@ static bool topSortOptimal(CodegenEnv &env, unsigned n,
   std::vector<unsigned> filterIt; // filter loop with 0 degree
   for (unsigned i = 0; i < n; i++) {
     if (inDegree[i] == 0) {
-      if (env.isFilterLoop(i))
+      if (env.merger().isFilterLoop(i))
         filterIt.push_back(i);
       else if (linalg::isReductionIterator(iteratorTypes[i]))
         redIt.push_back(i);
@@ -318,12 +318,12 @@ static bool topSortOptimal(CodegenEnv &env, unsigned n,
     //        O(X) computation  => O(NK+NMX) time complexity
     auto &it = !filterIt.empty() ? filterIt : (!parIt.empty() ? parIt : redIt);
     auto src = it.back();
-    env.topSort.push_back(src);
+    env.topSortPushBack(src);
     it.pop_back();
     // Update in-degree, and push 0-degree node into worklist.
     for (unsigned dst = 0; dst < n; dst++) {
       if (adjM[src][dst] && --inDegree[dst] == 0) {
-        if (env.isFilterLoop(dst))
+        if (env.merger().isFilterLoop(dst))
           filterIt.push_back(dst);
         else if (linalg::isReductionIterator(iteratorTypes[dst]))
           redIt.push_back(dst);
@@ -332,7 +332,7 @@ static bool topSortOptimal(CodegenEnv &env, unsigned n,
       }
     }
   }
-  return env.topSort.size() == n;
+  return env.topSortSize() == n;
 }
 
 /// Helper method to add all constraints from the indices in one affine
@@ -428,17 +428,16 @@ static bool computeIterationGraph(CodegenEnv &env, unsigned mask,
                                   OpOperand *skip = nullptr) {
   // Set up an n x n from/to adjacency matrix of the iteration graph
   // for the implicit loop indices i_0 .. i_n-1.
-  unsigned n = env.merger.getNumLoops();
+  unsigned n = env.merger().getNumLoops();
   std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false));
   std::vector<unsigned> inDegree(n, 0); // in-degree of each node.
-  auto iteratorTypes = env.linalgOp.getIteratorTypesArray();
+  auto iteratorTypes = env.op().getIteratorTypesArray();
   // Iterate over the indexing maps of every tensor in the tensor expression.
-  for (OpOperand &t : env.linalgOp->getOpOperands()) {
+  for (OpOperand &t : env.op()->getOpOperands()) {
     // Get map and encoding.
-    auto map = env.linalgOp.getMatchingIndexingMap(&t);
+    auto map = env.op().getMatchingIndexingMap(&t);
     auto enc = getSparseTensorEncoding(t.get().getType());
-    assert(map.getNumDims() + getNumCompoundAffineOnSparseDims(env.linalgOp) ==
-           n);
+    assert(map.getNumDims() + getNumCompoundAffineOnSparseDims(env.op()) == n);
     // Skip dense tensor constraints when not requested.
     if (!(mask & SortMask::kIncludeDense) && !enc)
       continue;
@@ -448,11 +447,12 @@ static bool computeIterationGraph(CodegenEnv &env, unsigned mask,
     // on the loop indices if no explicit dimension ordering is given.
     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
       AffineExpr ta = map.getResult(toOrigDim(enc, d));
-      Optional<unsigned> tldx = env.merger.getLoopIdx(t.getOperandNumber(), d);
+      Optional<unsigned> tldx =
+          env.merger().getLoopIdx(t.getOperandNumber(), d);
 
       // Filter loops should be constructed after all the dependent loops,
       // i.e., d0 + d1 < filter_loop(d0 + d1)
-      if (tldx && env.isFilterLoop(*tldx)) {
+      if (tldx && env.merger().isFilterLoop(*tldx)) {
         assert(!ta.isa<AffineDimExpr>() &&
                !isDenseDLT(getDimLevelType(enc, d)));
         addAffineOrderings(adjM, inDegree, ta, AffineExpr(), std::nullopt,
@@ -472,7 +472,7 @@ static bool computeIterationGraph(CodegenEnv &env, unsigned mask,
       if (d > 0) {
         AffineExpr fa = map.getResult(toOrigDim(enc, d - 1));
         Optional<unsigned> fldx =
-            env.merger.getLoopIdx(t.getOperandNumber(), d - 1);
+            env.merger().getLoopIdx(t.getOperandNumber(), d - 1);
 
         // Applying order constraints on every pair of dimExpr between two
         // compound affine expressions can sometime too strict:
@@ -480,7 +480,7 @@ static bool computeIterationGraph(CodegenEnv &env, unsigned mask,
         // It is totally fine to have loop sequence d0->d2->d1->d3 instead of
         // requiring d0 < d2, d1 < d2, d0 < d3, d1 < d3.
         if (!(mask & SortMask::kIncludeDense))
-          tryLoosenAffineDenseConstraints(env.linalgOp, fldx, fa, tldx, ta);
+          tryLoosenAffineDenseConstraints(env.op(), fldx, fa, tldx, ta);
 
         // (d0 + d1) < (d2 + d3), or
         // filter_loop_d-1 < (d2 + d3), or
@@ -495,23 +495,22 @@ static bool computeIterationGraph(CodegenEnv &env, unsigned mask,
     if (mask & SortMask::kIncludeUndef) {
       unsigned tensor = t.getOperandNumber();
       for (unsigned i = 0; i < n; i++)
-        if (isCompressedDLT(env.dimLevelType(tensor, i)) ||
-            isSingletonDLT(env.dimLevelType(tensor, i))) {
+        if (isCompressedDLT(env.dlt(tensor, i)) ||
+            isSingletonDLT(env.dlt(tensor, i))) {
           for (unsigned j = 0; j < n; j++)
-            if (isUndefDLT(env.dimLevelType(tensor, j))) {
+            if (isUndefDLT(env.dlt(tensor, j))) {
               adjM[i][j] = true;
               inDegree[j]++;
             }
         } else {
-          assert(isDenseDLT(env.dimLevelType(tensor, i)) ||
-                 isUndefDLT(env.dimLevelType(tensor, i)));
+          assert(isDenseDLT(env.dlt(tensor, i)) ||
+                 isUndefDLT(env.dlt(tensor, i)));
         }
     }
   }
   // Topologically sort the iteration graph to determine loop order.
   // Report failure for a cyclic iteration graph.
-  env.topSort.clear();
-  env.topSort.reserve(n);
+  env.topSortClear(n);
   return topSortOptimal(env, n, iteratorTypes, inDegree, adjM);
 }
 
@@ -526,20 +525,22 @@ static bool isMaterializing(Value val) {
 /// whether the out tensor in the tensor expression codegen is admissible.
 /// Sets `sparseOut` to the tensor and `outerParNest` to the outer injective
 /// nesting depth when a "truly dynamic" sparse tensor output occurs.
-static bool isAdmissibleTensorExp(CodegenEnv &env, unsigned exp) {
+static bool isAdmissibleTensorExp(CodegenEnv &env, unsigned exp,
+                                  OpOperand **sparseOut,
+                                  unsigned *outerParNest) {
   // We reject any expression that makes a reduction from `-outTensor`, as those
   // expressions create a dependency between the current iteration (i) and the
   // previous iteration (i-1). It would require iterating over the whole
   // coordinate space, which prevent exploiting sparsity for faster code.
-  for (utils::IteratorType it : env.linalgOp.getIteratorTypesArray()) {
+  for (utils::IteratorType it : env.op().getIteratorTypesArray()) {
     if (it == utils::IteratorType::reduction) {
-      if (env.merger.hasNegateOnOut(exp))
+      if (env.merger().hasNegateOnOut(exp))
         return false;
       break;
     }
   }
 
-  OpOperand *lhs = env.linalgOp.getDpsInitOperand(0);
+  OpOperand *lhs = env.op().getDpsInitOperand(0);
   unsigned tensor = lhs->getOperandNumber();
   auto enc = getSparseTensorEncoding(lhs->get().getType());
   // An non-annotated output tensor is assumed dense, and becomes a random
@@ -550,40 +551,39 @@ static bool isAdmissibleTensorExp(CodegenEnv &env, unsigned exp) {
   // access 1-dim memref. Also admissible since insertions cannot occur.
   bool allDense = true;
   unsigned numLoops =
-      env.merger.getNumLoops(); // numNativeLoops + numFilterLoops
-  for (unsigned i = 0; i < env.merger.getNumLoops(); i++)
-    if (isCompressedDLT(env.dimLevelType(tensor, i)) ||
-        isSingletonDLT(env.dimLevelType(tensor, i))) {
+      env.merger().getNumLoops(); // numNativeLoops + numFilterLoops
+  for (unsigned i = 0; i < env.merger().getNumLoops(); i++)
+    if (isCompressedDLT(env.dlt(tensor, i)) ||
+        isSingletonDLT(env.dlt(tensor, i))) {
       allDense = false;
       break;
     } else {
-      assert(isDenseDLT(env.dimLevelType(tensor, i)) ||
-             isUndefDLT(env.dimLevelType(tensor, i)));
+      assert(isDenseDLT(env.dlt(tensor, i)) || isUndefDLT(env.dlt(tensor, i)));
     }
   if (allDense)
     return true;
 
   // TODO: support compound affine expression on sparse output.
-  if (getNumCompoundAffineOnSparseDims(env.linalgOp.getMatchingIndexingMap(lhs),
+  if (getNumCompoundAffineOnSparseDims(env.op().getMatchingIndexingMap(lhs),
                                        lhs->get()) != 0)
     return false;
 
   // A tensor expression with a sparse output tensor that changes its values
   // but not its nonzero structure, an operation called "simply dynamic" in
   // [Bik96,Ch9], is also admissible without special env.
-  if (env.merger.isSingleCondition(tensor, exp))
+  if (env.merger().isSingleCondition(tensor, exp))
     return true;
 
   // Accept "truly dynamic" if the output tensor materializes uninitialized
   // into the computation and insertions occur in lexicographic index order.
   if (isMaterializing(lhs->get())) {
-    auto iteratorTypes = env.linalgOp.getIteratorTypesArray();
+    auto iteratorTypes = env.op().getIteratorTypesArray();
     unsigned nest = 0;
     for (unsigned i = 0; i < numLoops; i++) {
-      if (!env.isFilterLoop(env.topSort[i])) {
+      if (!env.merger().isFilterLoop(env.topSortAt(i))) {
         // We only count non-filter loops as filter loops should be considered
         // as a special type of parallel loops.
-        if (linalg::isReductionIterator(iteratorTypes[env.topSort[i]]))
+        if (linalg::isReductionIterator(iteratorTypes[env.topSortAt(i)]))
           break; // terminate at first reduction
         nest++;
       }
@@ -591,9 +591,9 @@ static bool isAdmissibleTensorExp(CodegenEnv &env, unsigned exp) {
     // Determine admissible dynamic insertion situations:
     // (1) fully injective, since there are no reductions,
     // (2) admissible 1-d expansion in innermost dimension.
-    if (nest >= env.linalgOp.getRank(lhs) - 1) {
-      env.sparseOut = lhs;
-      env.outerParNest = nest;
+    if (nest >= env.op().getRank(lhs) - 1) {
+      *sparseOut = lhs;
+      *outerParNest = nest;
       return true;
     }
   }
@@ -613,10 +613,10 @@ static Optional<Operation *> genLoopBoundary(
   SmallVector<Value> reduc;
   if (env.isReduc())
     reduc.push_back(env.getReduc());
-  if (env.expValues)
-    reduc.push_back(env.expCount);
-  if (env.insChain)
-    reduc.push_back(env.insChain);
+  if (env.isExpand())
+    reduc.push_back(env.getExpandCount());
+  if (env.getInsertionChain())
+    reduc.push_back(env.getInsertionChain());
 
   auto r = callback(reduc);
 
@@ -624,21 +624,21 @@ static Optional<Operation *> genLoopBoundary(
   unsigned i = 0;
   if (env.isReduc())
     env.updateReduc(reduc[i++]);
-  if (env.expValues)
-    env.expCount = reduc[i++];
-  if (env.insChain)
-    env.insChain = reduc[i];
+  if (env.isExpand())
+    env.updateExpandCount(reduc[i++]);
+  if (env.getInsertionChain())
+    env.updateInsertionChain(reduc[i]);
 
   return r;
 }
 
 /// Local bufferization of all dense and sparse data structures.
 static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
-  linalg::GenericOp op = env.linalgOp;
+  linalg::GenericOp op = env.op();
   Location loc = op.getLoc();
   assert(op.getNumOperands() == op.getNumDpsInputs() + 1);
 
-  env.loopEmitter->initializeLoopEmit(
+  env.emitter()->initializeLoopEmit(
       builder, loc,
       /// Generates buffer for the output tensor.
       /// Note that all sparse kernels assume that when all elements are written
@@ -676,7 +676,7 @@ static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
 
 /// Generates index for load/store on sparse tensor.
 static Value genIndex(CodegenEnv &env, OpOperand *t) {
-  auto map = env.linalgOp.getMatchingIndexingMap(t);
+  auto map = env.op().getMatchingIndexingMap(t);
   auto enc = getSparseTensorEncoding(t->get().getType());
   AffineExpr a = map.getResult(toOrigDim(enc, map.getNumResults() - 1));
   assert(a.getKind() == AffineExprKind::DimId);
@@ -687,69 +687,73 @@ static Value genIndex(CodegenEnv &env, OpOperand *t) {
 /// Generates subscript for load/store on a dense or sparse tensor.
 static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
                           SmallVectorImpl<Value> &args) {
-  linalg::GenericOp op = env.linalgOp;
+  linalg::GenericOp op = env.op();
   unsigned tensor = t->getOperandNumber();
   auto map = op.getMatchingIndexingMap(t);
   auto enc = getSparseTensorEncoding(t->get().getType());
   unsigned rank = map.getNumResults();
   if (enc) {
-    Value pidx = env.loopEmitter->getPidxs()[tensor].back();
+    Value pidx = env.emitter()->getPidxs()[tensor].back();
     assert(pidx);
     args.push_back(pidx); // position index
   } else {
     for (unsigned d = 0; d < rank; d++) {
       AffineExpr a = map.getResult(d);
-      args.push_back(env.loopEmitter->genAffine(builder, a, op.getLoc()));
+      args.push_back(env.emitter()->genAffine(builder, a, op.getLoc()));
     }
   }
-  return env.getValBuffer()[tensor];
+  return env.emitter()->getValBuffer()[tensor];
 }
 
 /// Generates insertion code to implement dynamic tensor load.
 static Value genInsertionLoad(CodegenEnv &env, OpBuilder &builder,
                               OpOperand *t) {
-  linalg::GenericOp op = env.linalgOp;
+  linalg::GenericOp op = env.op();
   Location loc = op.getLoc();
   // Direct lexicographic index order, tensor loads as zero.
-  if (!env.expValues) {
+  if (!env.isExpand()) {
     Type tp = getElementTypeOrSelf(t->get().getType());
     return constantZero(builder, loc, tp);
   }
   // Load from expanded access pattern.
   Value index = genIndex(env, t);
-  return builder.create<memref::LoadOp>(loc, env.expValues, index);
+  return builder.create<memref::LoadOp>(loc, env.getExpandValues(), index);
 }
 
 /// Generates insertion code to implement dynamic tensor load for reduction.
 static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder,
                                     OpOperand *t) {
-  linalg::GenericOp op = env.linalgOp;
+  linalg::GenericOp op = env.op();
   Location loc = op.getLoc();
   Value identity = env.getCustomRedId();
   // Direct lexicographic index order, tensor loads as identity.
-  if (!env.expValues)
+  if (!env.isExpand())
     return identity;
   // Load from expanded access pattern if filled, identity otherwise.
+  Value values = env.getExpandValues();
+  Value filled = env.getExpandFilled();
   Value index = genIndex(env, t);
-  Value isFilled = builder.create<memref::LoadOp>(loc, env.expFilled, index);
-  Value valAtIndex = builder.create<memref::LoadOp>(loc, env.expValues, index);
+  Value isFilled = builder.create<memref::LoadOp>(loc, filled, index);
+  Value valAtIndex = builder.create<memref::LoadOp>(loc, values, index);
   return builder.create<arith::SelectOp>(loc, isFilled, valAtIndex, identity);
 }
 
 /// Generates insertion code to implement dynamic tensor store.
 static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
                               Value rhs) {
-  linalg::GenericOp op = env.linalgOp;
+  linalg::GenericOp op = env.op();
   Location loc = op.getLoc();
   // Direct insertion in lexicographic index order.
-  if (!env.expValues) {
+  if (!env.isExpand()) {
     unsigned rank = op.getRank(t);
     SmallVector<Value> indices;
     for (unsigned i = 0; i < rank; i++) {
-      assert(env.getLoopIV(i));
-      indices.push_back(env.getLoopIV(i));
+      assert(env.emitter()->getLoopIV(i));
+      indices.push_back(env.emitter()->getLoopIV(i));
     }
-    env.insChain = builder.create<InsertOp>(loc, rhs, env.insChain, indices);
+    Value chain = env.getInsertionChain();
+    env.updateInsertionChain(
+        builder.create<InsertOp>(loc, rhs, chain, indices));
     return;
   }
   // Generates insertion code along expanded access pattern.
@@ -758,29 +762,33 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
   //     expAdded[inserts++] = i
   //   endif
   //   values[i] = rhs
+  Value values = env.getExpandValues();
+  Value filled = env.getExpandFilled();
+  Value added = env.getExpandAdded();
+  Value count = env.getExpandCount();
   Value index = genIndex(env, t);
   Value fval = constantI1(builder, loc, false);
   Value tval = constantI1(builder, loc, true);
   // If statement.
-  Value filled = builder.create<memref::LoadOp>(loc, env.expFilled, index);
+  Value isFilled = builder.create<memref::LoadOp>(loc, filled, index);
   Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
-                                             filled, fval);
+                                             isFilled, fval);
   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIndexType(), cond,
                                              /*else=*/true);
   // True branch.
   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-  builder.create<memref::StoreOp>(loc, tval, env.expFilled, index);
-  builder.create<memref::StoreOp>(loc, index, env.expAdded, env.expCount);
+  builder.create<memref::StoreOp>(loc, tval, filled, index);
+  builder.create<memref::StoreOp>(loc, index, added, count);
   Value one = constantIndex(builder, loc, 1);
-  Value add = builder.create<arith::AddIOp>(loc, env.expCount, one);
+  Value add = builder.create<arith::AddIOp>(loc, count, one);
   builder.create<scf::YieldOp>(loc, add);
   // False branch.
   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
-  builder.create<scf::YieldOp>(loc, env.expCount);
+  builder.create<scf::YieldOp>(loc, count);
   builder.setInsertionPointAfter(ifOp);
   // Value assignment.
-  env.expCount = ifOp.getResult(0);
-  builder.create<memref::StoreOp>(loc, rhs, env.expValues, index);
+  env.updateExpandCount(ifOp.getResult(0));
+  builder.create<memref::StoreOp>(loc, rhs, values, index);
 }
 
 /// Generates a load on a dense or sparse tensor.
@@ -791,23 +799,23 @@ static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, unsigned exp) {
     return val;
 
   // Load during insertion.
-  linalg::GenericOp op = env.linalgOp;
-  OpOperand &t = op->getOpOperand(env.exp(exp).tensor);
-  if (&t == env.sparseOut) {
+  linalg::GenericOp op = env.op();
+  OpOperand *t = &op->getOpOperand(env.exp(exp).tensor);
+  if (env.isSparseOutput(t)) {
     if (env.isCustomReduc())
-      return genInsertionLoadReduce(env, builder, &t);
-    return genInsertionLoad(env, builder, &t);
+      return genInsertionLoadReduce(env, builder, t);
+    return genInsertionLoad(env, builder, t);
   }
   // Actual load.
   SmallVector<Value> args;
-  Value ptr = genSubscript(env, builder, &t, args);
+  Value ptr = genSubscript(env, builder, t, args);
   return builder.create<memref::LoadOp>(op.getLoc(), ptr, args);
 }
 
 /// Generates a store on a dense or sparse tensor.
 static void genTensorStore(CodegenEnv &env, OpBuilder &builder, unsigned exp,
                            Value rhs) {
-  linalg::GenericOp op = env.linalgOp;
+  linalg::GenericOp op = env.op();
   Location loc = op.getLoc();
   // Test if this is a scalarized reduction.
   if (env.isReduc()) {
@@ -816,17 +824,16 @@ static void genTensorStore(CodegenEnv &env, OpBuilder &builder, unsigned exp,
   }
   // Store during insertion.
   OpOperand *t = op.getDpsInitOperand(0);
-  if (t == env.sparseOut) {
+  if (env.isSparseOutput(t)) {
     if (!rhs) {
       // Only unary and binary are allowed to return uninitialized rhs
       // to indicate missing output.
       assert(env.exp(exp).kind == kUnary || env.exp(exp).kind == kBinary);
     } else if (env.exp(exp).kind == kSelect) {
       // Select operation insertion.
-      Value insChain = env.insChain;
-      assert(insChain);
-      scf::IfOp ifOp = builder.create<scf::IfOp>(loc, insChain.getType(), rhs,
-                                                 /*else=*/true);
+      Value chain = env.getInsertionChain();
+      scf::IfOp ifOp =
+          builder.create<scf::IfOp>(loc, chain.getType(), rhs, /*else=*/true);
       builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
       // Existing value was preserved to be used here.
       assert(env.exp(exp).val);
@@ -834,12 +841,13 @@ static void genTensorStore(CodegenEnv &env, OpBuilder &builder, unsigned exp,
       genInsertionStore(env, builder, t, v0);
       env.exp(exp).val = Value();
       // Yield modified insertion chain along true branch.
-      builder.create<scf::YieldOp>(op.getLoc(), env.insChain);
+      Value mchain = env.getInsertionChain();
+      builder.create<scf::YieldOp>(op.getLoc(), mchain);
       // Yield original insertion chain along false branch.
       builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
-      builder.create<scf::YieldOp>(loc, insChain);
+      builder.create<scf::YieldOp>(loc, chain);
       // Done with if statement.
-      env.insChain = ifOp->getResult(0);
+      env.updateInsertionChain(ifOp->getResult(0));
       builder.setInsertionPointAfter(ifOp);
     } else {
       genInsertionStore(env, builder, t, rhs);
@@ -884,7 +892,7 @@ static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
 /// Recursively generates tensor expression.
 static Value genExp(CodegenEnv &env, RewriterBase &rewriter, unsigned exp,
                     unsigned ldx) {
-  linalg::GenericOp op = env.linalgOp;
+  linalg::GenericOp op = env.op();
   Location loc = op.getLoc();
 
   if (exp == -1u)
@@ -901,7 +909,7 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, unsigned exp,
 
   Value v0 = genExp(env, rewriter, env.exp(exp).children.e0, ldx);
   Value v1 = genExp(env, rewriter, env.exp(exp).children.e1, ldx);
-  Value ee = env.merger.buildExp(rewriter, loc, exp, v0, v1);
+  Value ee = env.merger().buildExp(rewriter, loc, exp, v0, v1);
   if (ee && (env.exp(exp).kind == Kind::kUnary ||
              env.exp(exp).kind == Kind::kBinary ||
              env.exp(exp).kind == Kind::kBinaryBranch ||
@@ -928,14 +936,15 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, unsigned exp,
   if (env.exp(exp).kind == Kind::kTensor) {
     // Inspect tensor indices.
     bool atLevel = ldx == -1u;
-    linalg::GenericOp op = env.linalgOp;
+    linalg::GenericOp op = env.op();
     OpOperand &t = op->getOpOperand(env.exp(exp).tensor);
     auto map = op.getMatchingIndexingMap(&t);
     auto enc = getSparseTensorEncoding(t.get().getType());
     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
       AffineExpr a = map.getResult(toOrigDim(enc, d));
-      Optional<unsigned> sldx = env.merger.getLoopIdx(t.getOperandNumber(), d);
-      if (sldx && env.isFilterLoop(*sldx)) {
+      Optional<unsigned> sldx =
+          env.merger().getLoopIdx(t.getOperandNumber(), d);
+      if (sldx && env.merger().isFilterLoop(*sldx)) {
         if (!env.getLoopIdxValue(*sldx))
           // The filter loops has not been constructed.
           return;
@@ -978,11 +987,11 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, unsigned exp,
 }
 
 /// Generates an expanded access pattern in innermost dimension.
-static void genExpansion(CodegenEnv &env, OpBuilder &builder, unsigned at,
-                         bool atStart) {
-  linalg::GenericOp op = env.linalgOp;
-  OpOperand *lhs = env.sparseOut;
-  if (!lhs || env.outerParNest != op.getRank(lhs) - 1 || at != env.outerParNest)
+static void genExpand(CodegenEnv &env, OpBuilder &builder, unsigned at,
+                      bool atStart) {
+  linalg::GenericOp op = env.op();
+  OpOperand *lhs = op.getDpsInitOperand(0);
+  if (!env.atExpandLevel(lhs, op.getRank(lhs), at))
     return; // not needed at this level
   assert(!env.isReduc());
   // Generate start or end of an expanded access pattern. Note that because
@@ -999,25 +1008,23 @@ static void genExpansion(CodegenEnv &env, OpBuilder &builder, unsigned at,
     Type t2 = MemRefType::get(dynShape, builder.getI1Type());
     Type t3 = MemRefType::get(dynShape, builder.getIndexType());
     Type t4 = builder.getIndexType();
-    auto res =
-        builder.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor);
-    assert(res.getNumResults() == 4);
-    assert(!env.expValues);
-    env.expValues = res.getResult(0);
-    env.expFilled = res.getResult(1);
-    env.expAdded = res.getResult(2);
-    env.expCount = res.getResult(3);
+    auto r = builder.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor);
+    assert(r.getNumResults() == 4);
+    env.startExpand(r.getResult(0), r.getResult(1), r.getResult(2),
+                    r.getResult(3));
   } else {
-    assert(env.expValues);
     SmallVector<Value> indices;
-    for (unsigned i = 0; i < at; i++) {
-      assert(env.getLoopIV(i));
-      indices.push_back(env.getLoopIV(i));
-    }
-    env.insChain = builder.create<CompressOp>(loc, env.expValues, env.expFilled,
-                                              env.expAdded, env.expCount,
-                                              env.insChain, indices);
-    env.expValues = env.expFilled = env.expAdded = env.expCount = Value();
+    for (unsigned i = 0; i < at; i++)
+      indices.push_back(env.emitter()->getLoopIV(i));
+    Value values = env.getExpandValues();
+    Value filled = env.getExpandFilled();
+    Value added = env.getExpandAdded();
+    Value count = env.getExpandCount();
+    Value chain = env.getInsertionChain();
+    Value compress = builder.create<CompressOp>(loc, values, filled, added,
+                                                count, chain, indices);
+    env.updateInsertionChain(compress);
+    env.endExpand();
   }
 }
 
@@ -1026,13 +1033,13 @@ static void genExpansion(CodegenEnv &env, OpBuilder &builder, unsigned at,
 /// converted to a parallel operation depends on the requested strategy.
 static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) {
   // Reject parallelization of sparse output.
-  if (env.sparseOut)
+  if (env.hasSparseOutput())
     return false;
   // Parallel loops on tensor expansion can cause data races.
-  if (env.expCount)
+  if (env.isExpand())
     return false;
   // Inspect strategy.
-  switch (env.options.parallelizationStrategy) {
+  switch (env.options().parallelizationStrategy) {
   case SparseParallelizationStrategy::kNone:
     return false;
   case SparseParallelizationStrategy::kDenseOuterLoop:
@@ -1052,15 +1059,15 @@ static Operation *genFor(CodegenEnv &env, OpBuilder &builder, bool isOuter,
                          bool isInner, unsigned idx, size_t tid, size_t dim,
                          ArrayRef<size_t> extraTids,
                          ArrayRef<size_t> extraDims) {
-  linalg::GenericOp op = env.linalgOp;
+  linalg::GenericOp op = env.op();
   Location loc = op.getLoc();
   auto iteratorTypes = op.getIteratorTypesArray();
-  bool isSparse = isCompressedDLT(env.dimLevelType(tid, idx)) ||
-                  isSingletonDLT(env.dimLevelType(tid, idx));
+  bool isSparse =
+      isCompressedDLT(env.dlt(tid, idx)) || isSingletonDLT(env.dlt(tid, idx));
   bool isParallel = isParallelFor(env, isOuter, isSparse);
 
   Operation *loop = *genLoopBoundary(env, [&](MutableArrayRef<Value> reduc) {
-    if (env.isFilterLoop(idx)) {
+    if (env.merger().isFilterLoop(idx)) {
       // extraTids/extraDims must be empty because filter loops only
       // corresponding to the one and only sparse tensor level.
       assert(isSparse && extraTids.empty() && extraDims.empty());
@@ -1069,10 +1076,10 @@ static Operation *genFor(CodegenEnv &env, OpBuilder &builder, bool isOuter,
       // Retrieves the affine expression for the filter loop.
       AffineExpr a =
           op.getMatchingIndexingMap(t).getResult(toOrigDim(enc, dim));
-      return env.loopEmitter->enterFilterLoopOverTensorAtDim(builder, loc, tid,
-                                                             dim, a, reduc);
+      return env.emitter()->enterFilterLoopOverTensorAtDim(builder, loc, tid,
+                                                           dim, a, reduc);
     }
-    return env.loopEmitter->enterLoopOverTensorAtDim(
+    return env.emitter()->enterLoopOverTensorAtDim(
         builder, loc, tid, dim, reduc, isParallel, extraTids, extraDims);
   });
   assert(loop);
@@ -1088,8 +1095,8 @@ static Operation *genWhile(CodegenEnv &env, OpBuilder &builder, unsigned idx,
   Operation *loop = *genLoopBoundary(env, [&](MutableArrayRef<Value> reduc) {
     // Construct the while-loop with a parameter for each
     // index.
-    return env.loopEmitter->enterCoIterationOverTensorsAtDims(
-        builder, env.linalgOp.getLoc(), condTids, condDims, needsUniv, reduc,
+    return env.emitter()->enterCoIterationOverTensorsAtDims(
+        builder, env.op().getLoc(), condTids, condDims, needsUniv, reduc,
         extraTids, extraDims);
   });
   assert(loop);
@@ -1104,10 +1111,10 @@ static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, unsigned at,
                           ArrayRef<size_t> extraDims) {
   assert(condTids.size() == condDims.size());
   assert(extraTids.size() == extraDims.size());
-  unsigned idx = env.topSort[at];
+  unsigned idx = env.topSortAt(at);
   if (condTids.size() == 1) {
     bool isOuter = at == 0;
-    bool isInner = at == env.topSort.size() - 1;
+    bool isInner = at == env.topSortSize() - 1;
     return genFor(env, builder, isOuter, isInner, idx, condTids.front(),
                   condDims.front(), extraTids, extraDims);
   }
@@ -1119,9 +1126,9 @@ static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, unsigned at,
 static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, unsigned idx,
                             bool needsUniv, BitVector &induction,
                             scf::WhileOp whileOp) {
-  Location loc = env.linalgOp.getLoc();
+  Location loc = env.op().getLoc();
   // Finalize each else branch of all if statements.
-  if (env.isReduc() || env.expValues || env.insChain) {
+  if (env.isReduc() || env.isExpand() || env.getInsertionChain()) {
     while (auto ifOp = dyn_cast_or_null<scf::IfOp>(
                builder.getInsertionBlock()->getParentOp())) {
       unsigned y = 0;
@@ -1130,13 +1137,13 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, unsigned idx,
         yields.push_back(env.getReduc());
         env.updateReduc(ifOp.getResult(y++));
       }
-      if (env.expValues) {
-        yields.push_back(env.expCount);
-        env.expCount = ifOp->getResult(y++);
+      if (env.isExpand()) {
+        yields.push_back(env.getExpandCount());
+        env.updateExpandCount(ifOp->getResult(y++));
       }
-      if (env.insChain) {
-        yields.push_back(env.insChain);
-        env.insChain = ifOp->getResult(y++);
+      if (env.getInsertionChain()) {
+        yields.push_back(env.getInsertionChain());
+        env.updateInsertionChain(ifOp->getResult(y++));
       }
       assert(y == yields.size());
       builder.create<scf::YieldOp>(loc, yields);
@@ -1149,35 +1156,34 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, unsigned idx,
 /// Generates a single if-statement within a while-loop.
 static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, unsigned idx,
                        BitVector &conditions) {
-  Location loc = env.linalgOp.getLoc();
+  Location loc = env.op().getLoc();
   SmallVector<Type> types;
   Value cond;
   for (unsigned b = 0, be = conditions.size(); b < be; b++) {
     if (!conditions[b])
       continue;
-    unsigned tensor = env.merger.tensor(b);
-    assert(idx == env.merger.index(b));
+    unsigned tensor = env.merger().tensor(b);
+    assert(idx == env.merger().index(b));
     Value clause;
-    if (isCompressedDLT(env.dimLevelType(b)) ||
-        isSingletonDLT(env.dimLevelType(b))) {
-      auto dim = *env.merger.getDimNum(tensor, idx);
-      Value op1 = env.loopEmitter->getCoord()[tensor][dim];
+    if (isCompressedDLT(env.dlt(b)) || isSingletonDLT(env.dlt(b))) {
+      auto dim = *env.merger().getDimNum(tensor, idx);
+      Value op1 = env.emitter()->getCoord()[tensor][dim];
       Value op2 = env.getLoopIdxValue(idx);
       clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, op1,
                                              op2);
     } else {
-      assert(isDenseDLT(env.dimLevelType(b)) ||
-             isUndefDLT(env.dimLevelType(b)));
+      assert(isDenseDLT(env.merger().getDimLevelType(b)) ||
+             isUndefDLT(env.merger().getDimLevelType(b)));
       clause = constantI1(builder, loc, true);
     }
     cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause;
   }
   if (env.isReduc())
     types.push_back(env.getReduc().getType());
-  if (env.expValues)
+  if (env.isExpand())
     types.push_back(builder.getIndexType());
-  if (env.insChain)
-    types.push_back(env.insChain.getType());
+  if (env.getInsertionChain())
+    types.push_back(env.getInsertionChain().getType());
   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
   return ifOp;
@@ -1192,16 +1198,16 @@ static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
     operands.push_back(env.getReduc());
     env.updateReduc(redInput);
   }
-  if (env.expValues) {
-    operands.push_back(env.expCount);
-    env.expCount = cntInput;
+  if (env.isExpand()) {
+    operands.push_back(env.getExpandCount());
+    env.updateExpandCount(cntInput);
   }
-  if (env.insChain) {
-    operands.push_back(env.insChain);
-    env.insChain = insInput;
+  if (env.getInsertionChain()) {
+    operands.push_back(env.getInsertionChain());
+    env.updateInsertionChain(insInput);
   }
   if (!operands.empty())
-    builder.create<scf::YieldOp>(env.linalgOp.getLoc(), operands);
+    builder.create<scf::YieldOp>(env.op().getLoc(), operands);
   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
 }
 
@@ -1218,17 +1224,17 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
   // Emit invariants at this loop sequence level.
   genInvariants(env, builder, exp, ldx, /*atStart=*/true);
   // Emit access pattern expansion for sparse tensor output.
-  genExpansion(env, builder, at, /*atStart=*/true);
+  genExpand(env, builder, at, /*atStart=*/true);
   // Emit further intitialization at this loop sequence level.
   unsigned l0 = env.set(lts)[0];
   bool needsUniv = false;
 
   SmallVector<size_t> tids;
   SmallVector<size_t> dims;
-  env.merger.foreachTidDimPairInBits(
+  env.merger().foreachTidDimPairInBits(
       env.lat(l0).bits,
       [&](unsigned b, unsigned tid, Optional<unsigned> dim, DimLevelType dlt) {
-        assert(env.merger.index(b) == idx);
+        assert(env.merger().index(b) == idx);
         if (isDenseDLT(dlt) || isUndefDLT(dlt)) {
           needsUniv = true;
         } else {
@@ -1238,7 +1244,7 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
         }
       });
 
-  env.loopEmitter->enterNewLoopSeq(builder, env.linalgOp.getLoc(), tids, dims);
+  env.emitter()->enterNewLoopSeq(builder, env.op().getLoc(), tids, dims);
 
   // Maintain the universal index only if it is actually
   // consumed by a subsequent lattice point.
@@ -1246,7 +1252,7 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
     unsigned lsize = env.set(lts).size();
     for (unsigned i = 1; i < lsize; i++) {
       unsigned li = env.set(lts)[i];
-      if (!env.merger.hasAnySparse(env.lat(li).simple))
+      if (!env.merger().hasAnySparse(env.lat(li).simple))
         return true;
     }
   }
@@ -1257,7 +1263,7 @@ static void genConstantDenseAddressFromLevel(CodegenEnv &env,
                                              OpBuilder &builder, unsigned tid,
                                              unsigned lvl) {
   // TODO: Handle affine expression on output tensor.
-  linalg::GenericOp op = env.linalgOp;
+  linalg::GenericOp op = env.op();
   assert(tid < op.getNumDpsInputs());
   OpOperand *input = op.getDpsInputOperands()[tid];
   ArrayRef<AffineExpr> affines = op.getMatchingIndexingMap(input).getResults();
@@ -1267,7 +1273,7 @@ static void genConstantDenseAddressFromLevel(CodegenEnv &env,
       AffineExpr affine = affines[toOrigDim(enc, i)];
       if (isDenseDLT(getDimLevelType(enc, i)) &&
           affine.isa<AffineConstantExpr>())
-        env.loopEmitter->genDenseAffineAddressAtCurLevel(
+        env.emitter()->genDenseAffineAddressAtCurLevel(
             builder, op.getLoc(), input->getOperandNumber(), i, affine);
       else
         return; // break on first non-dense non-constant level
@@ -1281,7 +1287,7 @@ static void genInitConstantDenseAddress(CodegenEnv &env,
   // starting from the first level as they do not depend on any thing.
   // E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
   // levels can be determined before loops.
-  for (unsigned tid = 0, e = env.linalgOp.getNumDpsInputs(); tid < e; tid++)
+  for (unsigned tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
     genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
 }
 
@@ -1295,9 +1301,9 @@ static void translateBitsToTidDimPairs(
   const BitVector &simple = env.lat(li).simple;
 
   // Converts bits to array + dim pair
-  env.merger.foreachTidDimPairInBits(all, [&, idx](unsigned b, unsigned tid,
-                                                   Optional<unsigned> dim,
-                                                   DimLevelType dlt) {
+  env.merger().foreachTidDimPairInBits(all, [&, idx](unsigned b, unsigned tid,
+                                                     Optional<unsigned> dim,
+                                                     DimLevelType dlt) {
     if (simple.test(b)) {
       if (isUndefDLT(dlt)) {
         // An undefined dlt in the lattices, we probably mean to iterate based
@@ -1306,8 +1312,8 @@ static void translateBitsToTidDimPairs(
         // output tensor).
         // out[i][j] = invariant; or a broadcast
         // out[i][j] = in[i] (j is undef for input)
-        tid = env.merger.getOutTensorID();
-        dim = env.merger.getDimNum(tid, idx);
+        tid = env.merger().getOutTensorID();
+        dim = env.merger().getDimNum(tid, idx);
         // Skips invalid dim (e.g., when this is a zero ranked tensor).
         if (!dim)
           return;
@@ -1320,7 +1326,7 @@ static void translateBitsToTidDimPairs(
       extraDims.push_back(*dim);
     } else {
       assert(isUndefDLT(dlt));
-      linalg::GenericOp op = env.linalgOp;
+      linalg::GenericOp op = env.op();
       if (tid >= op.getNumDpsInputs())
         // We only handle affine expression on input tensors (for now).
         return;
@@ -1361,12 +1367,12 @@ static void translateBitsToTidDimPairs(
     }
   });
 
-  if (isDenseDLT(env.dimLevelType(env.merger.getOutTensorID(), idx))) {
+  if (isDenseDLT(env.dlt(env.merger().getOutTensorID(), idx))) {
     // Note that we generate dense indices of the output tensor
     // unconditionally, since they may not appear in the lattice, but may be
-    // needed for linearized codegen.
-    auto dim = *env.merger.getDimNum(env.merger.getOutTensorID(), idx);
-    extraTids.push_back(env.merger.getOutTensorID());
+    // needed for linearized env.
+    auto dim = *env.merger().getDimNum(env.merger().getOutTensorID(), idx);
+    extraTids.push_back(env.merger().getOutTensorID());
     extraDims.push_back(dim);
   }
 }
@@ -1384,7 +1390,7 @@ static Operation *startLoop(CodegenEnv &env, OpBuilder &builder, unsigned at,
   // level.
   SmallVector<size_t> affineTids, affineDims;
   SmallVector<AffineExpr> affines;
-  translateBitsToTidDimPairs(env, li, env.topSort[at], condTids, condDims,
+  translateBitsToTidDimPairs(env, li, env.topSortAt(at), condTids, condDims,
                              extraTids, extraDims, affineTids, affineDims,
                              affines);
 
@@ -1392,8 +1398,8 @@ static Operation *startLoop(CodegenEnv &env, OpBuilder &builder, unsigned at,
   Operation *loop = genLoop(env, builder, at, needsUniv, condTids, condDims,
                             extraTids, extraDims);
   for (auto [tid, dim, exp] : llvm::zip(affineTids, affineDims, affines)) {
-    env.loopEmitter->genDenseAffineAddressAtCurLevel(
-        builder, env.linalgOp.getLoc(), tid, dim, exp);
+    env.emitter()->genDenseAffineAddressAtCurLevel(builder, env.op().getLoc(),
+                                                   tid, dim, exp);
   }
 
   // Until now, we have entered every <tid, dim> pair in {cond, extra,
@@ -1402,7 +1408,7 @@ static Operation *startLoop(CodegenEnv &env, OpBuilder &builder, unsigned at,
   auto allTids = llvm::concat<size_t>(condTids, extraTids, affineTids);
   auto allDims = llvm::concat<size_t>(condDims, extraDims, affineDims);
   for (auto [tid, dim] : llvm::zip(allTids, allDims)) {
-    if (tid != env.merger.getOutTensorID())
+    if (tid != env.merger().getOutTensorID())
       genConstantDenseAddressFromLevel(env, builder, tid, dim + 1);
   }
 
@@ -1420,7 +1426,7 @@ static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
   }
 
   genLoopBoundary(env, [&](MutableArrayRef<Value> reduc) {
-    env.loopEmitter->exitCurrentLoop(rewriter, env.linalgOp.getLoc(), reduc);
+    env.emitter()->exitCurrentLoop(rewriter, env.op().getLoc(), reduc);
     return std::nullopt;
   });
 
@@ -1431,11 +1437,11 @@ static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
 static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
                        unsigned at, unsigned idx, unsigned ldx) {
   assert(env.getLoopIdxValue(idx) == nullptr);
-  env.loopEmitter->exitCurrentLoopSeq();
+  env.emitter()->exitCurrentLoopSeq();
   // Unmark bookkeeping of invariants and loop index.
   genInvariants(env, builder, exp, ldx, /*atStart=*/false);
   // Finalize access pattern expansion for sparse tensor output.
-  genExpansion(env, builder, at, /*atStart=*/false);
+  genExpand(env, builder, at, /*atStart=*/false);
 }
 
 /// Recursively generates code while computing iteration lattices in order
@@ -1444,17 +1450,17 @@ static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
 static void genStmt(CodegenEnv &env, RewriterBase &rewriter, unsigned exp,
                     unsigned at) {
   // At each leaf, assign remaining tensor (sub)expression to output tensor.
-  if (at == env.topSort.size()) {
-    unsigned ldx = env.topSort[at - 1];
+  if (at == env.topSortSize()) {
+    unsigned ldx = env.topSortAt(at - 1);
     Value rhs = genExp(env, rewriter, exp, ldx);
     genTensorStore(env, rewriter, exp, rhs);
     return;
   }
 
   // Construct iteration lattices for current loop index, with L0 at top.
-  unsigned idx = env.topSort[at];
-  unsigned ldx = at == 0 ? -1u : env.topSort[at - 1];
-  unsigned lts = env.merger.optimizeSet(env.merger.buildLattices(exp, idx));
+  unsigned idx = env.topSortAt(at);
+  unsigned ldx = at == 0 ? -1u : env.topSortAt(at - 1);
+  unsigned lts = env.merger().optimizeSet(env.merger().buildLattices(exp, idx));
 
   // TODO: sort
   // TODO: dedup
@@ -1472,13 +1478,13 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, unsigned exp,
     // Visit all lattices points with Li >= Lj to generate the
     // loop-body, possibly with if statements for coiteration.
     Value redInput = env.getReduc();
-    Value cntInput = env.expCount;
-    Value insInput = env.insChain;
+    Value cntInput = env.getExpandCount();
+    Value insInput = env.getInsertionChain();
     bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr;
     for (unsigned j = 0; j < lsize; j++) {
       unsigned lj = env.set(lts)[j];
       unsigned ej = env.lat(lj).exp;
-      if (li == lj || env.merger.latGT(li, lj)) {
+      if (li == lj || env.merger().latGT(li, lj)) {
         // Recurse into body of each branch.
         if (isWhile) {
           scf::IfOp ifOp = genIf(env, rewriter, idx, env.lat(lj).simple);
@@ -1500,7 +1506,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, unsigned exp,
 
 /// Converts the result computed by the sparse kernel into the required form.
 static void genResult(CodegenEnv &env, RewriterBase &rewriter) {
-  linalg::GenericOp op = env.linalgOp;
+  linalg::GenericOp op = env.op();
   OpOperand *lhs = op.getDpsInitOperand(0);
   Value tensor = lhs->get();
   Type resType = tensor.getType();
@@ -1508,14 +1514,16 @@ static void genResult(CodegenEnv &env, RewriterBase &rewriter) {
     // The sparse tensor rematerializes from the original sparse tensor's
     // underlying sparse storage format. For an insertion chain, the
     // tensor materializes from the chain with 'hasInserts' enabled.
-    bool hasInserts = env.sparseOut == lhs;
-    if (hasInserts)
-      tensor = env.insChain;
+    bool hasInserts = false;
+    if (Value chain = env.getInsertionChain()) {
+      hasInserts = true;
+      tensor = chain;
+    }
     rewriter.replaceOpWithNewOp<LoadOp>(op, resType, tensor, hasInserts);
   } else {
     // To rematerialize an non-annotated tensor, simply load it
     // from the bufferized value.
-    Value val = env.getValBuffer().back(); // value array
+    Value val = env.emitter()->getValBuffer().back(); // value array
     rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val);
   }
 }
@@ -1550,8 +1558,8 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
       return failure();
 
     // Builds the tensor expression for the Linalg operation in SSA form.
-    Optional<unsigned> optExp = env.merger.buildTensorExpFromLinalg(op);
-    if (!optExp.has_value())
+    Optional<unsigned> optExp = env.merger().buildTensorExpFromLinalg(op);
+    if (!optExp)
       return failure();
     unsigned exp = *optExp;
 
@@ -1562,6 +1570,8 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
     // to resolve cycles by inserting a conversion.
     bool isAdmissible = false;
     bool hasCycle = true;
+    OpOperand *sparseOut = nullptr;
+    unsigned outerParNest = -1u;
     // An const list of all masks that we used for interation graph
     // computation. Must be ordered from more strict to less strict.
     const auto allMask = {SortMask::kIncludeAll, SortMask::kIncludeUndef,
@@ -1569,7 +1579,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
     for (auto mask : allMask)
       if (computeIterationGraph(env, mask)) {
         hasCycle = false;
-        if (isAdmissibleTensorExp(env, exp)) {
+        if (isAdmissibleTensorExp(env, exp, &sparseOut, &outerParNest)) {
           isAdmissible = true;
           break;
         }
@@ -1589,9 +1599,9 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
     SparseTensorLoopEmitter lpe(
         tensors,
         StringAttr::get(op.getContext(), linalg::GenericOp::getOperationName()),
-        /*hasOutput=*/true, /*isSparseOut=*/env.sparseOut != nullptr,
-        env.topSort);
-    env.startEmit(&lpe);
+        /*hasOutput=*/true, /*isSparseOut=*/sparseOut != nullptr,
+        env.topSortRef());
+    env.startEmit(sparseOut, outerParNest, &lpe);
 
     // Recursively generates code if admissible.
     genBuffers(env, rewriter);
@@ -1607,7 +1617,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
     // Compute topological sort while leaving out every
     // sparse input tensor in succession until an acylic
     // iteration graph results.
-    for (OpOperand *t : env.linalgOp.getDpsInputOperands()) {
+    for (OpOperand *t : env.op().getDpsInputOperands()) {
       unsigned tensor = t->getOperandNumber();
       Value tval = t->get();
       auto srcEnc = getSparseTensorEncoding(tval.getType());
@@ -1623,14 +1633,14 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
       auto srcTp = tval.getType().cast<RankedTensorType>();
       auto dstEnc = SparseTensorEncodingAttr::get(
           getContext(), srcEnc.getDimLevelType(),
-          permute(env, env.linalgOp.getMatchingIndexingMap(t)), // new order
+          permute(env, env.op().getMatchingIndexingMap(t)), // new order
           srcEnc.getHigherOrdering(), srcEnc.getPointerBitWidth(),
           srcEnc.getIndexBitWidth());
       auto dstTp = RankedTensorType::get(srcTp.getShape(),
                                          srcTp.getElementType(), dstEnc);
       auto convert = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval);
-      env.linalgOp->setOperand(tensor, convert);
-      rewriter.setInsertionPointAfter(env.linalgOp);
+      env.op()->setOperand(tensor, convert);
+      rewriter.setInsertionPointAfter(env.op());
       rewriter.create<bufferization::DeallocTensorOp>(tval.getLoc(), convert);
       return success();
     }


        


More information about the Mlir-commits mailing list