[Mlir-commits] [mlir] 2cd1592 - [mlir][sparse] implement index redution on dense level (for CSR)

Peiming Liu llvmlistbot at llvm.org
Mon Apr 17 09:36:38 PDT 2023


Author: Peiming Liu
Date: 2023-04-17T16:36:31Z
New Revision: 2cd15925f4485fc618bc33c1337e1b6f63d84ef6

URL: https://github.com/llvm/llvm-project/commit/2cd15925f4485fc618bc33c1337e1b6f63d84ef6
DIFF: https://github.com/llvm/llvm-project/commit/2cd15925f4485fc618bc33c1337e1b6f63d84ef6.diff

LOG: [mlir][sparse] implement index redution on dense level (for CSR)

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D147550

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_2d.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_3d.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index b624aaddd21df..c1608530b7b1f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -29,6 +29,8 @@ using namespace mlir::sparse_tensor;
   (builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::p, l, r)           \
        .getResult())
 
+#define ADDI(lhs, rhs) (builder.create<arith::AddIOp>(loc, lhs, rhs))
+
 #define C_IDX(v) (constantIndex(builder, loc, v))
 
 /// Generates a pointer/index load from the sparse storage scheme. Narrower
@@ -500,16 +502,17 @@ void LoopEmitter::exitCurrentLoopSeq(OpBuilder &builder, Location loc) {
       assert(sliceStack[tid].back().slicedOnLvl == lvl);
       sliceStack[tid].pop_back();
     } else {
-      // Else this is a resolved-slice, and advance posit similar to TACO.
-      Value c1 = C_IDX(1), c2 = C_IDX(2);
-
-      // pIdx += 2, we finished the current lvl, advance the pointer index of
-      // the previous level by two to skip the [pLo, pHi] for current level.
-      Value sPtrBuf = slicePosBuffer[tid][lvl].back();
-      Value curP = genIndexLoad(builder, loc, sPtrBuf, c1);
-      Value nexP = builder.create<arith::AddIOp>(loc, curP, c2);
-      // TODO: we could probably use an SSA value for it.
-      builder.create<memref::StoreOp>(loc, nexP, sPtrBuf, c1);
+      if (!isDenseDLT(lvlTypes[tid][lvl])) {
+        // Else this is a resolved-slice, and advance posit similar to TACO.
+        Value c1 = C_IDX(1), c2 = C_IDX(2);
+        // pIdx += 2, we finished the current lvl, advance the pointer index of
+        // the previous level by two to skip the [pLo, pHi] for current level.
+        Value sPtrBuf = slicePosBuffer[tid][lvl].back();
+        Value curP = genIndexLoad(builder, loc, sPtrBuf, c1);
+        Value nexP = builder.create<arith::AddIOp>(loc, curP, c2);
+        // TODO: we could probably use an SSA value for it.
+        builder.create<memref::StoreOp>(loc, nexP, sPtrBuf, c1);
+      }
     }
   }
   loopSeqStack.pop_back();
@@ -547,11 +550,9 @@ Value LoopEmitter::genAffine(OpBuilder &builder, Location loc, AffineExpr a) {
   }
 }
 
-Operation *LoopEmitter::emitForLoopOverTensorAtLvl(OpBuilder &builder,
-                                                   Location loc, TensorId tid,
-                                                   Level dstLvl,
-                                                   MutableArrayRef<Value> reduc,
-                                                   bool isParallel) {
+Operation *LoopEmitter::emitForLoopOverTensorAtLvl(
+    OpBuilder &builder, Location loc, TensorId tid, Level dstLvl, Value lo,
+    Value hi, MutableArrayRef<Value> reduc, bool isParallel) {
   bool isSparseCond = isCompressedDLT(lvlTypes[tid][dstLvl]) ||
                       isSingletonDLT(lvlTypes[tid][dstLvl]);
 
@@ -561,9 +562,6 @@ Operation *LoopEmitter::emitForLoopOverTensorAtLvl(OpBuilder &builder,
   // biggest range).
   const Level srcLvl = reassoc.front();
   Value step = C_IDX(1);
-  Value lo = isSparseCond ? posits[tid][srcLvl]        // current offset
-                          : loopSeqStack.back().first; // universal index
-  Value hi = highs[tid][srcLvl];
 
   Operation *loop = nullptr;
   Value iv;
@@ -682,7 +680,7 @@ Operation *LoopEmitter::enterLoopOverTensorAtLvl(
     ArrayRef<Level> lvls, MutableArrayRef<Value> reduc, bool isParallel) {
   // TODO: support multiple return on parallel for?
   assert(!isParallel || reduc.size() <= 1);
-  bool isSparseCond = false, isSliceCond = false;
+  bool isSparseCond = false, isSparseSliceCond = false;
   size_t tid = tids.front(), lvl = lvls.front();
 
   // Finds out the tensor level that we should use to generate loops. Amongs all
@@ -691,25 +689,25 @@ Operation *LoopEmitter::enterLoopOverTensorAtLvl(
     assert(lvlTypes[t].size() > l);         // Must be a valid tid, dim pair
     assert(!coords[t][l] ||                 // We cannot re-enter the same level
            !dependentLvlMap[t][l].empty()); // unless it is a slice-driver loop
-    auto dimType = lvlTypes[t][l];
+    auto lvlType = lvlTypes[t][l];
     // Must be a recognizable DLT.
-    assert(isDenseDLT(dimType) || isCompressedDLT(dimType) ||
-           isSingletonDLT(dimType));
+    assert(isDenseDLT(lvlType) || isCompressedDLT(lvlType) ||
+           isSingletonDLT(lvlType));
 
-    // This is a slice-driven loop.
-    if (!dependentLvlMap[t][l].empty()) {
-      assert(!isSliceCond && !isSparseCond);
-      isSliceCond = true;
+    // This is a slice-driven loop on sparse level.
+    if (!dependentLvlMap[t][l].empty() && !isDenseDLT(lvlType)) {
+      assert(!isSparseSliceCond && !isSparseCond);
+      isSparseSliceCond = true;
       tid = t;
       lvl = l;
       continue;
     }
 
-    bool isSparse = isCompressedDLT(dimType) || isSingletonDLT(dimType);
+    bool isSparse = isCompressedDLT(lvlType) || isSingletonDLT(lvlType);
     // We can at most have one sparse input, otherwise, a while loop is
     // required to co-iterate multiple sparse tensors.
     assert(!isSparseCond || !isSparse);
-    assert(!isSliceCond || !isSparseCond);
+    assert(!isSparseSliceCond || !isSparseCond);
     if (isSparse) {
       tid = t;
       lvl = l;
@@ -717,10 +715,27 @@ Operation *LoopEmitter::enterLoopOverTensorAtLvl(
     isSparseCond = isSparseCond || isSparse;
   }
 
+  DimLevelType lvlType = lvlTypes[tid][lvl];
+  // TODO: Dense slice driven loop can be generated using for loop as well.
+  assert(!isSparseSliceCond || !isDenseDLT(lvlType));
+  bool isDenseSliceCond =
+      isDenseDLT(lvlType) && !dependentLvlMap[tid][lvl].empty();
+  // if the slice is fully reduced, we can now use TACO-based algorithm to
+  // iterate it.
+
+  Operation *l = nullptr;
+
+  // At most one tensor used as condition in for loop;
+  SmallVector<TensorId, 1> condTid;
+  SmallVector<Level, 1> condLvl;
+  // There Might be multiple dense slice driven tensor.
+  SmallVector<TensorId> sliceTids;
+  SmallVector<Level> sliceLvls;
+  SmallVector<bool> sliceReduc;
+
   // Generates loops 
diff erently depending on whether we need a slice-driven
   // loop or a simple level traversal loop.
-  Operation *l = nullptr;
-  if (isSliceCond) {
+  if (isSparseSliceCond) {
     bool fullyReduced = depFullyReduced(tid, lvl);
     if (!fullyReduced) {
       l = emitSliceDrivenLoopOverTensorAtLvl(builder, loc, tid, lvl, reduc);
@@ -733,22 +748,63 @@ Operation *LoopEmitter::enterLoopOverTensorAtLvl(
           lvl, reduc);
     }
     levelReducedDep[tid][lvl]++;
-    // We can also prepare for next dim here in advance
-    // Pushes the loop into stack.
-    loopStack.emplace_back(
-        ArrayRef<TensorId>(), ArrayRef<Level>(), ArrayRef<TensorId>(tid),
-        ArrayRef<Level>(lvl), ArrayRef<bool>(fullyReduced), l,
-        builder.getInsertionBlock(), coords[tid][lvl], loopTag);
+    sliceTids.push_back(tid);
+    sliceLvls.push_back(lvl);
+    sliceReduc.push_back(fullyReduced);
   } else {
-    l = emitForLoopOverTensorAtLvl(builder, loc, tid, lvl, reduc, isParallel);
-    // We can also prepare for next dim here in advance
-    // Pushes the loop into stack.
-    loopStack.emplace_back(ArrayRef<TensorId>(tid), ArrayRef<Level>(lvl),
-                           ArrayRef<TensorId>(), ArrayRef<Level>(),
-                           ArrayRef<bool>(), l, builder.getInsertionBlock(),
-                           coords[tid][lvl], loopTag);
+    Value lo = isSparseCond ? posits[tid][lvl]           // current offset
+                            : loopSeqStack.back().first; // universal index
+    Value hi = highs[tid][lvl];
+    if (isDenseSliceCond) {
+      bool fullyReduced = depFullyReduced(tid, lvl);
+      Value sliceSz = sliceSizes[tid][lvl][sliceStack[tid].back().depth - 1];
+      // Adjust for loop hi for dense slice-driven loop.
+      if (fullyReduced) {
+        hi = sliceSz;
+        condTid.push_back(tid);
+        condLvl.push_back(lvl);
+      } else {
+        hi = builder.create<arith::SubIOp>(loc, lvlSizes[tid][lvl], sliceSz);
+        hi = builder.create<arith::AddIOp>(loc, hi, C_IDX(1));
+      }
+    } else {
+      condTid.push_back(tid);
+      condLvl.push_back(lvl);
+    }
+    l = emitForLoopOverTensorAtLvl(builder, loc, tid, lvl, lo, hi, reduc,
+                                   isParallel);
   }
-
+  Value iv = coords[tid][lvl];
+  for (auto [t, l] : llvm::zip(tids, lvls)) {
+    // We only need to handle slice-driven loops on dense level here.
+    // If it is a slice-driven loop on sparse level, it needs a while loop to
+    // insert break statements, and it must have been handled correctly in L692.
+    if (!dependentLvlMap[t][l].empty() && isDenseDLT(lvlTypes[t][l])) {
+      // Pushes sliced levels to build correct LoopInfo.
+      bool fullyReduc = depFullyReduced(t, l);
+      SliceInfo &info = sliceStack[t].back();
+      if (fullyReduc) {
+        posits[t][l] =
+            genAddress(builder, loc, t, l,
+                       builder.create<arith::AddIOp>(loc, info.offset, iv));
+      } else {
+        // Puts sliced dense loop into LoopInfo so that LoopEmitter knows how to
+        // exit it.
+        sliceTids.push_back(t);
+        sliceLvls.push_back(l);
+        sliceReduc.push_back(fullyReduc);
+        // Update the slice information as we enter the new loop.
+        assert(*info.slicedOnLvl == l);
+        info.minCrd = info.offset = iv;
+        info.isNonEmpty = constantI1(builder, loc, true);
+        levelReducedDep[t][l]++;
+      }
+    }
+  }
+  // NOTE: we can also prepare for next dim here in advance
+  // Pushes the loop into stack.
+  loopStack.emplace_back(condTid, condLvl, sliceTids, sliceLvls, sliceReduc, l,
+                         builder.getInsertionBlock(), iv, loopTag);
   // Emit extra locals.
   emitExtraLocalsForTensorsAtDenseLvls(builder, loc, tids, lvls);
   return l;
@@ -1106,6 +1162,10 @@ void LoopEmitter::emitExtraLocalsForTensorsAtDenseLvls(OpBuilder &builder,
   assert(tids.size() == lvls.size());
   for (auto [tid, lvl] : llvm::zip(tids, lvls)) {
     if (isDenseDLT(lvlTypes[tid][lvl])) {
+      // Slice-driven dense level should have be handled already.
+      if (!dependentLvlMap[tid][lvl].empty())
+        continue;
+
       auto enc = getSparseTensorEncoding(tensors[tid].getType());
       if (enc && !isSparseOutput(tid)) {
         bool validPos = lvl == 0 || posits[tid][lvl - 1];
@@ -1127,6 +1187,18 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
                               MutableArrayRef<Value> reduc) {
   const LoopInfo &loopInfo = loopStack.back();
   rewriter.setInsertionPointToEnd(loopInfo.userCodeBlock);
+  for (auto [tid, lvl, reduced] : llvm::zip(
+           loopInfo.slicedTids, loopInfo.slicedLvls, loopInfo.sliceReduced)) {
+    SliceInfo &info = sliceStack[tid].back();
+    assert(isDenseDLT(lvlTypes[tid][lvl]));
+    assert(*info.slicedOnLvl == lvl && !reduced);
+    (void)reduced;
+    // Resets slices pointers as the resolved slices are invalidated after we
+    // moves forward to the next slice.
+    invalidateSliceIterIdx(rewriter, loc, tid, lvl);
+    info.minCrd = info.offset = info.isNonEmpty = Value();
+    levelReducedDep[tid][lvl]--;
+  }
   if (auto forOp = llvm::dyn_cast<scf::ForOp>(loopInfo.loop)) {
     if (!reduc.empty()) {
       assert(reduc.size() == forOp.getNumResults());
@@ -1220,6 +1292,8 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
   unsigned delta = 0;
   for (auto [tid, lvl, resolved] : llvm::zip(
            loopInfo.slicedTids, loopInfo.slicedLvls, loopInfo.sliceReduced)) {
+    // TODO: handle dense.
+    assert(isCompressedDLT(lvlTypes[tid][lvl]));
     levelReducedDep[tid][lvl]--;
     if (!resolved) {
       genSliceNextInduction(builder, loc, whileOp, tid, lvl, operands, o);
@@ -1338,18 +1412,15 @@ unsigned LoopEmitter::remDepOnLevel(TensorId tid, Level lvl) const {
   return totalDependencies;
 }
 
-const LoopEmitter::SliceInfo &LoopEmitter::getFinalSliceOnLvl(TensorId tid,
-                                                              Level lvl) {
+const LoopEmitter::SliceInfo &LoopEmitter::getMostRecentSliceOnLvl(TensorId tid,
+                                                                   Level lvl) {
   // Finds the most-recent slice using a reverse iteration.
   for (auto it = sliceStack[tid].rbegin(), ie = sliceStack[tid].rend(); it < ie;
        it++) {
     if (it->slicedOnLvl == lvl) { // the level matched
-      // Must be the final slice we need to fully reduced the expression too.
-      assert(it->depth == dependentLvlMap[tid][lvl].size() - 1);
       return *it;
     }
   }
-
   llvm_unreachable("Failed to find sliceInfo");
 }
 
@@ -1366,9 +1437,7 @@ const LoopEmitter::SliceInfo &LoopEmitter::getFinalSliceOnLvl(TensorId tid,
 std::pair<Operation *, ValueRange> LoopEmitter::genSliceLvlTraverseLoop(
     OpBuilder &builder, Location loc, Value loopLo, Value loopHi, Value offset,
     Value size, TensorId tid, Level lvl, ValueRange userReduc, bool genYield,
-    llvm::function_ref<void(OpBuilder &, Location, Value,
-                            MutableArrayRef<Value>)>
-        bodyBuilder) {
+    LoopBodyBuilder bodyBuilder) {
   Value c1 = C_IDX(1);
   Value sliceHi = builder.create<arith::AddIOp>(loc, offset, size);
 
@@ -1454,40 +1523,106 @@ std::pair<Operation *, ValueRange> LoopEmitter::genSliceLvlTraverseLoop(
 //   }
 // }
 ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
-    OpBuilder &builder, Location loc, Value offset, TensorId tid, Level lvl,
-    size_t depth, ValueRange userReduc,
-    llvm::function_ref<void(OpBuilder &, Location, Value,
-                            MutableArrayRef<Value>)>
-        bodyBuilder) {
-
+    OpBuilder &builder, Location loc, TensorId tid,
+    ArrayRef<const SliceInfo *> unResLvls, ValueRange userReduc,
+    LoopBodyBuilder bodyBuilder) {
+  // assert(unResLvls.size() == 1 && "TODO");
   Value c0 = C_IDX(0), c1 = C_IDX(1), c2 = C_IDX(2);
 
-  // TODO: it only works on all compressed tensor.
-  Value sPtrBuf = slicePosBuffer[tid][lvl][depth];
-  Value pSt = c2;                                      // pointer starting index
-  Value mSz = genIndexLoad(builder, loc, sPtrBuf, c0); // memSize
-
-  auto forOp =
-      scf::buildLoopNest(
-          builder, loc, pSt, mSz, c2, userReduc,
-          [this, c1, tid, lvl, offset, sPtrBuf,
-           bodyBuilder](OpBuilder &builder, Location loc, ValueRange ivs,
-                        ValueRange iterArgs) -> scf::ValueVector {
+  const SliceInfo &frontSlice = *unResLvls.back();
+  Level firstLvl = *frontSlice.slicedOnLvl;
+  assert(!lvlFullyResolved(tid, firstLvl) && "TODO");
+
+  // FIXME: it is not zero when the first level is fully resolved.
+  Value pos = c0;
+  OpBuilder::InsertPoint ip;
+  SmallVector<Value> innerArgs(userReduc.begin(), userReduc.end());
+  scf::ForOp outerMost = nullptr;
+  if (!lvlFullyResolved(tid, firstLvl)) {
+    if (isCompressedDLT(lvlTypes[tid][firstLvl])) {
+      unsigned depth = frontSlice.depth - 1;
+      Value offset = frontSlice.offset;
+      Value sPtrBuf = slicePosBuffer[tid][firstLvl][depth];
+      Value mSz = genIndexLoad(builder, loc, sPtrBuf, c0); // memSize
+      outerMost = builder.create<scf::ForOp>(
+          loc, c2, mSz, c2, innerArgs,
+          [this, c1, tid, firstLvl, offset, sPtrBuf, &ip, &pos, &innerArgs](
+              OpBuilder &builder, Location loc, Value iv, ValueRange iterArgs) {
             // generate traversal for each level.
-            Value loopLo = genIndexLoad(builder, loc, sPtrBuf, ivs.front());
-            Value loopHi = genIndexLoad(
-                builder, loc, sPtrBuf,
-                builder.create<arith::AddIOp>(loc, ivs.front(), c1));
-            return genSliceLvlTraverseLoop(builder, loc, loopLo, loopHi, offset,
-                                           sliceSizes[tid][lvl].back(), tid,
-                                           lvl, iterArgs, true, bodyBuilder)
-                .second;
-          })
-          .loops.front();
+            Value loopLo = genIndexLoad(builder, loc, sPtrBuf, iv);
+            Value loopHi = genIndexLoad(builder, loc, sPtrBuf, ADDI(iv, c1));
+            ValueRange itArgs =
+                genSliceLvlTraverseLoop(
+                    builder, loc, loopLo, loopHi, offset,
+                    sliceSizes[tid][firstLvl].back(), tid, firstLvl, iterArgs,
+                    false,
+                    [&](OpBuilder &builder, Location, Value iv,
+                        MutableArrayRef<Value> reduc) {
+                      ip = builder.saveInsertionPoint();
+                      pos = iv;
+                      innerArgs.assign(reduc.begin(), reduc.end());
+                    })
+                    .second;
+            builder.create<scf::YieldOp>(loc, itArgs);
+          });
+    } else if (isDenseDLT(lvlTypes[tid][firstLvl])) {
+      assert(firstLvl == 0); // This must be the first level.
+      Value lb = frontSlice.offset;
+      Value sliceSz =
+          sliceSizes[tid][*frontSlice.slicedOnLvl][frontSlice.depth - 1];
+      Value ub = ADDI(lb, sliceSz);
+      outerMost = builder.create<scf::ForOp>(
+          loc, lb, ub, c1, innerArgs,
+          [&](OpBuilder &builder, Location loc, Value iv, ValueRange iterArgs) {
+            ip = builder.saveInsertionPoint();
+            pos = iv;
+            innerArgs.assign(iterArgs.begin(), iterArgs.end());
+          });
+    }
+    // We generated the loop for the first slice above, now remove it.
+    unResLvls = unResLvls.drop_back();
+  }
 
+  // Reset the insertion point into the loop body.
+  builder.restoreInsertionPoint(ip);
+  if (!unResLvls.empty()) {
+    // Fills in dense slices levels in between.
+    SmallVector<Value> lbs, ubs, steps, lvlSzs;
+    for (const SliceInfo *slice : llvm::reverse(unResLvls)) {
+      Level sliceLvl = *slice->slicedOnLvl;
+      assert(isDenseDLT(lvlTypes[tid][sliceLvl]));
+      Value offset = slice->offset;
+      Value sliceSz = sliceSizes[tid][sliceLvl][slice->depth - 1];
+      lbs.push_back(offset);
+      ubs.push_back(builder.create<arith::AddIOp>(loc, offset, sliceSz));
+      steps.push_back(c1);
+      lvlSzs.push_back(lvlSizes[tid][sliceLvl]);
+    }
+    auto denseNest = scf::buildLoopNest(
+        builder, loc, lbs, ubs, steps, innerArgs,
+        [&innerArgs, &lvlSzs, &pos,
+         bodyBuilder](OpBuilder &builder, Location loc, ValueRange ivs,
+                      ValueRange iterArgs) -> scf::ValueVector {
+          for (auto em : llvm::enumerate(ivs)) {
+            // Linearizes postion: pos = (pos * lvlsize) + iv;
+            pos = builder.create<arith::MulIOp>(loc, pos, lvlSzs[em.index()]);
+            pos = builder.create<arith::AddIOp>(loc, pos, em.value());
+          }
+          innerArgs.assign(iterArgs.begin(), iterArgs.end());
+          // Generates user request loop body.
+          // TODO: we do not have to check inbound for dense levels
+          bodyBuilder(builder, loc, pos, innerArgs);
+          return innerArgs;
+        });
+    builder.create<scf::YieldOp>(loc, denseNest.results);
+  } else {
+    // Generates user request loop body.
+    bodyBuilder(builder, loc, pos, innerArgs);
+    builder.create<scf::YieldOp>(loc, innerArgs);
+  }
   // Insert after current while operation.
-  builder.setInsertionPointAfter(forOp);
-  return forOp.getResults();
+  builder.setInsertionPointAfter(outerMost);
+  return outerMost.getResults();
 }
 
 void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
@@ -1495,6 +1630,13 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
   assert(lvl == 0 && "TODO: handle non-first level");
   Value c0 = C_IDX(0), c1 = C_IDX(1), c2 = C_IDX(2), c3 = C_IDX(3),
         c4 = C_IDX(4);
+  if (isDenseDLT(lvlTypes[tid][lvl])) {
+    // Dense slice begin is trivial.
+    sliceStack[tid].emplace_back(/*minCoord=*/c0, /*offset=*/c0,
+                                 /*nonEmpty=*/constantI1(builder, loc, true),
+                                 lvl, /*depth=*/1);
+    return;
+  }
   Value size = sliceSizes[tid][0][0];
   Value sPtrBuf = slicePosBuffer[tid][0][0];
   Value pHi = genIndexLoad(builder, loc, positionsBuffers[tid][0], c1);
@@ -1540,18 +1682,41 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
 // }
 void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
                                           TensorId tid, Level lvl) {
-  assert(isCompressedDLT(lvlTypes[tid][lvl]));
   Value c0 = C_IDX(0), c1 = C_IDX(1), c2 = C_IDX(2);
-  const SliceInfo &sliceInfo = sliceStack[tid].back();
-  unsigned prevLvl = *sliceInfo.slicedOnLvl;
-  assert(lvl >= prevLvl);
-  // Either lvl = prevSlicedLvl, i.e., t[d0 + d1 + d2,...] (more than one
+  unsigned depth = levelReducedDep[tid][lvl];
+  Value size = sliceSizes[tid][lvl][depth];
+  // Dense slice begin is trivial
+  if (isDenseDLT(lvlTypes[tid][lvl])) {
+    sliceStack[tid].emplace_back(c0, c0, constantI1(builder, loc, false), lvl,
+                                 depth + 1);
+    return;
+  }
+
+  assert(isCompressedDLT(lvlTypes[tid][lvl]));
+  // Unhandled Cases:
+  //
+  // 1st, lvl = prevSlicedLvl, i.e., t[d0 + d1 + d2,...] (more than one
   // variable need to be reduced on the same level).
-  // Or lvl > prevSliceLvl + 1, i.e., t[..., d2, d3 + d4] (having a
+  //
+  // 2nd, lvl > prevSliceLvl + 1, i.e., t[..., d2, d3 + d4] (having a
   // simple dim expression in between).
-  assert(lvl == prevLvl + 1 && "TODO: not yet implemented");
+  assert(lvl == *sliceStack[tid].back().slicedOnLvl + 1);
+
   // Check slice stack integrity.
-  assert(slicePosBuffer[tid][prevLvl].size() == sliceInfo.depth);
+  assert(slicePosBuffer[tid][lvl - 1].size() == sliceStack[tid].back().depth);
+
+  SmallVector<const SliceInfo *> unResSlices;
+  for (Level curLvl = lvl; curLvl >= 1; curLvl--) {
+    Level prevLvl = curLvl - 1;
+    unResSlices.push_back(&getMostRecentSliceOnLvl(tid, prevLvl));
+    if (!isDenseDLT(lvlTypes[tid][prevLvl]) || lvlFullyResolved(tid, prevLvl)) {
+      break;
+    }
+  }
+
+  assert(!unResSlices.empty() &&
+         !lvlFullyResolved(tid, *unResSlices.front()->slicedOnLvl));
+
   Value sPtrBuf = slicePosBuffer[tid][lvl].back();
   SmallVector<Value, 3> reduc = {
       constantI1(builder, loc, false), // isNonEmpty
@@ -1560,7 +1725,7 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
   };
 
   ValueRange result = genUnResolvedSliceTreeTraverse(
-      builder, loc, sliceInfo.offset, tid, prevLvl, sliceInfo.depth - 1, reduc,
+      builder, loc, tid, unResSlices, reduc,
       [this, c1, c2, tid, lvl, sPtrBuf](OpBuilder &builder, Location loc,
                                         Value iv,
                                         MutableArrayRef<Value> reduc) {
@@ -1606,8 +1771,6 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
         curMemSz = builder.create<arith::AddIOp>(loc, curMemSz, c2);
       });
 
-  unsigned depth = levelReducedDep[tid][lvl];
-  Value size = sliceSizes[tid][lvl][depth];
   Value isNonEmpty = result[0];
   Value minCrd = result[1];
   // Two metadata [memSize, idx].
@@ -1624,6 +1787,10 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
   Value c1 = C_IDX(1), c2 = C_IDX(2);
 
   if (depFullyReduced(tid, lvl)) {
+    // Do not need to prepare for slice driven loop on dense level after it is
+    // fully reduced.
+    if (isDenseDLT(lvlTypes[tid][lvl]))
+      return true;
     // If constraints on the tensor is fully resolved. We do not need to
     // generates slice begin any more, instead we fall back to TACO-based
     // algorithm to (co)iterates over the slice.
@@ -1703,6 +1870,16 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
   return false;
 }
 
+void LoopEmitter::invalidateSliceIterIdx(OpBuilder &builder, Location loc,
+                                         TensorId tid, Level lvl) {
+  for (unsigned i = 0; i <= lvl; i++) {
+    if (!isDenseDLT(lvlTypes[tid][i])) {
+      builder.create<memref::StoreOp>(loc, C_IDX(0),
+                                      slicePosBuffer[tid][i].back(), C_IDX(1));
+    }
+  }
+}
+
 void LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
                                         const Operation *op, TensorId tid,
                                         Level lvl,
@@ -1712,14 +1889,11 @@ void LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
     llvm_unreachable("TODO");
 
   // else generate code to compute next non empty slice.
-  Value c0 = C_IDX(0);
-  Value c1 = C_IDX(1);
-  Value c2 = C_IDX(2);
+  Value c0 = C_IDX(0), c1 = C_IDX(1), c2 = C_IDX(2);
 
   auto whileOp = llvm::cast<scf::WhileOp>(op);
   SliceInfo &info = sliceStack[tid].back();
   assert(info.slicedOnLvl == lvl);
-
   //
   // We forward to the next non empty slice by
   // if (minCrd > offset) {
@@ -1735,8 +1909,7 @@ void LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
   Value absOffset = info.offset;
   // Resets slices pointers as the resolved slices are invalidated after we
   // moves forward to the next slice.
-  for (unsigned i = 0; i <= lvl; i++)
-    builder.create<memref::StoreOp>(loc, c0, slicePosBuffer[tid][i].back(), c1);
+  invalidateSliceIterIdx(builder, loc, tid, lvl);
 
   SmallVector<Value, 3> reduc = {info.minCrd, info.isNonEmpty, absOffset};
   Value sPtrBuf = slicePosBuffer[tid][lvl][info.depth - 1];
@@ -1949,4 +2122,5 @@ Operation *LoopEmitter::emitSliceDrivenLoopOverTensorAtLvl(
 }
 
 #undef CMPI
+#undef ADDI
 #undef C_IDX

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
index 5bbb68198e0f5..554f24b16f8d6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -264,6 +264,9 @@ class LoopEmitter {
     unsigned depth; // the depth (relative to dependentDimMap[tid][lvl]).
   };
 
+  using LoopBodyBuilder = llvm::function_ref<void(OpBuilder &, Location, Value,
+                                                  MutableArrayRef<Value>)>;
+
   /// Linearizes address for dense dimension (i.e., p = (i * d0) + j).
   Value genAddress(OpBuilder &builder, Location loc, TensorId tid, Level lvl,
                    Value iv);
@@ -318,11 +321,13 @@ class LoopEmitter {
                                             ArrayRef<TensorId> tids,
                                             ArrayRef<Level> lvls);
 
-  /// Emits a for loop to iterate over a dense level, or a sparse level that has
-  /// not been sliced.
+  /// Emits a for loop to iterate over a tensor level with the provided lower
+  /// bound `lo` and upper bound `hi`.
+  /// Apart from iterating just single tensor level, for loops can be used for
+  /// slice-driven loop on dense level too.
   Operation *emitForLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
-                                        TensorId tid, Level lvl,
-                                        MutableArrayRef<Value> reduc,
+                                        TensorId tid, Level lvl, Value lo,
+                                        Value hi, MutableArrayRef<Value> reduc,
                                         bool isParallel);
 
   /// Emits a while loop to iterate over a sparse level that has been sliced.
@@ -405,9 +410,16 @@ class LoopEmitter {
 
   /// Retrieves the most recent slice on lvl. To reduce affine expression like
   /// d0 + d1 + d2, we need two slices (one of size d1 + d2, and the other of
-  /// size d2). This methods returns the latter slice (of size d2), which is
-  /// also the final slice on the level.
-  const SliceInfo &getFinalSliceOnLvl(TensorId tid, Level lvl);
+  /// size d2). This methods returns the latter slice (of size d2).
+  const SliceInfo &getMostRecentSliceOnLvl(TensorId tid, Level lvl);
+
+  /// Similar to getMostRecentSliceOnLvl, but yields error when the most recent
+  /// slice is not the final slice needed to fully reduced the dependencies.
+  const SliceInfo &getFinalSliceOnLvl(TensorId tid, Level lvl) {
+    const SliceInfo &info = getMostRecentSliceOnLvl(tid, lvl);
+    assert(info.depth == dependentLvlMap[tid][lvl].size() - 1);
+    return info;
+  }
 
   /// Get the remaining number of constraints needed to fully *resolve*
   /// dependent levels on tensor[tid].
@@ -436,18 +448,15 @@ class LoopEmitter {
   genSliceLvlTraverseLoop(OpBuilder &builder, Location loc, Value pLo,
                           Value pHi, Value offset, Value size, TensorId tid,
                           Level lvl, ValueRange userReduc, bool genYield,
-                          /*bodyBuilder=*/
-                          llvm::function_ref<void(OpBuilder &, Location, Value,
-                                                  MutableArrayRef<Value>)>);
+                          LoopBodyBuilder bodyBuilder);
 
   /// Generates a nested loop that iterates over tid on all the coordinates on
   /// lvl.
-  ValueRange genUnResolvedSliceTreeTraverse(
-      OpBuilder &builder, Location loc, Value offset, TensorId tid, Level lvl,
-      size_t depth, ValueRange userReduc,
-      /*bodyBody=*/
-      llvm::function_ref<void(OpBuilder &, Location, Value,
-                              MutableArrayRef<Value>)>);
+  ValueRange
+  genUnResolvedSliceTreeTraverse(OpBuilder &builder, Location loc, TensorId tid,
+                                 ArrayRef<const SliceInfo *> unResLvls,
+                                 ValueRange userReduc,
+                                 LoopBodyBuilder bodyBuilder);
 
   /// Generates code to get the first non-empty slice of tid on lvl, when all
   /// the previous level before `lvl` are resolved (or lvl is the first level).
@@ -465,6 +474,11 @@ class LoopEmitter {
   void genUnResolvedSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
                                Level lvl);
 
+  /// Invalidates the index kept in slice postion buffers (by setting it to
+  /// zero).
+  /// TODO: We should instead use an SSA value for the index.
+  void invalidateSliceIterIdx(OpBuilder &builder, Location loc, TensorId tid,
+                              Level lvl);
   /// Generates code to get the first non-empty slice of tid on lvl.
   /// return true if has already been resolved.
   bool genSliceBegin(OpBuilder &builder, Location loc, TensorId tid, Level lvl);

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 40a2454f779de..3a90ca513cc45 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1484,11 +1484,12 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
                                            std::optional<Level> lvl,
                                            DimLevelType dlt, bool isIdxReduc) {
     assert(env.merger().loop(b) == idx);
-    // FIXME: Dense index reduction can reuse the universal index as well.
-    if (!isIdxReduc && (isDenseDLT(dlt) || isUndefDLT(dlt))) {
+    if (isDenseDLT(dlt) || isUndefDLT(dlt))
       needsUniv = true;
-    } else {
-      // sparse/singleton levels.
+    if (isCompressedDLT(dlt) || isSingletonDLT(dlt) || isIdxReduc) {
+      // Only when this is a index reduction loop, can the dlt be undefined.
+      assert(!isUndefDLT(dlt) || isIdxReduc);
+      // sparse/singleton levels, or a dense/sparse index reduction loop.
       tids.push_back(tid);
       lvls.push_back(*lvl);
     }
@@ -1581,7 +1582,7 @@ static bool translateBitsToTidLvlPairs(
           tids.push_back(tid);
           lvls.push_back(*lvl);
           numloopCond++;
-        } else if (isDenseDLT(dlt)) {
+        } else if (isDenseDLT(dlt) || isIdxReduc) {
           tids.push_back(tid);
           lvls.push_back(*lvl);
         } else {

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_2d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_2d.mlir
index 1ca6b81285e25..555d1bb232035 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_2d.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_2d.mlir
@@ -1,9 +1,4 @@
-// UNSUPPORTED: target={{.*}}
-// FIXME: The test case is disabled (for now) because affine index on sparse tensor
-// are not handled efficiently by sparse compiler, the test case will be re-enabled
-// after new algorithm is implemented.
-
-// DEFINE: %{option} = enable-runtime-library=true
+// DEFINE: %{option} = "enable-runtime-library=true enable-index-reduction=true"
 // DEFINE: %{compile} = mlir-opt %s --sparse-compiler=%{option}
 // DEFINE: %{run} = mlir-cpu-runner \
 // DEFINE:  -e entry -entry-point-result=void  \
@@ -13,16 +8,16 @@
 // RUN: %{compile} | %{run}
 //
 // Do the same run, but now with direct IR generation.
-// REDEFINE: %{option} = "enable-runtime-library=false enable-buffer-initialization=true"
+// REDEFINE: %{option} = "enable-runtime-library=false enable-index-reduction=true"
 // RUN: %{compile} | %{run}
 //
 // Do the same run, but now with direct IR generation and vectorization.
-// REDEFINE: %{option} = "enable-runtime-library=false enable-buffer-initialization=true vl=2 reassociate-fp-reductions=true enable-index-optimizations=true"
+// REDEFINE: %{option} = "enable-runtime-library=false enable-index-reduction=true vl=2 reassociate-fp-reductions=true enable-index-optimizations=true"
 // RUN: %{compile} | %{run}
 
 // Do the same run, but now with direct IR generation and, if available, VLA
 // vectorization.
-// REDEFINE: %{option} = "enable-runtime-library=false vl=4 enable-arm-sve=%ENABLE_VLA"
+// REDEFINE: %{option} = "enable-runtime-library=false vl=4 enable-index-reduction=true enable-arm-sve=%ENABLE_VLA"
 // REDEFINE: %{run} = %lli \
 // REDEFINE:   --entry-function=entry_lli \
 // REDEFINE:   --extra-module=%S/Inputs/main_for_lli.ll \
@@ -33,6 +28,7 @@
 
 #DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
 #CSR = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}>
+#CDR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "dense"]}>
 #CSC = #sparse_tensor.encoding<{
   dimLevelType = [ "dense", "compressed" ],
   dimOrdering = affine_map<(i,j) -> (j,i)>
@@ -42,46 +38,55 @@
 module {
 
   func.func @conv2d(%input:  tensor<8x8xi32>,
-               %filter: tensor<3x3xi32, #DCSR>,
+               %filter: tensor<3x3xi32>,
                %output: tensor<6x6xi32>) -> tensor<6x6xi32> {
     %0 = linalg.conv_2d
-      ins  (%input, %filter: tensor<8x8xi32>, tensor<3x3xi32, #DCSR>)
+      ins  (%input, %filter: tensor<8x8xi32>, tensor<3x3xi32>)
       outs (%output: tensor<6x6xi32>) -> tensor<6x6xi32>
     return %0 : tensor<6x6xi32>
   }
 
   func.func @conv2d_sparse_out(%input:  tensor<8x8xi32>,
-               %filter: tensor<3x3xi32, #DCSR>) -> tensor<6x6xi32, #DCSR> {
+               %filter: tensor<3x3xi32>) -> tensor<6x6xi32, #DCSR> {
     %s = bufferization.alloc_tensor() : tensor<6x6xi32, #DCSR>
     %0 = linalg.conv_2d
-      ins  (%input, %filter: tensor<8x8xi32>, tensor<3x3xi32, #DCSR>)
+      ins  (%input, %filter: tensor<8x8xi32>, tensor<3x3xi32>)
       outs (%s: tensor<6x6xi32, #DCSR>) -> tensor<6x6xi32, #DCSR>
     return %0 : tensor<6x6xi32, #DCSR>
   }
 
   func.func @conv2d_all_sparse_DCSR(%input:  tensor<8x8xi32, #DCSR>,
-               %filter: tensor<3x3xi32, #DCSR>) -> tensor<6x6xi32, #DCSR> {
+               %filter: tensor<3x3xi32>) -> tensor<6x6xi32, #DCSR> {
     %s = bufferization.alloc_tensor() : tensor<6x6xi32, #DCSR>
     %0 = linalg.conv_2d
-      ins  (%input, %filter: tensor<8x8xi32, #DCSR>, tensor<3x3xi32, #DCSR>)
+      ins  (%input, %filter: tensor<8x8xi32, #DCSR>, tensor<3x3xi32>)
       outs (%s: tensor<6x6xi32, #DCSR>) -> tensor<6x6xi32, #DCSR>
     return %0 : tensor<6x6xi32, #DCSR>
   }
 
   func.func @conv2d_all_sparse_CSR(%input:  tensor<8x8xi32, #CSR>,
-               %filter: tensor<3x3xi32, #CSR>) -> tensor<6x6xi32, #CSR> {
+               %filter: tensor<3x3xi32>) -> tensor<6x6xi32, #CSR> {
     %s = bufferization.alloc_tensor() : tensor<6x6xi32, #CSR>
     %0 = linalg.conv_2d
-      ins  (%input, %filter: tensor<8x8xi32, #CSR>, tensor<3x3xi32, #CSR>)
+      ins  (%input, %filter: tensor<8x8xi32, #CSR>, tensor<3x3xi32>)
       outs (%s: tensor<6x6xi32, #CSR>) -> tensor<6x6xi32, #CSR>
     return %0 : tensor<6x6xi32, #CSR>
   }
 
+  func.func @conv2d_all_sparse_CD(%input:  tensor<8x8xi32, #CDR>,
+               %filter: tensor<3x3xi32>) -> tensor<6x6xi32, #CDR> {
+    %s = bufferization.alloc_tensor() : tensor<6x6xi32, #CDR>
+    %0 = linalg.conv_2d
+      ins  (%input, %filter: tensor<8x8xi32, #CDR>, tensor<3x3xi32>)
+      outs (%s: tensor<6x6xi32, #CDR>) -> tensor<6x6xi32, #CDR>
+    return %0 : tensor<6x6xi32, #CDR>
+  }
+
   func.func @conv2d_all_sparse_CSC(%input:  tensor<8x8xi32, #CSC>,
-               %filter: tensor<3x3xi32, #CSC>) -> tensor<6x6xi32, #CSC> {
+               %filter: tensor<3x3xi32>) -> tensor<6x6xi32, #CSC> {
     %s = bufferization.alloc_tensor() : tensor<6x6xi32, #CSC>
     %0 = linalg.conv_2d
-      ins  (%input, %filter: tensor<8x8xi32, #CSC>, tensor<3x3xi32, #CSC>)
+      ins  (%input, %filter: tensor<8x8xi32, #CSC>, tensor<3x3xi32>)
       outs (%s: tensor<6x6xi32, #CSC>) -> tensor<6x6xi32, #CSC>
     return %0 : tensor<6x6xi32, #CSC>
   }
@@ -96,12 +101,6 @@ module {
       [  0,  0,  0 ],
       [ -1,  0,  1 ]
     ]> : tensor<3x3xi32>
-    %sparse_filter_DCSR = sparse_tensor.convert %filter
-      : tensor<3x3xi32> to tensor<3x3xi32, #DCSR>
-    %sparse_filter_CSR = sparse_tensor.convert %filter
-      : tensor<3x3xi32> to tensor<3x3xi32, #CSR>
-    %sparse_filter_CSC = sparse_tensor.convert %filter
-      : tensor<3x3xi32> to tensor<3x3xi32, #CSC>
 
 
     %input = arith.constant dense<[
@@ -118,26 +117,31 @@ module {
       : tensor<8x8xi32> to tensor<8x8xi32, #DCSR>
     %sparse_input_CSR = sparse_tensor.convert %input
       : tensor<8x8xi32> to tensor<8x8xi32, #CSR>
+    %sparse_input_CD = sparse_tensor.convert %input
+      : tensor<8x8xi32> to tensor<8x8xi32, #CDR>
     %sparse_input_CSC = sparse_tensor.convert %input
       : tensor<8x8xi32> to tensor<8x8xi32, #CSC>
 
     // Call the kernel.
     %output = arith.constant dense<0> : tensor<6x6xi32>
-    %0 = call @conv2d(%input, %sparse_filter_DCSR, %output)
+    %0 = call @conv2d(%input, %filter, %output)
        : (tensor<8x8xi32>,
-          tensor<3x3xi32, #DCSR>, tensor<6x6xi32>) -> tensor<6x6xi32>
-    %1 = call @conv2d_sparse_out(%input, %sparse_filter_DCSR)
+          tensor<3x3xi32>, tensor<6x6xi32>) -> tensor<6x6xi32>
+    %1 = call @conv2d_sparse_out(%input, %filter)
        : (tensor<8x8xi32>,
-          tensor<3x3xi32, #DCSR>) -> tensor<6x6xi32, #DCSR>
-    %2 = call @conv2d_all_sparse_DCSR(%sparse_input_DCSR, %sparse_filter_DCSR)
+          tensor<3x3xi32>) -> tensor<6x6xi32, #DCSR>
+    %2 = call @conv2d_all_sparse_DCSR(%sparse_input_DCSR, %filter)
        : (tensor<8x8xi32, #DCSR>,
-          tensor<3x3xi32, #DCSR>) -> tensor<6x6xi32, #DCSR>
-    %3 = call @conv2d_all_sparse_CSR(%sparse_input_CSR, %sparse_filter_CSR)
+          tensor<3x3xi32>) -> tensor<6x6xi32, #DCSR>
+    %3 = call @conv2d_all_sparse_CSR(%sparse_input_CSR, %filter)
        : (tensor<8x8xi32, #CSR>,
-          tensor<3x3xi32, #CSR>) -> tensor<6x6xi32, #CSR>
-    %4 = call @conv2d_all_sparse_CSC(%sparse_input_CSC, %sparse_filter_CSC)
+          tensor<3x3xi32>) -> tensor<6x6xi32, #CSR>
+    %4 = call @conv2d_all_sparse_CD(%sparse_input_CD, %filter)
+       : (tensor<8x8xi32, #CDR>,
+          tensor<3x3xi32>) -> tensor<6x6xi32, #CDR>
+    %5 = call @conv2d_all_sparse_CSC(%sparse_input_CSC, %filter)
        : (tensor<8x8xi32, #CSC>,
-          tensor<3x3xi32, #CSC>) -> tensor<6x6xi32, #CSC>
+          tensor<3x3xi32>) -> tensor<6x6xi32, #CSC>
 
 
     // Verify the output.
@@ -183,6 +187,21 @@ module {
       : tensor<6x6xi32>, vector<6x6xi32>
     vector.print %v2 : vector<6x6xi32>
 
+    //
+    // Should be the same as dense output
+    // CHECK:    ( ( 0, 0, -1, -6, -1, 6 ),
+    // CHECK-SAME: ( -1, 0, 1, 0, 1, 0 ),
+    // CHECK-SAME: ( 0, -1, 1, 0, 0, 0 ),
+    // CHECK-SAME: ( -1, 0, 0, 0, 0, 0 ),
+    // CHECK-SAME: ( 0, 0, 3, 6, -3, -6 ),
+    // CHECK-SAME: ( 2, -1, 3, 0, -3, 0 ) )
+    //
+    %all_sparse_CD = sparse_tensor.convert %4
+      : tensor<6x6xi32, #CDR> to tensor<6x6xi32>
+    %v4 = vector.transfer_read %all_sparse_CD[%c0, %c0], %i0
+      : tensor<6x6xi32>, vector<6x6xi32>
+    vector.print %v4 : vector<6x6xi32>
+
     //
     // Should be the same as dense output
     // CHECK:    ( ( 0, 0, -1, -6, -1, 6 ),
@@ -207,25 +226,23 @@ module {
     // CHECK-SAME: ( 0, 0, 3, 6, -3, -6 ),
     // CHECK-SAME: ( 2, -1, 3, 0, -3, 0 ) )
     //
-    %all_sparse_CSC = sparse_tensor.convert %4
+    %all_sparse_CSC = sparse_tensor.convert %5
       : tensor<6x6xi32, #CSC> to tensor<6x6xi32>
-    %v4 = vector.transfer_read %all_sparse_CSC[%c0, %c0], %i0
+    %v5 = vector.transfer_read %all_sparse_CSC[%c0, %c0], %i0
       : tensor<6x6xi32>, vector<6x6xi32>
-    vector.print %v4 : vector<6x6xi32>
+    vector.print %v5 : vector<6x6xi32>
 
     // Release the resources.
-    bufferization.dealloc_tensor %sparse_filter_DCSR : tensor<3x3xi32, #DCSR>
-    bufferization.dealloc_tensor %sparse_filter_CSR : tensor<3x3xi32, #CSR>
-    bufferization.dealloc_tensor %sparse_filter_CSC : tensor<3x3xi32, #CSC>
-
     bufferization.dealloc_tensor %sparse_input_DCSR : tensor<8x8xi32, #DCSR>
     bufferization.dealloc_tensor %sparse_input_CSR : tensor<8x8xi32, #CSR>
     bufferization.dealloc_tensor %sparse_input_CSC : tensor<8x8xi32, #CSC>
+    bufferization.dealloc_tensor %sparse_input_CD : tensor<8x8xi32, #CDR>
 
     bufferization.dealloc_tensor %1 : tensor<6x6xi32, #DCSR>
     bufferization.dealloc_tensor %2 : tensor<6x6xi32, #DCSR>
     bufferization.dealloc_tensor %3 : tensor<6x6xi32, #CSR>
-    bufferization.dealloc_tensor %4 : tensor<6x6xi32, #CSC>
+    bufferization.dealloc_tensor %4 : tensor<6x6xi32, #CDR>
+    bufferization.dealloc_tensor %5 : tensor<6x6xi32, #CSC>
     return
   }
 }

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_3d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_3d.mlir
index 3c7d89f26401f..f9602ab93d259 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_3d.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_3d.mlir
@@ -1,9 +1,4 @@
-// UNSUPPORTED: target={{.*}}
-// FIXME: The test case is disabled (for now) because affine index on sparse tensor
-// are not handled efficiently by sparse compiler, the test case will be re-enabled
-// after new algorithm is implemented.
-
-// DEFINE: %{option} = enable-runtime-library=true
+// DEFINE: %{option} = "enable-index-reduction=true enable-runtime-library=true"
 // DEFINE: %{compile} = mlir-opt %s --sparse-compiler=%{option}
 // DEFINE: %{run} = mlir-cpu-runner \
 // DEFINE:  -e entry -entry-point-result=void  \
@@ -13,16 +8,16 @@
 // RUN: %{compile} | %{run}
 //
 // Do the same run, but now with direct IR generation.
-// REDEFINE: %{option} = "enable-runtime-library=false enable-buffer-initialization=true"
+// REDEFINE: %{option} = "enable-runtime-library=false enable-buffer-initialization=true enable-index-reduction=true"
 // RUN: %{compile} | %{run}
 //
 // Do the same run, but now with direct IR generation and vectorization.
-// REDEFINE: %{option} = "enable-runtime-library=false enable-buffer-initialization=true vl=2 reassociate-fp-reductions=true enable-index-optimizations=true"
+// REDEFINE: %{option} = "enable-runtime-library=false enable-buffer-initialization=true vl=2 reassociate-fp-reductions=true enable-index-optimizations=true enable-index-reduction=true"
 // RUN: %{compile} | %{run}
 
 // Do the same run, but now with direct IR generation and, if available, VLA
 // vectorization.
-// REDEFINE: %{option} = "enable-runtime-library=false vl=4 enable-arm-sve=%ENABLE_VLA"
+// REDEFINE: %{option} = "enable-runtime-library=false vl=4 enable-arm-sve=%ENABLE_VLA enable-index-reduction=true"
 // REDEFINE: %{run} = %lli \
 // REDEFINE:   --entry-function=entry_lli \
 // REDEFINE:   --extra-module=%S/Inputs/main_for_lli.ll \
@@ -39,6 +34,10 @@
   dimLevelType = [ "compressed", "dense", "compressed" ]
 }>
 
+#DDC = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "compressed", "compressed" ]
+}>
+
 // Creates and returns 3-D buffer of size (%s1, %s2, %s3) filled with the value %f
 func.func @alloc_3d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %f : f32) -> tensor<?x?x?xf32> {
   %buf = bufferization.alloc_tensor(%s1, %s2, %s3) : tensor<?x?x?xf32>
@@ -53,24 +52,33 @@ func.func @conv_3d(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: te
   return %ret : tensor<?x?x?xf32>
 }
 
-func.func @conv_3d_CCC(%arg0: tensor<?x?x?xf32, #CCC>, %arg1: tensor<?x?x?xf32, #CCC>) -> tensor<?x?x?xf32, #CCC> {
+func.func @conv_3d_CCC(%arg0: tensor<?x?x?xf32, #CCC>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32, #CCC> {
   %c6 = arith.constant 6 : index
   %s = bufferization.alloc_tensor(%c6, %c6, %c6) : tensor<?x?x?xf32, #CCC>
   %ret = linalg.conv_3d
-     ins (%arg0, %arg1: tensor<?x?x?xf32, #CCC>, tensor<?x?x?xf32, #CCC>)
+     ins (%arg0, %arg1: tensor<?x?x?xf32, #CCC>, tensor<?x?x?xf32>)
     outs (%s: tensor<?x?x?xf32, #CCC>) -> tensor<?x?x?xf32, #CCC>
   return %ret : tensor<?x?x?xf32, #CCC>
 }
 
-func.func @conv_3d_CDC(%arg0: tensor<?x?x?xf32, #CDC>, %arg1: tensor<?x?x?xf32, #CDC>) -> tensor<?x?x?xf32, #CDC> {
+func.func @conv_3d_CDC(%arg0: tensor<?x?x?xf32, #CDC>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32, #CDC> {
   %c6 = arith.constant 6 : index
   %s = bufferization.alloc_tensor(%c6, %c6, %c6) : tensor<?x?x?xf32, #CDC>
   %ret = linalg.conv_3d
-     ins (%arg0, %arg1: tensor<?x?x?xf32, #CDC>, tensor<?x?x?xf32, #CDC>)
+     ins (%arg0, %arg1: tensor<?x?x?xf32, #CDC>, tensor<?x?x?xf32>)
     outs (%s: tensor<?x?x?xf32, #CDC>) -> tensor<?x?x?xf32, #CDC>
   return %ret : tensor<?x?x?xf32, #CDC>
 }
 
+func.func @conv_3d_DDC(%arg0: tensor<?x?x?xf32, #DDC>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32, #DDC> {
+  %c6 = arith.constant 6 : index
+  %s = bufferization.alloc_tensor(%c6, %c6, %c6) : tensor<?x?x?xf32, #DDC>
+  %ret = linalg.conv_3d
+     ins (%arg0, %arg1: tensor<?x?x?xf32, #DDC>, tensor<?x?x?xf32>)
+    outs (%s: tensor<?x?x?xf32, #DDC>) -> tensor<?x?x?xf32, #DDC>
+  return %ret : tensor<?x?x?xf32, #DDC>
+}
+
 func.func @entry() {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
@@ -88,17 +96,15 @@ func.func @entry() {
 
   %in3D_CCC = sparse_tensor.convert %in3D
     : tensor<?x?x?xf32> to tensor<?x?x?xf32, #CCC>
-  %filter3D_CCC = sparse_tensor.convert %filter3D
-    : tensor<?x?x?xf32> to tensor<?x?x?xf32, #CCC>
-
   %in3D_CDC = sparse_tensor.convert %in3D
     : tensor<?x?x?xf32> to tensor<?x?x?xf32, #CDC>
-  %filter3D_CDC = sparse_tensor.convert %filter3D
-    : tensor<?x?x?xf32> to tensor<?x?x?xf32, #CDC>
+  %in3D_DDC = sparse_tensor.convert %in3D
+    : tensor<?x?x?xf32> to tensor<?x?x?xf32, #DDC>
 
   %dense_ret = call @conv_3d(%in3D, %filter3D, %out3D) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>)
-  %CCC_ret = call @conv_3d_CCC(%in3D_CCC, %filter3D_CCC) : (tensor<?x?x?xf32, #CCC>, tensor<?x?x?xf32, #CCC>) -> (tensor<?x?x?xf32, #CCC>)
-  %CDC_ret = call @conv_3d_CDC(%in3D_CDC, %filter3D_CDC) : (tensor<?x?x?xf32, #CDC>, tensor<?x?x?xf32, #CDC>) -> (tensor<?x?x?xf32, #CDC>)
+  %CCC_ret = call @conv_3d_CCC(%in3D_CCC, %filter3D) : (tensor<?x?x?xf32, #CCC>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32, #CCC>)
+  %CDC_ret = call @conv_3d_CDC(%in3D_CDC, %filter3D) : (tensor<?x?x?xf32, #CDC>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32, #CDC>)
+  %DDC_ret = call @conv_3d_DDC(%in3D_DDC, %filter3D) : (tensor<?x?x?xf32, #DDC>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32, #DDC>)
 
   //      CHECK:( ( ( 108, 108, 108, 108, 108, 108 ),
   // CHECK-SAME:    ( 124, 108, 108, 108, 108, 108 ),
@@ -224,18 +230,59 @@ func.func @entry() {
       : tensor<?x?x?xf32>, vector<6x6x6xf32>
   vector.print %v2 : vector<6x6x6xf32>
 
+  // CHECK-NEXT:( ( ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 124, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 124, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 124, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ) ),
+  // CHECK-SAME:  ( ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ) ),
+  // CHECK-SAME:  ( ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ) ),
+  // CHECK-SAME:  ( ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ) ),
+  // CHECK-SAME:  ( ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ) ),
+  // CHECK-SAME:  ( ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ) ) )
+  %3 = sparse_tensor.convert %DDC_ret
+    : tensor<?x?x?xf32, #DDC> to tensor<?x?x?xf32>
+  %v3 = vector.transfer_read %3[%c0, %c0, %c0], %zero
+      : tensor<?x?x?xf32>, vector<6x6x6xf32>
+  vector.print %v2 : vector<6x6x6xf32>
+
   // Free the resources
   bufferization.dealloc_tensor %in3D : tensor<?x?x?xf32>
   bufferization.dealloc_tensor %filter3D : tensor<?x?x?xf32>
   bufferization.dealloc_tensor %out3D : tensor<?x?x?xf32>
 
   bufferization.dealloc_tensor %in3D_CDC : tensor<?x?x?xf32, #CDC>
-  bufferization.dealloc_tensor %filter3D_CDC : tensor<?x?x?xf32, #CDC>
   bufferization.dealloc_tensor %in3D_CCC : tensor<?x?x?xf32, #CCC>
-  bufferization.dealloc_tensor %filter3D_CCC : tensor<?x?x?xf32, #CCC>
+  bufferization.dealloc_tensor %in3D_DDC : tensor<?x?x?xf32, #DDC>
 
   bufferization.dealloc_tensor %CCC_ret : tensor<?x?x?xf32, #CCC>
   bufferization.dealloc_tensor %CDC_ret : tensor<?x?x?xf32, #CDC>
-
+  bufferization.dealloc_tensor %DDC_ret : tensor<?x?x?xf32, #DDC>
   return
 }


        


More information about the Mlir-commits mailing list