[Mlir-commits] [mlir] 1328bb6 - [mlir][sparse] extend loop emitter and optimize lattices with the awareness of slice based iteration

Peiming Liu llvmlistbot at llvm.org
Mon Mar 20 15:20:02 PDT 2023


Author: Peiming Liu
Date: 2023-03-20T22:19:57Z
New Revision: 1328bb6ef1645951606ee3e8fa6acbbff6b2438f

URL: https://github.com/llvm/llvm-project/commit/1328bb6ef1645951606ee3e8fa6acbbff6b2438f
DIFF: https://github.com/llvm/llvm-project/commit/1328bb6ef1645951606ee3e8fa6acbbff6b2438f.diff

LOG: [mlir][sparse] extend loop emitter and optimize lattices with the awareness of slice based iteration

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D142929

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
    mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 0e6c2f1553f1c..4a83237fb1634 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -399,11 +399,17 @@ class Merger {
   /// to sparse level-type.
   bool hasAnySparse(const BitVector &bits) const;
 
+  /// Returns true if bits contains a dependent index reduction condition on
+  /// sparse levels.
+  bool hasSparseIdxReduction(const BitVector &bits) const;
+
   /// Gets the level-type of the `t`th tensor on `i`th loop.
   DimLevelType getDimLevelType(TensorId t, LoopId i) const {
     assert(t < numTensors && i < numLoops);
     return lvlTypes[t][i];
   }
+
+  /// Gets the level-type of the TensorLoopId.
   DimLevelType getDimLevelType(TensorLoopId b) const {
     return getDimLevelType(tensor(b), loop(b));
   }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
index 8e4904ad3a592..f326d5b950a31 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
@@ -28,6 +28,23 @@ static bool isMaterializing(Value val) {
          val.getDefiningOp<bufferization::AllocTensorOp>();
 }
 
+/// Makes target array's elements sorted according to the `order` array.
+static void sortArrayBasedOnOrder(std::vector<LoopId> &target,
+                                  ArrayRef<LoopId> order) {
+  std::sort(target.begin(), target.end(), [&order](LoopId l, LoopId r) {
+    assert(l != r);
+    int idxL = -1, idxR = -1;
+    for (int i = 0, e = order.size(); i < e; i++) {
+      if (order[i] == l)
+        idxL = i;
+      if (order[i] == r)
+        idxR = i;
+    }
+    assert(idxL >= 0 && idxR >= 0);
+    return idxL < idxR;
+  });
+}
+
 //===----------------------------------------------------------------------===//
 // Code generation environment constructor and general methods
 //===----------------------------------------------------------------------===//
@@ -57,15 +74,42 @@ void CodegenEnv::startEmit() {
     insChain = sparseOut->get();
     latticeMerger.setHasSparseOut(true);
   }
+
+  // Sort the related loop array such that they are in the same order as they
+  // appears on the topoOrder.
+  // TODO: since we only handle affine addition for slice based codegen, and
+  // addition is assoicative, the order how we evaluate the expression does
+  // not matter. However, to support multiplication, the order of the loop
+  // index should match the evaluation order to the affine expression AST.
+
   // Initialize loop emitter.
-  SmallVector<Value> tensors;
-  for (OpOperand &t : linalgOp->getOpOperands())
+  SmallVector<Value> tensors; // input tensors passed to loop emitter
+  for (OpOperand &t : linalgOp->getOpOperands()) {
     tensors.push_back(t.get());
-  loopEmitter.initialize(tensors,
-                         StringAttr::get(linalgOp.getContext(),
-                                         linalg::GenericOp::getOperationName()),
-                         /*hasOutput=*/true,
-                         /*isSparseOut=*/sparseOut != nullptr, topSort);
+    Level rank = linalgOp.getMatchingIndexingMap(&t).getNumResults();
+    for (Level lvl = 0; lvl < rank; lvl++) {
+      sortArrayBasedOnOrder(
+          latticeMerger.getDependentLoops(t.getOperandNumber(), lvl), topSort);
+    }
+  }
+
+  loopEmitter.initialize(
+      tensors,
+      StringAttr::get(linalgOp.getContext(),
+                      linalg::GenericOp::getOperationName()),
+      /*hasOutput=*/true,
+      /*isSparseOut=*/sparseOut != nullptr, topSort,
+      // TODO: compute the map and pass it to loop emitter directly instead of
+      // passing in a callback.
+      [this](TensorId t, Level lvl) -> std::vector<std::pair<TensorId, Level>> {
+        // Translates from a list of loop index to a list of [tid, dim] pair.
+        std::vector<LoopId> rLoops = this->merger().getDependentLoops(t, lvl);
+        std::vector<std::pair<TensorId, Level>> ret;
+        ret.reserve(rLoops.size());
+        for (LoopId l : rLoops)
+          ret.emplace_back(this->merger().getLoopDefiningLvl(l));
+        return ret;
+      });
 }
 
 std::optional<Operation *> CodegenEnv::genLoopBoundary(

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
index 776d7f7f47ece..8c6a7bd6433db 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
@@ -99,7 +99,6 @@ class CodegenEnv {
     topSort.reserve(capacity);
   }
 
-  ArrayRef<LoopId> getTopSort() const { return topSort; };
   ArrayRef<LoopId> getTopSortSlice(LoopOrd n, LoopOrd m) const;
   ArrayRef<LoopId> getLoopStackUpTo(LoopOrd n) const;
   ArrayRef<LoopId> getCurrentLoopStack() const;

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index c3823c0f204d9..459a1b38e03de 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -208,12 +208,14 @@ Value LoopEmitter::genSparseCrd(OpBuilder &builder, Location loc, TensorId tid,
 }
 
 LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput,
-                         bool isSparseOut, ArrayRef<LoopId> topSort) {
-  initialize(tensors, loopTag, hasOutput, isSparseOut, topSort);
+                         bool isSparseOut, ArrayRef<LoopId> topSort,
+                         DependentLvlGetter dimGetter) {
+  initialize(tensors, loopTag, hasOutput, isSparseOut, topSort, dimGetter);
 }
 
 void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
-                             bool isSparseOut, ArrayRef<LoopId> topSort) {
+                             bool isSparseOut, ArrayRef<LoopId> topSort,
+                             DependentLvlGetter dimGetter) {
   // First initialize the top-level type of the fields.
   this->loopTag = loopTag;
   this->hasOutput = hasOutput;
@@ -242,6 +244,9 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
   this->loopStack.reserve(numLoops);
   this->loopSeqStack.reserve(numLoops);
 
+  this->dependentLvlMap.assign(
+      numTensors, std::vector<std::vector<std::pair<TensorId, Level>>>());
+
   // Initialize nested types of `TensorId`-indexed fields.
   for (TensorId tid = 0; tid < numTensors; tid++) {
     const Value t = tensors[tid];
@@ -283,6 +288,12 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
     coordinatesBuffers[tid].assign(lvlRank, Value());
     sliceOffsets[tid].assign(lvlRank, Value());
     sliceStrides[tid].assign(lvlRank, Value());
+
+    dependentLvlMap[tid].assign(lvlRank,
+                                std::vector<std::pair<TensorId, Level>>());
+    if (dimGetter)
+      for (Level l = 0; l < lvlRank; l++)
+        dependentLvlMap[tid][l] = dimGetter(tid, l);
   }
 
   // Construct the inverse of the `topSort` from the sparsifier.
@@ -997,8 +1008,8 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
   }
 }
 
-void LoopEmitter::exitCoIterationLoop(OpBuilder &builder, Location loc,
-                                      MutableArrayRef<Value> reduc) {
+void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
+                                MutableArrayRef<Value> reduc) {
   const LoopInfo &loopInfo = loopStack.back();
   auto whileOp = llvm::cast<scf::WhileOp>(loopInfo.loop);
   builder.setInsertionPointToEnd(loopInfo.userCodeBlock);
@@ -1082,7 +1093,7 @@ void LoopEmitter::exitCurrentLoop(RewriterBase &rewriter, Location loc,
   assert(loopInfo.tids.size() == loopInfo.lvls.size());
   SmallVector<Value> red;
   if (llvm::isa<scf::WhileOp>(loopInfo.loop)) {
-    exitCoIterationLoop(rewriter, loc, reduc);
+    exitWhileLoop(rewriter, loc, reduc);
   } else {
     exitForLoop(rewriter, loc, reduc);
   }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
index 8e6c65fd96c92..8cfe00100eba8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -76,6 +76,14 @@ class LoopEmitter {
   /// initializing the loop emitter (e.g., to fill a dense output with zeros).
   using OutputUpdater = function_ref<Value(OpBuilder &builder, Location loc,
                                            Value memref, Value tensor)>;
+  // Map from [tid, dim] to a list of dependent [tid, dim] for affine expression
+  // index on sparse tensors.
+  // E.g., for affine index (d0 + d1), it depends on two [tid, dim] that defines
+  // d0 and d1 (for affine expression reduction).
+  // If the list is empty, it means that there is no affine expression on the
+  // input [tid, dim].
+  using DependentLvlGetter =
+      function_ref<std::vector<std::pair<TensorId, Level>>(TensorId, Level)>;
 
   LoopEmitter() = default;
 
@@ -89,11 +97,13 @@ class LoopEmitter {
   /// to `LoopId`.
   void initialize(ValueRange tensors, StringAttr loopTag = nullptr,
                   bool hasOutput = false, bool isSparseOut = false,
-                  ArrayRef<LoopId> topSort = {});
+                  ArrayRef<LoopId> topSort = {},
+                  DependentLvlGetter getter = nullptr);
 
   explicit LoopEmitter(ValueRange tensors, StringAttr loopTag = nullptr,
                        bool hasOutput = false, bool isSparseOut = false,
-                       ArrayRef<LoopId> topSort = {});
+                       ArrayRef<LoopId> topSort = {},
+                       DependentLvlGetter getter = nullptr);
 
   /// Starts a loop emitting session by generating all the buffers needed
   /// for iterating over the tensors.
@@ -295,8 +305,8 @@ class LoopEmitter {
                    MutableArrayRef<Value> reduc);
 
   /// Exits a while loop, returns the reduction results.
-  void exitCoIterationLoop(OpBuilder &builder, Location loc,
-                           MutableArrayRef<Value> reduc);
+  void exitWhileLoop(OpBuilder &builder, Location loc,
+                     MutableArrayRef<Value> reduc);
 
   //
   // View-based-reshape methods.
@@ -380,6 +390,15 @@ class LoopEmitter {
   std::vector<std::vector<Value>> sliceOffsets;
   std::vector<std::vector<Value>> sliceStrides;
 
+  // Map from [tid, level] to a list of dependent [tid, level].
+  // See comments for `DependentDimGetter`.
+  std::vector<std::vector<std::vector<std::pair<TensorId, Level>>>>
+      dependentLvlMap;
+
+  //
+  // View based reshape related-fields and methods
+  //
+
   /// Collapse Reassociations related to a specific tensor
   // TODO: support expand.
   std::vector<ArrayAttr> collapseReassoc;

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index d7ce2b7f63f5c..f189b14c60c7e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -593,23 +593,6 @@ static void tryRelaxAffineConstraints(linalg::GenericOp op,
   }
 }
 
-/// Makes target array's elements appear in the same order as the `order` array.
-static void sortArrayBasedOnOrder(std::vector<LoopId> &target,
-                                  ArrayRef<LoopId> order) {
-  std::sort(target.begin(), target.end(), [&order](LoopId l, LoopId r) {
-    assert(l != r);
-    int idxL = -1, idxR = -1;
-    for (int i = 0, e = order.size(); i < e; i++) {
-      if (order[i] == l)
-        idxL = i;
-      if (order[i] == r)
-        idxR = i;
-    }
-    assert(idxL >= 0 && idxR >= 0);
-    return idxL < idxR;
-  });
-}
-
 static void addFilterLoopBasedConstraints(CodegenEnv &env, OpOperand &t,
                                           OpOperand *skip, SortMask mask,
                                           std::vector<std::vector<bool>> &adjM,
@@ -1484,9 +1467,10 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
   SmallVector<Level> lvls;
   env.merger().foreachTensorLoopId(l0, [&](TensorLoopId b, TensorId tid,
                                            std::optional<Level> lvl,
-                                           DimLevelType dlt, bool) {
+                                           DimLevelType dlt, bool isIdxReduc) {
     assert(env.merger().loop(b) == idx);
-    if (isDenseDLT(dlt) || isUndefDLT(dlt)) {
+    // FIXME: Dense index reduction can reuse the universal index as well.
+    if (!isIdxReduc && (isDenseDLT(dlt) || isUndefDLT(dlt))) {
       needsUniv = true;
     } else {
       // sparse/singleton levels.
@@ -1503,7 +1487,8 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
     unsigned lsize = env.set(lts).size();
     for (unsigned i = 1; i < lsize; i++) {
       const LatPointId li = env.set(lts)[i];
-      if (!env.merger().hasAnySparse(env.lat(li).simple))
+      if (!env.merger().hasAnySparse(env.lat(li).simple) &&
+          !env.merger().hasSparseIdxReduction(env.lat(li).simple))
         return true;
     }
   }
@@ -1557,75 +1542,82 @@ static bool translateBitsToTidLvlPairs(
 
   unsigned numloopCond = 0;
   bool hasNonUnique = false;
-  env.merger().foreachTensorLoopId(li, [&, ldx](TensorLoopId b, TensorId tid,
-                                                std::optional<Level> lvl,
-                                                DimLevelType dlt, bool) {
-    if (simple.test(b)) {
-      if (isUndefDLT(dlt)) {
-        // An undefined dlt in the lattices, we probably mean to
-        // iterate based on the level of output tensor.  E.g., this
-        // could be a synthetic tensor (for invariants and sparse
-        // output tensor).
-        // out[i][j] = invariant; or a broadcast
-        // out[i][j] = in[i] (j is undef for input)
-        tid = outTid;
-        lvl = outLvl;
-        // Skips invalid lvl (e.g., when this is a zero ranked tensor).
-        if (!lvl)
-          return;
-      }
-      hasNonUnique = !isUniqueDLT(dlt) || hasNonUnique;
-      tids.push_back(tid);
-      lvls.push_back(*lvl);
-      numloopCond++;
-    } else if (isDenseDLT(dlt)) {
-      tids.push_back(tid);
-      lvls.push_back(*lvl);
-    } else {
-      assert(isUndefDLT(dlt));
-      linalg::GenericOp op = env.op();
-      if (tid >= op.getNumDpsInputs())
-        // We only handle affine expression on input tensors (for now).
-        return;
-      OpOperand *operand = &op->getOpOperand(tid);
-      const auto stt = getSparseTensorType(operand->get());
-      // Non-annotated dense tensors requires no special handling.
-      if (!stt.hasEncoding())
-        return;
-
-      ArrayRef<AffineExpr> affines =
-          op.getMatchingIndexingMap(operand).getResults();
-      const Level lvlRank = stt.getLvlRank();
-      assert(affines.size() == static_cast<size_t>(lvlRank));
-      for (Level l = 0; l < lvlRank; l++) {
-        // FIXME: `toOrigDim` is deprecated.
-        AffineExpr exp = affines[toOrigDim(stt.getEncoding(), l)];
-        // Skip simple affine expression and non-dense levels (which
-        // have their own filter loop).
-        if (exp.isa<AffineDimExpr>() || !stt.isDenseLvl(l))
-          continue;
 
-        // Constant affine expression are handled in genLoop
-        if (!exp.isa<AffineConstantExpr>()) {
-          bool isAtLoop = false;
-          if (isInvariantAffine(env, exp, ldx, isAtLoop) && isAtLoop) {
-            // If the compound affine is invariant and we are right at the
-            // level. We need to generate the address according to the
-            // affine expression. This is also the best place we can do it
-            // to avoid putting it inside inner loops.
-            // NOTE: It assumes that the levels of the input tensor are
-            // initialized in order (and it is also currently guaranteed by
-            // computeIterationGraph), another more admissible approach
-            // might be accepting out-of-order access between consecutive
-            // dense levels.
-            affineTids.push_back(tid);
-            affineLvls.push_back(l);
-            exps.push_back(exp);
+  env.merger().foreachTensorLoopId(
+      li, [&, ldx](TensorLoopId b, TensorId tid, std::optional<Level> lvl,
+                   DimLevelType dlt, bool isIdxReduc) {
+        if (simple.test(b)) {
+          if (isIdxReduc) {
+            tids.push_back(tid);
+            lvls.push_back(*lvl);
+            numloopCond++;
+            return;
+          }
+          if (isUndefDLT(dlt)) {
+            // An undefined dlt in the lattices, we probably mean to
+            // iterate based on the level of output tensor.  E.g., this
+            // could be a synthetic tensor (for invariants and sparse
+            // output tensor).
+            // out[i][j] = invariant; or a broadcast
+            // out[i][j] = in[i] (j is undef for input)
+            tid = outTid;
+            lvl = outLvl;
+            // Skips invalid lvl (e.g., when this is a zero ranked tensor).
+            if (!lvl)
+              return;
+          }
+          hasNonUnique = !isUniqueDLT(dlt) || hasNonUnique;
+          tids.push_back(tid);
+          lvls.push_back(*lvl);
+          numloopCond++;
+        } else if (isDenseDLT(dlt)) {
+          tids.push_back(tid);
+          lvls.push_back(*lvl);
+        } else {
+          assert(isUndefDLT(dlt));
+          linalg::GenericOp op = env.op();
+          if (tid >= op.getNumDpsInputs())
+            // We only handle affine expression on input tensors (for now).
+            return;
+          OpOperand *operand = &op->getOpOperand(tid);
+          const auto stt = getSparseTensorType(operand->get());
+          // Non-annotated dense tensors requires no special handling.
+          if (!stt.hasEncoding())
+            return;
+
+          ArrayRef<AffineExpr> affines =
+              op.getMatchingIndexingMap(operand).getResults();
+          const Level lvlRank = stt.getLvlRank();
+          assert(affines.size() == static_cast<size_t>(lvlRank));
+          for (Level l = 0; l < lvlRank; l++) {
+            // FIXME: `toOrigDim` is deprecated.
+            AffineExpr exp = affines[toOrigDim(stt.getEncoding(), l)];
+            // Skip simple affine expression and non-dense levels (which
+            // have their own filter loop).
+            if (exp.isa<AffineDimExpr>() || !stt.isDenseLvl(l))
+              continue;
+
+            // Constant affine expression are handled in genLoop
+            if (!exp.isa<AffineConstantExpr>()) {
+              bool isAtLoop = false;
+              if (isInvariantAffine(env, exp, ldx, isAtLoop) && isAtLoop) {
+                // If the compound affine is invariant and we are right at the
+                // level. We need to generate the address according to the
+                // affine expression. This is also the best place we can do it
+                // to avoid putting it inside inner loops.
+                // NOTE: It assumes that the levels of the input tensor are
+                // initialized in order (and it is also currently guaranteed by
+                // computeIterationGraph), another more admissible approach
+                // might be accepting out-of-order access between consecutive
+                // dense levels.
+                affineTids.push_back(tid);
+                affineLvls.push_back(l);
+                exps.push_back(exp);
+              }
+            }
           }
         }
-      }
-    }
-  });
+      });
 
   if (isDenseDLT(env.dlt(outTid, ldx))) {
     // Note that we generate dense indices of the output tensor
@@ -1642,8 +1634,9 @@ static bool translateBitsToTidLvlPairs(
 }
 
 /// Starts a single loop in current sequence.
-static Operation *startLoop(CodegenEnv &env, OpBuilder &builder, LoopOrd at,
-                            LatPointId li, bool needsUniv) {
+static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
+                                              OpBuilder &builder, unsigned at,
+                                              unsigned li, bool needsUniv) {
   // The set of tensors + lvls to generate loops on
   SmallVector<TensorId> tids, affineTids;
   SmallVector<Level> lvls, affineLvls;
@@ -1651,11 +1644,12 @@ static Operation *startLoop(CodegenEnv &env, OpBuilder &builder, LoopOrd at,
   // becomes invariant and the address shall now be generated at the current
   // level.
   SmallVector<AffineExpr> affines;
-  bool isFor = translateBitsToTidLvlPairs(
+  bool isSingleCond = translateBitsToTidLvlPairs(
       env, li, env.topSortAt(at), tids, lvls, affineTids, affineLvls, affines);
 
   // Emit the for/while-loop control.
-  Operation *loop = genLoop(env, builder, at, needsUniv, tids, lvls, isFor);
+  Operation *loop =
+      genLoop(env, builder, at, needsUniv, tids, lvls, isSingleCond);
   Location loc = env.op().getLoc();
   for (auto [tid, lvl, exp] : llvm::zip(affineTids, affineLvls, affines)) {
     env.emitter().genDenseAffineAddress(builder, loc, tid, lvl, exp);
@@ -1671,7 +1665,7 @@ static Operation *startLoop(CodegenEnv &env, OpBuilder &builder, LoopOrd at,
       genConstantDenseAddressFromLevel(env, builder, tid, lvl + 1);
   }
 
-  return loop;
+  return std::make_pair(loop, isSingleCond);
 }
 
 /// Ends a single loop in current sequence. Returns new values for needsUniv.
@@ -1734,20 +1728,19 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
   for (unsigned i = 0; i < lsize; i++) {
     // Start a loop.
     const LatPointId li = env.set(lts)[i];
-    Operation *loop = startLoop(env, rewriter, at, li, needsUniv);
+    auto [loop, isSingleCond] = startLoop(env, rewriter, at, li, needsUniv);
 
     // Visit all lattices points with Li >= Lj to generate the
     // loop-body, possibly with if statements for coiteration.
     Value redInput = env.getReduc();
     Value cntInput = env.getExpandCount();
     Value insInput = env.getInsertionChain();
-    bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr;
     for (unsigned j = 0; j < lsize; j++) {
       const LatPointId lj = env.set(lts)[j];
       const ExprId ej = env.lat(lj).exp;
       if (li == lj || env.merger().latGT(li, lj)) {
         // Recurse into body of each branch.
-        if (isWhile) {
+        if (!isSingleCond) {
           scf::IfOp ifOp = genIf(env, rewriter, idx, env.lat(lj).simple);
           genStmt(env, rewriter, ej, at + 1);
           endIf(env, rewriter, ifOp, loop, redInput, cntInput, insInput);
@@ -1866,18 +1859,12 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
     if (!isAdmissible)
       return failure(); // inadmissible expression, reject
 
-    for (OpOperand &t : env.op()->getOpOperands()) {
-      Level rank = env.op().getMatchingIndexingMap(&t).getNumResults();
-      for (Level lvl = 0; lvl < rank; lvl++) {
-        sortArrayBasedOnOrder(
-            env.merger().getDependentLoops(t.getOperandNumber(), lvl),
-            env.getTopSort());
-      }
-    }
-
     // Recursively generates code if admissible.
     env.startEmit();
     genBuffers(env, rewriter);
+    // TODO: Constant affine expression should be handled 
diff erently when using
+    // slice-based codegen, it does not matter now becasue we already reject the
+    // constant expression at a earlier stage.
     genInitConstantDenseAddress(env, rewriter);
     genStmt(env, rewriter, env.getExprId(), 0);
     genResult(env, rewriter);

diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 7f4400188cf14..40db5411132b4 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -362,7 +362,8 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
   }
 
   BitVector simple(latPoints[p0].bits);
-  bool reset = isSingleton && hasAnySparse(simple);
+  bool reset =
+      isSingleton && (hasAnySparse(simple) || hasSparseIdxReduction(simple));
   const TensorLoopId be = simple.size();
   TensorLoopId offset = 0; // relative to the end
   if (!reset)
@@ -379,7 +380,9 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
   // keep the rightmost bit (which could possibly be a synthetic tensor).
   for (TensorLoopId b = be - 1 - offset, i = 0; i < be;
        b = b == 0 ? be - 1 : b - 1, i++) {
-    if (simple[b]) {
+    // FIXME: better name? also slice on dense level has locate property as
+    // well. Handle it correctly!
+    if (simple[b] && !isLvlWithNonTrivialIdxExp(b)) {
       const auto dlt = getDimLevelType(b);
       if (!isCompressedDLT(dlt) && !isSingletonDLT(dlt)) {
         if (reset)
@@ -407,7 +410,7 @@ bool Merger::latGT(LatPointId i, LatPointId j) const {
 bool Merger::onlyDenseDiff(LatPointId i, LatPointId j) const {
   BitVector tmp(latPoints[j].bits);
   tmp ^= latPoints[i].bits;
-  return !hasAnySparse(tmp);
+  return !hasAnySparse(tmp) && !hasSparseIdxReduction(tmp);
 }
 
 bool Merger::expContainsTensor(ExprId e, TensorId t) const {
@@ -555,6 +558,14 @@ bool Merger::hasAnySparse(const BitVector &bits) const {
   return false;
 }
 
+bool Merger::hasSparseIdxReduction(const BitVector &bits) const {
+  // TODO: return false on dense levels.
+  for (unsigned b = 0, be = bits.size(); b < be; b++)
+    if (bits[b] && isLvlWithNonTrivialIdxExp(b))
+      return true;
+  return false;
+}
+
 #ifndef NDEBUG
 
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list