[llvm-branch-commits] [mlir] ff8815e - [mlir][sparse] code cleanup (remove topSort in CodegenEnv). (#72550)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Nov 17 03:47:24 PST 2023


Author: Peiming Liu
Date: 2023-11-16T13:21:49-08:00
New Revision: ff8815e597597a319ffde9d18e708040d226bbae

URL: https://github.com/llvm/llvm-project/commit/ff8815e597597a319ffde9d18e708040d226bbae
DIFF: https://github.com/llvm/llvm-project/commit/ff8815e597597a319ffde9d18e708040d226bbae.diff

LOG: [mlir][sparse] code cleanup (remove topSort in CodegenEnv). (#72550)

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
    mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
index 3a02d5634586070..ad2f649e5babdf4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
@@ -28,21 +28,13 @@ static bool isMaterializing(Value val) {
          val.getDefiningOp<bufferization::AllocTensorOp>();
 }
 
-/// Makes target array's elements sorted according to the `order` array.
-static void sortArrayBasedOnOrder(std::vector<LoopCoeffPair> &target,
-                                  ArrayRef<LoopId> order) {
+/// Sorts the dependent loops such that it is ordered in the same sequence in
+/// which loops will be generated.
+static void sortDependentLoops(std::vector<LoopCoeffPair> &target) {
   std::sort(target.begin(), target.end(),
-            [&order](const LoopCoeffPair &l, const LoopCoeffPair &r) {
+            [](const LoopCoeffPair &l, const LoopCoeffPair &r) {
               assert(std::addressof(l) == std::addressof(r) || l != r);
-              int idxL = -1, idxR = -1;
-              for (int i = 0, e = order.size(); i < e; i++) {
-                if (order[i] == l.first)
-                  idxL = i;
-                if (order[i] == r.first)
-                  idxR = i;
-              }
-              assert(idxL >= 0 && idxR >= 0);
-              return idxL < idxR;
+              return l.first < r.first;
             });
 }
 //===----------------------------------------------------------------------===//
@@ -54,14 +46,10 @@ CodegenEnv::CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts,
                        unsigned numFilterLoops, unsigned maxRank)
     : linalgOp(linop), sparseOptions(opts),
       latticeMerger(numTensors, numLoops, numFilterLoops, maxRank),
-      loopEmitter(), topSort(), sparseOut(nullptr), outerParNest(-1u),
-      insChain(), expValues(), expFilled(), expAdded(), expCount(), redVal(),
+      loopEmitter(), sparseOut(nullptr), outerParNest(-1u), insChain(),
+      expValues(), expFilled(), expAdded(), expCount(), redVal(),
       redExp(detail::kInvalidId), redCustom(detail::kInvalidId),
-      redValidLexInsert() {
-  // TODO: remove topSort, loops should be already sorted by previous pass.
-  for (unsigned l = 0; l < latticeMerger.getNumLoops(); l++)
-    topSort.push_back(l);
-}
+      redValidLexInsert() {}
 
 LogicalResult CodegenEnv::initTensorExp() {
   // Builds the tensor expression for the Linalg operation in SSA form.
@@ -97,7 +85,7 @@ void CodegenEnv::startEmit() {
     (void)enc;
     assert(!enc || lvlRank == enc.getLvlRank());
     for (Level lvl = 0; lvl < lvlRank; lvl++)
-      sortArrayBasedOnOrder(latticeMerger.getDependentLoops(tid, lvl), topSort);
+      sortDependentLoops(latticeMerger.getDependentLoops(tid, lvl));
   }
 
   loopEmitter.initialize(
@@ -105,7 +93,7 @@ void CodegenEnv::startEmit() {
       StringAttr::get(linalgOp.getContext(),
                       linalg::GenericOp::getOperationName()),
       /*hasOutput=*/true,
-      /*isSparseOut=*/sparseOut != nullptr, topSort,
+      /*isSparseOut=*/sparseOut != nullptr, /*numLoops=*/getLoopNum(),
       // TODO: compute the map and pass it to loop emitter directly instead of
       // passing in a callback.
       /*dependentLvlGetter=*/
@@ -190,8 +178,7 @@ bool CodegenEnv::isAdmissibleTensorExp(ExprId exp) {
   // needed.
   outerParNest = 0;
   const auto iteratorTypes = linalgOp.getIteratorTypesArray();
-  assert(topSortSize() == latticeMerger.getNumLoops());
-  for (const LoopId i : topSort) {
+  for (unsigned i = 0, e = getLoopNum(); i < e; i++) {
     if (linalg::isReductionIterator(iteratorTypes[i]))
       break; // terminate at first reduction
     outerParNest++;
@@ -208,26 +195,8 @@ bool CodegenEnv::isAdmissibleTensorExp(ExprId exp) {
 // Code generation environment topological sort methods
 //===----------------------------------------------------------------------===//
 
-ArrayRef<LoopId> CodegenEnv::getTopSortSlice(LoopOrd n, LoopOrd m) const {
-  return ArrayRef<LoopId>(topSort).slice(n, m);
-}
-
-ArrayRef<LoopId> CodegenEnv::getLoopStackUpTo(LoopOrd n) const {
-  return ArrayRef<LoopId>(topSort).take_front(n);
-}
-
-ArrayRef<LoopId> CodegenEnv::getCurrentLoopStack() const {
-  return getLoopStackUpTo(loopEmitter.getCurrentDepth());
-}
-
 Value CodegenEnv::getLoopVar(LoopId i) const {
-  // TODO: this class should store the inverse of `topSort` so that
-  // it can do this conversion directly, instead of searching through
-  // `topSort` every time.  (Or else, `LoopEmitter` should handle this.)
-  for (LoopOrd n = 0, numLoops = topSortSize(); n < numLoops; n++)
-    if (topSort[n] == i)
-      return loopEmitter.getLoopIV(n);
-  llvm_unreachable("invalid loop identifier");
+  return loopEmitter.getLoopIV(i);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
index c0fc505d153a459..af783b9dca276e3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
@@ -83,6 +83,8 @@ class CodegenEnv {
   }
   DimLevelType dlt(TensorLoopId b) const { return latticeMerger.getLvlType(b); }
 
+  unsigned getLoopNum() const { return latticeMerger.getNumLoops(); }
+
   //
   // LoopEmitter delegates.
   //
@@ -107,6 +109,8 @@ class CodegenEnv {
     return loopEmitter.unpackTensorLevelRange(std::forward<ContainerTy>(c));
   }
 
+  unsigned getLoopDepth() const { return loopEmitter.getCurrentDepth(); }
+
   //
   // Code generation environment verify functions.
   //
@@ -115,25 +119,6 @@ class CodegenEnv {
   /// It also sets the sparseOut if the output tensor is sparse.
   bool isAdmissibleTensorExp(ExprId e);
 
-  /// Whether the iteration graph is sorted in admissible topoOrder.
-  /// Sets outerParNest on success with sparse output
-  bool isAdmissibleTopoOrder();
-
-  //
-  // Topological delegate and sort methods.
-  //
-
-  LoopOrd topSortSize() const { return topSort.size(); }
-  LoopId topSortAt(LoopOrd n) const { return topSort.at(n); }
-  void topSortPushBack(LoopId i) { topSort.push_back(i); }
-  void topSortClear(size_t capacity = 0) {
-    topSort.clear();
-    topSort.reserve(capacity);
-  }
-
-  ArrayRef<LoopId> getTopSortSlice(LoopOrd n, LoopOrd m) const;
-  ArrayRef<LoopId> getLoopStackUpTo(LoopOrd n) const;
-  ArrayRef<LoopId> getCurrentLoopStack() const;
   /// Returns the induction-variable for the loop identified by the given
   /// `LoopId`.  This method handles application of the topological sort
   /// in order to convert the `LoopId` into the corresponding `LoopOrd`.
@@ -191,10 +176,6 @@ class CodegenEnv {
   // Loop emitter helper class.
   LoopEmitter loopEmitter;
 
-  // Topological sort.  This serves as a mapping from `LoopOrd` to `LoopId`
-  // (cf., `getLoopVar` and `topSortAt`).
-  std::vector<LoopId> topSort;
-
   // Sparse tensor as output. Implemented either through direct injective
   // insertion in lexicographic index order or through access pattern
   // expansion in the innermost loop nest (`expValues` through `expCount`).

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index 81ce525a62d9f61..ba798f09c4d583b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -276,13 +276,13 @@ Value LoopEmitter::genSparseCrd(OpBuilder &builder, Location loc, TensorId tid,
 }
 
 LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput,
-                         bool isSparseOut, ArrayRef<LoopId> topSort,
+                         bool isSparseOut, unsigned numLoops,
                          DependentLvlGetter dimGetter) {
-  initialize(tensors, loopTag, hasOutput, isSparseOut, topSort, dimGetter);
+  initialize(tensors, loopTag, hasOutput, isSparseOut, numLoops, dimGetter);
 }
 
 void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
-                             bool isSparseOut, ArrayRef<LoopId> topSort,
+                             bool isSparseOut, unsigned numLoops,
                              DependentLvlGetter dimGetter) {
   // First initialize the top-level type of the fields.
   this->loopTag = loopTag;
@@ -308,10 +308,8 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
   this->sliceOffsets.assign(numTensors, std::vector<Value>());
   this->sliceStrides.assign(numTensors, std::vector<Value>());
 
-  const LoopOrd numLoops = topSort.size();
   // These zeros will be overwritten below, but we need to initialize
   // them to something since we'll need random-access assignment.
-  this->loopIdToOrd.assign(numLoops, 0);
   this->loopStack.reserve(numLoops);
   this->loopSeqStack.reserve(numLoops);
 
@@ -387,13 +385,6 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
       }
     }
   }
-
-  // Construct the inverse of the `topSort` from the sparsifier.
-  // This is needed to map `AffineDimExpr`s back to the `LoopOrd`
-  // used in loop emitter.
-  // FIXME: This map should be maintained outside loop emitter.
-  for (LoopOrd n = 0; n < numLoops; n++)
-    loopIdToOrd[topSort[n]] = n;
 }
 
 void LoopEmitter::initializeLoopEmit(
@@ -611,8 +602,7 @@ Value LoopEmitter::genAffine(OpBuilder &builder, Location loc, AffineExpr a) {
     // However, elsewhere we have been lead to expect that `loopIdToOrd`
     // should be indexed by `LoopId`...
     const auto loopId = cast<AffineDimExpr>(a).getPosition();
-    assert(loopId < loopIdToOrd.size());
-    return loopStack[loopIdToOrd[loopId]].iv;
+    return loopStack[loopId].iv;
   }
   case AffineExprKind::Add: {
     auto binOp = cast<AffineBinaryOpExpr>(a);

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
index c6518decbdee0a4..320b39765dea4a7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -113,12 +113,11 @@ class LoopEmitter {
   /// to `LoopId`.
   void initialize(ValueRange tensors, StringAttr loopTag = nullptr,
                   bool hasOutput = false, bool isSparseOut = false,
-                  ArrayRef<LoopId> topSort = {},
-                  DependentLvlGetter getter = nullptr);
+                  unsigned numLoops = 0, DependentLvlGetter getter = nullptr);
 
   explicit LoopEmitter(ValueRange tensors, StringAttr loopTag = nullptr,
                        bool hasOutput = false, bool isSparseOut = false,
-                       ArrayRef<LoopId> topSort = {},
+                       unsigned numLoops = 0,
                        DependentLvlGetter getter = nullptr);
 
   /// Starts a loop emitting session by generating all the buffers needed
@@ -751,11 +750,6 @@ class LoopEmitter {
   // TODO: maybe we should have a LoopSeqInfo
   std::vector<std::pair<Value, std::vector<std::tuple<TensorId, Level, bool>>>>
       loopSeqStack;
-
-  /// Maps `LoopId` (used by `AffineDimExpr`) to `LoopOrd` (in the `loopStack`).
-  /// TODO: We should probably use a callback function here to make it more
-  /// general.
-  std::vector<LoopOrd> loopIdToOrd;
 };
 
 } // namespace sparse_tensor

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index cd6f689a04acb53..ec96ce23be03f28 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -49,8 +49,8 @@ using namespace mlir::sparse_tensor;
 // (assuming that's the actual meaning behind the "idx"-vs-"ldx" convention).
 
 /// Determines if affine expression is invariant.
-static bool isInvariantAffine(AffineExpr a, ArrayRef<LoopId> loopStack,
-                              LoopId ldx, bool &isAtLoop) {
+static bool isInvariantAffine(AffineExpr a, unsigned loopDepth, LoopId ldx,
+                              bool &isAtLoop) {
   switch (a.getKind()) {
   case AffineExprKind::DimId: {
     const LoopId i = cast<AffineDimExpr>(a).getPosition();
@@ -59,19 +59,14 @@ static bool isInvariantAffine(AffineExpr a, ArrayRef<LoopId> loopStack,
       // Must be invariant if we are at the given loop.
       return true;
     }
-    bool isInvariant = false;
-    for (LoopId l : loopStack) {
-      isInvariant = (l == i);
-      if (isInvariant)
-        break;
-    }
-    return isInvariant;
+    // The DimExpr is invariant the loop has already been generated.
+    return i < loopDepth;
   }
   case AffineExprKind::Add:
   case AffineExprKind::Mul: {
     auto binOp = cast<AffineBinaryOpExpr>(a);
-    return isInvariantAffine(binOp.getLHS(), loopStack, ldx, isAtLoop) &&
-           isInvariantAffine(binOp.getRHS(), loopStack, ldx, isAtLoop);
+    return isInvariantAffine(binOp.getLHS(), loopDepth, ldx, isAtLoop) &&
+           isInvariantAffine(binOp.getRHS(), loopDepth, ldx, isAtLoop);
   }
   default: {
     assert(isa<AffineConstantExpr>(a));
@@ -80,12 +75,6 @@ static bool isInvariantAffine(AffineExpr a, ArrayRef<LoopId> loopStack,
   }
 }
 
-/// Determines if affine expression is invariant.
-static bool isInvariantAffine(CodegenEnv &env, AffineExpr a, LoopId ldx,
-                              bool &isAtLoop) {
-  return isInvariantAffine(a, env.getCurrentLoopStack(), ldx, isAtLoop);
-}
-
 /// Helper method to inspect affine expressions. Rejects cases where the
 /// same index is used more than once. Also rejects compound affine
 /// expressions in sparse dimensions.
@@ -351,17 +340,6 @@ static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
       llvm::cast<linalg::LinalgOp>(op.getOperation())
           .createLoopRanges(builder, loc);
 
-  assert(loopRange.size() == env.merger().getStartingFilterLoopId());
-  SmallVector<Range, 4> sortedRange;
-  for (unsigned i = 0, e = env.topSortSize(); i < e; i++) {
-    LoopId ldx = env.topSortAt(i);
-    // FIXME: Gets rid of filter loops since we have a better algorithm to deal
-    // with affine index expression.
-    if (ldx < env.merger().getStartingFilterLoopId()) {
-      sortedRange.push_back(loopRange[ldx]);
-    }
-  }
-
   env.emitter().initializeLoopEmit(
       builder, loc,
       /// Generates buffer for the output tensor.
@@ -396,15 +374,14 @@ static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
         }
         return init;
       },
-      [&sortedRange, &env](OpBuilder &b, Location loc, Level l) {
-        assert(l < env.topSortSize());
+      [&loopRange, &env](OpBuilder &b, Location loc, Level l) {
+        assert(l < env.getLoopNum());
         // FIXME: Remove filter loop since we have a better algorithm to
         // deal with affine index expression.
         if (l >= env.merger().getStartingFilterLoopId())
           return Value();
 
-        return mlir::getValueOrCreateConstantIndexOp(b, loc,
-                                                     sortedRange[l].size);
+        return mlir::getValueOrCreateConstantIndexOp(b, loc, loopRange[l].size);
       });
 }
 
@@ -762,7 +739,7 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
           return;
         if (*sldx == ldx)
           isAtLoop = true;
-      } else if (!isInvariantAffine(env, a, ldx, isAtLoop))
+      } else if (!isInvariantAffine(a, env.getLoopDepth(), ldx, isAtLoop))
         return; // still in play
     }
     // All exhausted at this level (isAtLoop denotes exactly at this LoopId).
@@ -890,29 +867,6 @@ static bool shouldTryParallize(CodegenEnv &env, LoopId ldx, bool isOuter,
   return isParallelFor(env, isOuter, isSparse);
 }
 
-/// Generates a "filter loop" on the given tid level to locate a coordinate that
-/// is of the same value as evaluated by the affine expression in its matching
-/// indexing map.
-static Operation *genFilterLoop(CodegenEnv &env, OpBuilder &builder, LoopId ldx,
-                                TensorLevel tidLvl) {
-  linalg::GenericOp op = env.op();
-  Location loc = op.getLoc();
-  Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
-    assert(env.merger().isFilterLoop(ldx));
-    const auto [tid, lvl] = env.unpackTensorLevel(tidLvl);
-    // tids/lvls must only have one value because filter loops only
-    // corresponding to the one and only sparse tensor level.
-    OpOperand *t = &op->getOpOperand(tid);
-    auto enc = getSparseTensorEncoding(t->get().getType());
-    // Retrieves the affine expression for the filter loop.
-    // FIXME: `toOrigDim` is deprecated.
-    AffineExpr a = op.getMatchingIndexingMap(t).getResult(toOrigDim(enc, lvl));
-    return env.emitter().enterFilterLoopOverTensorAtLvl(builder, loc, tid, lvl,
-                                                        a, reduc);
-  });
-  return loop;
-}
-
 /// Emit a loop to coiterate over the list of tensor levels. The generated loop
 /// can either be a for loop or while loop depending on whether there is at most
 /// one sparse level in the list.
@@ -934,14 +888,8 @@ static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
 /// singleton iteration or co-iteration over the given conjunction.
 static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopOrd at,
                           bool needsUniv, ArrayRef<TensorLevel> tidLvls) {
-  const LoopId ldx = env.topSortAt(at);
-  if (env.merger().isFilterLoop(ldx)) {
-    assert(tidLvls.size() == 1);
-    return genFilterLoop(env, builder, ldx, tidLvls.front());
-  }
-
-  bool tryParallel = shouldTryParallize(env, ldx, at == 0, tidLvls);
-  return genCoIteration(env, builder, ldx, tidLvls, tryParallel, needsUniv);
+  bool tryParallel = shouldTryParallize(env, at, at == 0, tidLvls);
+  return genCoIteration(env, builder, at, tidLvls, tryParallel, needsUniv);
 }
 
 /// Generates the induction structure for a while-loop.
@@ -1066,12 +1014,12 @@ static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
 /// Starts a loop sequence at given level. Returns true if
 /// the universal loop index must be maintained at this level.
 static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
-                         LoopOrd at, LoopId idx, LoopId ldx, LatSetId lts) {
+                         LoopOrd idx, LoopId ldx, LatSetId lts) {
   assert(!env.getLoopVar(idx));
   // Emit invariants at this loop sequence level.
   genInvariants(env, builder, exp, ldx, /*atStart=*/true);
   // Emit access pattern expansion for sparse tensor output.
-  genExpand(env, builder, at, /*atStart=*/true);
+  genExpand(env, builder, idx, /*atStart=*/true);
   // Emit further intitialization at this loop sequence level.
   const LatPointId l0 = env.set(lts)[0];
   bool needsUniv = false;
@@ -1226,7 +1174,8 @@ static bool translateBitsToTidLvlPairs(
             // Constant affine expression are handled in genLoop
             if (!isa<AffineConstantExpr>(exp)) {
               bool isAtLoop = false;
-              if (isInvariantAffine(env, exp, ldx, isAtLoop) && isAtLoop) {
+              if (isInvariantAffine(exp, env.getLoopDepth(), ldx, isAtLoop) &&
+                  isAtLoop) {
                 // If the compound affine is invariant and we are right at the
                 // level. We need to generate the address according to the
                 // affine expression. This is also the best place we can do it
@@ -1273,8 +1222,8 @@ static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
   // becomes invariant and the address shall now be generated at the current
   // level.
   SmallVector<std::pair<TensorLevel, AffineExpr>> affineTidLvls;
-  bool isSingleCond = translateBitsToTidLvlPairs(env, li, env.topSortAt(at),
-                                                 tidLvls, affineTidLvls);
+  bool isSingleCond =
+      translateBitsToTidLvlPairs(env, li, at, tidLvls, affineTidLvls);
 
   // Emit the for/while-loop control.
   Operation *loop = genLoop(env, builder, at, needsUniv, tidLvls);
@@ -1324,13 +1273,13 @@ static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
 
 /// Ends a loop sequence at given level.
 static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
-                       unsigned at, unsigned idx, unsigned ldx) {
+                       unsigned idx, unsigned ldx) {
   assert(!env.getLoopVar(idx));
   env.emitter().exitCurrentLoopSeq(builder, env.op().getLoc());
   // Unmark bookkeeping of invariants and loop index.
   genInvariants(env, builder, exp, ldx, /*atStart=*/false);
   // Finalize access pattern expansion for sparse tensor output.
-  genExpand(env, builder, at, /*atStart=*/false);
+  genExpand(env, builder, idx, /*atStart=*/false);
 }
 
 /// Recursively generates code while computing iteration lattices in order
@@ -1339,22 +1288,19 @@ static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
 static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
                     LoopOrd at) {
   // At each leaf, assign remaining tensor (sub)expression to output tensor.
-  if (at == env.topSortSize()) {
-    const LoopId ldx = env.topSortAt(at - 1);
-    Value rhs = genExp(env, rewriter, exp, ldx);
+  if (at == env.getLoopNum()) {
+    Value rhs = genExp(env, rewriter, exp, at - 1);
     genTensorStore(env, rewriter, exp, rhs);
     return;
   }
 
   // Construct iteration lattices for current loop index, with L0 at top.
-  const LoopId idx = env.topSortAt(at);
-  const LoopId ldx = at == 0 ? ::mlir::sparse_tensor::detail::kInvalidId
-                             : env.topSortAt(at - 1);
+  const LoopId ldx = at == 0 ? sparse_tensor::detail::kInvalidId : at - 1;
   const LatSetId lts =
-      env.merger().optimizeSet(env.merger().buildLattices(exp, idx));
+      env.merger().optimizeSet(env.merger().buildLattices(exp, at));
 
   // Start a loop sequence.
-  bool needsUniv = startLoopSeq(env, rewriter, exp, at, idx, ldx, lts);
+  bool needsUniv = startLoopSeq(env, rewriter, exp, at, ldx, lts);
 
   // Emit a loop for every lattice point L0 >= Li in this loop sequence.
   //
@@ -1382,7 +1328,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
       if (li == lj || env.merger().latGT(li, lj)) {
         // Recurse into body of each branch.
         if (!isSingleCond) {
-          scf::IfOp ifOp = genIf(env, rewriter, idx, lj);
+          scf::IfOp ifOp = genIf(env, rewriter, at, lj);
           genStmt(env, rewriter, ej, at + 1);
           endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
         } else {
@@ -1392,11 +1338,11 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
     }
 
     // End a loop.
-    needsUniv = endLoop(env, rewriter, loop, idx, li, needsUniv, isSingleCond);
+    needsUniv = endLoop(env, rewriter, loop, at, li, needsUniv, isSingleCond);
   }
 
   // End a loop sequence.
-  endLoopSeq(env, rewriter, exp, at, idx, ldx);
+  endLoopSeq(env, rewriter, exp, at, ldx);
 }
 
 /// Converts the result computed by the sparse kernel into the required form.


        


More information about the llvm-branch-commits mailing list