[Mlir-commits] [mlir] [mlir][sparse] support sparsifying sparse kernels to sparse-iterator-based loop (PR #95858)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 17 15:16:18 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-sparse

Author: Peiming Liu (PeimingLiu)

<details>
<summary>Changes</summary>



---

Patch is 195.38 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/95858.diff


40 Files Affected:

- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h (+1-1) 
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td (+24-4) 
- (modified) mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h (+16-1) 
- (modified) mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h (+1) 
- (modified) mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td (+3-1) 
- (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+38) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp (+2-1) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp (+3-1) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp (+6) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp (+96-41) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h (+7) 
- (modified) mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir (+7-7) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_2d.mlir (+20-20) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_3d.mlir (+17-17) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_affine.mlir (+15-15) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_batch.mlir (+11-11) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_broadcast.mlir (+6-6) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_concat.mlir (+50-50) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir (+10-10) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir (+44-44) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_fusion.mlir (+3-3) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir (+50-50) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_kernels.mlir (+6-6) 
- (added) mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir (+43) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir (+13-13) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir (+12-12) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_nd.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_out.mlir (+30-30) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_outbuf.mlir (+17-17) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_parallel_reduce.mlir (+5-5) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_perm.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_reshape.mlir (+24-24) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_sddmm_org.mlir (+8-8) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_tensor_reshape.mlir (+6-6) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir (+7-7) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir (+10-10) 
- (modified) mlir/test/Dialect/SparseTensor/spy_sddmm_bsr.mlir (+8-8) 
- (modified) mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir (+42-42) 
- (added) mlir/test/Integration/Dialect/SparseTensor/CPU/iterator-based-sqsum.mlir (+80) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 04a6386a199de..68ca036121520 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -89,7 +89,7 @@ class LevelSet {
     assert(i < 64);
     return (bits & (1 << i)) != 0;
   }
-
+  unsigned max() const { return 64 - llvm::countl_zero(bits); }
   unsigned count() const { return llvm::popcount(bits); }
   bool empty() const { return bits == 0; }
 };
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index b2089924291cd..f31df080d7811 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1493,6 +1493,10 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
       ```
   }];
 
+  let arguments = (ins AnySparseTensor:$tensor,
+                       Optional<AnySparseIterator>:$parentIter,
+                       LevelAttr:$loLvl, LevelAttr:$hiLvl);
+  let results = (outs AnySparseIterSpace:$extractedSpace);
 
   let extraClassDeclaration = [{
     std::pair<Level, Level> getLvlRange() {
@@ -1506,10 +1510,20 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
     }
   }];
 
-  let arguments = (ins AnySparseTensor:$tensor,
-                       Optional<AnySparseIterator>:$parentIter,
-                       LevelAttr:$loLvl, LevelAttr:$hiLvl);
-  let results = (outs AnySparseIterSpace:$extractedSpace);
+  let builders = [
+    // Construct a 1-D iteration space.
+    OpBuilder<(ins "Value":$tensor, "Value":$parentIter,
+                   "sparse_tensor::Level":$loLvl),
+    [{
+      build($_builder, $_state, tensor, parentIter, loLvl, loLvl + 1);
+    }]>,
+    // Construct a 1-D root iteration space
+    OpBuilder<(ins "Value":$tensor),
+    [{
+      build($_builder, $_state, tensor, nullptr, 0);
+    }]>
+  ];
+
   let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
                        " attr-dict `:` type($tensor) (`,` type($parentIter)^)? "
                        "`->` qualified(type($extractedSpace))";
@@ -1594,6 +1608,12 @@ def IterateOp : SparseTensor_Op<"iterate",
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$region);
 
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<(ins "Value":$iterSpace, "ValueRange":$initArgs)>,
+    OpBuilder<(ins "Value":$iterSpace, "ValueRange":$initArgs, "LevelSet" :$crdUsedLvls)>
+  ];
+
   let extraClassDeclaration = [{
     unsigned getSpaceDim() {
       return getIterSpace().getType().getSpaceDim();
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
index 85bfee336f848..90021ffa7c380 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
@@ -51,6 +51,20 @@ struct SparsifierOptions : public PassPipelineOptions<SparsifierOptions> {
               mlir::SparseParallelizationStrategy::kAnyStorageAnyLoop,
               "any-storage-any-loop",
               "Enable sparse parallelization for any storage and loop."))};
+  PassOptions::Option<mlir::SparseEmitStrategy> emitStrategy{
+      *this, "sparse-emit-strategy",
+      ::llvm::cl::desc(
+          "Emit functional code or interfaces (to debug) for sparse loops"),
+      ::llvm::cl::init(mlir::SparseEmitStrategy::kFunctional),
+      llvm::cl::values(
+          clEnumValN(mlir::SparseEmitStrategy::kFunctional, "functional",
+                     "Emit functional code (with scf.for/while)."),
+          clEnumValN(mlir::SparseEmitStrategy::kSparseIterator,
+                     "sparse-iterator",
+                     "Emit (experimental) loops (with sparse.iterate)."),
+          clEnumValN(
+              mlir::SparseEmitStrategy::kDebugInterface, "debug-interface",
+              "Emit non-functional but easy-to-read interfaces to debug."))};
 
   PassOptions::Option<bool> enableRuntimeLibrary{
       *this, "enable-runtime-library",
@@ -143,7 +157,8 @@ struct SparsifierOptions : public PassPipelineOptions<SparsifierOptions> {
 
   /// Projects out the options for `createSparsificationPass`.
   SparsificationOptions sparsificationOptions() const {
-    return SparsificationOptions(parallelization, enableRuntimeLibrary);
+    return SparsificationOptions(parallelization, emitStrategy,
+                                 enableRuntimeLibrary);
   }
 
   /// Projects out the options for `createConvertVectorToLLVMPass`.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index c9164e39a3a75..2edb9a61d1876 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -51,6 +51,7 @@ enum class ReinterpretMapScope {
 /// Defines a scope for reinterpret map pass.
 enum class SparseEmitStrategy {
   kFunctional,     // generate fully inlined (and functional) sparse iteration
+  kSparseIterator, // generate (experimental) loop using sparse iterator.
   kDebugInterface, // generate only place-holder for sparse iteration
 };
 
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index b18c975105b75..7173cde67384b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -163,7 +163,9 @@ def SparsificationPass : Pass<"sparsification", "ModuleOp"> {
            "mlir::SparseEmitStrategy::kFunctional",
            "Emit functional code or interfaces (to debug) for sparse loops", [{llvm::cl::values(
              clEnumValN(mlir::SparseEmitStrategy::kFunctional, "functional",
-                        "Emit functional code."),
+                        "Emit functional code (with scf.for/while)."),
+             clEnumValN(mlir::SparseEmitStrategy::kSparseIterator, "sparse-iterator",
+                        "Emit (experimental) loops (with sparse.iterate)."),
              clEnumValN(mlir::SparseEmitStrategy::kDebugInterface, "debug-interface",
                         "Emit non-functional but easy-to-read interfaces to debug."))}]>,
     Option<"enableRuntimeLibrary", "enable-runtime-library", "bool",
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index ac711769ed2ea..504888be4c85f 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2300,6 +2300,41 @@ void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
   results.add<RemoveUnusedLvlCrds>(context);
 }
 
+void IterateOp::build(OpBuilder &builder, OperationState &odsState,
+                      Value iterSpace, ValueRange initArgs) {
+  unsigned rank = llvm::cast<IterSpaceType>(iterSpace.getType()).getSpaceDim();
+  // All ones.
+  LevelSet set((1 << rank) - 1);
+  return build(builder, odsState, iterSpace, initArgs, set);
+}
+
+void IterateOp::build(OpBuilder &builder, OperationState &odsState,
+                      Value iterSpace, ValueRange initArgs,
+                      LevelSet crdUsedLvls) {
+  OpBuilder::InsertionGuard guard(builder);
+
+  odsState.addOperands(iterSpace);
+  odsState.addOperands(initArgs);
+  odsState.getOrAddProperties<Properties>().crdUsedLvls =
+      builder.getIntegerAttr(builder.getIntegerType(64), crdUsedLvls);
+  Region *bodyRegion = odsState.addRegion();
+  odsState.addTypes(initArgs.getTypes());
+  Block *bodyBlock = builder.createBlock(bodyRegion);
+
+  // 1st args, sparse iterator
+  bodyBlock->addArgument(
+      llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(),
+      odsState.location);
+
+  // Followed by a list of used coordinates.
+  for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
+    bodyBlock->addArgument(builder.getIndexType(), odsState.location);
+
+  // Followed by a list of user-provided loop arguments.
+  for (Value v : initArgs)
+    bodyBlock->addArgument(v.getType(), v.getLoc());
+}
+
 ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
   OpAsmParser::Argument iterator;
   OpAsmParser::UnresolvedOperand iterSpace;
@@ -2384,6 +2419,9 @@ LogicalResult IterateOp::verify() {
     return emitOpError(
         "mismatch in number of loop-carried values and defined values");
   }
+  if (getCrdUsedLvls().max() > getSpaceDim())
+    return emitOpError("required out-of-bound coordinates");
+
   return success();
 }
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index 4224925147c84..1d614b7b29361 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -164,7 +164,6 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
       // replace sparse_tensor.yield with scf.yield.
       rewriter.eraseOp(yieldOp);
       rewriter.create<scf::YieldOp>(loc, yields);
-
       const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
       rewriter.replaceOp(
           op, whileOp.getResults().drop_front(it->getCursor().size()),
@@ -192,6 +191,8 @@ mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
 
 void mlir::populateLowerSparseIterationToSCFPatterns(
     TypeConverter &converter, RewritePatternSet &patterns) {
+
+  IterateOp::getCanonicalizationPatterns(patterns, patterns.getContext());
   patterns.add<ExtractIterSpaceConverter, SparseIterateOpConverter>(
       converter, patterns.getContext());
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 0c8e431d8c996..e8f85aab887ca 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1071,7 +1071,9 @@ static bool getAllTidLvlsInLatPoints(
   }
   // If we just need to one loop conditions and the conditions is not imposed on
   // non-unique level, the loop can be generated by a for loop.
-  return numloopCond == 1 && !hasNonUnique;
+  return numloopCond == 1 &&
+         (!hasNonUnique || env.options().sparseEmitStrategy ==
+                               SparseEmitStrategy::kSparseIterator);
 }
 
 /// Starts a loop sequence at given level. Returns true if
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 13c750e83d045..3ae6732f900fe 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -159,6 +159,12 @@ class SparsificationAndBufferizationPass
         pm.addPass(createSparseGPUCodegenPass(0, enableRuntimeLibrary));
       pm.addPass(createSparseReinterpretMapPass(ReinterpretMapScope::kAll));
       pm.addPass(createSparsificationPass(sparsificationOptions));
+      if (sparsificationOptions.sparseEmitStrategy ==
+          SparseEmitStrategy::kSparseIterator) {
+        pm.addNestedPass<func::FuncOp>(createSparseSpaceCollapsePass());
+        pm.addNestedPass<func::FuncOp>(createLowerSparseIterationToSCFPass());
+      }
+
       pm.addNestedPass<func::FuncOp>(createStageSparseOperationsPass());
       pm.addPass(createLowerSparseOpsToForeachPass(enableRuntimeLibrary,
                                                    /*enableConvert=*/true));
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index fe0e515a2d180..2be0193f0de83 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -139,6 +139,7 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
   this->valBuffer.assign(numTensors, nullptr);
   this->lvls.resize(numTensors);
   this->iters.resize(numTensors);
+  this->spIterVals.resize(numTensors);
 
   // These zeros will be overwritten below, but we need to initialize
   // them to something since we'll need random-access assignment.
@@ -173,6 +174,7 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
 
     lvls[tid].resize(lvlRank);
     iters[tid].resize(lvlRank);
+    spIterVals[tid].resize(lvlRank);
     loopHighs.assign(numLoops, nullptr);
 
     // Slice-driven loops related initialization.
@@ -229,6 +231,57 @@ LoopEmitter::makeLevelIterator(OpBuilder &builder, Location loc, TensorId t,
 void LoopEmitter::initializeLoopEmit(
     OpBuilder &builder, Location loc, LoopEmitter::OutputUpdater updater,
     LoopEmitter::SynTensorBoundSetter synSetter) {
+
+  // For every manifest tensor, set up the values buffer.
+  for (TensorId t = 0, numTensors = getNumManifestTensors(); t < numTensors;
+       t++) {
+    // TODO: this should be done through a folding pass after switching to
+    // `sparse_tensor.iterate`-based sparsification.
+    const Value tensor = tryFoldTensors(tensors[t]);
+    const auto rtp = dyn_cast<RankedTensorType>(tensor.getType());
+    // Skips only scalar, zero ranked tensor still need to be bufferized and
+    // (probably) filled with zeros by users.
+    if (!rtp)
+      continue;
+
+    auto stt = getSparseTensorType(tensor);
+    const auto shape = rtp.getShape();
+
+    // Perform the required bufferization. Dense inputs materialize from the
+    // input tensors. Sparse inputs use sparse primitives to obtain the values.
+    // Delegates extra output initialization to clients.
+    bool isOutput = isOutputTensor(t);
+    Type elementType = stt.getElementType();
+    if (!stt.hasEncoding()) {
+      // Non-annotated dense tensors.
+      BaseMemRefType denseTp = MemRefType::get(shape, elementType);
+
+      // TODO: if we unconditionally use fully dynamic layout here, it breaks
+      // some vectorization passes which requires static stride = 1.
+      // Is it possible to call vectorization pass after bufferization?
+      if (llvm::isa_and_nonnull<tensor::ExtractSliceOp>(tensor.getDefiningOp()))
+        denseTp = bufferization::getMemRefTypeWithFullyDynamicLayout(rtp);
+
+      Value denseVal =
+          builder.create<bufferization::ToMemrefOp>(loc, denseTp, tensor);
+      // Dense outputs need special handling.
+      if (isOutput && updater)
+        denseVal = updater(builder, loc, denseVal, tensor);
+
+      valBuffer[t] = denseVal;
+    } else {
+      // Annotated sparse tensors.
+      // We also need the value buffer for all-dense annotated "sparse"
+      // tensors.
+      valBuffer[t] = builder.create<ToValuesOp>(loc, tensor);
+    }
+  }
+
+  // The sparse iterator values will only be available after the loop is
+  // constructed.
+  if (emitStrategy == SparseEmitStrategy::kSparseIterator)
+    return;
+
   // For every synthetic tensor, set the high bound by calling the callback.
   if (synSetter) {
     TensorId synId = getSynTensorId();
@@ -241,7 +294,6 @@ void LoopEmitter::initializeLoopEmit(
   }
 
   // For every manifest tensor:
-  // * get the values buffer.
   // * For every level:
   //   * get the positions and coordinates buffers
   //   * get/compute the level-size, which is also used as the upper-bound
@@ -256,12 +308,9 @@ void LoopEmitter::initializeLoopEmit(
       // Skips only scalar, zero ranked tensor still need to be bufferized and
       // (probably) filled with zeros by users.
       continue;
-    // FIXME: the definition of `lvlRank` looks more like a dim-rank;
-    // but the variable is used as a level everywhere below, which
-    // suggests there may be some dim/lvl confusion going on here.
+
     auto stt = getSparseTensorType(tensor);
     const Level lvlRank = stt.getLvlRank();
-    const auto shape = rtp.getShape();
 
     // Scan all levels of current tensor.
     for (Level l = 0; l < lvlRank; l++) {
@@ -273,36 +322,6 @@ void LoopEmitter::initializeLoopEmit(
       auto it = makeLevelIterator(builder, loc, t, l);
       iters[t][l].emplace_back(std::move(it));
     }
-
-    // Perform the required bufferization. Dense inputs materialize
-    // from the input tensors. Sparse inputs use sparse primitives to obtain the
-    // values.
-    // Delegates extra output initialization to clients.
-    bool isOutput = isOutputTensor(t);
-    Type elementType = stt.getElementType();
-    if (!stt.hasEncoding()) {
-      // Non-annotated dense tensors.
-      BaseMemRefType denseTp = MemRefType::get(shape, elementType);
-
-      // TODO: if we unconditionally use fully dynamic layout here, it breaks
-      // some vectorization passes which requires static stride = 1.
-      // Is it possible to call vectorization pass after bufferization?
-      if (llvm::isa_and_nonnull<tensor::ExtractSliceOp>(tensor.getDefiningOp()))
-        denseTp = bufferization::getMemRefTypeWithFullyDynamicLayout(rtp);
-
-      Value denseVal =
-          builder.create<bufferization::ToMemrefOp>(loc, denseTp, tensor);
-      // Dense outputs need special handling.
-      if (isOutput && updater)
-        denseVal = updater(builder, loc, denseVal, tensor);
-
-      valBuffer[t] = denseVal;
-    } else {
-      // Annotated sparse tensors.
-      // We also need the value buffer for all-dense annotated "sparse"
-      // tensors.
-      valBuffer[t] = builder.create<ToValuesOp>(loc, tensor);
-    }
     // NOTE: we can also prepare for 0 lvl here in advance, this will hoist
     // some loop preparation from tensor iteration, but will also (undesirably)
     // hoist the code ouside if-conditions.
@@ -396,11 +415,13 @@ void LoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc,
                                   ArrayRef<TensorLevel> tidLvls) {
   // TODO: sort
   assert(loopSeqStack.size() == loopStack.size());
-  // Prepares for all the tensors used in the current loop sequence.
 
-  for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
-    levelReducedDep[tid][lvl]++;
-    prepareLoopOverTensorAtLvl(builder, loc, tid, lvl);
+  if (emitStrategy != SparseEmitStrategy::kSparseIterator) {
+    // Prepares for all the tensors used in the current loop sequence.
+    for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
+      levelReducedDep[tid][lvl]++;
+      prepareLoopOverTensorAtLvl(builder, loc, tid, lvl);
+    }
   }
 
   // Universal Index starts from 0.
@@ -598,6 +619,31 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
     OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls,
     MutableArrayRef<Value> reduc, bool tryParallel, bool needsUniv) {
 
+  // TODO: handle coiteration with sparse iterator.
+  if (emitStrategy == SparseEmitStrategy::kSparseIterator) {
+    assert(tidLvls.size() == 1);
+    auto [tid, lvl] = unpackTensorLevel(tidLvls.front());
+    Value t = tensors[tid];
+
+    // Extract and iterate over the iteration space.
+    ExtractIterSpaceOp extractSpaceOp =
+        lvl == 0 ? builder.create<ExtractIterSpaceOp>(loc, t)
+                 : builder.create<ExtractIterSpaceOp>(
+                       loc, t, spIterVals[tid][lvl - 1], lvl);
+
+    IterateOp iterOp = builder.create<IterateOp>(
+        loc, extractSpaceOp.getExtractedSpace(), reduc);
+    spIterVals[tid][lvl] = iterOp.getIterator();
+
+    // Update the reduction varaibles.
+    llvm::copy(iterOp.getRegionIterArgs(), reduc.begin());
+    // Set the insertion point to loop body.
+    builder.setInsertionPointToStart(iterOp.getBody());
+    loopStack.emplace_back(tidLvls, iterOp, builder.getInsertionBlock(),
+                           iterOp.getIterator(), loopTag);
+    return iterOp;
+  }
+
   // TODO: support multiple return on parallel for?
   tryParallel = tryParallel && reduc.size() <= 1;
 
@@ -685,6 +731,16 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
 void LoopEmitter::exitFo...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/95858


More information about the Mlir-commits mailing list