[Mlir-commits] [mlir] e015d38 - [mlir][sparse] Pass down constant coefficients of affine index expressions to LoopEmitter.

Peiming Liu llvmlistbot at llvm.org
Wed Aug 30 11:55:59 PDT 2023


Author: Peiming Liu
Date: 2023-08-30T18:44:50Z
New Revision: e015d385c913daae3ec9654b84104caf28940c77

URL: https://github.com/llvm/llvm-project/commit/e015d385c913daae3ec9654b84104caf28940c77
DIFF: https://github.com/llvm/llvm-project/commit/e015d385c913daae3ec9654b84104caf28940c77.diff

LOG: [mlir][sparse] Pass down constant coefficients of affine index expressions to LoopEmitter.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D158914

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
    mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 6293badb6df5c4..5e753800675728 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -56,6 +56,13 @@ using LatPointId = unsigned;
 /// for the corresponding `SmallVector<LatPointId>` object.
 using LatSetId = unsigned;
 
+/// A pair of level and its corresponding DimLevelType of a tensor.
+using LvlDLTPair = std::pair<Level, DimLevelType>;
+
+/// A pair of loop id and its coefficients. E.g., for affine expression in the
+/// affine map `2 * d0`, loop id = 0, coefficient = 2.
+using LoopCoeffPair = std::pair<LoopId, unsigned>;
+
 /// Tensor expression. Represents an MLIR expression in tensor index notation.
 struct TensorExp final {
   enum class Kind;
@@ -509,22 +516,22 @@ class Merger {
 
   /// Establishes the two-way map that i <-> <t, lvl, dlt>.
   void setLoopDependentTensorLevel(LoopId i, TensorId t, Level lvl,
-                                   DimLevelType dlt) {
+                                   DimLevelType dlt, unsigned coefficient) {
     assert(isValidLoopId(i) && isValidLevel(t, lvl));
-    assert(!loopToDependencies[i][t].has_value()); // must be the first def
-    loopToDependencies[i][t] = std::make_pair(lvl, dlt);
-    levelToDependentLoop[t][lvl].push_back(i);
+    assert(!loopToUnresolvedLvls[i][t].has_value()); // must be the first def
+    loopToUnresolvedLvls[i][t] = std::make_pair(lvl, dlt);
+    levelToDependentLoop[t][lvl].emplace_back(i, coefficient);
   }
 
   /// Whether the loop has dependent slice.
   bool hasDependentLvl(LoopId i, TensorId t) {
     assert(isValidTensorId(t) && isValidLoopId(i));
-    return loopToDependencies[i][t].has_value();
+    return loopToUnresolvedLvls[i][t].has_value();
   }
 
   /// Returns the list of loop indices which appear in the non-trivial index
   /// expression on t_l, e.g., A[i+j] => {i, j}
-  std::vector<LoopId> &getDependentLoops(TensorId t, Level lvl) {
+  std::vector<LoopCoeffPair> &getDependentLoops(TensorId t, Level lvl) {
     assert(isValidLevel(t, lvl));
     return levelToDependentLoop[t][lvl];
   }
@@ -541,7 +548,7 @@ class Merger {
     const TensorId t = tensor(b);
     const LoopId i = loop(b);
     assert(isValidTensorId(t) && isValidLoopId(i));
-    return loopToDependencies[i][t].has_value();
+    return loopToUnresolvedLvls[i][t].has_value();
   }
 
   /// Checks whether the TensorLoopId represents a sparse tensor level contains
@@ -556,12 +563,12 @@ class Merger {
 
   Level getLoopDependentLevel(TensorLoopId b) const {
     assert(isLvlWithNonTrivialIdxExp(b));
-    return loopToDependencies[loop(b)][tensor(b)]->first;
+    return loopToUnresolvedLvls[loop(b)][tensor(b)]->first;
   }
 
   DimLevelType getLoopDependentLevelType(TensorLoopId b) const {
     assert(isLvlWithNonTrivialIdxExp(b));
-    return loopToDependencies[loop(b)][tensor(b)]->second;
+    return loopToUnresolvedLvls[loop(b)][tensor(b)]->second;
   }
 
   /// Convenience getters to immediately access the stored nodes.
@@ -715,13 +722,13 @@ class Merger {
   /// It is currently only set for non-trivial index expressions.
   /// E.g., A[i+j] => i and j will have dependencies {A0, dlt(A0)} to indicate
   /// that i and j are used in the non-trivial index expression on A0.
-  std::vector<std::vector<std::optional<std::pair<Level, DimLevelType>>>>
-      loopToDependencies;
+  std::vector<std::vector<std::optional<LvlDLTPair>>> loopToUnresolvedLvls;
 
   /// The inverse map of ldxToDependencies from tensor level -> dependent loop
-  /// E.g., A[i+j], we have A0 => {i, j}, to indicate that A0 uses both {i, j}
-  /// to compute its indices.
-  std::vector<std::vector<std::vector<LoopId>>> levelToDependentLoop;
+  /// E.g., A[2i+j], we have A0 => {(2, i), (1, j)}, to indicate that A0 uses
+  /// both {i, j} to compute its indices and the coefficients on the loop id are
+  /// 2 and 1 respectively.
+  std::vector<std::vector<std::vector<LoopCoeffPair>>> levelToDependentLoop;
 
   /// Map from a loop to the [tid, lvl] pair that defines the loop boundary.
   std::vector<std::pair<TensorId, Level>> loopBounds;

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
index 854b6f5587073b..924b0a0dac8113 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
@@ -29,16 +29,16 @@ static bool isMaterializing(Value val) {
 }
 
 /// Makes target array's elements sorted according to the `order` array.
-static void sortArrayBasedOnOrder(std::vector<LoopId> &target,
+static void sortArrayBasedOnOrder(std::vector<LoopCoeffPair> &target,
                                   ArrayRef<LoopId> order) {
   std::sort(target.begin(), target.end(),
-            [&order](const LoopId &l, const LoopId &r) {
+            [&order](const LoopCoeffPair &l, const LoopCoeffPair &r) {
               assert(std::addressof(l) == std::addressof(r) || l != r);
               int idxL = -1, idxR = -1;
               for (int i = 0, e = order.size(); i < e; i++) {
-                if (order[i] == l)
+                if (order[i] == l.first)
                   idxL = i;
-                if (order[i] == r)
+                if (order[i] == r.first)
                   idxR = i;
               }
               assert(idxL >= 0 && idxR >= 0);
@@ -104,13 +104,17 @@ void CodegenEnv::startEmit() {
       /*isSparseOut=*/sparseOut != nullptr, topSort,
       // TODO: compute the map and pass it to loop emitter directly instead of
       // passing in a callback.
-      [this](TensorId t, Level lvl) -> std::vector<std::pair<TensorId, Level>> {
-        // Translates from a list of loop index to a list of [tid, dim] pair.
-        std::vector<LoopId> rLoops = this->merger().getDependentLoops(t, lvl);
-        std::vector<std::pair<TensorId, Level>> ret;
+      /*dependentLvlGetter=*/
+      [this](TensorId t,
+             Level lvl) -> std::vector<std::pair<TensorLevel, unsigned>> {
+        // Translates from a list of loop indices to a list of [tid, lvl] pair.
+        std::vector<LoopCoeffPair> &rLoops = merger().getDependentLoops(t, lvl);
+        std::vector<std::pair<TensorLevel, unsigned>> ret;
         ret.reserve(rLoops.size());
-        for (LoopId l : rLoops)
-          ret.emplace_back(this->merger().getLoopDefiningLvl(l));
+        for (auto [loop, coeff] : rLoops) {
+          TensorLevel tl = makeTensorLevel(merger().getLoopDefiningLvl(loop));
+          ret.emplace_back(tl, coeff);
+        };
         return ret;
       });
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
index dea9e740b8db64..c0fc505d153a45 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
@@ -96,6 +96,9 @@ class CodegenEnv {
            loopEmitter.getSynTensorId() == latticeMerger.getSynTensorID());
     return loopEmitter.makeTensorLevel(t, l);
   }
+  TensorLevel makeTensorLevel(std::pair<TensorId, Level> tlPair) const {
+    return makeTensorLevel(tlPair.first, tlPair.second);
+  }
   std::pair<TensorId, Level> unpackTensorLevel(TensorLevel tl) const {
     return loopEmitter.unpackTensorLevel(tl);
   }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index 06db5b0ab78e35..441f29dedcdafb 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -264,14 +264,11 @@ Value LoopEmitter::genSegmentHigh(OpBuilder &builder, Location loc,
 
 Value LoopEmitter::genSparseCrd(OpBuilder &builder, Location loc, TensorId tid,
                                 Level lvl) {
-  Value crd = C_IDX(0);
   // A load on the coordinates array yields the coordinate.
   const Value mem = coordinatesBuffers[tid][lvl];
   /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header.
   const Value pos = posits[tid][lvl];
-  const Value off = genIndexLoad(builder, loc, mem, pos);
-  // Linearized the coordinates within the same collapse reassociation.
-  crd = ADDI(crd, off);
+  const Value crd = genIndexLoad(builder, loc, mem, pos);
   return crd;
 }
 
@@ -317,9 +314,10 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
 
   // Index-reduction related fields.
   this->dependentLvlMap.assign(
-      numTensors, std::vector<std::vector<std::pair<TensorId, Level>>>());
+      numTensors, std::vector<std::vector<std::pair<TensorLevel, unsigned>>>());
   this->slicePosBuffer.assign(numTensors, std::vector<std::vector<Value>>());
-  this->sliceSizes.assign(numTensors, std::vector<std::vector<Value>>());
+  this->sliceMeta.assign(
+      numTensors, std::vector<std::vector<std::pair<Value, unsigned>>>());
   this->sliceStack.assign(numTensors, std::vector<SliceInfo>());
   this->levelReducedDep.assign(numTensors, std::vector<unsigned>());
 
@@ -367,10 +365,10 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
 
     // Slice-driven loops related initialization.
     levelReducedDep[tid].assign(lvlRank, 0);
-    dependentLvlMap[tid].assign(lvlRank,
-                                std::vector<std::pair<TensorId, Level>>());
+    dependentLvlMap[tid].assign(
+        lvlRank, std::vector<std::pair<TensorLevel, unsigned>>());
     slicePosBuffer[tid].assign(lvlRank, std::vector<Value>());
-    sliceSizes[tid].assign(lvlRank, std::vector<Value>());
+    sliceMeta[tid].assign(lvlRank, std::vector<std::pair<Value, unsigned>>());
     sliceStack[tid].emplace_back(/*minCrd=*/Value(),
                                  /*offset=*/Value(), /*isNonEmpty*/ Value(),
                                  std::nullopt, 0);
@@ -380,8 +378,8 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
         unsigned depends = dependentLvlMap[tid][l].size();
         if (depends == 0)
           continue;
-        // We need `depends - 1` slices to fully  the affine expression.
-        sliceSizes[tid][l].assign(depends - 1, nullptr);
+        sliceMeta[tid][l].assign(depends, std::make_pair(nullptr, 0));
+        // We need `depends - 1` slices to fully reduce the affine expression.
         slicePosBuffer[tid][l].assign(depends - 1, nullptr);
       }
     }
@@ -502,15 +500,20 @@ void LoopEmitter::initializeLoopEmit(
     Level lvlRank = SparseTensorType(rtp).getLvlRank();
     for (Level lvl = 0; lvl < lvlRank; lvl++) {
       if (!dependentLvlMap[t][lvl].empty()) {
-        ArrayRef<std::pair<TensorId, Level>> depLvls = dependentLvlMap[t][lvl];
+        ArrayRef<std::pair<TensorLevel, unsigned>> depLvls =
+            dependentLvlMap[t][lvl];
         // Needs at least two operands to form a non-trivial affine expression.
-        assert(depLvls.size() > 1);
+        assert(depLvls.size() == sliceMeta[t][lvl].size());
 
         Value size = c0;
-        for (unsigned e = depLvls.size() - 1; e >= 1; e--) {
-          auto [dt, dd] = depLvls[e];
-          size = ADDI(size, lvlSizes[dt][dd]);
-          sliceSizes[t][lvl][e - 1] = size;
+        for (int e = depLvls.size() - 1; e >= 0; e--) {
+          auto [dt, dl] = unpackTensorLevel(depLvls[e].first);
+          unsigned stride = depLvls[e].second;
+          Value stridedSize = lvlSizes[dt][dl];
+          if (stride != 1)
+            stridedSize = MULI(stridedSize, C_IDX(stride));
+          size = ADDI(size, stridedSize);
+          sliceMeta[t][lvl][e] = std::make_pair(size, stride);
         }
       }
     }
@@ -729,8 +732,9 @@ Value LoopEmitter::genWhileLoopConditions(OpBuilder &builder, Location loc,
       // crdHi is a loop invariant, hosit the computation outside the loop.
       if (llvm::isa_and_nonnull<scf::WhileOp>(loop))
         builder.setInsertionPoint(loop);
-      crdHi = ADDI(getMostRecentSliceOnLvl(tid, lvl).offset,
-                   sliceSizes[tid][lvl].back());
+      auto [size, stride] = sliceMeta[tid][lvl].back();
+      assert(stride == 1 && "Not yet implemented");
+      crdHi = ADDI(getMostRecentSliceOnLvl(tid, lvl).offset, size);
     }
     assert(crdHi);
     return genSparseReducedAffineCond(builder, loc,
@@ -984,7 +988,7 @@ bool LoopEmitter::shouldIteratedByForLoop(ArrayRef<TensorLvlCond> sparseConds,
   if (sparseConds.size() > 1)
     return false;
 
-  // We also need a while loop for levels with affine index expression for
+  // We also need a while loop for levels with affine index expression and
   // non-unique levels when deduplication is required.
   if (sparseConds.size() == 1) {
     auto [tid, lvl] = unpackTensorLevel(sparseConds.back().first);
@@ -1042,7 +1046,9 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
     if (isDenseCond(loopCondKind) && isAffineIdxCond(loopCondKind)) {
       bool unReduc = isAffineIdxUnRedCond(loopCondKind);
       assert(unReduc == !depFullyReduced(tid, lvl));
-      hi = sliceSizes[tid][lvl][sliceStack[tid].back().depth - 1];
+      auto [size, stride] = sliceMeta[tid][lvl][sliceStack[tid].back().depth];
+      assert(stride == 1 && "Not yet implemented");
+      hi = size;
       if (unReduc) {
         // Adjust for loop hi for dense slice-driven loop.
         hi = SUBI(lvlSizes[tid][lvl], hi);
@@ -1215,6 +1221,8 @@ void LoopEmitter::enterTensorsAtDenseLvls(
       SliceInfo &info = sliceStack[tid].back();
       // Pushes sliced dense loop info to tell LoopEmitter how to exit it.
       sliceInfo.emplace_back(tid, lvl, /*fullyReduced=*/!unReduc);
+      // FIXME: The offset and position iterator need to be adjusted when the
+      // slice is strided.
       if (unReduc) {
         assert(*info.slicedOnLvl == lvl);
         // Update the slice information as we enter the new loop.
@@ -1361,7 +1369,9 @@ void LoopEmitter::forwardsReducedSliceLevelTreeIt(OpBuilder &builder,
   while (curLvl < leafLvl && isDenseDLT(lvlTypes[tid][curLvl])) {
     // One step forward in parent level results in forwarding `slice.size` step
     // in child dense level.
-    fcnt = MULI(sliceSizes[tid][curLvl].back(), fcnt);
+    auto [size, stride] = sliceMeta[tid][curLvl].back();
+    assert(stride == 1 && "Not yet implemented");
+    fcnt = MULI(size, fcnt);
     curLvl++;
   }
 
@@ -1420,7 +1430,18 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
       // TODO: support coiterating multiple slices
       assert(loopInfo.trivialTidLvls.empty() &&
              loopInfo.sliceDrivenInfo.size() == 1);
-      genSliceNextInduction(builder, loc, whileOp, tid, lvl, operands, o);
+      auto [nxNonEmpty, nxMinCrd, nxAbsOffset] =
+          genSliceNextInduction(builder, loc, tid, lvl);
+      // Update while loop induction operands.
+      operands.push_back(nxNonEmpty);
+      operands.push_back(nxMinCrd);
+      operands.push_back(nxAbsOffset);
+
+      // Update the slice stack.
+      SliceInfo &info = sliceStack[tid].back();
+      info.isNonEmpty = whileOp.getResult(o++);
+      info.minCrd = whileOp.getResult(o++);
+      info.offset = whileOp.getResult(o++);
       continue;
     }
 
@@ -1566,7 +1587,10 @@ std::pair<Operation *, ValueRange> LoopEmitter::genSliceLvlTraverseLoop(
     Value size, TensorId tid, Level lvl, ValueRange userReduc,
     LoopBodyBuilder bodyBuilder) {
   Value c1 = C_IDX(1);
-  Value sliceHi = ADDI(offset, sliceSizes[tid][lvl].back());
+  auto [sliceSz, stride] = sliceMeta[tid][lvl].back();
+  assert(stride == 1 && "Not yet implemented");
+  Value sliceHi = ADDI(offset, sliceSz);
+
   SmallVector<Value> reduc{posLo}; // loop lower bounds
   const unsigned numMetaReduc = reduc.size();
 
@@ -1663,6 +1687,8 @@ ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
     reduc.back() = ADDI(reduc.back(), C_IDX(1));
   };
 
+  // FIXME: Need special handling when the previous unresolved slice is strided:
+  // We probably need to filter out coordinates that is not on stride.
   if (firstResLvl.has_value()) {
     // Overwrite position when the first level is fully resolved.
     pos = posits[firstResLvl->first][firstResLvl->second];
@@ -1694,10 +1720,13 @@ ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
               // non-consecutive segments.
               builder.create<memref::StoreOp>(loc, iterArgs.back(), sPtrBuf,
                                               ADDI(iv, c2).getResult());
+
+              auto [size, stride] = sliceMeta[tid][firstLvl].back();
+              assert(stride == 1 && "Not yet implemented");
               ValueRange itArgs =
                   genSliceLvlTraverseLoop(
-                      builder, loc, loopLo, loopHi, offset,
-                      sliceSizes[tid][firstLvl].back(), tid, firstLvl, iterArgs,
+                      builder, loc, loopLo, loopHi, offset, size, tid, firstLvl,
+                      iterArgs,
                       [&](OpBuilder &builder, Location, Value iv,
                           MutableArrayRef<Value> reduc) {
                         ip = builder.saveInsertionPoint();
@@ -1710,8 +1739,9 @@ ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
       } else if (isDenseDLT(lvlTypes[tid][firstLvl])) {
         assert(firstLvl == 0); // This must be the first level.
         Value lb = frontSlice.offset;
-        Value sliceSz =
-            sliceSizes[tid][*frontSlice.slicedOnLvl][frontSlice.depth - 1];
+        auto [sliceSz, stride] =
+            sliceMeta[tid][*frontSlice.slicedOnLvl][frontSlice.depth];
+        assert(stride == 1 && "Not yet implemented");
         Value ub = ADDI(lb, sliceSz);
         outerMost = builder.create<scf::ForOp>(
             loc, lb, ub, c1, innerArgs,
@@ -1735,7 +1765,8 @@ ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
       Level sliceLvl = *slice->slicedOnLvl;
       assert(isDenseDLT(lvlTypes[tid][sliceLvl]));
       Value offset = slice->offset;
-      Value sliceSz = sliceSizes[tid][sliceLvl][slice->depth - 1];
+      auto [sliceSz, stride] = sliceMeta[tid][sliceLvl][slice->depth];
+      assert(stride == 1 && "Not yet implemented");
       lbs.push_back(offset);
       ubs.push_back(ADDI(offset, sliceSz));
       steps.push_back(c1);
@@ -1788,7 +1819,8 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
                                  lvl, /*depth=*/1);
     return;
   }
-  Value size = sliceSizes[tid][lvl][0];
+  auto [nxSz, stride] = sliceMeta[tid][lvl][1];
+  assert(stride == 1 && "Not yet implemented");
   Value sPtrBuf = slicePosBuffer[tid][lvl][0];
   Value pHi, pLo;
   if (lvl == 0) {
@@ -1816,7 +1848,7 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
   Value minCrd = genIndexLoad(builder, loc, coordinatesBuffers[tid][lvl], pLo);
 
   // FIXME: We need the relative offset related to the base slice.
-  Value absOffset = offsetFromMinCoord(builder, loc, minCrd, size, isNonEmpty);
+  Value absOffset = offsetFromMinCoord(builder, loc, minCrd, nxSz, isNonEmpty);
   sliceStack[tid].emplace_back(minCrd, absOffset, isNonEmpty, lvl,
                                /*depth=*/1);
 }
@@ -1845,7 +1877,12 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
                                           TensorId tid, Level lvl) {
   Value c0 = C_IDX(0), c1 = C_IDX(1), c2 = C_IDX(2);
   unsigned depth = levelReducedDep[tid][lvl];
-  Value size = sliceSizes[tid][lvl][depth]; // Dense slice begin is trivial
+  // TODO: handle case when the current slice stride is not one.
+  assert(sliceMeta[tid][lvl][depth].second == 1 && "Not yet implemented");
+
+  // The remaining slice size after reduction.
+  Value remSz = sliceMeta[tid][lvl][depth + 1].first;
+  // Dense slice begin is trivial
   if (isDenseDLT(lvlTypes[tid][lvl])) {
     sliceStack[tid].emplace_back(c0, c0, constantI1(builder, loc, false), lvl,
                                  depth + 1);
@@ -1941,7 +1978,7 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
   builder.create<memref::StoreOp>(loc, result[2], sPtrBuf, c0);
   builder.create<memref::StoreOp>(loc, c0, sPtrBuf, c1);
   // FIXME: we need the relative offset related to the base slice.
-  Value absOffset = offsetFromMinCoord(builder, loc, minCrd, size, isNonEmpty);
+  Value absOffset = offsetFromMinCoord(builder, loc, minCrd, remSz, isNonEmpty);
   sliceStack[tid].emplace_back(minCrd, absOffset, isNonEmpty, lvl, depth + 1);
 }
 
@@ -2005,9 +2042,12 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
     // TODO: Maybe using allocaScopeOp inside the loop to resolve the issue?
     for (Level curLevel = lvl;
          curLevel >= 1 && !lvlFullyResolved(tid, curLevel - 1); curLevel--) {
-      auto depth = remDepOnLevel(tid, curLevel - 1);
-      assert(sliceSizes[tid][lvl].size() >= depth);
-      Value sz = *(sliceSizes[tid][lvl].rbegin() + depth - 1);
+      // We only handle cases when all the previously unresolved levels are
+      // fully reduced.
+      assert(depFullyReduced(tid, curLevel - 1));
+      assert(!sliceMeta[tid][curLevel - 1].empty());
+      auto [sz, stride] = sliceMeta[tid][curLevel - 1].back();
+      assert(stride == 1 && "Not yet implemented");
       bufSize = MULI(bufSize, sz);
     }
     // For a triple of [pLo, pHi, pPtr]. Note that we can not compress pHi
@@ -2042,18 +2082,15 @@ void LoopEmitter::invalidateSliceIterIdx(OpBuilder &builder, Location loc,
   }
 }
 
-void LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
-                                        const Operation *op, TensorId tid,
-                                        Level lvl,
-                                        SmallVectorImpl<Value> &operands,
-                                        unsigned &retIdx) {
+std::tuple<Value, Value, Value>
+LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
+                                   TensorId tid, Level lvl) {
   if (!isCompressedDLT(lvlTypes[tid][lvl]))
     llvm_unreachable("TODO");
 
   // else generate code to compute next non empty slice.
   Value c0 = C_IDX(0), c1 = C_IDX(1), c2 = C_IDX(2);
 
-  auto whileOp = llvm::cast<scf::WhileOp>(op);
   SliceInfo &info = sliceStack[tid].back();
   assert(info.slicedOnLvl == lvl);
   //
@@ -2182,9 +2219,12 @@ void LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
     builder.setInsertionPointAfter(forOp.loops.front());
     // minOffset = minCrd + 1 >= size ? minCrd + 1 - size : c0
     Value tmp = ADDI(forOp.results.front(), c1);
-    Value minOffset = SUBI(tmp, sliceSizes[tid][lvl][info.depth - 1]);
-    Value p = CMPI(uge, tmp, sliceSizes[tid][lvl][info.depth - 1]);
+    auto [size, stride] = sliceMeta[tid][lvl][info.depth];
+    assert(stride == 1 && "Not yet implemented");
+    Value minOffset = SUBI(tmp, size);
+    Value p = CMPI(uge, tmp, size);
     minOffset = SELECT(p, minOffset, c0);
+
     SmallVector<Value, 3> yields;
     yields.assign(forOp.results.begin(), forOp.results.end());
     yields.push_back(minOffset);
@@ -2200,7 +2240,9 @@ void LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
   Value maxPred = CMPI(ugt, minOffset, nxOffset);
   Value nextAbsOffset = SELECT(maxPred, minOffset, nxOffset);
 
-  Value sliceUB = ADDI(nextAbsOffset, sliceSizes[tid][lvl][info.depth - 1]);
+  auto [size, stride] = sliceMeta[tid][lvl][info.depth];
+  assert(stride == 1 && "Not yet implemented");
+  Value sliceUB = ADDI(nextAbsOffset, size);
 
   // FIXME: this only works if there is only one parent.
   assert(info.depth - 1 == 0);
@@ -2211,15 +2253,7 @@ void LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
   assert(info.depth - 1 == 0);
   Value nextRelOffset = nextAbsOffset;
   nextRelOffset = SELECT(nextNonEmpty, nextRelOffset, c0);
-
-  operands.push_back(nextNonEmpty);
-  operands.push_back(nextMinCrd);
-  operands.push_back(nextAbsOffset); // we push the absolute offset.
-
-  // Update the slice stack.
-  info.isNonEmpty = whileOp.getResult(retIdx++);
-  info.minCrd = whileOp.getResult(retIdx++);
-  info.offset = whileOp.getResult(retIdx++);
+  return std::make_tuple(nextNonEmpty, nextMinCrd, nextAbsOffset);
 }
 
 #undef CMPI

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
index 1b3acd68e587d7..d9948d3f4db73b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -84,18 +84,22 @@ class LoopEmitter {
   using SynTensorBoundSetter =
       function_ref<Value(OpBuilder &builder, Location loc, Level lvl)>;
 
-  // Map from [tid, dim] to a list of dependent [tid, dim] for affine expression
-  // index on sparse tensors.
-  // E.g., for affine index (d0 + d1), it depends on two [tid, dim] that defines
-  // d0 and d1 (for affine expression reduction).
+  // Map from [tid, lvl] to a list of dependent [tidlvl, coeffecient] for
+  // subscript expressions on sparse tensors.
+  //
+  // E.g., for affine index (2 * d0 + d1), it depends on two tidlvls that
+  // defines d0 and d1 (for affine expression reduction) and uses 2 and 1 for
+  // cofficients on d0, d1 respectively.
   // If the list is empty, it means that there is no affine expression on the
-  // input [tid, dim].
+  // input [tid, lvl].
+  //
   // NOTE: The caller is responsible to ensure that the order of the returned
   // list to be consistent with the topological order of the iteration graph,
   // otherwise the loop emitter might reduce a wrong dependent index variable
   // when generating slice-driven loops.
   using DependentLvlGetter =
-      function_ref<std::vector<std::pair<TensorId, Level>>(TensorId, Level)>;
+      function_ref<std::vector<std::pair<TensorLevel, unsigned>>(TensorId,
+                                                                 Level)>;
 
   LoopEmitter() = default;
 
@@ -335,9 +339,9 @@ class LoopEmitter {
     // Whether this is the tensor that has not yet been sliced.
     bool isInitialTensor() const { return !slicedOnLvl.has_value(); }
 
-    Value minCrd;                     // the minimum coordinate of the slice.
-    Value offset;                     // the offset of the current slice.
-    Value isNonEmpty;                 // whether the slice is empty.
+    Value minCrd;     // the minimum coordinate of the slice.
+    Value offset;     // the *absolute* offset of the current slice.
+    Value isNonEmpty; // whether the slice is empty.
     std::optional<Level> slicedOnLvl; // the level on which the slice is done
     unsigned depth; // the depth (relative to dependentDimMap[tid][lvl]).
   };
@@ -645,10 +649,12 @@ class LoopEmitter {
   bool genSliceBegin(OpBuilder &builder, Location loc, TensorId tid, Level lvl);
 
   /// Generates code to get the next non-empty slices of tid on lvl.
-  void genSliceNextInduction(OpBuilder &builder, Location loc,
-                             const Operation *whileOp, TensorId tid, Level lvl,
-                             SmallVectorImpl<Value> &operands,
-                             unsigned &retIdx);
+  /// Returns a tuple of values for <NonEmpty, MinCrd, AbsOffset> (see
+  /// SliceInfo) respectively.
+  std::tuple<Value, Value, Value> genSliceNextInduction(OpBuilder &builder,
+                                                        Location loc,
+                                                        TensorId tid,
+                                                        Level lvl);
 
   /// A optional string attribute that should be attached to the loop
   /// generated by loop emitter, it might help following passes to identify
@@ -707,9 +713,9 @@ class LoopEmitter {
   std::vector<std::vector<Value>> sliceOffsets;
   std::vector<std::vector<Value>> sliceStrides;
 
-  // Map from [tid, level] to a list of dependent [tid, level].
-  // See comments for `DependentDimGetter`.
-  std::vector<std::vector<std::vector<std::pair<TensorId, Level>>>>
+  // Map from [tid, level] to a list of dependent [tidlevel, coefficient].
+  // See comments for `DependentLvlGetter`.
+  std::vector<std::vector<std::vector<std::pair<TensorLevel, unsigned>>>>
       dependentLvlMap;
 
   // The cached position buffer for the slices, they serve the same purpose as
@@ -718,8 +724,9 @@ class LoopEmitter {
   // to avoid iteration from the beginning.
   std::vector<std::vector<std::vector<Value>>> slicePosBuffer;
 
-  // The cached size for each slices.
-  std::vector<std::vector<std::vector<Value>>> sliceSizes;
+  // The (size, stride) for each conceptual slice used for index reduction
+  // loops.
+  std::vector<std::vector<std::vector<std::pair<Value, unsigned>>>> sliceMeta;
 
   // The number of reduced dependencies on a tensor level so far.
   std::vector<std::vector<unsigned>> levelReducedDep;

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 2450fd6c7d03f6..770349d6d1db0f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -282,10 +282,14 @@ static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
 ///
 /// TODO: constant should be easy to handle.
 static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
-                          AffineExpr a, DimLevelType dlt,
-                          bool isSubExp = false) {
+                          AffineExpr a, DimLevelType dlt, bool isSubExp = false,
+                          int64_t coefficient = 1) {
   switch (a.getKind()) {
   case AffineExprKind::DimId: {
+    // Only allow positive coefficients on AffineDimExpr.
+    if (coefficient <= 0)
+      return false;
+
     const LoopId ldx = merger.makeLoopId(a.cast<AffineDimExpr>().getPosition());
     if (!isUndefDLT(merger.getLvlType(tensor, ldx)))
       return false; // used more than once, e.g., A[i][i]
@@ -293,8 +297,10 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
     // TODO: Generalizes the following two cases. A[i] (with trivial index
     // expression) can be treated as a special affine index expression. We do
     // not necessarily need to 
diff erentiate them.
-    if (!isSubExp)
+    if (!isSubExp) {
+      assert(coefficient == 1);
       merger.setLevelAndType(tensor, ldx, lvl, dlt);
+    }
 
     if (isSubExp) {
       // The current loops appears in more than one affine expressions on the
@@ -312,14 +318,26 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
         // else increase min(d0_1, d0_2).
         return false;
       }
-      merger.setLoopDependentTensorLevel(ldx, tensor, lvl, dlt);
+      merger.setLoopDependentTensorLevel(ldx, tensor, lvl, dlt, coefficient);
     }
     return true;
   }
   case AffineExprKind::Constant:
-  case AffineExprKind::Mul:
-    // TODO: Support Mul and Constant AffineExp for slice-based codegen
-    return false;
+    // TODO: Support Constant AffineExp for slice-based codegen
+  case AffineExprKind::Mul: {
+    // TODO: Support index expression like `2 * d0`, we now only support more
+    // complicated cases like `2 * d0 + d1`.
+    if (!isSubExp)
+      return false;
+    auto binOp = a.cast<AffineBinaryOpExpr>();
+    auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
+    if (rhs.isa<AffineConstantExpr>())
+      std::swap(lhs, rhs);
+    // Must be in form of `constant * d`.
+    assert(lhs.isa<AffineConstantExpr>() && rhs.isa<AffineDimExpr>());
+    int64_t coefficient = lhs.cast<AffineConstantExpr>().getValue();
+    return findDepIdxSet(merger, tensor, lvl, rhs, dlt, isSubExp, coefficient);
+  }
   case AffineExprKind::Add: {
     auto binOp = a.cast<AffineBinaryOpExpr>();
     return findDepIdxSet(merger, tensor, lvl, binOp.getLHS(), dlt, true) &&

diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index f39a2069a57dd8..4143efbd0ab28e 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -232,11 +232,11 @@ Merger::Merger(unsigned numInputOutputTensors, unsigned numNativeLoops,
                 std::vector<std::optional<Level>>(numLoops, std::nullopt)),
       lvlToLoop(numTensors,
                 std::vector<std::optional<LoopId>>(maxLvlRank, std::nullopt)),
-      loopToDependencies(
-          numLoops, std::vector<std::optional<std::pair<Level, DimLevelType>>>(
-                        numTensors, std::nullopt)),
-      levelToDependentLoop(numTensors, std::vector<std::vector<LoopId>>(
-                                           maxLvlRank, std::vector<LoopId>())),
+      loopToUnresolvedLvls(numLoops, std::vector<std::optional<LvlDLTPair>>(
+                                         numTensors, std::nullopt)),
+      levelToDependentLoop(numTensors,
+                           std::vector<std::vector<LoopCoeffPair>>(
+                               maxLvlRank, std::vector<LoopCoeffPair>())),
       loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {}
 
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list